Merge pull request #72936 from 372046933:range_dtype

PiperOrigin-RevId: 662422313
diff --git a/.bazelrc b/.bazelrc
index 11783a8..edb4659 100644
--- a/.bazelrc
+++ b/.bazelrc
@@ -573,6 +573,9 @@
 build:rbe_win_clang --linkopt=/FORCE:MULTIPLE
 build:rbe_win_clang --host_linkopt=/FORCE:MULTIPLE
 
+# TODO(belitskiy): Rename `rbe_win_clang` to this, once done switching presubmits.
+build:rbe_windows_x86_cpu --config=rbe_win_clang
+
 # END TF REMOTE BUILD EXECUTION OPTIONS
 
 # TFLite build configs for generic embedded Linux
@@ -787,17 +790,19 @@
 
 # PYCPP TESTS run a suite of Python and C++ tests to verify general correctness over
 # the whole TF code base. These are usually run continuously or upon presubmit.
-# CPU PYCPP:
+# LINUX CPU PYCPP:
 test:linux_cpu_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only
 test:linux_cpu_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only
 test:linux_cpu_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium
 test:linux_cpu_pycpp_test --config=linux_cpu_pycpp_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/...
-# CUDA PYCPP:
+
+# LINUX CUDA PYCPP:
 test:linux_cuda_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11
 test:linux_cuda_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11
 test:linux_cuda_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium
 test:linux_cuda_pycpp_test --config=linux_cuda_pycpp_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/...
-# ARM64 PYCPP
+
+# LINUX ARM64 PYCPP
 # In Linux Arm64 presubmit/continuous build, we cross-compile the binaries on
 # Linux x86 so that we can use RBE. Since tests still need to run on the single
 # host Arm64 machine, the build becomes too slow (~30 min) to be a presubmit.
@@ -830,6 +835,13 @@
 # CROSS-COMPILE MACOS X86 PYCPP
 build:cross_compile_macos_x86_pycpp_test --config=macos_x86_pycpp_test
 build:cross_compile_macos_x86_pycpp_test -//tensorflow/core/kernels:quantized_conv_ops_test -//tensorflow/core/kernels:quantized_matmul_op_test -//tensorflow/python/ops:quantized_conv_ops_test -//tensorflow/tools/graph_transforms:transforms_test -//tensorflow/python/tools:aot_compiled_test
+# WINDOWS X86-64 CPU PYCPP
+test:windows_x86_cpu_pycpp_test_filters --test_tag_filters=-no_windows,-windows_excluded,-no_oss,-oss_excluded,-gpu,-tpu,-benchmark-test
+test:windows_x86_cpu_pycpp_test_filters --build_tag_filters=-no_windows,-windows_excluded,-no_oss,-oss_excluded,-benchmark-test
+test:windows_x86_cpu_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium --test_timeout="300,450,1200,3600"
+test:windows_x86_cpu_pycpp_test_opts --copt=/d2ReducedOptimizeHugeFunctions --host_copt=/d2ReducedOptimizeHugeFunctions --dynamic_mode=off --build_tests_only
+test:windows_x86_cpu_pycpp_test --config=windows_x86_cpu_pycpp_test_opts --config=windows_x86_cpu_pycpp_test_filters -- //tensorflow/... -//tensorflow/java/... -//tensorflow/lite/... -//tensorflow/compiler/...
+
 # END TF TEST SUITE OPTIONS
 
 # START CROSS-COMPILE CONFIGS
diff --git a/.github/workflows/osv-scanner-scheduled.yml b/.github/workflows/osv-scanner-scheduled.yml
index 60af559..2013b8c 100644
--- a/.github/workflows/osv-scanner-scheduled.yml
+++ b/.github/workflows/osv-scanner-scheduled.yml
@@ -28,7 +28,7 @@
 jobs:
   scan-scheduled:
     if: github.repository == 'tensorflow/tensorflow'
-    uses: "google/osv-scanner-action/.github/workflows/osv-scanner-reusable.yml@v1.8.1"
+    uses: "google/osv-scanner-action/.github/workflows/osv-scanner-reusable.yml@v1.8.2"
     with:
       scan-args: |-
         --lockfile=requirements.txt:./requirements_lock_3_9.txt
diff --git a/.github/workflows/pylint-presubmit.yml b/.github/workflows/pylint-presubmit.yml
index f457db5..8b1a034 100644
--- a/.github/workflows/pylint-presubmit.yml
+++ b/.github/workflows/pylint-presubmit.yml
@@ -38,7 +38,7 @@
       run: |
         echo Changed files: ${{ steps.get_file_changes.outputs.files }}
     - name: Set up Python 3.9
-      uses: actions/setup-python@82c7e631bb3cdc910f68e0081d67478d79c6982d # v5.1.0
+      uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # v5.1.1
       with:
         python-version: "3.9"
     - name: Install Python dependencies
diff --git a/.github/workflows/scorecards-analysis.yml b/.github/workflows/scorecards-analysis.yml
index e72eab8..ceb213e 100644
--- a/.github/workflows/scorecards-analysis.yml
+++ b/.github/workflows/scorecards-analysis.yml
@@ -46,7 +46,7 @@
           persist-credentials: false
 
       - name: "Run analysis"
-        uses: ossf/scorecard-action@dc50aa9510b46c811795eb24b2f1ba02a914e534 # v2.3.3
+        uses: ossf/scorecard-action@62b2cac7ed8198b15735ed49ab1e5cf35480ba46 # v2.4.0
         with:
           results_file: results.sarif
           results_format: sarif
@@ -55,7 +55,7 @@
       # Upload the results as artifacts (optional). Commenting out will disable uploads of run results in SARIF
       # format to the repository Actions tab.
       - name: "Upload artifact"
-        uses: actions/upload-artifact@65462800fd760344b1a7b4382951275a0abb4808 # v4.3.3
+        uses: actions/upload-artifact@0b2256b8c012f0828dc542b3febcab082c67f72b # v4.3.4
         with:
           name: SARIF file
           path: results.sarif
@@ -64,6 +64,6 @@
       # Upload the results to GitHub's code scanning dashboard (optional).
       # Commenting out will disable upload of results to your repo's Code Scanning dashboard
       - name: "Upload to code-scanning"
-        uses: github/codeql-action/upload-sarif@b611370bb5703a7efb587f9d136a52ea24c5c38c # v3.25.11
+        uses: github/codeql-action/upload-sarif@afb54ba388a7dca6ecae48f608c4ff05ff4cc77a # v3.25.15
         with:
           sarif_file: results.sarif
diff --git a/.github/workflows/sigbuild-docker-branch.yml b/.github/workflows/sigbuild-docker-branch.yml
index 2c81873..c72cc98 100644
--- a/.github/workflows/sigbuild-docker-branch.yml
+++ b/.github/workflows/sigbuild-docker-branch.yml
@@ -43,16 +43,16 @@
         uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7
       -
         name: Set up Docker Buildx
-        uses: docker/setup-buildx-action@d70bba72b1f3fd22344832f00baa16ece964efeb # v3.3.0
+        uses: docker/setup-buildx-action@988b5a0280414f521da01fcc63a27aeeb4b104db # v3.6.1
       -
         name: Login to DockerHub
-        uses: docker/login-action@0d4c9c5ea7693da7b068278f7b52bda2a190a446 # v3.2.0
+        uses: docker/login-action@9780b0c442fbb1117ed29e0efdff1e18412f7567 # v3.3.0
         with:
           username: ${{ secrets.DOCKERHUB_USERNAME }}
           password: ${{ secrets.DOCKERHUB_TOKEN }}
       -
         name: Login to GCR
-        uses: docker/login-action@0d4c9c5ea7693da7b068278f7b52bda2a190a446 # v3.2.0
+        uses: docker/login-action@9780b0c442fbb1117ed29e0efdff1e18412f7567 # v3.3.0
         with:
           registry: gcr.io
           username: _json_key
@@ -67,7 +67,7 @@
       -
         name: Build and push
         id: docker_build
-        uses: docker/build-push-action@15560696de535e4014efeff63c48f16952e52dd1 # v6.2.0
+        uses: docker/build-push-action@5176d81f87c23d6fc96624dfdbcd9f3830bbe445 # v6.5.0
         with:
           push: true
           context: ./tensorflow/tools/tf_sig_build_dockerfiles
diff --git a/.github/workflows/sigbuild-docker-presubmit.yml b/.github/workflows/sigbuild-docker-presubmit.yml
index 7de12e1..e21ddb0 100644
--- a/.github/workflows/sigbuild-docker-presubmit.yml
+++ b/.github/workflows/sigbuild-docker-presubmit.yml
@@ -47,16 +47,25 @@
         uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7
       -
         name: Set up Docker Buildx
-        uses: docker/setup-buildx-action@d70bba72b1f3fd22344832f00baa16ece964efeb # v3.3.0
+        uses: docker/setup-buildx-action@988b5a0280414f521da01fcc63a27aeeb4b104db # v3.6.1
       -
         name: Login to GCR
         if: contains(github.event.pull_request.labels.*.name, 'build and push to gcr.io for staging')
-        uses: docker/login-action@0d4c9c5ea7693da7b068278f7b52bda2a190a446 # v3.2.0
+        uses: docker/login-action@9780b0c442fbb1117ed29e0efdff1e18412f7567 # v3.3.0
         with:
           registry: gcr.io
           username: _json_key
           password: ${{ secrets.GCP_CREDS }}
       -
+        name: Login to AR
+        # Once this is verified, change the label's name. For now, we will piggyback on gcr.io actions.
+        if: contains(github.event.pull_request.labels.*.name, 'build and push to gcr.io for staging')
+        uses: docker/login-action@0d4c9c5ea7693da7b068278f7b52bda2a190a446 # v3.2.0
+        with:
+          registry: us-central1-docker.pkg.dev
+          username: _json_key
+          password: ${{ secrets.GCP_CREDS }}
+      -
         name: Grab the date to do cache busting (assumes same day OK to keep)
         run: |
           echo "DATE=$(date +'%Y-%m-%d')" >> "$GITHUB_OUTPUT"
@@ -64,7 +73,7 @@
       -
         name: Build containers, and push to GCR only if the 'build and push to gcr.io for staging' label is applied
         id: docker_build
-        uses: docker/build-push-action@15560696de535e4014efeff63c48f16952e52dd1 # v6.2.0
+        uses: docker/build-push-action@5176d81f87c23d6fc96624dfdbcd9f3830bbe445 # v6.5.0
         with:
           push: ${{ contains(github.event.pull_request.labels.*.name, 'build and push to gcr.io for staging') }}
           context: ./tensorflow/tools/tf_sig_build_dockerfiles
@@ -74,6 +83,7 @@
             CACHEBUSTER=${{ steps.date.outputs.DATE }}
           tags: |
             gcr.io/tensorflow-sigs/build:${{ github.event.number }}-${{ matrix.python-version }}
+            us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/build:${{ github.event.number }}-${{ matrix.python-version }}
           cache-from: |
             type=registry,ref=tensorflow/build:latest-${{ matrix.python-version }}
             type=registry,ref=gcr.io/tensorflow-sigs/build:${{ github.event.number }}-${{ matrix.python-version }}
diff --git a/.github/workflows/sigbuild-docker.yml b/.github/workflows/sigbuild-docker.yml
index 062338e..78e7fd7 100644
--- a/.github/workflows/sigbuild-docker.yml
+++ b/.github/workflows/sigbuild-docker.yml
@@ -46,21 +46,29 @@
         uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7
       -
         name: Set up Docker Buildx
-        uses: docker/setup-buildx-action@d70bba72b1f3fd22344832f00baa16ece964efeb # v3.3.0
+        uses: docker/setup-buildx-action@988b5a0280414f521da01fcc63a27aeeb4b104db # v3.6.1
       -
         name: Login to DockerHub
-        uses: docker/login-action@0d4c9c5ea7693da7b068278f7b52bda2a190a446 # v3.2.0
+        uses: docker/login-action@9780b0c442fbb1117ed29e0efdff1e18412f7567 # v3.3.0
         with:
           username: ${{ secrets.DOCKERHUB_USERNAME }}
           password: ${{ secrets.DOCKERHUB_TOKEN }}
       -
         name: Login to GCR
-        uses: docker/login-action@0d4c9c5ea7693da7b068278f7b52bda2a190a446 # v3.2.0
+        uses: docker/login-action@9780b0c442fbb1117ed29e0efdff1e18412f7567 # v3.3.0
         with:
           registry: gcr.io
           username: _json_key
           password: ${{ secrets.GCP_CREDS }}
       -
+        name: Login to AR
+        # Once this is verified, removed gcr.io actions.
+        uses: docker/login-action@9780b0c442fbb1117ed29e0efdff1e18412f7567 # v3.3.0
+        with:
+          registry: us-central1-docker.pkg.dev
+          username: _json_key
+          password: ${{ secrets.GCP_CREDS }}
+      -
         name: Grab the upcoming TF version to tag this container
         run: |
           # [[:digit:]] searches for numbers and \+ joins them together
@@ -74,7 +82,7 @@
       -
         name: Build and push
         id: docker_build
-        uses: docker/build-push-action@15560696de535e4014efeff63c48f16952e52dd1 # v6.2.0
+        uses: docker/build-push-action@5176d81f87c23d6fc96624dfdbcd9f3830bbe445 # v6.5.0
         with:
           push: true
           context: ./tensorflow/tools/tf_sig_build_dockerfiles
@@ -87,6 +95,8 @@
             tensorflow/build:${{ steps.tf-version.outputs.TF_VERSION }}-${{ matrix.python-version }}
             gcr.io/tensorflow-sigs/build:latest-${{ matrix.python-version }}
             gcr.io/tensorflow-sigs/build:${{ steps.tf-version.outputs.TF_VERSION }}-${{ matrix.python-version }}
+            us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/build:latest-${{ matrix.python-version }}
+            us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/build:${{ steps.tf-version.outputs.TF_VERSION }}-${{ matrix.python-version }}
           cache-from: type=registry,ref=tensorflow/build:latest-${{ matrix.python-version }}
           cache-to: type=inline
       -
diff --git a/RELEASE.md b/RELEASE.md
index 51b3c60..eb2d939 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -35,6 +35,11 @@
       should run synchronously, as opposed to be parallelizable when
       `options.experimental_optimization.map_parallelization=True`. This saves
       memory compared to setting `num_parallel_calls=1`.
+    * Add optional `use_unbounded_threadpool` argument to `map`, to specify that
+      the `map` should use an unbounded threadpool instead of the default pool
+      that is based on the number of cores on the machine. This can improve 
+      throughput for map functions which perform IO or otherwise release the 
+      CPU.
 * `tf.lite`
     * `Dequantize` op supports `TensorType_INT4`.
         * This change includes per-channel dequantization.
diff --git a/ci/devinfra/docker_windows/Dockerfile b/ci/devinfra/docker_windows/Dockerfile
deleted file mode 100644
index 540f82a..0000000
--- a/ci/devinfra/docker_windows/Dockerfile
+++ /dev/null
@@ -1,256 +0,0 @@
-FROM mcr.microsoft.com/dotnet/framework/sdk:4.8-windowsservercore-ltsc2019@sha256:c1b2be17aa0c1a5d9493a306395a6f07141aae8d7897f7ba319183f28719c990
-
-# Set default powershell policy for this script (ProgressPreference='SilentlyContinue' makes
-# downloads with Invoke-WebRequest not show the progress bar and is MUCH faster).
-SHELL ["powershell.exe", "-ExecutionPolicy", "Bypass", "-Command", "$ErrorActionPreference='Stop'; $ProgressPreference='SilentlyContinue'; $VerbosePreference = 'Continue';"]
-
-# Workaround for networking (b/112379377) was closed as won't fix for MTU setting.
-# Remaining lines handle making the metadata server on the VM accessible inside docker.
-RUN Get-NetAdapter | Where-Object Name -like "*Ethernet*" | ForEach-Object { \
-        & netsh interface ipv4 set subinterface $_.InterfaceIndex mtu=1460 store=persistent }; \
-    $gateway = (Get-NetRoute | Where { $_.DestinationPrefix -eq \"0.0.0.0/0\" } | Sort-Object RouteMetric \
-        | Select NextHop).NextHop; \
-    $ifIndex = (Get-NetAdapter -InterfaceDescription \"Hyper-V Virtual Ethernet*\" | Sort-Object \
-        | Select ifIndex).ifIndex; \
-    New-NetRoute -DestinationPrefix 169.254.169.254/32 -InterfaceIndex $ifIndex -NextHop $gateway
-
-# Enable Long Paths for Win32 File/Folder APIs.
-RUN New-ItemProperty -Path HKLM:\SYSTEM\CurrentControlSet\Control\FileSystem \
-        -Name LongPathsEnabled -Value 1 -PropertyType DWORD -Force
-
-# Install Visual C++ Redistributable for Visual Studio 2015-2022.
-RUN New-Item -Path "C:/" -Name "TEMP" -ItemType "directory"; \
-    Invoke-WebRequest "https://aka.ms/vs/17/release/vc_redist.x64.exe" \
-        -OutFile C:/TEMP/vc_redist.x64.exe -UseBasicParsing; \
-    Start-Process -filepath C:/TEMP/vc_redist.x64.exe -ArgumentList '/install', '/passive', '/norestart' -Wait; \
-    Remove-Item C:/TEMP/vc_redist.x64.exe
-
-# Install Visual Studio 2022 Build Tools. Install ManagedDesktopBuildTools separately to ensure all Optional workloads are installed too.
-RUN Invoke-WebRequest "https://aka.ms/vs/17/release/vs_buildtools.exe" \
-        -OutFile C:/TEMP/vs_buildtools.exe -UseBasicParsing; \
-    Start-Process -FilePath C:/TEMP/vs_buildtools.exe -ArgumentList "--installPath", "C:/VS", \
-        "--quiet", "--wait", "--nocache", \
-        "--add", "Microsoft.VisualStudio.Workload.VCTools", \
-        "--add", "Microsoft.VisualStudio.Component.VC.Tools.x86.x64", \
-        "--add", "Microsoft.VisualStudio.Component.Windows10SDK.19041" -Wait; \
-    Start-Process -FilePath C:/TEMP/vs_buildtools.exe -ArgumentList "--installPath", "C:/VS", \
-        "--quiet", "--wait", "--nocache", "--includeOptional", \
-        "--add", "Microsoft.VisualStudio.Workload.ManagedDesktopBuildTools" -Wait; \
-    Remove-Item C:/TEMP/vs_buildtools.exe; \
-    [Environment]::SetEnvironmentVariable(\"BAZEL_VC\", \"C:\VS\VC\", \"Machine\"); \
-    $old_path = [Environment]::GetEnvironmentVariable(\"PATH\", \"Machine\"); \
-    [Environment]::SetEnvironmentVariable(\"PATH\", $old_path + \";C:\VS\VC\Tools\MSVC\14.33.31629\bin\Hostx64\x64;C:\VS\Common7\Tools;C:\VS\MSBuild\Current\Bin\", \"Machine\");
-
-# Add signtool.exe to the PATH. Note this path may need to be edited if updates
-# are made to the Windows 10 SDK.
-RUN $old_path = [Environment]::GetEnvironmentVariable(\"PATH\", \"Machine\"); \
-    [Environment]::SetEnvironmentVariable(\"PATH\", $old_path + \";C:\Program Files (x86)\Windows Kits\10\App Certification Kit\", \"Machine\");
-
-# Install WiX toolset (v4) - Necessary for MSI Installer/Signing builds
-RUN dotnet tool install --global wix
-
-# Install msys2, packages and add to path.
-RUN [Net.ServicePointManager]::SecurityProtocol = [Net.SecurityProtocolType]::Tls12; \
-    Invoke-WebRequest "https://repo.msys2.org/distrib/x86_64/msys2-base-x86_64-20220319.sfx.exe" \
-        -OutFile msys2_install.exe -UseBasicParsing; \
-    .\msys2_install.exe -y -oC:\; \
-    Remove-Item msys2_install.exe; \
-    function msys() { C:\msys64\usr\bin\bash.exe @('-lc') + @Args; } \
-    msys ' '; \
-    msys 'pacman --noconfirm -Syy bsdcpio bsdtar bzip2'; \
-    msys 'pacman --noconfirm -Syy coreutils curl dash file filesystem findutils'; \
-    msys 'pacman --noconfirm -Syy flex gawk gcc-libs grep gzip inetutils info'; \
-    msys 'pacman --noconfirm -Syy less lndir mintty ncurses pactoys-git patch'; \
-    msys 'pacman --noconfirm -Syy pax-git pkgfile rebase sed tar tftp-hpa time tzcode util-linux which'; \
-    $old_path = [Environment]::GetEnvironmentVariable(\"PATH\", \"Machine\"); \
-    [Environment]::SetEnvironmentVariable(\"PATH\", $old_path + \";C:\msys64;C:\msys64\usr\bin\", \"Machine\");
-
-# Install Go 1.19.1
-RUN Invoke-WebRequest "https://go.dev/dl/go1.19.1.windows-amd64.msi" \
-        -OutFile C:/TEMP/go_install.msi -UseBasicParsing; \
-    Start-Process C:/TEMP/go_install.msi -ArgumentList "/quiet", "/log", "C:/TEMP/go_install_log.txt", \
-        "InstallAllUsers=1", "PrependPath=1" -wait; \
-    Remove-Item C:/TEMP/go_install.msi; \
-    Remove-Item C:/TEMP/go_install_log.txt
-
-# Install Python 3.
-RUN Invoke-WebRequest "https://www.python.org/ftp/python/3.10.4/python-3.10.4-amd64.exe" \
-        -OutFile C:/TEMP/python_install.exe -UseBasicParsing; \
-    Start-Process C:/TEMP/python_install.exe -ArgumentList "/quiet", "/log", "C:/TEMP/python_install_log.txt", \
-        "InstallAllUsers=1", "PrependPath=1" -wait; \
-    Remove-Item C:/TEMP/python_install.exe; \
-    Remove-Item C:/TEMP/python_install_log.txt
-
-# Install JDK 17
-RUN Add-Type -AssemblyName "System.IO.Compression.FileSystem"; \
-    $zulu_url = \"https://cdn.azul.com/zulu/bin/zulu17.32.13-ca-jdk17.0.2-win_x64.zip\"; \
-    $zulu_zip = \"c:/temp/jdk_install.zip\"; \
-    $zulu_extracted_path = \"c:/temp/\" + [IO.Path]::GetFileNameWithoutExtension($zulu_url); \
-    $zulu_root = \"c:/openjdk\"; \
-    (New-Object Net.WebClient).DownloadFile($zulu_url, $zulu_zip); \
-    [System.IO.Compression.ZipFile]::ExtractToDirectory($zulu_zip, \"c:/temp\"); \
-    Move-Item $zulu_extracted_path -Destination $zulu_root; \
-    Remove-Item $zulu_zip; \
-    $old_path = [Environment]::GetEnvironmentVariable(\"PATH\", \"Machine\"); \
-    [Environment]::SetEnvironmentVariable(\"PATH\", $old_path + \";${zulu_root}\bin\", \"Machine\"); \
-    [Environment]::SetEnvironmentVariable(\"JAVA_HOME\", $zulu_root, \"Machine\")
-
-# Install gcloud (install.bat installs directly into bin folder of extracted zip contents)
-# Install needed gcloud components
-RUN Add-Type -AssemblyName "System.IO.Compression.FileSystem"; \
-    $pkg_url = \"https://dl.google.com/dl/cloudsdk/channels/rapid/downloads/google-cloud-cli-396.0.0-windows-x86_64.zip\"; \
-    $pkg_zip = \"c:/temp/gcloud.zip\"; \
-    $pkg_extracted_path = \"c:/google-cloud-sdk\"; \
-    (New-Object Net.WebClient).DownloadFile($pkg_url, $pkg_zip); \
-    [System.IO.Compression.ZipFile]::ExtractToDirectory($pkg_zip, \"c:/\"); \
-    Start-Process cmd.exe -ArgumentList "/c", "/s", "$pkg_extracted_path/install.bat", "-q" -wait; \
-    Remove-Item $pkg_zip; \
-    $old_path = [Environment]::GetEnvironmentVariable(\"PATH\", \"Machine\"); \
-    [Environment]::SetEnvironmentVariable(\"PATH\", $old_path + \";${pkg_extracted_path}\bin\", \"Machine\"); \
-    $env:PATH = [Environment]::GetEnvironmentVariable('PATH', 'Machine'); \
-    gcloud components install docker-credential-gcr kubectl gsutil;
-
-# Install cygwin and packages
-# Running a seperate ps1 file since when running inside a Dockerfile, it does
-# not work.
-COPY install/install_cygwin.ps1 c:/
-RUN c:/install_cygwin.ps1; \
-    $old_path = [Environment]::GetEnvironmentVariable(\"PATH\", \"Machine\"); \
-    [Environment]::SetEnvironmentVariable(\"PATH\", $old_path + \";C:\Cygwin64\bin\", \"Machine\");
-RUN Remove-Item c:/install_cygwin.ps1
-
-# Install Chocolatey and packages
-RUN Invoke-Expression ((New-Object Net.WebClient).DownloadString('https://chocolatey.org/install.ps1')); \
-    $env:PATH = [Environment]::GetEnvironmentVariable('PATH', 'Machine'); \
-    choco feature enable -n allowGlobalConfirmation; \
-    choco install 7zip; \
-    choco install 7zip.install; \
-    choco install 7zip.portable; \
-    choco install anaconda2 --version 5.0.1; \
-    choco install anaconda3 --version 5.0.1; \
-    choco install android-sdk --version 25.2.3.1; \
-    choco install AndroidStudio --version 3.0.1.0; \
-    choco install ant --version 1.10.1; \
-    choco install ccleaner; \
-    choco install chocolatey; \
-    choco install chocolatey-core.extension; \
-    choco install chocolatey-visualstudio.extension; \
-    choco install chocolatey-windowsupdate.extension; \
-    choco install cmake.install; \
-    choco install dotnetcore-sdk; \
-    choco install git; \
-    choco install git.install; \
-    choco install GoogleChrome; \
-    choco install gradle --version 4.4.1; \
-    choco install jdk8; \
-    choco install KB2533623; \
-    choco install KB2919355; \
-    choco install KB2919442; \
-    choco install KB2999226; \
-    choco install KB3033929; \
-    choco install KB3035131; \
-    choco install maven; \
-    choco install ninja; \
-    choco install nodejs --version 9.3.0; \
-    choco install nodejs.install --version 9.3.0; \
-    choco install nuget.commandline; \
-    choco install openjdk11; \
-    choco install peazip; \
-    choco install peazip.install; \
-    choco install peazip.portable; \
-    choco install php --version 7.2.0; \
-    choco install protoc --version 3.2.0; \
-    choco install ruby --version 2.5.0.1; \
-    choco install swig --version 3.0.9; \
-    choco install sysinternals; \
-    choco install unrar; \
-    choco install unzip; \
-    choco install vcredist140; \
-    choco install vcredist2015; \
-    choco install vim; \
-    choco install winrar; \
-    choco install zip; \
-    choco install Firefox; \
-    choco install iisexpress;
-
-RUN cmd /c 'mklink /J c:\Anaconda c:\tools\anaconda2';
-RUN cmd /c 'mklink c:\programdata\chocolatey\bin\rar.exe \"c:\program files\winrar\rar.exe\"';
-
-# Installing pip packages
-RUN pip install --upgrade setuptools; \
-    pip install altgraph appdirs cachetools certifi cffi chardet colorama \
-    cryptography cycler Cython decorator google-api-python-client \
-    google-auth google-auth-httplib2 grpcio httplib2 idna ipython-genutils \
-    kiwisolver macholib matplotlib nose numpy packaging pandas pickleshare pip \
-    prompt-toolkit protobuf psutil pyasn1 pyasn1-modules pycparser Pygments \
-    pyparsing pyreadline python-dateutil pytz pywin32 requests rsa setuptools \
-    simplegeneric six Tempita traitlets uritemplate urllib3 virtualenv wcwidth \
-    wheel win-unicode-console;
-
-# Hardcoding Android license since I did not find any solution on accepting it
-# through the docker build command. If the licensing agreement changes, this 
-# will need to be updated as well.
-RUN New-Item -ItemType Directory -Path C:\Android\android-sdk\licenses; \
-    Set-Content -Path .\Android\android-sdk\licenses\android-sdk-license -Value "`n24333f8a63b6825ea9c5514f83c2829b004d1fee" -NoNewLine;
-
-# Add sdkmanager to PATH
-RUN $old_path = [Environment]::GetEnvironmentVariable(\"PATH\", \"Machine\"); \
-    [Environment]::SetEnvironmentVariable(\"PATH\", $old_path + \";C:\Android\android-sdk\tools\bin\", \"Machine\");
-
-# Install android packages
-RUN $env:PATH = [Environment]::GetEnvironmentVariable('PATH', 'Machine'); \
-    New-Item C:\Users\ContainerAdministrator\.android\repositories.cfg; \
-    sdkmanager 'ndk-bundle'; \
-    sdkmanager 'platforms;android-33'; \
-    sdkmanager 'add-ons;addon-google_apis-google-24'; \
-    sdkmanager 'cmake;3.10.2.4988404'; \
-    sdkmanager 'cmake;3.18.1'; \
-    sdkmanager 'cmake;3.22.1'; \
-    sdkmanager 'cmake;3.6.4111459'; \
-    sdkmanager 'emulator'; \
-    sdkmanager 'system-images;android-27;google_apis;x86'; \
-    sdkmanager 'sources;android-27'; \
-    sdkmanager 'extras;google;Android_Emulator_Hypervisor_Driver'; \
-    sdkmanager 'extras;google;auto'; \
-    sdkmanager 'extras;google;google_play_services'; \
-    sdkmanager 'extras;google;instantapps'; \
-    sdkmanager 'extras;google;m2repository'; \
-    sdkmanager 'extras;google;market_apk_expansion'; \
-    sdkmanager 'extras;google;market_licensing'; \
-    sdkmanager 'extras;google;simulators'; \
-    sdkmanager 'extras;google;usb_driver'; \
-    sdkmanager 'extras;google;webdriver'; \
-    sdkmanager 'extras;android;m2repository'; \
-    sdkmanager 'extras;intel;Hardware_Accelerated_Execution_Manager'; \
-    sdkmanager 'extras;m2repository;com;android;support;constraint;constraint-layout;1.0.0'; \
-    sdkmanager 'extras;m2repository;com;android;support;constraint;constraint-layout-solver;1.0.2'; \
-    sdkmanager 'patcher;v4'; \
-    sdkmanager 'ndk;25.1.8937393'; \
-    sdkmanager 'build-tools;27.0.3';
-
-# Install Scoop and packages
-RUN iex \"& {$(irm get.scoop.sh)} -RunAsAdmin\"; \
-    scoop install perl; \
-    scoop install bazel; \
-    scoop install cuda; \
-    scoop install azure-functions-core-tools; \
-    scoop install azure-cli;
-
-# Setting environment variables
-RUN [Environment]::SetEnvironmentVariable('CYGWIN', 'winsymlinks:native', 'Machine'); \
-    [Environment]::SetEnvironmentVariable('HOME', 'C:\Users\ContainerAdministrator\', 'Machine'); \
-    [Environment]::SetEnvironmentVariable('HOMEDRIVE', 'C:', 'Machine'); \
-    [Environment]::SetEnvironmentVariable('HOMEPATH', '\Users\ContainerAdministrator\', 'Machine'); \
-    [Environment]::SetEnvironmentVariable('GOROOT', 'C:\Program Files\Go\', 'Machine'); \
-    [Environment]::SetEnvironmentVariable('KOKORO_POSIX_ROOT', '/tmpfs', 'Machine'); \
-    [Environment]::SetEnvironmentVariable('KOKORO_ROOT', 'T:\', 'Machine'); \
-    [Environment]::SetEnvironmentVariable('SHELL', '/bin/bash', 'Machine'); \
-    $old_path = [Environment]::GetEnvironmentVariable(\"PATH\", \"Machine\"); \
-    [Environment]::SetEnvironmentVariable(\"PATH\", $old_path + \";C:\Program Files\CMake\bin\", \"Machine\");
-
-
-# Restore default shell for Windows containers.
-SHELL ["cmd.exe", "/s", "/c"]
-
-# Default to PowerShell if no other command specified.
-CMD ["powershell.exe", "-NoLogo", "-ExecutionPolicy", "Bypass"]
diff --git a/ci/official/containers/linux_arm64/build.sh b/ci/official/containers/linux_arm64/build.sh
index 611d5f4..ffead7f 100755
--- a/ci/official/containers/linux_arm64/build.sh
+++ b/ci/official/containers/linux_arm64/build.sh
@@ -16,8 +16,8 @@
 
 # Builds the following Docker images for Linux ARM64. See the accompanying
 # Dockerfile for more details:
-# - gcr.io/tensorflow-sigs/build-arm64:jax-latest-multi-python
-# - gcr.io/tensorflow-sigs/build-arm64:tf-latest-multi-python
+# - us-central1-docker.pkg.dev/tensorflow-sigs/build-arm64:jax-latest-multi-python
+# - us-central1-docker.pkg.dev/tensorflow-sigs/build-arm64:tf-latest-multi-python
 
 set -exo pipefail
 
@@ -40,16 +40,14 @@
   fi
 fi
 
-# TODO(b/341050361): When these steps are verified, removed the GCR image code.
 AR_IMAGE_PATH="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/build-arm64"
 
 # Build for both JAX and TF usage.  We do these in one place because they share
 # almost all of the same cache layers
 export DOCKER_BUILDKIT=1
 for target in jax tf; do
-  IMAGE="gcr.io/tensorflow-sigs/build-arm64:$target-$TAG"
   AR_IMAGE="$AR_IMAGE_PATH:$target-$TAG"
-  docker pull "$IMAGE" || true
+  docker pull "$AR_IMAGE" || true
   # Due to some flakiness of resources pulled in the build, allow the docker
   # command to reattempt build a few times in the case of failure (b/302558736)
   set +e
@@ -58,8 +56,8 @@
     docker build \
     --build-arg REQUIREMENTS_FILE=jax.requirements.txt \
     --target=$target \
-    --cache-from "$IMAGE" \
-    -t "$IMAGE" -t "$AR_IMAGE" . && break
+    --cache-from "$AR_IMAGE" \
+    -t "$AR_IMAGE" . && break
   done
   final=$?
   if [ $final -ne 0 ]; then
@@ -68,8 +66,6 @@
   set -e
 
   if [[ -n "$KOKORO_BUILD_ID" ]]; then
-    gcloud auth configure-docker
-    docker push "$IMAGE"
     gcloud auth configure-docker us-central1-docker.pkg.dev
     docker push "$AR_IMAGE"
   fi
diff --git a/ci/official/envs/rbe b/ci/official/envs/rbe
index 12cc600..35f8173 100644
--- a/ci/official/envs/rbe
+++ b/ci/official/envs/rbe
@@ -33,7 +33,17 @@
 fi
 TFCI_BAZEL_COMMON_ARGS="$TFCI_BAZEL_COMMON_ARGS --config rbe_$TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX"
 
-# These flags share the user's gcloud credentials with the container, so that bazel
-# inside the container can authenticate. Note: TF's CI does not have any credential
-# stored here.
-TFCI_DOCKER_ARGS="$TFCI_DOCKER_ARGS -v $HOME/.config/gcloud:/root/.config/gcloud"
+if [[ `uname -s | grep -P '^MSYS_NT'` ]]; then
+  # Docker on Windows doesn't support the `host` networking mode, and so
+  # port-forwarding is required for the container to detect it's running on GCE.
+  export IP_ADDR=$(powershell -command "(Get-NetIPAddress -AddressFamily IPv4 -InterfaceAlias 'vEthernet (nat)').IPAddress")
+  netsh interface portproxy add v4tov4 listenaddress=$IP_ADDR listenport=80 connectaddress=169.254.169.254 connectport=80
+  # A local firewall rule for the container is added in
+  # ci/official/utilities/setup_docker.sh.
+else
+  # The volume mapping flag below shares the user's gcloud credentials, if any,
+  # with the container, in case the user has credentials stored there.
+  # This would allow Bazel to authenticate for RBE.
+  # Note: TF's CI does not have any credentials stored there.
+  TFCI_DOCKER_ARGS="$TFCI_DOCKER_ARGS -v $HOME/.config/gcloud:/root/.config/gcloud"
+fi
diff --git a/ci/official/envs/windows_x86 b/ci/official/envs/windows_x86
new file mode 100644
index 0000000..568a47f
--- /dev/null
+++ b/ci/official/envs/windows_x86
@@ -0,0 +1,20 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+TFCI_DOCKER_ENABLE=1
+TFCI_DOCKER_PULL_ENABLE=1
+TFCI_DOCKER_IMAGE="gcr.io/tensorflow-testing/tf-win2019-rbe@sha256:1082ef4299a72e44a84388f192ecefc81ec9091c146f507bc36070c089c0edcc"
+TFCI_BAZEL_COMMON_ARGS="--repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION"
+TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=windows_x86_cpu
+TFCI_OUTPUT_DIR=build_output
diff --git a/ci/official/pycpp.sh b/ci/official/pycpp.sh
index cf34600..f6f2090 100755
--- a/ci/official/pycpp.sh
+++ b/ci/official/pycpp.sh
@@ -15,12 +15,19 @@
 # ==============================================================================
 source "${BASH_SOURCE%/*}/utilities/setup.sh"
 
-if [[ $TFCI_PYCPP_SWAP_TO_BUILD_ENABLE == 1 ]]; then
-   tfrun bazel build $TFCI_BAZEL_COMMON_ARGS --profile "$TFCI_OUTPUT_DIR/profile.json.gz" --config="${TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX}_pycpp_test"
+if [[ `uname -s | grep -P '^MSYS_NT'` ]]; then
+  PROFILE_JSON_PATH=$(replace_drive_letter_with_c "$TFCI_OUTPUT_DIR")
+  PROFILE_JSON_PATH="$PROFILE_JSON_PATH/profile.json.gz"
 else
-   tfrun bazel test $TFCI_BAZEL_COMMON_ARGS --profile "$TFCI_OUTPUT_DIR/profile.json.gz" --config="${TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX}_pycpp_test"
+  PROFILE_JSON_PATH="$TFCI_OUTPUT_DIR/profile.json.gz"
+fi
+
+if [[ $TFCI_PYCPP_SWAP_TO_BUILD_ENABLE == 1 ]]; then
+   tfrun bazel build $TFCI_BAZEL_COMMON_ARGS --profile "$PROFILE_JSON_PATH" --config="${TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX}_pycpp_test"
+else
+   tfrun bazel test $TFCI_BAZEL_COMMON_ARGS --profile "$PROFILE_JSON_PATH"  --config="${TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX}_pycpp_test"
 fi
 
 # Note: the profile can be viewed by visiting chrome://tracing in a Chrome browser.
 # See https://docs.bazel.build/versions/main/skylark/performance.html#performance-profiling
-tfrun bazel analyze-profile "$TFCI_OUTPUT_DIR/profile.json.gz"
+tfrun bazel analyze-profile "$PROFILE_JSON_PATH"
diff --git a/ci/official/utilities/cleanup_summary.sh b/ci/official/utilities/cleanup_summary.sh
index dbe2203..6b6fdfa 100755
--- a/ci/official/utilities/cleanup_summary.sh
+++ b/ci/official/utilities/cleanup_summary.sh
@@ -23,8 +23,9 @@
 can view more detailed results that are probably easier to read than this log.
 Try the links below:
 EOF
-  # Find any "Streaming build results to" line, then print the last word in it,
-  # and don't print duplicates
+  # Find any "Streaming build results to" lines,
+  # de-duplicate,
+  # and print the last word from each
   awk '/Streaming build results to/ {print $NF}' "$TFCI_OUTPUT_DIR/script.log" | uniq
 }
 
@@ -32,14 +33,15 @@
 # Each failed target there will have its own representation, making failures
 # easier to find and read.
 function resultstore_extract {
-  local \
-    XML_PATH="$TFCI_OUTPUT_DIR/Bazel_Test_and_Build_Results/sponge_log.xml"
+  local PYTHON_BIN XML_PATH
+  PYTHON_BIN=$(which python3 2>/dev/null || which python)
+  XML_PATH="$TFCI_OUTPUT_DIR/Bazel_Test_and_Build_Results/sponge_log.xml"
 
-  python3 \
+  "$PYTHON_BIN" \
     "$TFCI_GIT_DIR/ci/official/utilities/extract_resultstore_links.py" \
     "$TFCI_OUTPUT_DIR/script.log" \
     --print \
-    --xml-out-path "$XML_PATH" || resultstore_extract_fallback
+    --xml-out-path "$XML_PATH"
 }
 
 if grep -q "Streaming build results to" "$TFCI_OUTPUT_DIR/script.log"; then
diff --git a/ci/official/utilities/convert_msys_paths_to_win_paths.py b/ci/official/utilities/convert_msys_paths_to_win_paths.py
new file mode 100644
index 0000000..ed1dd3b
--- /dev/null
+++ b/ci/official/utilities/convert_msys_paths_to_win_paths.py
@@ -0,0 +1,76 @@
+#!/usr/bin/python3
+# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+
+"""Converts MSYS Linux-like paths stored in env variables to Windows paths.
+
+This is necessary on Windows, because some applications do not understand/handle
+Linux-like paths MSYS uses, for example, Docker.
+"""
+
+import argparse
+import os
+
+
+def should_convert(var_name: str,
+                   blacklist: list[str] | None,
+                   whitelist_prefix: list[str] | None):
+  """Check the variable name against white/black lists."""
+  if blacklist and var_name in blacklist:
+    return False
+  if not whitelist_prefix:
+    return True
+
+  for prefix in whitelist_prefix:
+    if var_name.startswith(prefix):
+      return True
+  return False
+
+
+def main(parsed_args: argparse.Namespace):
+  converted_vars = {}
+
+  for var, value in os.environ.items():
+    if not value or not should_convert(var,
+                                       parsed_args.blacklist,
+                                       parsed_args.whitelist_prefix):
+      continue
+
+    # In Python, MSYS, Linux-like paths are automatically read as Windows paths
+    # with forward slashes, e.g. 'C:/Program Files', instead of
+    # '/c/Program Files', thus becoming converted simply by virtue of having
+    # been read.
+    converted_vars[var] = value
+
+  var_str = '\n'.join(f'{k}="{v}"'
+                      for k, v in converted_vars.items())
+  # The string can then be piped into `source`, to re-set the
+  # 'converted' variables.
+  print(var_str)
+
+
+if __name__ == '__main__':
+  parser = argparse.ArgumentParser(description=(
+      'Convert MSYS paths in environment variables to Windows paths.'))
+  parser.add_argument('--blacklist',
+                      nargs='*',
+                      help='List of variables to ignore')
+  parser.add_argument('--whitelist-prefix',
+                      nargs='*',
+                      help='Prefix for variables to include')
+  args = parser.parse_args()
+
+  main(args)
diff --git a/ci/official/utilities/setup.sh b/ci/official/utilities/setup.sh
index 2277b75..55d0e07 100755
--- a/ci/official/utilities/setup.sh
+++ b/ci/official/utilities/setup.sh
@@ -118,6 +118,12 @@
 # functionality instead.
 tfrun() { "$@"; }
 
+if [[ `uname -s | grep -P '^MSYS_NT'` ]]; then
+  source ./ci/official/utilities/windows.sh
+  echo 'Converting MSYS Linux-like paths to Windows paths (for Docker, Python, etc.)'
+  source <(python ./ci/official/utilities/convert_msys_paths_to_win_paths.py --whitelist-prefix TFCI_)
+fi
+
 # Run all "tfrun" commands under Docker. See setup_docker.sh for details
 if [[ "$TFCI_DOCKER_ENABLE" == 1 ]]; then
   source ./ci/official/utilities/setup_docker.sh
diff --git a/ci/official/utilities/setup_docker.sh b/ci/official/utilities/setup_docker.sh
index 91618c7..61db7c2 100755
--- a/ci/official/utilities/setup_docker.sh
+++ b/ci/official/utilities/setup_docker.sh
@@ -37,10 +37,30 @@
   # Pass all existing TFCI_ variables into the Docker container
   env_file=$(mktemp)
   env | grep ^TFCI_ > "$env_file"
-  docker run $TFCI_DOCKER_ARGS --name tf -w "$TFCI_GIT_DIR" -itd --rm \
-      -v "$TFCI_GIT_DIR:$TFCI_GIT_DIR" \
+
+  WORKING_DIR="$TFCI_GIT_DIR"
+  if [[ `uname -s | grep -P '^MSYS_NT'` ]]; then
+    env_file=$(cygpath -m $env_file)
+    # Host dirs can only be mapped to an existing drive inside the container, so
+    # T:\ is replaced with C:\.
+    _TFCI_OUTPUT_DIR_WIN=$(replace_drive_letter_with_c "$TFCI_OUTPUT_DIR")
+    sed -iE 's|^TFCI_OUTPUT_DIR=.*|TFCI_OUTPUT_DIR='"$_TFCI_OUTPUT_DIR_WIN"'|g' $env_file
+    WORKING_DIR=$(replace_drive_letter_with_c "$TFCI_GIT_DIR")
+    echo "GCE_METADATA_HOST=$IP_ADDR" > $env_file
+  fi
+
+  docker run $TFCI_DOCKER_ARGS --name tf -w "$WORKING_DIR" -itd --rm \
+      -v "$TFCI_GIT_DIR:$WORKING_DIR" \
       --env-file "$env_file" \
       "$TFCI_DOCKER_IMAGE" \
     bash
+
+  if [[ `uname -s | grep -P '^MSYS_NT'` ]]; then
+    # Allow requests from the container.
+    # Additional setup is contained in ci/official/envs/rbe.
+    CONTAINER_IP_ADDR=$(docker inspect -f '{{range .NetworkSettings.Networks}}{{.IPAddress}}{{end}}' tf)
+    netsh advfirewall firewall add rule name="Allow Metadata Proxy" dir=in action=allow protocol=TCP localport=80 remoteip="$CONTAINER_IP_ADDR"
+  fi
+
 fi
 tfrun() { docker exec tf "$@"; }
diff --git a/ci/official/utilities/windows.sh b/ci/official/utilities/windows.sh
new file mode 100644
index 0000000..1ab2d89
--- /dev/null
+++ b/ci/official/utilities/windows.sh
@@ -0,0 +1,26 @@
+#!/bin/bash
+# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# Windows-specific utilities.
+#
+
+# Docker on Windows has difficulty using volumes other than C:\, when it comes
+# to setting up up volume mappings.
+# Thus, the drive letter is replaced with C:\, in case it's
+# something else (ex. T:), which is frequently the case inside Kokoro jobs.
+function replace_drive_letter_with_c () {
+  sed -E "s|^[a-zA-Z]:|C:|g" <<< $1
+}
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 3556df1..c96cd8c 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -1445,7 +1445,7 @@
         "//tensorflow/lite/kernels/internal:compatibility",
         "//tensorflow/lite/kernels:builtin_ops",
         "//tensorflow/lite/kernels:reference_ops",
-        "//tensorflow/lite/schema:schema_fbs",
+        "//tensorflow/compiler/mlir/lite/schema:schema_fbs",
         "//tensorflow/lite/toco/logging:conversion_log_util",
         "//tensorflow/lite/toco/logging:toco_conversion_log_proto_cc",
         "//tensorflow/lite/toco:model_flags_proto_cc",
diff --git a/tensorflow/c/eager/parallel_device/parallel_device_lib_test.cc b/tensorflow/c/eager/parallel_device/parallel_device_lib_test.cc
index 86b201f..1123ccb 100644
--- a/tensorflow/c/eager/parallel_device/parallel_device_lib_test.cc
+++ b/tensorflow/c/eager/parallel_device/parallel_device_lib_test.cc
@@ -24,6 +24,7 @@
 #include "tensorflow/c/tf_buffer.h"
 #include "tensorflow/c/tf_datatype.h"
 #include "tensorflow/c/tf_status.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/common_runtime/eager/context.h"
 #include "tensorflow/core/framework/cancellation.h"
 #include "tensorflow/core/framework/function.h"
@@ -32,7 +33,6 @@
 #include "tensorflow/core/framework/types.pb.h"
 #include "tensorflow/core/platform/status.h"
 #include "tensorflow/core/platform/test.h"
-#include "tsl/lib/core/status_test_util.h"
 
 namespace tensorflow {
 namespace parallel_device {
diff --git a/tensorflow/c/eager/parallel_device/parallel_device_test.cc b/tensorflow/c/eager/parallel_device/parallel_device_test.cc
index ec2ce95..88ef5c1 100644
--- a/tensorflow/c/eager/parallel_device/parallel_device_test.cc
+++ b/tensorflow/c/eager/parallel_device/parallel_device_test.cc
@@ -27,8 +27,8 @@
 #include "tensorflow/c/tf_datatype.h"
 #include "tensorflow/c/tf_status.h"
 #include "tensorflow/c/tf_status_internal.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/test.h"
-#include "tsl/lib/core/status_test_util.h"
 
 // NOTE(allenl): These tests currently go through TFE_Execute and so are
 // integration testing rather than purely testing the parallel device. They
diff --git a/tensorflow/c/experimental/grappler/grappler_test.cc b/tensorflow/c/experimental/grappler/grappler_test.cc
index 57f1a65..357432e 100644
--- a/tensorflow/c/experimental/grappler/grappler_test.cc
+++ b/tensorflow/c/experimental/grappler/grappler_test.cc
@@ -19,6 +19,7 @@
 #include "tensorflow/c/tf_buffer.h"
 #include "tensorflow/c/tf_buffer_internal.h"
 #include "tensorflow/c/tf_status.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/function.h"
 #include "tensorflow/core/framework/node_def.pb.h"
 #include "tensorflow/core/framework/op.h"
@@ -33,7 +34,6 @@
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/protobuf/error_codes.pb.h"
 
 namespace tensorflow {
diff --git a/tensorflow/c/experimental/next_pluggable_device/BUILD b/tensorflow/c/experimental/next_pluggable_device/BUILD
index 56586f7..45c55c3 100644
--- a/tensorflow/c/experimental/next_pluggable_device/BUILD
+++ b/tensorflow/c/experimental/next_pluggable_device/BUILD
@@ -87,7 +87,6 @@
         "//tensorflow/core/tfrt/common:pjrt_util",
         "@com_google_absl//absl/log:check",
         "@com_google_googletest//:gtest_main",
-        "@local_tsl//tsl/lib/core:status_test_util",
         "@local_tsl//tsl/platform:casts",
         "@local_tsl//tsl/platform:status_matchers",
         "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc",
@@ -98,5 +97,6 @@
         "@local_xla//xla/pjrt/c:pjrt_c_api_hdrs",
         "@local_xla//xla/pjrt/c:pjrt_c_api_wrapper_impl",
         "@local_xla//xla/pjrt/cpu:cpu_client",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
diff --git a/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util_test.cc b/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util_test.cc
index 7f45fd9..1952364 100644
--- a/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util_test.cc
+++ b/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util_test.cc
@@ -30,10 +30,10 @@
 #include "xla/pjrt/pjrt_c_api_client.h"
 #include "xla/shape.h"
 #include "xla/shape_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/tfrt/common/async_value_tensor.h"
 #include "tensorflow/core/tfrt/common/pjrt_util.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/casts.h"
 #include "tsl/platform/status_matchers.h"
 #include "tsl/protobuf/error_codes.pb.h"
diff --git a/tensorflow/c/experimental/ops/BUILD b/tensorflow/c/experimental/ops/BUILD
index 7284261..76f1db6 100644
--- a/tensorflow/c/experimental/ops/BUILD
+++ b/tensorflow/c/experimental/ops/BUILD
@@ -22,11 +22,11 @@
         "//tensorflow/c/eager:abstract_context",
         "//tensorflow/c/eager:abstract_operation",
         "//tensorflow/c/eager:abstract_tensor_handle",
-        "//tensorflow/c/eager:c_api_unified_internal",
         "//tensorflow/c/eager:tracing_utils",
-        "//tensorflow/core:framework",
-        "//tensorflow/core/platform:errors",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core/platform:status",
         "@com_google_absl//absl/types:span",
+        "@local_tsl//tsl/platform:errors",
     ],
 )
 
diff --git a/tensorflow/c/experimental/ops/array_ops.cc b/tensorflow/c/experimental/ops/array_ops.cc
index db9464d..23deef1 100644
--- a/tensorflow/c/experimental/ops/array_ops.cc
+++ b/tensorflow/c/experimental/ops/array_ops.cc
@@ -17,11 +17,14 @@
 
 #include "tensorflow/c/experimental/ops/array_ops.h"
 
+#include "absl/types/span.h"
 #include "tensorflow/c/eager/abstract_context.h"
+#include "tensorflow/c/eager/abstract_operation.h"
 #include "tensorflow/c/eager/abstract_tensor_handle.h"
 #include "tensorflow/c/eager/tracing_utils.h"
-#include "tensorflow/core/framework/types.h"
-#include "tensorflow/core/platform/errors.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/platform/status.h"
+#include "tsl/platform/errors.h"
 
 using tensorflow::tracing::MaybeSetOpName;
 
diff --git a/tensorflow/c/experimental/ops/array_ops.h b/tensorflow/c/experimental/ops/array_ops.h
index f4d170a..466c36f 100644
--- a/tensorflow/c/experimental/ops/array_ops.h
+++ b/tensorflow/c/experimental/ops/array_ops.h
@@ -18,8 +18,11 @@
 #ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_ARRAY_OPS_H_
 #define TENSORFLOW_C_EXPERIMENTAL_OPS_ARRAY_OPS_H_
 
+#include "absl/types/span.h"
 #include "tensorflow/c/eager/abstract_context.h"
 #include "tensorflow/c/eager/abstract_tensor_handle.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/platform/status.h"
 
 namespace tensorflow {
 namespace ops {
diff --git a/tensorflow/c/experimental/saved_model/core/ops/restore_ops_test.cc b/tensorflow/c/experimental/saved_model/core/ops/restore_ops_test.cc
index 3a6de51..463f64c 100644
--- a/tensorflow/c/experimental/saved_model/core/ops/restore_ops_test.cc
+++ b/tensorflow/c/experimental/saved_model/core/ops/restore_ops_test.cc
@@ -19,6 +19,7 @@
 #include "tensorflow/c/experimental/saved_model/core/test_utils.h"
 #include "tensorflow/c/tensor_interface.h"
 #include "tensorflow/cc/saved_model/constants.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/common_runtime/device_mgr.h"
 #include "tensorflow/core/common_runtime/eager/context.h"
 #include "tensorflow/core/framework/types.pb.h"
@@ -26,7 +27,6 @@
 #include "tensorflow/core/platform/status.h"
 #include "tensorflow/core/platform/stringpiece.h"
 #include "tensorflow/core/platform/test.h"
-#include "tsl/lib/core/status_test_util.h"
 
 namespace tensorflow {
 namespace {
diff --git a/tensorflow/c/experimental/saved_model/core/ops/variable_ops_test.cc b/tensorflow/c/experimental/saved_model/core/ops/variable_ops_test.cc
index d06608f..18d7498 100644
--- a/tensorflow/c/experimental/saved_model/core/ops/variable_ops_test.cc
+++ b/tensorflow/c/experimental/saved_model/core/ops/variable_ops_test.cc
@@ -20,12 +20,12 @@
 #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
 #include "tensorflow/c/experimental/saved_model/core/test_utils.h"
 #include "tensorflow/c/tensor_interface.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/common_runtime/device_mgr.h"
 #include "tensorflow/core/common_runtime/eager/context.h"
 #include "tensorflow/core/framework/types.pb.h"
 #include "tensorflow/core/platform/status.h"
 #include "tensorflow/core/platform/test.h"
-#include "tsl/lib/core/status_test_util.h"
 
 namespace tensorflow {
 namespace {
diff --git a/tensorflow/c/experimental/stream_executor/stream_executor.cc b/tensorflow/c/experimental/stream_executor/stream_executor.cc
index 94441a7..ece9057 100644
--- a/tensorflow/c/experimental/stream_executor/stream_executor.cc
+++ b/tensorflow/c/experimental/stream_executor/stream_executor.cc
@@ -230,7 +230,7 @@
     return std::make_unique<HostMemoryAllocation>(buffer, size, this);
   }
 
-  void HostMemoryDeallocate(void* mem, uint64_t size) override {
+  void HostMemoryDeallocate(void* mem) override {
     stream_executor_->host_memory_deallocate(&device_, mem);
   }
 
@@ -432,7 +432,6 @@
       name_(platform.name) {}
 
 CPlatform::~CPlatform() {
-  executor_cache_.DestroyAllExecutors();
   platform_fns_.destroy_device_fns(&platform_, &device_fns_);
   platform_fns_.destroy_stream_executor(&platform_, &stream_executor_);
   platform_fns_.destroy_timer_fns(&platform_, &timer_fns_);
@@ -454,6 +453,11 @@
   config.ordinal = ordinal;
   return GetExecutor(config);
 }
+absl::StatusOr<StreamExecutor*> CPlatform::FindExisting(int ordinal) {
+  stream_executor::StreamExecutorConfig config;
+  config.ordinal = ordinal;
+  return executor_cache_.Get(config);
+}
 absl::StatusOr<StreamExecutor*> CPlatform::GetExecutor(
     const StreamExecutorConfig& config) {
   return executor_cache_.GetOrCreate(
diff --git a/tensorflow/c/experimental/stream_executor/stream_executor_internal.h b/tensorflow/c/experimental/stream_executor/stream_executor_internal.h
index e3e025c..d87794a 100644
--- a/tensorflow/c/experimental/stream_executor/stream_executor_internal.h
+++ b/tensorflow/c/experimental/stream_executor/stream_executor_internal.h
@@ -19,6 +19,7 @@
 #define TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
 
 #include <cstdint>
+#include <memory>
 #include <string>
 #include <utility>
 
@@ -99,12 +100,15 @@
   absl::StatusOr<StreamExecutor*> ExecutorForDevice(int ordinal) override;
   absl::StatusOr<StreamExecutor*> GetExecutor(
       const StreamExecutorConfig& config) override;
-  absl::StatusOr<std::unique_ptr<StreamExecutor>> GetUncachedExecutor(
-      const StreamExecutorConfig& config) override;
-
-  void DestroyAllExecutors() { executor_cache_.DestroyAllExecutors(); }
+  absl::StatusOr<StreamExecutor*> FindExisting(int ordinal) override;
 
  private:
+  // Returns a device constructed with the options specified in "config" without
+  // looking in or storing to the Platform's executor cache.
+  // Ownership IS transferred to the caller.
+  absl::StatusOr<std::unique_ptr<StreamExecutor>> GetUncachedExecutor(
+      const StreamExecutorConfig& config);
+
   SP_Platform platform_;
   void (*destroy_platform_)(SP_Platform*);
   SP_PlatformFns platform_fns_;
diff --git a/tensorflow/c/kernels.cc b/tensorflow/c/kernels.cc
index 0082c65..20aded8 100644
--- a/tensorflow/c/kernels.cc
+++ b/tensorflow/c/kernels.cc
@@ -794,10 +794,7 @@
 #else
   const auto* device = reinterpret_cast<const tensorflow::Device*>(
       device_base->UnderlyingDevice());
-  const absl::StatusOr<int> id = tsl::GetDeviceIdFromDeviceParsedName(
-      device->parsed_name(), tensorflow::DeviceType(device->device_type()));
-  if (!id.ok()) return -1;
-  return *id;
+  return tsl::GetDeviceIdFromDeviceParsedName(device->parsed_name());
 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
 }
 
diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD
index da27e61..02fd678 100644
--- a/tensorflow/cc/saved_model/BUILD
+++ b/tensorflow/cc/saved_model/BUILD
@@ -588,12 +588,12 @@
         "@com_google_absl//absl/status:statusor",
         "@com_google_absl//absl/strings",
         "@com_google_googletest//:gtest_main",
-        "@local_tsl//tsl/lib/core:status_test_util",
         "@local_tsl//tsl/platform:errors",
         "@local_tsl//tsl/platform:protobuf",
         "@local_tsl//tsl/platform:status_matchers",
         "@local_tsl//tsl/platform:statusor",
         "@local_tsl//tsl/platform:test",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
 
diff --git a/tensorflow/cc/saved_model/bundle_v2_test.cc b/tensorflow/cc/saved_model/bundle_v2_test.cc
index a0bbb82..1380282 100644
--- a/tensorflow/cc/saved_model/bundle_v2_test.cc
+++ b/tensorflow/cc/saved_model/bundle_v2_test.cc
@@ -28,10 +28,10 @@
 #include "json/reader.h"
 #include "json/value.h"
 #include "tensorflow/cc/saved_model/metrics.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/path.h"
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/protobuf/trackable_object_graph.pb.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/statusor.h"
 
 namespace tensorflow {
diff --git a/tensorflow/cc/saved_model/fingerprinting_utils_test.cc b/tensorflow/cc/saved_model/fingerprinting_utils_test.cc
index 1f6b0e1..3182afc 100644
--- a/tensorflow/cc/saved_model/fingerprinting_utils_test.cc
+++ b/tensorflow/cc/saved_model/fingerprinting_utils_test.cc
@@ -25,6 +25,7 @@
 #include "absl/status/statusor.h"
 #include "absl/strings/str_cat.h"
 #include "absl/strings/string_view.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/graph.pb.h"
 #include "tensorflow/core/platform/path.h"
 #include "tensorflow/core/protobuf/meta_graph.pb.h"
@@ -32,7 +33,6 @@
 #include "tensorflow/tools/proto_splitter/cc/util.h"
 #include "tensorflow/tools/proto_splitter/chunk.pb.h"
 #include "tensorflow/tools/proto_splitter/testdata/test_message.pb.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/errors.h"
 #include "tsl/platform/status_matchers.h"
 #include "tsl/platform/statusor.h"
diff --git a/tensorflow/cc/tools/freeze_saved_model_test.cc b/tensorflow/cc/tools/freeze_saved_model_test.cc
index d74c0cd..eb4ef40 100644
--- a/tensorflow/cc/tools/freeze_saved_model_test.cc
+++ b/tensorflow/cc/tools/freeze_saved_model_test.cc
@@ -23,6 +23,7 @@
 #include "tensorflow/cc/ops/resource_variable_ops.h"
 #include "tensorflow/cc/ops/state_ops.h"
 #include "tensorflow/cc/saved_model/loader.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/function_testlib.h"
 #include "tensorflow/core/framework/graph.pb.h"
 #include "tensorflow/core/framework/tensor.h"
@@ -35,7 +36,6 @@
 #include "tensorflow/core/protobuf/meta_graph.pb.h"
 #include "tensorflow/core/public/session.h"
 #include "tensorflow/core/public/session_options.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/errors.h"
 
 namespace tensorflow {
diff --git a/tensorflow/cc/training/coordinator_test.cc b/tensorflow/cc/training/coordinator_test.cc
index e3f9f78..211fd1e 100644
--- a/tensorflow/cc/training/coordinator_test.cc
+++ b/tensorflow/cc/training/coordinator_test.cc
@@ -16,6 +16,7 @@
 #include "tensorflow/cc/training/coordinator.h"
 
 #include "absl/status/status.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/lib/core/notification.h"
 #include "tensorflow/core/platform/blocking_counter.h"
 #include "tensorflow/core/platform/env.h"
@@ -23,7 +24,6 @@
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/platform/threadpool.h"
 #include "tensorflow/core/protobuf/error_codes.pb.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/protobuf/error_codes.pb.h"
 
 namespace tensorflow {
diff --git a/tensorflow/cc/training/queue_runner_test.cc b/tensorflow/cc/training/queue_runner_test.cc
index 9a5f612..f4de69b 100644
--- a/tensorflow/cc/training/queue_runner_test.cc
+++ b/tensorflow/cc/training/queue_runner_test.cc
@@ -26,6 +26,7 @@
 #include "tensorflow/cc/ops/random_ops.h"
 #include "tensorflow/cc/ops/state_ops.h"
 #include "tensorflow/cc/training/coordinator.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/cost_graph.pb.h"
 #include "tensorflow/core/framework/graph.pb.h"
 #include "tensorflow/core/framework/tensor.h"
@@ -40,7 +41,6 @@
 #include "tensorflow/core/protobuf/queue_runner.pb.h"
 #include "tensorflow/core/public/session.h"
 #include "tensorflow/core/public/session_options.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/status.h"
 #include "tsl/protobuf/error_codes.pb.h"
 
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index eddd237..6efe665 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -518,7 +518,6 @@
         ":internal",
         # We reuse VariableInfo in TFRT's implementation of TpuExecuteOp.
         "//learning/brain/tfrt/tf_tpu:__pkg__",
-        "//learning/brain/tfrt/tpu_plugin:__pkg__",
         "//learning/brain/tfrt/tpu_common:__pkg__",
         "//tensorflow/core/common_runtime/next_pluggable_device:__pkg__",
     ],
@@ -539,9 +538,6 @@
         ":internal",
         # We reuse VariableInfo in TFRT's implementation of TpuExecuteOp.
         "//learning/brain/tfrt/tf_tpu:__pkg__",
-        "//learning/brain/tfrt/tpu_plugin:__pkg__",
-        "//learning/brain/tfrt/tpu_common:__pkg__",
-        "//tensorflow/core/common_runtime/next_pluggable_device:__pkg__",
     ],
     deps = [
         ":variable_info",
@@ -612,8 +608,6 @@
         # We reuse VariableInfo in TFRT's implementation of TpuExecuteOp.
         "//learning/brain/tfrt/tf_tpu:__pkg__",
         "//learning/brain/tfrt/tpu_plugin:__pkg__",
-        "//learning/brain/tfrt/tpu_common:__pkg__",
-        "//tensorflow/core/common_runtime/next_pluggable_device:__pkg__",
         "//tensorflow/core/tfrt/gpu/kernel:__pkg__",
     ],
     deps = [
@@ -678,7 +672,6 @@
         "//tensorflow/core/tfrt/common:pjrt_util",
         "@com_google_absl//absl/container:flat_hash_set",
         "@com_google_googletest//:gtest_main",
-        "@local_tsl//tsl/lib/core:status_test_util",
         "@local_tsl//tsl/platform:status",
         "@local_tsl//tsl/platform:statusor",
         "@local_xla//xla/pjrt:pjrt_client",
@@ -686,6 +679,7 @@
         "@local_xla//xla/pjrt:tfrt_cpu_pjrt_client",
         "@local_xla//xla/tests:literal_test_util",
         "@local_xla//xla/tsl/framework:device_id_utils",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
 
@@ -712,11 +706,11 @@
         "//tensorflow/core/tfrt/common:pjrt_util",
         "@com_google_absl//absl/container:flat_hash_set",
         "@com_google_googletest//:gtest_main",
-        "@local_tsl//tsl/lib/core:status_test_util",
         "@local_tsl//tsl/platform:status",
         "@local_tsl//tsl/platform:statusor",
         "@local_xla//xla/pjrt:pjrt_client",
         "@local_xla//xla/tests:literal_test_util",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
 
@@ -726,7 +720,6 @@
     hdrs = ["xla_compile_util.h"],
     visibility = [
         ":internal",
-        "//tensorflow/core/common_runtime/next_pluggable_device:__pkg__",
         "//tensorflow/core/tfrt/gpu/kernel:__pkg__",
     ],
     deps = [
@@ -770,10 +763,7 @@
     name = "device_compiler",
     hdrs = ["device_compiler.h"],
     copts = tf_copts(),
-    visibility = [
-        ":internal",
-        "//tensorflow/core/common_runtime/next_pluggable_device:__pkg__",
-    ],
+    visibility = [":internal"],
     deps = [
         ":device_compilation_cache",
         ":device_compilation_cluster_signature",
@@ -1118,7 +1108,6 @@
     ],
     visibility = [
         ":internal",
-        "//tensorflow/core/tfrt/utils:__pkg__",
         "//third_party/cloud_tpu/inference_converter:__pkg__",
         "//waymo/onboard/ml/chauffeur_net:__pkg__",
     ],
@@ -1564,10 +1553,7 @@
     name = "device_compiler_client",
     srcs = ["device_compiler_client.cc"],
     hdrs = ["device_compiler_client.h"],
-    visibility = [
-        ":internal",
-        "//tensorflow/core/common_runtime/next_pluggable_device:__pkg__",
-    ],
+    visibility = [":internal"],
     deps = [
         "//tensorflow/compiler/tf2xla:xla_compiler",
         "//tensorflow/core/util:determinism",
@@ -1596,6 +1582,7 @@
 
 cc_library(
     name = "device_executable_persistor",
+    srcs = ["device_executable_persistor.cc"],
     hdrs = ["device_executable_persistor.h"],
     deps = [
         ":xla_compilation_cache_proto_cc",
@@ -1608,6 +1595,8 @@
         "//tensorflow/core/platform:status",
         "//tensorflow/core/platform:statusor",
         "@com_google_absl//absl/log",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/strings",
         "@local_tsl//tsl/platform:statusor",
         "@local_xla//xla:util",
         "@local_xla//xla/pjrt:pjrt_client",
diff --git a/tensorflow/compiler/jit/device_context_test.cc b/tensorflow/compiler/jit/device_context_test.cc
index be85ff9..d02337d 100644
--- a/tensorflow/compiler/jit/device_context_test.cc
+++ b/tensorflow/compiler/jit/device_context_test.cc
@@ -21,8 +21,8 @@
 #include <gtest/gtest.h>
 #include "tensorflow/compiler/jit/flags.h"
 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/tensor_testutil.h"
-#include "tsl/lib/core/status_test_util.h"
 
 namespace tensorflow {
 namespace {
diff --git a/tensorflow/compiler/jit/device_executable_persistor.cc b/tensorflow/compiler/jit/device_executable_persistor.cc
new file mode 100644
index 0000000..b673af7
--- /dev/null
+++ b/tensorflow/compiler/jit/device_executable_persistor.cc
@@ -0,0 +1,37 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/device_executable_persistor.h"
+
+#include <string>
+
+#include "absl/strings/str_cat.h"
+
+namespace tensorflow {
+
+std::string XlaSerializedCacheKeyToFileName(const XlaSerializedCacheKey& key) {
+  static constexpr char kXlaSerializedCacheKeySeparator[] = "__";
+  return absl::StrCat(
+      key.prefix(), key.prefix().empty() ? "" : kXlaSerializedCacheKeySeparator,
+      key.signature_fingerprint(), kXlaSerializedCacheKeySeparator,
+      key.cluster_fingerprint(), kXlaSerializedCacheKeySeparator,
+      key.device_type(),
+      key.compiled_using_pjrt()
+          ? absl::StrCat(kXlaSerializedCacheKeySeparator, "pjrt")
+          : "",
+      ".pb");
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/compiler/jit/device_executable_persistor.h b/tensorflow/compiler/jit/device_executable_persistor.h
index 78d2089..0f546c0 100644
--- a/tensorflow/compiler/jit/device_executable_persistor.h
+++ b/tensorflow/compiler/jit/device_executable_persistor.h
@@ -20,6 +20,7 @@
 #include <string>
 
 #include "absl/log/log.h"
+#include "absl/status/status.h"
 #include "tensorflow/compiler/jit/xla_compilation_cache.pb.h"
 #include "tensorflow/compiler/jit/xla_device_compiler_client.h"
 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
@@ -35,6 +36,9 @@
 
 namespace tensorflow {
 
+// Returns the persisted compilation cache file name for the given key.
+std::string XlaSerializedCacheKeyToFileName(const XlaSerializedCacheKey& key);
+
 // Offers a way to persist and/or load compiled `ExecutableType`s along with the
 // corresponding HLO (`CompilationResult`) to/from `persistent_cache_directory`
 // (if one was provided during construction) on disk  using `ClientType`.
@@ -142,8 +146,6 @@
                                 const xla::HloModuleProto& hlo_module,
                                 const XlaSerializedCacheEntry& entry) const;
 
-  std::string XlaSerializedCacheKeyToString(
-      const XlaSerializedCacheKey& key) const;
   std::string GetFilePath(const XlaSerializedCacheKey& key) const;
 
   const DeviceType device_type_;
@@ -173,24 +175,9 @@
           config.persistent_cache_directory_read_only) {}
 
 template <typename ExecutableType, typename ClientType>
-std::string DeviceExecutablePersistor<ExecutableType, ClientType>::
-    XlaSerializedCacheKeyToString(const XlaSerializedCacheKey& key) const {
-  static constexpr char kXlaSerializedCacheKeySeparator[] = "__";
-  return absl::StrCat(
-      key.prefix(), key.prefix().empty() ? "" : kXlaSerializedCacheKeySeparator,
-      key.signature_fingerprint(), kXlaSerializedCacheKeySeparator,
-      key.cluster_fingerprint(), kXlaSerializedCacheKeySeparator,
-      key.device_type(),
-      key.compiled_using_pjrt()
-          ? absl::StrCat(kXlaSerializedCacheKeySeparator, "pjrt")
-          : "");
-}
-
-template <typename ExecutableType, typename ClientType>
 std::string DeviceExecutablePersistor<ExecutableType, ClientType>::GetFilePath(
     const XlaSerializedCacheKey& key) const {
-  const std::string file_name =
-      absl::StrCat(XlaSerializedCacheKeyToString(key), ".pb");
+  const std::string file_name = XlaSerializedCacheKeyToFileName(key);
   return io::JoinPath(persistent_cache_directory_, file_name);
 }
 
@@ -299,9 +286,10 @@
 
   // Write to temp location, then when that completes, atomically move into the
   // final location.
-  std::string temp_path = io::JoinPath(
-      persistent_cache_directory_, XlaSerializedCacheKeyToString(entry.key()));
-  if (!env->CreateUniqueFileName(&temp_path, ".pb.tmp")) {
+  std::string temp_path =
+      io::JoinPath(persistent_cache_directory_,
+                   XlaSerializedCacheKeyToFileName(entry.key()));
+  if (!env->CreateUniqueFileName(&temp_path, ".tmp")) {
     return absl::UnavailableError(absl::StrCat(
         "Could not create a unique file inside ", persistent_cache_directory_));
   }
diff --git a/tensorflow/compiler/jit/pjrt_device_context.cc b/tensorflow/compiler/jit/pjrt_device_context.cc
index 51b6e57..794f32d 100644
--- a/tensorflow/compiler/jit/pjrt_device_context.cc
+++ b/tensorflow/compiler/jit/pjrt_device_context.cc
@@ -52,12 +52,10 @@
                           cpu_tensor->shape(), cpu_tensor->dtype(),
                           /*fast_mem=*/false, layout_preference));
   const xla::Layout* device_layout = &(shape.layout());
-  // The device id should match the local_hardware_id in
+  // The device id should match the local_device_id in
   // tensorflow/compiler/xla/pjrt/pjrt_client.h.
-  TF_ASSIGN_OR_RETURN(
-      const int pjrt_device_id,
-      tsl::GetDeviceIdFromDeviceParsedName(device->parsed_name(),
-                                           DeviceType(device->device_type())));
+  const int pjrt_device_id =
+      tsl::GetDeviceIdFromDeviceParsedName(device->parsed_name());
   TF_ASSIGN_OR_RETURN(xla::PjRtDevice * pjrt_device,
                       pjrt_client->LookupAddressableDevice(
                           xla::PjRtLocalDeviceId(pjrt_device_id)));
@@ -260,12 +258,10 @@
   xla::PjRtBuffer* src_device_buffer =
       tensorflow::AsyncValueTensor::FromTensor(input)->GetBuffer().get();
 
-  // The device id should match the local_hardware_id in
+  // The device id should match the local_device_id in
   // tensorflow/compiler/xla/pjrt/pjrt_client.h.
   const int pjrt_dst_device_id =
-      tsl::GetDeviceIdFromDeviceParsedName(dst->parsed_name(),
-                                           DeviceType(dst->device_type()))
-          .value();
+      tsl::GetDeviceIdFromDeviceParsedName(dst->parsed_name());
   xla::PjRtDevice* pjrt_dst_device =
       (*pjrt_dst_client)
           ->LookupAddressableDevice(xla::PjRtLocalDeviceId(pjrt_dst_device_id))
diff --git a/tensorflow/compiler/jit/shape_inference_test.cc b/tensorflow/compiler/jit/shape_inference_test.cc
index 3f96101..eaabf18 100644
--- a/tensorflow/compiler/jit/shape_inference_test.cc
+++ b/tensorflow/compiler/jit/shape_inference_test.cc
@@ -29,6 +29,7 @@
 #include "tensorflow/cc/ops/resource_variable_ops.h"
 #include "tensorflow/cc/ops/standard_ops.h"
 #include "tensorflow/compiler/jit/test_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/op.h"
 #include "tensorflow/core/framework/partial_tensor_shape.h"
 #include "tensorflow/core/framework/tensor_shape.h"
@@ -36,7 +37,6 @@
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/platform/types.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/status.h"
 
 namespace tensorflow {
diff --git a/tensorflow/compiler/jit/xla_host_send_recv_device_context_test.cc b/tensorflow/compiler/jit/xla_host_send_recv_device_context_test.cc
index 62da04c..bec124f 100644
--- a/tensorflow/compiler/jit/xla_host_send_recv_device_context_test.cc
+++ b/tensorflow/compiler/jit/xla_host_send_recv_device_context_test.cc
@@ -26,8 +26,8 @@
 #include "xla/stream_executor/platform_manager.h"
 #include "xla/stream_executor/stream.h"
 #include "xla/stream_executor/stream_executor.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/tensor_testutil.h"
-#include "tsl/lib/core/status_test_util.h"
 
 namespace tensorflow {
 namespace {
diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc
index ebeeaef..27a8f16 100644
--- a/tensorflow/compiler/jit/xla_launch_util.cc
+++ b/tensorflow/compiler/jit/xla_launch_util.cc
@@ -853,9 +853,8 @@
                                           ->use_pjrt_tensor_buffer;
 
   const DeviceType& device_type = GetDeviceType(ctx);
-  TF_ASSIGN_OR_RETURN(const int pjrt_device_id,
-                      tsl::GetDeviceIdFromDeviceParsedName(
-                          ctx->device()->parsed_name(), device_type));
+  const int pjrt_device_id =
+      tsl::GetDeviceIdFromDeviceParsedName(ctx->device()->parsed_name());
   TF_ASSIGN_OR_RETURN(xla::PjRtDevice * device,
                       pjrt_client->LookupAddressableDevice(
                           xla::PjRtLocalDeviceId(pjrt_device_id)));
diff --git a/tensorflow/compiler/jit/xla_launch_util_gpu_test.cc b/tensorflow/compiler/jit/xla_launch_util_gpu_test.cc
index 0ba66c2..563e75c 100644
--- a/tensorflow/compiler/jit/xla_launch_util_gpu_test.cc
+++ b/tensorflow/compiler/jit/xla_launch_util_gpu_test.cc
@@ -39,6 +39,7 @@
 #include "xla/pjrt/pjrt_client.h"
 #include "xla/tests/literal_test_util.h"
 #include "xla/tsl/framework/allocator.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/allocator.h"
 #include "tensorflow/core/framework/device.h"
 #include "tensorflow/core/framework/fake_input.h"
@@ -50,7 +51,6 @@
 #include "tensorflow/core/platform/refcount.h"
 #include "tensorflow/core/tfrt/common/create_pjrt_client_util.h"
 #include "tensorflow/core/tfrt/common/pjrt_util.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/status.h"
 #include "tsl/platform/statusor.h"
 
diff --git a/tensorflow/compiler/jit/xla_launch_util_test.cc b/tensorflow/compiler/jit/xla_launch_util_test.cc
index d19e4fc..443fdf3 100644
--- a/tensorflow/compiler/jit/xla_launch_util_test.cc
+++ b/tensorflow/compiler/jit/xla_launch_util_test.cc
@@ -34,6 +34,7 @@
 #include "xla/tests/literal_test_util.h"
 #include "xla/tsl/framework/allocator.h"
 #include "xla/tsl/framework/device_id_utils.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/allocator.h"
 #include "tensorflow/core/framework/device.h"
 #include "tensorflow/core/framework/fake_input.h"
@@ -45,7 +46,6 @@
 #include "tensorflow/core/platform/refcount.h"
 #include "tensorflow/core/tfrt/common/create_pjrt_client_util.h"
 #include "tensorflow/core/tfrt/common/pjrt_util.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/status.h"
 #include "tsl/platform/statusor.h"
 
@@ -674,9 +674,8 @@
                                           ->tensorflow_accelerator_device_info()
                                           ->use_pjrt_tensor_buffer;
   const DeviceType& device_type = GetDeviceType(context_.get());
-  TF_ASSERT_OK_AND_ASSIGN(const int pjrt_device_id,
-                          tsl::GetDeviceIdFromDeviceParsedName(
-                              context_->device()->parsed_name(), device_type));
+  const int pjrt_device_id =
+      tsl::GetDeviceIdFromDeviceParsedName(context_->device()->parsed_name());
   TF_ASSERT_OK_AND_ASSIGN(xla::PjRtDevice * pjrt_device,
                           pjrt_client_->LookupAddressableDevice(
                               xla::PjRtLocalDeviceId(pjrt_device_id)));
diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD
index a9b1e1b..75ca89d 100644
--- a/tensorflow/compiler/mlir/lite/BUILD
+++ b/tensorflow/compiler/mlir/lite/BUILD
@@ -318,6 +318,13 @@
     ],
 )
 
+cc_library(
+    name = "stateful_error_reporter",
+    hdrs = ["stateful_error_reporter.h"],
+    compatible_with = get_compatible_with_portable(),
+    deps = ["//tensorflow/compiler/mlir/lite/core/api:error_reporter"],
+)
+
 gentbl_cc_library(
     name = "tensorflow_lite_canonicalize_inc_gen",
     compatible_with = get_compatible_with_portable(),
@@ -1095,8 +1102,8 @@
     name = "flatbuffer_to_string",
     srcs = ["flatbuffer_to_string.cc"],
     deps = [
+        "//tensorflow/compiler/mlir/lite/core:absl_error_model_builder",
         "//tensorflow/compiler/mlir/lite/schema:schema_fbs_with_reflection",
-        "//tensorflow/lite/core:model_builder",
         "@flatbuffers",
     ],
 )
@@ -1146,7 +1153,6 @@
         "//tensorflow/core:framework",
         "//tensorflow/core:portable_gif_internal",
         "//tensorflow/core:protos_all_cc",
-        "//tensorflow/lite/core:framework",
         "//tensorflow/lite/toco:toco_flags_proto_cc",
         "//tensorflow/lite/tools/versioning",
         "//tensorflow/lite/tools/versioning:gpu_compatibility",
@@ -1189,8 +1195,10 @@
         ":size_utils",
         ":tensorflow_lite",
         "//tensorflow/compiler/mlir/lite:control_edges",
+        "//tensorflow/compiler/mlir/lite/core:absl_error_model_builder",
         "//tensorflow/compiler/mlir/lite/experimental/remat:metadata_util",
         "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps",
+        "//tensorflow/compiler/mlir/lite/schema:schema_fbs",
         "//tensorflow/compiler/mlir/lite/schema:schema_fbs_with_mutable",
         "//tensorflow/compiler/mlir/lite/schema:schema_utils",
         "//tensorflow/compiler/mlir/lite/stablehlo:legalize_stablehlo_composite_to_tfl_custom",
@@ -1204,8 +1212,6 @@
         "//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/core/platform:errors",
-        "//tensorflow/core/platform:status",
-        "//tensorflow/lite:model_builder",
         "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/container:flat_hash_set",
         "@com_google_absl//absl/status",
@@ -1357,7 +1363,6 @@
         "//tensorflow/core:core_cpu_base",
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/core/platform:errors",
-        "//tensorflow/lite:framework",
         "//tensorflow/lite/toco:toco_flags_proto_cc",
         "@com_google_absl//absl/status:statusor",
         "@com_google_absl//absl/strings",
diff --git a/tensorflow/compiler/mlir/lite/core/BUILD b/tensorflow/compiler/mlir/lite/core/BUILD
index 04e37c6..4816761 100644
--- a/tensorflow/compiler/mlir/lite/core/BUILD
+++ b/tensorflow/compiler/mlir/lite/core/BUILD
@@ -1,13 +1,59 @@
 load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable")
+load("//tensorflow/compiler/mlir/lite:build_def.bzl", "tflite_copts_warnings")
 
 package(
     # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
     licenses = ["notice"],
 )
 
+exports_files(
+    [
+        "model_builder_base.h",
+    ],
+    visibility = ["//tensorflow/lite/core:__pkg__"],
+)
+
 cc_library(
     name = "macros",
     hdrs = ["macros.h"],
     compatible_with = get_compatible_with_portable(),
     visibility = ["//visibility:public"],
 )
+
+cc_library(
+    name = "model_builder_base",
+    srcs = ["model_builder_base.cc"],
+    hdrs = ["model_builder_base.h"],
+    compatible_with = get_compatible_with_portable(),
+    copts = tflite_copts_warnings(),
+    visibility = [
+        "//tensorflow/compiler/mlir/lite:__subpackages__",
+        "//tensorflow/lite/core:__pkg__",
+    ],
+    deps = [
+        ":macros",
+        "//tensorflow/compiler/mlir/lite/core/api:error_reporter",
+        "//tensorflow/compiler/mlir/lite/core/api:verifier",
+        "//tensorflow/compiler/mlir/lite/schema:schema_fbs",
+        "//tensorflow/lite:allocation",
+        "@com_google_absl//absl/strings",
+        "@flatbuffers",
+    ],
+    alwayslink = 1,
+)
+
+cc_library(
+    name = "absl_error_model_builder",
+    srcs = ["absl_error_model_builder.cc"],
+    hdrs = ["absl_error_model_builder.h"],
+    compatible_with = get_compatible_with_portable(),
+    copts = tflite_copts_warnings(),
+    visibility = [
+        "//tensorflow/compiler/mlir/lite:__subpackages__",
+    ],
+    deps = [
+        ":model_builder_base",
+        "//tensorflow/compiler/mlir/lite/core/api:error_reporter",
+        "@com_google_absl//absl/log",
+    ],
+)
diff --git a/tensorflow/compiler/mlir/lite/core/absl_error_model_builder.cc b/tensorflow/compiler/mlir/lite/core/absl_error_model_builder.cc
new file mode 100644
index 0000000..269d81e
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/core/absl_error_model_builder.cc
@@ -0,0 +1,40 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/compiler/mlir/lite/core/absl_error_model_builder.h"
+
+#include <cstdarg>
+#include <cstdio>
+
+#include "absl/log/log.h"
+#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h"
+
+namespace mlir::TFL {
+
+int AbslErrorReporter::Report(const char* format, va_list args) {
+  char buffer[1024];
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Wformat-nonliteral"
+  vsprintf(buffer, format, args);
+#pragma clang diagnostic pop
+  LOG(ERROR) << buffer;
+  return 0;
+}
+
+tflite::ErrorReporter* GetAbslErrorReporter() {
+  static AbslErrorReporter* error_reporter = new AbslErrorReporter;
+  return error_reporter;
+}
+
+}  // namespace mlir::TFL
diff --git a/tensorflow/compiler/mlir/lite/core/absl_error_model_builder.h b/tensorflow/compiler/mlir/lite/core/absl_error_model_builder.h
new file mode 100644
index 0000000..c3d76e2
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/core/absl_error_model_builder.h
@@ -0,0 +1,47 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_COMPILER_MLIR_LITE_CORE_ABSL_ERROR_MODEL_BUILDER_H_
+#define TENSORFLOW_COMPILER_MLIR_LITE_CORE_ABSL_ERROR_MODEL_BUILDER_H_
+
+#include <cstdarg>
+
+#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h"
+#include "tensorflow/compiler/mlir/lite/core/model_builder_base.h"
+
+namespace mlir::TFL {
+
+// An error reporter that uses absl logging.
+class AbslErrorReporter : public tflite::ErrorReporter {
+  int Report(const char* format, va_list args) override;
+};
+
+tflite::ErrorReporter* GetAbslErrorReporter();
+
+class FlatBufferModelAbslError
+    : public tflite::impl::FlatBufferModelBase<FlatBufferModelAbslError> {
+ public:
+  // Use stderr_reporter as the default error reporter.
+  static tflite::ErrorReporter* GetDefaultErrorReporter() {
+    return GetAbslErrorReporter();
+  }
+
+  // Inherit all constructors from FlatBufferModelBase since inherited factory
+  // methods refer to them.
+  using FlatBufferModelBase<FlatBufferModelAbslError>::FlatBufferModelBase;
+};
+
+}  // namespace mlir::TFL
+
+#endif  // TENSORFLOW_COMPILER_MLIR_LITE_CORE_ABSL_ERROR_MODEL_BUILDER_H_
diff --git a/tensorflow/compiler/mlir/lite/core/api/BUILD b/tensorflow/compiler/mlir/lite/core/api/BUILD
new file mode 100644
index 0000000..cc2d519
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/core/api/BUILD
@@ -0,0 +1,44 @@
+load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable")
+load("//tensorflow/compiler/mlir/lite:build_def.bzl", "tflite_copts")
+
+package(
+    # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
+    default_visibility = [
+        "//tensorflow/compiler/mlir/lite:__subpackages__",
+        "//tensorflow/lite:__subpackages__",
+    ],
+    licenses = ["notice"],
+)
+
+exports_files(["error_reporter.h"])
+
+cc_library(
+    name = "error_reporter",
+    srcs = ["error_reporter.cc"],
+    hdrs = ["error_reporter.h"],
+    compatible_with = get_compatible_with_portable(),
+    copts = tflite_copts(),
+    deps = [],
+)
+
+exports_files(["verifier.h"])
+
+cc_library(
+    name = "verifier",
+    hdrs = ["verifier.h"],
+    compatible_with = get_compatible_with_portable(),
+    copts = tflite_copts(),
+    visibility = ["//visibility:public"],
+    deps = [":error_reporter"],
+)
+
+tf_cc_test(
+    name = "error_reporter_test",
+    size = "small",
+    srcs = ["error_reporter_test.cc"],
+    deps = [
+        ":error_reporter",
+        "@com_google_googletest//:gtest_main",
+    ],
+)
diff --git a/tensorflow/compiler/mlir/lite/core/api/error_reporter.cc b/tensorflow/compiler/mlir/lite/core/api/error_reporter.cc
new file mode 100644
index 0000000..96f7561
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/core/api/error_reporter.cc
@@ -0,0 +1,39 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h"
+
+#include <cstdarg>
+
+namespace tflite {
+
+int ErrorReporter::Report(const char* format, ...) {
+  va_list args;
+  va_start(args, format);
+  int code = Report(format, args);
+  va_end(args);
+  return code;
+}
+
+// TODO(aselle): Make the name of ReportError on context the same, so
+// we can use the ensure functions w/o a context and w/ a reporter.
+int ErrorReporter::ReportError(void*, const char* format, ...) {
+  va_list args;
+  va_start(args, format);
+  int code = Report(format, args);
+  va_end(args);
+  return code;
+}
+
+}  // namespace tflite
diff --git a/tensorflow/compiler/mlir/lite/core/api/error_reporter.h b/tensorflow/compiler/mlir/lite/core/api/error_reporter.h
new file mode 100644
index 0000000..79c9fc9
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/core/api/error_reporter.h
@@ -0,0 +1,72 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_COMPILER_MLIR_LITE_CORE_API_ERROR_REPORTER_H_
+#define TENSORFLOW_COMPILER_MLIR_LITE_CORE_API_ERROR_REPORTER_H_
+
+#include <cstdarg>
+
+namespace tflite {
+
+/// A functor that reports error to supporting system. Invoked similar to
+/// printf.
+///
+/// Usage:
+///  ErrorReporter foo;
+///  foo.Report("test %d", 5);
+/// or
+///  va_list args;
+///  foo.Report("test %d", args); // where args is va_list
+///
+/// Subclass ErrorReporter to provide another reporting destination.
+/// For example, if you have a GUI program, you might redirect to a buffer
+/// that drives a GUI error log box.
+class ErrorReporter {
+ public:
+  virtual ~ErrorReporter() = default;
+  /// Converts `args` to character equivalents according to `format` string,
+  /// constructs the error string and report it.
+  /// Returns number of characters written or zero on success, and negative
+  /// number on error.
+  virtual int Report(const char* format, va_list args) = 0;
+
+  /// Converts arguments to character equivalents according to `format` string,
+  /// constructs the error string and report it.
+  /// Returns number of characters written or zero on success, and negative
+  /// number on error.
+  int Report(const char* format, ...);
+
+  /// Equivalent to `Report` above. The additional `void*` parameter is unused.
+  /// This method is for compatibility with macros that takes `TfLiteContext`,
+  /// like TF_LITE_ENSURE and related macros.
+  int ReportError(void*, const char* format, ...);
+};
+
+}  // namespace tflite
+
+// You should not make bare calls to the error reporter, instead use the
+// TF_LITE_REPORT_ERROR macro, since this allows message strings to be
+// stripped when the binary size has to be optimized. If you are looking to
+// reduce binary size, define TF_LITE_STRIP_ERROR_STRINGS when compiling and
+// every call will be stubbed out, taking no memory.
+#ifndef TF_LITE_STRIP_ERROR_STRINGS
+#define TF_LITE_REPORT_ERROR(reporter, ...)                               \
+  do {                                                                    \
+    static_cast<::tflite::ErrorReporter*>(reporter)->Report(__VA_ARGS__); \
+  } while (false)
+#else  // TF_LITE_STRIP_ERROR_STRINGS
+#define TF_LITE_REPORT_ERROR(reporter, ...)
+#endif  // TF_LITE_STRIP_ERROR_STRINGS
+
+#endif  // TENSORFLOW_COMPILER_MLIR_LITE_CORE_API_ERROR_REPORTER_H_
diff --git a/tensorflow/compiler/mlir/lite/core/api/error_reporter_test.cc b/tensorflow/compiler/mlir/lite/core/api/error_reporter_test.cc
new file mode 100644
index 0000000..ca7c4a2
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/core/api/error_reporter_test.cc
@@ -0,0 +1,61 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h"
+
+#include <cstdio>
+
+#include <gtest/gtest.h>
+
+namespace tflite {
+
+class MockErrorReporter : public ErrorReporter {
+ public:
+  MockErrorReporter() { buffer_[0] = 0; }
+  int Report(const char* format, va_list args) override {
+    vsnprintf(buffer_, kBufferSize, format, args);
+    return 0;
+  }
+  char* GetBuffer() { return buffer_; }
+
+ private:
+  static constexpr int kBufferSize = 256;
+  char buffer_[kBufferSize];
+};
+
+TEST(ErrorReporter, TestReport) {
+  MockErrorReporter mock_reporter;
+  ErrorReporter* reporter = &mock_reporter;
+  reporter->Report("Error: %d", 23);
+  EXPECT_EQ(0, strcmp(mock_reporter.GetBuffer(), "Error: 23"));
+}
+
+TEST(ErrorReporter, TestReportMacro) {
+  MockErrorReporter mock_reporter;
+  // Only define the reporter if it's used, to avoid warnings.
+#ifndef TF_LITE_STRIP_ERROR_STRINGS
+  ErrorReporter* reporter = &mock_reporter;
+#endif  // TFLITE_STRIP_ERROR_STRINGS
+
+  TF_LITE_REPORT_ERROR(reporter, "Error: %d", 23);
+
+#ifndef TF_LITE_STRIP_ERROR_STRINGS
+  EXPECT_EQ(0, strcmp(mock_reporter.GetBuffer(), "Error: 23"));
+#else   // TF_LITE_STRIP_ERROR_STRINGS
+  EXPECT_EQ(0, strcmp(mock_reporter.GetBuffer(), ""));
+#endif  // TF_LITE_STRIP_ERROR_STRINGS
+}
+
+}  // namespace tflite
diff --git a/tensorflow/compiler/mlir/lite/core/api/verifier.h b/tensorflow/compiler/mlir/lite/core/api/verifier.h
new file mode 100644
index 0000000..2e24347
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/core/api/verifier.h
@@ -0,0 +1,39 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+/// \file
+///
+/// Abstract interface for verifying a model.
+#ifndef TENSORFLOW_COMPILER_MLIR_LITE_CORE_API_VERIFIER_H_
+#define TENSORFLOW_COMPILER_MLIR_LITE_CORE_API_VERIFIER_H_
+
+#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h"
+
+namespace tflite {
+
+/// Abstract interface that verifies whether a given model is legit.
+/// It facilitates the use-case to verify and build a model without loading it
+/// twice.
+/// (See also "tensorflow/lite/tools/verifier.h".)
+class TfLiteVerifier {
+ public:
+  /// Returns true if the model is legit.
+  virtual bool Verify(const char* data, int length,
+                      ErrorReporter* reporter) = 0;
+  virtual ~TfLiteVerifier() {}
+};
+
+}  // namespace tflite
+
+#endif  // TENSORFLOW_COMPILER_MLIR_LITE_CORE_API_VERIFIER_H_
diff --git a/tensorflow/compiler/mlir/lite/core/model_builder_base.cc b/tensorflow/compiler/mlir/lite/core/model_builder_base.cc
new file mode 100644
index 0000000..1537b90
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/core/model_builder_base.cc
@@ -0,0 +1,58 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/compiler/mlir/lite/core/model_builder_base.h"
+
+#include <stddef.h>
+
+#include <memory>
+#include <string>
+
+#include "absl/strings/str_cat.h"
+#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h"
+#include "tensorflow/lite/allocation.h"
+
+namespace tflite {
+
+#ifndef TFLITE_MCU
+// Loads a model from `filename`. If `mmap_file` is true then use mmap,
+// otherwise make a copy of the model in a buffer.
+std::unique_ptr<Allocation> GetAllocationFromFile(
+    const char* filename, ErrorReporter* error_reporter) {
+  std::unique_ptr<Allocation> allocation;
+  if (MMAPAllocation::IsSupported()) {
+    allocation = std::make_unique<MMAPAllocation>(filename, error_reporter);
+  } else {
+    allocation = std::make_unique<FileCopyAllocation>(filename, error_reporter);
+  }
+  return allocation;
+}
+
+// Loads a model from `fd`. If `mmap_file` is true then use mmap,
+// otherwise make a copy of the model in a buffer.
+std::unique_ptr<Allocation> GetAllocationFromFile(
+    int fd, ErrorReporter* error_reporter) {
+  std::unique_ptr<Allocation> allocation;
+  if (MMAPAllocation::IsSupported()) {
+    allocation = std::make_unique<MMAPAllocation>(fd, error_reporter);
+  } else {
+    allocation = std::make_unique<FileCopyAllocation>(
+        absl::StrCat("/proc/self/fd/", fd).c_str(), error_reporter);
+  }
+  return allocation;
+}
+
+#endif
+
+}  // namespace tflite
diff --git a/tensorflow/compiler/mlir/lite/core/model_builder_base.h b/tensorflow/compiler/mlir/lite/core/model_builder_base.h
new file mode 100644
index 0000000..4e40e61
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/core/model_builder_base.h
@@ -0,0 +1,614 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+/// \file
+///
+/// Deserialization infrastructure for tflite. Provides functionality
+/// to go from a serialized tflite model in flatbuffer format to an
+/// in-memory representation of the model.
+///
+/// WARNING: Users of TensorFlow Lite should not include this file directly,
+/// but should instead include "third_party/tensorflow/lite/model_builder.h".
+/// Only the TensorFlow Lite implementation itself should include this
+/// file directly.
+// IWYU pragma: private, include "third_party/tensorflow/lite/model_builder.h"
+
+#ifndef TENSORFLOW_COMPILER_MLIR_LITE_CORE_MODEL_BUILDER_BASE_H_
+#define TENSORFLOW_COMPILER_MLIR_LITE_CORE_MODEL_BUILDER_BASE_H_
+
+#include <stddef.h>
+
+#include <algorithm>
+#include <cstdint>
+#include <map>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "flatbuffers/base.h"  // from @flatbuffers
+#include "flatbuffers/buffer.h"  // from @flatbuffers
+#include "flatbuffers/vector.h"  // from @flatbuffers
+#include "flatbuffers/verifier.h"  // from @flatbuffers
+#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h"
+#include "tensorflow/compiler/mlir/lite/core/api/verifier.h"
+#include "tensorflow/compiler/mlir/lite/core/macros.h"
+#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h"
+#include "tensorflow/lite/allocation.h"
+
+namespace tflite {
+
+std::unique_ptr<Allocation> GetAllocationFromFile(
+    const char* filename, ErrorReporter* error_reporter);
+
+std::unique_ptr<Allocation> GetAllocationFromFile(
+    int fd, ErrorReporter* error_reporter);
+
+/// An RAII object that represents a read-only tflite model, copied from disk,
+/// or mmapped. This uses flatbuffers as the serialization format.
+///
+/// NOTE: The current API requires that a FlatBufferModelBase instance be kept
+/// alive by the client as long as it is in use by any dependent Interpreter
+/// instances. As the FlatBufferModelBase instance is effectively immutable
+/// after creation, the client may safely use a single model with multiple
+/// dependent Interpreter instances, even across multiple threads (though note
+/// that each Interpreter instance is *not* thread-safe).
+///
+/// <pre><code>
+/// using namespace tflite;
+/// StderrReporter error_reporter;
+/// auto model = FlatBufferModelBase::BuildFromFile("interesting_model.tflite",
+///                                             &error_reporter);
+/// MyOpResolver resolver;  // You need to subclass OpResolver to provide
+///                         // implementations.
+/// InterpreterBuilder builder(*model, resolver);
+/// std::unique_ptr<Interpreter> interpreter;
+/// if(builder(&interpreter) == kTfLiteOk) {
+///   .. run model inference with interpreter
+/// }
+/// </code></pre>
+///
+/// OpResolver must be defined to provide your kernel implementations to the
+/// interpreter. This is environment specific and may consist of just the
+/// builtin ops, or some custom operators you defined to extend tflite.
+namespace impl {
+
+template <typename T>
+class FlatBufferModelBase {
+ public:
+  /// Builds a model based on a file.
+  /// Caller retains ownership of `error_reporter` and must ensure its lifetime
+  /// is longer than the FlatBufferModelBase instance.
+  /// Returns a nullptr in case of failure.
+  static std::unique_ptr<T> BuildFromFile(
+      const char* filename,
+      ErrorReporter* error_reporter = T::GetDefaultErrorReporter()) {
+    error_reporter = ValidateErrorReporter(error_reporter);
+    std::unique_ptr<T> model = BuildFromAllocation(
+        GetAllocationFromFile(filename, error_reporter), error_reporter);
+#if FLATBUFFERS_LITTLEENDIAN == 1
+    return model;
+#else
+    return ByteConvertModel(std::move(model), error_reporter);
+#endif
+  }
+
+  /// Verifies whether the content of the file is legit, then builds a model
+  /// based on the file.
+  /// The extra_verifier argument is an additional optional verifier for the
+  /// file contents. By default, we always check with tflite::VerifyModelBuffer.
+  /// If extra_verifier is supplied, the file contents is also checked against
+  /// the extra_verifier after the check against tflite::VerifyModelBuilder.
+  /// Caller retains ownership of `error_reporter` and must ensure its lifetime
+  /// is longer than the FlatBufferModelBase instance.
+  /// Returns a nullptr in case of failure.
+  static std::unique_ptr<T> VerifyAndBuildFromFile(
+      const char* filename, TfLiteVerifier* extra_verifier = nullptr,
+      ErrorReporter* error_reporter = T::GetDefaultErrorReporter()) {
+    error_reporter = ValidateErrorReporter(error_reporter);
+    std::unique_ptr<T> model = VerifyAndBuildFromAllocation(
+        GetAllocationFromFile(filename, error_reporter), extra_verifier,
+        error_reporter);
+#if FLATBUFFERS_LITTLEENDIAN == 1
+    return model;
+#else
+    return ByteConvertModel(std::move(model), error_reporter);
+#endif
+  }
+
+  /// Builds a model based on a file descriptor.
+  /// Caller retains ownership of `error_reporter` and must ensure its lifetime
+  /// is longer than the FlatBufferModelBase instance. Caller retains ownership
+  /// of `fd` and must ensure it is closed after BuildFromFile returns. Returns
+  /// a nullptr in case of failure.
+  static std::unique_ptr<T> BuildFromFileDescriptor(
+      int fd, ErrorReporter* error_reporter = T::GetDefaultErrorReporter()) {
+    error_reporter = ValidateErrorReporter(error_reporter);
+    std::unique_ptr<T> model = BuildFromAllocation(
+        GetAllocationFromFile(fd, error_reporter), error_reporter);
+#if FLATBUFFERS_LITTLEENDIAN == 1
+    return model;
+#else
+    return ByteConvertModel(std::move(model), error_reporter);
+#endif
+  }
+
+  /// Verifies whether the content of the file descriptor is legit, then builds
+  /// a model based on the file.
+  /// The extra_verifier argument is an additional optional verifier for the
+  /// file contents. By default, we always check with tflite::VerifyModelBuffer.
+  /// If extra_verifier is supplied, the file contents is also checked against
+  /// the extra_verifier after the check against tflite::VerifyModelBuilder.
+  /// Caller retains ownership of `error_reporter` and must ensure its lifetime
+  /// is longer than the FlatBufferModelBase instance.
+  /// Returns a nullptr in case of failure.
+  static std::unique_ptr<T> VerifyAndBuildFromFileDescriptor(
+      int fd, TfLiteVerifier* extra_verifier = nullptr,
+      ErrorReporter* error_reporter = T::GetDefaultErrorReporter()) {
+    error_reporter = ValidateErrorReporter(error_reporter);
+    std::unique_ptr<FlatBufferModelBase<T>> model =
+        VerifyAndBuildFromAllocation(GetAllocationFromFile(fd, error_reporter),
+                                     extra_verifier, error_reporter);
+#if FLATBUFFERS_LITTLEENDIAN == 1
+    return model;
+#else
+    return ByteConvertModel(std::move(model), error_reporter);
+#endif
+  }
+
+  /// Builds a model based on a pre-loaded flatbuffer.
+  /// Caller retains ownership of the buffer and should keep it alive until
+  /// the returned object is destroyed. Caller also retains ownership of
+  /// `error_reporter` and must ensure its lifetime is longer than the
+  /// FlatBufferModelBase instance.
+  /// Returns a nullptr in case of failure.
+  /// NOTE: this does NOT validate the buffer so it should NOT be called on
+  /// invalid/untrusted input. Use VerifyAndBuildFromBuffer in that case
+  static std::unique_ptr<T> BuildFromBuffer(
+      const char* caller_owned_buffer, size_t buffer_size,
+      ErrorReporter* error_reporter = T::GetDefaultErrorReporter()) {
+    error_reporter = ValidateErrorReporter(error_reporter);
+    std::unique_ptr<Allocation> allocation(
+        new MemoryAllocation(caller_owned_buffer, buffer_size, error_reporter));
+    return BuildFromAllocation(std::move(allocation), error_reporter);
+  }
+
+  /// Verifies whether the content of the buffer is legit, then builds a model
+  /// based on the pre-loaded flatbuffer.
+  /// The extra_verifier argument is an additional optional verifier for the
+  /// buffer. By default, we always check with tflite::VerifyModelBuffer. If
+  /// extra_verifier is supplied, the buffer is checked against the
+  /// extra_verifier after the check against tflite::VerifyModelBuilder. The
+  /// caller retains ownership of the buffer and should keep it alive until the
+  /// returned object is destroyed. Caller retains ownership of `error_reporter`
+  /// and must ensure its lifetime is longer than the FlatBufferModelBase
+  /// instance. Returns a nullptr in case of failure.
+  static std::unique_ptr<T> VerifyAndBuildFromBuffer(
+      const char* caller_owned_buffer, size_t buffer_size,
+      TfLiteVerifier* extra_verifier = nullptr,
+      ErrorReporter* error_reporter = T::GetDefaultErrorReporter()) {
+    error_reporter = ValidateErrorReporter(error_reporter);
+    std::unique_ptr<Allocation> allocation(
+        new MemoryAllocation(caller_owned_buffer, buffer_size, error_reporter));
+    return VerifyAndBuildFromAllocation(std::move(allocation), extra_verifier,
+                                        error_reporter);
+  }
+
+#if FLATBUFFERS_LITTLEENDIAN == 0
+
+  void ByteSwapSerializedModel(std::string* serialized_model,
+                               bool from_big_endian) {
+    const uint8_t* buffer =
+        reinterpret_cast<const uint8_t*>(serialized_model->c_str());
+    const tflite::Model* input_model = tflite::GetModel(buffer);
+    ByteSwapTFLiteModel(input_model, from_big_endian);
+  }
+
+  void ByteSwapBuffer(int8_t tensor_type, size_t buffer_size, uint8_t* buffer,
+                      bool from_big_endian) {
+    switch (tensor_type) {
+      case tflite::TensorType_STRING: {
+        auto bp = reinterpret_cast<int32_t*>(buffer);
+        int num_of_strings =
+            from_big_endian ? bp[0] : flatbuffers::EndianSwap(bp[0]);
+        for (int i = 0; i < num_of_strings + 2; i++)
+          bp[i] = flatbuffers::EndianSwap(bp[i]);
+        break;
+      }
+      // 16-bit types
+      case tflite::TensorType_FLOAT16:
+      case tflite::TensorType_INT16:
+      case tflite::TensorType_UINT16: {
+        auto bp = reinterpret_cast<uint16_t*>(buffer);
+        for (int i = 0; i < buffer_size / 2; i++)
+          bp[i] = flatbuffers::EndianSwap(bp[i]);
+        break;
+      }
+      // 32-bit types
+      case tflite::TensorType_FLOAT32:
+      case tflite::TensorType_INT32:
+      case tflite::TensorType_UINT32:
+      case tflite::TensorType_COMPLEX64: {
+        auto bp = reinterpret_cast<uint32_t*>(buffer);
+        for (int i = 0; i < buffer_size / 4; i++)
+          bp[i] = flatbuffers::EndianSwap(bp[i]);
+        break;
+      }
+      // 64-bit types
+      case tflite::TensorType_INT64:
+      case tflite::TensorType_FLOAT64:
+      case tflite::TensorType_UINT64:
+      case tflite::TensorType_COMPLEX128: {
+        auto bp = reinterpret_cast<uint64_t*>(buffer);
+        for (int i = 0; i < buffer_size / 8; i++)
+          bp[i] = flatbuffers::EndianSwap(bp[i]);
+        break;
+      }
+      default:
+        break;
+    }
+  }
+
+  void ByteSwapTFLiteModel(const tflite::Model* tfl_model,
+                           bool from_big_endian) {
+    bool buffer_swapped[tfl_model->buffers()->size()] = {};
+    for (size_t subgraph_idx = 0; subgraph_idx < tfl_model->subgraphs()->size();
+         subgraph_idx++) {
+      const tflite::SubGraph* subgraph =
+          tfl_model->subgraphs()->Get(subgraph_idx);
+      for (size_t ts_idx = 0; ts_idx < subgraph->tensors()->size(); ts_idx++) {
+        const tflite::Tensor* tensor = subgraph->tensors()->Get(ts_idx);
+        if (tensor->buffer() > 0 &&
+            tensor->buffer() < tfl_model->buffers()->size() &&
+            !buffer_swapped[tensor->buffer()]) {
+          const tflite::Buffer* buffer_ =
+              (*tfl_model->buffers())[tensor->buffer()];
+          if (!buffer_ || !buffer_->data()) continue;
+          auto* buffer = buffer_->data();
+          uint8_t* buff_ = const_cast<uint8_t*>(buffer->data());
+          ByteSwapBuffer(tensor->type(), buffer->size(), buff_,
+                         from_big_endian);
+          buffer_swapped[tensor->buffer()] = true;
+        }
+      }
+    }
+  }
+
+  std::unique_ptr<T> ByteConvertModel(std::unique_ptr<T> model,
+                                      ErrorReporter* error_reporter,
+                                      bool from_big_endian) {
+    if (model == nullptr) return model;
+    auto tfl_model = model->GetModel();
+    if (tfl_model->subgraphs()->size() == 0) return model;
+    if (tfl_model->subgraphs()->Get(0)->tensors()->size() == 0) return model;
+    if (tfl_model->buffers()->size() < 2) return model;
+    return ByteSwapFlatBufferModelBase<T>(std::move(model), error_reporter,
+                                          from_big_endian);
+  }
+
+  std::unique_ptr<T> ByteSwapFlatBufferModelBase(std::unique_ptr<T> model,
+                                                 ErrorReporter* error_reporter,
+                                                 bool from_big_endian) {
+    FlatBufferModelBase<T>* modelp = model.release();
+    auto tflite_model = modelp->GetModel();
+    auto copied_model = std::make_unique<tflite::ModelT>();
+    tflite_model->UnPackTo(copied_model.get(), nullptr);
+    ByteSwapTFLiteModelT(copied_model.get(), from_big_endian);
+    std::unique_ptr<flatbuffers::FlatBufferBuilder> builder(
+        new flatbuffers::FlatBufferBuilder());
+    auto packed_model = tflite::Model::Pack(*builder, copied_model.get());
+    tflite::FinishModelBuffer(*builder, packed_model);
+    flatbuffers::FlatBufferBuilder* builder_ = builder.release();
+    return BuildFromBuffer(
+        reinterpret_cast<const char*>(builder_->GetBufferPointer()),
+        builder_->GetSize(), error_reporter);
+  }
+
+  void ByteSwapTFLiteModelT(tflite::ModelT* tfl_modelt, bool from_big_endian) {
+    size_t bytes_per_elem = 0;
+    bool buffer_swapped[tfl_modelt->buffers.size()] = {};
+    for (size_t subgraph_idx = 0; subgraph_idx < tfl_modelt->subgraphs.size();
+         subgraph_idx++) {
+      tflite::SubGraphT* subgraph =
+          tfl_modelt->subgraphs.at(subgraph_idx).get();
+      for (size_t ts_idx = 0; ts_idx < subgraph->tensors.size(); ts_idx++) {
+        tflite::TensorT* tensor = subgraph->tensors[ts_idx].get();
+        if (tensor->buffer > 0 && tensor->buffer < tfl_modelt->buffers.size() &&
+            !buffer_swapped[tensor->buffer]) {
+          const auto* buffer =
+              &(tfl_modelt->buffers[tensor->buffer].get()->data);
+          if (buffer && buffer->data()) {
+            uint8_t* buff_ = const_cast<uint8_t*>(buffer->data());
+            ByteSwapBuffer(tensor->type, buffer->size(), buff_,
+                           from_big_endian);
+            buffer_swapped[tensor->buffer] = true;
+          }
+        }
+      }
+    }
+  }
+
+#endif
+
+  /// Builds a model directly from an allocation.
+  /// Ownership of the allocation is passed to the model, but the caller
+  /// retains ownership of `error_reporter` and must ensure its lifetime is
+  /// longer than the FlatBufferModelBase instance.
+  /// Returns a nullptr in case of failure (e.g., the allocation is invalid).
+  static std::unique_ptr<T> BuildFromAllocation(
+      std::unique_ptr<Allocation> allocation,
+      ErrorReporter* error_reporter = T::GetDefaultErrorReporter()) {
+    std::unique_ptr<T> model(
+        new T(std::move(allocation), ValidateErrorReporter(error_reporter)));
+    if (!model->initialized()) {
+      model.reset();
+    } else {
+      model->ValidateModelBuffers(error_reporter);
+    }
+    return model;
+  }
+
+  /// Verifies whether the content of the allocation is legit, then builds a
+  /// model based on the provided allocation.
+  /// The extra_verifier argument is an additional optional verifier for the
+  /// buffer. By default, we always check with tflite::VerifyModelBuffer. If
+  /// extra_verifier is supplied, the buffer is checked against the
+  /// extra_verifier after the check against tflite::VerifyModelBuilder.
+  /// Ownership of the allocation is passed to the model, but the caller
+  /// retains ownership of `error_reporter` and must ensure its lifetime is
+  /// longer than the FlatBufferModelBase instance.
+  /// Returns a nullptr in case of failure.
+  static std::unique_ptr<T> VerifyAndBuildFromAllocation(
+      std::unique_ptr<Allocation> allocation,
+      TfLiteVerifier* extra_verifier = nullptr,
+      ErrorReporter* error_reporter = T::GetDefaultErrorReporter()) {
+    error_reporter = ValidateErrorReporter(error_reporter);
+    if (!allocation || !allocation->valid()) {
+      TF_LITE_REPORT_ERROR(error_reporter,
+                           "The model allocation is null/empty");
+      return nullptr;
+    }
+
+    {
+      // Flatbuffers can only be smaller than 2GB. The file format appends some
+      // data after the actual flabuffer. We truncate the allocation size to 2GB
+      // so that the verifier doesn't early exit on us.
+      size_t allocation_size =
+          std::min(allocation->bytes(),
+                   static_cast<size_t>(FLATBUFFERS_MAX_BUFFER_SIZE - 1));
+      flatbuffers::Verifier base_verifier(
+          reinterpret_cast<const uint8_t*>(allocation->base()),
+          allocation_size);
+      if (!VerifyModelBuffer(base_verifier)) {
+        TF_LITE_REPORT_ERROR(error_reporter,
+                             "The model is not a valid Flatbuffer buffer");
+        return nullptr;
+      }
+
+      if (extra_verifier &&
+          !extra_verifier->Verify(static_cast<const char*>(allocation->base()),
+                                  allocation_size, error_reporter)) {
+        // The verifier will have already logged an appropriate error message.
+        return nullptr;
+      }
+    }
+
+    return BuildFromAllocation(std::move(allocation), error_reporter);
+  }
+
+  /// Builds a model directly from a flatbuffer pointer
+  /// Caller retains ownership of the buffer and should keep it alive until the
+  /// returned object is destroyed. Caller retains ownership of `error_reporter`
+  /// and must ensure its lifetime is longer than the FlatBufferModelBase
+  /// instance. Returns a nullptr in case of failure.
+  static std::unique_ptr<T> BuildFromModel(
+      const tflite::Model* caller_owned_model_spec,
+      ErrorReporter* error_reporter = T::GetDefaultErrorReporter()) {
+    error_reporter = ValidateErrorReporter(error_reporter);
+
+    if (CheckBufferOutsideModel(caller_owned_model_spec)) {
+      TF_LITE_REPORT_ERROR(error_reporter,
+                           "The model contains weights not accessible from "
+                           "tflite::Model *, please use other api");
+      return nullptr;
+    }
+
+    std::unique_ptr<T> model(new T(caller_owned_model_spec, error_reporter));
+    if (!model->initialized()) {
+      model.reset();
+    } else {
+      model->ValidateModelBuffers(error_reporter);
+    }
+    return model;
+  }
+
+  // Releases memory or unmaps mmaped memory.
+  ~FlatBufferModelBase() = default;
+
+  // Copying or assignment is disallowed to simplify ownership semantics.
+  FlatBufferModelBase(const FlatBufferModelBase&) = delete;
+  FlatBufferModelBase& operator=(const FlatBufferModelBase&) = delete;
+
+  bool initialized() const { return model_ != nullptr; }
+  const tflite::Model* operator->() const { return model_; }
+  const tflite::Model* GetModel() const { return model_; }
+  ErrorReporter* error_reporter() const { return error_reporter_; }
+  const Allocation* allocation() const { return allocation_.get(); }
+
+  // Returns the minimum runtime version from the flatbuffer. This runtime
+  // version encodes the minimum required interpreter version to run the
+  // flatbuffer model. If the minimum version can't be determined, an empty
+  // string will be returned.
+  // Note that the returned minimum version is a lower-bound but not a strict
+  // lower-bound; ops in the graph may not have an associated runtime version,
+  // in which case the actual required runtime might be greater than the
+  // reported minimum.
+  std::string GetMinimumRuntime() const {
+    if (!model_ || !model_->metadata()) return "";
+
+    for (int i = 0; i < model_->metadata()->size(); ++i) {
+      auto metadata = model_->metadata()->Get(i);
+      if (metadata->name()->str() == tflite_metadata_min_runtime_version) {
+        auto buf = metadata->buffer();
+        auto* buffer = (*model_->buffers())[buf];
+        auto* array = buffer->data();
+        // Get the real length of the runtime string, since there might be
+        // trailing
+        // '\0's in the buffer.
+        for (int len = 0; len < array->size(); ++len) {
+          if (array->data()[len] == '\0') {
+            return std::string(reinterpret_cast<const char*>(array->data()),
+                               len);
+          }
+        }
+        // If there is no '\0' in the buffer, this indicates that the flatbuffer
+        // is malformed.
+        TF_LITE_REPORT_ERROR(
+            error_reporter_,
+            "Min_runtime_version in model metadata is malformed");
+        break;
+      }
+    }
+    return "";
+  }
+
+  // Return model metadata as a mapping of name & buffer strings.
+  // See Metadata table in TFLite schema.
+  std::map<std::string, std::string> ReadAllMetadata() const {
+    return ReadAllMetadata(model_);
+  }
+
+  // // Return model metadata as a mapping of name & buffer strings.
+  // // See Metadata table in TFLite schema.
+  static std::map<std::string, std::string> ReadAllMetadata(
+      const ::tflite::Model* model) {
+    std::map<std::string, std::string> keys_values;
+    if (!model || !model->metadata() || !model->buffers()) return keys_values;
+
+    for (int i = 0; i < model->metadata()->size(); ++i) {
+      auto metadata = model->metadata()->Get(i);
+      auto buf = metadata->buffer();
+      if (buf >= model->buffers()->size()) continue;
+      const tflite::Buffer* buffer = (*model->buffers())[buf];
+      if (!buffer || !buffer->data()) continue;
+      const flatbuffers::Vector<uint8_t>* array = buffer->data();
+      if (!array) continue;
+      std::string val = std::string(
+          reinterpret_cast<const char*>(array->data()), array->size());
+      // Skip if key or value of metadata is empty.
+      if (!metadata->name() || val.empty()) continue;
+      keys_values[metadata->name()->str()] = val;
+    }
+    return keys_values;
+  }
+
+  // Validates if the FlatBufferModelBase's buffer is well-formed. Specifically,
+  // it checks if the 0th entry of the model buffers is an empty buffer
+  // (sentinel). This is a convention so that tensors without a buffer can
+  // provide 0 as their buffer. NOTE: The function doesn't explicitly fail for
+  // backward compatibility reasons; it just provides a warning in case of
+  // failures.
+  void ValidateModelBuffers(ErrorReporter* error_reporter) {
+    auto buffers = model_->buffers();
+    if (buffers && buffers->size() > 0) {
+      auto first_buffer = buffers->Get(0);
+      if (first_buffer && first_buffer->size() != 0) {
+        // Note the 0th entry of this array must be an empty buffer (sentinel).
+        // This is a convention so that tensors without a buffer can provide 0
+        // as their buffer.
+        TF_LITE_REPORT_ERROR(
+            error_reporter,
+            "The 0th entry of the model buffer must be an empty buffer.");
+      }
+    }
+  }
+
+  /// Returns true if the model identifier is correct (otherwise false and
+  /// reports an error).
+  bool CheckModelIdentifier() const {
+    if (allocation_->bytes() < 7) {
+      TF_LITE_REPORT_ERROR(
+          error_reporter_,
+          "Model provided must have at least 7 bytes to hold identifier.\n");
+      return false;
+    }
+    if (!tflite::ModelBufferHasIdentifier(allocation_->base())) {
+      const char* ident = flatbuffers::GetBufferIdentifier(allocation_->base());
+      // Suppress unused variable warning.
+      (void)ident;
+      TF_LITE_REPORT_ERROR(
+          error_reporter_,
+          "Model provided has model identifier '%c%c%c%c', should be '%s'\n",
+          ident[0], ident[1], ident[2], ident[3], tflite::ModelIdentifier());
+      return false;
+    }
+    return true;
+  }
+
+  /// Check If the buffer is stored as part of the Flatbuffer or outside
+  /// Return false if the buffers are part of the Flatbuffer
+  static bool CheckBufferOutsideModel(const tflite::Model* model) {
+    if (!model || !model->metadata()) return false;
+
+    for (int i = 0; i < model->metadata()->size(); ++i) {
+      auto metadata = model->metadata()->Get(i);
+      if (metadata->name()->str() == tflite_metadata_buffer_location) {
+        return true;
+      }
+    }
+    return false;
+  }
+
+ protected:
+  /// Loads a model from a given allocation. FlatBufferModelBase will take over
+  /// the ownership of `allocation`, and delete it in destructor. The ownership
+  /// of `error_reporter`remains with the caller and must have lifetime at least
+  /// as much as FlatBufferModelBase. This is to allow multiple models to use
+  /// the same ErrorReporter instance.
+  explicit FlatBufferModelBase(
+      std::unique_ptr<Allocation> allocation,
+      ErrorReporter* error_reporter = T::GetDefaultErrorReporter())
+      : error_reporter_(ValidateErrorReporter(error_reporter)),
+        allocation_(std::move(allocation)) {
+    if (!allocation_ || !allocation_->valid() || !CheckModelIdentifier()) {
+      return;
+    }
+
+    model_ = ::tflite::GetModel(allocation_->base());
+  }
+
+  /// Loads a model from Model flatbuffer. The `model` has to remain alive and
+  /// unchanged until the end of this flatbuffer model's lifetime.
+  FlatBufferModelBase(const Model* model, ErrorReporter* error_reporter)
+      : model_(model), error_reporter_(ValidateErrorReporter(error_reporter)) {}
+
+  static ErrorReporter* ValidateErrorReporter(ErrorReporter* error_reporter) {
+    return error_reporter ? error_reporter : T::GetDefaultErrorReporter();
+  }
+
+  /// Flatbuffer traverser pointer. (Model* is a pointer that is within the
+  /// allocated memory of the data allocated by allocation's internals.
+  const tflite::Model* model_ = nullptr;
+  /// The error reporter to use for model errors and subsequent errors when
+  /// the interpreter is created
+  ErrorReporter* error_reporter_;
+  /// The allocator used for holding memory of the model. Note that this will
+  /// be null if the client provides a tflite::Model directly.
+  std::unique_ptr<Allocation> allocation_;
+};
+
+}  // namespace impl
+
+}  // namespace tflite
+
+#endif  // TENSORFLOW_COMPILER_MLIR_LITE_CORE_MODEL_BUILDER_BASE_H_
diff --git a/tensorflow/compiler/mlir/lite/debug/BUILD b/tensorflow/compiler/mlir/lite/debug/BUILD
index 2516cb4..5c179eb 100644
--- a/tensorflow/compiler/mlir/lite/debug/BUILD
+++ b/tensorflow/compiler/mlir/lite/debug/BUILD
@@ -59,8 +59,8 @@
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Pass",
         "@llvm-project//mlir:Support",
-        "@local_tsl//tsl/lib/core:status_test_util",
         "@local_tsl//tsl/platform:env",
         "@local_tsl//tsl/platform:path",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
diff --git a/tensorflow/compiler/mlir/lite/debug/debug_test.cc b/tensorflow/compiler/mlir/lite/debug/debug_test.cc
index 371e318..5d1ed84 100644
--- a/tensorflow/compiler/mlir/lite/debug/debug_test.cc
+++ b/tensorflow/compiler/mlir/lite/debug/debug_test.cc
@@ -46,8 +46,8 @@
 #include "mlir/Support/TypeID.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/lite/debug/debug_options.pb.h"
 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/types.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/env.h"
 #include "tsl/platform/path.h"
 
diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc
index 6164222..269f5cd 100644
--- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc
+++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc
@@ -113,7 +113,6 @@
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/types.pb.h"
 #include "tensorflow/core/platform/tstring.h"
-#include "tensorflow/lite/core/interpreter.h"
 #include "tensorflow/lite/toco/toco_flags.pb.h"
 #include "tensorflow/lite/tools/versioning/gpu_compatibility.h"
 #include "tensorflow/lite/tools/versioning/op_version.h"
@@ -159,6 +158,11 @@
 
 ABSL_CONST_INIT const absl::string_view kFlexOpNamePrefix = "Flex";
 
+// LINT.IfChange(optional_tensor)
+// Taken from third_party/tensorflow/lite/core/c/common.h
+constexpr int kTfLiteMigrationOptionalTensor = -1;
+// LINT.ThenChange(//tensorflow/lite/core/c/common.h:optional_tensor)
+
 // Use initial buffer size in flatbuffer builder to be same as the initial size
 // used by the TOCO export. (It does not explain rationale for this choice.)
 constexpr size_t kInitialBufferSize = 10240;
@@ -3024,7 +3028,7 @@
     operands.reserve(real_inst->getNumOperands());
     for (auto operand : real_inst->getOperands()) {
       if (mlir::isa<NoneType>(operand.getType()))
-        operands.push_back(kTfLiteOptionalTensor);
+        operands.push_back(kTfLiteMigrationOptionalTensor);
       else if (auto stats_op =
                    llvm::dyn_cast_or_null<mlir::quantfork::StatisticsOp>(
                        operand.getDefiningOp()))
diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc
index a03e988..a289126 100644
--- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc
+++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc
@@ -44,6 +44,7 @@
 #include "llvm/Analysis/AssumeBundleQueries.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/LogicalResult.h"
 #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"  // from @llvm-project
 #include "mlir/Dialect/Arith/IR/Arith.h"  // from @llvm-project
 #include "mlir/Dialect/Func/Extensions/AllExtensions.h"  // from @llvm-project
@@ -70,16 +71,17 @@
 #include "mlir/IR/Value.h"  // from @llvm-project
 #include "mlir/Pass/PassManager.h"  // from @llvm-project
 #include "mlir/Support/LLVM.h"  // from @llvm-project
-#include "mlir/Support/LogicalResult.h"  // from @llvm-project
 #include "mlir/Tools/mlir-translate/Translation.h"  // from @llvm-project
 #include "stablehlo/dialect/StablehloOps.h"  // from @stablehlo
 #include "stablehlo/dialect/VhloOps.h"  // from @stablehlo
+#include "tensorflow/compiler/mlir/lite/core/absl_error_model_builder.h"
 #include "tensorflow/compiler/mlir/lite/experimental/remat/metadata_util.h"
 #include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h"
 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
 #include "tensorflow/compiler/mlir/lite/offset_buffer.h"
 #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h"
 #include "tensorflow/compiler/mlir/lite/schema/mutable/schema_generated.h"
+#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h"
 #include "tensorflow/compiler/mlir/lite/schema/schema_utils.h"
 #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h"
 #include "tensorflow/compiler/mlir/lite/utils/const_tensor_utils.h"
@@ -97,8 +99,6 @@
 #include "tensorflow/core/framework/tensor.pb.h"
 #include "tensorflow/core/framework/tensor_shape.pb.h"
 #include "tensorflow/core/platform/errors.h"
-#include "tensorflow/core/platform/status.h"
-#include "tensorflow/lite/model_builder.h"
 #include "tsl/platform/status.h"
 #include "tsl/platform/statusor.h"
 
@@ -625,7 +625,7 @@
     const std::vector<std::string>& func_names,
     const std::vector<std::unique_ptr<tflite::TensorT>>& tensors, Location loc,
     OpBuilder builder,
-    const std::unique_ptr<tflite::FlatBufferModel>& model_ptr) {
+    const std::unique_ptr<tfl::FlatBufferModelAbslError>& model_ptr) {
   llvm::SmallVector<Value, 4> operands;
   llvm::SmallVector<mlir::Type, 2> outputTypes;
 
@@ -1116,7 +1116,7 @@
     bool experimental_prune_unreachable_nodes_unconditionally,
     const tflite::SignatureDefT* signature,
     const tflite::ControlEdges& control_edges,
-    const std::unique_ptr<tflite::FlatBufferModel>& model_ptr,
+    const std::unique_ptr<tfl::FlatBufferModelAbslError>& model_ptr,
     bool use_stablehlo_constant) {
   // Populate from metadata.
   ControlNodes control_nodes;
@@ -1518,8 +1518,8 @@
       mlir::TFL::TensorFlowLiteDialect, mlir::TF::TensorFlowDialect,
       mlir::stablehlo::StablehloDialect, mlir::vhlo::VhloDialect>();
 
-  auto model_ptr =
-      FlatBufferModel::VerifyAndBuildFromBuffer(buffer.data(), buffer.length());
+  auto model_ptr = tfl::FlatBufferModelAbslError::VerifyAndBuildFromBuffer(
+      buffer.data(), buffer.length());
   if (nullptr == model_ptr) {
     return emitError(base_loc, "couldn't parse flatbuffer"), nullptr;
   }
diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_to_string.cc b/tensorflow/compiler/mlir/lite/flatbuffer_to_string.cc
index df28f50..b393a88 100644
--- a/tensorflow/compiler/mlir/lite/flatbuffer_to_string.cc
+++ b/tensorflow/compiler/mlir/lite/flatbuffer_to_string.cc
@@ -27,7 +27,7 @@
 #include "flatbuffers/minireflect.h"  // from @flatbuffers
 #include "tensorflow/compiler/mlir/lite/schema/reflection/schema_generated.h"
 #if FLATBUFFERS_LITTLEENDIAN == 0
-#include "tensorflow/lite/core/model_builder.h"
+#include "tensorflow/compiler/mlir/lite/core/absl_error_model_builder.h"
 #endif
 
 namespace tflite {
@@ -144,7 +144,8 @@
   // If the flatbuffer model comes from stdin, convert its tensor content from
   // BE to LE to ensure the output text string is the same as on LE platforms.
   if (std::string(argv[1]) == "-")
-    tflite::FlatBufferModel::ByteSwapSerializedModel(&serialized_model, true);
+    mlir::TFL::FlatBufferModelAbslError::ByteSwapSerializedModel(
+        &serialized_model, true);
 #endif
   tflite::ToString(serialized_model);
   return 0;
diff --git a/tensorflow/compiler/mlir/lite/python/BUILD b/tensorflow/compiler/mlir/lite/python/BUILD
index 6d2de49..299bb9e 100644
--- a/tensorflow/compiler/mlir/lite/python/BUILD
+++ b/tensorflow/compiler/mlir/lite/python/BUILD
@@ -195,8 +195,10 @@
         ":saved_model_to_tfl_flatbuffer",
         "//tensorflow/c:kernels",
         "//tensorflow/c:tf_status_headers",
+        "//tensorflow/compiler/mlir/lite/core:absl_error_model_builder",
         "//tensorflow/compiler/mlir/lite/debug:debug_options_proto_cc",
         "//tensorflow/compiler/mlir/lite/metrics:error_collector",
+        "//tensorflow/compiler/mlir/lite/python/interpreter_wrapper:python_error_reporter",
         "//tensorflow/compiler/mlir/lite/python/interpreter_wrapper:python_utils",
         "//tensorflow/compiler/mlir/lite/quantization/lite:quantize_model",
         "//tensorflow/compiler/mlir/lite/schema:schema_fbs",
@@ -205,8 +207,6 @@
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
         "//tensorflow/core:protos_all_cc",
-        "//tensorflow/lite:model_builder",
-        "//tensorflow/lite/python/interpreter_wrapper:python_error_reporter",
         "//tensorflow/lite/toco:model",
         "//tensorflow/lite/toco:model_flags_proto_cc",
         "//tensorflow/lite/toco:toco_convert",
diff --git a/tensorflow/compiler/mlir/lite/python/converter_python_api.cc b/tensorflow/compiler/mlir/lite/python/converter_python_api.cc
index 31ec151..351846c 100644
--- a/tensorflow/compiler/mlir/lite/python/converter_python_api.cc
+++ b/tensorflow/compiler/mlir/lite/python/converter_python_api.cc
@@ -31,10 +31,12 @@
 #include "google/protobuf/text_format.h"
 #include "tensorflow/c/kernels.h"
 #include "tensorflow/c/tf_status.h"
+#include "tensorflow/compiler/mlir/lite/core/absl_error_model_builder.h"
 #include "tensorflow/compiler/mlir/lite/debug/debug_options.pb.h"
 #include "tensorflow/compiler/mlir/lite/metrics/error_collector.h"
 #include "tensorflow/compiler/mlir/lite/python/flatbuffer_to_mlir.h"
 #include "tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.h"
+#include "tensorflow/compiler/mlir/lite/python/interpreter_wrapper/python_error_reporter.h"
 #include "tensorflow/compiler/mlir/lite/python/interpreter_wrapper/python_utils.h"
 #include "tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.h"
 #include "tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h"
@@ -46,8 +48,6 @@
 #include "tensorflow/core/framework/op_def.pb.h"
 #include "tensorflow/core/framework/op_def_builder.h"
 #include "tensorflow/core/platform/status.h"
-#include "tensorflow/lite/model_builder.h"
-#include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h"
 #include "tensorflow/lite/toco/logging/conversion_log_util.h"
 #include "tensorflow/lite/toco/logging/toco_conversion_log.pb.h"
 #include "tensorflow/lite/toco/model.h"
@@ -309,7 +309,7 @@
                             bool enable_variable_quantization,
                             bool disable_per_channel_for_dense_layers,
                             PyObject* debug_options_proto_txt_raw) {
-  using tflite::interpreter_wrapper::PythonErrorReporter;
+  using tflite_migration::interpreter_wrapper::PythonErrorReporter;
   char* buf = nullptr;
   Py_ssize_t length;
   std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
@@ -362,9 +362,9 @@
     return nullptr;
   }
 
-  std::unique_ptr<tflite::FlatBufferModel> model =
-      tflite::FlatBufferModel::BuildFromBuffer(buf, length,
-                                               error_reporter.get());
+  std::unique_ptr<mlir::TFL::FlatBufferModelAbslError> model =
+      mlir::TFL::FlatBufferModelAbslError::BuildFromBuffer(
+          buf, length, error_reporter.get());
   if (!model) {
     PyErr_Format(PyExc_ValueError, "Invalid model");
     return nullptr;
@@ -399,7 +399,7 @@
 }
 
 PyObject* MlirSparsifyModel(PyObject* data) {
-  using tflite::interpreter_wrapper::PythonErrorReporter;
+  using tflite_migration::interpreter_wrapper::PythonErrorReporter;
   char* buf = nullptr;
   Py_ssize_t length;
   std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
@@ -408,9 +408,9 @@
     PyErr_Format(PyExc_ValueError, "Failed to convert input PyObject");
     return nullptr;
   }
-  std::unique_ptr<tflite::FlatBufferModel> model =
-      tflite::FlatBufferModel::BuildFromBuffer(buf, length,
-                                               error_reporter.get());
+  std::unique_ptr<mlir::TFL::FlatBufferModelAbslError> model =
+      mlir::TFL::FlatBufferModelAbslError::BuildFromBuffer(
+          buf, length, error_reporter.get());
   if (!model) {
     PyErr_Format(PyExc_ValueError, "Invalid model");
     return nullptr;
diff --git a/tensorflow/compiler/mlir/lite/python/interpreter_wrapper/BUILD b/tensorflow/compiler/mlir/lite/python/interpreter_wrapper/BUILD
index 8d2cb7a..9268de7 100644
--- a/tensorflow/compiler/mlir/lite/python/interpreter_wrapper/BUILD
+++ b/tensorflow/compiler/mlir/lite/python/interpreter_wrapper/BUILD
@@ -15,3 +15,14 @@
         "//third_party/python_runtime:headers",  # buildcleaner: keep
     ],
 )
+
+cc_library(
+    name = "python_error_reporter",
+    srcs = ["python_error_reporter.cc"],
+    hdrs = ["python_error_reporter.h"],
+    compatible_with = get_compatible_with_portable(),
+    deps = [
+        "//tensorflow/compiler/mlir/lite:stateful_error_reporter",
+        "//third_party/python_runtime:headers",  # buildcleaner: keep
+    ],
+)
diff --git a/tensorflow/compiler/mlir/lite/python/interpreter_wrapper/python_error_reporter.cc b/tensorflow/compiler/mlir/lite/python/interpreter_wrapper/python_error_reporter.cc
new file mode 100644
index 0000000..75f9222
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/python/interpreter_wrapper/python_error_reporter.cc
@@ -0,0 +1,47 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/mlir/lite/python/interpreter_wrapper/python_error_reporter.h"
+
+#include <cstdarg>
+#include <cstdio>
+#include <string>
+
+namespace tflite_migration {
+namespace interpreter_wrapper {
+
+// Report an error message
+int PythonErrorReporter::Report(const char* format, va_list args) {
+  char buf[1024];
+  int formatted = vsnprintf(buf, sizeof(buf), format, args);
+  buffer_ << buf;
+  return formatted;
+}
+
+// Set's a Python runtime exception with the last error.
+PyObject* PythonErrorReporter::exception() {
+  std::string last_message = message();
+  PyErr_SetString(PyExc_RuntimeError, last_message.c_str());
+  return nullptr;
+}
+
+// Gets the last error message and clears the buffer.
+std::string PythonErrorReporter::message() {
+  std::string value = buffer_.str();
+  buffer_.clear();
+  return value;
+}
+}  // namespace interpreter_wrapper
+}  // namespace tflite_migration
diff --git a/tensorflow/compiler/mlir/lite/python/interpreter_wrapper/python_error_reporter.h b/tensorflow/compiler/mlir/lite/python/interpreter_wrapper/python_error_reporter.h
new file mode 100644
index 0000000..f98a352
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/python/interpreter_wrapper/python_error_reporter.h
@@ -0,0 +1,50 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_ERROR_REPORTER_H_
+#define TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_ERROR_REPORTER_H_
+
+#include <Python.h>
+
+#include <cstdarg>
+#include <sstream>
+#include <string>
+
+#include "tensorflow/compiler/mlir/lite/stateful_error_reporter.h"
+
+namespace tflite_migration {
+namespace interpreter_wrapper {
+
+class PythonErrorReporter : public tflite_migration::StatefulErrorReporter {
+ public:
+  PythonErrorReporter() = default;
+
+  // Report an error message
+  int Report(const char* format, va_list args) override;
+
+  // Sets a Python runtime exception with the last error and
+  // clears the error message buffer.
+  PyObject* exception();
+
+  // Gets the last error message and clears the buffer.
+  std::string message() override;
+
+ private:
+  std::stringstream buffer_;
+};
+
+}  // namespace interpreter_wrapper
+}  // namespace tflite_migration
+#endif  // TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_ERROR_REPORTER_H_
diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD
index e57c2b3..c269b41 100644
--- a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD
+++ b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD
@@ -168,16 +168,16 @@
     deps = [
         ":quantize_model",
         ":test_util",
+        "//tensorflow/compiler/mlir/lite/core:absl_error_model_builder",
         "//tensorflow/compiler/mlir/lite/schema:schema_fbs",
         "//tensorflow/compiler/mlir/lite/schema:schema_utils",
         "//tensorflow/core:framework_internal",
         "//tensorflow/core:lib",
-        "//tensorflow/lite:framework",
         "@com_google_absl//absl/container:flat_hash_set",
         "@com_google_absl//absl/status",
         "@com_google_googletest//:gtest",
         "@flatbuffers",
-        "@local_tsl//tsl/lib/core:status_test_util",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
 
@@ -205,11 +205,11 @@
     deps = [
         ":quantize_weights",
         ":test_util",
+        "//tensorflow/compiler/mlir/lite/core:absl_error_model_builder",
         "//tensorflow/compiler/mlir/lite/schema:schema_fbs",
         "//tensorflow/compiler/mlir/lite/schema:schema_utils",
         "//tensorflow/core:framework_internal",
         "//tensorflow/core:lib",
-        "//tensorflow/lite:framework",
         "@com_google_absl//absl/status",
         "@com_google_googletest//:gtest",
         "@flatbuffers",
@@ -223,9 +223,7 @@
     srcs = ["test_util.cc"],
     hdrs = ["test_util.h"],
     deps = [
-        "//tensorflow/lite:framework",
-        "//tensorflow/lite/core/api",
+        "//tensorflow/compiler/mlir/lite/core/api:error_reporter",
         "@com_google_googletest//:gtest",
-        "@flatbuffers",
     ],
 )
diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc
index 1e7cdcd..371f452 100644
--- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc
+++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc
@@ -30,15 +30,15 @@
 #include "absl/status/status.h"
 #include "flatbuffers/flatbuffer_builder.h"  // from @flatbuffers
 #include "flatbuffers/vector.h"  // from @flatbuffers
+#include "tensorflow/compiler/mlir/lite/core/absl_error_model_builder.h"
 #include "tensorflow/compiler/mlir/lite/quantization/lite/test_util.h"
 #include "tensorflow/compiler/mlir/lite/schema/schema_generated.h"
 #include "tensorflow/compiler/mlir/lite/schema/schema_utils.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/init_main.h"
 #include "tensorflow/core/platform/path.h"
 #include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/util/command_line_flags.h"
-#include "tensorflow/lite/model_builder.h"
-#include "tsl/lib/core/status_test_util.h"
 
 // Note: branched from tensorflow/lite/tools/optimize/quantize_model_test.cc
 
@@ -50,6 +50,7 @@
 namespace optimize {
 namespace {
 
+using mlir::TFL::FlatBufferModelAbslError;
 using testing::Eq;
 using testing::FloatEq;
 using testing::FloatNear;
@@ -100,7 +101,7 @@
     return status;
   }
 
-  auto flatbuffer_model = FlatBufferModel::BuildFromBuffer(
+  auto flatbuffer_model = FlatBufferModelAbslError::BuildFromBuffer(
       output_buffer.data(), output_buffer.size());
   *model = UnPackFlatBufferModel(*flatbuffer_model->GetModel());
   return absl::OkStatus();
@@ -157,9 +158,10 @@
                        disable_per_channel_for_dense_layers);
 }
 
-std::unique_ptr<FlatBufferModel> ReadModel(const std::string& model_name) {
+std::unique_ptr<FlatBufferModelAbslError> ReadModel(
+    const std::string& model_name) {
   auto model_path = tensorflow::io::JoinPath(*g_test_model_dir, model_name);
-  return FlatBufferModel::BuildFromFile(model_path.c_str());
+  return FlatBufferModelAbslError::BuildFromFile(model_path.c_str());
 }
 
 template <typename T>
@@ -198,7 +200,7 @@
     model_ = UnPackFlatBufferModel(*readonly_model_);
   }
 
-  std::unique_ptr<FlatBufferModel> input_model_;
+  std::unique_ptr<FlatBufferModelAbslError> input_model_;
   const Model* readonly_model_;
   tflite::ModelT model_;
   std::string output_buffer_;  // Raw buffer for quantized output model.
diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc
index 7a42e74..db124c8 100644
--- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc
+++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc
@@ -27,6 +27,7 @@
 #include "flatbuffers/buffer.h"  // from @flatbuffers
 #include "flatbuffers/flatbuffer_builder.h"  // from @flatbuffers
 #include "flatbuffers/vector.h"  // from @flatbuffers
+#include "tensorflow/compiler/mlir/lite/core/absl_error_model_builder.h"
 #include "tensorflow/compiler/mlir/lite/quantization/lite/test_util.h"
 #include "tensorflow/compiler/mlir/lite/schema/schema_generated.h"
 #include "tensorflow/compiler/mlir/lite/schema/schema_utils.h"
@@ -34,7 +35,6 @@
 #include "tensorflow/core/platform/path.h"
 #include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/util/command_line_flags.h"
-#include "tensorflow/lite/model_builder.h"
 #include "tsl/platform/logging.h"
 
 // Note: branched from tensorflow/lite/tools/optimize/quantize_weights_test.cc
@@ -50,6 +50,7 @@
 using mlir::lite::BufferType;
 using mlir::lite::CustomOpMap;
 using mlir::lite::QuantizeWeights;
+using mlir::TFL::FlatBufferModelAbslError;
 constexpr bool kUseUpdatedHybridSchemeDefault = true;
 
 std::unique_ptr<ModelT> CreateMutableModelFromFile(const Model* input_model) {
@@ -58,28 +59,28 @@
   return copied_model;
 }
 
-std::unique_ptr<FlatBufferModel> ReadTestModel() {
+std::unique_ptr<FlatBufferModelAbslError> ReadTestModel() {
   auto model_path = tensorflow::io::JoinPath(
       *g_test_model_dir, ::mlir::lite::internal::kConvModelWith0Plus10Weights);
-  return FlatBufferModel::BuildFromFile(model_path.c_str());
+  return FlatBufferModelAbslError::BuildFromFile(model_path.c_str());
 }
 
-std::unique_ptr<FlatBufferModel> ReadSharedWeightsTestModel() {
+std::unique_ptr<FlatBufferModelAbslError> ReadSharedWeightsTestModel() {
   auto model_path = tensorflow::io::JoinPath(
       *g_test_model_dir, ::mlir::lite::internal::kModelWithSharedWeights);
-  return FlatBufferModel::BuildFromFile(model_path.c_str());
+  return FlatBufferModelAbslError::BuildFromFile(model_path.c_str());
 }
 
-std::unique_ptr<FlatBufferModel> ReadGatherTestModel() {
+std::unique_ptr<FlatBufferModelAbslError> ReadGatherTestModel() {
   auto model_path = tensorflow::io::JoinPath(
       *g_test_model_dir, ::mlir::lite::internal::kQuantizedWithGather);
-  return FlatBufferModel::BuildFromFile(model_path.c_str());
+  return FlatBufferModelAbslError::BuildFromFile(model_path.c_str());
 }
 
-std::unique_ptr<FlatBufferModel> ReadCustomOpTestModel() {
+std::unique_ptr<FlatBufferModelAbslError> ReadCustomOpTestModel() {
   auto model_path = tensorflow::io::JoinPath(
       *g_test_model_dir, ::mlir::lite::internal::kModelWithCustomOp);
-  return FlatBufferModel::BuildFromFile(model_path.c_str());
+  return FlatBufferModelAbslError::BuildFromFile(model_path.c_str());
 }
 
 template <typename T>
@@ -111,7 +112,7 @@
     model_ = input_model_->GetModel();
   }
 
-  std::unique_ptr<FlatBufferModel> input_model_;
+  std::unique_ptr<FlatBufferModelAbslError> input_model_;
   const Model* model_;
 
   bool IsModelInputOrOutput(const Model* model, uint32_t tensor_idx) {
diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/test_util.cc b/tensorflow/compiler/mlir/lite/quantization/lite/test_util.cc
index e096868..66c1ade 100644
--- a/tensorflow/compiler/mlir/lite/quantization/lite/test_util.cc
+++ b/tensorflow/compiler/mlir/lite/quantization/lite/test_util.cc
@@ -14,6 +14,9 @@
 ==============================================================================*/
 #include "tensorflow/compiler/mlir/lite/quantization/lite/test_util.h"
 
+#include <cstdarg>
+#include <cstdio>
+
 #include <gtest/gtest.h>
 
 namespace mlir {
diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/test_util.h b/tensorflow/compiler/mlir/lite/quantization/lite/test_util.h
index b4e317c1..8953a38 100644
--- a/tensorflow/compiler/mlir/lite/quantization/lite/test_util.h
+++ b/tensorflow/compiler/mlir/lite/quantization/lite/test_util.h
@@ -15,7 +15,9 @@
 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TEST_UTIL_H_
 #define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TEST_UTIL_H_
 
-#include "tensorflow/lite/core/api/error_reporter.h"
+#include <cstdarg>
+
+#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h"
 
 namespace mlir {
 namespace lite {
diff --git a/tensorflow/compiler/mlir/lite/schema/BUILD b/tensorflow/compiler/mlir/lite/schema/BUILD
index 7dd8eec..14e80f4 100644
--- a/tensorflow/compiler/mlir/lite/schema/BUILD
+++ b/tensorflow/compiler/mlir/lite/schema/BUILD
@@ -11,7 +11,15 @@
 )
 
 exports_files(
-    srcs = ["schema.fbs"],
+    srcs = [
+        "schema.fbs",
+        "schema_v0.fbs",
+        "schema_v1.fbs",
+        "schema_v2.fbs",
+        "schema_v3.fbs",
+        "schema_v3a.fbs",
+        "schema_v3b.fbs",
+    ],
 )
 
 filegroup(
diff --git a/tensorflow/lite/schema/schema_v0.fbs b/tensorflow/compiler/mlir/lite/schema/schema_v0.fbs
similarity index 100%
rename from tensorflow/lite/schema/schema_v0.fbs
rename to tensorflow/compiler/mlir/lite/schema/schema_v0.fbs
diff --git a/tensorflow/lite/schema/schema_v1.fbs b/tensorflow/compiler/mlir/lite/schema/schema_v1.fbs
similarity index 100%
rename from tensorflow/lite/schema/schema_v1.fbs
rename to tensorflow/compiler/mlir/lite/schema/schema_v1.fbs
diff --git a/tensorflow/lite/schema/schema_v2.fbs b/tensorflow/compiler/mlir/lite/schema/schema_v2.fbs
similarity index 100%
rename from tensorflow/lite/schema/schema_v2.fbs
rename to tensorflow/compiler/mlir/lite/schema/schema_v2.fbs
diff --git a/tensorflow/lite/schema/schema_v3.fbs b/tensorflow/compiler/mlir/lite/schema/schema_v3.fbs
similarity index 100%
rename from tensorflow/lite/schema/schema_v3.fbs
rename to tensorflow/compiler/mlir/lite/schema/schema_v3.fbs
diff --git a/tensorflow/lite/schema/schema_v3a.fbs b/tensorflow/compiler/mlir/lite/schema/schema_v3a.fbs
similarity index 100%
rename from tensorflow/lite/schema/schema_v3a.fbs
rename to tensorflow/compiler/mlir/lite/schema/schema_v3a.fbs
diff --git a/tensorflow/lite/schema/schema_v3c.fbs b/tensorflow/compiler/mlir/lite/schema/schema_v3c.fbs
similarity index 100%
rename from tensorflow/lite/schema/schema_v3c.fbs
rename to tensorflow/compiler/mlir/lite/schema/schema_v3c.fbs
diff --git a/tensorflow/compiler/mlir/lite/sparsity/BUILD b/tensorflow/compiler/mlir/lite/sparsity/BUILD
index 92e6aaa..7c15ac2 100644
--- a/tensorflow/compiler/mlir/lite/sparsity/BUILD
+++ b/tensorflow/compiler/mlir/lite/sparsity/BUILD
@@ -54,9 +54,9 @@
     ],
     deps = [
         ":sparsify_model",
+        "//tensorflow/compiler/mlir/lite/core:absl_error_model_builder",
         "//tensorflow/compiler/mlir/lite/schema:schema_fbs",
         "//tensorflow/compiler/mlir/lite/tools/optimize:reduced_precision_metadata",
-        "//tensorflow/lite/core:model_builder",
         "@com_google_absl//absl/status",
         "@com_google_googletest//:gtest_main",
         "@flatbuffers",
diff --git a/tensorflow/compiler/mlir/lite/sparsity/sparsify_model_test.cc b/tensorflow/compiler/mlir/lite/sparsity/sparsify_model_test.cc
index 0d1339d..cc557b5 100644
--- a/tensorflow/compiler/mlir/lite/sparsity/sparsify_model_test.cc
+++ b/tensorflow/compiler/mlir/lite/sparsity/sparsify_model_test.cc
@@ -27,9 +27,9 @@
 #include <gtest/gtest.h>
 #include "absl/status/status.h"
 #include "flatbuffers/flatbuffer_builder.h"  // from @flatbuffers
+#include "tensorflow/compiler/mlir/lite/core/absl_error_model_builder.h"
 #include "tensorflow/compiler/mlir/lite/schema/schema_generated.h"
 #include "tensorflow/compiler/mlir/lite/tools/optimize/reduced_precision_metadata.h"
-#include "tensorflow/lite/core/model_builder.h"
 
 namespace mlir {
 namespace lite {
@@ -41,7 +41,7 @@
   std::string expected_value = "test_data";
 
   // Load input model
-  auto input_fbm = tflite::FlatBufferModel::BuildFromFile(
+  auto input_fbm = mlir::TFL::FlatBufferModelAbslError::BuildFromFile(
       "tensorflow/compiler/mlir/lite/sparsity/testdata/"
       "sparse_tensor.bin");
   tflite::ModelT input_model;
@@ -60,7 +60,7 @@
   // Sparsify and create output model
   flatbuffers::FlatBufferBuilder output_builder;
   ASSERT_TRUE(SparsifyModel(input_model, &output_builder).ok());
-  auto output_fbm = tflite::FlatBufferModel::BuildFromBuffer(
+  auto output_fbm = mlir::TFL::FlatBufferModelAbslError::BuildFromBuffer(
       reinterpret_cast<const char*>(output_builder.GetCurrentBufferPointer()),
       output_builder.GetSize());
   tflite::ModelT output_model;
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/BUILD
index 2ba51a7..56c3d89 100644
--- a/tensorflow/compiler/mlir/lite/stablehlo/BUILD
+++ b/tensorflow/compiler/mlir/lite/stablehlo/BUILD
@@ -643,8 +643,12 @@
         "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:gather",
         "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:pad",
         "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:reduce",
+        "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:reduce_window",
+        "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:slice",
+        "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:sort",
         "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:util",
         "//tensorflow/compiler/mlir/tensorflow",
+        "@llvm-project//llvm:Support",
         "@llvm-project//mlir:ArithDialect",
         "@llvm-project//mlir:FuncDialect",
         "@llvm-project//mlir:IR",
@@ -691,7 +695,8 @@
         "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:conv",
         "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:conv_util",
         "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:pad_util",
-        "@llvm-project//llvm:Support",
+        "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:reduce_window",
+        "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:slice",
         "@llvm-project//mlir:FuncDialect",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Pass",
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/prepare_hlo.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/prepare_hlo.mlir
index 6e69c40..db09eca 100644
--- a/tensorflow/compiler/mlir/lite/stablehlo/tests/prepare_hlo.mlir
+++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/prepare_hlo.mlir
@@ -164,6 +164,31 @@
 // CHECK-SAME: [0, 0], [0, 0]
 // CHECK-SAME: (tensor<1x256x256x3xf32>, tensor<2x1x1x3xf32>) -> tensor<1x256x256x2xf32>
 
+// -----
+
+// CHECK-LABEL: conv2d_nhwc_ohwi_nhwc_asymmetric_padded
+func.func @conv2d_nhwc_ohwi_nhwc_asymmetric_padded(%input: tensor<1x255x255x3xf32>, %filter: tensor<2x1x1x3xf32>) -> tensor<1x256x256x2xf32> {
+  %0 = "mhlo.convolution"(%input, %filter) {
+    dimension_numbers = #mhlo.conv<[b, 0, 1, f]x[o, 0, 1, i]->[b, 0, 1, f]>,
+    batch_group_count = 1 : i64,
+    feature_group_count = 1 : i64,
+    window_strides = dense<1> : tensor<2xi64>,
+    padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
+    rhs_dilation = dense<[1, 1]> : tensor<2xi64>,
+    lhs_dilation = dense<[1, 1]> : tensor<2xi64>
+  } : (tensor<1x255x255x3xf32>, tensor<2x1x1x3xf32>) -> tensor<1x256x256x2xf32>
+  func.return %0 : tensor<1x256x256x2xf32>
+}
+
+// CHECK:      %[[PADDED_LHS:.*]] = "mhlo.pad"
+// CHECK-SAME: edge_padding_high = dense<[0, 1, 1, 0]>
+// CHECK-SAME: edge_padding_low = dense<0>
+// CHECK-SAME: interior_padding = dense<0>
+// CHECK:      mhlo.convolution(%[[PADDED_LHS]]
+// CHECK-SAME: pad
+// CHECK-SAME: [0, 0], [0, 0]
+// CHECK-SAME: (tensor<1x256x256x3xf32>, tensor<2x1x1x3xf32>) -> tensor<1x256x256x2xf32>
+
 
 // -----
 
@@ -417,7 +442,6 @@
 // 1D
 //=--
 
-// TODO: b/351437662 - Add support for conv1d.
 // CHECK-LABEL: conv1d_nsc_osi_nsc
 func.func @conv1d_nsc_osi_nsc(%arg0: tensor<16x32x256xf32>, %arg1: tensor<256x1x256xf32>) -> tensor<16x32x256xf32> {
 	%0 = "mhlo.convolution"(%arg0, %arg1) {
@@ -546,3 +570,106 @@
 // CHECK-SAME: edge_padding_low = dense<[0, 1, 0]>
 // CHECK-SAME: (tensor<2x2x3xf32>, tensor<f32>) -> tensor<3x3x3xf32>
 
+// -----
+
+//===----------------------------------------------------------------------===//
+// mhlo.reduce_window
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: reduce_window_valid_channel_first
+func.func @reduce_window_valid_channel_first(%arg0: tensor<4x3x16x16xf32>) -> tensor<4x3x7x7xf32> {
+  // "0xFF800000" represents -INF for f32.
+  %0 = mhlo.constant dense<0xFF800000> : tensor<f32>
+  %1 = "mhlo.reduce_window"(%arg0, %0) ({
+  ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
+    %2 = mhlo.maximum %arg1, %arg2 : tensor<f32>
+    mhlo.return %2 : tensor<f32>
+  }) {
+    base_dilations = dense<1> : tensor<4xi64>,
+    padding = dense<0> : tensor<4x2xi64>,
+    window_dilations = dense<1> : tensor<4xi64>,
+    window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>,
+    window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor<4x3x16x16xf32>, tensor<f32>) -> tensor<4x3x7x7xf32>
+  func.return %1 : tensor<4x3x7x7xf32>
+}
+
+// CHECK:      %[[INIT_CST:.*]] = mhlo.constant dense<0xFF800000> : tensor<f32>
+// CHECK:      %[[TPOSE_IN:.*]] = "mhlo.transpose"(%arg0) <{permutation = dense<[0, 2, 3, 1]> : tensor<4xi64>}> : (tensor<4x3x16x16xf32>) -> tensor<4x16x16x3xf32>
+// CHECK:      %[[RW:.*]] = "mhlo.reduce_window"(%[[TPOSE_IN]], %[[INIT_CST]])
+// CHECK-SAME: window_dimensions = dense<[1, 3, 3, 1]>
+// CHECK-SAME: window_strides = dense<[1, 2, 2, 1]>
+// CHECK:      %3 = "mhlo.transpose"(%[[RW]]) <{permutation = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : (tensor<4x7x7x3xf32>) -> tensor<4x3x7x7xf32>
+
+// -----
+
+// CHECK-LABEL: reduce_window_same_channel_first
+func.func @reduce_window_same_channel_first(%arg0: tensor<4x3x16x16xf32>) -> tensor<4x3x8x8xf32> {
+  // "0xFF800000" represents -INF for f32.
+  %0 = mhlo.constant dense<0xFF800000> : tensor<f32>
+  %1 = "mhlo.reduce_window"(%arg0, %0) ({
+    ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
+      %6 = mhlo.maximum %arg1, %arg2 : tensor<f32>
+      "mhlo.return"(%6) : (tensor<f32>) -> ()
+    }) {
+    base_dilations = dense<1> : tensor<4xi64>,
+    padding = dense<[[0, 0], [0, 0], [0, 1], [0, 1]]> : tensor<4x2xi64>,
+    window_dilations = dense<1> : tensor<4xi64>,
+    window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>,
+    window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor<4x3x16x16xf32>, tensor<f32>) -> tensor<4x3x8x8xf32>
+  func.return %1 : tensor<4x3x8x8xf32>
+}
+
+// CHECK:      %[[INIT_CST:.*]] = mhlo.constant dense<0xFF800000> : tensor<f32>
+// CHECK:      %[[TPOSE_IN:.*]] = "mhlo.transpose"(%arg0) <{permutation = dense<[0, 2, 3, 1]> : tensor<4xi64>}> : (tensor<4x3x16x16xf32>) -> tensor<4x16x16x3xf32>
+// CHECK:      %[[RW:.*]] = "mhlo.reduce_window"(%[[TPOSE_IN]], %[[INIT_CST]])
+// CHECK-SAME: padding
+// CHECK-SAME: [0, 0], [0, 1], [0, 1], [0, 0]
+// CHECK-SAME: window_dimensions = dense<[1, 3, 3, 1]>
+// CHECK-SAME: window_strides = dense<[1, 2, 2, 1]>
+// CHECK:      %3 = "mhlo.transpose"(%[[RW]]) <{permutation = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : (tensor<4x8x8x3xf32>) -> tensor<4x3x8x8xf32>
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// mhlo.dynamic_slice
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: dynamic_slice
+func.func @dynamic_slice(%arg0: tensor<7x3xf32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<4x2xf32> {
+  %0 = "mhlo.dynamic_slice"(%arg0, %arg1, %arg2) <{slice_sizes = dense<[4, 2]> : tensor<2xi64>}> : (tensor<7x3xf32>, tensor<i32>, tensor<i32>) -> tensor<4x2xf32>
+  func.return %0 : tensor<4x2xf32>
+}
+
+// CHECK:      mhlo.dynamic_slice
+// CHECK-SAME: (tensor<7x3xf32>, tensor<i32>, tensor<i32>) -> tensor<4x2xf32>
+
+// -----
+
+// CHECK-LABEL: dynamic_slice_ui32
+func.func @dynamic_slice_ui32(%arg0: tensor<7x3xf32>, %arg1: tensor<ui32>, %arg2: tensor<ui32>) -> tensor<4x2xf32> {
+  %0 = "mhlo.dynamic_slice"(%arg0, %arg1, %arg2) <{slice_sizes = dense<[4, 2]> : tensor<2xi64>}> : (tensor<7x3xf32>, tensor<ui32>, tensor<ui32>) -> tensor<4x2xf32>
+  func.return %0 : tensor<4x2xf32>
+}
+
+// CHECK:      mhlo.dynamic_slice
+// CHECK-SAME: (tensor<7x3xf32>, tensor<i32>, tensor<i32>) -> tensor<4x2xf32>
+
+// CHECK-LABEL: dynamic_slice_ui64
+func.func @dynamic_slice_ui64(%arg0: tensor<7x3xf32>, %arg1: tensor<ui64>, %arg2: tensor<ui64>) -> tensor<4x2xf32> {
+  %0 = "mhlo.dynamic_slice"(%arg0, %arg1, %arg2) <{slice_sizes = dense<[4, 2]> : tensor<2xi64>}> : (tensor<7x3xf32>, tensor<ui64>, tensor<ui64>) -> tensor<4x2xf32>
+  func.return %0 : tensor<4x2xf32>
+}
+
+// CHECK:      mhlo.dynamic_slice
+// CHECK-SAME: (tensor<7x3xf32>, tensor<i64>, tensor<i64>) -> tensor<4x2xf32>
+
+// -----
+
+// CHECK-LABEL: dynamic_slice_i64
+func.func @dynamic_slice_i64(%arg0: tensor<7x3xf32>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<4x2xf32> {
+  %0 = "mhlo.dynamic_slice"(%arg0, %arg1, %arg2) <{slice_sizes = dense<[4, 2]> : tensor<2xi64>}> : (tensor<7x3xf32>, tensor<i64>, tensor<i64>) -> tensor<4x2xf32>
+  func.return %0 : tensor<4x2xf32>
+}
+
+// CHECK:      mhlo.dynamic_slice
+// CHECK-SAME: (tensor<7x3xf32>, tensor<i64>, tensor<i64>) -> tensor<4x2xf32>
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir
index 8ca6f44..55f2b65 100644
--- a/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir
+++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir
@@ -72,13 +72,13 @@
   func.return %0 : tensor<3x5x1x4xf32>
 }
 
-// CHECK:     %[[TRANSPOSED_0:.*]] = "tfl.transpose"
-// CHECK:     %[[TRANSPOSED_1:.*]] = "tfl.transpose"
-// CHECK-NEXT: %[[RESHAPED_0:.*]] = mhlo.reshape %[[TRANSPOSED_0]]
-// CHECK-NEXT: %[[RESHAPED_1:.*]] = mhlo.reshape %[[TRANSPOSED_1]]
-// CHECK-NEXT: %[[BMM_0:.*]] = "tfl.batch_matmul"(%[[RESHAPED_0]], %[[RESHAPED_1]]) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<3x5x12xf32>, tensor<3x12x4xf32>) -> tensor<3x5x4xf32>
-// CHECK-NEXT: %[[RESHAPED_BMM:.*]] = mhlo.reshape %[[BMM_0]]
-// CHECK-NEXT: return %[[RESHAPED_BMM]] : tensor<3x5x1x4xf32>
+// CHECK: %[[TRANSPOSED_0:.*]] = "tfl.transpose"
+// CHECK: %[[TRANSPOSED_1:.*]] = "tfl.transpose"
+// CHECK: %[[RESHAPED_0:.*]] = "tfl.reshape"(%[[TRANSPOSED_0]]
+// CHECK: %[[RESHAPED_1:.*]] = "tfl.reshape"(%[[TRANSPOSED_1]]
+// CHECK: %[[BMM_0:.*]] = "tfl.batch_matmul"(%[[RESHAPED_0]], %[[RESHAPED_1]]) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<3x5x12xf32>, tensor<3x12x4xf32>) -> tensor<3x5x4xf32>
+// CHECK: %[[RESHAPED_BMM:.*]] = "tfl.reshape"(%[[BMM_0]]
+// CHECK: return %[[RESHAPED_BMM]] : tensor<3x5x1x4xf32>
 
 // -----
 
@@ -96,11 +96,10 @@
   func.return %0 : tensor<1x1x1024xf32>
 }
 
-// CHECK:     %[[RESHAPED_0:.*]] = mhlo.reshape %arg0
-// CHECK-NEXT: %[[RESHAPED_1:.*]] = mhlo.reshape %arg1
-// CHECK-NEXT: %[[BMM_0:.*]] = "tfl.batch_matmul"(%[[RESHAPED_0]], %[[RESHAPED_1]]) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : {{.*}} -> tensor<1x1024xf32>
-// CHECK-NEXT: %[[RESHAPED_BMM:.*]] = mhlo.reshape %[[BMM_0]]
-// CHECK-NEXT: return %[[RESHAPED_BMM]] : tensor<1x1x1024xf32>
+// CHECK:  %[[RESHAPED_0:.*]] = "tfl.reshape"(%arg0
+// CHECK: %[[BMM_0:.*]] = "tfl.batch_matmul"(%[[RESHAPED_0]], %arg1) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : {{.*}} -> tensor<1x1024xf32>
+// CHECK: %[[RESHAPED_BMM:.*]] = "tfl.reshape"(%[[BMM_0]]
+// CHECK: return %[[RESHAPED_BMM]] : tensor<1x1x1024xf32>
 
 // -----
 
@@ -115,11 +114,10 @@
   func.return %0 : tensor<8xi32>
 }
 
-// CHECK:      %[[RESHAPED_0:.*]] = mhlo.reshape %arg0
-// CHECK-NEXT: %[[RESHAPED_1:.*]] = mhlo.reshape %arg1
-// CHECK-NEXT: %[[BMM_0:.*]] = "tfl.batch_matmul"(%[[RESHAPED_0]], %[[RESHAPED_1]]) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : {{.*}} -> tensor<1x8xi32>
-// CHECK-NEXT: %[[RESHAPED_BMM:.*]] = mhlo.reshape %[[BMM_0]]
-// CHECK-NEXT: return %[[RESHAPED_BMM]] : tensor<8xi32>
+// CHECK: %[[RESHAPED_0:.*]] = "tfl.reshape"(%arg0
+// CHECK: %[[BMM_0:.*]] = "tfl.batch_matmul"(%[[RESHAPED_0]], %arg1) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : {{.*}} -> tensor<1x8xi32>
+// CHECK: %[[RESHAPED_BMM:.*]] = "tfl.reshape"(%[[BMM_0]]
+// CHECK: return %[[RESHAPED_BMM]] : tensor<8xi32>
 
 // -----
 
@@ -135,29 +133,30 @@
   func.return %0 : tensor<4x4x?xf32>
 }
 
-// CHECK:      %0 = "tfl.pseudo_const"() <{value = dense<[0, 2, 1]> : tensor<3xi64>}> : () -> tensor<3xi64>
-// CHECK-NEXT: %1 = "tfl.cast"(%0) : (tensor<3xi64>) -> tensor<3xi32>
-// CHECK-NEXT: %2 = "tfl.transpose"(%arg1, %1) : (tensor<4x?x256xf32>, tensor<3xi32>) -> tensor<4x256x?xf32>
-// CHECK-NEXT: %3 = mhlo.reshape %arg0 : (tensor<4x4x256xf32>) -> tensor<4x4x256xf32>
-// CHECK-NEXT: %4 = "tfl.shape"(%arg1) : (tensor<4x?x256xf32>) -> tensor<3xi32>
-// CHECK-NEXT: %5 = "tfl.pseudo_const"() <{value = dense<[-1, 0, -1]> : tensor<3xi32>}> : () -> tensor<3xi32>
-// CHECK-NEXT: %6 = "tfl.pseudo_const"() <{value = dense<[-1, -1, 0]> : tensor<3xi32>}> : () -> tensor<3xi32>
-// CHECK-NEXT: %7 = "tfl.pseudo_const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
-// CHECK-NEXT: %8 = "tfl.unsorted_segment_prod"(%4, %5, %7) : (tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<1xi32>
-// CHECK-NEXT: %9 = "tfl.unsorted_segment_prod"(%4, %6, %7) : (tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<1xi32>
-// CHECK-NEXT: %10 = "tfl.pseudo_const"() <{value = dense<4> : tensor<1xi32>}> : () -> tensor<1xi32>
-// CHECK-NEXT: %11 = "tfl.concatenation"(%10, %9, %8) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32>
-// CHECK-NEXT: %12 = mhlo.dynamic_reshape %2, %11 : (tensor<4x256x?xf32>, tensor<3xi32>) -> tensor<4x256x?xf32>
-// CHECK-NEXT: %13 = "tfl.batch_matmul"(%3, %12) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<4x4x256xf32>, tensor<4x256x?xf32>) -> tensor<4x4x?xf32>
-// CHECK-NEXT: %14 = "tfl.shape"(%arg0) : (tensor<4x4x256xf32>) -> tensor<3xi32>
-// CHECK-NEXT: %15 = "tfl.shape"(%arg1) : (tensor<4x?x256xf32>) -> tensor<3xi32>
-// CHECK-NEXT: %16 = "tfl.pseudo_const"() <{value = dense<[0, 1]> : tensor<2xi64>}> : () -> tensor<2xi64>
-// CHECK-NEXT: %17 = "tfl.gather"(%14, %16) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<3xi32>, tensor<2xi64>) -> tensor<2xi32>
-// CHECK-NEXT: %18 = "tfl.pseudo_const"() <{value = dense<1> : tensor<1xi64>}> : () -> tensor<1xi64>
-// CHECK-NEXT: %19 = "tfl.gather"(%15, %18) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<3xi32>, tensor<1xi64>) -> tensor<1xi32>
-// CHECK-NEXT: %20 = "tfl.concatenation"(%17, %19) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<2xi32>, tensor<1xi32>) -> tensor<3xi32>
-// CHECK-NEXT: %21 = mhlo.dynamic_reshape %13, %20 : (tensor<4x4x?xf32>, tensor<3xi32>) -> tensor<4x4x?xf32>
-// CHECK-NEXT: return %21 : tensor<4x4x?xf32>
+// CHECK:     %0 = "tfl.pseudo_const"() <{value = dense<[0, 2, 1]> : tensor<3xi64>}> : () -> tensor<3xi64>
+// CHECK:     %1 = "tfl.cast"(%0) : (tensor<3xi64>) -> tensor<3xi32>
+// CHECK:     %2 = "tfl.transpose"(%arg1, %1) : (tensor<4x?x256xf32>, tensor<3xi32>) -> tensor<4x256x?xf32>
+// CHECK:     %3 = "tfl.shape"(%arg1) : (tensor<4x?x256xf32>) -> tensor<3xi32>
+// CHECK-DAG: %4 = "tfl.pseudo_const"() <{value = dense<[-1, 0, -1]> : tensor<3xi32>}> : () -> tensor<3xi32>
+// CHECK-DAG: %5 = "tfl.pseudo_const"() <{value = dense<[-1, -1, 0]> : tensor<3xi32>}> : () -> tensor<3xi32>
+// CHECK-DAG: %6 = "tfl.pseudo_const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
+// CHECK:     %7 = "tfl.unsorted_segment_prod"(%3, %4, %6) : (tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<1xi32>
+// CHECK:     %8 = "tfl.unsorted_segment_prod"(%3, %5, %6) : (tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<1xi32>
+// CHECK:     %9 = "tfl.pseudo_const"() <{value = dense<4> : tensor<1xi32>}> : () -> tensor<1xi32>
+// CHECK:     %10 = "tfl.concatenation"(%9, %8, %7) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32>
+// CHECK:     %11 = "tfl.cast"(%10) : (tensor<3xi32>) -> tensor<3xi32>
+// CHECK:     %12 = "tfl.reshape"(%2, %11) : (tensor<4x256x?xf32>, tensor<3xi32>) -> tensor<4x256x?xf32>
+// CHECK:     %13 = "tfl.batch_matmul"(%arg0, %12) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<4x4x256xf32>, tensor<4x256x?xf32>) -> tensor<4x4x?xf32>
+// CHECK:     %14 = "tfl.shape"(%arg0) : (tensor<4x4x256xf32>) -> tensor<3xi32>
+// CHECK:     %15 = "tfl.shape"(%arg1) : (tensor<4x?x256xf32>) -> tensor<3xi32>
+// CHECK:     %16 = "tfl.pseudo_const"() <{value = dense<[0, 1]> : tensor<2xi64>}> : () -> tensor<2xi64>
+// CHECK:     %17 = "tfl.gather"(%14, %16) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<3xi32>, tensor<2xi64>) -> tensor<2xi32>
+// CHECK:     %18 = "tfl.pseudo_const"() <{value = dense<1> : tensor<1xi64>}> : () -> tensor<1xi64>
+// CHECK:     %19 = "tfl.gather"(%15, %18) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<3xi32>, tensor<1xi64>) -> tensor<1xi32>
+// CHECK:     %20 = "tfl.concatenation"(%17, %19) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<2xi32>, tensor<1xi32>) -> tensor<3xi32>
+// CHECK:     %21 = "tfl.cast"(%20) : (tensor<3xi32>) -> tensor<3xi32>
+// CHECK:     %22 = "tfl.reshape"(%13, %21) : (tensor<4x4x?xf32>, tensor<3xi32>) -> tensor<4x4x?xf32>
+// CHECK:     return %22 : tensor<4x4x?xf32>
 
 // -----
 
@@ -173,43 +172,45 @@
   func.return %0 : tensor<2x?x2x4xf32>
 }
 
-// CHECK:      %0 = "tfl.pseudo_const"() <{value = dense<[0, 1, 3, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
-// CHECK-NEXT: %1 = "tfl.cast"(%0) : (tensor<4xi64>) -> tensor<4xi32>
-// CHECK-NEXT: %2 = "tfl.transpose"(%arg1, %1) : (tensor<2x?x4x3xf32>, tensor<4xi32>) -> tensor<2x?x3x4xf32>
-// CHECK-NEXT: %3 = "tfl.shape"(%arg0) : (tensor<2x?x2x3xf32>) -> tensor<4xi32>
-// CHECK-NEXT: %4 = "tfl.pseudo_const"() <{value = dense<[-1, -1, 0, -1]> : tensor<4xi32>}> : () -> tensor<4xi32>
-// CHECK-NEXT: %5 = "tfl.pseudo_const"() <{value = dense<[-1, -1, -1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32>
-// CHECK-NEXT: %6 = "tfl.pseudo_const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
-// CHECK-NEXT: %7 = "tfl.unsorted_segment_prod"(%3, %4, %6) : (tensor<4xi32>, tensor<4xi32>, tensor<i32>) -> tensor<1xi32>
-// CHECK-NEXT: %8 = "tfl.unsorted_segment_prod"(%3, %5, %6) : (tensor<4xi32>, tensor<4xi32>, tensor<i32>) -> tensor<1xi32>
-// CHECK-NEXT: %9 = "tfl.pseudo_const"() <{value = dense<[0, 1]> : tensor<2xi64>}> : () -> tensor<2xi64>
-// CHECK-NEXT: %10 = "tfl.gather"(%3, %9) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<4xi32>, tensor<2xi64>) -> tensor<2xi32>
-// CHECK-NEXT: %11 = "tfl.concatenation"(%10, %7, %8) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32>
-// CHECK-NEXT: %12 = mhlo.dynamic_reshape %arg0, %11 : (tensor<2x?x2x3xf32>, tensor<4xi32>) -> tensor<2x?x2x3xf32>
-// CHECK-NEXT: %13 = "tfl.shape"(%arg1) : (tensor<2x?x4x3xf32>) -> tensor<4xi32>
-// CHECK-NEXT: %14 = "tfl.pseudo_const"() <{value = dense<[-1, -1, 0, -1]> : tensor<4xi32>}> : () -> tensor<4xi32>
-// CHECK-NEXT: %15 = "tfl.pseudo_const"() <{value = dense<[-1, -1, -1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32>
-// CHECK-NEXT: %16 = "tfl.pseudo_const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
-// CHECK-NEXT: %17 = "tfl.unsorted_segment_prod"(%13, %14, %16) : (tensor<4xi32>, tensor<4xi32>, tensor<i32>) -> tensor<1xi32>
-// CHECK-NEXT: %18 = "tfl.unsorted_segment_prod"(%13, %15, %16) : (tensor<4xi32>, tensor<4xi32>, tensor<i32>) -> tensor<1xi32>
-// CHECK-NEXT: %19 = "tfl.pseudo_const"() <{value = dense<[0, 1]> : tensor<2xi64>}> : () -> tensor<2xi64>
-// CHECK-NEXT: %20 = "tfl.gather"(%13, %19) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<4xi32>, tensor<2xi64>) -> tensor<2xi32>
-// CHECK-NEXT: %21 = "tfl.concatenation"(%20, %18, %17) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32>
-// CHECK-NEXT: %22 = mhlo.dynamic_reshape %2, %21 : (tensor<2x?x3x4xf32>, tensor<4xi32>) -> tensor<2x?x3x4xf32>
-// CHECK-NEXT: %23 = "tfl.batch_matmul"(%12, %22) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<2x?x2x3xf32>, tensor<2x?x3x4xf32>) -> tensor<2x?x2x4xf32>
-// CHECK-NEXT: %24 = "tfl.shape"(%arg0) : (tensor<2x?x2x3xf32>) -> tensor<4xi32>
-// CHECK-NEXT: %25 = "tfl.shape"(%arg1) : (tensor<2x?x4x3xf32>) -> tensor<4xi32>
-// CHECK-NEXT: %26 = "tfl.pseudo_const"() <{value = dense<[0, 1, 2]> : tensor<3xi64>}> : () -> tensor<3xi64>
-// CHECK-NEXT: %27 = "tfl.gather"(%24, %26) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<4xi32>, tensor<3xi64>) -> tensor<3xi32>
-// CHECK-NEXT: %28 = "tfl.pseudo_const"() <{value = dense<2> : tensor<1xi64>}> : () -> tensor<1xi64>
-// CHECK-NEXT: %29 = "tfl.gather"(%25, %28) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<4xi32>, tensor<1xi64>) -> tensor<1xi32>
-// CHECK-NEXT: %30 = "tfl.concatenation"(%27, %29) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<3xi32>, tensor<1xi32>) -> tensor<4xi32>
-// CHECK-NEXT: %31 = mhlo.dynamic_reshape %23, %30 : (tensor<2x?x2x4xf32>, tensor<4xi32>) -> tensor<2x?x2x4xf32>
-// CHECK-NEXT: return %31 : tensor<2x?x2x4xf32>
+// CHECK:     %0 = "tfl.pseudo_const"() <{value = dense<[0, 1, 3, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
+// CHECK:     %1 = "tfl.cast"(%0) : (tensor<4xi64>) -> tensor<4xi32>
+// CHECK:     %2 = "tfl.transpose"(%arg1, %1) : (tensor<2x?x4x3xf32>, tensor<4xi32>) -> tensor<2x?x3x4xf32>
+// CHECK:     %3 = "tfl.shape"(%arg0) : (tensor<2x?x2x3xf32>) -> tensor<4xi32>
+// CHECK-DAG: %4 = "tfl.pseudo_const"() <{value = dense<[-1, -1, 0, -1]> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG: %5 = "tfl.pseudo_const"() <{value = dense<[-1, -1, -1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG: %6 = "tfl.pseudo_const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
+// CHECK:     %7 = "tfl.unsorted_segment_prod"(%3, %4, %6) : (tensor<4xi32>, tensor<4xi32>, tensor<i32>) -> tensor<1xi32>
+// CHECK:     %8 = "tfl.unsorted_segment_prod"(%3, %5, %6) : (tensor<4xi32>, tensor<4xi32>, tensor<i32>) -> tensor<1xi32>
+// CHECK:     %9 = "tfl.pseudo_const"() <{value = dense<[0, 1]> : tensor<2xi64>}> : () -> tensor<2xi64>
+// CHECK:     %10 = "tfl.gather"(%3, %9) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<4xi32>, tensor<2xi64>) -> tensor<2xi32>
+// CHECK:     %11 = "tfl.concatenation"(%10, %7, %8) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32>
+// CHECK:     %12 = "tfl.cast"(%11) : (tensor<4xi32>) -> tensor<4xi32>
+// CHECK:     %13 = "tfl.reshape"(%arg0, %12) : (tensor<2x?x2x3xf32>, tensor<4xi32>) -> tensor<2x?x2x3xf32>
+// CHECK:     %14 = "tfl.shape"(%arg1) : (tensor<2x?x4x3xf32>) -> tensor<4xi32>
+// CHECK-DAG: %15 = "tfl.pseudo_const"() <{value = dense<[-1, -1, 0, -1]> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG: %16 = "tfl.pseudo_const"() <{value = dense<[-1, -1, -1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG: %17 = "tfl.pseudo_const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
+// CHECK:     %18 = "tfl.unsorted_segment_prod"(%14, %15, %17) : (tensor<4xi32>, tensor<4xi32>, tensor<i32>) -> tensor<1xi32>
+// CHECK:     %19 = "tfl.unsorted_segment_prod"(%14, %16, %17) : (tensor<4xi32>, tensor<4xi32>, tensor<i32>) -> tensor<1xi32>
+// CHECK:     %20 = "tfl.pseudo_const"() <{value = dense<[0, 1]> : tensor<2xi64>}> : () -> tensor<2xi64>
+// CHECK:     %21 = "tfl.gather"(%14, %20) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<4xi32>, tensor<2xi64>) -> tensor<2xi32>
+// CHECK:     %22 = "tfl.concatenation"(%21, %19, %18) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32>
+// CHECK:     %23 = "tfl.cast"(%22) : (tensor<4xi32>) -> tensor<4xi32>
+// CHECK:     %24 = "tfl.reshape"(%2, %23) : (tensor<2x?x3x4xf32>, tensor<4xi32>) -> tensor<2x?x3x4xf32>
+// CHECK:     %25 = "tfl.batch_matmul"(%13, %24) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<2x?x2x3xf32>, tensor<2x?x3x4xf32>) -> tensor<2x?x2x4xf32>
+// CHECK:     %26 = "tfl.shape"(%arg0) : (tensor<2x?x2x3xf32>) -> tensor<4xi32>
+// CHECK:     %27 = "tfl.shape"(%arg1) : (tensor<2x?x4x3xf32>) -> tensor<4xi32>
+// CHECK:     %28 = "tfl.pseudo_const"() <{value = dense<[0, 1, 2]> : tensor<3xi64>}> : () -> tensor<3xi64>
+// CHECK:     %29 = "tfl.gather"(%26, %28) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<4xi32>, tensor<3xi64>) -> tensor<3xi32>
+// CHECK:     %30 = "tfl.pseudo_const"() <{value = dense<2> : tensor<1xi64>}> : () -> tensor<1xi64>
+// CHECK:     %31 = "tfl.gather"(%27, %30) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<4xi32>, tensor<1xi64>) -> tensor<1xi32>
+// CHECK:     %32 = "tfl.concatenation"(%29, %31) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<3xi32>, tensor<1xi32>) -> tensor<4xi32>
+// CHECK:     %33 = "tfl.cast"(%32) : (tensor<4xi32>) -> tensor<4xi32>
+// CHECK:     %34 = "tfl.reshape"(%25, %33) : (tensor<2x?x2x4xf32>, tensor<4xi32>) -> tensor<2x?x2x4xf32>
+// CHECK:     return %34 : tensor<2x?x2x4xf32>
 
 // -----
 
-
 // CHECK-LABEL: dot_general_dynamic_lhs_rhs_out_dims
 func.func @dot_general_dynamic_lhs_rhs_out_dims(%arg0: tensor<2x2x?x3xf32>, %arg1: tensor<2x4x?x3xf32>) -> tensor<2x2x?x4x?xf32> {
   %0 = "mhlo.dot_general"(%arg0, %arg1) {
@@ -222,37 +223,40 @@
   func.return %0 : tensor<2x2x?x4x?xf32>
 }
 
-// CHECK:      %0 = "tfl.pseudo_const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
-// CHECK-NEXT: %1 = "tfl.cast"(%0) : (tensor<4xi64>) -> tensor<4xi32>
-// CHECK-NEXT: %2 = "tfl.transpose"(%arg1, %1) : (tensor<2x4x?x3xf32>, tensor<4xi32>) -> tensor<2x3x4x?xf32>
-// CHECK-NEXT: %3 = "tfl.shape"(%arg0) : (tensor<2x2x?x3xf32>) -> tensor<4xi32>
-// CHECK-NEXT: %4 = "tfl.pseudo_const"() <{value = dense<[-1, 0, 0, -1]> : tensor<4xi32>}> : () -> tensor<4xi32>
-// CHECK-NEXT: %5 = "tfl.pseudo_const"() <{value = dense<[-1, -1, -1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32>
-// CHECK-NEXT: %6 = "tfl.pseudo_const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
-// CHECK-NEXT: %7 = "tfl.unsorted_segment_prod"(%3, %4, %6) : (tensor<4xi32>, tensor<4xi32>, tensor<i32>) -> tensor<1xi32>
-// CHECK-NEXT: %8 = "tfl.unsorted_segment_prod"(%3, %5, %6) : (tensor<4xi32>, tensor<4xi32>, tensor<i32>) -> tensor<1xi32>
-// CHECK-NEXT: %9 = "tfl.pseudo_const"() <{value = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
-// CHECK-NEXT: %10 = "tfl.concatenation"(%9, %7, %8) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32>
-// CHECK-NEXT: %11 = mhlo.dynamic_reshape %arg0, %10 : (tensor<2x2x?x3xf32>, tensor<3xi32>) -> tensor<2x?x3xf32>
-// CHECK-NEXT: %12 = "tfl.shape"(%arg1) : (tensor<2x4x?x3xf32>) -> tensor<4xi32>
-// CHECK-NEXT: %13 = "tfl.pseudo_const"() <{value = dense<[-1, 0, 0, -1]> : tensor<4xi32>}> : () -> tensor<4xi32>
-// CHECK-NEXT: %14 = "tfl.pseudo_const"() <{value = dense<[-1, -1, -1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32>
-// CHECK-NEXT: %15 = "tfl.pseudo_const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
-// CHECK-NEXT: %16 = "tfl.unsorted_segment_prod"(%12, %13, %15) : (tensor<4xi32>, tensor<4xi32>, tensor<i32>) -> tensor<1xi32>
-// CHECK-NEXT: %17 = "tfl.unsorted_segment_prod"(%12, %14, %15) : (tensor<4xi32>, tensor<4xi32>, tensor<i32>) -> tensor<1xi32>
-// CHECK-NEXT: %18 = "tfl.pseudo_const"() <{value = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
-// CHECK-NEXT: %19 = "tfl.concatenation"(%18, %17, %16) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32>
-// CHECK-NEXT: %20 = mhlo.dynamic_reshape %2, %19 : (tensor<2x3x4x?xf32>, tensor<3xi32>) -> tensor<2x3x?xf32>
-// CHECK-NEXT: %21 = "tfl.batch_matmul"(%11, %20) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<2x?x3xf32>, tensor<2x3x?xf32>) -> tensor<2x?x?xf32>
-// CHECK-NEXT: %22 = "tfl.shape"(%arg0) : (tensor<2x2x?x3xf32>) -> tensor<4xi32>
-// CHECK-NEXT: %23 = "tfl.shape"(%arg1) : (tensor<2x4x?x3xf32>) -> tensor<4xi32>
-// CHECK-NEXT: %24 = "tfl.pseudo_const"() <{value = dense<[0, 1, 2]> : tensor<3xi64>}> : () -> tensor<3xi64>
-// CHECK-NEXT: %25 = "tfl.gather"(%22, %24) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<4xi32>, tensor<3xi64>) -> tensor<3xi32>
-// CHECK-NEXT: %26 = "tfl.pseudo_const"() <{value = dense<[1, 2]> : tensor<2xi64>}> : () -> tensor<2xi64>
-// CHECK-NEXT: %27 = "tfl.gather"(%23, %26) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<4xi32>, tensor<2xi64>) -> tensor<2xi32>
-// CHECK-NEXT: %28 = "tfl.concatenation"(%25, %27) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<3xi32>, tensor<2xi32>) -> tensor<5xi32>
-// CHECK-NEXT: %29 = mhlo.dynamic_reshape %21, %28 : (tensor<2x?x?xf32>, tensor<5xi32>) -> tensor<2x2x?x4x?xf32>
-// CHECK-NEXT: return %29 : tensor<2x2x?x4x?xf32>
+// CHECK:     %0 = "tfl.pseudo_const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
+// CHECK:     %1 = "tfl.cast"(%0) : (tensor<4xi64>) -> tensor<4xi32>
+// CHECK:     %2 = "tfl.transpose"(%arg1, %1) : (tensor<2x4x?x3xf32>, tensor<4xi32>) -> tensor<2x3x4x?xf32>
+// CHECK:     %3 = "tfl.shape"(%arg0) : (tensor<2x2x?x3xf32>) -> tensor<4xi32>
+// CHECK-DAG: %4 = "tfl.pseudo_const"() <{value = dense<[-1, 0, 0, -1]> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG: %5 = "tfl.pseudo_const"() <{value = dense<[-1, -1, -1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG: %6 = "tfl.pseudo_const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
+// CHECK:     %7 = "tfl.unsorted_segment_prod"(%3, %4, %6) : (tensor<4xi32>, tensor<4xi32>, tensor<i32>) -> tensor<1xi32>
+// CHECK:     %8 = "tfl.unsorted_segment_prod"(%3, %5, %6) : (tensor<4xi32>, tensor<4xi32>, tensor<i32>) -> tensor<1xi32>
+// CHECK:     %9 = "tfl.pseudo_const"() <{value = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
+// CHECK:     %10 = "tfl.concatenation"(%9, %7, %8) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32>
+// CHECK:     %11 = "tfl.cast"(%10) : (tensor<3xi32>) -> tensor<3xi32>
+// CHECK:     %12 = "tfl.reshape"(%arg0, %11) : (tensor<2x2x?x3xf32>, tensor<3xi32>) -> tensor<2x?x3xf32>
+// CHECK:     %13 = "tfl.shape"(%arg1) : (tensor<2x4x?x3xf32>) -> tensor<4xi32>
+// CHECK-DAG: %14 = "tfl.pseudo_const"() <{value = dense<[-1, 0, 0, -1]> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG: %15 = "tfl.pseudo_const"() <{value = dense<[-1, -1, -1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG: %16 = "tfl.pseudo_const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
+// CHECK:     %17 = "tfl.unsorted_segment_prod"(%13, %14, %16) : (tensor<4xi32>, tensor<4xi32>, tensor<i32>) -> tensor<1xi32>
+// CHECK:     %18 = "tfl.unsorted_segment_prod"(%13, %15, %16) : (tensor<4xi32>, tensor<4xi32>, tensor<i32>) -> tensor<1xi32>
+// CHECK:     %19 = "tfl.pseudo_const"() <{value = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
+// CHECK:     %20 = "tfl.concatenation"(%19, %18, %17) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32>
+// CHECK:     %21 = "tfl.cast"(%20) : (tensor<3xi32>) -> tensor<3xi32>
+// CHECK:     %22 = "tfl.reshape"(%2, %21) : (tensor<2x3x4x?xf32>, tensor<3xi32>) -> tensor<2x3x?xf32>
+// CHECK:     %23 = "tfl.batch_matmul"(%12, %22) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<2x?x3xf32>, tensor<2x3x?xf32>) -> tensor<2x?x?xf32>
+// CHECK:     %24 = "tfl.shape"(%arg0) : (tensor<2x2x?x3xf32>) -> tensor<4xi32>
+// CHECK:     %25 = "tfl.shape"(%arg1) : (tensor<2x4x?x3xf32>) -> tensor<4xi32>
+// CHECK:     %26 = "tfl.pseudo_const"() <{value = dense<[0, 1, 2]> : tensor<3xi64>}> : () -> tensor<3xi64>
+// CHECK:     %27 = "tfl.gather"(%24, %26) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<4xi32>, tensor<3xi64>) -> tensor<3xi32>
+// CHECK:     %28 = "tfl.pseudo_const"() <{value = dense<[1, 2]> : tensor<2xi64>}> : () -> tensor<2xi64>
+// CHECK:     %29 = "tfl.gather"(%25, %28) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<4xi32>, tensor<2xi64>) -> tensor<2xi32>
+// CHECK:     %30 = "tfl.concatenation"(%27, %29) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<3xi32>, tensor<2xi32>) -> tensor<5xi32>
+// CHECK:     %31 = "tfl.cast"(%30) : (tensor<5xi32>) -> tensor<5xi32>
+// CHECK:     %32 = "tfl.reshape"(%23, %31) : (tensor<2x?x?xf32>, tensor<5xi32>) -> tensor<2x2x?x4x?xf32>
+// CHECK:     return %32 : tensor<2x2x?x4x?xf32
 
 // -----
 
@@ -268,27 +272,28 @@
   func.return %0 : tensor<4x4x256xf32>
 }
 
-// CHECK:      %0 = "tfl.shape"(%arg0) : (tensor<4x4x?xf32>) -> tensor<3xi32>
-// CHECK-NEXT: %1 = "tfl.pseudo_const"() <{value = dense<[-1, 0, -1]> : tensor<3xi32>}> : () -> tensor<3xi32>
-// CHECK-NEXT: %2 = "tfl.pseudo_const"() <{value = dense<[-1, -1, 0]> : tensor<3xi32>}> : () -> tensor<3xi32>
-// CHECK-NEXT: %3 = "tfl.pseudo_const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
-// CHECK-NEXT: %4 = "tfl.unsorted_segment_prod"(%0, %1, %3) : (tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<1xi32>
-// CHECK-NEXT: %5 = "tfl.unsorted_segment_prod"(%0, %2, %3) : (tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<1xi32>
-// CHECK-NEXT: %6 = "tfl.pseudo_const"() <{value = dense<4> : tensor<1xi32>}> : () -> tensor<1xi32>
-// CHECK-NEXT: %7 = "tfl.concatenation"(%6, %4, %5) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32>
-// CHECK-NEXT: %8 = mhlo.dynamic_reshape %arg0, %7 : (tensor<4x4x?xf32>, tensor<3xi32>) -> tensor<4x4x?xf32>
-// CHECK-NEXT: %9 = "tfl.shape"(%arg1) : (tensor<4x?x256xf32>) -> tensor<3xi32>
-// CHECK-NEXT: %10 = "tfl.pseudo_const"() <{value = dense<[-1, -1, 0]> : tensor<3xi32>}> : () -> tensor<3xi32>
-// CHECK-NEXT: %11 = "tfl.pseudo_const"() <{value = dense<[-1, 0, -1]> : tensor<3xi32>}> : () -> tensor<3xi32>
-// CHECK-NEXT: %12 = "tfl.pseudo_const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
-// CHECK-NEXT: %13 = "tfl.unsorted_segment_prod"(%9, %10, %12) : (tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<1xi32>
-// CHECK-NEXT: %14 = "tfl.unsorted_segment_prod"(%9, %11, %12) : (tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<1xi32>
-// CHECK-NEXT: %15 = "tfl.pseudo_const"() <{value = dense<4> : tensor<1xi32>}> : () -> tensor<1xi32>
-// CHECK-NEXT: %16 = "tfl.concatenation"(%15, %14, %13) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32>
-// CHECK-NEXT: %17 = mhlo.dynamic_reshape %arg1, %16 : (tensor<4x?x256xf32>, tensor<3xi32>) -> tensor<4x?x256xf32>
-// CHECK-NEXT: %18 = "tfl.batch_matmul"(%8, %17) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<4x4x?xf32>, tensor<4x?x256xf32>) -> tensor<4x4x256xf32>
-// CHECK-NEXT: %19 = mhlo.reshape %18 : (tensor<4x4x256xf32>) -> tensor<4x4x256xf32>
-// CHECK-NEXT: return %19 : tensor<4x4x256xf32>
+// CHECK:     %0 = "tfl.shape"(%arg0) : (tensor<4x4x?xf32>) -> tensor<3xi32>
+// CHECK-DAG: %1 = "tfl.pseudo_const"() <{value = dense<[-1, 0, -1]> : tensor<3xi32>}> : () -> tensor<3xi32>
+// CHECK-DAG: %2 = "tfl.pseudo_const"() <{value = dense<[-1, -1, 0]> : tensor<3xi32>}> : () -> tensor<3xi32>
+// CHECK-DAG: %3 = "tfl.pseudo_const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
+// CHECK:     %4 = "tfl.unsorted_segment_prod"(%0, %1, %3) : (tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<1xi32>
+// CHECK:     %5 = "tfl.unsorted_segment_prod"(%0, %2, %3) : (tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<1xi32>
+// CHECK:     %6 = "tfl.pseudo_const"() <{value = dense<4> : tensor<1xi32>}> : () -> tensor<1xi32>
+// CHECK:     %7 = "tfl.concatenation"(%6, %4, %5) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32>
+// CHECK:     %8 = "tfl.cast"(%7) : (tensor<3xi32>) -> tensor<3xi32>
+// CHECK:     %9 = "tfl.reshape"(%arg0, %8) : (tensor<4x4x?xf32>, tensor<3xi32>) -> tensor<4x4x?xf32>
+// CHECK:     %10 = "tfl.shape"(%arg1) : (tensor<4x?x256xf32>) -> tensor<3xi32>
+// CHECK-DAG: %11 = "tfl.pseudo_const"() <{value = dense<[-1, -1, 0]> : tensor<3xi32>}> : () -> tensor<3xi32>
+// CHECK-DAG: %12 = "tfl.pseudo_const"() <{value = dense<[-1, 0, -1]> : tensor<3xi32>}> : () -> tensor<3xi32>
+// CHECK-DAG: %13 = "tfl.pseudo_const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
+// CHECK:     %14 = "tfl.unsorted_segment_prod"(%10, %11, %13) : (tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<1xi32>
+// CHECK:     %15 = "tfl.unsorted_segment_prod"(%10, %12, %13) : (tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<1xi32>
+// CHECK:     %16 = "tfl.pseudo_const"() <{value = dense<4> : tensor<1xi32>}> : () -> tensor<1xi32>
+// CHECK:     %17 = "tfl.concatenation"(%16, %15, %14) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32>
+// CHECK:     %18 = "tfl.cast"(%17) : (tensor<3xi32>) -> tensor<3xi32>
+// CHECK:     %19 = "tfl.reshape"(%arg1, %18) : (tensor<4x?x256xf32>, tensor<3xi32>) -> tensor<4x?x256xf32>
+// CHECK:     %20 = "tfl.batch_matmul"(%9, %19) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<4x4x?xf32>, tensor<4x?x256xf32>) -> tensor<4x4x256xf32>
+// CHECK:     return %20 : tensor<4x4x256xf32>
 
 // -----
 
@@ -568,13 +573,15 @@
 }
 
 // CHECK-DAG: %0 = mhlo.constant dense<0xFF800000> : tensor<f32>
-// CHECK:     %1 = mhlo.constant dense<0> : tensor<i32>
+// CHECK-DAG: %1 = mhlo.constant dense<0> : tensor<i32>
 // CHECK:     %2 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<32xi32>
-// CHECK:     %3 = mhlo.reshape %2 : (tensor<32xi32>) -> tensor<1x32x1xi32>
-// CHECK:     %cst = arith.constant dense<1> : tensor<1xi32>
-// CHECK:     %4 = "tfl.reduce_max"(%arg0, %cst) <{keep_dims = false}> : (tensor<1x32x1xf32>, tensor<1xi32>) -> tensor<1x1xf32>
-// CHECK:     %5 = "tfl.arg_max"(%arg0, %cst) : (tensor<1x32x1xf32>, tensor<1xi32>) -> tensor<1x1xi32>
-// CHECK:     return %4, %5 : tensor<1x1xf32>, tensor<1x1xi32>
+// CHECK:     %cst = arith.constant dense<[1, 32, 1]> : tensor<3xi64>
+// CHECK:     %3 = "tfl.cast"(%cst) : (tensor<3xi64>) -> tensor<3xi32>
+// CHECK:     %4 = "tfl.reshape"(%2, %3) : (tensor<32xi32>, tensor<3xi32>) -> tensor<1x32x1xi32>
+// CHECK:     %cst_0 = arith.constant dense<1> : tensor<1xi32>
+// CHECK:     %5 = "tfl.reduce_max"(%arg0, %cst_0) <{keep_dims = false}> : (tensor<1x32x1xf32>, tensor<1xi32>) -> tensor<1x1xf32>
+// CHECK:     %6 = "tfl.arg_max"(%arg0, %cst_0) : (tensor<1x32x1xf32>, tensor<1xi32>) -> tensor<1x1xi32>
+// CHECK:     return %5, %6 : tensor<1x1xf32>, tensor<1x1xi32>
 
 // -----
 
@@ -597,14 +604,16 @@
   func.return %4#1 : tensor<1xi32>
 }
 
-// CHECK:     %0 = mhlo.constant dense<0> : tensor<i32>
+// CHECK-DAG: %0 = mhlo.constant dense<0> : tensor<i32>
 // CHECK-DAG: %1 = mhlo.constant dense<-2147483648> : tensor<i32>
 // CHECK:     %2 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<9xi32>
-// CHECK:     %3 = mhlo.reshape %2 : (tensor<9xi32>) -> tensor<1x9xi32>
-// CHECK:     %cst = arith.constant dense<1> : tensor<1xi32>
-// CHECK:     %4 = "tfl.reduce_max"(%arg0, %cst) <{keep_dims = false}> : (tensor<1x9xi32>, tensor<1xi32>) -> tensor<1xi32>
-// CHECK:     %5 = "tfl.arg_max"(%arg0, %cst) : (tensor<1x9xi32>, tensor<1xi32>) -> tensor<1xi32>
-// CHECK:     return %5 : tensor<1xi32>
+// CHECK:     %cst = arith.constant dense<[1, 9]> : tensor<2xi64>
+// CHECK:     %3 = "tfl.cast"(%cst) : (tensor<2xi64>) -> tensor<2xi32>
+// CHECK:     %4 = "tfl.reshape"(%2, %3) : (tensor<9xi32>, tensor<2xi32>) -> tensor<1x9xi32>
+// CHECK:     %cst_0 = arith.constant dense<1> : tensor<1xi32>
+// CHECK:     %5 = "tfl.reduce_max"(%arg0, %cst_0) <{keep_dims = false}> : (tensor<1x9xi32>, tensor<1xi32>) -> tensor<1xi32>
+// CHECK:     %6 = "tfl.arg_max"(%arg0, %cst_0) : (tensor<1x9xi32>, tensor<1xi32>) -> tensor<1xi32>
+// CHECK:     return %6 : tensor<1xi32>
 
 // -----
 
@@ -618,11 +627,11 @@
   func.return %0 : tensor<1x32x1xf32>
 }
 
-// CHECK: %cst = arith.constant dense<1.000000e+00> : tensor<f32>
-// CHECK: %cst_0 = arith.constant dense<3.000000e+00> : tensor<f32>
-// CHECK: %0 = tfl.div %cst, %cst_0 {fused_activation_function = "NONE"} : tensor<f32>
-// CHECK: %1 = tfl.pow(%arg0, %0) : (tensor<1x32x1xf32>, tensor<f32>) -> tensor<1x32x1xf32>
-// CHECK: return %1 : tensor<1x32x1xf32>
+// CHECK-DAG: %cst = arith.constant dense<1.000000e+00> : tensor<f32>
+// CHECK-DAG: %cst_0 = arith.constant dense<3.000000e+00> : tensor<f32>
+// CHECK:     %0 = tfl.div %cst, %cst_0 {fused_activation_function = "NONE"} : tensor<f32>
+// CHECK:     %1 = tfl.pow(%arg0, %0) : (tensor<1x32x1xf32>, tensor<f32>) -> tensor<1x32x1xf32>
+// CHECK:     return %1 : tensor<1x32x1xf32>
 
 // -----
 
@@ -637,6 +646,44 @@
 // -----
 
 //===----------------------------------------------------------------------===//
+// mhlo.(dynamic)reshape
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: reshape
+func.func @reshape(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> {
+  %0 = "mhlo.reshape"(%arg0) : (tensor<2x3xf32>) -> tensor<3x2xf32>
+  func.return %0 : tensor<3x2xf32>
+}
+
+// CHECK: %cst = arith.constant dense<[3, 2]> : tensor<2xi64>
+// CHECK: %0 = "tfl.cast"(%cst) : (tensor<2xi64>) -> tensor<2xi32>
+// CHECK: %1 = "tfl.reshape"(%arg0, %0) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
+
+// -----
+
+// CHECK-LABEL: dynamic_reshape_i32
+func.func @dynamic_reshape_i32(%arg0: tensor<2x3xf32>, %arg1: tensor<2xi32>) -> tensor<?x?xf32> {
+  %0 = "mhlo.dynamic_reshape"(%arg0, %arg1) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<?x?xf32>
+  func.return %0 : tensor<?x?xf32>
+}
+
+// CHECK: %0 = "tfl.cast"(%arg1) : (tensor<2xi32>) -> tensor<2xi32>
+// CHECK: %1 = "tfl.reshape"(%arg0, %0) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<?x?xf32>
+
+// -----
+
+// CHECK-LABEL: dynamic_reshape_i64
+func.func @dynamic_reshape_i64(%arg0: tensor<2x3xf32>, %arg1: tensor<2xi64>) -> tensor<?x?xf32> {
+  %0 = "mhlo.dynamic_reshape"(%arg0, %arg1) : (tensor<2x3xf32>, tensor<2xi64>) -> tensor<?x?xf32>
+  func.return %0 : tensor<?x?xf32>
+}
+
+// CHECK: %0 = "tfl.cast"(%arg1) : (tensor<2xi64>) -> tensor<2xi32>
+// CHECK: %1 = "tfl.reshape"(%arg0, %0) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<?x?xf32>
+
+// -----
+
+//===----------------------------------------------------------------------===//
 // mhlo.convolution
 //===----------------------------------------------------------------------===//
 
@@ -1473,7 +1520,7 @@
   func.return %0 : tensor<4x64x128xf32>
 }
 
-// CHECK: %[[VAL_0:.*]] = mhlo.reshape %arg1 : (tensor<4x64xi32>) -> tensor<4x64x1xi32>
+// CHECK: %[[VAL_0:.*]] = "tfl.reshape"(%arg1, %0) : (tensor<4x64xi32>, tensor<3xi32>) -> tensor<4x64x1xi32
 // CHECK: %[[VAL_1:.*]] = "tfl.gather_nd"(%arg0, %[[VAL_0]]) : (tensor<98x128xf32>, tensor<4x64x1xi32>) -> tensor<4x64x128xf32>
 
 // -----
@@ -1582,3 +1629,1209 @@
 }
 
 // CHECK: %0 = "tfl.gather_nd"(%arg0, %arg1) : (tensor<256000xf32>, tensor<?x?x1xi32>) -> tensor<?x?xf32>
+
+// -----
+
+//===------------------------------------------------------------------------===
+// mhlo.reduce_window -> avg pool
+//===------------------------------------------------------------------------===
+
+// CHECK-LABEL: avgpool_same_channel_first
+func.func @avgpool_same_channel_first(%arg0: tensor<4x3x16x16xf32>) -> tensor<4x3x8x8xf32> {
+  %0 = mhlo.constant dense<1.000000e+00> : tensor<4x16x16x3xf32>
+  %1 = mhlo.constant dense<0.000000e+00> : tensor<f32>
+  %2 = "mhlo.transpose"(%arg0) <{permutation = dense<[0, 2, 3, 1]> : tensor<4xi64>}> : (tensor<4x3x16x16xf32>) -> tensor<4x16x16x3xf32>
+  %3 = "mhlo.reduce_window"(%2, %1) <{base_dilations = dense<1> : tensor<4xi64>, padding = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>}> ({
+  ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
+    %8 = mhlo.add %arg1, %arg2 : tensor<f32>
+    mhlo.return %8 : tensor<f32>
+  }) : (tensor<4x16x16x3xf32>, tensor<f32>) -> tensor<4x8x8x3xf32>
+  %4 = "mhlo.transpose"(%3) <{permutation = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : (tensor<4x8x8x3xf32>) -> tensor<4x3x8x8xf32>
+  %5 = "mhlo.reduce_window"(%0, %1) <{base_dilations = dense<1> : tensor<4xi64>, padding = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>}> ({
+  ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
+    %8 = mhlo.add %arg1, %arg2 : tensor<f32>
+    mhlo.return %8 : tensor<f32>
+  }) : (tensor<4x16x16x3xf32>, tensor<f32>) -> tensor<4x8x8x3xf32>
+  %6 = "mhlo.transpose"(%5) <{permutation = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : (tensor<4x8x8x3xf32>) -> tensor<4x3x8x8xf32>
+  %7 = mhlo.divide %4, %6 : tensor<4x3x8x8xf32>
+  return %7 : tensor<4x3x8x8xf32>
+}
+
+// CHECK:      %[[TPOSED_ARG0:.*]] = "tfl.transpose"(%arg0
+// CHECK-SAME: (tensor<4x3x16x16xf32>, tensor<4xi32>) -> tensor<4x16x16x3xf32>
+// CHECK:      %[[POOL_OUT:.*]] = "tfl.average_pool_2d"(%[[TPOSED_ARG0]]) <{filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<4x16x16x3xf32>) -> tensor<4x8x8x3xf32>
+// CHECK:      %[[TPOSED_OUT:.*]] = "tfl.transpose"(%[[POOL_OUT]]
+// CHECK-SAME: (tensor<4x8x8x3xf32>, tensor<4xi32>) -> tensor<4x3x8x8xf32>
+// CHECK:      return %[[TPOSED_OUT]]
+
+// -----
+
+// CHECK-LABEL: avgpool_valid_channel_first
+func.func @avgpool_valid_channel_first(%arg0: tensor<4x3x16x16xf32>) -> tensor<4x3x7x7xf32> {
+  %0 = mhlo.constant dense<9.000000e+00> : tensor<4x3x7x7xf32>
+  %1 = mhlo.constant dense<0.000000e+00> : tensor<f32>
+  %2 = "mhlo.transpose"(%arg0) <{permutation = dense<[0, 2, 3, 1]> : tensor<4xi64>}> : (tensor<4x3x16x16xf32>) -> tensor<4x16x16x3xf32>
+  %3 = "mhlo.reduce_window"(%2, %1) <{base_dilations = dense<1> : tensor<4xi64>, padding = dense<0> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>}> ({
+  ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
+    %6 = mhlo.add %arg1, %arg2 : tensor<f32>
+    mhlo.return %6 : tensor<f32>
+  }) : (tensor<4x16x16x3xf32>, tensor<f32>) -> tensor<4x7x7x3xf32>
+  %4 = "mhlo.transpose"(%3) <{permutation = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : (tensor<4x7x7x3xf32>) -> tensor<4x3x7x7xf32>
+  %5 = mhlo.divide %4, %0 : tensor<4x3x7x7xf32>
+  return %5 : tensor<4x3x7x7xf32>
+}
+
+// CHECK:      %[[TPOSED_ARG0:.*]] = "tfl.transpose"(%arg0
+// CHECK-SAME: (tensor<4x3x16x16xf32>, tensor<4xi32>) -> tensor<4x16x16x3xf32>
+// CHECK:      %[[POOL_OUT:.*]] = "tfl.average_pool_2d"(%[[TPOSED_ARG0]]) <{filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<4x16x16x3xf32>) -> tensor<4x7x7x3xf32>
+// CHECK:      %[[TPOSED_OUT:.*]] = "tfl.transpose"(%[[POOL_OUT]]
+// CHECK-SAME: (tensor<4x7x7x3xf32>, tensor<4xi32>) -> tensor<4x3x7x7xf32>
+// CHECK:      return %[[TPOSED_OUT]]
+
+// -----
+
+func.func @avgpool_valid(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> {
+  %0 = mhlo.constant dense<0.0> : tensor<f32>
+  %1 = mhlo.constant dense<9.0> : tensor<4x7x7x8xf32>
+  %2 = "mhlo.reduce_window"(%arg0, %0) ({
+    ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
+      %5 = mhlo.add %arg1, %arg2 : tensor<f32>
+      "mhlo.return"(%5) : (tensor<f32>) -> ()
+    }) {
+    base_dilations = dense<1> : tensor<4xi64>,
+    padding = dense<0> : tensor<4x2xi64>,
+    window_dilations = dense<1> : tensor<4xi64>,
+    window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
+    window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x7x7x8xf32>
+  %3 = mhlo.divide %2, %1 : tensor<4x7x7x8xf32>
+  func.return %3 : tensor<4x7x7x8xf32>
+}
+
+// CHECK: %[[POOL_OUT:.*]] = "tfl.average_pool_2d"(%arg0) <{filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32>
+// CHECK: return %[[POOL_OUT]]
+
+// -----
+
+// CHECK-LABEL: avgpool_valid_broadcasted_divisor
+func.func @avgpool_valid_broadcasted_divisor(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> {
+  %0 = mhlo.constant dense<0.0> : tensor<f32>
+  %1 = mhlo.constant dense<9.0> : tensor<f32>
+  %2 = "mhlo.broadcast_in_dim"(%1) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor<f32>) -> tensor<4x7x7x8xf32>
+  %3 = "mhlo.reduce_window"(%arg0, %0) ({
+    ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
+      %5 = mhlo.add %arg1, %arg2 : tensor<f32>
+      "mhlo.return"(%5) : (tensor<f32>) -> ()
+    }) {
+    base_dilations = dense<1> : tensor<4xi64>,
+    padding = dense<0> : tensor<4x2xi64>,
+    window_dilations = dense<1> : tensor<4xi64>,
+    window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
+    window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x7x7x8xf32>
+  %4 = mhlo.divide %3, %2 : tensor<4x7x7x8xf32>
+  func.return %4 : tensor<4x7x7x8xf32>
+}
+
+// CHECK: %[[POOL_OUT:.*]] = "tfl.average_pool_2d"(%arg0) <{filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32>
+// CHECK: return %[[POOL_OUT]]
+
+// -----
+
+// CHECK-LABEL: avgpool_valid_rw
+func.func @avgpool_valid_rw(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> {
+  %0 = mhlo.constant dense<1.0> : tensor<4x16x16x8xf32>
+  %1 = mhlo.constant dense<0.0> : tensor<f32>
+  %2 = "mhlo.reduce_window"(%arg0, %1) ({
+    ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
+      %6 = mhlo.add %arg1, %arg2 : tensor<f32>
+      "mhlo.return"(%6) : (tensor<f32>) -> ()
+    }) {
+    base_dilations = dense<1> : tensor<4xi64>,
+    padding = dense<[[0, 0], [0, 0], [0, 0], [0, 0]]> : tensor<4x2xi64>,
+    window_dilations = dense<1> : tensor<4xi64>,
+    window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
+    window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x7x7x8xf32>
+  %3 = "mhlo.reduce_window"(%0, %1) ({
+    ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
+      %6 = mhlo.add %arg1, %arg2 : tensor<f32>
+      "mhlo.return"(%6) : (tensor<f32>) -> ()
+    }) {
+    base_dilations = dense<1> : tensor<4xi64>,
+    padding = dense<[[0, 0], [0, 0], [0, 0], [0, 0]]> : tensor<4x2xi64>,
+    window_dilations = dense<1> : tensor<4xi64>,
+    window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
+    window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x7x7x8xf32>
+  %4 = mhlo.divide %2, %3 : tensor<4x7x7x8xf32>
+  func.return %4 : tensor<4x7x7x8xf32>
+}
+
+// CHECK: %[[POOL_OUT:.*]] = "tfl.average_pool_2d"(%arg0) <{filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32>
+// CHECK: return %[[POOL_OUT]]
+
+// -----
+
+// CHECK-LABEL: avgpool_valid_rw_broadcasted_const_lhs
+func.func @avgpool_valid_rw_broadcasted_const_lhs(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> {
+  %0 = mhlo.constant dense<1.0> : tensor<f32>
+  %1 = "mhlo.broadcast_in_dim"(%0) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor<f32>) -> tensor<4x16x16x8xf32>
+  %2 = mhlo.constant dense<0.0> : tensor<f32>
+  %3 = "mhlo.reduce_window"(%arg0, %2) ({
+    ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
+      %6 = mhlo.add %arg1, %arg2 : tensor<f32>
+      "mhlo.return"(%6) : (tensor<f32>) -> ()
+    }) {
+    base_dilations = dense<1> : tensor<4xi64>,
+    padding = dense<[[0, 0], [0, 0], [0, 0], [0, 0]]> : tensor<4x2xi64>,
+    window_dilations = dense<1> : tensor<4xi64>,
+    window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
+    window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x7x7x8xf32>
+  %4 = "mhlo.reduce_window"(%1, %2) ({
+    ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
+      %6 = mhlo.add %arg1, %arg2 : tensor<f32>
+      "mhlo.return"(%6) : (tensor<f32>) -> ()
+    }) {
+    base_dilations = dense<1> : tensor<4xi64>,
+    padding = dense<[[0, 0], [0, 0], [0, 0], [0, 0]]> : tensor<4x2xi64>,
+    window_dilations = dense<1> : tensor<4xi64>,
+    window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
+    window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x7x7x8xf32>
+  %5 = mhlo.divide %3, %4 : tensor<4x7x7x8xf32>
+  func.return %5 : tensor<4x7x7x8xf32>
+}
+
+// CHECK: %[[POOL_OUT:.*]] = "tfl.average_pool_2d"(%arg0) <{filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32>
+// CHECK: return %[[POOL_OUT]]
+
+// -----
+
+// CHECK-LABEL: avgpool_same
+func.func @avgpool_same(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32> {
+  %0 = mhlo.constant dense<1.0> : tensor<4x16x16x8xf32>
+  %1 = mhlo.constant dense<0.0> : tensor<f32>
+  %2 = "mhlo.reduce_window"(%arg0, %1) ({
+    ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
+      %6 = mhlo.add %arg1, %arg2 : tensor<f32>
+      "mhlo.return"(%6) : (tensor<f32>) -> ()
+    }) {
+    base_dilations = dense<1> : tensor<4xi64>,
+    padding = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi64>,
+    window_dilations = dense<1> : tensor<4xi64>,
+    window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
+    window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x8x8x8xf32>
+  %3 = "mhlo.reduce_window"(%0, %1) ({
+    ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
+      %6 = mhlo.add %arg1, %arg2 : tensor<f32>
+      "mhlo.return"(%6) : (tensor<f32>) -> ()
+    }) {
+    base_dilations = dense<1> : tensor<4xi64>,
+    padding = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi64>,
+    window_dilations = dense<1> : tensor<4xi64>,
+    window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
+    window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x8x8x8xf32>
+  %4 = mhlo.divide %2, %3 : tensor<4x8x8x8xf32>
+  func.return %4 : tensor<4x8x8x8xf32>
+}
+
+// CHECK: %[[POOL_OUT:.*]] = "tfl.average_pool_2d"(%arg0) <{filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32>
+// CHECK: return %[[POOL_OUT]]
+
+// -----
+
+// CHECK-LABEL: avgpool_reshape_broadcast
+func.func @avgpool_reshape_broadcast(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32> {
+  %0 = mhlo.constant dense<1.000000e+00> : tensor<1x16x16x1xf32>
+  %1 = mhlo.constant dense<0.000000e+00> : tensor<f32>
+  %2 = "mhlo.reduce_window"(%arg0, %1) ({
+  ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
+    %7 = mhlo.add %arg1, %arg2 : tensor<f32>
+    mhlo.return %7 : tensor<f32>
+  }) {base_dilations = dense<1> : tensor<4xi64>, padding = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x8x8x8xf32>
+  %3 = "mhlo.reduce_window"(%0, %1) ({
+  ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
+    %7 = mhlo.add %arg1, %arg2 : tensor<f32>
+    mhlo.return %7 : tensor<f32>
+  }) {base_dilations = dense<1> : tensor<4xi64>, padding = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<1x16x16x1xf32>, tensor<f32>) -> tensor<1x8x8x1xf32>
+  %4 = mhlo.reshape %3 : (tensor<1x8x8x1xf32>) -> tensor<8x8xf32>
+  %5 = "mhlo.broadcast_in_dim"(%4) <{broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}> : (tensor<8x8xf32>) -> tensor<4x8x8x8xf32>
+  %6 = mhlo.divide %2, %5 : tensor<4x8x8x8xf32>
+  return %6 : tensor<4x8x8x8xf32>
+}
+
+// CHECK: %[[POOL_OUT:.*]] = "tfl.average_pool_2d"(%arg0) <{filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32>
+// CHECK: return %[[POOL_OUT]]
+
+// -----
+
+//===------------------------------------------------------------------------===
+// mhlo.reduce_window -> max pool
+//===------------------------------------------------------------------------===
+
+// CHECK-LABEL: maxpool_same
+func.func @maxpool_same(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32> {
+  // "0xFF800000" represents -INF for f32.
+  %0 = mhlo.constant dense<0xFF800000> : tensor<f32>
+  %1 = "mhlo.reduce_window"(%arg0, %0) ({
+    ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
+      %6 = mhlo.maximum %arg1, %arg2 : tensor<f32>
+      "mhlo.return"(%6) : (tensor<f32>) -> ()
+    }) {
+    base_dilations = dense<1> : tensor<4xi64>,
+    padding = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi64>,
+    window_dilations = dense<1> : tensor<4xi64>,
+    window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
+    window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x8x8x8xf32>
+  func.return %1 : tensor<4x8x8x8xf32>
+}
+
+// CHECK: %1 = "tfl.max_pool_2d"(%arg0) <{filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32>
+
+// -----
+
+// CHECK-LABEL: maxpool_valid
+func.func @maxpool_valid(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> {
+  // "0xFF800000" represents -INF for f32.
+  %0 = mhlo.constant dense<0xFF800000> : tensor<f32>
+  %1 = "mhlo.reduce_window"(%arg0, %0) ({
+    ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
+      %6 = mhlo.maximum %arg1, %arg2 : tensor<f32>
+      "mhlo.return"(%6) : (tensor<f32>) -> ()
+    }) {
+    base_dilations = dense<1> : tensor<4xi64>,
+    padding = dense<[[0, 0], [0, 0], [0, 0], [0, 0]]> : tensor<4x2xi64>,
+    window_dilations = dense<1> : tensor<4xi64>,
+    window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
+    window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x7x7x8xf32>
+  func.return %1 : tensor<4x7x7x8xf32>
+}
+
+// CHECK: %1 = "tfl.max_pool_2d"(%arg0) <{filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32>
+
+// -----
+
+// CHECK-LABEL: maxpool_valid_channel_first
+func.func @maxpool_valid_channel_first(%arg0: tensor<4x3x16x16xf32>) -> tensor<4x3x7x7xf32> {
+  // "0xFF800000" represents -INF for f32.
+  %0 = mhlo.constant dense<0xFF800000> : tensor<f32>
+  %1 = "mhlo.transpose"(%arg0) <{permutation = dense<[0, 2, 3, 1]> : tensor<4xi64>}> : (tensor<4x3x16x16xf32>) -> tensor<4x16x16x3xf32>
+  %2 = "mhlo.reduce_window"(%1, %0) <{base_dilations = dense<1> : tensor<4xi64>, padding = dense<0> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>}> ({
+  ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
+    %4 = mhlo.maximum %arg1, %arg2 : tensor<f32>
+    mhlo.return %4 : tensor<f32>
+  }) : (tensor<4x16x16x3xf32>, tensor<f32>) -> tensor<4x7x7x3xf32>
+  %3 = "mhlo.transpose"(%2) <{permutation = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : (tensor<4x7x7x3xf32>) -> tensor<4x3x7x7xf32>
+  return %3 : tensor<4x3x7x7xf32>
+}
+
+// CHECK:      %[[TPOSED_ARG0:.*]] = "tfl.transpose"(%arg0
+// CHECK:      "tfl.max_pool_2d"(%[[TPOSED_ARG0]]) <{filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<4x16x16x3xf32>) -> tensor<4x7x7x3xf32>
+// CHECK:      return
+// CHECK-SAME: tensor<4x3x7x7xf32>
+
+// -----
+
+// CHECK-LABEL: maxpool_same_channel_first
+func.func @maxpool_same_channel_first(%arg0: tensor<4x3x16x16xf32>) -> tensor<4x3x8x8xf32> {
+  // "0xFF800000" represents -INF for f32.
+  %0 = mhlo.constant dense<0xFF800000> : tensor<f32>
+  %1 = "mhlo.transpose"(%arg0) <{permutation = dense<[0, 2, 3, 1]> : tensor<4xi64>}> : (tensor<4x3x16x16xf32>) -> tensor<4x16x16x3xf32>
+  %2 = "mhlo.reduce_window"(%1, %0) <{base_dilations = dense<1> : tensor<4xi64>, padding = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>}> ({
+  ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
+    %4 = mhlo.maximum %arg1, %arg2 : tensor<f32>
+    mhlo.return %4 : tensor<f32>
+  }) : (tensor<4x16x16x3xf32>, tensor<f32>) -> tensor<4x8x8x3xf32>
+  %3 = "mhlo.transpose"(%2) <{permutation = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : (tensor<4x8x8x3xf32>) -> tensor<4x3x8x8xf32>
+  return %3 : tensor<4x3x8x8xf32>
+}
+
+// CHECK:      %[[TPOSED_ARG0:.*]] = "tfl.transpose"(%arg0
+// CHECK:      "tfl.max_pool_2d"(%[[TPOSED_ARG0]]) <{filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<4x16x16x3xf32>) -> tensor<4x8x8x3xf32>
+// CHECK:      return
+// CHECK-SAME: tensor<4x3x8x8xf32>
+
+// -----
+
+//===------------------------------------------------------------------------===
+// mhlo.reduce_window -> tfl.cumsum
+//===------------------------------------------------------------------------===
+
+// CHECK-LABEL: reduce_window_sum
+func.func @reduce_window_sum(%arg0: tensor<4x12xf32>) -> tensor<4x12xf32> {
+  %0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
+  %1 = "mhlo.reduce_window"(%arg0, %0) ({
+  ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
+    %2 = mhlo.add %arg1, %arg2 : tensor<f32>
+    "mhlo.return"(%2) : (tensor<f32>) -> ()
+  }) {base_dilations = dense<1> : tensor<2xi64>, padding = dense<[[3, 0], [0, 0]]> : tensor<2x2xi64>, window_dilations = dense<1> : tensor<2xi64>, window_dimensions = dense<[4, 1]> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} : (tensor<4x12xf32>, tensor<f32>) -> tensor<4x12xf32>
+  func.return %1 : tensor<4x12xf32>
+}
+
+// CHECK: %[[AXIS:.*]] = arith.constant dense<0> : tensor<i32>
+// CHECK: "tfl.cumsum"(%arg0, %[[AXIS]]) <{exclusive = false, reverse = false}> : (tensor<4x12xf32>, tensor<i32>) -> tensor<4x12xf32>
+
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// mhlo.slice
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: slice
+func.func @slice(%arg0: tensor<1x4672xf32>) -> tensor<1x519xf32> {
+  %0 = "mhlo.slice"(%arg0) <{limit_indices = dense<[1, 4672]> : tensor<2xi64>, start_indices = dense<[0, 4153]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<1x4672xf32>) -> tensor<1x519xf32>
+  func.return %0 : tensor<1x519xf32>
+}
+
+// CHECK: %[[CST:.*]] = arith.constant dense<[0, 4153]> : tensor<2xi64>
+// CHECK: %[[CST_0:.*]] = arith.constant dense<[1, 4672]> : tensor<2xi64>
+// CHECK: %[[CST_1:.*]] = arith.constant dense<1> : tensor<2xi64>
+// CHECK: %[[VAL_0:.*]] = "tfl.cast"(%[[CST]]) : (tensor<2xi64>) -> tensor<2xi32>
+// CHECK: %[[VAL_1:.*]] = "tfl.cast"(%[[CST_0]]) : (tensor<2xi64>) -> tensor<2xi32>
+// CHECK: %[[VAL_2:.*]] = "tfl.cast"(%[[CST_1]]) : (tensor<2xi64>) -> tensor<2xi32>
+// CHECK: %[[VAL_3:.*]] = "tfl.strided_slice"(%arg0, %[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) <{begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32}> : (tensor<1x4672xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<1x519xf32>
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// mhlo.sort
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: sort_to_topk_iota_broadcast
+func.func @sort_to_topk_iota_broadcast(%arg0: tensor<3x6xf32>) -> (tensor<3x6xf32>, tensor<3x6xi32>) {
+  %0 = "mhlo.iota"() <{ iota_dimension = 0 : i64 }> : () -> tensor<6xi32>
+  %1 = "mhlo.broadcast_in_dim"(%0) <{broadcast_dimensions = dense<[1]> : tensor<1xi64>, name = "broadcast.0"}> : (tensor<6xi32>) -> tensor<3x6xi32>
+  %2:2 = "mhlo.sort"(%arg0, %1) ({
+  ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>, %arg3: tensor<i32>, %arg4: tensor<i32>):
+    %3 = "mhlo.compare"(%arg1, %arg2) {compare_type = #mhlo<comparison_type TOTALORDER>, comparison_direction = #mhlo<comparison_direction GT>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
+    "mhlo.return"(%3) : (tensor<i1>) -> ()
+  }) {dimension = 1 : i64, is_stable = true} : (tensor<3x6xf32>, tensor<3x6xi32>) -> (tensor<3x6xf32>, tensor<3x6xi32>)
+  func.return %2#0, %2#1 : tensor<3x6xf32>, tensor<3x6xi32>
+}
+
+// CHECK: %cst = arith.constant dense<6> : tensor<i32>
+// CHECK: %values, %indices = "tfl.topk_v2"(%arg0, %cst) : (tensor<3x6xf32>, tensor<i32>) -> (tensor<3x6xf32>, tensor<3x6xi32>)
+
+// -----
+
+// CHECK-LABEL: sort_to_topk_iota_cst_broadcast
+func.func @sort_to_topk_iota_cst_broadcast(%arg0: tensor<3x6xf32>) -> (tensor<3x6xf32>, tensor<3x6xi32>) {
+  %0 = mhlo.constant dense<[0, 1, 2, 3, 4, 5]> : tensor<6xi32>
+  %1 = "mhlo.broadcast_in_dim"(%0) <{broadcast_dimensions = dense<[1]> : tensor<1xi64>, name = "broadcast.0"}> : (tensor<6xi32>) -> tensor<3x6xi32>
+  %2:2 = "mhlo.sort"(%arg0, %1) ({
+  ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>, %arg3: tensor<i32>, %arg4: tensor<i32>):
+    %3 = "mhlo.compare"(%arg1, %arg2) {compare_type = #mhlo<comparison_type TOTALORDER>, comparison_direction = #mhlo<comparison_direction GT>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
+    "mhlo.return"(%3) : (tensor<i1>) -> ()
+  }) {dimension = 1 : i64, is_stable = true} : (tensor<3x6xf32>, tensor<3x6xi32>) -> (tensor<3x6xf32>, tensor<3x6xi32>)
+  func.return %2#0, %2#1 : tensor<3x6xf32>, tensor<3x6xi32>
+}
+
+// CHECK: %cst = arith.constant dense<6> : tensor<i32>
+// CHECK: %values, %indices = "tfl.topk_v2"(%arg0, %cst) : (tensor<3x6xf32>, tensor<i32>) -> (tensor<3x6xf32>, tensor<3x6xi32>)
+
+// -----
+
+// CHECK-LABEL: sort_to_topk_const
+func.func @sort_to_topk_const(%arg0: tensor<3x6xf32>) -> (tensor<3x6xf32>, tensor<3x6xi32>) {
+  %0 = mhlo.constant dense<[[0, 1, 2, 3, 4, 5], [0, 1, 2, 3, 4, 5], [0, 1, 2, 3, 4, 5]]> : tensor<3x6xi32>
+  %1:2 = "mhlo.sort"(%arg0, %0) ({
+  ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>, %arg3: tensor<i32>, %arg4: tensor<i32>):
+    %3 = "mhlo.compare"(%arg1, %arg2) {compare_type = #mhlo<comparison_type TOTALORDER>, comparison_direction = #mhlo<comparison_direction GT>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
+    "mhlo.return"(%3) : (tensor<i1>) -> ()
+  }) {dimension = 1 : i64, is_stable = true} : (tensor<3x6xf32>, tensor<3x6xi32>) -> (tensor<3x6xf32>, tensor<3x6xi32>)
+  func.return %1#0, %1#1 : tensor<3x6xf32>, tensor<3x6xi32>
+}
+
+// CHECK: %cst = arith.constant dense<6> : tensor<i32>
+// CHECK: %values, %indices = "tfl.topk_v2"(%arg0, %cst) : (tensor<3x6xf32>, tensor<i32>) -> (tensor<3x6xf32>, tensor<3x6xi32>
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// mhlo.dynamic_slice
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: dynamic_slice
+func.func @dynamic_slice(%arg0: tensor<7x3xf32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<4x2xf32> {
+  %0 = "mhlo.dynamic_slice"(%arg0, %arg1, %arg2) <{slice_sizes = dense<[4, 2]> : tensor<2xi64>}> : (tensor<7x3xf32>, tensor<i32>, tensor<i32>) -> tensor<4x2xf32>
+  func.return %0 : tensor<4x2xf32>
+}
+
+// CHECK-DAG: %[[CST_IS_0:.*]] = arith.constant dense<0> : tensor<i32>
+// CHECK-DAG: %[[CST_IS_3:.*]] = arith.constant dense<3> : tensor<i32>
+// CHECK:     %[[MAX_1:.*]] = "tfl.maximum"(%[[CST_IS_0]], %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+// CHECK:     %[[MIN_1:.*]] = "tfl.minimum"(%[[CST_IS_3]], %[[MAX_1]]) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+// CHECK:     %[[CST_IS_1:.*]] = arith.constant dense<1> : tensor<i32>
+// CHECK:     %[[MAX_2:.*]] = "tfl.maximum"(%[[CST_IS_0]], %arg2) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+// CHECK:     %[[MIN_2:.*]] = "tfl.minimum"(%[[CST_IS_1]], %[[MAX_2]]) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+// CHECK:     %[[PACK:.*]] = "tfl.pack"(%[[MIN_1]], %[[MIN_2]]) <{axis = 0 : i32, values_count = 2 : i32}> : (tensor<i32>, tensor<i32>) -> tensor<2xi32>
+// CHECK:     %[[SLICE_SIZE:.*]] = arith.constant dense<[4, 2]> : tensor<2xi64>
+// CHECK:     "tfl.slice"(%arg0, %[[PACK]], %[[SLICE_SIZE]]) : (tensor<7x3xf32>, tensor<2xi32>, tensor<2xi64>) -> tensor<4x2xf32>
+
+// -----
+
+// CHECK-LABEL: dynamic_slice_i64
+func.func @dynamic_slice_i64(%arg0: tensor<7x3xf32>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<4x2xf32> {
+  %0 = "mhlo.dynamic_slice"(%arg0, %arg1, %arg2) <{slice_sizes = dense<[4, 2]> : tensor<2xi64>}> : (tensor<7x3xf32>, tensor<i64>, tensor<i64>) -> tensor<4x2xf32>
+  func.return %0 : tensor<4x2xf32>
+}
+
+// CHECK-DAG: %[[CST_IS_0:.*]] = arith.constant dense<0> : tensor<i64>
+// CHECK-DAG: %[[CST_IS_3:.*]] = arith.constant dense<3> : tensor<i64>
+// CHECK:     %[[MAX_1:.*]] = "tfl.maximum"(%[[CST_IS_0]], %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
+// CHECK:     %[[MIN_1:.*]] = "tfl.minimum"(%[[CST_IS_3]], %[[MAX_1]]) : (tensor<i64>, tensor<i64>) -> tensor<i64>
+// CHECK:     %[[CST_IS_1:.*]] = arith.constant dense<1> : tensor<i64>
+// CHECK:     %[[MAX_2:.*]] = "tfl.maximum"(%[[CST_IS_0]], %arg2) : (tensor<i64>, tensor<i64>) -> tensor<i64>
+// CHECK:     %[[MIN_2:.*]] = "tfl.minimum"(%[[CST_IS_1]], %[[MAX_2]]) : (tensor<i64>, tensor<i64>) -> tensor<i64>
+// CHECK:     %[[PACK:.*]] = "tfl.pack"(%[[MIN_1]], %[[MIN_2]]) <{axis = 0 : i32, values_count = 2 : i32}> : (tensor<i64>, tensor<i64>) -> tensor<2xi64>
+// CHECK:     %[[SLICE_SIZE:.*]] = arith.constant dense<[4, 2]> : tensor<2xi64>
+// CHECK:     "tfl.slice"(%arg0, %[[PACK]], %[[SLICE_SIZE]]) : (tensor<7x3xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor<4x2xf32>
+
+// -----
+
+// CHECK-LABEL: dynamic_slice_splat_sizes
+func.func @dynamic_slice_splat_sizes(%arg0: tensor<7x3xf32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<2x2xf32> {
+  %0 = "mhlo.dynamic_slice"(%arg0, %arg1, %arg2) <{slice_sizes = dense<2> : tensor<2xi64>}> : (tensor<7x3xf32>, tensor<i32>, tensor<i32>) -> tensor<2x2xf32>
+  func.return %0 : tensor<2x2xf32>
+}
+
+// CHECK-DAG: %[[CST_IS_0:.*]] = arith.constant dense<0> : tensor<i32>
+// CHECK-DAG: %[[CST_IS_5:.*]] = arith.constant dense<5> : tensor<i32>
+// CHECK:     %[[MAX_1:.*]] = "tfl.maximum"(%[[CST_IS_0]], %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+// CHECK:     %[[MIN_1:.*]] = "tfl.minimum"(%[[CST_IS_5]], %[[MAX_1]]) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+// CHECK:     %[[CST_IS_1:.*]] = arith.constant dense<1> : tensor<i32>
+// CHECK:     %[[MAX_2:.*]] = "tfl.maximum"(%[[CST_IS_0]], %arg2) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+// CHECK:     %[[MIN_2:.*]] = "tfl.minimum"(%[[CST_IS_1]], %[[MAX_2]]) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+// CHECK:     %[[PACK:.*]] = "tfl.pack"(%[[MIN_1]], %[[MIN_2]]) <{axis = 0 : i32, values_count = 2 : i32}> : (tensor<i32>, tensor<i32>) -> tensor<2xi32>
+// CHECK:     %[[SLICE_SIZE:.*]] = arith.constant dense<2> : tensor<2xi64>
+// CHECK:     "tfl.slice"(%arg0, %[[PACK]], %[[SLICE_SIZE]]) : (tensor<7x3xf32>, tensor<2xi32>, tensor<2xi64>) -> tensor<2x2xf32>
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// rounding
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: round
+func.func @round(%arg0: tensor<8x128xf32>) -> tensor<8x128xf32> {
+  %0 = mhlo.constant dense<2.000000e+00> : tensor<8x128xf32>
+  %1 = mhlo.constant dense<5.000000e-01> : tensor<8x128xf32>
+  %2 = mhlo.constant dense<1.000000e+00> : tensor<8x128xf32>
+  %3 = "mhlo.floor"(%arg0) : (tensor<8x128xf32>) -> tensor<8x128xf32>
+  %4 = mhlo.subtract %arg0, %3 : tensor<8x128xf32>
+  %5 = "mhlo.compare"(%4, %1) {comparison_direction = #mhlo<comparison_direction GT>} : (tensor<8x128xf32>, tensor<8x128xf32>) -> tensor<8x128xi1>
+  %6 = "mhlo.compare"(%4, %1) {comparison_direction = #mhlo<comparison_direction EQ>} : (tensor<8x128xf32>, tensor<8x128xf32>) -> tensor<8x128xi1>
+  %7 = mhlo.multiply %arg0, %1 : tensor<8x128xf32>
+  %8 = "mhlo.floor"(%7) : (tensor<8x128xf32>) -> tensor<8x128xf32>
+  %9 = mhlo.multiply %8, %0 : tensor<8x128xf32>
+  %10 = mhlo.subtract %3, %9 : tensor<8x128xf32>
+  %11 = "mhlo.compare"(%10, %2) {comparison_direction = #mhlo<comparison_direction EQ>} : (tensor<8x128xf32>, tensor<8x128xf32>) -> tensor<8x128xi1>
+  %12 = mhlo.and %6, %11 : tensor<8x128xi1>
+  %13 = mhlo.or %5, %12 : tensor<8x128xi1>
+  %14 = mhlo.add %3, %2 : tensor<8x128xf32>
+  %15 = "mhlo.select"(%13, %14, %3) : (tensor<8x128xi1>, tensor<8x128xf32>, tensor<8x128xf32>) -> tensor<8x128xf32>
+  func.return %15 : tensor<8x128xf32>
+}
+
+// CHECK: "tfl.round"(%arg0) : (tensor<8x128xf32>) -> tensor<8x128xf32>
+
+// -----
+
+// CHECK-LABEL: floor_mod_float
+func.func @floor_mod_float(%arg0: tensor<192x8xf32>, %arg1: tensor<192x8xf32>) -> tensor<192x8xf32> {
+  %0 = mhlo.constant dense<0.000000e+00> : tensor<192x8xf32>
+  %1 = mhlo.remainder %arg0, %arg1 : tensor<192x8xf32>
+  %2 = "mhlo.compare"(%1, %0) {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<192x8xf32>, tensor<192x8xf32>) -> tensor<192x8xi1>
+  %3 = "mhlo.compare"(%arg1, %0) {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<192x8xf32>, tensor<192x8xf32>) -> tensor<192x8xi1>
+  %4 = "mhlo.compare"(%2, %3) {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<192x8xi1>, tensor<192x8xi1>) -> tensor<192x8xi1>
+  %5 = "mhlo.compare"(%1, %0) {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<192x8xf32>, tensor<192x8xf32>) -> tensor<192x8xi1>
+  %6 = mhlo.and %4, %5 : tensor<192x8xi1>
+  %7 = mhlo.add %1, %arg1 : tensor<192x8xf32>
+  %8 = "mhlo.select"(%6, %7, %1) : (tensor<192x8xi1>, tensor<192x8xf32>, tensor<192x8xf32>) -> tensor<192x8xf32>
+  func.return %8 : tensor<192x8xf32>
+}
+
+// CHECK: "tfl.floor_mod"(%arg0, %arg1) : (tensor<192x8xf32>, tensor<192x8xf32>) -> tensor<192x8xf32>
+
+// -----
+
+// CHECK-LABEL: floor_mod_int
+func.func @floor_mod_int(%arg0: tensor<192x8xi32>, %arg1: tensor<192x8xi32>) -> tensor<192x8xi32> {
+  %0 = mhlo.constant dense<0> : tensor<192x8xi32>
+  %1 = mhlo.remainder %arg0, %arg1 : tensor<192x8xi32>
+  %2 = "mhlo.compare"(%1, %0) {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<192x8xi32>, tensor<192x8xi32>) -> tensor<192x8xi1>
+  %3 = "mhlo.compare"(%arg1, %0) {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<192x8xi32>, tensor<192x8xi32>) -> tensor<192x8xi1>
+  %4 = "mhlo.compare"(%2, %3) {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<192x8xi1>, tensor<192x8xi1>) -> tensor<192x8xi1>
+  %5 = "mhlo.compare"(%1, %0) {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<192x8xi32>, tensor<192x8xi32>) -> tensor<192x8xi1>
+  %6 = mhlo.and %4, %5 : tensor<192x8xi1>
+  %7 = mhlo.add %1, %arg1 : tensor<192x8xi32>
+  %8 = "mhlo.select"(%6, %7, %1) : (tensor<192x8xi1>, tensor<192x8xi32>, tensor<192x8xi32>) -> tensor<192x8xi32>
+  func.return %8 : tensor<192x8xi32>
+}
+
+// CHECK: "tfl.floor_mod"(%arg0, %arg1) : (tensor<192x8xi32>, tensor<192x8xi32>) -> tensor<192x8xi32>
+
+// -----
+
+// CHECK-LABEL: floor_mod_float_cst
+func.func @floor_mod_float_cst(%arg0: tensor<192x8xf32>) -> tensor<192x8xf32> {
+  %0 = mhlo.constant dense<0.000000e+00> : tensor<192x8xf32>
+  %1 = mhlo.constant dense<2.000000e+00> : tensor<192x8xf32>
+  %2 = mhlo.remainder %arg0, %1 : tensor<192x8xf32>
+  %3 = "mhlo.compare"(%2, %0) {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<192x8xf32>, tensor<192x8xf32>) -> tensor<192x8xi1>
+  %4 = "mhlo.compare"(%2, %0) {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<192x8xf32>, tensor<192x8xf32>) -> tensor<192x8xi1>
+  %5 = mhlo.and %3, %4 : tensor<192x8xi1>
+  %6 = mhlo.add %2, %1 : tensor<192x8xf32>
+  %7 = "mhlo.select"(%5, %6, %2) : (tensor<192x8xi1>, tensor<192x8xf32>, tensor<192x8xf32>) -> tensor<192x8xf32>
+  func.return %7 : tensor<192x8xf32>
+}
+
+// CHECK: %cst = arith.constant dense<2.000000e+00> : tensor<192x8xf32>
+// CHECK: "tfl.floor_mod"(%arg0, %cst) : (tensor<192x8xf32>, tensor<192x8xf32>) -> tensor<192x8xf32>
+
+// -----
+
+// CHECK-LABEL: floor_mod_int_cst
+func.func @floor_mod_int_cst(%arg0: tensor<192x8xi32>) -> tensor<192x8xi32> {
+  %0 = mhlo.constant dense<0> : tensor<192x8xi32>
+  %1 = mhlo.constant dense<2> : tensor<192x8xi32>
+  %2 = mhlo.remainder %arg0, %1 : tensor<192x8xi32>
+  %3 = "mhlo.compare"(%2, %0) {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<192x8xi32>, tensor<192x8xi32>) -> tensor<192x8xi1>
+  %4 = "mhlo.compare"(%2, %0) {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<192x8xi32>, tensor<192x8xi32>) -> tensor<192x8xi1>
+  %5 = mhlo.and %3, %4 : tensor<192x8xi1>
+  %6 = mhlo.add %2, %1 : tensor<192x8xi32>
+  %7 = "mhlo.select"(%5, %6, %2) : (tensor<192x8xi1>, tensor<192x8xi32>, tensor<192x8xi32>) -> tensor<192x8xi32>
+  func.return %7 : tensor<192x8xi32>
+}
+
+// CHECK: %cst = arith.constant dense<2> : tensor<192x8xi32>
+// CHECK: "tfl.floor_mod"(%arg0, %cst) : (tensor<192x8xi32>, tensor<192x8xi32>) -> tensor<192x8xi32>
+
+// -----
+
+// CHECK-LABEL: floor_div
+func.func @floor_div(%arg0: tensor<10x10xf32>, %arg1: tensor<10x10xf32>) -> tensor<10x10xf32> {
+  %0 = mhlo.constant dense<0.000000e+00> : tensor<10x10xf32>
+  %1 = mhlo.constant dense<-1.000000e+00> : tensor<10x10xf32>
+  %2 = mhlo.remainder %arg0, %arg1 : tensor<10x10xf32>
+  %3 = "mhlo.compare"(%2, %0) {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xi1>
+  %4 = "mhlo.sign"(%arg1) : (tensor<10x10xf32>) -> tensor<10x10xf32>
+  %5 = "mhlo.sign"(%2) : (tensor<10x10xf32>) -> tensor<10x10xf32>
+  %6 = "mhlo.compare"(%4, %5) {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xi1>
+  %7 = mhlo.and %3, %6 : tensor<10x10xi1>
+  %8 = mhlo.subtract %arg0, %2 : tensor<10x10xf32>
+  %9 = mhlo.divide %8, %arg1 : tensor<10x10xf32>
+  %10 = mhlo.add %9, %1 : tensor<10x10xf32>
+  %11 = "mhlo.select"(%7, %10, %9) : (tensor<10x10xi1>, tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
+  %12 = "mhlo.round_nearest_afz"(%11) : (tensor<10x10xf32>) -> tensor<10x10xf32>
+  %13 = "mhlo.tuple"(%12) : (tensor<10x10xf32>) -> tuple<tensor<10x10xf32>>
+  func.return %12 : tensor<10x10xf32>
+}
+
+// CHECK: tfl.floor_div %arg0, %arg1 : tensor<10x10xf32
+
+// -----
+
+// CHECK-LABEL: floor_div_cst
+func.func @floor_div_cst(%arg0: tensor<10x10xf32>) -> tensor<10x10xf32> {
+  %0 = mhlo.constant dense<2.000000e+00> : tensor<10x10xf32>
+  %1 = mhlo.constant dense<0.000000e+00> : tensor<10x10xf32>
+  %2 = mhlo.constant dense<1.000000e+00> : tensor<10x10xf32>
+  %3 = mhlo.constant dense<5.000000e-01> : tensor<10x10xf32>
+  %4 = mhlo.constant dense<-1.000000e+00> : tensor<10x10xf32>
+  %5 = mhlo.remainder %arg0, %0 : tensor<10x10xf32>
+  %6 = "mhlo.compare"(%5, %1) {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xi1>
+  %7 = "mhlo.sign"(%5) : (tensor<10x10xf32>) -> tensor<10x10xf32>
+  %8 = "mhlo.compare"(%2, %7) {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xi1>
+  %9 = mhlo.and %6, %8 : tensor<10x10xi1>
+  %10 = mhlo.subtract %arg0, %5 : tensor<10x10xf32>
+  %11 = mhlo.multiply %10, %3 : tensor<10x10xf32>
+  %12 = mhlo.add %11, %4 : tensor<10x10xf32>
+  %13 = "mhlo.select"(%9, %12, %11) : (tensor<10x10xi1>, tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
+  %14 = "mhlo.round_nearest_afz"(%13) : (tensor<10x10xf32>) -> tensor<10x10xf32>
+  %15 = "mhlo.tuple"(%14) : (tensor<10x10xf32>) -> tuple<tensor<10x10xf32>>
+  func.return %14 : tensor<10x10xf32>
+}
+
+// CHECK: %[[CST:.*]] = mhlo.constant dense<2.000000e+00> : tensor<10x10xf32>
+// CHECK: tfl.floor_div %arg0, %[[CST]] : tensor<10x10xf32>
+
+// -----
+
+// CHECK-LABEL: floor_div_cst2
+func.func @floor_div_cst2(%arg0: tensor<10x10xf32>) -> tensor<10x10xf32> {
+  %0 = mhlo.constant dense<1.000000e+00> : tensor<10x10xf32>
+  %1 = mhlo.constant dense<2.000000e+00> : tensor<10x10xf32>
+  %2 = mhlo.constant dense<0.000000e+00> : tensor<10x10xf32>
+  %3 = mhlo.constant dense<-1.000000e+00> : tensor<10x10xf32>
+  %4 = mhlo.remainder %arg0, %1 : tensor<10x10xf32>
+  %5 = "mhlo.compare"(%4, %2) {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xi1>
+  %6 = "mhlo.sign"(%4) : (tensor<10x10xf32>) -> tensor<10x10xf32>
+  %7 = "mhlo.compare"(%0, %6) {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xi1>
+  %8 = mhlo.and %5, %7 : tensor<10x10xi1>
+  %9 = mhlo.subtract %arg0, %4 : tensor<10x10xf32>
+  %10 = mhlo.divide %9, %1 : tensor<10x10xf32>
+  %11 = mhlo.add %10, %3 : tensor<10x10xf32>
+  %12 = "mhlo.select"(%8, %11, %10) : (tensor<10x10xi1>, tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
+  %13 = "mhlo.round_nearest_afz"(%12) : (tensor<10x10xf32>) -> tensor<10x10xf32>
+  %14 = "mhlo.tuple"(%13) : (tensor<10x10xf32>) -> tuple<tensor<10x10xf32>>
+  func.return %13 : tensor<10x10xf32>
+}
+
+// CHECK: %[[CST:.*]] = mhlo.constant dense<2.000000e+00> : tensor<10x10xf32>
+// CHECK: tfl.floor_div %arg0, %[[CST]] : tensor<10x10xf32>
+
+// -----
+
+// CHECK-LABEL: floor_div_broadcast_cst
+func.func @floor_div_broadcast_cst(%arg0: tensor<10x8xf32>) -> tensor<10x8xf32> {
+  %0 = mhlo.constant dense<1.000000e+00> : tensor<10x8xf32>
+  %1 = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 4.000000e+00, 8.000000e+00, 1.600000e+01, 3.200000e+01, 6.400000e+01, 1.280000e+02]> : tensor<8xf32>
+  %2 = mhlo.constant dense<0.000000e+00> : tensor<10x8xf32>
+  %3 = mhlo.constant dense<-1.000000e+00> : tensor<10x8xf32>
+  %5 = "mhlo.broadcast_in_dim"(%1) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor<8xf32>) -> tensor<10x8xf32>
+  %6 = mhlo.remainder %arg0, %5 : tensor<10x8xf32>
+  %7 = "mhlo.compare"(%6, %2) {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<10x8xf32>, tensor<10x8xf32>) -> tensor<10x8xi1>
+  %8 = "mhlo.sign"(%6) : (tensor<10x8xf32>) -> tensor<10x8xf32>
+  %9 = "mhlo.compare"(%0, %8) {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<10x8xf32>, tensor<10x8xf32>) -> tensor<10x8xi1>
+  %10 = mhlo.and %7, %9 : tensor<10x8xi1>
+  %11 = mhlo.subtract %arg0, %6 : tensor<10x8xf32>
+  %12 = mhlo.divide %11, %5 : tensor<10x8xf32>
+  %13 = mhlo.add %12, %3 : tensor<10x8xf32>
+  %14 = "mhlo.select"(%10, %13, %12) : (tensor<10x8xi1>, tensor<10x8xf32>, tensor<10x8xf32>) -> tensor<10x8xf32>
+  %15 = "mhlo.round_nearest_afz"(%14) : (tensor<10x8xf32>) -> tensor<10x8xf32>
+  %16 = "mhlo.tuple"(%15) : (tensor<10x8xf32>) -> tuple<tensor<10x8xf32>>
+  func.return %15 : tensor<10x8xf32>
+}
+
+// CHECK: %[[BCAST:.*]] = "mhlo.broadcast_in_dim"(%1)
+// CHECK: tfl.floor_div %arg0, %[[BCAST]] : tensor<10x8xf32>
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// unary elementwise
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: convert_i32_f32
+func.func @convert_i32_f32(%arg0: tensor<2xi32>) -> tensor<2xf32> {
+  %0 = "mhlo.convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32>
+  func.return %0 : tensor<2xf32>
+}
+
+// CHECK: tfl.cast
+
+// -----
+
+// CHECK-LABEL: abs
+func.func @abs(%arg0: tensor<2xf32>) -> tensor<2xf32> {
+  %0 = "mhlo.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+  func.return %0 : tensor<2xf32>
+}
+
+// CHECK: tfl.abs
+
+// -----
+
+// CHECK-LABEL: abs_dynamic
+func.func @abs_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
+  %0 = "mhlo.abs"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
+  func.return %0 : tensor<?xf32>
+}
+
+// CHECK: tfl.abs
+
+// -----
+
+// CHECK-LABEL: ceil
+func.func @ceil(%arg0: tensor<2xf32>) -> tensor<2xf32> {
+  %0 = "mhlo.ceil"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+  func.return %0 : tensor<2xf32>
+}
+
+// CHECK: tfl.ceil
+
+// -----
+
+// CHECK-LABEL: ceil_dynamic
+func.func @ceil_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
+  %0 = "mhlo.ceil"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
+  func.return %0 : tensor<?xf32>
+}
+
+// CHECK: tfl.ceil
+
+// -----
+
+// CHECK-LABEL: complex_abs
+func.func @complex_abs(%arg0: tensor<2xcomplex<f32>>) -> tensor<2xf32> {
+  %0 = "mhlo.abs"(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
+  func.return %0 : tensor<2xf32>
+}
+
+// CHECK-NOT: tfl
+
+// -----
+
+func.func @is_finite(%arg0: tensor<2xf32>) -> tensor<2xi1> {
+  %0 = "mhlo.is_finite"(%arg0) : (tensor<2xf32>) -> tensor<2xi1>
+  func.return %0 : tensor<2xi1>
+}
+
+// CHECK: %0 = tfl.sub %arg0, %arg0 {fused_activation_function = "NONE"} : tensor<2xf32>
+// CHECK: %cst = arith.constant dense<0.000000e+00> : tensor<f32>
+// CHECK: %1 = "tfl.equal"(%0, %cst) : (tensor<2xf32>, tensor<f32>) -> tensor<2xi1>
+// CHECK: return %1 : tensor<2xi1>
+
+// -----
+
+func.func @is_finite_dynamic(%arg0: tensor<?xf32>) -> tensor<?xi1> {
+  %0 = "mhlo.is_finite"(%arg0) : (tensor<?xf32>) -> tensor<?xi1>
+  func.return %0 : tensor<?xi1>
+}
+
+// CHECK: %0 = tfl.sub %arg0, %arg0 {fused_activation_function = "NONE"} : tensor<?xf32>
+// CHECK: %cst = arith.constant dense<0.000000e+00> : tensor<f32>
+// CHECK: %1 = "tfl.equal"(%0, %cst) : (tensor<?xf32>, tensor<f32>) -> tensor<?xi1>
+
+// -----
+
+// CHECK-LABEL: cos
+func.func @cos(%arg0: tensor<2xf32>) -> tensor<2xf32> {
+  %0 = "mhlo.cosine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+  func.return %0 : tensor<2xf32>
+}
+
+// CHECK: tfl.cos
+
+// -----
+
+// CHECK-LABEL: cos_dynamic
+func.func @cos_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
+  %0 = "mhlo.cosine"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
+  func.return %0 : tensor<?xf32>
+}
+
+// CHECK: tfl.cos
+
+// -----
+
+// CHECK-LABEL: logistic
+func.func @logistic(%arg0: tensor<2xf32>) -> tensor<2xf32> {
+  %0 = "mhlo.logistic"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+  func.return %0 : tensor<2xf32>
+}
+
+// CHECK: tfl.logistic
+
+// -----
+
+// CHECK-LABEL: exp
+func.func @exp(%arg0: tensor<2xf32>) -> tensor<2xf32> {
+  %0 = "mhlo.exponential"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+  func.return %0 : tensor<2xf32>
+}
+
+// CHECK: tfl.exp
+
+// -----
+
+// CHECK-LABEL: exp_dynamic
+func.func @exp_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
+  %0 = "mhlo.exponential"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
+  func.return %0 : tensor<?xf32>
+}
+
+// CHECK: tfl.exp
+
+// -----
+
+// CHECK-LABEL: expm1
+func.func @expm1(%arg0: tensor<2xf32>) -> tensor<2xf32> {
+  %0 = "mhlo.exponential_minus_one"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+  func.return %0 : tensor<2xf32>
+}
+
+// CHECK: %0 = "tfl.exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+// CHECK: %cst = arith.constant dense<1.000000e+00> : tensor<f32>
+// CHECK: %1 = tfl.sub(%0, %cst) <{fused_activation_function = "NONE"}> : (tensor<2xf32>, tensor<f32>) -> tensor<2xf32>
+
+// -----
+
+// CHECK-LABEL: floor
+func.func @floor(%arg0: tensor<2xf32>) -> tensor<2xf32> {
+  %0 = "mhlo.floor"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+  func.return %0 : tensor<2xf32>
+}
+
+// CHECK: tfl.floor
+
+// -----
+
+// CHECK-LABEL: floor_dynamic
+func.func @floor_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
+  %0 = "mhlo.floor"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
+  func.return %0 : tensor<?xf32>
+}
+
+// CHECK: tfl.floor
+
+// -----
+
+// CHECK-LABEL: log
+func.func @log(%arg0: tensor<2xf32>) -> tensor<2xf32> {
+  %0 = "mhlo.log"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+  func.return %0 : tensor<2xf32>
+}
+
+// CHECK: tfl.log
+
+// -----
+
+// CHECK-LABEL: log_dynamic
+func.func @log_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
+  %0 = "mhlo.log"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
+  func.return %0 : tensor<?xf32>
+}
+
+// CHECK: tfl.log
+
+// -----
+
+// CHECK-LABEL: log1p
+func.func @log1p(%arg0: tensor<2xf32>) -> tensor<2xf32> {
+  %0 = "mhlo.log_plus_one"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+  func.return %0 : tensor<2xf32>
+}
+
+// CHECK: %cst = arith.constant dense<1.000000e+00> : tensor<f32>
+// CHECK: %0 = tfl.add(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor<2xf32>, tensor<f32>) -> tensor<2xf32>
+// CHECK: %1 = "tfl.log"(%0) : (tensor<2xf32>) -> tensor<2xf32>
+
+// -----
+
+// CHECK-LABEL: log1p_dynamic
+func.func @log1p_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
+  %0 = "mhlo.log_plus_one"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
+  func.return %0 : tensor<?xf32>
+}
+
+// CHECK: %cst = arith.constant dense<1.000000e+00> : tensor<f32>
+// CHECK: %0 = tfl.add(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor<?xf32>, tensor<f32>) -> tensor<?xf32>
+// CHECK: %1 = "tfl.log"(%0) : (tensor<?xf32>) -> tensor<?xf32>
+
+// -----
+
+// CHECK-LABEL: neg
+func.func @neg(%arg0: tensor<2xf32>) -> tensor<2xf32> {
+  %0 = "mhlo.negate"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+  func.return %0 : tensor<2xf32>
+}
+
+// CHECK: tfl.neg
+
+// -----
+
+// CHECK-LABEL: neg_dynamic
+func.func @neg_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
+  %0 = "mhlo.negate"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
+  func.return %0 : tensor<?xf32>
+}
+
+// CHECK: tfl.neg
+
+// -----
+
+// CHECK-LABEL: sin
+func.func @sin(%arg0: tensor<2xf32>) -> tensor<2xf32> {
+  %0 = "mhlo.sine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+  func.return %0 : tensor<2xf32>
+}
+
+// CHECK: tfl.sin
+
+// -----
+
+// CHECK-LABEL: sin_dynamic
+func.func @sin_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
+  %0 = "mhlo.sine"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
+  func.return %0 : tensor<?xf32>
+}
+
+// CHECK: tfl.sin
+
+// -----
+
+// CHECK-LABEL: rsqrt
+func.func @rsqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> {
+  %0 = "mhlo.rsqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+  func.return %0 : tensor<2xf32>
+}
+
+// CHECK: tfl.rsqrt
+
+// -----
+
+// CHECK-LABEL: rsqrt_dynamic
+func.func @rsqrt_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
+  %0 = "mhlo.rsqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
+  func.return %0 : tensor<?xf32>
+}
+
+// CHECK: tfl.rsqrt
+
+// -----
+
+// CHECK-LABEL: @sqrt
+func.func @sqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> {
+  %0 = "mhlo.sqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+  func.return %0 : tensor<2xf32>
+}
+
+// CHECK: tfl.sqrt
+
+// -----
+
+// CHECK-LABEL: sqrt_dynamic
+func.func @sqrt_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
+  %0 = "mhlo.sqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
+  func.return %0 : tensor<?xf32>
+}
+
+// CHECK: tfl.sqrt
+
+// -----
+
+// CHECK-LABEL: tanh
+func.func @tanh(%arg0: tensor<2xf32>) -> tensor<2xf32> {
+  %0 = "mhlo.tanh"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+  func.return %0 : tensor<2xf32>
+}
+
+// CHECK: tfl.tanh
+
+// -----
+
+// CHECK-LABEL: tanh_dynamic
+func.func @tanh_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
+  %0 = "mhlo.tanh"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
+  func.return %0 : tensor<?xf32>
+}
+
+// CHECK: tfl.tanh
+
+// -----
+
+// CHECK-LABEL: bitcast
+func.func @bitcast(%arg0: tensor<2xf32>) -> tensor<2xf32> {
+  %0 = "mhlo.bitcast_convert"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+  func.return %0 : tensor<2xf32>
+}
+
+// CHECK: tfl.bitcast
+
+// -----
+
+// CHECK-LABEL: bitcast_dynamic
+func.func @bitcast_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
+  %0 = "mhlo.bitcast_convert"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
+  func.return %0 : tensor<?xf32>
+}
+
+// CHECK: tfl.bitcast
+
+// -----
+
+// CHECK-LABEL: bitcast_same_widths
+func.func @bitcast_same_widths(%arg0: tensor<2xf32>) -> tensor<2xi32> {
+  %0 = "mhlo.bitcast_convert"(%arg0) : (tensor<2xf32>) -> tensor<2xi32>
+  func.return %0 : tensor<2xi32>
+}
+
+// CHECK: tfl.bitcast
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// logical and bitwise ops
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: not
+func.func @not(%arg0: tensor<5x3x1xi1>) -> tensor<5x3x1xi1> {
+  %0 = "mhlo.not"(%arg0): (tensor<5x3x1xi1>) -> (tensor<5x3x1xi1>)
+  func.return %0 : tensor<5x3x1xi1>
+}
+
+// CHECK: %0 = "tfl.logical_not"(%arg0) : (tensor<5x3x1xi1>) -> tensor<5x3x1xi1>
+
+// -----
+
+// CHECK-LABEL: not_i8
+func.func @not_i8(%arg0: tensor<7x9x11xi8>) -> tensor<7x9x11xi8> {
+  %0 = "mhlo.not"(%arg0): (tensor<7x9x11xi8>) -> (tensor<7x9x11xi8>)
+  func.return %0 : tensor<7x9x11xi8>
+}
+
+// CHECK: %cst = arith.constant dense<-1> : tensor<i8>
+// CHECK: %0 = "tfl.bitwise_xor"(%arg0, %cst) : (tensor<7x9x11xi8>, tensor<i8>) -> tensor<7x9x11xi8>
+
+// -----
+
+// CHECK-LABEL: not_i16
+func.func @not_i16(%arg0: tensor<7x9x11xi16>) -> tensor<7x9x11xi16> {
+  %0 = "mhlo.not"(%arg0): (tensor<7x9x11xi16>) -> (tensor<7x9x11xi16>)
+  func.return %0 : tensor<7x9x11xi16>
+}
+
+// CHECK: %cst = arith.constant dense<-1> : tensor<i16>
+// CHECK: %0 = "tfl.bitwise_xor"(%arg0, %cst) : (tensor<7x9x11xi16>, tensor<i16>) -> tensor<7x9x11xi16>
+
+// -----
+
+// CHECK-LABEL: not_i32
+func.func @not_i32(%arg0: tensor<7x9x11xi32>) -> tensor<7x9x11xi32> {
+  %0 = "mhlo.not"(%arg0): (tensor<7x9x11xi32>) -> (tensor<7x9x11xi32>)
+  func.return %0 : tensor<7x9x11xi32>
+}
+
+// CHECK: %cst = arith.constant dense<-1> : tensor<i32>
+// CHECK: %0 = "tfl.bitwise_xor"(%arg0, %cst) : (tensor<7x9x11xi32>, tensor<i32>) -> tensor<7x9x11xi32>
+
+// -----
+
+// CHECK-LABEL: not_ui8
+func.func @not_ui8(%arg0: tensor<7x9x11xui8>) -> tensor<7x9x11xui8> {
+  %0 = "mhlo.not"(%arg0): (tensor<7x9x11xui8>) -> (tensor<7x9x11xui8>)
+  func.return %0 : tensor<7x9x11xui8>
+}
+
+// CHECK: %cst = arith.constant dense<255> : tensor<ui8>
+// CHECK: %0 = "tfl.bitwise_xor"(%arg0, %cst) : (tensor<7x9x11xui8>, tensor<ui8>) -> tensor<7x9x11xui8>
+
+// -----
+
+// CHECK-LABEL: not_ui16
+func.func @not_ui16(%arg0: tensor<7x9x11xui16>) -> tensor<7x9x11xui16> {
+  %0 = "mhlo.not"(%arg0): (tensor<7x9x11xui16>) -> (tensor<7x9x11xui16>)
+  func.return %0 : tensor<7x9x11xui16>
+}
+
+// CHECK: %cst = arith.constant dense<65535> : tensor<ui16>
+// CHECK: %0 = "tfl.bitwise_xor"(%arg0, %cst) : (tensor<7x9x11xui16>, tensor<ui16>) -> tensor<7x9x11xui16>
+
+// -----
+
+// CHECK-LABEL: not_ui32
+func.func @not_ui32(%arg0: tensor<7x9x11xui32>) -> tensor<7x9x11xui32> {
+  %0 = "mhlo.not"(%arg0): (tensor<7x9x11xui32>) -> (tensor<7x9x11xui32>)
+  func.return %0 : tensor<7x9x11xui32>
+}
+
+// CHECK: %cst = arith.constant dense<4294967295> : tensor<ui32>
+// CHECK: %0 = "tfl.bitwise_xor"(%arg0, %cst) : (tensor<7x9x11xui32>, tensor<ui32>) -> tensor<7x9x11xui32>
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// binary ops
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: remainder
+func.func @remainder(%arg0: tensor<10x8xi32>, %arg1: tensor<10x8xi32>) -> tensor<10x8xi32> {
+  %0 = mhlo.remainder %arg0, %arg1 : tensor<10x8xi32>
+  func.return %0 : tensor<10x8xi32>
+}
+
+// CHECK: %0 = "tfl.floor_mod"(%arg0, %arg1) : (tensor<10x8xi32>, tensor<10x8xi32>) -> tensor<10x8xi32>
+
+// -----
+
+// CHECK-LABEL: shift_right_arith
+func.func @shift_right_arith(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
+  %0 = mhlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32>
+  func.return %0 : tensor<4xi32>
+}
+
+// CHECK: %0 = "tfl.right_shift"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
+
+// -----
+
+// CHECK-LABEL: shift_right_logical
+func.func @shift_right_logical(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
+  %0 = mhlo.shift_right_logical %arg0, %arg1 : tensor<4xi32>
+  func.return %0 : tensor<4xi32>
+}
+
+// CHECK: %0 = "tfl.right_shift"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// mhlo.compare
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: greater_unsupported_compare_type
+func.func @greater_unsupported_compare_type(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xi1> {
+  %0 = "mhlo.compare"(%arg0, %arg1) {compare_type = #mhlo<comparison_type TOTALORDER>, comparison_direction = #mhlo<comparison_direction GT>} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1>
+  func.return %0 : tensor<2xi1>
+}
+
+// CHECK-NOT: tfl
+// CHECK:     mhlo.compare
+
+// -----
+
+// CHECK-LABEL: equal
+func.func @equal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> {
+  %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = #mhlo<comparison_direction EQ>} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
+  func.return %0 : tensor<2xi1>
+}
+
+// CHECK: tfl.equal
+
+// -----
+
+// CHECK-LABEL: notequal
+func.func @notequal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> {
+  %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
+  func.return %0 : tensor<2xi1>
+}
+
+// CHECK: tfl.not_equal
+
+// -----
+
+// CHECK-LABEL: greater
+func.func @greater(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> {
+  %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = #mhlo<comparison_direction GT>} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
+  func.return %0 : tensor<2xi1>
+}
+
+// CHECK: tfl.greater
+
+// -----
+
+// CHECK-LABEL: greater_equal
+func.func @greater_equal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> {
+  %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = #mhlo<comparison_direction GE>} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
+  func.return %0 : tensor<2xi1>
+}
+
+// CHECK: tfl.greater_equal
+
+// -----
+
+// CHECK-LABEL: less
+func.func @less(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> {
+  %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
+  func.return %0 : tensor<2xi1>
+}
+
+// CHECK: tfl.less
+
+// -----
+
+// CHECK-LABEL: less_equal
+func.func @less_equal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> {
+  %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = #mhlo<comparison_direction LE>} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
+  func.return %0 : tensor<2xi1>
+}
+
+// CHECK: tfl.less_equal
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/BUILD
index 19d9863..a277c0b 100644
--- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/BUILD
+++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/BUILD
@@ -21,7 +21,6 @@
         "//tensorflow/compiler/mlir/lite:tensorflow_lite",
         "//tensorflow/compiler/mlir/tensorflow",
         "@llvm-project//llvm:Support",
-        "@llvm-project//mlir:ArithDialect",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Support",
         "@llvm-project//mlir:TransformUtils",
@@ -140,6 +139,7 @@
     srcs = ["conv_util.cc"],
     hdrs = ["conv_util.h"],
     deps = [
+        ":op_util_common",
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:ArithDialect",
         "@llvm-project//mlir:IR",
@@ -153,6 +153,7 @@
     srcs = ["pad.cc"],
     hdrs = ["pad.h"],
     deps = [
+        ":op_util_common",
         ":pad_util",
         "//tensorflow/compiler/mlir/lite:tensorflow_lite",
         "@llvm-project//llvm:Support",
@@ -169,6 +170,7 @@
     srcs = ["pad_util.cc"],
     hdrs = ["pad_util.h"],
     deps = [
+        ":op_util_common",
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Support",
@@ -191,3 +193,77 @@
         "@local_xla//xla/mlir_hlo",
     ],
 )
+
+cc_library(
+    name = "reduce_window",
+    srcs = ["reduce_window.cc"],
+    hdrs = ["reduce_window.h"],
+    deps = [
+        ":op_util_common",
+        ":reduce_window_util",
+        ":util",
+        "//tensorflow/compiler/mlir/lite:tensorflow_lite",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:ArithDialect",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:Support",
+        "@llvm-project//mlir:TransformUtils",
+        "@local_xla//xla/mlir_hlo",
+    ],
+)
+
+cc_library(
+    name = "op_util_common",
+    srcs = ["op_util_common.cc"],
+    hdrs = ["op_util_common.h"],
+    deps = [
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:Support",
+    ],
+)
+
+cc_library(
+    name = "reduce_window_util",
+    srcs = ["reduce_window_util.cc"],
+    hdrs = ["reduce_window_util.h"],
+    deps = [
+        ":op_util_common",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:Support",
+        "@local_xla//xla/mlir_hlo",
+    ],
+)
+
+cc_library(
+    name = "slice",
+    srcs = ["slice.cc"],
+    hdrs = ["slice.h"],
+    deps = [
+        ":op_util_common",
+        "//tensorflow/compiler/mlir/lite:tensorflow_lite",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:ArithDialect",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:Support",
+        "@llvm-project//mlir:TransformUtils",
+        "@local_xla//xla/mlir_hlo",
+    ],
+)
+
+cc_library(
+    name = "sort",
+    srcs = ["sort.cc"],
+    hdrs = ["sort.h"],
+    deps = [
+        "//tensorflow/compiler/mlir/lite:tensorflow_lite",
+        "//tensorflow/compiler/mlir/lite/stablehlo:hlo_matchers",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:ArithDialect",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:Support",
+        "@llvm-project//mlir:TransformUtils",
+        "@local_xla//xla/mlir_hlo",
+    ],
+)
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv.cc
index 33b62b5..87a429d 100644
--- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv.cc
+++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv.cc
@@ -44,27 +44,27 @@
   return llvm::all_of(shape, [](int64_t d) { return d >= 0; });
 }
 
-bool AreShapesSupported(const ConvData& data) {
+bool AreShapesSupported(const ConvView& data) {
   return IsShapeFullyStatic(data.InputShape()) &&
          IsShapeFullyStatic(data.KernelShape()) &&
          IsShapeFullyStatic(data.OutputShape());
 }
 
-bool IsPaddingSupported(const ConvData& data) {
+bool IsPaddingSupported(const ConvView& data) {
   return llvm::all_of(data.Padding(), [](const DimPadding& p) {
     return p.Hi() == 0 && p.Lo() == 0;
   });
 }
 
-bool IsInputDilationSupported(const ConvData& data) {
+bool IsInputDilationSupported(const ConvView& data) {
   return llvm::all_of(data.InputDilations(), [](int64_t v) { return v == 1; });
 }
 
-bool IsBatchGroupSupported(const ConvData& data) {
+bool IsBatchGroupSupported(const ConvView& data) {
   return data.BatchGroupCount() == 1;
 }
 
-bool IsWindowReversalSupported(const ConvData& data) {
+bool IsWindowReversalSupported(const ConvView& data) {
   return llvm::all_of(data.WindowReversal(), [](bool b) { return !b; });
 }
 
@@ -72,7 +72,7 @@
 // Used externally to setup a ConversionTarget with dynamically legal
 // mhlo.convolution. Doubles as matching predicate during legalization.
 bool IsConvLegal(mhlo::ConvolutionOp op) {
-  const ConvData data(op);
+  const ConvView data(op);
 
   const bool supported_conv_type =
       IsStandardConv(data) || IsDepthwiseConv(data);
@@ -89,7 +89,7 @@
 
 // Bias is a zero tensor of shape [output_channels].
 arith::ConstantOp BuildEmptyBias(OpBuilder& b, Location loc,
-                                 const ConvData& data) {
+                                 const ConvView& data) {
   auto bias_type = RankedTensorType::get(
       {data.OutputLayout().SpecialDim2(data.OutputShape())},
       data.ElementType());
@@ -109,7 +109,7 @@
     mhlo::ConvolutionOp op, OpAdaptor adaptor,
     ConversionPatternRewriter& rewriter) const {
   // Parse mhlo.convolution attrs into cc types.
-  const ConvData data(op);
+  const ConvView data(op);
 
   if (IsConvLegal(op) || !IsStandardConv(data) ||
       data.InputLayout().Rank() != 4) {
@@ -168,7 +168,7 @@
     mhlo::ConvolutionOp op, OpAdaptor adaptor,
     ConversionPatternRewriter& rewriter) const {
   // Parse mhlo.convolution attrs into cc types.
-  const ConvData data(op);
+  const ConvView data(op);
 
   if (IsConvLegal(op) || !IsDepthwiseConv(data)) {
     return failure();
@@ -236,7 +236,7 @@
     mhlo::ConvolutionOp op, OpAdaptor adaptor,
     ConversionPatternRewriter& rewriter) const {
   // Parse mhlo.convolution attrs into cc types.
-  const ConvData data(op);
+  const ConvView data(op);
 
   if (IsConvLegal(op) || !IsStandardConv(data) ||
       data.InputLayout().Rank() != 5) {
@@ -324,7 +324,7 @@
 
 LogicalResult Conv1DToConv2D::matchAndRewrite(mhlo::ConvolutionOp op,
                                               PatternRewriter& rewriter) const {
-  const ConvData view(op);
+  const ConvView view(op);
 
   if (view.InputLayout().Rank() != 3) {
     return rewriter.notifyMatchFailure(op, "Not 1D conv.");
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv_util.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv_util.cc
index 7d8c85a..e2dff5f 100644
--- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv_util.cc
+++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv_util.cc
@@ -18,92 +18,16 @@
 #include <cstdint>
 #include <optional>
 
-#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"  // from @llvm-project
 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
 #include "mlir/IR/Value.h"  // from @llvm-project
 #include "mlir/Support/LLVM.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h"
 #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
 
 namespace mlir::odml {
 
-llvm::SmallVector<int64_t, 4> Layout::GetPermForReLayout(
-    const Layout& to_layout) const {
-  llvm::SmallVector<int64_t, 4> perm(to_layout.Rank());
-  perm[to_layout.SpecialDim1()] = SpecialDim1();
-  perm[to_layout.SpecialDim2()] = SpecialDim2();
-  for (const auto [to_spatial, from_spatial] :
-       llvm::zip(to_layout.Spatials(), Spatials())) {
-    perm[to_spatial] = from_spatial;
-  }
-  return perm;
-}
-
-llvm::SmallVector<int64_t, 4> Layout::PermuteShape(
-    const Layout& to_layout, llvm::ArrayRef<int64_t> shape) const {
-  llvm::SmallVector<int64_t, 4> new_shape(to_layout.Rank());
-  const auto perm = GetPermForReLayout(to_layout);
-  for (const auto [ind, val] : llvm::enumerate(perm)) {
-    new_shape[ind] = shape[val];
-  }
-  return new_shape;
-}
-
-bool Layout::HasSpecialDims(int64_t special_dim1, int64_t special_dim2) const {
-  return SpecialDim1() == special_dim1 && SpecialDim2() == special_dim2;
-}
-
-bool Layout::AreSpatialsIota() const {
-  llvm::ArrayRef<int64_t> spatials = Spatials();
-  return llvm::all_of(llvm::enumerate(spatials), [&](const auto& it) {
-    return it.index() == 0 || (it.value() == spatials[it.index() - 1] + 1);
-  });
-}
-
-llvm::SmallVector<int64_t, 2> ResolveStridesOrDilations(
-    const int64_t num_spatials,
-    std::optional<mlir::DenseIntElementsAttr> opt_attr) {
-  if (!opt_attr.has_value()) {
-    return llvm::SmallVector<int64_t, 2>(num_spatials, 1);
-  }
-  auto attr = opt_attr.value();
-  if (attr.isSplat()) {
-    return llvm::SmallVector<int64_t, 2>(num_spatials,
-                                         attr.getSplatValue<int64_t>());
-  }
-  return llvm::SmallVector<int64_t, 2>(attr.getValues<int64_t>());
-}
-
-llvm::SmallVector<DimPadding, 2> ResolvePadding(
-    const int64_t num_spatials,
-    std::optional<mlir::DenseIntElementsAttr> opt_padding) {
-  llvm::SmallVector<DimPadding, 2> res;
-  if (!opt_padding.has_value()) {
-    for (int i = 0; i < num_spatials; ++i) {
-      res.push_back(DimPadding(0, 0));
-    }
-    return res;
-  }
-  auto padding = opt_padding.value();
-  if (padding.isSplat()) {
-    const int64_t val = padding.getSplatValue<int64_t>();
-    for (int i = 0; i < num_spatials; ++i) {
-      res.push_back(DimPadding(val, val));
-    }
-    return res;
-  }
-  int64_t prev;
-  for (const auto [ind, val] : llvm::enumerate(padding.getValues<int64_t>())) {
-    const int64_t side = ind % 2;
-    if (side == 1) {
-      res.push_back(DimPadding(prev, val));
-    }
-    prev = val;
-  }
-  return res;
-}
-
 llvm::SmallVector<bool, 2> ResolveWindowReversal(
     const int64_t num_spatials,
     std::optional<mlir::DenseElementsAttr> opt_reversals) {
@@ -118,7 +42,7 @@
   return llvm::SmallVector<bool, 2>(reversals.getValues<bool>());
 }
 
-ConvData::ConvData(mhlo::ConvolutionOp op)
+ConvView::ConvView(mhlo::ConvolutionOp op)
     : input_layout_(
           Layout{op.getDimensionNumbers().getInputBatchDimension(),
                  op.getDimensionNumbers().getInputFeatureDimension(),
@@ -156,7 +80,7 @@
 }
 
 Value CreatePadOpFromConvPadding(OpBuilder& b, mhlo::ConvolutionOp op) {
-  const ConvData data(op);
+  const ConvView data(op);
   const auto rank = data.InputLayout().Rank();
   auto input_spatials = data.InputLayout().Spatials();
 
@@ -185,8 +109,8 @@
       b.create<arith::ConstantOp>(op->getLoc(), padding_value_attr);
 
   auto pad_op = b.create<mhlo::PadOp>(padding_value_op->getLoc(), op.getLhs(),
-                                      padding_value_op, hi_padding_attr,
-                                      lo_padding_attr, interior_padding_attr);
+                                      padding_value_op, lo_padding_attr,
+                                      hi_padding_attr, interior_padding_attr);
 
   return pad_op;
 }
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv_util.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv_util.h
index 63ebbce..d20ad08 100644
--- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv_util.h
+++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv_util.h
@@ -19,6 +19,7 @@
 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
 #include "mlir/IR/Types.h"  // from @llvm-project
 #include "mlir/Support/LLVM.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h"
 #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
 
 // Helpers for working with mhlo.convolution attrs in the mlir api as
@@ -26,91 +27,7 @@
 
 namespace mlir::odml {
 
-// Generic class that wraps the "layout" of a convolution parameter.
-// Both kernel (e.g. [o, 0, 1, i]) and input/output (e.g. [b, 0, 1, f])
-// share the same structure just with different terminology for the
-// batch/feature/input_feature/output_feature dims.
-class Layout {
- public:
-  llvm::ArrayRef<int64_t> Spatials() const { return spatials_; }
-
-  int64_t NumSpatials() const { return spatials_.size(); }
-
-  int64_t Rank() const { return NumSpatials() + 2; }
-
-  Layout(int64_t special_dim1, int64_t special_dim2, ArrayRef<int64_t> spatials)
-      : special_dim1_(special_dim1),
-        special_dim2_(special_dim2),
-        spatials_(spatials) {}
-
-  // Gets index of first special dim. The batch dim for input and outputs,
-  // or the output feature dim for the kernel.
-  int64_t SpecialDim1() const { return special_dim1_; }
-
-  // Conveniance accesor for getting the dimension size of the first
-  // special dimension from a shape.
-  int64_t SpecialDim1(llvm::ArrayRef<int64_t> shape) const {
-    return shape[special_dim1_];
-  }
-
-  // Gets index of second special dim. The feature dim for input and outputs,
-  // or the input feature dim for the kernel.
-  int64_t SpecialDim2() const { return special_dim2_; }
-
-  // Convenience accesor for getting the dimension size of the second
-  // special dimension from a shape.
-  int64_t SpecialDim2(llvm::ArrayRef<int64_t> shape) const {
-    return shape[special_dim2_];
-  }
-
-  // Conveniance method for equality checking special dims.
-  bool HasSpecialDims(int64_t special_dim1, int64_t special_dim2) const;
-
-  // Determines if the spatial dimensions are all adjacent and in
-  // ascending order (HWD).
-  bool AreSpatialsIota() const;
-
-  // Gets a "permutation array" to be used for transposing a tensor
-  // of "this" layout to the given layout. A permutation array is some
-  // permutation of [0, 1, i...] for i < rank(layout). Assumes
-  // "this" and given layout have the same rank.
-  llvm::SmallVector<int64_t, 4> GetPermForReLayout(
-      const Layout& to_layout) const;
-
-  // Permutes given shape based on the permutaion implied to take this Layout to
-  // the given one.
-  llvm::SmallVector<int64_t, 4> PermuteShape(const Layout& to_layout,
-                                             ArrayRef<int64_t> shape) const;
-
-  bool operator==(const Layout& other) const {
-    return SpecialDim1() == other.SpecialDim1() &&
-           SpecialDim2() == other.SpecialDim2() &&
-           Spatials() == other.Spatials();
-  }
-
-  bool operator!=(const Layout& other) const { return !(*this == other); }
-
- private:
-  int64_t special_dim1_;
-  int64_t special_dim2_;
-  llvm::SmallVector<int64_t> spatials_;
-};
-
-// Wrapper for the padding attrs along a single dimension.
-class DimPadding {
- public:
-  int64_t Hi() const { return hi_; }
-
-  int64_t Lo() const { return lo_; }
-
-  DimPadding(int64_t hi, int64_t lo) : hi_(hi), lo_(lo) {}
-
- private:
-  int64_t hi_;
-  int64_t lo_;
-};
-
-class ConvData {
+class ConvView {
  public:
   // int for each spatial dim. Default 1.
   llvm::ArrayRef<int64_t> Strides() const { return strides_; }
@@ -145,7 +62,7 @@
 
   mlir::Type ElementType() const { return element_type_; }
 
-  explicit ConvData(mhlo::ConvolutionOp op);
+  explicit ConvView(mhlo::ConvolutionOp op);
 
  private:
   llvm::SmallVector<int64_t, 2> strides_;
@@ -171,11 +88,11 @@
   mlir::Type element_type_;
 };
 
-inline bool HasSupportedRank(const ConvData& data) {
+inline bool HasSupportedRank(const ConvView& data) {
   return data.InputLayout().Rank() == 4 || data.InputLayout().Rank() == 5;
 }
 
-inline bool HasSupportedOutFeatureDims(const ConvData& data) {
+inline bool HasSupportedOutFeatureDims(const ConvView& data) {
   const int64_t kernel_out_features =
       data.KernelLayout().SpecialDim2(data.KernelShape());
   const int64_t out_features =
@@ -183,7 +100,7 @@
   return kernel_out_features == out_features;
 }
 
-inline bool IsNonTrivialConv(const ConvData& data) {
+inline bool IsNonTrivialConv(const ConvView& data) {
   return llvm::all_of(data.InputDilations(), [](auto d) { return d == 1; });
 }
 
@@ -191,7 +108,7 @@
 // Standard conv predicates
 //=-----
 
-inline bool HasStandardConvInFeatureDims(const ConvData& data) {
+inline bool HasStandardConvInFeatureDims(const ConvView& data) {
   // kernel_in_features * feature_groups = input_features by definition.
   const int64_t input_features =
       data.InputLayout().SpecialDim2(data.InputShape());
@@ -204,7 +121,7 @@
   return !trivial_kernel_in_features && (!is_grouped_conv || rank == 4);
 }
 
-inline bool IsStandardConv(const ConvData& data) {
+inline bool IsStandardConv(const ConvView& data) {
   return HasSupportedRank(data) && IsNonTrivialConv(data) &&
          HasStandardConvInFeatureDims(data) && HasSupportedOutFeatureDims(data);
 }
@@ -212,7 +129,7 @@
 // Does this convolution map to a standard conv_2d or conv_3d
 // (not depthwise or tranpose conv)?
 inline bool IsStandardConv(mhlo::ConvolutionOp op) {
-  const ConvData data(op);
+  const ConvView data(op);
   return IsStandardConv(data);
 }
 
@@ -220,7 +137,7 @@
 // Depthwise conv predicates
 //=-----
 
-inline bool IsDepthwiseConv(const ConvData& data) {
+inline bool IsDepthwiseConv(const ConvView& data) {
   const bool valid_rank = data.InputLayout().Rank() == 4;
   if (!valid_rank || !HasSupportedOutFeatureDims(data) ||
       !IsNonTrivialConv(data)) {
@@ -233,7 +150,7 @@
 
 // Does this convolution map to depthwise conv?
 inline bool IsDepthwiseConv(mhlo::ConvolutionOp op) {
-  const ConvData data(op);
+  const ConvView data(op);
   return IsDepthwiseConv(data);
 }
 
@@ -273,7 +190,7 @@
   return GetTFLNativeStandardConvKernelLayout(DnumRank(dnums));
 }
 
-inline bool IsTFLNativeLayout(const ConvData& data) {
+inline bool IsTFLNativeLayout(const ConvView& data) {
   const int64_t rank = data.KernelLayout().Rank();
   const auto native_io_layout = GetTFLNativeInputOrOutputLayout(rank);
 
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.cc
new file mode 100644
index 0000000..3d67bbf
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.cc
@@ -0,0 +1,111 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h"
+
+#include <cstdint>
+#include <optional>
+
+#include "llvm/ADT/STLExtras.h"
+#include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
+#include "mlir/Support/LLVM.h"  // from @llvm-project
+
+namespace mlir::odml {
+
+llvm::SmallVector<int64_t, 4> Layout::GetPermForReLayout(
+    const Layout& to_layout) const {
+  llvm::SmallVector<int64_t, 4> perm(to_layout.Rank());
+  perm[to_layout.SpecialDim1()] = SpecialDim1();
+  perm[to_layout.SpecialDim2()] = SpecialDim2();
+  for (const auto [to_spatial, from_spatial] :
+       llvm::zip(to_layout.Spatials(), Spatials())) {
+    perm[to_spatial] = from_spatial;
+  }
+  return perm;
+}
+
+llvm::SmallVector<int64_t, 4> Layout::PermuteShape(
+    const Layout& to_layout, llvm::ArrayRef<int64_t> shape) const {
+  llvm::SmallVector<int64_t, 4> new_shape(to_layout.Rank());
+  const auto perm = GetPermForReLayout(to_layout);
+  for (const auto [ind, val] : llvm::enumerate(perm)) {
+    new_shape[ind] = shape[val];
+  }
+  return new_shape;
+}
+
+bool Layout::HasSpecialDims(int64_t special_dim1, int64_t special_dim2) const {
+  return SpecialDim1() == special_dim1 && SpecialDim2() == special_dim2;
+}
+
+bool Layout::AreSpatialsIota() const {
+  llvm::ArrayRef<int64_t> spatials = Spatials();
+  return llvm::all_of(llvm::enumerate(spatials), [&](const auto& it) {
+    return it.index() == 0 || (it.value() == spatials[it.index() - 1] + 1);
+  });
+}
+
+llvm::SmallVector<int64_t, 4> ResolveStridesOrDilations(
+    int64_t rank, std::optional<mlir::DenseIntElementsAttr> opt_attr) {
+  if (!opt_attr.has_value()) {
+    return llvm::SmallVector<int64_t, 4>(rank, 1);
+  }
+  auto attr = opt_attr.value();
+  if (attr.isSplat()) {
+    return llvm::SmallVector<int64_t, 4>(rank, attr.getSplatValue<int64_t>());
+  }
+  return llvm::SmallVector<int64_t, 4>(attr.getValues<int64_t>());
+}
+
+llvm::SmallVector<DimPadding, 2> ResolvePadding(
+    int64_t rank, std::optional<mlir::DenseIntElementsAttr> opt_padding) {
+  llvm::SmallVector<DimPadding, 4> res;
+  if (!opt_padding.has_value()) {
+    for (int i = 0; i < rank; ++i) {
+      res.push_back(DimPadding(0, 0));
+    }
+    return res;
+  }
+  auto padding = opt_padding.value();
+  if (padding.isSplat()) {
+    const int64_t val = padding.getSplatValue<int64_t>();
+    for (int i = 0; i < rank; ++i) {
+      res.push_back(DimPadding(val, val));
+    }
+    return res;
+  }
+  int64_t prev;
+  for (const auto [ind, val] : llvm::enumerate(padding.getValues<int64_t>())) {
+    const int64_t side = ind % 2;
+    if (side == 1) {
+      res.push_back(DimPadding(prev, val));
+    }
+    prev = val;
+  }
+  return res;
+}
+
+bool IsSamePaddingOnDim(int64_t in, int64_t dilate, int64_t stride, int64_t k,
+                        const DimPadding& pad) {
+  const int64_t pad_diff = pad.Hi() - pad.Lo();
+  if (pad_diff > 1 || pad_diff < 0) {
+    return false;
+  }
+  const int64_t pad_total = pad.Lo() + pad.Hi();
+  const int64_t out = (in + stride - 1) / stride;
+  const int effective_filter = (k - 1) * dilate + 1;
+  return ((out - 1) * stride + effective_filter) == in + pad_total;
+}
+
+}  // namespace mlir::odml
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h
new file mode 100644
index 0000000..e3f8941
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h
@@ -0,0 +1,139 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_OP_UTIL_COMMON_H_
+#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_OP_UTIL_COMMON_H_
+
+#include <cstdint>
+
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/SmallVector.h"
+#include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
+#include "mlir/Support/LLVM.h"  // from @llvm-project
+
+namespace mlir::odml {
+
+// Class that encodes the "layout" of a tensor. Layouts, generically
+// are some naming of the dimensions of a tensor. In all cases, 2 dimensions
+// are "special" (e.g. batch / feature) and the rest are referred to as "spatial
+// dims". When the special dims are batch and feature, batch is special dim 1
+// and feature is special dim 2. When special dims are input and output features
+// (conv filter), input features is special dim 1 and output features is special
+// dim 2.
+class Layout {
+ public:
+  llvm::ArrayRef<int64_t> Spatials() const { return spatials_; }
+
+  int64_t NumSpatials() const { return spatials_.size(); }
+
+  int64_t Rank() const { return NumSpatials() + 2; }
+
+  Layout(int64_t special_dim1, int64_t special_dim2, ArrayRef<int64_t> spatials)
+      : special_dim1_(special_dim1),
+        special_dim2_(special_dim2),
+        spatials_(spatials) {}
+
+  // TODO: b/351437662 - Consider just using 2 arrays for the case where
+  // there are more than 2 special dims.
+  int64_t SpecialDim1() const { return special_dim1_; }
+
+  // Conveniance accesor for getting the dimension size of the first
+  // special dimension from a shape.
+  int64_t SpecialDim1(llvm::ArrayRef<int64_t> shape) const {
+    return shape[special_dim1_];
+  }
+
+  int64_t SpecialDim2() const { return special_dim2_; }
+
+  // Convenience accesor for getting the dimension size of the second
+  // special dimension from a shape.
+  int64_t SpecialDim2(llvm::ArrayRef<int64_t> shape) const {
+    return shape[special_dim2_];
+  }
+
+  // Conveniance method for equality checking special dims.
+  bool HasSpecialDims(int64_t special_dim1, int64_t special_dim2) const;
+
+  // Determines if the spatial dimensions are all adjacent and in
+  // ascending order.
+  bool AreSpatialsIota() const;
+
+  // Gets a "permutation array" to be used for transposing a tensor
+  // of "this" layout to the given layout. A permutation array is some
+  // permutation of [0, 1, i...] for i < rank(layout). Assumes
+  // "this" and given layout have the same rank.
+  llvm::SmallVector<int64_t, 4> GetPermForReLayout(
+      const Layout& to_layout) const;
+
+  // Permutes given shape based on the permutaion implied to take this Layout to
+  // the given one.
+  llvm::SmallVector<int64_t, 4> PermuteShape(const Layout& to_layout,
+                                             ArrayRef<int64_t> shape) const;
+
+  bool operator==(const Layout& other) const {
+    return SpecialDim1() == other.SpecialDim1() &&
+           SpecialDim2() == other.SpecialDim2() &&
+           Spatials() == other.Spatials();
+  }
+
+  bool operator!=(const Layout& other) const { return !(*this == other); }
+
+ private:
+  int64_t special_dim1_;
+  int64_t special_dim2_;
+  llvm::SmallVector<int64_t> spatials_;
+};
+
+// Wrapper for the padding attrs along a single dimension.
+class DimPadding {
+ public:
+  int64_t Hi() const { return hi_; }
+
+  int64_t Lo() const { return lo_; }
+
+  bool Trivial() const { return Hi() == 0 && Lo() == 0; }
+
+  DimPadding(int64_t lo, int64_t hi) : lo_(lo), hi_(hi) {}
+
+ private:
+  int64_t lo_;
+  int64_t hi_;
+};
+
+inline llvm::SmallVector<int64_t> UnrollI64Splat(DenseElementsAttr data) {
+  if (!data.isSplat()) {
+    return llvm::SmallVector<int64_t>(data.getValues<int64_t>());
+  }
+  return llvm::SmallVector<int64_t>(data.getType().getNumElements(),
+                                    data.getSplatValue<int64_t>());
+}
+
+// Resolves optional strides or dilations attributes. If not present,
+// will return trivial 1's vector.
+llvm::SmallVector<int64_t, 4> ResolveStridesOrDilations(
+    int64_t rank, std::optional<mlir::DenseIntElementsAttr> opt_attr);
+
+// Resolves optional paddings attributes. If not present, will return
+// trivial [0, 0] paddings on each dim.
+llvm::SmallVector<DimPadding, 2> ResolvePadding(
+    int64_t rank, std::optional<mlir::DenseIntElementsAttr> opt_padding);
+
+// Does the padding correspond to "SAME" on given dimension configuration.
+// Assumes given dimension configuration is well formed.
+bool IsSamePaddingOnDim(int64_t in, int64_t dilate, int64_t stride, int64_t k,
+                        const DimPadding& pad);
+
+}  // namespace mlir::odml
+
+#endif  // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_OP_UTIL_COMMON_H_
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad.cc
index ac5e3ff..c25f27a 100644
--- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad.cc
+++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad.cc
@@ -28,6 +28,7 @@
 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"  // IWYU pragma: keep
+#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h"
 #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad_util.h"
 #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
 
@@ -49,8 +50,8 @@
 }
 
 DenseIntElementsAttr BuildTFLPaddingAttr(OpBuilder& b, mhlo::PadOp op) {
-  auto lows = UnrollSplat(op.getEdgePaddingLow());
-  auto highs = UnrollSplat(op.getEdgePaddingHigh());
+  auto lows = UnrollI64Splat(op.getEdgePaddingLow());
+  auto highs = UnrollI64Splat(op.getEdgePaddingHigh());
 
   llvm::SmallVector<int64_t> res;
   for (auto [l, h] : llvm::zip(lows, highs)) {
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad_util.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad_util.cc
index 859cdfe..cb004d3b 100644
--- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad_util.cc
+++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad_util.cc
@@ -20,6 +20,7 @@
 #include "llvm/ADT/SmallVector.h"
 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
 #include "mlir/IR/BuiltinTypeInterfaces.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h"
 #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
 
 namespace mlir::odml {
@@ -28,16 +29,8 @@
   return op.getEdgePaddingLow().getType();
 }
 
-llvm::SmallVector<int64_t> UnrollSplat(DenseElementsAttr data) {
-  if (!data.isSplat()) {
-    return llvm::SmallVector<int64_t>(data.getValues<int64_t>());
-  }
-  return llvm::SmallVector<int64_t>(data.getType().getNumElements(),
-                                    data.getSplatValue<int64_t>());
-}
-
 DenseIntElementsAttr SliceStartFromNegPadLows(mhlo::PadOp op) {
-  auto vals = UnrollSplat(op.getEdgePaddingLow());
+  auto vals = UnrollI64Splat(op.getEdgePaddingLow());
   auto starts = llvm::map_range(
       vals, [](auto v) -> int64_t { return (v >= 0) ? 0 : -1 * v; });
   return DenseIntElementsAttr::get(GetPaddingAttrType(op),
@@ -45,7 +38,7 @@
 }
 
 DenseIntElementsAttr SliceEndFromNegPadHighs(mhlo::PadOp op) {
-  auto vals = UnrollSplat(op.getEdgePaddingHigh());
+  auto vals = UnrollI64Splat(op.getEdgePaddingHigh());
   auto zip = llvm::zip(vals, op.getOperand().getType().getShape());
   auto ends = llvm::map_range(zip, [](auto it) -> int64_t {
     return (std::get<0>(it) >= 0) ? std::get<1>(it)
@@ -56,7 +49,7 @@
 }
 
 DenseIntElementsAttr ReplaceNegsWithZero(DenseElementsAttr data) {
-  auto vals = UnrollSplat(data);
+  auto vals = UnrollI64Splat(data);
   auto res =
       llvm::map_range(vals, [](auto v) -> int64_t { return (v < 0) ? 0 : v; });
   return DenseIntElementsAttr::get(data.getType(), llvm::to_vector(res));
@@ -64,8 +57,8 @@
 
 bool AnyNegativePads(mhlo::PadOp op) {
   auto is_neg = [](int64_t v) { return v < 0; };
-  auto lows_data = UnrollSplat(op.getEdgePaddingLow());
-  auto highs_data = UnrollSplat(op.getEdgePaddingHigh());
+  auto lows_data = UnrollI64Splat(op.getEdgePaddingLow());
+  auto highs_data = UnrollI64Splat(op.getEdgePaddingHigh());
   return llvm::any_of(lows_data, is_neg) || llvm::any_of(highs_data, is_neg);
 }
 
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad_util.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad_util.h
index aa0428f..5041903 100644
--- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad_util.h
+++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad_util.h
@@ -21,8 +21,6 @@
 
 namespace mlir::odml {
 
-llvm::SmallVector<int64_t> UnrollSplat(DenseElementsAttr data);
-
 // Gets elements corresponding to slice starts from negative padding
 // values.
 DenseIntElementsAttr SliceStartFromNegPadLows(mhlo::PadOp op);
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window.cc
new file mode 100644
index 0000000..a00ee33
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window.cc
@@ -0,0 +1,742 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window.h"
+
+#include <cstdint>
+#include <optional>
+#include <string>
+#include <tuple>
+
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/Sequence.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Casting.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"  // from @llvm-project
+#include "mlir/IR/Attributes.h"  // from @llvm-project
+#include "mlir/IR/Builders.h"  // from @llvm-project
+#include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
+#include "mlir/IR/IRMapping.h"  // from @llvm-project
+#include "mlir/IR/Matchers.h"  // from @llvm-project
+#include "mlir/IR/PatternMatch.h"  // from @llvm-project
+#include "mlir/Support/LLVM.h"  // from @llvm-project
+#include "mlir/Support/LogicalResult.h"  // from @llvm-project
+#include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"  // IWYU pragma: keep
+#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h"
+#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window_util.h"
+#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.h"
+#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
+
+namespace mlir::odml {
+namespace {
+
+// filters, strides, padding, faf.
+using TFLPoolAttrsT = std::tuple<IntegerAttr, IntegerAttr, IntegerAttr,
+                                 IntegerAttr, StringAttr, StringAttr>;
+
+bool AreDilationsSupported(const ReduceWindowView& op) {
+  auto is_one = [](int64_t v) { return v == 1; };
+  return llvm::all_of(op.BaseDilations(), is_one) &&
+         llvm::all_of(op.WindowDilations(), is_one);
+}
+
+bool IsRankSupported(const ReduceWindowView& op) { return op.Rank() == 4; }
+
+std::optional<std::tuple<ReduceWindowView, Layout>> GetViewIfAttrsSupported(
+    mhlo::ReduceWindowOp op) {
+  const ReduceWindowView view(op);
+
+  if (!IsRankSupported(view)) {
+    return std::nullopt;
+  }
+
+  if (!AreDilationsSupported(view)) {
+    return std::nullopt;
+  }
+
+  auto opt_layout = view.GuessLayout();
+  if (!opt_layout.has_value()) {
+    return std::nullopt;
+  }
+  auto layout = opt_layout.value();
+
+  const int64_t batch = layout.SpecialDim1();
+  if (!view.Paddings()[batch].Trivial()) {
+    return std::nullopt;
+  }
+
+  const int64_t chan = layout.SpecialDim2();
+  if (!view.Paddings()[chan].Trivial()) {
+    return std::nullopt;
+  }
+
+  return std::tuple(view, layout);
+}
+
+std::optional<bool> IsReduceWindowLegal(mhlo::ReduceWindowOp op) {
+  return std::nullopt;
+}
+
+std::optional<bool> IsDivideLegal(mhlo::DivOp op) { return std::nullopt; }
+
+Layout TFLNativePoolingLayout(int64_t rank) {
+  return Layout(0, rank - 1, llvm::to_vector(llvm::seq<int64_t>(1, rank - 1)));
+}
+
+bool IsCstFloatZero(Value val) {
+  DenseFPElementsAttr initial_value;
+  return matchPattern(val, m_Constant(&initial_value)) &&
+         initial_value.getNumElements() == 1 &&
+         initial_value.getValues<APFloat>()[0].isZero();
+}
+
+bool IsCstIntZero(Value val) {
+  DenseIntElementsAttr initial_value;
+  return matchPattern(val, m_Constant(&initial_value)) &&
+         initial_value.getNumElements() == 1 &&
+         initial_value.getValues<APInt>()[0].isZero();
+}
+
+llvm::SmallVector<int64_t> Permute(llvm::ArrayRef<int64_t> data,
+                                   llvm::ArrayRef<int64_t> perm) {
+  llvm::SmallVector<int64_t> res(data.size());
+  for (int i = 0; i < data.size(); ++i) {
+    res[i] = data[perm[i]];
+  }
+  return res;
+}
+
+Value TransposeTensor(OpBuilder& b, Value tensor,
+                      llvm::SmallVector<int64_t> perm) {
+  const int64_t perm_size = perm.size();
+  auto perm_attr_type = RankedTensorType::get({perm_size}, b.getI64Type());
+  auto perm_attr = DenseIntElementsAttr::get(perm_attr_type, perm);
+  return b.create<mhlo::TransposeOp>(tensor.getLoc(), tensor, perm_attr);
+}
+
+DenseIntElementsAttr BuildDenseI64(OpBuilder& b, ArrayRef<int64_t> shape,
+                                   ArrayRef<int64_t> data) {
+  return DenseIntElementsAttr::get(RankedTensorType::get(shape, b.getI64Type()),
+                                   data);
+}
+
+DenseIntElementsAttr BuildDenseI64(OpBuilder& b, ArrayRef<int64_t> data) {
+  const int64_t dim = data.size();
+  return BuildDenseI64(b, {dim}, data);
+}
+
+std::optional<std::tuple<Value, Value>> GetInputAndInitIfValid(
+    mhlo::ReduceWindowOp op) {
+  if (op->getNumResults() != 1) {
+    return std::nullopt;
+  }
+  if (op.getInputs().size() > 1) {
+    return std::nullopt;
+  }
+  if (op.getInitValues().size() > 1) {
+    return std::nullopt;
+  }
+  auto init_val = op.getInitValues().front();
+  if (llvm::dyn_cast<ShapedType>(init_val.getType()).getNumElements() != 1) {
+    return std::nullopt;
+  }
+  return std::tuple(op.getInputs().front(), op.getInitValues().front());
+}
+
+std::optional<std::string> GetTFLPadding(ArrayRef<DimPadding> paddings,
+                                         ArrayRef<int64_t> window_strides,
+                                         ArrayRef<int64_t> in_shape,
+                                         ArrayRef<int64_t> window_dims) {
+  const int64_t rank = paddings.size();
+  std::string tfl_padding = "VALID";
+  for (int i = 1; i < rank - 1; ++i) {
+    const auto& dim_pad = paddings[i];
+    if (dim_pad.Trivial()) {
+      continue;
+    }
+    if (!IsSamePaddingOnDim(in_shape[i], 1, window_strides[i], window_dims[i],
+                            dim_pad)) {
+      return std::nullopt;
+    }
+    tfl_padding = "SAME";
+  }
+  return tfl_padding;
+}
+
+TFLPoolAttrsT BuildTFLPoolAttrs(OpBuilder& b, const ReduceWindowView& view,
+                                StringRef padding) {
+  const int32_t filter_h = view.WindowDims()[1];
+  auto filter_h_attr = b.getI32IntegerAttr(filter_h);
+
+  const int32_t filter_w = view.WindowDims()[2];
+  auto filter_w_attr = b.getI32IntegerAttr(filter_w);
+
+  const int32_t stride_h = view.WindowStrides()[1];
+  auto stride_h_attr = b.getI32IntegerAttr(stride_h);
+
+  const int32_t stride_w = view.WindowStrides()[2];
+  auto stride_w_attr = b.getI32IntegerAttr(stride_w);
+
+  auto padding_attr = b.getStringAttr(padding);
+  auto faf_attr = b.getStringAttr("NONE");
+
+  return std::tuple(filter_h_attr, filter_w_attr, stride_h_attr, stride_w_attr,
+                    padding_attr, faf_attr);
+}
+
+//===------------------------------------------------------------------------===
+// relayout reduce_window to channel last
+//===------------------------------------------------------------------------===
+
+class RelayoutReduceWindow : public OpRewritePattern<mhlo::ReduceWindowOp> {
+ public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(mhlo::ReduceWindowOp op,
+                                PatternRewriter& rewriter) const final;
+};
+
+LogicalResult RelayoutReduceWindow::matchAndRewrite(
+    mhlo::ReduceWindowOp op, PatternRewriter& rewriter) const {
+  //
+  // check and parse attributes
+  //=-----
+
+  auto opt_view = GetViewIfAttrsSupported(op);
+  if (!opt_view.has_value()) {
+    return rewriter.notifyMatchFailure(
+        op, "Reduce window attributes not supported.");
+  }
+  const auto [view, layout] = opt_view.value();
+
+  //
+  // get inputs and inits if there are only one
+  //=-----
+
+  auto opt_input_and_init = GetInputAndInitIfValid(op);
+  if (!opt_input_and_init.has_value()) {
+    return rewriter.notifyMatchFailure(
+        op, "Reduce window has wrong number of inputs or init values.");
+  }
+  auto [input, init_val] = opt_input_and_init.value();
+
+  //
+  // figure out permutations for layout change
+  //=-----
+
+  const auto target_layout = TFLNativePoolingLayout(view.Rank());
+  if (layout == target_layout) {
+    return rewriter.notifyMatchFailure(
+        op, "Reduce window does not need layout change");
+  }
+
+  llvm::SmallVector<int64_t> perm_for_inputs =
+      layout.GetPermForReLayout(target_layout);
+
+  //
+  // permute layout sensitive attrs
+  //=-----
+
+  // permute paddings
+  auto paddings = view.Paddings();
+  llvm::SmallVector<int64_t> new_paddings(paddings.size() * 2);
+  for (int i = 0; i < new_paddings.size() / 2; ++i) {
+    const auto& dim_pad = paddings[perm_for_inputs[i]];
+    new_paddings[2 * i] = dim_pad.Lo();
+    new_paddings[2 * i + 1] = dim_pad.Hi();
+  }
+  const int64_t new_paddings_size = paddings.size();
+  auto new_paddings_type =
+      RankedTensorType::get({new_paddings_size, 2}, rewriter.getI64Type());
+  auto new_paddings_attr =
+      DenseIntElementsAttr::get(new_paddings_type, new_paddings);
+
+  // permute window dims
+  llvm::SmallVector<int64_t> new_window_dims =
+      Permute(view.WindowDims(), perm_for_inputs);
+  auto new_window_dims_attr = BuildDenseI64(rewriter, new_window_dims);
+
+  // permute window strides
+  llvm::SmallVector<int64_t> new_window_strides =
+      Permute(view.WindowStrides(), perm_for_inputs);
+  auto new_window_strides_attr = BuildDenseI64(rewriter, new_window_strides);
+
+  //
+  // permute params and build new op
+  //=-----
+
+  // figure out permuted result type
+  llvm::SmallVector<int64_t> perm_for_outputs =
+      target_layout.GetPermForReLayout(layout);
+  auto cur_out_type = llvm::dyn_cast<ShapedType>(op.getResult(0).getType());
+  llvm::SmallVector<int64_t> new_rw_out_shape =
+      layout.PermuteShape(target_layout, cur_out_type.getShape());
+  auto new_out_type = cur_out_type.clone(new_rw_out_shape);
+
+  // transpose input and build new reduce_window
+  auto new_input = TransposeTensor(rewriter, input, perm_for_inputs);
+  auto new_rw = rewriter.create<mhlo::ReduceWindowOp>(
+      op.getLoc(), new_out_type, new_input, init_val, new_window_dims_attr,
+      new_window_strides_attr, BuildDenseI64(rewriter, view.BaseDilations()),
+      BuildDenseI64(rewriter, view.WindowDilations()), new_paddings_attr);
+  IRMapping ir_map;
+  op.getBody().cloneInto(&new_rw.getBody(), ir_map);
+
+  // transpose output and update ir
+  auto new_output =
+      TransposeTensor(rewriter, new_rw.getResult(0), perm_for_outputs);
+  rewriter.replaceOp(op, new_output);
+
+  return success();
+}
+
+//===------------------------------------------------------------------------===
+// mhlo.reduce_window -> tfl.cum_sum
+//===------------------------------------------------------------------------===
+
+class LegalizeCumSum : public OpConversionPattern<mhlo::ReduceWindowOp> {
+ public:
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult matchAndRewrite(
+      mhlo::ReduceWindowOp op, OpAdaptor adaptor,
+      ConversionPatternRewriter& rewriter) const final;
+};
+
+LogicalResult LegalizeCumSum::matchAndRewrite(
+    mhlo::ReduceWindowOp op, OpAdaptor adaptor,
+    ConversionPatternRewriter& rewriter) const {
+  //
+  // check singular params and trivial attrs
+  //=-----
+
+  auto opt_input_init = GetInputAndInitIfValid(op);
+  if (!opt_input_init.has_value()) {
+    return rewriter.notifyMatchFailure(op,
+                                       "Must have 1 input, init and result.");
+  }
+  auto [input, init] = opt_input_init.value();
+
+  if (failed(MatchBinaryReduceFunction<mhlo::AddOp>(op.getBody()))) {
+    return rewriter.notifyMatchFailure(op, "Requires scalar add in region.");
+  }
+
+  if (!IsCstFloatZero(init) && !IsCstIntZero(init)) {
+    return rewriter.notifyMatchFailure(op, "Requires 0 for init value.");
+  }
+
+  const ReduceWindowView view(op);
+
+  auto trivial = [](int64_t v) { return v == 1; };
+  const bool trivial_window_dilate =
+      llvm::all_of(view.WindowDilations(), trivial);
+  const bool trivial_base_dilate = llvm::all_of(view.BaseDilations(), trivial);
+  const bool trivial_stride = llvm::all_of(view.WindowStrides(), trivial);
+  if (!trivial_window_dilate || !trivial_stride || !trivial_base_dilate) {
+    return rewriter.notifyMatchFailure(
+        op, "Requires trivial strides and dilations attributes.");
+  }
+
+  //
+  // figure out the implicit axis of reduction
+  //=-----
+
+  auto input_type = llvm::cast<ShapedType>(input.getType());
+  if (view.WindowDims().size() != input_type.getRank()) {
+    return rewriter.notifyMatchFailure(op, "Splat window dims not supported.");
+  }
+  int64_t axis = -1;
+  for (auto [ind, val] : llvm::enumerate(view.WindowDims())) {
+    if (val == 1) {
+      continue;
+    }
+
+    if (axis != -1) {
+      return rewriter.notifyMatchFailure(op, "Multiple non 1 dimensions.");
+    }
+
+    if (val != input_type.getShape()[ind]) {
+      return rewriter.notifyMatchFailure(
+          op, "Axis dimension requires size be same as input shape's.");
+    }
+    axis = ind;
+  }
+
+  if (axis == -1) {
+    return rewriter.notifyMatchFailure(op, "Could not identify axis.");
+  }
+
+  const int64_t axis_size = input_type.getShape()[axis];
+
+  //
+  // validate padding is [N-1, 0] on axis and zero elsewhere
+  //=-----
+
+  for (const auto& [ind, dim_pad] : llvm::enumerate(view.Paddings())) {
+    if (dim_pad.Hi() != 0) {
+      return rewriter.notifyMatchFailure(op, "Has non trivial high padding.");
+    }
+
+    if (ind != axis) {
+      if (!dim_pad.Trivial()) {
+        return rewriter.notifyMatchFailure(
+            op, "Has non trivial padding on non axis dim.");
+      }
+    } else {
+      if (dim_pad.Lo() != axis_size - 1) {
+        return rewriter.notifyMatchFailure(
+            op, "Requires low padding on axis dim to be N - 1.");
+      }
+    }
+  }
+
+  //
+  // build axis constant and tfl op
+  //=-----
+
+  auto axis_cst_attr = DenseIntElementsAttr::get(
+      RankedTensorType::get({}, rewriter.getI32Type()),
+      static_cast<int32_t>(axis));
+  auto axis_cst =
+      rewriter.create<arith::ConstantOp>(op->getLoc(), axis_cst_attr);
+
+  auto tfl_exclusive_attr = rewriter.getBoolAttr(false);
+  auto tfl_reverse_attr = rewriter.getBoolAttr(false);
+
+  rewriter.replaceOpWithNewOp<TFL::CumsumOp>(op, op->getResultTypes()[0], input,
+                                             axis_cst, tfl_exclusive_attr,
+                                             tfl_reverse_attr);
+
+  return success();
+}
+
+//===------------------------------------------------------------------------===
+// mhlo.reduce_window -> tfl.max_pool
+//===------------------------------------------------------------------------===
+
+bool isFloatMinusInfinity(Value value) {
+  DenseFPElementsAttr float_value;
+  if (!matchPattern(value, m_Constant(&float_value))) {
+    return false;
+  }
+  if (float_value.getNumElements() != 1) {
+    return false;
+  }
+  APFloat element = float_value.getValues<APFloat>()[0];
+  return element.isInfinity() && element.isNegative();
+}
+
+class LegalizeMaxPool : public OpConversionPattern<mhlo::ReduceWindowOp> {
+ public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult matchAndRewrite(
+      mhlo::ReduceWindowOp op, OpAdaptor adaptor,
+      ConversionPatternRewriter& rewriter) const final;
+};
+
+LogicalResult LegalizeMaxPool::matchAndRewrite(
+    mhlo::ReduceWindowOp op, OpAdaptor adaptor,
+    ConversionPatternRewriter& rewriter) const {
+  //
+  // parse and validate lhs reduce window
+  //=-----
+
+  const auto opt_view = GetViewIfAttrsSupported(op);
+  if (!opt_view.has_value()) {
+    return rewriter.notifyMatchFailure(op, "Reduce window is not valid.");
+  }
+  const auto [view, layout] = opt_view.value();
+  if (layout != TFLNativePoolingLayout(layout.Rank())) {
+    return rewriter.notifyMatchFailure(op, "Not tfl standard layout.");
+  }
+
+  // Check that the reduce-window is a max-reduce-window.
+  if (failed(MatchBinaryReduceFunction<mhlo::MaxOp>(op.getBody()))) {
+    return rewriter.notifyMatchFailure(op, "Must be a max pool.");
+  }
+
+  auto type = mlir::dyn_cast<ShapedType>(op.getResult(0).getType());
+  if (!mlir::isa<FloatType>(type.getElementType())) {
+    return rewriter.notifyMatchFailure(op, "Not a floating point pool.");
+  }
+
+  //
+  // validate inputs and init
+  //=-----
+
+  auto opt_inputs_and_init = GetInputAndInitIfValid(op);
+  if (!opt_inputs_and_init.has_value()) {
+    return rewriter.notifyMatchFailure(op, "Too many inputs or inits.");
+  }
+  auto [input, init] = opt_inputs_and_init.value();
+  auto input_type = llvm::dyn_cast<ShapedType>(input.getType());
+
+  if (!isFloatMinusInfinity(init)) {
+    return rewriter.notifyMatchFailure(op, "Init not minus infinity.");
+  }
+
+  //
+  // build tfl
+  //=-----
+
+  auto opt_tfl_padding =
+      GetTFLPadding(view.Paddings(), view.WindowStrides(),
+                    input_type.getShape(), view.WindowDims());
+  if (!opt_tfl_padding.has_value()) {
+    return rewriter.notifyMatchFailure(op, "Padding not SAME or VALID.");
+  }
+  const auto& tfl_padding = opt_tfl_padding.value();
+
+  auto [fh, fw, sh, sw, p, faf] =
+      BuildTFLPoolAttrs(rewriter, view, tfl_padding);
+  rewriter.replaceOpWithNewOp<TFL::MaxPool2DOp>(op, type, input, p, sw, sh, fw,
+                                                fh, faf);
+
+  return success();
+}
+
+//===------------------------------------------------------------------------===
+// mhlo.div(mhlo.reduce_window, cst | mhlo.reduce_window) -> tfl.avg_pool
+//===------------------------------------------------------------------------===
+
+void ReplaceWithAvgPool(mhlo::DivOp op, Value rw_lhs_input,
+                        const ReduceWindowView& lhs_view,
+                        llvm::StringRef padding, PatternRewriter& rewriter,
+                        mhlo::TransposeOp opt_final_tpose) {
+  Type out_type =
+      opt_final_tpose ? opt_final_tpose.getOperand().getType() : op.getType();
+
+  auto [fh, fw, sh, sw, p, faf] =
+      BuildTFLPoolAttrs(rewriter, lhs_view, padding);
+  Value final_op = rewriter.create<TFL::AveragePool2DOp>(
+      op->getLoc(), out_type, rw_lhs_input, fh, fw, p, sh, sw, faf);
+
+  if (opt_final_tpose) {
+    final_op = rewriter
+                   .create<mhlo::TransposeOp>(final_op.getLoc(), final_op,
+                                              opt_final_tpose.getPermutation())
+                   .getResult();
+  }
+
+  rewriter.replaceOp(op, final_op);
+}
+
+// Walks up the op and ignore all precedding ops of type Tys.
+// Returns the first producer op whose type is not in Tys.
+template <typename... Tys>
+Value RecursivelyWalkUp(Value op) {
+  while (llvm::isa_and_nonnull<Tys...>(op.getDefiningOp())) {
+    Operation* producer = op.getDefiningOp();
+    op = producer->getOperand(/*idx=*/0);
+  }
+
+  return op;
+}
+
+class LegalizeAvgPool : public OpConversionPattern<mhlo::DivOp> {
+ public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult matchAndRewrite(
+      mhlo::DivOp op, OpAdaptor adaptor,
+      ConversionPatternRewriter& rewriter) const final;
+};
+
+LogicalResult LegalizeAvgPool::matchAndRewrite(
+    mhlo::DivOp div_op, OpAdaptor adaptor,
+    ConversionPatternRewriter& rewriter) const {
+  //
+  // parse and validate lhs reduce window
+  //=-----
+
+  auto div_lhs = div_op.getLhs();
+  // If div's input is transposed, save it to chain on the new pool op.
+  mhlo::TransposeOp opt_final_tpose;
+  if (auto div_lhs_op = div_lhs.getDefiningOp()) {
+    opt_final_tpose = llvm::dyn_cast_or_null<mhlo::TransposeOp>(div_lhs_op);
+  }
+
+  auto rw_lhs_val = RecursivelyWalkUp<mhlo::TransposeOp>(div_lhs);
+  auto rw_lhs =
+      llvm::dyn_cast_or_null<mhlo::ReduceWindowOp>(rw_lhs_val.getDefiningOp());
+  if (!rw_lhs) {
+    return rewriter.notifyMatchFailure(
+        div_op, "Could not match lhs of div on reduce window.");
+  }
+
+  const auto opt_rw_lhs_view = GetViewIfAttrsSupported(rw_lhs);
+  if (!opt_rw_lhs_view.has_value()) {
+    return rewriter.notifyMatchFailure(div_op, "Lhs rw is not valid.");
+  }
+  const auto [rw_lhs_view, rw_lhs_layout] = opt_rw_lhs_view.value();
+  if (rw_lhs_layout != TFLNativePoolingLayout(rw_lhs_layout.Rank())) {
+    return rewriter.notifyMatchFailure(
+        div_op, "Lhs reduce window not tfl standard layout.");
+  }
+
+  // Check that the reduce-window is a sum-reduce-window.
+  if (failed(MatchBinaryReduceFunction<mhlo::AddOp>(rw_lhs.getBody()))) {
+    return rewriter.notifyMatchFailure(div_op,
+                                       "Failed to match rw lhs binary func.");
+  }
+
+  //
+  // validate inputs and init val
+  //=-----
+
+  auto opt_rw_lhs_input_and_init = GetInputAndInitIfValid(rw_lhs);
+  if (!opt_rw_lhs_input_and_init.has_value()) {
+    return rewriter.notifyMatchFailure(
+        div_op, "Lhs reduce window has wrong number of inputs or init values.");
+  }
+  auto [rw_lhs_input, rw_lhs_init_val] = opt_rw_lhs_input_and_init.value();
+  auto rw_lhs_input_type = llvm::dyn_cast<ShapedType>(rw_lhs_input.getType());
+
+  auto rw_lhs_type =
+      mlir::dyn_cast<RankedTensorType>(rw_lhs.getResult(0).getType());
+  if (!mlir::isa<FloatType>(rw_lhs_type.getElementType())) {
+    return rewriter.notifyMatchFailure(div_op,
+                                       "Reduce window lhs most be float type.");
+  }
+
+  // If the init value isn't zero then it can't be an average pool.
+  if (!IsCstFloatZero(rw_lhs_init_val)) {
+    return rewriter.notifyMatchFailure(
+        div_op, "Reduce window lhs init value is not zero.");
+  }
+
+  //
+  // case 1: rhs is splat const with val == window_size
+  //=-----
+
+  auto opt_tfl_padding =
+      GetTFLPadding(rw_lhs_view.Paddings(), rw_lhs_view.WindowStrides(),
+                    rw_lhs_input_type.getShape(), rw_lhs_view.WindowDims());
+  if (!opt_tfl_padding.has_value()) {
+    return rewriter.notifyMatchFailure(div_op,
+                                       "Padding must be VALID or SAME.");
+  }
+  const auto& tfl_padding = opt_tfl_padding.value();
+
+  {
+    DenseFPElementsAttr divisor;
+    auto div_rhs = RecursivelyWalkUp<mhlo::BroadcastInDimOp, mhlo::TransposeOp>(
+        div_op.getRhs());
+    if (matchPattern(div_rhs, m_Constant(&divisor))) {
+      if (!divisor.isSplat()) {
+        return failure();
+      }
+
+      if (!divisor.getSplatValue<APFloat>().isExactlyValue(
+              rw_lhs_view.WindowSize())) {
+        return rewriter.notifyMatchFailure(
+            div_op, "Rhs splat const is not equal to window size.");
+      }
+
+      if (tfl_padding != "VALID") {
+        return rewriter.notifyMatchFailure(div_op,
+                                           "Matching on rhs splat const where "
+                                           "rw lhs has non-trivial padding.");
+      }
+
+      ReplaceWithAvgPool(div_op, rw_lhs_input, rw_lhs_view, tfl_padding,
+                         rewriter, opt_final_tpose);
+      return success();
+    }
+  }
+
+  //
+  // case 2: rhs is another reduce window over 1's with same config as lhs
+  //=-----
+
+  {
+    Value divisor = RecursivelyWalkUp<mhlo::BroadcastInDimOp, mhlo::ReshapeOp,
+                                      mhlo::TransposeOp>(div_op.getRhs());
+    auto rw_rhs =
+        dyn_cast_or_null<mhlo::ReduceWindowOp>(divisor.getDefiningOp());
+    if (!rw_rhs) {
+      return rewriter.notifyMatchFailure(
+          div_op, "Rhs of div op is not a reduce window.");
+    }
+
+    const auto opt_rw_rhs_view = GetViewIfAttrsSupported(rw_rhs);
+    if (!opt_rw_rhs_view.has_value()) {
+      return rewriter.notifyMatchFailure(div_op, "Rhs rw is not valid.");
+    }
+    const auto [rw_rhs_view, rw_rhs_layout] = opt_rw_rhs_view.value();
+    if (rw_rhs_layout != TFLNativePoolingLayout(rw_rhs_layout.Rank())) {
+      return rewriter.notifyMatchFailure(
+          div_op, "Rhs reduce window not tfl standard layout.");
+    }
+
+    // Check that RHS is a sum-reduce-window.
+    if (failed(MatchBinaryReduceFunction<mhlo::AddOp>(rw_rhs.getBody()))) {
+      return rewriter.notifyMatchFailure(
+          div_op, "Rhs rw body function is not an add op.");
+    }
+
+    auto opt_rw_rhs_input_and_init = GetInputAndInitIfValid(rw_rhs);
+    if (!opt_rw_rhs_input_and_init.has_value()) {
+      return rewriter.notifyMatchFailure(
+          div_op,
+          "Rhs reduce window has wrong number of inputs or init values.");
+    }
+    auto [rw_rhs_input, rw_rhs_init_val] = opt_rw_rhs_input_and_init.value();
+
+    if (!IsCstFloatZero(rw_rhs_init_val)) {
+      return rewriter.notifyMatchFailure(div_op,
+                                         "Rhs rw init vals is not zero.");
+    }
+
+    rw_rhs_input = RecursivelyWalkUp<mhlo::BroadcastInDimOp, mhlo::TransposeOp>(
+        rw_rhs_input);
+    DenseFPElementsAttr rhs_input_data;
+    if (!matchPattern(rw_rhs_input, m_Constant(&rhs_input_data)) ||
+        !rhs_input_data.isSplat() ||
+        !rhs_input_data.getSplatValue<APFloat>().isExactlyValue(1.0)) {
+      return rewriter.notifyMatchFailure(div_op,
+                                         "Rw rhs input is not splat of 1.0.");
+    }
+
+    // Check that the two reduce window have the same window configuration.
+    if (rw_lhs.getWindowDimensions() != rw_rhs.getWindowDimensions() ||
+        rw_lhs.getWindowStrides() != rw_rhs.getWindowStrides() ||
+        rw_lhs.getPadding() != rw_rhs.getPadding()) {
+      return rewriter.notifyMatchFailure(
+          div_op, "Lhs rw and Rhs rw do not have the same config.");
+    }
+
+    ReplaceWithAvgPool(div_op, rw_lhs_input, rw_lhs_view, tfl_padding, rewriter,
+                       opt_final_tpose);
+    return success();
+  }
+
+  return failure();
+}
+
+}  // namespace
+
+void PopulateLegalizeReduceWindowPatterns(MLIRContext* ctx,
+                                          RewritePatternSet& patterns,
+                                          ConversionTarget& target) {
+  patterns.add<LegalizeAvgPool, LegalizeMaxPool, LegalizeCumSum>(ctx);
+  target.addDynamicallyLegalOp<mhlo::ReduceWindowOp>(IsReduceWindowLegal);
+  target.addDynamicallyLegalOp<mhlo::DivOp>(IsDivideLegal);
+}
+
+void PopulatePrepareReduceWindowPatterns(MLIRContext* ctx,
+                                         RewritePatternSet& patterns) {
+  patterns.add<RelayoutReduceWindow>(ctx);
+}
+
+}  // namespace mlir::odml
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window.h
new file mode 100644
index 0000000..ccc9c27f
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window.h
@@ -0,0 +1,45 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_REDUCE_WINDOW_H_
+#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_REDUCE_WINDOW_H_
+
+#include "mlir/IR/PatternMatch.h"  // from @llvm-project
+#include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
+
+namespace mlir::odml {
+
+// Patterns to legalize mhlo.reduce_window to TFL.
+//
+// Maps the following representations of AvgPool in MHLO into a tfl.avg_pool
+// operation when they cleanly map to 2D or 3D average pool with VALID or SAME
+// padding:
+// * div(reduce_sum_window(x), constant(sizeof(window)))
+// * div(reduce_sum_window(x), reduce_sum_window(constant(1)))
+//
+// Emits: tfl.average_pool2d
+void PopulateLegalizeReduceWindowPatterns(MLIRContext* ctx,
+                                          RewritePatternSet& patterns,
+                                          ConversionTarget& target);
+
+// Patterns to prepare mhlo.reduce_window for legalization.
+// Transposes reduce_windows to be NHWC.
+//
+// Emits: tfl.transpose
+void PopulatePrepareReduceWindowPatterns(MLIRContext* ctx,
+                                         RewritePatternSet& patterns);
+
+}  // namespace mlir::odml
+
+#endif  // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_REDUCE_WINDOW_H_
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window_util.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window_util.cc
new file mode 100644
index 0000000..67e4db6
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window_util.cc
@@ -0,0 +1,72 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window_util.h"
+
+#include <cstdint>
+#include <optional>
+
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/Sequence.h"
+#include "llvm/ADT/SmallVector.h"
+#include "mlir/Support/LLVM.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h"
+#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
+
+namespace mlir::odml {
+
+ReduceWindowView::ReduceWindowView(mhlo::ReduceWindowOp op) {
+  rank_ = op.getWindowDimensions().size();
+  window_dims_ =
+      SmallVector<int64_t, 4>(op.getWindowDimensions().getValues<int64_t>());
+  window_strides_ = ResolveStridesOrDilations(rank_, op.getWindowStrides());
+  window_dilations_ = ResolveStridesOrDilations(rank_, op.getWindowDilations());
+  base_dilations_ = ResolveStridesOrDilations(rank_, op.getBaseDilations());
+  paddings_ = ResolvePadding(rank_, op.getPadding());
+  window_size_ = 1;
+  for (auto d : window_dims_) {
+    window_size_ *= d;
+  }
+}
+
+std::optional<Layout> ReduceWindowView::GuessLayout() const {
+  auto zip_dims_strides = llvm::zip(WindowDims(), WindowStrides());
+  auto simple_window_dims =
+      llvm::to_vector(llvm::map_range(zip_dims_strides, [](auto it) {
+        return std::get<0>(it) == 1 && std::get<1>(it) == 1;
+      }));
+
+  if (llvm::count(simple_window_dims, 1) < 2) {
+    return std::nullopt;
+  }
+
+  const bool is_channel_last =
+      simple_window_dims[0] && simple_window_dims[Rank() - 1];
+  if (is_channel_last) {
+    return Layout(0, Rank() - 1,
+                  llvm::to_vector(llvm::seq<int64_t>(1, Rank() - 1)));
+  }
+
+  const bool is_channel_first = simple_window_dims[0] && simple_window_dims[1];
+  if (is_channel_first) {
+    return Layout(0, 1, llvm::to_vector(llvm::seq<int64_t>(2, Rank())));
+  }
+
+  // In theory, we can support any layout with at least 2 1's in
+  // `simple_window_dims` by permuting layouts such that the 1's are
+  // the first and last position. Unclear if such a case ever comes up.
+  return std::nullopt;
+}
+
+}  // namespace mlir::odml
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window_util.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window_util.h
new file mode 100644
index 0000000..512389b
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window_util.h
@@ -0,0 +1,62 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_REDUCE_WINDOW_UTIL_H_
+#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_REDUCE_WINDOW_UTIL_H_
+
+#include <optional>
+
+#include "llvm/ADT/ArrayRef.h"
+#include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
+#include "mlir/IR/Types.h"  // from @llvm-project
+#include "mlir/Support/LLVM.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h"
+#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
+
+// Helpers for working with mhlo.reduce_window attrs in the mlir api as
+// native cc types.
+
+namespace mlir::odml {
+
+class ReduceWindowView {
+ public:
+  explicit ReduceWindowView(mhlo::ReduceWindowOp op);
+
+  llvm::ArrayRef<int64_t> WindowDims() const { return window_dims_; }
+  int64_t WindowSize() const { return window_size_; }
+  llvm::ArrayRef<int64_t> WindowStrides() const { return window_strides_; }
+  llvm::ArrayRef<DimPadding> Paddings() const { return paddings_; }
+  llvm::ArrayRef<int64_t> WindowDilations() const { return window_dilations_; }
+  llvm::ArrayRef<int64_t> BaseDilations() const { return base_dilations_; }
+  int64_t Rank() const { return rank_; }
+
+  std::optional<Layout> GuessLayout() const;
+
+ private:
+  int64_t rank_;
+
+  llvm::SmallVector<int64_t, 4> window_dims_;
+  llvm::SmallVector<int64_t, 4> window_strides_;
+  llvm::SmallVector<int64_t, 4> window_dilations_;
+
+  llvm::SmallVector<DimPadding, 4> paddings_;
+
+  llvm::SmallVector<int64_t, 4> base_dilations_;
+
+  int64_t window_size_;
+};
+
+}  // namespace mlir::odml
+
+#endif  // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_REDUCE_WINDOW_UTIL_H_
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/slice.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/slice.cc
new file mode 100644
index 0000000..76c843e
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/slice.cc
@@ -0,0 +1,214 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/slice.h"
+
+#include <cstdint>
+
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/Casting.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"  // from @llvm-project
+#include "mlir/IR/Builders.h"  // from @llvm-project
+#include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypeInterfaces.h"  // from @llvm-project
+#include "mlir/IR/PatternMatch.h"  // from @llvm-project
+#include "mlir/Support/LLVM.h"  // from @llvm-project
+#include "mlir/Support/LogicalResult.h"  // from @llvm-project
+#include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"  // IWYU pragma: keep
+#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h"
+#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
+
+namespace mlir::odml {
+namespace {
+
+//===----------------------------------------------------------------------===//
+// mhlo.slice
+//===----------------------------------------------------------------------===//
+
+// Cast the value to i32.
+Value BuildTFLCastOp(OpBuilder& b, Value value) {
+  return b.create<TFL::CastOp>(
+      value.getLoc(),
+      RankedTensorType::get(llvm::cast<ShapedType>(value.getType()).getShape(),
+                            b.getI32Type()),
+      value);
+}
+
+class LegalizeSliceOp : public OpConversionPattern<mhlo::SliceOp> {
+ public:
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult matchAndRewrite(
+      mhlo::SliceOp slice_op, OpAdaptor adaptor,
+      ConversionPatternRewriter& rewriter) const final {
+    auto begin = rewriter.create<arith::ConstantOp>(slice_op.getLoc(),
+                                                    slice_op.getStartIndices());
+    auto end = rewriter.create<arith::ConstantOp>(slice_op.getLoc(),
+                                                  slice_op.getLimitIndices());
+    auto strides = rewriter.create<arith::ConstantOp>(slice_op.getLoc(),
+                                                      slice_op.getStrides());
+    auto zero = rewriter.getIntegerAttr(rewriter.getI32Type(), 0);
+    auto no_offset = rewriter.getBoolAttr(false);
+
+    rewriter.replaceOpWithNewOp<TFL::StridedSliceOp>(
+        slice_op, slice_op.getType(), slice_op.getOperand(),
+        BuildTFLCastOp(rewriter, begin), BuildTFLCastOp(rewriter, end),
+        BuildTFLCastOp(rewriter, strides), zero, zero, zero, zero, zero,
+        no_offset);
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// mhlo.dynamic_slice
+//===----------------------------------------------------------------------===//
+
+class CastSliceIndicesToSignless
+    : public OpRewritePattern<mhlo::DynamicSliceOp> {
+ public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(mhlo::DynamicSliceOp op,
+                                PatternRewriter& rewriter) const final;
+};
+
+LogicalResult CastSliceIndicesToSignless::matchAndRewrite(
+    mhlo::DynamicSliceOp op, PatternRewriter& rewriter) const {
+  // All start inds have the same element type.
+  auto start_type =
+      llvm::cast<ShapedType>(op.getStartIndices().front().getType());
+  auto start_e_type = start_type.getElementType();
+
+  if (start_e_type.isSignlessIntOrFloat()) {
+    return rewriter.notifyMatchFailure(op, "Already signless.");
+  }
+  auto new_start_e_type =
+      rewriter.getIntegerType(start_e_type.getIntOrFloatBitWidth());
+
+  llvm::SmallVector<Value> casted_start_inds;
+  for (auto start_ind_opr : op.getStartIndices()) {
+    auto casted_start_ind_opr = rewriter.create<mhlo::ConvertOp>(
+        start_ind_opr.getLoc(), start_ind_opr, new_start_e_type);
+    casted_start_inds.push_back(casted_start_ind_opr.getResult());
+  }
+
+  rewriter.replaceOpWithNewOp<mhlo::DynamicSliceOp>(
+      op, op.getOperand(), casted_start_inds, op.getSliceSizes());
+
+  return success();
+}
+
+bool IsDynamicSliceLegal(mhlo::DynamicSliceOp op) {
+  return !llvm::cast<ShapedType>(op.getOperand().getType()).hasStaticShape();
+}
+
+class LegalizeDynamicSliceOp
+    : public OpConversionPattern<mhlo::DynamicSliceOp> {
+ public:
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult matchAndRewrite(
+      mhlo::DynamicSliceOp op, OpAdaptor adaptor,
+      ConversionPatternRewriter& rewriter) const final;
+};
+
+LogicalResult LegalizeDynamicSliceOp::matchAndRewrite(
+    mhlo::DynamicSliceOp op, OpAdaptor adaptor,
+    ConversionPatternRewriter& rewriter) const {
+  auto start_type =
+      llvm::cast<ShapedType>(op.getStartIndices().front().getType());
+  auto start_e_type = start_type.getElementType();
+  if (!start_e_type.isSignlessIntOrFloat()) {
+    return rewriter.notifyMatchFailure(
+        op, "Must be signless integer for start indices.");
+  }
+
+  auto input_type = llvm::cast<ShapedType>(op.getOperand().getType());
+  if (!input_type.hasStaticShape()) {
+    return rewriter.notifyMatchFailure(op, "Input must be statically shaped.");
+  }
+
+  //
+  // clamp start indices between zero and shape(operand) - slice_sizes
+  //=-----
+
+  Value clamp_left_cst = rewriter.create<arith::ConstantOp>(
+      op->getLoc(), rewriter.getZeroAttr(start_type));
+
+  llvm::SmallVector<Value> new_start_indices;
+  const auto stride_sizes = UnrollI64Splat(op.getSliceSizes());
+
+  for (auto [dim_size, start_ind_opr, stride_size] :
+       llvm::zip(input_type.getShape(), op.getStartIndices(), stride_sizes)) {
+    const int64_t clamp_right_val = dim_size - stride_size;
+    auto clamp_right_cst = rewriter.create<arith::ConstantOp>(
+        op->getLoc(),
+        DenseElementsAttr::get(start_type, rewriter.getIntegerAttr(
+                                               start_e_type, clamp_right_val)));
+
+    Value new_start_ind = rewriter.create<TFL::MaximumOp>(
+        op->getLoc(), start_type, clamp_left_cst, start_ind_opr);
+    new_start_ind = rewriter.create<TFL::MinimumOp>(
+        op->getLoc(), start_type, clamp_right_cst, new_start_ind);
+
+    new_start_indices.push_back(new_start_ind);
+  }
+
+  //
+  // pack variadic scalar start indices into one tensor
+  //=-----
+
+  const int64_t packed_start_indices_size = new_start_indices.size();
+  auto packed_start_indices_type =
+      RankedTensorType::get({packed_start_indices_size}, start_e_type);
+
+  auto values_count_attr =
+      rewriter.getI32IntegerAttr(packed_start_indices_size);
+  auto pack_axis_attr = rewriter.getI32IntegerAttr(0);
+
+  auto packed_start_inds = rewriter.create<TFL::PackOp>(
+      op->getLoc(), packed_start_indices_type, new_start_indices,
+      values_count_attr, pack_axis_attr);
+
+  //
+  // build tfl
+  //=-----
+
+  auto slice_sizes_cst =
+      rewriter.create<arith::ConstantOp>(op->getLoc(), op.getSliceSizes());
+
+  rewriter.replaceOpWithNewOp<TFL::SliceOp>(op, op.getType(), op.getOperand(),
+                                            packed_start_inds, slice_sizes_cst);
+
+  return success();
+}
+
+}  // namespace
+
+void PopulateLegalizeSlicePatterns(MLIRContext* ctx,
+                                   RewritePatternSet& patterns,
+                                   ConversionTarget& target) {
+  patterns.add<LegalizeSliceOp, LegalizeDynamicSliceOp>(ctx);
+
+  target.addIllegalOp<mhlo::SliceOp>();
+  target.addDynamicallyLegalOp<mhlo::DynamicSliceOp>(IsDynamicSliceLegal);
+}
+
+void PopulatePrepareSlicePatterns(MLIRContext* ctx,
+                                  RewritePatternSet& patterns) {
+  patterns.add<CastSliceIndicesToSignless>(ctx);
+}
+
+}  // namespace mlir::odml
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/slice.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/slice.h
new file mode 100644
index 0000000..024cbb4
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/slice.h
@@ -0,0 +1,34 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_SLICE_H_
+#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_SLICE_H_
+
+#include "mlir/IR/PatternMatch.h"  // from @llvm-project
+#include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
+#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
+
+namespace mlir::odml {
+
+// Patterns to legalize mhlo.slice to TFL.
+void PopulateLegalizeSlicePatterns(MLIRContext* ctx,
+                                   RewritePatternSet& patterns,
+                                   ConversionTarget& target);
+
+void PopulatePrepareSlicePatterns(MLIRContext* ctx,
+                                  RewritePatternSet& patterns);
+
+}  // namespace mlir::odml
+
+#endif  // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_SLICE_H_
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/sort.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/sort.cc
new file mode 100644
index 0000000..43477cc
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/sort.cc
@@ -0,0 +1,146 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/sort.h"
+
+#include <cstdint>
+
+#include "llvm/ADT/ilist.h"
+#include "llvm/Support/Casting.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"  // from @llvm-project
+#include "mlir/IR/Block.h"  // from @llvm-project
+#include "mlir/IR/Builders.h"  // from @llvm-project
+#include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypeInterfaces.h"  // from @llvm-project
+#include "mlir/IR/MLIRContext.h"  // from @llvm-project
+#include "mlir/IR/PatternMatch.h"  // from @llvm-project
+#include "mlir/IR/Region.h"  // from @llvm-project
+#include "mlir/Support/LogicalResult.h"  // from @llvm-project
+#include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
+#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/hlo_matchers.h"
+#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
+
+namespace mlir::odml {
+namespace {
+
+using OpListType = llvm::iplist<Operation>;
+
+template <typename ReturnOpType>
+bool MatchTopKComparator(Region& comparator) {
+  if (!comparator.hasOneBlock()) return false;
+  Block& comparator_blk = comparator.front();
+
+  OpListType& operations = comparator_blk.getOperations();
+  if (operations.size() != 2) return false;
+
+  auto compare_op =
+      llvm::dyn_cast_or_null<mhlo::CompareOp>(&operations.front());
+  auto return_op = llvm::dyn_cast_or_null<ReturnOpType>(&operations.back());
+  if (!compare_op || !return_op) return false;
+
+  if (compare_op.getComparisonDirection() != mhlo::ComparisonDirection::GT) {
+    return false;
+  }
+
+  if (compare_op.getOperands()[0] != comparator_blk.getArgument(0) ||
+      compare_op.getOperands()[1] != comparator_blk.getArgument(1)) {
+    return false;
+  }
+
+  return return_op.getOperands().front() == compare_op.getResult();
+}
+
+bool IsSortOpNotTopK(mhlo::SortOp op) {
+  if (op->getNumOperands() != 2) {
+    return true;
+  }
+
+  auto keys_opr = op.getInputs().front();
+  auto keys_type = llvm::cast<ShapedType>(keys_opr.getType());
+
+  if (!keys_type.hasStaticShape() ||
+      !keys_type.getElementType().isIntOrFloat()) {
+    return true;
+  }
+
+  auto indices_opr = op.getInputs().back();
+  auto indices_type = llvm::cast<ShapedType>(indices_opr.getType());
+
+  if (!indices_type.hasStaticShape() ||
+      !indices_type.getElementType().isInteger(32)) {
+    return true;
+  }
+
+  const int64_t sort_dim = op.getDimension();
+  const auto k = indices_type.getDimSize(sort_dim);
+  const auto rank = keys_type.getRank();
+
+  if (sort_dim != rank - 1 || k < 1) {
+    return true;
+  }
+
+  OpBuilder b(op->getContext());
+  if (!MatchIota(b.getI64TensorAttr({sort_dim}), indices_opr)) {
+    return true;
+  }
+
+  if (!MatchTopKComparator<mhlo::ReturnOp>(op.getComparator())) {
+    return true;
+  }
+
+  return false;
+}
+
+class LegalizeSortOp : public OpConversionPattern<mhlo::SortOp> {
+ public:
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult matchAndRewrite(
+      mhlo::SortOp sort_op, OpAdaptor adaptor,
+      ConversionPatternRewriter& rewriter) const final;
+};
+
+LogicalResult LegalizeSortOp::matchAndRewrite(
+    mhlo::SortOp op, OpAdaptor adaptor,
+    ConversionPatternRewriter& rewriter) const {
+  if (IsSortOpNotTopK(op)) {
+    return failure();
+  }
+
+  auto keys = op.getInputs().front();
+  auto indices = op.getInputs().back();
+  auto indices_type = llvm::cast<ShapedType>(indices.getType());
+
+  const int32_t k = indices_type.getShape().back();
+  auto k_cst_attr = DenseIntElementsAttr::get(
+      RankedTensorType::get({}, rewriter.getI32Type()), k);
+  auto k_cst = rewriter.create<arith::ConstantOp>(op->getLoc(), k_cst_attr);
+
+  rewriter.replaceOpWithNewOp<TFL::TopKV2Op>(op, keys.getType(),
+                                             indices.getType(), keys, k_cst);
+
+  return success();
+}
+
+}  // namespace
+
+void PopulateSortPatterns(MLIRContext* ctx, RewritePatternSet& patterns,
+                          ConversionTarget& target) {
+  patterns.add<LegalizeSortOp>(ctx);
+  target.addDynamicallyLegalOp<mhlo::SortOp>(IsSortOpNotTopK);
+}
+
+}  // namespace mlir::odml
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/sort.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/sort.h
new file mode 100644
index 0000000..9bbb1f3
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/sort.h
@@ -0,0 +1,28 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_SORT_H_
+#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_SORT_H_
+
+#include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
+
+namespace mlir::odml {
+
+void PopulateSortPatterns(MLIRContext* ctx, RewritePatternSet& patterns,
+                          ConversionTarget& target);
+
+}  // namespace mlir::odml
+
+#endif  // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_SORT_H_
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/prepare_hlo.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/prepare_hlo.cc
index 5a7b976..9ff5a6e 100644
--- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/prepare_hlo.cc
+++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/prepare_hlo.cc
@@ -17,7 +17,6 @@
 #include <utility>
 
 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project  // IWYU pragma: keep
-#include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
 #include "mlir/Pass/Pass.h"  // from @llvm-project
@@ -29,6 +28,8 @@
 #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv.h"
 #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv_util.h"  // IWYU pragma: keep
 #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad_util.h"  // IWYU pragma: keep
+#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window.h"
+#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/slice.h"
 #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"  // IWYU pragma: keep
 
 namespace mlir {
@@ -54,6 +55,8 @@
   populateWithGenerated(patterns);
 
   PopulatePrepareConvPatterns(context, patterns);
+  PopulatePrepareReduceWindowPatterns(context, patterns);
+  PopulatePrepareSlicePatterns(context, patterns);
 
   if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) {
     signalPassFailure();
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tfl_stablehlo_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tfl_stablehlo_pass.cc
index 8e22b34..e8a2bc8 100644
--- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tfl_stablehlo_pass.cc
+++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tfl_stablehlo_pass.cc
@@ -148,6 +148,7 @@
         case flexbuffers::FBT_VECTOR_INT: {
           const auto& vector = value.AsTypedVector();
           std::vector<int64_t> vec;
+          vec.reserve(vector.size());
           for (size_t i = 0; i < vector.size(); i++) {
             vec.push_back(vector[i].AsInt64());
           }
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo.cc
index 7aff6c1..5395af5 100644
--- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo.cc
+++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo.cc
@@ -14,12 +14,16 @@
 ==============================================================================*/
 
 // The kept headers are provided for the included file `passes.h.inc`.
+#include <cstdint>
 #include <memory>
 #include <optional>
 #include <utility>
 
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/Casting.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"  // from @llvm-project
 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinAttributeInterfaces.h"  // from @llvm-project
 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
@@ -37,6 +41,9 @@
 #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/gather.h"
 #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad.h"
 #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce.h"
+#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window.h"
+#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/slice.h"
+#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/sort.h"
 #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.h"  // IWYU pragma: keep
 #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h"
 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"  // IWYU pragma: keep
@@ -47,9 +54,137 @@
 namespace odml {
 namespace {
 
+// Returns the shape of the given value in a Constant Op.
+arith::ConstantOp ShapeToConst(PatternRewriter& rewriter, Value value) {
+  ArrayRef<int64_t> shape = mlir::cast<ShapedType>(value.getType()).getShape();
+  auto attr_type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
+                                         rewriter.getIntegerType(64));
+  auto attr = DenseElementsAttr::get(attr_type, shape);
+  return rewriter.create<arith::ConstantOp>(value.getLoc(), attr_type, attr);
+}
+
+bool IsSign(APInt a, APInt sign) {
+  if (a.isZero()) return a == sign;
+  if (a.isNegative()) return sign == -1;
+  return sign == 1;
+}
+
+bool IsSign(APFloat a, APFloat sign) {
+  if (a.isNaN() || a.isZero()) return a == sign;
+  if (a.isNegative()) return sign.isExactlyValue(-1.0);
+  return sign.isExactlyValue(1.0);
+}
+
+bool IsDenseSplatIntAttr(ElementsAttr float_or_int) {
+  return mlir::isa<SplatElementsAttr>(float_or_int) &&
+         mlir::isa<DenseIntElementsAttr>(float_or_int);
+}
+
+bool IsDenseSplatFloatAttr(ElementsAttr float_or_int) {
+  return mlir::isa<SplatElementsAttr>(float_or_int) &&
+         mlir::isa<DenseFPElementsAttr>(float_or_int);
+}
+
+bool ValueEquals(ElementsAttr float_or_int, double rhs) {
+  if (IsDenseSplatFloatAttr(float_or_int)) {
+    return mlir::cast<SplatElementsAttr>(float_or_int)
+        .getSplatValue<APFloat>()
+        .isExactlyValue(rhs);
+  } else if (IsDenseSplatIntAttr(float_or_int)) {
+    return mlir::cast<SplatElementsAttr>(float_or_int).getSplatValue<APInt>() ==
+           static_cast<int>(rhs);
+  }
+  return false;
+}
+
+// Returns whether the splat constant is the sign of the int or float Tensor.
+bool TensorIsSign(PatternRewriter& rewriter, ElementsAttr float_or_int,
+                  ElementsAttr sgn_cst) {
+  auto sgn_splat = llvm::dyn_cast<SplatElementsAttr>(sgn_cst);
+  if (!sgn_splat) return false;
+
+  auto splat = dyn_cast<SplatElementsAttr>(float_or_int);
+  if (auto float_spl = llvm::dyn_cast_if_present<FloatAttr>(splat),
+      sgn_cst_spl = llvm::dyn_cast_if_present<FloatAttr>(sgn_splat);
+      float_spl && sgn_cst_spl) {
+    return IsSign(float_spl.getValue(), sgn_cst_spl.getValue());
+  }
+  if (auto int_spl = llvm::dyn_cast_if_present<IntegerAttr>(splat),
+      sgn_cst_spl = llvm::dyn_cast_if_present<IntegerAttr>(sgn_splat);
+      int_spl && sgn_cst_spl) {
+    return IsSign(int_spl.getValue(), sgn_cst_spl.getValue());
+  }
+  if (mlir::isa<DenseFPElementsAttr>(float_or_int)) {
+    auto sgn_splat_value = sgn_splat.getSplatValue<APFloat>();
+    return llvm::all_of(float_or_int.getValues<APFloat>(), [&](APFloat value) {
+      return IsSign(value, sgn_splat_value);
+    });
+  }
+  if (mlir::isa<DenseIntElementsAttr>(float_or_int)) {
+    auto sgn_splat_value = sgn_splat.getSplatValue<APInt>();
+    return llvm::all_of(float_or_int.getValues<APInt>(), [&](APInt value) {
+      return IsSign(value, sgn_splat_value);
+    });
+  }
+  return false;
+}
+
+bool SameTypeOrDefaultCompare(mhlo::ComparisonTypeAttr comparison_type_attr,
+                              ElementsAttr cst) {
+  if (!comparison_type_attr) return true;
+  auto comparison_type_attr_value = comparison_type_attr.getValue();
+  if (comparison_type_attr_value == mhlo::ComparisonType::FLOAT &&
+      IsDenseSplatFloatAttr(cst)) {
+    return true;
+  }
+  if ((comparison_type_attr_value == mhlo::ComparisonType::SIGNED ||
+       comparison_type_attr_value == mhlo::ComparisonType::UNSIGNED) &&
+      IsDenseSplatIntAttr(cst)) {
+    return true;
+  }
+  return false;
+}
+
+bool ValueIsReciprocal(ElementsAttr float_or_int, ElementsAttr rhs) {
+  if (IsDenseSplatFloatAttr(float_or_int) &&
+      IsDenseSplatFloatAttr(float_or_int)) {
+    return (mlir::cast<SplatElementsAttr>(float_or_int)
+                .getSplatValue<APFloat>() *
+            mlir::cast<SplatElementsAttr>(rhs).getSplatValue<APFloat>())
+        .isExactlyValue(1.0);
+  } else if (IsDenseSplatIntAttr(float_or_int) &&
+             IsDenseSplatIntAttr(float_or_int)) {
+    return (mlir::cast<SplatElementsAttr>(float_or_int).getSplatValue<APInt>() *
+            mlir::cast<SplatElementsAttr>(rhs).getSplatValue<APInt>()) == 1;
+  }
+  return false;
+}
+
+bool ValueGreaterThanZero(ElementsAttr float_or_int) {
+  if (IsDenseSplatIntAttr(float_or_int)) {
+    auto value =
+        mlir::cast<SplatElementsAttr>(float_or_int).getSplatValue<APInt>();
+    return !value.isNegative() && !value.isZero();
+  } else if (IsDenseSplatFloatAttr(float_or_int)) {
+    auto value =
+        mlir::cast<SplatElementsAttr>(float_or_int).getSplatValue<APFloat>();
+    return !value.isNaN() && !value.isNegative() && !value.isZero();
+  }
+  return false;
+}
+
 #define GEN_PASS_DEF_LEGALIZEHLOTOTFLITEPASS
 #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h.inc"
 
+bool SupportedComparisonType(mhlo::ComparisonTypeAttr comp_type) {
+  if (!comp_type) return true;
+  auto c_ty = comp_type.getValue();
+  return c_ty == mhlo::ComparisonType::FLOAT ||
+         c_ty == mhlo::ComparisonType::SIGNED ||
+         c_ty == mhlo::ComparisonType::UNSIGNED ||
+         c_ty == mhlo::ComparisonType::NOTYPE;
+}
+
 class LegalizeHloToTfLitePass
     : public impl::LegalizeHloToTfLitePassBase<LegalizeHloToTfLitePass> {
  public:
@@ -62,10 +197,43 @@
   return !op.getType().getElementType().isF32();
 }
 
+bool IsNotOpLegal(mhlo::NotOp op) {
+  return op.getType().getElementType().isInteger(64);
+}
+
+// Mark possible target ops from rounding patterns as having "unknown"
+// legality. This is required to schedule patterns on these ops even
+// though MhloDialect is explicitly marked legal (which cannot be changed
+// easily).
+void AddRoundingOpsAsUnknown(ConversionTarget& target) {
+  target.addDynamicallyLegalOp<
+      mhlo::FloorOp, mhlo::SubtractOp, mhlo::AndOp, mhlo::SelectOp, mhlo::RemOp,
+      mhlo::AddOp, mhlo::SignOp, mhlo::MulOp, mhlo::DivOp, mhlo::OrOp,
+      mhlo::BroadcastInDimOp, mhlo::ConstantOp, mhlo::RoundOp, mhlo::TupleOp>(
+      [](Operation* op) { return std::nullopt; });
+}
+
+bool IsCompareLegal(mhlo::CompareOp op) {
+  return !SupportedComparisonType(op.getCompareTypeAttr());
+}
+
+void SetUnaryOpLegal(ConversionTarget& target) {
+  auto is_legal = [](Operation* op) {
+    return !llvm::cast<ShapedType>(op->getOperand(0).getType())
+                .getElementType()
+                .isIntOrFloat();
+  };
+  target.addDynamicallyLegalOp<
+      mhlo::AbsOp, mhlo::BitcastConvertOp, mhlo::CeilOp, mhlo::IsFiniteOp,
+      mhlo::CosineOp, mhlo::ExpOp, mhlo::Expm1Op, mhlo::FloorOp, mhlo::ImagOp,
+      mhlo::LogOp, mhlo::NegOp, mhlo::RealOp, mhlo::Log1pOp, mhlo::RsqrtOp,
+      mhlo::SineOp, mhlo::LogisticOp, mhlo::SignOp, mhlo::SqrtOp, mhlo::TanhOp,
+      mhlo::ConvertOp>(is_legal);
+}
+
 #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/generated_tflite_legalize_hlo.inc"
 void LegalizeHloToTfLitePass::runOnOperation() {
   MLIRContext* context = &getContext();
-
   RewritePatternSet patterns(context);
   patterns.add<odml::ConvertCustomCallOp, odml::LowerDotGeneralOp>(context);
   populateWithGenerated(patterns);
@@ -73,14 +241,25 @@
   ConversionTarget target(*context);
   target.addLegalDialect<TFL::TensorFlowLiteDialect, mhlo::MhloDialect>();
   target.addLegalOp<func::CallOp, func::ConstantOp, arith::ConstantOp>();
+
   target.addDynamicallyLegalOp<mhlo::CustomCallOp>(IsCustomCallLegal);
   target.addDynamicallyLegalOp<mhlo::CbrtOp>(IsCbrtLegal);
-  target.addIllegalOp<mhlo::DotGeneralOp, mhlo::DotOp, mhlo::TransposeOp>();
+  target.addIllegalOp<mhlo::DotGeneralOp, mhlo::DotOp, mhlo::TransposeOp,
+                      mhlo::ShiftRightArithmeticOp, mhlo::ShiftRightLogicalOp,
+                      mhlo::RemOp, mhlo::ReshapeOp, mhlo::DynamicReshapeOp>();
+  target.addDynamicallyLegalOp<mhlo::NotOp>(IsNotOpLegal);
+  target.addDynamicallyLegalOp<mhlo::CompareOp>(IsCompareLegal);
+
+  AddRoundingOpsAsUnknown(target);
+  SetUnaryOpLegal(target);
 
   PopulatePadPatterns(context, patterns, target);
   PopulateReducePatterns(context, patterns, target);
+  PopulateLegalizeReduceWindowPatterns(context, patterns, target);
   PopulateGatherPatterns(context, patterns, target);
   PopulateLegalizeConvPatterns(context, patterns, target);
+  PopulateLegalizeSlicePatterns(context, patterns, target);
+  PopulateSortPatterns(context, patterns, target);
 
   if (failed(applyPartialConversion(getOperation(), target,
                                     std::move(patterns)))) {
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo_patterns.td b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo_patterns.td
index 671115d..7d41cbe 100644
--- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo_patterns.td
+++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo_patterns.td
@@ -17,10 +17,15 @@
 include "mlir/Dialect/Func/IR/FuncOps.td"
 include "mhlo/IR/hlo_ops.td"
 include "mlir/IR/CommonAttrConstraints.td"
+include "mlir/IR/CommonAttrConstraints.td"
+include "mlir/IR/CommonTypeConstraints.td"
 include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
 include "tensorflow/compiler/mlir/lite/ir/tfl_op_enums.td"
 include "mlir/Dialect/Arith/IR/ArithOps.td"
 
+
+def ShapeToConst : NativeCodeCall<"ShapeToConst($_builder, $0)">;
+
 def CreateTFLCastToInt32Op : NativeCodeCall<
   "CreateCastToInt32($0, $_loc, $_builder)">;
 
@@ -28,6 +33,95 @@
                             (TFL_TransposeOp $arg,
                             (CreateTFLCastToInt32Op (TFL_ConstOp $perm)))>;
 
+
+def : Pat<(MHLO_ShiftRightArithmeticOp $l, $r), (TFL_RightShiftOp $l, $r)>;
+def : Pat<(MHLO_ShiftRightLogicalOp $l, $r), (TFL_RightShiftOp $l, $r)>;
+def : Pat<(MHLO_RemOp $l, $r), (TFL_FloorModOp $l, $r)>;
+
+def LegalizeReshape : Pat<(MHLO_ReshapeOp:$output $input),
+          (TFL_ReshapeOp $input, 
+            (CreateTFLCastToInt32Op (ShapeToConst $output)))>;
+
+def LegalizeDynamicReshape : Pat<(MHLO_DynamicReshapeOp $input, $shape),
+          (TFL_ReshapeOp $input, (CreateTFLCastToInt32Op $shape))>;
+
+//===----------------------------------------------------------------------===//
+// logical and bitwise ops
+//===----------------------------------------------------------------------===//
+
+class GetRankedScalarAttr<string prefix, int width, string signed, string value> :
+  NativeCodeCall<"DenseElementsAttr::get<" # prefix # "int" # width # "_t>("
+    "RankedTensorType::get({}, $_builder.getIntegerType("
+      # width # signed # "))," # value # ")">;
+
+def : Pat<(MHLO_NotOp I1Tensor:$input), (TFL_LogicalNotOp $input)>;
+
+// TFL does not support bitwise negation. not(x) is equivalant to xor(x, y) if
+// y has a 1 in every bit position (xor(1, 1) = 0 and xor(0, 1) = 1).
+
+// Signed: The 2s complement of -1 has a 1 in every bit position.
+def : Pat<(MHLO_NotOp I8Tensor:$input),
+            (TFL_BitwiseXorOp $input,
+              (Arith_ConstantOp
+                (GetRankedScalarAttr<"", 8, "", "-1">)))>;
+
+def : Pat<(MHLO_NotOp I16Tensor:$input),
+            (TFL_BitwiseXorOp $input,
+              (Arith_ConstantOp
+                (GetRankedScalarAttr<"", 16, "", "-1">)))>;
+
+def : Pat<(MHLO_NotOp I32Tensor:$input),
+            (TFL_BitwiseXorOp $input,
+              (Arith_ConstantOp
+                (GetRankedScalarAttr<"", 32, "", "-1">)))>;
+
+
+// Unsigned: 0xFFF... has a 1 in every bit position.
+def : Pat<(MHLO_NotOp TensorOf<[UI8]>:$input),
+            (TFL_BitwiseXorOp $input,
+              (Arith_ConstantOp
+                (GetRankedScalarAttr<"u", 8, ", false", "0xFFU">)))>;
+
+def : Pat<(MHLO_NotOp TensorOf<[UI16]>:$input),
+            (TFL_BitwiseXorOp $input,
+              (Arith_ConstantOp
+                (GetRankedScalarAttr<"u", 16, ", false", "0xFFFFU">)))>;
+
+def : Pat<(MHLO_NotOp TensorOf<[UI32]>:$input),
+            (TFL_BitwiseXorOp $input,
+              (Arith_ConstantOp
+                (GetRankedScalarAttr<"u", 32, ", false", "0xFFFFFFFFUL">)))>;
+
+//===----------------------------------------------------------------------===//
+// comparison ops
+//===----------------------------------------------------------------------===//
+
+// Check implicit bool cast of `$_self` to ensure Attribute is non-null before
+// casting.
+def HasSupportedComparisonType : AttrConstraint<
+  CPred<"!$_self || SupportedComparisonType($_self.cast<mhlo::ComparisonTypeAttr>())">>;
+
+class MHLO_ComparisonDirectionValue<string enumStr> :
+  ConstantAttr<MHLO_ComparisonDirectionAttr,
+    "::mlir::mhlo::ComparisonDirection::" # enumStr>;
+
+foreach p = [
+  [TFL_EqualOp,        MHLO_ComparisonDirectionValue<"EQ">],
+  [TFL_NotEqualOp,     MHLO_ComparisonDirectionValue<"NE">],
+  [TFL_GreaterEqualOp, MHLO_ComparisonDirectionValue<"GE">],
+  [TFL_LessEqualOp,    MHLO_ComparisonDirectionValue<"LE">],
+  [TFL_GreaterOp,      MHLO_ComparisonDirectionValue<"GT">],
+  [TFL_LessOp,         MHLO_ComparisonDirectionValue<"LT">]]
+in {
+  def : Pat<
+    (MHLO_CompareOp $l, $r, p[1], HasSupportedComparisonType),
+    (p[0] $l, $r)>;
+}
+
+//===----------------------------------------------------------------------===//
+// unary element-wise op
+//===----------------------------------------------------------------------===//
+
 def LowerCbrt : Pat<(MHLO_CbrtOp $opr),
   (TFL_PowOp $opr,
     (TFL_DivOp
@@ -35,3 +129,429 @@
       (Arith_ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "3.0f">),
       TFL_AF_None)),
       [(F32Tensor $opr)]>;
+
+
+foreach pair = [
+  [MHLO_AbsOp, TFL_AbsOp],
+  [MHLO_BitcastConvertOp, TFL_BitcastOp],
+  [MHLO_CeilOp, TFL_CeilOp],
+  [MHLO_CosineOp, TFL_CosOp],
+  [MHLO_ExpOp, TFL_ExpOp],
+  [MHLO_FloorOp, TFL_FloorOp],
+  [MHLO_ImagOp, TFL_ImagOp],
+  [MHLO_LogOp, TFL_LogOp],
+  [MHLO_LogisticOp, TFL_LogisticOp],
+  [MHLO_NegOp, TFL_NegOp],
+  [MHLO_RealOp, TFL_RealOp],
+  [MHLO_RsqrtOp, TFL_RsqrtOp],
+  [MHLO_SineOp, TFL_SinOp],
+  [MHLO_SignOp, TFL_SignOp],
+  [MHLO_SqrtOp, TFL_SqrtOp],
+  [MHLO_TanhOp, TFL_TanhOp]
+] in {
+  def : Pat<
+    (pair[0] $input),
+    (pair[1] $input)>;
+}
+
+def : Pat<
+  (MHLO_ConvertOp $input),
+  (TFL_CastOp $input)>;
+
+def : Pat<
+  (MHLO_Expm1Op F32Tensor:$x),
+  (TFL_SubOp
+    (TFL_ExpOp $x),
+    (Arith_ConstantOp
+      ConstantAttr<RankedF32ElementsAttr<[]>, "1.0f">),
+    TFL_AF_None)>;
+
+def : Pat<
+  (MHLO_IsFiniteOp F32Tensor:$x),
+  (TFL_EqualOp
+    (TFL_SubOp $x, $x, TFL_AF_None),
+    (Arith_ConstantOp
+      ConstantAttr<RankedF32ElementsAttr<[]>, "0.0f">))>;
+
+def : Pat<
+  (MHLO_Log1pOp F32Tensor:$x),
+  (TFL_LogOp
+    (TFL_AddOp
+      $x,
+      (Arith_ConstantOp
+        ConstantAttr<RankedF32ElementsAttr<[]>, "1.0f">),
+      TFL_AF_None))>;
+
+//===----------------------------------------------------------------------===//
+// rounding
+//===----------------------------------------------------------------------===//
+
+class ValueEquals<string val> :
+  Constraint<CPred<"ValueEquals($0, " # val # ")">>;
+
+def SameValue :
+  Constraint<CPred<"$0 == $1">>;
+
+def FloatOrDefaultCompare :
+  Constraint<CPred<"!$0 || $0.getValue() == ::mlir::mhlo::ComparisonType::FLOAT">>;
+
+def SameTypeOrDefaultCompare :
+  Constraint<CPred<"SameTypeOrDefaultCompare($0, $1)">>;
+
+def ValueIsReciprocal :
+  Constraint<CPred<"ValueIsReciprocal($0, $1)">>;
+
+def TensorIsSign :
+  Constraint<CPred<"TensorIsSign($_builder, $0, $1)">>;
+
+def ValueGreaterThanZero :
+  Constraint<CPred<"ValueGreaterThanZero($0)">>;
+
+
+// Converts a dag of HLOs representing banker rounding (round x.5 to nearest
+// even) to tfl.round. This only supports float types because mhlo.floor only
+// supports float types. tf.round with integer input type will become an
+// identity op, so we will never face an mhlo.floor with an integer input type.
+// The pattern matched executes the following computation:
+// frac = x - floor(x)
+// to_even = (floor(x) - 2 * floor(0.5 * x)) == 1
+// if frac > 0.5 || (frac == 0.5 && to_even)
+//   return floor(x) + 1
+// else
+//   return floor(x)
+def Round : Pat<(MHLO_SelectOp
+            (MHLO_OrOp
+              (MHLO_CompareOp (MHLO_SubtractOp:$frac
+                               $input,
+                               (MHLO_FloorOp:$floor $input)),
+                             (MHLO_ConstantOp $half),
+                             MHLO_ComparisonDirectionValue<"GT">,
+                             $compare_type0),
+              (MHLO_AndOp
+                (MHLO_CompareOp
+                  $frac1,
+                  (MHLO_ConstantOp $half1),
+                  MHLO_ComparisonDirectionValue<"EQ">,
+                  $compare_type1),
+                (MHLO_CompareOp
+                  (MHLO_SubtractOp
+                    $floor1,
+                    (MHLO_MulOp
+                      (MHLO_FloorOp (MHLO_MulOp $input, (MHLO_ConstantOp $half2))),
+                      (MHLO_ConstantOp $two))),
+                  (MHLO_ConstantOp $one1),
+                  MHLO_ComparisonDirectionValue<"EQ">,
+                  $compare_type2))),
+            (MHLO_AddOp $floor2, (MHLO_ConstantOp $one)),
+            $floor3),
+          (TFL_RoundOp $input),
+          [(ValueEquals<"1.0"> $one),
+           (ValueEquals<"1.0"> $one1),
+           (ValueEquals<"2.0"> $two),
+           (ValueEquals<"0.5"> $half),
+           (ValueEquals<"0.5"> $half1),
+           (ValueEquals<"0.5"> $half2),
+           (SameValue $floor, $floor1),
+           (SameValue $floor, $floor2),
+           (SameValue $floor, $floor3),
+           (SameValue $frac, $frac1),
+           (FloatOrDefaultCompare $compare_type0),
+           (FloatOrDefaultCompare $compare_type1),
+           (FloatOrDefaultCompare $compare_type2)]>;
+
+// Converts a dag of HLOs representing floor_mod to tfl.floor_mod.
+// The pattern matched executes the following computation:
+//
+// rem = remainder(arg0, arg1)
+// for i in 0 to len(arg1):
+//    if ((rem[i] < 0) != (arg0[i] < 0) && arg0[i] != 0)
+//       rem[i] += arg1[i]
+// return rem
+def : Pat<(MHLO_SelectOp
+            (MHLO_AndOp
+              (MHLO_CompareOp
+                (MHLO_CompareOp:$rltz
+                  (MHLO_RemOp:$rem $arg, $arg1),
+                  (MHLO_ConstantOp $cst),
+                  MHLO_ComparisonDirectionValue<"LT">,
+                  $compare_type),
+                (MHLO_CompareOp:$arg1ltz $arg1, (MHLO_ConstantOp $cst1), MHLO_ComparisonDirectionValue<"LT">, $compare_type1),
+                MHLO_ComparisonDirectionValue<"NE">,
+                $compare_type2),
+              (MHLO_CompareOp:$rnz $rem1, (MHLO_ConstantOp $cst2), MHLO_ComparisonDirectionValue<"NE">, $compare_type3)),
+            (MHLO_AddOp $rem2, $arg1),
+            $rem3),
+          (TFL_FloorModOp $arg, $arg1),
+          [(ValueEquals<"0.0"> $cst),
+           (ValueEquals<"0.0"> $cst1),
+           (ValueEquals<"0.0"> $cst2),
+           (SameValue $rem, $rem1),
+           (SameValue $rem, $rem2),
+           (SameValue $rem, $rem3),
+           (SameTypeOrDefaultCompare $compare_type, $cst),
+           (SameTypeOrDefaultCompare $compare_type1, $cst1)]>;
+
+// Converts a dag of HLOs representing floor_mod with a constant to
+// tfl.floor_mod. The pattern matched executes the following computation:
+//
+// cst = value that is > 0
+// rem = remainder(arg0, cst)
+// for i in 0 to len(arg1):
+//    if (rem[i] < 0 && rem[i] != 0)
+//       rem[i] += cst
+// return rem
+def : Pat<(MHLO_SelectOp
+            (MHLO_AndOp
+              (MHLO_CompareOp:$rltz
+                (MHLO_RemOp:$rem $arg, (MHLO_ConstantOp $cst)),
+                (MHLO_ConstantOp $cst1),
+                MHLO_ComparisonDirectionValue<"LT">,
+                $compare_type),
+              (MHLO_CompareOp:$rnz $rem1, (MHLO_ConstantOp $cst2), MHLO_ComparisonDirectionValue<"NE">, $compare_type3)),
+            (MHLO_AddOp $rem2, (MHLO_ConstantOp $cst3)),
+            $rem3),
+          (TFL_FloorModOp $arg, (Arith_ConstantOp $cst3)),
+          [(ValueGreaterThanZero $cst),
+           (ValueEquals<"0.0"> $cst1),
+           (ValueEquals<"0.0"> $cst2),
+           (SameValue $cst, $cst3),
+           (SameValue $rem, $rem1),
+           (SameValue $rem, $rem2),
+           (SameValue $rem, $rem3),
+           (SameTypeOrDefaultCompare $compare_type, $cst1),
+           (SameTypeOrDefaultCompare $compare_type3, $cst2)]>;
+
+// Converts a dag of HLOs representing floor_div to tfl.floor_div.
+// The pattern matched executes the following computation:
+//
+// rem = remainder(arg0, arg1)
+// for i in 0 to len(arg1):
+//    rem[i] = arg0[i] - rem[i] / arg1[i]
+//    if (rem[i] != 0 && sign(arg1[i]) != sign(rem[i]))
+//       rem[i] -= 1.0
+// return round_nearest_afz(rem)
+// As a dag this looks like the following:
+//                       round
+//                         |
+//              -------- select
+//              |          |    \
+//             &&          +    div
+//           /   |        / \
+//        !=    !=       div -1
+//       / |   / |      / |
+//    rem 0.0 sn sn1    -  $1
+//   / |      |  |    / |
+// $0 $1     $1 rem  $0 rem
+// Note that named operators like 'sn' and 'sn1' are different values produced by
+// the same function in this case the sign function. Named values like 'div'
+// refer to the same value produced by the same function, in this case division.
+// Mathematical symbols do not indicate a re-use of the value.
+def : Pat<(MHLO_RoundOp
+            (MHLO_SelectOp
+              (MHLO_AndOp
+                (MHLO_CompareOp
+                  (MHLO_RemOp:$rem $arg0, $arg1),
+                  (MHLO_ConstantOp $cst),
+                  MHLO_ComparisonDirectionValue<"NE">,
+                  $compare_type),
+                (MHLO_CompareOp
+                  (MHLO_SignOp $arg1),
+                  (MHLO_SignOp $rem1),
+                  MHLO_ComparisonDirectionValue<"NE">,
+                  $compare_type1)),
+              (MHLO_AddOp
+                (MHLO_DivOp:$div
+                  (MHLO_SubtractOp $arg0, $rem2),
+                  $arg1b),
+                (MHLO_ConstantOp $cst_neg1)),
+              $div1)),
+          (TFL_FloorDivOp $arg0, $arg1),
+          [(ValueEquals<"0.0"> $cst),
+           (ValueEquals<"-1.0"> $cst_neg1),
+           (SameValue $div, $div1),
+           (SameValue $rem, $rem1),
+           (SameValue $rem, $rem2),
+           (FloatOrDefaultCompare $compare_type, $cst),
+           (FloatOrDefaultCompare $compare_type1, $cst)]>;
+
+// Converts a dag of HLOs representing floor_div with a splat constant to
+// tfl.floor_div. The pattern matched executes the following computation:
+// This particular pattern matches multiplication with the reciprocal of the
+// constant instead of dividing by the constant.
+// rem = remainder(arg0, cst)
+// for i in 0 to len(arg0):
+//    rem[i] = (arg0[i] - rem[i]) * 1 / cst
+//    if (rem[i] != 0 && sign(cst) != sign(rem[i]))
+//       rem[i] += -1.0
+// return round_nearest_afz(rem)
+// As a dag this looks like the following:
+//                       round
+//                         |
+//              -------- select
+//              |           |     \
+//             &&           +     mul
+//           /   |         / \
+//        !=    !=        mul -1
+//       / |   /  |       / |
+//    rem 0.0 cs1 sn1   -  cs2
+//   / |           |   / |
+// $0 cst         rem $0 rem
+// cs1 == sign(cst)
+// cs2 = 1 / cst i.e. the reciprocal
+// Note that named operators like 'sn' and 'sn1' are different values produced by
+// the same function in this case the sign function. Named values like 'div'
+// refer to the same value produced by the same function, in this case division.
+// Mathematical symbols do not indicate a re-use of the value.
+def : Pat<(MHLO_RoundOp
+            (MHLO_SelectOp
+              (MHLO_AndOp
+                (MHLO_CompareOp
+                  (MHLO_RemOp:$rem $arg0, (MHLO_ConstantOp:$cst $cstv)),
+                  (MHLO_ConstantOp $cst_zero),
+                  MHLO_ComparisonDirectionValue<"NE">,
+                  $compare_type),
+                (MHLO_CompareOp
+                  (MHLO_ConstantOp $cst_sgn),
+                  (MHLO_SignOp $rem1),
+                  MHLO_ComparisonDirectionValue<"NE">,
+                  $compare_type1)),
+              (MHLO_AddOp
+                (MHLO_MulOp:$mul
+                  (MHLO_SubtractOp $arg0, $rem2),
+                  (MHLO_ConstantOp $cst_recip)),
+                (MHLO_ConstantOp $cst_neg1)),
+              $mul1)),
+          (TFL_FloorDivOp $arg0, $cst),
+          [(ValueEquals<"0.0"> $cst_zero),
+           (ValueEquals<"-1.0"> $cst_neg1),
+           (TensorIsSign $cstv, $cst_sgn),
+           (ValueIsReciprocal $cstv, $cst_recip),
+           (SameValue $mul, $mul1),
+           (SameValue $rem, $rem1),
+           (SameValue $rem, $rem2),
+           (FloatOrDefaultCompare $compare_type, $cst_zero),
+           (FloatOrDefaultCompare $compare_type1, $cst_sgn)]>;
+
+// Converts a dag of HLOs representing floor_div with a splat constant to
+// tfl.floor_div. The pattern matched executes the following computation:
+// This particular pattern matches division with the constant.
+// .
+// rem = remainder(arg0, cst)
+// for i in 0 to len(arg0):
+//    rem[i] = (arg0[i] - rem[i]) / cst
+//    if (rem[i] != 0 && sign(cst) != sign(rem[i]))
+//       rem[i] -= 1.0
+// return round_nearest_afz(rem)
+// As a dag this looks like the following:
+//                       round
+//                         |
+//              -------- select
+//              |           |     \
+//             &&           +     div
+//           /   |         / \
+//        !=    !=        div -1
+//       / |   /  |       / |
+//    rem 0.0 cs1 sn1   -  cs2
+//   / |           |   / |
+// $0 cst         rem $0 rem
+// cs1 == sign(cst)
+// cs2 = 1 / cst i.e. the reciprocal
+// Note that named operators like 'sn' and 'sn1' are different values produced by
+// the same function in this case the sign function. Named values like 'div'
+// refer to the same value produced by the same function, in this case division.
+// Mathematical symbols do not indicate a re-use of the value.
+def : Pat<(MHLO_RoundOp
+            (MHLO_SelectOp
+              (MHLO_AndOp
+                (MHLO_CompareOp
+                  (MHLO_RemOp:$rem $arg0, (MHLO_ConstantOp:$cst $cstv)),
+                  (MHLO_ConstantOp $cst_zero),
+                  MHLO_ComparisonDirectionValue<"NE">,
+                  $compare_type),
+                (MHLO_CompareOp
+                  (MHLO_ConstantOp $cst_sgn),
+                  (MHLO_SignOp $rem1),
+                  MHLO_ComparisonDirectionValue<"NE">,
+                  $compare_type1)),
+              (MHLO_AddOp
+                (MHLO_DivOp:$div
+                  (MHLO_SubtractOp $arg0, $rem2),
+                  (MHLO_ConstantOp $cstv1)),
+                (MHLO_ConstantOp $cst_neg1)),
+              $div1)),
+          (TFL_FloorDivOp $arg0, $cst),
+          [(ValueEquals<"0.0"> $cst_zero),
+           (ValueEquals<"-1.0"> $cst_neg1),
+           (TensorIsSign $cstv, $cst_sgn),
+           (SameValue $div, $div1),
+           (SameValue $rem, $rem1),
+           (SameValue $rem, $rem2),
+           (SameValue $cstv1, $cstv),
+           (FloatOrDefaultCompare $compare_type, $cst_zero),
+           (FloatOrDefaultCompare $compare_type1, $cst_sgn)]>;
+
+// Converts a dag of HLOs representing floor_div with a broadcasted vector
+// constant to tfl.floor_div. The pattern matched executes the following
+// computation:
+// scs = sign(cst)
+// bcst = broadcast(cst)
+// rem = remainder(arg0, bcst)
+// for i in 0 to len(arg0):
+//    rem[i] = arg0[i] - rem[i] * / bcst
+//    if (rem[i] != 0 && scs != sign(rem[i]))
+//       rem[i] -= 1.0
+// return round_nearest_afz(rem)
+// Where scs is a splat constant folded sign on the unbroadcasted tensor.
+//
+// As a dag this looks like the following:
+//                       round
+//                         |
+//              -------- select
+//              |           |     \
+//             &&           +     div
+//           /   |         / \
+//        !=    !=        div -1
+//       / |   /  |       / |
+//    rem 0.0 scs sn1   -  bcst
+//   / |           |   / |
+// $0 bcst        rem $0 rem
+//     |
+//    cst
+// scs == sign(cst) == sign(bcst)
+// Note that named operators like 'sn' and 'sn1' are different values produced by
+// the same function in this case the sign function. Named values like 'div'
+// refer to the same value produced by the same function, in this case division.
+// Mathematical symbols do not indicate a re-use of the value.
+def : Pat<(MHLO_RoundOp
+            (MHLO_SelectOp
+              (MHLO_AndOp
+                (MHLO_CompareOp
+                  (MHLO_RemOp:$rem $arg0,
+                    (MHLO_BroadcastInDimOp:$bcst
+                      (MHLO_ConstantOp $cstv),
+                      $broadcast_dimension)),
+                  (MHLO_ConstantOp $cst_zero),
+                  MHLO_ComparisonDirectionValue<"NE">,
+                  $compare_type),
+                (MHLO_CompareOp
+                  (MHLO_ConstantOp $cst_sgn),
+                  (MHLO_SignOp $rem1),
+                  MHLO_ComparisonDirectionValue<"NE">,
+                  $compare_type1)),
+              (MHLO_AddOp
+                (MHLO_DivOp:$div
+                  (MHLO_SubtractOp $arg0, $rem2),
+                  $bcst1),
+                (MHLO_ConstantOp $cst_neg1)),
+              $div1)),
+          (TFL_FloorDivOp $arg0, $bcst),
+          [(ValueEquals<"0.0"> $cst_zero),
+           (ValueEquals<"-1.0"> $cst_neg1),
+           (TensorIsSign $cstv, $cst_sgn),
+           (SameValue $bcst, $bcst1),
+           (SameValue $div, $div1),
+           (SameValue $rem, $rem1),
+           (SameValue $rem, $rem2),
+           (FloatOrDefaultCompare $compare_type, $cst_zero),
+           (FloatOrDefaultCompare $compare_type1, $cst_sgn)]>;
+
diff --git a/tensorflow/compiler/mlir/lite/stateful_error_reporter.h b/tensorflow/compiler/mlir/lite/stateful_error_reporter.h
new file mode 100644
index 0000000..fbb82d3
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/stateful_error_reporter.h
@@ -0,0 +1,36 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STATEFUL_ERROR_REPORTER_H_
+#define TENSORFLOW_COMPILER_MLIR_LITE_STATEFUL_ERROR_REPORTER_H_
+
+// LINT.IfChange
+#include <string>
+
+#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h"
+
+namespace tflite_migration {
+
+// Similar to tflite::ErrorReporter, except that it allows callers to get the
+// last error message.
+class StatefulErrorReporter : public tflite::ErrorReporter {
+ public:
+  // Returns last error message. Returns empty string if no error is reported.
+  virtual std::string message() = 0;
+};
+
+}  // namespace tflite_migration
+// LINT.ThenChange(//tensorflow/lite/stateful_error_reporter.h)
+
+#endif  // TENSORFLOW_COMPILER_MLIR_LITE_STATEFUL_ERROR_REPORTER_H_
diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/BUILD b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/BUILD
index 6cc9a62..4e7fa53 100644
--- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/BUILD
+++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/BUILD
@@ -56,9 +56,9 @@
         "importer_test_min_max.cc",
     ],
     deps = [
+        "//tensorflow/compiler/mlir/lite/core:absl_error_model_builder",
         "//tensorflow/compiler/mlir/lite/schema:schema_fbs",
         "//tensorflow/compiler/mlir/lite/schema:schema_utils",
-        "//tensorflow/lite:framework",
         "@com_google_absl//absl/strings",
         "@llvm-project//llvm:Support",
     ],
diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/importer_test_min_max.cc b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/importer_test_min_max.cc
index 30890fb..6231088 100644
--- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/importer_test_min_max.cc
+++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/importer_test_min_max.cc
@@ -16,17 +16,15 @@
 #include <iostream>
 #include <memory>
 #include <optional>
-#include <system_error>
 
-#include "absl/strings/string_view.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/InitLLVM.h"
 #include "llvm/Support/MemoryBuffer.h"
 #include "llvm/Support/PrettyStackTrace.h"
 #include "llvm/Support/raw_ostream.h"
+#include "tensorflow/compiler/mlir/lite/core/absl_error_model_builder.h"
 #include "tensorflow/compiler/mlir/lite/schema/schema_generated.h"
 #include "tensorflow/compiler/mlir/lite/schema/schema_utils.h"
-#include "tensorflow/lite/model.h"
 
 using llvm::cl::opt;
 
@@ -52,7 +50,7 @@
 namespace {
 std::optional<std::unique_ptr<tflite::ModelT>> InjectStatsToFullyConnected(
     llvm::StringRef buffer) {
-  auto model_ptr = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(
+  auto model_ptr = TFL::FlatBufferModelAbslError::VerifyAndBuildFromBuffer(
       buffer.data(), buffer.size());
   if (nullptr == model_ptr) {
     return std::nullopt;
diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc
index 853606c..02c4ac4 100644
--- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc
+++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc
@@ -107,7 +107,6 @@
 using mlir::ModuleOp;
 using mlir::Operation;
 using mlir::OwningOpRef;
-using ::stablehlo::quantization::QuantizationConfig;
 using ::tensorflow::quantization::PyFunctionLibrary;
 
 bool IsControlFlowV1Op(Operation* op) {
diff --git a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc
index d3830d3..2b5b753 100644
--- a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc
@@ -81,25 +81,6 @@
 
 namespace {
 
-// TODO(b/355062942): This a temporary solution to unblock LLVM intergration.
-// https://github.com/llvm/llvm-project/commit/bbd4af5da2b741672a8e6f625eb12ea5c2d6220f
-// changed the behavior of `applySignatureConversion`. Before, an op adaptor
-// would have the new block arguments directly as operands. Now, there is an
-// `UnrealizedConversionCastOp` inserts from the new type to the old type.
-// The new behaviour is correct, but passes in this file depended on the old
-// bahavior and worked by coincidence.
-llvm::SmallVector<Value, 4> GetOperandsAndSkipUnrealizedConversionCasts(
-    ValueRange operands) {
-  llvm::SmallVector<Value, 4> result;
-  for (Value operand : operands) {
-    if (auto cast = operand.getDefiningOp<UnrealizedConversionCastOp>()) {
-      operand = cast.getInputs().front();
-    }
-    result.push_back(operand);
-  }
-  return result;
-}
-
 /// Lower TensorList ops in functions for subsequent legalization.
 struct LowerStaticTensorListPass
     : public impl::LowerStaticTensorListPassBase<LowerStaticTensorListPass> {
@@ -371,9 +352,7 @@
       ConversionPatternRewriter &rewriter) const {
     Location loc = op.getLoc();
 
-    auto operands =
-        GetOperandsAndSkipUnrealizedConversionCasts(adaptor.getOperands());
-
+    auto operands = adaptor.getOperands();
     Value input = operands[0];
     Value index = operands[1];
     Value item = operands[2];
@@ -433,8 +412,7 @@
       ConversionPatternRewriter &rewriter) const {
     Location loc = op.getLoc();
 
-    auto operands =
-        GetOperandsAndSkipUnrealizedConversionCasts(adaptor.getOperands());
+    auto operands = adaptor.getOperands();
     Value input = operands[0];
     Value index = operands[1];
     Value item = operands[2];
@@ -721,8 +699,7 @@
   LogicalResult matchAndRewrite(
       TF::TensorListPushBackOp op, OpAdaptor adaptor,
       ConversionPatternRewriter &rewriter) const override {
-    auto operands =
-        GetOperandsAndSkipUnrealizedConversionCasts(adaptor.getOperands());
+    auto operands = adaptor.getOperands();
     Value input_handle = operands[0];
     Value item = operands[1];
 
@@ -764,8 +741,7 @@
   LogicalResult matchAndRewrite(
       TF::TensorListResizeOp op, OpAdaptor adaptor,
       ConversionPatternRewriter &rewriter) const override {
-    auto operands =
-        GetOperandsAndSkipUnrealizedConversionCasts(adaptor.getOperands());
+    auto operands = adaptor.getOperands();
     Value input_handle = operands[0];
     Value size = operands[1];
 
@@ -929,9 +905,7 @@
   LogicalResult matchAndRewrite(
       TF::TensorListGetItemOp op, OpAdaptor adaptor,
       ConversionPatternRewriter &rewriter) const override {
-    auto operands =
-        GetOperandsAndSkipUnrealizedConversionCasts(adaptor.getOperands());
-
+    auto operands = adaptor.getOperands();
     Value input = operands[0];
     Value index = operands[1];
     rewriter.replaceOpWithNewOp<TF::GatherOp>(op, op.getType(), input, index,
@@ -948,8 +922,7 @@
       TF::TensorListLengthOp op, OpAdaptor adaptor,
       ConversionPatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
-    Value input_handle =
-        GetOperandsAndSkipUnrealizedConversionCasts(adaptor.getOperands())[0];
+    Value input_handle = adaptor.getOperands()[0];
 
     BoolAttr true_attr = rewriter.getBoolAttr(true);
     auto shape = rewriter.create<TF::ShapeOp>(loc, input_handle,
@@ -970,8 +943,7 @@
       ConversionPatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
 
-    auto operands =
-        GetOperandsAndSkipUnrealizedConversionCasts(adaptor.getOperands());
+    auto operands = adaptor.getOperands();
     Value input = operands[0];
     Value element_shape = operands[1];
 
@@ -1021,8 +993,7 @@
       ConversionPatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
 
-    auto operands =
-        GetOperandsAndSkipUnrealizedConversionCasts(adaptor.getOperands());
+    auto operands = adaptor.getOperands();
     Value input = operands[0];
     Value element_shape = operands[1];
 
@@ -1084,8 +1055,7 @@
   LogicalResult matchAndRewrite(
       TF::IdentityOp op, OpAdaptor adaptor,
       ConversionPatternRewriter &rewriter) const override {
-    Value input =
-        GetOperandsAndSkipUnrealizedConversionCasts(adaptor.getOperands())[0];
+    Value input = adaptor.getOperands()[0];
     rewriter.replaceOpWithNewOp<TF::IdentityOp>(op, input.getType(), input,
                                                 op->getAttrs());
     return success();
@@ -1098,9 +1068,7 @@
   LogicalResult matchAndRewrite(
       func::ReturnOp op, OpAdaptor adaptor,
       ConversionPatternRewriter &rewriter) const override {
-    auto operands =
-        GetOperandsAndSkipUnrealizedConversionCasts(adaptor.getOperands());
-
+    auto operands = adaptor.getOperands();
     rewriter.replaceOpWithNewOp<func::ReturnOp>(op, ValueRange{}, operands,
                                                 op->getAttrs());
     return success();
@@ -1113,8 +1081,7 @@
   LogicalResult matchAndRewrite(
       TF::YieldOp op, OpAdaptor adaptor,
       ConversionPatternRewriter &rewriter) const override {
-    auto operands =
-        GetOperandsAndSkipUnrealizedConversionCasts(adaptor.getOperands());
+    auto operands = adaptor.getOperands();
     rewriter.replaceOpWithNewOp<TF::YieldOp>(op, operands);
     return success();
   }
diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc
index 6f4c2dd..1bbf673 100644
--- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc
+++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc
@@ -129,7 +129,7 @@
 
 void AddStablehloQuantToIntPasses(OpPassManager& pm) {
   pm.addNestedPass<func::FuncOp>(
-      mlir::stablehlo::createStablehloLegalizeQuantToIntPass());
+      mlir::stablehlo::createStablehloLegalizeQuantToMathPass());
   // StableHLO -> MHLO legalization.
   pm.addPass(mhlo::createStablehloLegalizeToHloPass());
   pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_types_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_types_test.cc
index 318ccbd..e71af87 100644
--- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_types_test.cc
+++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_types_test.cc
@@ -27,9 +27,9 @@
 #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.h"
 #include "tensorflow/compiler/mlir/register_common_dialects.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/lib/monitoring/cell_reader.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/statusor.h"
 
 namespace mlir::quant::stablehlo {
diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/legalize_tf_quant_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/legalize_tf_quant_test.cc
index 8cbd48d..2e9d7a8 100644
--- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/legalize_tf_quant_test.cc
+++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/legalize_tf_quant_test.cc
@@ -25,12 +25,12 @@
 #include "xla/shape.h"
 #include "xla/stream_executor/platform.h"
 #include "xla/stream_executor/platform_manager.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/tensor_shape.h"
 #include "tensorflow/core/framework/types.pb.h"
 #include "tensorflow/core/protobuf/config.pb.h"
 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
 #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
-#include "tsl/lib/core/status_test_util.h"
 
 namespace mlir::quant::stablehlo {
 namespace {
diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.cc
index e96ab83..d8878ff 100644
--- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.cc
+++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.cc
@@ -31,7 +31,7 @@
   pm.addNestedPass<mlir::func::FuncOp>(mlir::createCanonicalizerPass());
   pm.addPass(mhlo::createHloLegalizeToStablehloPass());
   pm.addNestedPass<mlir::func::FuncOp>(
-      mlir::stablehlo::createStablehloLegalizeQuantToIntPass());
+      mlir::stablehlo::createStablehloLegalizeQuantToMathPass());
   pm.addPass(mhlo::createStablehloLegalizeToHloPass());
   pm.addNestedPass<mlir::func::FuncOp>(mlir::createCanonicalizerPass());
   pm.addNestedPass<mlir::func::FuncOp>(CreateVerifyQuantLegalizationPass());
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/debugging/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/BUILD
index b465fe1..80167b4 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/debugging/BUILD
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/BUILD
@@ -51,11 +51,11 @@
         "@llvm-project//mlir:Pass",
         "@llvm-project//mlir:Support",
         "@llvm-project//mlir:Transforms",
-        "@local_tsl//tsl/lib/core:status_test_util",
         "@local_tsl//tsl/platform:env",
         "@local_tsl//tsl/platform:path",
         "@local_tsl//tsl/platform:test",
         "@local_tsl//tsl/platform:test_main",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
         "@stablehlo//:stablehlo_ops",
     ],
 )
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump_test.cc b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump_test.cc
index c3034f4..be49ddb 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump_test.cc
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump_test.cc
@@ -32,7 +32,7 @@
 #include "mlir/Support/TypeID.h"  // from @llvm-project
 #include "mlir/Transforms/Passes.h"  // from @llvm-project
 #include "stablehlo/dialect/StablehloOps.h"  // from @stablehlo
-#include "tsl/lib/core/status_test_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tsl/platform/env.h"
 #include "tsl/platform/path.h"
 #include "tsl/platform/test.h"
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
index 7b90a43..cbde974 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
@@ -1137,7 +1137,7 @@
     DefaultValuedOptionalAttr<I64ArrayAttr, "{}">:$low_priority_allowed_batch_sizes,
     DefaultValuedOptionalAttr<I64Attr, "0">:$low_priority_max_enqueued_batches,
     DefaultValuedOptionalAttr<TF_AnyStrAttrOf<["low_priority_padding_with_max_batch_size", "low_priority_padding_with_next_allowed_batch_size", "priority_isolation"]>, "\"low_priority_padding_with_max_batch_size\"">:$mixed_priority_policy,
-    DefaultValuedOptionalAttr<TF_AnyStrAttrOf<["PAD_UP"]>, "\"PAD_UP\"">:$batch_padding_policy,
+    DefaultValuedOptionalAttr<TF_AnyStrAttrOf<["PAD_UP", "BATCH_DOWN", "MINIMIZE_TPU_COST_PER_REQUEST"]>, "\"PAD_UP\"">:$batch_padding_policy,
     DefaultValuedOptionalAttr<BoolAttr, "false">:$enable_large_batch_splitting
   );
 
@@ -11293,6 +11293,7 @@
     DefaultValuedOptionalAttr<BoolAttr, "true">:$use_inter_op_parallelism,
     DefaultValuedOptionalAttr<StrAttr, "\"default\"">:$deterministic,
     DefaultValuedOptionalAttr<BoolAttr, "false">:$preserve_cardinality,
+    DefaultValuedOptionalAttr<BoolAttr, "false">:$use_unbounded_threadpool,
     DefaultValuedOptionalAttr<StrAttr, "\"\"">:$metadata
   );
 
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/freeze_variables.mlir b/tensorflow/compiler/mlir/tensorflow/tests/freeze_variables.mlir
index 4e82022..a458a20 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/freeze_variables.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/freeze_variables.mlir
@@ -424,7 +424,7 @@
   }
 
   // CHECK: func private @f_callee(%[[ARG0:.*]]: tensor<0xf32>) -> tensor<0xf32>
-  // CHECK-SAME: tf._input_shapes = [#tf_type.shape<00>]
+  // CHECK-SAME: tf._input_shapes = [#tf_type.shape<0>]
   func.func private @f_callee(%arg0: tensor<0xf32>, %arg1: tensor<*x!tf_type.resource>) -> tensor<0xf32> attributes {tf._input_shapes = [#tf_type.shape<0>, #tf_type.shape<>]} {
     %0 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf_type.resource>) -> tensor<0xf32>
     %1 = "tf.AddV2"(%arg0, %0) : (tensor<0xf32>, tensor<0xf32>) -> tensor<0xf32>
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_freeze_global_tensors.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_freeze_global_tensors.mlir
index 1789b1d..4a31291 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_freeze_global_tensors.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_freeze_global_tensors.mlir
@@ -289,7 +289,7 @@
     %1 = "tf.AddV2"(%arg0, %0) : (tensor<f32>, tensor<f32>) -> tensor<f32>
     func.return %1 : tensor<f32>
   }
-  // CHECK: func.func private @f_callee(%arg0: tensor<f32>) -> tensor<f32> attributes {tf._input_shapes = [#tf_type.shape<00>]} {
+  // CHECK: func.func private @f_callee(%arg0: tensor<f32>) -> tensor<f32> attributes {tf._input_shapes = [#tf_type.shape<0>]} {
   // CHECK:   %cst = "tf.Const"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>
   // CHECK:   %0 = "tf.AddV2"(%arg0, %cst) : (tensor<f32>, tensor<f32>) -> tensor<f32>
   // CHECK:   return %0 : tensor<f32>
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/BUILD b/tensorflow/compiler/mlir/tensorflow/transforms/BUILD
index a7ee2b1..f400424 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/BUILD
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/BUILD
@@ -829,6 +829,7 @@
         "@local_xla//xla/service:shape_inference",
         "@local_xla//xla/translate/hlo_to_mhlo:hlo_utils",
         "@local_xla//xla/translate/mhlo_to_hlo:type_to_shape",
+        "@local_xla//xla/tsl/util:env_var",
     ],
 )
 
@@ -1094,3 +1095,14 @@
         "@local_tsl//tsl/platform:path",
     ],
 )
+
+tf_cc_test(
+    name = "shape_inference_test",
+    srcs = ["shape_inference_test.cc"],
+    deps = [
+        ":shape_inference_pass",
+        "@com_google_googletest//:gtest_main",
+        "@llvm-project//mlir:IR",
+        "@local_tsl//tsl/platform:env",
+    ],
+)
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops_test.cc b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops_test.cc
index 54fb7cf..77f9361 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops_test.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops_test.cc
@@ -36,13 +36,13 @@
 #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
 #include "xla/tsl/framework/device_type.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/lib/monitoring/cell_reader.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/resource_loader.h"
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/tpu/tpu_defs.h"
 #include "tensorflow/core/util/debug_data_dumper.h"
-#include "tsl/lib/core/status_test_util.h"
 
 namespace tensorflow {
 namespace tfrt_compiler {
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
index 4bcd7e0..d9110a2 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
@@ -87,6 +87,7 @@
 #include "xla/shape.h"
 #include "xla/translate/hlo_to_mhlo/hlo_utils.h"
 #include "xla/translate/mhlo_to_hlo/type_to_shape.h"
+#include "xla/tsl/util/env_var.h"
 #include "xla/window_util.h"
 #include "xla/xla_data.pb.h"
 #include "tensorflow/core/framework/shape_inference.h"
@@ -110,6 +111,20 @@
 namespace TF {
 namespace {
 
+MLIRContext::Threading GetMlirContextThreading() {
+  bool enable_single_thread_mlir_context = []() {
+    bool result = false;
+    if (auto status = tsl::ReadBoolFromEnvVar(kMLIRContextSingleThreadVar,
+                                              /*default_val=*/false, &result);
+        status.ok()) {
+      return result;
+    }
+    return false;
+  }();
+  return enable_single_thread_mlir_context ? MLIRContext::Threading::DISABLED
+                                           : MLIRContext::Threading::ENABLED;
+}
+
 // Compute a refined type between two types `lhs` and `rhs`, the result type
 // is always more refined (i.e. has more static information) than `lhs`
 // This method will actually merge the information contained in the
@@ -443,6 +458,11 @@
 }
 }  // namespace
 
+// Create a MLIRContext based on the threading setup in the env var.
+std::unique_ptr<MLIRContext> MakeMLIRContextWithThreading() {
+  return std::make_unique<MLIRContext>(GetMlirContextThreading());
+}
+
 // Returns whether type can be further refined.
 bool CanBeRefined(Type type) {
   auto shape_type = mlir::dyn_cast<ShapedType>(type);
@@ -1024,7 +1044,7 @@
   // each `XlaCallModule` op. Uses its own MLIRContext since the loader needs to
   // load additional dialects, which is not allowed for the main context since
   // shape inference may be called from a pass.
-  MLIRContext xla_call_module_context_;
+  std::unique_ptr<MLIRContext> xla_call_module_context_;
   DenseMap<XlaCallModuleOp, std::unique_ptr<tensorflow::XlaCallModuleLoader>>
       xla_call_module_loaders_;
 };
@@ -1036,6 +1056,7 @@
       symbol_users_(symbol_table_, module),
       graph_version_(graph_version),
       propagate_caller_callee_constants_(propagate_caller_callee_constants) {
+  xla_call_module_context_ = MakeMLIRContextWithThreading();
   for (const auto& op_type : ops_to_skip) {
     ops_to_skip_.insert(op_type);
   }
@@ -1242,10 +1263,10 @@
     mlir::DialectRegistry registry;
     registry.insert<mlir::func::FuncDialect>();
     mlir::func::registerAllExtensions(registry);
-    xla_call_module_context_.appendDialectRegistry(registry);
+    xla_call_module_context_->appendDialectRegistry(registry);
 
     auto l = tensorflow::XlaCallModuleLoader::Create(
-        &xla_call_module_context_, op.getVersion(), op.getModule().str(),
+        xla_call_module_context_.get(), op.getVersion(), op.getModule().str(),
         std::move(disabled_checks), std::move(platforms),
         /*num_invocation_args=*/op.getArgs().size(),
         op.getHasTokenInputOutput());
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h
index 46c1bc9..9075754 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h
@@ -17,6 +17,7 @@
 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_SHAPE_INFERENCE_H_
 
 #include <cstdint>
+#include <memory>
 
 #include "absl/status/statusor.h"
 #include "absl/strings/string_view.h"
@@ -31,6 +32,9 @@
 namespace mlir {
 namespace TF {
 
+inline constexpr char kMLIRContextSingleThreadVar[] =
+    "TF_USE_SINGLE_THREAD_MLIR_CONTEXT";
+
 // Returns whether type can be further refined.
 bool CanBeRefined(Type type);
 
@@ -71,6 +75,9 @@
                                       int64_t max_iterations = 10,
                                       ArrayRef<TypeID> ops_to_skip = {});
 
+// Create a MLIRContext based on the threading setup in the env var.
+std::unique_ptr<MLIRContext> MakeMLIRContextWithThreading();
+
 }  // namespace TF
 
 }  // namespace mlir
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_test.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_test.cc
new file mode 100644
index 0000000..416807c
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_test.cc
@@ -0,0 +1,39 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h"
+
+#include <memory>
+
+#include <gtest/gtest.h>
+#include "mlir/IR/MLIRContext.h"  // from @llvm-project
+
+namespace mlir {
+namespace TF {
+namespace {
+
+TEST(ShapeInferenceTest, CreateMultiThreadedMLIRContext) {
+  std::unique_ptr<MLIRContext> ctx = MakeMLIRContextWithThreading();
+  EXPECT_TRUE(ctx->isMultithreadingEnabled());
+}
+
+TEST(ShapeInferenceTest, CreateSingleThreadedMLIRContext) {
+  setenv(kMLIRContextSingleThreadVar, "true", 1);
+  std::unique_ptr<MLIRContext> ctx = MakeMLIRContextWithThreading();
+  EXPECT_FALSE(ctx->isMultithreadingEnabled());
+}
+
+}  // namespace
+}  // namespace TF
+}  // namespace mlir
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/cluster_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/cluster_util_test.cc
index 9262f87..3720a09 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/cluster_util_test.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/cluster_util_test.cc
@@ -22,8 +22,8 @@
 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/errors.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/statusor.h"
 
 namespace mlir::TF {
diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD b/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD
index 20f771d..4763746 100644
--- a/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD
+++ b/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD
@@ -97,12 +97,12 @@
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Pass",
-        "@local_tsl//tsl/lib/core:status_test_util",
         "@local_tsl//tsl/platform:errors",
         "@local_tsl//tsl/platform:statusor",
         "@local_xla//xla:shape_util",
         "@local_xla//xla/client:xla_builder",
         "@local_xla//xla/client:xla_computation",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
 
@@ -179,7 +179,6 @@
         "//tensorflow/core/tpu/kernels:tpu_compile_op_support",
         "//tensorflow/core/tpu/kernels/xla:host_compute_ops",
         "@com_google_googletest//:gtest",
-        "@local_tsl//tsl/lib/core:status_test_util",
         "@local_tsl//tsl/lib/monitoring:test_utils",
         "@local_tsl//tsl/platform:errors",
         "@local_tsl//tsl/platform:statusor",
@@ -187,6 +186,7 @@
         "@local_xla//xla/client:client_library",
         "@local_xla//xla/stream_executor:platform_manager",
         "@local_xla//xla/translate/mhlo_to_hlo:type_to_shape",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
 
@@ -247,8 +247,8 @@
         "@com_google_googletest//:gtest_main",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Parser",
-        "@local_tsl//tsl/lib/core:status_test_util",
         "@local_tsl//tsl/platform:status",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
 
@@ -299,7 +299,7 @@
         "@com_google_googletest//:gtest_main",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Parser",
-        "@local_tsl//tsl/lib/core:status_test_util",
         "@local_tsl//tsl/platform:status",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf_test.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf_test.cc
index 916b568..bce19ea 100644
--- a/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf_test.cc
+++ b/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf_test.cc
@@ -28,9 +28,9 @@
 #include "mlir/Parser/Parser.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/register_common_dialects.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/lib/monitoring/cell_reader.h"
 #include "tensorflow/core/platform/resource_loader.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/status.h"
 
 namespace tensorflow {
diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util_test.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util_test.cc
index 71640e8..57769d2 100644
--- a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util_test.cc
+++ b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util_test.cc
@@ -39,6 +39,7 @@
 #include "xla/client/xla_builder.h"
 #include "xla/client/xla_computation.h"
 #include "xla/shape_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/fake_input.h"
 #include "tensorflow/core/framework/function.h"
 #include "tensorflow/core/framework/node_def_builder.h"
@@ -49,7 +50,6 @@
 #include "tensorflow/core/graph/graph.h"
 #include "tensorflow/core/lib/monitoring/cell_reader.h"
 #include "tensorflow/core/platform/types.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/errors.h"
 #include "tsl/platform/statusor.h"
 
diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph_test.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph_test.cc
index 06208be..c3598e6 100644
--- a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph_test.cc
+++ b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph_test.cc
@@ -29,11 +29,11 @@
 #include "xla/shape.h"
 #include "xla/stream_executor/platform_manager.h"
 #include "xla/translate/mhlo_to_hlo/type_to_shape.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/tensor_shape.h"
 #include "tensorflow/core/lib/monitoring/cell_reader.h"
 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
 #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/lib/monitoring/test_utils.h"
 #include "tsl/platform/errors.h"
 #include "tsl/platform/statusor.h"
diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/tf_dialect_to_executor_test.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/tf_dialect_to_executor_test.cc
index cad1edf..1da6d58c 100644
--- a/tensorflow/compiler/mlir/tf2xla/api/v1/tf_dialect_to_executor_test.cc
+++ b/tensorflow/compiler/mlir/tf2xla/api/v1/tf_dialect_to_executor_test.cc
@@ -27,9 +27,9 @@
 #include "mlir/IR/OwningOpRef.h"  // from @llvm-project
 #include "mlir/Parser/Parser.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/register_common_dialects.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/lib/monitoring/cell_reader.h"
 #include "tensorflow/core/platform/resource_loader.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/status.h"
 
 namespace tensorflow {
diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD b/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD
index 74219b2..49b12fa 100644
--- a/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD
+++ b/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD
@@ -168,8 +168,8 @@
         "@llvm-project//mlir:FuncDialect",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Parser",
-        "@local_tsl//tsl/lib/core:status_test_util",
         "@local_tsl//tsl/platform:status",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
 
@@ -227,7 +227,7 @@
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Parser",
-        "@local_tsl//tsl/lib/core:status_test_util",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
 
diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf_test.cc b/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf_test.cc
index 233b112..20da430 100644
--- a/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf_test.cc
+++ b/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf_test.cc
@@ -33,9 +33,9 @@
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
 #include "tensorflow/compiler/mlir/tf2xla/api/v2/testing/utils.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/lib/monitoring/cell_reader.h"
 #include "tensorflow/core/platform/resource_loader.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/status.h"
 
 namespace tensorflow {
diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf_test.cc b/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf_test.cc
index 3f58643..5830676 100644
--- a/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf_test.cc
+++ b/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf_test.cc
@@ -29,6 +29,7 @@
 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
 #include "xla/client/client_library.h"
 #include "xla/stream_executor/platform_manager.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/lib/monitoring/cell_reader.h"
 #include "tensorflow/core/lib/monitoring/test_utils.h"
 #include "tensorflow/core/platform/env.h"
@@ -37,7 +38,6 @@
 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
 #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
 #include "tensorflow/core/util/debug_data_dumper.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/lib/monitoring/test_utils.h"
 #include "tsl/platform/statusor.h"
 
diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor_test.cc b/tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor_test.cc
index 4e5199b..22d36ea 100644
--- a/tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor_test.cc
+++ b/tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor_test.cc
@@ -32,9 +32,9 @@
 #include "mlir/Parser/Parser.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/register_common_dialects.h"
 #include "tensorflow/compiler/mlir/tf2xla/api/v2/testing/utils.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/lib/monitoring/cell_reader.h"
 #include "tensorflow/core/platform/resource_loader.h"
-#include "tsl/lib/core/status_test_util.h"
 
 namespace tensorflow {
 namespace tf2xla {
diff --git a/tensorflow/compiler/mlir/tf2xla/internal/BUILD b/tensorflow/compiler/mlir/tf2xla/internal/BUILD
index ea2bf17..e635213 100644
--- a/tensorflow/compiler/mlir/tf2xla/internal/BUILD
+++ b/tensorflow/compiler/mlir/tf2xla/internal/BUILD
@@ -244,8 +244,8 @@
         "@llvm-project//mlir:Parser",
         "@llvm-project//mlir:Pass",
         "@llvm-project//mlir:Transforms",
-        "@local_tsl//tsl/lib/core:status_test_util",
         "@local_tsl//tsl/platform:status",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
 
@@ -296,7 +296,7 @@
         "//tensorflow/core:testlib",
         "//tensorflow/core/platform:enable_tf2_utils",
         # "//tensorflow/core/platform:resource_loader",
-        "@local_tsl//tsl/lib/core:status_test_util",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
 
diff --git a/tensorflow/compiler/mlir/tf2xla/internal/logging_hooks_test.cc b/tensorflow/compiler/mlir/tf2xla/internal/logging_hooks_test.cc
index 840d4c9..3365a85 100644
--- a/tensorflow/compiler/mlir/tf2xla/internal/logging_hooks_test.cc
+++ b/tensorflow/compiler/mlir/tf2xla/internal/logging_hooks_test.cc
@@ -29,10 +29,10 @@
 #include "mlir/Pass/PassManager.h"  // from @llvm-project
 #include "mlir/Transforms/Passes.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/register_common_dialects.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/resource_loader.h"
 #include "tensorflow/core/platform/test.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/file_statistics.h"
 #include "tsl/platform/status.h"
 
diff --git a/tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util_test.cc b/tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util_test.cc
index 6cbc67d..78d027a 100644
--- a/tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util_test.cc
+++ b/tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util_test.cc
@@ -30,6 +30,7 @@
 #include "tensorflow/cc/ops/tpu_functional_ops.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
 #include "tensorflow/compiler/tf2xla/tf2xla_defs.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/function.h"
 #include "tensorflow/core/framework/function_testlib.h"
 #include "tensorflow/core/framework/op.h"
@@ -38,7 +39,6 @@
 #include "tensorflow/core/graph/node_builder.h"
 #include "tensorflow/core/platform/enable_tf2_utils.h"
 #include "tensorflow/core/platform/types.h"
-#include "tsl/lib/core/status_test_util.h"
 
 namespace tensorflow {
 
diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/BUILD b/tensorflow/compiler/mlir/tf2xla/transforms/BUILD
index 6ef95e7..8be5db9 100644
--- a/tensorflow/compiler/mlir/tf2xla/transforms/BUILD
+++ b/tensorflow/compiler/mlir/tf2xla/transforms/BUILD
@@ -129,8 +129,8 @@
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:FuncDialect",
         "@llvm-project//mlir:Pass",
-        "@local_tsl//tsl/lib/core:status_test_util",
         "@local_tsl//tsl/platform:statusor",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
 
@@ -390,7 +390,6 @@
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Support",
         "@llvm-project//mlir:TransformUtils",
-        "@local_tsl//tsl/lib/core:status_test_util",
         "@local_tsl//tsl/platform:errors",
         "@local_tsl//tsl/platform:status",
         "@local_tsl//tsl/platform:statusor",
@@ -399,6 +398,7 @@
         "@local_xla//xla/client:xla_builder",
         "@local_xla//xla/client:xla_computation",
         "@local_xla//xla/mlir_hlo",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
 
@@ -507,8 +507,8 @@
         "@llvm-project//mlir:FuncDialect",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Support",
-        "@local_tsl//tsl/lib/core:status_test_util",
         "@local_tsl//tsl/platform:status",
         "@local_tsl//tsl/platform:statusor",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc
index 9021095..a17ef43 100644
--- a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc
+++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc
@@ -34,7 +34,7 @@
 #include "tensorflow/compiler/mlir/tf2xla/transforms/passes.h"
 #include "tensorflow/compiler/mlir/tf2xla/transforms/test_utils.h"
 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tsl/lib/core/status_test_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tsl/platform/status.h"
 #include "tsl/platform/statusor.h"
 
diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter_test.cc b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter_test.cc
index aecf9db..c9d0680 100644
--- a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter_test.cc
+++ b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter_test.cc
@@ -42,9 +42,9 @@
 #include "xla/client/xla_computation.h"
 #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
 #include "xla/shape_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "xla/xla_data.pb.h"
 #include "tensorflow/core/framework/op_kernel.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/errors.h"
 #include "tsl/platform/status.h"
 #include "tsl/platform/statusor.h"
diff --git a/tensorflow/compiler/mlir/tfrt/tests/ifrt/rewrite_cluster_to_ifrt_call.mlir b/tensorflow/compiler/mlir/tfrt/tests/ifrt/rewrite_cluster_to_ifrt_call.mlir
index dbb7773..730d096 100644
--- a/tensorflow/compiler/mlir/tfrt/tests/ifrt/rewrite_cluster_to_ifrt_call.mlir
+++ b/tensorflow/compiler/mlir/tfrt/tests/ifrt/rewrite_cluster_to_ifrt_call.mlir
@@ -11,6 +11,7 @@
 //
 // CHECK:  func.func @_ifrt_program__func(%arg0: tensor<1x3xf32>)
 // CHECK-SAME: __tpu_compile_metadata_text = "args { dtype: DT_FLOAT shape { dim { size: 1 } dim { size: 3 } } kind: PARAMETER sharding { } is_bounded_dynamic_dim: false } retvals { sharding { } } num_replicas: 1 num_cores_per_replica: 1 "
+// CHECK-SAME: device_assignment = []
 // CHECK-SAME: tfrt_ifrt_serving.program_id = [[PROGRAM_ID]] : i64
 // CHECK:      return
 
@@ -38,7 +39,8 @@
 // CHECK:    return
 //
 // CHECK:  func.func @_ifrt_program__func(%arg0: tensor<1x3xf32>)
-// CHECK-SAME: __tpu_compile_metadata_text = "args { dtype: DT_FLOAT shape { dim { size: 1 } dim { size: 3 } } kind: PARAMETER sharding { type: OTHER tile_assignment_dimensions: 2 tile_assignment_dimensions: 1 tile_assignment_devices: 0 tile_assignment_devices: 1 } is_bounded_dynamic_dim: false } num_replicas: 1 num_cores_per_replica: 2 device_assignment { replica_count: 1 computation_count: 2 computation_devices { replica_device_ids: 0 } computation_devices { replica_device_ids: 1 } } use_spmd_for_xla_partitioning: true "
+// CHECK-SAME: __tpu_compile_metadata_text = "args { dtype: DT_FLOAT shape { dim { size: 1 } dim { size: 3 } } kind: PARAMETER sharding { type: OTHER tile_assignment_dimensions: 2 tile_assignment_dimensions: 1 tile_assignment_devices: 0 tile_assignment_devices: 1 } is_bounded_dynamic_dim: false } num_replicas: 1 num_cores_per_replica: 2 use_spmd_for_xla_partitioning: true "
+// CHECK-SAME: device_assignment = [0, 0, 0, 0, 0, 0, 0, 1]
 // CHECK-SAME: tfrt_ifrt_serving.program_id = [[PROGRAM_ID]] : i64
 // CHECK:      return
 
@@ -70,6 +72,7 @@
 // CHECK:    return
 //
 // CHECK:  func.func @_ifrt_program__func(%arg0: tensor<1x3xf32>, %arg1: tensor<3x1xf32>) -> tensor<1x1xf32>
+// CHECK-SAME:      device_assignment = [0, 0, 0, 0, 0, 0, 0, 1]
 // CHECK-SAME:      tfrt_ifrt_serving.program_id = [[PROGRAM_ID]] : i64
 // CHECK-NEXT:     %0 = "tf.MatMul"(%arg0, %arg1)
 // CHECK:          return
@@ -102,6 +105,7 @@
 // CHECK:    return
 //
 // CHECK:  func.func @_ifrt_program__func(%arg0: tensor<1x3xf32>, %arg1: tensor<3x1xf32>) -> tensor<1x1xf32>
+// CHECK-SAME:       device_assignment = []
 // CHECK-SAME:      tfrt_ifrt_serving.program_id = [[PROGRAM_ID]] : i64
 // CHECK-NEXT:     %0 = "tf.MatMul"(%arg0, %arg1)
 // CHECK:          return
diff --git a/tensorflow/compiler/mlir/tfrt/tests/ifrt/tf_device_cleanup.mlir b/tensorflow/compiler/mlir/tfrt/tests/ifrt/tf_device_cleanup.mlir
new file mode 100644
index 0000000..02afa96
--- /dev/null
+++ b/tensorflow/compiler/mlir/tfrt/tests/ifrt/tf_device_cleanup.mlir
@@ -0,0 +1,8 @@
+// RUN: tf-tfrt-opt %s -tf-device-cleanup | FileCheck %s
+
+// CHECK-LABEL: func @ops_with_device
+func.func @ops_with_device() {
+  %0 = "tf.VarHandleOp"() {container = "", shared_name = "var", device = "/device/..."} : () -> tensor<!tf_type.resource<tensor<1xf32>>>
+  // CHECK-NOT: device = "/device/..."
+  func.return
+}
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD
index 80969fe..2ec0fdd 100644
--- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD
+++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD
@@ -69,6 +69,7 @@
         "lower_to_ifrt_restore_variable.cc",
         "rewrite_cluster_to_ifrt_call.cc",
         "sink_variable_as_named_array.cc",
+        "tf_device_cleanup.cc",
         "tf_identity_propagation.cc",
         "tf_ifrt_passes.cc",
         "tf_restore_merging.cc",
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler_test.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler_test.cc
index 6da9eda..3b190d3 100644
--- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler_test.cc
+++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler_test.cc
@@ -31,6 +31,7 @@
 #include "xla/python/ifrt/client.h"
 #include "xla/python/ifrt/test_util.h"
 #include "xla/tsl/framework/test_util/mock_serving_device_selector.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/resource_loader.h"
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/tfrt/graph_executor/graph_execution_options.h"
@@ -38,7 +39,6 @@
 #include "tensorflow/core/tfrt/ifrt/ifrt_serving_core_selector.h"
 #include "tensorflow/core/tfrt/runtime/runtime.h"
 #include "tensorflow/core/tfrt/saved_model/saved_model_testutil.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/env.h"
 #include "tsl/platform/statusor.h"
 #include "tsl/platform/threadpool.h"
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/passes.td b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/passes.td
index 7cdc557..9c37c58 100644
--- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/passes.td
+++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/passes.td
@@ -129,3 +129,14 @@
   let constructor = "CreateTfIdentityPropagationPass()";
 }
 
+def TfDeviceCleanupPass : Pass<"tf-device-cleanup", "mlir::func::FuncOp"> {
+  let summary = "Cleans up device attributes from all ops";
+
+  let description = [{
+    This pass removes `device` attributes from all TF ops. Some Serving
+    doesn't rely on `device` attributes from SavedModel.
+  }];
+
+  let constructor = "CreateTfDeviceCleanupPass()";
+}
+
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/rewrite_cluster_to_ifrt_call.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/rewrite_cluster_to_ifrt_call.cc
index aba37ac..2fc2c17 100644
--- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/rewrite_cluster_to_ifrt_call.cc
+++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/rewrite_cluster_to_ifrt_call.cc
@@ -27,6 +27,7 @@
 #include "llvm/ADT/APInt.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/LogicalResult.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
@@ -128,39 +129,9 @@
              << " is missing";
     int num_cores_per_replica = num_cores_per_replica_attr.getInt();
 
-    std::optional<xla::DeviceAssignmentProto> xla_device_assignment;
-    auto topology_attr = cluster_func->getAttrOfType<mlir::StringAttr>(
-        tensorflow::kTopologyAttr);
-    // Get device assignment.
-    auto device_assignment_attr = cluster_func->getAttrOfType<mlir::ArrayAttr>(
-        tensorflow::kDeviceAssignmentAttr);
-    if (topology_attr && device_assignment_attr && !topology_attr.empty() &&
-        !device_assignment_attr.empty()) {
-      auto device_coordinates =
-          tensorflow::GetDeviceCoordinates(device_assignment_attr);
-      if (!device_coordinates.ok())
-        return cluster_func.emitError()
-               << "error in parsing tpu device coordinates: "
-               << device_coordinates.status().message();
-
-      absl::StatusOr<xla::DeviceAssignmentProto>
-          xla_device_assignment_from_device_assignment_attr =
-              tensorflow::GetXlaDeviceAssignmentProto(
-                  topology_attr.getValue(), num_replicas, num_cores_per_replica,
-                  *device_coordinates);
-      if (!xla_device_assignment_from_device_assignment_attr.ok()) {
-        return cluster_func.emitError()
-               << "error in getting xla device assignment: "
-               << xla_device_assignment_from_device_assignment_attr.status()
-                      .message();
-      }
-      xla_device_assignment =
-          *xla_device_assignment_from_device_assignment_attr;
-    }
-
     return mlir::TFTPU::SetMetadataProtoFromClusterFuncOp(
         cluster_func, num_replicas, num_cores_per_replica,
-        std::move(xla_device_assignment), metadata);
+        /*xla_device_assignment=*/std::nullopt, metadata);
   }
 
   void Rewrite(mlir::SymbolTable &symbol_table,
@@ -194,10 +165,16 @@
 
       auto metadata_attr =
           ifrt_program->getAttrOfType<mlir::StringAttr>(kMetadataTextAttrName);
-      if (!metadata_attr) {
+      auto device_assignment_attr =
+          ifrt_program->getAttrOfType<mlir::ArrayAttr>(kDeviceAssignmentAttr);
+      if (!metadata_attr || !device_assignment_attr) {
         return signalPassFailure();
       }
+
+      // For better debuggability, attach attributes such as
+      // tpu_compile_metadata and device_assignment to IfrtCallOp.
       ifrt_call_op->setAttr(kMetadataTextAttrName, metadata_attr);
+      ifrt_call_op->setAttr(kDeviceAssignmentAttr, device_assignment_attr);
 
       // TODO(b/304839793): populate variable names after adding a variable
       // hoisting pass.
@@ -230,6 +207,13 @@
     cloned_ifrt_program->setAttr(kMetadataTextAttrName,
                                  builder.getStringAttr(serialized_metadata));
 
+    auto device_assignment_attr =
+        cluster_func->getAttrOfType<mlir::ArrayAttr>(kDeviceAssignmentAttr);
+    if (!device_assignment_attr) {
+      device_assignment_attr = builder.getArrayAttr({});
+    }
+    cloned_ifrt_program->setAttr(kDeviceAssignmentAttr, device_assignment_attr);
+
     cloned_ifrt_program.setName(ifrt_program_name);
 
     int64_t program_id = NewProgramId();
@@ -250,10 +234,11 @@
     // hoisting pass.
     ifrt_call_op.setVariableArgIndicesAttr(builder.getI32ArrayAttr({}));
     ifrt_call_op.setProgramId(program_id);
-    // Additionally attach tpu_compile_metadata to IfrtCallOp. Some subsequent
-    // pass such as SinkVariableAsNamedArrayPass relies on this attribute.
+    // For better debuggability, attach attributes such as tpu_compile_metadata
+    // and device_assignment to IfrtCallOp.
     ifrt_call_op->setAttr(kMetadataTextAttrName,
                           builder.getStringAttr(serialized_metadata));
+    ifrt_call_op->setAttr(kDeviceAssignmentAttr, device_assignment_attr);
 
     cluster_func->replaceAllUsesWith(ifrt_call_op.getResults());
     cluster_func->erase();
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo_test.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo_test.cc
index b201370..1de61ab 100644
--- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo_test.cc
+++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo_test.cc
@@ -37,9 +37,9 @@
 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
 #include "xla/python/ifrt/client.h"
 #include "xla/python/ifrt/test_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/resource_loader.h"
 #include "tensorflow/core/platform/test.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/protobuf.h"
 #include "tsl/platform/statusor.h"
 
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_device_cleanup.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_device_cleanup.cc
new file mode 100644
index 0000000..b40c94e
--- /dev/null
+++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_device_cleanup.cc
@@ -0,0 +1,51 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
+#include "mlir/Pass/Pass.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
+
+namespace tensorflow {
+namespace ifrt_serving {
+namespace {
+
+#define GEN_PASS_DEF_TFDEVICECLEANUPPASS
+#define GEN_PASS_DECL_TFDEVICECLEANUPPASS
+#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/passes.h.inc"  // IWYU pragma: keep
+
+class TfDeviceCleanupPass
+    : public impl::TfDeviceCleanupPassBase<TfDeviceCleanupPass> {
+ public:
+  void runOnOperation() override {
+    mlir::func::FuncOp func = getOperation();
+    func.walk([](mlir::Operation* op) {
+      if (llvm::isa<mlir::TF::TensorFlowDialect>(op->getDialect())) {
+        op->removeAttr("device");
+      }
+    });
+  }
+};
+
+}  // namespace
+
+std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
+CreateTfDeviceCleanupPass() {
+  return std::make_unique<TfDeviceCleanupPass>();
+}
+
+}  // namespace ifrt_serving
+}  // namespace tensorflow
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.cc
index 2802cb5..6d49f9a 100644
--- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.cc
+++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.cc
@@ -81,6 +81,10 @@
 
   pm.addPass(CreateRewriteClusterToIfrtCallPass());
 
+  // After device program is extracted, we can clean up device attributes from
+  // all ops.
+  pm.addNestedPass<mlir::func::FuncOp>(CreateTfDeviceCleanupPass());
+
   // Sink VarHandle with ReadVariableOp: subsequent SinkVariableAsNamedArrayPass
   // rely on the co-existence of VarHandle and ReadVariable in the same
   // function.
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.h b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.h
index 93713fb..92d9b06 100644
--- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.h
+++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.h
@@ -57,6 +57,10 @@
 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
 CreateLowerToIfrtRestoreVariablePass();
 
+// Creates a pass that cleans up device attributes from all ops.
+std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
+CreateTfDeviceCleanupPass();
+
 #define GEN_PASS_REGISTRATION
 #include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/passes.h.inc"  // IWYU pragma: keep
 
diff --git a/tensorflow/compiler/mlir/tfrt/translate/mlrt/BUILD b/tensorflow/compiler/mlir/tfrt/translate/mlrt/BUILD
index cb517d1..83b70c2 100644
--- a/tensorflow/compiler/mlir/tfrt/translate/mlrt/BUILD
+++ b/tensorflow/compiler/mlir/tfrt/translate/mlrt/BUILD
@@ -20,11 +20,15 @@
     deps = [
         "//tensorflow/core/tfrt/mlrt/bytecode",
         "//tensorflow/core/tfrt/mlrt/bytecode:executable",
+        "//tensorflow/core/tfrt/mlrt/bytecode:function",
+        "//tensorflow/core/tfrt/mlrt/bytecode:kernel",
         "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/log",
         "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/status",
         "@com_google_absl//absl/status:statusor",
         "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:span",
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:FuncDialect",
         "@llvm-project//mlir:IR",
@@ -38,10 +42,15 @@
     data = glob(["testdata/**"]),
     deps = [
         ":mlir_to_bytecode",
+        "//tensorflow/core/tfrt/mlrt/bytecode",
         "//tensorflow/core/tfrt/mlrt/bytecode:executable",
         "//tensorflow/core/tfrt/mlrt/interpreter:attribute_span",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings:string_view",
         "@com_google_googletest//:gtest_main",
         "@llvm-project//mlir:FuncDialect",
+        "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Parser",
         "@llvm-project//mlir:Support",
         "@local_tsl//tsl/platform:resource_loader",
@@ -57,10 +66,15 @@
     hdrs = ["test_utils.h"],
     deps = [
         # copybara:uncomment "//learning/brain/experimental/tfrt/native_lowering/stubs:tfrt_native_lowering_impl",
+        "@com_google_absl//absl/status",
         "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/strings:string_view",
+        "@com_google_absl//absl/types:span",
         "//tensorflow/core:core_cpu_base",
         "//tensorflow/core:framework",
         "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core/tfrt/graph_executor:sync_resource_state",
         "//tensorflow/core/tfrt/mlrt/attribute",
         "//tensorflow/core/tfrt/mlrt/bytecode",
         "//tensorflow/core/tfrt/mlrt/bytecode:kernel",
@@ -70,7 +84,9 @@
         "//tensorflow/core/tfrt/stubs:tfrt_native_lowering_stub",
         "//tensorflow/core/tfrt/utils:tensor_util",
         "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:statusor",
         "@tf_runtime//:hostcontext",
+        "@tf_runtime//:support",
         "@tf_runtime//:tensor",
     ],
 )
diff --git a/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.cc b/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.cc
index d3b19eb..52b1826 100644
--- a/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.cc
+++ b/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.cc
@@ -25,14 +25,26 @@
 #include "absl/container/flat_hash_map.h"
 #include "absl/log/check.h"
 #include "absl/log/log.h"
+#include "absl/status/status.h"
 #include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
 #include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/raw_ostream.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/Region.h"  // from @llvm-project
+#include "mlir/IR/Types.h"  // from @llvm-project
+#include "mlir/IR/Value.h"  // from @llvm-project
 #include "mlir/Support/LLVM.h"  // from @llvm-project
 #include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h"
 #include "tensorflow/core/tfrt/mlrt/bytecode/executable.h"
+#include "tensorflow/core/tfrt/mlrt/bytecode/function.h"
+#include "tensorflow/core/tfrt/mlrt/bytecode/kernel.h"
 
 namespace mlrt {
 namespace {
diff --git a/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.h b/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.h
index 7f5416d..9508656 100644
--- a/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.h
+++ b/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.h
@@ -22,9 +22,16 @@
 #include <vector>
 
 #include "absl/container/flat_hash_map.h"
+#include "absl/status/status.h"
 #include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/StringRef.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
+#include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/Operation.h"  // from @llvm-project
 #include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h"
 
 namespace mlrt {
diff --git a/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode_test.cc b/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode_test.cc
index 9f02f1d..d7d3065d 100644
--- a/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode_test.cc
+++ b/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode_test.cc
@@ -19,9 +19,20 @@
 #include <vector>
 
 #include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
+#include "mlir/IR/Attributes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
+#include "mlir/IR/DialectRegistry.h"  // from @llvm-project
+#include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/Parser/Parser.h"  // from @llvm-project
 #include "mlir/Support/LLVM.h"  // from @llvm-project
+#include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h"
 #include "tensorflow/core/tfrt/mlrt/bytecode/executable.h"
 #include "tensorflow/core/tfrt/mlrt/interpreter/attribute_span.h"
 #include "tsl/platform/resource_loader.h"
diff --git a/tensorflow/compiler/mlir/tfrt/translate/mlrt/test_utils.cc b/tensorflow/compiler/mlir/tfrt/translate/mlrt/test_utils.cc
index b5a3cb9..e4f9e6f 100644
--- a/tensorflow/compiler/mlir/tfrt/translate/mlrt/test_utils.cc
+++ b/tensorflow/compiler/mlir/tfrt/translate/mlrt/test_utils.cc
@@ -22,10 +22,19 @@
 #include <utility>
 #include <vector>
 
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_cat.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/tfrt/mlrt/attribute/attribute.h"
 #include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h"
 #include "tensorflow/core/tfrt/mlrt/bytecode/kernel.h"
 #include "tensorflow/core/tfrt/mlrt/interpreter/context.h"
 #include "tensorflow/core/tfrt/mlrt/interpreter/interpreter_testutil.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
 
 namespace mlrt {
 namespace testing {
diff --git a/tensorflow/compiler/mlir/tfrt/translate/mlrt/test_utils.h b/tensorflow/compiler/mlir/tfrt/translate/mlrt/test_utils.h
index d569f32..6140c71 100644
--- a/tensorflow/compiler/mlir/tfrt/translate/mlrt/test_utils.h
+++ b/tensorflow/compiler/mlir/tfrt/translate/mlrt/test_utils.h
@@ -21,10 +21,15 @@
 #include <utility>
 #include <vector>
 
+#include "absl/status/status.h"
 #include "absl/status/statusor.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
 #include "tensorflow/core/framework/attr_value.pb.h"
 #include "tensorflow/core/framework/node_def_util.h"
 #include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/tfrt/graph_executor/sync_resource_state.h"
 #include "tensorflow/core/tfrt/mlrt/attribute/attribute.h"
 #include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h"
 #include "tensorflow/core/tfrt/mlrt/bytecode/kernel.h"
@@ -34,10 +39,13 @@
 #include "tensorflow/core/tfrt/stubs/tfrt_native_lowering_stub.h"
 #include "tensorflow/core/tfrt/utils/tensor_util.h"
 #include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
 #include "tfrt/host_context/concurrent_work_queue.h"  // from @tf_runtime
 #include "tfrt/host_context/execution_context.h"  // from @tf_runtime
 #include "tfrt/host_context/host_allocator.h"  // from @tf_runtime
 #include "tfrt/host_context/host_context.h"  // from @tf_runtime
+#include "tfrt/support/string_util.h"  // from @tf_runtime
+#include "tfrt/tensor/dense_host_tensor.h"  // from @tf_runtime
 #include "tfrt/tensor/dense_tensor_utils.h"  // from @tf_runtime
 
 namespace mlrt {
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index fc31b8f..0972f67 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -1,6 +1,7 @@
 load("//tensorflow:strict.default.bzl", "py_strict_library", "py_strict_test")
 load("//tensorflow:tensorflow.default.bzl", "cuda_py_strict_test", "tf_cuda_cc_test")
 load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
+load("//tensorflow/compiler/tests:build_combined_defs.bzl", "tf_xla_combined_py_test")
 load("//tensorflow/compiler/tests:build_defs.bzl", "generate_backend_suites", "tf_xla_py_strict_test")
 load(
     "//tensorflow/core/platform:build_config_root.bzl",
@@ -88,15 +89,20 @@
     ],
 )
 
-tf_xla_py_strict_test(
-    name = "adadelta_test",
+tf_xla_combined_py_test(
+    name = "ops_test_mlir_false",
     size = "medium",
-    srcs = ["adadelta_test.py"],
     enable_mlir_bridge = False,
+    package = "tensorflow.compiler.tests",
     python_version = "PY3",
     tags = [
         "no_pip",  # TODO(b/149738646): fix pip install so these tests run on kokoro pip
     ],
+    test_files = [
+        # go/keep-sorted start
+        "adadelta_test.py",
+        # go/keep-sorted end
+    ],
     deps = [
         ":xla_test",
         "//tensorflow/python/framework:constant_op",
diff --git a/tensorflow/compiler/tests/build_combined_defs.bzl b/tensorflow/compiler/tests/build_combined_defs.bzl
new file mode 100644
index 0000000..0463fe1
--- /dev/null
+++ b/tensorflow/compiler/tests/build_combined_defs.bzl
@@ -0,0 +1,46 @@
+"""Build rule for combining Tensorflow/XLA tests."""
+
+load("//tensorflow:strict.default.bzl", "py_strict_test")
+load("//tensorflow/compiler/tests:build_defs.bzl", "tf_xla_py_test")
+
+def tf_xla_combined_py_test(name = "", package = None, test_files = [], **kwargs):
+    """Generates combined tf_xla_py_test targets, one per XLA backend.
+
+    All tests found in the list test_files are combined into one new test which is then passed on to
+    tf_xla_py_test which creates a new target per XLA backend.
+
+    Args:
+      name: Name of the target.
+      package: The package that all tests in test_files belong to.
+      test_files: The test files to be combined and tested.
+      **kwargs: keyword arguments passed onto the tf_xla_py_test rule.
+    """
+
+    test_file = name + ".py"
+
+    # run the generator to create the combined test file containing all the tests in test_files
+    # redirecting the output of the generator to test_file.
+    native.genrule(
+        name = name + "_gen",
+        testonly = 1,
+        srcs = test_files,
+        outs = [test_file],
+        cmd = """
+mkdir -p $(@D) && cat > $@ << EOF
+from tensorflow.python.platform import test
+%s
+
+if __name__ == "__main__":
+  test.main()
+EOF
+        """ % "\n".join(["from %s.%s import *" % (package, test[:-3]) for test in test_files]),
+        tools = [],
+        tags = ["generated_python_test=%s.%s" % (package, name)],
+    )
+
+    tf_xla_py_test(
+        name = name,
+        test_rule = py_strict_test,
+        srcs = [test_file] + test_files,
+        **kwargs
+    )
diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc
index 1a9215b..ffab429 100644
--- a/tensorflow/compiler/tests/randomized_tests.cc
+++ b/tensorflow/compiler/tests/randomized_tests.cc
@@ -68,6 +68,7 @@
 #include "absl/types/span.h"
 #include "tensorflow/compiler/jit/defs.h"
 #include "tensorflow/compiler/jit/flags.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "xla/xla_data.pb.h"
 #include "tensorflow/core/common_runtime/device_mgr.h"
 #include "tensorflow/core/framework/device.h"
@@ -97,7 +98,6 @@
 #include "tensorflow/core/util/device_name_utils.h"
 #include "tensorflow/core/util/padding.h"
 #include "tensorflow/core/util/tensor_format.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/errors.h"
 #include "tsl/platform/status.h"
 
diff --git a/tensorflow/compiler/tests/unary_ops_composition_test.cc b/tensorflow/compiler/tests/unary_ops_composition_test.cc
index 40a5300..58878b8 100644
--- a/tensorflow/compiler/tests/unary_ops_composition_test.cc
+++ b/tensorflow/compiler/tests/unary_ops_composition_test.cc
@@ -21,6 +21,7 @@
 
 #include "tensorflow/compiler/jit/flags.h"
 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/allocator.h"
 #include "tensorflow/core/framework/device_base.h"
 #include "tensorflow/core/framework/device_factory.h"
@@ -35,7 +36,6 @@
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/util/port.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/status.h"
 
 namespace tensorflow {
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index c1ac5b1..75e1eb4 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -668,6 +668,10 @@
     hdrs = [
         "xla_resource.h",
     ],
+    visibility = [
+        ":internal",
+        "//learning/deepmind/tensorflow/tpufunc:__pkg__",
+    ],
     deps = [
         ":common",
         ":sharding_util",
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index 9ba3ded..b65ac38 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -211,7 +211,6 @@
         "@com_google_absl//absl/strings:str_format",
         "@com_google_absl//absl/synchronization",
         "@com_google_absl//absl/types:span",
-        "@local_tsl//tsl/lib/strings:proto_serialization",
         "@local_tsl//tsl/platform:errors",
         "@local_tsl//tsl/platform:fingerprint",
         "@local_tsl//tsl/platform:statusor",
@@ -226,6 +225,7 @@
         "@local_xla//xla/stream_executor/gpu:gpu_executor_header",
         "@local_xla//xla/stream_executor/gpu:gpu_stream_header",
         "@local_xla//xla/stream_executor/gpu:gpu_types_header",
+        "@local_xla//xla/tsl/lib/strings:proto_serialization",
     ],
     alwayslink = 1,
 )
diff --git a/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.cc b/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.cc
index 4971fd0..6697b06 100644
--- a/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.cc
+++ b/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.cc
@@ -56,6 +56,7 @@
 #include "xla/stream_executor/platform.h"
 #include "xla/stream_executor/platform_manager.h"
 #include "xla/stream_executor/stream.h"
+#include "xla/tsl/lib/strings/proto_serialization.h"
 #include "xla/util.h"
 #include "tensorflow/core/common_runtime/gpu/gpu_process_state.h"
 #include "tensorflow/core/common_runtime/process_state.h"
@@ -73,7 +74,6 @@
 #include "tensorflow/core/platform/refcount.h"
 #include "tensorflow/core/platform/status.h"
 #include "tensorflow/core/platform/types.h"
-#include "tsl/lib/strings/proto_serialization.h"
 #include "tsl/platform/errors.h"
 #include "tsl/platform/fingerprint.h"
 #include "tsl/platform/statusor.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc
index 39dd10d..31a6b68 100644
--- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc
@@ -138,6 +138,7 @@
     std::vector<bool> dims_are_dynamic;
     const auto& dims = shape.dims();
     dims_are_dynamic.reserve(dims);
+    output_dim_sizes.reserve(dims);
     for (int64_t i = 0; i < dims; ++i) {
       output_dim_sizes.push_back(
           xla::Reshape(xla::Slice(ctx->Input(1), {i}, {i + 1}, {1}), {}));
diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc
index 26d3cff..09d6898 100644
--- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc
@@ -405,6 +405,7 @@
         std::vector<xla::XlaOp> dynamic_dims;
         const xla::Shape& shape = list_shape.tuple_shapes(i);
         auto sub_element = xla::GetTupleElement(list, i);
+        dynamic_dims.reserve(shape.dimensions_size());
         for (int64_t dim = 0; dim < shape.dimensions_size(); ++dim) {
           dynamic_dims.push_back(xla::GetDimensionSize(sub_element, dim));
         }
diff --git a/tensorflow/compiler/tf2xla/kernels/where_op.cc b/tensorflow/compiler/tf2xla/kernels/where_op.cc
index 73c6c34..1845b9b 100644
--- a/tensorflow/compiler/tf2xla/kernels/where_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/where_op.cc
@@ -275,6 +275,7 @@
   // and then scatter iotas[out_idxs] into the output.
   std::vector<XlaOp> iotas_to_concat;
   auto iota_shape = xla::ShapeUtil::MakeShape(S32, input_shape.dimensions());
+  iotas_to_concat.reserve(iota_shape.rank());
   for (int64_t axis = 0; axis < iota_shape.rank(); ++axis) {
     iotas_to_concat.push_back(
         xla::Reshape(xla::Iota(b, iota_shape, axis), {flattened_size, 1}));
diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc
index 787d676..48f0622 100644
--- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc
+++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc
@@ -336,11 +336,6 @@
       return nullptr;
     }
 
-    absl::StatusOr<std::unique_ptr<se::StreamExecutor>> GetUncachedExecutor(
-        const se::StreamExecutorConfig& config) override {
-      return std::unique_ptr<se::StreamExecutor>(nullptr);
-    }
-
    private:
     string name_;
   };
diff --git a/tensorflow/core/common_runtime/BUILD b/tensorflow/core/common_runtime/BUILD
index fc32bda..0877e8a 100644
--- a/tensorflow/core/common_runtime/BUILD
+++ b/tensorflow/core/common_runtime/BUILD
@@ -3427,10 +3427,10 @@
         "//tensorflow/core/framework:function_testlib",
         "//tensorflow/core/framework:optimized_function_graph_proto_cc",
         "@com_google_googletest//:gtest_main",
-        "@local_tsl//tsl/lib/core:status_test_util",
         "@local_tsl//tsl/platform:status",
         "@local_tsl//tsl/platform:status_matchers",
         "@local_tsl//tsl/platform:test",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
 
@@ -3477,9 +3477,9 @@
         "//tensorflow/core/kernels:function_ops",
         "@com_google_absl//absl/strings",
         "@com_google_googletest//:gtest_main",
-        "@local_tsl//tsl/lib/core:status_test_util",
         "@local_tsl//tsl/platform:status",
         "@local_tsl//tsl/platform:test",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
 
@@ -3497,7 +3497,7 @@
         "//tensorflow/core:ops",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
-        "@local_tsl//tsl/lib/core:status_test_util",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
 
diff --git a/tensorflow/core/common_runtime/arg_ret_placement_test.cc b/tensorflow/core/common_runtime/arg_ret_placement_test.cc
index 8aea657..11b8bdb 100644
--- a/tensorflow/core/common_runtime/arg_ret_placement_test.cc
+++ b/tensorflow/core/common_runtime/arg_ret_placement_test.cc
@@ -20,6 +20,7 @@
 
 #include <gtest/gtest.h>
 #include "tensorflow/cc/framework/scope.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/allocator.h"
 #include "tensorflow/core/framework/full_type.pb.h"
 #include "tensorflow/core/framework/function.h"
@@ -29,7 +30,6 @@
 #include "tensorflow/core/graph/graph.h"
 #include "tensorflow/core/graph/node_builder.h"
 #include "tensorflow/core/platform/test.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/errors.h"
 #include "tsl/platform/status.h"
 
diff --git a/tensorflow/core/common_runtime/colocate_predecessor_trees_pass_test.cc b/tensorflow/core/common_runtime/colocate_predecessor_trees_pass_test.cc
index a16b90c..6e78f32 100644
--- a/tensorflow/core/common_runtime/colocate_predecessor_trees_pass_test.cc
+++ b/tensorflow/core/common_runtime/colocate_predecessor_trees_pass_test.cc
@@ -19,6 +19,7 @@
 #include <string>
 
 #include "tensorflow/cc/framework/scope.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/common_runtime/graph_def_builder_util.h"
 #include "tensorflow/core/common_runtime/optimization_registry.h"
 #include "tensorflow/core/config/flag_defs.h"
@@ -29,7 +30,6 @@
 #include "tensorflow/core/graph/graph.h"
 #include "tensorflow/core/graph/graph_def_builder.h"
 #include "tensorflow/core/platform/test.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/test.h"
 
 namespace tensorflow {
diff --git a/tensorflow/core/common_runtime/copy_tensor.cc b/tensorflow/core/common_runtime/copy_tensor.cc
index a4712a5..675ffc6 100644
--- a/tensorflow/core/common_runtime/copy_tensor.cc
+++ b/tensorflow/core/common_runtime/copy_tensor.cc
@@ -206,7 +206,7 @@
                         const Tensor* input, Tensor* output,
                         int dev_to_dev_stream_index, StatusCallback done,
                         bool sync_dst_compute) {
-  profiler::ScopedAnnotation annotation(
+  tsl::profiler::ScopedAnnotation annotation(
       [&] { return absl::StrCat("#edge_name=", edge_name, "#"); });
   VLOG(4) << "Copy " << edge_name;
 
diff --git a/tensorflow/core/common_runtime/eager/context_distributed_manager.cc b/tensorflow/core/common_runtime/eager/context_distributed_manager.cc
index d2b51ee..1f8c8cd 100644
--- a/tensorflow/core/common_runtime/eager/context_distributed_manager.cc
+++ b/tensorflow/core/common_runtime/eager/context_distributed_manager.cc
@@ -136,49 +136,6 @@
   std::string key_prefix_;
 };
 
-// Remove LocalDeviceState objects from
-// info->local_device_states that have unique hardware IDs
-// (i.e. ignore duplicate virtual devices) and return them in a map.
-static std::map<int, std::unique_ptr<xla::LocalDeviceState>>
-GetUniqueDeviceStates(PjRtGpuClientCreationInfo* info) {
-  // Only consider each hardware device once. In test environments, one
-  // physical GPU (e.g. hardware_id 0) might be shared as virtual GPUs (e.g.
-  // local_id 0 and 1) by multiple workers (multiple processes on the same
-  // computer). If there is a need to not ignore these for an actual case, a
-  // possible solution is to add a flag to only enable the use of
-  // hardware_id_to_local_id for tests.
-
-  auto input_states = std::move(info->local_device_states);
-
-  absl::flat_hash_map<int, int> hardware_id_to_local_id;
-  for (const auto& id_state : input_states) {
-    int local_id = id_state.second->local_device_id().value();
-    int hardware_id = id_state.second->local_hardware_id().value();
-    if (hardware_id_to_local_id.contains(hardware_id)) {
-      if (hardware_id_to_local_id[hardware_id] > local_id) {
-        // Use the device with the smallest local_id, ignore others.
-        hardware_id_to_local_id[hardware_id] = local_id;
-      }
-    } else {
-      hardware_id_to_local_id[hardware_id] = local_id;
-    }
-  }
-  std::map<int, std::unique_ptr<xla::LocalDeviceState>> local_device_states;
-  for (auto& id_state : input_states) {
-    int local_id = id_state.second->local_device_id().value();
-    int hardware_id = id_state.second->local_hardware_id().value();
-    if (hardware_id_to_local_id[hardware_id] != local_id) {
-      VLOG(1) << "For hardware_id=" << hardware_id
-              << ", ignoring redundant local_id=" << local_id
-              << ". local_id=" << hardware_id_to_local_id[hardware_id]
-              << " will be used instead.";
-      continue;
-    }
-    local_device_states.emplace(id_state.first, std::move(id_state.second));
-  }
-  return local_device_states;
-}
-
 // Coordinate creation of a PjRt GPU client with distributed devices when there
 // are multiple threads (which typically occurs in test environments that use
 // multiple threads to simulate multiple workers).
@@ -319,10 +276,9 @@
 
   auto kv_store =
       std::make_shared<XlaKeyValueStore>(coordination_service_agent);
-  std::map<int, std::unique_ptr<xla::LocalDeviceState>>
-      unique_local_device_states;
+  std::map<int, std::unique_ptr<xla::LocalDeviceState>> local_device_states;
   if (use_creation_info) {
-    unique_local_device_states = GetUniqueDeviceStates(info);
+    local_device_states = std::move(info->local_device_states);
   }
   if (use_creation_info) {
     // Tell any other threads are waiting to call BuildDistributedDevices to
@@ -330,7 +286,7 @@
     creation_state->SetReady();
   }
   auto device_topology_pair = BuildDistributedDevices(
-      platform_name, std::move(unique_local_device_states), node_id, num_nodes,
+      platform_name, std::move(local_device_states), node_id, num_nodes,
       gpu_run_options.get(), kv_store, /*enable_mock_nccl=*/false);
   if (!device_topology_pair.ok()) {
     if (use_creation_info) {
diff --git a/tensorflow/core/common_runtime/eager/context_test.cc b/tensorflow/core/common_runtime/eager/context_test.cc
index 1758e31..4c50c62 100644
--- a/tensorflow/core/common_runtime/eager/context_test.cc
+++ b/tensorflow/core/common_runtime/eager/context_test.cc
@@ -26,6 +26,7 @@
 #include "tensorflow/c/eager/immediate_execution_context.h"
 #include "tensorflow/c/eager/immediate_execution_operation.h"
 #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/common_runtime/composite_device.h"
 #include "tensorflow/core/common_runtime/device_mgr.h"
 #include "tensorflow/core/common_runtime/eager/context_distributed_manager.h"
@@ -39,7 +40,6 @@
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/public/session_options.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/status.h"
 
 namespace tensorflow {
diff --git a/tensorflow/core/common_runtime/eager/eager_executor_test.cc b/tensorflow/core/common_runtime/eager/eager_executor_test.cc
index e933a04..1650dbf 100644
--- a/tensorflow/core/common_runtime/eager/eager_executor_test.cc
+++ b/tensorflow/core/common_runtime/eager/eager_executor_test.cc
@@ -17,10 +17,10 @@
 #include <memory>
 #include <utility>
 
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/status.h"
 #include "tensorflow/core/platform/status_matchers.h"
 #include "tensorflow/core/platform/test.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/status.h"
 #include "tsl/protobuf/error_codes.pb.h"
 
diff --git a/tensorflow/core/common_runtime/eager/placement_utils_test.cc b/tensorflow/core/common_runtime/eager/placement_utils_test.cc
index 8803d74..6220cc9 100644
--- a/tensorflow/core/common_runtime/eager/placement_utils_test.cc
+++ b/tensorflow/core/common_runtime/eager/placement_utils_test.cc
@@ -21,11 +21,11 @@
 
 #include <gtest/gtest.h>
 #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/common_runtime/device_mgr.h"
 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
 #include "tensorflow/core/common_runtime/eager/execute_node.h"
 #include "tensorflow/core/platform/test.h"
-#include "tsl/lib/core/status_test_util.h"
 
 #define DEVICE_CPU0 "/job:localhost/replica:0/task:0/device:CPU:0"
 #define DEVICE_CPU0_TASK1 "/job:localhost/replica:0/task:1/device:CPU:0"
diff --git a/tensorflow/core/common_runtime/eager/summary_optimizer_test.cc b/tensorflow/core/common_runtime/eager/summary_optimizer_test.cc
index 1038ba9..efa597f 100644
--- a/tensorflow/core/common_runtime/eager/summary_optimizer_test.cc
+++ b/tensorflow/core/common_runtime/eager/summary_optimizer_test.cc
@@ -19,12 +19,12 @@
 #include <string>
 #include <vector>
 
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/attr_value.pb.h"
 #include "tensorflow/core/framework/function.h"
 #include "tensorflow/core/framework/function.pb.h"
 #include "tensorflow/core/framework/op.h"
 #include "tensorflow/core/platform/test.h"
-#include "tsl/lib/core/status_test_util.h"
 
 namespace tensorflow {
 namespace {
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device_test.cc b/tensorflow/core/common_runtime/gpu/gpu_device_test.cc
index 3f2665e..6aa62de 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device_test.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device_test.cc
@@ -25,13 +25,13 @@
 #include "xla/stream_executor/gpu/gpu_init.h"
 #include "xla/tests/test_macros.h"
 #include "xla/tsl/framework/device_id.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/common_runtime/gpu/gpu_process_state.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/random.h"
 #include "tensorflow/core/platform/status.h"
 #include "tensorflow/core/platform/test.h"
-#include "tsl/lib/core/status_test_util.h"
 
 #ifdef TF_GPU_USE_PJRT
 #include "xla/pjrt/pjrt_client.h"
diff --git a/tensorflow/core/common_runtime/graph_constructor.cc b/tensorflow/core/common_runtime/graph_constructor.cc
index 66109ae..78c2713 100644
--- a/tensorflow/core/common_runtime/graph_constructor.cc
+++ b/tensorflow/core/common_runtime/graph_constructor.cc
@@ -112,7 +112,7 @@
         : allow_internal_ops(false),
           expect_device_spec(false),
           propagate_device_spec(in.propagate_device_spec),
-          prefix(in.prefix.empty() || str_util::EndsWith(in.prefix, "/")
+          prefix(in.prefix.empty() || absl::EndsWith(in.prefix, "/")
                      ? in.prefix
                      : in.prefix + "/"),
           uniquify_names(in.uniquify_names),
diff --git a/tensorflow/core/common_runtime/int32_fulltype_test.cc b/tensorflow/core/common_runtime/int32_fulltype_test.cc
index e6ead59..5d2c0e0 100644
--- a/tensorflow/core/common_runtime/int32_fulltype_test.cc
+++ b/tensorflow/core/common_runtime/int32_fulltype_test.cc
@@ -18,6 +18,7 @@
 #include <string>
 #include <unordered_map>
 
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/common_runtime/graph_def_builder_util.h"
 #include "tensorflow/core/framework/device_attributes.pb.h"
 #include "tensorflow/core/framework/full_type.pb.h"
@@ -30,7 +31,6 @@
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/protobuf/config.pb.h"
 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
-#include "tsl/lib/core/status_test_util.h"
 
 namespace tensorflow {
 
diff --git a/tensorflow/core/common_runtime/lower_while_op_test.cc b/tensorflow/core/common_runtime/lower_while_op_test.cc
index b57145c..31c1e40 100644
--- a/tensorflow/core/common_runtime/lower_while_op_test.cc
+++ b/tensorflow/core/common_runtime/lower_while_op_test.cc
@@ -25,6 +25,7 @@
 #include "tensorflow/cc/ops/control_flow_ops_internal.h"
 #include "tensorflow/cc/ops/function_ops.h"
 #include "tensorflow/cc/ops/standard_ops.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/common_runtime/graph_constructor.h"
 #include "tensorflow/core/common_runtime/graph_runner.h"
 #include "tensorflow/core/common_runtime/lower_functional_ops.h"
@@ -39,7 +40,6 @@
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/lib/strings/str_util.h"
 #include "tensorflow/core/platform/test.h"
-#include "tsl/lib/core/status_test_util.h"
 
 namespace tensorflow {
 namespace {
diff --git a/tensorflow/core/common_runtime/next_pluggable_device/BUILD b/tensorflow/core/common_runtime/next_pluggable_device/BUILD
index 74f843c..4a36e8f 100644
--- a/tensorflow/core/common_runtime/next_pluggable_device/BUILD
+++ b/tensorflow/core/common_runtime/next_pluggable_device/BUILD
@@ -329,7 +329,6 @@
         "@com_google_absl//absl/log:check",
         "@com_google_absl//absl/status",
         "@com_google_absl//absl/time",
-        "@local_tsl//tsl/lib/core:status_test_util",
         "@local_tsl//tsl/platform:env",
         "@local_tsl//tsl/platform:env_impl",
         "@local_tsl//tsl/platform:errors",
@@ -341,6 +340,7 @@
         "@local_xla//xla/tsl/distributed_runtime:call_options",
         "@local_xla//xla/tsl/distributed_runtime/coordination:coordination_client",
         "@local_xla//xla/tsl/distributed_runtime/coordination:coordination_service_agent",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
 
diff --git a/tensorflow/core/common_runtime/next_pluggable_device/c/BUILD b/tensorflow/core/common_runtime/next_pluggable_device/c/BUILD
index 7feb974..7862391 100644
--- a/tensorflow/core/common_runtime/next_pluggable_device/c/BUILD
+++ b/tensorflow/core/common_runtime/next_pluggable_device/c/BUILD
@@ -169,8 +169,8 @@
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/synchronization",
         "@com_google_googletest//:gtest_main",
-        "@local_tsl//tsl/lib/core:status_test_util",
         "@local_tsl//tsl/platform:status",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
 
diff --git a/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_test.cc b/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_test.cc
index b5a37df..02ea581 100644
--- a/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_test.cc
+++ b/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_test.cc
@@ -26,6 +26,7 @@
 #include "absl/strings/str_cat.h"
 #include "absl/synchronization/notification.h"
 #include "xla/tsl/framework/allocator.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/common_runtime/device_mgr.h"
 #include "tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_helper.h"
 #include "tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_internal.h"
@@ -45,7 +46,6 @@
 #include "tensorflow/core/platform/status.h"
 #include "tensorflow/core/platform/stringpiece.h"
 #include "tensorflow/core/platform/types.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/status.h"
 
 namespace tensorflow {
diff --git a/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_coordination_service_agent_test.cc b/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_coordination_service_agent_test.cc
index f61be2e..5d62f8c 100644
--- a/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_coordination_service_agent_test.cc
+++ b/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_coordination_service_agent_test.cc
@@ -27,8 +27,8 @@
 #include "xla/tsl/distributed_runtime/call_options.h"
 #include "xla/tsl/distributed_runtime/coordination/coordination_client.h"
 #include "xla/tsl/distributed_runtime/coordination/coordination_service_agent.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/status.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/env.h"
 #include "tsl/platform/test.h"
 #include "tsl/protobuf/coordination_config.pb.h"
diff --git a/tensorflow/core/common_runtime/optimize_function_graph_utils_test.cc b/tensorflow/core/common_runtime/optimize_function_graph_utils_test.cc
index b2cf3d1..52925d9 100644
--- a/tensorflow/core/common_runtime/optimize_function_graph_utils_test.cc
+++ b/tensorflow/core/common_runtime/optimize_function_graph_utils_test.cc
@@ -21,6 +21,7 @@
 #include <gtest/gtest.h>
 #include "absl/strings/string_view.h"
 #include "absl/time/time.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/common_runtime/device.h"
 #include "tensorflow/core/common_runtime/device_factory.h"
 #include "tensorflow/core/common_runtime/device_set.h"
@@ -29,7 +30,6 @@
 #include "tensorflow/core/framework/metrics.h"
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/public/session_options.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/env.h"
 #include "tsl/platform/status.h"
 
diff --git a/tensorflow/core/common_runtime/optimized_function_graph_info.h b/tensorflow/core/common_runtime/optimized_function_graph_info.h
index dd05b02..b15790d 100644
--- a/tensorflow/core/common_runtime/optimized_function_graph_info.h
+++ b/tensorflow/core/common_runtime/optimized_function_graph_info.h
@@ -71,10 +71,10 @@
   OptimizedFunctionGraphInfo(OptimizedFunctionGraphInfo& info) = delete;
   OptimizedFunctionGraphInfo& operator=(OptimizedFunctionGraphInfo& info) =
       delete;
-  OptimizedFunctionGraphInfo(OptimizedFunctionGraphInfo&& info) noexcept =
-      default;
-  OptimizedFunctionGraphInfo& operator=(
-      OptimizedFunctionGraphInfo&& info) noexcept = default;
+  OptimizedFunctionGraphInfo(OptimizedFunctionGraphInfo&& info) =
+      default;  // NOLINT
+  OptimizedFunctionGraphInfo& operator=(OptimizedFunctionGraphInfo&& info) =
+      default;  // NOLINT
 
   // Converts from the struct to OptimizedFunctionGraph proto.
   static OptimizedFunctionGraph ToProto(const OptimizedFunctionGraphInfo& info);
diff --git a/tensorflow/core/common_runtime/optimized_function_graph_info_test.cc b/tensorflow/core/common_runtime/optimized_function_graph_info_test.cc
index 800da5f..cab15e6 100644
--- a/tensorflow/core/common_runtime/optimized_function_graph_info_test.cc
+++ b/tensorflow/core/common_runtime/optimized_function_graph_info_test.cc
@@ -23,11 +23,11 @@
 #include "absl/status/statusor.h"
 #include "absl/strings/substitute.h"
 #include "third_party/protobuf/text_format.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/function.h"
 #include "tensorflow/core/framework/function_testlib.h"
 #include "tensorflow/core/framework/optimized_function_graph.pb.h"
 #include "tensorflow/core/graph/node_builder.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/errors.h"
 #include "tsl/platform/status.h"
 #include "tsl/platform/status_matchers.h"
diff --git a/tensorflow/core/common_runtime/replicate_constants_pass_test.cc b/tensorflow/core/common_runtime/replicate_constants_pass_test.cc
index 346ba17..bf335df 100644
--- a/tensorflow/core/common_runtime/replicate_constants_pass_test.cc
+++ b/tensorflow/core/common_runtime/replicate_constants_pass_test.cc
@@ -22,6 +22,7 @@
 #include "tensorflow/cc/framework/scope.h"
 #include "tensorflow/cc/ops/const_op.h"
 #include "tensorflow/cc/ops/math_ops.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/common_runtime/optimization_registry.h"
 #include "tensorflow/core/config/flag_defs.h"
 #include "tensorflow/core/config/flags.h"
@@ -29,7 +30,6 @@
 #include "tensorflow/core/framework/tensor_shape.h"
 #include "tensorflow/core/graph/graph.h"
 #include "tensorflow/core/platform/test.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/status.h"
 #include "tsl/platform/test.h"
 
diff --git a/tensorflow/core/config/flag_defs.h b/tensorflow/core/config/flag_defs.h
index a773fbb..23e9989 100644
--- a/tensorflow/core/config/flag_defs.h
+++ b/tensorflow/core/config/flag_defs.h
@@ -64,6 +64,9 @@
   // TODO(b/341325107): Make this behavior the default and remove the flag.
   TF_DECLARE_FLAG(enable_function_pruning_before_inlining, false,
                   "If true, functions will be pruned before inlining.")
+  TF_DECLARE_FLAG(enable_skip_encapsulation_for_non_tpu_graphs, false,
+                  "If true, TF2XLA encapsulation will be skipped for non-TPU "
+                  "graphs.")
   // LINT.ThenChange(//tensorflow/core/config/flags_api_wrapper.cc)
 };
 
diff --git a/tensorflow/core/config/flags_api_wrapper.cc b/tensorflow/core/config/flags_api_wrapper.cc
index 096d48c..060ede3 100644
--- a/tensorflow/core/config/flags_api_wrapper.cc
+++ b/tensorflow/core/config/flags_api_wrapper.cc
@@ -55,5 +55,6 @@
   TF_PY_DECLARE_FLAG(enable_colocation_key_propagation_in_while_op_lowering);
   TF_PY_DECLARE_FLAG(enable_tf2min_ici_weight)
   TF_PY_DECLARE_FLAG(enable_function_pruning_before_inlining)
+  TF_PY_DECLARE_FLAG(enable_skip_encapsulation_for_non_tpu_graphs)
   // LINT.ThenChange(//tensorflow/core/config/flag_defs.h)
 };
diff --git a/tensorflow/core/data/BUILD b/tensorflow/core/data/BUILD
index 1ec7a6f..748dfc1 100644
--- a/tensorflow/core/data/BUILD
+++ b/tensorflow/core/data/BUILD
@@ -218,7 +218,9 @@
         "//tensorflow/core:framework",
         "@com_google_absl//absl/base:core_headers",
         "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
         "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/strings:str_format",
         "@com_google_absl//absl/synchronization",
         "@local_tsl//tsl/platform:errors",
         "@local_tsl//tsl/platform:statusor",
diff --git a/tensorflow/core/data/captured_function.cc b/tensorflow/core/data/captured_function.cc
index 6ddf987..2206cb0 100644
--- a/tensorflow/core/data/captured_function.cc
+++ b/tensorflow/core/data/captured_function.cc
@@ -943,8 +943,10 @@
 }
 
 void InstantiatedCapturedFunction::RunAsync(
-    IteratorContext* ctx, std::vector<Tensor>&& args, std::vector<Tensor>* rets,
-    FunctionLibraryRuntime::DoneCallback done,
+    std::function<void(std::function<void()>)> runner,
+    CancellationManager* parent_cancellation_manager,
+    CollectiveExecutor* collective_executor, std::vector<Tensor>&& args,
+    std::vector<Tensor>* rets, FunctionLibraryRuntime::DoneCallback done,
     const std::shared_ptr<model::Node>& node) const {
   auto& info = captured_func_->short_circuit_info();
   if (!info.indices.empty()) {
@@ -952,7 +954,7 @@
     // potentially do a non-trivial amount of (e.g. copying) work, and we may
     // want to run that concurrently with the next invocation.
     Status s = RunShortCircuit(info, std::move(args), captured_func_, rets);
-    (*ctx->runner())(
+    runner(
         std::bind([s](FunctionLibraryRuntime::DoneCallback& done) { done(s); },
                   std::move(done)));
     return;
@@ -971,18 +973,18 @@
         resource_mgr->Cleanup(name).IgnoreError();
       });
   f_opts.step_container = step_container;
-  f_opts.runner = ctx->runner();
+  f_opts.runner = &runner;
   f_opts.create_rendezvous = ShouldCreateRendezvous();
   auto cancellation_manager =
-      std::make_unique<CancellationManager>(ctx->cancellation_manager());
+      std::make_unique<CancellationManager>(parent_cancellation_manager);
   f_opts.cancellation_manager = cancellation_manager.get();
-  f_opts.collective_executor = ctx->collective_executor();
+  f_opts.collective_executor = collective_executor;
 
   std::shared_ptr<SimpleStepStatsCollector> stats_collector;
-  if (node || ctx->stats_aggregator()) {
+  if (node) {
     stats_collector = std::make_shared<SimpleStepStatsCollector>();
   }
-  const bool collect_usage = node && ctx->model();
+  const bool collect_usage = node != nullptr;
   f_opts.stats_collector = stats_collector.get();
 
   // Transfer ownership of the cancellation manager to `callback`.
@@ -992,7 +994,6 @@
       [this, rets, step_container, raw_cancellation_manager, frame, node,
        collect_usage](
           const FunctionLibraryRuntime::DoneCallback& done,
-          IteratorContext* ctx,
           const std::shared_ptr<SimpleStepStatsCollector>& stats_collector,
           // Begin unbound arguments.
           Status s) {
@@ -1003,18 +1004,6 @@
         }
         delete frame;
         if (node) {
-          // TODO(b/129085499) Utilize the `node_name` which would be unique
-          // than the prefix for the function execution time statistics.
-          // prefix_with_func_name would then be node_name + func_name.
-          if (ctx->stats_aggregator()) {
-            string prefix_with_func_name =
-                strings::StrCat(node->name(), stats_utils::kDelimiter,
-                                captured_func_->func().name());
-            ctx->stats_aggregator()->AddToHistogram(
-                stats_utils::ExecutionTimeHistogramName(prefix_with_func_name),
-                {static_cast<float>(stats_collector->processing_time())},
-                node->num_elements());
-          }
           node->add_processing_time(stats_collector->processing_time());
         }
         if (collect_usage) {
@@ -1025,7 +1014,7 @@
           node->record_stop(EnvTime::NowNanos());
         }
       },
-      std::move(done), ctx, std::move(stats_collector), std::placeholders::_1);
+      std::move(done), std::move(stats_collector), std::placeholders::_1);
 
   tsl::profiler::TraceMe activity(
       [&] {
diff --git a/tensorflow/core/data/captured_function.h b/tensorflow/core/data/captured_function.h
index e415c54..854d9fc 100644
--- a/tensorflow/core/data/captured_function.h
+++ b/tensorflow/core/data/captured_function.h
@@ -288,6 +288,18 @@
   void RunAsync(IteratorContext* ctx, std::vector<Tensor>&& args,
                 std::vector<Tensor>* rets,
                 FunctionLibraryRuntime::DoneCallback done,
+                const std::shared_ptr<model::Node>& node) const {
+    RunAsync(*(ctx->runner()), ctx->cancellation_manager(),
+             ctx->collective_executor(), std::move(args), rets, done, node);
+  }
+
+  // A version of `RunAsync` that does not take an `IteratorContext` but a
+  // runner, a cancellation manager, and a collective executor.
+  void RunAsync(std::function<void(std::function<void()>)> runner,
+                CancellationManager* parent_cancellation_manager,
+                CollectiveExecutor* collective_executor,
+                std::vector<Tensor>&& args, std::vector<Tensor>* rets,
+                FunctionLibraryRuntime::DoneCallback done,
                 const std::shared_ptr<model::Node>& node) const;
 
   std::string func_name() const { return captured_func_->func().name(); }
diff --git a/tensorflow/core/data/dataset_test_base.cc b/tensorflow/core/data/dataset_test_base.cc
index 7e295e3..e770b4f 100644
--- a/tensorflow/core/data/dataset_test_base.cc
+++ b/tensorflow/core/data/dataset_test_base.cc
@@ -348,7 +348,7 @@
 
 Status DatasetOpsTestBase::CreateDatasetContext(
     OpKernel* const dateset_kernel,
-    gtl::InlinedVector<TensorValue, 4>* const inputs,
+    absl::InlinedVector<TensorValue, 4>* const inputs,
     std::unique_ptr<OpKernelContext::Params>* dataset_context_params,
     std::unique_ptr<OpKernelContext>* dataset_context) {
   Status status = CheckOpKernelInput(*dateset_kernel, *inputs);
@@ -515,13 +515,13 @@
 }
 
 Status DatasetOpsTestBase::CreateOpKernelContext(
-    OpKernel* kernel, gtl::InlinedVector<TensorValue, 4>* inputs,
+    OpKernel* kernel, absl::InlinedVector<TensorValue, 4>* inputs,
     std::unique_ptr<OpKernelContext>* context) {
   return CreateOpKernelContext(kernel, inputs, &params_, context);
 }
 
 Status DatasetOpsTestBase::CreateOpKernelContext(
-    OpKernel* kernel, gtl::InlinedVector<TensorValue, 4>* inputs,
+    OpKernel* kernel, absl::InlinedVector<TensorValue, 4>* inputs,
     std::unique_ptr<OpKernelContext::Params>* context_params,
     std::unique_ptr<OpKernelContext>* context) {
   auto params = std::make_unique<OpKernelContext::Params>();
@@ -565,7 +565,7 @@
 }
 
 Status DatasetOpsTestBase::CheckOpKernelInput(
-    const OpKernel& kernel, const gtl::InlinedVector<TensorValue, 4>& inputs) {
+    const OpKernel& kernel, const absl::InlinedVector<TensorValue, 4>& inputs) {
   if (kernel.num_inputs() != inputs.size()) {
     return errors::InvalidArgument("The number of input elements should be ",
                                    kernel.num_inputs(),
@@ -575,7 +575,7 @@
 }
 
 Status DatasetOpsTestBase::AddDatasetInput(
-    gtl::InlinedVector<TensorValue, 4>* inputs, DataTypeVector input_types,
+    absl::InlinedVector<TensorValue, 4>* inputs, DataTypeVector input_types,
     DataType dtype, const TensorShape& shape) {
   if (input_types.size() < inputs->size()) {
     return errors::InvalidArgument("Adding more inputs than types: ",
@@ -862,7 +862,7 @@
     input_datasets.push_back(t.get());
     created_tensors->push_back(std::move(t));
   }
-  gtl::InlinedVector<TensorValue, 4> inputs;
+  absl::InlinedVector<TensorValue, 4> inputs;
   inputs.reserve(input_datasets.size());
   for (auto input_dataset : input_datasets) {
     inputs.emplace_back(TensorValue(input_dataset));
@@ -985,7 +985,7 @@
   TF_RETURN_IF_ERROR(dataset_params.GetAttributes(&attributes));
 
   auto input_tensors = dataset_params.GetInputTensors();
-  gtl::InlinedVector<TensorValue, 4> inputs;
+  absl::InlinedVector<TensorValue, 4> inputs;
   inputs.reserve(input_datasets.size() + input_tensors.size());
   for (auto input_dataset : input_datasets) {
     inputs.emplace_back(TensorValue(input_dataset));
@@ -1165,7 +1165,7 @@
     const std::vector<Tensor>& input_components) {
   std::vector<PartialTensorShape> shapes;
   for (const auto& component : input_components) {
-    gtl::InlinedVector<int64_t, 4> partial_dim_sizes;
+    absl::InlinedVector<int64_t, 4> partial_dim_sizes;
     for (int i = 1; i < component.dims(); ++i) {
       partial_dim_sizes.push_back(component.dim_size(i));
     }
diff --git a/tensorflow/core/data/dataset_test_base.h b/tensorflow/core/data/dataset_test_base.h
index ec9805b..e727823 100644
--- a/tensorflow/core/data/dataset_test_base.h
+++ b/tensorflow/core/data/dataset_test_base.h
@@ -766,7 +766,7 @@
 
   // Creates a new op kernel context.
   Status CreateDatasetContext(
-      OpKernel* dateset_kernel, gtl::InlinedVector<TensorValue, 4>* inputs,
+      OpKernel* dateset_kernel, absl::InlinedVector<TensorValue, 4>* inputs,
       std::unique_ptr<OpKernelContext::Params>* dataset_context_params,
       std::unique_ptr<OpKernelContext>* dataset_context);
 
@@ -798,16 +798,16 @@
 
   // Checks that the size of `inputs` matches the requirement of the op kernel.
   Status CheckOpKernelInput(const OpKernel& kernel,
-                            const gtl::InlinedVector<TensorValue, 4>& inputs);
+                            const absl::InlinedVector<TensorValue, 4>& inputs);
 
   // Creates a new context for running the dataset operation.
   Status CreateOpKernelContext(OpKernel* kernel,
-                               gtl::InlinedVector<TensorValue, 4>* inputs,
+                               absl::InlinedVector<TensorValue, 4>* inputs,
                                std::unique_ptr<OpKernelContext>* context);
 
   // Creates a new context for running the dataset operation.
   Status CreateOpKernelContext(OpKernel* kernel,
-                               gtl::InlinedVector<TensorValue, 4>* inputs,
+                               absl::InlinedVector<TensorValue, 4>* inputs,
                                std::unique_ptr<OpKernelContext::Params>* params,
                                std::unique_ptr<OpKernelContext>* context);
 
@@ -856,7 +856,7 @@
 
   // Adds an empty tensor with the specified dtype and shape to the input
   // vector.
-  Status AddDatasetInput(gtl::InlinedVector<TensorValue, 4>* inputs,
+  Status AddDatasetInput(absl::InlinedVector<TensorValue, 4>* inputs,
                          DataTypeVector input_types, DataType dtype,
                          const TensorShape& shape);
 
diff --git a/tensorflow/core/data/dataset_utils.cc b/tensorflow/core/data/dataset_utils.cc
index cc7ed17..1934599 100644
--- a/tensorflow/core/data/dataset_utils.cc
+++ b/tensorflow/core/data/dataset_utils.cc
@@ -1018,7 +1018,7 @@
                             AllTasks);
 REGISTER_DATASET_EXPERIMENT("no_compression", RandomJobSamplePercentage<0>,
                             AllTasks);
-REGISTER_DATASET_EXPERIMENT("no_compression_v2", RandomJobSamplePercentage<50>,
+REGISTER_DATASET_EXPERIMENT("no_compression_v2", RandomJobSamplePercentage<0>,
                             AllTasks);
 REGISTER_DATASET_EXPERIMENT("inject_io_prefetch", RandomJobSamplePercentage<0>,
                             AllTasks);
diff --git a/tensorflow/core/data/dataset_utils_test.cc b/tensorflow/core/data/dataset_utils_test.cc
index e581f6e..2e107eb 100644
--- a/tensorflow/core/data/dataset_utils_test.cc
+++ b/tensorflow/core/data/dataset_utils_test.cc
@@ -359,11 +359,10 @@
   auto opt_ins = test_case.opt_ins;
   auto opt_outs = test_case.opt_outs;
   if (!opt_ins.empty()) {
-    setenv("TF_DATA_EXPERIMENT_OPT_IN", str_util::Join(opt_ins, ",").c_str(),
-           1);
+    setenv("TF_DATA_EXPERIMENT_OPT_IN", absl::StrJoin(opt_ins, ",").c_str(), 1);
   }
   if (!opt_outs.empty()) {
-    setenv("TF_DATA_EXPERIMENT_OPT_OUT", str_util::Join(opt_outs, ",").c_str(),
+    setenv("TF_DATA_EXPERIMENT_OPT_OUT", absl::StrJoin(opt_outs, ",").c_str(),
            1);
   }
   const std::string job_name = "job";
@@ -376,14 +375,14 @@
   for (const auto& experiment : test_case.expected_in) {
     EXPECT_TRUE(experiment_set.find(experiment) != experiment_set.end())
         << "experiment=" << experiment << " opt_ins={"
-        << str_util::Join(opt_ins, ",") << "} opt_outs={"
-        << str_util::Join(opt_outs, ",") << "}";
+        << absl::StrJoin(opt_ins, ",") << "} opt_outs={"
+        << absl::StrJoin(opt_outs, ",") << "}";
   }
   for (const auto& experiment : test_case.expected_out) {
     EXPECT_TRUE(experiment_set.find(experiment) == experiment_set.end())
         << "experiment=" << experiment << " opt_ins={"
-        << str_util::Join(opt_ins, ",") << "} opt_outs={"
-        << str_util::Join(opt_outs, ",") << "}";
+        << absl::StrJoin(opt_ins, ",") << "} opt_outs={"
+        << absl::StrJoin(opt_outs, ",") << "}";
   }
 
   if (!opt_ins.empty()) {
diff --git a/tensorflow/core/data/global_shuffle_utils.cc b/tensorflow/core/data/global_shuffle_utils.cc
index 132a35f..dc42563 100644
--- a/tensorflow/core/data/global_shuffle_utils.cc
+++ b/tensorflow/core/data/global_shuffle_utils.cc
@@ -16,10 +16,13 @@
 
 #include <cstdint>
 #include <optional>
+#include <string>
 #include <vector>
 
 #include "absl/status/status.h"
+#include "absl/status/statusor.h"
 #include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
 #include "absl/synchronization/mutex.h"
 #include "tensorflow/core/framework/dataset.h"
 #include "tensorflow/core/framework/tensor.h"
@@ -29,6 +32,13 @@
 namespace tensorflow {
 namespace data {
 
+namespace {
+
+constexpr absl::string_view kGlobalShuffleIteratorNextIndex =
+    "global_shuffle_iterator_next_index";
+
+}
+
 IteratorContextWithIndexMapper::IteratorContextWithIndexMapper(
     IteratorContext* ctx, const IteratorBase* iterator)
     : ctx_(ctx) {
@@ -60,10 +70,22 @@
   }
 
   absl::MutexLock l(&mu_);
-  TF_ASSIGN_OR_RETURN(int64_t output_index,
-                      ctx->index_mapper()(element_count_++));
+  absl::StatusOr<int64_t> shuffled_index =
+      absl::NotFoundError("Default not found");
+
+  while (absl::IsNotFound(shuffled_index.status())) {
+    shuffled_index = ctx->index_mapper()(element_count_++);
+  }
+
+  if (absl::IsOutOfRange(shuffled_index.status())) {
+    *end_of_sequence = true;
+    return absl::OkStatus();
+  }
+
+  TF_RETURN_IF_ERROR(shuffled_index.status());
+
   absl::Status status =
-      dataset_->Get(AnyContext(ctx), output_index, out_tensors);
+      dataset_->Get(AnyContext(ctx), shuffled_index.value(), out_tensors);
   if (absl::IsOutOfRange(status)) {
     *end_of_sequence = true;
     return absl::OkStatus();
@@ -73,7 +95,18 @@
   return absl::OkStatus();
 }
 
-absl::Status GlobalShuffleIterator::Restore(IteratorContext* ctx) {
+absl::Status GlobalShuffleIterator::Save(
+    const std::string& parent_iterator_prefix, SerializationContext* ctx,
+    IteratorStateWriter* writer) {
+  absl::MutexLock l(&mu_);
+  TF_RETURN_IF_ERROR(writer->WriteScalar(
+      parent_iterator_prefix, kGlobalShuffleIteratorNextIndex, element_count_));
+  return absl::OkStatus();
+}
+
+absl::Status GlobalShuffleIterator::Restore(
+    const std::string& parent_iterator_prefix, IteratorContext* ctx,
+    IteratorStateReader* reader) {
   if (!ctx->restored_element_count().has_value()) {
     return absl::FailedPreconditionError(absl::StrCat(
         "Trying to restore random element count for dataset ",
@@ -81,7 +114,9 @@
   }
 
   absl::MutexLock l(&mu_);
-  element_count_ = *(ctx->restored_element_count());
+  TF_RETURN_IF_ERROR(reader->ReadScalar(parent_iterator_prefix,
+                                        kGlobalShuffleIteratorNextIndex,
+                                        &element_count_));
   return absl::OkStatus();
 }
 
diff --git a/tensorflow/core/data/global_shuffle_utils.h b/tensorflow/core/data/global_shuffle_utils.h
index 91b4fa0..c7513a0 100644
--- a/tensorflow/core/data/global_shuffle_utils.h
+++ b/tensorflow/core/data/global_shuffle_utils.h
@@ -75,9 +75,13 @@
   absl::Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
                        bool* end_of_sequence);
 
+  absl::Status Save(const std::string& parent_iterator_prefix,
+                    SerializationContext* ctx, IteratorStateWriter* writer);
+
   // Restores the element count.
   // REQUIRES: ctx->restored_element_count() != nullopt.
-  absl::Status Restore(IteratorContext* ctx);
+  absl::Status Restore(const std::string& parent_iterator_prefix,
+                       IteratorContext* ctx, IteratorStateReader* reader);
 
  private:
   const DatasetBase* const dataset_;
diff --git a/tensorflow/core/data/service/BUILD b/tensorflow/core/data/service/BUILD
index fe265ba..3bfbac2 100644
--- a/tensorflow/core/data/service/BUILD
+++ b/tensorflow/core/data/service/BUILD
@@ -680,8 +680,6 @@
     # copybara:uncomment copts = ["-Wthread-safety-analysis"],
     deps = [
         ":credentials_factory",
-        "//tensorflow/core:framework",
-        "//tensorflow/core/data:dataset_utils",
     ],
 )
 
@@ -758,8 +756,8 @@
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
         "@com_google_absl//absl/strings",
-        "@local_tsl//tsl/lib/core:status_test_util",
         "@local_tsl//tsl/platform:statusor",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
 
diff --git a/tensorflow/core/data/service/auto_scaler_test.cc b/tensorflow/core/data/service/auto_scaler_test.cc
index c04ea49..299715d 100644
--- a/tensorflow/core/data/service/auto_scaler_test.cc
+++ b/tensorflow/core/data/service/auto_scaler_test.cc
@@ -18,9 +18,9 @@
 #include <optional>
 
 #include "absl/time/time.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/metrics.h"
 #include "tensorflow/core/lib/monitoring/cell_reader.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/status_matchers.h"
 
 namespace tensorflow {
diff --git a/tensorflow/core/data/service/client/BUILD b/tensorflow/core/data/service/client/BUILD
index 16bd0ef..34b0b7a 100644
--- a/tensorflow/core/data/service/client/BUILD
+++ b/tensorflow/core/data/service/client/BUILD
@@ -120,10 +120,10 @@
         "//tensorflow/core/data/service:dispatcher_client",
         "//tensorflow/core/data/service:test_cluster",
         "//tensorflow/core/data/service:test_util",
-        "@local_tsl//tsl/lib/core:status_test_util",
         "@local_tsl//tsl/platform:errors",
         "@local_tsl//tsl/platform:status_matchers",
         "@local_tsl//tsl/protobuf:protos_all_cc",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ] + tf_grpc_cc_dependencies() + tf_protos_profiler_service(),
 )
 
diff --git a/tensorflow/core/data/service/client/data_service_client_test.cc b/tensorflow/core/data/service/client/data_service_client_test.cc
index 8ec654b..09b1ede 100644
--- a/tensorflow/core/data/service/client/data_service_client_test.cc
+++ b/tensorflow/core/data/service/client/data_service_client_test.cc
@@ -21,6 +21,7 @@
 #include <vector>
 
 #include "absl/time/time.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/data/service/client/common.h"
 #include "tensorflow/core/data/service/common.h"
 #include "tensorflow/core/data/service/test_cluster.h"
@@ -33,7 +34,6 @@
 #include "tensorflow/core/protobuf/config.pb.h"
 #include "tensorflow/core/protobuf/data_service.pb.h"
 #include "tensorflow/core/protobuf/error_codes.pb.h"
-#include "tsl/lib/core/status_test_util.h"
 
 namespace tensorflow {
 namespace data {
diff --git a/tensorflow/core/data/service/client/utils_test.cc b/tensorflow/core/data/service/client/utils_test.cc
index c3d9451..8729bff 100644
--- a/tensorflow/core/data/service/client/utils_test.cc
+++ b/tensorflow/core/data/service/client/utils_test.cc
@@ -17,13 +17,13 @@
 #include <optional>
 #include <string>
 
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/data/service/dispatcher_client.h"
 #include "tensorflow/core/data/service/test_cluster.h"
 #include "tensorflow/core/data/service/test_util.h"
 #include "tensorflow/core/framework/dataset.h"
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/protobuf/data_service.pb.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/errors.h"
 #include "tsl/platform/status_matchers.h"
 #include "tsl/protobuf/error_codes.pb.h"
diff --git a/tensorflow/core/data/service/dispatcher_state_test.cc b/tensorflow/core/data/service/dispatcher_state_test.cc
index b1341be..e561ecb 100644
--- a/tensorflow/core/data/service/dispatcher_state_test.cc
+++ b/tensorflow/core/data/service/dispatcher_state_test.cc
@@ -21,13 +21,13 @@
 
 #include "absl/strings/numbers.h"
 #include "absl/strings/str_cat.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/data/service/common.pb.h"
 #include "tensorflow/core/data/service/journal.pb.h"
 #include "tensorflow/core/platform/random.h"
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/protobuf/data_service.pb.h"
 #include "tensorflow/core/protobuf/service_config.pb.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/status_matchers.h"
 
 namespace tensorflow {
diff --git a/tensorflow/core/data/service/py_utils.cc b/tensorflow/core/data/service/py_utils.cc
index d14e1c9..be5308d 100644
--- a/tensorflow/core/data/service/py_utils.cc
+++ b/tensorflow/core/data/service/py_utils.cc
@@ -17,9 +17,7 @@
 
 #include <string>
 
-#include "tensorflow/core/data/dataset_utils.h"
 #include "tensorflow/core/data/service/credentials_factory.h"
-#include "tensorflow/core/framework/metrics.h"
 
 namespace tensorflow {
 namespace data {
@@ -39,17 +37,5 @@
   return "grpc";
 }
 
-bool DisableCompressionAtRegistrationTime() {
-#if defined(PLATFORM_GOOGLE)
-  if (!GetExperiments().contains("no_compression_v2")) {
-    return false;
-  }
-  metrics::RecordTFDataServiceCompressionAction(
-      "disabled_at_registration_time");
-  return true;
-#endif  // PLATFORM_GOOGLE
-  return false;
-}
-
 }  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/data/service/py_utils.h b/tensorflow/core/data/service/py_utils.h
index 010c155..b0ea892 100644
--- a/tensorflow/core/data/service/py_utils.h
+++ b/tensorflow/core/data/service/py_utils.h
@@ -27,10 +27,6 @@
 // Returns the default protocol to use for tf.data service control flow.
 std::string DefaultProtocol();
 
-// Returns `true` if tf.data service compression is to be disabled at
-// registration time.
-bool DisableCompressionAtRegistrationTime();
-
 }  // namespace data
 }  // namespace tensorflow
 
diff --git a/tensorflow/core/data/service/snapshot/BUILD b/tensorflow/core/data/service/snapshot/BUILD
index 40b5cba..35d28b0 100644
--- a/tensorflow/core/data/service/snapshot/BUILD
+++ b/tensorflow/core/data/service/snapshot/BUILD
@@ -150,13 +150,13 @@
         "@com_google_absl//absl/status",
         "@com_google_absl//absl/status:statusor",
         "@com_google_absl//absl/strings",
-        "@local_tsl//tsl/lib/core:status_test_util",
         "@local_tsl//tsl/platform:env",
         "@local_tsl//tsl/platform:errors",
         "@local_tsl//tsl/platform:path",
         "@local_tsl//tsl/platform:status_matchers",
         "@local_tsl//tsl/platform:statusor",
         "@local_tsl//tsl/platform:test",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
 
diff --git a/tensorflow/core/data/service/snapshot/distributed_snapshot_test.cc b/tensorflow/core/data/service/snapshot/distributed_snapshot_test.cc
index 71e9bfc..8974964 100644
--- a/tensorflow/core/data/service/snapshot/distributed_snapshot_test.cc
+++ b/tensorflow/core/data/service/snapshot/distributed_snapshot_test.cc
@@ -20,6 +20,7 @@
 
 #include "absl/status/status.h"
 #include "absl/time/time.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/data/service/dispatcher_client.h"
 #include "tensorflow/core/data/service/snapshot/path_utils.h"
 #include "tensorflow/core/data/service/snapshot/test_utils.h"
@@ -28,7 +29,6 @@
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/types.pb.h"
 #include "tensorflow/core/protobuf/snapshot.pb.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/lib/io/compression.h"
 #include "tsl/platform/env.h"
 #include "tsl/platform/path.h"
diff --git a/tensorflow/core/data/service/snapshot/file_utils_test.cc b/tensorflow/core/data/service/snapshot/file_utils_test.cc
index 9bf1e52..9582cab 100644
--- a/tensorflow/core/data/service/snapshot/file_utils_test.cc
+++ b/tensorflow/core/data/service/snapshot/file_utils_test.cc
@@ -18,13 +18,13 @@
 #include <string>
 #include <vector>
 
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/data/dataset_test_base.h"
 #include "tensorflow/core/data/service/test_util.h"
 #include "tensorflow/core/data/snapshot_utils.h"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/tensor_shape.h"
 #include "tensorflow/core/framework/types.pb.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/lib/io/compression.h"
 #include "tsl/platform/env.h"
 #include "tsl/platform/errors.h"
diff --git a/tensorflow/core/data/service/snapshot/parallel_tfrecord_writer_test.cc b/tensorflow/core/data/service/snapshot/parallel_tfrecord_writer_test.cc
index d55e7d1..43944c6 100644
--- a/tensorflow/core/data/service/snapshot/parallel_tfrecord_writer_test.cc
+++ b/tensorflow/core/data/service/snapshot/parallel_tfrecord_writer_test.cc
@@ -30,10 +30,10 @@
 #include "absl/status/status.h"
 #include "absl/status/statusor.h"
 #include "absl/strings/str_cat.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/data/service/byte_size.h"
 #include "tensorflow/core/data/snapshot_utils.h"
 #include "tensorflow/core/framework/tensor.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/lib/io/compression.h"
 #include "tsl/platform/env.h"
 #include "tsl/platform/errors.h"
diff --git a/tensorflow/core/data/service/snapshot/prefetched_split_provider_test.cc b/tensorflow/core/data/service/snapshot/prefetched_split_provider_test.cc
index 0f0c9f9..1e019a1 100644
--- a/tensorflow/core/data/service/snapshot/prefetched_split_provider_test.cc
+++ b/tensorflow/core/data/service/snapshot/prefetched_split_provider_test.cc
@@ -31,13 +31,13 @@
 #include "absl/status/statusor.h"
 #include "absl/strings/str_cat.h"
 #include "absl/synchronization/mutex.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/data/service/common.pb.h"
 #include "tensorflow/core/data/service/split_provider.h"
 #include "tensorflow/core/data/service/test_util.h"
 #include "tensorflow/core/data/snapshot_utils.h"
 #include "tensorflow/core/framework/dataset.h"
 #include "tensorflow/core/framework/tensor.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/lib/io/compression.h"
 #include "tsl/platform/env.h"
 #include "tsl/platform/errors.h"
diff --git a/tensorflow/core/data/service/snapshot/snapshot_chunk_provider_test.cc b/tensorflow/core/data/service/snapshot/snapshot_chunk_provider_test.cc
index 4dc6f23..e40fd0a 100644
--- a/tensorflow/core/data/service/snapshot/snapshot_chunk_provider_test.cc
+++ b/tensorflow/core/data/service/snapshot/snapshot_chunk_provider_test.cc
@@ -26,13 +26,13 @@
 #include "absl/status/statusor.h"
 #include "absl/strings/str_cat.h"
 #include "absl/synchronization/mutex.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/data/serialization_utils.h"
 #include "tensorflow/core/data/service/snapshot/file_utils.h"
 #include "tensorflow/core/data/service/snapshot/path_utils.h"
 #include "tensorflow/core/framework/dataset.h"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/variant_tensor_data.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/env.h"
 #include "tsl/platform/errors.h"
 #include "tsl/platform/path.h"
diff --git a/tensorflow/core/data/service/snapshot/snapshot_manager_test.cc b/tensorflow/core/data/service/snapshot/snapshot_manager_test.cc
index 9e11653..65b3c59 100644
--- a/tensorflow/core/data/service/snapshot/snapshot_manager_test.cc
+++ b/tensorflow/core/data/service/snapshot/snapshot_manager_test.cc
@@ -17,13 +17,13 @@
 #include <memory>
 #include <string>
 
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/data/service/common.pb.h"
 #include "tensorflow/core/data/service/dispatcher.pb.h"
 #include "tensorflow/core/data/service/snapshot/path_utils.h"
 #include "tensorflow/core/data/service/test_util.h"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/tensor.pb.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/env.h"
 #include "tsl/platform/status.h"
 #include "tsl/platform/status_matchers.h"
diff --git a/tensorflow/core/data/service/snapshot/snapshot_split_provider.cc b/tensorflow/core/data/service/snapshot/snapshot_split_provider.cc
index 96c7c3e..5a6f820 100644
--- a/tensorflow/core/data/service/snapshot/snapshot_split_provider.cc
+++ b/tensorflow/core/data/service/snapshot/snapshot_split_provider.cc
@@ -49,8 +49,9 @@
 constexpr char kRepetitionIndex[] = "repetition_index";
 
 absl::StatusOr<int64_t> GetRepetitionIndex(const std::string& split_file) {
-  tsl::StringPiece repetition_dir_path = tsl::io::Dirname(split_file);
-  tsl::StringPiece repetition_dir_name = tsl::io::Basename(repetition_dir_path);
+  absl::string_view repetition_dir_path = tsl::io::Dirname(split_file);
+  absl::string_view repetition_dir_name =
+      tsl::io::Basename(repetition_dir_path);
   return ParseRepetitionDirectoryName(repetition_dir_name);
 }
 }  // namespace
diff --git a/tensorflow/core/data/service/snapshot/snapshot_split_provider_test.cc b/tensorflow/core/data/service/snapshot/snapshot_split_provider_test.cc
index 2d6a9fd..b9b9f3d 100644
--- a/tensorflow/core/data/service/snapshot/snapshot_split_provider_test.cc
+++ b/tensorflow/core/data/service/snapshot/snapshot_split_provider_test.cc
@@ -23,6 +23,7 @@
 #include "absl/status/status.h"
 #include "absl/strings/str_cat.h"
 #include "absl/time/time.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/data/serialization_utils.h"
 #include "tensorflow/core/data/service/common.pb.h"
 #include "tensorflow/core/data/service/dispatcher_client.h"
@@ -34,7 +35,6 @@
 #include "tensorflow/core/framework/tensor_testutil.h"
 #include "tensorflow/core/framework/variant_tensor_data.h"
 #include "tensorflow/core/protobuf/snapshot.pb.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/lib/io/compression.h"
 #include "tsl/platform/env.h"
 #include "tsl/platform/errors.h"
diff --git a/tensorflow/core/data/service/snapshot/snapshot_stream_writer_checkpoint_test.cc b/tensorflow/core/data/service/snapshot/snapshot_stream_writer_checkpoint_test.cc
index 85d401b..071c7e1 100644
--- a/tensorflow/core/data/service/snapshot/snapshot_stream_writer_checkpoint_test.cc
+++ b/tensorflow/core/data/service/snapshot/snapshot_stream_writer_checkpoint_test.cc
@@ -21,6 +21,7 @@
 #include "absl/status/status.h"
 #include "absl/status/statusor.h"
 #include "absl/time/time.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/data/service/byte_size.h"
 #include "tensorflow/core/data/service/common.pb.h"
 #include "tensorflow/core/data/service/snapshot/path_utils.h"
@@ -28,7 +29,6 @@
 #include "tensorflow/core/data/service/snapshot/test_utils.h"
 #include "tensorflow/core/data/service/task_runner.h"
 #include "tensorflow/core/data/service/test_util.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/lib/io/compression.h"
 #include "tsl/platform/env.h"
 #include "tsl/platform/random.h"
diff --git a/tensorflow/core/data/service/snapshot/snapshot_stream_writer_test.cc b/tensorflow/core/data/service/snapshot/snapshot_stream_writer_test.cc
index 84243d8..f918341 100644
--- a/tensorflow/core/data/service/snapshot/snapshot_stream_writer_test.cc
+++ b/tensorflow/core/data/service/snapshot/snapshot_stream_writer_test.cc
@@ -25,6 +25,7 @@
 #include "absl/status/status.h"
 #include "absl/status/statusor.h"
 #include "absl/strings/str_cat.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/data/service/byte_size.h"
 #include "tensorflow/core/data/service/common.pb.h"
 #include "tensorflow/core/data/service/snapshot/file_utils.h"
@@ -35,7 +36,6 @@
 #include "tensorflow/core/data/snapshot_utils.h"
 #include "tensorflow/core/data/standalone.h"
 #include "tensorflow/core/framework/tensor.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/lib/io/compression.h"
 #include "tsl/lib/monitoring/cell_reader.h"
 #include "tsl/platform/env.h"
diff --git a/tensorflow/core/data/service/split_provider_test.cc b/tensorflow/core/data/service/split_provider_test.cc
index 08adc90..d311db2 100644
--- a/tensorflow/core/data/service/split_provider_test.cc
+++ b/tensorflow/core/data/service/split_provider_test.cc
@@ -21,10 +21,10 @@
 #include <vector>
 
 #include "absl/strings/str_cat.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/data/service/common.pb.h"
 #include "tensorflow/core/data/service/test_util.h"
 #include "tensorflow/core/framework/dataset.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/statusor.h"
 #include "tsl/platform/test.h"
 
diff --git a/tensorflow/core/data/standalone_save_restore_test.cc b/tensorflow/core/data/standalone_save_restore_test.cc
index fd16369..9798021 100644
--- a/tensorflow/core/data/standalone_save_restore_test.cc
+++ b/tensorflow/core/data/standalone_save_restore_test.cc
@@ -17,11 +17,11 @@
 #include <utility>
 #include <vector>
 
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/data/service/common.pb.h"
 #include "tensorflow/core/data/service/test_util.h"
 #include "tensorflow/core/data/standalone.h"
 #include "tensorflow/core/framework/tensor.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/errors.h"
 #include "tsl/platform/status.h"
 #include "tsl/platform/status_matchers.h"
diff --git a/tensorflow/core/data/standalone_test.cc b/tensorflow/core/data/standalone_test.cc
index 54f438b..fac2a9e 100644
--- a/tensorflow/core/data/standalone_test.cc
+++ b/tensorflow/core/data/standalone_test.cc
@@ -19,10 +19,10 @@
 #include <optional>
 #include <vector>
 
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/graph.pb.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/test.h"
-#include "tsl/lib/core/status_test_util.h"
 
 namespace tensorflow {
 namespace data {
diff --git a/tensorflow/core/debug/BUILD b/tensorflow/core/debug/BUILD
index 6c82115..fb01f41 100644
--- a/tensorflow/core/debug/BUILD
+++ b/tensorflow/core/debug/BUILD
@@ -182,11 +182,8 @@
         ":debug_grpc_testlib",
         ":debug_io_utils",
         ":debug_node_key",
-        ":debug_service_proto_cc",
         ":debugger_event_metadata_proto_cc",
-        "//tensorflow/core:core_cpu",
         "//tensorflow/core:core_cpu_internal",
-        "//tensorflow/core:framework",
         "//tensorflow/core:framework_internal",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
@@ -194,7 +191,6 @@
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
         "//tensorflow/core:testlib",
-        "//tensorflow/core/platform/default/build_config:platformlib",
     ],
 )
 
@@ -260,7 +256,6 @@
         "//tensorflow/core:framework_internal",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
-        "//tensorflow/core:protos_all_cc",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
     ],
diff --git a/tensorflow/core/debug/debug_grpc_io_utils_test.cc b/tensorflow/core/debug/debug_grpc_io_utils_test.cc
index 3eaf365..87aea15 100644
--- a/tensorflow/core/debug/debug_grpc_io_utils_test.cc
+++ b/tensorflow/core/debug/debug_grpc_io_utils_test.cc
@@ -13,6 +13,8 @@
 limitations under the License.
 ==============================================================================*/
 
+#include <memory>
+
 #include "tensorflow/core/debug/debug_graph_utils.h"
 #include "tensorflow/core/debug/debug_grpc_testlib.h"
 #include "tensorflow/core/debug/debug_io_utils.h"
@@ -47,10 +49,10 @@
                             int64_t server_start_delay_micros) {
     server_data->port = testing::PickUnusedPortOrDie();
     server_data->url = strings::StrCat("grpc://localhost:", server_data->port);
-    server_data->server.reset(new test::TestEventListenerImpl());
+    server_data->server = std::make_unique<test::TestEventListenerImpl>();
 
-    server_data->thread_pool.reset(
-        new thread::ThreadPool(Env::Default(), "test_server", 1));
+    server_data->thread_pool =
+        std::make_unique<thread::ThreadPool>(Env::Default(), "test_server", 1);
     server_data->thread_pool->Schedule(
         [server_data, server_start_delay_micros]() {
           Env::Default()->SleepForMicroseconds(server_start_delay_micros);
diff --git a/tensorflow/core/debug/debug_grpc_testlib.h b/tensorflow/core/debug/debug_grpc_testlib.h
index 18009a3..2a57df8 100644
--- a/tensorflow/core/debug/debug_grpc_testlib.h
+++ b/tensorflow/core/debug/debug_grpc_testlib.h
@@ -39,7 +39,7 @@
   ::grpc::Status SendEvents(
       ::grpc::ServerContext* context,
       ::grpc::ServerReaderWriter< ::tensorflow::EventReply,
-                                  ::tensorflow::Event>* stream);
+                                  ::tensorflow::Event>* stream) override;
 
   // Clear debug data (e.g., Tensors) received so far.
   void ClearReceivedDebugData();
diff --git a/tensorflow/core/debug/debug_io_utils_test.cc b/tensorflow/core/debug/debug_io_utils_test.cc
index dad4360..74d5758 100644
--- a/tensorflow/core/debug/debug_io_utils_test.cc
+++ b/tensorflow/core/debug/debug_io_utils_test.cc
@@ -13,11 +13,12 @@
 limitations under the License.
 ==============================================================================*/
 
-#include <cstdlib>
-#include <unordered_set>
-
 #include "tensorflow/core/debug/debug_io_utils.h"
 
+#include <cstdlib>
+#include <memory>
+#include <unordered_set>
+
 #include "tensorflow/core/debug/debug_callback_registry.h"
 #include "tensorflow/core/debug/debug_node_key.h"
 #include "tensorflow/core/debug/debugger_event_metadata.pb.h"
@@ -40,7 +41,7 @@
   void Initialize() {
     env_ = Env::Default();
 
-    tensor_a_.reset(new Tensor(DT_FLOAT, TensorShape({2, 2})));
+    tensor_a_ = std::make_unique<Tensor>(DT_FLOAT, TensorShape({2, 2}));
     tensor_a_->flat<float>()(0) = 5.0;
     tensor_a_->flat<float>()(1) = 3.0;
     tensor_a_->flat<float>()(2) = -1.0;
diff --git a/tensorflow/core/debug/debugger_state_impl.h b/tensorflow/core/debug/debugger_state_impl.h
index 9de9bdc..4114d68 100644
--- a/tensorflow/core/debug/debugger_state_impl.h
+++ b/tensorflow/core/debug/debugger_state_impl.h
@@ -26,7 +26,7 @@
 class DebuggerState : public DebuggerStateInterface {
  public:
   DebuggerState(const DebugOptions& debug_options);
-  virtual ~DebuggerState();
+  ~DebuggerState() override;
 
   // Publish metadata about the debugged Session::Run() call.
   //
@@ -47,7 +47,7 @@
  public:
   DebugGraphDecorator(const DebugOptions& debug_options)
       : debug_options_(debug_options) {}
-  virtual ~DebugGraphDecorator() {}
+  ~DebugGraphDecorator() override {}
 
   Status DecorateGraph(Graph* graph, Device* device) override;
   Status PublishGraph(const Graph& graph, const string& device_name) override;
diff --git a/tensorflow/core/distributed_runtime/integration_test/BUILD b/tensorflow/core/distributed_runtime/integration_test/BUILD
index 4927d6f..7408bcb 100644
--- a/tensorflow/core/distributed_runtime/integration_test/BUILD
+++ b/tensorflow/core/distributed_runtime/integration_test/BUILD
@@ -52,7 +52,7 @@
         "//tensorflow/core/platform:blocking_counter",
         "//tensorflow/core/platform:env",
         "@com_google_absl//absl/time",
-        "@local_tsl//tsl/lib/core:status_test_util",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
 
@@ -168,6 +168,6 @@
         "//tensorflow/core/distributed_runtime:server_lib",
         "//tensorflow/core/platform:env",
         "@com_google_absl//absl/time",
-        "@local_tsl//tsl/lib/core:status_test_util",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
diff --git a/tensorflow/core/distributed_runtime/integration_test/c_api_coordination_test.cc b/tensorflow/core/distributed_runtime/integration_test/c_api_coordination_test.cc
index 4e0bd6f..356f0a0 100644
--- a/tensorflow/core/distributed_runtime/integration_test/c_api_coordination_test.cc
+++ b/tensorflow/core/distributed_runtime/integration_test/c_api_coordination_test.cc
@@ -23,6 +23,7 @@
 #include "tensorflow/c/eager/c_api_internal.h"
 #include "tensorflow/c/eager/c_api_test_util.h"
 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/distributed_runtime/server_lib.h"
 #include "tensorflow/core/framework/function.pb.h"
 #include "tensorflow/core/platform/blocking_counter.h"
@@ -31,7 +32,6 @@
 #include "tensorflow/core/protobuf/cluster.pb.h"
 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
 #include "tensorflow/core/protobuf/tensorflow_server.pb.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/protobuf/coordination_config.pb.h"
 
 namespace tensorflow {
diff --git a/tensorflow/core/distributed_runtime/integration_test/c_api_multi_client_test.cc b/tensorflow/core/distributed_runtime/integration_test/c_api_multi_client_test.cc
index ba48750..cffe93d 100644
--- a/tensorflow/core/distributed_runtime/integration_test/c_api_multi_client_test.cc
+++ b/tensorflow/core/distributed_runtime/integration_test/c_api_multi_client_test.cc
@@ -185,8 +185,7 @@
     if (worker_id == 0) {
       TFE_TensorHandle* in = TestMatrixTensorHandle(ctx);
       const std::string& op_name =
-          tensorflow::str_util::StrContains(send_device, "GPU") ? "Send"
-                                                                : "_HostSend";
+          absl::StrContains(send_device, "GPU") ? "Send" : "_HostSend";
       TFE_Op* sendop = SendOp(ctx, in, op_name, send_device, recv_device,
                               send_device_incarnation);
       TFE_TensorHandle* retvals[1];
@@ -197,8 +196,7 @@
       TFE_DeleteTensorHandle(in);
     } else {
       const std::string& op_name =
-          tensorflow::str_util::StrContains(send_device, "GPU") ? "Recv"
-                                                                : "_HostRecv";
+          absl::StrContains(send_device, "GPU") ? "Recv" : "_HostRecv";
       TFE_Op* recvop = RecvOp(ctx, op_name, send_device, recv_device,
                               send_device_incarnation);
       TFE_TensorHandle* retvals[1];
diff --git a/tensorflow/core/distributed_runtime/integration_test/c_api_recoverable_jobs_test.cc b/tensorflow/core/distributed_runtime/integration_test/c_api_recoverable_jobs_test.cc
index dde7f65..3d9ff3c 100644
--- a/tensorflow/core/distributed_runtime/integration_test/c_api_recoverable_jobs_test.cc
+++ b/tensorflow/core/distributed_runtime/integration_test/c_api_recoverable_jobs_test.cc
@@ -24,6 +24,7 @@
 #include "tensorflow/c/eager/c_api_internal.h"
 #include "tensorflow/c/eager/c_api_test_util.h"
 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/distributed_runtime/server_lib.h"
 #include "tensorflow/core/framework/function.pb.h"
 #include "tensorflow/core/platform/strcat.h"
@@ -31,7 +32,6 @@
 #include "tensorflow/core/protobuf/cluster.pb.h"
 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
 #include "tensorflow/core/protobuf/tensorflow_server.pb.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/protobuf/coordination_config.pb.h"
 
 namespace tensorflow {
diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc
index 7777bc6..2fcbc72 100644
--- a/tensorflow/core/distributed_runtime/master_session.cc
+++ b/tensorflow/core/distributed_runtime/master_session.cc
@@ -1868,14 +1868,17 @@
       DebuggerStateRegistry::CreateState(debug_options, debugger_state));
 
   std::vector<string> input_names;
+  input_names.reserve(req.num_feeds());
   for (size_t i = 0; i < req.num_feeds(); ++i) {
     input_names.push_back(req.feed_name(i));
   }
   std::vector<string> output_names;
+  output_names.reserve(req.num_fetches());
   for (size_t i = 0; i < req.num_fetches(); ++i) {
     output_names.push_back(req.fetch_name(i));
   }
   std::vector<string> target_names;
+  target_names.reserve(req.num_targets());
   for (size_t i = 0; i < req.num_targets(); ++i) {
     target_names.push_back(req.target_name(i));
   }
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc
index 56eb09e..c1026dc 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc
@@ -17,6 +17,7 @@
 
 #include <string>
 
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/common_runtime/device.h"
 #include "tensorflow/core/distributed_runtime/rpc/grpc_testlib.h"
 #include "tensorflow/core/framework/graph.pb.h"
@@ -33,7 +34,6 @@
 #include "tensorflow/core/protobuf/error_codes.pb.h"
 #include "tensorflow/core/public/session.h"
 #include "tensorflow/core/util/port.h"
-#include "tsl/lib/core/status_test_util.h"
 
 namespace tensorflow {
 
diff --git a/tensorflow/core/framework/BUILD b/tensorflow/core/framework/BUILD
index a0b39c7..526085c 100644
--- a/tensorflow/core/framework/BUILD
+++ b/tensorflow/core/framework/BUILD
@@ -1948,7 +1948,7 @@
     deps = [
         "//tensorflow/core:framework",
         "//tensorflow/security/fuzzing/cc/core/framework:tensor_shape_domains",
-        "@local_tsl//tsl/lib/core:status_test_util",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
 
@@ -1962,7 +1962,7 @@
         "//tensorflow/security/fuzzing/cc/core/framework:datatype_domains",
         "//tensorflow/security/fuzzing/cc/core/framework:tensor_domains",
         "//tensorflow/security/fuzzing/cc/core/framework:tensor_shape_domains",
-        "@local_tsl//tsl/lib/core:status_test_util",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
 
diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h
index 9ebcd90..ca9e5a6 100644
--- a/tensorflow/core/framework/dataset.h
+++ b/tensorflow/core/framework/dataset.h
@@ -91,6 +91,14 @@
 
 // Maps the index of dataset elements to a globally shuffled index. See the
 // comment for IteratorContext::Params::index_mapper for more details.
+// Notes:
+// * `absl::OutOfRangeError` indicates the input index argument exceeds
+//   the cardinality of the dataset.
+// * `absl::NotFoundError` indicates we should skip this element.
+//    This happens in the case we mix multiple datasets into one. For example,
+//    `dataset1.concatenate(dataset2)`.
+// See go/tf-data-random-access-iterator and
+// go/tf-data-random-access-iterator-for-concatenate for more info.
 using IndexMapperFn = std::function<absl::StatusOr<size_t>(size_t)>;
 
 constexpr char kTFDataFunction[] = "_tf_data_function";
@@ -905,6 +913,10 @@
 
   IndexMapperFn index_mapper() const { return params_.index_mapper; }
 
+  void set_restored_element_count(size_t element_count) {
+    params_.restored_element_count.emplace(element_count);
+  }
+
   std::optional<int64_t> restored_element_count() const {
     return params_.restored_element_count;
   }
diff --git a/tensorflow/core/framework/dataset_test.cc b/tensorflow/core/framework/dataset_test.cc
index a632551..66213ea 100644
--- a/tensorflow/core/framework/dataset_test.cc
+++ b/tensorflow/core/framework/dataset_test.cc
@@ -21,11 +21,11 @@
 #include <gtest/gtest.h>
 #include "absl/container/flat_hash_set.h"
 #include "absl/status/status.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/tensor_testutil.h"
 #include "tensorflow/core/framework/types.pb.h"
 #include "tensorflow/core/platform/status.h"
 #include "tensorflow/core/platform/test.h"
-#include "tsl/lib/core/status_test_util.h"
 
 namespace tensorflow {
 namespace data {
diff --git a/tensorflow/core/framework/partial_tensor_shape_test.cc b/tensorflow/core/framework/partial_tensor_shape_test.cc
index e20a585..581989c 100644
--- a/tensorflow/core/framework/partial_tensor_shape_test.cc
+++ b/tensorflow/core/framework/partial_tensor_shape_test.cc
@@ -17,13 +17,13 @@
 
 #include <limits>
 
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/tensor_shape.h"
 #include "tensorflow/core/framework/tensor_shape.pb.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/status_matchers.h"
 #include "tensorflow/core/platform/test.h"
-#include "tsl/lib/core/status_test_util.h"
 
 namespace tensorflow {
 namespace {
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc
index d743669..71d856e 100644
--- a/tensorflow/core/framework/shape_inference.cc
+++ b/tensorflow/core/framework/shape_inference.cc
@@ -1288,6 +1288,10 @@
 
 bool InferenceContext::RelaxOutputHandleShapesAndMergeTypes(
     int idx, const std::vector<ShapeAndType>& shapes_and_types) {
+  CHECK_GE(idx, 0) << "idx must be non-negative. Got idx: " << idx << ".";
+  CHECK_LT(idx, output_handle_shapes_and_types_.size())
+      << "Got idx: " << idx << " but only "
+      << output_handle_shapes_and_types_.size() << " inputs.";
   if (output_handle_shapes_and_types_[idx] == nullptr) {
     output_handle_shapes_and_types_[idx].reset(
         new std::vector<ShapeAndType>(shapes_and_types));
@@ -1299,6 +1303,10 @@
 
 bool InferenceContext::RelaxInputHandleShapesAndMergeTypes(
     int idx, const std::vector<ShapeAndType>& shapes_and_types) {
+  CHECK_GE(idx, 0) << "idx must be non-negative. Got idx: " << idx << ".";
+  CHECK_LT(idx, input_handle_shapes_and_types_.size())
+      << "Got idx: " << idx << " but only "
+      << input_handle_shapes_and_types_.size() << " inputs.";
   if (input_handle_shapes_and_types_[idx] == nullptr) {
     input_handle_shapes_and_types_[idx].reset(
         new std::vector<ShapeAndType>(shapes_and_types));
diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h
index f00dac8..6ed932e 100644
--- a/tensorflow/core/framework/shape_inference.h
+++ b/tensorflow/core/framework/shape_inference.h
@@ -683,6 +683,10 @@
 
   void set_input_handle_shapes_and_types(
       int idx, const std::vector<ShapeAndType>& shapes_and_types) {
+    CHECK_GE(idx, 0) << "idx must be non-negative. Got idx: " << idx << ".";
+    CHECK_LT(idx, input_handle_shapes_and_types_.size())
+        << "Got idx: " << idx << " but only "
+        << input_handle_shapes_and_types_.size() << " inputs.";
     input_handle_shapes_and_types_[idx] =
         absl::make_unique<std::vector<ShapeAndType>>(shapes_and_types);
   }
@@ -690,17 +694,29 @@
   // Returns the output handle shapes and types, for the resource tensor output
   // at index <idx>. Returns NULL if the shape and types were never set.
   const std::vector<ShapeAndType>* output_handle_shapes_and_types(int idx) {
+    CHECK_GE(idx, 0) << "idx must be non-negative. Got idx: " << idx << ".";
+    CHECK_LT(idx, output_handle_shapes_and_types_.size())
+        << "Got idx: " << idx << " but only "
+        << output_handle_shapes_and_types_.size() << " outputs.";
     return output_handle_shapes_and_types_[idx].get();
   }
 
   // Returns the inputs handle shapes and types, for the resource tensor input
   // at index <idx>. Returns NULL if the shape and types were not available.
   const std::vector<ShapeAndType>* input_handle_shapes_and_types(int idx) {
+    CHECK_GE(idx, 0) << "idx must be non-negative. Got idx: " << idx << ".";
+    CHECK_LT(idx, input_handle_shapes_and_types_.size())
+        << "Got idx: " << idx << " but only "
+        << input_handle_shapes_and_types_.size() << " inputs.";
     return input_handle_shapes_and_types_[idx].get();
   }
 
   void set_output_handle_shapes_and_types(
       int idx, const std::vector<ShapeAndType>& shapes_and_types) {
+    CHECK_GE(idx, 0) << "idx must be non-negative. Got idx: " << idx << ".";
+    CHECK_LT(idx, output_handle_shapes_and_types_.size())
+        << "Got idx: " << idx << " but only "
+        << output_handle_shapes_and_types_.size() << " inputs.";
     output_handle_shapes_and_types_[idx] =
         absl::make_unique<std::vector<ShapeAndType>>(shapes_and_types);
   }
diff --git a/tensorflow/core/framework/tensor_fuzz.cc b/tensorflow/core/framework/tensor_fuzz.cc
index 5665185..49f91b0 100644
--- a/tensorflow/core/framework/tensor_fuzz.cc
+++ b/tensorflow/core/framework/tensor_fuzz.cc
@@ -13,13 +13,13 @@
 limitations under the License.
 ==============================================================================*/
 #include "fuzztest/fuzztest.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/tensor.pb.h"
 #include "tensorflow/core/framework/tensor_shape.h"
 #include "tensorflow/security/fuzzing/cc/core/framework/datatype_domains.h"
 #include "tensorflow/security/fuzzing/cc/core/framework/tensor_domains.h"
 #include "tensorflow/security/fuzzing/cc/core/framework/tensor_shape_domains.h"
-#include "tsl/lib/core/status_test_util.h"
 
 namespace tensorflow::fuzzing {
 namespace {
diff --git a/tensorflow/core/framework/tensor_shape_fuzz.cc b/tensorflow/core/framework/tensor_shape_fuzz.cc
index 7a0351a..d14284e 100644
--- a/tensorflow/core/framework/tensor_shape_fuzz.cc
+++ b/tensorflow/core/framework/tensor_shape_fuzz.cc
@@ -18,9 +18,9 @@
 #include <vector>
 
 #include "fuzztest/fuzztest.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/tensor_shape.h"
 #include "tensorflow/security/fuzzing/cc/core/framework/tensor_shape_domains.h"
-#include "tsl/lib/core/status_test_util.h"
 
 namespace tensorflow {
 namespace fuzzing {
diff --git a/tensorflow/core/framework/tensor_shape_test.cc b/tensorflow/core/framework/tensor_shape_test.cc
index e55cefa..c13a16f 100644
--- a/tensorflow/core/framework/tensor_shape_test.cc
+++ b/tensorflow/core/framework/tensor_shape_test.cc
@@ -18,6 +18,7 @@
 #include <cstdint>
 #include <limits>
 
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/tensor_shape.pb.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/lib/random/simple_philox.h"
@@ -28,7 +29,6 @@
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/platform/test_benchmark.h"
 #include "tensorflow/core/protobuf/error_codes.pb.h"
-#include "tsl/lib/core/status_test_util.h"
 
 namespace tensorflow {
 class TensorShapeTestHelper {
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index e0981fe..2bf4de1 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -444,7 +444,7 @@
 }
 
 bool IsQueue(const NodeDef& node) {
-  return str_util::EndsWith(node.op(), "QueueV2");
+  return absl::EndsWith(node.op(), "QueueV2");
 }
 
 bool IsRandomShuffle(const NodeDef& node) {
diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD
index bfed969..e23e1c0 100644
--- a/tensorflow/core/grappler/optimizers/data/BUILD
+++ b/tensorflow/core/grappler/optimizers/data/BUILD
@@ -1022,7 +1022,7 @@
         "//tensorflow/core/platform:status",
         "@com_google_absl//absl/strings",
         "@com_google_googletest//:gtest_main",
-        "@local_tsl//tsl/lib/core:status_test_util",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
 
diff --git a/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc
index f9d4f06..a212e25 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc
@@ -178,7 +178,8 @@
 NodeDef MakeParallelMapV2Node(StringPiece name, StringPiece input_node_name,
                               StringPiece num_parallel_calls_node_name,
                               StringPiece function_name,
-                              StringPiece deterministic) {
+                              StringPiece deterministic,
+                              bool use_unbounded_threadpool) {
   return test::function::NDef(
       name, "ParallelMapDatasetV2",
       {string(input_node_name), string(num_parallel_calls_node_name)},
@@ -188,6 +189,7 @@
           {"output_shapes", absl::Span<const TensorShape>{}},
           {"output_types", absl::Span<const DataType>{}},
           {"deterministic", string(deterministic)},
+          {"use_unbounded_threadpool", use_unbounded_threadpool},
       });
 }
 
diff --git a/tensorflow/core/grappler/optimizers/data/graph_test_utils.h b/tensorflow/core/grappler/optimizers/data/graph_test_utils.h
index 7341329..c5823d1 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_test_utils.h
+++ b/tensorflow/core/grappler/optimizers/data/graph_test_utils.h
@@ -89,7 +89,8 @@
 NodeDef MakeParallelMapV2Node(StringPiece name, StringPiece input_node_name,
                               StringPiece num_parallel_calls_node_name,
                               StringPiece function_name,
-                              StringPiece deterministic);
+                              StringPiece deterministic,
+                              bool use_unbounded_threadpool);
 
 // Creates a test NodeDef for ParseExampleDataset.
 NodeDef MakeParseExampleNode(StringPiece name, StringPiece input_node_name,
diff --git a/tensorflow/core/grappler/optimizers/data/inject_io_prefetch_test.cc b/tensorflow/core/grappler/optimizers/data/inject_io_prefetch_test.cc
index 54bf5fe..5cb93fa 100644
--- a/tensorflow/core/grappler/optimizers/data/inject_io_prefetch_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/inject_io_prefetch_test.cc
@@ -18,6 +18,7 @@
 #include <string>
 
 #include <gtest/gtest.h>
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/function.h"
 #include "tensorflow/core/framework/function.pb.h"
 #include "tensorflow/core/framework/function_testlib.h"
@@ -27,7 +28,6 @@
 #include "tensorflow/core/grappler/grappler_item.h"
 #include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
-#include "tsl/lib/core/status_test_util.h"
 
 namespace tensorflow {
 namespace grappler {
@@ -101,7 +101,7 @@
             {{"value", 1}, {"dtype", DT_INT32}}),
        graph_tests_utils::MakeParallelMapV2Node(
            "map_1", "io_1", "num_parallel_calls_1", "noop_1",
-           /*deterministic=*/"default"),
+           /*deterministic=*/"default", /*use_unbounded_threadpool=*/false),
 
        NDef("files_2", "Const", {},
             {{"value", "file1file2"}, {"dtype", DT_STRING}}),
@@ -114,7 +114,7 @@
             {{"value", 1}, {"dtype", DT_INT32}}),
        graph_tests_utils::MakeParallelMapV2Node(
            "map_2", "io_2", "num_parallel_calls_2", "noop_2",
-           /*deterministic=*/"default"),
+           /*deterministic=*/"default", /*use_unbounded_threadpool=*/false),
 
        NDef("zip", "ZipDataset", {"map_1", "map_2"}, {}),
        NDef("Sink", "Identity", {"zip"}, {})},
diff --git a/tensorflow/core/grappler/optimizers/data/make_deterministic_test.cc b/tensorflow/core/grappler/optimizers/data/make_deterministic_test.cc
index 1b76fee..1ff66f3 100644
--- a/tensorflow/core/grappler/optimizers/data/make_deterministic_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/make_deterministic_test.cc
@@ -85,7 +85,7 @@
   } else {
     orig_map_node_def = graph_tests_utils::MakeParallelMapV2Node(
         "map", "range", "num_parallel_calls", "MyFunction",
-        deterministic ? "true" : "false");
+        deterministic ? "true" : "false", /*use_unbounded_threadpool=*/false);
   }
   orig_map_node_def.add_input("^start");
   AttrValue* attr_val = &(*orig_map_node_def.mutable_attr())["Targuments"];
@@ -321,7 +321,8 @@
             {{"value", 1}, {"dtype", DT_INT32}}),
        graph_tests_utils::MakeParallelMapV2Node(
            "map", "range", "num_parallel_calls", func_name,
-           deterministic ? "true" : "false")},
+           deterministic ? "true" : "false",
+           /*use_unbounded_threadpool=*/false)},
       // FunctionLib
       {test::function::XTimesTwo(), OuterXTimesTwo()});
 
@@ -387,7 +388,8 @@
             {{"value", Tensor(int64_t{1})}, {"dtype", DT_INT64}}),
        graph_tests_utils::MakeParallelMapV2Node(
            "map", "range", "num_parallel_calls", func_name,
-           deterministic ? "true" : "false"),
+           deterministic ? "true" : "false",
+           /*use_unbounded_threadpool=*/false),
        graph_tests_utils::MakePrefetchNode("prefetch", "map", "buffer_size")},
       // FunctionLib
       {test::function::RandomUniform(), OuterRandomUniform()});
@@ -485,7 +487,7 @@
 
   NodeDef map_node_def = graph_tests_utils::MakeParallelMapV2Node(
       "map", "range", "num_parallel_calls", func_name,
-      deterministic ? "true" : "false");
+      deterministic ? "true" : "false", /*use_unbounded_threadpool=*/false);
   map_node_def.add_input("^start");
 
   // Rewrite occurs due to parallelism in map function
@@ -587,7 +589,8 @@
             {{"value", Tensor(int64_t{1})}, {"dtype", DT_INT64}}),
        graph_tests_utils::MakeParallelMapV2Node(
            "map", "range", "num_parallel_calls", func_name,
-           deterministic ? "true" : "false"),
+           deterministic ? "true" : "false",
+           /*use_unbounded_threadpool=*/false),
        graph_tests_utils::MakePrefetchNode("prefetch", "map", "buffer_size")},
       // FunctionLib
       {test::function::ReadResourceVariable(), OuterReadResourceVariable()});
diff --git a/tensorflow/core/grappler/optimizers/data/make_sloppy_test.cc b/tensorflow/core/grappler/optimizers/data/make_sloppy_test.cc
index 207a9cd..bf15420 100644
--- a/tensorflow/core/grappler/optimizers/data/make_sloppy_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/make_sloppy_test.cc
@@ -21,7 +21,6 @@
 #include "tensorflow/core/grappler/grappler_item.h"
 #include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
-
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/test.h"
 
@@ -147,7 +146,7 @@
             {{"value", 1}, {"dtype", DT_INT32}}),
        graph_tests_utils::MakeParallelMapV2Node(
            "map", "range", "num_parallel_calls", "XTimesTwo",
-           /*deterministic=*/"default")},
+           /*deterministic=*/"default", /*use_unbounded_threadpool=*/false)},
       // FunctionLib
       {
           test::function::XTimesTwo(),
diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc
index 69943e8..091e94d 100644
--- a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc
@@ -124,6 +124,11 @@
     if (node2->op() != "MapDataset" && !IsParallelMap(*node2)) {
       continue;
     }
+    // Do not fuse ParallelMap node that uses the unbounded thread pool.
+    if (node2->attr().find("use_unbounded_threadpool") != node2->attr().end() &&
+        node2->attr().at("use_unbounded_threadpool").b()) {
+      continue;
+    }
     // Use a more descriptive variable name now that we know the node type.
     NodeDef* map_node = node2;
 
diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc
index 74947cb..077123e 100644
--- a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc
@@ -402,6 +402,71 @@
   EXPECT_TRUE(graph_utils::Compare(*graph.graph(), output));
 }
 
+TEST(MapAndBatchFusionTest, NoChange_UnboundedThreadpoolParallelMap) {
+  GrapplerItem item;
+  MutableGraphView graph(&item.graph);
+  NodeDef *start_node = graph_utils::AddScalarConstNode<int64_t>(0, &graph);
+  NodeDef *stop_node = graph_utils::AddScalarConstNode<int64_t>(10, &graph);
+  NodeDef *step_node = graph_utils::AddScalarConstNode<int64_t>(1, &graph);
+
+  std::vector<string> range_inputs(3);
+  range_inputs[0] = start_node->name();
+  range_inputs[1] = stop_node->name();
+  range_inputs[2] = step_node->name();
+  std::vector<std::pair<string, AttrValue>> range_attrs;
+  NodeDef *range_node = graph_utils::AddNode("", "RangeDataset", range_inputs,
+                                             range_attrs, &graph);
+  NodeDef *captured_input_node =
+      graph_utils::AddScalarConstNode<StringPiece>("hello", &graph);
+  NodeDef *num_parallel_calls_node =
+      graph_utils::AddScalarConstNode<int>(2, &graph);
+
+  NodeDef *map_node;
+  {
+    std::vector<string> map_inputs(3);
+    map_inputs[0] = range_node->name();
+    map_inputs[1] = captured_input_node->name();
+    map_inputs[2] = num_parallel_calls_node->name();
+    std::vector<std::pair<string, AttrValue>> map_attrs(3);
+    AttrValue f_attr;
+    SetAttrValue("f", &f_attr);
+    map_attrs[0] = std::make_pair("f", f_attr);
+    AttrValue args_attr;
+    SetAttrValue("Targuments", &args_attr);
+    map_attrs[1] = std::make_pair("Targuments", args_attr);
+    AttrValue use_unbounded_threadpool_attr;
+    SetAttrValue(true, &use_unbounded_threadpool_attr);
+    map_attrs[2] = std::make_pair("use_unbounded_threadpool",
+                                  use_unbounded_threadpool_attr);
+    map_node = graph_utils::AddNode("", "ParallelMapDataset", map_inputs,
+                                    map_attrs, &graph);
+  }
+
+  NodeDef *batch_size_node =
+      graph_utils::AddScalarConstNode<int64_t>(5, &graph);
+  NodeDef *batch_node;
+  {
+    std::vector<string> batch_inputs(2);
+    batch_inputs[0] = map_node->name();
+    batch_inputs[1] = batch_size_node->name();
+    std::vector<std::pair<string, AttrValue>> batch_attrs(2);
+    AttrValue shapes_attr;
+    SetAttrValue("output_shapes", &shapes_attr);
+    batch_attrs[0] = std::make_pair("output_shapes", shapes_attr);
+    AttrValue types_attr;
+    SetAttrValue("output_types", &types_attr);
+    batch_attrs[1] = std::make_pair("output_types", types_attr);
+    batch_node = graph_utils::AddNode("", "BatchDataset", batch_inputs,
+                                      batch_attrs, &graph);
+  }
+
+  MapAndBatchFusion optimizer;
+  GraphDef output;
+  TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+
+  EXPECT_TRUE(graph_utils::Compare(*graph.graph(), output));
+}
+
 }  // namespace
 }  // namespace grappler
 }  // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/map_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_fusion.cc
index d2bf6a3..78e9eba 100644
--- a/tensorflow/core/grappler/optimizers/data/map_fusion.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_fusion.cc
@@ -216,10 +216,22 @@
   for (const NodeDef& node : sorted_old_graph.node()) {
     const NodeDef* map_node = get_map_node(node);
     if (!map_node) continue;
+    // Do not fuse ParallelMap node that uses the unbounded thread pool.
+    if (map_node->attr().find("use_unbounded_threadpool") !=
+            map_node->attr().end() &&
+        map_node->attr().at("use_unbounded_threadpool").b()) {
+      continue;
+    }
 
     const NodeDef* parent_map_node =
         get_map_node(*graph_utils::GetInputNode(*map_node, graph));
     if (!parent_map_node) continue;
+    // Do not fuse ParallelMap node that uses the unbounded thread pool.
+    if (parent_map_node->attr().find("use_unbounded_threadpool") !=
+            parent_map_node->attr().end() &&
+        parent_map_node->attr().at("use_unbounded_threadpool").b()) {
+      continue;
+    }
 
     // TODO(b/148614504): Support fusing different types of map operations.
     if (parent_map_node->op() != map_node->op()) continue;
diff --git a/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc
index c81191e..a773d9b 100644
--- a/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc
@@ -20,6 +20,7 @@
 
 #include <gtest/gtest.h>
 #include "absl/strings/string_view.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/attr_value_util.h"
 #include "tensorflow/core/framework/function_testlib.h"
 #include "tensorflow/core/grappler/grappler_item.h"
@@ -28,7 +29,6 @@
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/status.h"
 #include "tensorflow/core/platform/test.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/errors.h"
 
 namespace tensorflow {
@@ -88,9 +88,11 @@
        NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
        num_parallel_calls_node,
        MakeParallelMapV2Node("map1", "range", num_parallel_calls_node.name(),
-                             "XTimesTwo", "default"),
+                             "XTimesTwo", "default",
+                             /*use_unbounded_threadpool=*/false),
        MakeParallelMapV2Node("map2", "map1", num_parallel_calls_node.name(),
-                             "XTimesTwo", "default")},
+                             "XTimesTwo", "default",
+                             /*use_unbounded_threadpool=*/false)},
       // FunctionLib
       {
           test::function::XTimesTwo(),
@@ -171,9 +173,11 @@
        NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
        num_parallel_calls_node,
        MakeParallelMapV2Node("map1", "range", num_parallel_calls_node.name(),
-                             "XTimesTwo", "default"),
+                             "XTimesTwo", "default",
+                             /*use_unbounded_threadpool=*/false),
        MakeParallelMapV2Node("map2", "map1", num_parallel_calls_node.name(),
-                             "XTimesTwo", "default")},
+                             "XTimesTwo", "default",
+                             /*use_unbounded_threadpool=*/false)},
       // FunctionLib
       {
           test::function::XTimesTwo(),
@@ -187,6 +191,36 @@
   EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map2", output));
 }
 
+TEST(MapFusionTest, NoChange_UnboundedThreadpoolParallelMap) {
+  using test::function::NDef;
+  GrapplerItem item;
+  NodeDef num_parallel_calls_node = CreateScalarConstNodeHelper(
+      "num_parallel_calls", DT_INT64,
+      [](TensorProto* proto) { proto->add_int64_val(-1); });
+  item.graph = test::function::GDef(
+      {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
+       NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
+       NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+       NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
+       num_parallel_calls_node,
+       MakeParallelMapV2Node("map1", "range", num_parallel_calls_node.name(),
+                             "XTimesTwo", "default",
+                             /*use_unbounded_threadpool=*/true),
+       MakeParallelMapV2Node("map2", "map1", num_parallel_calls_node.name(),
+                             "XTimesTwo", "default",
+                             /*use_unbounded_threadpool=*/false)},
+      // FunctionLib
+      {
+          test::function::XTimesTwo(),
+      });
+
+  MapFusion optimizer;
+  GraphDef output;
+  TF_ASSERT_OK(OptimizeWithMapFusion(item, &output, true));
+  EXPECT_TRUE(graph_utils::ContainsGraphNodeWithName("map1", output));
+  EXPECT_TRUE(graph_utils::ContainsGraphNodeWithName("map2", output));
+}
+
 TEST(MapFusionTest, FusedNodesAndFunctionsAreNamedAfterOldNodesAndFunctions) {
   using test::function::NDef;
   NodeDef num_parallel_calls_node = CreateScalarConstNodeHelper(
@@ -209,10 +243,11 @@
          num_parallel_calls_node,
          MakeParallelMapV2Node(parent_map_node_name, "range",
                                num_parallel_calls_node.name(),
-                               parent_function_name, "default"),
+                               parent_function_name, "default",
+                               /*use_unbounded_threadpool=*/false),
          MakeParallelMapV2Node(map_node_name, parent_map_node_name,
                                num_parallel_calls_node.name(), function_name,
-                               "default")},
+                               "default", /*use_unbounded_threadpool=*/false)},
         // FunctionLib
         {parent_fn, fn});
   };
diff --git a/tensorflow/core/grappler/optimizers/data/remove_compression_map_test.cc b/tensorflow/core/grappler/optimizers/data/remove_compression_map_test.cc
index 25d86a1..2060b0e 100644
--- a/tensorflow/core/grappler/optimizers/data/remove_compression_map_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/remove_compression_map_test.cc
@@ -17,6 +17,7 @@
 
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/function.h"
 #include "tensorflow/core/framework/function_testlib.h"
 #include "tensorflow/core/framework/graph.pb.h"
@@ -25,7 +26,6 @@
 #include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
 #include "tensorflow/core/platform/status_matchers.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/protobuf/error_codes.pb.h"
 
 namespace tensorflow {
@@ -74,7 +74,8 @@
            /*input_node_name=*/"RangeDataset/_3",
            /*num_parallel_calls_node_name=*/"Const/_4",
            /*function_name=*/"__inference_Dataset_map_lambda_10",
-           /*deterministic=*/"default"),
+           /*deterministic=*/"default",
+           /*use_unbounded_threadpool=*/false),
 
        NDef("dataset",                    // name
             "_Retval",                    // op
diff --git a/tensorflow/core/grappler/optimizers/data/seq_interleave_prefetch_test.cc b/tensorflow/core/grappler/optimizers/data/seq_interleave_prefetch_test.cc
index 076357a..eba3fce 100644
--- a/tensorflow/core/grappler/optimizers/data/seq_interleave_prefetch_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/seq_interleave_prefetch_test.cc
@@ -19,6 +19,7 @@
 
 #include <gtest/gtest.h>
 #include "absl/strings/str_cat.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/data/dataset_utils.h"
 #include "tensorflow/core/framework/function.h"
 #include "tensorflow/core/framework/function.pb.h"
@@ -32,7 +33,6 @@
 #include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
 #include "tensorflow/core/platform/status.h"
-#include "tsl/lib/core/status_test_util.h"
 
 namespace tensorflow {
 namespace grappler {
diff --git a/tensorflow/core/grappler/utils/pattern_utils_test.cc b/tensorflow/core/grappler/utils/pattern_utils_test.cc
index 6b8f689..22fe41b 100644
--- a/tensorflow/core/grappler/utils/pattern_utils_test.cc
+++ b/tensorflow/core/grappler/utils/pattern_utils_test.cc
@@ -184,7 +184,7 @@
   bool all_indices_matched = true;
   for (auto it = matched_nodes_map.begin(); it != matched_nodes_map.begin();
        it++) {
-    auto label = str_util::StripPrefix(it->first, "my_");
+    auto label = absl::StripPrefix(it->first, "my_");
     int matched_node_idx = it->second;
     int expected_node_idx = graph_view.GetNode(label)->node_index();
     if (matched_node_idx != expected_node_idx) {
@@ -268,7 +268,7 @@
   bool all_indices_matched = true;
   for (auto it = matched_nodes_map.begin(); it != matched_nodes_map.begin();
        it++) {
-    auto label = str_util::StripPrefix(it->first, "my_");
+    auto label = absl::StripPrefix(it->first, "my_");
     int matched_node_idx = it->second;
     int expected_node_idx = graph_view.GetNode(label)->node_index();
     if (matched_node_idx != expected_node_idx) {
@@ -387,7 +387,7 @@
   bool all_indices_matched = true;
   for (auto it = matched_nodes_map.begin(); it != matched_nodes_map.begin();
        it++) {
-    auto label = str_util::StripPrefix(it->first, "my_");
+    auto label = absl::StripPrefix(it->first, "my_");
     int matched_node_idx = it->second;
     int expected_node_idx = graph_view.GetNode(label)->node_index();
     if (matched_node_idx != expected_node_idx) {
@@ -561,7 +561,7 @@
       bool all_indices_matched = true;
       for (auto it = matched_nodes_map.begin(); it != matched_nodes_map.begin();
            it++) {
-        auto label = str_util::StripPrefix(it->first, "my_");
+        auto label = absl::StripPrefix(it->first, "my_");
         int matched_node_idx = it->second;
         int expected_node_idx = graph_view.GetNode(label)->node_index();
         if (matched_node_idx != expected_node_idx) {
diff --git a/tensorflow/core/ir/types/dialect.cc b/tensorflow/core/ir/types/dialect.cc
index 481c9f0..9805e17 100644
--- a/tensorflow/core/ir/types/dialect.cc
+++ b/tensorflow/core/ir/types/dialect.cc
@@ -32,6 +32,7 @@
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypeInterfaces.h"  // from @llvm-project
 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Dialect.h"  // from @llvm-project
 #include "mlir/IR/DialectImplementation.h"  // from @llvm-project
@@ -366,17 +367,10 @@
   os << "<";
   if (hasRank()) {
     auto print_dim = [&](int64_t dim) {
-      if (dim != ShapedType::kDynamic) {
-        if (dim == 0) {
-          // In order to avoid the parseInteger below from confusing a dimension
-          // list with '0x' as hex integer, we use 00 for a 0 sized dimension.
-          os << "00";
-        } else {
-          os << dim;
-        }
-      } else {
+      if (dim != ShapedType::kDynamic)
+        os << dim;
+      else
         os << "?";
-      }
     };
     llvm::interleave(getShape(), os, print_dim, "x");
   } else {
@@ -405,7 +399,7 @@
       llvm::SMLoc loc = parser.getCurrentLocation();
       if (succeeded(parser.parseOptionalQuestion())) {
         shape.back() = ShapedType::kDynamic;
-      } else if (failed(parser.parseInteger(shape.back()))) {
+      } else if (failed(parser.parseDecimalInteger(shape.back()))) {
         parser.emitError(loc)
             << "expected an integer or `?` when parsing a tf.shape attribute";
         return failure();
diff --git a/tensorflow/core/ir/types/dialect_test.cc b/tensorflow/core/ir/types/dialect_test.cc
index 84a301a..4fb014d 100644
--- a/tensorflow/core/ir/types/dialect_test.cc
+++ b/tensorflow/core/ir/types/dialect_test.cc
@@ -62,7 +62,7 @@
 TEST(TFTypesDialect, ParsesDimensionListWithZero) {
   // Test that a dimension list with zero can be parsed.
   const char *const code = R"mlir(
-  "test.op"() {shape = #tf_type.shape<00x128>} : () -> ()
+  "test.op"() {shape = #tf_type.shape<0x128>} : () -> ()
 )mlir";
 
   MLIRContext context;
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 9dd9d5e..213c83a 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -701,6 +701,7 @@
         "//tensorflow/core/kernels/batching_util:adaptive_shared_batch_scheduler",
         "//tensorflow/core/kernels/batching_util:batch_resource_base",
         "//tensorflow/core/kernels/batching_util:batch_scheduler_hdrs",
+        "//tensorflow/core/kernels/batching_util:batch_scheduler_utils",
         "//tensorflow/core/kernels/batching_util:bounded_executor",
         "//tensorflow/core/kernels/batching_util:concat_split_util",
         "//tensorflow/core/kernels/batching_util:periodic_function_dynamic",
@@ -1647,7 +1648,7 @@
 
 tf_cc_test(
     name = "batch_kernels_test",
-    size = "medium",
+    size = "small",
     srcs = ["batch_kernels_test.cc"],
     features = ["-layering_check"],
     deps = [
@@ -4672,7 +4673,7 @@
         "spacetobatch_functor.h",
         "spacetobatch_functor_gpu.cu.cc",
     ],
-    visibility = [":friends"],
+    visibility = ["//visibility:private"],
     deps = [
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
@@ -4715,7 +4716,7 @@
         "spacetodepth_op.h",
         "spacetodepth_op_gpu.cu.cc",
     ],
-    visibility = [":friends"],
+    visibility = ["//visibility:private"],
     deps = [
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
diff --git a/tensorflow/core/kernels/batch_kernels.cc b/tensorflow/core/kernels/batch_kernels.cc
index 8e1e97d..bd93c1e 100644
--- a/tensorflow/core/kernels/batch_kernels.cc
+++ b/tensorflow/core/kernels/batch_kernels.cc
@@ -38,6 +38,7 @@
 #include "tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h"
 #include "tensorflow/core/kernels/batching_util/batch_resource_base.h"
 #include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
+#include "tensorflow/core/kernels/batching_util/batch_scheduler_utils.h"
 #include "tensorflow/core/kernels/batching_util/bounded_executor.h"
 #include "tensorflow/core/kernels/batching_util/concat_split_util.h"
 #include "tensorflow/core/kernels/batching_util/periodic_function.h"
@@ -177,7 +178,8 @@
                   /*mixed_priority_batching_policy=*/
                   serving::MixedPriorityBatchingPolicy::
                       kLowPriorityPaddingWithMaxBatchSize,
-                  enable_large_batch_splitting, resource);
+                  enable_large_batch_splitting,
+                  /*batch_padding_policy=*/"PAD_UP", resource);
   }
 
   static Status Create(
@@ -190,7 +192,7 @@
       int32_t low_priority_max_enqueued_batches,
       const std::vector<int32>& low_priority_allowed_batch_sizes,
       serving::MixedPriorityBatchingPolicy mixed_priority_batching_policy,
-      bool enable_large_batch_splitting,
+      bool enable_large_batch_splitting, absl::string_view batch_padding_policy,
       std::unique_ptr<BatchResource>* resource) {
     BatcherT::Options batcher_options;
     batcher_options.num_batch_threads = num_batch_threads;
@@ -203,8 +205,8 @@
             num_batch_threads, max_execution_batch_size, batch_timeout_micros,
             max_enqueued_batches, allowed_batch_sizes,
             enable_large_batch_splitting,
-            /*disable_padding=*/false, low_priority_max_batch_size,
-            low_priority_batch_timeout_micros,
+            /*disable_padding=*/false, batch_padding_policy,
+            low_priority_max_batch_size, low_priority_batch_timeout_micros,
             low_priority_max_enqueued_batches, low_priority_allowed_batch_sizes,
             mixed_priority_batching_policy),
         allowed_batch_sizes));
@@ -439,7 +441,7 @@
           low_priority_batch_timeout_micros_,
           low_priority_max_enqueued_batches_, low_priority_allowed_batch_sizes_,
           mixed_priority_batching_policy, enable_large_batch_splitting_,
-          &new_resource));
+          batch_padding_policy_, &new_resource));
       if (session_metadata) {
         new_resource->set_session_metadata(*session_metadata);
       }
diff --git a/tensorflow/core/kernels/batch_kernels_auto_warmup_test.cc b/tensorflow/core/kernels/batch_kernels_auto_warmup_test.cc
index e3601cf..7e66f9b 100644
--- a/tensorflow/core/kernels/batch_kernels_auto_warmup_test.cc
+++ b/tensorflow/core/kernels/batch_kernels_auto_warmup_test.cc
@@ -13,8 +13,6 @@
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/core/kernels/batch_kernels.h"
-
 #include <cstdint>
 #include <memory>
 #include <utility>
@@ -22,6 +20,7 @@
 
 #include <gtest/gtest.h>
 #include "absl/strings/match.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
 #include "tensorflow/core/framework/device_factory.h"
 #include "tensorflow/core/framework/function.h"
@@ -30,6 +29,7 @@
 #include "tensorflow/core/framework/tensor_testutil.h"
 #include "tensorflow/core/framework/types.pb.h"
 #include "tensorflow/core/kernels/batch_kernel_test_util.h"
+#include "tensorflow/core/kernels/batch_kernels.h"
 #include "tensorflow/core/kernels/batching_util/warmup.h"
 #include "tensorflow/core/kernels/ops_testutil.h"
 #include "tensorflow/core/platform/env.h"
@@ -37,7 +37,6 @@
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/protobuf/config.pb.h"
 #include "tensorflow/core/public/version.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/blocking_counter.h"
 #include "tsl/platform/errors.h"
 #include "tsl/platform/refcount.h"
diff --git a/tensorflow/core/kernels/batch_kernels_env_test.cc b/tensorflow/core/kernels/batch_kernels_env_test.cc
index 508c0e8..5a8bfec 100644
--- a/tensorflow/core/kernels/batch_kernels_env_test.cc
+++ b/tensorflow/core/kernels/batch_kernels_env_test.cc
@@ -14,11 +14,11 @@
 ==============================================================================*/
 
 #include <gmock/gmock.h>
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/kernels/batch_kernel_test_util.h"
 #include "tensorflow/core/platform/status.h"
 #include "tensorflow/core/platform/status_matchers.h"
 #include "tensorflow/core/protobuf/error_codes.pb.h"
-#include "tsl/lib/core/status_test_util.h"
 
 namespace tensorflow {
 namespace {
diff --git a/tensorflow/core/kernels/batch_kernels_test.cc b/tensorflow/core/kernels/batch_kernels_test.cc
index 62666c0..9aaeb5a 100644
--- a/tensorflow/core/kernels/batch_kernels_test.cc
+++ b/tensorflow/core/kernels/batch_kernels_test.cc
@@ -17,12 +17,15 @@
 
 #include <cstdint>
 #include <memory>
+#include <string>
 #include <utility>
 #include <vector>
 
 #include <gtest/gtest.h>
 #include "absl/status/status.h"
+#include "absl/strings/str_cat.h"
 #include "absl/strings/string_view.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
 #include "tensorflow/core/framework/device_factory.h"
 #include "tensorflow/core/framework/function.h"
@@ -39,12 +42,13 @@
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/protobuf/config.pb.h"
 #include "tensorflow/core/public/version.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/blocking_counter.h"
 #include "tsl/platform/criticality.h"
 #include "tsl/platform/errors.h"
 #include "tsl/platform/refcount.h"
 #include "tsl/platform/status.h"
+#include "tsl/platform/statusor.h"
+#include "tsl/protobuf/error_codes.pb.h"
 
 namespace tensorflow {
 namespace {
@@ -84,19 +88,13 @@
           return absl::OkStatus();
         }});
   }
-};
 
-class BatchFunctionTestState : public SharedBatchFunctionTestState {
- public:
-  // Init test fixture with a batch kernel instance. The caller guarantees that
-  // the device pointer is valid throughout the life of this class.
-  absl::Status Init(Device *device, bool enable_low_priority_queue,
-                    absl::string_view mixed_priority_policy,
-                    int64_t expected_batch_size) {
-    // Override the per-test/per-op device with a given device so that it can
-    // be shared between ops.
-    device_ = device;
-
+ protected:
+  // Create common batch function op for testing.
+  absl::StatusOr<NodeDefBuilder> CreateBatchFunctionBuilder(
+      const std::vector<int> &allowed_batch_sizes, int max_batch_size,
+      absl::string_view padding_policy,
+      const TensorShape &expected_output_shape) {
     NameAttrList f;
     f.set_name("ShapeEnforcingFunction");
     FunctionDef func = FunctionDefHelper::Create(
@@ -112,8 +110,7 @@
         {{{"o"},
           "EnsureShape",
           {"x"},
-          {{"T", DataType::DT_INT64},
-           {"shape", TensorShape({expected_batch_size, 2})}}}},
+          {{"T", DataType::DT_INT64}, {"shape", expected_output_shape}}}},
         // ret_def
         {{"o", "o:output"}});
     TF_RETURN_IF_ERROR(flib_def_->AddFunctionDef(func));
@@ -121,13 +118,40 @@
 
     std::vector<NodeDefBuilder::NodeOut> inputs(
         {NodeDefBuilder::NodeOut({"n1", 0, DataType::DT_INT64})});
-    TF_RETURN_IF_ERROR(NodeDefBuilder("BatchTPUInput", "BatchFunction")
-                           .Attr("max_batch_size", 8)
-                           .Attr("num_batch_threads", 8)
-                           .Attr("allowed_batch_sizes", {4, 8})
-                           .Attr("batch_timeout_micros", 1000000)
-                           .Attr("max_enqueued_batches", 10)
-                           .Attr("enable_large_batch_splitting", true)
+    return NodeDefBuilder(absl::StrCat("BatchTPUInput", padding_policy),
+                          "BatchFunction")
+        .Attr("max_batch_size", max_batch_size)
+        .Attr("num_batch_threads", 8)
+        .Attr("allowed_batch_sizes", allowed_batch_sizes)
+        .Attr("batch_timeout_micros", 1000000)
+        .Attr("max_enqueued_batches", 10)
+        .Attr("enable_large_batch_splitting", true)
+        .Attr("batch_padding_policy", padding_policy)
+        .Attr("Tin", {DataType::DT_INT64})
+        .Input(inputs)
+        .Attr("Tcaptured", std::vector<DataType>{})
+        .Input(std::vector<NodeDefBuilder::NodeOut>{})
+        .Attr("Tout", std::vector<DataType>{DT_INT64})
+        .Attr("f", f);
+  }
+};
+
+class BatchFunctionTestState : public SharedBatchFunctionTestState {
+ public:
+  // Init test fixture with a batch kernel instance. The caller guarantees that
+  // the device pointer is valid throughout the life of this class.
+  absl::Status Init(Device *device, bool enable_low_priority_queue,
+                    absl::string_view mixed_priority_policy,
+                    int64_t expected_batch_size) {
+    // Override the per-test/per-op device with a given device so that it can
+    // be shared between ops.
+    device_ = device;
+
+    const TensorShape expected_output_shape({expected_batch_size, 2});
+    TF_ASSIGN_OR_RETURN(
+        NodeDefBuilder builder,
+        CreateBatchFunctionBuilder({4, 8}, 8, "PAD_UP", expected_output_shape));
+    TF_RETURN_IF_ERROR(builder
                            .Attr("low_priority_max_batch_size",
                                  enable_low_priority_queue ? 8 : 0)
                            .Attr("low_priority_batch_timeout_micros",
@@ -139,14 +163,8 @@
                            .Attr("low_priority_max_enqueued_batches",
                                  enable_low_priority_queue ? 2 : 0)
                            .Attr("mixed_priority_policy", mixed_priority_policy)
-                           .Attr("batch_padding_policy", "PAD_UP")
-                           .Attr("Tin", {DataType::DT_INT64})
-                           .Input(inputs)
-                           .Attr("Tcaptured", std::vector<DataType>{})
-                           .Input(std::vector<NodeDefBuilder::NodeOut>{})
-                           .Attr("Tout", std::vector<DataType>{DT_INT64})
-                           .Attr("f", f)
                            .Finalize(node_def()));
+
     return OpsTestBase::InitOp();
   }
 
@@ -576,48 +594,13 @@
     // be shared between ops.
     device_ = cpu_device;
 
-    NameAttrList f;
-    f.set_name("BatchFunctionKernelParallelWarmupTestStateFunc");
-    FunctionDef func = FunctionDefHelper::Create(
-        // function_name
-        f.name(),
-        // in_def
-        {"x:int64"},
-        // out_def
-        {"o:int64"},
-        // attr_def
-        {},
-        // node_def
-        {{{"o"},
-          "EnsureShape",
-          {"x"},
-          {{"T", DataType::DT_INT64}, {"shape", TensorShape({2})}}}},
-        // ret_def
-        {{"o", "o:output"}});
-    TF_RETURN_IF_ERROR(flib_def_->AddFunctionDef(func));
-    SharedBatchFunctionTestState::CreateFunctionLibraryRuntime();
+    const TensorShape expected_output_shape({2});
+    TF_ASSIGN_OR_RETURN(
+        NodeDefBuilder builder,
+        CreateBatchFunctionBuilder({2, 4, 8}, enable_splitting ? 16 : 8,
+                                   "PAD_UP", expected_output_shape));
+    TF_RETURN_IF_ERROR(builder.Finalize(node_def()));
 
-    std::vector<NodeDefBuilder::NodeOut> inputs(
-        {NodeDefBuilder::NodeOut({"n1", 0, DataType::DT_INT64})});
-    TF_RETURN_IF_ERROR(NodeDefBuilder("BatchTPUInput", "BatchFunction")
-                           .Attr("max_batch_size", enable_splitting ? 16 : 8)
-                           .Attr("num_batch_threads", 8)
-                           .Attr("allowed_batch_sizes", {2, 4, 8})
-                           .Attr("batch_timeout_micros", 1000000)
-                           .Attr("max_enqueued_batches", 10)
-                           .Attr("enable_large_batch_splitting", true)
-                           .Attr("low_priority_max_batch_size", 64)
-                           .Attr("low_priority_batch_timeout_micros", 8000)
-                           .Attr("low_priority_allowed_batch_sizes", {32, 64})
-                           .Attr("low_priority_max_enqueued_batches", 1000)
-                           .Attr("batch_padding_policy", "PAD_UP")
-                           .Attr("Tin", {DataType::DT_INT64})
-                           .Input(inputs)
-                           .Attr("Tcaptured", std::vector<DataType>{})
-                           .Input(std::vector<NodeDefBuilder::NodeOut>{})
-                           .Attr("Tout", std::vector<DataType>{DT_INT64})
-                           .Attr("f", f)
-                           .Finalize(node_def()));
     return OpsTestBase::InitOp();
   }
 
@@ -688,5 +671,80 @@
                          BatchFunctionKernelParallelWarmupTest,
                          ::testing::Bool());
 
+class BatchFunctionKernelPaddingTestState
+    : public SharedBatchFunctionTestState {
+ public:
+  // Init test fixture with a batch kernel instance.
+  absl::Status Init(absl::string_view padding_policy, int expected_batch_size) {
+    static auto *const cpu_device = []() {
+      auto device =
+          DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0");
+      return device.release();
+    }();
+
+    // Override the per-test/per-op device with a global device so that it can
+    // be shared between ops.
+    device_ = cpu_device;
+
+    const TensorShape expected_output_shape({expected_batch_size, 2});
+    TF_RETURN_IF_ERROR(CreateBatchFunctionBuilder({4, 8}, 8, padding_policy,
+                                                  expected_output_shape)
+                           ->Finalize(node_def()));
+
+    return OpsTestBase::InitOp();
+  }
+
+  void TestBody() override {}
+};
+
+class BatchFunctionKernelPaddingTest
+    : public ::testing::TestWithParam<std::string> {};
+
+TEST_P(BatchFunctionKernelPaddingTest, PadUp) {
+  SessionMetadata session_metadata;
+  session_metadata.set_name("test_model");
+  session_metadata.set_version(123);
+
+  // Send 5 requests in parallel and check that the given batch padding
+  // policy behaves as expected.
+  int64_t num_requests = 5;
+  int64_t expected_batch_size = 0;
+  std::string padding_policy = GetParam();
+  if (padding_policy == "PAD_UP") {
+    expected_batch_size = 8;
+  } else if (padding_policy == "BATCH_DOWN") {
+    expected_batch_size = 4;
+  } else if (padding_policy == "MINIMIZE_TPU_COST_PER_REQUEST") {
+    expected_batch_size = 8;
+  } else {
+    FAIL() << "Unsupported padding policy: " << padding_policy;
+  }
+
+  {
+    tsl::BlockingCounter blocking_counter(num_requests);
+    for (int i = 0; i < num_requests; ++i) {
+      Env::Default()->SchedClosure([&]() {
+        BatchFunctionKernelPaddingTestState test_state;
+        test_state.set_session_metadata(session_metadata);
+        TF_CHECK_OK(test_state.Init(padding_policy, expected_batch_size));
+        test_state.AddInputFromList<int64_t>(TensorShape({1, 2}), {123, 456});
+        TF_EXPECT_OK(test_state.RunOpKernel());
+
+        test::ExpectTensorEqual<int64_t>(
+            *test_state.GetOutput(0),
+            test::AsTensor<int64_t>({123, 456}, TensorShape({1, 2})));
+        blocking_counter.DecrementCount();
+      });
+    }
+
+    blocking_counter.Wait();
+  }
+}
+
+INSTANTIATE_TEST_SUITE_P(BatchFunctionKernelPaddingTestSuite,
+                         BatchFunctionKernelPaddingTest,
+                         ::testing::Values("PAD_UP", "BATCH_DOWN",
+                                           "MINIMIZE_TPU_COST_PER_REQUEST"));
+
 }  // namespace
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/batching_util/BUILD b/tensorflow/core/kernels/batching_util/BUILD
index 6cf3ef0..4173632 100644
--- a/tensorflow/core/kernels/batching_util/BUILD
+++ b/tensorflow/core/kernels/batching_util/BUILD
@@ -121,10 +121,10 @@
         "//tensorflow/core/lib/core:notification",
         "//tensorflow/core/lib/core:status",
         "//tensorflow/core/platform:thread_annotations",
-        "//tensorflow/core/profiler/lib:traceme",
         "@com_google_absl//absl/status:statusor",
         "@com_google_absl//absl/strings:string_view",
         "@local_tsl//tsl/platform:criticality",
+        "@local_tsl//tsl/profiler/lib:traceme",
     ],
 )
 
@@ -134,12 +134,12 @@
     hdrs = ["batch_scheduler.h"],
     deps = [
         "//tensorflow/core:lib",
-        "//tensorflow/core/profiler/lib:traceme",
         "@com_google_absl//absl/status",
         "@com_google_absl//absl/status:statusor",
         "@com_google_absl//absl/strings:str_format",
         "@com_google_absl//absl/strings:string_view",
         "@local_tsl//tsl/platform:criticality",
+        "@local_tsl//tsl/profiler/lib:traceme",
     ],
 )
 
@@ -148,8 +148,12 @@
     srcs = ["batch_scheduler_utils.cc"],
     hdrs = ["batch_scheduler_utils.h"],
     deps = [
+        ":batch_scheduler_hdrs",
+        ":batch_stats",
         "//tensorflow/core:portable_gif_internal",
         "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/strings:string_view",
+        "@com_google_absl//absl/time",
     ],
 )
 
@@ -183,7 +187,10 @@
     name = "batch_scheduler_utils_test",
     srcs = ["batch_scheduler_utils_test.cc"],
     deps = [
+        ":batch_scheduler_hdrs",
         ":batch_scheduler_utils",
+        ":batch_stats",
+        "@com_google_absl//absl/time",
         "@com_google_googletest//:gtest_main",
     ],
 )
@@ -195,6 +202,7 @@
         ":batch_input_task",
         ":batch_scheduler_hdrs",
         ":batch_scheduler_utils",
+        ":batch_stats",
         ":periodic_function_dynamic",
         "//tensorflow/core:framework_lite",
         "//tensorflow/core:lib",
@@ -209,7 +217,6 @@
         "//tensorflow/core/profiler/lib:context_types_hdrs",
         "//tensorflow/core/profiler/lib:traceme",
         "//tensorflow/core/profiler/lib:traceme_encode",
-        "@com_google_absl//absl/log:check",
         "@com_google_absl//absl/status",
         "@com_google_absl//absl/strings:str_format",
         "@com_google_absl//absl/time",
@@ -225,13 +232,13 @@
         ":batch_input_task",
         ":batch_scheduler",
         ":batch_scheduler_utils",
+        ":batch_stats",
         ":periodic_function_dynamic",
         "//tensorflow/core:lib",
         "//tensorflow/core/profiler/lib:connected_traceme",
         "//tensorflow/core/profiler/lib:context_types_hdrs",
         "//tensorflow/core/profiler/lib:traceme",
         "//tensorflow/core/profiler/lib:traceme_encode",
-        "@com_google_absl//absl/log:check",
         "@com_google_absl//absl/status",
         "@com_google_absl//absl/strings:str_format",
         "@com_google_absl//absl/time",
@@ -246,6 +253,7 @@
     srcs = ["shared_batch_scheduler_test.cc"],
     deps = [
         ":batch_scheduler",
+        ":batch_scheduler_utils",
         ":fake_clock_env",
         ":shared_batch_scheduler",
         "//tensorflow/core:lib",
@@ -481,18 +489,30 @@
     srcs = ["batch_resource_base_test.cc"],
     deps = [
         ":batch_resource_base",
+        ":batch_scheduler_hdrs",
+        ":batch_scheduler_utils",
         ":batch_stats",
+        ":shared_batch_scheduler",
+        "//tensorflow/core:core_cpu",
         "//tensorflow/core:framework",
+        "//tensorflow/core:ops",
+        "//tensorflow/core:portable_gif_internal",
         "//tensorflow/core/common_runtime:cost_constants",
         "//tensorflow/core/common_runtime:cost_measurement",
         "//tensorflow/core/common_runtime:cost_measurement_registry",
         "//tensorflow/core/common_runtime:no_op_cost_measurement",
         "//tensorflow/core/common_runtime:request_cost",
         "//tensorflow/core/framework:types_proto_cc",
+        "//tensorflow/core/kernels:batch_kernels",
+        "//tensorflow/core/lib/monitoring:cell_reader",
+        "//tensorflow/core/platform:notification",
+        "@com_google_absl//absl/status",
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/time",
+        "@com_google_absl//absl/types:span",
         "@com_google_googletest//:gtest_main",
         "@local_tsl//tsl/platform:criticality",
+        "@local_tsl//tsl/platform:status",
     ],
 )
 
diff --git a/tensorflow/core/kernels/batching_util/batch_resource_base.cc b/tensorflow/core/kernels/batching_util/batch_resource_base.cc
index 630ff22..81cbe41 100644
--- a/tensorflow/core/kernels/batching_util/batch_resource_base.cc
+++ b/tensorflow/core/kernels/batching_util/batch_resource_base.cc
@@ -233,6 +233,16 @@
   cell->GetCell(model_name, op_name)->Set(max_batch_size);
 }
 
+void RecordBatchParamPaddingPolicy(const string& batch_padding_policy,
+                                   const string& model_name,
+                                   const string& op_name) {
+  static auto* cell = monitoring::Gauge<string, 2>::New(
+      "/tensorflow/serving/batching/configured_batch_padding_policy",
+      "The value of BatchFunction.batch_padding_policy attribute.",
+      "model_name", "op_name");
+  cell->GetCell(model_name, op_name)->Set(batch_padding_policy);
+}
+
 void RecordBatchParamMaxEnqueuedBatches(int64_t max_enqueued_batches,
                                         const string& model_name,
                                         const string& op_name) {
@@ -406,6 +416,9 @@
     RecordBatchParamMaxEnqueuedBatches(
         batcher_queue_options_.max_enqueued_batches, GetModelName(context),
         context->op_kernel().name());
+    RecordBatchParamPaddingPolicy(
+        this->batcher_queue_options_.batch_padding_policy,
+        GetModelName(context), context->op_kernel().name());
   } else if (adaptive_batcher_) {
     RecordBatchParamBatchTimeoutMicros(
         adaptive_batcher_queue_options_.batch_timeout_micros,
@@ -472,8 +485,10 @@
   }
 
   BatcherQueueT* batcher_queue;
-  TF_RETURN_IF_ERROR(
-      LookupOrCreateBatcherQueue(batcher_queue_name, &batcher_queue));
+  TF_RETURN_IF_ERROR(LookupOrCreateBatcherQueue(
+      /* queue_name= */ batcher_queue_name,
+      /* model_name= */ GetModelName(context),
+      /* op_name= */ context->op_kernel().name(), /* queue= */ &batcher_queue));
 
   if (!session_metadata().name().empty()) {
     absl::MutexLock lock(&outstanding_batch_mu_);
@@ -500,7 +515,9 @@
   return GetBatcherQueueOptions(
       num_batch_threads, max_batch_size, batch_timeout_micros,
       max_enqueued_batches, allowed_batch_sizes, enable_large_batch_splitting,
-      disable_padding, /*low_priority_max_batch_size=*/0,
+      disable_padding,
+      /*batch_padding_policy=*/kPadUpPolicy,
+      /*low_priority_max_batch_size=*/0,
       /*low_priority_batch_timeout_micros=*/0,
       /*low_priority_max_enqueued_batches=*/0,
       /*low_priority_allowed_batch_sizes=*/{},
@@ -514,7 +531,7 @@
     int32_t batch_timeout_micros, int32_t max_enqueued_batches,
     const std::vector<int32>& allowed_batch_sizes,
     bool enable_large_batch_splitting, bool disable_padding,
-    int32_t low_priority_max_batch_size,
+    absl::string_view batch_padding_policy, int32_t low_priority_max_batch_size,
     int32_t low_priority_batch_timeout_micros,
     int32_t low_priority_max_enqueued_batches,
     const std::vector<int32>& low_priority_allowed_batch_sizes,
@@ -523,6 +540,8 @@
   batcher_queue_options.input_batch_size_limit = max_batch_size;
   batcher_queue_options.max_enqueued_batches = max_enqueued_batches;
   batcher_queue_options.batch_timeout_micros = batch_timeout_micros;
+  batcher_queue_options.batch_padding_policy =
+      std::string(batch_padding_policy);
   if (low_priority_max_batch_size > 0) {
     batcher_queue_options.enable_priority_queue = true;
   }
@@ -1172,9 +1191,9 @@
   }
 }
 
-// Looks up the batcher queue for 'queue_name'. If it didn't previously exist,
-// creates it.
 Status BatchResourceBase::LookupOrCreateBatcherQueue(const string& queue_name,
+                                                     const string& model_name,
+                                                     const string& op_name,
                                                      BatcherQueueT** queue) {
   mutex_lock l(batcher_queues_mu_);
 
@@ -1186,8 +1205,12 @@
 
   std::unique_ptr<BatcherQueueT> new_queue;
   if (batcher_) {
+    BatcherT::QueueOptions batcher_queue_options = batcher_queue_options_;
+    batcher_queue_options.model_batch_stats = &GlobalBatchStatsRegistry().model(
+        /* model_name= */ model_name, /* op_name= */ op_name);
+
     TF_RETURN_IF_ERROR(batcher_->AddQueue(
-        batcher_queue_options_,
+        batcher_queue_options,
         absl::bind_front(&BatchResourceBase::ProcessBatchCallBack, this),
         &new_queue));
   } else if (adaptive_batcher_) {
diff --git a/tensorflow/core/kernels/batching_util/batch_resource_base.h b/tensorflow/core/kernels/batching_util/batch_resource_base.h
index e8b3926..c50b29f 100644
--- a/tensorflow/core/kernels/batching_util/batch_resource_base.h
+++ b/tensorflow/core/kernels/batching_util/batch_resource_base.h
@@ -25,6 +25,7 @@
 #include <vector>
 
 #include "absl/strings/str_join.h"
+#include "absl/strings/string_view.h"
 #include "absl/synchronization/blocking_counter.h"
 #include "tensorflow/core/common_runtime/cost_measurement_registry.h"
 #include "tensorflow/core/common_runtime/request_cost.h"
@@ -34,6 +35,7 @@
 #include "tensorflow/core/framework/tensor_shape.h"
 #include "tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h"
 #include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
+#include "tensorflow/core/kernels/batching_util/batch_scheduler_utils.h"
 #include "tensorflow/core/kernels/batching_util/shared_batch_scheduler.h"
 #include "tensorflow/core/kernels/batching_util/threadsafe_status.h"
 #include "tensorflow/core/platform/context.h"
@@ -52,6 +54,7 @@
   int32_t batch_timeout_micros;
   int32_t max_enqueued_batches;
   std::vector<int32_t> allowed_batch_sizes;
+  std::string batch_padding_policy{kPadUpPolicy};
   int32_t low_priority_max_batch_size;
   int32_t low_priority_batch_timeout_micros;
   int32_t low_priority_max_enqueued_batches;
@@ -213,6 +216,7 @@
       int32_t batch_timeout_micros, int32_t max_enqueued_batches,
       const std::vector<int32>& allowed_batch_sizes,
       bool enable_large_batch_splitting, bool disable_padding,
+      absl::string_view batch_padding_policy,
       int32_t low_priority_max_batch_size,
       int32_t low_priority_batch_timeout_micros,
       int32_t low_priority_max_enqueued_batches,
@@ -332,9 +336,14 @@
   static Status EmitIndexTensor(OpKernelContext* context, const BatchT& batch,
                                 int output_index);
 
-  // Looks up the batcher queue for 'queue_name'. If it did't previously exist,
+  // Looks up the batcher queue for 'queue_name'. If it didn't previously exist,
   // creates it.
+  //
+  // The model_name and op_name are the names of the current model and
+  // operation, respectively.
   Status LookupOrCreateBatcherQueue(const string& queue_name,
+                                    const string& model_name,
+                                    const string& op_name,
                                     BatcherQueueT** queue);
 
   SessionMetadata session_metadata_;
diff --git a/tensorflow/core/kernels/batching_util/batch_resource_base_test.cc b/tensorflow/core/kernels/batching_util/batch_resource_base_test.cc
index 3900711..fa4fad9 100644
--- a/tensorflow/core/kernels/batching_util/batch_resource_base_test.cc
+++ b/tensorflow/core/kernels/batching_util/batch_resource_base_test.cc
@@ -16,22 +16,40 @@
 #include "tensorflow/core/kernels/batching_util/batch_resource_base.h"
 
 #include <cstdint>
+#include <functional>
 #include <memory>
+#include <string>
 #include <vector>
 
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
+#include "absl/status/status.h"
 #include "absl/strings/string_view.h"
 #include "absl/time/time.h"
+#include "absl/types/span.h"
 #include "tensorflow/core/common_runtime/cost_constants.h"
 #include "tensorflow/core/common_runtime/cost_measurement.h"
 #include "tensorflow/core/common_runtime/cost_measurement_registry.h"
 #include "tensorflow/core/common_runtime/request_cost.h"
+#include "tensorflow/core/framework/device.h"
+#include "tensorflow/core/framework/device_factory.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
+#include "tensorflow/core/kernels/batching_util/batch_scheduler_utils.h"
 #include "tensorflow/core/kernels/batching_util/batch_stats.h"
+#include "tensorflow/core/kernels/batching_util/shared_batch_scheduler.h"
+#include "tensorflow/core/lib/monitoring/cell_reader.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/notification.h"
+#include "tensorflow/core/public/session_options.h"
+#include "tensorflow/core/public/version.h"
 #include "tsl/platform/criticality.h"
+#include "tsl/platform/status.h"
 
 namespace tensorflow {
 namespace serving {
@@ -400,6 +418,210 @@
             original_cumulative_processed_size + 4);
 }
 
+class BatchResourceBaseTest : public ::testing::Test {
+ protected:
+  // Like BatchResourceBase but overrides abstract methods, one of which
+  // notifies the exposed process_func_batch_called() notification.
+  class MyBatchResource : public BatchResourceBase {
+   public:
+    using BatchResourceBase::BatchResourceBase;
+
+    std::string DebugString() const override { return ""; }
+
+    void ProcessFuncBatchImpl(
+        const BatchResourceBase::BatchTask& /* last_task */,
+        absl::Span<const Tensor> /* inputs */,
+        std::vector<Tensor>* /* combined_outputs */,
+        std::function<void(const absl::Status&)> /* done */) const override {
+      process_func_batch_called_.Notify();
+    }
+
+    Notification& process_func_batch_called() {
+      return process_func_batch_called_;
+    }
+
+   private:
+    mutable Notification process_func_batch_called_;
+  };
+
+  BatchResourceBaseTest() {
+    // The whole point of this test fixture is to create a usable batch function
+    // context, context_.
+
+    // Create device_.
+    device_ = DeviceFactory::NewDevice("CPU", SessionOptions{},
+                                       "/job:a/replica:0/task:0");
+
+    // Create batch_kernel_node_def.
+    NodeDefBuilder batch_function_builder("my_batch_node", "BatchFunction");
+    batch_function_builder.Attr("max_batch_size", 128);
+    batch_function_builder.Attr("num_batch_threads", 8);
+    batch_function_builder.Attr("allowed_batch_sizes", {2, 4, 8});
+    batch_function_builder.Attr("batch_timeout_micros", 100);
+    batch_function_builder.Attr("max_enqueued_batches", 100);
+    batch_function_builder.Attr("enable_large_batch_splitting", true);
+    std::vector<DataType> input_dtypes = {DataType::DT_INT64,
+                                          DataType::DT_INT64};
+    std::vector<NodeDefBuilder::NodeOut> inputs;
+    inputs.push_back(NodeDefBuilder::NodeOut({"n1", 0, DataType::DT_INT64}));
+    inputs.push_back(NodeDefBuilder::NodeOut({"n2", 1, DataType::DT_INT64}));
+    batch_function_builder.Attr("Tin", input_dtypes);
+    batch_function_builder.Input(inputs);
+    batch_function_builder.Attr("Tcaptured", {DataType::DT_INT64});
+    batch_function_builder.Input(std::vector<NodeDefBuilder::NodeOut>{
+        NodeDefBuilder::NodeOut({"n3", 1, DataType::DT_INT64})});
+    batch_function_builder.Attr("Tout", {DataType::DT_INT64});
+    NameAttrList f;
+    f.set_name("func_to_batch");
+    batch_function_builder.Attr("f", f);
+    NodeDef batch_kernel_node_def;
+    TF_CHECK_OK(batch_function_builder.Finalize(&batch_kernel_node_def));
+
+    // Create batch_kernel_.
+    absl::Status op_kernel_creation_status;
+    batch_kernel_ =
+        CreateOpKernel(DEVICE_CPU, device_.get(), device_->GetAllocator({}),
+                       batch_kernel_node_def, TF_GRAPH_DEF_VERSION,
+                       &op_kernel_creation_status);
+    TF_CHECK_OK(op_kernel_creation_status);
+    CHECK(batch_kernel_ != nullptr);
+
+    // Create input tensors.
+    input_tensor_ = Tensor(DataType::DT_INT64, TensorShape({5, 2, 1}));
+    input_tensor_values_ = {
+        TensorValue(&input_tensor_),
+        TensorValue(&input_tensor_),
+        TensorValue(&input_tensor_),
+    };
+
+    // Fill-in session_metadata_.
+    session_metadata_.set_name("my_model_name");
+
+    // Fill-in params_.
+    params_.device = device_.get();
+    params_.op_kernel = batch_kernel_.get();
+    params_.inputs = input_tensor_values_;
+    params_.session_metadata = &session_metadata_;
+
+    // Create context_.
+    context_ = std::make_unique<OpKernelContext>(&params_);
+  }
+
+  std::unique_ptr<Device> device_;
+
+  std::unique_ptr<OpKernel> batch_kernel_;
+
+  Tensor input_tensor_;
+  std::vector<TensorValue> input_tensor_values_;
+
+  SessionMetadata session_metadata_;
+
+  OpKernelContext::Params params_;
+
+  std::unique_ptr<OpKernelContext> context_;
+};
+
+TEST_F(BatchResourceBaseTest, PassesCorrectModelBatchStatsToSbs) {
+  using BatchTask = BatchResourceBase::BatchTask;
+  using SharedBatchScheduler = SharedBatchScheduler<BatchTask>;
+
+  // Like SharedBatchScheduler but exposes the last QueueOptions passed to
+  // AddQueue as queue_options().
+  class MySharedBatchScheduler : public SharedBatchScheduler {
+   public:
+    MySharedBatchScheduler() : SharedBatchScheduler::SharedBatchScheduler({}) {}
+
+    absl::Status AddQueue(
+        const QueueOptions& options,
+        ProcessBatchCallback process_batch_callback,
+        std::unique_ptr<BatchScheduler<BatchTask>>* queue) override {
+      queue_options_ = options;
+      return SharedBatchScheduler::AddQueue(options, process_batch_callback,
+                                            queue);
+    }
+
+    const QueueOptions& queue_options() const { return queue_options_; }
+
+   private:
+    QueueOptions queue_options_;
+  };
+
+  auto batcher = std::make_shared<MySharedBatchScheduler>();
+
+  MyBatchResource* my_batch_resource = new MyBatchResource(
+      /* has_process_batch_function */ true,
+      /* batcher= */ batcher,
+      /* batcher_queue_options */ {},
+      /* allowed_batch_sizes */ {});
+
+  TF_CHECK_OK(my_batch_resource->RegisterInput(
+      /* guid= */
+      0,
+      /* context= */ context_.get(),
+      /* batcher_queue_name= */ "batcher_queue_name",
+      /* create_batch_task_fn= */
+      []() -> absl::StatusOr<std::unique_ptr<BatchResourceBase::BatchTask>> {
+        return std::make_unique<BatchResourceBase::BatchTask>();
+      },
+      /* done_callback= */ [] {}, /* forced_warmup_batch_size= */ 0));
+
+  EXPECT_EQ(batcher->queue_options().model_batch_stats,
+            &GlobalBatchStatsRegistry().model(/* model_name= */ "my_model_name",
+                                              /* op_name= */ "my_batch_node"));
+
+  // Wait for the batch timeout to expire and the scheduler to dump the only
+  // scheduled task back to the batch resource. If we don't do this, the
+  // scheduler will do this itself on destruction, when the resource has already
+  // been destroyed.
+  my_batch_resource->process_func_batch_called().WaitForNotificationWithTimeout(
+      absl::Seconds(1));
+
+  // This is how we have to destroy the BatchResource.
+  my_batch_resource->Unref();
+}
+
+TEST_F(BatchResourceBaseTest, ConfiguredBatchPaddingPolicyMetric) {
+  tensorflow::monitoring::testing::CellReader<std::string> metric(
+      "/tensorflow/serving/batching/configured_batch_padding_policy");
+
+  std::shared_ptr<SharedBatchScheduler<BatchResourceBase::BatchTask>> batcher;
+  TF_CHECK_OK(
+      SharedBatchScheduler<BatchResourceBase::BatchTask>::Create({}, &batcher));
+
+  MyBatchResource* my_batch_resource = new MyBatchResource(
+      /* has_process_batch_function */ true,
+      /* batcher= */ batcher,
+      /* batcher_queue_options */
+      MyBatchResource::BatcherT::QueueOptions{
+          .batch_padding_policy{kMinimizeTpuCostPerRequestPolicy},
+      },
+      /* allowed_batch_sizes */ {});
+
+  TF_CHECK_OK(my_batch_resource->RegisterInput(
+      /* guid= */
+      0, /* context= */ context_.get(),
+      /* batcher_queue_name= */ "batcher_queue_name",
+      /* create_batch_task_fn= */
+      []() -> absl::StatusOr<std::unique_ptr<BatchResourceBase::BatchTask>> {
+        return std::make_unique<BatchResourceBase::BatchTask>();
+      },
+      /* done_callback= */ [] {}, /* forced_warmup_batch_size= */ 0));
+
+  EXPECT_EQ(metric.Read(/* model_name= */ "my_model_name",
+                        /* op_name= */ "my_batch_node"),
+            kMinimizeTpuCostPerRequestPolicy);
+
+  // Wait for the batch timeout to expire and the scheduler to dump the only
+  // scheduled task back to the batch resource. If we don't do this, the
+  // scheduler will do this itself on destruction, when the resource has already
+  // been destroyed.
+  my_batch_resource->process_func_batch_called().WaitForNotificationWithTimeout(
+      absl::Seconds(1));
+
+  // This is how we have to destroy the BatchResource.
+  my_batch_resource->Unref();
+}
+
 }  // namespace
 }  // namespace serving
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/batching_util/batch_scheduler.h b/tensorflow/core/kernels/batching_util/batch_scheduler.h
index c70972c..1c50c55 100644
--- a/tensorflow/core/kernels/batching_util/batch_scheduler.h
+++ b/tensorflow/core/kernels/batching_util/batch_scheduler.h
@@ -32,7 +32,7 @@
 #include <atomic>
 #include <cstddef>
 #include <deque>
-#include <functional>
+#include <iterator>
 #include <memory>
 #include <optional>
 #include <utility>
@@ -43,12 +43,11 @@
 #include "tensorflow/core/lib/core/notification.h"
 #include "tensorflow/core/lib/core/status.h"
 #include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/macros.h"
 #include "tensorflow/core/platform/mutex.h"
 #include "tensorflow/core/platform/thread_annotations.h"
 #include "tensorflow/core/platform/types.h"
-#include "tensorflow/core/profiler/lib/traceme.h"
 #include "tsl/platform/criticality.h"
+#include "tsl/profiler/lib/traceme.h"
 
 namespace tensorflow {
 namespace serving {
@@ -252,7 +251,7 @@
 // accept new tasks; a closed one cannot. A batch is monotonic: initially it is
 // open and tasks can be added to it; then it is closed and its set of tasks
 // remains fixed for the remainder of its life. A closed batch cannot be re-
-// opened. Tasks can never be removed from a batch.
+// opened.
 //
 // Type parameter TaskType must be a subclass of BatchTask.
 template <typename TaskType>
@@ -304,6 +303,15 @@
   // Returns the TraceMe context id of this batch.
   uint64 traceme_context_id() const;
 
+  // Attempts to trim this batch to a new, smaller size (not to be confused with
+  // the number of tasks in the batch). On success, the trimmed tasks go into
+  // 'out_trimmed_tasks' in the same order the tasks were in this batch.
+  //
+  // The method might not succeed if it needs to split a large task to hit the
+  // correct size.
+  void TryTrimToNewSize(
+      int new_size, std::vector<std::unique_ptr<TaskType>>& out_trimmed_tasks);
+
  private:
   mutable mutex mu_;
 
@@ -505,6 +513,45 @@
   return traceme_context_id_;
 }
 
+template <typename TaskType>
+void Batch<TaskType>::TryTrimToNewSize(
+    int new_size, std::vector<std::unique_ptr<TaskType>>& out_trimmed_tasks) {
+  mutex_lock l(mu_);
+  DCHECK_GT(new_size, 0);
+  DCHECK_LT(new_size, size_);
+  DCHECK(out_trimmed_tasks.empty());
+
+  // Index of the first task to trim away. It is possible that it is the index
+  // of a task of size larger than 1 that will have to be split in order to get
+  // to the target new_size.
+  int32 first_task_to_move = 0;
+  // The sum of sizes of tasks i, where i < first_task_to_move.
+  int32 size_of_previous_tasks = 0;
+  while (size_of_previous_tasks + tasks_[first_task_to_move]->size() <=
+         new_size) {
+    size_of_previous_tasks += tasks_[first_task_to_move]->size();
+    first_task_to_move++;
+    // The loop must always stop before this check is tripped because new_size
+    // must never be larger than the size of the batch.
+    DCHECK_LT(first_task_to_move, tasks_.size());
+  }
+
+  // Check whether task 'first_task_to_move' will have to be split.
+  if (size_of_previous_tasks < new_size) {
+    // TODO: b/325954758 - Consider supporting splitting large tasks and then
+    // drop 'Try' from the method name.
+    return;
+  }
+  DCHECK_EQ(size_of_previous_tasks, new_size);
+
+  // Actually trim.
+  out_trimmed_tasks.reserve(tasks_.size() - first_task_to_move);
+  std::move(tasks_.begin() + first_task_to_move, tasks_.end(),
+            std::back_inserter(out_trimmed_tasks));
+  tasks_.resize(first_task_to_move);
+  size_ = new_size;
+}
+
 }  // namespace serving
 }  // namespace tensorflow
 
diff --git a/tensorflow/core/kernels/batching_util/batch_scheduler_test.cc b/tensorflow/core/kernels/batching_util/batch_scheduler_test.cc
index e159c437..2f9c903 100644
--- a/tensorflow/core/kernels/batching_util/batch_scheduler_test.cc
+++ b/tensorflow/core/kernels/batching_util/batch_scheduler_test.cc
@@ -21,12 +21,13 @@
 #include <optional>
 #include <string>
 #include <tuple>
+#include <vector>
 
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
 #include "absl/status/status.h"
 #include "tensorflow/core/platform/env.h"
-#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/notification.h"
 #include "tensorflow/core/platform/status_matchers.h"
 #include "tensorflow/core/platform/test.h"
 #include "tsl/platform/criticality.h"
@@ -37,6 +38,7 @@
 
 using ::testing::ElementsAre;
 using ::testing::Eq;
+using ::testing::Pointer;
 using ::testing::Property;
 
 TEST(MixedPriorityBatchingPolicyTest, InvalidAttrValueError) {
@@ -386,6 +388,53 @@
   EXPECT_THAT(batch.RemoveAllTasks(), ::testing::IsEmpty());  // third call
 }
 
+TEST(BatchTest, TryTrimToNewSizeTrimsAndReturnsTrimmedElementsInOrder) {
+  Batch<FakeTask> batch;
+
+  auto task0 = new FakeTask(3);
+  batch.AddTask(std::unique_ptr<FakeTask>(task0));
+
+  auto task1 = new FakeTask(5);
+  batch.AddTask(std::unique_ptr<FakeTask>(task1));
+
+  auto task2 = new FakeTask(7);
+  batch.AddTask(std::unique_ptr<FakeTask>(task2));
+
+  auto task3 = new FakeTask(9);
+  batch.AddTask(std::unique_ptr<FakeTask>(task3));
+
+  std::vector<std::unique_ptr<FakeTask>> trimmed_tasks;
+  batch.TryTrimToNewSize(/* new_size= */ 8,
+                         /* out_trimmed_tasks= */ trimmed_tasks);
+
+  EXPECT_EQ(batch.size(), 8);
+  EXPECT_EQ(batch.num_tasks(), 2);
+
+  EXPECT_THAT(trimmed_tasks, ElementsAre(Pointer(task2), Pointer(task3)));
+
+  batch.Close();  // Batch::~Batch blocks until the batch is closed.
+}
+
+TEST(BatchTest, TryTrimToNewSizeDoesNotTrimWhenItWouldNeedToSplitATask) {
+  Batch<FakeTask> batch;
+
+  auto task0 = new FakeTask(3);
+  batch.AddTask(std::unique_ptr<FakeTask>(task0));
+
+  auto task1 = new FakeTask(5);
+  batch.AddTask(std::unique_ptr<FakeTask>(task1));
+
+  std::vector<std::unique_ptr<FakeTask>> trimmed_tasks;
+  batch.TryTrimToNewSize(/* new_size= */ 4,
+                         /* out_trimmed_tasks= */ trimmed_tasks);
+
+  EXPECT_EQ(batch.size(), 8);
+  EXPECT_EQ(batch.num_tasks(), 2);
+  EXPECT_TRUE(trimmed_tasks.empty());
+
+  batch.Close();  // Batch::~Batch blocks until the batch is closed.
+}
+
 }  // namespace
 }  // namespace serving
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/batching_util/batch_scheduler_utils.h b/tensorflow/core/kernels/batching_util/batch_scheduler_utils.h
index 7e4382a..9a6deb1 100644
--- a/tensorflow/core/kernels/batching_util/batch_scheduler_utils.h
+++ b/tensorflow/core/kernels/batching_util/batch_scheduler_utils.h
@@ -16,8 +16,15 @@
 #ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_SCHEDULER_UTILS_H_
 #define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_SCHEDULER_UTILS_H_
 
+#include <memory>
+#include <optional>
 #include <vector>
 
+#include "absl/strings/string_view.h"
+#include "absl/time/time.h"
+#include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
+#include "tensorflow/core/kernels/batching_util/batch_stats.h"
+#include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/types.h"
 
 namespace tensorflow {
@@ -36,6 +43,114 @@
                             const std::vector<int32>& allowed_batch_sizes,
                             bool disable_padding);
 
+// Constants containing possible values for the batch_padding_policy argument
+// of MaybeBatchDown. This argument specifies the policy that a batch scheduler
+// is using when deciding what to do when, say, 18 requests need to be batched,
+// but only 16 and 32 batch sizes are allowed. The following options are
+// available.
+//
+//   - PAD_UP: pad to size 32.
+//   - BATCH_DOWN: schedule a batch of size 16 and leave 2 requests in the
+//     batch buffer.
+//   - MINIMIZE_TPU_COST_PER_REQUEST: a smarter greedy policy that chooses
+//     to either PAD_UP or BATCH_DOWN so as to minimize the TPU costs per
+//     real request. In this case, it would compare (batch_16_cost / 16) and
+//     (batch_32_cost / 18).
+//
+inline constexpr absl::string_view kBatchDownPolicy = "BATCH_DOWN";
+inline constexpr absl::string_view kPadUpPolicy = "PAD_UP";
+inline constexpr absl::string_view kMinimizeTpuCostPerRequestPolicy =
+    "MINIMIZE_TPU_COST_PER_REQUEST";
+
+// Trims the batch to the next allowed batch size when possible and when
+// configured by batch_padding_policy.
+//
+// When trimming, this function puts the trimmed tasks go into the
+// out_trimmed_tasks vector in the same order as they were in the batch.
+template <typename TaskType>
+void MaybeBatchDown(Batch<TaskType>& batch,
+                    const std::vector<int32>& allowed_batch_sizes,
+                    bool disable_padding,
+                    absl::string_view batch_padding_policy,
+                    ModelBatchStats* model_batch_stats,
+                    std::vector<std::unique_ptr<TaskType>>& out_trimmed_tasks) {
+  if (batch_padding_policy == kPadUpPolicy) {
+    // This is the default behavior of batch resource when it is given a batch
+    // size that doesn't match any of the allowed batch sizes.
+    return;
+  }
+  bool minimize_tpu_cost_per_request;
+  if (batch_padding_policy == kBatchDownPolicy) {
+    minimize_tpu_cost_per_request = false;
+  } else if (batch_padding_policy == kMinimizeTpuCostPerRequestPolicy) {
+    if (model_batch_stats == nullptr) {
+      LOG_FIRST_N(ERROR, 1)
+          << kMinimizeTpuCostPerRequestPolicy
+          << " batch padding policy has been chosen "
+             "but no ModelBatchStats passed to the batch scheduler; will "
+             "fall back on the "
+          << kPadUpPolicy << " policy.";
+      return;
+    }
+    minimize_tpu_cost_per_request = true;
+  } else {
+    LOG_FIRST_N(ERROR, 1) << "Unsupported batch_padding_policy: "
+                          << batch_padding_policy << ", falling back on the "
+                          << kPadUpPolicy << " policy.";
+    return;
+  }
+
+  int32 batch_size = batch.size();
+
+  int32 pad_up_size =
+      GetNextAllowedBatchSize(batch_size, allowed_batch_sizes, disable_padding);
+  if (pad_up_size == batch_size) {
+    return;  // Good, no padding is necessary.
+  }
+
+  int32 batch_down_size =
+      GetPrevAllowedBatchSize(batch_size, allowed_batch_sizes, disable_padding);
+  if (batch_down_size == batch_size) {
+    return;  // Can't batch down (e.g. no smaller batch size available).
+  }
+
+  if (minimize_tpu_cost_per_request) {
+    // TODO: b/325954758 - Consider logging a warning here or elsewhere if
+    // a larger batch doesn't cost meaningfully cheaper than a smaller batch.
+    // TODO: b/325954758 - Consider logging a warning here or elsewhere if a
+    // smaller batch costs unreasonably cheaper than a larger one (assuming
+    // a batch cost model = constant_cost + batch_size * per_element_cost).
+    // TODO: b/325954758 - Consider occasionally picking either batch size so
+    // that we learn fresh costs of each batch size. For this code, it is not a
+    // large priority though because if we are in between two allowed batch
+    // sizes (say, 16 and 32), chances are that will occasionally organically
+    // get batches of exact sizes 16 and 32 (and then we pick those
+    // unconditionally). But if we explicitly occasionally explored other batch
+    // sizes, we wouldn't have to rely on this "chances are". For other
+    // applications of batch costs, we might also want to occasionally explore
+    // all allowed batch sizes and not just 16 and 32 from this example.
+    std::optional<absl::Duration> down_batch_cost =
+        model_batch_stats->batch_size(batch_down_size).tpu_cost().mean();
+    std::optional<absl::Duration> up_batch_cost =
+        model_batch_stats->batch_size(pad_up_size).tpu_cost().mean();
+    if (!down_batch_cost.has_value() || !up_batch_cost.has_value()) {
+      // We have no data about batch costs, let's just do nothing.
+      return;
+    }
+
+    auto batch_down_cost_per_request = *down_batch_cost / batch_down_size;
+    auto pad_up_cost_per_request = *up_batch_cost / batch_size;
+
+    if (pad_up_cost_per_request < batch_down_cost_per_request) {
+      // Abort batching down because it's cheaper to pad up.
+      return;
+    }
+  }
+
+  // Batch down.
+  batch.TryTrimToNewSize(batch_down_size, out_trimmed_tasks);
+}
+
 }  // namespace serving
 }  // namespace tensorflow
 
diff --git a/tensorflow/core/kernels/batching_util/batch_scheduler_utils_test.cc b/tensorflow/core/kernels/batching_util/batch_scheduler_utils_test.cc
index 2bff515..e45cb46 100644
--- a/tensorflow/core/kernels/batching_util/batch_scheduler_utils_test.cc
+++ b/tensorflow/core/kernels/batching_util/batch_scheduler_utils_test.cc
@@ -15,7 +15,14 @@
 
 #include "tensorflow/core/kernels/batching_util/batch_scheduler_utils.h"
 
+#include <cstddef>
+#include <memory>
+#include <vector>
+
 #include <gtest/gtest.h>
+#include "absl/time/time.h"
+#include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
+#include "tensorflow/core/kernels/batching_util/batch_stats.h"
 
 namespace tensorflow {
 namespace serving {
@@ -66,6 +73,208 @@
   EXPECT_EQ(GetPrevAllowedBatchSize(10, {2, 4, 8}, false), 8);
 }
 
+class FakeTask : public BatchTask {
+ public:
+  explicit FakeTask(size_t size) : size_(size) {}
+
+  size_t size() const override { return size_; }
+
+ private:
+  const size_t size_;
+};
+
+TEST(MaybeBatchDownTest, PadUp) {
+  Batch<FakeTask> batch;
+  batch.AddTask(std::make_unique<FakeTask>(1));
+  batch.AddTask(std::make_unique<FakeTask>(1));
+  batch.AddTask(std::make_unique<FakeTask>(1));
+  batch.Close();
+
+  std::vector<std::unique_ptr<FakeTask>> out_trimmed_tasks;
+
+  MaybeBatchDown(
+      /* batch= */ batch, /* allowed_batch_sizes= */ {1, 2, 4, 8},
+      /* disable_padding= */ false,
+      /* batch_padding_policy= */ kPadUpPolicy,
+      /* model_batch_stats= */ nullptr,
+      /* out_trimmed_tasks= */ out_trimmed_tasks);
+
+  // The batch must stay unchanged (for the batch resource to then pad it to the
+  // next allowed batch size, thus ending up in a pad-up behavior.)
+  EXPECT_EQ(batch.size(), 3);
+}
+
+TEST(MaybeBatchDownTest, BatchDown) {
+  Batch<FakeTask> batch;
+  batch.AddTask(std::make_unique<FakeTask>(1));
+  batch.AddTask(std::make_unique<FakeTask>(1));
+  batch.AddTask(std::make_unique<FakeTask>(1));
+  batch.Close();
+
+  std::vector<std::unique_ptr<FakeTask>> out_trimmed_tasks;
+
+  MaybeBatchDown(
+      /* batch= */ batch, /* allowed_batch_sizes= */ {1, 2, 4, 8},
+      /* disable_padding= */ false,
+      /* batch_padding_policy= */ kBatchDownPolicy,
+      /* model_batch_stats= */ nullptr,
+      /* out_trimmed_tasks= */ out_trimmed_tasks);
+
+  // The scheduler should trim the batch to a smaller allowed size that requires
+  // no padding.
+  EXPECT_EQ(batch.size(), 2);
+  // The trimmed part.
+  EXPECT_EQ(out_trimmed_tasks.size(), 1);
+}
+
+TEST(MaybeBatchDownTest, BatchDownDoesNotSplitTasks) {
+  // Add tasks for size 3, but the second task is large and will have to be
+  // split if doing batch-down.
+  Batch<FakeTask> batch;
+  batch.AddTask(std::make_unique<FakeTask>(1));
+  batch.AddTask(std::make_unique<FakeTask>(2));
+  batch.Close();
+
+  std::vector<std::unique_ptr<FakeTask>> out_trimmed_tasks;
+
+  MaybeBatchDown(
+      /* batch= */ batch, /* allowed_batch_sizes= */ {1, 2, 4, 8},
+      /* disable_padding= */ false,
+      /* batch_padding_policy= */ kBatchDownPolicy,
+      /* model_batch_stats= */ nullptr,
+      /* out_trimmed_tasks= */ out_trimmed_tasks);
+
+  // The batch must stay unchanged due the fact that the current implementation
+  // doesn's support splitting large tasks.
+  EXPECT_EQ(batch.size(), 3);
+}
+
+TEST(MaybeBatchDownTest, BatchDownDoesNothingWhenTheBatchSizeIsAlreadyAllowed) {
+  Batch<FakeTask> batch;
+  batch.AddTask(std::make_unique<FakeTask>(1));
+  batch.AddTask(std::make_unique<FakeTask>(1));
+  batch.AddTask(std::make_unique<FakeTask>(1));
+  batch.AddTask(std::make_unique<FakeTask>(1));
+  batch.Close();
+
+  std::vector<std::unique_ptr<FakeTask>> out_trimmed_tasks;
+
+  MaybeBatchDown(
+      /* batch= */ batch, /* allowed_batch_sizes= */ {1, 2, 4, 8},
+      /* disable_padding= */ false,
+      /* batch_padding_policy= */ kBatchDownPolicy,
+      /* model_batch_stats= */ nullptr,
+      /* out_trimmed_tasks= */ out_trimmed_tasks);
+
+  // The batch should stay unchanged because it's already of an allowed size.
+  EXPECT_EQ(batch.size(), 4);
+}
+
+TEST(MaybeBatchDownTest, BatchDownDoesNothingWhenNoSmallerAllowedSize) {
+  Batch<FakeTask> batch;
+  batch.AddTask(std::make_unique<FakeTask>(1));
+  batch.AddTask(std::make_unique<FakeTask>(1));
+  batch.AddTask(std::make_unique<FakeTask>(1));
+  batch.Close();
+
+  std::vector<std::unique_ptr<FakeTask>> out_trimmed_tasks;
+
+  MaybeBatchDown(
+      /* batch= */ batch, /* allowed_batch_sizes= */ {4, 8},
+      /* disable_padding= */ false,
+      /* batch_padding_policy= */ kBatchDownPolicy,
+      /* model_batch_stats= */ nullptr,
+      /* out_trimmed_tasks= */ out_trimmed_tasks);
+
+  // Can't batch down because there is no smaller allowed size.
+  EXPECT_EQ(batch.size(), 3);
+}
+
+TEST(MaybeBatchDownTest, MinimizeTpuCostPerRequestPicksBatchDown) {
+  Batch<FakeTask> batch;
+  batch.AddTask(std::make_unique<FakeTask>(1));
+  batch.AddTask(std::make_unique<FakeTask>(1));
+  batch.AddTask(std::make_unique<FakeTask>(1));
+  batch.Close();
+
+  ModelBatchStats model_batch_stats;
+  model_batch_stats.batch_size(2).tpu_cost().Register(absl::Seconds(2));
+  model_batch_stats.batch_size(4).tpu_cost().Register(absl::Seconds(3.1));
+
+  std::vector<std::unique_ptr<FakeTask>> out_trimmed_tasks;
+  MaybeBatchDown(
+      /* batch= */ batch, /* allowed_batch_sizes= */ {2, 4},
+      /* disable_padding= */ false,
+      /* batch_padding_policy= */ kMinimizeTpuCostPerRequestPolicy,
+      /* model_batch_stats= */ &model_batch_stats,
+      /* out_trimmed_tasks= */ out_trimmed_tasks);
+
+  EXPECT_EQ(batch.size(), 2);
+}
+
+TEST(MaybeBatchDownTest, MinimizeTpuCostPerRequestPicksPadUp) {
+  Batch<FakeTask> batch;
+  batch.AddTask(std::make_unique<FakeTask>(1));
+  batch.AddTask(std::make_unique<FakeTask>(1));
+  batch.AddTask(std::make_unique<FakeTask>(1));
+  batch.Close();
+
+  ModelBatchStats model_batch_stats;
+  model_batch_stats.batch_size(2).tpu_cost().Register(absl::Seconds(2));
+  model_batch_stats.batch_size(4).tpu_cost().Register(absl::Seconds(2.9));
+
+  std::vector<std::unique_ptr<FakeTask>> out_trimmed_tasks;
+  MaybeBatchDown(
+      /* batch= */ batch, /* allowed_batch_sizes= */ {2, 4},
+      /* disable_padding= */ false,
+      /* batch_padding_policy= */ kMinimizeTpuCostPerRequestPolicy,
+      /* model_batch_stats= */ &model_batch_stats,
+      /* out_trimmed_tasks= */ out_trimmed_tasks);
+
+  EXPECT_EQ(batch.size(), 3);
+}
+
+TEST(MaybeBatchDownTest, MinimizeTpuCostPerRequestIsOkWithMissingCosts) {
+  Batch<FakeTask> batch;
+  batch.AddTask(std::make_unique<FakeTask>(1));
+  batch.AddTask(std::make_unique<FakeTask>(1));
+  batch.AddTask(std::make_unique<FakeTask>(1));
+  batch.Close();
+
+  ModelBatchStats model_batch_stats;
+  model_batch_stats.batch_size(2).tpu_cost().Register(absl::Seconds(2));
+  // Not adding costs for batch 4.
+
+  std::vector<std::unique_ptr<FakeTask>> out_trimmed_tasks;
+  MaybeBatchDown(
+      /* batch= */ batch, /* allowed_batch_sizes= */ {2, 4},
+      /* disable_padding= */ false,
+      /* batch_padding_policy= */ kMinimizeTpuCostPerRequestPolicy,
+      /* model_batch_stats= */ &model_batch_stats,
+      /* out_trimmed_tasks= */ out_trimmed_tasks);
+
+  // No expectations as we do not expect a particular behavior. We just care
+  // that we don't crash.
+}
+
+TEST(MaybeBatchDownTest, MinimizeTpuCostPerRequestDoesPadUpWhenNoModelStats) {
+  Batch<FakeTask> batch;
+  batch.AddTask(std::make_unique<FakeTask>(1));
+  batch.AddTask(std::make_unique<FakeTask>(1));
+  batch.AddTask(std::make_unique<FakeTask>(1));
+  batch.Close();
+
+  std::vector<std::unique_ptr<FakeTask>> out_trimmed_tasks;
+  MaybeBatchDown(
+      /* batch= */ batch, /* allowed_batch_sizes= */ {2, 4},
+      /* disable_padding= */ false,
+      /* batch_padding_policy= */ kMinimizeTpuCostPerRequestPolicy,
+      /* model_batch_stats= */ nullptr,
+      /* out_trimmed_tasks= */ out_trimmed_tasks);
+
+  EXPECT_EQ(batch.size(), 3);
+}
+
 }  // namespace
 
 }  // namespace serving
diff --git a/tensorflow/core/kernels/batching_util/batch_stats.h b/tensorflow/core/kernels/batching_util/batch_stats.h
index b430663..87c36fc 100644
--- a/tensorflow/core/kernels/batching_util/batch_stats.h
+++ b/tensorflow/core/kernels/batching_util/batch_stats.h
@@ -66,6 +66,10 @@
 
 namespace tensorflow::serving {
 
+// Default values for when there is no recorded statistic in ModelBatchStats.
+constexpr int64_t kNumBatchThreadsUnknown = -1;
+constexpr int64_t kBatchTimeoutMicrosUnknown = -1;
+
 // Tracks the average cost of registered samples.
 //
 // Thread-safe.
@@ -167,6 +171,23 @@
     return result;
   }
 
+  void SetNumBatchThreads(int64_t num_batch_threads) {
+    num_batch_threads_.store(num_batch_threads, std::memory_order_relaxed);
+  }
+
+  int64_t num_batch_threads() const {
+    return num_batch_threads_.load(std::memory_order_relaxed);
+  }
+
+  void SetBatchTimeoutMicros(int64_t batch_timeout_micros) {
+    batch_timeout_micros_.store(batch_timeout_micros,
+                                std::memory_order_relaxed);
+  }
+
+  int64_t batch_timeout_micros() const {
+    return batch_timeout_micros_.load(std::memory_order_relaxed);
+  }
+
  private:
   mutable mutex mu_;
 
@@ -184,6 +205,13 @@
   // Can be used to generate an internal load metric per model. See
   // RegisterQuerySize for more details.
   std::atomic<int64_t> cumulative_processed_size_ = 0;
+
+  // The number of batch threads assigned to this model.
+  std::atomic<int64_t> num_batch_threads_ = kNumBatchThreadsUnknown;
+
+  // The timeout in microseconds for this model (after which the current batch
+  // is sent to be processed by the TPU).
+  std::atomic<int64_t> batch_timeout_micros_ = kBatchTimeoutMicrosUnknown;
 };
 
 // Tracks batch statistics for all models.
diff --git a/tensorflow/core/kernels/batching_util/batch_stats_test.cc b/tensorflow/core/kernels/batching_util/batch_stats_test.cc
index 223cde6..5f5168c 100644
--- a/tensorflow/core/kernels/batching_util/batch_stats_test.cc
+++ b/tensorflow/core/kernels/batching_util/batch_stats_test.cc
@@ -126,6 +126,28 @@
   ASSERT_THAT(stats.BatchSizes(), UnorderedElementsAre(1, 2, 4));
 }
 
+TEST(BatchStatsTest, BatchTimeoutIsCorrect) {
+  ModelBatchStats stats;
+
+  // Originally the batch timeout is -1 if unassigned.
+  ASSERT_EQ(stats.batch_timeout_micros(), -1);
+
+  // Assign a batch timeout of 100 microseconds.
+  stats.SetBatchTimeoutMicros(100);
+  ASSERT_EQ(stats.batch_timeout_micros(), 100);
+}
+
+TEST(BatchStatsTest, NumBatchThreadsIsCorrect) {
+  ModelBatchStats stats;
+
+  // Originally the number of batch threads is -1 if unassigned.
+  ASSERT_EQ(stats.num_batch_threads(), -1);
+
+  // Assign a number of per-model batch threads.
+  stats.SetNumBatchThreads(16);
+  ASSERT_EQ(stats.num_batch_threads(), 16);
+}
+
 }  // namespace
 
 }  // namespace tensorflow::serving
diff --git a/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h b/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h
index 93ac0c9..acea649 100644
--- a/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h
+++ b/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h
@@ -29,13 +29,13 @@
 #include <variant>
 #include <vector>
 
-#include "absl/log/check.h"
 #include "absl/status/status.h"
 #include "absl/strings/str_format.h"
 #include "absl/time/clock.h"
 #include "tensorflow/core/kernels/batching_util/batch_input_task.h"
 #include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
 #include "tensorflow/core/kernels/batching_util/batch_scheduler_utils.h"
+#include "tensorflow/core/kernels/batching_util/batch_stats.h"
 #include "tensorflow/core/kernels/batching_util/periodic_function.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/core/status.h"
@@ -43,6 +43,7 @@
 #include "tensorflow/core/platform/cpu_info.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/errors.h"
+#include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/mutex.h"
 #include "tensorflow/core/platform/notification.h"
 #include "tensorflow/core/platform/thread_annotations.h"
@@ -149,7 +150,7 @@
       const Options& options,
       std::shared_ptr<SharedBatchScheduler<TaskType>>* scheduler);
 
-  ~SharedBatchScheduler();
+  virtual ~SharedBatchScheduler();
 
   // Adds a queue to which tasks may be submitted. The returned queue implements
   // the BatchScheduler API. Each queue has its own set of scheduling options,
@@ -240,6 +241,18 @@
     // If true, the padding will not be appended.
     bool disable_padding = false;
 
+    // The padding policy to use.
+    //
+    // See the documentation for kPadUpPolicy for details.
+    string batch_padding_policy = string(kPadUpPolicy);
+
+    // A pointer to a ModelBatchStats instance for this model. To be used for
+    // cost-based padding policy selection.
+    //
+    // If null, some other padding policy will be used if a cost-based one is
+    // requested.
+    ModelBatchStats* model_batch_stats = nullptr;
+
     // If true, queue implementation would split high priority and low priority
     // inputs into two sub queues.
     bool enable_priority_queue = false;
@@ -270,13 +283,15 @@
     MixedPriorityBatchingPolicy mixed_priority_batching_policy =
         MixedPriorityBatchingPolicy::kLowPriorityPaddingWithMaxBatchSize;
   };
-  Status AddQueue(const QueueOptions& options,
-                  ProcessBatchCallback process_batch_callback,
-                  std::unique_ptr<BatchScheduler<TaskType>>* queue);
+  // This method is marked virtual for testing purposes only.
+  virtual Status AddQueue(const QueueOptions& options,
+                          ProcessBatchCallback process_batch_callback,
+                          std::unique_ptr<BatchScheduler<TaskType>>* queue);
 
- private:
+ protected:
   explicit SharedBatchScheduler(const Options& options);
 
+ private:
   void GetNextWorkItem_Locked(internal::Queue<TaskType>** queue_for_batch_out,
                               BatchUniquePtr* batch_to_process_out)
       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
@@ -588,6 +603,9 @@
   // The time at which the first task was added to the open (back-most) batch
   // in 'high_priority_batches_'. Valid iff that batch contains at least one
   // task.
+  //
+  // Note that when using a batch padding policy other than PAD_UP, this field
+  // might contain an approximate value (see ScheduleBatchWithEagerSplit).
   uint64 open_batch_start_time_micros_ TF_GUARDED_BY(mu_);
 
   // Whether this queue contains a batch that is eligible to be scheduled.
@@ -920,7 +938,7 @@
 
 template <typename TaskType>
 Status Queue<TaskType>::ScheduleWithLazySplit(std::unique_ptr<TaskType>* task) {
-  profiler::TraceMe trace_me([task] {
+  tsl::profiler::TraceMe trace_me([task] {
     return profiler::TraceMeEncode(
         "ScheduleWithLazySplit",
         {{"batching_input_task_size", (*task)->size()}});
@@ -1055,7 +1073,7 @@
 Status Queue<TaskType>::ScheduleWithoutOrEagerSplit(
     std::unique_ptr<TaskType>* task) {
   const bool large_batch_splitting = options_.enable_large_batch_splitting;
-  profiler::TraceMe trace_me([task, large_batch_splitting] {
+  tsl::profiler::TraceMe trace_me([task, large_batch_splitting] {
     return profiler::TraceMeEncode(
         large_batch_splitting ? "ScheduleWithEagerSplit"
                               : "ScheduleWithoutSplit",
@@ -1223,7 +1241,37 @@
     std::deque<std::unique_ptr<Batch<TaskType>>>& batches = GetBatches();
     // Consider closing the open batch at this time, to schedule it.
     if (batches.size() == 1 && IsOpenBatchSchedulable()) {
+      // Support BatchPaddingPolicy::kBatchDown and
+      // BatchPaddingPolicy::kMinimizeTpuCostPerRequest. We do this before
+      // starting a new batch because starting a new batch will close the old
+      // batch, making it read-only.
+      std::vector<std::unique_ptr<TaskType>> trimmed_tasks;
+      MaybeBatchDown(
+          /* batch= */ *batches[0],
+          /* allowed_batch_sizes= */ options_.allowed_batch_sizes,
+          /* disable_padding= */ options_.disable_padding,
+          /* batch_padding_policy= */ options_.batch_padding_policy,
+          /* model_batch_stats= */ options_.model_batch_stats,
+          /* out_trimmed_tasks= */ trimmed_tasks);
+
       StartNewBatch();
+
+      // Move the trimmed tasks, if any, into the new batch.
+      Batch<TaskType>& new_batch = *batches[1];
+      for (std::unique_ptr<TaskType>& task : trimmed_tasks) {
+        new_batch.AddTask(std::move(task));
+      }
+      if (!new_batch.empty()) {
+        // TODO - b/325954758: Reconsider the starting time of a trimmed batch.
+        //
+        // Ideally, we'd set open_batch_start_time_micros_ to time we received
+        // the first task, but we don't have this information here, so we're
+        // using NOW as the timestamp. An alternative solution that doesn't
+        // require adding time to each task would be to assume that requests
+        // arrived at a steady rate and therefore use a point between the old
+        // value of open_batch_start_time_micros_ and NOW.
+        open_batch_start_time_micros_ = env_->NowMicros();
+      }
     }
 
     if (batches.size() >= 2) {
diff --git a/tensorflow/core/kernels/batching_util/shared_batch_scheduler_test.cc b/tensorflow/core/kernels/batching_util/shared_batch_scheduler_test.cc
index 1c4073a..2a5afae 100644
--- a/tensorflow/core/kernels/batching_util/shared_batch_scheduler_test.cc
+++ b/tensorflow/core/kernels/batching_util/shared_batch_scheduler_test.cc
@@ -27,11 +27,14 @@
 #include "absl/container/fixed_array.h"
 #include "absl/status/status.h"
 #include "absl/time/time.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
+#include "tensorflow/core/kernels/batching_util/batch_scheduler_utils.h"
 #include "tensorflow/core/kernels/batching_util/fake_clock_env.h"
 #include "tensorflow/core/lib/core/notification.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/cpu_info.h"
+#include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/macros.h"
 #include "tensorflow/core/platform/status.h"
@@ -39,7 +42,6 @@
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/platform/test_benchmark.h"
 #include "tensorflow/core/protobuf/error_codes.pb.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/criticality.h"
 
 namespace tensorflow {
@@ -1052,6 +1054,79 @@
   }
 }
 
+TEST_P(SharedBatchSchedulerTest, BatchPaddingPolicyBatchDown) {
+  if (enable_lazy_split()) {
+    GTEST_SKIP()
+        << "BatchPaddingPolicy::kBatchDown is not supported for lazy split.";
+  }
+
+  // Set up a fake clock, which only advances when we explicitly tell it to.
+  test_util::FakeClockEnv env(Env::Default());
+  Notification start_teardown, stop_teardown;
+  std::unique_ptr<Thread> teardown_thread =
+      CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown);
+
+  {
+    Notification first_batch_processed;
+    Notification second_batch_processed;
+    auto callback = [&](std::unique_ptr<Batch<FakeTask>> batch) {
+      if (!first_batch_processed.HasBeenNotified()) {
+        // This is the main expectation of the test.
+        //
+        // The scheduler should  have trimmed the batch to a smaller allowed
+        // size which requires no padding.
+        EXPECT_EQ(batch->size(), 2);
+
+        first_batch_processed.Notify();
+        return;
+      }
+
+      if (!second_batch_processed.HasBeenNotified()) {
+        // Leftovers after the first batch.
+        EXPECT_EQ(batch->size(), 1);
+
+        second_batch_processed.Notify();
+        return;
+      }
+
+      ADD_FAILURE() << "Batch callback must not be invoked more than expected";
+    };
+
+    auto scheduler = CreateSharedBatchScheduler(1, &env);
+
+    QueueOptions options =
+        CreateQueueOptions(/* max_execution_batch_size= */ 10,
+                           /* input_batch_size_limit= */ 10,
+                           /* batch_timeout_micros= */ 10,
+                           /* max_enqueued_batches= */ 10);
+
+    // The most interesting option for this test.
+    options.allowed_batch_sizes = {1, 2, 4, 8};
+    options.batch_padding_policy = kBatchDownPolicy;
+
+    auto queue = CreateQueue(scheduler, options, callback);
+
+    // Schedule some tasks and ensure the scheduler calls the callback after a
+    // batch timeout has expired.
+    TF_ASSERT_OK(ScheduleTask(1, queue.get()));
+    TF_ASSERT_OK(ScheduleTask(1, queue.get()));
+    TF_ASSERT_OK(ScheduleTask(1, queue.get()));
+    env.AdvanceByMicroseconds(options.batch_timeout_micros);
+    first_batch_processed.WaitForNotification();
+
+    // Ensure the scheduler correctly updates the starting time of the new
+    // batch.
+    env.AdvanceByMicroseconds(options.batch_timeout_micros - 1);
+    EXPECT_FALSE(second_batch_processed.WaitForNotificationWithTimeout(
+        absl::Milliseconds(10)));
+    env.AdvanceByMicroseconds(1);
+    second_batch_processed.WaitForNotification();
+
+    start_teardown.Notify();
+  }
+  stop_teardown.Notify();
+}
+
 // TODO(b/161857471):
 // Add test coverage when input-split and no-split returns differently.
 INSTANTIATE_TEST_SUITE_P(
diff --git a/tensorflow/core/kernels/checkpoint_callback_manager_test.cc b/tensorflow/core/kernels/checkpoint_callback_manager_test.cc
index 250f843..cb39718 100644
--- a/tensorflow/core/kernels/checkpoint_callback_manager_test.cc
+++ b/tensorflow/core/kernels/checkpoint_callback_manager_test.cc
@@ -19,6 +19,7 @@
 
 #include "absl/strings/str_cat.h"
 #include "absl/strings/string_view.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/resource_handle.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/path.h"
@@ -26,7 +27,6 @@
 #include "tensorflow/core/platform/statusor.h"
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/platform/types.h"
-#include "tsl/lib/core/status_test_util.h"
 
 namespace tensorflow {
 namespace checkpoint {
diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD
index a5df473..c447763 100644
--- a/tensorflow/core/kernels/data/BUILD
+++ b/tensorflow/core/kernels/data/BUILD
@@ -121,6 +121,10 @@
         "//tensorflow/core:lib_internal",
         "//tensorflow/core/data:name_utils",
         "//tensorflow/core/data:split_utils",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings:str_format",
+        "@local_tsl//tsl/platform:mutex",
     ],
 )
 
@@ -851,6 +855,7 @@
         "//tensorflow/core/data:dataset_utils",
         "//tensorflow/core/data:name_utils",
         "//tensorflow/core/data:stats_utils",
+        "//tensorflow/core/data:unbounded_thread_pool",
         "//tensorflow/core/profiler/lib:traceme",
         "//tensorflow/core/profiler/lib:traceme_encode",
         "@com_google_absl//absl/base",
@@ -1480,7 +1485,7 @@
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/strings:string_view",
         "@com_google_googletest//:gtest_main",
-        "@local_tsl//tsl/lib/core:status_test_util",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
 
diff --git a/tensorflow/core/kernels/data/batch_dataset_op.cc b/tensorflow/core/kernels/data/batch_dataset_op.cc
index 2387067..cd8acd1 100644
--- a/tensorflow/core/kernels/data/batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/batch_dataset_op.cc
@@ -288,17 +288,24 @@
     Status RestoreInternal(IteratorContext* ctx,
                            IteratorStateReader* reader) override {
       mutex_lock l(mu_);
+      int64_t input_empty;
+      TF_RETURN_IF_ERROR(
+          reader->ReadScalar(prefix(), kInputImplEmpty, &input_empty));
+
       if (ctx->restored_element_count().has_value()) {
         IteratorContext::Params params(ctx);
         params.restored_element_count =
             *ctx->restored_element_count() * dataset()->batch_size_;
         IteratorContext ctx_copy(params);
-        return RestoreInput(&ctx_copy, reader, input_impl_);
+        if (!static_cast<bool>(input_empty)) {
+          TF_RETURN_IF_ERROR(RestoreInput(&ctx_copy, reader, input_impl_));
+          ctx->MergeCheckpoint(ctx_copy.checkpoint());
+        } else {
+          input_impl_.reset();
+        }
+        return absl::OkStatus();
       }
 
-      int64_t input_empty;
-      TF_RETURN_IF_ERROR(
-          reader->ReadScalar(prefix(), kInputImplEmpty, &input_empty));
       if (!static_cast<bool>(input_empty)) {
         TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
       } else {
diff --git a/tensorflow/core/kernels/data/cache_dataset_ops.cc b/tensorflow/core/kernels/data/cache_dataset_ops.cc
index de8f50d..b77af19 100644
--- a/tensorflow/core/kernels/data/cache_dataset_ops.cc
+++ b/tensorflow/core/kernels/data/cache_dataset_ops.cc
@@ -880,13 +880,14 @@
         TF_RETURN_IF_ERROR(
             WriteElementsToCheckpoint(writer, prefix(), cache_->data()));
       }
+      TF_RETURN_IF_ERROR(global_shuffle_iterator_.Save(prefix(), ctx, writer));
       return SaveInput(ctx, writer, iterator_);
     }
 
     Status RestoreInternal(IteratorContext* ctx,
                            IteratorStateReader* reader) override {
       if (ctx->restored_element_count().has_value()) {
-        return global_shuffle_iterator_.Restore(ctx);
+        return global_shuffle_iterator_.Restore(prefix(), ctx, reader);
       }
       mutex_lock l(mu_);
       iterator_.reset();
diff --git a/tensorflow/core/kernels/data/concatenate_dataset_op.cc b/tensorflow/core/kernels/data/concatenate_dataset_op.cc
index 8f380a2..77ed1ce 100644
--- a/tensorflow/core/kernels/data/concatenate_dataset_op.cc
+++ b/tensorflow/core/kernels/data/concatenate_dataset_op.cc
@@ -14,13 +14,21 @@
 ==============================================================================*/
 #include "tensorflow/core/kernels/data/concatenate_dataset_op.h"
 
-#include <string>
+#include <algorithm>
+#include <cstddef>
 #include <utility>
 
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_format.h"
 #include "tensorflow/core/data/name_utils.h"
 #include "tensorflow/core/data/split_utils.h"
-#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/dataset.h"
 #include "tensorflow/core/framework/tensor.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/mutex.h"
+#include "tsl/platform/statusor.h"
+#include "tsl/platform/thread_annotations.h"
 
 namespace tensorflow {
 namespace data {
@@ -36,6 +44,30 @@
 
 constexpr char kIndex[] = "i";
 constexpr char kInputImplUninitialized[] = "input_impl_uninitialized";
+constexpr char kElementCount[] = "element_count";
+
+namespace {
+
+// Gets the next shuffled index by iterating through the `index_mapper` until
+// 1. It is not a `NotFoundError` or
+// 2. It is an `OutOfRangeError` or
+// 3. It is an error other than `NotFoundError` or `OutOfRangeError`
+absl::StatusOr<size_t> GetNextShuffledIndex(const IndexMapperFn& index_mapper,
+                                            size_t& element_count) {
+  absl::StatusOr<size_t> shuffled_index = absl::NotFoundError("default");
+
+  while (absl::IsNotFound(shuffled_index.status())) {
+    shuffled_index = index_mapper(element_count++);
+    if (absl::IsOutOfRange(shuffled_index.status())) {
+      return shuffled_index.status();
+    }
+    if (!absl::IsNotFound(shuffled_index.status()) && !shuffled_index.ok()) {
+      return shuffled_index.status();
+    }
+  }
+  return shuffled_index;
+}
+}  // namespace
 
 class ConcatenateDatasetOp::Dataset : public DatasetBase {
  public:
@@ -58,6 +90,12 @@
                                                  &output_tensorshape));
       output_shapes_.push_back(output_tensorshape);
     }
+    if (input_ != nullptr && !input_->RandomIndexingCompatible().ok()) {
+      random_indexing_compatible_ = input->RandomIndexingCompatible();
+    } else if (to_concatenate_ != nullptr &&
+               !to_concatenate_->RandomIndexingCompatible().ok()) {
+      random_indexing_compatible_ = to_concatenate_->RandomIndexingCompatible();
+    }
   }
   ~Dataset() override {
     input_->Unref();
@@ -126,6 +164,10 @@
     return absl::OkStatus();
   }
 
+  absl::Status RandomIndexingCompatible() const override {
+    return random_indexing_compatible_;
+  }
+
  protected:
   Status AsGraphDefInternal(SerializationContext* ctx,
                             DatasetGraphDefBuilder* b,
@@ -149,11 +191,15 @@
     bool SymbolicCheckpointCompatible() const override { return true; }
 
     Status Initialize(IteratorContext* ctx) override {
+      mutex_lock l(mu_);
+      input_impls_.resize(2);
+
       TF_ASSIGN_OR_RETURN(input_contexts_,
                           CreateInputIteratorContexts(ctx, dataset()));
       TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
           &input_contexts_[0], this, strings::StrCat(prefix(), "[0]"),
-          &input_impl_));
+          &input_impls_[0]));
+
       ctx->MergeCheckpoint(input_contexts_[0].checkpoint());
       return absl::OkStatus();
     }
@@ -162,25 +208,115 @@
                            std::vector<Tensor>* out_tensors,
                            bool* end_of_sequence) override {
       mutex_lock l(mu_);
-      if (!input_impl_) {
+      if (!input_impls_[0] && !input_impls_[1]) {
         *end_of_sequence = true;
         return absl::OkStatus();
       }
-      while (i_ < 2) {
-        TF_RETURN_IF_ERROR(input_impl_->GetNext(&input_contexts_[i_],
-                                                out_tensors, end_of_sequence));
+      // Global shuffling
+      if (ctx->index_mapper()) {
+        if (input_impls_[1] == nullptr) {
+          // Creates the second iterator immediately in the case of
+          // global random shuffling.
+          TF_RETURN_IF_ERROR(dataset()->to_concatenate_->MakeIterator(
+              &input_contexts_[1], this, strings::StrCat(prefix(), "[1]"),
+              &input_impls_[1]));
+          ctx->MergeCheckpoint(input_contexts_[1].checkpoint());
+        }
+
+        if (input_contexts_[0].index_mapper() == nullptr) {
+          IndexMapperFn left_index_mapper =
+              [index_mapper = ctx->index_mapper(),
+               left_cardinality = dataset()->input_cardinality_,
+               right_cardinality = dataset()->to_concatenate_cardinality_](
+                  size_t to_idx) -> absl::StatusOr<size_t> {
+            TF_ASSIGN_OR_RETURN(size_t from_idx, index_mapper(to_idx));
+
+            if (from_idx >= left_cardinality + right_cardinality) {
+              return absl::OutOfRangeError("Running out of elements.");
+            }
+            if (from_idx >= left_cardinality) {
+              // This has to return a status so that upstream global shuffle
+              // iterator will not treat it as an end of sequence.
+              return absl::NotFoundError("Skipping this element.");
+            }
+            return from_idx;
+          };
+
+          IndexMapperFn right_index_mapper =
+              [index_mapper = ctx->index_mapper(),
+               left_cardinality = dataset()->input_cardinality_,
+               right_cardinality = dataset()->to_concatenate_cardinality_](
+                  size_t to_idx) -> absl::StatusOr<size_t> {
+            TF_ASSIGN_OR_RETURN(size_t from_idx, index_mapper(to_idx));
+
+            if (from_idx >= left_cardinality + right_cardinality) {
+              return absl::OutOfRangeError("Running out of elements.");
+            }
+            if (from_idx < left_cardinality) {
+              // This has to return a status so that upstream global shuffle
+              // iterator will not treat it as an end of sequence.
+              return absl::NotFoundError("Skipping this element.");
+            }
+            return from_idx - left_cardinality;
+          };
+
+          input_contexts_[0].SetIndexMapper(left_index_mapper);
+          input_contexts_[1].SetIndexMapper(right_index_mapper);
+        }
+
+        // Materializes the shuffled index because we need this information
+        // to determine which iterator we need to call later.
+
+        absl::StatusOr<size_t> shuffled_index =
+            GetNextShuffledIndex(ctx->index_mapper(), element_count_);
+
+        if (absl::IsOutOfRange(shuffled_index.status())) {
+          *end_of_sequence = true;
+          return absl::OkStatus();
+        }
+
+        TF_RETURN_IF_ERROR(shuffled_index.status());
+
+        // Routes the shuffled index to the correct input iterator.
+        bool temp_end_of_sequence = false;
+        absl::Status status = absl::OkStatus();
+        if (shuffled_index.value() < dataset()->input_cardinality_) {
+          status = input_impls_[0]->GetNext(&input_contexts_[0], out_tensors,
+                                            &temp_end_of_sequence);
+          ctx->MergeCheckpoint(input_contexts_[0].checkpoint());
+        } else {
+          status = input_impls_[1]->GetNext(&input_contexts_[1], out_tensors,
+                                            &temp_end_of_sequence);
+          ctx->MergeCheckpoint(input_contexts_[1].checkpoint());
+        }
+        TF_RETURN_IF_ERROR(status);
+
+        if (temp_end_of_sequence) {
+          *end_of_sequence = temp_end_of_sequence;
+          return absl::OkStatus();
+        }
+        return absl::OkStatus();
+      }
+
+      for (; i_ < 2; ++i_) {
+        TF_RETURN_IF_ERROR(input_impls_[i_]->GetNext(
+            &input_contexts_[i_], out_tensors, end_of_sequence));
         ctx->MergeCheckpoint(input_contexts_[i_].checkpoint());
         if (!*end_of_sequence) {
           return absl::OkStatus();
         }
-        if (++i_ < 2) {
+        if (i_ == 0) {
+          // Creates the second iterator only when the first iterator
+          // is exhausted to save memory usage.
           TF_RETURN_IF_ERROR(dataset()->to_concatenate_->MakeIterator(
-              &input_contexts_[i_], this, strings::StrCat(prefix(), "[1]"),
-              &input_impl_));
+              &input_contexts_[1], this, strings::StrCat(prefix(), "[1]"),
+              &input_impls_[1]));
+          ctx->MergeCheckpoint(input_contexts_[1].checkpoint());
         }
       }
       *end_of_sequence = true;
-      input_impl_.reset();
+      input_impls_[0].reset();
+      input_impls_[1].reset();
       return absl::OkStatus();
     }
 
@@ -196,10 +332,18 @@
       mutex_lock l(mu_);
       TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kIndex, i_));
       TF_RETURN_IF_ERROR(
-          writer->WriteScalar(prefix(), kInputImplUninitialized,
-                              static_cast<int64_t>(!input_impl_)));
-      if (input_impl_) {
-        TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
+          writer->WriteScalar(prefix(), kElementCount, element_count_));
+      TF_RETURN_IF_ERROR(writer->WriteScalar(
+          prefix(), absl::StrFormat("%s[%d]", kInputImplUninitialized, 0),
+          static_cast<int64_t>(!input_impls_[0])));
+      if (input_impls_[0]) {
+        TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impls_[0]));
+      }
+      TF_RETURN_IF_ERROR(writer->WriteScalar(
+          prefix(), absl::StrFormat("%s[%d]", kInputImplUninitialized, 1),
+          static_cast<int64_t>(!input_impls_[1])));
+      if (input_impls_[1]) {
+        TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impls_[1]));
       }
       return absl::OkStatus();
     }
@@ -207,33 +351,96 @@
     Status RestoreInternal(IteratorContext* ctx,
                            IteratorStateReader* reader) override {
       mutex_lock l(mu_);
-      TF_RETURN_IF_ERROR(reader->ReadScalar(prefix(), kIndex, &i_));
-      int64_t input_uninitialized;
-      TF_RETURN_IF_ERROR(reader->ReadScalar(prefix(), kInputImplUninitialized,
-                                            &input_uninitialized));
-      if (static_cast<bool>(input_uninitialized)) {
-        input_impl_.reset();
+
+      int64_t input_uninitialized[2];
+      TF_RETURN_IF_ERROR(reader->ReadScalar(
+          prefix(), absl::StrFormat("%s[%d]", kInputImplUninitialized, 0),
+          &input_uninitialized[0]));
+      if (static_cast<bool>(input_uninitialized[0])) {
+        input_impls_[0].reset();
+      }
+      TF_RETURN_IF_ERROR(reader->ReadScalar(
+          prefix(), absl::StrFormat("%s[%d]", kInputImplUninitialized, 1),
+          &input_uninitialized[1]));
+      if (static_cast<bool>(input_uninitialized[1])) {
+        input_impls_[1].reset();
+      }
+
+      if (ctx->restored_element_count()) {
+        if (input_impls_.size() != 2) {
+          return absl::FailedPreconditionError(
+              "`Initialize` should be called before restoring from the "
+              "checkpoint.");
+        }
+        {
+          int64_t tmp_element_count;
+          TF_RETURN_IF_ERROR(
+              reader->ReadScalar(prefix(), kElementCount, &tmp_element_count));
+          if (tmp_element_count < 0) {
+            return absl::FailedPreconditionError(absl::StrFormat(
+                "element_count should be >= 0. Got %d", tmp_element_count));
+          }
+          element_count_ = static_cast<size_t>(tmp_element_count);
+        }
+
+        if (!static_cast<bool>(input_uninitialized[0])) {
+          if (!input_impls_[0]) {
+            return absl::FailedPreconditionError(
+                "Something went wrong internally. The first iterator should "
+                "exist because of `Initialize`.");
+          }
+          input_contexts_[0].set_restored_element_count(
+              *ctx->restored_element_count());
+          TF_RETURN_IF_ERROR(
+              RestoreInput(&input_contexts_[0], reader, input_impls_[0]));
+          ctx->MergeCheckpoint(input_contexts_[0].checkpoint());
+        }
+
+        if (!static_cast<bool>(input_uninitialized[1])) {
+          TF_RETURN_IF_ERROR(dataset()->to_concatenate_->MakeIterator(
+              &input_contexts_[1], this, strings::StrCat(prefix(), "[1]"),
+              &input_impls_[1]));
+
+          input_contexts_[1].set_restored_element_count(
+              *ctx->restored_element_count());
+
+          TF_RETURN_IF_ERROR(
+              RestoreInput(&input_contexts_[1], reader, input_impls_[1]));
+          ctx->MergeCheckpoint(input_contexts_[1].checkpoint());
+        }
         return absl::OkStatus();
       }
+
+      TF_RETURN_IF_ERROR(reader->ReadScalar(prefix(), kIndex, &i_));
+
       if (!TF_PREDICT_TRUE(i_ >= 0 && i_ <= 2))
         return errors::InvalidArgument("i_ must be in range [0, 2].");
-      if (i_ == 1) {
+
+      if (!static_cast<bool>(input_uninitialized[0])) {
+        TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impls_[0]));
+      }
+      if (!static_cast<bool>(input_uninitialized[1])) {
         TF_RETURN_IF_ERROR(dataset()->to_concatenate_->MakeIterator(
-            ctx, this, strings::StrCat(prefix(), "[1]"), &input_impl_));
-      } else if (i_ == 2) {
-        input_impl_.reset();
+            &input_contexts_[1], this, strings::StrCat(prefix(), "[1]"),
+            &input_impls_[1]));
+        ctx->MergeCheckpoint(input_contexts_[1].checkpoint());
+
+        TF_RETURN_IF_ERROR(
+            RestoreInput(&input_contexts_[1], reader, input_impls_[1]));
+        ctx->MergeCheckpoint(input_contexts_[1].checkpoint());
       }
-      if (input_impl_) {
-        TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
-      }
+
       return absl::OkStatus();
     }
 
    private:
     mutex mu_;
     int64_t i_ TF_GUARDED_BY(mu_);
-    std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
-    std::vector<IteratorContext> input_contexts_;
+    std::vector<std::unique_ptr<IteratorBase>> input_impls_ TF_GUARDED_BY(mu_);
+    std::vector<IteratorContext> input_contexts_ TF_GUARDED_BY(mu_);
+    // Indicates `ctx->index_mapper()(element_count_)` is the next
+    // shuffled index.
+    size_t element_count_ TF_GUARDED_BY(mu_) = 0;
   };
 
   Status MostSpecificCompatibleShape(const PartialTensorShape& ts1,
@@ -257,6 +464,7 @@
   const int64_t input_cardinality_;
   const int64_t to_concatenate_cardinality_;
   std::vector<PartialTensorShape> output_shapes_;
+  absl::Status random_indexing_compatible_ = absl::OkStatus();
 };
 
 ConcatenateDatasetOp::ConcatenateDatasetOp(OpKernelConstruction* ctx)
diff --git a/tensorflow/core/kernels/data/experimental/assert_cardinality_dataset_op.cc b/tensorflow/core/kernels/data/experimental/assert_cardinality_dataset_op.cc
index f15bd39..936ad1f 100644
--- a/tensorflow/core/kernels/data/experimental/assert_cardinality_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/assert_cardinality_dataset_op.cc
@@ -173,17 +173,8 @@
 
     Status RestoreInternal(IteratorContext* ctx,
                            IteratorStateReader* reader) override {
-      if (ctx->restored_element_count().has_value()) {
-        num_elements_ = *(ctx->restored_element_count());
-        // If the dataset has reached the end of sequence, the restored element
-        // count could be cardinality + 1.
-        if (num_elements_ > dataset()->Cardinality()) {
-          num_elements_ = dataset()->Cardinality();
-        }
-      } else {
-        TF_RETURN_IF_ERROR(
-            reader->ReadScalar(full_name("num_elements"), &num_elements_));
-      }
+      TF_RETURN_IF_ERROR(
+          reader->ReadScalar(full_name("num_elements"), &num_elements_));
       return RestoreInput(ctx, reader, input_impl_);
     }
 
diff --git a/tensorflow/core/kernels/data/experimental/dense_to_sparse_batch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/dense_to_sparse_batch_dataset_op.cc
index 8bfc9ad..91b596c 100644
--- a/tensorflow/core/kernels/data/experimental/dense_to_sparse_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/dense_to_sparse_batch_dataset_op.cc
@@ -251,7 +251,7 @@
           const auto& t_flat = t.flat<T>();
           // TODO(mrry): Replace with a memcpy or something more
           // efficient. (Maybe an Eigen assign op?)
-          gtl::InlinedVector<int64_t, 4> strides(row_ndims);
+          absl::InlinedVector<int64_t, 4UL> strides(row_ndims);
           if (!strides.empty()) {
             strides[row_ndims - 1] = 1;
             for (int64_t row_dim = strides.size() - 2; row_dim >= 0;
diff --git a/tensorflow/core/kernels/data/experimental/global_shuffle_dataset_op.cc b/tensorflow/core/kernels/data/experimental/global_shuffle_dataset_op.cc
index c937f61..20d65e8 100644
--- a/tensorflow/core/kernels/data/experimental/global_shuffle_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/global_shuffle_dataset_op.cc
@@ -235,15 +235,8 @@
         TF_ASSIGN_OR_RETURN(element_position,
                             parent_index_mapper(element_position));
       }
-      // This could happen if the source dataset generates more elements than
-      // needed by the intermediate transformations. For example, when shuffling
-      // `range(10).batch(3, drop_remainder=True)`, the last element of `range`
-      // has index 9, which maps to the 4th batched element. However, since
-      // `batch` drops remainders, the cardinality is 3. In this case, the
-      // element position exceeds the max index. The caller is responsible to
-      // handle this case properly.
       if (element_position > max_index) {
-        return element_position;
+        return absl::OutOfRangeError("Out of range");
       }
       if (max_index == 0) {
         return 0;
@@ -265,6 +258,7 @@
     TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kSeed, seed_));
     TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kSeed2, seed2_));
     TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kSeed3, seed3_));
+    TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
     return absl::OkStatus();
   }
 
diff --git a/tensorflow/core/kernels/data/experimental/list_dataset_op.cc b/tensorflow/core/kernels/data/experimental/list_dataset_op.cc
index c93ccce8..2852b44 100644
--- a/tensorflow/core/kernels/data/experimental/list_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/list_dataset_op.cc
@@ -186,14 +186,16 @@
 
     Status SaveInternal(SerializationContext* ctx,
                         IteratorStateWriter* writer) override {
-      return split_provider_->Save(
-          [this](const std::string& key) { return full_name(key); }, writer);
+      TF_RETURN_IF_ERROR(split_provider_->Save(
+          [this](const std::string& key) { return full_name(key); }, writer));
+      TF_RETURN_IF_ERROR(global_shuffle_iterator_.Save(prefix(), ctx, writer));
+      return absl::OkStatus();
     }
 
     Status RestoreInternal(IteratorContext* ctx,
                            IteratorStateReader* reader) override {
       if (ctx->restored_element_count().has_value()) {
-        return global_shuffle_iterator_.Restore(ctx);
+        return global_shuffle_iterator_.Restore(prefix(), ctx, reader);
       }
       return split_provider_->Restore(
           [this](const std::string& key) { return full_name(key); }, reader);
diff --git a/tensorflow/core/kernels/data/experimental/list_dataset_op_test.cc b/tensorflow/core/kernels/data/experimental/list_dataset_op_test.cc
index 86f4b00..44e25cd 100644
--- a/tensorflow/core/kernels/data/experimental/list_dataset_op_test.cc
+++ b/tensorflow/core/kernels/data/experimental/list_dataset_op_test.cc
@@ -96,7 +96,7 @@
       const std::vector<std::vector<Tensor>>& input_elements) {
     std::vector<PartialTensorShape> output_shapes;
     for (const auto& tensor : input_elements.front()) {
-      gtl::InlinedVector<int64_t, 4> partial_dim_sizes;
+      absl::InlinedVector<int64_t, 4UL> partial_dim_sizes;
       partial_dim_sizes.reserve(tensor.dims());
       for (int i = 0; i < tensor.dims(); ++i) {
         partial_dim_sizes.push_back(tensor.dim_size(i));
diff --git a/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc
index 7021a23..cd9bb2a 100644
--- a/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc
@@ -164,7 +164,7 @@
                             DatasetGraphDefBuilder* b,
                             Node** output) const override {
     std::vector<std::pair<size_t, Node*>> inputs;
-    std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>> list_inputs;
+    std::vector<std::pair<size_t, absl::Span<Node* const>>> list_inputs;
     int input_index = 0;
 
     Node* input_node;
diff --git a/tensorflow/core/kernels/data/experimental/unbatch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/unbatch_dataset_op.cc
index 3cc3320..5682d19 100644
--- a/tensorflow/core/kernels/data/experimental/unbatch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/unbatch_dataset_op.cc
@@ -72,7 +72,7 @@
           if (batch_size_ < 0 && shape.dim_size(0) >= 0) {
             batch_size_ = shape.dim_size(0);
           }
-          gtl::InlinedVector<int64_t, 4> partial_dim_sizes;
+          absl::InlinedVector<int64_t, 4UL> partial_dim_sizes;
           for (int i = 1; i < shape.dims(); ++i) {
             partial_dim_sizes.push_back(shape.dim_size(i));
           }
diff --git a/tensorflow/core/kernels/data/flat_map_dataset_op.cc b/tensorflow/core/kernels/data/flat_map_dataset_op.cc
index bdd95f0..cd03a09 100644
--- a/tensorflow/core/kernels/data/flat_map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/flat_map_dataset_op.cc
@@ -141,7 +141,12 @@
   }
 
   absl::Status RandomIndexingCompatible() const override {
-    return random_indexing_compatible_;
+    return absl::UnimplementedError(
+        "Please consider applying maps on each dataset, concatenating them "
+        "into "
+        "one dataset and apply global shuffle dataset op onto the "
+        "dataset to achieve the same result as flat map with global "
+        "shuffling.");
   }
 
  protected:
@@ -358,7 +363,10 @@
       return absl::OkStatus();
     }
 
-    // TODO(b/325112575): Refactor and reuse this code from weighted flat map.
+    // TODO: b/355241367 - This implementation is incorrect because IndexMapper
+    // should be stateless otherwise it would not be compatible with batch
+    // dataset op.
+    // See go/tf-data-random-access-iterator-for-concatenate for more info.
     IndexMapperFn GetFlatMapIndexMapper(IndexMapperFn parent_index_mapper,
                                         size_t input_dataset_index)
         TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
diff --git a/tensorflow/core/kernels/data/map_defun_op_test.cc b/tensorflow/core/kernels/data/map_defun_op_test.cc
index d48650c..aaf292e 100644
--- a/tensorflow/core/kernels/data/map_defun_op_test.cc
+++ b/tensorflow/core/kernels/data/map_defun_op_test.cc
@@ -104,9 +104,10 @@
   }
 
   // Creates a new `MapDefun` op kernel context.
-  Status CreateMapDefunContext(OpKernel* const op_kernel,
-                               gtl::InlinedVector<TensorValue, 4>* const inputs,
-                               std::unique_ptr<OpKernelContext>* context) {
+  Status CreateMapDefunContext(
+      OpKernel* const op_kernel,
+      absl::InlinedVector<TensorValue, 4UL>* const inputs,
+      std::unique_ptr<OpKernelContext>* context) {
     TF_RETURN_IF_ERROR(CheckOpKernelInput(*op_kernel, *inputs));
     TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context));
     return absl::OkStatus();
@@ -243,7 +244,7 @@
   TestCase test_case = GetParam();
   TF_ASSERT_OK(InitializeRuntime(test_case.map_defun_op_params));
   auto input_tensors = test_case.map_defun_op_params.GetInputTensors();
-  gtl::InlinedVector<TensorValue, 4> input_values;
+  absl::InlinedVector<TensorValue, 4UL> input_values;
   for (auto& input : input_tensors) {
     input_values.push_back(TensorValue(&input));
   }
@@ -272,7 +273,7 @@
   for (auto& test_case : test_cases) {
     TF_ASSERT_OK(InitializeRuntime(test_case.map_defun_op_params));
     auto input_tensors = test_case.map_defun_op_params.GetInputTensors();
-    gtl::InlinedVector<TensorValue, 4> input_values;
+    absl::InlinedVector<TensorValue, 4UL> input_values;
     for (auto& input : input_tensors) {
       input_values.push_back(TensorValue(&input));
     }
diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
index fcb0c11..72b230c 100644
--- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
@@ -17,6 +17,7 @@
 #include <cstddef>
 #include <deque>
 #include <functional>
+#include <limits>
 #include <memory>
 #include <optional>
 #include <string>
@@ -31,6 +32,7 @@
 #include "tensorflow/core/data/dataset_utils.h"
 #include "tensorflow/core/data/name_utils.h"
 #include "tensorflow/core/data/stats_utils.h"
+#include "tensorflow/core/data/unbounded_thread_pool.h"
 #include "tensorflow/core/framework/dataset.h"
 #include "tensorflow/core/framework/metrics.h"
 #include "tensorflow/core/framework/model.h"
@@ -92,17 +94,19 @@
           const std::vector<PartialTensorShape>& output_shapes,
           DeterminismPolicy deterministic,
           std::unique_ptr<CapturedFunction> captured_func,
-          bool preserve_cardinality, int op_version)
+          bool preserve_cardinality, bool use_unbounded_threadpool,
+          int op_version)
       : Dataset(DatasetContext(ctx), input, num_parallel_calls, output_types,
                 output_shapes, deterministic, std::move(captured_func),
-                preserve_cardinality, op_version) {}
+                preserve_cardinality, use_unbounded_threadpool, op_version) {}
 
   Dataset(DatasetContext dataset_context, const DatasetBase* input,
           int64_t num_parallel_calls, const DataTypeVector& output_types,
           const std::vector<PartialTensorShape>& output_shapes,
           DeterminismPolicy deterministic,
           std::unique_ptr<CapturedFunction> captured_func,
-          bool preserve_cardinality, int op_version)
+          bool preserve_cardinality, bool use_unbounded_threadpool,
+          int op_version)
       : DatasetBase(std::move(dataset_context)),
         input_(input),
         num_parallel_calls_(num_parallel_calls),
@@ -110,6 +114,7 @@
         output_shapes_(output_shapes),
         deterministic_(deterministic),
         preserve_cardinality_(preserve_cardinality),
+        use_unbounded_threadpool_(use_unbounded_threadpool),
         captured_func_(std::move(captured_func)),
         op_version_(op_version) {
     input_->Ref();
@@ -235,6 +240,12 @@
     b->BuildAttrValue(preserve_cardinality_, &preserve_cardinality_attr);
     attrs.emplace_back(kPreserveCardinality, preserve_cardinality_attr);
 
+    // Attr: use_unbounded_threadpool
+    AttrValue use_unbounded_threadpool_attr;
+    b->BuildAttrValue(use_unbounded_threadpool_,
+                      &use_unbounded_threadpool_attr);
+    attrs.emplace_back(kUseUnboundedThreadpool, use_unbounded_threadpool_attr);
+
     TF_RETURN_IF_ERROR(b->AddDataset(
         this,
         {std::make_pair(0, input_graph_node),
@@ -256,6 +267,7 @@
           deterministic_(params.dataset->deterministic_.IsDeterministic() ||
                          params.dataset->deterministic_.IsDefault()),
           preserve_cardinality_(params.dataset->preserve_cardinality_),
+          use_unbounded_threadpool_(params.dataset->use_unbounded_threadpool_),
           autotune_(params.dataset->num_parallel_calls_ == model::kAutotune) {}
 
     ~Iterator() override {
@@ -271,7 +283,10 @@
     Status Initialize(IteratorContext* ctx) override {
       mutex_lock l(*mu_);
       interleave_depth_ = ctx->interleave_depth();
-
+      if (use_unbounded_threadpool_) {
+        unbounded_thread_pool_ = std::make_unique<UnboundedThreadPool>(
+            ctx->env(), "tf_data_map_unbounded_thread_pool");
+      }
       if (num_parallel_calls_->value == model::kAutotune) {
         num_parallel_calls_->value = GetAutotuneDefaultParallelism(ctx);
       }
@@ -323,11 +338,17 @@
     std::shared_ptr<model::Node> CreateNode(
         IteratorContext* ctx, model::Node::Args args) const override {
       std::shared_ptr<model::Parameter> parameter;
+      // If unbounded threadpool is used, sets the max of `num_parallel_calls`
+      // to be infinite and lets Autotune find the right value that is under
+      // the ram budget.
+      double max_parallelism_value = use_unbounded_threadpool_
+                                         ? std::numeric_limits<double>::max()
+                                         : ctx->runner_threadpool_size();
       if (num_parallel_calls_ &&
           dataset()->num_parallel_calls_ == model::kAutotune) {
         parameter = model::MakeParameter(
             "parallelism", num_parallel_calls_, /*min=*/1,
-            /*max=*/ctx->runner_threadpool_size(),
+            /*max=*/max_parallelism_value,
             // This is to ensure before this op has seen its first element,
             // `MaximumBufferedBytes()` can use the correct `parameter->value`
             // to estimate the maximum buffer bytes.
@@ -335,7 +356,7 @@
       } else {
         parameter =
             model::MakeParameter("parallelism", num_parallel_calls_, /*min=*/1,
-                                 /*max=*/ctx->runner_threadpool_size());
+                                 /*max=*/max_parallelism_value);
       }
       std::optional<int64_t> estimated_element_size =
           dataset()->GetEstimatedElementSize();
@@ -394,10 +415,6 @@
 
     Status RestoreInternal(IteratorContext* ctx,
                            IteratorStateReader* reader) override {
-      if (ctx->restored_element_count().has_value()) {
-        return RestoreInput(ctx, reader, input_impl_);
-      }
-
       mutex_lock l(*mu_);
       TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
       DCHECK(invocation_results_.empty());
@@ -456,6 +473,9 @@
           std::make_pair("autotune", autotune_ ? "true" : "false"));
       result.push_back(
           std::make_pair("deterministic", deterministic_ ? "true" : "false"));
+      result.push_back(
+          std::make_pair("use_unbounded_threadpool",
+                         use_unbounded_threadpool_ ? "true" : "false"));
       result.push_back(std::make_pair(
           "parallelism",
           parallelism == -1
@@ -543,7 +563,15 @@
 
       // Apply the map function on `input_element`, storing the result in
       // `result->return_values`, and invoking `done` when finished.
-      if (dataset()->captured_func_->use_inter_op_parallelism()) {
+      if (use_unbounded_threadpool_) {
+        auto runner_fn = [this](std::function<void()> fn) {
+          this->unbounded_thread_pool_->Schedule(fn);
+        };
+        instantiated_captured_func_->RunAsync(
+            runner_fn, ctx->cancellation_manager(), ctx->collective_executor(),
+            std::move(input_element), &result->return_values, done,
+            model_node());
+      } else if (dataset()->captured_func_->use_inter_op_parallelism()) {
         instantiated_captured_func_->RunAsync(
             ctx.get(), std::move(input_element), &result->return_values,
             std::move(done), model_node());
@@ -751,6 +779,7 @@
     const std::shared_ptr<model::SharedState> num_parallel_calls_;
     const bool deterministic_;
     const bool preserve_cardinality_;
+    const bool use_unbounded_threadpool_;
     const bool autotune_;
     // Counts the number of outstanding calls.
     int64_t num_calls_ TF_GUARDED_BY(*mu_) = 0;
@@ -767,6 +796,7 @@
     bool cancelled_ TF_GUARDED_BY(*mu_) = false;
     std::unique_ptr<Thread> runner_thread_ TF_GUARDED_BY(*mu_);
     std::unique_ptr<Thread> stats_thread_ TF_GUARDED_BY(*mu_);
+    std::unique_ptr<UnboundedThreadPool> unbounded_thread_pool_;
 
     // Method for deregistering the cancellation callback.
     std::function<void()> deregister_fn_;
@@ -784,6 +814,7 @@
   const std::vector<PartialTensorShape> output_shapes_;
   const DeterminismPolicy deterministic_;
   const bool preserve_cardinality_;
+  const bool use_unbounded_threadpool_;
   const std::unique_ptr<CapturedFunction> captured_func_;
   const int op_version_;
   // This is used for random access provided by Get().
@@ -812,12 +843,15 @@
     } else {
       deterministic_ = DeterminismPolicy(DeterminismPolicy::Type::kDefault);
     }
+    use_unbounded_threadpool_ = false;
   }
   if (op_version_ == 2) {
     std::string deterministic;
     OP_REQUIRES_OK(ctx, ctx->GetAttr(kDeterministic, &deterministic));
     OP_REQUIRES_OK(
         ctx, DeterminismPolicy::FromString(deterministic, &deterministic_));
+    OP_REQUIRES_OK(
+        ctx, ctx->GetAttr(kUseUnboundedThreadpool, &use_unbounded_threadpool_));
   }
   OP_REQUIRES_OK(ctx,
                  ctx->GetAttr(kPreserveCardinality, &preserve_cardinality_));
@@ -849,10 +883,10 @@
     metrics::RecordTFDataAutotune(kDatasetType);
   }
 
-  *output =
-      new Dataset(ctx, input, num_parallel_calls, output_types_, output_shapes_,
-                  deterministic_, std::move(captured_func),
-                  preserve_cardinality_, op_version_);
+  *output = new Dataset(ctx, input, num_parallel_calls, output_types_,
+                        output_shapes_, deterministic_,
+                        std::move(captured_func), preserve_cardinality_,
+                        use_unbounded_threadpool_, op_version_);
 }
 
 std::unique_ptr<DatasetBase> MakeDataServiceUncompressDataset(
@@ -867,7 +901,8 @@
       /*num_parallel_calls=*/model::kAutotune, output_types, output_shapes,
       DeterminismPolicy(DeterminismPolicy::Type::kDefault),
       std::move(captured_function),
-      /*preserve_cardinality=*/true, /*op_version=*/2);
+      /*preserve_cardinality=*/true,
+      /*use_unbounded_threadpool=*/false, /*op_version=*/2);
 }
 
 namespace {
diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.h b/tensorflow/core/kernels/data/parallel_map_dataset_op.h
index 4e1e564..efdf633 100644
--- a/tensorflow/core/kernels/data/parallel_map_dataset_op.h
+++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.h
@@ -38,6 +38,8 @@
   static constexpr const char* const kSloppy = "sloppy";
   static constexpr const char* const kPreserveCardinality =
       "preserve_cardinality";
+  static constexpr const char* const kUseUnboundedThreadpool =
+      "use_unbounded_threadpool";
 
   explicit ParallelMapDatasetOp(OpKernelConstruction* ctx);
 
@@ -54,6 +56,7 @@
   bool sloppy_;
   bool preserve_cardinality_;
   DeterminismPolicy deterministic_;
+  bool use_unbounded_threadpool_;
 
   friend std::unique_ptr<DatasetBase> MakeDataServiceUncompressDataset(
       DatasetBase* input, std::unique_ptr<CapturedFunction> captured_function,
diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op_test.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op_test.cc
index 357e279..cedc0e8 100644
--- a/tensorflow/core/kernels/data/parallel_map_dataset_op_test.cc
+++ b/tensorflow/core/kernels/data/parallel_map_dataset_op_test.cc
@@ -12,10 +12,10 @@
 #include "tensorflow/core/kernels/data/parallel_map_dataset_op.h"
 
 #include <gtest/gtest.h>
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/data/dataset_test_base.h"
 #include "tensorflow/core/data/name_utils.h"
 #include "tensorflow/core/framework/tensor_shape.h"
-#include "tsl/lib/core/status_test_util.h"
 
 namespace tensorflow {
 namespace data {
diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
index 0a7779c..c6238a0 100644
--- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
@@ -297,11 +297,6 @@
 
     Status RestoreInternal(IteratorContext* ctx,
                            IteratorStateReader* reader) override {
-      if (ctx->restored_element_count().has_value()) {
-        tsl::mutex_lock l(input_mu_);
-        return RestoreInput(ctx, reader, input_impl_);
-      }
-
       mutex_lock input_l(input_mu_);
       mutex_lock l(*mu_);
       DCHECK(!prefetch_thread_);
diff --git a/tensorflow/core/kernels/data/range_dataset_op.cc b/tensorflow/core/kernels/data/range_dataset_op.cc
index bf71d67..5834494 100644
--- a/tensorflow/core/kernels/data/range_dataset_op.cc
+++ b/tensorflow/core/kernels/data/range_dataset_op.cc
@@ -332,13 +332,14 @@
         TF_RETURN_IF_ERROR(
             writer->WriteScalar(prefix(), kNext, counter_->Peek()));
       }
+      TF_RETURN_IF_ERROR(global_shuffle_iterator_.Save(prefix(), ctx, writer));
       return absl::OkStatus();
     }
 
     Status RestoreInternal(IteratorContext* ctx,
                            IteratorStateReader* reader) override {
       if (ctx->restored_element_count().has_value()) {
-        return global_shuffle_iterator_.Restore(ctx);
+        return global_shuffle_iterator_.Restore(prefix(), ctx, reader);
       }
       if (reader->Contains(prefix(), kHasSplitProvider)) {
         TF_RETURN_IF_ERROR(split_provider_->Restore(
diff --git a/tensorflow/core/kernels/data/repeat_dataset_op.cc b/tensorflow/core/kernels/data/repeat_dataset_op.cc
index 820fc9b..555bdba 100644
--- a/tensorflow/core/kernels/data/repeat_dataset_op.cc
+++ b/tensorflow/core/kernels/data/repeat_dataset_op.cc
@@ -319,7 +319,7 @@
         if (element_position >= input_cardinality) {
           // The input element position is out-of-range. The caller is
           // responsible for handle this case (e.g.: returning end_of_sequence).
-          return element_position;
+          return absl::OutOfRangeError("Finite repeat is out of range");
         }
 
         // First, maps the input indices from
@@ -356,28 +356,37 @@
     Status RestoreInternal(IteratorContext* ctx,
                            IteratorStateReader* reader) override {
       mutex_lock l(mu_);
+      int64_t input_empty;
+      TF_RETURN_IF_ERROR(
+          reader->ReadScalar(prefix(), kInputImplEmpty, &input_empty));
+      TF_RETURN_IF_ERROR(reader->ReadScalar(prefix(), kCurIteration, &i_));
+
       if (ctx->restored_element_count().has_value()) {
         CardinalityOptions options;
         options.set_compute_level(
             CardinalityOptions::CARDINALITY_COMPUTE_MODERATE);
         const int64_t input_cardinality =
             dataset()->input_->Cardinality(std::move(options));
-        i_ = *ctx->restored_element_count() / input_cardinality;
         // For upstream iterators, the restored element count should be the
         // element count within the current repetition.
         IteratorContext::Params params(ctx);
         params.restored_element_count =
-            *ctx->restored_element_count() % input_cardinality;
+            *ctx->restored_element_count() % (input_cardinality);
         params.index_mapper = GetIndexMapper(ctx->index_mapper());
         IteratorContext ctx_with_restored_element_count(params);
-        return RestoreInput(&ctx_with_restored_element_count, reader,
-                            input_impl_);
+        if (!input_empty) {
+          // Needs to re-`MakeIterator` because `i_` might have changed.
+          TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
+              ctx, this, nested_prefix(prefix(), i_), &input_impl_));
+          TF_RETURN_IF_ERROR(RestoreInput(&ctx_with_restored_element_count,
+                                          reader, input_impl_));
+          ctx->MergeCheckpoint(ctx_with_restored_element_count.checkpoint());
+        } else {
+          input_impl_.reset();
+        }
+        return absl::OkStatus();
       }
 
-      TF_RETURN_IF_ERROR(reader->ReadScalar(prefix(), kCurIteration, &i_));
-      int64_t input_empty;
-      TF_RETURN_IF_ERROR(
-          reader->ReadScalar(prefix(), kInputImplEmpty, &input_empty));
       if (static_cast<bool>(!input_empty)) {
         TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
             ctx, this, nested_prefix(prefix(), i_), &input_impl_));
diff --git a/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc b/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc
index 695263e..950be9b 100644
--- a/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc
+++ b/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc
@@ -280,7 +280,7 @@
                                 "is not currently supported."));
       previous_batch_index = next_batch_index;
     }
-    gtl::InlinedVector<int64_t, 8> std_order(dense_shape->NumElements(), 0);
+    absl::InlinedVector<int64_t, 8UL> std_order(dense_shape->NumElements(), 0);
     TensorShape shape;
     OP_REQUIRES_OK(ctx, TensorShape::BuildTensorShape(
                             dense_shape->vec<int64_t>(), &shape));
diff --git a/tensorflow/core/kernels/data/take_dataset_op.cc b/tensorflow/core/kernels/data/take_dataset_op.cc
index a08890e..d910271 100644
--- a/tensorflow/core/kernels/data/take_dataset_op.cc
+++ b/tensorflow/core/kernels/data/take_dataset_op.cc
@@ -187,12 +187,6 @@
 
   Status RestoreInternal(IteratorContext* ctx,
                          IteratorStateReader* reader) override {
-    if (ctx->restored_element_count().has_value()) {
-      mutex_lock l(mu_);
-      i_ = *ctx->restored_element_count();
-      return RestoreInput(ctx, reader, input_impl_);
-    }
-
     mutex_lock l(mu_);
     TF_RETURN_IF_ERROR(reader->ReadScalar(prefix(), kCurIndex, &i_));
     int64_t input_empty;
diff --git a/tensorflow/core/kernels/data/tensor_dataset_op.cc b/tensorflow/core/kernels/data/tensor_dataset_op.cc
index 02736e5..fe2d564 100644
--- a/tensorflow/core/kernels/data/tensor_dataset_op.cc
+++ b/tensorflow/core/kernels/data/tensor_dataset_op.cc
@@ -184,13 +184,14 @@
       mutex_lock l(mu_);
       TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kProduced,
                                              static_cast<int64_t>(produced_)));
+      TF_RETURN_IF_ERROR(global_shuffle_iterator_.Save(prefix(), ctx, writer));
       return absl::OkStatus();
     }
 
     Status RestoreInternal(IteratorContext* ctx,
                            IteratorStateReader* reader) override {
       if (ctx->restored_element_count().has_value()) {
-        return global_shuffle_iterator_.Restore(ctx);
+        return global_shuffle_iterator_.Restore(prefix(), ctx, reader);
       }
 
       mutex_lock l(mu_);
diff --git a/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc b/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc
index dad1e8c..3e2374b 100644
--- a/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc
+++ b/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc
@@ -55,7 +55,7 @@
         replicate_on_split_(replicate_on_split) {
     for (const Tensor& t : tensors_) {
       dtypes_.push_back(t.dtype());
-      gtl::InlinedVector<int64_t, 4> element_dim_sizes;
+      absl::InlinedVector<int64_t, 4UL> element_dim_sizes;
       // Handle scalar here. Check that everyone matches here? Or fail
       // at runtime?
       for (int i = 1; i < t.dims(); ++i) {
@@ -206,14 +206,16 @@
 
     Status SaveInternal(SerializationContext* ctx,
                         IteratorStateWriter* writer) override {
-      return split_provider_->Save(
-          [this](const std::string& key) { return full_name(key); }, writer);
+      TF_RETURN_IF_ERROR(split_provider_->Save(
+          [this](const std::string& key) { return full_name(key); }, writer));
+      TF_RETURN_IF_ERROR(global_shuffle_iterator_.Save(prefix(), ctx, writer));
+      return absl::OkStatus();
     }
 
     Status RestoreInternal(IteratorContext* ctx,
                            IteratorStateReader* reader) override {
       if (ctx->restored_element_count().has_value()) {
-        return global_shuffle_iterator_.Restore(ctx);
+        return global_shuffle_iterator_.Restore(prefix(), ctx, reader);
       }
       return split_provider_->Restore(
           [this](const std::string& key) { return full_name(key); }, reader);
diff --git a/tensorflow/core/kernels/data/tf_record_dataset_op_test.cc b/tensorflow/core/kernels/data/tf_record_dataset_op_test.cc
index 3c5ed02..99a29eb 100644
--- a/tensorflow/core/kernels/data/tf_record_dataset_op_test.cc
+++ b/tensorflow/core/kernels/data/tf_record_dataset_op_test.cc
@@ -20,6 +20,7 @@
 #include "absl/strings/str_cat.h"
 #include "absl/strings/str_join.h"
 #include "absl/strings/string_view.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/data/dataset_test_base.h"
 #include "tensorflow/core/data/name_utils.h"
 #include "tensorflow/core/framework/dataset.h"
@@ -34,7 +35,6 @@
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/platform/tstring.h"
 #include "tensorflow/core/platform/types.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/errors.h"
 #include "tsl/platform/status.h"
 
diff --git a/tensorflow/core/kernels/data/window_dataset.cc b/tensorflow/core/kernels/data/window_dataset.cc
index a5c0e49..c9d2a5a 100644
--- a/tensorflow/core/kernels/data/window_dataset.cc
+++ b/tensorflow/core/kernels/data/window_dataset.cc
@@ -177,6 +177,7 @@
     std::vector<std::vector<Tensor>> elements;
     for (size_t i = 0; i < num_elements; ++i) {
       std::vector<Tensor> element;
+      element.reserve(element_size);
       for (size_t j = 0; j < element_size; ++j) {
         element.push_back(std::move(inputs[i * element_size + j]));
       }
diff --git a/tensorflow/core/kernels/data/window_dataset_op_test.cc b/tensorflow/core/kernels/data/window_dataset_op_test.cc
index 3252f52..6dfc803 100644
--- a/tensorflow/core/kernels/data/window_dataset_op_test.cc
+++ b/tensorflow/core/kernels/data/window_dataset_op_test.cc
@@ -16,6 +16,7 @@
 
 #include <gtest/gtest.h>
 #include "absl/status/status.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/data/dataset_test_base.h"
 #include "tensorflow/core/data/name_utils.h"
 #include "tensorflow/core/data/serialization_utils.h"
@@ -27,7 +28,6 @@
 #include "tensorflow/core/framework/variant_tensor_data.h"
 #include "tensorflow/core/platform/status.h"
 #include "tensorflow/core/platform/types.h"
-#include "tsl/lib/core/status_test_util.h"
 
 namespace tensorflow {
 namespace data {
diff --git a/tensorflow/core/kernels/data/zip_dataset_op.cc b/tensorflow/core/kernels/data/zip_dataset_op.cc
index e23c181..f2891dd 100644
--- a/tensorflow/core/kernels/data/zip_dataset_op.cc
+++ b/tensorflow/core/kernels/data/zip_dataset_op.cc
@@ -272,6 +272,9 @@
       mutex_lock l(mu_);
       // Note: When restoring, `SaveInternal` would not be called
       // if there is a global_shuffle_dataset_op.cc above this op.
+      int64_t inputs_empty;
+      TF_RETURN_IF_ERROR(
+          reader->ReadScalar(prefix(), kInputImplsEmpty, &inputs_empty));
       if (ctx->restored_element_count()) {
         if (input_impls_.size() != dataset()->inputs_.size()) {
           return absl::FailedPreconditionError(
@@ -283,14 +286,19 @@
               "ctx->index_mapper() should be provided along with "
               "ctx->restored_element_count() when restoring.");
         }
-        for (const auto& input_impl : input_impls_) {
-          TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl));
+        if (static_cast<bool>(inputs_empty)) {
+          input_impls_.clear();
+        } else {
+          for (int i = 0; i < input_impls_.size(); ++i) {
+            input_contexts_[i].set_restored_element_count(
+                ctx->restored_element_count().value());
+            TF_RETURN_IF_ERROR(
+                RestoreInput(&input_contexts_[i], reader, input_impls_[i]));
+            ctx->MergeCheckpoint(input_contexts_[i].checkpoint());
+          }
         }
         return absl::OkStatus();
       }
-      int64_t inputs_empty;
-      TF_RETURN_IF_ERROR(
-          reader->ReadScalar(prefix(), kInputImplsEmpty, &inputs_empty));
       if (static_cast<bool>(inputs_empty)) {
         input_impls_.clear();
       } else {
diff --git a/tensorflow/core/kernels/data/zip_dataset_op_test.cc b/tensorflow/core/kernels/data/zip_dataset_op_test.cc
index c864287..9d8eed6 100644
--- a/tensorflow/core/kernels/data/zip_dataset_op_test.cc
+++ b/tensorflow/core/kernels/data/zip_dataset_op_test.cc
@@ -17,6 +17,7 @@
 #include <gtest/gtest.h>
 #include "absl/status/status.h"
 #include "absl/strings/str_cat.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/data/dataset_test_base.h"
 #include "tensorflow/core/data/name_utils.h"
 #include "tensorflow/core/framework/tensor.h"
@@ -25,7 +26,6 @@
 #include "tensorflow/core/framework/types.pb.h"
 #include "tensorflow/core/platform/status.h"
 #include "tensorflow/core/platform/types.h"
-#include "tsl/lib/core/status_test_util.h"
 
 namespace tensorflow {
 namespace data {
diff --git a/tensorflow/core/kernels/debug_ops.h b/tensorflow/core/kernels/debug_ops.h
index d7c0c76..15ff88c 100644
--- a/tensorflow/core/kernels/debug_ops.h
+++ b/tensorflow/core/kernels/debug_ops.h
@@ -190,7 +190,7 @@
         LOG(ERROR) << "Debug node of watch key "
                    << debug_watch_key_->debug_node_name
                    << " failed to publish debug tensor data to all URLs "
-                   << str_util::Join(debug_urls_, ", ")
+                   << absl::StrJoin(debug_urls_, ", ")
                    << ", due to: " << status.message();
       }
       return status;
diff --git a/tensorflow/core/kernels/function_ops.cc b/tensorflow/core/kernels/function_ops.cc
index 7c865b6..0871890 100644
--- a/tensorflow/core/kernels/function_ops.cc
+++ b/tensorflow/core/kernels/function_ops.cc
@@ -255,7 +255,7 @@
       args.push_back(ctx->input(i));
     }
     std::vector<Tensor>* rets = new std::vector<Tensor>;
-    profiler::TraceMe trace_me("SymbolicGradientOp");
+    tsl::profiler::TraceMe trace_me("SymbolicGradientOp");
     lib->Run(opts, handle, args, rets, [ctx, done, rets](const Status& status) {
       if (!status.ok()) {
         ctx->SetStatus(status);
@@ -319,12 +319,12 @@
       handle = cached_entry->second;
     } else {
       VLOG(1) << "Instantiating " << func_name << " on " << target_device;
-      profiler::TraceMe activity(
+      tsl::profiler::TraceMe activity(
           [&] {
             return strings::StrCat("RemoteCall: Instantiate: ", func_name,
                                    " on ", target_device);
           },
-          profiler::TraceMeLevel::kInfo);
+          tsl::profiler::TraceMeLevel::kInfo);
       FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
       const auto* config = (ctx->function_library())
                                ? ctx->function_library()->config_proto()
@@ -398,24 +398,24 @@
   auto* rets = new std::vector<Tensor>;
   VLOG(1) << "Running " << func_name << " on " << target_device
           << " with handle: " << handle;
-  profiler::TraceMe trace_me(
+  tsl::profiler::TraceMe trace_me(
       [&] {
         return profiler::TraceMeEncode(
             "RemoteCallOp",
             {{"func_name", func_name}, {"device", target_device}});
       },
-      profiler::TraceMeLevel::kInfo);
+      tsl::profiler::TraceMeLevel::kInfo);
   lib->Run(
       opts, handle, args, rets,
       [rets, done = std::move(done), func_name, ctx, cancel_mgr,
        target_device = std::move(function_target.first)](const Status& status) {
-        profiler::TraceMe activity(
+        tsl::profiler::TraceMe activity(
             [&] {
               return profiler::TraceMeEncode(
                   "RemoteCallOpDone",
                   {{"func_name", func_name}, {"device", target_device}});
             },
-            profiler::TraceMeLevel::kInfo);
+            tsl::profiler::TraceMeLevel::kInfo);
         if (!status.ok()) {
           ctx->SetStatus(status);
         } else {
diff --git a/tensorflow/core/kernels/functional_ops.cc b/tensorflow/core/kernels/functional_ops.cc
index 79c393f..7cf465b 100644
--- a/tensorflow/core/kernels/functional_ops.cc
+++ b/tensorflow/core/kernels/functional_ops.cc
@@ -199,7 +199,7 @@
     void Start() {
       FHandle handle = cond_ ? then_handle_ : else_handle_;
       rets_.clear();
-      profiler::TraceMe trace_me("IfOp");
+      tsl::profiler::TraceMe trace_me("IfOp");
       lib_->Run(
           // Evaluate one of the branch.
           opts_, handle, args_, &rets_,
@@ -378,7 +378,7 @@
         branch = branch_handles_.size() - 1;
       }
       rets_.clear();
-      profiler::TraceMe trace_me("CaseOp");
+      tsl::profiler::TraceMe trace_me("CaseOp");
       lib_->Run(
           // Evaluate one of the branch.
           opts_, branch_handles_[branch], args_, &rets_,
@@ -633,7 +633,7 @@
     std::unique_ptr<BodyFuncCallFrame> body_frame_;
 
     void EvalCond() {
-      profiler::TraceMe trace_me("WhileOp-EvalCond");
+      tsl::profiler::TraceMe trace_me("WhileOp-EvalCond");
       lib_->Run(
           // Evaluate the condition.
           opts_, cond_handle_, args_, &rets_,
@@ -669,7 +669,7 @@
       }
       rets_.clear();
       rets_.resize(args_.size());
-      profiler::TraceMe trace_me("WhileOp-StartBody");
+      tsl::profiler::TraceMe trace_me("WhileOp-StartBody");
       lib_->Run(
           // Evaluate the body.
           opts_, body_handle_, body_frame_.get(),
@@ -724,7 +724,7 @@
     do {
       // Evaluate the cond function on the current loop variables.
       {
-        profiler::TraceMe trace_me("WhileOp-EvalCond");
+        tsl::profiler::TraceMe trace_me("WhileOp-EvalCond");
         TF_RETURN_IF_ERROR(lib->RunSync(opts, cond_handle, args, &cond_rets));
       }
       if (cond_rets.size() != 1) {
@@ -745,7 +745,7 @@
       // Evaluate the body function on the current loop variables, to get an
       // updated vector of loop variables.
       {
-        profiler::TraceMe trace_me("WhileOp-StartBody");
+        tsl::profiler::TraceMe trace_me("WhileOp-StartBody");
         body_rets.resize(num_loop_vars);
         BodyFuncCallFrame call_frame(&args, &body_rets, loop_var_types);
         TF_RETURN_IF_ERROR(lib->RunSync(opts, body_handle, &call_frame));
@@ -982,7 +982,7 @@
         args_[1 + i] = std::move(rets_[i]);
       }
       rets_.clear();
-      profiler::TraceMe trace_me("ForOp");
+      tsl::profiler::TraceMe trace_me("ForOp");
       lib_->Run(opts_, body_handle_, args_, &rets_, [this](const Status& s) {
         if (s.ok()) {
           *iter_ += delta_;
diff --git a/tensorflow/core/kernels/gather_nd_op_test.cc b/tensorflow/core/kernels/gather_nd_op_test.cc
index 3212068..2758fbb 100644
--- a/tensorflow/core/kernels/gather_nd_op_test.cc
+++ b/tensorflow/core/kernels/gather_nd_op_test.cc
@@ -18,6 +18,7 @@
 #include <vector>
 
 #include "absl/strings/match.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
 #include "tensorflow/core/framework/allocator.h"
 #include "tensorflow/core/framework/fake_input.h"
@@ -37,7 +38,6 @@
 #include "tensorflow/core/platform/status.h"
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/platform/test_benchmark.h"
-#include "tsl/lib/core/status_test_util.h"
 
 namespace tensorflow {
 
diff --git a/tensorflow/core/kernels/image/resize_bicubic_op_test.cc b/tensorflow/core/kernels/image/resize_bicubic_op_test.cc
index 77f1725..209dbbd 100644
--- a/tensorflow/core/kernels/image/resize_bicubic_op_test.cc
+++ b/tensorflow/core/kernels/image/resize_bicubic_op_test.cc
@@ -92,7 +92,13 @@
                                    std::array<float, 4>* weights,
                                    std::array<int64_t, 4>* indices) {
     const int64_t in_loc = scale * out_loc;
-    const float delta = scale * out_loc - in_loc;
+    // Ensure that the following calculation is kept in a float to match the
+    // rounding done in the optimised case. Merging it with the following line
+    // keeps an intermediate value at higher precision and that leads to a
+    // divergence in the result. So keep the following two lines separate to
+    // ensure that the calculation is rounded as expected.
+    const float in_loc_float = scale * out_loc;
+    const float delta = in_loc_float - in_loc;
     const int64_t offset = lrintf(delta * kTableSize);
     const float* coeffs_tab = GetCoeffsTable();
     *weights = {{coeffs_tab[offset * 2 + 1], coeffs_tab[offset * 2],
diff --git a/tensorflow/core/kernels/partitioned_function_ops.cc b/tensorflow/core/kernels/partitioned_function_ops.cc
index 7619201..d07b4b9 100644
--- a/tensorflow/core/kernels/partitioned_function_ops.cc
+++ b/tensorflow/core/kernels/partitioned_function_ops.cc
@@ -257,7 +257,7 @@
 
   std::vector<Tensor>* rets = new std::vector<Tensor>;
   const string& func_name = func_->name();
-  profiler::TraceMe trace_me("PartitionedCallOp");
+  tsl::profiler::TraceMe trace_me("PartitionedCallOp");
   lib->Run(run_opts, handle, inputs, rets,
            [rets, done = std::move(done), ctx, func_name,
             step_container](const Status& status) {
diff --git a/tensorflow/core/kernels/scatter_nd_op_test.cc b/tensorflow/core/kernels/scatter_nd_op_test.cc
index 2e88088..02fa44f 100644
--- a/tensorflow/core/kernels/scatter_nd_op_test.cc
+++ b/tensorflow/core/kernels/scatter_nd_op_test.cc
@@ -20,6 +20,7 @@
 
 #include "absl/status/status.h"
 #include "absl/strings/match.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/allocator.h"
 #include "tensorflow/core/framework/fake_input.h"
 #include "tensorflow/core/framework/node_def_builder.h"
@@ -38,7 +39,6 @@
 #include "tensorflow/core/platform/status.h"
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/platform/test_benchmark.h"
-#include "tsl/lib/core/status_test_util.h"
 
 namespace tensorflow {
 namespace {
diff --git a/tensorflow/core/kernels/sendrecv_ops.cc b/tensorflow/core/kernels/sendrecv_ops.cc
index 7f76cb4..d15cc7f 100644
--- a/tensorflow/core/kernels/sendrecv_ops.cc
+++ b/tensorflow/core/kernels/sendrecv_ops.cc
@@ -120,8 +120,8 @@
   auto dst_it = attr.find("_dst");
   const string& src = src_it != attr.end() ? src_it->second.s() : "";
   const string& dst = dst_it != attr.end() ? dst_it->second.s() : "";
-  string op = profiler::TraceMeOp(name_view(), type_string_view());
-  return profiler::TraceMeEncode(
+  string op = tsl::profiler::TraceMeOp(name_view(), type_string_view());
+  return tsl::profiler::TraceMeEncode(
       std::move(op),
       {{"from", src}, {"to", dst}, {"key", parsed_key_.FullKey()}});
 }
@@ -166,8 +166,8 @@
   auto dst_it = attr.find("_dst");
   const string& src = src_it != attr.end() ? src_it->second.s() : "";
   const string& dst = dst_it != attr.end() ? dst_it->second.s() : "";
-  string op = profiler::TraceMeOp(name_view(), type_string_view());
-  return profiler::TraceMeEncode(
+  string op = tsl::profiler::TraceMeOp(name_view(), type_string_view());
+  return tsl::profiler::TraceMeEncode(
       std::move(op),
       {{"from", src}, {"to", dst}, {"key", parsed_key_.FullKey()}});
 }
diff --git a/tensorflow/core/kernels/spectrogram_op_test.cc b/tensorflow/core/kernels/spectrogram_op_test.cc
index 791fdda..650024b 100644
--- a/tensorflow/core/kernels/spectrogram_op_test.cc
+++ b/tensorflow/core/kernels/spectrogram_op_test.cc
@@ -25,13 +25,13 @@
 #include "tensorflow/cc/ops/audio_ops.h"
 #include "tensorflow/cc/ops/const_op.h"
 #include "tensorflow/cc/ops/math_ops.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/tensor_testutil.h"
 #include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/framework/types.pb.h"
 #include "tensorflow/core/kernels/ops_util.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/test.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/errors.h"
 #include "tsl/platform/status_matchers.h"
 
diff --git a/tensorflow/core/kernels/stochastic_cast_op_test.cc b/tensorflow/core/kernels/stochastic_cast_op_test.cc
index b0a5835..10d9eae1 100644
--- a/tensorflow/core/kernels/stochastic_cast_op_test.cc
+++ b/tensorflow/core/kernels/stochastic_cast_op_test.cc
@@ -21,6 +21,7 @@
 
 #include <gtest/gtest.h>
 #include "Eigen/Core"  // from @eigen_archive
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/fake_input.h"
 #include "tensorflow/core/framework/node_def_builder.h"
 #include "tensorflow/core/framework/tensor_shape.h"
@@ -30,7 +31,6 @@
 #include "tensorflow/core/lib/random/random_distributions.h"
 #include "tensorflow/core/platform/bfloat16.h"
 #include "tensorflow/core/platform/logging.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/lib/random/philox_random.h"
 
 namespace Eigen {
diff --git a/tensorflow/core/kernels/uniform_quant_ops/BUILD b/tensorflow/core/kernels/uniform_quant_ops/BUILD
index 5c158fe..507a857 100644
--- a/tensorflow/core/kernels/uniform_quant_ops/BUILD
+++ b/tensorflow/core/kernels/uniform_quant_ops/BUILD
@@ -191,7 +191,7 @@
         "//tensorflow/core/platform:errors",
         "//tensorflow/core/platform:test",
         "@com_google_googletest//:gtest_main",
-        "@local_tsl//tsl/lib/core:status_test_util",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
 
@@ -205,6 +205,6 @@
         "//tensorflow/core/platform:test",
         "@com_google_absl//absl/status",
         "@com_google_googletest//:gtest_main",
-        "@local_tsl//tsl/lib/core:status_test_util",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
diff --git a/tensorflow/core/kernels/uniform_quant_ops/math_utils_test.cc b/tensorflow/core/kernels/uniform_quant_ops/math_utils_test.cc
index 4d2e869..a331b8b 100644
--- a/tensorflow/core/kernels/uniform_quant_ops/math_utils_test.cc
+++ b/tensorflow/core/kernels/uniform_quant_ops/math_utils_test.cc
@@ -17,10 +17,10 @@
 #include <limits>
 
 #include <gtest/gtest.h>
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/tensor_testutil.h"
 #include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/test.h"
-#include "tsl/lib/core/status_test_util.h"
 
 namespace tensorflow {
 
diff --git a/tensorflow/core/kernels/uniform_quant_ops/tensor_utils_test.cc b/tensorflow/core/kernels/uniform_quant_ops/tensor_utils_test.cc
index f3d2f2c..c4f0ea5 100644
--- a/tensorflow/core/kernels/uniform_quant_ops/tensor_utils_test.cc
+++ b/tensorflow/core/kernels/uniform_quant_ops/tensor_utils_test.cc
@@ -19,9 +19,9 @@
 
 #include <gtest/gtest.h>
 #include "absl/status/status.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/tensor_testutil.h"
 #include "tensorflow/core/platform/test.h"
-#include "tsl/lib/core/status_test_util.h"
 
 namespace tensorflow {
 
diff --git a/tensorflow/core/kernels/uniform_quant_ops/uniform_quantized_convolution_ops_test.cc b/tensorflow/core/kernels/uniform_quant_ops/uniform_quantized_convolution_ops_test.cc
index 7ffc629..73121de 100644
--- a/tensorflow/core/kernels/uniform_quant_ops/uniform_quantized_convolution_ops_test.cc
+++ b/tensorflow/core/kernels/uniform_quant_ops/uniform_quantized_convolution_ops_test.cc
@@ -17,6 +17,7 @@
 #include <vector>
 
 #include <gtest/gtest.h>
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/fake_input.h"
 #include "tensorflow/core/framework/node_def_builder.h"
 #include "tensorflow/core/framework/numeric_types.h"
@@ -27,7 +28,6 @@
 #include "tensorflow/core/kernels/ops_testutil.h"
 #include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/util/quantization/uniform_quant_ops_attr.pb.h"
-#include "tsl/lib/core/status_test_util.h"
 
 namespace tensorflow {
 
diff --git a/tensorflow/core/lib/core/status_test_util.h b/tensorflow/core/lib/core/status_test_util.h
index ef333ef..3c604ee 100644
--- a/tensorflow/core/lib/core/status_test_util.h
+++ b/tensorflow/core/lib/core/status_test_util.h
@@ -16,7 +16,7 @@
 #ifndef TENSORFLOW_CORE_LIB_CORE_STATUS_TEST_UTIL_H_
 #define TENSORFLOW_CORE_LIB_CORE_STATUS_TEST_UTIL_H_
 
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/test.h"
-#include "tsl/lib/core/status_test_util.h"
 
 #endif  // TENSORFLOW_CORE_LIB_CORE_STATUS_TEST_UTIL_H_
diff --git a/tensorflow/core/lib/db/sqlite_test.cc b/tensorflow/core/lib/db/sqlite_test.cc
index 39ea38e..0a0042f 100644
--- a/tensorflow/core/lib/db/sqlite_test.cc
+++ b/tensorflow/core/lib/db/sqlite_test.cc
@@ -17,11 +17,11 @@
 #include <array>
 #include <climits>
 
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/lib/core/stringpiece.h"
 #include "tensorflow/core/platform/status.h"
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/platform/types.h"
-#include "tsl/lib/core/status_test_util.h"
 
 namespace tensorflow {
 namespace {
diff --git a/tensorflow/core/lib/gtl/edit_distance.h b/tensorflow/core/lib/gtl/edit_distance.h
index 9481ba7..818ec69 100644
--- a/tensorflow/core/lib/gtl/edit_distance.h
+++ b/tensorflow/core/lib/gtl/edit_distance.h
@@ -59,7 +59,7 @@
   if (s == t) return 0;
 
   // Create work vector
-  gtl::InlinedVector<int64_t, 32> scratch_holder(t_size);
+  absl::InlinedVector<int64_t, 32UL> scratch_holder(t_size);
 
   int64_t* scratch = scratch_holder.data();
 
diff --git a/tensorflow/core/lib/strings/BUILD b/tensorflow/core/lib/strings/BUILD
index 72eb0a6..d8f4e6d 100644
--- a/tensorflow/core/lib/strings/BUILD
+++ b/tensorflow/core/lib/strings/BUILD
@@ -51,7 +51,7 @@
     name = "proto_serialization",
     hdrs = ["proto_serialization.h"],
     deps = [
-        "@local_tsl//tsl/lib/strings:proto_serialization",
+        "@local_xla//xla/tsl/lib/strings:proto_serialization",
     ],
 )
 
@@ -116,7 +116,7 @@
         "ordered_code.cc",
         "ordered_code.h",
         "proto_serialization.h",
-        "@local_tsl//tsl/lib/strings:mobile_srcs_only_runtime",
+        "@local_xla//xla/tsl/lib/strings:mobile_srcs_only_runtime",
     ],
     visibility = ["//tensorflow/core:__pkg__"],
 )
@@ -133,7 +133,7 @@
         "str_util.h",
         "strcat.h",
         "stringprintf.h",
-        "@local_tsl//tsl/lib/strings:legacy_lib_strings_all_headers",
+        "@local_xla//xla/tsl/lib/strings:legacy_lib_strings_all_headers",
     ],
     visibility = ["//tensorflow/core:__pkg__"],
 )
@@ -165,7 +165,7 @@
         "str_util.h",
         "strcat.h",
         "stringprintf.h",
-        "@local_tsl//tsl/lib/strings:legacy_lib_string_headers",
+        "@local_xla//xla/tsl/lib/strings:legacy_lib_string_headers",
     ],
     visibility = ["//tensorflow/core:__pkg__"],
 )
@@ -178,7 +178,7 @@
         "proto_serialization.h",
         "proto_text_util.h",
         "scanner.h",
-        "@local_tsl//tsl/lib/strings:legacy_lib_internal_public_string_headers",
+        "@local_xla//xla/tsl/lib/strings:legacy_lib_internal_public_string_headers",
     ],
     visibility = ["//tensorflow/core:__pkg__"],
 )
diff --git a/tensorflow/core/lib/strings/proto_serialization.h b/tensorflow/core/lib/strings/proto_serialization.h
index 0c01708..e0c253f 100644
--- a/tensorflow/core/lib/strings/proto_serialization.h
+++ b/tensorflow/core/lib/strings/proto_serialization.h
@@ -15,7 +15,7 @@
 #ifndef TENSORFLOW_CORE_LIB_STRINGS_PROTO_SERIALIZATION_H_
 #define TENSORFLOW_CORE_LIB_STRINGS_PROTO_SERIALIZATION_H_
 
-#include "tsl/lib/strings/proto_serialization.h"
+#include "xla/tsl/lib/strings/proto_serialization.h"
 
 namespace tensorflow {
 // NOLINTBEGIN(misc-unused-using-decls)
diff --git a/tensorflow/core/ops/array_grad.cc b/tensorflow/core/ops/array_grad.cc
index ebaade2..b05c412 100644
--- a/tensorflow/core/ops/array_grad.cc
+++ b/tensorflow/core/ops/array_grad.cc
@@ -157,7 +157,6 @@
     offset_i.push_back(strings::StrCat("offset:offset:", i));
     dx_i.push_back(strings::StrCat("dx_", i, ":output:0"));
   }
-  DataTypeVector dtype_list(N, T);
 
   // ConcatGrad(dim, x, dy):
   //   for i in range(N):
diff --git a/tensorflow/core/ops/batch_ops.cc b/tensorflow/core/ops/batch_ops.cc
index 99d4551..6d21ee4 100644
--- a/tensorflow/core/ops/batch_ops.cc
+++ b/tensorflow/core/ops/batch_ops.cc
@@ -76,9 +76,17 @@
     // allowed. The following options are available.
     //
     //   - PAD_UP: pad to size 32.
+    //   - BATCH_DOWN: schedule a batch of size 16 and leave 2 requests in the
+    //     batch buffer.
+    //   - MINIMIZE_TPU_COST_PER_REQUEST: a smarter greedy policy that chooses
+    //     to either PAD_UP or BATCH_DOWN so as to minimize the TPU costs per
+    //     real request. In this case, it would compare (batch_16_cost / 16) and
+    //     (batch_32_cost / 18).
+    //
+    // WARNING: Not all batch schedulers might support this attribute.
     .Attr(
         "batch_padding_policy: "
-        "{'PAD_UP'} = 'PAD_UP'")
+        "{'PAD_UP', 'BATCH_DOWN', 'MINIMIZE_TPU_COST_PER_REQUEST'} = 'PAD_UP'")
     .Attr("Tin: list(type)")
     .Attr("Tcaptured: list(type) >= 0")
     .Attr("Tout: list(type)")
diff --git a/tensorflow/core/ops/compat/ops_history_v2/BatchFunction.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/BatchFunction.pbtxt
index 8fecdf6..d743b8e 100644
--- a/tensorflow/core/ops/compat/ops_history_v2/BatchFunction.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history_v2/BatchFunction.pbtxt
@@ -802,3 +802,152 @@
   }
   is_distributed_communication: true
 }
+op {
+  name: "BatchFunction"
+  input_arg {
+    name: "in_tensors"
+    type_list_attr: "Tin"
+  }
+  input_arg {
+    name: "captured_tensors"
+    type_list_attr: "Tcaptured"
+  }
+  output_arg {
+    name: "out_tensors"
+    type_list_attr: "Tout"
+  }
+  attr {
+    name: "f"
+    type: "func"
+  }
+  attr {
+    name: "num_batch_threads"
+    type: "int"
+  }
+  attr {
+    name: "max_batch_size"
+    type: "int"
+  }
+  attr {
+    name: "batch_timeout_micros"
+    type: "int"
+  }
+  attr {
+    name: "max_enqueued_batches"
+    type: "int"
+    default_value {
+      i: 10
+    }
+  }
+  attr {
+    name: "allowed_batch_sizes"
+    type: "list(int)"
+    default_value {
+      list {
+      }
+    }
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "batching_queue"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "low_priority_max_batch_size"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "low_priority_batch_timeout_micros"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "low_priority_allowed_batch_sizes"
+    type: "list(int)"
+    default_value {
+      list {
+      }
+    }
+  }
+  attr {
+    name: "low_priority_max_enqueued_batches"
+    type: "int"
+    default_value {
+      i: 0
+    }
+  }
+  attr {
+    name: "mixed_priority_policy"
+    type: "string"
+    default_value {
+      s: "low_priority_padding_with_max_batch_size"
+    }
+    allowed_values {
+      list {
+        s: "low_priority_padding_with_max_batch_size"
+        s: "low_priority_padding_with_next_allowed_batch_size"
+        s: "priority_isolation"
+      }
+    }
+  }
+  attr {
+    name: "batch_padding_policy"
+    type: "string"
+    default_value {
+      s: "PAD_UP"
+    }
+    allowed_values {
+      list {
+        s: "PAD_UP"
+        s: "BATCH_DOWN"
+        s: "MINIMIZE_TPU_COST_PER_REQUEST"
+      }
+    }
+  }
+  attr {
+    name: "Tin"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "Tcaptured"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "Tout"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "enable_large_batch_splitting"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_distributed_communication: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v2/ParallelMapDatasetV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/ParallelMapDatasetV2.pbtxt
index 55e73b7..1f3016b 100644
--- a/tensorflow/core/ops/compat/ops_history_v2/ParallelMapDatasetV2.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history_v2/ParallelMapDatasetV2.pbtxt
@@ -290,3 +290,98 @@
     }
   }
 }
+op {
+  name: "ParallelMapDatasetV2"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "other_arguments"
+    type_list_attr: "Targuments"
+  }
+  input_arg {
+    name: "num_parallel_calls"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+    experimental_full_type {
+      type_id: TFT_DATASET
+      args {
+        type_id: TFT_FOR_EACH
+        args {
+          type_id: TFT_PRODUCT
+        }
+        args {
+          type_id: TFT_TENSOR
+          args {
+            type_id: TFT_VAR
+            s: "output_types"
+          }
+        }
+        args {
+          type_id: TFT_VAR
+          s: "output_types"
+        }
+      }
+    }
+  }
+  attr {
+    name: "f"
+    type: "func"
+  }
+  attr {
+    name: "Targuments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "use_inter_op_parallelism"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "deterministic"
+    type: "string"
+    default_value {
+      s: "default"
+    }
+  }
+  attr {
+    name: "preserve_cardinality"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "use_unbounded_threadpool"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "metadata"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+}
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index d347d0c..7e81212 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -206,6 +206,7 @@
     // "true", "false", or "default".
     .Attr("deterministic: string = 'default'")
     .Attr("preserve_cardinality: bool = false")
+    .Attr("use_unbounded_threadpool: bool = false")
     .Attr("metadata: string = ''")
     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
                                                            "output_types"))
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index d05df09..dcf9e2f 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -4112,6 +4112,8 @@
     allowed_values {
       list {
         s: "PAD_UP"
+        s: "BATCH_DOWN"
+        s: "MINIMIZE_TPU_COST_PER_REQUEST"
       }
     }
   }
@@ -33011,6 +33013,13 @@
     }
   }
   attr {
+    name: "use_unbounded_threadpool"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
     name: "metadata"
     type: "string"
     default_value {
diff --git a/tensorflow/core/platform/BUILD b/tensorflow/core/platform/BUILD
index 2f179f7..1dc0e20 100644
--- a/tensorflow/core/platform/BUILD
+++ b/tensorflow/core/platform/BUILD
@@ -317,7 +317,7 @@
         ":test_main",
         "//tensorflow/core:protos_all_cc",
         "@eigen_archive//:eigen3",
-        "@local_tsl//tsl/lib/core:status_test_util",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
 
diff --git a/tensorflow/core/platform/env_test.cc b/tensorflow/core/platform/env_test.cc
index 178a68c..2059e3b 100644
--- a/tensorflow/core/platform/env_test.cc
+++ b/tensorflow/core/platform/env_test.cc
@@ -19,6 +19,7 @@
 
 #include <memory>
 
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/graph.pb.h"
 #include "tensorflow/core/framework/node_def.pb.h"
 #include "tensorflow/core/platform/cord.h"
@@ -29,7 +30,6 @@
 #include "tensorflow/core/platform/strcat.h"
 #include "tensorflow/core/platform/stringpiece.h"
 #include "tensorflow/core/platform/test.h"
-#include "tsl/lib/core/status_test_util.h"
 
 namespace tsl {
 
diff --git a/tensorflow/core/profiler/convert/trace_viewer/trace_events_to_json.h b/tensorflow/core/profiler/convert/trace_viewer/trace_events_to_json.h
index 916b530..66de83f 100644
--- a/tensorflow/core/profiler/convert/trace_viewer/trace_events_to_json.h
+++ b/tensorflow/core/profiler/convert/trace_viewer/trace_events_to_json.h
@@ -541,7 +541,7 @@
       separator.Add();
       output->Append(R"({"args":{"name":)", JsonEscape(device.name()),
                      R"(},"name":"process_name","ph":"M","pid":)", device_id,
-                     "}");
+                     R"(,"thread_count":)", device.resources_size(), "}");
     }
     separator.Add();
     output->Append(R"({"args":{"sort_index":)", device_id,
diff --git a/tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats_test.cc b/tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats_test.cc
index 3c06bd3..6c80add 100644
--- a/tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats_test.cc
+++ b/tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats_test.cc
@@ -24,6 +24,7 @@
 #include <gtest/gtest.h>
 #include "absl/strings/str_cat.h"
 #include "absl/strings/string_view.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/file_system.h"
 #include "tensorflow/core/platform/statusor.h"
@@ -31,7 +32,6 @@
 #include "tensorflow/core/profiler/protobuf/dcn_slack_analysis.pb.h"
 #include "tensorflow/core/profiler/utils/xplane_builder.h"
 #include "tensorflow/core/profiler/utils/xplane_utils.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/status.h"
 #include "tsl/profiler/protobuf/xplane.pb.h"
 
diff --git a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc
index 5609fd7..8b22847 100644
--- a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc
+++ b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc
@@ -227,8 +227,6 @@
   XPlaneVisitor plane = tsl::profiler::CreateTfXPlaneVisitor(&device_trace);
   using OpMetricBySymbol =
       absl::flat_hash_map</*symbol_id=*/uint64_t, OpMetrics>;
-  absl::flat_hash_map</*program_id=*/uint64_t, OpMetricBySymbol> flat_op_metric;
-
   XEventsOpMetricsDbBuilder builder;
 
   plane.ForEachLine([&](const XLineVisitor& line) {
diff --git a/tensorflow/core/profiler/utils/xplane_schema.h b/tensorflow/core/profiler/utils/xplane_schema.h
index 70d5efc..d6efbd1 100644
--- a/tensorflow/core/profiler/utils/xplane_schema.h
+++ b/tensorflow/core/profiler/utils/xplane_schema.h
@@ -54,6 +54,7 @@
 using tsl::profiler::kPythonTracerPlaneName;         // NOLINT
 using tsl::profiler::kRoctracerApiPlaneName;         // NOLINT
 using tsl::profiler::kSourceLineName;                // NOLINT
+using tsl::profiler::kSparseCorePlaneRegex;          // NOLINT
 using tsl::profiler::kStepLineName;                  // NOLINT
 using tsl::profiler::kTensorFlowNameScopeLineName;   // NOLINT
 using tsl::profiler::kTensorFlowOpLineName;          // NOLINT
diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h
index cd5d889..fbf446c 100644
--- a/tensorflow/core/public/version.h
+++ b/tensorflow/core/public/version.h
@@ -108,7 +108,7 @@
 
 #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0
 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0
-#define TF_GRAPH_DEF_VERSION 1940  // Updated: 2024/7/31
+#define TF_GRAPH_DEF_VERSION 1952  // Updated: 2024/8/12
 
 // Checkpoint compatibility versions (the versions field in SavedSliceMeta).
 //
diff --git a/tensorflow/core/runtime_fallback/kernel/attr_util_test.cc b/tensorflow/core/runtime_fallback/kernel/attr_util_test.cc
index bdb6f9e..5b88167 100644
--- a/tensorflow/core/runtime_fallback/kernel/attr_util_test.cc
+++ b/tensorflow/core/runtime_fallback/kernel/attr_util_test.cc
@@ -16,10 +16,10 @@
 
 #include <vector>
 
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/status.h"
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/platform/types.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tfrt/core_runtime/op_attr_type.h"  // from @tf_runtime
 #include "tfrt/core_runtime/op_attrs.h"  // from @tf_runtime
 #include "tfrt/support/forward_decls.h"  // from @tf_runtime
diff --git a/tensorflow/core/runtime_fallback/runtime/BUILD b/tensorflow/core/runtime_fallback/runtime/BUILD
index ece9e5d..45f433d 100644
--- a/tensorflow/core/runtime_fallback/runtime/BUILD
+++ b/tensorflow/core/runtime_fallback/runtime/BUILD
@@ -195,6 +195,7 @@
         "//tensorflow/core/kernels/batching_util:adaptive_shared_batch_scheduler",
         "//tensorflow/core/kernels/batching_util:batch_resource_base",
         "//tensorflow/core/kernels/batching_util:batch_scheduler_hdrs",
+        "//tensorflow/core/kernels/batching_util:batch_stats",
         "//tensorflow/core/kernels/batching_util:bounded_executor",
         "//tensorflow/core/kernels/batching_util:warmup",
         "//tensorflow/core/lib/core:refcount",
@@ -205,7 +206,6 @@
         "//tensorflow/core/tfrt/fallback:op_kernel_runner",
         "//tensorflow/core/tfrt/utils:error_util",
         "//tensorflow/core/tfrt/utils:fallback_tensor",
-        "@com_google_absl//absl/log:check",
         "@com_google_absl//absl/strings",
         "@local_tsl//tsl/platform:errors",
         "@local_tsl//tsl/platform:status",
diff --git a/tensorflow/core/runtime_fallback/runtime/fallback_batch_kernel.h b/tensorflow/core/runtime_fallback/runtime/fallback_batch_kernel.h
index dae6eb3..86772a2 100644
--- a/tensorflow/core/runtime_fallback/runtime/fallback_batch_kernel.h
+++ b/tensorflow/core/runtime_fallback/runtime/fallback_batch_kernel.h
@@ -20,7 +20,6 @@
 #include <string>
 #include <vector>
 
-#include "absl/log/check.h"
 #include "absl/strings/str_cat.h"
 #include "absl/strings/string_view.h"
 #include "tensorflow/core/framework/op_kernel.h"
@@ -30,10 +29,12 @@
 #include "tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h"
 #include "tensorflow/core/kernels/batching_util/batch_resource_base.h"
 #include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
+#include "tensorflow/core/kernels/batching_util/batch_stats.h"
 #include "tensorflow/core/kernels/batching_util/warmup.h"
 #include "tensorflow/core/lib/core/refcount.h"
 #include "tensorflow/core/lib/core/status.h"
 #include "tensorflow/core/platform/errors.h"
+#include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/random.h"
 #include "tsl/platform/errors.h"
 #include "tsl/platform/status.h"
@@ -215,6 +216,7 @@
       batch_resource_options.batch_timeout_micros = batch_timeout_micros_;
       batch_resource_options.max_enqueued_batches = max_enqueued_batches_;
       batch_resource_options.allowed_batch_sizes = allowed_batch_sizes_;
+      batch_resource_options.batch_padding_policy = batch_padding_policy_;
       batch_resource_options.low_priority_max_batch_size =
           low_priority_max_batch_size_;
       batch_resource_options.low_priority_batch_timeout_micros =
@@ -224,6 +226,13 @@
       batch_resource_options.low_priority_allowed_batch_sizes =
           low_priority_allowed_batch_sizes_;
 
+      serving::ModelBatchStats& model_batch_stats =
+          serving::GlobalBatchStatsRegistry().model(
+              /* model_name= */ std::string(GetModelName(c)),
+              /* op_name= */ c->op_kernel().name());
+      model_batch_stats.SetBatchTimeoutMicros(batch_timeout_micros_);
+      model_batch_stats.SetNumBatchThreads(num_batch_threads_);
+
       std::unique_ptr<BatchResourceType> new_resource;
       auto status = BatchResourceType::Create(
           c, batch_resource_options, batch_function_,
diff --git a/tensorflow/core/runtime_fallback/runtime/runtime_fallback_batch_tf_opkernels.cc b/tensorflow/core/runtime_fallback/runtime/runtime_fallback_batch_tf_opkernels.cc
index e953c70..38b8fc3 100644
--- a/tensorflow/core/runtime_fallback/runtime/runtime_fallback_batch_tf_opkernels.cc
+++ b/tensorflow/core/runtime_fallback/runtime/runtime_fallback_batch_tf_opkernels.cc
@@ -137,7 +137,8 @@
             options.num_batch_threads, options.max_batch_size,
             options.batch_timeout_micros, options.max_enqueued_batches,
             options.allowed_batch_sizes, enable_large_batch_splitting,
-            disable_padding, options.low_priority_max_batch_size,
+            disable_padding, options.batch_padding_policy,
+            options.low_priority_max_batch_size,
             options.low_priority_batch_timeout_micros,
             options.low_priority_max_enqueued_batches,
             options.low_priority_allowed_batch_sizes,
@@ -437,7 +438,7 @@
     // BatchFunction in core/ops/batch_ops.cc.
     .Attr(
         "batch_padding_policy: "
-        "{'PAD_UP'} = 'PAD_UP'")
+        "{'PAD_UP', 'BATCH_DOWN', 'MINIMIZE_TPU_COST_PER_REQUEST'} = 'PAD_UP'")
     .Attr("Tin: list(type)")
     .Attr("Tcaptured: list(type) >= 0")
     .Attr("Tout: list(type)")
diff --git a/tensorflow/core/tfrt/common/BUILD b/tensorflow/core/tfrt/common/BUILD
index ba79007..eb9724f 100644
--- a/tensorflow/core/tfrt/common/BUILD
+++ b/tensorflow/core/tfrt/common/BUILD
@@ -188,13 +188,13 @@
         ":pjrt_state",
         ":pjrt_util",
         "//tensorflow/core:framework",
-        "@local_tsl//tsl/lib/core:status_test_util",
         "@local_tsl//tsl/platform:status_matchers",
         "@local_tsl//tsl/platform:statusor",
         "@local_tsl//tsl/platform:test",
         "@local_tsl//tsl/platform:test_main",
         "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc",
         "@local_xla//xla/pjrt/cpu:cpu_client",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
 
@@ -214,13 +214,13 @@
         "//tensorflow/compiler/tf2xla:xla_op_registry",
         "//tensorflow/core:framework",
         "@com_google_absl//absl/strings",
-        "@local_tsl//tsl/lib/core:status_test_util",
         "@local_tsl//tsl/platform:status_matchers",
         "@local_tsl//tsl/platform:test_main",
         "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc",
         "@local_xla//xla/pjrt:pjrt_client",
         "@local_xla//xla/pjrt:tfrt_cpu_pjrt_client",
         "@local_xla//xla/service:gpu_plugin",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
 
diff --git a/tensorflow/core/tfrt/common/pjrt_state_test.cc b/tensorflow/core/tfrt/common/pjrt_state_test.cc
index fddd72e..03dcdb7 100644
--- a/tensorflow/core/tfrt/common/pjrt_state_test.cc
+++ b/tensorflow/core/tfrt/common/pjrt_state_test.cc
@@ -21,10 +21,10 @@
 #include <gtest/gtest.h>
 #include "xla/pjrt/cpu/cpu_client.h"
 #include "xla/pjrt/pjrt_client.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/platform/refcount.h"
 #include "tensorflow/core/protobuf/error_codes.pb.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/status_matchers.h"
 #include "tsl/platform/statusor.h"
 
diff --git a/tensorflow/core/tfrt/common/pjrt_util_test.cc b/tensorflow/core/tfrt/common/pjrt_util_test.cc
index 1361b72..48f7743 100644
--- a/tensorflow/core/tfrt/common/pjrt_util_test.cc
+++ b/tensorflow/core/tfrt/common/pjrt_util_test.cc
@@ -18,9 +18,9 @@
 #include <utility>
 
 #include "xla/pjrt/cpu/cpu_client.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/tfrt/common/pjrt_state.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/status_matchers.h"
 #include "tsl/platform/statusor.h"
 #include "tsl/platform/test.h"
diff --git a/tensorflow/core/tfrt/fallback/fallback_state_test.cc b/tensorflow/core/tfrt/fallback/fallback_state_test.cc
index d7d5531..2111171 100644
--- a/tensorflow/core/tfrt/fallback/fallback_state_test.cc
+++ b/tensorflow/core/tfrt/fallback/fallback_state_test.cc
@@ -19,10 +19,10 @@
 #include "tensorflow/cc/framework/ops.h"
 #include "tensorflow/cc/framework/scope.h"
 #include "tensorflow/cc/ops/const_op.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/status_matchers.h"
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/protobuf/error_codes.pb.h"
-#include "tsl/lib/core/status_test_util.h"
 
 namespace tensorflow {
 namespace {
diff --git a/tensorflow/core/tfrt/gpu/kernel/BUILD b/tensorflow/core/tfrt/gpu/kernel/BUILD
index bd4f861..fef0e58 100644
--- a/tensorflow/core/tfrt/gpu/kernel/BUILD
+++ b/tensorflow/core/tfrt/gpu/kernel/BUILD
@@ -13,20 +13,16 @@
     deps = [
         ":gpu_runner",
         "//tensorflow/core:framework",
-        "//tensorflow/core/common_runtime:copy_tensor",
         "//tensorflow/core/framework:tensor",
         "//tensorflow/core/platform:status",
         "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_compat_request_state",
-        "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_utils",
         "//tensorflow/core/runtime_fallback/kernel:tensor_util",
         "//tensorflow/core/tfrt/utils:fallback_tensor",
         "//tensorflow/core/tfrt/utils:gpu_variables_table",
-        "//tensorflow/core/tfrt/utils:tensor_util",
         "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/status",
         "@com_google_absl//absl/strings",
         "@llvm-project//llvm:Support",
-        "@tf_runtime//:core_runtime",
         "@tf_runtime//:hostcontext",
         "@tf_runtime//:support",
         "@tf_runtime//:tensor_alwayslink",
@@ -47,9 +43,11 @@
         "//tensorflow/compiler/jit:xla_launch_util",
         "//tensorflow/compiler/tf2xla:xla_compiler",
         "//tensorflow/core:framework",
+        "//tensorflow/core/framework:attr_value_proto_cc",
+        "//tensorflow/core/framework:function_proto_cc",
+        "//tensorflow/core/framework:types_proto_cc",
         "//tensorflow/core/platform:notification",
         "//tensorflow/core/platform:status",
-        "//tensorflow/core/platform:statusor",
         "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_compat_request_state",
         "//tensorflow/core/tfrt/common:global_state",
         "//tensorflow/core/tfrt/utils:fallback_tensor",
@@ -59,6 +57,7 @@
         "@com_google_absl//absl/log",
         "@com_google_absl//absl/log:check",
         "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/types:span",
         "@llvm-project//llvm:Support",
@@ -122,6 +121,7 @@
         "//tensorflow/core/common_runtime/gpu:gpu_serving_device_selector",
         "//tensorflow/core/platform:status",
         "//tensorflow/core/tfrt/runtime",
+        "@com_google_absl//absl/status",
         "@local_xla//xla/tsl/framework:serving_device_selector_policies",
         "@tf_runtime//:hostcontext",
     ],
diff --git a/tensorflow/core/tfrt/gpu/kernel/gpu_runner.cc b/tensorflow/core/tfrt/gpu/kernel/gpu_runner.cc
index 3143b8b..d4047d4 100644
--- a/tensorflow/core/tfrt/gpu/kernel/gpu_runner.cc
+++ b/tensorflow/core/tfrt/gpu/kernel/gpu_runner.cc
@@ -41,15 +41,18 @@
 #include "xla/tsl/framework/device_id_manager.h"
 #include "xla/tsl/framework/serving_device_selector.h"
 #include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
 #include "tensorflow/core/framework/device.h"
 #include "tensorflow/core/framework/device_base.h"
 #include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/function.pb.h"
 #include "tensorflow/core/framework/resource_mgr.h"
 #include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
 #include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
 #include "tensorflow/core/platform/notification.h"
 #include "tensorflow/core/platform/status.h"
-#include "tensorflow/core/platform/statusor.h"
 #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h"
 #include "tensorflow/core/tfrt/common/global_state.h"
 #include "tensorflow/core/tfrt/utils/fallback_tensor.h"
@@ -231,7 +234,7 @@
     int device_idx, const llvm::SmallVector<tfrt_stub::FallbackTensor>& args,
     tfrt::ArrayRef<int64_t> resource_indices, Device* cpu_device,
     absl::flat_hash_map<int, Device*> gpu_devices,
-    tfrt::gpu::GpuVariablesTable& vars_table,
+    tfrt::gpu::GpuVariablesTable& vars_table, bool variables_are_shared,
     const tfrt::ExecutionContext& exec_ctx) {
   llvm::SmallVector<tfrt::AsyncValueRef<tfrt_stub::FallbackTensor>> results;
 
@@ -244,35 +247,51 @@
   TF_ASSIGN_OR_RETURN(const std::vector<tsl::TfDeviceId> devices_on_platform,
                       tsl::DeviceIdManager::GetTfDevicesOnPlatform(
                           device_type, platform_device_id));
-  const int platform_idx = platform_device_id.value();
   absl::flat_hash_set<int64_t> resource_indices_set(resource_indices.begin(),
                                                     resource_indices.end());
 
+  // If variables are shared, there is only one copy of variables for all
+  // logical devices on the same physical GPU device; otherwise, each logical
+  // device has its own copy of variables.
+  const int cache_copy_idx =
+      variables_are_shared ? platform_device_id.value() : device_idx;
+
   for (int i = 0, resource_idx = 0; i < args.size(); ++i) {
     if (resource_indices_set.contains(i)) {
       // Transfer resources.
+      VLOG(2) << "Transfer resource arg[" << i << "].";
       tfrt::AsyncValueRef<tfrt_stub::FallbackTensor> device_tensor;
       auto cached_device_variable =
-          vars_table.GetDeviceVariable(args[i], platform_idx);
+          vars_table.GetDeviceVariable(args[i], cache_copy_idx);
       if (cached_device_variable) {
-        VLOG(2) << "Cache hit for resource arg[" << i << "]";
+        VLOG(2) << "Cache hit for resource arg[" << i << "].";
         device_tensor = cached_device_variable.CopyRef();
       } else {
-        VLOG(2) << "Cache miss for resource arg[" << i << "]";
-        // Distribute variables on virtual devices on the same GPU.
-        const int idx = resource_idx % devices_on_platform.size();
-        const int gpu_device_idx = devices_on_platform[idx].value();
+        VLOG(2) << "Cache miss for resource arg[" << i << "].";
+
+        int gpu_device_idx;
+        if (variables_are_shared) {
+          // Distribute variables on logical devices on the same GPU.
+          const int idx = resource_idx % devices_on_platform.size();
+          gpu_device_idx = devices_on_platform[idx].value();
+        } else {
+          gpu_device_idx = device_idx;
+        }
+
+        VLOG(2) << "Transfer the resource arg[" << i << "] to device "
+                << gpu_device_idx << ".";
         device_tensor = TransferTensorToDevice(exec_ctx, args[i],
                                                gpu_devices.at(gpu_device_idx));
-        vars_table.AddOrUpdateDeviceVariable(args[i], platform_idx,
+        vars_table.AddOrUpdateDeviceVariable(args[i], cache_copy_idx,
                                              std::move(device_tensor));
         device_tensor =
-            vars_table.GetDeviceVariable(args[i], platform_idx).CopyRef();
+            vars_table.GetDeviceVariable(args[i], cache_copy_idx).CopyRef();
       }
       results.push_back(device_tensor);
       ++resource_idx;
     } else {
       // Transfer inputs.
+      VLOG(2) << "Transfer input arg[" << i << "].";
       tfrt::AsyncValueRef<tfrt_stub::FallbackTensor> device_tensor =
           TransferTensorToDevice(exec_ctx, args[i], gpu_devices.at(device_idx));
       results.push_back(device_tensor);
@@ -356,6 +375,7 @@
   tsl::DeviceReservation device_reservation =
       serving_device_selector_->ReserveDevice(absl::StrCat(fingerprint));
   const int device_idx = device_reservation.device_index();
+  VLOG(1) << "GpuRunner selected device " << device_idx << ".";
 
   // Compile the program.
   const XlaCompiler::CompilationResult* compilation_result;
@@ -368,10 +388,10 @@
   TF_ASSIGN_OR_RETURN(
       llvm::SmallVector<tfrt::AsyncValueRef<tfrt_stub::FallbackTensor>>
           transferred_args,
-      TransferVariablesAndInputs(device_idx, *run_inputs.args,
-                                 run_inputs.resource_indices,
-                                 run_inputs.cpu_device, *run_inputs.gpu_devices,
-                                 vars_table_, *run_inputs.exec_ctx));
+      TransferVariablesAndInputs(
+          device_idx, *run_inputs.args, run_inputs.resource_indices,
+          run_inputs.cpu_device, *run_inputs.gpu_devices, vars_table_,
+          /*variables_are_shared=*/false, *run_inputs.exec_ctx));
 
   llvm::SmallVector<tfrt::RCReference<tfrt::AsyncValue>, 4>
       transferred_args_to_wait;
diff --git a/tensorflow/core/tfrt/gpu/kernel/gpu_runner.h b/tensorflow/core/tfrt/gpu/kernel/gpu_runner.h
index fc61eff..d292fed 100644
--- a/tensorflow/core/tfrt/gpu/kernel/gpu_runner.h
+++ b/tensorflow/core/tfrt/gpu/kernel/gpu_runner.h
@@ -18,7 +18,10 @@
 #include <vector>
 
 #include "absl/container/flat_hash_map.h"
+#include "absl/status/statusor.h"
+#include "llvm/ADT/SmallVector.h"
 #include "xla/tsl/framework/serving_device_selector.h"
+#include "tensorflow/core/framework/device.h"
 #include "tensorflow/core/framework/resource_mgr.h"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/platform/status.h"
@@ -27,6 +30,7 @@
 #include "tensorflow/core/tfrt/utils/gpu_variables_table.h"
 #include "tfrt/host_context/async_value_ref.h"  // from @tf_runtime
 #include "tfrt/host_context/execution_context.h"  // from @tf_runtime
+#include "tfrt/support/forward_decls.h"  // from @tf_runtime
 
 namespace tensorflow {
 namespace gpu {
diff --git a/tensorflow/core/tfrt/gpu/kernel/gpurt_kernels.cc b/tensorflow/core/tfrt/gpu/kernel/gpurt_kernels.cc
index 8cc6a62..43cb013 100644
--- a/tensorflow/core/tfrt/gpu/kernel/gpurt_kernels.cc
+++ b/tensorflow/core/tfrt/gpu/kernel/gpurt_kernels.cc
@@ -21,19 +21,15 @@
 #include "absl/strings/str_cat.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
-#include "tensorflow/core/common_runtime/copy_tensor.h"
-#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/device.h"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/platform/status.h"
 #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h"
-#include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_utils.h"
 #include "tensorflow/core/runtime_fallback/kernel/tensor_util.h"
 #include "tensorflow/core/tfrt/gpu/kernel/gpu_runner.h"
 #include "tensorflow/core/tfrt/utils/fallback_tensor.h"
 #include "tensorflow/core/tfrt/utils/gpu_variables_table.h"
-#include "tensorflow/core/tfrt/utils/tensor_util.h"
-#include "tfrt/host_context/async_dispatch.h"  // from @tf_runtime
 #include "tfrt/host_context/async_value_ref.h"  // from @tf_runtime
 #include "tfrt/host_context/attribute_utils.h"  // from @tf_runtime
 #include "tfrt/host_context/execution_context.h"  // from @tf_runtime
diff --git a/tensorflow/core/tfrt/gpu/kernel/tfrt_gpu_init.cc b/tensorflow/core/tfrt/gpu/kernel/tfrt_gpu_init.cc
index 94e52ad..48f3160 100644
--- a/tensorflow/core/tfrt/gpu/kernel/tfrt_gpu_init.cc
+++ b/tensorflow/core/tfrt/gpu/kernel/tfrt_gpu_init.cc
@@ -17,6 +17,7 @@
 #include <memory>
 #include <utility>
 
+#include "absl/status/status.h"
 #include "xla/tsl/framework/serving_device_selector_policies.h"
 #include "tensorflow/core/common_runtime/gpu/gpu_serving_device_selector.h"
 #include "tensorflow/core/platform/status.h"
diff --git a/tensorflow/core/tfrt/gpu/kernel/tfrt_gpu_init.h b/tensorflow/core/tfrt/gpu/kernel/tfrt_gpu_init.h
index bb99022..452ccdd 100644
--- a/tensorflow/core/tfrt/gpu/kernel/tfrt_gpu_init.h
+++ b/tensorflow/core/tfrt/gpu/kernel/tfrt_gpu_init.h
@@ -15,6 +15,7 @@
 #ifndef TENSORFLOW_CORE_TFRT_GPU_KERNEL_TFRT_GPU_INIT_H_
 #define TENSORFLOW_CORE_TFRT_GPU_KERNEL_TFRT_GPU_INIT_H_
 #include "xla/tsl/framework/serving_device_selector_policies.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/core/tfrt/runtime/runtime.h"
 
 namespace tensorflow {
diff --git a/tensorflow/core/tfrt/graph_executor/BUILD b/tensorflow/core/tfrt/graph_executor/BUILD
index 3a5d945..61d869f 100644
--- a/tensorflow/core/tfrt/graph_executor/BUILD
+++ b/tensorflow/core/tfrt/graph_executor/BUILD
@@ -246,9 +246,9 @@
         ":test_config_proto_cc",
         "@com_google_absl//absl/status",
         "@com_google_googletest//:gtest_main",
-        "@local_tsl//tsl/lib/core:status_test_util",
         "@local_tsl//tsl/platform:status_matchers",
         "@local_tsl//tsl/platform:statusor",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
 
diff --git a/tensorflow/core/tfrt/graph_executor/config_test.cc b/tensorflow/core/tfrt/graph_executor/config_test.cc
index bc3d186..fc1b54f 100644
--- a/tensorflow/core/tfrt/graph_executor/config_test.cc
+++ b/tensorflow/core/tfrt/graph_executor/config_test.cc
@@ -17,9 +17,9 @@
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
 #include "absl/status/status.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/tfrt/graph_executor/config.pb.h"
 #include "tensorflow/core/tfrt/graph_executor/test_config.pb.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/status_matchers.h"
 #include "tsl/platform/statusor.h"
 
diff --git a/tensorflow/core/tfrt/graph_executor/graph_executor_test.cc b/tensorflow/core/tfrt/graph_executor/graph_executor_test.cc
index 0a0f073..c0e07b3 100644
--- a/tensorflow/core/tfrt/graph_executor/graph_executor_test.cc
+++ b/tensorflow/core/tfrt/graph_executor/graph_executor_test.cc
@@ -31,6 +31,7 @@
 #include "tensorflow/cc/framework/scope.h"
 #include "tensorflow/cc/ops/array_ops.h"
 #include "tensorflow/cc/ops/const_op.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/common_shape_fns.h"
 #include "tensorflow/core/framework/graph.pb.h"
 #include "tensorflow/core/framework/op.h"
@@ -50,7 +51,6 @@
 #include "tensorflow/core/tfrt/mlrt/interpreter/value.h"
 #include "tensorflow/core/tfrt/mlrt/kernel/kernel.h"
 #include "tensorflow/core/tfrt/saved_model/saved_model_testutil.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/statusor.h"
 #include "tfrt/cpp_tests/test_util.h"  // from @tf_runtime
 #include "tfrt/host_context/resource_context.h"  // from @tf_runtime
diff --git a/tensorflow/core/tfrt/ifrt/BUILD b/tensorflow/core/tfrt/ifrt/BUILD
index dcc4c9d..94ca5ab 100644
--- a/tensorflow/core/tfrt/ifrt/BUILD
+++ b/tensorflow/core/tfrt/ifrt/BUILD
@@ -114,6 +114,7 @@
     hdrs = ["ifrt_serving_executable.h"],
     deps = [
         ":ifrt_config_proto_cc",
+        ":ifrt_device_utils",
         ":ifrt_loaded_variable_registry",
         ":ifrt_loaded_variable_utils",
         ":ifrt_restore_tensor_registry",
@@ -143,7 +144,10 @@
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/synchronization",
         "@com_google_absl//absl/types:span",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:FuncDialect",
         "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:Support",
         "@local_tsl//tsl/platform:env",
         "@local_tsl//tsl/platform:errors",
         "@local_tsl//tsl/platform:statusor",
@@ -216,12 +220,22 @@
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/synchronization",
         "@local_tsl//tsl/platform:statusor",
+        "@local_xla//xla/hlo/ir:hlo",
         "@local_xla//xla/python/ifrt",
         "@local_xla//xla/tsl/concurrency:ref_count",
     ],
 )
 
 cc_library(
+    name = "ifrt_model_restore_context",
+    hdrs = ["ifrt_model_restore_context.h"],
+    deps = [
+        ":checkpoint_loader",
+        "@com_google_absl//absl/strings:string_view",
+    ],
+)
+
+cc_library(
     name = "ifrt_model_context",
     srcs = ["ifrt_model_context.cc"],
     hdrs = ["ifrt_model_context.h"],
@@ -406,10 +420,10 @@
         "//tensorflow/core/framework:types_proto_cc",
         "@com_google_absl//absl/status",
         "@com_google_googletest//:gtest_main",
-        "@local_tsl//tsl/lib/core:status_test_util",
         "@local_tsl//tsl/platform:status_matchers",
         "@local_tsl//tsl/platform:statusor",
         "@local_xla//xla/python/ifrt",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
 
@@ -429,16 +443,17 @@
         "@com_google_absl//absl/status",
         "@com_google_absl//absl/status:statusor",
         "@com_google_googletest//:gtest_main",
-        "@local_tsl//tsl/lib/core:status_test_util",
         "@local_tsl//tsl/platform:env",
         "@local_tsl//tsl/platform:status_matchers",
         "@local_tsl//tsl/platform:statusor",
         "@local_tsl//tsl/platform:test",
         "@local_xla//xla:xla_data_proto_cc",
+        "@local_xla//xla/hlo/ir:hlo",
         "@local_xla//xla/python/ifrt",
         "@local_xla//xla/python/ifrt:test_util",
         "@local_xla//xla/python/pjrt_ifrt:tfrt_cpu_client_test_lib",
         "@local_xla//xla/tsl/concurrency:ref_count",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
         "@tf_runtime//:hostcontext",
     ],
 )
@@ -593,3 +608,37 @@
         "@tf_runtime//backends/cpu:tf_ops_alwayslink",
     ],
 )
+
+cc_library(
+    name = "checkpoint_loader",
+    srcs = ["checkpoint_loader.cc"],
+    hdrs = ["checkpoint_loader.h"],
+    deps = [
+        ":ifrt_loaded_variable_utils",
+        ":ifrt_restore_tensor_registry",
+        "//tensorflow/compiler/mlir/tfrt/transforms/ifrt:ifrt_types",
+        "//tensorflow/core:framework",
+        "//tensorflow/core/common_runtime:function",
+        "//tensorflow/core/framework:attr_value_proto_cc",
+        "//tensorflow/core/framework:node_def_util",
+        "//tensorflow/core/framework:tensor",
+        "//tensorflow/core/tfrt/fallback:op_kernel_runner",
+        "//tensorflow/core/tfrt/mlrt/bytecode",
+        "//tensorflow/core/tfrt/mlrt/kernel:context",
+        "//tensorflow/core/tfrt/mlrt/kernel:kernel_runner_utils",
+        "//tensorflow/core/tfrt/mlrt/kernel:shard_restore_util",
+        "//tensorflow/core/tfrt/utils:fallback_tensor",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:span",
+        "@llvm-project//mlir:IR",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:statusor",
+        "@local_tsl//tsl/platform:tstring",
+        "@local_xla//xla/python/ifrt",
+        "@tf_runtime//:hostcontext",
+    ],
+)
diff --git a/tensorflow/core/tfrt/ifrt/checkpoint_loader.cc b/tensorflow/core/tfrt/ifrt/checkpoint_loader.cc
new file mode 100644
index 0000000..7085efd
--- /dev/null
+++ b/tensorflow/core/tfrt/ifrt/checkpoint_loader.cc
@@ -0,0 +1,358 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/tfrt/ifrt/checkpoint_loader.h"
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_cat.h"
+#include "absl/types/span.h"
+#include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/OwningOpRef.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h"
+#include "xla/python/ifrt/future.h"
+#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_handle.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/tfrt/fallback/op_kernel_runner.h"
+#include "tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils.h"
+#include "tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h"
+#include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h"
+#include "tensorflow/core/tfrt/mlrt/kernel/context.h"
+#include "tensorflow/core/tfrt/mlrt/kernel/kernel_runner_utils.h"
+#include "tensorflow/core/tfrt/mlrt/kernel/shard_restore_util.h"
+#include "tensorflow/core/tfrt/utils/fallback_tensor.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
+#include "tsl/platform/tstring.h"
+#include "tfrt/host_context/concurrent_work_queue.h"  // from @tf_runtime
+
+namespace tensorflow {
+namespace ifrt_serving {
+
+namespace {
+
+static constexpr int kNumRestoreClusters = 4;
+
+// A shard of variables to be restored.
+struct RestoreVariableShard {
+  tensorflow::Tensor prefix;
+  tensorflow::Tensor tensor_names;
+  tensorflow::Tensor shape_and_slices;
+  std::vector<tensorflow::tfrt_stub::FallbackTensor> var_handles;
+  tensorflow::AttrValue dtypes_attr_value;
+  std::vector<tensorflow::DataType> restored_dtypes;
+  std::vector<bool> truncate_in_cast;
+};
+
+struct AsyncState {
+  explicit AsyncState(
+      const std::vector<tensorflow::TensorValue>& input_tf_tensor_values,
+      const OpKernelContext::Params& params, int num_outputs,
+      const tensorflow::DeviceMgr& device_manager,
+      const tensorflow::ProcessFunctionLibraryRuntime&
+          process_function_library_runtime)
+      : run_state(input_tf_tensor_values, params),
+        context(&run_state.params, num_outputs),
+        device_manager(device_manager),
+        process_function_library_runtime(process_function_library_runtime) {}
+
+  tfrt_stub::OpKernelRunState run_state;
+  OpKernelContext context;
+  const tensorflow::DeviceMgr& device_manager;
+  const tensorflow::ProcessFunctionLibraryRuntime&
+      process_function_library_runtime;
+
+  std::vector<xla::ifrt::Promise<tensorflow::Tensor>> results;
+};
+
+// Returns a casted tensor if successful.
+absl::StatusOr<tensorflow::Tensor> Cast(
+    tensorflow::Tensor& in_tensor, tensorflow::DataType restored_dtype,
+    tensorflow::DataType cast_dtype, bool truncate_in_cast,
+    const tensorflow::DeviceMgr& device_manager,
+    const tensorflow::ProcessFunctionLibraryRuntime&
+        process_function_library_runtime,
+    OpKernelContext::Params& params) {
+  auto runner =
+      tfrt_stub::OpKernelRunner::Create(
+          /*op_name=*/
+          "Cast", /*node_name=*/"Cast", params.device->name(),
+          /*num_args=*/1,
+          [&](tensorflow::AttrValueMap* attr_value_map) {
+            tensorflow::AttrValue restored_dtype_attr_value;
+            restored_dtype_attr_value.set_type(restored_dtype);
+            attr_value_map->insert({"SrcT", restored_dtype_attr_value});
+
+            tensorflow::AttrValue cast_dtype_attr_value;
+            cast_dtype_attr_value.set_type(cast_dtype);
+            attr_value_map->insert({"DstT", cast_dtype_attr_value});
+
+            tensorflow::AttrValue truncate_attr_value;
+            truncate_attr_value.set_b(truncate_in_cast);
+            attr_value_map->insert({"Truncate", truncate_attr_value});
+            return absl::OkStatus();
+          },
+          device_manager, process_function_library_runtime)
+          .value();
+
+  std::vector<tensorflow::TensorValue> input_tf_tensor_values;
+  input_tf_tensor_values.push_back(tensorflow::TensorValue(&in_tensor));
+
+  tf_mlrt::SetUpParams(runner, input_tf_tensor_values, params);
+  // Use persistent device instead of the per request device.
+
+  OpKernelContext op_kernel_context(&params, /*num_outputs=*/1);
+
+  runner.Run(&op_kernel_context);
+
+  if (!op_kernel_context.status().ok()) {
+    return op_kernel_context.status();
+  }
+  DCHECK_EQ(op_kernel_context.num_outputs(), 1);
+  return *(op_kernel_context.mutable_output(0));
+}
+
+absl::Status RunShard(RestoreVariableShard shard,
+                      IfrtRestoreTensorRegistry* ifrt_restore_tensor_registry,
+                      tfrt::ConcurrentWorkQueue* checkpoint_loader_work_queue,
+                      tf_mlrt::Context& context) {
+  if (!ifrt_restore_tensor_registry) {
+    return absl::InternalError("ifrt_restore_tensor_registry must not be null");
+  }
+  if (!checkpoint_loader_work_queue) {
+    return absl::InternalError("checkpoint_loader_work_queue must not be null");
+  }
+  const int num_outputs = shard.var_handles.size();
+  DCHECK_EQ(num_outputs, shard.tensor_names.NumElements());
+  auto& fallback_request_state = context.fallback_request_state();
+
+  // Use `tf.RestoreV2` to restore tensor. This will also populate
+  // tensorflow::ResourceManager.
+  // TODO(b/319045348): avoid populating tensorflow::ResourceManager if the
+  // variable is only used by device/IFRT.
+  // TODO(b/319045348): consider directly calling restore function such as that
+  // in /tensorflow/core/kernels/save_restore_v2_ops.cc
+  auto runner =
+      tfrt_stub::OpKernelRunner::Create(
+          /*op_name=*/
+          "RestoreV2", /*node_name=*/"RestoreV2",
+          context.params().device->name(),
+          /*num_args=*/3,
+          [&](tensorflow::AttrValueMap* attr_value_map) {
+            attr_value_map->insert({"dtypes", shard.dtypes_attr_value});
+            return absl::OkStatus();
+          },
+          fallback_request_state.device_manager(),
+          fallback_request_state.process_function_library_runtime())
+          .value();
+
+  // Prepare the input tensors.
+  std::vector<tensorflow::TensorValue> input_tf_tensor_values;
+  static constexpr int kNumInputArgs = 3;
+  input_tf_tensor_values.resize(kNumInputArgs);
+  // We need to keep these tensor alive
+  input_tf_tensor_values[0].tensor = &shard.prefix;
+  input_tf_tensor_values[1].tensor = &shard.tensor_names;
+  input_tf_tensor_values[2].tensor = &shard.shape_and_slices;
+
+  auto& params = context.params();
+  tf_mlrt::SetUpParams(runner, input_tf_tensor_values, params);
+  // Use persistent device instead of the per request device.
+  params.device = context.fallback_request_state().device_manager().HostCPU();
+
+  auto async_state = std::make_unique<AsyncState>(
+      input_tf_tensor_values, params, num_outputs,
+      fallback_request_state.device_manager(),
+      fallback_request_state.process_function_library_runtime());
+
+  for (int i = 0; i < num_outputs; ++i) {
+    auto promise = xla::ifrt::Future<tensorflow::Tensor>::CreatePromise();
+    auto future = xla::ifrt::Future<tensorflow::Tensor>(promise);
+    const ResourceHandle& var_handle =
+        shard.var_handles[i].tensor().scalar<tensorflow::ResourceHandle>()();
+
+    TF_ASSIGN_OR_RETURN(ifrt_serving::DtypeAndShape dtype_and_shape,
+                        ifrt_serving::GetDtypeAndShape(var_handle));
+
+    std::string runtime_name =
+        ifrt_serving::GetRuntimeNameFromVarHandle(var_handle);
+
+    ifrt_serving::IfrtRestoreTensorRegistry::RestoredTensorInfo
+        restored_tensor_info = {false, std::move(dtype_and_shape),
+                                std::move(future)};
+    if (auto status = ifrt_restore_tensor_registry->TryRegister(
+            runtime_name, restored_tensor_info);
+        !status.ok()) {
+      // Propagate errors so that if already-registered futures are being waited
+      // on, they can be unblocked.
+      for (auto& result : async_state->results) {
+        std::move(result).Set(status);
+      };
+      return status;
+    }
+    async_state->results.push_back(std::move(promise));
+  }
+
+  // Use dedicated work queue for restore operation.
+  checkpoint_loader_work_queue->AddTask([runner = std::move(runner),
+                                         async_state = std::move(async_state),
+                                         shard = std::move(shard)]() {
+    // Keep input tensor alive in `shard`.
+    auto* op_kernel_context_ptr = &async_state->context;
+    runner.Run(op_kernel_context_ptr);
+
+    auto& op_kernel_context = async_state->context;
+    if (!op_kernel_context.status().ok()) {
+      for (auto& result : async_state->results) {
+        std::move(result).Set(op_kernel_context.status());
+      }
+      return;
+    }
+    DCHECK_EQ(shard.var_handles.size(), op_kernel_context.num_outputs());
+    DCHECK_EQ(shard.truncate_in_cast.size(), op_kernel_context.num_outputs());
+
+    // TODO(b/343964091): consider to run multiple casts in parallel.
+    for (int i = 0; i < op_kernel_context.num_outputs(); ++i) {
+      DCHECK(op_kernel_context.mutable_output(i));
+
+      if (op_kernel_context.mutable_output(i)->dtype() !=
+          shard.restored_dtypes[i]) {
+        std::move(async_state->results[i])
+            .Set(absl::InvalidArgumentError(absl::StrCat(
+                "The restored tensor has a different dtype than the "
+                "variable handle: ",
+                op_kernel_context.mutable_output(i)->dtype(), " vs. ",
+                shard.restored_dtypes[i])));
+        return;
+      }
+      const ResourceHandle& var_handle =
+          shard.var_handles[i].tensor().scalar<tensorflow::ResourceHandle>()();
+
+      if (shard.restored_dtypes[i] == var_handle.dtypes_and_shapes()[0].dtype) {
+        std::move(async_state->results[i])
+            .Set(*std::move(op_kernel_context.mutable_output(i)));
+      } else {
+        absl::StatusOr<tensorflow::Tensor> cast_output =
+            Cast(*op_kernel_context.mutable_output(i), shard.restored_dtypes[i],
+                 var_handle.dtypes_and_shapes()[0].dtype,
+                 shard.truncate_in_cast[i], async_state->device_manager,
+                 async_state->process_function_library_runtime,
+                 async_state->run_state.params);
+        if (!cast_output.ok()) {
+          std::move(async_state->results[i]).Set(cast_output.status());
+        } else {
+          std::move(async_state->results[i]).Set(*std::move(cast_output));
+        }
+      }
+    }
+  });
+  return absl::OkStatus();
+}
+
+int64_t GetSizeFromVarHandle(const ResourceHandle& handle) {
+  int size = 0;
+  for (auto& dtype_and_shape : handle.dtypes_and_shapes()) {
+    size += DataTypeSize(dtype_and_shape.dtype) *
+            dtype_and_shape.shape.num_elements();
+  }
+  return size;
+}
+
+}  // namespace
+
+absl::Status CheckpointLoader::PrepareRestore(
+    mlir::OwningOpRef<mlir::ModuleOp> module) {
+  VLOG(1) << "Skip CheckpointLoader::PrepareRestore";
+  return absl::OkStatus();
+}
+
+absl::Status CheckpointLoader::Load(
+    const tensorflow::tfrt_stub::FallbackTensor& prefix,
+    const std::vector<tensorflow::tfrt_stub::FallbackTensor>& var_handles,
+    const tensorflow::tfrt_stub::FallbackTensor& tensor_names,
+    const tensorflow::tfrt_stub::FallbackTensor& shape_and_slices,
+    const mlrt::bc::Vector<tensorflow::DataType>& restored_dtypes,
+    const mlrt::bc::Vector<bool>& truncate_in_cast, tf_mlrt::Context& context) {
+  std::vector<int64_t> variable_sizes;
+  variable_sizes.reserve(var_handles.size());
+  for (auto& handle : var_handles) {
+    variable_sizes.push_back(GetSizeFromVarHandle(
+        handle.tensor().scalar<tensorflow::ResourceHandle>()()));
+  }
+
+  std::vector<std::vector<int>> sharded_indices = tf_mlrt::ShardVariables(
+      kNumRestoreClusters, absl::MakeSpan(variable_sizes));
+
+  // Converts the names and slices back to the tensor.
+  auto vector_to_tensor = [](const std::vector<tsl::tstring>& vec) {
+    tensorflow::Tensor tensor(tensorflow::DT_STRING,
+                              TensorShape({static_cast<int>(vec.size())}));
+    for (int i = 0; i < vec.size(); ++i) {
+      tensor.flat<tsl::tstring>()(i) = vec[i];
+    }
+    return tensor;
+  };
+
+  const auto& tensor_names_flat = tensor_names.tensor().flat<tsl::tstring>();
+  const auto& shape_and_slices_flat =
+      shape_and_slices.tensor().flat<tsl::tstring>();
+
+  std::vector<RestoreVariableShard> shards;
+  shards.reserve(sharded_indices.size());
+  for (auto& sharded_index : sharded_indices) {
+    RestoreVariableShard shard;
+    shard.var_handles.reserve(sharded_index.size());
+    shard.truncate_in_cast.reserve(sharded_index.size());
+    shard.restored_dtypes.reserve(sharded_index.size());
+    std::vector<tsl::tstring> tensor_names;
+    std::vector<tsl::tstring> shape_and_slices;
+    shape_and_slices.reserve(sharded_index.size());
+    tensor_names.reserve(sharded_index.size());
+    for (int index : sharded_index) {
+      tensor_names.push_back(tensor_names_flat(index));
+      shape_and_slices.push_back(shape_and_slices_flat(index));
+      shard.dtypes_attr_value.mutable_list()->add_type(restored_dtypes[index]);
+      shard.var_handles.push_back(var_handles[index]);
+      shard.restored_dtypes.push_back(restored_dtypes[index]);
+      shard.truncate_in_cast.push_back(truncate_in_cast[index]);
+    }
+    shard.prefix = prefix.tensor();
+    shard.tensor_names = vector_to_tensor(tensor_names);
+    shard.shape_and_slices = vector_to_tensor(shape_and_slices);
+    shards.push_back(std::move(shard));
+  }
+  for (const auto& shard : shards) {
+    TF_RETURN_IF_ERROR(RunShard(shard, ifrt_restore_tensor_registry_,
+                                checkpoint_loader_work_queue_, context));
+  }
+  return absl::OkStatus();
+}
+
+}  // namespace ifrt_serving
+}  // namespace tensorflow
diff --git a/tensorflow/core/tfrt/ifrt/checkpoint_loader.h b/tensorflow/core/tfrt/ifrt/checkpoint_loader.h
new file mode 100644
index 0000000..cd835d4
--- /dev/null
+++ b/tensorflow/core/tfrt/ifrt/checkpoint_loader.h
@@ -0,0 +1,66 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_TFRT_IFRT_CHECKPOINT_LOADER_H_
+#define TENSORFLOW_CORE_TFRT_IFRT_CHECKPOINT_LOADER_H_
+
+#include <vector>
+
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/OwningOpRef.h"  // from @llvm-project
+#include "tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h"
+#include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h"
+#include "tensorflow/core/tfrt/mlrt/kernel/context.h"
+#include "tensorflow/core/tfrt/utils/fallback_tensor.h"
+#include "tfrt/host_context/concurrent_work_queue.h"  // from @tf_runtime
+
+namespace tensorflow {
+namespace ifrt_serving {
+
+// TODO(b/352551302) Move the unit test in ifrt_ops_kernel for restore to test
+// this class's APIs.
+// Implement the `CheckpointLoaderInterface` by using RestoreV2.
+class CheckpointLoader {
+ public:
+  explicit CheckpointLoader(
+      IfrtRestoreTensorRegistry* ifrt_restore_tensor_registry,
+      tfrt::ConcurrentWorkQueue* checkpoint_loader_work_queue)
+      : ifrt_restore_tensor_registry_(ifrt_restore_tensor_registry),
+        checkpoint_loader_work_queue_(checkpoint_loader_work_queue) {}
+  virtual ~CheckpointLoader() = default;
+
+  // Called before `Load` to do some preparation work.
+  virtual absl::Status PrepareRestore(mlir::OwningOpRef<mlir::ModuleOp> module);
+
+  // Load the checkpoint. This API is designed to be compatible with the
+  // `tf_mlrt.ifrt_restore_variable` kernel.
+  virtual absl::Status Load(
+      const tensorflow::tfrt_stub::FallbackTensor& prefix,
+      const std::vector<tensorflow::tfrt_stub::FallbackTensor>& var_handles,
+      const tensorflow::tfrt_stub::FallbackTensor& tensor_names,
+      const tensorflow::tfrt_stub::FallbackTensor& shape_and_slices,
+      const mlrt::bc::Vector<tensorflow::DataType>& restored_dtypes,
+      const mlrt::bc::Vector<bool>& truncate_in_cast,
+      tf_mlrt::Context& context);
+
+  IfrtRestoreTensorRegistry* ifrt_restore_tensor_registry_;
+  tfrt::ConcurrentWorkQueue* checkpoint_loader_work_queue_;
+};
+
+}  // namespace ifrt_serving
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_TFRT_IFRT_CHECKPOINT_LOADER_H_
diff --git a/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h b/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h
index e799c57..2c6a566 100644
--- a/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h
+++ b/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h
@@ -17,6 +17,7 @@
 #define TENSORFLOW_CORE_TFRT_IFRT_IFRT_LOADED_VARIABLE_REGISTRY_H_
 
 #include <string>
+#include <vector>
 
 #include "absl/base/thread_annotations.h"
 #include "absl/container/flat_hash_map.h"
@@ -25,6 +26,7 @@
 #include "absl/status/status.h"
 #include "absl/status/statusor.h"
 #include "absl/synchronization/mutex.h"
+#include "xla/hlo/ir/hlo_sharding.h"
 #include "xla/python/ifrt/array.h"
 #include "xla/python/ifrt/future.h"
 #include "xla/tsl/concurrency/ref_count.h"
@@ -38,19 +40,29 @@
   // The key is per variable tensor per device assignment. For single -device
   // program, variables can be loaded on multiple devices with core selection.
   // For SPMD program, we currently assume all devices will be used, so we use
-  // set to make it compatible with SPMD.
+  // vector to make it compatible with SPMD.
   struct Key {
-    // We use a set to make it compatible with SPMD.
-    absl::flat_hash_set<int> device_ids;
+    // We use a vector to make it compatible with SPMD because the order of the
+    // devices used for sharding must match the order of the devices used for
+    // xla compilation.
+    std::vector<int> device_ids;
     std::string input_name;
+    xla::HloSharding hlo_sharding;
     template <typename H>
     friend H AbslHashValue(H h, const Key& key) {
-      h = H::combine(std::move(h), key.input_name, key.device_ids);
+      h = H::combine(std::move(h), key.input_name, key.device_ids,
+                     key.hlo_sharding);
       return h;
     }
 
     friend bool operator==(const Key& x, const Key& y) {
-      return x.input_name == y.input_name && x.device_ids == y.device_ids;
+      return x.input_name == y.input_name && x.device_ids == y.device_ids &&
+             x.hlo_sharding == y.hlo_sharding;
+    }
+
+    std::string ToString() const {
+      return absl::StrCat(input_name, ":", absl::StrJoin(device_ids, ","), ":",
+                          hlo_sharding.ToString());
     }
   };
 
diff --git a/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils.cc b/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils.cc
index 7a17a7e..0af5363 100644
--- a/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils.cc
+++ b/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils.cc
@@ -53,14 +53,10 @@
     std::shared_ptr<xla::ifrt::Client> ifrt_client,
     const tsl::thread::ThreadPool& thread_pool,
     const tensorflow::Tensor& variable,
-    const VariableDeviceShardingConfigProto& sharding_config) {
-  std::vector<int> device_ids{sharding_config.device_ids().begin(),
-                              sharding_config.device_ids().end()};
-  TF_ASSIGN_OR_RETURN(xla::HloSharding hlo_sharding,
-                      xla::HloSharding::FromProto(sharding_config.sharding()));
+    const VariableDeviceShardingConfig& sharding_config) {
   return tensorflow::ifrt_serving::MakeArrayFromTensor(
-      *ifrt_client, variable, sharding_config.device_ids(), hlo_sharding,
-      thread_pool);
+      *ifrt_client, variable, sharding_config.device_ids,
+      sharding_config.hlo_sharding, thread_pool);
 }
 
 }  // namespace
@@ -97,12 +93,11 @@
     const ifrt_serving::IfrtRestoreTensorRegistry& ifrt_restore_tensor_registry,
     ifrt_serving::IfrtLoadedVariableRegistry& ifrt_loaded_variable_registry,
     tfrt::ConcurrentWorkQueue* checkpoint_loader_queue,
-    const VariableDeviceShardingConfigProto& sharding_config) {
-  absl::flat_hash_set<int> device_ids{sharding_config.device_ids().begin(),
-                                      sharding_config.device_ids().end()};
+    const VariableDeviceShardingConfig& sharding_config) {
   IfrtLoadedVariableRegistry::Key loaded_variable_key{
-      .device_ids = std::move(device_ids),
+      .device_ids = sharding_config.device_ids,
       .input_name = std::string(runtime_name),
+      .hlo_sharding = sharding_config.hlo_sharding,
   };
   if (ifrt_loaded_variable_registry.GetLoadedVariable(loaded_variable_key)
           .ok()) {
diff --git a/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils.h b/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils.h
index 4d07d1a..6fea3a5 100644
--- a/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils.h
+++ b/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils.h
@@ -18,11 +18,13 @@
 
 #include <memory>
 #include <string>
+#include <vector>
 
 #include "absl/status/status.h"
 #include "absl/status/statusor.h"
 #include "absl/strings/string_view.h"
 #include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h"
+#include "xla/hlo/ir/hlo_sharding.h"
 #include "xla/python/ifrt/client.h"
 #include "tensorflow/core/framework/resource_handle.h"
 #include "tensorflow/core/tfrt/ifrt/ifrt_config.pb.h"
@@ -37,6 +39,12 @@
 // An index to indicate a non per-core executable bundle cache.
 inline constexpr int kNoCoreSelectedIndex = -1;
 
+// TODO(b/352551302) Delete VariableDeviceShardingConfigProto.
+struct VariableDeviceShardingConfig {
+  std::vector<int> device_ids;
+  xla::HloSharding hlo_sharding;
+};
+
 absl::StatusOr<ifrt_serving::DtypeAndShape> GetDtypeAndShape(
     const ResourceHandle& resource_handle);
 
@@ -57,7 +65,7 @@
     const ifrt_serving::IfrtRestoreTensorRegistry& ifrt_restore_tensor_registry,
     ifrt_serving::IfrtLoadedVariableRegistry& ifrt_loaded_variable_registry,
     tfrt::ConcurrentWorkQueue* checkpoint_loader_queue,
-    const VariableDeviceShardingConfigProto& sharding_config);
+    const VariableDeviceShardingConfig& sharding_config);
 
 }  // namespace ifrt_serving
 }  // namespace tensorflow
diff --git a/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils_test.cc b/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils_test.cc
index 4777d0a3c..fe8e988 100644
--- a/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils_test.cc
+++ b/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils_test.cc
@@ -23,12 +23,14 @@
 #include <gtest/gtest.h>
 #include "absl/status/status.h"
 #include "absl/status/statusor.h"
+#include "xla/hlo/ir/hlo_sharding.h"
 #include "xla/python/ifrt/array.h"
 #include "xla/python/ifrt/client.h"
 #include "xla/python/ifrt/device.h"
 #include "xla/python/ifrt/future.h"
 #include "xla/python/ifrt/test_util.h"
 #include "xla/tsl/concurrency/ref_count.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "xla/xla_data.pb.h"
 #include "tensorflow/core/framework/resource_handle.h"
 #include "tensorflow/core/framework/tensor.h"
@@ -39,7 +41,6 @@
 #include "tensorflow/core/tfrt/ifrt/ifrt_config.pb.h"
 #include "tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h"
 #include "tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/env.h"
 #include "tsl/platform/status_matchers.h"
 #include "tsl/platform/statusor.h"
@@ -77,8 +78,10 @@
   auto restore_work_queue = tfrt::CreateMultiThreadedWorkQueue(
       /*num_threads=*/4, /*num_blocking_threads=*/4);
 
-  VariableDeviceShardingConfigProto sharding_config;
-  sharding_config.add_device_ids(0);
+  VariableDeviceShardingConfig sharding_config = {
+      .device_ids = {0},
+      .hlo_sharding = xla::HloSharding::Replicate(),
+  };
 
   auto promise = xla::ifrt::Future<tensorflow::Tensor>::CreatePromise();
   auto future = xla::ifrt::Future<tensorflow::Tensor>(promise);
@@ -120,8 +123,10 @@
   auto restore_work_queue = tfrt::CreateMultiThreadedWorkQueue(
       /*num_threads=*/4, /*num_blocking_threads=*/4);
 
-  VariableDeviceShardingConfigProto sharding_config;
-  sharding_config.add_device_ids(0);
+  VariableDeviceShardingConfig sharding_config{
+      .device_ids = {0},
+      .hlo_sharding = xla::HloSharding::Replicate(),
+  };
 
   auto promise = xla::ifrt::Future<tensorflow::Tensor>::CreatePromise();
   auto future = xla::ifrt::Future<tensorflow::Tensor>(promise);
@@ -140,6 +145,7 @@
   IfrtLoadedVariableRegistry::Key key{
       .device_ids = {0},
       .input_name = "var_x",
+      .hlo_sharding = sharding_config.hlo_sharding,
   };
   TF_ASSERT_OK_AND_ASSIGN(auto v,
                           loaded_variable_registry.GetLoadedVariable(key));
diff --git a/tensorflow/core/tfrt/ifrt/ifrt_model_restore_context.h b/tensorflow/core/tfrt/ifrt/ifrt_model_restore_context.h
new file mode 100644
index 0000000..da9528e
--- /dev/null
+++ b/tensorflow/core/tfrt/ifrt/ifrt_model_restore_context.h
@@ -0,0 +1,50 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_TFRT_IFRT_IFRT_MODEL_RESTORE_CONTEXT_H_
+#define TENSORFLOW_CORE_TFRT_IFRT_IFRT_MODEL_RESTORE_CONTEXT_H_
+
+#include <memory>
+#include <utility>
+
+#include "absl/strings/string_view.h"
+#include "tensorflow/core/tfrt/ifrt/checkpoint_loader.h"
+
+namespace tensorflow {
+namespace ifrt_serving {
+
+inline constexpr absl::string_view kIfrtModelRestoreContextName =
+    "IfrtModelRestoreContext";
+
+// A resource context that holds the `CheckpointLoader` for a model. We need a
+// different context than `IfrtModelContext` because `IfrtModelContext` is too
+// large to be a dependency of other libraries.
+class IfrtModelRestoreContext {
+ public:
+  explicit IfrtModelRestoreContext(
+      std::unique_ptr<CheckpointLoader> checkpoint_loader)
+      : checkpoint_loader_(std::move(checkpoint_loader)) {}
+
+  CheckpointLoader* checkpoint_loader() const {
+    return checkpoint_loader_.get();
+  }
+
+ private:
+  std::unique_ptr<CheckpointLoader> checkpoint_loader_;
+};
+
+}  // namespace ifrt_serving
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_TFRT_IFRT_IFRT_MODEL_RESTORE_CONTEXT_H_
diff --git a/tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry_test.cc b/tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry_test.cc
index de0a27a..3251962 100644
--- a/tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry_test.cc
+++ b/tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry_test.cc
@@ -21,11 +21,11 @@
 #include "absl/status/status.h"
 #include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h"
 #include "xla/python/ifrt/future.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/tensor_shape.h"
 #include "tensorflow/core/framework/tensor_testutil.h"
 #include "tensorflow/core/framework/types.pb.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/status_matchers.h"
 #include "tsl/platform/statusor.h"
 
diff --git a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.cc b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.cc
index b5e0770..a6d35bf 100644
--- a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.cc
+++ b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.cc
@@ -34,8 +34,12 @@
 #include "absl/strings/string_view.h"
 #include "absl/synchronization/mutex.h"
 #include "absl/types/span.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
 #include "mlir/IR/OwningOpRef.h"  // from @llvm-project
+#include "mlir/Support/LLVM.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/extract_callback.h"
 #include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h"
 #include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.h"
@@ -63,6 +67,7 @@
 #include "xla/tsl/framework/serving_device_selector.h"
 #include "xla/xla_data.pb.h"
 #include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/example/feature.pb.h"
 #include "tensorflow/core/framework/function.pb.h"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/tensor_shape.h"
@@ -70,6 +75,7 @@
 #include "tensorflow/core/framework/types.pb.h"
 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
 #include "tensorflow/core/tfrt/ifrt/ifrt_config.pb.h"
+#include "tensorflow/core/tfrt/ifrt/ifrt_device_utils.h"
 #include "tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h"
 #include "tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils.h"
 #include "tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h"
@@ -119,39 +125,61 @@
   return dtypes_and_shapes;
 }
 
-absl::StatusOr<xla::DeviceAssignment> GetXlaDeviceAssignment(
-    const tensorflow::tpu::TPUCompileMetadataProto& compile_metadata) {
-  if (!compile_metadata.has_device_assignment()) {
-    return absl::InternalError("No device assignment found.");
+// Returns the device assignment from the given IFRT devices list.
+absl::StatusOr<xla::DeviceAssignment> GetRuntimeXlaDeviceAssignment(
+    const xla::ifrt::DeviceList& devices, int num_replicas,
+    int num_cores_per_replica) {
+  const int num_devices = num_replicas * num_cores_per_replica;
+  if (devices.size() != num_devices) {
+    return absl::InternalError(
+        absl::StrCat("Device assignment has ", devices.size(),
+                     " devices, but expected ", num_devices));
   }
-  TF_ASSIGN_OR_RETURN(
-      std::unique_ptr<xla::DeviceAssignment> da,
-      xla::DeviceAssignment::Deserialize(compile_metadata.device_assignment()));
-  return *da;
-}
-
-absl::StatusOr<std::vector<xla::ifrt::Device*>> GetAssignedDevices(
-    const xla::ifrt::Client& ifrt_client,
-    const tensorflow::tpu::TPUCompileMetadataProto& compile_metadata) {
-  TF_ASSIGN_OR_RETURN(auto device_assignment,
-                      GetXlaDeviceAssignment(compile_metadata));
-  const int num_devices =
-      device_assignment.replica_count() * device_assignment.computation_count();
-  std::vector<xla::ifrt::Device*> devices;
-  devices.reserve(num_devices);
-  for (int replica_idx = 0; replica_idx < device_assignment.replica_count();
-       replica_idx++) {
-    for (int computation_idx = 0;
-         computation_idx < device_assignment.computation_count();
-         computation_idx++) {
-      auto device_id = device_assignment(replica_idx, computation_idx);
-      TF_ASSIGN_OR_RETURN(
-          xla::ifrt::Device * device,
-          ifrt_client.LookupDevice(xla::ifrt::DeviceId(device_id)));
-      devices.push_back(device);
+  xla::DeviceAssignment da(num_replicas, num_cores_per_replica);
+  int device_index = 0;
+  for (int replica_idx = 0; replica_idx < num_replicas; replica_idx++) {
+    for (int core_idx = 0; core_idx < num_cores_per_replica;
+         core_idx++, device_index++) {
+      da(replica_idx, core_idx) = devices[device_index]->Id().value();
+      VLOG(3) << "Added IFRT device id: " << da(replica_idx, core_idx);
     }
   }
-  return devices;
+  return da;
+}
+
+static constexpr absl::string_view kDeviceAssignmentAttr = "device_assignment";
+static constexpr absl::string_view kEntryFuncName = "main";
+
+absl::StatusOr<std::vector<xla::ifrt::Device*>> GetAssignedDevices(
+    mlir::ModuleOp module, const xla::ifrt::Client& ifrt_client,
+    int num_replicas, int num_cores_per_replica) {
+  auto op = module.lookupSymbol<mlir::func::FuncOp>(kEntryFuncName);
+  if (!op) {
+    return absl::InternalError("Could not find entry function in MLIR Module.");
+  }
+
+  auto device_assignment_attr =
+      op->getAttrOfType<mlir::ArrayAttr>(kDeviceAssignmentAttr);
+  std::optional<std::vector<int>> device_assignment_attr_val;
+
+  if (device_assignment_attr && !device_assignment_attr.getValue().empty()) {
+    std::vector<int> coords;
+    coords.reserve(num_replicas * num_cores_per_replica);
+    for (auto coord_attr : device_assignment_attr.getValue()) {
+      auto coord_attr_val = mlir::dyn_cast<mlir::IntegerAttr>(coord_attr);
+      if (!coord_attr_val) {
+        return absl::InternalError(
+            llvm::formatv("Device assignment attribute is not an integer: {0}",
+                          device_assignment_attr)
+                .str());
+      }
+      coords.push_back(coord_attr_val.getInt());
+    }
+    device_assignment_attr_val = std::move(coords);
+  }
+  return GetAssignedIfrtDevices(ifrt_client, num_replicas,
+                                num_cores_per_replica,
+                                device_assignment_attr_val);
 }
 
 }  // namespace
@@ -173,12 +201,21 @@
       tensorflow::tpu::TPUCompileMetadataProto original_compile_metadata,
       GetCompileMetadata(*module, *client));
 
+  TF_ASSIGN_OR_RETURN(
+      std::vector<xla::ifrt::Device*> assigned_devices,
+      GetAssignedDevices(*module, *client,
+                         original_compile_metadata.num_replicas(),
+                         original_compile_metadata.num_cores_per_replica()));
+
   auto executable = absl::WrapUnique(new IfrtServingExecutable(
       program_id, model_name, signature_name, std::move(module),
       std::move(client), thread_pool, ifrt_loaded_variable_registry,
       ifrt_restore, checkpoint_loader_queue, device_mgr,
       std::move(shape_representation_fn), ifrt_serving_core_selector,
-      std::move(original_compile_metadata), compilation_environement_proto));
+      std::move(original_compile_metadata),
+      xla::ifrt::DeviceList(xla::ifrt::DeviceList::Devices(
+          assigned_devices.begin(), assigned_devices.end())),
+      compilation_environement_proto));
 
   return executable;
 }
@@ -367,14 +404,17 @@
   xla_compile_options.executable_build_options.set_num_partitions(
       num_partitions);
 
-  xla_compile_options.executable_build_options.set_use_spmd_partitioning(true);
+  xla_compile_options.executable_build_options.set_use_spmd_partitioning(
+      original_compile_metadata_.use_spmd_for_xla_partitioning());
   xla_compile_options.parameter_is_tupled_arguments = false;
   // Use portable execution for single device + core selection.
   if (UsePortableExecution(compile_metadata)) {
     xla_compile_options.compile_portable_executable = true;
   } else {
-    TF_ASSIGN_OR_RETURN(xla::DeviceAssignment da,
-                        GetXlaDeviceAssignment(tf2hlo_result.compile_metadata));
+    TF_ASSIGN_OR_RETURN(
+        xla::DeviceAssignment da,
+        GetRuntimeXlaDeviceAssignment(assigned_device_list_, num_replicas,
+                                      num_partitions));
     VLOG(2) << "Device assignment :" << da.ToString();
     xla_compile_options.executable_build_options.set_device_assignment(da);
   }
@@ -516,7 +556,7 @@
 
   // `device_reservation` should be alive before the end of the execution.
   tsl::DeviceReservation device_reservation(kNoCoreSelectedIndex, nullptr);
-  std::vector<xla ::ifrt::Device*> devices;
+  xla::ifrt::DeviceList device_list;
   if (UsePortableExecution(compile_metadata)) {
     device_reservation =
         ifrt_serving_core_selector_->ReserveDevice(program_id_);
@@ -526,19 +566,16 @@
     TF_ASSIGN_OR_RETURN(xla::ifrt::Device * device,
                         ifrt_client_->LookupDevice(xla::ifrt::DeviceId(
                             device_reservation.device_index())));
-    devices.push_back(device);
+    device_list =
+        xla::ifrt::DeviceList(xla::ifrt::DeviceList::Devices({device}));
   } else {
-    TF_ASSIGN_OR_RETURN(devices,
-                        GetAssignedDevices(*ifrt_client_, compile_metadata));
+    device_list = assigned_device_list_;
   }
   TF_ASSIGN_OR_RETURN(SharedCachedExecutableBundle executable_bundle,
                       LookUpOrCreateExecutable(
                           compile_metadata, absl::MakeSpan(dtypes_and_shapes))
                           .Await());
 
-  xla::ifrt::DeviceList device_list(
-      xla::ifrt::DeviceList::Devices(devices.begin(), devices.end()));
-
   if (executable_bundle->compile_metadata.args().size() !=
       dtypes_and_shapes.size()) {
     return absl::InternalError(absl::StrCat(
@@ -548,7 +585,7 @@
 
   // Asynchronously load the restored variable tensors to Ifrt array.
   TF_RETURN_IF_ERROR(AsyncLoadIfrtArray(inputs, variable_arg_indices,
-                                        *executable_bundle, devices));
+                                        *executable_bundle, device_list));
 
   std::vector<tsl::RCReference<xla::ifrt::Array>> args;
   args.reserve(inputs.size());
@@ -556,13 +593,19 @@
   for (int i = 0; i < inputs.size(); i++) {
     if (variable_index < variable_arg_indices.size() &&
         i == variable_arg_indices[variable_index]) {
-      absl::flat_hash_set<int> device_ids;
-      for (const auto& device : devices) {
-        device_ids.insert(device->Id().value());
+      std::vector<int> device_ids;
+      device_ids.reserve(device_list.size());
+      for (const auto& device : device_list) {
+        device_ids.push_back(device->Id().value());
       }
+      TF_ASSIGN_OR_RETURN(
+          xla::HloSharding hlo_sharding,
+          xla::HloSharding::FromProto(
+              executable_bundle->compile_metadata.args()[i].sharding()));
       IfrtLoadedVariableRegistry::Key key{
           .device_ids = std::move(device_ids),
           .input_name = inputs[i].scalar<tsl::tstring>()(),
+          .hlo_sharding = std::move(hlo_sharding),
       };
       TF_ASSIGN_OR_RETURN(
           auto loaded_variable,
@@ -640,7 +683,7 @@
     absl::Span<const tensorflow::Tensor> inputs,
     absl::Span<const int> variable_arg_indices,
     const CachedExecutableBundle& executable_bundle,
-    const std::vector<xla::ifrt::Device*>& devices) {
+    const xla::ifrt::DeviceList& devices) {
   for (const int i : variable_arg_indices) {
     if (inputs[i].dtype() != tensorflow::DT_STRING ||
         !tensorflow::TensorShapeUtils::IsScalar(inputs[i].shape())) {
@@ -652,11 +695,15 @@
     }
     std::string runtime_name = inputs[i].scalar<tsl::tstring>()();
     // TODO(b/339521818): Add test cases for OpSharding on variables.
-    VariableDeviceShardingConfigProto sharding_config;
-    *sharding_config.mutable_sharding() =
-        executable_bundle.compile_metadata.args()[i].sharding();
+    TF_ASSIGN_OR_RETURN(
+        xla::HloSharding hlo_sharding,
+        xla::HloSharding::FromProto(
+            executable_bundle.compile_metadata.args()[i].sharding()));
+    VariableDeviceShardingConfig sharding_config{
+        .hlo_sharding = std::move(hlo_sharding),
+    };
     for (const auto& device : devices) {
-      sharding_config.add_device_ids(device->Id().value());
+      sharding_config.device_ids.push_back(device->Id().value());
     }
 
     TF_RETURN_IF_ERROR(
diff --git a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.h b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.h
index f3983b6..9dfc225 100644
--- a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.h
+++ b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.h
@@ -19,7 +19,6 @@
 #include <algorithm>
 #include <cstdint>
 #include <memory>
-#include <optional>
 #include <string>
 #include <utility>
 #include <vector>
@@ -144,12 +143,14 @@
       tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn,
       IfrtServingCoreSelector* ifrt_serving_core_selector,
       tensorflow::tpu::TPUCompileMetadataProto original_compile_metadata,
+      xla::ifrt::DeviceList assigned_device_list,
       tsl::protobuf::Message* compilation_environment_proto)
       : program_id_(program_id),
         model_name_(std::string(model_name)),
         signature_name_(std::string(signature_name)),
         module_(std::move(module)),
         original_compile_metadata_(std::move(original_compile_metadata)),
+        assigned_device_list_(std::move(assigned_device_list)),
         ifrt_client_(std::move(client)),
         thread_pool_(*thread_pool),
         ifrt_loaded_variable_registry_(*ifrt_loaded_variable_registry),
@@ -168,9 +169,10 @@
 
   mlir::OwningOpRef<mlir::ModuleOp> module_ ABSL_GUARDED_BY(mutex_);
   // The original compile metadata. We need to keep it around to be able to
-  // test portable execution condition even if the Module itsel is already
+  // test portable execution condition even if the Module itself is already
   // released.
   tensorflow::tpu::TPUCompileMetadataProto original_compile_metadata_;
+  const xla::ifrt::DeviceList assigned_device_list_;
 
   std::shared_ptr<xla::ifrt::Client> ifrt_client_;
   tsl::thread::ThreadPool& thread_pool_;
@@ -196,7 +198,7 @@
       absl::Span<const tensorflow::Tensor> inputs,
       absl::Span<const int> variable_arg_indices,
       const CachedExecutableBundle& executable_bundle,
-      const std::vector<xla::ifrt::Device*>& devices);
+      const xla::ifrt::DeviceList& devices);
 
   absl::StatusOr<tsl::RCReference<xla::ifrt::Array>> ConvertTensorToArray(
       const tensorflow::Tensor& tensor,
diff --git a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable_test.cc b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable_test.cc
index c0bdef1..28d3efb 100644
--- a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable_test.cc
+++ b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable_test.cc
@@ -33,6 +33,7 @@
 #include "xla/python/ifrt/test_util.h"
 #include "xla/tsl/framework/serving_device_selector.h"
 #include "xla/tsl/framework/test_util/mock_serving_device_selector.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/tensor_matcher.h"
 #include "tensorflow/core/framework/tensor_shape.h"
@@ -41,7 +42,6 @@
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h"
 #include "tensorflow/core/tfrt/ifrt/ifrt_serving_executable_test_util.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/statusor.h"
 #include "tsl/platform/tstring.h"
 
diff --git a/tensorflow/core/tfrt/ifrt/sharding_utils_test.cc b/tensorflow/core/tfrt/ifrt/sharding_utils_test.cc
index 7ec27bb..77043fc 100644
--- a/tensorflow/core/tfrt/ifrt/sharding_utils_test.cc
+++ b/tensorflow/core/tfrt/ifrt/sharding_utils_test.cc
@@ -34,13 +34,13 @@
 #include "xla/python/ifrt/test_util.h"
 #include "xla/python/pjrt_ifrt/xla_sharding.h"
 #include "xla/tsl/concurrency/ref_count.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "xla/xla_data.pb.h"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/tensor_matcher.h"
 #include "tensorflow/core/framework/tensor_shape.h"
 #include "tensorflow/core/framework/tensor_testutil.h"
 #include "tensorflow/core/tfrt/ifrt/ifrt_tensor_utils.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/env.h"
 #include "tsl/platform/ml_dtypes.h"
 #include "tsl/platform/status_matchers.h"
diff --git a/tensorflow/core/tfrt/kernels/BUILD b/tensorflow/core/tfrt/kernels/BUILD
index 9716f4b..817b848 100644
--- a/tensorflow/core/tfrt/kernels/BUILD
+++ b/tensorflow/core/tfrt/kernels/BUILD
@@ -62,7 +62,6 @@
         "@com_google_absl//absl/strings:string_view",
         "@com_google_absl//absl/types:span",
         "@com_google_googletest//:gtest_main",
-        "@local_tsl//tsl/lib/core:status_test_util",
         "@local_tsl//tsl/platform:status",
         "@local_tsl//tsl/platform:statusor",
         "@local_xla//xla/pjrt/cpu:cpu_client",
@@ -72,6 +71,7 @@
         "@local_xla//xla/python/pjrt_ifrt:tfrt_cpu_client_test_lib",
         "@local_xla//xla/tsl/framework:serving_device_selector",
         "@local_xla//xla/tsl/framework/test_util:mock_serving_device_selector",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
 
diff --git a/tensorflow/core/tfrt/kernels/ifrt_program_ops_test.cc b/tensorflow/core/tfrt/kernels/ifrt_program_ops_test.cc
index 3ae4d09..cd29511 100644
--- a/tensorflow/core/tfrt/kernels/ifrt_program_ops_test.cc
+++ b/tensorflow/core/tfrt/kernels/ifrt_program_ops_test.cc
@@ -32,6 +32,7 @@
 #include "xla/python/pjrt_ifrt/pjrt_client.h"
 #include "xla/tsl/framework/serving_device_selector.h"
 #include "xla/tsl/framework/test_util/mock_serving_device_selector.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/fake_input.h"
 #include "tensorflow/core/framework/node_def_builder.h"
 #include "tensorflow/core/framework/tensor.h"
@@ -45,7 +46,6 @@
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/tfrt/ifrt/ifrt_executable_registry.h"
 #include "tensorflow/core/tfrt/ifrt/ifrt_serving_executable_test_util.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/status.h"
 #include "tsl/platform/statusor.h"
 
diff --git a/tensorflow/core/tfrt/mlrt/interpreter/BUILD b/tensorflow/core/tfrt/mlrt/interpreter/BUILD
index 552959b..0b1eee7 100644
--- a/tensorflow/core/tfrt/mlrt/interpreter/BUILD
+++ b/tensorflow/core/tfrt/mlrt/interpreter/BUILD
@@ -189,9 +189,9 @@
         "@com_google_absl//absl/types:span",
         "@com_google_benchmark//:benchmark",
         "@com_google_googletest//:gtest_main",
-        "@local_tsl//tsl/lib/core:status_test_util",
         "@local_tsl//tsl/platform:status_matchers",
         "@local_tsl//tsl/platform:test_benchmark",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
         "@tf_runtime//:hostcontext",
     ],
 )
diff --git a/tensorflow/core/tfrt/mlrt/interpreter/interpreter_test.cc b/tensorflow/core/tfrt/mlrt/interpreter/interpreter_test.cc
index 0434f3c..97982e7 100644
--- a/tensorflow/core/tfrt/mlrt/interpreter/interpreter_test.cc
+++ b/tensorflow/core/tfrt/mlrt/interpreter/interpreter_test.cc
@@ -29,6 +29,7 @@
 #include "absl/synchronization/notification.h"
 #include "absl/types/span.h"
 #include "benchmark/benchmark.h"  // from @com_google_benchmark
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h"
 #include "tensorflow/core/tfrt/mlrt/bytecode/executable.h"
 #include "tensorflow/core/tfrt/mlrt/interpreter/async_handle.h"
@@ -39,7 +40,6 @@
 #include "tensorflow/core/tfrt/mlrt/interpreter/interpreter_testutil.h"
 #include "tensorflow/core/tfrt/mlrt/interpreter/register_span.h"
 #include "tensorflow/core/tfrt/mlrt/interpreter/value.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/status_matchers.h"
 #include "tsl/platform/test_benchmark.h"
 #include "tfrt/host_context/concurrent_work_queue.h"  // from @tf_runtime
diff --git a/tensorflow/core/tfrt/mlrt/kernel/BUILD b/tensorflow/core/tfrt/mlrt/kernel/BUILD
index 9da1e45..5e377f8 100644
--- a/tensorflow/core/tfrt/mlrt/kernel/BUILD
+++ b/tensorflow/core/tfrt/mlrt/kernel/BUILD
@@ -10,6 +10,7 @@
         # copybara:uncomment "//learning/brain/tfrt:__subpackages__",
         # copybara:uncomment "//learning/serving/servables/tfrt:__subpackages__",
         "//tensorflow/core/tfrt/graph_executor:__subpackages__",
+        "//tensorflow/core/tfrt/ifrt:__subpackages__",
         "//tensorflow/core/tfrt/saved_model:__subpackages__",
         "//tensorflow/core/tfrt/tfrt_session:__subpackages__",
     ],
@@ -67,19 +68,16 @@
     deps = [
         ":context",
         ":kernel",
-        ":kernel_runner_utils",
-        ":shard_restore_util",
-        "//tensorflow/compiler/mlir/tfrt/transforms/ifrt:ifrt_types",
         "//tensorflow/core:framework",
         "//tensorflow/core:framework_internal",
-        "//tensorflow/core/common_runtime:function",
         "//tensorflow/core/framework:attr_value_proto_cc",
         "//tensorflow/core/framework:types_proto_cc",
         "//tensorflow/core/platform:protobuf",
-        "//tensorflow/core/tfrt/fallback:op_kernel_runner",
+        "//tensorflow/core/tfrt/ifrt:checkpoint_loader",
         "//tensorflow/core/tfrt/ifrt:ifrt_config_proto_cc",
         "//tensorflow/core/tfrt/ifrt:ifrt_loaded_variable_utils",
         "//tensorflow/core/tfrt/ifrt:ifrt_model_context",
+        "//tensorflow/core/tfrt/ifrt:ifrt_model_restore_context",
         "//tensorflow/core/tfrt/ifrt:ifrt_restore_tensor_registry",
         "//tensorflow/core/tfrt/mlrt/bytecode",
         "//tensorflow/core/tfrt/mlrt/interpreter:context",
@@ -89,13 +87,10 @@
         "@com_google_absl//absl/log:check",
         "@com_google_absl//absl/status",
         "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/types:span",
         "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:statusor",
         "@local_tsl//tsl/platform:tstring",
         "@local_xla//xla:xla_data_proto_cc",
         "@local_xla//xla/python/ifrt",
-        "@tf_runtime//:hostcontext",
     ],
     alwayslink = 1,
 )
@@ -178,8 +173,8 @@
         "@com_google_absl//absl/synchronization",
         "@com_google_absl//absl/types:span",
         "@com_google_googletest//:gtest_main",
-        "@local_tsl//tsl/lib/core:status_test_util",
         "@local_tsl//tsl/platform:status_matchers",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
         "@tf_runtime//:hostcontext",
         "@tf_runtime//:ref_count",
     ],
@@ -210,9 +205,11 @@
         "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_compat_request_state",
         "//tensorflow/core/tfrt/fallback:fallback_state",
         "//tensorflow/core/tfrt/fallback:op_kernel_runner",
+        "//tensorflow/core/tfrt/ifrt:checkpoint_loader",
         "//tensorflow/core/tfrt/ifrt:ifrt_config_proto_cc",
         "//tensorflow/core/tfrt/ifrt:ifrt_loaded_variable_registry",
         "//tensorflow/core/tfrt/ifrt:ifrt_model_context",
+        "//tensorflow/core/tfrt/ifrt:ifrt_model_restore_context",
         "//tensorflow/core/tfrt/ifrt:ifrt_restore_tensor_registry",
         "//tensorflow/core/tfrt/ifrt:ifrt_serving_core_selector",
         "//tensorflow/core/tfrt/mlrt/bytecode",
@@ -230,7 +227,7 @@
         "@com_google_absl//absl/types:span",
         "@com_google_googletest//:gtest_main",
         "@eigen_archive//:eigen3",
-        "@local_tsl//tsl/lib/core:status_test_util",
+        "@llvm-project//mlir:IR",
         "@local_tsl//tsl/platform:env",
         "@local_tsl//tsl/platform:refcount",
         "@local_tsl//tsl/platform:status",
@@ -242,6 +239,7 @@
         "@local_xla//xla/python/pjrt_ifrt:tfrt_cpu_client_test_lib",
         "@local_xla//xla/tsl/framework:serving_device_selector",
         "@local_xla//xla/tsl/framework/test_util:mock_serving_device_selector",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
         "@tf_runtime//:hostcontext",
     ],
 )
diff --git a/tensorflow/core/tfrt/mlrt/kernel/batch_kernel.cc b/tensorflow/core/tfrt/mlrt/kernel/batch_kernel.cc
index 9d30f31..5a58488 100644
--- a/tensorflow/core/tfrt/mlrt/kernel/batch_kernel.cc
+++ b/tensorflow/core/tfrt/mlrt/kernel/batch_kernel.cc
@@ -234,7 +234,9 @@
             options.num_batch_threads, options.max_batch_size,
             options.batch_timeout_micros, options.max_enqueued_batches,
             options.allowed_batch_sizes, enable_large_batch_splitting,
-            disable_padding, options.low_priority_max_batch_size,
+            disable_padding,
+            /* batch_padding_policy= */ options.batch_padding_policy,
+            options.low_priority_max_batch_size,
             options.low_priority_batch_timeout_micros,
             options.low_priority_max_enqueued_batches,
             options.low_priority_allowed_batch_sizes,
diff --git a/tensorflow/core/tfrt/mlrt/kernel/ifrt_ops_kernel.cc b/tensorflow/core/tfrt/mlrt/kernel/ifrt_ops_kernel.cc
index ca9dd22..e5c7dbd 100644
--- a/tensorflow/core/tfrt/mlrt/kernel/ifrt_ops_kernel.cc
+++ b/tensorflow/core/tfrt/mlrt/kernel/ifrt_ops_kernel.cc
@@ -25,39 +25,31 @@
 #include "absl/status/status.h"
 #include "absl/strings/str_cat.h"
 #include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h"
 #include "xla/python/ifrt/future.h"
 #include "xla/xla_data.pb.h"
-#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
 #include "tensorflow/core/framework/attr_value.pb.h"
 #include "tensorflow/core/framework/device_base.h"
-#include "tensorflow/core/framework/node_def_util.h"
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/resource_handle.h"
 #include "tensorflow/core/framework/resource_var.h"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/tensor_shape.h"
-#include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/framework/types.pb.h"
 #include "tensorflow/core/platform/protobuf.h"  // IWYU pragma: keep
-#include "tensorflow/core/tfrt/fallback/op_kernel_runner.h"
+#include "tensorflow/core/tfrt/ifrt/checkpoint_loader.h"
 #include "tensorflow/core/tfrt/ifrt/ifrt_config.pb.h"
 #include "tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils.h"
 #include "tensorflow/core/tfrt/ifrt/ifrt_model_context.h"
+#include "tensorflow/core/tfrt/ifrt/ifrt_model_restore_context.h"
 #include "tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h"
 #include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h"
 #include "tensorflow/core/tfrt/mlrt/interpreter/context.h"
 #include "tensorflow/core/tfrt/mlrt/interpreter/future.h"
 #include "tensorflow/core/tfrt/mlrt/kernel/context.h"
 #include "tensorflow/core/tfrt/mlrt/kernel/kernel.h"
-#include "tensorflow/core/tfrt/mlrt/kernel/kernel_runner_utils.h"
-#include "tensorflow/core/tfrt/mlrt/kernel/shard_restore_util.h"
 #include "tensorflow/core/tfrt/utils/fallback_tensor.h"
 #include "tsl/platform/errors.h"
-#include "tsl/platform/statusor.h"
 #include "tsl/platform/tstring.h"
-#include "tfrt/host_context/concurrent_work_queue.h"  // from @tf_runtime
 
 using tensorflow::ifrt_serving::IfrtModelContext;
 
@@ -65,14 +57,6 @@
 namespace tf_mlrt {
 
 namespace {
-int64_t GetSizeFromVarHandle(const ResourceHandle& handle) {
-  int size = 0;
-  for (auto& dtype_and_shape : handle.dtypes_and_shapes()) {
-    size += DataTypeSize(dtype_and_shape.dtype) *
-            dtype_and_shape.shape.num_elements();
-  }
-  return size;
-}
 
 struct MlrtIfrtRestoreVariableKernel : mlrt::KernelFrame {
   using KernelFrame::KernelFrame;
@@ -119,20 +103,8 @@
   // dynamically decide it based on the size of the variables.
   static constexpr int kNumRestoreClusters = 4;
 
-  // A shard of variables to be restored.
-  struct RestoreVariableShard {
-    tensorflow::Tensor prefix;
-    tensorflow::Tensor tensor_names;
-    tensorflow::Tensor shape_and_slices;
-    std::vector<tensorflow::tfrt_stub::FallbackTensor> var_handles;
-    tensorflow::AttrValue dtypes_attr_value;
-    std::vector<tensorflow::DataType> restored_dtypes;
-    std::vector<bool> truncate_in_cast;
-  };
-
   absl::Status InvokeHelper();
 
-  absl::Status RunShard(RestoreVariableShard shard);
   absl::Status ValidateInput();
 };
 
@@ -144,218 +116,6 @@
   }
 }
 
-// Returns a casted tensor if successful.
-absl::StatusOr<tensorflow::Tensor> Cast(
-    tensorflow::Tensor& in_tensor, tensorflow::DataType restored_dtype,
-    tensorflow::DataType cast_dtype, bool truncate_in_cast,
-    const tensorflow::DeviceMgr& device_manager,
-    const tensorflow::ProcessFunctionLibraryRuntime&
-        process_function_library_runtime,
-    OpKernelContext::Params& params) {
-  auto runner =
-      tfrt_stub::OpKernelRunner::Create(
-          /*op_name=*/
-          "Cast", /*node_name=*/"Cast", params.device->name(),
-          /*num_args=*/1,
-          [&](tensorflow::AttrValueMap* attr_value_map) {
-            tensorflow::AttrValue restored_dtype_attr_value;
-            restored_dtype_attr_value.set_type(restored_dtype);
-            attr_value_map->insert({"SrcT", restored_dtype_attr_value});
-
-            tensorflow::AttrValue cast_dtype_attr_value;
-            cast_dtype_attr_value.set_type(cast_dtype);
-            attr_value_map->insert({"DstT", cast_dtype_attr_value});
-
-            tensorflow::AttrValue truncate_attr_value;
-            truncate_attr_value.set_b(truncate_in_cast);
-            attr_value_map->insert({"Truncate", truncate_attr_value});
-            return absl::OkStatus();
-          },
-          device_manager, process_function_library_runtime)
-          .value();
-
-  std::vector<tensorflow::TensorValue> input_tf_tensor_values;
-  input_tf_tensor_values.push_back(tensorflow::TensorValue(&in_tensor));
-
-  SetUpParams(runner, input_tf_tensor_values, params);
-  // Use persistent device instead of the per request device.
-
-  OpKernelContext op_kernel_context(&params, /*num_outputs=*/1);
-
-  runner.Run(&op_kernel_context);
-
-  if (!op_kernel_context.status().ok()) {
-    return op_kernel_context.status();
-  }
-  DCHECK_EQ(op_kernel_context.num_outputs(), 1);
-  return *(op_kernel_context.mutable_output(0));
-}
-
-absl::Status MlrtIfrtRestoreVariableKernel::RunShard(
-    RestoreVariableShard shard) {
-  std::optional<IfrtModelContext*> ifrt_model_context =
-      context().resource_context().GetResource<IfrtModelContext>(
-          "IfrtModelContext");
-  if (!ifrt_model_context.has_value()) {
-    return absl::FailedPreconditionError(
-        "RestoreVariableOp: failed to fetch IfrtModelContext");
-  }
-  const int num_outputs = shard.var_handles.size();
-  DCHECK_EQ(num_outputs, shard.tensor_names.NumElements());
-  auto& fallback_request_state = context().fallback_request_state();
-
-  // Use `tf.RestoreV2` to restore tensor. This will also populate
-  // tensorflow::ResourceManager.
-  // TODO(b/319045348): avoid populating tensorflow::ResourceManager if the
-  // variable is only used by device/IFRT.
-  // TODO(b/319045348): consider directly calling restore function such as that
-  // in /tensorflow/core/kernels/save_restore_v2_ops.cc
-  auto runner =
-      tfrt_stub::OpKernelRunner::Create(
-          /*op_name=*/
-          "RestoreV2", /*node_name=*/"RestoreV2",
-          context().params().device->name(),
-          /*num_args=*/3,
-          [&](tensorflow::AttrValueMap* attr_value_map) {
-            attr_value_map->insert({"dtypes", shard.dtypes_attr_value});
-            return absl::OkStatus();
-          },
-          fallback_request_state.device_manager(),
-          fallback_request_state.process_function_library_runtime())
-          .value();
-
-  // Prepare the input tensors.
-  std::vector<tensorflow::TensorValue> input_tf_tensor_values;
-  static constexpr int kNumInputArgs = 3;
-  input_tf_tensor_values.resize(kNumInputArgs);
-  // We need to keep these tensor alive
-  input_tf_tensor_values[0].tensor = &shard.prefix;
-  input_tf_tensor_values[1].tensor = &shard.tensor_names;
-  input_tf_tensor_values[2].tensor = &shard.shape_and_slices;
-
-  auto& params = context().params();
-  SetUpParams(runner, input_tf_tensor_values, params);
-  // Use persistent device instead of the per request device.
-  params.device = context().fallback_request_state().device_manager().HostCPU();
-
-  struct AsyncState {
-    explicit AsyncState(
-        const std::vector<tensorflow::TensorValue>& input_tf_tensor_values,
-        const OpKernelContext::Params& params, int num_outputs,
-        const tensorflow::DeviceMgr& device_manager,
-        const tensorflow::ProcessFunctionLibraryRuntime&
-            process_function_library_runtime)
-        : run_state(input_tf_tensor_values, params),
-          context(&run_state.params, num_outputs),
-          device_manager(device_manager),
-          process_function_library_runtime(process_function_library_runtime) {}
-
-    tfrt_stub::OpKernelRunState run_state;
-    OpKernelContext context;
-    const tensorflow::DeviceMgr& device_manager;
-    const tensorflow::ProcessFunctionLibraryRuntime&
-        process_function_library_runtime;
-
-    std::vector<xla::ifrt::Promise<tensorflow::Tensor>> results;
-  };
-  auto async_state = std::make_unique<AsyncState>(
-      input_tf_tensor_values, params, num_outputs,
-      fallback_request_state.device_manager(),
-      fallback_request_state.process_function_library_runtime());
-
-  ifrt_serving::IfrtRestoreTensorRegistry& ifrt_restore_tensor_registry =
-      (*ifrt_model_context)->GetRestoreTensorRegistry();
-  for (int i = 0; i < num_outputs; ++i) {
-    auto promise = xla::ifrt::Future<tensorflow::Tensor>::CreatePromise();
-    auto future = xla::ifrt::Future<tensorflow::Tensor>(promise);
-    const ResourceHandle& var_handle =
-        shard.var_handles[i].tensor().scalar<tensorflow::ResourceHandle>()();
-
-    TF_ASSIGN_OR_RETURN(ifrt_serving::DtypeAndShape dtype_and_shape,
-                        ifrt_serving::GetDtypeAndShape(var_handle));
-
-    std::string runtime_name =
-        ifrt_serving::GetRuntimeNameFromVarHandle(var_handle);
-
-    ifrt_serving::IfrtRestoreTensorRegistry::RestoredTensorInfo
-        restored_tensor_info = {false, std::move(dtype_and_shape),
-                                std::move(future)};
-    if (auto status = ifrt_restore_tensor_registry.TryRegister(
-            runtime_name, restored_tensor_info);
-        !status.ok()) {
-      // Propagate errors so that if already-registered futures are being waited
-      // on, they can be unblocked.
-      for (auto& result : async_state->results) {
-        std::move(result).Set(status);
-      };
-      return status;
-    }
-    async_state->results.push_back(std::move(promise));
-  }
-
-  // Use dedicated work queue for restore operation.
-  DCHECK((*ifrt_model_context)->checkpoint_loader_queue() != nullptr);
-  (*ifrt_model_context)
-      ->checkpoint_loader_queue()
-      ->AddTask([runner = std::move(runner),
-                 async_state = std::move(async_state),
-                 shard = std::move(shard)]() {
-        // Keep input tensor alive in `shard`.
-        auto* op_kernel_context_ptr = &async_state->context;
-        runner.Run(op_kernel_context_ptr);
-
-        auto& op_kernel_context = async_state->context;
-        if (!op_kernel_context.status().ok()) {
-          for (auto& result : async_state->results) {
-            std::move(result).Set(op_kernel_context.status());
-          }
-          return;
-        }
-        DCHECK_EQ(shard.var_handles.size(), op_kernel_context.num_outputs());
-        DCHECK_EQ(shard.truncate_in_cast.size(),
-                  op_kernel_context.num_outputs());
-
-        // TODO(b/343964091): consider to run multiple casts in parallel.
-        for (int i = 0; i < op_kernel_context.num_outputs(); ++i) {
-          DCHECK(op_kernel_context.mutable_output(i));
-
-          if (op_kernel_context.mutable_output(i)->dtype() !=
-              shard.restored_dtypes[i]) {
-            std::move(async_state->results[i])
-                .Set(absl::InvalidArgumentError(absl::StrCat(
-                    "The restored tensor has a different dtype than the "
-                    "variable handle: ",
-                    op_kernel_context.mutable_output(i)->dtype(), " vs. ",
-                    shard.restored_dtypes[i])));
-            return;
-          }
-          const ResourceHandle& var_handle =
-              shard.var_handles[i]
-                  .tensor()
-                  .scalar<tensorflow::ResourceHandle>()();
-
-          if (shard.restored_dtypes[i] ==
-              var_handle.dtypes_and_shapes()[0].dtype) {
-            std::move(async_state->results[i])
-                .Set(*std::move(op_kernel_context.mutable_output(i)));
-          } else {
-            absl::StatusOr<tensorflow::Tensor> cast_output = Cast(
-                *op_kernel_context.mutable_output(i), shard.restored_dtypes[i],
-                var_handle.dtypes_and_shapes()[0].dtype,
-                shard.truncate_in_cast[i], async_state->device_manager,
-                async_state->process_function_library_runtime,
-                async_state->run_state.params);
-            if (!cast_output.ok()) {
-              std::move(async_state->results[i]).Set(cast_output.status());
-            } else {
-              std::move(async_state->results[i]).Set(*std::move(cast_output));
-            }
-          }
-        }
-      });
-  return absl::OkStatus();
-}
-
 absl::Status MlrtIfrtRestoreVariableKernel::ValidateInput() {
   if (prefix().tensor().NumElements() != 1) {
     return absl::InvalidArgumentError(
@@ -398,65 +158,26 @@
 }
 
 absl::Status MlrtIfrtRestoreVariableKernel::InvokeHelper() {
-  TF_RETURN_IF_ERROR(ValidateInput());
-
-  std::vector<int64_t> variable_sizes;
-  variable_sizes.reserve(var_handles().size());
-  for (auto& handle : var_handles()) {
-    variable_sizes.push_back(GetSizeFromVarHandle(
-        handle.tensor().scalar<tensorflow::ResourceHandle>()()));
+  std::optional<ifrt_serving::IfrtModelRestoreContext*> model_restore_context =
+      context()
+          .resource_context()
+          .GetResource<ifrt_serving::IfrtModelRestoreContext>(
+              ifrt_serving::kIfrtModelRestoreContextName);
+  if (!model_restore_context.has_value()) {
+    return absl::InternalError(
+        "Did not find IfrtModelRestoreContext resource.");
   }
-
-  std::vector<std::vector<int>> sharded_indices =
-      ShardVariables(kNumRestoreClusters, absl::MakeSpan(variable_sizes));
-
-  // Converts the names and slices back to the tensor.
-  auto vector_to_tensor = [](const std::vector<tsl::tstring>& vec) {
-    tensorflow::Tensor tensor(tensorflow::DT_STRING,
-                              TensorShape({static_cast<int>(vec.size())}));
-    for (int i = 0; i < vec.size(); ++i) {
-      tensor.flat<tsl::tstring>()(i) = vec[i];
-    }
-    return tensor;
-  };
-
-  const auto& tensor_names_flat = tensor_names().tensor().flat<tsl::tstring>();
-  const auto& shape_and_slices_flat =
-      shape_and_slices().tensor().flat<tsl::tstring>();
-
-  std::vector<RestoreVariableShard> shards;
-  shards.reserve(sharded_indices.size());
-  for (auto& sharded_index : sharded_indices) {
-    RestoreVariableShard shard;
-    shard.var_handles.reserve(sharded_index.size());
-    shard.truncate_in_cast.reserve(sharded_index.size());
-    shard.restored_dtypes.reserve(sharded_index.size());
-
-    std::vector<tsl::tstring> tensor_names;
-    std::vector<tsl::tstring> shape_and_slices;
-    shape_and_slices.reserve(sharded_index.size());
-    tensor_names.reserve(sharded_index.size());
-    for (int index : sharded_index) {
-      tensor_names.push_back(tensor_names_flat(index));
-      shape_and_slices.push_back(shape_and_slices_flat(index));
-      shard.dtypes_attr_value.mutable_list()->add_type(
-          restored_dtypes()[index]);
-
-      shard.var_handles.push_back(var_handles()[index]);
-      shard.restored_dtypes.push_back(restored_dtypes()[index]);
-      shard.truncate_in_cast.push_back(truncate_in_cast()[index]);
-    }
-
-    shard.prefix = prefix().tensor();
-    shard.tensor_names = vector_to_tensor(tensor_names);
-    shard.shape_and_slices = vector_to_tensor(shape_and_slices);
-    shards.push_back(std::move(shard));
+  if (*model_restore_context == nullptr) {
+    return absl::InternalError("IfrtModelRestoreContext must not be null.");
   }
-
-  for (const auto& shard : shards) {
-    TF_RETURN_IF_ERROR(RunShard(shard));
+  ifrt_serving::CheckpointLoader* checkpoint_loader =
+      (*model_restore_context)->checkpoint_loader();
+  if (!checkpoint_loader) {
+    return absl::InternalError("CheckpointLoader must not be null.");
   }
-  return absl::OkStatus();
+  return checkpoint_loader->Load(prefix(), var_handles(), tensor_names(),
+                                 shape_and_slices(), restored_dtypes(),
+                                 truncate_in_cast(), context());
 }
 
 class MlrtIfrtLoadVariableKernel : public mlrt::KernelFrame {
diff --git a/tensorflow/core/tfrt/mlrt/kernel/ifrt_ops_kernel_test.cc b/tensorflow/core/tfrt/mlrt/kernel/ifrt_ops_kernel_test.cc
index 83b5876..07fb83b 100644
--- a/tensorflow/core/tfrt/mlrt/kernel/ifrt_ops_kernel_test.cc
+++ b/tensorflow/core/tfrt/mlrt/kernel/ifrt_ops_kernel_test.cc
@@ -33,6 +33,7 @@
 #include "xla/python/ifrt/future.h"
 #include "xla/python/ifrt/test_util.h"
 #include "xla/tsl/framework/test_util/mock_serving_device_selector.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/resource_var.h"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/tensor_matcher.h"
@@ -43,8 +44,10 @@
 #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h"
 #include "tensorflow/core/tfrt/fallback/fallback_state.h"
 #include "tensorflow/core/tfrt/fallback/op_kernel_runner.h"
+#include "tensorflow/core/tfrt/ifrt/checkpoint_loader.h"
 #include "tensorflow/core/tfrt/ifrt/ifrt_config.pb.h"
 #include "tensorflow/core/tfrt/ifrt/ifrt_model_context.h"
+#include "tensorflow/core/tfrt/ifrt/ifrt_model_restore_context.h"
 #include "tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h"
 #include "tensorflow/core/tfrt/ifrt/ifrt_serving_core_selector.h"
 #include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h"
@@ -57,7 +60,6 @@
 #include "tensorflow/core/tfrt/mlrt/kernel/context.h"
 #include "tensorflow/core/tfrt/mlrt/kernel/kernel.h"
 #include "tensorflow/core/tfrt/utils/fallback_tensor.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/env.h"
 #include "tsl/platform/refcount.h"
 #include "tsl/platform/status.h"
@@ -403,6 +405,13 @@
             .value();
     ifrt_model_context_->set_checkpoint_loader_queue(restore_work_queue_.get());
 
+    resource_context_
+        .CreateResource<tensorflow::ifrt_serving::IfrtModelRestoreContext>(
+            ifrt_serving::kIfrtModelRestoreContextName,
+            std::make_unique<tensorflow::ifrt_serving::CheckpointLoader>(
+                &ifrt_model_context_->GetRestoreTensorRegistry(),
+                ifrt_model_context_->checkpoint_loader_queue()));
+
     serving_device_selector_ =
         std::make_unique<tsl::test_util::MockServingDeviceSelector>();
     ifrt_core_selector_ =
diff --git a/tensorflow/core/tfrt/mlrt/kernel/kernel_test.cc b/tensorflow/core/tfrt/mlrt/kernel/kernel_test.cc
index f996695..a9aa89a 100644
--- a/tensorflow/core/tfrt/mlrt/kernel/kernel_test.cc
+++ b/tensorflow/core/tfrt/mlrt/kernel/kernel_test.cc
@@ -29,6 +29,7 @@
 #include "absl/strings/substitute.h"
 #include "absl/synchronization/notification.h"
 #include "absl/types/span.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/device.h"
 #include "tensorflow/core/framework/tensor_testutil.h"
 #include "tensorflow/core/tfrt/fallback/device_with_custom_allocator.h"
@@ -39,7 +40,6 @@
 #include "tensorflow/core/tfrt/mlrt/interpreter/interpreter_testutil.h"
 #include "tensorflow/core/tfrt/mlrt/kernel/batch_kernel.h"
 #include "tensorflow/core/tfrt/mlrt/kernel/context.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/status_matchers.h"
 #include "tfrt/concurrency/ref_count.h"  // from @tf_runtime
 #include "tfrt/host_context/concurrent_work_queue.h"  // from @tf_runtime
diff --git a/tensorflow/core/tfrt/mlrt/kernel/shard_restore_util.cc b/tensorflow/core/tfrt/mlrt/kernel/shard_restore_util.cc
index cd3f49f..16293c2 100644
--- a/tensorflow/core/tfrt/mlrt/kernel/shard_restore_util.cc
+++ b/tensorflow/core/tfrt/mlrt/kernel/shard_restore_util.cc
@@ -66,7 +66,7 @@
   };
   std::priority_queue<RestoreVariableCluster,
                       std::vector<RestoreVariableCluster>, decltype(cmp)>
-      min_heap;
+      min_heap(cmp);
   for (int i = 0; i < num_shards; ++i) {
     min_heap.push(RestoreVariableCluster());
   }
diff --git a/tensorflow/core/tfrt/saved_model/BUILD b/tensorflow/core/tfrt/saved_model/BUILD
index 9e85c14..5261546 100644
--- a/tensorflow/core/tfrt/saved_model/BUILD
+++ b/tensorflow/core/tfrt/saved_model/BUILD
@@ -135,6 +135,8 @@
         "//tensorflow/core/tfrt/graph_executor",
         "//tensorflow/core/tfrt/graph_executor:export_mlir",
         "//tensorflow/core/tfrt/graph_executor:graph_execution_options",
+        "//tensorflow/core/tfrt/ifrt:checkpoint_loader",
+        "//tensorflow/core/tfrt/ifrt:ifrt_model_restore_context",
         "//tensorflow/core/tfrt/mlrt/bytecode",
         "//tensorflow/core/tfrt/mlrt/bytecode:executable",
         "//tensorflow/core/tfrt/mlrt/interpreter:context",
diff --git a/tensorflow/core/tfrt/saved_model/saved_model.cc b/tensorflow/core/tfrt/saved_model/saved_model.cc
index 62ad655..84fbbff 100644
--- a/tensorflow/core/tfrt/saved_model/saved_model.cc
+++ b/tensorflow/core/tfrt/saved_model/saved_model.cc
@@ -16,13 +16,11 @@
 
 #include <algorithm>
 #include <cstddef>
-#include <cstdint>
 #include <functional>
 #include <iterator>
 #include <memory>
 #include <optional>
 #include <string>
-#include <unordered_set>
 #include <utility>
 #include <vector>
 
@@ -70,6 +68,8 @@
 #include "tensorflow/core/tfrt/graph_executor/export_mlir.h"
 #include "tensorflow/core/tfrt/graph_executor/graph_execution_options.h"
 #include "tensorflow/core/tfrt/graph_executor/graph_executor.h"
+#include "tensorflow/core/tfrt/ifrt/checkpoint_loader.h"
+#include "tensorflow/core/tfrt/ifrt/ifrt_model_restore_context.h"
 #include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h"
 #include "tensorflow/core/tfrt/mlrt/bytecode/executable.h"
 #include "tensorflow/core/tfrt/mlrt/interpreter/context.h"
@@ -134,6 +134,34 @@
         "/tensorflow/tfrt/saved_model/input_spec_validation_failure",
         "Record the models that failed input spec validation.", "model_name");
 
+absl::Status PrepareRestore(mlir::MLIRContext* context,
+                            ModelRuntimeContext* model_runtime_context,
+                            const tensorflow::MetaGraphDef& meta_graph_def,
+                            FallbackState& fallback_state,
+                            const std::string& saved_model_dir,
+                            const SavedModel::Options& options,
+                            ifrt_serving::CheckpointLoader* checkpoint_loader) {
+  // Import the global MLIR with `import_user_signatures` as true so that we can
+  // analysis the global MLIR to retrieve data needed for restore.
+  mlir::OwningOpRef<mlir::ModuleOp> mlir_module_restore_analysis;
+  ASSIGN_OR_RETURN_IN_IMPORT(
+      mlir_module_restore_analysis,
+      ImportSavedModel(
+          context, meta_graph_def, fallback_state, saved_model_dir,
+          /*import_user_signatures=*/true,
+          options.graph_execution_options.run_placer_grappler_on_functions));
+
+  if (!checkpoint_loader) {
+    return absl::InternalError("Missing checkpoint loader.");
+  }
+
+  TF_RETURN_IF_ERROR(checkpoint_loader->PrepareRestore(
+      std::move(mlir_module_restore_analysis)));
+
+  LOG(INFO) << "Complete set restore metadata.";
+  return absl::OkStatus();
+}
+
 tensorflow::Status RunBytecodeInitializers(
     const GraphExecutionOptions& options,
     const InitializersAndSignatures& initializers_and_signatures,
@@ -596,6 +624,25 @@
     model_context.set_callable_options(nullptr);
   }
 
+  if (options.graph_execution_options.use_ifrt) {
+    std::optional<ifrt_serving::IfrtModelRestoreContext*>
+        model_restore_context =
+            model_context.resource_context()
+                .GetResource<ifrt_serving::IfrtModelRestoreContext>(
+                    ifrt_serving::kIfrtModelRestoreContextName);
+    if (!model_restore_context.has_value()) {
+      return absl::InternalError(
+          "Did not find IfrtModelRestoreContext resource.");
+    }
+    if (*model_restore_context == nullptr) {
+      return absl::InternalError("IfrtModelRestoreContexts must not be null.");
+    }
+    TF_RETURN_IF_ERROR(
+        PrepareRestore(&context, &model_context, meta_graph_def,
+                       *fallback_state, std::string(saved_model_dir), options,
+                       (*model_restore_context)->checkpoint_loader()));
+  }
+
   GetDefaultInputValue(meta_graph_def.signature_def(), model_context,
                        initializers_and_signatures.signature_map);
 
diff --git a/tensorflow/core/tfrt/saved_model/tests/BUILD b/tensorflow/core/tfrt/saved_model/tests/BUILD
index 3dfc07d..c026800 100644
--- a/tensorflow/core/tfrt/saved_model/tests/BUILD
+++ b/tensorflow/core/tfrt/saved_model/tests/BUILD
@@ -649,7 +649,9 @@
         "//tensorflow/core/platform:resource_loader",
         "//tensorflow/core/runtime_fallback/runtime:runtime_fallback_alwayslink",
         "//tensorflow/core/tfrt:ifrt_program_ops_op_lib",
+        "//tensorflow/core/tfrt/ifrt:checkpoint_loader",
         "//tensorflow/core/tfrt/ifrt:ifrt_model_context",
+        "//tensorflow/core/tfrt/ifrt:ifrt_model_restore_context",
         "//tensorflow/core/tfrt/ifrt:ifrt_serving_core_selector",
         "//tensorflow/core/tfrt/mlrt/kernel:ifrt_ops_kernel",
         "//tensorflow/core/tfrt/runtime",
diff --git a/tensorflow/core/tfrt/saved_model/tests/saved_model_ifrt_test.cc b/tensorflow/core/tfrt/saved_model/tests/saved_model_ifrt_test.cc
index 8eaf225..4f4caf0 100644
--- a/tensorflow/core/tfrt/saved_model/tests/saved_model_ifrt_test.cc
+++ b/tensorflow/core/tfrt/saved_model/tests/saved_model_ifrt_test.cc
@@ -26,14 +26,16 @@
 #include "xla/python/ifrt/client.h"
 #include "xla/python/ifrt/test_util.h"
 #include "xla/tsl/framework/test_util/mock_serving_device_selector.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/platform/resource_loader.h"
+#include "tensorflow/core/tfrt/ifrt/checkpoint_loader.h"
 #include "tensorflow/core/tfrt/ifrt/ifrt_model_context.h"
+#include "tensorflow/core/tfrt/ifrt/ifrt_model_restore_context.h"
 #include "tensorflow/core/tfrt/ifrt/ifrt_serving_core_selector.h"
 #include "tensorflow/core/tfrt/runtime/runtime.h"
 #include "tensorflow/core/tfrt/saved_model/saved_model.h"
 #include "tensorflow/core/tfrt/saved_model/saved_model_testutil.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/env.h"
 #include "tsl/platform/statusor.h"
 #include "tsl/platform/threadpool.h"
@@ -77,10 +79,17 @@
                 "IfrtModelContext", client, &core_selector, &GetThreadPool(),
                 /*compilation_environment_proto=*/nullptr);
 
-        (*model_context.resource_context()
-              .GetResource<tensorflow::ifrt_serving::IfrtModelContext>(
-                  "IfrtModelContext"))
-            ->set_checkpoint_loader_queue(work_queue.get());
+        tensorflow::ifrt_serving::IfrtModelContext* ifrt_model_context =
+            (*model_context.resource_context()
+                  .GetResource<tensorflow::ifrt_serving::IfrtModelContext>(
+                      "IfrtModelContext"));
+        ifrt_model_context->set_checkpoint_loader_queue(work_queue.get());
+        model_context.resource_context()
+            .CreateResource<tensorflow::ifrt_serving::IfrtModelRestoreContext>(
+                ifrt_serving::kIfrtModelRestoreContextName,
+                std::make_unique<tensorflow::ifrt_serving::CheckpointLoader>(
+                    &ifrt_model_context->GetRestoreTensorRegistry(),
+                    ifrt_model_context->checkpoint_loader_queue()));
 
         return absl::OkStatus();
       });
diff --git a/tensorflow/core/tfrt/saved_model/tests/saved_model_test.cc b/tensorflow/core/tfrt/saved_model/tests/saved_model_test.cc
index 183b418..605e441 100644
--- a/tensorflow/core/tfrt/saved_model/tests/saved_model_test.cc
+++ b/tensorflow/core/tfrt/saved_model/tests/saved_model_test.cc
@@ -34,6 +34,7 @@
 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/tfrt/backend_compiler.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/tensor_shape.h"
 #include "tensorflow/core/framework/types.pb.h"
@@ -45,7 +46,6 @@
 #include "tensorflow/core/tfrt/runtime/work_queue_interface.h"
 #include "tensorflow/core/tfrt/saved_model/saved_model_testutil.h"
 #include "tensorflow/core/tfrt/saved_model/saved_model_util.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/status.h"
 #include "tsl/platform/statusor.h"
 #include "tfrt/host_context/concurrent_work_queue.h"  // from @tf_runtime
diff --git a/tensorflow/core/tfrt/saved_model/utils/BUILD b/tensorflow/core/tfrt/saved_model/utils/BUILD
index b76aa6f..008fea2 100644
--- a/tensorflow/core/tfrt/saved_model/utils/BUILD
+++ b/tensorflow/core/tfrt/saved_model/utils/BUILD
@@ -60,8 +60,8 @@
         "@com_google_googletest//:gtest_main",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Parser",
-        "@local_tsl//tsl/lib/core:status_test_util",
         "@local_tsl//tsl/platform:env",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
         "@tf_runtime//:bef",
     ],
 )
diff --git a/tensorflow/core/tfrt/saved_model/utils/serialize_utils_test.cc b/tensorflow/core/tfrt/saved_model/utils/serialize_utils_test.cc
index deaf171..2fc9c28 100644
--- a/tensorflow/core/tfrt/saved_model/utils/serialize_utils_test.cc
+++ b/tensorflow/core/tfrt/saved_model/utils/serialize_utils_test.cc
@@ -26,6 +26,7 @@
 #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
 #include "tensorflow/compiler/mlir/tfrt/transforms/mlrt/import_model.h"
 #include "tensorflow/compiler/mlir/tfrt/translate/import_model.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/path.h"
 #include "tensorflow/core/platform/resource_loader.h"
 #include "tensorflow/core/tfrt/fallback/fallback_state.h"
@@ -33,7 +34,6 @@
 #include "tensorflow/core/tfrt/saved_model/saved_model_testutil.h"
 #include "tensorflow/core/tfrt/saved_model/saved_model_util.h"
 #include "tensorflow/core/tfrt/utils/utils.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/env.h"
 #include "tfrt/bef/bef_buffer.h"  // from @tf_runtime
 
diff --git a/tensorflow/core/tfrt/tfrt_session/tfrt_session_test.cc b/tensorflow/core/tfrt/tfrt_session/tfrt_session_test.cc
index 2b63268..b63fc76 100644
--- a/tensorflow/core/tfrt/tfrt_session/tfrt_session_test.cc
+++ b/tensorflow/core/tfrt/tfrt_session/tfrt_session_test.cc
@@ -29,6 +29,7 @@
 #include "tensorflow/cc/ops/array_ops.h"
 #include "tensorflow/cc/ops/const_op.h"
 #include "tensorflow/cc/saved_model/reader.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/graph.pb.h"
 #include "tensorflow/core/framework/tensor_shape.h"
 #include "tensorflow/core/framework/tensor_testutil.h"
@@ -44,7 +45,6 @@
 #include "tensorflow/core/tfrt/runtime/runtime.h"
 #include "tensorflow/core/tfrt/saved_model/saved_model_testutil.h"
 #include "tensorflow/core/tfrt/utils/thread_pool.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/protobuf.h"
 
 namespace tensorflow {
diff --git a/tensorflow/core/tfrt/utils/debug/node_io_dump_rewriter_test.cc b/tensorflow/core/tfrt/utils/debug/node_io_dump_rewriter_test.cc
index 330adde..f5af931 100644
--- a/tensorflow/core/tfrt/utils/debug/node_io_dump_rewriter_test.cc
+++ b/tensorflow/core/tfrt/utils/debug/node_io_dump_rewriter_test.cc
@@ -34,6 +34,7 @@
 #include "tensorflow/cc/ops/array_ops.h"
 #include "tensorflow/cc/ops/math_ops.h"
 #include "tensorflow/cc/saved_model/reader.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/common_runtime/function_utils.h"
 #include "tensorflow/core/framework/op.h"
 #include "tensorflow/core/framework/tensor.h"
@@ -45,7 +46,6 @@
 #include "tensorflow/core/protobuf/meta_graph.pb.h"
 #include "tensorflow/core/tfrt/saved_model/saved_model.h"
 #include "tensorflow/core/tfrt/saved_model/saved_model_testutil.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/path.h"
 #include "tsl/platform/statusor.h"
 #include "tsl/platform/test.h"
diff --git a/tensorflow/core/tfrt/utils/tfrt_graph_execution_state.cc b/tensorflow/core/tfrt/utils/tfrt_graph_execution_state.cc
index 498d07a..c18230e 100644
--- a/tensorflow/core/tfrt/utils/tfrt_graph_execution_state.cc
+++ b/tensorflow/core/tfrt/utils/tfrt_graph_execution_state.cc
@@ -250,9 +250,9 @@
     DumpGraphDefToFile("before_pruning", graph_def);
   }
 
-  TF_ASSIGN_OR_RETURN(
-      result.graph,
-      CreatePrunedGraph(graph_def, build_graph_options.callable_options));
+  TF_ASSIGN_OR_RETURN(result.graph,
+                      CreatePrunedGraph(std::move(graph_def),
+                                        build_graph_options.callable_options));
   DCHECK(result.graph);
 
   if (VLOG_IS_ON(1)) {
diff --git a/tensorflow/core/tpu/graph_rewrite/BUILD b/tensorflow/core/tpu/graph_rewrite/BUILD
index 273c822..73fbacd 100644
--- a/tensorflow/core/tpu/graph_rewrite/BUILD
+++ b/tensorflow/core/tpu/graph_rewrite/BUILD
@@ -1,6 +1,10 @@
 # Contains graph rewrites for TPU runtimes and optimizations.
 
 load(
+    "//tensorflow:tensorflow.bzl",
+    "tf_cc_test",
+)
+load(
     "//tensorflow/core/platform:build_config_root.bzl",
     "if_static",
 )
@@ -119,6 +123,7 @@
         "//tensorflow/core:session_options",
         "//tensorflow/core/common_runtime:function_body",
         "//tensorflow/core/common_runtime:function_utils",
+        "//tensorflow/core/config:flag_defs",
         "//tensorflow/core/tpu:tpu_compile_interface",
         "//tensorflow/core/tpu:tpu_defs",
         "@com_google_absl//absl/container:flat_hash_map",
@@ -131,6 +136,7 @@
         "@local_tsl//tsl/platform:logging",
         "@local_tsl//tsl/platform:status",
         "@local_xla//xla:status_macros",
+        "@local_xla//xla/tsl/util:env_var",
     ] + if_static(
         [
             "//tensorflow/core/common_runtime:function",
@@ -140,6 +146,26 @@
     ),
 )
 
+tf_cc_test(
+    name = "encapsulate_tpu_computations_pass_test",
+    srcs = ["encapsulate_tpu_computations_pass_test.cc"],
+    deps = [
+        ":encapsulate_tpu_computations_pass",
+        "//tensorflow/cc:cc_ops",
+        "//tensorflow/cc:cc_ops_internal",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:framework_internal",
+        "//tensorflow/core:graph",
+        "//tensorflow/core:lib_internal",
+        "//tensorflow/core:ops",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+        "//tensorflow/core:testlib",
+        "//tensorflow/core/common_runtime:optimization_registry",
+        "//tensorflow/core/config:flag_defs",
+    ],
+)
+
 cc_library(
     name = "distributed_tpu_rewrite_pass_internal",
     srcs = ["distributed_tpu_rewrite_pass_internal.cc"],
diff --git a/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.cc b/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.cc
index 62a4c45..9370cd6 100644
--- a/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.cc
+++ b/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.cc
@@ -48,6 +48,7 @@
 #include "tensorflow/core/common_runtime/function_utils.h"
 #include "tensorflow/core/common_runtime/optimization_registry.h"
 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
+#include "tensorflow/core/config/flag_defs.h"
 #include "tensorflow/core/framework/function.h"
 #include "tensorflow/core/framework/function.pb.h"
 #include "tensorflow/core/framework/graph_to_functiondef.h"
@@ -2481,10 +2482,32 @@
   return absl::OkStatus();
 }
 
+// TODO(b/355263902): Encapsulation fails for some non-TPU graphs that are
+// missing full variable shape information. Remove this path once the
+// underlying issue is fixed.
+bool ShouldSkipEncapsulationForNonTPUGraph() {
+  return flags::Global().enable_skip_encapsulation_for_non_tpu_graphs.value();
+}
+
 }  // namespace
 
 /*static*/ Status EncapsulateTPUComputationsPass::Encapsulate(
     std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def) {
+  // If the graph does not contain any TPU computations, there is nothing to do.
+  if (ShouldSkipEncapsulationForNonTPUGraph()) {
+    bool found_tpu_replicate = false;
+    for (const Node* n : (*graph)->nodes()) {
+      if (n->attrs().Find(kTPUReplicateAttr) != nullptr) {
+        found_tpu_replicate = true;
+        break;
+      }
+    }
+    if (!found_tpu_replicate) {
+      VLOG(1) << "No TPU replicate found, skipping encapsulation";
+      return absl::OkStatus();
+    }
+  }
+
   // Check for undeclared outputs before Encapsulation, so we can give a better
   // error message.
   // TODO(phawkins): merge this with the encapsulation code to avoid the extra
diff --git a/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass_test.cc b/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass_test.cc
new file mode 100644
index 0000000..a21cdae
--- /dev/null
+++ b/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass_test.cc
@@ -0,0 +1,90 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.h"
+
+#include <memory>
+
+#include "xla/tsl/lib/core/status_test_util.h"
+#include "tensorflow/core/common_runtime/optimization_registry.h"
+#include "tensorflow/core/config/flag_defs.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/testlib.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+
+namespace {
+
+std::unique_ptr<Graph> CreateGraph() {
+  // c = a + b
+  auto g = std::make_unique<Graph>(OpRegistry::Global());
+  auto in0 = test::graph::Arg(g.get(), 0, DT_FLOAT);
+  auto in1 = test::graph::Arg(g.get(), 1, DT_FLOAT);
+  auto tmp = test::graph::Add(g.get(), in0, in1);
+  auto ret = test::graph::Retval(g.get(), 0, tmp);
+  g->AddControlEdge(in1, ret);
+  FixupSourceAndSinkEdges(g.get());
+  return g;
+}
+
+TEST(EncapsulateTPUComputationsPassTest, NonTPUGraph) {
+  auto g = CreateGraph();
+  GraphOptimizationPassOptions options;
+  options.graph = &g;
+  options.flib_def = g->mutable_flib_def();
+
+  EncapsulateTPUComputationsPass pass;
+  TF_ASSERT_OK(pass.Run(options));
+
+  int nodes_meeting_expectations = 0;
+
+  for (const auto* node : g->nodes()) {
+    if (!IsSource(node) && !IsSink(node)) {
+      ASSERT_TRUE(node->attrs().Find("_xla_inferred_shapes"));
+      ++nodes_meeting_expectations;
+    }
+  }
+  EXPECT_EQ(nodes_meeting_expectations, 4);
+}
+
+TEST(EncapsulateTPUComputationsPassTest, SkipEncapsulationForNonTPUGraph) {
+  flags::Global().enable_skip_encapsulation_for_non_tpu_graphs.reset(true);
+  auto g = CreateGraph();
+  GraphOptimizationPassOptions options;
+  options.graph = &g;
+  options.flib_def = g->mutable_flib_def();
+
+  EncapsulateTPUComputationsPass pass;
+  TF_ASSERT_OK(pass.Run(options));
+
+  int nodes_meeting_expectations = 0;
+
+  for (const auto* node : g->nodes()) {
+    if (!IsSource(node) && !IsSink(node)) {
+      ASSERT_FALSE(node->attrs().Find("_xla_inferred_shapes"));
+      ++nodes_meeting_expectations;
+    }
+  }
+  EXPECT_EQ(nodes_meeting_expectations, 4);
+
+  flags::Global().enable_skip_encapsulation_for_non_tpu_graphs.reset(false);
+}
+
+}  // namespace
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/tpu/kernels/sharding_util_ops_test.cc b/tensorflow/core/tpu/kernels/sharding_util_ops_test.cc
index e62e9d7..084802a 100644
--- a/tensorflow/core/tpu/kernels/sharding_util_ops_test.cc
+++ b/tensorflow/core/tpu/kernels/sharding_util_ops_test.cc
@@ -25,6 +25,7 @@
 #include "absl/strings/match.h"
 #include "absl/strings/str_cat.h"
 #include "absl/types/span.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/graph.pb.h"
 #include "tensorflow/core/framework/op.h"
 #include "tensorflow/core/framework/tensor.h"
@@ -39,7 +40,6 @@
 #include "tensorflow/core/protobuf/config.pb.h"
 #include "tensorflow/core/public/session.h"
 #include "tensorflow/core/public/session_options.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/errors.h"
 #include "tsl/protobuf/error_codes.pb.h"
 
diff --git a/tensorflow/core/tpu/kernels/sharding_utils_test.cc b/tensorflow/core/tpu/kernels/sharding_utils_test.cc
index cd583df..552a637 100644
--- a/tensorflow/core/tpu/kernels/sharding_utils_test.cc
+++ b/tensorflow/core/tpu/kernels/sharding_utils_test.cc
@@ -26,11 +26,11 @@
 #include "absl/strings/str_cat.h"
 #include "absl/types/span.h"
 #include "unsupported/Eigen/CXX11/Tensor"  // from @eigen_archive
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/tensor_shape.h"
 #include "tensorflow/core/framework/tensor_testutil.h"
 #include "tensorflow/core/platform/status.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/env.h"
 #include "tsl/platform/statusor.h"
 #include "tsl/platform/test.h"
diff --git a/tensorflow/core/tpu/kernels/sparse_core_layout.cc b/tensorflow/core/tpu/kernels/sparse_core_layout.cc
index 2f4a945b..bc4c416 100644
--- a/tensorflow/core/tpu/kernels/sparse_core_layout.cc
+++ b/tensorflow/core/tpu/kernels/sparse_core_layout.cc
@@ -72,10 +72,10 @@
       activation_mem_bytes_limit_(GetXlaSparseCoreStackingMemLimit()),
       variable_shard_bytes_limit_(GetXlaSparseCoreStackingTableShardLimit()) {}
 
-absl::Status SparseCoreLayoutStacker::AddTable(tsl::StringPiece table_name,
+absl::Status SparseCoreLayoutStacker::AddTable(absl::string_view table_name,
                                                int64_t table_height,
                                                int64_t table_width,
-                                               tsl::StringPiece group,
+                                               absl::string_view group,
                                                int64_t output_samples) {
   if (stacks_by_group_.empty()) {  // First call?
     VLOG(1) << "Stacking parameters: stacking_enabled_ = " << stacking_enabled_
diff --git a/tensorflow/core/tpu/kernels/sparse_core_layout.h b/tensorflow/core/tpu/kernels/sparse_core_layout.h
index c1d22f3..9f4697c 100644
--- a/tensorflow/core/tpu/kernels/sparse_core_layout.h
+++ b/tensorflow/core/tpu/kernels/sparse_core_layout.h
@@ -84,8 +84,8 @@
   //
   // Be sure you call AddTable in a deterministic order; the details of the
   // stacking will depend on the order you call AddTable.
-  absl::Status AddTable(tsl::StringPiece table_name, int64_t table_height,
-                        int64_t table_width, tsl::StringPiece group,
+  absl::Status AddTable(absl::string_view table_name, int64_t table_height,
+                        int64_t table_width, absl::string_view group,
                         int64_t output_samples);
 
   // Get the information about each table out.
diff --git a/tensorflow/core/tpu/tpu_embedding_configuration_proto_rewrite_test.cc b/tensorflow/core/tpu/tpu_embedding_configuration_proto_rewrite_test.cc
index ca2b1d5..d854eab 100644
--- a/tensorflow/core/tpu/tpu_embedding_configuration_proto_rewrite_test.cc
+++ b/tensorflow/core/tpu/tpu_embedding_configuration_proto_rewrite_test.cc
@@ -21,11 +21,11 @@
 #include <gtest/gtest.h>
 #include "absl/status/status.h"
 #include "absl/strings/string_view.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/status.h"
 #include "tensorflow/core/platform/status_matchers.h"
 #include "tensorflow/core/protobuf/tpu/tpu_embedding_configuration.pb.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/protobuf.h"  // IWYU pragma: keep
 #include "tsl/platform/test.h"
 
diff --git a/tensorflow/core/tpu/tpu_embedding_errors_test.cc b/tensorflow/core/tpu/tpu_embedding_errors_test.cc
index 3dbb182..f0a8d86 100644
--- a/tensorflow/core/tpu/tpu_embedding_errors_test.cc
+++ b/tensorflow/core/tpu/tpu_embedding_errors_test.cc
@@ -21,8 +21,8 @@
 #include "absl/status/status.h"
 #include "absl/status/statusor.h"
 #include "absl/strings/string_view.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/errors.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/statusor.h"
 #include "tsl/protobuf/error_codes.pb.h"
 
diff --git a/tensorflow/core/util/autotune_maps/BUILD b/tensorflow/core/util/autotune_maps/BUILD
index 8b89487..990edbe 100644
--- a/tensorflow/core/util/autotune_maps/BUILD
+++ b/tensorflow/core/util/autotune_maps/BUILD
@@ -52,8 +52,8 @@
         "@com_google_absl//absl/log:check",
         "@com_google_absl//absl/status",
         "@com_google_absl//absl/status:statusor",
-        "@local_tsl//tsl/lib/strings:proto_serialization",
         "@local_tsl//tsl/protobuf:dnn_proto_cc",
+        "@local_xla//xla/tsl/lib/strings:proto_serialization",
     ],
 )
 
@@ -118,7 +118,7 @@
         "conv_parameters.h",
     ],
     cuda_deps = [
-        "@local_tsl//tsl/lib/strings:proto_serialization",
+        "@local_xla//xla/tsl/lib/strings:proto_serialization",
     ],
     deps = [
         ":conv_parameters_proto_cc",
@@ -182,12 +182,12 @@
         "//tensorflow/core:framework",
         "//tensorflow/core/platform:status",
         "//tensorflow/core/platform:str_util",
-        "@local_tsl//tsl/lib/strings:proto_serialization",
         "@local_tsl//tsl/protobuf:dnn_proto_cc",
         "@local_xla//xla:status_macros",
         "@local_xla//xla/stream_executor:dnn",
         "@local_xla//xla/stream_executor:platform_manager",
         "@local_xla//xla/stream_executor/gpu:gpu_init",
+        "@local_xla//xla/tsl/lib/strings:proto_serialization",
     ],
 )
 
diff --git a/tensorflow/core/util/autotune_maps/autotune_serialize.cc b/tensorflow/core/util/autotune_maps/autotune_serialize.cc
index 63470c0..c601502 100644
--- a/tensorflow/core/util/autotune_maps/autotune_serialize.cc
+++ b/tensorflow/core/util/autotune_maps/autotune_serialize.cc
@@ -25,13 +25,13 @@
 #include "xla/stream_executor/dnn.h"
 #include "xla/stream_executor/gpu/gpu_init.h"
 #include "xla/stream_executor/platform_manager.h"
+#include "xla/tsl/lib/strings/proto_serialization.h"
 #include "tensorflow/core/platform/str_util.h"
 #include "tensorflow/core/util/activation_mode.h"
 #include "tensorflow/core/util/autotune_maps/autotune_map.pb.h"
 #include "tensorflow/core/util/autotune_maps/conv_autotune_maps.h"
 #include "tensorflow/core/util/autotune_maps/conv_parameters.h"
 #include "tensorflow/core/util/autotune_maps/conv_parameters.pb.h"
-#include "tsl/lib/strings/proto_serialization.h"
 #include "tsl/protobuf/dnn.pb.h"
 
 namespace tensorflow {
diff --git a/tensorflow/core/util/autotune_maps/conv_map_wrapper.cc b/tensorflow/core/util/autotune_maps/conv_map_wrapper.cc
index baa68aa..0bd1122 100644
--- a/tensorflow/core/util/autotune_maps/conv_map_wrapper.cc
+++ b/tensorflow/core/util/autotune_maps/conv_map_wrapper.cc
@@ -19,9 +19,9 @@
 
 #include "absl/log/check.h"
 #include "absl/status/status.h"
+#include "xla/tsl/lib/strings/proto_serialization.h"
 #include "tensorflow/core/util/autotune_maps/autotune_map.pb.h"
 #include "tensorflow/core/util/autotune_maps/conv_parameters.pb.h"
-#include "tsl/lib/strings/proto_serialization.h"
 #include "tsl/protobuf/dnn.pb.h"
 
 namespace tensorflow {
diff --git a/tensorflow/core/util/autotune_maps/conv_parameters.cc b/tensorflow/core/util/autotune_maps/conv_parameters.cc
index 6343693..a620e39 100644
--- a/tensorflow/core/util/autotune_maps/conv_parameters.cc
+++ b/tensorflow/core/util/autotune_maps/conv_parameters.cc
@@ -19,9 +19,9 @@
 #include <vector>
 
 #include "absl/strings/str_format.h"
+#include "xla/tsl/lib/strings/proto_serialization.h"
 #include "tensorflow/core/platform/hash.h"
 #include "tensorflow/core/util/autotune_maps/conv_parameters.pb.h"
-#include "tsl/lib/strings/proto_serialization.h"
 
 namespace tensorflow {
 
diff --git a/tensorflow/core/util/bcast.h b/tensorflow/core/util/bcast.h
index 5c8a5db..61d1fb5 100644
--- a/tensorflow/core/util/bcast.h
+++ b/tensorflow/core/util/bcast.h
@@ -199,7 +199,6 @@
     prev_is_one[i] = false;
     current_is_one[i] = false;
   }
-  Vec output;
   bool output_dim_set = false;
   int64_t output_dim = -1;
   bool none_is_one = true;
diff --git a/tensorflow/core/util/dump_graph_test.cc b/tensorflow/core/util/dump_graph_test.cc
index d24eccf..935ca41 100644
--- a/tensorflow/core/util/dump_graph_test.cc
+++ b/tensorflow/core/util/dump_graph_test.cc
@@ -16,6 +16,7 @@
 #include "tensorflow/core/util/dump_graph.h"
 
 #include "absl/strings/match.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/op.h"
 #include "tensorflow/core/graph/graph.h"
 #include "tensorflow/core/graph/node_builder.h"
@@ -23,7 +24,6 @@
 #include "tensorflow/core/platform/path.h"
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/platform/types.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/status.h"
 
 namespace tensorflow {
diff --git a/tensorflow/core/util/quantization/uniform_quant_ops_params_test.cc b/tensorflow/core/util/quantization/uniform_quant_ops_params_test.cc
index fe27a6c..ad28eeb 100644
--- a/tensorflow/core/util/quantization/uniform_quant_ops_params_test.cc
+++ b/tensorflow/core/util/quantization/uniform_quant_ops_params_test.cc
@@ -16,9 +16,9 @@
 
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/tensor_shape.h"
 #include "tensorflow/core/util/quantization/uniform_quant_ops_attr.pb.h"
-#include "tsl/lib/core/status_test_util.h"
 
 namespace tensorflow {
 namespace {
diff --git a/tensorflow/core/util/strided_slice_op_test.cc b/tensorflow/core/util/strided_slice_op_test.cc
index cbe0976..6eb961c 100644
--- a/tensorflow/core/util/strided_slice_op_test.cc
+++ b/tensorflow/core/util/strided_slice_op_test.cc
@@ -22,12 +22,12 @@
 
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/tensor_shape.h"
 #include "tensorflow/core/framework/tensor_testutil.h"
 #include "tensorflow/core/lib/gtl/inlined_vector.h"
 #include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/test.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/errors.h"
 #include "tsl/platform/status_matchers.h"
 
diff --git a/tensorflow/core/util/util.cc b/tensorflow/core/util/util.cc
index 6e9c20d..05f5d0f 100644
--- a/tensorflow/core/util/util.cc
+++ b/tensorflow/core/util/util.cc
@@ -151,10 +151,12 @@
   } else if (dt == DT_HALF) {
     // Float16 is not supported in oneDNN v2.x
 #ifdef ENABLE_ONEDNN_V3
-    result = (TestCPUFeature(port::CPUFeature::AVX512BW) &&
-              (TestCPUFeature(port::CPUFeature::AVX512_FP16) ||
-               TestCPUFeature(port::CPUFeature::AMX_FP16) ||
-               TestCPUFeature(port::CPUFeature::AVX_NE_CONVERT)));
+    // Some CPUs that don't support AVX-512 use AVX-NE-CONVERT to cast to and
+    // from FP32
+    result = ((TestCPUFeature(port::CPUFeature::AVX512BW) &&
+               (TestCPUFeature(port::CPUFeature::AVX512_FP16) ||
+                TestCPUFeature(port::CPUFeature::AMX_FP16))) ||
+              TestCPUFeature(port::CPUFeature::AVX_NE_CONVERT));
     if (result) VLOG(2) << "CPU supports " << DataType_Name(dt);
 #endif  // ENABLE_ONEDNN_V3
   } else {
diff --git a/tensorflow/dtensor/mlir/utils/update_tpu_metadata.cc b/tensorflow/dtensor/mlir/utils/update_tpu_metadata.cc
index 802d46f..2e24e5d 100644
--- a/tensorflow/dtensor/mlir/utils/update_tpu_metadata.cc
+++ b/tensorflow/dtensor/mlir/utils/update_tpu_metadata.cc
@@ -165,7 +165,7 @@
       mesh_name = "";
     }
     const std::vector<int>& tpu_core_ids = Mesh::tpu_core_ids()[mesh_name];
-    VLOG(1) << "tpu_core_ids: " << str_util::Join(tpu_core_ids, ", ");
+    VLOG(1) << "tpu_core_ids: " << absl::StrJoin(tpu_core_ids, ", ");
 
     xla::DeviceAssignmentProto device_assignment;
     device_assignment.set_replica_count(1);
@@ -223,7 +223,7 @@
       mesh_name = "";
     }
     const std::vector<int>& tpu_core_ids = Mesh::tpu_core_ids()[mesh_name];
-    VLOG(1) << "tpu_core_ids: " << str_util::Join(tpu_core_ids, ", ");
+    VLOG(1) << "tpu_core_ids: " << absl::StrJoin(tpu_core_ids, ", ");
 
     xla::DeviceAssignmentProto device_assignment;
     device_assignment.set_replica_count(num_replicas);
diff --git a/tensorflow/dtensor/tests/layout_to_xla_sharding_test.cc b/tensorflow/dtensor/tests/layout_to_xla_sharding_test.cc
index ff29775..475e08c 100644
--- a/tensorflow/dtensor/tests/layout_to_xla_sharding_test.cc
+++ b/tensorflow/dtensor/tests/layout_to_xla_sharding_test.cc
@@ -22,11 +22,11 @@
 #include "absl/strings/str_cat.h"
 #include "benchmark/benchmark.h"  // from @com_google_benchmark
 #include "xla/hlo/ir/hlo_sharding.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "xla/xla_data.pb.h"
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/dtensor/cc/dstatus.h"
 #include "tensorflow/dtensor/cc/tensor_layout.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/statusor.h"
 
 namespace tensorflow {
diff --git a/tensorflow/examples/speech_commands/accuracy_utils_test.cc b/tensorflow/examples/speech_commands/accuracy_utils_test.cc
index cf4f5ba..7edd1b4 100644
--- a/tensorflow/examples/speech_commands/accuracy_utils_test.cc
+++ b/tensorflow/examples/speech_commands/accuracy_utils_test.cc
@@ -15,11 +15,11 @@
 
 #include "tensorflow/examples/speech_commands/accuracy_utils.h"
 
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/path.h"
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/platform/types.h"
-#include "tsl/lib/core/status_test_util.h"
 
 namespace tensorflow {
 
diff --git a/tensorflow/examples/speech_commands/recognize_commands_test.cc b/tensorflow/examples/speech_commands/recognize_commands_test.cc
index 1730d06..1f13e24 100644
--- a/tensorflow/examples/speech_commands/recognize_commands_test.cc
+++ b/tensorflow/examples/speech_commands/recognize_commands_test.cc
@@ -15,12 +15,12 @@
 
 #include "tensorflow/examples/speech_commands/recognize_commands.h"
 
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/tensor_testutil.h"
 #include "tensorflow/core/framework/types.pb.h"
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/platform/types.h"
-#include "tsl/lib/core/status_test_util.h"
 
 namespace tensorflow {
 
diff --git a/tensorflow/lite/BUILD b/tensorflow/lite/BUILD
index 4a1d5bd..eb41fb0 100644
--- a/tensorflow/lite/BUILD
+++ b/tensorflow/lite/BUILD
@@ -830,7 +830,6 @@
     deps = [
         ":minimal_logging",
         "//tensorflow/lite/core/api:error_reporter",
-        "//tensorflow/lite/core/c:common",
     ],
 )
 
diff --git a/tensorflow/lite/CMakeLists.txt b/tensorflow/lite/CMakeLists.txt
index eb93434..6c288ad 100644
--- a/tensorflow/lite/CMakeLists.txt
+++ b/tensorflow/lite/CMakeLists.txt
@@ -373,7 +373,9 @@
   list(APPEND TFLITE_DELEGATES_GPU_SRCS
     ${TFLITE_SOURCE_DIR}/delegates/gpu/api.cc
     ${TFLITE_SOURCE_DIR}/delegates/gpu/delegate.cc
+    ${TFLITE_SOURCE_DIR}/delegates/gpu/delegate.h
     ${TFLITE_SOURCE_DIR}/delegates/gpu/delegate_options.cc
+    ${TFLITE_SOURCE_DIR}/delegates/gpu/delegate_options.h
     ${TFLITE_SOURCE_DIR}/delegates/gpu/tflite_profile.cc
     ${TFLITE_SOURCE_DIR}/experimental/acceleration/compatibility/android_info.cc
     ${TFLITE_DELEGATES_GPU_CL_SRCS}
@@ -681,6 +683,10 @@
   ${TF_SOURCE_DIR}/compiler/mlir/lite/utils/string_utils.h
   ${TF_SOURCE_DIR}/compiler/mlir/lite/experimental/remat/metadata_util.h
   ${TF_SOURCE_DIR}/compiler/mlir/lite/experimental/remat/metadata_util.cc
+  ${TF_SOURCE_DIR}/compiler/mlir/lite/core/model_builder_base.h
+  ${TF_SOURCE_DIR}/compiler/mlir/lite/core/model_builder_base.cc
+  ${TF_SOURCE_DIR}/compiler/mlir/lite/core/api/error_reporter.h
+  ${TF_SOURCE_DIR}/compiler/mlir/lite/core/api/error_reporter.cc
   ${TFLITE_SOURCE_DIR}/schema/schema_generated.h
 )
 add_library(tensorflow-lite
diff --git a/tensorflow/lite/allocation.cc b/tensorflow/lite/allocation.cc
index b187ef0..bbc41fa 100644
--- a/tensorflow/lite/allocation.cc
+++ b/tensorflow/lite/allocation.cc
@@ -21,6 +21,8 @@
 
 #include <cstdint>
 #include <cstdio>
+#include <cstdlib>
+#include <cstring>
 #include <memory>
 
 #include "tensorflow/lite/core/api/error_reporter.h"
@@ -100,11 +102,37 @@
   }
 #endif  // __arm__
 
+// `android_local_test` doesn't support zipalign b/356640509 so we need this
+// workaround to keep our clients working.
+// TODO: b/356413060 - Remove the workaround once b/356640509 is fixed.
+#if defined(__x86_64__) && defined(UNDEFINED_BEHAVIOR_SANITIZER)
+  if ((reinterpret_cast<uintptr_t>(ptr) & 0x3) != 0) {
+    aligned_ptr_ = ::aligned_alloc(4, num_bytes);
+    if (aligned_ptr_ == nullptr) {
+      TF_LITE_REPORT_ERROR(error_reporter, "Failed to allocate aligned buffer");
+      buffer_ = nullptr;
+      buffer_size_bytes_ = 0;
+      return;
+    }
+    memcpy(aligned_ptr_, ptr, num_bytes);
+    buffer_ = aligned_ptr_;
+  } else {
+    buffer_ = ptr;
+  }
+#else   // defined(__x86_64__) && defined(UNDEFINED_BEHAVIOR_SANITIZER)
   buffer_ = ptr;
+#endif  // defined(__x86_64__) && defined(UNDEFINED_BEHAVIOR_SANITIZER)
+
   buffer_size_bytes_ = num_bytes;
 }
 
-MemoryAllocation::~MemoryAllocation() {}
+MemoryAllocation::~MemoryAllocation() {
+#if defined(__x86_64__) && defined(UNDEFINED_BEHAVIOR_SANITIZER)
+  if (aligned_ptr_) {
+    free(aligned_ptr_);
+  }
+#endif
+}
 
 const void* MemoryAllocation::base() const { return buffer_; }
 
diff --git a/tensorflow/lite/allocation.h b/tensorflow/lite/allocation.h
index 6840646..f007b3c 100644
--- a/tensorflow/lite/allocation.h
+++ b/tensorflow/lite/allocation.h
@@ -145,6 +145,9 @@
 
  private:
   const void* buffer_;
+#if defined(__x86_64__) && defined(UNDEFINED_BEHAVIOR_SANITIZER)
+  void* aligned_ptr_ = nullptr;
+#endif
   size_t buffer_size_bytes_ = 0;
 };
 
diff --git a/tensorflow/lite/core/BUILD b/tensorflow/lite/core/BUILD
index d18efae..4309e28 100644
--- a/tensorflow/lite/core/BUILD
+++ b/tensorflow/lite/core/BUILD
@@ -43,13 +43,14 @@
     ],
     compatible_with = get_compatible_with_portable(),
     copts = tflite_copts() + tflite_copts_warnings(),
-    visibility = [
-        "//tensorflow/lite:__subpackages__",
-    ],
+    visibility = ["//tensorflow/lite:__subpackages__"],
     deps = [
         ":cc_api_stable",
         ":signature_runner",
+        "//tensorflow/compiler/mlir/lite/core:macros",
+        "//tensorflow/compiler/mlir/lite/core:model_builder_base",
         "//tensorflow/compiler/mlir/lite/experimental/remat:metadata_util",
+        "//tensorflow/compiler/mlir/lite/schema:schema_fbs",
         "//tensorflow/lite:allocation",
         "//tensorflow/lite:array",
         "//tensorflow/lite:external_cpu_backend_context",
@@ -120,6 +121,7 @@
         ":cc_api_stable",
         ":model_builder",
         ":signature_runner",
+        "//tensorflow/compiler/mlir/lite/core:model_builder_base",
         "//tensorflow/compiler/mlir/lite/experimental/remat:metadata_util",
         "//tensorflow/lite:allocation",
         "//tensorflow/lite:array",
@@ -135,7 +137,6 @@
         "//tensorflow/lite:util",
         "//tensorflow/lite/c:common_internal",
         "//tensorflow/lite/core/api",
-        "//tensorflow/lite/core/api:verifier",
         "//tensorflow/lite/core/async:async_signature_runner",
         "//tensorflow/lite/core/c:common",
         "//tensorflow/lite/experimental/resource",
@@ -174,7 +175,10 @@
         ":model_builder",
         ":signature_runner",
         ":subgraph",
+        "//tensorflow/compiler/mlir/lite/core:model_builder_base",
         "//tensorflow/compiler/mlir/lite/experimental/remat:metadata_util",
+        "//tensorflow/compiler/mlir/lite/schema:schema_fbs",
+        "//tensorflow/compiler/mlir/lite/schema:schema_utils",
         "//tensorflow/lite:allocation",
         "//tensorflow/lite:array",
         "//tensorflow/lite:external_cpu_backend_context",
@@ -253,6 +257,7 @@
     deps = [
         ":cc_api_stable",
         ":signature_runner",
+        "//tensorflow/compiler/mlir/lite/core:model_builder_base",
         "//tensorflow/compiler/mlir/lite/experimental/remat:metadata_util",
         "//tensorflow/lite:allocation",
         "//tensorflow/lite:array",
@@ -268,7 +273,6 @@
         "//tensorflow/lite:util",
         "//tensorflow/lite/c:common_internal",
         "//tensorflow/lite/core/api",
-        "//tensorflow/lite/core/api:verifier",
         "//tensorflow/lite/core/async:async_signature_runner",
         "//tensorflow/lite/core/c:c_api_types",
         "//tensorflow/lite/core/c:common",
@@ -286,22 +290,14 @@
 
 cc_library(
     name = "model_builder",
-    srcs = ["model_builder.cc"],
     hdrs = ["model_builder.h"],
     compatible_with = get_compatible_with_portable(),
     copts = tflite_copts_warnings(),
     visibility = internal_visibility_allowlist(),
     deps = [
-        ":macros",
-        "//tensorflow/compiler/mlir/lite/core:macros",
-        "//tensorflow/lite:allocation",
+        "//tensorflow/compiler/mlir/lite/core:model_builder_base",
         "//tensorflow/lite:stderr_reporter",
-        "//tensorflow/lite:string",
         "//tensorflow/lite/core/api:error_reporter",
-        "//tensorflow/lite/core/api:verifier",
-        "//tensorflow/lite/schema:schema_fbs",
-        "@com_google_absl//absl/strings",
-        "@flatbuffers",
     ],
     alwayslink = 1,
 )
@@ -336,6 +332,7 @@
     deps = [
         ":framework",
         ":signature_runner",
+        "//tensorflow/compiler/mlir/lite/core:model_builder_base",
         "//tensorflow/lite:model_builder",
         "//tensorflow/lite/core/kernels:builtin_ops",
         "//tensorflow/lite/testing:util",
@@ -366,6 +363,7 @@
     ],
     tags = [
         "no_windows",  # TODO(b/194459105): the test is flaky.
+        "noasan",
         "tflite_not_portable",
         "tflite_smoke_test",
     ],
diff --git a/tensorflow/lite/core/api/BUILD b/tensorflow/lite/core/api/BUILD
index 08ac033..ac1bb24 100644
--- a/tensorflow/lite/core/api/BUILD
+++ b/tensorflow/lite/core/api/BUILD
@@ -11,7 +11,6 @@
 filegroup(
     name = "tflite_internal_cc_3p_api_deps_src",
     srcs = [
-        ":error_reporter.cc",
         ":error_reporter.h",
         ":op_resolver.cc",
         ":op_resolver.h",
@@ -77,23 +76,33 @@
 
 cc_library(
     name = "error_reporter",
-    srcs = ["error_reporter.cc"],
-    hdrs = ["error_reporter.h"],
+    hdrs = [
+        "error_reporter.h",
+        "//tensorflow/compiler/mlir/lite/core/api:error_reporter.h",
+    ],
     compatible_with = get_compatible_with_portable(),
     copts = tflite_copts(),
     visibility = [
         "//visibility:public",
     ],
-    deps = [],
+    deps = [
+        "//tensorflow/compiler/mlir/lite/core/api:error_reporter",
+    ],
 )
 
 cc_library(
     name = "verifier",
-    hdrs = ["verifier.h"],
+    hdrs = [
+        "verifier.h",
+        "//tensorflow/compiler/mlir/lite/core/api:verifier.h",
+    ],
     compatible_with = get_compatible_with_portable(),
     copts = tflite_copts(),
     visibility = ["//visibility:public"],
-    deps = [":error_reporter"],
+    deps = [
+        "//tensorflow/compiler/mlir/lite/core/api:error_reporter",
+        "//tensorflow/compiler/mlir/lite/core/api:verifier",
+    ],
 )
 
 cc_library(
@@ -109,16 +118,6 @@
 )
 
 cc_test(
-    name = "error_reporter_test",
-    size = "small",
-    srcs = ["error_reporter_test.cc"],
-    deps = [
-        ":api",
-        "@com_google_googletest//:gtest_main",
-    ],
-)
-
-cc_test(
     name = "op_resolver_test",
     size = "small",
     srcs = ["op_resolver_test.cc"],
diff --git a/tensorflow/lite/core/api/error_reporter.cc b/tensorflow/lite/core/api/error_reporter.cc
deleted file mode 100644
index 7070eaa..0000000
--- a/tensorflow/lite/core/api/error_reporter.cc
+++ /dev/null
@@ -1,38 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "tensorflow/lite/core/api/error_reporter.h"
-#include <cstdarg>
-
-namespace tflite {
-
-int ErrorReporter::Report(const char* format, ...) {
-  va_list args;
-  va_start(args, format);
-  int code = Report(format, args);
-  va_end(args);
-  return code;
-}
-
-// TODO(aselle): Make the name of ReportError on context the same, so
-// we can use the ensure functions w/o a context and w/ a reporter.
-int ErrorReporter::ReportError(void*, const char* format, ...) {
-  va_list args;
-  va_start(args, format);
-  int code = Report(format, args);
-  va_end(args);
-  return code;
-}
-
-}  // namespace tflite
diff --git a/tensorflow/lite/core/api/error_reporter.h b/tensorflow/lite/core/api/error_reporter.h
index 1e0ef7d..f910604 100644
--- a/tensorflow/lite/core/api/error_reporter.h
+++ b/tensorflow/lite/core/api/error_reporter.h
@@ -15,58 +15,6 @@
 #ifndef TENSORFLOW_LITE_CORE_API_ERROR_REPORTER_H_
 #define TENSORFLOW_LITE_CORE_API_ERROR_REPORTER_H_
 
-#include <cstdarg>
-
-namespace tflite {
-
-/// A functor that reports error to supporting system. Invoked similar to
-/// printf.
-///
-/// Usage:
-///  ErrorReporter foo;
-///  foo.Report("test %d", 5);
-/// or
-///  va_list args;
-///  foo.Report("test %d", args); // where args is va_list
-///
-/// Subclass ErrorReporter to provide another reporting destination.
-/// For example, if you have a GUI program, you might redirect to a buffer
-/// that drives a GUI error log box.
-class ErrorReporter {
- public:
-  virtual ~ErrorReporter() = default;
-  /// Converts `args` to character equivalents according to `format` string,
-  /// constructs the error string and report it.
-  /// Returns number of characters written or zero on success, and negative
-  /// number on error.
-  virtual int Report(const char* format, va_list args) = 0;
-
-  /// Converts arguments to character equivalents according to `format` string,
-  /// constructs the error string and report it.
-  /// Returns number of characters written or zero on success, and negative
-  /// number on error.
-  int Report(const char* format, ...);
-
-  /// Equivalent to `Report` above. The additional `void*` parameter is unused.
-  /// This method is for compatibility with macros that takes `TfLiteContext`,
-  /// like TF_LITE_ENSURE and related macros.
-  int ReportError(void*, const char* format, ...);
-};
-
-}  // namespace tflite
-
-// You should not make bare calls to the error reporter, instead use the
-// TF_LITE_REPORT_ERROR macro, since this allows message strings to be
-// stripped when the binary size has to be optimized. If you are looking to
-// reduce binary size, define TF_LITE_STRIP_ERROR_STRINGS when compiling and
-// every call will be stubbed out, taking no memory.
-#ifndef TF_LITE_STRIP_ERROR_STRINGS
-#define TF_LITE_REPORT_ERROR(reporter, ...)                               \
-  do {                                                                    \
-    static_cast<::tflite::ErrorReporter*>(reporter)->Report(__VA_ARGS__); \
-  } while (false)
-#else  // TF_LITE_STRIP_ERROR_STRINGS
-#define TF_LITE_REPORT_ERROR(reporter, ...)
-#endif  // TF_LITE_STRIP_ERROR_STRINGS
+#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h"  // IWYU pragma: export
 
 #endif  // TENSORFLOW_LITE_CORE_API_ERROR_REPORTER_H_
diff --git a/tensorflow/lite/core/api/error_reporter_test.cc b/tensorflow/lite/core/api/error_reporter_test.cc
deleted file mode 100644
index 03d6da73..0000000
--- a/tensorflow/lite/core/api/error_reporter_test.cc
+++ /dev/null
@@ -1,61 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/lite/core/api/error_reporter.h"
-
-#include <cstdio>
-
-#include <gtest/gtest.h>
-
-namespace tflite {
-
-class MockErrorReporter : public ErrorReporter {
- public:
-  MockErrorReporter() { buffer_[0] = 0; }
-  int Report(const char* format, va_list args) override {
-    vsnprintf(buffer_, kBufferSize, format, args);
-    return 0;
-  }
-  char* GetBuffer() { return buffer_; }
-
- private:
-  static constexpr int kBufferSize = 256;
-  char buffer_[kBufferSize];
-};
-
-TEST(ErrorReporter, TestReport) {
-  MockErrorReporter mock_reporter;
-  ErrorReporter* reporter = &mock_reporter;
-  reporter->Report("Error: %d", 23);
-  EXPECT_EQ(0, strcmp(mock_reporter.GetBuffer(), "Error: 23"));
-}
-
-TEST(ErrorReporter, TestReportMacro) {
-  MockErrorReporter mock_reporter;
-  // Only define the reporter if it's used, to avoid warnings.
-#ifndef TF_LITE_STRIP_ERROR_STRINGS
-  ErrorReporter* reporter = &mock_reporter;
-#endif  // TFLITE_STRIP_ERROR_STRINGS
-
-  TF_LITE_REPORT_ERROR(reporter, "Error: %d", 23);
-
-#ifndef TF_LITE_STRIP_ERROR_STRINGS
-  EXPECT_EQ(0, strcmp(mock_reporter.GetBuffer(), "Error: 23"));
-#else   // TF_LITE_STRIP_ERROR_STRINGS
-  EXPECT_EQ(0, strcmp(mock_reporter.GetBuffer(), ""));
-#endif  // TF_LITE_STRIP_ERROR_STRINGS
-}
-
-}  // namespace tflite
diff --git a/tensorflow/lite/core/api/verifier.h b/tensorflow/lite/core/api/verifier.h
index 8128ff3..dcb1d02 100644
--- a/tensorflow/lite/core/api/verifier.h
+++ b/tensorflow/lite/core/api/verifier.h
@@ -18,22 +18,6 @@
 #ifndef TENSORFLOW_LITE_CORE_API_VERIFIER_H_
 #define TENSORFLOW_LITE_CORE_API_VERIFIER_H_
 
-#include "tensorflow/lite/core/api/error_reporter.h"
-
-namespace tflite {
-
-/// Abstract interface that verifies whether a given model is legit.
-/// It facilitates the use-case to verify and build a model without loading it
-/// twice.
-/// (See also "tensorflow/lite/tools/verifier.h".)
-class TfLiteVerifier {
- public:
-  /// Returns true if the model is legit.
-  virtual bool Verify(const char* data, int length,
-                      ErrorReporter* reporter) = 0;
-  virtual ~TfLiteVerifier() {}
-};
-
-}  // namespace tflite
+#include "tensorflow/compiler/mlir/lite/core/api/verifier.h"  // IWYU pragma: export
 
 #endif  // TENSORFLOW_LITE_CORE_API_VERIFIER_H_
diff --git a/tensorflow/lite/core/c/common.h b/tensorflow/lite/core/c/common.h
index 96f19f1..648b862 100644
--- a/tensorflow/lite/core/c/common.h
+++ b/tensorflow/lite/core/c/common.h
@@ -100,7 +100,9 @@
   TfLiteStatus (*Refresh)(struct TfLiteContext* context);
 } TfLiteExternalContext;
 
+// LINT.IfChange(optional_tensor)
 #define kTfLiteOptionalTensor (-1)
+// LINT.ThenChange(//tensorflow/compiler/mlir/lite/flatbuffer_export.cc:optional_tensor)
 
 /// Fixed size list of integers. Used for dimensions and inputs/outputs tensor
 /// indices
diff --git a/tensorflow/lite/core/model_builder.cc b/tensorflow/lite/core/model_builder.cc
deleted file mode 100644
index afa8513..0000000
--- a/tensorflow/lite/core/model_builder.cc
+++ /dev/null
@@ -1,474 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "tensorflow/lite/core/model_builder.h"
-
-#include <stddef.h>
-#include <stdint.h>
-
-#include <algorithm>
-#include <cstring>
-#include <map>
-#include <memory>
-#include <string>
-#include <utility>
-
-#include "absl/strings/str_cat.h"
-#include "flatbuffers/base.h"  // from @flatbuffers
-#include "flatbuffers/buffer.h"  // from @flatbuffers
-#include "flatbuffers/vector.h"  // from @flatbuffers
-#include "flatbuffers/verifier.h"  // from @flatbuffers
-#include "tensorflow/compiler/mlir/lite/core/macros.h"
-#include "tensorflow/lite/allocation.h"
-#include "tensorflow/lite/core/api/error_reporter.h"
-#include "tensorflow/lite/core/api/verifier.h"
-#include "tensorflow/lite/schema/schema_generated.h"
-#include "tensorflow/lite/stderr_reporter.h"
-#include "tensorflow/lite/string_type.h"
-
-namespace tflite {
-
-namespace {
-
-// Ensure that ErrorReporter is non-null.
-ErrorReporter* ValidateErrorReporter(ErrorReporter* e) {
-  return e ? e : DefaultErrorReporter();
-}
-
-}  // namespace
-
-#ifndef TFLITE_MCU
-// Loads a model from `filename`. If `mmap_file` is true then use mmap,
-// otherwise make a copy of the model in a buffer.
-std::unique_ptr<Allocation> GetAllocationFromFile(
-    const char* filename, ErrorReporter* error_reporter) {
-  std::unique_ptr<Allocation> allocation;
-  if (MMAPAllocation::IsSupported()) {
-    allocation = std::make_unique<MMAPAllocation>(filename, error_reporter);
-  } else {
-    allocation = std::make_unique<FileCopyAllocation>(filename, error_reporter);
-  }
-  return allocation;
-}
-
-// Loads a model from `fd`. If `mmap_file` is true then use mmap,
-// otherwise make a copy of the model in a buffer.
-std::unique_ptr<Allocation> GetAllocationFromFile(
-    int fd, ErrorReporter* error_reporter) {
-  std::unique_ptr<Allocation> allocation;
-  if (MMAPAllocation::IsSupported()) {
-    allocation = std::make_unique<MMAPAllocation>(fd, error_reporter);
-  } else {
-    allocation = std::make_unique<FileCopyAllocation>(
-        absl::StrCat("/proc/self/fd/", fd).c_str(), error_reporter);
-  }
-  return allocation;
-}
-
-namespace impl {
-
-std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromFile(
-    const char* filename, ErrorReporter* error_reporter) {
-  error_reporter = ValidateErrorReporter(error_reporter);
-  std::unique_ptr<FlatBufferModel> model = BuildFromAllocation(
-      GetAllocationFromFile(filename, error_reporter), error_reporter);
-#if FLATBUFFERS_LITTLEENDIAN == 1
-  return model;
-#else
-  return ByteConvertModel(std::move(model), error_reporter);
-#endif
-}
-
-std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromFile(
-    const char* filename, TfLiteVerifier* extra_verifier,
-    ErrorReporter* error_reporter) {
-  error_reporter = ValidateErrorReporter(error_reporter);
-  std::unique_ptr<FlatBufferModel> model = VerifyAndBuildFromAllocation(
-      GetAllocationFromFile(filename, error_reporter), extra_verifier,
-      error_reporter);
-#if FLATBUFFERS_LITTLEENDIAN == 1
-  return model;
-#else
-  return ByteConvertModel(std::move(model), error_reporter);
-#endif
-}
-
-std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromFileDescriptor(
-    int fd, ErrorReporter* error_reporter) {
-  error_reporter = ValidateErrorReporter(error_reporter);
-  std::unique_ptr<FlatBufferModel> model = BuildFromAllocation(
-      GetAllocationFromFile(fd, error_reporter), error_reporter);
-#if FLATBUFFERS_LITTLEENDIAN == 1
-  return model;
-#else
-  return ByteConvertModel(std::move(model), error_reporter);
-#endif
-}
-
-std::unique_ptr<FlatBufferModel>
-FlatBufferModel::VerifyAndBuildFromFileDescriptor(
-    int fd, TfLiteVerifier* extra_verifier, ErrorReporter* error_reporter) {
-  error_reporter = ValidateErrorReporter(error_reporter);
-  std::unique_ptr<FlatBufferModel> model =
-      VerifyAndBuildFromAllocation(GetAllocationFromFile(fd, error_reporter),
-                                   extra_verifier, error_reporter);
-#if FLATBUFFERS_LITTLEENDIAN == 1
-  return model;
-#else
-  return ByteConvertModel(std::move(model), error_reporter);
-#endif
-}
-
-}  // namespace impl
-
-#endif
-
-std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromBuffer(
-    const char* caller_owned_buffer, size_t buffer_size,
-    ErrorReporter* error_reporter) {
-  error_reporter = ValidateErrorReporter(error_reporter);
-  std::unique_ptr<Allocation> allocation(
-      new MemoryAllocation(caller_owned_buffer, buffer_size, error_reporter));
-  return BuildFromAllocation(std::move(allocation), error_reporter);
-}
-
-std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromBuffer(
-    const char* caller_owned_buffer, size_t buffer_size,
-    TfLiteVerifier* extra_verifier, ErrorReporter* error_reporter) {
-  error_reporter = ValidateErrorReporter(error_reporter);
-  std::unique_ptr<Allocation> allocation(
-      new MemoryAllocation(caller_owned_buffer, buffer_size, error_reporter));
-  return VerifyAndBuildFromAllocation(std::move(allocation), extra_verifier,
-                                      error_reporter);
-}
-
-#if FLATBUFFERS_LITTLEENDIAN == 0
-
-void FlatBufferModel::ByteSwapSerializedModel(std::string* serialized_model,
-                                              bool from_big_endian) {
-  const uint8_t* buffer =
-      reinterpret_cast<const uint8_t*>(serialized_model->c_str());
-  const tflite::Model* input_model = tflite::GetModel(buffer);
-  ByteSwapTFLiteModel(input_model, from_big_endian);
-}
-
-void FlatBufferModel::ByteSwapBuffer(int8_t tensor_type, size_t buffer_size,
-                                     uint8_t* buffer, bool from_big_endian) {
-  switch (tensor_type) {
-    case tflite::TensorType_STRING: {
-      auto bp = reinterpret_cast<int32_t*>(buffer);
-      int num_of_strings =
-          from_big_endian ? bp[0] : flatbuffers::EndianSwap(bp[0]);
-      for (int i = 0; i < num_of_strings + 2; i++)
-        bp[i] = flatbuffers::EndianSwap(bp[i]);
-      break;
-    }
-    // 16-bit types
-    case tflite::TensorType_FLOAT16:
-    case tflite::TensorType_INT16:
-    case tflite::TensorType_UINT16: {
-      auto bp = reinterpret_cast<uint16_t*>(buffer);
-      for (int i = 0; i < buffer_size / 2; i++)
-        bp[i] = flatbuffers::EndianSwap(bp[i]);
-      break;
-    }
-    // 32-bit types
-    case tflite::TensorType_FLOAT32:
-    case tflite::TensorType_INT32:
-    case tflite::TensorType_UINT32:
-    case tflite::TensorType_COMPLEX64: {
-      auto bp = reinterpret_cast<uint32_t*>(buffer);
-      for (int i = 0; i < buffer_size / 4; i++)
-        bp[i] = flatbuffers::EndianSwap(bp[i]);
-      break;
-    }
-    // 64-bit types
-    case tflite::TensorType_INT64:
-    case tflite::TensorType_FLOAT64:
-    case tflite::TensorType_UINT64:
-    case tflite::TensorType_COMPLEX128: {
-      auto bp = reinterpret_cast<uint64_t*>(buffer);
-      for (int i = 0; i < buffer_size / 8; i++)
-        bp[i] = flatbuffers::EndianSwap(bp[i]);
-      break;
-    }
-    default:
-      break;
-  }
-}
-
-void FlatBufferModel::ByteSwapTFLiteModel(const tflite::Model* tfl_model,
-                                          bool from_big_endian) {
-  bool buffer_swapped[tfl_model->buffers()->size()] = {};
-  for (size_t subgraph_idx = 0; subgraph_idx < tfl_model->subgraphs()->size();
-       subgraph_idx++) {
-    const tflite::SubGraph* subgraph =
-        tfl_model->subgraphs()->Get(subgraph_idx);
-    for (size_t ts_idx = 0; ts_idx < subgraph->tensors()->size(); ts_idx++) {
-      const tflite::Tensor* tensor = subgraph->tensors()->Get(ts_idx);
-      if (tensor->buffer() > 0 &&
-          tensor->buffer() < tfl_model->buffers()->size() &&
-          !buffer_swapped[tensor->buffer()]) {
-        const tflite::Buffer* buffer_ =
-            (*tfl_model->buffers())[tensor->buffer()];
-        if (!buffer_ || !buffer_->data()) continue;
-        auto* buffer = buffer_->data();
-        uint8_t* buff_ = const_cast<uint8_t*>(buffer->data());
-        ByteSwapBuffer(tensor->type(), buffer->size(), buff_, from_big_endian);
-        buffer_swapped[tensor->buffer()] = true;
-      }
-    }
-  }
-}
-
-std::unique_ptr<FlatBufferModel> FlatBufferModel::ByteConvertModel(
-    std::unique_ptr<FlatBufferModel> model, ErrorReporter* error_reporter,
-    bool from_big_endian) {
-  if (model == nullptr) return model;
-  auto tfl_model = model->GetModel();
-  if (tfl_model->subgraphs()->size() == 0) return model;
-  if (tfl_model->subgraphs()->Get(0)->tensors()->size() == 0) return model;
-  if (tfl_model->buffers()->size() < 2) return model;
-  return ByteSwapFlatBufferModel(std::move(model), error_reporter,
-                                 from_big_endian);
-}
-
-std::unique_ptr<FlatBufferModel> FlatBufferModel::ByteSwapFlatBufferModel(
-    std::unique_ptr<FlatBufferModel> model, ErrorReporter* error_reporter,
-    bool from_big_endian) {
-  FlatBufferModel* modelp = model.release();
-  auto tflite_model = modelp->GetModel();
-  auto copied_model = std::make_unique<tflite::ModelT>();
-  tflite_model->UnPackTo(copied_model.get(), nullptr);
-  ByteSwapTFLiteModelT(copied_model.get(), from_big_endian);
-  std::unique_ptr<flatbuffers::FlatBufferBuilder> builder(
-      new flatbuffers::FlatBufferBuilder());
-  auto packed_model = tflite::Model::Pack(*builder, copied_model.get());
-  tflite::FinishModelBuffer(*builder, packed_model);
-  flatbuffers::FlatBufferBuilder* builder_ = builder.release();
-  return BuildFromBuffer(
-      reinterpret_cast<const char*>(builder_->GetBufferPointer()),
-      builder_->GetSize(), error_reporter);
-}
-
-void FlatBufferModel::ByteSwapTFLiteModelT(tflite::ModelT* tfl_modelt,
-                                           bool from_big_endian) {
-  size_t bytes_per_elem = 0;
-  bool buffer_swapped[tfl_modelt->buffers.size()] = {};
-  for (size_t subgraph_idx = 0; subgraph_idx < tfl_modelt->subgraphs.size();
-       subgraph_idx++) {
-    tflite::SubGraphT* subgraph = tfl_modelt->subgraphs.at(subgraph_idx).get();
-    for (size_t ts_idx = 0; ts_idx < subgraph->tensors.size(); ts_idx++) {
-      tflite::TensorT* tensor = subgraph->tensors[ts_idx].get();
-      if (tensor->buffer > 0 && tensor->buffer < tfl_modelt->buffers.size() &&
-          !buffer_swapped[tensor->buffer]) {
-        const auto* buffer = &(tfl_modelt->buffers[tensor->buffer].get()->data);
-        if (buffer && buffer->data()) {
-          uint8_t* buff_ = const_cast<uint8_t*>(buffer->data());
-          ByteSwapBuffer(tensor->type, buffer->size(), buff_, from_big_endian);
-          buffer_swapped[tensor->buffer] = true;
-        }
-      }
-    }
-  }
-}
-
-#endif
-
-void FlatBufferModel::ValidateModelBuffers(ErrorReporter* error_reporter) {
-  auto buffers = model_->buffers();
-  if (buffers && buffers->size() > 0) {
-    auto first_buffer = buffers->Get(0);
-    if (first_buffer && first_buffer->size() != 0) {
-      // Note the 0th entry of this array must be an empty buffer (sentinel).
-      // This is a convention so that tensors without a buffer can provide 0
-      // as their buffer.
-      TF_LITE_REPORT_ERROR(
-          error_reporter,
-          "The 0th entry of the model buffer must be an empty buffer.");
-    }
-  }
-}
-
-std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromAllocation(
-    std::unique_ptr<Allocation> allocation, ErrorReporter* error_reporter) {
-  std::unique_ptr<FlatBufferModel> model(new FlatBufferModel(
-      std::move(allocation), ValidateErrorReporter(error_reporter)));
-  if (!model->initialized()) {
-    model.reset();
-  } else {
-    model->ValidateModelBuffers(error_reporter);
-  }
-  return model;
-}
-
-std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromAllocation(
-    std::unique_ptr<Allocation> allocation, TfLiteVerifier* extra_verifier,
-    ErrorReporter* error_reporter) {
-  error_reporter = ValidateErrorReporter(error_reporter);
-  if (!allocation || !allocation->valid()) {
-    TF_LITE_REPORT_ERROR(error_reporter, "The model allocation is null/empty");
-    return nullptr;
-  }
-
-  {
-    // Flatbuffers can only be smaller than 2GB. The file format appends some
-    // data after the actual flabuffer. We truncate the allocation size to 2GB
-    // so that the verifier doesn't early exit on us.
-    size_t allocation_size =
-        std::min(allocation->bytes(),
-                 static_cast<size_t>(FLATBUFFERS_MAX_BUFFER_SIZE - 1));
-    flatbuffers::Verifier base_verifier(
-        reinterpret_cast<const uint8_t*>(allocation->base()), allocation_size);
-    if (!VerifyModelBuffer(base_verifier)) {
-      TF_LITE_REPORT_ERROR(error_reporter,
-                           "The model is not a valid Flatbuffer buffer");
-      return nullptr;
-    }
-
-    if (extra_verifier &&
-        !extra_verifier->Verify(static_cast<const char*>(allocation->base()),
-                                allocation_size, error_reporter)) {
-      // The verifier will have already logged an appropriate error message.
-      return nullptr;
-    }
-  }
-
-  return BuildFromAllocation(std::move(allocation), error_reporter);
-}
-
-std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromModel(
-    const tflite::Model* caller_owned_model_spec,
-    ErrorReporter* error_reporter) {
-  error_reporter = ValidateErrorReporter(error_reporter);
-
-  if (CheckBufferOutsideModel(caller_owned_model_spec)) {
-    TF_LITE_REPORT_ERROR(error_reporter,
-                         "The model contains weights not accessible from "
-                         "tflite::Model *, please use other api");
-    return nullptr;
-  }
-
-  std::unique_ptr<FlatBufferModel> model(
-      new FlatBufferModel(caller_owned_model_spec, error_reporter));
-  if (!model->initialized()) {
-    model.reset();
-  } else {
-    model->ValidateModelBuffers(error_reporter);
-  }
-  return model;
-}
-
-bool FlatBufferModel::CheckBufferOutsideModel(const tflite::Model* model) {
-  if (!model || !model->metadata()) return false;
-
-  for (int i = 0; i < model->metadata()->size(); ++i) {
-    auto metadata = model->metadata()->Get(i);
-    if (metadata->name()->str() == tflite_metadata_buffer_location) {
-      return true;
-    }
-  }
-  return false;
-}
-
-string FlatBufferModel::GetMinimumRuntime() const {
-  if (!model_ || !model_->metadata()) return "";
-
-  for (int i = 0; i < model_->metadata()->size(); ++i) {
-    auto metadata = model_->metadata()->Get(i);
-    if (metadata->name()->str() == tflite_metadata_min_runtime_version) {
-      auto buf = metadata->buffer();
-      auto* buffer = (*model_->buffers())[buf];
-      auto* array = buffer->data();
-      // Get the real length of the runtime string, since there might be
-      // trailing
-      // '\0's in the buffer.
-      for (int len = 0; len < array->size(); ++len) {
-        if (array->data()[len] == '\0') {
-          return string(reinterpret_cast<const char*>(array->data()), len);
-        }
-      }
-      // If there is no '\0' in the buffer, this indicates that the flatbuffer
-      // is malformed.
-      TF_LITE_REPORT_ERROR(
-          error_reporter_,
-          "Min_runtime_version in model metadata is malformed");
-      break;
-    }
-  }
-  return "";
-}
-
-std::map<std::string, std::string> FlatBufferModel::ReadAllMetadata() const {
-  return ReadAllMetadata(model_);
-}
-
-std::map<std::string, std::string> FlatBufferModel::ReadAllMetadata(
-    const tflite::Model* model) {
-  std::map<std::string, std::string> keys_values;
-  if (!model || !model->metadata() || !model->buffers()) return keys_values;
-
-  for (int i = 0; i < model->metadata()->size(); ++i) {
-    auto metadata = model->metadata()->Get(i);
-    auto buf = metadata->buffer();
-    if (buf >= model->buffers()->size()) continue;
-    const tflite::Buffer* buffer = (*model->buffers())[buf];
-    if (!buffer || !buffer->data()) continue;
-    const flatbuffers::Vector<uint8_t>* array = buffer->data();
-    if (!array) continue;
-    std::string val =
-        string(reinterpret_cast<const char*>(array->data()), array->size());
-    // Skip if key or value of metadata is empty.
-    if (!metadata->name() || val.empty()) continue;
-    keys_values[metadata->name()->str()] = val;
-  }
-  return keys_values;
-}
-
-bool FlatBufferModel::CheckModelIdentifier() const {
-  if (allocation_->bytes() < 7) {
-    TF_LITE_REPORT_ERROR(
-        error_reporter_,
-        "Model provided must have at least 7 bytes to hold identifier.\n");
-    return false;
-  }
-  if (!tflite::ModelBufferHasIdentifier(allocation_->base())) {
-    const char* ident = flatbuffers::GetBufferIdentifier(allocation_->base());
-    TF_LITE_REPORT_ERROR(
-        error_reporter_,
-        "Model provided has model identifier '%c%c%c%c', should be '%s'\n",
-        ident[0], ident[1], ident[2], ident[3], tflite::ModelIdentifier());
-    return false;
-  }
-  return true;
-}
-
-FlatBufferModel::FlatBufferModel(const Model* model,
-                                 ErrorReporter* error_reporter)
-    : model_(model), error_reporter_(ValidateErrorReporter(error_reporter)) {}
-
-FlatBufferModel::FlatBufferModel(std::unique_ptr<Allocation> allocation,
-                                 ErrorReporter* error_reporter)
-    : error_reporter_(ValidateErrorReporter(error_reporter)),
-      allocation_(std::move(allocation)) {
-  if (!allocation_ || !allocation_->valid() || !CheckModelIdentifier()) {
-    return;
-  }
-
-  model_ = ::tflite::GetModel(allocation_->base());
-}
-
-FlatBufferModel::~FlatBufferModel() {}
-
-}  // namespace tflite
diff --git a/tensorflow/lite/core/model_builder.h b/tensorflow/lite/core/model_builder.h
index 3337d93..9a36c4b 100644
--- a/tensorflow/lite/core/model_builder.h
+++ b/tensorflow/lite/core/model_builder.h
@@ -22,260 +22,30 @@
 /// but should instead include "third_party/tensorflow/lite/model_builder.h".
 /// Only the TensorFlow Lite implementation itself should include this
 /// file directly.
+
 #ifndef TENSORFLOW_LITE_CORE_MODEL_BUILDER_H_
 #define TENSORFLOW_LITE_CORE_MODEL_BUILDER_H_
 
 #include <stddef.h>
 
-#include <map>
-#include <memory>
-#include <string>
-
-#include "tensorflow/lite/allocation.h"
+#include "tensorflow/compiler/mlir/lite/core/model_builder_base.h"
 #include "tensorflow/lite/core/api/error_reporter.h"
-#include "tensorflow/lite/core/api/verifier.h"
-#include "tensorflow/lite/schema/schema_generated.h"
 #include "tensorflow/lite/stderr_reporter.h"
 
 namespace tflite {
 
-/// An RAII object that represents a read-only tflite model, copied from disk,
-/// or mmapped. This uses flatbuffers as the serialization format.
-///
-/// NOTE: The current API requires that a FlatBufferModel instance be kept alive
-/// by the client as long as it is in use by any dependent Interpreter
-/// instances. As the FlatBufferModel instance is effectively immutable after
-/// creation, the client may safely use a single model with multiple dependent
-/// Interpreter instances, even across multiple threads (though note that each
-/// Interpreter instance is *not* thread-safe).
-///
-/// <pre><code>
-/// using namespace tflite;
-/// StderrReporter error_reporter;
-/// auto model = FlatBufferModel::BuildFromFile("interesting_model.tflite",
-///                                             &error_reporter);
-/// MyOpResolver resolver;  // You need to subclass OpResolver to provide
-///                         // implementations.
-/// InterpreterBuilder builder(*model, resolver);
-/// std::unique_ptr<Interpreter> interpreter;
-/// if(builder(&interpreter) == kTfLiteOk) {
-///   .. run model inference with interpreter
-/// }
-/// </code></pre>
-///
-/// OpResolver must be defined to provide your kernel implementations to the
-/// interpreter. This is environment specific and may consist of just the
-/// builtin ops, or some custom operators you defined to extend tflite.
 namespace impl {
 
-class FlatBufferModel {
+class FlatBufferModel : public FlatBufferModelBase<FlatBufferModel> {
  public:
-  /// Builds a model based on a file.
-  /// Caller retains ownership of `error_reporter` and must ensure its lifetime
-  /// is longer than the FlatBufferModel instance.
-  /// Returns a nullptr in case of failure.
-  static std::unique_ptr<FlatBufferModel> BuildFromFile(
-      const char* filename,
-      ErrorReporter* error_reporter = DefaultErrorReporter());
+  // Use stderr_reporter as the default error reporter.
+  static ErrorReporter* GetDefaultErrorReporter() {
+    return DefaultErrorReporter();
+  }
 
-  /// Verifies whether the content of the file is legit, then builds a model
-  /// based on the file.
-  /// The extra_verifier argument is an additional optional verifier for the
-  /// file contents. By default, we always check with tflite::VerifyModelBuffer.
-  /// If extra_verifier is supplied, the file contents is also checked against
-  /// the extra_verifier after the check against tflite::VerifyModelBuilder.
-  /// Caller retains ownership of `error_reporter` and must ensure its lifetime
-  /// is longer than the FlatBufferModel instance.
-  /// Returns a nullptr in case of failure.
-  static std::unique_ptr<FlatBufferModel> VerifyAndBuildFromFile(
-      const char* filename, TfLiteVerifier* extra_verifier = nullptr,
-      ErrorReporter* error_reporter = DefaultErrorReporter());
-
-  /// Builds a model based on a file descriptor.
-  /// Caller retains ownership of `error_reporter` and must ensure its lifetime
-  /// is longer than the FlatBufferModel instance. Caller retains ownership of
-  /// `fd` and must ensure it is closed after BuildFromFile returns.
-  /// Returns a nullptr in case of failure.
-  static std::unique_ptr<FlatBufferModel> BuildFromFileDescriptor(
-      int fd,
-      ErrorReporter* error_reporter = DefaultErrorReporter());
-
-  /// Verifies whether the content of the file descriptor is legit, then builds
-  /// a model based on the file.
-  /// The extra_verifier argument is an additional optional verifier for the
-  /// file contents. By default, we always check with tflite::VerifyModelBuffer.
-  /// If extra_verifier is supplied, the file contents is also checked against
-  /// the extra_verifier after the check against tflite::VerifyModelBuilder.
-  /// Caller retains ownership of `error_reporter` and must ensure its lifetime
-  /// is longer than the FlatBufferModel instance.
-  /// Returns a nullptr in case of failure.
-  static std::unique_ptr<FlatBufferModel> VerifyAndBuildFromFileDescriptor(
-      int fd, TfLiteVerifier* extra_verifier = nullptr,
-      ErrorReporter* error_reporter = DefaultErrorReporter());
-
-  /// Builds a model based on a pre-loaded flatbuffer.
-  /// Caller retains ownership of the buffer and should keep it alive until
-  /// the returned object is destroyed. Caller also retains ownership of
-  /// `error_reporter` and must ensure its lifetime is longer than the
-  /// FlatBufferModel instance.
-  /// Returns a nullptr in case of failure.
-  /// NOTE: this does NOT validate the buffer so it should NOT be called on
-  /// invalid/untrusted input. Use VerifyAndBuildFromBuffer in that case
-  static std::unique_ptr<FlatBufferModel> BuildFromBuffer(
-      const char* caller_owned_buffer, size_t buffer_size,
-      ErrorReporter* error_reporter = DefaultErrorReporter());
-
-  /// Verifies whether the content of the buffer is legit, then builds a model
-  /// based on the pre-loaded flatbuffer.
-  /// The extra_verifier argument is an additional optional verifier for the
-  /// buffer. By default, we always check with tflite::VerifyModelBuffer. If
-  /// extra_verifier is supplied, the buffer is checked against the
-  /// extra_verifier after the check against tflite::VerifyModelBuilder. The
-  /// caller retains ownership of the buffer and should keep it alive until the
-  /// returned object is destroyed. Caller retains ownership of `error_reporter`
-  /// and must ensure its lifetime is longer than the FlatBufferModel instance.
-  /// Returns a nullptr in case of failure.
-  static std::unique_ptr<FlatBufferModel> VerifyAndBuildFromBuffer(
-      const char* caller_owned_buffer, size_t buffer_size,
-      TfLiteVerifier* extra_verifier = nullptr,
-      ErrorReporter* error_reporter = DefaultErrorReporter());
-
-  /// Builds a model directly from an allocation.
-  /// Ownership of the allocation is passed to the model, but the caller
-  /// retains ownership of `error_reporter` and must ensure its lifetime is
-  /// longer than the FlatBufferModel instance.
-  /// Returns a nullptr in case of failure (e.g., the allocation is invalid).
-  static std::unique_ptr<FlatBufferModel> BuildFromAllocation(
-      std::unique_ptr<Allocation> allocation,
-      ErrorReporter* error_reporter = DefaultErrorReporter());
-
-  /// Verifies whether the content of the allocation is legit, then builds a
-  /// model based on the provided allocation.
-  /// The extra_verifier argument is an additional optional verifier for the
-  /// buffer. By default, we always check with tflite::VerifyModelBuffer. If
-  /// extra_verifier is supplied, the buffer is checked against the
-  /// extra_verifier after the check against tflite::VerifyModelBuilder.
-  /// Ownership of the allocation is passed to the model, but the caller
-  /// retains ownership of `error_reporter` and must ensure its lifetime is
-  /// longer than the FlatBufferModel instance.
-  /// Returns a nullptr in case of failure.
-  static std::unique_ptr<FlatBufferModel> VerifyAndBuildFromAllocation(
-      std::unique_ptr<Allocation> allocation,
-      TfLiteVerifier* extra_verifier = nullptr,
-      ErrorReporter* error_reporter = DefaultErrorReporter());
-
-  /// Builds a model directly from a flatbuffer pointer
-  /// Caller retains ownership of the buffer and should keep it alive until the
-  /// returned object is destroyed. Caller retains ownership of `error_reporter`
-  /// and must ensure its lifetime is longer than the FlatBufferModel instance.
-  /// Returns a nullptr in case of failure.
-  static std::unique_ptr<FlatBufferModel> BuildFromModel(
-      const tflite::Model* caller_owned_model_spec,
-      ErrorReporter* error_reporter = DefaultErrorReporter());
-
-#if FLATBUFFERS_LITTLEENDIAN == 0
-  /// Byte swap a constant buffer in place.
-  static void ByteSwapBuffer(int8_t tensor_type, size_t buffer_size,
-                             uint8_t* buffer, bool from_big_endian = true);
-
-  /// Byte swap the buffers field of a TFLite Model instance in place.
-  static void ByteSwapTFLiteModel(const tflite::Model* tfl_model,
-                                  bool from_big_endian = true);
-
-  /// Byte swap the buffers field of a TFLite ModelT instance in place.
-  static void ByteSwapTFLiteModelT(tflite::ModelT* tfl_modelt,
-                                   bool from_big_endian = true);
-
-  /// Convert the TFLite buffers field between LE and BE format in a
-  /// FlatBufferModel which is not empty and return the converted instance.
-  static std::unique_ptr<FlatBufferModel> ByteConvertModel(
-      std::unique_ptr<FlatBufferModel> model,
-      ErrorReporter* error_reporter = DefaultErrorReporter(),
-      bool from_big_endian = false);
-
-  /// Byte Swap the TFLite buffers field in a FlatBufferModel and return the
-  /// swapped instance.
-  static std::unique_ptr<FlatBufferModel> ByteSwapFlatBufferModel(
-      std::unique_ptr<FlatBufferModel> model,
-      ErrorReporter* error_reporter = DefaultErrorReporter(),
-      bool from_big_endian = false);
-
-  /// Byte Swap the serialized String of a TFLite model in place.
-  static void ByteSwapSerializedModel(std::string* serialized_model,
-                                      bool from_big_endian = true);
-#endif
-
-  // Releases memory or unmaps mmaped memory.
-  ~FlatBufferModel();
-
-  // Copying or assignment is disallowed to simplify ownership semantics.
-  FlatBufferModel(const FlatBufferModel&) = delete;
-  FlatBufferModel& operator=(const FlatBufferModel&) = delete;
-
-  bool initialized() const { return model_ != nullptr; }
-  const tflite::Model* operator->() const { return model_; }
-  const tflite::Model* GetModel() const { return model_; }
-  ErrorReporter* error_reporter() const { return error_reporter_; }
-  const Allocation* allocation() const { return allocation_.get(); }
-
-  // Returns the minimum runtime version from the flatbuffer. This runtime
-  // version encodes the minimum required interpreter version to run the
-  // flatbuffer model. If the minimum version can't be determined, an empty
-  // string will be returned.
-  // Note that the returned minimum version is a lower-bound but not a strict
-  // lower-bound; ops in the graph may not have an associated runtime version,
-  // in which case the actual required runtime might be greater than the
-  // reported minimum.
-  std::string GetMinimumRuntime() const;
-
-  // Return model metadata as a mapping of name & buffer strings.
-  // See Metadata table in TFLite schema.
-  std::map<std::string, std::string> ReadAllMetadata() const;
-
-  // Return model metadata as a mapping of name & buffer strings.
-  // See Metadata table in TFLite schema.
-  static std::map<std::string, std::string> ReadAllMetadata(
-      const ::tflite::Model* model);
-
-  // If the buffer is stored as part of the Flatbuffer or outside
-  // return false if the buffers are part of the Flatbuffer
-  static bool CheckBufferOutsideModel(const tflite::Model* model);
-
-  // Validates if the FlatBufferModel's buffer is well-formed. Specifically, it
-  // checks if the 0th entry of the model buffers is an empty buffer (sentinel).
-  // This is a convention so that tensors without a buffer can provide 0
-  // as their buffer.
-  // NOTE: The function doesn't explicitly fail for backward compatibility
-  // reasons; it just provides a warning in case of failures.
-  void ValidateModelBuffers(ErrorReporter* error_reporter);
-
-  /// Returns true if the model identifier is correct (otherwise false and
-  /// reports an error).
-  bool CheckModelIdentifier() const;
-
- private:
-  /// Loads a model from a given allocation. FlatBufferModel will take over the
-  /// ownership of `allocation`, and delete it in destructor. The ownership of
-  /// `error_reporter`remains with the caller and must have lifetime at least
-  /// as much as FlatBufferModel. This is to allow multiple models to use the
-  /// same ErrorReporter instance.
-  explicit FlatBufferModel(
-      std::unique_ptr<Allocation> allocation,
-      ErrorReporter* error_reporter = DefaultErrorReporter());
-
-  /// Loads a model from Model flatbuffer. The `model` has to remain alive and
-  /// unchanged until the end of this flatbuffermodel's lifetime.
-  FlatBufferModel(const Model* model, ErrorReporter* error_reporter);
-
-  /// Flatbuffer traverser pointer. (Model* is a pointer that is within the
-  /// allocated memory of the data allocated by allocation's internals.
-  const tflite::Model* model_ = nullptr;
-  /// The error reporter to use for model errors and subsequent errors when
-  /// the interpreter is created
-  ErrorReporter* error_reporter_;
-  /// The allocator used for holding memory of the model. Note that this will
-  /// be null if the client provides a tflite::Model directly.
-  std::unique_ptr<Allocation> allocation_;
+  // Inherit all constructors from FlatBufferModelBase since inherited factory
+  // methods refer to them.
+  using FlatBufferModelBase<FlatBufferModel>::FlatBufferModelBase;
 };
 
 }  // namespace impl
diff --git a/tensorflow/lite/delegates/gpu/cl/BUILD b/tensorflow/lite/delegates/gpu/cl/BUILD
index 73f192d..b84cb9a 100644
--- a/tensorflow/lite/delegates/gpu/cl/BUILD
+++ b/tensorflow/lite/delegates/gpu/cl/BUILD
@@ -299,6 +299,7 @@
         ":cl_kernel",
         ":program_cache",
         ":tensor",
+        "//tensorflow/lite/delegates/gpu/common/task:compiler_options",
         "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
         "@com_google_absl//absl/strings",
     ],
diff --git a/tensorflow/lite/delegates/gpu/cl/cl_operation.cc b/tensorflow/lite/delegates/gpu/cl/cl_operation.cc
index 1cc1738..8fd9493 100644
--- a/tensorflow/lite/delegates/gpu/cl/cl_operation.cc
+++ b/tensorflow/lite/delegates/gpu/cl/cl_operation.cc
@@ -17,6 +17,8 @@
 
 #include <string>
 
+#include "tensorflow/lite/delegates/gpu/common/task/compiler_options.h"
+
 namespace tflite {
 namespace gpu {
 namespace cl {
@@ -165,6 +167,10 @@
                                 creation_context.context, &operation_->args_,
                                 &operation_->code_));
   operation_->args_.ReleaseCPURepresentation();
+  if (creation_context.device->info_.opencl_info.IsCLVK()) {
+    operation_->compiler_options_.push_back(
+        CompilerOptions::kClFastRelaxedMath);
+  }
   RETURN_IF_ERROR(creation_context.cache->GetOrCreateCLKernel(
       operation_->code_, "main_function", operation_->compiler_options_,
       *creation_context.context, *creation_context.device, &kernel_,
diff --git a/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.cc b/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.cc
index 7a40a60..5bd407d 100644
--- a/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.cc
+++ b/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.cc
@@ -75,6 +75,7 @@
 
   // Ordered records are to be sorted by size of corresponding tensor.
   std::vector<TensorUsageWithIndex<size_t>> ordered_records;
+  ordered_records.reserve(num_tensors);
   for (size_t i = 0; i < num_tensors; ++i) {
     ordered_records.emplace_back(&usage_records[i], i);
   }
diff --git a/tensorflow/lite/delegates/gpu/common/model.cc b/tensorflow/lite/delegates/gpu/common/model.cc
index dc68e70..a7a174f 100644
--- a/tensorflow/lite/delegates/gpu/common/model.cc
+++ b/tensorflow/lite/delegates/gpu/common/model.cc
@@ -333,10 +333,16 @@
   model->nodes_.clear();
   model->execution_plan_.clear();
   model->values_.clear();
+  model->known_graph_outputs_.clear();
   for (auto& value_def : values_) {
     model->values_.push_back({});
     if (value_def.value) {
       model->values_.back().value = std::make_unique<Value>(*value_def.value);
+      if (std::find(known_graph_outputs_.begin(), known_graph_outputs_.end(),
+                    value_def.value.get()) != known_graph_outputs_.end()) {
+        model->known_graph_outputs_.push_back(
+            model->values_.back().value.get());
+      }
     }
   }
   // Add all nodes first.
diff --git a/tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/sample_stable_delegate_external_test.cc b/tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/sample_stable_delegate_external_test.cc
index f80a52d..8554124 100644
--- a/tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/sample_stable_delegate_external_test.cc
+++ b/tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/sample_stable_delegate_external_test.cc
@@ -56,7 +56,7 @@
       LoadDelegateFromSharedLibrary(
           "tensorflow/lite/delegates/utils/experimental/"
           "sample_stable_delegate/"
-          "libtensorflowlite_sample_stable_delegate_for_test.so");
+          "libtensorflowlite_sample_stable_delegate.so");
   ASSERT_NE(stable_delegate_handle, nullptr);
   EXPECT_STREQ(stable_delegate_handle->delegate_abi_version,
                TFL_STABLE_DELEGATE_ABI_VERSION);
diff --git a/tensorflow/lite/examples/label_image/CMakeLists.txt b/tensorflow/lite/examples/label_image/CMakeLists.txt
index 9874801..2fcb09c 100644
--- a/tensorflow/lite/examples/label_image/CMakeLists.txt
+++ b/tensorflow/lite/examples/label_image/CMakeLists.txt
@@ -61,6 +61,11 @@
           ${TFLITE_SOURCE_DIR}/tools/delegates/external_delegate_provider.cc)
 endif()
 
+include_directories(label_image
+  PUBLIC
+  ${CMAKE_BINARY_DIR}
+)
+
 add_executable(label_image
   ${TFLITE_LABEL_IMAGE_SRCS}
 )
@@ -78,4 +83,6 @@
 )
 target_link_libraries(label_image
   tensorflow-lite
+  profiling_info_proto
+  protobuf
 )
diff --git a/tensorflow/lite/java/src/main/native/BUILD b/tensorflow/lite/java/src/main/native/BUILD
index 8ce5e0c..b6c518d 100644
--- a/tensorflow/lite/java/src/main/native/BUILD
+++ b/tensorflow/lite/java/src/main/native/BUILD
@@ -22,6 +22,7 @@
     visibility = jni_utils_visibility_allowlist(),
     deps = [
         "//tensorflow/lite:error_reporter",
+        "//tensorflow/lite/core/c:common",
         "//tensorflow/lite/java/jni",
     ],
 )
diff --git a/tensorflow/lite/java/src/main/native/jni_utils.h b/tensorflow/lite/java/src/main/native/jni_utils.h
index 1602d77..1796a38 100644
--- a/tensorflow/lite/java/src/main/native/jni_utils.h
+++ b/tensorflow/lite/java/src/main/native/jni_utils.h
@@ -21,6 +21,7 @@
 
 #include <vector>
 
+#include "tensorflow/lite/core/c/common.h"
 #include "tensorflow/lite/error_reporter.h"
 
 namespace tflite {
diff --git a/tensorflow/lite/kernels/embedding_lookup.cc b/tensorflow/lite/kernels/embedding_lookup.cc
index 4190fd7..d927010 100644
--- a/tensorflow/lite/kernels/embedding_lookup.cc
+++ b/tensorflow/lite/kernels/embedding_lookup.cc
@@ -104,13 +104,13 @@
     // Propagate empty tensor if input is empty
     return kTfLiteOk;
   }
-  const int row_bytes = value->bytes / row_size;
+  const int64_t row_bytes = value->bytes / row_size;
 
   char* output_raw = GetTensorData<char>(output);
   const char* value_raw = GetTensorData<char>(value);
   const int32_t* lookup_data = GetTensorData<int32_t>(lookup);
   for (int i = 0; i < SizeOfDimension(lookup, 0); i++) {
-    int idx = lookup_data[i];
+    int64_t idx = lookup_data[i];
     if (idx >= row_size || idx < 0) {
       TF_LITE_KERNEL_LOG(context,
                          "Embedding Lookup: index out of bounds. "
diff --git a/tensorflow/lite/kernels/embedding_lookup_test.cc b/tensorflow/lite/kernels/embedding_lookup_test.cc
index d13ddd4..679975f 100644
--- a/tensorflow/lite/kernels/embedding_lookup_test.cc
+++ b/tensorflow/lite/kernels/embedding_lookup_test.cc
@@ -92,6 +92,19 @@
       }
     }
   }
+
+  template <typename T>
+  void Set2DWeightMatrix(const std::function<T(int, int)>& function) {
+    TfLiteTensor* tensor = interpreter_->tensor(weight_);
+    int64_t rows = tensor->dims->data[0];
+    int64_t columns = tensor->dims->data[1];
+    T* data = GetTensorData<T>(tensor);
+    for (int64_t i = 0; i < rows; i++) {
+      for (int64_t j = 0; j < columns; j++) {
+        data[i * columns + j] = function(i, j);
+      }
+    }
+  }
 };
 
 class HybridEmbeddingLookupOpModel : public BaseEmbeddingLookupOpModel {
@@ -144,6 +157,27 @@
               })));
 }
 
+#if !defined(GOOGLE_UNSUPPORTED_OS_LOONIX) && defined(__LP64__)
+TEST(EmbeddingLookupOpTest, LargeTableTest) {
+  EmbeddingLookupOpModel m({1}, {256000, 9216});
+  // Choose a value specifically designed to overflow int32.max
+  m.SetInput({235248});
+  m.Set2DWeightMatrix<float>(
+      [](int i, int j) -> float { return j + i / 100.; });
+
+  // This will cause a lookup at index 235248 in a buffer where every row
+  // has 9216 entries * 4 bytes per entry, which will overflow unless
+  // the Op is using a 64-bit offset for address calculation.
+  ASSERT_EQ(m.Invoke(), kTfLiteOk);
+  std::vector<float> exp(9216);
+
+  for (int s = 0; s < exp.size(); s++) {
+    exp[s] = static_cast<float>(s) + 2352.48f;
+  }
+  EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear(exp)));
+}
+#endif
+
 TEST(HybridEmbeddingLookupHybridOpTest, Simple2DTestUint8) {
   HybridEmbeddingLookupOpModel m({3}, {3, 8}, TensorType_UINT8);
   m.SetInput({1, 0, 2});
diff --git a/tensorflow/lite/kernels/internal/portable_tensor_utils.h b/tensorflow/lite/kernels/internal/portable_tensor_utils.h
index d37fe6e..ed59fd0 100644
--- a/tensorflow/lite/kernels/internal/portable_tensor_utils.h
+++ b/tensorflow/lite/kernels/internal/portable_tensor_utils.h
@@ -317,7 +317,7 @@
 void ApplySigmoid(const int16_t* input, int32_t n_batch, int32_t n_input,
                   int16_t* output);
 
-// Same as above but the internal calcualtion is float.
+// Same as above but the internal calculation is float.
 void ApplySigmoidFloat(const int16_t* input, int32_t n_batch, int32_t n_input,
                        int16_t* output);
 
@@ -333,7 +333,7 @@
 void ApplyTanh(int32_t intger_bits, const int16_t* input, int32_t n_batch,
                int32_t n_input, int16_t* output);
 
-// Apply Tanh to a quantized vector. Tbe internal calculation is in float.
+// Apply Tanh to a quantized vector. The internal calculation is in float.
 //    - Input has 2^(integer_bits) as scale.
 //    - Output has Q0.15 as scale.
 void ApplyTanhFloat(const int16_t* input, int32_t n_batch, int32_t n_input,
diff --git a/tensorflow/lite/kernels/internal/quantization_util_test.cc b/tensorflow/lite/kernels/internal/quantization_util_test.cc
index aec0b2b..aa9c274 100644
--- a/tensorflow/lite/kernels/internal/quantization_util_test.cc
+++ b/tensorflow/lite/kernels/internal/quantization_util_test.cc
@@ -160,13 +160,13 @@
 //  255       | 30.0
 //  128       | 10.0
 TEST(QuantizationUtilTest, ChooseQuantizationParams) {
-  QuantizationParams qp = ChooseQuantizationParams<uint8>(-10.0, 30.0);
+  QuantizationParams qp = ChooseQuantizationParams<uint8_t>(-10.0, 30.0);
   EXPECT_NEAR(qp.scale, 0.156863, 1e-5);
   EXPECT_EQ(qp.zero_point, 64);
 }
 
 TEST(QuantizationUtilTest, ChooseQuantizationParamsZeroPointOnMinBoundary) {
-  QuantizationParams qp = ChooseQuantizationParams<uint8>(0.0, 30.0);
+  QuantizationParams qp = ChooseQuantizationParams<uint8_t>(0.0, 30.0);
   EXPECT_NEAR(qp.scale, 0.117647, 1e-5);
   EXPECT_EQ(qp.zero_point, 0);
 }
@@ -174,23 +174,23 @@
 #if GTEST_HAS_DEATH_TEST
 TEST(QuantizationUtilTest, ChooseQuantizationParamsZeroNotInRange) {
   // Assumption is that zero is within the range.
-  EXPECT_DEATH(ChooseQuantizationParams<uint8>(10.0, 30.0), "");
+  EXPECT_DEATH(ChooseQuantizationParams<uint8_t>(10.0, 30.0), "");
 }
 
 TEST(QuantizationUtilTest, ChooseQuantizationParamsEmptyRangePositive) {
   // Assumption is that zero is within the range.
-  EXPECT_DEATH(ChooseQuantizationParams<uint8>(30.0, 30.0), "");
+  EXPECT_DEATH(ChooseQuantizationParams<uint8_t>(30.0, 30.0), "");
 }
 #endif  // GTEST_HAS_DEATH_TEST
 
 TEST(QuantizationUtilTest, ChooseQuantizationParamsEmptyRangeZero) {
-  QuantizationParams qp = ChooseQuantizationParams<uint8>(0.0, 0.0);
+  QuantizationParams qp = ChooseQuantizationParams<uint8_t>(0.0, 0.0);
   EXPECT_NEAR(qp.scale, 0.0, 1e-5);
   EXPECT_EQ(qp.zero_point, 0);
 }
 
 TEST(QuantizationUtilTest, ChooseQuantizationParamsZeroPointOnMaxBoundary) {
-  QuantizationParams qp = ChooseQuantizationParams<uint8>(-10.0, 0.0);
+  QuantizationParams qp = ChooseQuantizationParams<uint8_t>(-10.0, 0.0);
   EXPECT_NEAR(qp.scale, 0.039216, 1e-5);
   EXPECT_EQ(qp.zero_point, 255);
 }
@@ -330,7 +330,7 @@
 
 #if GTEST_HAS_DEATH_TEST
 TEST(QuantizationUtilTest, ChooseQuantizationParamsInvalidRange) {
-  EXPECT_DEATH(ChooseQuantizationParams<uint8>(10.0, -30.0), "");
+  EXPECT_DEATH(ChooseQuantizationParams<uint8_t>(10.0, -30.0), "");
 }
 
 TEST(QuantizationUtilTest, QuantizeMultiplierSmallerThanOneExp) {
@@ -533,12 +533,12 @@
   const std::vector<double> weights = {-4,    -2,   -1,  -0.5, -0.25, -0.125, 0,
                                        0.125, 0.25, 0.5, 1,    2,     4};
   const int size = weights.size();
-  std::vector<int32> effective_scale_significand(size);
+  std::vector<int32_t> effective_scale_significand(size);
   std::vector<int> effective_scale_shift(size);
   QuantizeMultiplierArray(weights.data(), size,
                           effective_scale_significand.data(),
                           effective_scale_shift.data());
-  const std::vector<int32> expected_effective_scale_significand = {
+  const std::vector<int32_t> expected_effective_scale_significand = {
       -1073741824,  // float scale = -4
       -1073741824,  // float scale = -2
       -1073741824,  // float scale = -1
diff --git a/tensorflow/lite/kernels/parse_example/parse_example.cc b/tensorflow/lite/kernels/parse_example/parse_example.cc
index acec033..ec87aab 100644
--- a/tensorflow/lite/kernels/parse_example/parse_example.cc
+++ b/tensorflow/lite/kernels/parse_example/parse_example.cc
@@ -111,7 +111,7 @@
 bool ParseExample(StringRef serialized, Example* example) {
   DCHECK(example != nullptr);
   tf::protobuf::io::CodedInputStream stream(
-      reinterpret_cast<const uint8*>(serialized.str), serialized.len);
+      reinterpret_cast<const uint8_t*>(serialized.str), serialized.len);
   tensorflow::example::EnableAliasing(&stream);
   return ParseExample(&stream, example);
 }
diff --git a/tensorflow/lite/kernels/shim/BUILD b/tensorflow/lite/kernels/shim/BUILD
index 3244635..9aa10f1 100644
--- a/tensorflow/lite/kernels/shim/BUILD
+++ b/tensorflow/lite/kernels/shim/BUILD
@@ -163,6 +163,7 @@
         "@com_google_absl//absl/status",
         "@com_google_absl//absl/status:statusor",
         "@com_google_absl//absl/strings",
+        "@local_tsl//tsl/platform:macros",
     ] + if_mobile([
         "//tensorflow/core:portable_tensorflow_lib_lite",
     ]) + if_not_mobile([
diff --git a/tensorflow/lite/kernels/shim/test_op/BUILD b/tensorflow/lite/kernels/shim/test_op/BUILD
index e5703e9..af4cd02 100644
--- a/tensorflow/lite/kernels/shim/test_op/BUILD
+++ b/tensorflow/lite/kernels/shim/test_op/BUILD
@@ -48,7 +48,7 @@
         "//tensorflow/core/kernels:ops_testutil",
         "//tensorflow/core/platform:tstring",
         "@com_google_googletest//:gtest_main",
-        "@local_tsl//tsl/lib/core:status_test_util",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
 
@@ -132,7 +132,7 @@
         "//tensorflow/core/framework:tensor_testutil",
         "//tensorflow/core/kernels:ops_testutil",
         "@com_google_googletest//:gtest_main",
-        "@local_tsl//tsl/lib/core:status_test_util",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
 
diff --git a/tensorflow/lite/kernels/shim/test_op/simple_tf_op_test.cc b/tensorflow/lite/kernels/shim/test_op/simple_tf_op_test.cc
index db537b72..a37483f 100644
--- a/tensorflow/lite/kernels/shim/test_op/simple_tf_op_test.cc
+++ b/tensorflow/lite/kernels/shim/test_op/simple_tf_op_test.cc
@@ -15,6 +15,7 @@
 #include <cstdint>
 
 #include <gtest/gtest.h>
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/fake_input.h"
 #include "tensorflow/core/framework/node_def_builder.h"
 #include "tensorflow/core/framework/tensor_shape.h"
@@ -22,7 +23,6 @@
 #include "tensorflow/core/framework/types.pb.h"
 #include "tensorflow/core/kernels/ops_testutil.h"
 #include "tensorflow/core/platform/tstring.h"
-#include "tsl/lib/core/status_test_util.h"
 
 namespace tflite {
 namespace shim {
diff --git a/tensorflow/lite/kernels/shim/test_op/tmpl_tf_op_test.cc b/tensorflow/lite/kernels/shim/test_op/tmpl_tf_op_test.cc
index c457bcc..8e661d8 100644
--- a/tensorflow/lite/kernels/shim/test_op/tmpl_tf_op_test.cc
+++ b/tensorflow/lite/kernels/shim/test_op/tmpl_tf_op_test.cc
@@ -15,13 +15,13 @@
 #include <cstdint>
 
 #include <gtest/gtest.h>
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/fake_input.h"
 #include "tensorflow/core/framework/node_def_builder.h"
 #include "tensorflow/core/framework/tensor_shape.h"
 #include "tensorflow/core/framework/tensor_testutil.h"
 #include "tensorflow/core/framework/types.pb.h"
 #include "tensorflow/core/kernels/ops_testutil.h"
-#include "tsl/lib/core/status_test_util.h"
 
 namespace tflite {
 namespace shim {
diff --git a/tensorflow/lite/kernels/shim/tf_op_shim.cc b/tensorflow/lite/kernels/shim/tf_op_shim.cc
index 7d12bc8..d71cfa7 100644
--- a/tensorflow/lite/kernels/shim/tf_op_shim.cc
+++ b/tensorflow/lite/kernels/shim/tf_op_shim.cc
@@ -21,10 +21,17 @@
 #include <vector>
 
 #include "absl/status/status.h"
+#include "absl/status/statusor.h"
 #include "absl/strings/str_cat.h"
 #include "tensorflow/core/framework/attr_value.pb.h"
 #include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
 #include "tensorflow/core/protobuf/error_codes.pb.h"
+#include "tensorflow/lite/kernels/shim/op_kernel.h"
+#include "tensorflow/lite/kernels/shim/shape.h"
 #include "tensorflow/lite/kernels/shim/status_macros.h"
 #include "tensorflow/lite/kernels/shim/tensor_view.h"
 #include "tensorflow/lite/kernels/shim/tf_tensor_view.h"
diff --git a/tensorflow/lite/kernels/shim/tf_op_shim.h b/tensorflow/lite/kernels/shim/tf_op_shim.h
index 834a394..8f6442b 100644
--- a/tensorflow/lite/kernels/shim/tf_op_shim.h
+++ b/tensorflow/lite/kernels/shim/tf_op_shim.h
@@ -21,13 +21,15 @@
 #include "absl/status/status.h"
 #include "absl/status/statusor.h"
 #include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/op.h"
 #include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/op_requires.h"
 #include "tensorflow/core/framework/registration/registration.h"
 #include "tensorflow/core/framework/shape_inference.h"
-#include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/status.h"
 #include "tensorflow/lite/kernels/shim/op_kernel.h"
 #include "tensorflow/lite/kernels/shim/shape.h"
+#include "tsl/platform/macros.h"
 
 // This file contains the TF adapter. That is, it takes a `OpKernelShim`
 // class and provides a TF kernel out of it.
@@ -51,9 +53,9 @@
  public:
   explicit TfInvokeContext(::tensorflow::OpKernelContext* context);
   // Read an input tensor
-  ConstTensorViewOr GetInput(const int idx) const;
+  ConstTensorViewOr GetInput(int idx) const;
   // Get a mutable output tensor
-  TensorViewOr GetOutput(const int idx, const Shape& shape) const;
+  TensorViewOr GetOutput(int idx, const Shape& shape) const;
   // Number of input tensors
   int NumInputs() const;
   // Number of output tensors
@@ -70,11 +72,11 @@
   explicit TfShapeInferenceContext(
       ::tensorflow::shape_inference::InferenceContext* context);
   // Read an input tensor shape
-  ShapeOr GetInputShape(const int idx) const;
+  ShapeOr GetInputShape(int idx) const;
   // Set an output tensor shape
-  absl::Status SetOutputShape(const int idx, const Shape& shape);
+  absl::Status SetOutputShape(int idx, const Shape& shape);
   // Read an input tensor during shape inference
-  ConstTensorViewOr GetInputTensor(const int idx) const;
+  ConstTensorViewOr GetInputTensor(int idx) const;
   // Read a given attribute
   absl::StatusOr<AttrValue> GetAttr(const std::string& attr_name) const;
   // Number of input tensors
diff --git a/tensorflow/lite/profiling/profile_summary_formatter_test.cc b/tensorflow/lite/profiling/profile_summary_formatter_test.cc
index 48b0697..a6ba380 100644
--- a/tensorflow/lite/profiling/profile_summary_formatter_test.cc
+++ b/tensorflow/lite/profiling/profile_summary_formatter_test.cc
@@ -14,13 +14,14 @@
 ==============================================================================*/
 #include "tensorflow/lite/profiling/profile_summary_formatter.h"
 
+#include <cstddef>
 #include <fstream>
 #include <ios>
 #include <map>
 #include <memory>
 #include <string>
+#include <tuple>
 
-#include <gmock/gmock.h>
 #include <gtest/gtest.h>
 #include "absl/strings/match.h"
 #include "tensorflow/core/util/stat_summarizer_options.h"
@@ -32,6 +33,127 @@
 
 namespace {
 
+// LINT.IfChange(OpProfilingStatComparator)
+bool AreOpProfilingStatEqual(const OpProfilingStat& op_profiling_stat_1,
+                             const OpProfilingStat& op_profiling_stat_2) {
+  auto proto_to_tuple = [](const OpProfilingStat& op_profiling_stat) {
+    return std::make_tuple(op_profiling_stat.first(), op_profiling_stat.last(),
+                           op_profiling_stat.avg(), op_profiling_stat.stddev(),
+                           op_profiling_stat.variance(),
+                           op_profiling_stat.min(), op_profiling_stat.max(),
+                           op_profiling_stat.sum(), op_profiling_stat.count());
+  };
+  return proto_to_tuple(op_profiling_stat_1) ==
+         proto_to_tuple(op_profiling_stat_2);
+}
+// LINT.ThenChange(//tensorflow/lite/profiling/proto/profiling_info.proto:OpProfilingStat)
+
+// LINT.IfChange(OpProfileDataComparator)
+bool AreOpProfileDataEqual(const OpProfileData& op_profile_data_1,
+                           const OpProfileData& op_profile_data_2) {
+  auto proto_to_tuple = [](const OpProfileData& op_profile_data) {
+    return std::make_tuple(op_profile_data.node_type(),
+                           op_profile_data.times_called(),
+                           op_profile_data.name(), op_profile_data.run_order());
+  };
+
+  return (proto_to_tuple(op_profile_data_1) ==
+          proto_to_tuple(op_profile_data_2)) &&
+         AreOpProfilingStatEqual(op_profile_data_1.inference_microseconds(),
+                                 op_profile_data_2.inference_microseconds()) &&
+         (AreOpProfilingStatEqual(op_profile_data_1.mem_kb(),
+                                  op_profile_data_2.mem_kb()));
+}
+// LINT.ThenChange(//tensorflow/lite/profiling/proto/profiling_info.proto:OpProfileData)
+
+// LINT.IfChange(SubGraphProfilingDataComparator)
+bool AreSubGraphProfilingDataEqual(
+    const SubGraphProfilingData& subgraph_profiling_data_1,
+    const SubGraphProfilingData& subgraph_profiling_data_2) {
+  auto proto_to_tuple =
+      [](const SubGraphProfilingData& subgraph_profiling_data) {
+        return std::make_tuple(
+            subgraph_profiling_data.subgraph_name(),
+            subgraph_profiling_data.per_op_profiles().size());
+      };
+
+  if (proto_to_tuple(subgraph_profiling_data_1) ==
+      proto_to_tuple(subgraph_profiling_data_2)) {
+    for (size_t i = 0; i < subgraph_profiling_data_1.per_op_profiles().size();
+         ++i) {
+      auto op_profile_data_1 = subgraph_profiling_data_1.per_op_profiles(i);
+      auto op_profile_data_2 = subgraph_profiling_data_2.per_op_profiles(i);
+      if (!AreOpProfileDataEqual(op_profile_data_1, op_profile_data_2)) {
+        return false;
+      }
+    }
+    return true;
+  }
+  return false;
+}
+// LINT.ThenChange(//tensorflow/lite/profiling/proto/profiling_info.proto:SubGraphProfilingData)
+
+// LINT.IfChange(DelegateProfilingDataComparator)
+bool AreDelegateProfilingDataEqual(
+    const DelegateProfilingData& delegate_profiling_data_1,
+    const DelegateProfilingData& delegate_profiling_data_2) {
+  auto proto_to_tuple =
+      [](const DelegateProfilingData& delegate_profiling_data) {
+        return std::make_tuple(
+            delegate_profiling_data.delegate_name(),
+            delegate_profiling_data.per_op_profiles().size());
+      };
+
+  if (proto_to_tuple(delegate_profiling_data_1) ==
+      proto_to_tuple(delegate_profiling_data_2)) {
+    for (size_t i = 0; i < delegate_profiling_data_1.per_op_profiles().size();
+         ++i) {
+      auto op_profile_data_1 = delegate_profiling_data_1.per_op_profiles(i);
+      auto op_profile_data_2 = delegate_profiling_data_2.per_op_profiles(i);
+      if (!AreOpProfileDataEqual(op_profile_data_1, op_profile_data_2)) {
+        return false;
+      }
+    }
+    return true;
+  }
+  return false;
+}
+// LINT.ThenChange(//tensorflow/lite/profiling/proto/profiling_info.proto:DelegateProfilingData)
+
+// LINT.IfChange(ModelProfilingDataComparator)
+bool AreModelProfilingDataEqual(
+    const ModelProfilingData& model_profiling_data_1,
+    const ModelProfilingData& model_profiling_data_2) {
+  if (model_profiling_data_1.subgraph_profiles().size() !=
+      model_profiling_data_2.subgraph_profiles().size()) {
+    return false;
+  }
+  for (size_t i = 0; i < model_profiling_data_1.subgraph_profiles().size();
+       ++i) {
+    auto subgraph_profile_1 = model_profiling_data_1.subgraph_profiles(i);
+    auto subgraph_profile_2 = model_profiling_data_2.subgraph_profiles(i);
+    if (!AreSubGraphProfilingDataEqual(subgraph_profile_1,
+                                       subgraph_profile_2)) {
+      return false;
+    }
+  }
+  if (model_profiling_data_1.delegate_profiles().size() !=
+      model_profiling_data_2.delegate_profiles().size()) {
+    return false;
+  }
+  for (size_t i = 0; i < model_profiling_data_1.delegate_profiles().size();
+       ++i) {
+    auto delegate_profile_1 = model_profiling_data_1.delegate_profiles(i);
+    auto delegate_profile_2 = model_profiling_data_2.delegate_profiles(i);
+    if (!AreDelegateProfilingDataEqual(delegate_profile_1,
+                                       delegate_profile_2)) {
+      return false;
+    }
+  }
+  return true;
+}
+// LINT.ThenChange(//tensorflow/lite/profiling/proto/profiling_info.proto:ModelProfilingData)
+
 TEST(SummaryWriterTest, SummaryOptionStdOut) {
   ProfileSummaryDefaultFormatter writer;
   tensorflow::StatSummarizerOptions options = writer.GetStatSummarizerOptions();
@@ -182,8 +304,9 @@
   op_profile_data_1.set_name(kernel_name_1);
   op_profile_data_1.set_run_order(1);
   op_profile_data_1.set_times_called(2);
-  EXPECT_THAT(model_profiling_data.subgraph_profiles(0).per_op_profiles(0),
-              testing::EqualsProto(op_profile_data_1));
+  EXPECT_TRUE(AreOpProfileDataEqual(
+      model_profiling_data.subgraph_profiles(0).per_op_profiles(0),
+      op_profile_data_1));
 
   OpProfileData op_profile_data_2;
   op_profile_data_2.set_node_type(op_name_2);
@@ -212,8 +335,9 @@
   op_profile_data_2.set_name(kernel_name_2);
   op_profile_data_2.set_run_order(2);
 
-  EXPECT_THAT(model_profiling_data.subgraph_profiles(0).per_op_profiles(1),
-              testing::EqualsProto(op_profile_data_2));
+  EXPECT_TRUE(AreOpProfileDataEqual(
+      model_profiling_data.subgraph_profiles(0).per_op_profiles(1),
+      op_profile_data_2));
 
   ASSERT_EQ(model_profiling_data.subgraph_profiles(1).subgraph_name(),
             "Subgraph 1");
@@ -246,8 +370,9 @@
   op_profile_data_3.set_times_called(1);
   op_profile_data_3.set_name(kernel_name_3);
   op_profile_data_3.set_run_order(3);
-  EXPECT_THAT(model_profiling_data.subgraph_profiles(1).per_op_profiles(0),
-              testing::EqualsProto(op_profile_data_3));
+  EXPECT_TRUE(AreOpProfileDataEqual(
+      model_profiling_data.subgraph_profiles(1).per_op_profiles(0),
+      op_profile_data_3));
 }
 
 TEST(SummaryWriterTest, MultiSubgraphHandleOutputForProto) {
@@ -351,10 +476,10 @@
   file.close();
 
   ASSERT_TRUE(benchmark_profiling_data.model_name().empty());
-  EXPECT_THAT(benchmark_profiling_data.init_profile(),
-              testing::EqualsProto(model_profiling_data_init));
-  EXPECT_THAT(benchmark_profiling_data.runtime_profile(),
-              testing::EqualsProto(model_profiling_data_run));
+  EXPECT_TRUE(AreModelProfilingDataEqual(
+      benchmark_profiling_data.init_profile(), model_profiling_data_init));
+  EXPECT_TRUE(AreModelProfilingDataEqual(
+      benchmark_profiling_data.runtime_profile(), model_profiling_data_run));
 }
 
 TEST(SummaryWriterTest, MultiSubgraphShortSummary) {
diff --git a/tensorflow/lite/profiling/proto/profiling_info.proto b/tensorflow/lite/profiling/proto/profiling_info.proto
index 8116524..5d33571 100644
--- a/tensorflow/lite/profiling/proto/profiling_info.proto
+++ b/tensorflow/lite/profiling/proto/profiling_info.proto
@@ -25,22 +25,27 @@
   optional ModelProfilingData runtime_profile = 3;
 }
 
+// LINT.IfChange(ModelProfilingData)
 message ModelProfilingData {
   repeated SubGraphProfilingData subgraph_profiles = 1;
   repeated DelegateProfilingData delegate_profiles = 2;
 }
+// LINT.ThenChange(//tensorflow/lite/profiling/profile_summary_formatter_test.cc:ModelProfilingDataComparator)
 
+// LINT.IfChange(SubGraphProfilingData)
 message SubGraphProfilingData {
   optional string subgraph_name = 1;
   optional int32 subgraph_index = 2;
   repeated OpProfileData per_op_profiles = 3;
 }
+// LINT.ThenChange(//tensorflow/lite/profiling/profile_summary_formatter_test.cc:SubGraphProfilingDataComparator)
 
 message DelegateProfilingData {
   optional string delegate_name = 1;
   repeated OpProfileData per_op_profiles = 2;
 }
 
+// LINT.IfChange(OpProfilingStat)
 message OpProfilingStat {
   optional int64 first = 1;
   optional int64 last = 2;
@@ -52,7 +57,9 @@
   optional int64 sum = 8;
   optional int64 count = 9;
 }
+// LINT.ThenChange(//tensorflow/lite/profiling/profile_summary_formatter_test.cc:OpProfilingStatComparator)
 
+// LINT.IfChange(OpProfileData)
 message OpProfileData {
   optional string node_type = 1;
   optional OpProfilingStat inference_microseconds = 2;
@@ -61,3 +68,4 @@
   optional string name = 5;
   optional int64 run_order = 6;
 }
+// LINT.ThenChange(//tensorflow/lite/profiling/profile_summary_formatter_test.cc:OpProfileDataComparator)
diff --git a/tensorflow/lite/schema/BUILD b/tensorflow/lite/schema/BUILD
index 7bf0f18..e064789 100644
--- a/tensorflow/lite/schema/BUILD
+++ b/tensorflow/lite/schema/BUILD
@@ -48,10 +48,10 @@
         "upgrade_schema.py",
     ],
     data = [
-        "schema_v0.fbs",
-        "schema_v1.fbs",
-        "schema_v2.fbs",
-        "schema_v3.fbs",
+        "//tensorflow/compiler/mlir/lite/schema:schema_v0.fbs",
+        "//tensorflow/compiler/mlir/lite/schema:schema_v1.fbs",
+        "//tensorflow/compiler/mlir/lite/schema:schema_v2.fbs",
+        "//tensorflow/compiler/mlir/lite/schema:schema_v3.fbs",
         "@flatbuffers//:flatc",
     ],
     srcs_version = "PY3",
@@ -103,13 +103,6 @@
 
 exports_files([
     "conversion_metadata.fbs",
-    "schema.fbs",
-    "schema_v0.fbs",
-    "schema_v1.fbs",
-    "schema_v2.fbs",
-    "schema_v3.fbs",
-    "schema_v3a.fbs",
-    "schema_v3b.fbs",
 ])
 
 flatbuffer_cc_library(
diff --git a/tensorflow/lite/schema/schema_v3b.fbs b/tensorflow/lite/schema/schema_v3b.fbs
deleted file mode 100644
index 9177860..0000000
--- a/tensorflow/lite/schema/schema_v3b.fbs
+++ /dev/null
@@ -1,1242 +0,0 @@
-// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//     http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Revision History
-// Version 0: Initial version.
-// Version 1: Add subgraphs to schema.
-// Version 2: Rename operators to conform to NN API.
-// Version 3: Move buffer data from Model.Subgraph.Tensors to Model.Buffers.
-// Version 3a: Add new builtin op code field. Has backward compatibility with
-//             version 3.
-// Version 3b: Rename fields in SignatureDef. Has backward compatibility with
-//             version 3 and 3a.
-
-namespace tflite;
-
-// This corresponds to the version.
-file_identifier "TFL3";
-// File extension of any written files.
-file_extension "tflite";
-
-// IMPORTANT: All new members of tables, enums and unions must be added at the
-// end to ensure backwards compatibility.
-
-// The type of data stored in a tensor.
-enum TensorType : byte {
-  FLOAT32 = 0,
-  FLOAT16 = 1,
-  INT32 = 2,
-  UINT8 = 3,
-  INT64 = 4,
-  STRING = 5,
-  BOOL = 6,
-  INT16 = 7,
-  COMPLEX64 = 8,
-  INT8 = 9,
-  FLOAT64 = 10,
-  COMPLEX128 = 11,
-  UINT64 = 12,
-  // Experimental: Resource and variant types are experimental, that are subject
-  // to change. Do not implement custom kernels using resource & variant types
-  // now.
-  RESOURCE = 13,
-  VARIANT = 14,
-  UINT32 = 15,
-}
-
-// Custom quantization parameters for experimenting with new quantization
-// techniques.
-table CustomQuantization {
-  custom:[ubyte] (force_align: 16);
-}
-
-// Represents a specific quantization technique's parameters.
-union QuantizationDetails {
-  CustomQuantization,
-}
-
-// Parameters for converting a quantized tensor back to float.
-table QuantizationParameters {
-  // These four parameters are the asymmetric linear quantization parameters.
-  // Given a quantized value q, the corresponding float value f should be:
-  //   f = scale * (q - zero_point)
-  // For other quantization types, the QuantizationDetails below is used.
-  min:[float];  // For importing back into tensorflow.
-  max:[float];  // For importing back into tensorflow.
-  scale:[float];  // For dequantizing the tensor's values.
-  zero_point:[long];
-
-  // If this is not none, the other quantization parameters (i.e. min, max,
-  // scale, zero_point fields above) are ignored and the value of the
-  // QuantizationDetails union should be used.
-  details:QuantizationDetails;
-
-  // Specifies the dimension of the Tensor's shape that the scales and
-  // zero_points correspond to. For example, a tensor t, with dims=[4, 3, 2, 1]
-  // with quantization params:
-  //   scale=[1.0, 2.0, 3.0], zero_point=[1, 2, 3], quantization_dimension=1
-  // will be quantized across the second dimension of t.
-  //   t[:, 0, :, :] will have scale[0]=1.0, zero_point[0]=1
-  //   t[:, 1, :, :] will have scale[1]=2.0, zero_point[0]=2
-  //   t[:, 2, :, :] will have scale[2]=3.0, zero_point[0]=3
-  quantized_dimension:int;
-}
-
-// Sparse tensors.
-// We use a modification of the TACO format.
-// Reference: http://tensor-compiler.org/kjolstad-oopsla17-tensor-compiler.pdf
-//
-// To encode a conceptual n-dimensional dense tensor with dims (d0, ..., dn-1),
-// potentially with a k-dimensional block (0 <= k <= n) with dims
-// (dn, ..., dn+k-1), the format needs to specify:
-//   1. In what order to traverse these dimensions. For example, to store a 2-D
-//      matrix in row major order, the traversal order would be (d0, d1),
-//      whereas to store it in column major order, the traversal order would be
-//      (d1, d0). If the 2-D matrix has a 2-D inner block, the traversal order
-//      could be (d0, d1, d2, d3).
-//   2. How each block dimension in (dn, ..., dn+k-1) maps to the original
-//      tensor dimension in (d0, ..., dn-1).
-//   3. In the traversal order defined above, the format (dense vs. sparse) and
-//      index metadata for each dimension. For a dense dimension, this is just
-//      the size of that dimension. For a sparse dimension, it's the same as
-//      the compressed index defined in the Compressed Sparse Row (CSR) format.
-//      (http://scipy-lectures.org/advanced/scipy_sparse/csr_matrix.html)
-
-// The storage type for a dimension. Currently we support:
-//   1. DENSE: each coordinate in this dimension is stored implicitly.
-//   2. SPARSE_CSR: only the coordinates with non-zero elements are stored. The
-//      compression technique is the same what CSR uses.
-// More types like a sparse dimension with a different compression technique
-// could be added to the list in the future.
-enum DimensionType : byte {
-  DENSE = 0,
-  SPARSE_CSR = 1,
-}
-
-table Int32Vector {
-  values:[int];
-}
-
-table Uint16Vector {
-  values:[ushort] (force_align: 4);
-}
-
-table Uint8Vector {
-  values:[ubyte] (force_align: 4);
-}
-
-// Variable-typed buffer to store the index metadata for a sparse dimension.
-// The widest type is Int32 instead of UInt32 because tensor's shape is a int32
-// vector. We don't want the per-dimensional index to overflow that range.
-union SparseIndexVector {
-  Int32Vector,
-  Uint16Vector,
-  Uint8Vector
-}
-
-table DimensionMetadata {
-  // Whether a dimension is dense or sparse.
-  format:DimensionType;
-  // Index metadata used for a dimension.
-  //   - If format is DimensionType.DENSE then we use the dense_size field to
-  //     store the size of that dimension. Each index in that dimension is
-  //     stored implicitly.
-  //   - If format is DimensionType.SPARSE_CSR then we use array_segments and
-  //     array_indices to encode that dimension. array_segments represents how
-  //     to segment the indices array, each segment corresponds to one element
-  //     in the previous dimension. array_indices represents the index of the
-  //     non-zero elements within this dimension (as those in the CSR matrix
-  //     format, where the first array is row pointers and the second array is
-  //     column indices).
-  dense_size:int;
-  array_segments:SparseIndexVector;
-  array_indices:SparseIndexVector;
-}
-
-// Parameters to encode a sparse TfLite tensor.
-table SparsityParameters {
-  // The traversal order of the dimensions defined in the `shape` field of the
-  // conceptual dense tensor. For a n-dimensional tensors with dims (d0, d1,
-  // ..., dn-1),
-  //   - if not block sparse, the traversal_order is just a permutation of (d0,
-  //     ..., dn-1). For example, a 2-D matrix stored in row-major order would
-  //     have traversal_order = (d0, d1).
-  //   - if block sparse with a k-dimensional block (0 <= k <= n), the
-  //     traversal_order has n + k elements. The first n elements are still a
-  //     permutation of (d0, ..., dn-1). The lask k elements are a permutation
-  //     of (dn, ..., dn+k-1), defining how to traverse a block internally. For
-  //     example, a 2-D matrix with 2-D blocks, both stored in row-major order
-  //     would have traversal_order = (d0, d1, d2, d3).
-  traversal_order:[int];
-  // For an n-dimensional tensor with a k-dimensional block (0 <= k <= n),
-  // stores how a block dimension in (dn, ..., dn+k-1) maps to the original
-  // tensor dimension in (d0, ..., dn).
-  // It's stored in the order of (dn, ..., dn+k-1).
-  // If not block-sparse, this field is NULL.
-  block_map:[int];
-  // In the traversal order defined above, the metadata needed for
-  // each dimension to locate the non-zero values in the original dense tensor.
-  // The size of the dim_metadata array = the size of the traversal_order array
-  // = n + k.
-  dim_metadata:[DimensionMetadata];
-}
-
-table Tensor {
-  // The tensor shape. The meaning of each entry is operator-specific but
-  // builtin ops use: [batch size, height, width, number of channels] (That's
-  // Tensorflow's NHWC).
-  shape:[int];
-  type:TensorType;
-  // An index that refers to the buffers table at the root of the model. Or,
-  // if there is no data buffer associated (i.e. intermediate results), then
-  // this is 0 (which refers to an always existent empty buffer).
-  //
-  // The data_buffer itself is an opaque container, with the assumption that the
-  // target device is little-endian. In addition, all builtin operators assume
-  // the memory is ordered such that if `shape` is [4, 3, 2], then index
-  // [i, j, k] maps to data_buffer[i*3*2 + j*2 + k].
-  buffer:uint;
-  name:string;  // For debugging and importing back into tensorflow.
-  quantization:QuantizationParameters;  // Optional.
-
-  is_variable:bool = false;
-
-  // Parameters to encode a sparse tensor. See the example in
-  // tensorflow/lite/testdata/sparse_tensor.json.
-  sparsity:SparsityParameters;  // Optional.
-
-  // Encodes `shape` with unknown dimensions. Unknown dimensions are
-  // represented with -1.
-  shape_signature:[int]; // Optional.
-}
-
-// A list of builtin operators. Builtin operators are slightly faster than custom
-// ones, but not by much. Moreover, while custom operators accept an opaque
-// object containing configuration parameters, builtins have a predetermined
-// set of acceptable options.
-// LINT.IfChange
-enum BuiltinOperator : int32 {
-  ADD = 0,
-  AVERAGE_POOL_2D = 1,
-  CONCATENATION = 2,
-  CONV_2D = 3,
-  DEPTHWISE_CONV_2D = 4,
-  DEPTH_TO_SPACE = 5,
-  DEQUANTIZE = 6,
-  EMBEDDING_LOOKUP = 7,
-  FLOOR = 8,
-  FULLY_CONNECTED = 9,
-  HASHTABLE_LOOKUP = 10,
-  L2_NORMALIZATION = 11,
-  L2_POOL_2D = 12,
-  LOCAL_RESPONSE_NORMALIZATION = 13,
-  LOGISTIC = 14,
-  LSH_PROJECTION = 15,
-  LSTM = 16,
-  MAX_POOL_2D = 17,
-  MUL = 18,
-  RELU = 19,
-  // NOTE(aselle): RELU_N1_TO_1 used to be called RELU1, but it was renamed
-  // since different model developers use RELU1 in different ways. Never
-  // create another op called RELU1.
-  RELU_N1_TO_1 = 20,
-  RELU6 = 21,
-  RESHAPE = 22,
-  RESIZE_BILINEAR = 23,
-  RNN = 24,
-  SOFTMAX = 25,
-  SPACE_TO_DEPTH = 26,
-  SVDF = 27,
-  TANH = 28,
-  CONCAT_EMBEDDINGS = 29,
-  SKIP_GRAM = 30,
-  CALL = 31,
-  CUSTOM = 32,
-  EMBEDDING_LOOKUP_SPARSE = 33,
-  PAD = 34,
-  UNIDIRECTIONAL_SEQUENCE_RNN = 35,
-  GATHER = 36,
-  BATCH_TO_SPACE_ND = 37,
-  SPACE_TO_BATCH_ND = 38,
-  TRANSPOSE = 39,
-  MEAN = 40,
-  SUB = 41,
-  DIV = 42,
-  SQUEEZE = 43,
-  UNIDIRECTIONAL_SEQUENCE_LSTM = 44,
-  STRIDED_SLICE = 45,
-  BIDIRECTIONAL_SEQUENCE_RNN = 46,
-  EXP = 47,
-  TOPK_V2 = 48,
-  SPLIT = 49,
-  LOG_SOFTMAX = 50,
-  // DELEGATE is a special op type for the operations which are delegated to
-  // other backends.
-  // WARNING: Experimental interface, subject to change
-  DELEGATE = 51,
-  BIDIRECTIONAL_SEQUENCE_LSTM = 52,
-  CAST = 53,
-  PRELU = 54,
-  MAXIMUM = 55,
-  ARG_MAX = 56,
-  MINIMUM = 57,
-  LESS = 58,
-  NEG = 59,
-  PADV2 = 60,
-  GREATER = 61,
-  GREATER_EQUAL = 62,
-  LESS_EQUAL = 63,
-  SELECT = 64,
-  SLICE = 65,
-  SIN = 66,
-  TRANSPOSE_CONV = 67,
-  SPARSE_TO_DENSE = 68,
-  TILE = 69,
-  EXPAND_DIMS = 70,
-  EQUAL = 71,
-  NOT_EQUAL = 72,
-  LOG = 73,
-  SUM = 74,
-  SQRT = 75,
-  RSQRT = 76,
-  SHAPE = 77,
-  POW = 78,
-  ARG_MIN = 79,
-  FAKE_QUANT = 80,
-  REDUCE_PROD = 81,
-  REDUCE_MAX = 82,
-  PACK = 83,
-  LOGICAL_OR = 84,
-  ONE_HOT = 85,
-  LOGICAL_AND = 86,
-  LOGICAL_NOT = 87,
-  UNPACK = 88,
-  REDUCE_MIN = 89,
-  FLOOR_DIV = 90,
-  REDUCE_ANY = 91,
-  SQUARE = 92,
-  ZEROS_LIKE = 93,
-  FILL = 94,
-  FLOOR_MOD = 95,
-  RANGE = 96,
-  RESIZE_NEAREST_NEIGHBOR = 97,
-  LEAKY_RELU = 98,
-  SQUARED_DIFFERENCE = 99,
-  MIRROR_PAD = 100,
-  ABS = 101,
-  SPLIT_V = 102,
-  UNIQUE = 103,
-  CEIL = 104,
-  REVERSE_V2 = 105,
-  ADD_N = 106,
-  GATHER_ND = 107,
-  COS = 108,
-  WHERE = 109,
-  RANK = 110,
-  ELU = 111,
-  REVERSE_SEQUENCE = 112,
-  MATRIX_DIAG = 113,
-  QUANTIZE = 114,
-  MATRIX_SET_DIAG = 115,
-  ROUND = 116,
-  HARD_SWISH = 117,
-  IF = 118,
-  WHILE = 119,
-  NON_MAX_SUPPRESSION_V4 = 120,
-  NON_MAX_SUPPRESSION_V5 = 121,
-  SCATTER_ND = 122,
-  SELECT_V2 = 123,
-  DENSIFY = 124,
-  SEGMENT_SUM = 125,
-  BATCH_MATMUL = 126,
-  PLACEHOLDER_FOR_GREATER_OP_CODES = 127,
-  CUMSUM = 128,
-  CALL_ONCE = 129,
-  BROADCAST_TO = 130,
-  RFFT2D = 131,
-  CONV_3D = 132,
-  IMAG=133,
-  REAL=134,
-  COMPLEX_ABS=135,
-  HASHTABLE = 136,
-  HASHTABLE_FIND = 137,
-  HASHTABLE_IMPORT = 138,
-  HASHTABLE_SIZE = 139,
-  REDUCE_ALL = 140,
-  CONV_3D_TRANSPOSE = 141,
-  VAR_HANDLE = 142,
-  READ_VARIABLE = 143,
-  ASSIGN_VARIABLE = 144,
-}
-// LINT.ThenChange(nnapi_linter/linter.proto)
-
-// Options for the builtin operators.
-union BuiltinOptions {
-  Conv2DOptions,
-  DepthwiseConv2DOptions,
-  ConcatEmbeddingsOptions,
-  LSHProjectionOptions,
-  Pool2DOptions,
-  SVDFOptions,
-  RNNOptions,
-  FullyConnectedOptions,
-  SoftmaxOptions,
-  ConcatenationOptions,
-  AddOptions,
-  L2NormOptions,
-  LocalResponseNormalizationOptions,
-  LSTMOptions,
-  ResizeBilinearOptions,
-  CallOptions,
-  ReshapeOptions,
-  SkipGramOptions,
-  SpaceToDepthOptions,
-  EmbeddingLookupSparseOptions,
-  MulOptions,
-  PadOptions,
-  GatherOptions,
-  BatchToSpaceNDOptions,
-  SpaceToBatchNDOptions,
-  TransposeOptions,
-  ReducerOptions,
-  SubOptions,
-  DivOptions,
-  SqueezeOptions,
-  SequenceRNNOptions,
-  StridedSliceOptions,
-  ExpOptions,
-  TopKV2Options,
-  SplitOptions,
-  LogSoftmaxOptions,
-  CastOptions,
-  DequantizeOptions,
-  MaximumMinimumOptions,
-  ArgMaxOptions,
-  LessOptions,
-  NegOptions,
-  PadV2Options,
-  GreaterOptions,
-  GreaterEqualOptions,
-  LessEqualOptions,
-  SelectOptions,
-  SliceOptions,
-  TransposeConvOptions,
-  SparseToDenseOptions,
-  TileOptions,
-  ExpandDimsOptions,
-  EqualOptions,
-  NotEqualOptions,
-  ShapeOptions,
-  PowOptions,
-  ArgMinOptions,
-  FakeQuantOptions,
-  PackOptions,
-  LogicalOrOptions,
-  OneHotOptions,
-  LogicalAndOptions,
-  LogicalNotOptions,
-  UnpackOptions,
-  FloorDivOptions,
-  SquareOptions,
-  ZerosLikeOptions,
-  FillOptions,
-  BidirectionalSequenceLSTMOptions,
-  BidirectionalSequenceRNNOptions,
-  UnidirectionalSequenceLSTMOptions,
-  FloorModOptions,
-  RangeOptions,
-  ResizeNearestNeighborOptions,
-  LeakyReluOptions,
-  SquaredDifferenceOptions,
-  MirrorPadOptions,
-  AbsOptions,
-  SplitVOptions,
-  UniqueOptions,
-  ReverseV2Options,
-  AddNOptions,
-  GatherNdOptions,
-  CosOptions,
-  WhereOptions,
-  RankOptions,
-  ReverseSequenceOptions,
-  MatrixDiagOptions,
-  QuantizeOptions,
-  MatrixSetDiagOptions,
-  HardSwishOptions,
-  IfOptions,
-  WhileOptions,
-  DepthToSpaceOptions,
-  NonMaxSuppressionV4Options,
-  NonMaxSuppressionV5Options,
-  ScatterNdOptions,
-  SelectV2Options,
-  DensifyOptions,
-  SegmentSumOptions,
-  BatchMatMulOptions,
-  CumsumOptions,
-  CallOnceOptions,
-  BroadcastToOptions,
-  Rfft2dOptions,
-  Conv3DOptions,
-  HashtableOptions,
-  HashtableFindOptions,
-  HashtableImportOptions,
-  HashtableSizeOptions,
-  VarHandleOptions,
-  ReadVariableOptions,
-  AssignVariableOptions,
-}
-
-enum Padding : byte { SAME, VALID }
-
-enum ActivationFunctionType : byte {
-  NONE = 0,
-  RELU = 1,
-  RELU_N1_TO_1 = 2,
-  RELU6 = 3,
-  TANH = 4,
-  SIGN_BIT = 5,
-}
-
-table Conv2DOptions {
-  padding:Padding;
-  stride_w:int;
-  stride_h:int;
-  fused_activation_function:ActivationFunctionType;
-  dilation_w_factor:int = 1;
-  dilation_h_factor:int = 1;
-}
-
-// Options for both Conv3D and Conv3DTranspose.
-table Conv3DOptions {
-  padding:Padding;
-  stride_d:int;
-  stride_w:int;
-  stride_h:int;
-  fused_activation_function:ActivationFunctionType;
-  dilation_d_factor:int = 1;
-  dilation_w_factor:int = 1;
-  dilation_h_factor:int = 1;
-}
-
-table Pool2DOptions {
-  padding:Padding;
-  stride_w:int;
-  stride_h:int;
-  filter_width:int;
-  filter_height:int;
-  fused_activation_function:ActivationFunctionType;
-}
-
-table DepthwiseConv2DOptions {
-  // Parameters for DepthwiseConv version 1 or above.
-  padding:Padding;
-  stride_w:int;
-  stride_h:int;
-  // `depth_multiplier` is redundant. It's used by CPU kernels in
-  // TensorFlow 2.0 or below, but ignored in versions above.
-  // See comments in lite/c/builtin_op_data.h for more details.
-  depth_multiplier:int;
-  fused_activation_function:ActivationFunctionType;
-  // Parameters for DepthwiseConv version 2 or above.
-  dilation_w_factor:int = 1;
-  dilation_h_factor:int = 1;
-}
-
-table ConcatEmbeddingsOptions {
-  num_channels:int;
-  num_columns_per_channel:[int];
-  embedding_dim_per_channel:[int]; // This could be inferred from parameters.
-}
-
-enum LSHProjectionType: byte {
-  UNKNOWN = 0,
-  SPARSE = 1,
-  DENSE = 2,
-}
-
-table LSHProjectionOptions {
-  type: LSHProjectionType;
-}
-
-table SVDFOptions {
-  rank:int;
-  fused_activation_function:ActivationFunctionType;
-  // For weights-only quantization, use asymmetric quantization for non
-  // constant inputs at evaluation time.
-  asymmetric_quantize_inputs:bool;
-}
-
-// An implementation of TensorFlow RNNCell.
-table RNNOptions {
-  fused_activation_function:ActivationFunctionType;
-  asymmetric_quantize_inputs:bool;
-}
-
-// An implementation of TensorFlow dynamic_rnn with RNNCell.
-table SequenceRNNOptions {
-  time_major:bool;
-  fused_activation_function:ActivationFunctionType;
-  asymmetric_quantize_inputs:bool;
-}
-
-// An implementation of TensorFlow bidrectional_dynamic_rnn with RNNCell.
-table BidirectionalSequenceRNNOptions {
-  time_major:bool;
-  fused_activation_function:ActivationFunctionType;
-  merge_outputs: bool;
-  asymmetric_quantize_inputs:bool;
-}
-
-enum FullyConnectedOptionsWeightsFormat: byte {
-  DEFAULT = 0,
-  SHUFFLED4x16INT8 = 1,
-}
-
-// An implementation of TensorFlow fully_connected (a.k.a Dense) layer.
-table FullyConnectedOptions {
-  // Parameters for FullyConnected version 1 or above.
-  fused_activation_function:ActivationFunctionType;
-
-  // Parameters for FullyConnected version 2 or above.
-  weights_format:FullyConnectedOptionsWeightsFormat = DEFAULT;
-
-  // Parameters for FullyConnected version 5 or above.
-  // If set to true, then the number of dimension is preserved. Furthermore,
-  // all but the last dimension of the input and output shapes will be equal.
-  keep_num_dims: bool;
-
-  // Parameters for FullyConnected version 7 or above.
-  // If set to true, then weights-only op will use asymmetric quantization for
-  // inputs.
-  asymmetric_quantize_inputs: bool;
-}
-
-table SoftmaxOptions {
-  beta: float;
-}
-
-// An implementation of TensorFlow concat.
-table ConcatenationOptions {
-  axis:int;
-  fused_activation_function:ActivationFunctionType;
-}
-
-table AddOptions {
-  fused_activation_function:ActivationFunctionType;
-  // Parameters supported by version 3.
-  pot_scale_int16:bool = true;
-}
-
-table MulOptions {
-  fused_activation_function:ActivationFunctionType;
-}
-
-table L2NormOptions {
-  // This field is currently ignored in the L2 Norm Op.
-  fused_activation_function:ActivationFunctionType;
-}
-
-table LocalResponseNormalizationOptions {
-  radius:int;
-  bias:float;
-  alpha:float;
-  beta:float;
-}
-
-enum LSTMKernelType : byte {
-  // Full LSTM kernel which supports peephole and projection.
-  FULL = 0,
-  // Basic LSTM kernels. Equivalent to TensorFlow BasicLSTMCell.
-  BASIC = 1,
-}
-
-// An implementation of TensorFlow LSTMCell and CoupledInputForgetGateLSTMCell
-table LSTMOptions {
-  // Parameters for LSTM version 1 or above.
-  fused_activation_function:ActivationFunctionType;
-  cell_clip: float; // Optional, 0.0 means no clipping
-  proj_clip: float; // Optional, 0.0 means no clipping
-
-  // Parameters for LSTM version 2 or above.
-  // Basic kernel is only supported in version 2 or above.
-  kernel_type: LSTMKernelType = FULL;
-
-  // Parameters for LSTM version 4 or above.
-  asymmetric_quantize_inputs: bool;
-}
-
-// An implementation of TensorFlow dynamic_rnn with LSTMCell.
-table UnidirectionalSequenceLSTMOptions {
-  fused_activation_function:ActivationFunctionType;
-  cell_clip: float; // Optional, 0.0 means no clipping
-  proj_clip: float; // Optional, 0.0 means no clipping
-
-  // If true then first dimension is sequence, otherwise batch.
-  time_major:bool;
-
-  // Parameter for Unidirectional Sequence LSTM version 4.
-  asymmetric_quantize_inputs:bool;
-}
-
-table BidirectionalSequenceLSTMOptions {
-  // Parameters supported by version 1:
-  fused_activation_function:ActivationFunctionType;
-  cell_clip: float; // Optional, 0.0 means no clipping
-  proj_clip: float; // Optional, 0.0 means no clipping
-
-  // If true, store the outputs of both directions into the first output.
-  merge_outputs: bool;
-
-  // Parameters supported by version 2:
-  // If true then first dimension is sequence, otherwise batch.
-  // Version 1 implementations assumed time_major to be true, so this default
-  // value should never change.
-  time_major: bool = true;
-
-  // Parameters for version 3 or above.
-  asymmetric_quantize_inputs:bool;
-}
-
-table ResizeBilinearOptions {
-  new_height: int (deprecated);
-  new_width: int (deprecated);
-  align_corners: bool;
-  half_pixel_centers: bool;
-}
-
-table ResizeNearestNeighborOptions {
-  align_corners: bool;
-  half_pixel_centers: bool;
-}
-
-// A call operation options
-table CallOptions {
-  // The subgraph index that needs to be called.
-  subgraph:uint;
-}
-
-table PadOptions {
-}
-
-table PadV2Options {
-}
-
-table ReshapeOptions {
-  new_shape:[int];
-}
-
-table SpaceToBatchNDOptions {
-}
-
-table BatchToSpaceNDOptions {
-}
-
-table SkipGramOptions {
-  ngram_size: int;
-  max_skip_size: int;
-  include_all_ngrams: bool;
-}
-
-table SpaceToDepthOptions {
-  block_size: int;
-}
-
-table DepthToSpaceOptions {
-  block_size: int;
-}
-
-table SubOptions {
-  fused_activation_function:ActivationFunctionType;
-  // Parameters supported by version 5
-  pot_scale_int16:bool = true;
-}
-
-table DivOptions {
-  fused_activation_function:ActivationFunctionType;
-}
-
-table TopKV2Options {
-}
-
-enum CombinerType : byte {
-  SUM = 0,
-  MEAN = 1,
-  SQRTN = 2,
-}
-
-table EmbeddingLookupSparseOptions {
-  combiner:CombinerType;
-}
-
-table GatherOptions {
-  axis: int;
-  // Parameters for Gather version 5 or above.
-  batch_dims: int = 0;
-}
-
-table TransposeOptions {
-}
-
-table ExpOptions {
-}
-
-table CosOptions {
-}
-
-table ReducerOptions {
-  keep_dims: bool;
-}
-
-table SqueezeOptions {
-  squeeze_dims:[int];
-}
-
-table SplitOptions {
-  num_splits: int;
-}
-
-table SplitVOptions {
-  num_splits: int;
-}
-
-table StridedSliceOptions {
-  begin_mask: int;
-  end_mask: int;
-  ellipsis_mask: int;
-  new_axis_mask: int;
-  shrink_axis_mask: int;
-}
-
-table LogSoftmaxOptions {
-}
-
-table CastOptions {
-  in_data_type: TensorType;
-  out_data_type: TensorType;
-}
-
-table DequantizeOptions {
-}
-
-table MaximumMinimumOptions {
-}
-
-table TileOptions {
-}
-
-table ArgMaxOptions {
-  output_type : TensorType;
-}
-
-table ArgMinOptions {
-  output_type : TensorType;
-}
-
-table GreaterOptions {
-}
-
-table GreaterEqualOptions {
-}
-
-table LessOptions {
-}
-
-table LessEqualOptions {
-}
-
-table NegOptions {
-}
-
-table SelectOptions {
-}
-
-table SliceOptions {
-}
-
-table TransposeConvOptions {
-  padding:Padding;
-  stride_w:int;
-  stride_h:int;
-}
-
-table ExpandDimsOptions {
-}
-
-table SparseToDenseOptions {
-  validate_indices:bool;
-}
-
-table EqualOptions {
-}
-
-table NotEqualOptions {
-}
-
-table ShapeOptions {
-  // Optional output type of the operation (int32 or int64). Defaults to int32.
-  out_type : TensorType;
-}
-
-table RankOptions {
-}
-
-table PowOptions {
-}
-
-table FakeQuantOptions {
-  // Parameters supported by version 1:
-  min:float;
-  max:float;
-  num_bits:int;
-
-  // Parameters supported by version 2:
-  narrow_range:bool;
-}
-
-table PackOptions {
-  values_count:int;
-  axis:int;
-}
-
-table LogicalOrOptions {
-}
-
-table OneHotOptions {
-  axis:int;
-}
-
-table AbsOptions {
-}
-
-
-table HardSwishOptions {
-}
-
-table LogicalAndOptions {
-}
-
-table LogicalNotOptions {
-}
-
-table UnpackOptions {
-  num:int;
-  axis:int;
-}
-
-table FloorDivOptions {
-}
-
-table SquareOptions {
-}
-
-table ZerosLikeOptions {
-}
-
-table FillOptions {
-}
-
-table FloorModOptions {
-}
-
-table RangeOptions {
-}
-
-table LeakyReluOptions {
-  alpha:float;
-}
-
-table SquaredDifferenceOptions {
-}
-
-enum MirrorPadMode : byte {
-  // Doesn't include borders.
-  REFLECT = 0,
-  // Includes borders.
-  SYMMETRIC = 1,
-}
-
-table MirrorPadOptions {
-  mode:MirrorPadMode;
-}
-
-table UniqueOptions {
-  idx_out_type:TensorType = INT32;
-}
-
-table ReverseV2Options {
-}
-
-table AddNOptions {
-}
-
-table GatherNdOptions {
-}
-
-table WhereOptions {
-}
-
-table ReverseSequenceOptions {
-  seq_dim:int;
-  batch_dim:int = 0;
-}
-
-table MatrixDiagOptions {
-}
-
-table QuantizeOptions {
-}
-
-table MatrixSetDiagOptions {
-}
-
-table IfOptions {
-  then_subgraph_index:int;
-  else_subgraph_index:int;
-}
-
-table CallOnceOptions {
-  init_subgraph_index:int;
-}
-
-table WhileOptions {
-  cond_subgraph_index:int;
-  body_subgraph_index:int;
-}
-
-table NonMaxSuppressionV4Options {
-}
-
-table NonMaxSuppressionV5Options {
-}
-
-table ScatterNdOptions {
-}
-
-table SelectV2Options {
-}
-
-table DensifyOptions {
-}
-
-table SegmentSumOptions {
-}
-
-table BatchMatMulOptions {
-  adj_x:bool;
-  adj_y:bool;
-  // Parameters for BatchMatMul version 4 or above.
-  // If set to true, then weights-only op will use asymmetric quantization for
-  // inputs.
-  asymmetric_quantize_inputs: bool;
-}
-
-table CumsumOptions {
-  exclusive:bool;
-  reverse:bool;
-}
-
-table BroadcastToOptions {
-}
-
-table Rfft2dOptions {
-}
-
-table HashtableOptions {
-  // The identity of hash tables. This identity will be used across different
-  // subgraphs in the same interpreter instance.
-  table_id:int;
-  key_dtype:TensorType;
-  value_dtype:TensorType;
-}
-
-table HashtableFindOptions {
-}
-
-table HashtableImportOptions {
-}
-
-table HashtableSizeOptions {
-}
-
-table VarHandleOptions {
-  container:string;
-  shared_name:string;
-}
-
-table ReadVariableOptions {
-}
-
-table AssignVariableOptions {
-}
-
-// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
-// builtin, or a string if the operator is custom.
-table OperatorCode {
-  // This field is for backward compatibility. This field will be used when
-  // the value of the extended builtin_code field has less than
-  // BulitinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES.
-  deprecated_builtin_code:byte;
-  custom_code:string;
-
-  // The version of the operator. The version need to be bumped whenever new
-  // parameters are introduced into an op.
-  version:int = 1;
-
-  // This field is introduced for resolving op builtin code shortage problem
-  // (the original BuiltinOperator enum field was represented as a byte).
-  // This field will be used when the value of the extended builtin_code field
-  // has greater than BulitinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES.
-  builtin_code:BuiltinOperator;
-}
-
-enum CustomOptionsFormat : byte {
-  FLEXBUFFERS = 0,
-}
-
-// An operator takes tensors as inputs and outputs. The type of operation being
-// performed is determined by an index into the list of valid OperatorCodes,
-// while the specifics of each operations is configured using builtin_options
-// or custom_options.
-table Operator {
-  // Index into the operator_codes array. Using an integer here avoids
-  // complicate map lookups.
-  opcode_index:uint;
-
-  // Optional input are indicated by -1.
-  inputs:[int];
-  outputs:[int];
-
-  builtin_options:BuiltinOptions;
-  custom_options:[ubyte];
-  custom_options_format:CustomOptionsFormat;
-
-  // A list of booleans indicating the input tensors which are being mutated by
-  // this operator.(e.g. used by RNN and LSTM).
-  // For example, if the "inputs" array refers to 5 tensors and the second and
-  // fifth are mutable variables, then this list will contain
-  // [false, true, false, false, true].
-  //
-  // If the list is empty, no variable is mutated in this operator.
-  // The list either has the same length as `inputs`, or is empty.
-  mutating_variable_inputs:[bool];
-
-  // A list of indices to the subgraph's "tensors" that are internal to an Op.
-  // Internal tensors are those that do not flow in or out of the operation,
-  // but instead are part of internal computation. As such, the operation's
-  // implementation may manage its memory more efficiently. They are needed
-  // however (i.e. not just an implementation detail) since they are part of the
-  // computation, which may require relevant metadata such as quantization
-  // parameters.
-  intermediates:[int];
-}
-
-// The root type, defining a subgraph, which typically represents an entire
-// model.
-table SubGraph {
-  // A list of all tensors used in this subgraph.
-  tensors:[Tensor];
-
-  // Indices of the tensors that are inputs into this subgraph. Note this is
-  // the list of non-static tensors that feed into the subgraph for inference.
-  inputs:[int];
-
-  // Indices of the tensors that are outputs out of this subgraph. Note this is
-  // the list of output tensors that are considered the product of the
-  // subgraph's inference.
-  outputs:[int];
-
-  // All operators, in execution order.
-  operators:[Operator];
-
-  // Name of this subgraph (used for debugging).
-  name:string;
-}
-
-// Table of raw data buffers (used for constant tensors). Referenced by tensors
-// by index. The generous alignment accommodates mmap-friendly data structures.
-table Buffer {
-  data:[ubyte] (force_align: 16);
-}
-
-table Metadata {
-  // A human readable string to uniquely identify a Metadata.
-  name:string;
-  // An index to the buffers table.
-  buffer:uint;
-}
-
-// Map from an alias name of tensor to tensor index in the graph.
-// This is used in Signature def.
-table TensorMap {
-  // Represents the alias to use for this tensor.
-  name:string;
-
-  // The actual tensor index in the primary graph, that 'name' corresponds to.
-  tensor_index:uint;
-}
-
-// This corresponds to SignatureDef in Tensorflow SavedModel.
-// The SignatureDef will be part of the SavedModel provided for conversion.
-table SignatureDef {
-  // Named inputs for this signature.
-  inputs:[TensorMap];
-
-  // Named outputs for this signature.
-  outputs:[TensorMap];
-
-  // Key value which was in the Tensorflow SavedModel SignatureDef map.
-  signature_key:string;
-
-  // Model tag, deprecated.
-  deprecated_tag:string (deprecated);
-
-  // Index of subgraphs that corresponds to the exported method.
-  subgraph_index:uint;
-}
-
-table Model {
-  // Version of the schema.
-  version:uint;
-
-  // A list of all operator codes used in this model. This is
-  // kept in order because operators carry an index into this
-  // vector.
-  operator_codes:[OperatorCode];
-
-  // All the subgraphs of the model. The 0th is assumed to be the main
-  // model.
-  subgraphs:[SubGraph];
-
-  // A description of the model.
-  description:string;
-
-  // Buffers of the model.
-  // Note the 0th entry of this array must be an empty buffer (sentinel).
-  // This is a convention so that tensors without a buffer can provide 0 as
-  // their buffer.
-  buffers:[Buffer];
-
-  // Metadata about the model. Indirects into the existings buffers list.
-  // Deprecated, prefer to use metadata field.
-  metadata_buffer:[int];
-
-  // Metadata about the model.
-  metadata:[Metadata];
-
-  // Optional SignatureDefs for the model.
-  signature_defs:[SignatureDef];
-}
-
-root_type Model;
diff --git a/tensorflow/lite/stateful_error_reporter.h b/tensorflow/lite/stateful_error_reporter.h
index cf66934..10dc096 100644
--- a/tensorflow/lite/stateful_error_reporter.h
+++ b/tensorflow/lite/stateful_error_reporter.h
@@ -15,9 +15,10 @@
 #ifndef TENSORFLOW_LITE_STATEFUL_ERROR_REPORTER_H_
 #define TENSORFLOW_LITE_STATEFUL_ERROR_REPORTER_H_
 
+// LINT.IfChange
 #include <string>
 
-#include "tensorflow/lite/core/api/error_reporter.h"
+#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h"
 
 namespace tflite {
 
@@ -30,5 +31,6 @@
 };
 
 }  // namespace tflite
+// LINT.ThenChange(//tensorflow/compiler/mlir/lite/stateful_error_reporter.h)
 
 #endif  // TENSORFLOW_LITE_STATEFUL_ERROR_REPORTER_H_
diff --git a/tensorflow/lite/stderr_reporter.h b/tensorflow/lite/stderr_reporter.h
index 2eacb9e..fdac5d4 100644
--- a/tensorflow/lite/stderr_reporter.h
+++ b/tensorflow/lite/stderr_reporter.h
@@ -18,7 +18,6 @@
 #include <cstdarg>
 
 #include "tensorflow/lite/core/api/error_reporter.h"
-#include "tensorflow/lite/core/c/common.h"
 
 namespace tflite {
 
diff --git a/tensorflow/lite/testing/op_tests/is_finite.py b/tensorflow/lite/testing/op_tests/is_finite.py
index 2425fa9..493ea05 100644
--- a/tensorflow/lite/testing/op_tests/is_finite.py
+++ b/tensorflow/lite/testing/op_tests/is_finite.py
@@ -52,7 +52,7 @@
 
     input_values[random_index(input_values.shape)] = np.inf
     input_values[random_index(input_values.shape)] = -np.inf
-    input_values[random_index(input_values.shape)] = np.NAN
+    input_values[random_index(input_values.shape)] = np.nan
     input_values[random_index(input_values.shape)] = tf.float32.max
     input_values[random_index(input_values.shape)] = tf.float32.min
 
diff --git a/tensorflow/lite/toco/BUILD b/tensorflow/lite/toco/BUILD
index 4f370bb..8c9ccf7 100644
--- a/tensorflow/lite/toco/BUILD
+++ b/tensorflow/lite/toco/BUILD
@@ -302,7 +302,6 @@
         "//tensorflow/lite/kernels/internal:strided_slice_logic",
         "@com_google_absl//absl/container:inlined_vector",
         "@com_google_absl//absl/log:check",
-        "@com_google_absl//absl/memory",
         "@com_google_absl//absl/status",
         "@com_google_absl//absl/strings",
     ],
diff --git a/tensorflow/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc b/tensorflow/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc
index 89da8a6..d6932b7 100644
--- a/tensorflow/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc
+++ b/tensorflow/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc
@@ -17,11 +17,14 @@
 #include <unordered_map>
 #include <vector>
 
+#include "absl/status/status.h"
 #include "absl/strings/str_cat.h"
+#include "tensorflow/core/platform/errors.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
 #include "tensorflow/lite/toco/model.h"
 #include "tensorflow/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
 
 namespace toco {
 
diff --git a/tensorflow/lite/toco/graph_transformations/convert_matrix_diag_v2_or_v3_to_v1.cc b/tensorflow/lite/toco/graph_transformations/convert_matrix_diag_v2_or_v3_to_v1.cc
index cc519e4..6d2b5ca 100644
--- a/tensorflow/lite/toco/graph_transformations/convert_matrix_diag_v2_or_v3_to_v1.cc
+++ b/tensorflow/lite/toco/graph_transformations/convert_matrix_diag_v2_or_v3_to_v1.cc
@@ -12,8 +12,9 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
+#include "absl/status/status.h"
 #include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
 #include "tensorflow/lite/toco/model.h"
 #include "tensorflow/lite/toco/tooling_util.h"
diff --git a/tensorflow/lite/toco/graph_transformations/convert_matrix_set_diag_v2_or_v3_to_v1.cc b/tensorflow/lite/toco/graph_transformations/convert_matrix_set_diag_v2_or_v3_to_v1.cc
index 66d7f64..84e84aa 100644
--- a/tensorflow/lite/toco/graph_transformations/convert_matrix_set_diag_v2_or_v3_to_v1.cc
+++ b/tensorflow/lite/toco/graph_transformations/convert_matrix_set_diag_v2_or_v3_to_v1.cc
@@ -17,9 +17,9 @@
 #include <unordered_map>
 #include <vector>
 
-#include "absl/strings/str_cat.h"
+#include "absl/status/status.h"
 #include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
 #include "tensorflow/lite/toco/model.h"
 #include "tensorflow/lite/toco/tooling_util.h"
diff --git a/tensorflow/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc b/tensorflow/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc
index 8f56bfa..b7763e1 100644
--- a/tensorflow/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc
+++ b/tensorflow/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc
@@ -17,10 +17,11 @@
 #include <unordered_map>
 #include <vector>
 
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
 #include "tensorflow/lite/toco/model.h"
 #include "tensorflow/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
 
 namespace toco {
 
diff --git a/tensorflow/lite/toco/graph_transformations/convert_reorder_axes.cc b/tensorflow/lite/toco/graph_transformations/convert_reorder_axes.cc
index 49c380b..60dcf00 100644
--- a/tensorflow/lite/toco/graph_transformations/convert_reorder_axes.cc
+++ b/tensorflow/lite/toco/graph_transformations/convert_reorder_axes.cc
@@ -17,11 +17,12 @@
 #include <unordered_map>
 #include <vector>
 
-#include "absl/strings/str_cat.h"
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
 #include "tensorflow/lite/toco/model.h"
 #include "tensorflow/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
 
 namespace toco {
 
diff --git a/tensorflow/lite/toco/graph_transformations/convert_squeeze_to_reshape.cc b/tensorflow/lite/toco/graph_transformations/convert_squeeze_to_reshape.cc
index c3bfbf5..c98d64d 100644
--- a/tensorflow/lite/toco/graph_transformations/convert_squeeze_to_reshape.cc
+++ b/tensorflow/lite/toco/graph_transformations/convert_squeeze_to_reshape.cc
@@ -17,11 +17,12 @@
 #include <unordered_map>
 #include <vector>
 
-#include "absl/strings/str_cat.h"
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
 #include "tensorflow/lite/toco/model.h"
 #include "tensorflow/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
 
 namespace toco {
 
diff --git a/tensorflow/lite/toco/graph_transformations/convert_trivial_addn_to_add.cc b/tensorflow/lite/toco/graph_transformations/convert_trivial_addn_to_add.cc
index 547e0d8..c60ddff 100644
--- a/tensorflow/lite/toco/graph_transformations/convert_trivial_addn_to_add.cc
+++ b/tensorflow/lite/toco/graph_transformations/convert_trivial_addn_to_add.cc
@@ -12,7 +12,9 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
+#include "absl/status/status.h"
 #include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
 #include "tensorflow/lite/toco/model.h"
 #include "tensorflow/lite/toco/tooling_util.h"
diff --git a/tensorflow/lite/toco/graph_transformations/convert_trivial_pack_to_reshape.cc b/tensorflow/lite/toco/graph_transformations/convert_trivial_pack_to_reshape.cc
index f493d4e..c945615 100644
--- a/tensorflow/lite/toco/graph_transformations/convert_trivial_pack_to_reshape.cc
+++ b/tensorflow/lite/toco/graph_transformations/convert_trivial_pack_to_reshape.cc
@@ -17,11 +17,12 @@
 #include <unordered_map>
 #include <vector>
 
-#include "absl/strings/str_cat.h"
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
 #include "tensorflow/lite/toco/model.h"
 #include "tensorflow/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
 
 namespace toco {
 
diff --git a/tensorflow/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc b/tensorflow/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc
index 4781f4e..71a7d92 100644
--- a/tensorflow/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc
+++ b/tensorflow/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc
@@ -14,10 +14,12 @@
 ==============================================================================*/
 #include <vector>
 
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
 #include "tensorflow/lite/toco/model.h"
+#include "tensorflow/lite/toco/toco_types.h"
 #include "tensorflow/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
 
 namespace toco {
 
diff --git a/tensorflow/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc b/tensorflow/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc
index 183cb53..8a33ad5 100644
--- a/tensorflow/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc
+++ b/tensorflow/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc
@@ -15,7 +15,9 @@
 #include <string>
 #include <vector>
 
+#include "absl/status/status.h"
 #include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
 #include "tensorflow/lite/toco/model.h"
 #include "tensorflow/lite/toco/tooling_util.h"
diff --git a/tensorflow/lite/toco/graph_transformations/create_im2col_arrays.cc b/tensorflow/lite/toco/graph_transformations/create_im2col_arrays.cc
index f69afe4..380cdf2 100644
--- a/tensorflow/lite/toco/graph_transformations/create_im2col_arrays.cc
+++ b/tensorflow/lite/toco/graph_transformations/create_im2col_arrays.cc
@@ -17,11 +17,12 @@
 #include <unordered_map>
 #include <vector>
 
-#include "absl/strings/str_cat.h"
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
 #include "tensorflow/lite/toco/model.h"
 #include "tensorflow/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
 
 namespace toco {
 
diff --git a/tensorflow/lite/toco/graph_transformations/dequantize.cc b/tensorflow/lite/toco/graph_transformations/dequantize.cc
index 1aa5069..5dd4d2e 100644
--- a/tensorflow/lite/toco/graph_transformations/dequantize.cc
+++ b/tensorflow/lite/toco/graph_transformations/dequantize.cc
@@ -17,11 +17,13 @@
 #include <unordered_map>
 #include <vector>
 
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
 #include "tensorflow/lite/toco/graph_transformations/remove_trivial_passthrough.h"
 #include "tensorflow/lite/toco/model.h"
 #include "tensorflow/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
 
 namespace toco {
 
diff --git a/tensorflow/lite/toco/graph_transformations/drop_fake_quant.cc b/tensorflow/lite/toco/graph_transformations/drop_fake_quant.cc
index 0a7af2f..cdd748a 100644
--- a/tensorflow/lite/toco/graph_transformations/drop_fake_quant.cc
+++ b/tensorflow/lite/toco/graph_transformations/drop_fake_quant.cc
@@ -17,11 +17,12 @@
 #include <unordered_map>
 #include <vector>
 
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
 #include "tensorflow/lite/toco/graph_transformations/remove_trivial_passthrough.h"
 #include "tensorflow/lite/toco/model.h"
 #include "tensorflow/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
 
 namespace toco {
 
diff --git a/tensorflow/lite/toco/graph_transformations/drop_im2col_arrays.cc b/tensorflow/lite/toco/graph_transformations/drop_im2col_arrays.cc
index a076814..d3cfae0 100644
--- a/tensorflow/lite/toco/graph_transformations/drop_im2col_arrays.cc
+++ b/tensorflow/lite/toco/graph_transformations/drop_im2col_arrays.cc
@@ -12,10 +12,12 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
 #include "tensorflow/lite/toco/model.h"
 #include "tensorflow/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
 
 namespace toco {
 
diff --git a/tensorflow/lite/toco/graph_transformations/ensure_bias_vectors.cc b/tensorflow/lite/toco/graph_transformations/ensure_bias_vectors.cc
index 22d6d94..f8d639c 100644
--- a/tensorflow/lite/toco/graph_transformations/ensure_bias_vectors.cc
+++ b/tensorflow/lite/toco/graph_transformations/ensure_bias_vectors.cc
@@ -17,7 +17,9 @@
 #include <unordered_map>
 #include <vector>
 
+#include "absl/status/status.h"
 #include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
 #include "tensorflow/lite/toco/model.h"
 #include "tensorflow/lite/toco/tooling_util.h"
diff --git a/tensorflow/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc b/tensorflow/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc
index b496f51..ed3a89a 100644
--- a/tensorflow/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc
+++ b/tensorflow/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc
@@ -17,10 +17,13 @@
 #include <unordered_map>
 #include <vector>
 
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
 #include "tensorflow/lite/toco/model.h"
+#include "tensorflow/lite/toco/runtime/types.h"
 #include "tensorflow/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
 
 namespace toco {
 
diff --git a/tensorflow/lite/toco/graph_transformations/fuse_activation_functions.cc b/tensorflow/lite/toco/graph_transformations/fuse_activation_functions.cc
index 7d34270..64b91cc 100644
--- a/tensorflow/lite/toco/graph_transformations/fuse_activation_functions.cc
+++ b/tensorflow/lite/toco/graph_transformations/fuse_activation_functions.cc
@@ -17,11 +17,13 @@
 #include <unordered_map>
 #include <vector>
 
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
 #include "tensorflow/lite/toco/model.h"
 #include "tensorflow/lite/toco/runtime/types.h"
 #include "tensorflow/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
 
 namespace toco {
 
diff --git a/tensorflow/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc b/tensorflow/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc
index 926d41c..3afa9c4 100644
--- a/tensorflow/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc
+++ b/tensorflow/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc
@@ -18,11 +18,13 @@
 #include <unordered_map>
 #include <vector>
 
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
 #include "tensorflow/lite/toco/model.h"
 #include "tensorflow/lite/toco/runtime/types.h"
 #include "tensorflow/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
 
 namespace toco {
 
diff --git a/tensorflow/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc b/tensorflow/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc
index 130cb0b..fa0baf9 100644
--- a/tensorflow/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc
+++ b/tensorflow/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc
@@ -17,7 +17,9 @@
 #include <unordered_map>
 #include <vector>
 
+#include "absl/status/status.h"
 #include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
 #include "tensorflow/lite/toco/model.h"
 #include "tensorflow/lite/toco/runtime/types.h"
diff --git a/tensorflow/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc b/tensorflow/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc
index df1d6da..ba57090 100644
--- a/tensorflow/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc
+++ b/tensorflow/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc
@@ -17,10 +17,11 @@
 #include <unordered_map>
 #include <vector>
 
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
 #include "tensorflow/lite/toco/model.h"
 #include "tensorflow/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
 
 namespace toco {
 
diff --git a/tensorflow/lite/toco/graph_transformations/graph_transformations.cc b/tensorflow/lite/toco/graph_transformations/graph_transformations.cc
index bee6665..125e559 100644
--- a/tensorflow/lite/toco/graph_transformations/graph_transformations.cc
+++ b/tensorflow/lite/toco/graph_transformations/graph_transformations.cc
@@ -21,9 +21,12 @@
 #include <utility>
 #include <vector>
 
-#include "tensorflow/lite/toco/toco_port.h"
-#include "tensorflow/lite/toco/tooling_util.h"
 #include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
+#include "tensorflow/lite/toco/format_port.h"
+#include "tensorflow/lite/toco/model.h"
+#include "tensorflow/lite/toco/model_flags.pb.h"
+#include "tensorflow/lite/toco/tooling_util.h"
 
 namespace toco {
 
diff --git a/tensorflow/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/lite/toco/graph_transformations/graph_transformations.h
index 9f93ee1..c7e2c9de 100644
--- a/tensorflow/lite/toco/graph_transformations/graph_transformations.h
+++ b/tensorflow/lite/toco/graph_transformations/graph_transformations.h
@@ -21,6 +21,8 @@
 #include <unordered_set>
 #include <vector>
 
+#include "absl/log/check.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/lite/toco/model.h"
 #include "tensorflow/lite/toco/toco_port.h"
 
diff --git a/tensorflow/lite/toco/graph_transformations/group_bidirectional_sequence_ops.cc b/tensorflow/lite/toco/graph_transformations/group_bidirectional_sequence_ops.cc
index f5a8d16..2da6fbe 100644
--- a/tensorflow/lite/toco/graph_transformations/group_bidirectional_sequence_ops.cc
+++ b/tensorflow/lite/toco/graph_transformations/group_bidirectional_sequence_ops.cc
@@ -20,6 +20,8 @@
 #include <string>
 #include <vector>
 
+#include "absl/log/check.h"
+#include "absl/status/status.h"
 #include "tensorflow/core/lib/core/status.h"
 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
 #include "tensorflow/lite/toco/model.h"
diff --git a/tensorflow/lite/toco/graph_transformations/hardcode_min_max.cc b/tensorflow/lite/toco/graph_transformations/hardcode_min_max.cc
index 53c12b4..6f142a4 100644
--- a/tensorflow/lite/toco/graph_transformations/hardcode_min_max.cc
+++ b/tensorflow/lite/toco/graph_transformations/hardcode_min_max.cc
@@ -17,7 +17,9 @@
 #include <string>
 #include <vector>
 
+#include "absl/status/status.h"
 #include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
 #include "tensorflow/lite/toco/model.h"
 #include "tensorflow/lite/toco/tooling_util.h"
diff --git a/tensorflow/lite/toco/graph_transformations/identify_dilated_conv.cc b/tensorflow/lite/toco/graph_transformations/identify_dilated_conv.cc
index 026f51a..985e588 100644
--- a/tensorflow/lite/toco/graph_transformations/identify_dilated_conv.cc
+++ b/tensorflow/lite/toco/graph_transformations/identify_dilated_conv.cc
@@ -15,10 +15,12 @@
 #include <string>
 #include <vector>
 
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
 #include "tensorflow/lite/toco/model.h"
 #include "tensorflow/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
 
 namespace toco {
 
diff --git a/tensorflow/lite/toco/graph_transformations/identify_hardswish.cc b/tensorflow/lite/toco/graph_transformations/identify_hardswish.cc
index 4a6dea0..437147f 100644
--- a/tensorflow/lite/toco/graph_transformations/identify_hardswish.cc
+++ b/tensorflow/lite/toco/graph_transformations/identify_hardswish.cc
@@ -18,7 +18,8 @@
 #include <unordered_map>
 #include <vector>
 
-#include "tensorflow/core/platform/logging.h"
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
 #include "tensorflow/lite/toco/graph_transformations/identify_util.h"
 #include "tensorflow/lite/toco/model.h"
diff --git a/tensorflow/lite/toco/graph_transformations/identify_l2_normalization.cc b/tensorflow/lite/toco/graph_transformations/identify_l2_normalization.cc
index b66f0b0..e8a5d20 100644
--- a/tensorflow/lite/toco/graph_transformations/identify_l2_normalization.cc
+++ b/tensorflow/lite/toco/graph_transformations/identify_l2_normalization.cc
@@ -18,10 +18,12 @@
 #include <unordered_map>
 #include <vector>
 
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
 #include "tensorflow/lite/toco/model.h"
 #include "tensorflow/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
 
 namespace toco {
 
diff --git a/tensorflow/lite/toco/graph_transformations/identify_l2_pool.cc b/tensorflow/lite/toco/graph_transformations/identify_l2_pool.cc
index 91bda7e..a980995a 100644
--- a/tensorflow/lite/toco/graph_transformations/identify_l2_pool.cc
+++ b/tensorflow/lite/toco/graph_transformations/identify_l2_pool.cc
@@ -17,10 +17,12 @@
 #include <unordered_map>
 #include <vector>
 
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
 #include "tensorflow/lite/toco/model.h"
 #include "tensorflow/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
 
 namespace toco {
 
diff --git a/tensorflow/lite/toco/graph_transformations/identify_lstm.cc b/tensorflow/lite/toco/graph_transformations/identify_lstm.cc
index 18e74ae..df0aa9f 100644
--- a/tensorflow/lite/toco/graph_transformations/identify_lstm.cc
+++ b/tensorflow/lite/toco/graph_transformations/identify_lstm.cc
@@ -16,8 +16,12 @@
 #include <string>
 #include <vector>
 
+#include "absl/log/check.h"
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
 #include "tensorflow/lite/toco/model.h"
+#include "tensorflow/lite/toco/runtime/types.h"
 #include "tensorflow/lite/toco/tooling_util.h"
 
 namespace toco {
diff --git a/tensorflow/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc b/tensorflow/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc
index 4b2c497..24299d5 100644
--- a/tensorflow/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc
+++ b/tensorflow/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc
@@ -18,8 +18,9 @@
 #include <utility>
 #include <vector>
 
-#include "absl/memory/memory.h"
-#include "absl/strings/string_view.h"
+#include "absl/log/check.h"
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
 #include "tensorflow/lite/toco/graph_transformations/lstm_utils.h"
 #include "tensorflow/lite/toco/model.h"
diff --git a/tensorflow/lite/toco/graph_transformations/identify_lstm_split_inputs.cc b/tensorflow/lite/toco/graph_transformations/identify_lstm_split_inputs.cc
index 3de0a71..aea6d93 100644
--- a/tensorflow/lite/toco/graph_transformations/identify_lstm_split_inputs.cc
+++ b/tensorflow/lite/toco/graph_transformations/identify_lstm_split_inputs.cc
@@ -18,8 +18,8 @@
 #include <utility>
 #include <vector>
 
-#include "absl/memory/memory.h"
-#include "absl/strings/string_view.h"
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
 #include "tensorflow/lite/toco/graph_transformations/lstm_utils.h"
 #include "tensorflow/lite/toco/model.h"
diff --git a/tensorflow/lite/toco/graph_transformations/identify_nearest_upsample.cc b/tensorflow/lite/toco/graph_transformations/identify_nearest_upsample.cc
index 580b680..1d1d67b 100644
--- a/tensorflow/lite/toco/graph_transformations/identify_nearest_upsample.cc
+++ b/tensorflow/lite/toco/graph_transformations/identify_nearest_upsample.cc
@@ -17,8 +17,9 @@
 #include <string>
 #include <vector>
 
+#include "absl/status/status.h"
+#include "absl/strings/str_cat.h"
 #include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/platform/logging.h"
 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
 #include "tensorflow/lite/toco/model.h"
 #include "tensorflow/lite/toco/tooling_util.h"
diff --git a/tensorflow/lite/toco/graph_transformations/identify_prelu.cc b/tensorflow/lite/toco/graph_transformations/identify_prelu.cc
index 31edcb4..0f28cb1 100644
--- a/tensorflow/lite/toco/graph_transformations/identify_prelu.cc
+++ b/tensorflow/lite/toco/graph_transformations/identify_prelu.cc
@@ -17,10 +17,12 @@
 #include <unordered_map>
 #include <vector>
 
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
 #include "tensorflow/lite/toco/model.h"
+#include "tensorflow/lite/toco/runtime/types.h"
 #include "tensorflow/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
 
 // This transformation rule tries to identify the PRelu structure generated by
 // Keras, and convert it to a single op.
diff --git a/tensorflow/lite/toco/graph_transformations/identify_relu1.cc b/tensorflow/lite/toco/graph_transformations/identify_relu1.cc
index dad425c..6f2e224 100644
--- a/tensorflow/lite/toco/graph_transformations/identify_relu1.cc
+++ b/tensorflow/lite/toco/graph_transformations/identify_relu1.cc
@@ -17,7 +17,8 @@
 #include <unordered_map>
 #include <vector>
 
-#include "tensorflow/core/platform/logging.h"
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
 #include "tensorflow/lite/toco/graph_transformations/identify_util.h"
 #include "tensorflow/lite/toco/model.h"
diff --git a/tensorflow/lite/toco/graph_transformations/identify_util.cc b/tensorflow/lite/toco/graph_transformations/identify_util.cc
index e860511..6ed8e33 100644
--- a/tensorflow/lite/toco/graph_transformations/identify_util.cc
+++ b/tensorflow/lite/toco/graph_transformations/identify_util.cc
@@ -18,6 +18,7 @@
 #include <string>
 
 #include "tensorflow/lite/toco/model.h"
+#include "tensorflow/lite/toco/runtime/types.h"
 #include "tensorflow/lite/toco/tooling_util.h"
 
 namespace toco {
diff --git a/tensorflow/lite/toco/graph_transformations/identify_util.h b/tensorflow/lite/toco/graph_transformations/identify_util.h
index 1a79231..6c59b0b 100644
--- a/tensorflow/lite/toco/graph_transformations/identify_util.h
+++ b/tensorflow/lite/toco/graph_transformations/identify_util.h
@@ -17,6 +17,7 @@
 #include <string>
 
 #include "tensorflow/lite/toco/model.h"
+#include "tensorflow/lite/toco/runtime/types.h"
 
 namespace toco {
 
diff --git a/tensorflow/lite/toco/graph_transformations/lstm_utils.cc b/tensorflow/lite/toco/graph_transformations/lstm_utils.cc
index 7a979b7..676aa75 100644
--- a/tensorflow/lite/toco/graph_transformations/lstm_utils.cc
+++ b/tensorflow/lite/toco/graph_transformations/lstm_utils.cc
@@ -16,6 +16,9 @@
 
 #include <string>
 
+#include "tensorflow/lite/toco/model.h"
+#include "tensorflow/lite/toco/tooling_util.h"
+
 namespace toco {
 
 void CreateOptionalArray(Model* model, std::string* input_array_buffer,
diff --git a/tensorflow/lite/toco/graph_transformations/make_initial_dequantize_operator.cc b/tensorflow/lite/toco/graph_transformations/make_initial_dequantize_operator.cc
index 290dc7f..0726b32 100644
--- a/tensorflow/lite/toco/graph_transformations/make_initial_dequantize_operator.cc
+++ b/tensorflow/lite/toco/graph_transformations/make_initial_dequantize_operator.cc
@@ -17,12 +17,13 @@
 #include <unordered_map>
 #include <vector>
 
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
 #include "tensorflow/lite/toco/graph_transformations/quantization_util.h"
 #include "tensorflow/lite/toco/model.h"
 #include "tensorflow/lite/toco/model_flags.pb.h"
 #include "tensorflow/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
 
 namespace toco {
 
diff --git a/tensorflow/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc b/tensorflow/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc
index b07815e..a292b97 100644
--- a/tensorflow/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc
+++ b/tensorflow/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc
@@ -18,12 +18,14 @@
 #include <unordered_map>
 #include <vector>
 
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
 #include "tensorflow/lite/toco/graph_transformations/remove_trivial_passthrough.h"
 #include "tensorflow/lite/toco/model.h"
-#include "tensorflow/lite/toco/runtime/types.h"
+#include "tensorflow/lite/toco/toco_types.h"
 #include "tensorflow/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
 
 namespace toco {
 
diff --git a/tensorflow/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc b/tensorflow/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc
index 8538413..588a034 100644
--- a/tensorflow/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc
+++ b/tensorflow/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc
@@ -15,6 +15,9 @@
 #include <algorithm>
 #include <string>
 
+#include "absl/log/check.h"
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
 #include "tensorflow/lite/toco/model.h"
 #include "tensorflow/lite/toco/tooling_util.h"
diff --git a/tensorflow/lite/toco/graph_transformations/propagate_activation_function_into_constants.cc b/tensorflow/lite/toco/graph_transformations/propagate_activation_function_into_constants.cc
index 79d8229..fffdde0 100644
--- a/tensorflow/lite/toco/graph_transformations/propagate_activation_function_into_constants.cc
+++ b/tensorflow/lite/toco/graph_transformations/propagate_activation_function_into_constants.cc
@@ -17,12 +17,13 @@
 #include <unordered_map>
 #include <vector>
 
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
 #include "tensorflow/lite/toco/graph_transformations/remove_trivial_passthrough.h"
 #include "tensorflow/lite/toco/model.h"
-#include "tensorflow/lite/toco/runtime/types.h"
 #include "tensorflow/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
 
 namespace toco {
 
diff --git a/tensorflow/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/lite/toco/graph_transformations/propagate_array_data_types.cc
index af801c3..ef0a520 100644
--- a/tensorflow/lite/toco/graph_transformations/propagate_array_data_types.cc
+++ b/tensorflow/lite/toco/graph_transformations/propagate_array_data_types.cc
@@ -17,9 +17,11 @@
 #include <unordered_map>
 #include <vector>
 
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
 #include "tensorflow/lite/toco/model.h"
-#include "tensorflow/core/platform/logging.h"
 
 namespace toco {
 
diff --git a/tensorflow/lite/toco/graph_transformations/propagate_default_min_max.cc b/tensorflow/lite/toco/graph_transformations/propagate_default_min_max.cc
index 0f9197c..54b76fb 100644
--- a/tensorflow/lite/toco/graph_transformations/propagate_default_min_max.cc
+++ b/tensorflow/lite/toco/graph_transformations/propagate_default_min_max.cc
@@ -17,11 +17,13 @@
 #include <unordered_map>
 #include <vector>
 
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
 #include "tensorflow/lite/toco/graph_transformations/quantization_util.h"
 #include "tensorflow/lite/toco/model.h"
 #include "tensorflow/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
 
 namespace toco {
 
diff --git a/tensorflow/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc b/tensorflow/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc
index 10968a9..62d8715 100644
--- a/tensorflow/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc
+++ b/tensorflow/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc
@@ -17,11 +17,13 @@
 #include <unordered_map>
 #include <vector>
 
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
 #include "tensorflow/lite/toco/graph_transformations/quantization_util.h"
 #include "tensorflow/lite/toco/model.h"
 #include "tensorflow/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
 
 namespace toco {
 
diff --git a/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index ab6f407..5136bc0 100644
--- a/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -21,11 +21,14 @@
 #include <unordered_map>
 #include <vector>
 
+#include "absl/status/status.h"
 #include "absl/strings/str_join.h"
 #include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
 #include "tensorflow/lite/toco/model.h"
+#include "tensorflow/lite/toco/toco_types.h"
 #include "tensorflow/lite/toco/tooling_util.h"
 
 namespace toco {
diff --git a/tensorflow/lite/toco/graph_transformations/quantize.cc b/tensorflow/lite/toco/graph_transformations/quantize.cc
index 87f95a9..9e5e580 100644
--- a/tensorflow/lite/toco/graph_transformations/quantize.cc
+++ b/tensorflow/lite/toco/graph_transformations/quantize.cc
@@ -21,11 +21,15 @@
 #include <unordered_map>
 #include <vector>
 
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
 #include "tensorflow/lite/toco/graph_transformations/quantization_util.h"
 #include "tensorflow/lite/toco/model.h"
 #include "tensorflow/lite/toco/model_flags.pb.h"
+#include "tensorflow/lite/toco/toco_types.h"
 #include "tensorflow/lite/toco/tooling_util.h"
 
 namespace toco {
diff --git a/tensorflow/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc b/tensorflow/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc
index 5e867ea..bf9334f 100644
--- a/tensorflow/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc
+++ b/tensorflow/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc
@@ -18,10 +18,11 @@
 #include <unordered_map>
 #include <vector>
 
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
 #include "tensorflow/lite/toco/model.h"
-#include "tensorflow/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
 
 namespace toco {
 
diff --git a/tensorflow/lite/toco/graph_transformations/remove_final_dequantize_op.cc b/tensorflow/lite/toco/graph_transformations/remove_final_dequantize_op.cc
index 438c7a6..fc15e8e 100644
--- a/tensorflow/lite/toco/graph_transformations/remove_final_dequantize_op.cc
+++ b/tensorflow/lite/toco/graph_transformations/remove_final_dequantize_op.cc
@@ -17,11 +17,12 @@
 #include <unordered_map>
 #include <vector>
 
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
 #include "tensorflow/lite/toco/model.h"
 #include "tensorflow/lite/toco/model_flags.pb.h"
 #include "tensorflow/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
 
 namespace toco {
 
diff --git a/tensorflow/lite/toco/graph_transformations/remove_successive_transpose.cc b/tensorflow/lite/toco/graph_transformations/remove_successive_transpose.cc
index fdc4d27..79e6b68 100644
--- a/tensorflow/lite/toco/graph_transformations/remove_successive_transpose.cc
+++ b/tensorflow/lite/toco/graph_transformations/remove_successive_transpose.cc
@@ -15,7 +15,8 @@
 #include <string>
 #include <vector>
 
-#include "tensorflow/core/platform/logging.h"
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
 #include "tensorflow/lite/toco/model.h"
 #include "tensorflow/lite/toco/tooling_util.h"
diff --git a/tensorflow/lite/toco/graph_transformations/remove_tensorflow_assert.cc b/tensorflow/lite/toco/graph_transformations/remove_tensorflow_assert.cc
index 88402f0..45de603 100644
--- a/tensorflow/lite/toco/graph_transformations/remove_tensorflow_assert.cc
+++ b/tensorflow/lite/toco/graph_transformations/remove_tensorflow_assert.cc
@@ -16,10 +16,12 @@
 #include <string>
 #include <vector>
 
+#include "absl/status/status.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
 #include "tensorflow/lite/toco/model.h"
 #include "tensorflow/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
 
 namespace toco {
 
diff --git a/tensorflow/lite/tools/optimize/modify_model_interface.cc b/tensorflow/lite/tools/optimize/modify_model_interface.cc
index 2af2324..860fa05 100644
--- a/tensorflow/lite/tools/optimize/modify_model_interface.cc
+++ b/tensorflow/lite/tools/optimize/modify_model_interface.cc
@@ -248,7 +248,7 @@
 
 TfLiteStatus RemoveInputTensor(ModelT* model,
                                const std::vector<TensorOpTensor>& inputs,
-                               int32 original_number_tensors) {
+                               int32_t original_number_tensors) {
   // Consistency check to make sure that erase start from the end.
   int last_op_index = std::numeric_limits<int32_t>::max();
   int last_tensor_index = std::numeric_limits<int32_t>::max();
@@ -274,7 +274,7 @@
 
 TfLiteStatus RemoveOutputTensor(ModelT* model,
                                 const std::vector<TensorOpTensor>& outputs,
-                                int32 original_number_tensors) {
+                                int32_t original_number_tensors) {
   // Consistency check to make sure that erase start from the end.
   int last_op_index = std::numeric_limits<int32_t>::max();
   int last_tensor_index = std::numeric_limits<int32_t>::max();
@@ -298,7 +298,6 @@
   return kTfLiteOk;
 }
 
-
 int GetOriginalNumberOfTensors(const TensorType& input_type,
                                const TensorType& output_type, ModelT* model,
                                ErrorReporter* error_reporter) {
diff --git a/tensorflow/opensource_only.files b/tensorflow/opensource_only.files
index ab15889..a1c155f 100644
--- a/tensorflow/opensource_only.files
+++ b/tensorflow/opensource_only.files
@@ -321,6 +321,7 @@
 tf_staging/third_party/repo.bzl:
 tf_staging/third_party/six.BUILD:
 tf_staging/third_party/snappy.BUILD:
+tf_staging/third_party/spirv_llvm_translator/spirv_llvm_translator.BUILD:
 tf_staging/third_party/sqlite.BUILD:
 tf_staging/third_party/stablehlo/BUILD:
 tf_staging/third_party/systemlibs/BUILD.tpl:
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 035daec..8aa243a 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -932,7 +932,6 @@
         "//tensorflow/core/util/tensor_bundle",  # checkpoint_reader
         "//tensorflow/dtensor/cc:dtensor_device_cc",  # DTensor
         "//tensorflow/dtensor/cc:tensor_layout",  # DTensor
-        "//tensorflow/lite/kernels/shim:tf_op_shim",  # tf_text
         "//tensorflow/lite/toco/python:toco_python_api",  # toco
         "//tensorflow/python/client:tf_session_helper",  # tf_session
         "//tensorflow/python/eager:pywrap_tfe_lib",  # pywrap_tfe_lib
diff --git a/tensorflow/python/autograph/g3doc/pyct_tutorial.ipynb b/tensorflow/python/autograph/g3doc/pyct_tutorial.ipynb
index 8b7b3e9..44c2ea6 100644
--- a/tensorflow/python/autograph/g3doc/pyct_tutorial.ipynb
+++ b/tensorflow/python/autograph/g3doc/pyct_tutorial.ipynb
@@ -164,7 +164,7 @@
       "source": [
         "### Helpful static analysis passes\n",
         "\n",
-        "The `static_analysis` module contains various helper passes for dataflow analyis.\n",
+        "The `static_analysis` module contains various helper passes for dataflow analysis.\n",
         "\n",
         "All these passes annotate the AST. These annotations can be extracted using [anno.getanno](https://github.com/tensorflow/tensorflow/blob/40802bcdb5c8a4379da2145441f51051402bd29b/tensorflow/python/autograph/pyct/anno.py#L111). Most of them rely on the `qual_names` annotations, which just simplify the way more complex identifiers like `a.b.c` are accessed.\n",
         "\n",
@@ -253,7 +253,7 @@
         "\n",
         "\n",
         "def f(a):\n",
-        "  if a \u003e 0:\n",
+        "  if a > 0:\n",
         "    return a\n",
         "  b = -a\n",
         "\n",
diff --git a/tensorflow/python/autograph/pyct/origin_info.py b/tensorflow/python/autograph/pyct/origin_info.py
index 08f7342..ffacbe4 100644
--- a/tensorflow/python/autograph/pyct/origin_info.py
+++ b/tensorflow/python/autograph/pyct/origin_info.py
@@ -172,7 +172,7 @@
         hasattr(root_node.decorator_list[0], 'lineno')):
       # Typical case: functions. The line number of the first decorator
       # is more accurate than the line number of the function itself in
-      # 3.8+. In earier versions they coincide.
+      # 3.8+. In earlier versions they coincide.
       self._lineno_offset = context_lineno - root_node.decorator_list[0].lineno
     else:
       # Fall back to the line number of the root node.
diff --git a/tensorflow/python/autograph/pyct/static_analysis/activity.py b/tensorflow/python/autograph/pyct/static_analysis/activity.py
index 8af0f7a..5d6a872 100644
--- a/tensorflow/python/autograph/pyct/static_analysis/activity.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/activity.py
@@ -570,7 +570,7 @@
       node.decorator_list = self.visit_block(node.decorator_list)
       if node.returns:
         node.returns = self._process_annotation(node.returns)
-      # Argument annotartions (includeing defaults) affect the defining context.
+      # Argument annotartions (including defaults) affect the defining context.
       node = self._visit_arg_annotations(node)
 
       function_name = qual_names.QN(node.name)
diff --git a/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py b/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py
index cdeddaa..ad373a5 100644
--- a/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py
@@ -249,7 +249,7 @@
 
     inner_fn_body = fn_body[1].body[1].body
     def_of_a_in_foo = inner_fn_body[0].value
-    # Even though `a` is visible in the inner functio above, the late binding
+    # Even though `a` is visible in the inner function above, the late binding
     # makes it impossible to assume that the same value will be visible at
     # call time.
     self.assertHasDefs(def_of_a_in_foo, 0)
diff --git a/tensorflow/python/autograph/pyct/static_analysis/type_inference.py b/tensorflow/python/autograph/pyct/static_analysis/type_inference.py
index 5b59a5a..d5ab1f5 100644
--- a/tensorflow/python/autograph/pyct/static_analysis/type_inference.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/type_inference.py
@@ -63,9 +63,10 @@
       ns: namespace
       types_ns: types namespace
       name: symbol name
+
     Returns:
       Tuple (type, static_value). The first element is the type to use for
-      inferrence. The second is the static value to use. Return None to treat it
+      inference. The second is the static value to use. Return None to treat it
       as unknown.
     """
     raise NotImplementedError('subclasses must implement')
@@ -383,7 +384,7 @@
     for t in f_types:
 
       if isinstance(t, Callable):
-        # Note: these are undocummented - may be version-specific!
+        # Note: these are undocumented - may be version-specific!
         # Callable[[x], y]: __args__ are (x, y)
         args = t.__args__
         if args:
diff --git a/tensorflow/python/autograph/pyct/transformer.py b/tensorflow/python/autograph/pyct/transformer.py
index c190097..005135e 100644
--- a/tensorflow/python/autograph/pyct/transformer.py
+++ b/tensorflow/python/autograph/pyct/transformer.py
@@ -314,13 +314,13 @@
         in nodes
       after_visit: optional callable that takes in an AST node and returns a
         tuple (new_node, new_destination). It is called after visiting each item
-        in nodes. Is used in the same was as the
-          visit_* methods: new_node will replace the node; if not None,
-            new_destination must be a list, and subsequent nodes will be placed
-            in this list instead of the list returned by visit_block.
+        in nodes. Is used in the same was as the visit_* methods: new_node will
+        replace the node; if not None, new_destination must be a list, and
+        subsequent nodes will be placed in this list instead of the list
+        returned by visit_block.
 
     Returns:
-      A list of AST node objects containing the transformed items fron nodes,
+      A list of AST node objects containing the transformed items from nodes,
       except those nodes that have been relocated using after_visit.
     """
     if nodes is None:
diff --git a/tensorflow/python/autograph/pyct/transpiler.py b/tensorflow/python/autograph/pyct/transpiler.py
index 013ccc5..f7b9150 100644
--- a/tensorflow/python/autograph/pyct/transpiler.py
+++ b/tensorflow/python/autograph/pyct/transpiler.py
@@ -238,7 +238,7 @@
           result = <<transform node>>
           return result
 
-      transformer = MyTransfomer()
+      transformer = MyTransformer()
 
       result = transformer.transform(f, ...)
       # result is the output
@@ -381,7 +381,7 @@
           node = <<transform node, usually using ast.NodeTransformer classes>>
           return node
 
-      transformer = MyTransfomer()
+      transformer = MyTransformer()
 
       new_f, module, source_map = transformer.transform_function(f, ...)
       # new_f is a function with signature identical to f
@@ -430,7 +430,7 @@
     return cached_factory
 
   def transform_function(self, fn, user_context):
-    """Transforms a function. See GenericTranspiler.trasnform_function.
+    """Transforms a function. See GenericTranspiler.transform_function.
 
     This overload wraps the parent's `transform_function`, adding caching and
     facilities to instantiate the output as a Python object. It also
@@ -441,6 +441,7 @@
       fn: A function or lambda.
       user_context: An opaque object (may be None) that is forwarded to
         transform_ast, through the ctx.user attribute.
+
     Returns:
       A tuple:
         * A function or lambda with the same signature and closure as `fn`
diff --git a/tensorflow/python/client/tf_session_wrapper.cc b/tensorflow/python/client/tf_session_wrapper.cc
index b2d3492..00baa13 100644
--- a/tensorflow/python/client/tf_session_wrapper.cc
+++ b/tensorflow/python/client/tf_session_wrapper.cc
@@ -334,7 +334,7 @@
 
   tf_handle(const tf_handle<T>& other) { Reset(other.obj_); }
 
-  tf_handle<T>& operator=(tf_handle<T>&& other) {
+  tf_handle<T>& operator=(tf_handle<T>&& other) noexcept {
     if (this == &other) {
       return *this;
     }
diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py
index 77c31ab..a4777a7 100644
--- a/tensorflow/python/compat/compat.py
+++ b/tensorflow/python/compat/compat.py
@@ -29,7 +29,7 @@
 # This value changes every day with an automatic CL. It can be modified in code
 # via `forward_compatibility_horizon()` or with the environment variable
 # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date.
-_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 7, 31)
+_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 8, 12)
 _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS"
 _FORWARD_COMPATIBILITY_DATE_NUMBER = None
 
diff --git a/tensorflow/python/data/experimental/kernel_tests/assert_cardinality_test.py b/tensorflow/python/data/experimental/kernel_tests/assert_cardinality_test.py
index 857c6f7..b7c062b 100644
--- a/tensorflow/python/data/experimental/kernel_tests/assert_cardinality_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/assert_cardinality_test.py
@@ -88,13 +88,6 @@
           test_base.default_test_combinations(),
           combinations.combine(
               num_elements=10,
-              asserted_cardinality=1,
-              expected_error=errors.FailedPreconditionError,
-              expected_error_message=(
-                  "Input dataset was expected to contain 1 element but "
-                  "contained at least 2 elements.")) +
-          combinations.combine(
-              num_elements=10,
               asserted_cardinality=100,
               expected_error=errors.FailedPreconditionError,
               expected_error_message=(
diff --git a/tensorflow/python/data/experimental/service/_pywrap_utils_exp.pyi b/tensorflow/python/data/experimental/service/_pywrap_utils_exp.pyi
index e88ec56..29126c1 100644
--- a/tensorflow/python/data/experimental/service/_pywrap_utils_exp.pyi
+++ b/tensorflow/python/data/experimental/service/_pywrap_utils_exp.pyi
@@ -14,4 +14,3 @@
 # ==============================================================================
 
 def TF_DATA_DefaultProtocol() -> str: ...
-def TF_DATA_DisableCompressionAtRegistrationTime() -> bool: ...
diff --git a/tensorflow/python/data/experimental/service/utils_wrapper.cc b/tensorflow/python/data/experimental/service/utils_wrapper.cc
index f949829..c725ff3 100644
--- a/tensorflow/python/data/experimental/service/utils_wrapper.cc
+++ b/tensorflow/python/data/experimental/service/utils_wrapper.cc
@@ -23,8 +23,4 @@
 PYBIND11_MODULE(_pywrap_utils_exp, m) {
   m.def("TF_DATA_DefaultProtocol",
         []() -> std::string { return tensorflow::data::DefaultProtocol(); });
-
-  m.def("TF_DATA_DisableCompressionAtRegistrationTime", []() -> bool {
-    return tensorflow::data::DisableCompressionAtRegistrationTime();
-  });
 };
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD
index 7c20b1f..88de9cd 100644
--- a/tensorflow/python/data/kernel_tests/BUILD
+++ b/tensorflow/python/data/kernel_tests/BUILD
@@ -213,9 +213,11 @@
     name = "concatenate_test",
     size = "medium",
     srcs = ["concatenate_test.py"],
+    shard_count = 20,
     deps = [
         ":checkpoint_test_base",
         ":test_base",
+        "//tensorflow/python/data/experimental/ops:global_shuffle_op",
         "//tensorflow/python/data/experimental/ops:random_access",
         "//tensorflow/python/data/ops:dataset_ops",
         "//tensorflow/python/data/ops:options",
diff --git a/tensorflow/python/data/kernel_tests/concatenate_test.py b/tensorflow/python/data/kernel_tests/concatenate_test.py
index be08594..51c7ca4 100644
--- a/tensorflow/python/data/kernel_tests/concatenate_test.py
+++ b/tensorflow/python/data/kernel_tests/concatenate_test.py
@@ -13,8 +13,10 @@
 # limitations under the License.
 # ==============================================================================
 """Tests for `tf.data.Dataset.concatenate()."""
+from typing import Callable, Tuple
 from absl.testing import parameterized
 import numpy as np
+from tensorflow.python.data.experimental.ops import global_shuffle_op
 from tensorflow.python.data.experimental.ops import random_access
 from tensorflow.python.data.kernel_tests import checkpoint_test_base
 from tensorflow.python.data.kernel_tests import test_base
@@ -248,5 +250,301 @@
       self.evaluate(random_access.at(concatenated, index=5))
 
 
+class GlobalShuffleTest(test_base.DatasetTestBase, parameterized.TestCase):
+  """Tests for global shuffling of tf.data datasets."""
+
+  @combinations.generate(test_base.default_test_combinations())
+  def testShuffledOutput(self):
+    dataset1 = dataset_ops.Dataset.range(0, 5)
+    dataset2 = dataset_ops.Dataset.range(5, 17)
+
+    dataset = dataset1.concatenate(dataset2)
+
+    dataset = global_shuffle_op._global_shuffle(dataset)
+
+    output = self.getDatasetOutput(dataset, requires_initialization=True)
+    self.assertCountEqual(output, range(0, 17))
+
+  @combinations.generate(test_base.default_test_combinations())
+  def testShuffledWithBatchOutput(self):
+    """Testing with `.batch()` ensures the global shuffle map is stateless."""
+    dataset1 = dataset_ops.Dataset.range(0, 4)
+    dataset2 = dataset_ops.Dataset.range(4, 10)
+
+    dataset = dataset1.concatenate(dataset2)
+    dataset = dataset.batch(3, drop_remainder=True)
+
+    dataset = global_shuffle_op._global_shuffle(dataset)
+
+    got = self.getDatasetOutput(dataset, requires_initialization=True)
+    expected = [
+        np.array([0, 1, 2], dtype=np.int32),
+        np.array([3, 4, 5], dtype=np.int32),
+        np.array([6, 7, 8], dtype=np.int32),
+    ]
+
+    self.assertIsInstance(got, list)
+    # Converts to tuples for lexicographically sort
+    got.sort(key=tuple)
+
+    self.assertLen(got, len(expected))
+
+    for element_got, element_expected in zip(got, expected):
+      self.assertAllEqual(element_got, element_expected)
+
+  @combinations.generate(test_base.default_test_combinations())
+  def testNestedConcatenateShuffledOutput(self):
+    dataset1 = dataset_ops.Dataset.range(0, 3)
+    dataset2 = dataset_ops.Dataset.range(3, 6)
+    dataset3 = dataset_ops.Dataset.range(6, 9)
+
+    dataset = dataset1.concatenate(dataset2)
+    dataset = dataset.concatenate(dataset3)
+
+    dataset = global_shuffle_op._global_shuffle(dataset)
+
+    output = self.getDatasetOutput(dataset, requires_initialization=True)
+    self.assertCountEqual(output, range(0, 9))
+
+
+class ConcatenateGlobalShuffleCheckpointTest(
+    checkpoint_test_base.CheckpointTestBase, parameterized.TestCase
+):
+
+  @combinations.generate(
+      combinations.times(
+          test_base.default_test_combinations(),
+          checkpoint_test_base.default_test_combinations(),
+          combinations.combine(
+              dataset_ranges=[(10, 8), (9, 5), (4, 7), (5, 8)],
+              reshuffle_each_iteration=[True, False],
+              symbolic_checkpoint=[True, False],
+          ),
+      )
+  )
+  def testConcatenate(
+      self,
+      verify_fn: Callable[..., None],
+      dataset_ranges: Tuple[int, int],
+      reshuffle_each_iteration: bool,
+      symbolic_checkpoint: bool,
+  ):
+
+    def _build_dataset():
+      first_dataset = dataset_ops.Dataset.range(dataset_ranges[0])
+      second_dataset = dataset_ops.Dataset.range(
+          dataset_ranges[0], dataset_ranges[0] + dataset_ranges[1]
+      )
+      dataset = first_dataset.concatenate(second_dataset)
+      dataset = global_shuffle_op._global_shuffle(
+          dataset, seed=10, reshuffle_each_iteration=reshuffle_each_iteration
+      )
+
+      options = options_lib.Options()
+      options.experimental_optimization.apply_default_optimizations = False
+      options.experimental_symbolic_checkpoint = symbolic_checkpoint
+      return dataset.with_options(options)
+
+    verify_fn(
+        self,
+        _build_dataset,
+        num_outputs=sum(dataset_ranges),
+        assert_items_equal=reshuffle_each_iteration,
+    )
+
+  @combinations.generate(
+      combinations.times(
+          test_base.default_test_combinations(),
+          checkpoint_test_base.default_test_combinations(),
+          combinations.combine(
+              dataset_ranges=[(10, 8, 11), (9, 5, 3)],
+              reshuffle_each_iteration=[True, False],
+              symbolic_checkpoint=[True, False],
+          ),
+      )
+  )
+  def testNestedConcatenate(
+      self,
+      verify_fn: Callable[..., None],
+      dataset_ranges: Tuple[int, int],
+      reshuffle_each_iteration: bool,
+      symbolic_checkpoint: bool,
+  ):
+
+    def _build_dataset():
+      first_dataset = dataset_ops.Dataset.range(dataset_ranges[0])
+      second_dataset = dataset_ops.Dataset.range(
+          dataset_ranges[0], dataset_ranges[0] + dataset_ranges[1]
+      )
+      third_dataset = dataset_ops.Dataset.range(
+          sum(dataset_ranges[:2]), sum(dataset_ranges[:3])
+      )
+
+      dataset = first_dataset.concatenate(second_dataset)
+      dataset = dataset.concatenate(third_dataset)
+
+      dataset = global_shuffle_op._global_shuffle(
+          dataset, seed=10, reshuffle_each_iteration=reshuffle_each_iteration
+      )
+
+      options = options_lib.Options()
+      options.experimental_optimization.apply_default_optimizations = False
+      options.experimental_symbolic_checkpoint = symbolic_checkpoint
+      return dataset.with_options(options)
+
+    verify_fn(
+        self,
+        _build_dataset,
+        num_outputs=sum(dataset_ranges),
+        assert_items_equal=reshuffle_each_iteration,
+    )
+
+  @combinations.generate(
+      combinations.times(
+          test_base.default_test_combinations(),
+          checkpoint_test_base.default_test_combinations(),
+          combinations.combine(
+              dataset_ranges=[(3, 4, 6, 5)],
+              reshuffle_each_iteration=[True, False],
+              symbolic_checkpoint=[True, False],
+          ),
+      )
+  )
+  def testFourNestedConcatenate(
+      self,
+      verify_fn: Callable[..., None],
+      dataset_ranges: Tuple[int, int],
+      reshuffle_each_iteration: bool,
+      symbolic_checkpoint: bool,
+  ):
+    def _build_dataset():
+      first_dataset = dataset_ops.Dataset.range(dataset_ranges[0])
+      second_dataset = dataset_ops.Dataset.range(
+          dataset_ranges[0], sum(dataset_ranges[:2])
+      )
+      third_dataset = dataset_ops.Dataset.range(
+          sum(dataset_ranges[:2]), sum(dataset_ranges[:3])
+      )
+      fourth_dataset = dataset_ops.Dataset.range(
+          sum(dataset_ranges[:3]), sum(dataset_ranges)
+      )
+
+      left = first_dataset.concatenate(second_dataset)
+      right = third_dataset.concatenate(fourth_dataset)
+
+      dataset = left.concatenate(right)
+      dataset = global_shuffle_op._global_shuffle(
+          dataset, seed=10, reshuffle_each_iteration=reshuffle_each_iteration
+      )
+
+      options = options_lib.Options()
+      options.experimental_optimization.apply_default_optimizations = False
+      options.experimental_symbolic_checkpoint = symbolic_checkpoint
+      return dataset.with_options(options)
+
+    verify_fn(
+        self,
+        _build_dataset,
+        num_outputs=sum(dataset_ranges),
+        assert_items_equal=reshuffle_each_iteration,
+    )
+
+  @combinations.generate(
+      combinations.times(
+          test_base.default_test_combinations(),
+          checkpoint_test_base.default_test_combinations(),
+          combinations.combine(
+              dataset_ranges=[(1, 2, 3, 4, 5, 6)],
+              reshuffle_each_iteration=[True, False],
+              symbolic_checkpoint=[True, False],
+          ),
+      )
+  )
+  def testDeepConcatenate(
+      self,
+      verify_fn: Callable[..., None],
+      dataset_ranges: Tuple[int, ...],
+      reshuffle_each_iteration: bool,
+      symbolic_checkpoint: bool,
+  ):
+    def _build_dataset():
+      prefix_sums = [0] * (len(dataset_ranges) + 1)
+      for i, value in enumerate(dataset_ranges):
+        prefix_sums[i + 1] = prefix_sums[i] + value
+
+      dataset = dataset_ops.Dataset.range(prefix_sums[0], prefix_sums[1])
+      for i in range(1, len(dataset_ranges)):
+        to_concat = dataset_ops.Dataset.range(
+            prefix_sums[i], prefix_sums[i + 1]
+        )
+        dataset = dataset.concatenate(to_concat)
+
+      dataset = global_shuffle_op._global_shuffle(
+          dataset, seed=10, reshuffle_each_iteration=reshuffle_each_iteration
+      )
+
+      options = options_lib.Options()
+      options.experimental_optimization.apply_default_optimizations = False
+      options.experimental_symbolic_checkpoint = symbolic_checkpoint
+      return dataset.with_options(options)
+
+    verify_fn(
+        self,
+        _build_dataset,
+        num_outputs=sum(dataset_ranges),
+        assert_items_equal=reshuffle_each_iteration,
+    )
+
+  @combinations.generate(
+      combinations.times(
+          test_base.default_test_combinations(),
+          checkpoint_test_base.default_test_combinations(),
+          combinations.combine(
+              dataset_ranges=[(1, 2, 3, 4, 5, 6)],
+              reshuffle_each_iteration=[True, False],
+              symbolic_checkpoint=[True, False],
+          ),
+      )
+  )
+  def testDeepConcatenateWithBatchAndPrefetch(
+      self,
+      verify_fn: Callable[..., None],
+      dataset_ranges: Tuple[int, ...],
+      reshuffle_each_iteration: bool,
+      symbolic_checkpoint: bool,
+  ):
+    def _build_dataset():
+      prefix_sums = [0] * (len(dataset_ranges) + 1)
+      for i, value in enumerate(dataset_ranges):
+        prefix_sums[i + 1] = prefix_sums[i] + value
+
+      dataset = dataset_ops.Dataset.range(prefix_sums[0], prefix_sums[1])
+      for i in range(1, len(dataset_ranges)):
+        to_concat = dataset_ops.Dataset.range(
+            prefix_sums[i], prefix_sums[i + 1]
+        )
+        dataset = dataset.concatenate(to_concat)
+
+      dataset = dataset.batch(2, drop_remainder=True)
+      dataset = dataset.prefetch(buffer_size=dataset_ops.AUTOTUNE)
+
+      dataset = global_shuffle_op._global_shuffle(
+          dataset, seed=10, reshuffle_each_iteration=reshuffle_each_iteration
+      )
+      dataset = dataset.unbatch()
+
+      options = options_lib.Options()
+      options.experimental_optimization.apply_default_optimizations = False
+      options.experimental_symbolic_checkpoint = symbolic_checkpoint
+      return dataset.with_options(options)
+
+    verify_fn(
+        self,
+        _build_dataset,
+        num_outputs=(sum(dataset_ranges) // 2) * 2,
+        assert_items_equal=reshuffle_each_iteration,
+    )
+
+
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/python/data/kernel_tests/flat_map_test.py b/tensorflow/python/data/kernel_tests/flat_map_test.py
index 00299dd..c2c5f26 100644
--- a/tensorflow/python/data/kernel_tests/flat_map_test.py
+++ b/tensorflow/python/data/kernel_tests/flat_map_test.py
@@ -15,6 +15,7 @@
 """Tests for `tf.data.Dataset.flat_map()`."""
 import random
 from typing import Callable, Optional
+import unittest
 
 from absl.testing import parameterized
 import numpy as np
@@ -466,6 +467,10 @@
     verify_fn(self, build_dataset, num_outputs=3 * 4 - num_skips)
 
 
+@unittest.skip(
+    "TODO: b/355241367 - `flat_map_dataset_op.cc` still needs to be fixed."
+    " Please use concatenate dataset op plus global shuffling instead."
+)
 class FlatMapGlobalShuffleTest(
     test_base.DatasetTestBase, parameterized.TestCase):
 
@@ -511,6 +516,10 @@
       self.getDatasetOutput(dataset, requires_initialization=True)
 
 
+@unittest.skip(
+    "TODO: b/355241367 - `flat_map_dataset_op.cc` still needs to be fixed."
+    " Please use concatenate dataset op plus global shuffling instead."
+)
 class FlatMapGlobalShuffleCheckpointTest(
     checkpoint_test_base.CheckpointTestBase, parameterized.TestCase
 ):
diff --git a/tensorflow/python/data/kernel_tests/map_test.py b/tensorflow/python/data/kernel_tests/map_test.py
index 2d00e6b..2e58aeb 100644
--- a/tensorflow/python/data/kernel_tests/map_test.py
+++ b/tensorflow/python/data/kernel_tests/map_test.py
@@ -262,13 +262,15 @@
             self.assertAllEqual(component[i]**2, result_component)
 
   def _parallel_map_dataset_factory(self, components, apply_map, count,
-                                    num_parallel_calls, buffer_size):
+                                    num_parallel_calls, buffer_size,
+                                    use_unbounded_threadpool=False):
 
     def _map_fn(x, y, z):
       return math_ops.square(x), math_ops.square(y), math_ops.square(z)
 
     dataset = dataset_ops.Dataset.from_tensor_slices(components)
-    dataset = apply_map(dataset, _map_fn, num_parallel_calls=num_parallel_calls)
+    dataset = apply_map(dataset, _map_fn, num_parallel_calls=num_parallel_calls,
+                        use_unbounded_threadpool=use_unbounded_threadpool)
     dataset = dataset.prefetch(buffer_size).repeat(count)
 
     self.assertEqual(
@@ -284,8 +286,10 @@
           combinations.combine(num_parallel_calls=2, buffer_size=2) +
           combinations.combine(num_parallel_calls=2, buffer_size=4) +
           combinations.combine(num_parallel_calls=8, buffer_size=8) +
-          combinations.combine(num_parallel_calls=8, buffer_size=16)))
-  def testParallelMapDataset(self, apply_map, num_parallel_calls, buffer_size):
+          combinations.combine(num_parallel_calls=8, buffer_size=16),
+          combinations.combine(use_unbounded_threadpool=[None, True, False])))
+  def testParallelMapDataset(self, apply_map, num_parallel_calls, buffer_size,
+                             use_unbounded_threadpool):
     """Test an dataset that maps a TF function across its input elements."""
 
     # The pipeline is TensorSliceDataset -> ParallelMapDataset(square_3) ->
@@ -296,7 +300,8 @@
     # Test single-threaded access to the iterator.
     get_next = self.getNext(
         self._parallel_map_dataset_factory(components, apply_map, 14,
-                                           num_parallel_calls, buffer_size))
+                                           num_parallel_calls, buffer_size,
+                                           use_unbounded_threadpool))
     for _ in range(14):
       for i in range(7):
         result = self.evaluate(get_next())
@@ -1537,6 +1542,20 @@
 
   @combinations.generate(
       combinations.times(test_base.default_test_combinations(),
+                         combinations.combine(
+                             use_unbounded_threadpool=[True, False])))
+  def testAutotuneUseUnboundedThreadpool(self, use_unbounded_threadpool):
+    dataset = dataset_ops.Dataset.range(100)
+    dataset = dataset.map(
+        lambda x: x * 2,
+        num_parallel_calls=dataset_ops.AUTOTUNE,
+        use_unbounded_threadpool=use_unbounded_threadpool,
+        deterministic=True,
+        name="map")
+    self.assertDatasetProduces(dataset, [x * 2 for x in range(100)])
+
+  @combinations.generate(
+      combinations.times(test_base.default_test_combinations(),
                          combinations.combine(num_parallel_calls=[None, 1])))
   def testName(self, num_parallel_calls):
     dataset = dataset_ops.Dataset.from_tensors(21).map(
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index b2580e3..c8d01f2 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -2157,6 +2157,7 @@
       num_parallel_calls=None,
       deterministic=None,
       synchronous=None,
+      use_unbounded_threadpool=False,
       name=None,
   ) -> "DatasetV2":
     """Maps `map_func` across the elements of this dataset.
@@ -2313,6 +2314,11 @@
         saving memory, since even setting `num_parallel_calls=1` will cause one
         batch to be buffered, while with `synchronous=True` the map
         transformation doesn't buffer anything.
+      use_unbounded_threadpool: (Optional.) By default, map functions run in a
+        limited threadpool based on the number of cores on the machine. This
+        efficient for CPU-heavy processing, but if the map function performs IO
+        it is better to use an unbounded threadpool by setting it to `True`. It
+        is `False` by default.
       name: (Optional.) A name for the tf.data operation.
 
     Returns:
@@ -2329,6 +2335,7 @@
         num_parallel_calls=num_parallel_calls,
         deterministic=deterministic,
         synchronous=synchronous,
+        use_unbounded_threadpool=use_unbounded_threadpool,
         name=name,
     )
     # pylint: enable=g-import-not-at-top,protected-access
@@ -4092,6 +4099,7 @@
       num_parallel_calls=None,
       deterministic=None,
       synchronous=None,
+      use_unbounded_threadpool=False,
       name=None,
   ):
     # Loaded lazily due to a circular dependency (dataset_ops -> map_op ->
@@ -4105,12 +4113,17 @@
         num_parallel_calls=num_parallel_calls,
         deterministic=deterministic,
         synchronous=synchronous,
+        use_unbounded_threadpool=use_unbounded_threadpool,
     )
     # pylint: enable=g-import-not-at-top,protected-access
 
   @deprecation.deprecated(None, "Use `tf.data.Dataset.map()")
   def map_with_legacy_function(
-      self, map_func, num_parallel_calls=None, deterministic=None
+      self,
+      map_func,
+      num_parallel_calls=None,
+      deterministic=None,
+      use_unbounded_threadpool=False,
   ) -> "DatasetV1Adapter":
     """Maps `map_func` across the elements of this dataset.
 
@@ -4133,6 +4146,11 @@
         elements out of order to trade determinism for performance. If not
         specified, the `tf.data.Options.deterministic` option (`True` by
         default) controls the behavior.
+      use_unbounded_threadpool: (Optional.) By default, map functions run in a
+        limited threadpool based on the number of cores on the machine. This
+        efficient for CPU-heavy processing, but if the map function performs IO
+        it is better to use an unbounded threadpool by setting it to `True`. It
+        is `False` by default.
 
     Returns:
       Dataset: A `Dataset`.
diff --git a/tensorflow/python/data/ops/map_op.py b/tensorflow/python/data/ops/map_op.py
index 0a056ab..f301dee 100644
--- a/tensorflow/python/data/ops/map_op.py
+++ b/tensorflow/python/data/ops/map_op.py
@@ -30,6 +30,7 @@
     num_parallel_calls=None,
     deterministic=None,
     synchronous=None,
+    use_unbounded_threadpool=None,
     name=None,
 ):
   """See `Dataset.map()` for details."""
@@ -59,6 +60,7 @@
         num_parallel_calls=num_parallel_calls,
         deterministic=deterministic,
         preserve_cardinality=True,
+        use_unbounded_threadpool=use_unbounded_threadpool,
         name=name)
 
 
@@ -68,6 +70,7 @@
     num_parallel_calls=None,
     deterministic=None,
     synchronous=None,
+    use_unbounded_threadpool=None,  # pylint: disable=unused-argument
 ):
   """See `Dataset.map()` for details."""
   if num_parallel_calls is None or debug_mode.DEBUG_MODE:
@@ -92,7 +95,8 @@
             map_func,
             num_parallel_calls,
             deterministic,
-            preserve_cardinality=False))
+            preserve_cardinality=False,
+            use_unbounded_threadpool=False))
 
 
 def _map_v1_with_legacy_function(  # pylint: disable=unused-private-name
@@ -130,7 +134,8 @@
             num_parallel_calls,
             deterministic,
             preserve_cardinality=False,
-            use_legacy_function=True))
+            use_legacy_function=True,
+            use_unbounded_threadpool=False))
 
 
 class _MapDataset(dataset_ops.UnaryDataset):
@@ -189,6 +194,7 @@
                use_inter_op_parallelism=True,
                preserve_cardinality=False,
                use_legacy_function=False,
+               use_unbounded_threadpool=False,
                name=None):
     """See `Dataset.map()` for details."""
     self._input_dataset = input_dataset
@@ -207,6 +213,7 @@
     self._preserve_cardinality = preserve_cardinality
     self._num_parallel_calls = ops.convert_to_tensor(
         num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls")
+    self._use_unbounded_threadpool = use_unbounded_threadpool
     self._name = name
     variant_tensor = gen_dataset_ops.parallel_map_dataset_v2(
         input_dataset._variant_tensor,  # pylint: disable=protected-access
@@ -216,6 +223,7 @@
         deterministic=self._deterministic,
         use_inter_op_parallelism=self._use_inter_op_parallelism,
         preserve_cardinality=self._preserve_cardinality,
+        use_unbounded_threadpool=self._use_unbounded_threadpool,
         **self._common_args)
     super().__init__(input_dataset, variant_tensor)
 
diff --git a/tensorflow/python/distribute/parallel_device/pywrap_parallel_device.cc b/tensorflow/python/distribute/parallel_device/pywrap_parallel_device.cc
index cd3d2ee..1f334a0 100644
--- a/tensorflow/python/distribute/parallel_device/pywrap_parallel_device.cc
+++ b/tensorflow/python/distribute/parallel_device/pywrap_parallel_device.cc
@@ -19,15 +19,9 @@
 #include "Python.h"
 #include "pybind11/pybind11.h"  // from @pybind11
 #include "pybind11/stl.h"  // from @pybind11
-#include "tensorflow/c/c_api.h"
-#include "tensorflow/c/c_api_experimental.h"
-#include "tensorflow/c/eager/c_api.h"
 #include "tensorflow/c/eager/c_api_experimental.h"
 #include "tensorflow/c/eager/parallel_device/parallel_device.h"
-#include "tensorflow/c/safe_ptr.h"
-#include "tensorflow/python/lib/core/py_exception_registry.h"
 #include "tensorflow/python/lib/core/pybind11_lib.h"
-#include "tensorflow/python/lib/core/pybind11_status.h"
 #include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
 
 namespace py = pybind11;
diff --git a/tensorflow/python/flags_pybind.pyi b/tensorflow/python/flags_pybind.pyi
index b34ed2f..7c450b6 100644
--- a/tensorflow/python/flags_pybind.pyi
+++ b/tensorflow/python/flags_pybind.pyi
@@ -24,6 +24,7 @@
     enable_function_pruning_before_inlining: Flag
     enable_nested_function_shape_inference: Flag
     enable_quantized_dtypes_training: Flag
+    enable_skip_encapsulation_for_non_tpu_graphs: Flag
     enable_tf2min_ici_weight: Flag
     graph_building_optimization: Flag
     more_stack_traces: Flag
diff --git a/tensorflow/python/framework/offset_counter_helper_test.cc b/tensorflow/python/framework/offset_counter_helper_test.cc
index dcf6f7c..ef616a3 100644
--- a/tensorflow/python/framework/offset_counter_helper_test.cc
+++ b/tensorflow/python/framework/offset_counter_helper_test.cc
@@ -18,10 +18,10 @@
 #include <string>
 
 #include "absl/strings/str_format.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/python/framework/op_reg_offset.pb.h"
-#include "tsl/lib/core/status_test_util.h"
 
 namespace tensorflow {
 namespace {
diff --git a/tensorflow/python/lib/core/py_seq_tensor.cc b/tensorflow/python/lib/core/py_seq_tensor.cc
index 4cd43ae..1c81b35 100644
--- a/tensorflow/python/lib/core/py_seq_tensor.cc
+++ b/tensorflow/python/lib/core/py_seq_tensor.cc
@@ -716,16 +716,22 @@
   // These objects are efficiently handled by Numpy. We transform them into
   // Numpy arrays and handle them in the Numpy case below. Note that Tensors
   // implement the __array__ function, and will be handled in this shortcut.
-  Safe_PyObjectPtr array =
-      make_safe(PyArray_FromArrayAttr(obj, nullptr, nullptr));
-  if (array == nullptr) {
-    return nullptr;
+  // We used to call PyArray_FromArrayAttr here, but NumPy 2.0 changed its
+  // semantics such that it errors if a copy of the array is required.
+  // (Ideally no copy would be needed here, but that would be a larger change.)
+  Safe_PyObjectPtr array;
+  if (PyObject_HasAttrString(obj, "__array__")) {
+    array = make_safe(PyObject_CallMethod(obj, "__array__", nullptr));
+    if (array == nullptr) {
+      return nullptr;
+    }
+    if (!PyArray_Check(array.get())) {
+      PyErr_SetString(PyExc_ValueError,
+                      "Value returned by __array__ is not a NumPy array");
+      return nullptr;
+    }
   }
-  if (array.get() == Py_NotImplemented) {
-    // The Py_NotImplemented returned from PyArray_FromArrayAttr is not
-    // Py_INCREF'ed, so we don't want the Safe_PyObjectPtr to Py_DECREF it.
-    array.release();
-
+  if (!array) {
     // Try __array_interface__ objects (such as PIL Image).
     array = make_safe(PyArray_FromInterface(obj));
     if (array == nullptr) {
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 0cb65da..76dd388 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -396,7 +396,7 @@
 
   Given a tensor `input`, this operation inserts a dimension of length 1 at the
   dimension index `axis` of `input`'s shape. The dimension index follows Python
-  indexing rules: It's zero-based, a negative index it is counted backward
+  indexing rules: It's zero-based, and a negative index is counted backward
   from the end.
 
   This operation is useful to:
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index ab508d1..d553667 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -2010,8 +2010,12 @@
     # infer dtype if not explicitly provided
     if dtype is None:
       dtype_hierarchy = [
-          dtypes.int32, dtypes.int64, dtypes.float16, dtypes.bfloat16,
-          dtypes.float32, dtypes.float64
+          dtypes.int32,
+          dtypes.int64,
+          dtypes.float16,
+          dtypes.bfloat16,
+          dtypes.float32,
+          dtypes.float64,
       ]
       assert all(arg.dtype in dtype_hierarchy for arg in [start, limit, delta])
       inferred_dtype = max([arg.dtype for arg in [start, limit, delta]],
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt
index 17a2776..b43fcc3 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt
@@ -129,11 +129,11 @@
   }
   member_method {
     name: "map"
-    argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'use_unbounded_threadpool\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], "
   }
   member_method {
     name: "map_with_legacy_function"
-    argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+    argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'use_unbounded_threadpool\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], "
   }
   member_method {
     name: "options"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt
index 88a2e4f..e4b1c19 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt
@@ -131,11 +131,11 @@
   }
   member_method {
     name: "map"
-    argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'use_unbounded_threadpool\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], "
   }
   member_method {
     name: "map_with_legacy_function"
-    argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+    argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'use_unbounded_threadpool\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], "
   }
   member_method {
     name: "options"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt
index b71ecb1..999d0c0 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt
@@ -131,11 +131,11 @@
   }
   member_method {
     name: "map"
-    argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'use_unbounded_threadpool\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], "
   }
   member_method {
     name: "map_with_legacy_function"
-    argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+    argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'use_unbounded_threadpool\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], "
   }
   member_method {
     name: "options"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt
index e4f1c85..123c3b7 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt
@@ -131,11 +131,11 @@
   }
   member_method {
     name: "map"
-    argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'use_unbounded_threadpool\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], "
   }
   member_method {
     name: "map_with_legacy_function"
-    argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+    argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'use_unbounded_threadpool\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], "
   }
   member_method {
     name: "options"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.pbtxt
index 266a8f1..2b8fdc3 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.pbtxt
@@ -131,11 +131,11 @@
   }
   member_method {
     name: "map"
-    argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'use_unbounded_threadpool\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], "
   }
   member_method {
     name: "map_with_legacy_function"
-    argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+    argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'use_unbounded_threadpool\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], "
   }
   member_method {
     name: "options"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.pbtxt
index aff5ee5f..77b396a 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.pbtxt
@@ -131,11 +131,11 @@
   }
   member_method {
     name: "map"
-    argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'use_unbounded_threadpool\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], "
   }
   member_method {
     name: "map_with_legacy_function"
-    argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+    argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'use_unbounded_threadpool\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], "
   }
   member_method {
     name: "options"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.pbtxt
index 6986dab..48c8540 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.pbtxt
@@ -131,11 +131,11 @@
   }
   member_method {
     name: "map"
-    argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'use_unbounded_threadpool\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], "
   }
   member_method {
     name: "map_with_legacy_function"
-    argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+    argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'use_unbounded_threadpool\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], "
   }
   member_method {
     name: "options"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
index 8c5d4bf..c55d7db 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
@@ -3014,7 +3014,7 @@
   }
   member_method {
     name: "ParallelMapDatasetV2"
-    argspec: "args=[\'input_dataset\', \'other_arguments\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'use_inter_op_parallelism\', \'deterministic\', \'preserve_cardinality\', \'metadata\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'default\', \'False\', \'\', \'None\'], "
+    argspec: "args=[\'input_dataset\', \'other_arguments\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'use_inter_op_parallelism\', \'deterministic\', \'preserve_cardinality\', \'use_unbounded_threadpool\', \'metadata\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'default\', \'False\', \'False\', \'\', \'None\'], "
   }
   member_method {
     name: "ParameterizedTruncatedNormal"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt
index 1852f76..00f5d54 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt
@@ -100,7 +100,7 @@
   }
   member_method {
     name: "map"
-    argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'use_unbounded_threadpool\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], "
   }
   member_method {
     name: "options"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt
index 5f1f368..9230f45 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt
@@ -102,7 +102,7 @@
   }
   member_method {
     name: "map"
-    argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'use_unbounded_threadpool\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], "
   }
   member_method {
     name: "options"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt
index 85eb696..e89d811 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt
@@ -101,7 +101,7 @@
   }
   member_method {
     name: "map"
-    argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'use_unbounded_threadpool\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], "
   }
   member_method {
     name: "options"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt
index 42a293d..c936ff4 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt
@@ -102,7 +102,7 @@
   }
   member_method {
     name: "map"
-    argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'use_unbounded_threadpool\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], "
   }
   member_method {
     name: "options"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt
index d376170..f31da9c 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt
@@ -102,7 +102,7 @@
   }
   member_method {
     name: "map"
-    argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'use_unbounded_threadpool\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], "
   }
   member_method {
     name: "options"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt
index 190a21f..8faf543 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt
@@ -103,7 +103,7 @@
   }
   member_method {
     name: "map"
-    argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'use_unbounded_threadpool\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], "
   }
   member_method {
     name: "options"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt
index e19f932..490b7f7 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt
@@ -102,7 +102,7 @@
   }
   member_method {
     name: "map"
-    argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'use_unbounded_threadpool\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], "
   }
   member_method {
     name: "options"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.experimental.dtensor.-d-tensor-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.experimental.dtensor.-d-tensor-dataset.pbtxt
index d6d6439..c2194ab 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.experimental.dtensor.-d-tensor-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.experimental.dtensor.-d-tensor-dataset.pbtxt
@@ -103,7 +103,7 @@
   }
   member_method {
     name: "map"
-    argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'use_unbounded_threadpool\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], "
   }
   member_method {
     name: "options"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
index 8c5d4bf..c55d7db 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
@@ -3014,7 +3014,7 @@
   }
   member_method {
     name: "ParallelMapDatasetV2"
-    argspec: "args=[\'input_dataset\', \'other_arguments\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'use_inter_op_parallelism\', \'deterministic\', \'preserve_cardinality\', \'metadata\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'default\', \'False\', \'\', \'None\'], "
+    argspec: "args=[\'input_dataset\', \'other_arguments\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'use_inter_op_parallelism\', \'deterministic\', \'preserve_cardinality\', \'use_unbounded_threadpool\', \'metadata\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'default\', \'False\', \'False\', \'\', \'None\'], "
   }
   member_method {
     name: "ParameterizedTruncatedNormal"
diff --git a/tensorflow/tools/benchmark/benchmark_model_test.cc b/tensorflow/tools/benchmark/benchmark_model_test.cc
index c6e4284..1b1b443 100644
--- a/tensorflow/tools/benchmark/benchmark_model_test.cc
+++ b/tensorflow/tools/benchmark/benchmark_model_test.cc
@@ -18,6 +18,7 @@
 #include "tensorflow/cc/framework/scope.h"
 #include "tensorflow/cc/ops/array_ops.h"
 #include "tensorflow/cc/ops/math_ops.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/graph.pb.h"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/tensor_shape.h"
@@ -29,7 +30,6 @@
 #include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/public/session.h"
 #include "tensorflow/core/util/stat_summarizer.h"
-#include "tsl/lib/core/status_test_util.h"
 
 namespace tensorflow {
 namespace {
diff --git a/tensorflow/tools/ci_build/build_scripts/ARM_SKIP_TESTS.sh b/tensorflow/tools/ci_build/build_scripts/ARM_SKIP_TESTS.sh
index 72a228f..992aa6d 100644
--- a/tensorflow/tools/ci_build/build_scripts/ARM_SKIP_TESTS.sh
+++ b/tensorflow/tools/ci_build/build_scripts/ARM_SKIP_TESTS.sh
@@ -16,5 +16,4 @@
 set -x
 
 ARM_SKIP_TESTS="-//tensorflow/lite/... \
--//tensorflow/core/kernels/image:resize_bicubic_op_test \
 "
diff --git a/tensorflow/tools/graph_transforms/backports_test.cc b/tensorflow/tools/graph_transforms/backports_test.cc
index 80a954e..155ec29 100644
--- a/tensorflow/tools/graph_transforms/backports_test.cc
+++ b/tensorflow/tools/graph_transforms/backports_test.cc
@@ -192,7 +192,7 @@
     std::map<string, const NodeDef*> node_lookup;
     MapNamesToNodes(result, &node_lookup);
     ASSERT_EQ(1, node_lookup.count("v3_node"));
-    EXPECT_TRUE(str_util::EndsWith(node_lookup.at("v3_node")->op(), "V2"));
+    EXPECT_TRUE(absl::EndsWith(node_lookup.at("v3_node")->op(), "V2"));
   }
 }
 
diff --git a/tensorflow/tools/graph_transforms/fold_constants_test.cc b/tensorflow/tools/graph_transforms/fold_constants_test.cc
index dcdc3c2..3d388cd 100644
--- a/tensorflow/tools/graph_transforms/fold_constants_test.cc
+++ b/tensorflow/tools/graph_transforms/fold_constants_test.cc
@@ -210,10 +210,10 @@
     for (const NodeDef& node : graph_def.node()) {
       const StringPiece name(node.name());
       const int occurrence_count = folded_node_map.count(node.name());
-      if (str_util::EndsWith(name, "expect_removed")) {
+      if (absl::EndsWith(name, "expect_removed")) {
         EXPECT_EQ(0, occurrence_count) << "node.name()=" << node.name();
       }
-      if (str_util::EndsWith(name, "expect_remains")) {
+      if (absl::EndsWith(name, "expect_remains")) {
         EXPECT_EQ(1, occurrence_count) << "node.name()=" << node.name();
       }
     }
diff --git a/tensorflow/tools/proto_splitter/cc/BUILD b/tensorflow/tools/proto_splitter/cc/BUILD
index 9672bf1..da86a3a 100644
--- a/tensorflow/tools/proto_splitter/cc/BUILD
+++ b/tensorflow/tools/proto_splitter/cc/BUILD
@@ -97,9 +97,9 @@
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/strings:cord",
         "@com_google_googletest//:gtest_main",
-        "@local_tsl//tsl/lib/core:status_test_util",
         "@local_tsl//tsl/platform:protobuf",
         "@local_tsl//tsl/platform:status_matchers",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
         "@riegeli//riegeli/base:initializer",
         "@riegeli//riegeli/bytes:cord_reader",
         "@riegeli//riegeli/bytes:fd_reader",
@@ -163,11 +163,11 @@
         "//tensorflow/tools/proto_splitter/testdata:test_message_proto_cc",
         "@com_google_absl//absl/status",
         "@com_google_googletest//:gtest_main",
-        "@local_tsl//tsl/lib/core:status_test_util",
         "@local_tsl//tsl/platform:protobuf",
         "@local_tsl//tsl/platform:status",
         "@local_tsl//tsl/platform:status_matchers",
         "@local_tsl//tsl/platform:statusor",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ] + if_oss([
         "//tensorflow/tools/proto_splitter:protos_impl",
     ]),
diff --git a/tensorflow/tools/proto_splitter/cc/composable_splitter_test.cc b/tensorflow/tools/proto_splitter/cc/composable_splitter_test.cc
index 55aac6a..d62acc5 100644
--- a/tensorflow/tools/proto_splitter/cc/composable_splitter_test.cc
+++ b/tensorflow/tools/proto_splitter/cc/composable_splitter_test.cc
@@ -30,6 +30,7 @@
 #include "riegeli/bytes/fd_reader.h"  // from @riegeli
 #include "riegeli/bytes/string_reader.h"  // from @riegeli
 #include "riegeli/records/record_reader.h"  // from @riegeli
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/graph.pb.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/file_system_helper.h"
@@ -39,7 +40,6 @@
 #include "tensorflow/tools/proto_splitter/cc/util.h"
 #include "tensorflow/tools/proto_splitter/chunk.pb.h"
 #include "tensorflow/tools/proto_splitter/testdata/test_message.pb.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/errors.h"
 #include "tsl/platform/status_matchers.h"
 #include "tsl/platform/statusor.h"
diff --git a/tensorflow/tools/proto_splitter/cc/graph_def_splitter_test.cc b/tensorflow/tools/proto_splitter/cc/graph_def_splitter_test.cc
index 1d98a3a..1fb19f5 100644
--- a/tensorflow/tools/proto_splitter/cc/graph_def_splitter_test.cc
+++ b/tensorflow/tools/proto_splitter/cc/graph_def_splitter_test.cc
@@ -23,6 +23,7 @@
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
 #include "absl/strings/cord.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/attr_value.pb.h"
 #include "tensorflow/core/framework/graph.pb.h"
 #include "tensorflow/core/framework/node_def.pb.h"
@@ -34,7 +35,6 @@
 #include "tensorflow/tools/proto_splitter/cc/test_util.h"
 #include "tensorflow/tools/proto_splitter/cc/util.h"
 #include "tensorflow/tools/proto_splitter/testdata/test_message.pb.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/protobuf.h"
 #include "tsl/platform/statusor.h"
 
diff --git a/tensorflow/tools/proto_splitter/cc/saved_model_splitter_test.cc b/tensorflow/tools/proto_splitter/cc/saved_model_splitter_test.cc
index b03bcc1..1712421 100644
--- a/tensorflow/tools/proto_splitter/cc/saved_model_splitter_test.cc
+++ b/tensorflow/tools/proto_splitter/cc/saved_model_splitter_test.cc
@@ -22,6 +22,7 @@
 
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/attr_value.pb.h"
 #include "tensorflow/core/framework/graph.pb.h"
 #include "tensorflow/core/framework/node_def.pb.h"
@@ -33,7 +34,6 @@
 #include "tensorflow/tools/proto_splitter/cc/max_size.h"
 #include "tensorflow/tools/proto_splitter/cc/util.h"
 #include "tensorflow/tools/proto_splitter/testdata/test_message.pb.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/protobuf.h"
 #include "tsl/platform/statusor.h"
 
diff --git a/tensorflow/tools/proto_splitter/cc/util_test.cc b/tensorflow/tools/proto_splitter/cc/util_test.cc
index e318f7c..7880519 100644
--- a/tensorflow/tools/proto_splitter/cc/util_test.cc
+++ b/tensorflow/tools/proto_splitter/cc/util_test.cc
@@ -21,12 +21,12 @@
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
 #include "absl/status/status.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/path.h"
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/tools/proto_splitter/cc/test_util.h"
 #include "tensorflow/tools/proto_splitter/chunk.pb.h"
 #include "tensorflow/tools/proto_splitter/testdata/test_message.pb.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/protobuf.h"
 #include "tsl/platform/status.h"
 #include "tsl/platform/status_matchers.h"
diff --git a/tensorflow/tools/proto_splitter/merge_test.cc b/tensorflow/tools/proto_splitter/merge_test.cc
index 5f78f3f..06d40d5 100644
--- a/tensorflow/tools/proto_splitter/merge_test.cc
+++ b/tensorflow/tools/proto_splitter/merge_test.cc
@@ -23,6 +23,7 @@
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
 #include "absl/strings/str_cat.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tensorflow/core/framework/graph.pb.h"
 #include "tensorflow/core/platform/path.h"
 #include "tensorflow/core/platform/test.h"
@@ -31,7 +32,6 @@
 #include "tensorflow/tools/proto_splitter/cc/util.h"
 #include "tensorflow/tools/proto_splitter/chunk.pb.h"
 #include "tensorflow/tools/proto_splitter/testdata/test_message.pb.h"
-#include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/env.h"
 #include "tsl/platform/protobuf.h"
 #include "tsl/platform/statusor.h"
diff --git a/tensorflow/tools/tf_sig_build_dockerfiles/Dockerfile b/tensorflow/tools/tf_sig_build_dockerfiles/Dockerfile
index 7659bd6..b6d1aca 100644
--- a/tensorflow/tools/tf_sig_build_dockerfiles/Dockerfile
+++ b/tensorflow/tools/tf_sig_build_dockerfiles/Dockerfile
@@ -1,5 +1,5 @@
 ################################################################################
-FROM ubuntu:22.04@sha256:19478ce7fc2ffbce89df29fea5725a8d12e57de52eb9ea570890dc5852aac1ac as builder
+FROM ubuntu:22.04@sha256:340d9b015b194dc6e2a13938944e0d016e57b9679963fdeb9ce021daac430221 as builder
 ################################################################################
 
 # Install devtoolset build dependencies
diff --git a/tensorflow/workspace2.bzl b/tensorflow/workspace2.bzl
index 06def11..2959fe8 100644
--- a/tensorflow/workspace2.bzl
+++ b/tensorflow/workspace2.bzl
@@ -57,7 +57,6 @@
 load("//third_party/tensorrt:tensorrt_configure.bzl", "tensorrt_configure")
 load("//third_party/tensorrt:workspace.bzl", tensorrt = "repo")
 load("//third_party/triton:workspace.bzl", triton = "repo")
-load("//third_party/uv:workspace.bzl", uv = "repo")
 load("//third_party/vulkan_headers:workspace.bzl", vulkan_headers = "repo")
 
 def _initialize_third_party():
@@ -93,7 +92,6 @@
     vulkan_headers()
     tensorrt()
     triton()
-    uv()
 
     # copybara: tsl vendor
 
diff --git a/third_party/absl/nvidia_jetson.patch b/third_party/absl/nvidia_jetson.patch
new file mode 100644
index 0000000..5328c3a
--- /dev/null
+++ b/third_party/absl/nvidia_jetson.patch
@@ -0,0 +1,35 @@
+From 372124e6af36a540e74a2ec31d79d7297a831f98 Mon Sep 17 00:00:00 2001
+From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20Bastien?= <frederic.bastien@gmail.com>
+Date: Thu, 1 Aug 2024 12:38:52 -0700
+Subject: [PATCH] PR #1732: Fix build on NVIDIA Jetson board. Fix #1665
+
+Imported from GitHub PR https://github.com/abseil/abseil-cpp/pull/1732
+
+Fix build on NVIDIA Jetson board. Fix #1665
+
+This patch is already used by the spark project.
+I'm fixing this as this break the build of Tensorflow and JAX on Jetson board.
+Merge 7db2d2ab9fbed1f0fabad10a6ec73533ba71bfff into 6b8ebb35c0414ef5a2b6fd4a0f59057e41beaff9
+
+Merging this change closes #1732
+
+COPYBARA_INTEGRATE_REVIEW=https://github.com/abseil/abseil-cpp/pull/1732 from nouiz:fix_neon_on_jetson 7db2d2ab9fbed1f0fabad10a6ec73533ba71bfff
+PiperOrigin-RevId: 658501520
+Change-Id: If502ede4efc8c877fb3fed227eca6dc7622dd181
+---
+ absl/base/config.h | 2 +-
+ 1 file changed, 1 insertion(+), 1 deletion(-)
+
+diff --git a/absl/base/config.h b/absl/base/config.h
+index 97c9a22a109..ab1e9860a91 100644
+--- a/absl/base/config.h
++++ b/absl/base/config.h
+@@ -926,7 +926,7 @@ static_assert(ABSL_INTERNAL_INLINE_NAMESPACE_STR[0] != 'h' ||
+ // https://llvm.org/docs/CompileCudaWithLLVM.html#detecting-clang-vs-nvcc-from-code
+ #ifdef ABSL_INTERNAL_HAVE_ARM_NEON
+ #error ABSL_INTERNAL_HAVE_ARM_NEON cannot be directly set
+-#elif defined(__ARM_NEON) && !defined(__CUDA_ARCH__)
++#elif defined(__ARM_NEON) && !(defined(__NVCC__) && defined(__CUDACC__))
+ #define ABSL_INTERNAL_HAVE_ARM_NEON 1
+ #endif
+ 
diff --git a/third_party/absl/workspace.bzl b/third_party/absl/workspace.bzl
index 06f7516..9565a82 100644
--- a/third_party/absl/workspace.bzl
+++ b/third_party/absl/workspace.bzl
@@ -44,4 +44,5 @@
         system_link_files = SYS_LINKS,
         strip_prefix = "abseil-cpp-{commit}".format(commit = ABSL_COMMIT),
         urls = tf_mirror_urls("https://github.com/abseil/abseil-cpp/archive/{commit}.tar.gz".format(commit = ABSL_COMMIT)),
+        patch_file = ["//third_party/absl:nvidia_jetson.patch"],
     )
diff --git a/third_party/gpus/cuda/BUILD.tpl b/third_party/gpus/cuda/BUILD.tpl
index 0b85e59..44cdbe3 100644
--- a/third_party/gpus/cuda/BUILD.tpl
+++ b/third_party/gpus/cuda/BUILD.tpl
@@ -249,3 +249,9 @@
     # to make bazel query happy.
     name = "nvptxcompiler",
 )
+
+cc_library(
+    # This is not yet fully supported, but we need the rule
+    # to make bazel query happy.
+    name = "nvjitlink",
+)
\ No newline at end of file
diff --git a/third_party/gpus/rocm_configure.bzl b/third_party/gpus/rocm_configure.bzl
index c185ca7..ff9b53b 100644
--- a/third_party/gpus/rocm_configure.bzl
+++ b/third_party/gpus/rocm_configure.bzl
@@ -205,6 +205,8 @@
     inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/17.0.0/include")
     inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/17/include")
     inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/18/include")
+    if int(rocm_config.rocm_version_number) >= 60200:
+        inc_dirs.append(rocm_toolkit_path + "/lib/llvm/lib/clang/18/include")
 
     # Support hcc based off clang 10.0.0 (for ROCm 3.3)
     inc_dirs.append(rocm_toolkit_path + "/hcc/compiler/lib/clang/10.0.0/include/")
diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl
index 6429d9b..54a3c65 100644
--- a/third_party/llvm/workspace.bzl
+++ b/third_party/llvm/workspace.bzl
@@ -4,8 +4,8 @@
 
 def repo(name):
     """Imports LLVM."""
-    LLVM_COMMIT = "42d641ef5cc4bd82f98ef9959a593ca6db66d75d"
-    LLVM_SHA256 = "ec368e9c3b1e1c5eb646c21da65bb54a53060b417e61f2451f3917b35d743abd"
+    LLVM_COMMIT = "4c5ef6690040383956461828457ac27f7f912edb"
+    LLVM_SHA256 = "a30da7822f5307bc0aca8c497ffdd6369e3877186e87501e2ac1f3ec5ed1c0b7"
 
     tf_http_archive(
         name = name,
diff --git a/third_party/mkl_dnn/mkldnn_v1.BUILD b/third_party/mkl_dnn/mkldnn_v1.BUILD
index 4b3ad84..8c73096 100644
--- a/third_party/mkl_dnn/mkldnn_v1.BUILD
+++ b/third_party/mkl_dnn/mkldnn_v1.BUILD
@@ -12,7 +12,7 @@
     "#cmakedefine DNNL_SYCL_CUDA": "#undef DNNL_SYCL_CUDA",
     "#cmakedefine DNNL_SYCL_HIP": "#undef DNNL_SYCL_HIP",
     "#cmakedefine DNNL_ENABLE_STACK_CHECKER": "#undef DNNL_ENABLE_STACK_CHECKER",
-    "#cmakedefine ONEDNN_BUILD_GRAPH": "#undef ONEDNN_BUILD_GRAPH",
+    "#cmakedefine ONEDNN_BUILD_GRAPH": "#define ONEDNN_BUILD_GRAPH",
     "#cmakedefine DNNL_EXPERIMENTAL_SPARSE": "#define DNNL_EXPERIMENTAL_SPARSE",
     "#cmakedefine DNNL_EXPERIMENTAL": "#undef DNNL_EXPERIMENTAL",
     "#cmakedefine01 BUILD_TRAINING": "#define BUILD_TRAINING 1",
@@ -109,6 +109,7 @@
     "-UUSE_CBLAS",
     "-DDNNL_ENABLE_MAX_CPU_ISA",
     "-DDNNL_ENABLE_ITT_TASKS",
+    "-DDNNL_ENABLE_GRAPH_DUMP",
 ] + tf_openmp_copts()
 
 _INCLUDES_LIST = [
@@ -119,6 +120,7 @@
     "src/cpu",
     "src/cpu/gemm",
     "src/cpu/x64/xbyak",
+    "src/graph",
 ]
 
 _TEXTUAL_HDRS_LIST = glob([
@@ -129,6 +131,15 @@
     "src/cpu/**/*.hpp",
     "src/cpu/jit_utils/**/*.hpp",
     "src/cpu/x64/xbyak/*.h",
+    "src/graph/interface/*.hpp",
+    "src/graph/backend/*.hpp",
+    "src/graph/backend/dnnl/*.hpp",
+    "src/graph/backend/fake/*.hpp",
+    "src/graph/backend/dnnl/passes/*.hpp",
+    "src/graph/backend/dnnl/patterns/*.hpp",
+    "src/graph/backend/dnnl/kernels/*.hpp",
+    "src/graph/utils/*.hpp",
+    "src/graph/utils/pm/*.hpp",
 ]) + [
     ":dnnl_config_h",
     ":dnnl_version_h",
@@ -160,6 +171,16 @@
             "src/cpu/**/*.cpp",
             "src/common/ittnotify/*.c",
             "src/cpu/jit_utils/**/*.cpp",
+            "src/cpu/x64/**/*.cpp",
+            "src/graph/interface/*.cpp",
+            "src/graph/backend/*.cpp",
+            "src/graph/backend/dnnl/*.cpp",
+            "src/graph/backend/fake/*.cpp",
+            "src/graph/backend/dnnl/passes/*.cpp",
+            "src/graph/backend/dnnl/patterns/*.cpp",
+            "src/graph/backend/dnnl/kernels/*.cpp",
+            "src/graph/utils/*.cpp",
+            "src/graph/utils/pm/*.cpp",
         ],
         exclude = [
             "src/cpu/aarch64/**",
diff --git a/third_party/shardy/temporary.patch b/third_party/shardy/temporary.patch
index 4d99610..1711c1b 100644
--- a/third_party/shardy/temporary.patch
+++ b/third_party/shardy/temporary.patch
@@ -1,15 +1,15 @@
 diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl
-index 9345d8d..6429d9b 100644
+index 4b3f0db..54a3c65 100644
 --- a/third_party/llvm/workspace.bzl
 +++ b/third_party/llvm/workspace.bzl
 @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive")
  
  def repo(name):
      """Imports LLVM."""
--    LLVM_COMMIT = "d92a484e6f5c9063d82ca79405bb3557d88ad575"
--    LLVM_SHA256 = "0e6cce920f7344248ed747443fc16c316faf398e33f6a7f9f11f41ede861f824"
-+    LLVM_COMMIT = "42d641ef5cc4bd82f98ef9959a593ca6db66d75d"
-+    LLVM_SHA256 = "ec368e9c3b1e1c5eb646c21da65bb54a53060b417e61f2451f3917b35d743abd"
+-    LLVM_COMMIT = "16dadecc05fa4986d4522c2c3a09a7628feb0fd4"
+-    LLVM_SHA256 = "e7c5195e30f75c6027f90b8196ded71a41a38c586931cfb33d63295d9eed95fd"
++    LLVM_COMMIT = "4c5ef6690040383956461828457ac27f7f912edb"
++    LLVM_SHA256 = "a30da7822f5307bc0aca8c497ffdd6369e3877186e87501e2ac1f3ec5ed1c0b7"
  
      tf_http_archive(
          name = name,
diff --git a/third_party/shardy/workspace.bzl b/third_party/shardy/workspace.bzl
index 200ac3f..a40d827 100644
--- a/third_party/shardy/workspace.bzl
+++ b/third_party/shardy/workspace.bzl
@@ -3,8 +3,8 @@
 load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
 
 def repo():
-    SHARDY_COMMIT = "df54e37427b0007e6527b62616ed1f66a68dda4a"
-    SHARDY_SHA256 = "2ebf03fd73c4578e721c539ad05b33d5fbfae6838abbb58b944e12f1eafbd9b2"
+    SHARDY_COMMIT = "76731821434117cb6d736bdd1b32b7ee4ffbcb4b"
+    SHARDY_SHA256 = "944bdbdc9e97ca95b15ac81bfee151664c06b4e6373661d25091e064a604fd2f"
 
     tf_http_archive(
         name = "shardy",
diff --git a/third_party/spirv_llvm_translator/BUILD b/third_party/spirv_llvm_translator/BUILD
new file mode 100644
index 0000000..8d626dc
--- /dev/null
+++ b/third_party/spirv_llvm_translator/BUILD
@@ -0,0 +1,7 @@
+package(
+    # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
+    default_visibility = ["//visibility:public"],
+    licenses = ["notice"],
+)
+
+# spirv_llvm_translator license placeholder
diff --git a/third_party/spirv_llvm_translator/spirv_llvm_translator.BUILD b/third_party/spirv_llvm_translator/spirv_llvm_translator.BUILD
new file mode 100644
index 0000000..557e2e8
--- /dev/null
+++ b/third_party/spirv_llvm_translator/spirv_llvm_translator.BUILD
@@ -0,0 +1,34 @@
+cc_library(
+    name = "spirv_llvm_translator",
+    srcs = glob([
+        "lib/SPIRV/libSPIRV/*.cpp",
+        "lib/SPIRV/libSPIRV/*.hpp",
+        "lib/SPIRV/libSPIRV/*.h",
+        "lib/SPIRV/Mangler/*.cpp",
+        "lib/SPIRV/Mangler/*.h",
+        "lib/SPIRV/*.cpp",
+        "lib/SPIRV/*.hpp",
+        "lib/SPIRV/*.h",
+    ]),
+    hdrs = glob(["include/*"]),
+    includes = [
+        "include/",
+        "lib/SPIRV/",
+        "lib/SPIRV/Mangler/",
+        "lib/SPIRV/libSPIRV/",
+    ],
+    visibility = ["//visibility:public"],
+    deps = [
+        "@llvm-project//llvm:Analysis",
+        "@llvm-project//llvm:BitWriter",
+        "@llvm-project//llvm:CodeGen",
+        "@llvm-project//llvm:Core",
+        "@llvm-project//llvm:Demangle",
+        "@llvm-project//llvm:IRReader",
+        "@llvm-project//llvm:Linker",
+        "@llvm-project//llvm:Passes",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//llvm:TransformUtils",
+        "@spirv_headers//:spirv_cpp_headers",
+    ],
+)
diff --git a/third_party/spirv_llvm_translator/spirv_llvm_translator.patch b/third_party/spirv_llvm_translator/spirv_llvm_translator.patch
new file mode 100644
index 0000000..fc843b1
--- /dev/null
+++ b/third_party/spirv_llvm_translator/spirv_llvm_translator.patch
@@ -0,0 +1,25 @@
+diff --git a/lib/SPIRV/SPIRVInternal.h b/lib/SPIRV/SPIRVInternal.h
+index a828add8..924e13b4 100644
+
+Spir backend uses different addrspace representations link with nvptx backend link.
+We reorder the enum value here so that we can make XLA LLVM codegen simple(avoiding
+changing addrspace based on device backend everywhere)
+
+--- a/lib/SPIRV/SPIRVInternal.h
++++ b/lib/SPIRV/SPIRVInternal.h
+@@ -179,11 +179,12 @@ typedef SPIRVMap<Op, Op, IntBoolOpMapId> IntBoolOpMap;
+   "-v512:512:512-v1024:1024:1024"
+
+ enum SPIRAddressSpace {
+-  SPIRAS_Private,
++  SPIRAS_Generic,
+   SPIRAS_Global,
+-  SPIRAS_Constant,
++  SPIRAS_Internal,
+   SPIRAS_Local,
+-  SPIRAS_Generic,
++  SPIRAS_Constant,
++  SPIRAS_Private,
+   SPIRAS_GlobalDevice,
+   SPIRAS_GlobalHost,
+   SPIRAS_Input,
\ No newline at end of file
diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch
index 6a1b6b0..8b13789 100755
--- a/third_party/stablehlo/temporary.patch
+++ b/third_party/stablehlo/temporary.patch
@@ -1,19 +1 @@
-diff --ruN a/stablehlo/stablehlo/tests/math/ulp_difference_float32.mlir b/stablehlo/stablehlo/tests/math/ulp_difference_float32.mlir
---- stablehlo/stablehlo/tests/math/ulp_difference_float32.mlir
-+++ stablehlo/stablehlo/tests/math/ulp_difference_float32.mlir
-@@ -1,4 +1,5 @@
- // RUN: stablehlo-opt --chlo-legalize-to-stablehlo %s | stablehlo-translate --interpret
-+// XFAIL: *
- // This file is generated, see build_tools/math/README.md for more information.
- module @ulp_difference_float32 {
-   func.func public @main() {
-diff --ruN a/stablehlo/stablehlo/tests/math/ulp_difference_float64.mlir b/stablehlo/stablehlo/tests/math/ulp_difference_float64.mlir
---- stablehlo/stablehlo/tests/math/ulp_difference_float64.mlir
-+++ stablehlo/stablehlo/tests/math/ulp_difference_float64.mlir
-@@ -1,4 +1,5 @@
- // RUN: stablehlo-opt --chlo-legalize-to-stablehlo %s | stablehlo-translate --interpret
-+// XFAIL: *
- // This file is generated, see build_tools/math/README.md for more information.
- module @ulp_difference_float64 {
-   func.func public @main() {
 
diff --git a/third_party/stablehlo/workspace.bzl b/third_party/stablehlo/workspace.bzl
index f9c14a6..f1b02d2 100644
--- a/third_party/stablehlo/workspace.bzl
+++ b/third_party/stablehlo/workspace.bzl
@@ -4,8 +4,8 @@
 
 def repo():
     # LINT.IfChange
-    STABLEHLO_COMMIT = "8555db77763fadbd6be83df0a5532828bc419cba"
-    STABLEHLO_SHA256 = "666a88d94e0f1b36e9e5b25411521b878320c61983214859b4e419f36acbf332"
+    STABLEHLO_COMMIT = "24d1807a9a3e0df81103a0be9be7ad28ee34c85a"
+    STABLEHLO_SHA256 = "fc8b165379e6b34a7ab64ebe9334605efec9c01b9106cc78c40ed1163f35a529"
     # LINT.ThenChange(Google-internal path)
 
     tf_http_archive(
diff --git a/third_party/triton/BUILD b/third_party/triton/BUILD
index 3c41380..3b74065 100644
--- a/third_party/triton/BUILD
+++ b/third_party/triton/BUILD
@@ -1 +1,13 @@
-# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"])
+# copybara:uncomment_begin
+# package(default_applicable_licenses = ["//tensorflow:license"])
+#
+# filegroup(
+#     name = "patch_files",
+#     srcs = glob([
+#         "xla_extensions/**",
+#         "llvm_integration/**",
+#         "temporary/**",
+#     ]),
+#     visibility = ["//third_party/triton:__subpackages__"],
+# )
+# copybara:uncomment_end
diff --git a/third_party/triton/llvm_integration/BUILD b/third_party/triton/llvm_integration/BUILD
deleted file mode 100644
index 3c41380..0000000
--- a/third_party/triton/llvm_integration/BUILD
+++ /dev/null
@@ -1 +0,0 @@
-# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"])
diff --git a/third_party/triton/llvm_integration/cl657620552.patch b/third_party/triton/llvm_integration/cl657620552.patch
deleted file mode 100644
index 4a1f47d..0000000
--- a/third_party/triton/llvm_integration/cl657620552.patch
+++ /dev/null
@@ -1,18 +0,0 @@
-# Do not upstream this patch. This has been already upstreamed in
-# https://github.com/triton-lang/triton/commit/de46a0ede6efe7e93c2a9ebef639e36c6177c511
-# Next integration will include it and this patch should be removed then.
-
-diff --git a/third_party/amd/python/triton_amd.cc b/third_party/amd/python/triton_amd.cc
---- a/third_party/amd/python/triton_amd.cc
-+++ b/third_party/amd/python/triton_amd.cc
-@@ -193,9 +193,7 @@ void init_triton_amd(py::module &&m) {
-             target->createMCAsmBackend(*sti, *mri, mcOptions));
-         mcStreamer.reset(target->createMCObjectStreamer(
-             triple, ctx, std::move(mab), mab->createObjectWriter(svos),
--            std::move(ce), *sti, mcOptions.MCRelaxAll,
--            mcOptions.MCIncrementalLinkerCompatible,
--            /*DWARFMustBeAtTheEnd=*/false));
-+            std::move(ce), *sti));
- 
-         std::unique_ptr<llvm::MCAsmParser> parser(
-             createMCAsmParser(srcMgr, ctx, *mcStreamer, *mai));
diff --git a/third_party/triton/llvm_integration/series.bzl b/third_party/triton/llvm_integration/series.bzl
index 5348e66..656b9c8 100644
--- a/third_party/triton/llvm_integration/series.bzl
+++ b/third_party/triton/llvm_integration/series.bzl
@@ -8,6 +8,5 @@
 """
 
 llvm_patch_list = [
-    "//third_party/triton/llvm_integration:cl657620552.patch",
     # Add new patches just above this line
 ]
diff --git a/third_party/triton/temporary/BUILD b/third_party/triton/temporary/BUILD
deleted file mode 100644
index 3c41380..0000000
--- a/third_party/triton/temporary/BUILD
+++ /dev/null
@@ -1 +0,0 @@
-# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"])
diff --git a/third_party/triton/temporary/cuda11-temporary.patch b/third_party/triton/temporary/cuda11-temporary.patch
deleted file mode 100644
index a92166e..0000000
--- a/third_party/triton/temporary/cuda11-temporary.patch
+++ /dev/null
@@ -1,35 +0,0 @@
-# This temporary patch has already been included to the public list of Triton
-# patches. It is only here temporarily to be included in the openxla version,
-# but it will be removed during the next triton integration.
-
---- a/third_party/nvidia/backend/driver.c
-+++ b/third_party/nvidia/backend/driver.c
-@@ -154,6 +154,8 @@ static PyObject *loadBinary(PyObject *se
- typedef CUresult (*cuOccupancyMaxActiveClusters_t)(
-     int *numClusters, CUfunction func, const CUlaunchConfig *config);
- 
-+#if CUDA_VERSION < 12000
-+#else
- typedef CUresult (*cuTensorMapEncodeTiled_t)(
-     CUtensorMap *tensorMap, CUtensorMapDataType tensorDataType,
-     cuuint32_t tensorRank, void *globalAddress, const cuuint64_t *globalDim,
-@@ -161,6 +161,7 @@ typedef CUresult (*cuTensorMapEncodeTile
-     const cuuint32_t *elementStrides, CUtensorMapInterleave interleave,
-     CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion,
-     CUtensorMapFloatOOBfill oobFill);
-+#endif
- 
- #define defineGetFunctionHandle(name, symbolName)                              \
-   static symbolName##_t name() {                                               \
-@@ -187,8 +187,11 @@ typedef CUresult (*cuTensorMapEncodeTile
- defineGetFunctionHandle(getCuOccupancyMaxActiveClustersHandle,
-                         cuOccupancyMaxActiveClusters);
- 
-+#if CUDA_VERSION < 12000
-+#else
- defineGetFunctionHandle(getCuTensorMapEncodeTiledHandle,
-                         cuTensorMapEncodeTiled);
-+#endif
- 
- static PyObject *occupancyMaxActiveClusters(PyObject *self, PyObject *args) {
-   int clusterDimX = -1, clusterDimY = -1, clusterDimZ = -1,
diff --git a/third_party/triton/temporary/series.bzl b/third_party/triton/temporary/series.bzl
index f55c41e..4fa5526 100644
--- a/third_party/triton/temporary/series.bzl
+++ b/third_party/triton/temporary/series.bzl
@@ -14,7 +14,5 @@
 """
 
 temporary_patch_list = [
-    "//third_party/triton/temporary:cuda11-temporary.patch",
-    "//third_party/triton/temporary:undo_tesla_gpu.patch",
     # Add new patches just above this line
 ]
diff --git a/third_party/triton/temporary/undo_tesla_gpu.patch b/third_party/triton/temporary/undo_tesla_gpu.patch
deleted file mode 100644
index 6c2d1d1..0000000
--- a/third_party/triton/temporary/undo_tesla_gpu.patch
+++ /dev/null
@@ -1,13 +0,0 @@
-This can be removed on the next integrate as it already exists in upstream.
-diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
---- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
-+++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
-@@ -21,7 +21,7 @@ namespace {
- static int getMMAVersionSafe(int computeCapability, DotOp op) {
-   // List supported mma version in order of preference.
-   SmallVector<int> versionsSupported;
--  if (computeCapability < 80) {
-+  if (computeCapability < 75) {
-     versionsSupported = {1};
-   } else if (computeCapability < 90) {
-     versionsSupported = {2};
diff --git a/third_party/triton/workspace.bzl b/third_party/triton/workspace.bzl
index f321b4e..012f5b4 100644
--- a/third_party/triton/workspace.bzl
+++ b/third_party/triton/workspace.bzl
@@ -1,15 +1,15 @@
 """Provides the repository macro to import Triton."""
 
 load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
-load("//third_party/triton/llvm_integration:series.bzl", "llvm_patch_list")
-load("//third_party/triton/temporary:series.bzl", "temporary_patch_list")
-load("//third_party/triton/xla_extensions:series.bzl", "extensions_files_patch_list")
+load("//third_party/triton:llvm_integration/series.bzl", "llvm_patch_list")
+load("//third_party/triton:temporary/series.bzl", "temporary_patch_list")
+load("//third_party/triton:xla_extensions/series.bzl", "extensions_files_patch_list")
 
 def repo():
     """Imports Triton."""
 
-    TRITON_COMMIT = "cl657175856"
-    TRITON_SHA256 = "316f421a7d7ead2b7e5adc2e8bb68ce1a8f7809db73dbed8abd54c35bd0c1576"
+    TRITON_COMMIT = "cl659604537"
+    TRITON_SHA256 = "4f4699ca0df9d48649efcd838d331aa6bee4ec7fce905cdeed71715f6ca7e033"
     tf_http_archive(
         name = "triton",
         sha256 = TRITON_SHA256,
diff --git a/third_party/triton/xla_extensions/BUILD b/third_party/triton/xla_extensions/BUILD
deleted file mode 100644
index 3c41380..0000000
--- a/third_party/triton/xla_extensions/BUILD
+++ /dev/null
@@ -1 +0,0 @@
-# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"])
diff --git a/third_party/triton/xla_extensions/series.bzl b/third_party/triton/xla_extensions/series.bzl
index 19ba85b..be33c18 100644
--- a/third_party/triton/xla_extensions/series.bzl
+++ b/third_party/triton/xla_extensions/series.bzl
@@ -7,7 +7,7 @@
 """
 
 extensions_files_patch_list = [
-    "//third_party/triton/xla_extensions:sparse_dot.patch",  # Sparsity internal patch
-    "//third_party/triton/xla_extensions:sparsity_layout.patch",  # Sparsity internal patch
+    "//third_party/triton:xla_extensions/sparse_dot.patch",  # Sparsity internal patch
+    "//third_party/triton:xla_extensions/sparsity_layout.patch",  # Sparsity internal patch
     # Add new patches just above this line
 ]
diff --git a/third_party/triton/xla_extensions/sparse_dot.patch b/third_party/triton/xla_extensions/sparse_dot.patch
index a1c011d..d9fa4cc 100644
--- a/third_party/triton/xla_extensions/sparse_dot.patch
+++ b/third_party/triton/xla_extensions/sparse_dot.patch
@@ -273,9 +273,9 @@
 +  return op->hasTrait<OpTrait::DotLike>() || isa<ttg::SparseDotOp>(op);
 +}
 +
- // Replace the ForOp's yield with a new one with the given operands appended.
- static void appendToYield(scf::ForOp forOp, ArrayRef<Value> newOperands) {
-   // Fix up the yield op.
+ static void createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc,
+                             Value insertIdx, Value extractIdx,
+                             tt::CoarseSchedule &schedule,
 @@ -248,19 +252,28 @@ getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) {
      } else {
        if (!isa<ttg::LocalLoadOp, ttg::ConvertLayoutOp>(user))
diff --git a/third_party/triton/xla_extensions/sparsity_layout.patch b/third_party/triton/xla_extensions/sparsity_layout.patch
index b64ddbd..4daf4f2 100644
--- a/third_party/triton/xla_extensions/sparsity_layout.patch
+++ b/third_party/triton/xla_extensions/sparsity_layout.patch
@@ -2,19 +2,20 @@
 index 34fb89954..a0172e107 100644
 --- a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp
 +++ b/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp
-@@ -57,8 +57,9 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
+@@ -57,6 +57,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
    addArgumentMaterialization([&](OpBuilder &builder,
                                   RankedTensorType tensorType, ValueRange inputs,
                                   Location loc) -> std::optional<Value> {
--    llvm_unreachable("Argument rematerialization should not happen in Triton "
--                     "-> TritonGPU conversion");
-+    // TODO(b/354860562): reenable or remove.
-+    // llvm_unreachable("Argument rematerialization should not happen in Triton "
-+    //                  "-> TritonGPU conversion");
++    // Allows partial TTIR to TTGIR conversion by materializing a conversion for
++    // remaining arguments that have been converted to a new type.
++    // We use this to rewrite triton_gpu.sparse_dot in a separate pass after
++    // 'convert-triton-to-tritongpu'.
++    return builder.create<triton::gpu::ConvertLayoutOp>(loc, tensorType,
++                                                        inputs);
+     llvm_unreachable("Argument rematerialization should not happen in Triton "
+                      "-> TritonGPU conversion");
      return std::nullopt;
-   });
- 
-@@ -67,6 +68,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
+@@ -67,6 +73,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
    addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
                                 ValueRange inputs,
                                 Location loc) -> std::optional<Value> {
@@ -31,7 +32,7 @@
 index df3d3b042..e38c184f6 100644
 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp
 +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp
-@@ -2867,13 +2867,13 @@ struct CanonicalizeConvertFromConvert
+@@ -2867,13 +2879,13 @@ struct CanonicalizeConvertFromConvert
      // heuristic to accommodate fused attention.
      auto srcType = op.getSrc().getType();
      auto dstType = op.getType();
diff --git a/third_party/uv/uv.BUILD b/third_party/uv/uv.BUILD
index 75a2df3..b04383a 100644
--- a/third_party/uv/uv.BUILD
+++ b/third_party/uv/uv.BUILD
@@ -11,7 +11,48 @@
 
 cc_library(
     name = "uv",
-    srcs = glob(["src/*.c"]),
+    srcs = [
+        "src/fs-poll.c",
+        "src/idna.c",
+        "src/inet.c",
+        "src/random.c",
+        "src/strscpy.c",
+        "src/threadpool.c",
+        "src/timer.c",
+        "src/uv-common.c",
+        "src/uv-data-getter-setters.c",
+        "src/version.c",
+    ] + [
+        "src/unix/async.c",
+        "src/unix/core.c",
+        "src/unix/dl.c",
+        "src/unix/fs.c",
+        "src/unix/getaddrinfo.c",
+        "src/unix/getnameinfo.c",
+        "src/unix/loop.c",
+        "src/unix/loop-watcher.c",
+        "src/unix/pipe.c",
+        "src/unix/poll.c",
+        "src/unix/process.c",
+        "src/unix/random-devurandom.c",
+        "src/unix/signal.c",
+        "src/unix/stream.c",
+        "src/unix/tcp.c",
+        "src/unix/thread.c",
+        "src/unix/tty.c",
+        "src/unix/udp.c",
+    ] + select({
+        "@platforms//os:osx": [
+            "src/unix/bsd-ifaddrs.c",
+            "src/unix/darwin.c",
+            "src/unix/darwin-proctitle.c",
+            "src/unix/fsevents.c",
+            "src/unix/kqueue.c",
+            "src/unix/proctitle.c",
+            "src/unix/random-getentropy.c",
+        ],
+    }),
+    # TODO: Add Linux, etc. as in https://github.com/libuv/libuv/blob/v1.38.0/CMakeLists.txt.
     hdrs = [
         "include/uv.h",
     ],
diff --git a/third_party/xla/.bazelrc b/third_party/xla/.bazelrc
index 11783a8..edb4659 100644
--- a/third_party/xla/.bazelrc
+++ b/third_party/xla/.bazelrc
@@ -573,6 +573,9 @@
 build:rbe_win_clang --linkopt=/FORCE:MULTIPLE
 build:rbe_win_clang --host_linkopt=/FORCE:MULTIPLE
 
+# TODO(belitskiy): Rename `rbe_win_clang` to this, once done switching presubmits.
+build:rbe_windows_x86_cpu --config=rbe_win_clang
+
 # END TF REMOTE BUILD EXECUTION OPTIONS
 
 # TFLite build configs for generic embedded Linux
@@ -787,17 +790,19 @@
 
 # PYCPP TESTS run a suite of Python and C++ tests to verify general correctness over
 # the whole TF code base. These are usually run continuously or upon presubmit.
-# CPU PYCPP:
+# LINUX CPU PYCPP:
 test:linux_cpu_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only
 test:linux_cpu_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only
 test:linux_cpu_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium
 test:linux_cpu_pycpp_test --config=linux_cpu_pycpp_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/...
-# CUDA PYCPP:
+
+# LINUX CUDA PYCPP:
 test:linux_cuda_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11
 test:linux_cuda_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11
 test:linux_cuda_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium
 test:linux_cuda_pycpp_test --config=linux_cuda_pycpp_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/...
-# ARM64 PYCPP
+
+# LINUX ARM64 PYCPP
 # In Linux Arm64 presubmit/continuous build, we cross-compile the binaries on
 # Linux x86 so that we can use RBE. Since tests still need to run on the single
 # host Arm64 machine, the build becomes too slow (~30 min) to be a presubmit.
@@ -830,6 +835,13 @@
 # CROSS-COMPILE MACOS X86 PYCPP
 build:cross_compile_macos_x86_pycpp_test --config=macos_x86_pycpp_test
 build:cross_compile_macos_x86_pycpp_test -//tensorflow/core/kernels:quantized_conv_ops_test -//tensorflow/core/kernels:quantized_matmul_op_test -//tensorflow/python/ops:quantized_conv_ops_test -//tensorflow/tools/graph_transforms:transforms_test -//tensorflow/python/tools:aot_compiled_test
+# WINDOWS X86-64 CPU PYCPP
+test:windows_x86_cpu_pycpp_test_filters --test_tag_filters=-no_windows,-windows_excluded,-no_oss,-oss_excluded,-gpu,-tpu,-benchmark-test
+test:windows_x86_cpu_pycpp_test_filters --build_tag_filters=-no_windows,-windows_excluded,-no_oss,-oss_excluded,-benchmark-test
+test:windows_x86_cpu_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium --test_timeout="300,450,1200,3600"
+test:windows_x86_cpu_pycpp_test_opts --copt=/d2ReducedOptimizeHugeFunctions --host_copt=/d2ReducedOptimizeHugeFunctions --dynamic_mode=off --build_tests_only
+test:windows_x86_cpu_pycpp_test --config=windows_x86_cpu_pycpp_test_opts --config=windows_x86_cpu_pycpp_test_filters -- //tensorflow/... -//tensorflow/java/... -//tensorflow/lite/... -//tensorflow/compiler/...
+
 # END TF TEST SUITE OPTIONS
 
 # START CROSS-COMPILE CONFIGS
diff --git a/third_party/xla/.github/workflows/bazel_query.yml b/third_party/xla/.github/workflows/bazel_query.yml
new file mode 100644
index 0000000..253218a
--- /dev/null
+++ b/third_party/xla/.github/workflows/bazel_query.yml
@@ -0,0 +1,38 @@
+# Copyright 2024 The OpenXLA Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+name: Bazel Query
+permissions:
+  contents: read
+on:
+  pull_request:
+
+env:
+  # Have `go install` place binaries in $PATH
+  GOBIN: "/usr/local/bin"
+
+jobs:
+  bazel-query:
+    runs-on: ubuntu-22.04
+    defaults:
+      run:
+        shell: bash
+    timeout-minutes: 2
+    steps:
+      - name: "Checking out repository"
+        uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1
+      - name: "Install bazelisk"
+        run: go install github.com/bazelbuild/bazelisk@24651ab # v1.20.0
+      - name: "Run bazel query //xla/..."
+        run: bazelisk query //xla/...
diff --git a/third_party/xla/.github/workflows/buildifier.yml b/third_party/xla/.github/workflows/buildifier.yml
index 5514067..797b884 100644
--- a/third_party/xla/.github/workflows/buildifier.yml
+++ b/third_party/xla/.github/workflows/buildifier.yml
@@ -13,7 +13,8 @@
 # limitations under the License.
 # ============================================================================
 name: Buildifier
-permissions: read-all
+permissions:
+  contents: read
 on:
   pull_request:
 
diff --git a/third_party/xla/.github/workflows/check_contents.yml b/third_party/xla/.github/workflows/check_contents.yml
index fd38adf..1756b36 100644
--- a/third_party/xla/.github/workflows/check_contents.yml
+++ b/third_party/xla/.github/workflows/check_contents.yml
@@ -19,7 +19,8 @@
 # files once XLA moves out of Tensorflow internally.
 # TODO(ddunleavy): Update this after METADATA files are consolidated.
 name: Check Contents
-permissions: read-all
+permissions:
+  contents: read
 on:
   pull_request:
 
diff --git a/third_party/xla/.github/workflows/clang_format.yml b/third_party/xla/.github/workflows/clang_format.yml
index 2701311..e22b67e 100644
--- a/third_party/xla/.github/workflows/clang_format.yml
+++ b/third_party/xla/.github/workflows/clang_format.yml
@@ -14,7 +14,8 @@
 # ============================================================================
 
 name: Clang Format
-permissions: read-all
+permissions:
+  contents: read
 on:
   pull_request:
 
diff --git a/third_party/xla/docs/custom_call.md b/third_party/xla/docs/custom_call.md
index eb97ad7..2471df6 100644
--- a/third_party/xla/docs/custom_call.md
+++ b/third_party/xla/docs/custom_call.md
@@ -1,4 +1,4 @@
-# XLA custom calls
+# XLA Custom Calls
 
 This document describes how to write and use XLA custom calls using XLA FFI
 library. Custom call is a mechanism to describe an external "operation" in the
@@ -23,6 +23,269 @@
 > and custom call target references or to use C-style namespacing directly in
 > the function name.
 
+## JAX + XLA Custom Calls
+
+See [JAX documentation](https://jax.readthedocs.io/en/latest/ffi.html) for
+end to end examples of integrating custom calls and XLA FFI with JAX.
+
+## XLA FFI Binding
+
+XLA FFI binding is a compile-time specification of the custom call signature:
+custom call arguments, attributes and their types, and additional parameters
+passed via the execution context (i.e., gpu stream for GPU backend). XLA FFI
+finding can be bound to any C++ callable (function pointer, lambda, etc.) with
+compatible `operator()` signature. Constructed handler decodes XLA FFI call
+frame (defined by the stable C API), type check all parameters, and forward
+decoded results to the user-defined callback.
+
+XLA FFI binding heavily relies on template metaprogramming to be be able to
+compile constructed handler to the most efficient machine code. Run time
+overheads are in order of a couple of nanoseconds for each custom call
+parameter.
+
+XLA FFI customization points implemented as template specializations, and
+users can define how to decode their custom types, i.e., it is possible
+to define custom decoding for user-defined `enum class` types.
+
+### Returning Errors From Custom Calls
+
+Custom call implementations must return `xla::ffi::Error` value to signal
+success or error to XLA runtime. It is similar to `absl::Status`, and has
+the same set of error codes. We do not use `absl::Status` because it does
+not have a stable ABI and it would be unsafe to pass it between dynamically
+loaded custom call library, and XLA itself.
+
+```c++
+// Handler that always returns an error.
+auto always_error = Ffi::Bind().To(
+    []() { return Error(ErrorCode::kInternal, "Oops!"); });
+
+// Handler that always returns a success.
+auto always_success = Ffi::Bind().To(
+    []() { return Error::Success(); });
+
+```
+
+### Buffer Arguments And Results
+
+XLA uses destination passing style for results: custom calls (or any other XLA
+operations for that matter) do not allocate memory for results, and instead
+write into destinations passed by XLA runtime. XLA uses static buffer
+assignment, and allocates buffers for all values based on their live ranges at
+compile time.
+
+Results passed to FFI handlers wrapped into a `Result<T>` template, that
+has a pointer-like semantics: `operator->` gives access to the underlying
+parameter.
+
+`AnyBuffer` arguments and results gives access to custom call buffer parameters
+of any data type. This is useful when custom call has a generic implementation
+that works for multiple data types, and custom call implementation does run time
+dispatching based on data type. `AnyBuffer` gives access to the buffer data
+type, dimensions, and a pointer to the buffer itself.
+
+```mlir
+%0 = "stablehlo.custom_call"(%arg0) {
+  call_target_name = "foo",
+  api_version = 4 : i32
+} : (tensor<2x2xf32>) -> tensor<2x2xf32>
+```
+
+
+```c++
+// Buffers of any rank and data type.
+auto handler = Ffi::Bind().Arg<AnyBuffer>().Ret<AnyBuffer>().To(
+    [](AnyBuffer arg, Result<AnyBuffer> res) -> Error {
+      void* arg_data = arg.untyped_data();
+      void* res_data = res->untyped_data();
+      return Error::Success();
+    });
+```
+
+### Constrained Buffer Arguments And Results
+
+`Buffer` allows to add constraints on the buffer data type and rank, and they
+will be automatically checked by the handler and return an error to XLA runtime,
+if run time arguments do not match the FFI handler signature.
+
+```c++
+// Buffers of any rank and F32 data type.
+auto handler = Ffi::Bind().Arg<Buffer<F32>>().Ret<Buffer<F32>>().To(
+    [](Buffer<F32> arg, Result<Buffer<F32>> res) -> Error {
+      float* arg_data = arg.typed_data();
+      float* res_data = res->typed_data();
+      return Error::Success();
+    });
+```
+
+```c++
+// Buffers of rank 2 and F32 data type.
+auto handler = Ffi::Bind().Arg<BufferR2<F32>>().Ret<BufferR2<F32>>().To(
+    [](BufferR2<F32> arg, Result<BufferR2<F32>> res) -> Error {
+      float* arg_data = arg.typed_data();
+      float* res_data = res->typed_data();
+      return Error::Success();
+    });
+```
+
+### Variadic Arguments And Results
+
+If the number of arguments and result can be different in different instances of
+a custom call, they can be decoded at run time using `RemainingArgs` and
+`RemainingRets`.
+
+```
+auto handler = Ffi::Bind().RemainingArgs().RemainingRets().To(
+    [](RemainingArgs args, RemainingRets results) -> Error {
+      ErrorOr<AnyBuffer> arg = args.get<AnyBuffer>(0);
+      ErrorOr<Result<AnyBuffer>> res = results.get<AnyBuffer>(0);
+
+      if (!arg.has_value()) {
+        return Error(ErrorCode::kInternal, arg.error());
+      }
+
+      if (!res.has_value()) {
+        return Error(ErrorCode::kInternal, res.error());
+      }
+
+      return Error::Success();
+    });
+```
+
+Variadic arguments and results can be declared after regular arguments and
+results, however binding regular arguments and results after variadic one is
+illegal.
+
+```c++
+auto handler =
+    Ffi::Bind()
+        .Arg<AnyBuffer>()
+        .RemainingArgs()
+        .Ret<AnyBuffer>()
+        .RemainingRets()
+        .To([](AnyBuffer arg, RemainingArgs args, AnyBuffer ret,
+               RemainingRets results) -> Error { return Error::Success(); });
+```
+
+### Attributes
+
+XLA FFI supports automatic decoding of `mlir::DictionaryAttr` passed as a
+`custom_call` `backend_config` into FFI handler arguments.
+
+Note: See [stablehlo RFC](https://github.com/openxla/stablehlo/blob/main/rfcs/20240312-standardize-customcallop.md)
+for details, and `stablehlo.custom_call` operation specification.
+
+```mlir
+%0 = "stablehlo.custom_call"(%arg0) {
+  call_target_name = "foo",
+  backend_config= {
+    i32 = 42 : i32,
+    str = "string"
+  },
+  api_version = 4 : i32
+} : (tensor<f32>) -> tensor<f32>
+```
+
+In this example custom call has a single buffer argument and two attributes, and
+XLA FFI can automatically decode them and pass to the user-defined callable.
+
+```c++
+auto handler = Ffi::Bind()
+  .Arg<BufferR0<F32>>()
+  .Attr<int32_t>("i32")
+  .Attr<std::string_view>("str")
+  .To([](BufferR0<F32> buffer, int32_t i32, std::string_view str) {
+    return Error::Success();
+  });
+```
+
+### User-Defined Enum Attributes
+
+XLA FFI can automatically decode integral MLIR attributes into user-defined
+enums. Enum class must have the same underlying integral type, and decoding
+has to be explicitly registered with XLA FFI.
+
+
+```mlir
+%0 = "stablehlo.custom_call"(%arg0) {
+  call_target_name = "foo",
+  backend_config= {
+    command = 0 : i32
+  },
+  api_version = 4 : i32
+} : (tensor<f32>) -> tensor<f32>
+```
+
+```c++
+enum class Command : int32_t {
+  kAdd = 0,
+  kMul = 1,
+};
+
+XLA_FFI_REGISTER_ENUM_ATTR_DECODING(Command);
+
+auto handler = Ffi::Bind().Attr<Command>("command").To(
+    [](Command command) -> Error { return Error::Success(); });
+```
+
+### Binding All Custom Call Attributes
+
+It is possible to get access to all custom call attributes as a dictionary
+and lazily decode only the attributes that are needed at run time.
+
+```c++
+auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
+  ErrorOr<int32_t> i32 = attrs.get<int32_t>("i32");
+  return Error::Success();
+});
+```
+
+### User-defined Struct Attributes
+
+XLA FFI can decode dictionary attributes into user-defined structs.
+
+```mlir
+%0 = "stablehlo.custom_call"(%arg0) {
+  call_target_name = "foo",
+  backend_config= {
+    range = { lo = 0 : i64, hi = 42 : i64 }
+  },
+  api_version = 4 : i32
+} : (tensor<f32>) -> tensor<f32>
+```
+
+In example above `range` is an `mlir::DictionaryAttr` attribute, and instead
+of accessing dictionary fields by name, it can be automatically decoded as
+a C++ struct. Decoding has to be explicitly registered with a
+`XLA_FFI_REGISTER_STRUCT_ATTR_DECODING` macro (behind the scene it defines
+a template specialization in `::xla::ffi` namespace, thus macro must be added to
+the global namespace).
+
+```c++
+struct Range {
+  int64_t lo;
+  int64_t hi;
+};
+
+XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(Range, StructMember<int64_t>("i64"),
+                                             StructMember<int64_t>("i64"));
+
+auto handler = Ffi::Bind().Attr<Range>("range").To([](Range range) -> Error{
+  return Error::Success();
+});
+```
+
+Custom attributes can be loaded from a dictionary, just like any other
+attribute. In example below, all custom call attributes decoded as a
+`Dictionary`, and a `range` can be accessed by name.
+
+```c++
+auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
+  ErrorOr<Range> range = attrs.get<Range>("range");
+  return Error::Success();
+});
+```
+
 ## Create a custom call on CPU
 
 You can create an HLO instruction that represents a custom call via XLA's client
diff --git a/third_party/xla/docs/determinism.md b/third_party/xla/docs/determinism.md
index d8cd934..09a1e4f 100644
--- a/third_party/xla/docs/determinism.md
+++ b/third_party/xla/docs/determinism.md
@@ -8,6 +8,10 @@
 measurements different kernels can be picked as the fastest ones in different
 compilation runs.
 
+`--xla_gpu_require_complete_aot_autotune_results` can be used to ensure that no
+autotuning happens on repeated compilations - they either reuse compatible
+results of previous runs or fail.
+
 ## Execution
 
 Programs compiled by XLA can be non-deterministic on operations like scatter,
diff --git a/third_party/xla/docs/operation_semantics.md b/third_party/xla/docs/operation_semantics.md
index 55ed575..5584997 100644
--- a/third_party/xla/docs/operation_semantics.md
+++ b/third_party/xla/docs/operation_semantics.md
@@ -1214,8 +1214,8 @@
 
 Where `Op` is one of `Add` (addition), `Sub`(subtraction), `Mul`
 (multiplication), `Div` (division), `Pow` (power), `Rem` (remainder), `Max`
-(maximum), `Min` (minimum), `LogicalAnd` (logical AND), `LogicalOr` (logical
-OR), `LogicalXor` (logical XOR), `ShiftLeft` (Left Shift),
+(maximum), `Min` (minimum), `And` (logical AND), `Or` (logical
+OR), `Xor` (logical XOR), `ShiftLeft` (Left Shift),
 `ShiftRightArithmetic` (arithmetic Right Shift), `ShiftRightLogical` (logical
 Right Shift), `Atan2` (2-argument arctangent), or `Complex` (combines real and
 imaginary parts into a complex number)
@@ -1305,12 +1305,22 @@
 
 <b>`Abs(operand)`</b> Element-wise abs `x -> |x|`.
 
+<b>`Cbrt(operand)`</b> Element-wise cubic root operation `x -> cbrt(x)`.
+
 <b>`Ceil(operand)`</b> Element-wise ceil `x -> ⌈x⌉`.
 
+<b>`Clz(operand)`</b> Element-wise count leading zeros.
+
 <b>`Cos(operand)`</b> Element-wise cosine `x -> cos(x)`.
 
+<b>`Erf(operand)`</b> Element-wise error function `x -> erf(x)` where
+
+$$\text{erf}(x) = \frac{2}{\sqrt{\pi}}\int_0^x e^{-t^2} \, dt$$.
+
 <b>`Exp(operand)`</b> Element-wise natural exponential `x -> e^x`.
 
+<b>`Expm1(operand)`</b> Element-wise natural exponential minus one `x -> e^x - 1`.
+
 <b>`Floor(operand)`</b> Element-wise floor `x -> ⌊x⌋`.
 
 <b>`Imag(operand)`</b> Element-wise imaginary part of a complex (or real)
@@ -1323,19 +1333,25 @@
 
 <b>`Log(operand)`</b> Element-wise natural logarithm `x -> ln(x)`.
 
-<b>`LogicalNot(operand)`</b> Element-wise logical not `x -> !(x)`.
+<b>`Log1p(operand)`</b> Element-wise shifted natural logarithm `x -> ln(1+x)`.
 
 <b>`Logistic(operand)`</b> Element-wise logistic function computation `x ->
 logistic(x)`.
 
+<b>`Neg(operand)`</b> Element-wise negation `x -> -x`.
+
+<b>`Not(operand)`</b> Element-wise logical not `x -> !(x)`.
+
 <b>`PopulationCount(operand)`</b> Computes the number of bits set in each
 element of `operand`.
 
-<b>`Neg(operand)`</b> Element-wise negation `x -> -x`.
-
 <b>`Real(operand)`</b> Element-wise real part of a complex (or real) shape.
 `x -> real(x)`. If the operand is a floating point type, returns the same value.
 
+<b>`Round(operand)`</b> Element-wise rounding, ties away from zero.
+
+<b>`RoundNearestEven(operand)`</b> Element-wise rounding, ties to nearest even.
+
 <b>`Rsqrt(operand)`</b> Element-wise reciprocal of square root operation
 `x -> 1.0 / sqrt(x)`.
 
@@ -1345,16 +1361,14 @@
 
 using the comparison operator of the element type of `operand`.
 
+<b>`Sin(operand)`</b> Element-wise sine `x -> sin(x)`.
+
 <b>`Sqrt(operand)`</b> Element-wise square root operation `x -> sqrt(x)`.
 
-<b>`Cbrt(operand)`</b> Element-wise cubic root operation `x -> cbrt(x)`.
+<b>`Tan(operand)`</b> Element-wise tangent `x -> tan(x)`.
 
 <b>`Tanh(operand)`</b> Element-wise hyperbolic tangent `x -> tanh(x)`.
 
-<b>`Round(operand)`</b> Element-wise rounding, ties away from zero.
-
-<b>`RoundNearestEven(operand)`</b> Element-wise rounding, ties to nearest even.
-
 Arguments | Type    | Semantics
 --------- | ------- | ---------------------------
 `operand` | `XlaOp` | The operand to the function
diff --git a/third_party/xla/opensource_only.files b/third_party/xla/opensource_only.files
index baafd35..5759a24 100644
--- a/third_party/xla/opensource_only.files
+++ b/third_party/xla/opensource_only.files
@@ -34,6 +34,7 @@
 third_party/py/python_repo.bzl:
 third_party/python_runtime/BUILD:
 third_party/repo.bzl:
+third_party/spirv_llvm_translator/spirv_llvm_translator.BUILD:
 third_party/stablehlo/BUILD:
 tools/toolchains/BUILD:
 tools/toolchains/clang6/BUILD:
diff --git a/third_party/xla/third_party/shardy/temporary.patch b/third_party/xla/third_party/shardy/temporary.patch
index 4d99610..1711c1b 100644
--- a/third_party/xla/third_party/shardy/temporary.patch
+++ b/third_party/xla/third_party/shardy/temporary.patch
@@ -1,15 +1,15 @@
 diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl
-index 9345d8d..6429d9b 100644
+index 4b3f0db..54a3c65 100644
 --- a/third_party/llvm/workspace.bzl
 +++ b/third_party/llvm/workspace.bzl
 @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive")
  
  def repo(name):
      """Imports LLVM."""
--    LLVM_COMMIT = "d92a484e6f5c9063d82ca79405bb3557d88ad575"
--    LLVM_SHA256 = "0e6cce920f7344248ed747443fc16c316faf398e33f6a7f9f11f41ede861f824"
-+    LLVM_COMMIT = "42d641ef5cc4bd82f98ef9959a593ca6db66d75d"
-+    LLVM_SHA256 = "ec368e9c3b1e1c5eb646c21da65bb54a53060b417e61f2451f3917b35d743abd"
+-    LLVM_COMMIT = "16dadecc05fa4986d4522c2c3a09a7628feb0fd4"
+-    LLVM_SHA256 = "e7c5195e30f75c6027f90b8196ded71a41a38c586931cfb33d63295d9eed95fd"
++    LLVM_COMMIT = "4c5ef6690040383956461828457ac27f7f912edb"
++    LLVM_SHA256 = "a30da7822f5307bc0aca8c497ffdd6369e3877186e87501e2ac1f3ec5ed1c0b7"
  
      tf_http_archive(
          name = name,
diff --git a/third_party/xla/third_party/shardy/workspace.bzl b/third_party/xla/third_party/shardy/workspace.bzl
index 200ac3f..a40d827 100644
--- a/third_party/xla/third_party/shardy/workspace.bzl
+++ b/third_party/xla/third_party/shardy/workspace.bzl
@@ -3,8 +3,8 @@
 load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
 
 def repo():
-    SHARDY_COMMIT = "df54e37427b0007e6527b62616ed1f66a68dda4a"
-    SHARDY_SHA256 = "2ebf03fd73c4578e721c539ad05b33d5fbfae6838abbb58b944e12f1eafbd9b2"
+    SHARDY_COMMIT = "76731821434117cb6d736bdd1b32b7ee4ffbcb4b"
+    SHARDY_SHA256 = "944bdbdc9e97ca95b15ac81bfee151664c06b4e6373661d25091e064a604fd2f"
 
     tf_http_archive(
         name = "shardy",
diff --git a/third_party/xla/third_party/spirv_llvm_translator/BUILD b/third_party/xla/third_party/spirv_llvm_translator/BUILD
new file mode 100644
index 0000000..8d626dc
--- /dev/null
+++ b/third_party/xla/third_party/spirv_llvm_translator/BUILD
@@ -0,0 +1,7 @@
+package(
+    # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
+    default_visibility = ["//visibility:public"],
+    licenses = ["notice"],
+)
+
+# spirv_llvm_translator license placeholder
diff --git a/third_party/xla/third_party/spirv_llvm_translator/spirv_llvm_translator.BUILD b/third_party/xla/third_party/spirv_llvm_translator/spirv_llvm_translator.BUILD
new file mode 100644
index 0000000..557e2e8
--- /dev/null
+++ b/third_party/xla/third_party/spirv_llvm_translator/spirv_llvm_translator.BUILD
@@ -0,0 +1,34 @@
+cc_library(
+    name = "spirv_llvm_translator",
+    srcs = glob([
+        "lib/SPIRV/libSPIRV/*.cpp",
+        "lib/SPIRV/libSPIRV/*.hpp",
+        "lib/SPIRV/libSPIRV/*.h",
+        "lib/SPIRV/Mangler/*.cpp",
+        "lib/SPIRV/Mangler/*.h",
+        "lib/SPIRV/*.cpp",
+        "lib/SPIRV/*.hpp",
+        "lib/SPIRV/*.h",
+    ]),
+    hdrs = glob(["include/*"]),
+    includes = [
+        "include/",
+        "lib/SPIRV/",
+        "lib/SPIRV/Mangler/",
+        "lib/SPIRV/libSPIRV/",
+    ],
+    visibility = ["//visibility:public"],
+    deps = [
+        "@llvm-project//llvm:Analysis",
+        "@llvm-project//llvm:BitWriter",
+        "@llvm-project//llvm:CodeGen",
+        "@llvm-project//llvm:Core",
+        "@llvm-project//llvm:Demangle",
+        "@llvm-project//llvm:IRReader",
+        "@llvm-project//llvm:Linker",
+        "@llvm-project//llvm:Passes",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//llvm:TransformUtils",
+        "@spirv_headers//:spirv_cpp_headers",
+    ],
+)
diff --git a/third_party/xla/third_party/spirv_llvm_translator/spirv_llvm_translator.patch b/third_party/xla/third_party/spirv_llvm_translator/spirv_llvm_translator.patch
new file mode 100644
index 0000000..fc843b1
--- /dev/null
+++ b/third_party/xla/third_party/spirv_llvm_translator/spirv_llvm_translator.patch
@@ -0,0 +1,25 @@
+diff --git a/lib/SPIRV/SPIRVInternal.h b/lib/SPIRV/SPIRVInternal.h
+index a828add8..924e13b4 100644
+
+Spir backend uses different addrspace representations link with nvptx backend link.
+We reorder the enum value here so that we can make XLA LLVM codegen simple(avoiding
+changing addrspace based on device backend everywhere)
+
+--- a/lib/SPIRV/SPIRVInternal.h
++++ b/lib/SPIRV/SPIRVInternal.h
+@@ -179,11 +179,12 @@ typedef SPIRVMap<Op, Op, IntBoolOpMapId> IntBoolOpMap;
+   "-v512:512:512-v1024:1024:1024"
+
+ enum SPIRAddressSpace {
+-  SPIRAS_Private,
++  SPIRAS_Generic,
+   SPIRAS_Global,
+-  SPIRAS_Constant,
++  SPIRAS_Internal,
+   SPIRAS_Local,
+-  SPIRAS_Generic,
++  SPIRAS_Constant,
++  SPIRAS_Private,
+   SPIRAS_GlobalDevice,
+   SPIRAS_GlobalHost,
+   SPIRAS_Input,
\ No newline at end of file
diff --git a/third_party/xla/third_party/stablehlo/temporary.patch b/third_party/xla/third_party/stablehlo/temporary.patch
index 6a1b6b0..8b13789 100755
--- a/third_party/xla/third_party/stablehlo/temporary.patch
+++ b/third_party/xla/third_party/stablehlo/temporary.patch
@@ -1,19 +1 @@
-diff --ruN a/stablehlo/stablehlo/tests/math/ulp_difference_float32.mlir b/stablehlo/stablehlo/tests/math/ulp_difference_float32.mlir
---- stablehlo/stablehlo/tests/math/ulp_difference_float32.mlir
-+++ stablehlo/stablehlo/tests/math/ulp_difference_float32.mlir
-@@ -1,4 +1,5 @@
- // RUN: stablehlo-opt --chlo-legalize-to-stablehlo %s | stablehlo-translate --interpret
-+// XFAIL: *
- // This file is generated, see build_tools/math/README.md for more information.
- module @ulp_difference_float32 {
-   func.func public @main() {
-diff --ruN a/stablehlo/stablehlo/tests/math/ulp_difference_float64.mlir b/stablehlo/stablehlo/tests/math/ulp_difference_float64.mlir
---- stablehlo/stablehlo/tests/math/ulp_difference_float64.mlir
-+++ stablehlo/stablehlo/tests/math/ulp_difference_float64.mlir
-@@ -1,4 +1,5 @@
- // RUN: stablehlo-opt --chlo-legalize-to-stablehlo %s | stablehlo-translate --interpret
-+// XFAIL: *
- // This file is generated, see build_tools/math/README.md for more information.
- module @ulp_difference_float64 {
-   func.func public @main() {
 
diff --git a/third_party/xla/third_party/stablehlo/workspace.bzl b/third_party/xla/third_party/stablehlo/workspace.bzl
index f9c14a6..f1b02d2 100644
--- a/third_party/xla/third_party/stablehlo/workspace.bzl
+++ b/third_party/xla/third_party/stablehlo/workspace.bzl
@@ -4,8 +4,8 @@
 
 def repo():
     # LINT.IfChange
-    STABLEHLO_COMMIT = "8555db77763fadbd6be83df0a5532828bc419cba"
-    STABLEHLO_SHA256 = "666a88d94e0f1b36e9e5b25411521b878320c61983214859b4e419f36acbf332"
+    STABLEHLO_COMMIT = "24d1807a9a3e0df81103a0be9be7ad28ee34c85a"
+    STABLEHLO_SHA256 = "fc8b165379e6b34a7ab64ebe9334605efec9c01b9106cc78c40ed1163f35a529"
     # LINT.ThenChange(Google-internal path)
 
     tf_http_archive(
diff --git a/third_party/xla/third_party/triton/BUILD b/third_party/xla/third_party/triton/BUILD
index 3c41380..3b74065 100644
--- a/third_party/xla/third_party/triton/BUILD
+++ b/third_party/xla/third_party/triton/BUILD
@@ -1 +1,13 @@
-# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"])
+# copybara:uncomment_begin
+# package(default_applicable_licenses = ["//tensorflow:license"])
+#
+# filegroup(
+#     name = "patch_files",
+#     srcs = glob([
+#         "xla_extensions/**",
+#         "llvm_integration/**",
+#         "temporary/**",
+#     ]),
+#     visibility = ["//third_party/triton:__subpackages__"],
+# )
+# copybara:uncomment_end
diff --git a/third_party/xla/third_party/triton/llvm_integration/BUILD b/third_party/xla/third_party/triton/llvm_integration/BUILD
deleted file mode 100644
index 3c41380..0000000
--- a/third_party/xla/third_party/triton/llvm_integration/BUILD
+++ /dev/null
@@ -1 +0,0 @@
-# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"])
diff --git a/third_party/xla/third_party/triton/llvm_integration/cl657620552.patch b/third_party/xla/third_party/triton/llvm_integration/cl657620552.patch
deleted file mode 100644
index 4a1f47d..0000000
--- a/third_party/xla/third_party/triton/llvm_integration/cl657620552.patch
+++ /dev/null
@@ -1,18 +0,0 @@
-# Do not upstream this patch. This has been already upstreamed in
-# https://github.com/triton-lang/triton/commit/de46a0ede6efe7e93c2a9ebef639e36c6177c511
-# Next integration will include it and this patch should be removed then.
-
-diff --git a/third_party/amd/python/triton_amd.cc b/third_party/amd/python/triton_amd.cc
---- a/third_party/amd/python/triton_amd.cc
-+++ b/third_party/amd/python/triton_amd.cc
-@@ -193,9 +193,7 @@ void init_triton_amd(py::module &&m) {
-             target->createMCAsmBackend(*sti, *mri, mcOptions));
-         mcStreamer.reset(target->createMCObjectStreamer(
-             triple, ctx, std::move(mab), mab->createObjectWriter(svos),
--            std::move(ce), *sti, mcOptions.MCRelaxAll,
--            mcOptions.MCIncrementalLinkerCompatible,
--            /*DWARFMustBeAtTheEnd=*/false));
-+            std::move(ce), *sti));
- 
-         std::unique_ptr<llvm::MCAsmParser> parser(
-             createMCAsmParser(srcMgr, ctx, *mcStreamer, *mai));
diff --git a/third_party/xla/third_party/triton/llvm_integration/series.bzl b/third_party/xla/third_party/triton/llvm_integration/series.bzl
index 5348e66..656b9c8 100644
--- a/third_party/xla/third_party/triton/llvm_integration/series.bzl
+++ b/third_party/xla/third_party/triton/llvm_integration/series.bzl
@@ -8,6 +8,5 @@
 """
 
 llvm_patch_list = [
-    "//third_party/triton/llvm_integration:cl657620552.patch",
     # Add new patches just above this line
 ]
diff --git a/third_party/xla/third_party/triton/temporary/BUILD b/third_party/xla/third_party/triton/temporary/BUILD
deleted file mode 100644
index 3c41380..0000000
--- a/third_party/xla/third_party/triton/temporary/BUILD
+++ /dev/null
@@ -1 +0,0 @@
-# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"])
diff --git a/third_party/xla/third_party/triton/temporary/cuda11-temporary.patch b/third_party/xla/third_party/triton/temporary/cuda11-temporary.patch
deleted file mode 100644
index a92166e..0000000
--- a/third_party/xla/third_party/triton/temporary/cuda11-temporary.patch
+++ /dev/null
@@ -1,35 +0,0 @@
-# This temporary patch has already been included to the public list of Triton
-# patches. It is only here temporarily to be included in the openxla version,
-# but it will be removed during the next triton integration.
-
---- a/third_party/nvidia/backend/driver.c
-+++ b/third_party/nvidia/backend/driver.c
-@@ -154,6 +154,8 @@ static PyObject *loadBinary(PyObject *se
- typedef CUresult (*cuOccupancyMaxActiveClusters_t)(
-     int *numClusters, CUfunction func, const CUlaunchConfig *config);
- 
-+#if CUDA_VERSION < 12000
-+#else
- typedef CUresult (*cuTensorMapEncodeTiled_t)(
-     CUtensorMap *tensorMap, CUtensorMapDataType tensorDataType,
-     cuuint32_t tensorRank, void *globalAddress, const cuuint64_t *globalDim,
-@@ -161,6 +161,7 @@ typedef CUresult (*cuTensorMapEncodeTile
-     const cuuint32_t *elementStrides, CUtensorMapInterleave interleave,
-     CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion,
-     CUtensorMapFloatOOBfill oobFill);
-+#endif
- 
- #define defineGetFunctionHandle(name, symbolName)                              \
-   static symbolName##_t name() {                                               \
-@@ -187,8 +187,11 @@ typedef CUresult (*cuTensorMapEncodeTile
- defineGetFunctionHandle(getCuOccupancyMaxActiveClustersHandle,
-                         cuOccupancyMaxActiveClusters);
- 
-+#if CUDA_VERSION < 12000
-+#else
- defineGetFunctionHandle(getCuTensorMapEncodeTiledHandle,
-                         cuTensorMapEncodeTiled);
-+#endif
- 
- static PyObject *occupancyMaxActiveClusters(PyObject *self, PyObject *args) {
-   int clusterDimX = -1, clusterDimY = -1, clusterDimZ = -1,
diff --git a/third_party/xla/third_party/triton/temporary/series.bzl b/third_party/xla/third_party/triton/temporary/series.bzl
index f55c41e..4fa5526 100644
--- a/third_party/xla/third_party/triton/temporary/series.bzl
+++ b/third_party/xla/third_party/triton/temporary/series.bzl
@@ -14,7 +14,5 @@
 """
 
 temporary_patch_list = [
-    "//third_party/triton/temporary:cuda11-temporary.patch",
-    "//third_party/triton/temporary:undo_tesla_gpu.patch",
     # Add new patches just above this line
 ]
diff --git a/third_party/xla/third_party/triton/temporary/undo_tesla_gpu.patch b/third_party/xla/third_party/triton/temporary/undo_tesla_gpu.patch
deleted file mode 100644
index 6c2d1d1..0000000
--- a/third_party/xla/third_party/triton/temporary/undo_tesla_gpu.patch
+++ /dev/null
@@ -1,13 +0,0 @@
-This can be removed on the next integrate as it already exists in upstream.
-diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
---- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
-+++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
-@@ -21,7 +21,7 @@ namespace {
- static int getMMAVersionSafe(int computeCapability, DotOp op) {
-   // List supported mma version in order of preference.
-   SmallVector<int> versionsSupported;
--  if (computeCapability < 80) {
-+  if (computeCapability < 75) {
-     versionsSupported = {1};
-   } else if (computeCapability < 90) {
-     versionsSupported = {2};
diff --git a/third_party/xla/third_party/triton/workspace.bzl b/third_party/xla/third_party/triton/workspace.bzl
index f321b4e..012f5b4 100644
--- a/third_party/xla/third_party/triton/workspace.bzl
+++ b/third_party/xla/third_party/triton/workspace.bzl
@@ -1,15 +1,15 @@
 """Provides the repository macro to import Triton."""
 
 load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
-load("//third_party/triton/llvm_integration:series.bzl", "llvm_patch_list")
-load("//third_party/triton/temporary:series.bzl", "temporary_patch_list")
-load("//third_party/triton/xla_extensions:series.bzl", "extensions_files_patch_list")
+load("//third_party/triton:llvm_integration/series.bzl", "llvm_patch_list")
+load("//third_party/triton:temporary/series.bzl", "temporary_patch_list")
+load("//third_party/triton:xla_extensions/series.bzl", "extensions_files_patch_list")
 
 def repo():
     """Imports Triton."""
 
-    TRITON_COMMIT = "cl657175856"
-    TRITON_SHA256 = "316f421a7d7ead2b7e5adc2e8bb68ce1a8f7809db73dbed8abd54c35bd0c1576"
+    TRITON_COMMIT = "cl659604537"
+    TRITON_SHA256 = "4f4699ca0df9d48649efcd838d331aa6bee4ec7fce905cdeed71715f6ca7e033"
     tf_http_archive(
         name = "triton",
         sha256 = TRITON_SHA256,
diff --git a/third_party/xla/third_party/triton/xla_extensions/BUILD b/third_party/xla/third_party/triton/xla_extensions/BUILD
deleted file mode 100644
index 3c41380..0000000
--- a/third_party/xla/third_party/triton/xla_extensions/BUILD
+++ /dev/null
@@ -1 +0,0 @@
-# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"])
diff --git a/third_party/xla/third_party/triton/xla_extensions/series.bzl b/third_party/xla/third_party/triton/xla_extensions/series.bzl
index 19ba85b..be33c18 100644
--- a/third_party/xla/third_party/triton/xla_extensions/series.bzl
+++ b/third_party/xla/third_party/triton/xla_extensions/series.bzl
@@ -7,7 +7,7 @@
 """
 
 extensions_files_patch_list = [
-    "//third_party/triton/xla_extensions:sparse_dot.patch",  # Sparsity internal patch
-    "//third_party/triton/xla_extensions:sparsity_layout.patch",  # Sparsity internal patch
+    "//third_party/triton:xla_extensions/sparse_dot.patch",  # Sparsity internal patch
+    "//third_party/triton:xla_extensions/sparsity_layout.patch",  # Sparsity internal patch
     # Add new patches just above this line
 ]
diff --git a/third_party/xla/third_party/triton/xla_extensions/sparse_dot.patch b/third_party/xla/third_party/triton/xla_extensions/sparse_dot.patch
index a1c011d..d9fa4cc 100644
--- a/third_party/xla/third_party/triton/xla_extensions/sparse_dot.patch
+++ b/third_party/xla/third_party/triton/xla_extensions/sparse_dot.patch
@@ -273,9 +273,9 @@
 +  return op->hasTrait<OpTrait::DotLike>() || isa<ttg::SparseDotOp>(op);
 +}
 +
- // Replace the ForOp's yield with a new one with the given operands appended.
- static void appendToYield(scf::ForOp forOp, ArrayRef<Value> newOperands) {
-   // Fix up the yield op.
+ static void createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc,
+                             Value insertIdx, Value extractIdx,
+                             tt::CoarseSchedule &schedule,
 @@ -248,19 +252,28 @@ getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) {
      } else {
        if (!isa<ttg::LocalLoadOp, ttg::ConvertLayoutOp>(user))
diff --git a/third_party/xla/third_party/triton/xla_extensions/sparsity_layout.patch b/third_party/xla/third_party/triton/xla_extensions/sparsity_layout.patch
index b64ddbd..4daf4f2 100644
--- a/third_party/xla/third_party/triton/xla_extensions/sparsity_layout.patch
+++ b/third_party/xla/third_party/triton/xla_extensions/sparsity_layout.patch
@@ -2,19 +2,20 @@
 index 34fb89954..a0172e107 100644
 --- a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp
 +++ b/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp
-@@ -57,8 +57,9 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
+@@ -57,6 +57,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
    addArgumentMaterialization([&](OpBuilder &builder,
                                   RankedTensorType tensorType, ValueRange inputs,
                                   Location loc) -> std::optional<Value> {
--    llvm_unreachable("Argument rematerialization should not happen in Triton "
--                     "-> TritonGPU conversion");
-+    // TODO(b/354860562): reenable or remove.
-+    // llvm_unreachable("Argument rematerialization should not happen in Triton "
-+    //                  "-> TritonGPU conversion");
++    // Allows partial TTIR to TTGIR conversion by materializing a conversion for
++    // remaining arguments that have been converted to a new type.
++    // We use this to rewrite triton_gpu.sparse_dot in a separate pass after
++    // 'convert-triton-to-tritongpu'.
++    return builder.create<triton::gpu::ConvertLayoutOp>(loc, tensorType,
++                                                        inputs);
+     llvm_unreachable("Argument rematerialization should not happen in Triton "
+                      "-> TritonGPU conversion");
      return std::nullopt;
-   });
- 
-@@ -67,6 +68,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
+@@ -67,6 +73,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
    addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
                                 ValueRange inputs,
                                 Location loc) -> std::optional<Value> {
@@ -31,7 +32,7 @@
 index df3d3b042..e38c184f6 100644
 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp
 +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp
-@@ -2867,13 +2867,13 @@ struct CanonicalizeConvertFromConvert
+@@ -2867,13 +2879,13 @@ struct CanonicalizeConvertFromConvert
      // heuristic to accommodate fused attention.
      auto srcType = op.getSrc().getType();
      auto dstType = op.getType();
diff --git a/third_party/xla/third_party/tsl/.bazelrc b/third_party/xla/third_party/tsl/.bazelrc
index 11783a8..edb4659 100644
--- a/third_party/xla/third_party/tsl/.bazelrc
+++ b/third_party/xla/third_party/tsl/.bazelrc
@@ -573,6 +573,9 @@
 build:rbe_win_clang --linkopt=/FORCE:MULTIPLE
 build:rbe_win_clang --host_linkopt=/FORCE:MULTIPLE
 
+# TODO(belitskiy): Rename `rbe_win_clang` to this, once done switching presubmits.
+build:rbe_windows_x86_cpu --config=rbe_win_clang
+
 # END TF REMOTE BUILD EXECUTION OPTIONS
 
 # TFLite build configs for generic embedded Linux
@@ -787,17 +790,19 @@
 
 # PYCPP TESTS run a suite of Python and C++ tests to verify general correctness over
 # the whole TF code base. These are usually run continuously or upon presubmit.
-# CPU PYCPP:
+# LINUX CPU PYCPP:
 test:linux_cpu_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only
 test:linux_cpu_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only
 test:linux_cpu_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium
 test:linux_cpu_pycpp_test --config=linux_cpu_pycpp_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/...
-# CUDA PYCPP:
+
+# LINUX CUDA PYCPP:
 test:linux_cuda_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11
 test:linux_cuda_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11
 test:linux_cuda_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium
 test:linux_cuda_pycpp_test --config=linux_cuda_pycpp_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/...
-# ARM64 PYCPP
+
+# LINUX ARM64 PYCPP
 # In Linux Arm64 presubmit/continuous build, we cross-compile the binaries on
 # Linux x86 so that we can use RBE. Since tests still need to run on the single
 # host Arm64 machine, the build becomes too slow (~30 min) to be a presubmit.
@@ -830,6 +835,13 @@
 # CROSS-COMPILE MACOS X86 PYCPP
 build:cross_compile_macos_x86_pycpp_test --config=macos_x86_pycpp_test
 build:cross_compile_macos_x86_pycpp_test -//tensorflow/core/kernels:quantized_conv_ops_test -//tensorflow/core/kernels:quantized_matmul_op_test -//tensorflow/python/ops:quantized_conv_ops_test -//tensorflow/tools/graph_transforms:transforms_test -//tensorflow/python/tools:aot_compiled_test
+# WINDOWS X86-64 CPU PYCPP
+test:windows_x86_cpu_pycpp_test_filters --test_tag_filters=-no_windows,-windows_excluded,-no_oss,-oss_excluded,-gpu,-tpu,-benchmark-test
+test:windows_x86_cpu_pycpp_test_filters --build_tag_filters=-no_windows,-windows_excluded,-no_oss,-oss_excluded,-benchmark-test
+test:windows_x86_cpu_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium --test_timeout="300,450,1200,3600"
+test:windows_x86_cpu_pycpp_test_opts --copt=/d2ReducedOptimizeHugeFunctions --host_copt=/d2ReducedOptimizeHugeFunctions --dynamic_mode=off --build_tests_only
+test:windows_x86_cpu_pycpp_test --config=windows_x86_cpu_pycpp_test_opts --config=windows_x86_cpu_pycpp_test_filters -- //tensorflow/... -//tensorflow/java/... -//tensorflow/lite/... -//tensorflow/compiler/...
+
 # END TF TEST SUITE OPTIONS
 
 # START CROSS-COMPILE CONFIGS
diff --git a/third_party/xla/third_party/tsl/opensource_only.files b/third_party/xla/third_party/tsl/opensource_only.files
index 300ae95..1d52b26 100644
--- a/third_party/xla/third_party/tsl/opensource_only.files
+++ b/third_party/xla/third_party/tsl/opensource_only.files
@@ -93,6 +93,7 @@
 third_party/repo.bzl:
 third_party/six.BUILD:
 third_party/snappy.BUILD:
+third_party/spirv_llvm_translator/spirv_llvm_translator.BUILD:
 third_party/systemlibs/BUILD.tpl:
 third_party/systemlibs/BUILD:
 third_party/systemlibs/absl_py.BUILD:
diff --git a/third_party/xla/third_party/tsl/third_party/absl/nvidia_jetson.patch b/third_party/xla/third_party/tsl/third_party/absl/nvidia_jetson.patch
new file mode 100644
index 0000000..5328c3a
--- /dev/null
+++ b/third_party/xla/third_party/tsl/third_party/absl/nvidia_jetson.patch
@@ -0,0 +1,35 @@
+From 372124e6af36a540e74a2ec31d79d7297a831f98 Mon Sep 17 00:00:00 2001
+From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20Bastien?= <frederic.bastien@gmail.com>
+Date: Thu, 1 Aug 2024 12:38:52 -0700
+Subject: [PATCH] PR #1732: Fix build on NVIDIA Jetson board. Fix #1665
+
+Imported from GitHub PR https://github.com/abseil/abseil-cpp/pull/1732
+
+Fix build on NVIDIA Jetson board. Fix #1665
+
+This patch is already used by the spark project.
+I'm fixing this as this break the build of Tensorflow and JAX on Jetson board.
+Merge 7db2d2ab9fbed1f0fabad10a6ec73533ba71bfff into 6b8ebb35c0414ef5a2b6fd4a0f59057e41beaff9
+
+Merging this change closes #1732
+
+COPYBARA_INTEGRATE_REVIEW=https://github.com/abseil/abseil-cpp/pull/1732 from nouiz:fix_neon_on_jetson 7db2d2ab9fbed1f0fabad10a6ec73533ba71bfff
+PiperOrigin-RevId: 658501520
+Change-Id: If502ede4efc8c877fb3fed227eca6dc7622dd181
+---
+ absl/base/config.h | 2 +-
+ 1 file changed, 1 insertion(+), 1 deletion(-)
+
+diff --git a/absl/base/config.h b/absl/base/config.h
+index 97c9a22a109..ab1e9860a91 100644
+--- a/absl/base/config.h
++++ b/absl/base/config.h
+@@ -926,7 +926,7 @@ static_assert(ABSL_INTERNAL_INLINE_NAMESPACE_STR[0] != 'h' ||
+ // https://llvm.org/docs/CompileCudaWithLLVM.html#detecting-clang-vs-nvcc-from-code
+ #ifdef ABSL_INTERNAL_HAVE_ARM_NEON
+ #error ABSL_INTERNAL_HAVE_ARM_NEON cannot be directly set
+-#elif defined(__ARM_NEON) && !defined(__CUDA_ARCH__)
++#elif defined(__ARM_NEON) && !(defined(__NVCC__) && defined(__CUDACC__))
+ #define ABSL_INTERNAL_HAVE_ARM_NEON 1
+ #endif
+ 
diff --git a/third_party/xla/third_party/tsl/third_party/absl/workspace.bzl b/third_party/xla/third_party/tsl/third_party/absl/workspace.bzl
index 06f7516..9565a82 100644
--- a/third_party/xla/third_party/tsl/third_party/absl/workspace.bzl
+++ b/third_party/xla/third_party/tsl/third_party/absl/workspace.bzl
@@ -44,4 +44,5 @@
         system_link_files = SYS_LINKS,
         strip_prefix = "abseil-cpp-{commit}".format(commit = ABSL_COMMIT),
         urls = tf_mirror_urls("https://github.com/abseil/abseil-cpp/archive/{commit}.tar.gz".format(commit = ABSL_COMMIT)),
+        patch_file = ["//third_party/absl:nvidia_jetson.patch"],
     )
diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/BUILD.tpl
index 0b85e59..44cdbe3 100644
--- a/third_party/xla/third_party/tsl/third_party/gpus/cuda/BUILD.tpl
+++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/BUILD.tpl
@@ -249,3 +249,9 @@
     # to make bazel query happy.
     name = "nvptxcompiler",
 )
+
+cc_library(
+    # This is not yet fully supported, but we need the rule
+    # to make bazel query happy.
+    name = "nvjitlink",
+)
\ No newline at end of file
diff --git a/third_party/xla/third_party/tsl/third_party/gpus/rocm_configure.bzl b/third_party/xla/third_party/tsl/third_party/gpus/rocm_configure.bzl
index c185ca7..ff9b53b 100644
--- a/third_party/xla/third_party/tsl/third_party/gpus/rocm_configure.bzl
+++ b/third_party/xla/third_party/tsl/third_party/gpus/rocm_configure.bzl
@@ -205,6 +205,8 @@
     inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/17.0.0/include")
     inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/17/include")
     inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/18/include")
+    if int(rocm_config.rocm_version_number) >= 60200:
+        inc_dirs.append(rocm_toolkit_path + "/lib/llvm/lib/clang/18/include")
 
     # Support hcc based off clang 10.0.0 (for ROCm 3.3)
     inc_dirs.append(rocm_toolkit_path + "/hcc/compiler/lib/clang/10.0.0/include/")
diff --git a/third_party/xla/third_party/tsl/third_party/mkl_dnn/mkldnn_v1.BUILD b/third_party/xla/third_party/tsl/third_party/mkl_dnn/mkldnn_v1.BUILD
index 4b3ad84..8c73096 100644
--- a/third_party/xla/third_party/tsl/third_party/mkl_dnn/mkldnn_v1.BUILD
+++ b/third_party/xla/third_party/tsl/third_party/mkl_dnn/mkldnn_v1.BUILD
@@ -12,7 +12,7 @@
     "#cmakedefine DNNL_SYCL_CUDA": "#undef DNNL_SYCL_CUDA",
     "#cmakedefine DNNL_SYCL_HIP": "#undef DNNL_SYCL_HIP",
     "#cmakedefine DNNL_ENABLE_STACK_CHECKER": "#undef DNNL_ENABLE_STACK_CHECKER",
-    "#cmakedefine ONEDNN_BUILD_GRAPH": "#undef ONEDNN_BUILD_GRAPH",
+    "#cmakedefine ONEDNN_BUILD_GRAPH": "#define ONEDNN_BUILD_GRAPH",
     "#cmakedefine DNNL_EXPERIMENTAL_SPARSE": "#define DNNL_EXPERIMENTAL_SPARSE",
     "#cmakedefine DNNL_EXPERIMENTAL": "#undef DNNL_EXPERIMENTAL",
     "#cmakedefine01 BUILD_TRAINING": "#define BUILD_TRAINING 1",
@@ -109,6 +109,7 @@
     "-UUSE_CBLAS",
     "-DDNNL_ENABLE_MAX_CPU_ISA",
     "-DDNNL_ENABLE_ITT_TASKS",
+    "-DDNNL_ENABLE_GRAPH_DUMP",
 ] + tf_openmp_copts()
 
 _INCLUDES_LIST = [
@@ -119,6 +120,7 @@
     "src/cpu",
     "src/cpu/gemm",
     "src/cpu/x64/xbyak",
+    "src/graph",
 ]
 
 _TEXTUAL_HDRS_LIST = glob([
@@ -129,6 +131,15 @@
     "src/cpu/**/*.hpp",
     "src/cpu/jit_utils/**/*.hpp",
     "src/cpu/x64/xbyak/*.h",
+    "src/graph/interface/*.hpp",
+    "src/graph/backend/*.hpp",
+    "src/graph/backend/dnnl/*.hpp",
+    "src/graph/backend/fake/*.hpp",
+    "src/graph/backend/dnnl/passes/*.hpp",
+    "src/graph/backend/dnnl/patterns/*.hpp",
+    "src/graph/backend/dnnl/kernels/*.hpp",
+    "src/graph/utils/*.hpp",
+    "src/graph/utils/pm/*.hpp",
 ]) + [
     ":dnnl_config_h",
     ":dnnl_version_h",
@@ -160,6 +171,16 @@
             "src/cpu/**/*.cpp",
             "src/common/ittnotify/*.c",
             "src/cpu/jit_utils/**/*.cpp",
+            "src/cpu/x64/**/*.cpp",
+            "src/graph/interface/*.cpp",
+            "src/graph/backend/*.cpp",
+            "src/graph/backend/dnnl/*.cpp",
+            "src/graph/backend/fake/*.cpp",
+            "src/graph/backend/dnnl/passes/*.cpp",
+            "src/graph/backend/dnnl/patterns/*.cpp",
+            "src/graph/backend/dnnl/kernels/*.cpp",
+            "src/graph/utils/*.cpp",
+            "src/graph/utils/pm/*.cpp",
         ],
         exclude = [
             "src/cpu/aarch64/**",
diff --git a/third_party/xla/third_party/tsl/third_party/spirv_llvm_translator/BUILD b/third_party/xla/third_party/tsl/third_party/spirv_llvm_translator/BUILD
new file mode 100644
index 0000000..8d626dc
--- /dev/null
+++ b/third_party/xla/third_party/tsl/third_party/spirv_llvm_translator/BUILD
@@ -0,0 +1,7 @@
+package(
+    # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
+    default_visibility = ["//visibility:public"],
+    licenses = ["notice"],
+)
+
+# spirv_llvm_translator license placeholder
diff --git a/third_party/xla/third_party/tsl/third_party/spirv_llvm_translator/spirv_llvm_translator.BUILD b/third_party/xla/third_party/tsl/third_party/spirv_llvm_translator/spirv_llvm_translator.BUILD
new file mode 100644
index 0000000..557e2e8
--- /dev/null
+++ b/third_party/xla/third_party/tsl/third_party/spirv_llvm_translator/spirv_llvm_translator.BUILD
@@ -0,0 +1,34 @@
+cc_library(
+    name = "spirv_llvm_translator",
+    srcs = glob([
+        "lib/SPIRV/libSPIRV/*.cpp",
+        "lib/SPIRV/libSPIRV/*.hpp",
+        "lib/SPIRV/libSPIRV/*.h",
+        "lib/SPIRV/Mangler/*.cpp",
+        "lib/SPIRV/Mangler/*.h",
+        "lib/SPIRV/*.cpp",
+        "lib/SPIRV/*.hpp",
+        "lib/SPIRV/*.h",
+    ]),
+    hdrs = glob(["include/*"]),
+    includes = [
+        "include/",
+        "lib/SPIRV/",
+        "lib/SPIRV/Mangler/",
+        "lib/SPIRV/libSPIRV/",
+    ],
+    visibility = ["//visibility:public"],
+    deps = [
+        "@llvm-project//llvm:Analysis",
+        "@llvm-project//llvm:BitWriter",
+        "@llvm-project//llvm:CodeGen",
+        "@llvm-project//llvm:Core",
+        "@llvm-project//llvm:Demangle",
+        "@llvm-project//llvm:IRReader",
+        "@llvm-project//llvm:Linker",
+        "@llvm-project//llvm:Passes",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//llvm:TransformUtils",
+        "@spirv_headers//:spirv_cpp_headers",
+    ],
+)
diff --git a/third_party/xla/third_party/tsl/third_party/spirv_llvm_translator/spirv_llvm_translator.patch b/third_party/xla/third_party/tsl/third_party/spirv_llvm_translator/spirv_llvm_translator.patch
new file mode 100644
index 0000000..fc843b1
--- /dev/null
+++ b/third_party/xla/third_party/tsl/third_party/spirv_llvm_translator/spirv_llvm_translator.patch
@@ -0,0 +1,25 @@
+diff --git a/lib/SPIRV/SPIRVInternal.h b/lib/SPIRV/SPIRVInternal.h
+index a828add8..924e13b4 100644
+
+Spir backend uses different addrspace representations link with nvptx backend link.
+We reorder the enum value here so that we can make XLA LLVM codegen simple(avoiding
+changing addrspace based on device backend everywhere)
+
+--- a/lib/SPIRV/SPIRVInternal.h
++++ b/lib/SPIRV/SPIRVInternal.h
+@@ -179,11 +179,12 @@ typedef SPIRVMap<Op, Op, IntBoolOpMapId> IntBoolOpMap;
+   "-v512:512:512-v1024:1024:1024"
+
+ enum SPIRAddressSpace {
+-  SPIRAS_Private,
++  SPIRAS_Generic,
+   SPIRAS_Global,
+-  SPIRAS_Constant,
++  SPIRAS_Internal,
+   SPIRAS_Local,
+-  SPIRAS_Generic,
++  SPIRAS_Constant,
++  SPIRAS_Private,
+   SPIRAS_GlobalDevice,
+   SPIRAS_GlobalHost,
+   SPIRAS_Input,
\ No newline at end of file
diff --git a/third_party/xla/third_party/tsl/tsl/lib/gtl/flatmap.h b/third_party/xla/third_party/tsl/tsl/lib/gtl/flatmap.h
index f04fd0c..8d5cf79 100644
--- a/third_party/xla/third_party/tsl/tsl/lib/gtl/flatmap.h
+++ b/third_party/xla/third_party/tsl/tsl/lib/gtl/flatmap.h
@@ -80,7 +80,7 @@
 
   // Move constructor leaves src in a valid but unspecified state (same as
   // std::unordered_map).
-  FlatMap(FlatMap&& src) : rep_(std::move(src.rep_)) {}
+  FlatMap(FlatMap&& src) noexcept : rep_(std::move(src.rep_)) {}
 
   template <typename InputIter>
   FlatMap(InputIter first, InputIter last, size_t N = 1,
@@ -100,14 +100,14 @@
 
   // Move-assignment operator leaves src in a valid but unspecified state (same
   // as std::unordered_map).
-  FlatMap& operator=(FlatMap&& src) {
+  FlatMap& operator=(FlatMap&& src) noexcept {
     rep_.MoveFrom(std::move(src.rep_));
     return *this;
   }
 
   ~FlatMap() {}
 
-  void swap(FlatMap& x) { rep_.swap(x.rep_); }
+  void swap(FlatMap& x) noexcept { rep_.swap(x.rep_); }
   void clear_no_resize() { rep_.clear_no_resize(); }
   void clear() { rep_.clear(); }
   void reserve(size_t N) { rep_.Resize(std::max(N, size())); }
diff --git a/third_party/xla/third_party/tsl/tsl/lib/gtl/flatrep.h b/third_party/xla/third_party/tsl/tsl/lib/gtl/flatrep.h
index dfc6584..d6c77e7 100644
--- a/third_party/xla/third_party/tsl/tsl/lib/gtl/flatrep.h
+++ b/third_party/xla/third_party/tsl/tsl/lib/gtl/flatrep.h
@@ -58,10 +58,11 @@
     CopyEntries(src.array_, src.end_, CopyEntry());
   }
 
-  FlatRep(FlatRep&& src)
-      // Copy rather than move src.hash_ and src.equal_.  This is necessary to
-      // leave src in a valid state -- otherwise e.g. if hash_ is an
-      // std::function, moving it would null it out.
+  FlatRep(
+      FlatRep&& src) noexcept  // Copy rather than move src.hash_ and
+                               // src.equal_.  This is necessary to leave src in
+                               // a valid state -- otherwise e.g. if hash_ is an
+                               // std::function, moving it would null it out.
       : hash_(src.hash_), equal_(src.equal_) {
     // TODO(jlebar): Init(1) still allocates some memory, so this isn't as cheap
     // as it could be.  The fundamental problem is that we need to leave src in
@@ -118,7 +119,7 @@
     MaybeResize();
   }
 
-  void swap(FlatRep& x) {
+  void swap(FlatRep& x) noexcept {
     using std::swap;
     swap(array_, x.array_);
     swap(end_, x.end_);
diff --git a/third_party/xla/third_party/tsl/tsl/lib/gtl/flatset.h b/third_party/xla/third_party/tsl/tsl/lib/gtl/flatset.h
index ec8e9ad..b317822 100644
--- a/third_party/xla/third_party/tsl/tsl/lib/gtl/flatset.h
+++ b/third_party/xla/third_party/tsl/tsl/lib/gtl/flatset.h
@@ -63,7 +63,7 @@
 
   // Move constructor leaves src in a valid but unspecified state (same as
   // std::unordered_set).
-  FlatSet(FlatSet&& src) : rep_(std::move(src.rep_)) {}
+  FlatSet(FlatSet&& src) noexcept : rep_(std::move(src.rep_)) {}
 
   template <typename InputIter>
   FlatSet(InputIter first, InputIter last, size_t N = 1,
@@ -83,14 +83,14 @@
 
   // Move-assignment operator leaves src in a valid but unspecified state (same
   // as std::unordered_set).
-  FlatSet& operator=(FlatSet&& src) {
+  FlatSet& operator=(FlatSet&& src) noexcept {
     rep_.MoveFrom(std::move(src.rep_));
     return *this;
   }
 
   ~FlatSet() {}
 
-  void swap(FlatSet& x) { rep_.swap(x.rep_); }
+  void swap(FlatSet& x) noexcept { rep_.swap(x.rep_); }
   void clear_no_resize() { rep_.clear_no_resize(); }
   void clear() { rep_.clear(); }
   void reserve(size_t N) { rep_.Resize(std::max(N, size())); }
diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/BUILD b/third_party/xla/third_party/tsl/tsl/lib/io/BUILD
index c103dcf..055931f 100644
--- a/third_party/xla/third_party/tsl/tsl/lib/io/BUILD
+++ b/third_party/xla/third_party/tsl/tsl/lib/io/BUILD
@@ -263,12 +263,12 @@
     srcs = ["buffered_file_test.cc"],
     deps = [
         ":buffered_file",
-        "//tsl/lib/core:status_test_util",
         "//tsl/platform:env",
         "//tsl/platform:env_impl",
         "//tsl/platform:test",
         "//tsl/platform:test_benchmark",
         "//tsl/platform:test_main",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
 
@@ -443,12 +443,12 @@
     deps = [
         ":buffered_inputstream",
         ":random_inputstream",
-        "//tsl/lib/core:status_test_util",
         "//tsl/platform:env",
         "//tsl/platform:env_impl",
         "//tsl/platform:test",
         "//tsl/platform:test_benchmark",
         "//tsl/platform:test_main",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
 
@@ -471,7 +471,6 @@
     srcs = ["inputbuffer_test.cc"],
     deps = [
         ":inputbuffer",
-        "//tsl/lib/core:status_test_util",
         "//tsl/platform:coding",
         "//tsl/platform:env",
         "//tsl/platform:env_impl",
@@ -482,6 +481,7 @@
         "//tsl/platform:strcat",
         "//tsl/platform:test",
         "//tsl/platform:test_main",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
 
@@ -491,10 +491,10 @@
     srcs = ["inputstream_interface_test.cc"],
     deps = [
         ":inputstream_interface",
-        "//tsl/lib/core:status_test_util",
         "//tsl/platform:errors",
         "//tsl/platform:test",
         "//tsl/platform:test_main",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
 
@@ -504,11 +504,11 @@
     srcs = ["random_inputstream_test.cc"],
     deps = [
         ":random_inputstream",
-        "//tsl/lib/core:status_test_util",
         "//tsl/platform:env",
         "//tsl/platform:env_impl",
         "//tsl/platform:test",
         "//tsl/platform:test_main",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
 
@@ -519,7 +519,6 @@
     deps = [
         ":record_reader",
         ":record_writer",
-        "//tsl/lib/core:status_test_util",
         "//tsl/platform:env",
         "//tsl/platform:env_impl",
         "//tsl/platform:errors",
@@ -528,6 +527,7 @@
         "//tsl/platform:strcat",
         "//tsl/platform:test",
         "//tsl/platform:test_main",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
         "@zlib",
     ],
 )
@@ -539,7 +539,6 @@
     deps = [
         ":record_reader",
         ":record_writer",
-        "//tsl/lib/core:status_test_util",
         "//tsl/lib/hash:crc32c",
         "//tsl/lib/random:philox",
         "//tsl/platform:coding",
@@ -549,6 +548,7 @@
         "//tsl/platform:str_util",
         "//tsl/platform:test",
         "//tsl/platform:test_main",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
 
@@ -580,12 +580,12 @@
         ":zlib_compression_options",
         ":zlib_inputstream",
         ":zlib_outputbuffer",
-        "//tsl/lib/core:status_test_util",
         "//tsl/platform:env",
         "//tsl/platform:env_impl",
         "//tsl/platform:errors",
         "//tsl/platform:strcat",
         "//tsl/platform:test",
         "//tsl/platform:test_main",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_file_test.cc b/third_party/xla/third_party/tsl/tsl/lib/io/buffered_file_test.cc
index 6fae0b6..f9fa67d 100644
--- a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_file_test.cc
+++ b/third_party/xla/third_party/tsl/tsl/lib/io/buffered_file_test.cc
@@ -18,7 +18,7 @@
 #include <memory>
 #include <utility>
 
-#include "tsl/lib/core/status_test_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tsl/platform/env.h"
 #include "tsl/platform/test.h"
 #include "tsl/platform/test_benchmark.h"
diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream_test.cc b/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream_test.cc
index ab1f58e..83e5776 100644
--- a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream_test.cc
+++ b/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream_test.cc
@@ -15,7 +15,7 @@
 
 #include "tsl/lib/io/buffered_inputstream.h"
 
-#include "tsl/lib/core/status_test_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tsl/lib/io/random_inputstream.h"
 #include "tsl/platform/env.h"
 #include "tsl/platform/test.h"
diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/inputbuffer_test.cc b/third_party/xla/third_party/tsl/tsl/lib/io/inputbuffer_test.cc
index e384604..d23f06a 100644
--- a/third_party/xla/third_party/tsl/tsl/lib/io/inputbuffer_test.cc
+++ b/third_party/xla/third_party/tsl/tsl/lib/io/inputbuffer_test.cc
@@ -17,7 +17,7 @@
 
 #include <vector>
 
-#include "tsl/lib/core/status_test_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tsl/platform/coding.h"
 #include "tsl/platform/env.h"
 #include "tsl/platform/errors.h"
diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface_test.cc b/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface_test.cc
index 23d4fb0..c9c34db 100644
--- a/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface_test.cc
+++ b/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface_test.cc
@@ -15,7 +15,7 @@
 
 #include "tsl/lib/io/inputstream_interface.h"
 
-#include "tsl/lib/core/status_test_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tsl/platform/errors.h"
 #include "tsl/platform/test.h"
 
diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/random_inputstream_test.cc b/third_party/xla/third_party/tsl/tsl/lib/io/random_inputstream_test.cc
index 0b47ef2..dfa4ec8 100644
--- a/third_party/xla/third_party/tsl/tsl/lib/io/random_inputstream_test.cc
+++ b/third_party/xla/third_party/tsl/tsl/lib/io/random_inputstream_test.cc
@@ -15,7 +15,7 @@
 
 #include "tsl/lib/io/random_inputstream.h"
 
-#include "tsl/lib/core/status_test_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tsl/platform/env.h"
 #include "tsl/platform/test.h"
 
diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/record_reader_writer_test.cc b/third_party/xla/third_party/tsl/tsl/lib/io/record_reader_writer_test.cc
index 67df7831..45934c9 100644
--- a/third_party/xla/third_party/tsl/tsl/lib/io/record_reader_writer_test.cc
+++ b/third_party/xla/third_party/tsl/tsl/lib/io/record_reader_writer_test.cc
@@ -23,7 +23,7 @@
 #include <memory>
 #include <vector>
 
-#include "tsl/lib/core/status_test_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tsl/platform/env.h"
 #include "tsl/platform/errors.h"
 #include "tsl/platform/logging.h"
diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/recordio_test.cc b/third_party/xla/third_party/tsl/tsl/lib/io/recordio_test.cc
index 42adf76..51c2be6 100644
--- a/third_party/xla/third_party/tsl/tsl/lib/io/recordio_test.cc
+++ b/third_party/xla/third_party/tsl/tsl/lib/io/recordio_test.cc
@@ -13,7 +13,7 @@
 limitations under the License.
 ==============================================================================*/
 
-#include "tsl/lib/core/status_test_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tsl/lib/hash/crc32c.h"
 #include "tsl/lib/io/record_reader.h"
 #include "tsl/lib/io/record_writer.h"
diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/BUILD b/third_party/xla/third_party/tsl/tsl/lib/io/snappy/BUILD
index 3f42c5f..0adc5e2 100644
--- a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/BUILD
+++ b/third_party/xla/third_party/tsl/tsl/lib/io/snappy/BUILD
@@ -90,12 +90,12 @@
         ":snappy_inputbuffer",
         ":snappy_inputstream",
         ":snappy_outputbuffer",
-        "//tsl/lib/core:status_test_util",
         "//tsl/lib/io:inputbuffer",
         "//tsl/lib/io:random_inputstream",
         "//tsl/platform:env",
         "//tsl/platform:env_impl",
         "//tsl/platform:test",
         "//tsl/platform:test_main",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_inputbuffer.cc b/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_inputbuffer.cc
index d04d8d1..7844b89 100644
--- a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_inputbuffer.cc
+++ b/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_inputbuffer.cc
@@ -170,7 +170,7 @@
     bytes_to_read -= avail_in_;
     read_location += avail_in_;
   }
-  StringPiece data;
+  absl::string_view data;
   // Try to read enough data to fill up input_buffer_.
   absl::Status s = file_->Read(file_pos_, bytes_to_read, &data, read_location);
   if (data.data() != read_location) {
diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_outputbuffer.cc b/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_outputbuffer.cc
index 6d19c60..e851f58 100644
--- a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_outputbuffer.cc
+++ b/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_outputbuffer.cc
@@ -40,7 +40,7 @@
   }
 }
 
-absl::Status SnappyOutputBuffer::Append(StringPiece data) {
+absl::Status SnappyOutputBuffer::Append(absl::string_view data) {
   return Write(data);
 }
 
@@ -58,7 +58,7 @@
   return Flush();
 }
 
-absl::Status SnappyOutputBuffer::Name(StringPiece* result) const {
+absl::Status SnappyOutputBuffer::Name(absl::string_view* result) const {
   return file_->Name(result);
 }
 
@@ -71,7 +71,7 @@
   return file_->Tell(position);
 }
 
-absl::Status SnappyOutputBuffer::Write(StringPiece data) {
+absl::Status SnappyOutputBuffer::Write(absl::string_view data) {
   //
   // The deflated output is accumulated in output_buffer_ and gets written to
   // file as and when needed.
@@ -121,7 +121,7 @@
   return input_buffer_capacity_ - avail_in_;
 }
 
-void SnappyOutputBuffer::AddToInputBuffer(StringPiece data) {
+void SnappyOutputBuffer::AddToInputBuffer(absl::string_view data) {
   size_t bytes_to_write = data.size();
   DCHECK_LE(bytes_to_write, AvailableInputSpace());
 
@@ -182,7 +182,7 @@
 absl::Status SnappyOutputBuffer::FlushOutputBufferToFile() {
   size_t bytes_to_write = output_buffer_capacity_ - avail_out_;
   if (bytes_to_write > 0) {
-    absl::Status s = file_->Append(StringPiece(
+    absl::Status s = file_->Append(absl::string_view(
         reinterpret_cast<char*>(output_buffer_.get()), bytes_to_write));
     if (s.ok()) {
       next_out_ = output_buffer_.get();
diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_outputbuffer.h b/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_outputbuffer.h
index 4c4d664..a3bd447 100644
--- a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_outputbuffer.h
+++ b/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_outputbuffer.h
@@ -64,7 +64,7 @@
   //
   // The input data is buffered internally and will be written to disk at a
   // later time. To immediately write contents to file call `Flush()`.
-  absl::Status Append(StringPiece data) override;
+  absl::Status Append(absl::string_view data) override;
 
 #if defined(TF_CORD_SUPPORT)
   absl::Status Append(const absl::Cord& cord) override;
@@ -81,7 +81,7 @@
   absl::Status Close() override;
 
   // Returns the name of the underlying file.
-  absl::Status Name(StringPiece* result) const override;
+  absl::Status Name(absl::string_view* result) const override;
 
   // Deflates any cached input, writes all output to file and syncs it.
   absl::Status Sync() override;
@@ -98,7 +98,7 @@
   // to file when the buffer is full.
   //
   // To immediately write contents to file call `Flush()`.
-  absl::Status Write(StringPiece data);
+  absl::Status Write(absl::string_view data);
 
   // Compresses any cached input and writes all output to file. This must be
   // called before the destructor to avoid any data loss.
@@ -107,7 +107,7 @@
  private:
   // Appends `data` to `input_buffer_`.
   // Throws if `data.size()` > AvailableInputSpace().
-  void AddToInputBuffer(StringPiece data);
+  void AddToInputBuffer(absl::string_view data);
 
   // Appends `data` to `output_buffer_`. Flushes buffer contents to file when
   // buffer gets full.
diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_test.cc b/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_test.cc
index 33f42bd..78eecf3 100644
--- a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_test.cc
+++ b/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_test.cc
@@ -15,7 +15,7 @@
 
 #include <memory>
 
-#include "tsl/lib/core/status_test_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tsl/lib/io/inputbuffer.h"
 #include "tsl/lib/io/random_inputstream.h"
 #include "tsl/lib/io/snappy/snappy_inputbuffer.h"
@@ -77,7 +77,7 @@
                              compress_output_buf_size);
 
   for (int i = 0; i < num_writes; i++) {
-    TF_RETURN_IF_ERROR(out.Write(StringPiece(data)));
+    TF_RETURN_IF_ERROR(out.Write(absl::string_view(data)));
     if (with_flush) {
       TF_RETURN_IF_ERROR(out.Flush());
     }
@@ -96,7 +96,7 @@
     std::unique_ptr<RandomAccessFile> file_reader;
     TF_RETURN_IF_ERROR(env->NewRandomAccessFile(fname, &file_reader));
 
-    StringPiece data;
+    absl::string_view data;
     size_t file_pos = 0;
     size_t bytes_to_read = 256;
     char* scratch = new char[bytes_to_read];
@@ -106,14 +106,14 @@
     while ((file_reader->Read(file_pos, bytes_to_read, &data, scratch)).ok()) {
       file_pos += data.size();
       TF_CHECK_OK(
-          corrupt_file_writer->Append(StringPiece(buffer, buffer_size)));
+          corrupt_file_writer->Append(absl::string_view(buffer, buffer_size)));
       memcpy(buffer, data.data(), data.size());
       buffer_size = data.size();
     }
 
     // Drop the last byte. File is now corrupt.
-    TF_CHECK_OK(
-        corrupt_file_writer->Append(StringPiece(buffer, buffer_size - 1)));
+    TF_CHECK_OK(corrupt_file_writer->Append(
+        absl::string_view(buffer, buffer_size - 1)));
     TF_CHECK_OK(corrupt_file_writer->Flush());
     TF_CHECK_OK(corrupt_file_writer->Close());
     delete[] scratch;
@@ -216,7 +216,7 @@
   TF_CHECK_OK(env->NewWritableFile(fname, &file_writer));
   io::SnappyOutputBuffer out(file_writer.get(), compress_input_buf_size,
                              compress_output_buf_size);
-  TF_CHECK_OK(out.Write(StringPiece(data)));
+  TF_CHECK_OK(out.Write(absl::string_view(data)));
   TF_CHECK_OK(out.Flush());
   TF_CHECK_OK(file_writer->Flush());
   TF_CHECK_OK(file_writer->Close());
@@ -296,7 +296,7 @@
 
 static bool SnappyCompressionSupported() {
   string out;
-  StringPiece in = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa";
+  absl::string_view in = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa";
   return port::Snappy_Compress(in.data(), in.size(), &out);
 }
 
diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/zlib_buffers_test.cc b/third_party/xla/third_party/tsl/tsl/lib/io/zlib_buffers_test.cc
index 0aa65e8..c2ff61d 100644
--- a/third_party/xla/third_party/tsl/tsl/lib/io/zlib_buffers_test.cc
+++ b/third_party/xla/third_party/tsl/tsl/lib/io/zlib_buffers_test.cc
@@ -13,7 +13,7 @@
 limitations under the License.
 ==============================================================================*/
 
-#include "tsl/lib/core/status_test_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tsl/lib/io/random_inputstream.h"
 #include "tsl/lib/io/zlib_compression_options.h"
 #include "tsl/lib/io/zlib_inputstream.h"
diff --git a/third_party/xla/third_party/tsl/tsl/lib/strings/BUILD b/third_party/xla/third_party/tsl/tsl/lib/strings/BUILD
deleted file mode 100644
index 699965e..0000000
--- a/third_party/xla/third_party/tsl/tsl/lib/strings/BUILD
+++ /dev/null
@@ -1,57 +0,0 @@
-load(
-    "@local_tsl//tsl/platform:rules_cc.bzl",
-    "cc_library",
-)
-load("@local_xla//xla/tsl:tsl.bzl", "internal_visibility")
-load("@local_xla//xla/tsl:tsl.default.bzl", "filegroup")
-
-# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"])
-
-cc_library(
-    name = "proto_serialization",
-    srcs = ["proto_serialization.cc"],
-    hdrs = ["proto_serialization.h"],
-    visibility = ["//visibility:public"],
-    deps = [
-        "//tsl/lib/gtl:inlined_vector",
-        "//tsl/platform:hash",
-        "//tsl/platform:logging",
-        "//tsl/platform:macros",
-        "//tsl/platform:protobuf",
-        "@com_google_absl//absl/memory",
-        "@com_google_absl//absl/strings",
-    ],
-)
-
-filegroup(
-    name = "mobile_srcs_only_runtime",
-    srcs = [
-        "proto_serialization.cc",
-        "proto_serialization.h",
-    ],
-    visibility = internal_visibility(["//tensorflow/core/lib/strings:__pkg__"]),
-)
-
-filegroup(
-    name = "legacy_lib_strings_all_headers",
-    srcs = [
-        "proto_serialization.h",
-    ],
-    visibility = internal_visibility(["//tensorflow/core/lib/strings:__pkg__"]),
-)
-
-filegroup(
-    name = "legacy_lib_string_headers",
-    srcs = [
-        "proto_serialization.h",
-    ],
-    visibility = internal_visibility(["//tensorflow/core/lib/strings:__pkg__"]),
-)
-
-filegroup(
-    name = "legacy_lib_internal_public_string_headers",
-    srcs = [
-        "proto_serialization.h",
-    ],
-    visibility = internal_visibility(["//tensorflow/core/lib/strings:__pkg__"]),
-)
diff --git a/third_party/xla/third_party/tsl/tsl/lib/strings/proto_serialization.cc b/third_party/xla/third_party/tsl/tsl/lib/strings/proto_serialization.cc
deleted file mode 100644
index 139849e..0000000
--- a/third_party/xla/third_party/tsl/tsl/lib/strings/proto_serialization.cc
+++ /dev/null
@@ -1,101 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "tsl/lib/strings/proto_serialization.h"
-
-#include <cstring>
-#include <memory>
-
-#include "absl/memory/memory.h"
-#include "absl/strings/string_view.h"
-#include "tsl/lib/gtl/inlined_vector.h"
-#include "tsl/platform/hash.h"
-#include "tsl/platform/logging.h"
-#include "tsl/platform/macros.h"
-
-namespace tsl {
-namespace {
-
-// Helper for deterministic serialization.
-class DeterministicSerializer {
- public:
-  explicit DeterministicSerializer(const protobuf::MessageLite& msg)
-      : DeterministicSerializer(msg, msg.ByteSizeLong()) {}
-
-  DeterministicSerializer(const protobuf::MessageLite& msg, size_t size)
-      : size_(size) {
-    char* ptr = space_;
-    if (size_ > sizeof(space_)) {
-      ptr = new char[size_];
-      alloc_.reset(ptr);
-    }
-    bool ok = SerializeToBufferDeterministic(msg, ptr, size_);
-    DCHECK(ok);
-  }
-
-  size_t size() const { return size_; }
-  const char* data() const { return alloc_ == nullptr ? space_ : alloc_.get(); }
-
- private:
-  // Avoid InlinedVector since it causes 2x slowdown in the compilation
-  // of graphs containing large tensors in debug mode.
-  static constexpr int kInlinedBufferSize = 256;
-  const size_t size_;
-  std::unique_ptr<char[]> alloc_;
-  char space_[kInlinedBufferSize];
-};
-}  // namespace
-
-bool SerializeToStringDeterministic(const protobuf::MessageLite& msg,
-                                    string* result) {
-  const size_t size = msg.ByteSizeLong();
-  DCHECK_LE(size, static_cast<size_t>(INT_MAX));
-  *result = string(size, '\0');
-  return SerializeToBufferDeterministic(msg, const_cast<char*>(result->data()),
-                                        result->size());
-}
-
-bool SerializeToBufferDeterministic(const protobuf::MessageLite& msg,
-                                    char* buffer, size_t size) {
-  DCHECK(msg.ByteSizeLong() == size && size <= static_cast<size_t>(INT_MAX));
-  protobuf::io::ArrayOutputStream array_stream(buffer, size);
-  protobuf::io::CodedOutputStream output_stream(&array_stream);
-  output_stream.SetSerializationDeterministic(true);
-  msg.SerializeWithCachedSizes(&output_stream);
-  return !output_stream.HadError() &&
-         size == static_cast<size_t>(output_stream.ByteCount());
-}
-
-bool AreSerializedProtosEqual(const protobuf::MessageLite& x,
-                              const protobuf::MessageLite& y) {
-  const size_t size = x.ByteSizeLong();
-  if (size != y.ByteSizeLong()) return false;
-  if (size == 0) return true;
-  DeterministicSerializer x_serialized(x, size);
-  DeterministicSerializer y_serialized(y, size);
-  return memcmp(x_serialized.data(), y_serialized.data(), size) == 0;
-}
-
-uint64 DeterministicProtoHash64(const protobuf::MessageLite& proto,
-                                uint64 seed) {
-  DeterministicSerializer serialized(proto);
-  return Hash64(serialized.data(), serialized.size(), seed);
-}
-
-uint64 DeterministicProtoHash64(const protobuf::MessageLite& proto) {
-  DeterministicSerializer serialized(proto);
-  return Hash64(serialized.data(), serialized.size());
-}
-
-}  // namespace tsl
diff --git a/third_party/xla/third_party/tsl/tsl/lib/strings/proto_serialization.h b/third_party/xla/third_party/tsl/tsl/lib/strings/proto_serialization.h
deleted file mode 100644
index 96a5c55..0000000
--- a/third_party/xla/third_party/tsl/tsl/lib/strings/proto_serialization.h
+++ /dev/null
@@ -1,48 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef TENSORFLOW_TSL_LIB_STRINGS_PROTO_SERIALIZATION_H_
-#define TENSORFLOW_TSL_LIB_STRINGS_PROTO_SERIALIZATION_H_
-
-#include "tsl/platform/protobuf.h"
-
-namespace tsl {
-
-// Wrapper around protocol buffer serialization that requests deterministic
-// serialization, in particular for Map fields, which serialize in a random
-// order by default. Returns true on success.
-// Serialization is guaranteed to be deterministic for a given binary only.
-// See the following for more details:
-// https://github.com/google/protobuf/blob/a1bb147e96b6f74db6cdf3c3fcb00492472dbbfa/src/google/protobuf/io/coded_stream.h#L834
-bool SerializeToStringDeterministic(const protobuf::MessageLite& msg,
-                                    string* result);
-
-// As above, but takes a pre-allocated buffer wrapped by result.
-// PRECONDITION: size == msg.ByteSizeLong() && size <= INT_MAX.
-bool SerializeToBufferDeterministic(const protobuf::MessageLite& msg,
-                                    char* buffer, size_t size);
-
-// Returns true if serializing x and y using
-// SerializeToBufferDeterministic() yields identical strings.
-bool AreSerializedProtosEqual(const protobuf::MessageLite& x,
-                              const protobuf::MessageLite& y);
-
-// Computes Hash64 of the output of SerializeToBufferDeterministic().
-uint64 DeterministicProtoHash64(const protobuf::MessageLite& proto);
-uint64 DeterministicProtoHash64(const protobuf::MessageLite& proto,
-                                uint64 seed);
-
-}  // namespace tsl
-
-#endif  // TENSORFLOW_TSL_LIB_STRINGS_PROTO_SERIALIZATION_H_
diff --git a/third_party/xla/third_party/tsl/tsl/platform/BUILD b/third_party/xla/third_party/tsl/tsl/platform/BUILD
index 1f4c491..0f61c77 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/BUILD
+++ b/third_party/xla/third_party/tsl/tsl/platform/BUILD
@@ -1465,7 +1465,7 @@
         ":subprocess",
         ":test",
         ":test_main",
-        "//tsl/lib/core:status_test_util",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
 
@@ -1768,7 +1768,7 @@
         ":str_util",
         ":test",
         ":test_main",
-        "//tsl/lib/core:status_test_util",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
 
@@ -1784,7 +1784,7 @@
         ":str_util",
         ":test",
         ":test_main",
-        "//tsl/lib/core:status_test_util",
         "@com_google_absl//absl/time",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/BUILD b/third_party/xla/third_party/tsl/tsl/platform/cloud/BUILD
index bff2db4..e92fd04 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/BUILD
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/BUILD
@@ -231,7 +231,6 @@
     copts = tsl_copts(),
     deps = [
         ":curl_http_request",
-        "//tsl/lib/core:status_test_util",
         "//tsl/platform:errors",
         "//tsl/platform:macros",
         "//tsl/platform:protobuf",
@@ -240,6 +239,7 @@
         "//tsl/platform:test",
         "//tsl/platform:types",
         "@curl",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
 
@@ -359,10 +359,10 @@
     deps = [
         ":expiring_lru_cache",
         ":now_seconds_env",
-        "//tsl/lib/core:status_test_util",
         "//tsl/platform:env_impl",
         "//tsl/platform:test",
         "//tsl/platform:test_main",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
 
@@ -373,13 +373,13 @@
     deps = [
         ":now_seconds_env",
         ":ram_file_block_cache",
-        "//tsl/lib/core:status_test_util",
         "//tsl/platform:blocking_counter",
         "//tsl/platform:env",
         "//tsl/platform:env_impl",
         "//tsl/platform:notification",
         "//tsl/platform:test",
         "//tsl/platform:test_main",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
 
@@ -390,7 +390,6 @@
     deps = [
         ":gcs_file_system",
         ":http_request_fake",
-        "//tsl/lib/core:status_test_util",
         "//tsl/platform:env_impl",
         "//tsl/platform:errors",
         "//tsl/platform:str_util",
@@ -399,6 +398,7 @@
         "//tsl/platform:test_main",
         "//tsl/profiler/backends/cpu:traceme_recorder_impl",
         "//tsl/profiler/utils:time_utils_impl",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
 
@@ -423,11 +423,11 @@
     linkopts = if_windows(["-DEFAULTLIB:ws2_32.lib"]),
     deps = [
         ":gcs_throttle",
-        "//tsl/lib/core:status_test_util",
         "//tsl/platform:env_impl",
         "//tsl/platform:str_util",
         "//tsl/platform:test",
         "//tsl/platform:test_main",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
 
@@ -437,13 +437,13 @@
     srcs = ["curl_http_request_test.cc"],
     deps = [
         ":curl_http_request",
-        "//tsl/lib/core:status_test_util",
         "//tsl/platform:env_impl",
         "//tsl/platform:path",
         "//tsl/platform:platform_port",
         "//tsl/platform:test",
         "//tsl/platform:test_main",
         "@com_google_absl//absl/status",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
 
@@ -459,7 +459,6 @@
     deps = [
         ":http_request_fake",
         ":oauth_client",
-        "//tsl/lib/core:status_test_util",
         "//tsl/platform:base64",
         "//tsl/platform:env",
         "//tsl/platform:env_impl",
@@ -468,6 +467,7 @@
         "//tsl/platform:test",
         "//tsl/platform:test_main",
         "@boringssl//:crypto",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
 
@@ -484,11 +484,11 @@
         ":google_auth_provider",
         ":http_request_fake",
         ":oauth_client",
-        "//tsl/lib/core:status_test_util",
         "//tsl/platform:env_impl",
         "//tsl/platform:path",
         "//tsl/platform:test",
         "//tsl/platform:test_main",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
 
@@ -525,8 +525,8 @@
     srcs = ["time_util_test.cc"],
     deps = [
         ":time_util",
-        "//tsl/lib/core:status_test_util",
         "//tsl/platform:test",
         "//tsl/platform:test_main",
+        "@local_xla//xla/tsl/lib/core:status_test_util",
     ],
 )
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/auth_provider.h b/third_party/xla/third_party/tsl/tsl/platform/cloud/auth_provider.h
index 969b8bc..4b1b292 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/auth_provider.h
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/auth_provider.h
@@ -31,9 +31,9 @@
   /// \brief Returns the short-term authentication bearer token.
   ///
   /// Safe for concurrent use by multiple threads.
-  virtual Status GetToken(string* t) = 0;
+  virtual absl::Status GetToken(string* t) = 0;
 
-  static Status GetToken(AuthProvider* provider, string* token) {
+  static absl::Status GetToken(AuthProvider* provider, string* token) {
     if (!provider) {
       return errors::Internal("Auth provider is required.");
     }
@@ -44,9 +44,9 @@
 /// No-op auth provider, which will only work for public objects.
 class EmptyAuthProvider : public AuthProvider {
  public:
-  Status GetToken(string* token) override {
+  absl::Status GetToken(string* token) override {
     *token = "";
-    return OkStatus();
+    return absl::OkStatus();
   }
 };
 
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/compute_engine_metadata_client.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/compute_engine_metadata_client.cc
index 7be3af7..7a41c8f 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/compute_engine_metadata_client.cc
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/compute_engine_metadata_client.cc
@@ -40,7 +40,7 @@
     : http_request_factory_(std::move(http_request_factory)),
       retry_config_(config) {}
 
-Status ComputeEngineMetadataClient::GetMetadata(
+absl::Status ComputeEngineMetadataClient::GetMetadata(
     const string& path, std::vector<char>* response_buffer) {
   const auto get_metadata_from_gce = [path, response_buffer, this]() {
     string metadata_url;
@@ -56,7 +56,7 @@
     request->AddHeader("Metadata-Flavor", "Google");
     request->SetResultBuffer(response_buffer);
     TF_RETURN_IF_ERROR(request->Send());
-    return OkStatus();
+    return absl::OkStatus();
   };
 
   return RetryingUtils::CallWithRetries(get_metadata_from_gce, retry_config_);
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/compute_engine_metadata_client.h b/third_party/xla/third_party/tsl/tsl/platform/cloud/compute_engine_metadata_client.h
index fac94cd..1337d33 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/compute_engine_metadata_client.h
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/compute_engine_metadata_client.h
@@ -51,8 +51,8 @@
   /// To get the zone of an instance:
   ///   compute_engine_metadata_client.GetMetadata(
   ///       "instance/zone", response_buffer);
-  virtual Status GetMetadata(const string& path,
-                             std::vector<char>* response_buffer);
+  virtual absl::Status GetMetadata(const string& path,
+                                   std::vector<char>* response_buffer);
 
  private:
   std::shared_ptr<HttpRequest::Factory> http_request_factory_;
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/compute_engine_zone_provider.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/compute_engine_zone_provider.cc
index 7720784..19f2755 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/compute_engine_zone_provider.cc
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/compute_engine_zone_provider.cc
@@ -28,15 +28,15 @@
     std::shared_ptr<ComputeEngineMetadataClient> google_metadata_client)
     : google_metadata_client_(std::move(google_metadata_client)) {}
 
-Status ComputeEngineZoneProvider::GetZone(string* zone) {
+absl::Status ComputeEngineZoneProvider::GetZone(string* zone) {
   if (!cached_zone.empty()) {
     *zone = cached_zone;
-    return OkStatus();
+    return absl::OkStatus();
   }
   std::vector<char> response_buffer;
   TF_RETURN_IF_ERROR(google_metadata_client_->GetMetadata(kGceMetadataZonePath,
                                                           &response_buffer));
-  StringPiece location(&response_buffer[0], response_buffer.size());
+  absl::string_view location(&response_buffer[0], response_buffer.size());
 
   std::vector<string> elems = str_util::Split(location, "/");
   if (elems.size() == 4) {
@@ -47,7 +47,7 @@
                << string(location);
   }
 
-  return OkStatus();
+  return absl::OkStatus();
 }
 ComputeEngineZoneProvider::~ComputeEngineZoneProvider() {}
 
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/compute_engine_zone_provider.h b/third_party/xla/third_party/tsl/tsl/platform/cloud/compute_engine_zone_provider.h
index a37b43c..99ed41f 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/compute_engine_zone_provider.h
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/compute_engine_zone_provider.h
@@ -27,7 +27,7 @@
       std::shared_ptr<ComputeEngineMetadataClient> google_metadata_client);
   virtual ~ComputeEngineZoneProvider();
 
-  Status GetZone(string* zone) override;
+  absl::Status GetZone(string* zone) override;
 
  private:
   std::shared_ptr<ComputeEngineMetadataClient> google_metadata_client_;
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request.cc
index c41f967..44eeab7 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request.cc
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request.cc
@@ -230,8 +230,8 @@
       libcurl_->curl_easy_setopt(curl_, CURLOPT_CUSTOMREQUEST, "DELETE"));
 }
 
-Status CurlHttpRequest::SetPutFromFile(const string& body_filepath,
-                                       size_t offset) {
+absl::Status CurlHttpRequest::SetPutFromFile(const string& body_filepath,
+                                             size_t offset) {
   CheckNotSent();
   CheckMethodNotSet();
   is_method_set_ = true;
@@ -257,7 +257,7 @@
                                            reinterpret_cast<void*>(put_body_)));
   // Using the default CURLOPT_READFUNCTION, which is doing an fread() on the
   // FILE * userdata set with CURLOPT_READDATA.
-  return OkStatus();
+  return absl::OkStatus();
 }
 
 void CurlHttpRequest::SetPutEmptyBody() {
@@ -286,7 +286,7 @@
                                            reinterpret_cast<void*>(this)));
   CHECK_CURL_OK(libcurl_->curl_easy_setopt(curl_, CURLOPT_READFUNCTION,
                                            &CurlHttpRequest::ReadCallback));
-  post_body_buffer_ = StringPiece(buffer, size);
+  post_body_buffer_ = absl::string_view(buffer, size);
 }
 
 void CurlHttpRequest::SetPostEmptyBody() {
@@ -397,8 +397,8 @@
                                        size_t nmemb, void* this_object) {
   CHECK(ptr);
   auto that = reinterpret_cast<CurlHttpRequest*>(this_object);
-  StringPiece header(reinterpret_cast<const char*>(ptr), size * nmemb);
-  StringPiece name, value;
+  absl::string_view header(reinterpret_cast<const char*>(ptr), size * nmemb);
+  absl::string_view name, value;
   // The supplied header has the form "<name>: <value>", parse it.
   if (strings::Scanner(header)
           .ScanEscapedUntil(':')
@@ -412,7 +412,7 @@
   return size * nmemb;
 }
 
-Status CurlHttpRequest::Send() {
+absl::Status CurlHttpRequest::Send() {
   CheckNotSent();
   CHECK(is_uri_set_) << "URI has not been set.";
 
@@ -457,7 +457,7 @@
   auto get_error_message = [this]() -> string {
     string error_message = strings::StrCat(
         "Error executing an HTTP request: HTTP response code ", response_code_);
-    StringPiece body = GetResponse();
+    absl::string_view body = GetResponse();
     if (!body.empty()) {
       return strings::StrCat(
           error_message, " with body '",
@@ -466,7 +466,7 @@
     return error_message;
   };
 
-  Status result;
+  absl::Status result;
   switch (response_code_) {
     // The group of response codes indicating that the request achieved
     // the expected goal.
@@ -474,7 +474,7 @@
     case 201:  // Created
     case 204:  // No Content
     case 206:  // Partial Content
-      result = OkStatus();
+      result = absl::OkStatus();
       break;
 
     case 416:  // Requested Range Not Satisfiable
@@ -485,7 +485,7 @@
       if (IsDirectResponse()) {
         direct_response_.bytes_transferred_ = 0;
       }
-      result = OkStatus();
+      result = absl::OkStatus();
       break;
 
     // INVALID_ARGUMENT indicates a problem with how the request is constructed.
@@ -556,13 +556,14 @@
   CHECK(!is_sent_) << "The request has already been sent.";
 }
 
-StringPiece CurlHttpRequest::GetResponse() const {
-  StringPiece response;
+absl::string_view CurlHttpRequest::GetResponse() const {
+  absl::string_view response;
   if (IsDirectResponse()) {
-    response = StringPiece(direct_response_.buffer_,
-                           direct_response_.bytes_transferred_);
+    response = absl::string_view(direct_response_.buffer_,
+                                 direct_response_.bytes_transferred_);
   } else {
-    response = StringPiece(response_buffer_->data(), response_buffer_->size());
+    response =
+        absl::string_view(response_buffer_->data(), response_buffer_->size());
   }
   return response;
 }
@@ -627,10 +628,10 @@
   return 0;
 }
 
-Status CurlHttpRequest::CURLcodeToStatus(CURLcode code,
-                                         const char* error_buffer) {
+absl::Status CurlHttpRequest::CURLcodeToStatus(CURLcode code,
+                                               const char* error_buffer) {
   if (code == CURLE_OK) {
-    return OkStatus();
+    return absl::OkStatus();
   }
   string error_message = strings::StrCat(
       "Error executing an HTTP request: libcurl code ", code, " meaning '",
@@ -648,7 +649,7 @@
     // a response body (e.g. GCS sends one with an error message) but we
     // pretend as though they don't, so actually ignore this error.
     if (get_response_result == CURLE_OK && response_code == 416) {
-      return OkStatus();
+      return absl::OkStatus();
     }
     return errors::FailedPrecondition(
         strings::StrCat(error_message, overflow_message));
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request.h b/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request.h
index b5c7285..4c64758 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request.h
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request.h
@@ -86,7 +86,8 @@
   ///
   /// The request body will be taken from the specified file starting from
   /// the given offset.
-  Status SetPutFromFile(const string& body_filepath, size_t offset) override;
+  absl::Status SetPutFromFile(const string& body_filepath,
+                              size_t offset) override;
 
   /// Makes the request a PUT request with an empty body.
   void SetPutEmptyBody() override;
@@ -140,7 +141,7 @@
   ///
   /// If the result buffer was defined, the response will be written there.
   /// The object is not designed to be re-used after Send() is executed.
-  Status Send() override;
+  absl::Status Send() override;
 
   // Url encodes str and returns a new string.
   string EscapeString(const string& str) override;
@@ -167,18 +168,18 @@
                               curl_off_t ulnow);
   void CheckMethodNotSet() const;
   void CheckNotSent() const;
-  StringPiece GetResponse() const;
+  absl::string_view GetResponse() const;
 
   /// Helper to convert the given CURLcode and error buffer, representing the
   /// result of performing a transfer, into a Status with an error message.
-  Status CURLcodeToStatus(CURLcode code, const char* error_buffer);
+  absl::Status CURLcodeToStatus(CURLcode code, const char* error_buffer);
 
   LibCurl* libcurl_;
   Env* env_;
 
   FILE* put_body_ = nullptr;
 
-  StringPiece post_body_buffer_;
+  absl::string_view post_body_buffer_;
   size_t post_body_read_ = 0;
 
   std::vector<char>* response_buffer_ = nullptr;
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request_test.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request_test.cc
index 36d7108..31cde67 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request_test.cc
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request_test.cc
@@ -19,7 +19,7 @@
 #include <string>
 
 #include "absl/status/status.h"
-#include "tsl/lib/core/status_test_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tsl/platform/mem.h"
 #include "tsl/platform/path.h"
 #include "tsl/platform/test.h"
@@ -151,8 +151,8 @@
       posted_content_ = "";
       do {
         bytes_read = read_callback_(buffer, 1, sizeof(buffer), read_data_);
-        posted_content_ =
-            strings::StrCat(posted_content_, StringPiece(buffer, bytes_read));
+        posted_content_ = strings::StrCat(
+            posted_content_, absl::string_view(buffer, bytes_read));
       } while (bytes_read > 0);
     }
     if (write_data_ || write_callback_) {
@@ -366,7 +366,7 @@
 
   http_request.SetUri("http://www.testuri.com");
   http_request.SetResultBufferDirect(scratch.data(), scratch.size());
-  const Status& status = http_request.Send();
+  const absl::Status& status = http_request.Send();
   EXPECT_EQ(error::FAILED_PRECONDITION, status.code());
   EXPECT_EQ(
       "Error executing an HTTP request: libcurl code 23 meaning "
@@ -770,7 +770,7 @@
 
   void RecordResponse(const HttpRequest* request, const string& uri,
                       HttpRequest::RequestMethod method,
-                      const Status& result) override {
+                      const absl::Status& result) override {
     has_recorded_response_ = true;
     record_response_request_ = request;
     record_response_uri_ = uri;
@@ -787,7 +787,7 @@
   string record_response_uri_ = "http://www.testuri.com";
   HttpRequest::RequestMethod record_response_method_ =
       HttpRequest::RequestMethod::kGet;
-  Status record_response_result_;
+  absl::Status record_response_result_;
 
   bool has_recorded_request_ = false;
   bool has_recorded_response_ = false;
@@ -864,7 +864,7 @@
   http_request.AddAuthBearerHeader("fake-bearer");
   http_request.SetRange(100, 199);
   http_request.SetResultBuffer(&scratch);
-  Status s = http_request.Send();
+  absl::Status s = http_request.Send();
 
   // Check interaction with stats.
   ASSERT_TRUE(stats.has_recorded_request_);
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/expiring_lru_cache.h b/third_party/xla/third_party/tsl/tsl/platform/cloud/expiring_lru_cache.h
index 1def81b..d3a15bc 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/expiring_lru_cache.h
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/expiring_lru_cache.h
@@ -71,13 +71,13 @@
     return LookupLocked(key, value);
   }
 
-  typedef std::function<Status(const string&, T*)> ComputeFunc;
+  typedef std::function<absl::Status(const string&, T*)> ComputeFunc;
 
   /// Look up the entry with key `key` and copy it to `value` if found. If not
   /// found, call `compute_func`. If `compute_func` returns successfully, store
   /// a copy of the output parameter in the cache, and another copy in `value`.
-  Status LookupOrCompute(const string& key, T* value,
-                         const ComputeFunc& compute_func) {
+  absl::Status LookupOrCompute(const string& key, T* value,
+                               const ComputeFunc& compute_func) {
     if (max_age_ == 0) {
       return compute_func(key, value);
     }
@@ -88,9 +88,9 @@
     // key if this proves to be a significant performance bottleneck.
     mutex_lock lock(mu_);
     if (LookupLocked(key, value)) {
-      return OkStatus();
+      return absl::OkStatus();
     }
-    Status s = compute_func(key, value);
+    absl::Status s = compute_func(key, value);
     if (s.ok()) {
       InsertLocked(key, *value);
     }
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/expiring_lru_cache_test.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/expiring_lru_cache_test.cc
index ce3e0fc..7225dca 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/expiring_lru_cache_test.cc
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/expiring_lru_cache_test.cc
@@ -17,7 +17,7 @@
 
 #include <memory>
 
-#include "tsl/lib/core/status_test_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tsl/platform/cloud/now_seconds_env.h"
 #include "tsl/platform/test.h"
 
@@ -97,7 +97,7 @@
       [&num_compute_calls](const string& key, int* value) {
         *value = num_compute_calls;
         num_compute_calls++;
-        return OkStatus();
+        return absl::OkStatus();
       };
   ExpiringLRUCache<int> cache1(0, 4);
 
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/file_block_cache.h b/third_party/xla/third_party/tsl/tsl/platform/cloud/file_block_cache.h
index e336a42..5992754 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/file_block_cache.h
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/file_block_cache.h
@@ -70,9 +70,9 @@
   /// cache is constructed. The returned Status should be OK as long as the
   /// read from the remote filesystem succeeded (similar to the semantics of the
   /// read(2) system call).
-  typedef std::function<Status(const string& filename, size_t offset,
-                               size_t buffer_size, char* buffer,
-                               size_t* bytes_transferred)>
+  typedef std::function<absl::Status(const string& filename, size_t offset,
+                                     size_t buffer_size, char* buffer,
+                                     size_t* bytes_transferred)>
       BlockFetcher;
 
   virtual ~FileBlockCache() {}
@@ -91,8 +91,8 @@
   ///    placed in `out`.
   /// 4) OK otherwise (i.e. the read succeeded, and at least one byte was placed
   ///    in `out`).
-  virtual Status Read(const string& filename, size_t offset, size_t n,
-                      char* buffer, size_t* bytes_transferred) = 0;
+  virtual absl::Status Read(const string& filename, size_t offset, size_t n,
+                            char* buffer, size_t* bytes_transferred) = 0;
 
   // Validate the given file signature with the existing file signature in the
   // cache. Returns true if the signature doesn't change or the file did not
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_dns_cache.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_dns_cache.cc
index 4819b49..594703f 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_dns_cache.cc
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_dns_cache.cc
@@ -41,7 +41,8 @@
 const std::vector<string>& kCachedDomainNames =
     *new std::vector<string>{"www.googleapis.com", "storage.googleapis.com"};
 
-inline void print_getaddrinfo_error(const string& name, Status return_status) {
+inline void print_getaddrinfo_error(const string& name,
+                                    absl::Status return_status) {
   // Status doesn't map well to EAI type errors.
   LOG(ERROR) << "Error resolving " << name << ": " << return_status;
 }
@@ -104,13 +105,13 @@
       /* max_delay_time_us = */ 50 * 1000 * 5000,
       /* max_retries = */ 5);
 
-  const Status getaddrinfo_status = RetryingUtils::CallWithRetries(
+  const absl::Status getaddrinfo_status = RetryingUtils::CallWithRetries(
       [&name, &hints, &result]() {
         int return_code = getaddrinfo(name.c_str(), nullptr, &hints, &result);
         absl::Status return_status;
         switch (return_code) {
           case 0:
-            return_status = OkStatus();
+            return_status = absl::OkStatus();
             break;
 #ifndef _WIN32
           case EAI_ADDRFAMILY:
@@ -175,7 +176,7 @@
 #endif
         }
 
-        return Status(return_status);
+        return absl::Status(return_status);
       },
       retryConfig);
 
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_dns_cache_test.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_dns_cache_test.cc
index 069dcb5..a5ce088 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_dns_cache_test.cc
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_dns_cache_test.cc
@@ -40,8 +40,9 @@
   void SetRequestStats(HttpRequest::RequestStats* stats) override {}
   void SetDeleteRequest() override {}
 
-  Status SetPutFromFile(const string& body_filepath, size_t offset) override {
-    return OkStatus();
+  absl::Status SetPutFromFile(const string& body_filepath,
+                              size_t offset) override {
+    return absl::OkStatus();
   }
   void SetPutEmptyBody() override {}
   void SetPostFromBuffer(const char* buffer, size_t size) override {}
@@ -52,7 +53,7 @@
 
   string GetResponseHeader(const string& name) const override { return ""; }
   uint64 GetResponseCode() const override { return 0; }
-  Status Send() override { return OkStatus(); }
+  absl::Status Send() override { return absl::OkStatus(); }
   string EscapeString(const string& str) override { return ""; }
 
   void SetTimeouts(uint32 connection, uint32 inactivity,
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_file_system.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_file_system.cc
index ea65028..c1cc244 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_file_system.cc
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_file_system.cc
@@ -163,9 +163,9 @@
 // objects.
 constexpr char kComposeAppend[] = "compose";
 
-Status GetTmpFilename(string* filename) {
+absl::Status GetTmpFilename(string* filename) {
   *filename = io::GetTempFilename("");
-  return OkStatus();
+  return absl::OkStatus();
 }
 
 /// Appends a trailing slash if the name doesn't already have one.
@@ -199,7 +199,7 @@
   std::set<string> result;
   result.insert(paths.begin(), paths.end());
   for (const string& path : paths) {
-    StringPiece subpath = io::Dirname(path);
+    absl::string_view subpath = io::Dirname(path);
     // If `path` starts with `/`, `subpath` will be `/` and then we get into an
     // infinite loop. Same behavior happens if there is a `//` pattern in
     // `path`, so we check for that and leave the loop quicker.
@@ -211,32 +211,32 @@
   return result;
 }
 
-Status ParseJson(StringPiece json, Json::Value* result) {
+absl::Status ParseJson(absl::string_view json, Json::Value* result) {
   Json::Reader reader;
   if (!reader.parse(json.data(), json.data() + json.size(), *result)) {
     return errors::Internal("Couldn't parse JSON response from GCS.");
   }
-  return OkStatus();
+  return absl::OkStatus();
 }
 
-Status ParseJson(const std::vector<char>& json, Json::Value* result) {
-  return ParseJson(StringPiece{json.data(), json.size()}, result);
+absl::Status ParseJson(const std::vector<char>& json, Json::Value* result) {
+  return ParseJson(absl::string_view{json.data(), json.size()}, result);
 }
 
 /// Reads a JSON value with the given name from a parent JSON value.
-Status GetValue(const Json::Value& parent, const char* name,
-                Json::Value* result) {
+absl::Status GetValue(const Json::Value& parent, const char* name,
+                      Json::Value* result) {
   *result = parent.get(name, Json::Value::null);
   if (result->isNull()) {
     return errors::Internal("The field '", name,
                             "' was expected in the JSON response.");
   }
-  return OkStatus();
+  return absl::OkStatus();
 }
 
 /// Reads a string JSON value with the given name from a parent JSON value.
-Status GetStringValue(const Json::Value& parent, const char* name,
-                      string* result) {
+absl::Status GetStringValue(const Json::Value& parent, const char* name,
+                            string* result) {
   Json::Value result_value;
   TF_RETURN_IF_ERROR(GetValue(parent, name, &result_value));
   if (!result_value.isString()) {
@@ -245,21 +245,21 @@
         "' in the JSON response was expected to be a string.");
   }
   *result = result_value.asString();
-  return OkStatus();
+  return absl::OkStatus();
 }
 
 /// Reads a long JSON value with the given name from a parent JSON value.
-Status GetInt64Value(const Json::Value& parent, const char* name,
-                     int64_t* result) {
+absl::Status GetInt64Value(const Json::Value& parent, const char* name,
+                           int64_t* result) {
   Json::Value result_value;
   TF_RETURN_IF_ERROR(GetValue(parent, name, &result_value));
   if (result_value.isNumeric()) {
     *result = result_value.asInt64();
-    return OkStatus();
+    return absl::OkStatus();
   }
   if (result_value.isString() &&
       strings::safe_strto64(result_value.asCString(), result)) {
-    return OkStatus();
+    return absl::OkStatus();
   }
   return errors::Internal(
       "The field '", name,
@@ -267,7 +267,8 @@
 }
 
 /// Reads a boolean JSON value with the given name from a parent JSON value.
-Status GetBoolValue(const Json::Value& parent, const char* name, bool* result) {
+absl::Status GetBoolValue(const Json::Value& parent, const char* name,
+                          bool* result) {
   Json::Value result_value;
   TF_RETURN_IF_ERROR(GetValue(parent, name, &result_value));
   if (!result_value.isBool()) {
@@ -276,7 +277,7 @@
         "' in the JSON response was expected to be a boolean.");
   }
   *result = result_value.asBool();
-  return OkStatus();
+  return absl::OkStatus();
 }
 
 /// Get GCS Retry Config by applying user overrides through env if any.
@@ -314,21 +315,21 @@
 /// A GCS-based implementation of a random access file with an LRU block cache.
 class GcsRandomAccessFile : public RandomAccessFile {
  public:
-  using ReadFn =
-      std::function<Status(const string& filename, uint64 offset, size_t n,
-                           StringPiece* result, char* scratch)>;
+  using ReadFn = std::function<absl::Status(
+      const string& filename, uint64 offset, size_t n,
+      absl::string_view* result, char* scratch)>;
 
   GcsRandomAccessFile(const string& filename, ReadFn read_fn)
       : filename_(filename), read_fn_(std::move(read_fn)) {}
 
-  Status Name(StringPiece* result) const override {
+  absl::Status Name(absl::string_view* result) const override {
     *result = filename_;
-    return OkStatus();
+    return absl::OkStatus();
   }
 
   /// The implementation of reads with an LRU block cache. Thread safe.
-  Status Read(uint64 offset, size_t n, StringPiece* result,
-              char* scratch) const override {
+  absl::Status Read(uint64 offset, size_t n, absl::string_view* result,
+                    char* scratch) const override {
     return read_fn_(filename_, offset, n, result, scratch);
   }
 
@@ -342,9 +343,9 @@
 /// A GCS-based implementation of a random access file with a read buffer.
 class BufferedGcsRandomAccessFile : public RandomAccessFile {
  public:
-  using ReadFn =
-      std::function<Status(const string& filename, uint64 offset, size_t n,
-                           StringPiece* result, char* scratch)>;
+  using ReadFn = std::function<absl::Status(
+      const string& filename, uint64 offset, size_t n,
+      absl::string_view* result, char* scratch)>;
 
   // Initialize the reader. Provided read_fn should be thread safe.
   BufferedGcsRandomAccessFile(const string& filename, uint64 buffer_size,
@@ -355,16 +356,16 @@
         buffer_start_(0),
         buffer_end_is_past_eof_(false) {}
 
-  Status Name(StringPiece* result) const override {
+  absl::Status Name(absl::string_view* result) const override {
     *result = filename_;
-    return OkStatus();
+    return absl::OkStatus();
   }
 
   /// The implementation of reads with an read buffer. Thread safe.
   /// Returns `OUT_OF_RANGE` if fewer than n bytes were stored in `*result`
   /// because of EOF.
-  Status Read(uint64 offset, size_t n, StringPiece* result,
-              char* scratch) const override {
+  absl::Status Read(uint64 offset, size_t n, absl::string_view* result,
+                    char* scratch) const override {
     if (n > buffer_size_) {
       return read_fn_(filename_, offset, n, result, scratch);
     }
@@ -375,12 +376,12 @@
       if (offset < buffer_end && offset >= buffer_start_) {
         copy_size = std::min(n, static_cast<size_t>(buffer_end - offset));
         memcpy(scratch, buffer_.data() + (offset - buffer_start_), copy_size);
-        *result = StringPiece(scratch, copy_size);
+        *result = absl::string_view(scratch, copy_size);
       }
       bool consumed_buffer_to_eof =
           offset + copy_size >= buffer_end && buffer_end_is_past_eof_;
       if (copy_size < n && !consumed_buffer_to_eof) {
-        Status status = FillBuffer(offset + copy_size);
+        absl::Status status = FillBuffer(offset + copy_size);
         if (!status.ok() && !absl::IsOutOfRange(status)) {
           // Empty the buffer to avoid caching bad reads.
           buffer_.resize(0);
@@ -389,7 +390,7 @@
         size_t remaining_copy = std::min(n - copy_size, buffer_.size());
         memcpy(scratch + copy_size, buffer_.data(), remaining_copy);
         copy_size += remaining_copy;
-        *result = StringPiece(scratch, copy_size);
+        *result = absl::string_view(scratch, copy_size);
       }
       if (copy_size < n) {
         // Forget the end-of-file flag to allow for clients that poll on the
@@ -399,17 +400,17 @@
                                   " bytes from ", offset, ".");
       }
     }
-    return OkStatus();
+    return absl::OkStatus();
   }
 
  private:
-  Status FillBuffer(uint64 start) const
+  absl::Status FillBuffer(uint64 start) const
       TF_EXCLUSIVE_LOCKS_REQUIRED(buffer_mutex_) {
     buffer_start_ = start;
     buffer_.resize(buffer_size_);
-    StringPiece str_piece;
-    Status status = read_fn_(filename_, buffer_start_, buffer_size_, &str_piece,
-                             &(buffer_[0]));
+    absl::string_view str_piece;
+    absl::Status status = read_fn_(filename_, buffer_start_, buffer_size_,
+                                   &str_piece, &(buffer_[0]));
     buffer_end_is_past_eof_ = absl::IsOutOfRange(status);
     buffer_.resize(str_piece.size());
     return status;
@@ -437,28 +438,28 @@
 };
 
 // Function object declaration with params needed to create upload sessions.
-typedef std::function<Status(
+typedef std::function<absl::Status(
     uint64 start_offset, const std::string& object_to_upload,
     const std::string& bucket, uint64 file_size, const std::string& gcs_path,
     UploadSessionHandle* session_handle)>
     SessionCreator;
 
 // Function object declaration with params needed to upload objects.
-typedef std::function<Status(const std::string& session_uri,
-                             uint64 start_offset, uint64 already_uploaded,
-                             const std::string& tmp_content_filename,
-                             uint64 file_size, const std::string& file_path)>
+typedef std::function<absl::Status(
+    const std::string& session_uri, uint64 start_offset,
+    uint64 already_uploaded, const std::string& tmp_content_filename,
+    uint64 file_size, const std::string& file_path)>
     ObjectUploader;
 
 // Function object declaration with params needed to poll upload status.
-typedef std::function<Status(const string& session_uri, uint64 file_size,
-                             const std::string& gcs_path, bool* completed,
-                             uint64* uploaded)>
+typedef std::function<absl::Status(const string& session_uri, uint64 file_size,
+                                   const std::string& gcs_path, bool* completed,
+                                   uint64* uploaded)>
     StatusPoller;
 
 // Function object declaration with params needed to poll upload status.
-typedef std::function<Status(const string& fname, const string& bucket,
-                             const string& object, int64_t* generation)>
+typedef std::function<absl::Status(const string& fname, const string& bucket,
+                                   const string& object, int64_t* generation)>
     GenerationGetter;
 
 /// \brief GCS-based implementation of a writeable file.
@@ -534,7 +535,7 @@
     std::remove(tmp_content_filename_.c_str());
   }
 
-  Status Append(StringPiece data) override {
+  absl::Status Append(absl::string_view data) override {
     TF_RETURN_IF_ERROR(CheckWritable());
     VLOG(3) << "Append: " << GetGcsPath() << " size " << data.length();
     sync_needed_ = true;
@@ -543,37 +544,38 @@
       return errors::Internal(
           "Could not append to the internal temporary file.");
     }
-    return OkStatus();
+    return absl::OkStatus();
   }
 
-  Status Close() override {
+  absl::Status Close() override {
     VLOG(3) << "Close:" << GetGcsPath();
     if (outfile_.is_open()) {
-      Status sync_status = Sync();
+      absl::Status sync_status = Sync();
       if (sync_status.ok()) {
         outfile_.close();
       }
       return sync_status;
     }
-    return OkStatus();
+    return absl::OkStatus();
   }
 
-  Status Flush() override {
+  absl::Status Flush() override {
     VLOG(3) << "Flush:" << GetGcsPath();
     return Sync();
   }
 
-  Status Name(StringPiece* result) const override {
-    return errors::Unimplemented("GCSWritableFile does not support Name()");
+  absl::Status Name(absl::string_view* result) const override {
+    *result = object_;
+    return absl::OkStatus();
   }
 
-  Status Sync() override {
+  absl::Status Sync() override {
     VLOG(3) << "Sync started:" << GetGcsPath();
     TF_RETURN_IF_ERROR(CheckWritable());
     if (!sync_needed_) {
-      return OkStatus();
+      return absl::OkStatus();
     }
-    Status status = SyncImpl();
+    absl::Status status = SyncImpl();
     VLOG(3) << "Sync finished " << GetGcsPath();
     if (status.ok()) {
       sync_needed_ = false;
@@ -581,12 +583,12 @@
     return status;
   }
 
-  Status Tell(int64_t* position) override {
+  absl::Status Tell(int64_t* position) override {
     *position = outfile_.tellp();
     if (*position == -1) {
       return errors::Internal("tellp on the internal temporary file failed");
     }
-    return OkStatus();
+    return absl::OkStatus();
   }
 
  private:
@@ -596,7 +598,7 @@
   /// In case of a failure, it resumes failed uploads as recommended by the GCS
   /// resumable API documentation. When the whole upload needs to be
   /// restarted, Sync() returns UNAVAILABLE and relies on RetryingFileSystem.
-  Status SyncImpl() {
+  absl::Status SyncImpl() {
     outfile_.flush();
     if (!outfile_.good()) {
       return errors::Internal(
@@ -620,7 +622,7 @@
                                               &session_handle));
     uint64 already_uploaded = 0;
     bool first_attempt = true;
-    const Status upload_status = RetryingUtils::CallWithRetries(
+    const absl::Status upload_status = RetryingUtils::CallWithRetries(
         [&first_attempt, &already_uploaded, &session_handle, &start_offset,
          this]() {
           if (session_handle.resumable && !first_attempt) {
@@ -637,7 +639,7 @@
               // It's unclear why UploadToSession didn't return OK in the
               // previous attempt, but GCS reports that the file is fully
               // uploaded, so succeed.
-              return OkStatus();
+              return absl::OkStatus();
             }
           }
           first_attempt = false;
@@ -661,28 +663,28 @@
     return upload_status;
   }
 
-  Status CheckWritable() const {
+  absl::Status CheckWritable() const {
     if (!outfile_.is_open()) {
       return errors::FailedPrecondition(
           "The internal temporary file is not writable.");
     }
-    return OkStatus();
+    return absl::OkStatus();
   }
 
-  Status GetCurrentFileSize(uint64* size) {
+  absl::Status GetCurrentFileSize(uint64* size) {
     const auto tellp = outfile_.tellp();
     if (tellp == static_cast<std::streampos>(-1)) {
       return errors::Internal(
           "Could not get the size of the internal temporary file.");
     }
     *size = tellp;
-    return OkStatus();
+    return absl::OkStatus();
   }
 
   /// Initiates a new resumable upload session.
-  Status CreateNewUploadSession(uint64 start_offset,
-                                std::string object_to_upload,
-                                UploadSessionHandle* session_handle) {
+  absl::Status CreateNewUploadSession(uint64 start_offset,
+                                      std::string object_to_upload,
+                                      UploadSessionHandle* session_handle) {
     uint64 file_size;
     TF_RETURN_IF_ERROR(GetCurrentFileSize(&file_size));
     return session_creator_(start_offset, object_to_upload, bucket_, file_size,
@@ -691,7 +693,7 @@
 
   /// Appends the data of append_object to the original object and deletes
   /// append_object.
-  Status AppendObject(string append_object) {
+  absl::Status AppendObject(string append_object) {
     const string append_object_path = GetGcsPathWithObject(append_object);
     VLOG(3) << "AppendObject: " << append_object_path << " to " << GetGcsPath();
 
@@ -718,7 +720,7 @@
           request->SetPostFromBuffer(request_body.c_str(), request_body.size());
           TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(),
                                           " when composing to ", GetGcsPath());
-          return OkStatus();
+          return absl::OkStatus();
         },
         retry_config_));
 
@@ -734,8 +736,8 @@
   /// If the upload has already succeeded, sets 'completed' to true.
   /// Otherwise sets 'completed' to false and 'uploaded' to the currently
   /// uploaded size in bytes.
-  Status RequestUploadSessionStatus(const string& session_uri, bool* completed,
-                                    uint64* uploaded) {
+  absl::Status RequestUploadSessionStatus(const string& session_uri,
+                                          bool* completed, uint64* uploaded) {
     uint64 file_size;
     TF_RETURN_IF_ERROR(GetCurrentFileSize(&file_size));
     return status_poller_(session_uri, file_size, GetGcsPath(), completed,
@@ -743,11 +745,11 @@
   }
 
   /// Uploads data to object.
-  Status UploadToSession(const string& session_uri, uint64 start_offset,
-                         uint64 already_uploaded) {
+  absl::Status UploadToSession(const string& session_uri, uint64 start_offset,
+                               uint64 already_uploaded) {
     uint64 file_size;
     TF_RETURN_IF_ERROR(GetCurrentFileSize(&file_size));
-    Status status =
+    absl::Status status =
         object_uploader_(session_uri, start_offset, already_uploaded,
                          tmp_content_filename_, file_size, GetGcsPath());
     if (status.ok()) {
@@ -795,14 +797,14 @@
   uint64 length_;
 };
 
-bool StringPieceIdentity(StringPiece str, StringPiece* value) {
+bool StringPieceIdentity(absl::string_view str, absl::string_view* value) {
   *value = str;
   return true;
 }
 
 /// \brief Utility function to split a comma delimited list of strings to an
 /// unordered set, lowercasing all values.
-bool SplitByCommaToLowercaseSet(StringPiece list,
+bool SplitByCommaToLowercaseSet(absl::string_view list,
                                 std::unordered_set<string>* set) {
   std::vector<string> vector = absl::StrSplit(absl::AsciiStrToLower(list), ',');
   *set = std::unordered_set<string>(vector.begin(), vector.end());
@@ -897,14 +899,14 @@
   }
 
   // Get the additional header
-  StringPiece add_header_contents;
+  absl::string_view add_header_contents;
   if (GetEnvVar(kAdditionalRequestHeader, StringPieceIdentity,
                 &add_header_contents)) {
     size_t split = add_header_contents.find(':', 0);
 
-    if (split != StringPiece::npos) {
-      StringPiece header_name = add_header_contents.substr(0, split);
-      StringPiece header_value = add_header_contents.substr(split + 1);
+    if (split != absl::string_view::npos) {
+      absl::string_view header_name = add_header_contents.substr(0, split);
+      absl::string_view header_value = add_header_contents.substr(split + 1);
 
       if (!header_name.empty() && !header_value.empty()) {
         additional_header_.reset(new std::pair<const string, const string>(
@@ -968,7 +970,7 @@
   GetEnvVar(kAllowedBucketLocations, SplitByCommaToLowercaseSet,
             &allowed_locations_);
 
-  StringPiece append_mode;
+  absl::string_view append_mode;
   GetEnvVar(kAppendMode, StringPieceIdentity, &append_mode);
   if (append_mode == kComposeAppend) {
     compose_append_ = true;
@@ -1006,7 +1008,7 @@
       compose_append_(compose_append),
       additional_header_(additional_header) {}
 
-Status GcsFileSystem::NewRandomAccessFile(
+absl::Status GcsFileSystem::NewRandomAccessFile(
     const string& fname, TransactionToken* token,
     std::unique_ptr<RandomAccessFile>* result) {
   string bucket, object;
@@ -1016,7 +1018,7 @@
     result->reset(new GcsRandomAccessFile(fname, [this, bucket, object](
                                                      const string& fname,
                                                      uint64 offset, size_t n,
-                                                     StringPiece* result,
+                                                     absl::string_view* result,
                                                      char* scratch) {
       tf_shared_lock l(block_cache_lock_);
       GcsFileStat stat;
@@ -1031,37 +1033,37 @@
             << "File signature has been changed. Refreshing the cache. Path: "
             << fname;
       }
-      *result = StringPiece();
+      *result = absl::string_view();
       size_t bytes_transferred;
       TF_RETURN_IF_ERROR(file_block_cache_->Read(fname, offset, n, scratch,
                                                  &bytes_transferred));
-      *result = StringPiece(scratch, bytes_transferred);
+      *result = absl::string_view(scratch, bytes_transferred);
       if (bytes_transferred < n) {
         return errors::OutOfRange("EOF reached, ", result->size(),
                                   " bytes were read out of ", n,
                                   " bytes requested.");
       }
-      return OkStatus();
+      return absl::OkStatus();
     }));
   } else {
     result->reset(new BufferedGcsRandomAccessFile(
         fname, block_size_,
         [this, bucket, object](const string& fname, uint64 offset, size_t n,
-                               StringPiece* result, char* scratch) {
-          *result = StringPiece();
+                               absl::string_view* result, char* scratch) {
+          *result = absl::string_view();
           size_t bytes_transferred;
           TF_RETURN_IF_ERROR(
               LoadBufferFromGCS(fname, offset, n, scratch, &bytes_transferred));
-          *result = StringPiece(scratch, bytes_transferred);
+          *result = absl::string_view(scratch, bytes_transferred);
           if (bytes_transferred < n) {
             return errors::OutOfRange("EOF reached, ", result->size(),
                                       " bytes were read out of ", n,
                                       " bytes requested.");
           }
-          return OkStatus();
+          return absl::OkStatus();
         }));
   }
-  return OkStatus();
+  return absl::OkStatus();
 }
 
 void GcsFileSystem::ResetFileBlockCache(size_t block_size_bytes,
@@ -1092,9 +1094,10 @@
 }
 
 // A helper function to actually read the data from GCS.
-Status GcsFileSystem::LoadBufferFromGCS(const string& fname, size_t offset,
-                                        size_t n, char* buffer,
-                                        size_t* bytes_transferred) {
+absl::Status GcsFileSystem::LoadBufferFromGCS(const string& fname,
+                                              size_t offset, size_t n,
+                                              char* buffer,
+                                              size_t* bytes_transferred) {
   *bytes_transferred = 0;
 
   string bucket, object;
@@ -1148,11 +1151,11 @@
     }
   }
 
-  return OkStatus();
+  return absl::OkStatus();
 }
 
 /// Initiates a new upload session.
-Status GcsFileSystem::CreateNewUploadSession(
+absl::Status GcsFileSystem::CreateNewUploadSession(
     uint64 start_offset, const std::string& object_to_upload,
     const std::string& bucket, uint64 file_size, const std::string& gcs_path,
     UploadSessionHandle* session_handle) {
@@ -1179,15 +1182,13 @@
                               gcs_path, ": 'Location' header not returned.");
     }
   }
-  return OkStatus();
+  return absl::OkStatus();
 }
 
-Status GcsFileSystem::UploadToSession(const std::string& session_uri,
-                                      uint64 start_offset,
-                                      uint64 already_uploaded,
-                                      const std::string& tmp_content_filename,
-                                      uint64 file_size,
-                                      const std::string& file_path) {
+absl::Status GcsFileSystem::UploadToSession(
+    const std::string& session_uri, uint64 start_offset,
+    uint64 already_uploaded, const std::string& tmp_content_filename,
+    uint64 file_size, const std::string& file_path) {
   std::unique_ptr<HttpRequest> request;
   TF_RETURN_IF_ERROR(CreateHttpRequest(&request));
   request->SetUri(session_uri);
@@ -1203,14 +1204,12 @@
                                              start_offset + already_uploaded));
   TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when uploading ",
                                   file_path);
-  return OkStatus();
+  return absl::OkStatus();
 }
 
-Status GcsFileSystem::RequestUploadSessionStatus(const string& session_uri,
-                                                 uint64 file_size,
-                                                 const std::string& gcs_path,
-                                                 bool* completed,
-                                                 uint64* uploaded) {
+absl::Status GcsFileSystem::RequestUploadSessionStatus(
+    const string& session_uri, uint64 file_size, const std::string& gcs_path,
+    bool* completed, uint64* uploaded) {
   CHECK(completed != nullptr) << "RequestUploadSessionStatus() called with out "
                                  "param 'completed' == nullptr.";  // Crash ok
   CHECK(uploaded != nullptr) << "RequestUploadSessionStatus() called with out "
@@ -1221,10 +1220,10 @@
   request->SetTimeouts(timeouts_.connect, timeouts_.idle, timeouts_.metadata);
   request->AddHeader("Content-Range", strings::StrCat("bytes */", file_size));
   request->SetPutEmptyBody();
-  Status status = request->Send();
+  absl::Status status = request->Send();
   if (status.ok()) {
     *completed = true;
-    return OkStatus();
+    return absl::OkStatus();
   }
   *completed = false;
   if (request->GetResponseCode() != HTTP_CODE_RESUME_INCOMPLETE) {
@@ -1235,7 +1234,7 @@
     // This means GCS doesn't have any bytes of the file yet.
     *uploaded = 0;
   } else {
-    StringPiece range_piece(received_range);
+    absl::string_view range_piece(received_range);
     absl::ConsumePrefix(&range_piece,
                         "bytes=");  // May or may not be present.
 
@@ -1269,13 +1268,15 @@
     // If GCS returned "Range: 0-10", this means 11 bytes were uploaded.
     *uploaded = range_parts[1] + 1;
   }
-  return OkStatus();
+  return absl::OkStatus();
 }
 
-Status GcsFileSystem::ParseGcsPathForScheme(StringPiece fname, string scheme,
-                                            bool empty_object_ok,
-                                            string* bucket, string* object) {
-  StringPiece parsed_scheme, bucketp, objectp;
+absl::Status GcsFileSystem::ParseGcsPathForScheme(absl::string_view fname,
+                                                  string scheme,
+                                                  bool empty_object_ok,
+                                                  string* bucket,
+                                                  string* object) {
+  absl::string_view parsed_scheme, bucketp, objectp;
   io::ParseURI(fname, &parsed_scheme, &bucketp, &objectp);
   if (parsed_scheme != scheme) {
     return errors::InvalidArgument("GCS path doesn't start with 'gs://': ",
@@ -1292,11 +1293,12 @@
     return errors::InvalidArgument("GCS path doesn't contain an object name: ",
                                    fname);
   }
-  return OkStatus();
+  return absl::OkStatus();
 }
 
-Status GcsFileSystem::ParseGcsPath(StringPiece fname, bool empty_object_ok,
-                                   string* bucket, string* object) {
+absl::Status GcsFileSystem::ParseGcsPath(absl::string_view fname,
+                                         bool empty_object_ok, string* bucket,
+                                         string* object) {
   return ParseGcsPathForScheme(fname, "gs", empty_object_ok, bucket, object);
 }
 
@@ -1308,9 +1310,9 @@
   // MatchingPathsCache as well.
 }
 
-Status GcsFileSystem::NewWritableFile(const string& fname,
-                                      TransactionToken* token,
-                                      std::unique_ptr<WritableFile>* result) {
+absl::Status GcsFileSystem::NewWritableFile(
+    const string& fname, TransactionToken* token,
+    std::unique_ptr<WritableFile>* result) {
   string bucket, object;
   TF_RETURN_IF_ERROR(ParseGcsPath(fname, false, &bucket, &object));
 
@@ -1344,7 +1346,7 @@
         },
         retry_config_));
     *generation = stat.generation_number;
-    return OkStatus();
+    return absl::OkStatus();
   };
 
   result->reset(new GcsWritableFile(
@@ -1352,20 +1354,20 @@
       [this, fname]() { ClearFileCaches(fname); }, retry_config_,
       compose_append_, session_creator, object_uploader, status_poller,
       generation_getter));
-  return OkStatus();
+  return absl::OkStatus();
 }
 
 // Reads the file from GCS in chunks and stores it in a tmp file,
 // which is then passed to GcsWritableFile.
-Status GcsFileSystem::NewAppendableFile(const string& fname,
-                                        TransactionToken* token,
-                                        std::unique_ptr<WritableFile>* result) {
+absl::Status GcsFileSystem::NewAppendableFile(
+    const string& fname, TransactionToken* token,
+    std::unique_ptr<WritableFile>* result) {
   std::unique_ptr<RandomAccessFile> reader;
   TF_RETURN_IF_ERROR(NewRandomAccessFile(fname, token, &reader));
   std::unique_ptr<char[]> buffer(new char[kReadAppendableFileBufferSize]);
-  Status status;
+  absl::Status status;
   uint64 offset = 0;
-  StringPiece read_chunk;
+  absl::string_view read_chunk;
 
   // Read the file from GCS in chunks and save it to a tmp file.
   string old_content_filename;
@@ -1421,7 +1423,7 @@
         },
         retry_config_));
     *generation = stat.generation_number;
-    return OkStatus();
+    return absl::OkStatus();
   };
 
   // Create a writable file and pass the old content to it.
@@ -1432,10 +1434,10 @@
       [this, fname]() { ClearFileCaches(fname); }, retry_config_,
       compose_append_, session_creator, object_uploader, status_poller,
       generation_getter));
-  return OkStatus();
+  return absl::OkStatus();
 }
 
-Status GcsFileSystem::NewReadOnlyMemoryRegionFromFile(
+absl::Status GcsFileSystem::NewReadOnlyMemoryRegionFromFile(
     const string& fname, TransactionToken* token,
     std::unique_ptr<ReadOnlyMemoryRegion>* result) {
   uint64 size;
@@ -1445,21 +1447,22 @@
   std::unique_ptr<RandomAccessFile> file;
   TF_RETURN_IF_ERROR(NewRandomAccessFile(fname, token, &file));
 
-  StringPiece piece;
+  absl::string_view piece;
   TF_RETURN_IF_ERROR(file->Read(0, size, &piece, data.get()));
 
   result->reset(new GcsReadOnlyMemoryRegion(std::move(data), size));
-  return OkStatus();
+  return absl::OkStatus();
 }
 
-Status GcsFileSystem::FileExists(const string& fname, TransactionToken* token) {
+absl::Status GcsFileSystem::FileExists(const string& fname,
+                                       TransactionToken* token) {
   string bucket, object;
   TF_RETURN_IF_ERROR(ParseGcsPath(fname, true, &bucket, &object));
   if (object.empty()) {
     bool result;
     TF_RETURN_IF_ERROR(BucketExists(bucket, &result));
     if (result) {
-      return OkStatus();
+      return absl::OkStatus();
     } else {
       return absl::NotFoundError(
           absl::StrCat("The specified bucket ", fname, " was not found."));
@@ -1468,7 +1471,7 @@
 
   // Check if the object exists.
   GcsFileStat stat;
-  const Status status = StatForObject(fname, bucket, object, &stat);
+  const absl::Status status = StatForObject(fname, bucket, object, &stat);
   if (!absl::IsNotFound(status)) {
     return status;
   }
@@ -1477,31 +1480,32 @@
   bool result;
   TF_RETURN_IF_ERROR(FolderExists(fname, &result));
   if (result) {
-    return OkStatus();
+    return absl::OkStatus();
   }
   return errors::NotFound("The specified path ", fname, " was not found.");
 }
 
-Status GcsFileSystem::ObjectExists(const string& fname, const string& bucket,
-                                   const string& object, bool* result) {
+absl::Status GcsFileSystem::ObjectExists(const string& fname,
+                                         const string& bucket,
+                                         const string& object, bool* result) {
   GcsFileStat stat;
-  const Status status = StatForObject(fname, bucket, object, &stat);
+  const absl::Status status = StatForObject(fname, bucket, object, &stat);
   switch (static_cast<int>(status.code())) {
     case static_cast<int>(error::Code::OK):
       *result = !stat.base.is_directory;
-      return OkStatus();
+      return absl::OkStatus();
     case static_cast<int>(error::Code::NOT_FOUND):
       *result = false;
-      return OkStatus();
+      return absl::OkStatus();
     default:
       return status;
   }
 }
 
-Status GcsFileSystem::UncachedStatForObject(const string& fname,
-                                            const string& bucket,
-                                            const string& object,
-                                            GcsFileStat* stat) {
+absl::Status GcsFileSystem::UncachedStatForObject(const string& fname,
+                                                  const string& bucket,
+                                                  const string& object,
+                                                  GcsFileStat* stat) {
   std::vector<char> output_buffer;
   std::unique_ptr<HttpRequest> request;
   TF_RETURN_WITH_CONTEXT_IF_ERROR(CreateHttpRequest(&request),
@@ -1542,7 +1546,7 @@
           << "; mtime_nsec: " << stat->base.mtime_nsec
           << "; updated: " << updated;
 
-  if (str_util::EndsWith(fname, "/")) {
+  if (absl::EndsWith(fname, "/")) {
     // In GCS a path can be both a directory and a file, both it is uncommon for
     // other file systems. To avoid the ambiguity, if a path ends with "/" in
     // GCS, we always regard it as a directory mark or a virtual directory.
@@ -1550,11 +1554,13 @@
   } else {
     stat->base.is_directory = false;
   }
-  return OkStatus();
+  return absl::OkStatus();
 }
 
-Status GcsFileSystem::StatForObject(const string& fname, const string& bucket,
-                                    const string& object, GcsFileStat* stat) {
+absl::Status GcsFileSystem::StatForObject(const string& fname,
+                                          const string& bucket,
+                                          const string& object,
+                                          GcsFileStat* stat) {
   if (object.empty()) {
     return errors::InvalidArgument(strings::Printf(
         "'object' must be a non-empty string. (File: %s)", fname.c_str()));
@@ -1565,26 +1571,27 @@
       [this, &bucket, &object](const string& fname, GcsFileStat* stat) {
         return UncachedStatForObject(fname, bucket, object, stat);
       }));
-  return OkStatus();
+  return absl::OkStatus();
 }
 
-Status GcsFileSystem::BucketExists(const string& bucket, bool* result) {
-  const Status status = GetBucketMetadata(bucket, nullptr);
+absl::Status GcsFileSystem::BucketExists(const string& bucket, bool* result) {
+  const absl::Status status = GetBucketMetadata(bucket, nullptr);
   switch (static_cast<absl::StatusCode>(status.code())) {
     case absl::StatusCode::kOk:
       *result = true;
-      return OkStatus();
+      return absl::OkStatus();
     case absl::StatusCode::kNotFound:
       *result = false;
-      return OkStatus();
+      return absl::OkStatus();
     default:
       return status;
   }
 }
 
-Status GcsFileSystem::CheckBucketLocationConstraint(const string& bucket) {
+absl::Status GcsFileSystem::CheckBucketLocationConstraint(
+    const string& bucket) {
   if (allowed_locations_.empty()) {
-    return OkStatus();
+    return absl::OkStatus();
   }
 
   // Avoid calling external API's in the constructor
@@ -1597,7 +1604,7 @@
   string location;
   TF_RETURN_IF_ERROR(GetBucketLocation(bucket, &location));
   if (allowed_locations_.find(location) != allowed_locations_.end()) {
-    return OkStatus();
+    return absl::OkStatus();
   }
 
   return errors::FailedPrecondition(strings::Printf(
@@ -1606,11 +1613,11 @@
       absl::StrJoin(allowed_locations_, ", ").c_str()));
 }
 
-Status GcsFileSystem::GetBucketLocation(const string& bucket,
-                                        string* location) {
+absl::Status GcsFileSystem::GetBucketLocation(const string& bucket,
+                                              string* location) {
   auto compute_func = [this](const string& bucket, string* location) {
     std::vector<char> result_buffer;
-    Status status = GetBucketMetadata(bucket, &result_buffer);
+    absl::Status status = GetBucketMetadata(bucket, &result_buffer);
     Json::Value result;
     TF_RETURN_IF_ERROR(ParseJson(result_buffer, &result));
     string bucket_location;
@@ -1618,17 +1625,17 @@
         GetStringValue(result, kBucketMetadataLocationKey, &bucket_location));
     // Lowercase the GCS location to be case insensitive for allowed locations.
     *location = absl::AsciiStrToLower(bucket_location);
-    return OkStatus();
+    return absl::OkStatus();
   };
 
   TF_RETURN_IF_ERROR(
       bucket_location_cache_->LookupOrCompute(bucket, location, compute_func));
 
-  return OkStatus();
+  return absl::OkStatus();
 }
 
-Status GcsFileSystem::GetBucketMetadata(const string& bucket,
-                                        std::vector<char>* result_buffer) {
+absl::Status GcsFileSystem::GetBucketMetadata(
+    const string& bucket, std::vector<char>* result_buffer) {
   std::unique_ptr<HttpRequest> request;
   TF_RETURN_IF_ERROR(CreateHttpRequest(&request));
   request->SetUri(strings::StrCat(kGcsUriBase, "b/", bucket));
@@ -1641,7 +1648,7 @@
   return request->Send();
 }
 
-Status GcsFileSystem::FolderExists(const string& dirname, bool* result) {
+absl::Status GcsFileSystem::FolderExists(const string& dirname, bool* result) {
   StatCache::ComputeFunc compute_func = [this](const string& dirname,
                                                GcsFileStat* stat) {
     std::vector<string> children;
@@ -1650,36 +1657,36 @@
                            true /* include_self_directory_marker */));
     if (!children.empty()) {
       stat->base = DIRECTORY_STAT;
-      return OkStatus();
+      return absl::OkStatus();
     } else {
       return errors::InvalidArgument("Not a directory!");
     }
   };
   GcsFileStat stat;
-  Status s = stat_cache_->LookupOrCompute(MaybeAppendSlash(dirname), &stat,
-                                          compute_func);
+  absl::Status s = stat_cache_->LookupOrCompute(MaybeAppendSlash(dirname),
+                                                &stat, compute_func);
   if (s.ok()) {
     *result = stat.base.is_directory;
-    return OkStatus();
+    return absl::OkStatus();
   }
   if (absl::IsInvalidArgument(s)) {
     *result = false;
-    return OkStatus();
+    return absl::OkStatus();
   }
   return s;
 }
 
-Status GcsFileSystem::GetChildren(const string& dirname,
-                                  TransactionToken* token,
-                                  std::vector<string>* result) {
+absl::Status GcsFileSystem::GetChildren(const string& dirname,
+                                        TransactionToken* token,
+                                        std::vector<string>* result) {
   return GetChildrenBounded(dirname, UINT64_MAX, result,
                             false /* recursively */,
                             false /* include_self_directory_marker */);
 }
 
-Status GcsFileSystem::GetMatchingPaths(const string& pattern,
-                                       TransactionToken* token,
-                                       std::vector<string>* results) {
+absl::Status GcsFileSystem::GetMatchingPaths(const string& pattern,
+                                             TransactionToken* token,
+                                             std::vector<string>* results) {
   MatchingPathsCache::ComputeFunc compute_func =
       [this](const string& pattern, std::vector<string>* results) {
         results->clear();
@@ -1700,7 +1707,7 @@
 
         // To handle `/` in the object names, we need to remove it from `dir`
         // and then use `StrCat` to insert it back.
-        const StringPiece dir_no_slash = str_util::StripSuffix(dir, "/");
+        const absl::string_view dir_no_slash = absl::StripSuffix(dir, "/");
 
         // Match all obtained paths to the input pattern.
         for (const auto& path : files_and_folders) {
@@ -1715,18 +1722,16 @@
             results->push_back(full_path);
           }
         }
-        return OkStatus();
+        return absl::OkStatus();
       };
   TF_RETURN_IF_ERROR(
       matching_paths_cache_->LookupOrCompute(pattern, results, compute_func));
-  return OkStatus();
+  return absl::OkStatus();
 }
 
-Status GcsFileSystem::GetChildrenBounded(const string& dirname,
-                                         uint64 max_results,
-                                         std::vector<string>* result,
-                                         bool recursive,
-                                         bool include_self_directory_marker) {
+absl::Status GcsFileSystem::GetChildrenBounded(
+    const string& dirname, uint64 max_results, std::vector<string>* result,
+    bool recursive, bool include_self_directory_marker) {
   if (!result) {
     return errors::InvalidArgument("'result' cannot be null");
   }
@@ -1786,7 +1791,7 @@
         // The names should be relative to the 'dirname'. That means the
         // 'object_prefix', which is part of 'dirname', should be removed from
         // the beginning of 'name'.
-        StringPiece relative_path(name);
+        absl::string_view relative_path(name);
         if (!absl::ConsumePrefix(&relative_path, object_prefix)) {
           return errors::Internal(strings::StrCat(
               "Unexpected response: the returned file name ", name,
@@ -1796,7 +1801,7 @@
           result->emplace_back(relative_path);
         }
         if (++retrieved_results >= max_results) {
-          return OkStatus();
+          return absl::OkStatus();
         }
       }
     }
@@ -1815,7 +1820,7 @@
               "response.");
         }
         const string& prefix_str = prefix.asString();
-        StringPiece relative_path(prefix_str);
+        absl::string_view relative_path(prefix_str);
         if (!absl::ConsumePrefix(&relative_path, object_prefix)) {
           return errors::Internal(
               "Unexpected response: the returned folder name ", prefix_str,
@@ -1823,13 +1828,13 @@
         }
         result->emplace_back(relative_path);
         if (++retrieved_results >= max_results) {
-          return OkStatus();
+          return absl::OkStatus();
         }
       }
     }
     const auto token = root.get("nextPageToken", Json::Value::null);
     if (token.isNull()) {
-      return OkStatus();
+      return absl::OkStatus();
     }
     if (!token.isString()) {
       return errors::Internal(
@@ -1839,8 +1844,8 @@
   }
 }
 
-Status GcsFileSystem::Stat(const string& fname, TransactionToken* token,
-                           FileStatistics* stat) {
+absl::Status GcsFileSystem::Stat(const string& fname, TransactionToken* token,
+                                 FileStatistics* stat) {
   if (!stat) {
     return errors::Internal("'stat' cannot be nullptr.");
   }
@@ -1851,16 +1856,16 @@
     TF_RETURN_IF_ERROR(BucketExists(bucket, &is_bucket));
     if (is_bucket) {
       *stat = DIRECTORY_STAT;
-      return OkStatus();
+      return absl::OkStatus();
     }
     return errors::NotFound("The specified bucket ", fname, " was not found.");
   }
 
   GcsFileStat gcs_stat;
-  const Status status = StatForObject(fname, bucket, object, &gcs_stat);
+  const absl::Status status = StatForObject(fname, bucket, object, &gcs_stat);
   if (status.ok()) {
     *stat = gcs_stat.base;
-    return OkStatus();
+    return absl::OkStatus();
   }
   if (!absl::IsNotFound(status)) {
     return status;
@@ -1869,12 +1874,13 @@
   TF_RETURN_IF_ERROR(FolderExists(fname, &is_folder));
   if (is_folder) {
     *stat = DIRECTORY_STAT;
-    return OkStatus();
+    return absl::OkStatus();
   }
   return errors::NotFound("The specified path ", fname, " was not found.");
 }
 
-Status GcsFileSystem::DeleteFile(const string& fname, TransactionToken* token) {
+absl::Status GcsFileSystem::DeleteFile(const string& fname,
+                                       TransactionToken* token) {
   string bucket, object;
   TF_RETURN_IF_ERROR(ParseGcsPath(fname, false, &bucket, &object));
 
@@ -1887,11 +1893,11 @@
 
   TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when deleting ", fname);
   ClearFileCaches(fname);
-  return OkStatus();
+  return absl::OkStatus();
 }
 
-Status GcsFileSystem::CreateDir(const string& dirname,
-                                TransactionToken* token) {
+absl::Status GcsFileSystem::CreateDir(const string& dirname,
+                                      TransactionToken* token) {
   string dirname_with_slash = MaybeAppendSlash(dirname);
   VLOG(3) << "CreateDir: creating directory with dirname: " << dirname
           << " and dirname_with_slash: " << dirname_with_slash;
@@ -1901,7 +1907,7 @@
   if (object.empty()) {
     bool is_bucket;
     TF_RETURN_IF_ERROR(BucketExists(bucket, &is_bucket));
-    return is_bucket ? OkStatus()
+    return is_bucket ? absl::OkStatus()
                      : errors::NotFound("The specified bucket ",
                                         dirname_with_slash, " was not found.");
   }
@@ -1924,10 +1930,10 @@
 
   request->SetPostEmptyBody();
   request->SetTimeouts(timeouts_.connect, timeouts_.idle, timeouts_.metadata);
-  const Status& status = request->Send();
+  const absl::Status& status = request->Send();
   if (status.ok()) {
     VLOG(3) << "CreateDir: finished uploading directory " << dirname;
-    return OkStatus();
+    return absl::OkStatus();
   }
   if (request->GetResponseCode() != HTTP_CODE_PRECONDITION_FAILED) {
     TF_RETURN_WITH_CONTEXT_IF_ERROR(status, " when uploading ",
@@ -1940,8 +1946,8 @@
 
 // Checks that the directory is empty (i.e no objects with this prefix exist).
 // Deletes the GCS directory marker if it exists.
-Status GcsFileSystem::DeleteDir(const string& dirname,
-                                TransactionToken* token) {
+absl::Status GcsFileSystem::DeleteDir(const string& dirname,
+                                      TransactionToken* token) {
   std::vector<string> children;
   // A directory is considered empty either if there are no matching objects
   // with the corresponding name prefix or if there is exactly one matching
@@ -1958,11 +1964,12 @@
     // This is the directory marker object. Delete it.
     return DeleteFile(MaybeAppendSlash(dirname), token);
   }
-  return OkStatus();
+  return absl::OkStatus();
 }
 
-Status GcsFileSystem::GetFileSize(const string& fname, TransactionToken* token,
-                                  uint64* file_size) {
+absl::Status GcsFileSystem::GetFileSize(const string& fname,
+                                        TransactionToken* token,
+                                        uint64* file_size) {
   if (!file_size) {
     return errors::Internal("'file_size' cannot be nullptr.");
   }
@@ -1974,11 +1981,11 @@
   FileStatistics stat;
   TF_RETURN_IF_ERROR(Stat(fname, token, &stat));
   *file_size = stat.length;
-  return OkStatus();
+  return absl::OkStatus();
 }
 
-Status GcsFileSystem::RenameFile(const string& src, const string& target,
-                                 TransactionToken* token) {
+absl::Status GcsFileSystem::RenameFile(const string& src, const string& target,
+                                       TransactionToken* token) {
   if (!IsDirectory(src, token).ok()) {
     return RenameObject(src, target);
   }
@@ -1991,11 +1998,12 @@
     TF_RETURN_IF_ERROR(
         RenameObject(JoinGcsPath(src, subpath), JoinGcsPath(target, subpath)));
   }
-  return OkStatus();
+  return absl::OkStatus();
 }
 
 // Uses a GCS API command to copy the object and then deletes the old one.
-Status GcsFileSystem::RenameObject(const string& src, const string& target) {
+absl::Status GcsFileSystem::RenameObject(const string& src,
+                                         const string& target) {
   VLOG(3) << "RenameObject: started gs://" << src << " to " << target;
   string src_bucket, src_object, target_bucket, target_object;
   TF_RETURN_IF_ERROR(ParseGcsPath(src, false, &src_bucket, &src_object));
@@ -2040,15 +2048,15 @@
       [this, &src]() { return DeleteFile(src, nullptr); }, retry_config_);
 }
 
-Status GcsFileSystem::IsDirectory(const string& fname,
-                                  TransactionToken* token) {
+absl::Status GcsFileSystem::IsDirectory(const string& fname,
+                                        TransactionToken* token) {
   string bucket, object;
   TF_RETURN_IF_ERROR(ParseGcsPath(fname, true, &bucket, &object));
   if (object.empty()) {
     bool is_bucket;
     TF_RETURN_IF_ERROR(BucketExists(bucket, &is_bucket));
     if (is_bucket) {
-      return OkStatus();
+      return absl::OkStatus();
     }
     return errors::NotFound("The specified bucket gs://", bucket,
                             " was not found.");
@@ -2056,7 +2064,7 @@
   bool is_folder;
   TF_RETURN_IF_ERROR(FolderExists(fname, &is_folder));
   if (is_folder) {
-    return OkStatus();
+    return absl::OkStatus();
   }
   bool is_object;
   TF_RETURN_IF_ERROR(ObjectExists(fname, bucket, object, &is_object));
@@ -2067,10 +2075,10 @@
   return errors::NotFound("The specified path ", fname, " was not found.");
 }
 
-Status GcsFileSystem::DeleteRecursively(const string& dirname,
-                                        TransactionToken* token,
-                                        int64_t* undeleted_files,
-                                        int64_t* undeleted_dirs) {
+absl::Status GcsFileSystem::DeleteRecursively(const string& dirname,
+                                              TransactionToken* token,
+                                              int64_t* undeleted_files,
+                                              int64_t* undeleted_dirs) {
   if (!undeleted_files || !undeleted_dirs) {
     return errors::Internal(
         "'undeleted_files' and 'undeleted_dirs' cannot be nullptr.");
@@ -2079,7 +2087,7 @@
   *undeleted_dirs = 0;
   if (!IsDirectory(dirname, token).ok()) {
     *undeleted_dirs = 1;
-    return Status(
+    return absl::Status(
         absl::StatusCode::kNotFound,
         strings::StrCat(dirname, " doesn't exist or not a directory."));
   }
@@ -2106,7 +2114,7 @@
       }
     }
   }
-  return OkStatus();
+  return absl::OkStatus();
 }
 
 // Flushes all caches for filesystem metadata and file contents. Useful for
@@ -2148,7 +2156,8 @@
 // Creates an HttpRequest and sets several parameters that are common to all
 // requests.  All code (in GcsFileSystem) that creates an HttpRequest should
 // go through this method, rather than directly using http_request_factory_.
-Status GcsFileSystem::CreateHttpRequest(std::unique_ptr<HttpRequest>* request) {
+absl::Status GcsFileSystem::CreateHttpRequest(
+    std::unique_ptr<HttpRequest>* request) {
   std::unique_ptr<HttpRequest> new_request{http_request_factory_->Create()};
   if (dns_cache_) {
     dns_cache_->AnnotateRequest(new_request.get());
@@ -2177,7 +2186,7 @@
   }
 
   *request = std::move(new_request);
-  return OkStatus();
+  return absl::OkStatus();
 }
 
 RetryingGcsFileSystem::RetryingGcsFileSystem()
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_file_system.h b/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_file_system.h
index 17725e8..f7452a4 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_file_system.h
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_file_system.h
@@ -66,7 +66,7 @@
 // Helper function to extract an environment variable and convert it into a
 // value of type T.
 template <typename T>
-bool GetEnvVar(const char* varname, bool (*convert)(StringPiece, T*),
+bool GetEnvVar(const char* varname, bool (*convert)(absl::string_view, T*),
                T* value) {
   const char* env_value = std::getenv(varname);
   if (env_value == nullptr) {
@@ -144,48 +144,54 @@
 
   TF_USE_FILESYSTEM_METHODS_WITH_NO_TRANSACTION_SUPPORT;
 
-  Status NewRandomAccessFile(
+  absl::Status NewRandomAccessFile(
       const string& fname, TransactionToken* token,
       std::unique_ptr<RandomAccessFile>* result) override;
 
-  Status NewWritableFile(const string& fname, TransactionToken* token,
-                         std::unique_ptr<WritableFile>* result) override;
+  absl::Status NewWritableFile(const string& fname, TransactionToken* token,
+                               std::unique_ptr<WritableFile>* result) override;
 
-  Status NewAppendableFile(const string& fname, TransactionToken* token,
-                           std::unique_ptr<WritableFile>* result) override;
+  absl::Status NewAppendableFile(
+      const string& fname, TransactionToken* token,
+      std::unique_ptr<WritableFile>* result) override;
 
-  Status NewReadOnlyMemoryRegionFromFile(
+  absl::Status NewReadOnlyMemoryRegionFromFile(
       const string& fname, TransactionToken* token,
       std::unique_ptr<ReadOnlyMemoryRegion>* result) override;
 
-  Status FileExists(const string& fname, TransactionToken* token) override;
+  absl::Status FileExists(const string& fname,
+                          TransactionToken* token) override;
 
-  Status Stat(const string& fname, TransactionToken* token,
-              FileStatistics* stat) override;
+  absl::Status Stat(const string& fname, TransactionToken* token,
+                    FileStatistics* stat) override;
 
-  Status GetChildren(const string& dir, TransactionToken* token,
-                     std::vector<string>* result) override;
+  absl::Status GetChildren(const string& dir, TransactionToken* token,
+                           std::vector<string>* result) override;
 
-  Status GetMatchingPaths(const string& pattern, TransactionToken* token,
-                          std::vector<string>* results) override;
+  absl::Status GetMatchingPaths(const string& pattern, TransactionToken* token,
+                                std::vector<string>* results) override;
 
-  Status DeleteFile(const string& fname, TransactionToken* token) override;
+  absl::Status DeleteFile(const string& fname,
+                          TransactionToken* token) override;
 
-  Status CreateDir(const string& dirname, TransactionToken* token) override;
+  absl::Status CreateDir(const string& dirname,
+                         TransactionToken* token) override;
 
-  Status DeleteDir(const string& dirname, TransactionToken* token) override;
+  absl::Status DeleteDir(const string& dirname,
+                         TransactionToken* token) override;
 
-  Status GetFileSize(const string& fname, TransactionToken* token,
-                     uint64* file_size) override;
+  absl::Status GetFileSize(const string& fname, TransactionToken* token,
+                           uint64* file_size) override;
 
-  Status RenameFile(const string& src, const string& target,
-                    TransactionToken* token) override;
+  absl::Status RenameFile(const string& src, const string& target,
+                          TransactionToken* token) override;
 
-  Status IsDirectory(const string& fname, TransactionToken* token) override;
+  absl::Status IsDirectory(const string& fname,
+                           TransactionToken* token) override;
 
-  Status DeleteRecursively(const string& dirname, TransactionToken* token,
-                           int64_t* undeleted_files,
-                           int64_t* undeleted_dirs) override;
+  absl::Status DeleteRecursively(const string& dirname, TransactionToken* token,
+                                 int64_t* undeleted_files,
+                                 int64_t* undeleted_dirs) override;
 
   void FlushCaches(TransactionToken* token) override;
 
@@ -267,7 +273,7 @@
           write(write) {}
   };
 
-  Status CreateHttpRequest(std::unique_ptr<HttpRequest>* request);
+  absl::Status CreateHttpRequest(std::unique_ptr<HttpRequest>* request);
 
   /// \brief Sets a new AuthProvider on the GCS FileSystem.
   ///
@@ -289,37 +295,38 @@
       size_t block_size, size_t max_bytes, uint64 max_staleness);
 
   /// Loads file contents from GCS for a given filename, offset, and length.
-  virtual Status LoadBufferFromGCS(const string& fname, size_t offset, size_t n,
-                                   char* buffer, size_t* bytes_transferred);
+  virtual absl::Status LoadBufferFromGCS(const string& fname, size_t offset,
+                                         size_t n, char* buffer,
+                                         size_t* bytes_transferred);
 
   // Creates an upload session for an upcoming GCS object upload.
-  virtual Status CreateNewUploadSession(uint64 start_offset,
-                                        const std::string& object_to_upload,
-                                        const std::string& bucket,
-                                        uint64 file_size,
-                                        const std::string& gcs_path,
-                                        UploadSessionHandle* session_handle);
+  virtual absl::Status CreateNewUploadSession(
+      uint64 start_offset, const std::string& object_to_upload,
+      const std::string& bucket, uint64 file_size, const std::string& gcs_path,
+      UploadSessionHandle* session_handle);
 
   // Uploads object data to session.
-  virtual Status UploadToSession(const std::string& session_uri,
-                                 uint64 start_offset, uint64 already_uploaded,
-                                 const std::string& tmp_content_filename,
-                                 uint64 file_size,
-                                 const std::string& file_path);
+  virtual absl::Status UploadToSession(const std::string& session_uri,
+                                       uint64 start_offset,
+                                       uint64 already_uploaded,
+                                       const std::string& tmp_content_filename,
+                                       uint64 file_size,
+                                       const std::string& file_path);
 
   /// \brief Requests status of a previously initiated upload session.
   ///
   /// If the upload has already succeeded, sets 'completed' to true.
   /// Otherwise sets 'completed' to false and 'uploaded' to the currently
   /// uploaded size in bytes.
-  virtual Status RequestUploadSessionStatus(const string& session_uri,
-                                            uint64 file_size,
-                                            const std::string& gcs_path,
-                                            bool* completed, uint64* uploaded);
+  virtual absl::Status RequestUploadSessionStatus(const string& session_uri,
+                                                  uint64 file_size,
+                                                  const std::string& gcs_path,
+                                                  bool* completed,
+                                                  uint64* uploaded);
 
-  Status ParseGcsPathForScheme(StringPiece fname, string scheme,
-                               bool empty_object_ok, string* bucket,
-                               string* object);
+  absl::Status ParseGcsPathForScheme(absl::string_view fname, string scheme,
+                                     bool empty_object_ok, string* bucket,
+                                     string* object);
 
   /// \brief Splits a GCS path to a bucket and an object.
   ///
@@ -327,8 +334,9 @@
   /// "bucket-name" and "path/to/file.txt".
   /// If fname only contains the bucket and empty_object_ok = true, the returned
   /// object is empty.
-  virtual Status ParseGcsPath(StringPiece fname, bool empty_object_ok,
-                              string* bucket, string* object);
+  virtual absl::Status ParseGcsPath(absl::string_view fname,
+                                    bool empty_object_ok, string* bucket,
+                                    string* object);
 
   std::shared_ptr<ComputeEngineMetadataClient> compute_engine_metadata_client_;
 
@@ -348,7 +356,7 @@
   /// \brief Checks if the bucket exists. Returns OK if the check succeeded.
   ///
   /// 'result' is set if the function returns OK. 'result' cannot be nullptr.
-  Status BucketExists(const string& bucket, bool* result);
+  absl::Status BucketExists(const string& bucket, bool* result);
 
   /// \brief Retrieves the GCS bucket location. Returns OK if the location was
   /// retrieved.
@@ -359,28 +367,28 @@
   /// This requires the bucket metadata permission.
   /// Repeated calls for the same bucket are cached so this function can be
   /// called frequently without causing an extra API call
-  Status GetBucketLocation(const string& bucket, string* location);
+  absl::Status GetBucketLocation(const string& bucket, string* location);
 
   /// \brief Check if the GCS buckets location is allowed with the current
   /// constraint configuration
-  Status CheckBucketLocationConstraint(const string& bucket);
+  absl::Status CheckBucketLocationConstraint(const string& bucket);
 
   /// \brief Given the input bucket `bucket`, fills `result_buffer` with the
   /// results of the metadata. Returns OK if the API call succeeds without
   /// error.
-  Status GetBucketMetadata(const string& bucket,
-                           std::vector<char>* result_buffer);
+  absl::Status GetBucketMetadata(const string& bucket,
+                                 std::vector<char>* result_buffer);
 
   /// \brief Checks if the object exists. Returns OK if the check succeeded.
   ///
   /// 'result' is set if the function returns OK. 'result' cannot be nullptr.
-  Status ObjectExists(const string& fname, const string& bucket,
-                      const string& object, bool* result);
+  absl::Status ObjectExists(const string& fname, const string& bucket,
+                            const string& object, bool* result);
 
   /// \brief Checks if the folder exists. Returns OK if the check succeeded.
   ///
   /// 'result' is set if the function returns OK. 'result' cannot be nullptr.
-  Status FolderExists(const string& dirname, bool* result);
+  absl::Status FolderExists(const string& dirname, bool* result);
 
   /// \brief Internal version of GetChildren with more knobs.
   ///
@@ -390,19 +398,19 @@
   /// If 'include_self_directory_marker' is true and there is a GCS directory
   /// marker at the path 'dir', GetChildrenBound will return an empty string
   /// as one of the children that represents this marker.
-  Status GetChildrenBounded(const string& dir, uint64 max_results,
-                            std::vector<string>* result, bool recursively,
-                            bool include_self_directory_marker);
+  absl::Status GetChildrenBounded(const string& dir, uint64 max_results,
+                                  std::vector<string>* result, bool recursively,
+                                  bool include_self_directory_marker);
 
   /// Retrieves file statistics assuming fname points to a GCS object. The data
   /// may be read from cache or from GCS directly.
-  Status StatForObject(const string& fname, const string& bucket,
-                       const string& object, GcsFileStat* stat);
+  absl::Status StatForObject(const string& fname, const string& bucket,
+                             const string& object, GcsFileStat* stat);
   /// Retrieves file statistics of file fname directly from GCS.
-  Status UncachedStatForObject(const string& fname, const string& bucket,
-                               const string& object, GcsFileStat* stat);
+  absl::Status UncachedStatForObject(const string& fname, const string& bucket,
+                                     const string& object, GcsFileStat* stat);
 
-  Status RenameObject(const string& src, const string& target);
+  absl::Status RenameObject(const string& src, const string& target);
 
   // Clear all the caches related to the file with name `filename`.
   void ClearFileCaches(const string& fname);
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_file_system_test.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_file_system_test.cc
index 9221128..9d9d308 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_file_system_test.cc
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_file_system_test.cc
@@ -17,7 +17,7 @@
 
 #include <fstream>
 
-#include "tsl/lib/core/status_test_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tsl/platform/cloud/http_request_fake.h"
 #include "tsl/platform/errors.h"
 #include "tsl/platform/str_util.h"
@@ -45,17 +45,17 @@
 
 class FakeAuthProvider : public AuthProvider {
  public:
-  Status GetToken(string* token) override {
+  absl::Status GetToken(string* token) override {
     *token = "fake_token";
-    return OkStatus();
+    return absl::OkStatus();
   }
 };
 
 class FakeZoneProvider : public ZoneProvider {
  public:
-  Status GetZone(string* zone) override {
+  absl::Status GetZone(string* zone) override {
     *zone = "us-east1-b";
-    return OkStatus();
+    return absl::OkStatus();
   }
 };
 
@@ -88,12 +88,12 @@
   TF_EXPECT_OK(
       fs.NewRandomAccessFile("gs://bucket/random_access.txt", nullptr, &file));
 
-  StringPiece filename;
+  absl::string_view filename;
   TF_EXPECT_OK(file->Name(&filename));
   EXPECT_EQ(filename, "gs://bucket/random_access.txt");
 
   char scratch[6];
-  StringPiece result;
+  absl::string_view result;
 
   // Read the first chunk.
   TF_EXPECT_OK(file->Read(0, sizeof(scratch), &result, scratch));
@@ -135,12 +135,12 @@
   TF_EXPECT_OK(
       fs.NewRandomAccessFile("gs://bucket/random_access.txt", nullptr, &file));
 
-  StringPiece filename;
+  absl::string_view filename;
   TF_EXPECT_OK(file->Name(&filename));
   EXPECT_EQ(filename, "gs://bucket/random_access.txt");
 
   char scratch[6];
-  StringPiece result;
+  absl::string_view result;
 
   // Read the first chunk.
   TF_EXPECT_OK(file->Read(0, sizeof(scratch), &result, scratch));
@@ -183,12 +183,12 @@
   TF_EXPECT_OK(
       fs.NewRandomAccessFile("gs://bucket/random_access.txt", nullptr, &file));
 
-  StringPiece filename;
+  absl::string_view filename;
   TF_EXPECT_OK(file->Name(&filename));
   EXPECT_EQ(filename, "gs://bucket/random_access.txt");
 
   char scratch[6];
-  StringPiece result;
+  absl::string_view result;
 
   // Read the first chunk.
   EXPECT_TRUE(
@@ -230,12 +230,12 @@
   TF_EXPECT_OK(
       fs.NewRandomAccessFile("gs://bucket/random_access.txt", nullptr, &file));
 
-  StringPiece filename;
+  absl::string_view filename;
   TF_EXPECT_OK(file->Name(&filename));
   EXPECT_EQ(filename, "gs://bucket/random_access.txt");
 
   char scratch[10];
-  StringPiece result;
+  absl::string_view result;
 
   // Read the first chunk.
   TF_EXPECT_OK(file->Read(0, sizeof(scratch), &result, scratch));
@@ -271,12 +271,12 @@
   TF_EXPECT_OK(
       fs.NewRandomAccessFile("gs://bucket/random_access.txt", nullptr, &file));
 
-  StringPiece filename;
+  absl::string_view filename;
   TF_EXPECT_OK(file->Name(&filename));
   EXPECT_EQ(filename, "gs://bucket/random_access.txt");
 
   char scratch[5];
-  StringPiece result;
+  absl::string_view result;
 
   // Read the first chunk. Even though the backend response is out-of-range,
   // we should get a OK status since we're just reading the first 5 bytes.
@@ -323,12 +323,12 @@
   TF_EXPECT_OK(
       fs.NewRandomAccessFile("gs://bucket/random_access.txt", nullptr, &file));
 
-  StringPiece filename;
+  absl::string_view filename;
   TF_EXPECT_OK(file->Name(&filename));
   EXPECT_EQ(filename, "gs://bucket/random_access.txt");
 
   char scratch[5];
-  StringPiece result;
+  absl::string_view result;
 
   TF_EXPECT_OK(file->Read(1, sizeof(scratch), &result, scratch));
   EXPECT_EQ("12345", result);
@@ -365,12 +365,12 @@
   TF_EXPECT_OK(
       fs.NewRandomAccessFile("gs://bucket/random_access.txt", nullptr, &file));
 
-  StringPiece filename;
+  absl::string_view filename;
   TF_EXPECT_OK(file->Name(&filename));
   EXPECT_EQ(filename, "gs://bucket/random_access.txt");
 
   char scratch[10];
-  StringPiece result;
+  absl::string_view result;
 
   // Read the first chunk. Since the first read is out-of-range,
   // we don't cache the out-of-range flag and each subsequent read triggers a
@@ -413,12 +413,12 @@
   TF_EXPECT_OK(
       fs.NewRandomAccessFile("gs://bucket/random_access.txt", nullptr, &file));
 
-  StringPiece filename;
+  absl::string_view filename;
   TF_EXPECT_OK(file->Name(&filename));
   EXPECT_EQ(filename, "gs://bucket/random_access.txt");
 
   char scratch[10];
-  StringPiece result;
+  absl::string_view result;
 
   // Read the first chunk.
   EXPECT_TRUE(
@@ -574,7 +574,7 @@
       fs.NewRandomAccessFile("gs://bucket/random_access.txt", nullptr, &file));
 
   char small_scratch[3];
-  StringPiece result;
+  absl::string_view result;
 
   // Read the first chunk.
   TF_EXPECT_OK(file->Read(0, sizeof(small_scratch), &result, small_scratch));
@@ -629,7 +629,7 @@
       nullptr /* gcs additional header */, false /* compose append */);
 
   char scratch[100];
-  StringPiece result;
+  absl::string_view result;
   {
     // We are instantiating this in an enclosed scope to make sure after the
     // unique ptr goes out of scope, we can still access result.
@@ -716,7 +716,7 @@
       nullptr /* gcs additional header */, false /* compose append */);
 
   char scratch[100];
-  StringPiece result;
+  absl::string_view result;
   std::unique_ptr<RandomAccessFile> file;
   TF_EXPECT_OK(
       fs.NewRandomAccessFile("gs://bucket/random_access.txt", nullptr, &file));
@@ -766,7 +766,7 @@
       kTestTimeoutConfig, *kAllowedLocationsDefault,
       nullptr /* gcs additional header */, false /* compose append */);
   char scratch[100];
-  StringPiece result;
+  absl::string_view result;
   // There should only be two HTTP requests issued to GCS even though we iterate
   // this loop 10 times.  This shows that the underlying FileBlockCache persists
   // across file close/open boundaries.
@@ -841,7 +841,7 @@
       fs.NewRandomAccessFile("gs://bucket/random_access.txt", nullptr, &file));
 
   char scratch[5];
-  StringPiece result;
+  absl::string_view result;
 
   // First read.
   TF_EXPECT_OK(file->Read(0, sizeof(scratch), &result, scratch));
@@ -908,7 +908,7 @@
       fs.NewRandomAccessFile("gs://bucket/random_access.txt", nullptr, &file));
 
   char scratch[6];
-  StringPiece result;
+  absl::string_view result;
 
   EXPECT_TRUE(
       errors::IsInternal(file->Read(0, sizeof(scratch), &result, scratch)));
@@ -972,7 +972,7 @@
   TF_EXPECT_OK(
       fs.NewRandomAccessFile("gs://bucket/path/writeable", nullptr, &rfile));
   char scratch[100];
-  StringPiece result;
+  absl::string_view result;
   TF_EXPECT_OK(rfile->Read(0, 4, &result, scratch));
   EXPECT_EQ("0123", result);
   // Open the writable file.
@@ -1107,7 +1107,7 @@
                            "Timeouts: 5 1 10\n"
                            "Header Content-Range: bytes */17\n"
                            "Put: yes\n",
-                           "", OkStatus(), nullptr, {}, 201),
+                           "", absl::OkStatus(), nullptr, {}, 201),
        new FakeHttpRequest(
            "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/"
            "path%2Fwriteable?fields=size%2Cgeneration%2Cupdated\n"
@@ -1138,7 +1138,7 @@
   TF_EXPECT_OK(
       fs.NewRandomAccessFile("gs://bucket/path/writeable", nullptr, &rfile));
   char scratch[100];
-  StringPiece result;
+  absl::string_view result;
   TF_EXPECT_OK(rfile->Read(0, 4, &result, scratch));
   EXPECT_EQ("0123", result);
   // Now write to the same file. Once the write succeeds, the cached block will
@@ -1402,7 +1402,7 @@
   TF_EXPECT_OK(
       fs.NewRandomAccessFile("gs://bucket/path/appendable", nullptr, &rfile));
   char scratch[100];
-  StringPiece result;
+  absl::string_view result;
   TF_EXPECT_OK(rfile->Read(0, 8, &result, scratch));
   EXPECT_EQ("content1", result);
   // Closing the appendable file will flush its contents to GCS, triggering HTTP
@@ -1496,8 +1496,9 @@
   TF_EXPECT_OK(fs.NewReadOnlyMemoryRegionFromFile(
       "gs://bucket/path/random_access.txt", nullptr, &region));
 
-  EXPECT_EQ(content, StringPiece(reinterpret_cast<const char*>(region->data()),
-                                 region->length()));
+  EXPECT_EQ(content,
+            absl::string_view(reinterpret_cast<const char*>(region->data()),
+                              region->length()));
 }
 
 TEST(GcsFileSystemTest, NewReadOnlyMemoryRegionFromFile_NoObjectName) {
@@ -2262,7 +2263,7 @@
 
   // Do an initial read of the file to load its contents into the block cache.
   char scratch[100];
-  StringPiece result;
+  absl::string_view result;
   std::unique_ptr<RandomAccessFile> file;
   TF_EXPECT_OK(
       fs.NewRandomAccessFile("gs://bucket/path/file1.txt", nullptr, &file));
@@ -2656,7 +2657,7 @@
   // Do an initial read of the source and destination files to load their
   // contents into the block cache.
   char scratch[100];
-  StringPiece result;
+  absl::string_view result;
   std::unique_ptr<RandomAccessFile> src;
   std::unique_ptr<RandomAccessFile> dst;
   TF_EXPECT_OK(
@@ -3798,7 +3799,7 @@
       fs.NewRandomAccessFile("gs://bucket/random_access.txt", nullptr, &file));
 
   char scratch[6];
-  StringPiece result;
+  absl::string_view result;
 
   TF_EXPECT_OK(file->Read(0, sizeof(scratch), &result, scratch));
   EXPECT_EQ("012345", result);
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_throttle_test.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_throttle_test.cc
index dfd8310..658629f 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_throttle_test.cc
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_throttle_test.cc
@@ -15,7 +15,7 @@
 
 #include "tsl/platform/cloud/gcs_throttle.h"
 
-#include "tsl/lib/core/status_test_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tsl/platform/str_util.h"
 #include "tsl/platform/test.h"
 
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/google_auth_provider.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/google_auth_provider.cc
index f1b62fb..7f1f94d 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/google_auth_provider.cc
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/google_auth_provider.cc
@@ -82,7 +82,7 @@
 }
 
 /// Returns the credentials file name from the env variable.
-Status GetEnvironmentVariableFileName(string* filename) {
+absl::Status GetEnvironmentVariableFileName(string* filename) {
   if (!filename) {
     return errors::FailedPrecondition("'filename' cannot be nullptr.");
   }
@@ -92,11 +92,11 @@
                                             " is not set or corrupt."));
   }
   *filename = result;
-  return OkStatus();
+  return absl::OkStatus();
 }
 
 /// Returns the well known file produced by command 'gcloud auth login'.
-Status GetWellKnownFileName(string* filename) {
+absl::Status GetWellKnownFileName(string* filename) {
   if (!filename) {
     return errors::FailedPrecondition("'filename' cannot be nullptr.");
   }
@@ -118,7 +118,7 @@
         "Could not find the credentials file in the standard gcloud location.");
   }
   *filename = result;
-  return OkStatus();
+  return absl::OkStatus();
 }
 
 }  // namespace
@@ -138,42 +138,42 @@
           std::move(compute_engine_metadata_client)),
       env_(env) {}
 
-Status GoogleAuthProvider::GetToken(string* t) {
+absl::Status GoogleAuthProvider::GetToken(string* t) {
   mutex_lock lock(mu_);
   const uint64 now_sec = env_->NowSeconds();
 
   if (now_sec + kExpirationTimeMarginSec < expiration_timestamp_sec_) {
     *t = current_token_;
-    return OkStatus();
+    return absl::OkStatus();
   }
 
   if (GetTokenForTesting().ok()) {
     *t = current_token_;
-    return OkStatus();
+    return absl::OkStatus();
   }
 
   auto token_from_files_status = GetTokenFromFiles();
   if (token_from_files_status.ok()) {
     *t = current_token_;
-    return OkStatus();
+    return absl::OkStatus();
   }
 
   char* no_gce_check_var = std::getenv(kNoGceCheck);
   bool skip_gce_check = no_gce_check_var != nullptr &&
                         absl::EqualsIgnoreCase(no_gce_check_var, "true");
-  Status token_from_gce_status;
+  absl::Status token_from_gce_status;
   if (skip_gce_check) {
     token_from_gce_status =
-        Status(absl::StatusCode::kCancelled,
-               strings::StrCat("GCE check skipped due to presence of $",
-                               kNoGceCheck, " environment variable."));
+        absl::Status(absl::StatusCode::kCancelled,
+                     strings::StrCat("GCE check skipped due to presence of $",
+                                     kNoGceCheck, " environment variable."));
   } else {
     token_from_gce_status = GetTokenFromGce();
   }
 
   if (token_from_gce_status.ok()) {
     *t = current_token_;
-    return OkStatus();
+    return absl::OkStatus();
   }
 
   if (skip_gce_check) {
@@ -203,10 +203,10 @@
   }
   current_token_ = "";
 
-  return OkStatus();
+  return absl::OkStatus();
 }
 
-Status GoogleAuthProvider::GetTokenFromFiles() {
+absl::Status GoogleAuthProvider::GetTokenFromFiles() {
   string credentials_filename;
   if (!GetEnvironmentVariableFileName(&credentials_filename).ok() &&
       !GetWellKnownFileName(&credentials_filename).ok()) {
@@ -231,33 +231,33 @@
     return errors::FailedPrecondition(
         "Unexpected content of the JSON credentials file.");
   }
-  return OkStatus();
+  return absl::OkStatus();
 }
 
-Status GoogleAuthProvider::GetTokenFromGce() {
+absl::Status GoogleAuthProvider::GetTokenFromGce() {
   std::vector<char> response_buffer;
   const uint64 request_timestamp_sec = env_->NowSeconds();
 
   TF_RETURN_IF_ERROR(compute_engine_metadata_client_->GetMetadata(
       kGceTokenPath, &response_buffer));
-  StringPiece response =
-      StringPiece(&response_buffer[0], response_buffer.size());
+  absl::string_view response =
+      absl::string_view(&response_buffer[0], response_buffer.size());
 
   TF_RETURN_IF_ERROR(oauth_client_->ParseOAuthResponse(
       response, request_timestamp_sec, &current_token_,
       &expiration_timestamp_sec_));
 
-  return OkStatus();
+  return absl::OkStatus();
 }
 
-Status GoogleAuthProvider::GetTokenForTesting() {
+absl::Status GoogleAuthProvider::GetTokenForTesting() {
   const char* token = std::getenv(kGoogleAuthTokenForTesting);
   if (!token) {
     return errors::NotFound("The env variable for testing was not set.");
   }
   expiration_timestamp_sec_ = UINT64_MAX;
   current_token_ = token;
-  return OkStatus();
+  return absl::OkStatus();
 }
 
 }  // namespace tsl
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/google_auth_provider.h b/third_party/xla/third_party/tsl/tsl/platform/cloud/google_auth_provider.h
index 63b7ea6..38ab66d 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/google_auth_provider.h
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/google_auth_provider.h
@@ -40,20 +40,20 @@
   /// \brief Returns the short-term authentication bearer token.
   ///
   /// Safe for concurrent use by multiple threads.
-  Status GetToken(string* token) override;
+  absl::Status GetToken(string* token) override;
 
  private:
   /// \brief Gets the bearer token from files.
   ///
   /// Tries the file from $GOOGLE_APPLICATION_CREDENTIALS and the
   /// standard gcloud tool's location.
-  Status GetTokenFromFiles() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
+  absl::Status GetTokenFromFiles() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
 
   /// Gets the bearer token from Google Compute Engine environment.
-  Status GetTokenFromGce() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
+  absl::Status GetTokenFromGce() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
 
   /// Gets the bearer token from the system env variable, for testing purposes.
-  Status GetTokenForTesting() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
+  absl::Status GetTokenForTesting() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
 
   std::unique_ptr<OAuthClient> oauth_client_;
   std::shared_ptr<ComputeEngineMetadataClient> compute_engine_metadata_client_;
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/google_auth_provider_test.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/google_auth_provider_test.cc
index 6f3072f..e7d6c4a 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/google_auth_provider_test.cc
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/google_auth_provider_test.cc
@@ -17,7 +17,7 @@
 
 #include <stdlib.h>
 
-#include "tsl/lib/core/status_test_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tsl/platform/cloud/http_request_fake.h"
 #include "tsl/platform/path.h"
 #include "tsl/platform/test.h"
@@ -40,23 +40,24 @@
 
 class FakeOAuthClient : public OAuthClient {
  public:
-  Status GetTokenFromServiceAccountJson(
-      Json::Value json, StringPiece oauth_server_uri, StringPiece scope,
-      string* token, uint64* expiration_timestamp_sec) override {
-    provided_credentials_json = json;
-    *token = return_token;
-    *expiration_timestamp_sec = return_expiration_timestamp;
-    return OkStatus();
-  }
-
-  /// Retrieves a bearer token using a refresh token.
-  Status GetTokenFromRefreshTokenJson(
-      Json::Value json, StringPiece oauth_server_uri, string* token,
+  absl::Status GetTokenFromServiceAccountJson(
+      Json::Value json, absl::string_view oauth_server_uri,
+      absl::string_view scope, string* token,
       uint64* expiration_timestamp_sec) override {
     provided_credentials_json = json;
     *token = return_token;
     *expiration_timestamp_sec = return_expiration_timestamp;
-    return OkStatus();
+    return absl::OkStatus();
+  }
+
+  /// Retrieves a bearer token using a refresh token.
+  absl::Status GetTokenFromRefreshTokenJson(
+      Json::Value json, absl::string_view oauth_server_uri, string* token,
+      uint64* expiration_timestamp_sec) override {
+    provided_credentials_json = json;
+    *token = return_token;
+    *expiration_timestamp_sec = return_expiration_timestamp;
+    return absl::OkStatus();
   }
 
   string return_token;
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/http_request.h b/third_party/xla/third_party/tsl/tsl/platform/cloud/http_request.h
index a3a3136..8102dd6 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/http_request.h
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/http_request.h
@@ -85,7 +85,8 @@
 
     /// RecordResponse is called after the response has been received.
     virtual void RecordResponse(const HttpRequest* request, const string& uri,
-                                RequestMethod method, const Status& result) = 0;
+                                RequestMethod method,
+                                const absl::Status& result) = 0;
   };
 
   HttpRequest() {}
@@ -124,7 +125,8 @@
   ///
   /// The request body will be taken from the specified file starting from
   /// the given offset.
-  virtual Status SetPutFromFile(const string& body_filepath, size_t offset) = 0;
+  virtual absl::Status SetPutFromFile(const string& body_filepath,
+                                      size_t offset) = 0;
 
   /// Makes the request a PUT request with an empty body.
   virtual void SetPutEmptyBody() = 0;
@@ -169,7 +171,7 @@
   ///
   /// If the result buffer was defined, the response will be written there.
   /// The object is not designed to be re-used after Send() is executed.
-  virtual Status Send() = 0;
+  virtual absl::Status Send() = 0;
 
   // Url encodes str and returns a new string.
   virtual string EscapeString(const string& str) = 0;
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/http_request_fake.h b/third_party/xla/third_party/tsl/tsl/platform/cloud/http_request_fake.h
index ea1f487..869d2ab 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/http_request_fake.h
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/http_request_fake.h
@@ -21,7 +21,7 @@
 #include <string>
 #include <vector>
 
-#include "tsl/lib/core/status_test_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tsl/platform/cloud/curl_http_request.h"
 #include "tsl/platform/errors.h"
 #include "tsl/platform/macros.h"
@@ -38,12 +38,13 @@
  public:
   /// Return the response for the given request.
   FakeHttpRequest(const string& request, const string& response)
-      : FakeHttpRequest(request, response, OkStatus(), nullptr, {}, 200) {}
+      : FakeHttpRequest(request, response, absl::OkStatus(), nullptr, {}, 200) {
+  }
 
   /// Return the response with headers for the given request.
   FakeHttpRequest(const string& request, const string& response,
                   const std::map<string, string>& response_headers)
-      : FakeHttpRequest(request, response, OkStatus(), nullptr,
+      : FakeHttpRequest(request, response, absl::OkStatus(), nullptr,
                         response_headers, 200) {}
 
   /// \brief Return the response for the request and capture the POST body.
@@ -51,12 +52,12 @@
   /// Post body is not expected to be a part of the 'request' parameter.
   FakeHttpRequest(const string& request, const string& response,
                   string* captured_post_body)
-      : FakeHttpRequest(request, response, OkStatus(), captured_post_body, {},
-                        200) {}
+      : FakeHttpRequest(request, response, absl::OkStatus(), captured_post_body,
+                        {}, 200) {}
 
   /// \brief Return the response and the status for the given request.
   FakeHttpRequest(const string& request, const string& response,
-                  Status response_status, uint64 response_code)
+                  absl::Status response_status, uint64 response_code)
       : FakeHttpRequest(request, response, response_status, nullptr, {},
                         response_code) {}
 
@@ -65,7 +66,7 @@
   ///
   /// Post body is not expected to be a part of the 'request' parameter.
   FakeHttpRequest(const string& request, const string& response,
-                  Status response_status, string* captured_post_body,
+                  absl::Status response_status, string* captured_post_body,
                   const std::map<string, string>& response_headers,
                   uint64 response_code)
       : expected_request_(request),
@@ -88,20 +89,21 @@
     actual_request_ += "Auth Token: " + auth_token + "\n";
   }
   void SetDeleteRequest() override { actual_request_ += "Delete: yes\n"; }
-  Status SetPutFromFile(const string& body_filepath, size_t offset) override {
+  absl::Status SetPutFromFile(const string& body_filepath,
+                              size_t offset) override {
     std::ifstream stream(body_filepath);
     const string& content = string(std::istreambuf_iterator<char>(stream),
                                    std::istreambuf_iterator<char>())
                                 .substr(offset);
     actual_request_ += "Put body: " + content + "\n";
-    return OkStatus();
+    return absl::OkStatus();
   }
   void SetPostFromBuffer(const char* buffer, size_t size) override {
     if (captured_post_body_) {
       *captured_post_body_ = string(buffer, size);
     } else {
       actual_request_ +=
-          strings::StrCat("Post body: ", StringPiece(buffer, size), "\n");
+          strings::StrCat("Post body: ", absl::string_view(buffer, size), "\n");
     }
   }
   void SetPutEmptyBody() override { actual_request_ += "Put: yes\n"; }
@@ -123,7 +125,7 @@
   size_t GetResultBufferDirectBytesTransferred() override {
     return direct_result_bytes_transferred_;
   }
-  Status Send() override {
+  absl::Status Send() override {
     EXPECT_EQ(expected_request_, actual_request())
         << "Unexpected HTTP request.";
     if (buffer_) {
@@ -182,7 +184,7 @@
   string actual_uri_;
   string actual_request_;
   string response_;
-  Status response_status_;
+  absl::Status response_status_;
   string* captured_post_body_ = nullptr;
   std::map<string, string> response_headers_;
   uint64 response_code_ = 0;
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/oauth_client.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/oauth_client.cc
index c983577..7480680 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/oauth_client.cc
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/oauth_client.cc
@@ -49,8 +49,8 @@
 constexpr char kGrantType[] =
     "urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Ajwt-bearer";
 
-Status ReadJsonValue(const Json::Value& json, const string& name,
-                     Json::Value* value) {
+absl::Status ReadJsonValue(const Json::Value& json, const string& name,
+                           Json::Value* value) {
   if (!value) {
     return errors::FailedPrecondition("'value' cannot be nullptr.");
   }
@@ -59,11 +59,11 @@
     return errors::FailedPrecondition(
         strings::StrCat("Couldn't read a JSON value '", name, "'."));
   }
-  return OkStatus();
+  return absl::OkStatus();
 }
 
-Status ReadJsonString(const Json::Value& json, const string& name,
-                      string* value) {
+absl::Status ReadJsonString(const Json::Value& json, const string& name,
+                            string* value) {
   Json::Value json_value;
   TF_RETURN_IF_ERROR(ReadJsonValue(json, name, &json_value));
   if (!json_value.isString()) {
@@ -71,11 +71,11 @@
         strings::StrCat("JSON value '", name, "' is not string."));
   }
   *value = json_value.asString();
-  return OkStatus();
+  return absl::OkStatus();
 }
 
-Status ReadJsonInt(const Json::Value& json, const string& name,
-                   int64_t* value) {
+absl::Status ReadJsonInt(const Json::Value& json, const string& name,
+                         int64_t* value) {
   Json::Value json_value;
   TF_RETURN_IF_ERROR(ReadJsonValue(json, name, &json_value));
   if (!json_value.isIntegral()) {
@@ -83,11 +83,11 @@
         strings::StrCat("JSON value '", name, "' is not integer."));
   }
   *value = json_value.asInt64();
-  return OkStatus();
+  return absl::OkStatus();
 }
 
-Status CreateSignature(RSA* private_key, StringPiece to_sign,
-                       string* signature) {
+absl::Status CreateSignature(RSA* private_key, absl::string_view to_sign,
+                             string* signature) {
   if (!private_key || !signature) {
     return errors::FailedPrecondition(
         "'private_key' and 'signature' cannot be nullptr.");
@@ -126,14 +126,15 @@
   if (EVP_DigestSignFinal(md_ctx.get(), sig.get(), &sig_len) != 1) {
     return errors::Internal("DigestFinal (signature compute) failed.");
   }
-  return Base64Encode(StringPiece(reinterpret_cast<char*>(sig.get()), sig_len),
-                      signature);
+  return Base64Encode(
+      absl::string_view(reinterpret_cast<char*>(sig.get()), sig_len),
+      signature);
 }
 
 /// Encodes a claim for a JSON web token (JWT) to make an OAuth request.
-Status EncodeJwtClaim(StringPiece client_email, StringPiece scope,
-                      StringPiece audience, uint64 request_timestamp_sec,
-                      string* encoded) {
+absl::Status EncodeJwtClaim(absl::string_view client_email,
+                            absl::string_view scope, absl::string_view audience,
+                            uint64 request_timestamp_sec, string* encoded) {
   // Step 1: create the JSON with the claim.
   Json::Value root;
   root["iss"] = Json::Value(client_email.data(),
@@ -155,7 +156,7 @@
 }
 
 /// Encodes a header for a JSON web token (JWT) to make an OAuth request.
-Status EncodeJwtHeader(StringPiece key_id, string* encoded) {
+absl::Status EncodeJwtHeader(absl::string_view key_id, string* encoded) {
   // Step 1: create the JSON with the header.
   Json::Value root;
   root["alg"] = kCryptoAlgorithm;
@@ -180,9 +181,9 @@
     std::unique_ptr<HttpRequest::Factory> http_request_factory, Env* env)
     : http_request_factory_(std::move(http_request_factory)), env_(env) {}
 
-Status OAuthClient::GetTokenFromServiceAccountJson(
-    Json::Value json, StringPiece oauth_server_uri, StringPiece scope,
-    string* token, uint64* expiration_timestamp_sec) {
+absl::Status OAuthClient::GetTokenFromServiceAccountJson(
+    Json::Value json, absl::string_view oauth_server_uri,
+    absl::string_view scope, string* token, uint64* expiration_timestamp_sec) {
   if (!token || !expiration_timestamp_sec) {
     return errors::FailedPrecondition(
         "'token' and 'expiration_timestamp_sec' cannot be nullptr.");
@@ -228,15 +229,15 @@
   request->SetResultBuffer(&response_buffer);
   TF_RETURN_IF_ERROR(request->Send());
 
-  StringPiece response =
-      StringPiece(response_buffer.data(), response_buffer.size());
+  absl::string_view response =
+      absl::string_view(response_buffer.data(), response_buffer.size());
   TF_RETURN_IF_ERROR(ParseOAuthResponse(response, request_timestamp_sec, token,
                                         expiration_timestamp_sec));
-  return OkStatus();
+  return absl::OkStatus();
 }
 
-Status OAuthClient::GetTokenFromRefreshTokenJson(
-    Json::Value json, StringPiece oauth_server_uri, string* token,
+absl::Status OAuthClient::GetTokenFromRefreshTokenJson(
+    Json::Value json, absl::string_view oauth_server_uri, string* token,
     uint64* expiration_timestamp_sec) {
   if (!token || !expiration_timestamp_sec) {
     return errors::FailedPrecondition(
@@ -260,17 +261,17 @@
   request->SetResultBuffer(&response_buffer);
   TF_RETURN_IF_ERROR(request->Send());
 
-  StringPiece response =
-      StringPiece(response_buffer.data(), response_buffer.size());
+  absl::string_view response =
+      absl::string_view(response_buffer.data(), response_buffer.size());
   TF_RETURN_IF_ERROR(ParseOAuthResponse(response, request_timestamp_sec, token,
                                         expiration_timestamp_sec));
-  return OkStatus();
+  return absl::OkStatus();
 }
 
-Status OAuthClient::ParseOAuthResponse(StringPiece response,
-                                       uint64 request_timestamp_sec,
-                                       string* token,
-                                       uint64* expiration_timestamp_sec) {
+absl::Status OAuthClient::ParseOAuthResponse(absl::string_view response,
+                                             uint64 request_timestamp_sec,
+                                             string* token,
+                                             uint64* expiration_timestamp_sec) {
   if (!token || !expiration_timestamp_sec) {
     return errors::FailedPrecondition(
         "'token' and 'expiration_timestamp_sec' cannot be nullptr.");
@@ -292,7 +293,7 @@
   *expiration_timestamp_sec = request_timestamp_sec + expires_in;
   TF_RETURN_IF_ERROR(ReadJsonString(root, "access_token", token));
 
-  return OkStatus();
+  return absl::OkStatus();
 }
 
 }  // namespace tsl
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/oauth_client.h b/third_party/xla/third_party/tsl/tsl/platform/cloud/oauth_client.h
index 895c2d0..19d8b4f 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/oauth_client.h
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/oauth_client.h
@@ -37,20 +37,20 @@
   ///
   /// Retrieves the authentication bearer token using a JSON file
   /// with the client's private key.
-  virtual Status GetTokenFromServiceAccountJson(
-      Json::Value json, StringPiece oauth_server_uri, StringPiece scope,
-      string* token, uint64* expiration_timestamp_sec);
+  virtual absl::Status GetTokenFromServiceAccountJson(
+      Json::Value json, absl::string_view oauth_server_uri,
+      absl::string_view scope, string* token, uint64* expiration_timestamp_sec);
 
   /// Retrieves a bearer token using a refresh token.
-  virtual Status GetTokenFromRefreshTokenJson(Json::Value json,
-                                              StringPiece oauth_server_uri,
-                                              string* token,
-                                              uint64* expiration_timestamp_sec);
+  virtual absl::Status GetTokenFromRefreshTokenJson(
+      Json::Value json, absl::string_view oauth_server_uri, string* token,
+      uint64* expiration_timestamp_sec);
 
   /// Parses the JSON response with the token from an OAuth 2.0 server.
-  virtual Status ParseOAuthResponse(StringPiece response,
-                                    uint64 request_timestamp_sec, string* token,
-                                    uint64* expiration_timestamp_sec);
+  virtual absl::Status ParseOAuthResponse(absl::string_view response,
+                                          uint64 request_timestamp_sec,
+                                          string* token,
+                                          uint64* expiration_timestamp_sec);
 
  private:
   std::unique_ptr<HttpRequest::Factory> http_request_factory_;
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/oauth_client_test.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/oauth_client_test.cc
index 8979f44..dc4c116 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/oauth_client_test.cc
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/oauth_client_test.cc
@@ -20,7 +20,7 @@
 #include <openssl/bio.h>
 #include <openssl/evp.h>
 #include <openssl/pem.h>
-#include "tsl/lib/core/status_test_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tsl/platform/base64.h"
 #include "tsl/platform/cloud/http_request_fake.h"
 #include "tsl/platform/env.h"
@@ -118,7 +118,7 @@
   EXPECT_EQ(13920, expiration_timestamp);
 
   // Now look at the JWT claim that was sent to the OAuth server.
-  StringPiece grant_type, assertion;
+  absl::string_view grant_type, assertion;
   ASSERT_TRUE(strings::Scanner(post_body)
                   .OneLiteral("grant_type=")
                   .RestartCapture()
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/ram_file_block_cache.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/ram_file_block_cache.cc
index f16ab81..57330d1 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/ram_file_block_cache.cc
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/ram_file_block_cache.cc
@@ -68,12 +68,12 @@
 }
 
 /// Move the block to the front of the LRU list if it isn't already there.
-Status RamFileBlockCache::UpdateLRU(const Key& key,
-                                    const std::shared_ptr<Block>& block) {
+absl::Status RamFileBlockCache::UpdateLRU(const Key& key,
+                                          const std::shared_ptr<Block>& block) {
   mutex_lock lock(mu_);
   if (block->timestamp == 0) {
     // The block was evicted from another thread. Allow it to remain evicted.
-    return OkStatus();
+    return absl::OkStatus();
   }
   if (block->lru_iterator != lru_list_.begin()) {
     lru_list_.erase(block->lru_iterator);
@@ -95,11 +95,11 @@
 
   Trim();
 
-  return OkStatus();
+  return absl::OkStatus();
 }
 
-Status RamFileBlockCache::MaybeFetch(const Key& key,
-                                     const std::shared_ptr<Block>& block) {
+absl::Status RamFileBlockCache::MaybeFetch(
+    const Key& key, const std::shared_ptr<Block>& block) {
   bool downloaded_block = false;
   auto reconcile_state =
       absl::MakeCleanup([this, &downloaded_block, &key, &block] {
@@ -123,7 +123,7 @@
   // Loop until either block content is successfully fetched, or our request
   // encounters an error.
   mutex_lock l(block->mu);
-  Status status = OkStatus();
+  absl::Status status = absl::OkStatus();
   while (true) {
     switch (block->state) {
       case FetchState::ERROR:
@@ -155,23 +155,24 @@
       case FetchState::FETCHING:
         block->cond_var.wait_for(l, std::chrono::seconds(60));
         if (block->state == FetchState::FINISHED) {
-          return OkStatus();
+          return absl::OkStatus();
         }
         // Re-loop in case of errors.
         break;
       case FetchState::FINISHED:
-        return OkStatus();
+        return absl::OkStatus();
     }
   }
   return errors::Internal(
       "Control flow should never reach the end of RamFileBlockCache::Fetch.");
 }
 
-Status RamFileBlockCache::Read(const string& filename, size_t offset, size_t n,
-                               char* buffer, size_t* bytes_transferred) {
+absl::Status RamFileBlockCache::Read(const string& filename, size_t offset,
+                                     size_t n, char* buffer,
+                                     size_t* bytes_transferred) {
   *bytes_transferred = 0;
   if (n == 0) {
-    return OkStatus();
+    return absl::OkStatus();
   }
   if (!IsCacheEnabled() || (n > max_bytes_)) {
     // The cache is effectively disabled, so we pass the read through to the
@@ -226,7 +227,7 @@
     }
   }
   *bytes_transferred = total_bytes_transferred;
-  return OkStatus();
+  return absl::OkStatus();
 }
 
 bool RamFileBlockCache::ValidateAndUpdateFileSignature(const string& filename,
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/ram_file_block_cache.h b/third_party/xla/third_party/tsl/tsl/platform/cloud/ram_file_block_cache.h
index 76cf7eb..627cf6f 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/ram_file_block_cache.h
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/ram_file_block_cache.h
@@ -45,9 +45,9 @@
   /// cache is constructed. The returned Status should be OK as long as the
   /// read from the remote filesystem succeeded (similar to the semantics of the
   /// read(2) system call).
-  typedef std::function<Status(const string& filename, size_t offset,
-                               size_t buffer_size, char* buffer,
-                               size_t* bytes_transferred)>
+  typedef std::function<absl::Status(const string& filename, size_t offset,
+                                     size_t buffer_size, char* buffer,
+                                     size_t* bytes_transferred)>
       BlockFetcher;
 
   RamFileBlockCache(size_t block_size, size_t max_bytes, uint64 max_staleness,
@@ -88,8 +88,8 @@
   ///    placed in `out`.
   /// 4) OK otherwise (i.e. the read succeeded, and at least one byte was placed
   ///    in `out`).
-  Status Read(const string& filename, size_t offset, size_t n, char* buffer,
-              size_t* bytes_transferred) override;
+  absl::Status Read(const string& filename, size_t offset, size_t n,
+                    char* buffer, size_t* bytes_transferred) override;
 
   // Validate the given file signature with the existing file signature in the
   // cache. Returns true if the signature doesn't change or the file doesn't
@@ -197,14 +197,14 @@
   /// Look up a Key in the block cache.
   std::shared_ptr<Block> Lookup(const Key& key) TF_LOCKS_EXCLUDED(mu_);
 
-  Status MaybeFetch(const Key& key, const std::shared_ptr<Block>& block)
+  absl::Status MaybeFetch(const Key& key, const std::shared_ptr<Block>& block)
       TF_LOCKS_EXCLUDED(mu_);
 
   /// Trim the block cache to make room for another entry.
   void Trim() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
 
   /// Update the LRU iterator for the block at `key`.
-  Status UpdateLRU(const Key& key, const std::shared_ptr<Block>& block)
+  absl::Status UpdateLRU(const Key& key, const std::shared_ptr<Block>& block)
       TF_LOCKS_EXCLUDED(mu_);
 
   /// Remove all blocks of a file, with mu_ already held.
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/ram_file_block_cache_test.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/ram_file_block_cache_test.cc
index 5d17d73..cc71601 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/ram_file_block_cache_test.cc
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/ram_file_block_cache_test.cc
@@ -17,7 +17,7 @@
 
 #include <cstring>
 
-#include "tsl/lib/core/status_test_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tsl/platform/blocking_counter.h"
 #include "tsl/platform/cloud/now_seconds_env.h"
 #include "tsl/platform/env.h"
@@ -27,12 +27,12 @@
 namespace tsl {
 namespace {
 
-Status ReadCache(RamFileBlockCache* cache, const string& filename,
-                 size_t offset, size_t n, std::vector<char>* out) {
+absl::Status ReadCache(RamFileBlockCache* cache, const string& filename,
+                       size_t offset, size_t n, std::vector<char>* out) {
   out->clear();
   out->resize(n, 0);
   size_t bytes_transferred = 0;
-  Status status =
+  absl::Status status =
       cache->Read(filename, offset, n, out->data(), &bytes_transferred);
   EXPECT_LE(bytes_transferred, n);
   out->resize(bytes_transferred, n);
@@ -43,7 +43,7 @@
   auto fetcher = [](const string& filename, size_t offset, size_t n,
                     char* buffer, size_t* bytes_transferred) {
     // Do nothing.
-    return OkStatus();
+    return absl::OkStatus();
   };
   RamFileBlockCache cache1(0, 0, 0, fetcher);
   RamFileBlockCache cache2(16, 0, 0, fetcher);
@@ -63,7 +63,7 @@
     calls++;
     memset(buffer, 'x', n);
     *bytes_transferred = n;
-    return OkStatus();
+    return absl::OkStatus();
   };
   string filename = "file";
   RamFileBlockCache cache(16, 32, 0, fetcher);
@@ -99,7 +99,7 @@
     calls++;
     memset(buffer, 'x', got_n);
     *bytes_transferred = got_n;
-    return OkStatus();
+    return absl::OkStatus();
   };
   // If block_size, max_bytes, or both are zero, or want_n is larger than
   // max_bytes the cache is a pass-through.
@@ -136,7 +136,7 @@
     } else {
       *bytes_transferred = 0;
     }
-    return OkStatus();
+    return absl::OkStatus();
   };
   for (size_t block_size = 2; block_size <= 4; block_size++) {
     // Make a cache of N-byte block size (1 block) and verify that reads of
@@ -181,7 +181,7 @@
     calls.insert(offset);
     memset(buffer, 'x', n);
     *bytes_transferred = n;
-    return OkStatus();
+    return absl::OkStatus();
   };
   const uint32 block_count = 256;
   RamFileBlockCache cache(block_size, block_count * block_size, 0, fetcher);
@@ -222,7 +222,7 @@
       second_block = true;
     }
     *bytes_transferred = bytes_to_copy;
-    return OkStatus();
+    return absl::OkStatus();
   };
   RamFileBlockCache cache(block_size, block_size, 0, fetcher);
   std::vector<char> out;
@@ -233,7 +233,7 @@
   // Reading at offset file_size + 4 will read the second block (since the read
   // at file_size + 4 = 28 will be aligned to an offset of 16) but will return
   // OutOfRange because the offset is past the end of the 24-byte file.
-  Status status = ReadCache(&cache, "", file_size + 4, 4, &out);
+  absl::Status status = ReadCache(&cache, "", file_size + 4, 4, &out);
   EXPECT_EQ(status.code(), error::OUT_OF_RANGE);
   EXPECT_TRUE(second_block);
   // Reading the second full block will return 8 bytes, from a cache hit.
@@ -255,7 +255,7 @@
     EXPECT_GE(n, 1);
     memset(buffer, 'x', 1);
     *bytes_transferred = 1;
-    return OkStatus();
+    return absl::OkStatus();
   };
   RamFileBlockCache cache(block_size, 2 * block_size, 0, fetcher);
   std::vector<char> out;
@@ -264,7 +264,7 @@
   EXPECT_EQ(out.size(), 1);
   // Now read the first block; this should yield an INTERNAL error because we
   // had already cached a partial block at a later position.
-  Status status = ReadCache(&cache, "", 0, block_size, &out);
+  absl::Status status = ReadCache(&cache, "", 0, block_size, &out);
   EXPECT_EQ(status.code(), error::INTERNAL);
 }
 
@@ -282,7 +282,7 @@
     }
     memset(buffer, 'x', n);
     *bytes_transferred = n;
-    return OkStatus();
+    return absl::OkStatus();
   };
   const uint32 block_count = 2;
   RamFileBlockCache cache(block_size, block_count * block_size, 0, fetcher);
@@ -324,7 +324,7 @@
     calls++;
     memset(buffer, 'x', n);
     *bytes_transferred = n;
-    return OkStatus();
+    return absl::OkStatus();
   };
   std::vector<char> out;
   std::unique_ptr<NowSecondsEnv> env(new NowSecondsEnv);
@@ -369,7 +369,7 @@
     }
     memset(buffer, c, n);
     *bytes_transferred = n;
-    return OkStatus();
+    return absl::OkStatus();
   };
   // This cache has space for 4 blocks; we'll read from two files.
   const size_t n = 3;
@@ -426,7 +426,7 @@
     calls++;
     memset(buffer, 'x', n);
     *bytes_transferred = n;
-    return OkStatus();
+    return absl::OkStatus();
   };
   std::vector<char> out;
   // Our fake environment is initialized with the current timestamp.
@@ -493,7 +493,7 @@
     }
     memset(buffer, 'x', n);
     *bytes_transferred = n;
-    return OkStatus();
+    return absl::OkStatus();
   };
   const int block_size = 8;
   RamFileBlockCache cache(block_size, 2 * callers * block_size, 0, fetcher);
@@ -529,7 +529,7 @@
     notification.Notify();
     // Wait for other thread to issue read.
     Env::Default()->SleepForMicroseconds(100000);  // 0.1 secs
-    return OkStatus();
+    return absl::OkStatus();
   };
   RamFileBlockCache cache(block_size, block_size, 0, fetcher);
   // Fork off thread for parallel read.
@@ -554,7 +554,7 @@
     calls++;
     memset(buffer, 'x', n);
     *bytes_transferred = n;
-    return OkStatus();
+    return absl::OkStatus();
   };
   RamFileBlockCache cache(16, 32, 0, fetcher);
   std::vector<char> out;
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/time_util.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/time_util.cc
index fecba6b..62f8258 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/time_util.cc
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/time_util.cc
@@ -34,7 +34,7 @@
 
 // Only implements one special case of RFC 3339 which is returned by
 // GCS API, e.g 2016-04-29T23:15:24.896Z.
-Status ParseRfc3339Time(const string& time, int64_t* mtime_nsec) {
+absl::Status ParseRfc3339Time(const string& time, int64_t* mtime_nsec) {
   tm parsed{0};
   float seconds;
   if (sscanf(time.c_str(), "%4d-%2d-%2dT%2d:%2d:%fZ", &(parsed.tm_year),
@@ -52,7 +52,7 @@
                 static_cast<int64_t>(std::floor((seconds - int_seconds) *
                                                 kNanosecondsPerSecond));
 
-  return OkStatus();
+  return absl::OkStatus();
 }
 
 }  // namespace tsl
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/time_util.h b/third_party/xla/third_party/tsl/tsl/platform/cloud/time_util.h
index 5eb116c..4dd2d29 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/time_util.h
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/time_util.h
@@ -22,7 +22,7 @@
 
 /// Parses the timestamp in RFC 3339 format and returns it
 /// as nanoseconds since epoch.
-Status ParseRfc3339Time(const string& time, int64_t* mtime_nsec);
+absl::Status ParseRfc3339Time(const string& time, int64_t* mtime_nsec);
 
 }  // namespace tsl
 
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/time_util_test.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/time_util_test.cc
index 3a96555..6b54787 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/time_util_test.cc
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/time_util_test.cc
@@ -15,7 +15,7 @@
 
 #include "tsl/platform/cloud/time_util.h"
 
-#include "tsl/lib/core/status_test_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tsl/platform/test.h"
 
 namespace tsl {
diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/zone_provider.h b/third_party/xla/third_party/tsl/tsl/platform/cloud/zone_provider.h
index 8c000e0..14b64ea 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/cloud/zone_provider.h
+++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/zone_provider.h
@@ -34,9 +34,9 @@
   /// Returns an empty string in the case where the zone does not match the
   /// expected format
   /// Safe for concurrent use by multiple threads.
-  virtual Status GetZone(string* zone) = 0;
+  virtual absl::Status GetZone(string* zone) = 0;
 
-  static Status GetZone(ZoneProvider* provider, string* zone) {
+  static absl::Status GetZone(ZoneProvider* provider, string* zone) {
     if (!provider) {
       return errors::Internal("Zone provider is required.");
     }
diff --git a/third_party/xla/third_party/tsl/tsl/platform/default/rocm_rocdl_path.cc b/third_party/xla/third_party/tsl/tsl/platform/default/rocm_rocdl_path.cc
index a1f3eba..a1934f8 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/default/rocm_rocdl_path.cc
+++ b/third_party/xla/third_party/tsl/tsl/platform/default/rocm_rocdl_path.cc
@@ -26,7 +26,7 @@
 
 namespace tsl {
 
-string RocmRoot() {
+std::string RocmRoot() {
 #if TENSORFLOW_USE_ROCM
   if (const char* rocm_path_env = std::getenv("ROCM_PATH")) {
     VLOG(3) << "ROCM root = " << rocm_path_env;
@@ -40,6 +40,12 @@
 #endif
 }
 
-string RocdlRoot() { return io::JoinPath(RocmRoot(), "amdgcn/bitcode"); }
+std::string RocdlRoot() {
+  if (const char* device_lib_path_env = std::getenv("HIP_DEVICE_LIB_PATH")) {
+    return device_lib_path_env;
+  } else {
+    return io::JoinPath(RocmRoot(), "amdgcn/bitcode");
+  }
+}
 
 }  // namespace tsl
diff --git a/third_party/xla/third_party/tsl/tsl/platform/retrying_file_system_test.cc b/third_party/xla/third_party/tsl/tsl/platform/retrying_file_system_test.cc
index 8477cdb..33792c8 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/retrying_file_system_test.cc
+++ b/third_party/xla/third_party/tsl/tsl/platform/retrying_file_system_test.cc
@@ -17,7 +17,7 @@
 
 #include <fstream>
 
-#include "tsl/lib/core/status_test_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tsl/platform/str_util.h"
 #include "tsl/platform/test.h"
 
diff --git a/third_party/xla/third_party/tsl/tsl/platform/retrying_utils_test.cc b/third_party/xla/third_party/tsl/tsl/platform/retrying_utils_test.cc
index 5d55ec3..0024168 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/retrying_utils_test.cc
+++ b/third_party/xla/third_party/tsl/tsl/platform/retrying_utils_test.cc
@@ -18,7 +18,7 @@
 #include <fstream>
 
 #include "absl/time/time.h"
-#include "tsl/lib/core/status_test_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tsl/platform/env.h"
 #include "tsl/platform/errors.h"
 #include "tsl/platform/str_util.h"
diff --git a/third_party/xla/third_party/tsl/tsl/platform/subprocess_test.cc b/third_party/xla/third_party/tsl/tsl/platform/subprocess_test.cc
index 1b1bbcb..1b01a8c 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/subprocess_test.cc
+++ b/third_party/xla/third_party/tsl/tsl/platform/subprocess_test.cc
@@ -20,7 +20,7 @@
 #include <algorithm>
 #include <string>
 
-#include "tsl/lib/core/status_test_util.h"
+#include "xla/tsl/lib/core/status_test_util.h"
 #include "tsl/platform/path.h"
 #include "tsl/platform/strcat.h"
 #include "tsl/platform/test.h"
diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_lock.h b/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_lock.h
index 478dae8..ef30366 100644
--- a/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_lock.h
+++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_lock.h
@@ -46,9 +46,9 @@
   ProfilerLock& operator=(const ProfilerLock&) = delete;
 
   // Movable.
-  ProfilerLock(ProfilerLock&& other)
+  ProfilerLock(ProfilerLock&& other) noexcept
       : active_(std::exchange(other.active_, false)) {}
-  ProfilerLock& operator=(ProfilerLock&& other) {
+  ProfilerLock& operator=(ProfilerLock&& other) noexcept {
     active_ = std::exchange(other.active_, false);
     return *this;
   }
diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/traceme.h b/third_party/xla/third_party/tsl/tsl/profiler/lib/traceme.h
index 75c2902..da9fe21 100644
--- a/third_party/xla/third_party/tsl/tsl/profiler/lib/traceme.h
+++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/traceme.h
@@ -146,8 +146,8 @@
   }
 
   // Movable.
-  TraceMe(TraceMe&& other) { *this = std::move(other); }
-  TraceMe& operator=(TraceMe&& other) {
+  TraceMe(TraceMe&& other) noexcept { *this = std::move(other); }
+  TraceMe& operator=(TraceMe&& other) noexcept {
 #if !defined(IS_MOBILE_PLATFORM)
     if (TF_PREDICT_FALSE(other.start_time_ != kUntracedActivity)) {
       name_.Emplace(std::move(other.name_).Consume());
diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/BUILD b/third_party/xla/third_party/tsl/tsl/profiler/utils/BUILD
index 203657d..39113eb 100644
--- a/third_party/xla/third_party/tsl/tsl/profiler/utils/BUILD
+++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/BUILD
@@ -363,6 +363,7 @@
         ":tpu_xplane_utils",
         ":xplane_schema",
         ":xplane_utils",
+        ":xplane_visitor",
         "//tsl/platform:test",
         "//tsl/platform:test_main",
         "//tsl/profiler/protobuf:xplane_proto_cc",
diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils.cc b/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils.cc
index 19841f5..9274a1d 100644
--- a/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils.cc
+++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils.cc
@@ -17,6 +17,7 @@
 #include <optional>
 #include <vector>
 
+#include "absl/strings/string_view.h"
 #include "tsl/platform/regexp.h"
 #include "tsl/profiler/protobuf/xplane.pb.h"
 #include "tsl/profiler/utils/xplane_schema.h"
@@ -48,5 +49,11 @@
   return std::nullopt;
 }
 
+std::optional<int> GetSparseCoreId(absl::string_view plane_name) {
+  std::optional<int> core_id;
+  RE2::FullMatch(plane_name, {kSparseCorePlaneRegex}, &core_id);
+  return core_id;
+}
+
 }  // namespace profiler
 }  // namespace tsl
diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils.h b/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils.h
index f3a150c..2fb7c67 100644
--- a/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils.h
+++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils.h
@@ -36,6 +36,10 @@
 // TensorCore plane name.
 std::optional<int> GetTensorCoreId(absl::string_view plane_name);
 
+// Get Sparsecore Id from SparseCore plane name if plane name is a valid
+// SparseCore plane name.
+std::optional<int> GetSparseCoreId(absl::string_view plane_name);
+
 }  // namespace profiler
 }  // namespace tsl
 
diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils_test.cc b/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils_test.cc
index a385c77..e5bcd73 100644
--- a/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils_test.cc
+++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils_test.cc
@@ -21,11 +21,13 @@
 #include "tsl/profiler/protobuf/xplane.pb.h"
 #include "tsl/profiler/utils/xplane_schema.h"
 #include "tsl/profiler/utils/xplane_utils.h"
+#include "tsl/profiler/utils/xplane_visitor.h"
 
 namespace tsl {
 namespace profiler {
 namespace {
 
+using ::testing::Optional;
 using ::testing::UnorderedElementsAre;
 
 TEST(TpuXPlaneUtilsTest, GetTensorCoreXPlanesFromXSpace) {
@@ -65,6 +67,22 @@
       GetTensorCoreId(absl::StrCat("/prefix", TpuPlaneName(0))).has_value());
 }
 
+TEST(TpuXplaneUtilsTest, GetSparseCorePlanesFromXSpace) {
+  XSpace space;
+  XPlane* p1 = FindOrAddMutablePlaneWithName(&space, TpuPlaneName(0));
+  XPlane* p2 = FindOrAddMutablePlaneWithName(&space, TpuPlaneName(1));
+  XPlane* p3 = FindOrAddMutablePlaneWithName(
+      &space, absl::StrCat(TpuPlaneName(0), " SparseCore 0"));
+  XPlane* p4 = FindOrAddMutablePlaneWithName(
+      &space, absl::StrCat(TpuPlaneName(0), " SparseCore 1"));
+
+  EXPECT_THAT(FindTensorCorePlanes(space), UnorderedElementsAre(p1, p2));
+  EXPECT_THAT(FindPlanesWithPrefix(space, kTpuPlanePrefix),
+              UnorderedElementsAre(p1, p2, p3, p4));
+  EXPECT_THAT(GetSparseCoreId(p3->name()), Optional(0));
+  EXPECT_THAT(GetSparseCoreId(p4->name()), Optional(1));
+}
+
 }  // namespace
 }  // namespace profiler
 }  // namespace tsl
diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.cc b/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.cc
index 33de2b0..2cd8aaa 100644
--- a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.cc
+++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.cc
@@ -33,6 +33,8 @@
 const absl::string_view kTpuPlanePrefix = "/device:TPU:";
 const absl::string_view kTpuNonCorePlaneNamePrefix = "#Chip";
 const char kTpuPlaneRegex[] = {"/device:TPU:([0-9]*)$"};
+const char kSparseCorePlaneRegex[] = {
+    "/device:TPU:[0-9]+ SparseCore ([0-9]+)$"};
 // TODO(b/195582092): change it to /device:custom once all literals are
 // migrated.
 const absl::string_view kCustomPlanePrefix = "/device:CUSTOM:";
diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.h b/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.h
index 2e693b4..edf808b 100644
--- a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.h
+++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.h
@@ -41,6 +41,8 @@
 TF_CONST_INIT extern const absl::string_view kTpuPlanePrefix;
 // Regex for XPlanes that contain TensorCore planes.
 TF_CONST_INIT extern const char kTpuPlaneRegex[];
+// Regex for XPlanes that contain TPU Core planes.
+TF_CONST_INIT extern const char kSparseCorePlaneRegex[];
 // Name prefix of XPlane that contains custom device events.
 TF_CONST_INIT extern const absl::string_view kCustomPlanePrefix;
 // Name prefix of XPlane that contains TPU non-core events such as HBM, ICI etc.
diff --git a/third_party/xla/third_party/tsl/workspace2.bzl b/third_party/xla/third_party/tsl/workspace2.bzl
index 0a2993f..8b8f1de 100644
--- a/third_party/xla/third_party/tsl/workspace2.bzl
+++ b/third_party/xla/third_party/tsl/workspace2.bzl
@@ -160,13 +160,13 @@
 
     tf_http_archive(
         name = "mkl_dnn_acl_compatible",
-        build_file = "//tensorflow/third_party/mkl_dnn:mkldnn_acl.BUILD",
+        build_file = "//third_party/mkl_dnn:mkldnn_acl.BUILD",
         patch_file = [
-            "//tensorflow/third_party/mkl_dnn:onednn_acl_threadcap.patch",
-            "//tensorflow/third_party/mkl_dnn:onednn_acl_reorder.patch",
-            "//tensorflow/third_party/mkl_dnn:onednn_acl_thread_local_scheduler.patch",
-            "//tensorflow/third_party/mkl_dnn:onednn_acl_fp32_bf16_reorder.patch",
-            "//tensorflow/third_party/mkl_dnn:onednn_acl_bf16_capability_detection_for_ubuntu20.04.patch",
+            "//third_party/mkl_dnn:onednn_acl_threadcap.patch",
+            "//third_party/mkl_dnn:onednn_acl_reorder.patch",
+            "//third_party/mkl_dnn:onednn_acl_thread_local_scheduler.patch",
+            "//third_party/mkl_dnn:onednn_acl_fp32_bf16_reorder.patch",
+            "//third_party/mkl_dnn:onednn_acl_bf16_capability_detection_for_ubuntu20.04.patch",
         ],
         sha256 = "2f76b407ef8893cca71340f88cd800019a1f14f8ac1bbdbb89a84be1370b52e3",
         strip_prefix = "oneDNN-3.2.1",
@@ -591,6 +591,22 @@
         urls = tf_mirror_urls("https://github.com/google/glog/archive/refs/tags/v0.4.0.tar.gz"),
     )
 
+    tf_http_archive(
+        name = "spirv_headers",
+        sha256 = "11d835c60297b26532c05c3f3b581ba7a2787b5ae7399e94f72c392169216f11",
+        strip_prefix = "SPIRV-Headers-b73e168ca5e123dcf3dea8a34b19a5130f421ae1",
+        urls = tf_mirror_urls("https://github.com/KhronosGroup/SPIRV-Headers/archive/b73e168ca5e123dcf3dea8a34b19a5130f421ae1.tar.gz"),
+    )
+
+    tf_http_archive(
+        name = "spirv_llvm_translator",
+        sha256 = "d499769f4fd1e0ce9d4dbd3622ee7e3e641b5623dcdf811521e3e7c0bdb1e6c2",
+        strip_prefix = "SPIRV-LLVM-Translator-dad1f0eaab8047a4f73c50ed5f3d1694b78aae97",
+        build_file = "//third_party/spirv_llvm_translator:spirv_llvm_translator.BUILD",
+        patch_file = ["//third_party/spirv_llvm_translator:spirv_llvm_translator.patch"],
+        urls = tf_mirror_urls("https://github.com/KhronosGroup/SPIRV-LLVM-Translator/archive/dad1f0eaab8047a4f73c50ed5f3d1694b78aae97.tar.gz"),
+    )
+
 # buildifier: disable=unnamed-macro
 def workspace():
     # Check the bazel version before executing any repository rules, in case
diff --git a/third_party/xla/third_party/uv/uv.BUILD b/third_party/xla/third_party/uv/uv.BUILD
index 75a2df3..b04383a 100644
--- a/third_party/xla/third_party/uv/uv.BUILD
+++ b/third_party/xla/third_party/uv/uv.BUILD
@@ -11,7 +11,48 @@
 
 cc_library(
     name = "uv",
-    srcs = glob(["src/*.c"]),
+    srcs = [
+        "src/fs-poll.c",
+        "src/idna.c",
+        "src/inet.c",
+        "src/random.c",
+        "src/strscpy.c",
+        "src/threadpool.c",
+        "src/timer.c",
+        "src/uv-common.c",
+        "src/uv-data-getter-setters.c",
+        "src/version.c",
+    ] + [
+        "src/unix/async.c",
+        "src/unix/core.c",
+        "src/unix/dl.c",
+        "src/unix/fs.c",
+        "src/unix/getaddrinfo.c",
+        "src/unix/getnameinfo.c",
+        "src/unix/loop.c",
+        "src/unix/loop-watcher.c",
+        "src/unix/pipe.c",
+        "src/unix/poll.c",
+        "src/unix/process.c",
+        "src/unix/random-devurandom.c",
+        "src/unix/signal.c",
+        "src/unix/stream.c",
+        "src/unix/tcp.c",
+        "src/unix/thread.c",
+        "src/unix/tty.c",
+        "src/unix/udp.c",
+    ] + select({
+        "@platforms//os:osx": [
+            "src/unix/bsd-ifaddrs.c",
+            "src/unix/darwin.c",
+            "src/unix/darwin-proctitle.c",
+            "src/unix/fsevents.c",
+            "src/unix/kqueue.c",
+            "src/unix/proctitle.c",
+            "src/unix/random-getentropy.c",
+        ],
+    }),
+    # TODO: Add Linux, etc. as in https://github.com/libuv/libuv/blob/v1.38.0/CMakeLists.txt.
     hdrs = [
         "include/uv.h",
     ],
diff --git a/third_party/xla/warnings.bazelrc b/third_party/xla/warnings.bazelrc
index a5711c9..ac21913 100644
--- a/third_party/xla/warnings.bazelrc
+++ b/third_party/xla/warnings.bazelrc
@@ -93,7 +93,5 @@
 build:warnings --copt=-Wnon-virtual-dtor
 build:warnings --copt=-Wimplicit-fallthrough
 build:warnings --copt=-Wthread-safety-analysis
-build:warnings --copt=-Wno-tautological-type-limit-compare
-build:warnings --copt=-Wno-nullability-completeness
 build:warnings --copt=-Wno-builtin-macro-redefined
 build:warnings --copt=-Wno-macro-redefined
diff --git a/third_party/xla/workspace2.bzl b/third_party/xla/workspace2.bzl
index dea8d37..da31481 100644
--- a/third_party/xla/workspace2.bzl
+++ b/third_party/xla/workspace2.bzl
@@ -16,7 +16,6 @@
 load("//third_party/shardy:workspace.bzl", shardy = "repo")
 load("//third_party/stablehlo:workspace.bzl", stablehlo = "repo")
 load("//third_party/triton:workspace.bzl", triton = "repo")
-load("//third_party/uv:workspace.bzl", uv = "repo")
 
 def _initialize_third_party():
     """ Load third party repositories.  See above load() statements. """
@@ -28,7 +27,6 @@
     shardy()
     stablehlo()
     triton()
-    uv()
 
 # Define all external repositories required by TensorFlow
 def _tf_repositories():
@@ -52,9 +50,9 @@
         name = "cudnn_frontend_archive",
         build_file = "//third_party:cudnn_frontend.BUILD",
         patch_file = ["//third_party:cudnn_frontend_header_fix.patch"],
-        sha256 = "281789777ac296f5f8215a7c4bd066de8816d240eb44c760788beebf8d25a99f",
-        strip_prefix = "cudnn-frontend-1.5.1",
-        urls = tf_mirror_urls("https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.5.1.zip"),
+        sha256 = "313f4a38a54e578ed668809697c96754497141ba62332bdcc019faaeb1e3c6f6",
+        strip_prefix = "cudnn-frontend-1.6.0",
+        urls = tf_mirror_urls("https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.6.0.zip"),
     )
 
     tf_http_archive(
diff --git a/third_party/xla/xla/BUILD b/third_party/xla/xla/BUILD
index 5c72564..5d4476a 100644
--- a/third_party/xla/xla/BUILD
+++ b/third_party/xla/xla/BUILD
@@ -516,6 +516,7 @@
         ":test",
         ":util",
         ":xla_data_proto_cc",
+        "@com_google_absl//absl/container:inlined_vector",
         "@com_google_absl//absl/status",
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/types:span",
@@ -1048,7 +1049,9 @@
         ":array2d",
         ":array3d",
         ":array4d",
+        ":literal",
         ":literal_util",
+        ":shape_util",
         ":util",
         ":window_util",
         ":xla_data_proto_cc",
@@ -1056,9 +1059,11 @@
         "//xla/client:xla_builder",
         "//xla/hlo/evaluator:hlo_evaluator",
         "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_module_config",
         "//xla/service:shape_inference",
         "@com_google_absl//absl/container:flat_hash_set",
         "@com_google_absl//absl/functional:function_ref",
+        "@com_google_absl//absl/log:check",
         "@com_google_absl//absl/types:span",
         "@local_tsl//tsl/lib/math:math_util",
         "@local_tsl//tsl/platform:logging",
@@ -1072,7 +1077,9 @@
         ":array2d",
         ":array3d",
         ":array4d",
+        ":error_spec",
         ":literal",
+        ":literal_util",
         ":reference_util",
         ":test",
         ":xla_data_proto_cc",
@@ -1254,10 +1261,10 @@
     deps = [
         ":autotune_results_proto_cc",
         ":autotuning_proto_cc",
+        "//xla/tsl/lib/strings:proto_serialization",
         "@com_google_absl//absl/log:check",
         "@com_google_absl//absl/status",
         "@com_google_absl//absl/status:statusor",
-        "@local_tsl//tsl/lib/strings:proto_serialization",
     ],
 )
 
@@ -1312,6 +1319,33 @@
     visibility = ["//visibility:public"],
 )
 
+cc_library(
+    name = "sort_json",
+    srcs = ["sort_json.cc"],
+    hdrs = ["sort_json.h"],
+    visibility = ["//visibility:public"],
+    deps = [
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/strings:string_view",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+xla_cc_test(
+    name = "sort_json_test",
+    srcs = ["sort_json_test.cc"],
+    deps = [
+        ":sort_json",
+        "@com_google_googletest//:gtest_main",
+        "@local_tsl//tsl/platform:status_matchers",
+        "@local_tsl//tsl/platform:test",
+        "@local_tsl//tsl/platform:test_main",
+    ],
+)
+
 # Needed to workaround https://github.com/bazelbuild/bazel/issues/21519
 alias(
     name = "bazel_issue_21519",
diff --git a/third_party/xla/xla/array.h b/third_party/xla/xla/array.h
index 6a6f505..03c5f3b 100644
--- a/third_party/xla/xla/array.h
+++ b/third_party/xla/xla/array.h
@@ -603,12 +603,12 @@
       std::fill(data.get(), data.get() + size, init);
     }
 
-    OwnedBuffer(OwnedBuffer&& other)
+    OwnedBuffer(OwnedBuffer&& other) noexcept
         : data(std::move(other.data)), size(other.size) {
       other.size = 0;
     }
 
-    OwnedBuffer& operator=(OwnedBuffer&& other) {
+    OwnedBuffer& operator=(OwnedBuffer&& other) noexcept {
       data = std::move(other.data);
       size = other.size;
       other.size = 0;
diff --git a/third_party/xla/xla/autotune_result_wrapper.cc b/third_party/xla/xla/autotune_result_wrapper.cc
index 855c8aa..ee92f17 100644
--- a/third_party/xla/xla/autotune_result_wrapper.cc
+++ b/third_party/xla/xla/autotune_result_wrapper.cc
@@ -21,7 +21,7 @@
 #include "absl/status/status.h"
 #include "xla/autotune_results.pb.h"
 #include "xla/autotuning.pb.h"
-#include "tsl/lib/strings/proto_serialization.h"
+#include "xla/tsl/lib/strings/proto_serialization.h"
 
 namespace xla {
 
diff --git a/third_party/xla/xla/backends/interpreter/executor.h b/third_party/xla/xla/backends/interpreter/executor.h
index 0aee389b..1228b3b 100644
--- a/third_party/xla/xla/backends/interpreter/executor.h
+++ b/third_party/xla/xla/backends/interpreter/executor.h
@@ -86,15 +86,6 @@
   absl::Status Init() override { return absl::OkStatus(); }
 
   int device_ordinal() const override { return device_ordinal_; };
-  absl::StatusOr<std::unique_ptr<Kernel>> LoadKernel(
-      const MultiKernelLoaderSpec &spec) override {
-    return absl::UnimplementedError("Not Implemented");
-  }
-  absl::Status Launch(Stream *stream, const ThreadDim &thread_dims,
-                      const BlockDim &block_dims, const Kernel &kernel,
-                      const KernelArgs &args) override {
-    return absl::UnimplementedError("Not Implemented");
-  }
 
   DeviceMemoryBase Allocate(uint64_t size, int64_t memory_space) override;
   void Deallocate(DeviceMemoryBase *mem) override;
@@ -103,7 +94,7 @@
       uint64_t size) override {
     return std::make_unique<HostMemoryAllocation>(new char[size], size, this);
   }
-  void HostMemoryDeallocate(void *mem, uint64_t size) override {
+  void HostMemoryDeallocate(void *mem) override {
     delete[] static_cast<char *>(mem);
   }
 
diff --git a/third_party/xla/xla/backends/interpreter/platform.cc b/third_party/xla/xla/backends/interpreter/platform.cc
index 8b77eb1..1a541f1 100644
--- a/third_party/xla/xla/backends/interpreter/platform.cc
+++ b/third_party/xla/xla/backends/interpreter/platform.cc
@@ -54,6 +54,13 @@
   return GetExecutor(config);
 }
 
+absl::StatusOr<StreamExecutor*> XlaInterpreterPlatform::FindExisting(
+    int ordinal) {
+  StreamExecutorConfig config;
+  config.ordinal = ordinal;
+  return executor_cache_.Get(config);
+}
+
 absl::StatusOr<StreamExecutor*> XlaInterpreterPlatform::GetExecutor(
     const StreamExecutorConfig& config) {
   return executor_cache_.GetOrCreate(
diff --git a/third_party/xla/xla/backends/interpreter/platform.h b/third_party/xla/xla/backends/interpreter/platform.h
index da3d18e..08a8d21 100644
--- a/third_party/xla/xla/backends/interpreter/platform.h
+++ b/third_party/xla/xla/backends/interpreter/platform.h
@@ -49,9 +49,13 @@
 
   absl::StatusOr<StreamExecutor*> GetExecutor(
       const StreamExecutorConfig& config) override;
+  absl::StatusOr<StreamExecutor*> FindExisting(int ordinal) override;
 
+  // Returns a device constructed with the options specified in "config" without
+  // looking in or storing to the Platform's executor cache.
+  // Ownership IS transferred to the caller.
   absl::StatusOr<std::unique_ptr<StreamExecutor>> GetUncachedExecutor(
-      const StreamExecutorConfig& config) override;
+      const StreamExecutorConfig& config);
 
  private:
   // This platform's name.
diff --git a/third_party/xla/xla/backends/profiler/gpu/rocm_collector.cc b/third_party/xla/xla/backends/profiler/gpu/rocm_collector.cc
index e53b017..65f8d57 100644
--- a/third_party/xla/xla/backends/profiler/gpu/rocm_collector.cc
+++ b/third_party/xla/xla/backends/profiler/gpu/rocm_collector.cc
@@ -618,7 +618,7 @@
   const std::vector<RocmTracerEvent> ApiActivityInfoExchange()
       TF_EXCLUSIVE_LOCKS_REQUIRED(event_maps_mutex_);
 
-  absl::flat_hash_map<uint32_t, PerDeviceCollector> per_device_collector_;
+  absl::node_hash_map<uint32_t, PerDeviceCollector> per_device_collector_;
 };
 //==========
 
diff --git a/third_party/xla/xla/client/BUILD b/third_party/xla/xla/client/BUILD
index 1e7f1bf..d322caf 100644
--- a/third_party/xla/xla/client/BUILD
+++ b/third_party/xla/xla/client/BUILD
@@ -53,6 +53,7 @@
         "@com_google_absl//absl/types:span",
         "@local_tsl//tsl/lib/math:math_util",
         "@local_tsl//tsl/platform:logging",
+        "@local_tsl//tsl/platform:status",
     ],
 )
 
@@ -71,24 +72,23 @@
     srcs = ["client.cc"],
     hdrs = ["client.h"],
     deps = [
-        ":global_data",
         ":xla_computation",
-        "//xla:debug_options_flags",
         "//xla:execution_options_util",
         "//xla:literal",
+        "//xla:shape_util",
         "//xla:status_macros",
         "//xla:types",
+        "//xla:util",
         "//xla:xla_data_proto_cc",
         "//xla:xla_proto_cc",
         "//xla/service",
         "//xla/service:hlo_proto_cc",
         "@com_google_absl//absl/status",
         "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings",
         "@com_google_absl//absl/types:span",
-        "@local_tsl//tsl/platform:errors",
         "@local_tsl//tsl/platform:logging",
         "@local_tsl//tsl/platform:protobuf",
+        "@local_tsl//tsl/platform:statusor",
     ],
 )
 
@@ -111,9 +111,9 @@
         "@com_google_absl//absl/log:check",
         "@com_google_absl//absl/status:statusor",
         "@com_google_absl//absl/strings:str_format",
+        "@com_google_absl//absl/strings:string_view",
+        "@com_google_absl//absl/types:span",
         "@local_tsl//tsl/platform:env",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:status",
         "@local_tsl//tsl/platform:statusor",
     ],
 )
@@ -126,14 +126,18 @@
         ":client",
         ":executable_build_options",
         ":xla_computation",
+        "//xla:debug_options_flags",
         "//xla:executable_run_options",
+        "//xla:literal",
         "//xla:shape_tree",
+        "//xla:shape_util",
+        "//xla:util",
         "//xla:xla_data_proto_cc",
         "//xla/service:backend",
         "//xla/service:compiler",
+        "//xla/service:computation_layout",
         "//xla/service:dump",
         "//xla/service:executable",
-        "//xla/service:hlo_proto_cc",
         "//xla/service:local_service",
         "//xla/service:maybe_owning_device_memory",
         "//xla/service:shaped_buffer",
@@ -141,9 +145,14 @@
         "//xla/service:stream_pool",
         "//xla/stream_executor",
         "//xla/stream_executor:device_memory_allocator",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/log:check",
         "@com_google_absl//absl/status",
         "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
         "@com_google_absl//absl/types:span",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:statusor",
     ],
 )
 
@@ -154,12 +163,18 @@
     deps = [
         ":client",
         ":xla_computation",
+        "//xla:shape_util",
         "//xla:status_macros",
         "//xla:xla_data_proto_cc",
+        "//xla:xla_proto_cc",
         "//xla/service:compile_only_service",
         "//xla/service:compiler",
+        "//xla/service:hlo_module_config",
         "//xla/stream_executor",
+        "@com_google_absl//absl/log:check",
         "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings:string_view",
+        "@com_google_absl//absl/types:span",
         "@llvm-project//llvm:Support",
         "@llvm-project//llvm:TargetParser",
     ],
@@ -174,18 +189,20 @@
     deps = [
         ":compile_only_client",
         ":local_client",
-        "//xla:status_macros",
         "//xla:types",
-        "//xla:util",
-        "//xla/service:backend",
+        "//xla/service",
         "//xla/service:compile_only_service",
         "//xla/service:local_service",
         "//xla/service:platform_util",
         "//xla/stream_executor",
         "//xla/stream_executor:device_memory_allocator",
+        "@com_google_absl//absl/base:core_headers",
         "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/synchronization",
         "@local_tsl//tsl/platform:logging",
+        "@local_tsl//tsl/platform:status",
+        "@local_tsl//tsl/platform:statusor",
     ],
 )
 
@@ -200,6 +217,7 @@
         "//xla:types",
         "//xla:util",
         "//xla:xla_data_proto_cc",
+        "@com_google_absl//absl/log:check",
     ],
 )
 
@@ -214,6 +232,7 @@
         "//xla:util",
         "//xla:xla_data_proto_cc",
         "//xla/service:hlo_proto_cc",
+        "@com_google_absl//absl/status:statusor",
     ],
 )
 
@@ -233,8 +252,13 @@
         "//xla:xla_data_proto_cc",
         "//xla/hlo/evaluator:hlo_evaluator",
         "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_module_config",
         "//xla/service:hlo_proto_cc",
+        "@com_google_absl//absl/algorithm:container",
         "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/hash",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/log:check",
         "@com_google_absl//absl/status:statusor",
         "@com_google_absl//absl/strings:str_format",
         "@com_google_absl//absl/types:span",
@@ -305,6 +329,7 @@
         "//xla:xla_data_proto_cc",
         "//xla/hlo/ir:hlo",
         "//xla/service:hlo_parser",
+        "//xla/service:hlo_proto_cc",
         "//xla/service:pattern_matcher",
         "//xla/service:pattern_matcher_gmock",
         "//xla/tests:xla_internal_test_main",
diff --git a/third_party/xla/xla/client/client.cc b/third_party/xla/xla/client/client.cc
index 6e89947..f5e174d 100644
--- a/third_party/xla/xla/client/client.cc
+++ b/third_party/xla/xla/client/client.cc
@@ -22,16 +22,22 @@
 #include <vector>
 
 #include "absl/status/status.h"
-#include "absl/strings/str_cat.h"
+#include "absl/types/span.h"
 #include "xla/client/xla_computation.h"
-#include "xla/debug_options_flags.h"
 #include "xla/execution_options_util.h"
+#include "xla/layout.h"
 #include "xla/literal.h"
+#include "xla/service/hlo.pb.h"
+#include "xla/service/service.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
 #include "xla/status_macros.h"
-#include "xla/types.h"
-#include "tsl/platform/errors.h"
+#include "xla/util.h"
+#include "xla/xla.pb.h"
+#include "xla/xla_data.pb.h"
 #include "tsl/platform/logging.h"
 #include "tsl/platform/protobuf.h"
+#include "tsl/platform/statusor.h"
 
 namespace xla {
 
diff --git a/third_party/xla/xla/client/client.h b/third_party/xla/xla/client/client.h
index f3eacbc..1201568 100644
--- a/third_party/xla/xla/client/client.h
+++ b/third_party/xla/xla/client/client.h
@@ -21,12 +21,15 @@
 #include <utility>
 #include <vector>
 
+#include "absl/status/status.h"
 #include "absl/status/statusor.h"
 #include "absl/types/span.h"
 #include "xla/client/xla_computation.h"
+#include "xla/layout.h"
 #include "xla/literal.h"
 #include "xla/service/hlo.pb.h"
 #include "xla/service/service.h"
+#include "xla/shape.h"
 #include "xla/types.h"
 #include "xla/xla.pb.h"
 #include "xla/xla_data.pb.h"
diff --git a/third_party/xla/xla/client/client_library.cc b/third_party/xla/xla/client/client_library.cc
index b55691b..476208d 100644
--- a/third_party/xla/xla/client/client_library.cc
+++ b/third_party/xla/xla/client/client_library.cc
@@ -20,11 +20,17 @@
 #include <set>
 #include <utility>
 
-#include "xla/service/backend.h"
+#include "absl/synchronization/mutex.h"
+#include "xla/client/compile_only_client.h"
+#include "xla/client/local_client.h"
+#include "xla/service/compile_only_service.h"
+#include "xla/service/local_service.h"
 #include "xla/service/platform_util.h"
-#include "xla/status_macros.h"
-#include "xla/util.h"
+#include "xla/service/service.h"
+#include "xla/stream_executor/platform.h"
 #include "tsl/platform/logging.h"
+#include "tsl/platform/status.h"
+#include "tsl/platform/statusor.h"
 
 namespace xla {
 
diff --git a/third_party/xla/xla/client/client_library.h b/third_party/xla/xla/client/client_library.h
index db86732..0e4f3a9a 100644
--- a/third_party/xla/xla/client/client_library.h
+++ b/third_party/xla/xla/client/client_library.h
@@ -28,13 +28,16 @@
 #include <string>
 #include <vector>
 
+#include "absl/base/thread_annotations.h"
 #include "absl/container/flat_hash_map.h"
 #include "absl/status/statusor.h"
+#include "absl/synchronization/mutex.h"
 #include "xla/client/compile_only_client.h"
 #include "xla/client/local_client.h"
 #include "xla/service/compile_only_service.h"
 #include "xla/service/local_service.h"
 #include "xla/stream_executor/device_memory_allocator.h"
+#include "xla/stream_executor/platform.h"
 #include "xla/stream_executor/stream_executor.h"
 #include "xla/types.h"
 
diff --git a/third_party/xla/xla/client/compile_only_client.cc b/third_party/xla/xla/client/compile_only_client.cc
index 23c07b37..1aa6a4f 100644
--- a/third_party/xla/xla/client/compile_only_client.cc
+++ b/third_party/xla/xla/client/compile_only_client.cc
@@ -18,9 +18,19 @@
 #include <memory>
 #include <vector>
 
+#include "absl/log/check.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "llvm/ADT/StringRef.h"
 #include "llvm/ADT/Twine.h"
 #include "llvm/TargetParser/Triple.h"
+#include "xla/service/compile_only_service.h"
+#include "xla/service/compiler.h"
+#include "xla/service/hlo_module_config.h"
+#include "xla/shape.h"
 #include "xla/status_macros.h"
+#include "xla/xla.pb.h"
 
 namespace xla {
 
diff --git a/third_party/xla/xla/client/compile_only_client.h b/third_party/xla/xla/client/compile_only_client.h
index 8dde8c8..2dcb977 100644
--- a/third_party/xla/xla/client/compile_only_client.h
+++ b/third_party/xla/xla/client/compile_only_client.h
@@ -20,11 +20,16 @@
 #include <vector>
 
 #include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
 #include "xla/client/client.h"
 #include "xla/client/xla_computation.h"
 #include "xla/service/compile_only_service.h"
 #include "xla/service/compiler.h"
+#include "xla/service/hlo_module_config.h"
+#include "xla/shape.h"
 #include "xla/stream_executor/stream_executor.h"
+#include "xla/xla.pb.h"
 #include "xla/xla_data.pb.h"
 
 namespace xla {
diff --git a/third_party/xla/xla/client/executable_build_options.cc b/third_party/xla/xla/client/executable_build_options.cc
index 46b810d..e484319 100644
--- a/third_party/xla/xla/client/executable_build_options.cc
+++ b/third_party/xla/xla/client/executable_build_options.cc
@@ -27,14 +27,13 @@
 #include "xla/debug_options_flags.h"
 #include "xla/execution_options_util.h"
 #include "xla/layout_util.h"
+#include "xla/pjrt/compile_options.pb.h"
 #include "xla/service/compilation_environments.h"
 #include "xla/service/computation_placer.h"
 #include "xla/shape.h"
 #include "xla/shape_util.h"
 #include "xla/util.h"
 #include "xla/xla.pb.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/status.h"
 #include "tsl/platform/statusor.h"
 
 namespace xla {
diff --git a/third_party/xla/xla/client/executable_build_options.h b/third_party/xla/xla/client/executable_build_options.h
index c849230..f1129d6 100644
--- a/third_party/xla/xla/client/executable_build_options.h
+++ b/third_party/xla/xla/client/executable_build_options.h
@@ -24,6 +24,10 @@
 
 #include "absl/algorithm/container.h"
 #include "absl/container/inlined_vector.h"
+#include "absl/log/check.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
 #include "xla/pjrt/compile_options.pb.h"
 #include "xla/pjrt/distributed/key_value_store_interface.h"
 #include "xla/service/compilation_environments.h"
diff --git a/third_party/xla/xla/client/lib/BUILD b/third_party/xla/xla/client/lib/BUILD
index 28f6c27..7e2cc0d 100644
--- a/third_party/xla/xla/client/lib/BUILD
+++ b/third_party/xla/xla/client/lib/BUILD
@@ -227,6 +227,7 @@
 
 xla_test(
     name = "math_test",
+    timeout = "long",
     srcs = ["math_test.cc"],
     backend_tags = {
         # Times out.
diff --git a/third_party/xla/xla/client/local_client.cc b/third_party/xla/xla/client/local_client.cc
index c388dc4..05056ba 100644
--- a/third_party/xla/xla/client/local_client.cc
+++ b/third_party/xla/xla/client/local_client.cc
@@ -20,13 +20,37 @@
 #include <utility>
 #include <vector>
 
+#include "absl/log/check.h"
+#include "absl/log/log.h"
 #include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_cat.h"
+#include "absl/types/span.h"
+#include "xla/client/executable_build_options.h"
 #include "xla/client/xla_computation.h"
+#include "xla/debug_options_flags.h"
+#include "xla/executable_run_options.h"
+#include "xla/literal.h"
 #include "xla/service/backend.h"
+#include "xla/service/compiler.h"
+#include "xla/service/computation_layout.h"
 #include "xla/service/dump.h"
+#include "xla/service/executable.h"
+#include "xla/service/maybe_owning_device_memory.h"
 #include "xla/service/service_executable_run_options.h"
+#include "xla/service/shaped_buffer.h"
 #include "xla/service/source_map_util.h"
 #include "xla/service/stream_pool.h"
+#include "xla/shape.h"
+#include "xla/shape_tree.h"
+#include "xla/shape_util.h"
+#include "xla/stream_executor/device_memory_allocator.h"
+#include "xla/stream_executor/platform.h"
+#include "xla/stream_executor/stream.h"
+#include "xla/util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
 
 using xla::source_map_util::InvalidParameterArgument;
 
diff --git a/third_party/xla/xla/client/local_client.h b/third_party/xla/xla/client/local_client.h
index 236ebe0..07c6e6e 100644
--- a/third_party/xla/xla/client/local_client.h
+++ b/third_party/xla/xla/client/local_client.h
@@ -21,21 +21,28 @@
 #include <utility>
 #include <vector>
 
+#include "absl/status/status.h"
 #include "absl/status/statusor.h"
 #include "absl/types/span.h"
 #include "xla/client/client.h"
 #include "xla/client/executable_build_options.h"
 #include "xla/client/xla_computation.h"
 #include "xla/executable_run_options.h"
+#include "xla/literal.h"
+#include "xla/service/backend.h"
 #include "xla/service/compiler.h"
 #include "xla/service/executable.h"
 #include "xla/service/local_service.h"
 #include "xla/service/maybe_owning_device_memory.h"
+#include "xla/service/service_executable_run_options.h"
 #include "xla/service/shaped_buffer.h"
+#include "xla/service/stream_pool.h"
 #include "xla/shape_tree.h"
 #include "xla/stream_executor/device_memory_allocator.h"
 #include "xla/stream_executor/stream_executor.h"
 #include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
 
 namespace xla {
 
diff --git a/third_party/xla/xla/client/padding.cc b/third_party/xla/xla/client/padding.cc
index 37abc59..daf26d5 100644
--- a/third_party/xla/xla/client/padding.cc
+++ b/third_party/xla/xla/client/padding.cc
@@ -20,9 +20,11 @@
 #include <vector>
 
 #include "absl/status/status.h"
+#include "absl/types/span.h"
 #include "xla/util.h"
 #include "tsl/lib/math/math_util.h"
 #include "tsl/platform/logging.h"
+#include "tsl/platform/status.h"
 
 namespace xla {
 
diff --git a/third_party/xla/xla/client/padding.h b/third_party/xla/xla/client/padding.h
index e715226..e717183 100644
--- a/third_party/xla/xla/client/padding.h
+++ b/third_party/xla/xla/client/padding.h
@@ -19,6 +19,7 @@
 #include <utility>
 #include <vector>
 
+#include "absl/status/status.h"
 #include "absl/status/statusor.h"
 #include "absl/types/span.h"
 #include "xla/types.h"
diff --git a/third_party/xla/xla/client/sharding_builder.cc b/third_party/xla/xla/client/sharding_builder.cc
index e2324d6..7b179b8 100644
--- a/third_party/xla/xla/client/sharding_builder.cc
+++ b/third_party/xla/xla/client/sharding_builder.cc
@@ -17,6 +17,12 @@
 
 #include <vector>
 
+#include "absl/log/check.h"
+#include "xla/shape.h"
+#include "xla/shape_tree.h"
+#include "xla/util.h"
+#include "xla/xla_data.pb.h"
+
 namespace xla {
 namespace sharding_builder {
 
diff --git a/third_party/xla/xla/client/sharding_builder.h b/third_party/xla/xla/client/sharding_builder.h
index 98d6512..eef395e 100644
--- a/third_party/xla/xla/client/sharding_builder.h
+++ b/third_party/xla/xla/client/sharding_builder.h
@@ -19,6 +19,7 @@
 #include <vector>
 
 #include "xla/array.h"
+#include "xla/shape.h"
 #include "xla/shape_tree.h"
 #include "xla/shape_util.h"
 #include "xla/types.h"
diff --git a/third_party/xla/xla/client/value_inference.cc b/third_party/xla/xla/client/value_inference.cc
index 1ba694a..2f0b6e2 100644
--- a/third_party/xla/xla/client/value_inference.cc
+++ b/third_party/xla/xla/client/value_inference.cc
@@ -21,11 +21,17 @@
 #include <utility>
 #include <vector>
 
+#include "absl/algorithm/container.h"
 #include "absl/container/flat_hash_map.h"
+#include "absl/hash/hash.h"
+#include "absl/log/check.h"
+#include "absl/log/log.h"
 #include "absl/status/statusor.h"
 #include "absl/strings/str_format.h"
 #include "absl/types/span.h"
+#include "xla/client/xla_builder.h"
 #include "xla/comparison_util.h"
+#include "xla/hlo/evaluator/hlo_evaluator.h"
 #include "xla/hlo/ir/dfs_hlo_visitor.h"
 #include "xla/hlo/ir/hlo_computation.h"
 #include "xla/hlo/ir/hlo_opcode.h"
@@ -33,6 +39,8 @@
 #include "xla/literal_util.h"
 #include "xla/primitive_util.h"
 #include "xla/service/hlo.pb.h"
+#include "xla/service/hlo_module_config.h"
+#include "xla/shape.h"
 #include "xla/shape_util.h"
 #include "xla/status_macros.h"
 #include "xla/util.h"
diff --git a/third_party/xla/xla/client/value_inference.h b/third_party/xla/xla/client/value_inference.h
index 6f1685f..84c1c99 100644
--- a/third_party/xla/xla/client/value_inference.h
+++ b/third_party/xla/xla/client/value_inference.h
@@ -19,6 +19,9 @@
 #include <utility>
 
 #include "absl/container/flat_hash_map.h"
+#include "absl/log/check.h"
+#include "absl/status/statusor.h"
+#include "absl/types/span.h"
 #include "xla/client/xla_builder.h"
 #include "xla/hlo/evaluator/hlo_evaluator.h"
 #include "xla/hlo/ir/dfs_hlo_visitor.h"
@@ -26,6 +29,7 @@
 #include "xla/hlo/ir/hlo_opcode.h"
 #include "xla/literal.h"
 #include "xla/literal_util.h"
+#include "xla/shape_util.h"
 #include "xla/util.h"
 #include "xla/xla_data.pb.h"
 
diff --git a/third_party/xla/xla/client/xla_builder.cc b/third_party/xla/xla/client/xla_builder.cc
index 0b679af..98e7dad 100644
--- a/third_party/xla/xla/client/xla_builder.cc
+++ b/third_party/xla/xla/client/xla_builder.cc
@@ -1450,6 +1450,42 @@
   });
 }
 
+XlaOp XlaBuilder::CompositeCall(const XlaComputation& computation,
+                                absl::Span<const XlaOp> operands,
+                                const std::string& name,
+                                std::optional<absl::string_view> attributes,
+                                std::optional<int64_t> version) {
+  return ReportErrorOrReturn([&]() -> absl::StatusOr<XlaOp> {
+    HloInstructionProto instr;
+    std::vector<const Shape*> operand_shape_ptrs;
+    TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
+    absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
+                      [](const Shape& shape) { return &shape; });
+    TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
+                        computation.GetProgramShape());
+    TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferCallShape(
+                                         operand_shape_ptrs,
+                                         /*to_apply=*/called_program_shape));
+    *instr.mutable_shape() = shape.ToProto();
+
+    AddCalledComputation(computation, &instr);
+    instr.set_is_composite(true);
+
+    TF_ASSIGN_OR_RETURN(
+        XlaOp instruction,
+        AddInstruction(std::move(instr), HloOpcode::kCall, operands));
+    TF_RETURN_IF_ERROR(
+        SetInstructionFrontendAttribute(instruction, "composite.name", name));
+    TF_RETURN_IF_ERROR(SetInstructionFrontendAttribute(
+        instruction, "composite.attributes",
+        attributes.has_value() ? std::string(*attributes) : "{}"));
+    TF_RETURN_IF_ERROR(SetInstructionFrontendAttribute(
+        instruction, "composite.version",
+        version.has_value() ? std::to_string(*version) : "0"));
+    return instruction;
+  });
+}
+
 XlaOp XlaBuilder::Parameter(
     int64_t parameter_number, const Shape& shape, const std::string& name,
     const std::vector<bool>& replicated_at_leaf_buffers) {
@@ -5196,6 +5232,14 @@
   return builder->Call(computation, operands);
 }
 
+XlaOp CompositeCall(XlaBuilder* builder, const XlaComputation& computation,
+                    absl::Span<const XlaOp> operands, const std::string& name,
+                    std::optional<absl::string_view> attributes,
+                    std::optional<int64_t> version) {
+  return builder->CompositeCall(computation, operands, name, attributes,
+                                version);
+}
+
 XlaOp CustomCall(
     XlaBuilder* builder, const std::string& call_target_name,
     absl::Span<const XlaOp> operands, const Shape& shape,
diff --git a/third_party/xla/xla/client/xla_builder.h b/third_party/xla/xla/client/xla_builder.h
index c1192bf..53683cf 100644
--- a/third_party/xla/xla/client/xla_builder.h
+++ b/third_party/xla/xla/client/xla_builder.h
@@ -731,6 +731,12 @@
   XlaOp Call(const XlaComputation& computation,
              absl::Span<const XlaOp> operands);
 
+  XlaOp CompositeCall(
+      const XlaComputation& computation, absl::Span<const XlaOp> operands,
+      const std::string& name,
+      std::optional<absl::string_view> attributes = std::nullopt,
+      std::optional<int64_t> version = std::nullopt);
+
   XlaOp CustomCall(
       const std::string& call_target_name, absl::Span<const XlaOp> operands,
       const Shape& shape_with_layout, const std::string& opaque,
@@ -1378,6 +1384,14 @@
                       const std::string& outfeed_config);
   friend XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
                     absl::Span<const XlaOp> operands);
+
+  friend XlaOp CompositeCall(XlaBuilder* builder,
+                             const XlaComputation& computation,
+                             absl::Span<const XlaOp> operands,
+                             const std::string& name,
+                             std::optional<absl::string_view> attributes,
+                             std::optional<int64_t> version);
+
   friend XlaOp CustomCall(
       XlaBuilder* builder, const std::string& call_target_name,
       absl::Span<const XlaOp> operands, const Shape& shape,
@@ -2305,6 +2319,12 @@
 XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
            absl::Span<const XlaOp> operands);
 
+// Enqueues a composite call instruction onto the computation.
+XlaOp CompositeCall(XlaBuilder* builder, const XlaComputation& computation,
+                    absl::Span<const XlaOp> operands, const std::string& name,
+                    std::optional<absl::string_view> attributes = std::nullopt,
+                    std::optional<int64_t> version = std::nullopt);
+
 // Enqueues a custom call instruction onto the computation. A custom call
 // invokes code external to XLA. The |operands| are passed to the external code,
 // and the external code is expected to produce a result of the given
diff --git a/third_party/xla/xla/client/xla_builder_test.cc b/third_party/xla/xla/client/xla_builder_test.cc
index 9828d50..8ecf243 100644
--- a/third_party/xla/xla/client/xla_builder_test.cc
+++ b/third_party/xla/xla/client/xla_builder_test.cc
@@ -48,6 +48,7 @@
 #include "xla/hlo/ir/hlo_opcode.h"
 #include "xla/hlo/ir/hlo_sharding.h"
 #include "xla/layout_util.h"
+#include "xla/service/hlo.pb.h"
 #include "xla/service/hlo_parser.h"
 #include "xla/service/pattern_matcher.h"
 #include "xla/service/pattern_matcher_gmock.h"
@@ -330,6 +331,176 @@
                                 m::Call(m::Constant(), m::Constant()))));
 }
 
+TEST(XlaBuilderTest, CompositeCall) {
+  XlaBuilder b(TestName());
+  const Shape shape = ShapeUtil::MakeShape(F32, {});
+  const Shape expected = ShapeUtil::MakeTupleShape({shape, shape, shape});
+
+  XlaBuilder bsum(TestName());
+  Add(Parameter(&bsum, 0, shape, "arg0"), Parameter(&bsum, 1, shape, "arg1"));
+  TF_ASSERT_OK_AND_ASSIGN(const XlaComputation computation, bsum.Build());
+
+  std::vector<XlaOp> operands = {Parameter(&b, 0, shape, "arg0"),
+                                 Parameter(&b, 1, shape, "arg1")};
+  CompositeCall(&b, computation, absl::MakeSpan(operands),
+                /*name=*/"foo.bar",
+                /*attributes=*/"{n = 1 : i32, tensor = dense<1> : tensor<i32>}",
+                /*version=*/1);
+
+  TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b));
+  EXPECT_THAT(GetRoot(*module),
+              GmockMatch(m::Call(m::Parameter(), m::Parameter())));
+}
+
+TEST(XlaBuilderTest, CompositeCallFrontendAttributesStayLocal) {
+  XlaBuilder b(TestName());
+  const Shape shape = ShapeUtil::MakeShape(F32, {});
+  const Shape expected = ShapeUtil::MakeTupleShape({shape, shape, shape});
+
+  XlaBuilder bsum(TestName());
+  Add(Parameter(&bsum, 0, shape, "arg0"), Parameter(&bsum, 1, shape, "arg1"));
+  TF_ASSERT_OK_AND_ASSIGN(const XlaComputation computation, bsum.Build());
+
+  std::vector<XlaOp> operands = {Parameter(&b, 0, shape, "arg0"),
+                                 Parameter(&b, 1, shape, "arg1")};
+  CompositeCall(&b, computation, absl::MakeSpan(operands),
+                /*name=*/"foo.bar",
+                /*attributes=*/"{n = 1 : i32, tensor = dense<1> : tensor<i32>}",
+                /*version=*/1);
+  Add(operands[0], operands[1]);
+
+  TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b));
+  EXPECT_TRUE(GetRoot(*module)->frontend_attributes().map().empty());
+}
+
+TEST(XlaBuilderTest, CompositeCallMissingName) {
+  XlaBuilder b(TestName());
+  const Shape shape = ShapeUtil::MakeShape(F32, {});
+  const Shape expected = ShapeUtil::MakeTupleShape({shape, shape, shape});
+
+  XlaBuilder bsum(TestName());
+  Add(Parameter(&bsum, 0, shape, "arg0"), Parameter(&bsum, 1, shape, "arg1"));
+  TF_ASSERT_OK_AND_ASSIGN(const XlaComputation computation, bsum.Build());
+
+  std::vector<XlaOp> operands = {Parameter(&b, 0, shape, "arg0"),
+                                 Parameter(&b, 1, shape, "arg1")};
+  CompositeCall(&b, computation, absl::MakeSpan(operands), /*name=*/"",
+                /*attributes=*/"{n = 1 : i32, tensor = dense<1> : tensor<i32>}",
+                /*version=*/1);
+
+  auto statusor = BuildHloModule(b);
+  ASSERT_FALSE(statusor.ok());
+  EXPECT_THAT(statusor.status().message(),
+              HasSubstr("A composite call op must have frontend attributes "
+                        "with key composite.name whose value is non-empty"));
+}
+
+TEST(XlaBuilderTest, CompositeCallMissingAttribute) {
+  XlaBuilder b(TestName());
+  const Shape shape = ShapeUtil::MakeShape(F32, {});
+  const Shape expected = ShapeUtil::MakeTupleShape({shape, shape, shape});
+
+  XlaBuilder bsum(TestName());
+  Add(Parameter(&bsum, 0, shape, "arg0"), Parameter(&bsum, 1, shape, "arg1"));
+  TF_ASSERT_OK_AND_ASSIGN(const XlaComputation computation, bsum.Build());
+
+  std::vector<XlaOp> operands = {Parameter(&b, 0, shape, "arg0"),
+                                 Parameter(&b, 1, shape, "arg1")};
+  CompositeCall(&b, computation, absl::MakeSpan(operands), /*name=*/"foo.bar",
+                /*attributes=*/"", /*version=*/1);
+
+  auto statusor = BuildHloModule(b);
+  ASSERT_FALSE(statusor.ok());
+  EXPECT_THAT(
+      statusor.status().message(),
+      HasSubstr(
+          "A composite call op must have frontend attributes with key "
+          "composite.attributes whose value is default: {} or non-empty"));
+}
+
+TEST(XlaBuilderTest, CompositeCallNonNegativeVersion) {
+  XlaBuilder b(TestName());
+
+  FrontendAttributes frontend_attributes = b.frontend_attributes();
+  frontend_attributes.mutable_map()->insert({"foo", "bar"});
+  b.SetFrontendAttributes(frontend_attributes);
+
+  const Shape shape = ShapeUtil::MakeShape(F32, {});
+  const Shape expected = ShapeUtil::MakeTupleShape({shape, shape, shape});
+
+  XlaBuilder bsum(TestName());
+  Add(Parameter(&bsum, 0, shape, "arg0"), Parameter(&bsum, 1, shape, "arg1"));
+  TF_ASSERT_OK_AND_ASSIGN(const XlaComputation computation, bsum.Build());
+
+  std::vector<XlaOp> operands = {Parameter(&b, 0, shape, "arg0"),
+                                 Parameter(&b, 1, shape, "arg1")};
+  CompositeCall(&b, computation, absl::MakeSpan(operands),
+                /*name=*/"foo.bar",
+                /*attributes=*/"{n = 1 : i32, tensor = dense<1> : tensor<i32>}",
+                /*version=*/-1);
+
+  auto statusor = BuildHloModule(b);
+  ASSERT_FALSE(statusor.ok());
+  EXPECT_THAT(statusor.status().message(),
+              HasSubstr("A composite call op must have frontend attributes "
+                        "with a composite.version whose value is a "
+                        "non-negative integer but got: -1"));
+}
+
+TEST(XlaBuilderTest, CompositeCallOptionalVersionAndAttribute) {
+  XlaBuilder b(TestName());
+  const Shape shape = ShapeUtil::MakeShape(F32, {});
+  const Shape expected = ShapeUtil::MakeTupleShape({shape, shape, shape});
+
+  XlaBuilder bsum(TestName());
+  Add(Parameter(&bsum, 0, shape, "arg0"), Parameter(&bsum, 1, shape, "arg1"));
+  TF_ASSERT_OK_AND_ASSIGN(const XlaComputation computation, bsum.Build());
+
+  std::vector<XlaOp> operands = {Parameter(&b, 0, shape, "arg0"),
+                                 Parameter(&b, 1, shape, "arg1")};
+  CompositeCall(&b, computation, absl::MakeSpan(operands), /*name=*/"foo.bar");
+
+  TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b));
+  ASSERT_THAT(GetRoot(*module),
+              GmockMatch(m::Call(m::Parameter(), m::Parameter())));
+  ASSERT_TRUE(GetRoot(*module)->frontend_attributes().map().contains(
+      "composite.attributes"));
+  EXPECT_EQ(
+      GetRoot(*module)->frontend_attributes().map().at("composite.attributes"),
+      "{}");
+  EXPECT_EQ(
+      GetRoot(*module)->frontend_attributes().map().at("composite.version"),
+      "0");
+}
+
+TEST(XlaBuilderTest, CompositeCallWithExtraFrontendAttributes) {
+  XlaBuilder b(TestName());
+
+  FrontendAttributes frontend_attributes = b.frontend_attributes();
+  frontend_attributes.mutable_map()->insert({"foo", "bar"});
+  b.SetFrontendAttributes(frontend_attributes);
+
+  const Shape shape = ShapeUtil::MakeShape(F32, {});
+  const Shape expected = ShapeUtil::MakeTupleShape({shape, shape, shape});
+
+  XlaBuilder bsum(TestName());
+  Add(Parameter(&bsum, 0, shape, "arg0"), Parameter(&bsum, 1, shape, "arg1"));
+  TF_ASSERT_OK_AND_ASSIGN(const XlaComputation computation, bsum.Build());
+
+  std::vector<XlaOp> operands = {Parameter(&b, 0, shape, "arg0"),
+                                 Parameter(&b, 1, shape, "arg1")};
+  CompositeCall(&b, computation, absl::MakeSpan(operands),
+                /*name=*/"foo.bar",
+                /*attributes=*/"{n = 1 : i32, tensor = dense<1> : tensor<i32>}",
+                /*version=*/1);
+
+  TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b));
+  EXPECT_THAT(GetRoot(*module),
+              GmockMatch(m::Call(m::Parameter(), m::Parameter())));
+  ASSERT_TRUE(GetRoot(*module)->frontend_attributes().map().contains("foo"));
+  EXPECT_EQ(GetRoot(*module)->frontend_attributes().map().at("foo"), "bar");
+}
+
 TEST(XlaBuilderTest, BinopHasDegenerateBroadcast) {
   XlaBuilder b(TestName());
   auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {1, 2, 3}), "x");
diff --git a/third_party/xla/xla/client/xla_computation.cc b/third_party/xla/xla/client/xla_computation.cc
index c92de63..fc55846 100644
--- a/third_party/xla/xla/client/xla_computation.cc
+++ b/third_party/xla/xla/client/xla_computation.cc
@@ -18,6 +18,9 @@
 #include <memory>
 #include <utility>
 
+#include "absl/status/statusor.h"
+#include "xla/service/hlo.pb.h"
+#include "xla/shape.h"
 #include "xla/status_macros.h"
 #include "xla/util.h"
 
diff --git a/third_party/xla/xla/client/xla_computation.h b/third_party/xla/xla/client/xla_computation.h
index e21a92d..52a54aa 100644
--- a/third_party/xla/xla/client/xla_computation.h
+++ b/third_party/xla/xla/client/xla_computation.h
@@ -20,6 +20,7 @@
 #include <string>
 #include <utility>
 
+#include "absl/status/statusor.h"
 #include "xla/service/hlo.pb.h"
 #include "xla/shape.h"
 #include "xla/status_macros.h"
diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc
index 79e716b..890bd25 100644
--- a/third_party/xla/xla/debug_options_flags.cc
+++ b/third_party/xla/xla/debug_options_flags.cc
@@ -237,6 +237,7 @@
   opts.set_xla_gpu_enable_llvm_module_compilation_parallelism(false);
 
   opts.set_xla_gpu_enable_libnvptxcompiler(false);
+  opts.set_xla_gpu_enable_libnvjitlink(false);
 
   opts.set_xla_gpu_enable_dot_strength_reduction(true);
 
@@ -246,7 +247,7 @@
   opts.set_xla_gpu_nccl_p2p_max_nchannels(0);
 
 #if GOOGLE_CUDA
-  opts.set_xla_gpu_mlir_emitter_level(3);
+  opts.set_xla_gpu_mlir_emitter_level(4);
 #else
   opts.set_xla_gpu_mlir_emitter_level(0);
 #endif
@@ -281,6 +282,8 @@
 
   opts.set_xla_enable_command_buffers_during_profiling(false);
 
+  opts.set_xla_gpu_cudnn_gemm_max_plans(5);
+
   return opts;
 }
 
@@ -1834,6 +1837,12 @@
       "Experimental: Enable command buffers while a profiling active. "
       "By default, enabling profiling switches from command buffers to "
       "op-by-op mode."));
+  flag_list->push_back(tsl::Flag(
+      "xla_gpu_cudnn_gemm_max_plans",
+      int32_setter_for(&DebugOptions::set_xla_gpu_cudnn_gemm_max_plans),
+      debug_options->xla_gpu_cudnn_gemm_max_plans(),
+      "Limit for the number of kernel configurations (plans) to use during "
+      "autotuning of cuDNN GEMM fusions."));
 }  // NOLINT(readability/fn_size)
 
 // Allocates flag_values and flag_objects; this function must not be called more
diff --git a/third_party/xla/xla/examples/axpy/stablehlo_compile_test.cc b/third_party/xla/xla/examples/axpy/stablehlo_compile_test.cc
index 897a1e9..49a99ee 100644
--- a/third_party/xla/xla/examples/axpy/stablehlo_compile_test.cc
+++ b/third_party/xla/xla/examples/axpy/stablehlo_compile_test.cc
@@ -62,10 +62,8 @@
   //   PlatformUtil::GetPlatform("CUDA"));
   TF_ASSERT_OK_AND_ASSIGN(se::Platform * platform,
                           PlatformUtil::GetPlatform("cpu"));
-  se::StreamExecutorConfig config;
-  config.ordinal = 0;
   TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor,
-                          platform->GetExecutor(config));
+                          platform->ExecutorForDevice(/*ordinal=*/0));
 
   // LocalDeviceState and PjRtStreamExecutorDevice describes the state of a
   // device which can do computation or transfer buffers. This could represent a
diff --git a/third_party/xla/xla/ffi/BUILD b/third_party/xla/xla/ffi/BUILD
index f01ddef..2e9deaa 100644
--- a/third_party/xla/xla/ffi/BUILD
+++ b/third_party/xla/xla/ffi/BUILD
@@ -118,8 +118,10 @@
         ":api",
         ":execution_context",
         ":execution_state",
+        "//xla:executable_run_options",
         "//xla:shape_util",
         "//xla:types",
+        "//xla:util",
         "//xla:xla_data_proto_cc",
         "//xla/ffi/api:c_api",
         "//xla/ffi/api:c_api_internal",
@@ -148,6 +150,7 @@
         ":execution_context",
         ":execution_state",
         ":type_id_registry",
+        "//xla:executable_run_options",
         "//xla:util",
         "//xla/ffi/api:c_api",
         "//xla/ffi/api:c_api_internal",
diff --git a/third_party/xla/xla/ffi/api/api.h b/third_party/xla/xla/ffi/api/api.h
index 7675c3a..dccbcc6 100644
--- a/third_party/xla/xla/ffi/api/api.h
+++ b/third_party/xla/xla/ffi/api/api.h
@@ -33,7 +33,6 @@
 #include <tuple>
 #include <type_traits>
 #include <utility>
-#include <variant>
 #include <vector>
 
 // This is a header-only base C++ library that defines templates for decoding
@@ -232,16 +231,17 @@
   template <typename... Args>
   static std::string StrCat(Args... args);
 
-  static inline XLA_FFI_Error* MakeError(const XLA_FFI_Api* api,
-                                         XLA_FFI_Error_Code errc,
-                                         std::string message);
+  static XLA_FFI_Error* Sucess();
 
-  static inline XLA_FFI_Error* InvalidArgument(const XLA_FFI_Api* api,
-                                               std::string message);
+  static XLA_FFI_Error* MakeError(const XLA_FFI_Api* api,
+                                  XLA_FFI_Error_Code errc, std::string message);
 
-  static inline XLA_FFI_Error* CheckStructSize(const XLA_FFI_Api* api,
-                                               std::string_view struct_name,
-                                               size_t expected, size_t actual);
+  static XLA_FFI_Error* InvalidArgument(const XLA_FFI_Api* api,
+                                        std::string message);
+
+  static XLA_FFI_Error* CheckStructSize(const XLA_FFI_Api* api,
+                                        std::string_view struct_name,
+                                        size_t expected, size_t actual);
 };
 
 XLA_FFI_Error* Ffi::RegisterStaticHandler(const XLA_FFI_Api* api,
@@ -266,8 +266,11 @@
   return ss.str();
 }
 
-XLA_FFI_Error* Ffi::MakeError(const XLA_FFI_Api* api, XLA_FFI_Error_Code errc,
-                              std::string message) {
+inline XLA_FFI_Error* Ffi::Sucess() { return nullptr; }
+
+inline XLA_FFI_Error* Ffi::MakeError(const XLA_FFI_Api* api,
+                                     XLA_FFI_Error_Code errc,
+                                     std::string message) {
   XLA_FFI_Error_Create_Args args;
   args.struct_size = XLA_FFI_Error_Create_Args_STRUCT_SIZE;
   args.priv = nullptr;
@@ -276,15 +279,15 @@
   return api->XLA_FFI_Error_Create(&args);
 }
 
-XLA_FFI_Error* Ffi::InvalidArgument(const XLA_FFI_Api* api,
-                                    std::string message) {
+inline XLA_FFI_Error* Ffi::InvalidArgument(const XLA_FFI_Api* api,
+                                           std::string message) {
   return MakeError(api, XLA_FFI_Error_Code_INVALID_ARGUMENT,
                    std::move(message));
 }
 
-XLA_FFI_Error* Ffi::CheckStructSize(const XLA_FFI_Api* api,
-                                    std::string_view struct_name,
-                                    size_t expected, size_t actual) {
+inline XLA_FFI_Error* Ffi::CheckStructSize(const XLA_FFI_Api* api,
+                                           std::string_view struct_name,
+                                           size_t expected, size_t actual) {
   if (expected != actual) {
     return InvalidArgument(
         api, StrCat("Unexpected ", struct_name, " size: expected ", expected,
@@ -306,12 +309,13 @@
 // parameter packs. We need this to be able to pattern match FFI handler
 // signature at compile time.
 
+// A type tag for decoding optional argument.
+template <typename T>
+struct OptionalArgTag {};
+
 // A type tag to forward all remaining args as `RemainingArgs`.
 struct RemainingArgsTag {};
 
-// A type tag to forward all remaining results as `RemainingRets`.
-struct RemainingRetsTag {};
-
 // A type tag to distinguish parameters tied to results in the `Binding`
 // variadic template. In XLA FFI we use destination passing style APIs and don't
 // return anything from the handler, but instead pass a destination where the
@@ -319,6 +323,13 @@
 template <typename T>
 struct RetTag {};
 
+// A type tag for decoding optional result.
+template <typename T>
+struct OptionalRetTag {};
+
+// A type tag to forward all remaining results as `RemainingRets`.
+struct RemainingRetsTag {};
+
 // A type tag to distinguish parameters tied to the attributes in the
 // `Binding` variadic template.
 template <typename T>
@@ -357,12 +368,30 @@
 
 //----------------------------------------------------------------------------//
 
-// Checks if remaining arguments are in the parameter pack.
+template <typename T>
+struct IsOptionalArgTag : std::false_type {};
+template <typename T>
+struct IsOptionalArgTag<OptionalArgTag<T>> : std::true_type {};
+
+template <typename T>
+struct IsOptionalRetTag : std::false_type {};
+template <typename T>
+struct IsOptionalRetTag<OptionalRetTag<T>> : std::true_type {};
+
+// Checks if parameter pack has an optional argument.
+template <typename... Ts>
+using HasOptionalArgTag = std::disjunction<IsOptionalArgTag<Ts>...>;
+
+// Checks if parameter pack has remaining arguments.
 template <typename... Ts>
 using HasRemainingArgsTag =
     std::disjunction<std::is_same<RemainingArgsTag, Ts>...>;
 
-// Checks if remaining results are in the parameter pack.
+// Checks if parameter pack has an optional result.
+template <typename... Ts>
+using HasOptionalRetTag = std::disjunction<IsOptionalRetTag<Ts>...>;
+
+// Checks if parameter pack has remaining results.
 template <typename... Ts>
 using HasRemainingRetsTag =
     std::disjunction<std::is_same<RemainingRetsTag, Ts>...>;
@@ -413,11 +442,34 @@
  public:
   template <typename T>
   Binding<stage, Ts..., T> Arg() && {
+    static_assert(!internal::HasOptionalArgTag<Ts...>::value,
+                  "argument can't be passed after optional argument");
+    static_assert(!internal::HasRemainingArgsTag<Ts...>::value,
+                  "argument can't be passed after remaining arguments");
     return {std::move(*this)};
   }
 
   template <typename T>
   Binding<stage, Ts..., internal::RetTag<T>> Ret() && {
+    static_assert(!internal::HasOptionalRetTag<Ts...>::value,
+                  "result can't be passed after optional result");
+    static_assert(!internal::HasRemainingRetsTag<Ts...>::value,
+                  "result can't be passed after remaining results");
+    return {std::move(*this)};
+  }
+
+  template <typename T>
+  Binding<stage, Ts..., internal::OptionalArgTag<T>> OptionalArg() && {
+    static_assert(
+        !internal::HasRemainingArgsTag<Ts...>::value,
+        "optional argument can't be passed after remaining arguments");
+    return {std::move(*this)};
+  }
+
+  template <typename T>
+  Binding<stage, Ts..., internal::OptionalRetTag<T>> OptionalRet() && {
+    static_assert(!internal::HasRemainingRetsTag<Ts...>::value,
+                  "optional result can't be passed after remaining results");
     return {std::move(*this)};
   }
 
@@ -427,7 +479,7 @@
     return {std::move(*this)};
   }
 
-  Binding<stage, Ts..., internal::RemainingRetsTag> RemainingResults() && {
+  Binding<stage, Ts..., internal::RemainingRetsTag> RemainingRets() && {
     static_assert(!internal::HasRemainingRetsTag<Ts...>::value,
                   "remaining results can be passed just once");
     return {std::move(*this)};
@@ -900,10 +952,20 @@
   }
 };
 
-}  // namespace internal
+template <typename T>
+struct Decode<OptionalArgTag<T>> {
+  static std::optional<std::optional<T>> call(DecodingOffsets& offsets,
+                                              DecodingContext& ctx,
+                                              DiagnosticEngine& diagnostic) {
+    if (offsets.args >= ctx.call_frame->args.size) {
+      return std::optional<T>(std::nullopt);
+    }
+    return Decode<T>::call(offsets, ctx, diagnostic);
+  }
+};
 
 template <typename T>
-struct internal::Decode<internal::RetTag<T>> {
+struct Decode<RetTag<T>> {
   static std::optional<Result<T>> call(DecodingOffsets& offsets,
                                        DecodingContext& ctx,
                                        DiagnosticEngine& diagnostic) {
@@ -914,7 +976,19 @@
 };
 
 template <typename T>
-struct internal::Decode<internal::AttrTag<T>> {
+struct Decode<OptionalRetTag<T>> {
+  static std::optional<std::optional<Result<T>>> call(
+      DecodingOffsets& offsets, DecodingContext& ctx,
+      DiagnosticEngine& diagnostic) {
+    if (offsets.rets >= ctx.call_frame->rets.size) {
+      return std::optional<Result<T>>(std::nullopt);
+    }
+    return Decode<RetTag<T>>::call(offsets, ctx, diagnostic);
+  }
+};
+
+template <typename T>
+struct Decode<AttrTag<T>> {
   using R = typename AttrDecoding<T>::Type;
 
   static std::optional<R> call(DecodingOffsets& offsets, DecodingContext& ctx,
@@ -946,7 +1020,7 @@
 };
 
 template <typename T>
-struct internal::Decode<internal::CtxTag<T>> {
+struct Decode<CtxTag<T>> {
   using R = typename CtxDecoding<T>::Type;
 
   static std::optional<R> call(DecodingOffsets& offsets, DecodingContext& ctx,
@@ -956,75 +1030,17 @@
   }
 };
 
-//===----------------------------------------------------------------------===//
-// Expected
-//===----------------------------------------------------------------------===//
-
-// Forward declare.
-template <typename E>
-class Unexpected;
-
-// TODO(slebedev): Replace with `std::expected` when C++23 is available.
-template <typename T, typename E>
-class Expected {
- public:
-  constexpr Expected(T value) : data_(std::move(value)) {}  // NOLINT
-  constexpr Expected(Unexpected<E> u);                      // NOLINT
-
-  constexpr operator bool() const {  // NOLINT
-    return has_value();
-  }
-
-  constexpr T& operator*() & { return value(); }
-  constexpr const T& operator*() const& { return value(); }
-  constexpr T&& operator*() && { return std::move(value()); }
-  constexpr const T& operator*() const&& { return std::move(value()); }
-
-  constexpr T* operator->() { return &value(); }
-  constexpr const T* operator->() const { return &value(); }
-
-  constexpr bool has_value() const { return std::holds_alternative<T>(data_); }
-  constexpr bool has_error() const { return std::holds_alternative<E>(data_); }
-
-  constexpr T& value() & { return std::get<T>(data_); }
-  constexpr const T& value() const& { return std::get<T>(data_); }
-  constexpr T&& value() && { return std::get<T>(std::move(data_)); }
-  constexpr const T& value() const&& { return std::get<T>(std::move(data_)); }
-
-  constexpr E& error() & { return std::get<E>(data_); }
-  constexpr const E& error() const& { return std::get<E>(data_); }
-  constexpr E&& error() && { return std::get<E>(std::move(data_)); }
-  constexpr const E&& error() const&& { return std::get<E>(std::move(data_)); }
-
- private:
-  std::variant<T, E> data_;
-};
-
-template <typename E>
-class Unexpected {
- public:
-  constexpr Unexpected(E error) : error_(std::move(error)) {}  // NOLINT
-
- private:
-  template <typename, typename>
-  friend class Expected;
-
-  E error_;
-};
-
-Unexpected(const char*) -> Unexpected<std::string>;
-
-template <typename T, typename E>
-constexpr Expected<T, E>::Expected(Unexpected<E> u)
-    : data_(std::move(u.error_)) {}
+}  // namespace internal
 
 //===----------------------------------------------------------------------===//
 // Type-safe wrapper for accessing a variable number of arguments.
 //===----------------------------------------------------------------------===//
 
-class RemainingArgs {
+namespace internal {
+
+class RemainingArgsBase {
  public:
-  RemainingArgs(const XLA_FFI_Args* args, size_t offset)
+  RemainingArgsBase(const XLA_FFI_Args* args, size_t offset)
       : args_(args), offset_(offset) {
     assert(offset <= args_->size && "illegal remaining args offset");
   }
@@ -1032,43 +1048,26 @@
   size_t size() const { return args_->size - offset_; }
   bool empty() const { return size() == 0; }
 
-  template <typename T>
-  Expected<T, std::string> get(size_t index) const {
-    size_t idx = offset_ + index;
-    if (idx >= args_->size) {
-      return Unexpected("Index out of range.");
-    }
-
-    DiagnosticEngine diagnostic;
-    auto value_opt =
-        ArgDecoding<T>::Decode(args_->types[idx], args_->args[idx], diagnostic);
-    if (!value_opt.has_value()) {
-      return Unexpected(diagnostic.Result());
-    }
-    return *value_opt;
-  }
+ protected:
+  const XLA_FFI_Args* args() const { return args_; }
+  size_t offset() const { return offset_; }
 
  private:
-  const XLA_FFI_Args* args_;  // not owned
+  const XLA_FFI_Args* args_;
   size_t offset_;
 };
 
-template <>
-struct internal::Decode<internal::RemainingArgsTag> {
-  static std::optional<RemainingArgs> call(DecodingOffsets& offsets,
-                                           DecodingContext& ctx,
-                                           DiagnosticEngine& diagnostic) {
-    return RemainingArgs(&ctx.call_frame->args, offsets.args);
-  }
-};
+}  // namespace internal
 
 //===----------------------------------------------------------------------===//
 // Type-safe wrapper for accessing a variable number of results.
 //===----------------------------------------------------------------------===//
 
-class RemainingResults {
+namespace internal {
+
+class RemainingRetsBase {
  public:
-  RemainingResults(const XLA_FFI_Rets* rets, size_t offset)
+  RemainingRetsBase(const XLA_FFI_Rets* rets, size_t offset)
       : rets_(rets), offset_(offset) {
     assert(offset <= rets_->size && "illegal remaining rets offset");
   }
@@ -1076,43 +1075,30 @@
   size_t size() const { return rets_->size - offset_; }
   bool empty() const { return size() == 0; }
 
-  template <typename T>
-  Expected<T, std::string> get(size_t index) const {
-    size_t idx = offset_ + index;
-    if (idx >= rets_->size) {
-      return Unexpected("Index out of range.");
-    }
-
-    DiagnosticEngine diagnostic;
-    auto value_opt =
-        RetDecoding<T>::Decode(rets_->types[idx], rets_->rets[idx], diagnostic);
-    if (!value_opt.has_value()) {
-      return Unexpected(diagnostic.Result());
-    }
-    return **value_opt;
-  }
+ protected:
+  const XLA_FFI_Rets* rets() const { return rets_; }
+  size_t offset() const { return offset_; }
 
  private:
   const XLA_FFI_Rets* rets_;  // not owned
   size_t offset_;
 };
 
-template <>
-struct internal::Decode<internal::RemainingRetsTag> {
-  static std::optional<RemainingResults> call(DecodingOffsets& offsets,
-                                              DecodingContext& ctx,
-                                              DiagnosticEngine& diagnostic) {
-    return RemainingResults(&ctx.call_frame->rets, offsets.rets);
-  }
-};
+}  // namespace internal
 
 //===----------------------------------------------------------------------===//
 // Type-safe wrapper for accessing dictionary attributes.
 //===----------------------------------------------------------------------===//
 
-class Dictionary {
+namespace internal {
+
+// Forward declare dictionary attribute decoding defined below.
+template <typename T, typename... Ts>
+struct DecodeDictionaryAttr;
+
+class DictionaryBase {
  public:
-  explicit Dictionary(const XLA_FFI_Attrs* attrs) : attrs_(attrs) {}
+  explicit DictionaryBase(const XLA_FFI_Attrs* attrs) : attrs_(attrs) {}
 
   size_t size() const { return attrs_->size; }
 
@@ -1120,21 +1106,15 @@
     return Find(name) < attrs_->size;
   }
 
-  template <typename T>
-  Expected<T, std::string> get(std::string_view name) const {
-    DiagnosticEngine diagnostic;
-    auto value_opt = get<T>(name, diagnostic);
-    if (!value_opt.has_value()) {
-      return Unexpected(diagnostic.Result());
-    }
-    return *value_opt;
-  }
+ protected:
+  template <typename T, typename... Ts>
+  friend struct DecodeDictionaryAttr;
 
   template <typename T>
   std::optional<T> get(std::string_view name,
                        DiagnosticEngine& diagnostic) const {
     size_t idx = Find(name);
-    if (idx >= attrs_->size) {
+    if (XLA_FFI_PREDICT_FALSE(idx >= attrs_->size)) {
       return diagnostic.Emit("Unexpected attribute: ") << name;
     }
 
@@ -1161,15 +1141,11 @@
   const XLA_FFI_Attrs* attrs_;
 };
 
-// Decode `AttrsTag` into a generic `Dictionary` attribute.
-template <>
-struct internal::Decode<internal::AttrsTag<Dictionary>> {
-  static std::optional<Dictionary> call(DecodingOffsets& offsets,
-                                        DecodingContext& ctx,
-                                        DiagnosticEngine& diagnostic) {
-    return Dictionary(&ctx.call_frame->attrs);
-  }
-};
+}  // namespace internal
+
+//===----------------------------------------------------------------------===//
+// Decoding for aggregate attributes (decoding dictionaries into structs).
+//===----------------------------------------------------------------------===//
 
 // Decode `AttrsTag` into a type `T` relying on struct decoding defined below.
 template <typename T>
@@ -1186,6 +1162,13 @@
 // Template metaprogramming for decoding handler signature
 //===----------------------------------------------------------------------===//
 
+// Forward declare classes for decoding variadic number of arguments and
+// results. They are defined in `ffi.h` headers (internal and external), to be
+// able to use slightly different implementations for internal and external
+// FFI (`absl::StatusOr` vs `ffi::ErrorOr`).
+class RemainingArgs;
+class RemainingRets;
+
 namespace internal {
 // A helper struct to extract the type of the handler argument.
 template <typename T>
@@ -1193,23 +1176,31 @@
   using Type = T;
 };
 
+template <typename T>
+struct FnArgType<internal::OptionalArgTag<T>> {
+  using Type = std::optional<T>;
+};
+
 template <>
 struct FnArgType<internal::RemainingArgsTag> {
   using Type = RemainingArgs;
 };
 
-template <>
-struct FnArgType<internal::RemainingRetsTag> {
-  using Type = RemainingResults;
-};
-
-// Extracts the underlying type from the returned result type tag.
 template <typename T>
 struct FnArgType<internal::RetTag<T>> {
   using Type = Result<T>;
 };
 
-// Extracts the underlying type from the attribute type tag.
+template <typename T>
+struct FnArgType<internal::OptionalRetTag<T>> {
+  using Type = std::optional<Result<T>>;
+};
+
+template <>
+struct FnArgType<internal::RemainingRetsTag> {
+  using Type = RemainingRets;
+};
+
 template <typename T>
 struct FnArgType<internal::AttrTag<T>> {
   using Type = typename AttrDecoding<T>::Type;
@@ -1220,7 +1211,6 @@
   using Type = T;
 };
 
-// Extracts the underlying type from the context type tag.
 template <typename T>
 struct FnArgType<internal::CtxTag<T>> {
   using Type = typename CtxDecoding<T>::Type;
@@ -1230,20 +1220,27 @@
 // a special decoding rule defined by template specialization.
 template <typename>
 struct IsTagged : std::false_type {};
+
+template <typename T>
+struct IsTagged<OptionalArgTag<T>> : std::true_type {};
 template <typename T>
 struct IsTagged<RetTag<T>> : std::true_type {};
 template <typename T>
+struct IsTagged<OptionalRetTag<T>> : std::true_type {};
+template <typename T>
 struct IsTagged<AttrTag<T>> : std::true_type {};
 template <typename T>
 struct IsTagged<AttrsTag<T>> : std::true_type {};
 template <typename T>
 struct IsTagged<CtxTag<T>> : std::true_type {};
+
 template <>
 struct IsTagged<RemainingArgsTag> : std::true_type {};
 template <>
 struct IsTagged<RemainingRetsTag> : std::true_type {};
 
-// A template for counting regular arguments in the Ts pack.
+// A template for counting regular arguments in the Ts pack (arguments that are
+// not wrapped into a special tag).
 template <typename... Ts>
 struct NumArgs;
 
@@ -1269,9 +1266,15 @@
 
   static constexpr int64_t kNumArgs = internal::NumArgs<Ts...>::value;
 
+  static constexpr int64_t kNumOptionalArgs =
+      internal::NumTagged<internal::OptionalArgTag, Ts...>::value;
+
   static constexpr int64_t kNumRets =
       internal::NumTagged<internal::RetTag, Ts...>::value;
 
+  static constexpr int64_t kNumOptionalRets =
+      internal::NumTagged<internal::OptionalRetTag, Ts...>::value;
+
   static constexpr int64_t kNumAttrs =
       internal::NumTagged<internal::AttrTag, Ts...>::value;
 
@@ -1292,22 +1295,22 @@
  public:
   XLA_FFI_Error* Call(const XLA_FFI_CallFrame* call_frame) const override {
     // Sanity checking call frame struct size.
-    if (auto* err = CheckStructSize(call_frame->api, "XLA_FFI_CallFrame",
-                                    XLA_FFI_CallFrame_STRUCT_SIZE,
-                                    call_frame->struct_size))
+    if (XLA_FFI_Error* err = CheckStructSize(
+            call_frame->api, "XLA_FFI_CallFrame", XLA_FFI_CallFrame_STRUCT_SIZE,
+            call_frame->struct_size)) {
       return err;
+    }
 
     // Check the API versions.
-    auto api_version = call_frame->api->api_version;
+    const XLA_FFI_Api_Version& api_version = call_frame->api->api_version;
     if (api_version.major_version != XLA_FFI_API_MAJOR ||
         api_version.minor_version != XLA_FFI_API_MINOR) {
       return InvalidArgument(
           call_frame->api,
           StrCat("FFI handler's API version (", XLA_FFI_API_MAJOR, ".",
-                 XLA_FFI_API_MINOR,
-                 ") does not match the framework's API version (",
-                 api_version.major_version, ".", api_version.minor_version,
-                 ")"));
+                 XLA_FFI_API_MINOR, ") does not match the framework's API ",
+                 "version (", api_version.major_version, ".",
+                 api_version.minor_version, ")"));
     }
 
     // Check that handler is called during correct execution stage.
@@ -1321,12 +1324,21 @@
 
     // Check that the number of passed arguments matches the signature. Each
     // individual argument decoding will check the actual type.
-    if (internal::HasRemainingArgsTag<Ts...>::value) {
+    if constexpr (internal::HasRemainingArgsTag<Ts...>::value) {
       if (XLA_FFI_PREDICT_FALSE(call_frame->args.size < kNumArgs)) {
         return InvalidArgument(
             call_frame->api,
             StrCat("Wrong number of arguments: expected at least ",
-                   kNumArgs - 1, " but got ", call_frame->args.size));
+                   kNumArgs - kNumOptionalArgs - 1, " but got ",
+                   call_frame->args.size));
+      }
+    } else if constexpr (internal::HasOptionalArgTag<Ts...>::value) {
+      if (XLA_FFI_PREDICT_FALSE(call_frame->args.size < kNumArgs)) {
+        return InvalidArgument(
+            call_frame->api,
+            StrCat("Wrong number of arguments: expected at least ",
+                   kNumArgs - kNumOptionalArgs, " but got ",
+                   call_frame->args.size));
       }
     } else {
       if (XLA_FFI_PREDICT_FALSE(call_frame->args.size != kNumArgs)) {
@@ -1339,12 +1351,21 @@
 
     // Check that the number of results matches the signature. Each individual
     // result decoding will check the actual type.
-    if (internal::HasRemainingRetsTag<Ts...>::value) {
+    if constexpr (internal::HasRemainingRetsTag<Ts...>::value) {
       if (XLA_FFI_PREDICT_FALSE(call_frame->rets.size < kNumRets)) {
         return InvalidArgument(
             call_frame->api,
-            StrCat("Wrong number of results: expected at least ", kNumRets - 1,
-                   " but got ", call_frame->rets.size));
+            StrCat("Wrong number of results: expected at least ",
+                   kNumRets - kNumOptionalRets - 1, " but got ",
+                   call_frame->rets.size));
+      }
+    } else if constexpr (internal::HasOptionalRetTag<Ts...>::value) {
+      if (XLA_FFI_PREDICT_FALSE(call_frame->rets.size < kNumRets)) {
+        return InvalidArgument(
+            call_frame->api,
+            StrCat("Wrong number of results: expected at least ",
+                   kNumRets - kNumOptionalRets, " but got ",
+                   call_frame->rets.size));
       }
     } else {
       if (XLA_FFI_PREDICT_FALSE(call_frame->rets.size != kNumRets)) {
@@ -1515,21 +1536,6 @@
   }
 };
 
-template <>
-struct AttrDecoding<Dictionary> {
-  using Type = Dictionary;
-  static std::optional<Dictionary> Decode(XLA_FFI_AttrType type, void* attr,
-                                          DiagnosticEngine& diagnostic) {
-    if (XLA_FFI_PREDICT_FALSE(type != XLA_FFI_AttrType_DICTIONARY)) {
-      return diagnostic.Emit("Wrong attribute type: expected ")
-             << XLA_FFI_AttrType_DICTIONARY << " but got " << type;
-    }
-
-    auto* attrs = reinterpret_cast<XLA_FFI_Attrs*>(attr);
-    return Dictionary(attrs);
-  }
-};
-
 //===----------------------------------------------------------------------===//
 // Automatic dictionary attributes to structs decoding.
 //===----------------------------------------------------------------------===//
@@ -1574,7 +1580,7 @@
     //
     // Consider using `static auto decoder = ...` below, and compute mapping in
     // constructor. Add benchmarks first to know what to improve!
-    Dictionary dict(attrs);
+    internal::DictionaryBase dict(attrs);
 
     std::tuple<std::optional<Ts>...> members = {
         dict.get<Ts>(names[Is], diagnostic)...};
@@ -1637,7 +1643,7 @@
 // type to decode the attribute as a scalar value and cast it to the enum type.
 #define XLA_FFI_REGISTER_ENUM_ATTR_DECODING(T)                                \
   template <>                                                                 \
-  struct ::xla::ffi::AttrDecoding<T> {                                        \
+  struct xla::ffi::AttrDecoding<T> {                                          \
     using Type = T;                                                           \
     using U = std::underlying_type_t<Type>;                                   \
     static_assert(std::is_enum<Type>::value, "Expected enum class");          \
diff --git a/third_party/xla/xla/ffi/api/c_api_internal.h b/third_party/xla/xla/ffi/api/c_api_internal.h
index 3c5c2ba..da5ea32 100644
--- a/third_party/xla/xla/ffi/api/c_api_internal.h
+++ b/third_party/xla/xla/ffi/api/c_api_internal.h
@@ -74,6 +74,11 @@
 typedef void* XLA_FFI_INTERNAL_ExecutionState_Get(
     XLA_FFI_ExecutionContext* ctx);
 
+// Returns a pointer to the `Eigen::ThreadPoolDevice` passed via run options,
+// which allows FFI handlers to execute tasks in the same thread pool as XLA.
+typedef void* XLA_FFI_INTERNAL_IntraOpThreadPool_Get(
+    XLA_FFI_ExecutionContext* ctx);
+
 //===----------------------------------------------------------------------===//
 // API access
 //===----------------------------------------------------------------------===//
@@ -89,6 +94,7 @@
   _XLA_FFI_INTERNAL_API_STRUCT_FIELD(XLA_FFI_INTERNAL_CalledComputation_Get);
   _XLA_FFI_INTERNAL_API_STRUCT_FIELD(XLA_FFI_INTERNAL_ExecutionContext_Get);
   _XLA_FFI_INTERNAL_API_STRUCT_FIELD(XLA_FFI_INTERNAL_ExecutionState_Get);
+  _XLA_FFI_INTERNAL_API_STRUCT_FIELD(XLA_FFI_INTERNAL_IntraOpThreadPool_Get);
 };
 
 #undef _XLA_FFI_INTERNAL_API_STRUCT_FIELD
diff --git a/third_party/xla/xla/ffi/api/ffi.h b/third_party/xla/xla/ffi/api/ffi.h
index c48996b..c35fdce 100644
--- a/third_party/xla/xla/ffi/api/ffi.h
+++ b/third_party/xla/xla/ffi/api/ffi.h
@@ -36,6 +36,7 @@
 #include <string_view>
 #include <type_traits>
 #include <utility>
+#include <variant>
 #include <vector>
 
 #include "xla/ffi/api/c_api.h"
@@ -75,6 +76,30 @@
   F8E4M3FNUZ = XLA_FFI_DataType_F8E4M3FNUZ,
 };
 
+// Create aliases in ::xla::ffi namespace for all DataTypes, for consistency
+// with xla that defines PrimitiveType enums in ::xla namespace.
+inline constexpr DataType PRED = DataType::PRED;
+inline constexpr DataType S8 = DataType::S8;
+inline constexpr DataType S16 = DataType::S16;
+inline constexpr DataType S32 = DataType::S32;
+inline constexpr DataType S64 = DataType::S64;
+inline constexpr DataType U8 = DataType::U8;
+inline constexpr DataType U16 = DataType::U16;
+inline constexpr DataType U32 = DataType::U32;
+inline constexpr DataType U64 = DataType::U64;
+inline constexpr DataType F16 = DataType::F16;
+inline constexpr DataType F32 = DataType::F32;
+inline constexpr DataType F64 = DataType::F64;
+inline constexpr DataType BF16 = DataType::BF16;
+inline constexpr DataType C64 = DataType::C64;
+inline constexpr DataType C128 = DataType::C128;
+inline constexpr DataType TOKEN = DataType::TOKEN;
+inline constexpr DataType F8E5M2 = DataType::F8E5M2;
+inline constexpr DataType F8E4M3FN = DataType::F8E4M3FN;
+inline constexpr DataType F8E4M3B11FNUZ = DataType::F8E4M3B11FNUZ;
+inline constexpr DataType F8E5M2FNUZ = DataType::F8E5M2FNUZ;
+inline constexpr DataType F8E4M3FNUZ = DataType::F8E4M3FNUZ;
+
 inline std::ostream& operator<<(std::ostream& os, const DataType dtype) {
   return os << static_cast<XLA_FFI_DataType>(dtype);
 }
@@ -149,7 +174,7 @@
 };
 
 //===----------------------------------------------------------------------===//
-// Error and ErrorOr
+// Error
 //===----------------------------------------------------------------------===//
 
 enum class ErrorCode : uint8_t {
@@ -182,19 +207,89 @@
   Error(XLA_FFI_Error_Code errc, std::string message)
       : Error(static_cast<ErrorCode>(errc), std::move(message)) {}
 
-  static Error Success() { return Error(); }
-
   bool success() const { return errc_ == ErrorCode::kOk; }
   bool failure() const { return !success(); }
 
   std::optional<ErrorCode> errc() const { return errc_; }
   const std::string& message() const { return message_; }
 
+  static Error Success() { return Error(); }
+
+  static Error Internal(std::string message) {
+    return Error(ErrorCode::kInternal, std::move(message));
+  }
+
+  static Error InvalidArgument(std::string message) {
+    return Error(ErrorCode::kInvalidArgument, std::move(message));
+  }
+
  private:
   ErrorCode errc_ = ErrorCode::kOk;
   std::string message_;
 };
 
+//===----------------------------------------------------------------------===//
+// Expected<T, E> and ErrorOr<T>
+//===----------------------------------------------------------------------===//
+
+// Forward declare.
+template <typename E>
+class Unexpected;
+
+// TODO(slebedev): Replace with `std::expected` when C++23 is available.
+template <typename T, typename E>
+class Expected {
+ public:
+  constexpr Expected(T value) : data_(std::move(value)) {}  // NOLINT
+  constexpr Expected(Unexpected<E> u);                      // NOLINT
+
+  constexpr operator bool() const {  // NOLINT
+    return has_value();
+  }
+
+  constexpr T& operator*() & { return value(); }
+  constexpr const T& operator*() const& { return value(); }
+  constexpr T&& operator*() && { return std::move(value()); }
+  constexpr const T& operator*() const&& { return std::move(value()); }
+
+  constexpr T* operator->() { return &value(); }
+  constexpr const T* operator->() const { return &value(); }
+
+  constexpr bool has_value() const { return std::holds_alternative<T>(data_); }
+  constexpr bool has_error() const { return std::holds_alternative<E>(data_); }
+
+  constexpr T& value() & { return std::get<T>(data_); }
+  constexpr const T& value() const& { return std::get<T>(data_); }
+  constexpr T&& value() && { return std::get<T>(std::move(data_)); }
+  constexpr const T& value() const&& { return std::get<T>(std::move(data_)); }
+
+  constexpr E& error() & { return std::get<E>(data_); }
+  constexpr const E& error() const& { return std::get<E>(data_); }
+  constexpr E&& error() && { return std::get<E>(std::move(data_)); }
+  constexpr const E&& error() const&& { return std::get<E>(std::move(data_)); }
+
+ private:
+  std::variant<T, E> data_;
+};
+
+template <typename E>
+class Unexpected {
+ public:
+  constexpr Unexpected(E error) : error_(std::move(error)) {}  // NOLINT
+
+ private:
+  template <typename, typename>
+  friend class Expected;
+
+  E error_;
+};
+
+Unexpected(const char*) -> Unexpected<std::string>;
+
+template <typename T, typename E>
+constexpr Expected<T, E>::Expected(Unexpected<E> u)
+    : data_(std::move(u.error_)) {}
+
 template <typename T>
 class ErrorOr : public Expected<T, Error> {
  public:
@@ -484,6 +579,42 @@
 };
 
 //===----------------------------------------------------------------------===//
+// Type-safe wrapper for accessing a variable number of arguments.
+//===----------------------------------------------------------------------===//
+
+class RemainingArgs : public internal::RemainingArgsBase {
+ public:
+  using internal::RemainingArgsBase::RemainingArgsBase;
+
+  template <typename T>
+  ErrorOr<T> get(size_t index) const {
+    size_t idx = offset() + index;
+    if (XLA_FFI_PREDICT_FALSE(idx >= args()->size)) {
+      return Unexpected(
+          Error(ErrorCode::kInvalidArgument, "Index out of range"));
+    }
+
+    DiagnosticEngine diagnostic;
+    std::optional<T> value = ArgDecoding<T>::Decode(
+        args()->types[idx], args()->args[idx], diagnostic);
+    if (XLA_FFI_PREDICT_FALSE(!value.has_value())) {
+      return Unexpected(Error::Internal(diagnostic.Result()));
+    }
+
+    return *value;
+  }
+};
+
+template <>
+struct internal::Decode<internal::RemainingArgsTag> {
+  static std::optional<RemainingArgs> call(DecodingOffsets& offsets,
+                                           DecodingContext& ctx,
+                                           DiagnosticEngine& diagnostic) {
+    return RemainingArgs(&ctx.call_frame->args, offsets.args);
+  }
+};
+
+//===----------------------------------------------------------------------===//
 // Results decoding
 //===----------------------------------------------------------------------===//
 
@@ -524,6 +655,42 @@
 };
 
 //===----------------------------------------------------------------------===//
+// Type-safe wrapper for accessing a variable number of results.
+//===----------------------------------------------------------------------===//
+
+class RemainingRets : public internal::RemainingRetsBase {
+ public:
+  using internal::RemainingRetsBase::RemainingRetsBase;
+
+  template <typename T>
+  ErrorOr<Result<T>> get(size_t index) const {
+    size_t idx = offset() + index;
+    if (XLA_FFI_PREDICT_FALSE(idx >= rets()->size)) {
+      return Unexpected(
+          Error(ErrorCode::kInvalidArgument, "Index out of range"));
+    }
+
+    DiagnosticEngine diagnostic;
+    std::optional<Result<T>> value = RetDecoding<T>::Decode(
+        rets()->types[idx], rets()->rets[idx], diagnostic);
+    if (XLA_FFI_PREDICT_FALSE(!value.has_value())) {
+      return Unexpected(Error::Internal(diagnostic.Result()));
+    }
+
+    return *value;
+  }
+};
+
+template <>
+struct internal::Decode<internal::RemainingRetsTag> {
+  static std::optional<RemainingRets> call(DecodingOffsets& offsets,
+                                           DecodingContext& ctx,
+                                           DiagnosticEngine& diagnostic) {
+    return RemainingRets(&ctx.call_frame->rets, offsets.rets);
+  }
+};
+
+//===----------------------------------------------------------------------===//
 // Attributes decoding
 //===----------------------------------------------------------------------===//
 
@@ -581,6 +748,49 @@
 };
 
 //===----------------------------------------------------------------------===//
+// Type-safe wrapper for accessing dictionary attributes.
+//===----------------------------------------------------------------------===//
+
+class Dictionary : public internal::DictionaryBase {
+ public:
+  using internal::DictionaryBase::DictionaryBase;
+
+  template <typename T>
+  ErrorOr<T> get(std::string_view name) const {
+    DiagnosticEngine diagnostic;
+    std::optional<T> value = internal::DictionaryBase::get<T>(name, diagnostic);
+    if (!value.has_value()) {
+      return Unexpected(Error::Internal(diagnostic.Result()));
+    }
+    return *value;
+  }
+};
+
+// Decode `AttrsTag` (all attributes) into a `Dictionary`.
+template <>
+struct internal::Decode<internal::AttrsTag<Dictionary>> {
+  static std::optional<Dictionary> call(DecodingOffsets& offsets,
+                                        DecodingContext& ctx,
+                                        DiagnosticEngine& diagnostic) {
+    return Dictionary(&ctx.call_frame->attrs);
+  }
+};
+
+// Decode individual attribute into `Dictionary` type.
+template <>
+struct AttrDecoding<Dictionary> {
+  using Type = Dictionary;
+  static std::optional<Dictionary> Decode(XLA_FFI_AttrType type, void* attr,
+                                          DiagnosticEngine& diagnostic) {
+    if (XLA_FFI_PREDICT_FALSE(type != XLA_FFI_AttrType_DICTIONARY)) {
+      return diagnostic.Emit("Wrong attribute type: expected ")
+             << XLA_FFI_AttrType_DICTIONARY << " but got " << type;
+    }
+    return Dictionary(reinterpret_cast<XLA_FFI_Attrs*>(attr));
+  }
+};
+
+//===----------------------------------------------------------------------===//
 // Error helpers
 //===----------------------------------------------------------------------===//
 
@@ -758,6 +968,7 @@
     internal::DestroyError(api_, error);
     return std::nullopt;
   }
+  allocations_.push_back({size, args.data});
   return args.data;
 }
 
diff --git a/third_party/xla/xla/ffi/api/ffi_test.cc b/third_party/xla/xla/ffi/api/ffi_test.cc
index 8db0d46..2bbfd04 100644
--- a/third_party/xla/xla/ffi/api/ffi_test.cc
+++ b/third_party/xla/xla/ffi/api/ffi_test.cc
@@ -19,7 +19,9 @@
 #include <cstdint>
 #include <limits>
 #include <memory>
+#include <optional>
 #include <string>
+#include <string_view>
 #include <type_traits>
 #include <vector>
 
@@ -237,12 +239,11 @@
   builder.AddBufferArg(memory, PrimitiveType::F32, /*dims=*/{2, 2});
   auto call_frame = builder.Build();
 
-  auto handler =
-      Ffi::Bind().Arg<BufferR2<DataType::F32>>().To([&](auto buffer) {
-        EXPECT_EQ(buffer.typed_data(), storage.data());
-        EXPECT_EQ(buffer.dimensions().size(), 2);
-        return Error::Success();
-      });
+  auto handler = Ffi::Bind().Arg<BufferR2<F32>>().To([&](auto buffer) {
+    EXPECT_EQ(buffer.typed_data(), storage.data());
+    EXPECT_EQ(buffer.dimensions().size(), 2);
+    return Error::Success();
+  });
   auto status = Call(*handler, call_frame);
 
   TF_ASSERT_OK(status);
@@ -270,7 +271,7 @@
   CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/0);
   auto call_frame = builder.Build();
 
-  auto handler = Ffi::Bind().Arg<BufferR1<DataType::F32>>().To(
+  auto handler = Ffi::Bind().Arg<BufferR1<F32>>().To(
       [](auto) { return Error::Success(); });
   auto status = Call(*handler, call_frame);
 
@@ -286,7 +287,7 @@
   builder.AddBufferArg(memory, PrimitiveType::F32, /*dims=*/{2, 2});
   auto call_frame = builder.Build();
 
-  auto handler = Ffi::Bind().Arg<BufferR1<DataType::F32>>().To(
+  auto handler = Ffi::Bind().Arg<BufferR1<F32>>().To(
       [](auto) { return Error::Success(); });
   auto status = Call(*handler, call_frame);
 
@@ -303,7 +304,7 @@
   builder.AddBufferArg(memory, PrimitiveType::S32, /*dims=*/{2, 2});
   auto call_frame = builder.Build();
 
-  auto handler = Ffi::Bind().Arg<BufferR2<DataType::F32>>().To(
+  auto handler = Ffi::Bind().Arg<BufferR2<F32>>().To(
       [](auto) { return Error::Success(); });
   auto status = Call(*handler, call_frame);
 
@@ -322,7 +323,7 @@
   auto fn = [&](Token tok) {
     EXPECT_EQ(tok.typed_data(), nullptr);
     EXPECT_EQ(tok.dimensions().size(), 0);
-    return ffi::Error::Success();
+    return Error::Success();
   };
 
   auto handler = Ffi::Bind().Arg<Token>().To(fn);
@@ -330,6 +331,182 @@
   TF_ASSERT_OK(status);
 }
 
+TEST(FfiTest, RemainingArgs) {
+  std::vector<float> storage(4, 0.0f);
+  se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float));
+
+  CallFrameBuilder builder(/*num_args=*/1, /*num_rets=*/0);
+  builder.AddBufferArg(memory, PrimitiveType::F32, /*dims=*/{2, 2});
+  auto call_frame = builder.Build();
+
+  auto fn = [&](RemainingArgs args) {
+    EXPECT_EQ(args.size(), 1);
+
+    ErrorOr<AnyBuffer> arg0 = args.get<AnyBuffer>(0);
+    ErrorOr<AnyBuffer> arg1 = args.get<AnyBuffer>(1);
+
+    EXPECT_TRUE(arg0.has_value());
+    EXPECT_FALSE(arg1.has_value());
+
+    return Error::Success();
+  };
+
+  auto handler = Ffi::Bind().RemainingArgs().To(fn);
+  auto status = Call(*handler, call_frame);
+
+  TF_ASSERT_OK(status);
+}
+
+TEST(FfiTest, RemainingRets) {
+  std::vector<float> storage(4, 0.0f);
+  se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float));
+
+  CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/2);
+  builder.AddBufferRet(memory, PrimitiveType::F32, /*dims=*/{2, 2});
+  builder.AddBufferRet(memory, PrimitiveType::F32, /*dims=*/{2, 2});
+  auto call_frame = builder.Build();
+
+  auto fn = [&](Result<AnyBuffer> ret, RemainingRets rets) {
+    EXPECT_EQ(rets.size(), 1);
+
+    ErrorOr<Result<AnyBuffer>> ret0 = rets.get<AnyBuffer>(0);
+    ErrorOr<Result<AnyBuffer>> ret1 = rets.get<AnyBuffer>(1);
+
+    EXPECT_TRUE(ret0.has_value());
+    EXPECT_FALSE(ret1.has_value());
+
+    return Error::Success();
+  };
+
+  auto handler = Ffi::Bind().Ret<AnyBuffer>().RemainingRets().To(fn);
+  auto status = Call(*handler, call_frame);
+
+  TF_ASSERT_OK(status);
+}
+
+TEST(FfiTest, OptionalArgs) {
+  std::vector<float> storage(4, 0.0f);
+  se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float));
+
+  CallFrameBuilder builder(/*num_args=*/1, /*num_rets=*/0);
+  builder.AddBufferArg(memory, PrimitiveType::F32, /*dims=*/{2, 2});
+  auto call_frame = builder.Build();
+
+  {  // Single optional argument.
+    auto fn = [&](std::optional<AnyBuffer> arg0) {
+      EXPECT_TRUE(arg0.has_value());
+      return Error::Success();
+    };
+
+    auto handler = Ffi::Bind().OptionalArg<AnyBuffer>().To(fn);
+    auto status = Call(*handler, call_frame);
+
+    TF_ASSERT_OK(status);
+  }
+
+  {  // Two optional arguments.
+    auto fn = [&](std::optional<AnyBuffer> arg0,
+                  std::optional<AnyBuffer> arg1) {
+      EXPECT_TRUE(arg0.has_value());
+      EXPECT_FALSE(arg1.has_value());
+      return Error::Success();
+    };
+
+    auto handler =
+        Ffi::Bind().OptionalArg<AnyBuffer>().OptionalArg<AnyBuffer>().To(fn);
+    auto status = Call(*handler, call_frame);
+
+    TF_ASSERT_OK(status);
+  }
+
+  {  // Optional argument after a regular one.
+    auto fn = [&](AnyBuffer arg0, std::optional<AnyBuffer> arg1) {
+      EXPECT_FALSE(arg1.has_value());
+      return Error::Success();
+    };
+
+    auto handler = Ffi::Bind().Arg<AnyBuffer>().OptionalArg<AnyBuffer>().To(fn);
+    auto status = Call(*handler, call_frame);
+
+    TF_ASSERT_OK(status);
+  }
+
+  {  // Remaining arguments after optional one.
+    auto fn = [&](std::optional<AnyBuffer> arg0, RemainingArgs args) {
+      EXPECT_TRUE(arg0.has_value());
+      EXPECT_EQ(args.size(), 0);
+      return Error::Success();
+    };
+
+    auto handler = Ffi::Bind().OptionalArg<AnyBuffer>().RemainingArgs().To(fn);
+    auto status = Call(*handler, call_frame);
+
+    TF_ASSERT_OK(status);
+  }
+}
+
+TEST(FfiTest, OptionalRets) {
+  std::vector<float> storage(4, 0.0f);
+  se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float));
+
+  CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/1);
+  builder.AddBufferRet(memory, PrimitiveType::F32, /*dims=*/{2, 2});
+  auto call_frame = builder.Build();
+
+  {  // Single optional result.
+    auto fn = [&](std::optional<Result<AnyBuffer>> ret0) {
+      EXPECT_TRUE(ret0.has_value());
+      return Error::Success();
+    };
+
+    auto handler = Ffi::Bind().OptionalRet<AnyBuffer>().To(fn);
+    auto status = Call(*handler, call_frame);
+
+    TF_ASSERT_OK(status);
+  }
+
+  {  // Two optional results.
+    auto fn = [&](std::optional<Result<AnyBuffer>> ret0,
+                  std::optional<Result<AnyBuffer>> ret1) {
+      EXPECT_TRUE(ret0.has_value());
+      EXPECT_FALSE(ret1.has_value());
+      return Error::Success();
+    };
+
+    auto handler =
+        Ffi::Bind().OptionalRet<AnyBuffer>().OptionalRet<AnyBuffer>().To(fn);
+    auto status = Call(*handler, call_frame);
+
+    TF_ASSERT_OK(status);
+  }
+
+  {  // Optional result after a regular one.
+    auto fn = [&](Result<AnyBuffer> ret0,
+                  std::optional<Result<AnyBuffer>> ret1) {
+      EXPECT_FALSE(ret1.has_value());
+      return Error::Success();
+    };
+
+    auto handler = Ffi::Bind().Ret<AnyBuffer>().OptionalRet<AnyBuffer>().To(fn);
+    auto status = Call(*handler, call_frame);
+
+    TF_ASSERT_OK(status);
+  }
+
+  {  // Remaining results after optional one.
+    auto fn = [&](std::optional<Result<AnyBuffer>> ret0, RemainingRets rets) {
+      EXPECT_TRUE(ret0.has_value());
+      EXPECT_EQ(rets.size(), 0);
+      return Error::Success();
+    };
+
+    auto handler = Ffi::Bind().OptionalRet<AnyBuffer>().RemainingRets().To(fn);
+    auto status = Call(*handler, call_frame);
+
+    TF_ASSERT_OK(status);
+  }
+}
+
 TEST(FfiTest, AutoBinding) {
   static constexpr char kI32[] = "i32";
 
@@ -463,6 +640,150 @@
   TF_ASSERT_OK(status);
 }
 
+TEST(FfiTest, AttrsAsDictionary) {
+  CallFrameBuilder::AttributesBuilder attrs;
+  attrs.Insert("i32", 42);
+  attrs.Insert("f32", 42.0f);
+  attrs.Insert("str", "foo");
+
+  CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/0);
+  builder.AddAttributes(attrs.Build());
+  auto call_frame = builder.Build();
+
+  auto fn = [&](Dictionary dict) {
+    EXPECT_EQ(dict.size(), 3);
+
+    EXPECT_TRUE(dict.contains("i32"));
+    EXPECT_TRUE(dict.contains("f32"));
+    EXPECT_TRUE(dict.contains("str"));
+
+    ErrorOr<int32_t> i32 = dict.get<int32_t>("i32");
+    ErrorOr<float> f32 = dict.get<float>("f32");
+    ErrorOr<std::string_view> str = dict.get<std::string_view>("str");
+
+    EXPECT_TRUE(i32.has_value());
+    EXPECT_TRUE(f32.has_value());
+    EXPECT_TRUE(str.has_value());
+
+    if (i32.has_value()) EXPECT_EQ(*i32, 42);
+    if (f32.has_value()) EXPECT_EQ(*f32, 42.0f);
+    if (str.has_value()) EXPECT_EQ(*str, "foo");
+
+    EXPECT_FALSE(dict.contains("i64"));
+    EXPECT_FALSE(dict.get<int64_t>("i32").has_value());
+    EXPECT_FALSE(dict.get<int64_t>("i64").has_value());
+
+    return Error::Success();
+  };
+
+  auto handler = Ffi::Bind().Attrs().To(fn);
+  auto status = Call(*handler, call_frame);
+
+  TF_ASSERT_OK(status);
+}
+
+TEST(FfiTest, DictionaryAttr) {
+  CallFrameBuilder::FlatAttributesMap dict0;
+  dict0.try_emplace("i32", 42);
+
+  CallFrameBuilder::FlatAttributesMap dict1;
+  dict1.try_emplace("f32", 42.0f);
+
+  CallFrameBuilder::AttributesBuilder attrs;
+  attrs.Insert("dict0", dict0);
+  attrs.Insert("dict1", dict1);
+
+  CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/0);
+  builder.AddAttributes(attrs.Build());
+  auto call_frame = builder.Build();
+
+  auto fn = [&](Dictionary dict0, Dictionary dict1) {
+    EXPECT_EQ(dict0.size(), 1);
+    EXPECT_EQ(dict1.size(), 1);
+
+    EXPECT_TRUE(dict0.contains("i32"));
+    EXPECT_TRUE(dict1.contains("f32"));
+
+    ErrorOr<int32_t> i32 = dict0.get<int32_t>("i32");
+    ErrorOr<float> f32 = dict1.get<float>("f32");
+
+    EXPECT_TRUE(i32.has_value());
+    EXPECT_TRUE(f32.has_value());
+
+    if (i32.has_value()) EXPECT_EQ(*i32, 42);
+    if (f32.has_value()) EXPECT_EQ(*f32, 42.0f);
+
+    return Error::Success();
+  };
+
+  auto handler =
+      Ffi::Bind().Attr<Dictionary>("dict0").Attr<Dictionary>("dict1").To(fn);
+
+  auto status = Call(*handler, call_frame);
+
+  TF_ASSERT_OK(status);
+}
+
+struct PairOfI32AndF32 {
+  int32_t i32;
+  float f32;
+};
+
+XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(PairOfI32AndF32,
+                                      StructMember<int32_t>("i32"),
+                                      StructMember<float>("f32"));
+
+TEST(FfiTest, StructAttr) {
+  CallFrameBuilder::FlatAttributesMap dict;
+  dict.try_emplace("i32", 42);
+  dict.try_emplace("f32", 42.0f);
+
+  CallFrameBuilder::AttributesBuilder attrs;
+  attrs.Insert("str", "foo");
+  attrs.Insert("i32_and_f32", dict);
+
+  CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/0);
+  builder.AddAttributes(attrs.Build());
+  auto call_frame = builder.Build();
+
+  auto fn = [&](std::string_view str, PairOfI32AndF32 i32_and_f32) {
+    EXPECT_EQ(str, "foo");
+    EXPECT_EQ(i32_and_f32.i32, 42);
+    EXPECT_EQ(i32_and_f32.f32, 42.0f);
+    return Error::Success();
+  };
+
+  auto handler = Ffi::Bind()
+                     .Attr<std::string_view>("str")
+                     .Attr<PairOfI32AndF32>("i32_and_f32")
+                     .To(fn);
+
+  auto status = Call(*handler, call_frame);
+
+  TF_ASSERT_OK(status);
+}
+
+TEST(FfiTest, AttrsAsStruct) {
+  CallFrameBuilder::AttributesBuilder attrs;
+  attrs.Insert("i32", 42);
+  attrs.Insert("f32", 42.0f);
+
+  CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/0);
+  builder.AddAttributes(attrs.Build());
+  auto call_frame = builder.Build();
+
+  auto fn = [&](PairOfI32AndF32 i32_and_f32) {
+    EXPECT_EQ(i32_and_f32.i32, 42);
+    EXPECT_EQ(i32_and_f32.f32, 42.0f);
+    return Error::Success();
+  };
+
+  auto handler = Ffi::Bind().Attrs<PairOfI32AndF32>().To(fn);
+  auto status = Call(*handler, call_frame);
+
+  TF_ASSERT_OK(status);
+}
+
 TEST(FfiTest, PointerAttr) {
   std::string foo = "foo";
 
@@ -641,14 +962,19 @@
 
   // A test only memory allocator that returns a fixed memory address.
   struct TestDeviceMemoryAllocator final : public se::DeviceMemoryAllocator {
-    TestDeviceMemoryAllocator() : se::DeviceMemoryAllocator(nullptr) {}
+    size_t count;
+
+    TestDeviceMemoryAllocator()
+        : se::DeviceMemoryAllocator(nullptr), count(0) {}
 
     absl::StatusOr<se::OwningDeviceMemory> Allocate(int, uint64_t size, bool,
                                                     int64_t) final {
+      count++;
       return se::OwningDeviceMemory(se::DeviceMemoryBase(kAddr, size), 0, this);
     }
 
     absl::Status Deallocate(int, se::DeviceMemoryBase mem) final {
+      count--;
       EXPECT_EQ(mem.opaque(), kAddr);
       return absl::OkStatus();
     }
@@ -672,11 +998,25 @@
       CallFrameBuilder(/*num_args=*/0, /*num_rets=*/0).Build();
 
   CallOptions options;
-  options.allocator = &allocator;
+  options.backend_options = CallOptions::GpuOptions{nullptr, &allocator};
 
   auto status = Call(*handler, call_frame, options);
 
   TF_ASSERT_OK(status);
+  EXPECT_EQ(allocator.count, 0);
+}
+
+TEST(FfiTest, ScratchAllocatorUnimplemented) {
+  auto fn = [&](ScratchAllocator scratch_allocator) {
+    auto mem = scratch_allocator.Allocate(1024);
+    EXPECT_FALSE(mem.has_value());
+    return Error::Success();
+  };
+  auto handler = Ffi::Bind().Ctx<ScratchAllocator>().To(fn);
+  CallFrame call_frame =
+      CallFrameBuilder(/*num_args=*/0, /*num_rets=*/0).Build();
+  auto status = Call(*handler, call_frame);
+  TF_ASSERT_OK(status);
 }
 
 //===----------------------------------------------------------------------===//
@@ -747,7 +1087,7 @@
 void BM_BufferArgX1(benchmark::State& state) {
   auto call_frame = WithBufferArgs(1).Build();
 
-  auto handler = Ffi::Bind().Arg<BufferR4<DataType::F32>>().To([](auto buffer) {
+  auto handler = Ffi::Bind().Arg<BufferR4<F32>>().To([](auto buffer) {
     benchmark::DoNotOptimize(buffer);
     return Error::Success();
   });
@@ -767,10 +1107,10 @@
   auto call_frame = WithBufferArgs(4).Build();
 
   auto handler = Ffi::Bind()
-                     .Arg<BufferR4<DataType::F32>>()
-                     .Arg<BufferR4<DataType::F32>>()
-                     .Arg<BufferR4<DataType::F32>>()
-                     .Arg<BufferR4<DataType::F32>>()
+                     .Arg<BufferR4<F32>>()
+                     .Arg<BufferR4<F32>>()
+                     .Arg<BufferR4<F32>>()
+                     .Arg<BufferR4<F32>>()
                      .To([](auto b0, auto b1, auto b2, auto b3) {
                        benchmark::DoNotOptimize(b0);
                        benchmark::DoNotOptimize(b1);
@@ -794,14 +1134,14 @@
   auto call_frame = WithBufferArgs(8).Build();
 
   auto handler = Ffi::Bind()
-                     .Arg<BufferR4<DataType::F32>>()
-                     .Arg<BufferR4<DataType::F32>>()
-                     .Arg<BufferR4<DataType::F32>>()
-                     .Arg<BufferR4<DataType::F32>>()
-                     .Arg<BufferR4<DataType::F32>>()
-                     .Arg<BufferR4<DataType::F32>>()
-                     .Arg<BufferR4<DataType::F32>>()
-                     .Arg<BufferR4<DataType::F32>>()
+                     .Arg<BufferR4<F32>>()
+                     .Arg<BufferR4<F32>>()
+                     .Arg<BufferR4<F32>>()
+                     .Arg<BufferR4<F32>>()
+                     .Arg<BufferR4<F32>>()
+                     .Arg<BufferR4<F32>>()
+                     .Arg<BufferR4<F32>>()
+                     .Arg<BufferR4<F32>>()
                      .To([](auto b0, auto b1, auto b2, auto b3, auto b4,
                             auto b5, auto b6, auto b7) {
                        benchmark::DoNotOptimize(b0);
diff --git a/third_party/xla/xla/ffi/ffi.h b/third_party/xla/xla/ffi/ffi.h
index 2ace88a..82076c5 100644
--- a/third_party/xla/xla/ffi/ffi.h
+++ b/third_party/xla/xla/ffi/ffi.h
@@ -27,6 +27,8 @@
 #include <limits>
 #include <memory>
 #include <optional>
+#include <string>
+#include <string_view>
 
 // IWYU pragma: begin_exports
 #include "xla/ffi/api/api.h"
@@ -38,6 +40,7 @@
 #include "absl/base/optimization.h"
 #include "absl/status/status.h"
 #include "absl/types/span.h"
+#include "xla/executable_run_options.h"
 #include "xla/ffi/api/c_api.h"
 #include "xla/ffi/api/c_api_internal.h"  // IWYU pragma: keep
 #include "xla/ffi/execution_context.h"
@@ -49,6 +52,7 @@
 #include "xla/stream_executor/scratch_allocator.h"
 #include "xla/stream_executor/stream.h"
 #include "xla/types.h"  // IWYU pragma: keep
+#include "xla/util.h"
 #include "xla/xla_data.pb.h"
 #include "tsl/platform/logging.h"
 
@@ -60,6 +64,7 @@
 struct Allocator {};          // binds `se::DeviceMemoryAllocator*`
 struct ScratchAllocator {};   // binds `se::OwningScratchAllocator`
 struct CalledComputation {};  // binds `HloComputation*`
+struct IntraOpThreadPool {};  // binds `const Eigen::ThreadPoolDevice*`
 
 //===----------------------------------------------------------------------===//
 // Arguments
@@ -239,6 +244,41 @@
 };
 
 //===----------------------------------------------------------------------===//
+// Type-safe wrapper for accessing a variable number of arguments.
+//===----------------------------------------------------------------------===//
+
+class RemainingArgs : public internal::RemainingArgsBase {
+ public:
+  using internal::RemainingArgsBase::RemainingArgsBase;
+
+  template <typename T>
+  absl::StatusOr<T> get(size_t index) const {
+    size_t idx = offset() + index;
+    if (ABSL_PREDICT_FALSE(idx >= args()->size)) {
+      return InvalidArgument("Index out of range.");
+    }
+
+    DiagnosticEngine diagnostic;
+    std::optional<T> value = ArgDecoding<T>::Decode(
+        args()->types[idx], args()->args[idx], diagnostic);
+    if (ABSL_PREDICT_FALSE(!value.has_value())) {
+      return Internal("%s", diagnostic.Result());
+    }
+
+    return *value;
+  }
+};
+
+template <>
+struct internal::Decode<internal::RemainingArgsTag> {
+  static std::optional<RemainingArgs> call(DecodingOffsets& offsets,
+                                           DecodingContext& ctx,
+                                           DiagnosticEngine& diagnostic) {
+    return RemainingArgs(&ctx.call_frame->args, offsets.args);
+  }
+};
+
+//===----------------------------------------------------------------------===//
 // Results decoding
 //===----------------------------------------------------------------------===//
 
@@ -272,6 +312,41 @@
 };
 
 //===----------------------------------------------------------------------===//
+// Type-safe wrapper for accessing a variable number of results.
+//===----------------------------------------------------------------------===//
+
+class RemainingRets : public internal::RemainingRetsBase {
+ public:
+  using internal::RemainingRetsBase::RemainingRetsBase;
+
+  template <typename T>
+  absl::StatusOr<Result<T>> get(size_t index) const {
+    size_t idx = offset() + index;
+    if (ABSL_PREDICT_FALSE(idx >= rets()->size)) {
+      return InvalidArgument("Index out of range.");
+    }
+
+    DiagnosticEngine diagnostic;
+    std::optional<Result<T>> value = RetDecoding<T>::Decode(
+        rets()->types[idx], rets()->rets[idx], diagnostic);
+    if (ABSL_PREDICT_FALSE(!value.has_value())) {
+      return Internal("%s", diagnostic.Result());
+    }
+
+    return *value;
+  }
+};
+
+template <>
+struct internal::Decode<internal::RemainingRetsTag> {
+  static std::optional<RemainingRets> call(DecodingOffsets& offsets,
+                                           DecodingContext& ctx,
+                                           DiagnosticEngine& diagnostic) {
+    return RemainingRets(&ctx.call_frame->rets, offsets.rets);
+  }
+};
+
+//===----------------------------------------------------------------------===//
 // Attributes decoding
 //===----------------------------------------------------------------------===//
 
@@ -330,6 +405,49 @@
 };
 
 //===----------------------------------------------------------------------===//
+// Type-safe wrapper for accessing dictionary attributes.
+//===----------------------------------------------------------------------===//
+
+class Dictionary : public internal::DictionaryBase {
+ public:
+  using internal::DictionaryBase::DictionaryBase;
+
+  template <typename T>
+  absl::StatusOr<T> get(std::string_view name) const {
+    DiagnosticEngine diagnostic;
+    std::optional<T> value = internal::DictionaryBase::get<T>(name, diagnostic);
+    if (!value.has_value()) {
+      return Internal("%s", diagnostic.Result());
+    }
+    return *value;
+  }
+};
+
+// Decode `AttrsTag` (all attributes) into a `Dictionary`.
+template <>
+struct internal::Decode<internal::AttrsTag<Dictionary>> {
+  static std::optional<Dictionary> call(DecodingOffsets& offsets,
+                                        DecodingContext& ctx,
+                                        DiagnosticEngine& diagnostic) {
+    return Dictionary(&ctx.call_frame->attrs);
+  }
+};
+
+// Decode individual attribute into `Dictionary` type.
+template <>
+struct AttrDecoding<Dictionary> {
+  using Type = Dictionary;
+  static std::optional<Dictionary> Decode(XLA_FFI_AttrType type, void* attr,
+                                          DiagnosticEngine& diagnostic) {
+    if (XLA_FFI_PREDICT_FALSE(type != XLA_FFI_AttrType_DICTIONARY)) {
+      return diagnostic.Emit("Wrong attribute type: expected ")
+             << XLA_FFI_AttrType_DICTIONARY << " but got " << type;
+    }
+    return Dictionary(reinterpret_cast<XLA_FFI_Attrs*>(attr));
+  }
+};
+
+//===----------------------------------------------------------------------===//
 // Context decoding
 //===----------------------------------------------------------------------===//
 
@@ -365,7 +483,7 @@
                                     DiagnosticEngine&) {
     void* device_allocator =
         api->internal_api->XLA_FFI_INTERNAL_DeviceMemoryAllocator_Get(ctx);
-    return reinterpret_cast<se::DeviceMemoryAllocator*>(device_allocator);
+    return reinterpret_cast<Type>(device_allocator);
   }
 };
 
@@ -399,6 +517,19 @@
   }
 };
 
+template <>
+struct CtxDecoding<IntraOpThreadPool> {
+  using Type = const Eigen::ThreadPoolDevice*;
+
+  static std::optional<Type> Decode(const XLA_FFI_Api* api,
+                                    XLA_FFI_ExecutionContext* ctx,
+                                    DiagnosticEngine&) {
+    void* intra_op_thread_pool =
+        api->internal_api->XLA_FFI_INTERNAL_IntraOpThreadPool_Get(ctx);
+    return reinterpret_cast<Type>(intra_op_thread_pool);
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // UserData
 //===----------------------------------------------------------------------===//
diff --git a/third_party/xla/xla/ffi/ffi_api.cc b/third_party/xla/xla/ffi/ffi_api.cc
index f402ed2..c5f07eb 100644
--- a/third_party/xla/xla/ffi/ffi_api.cc
+++ b/third_party/xla/xla/ffi/ffi_api.cc
@@ -15,12 +15,14 @@
 
 #include "xla/ffi/ffi_api.h"
 
+#include <algorithm>
 #include <cstddef>
 #include <cstdint>
 #include <exception>
 #include <string>
 #include <string_view>
 #include <utility>
+#include <variant>
 #include <vector>
 
 #include "absl/base/optimization.h"
@@ -31,6 +33,7 @@
 #include "absl/strings/str_cat.h"
 #include "absl/strings/str_format.h"
 #include "absl/strings/str_join.h"
+#include "xla/executable_run_options.h"
 #include "xla/ffi/api/api.h"
 #include "xla/ffi/api/c_api.h"
 #include "xla/ffi/api/c_api_internal.h"  // IWYU pragma: keep
@@ -56,10 +59,19 @@
 };
 
 struct XLA_FFI_ExecutionContext {
-  int32_t device_ordinal = -1;
+  struct CpuContext {
+    const Eigen::ThreadPoolDevice* intra_op_thread_pool = nullptr;
+  };
 
-  stream_executor::Stream* stream = nullptr;
-  stream_executor::DeviceMemoryAllocator* allocator = nullptr;
+  struct GpuContext {
+    stream_executor::Stream* stream = nullptr;
+    stream_executor::DeviceMemoryAllocator* allocator = nullptr;
+  };
+
+  using BackendContext = std::variant<std::monostate, CpuContext, GpuContext>;
+
+  int32_t device_ordinal = -1;
+  BackendContext backend_context = {};
 
   const xla::HloComputation* called_computation = nullptr;
   const xla::ffi::ExecutionContext* execution_context = nullptr;
@@ -76,10 +88,27 @@
 
 static XLA_FFI_ExecutionContext CreateExecutionContext(
     const CallOptions& options) {
+  using BackendContext = XLA_FFI_ExecutionContext::BackendContext;
+
+  // Converts CallOptions to corresponding backend context.
+  struct BackendVisitor {
+    BackendContext operator()(const std::monostate&) const {
+      return std::monostate{};
+    }
+
+    BackendContext operator()(const CallOptions::CpuOptions& options) const {
+      return XLA_FFI_ExecutionContext::CpuContext{options.intra_op_thread_pool};
+    }
+
+    BackendContext operator()(const CallOptions::GpuOptions& options) const {
+      return XLA_FFI_ExecutionContext::GpuContext{options.stream,
+                                                  options.allocator};
+    }
+  };
+
   return XLA_FFI_ExecutionContext{
       options.device_ordinal,
-      options.stream,
-      options.allocator,
+      std::visit(BackendVisitor{}, options.backend_options),
       options.called_computation,
       internal::ScopedExecutionContext::GetCallExecutionContext(options),
       options.execution_state,
@@ -376,12 +405,20 @@
       "XLA_FFI_Stream_Get", XLA_FFI_Stream_Get_Args_STRUCT_SIZE,
       args->struct_size));
 
-  if (args->ctx->stream == nullptr) {
+  auto* gpu = std::get_if<XLA_FFI_ExecutionContext::GpuContext>(
+      &args->ctx->backend_context);
+
+  if (ABSL_PREDICT_FALSE(gpu == nullptr)) {
     return new XLA_FFI_Error{
-        InvalidArgument("XLA FFI stream is not available")};
+        Unimplemented("XLA FFI GPU context is not available")};
   }
 
-  auto handle = args->ctx->stream->platform_specific_handle();
+  if (ABSL_PREDICT_FALSE(gpu->stream == nullptr)) {
+    return new XLA_FFI_Error{
+        Unimplemented("XLA FFI GPU stream is not available")};
+  }
+
+  auto handle = gpu->stream->platform_specific_handle();
   args->stream = handle.stream;
 
   return nullptr;
@@ -459,6 +496,22 @@
       "XLA_FFI_DeviceMemory_Allocate_Args",
       XLA_FFI_DeviceMemory_Allocate_Args_STRUCT_SIZE, args->struct_size));
 
+  auto* gpu = std::get_if<XLA_FFI_ExecutionContext::GpuContext>(
+      &args->ctx->backend_context);
+
+  // TODO(ezhulenev): Device memory allocation should be supported for all
+  // backends, not just GPU, although for CPU it doesn't make much sense, as
+  // plain `new` is sufficient.
+  if (ABSL_PREDICT_FALSE(gpu == nullptr)) {
+    return new XLA_FFI_Error{
+        InvalidArgument("XLA FFI GPU context is not available")};
+  }
+
+  if (ABSL_PREDICT_FALSE(gpu->allocator == nullptr)) {
+    return new XLA_FFI_Error{
+        Unimplemented("No device memory allocator available on this platform")};
+  }
+
   // TODO(ezhulenev): We happen to have the same alignment requirement for
   // device memory on CPU and GPU backends, but instead of hardcoding it here
   // we should query it for the platform XLA FFI handler is registered with.
@@ -471,7 +524,7 @@
   }
 
   absl::StatusOr<stream_executor::OwningDeviceMemory> memory =
-      args->ctx->allocator->Allocate(args->ctx->device_ordinal, args->size);
+      gpu->allocator->Allocate(args->ctx->device_ordinal, args->size);
   if (!memory.ok()) {
     return new XLA_FFI_Error{std::move(memory).status()};
   }
@@ -486,7 +539,23 @@
       "XLA_FFI_DeviceMemory_Free_Args",
       XLA_FFI_DeviceMemory_Free_Args_STRUCT_SIZE, args->struct_size));
 
-  absl::Status status = args->ctx->allocator->Deallocate(
+  auto* gpu = std::get_if<XLA_FFI_ExecutionContext::GpuContext>(
+      &args->ctx->backend_context);
+
+  // TODO(ezhulenev): Device memory allocation should be supported for all
+  // backends, not just GPU, although for CPU it doesn't make much sense, as
+  // plain `new` is sufficient.
+  if (ABSL_PREDICT_FALSE(gpu == nullptr)) {
+    return new XLA_FFI_Error{
+        Unimplemented("XLA FFI GPU context is not available")};
+  }
+
+  if (ABSL_PREDICT_FALSE(gpu->allocator == nullptr)) {
+    return new XLA_FFI_Error{
+        Unimplemented("No device memory allocator available on this platform")};
+  }
+
+  absl::Status status = gpu->allocator->Deallocate(
       args->ctx->device_ordinal,
       stream_executor::DeviceMemoryBase(args->data, args->size));
   if (!status.ok()) {
@@ -509,7 +578,13 @@
 }
 
 static void* XLA_FFI_INTERNAL_Stream_Get(XLA_FFI_ExecutionContext* ctx) {
-  return ctx->stream;
+  if (auto* gpu = std::get_if<XLA_FFI_ExecutionContext::GpuContext>(
+          &ctx->backend_context)) {
+    return gpu->stream;
+  }
+
+  return new XLA_FFI_Error{
+      InvalidArgument("XLA FFI GPU context is not available")};
 }
 
 static int32_t XLA_FFI_INTERNAL_DeviceOrdinal_Get(
@@ -519,7 +594,13 @@
 
 static void* XLA_FFI_INTERNAL_DeviceMemoryAllocator_Get(
     XLA_FFI_ExecutionContext* ctx) {
-  return ctx->allocator;
+  if (auto* gpu = std::get_if<XLA_FFI_ExecutionContext::GpuContext>(
+          &ctx->backend_context)) {
+    return gpu->allocator;
+  }
+
+  return new XLA_FFI_Error{
+      InvalidArgument("XLA FFI GPU context is not available")};
 }
 
 static void* XLA_FFI_INTERNAL_CalledComputation_Get(
@@ -537,6 +618,16 @@
   return const_cast<ffi::ExecutionState*>(ctx->execution_state);
 }
 
+void* XLA_FFI_INTERNAL_IntraOpThreadPool_Get(XLA_FFI_ExecutionContext* ctx) {
+  if (auto* cpu = std::get_if<XLA_FFI_ExecutionContext::CpuContext>(
+          &ctx->backend_context)) {
+    return const_cast<Eigen::ThreadPoolDevice*>(cpu->intra_op_thread_pool);
+  }
+
+  return new XLA_FFI_Error{
+      InvalidArgument("XLA FFI CPU context is not available")};
+}
+
 //===----------------------------------------------------------------------===//
 // XLA FFI Api access
 //===----------------------------------------------------------------------===//
@@ -551,6 +642,7 @@
     XLA_FFI_INTERNAL_CalledComputation_Get,
     XLA_FFI_INTERNAL_ExecutionContext_Get,
     XLA_FFI_INTERNAL_ExecutionState_Get,
+    XLA_FFI_INTERNAL_IntraOpThreadPool_Get,
 };
 
 static XLA_FFI_Api api = {
diff --git a/third_party/xla/xla/ffi/ffi_api.h b/third_party/xla/xla/ffi/ffi_api.h
index 7a6e5aa..f583fa3 100644
--- a/third_party/xla/xla/ffi/ffi_api.h
+++ b/third_party/xla/xla/ffi/ffi_api.h
@@ -19,10 +19,12 @@
 #include <cstdint>
 #include <string>
 #include <string_view>
+#include <variant>
 
 #include "absl/container/flat_hash_map.h"
 #include "absl/status/status.h"
 #include "absl/status/statusor.h"
+#include "xla/executable_run_options.h"
 #include "xla/ffi/api/api.h"
 #include "xla/ffi/api/c_api.h"
 #include "xla/ffi/api/c_api_internal.h"  // IWYU pragma: keep
@@ -47,11 +49,23 @@
 // Calling XLA FFI handlers
 //===----------------------------------------------------------------------===//
 
+// Options for calling XLA FFI handlers. Backend specific options must be
+// constructed from `xla::ExecuteRunOptions`, to give FFI handlers access to
+// XLA runtime internals.
 struct CallOptions {
-  int32_t device_ordinal = -1;
+  struct CpuOptions {
+    const Eigen::ThreadPoolDevice* intra_op_thread_pool = nullptr;
+  };
 
-  se::Stream* stream = nullptr;
-  se::DeviceMemoryAllocator* allocator = nullptr;
+  struct GpuOptions {
+    se::Stream* stream = nullptr;
+    se::DeviceMemoryAllocator* allocator = nullptr;
+  };
+
+  using BackendOptions = std::variant<std::monostate, CpuOptions, GpuOptions>;
+
+  int32_t device_ordinal = -1;
+  BackendOptions backend_options = {};
 
   const HloComputation* called_computation = nullptr;
   const ExecutionContext* execution_context = nullptr;
diff --git a/third_party/xla/xla/ffi/ffi_test.cc b/third_party/xla/xla/ffi/ffi_test.cc
index ab8d200..9fb4ff8 100644
--- a/third_party/xla/xla/ffi/ffi_test.cc
+++ b/third_party/xla/xla/ffi/ffi_test.cc
@@ -317,21 +317,21 @@
     EXPECT_TRUE(dict.contains("f32"));
     EXPECT_TRUE(dict.contains("str"));
 
-    auto i32 = dict.get<int32_t>("i32");
-    auto f32 = dict.get<float>("f32");
-    auto str = dict.get<std::string_view>("str");
+    absl::StatusOr<int32_t> i32 = dict.get<int32_t>("i32");
+    absl::StatusOr<float> f32 = dict.get<float>("f32");
+    absl::StatusOr<std::string_view> str = dict.get<std::string_view>("str");
 
-    EXPECT_TRUE(i32.has_value());
-    EXPECT_TRUE(f32.has_value());
-    EXPECT_TRUE(str.has_value());
+    EXPECT_TRUE(i32.ok());
+    EXPECT_TRUE(f32.ok());
+    EXPECT_TRUE(str.ok());
 
-    if (i32) EXPECT_EQ(*i32, 42);
-    if (f32) EXPECT_EQ(*f32, 42.0f);
-    if (str) EXPECT_EQ(*str, "foo");
+    if (i32.ok()) EXPECT_EQ(*i32, 42);
+    if (f32.ok()) EXPECT_EQ(*f32, 42.0f);
+    if (str.ok()) EXPECT_EQ(*str, "foo");
 
     EXPECT_FALSE(dict.contains("i64"));
-    EXPECT_FALSE(dict.get<int64_t>("i32").has_value());
-    EXPECT_FALSE(dict.get<int64_t>("i64").has_value());
+    EXPECT_FALSE(dict.get<int64_t>("i32").ok());
+    EXPECT_FALSE(dict.get<int64_t>("i64").ok());
 
     return absl::OkStatus();
   };
@@ -364,14 +364,14 @@
     EXPECT_TRUE(dict0.contains("i32"));
     EXPECT_TRUE(dict1.contains("f32"));
 
-    auto i32 = dict0.get<int32_t>("i32");
-    auto f32 = dict1.get<float>("f32");
+    absl::StatusOr<int32_t> i32 = dict0.get<int32_t>("i32");
+    absl::StatusOr<float> f32 = dict1.get<float>("f32");
 
-    EXPECT_TRUE(i32.has_value());
-    EXPECT_TRUE(f32.has_value());
+    EXPECT_TRUE(i32.ok());
+    EXPECT_TRUE(f32.ok());
 
-    if (i32) EXPECT_EQ(*i32, 42);
-    if (f32) EXPECT_EQ(*f32, 42.0f);
+    if (i32.ok()) EXPECT_EQ(*i32, 42);
+    if (f32.ok()) EXPECT_EQ(*f32, 42.0f);
 
     return absl::OkStatus();
   };
@@ -631,8 +631,14 @@
 
   auto fn = [&](RemainingArgs args) {
     EXPECT_EQ(args.size(), 1);
-    EXPECT_TRUE(args.get<AnyBuffer>(0).has_value());
-    EXPECT_FALSE(args.get<AnyBuffer>(1).has_value());
+
+    absl::StatusOr<AnyBuffer> arg0 = args.get<AnyBuffer>(0);
+    absl::StatusOr<AnyBuffer> arg1 = args.get<AnyBuffer>(1);
+
+    EXPECT_TRUE(arg0.ok());
+    EXPECT_THAT(arg1.status(), StatusIs(absl::StatusCode::kInvalidArgument,
+                                        HasSubstr("Index out of range")));
+
     return absl::OkStatus();
   };
 
@@ -651,19 +657,148 @@
   builder.AddBufferRet(memory, PrimitiveType::F32, /*dims=*/{2, 2});
   auto call_frame = builder.Build();
 
-  auto fn = [&](Result<AnyBuffer> ret, RemainingResults rets) {
+  auto fn = [&](Result<AnyBuffer> ret, RemainingRets rets) {
     EXPECT_EQ(rets.size(), 1);
-    EXPECT_TRUE(rets.get<AnyBuffer>(0).has_value());
-    EXPECT_FALSE(rets.get<AnyBuffer>(1).has_value());
+
+    absl::StatusOr<Result<AnyBuffer>> ret0 = rets.get<AnyBuffer>(0);
+    absl::StatusOr<Result<AnyBuffer>> ret1 = rets.get<AnyBuffer>(1);
+
+    EXPECT_TRUE(ret0.ok());
+    EXPECT_THAT(ret1.status(), StatusIs(absl::StatusCode::kInvalidArgument,
+                                        HasSubstr("Index out of range")));
+
     return absl::OkStatus();
   };
 
-  auto handler = Ffi::Bind().Ret<AnyBuffer>().RemainingResults().To(fn);
+  auto handler = Ffi::Bind().Ret<AnyBuffer>().RemainingRets().To(fn);
   auto status = Call(*handler, call_frame);
 
   TF_ASSERT_OK(status);
 }
 
+TEST(FfiTest, OptionalArgs) {
+  std::vector<float> storage(4, 0.0f);
+  se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float));
+
+  CallFrameBuilder builder(/*num_args=*/1, /*num_rets=*/0);
+  builder.AddBufferArg(memory, PrimitiveType::F32, /*dims=*/{2, 2});
+  auto call_frame = builder.Build();
+
+  {  // Single optional argument.
+    auto fn = [&](std::optional<AnyBuffer> arg0) {
+      EXPECT_TRUE(arg0.has_value());
+      return absl::OkStatus();
+    };
+
+    auto handler = Ffi::Bind().OptionalArg<AnyBuffer>().To(fn);
+    auto status = Call(*handler, call_frame);
+
+    TF_ASSERT_OK(status);
+  }
+
+  {  // Two optional arguments.
+    auto fn = [&](std::optional<AnyBuffer> arg0,
+                  std::optional<AnyBuffer> arg1) {
+      EXPECT_TRUE(arg0.has_value());
+      EXPECT_FALSE(arg1.has_value());
+      return absl::OkStatus();
+    };
+
+    auto handler =
+        Ffi::Bind().OptionalArg<AnyBuffer>().OptionalArg<AnyBuffer>().To(fn);
+    auto status = Call(*handler, call_frame);
+
+    TF_ASSERT_OK(status);
+  }
+
+  {  // Optional argument after a regular one.
+    auto fn = [&](AnyBuffer arg0, std::optional<AnyBuffer> arg1) {
+      EXPECT_FALSE(arg1.has_value());
+      return absl::OkStatus();
+    };
+
+    auto handler = Ffi::Bind().Arg<AnyBuffer>().OptionalArg<AnyBuffer>().To(fn);
+    auto status = Call(*handler, call_frame);
+
+    TF_ASSERT_OK(status);
+  }
+
+  {  // Remaining arguments after optional one.
+    auto fn = [&](std::optional<AnyBuffer> arg0, RemainingArgs args) {
+      EXPECT_TRUE(arg0.has_value());
+      EXPECT_EQ(args.size(), 0);
+      return absl::OkStatus();
+    };
+
+    auto handler = Ffi::Bind().OptionalArg<AnyBuffer>().RemainingArgs().To(fn);
+    auto status = Call(*handler, call_frame);
+
+    TF_ASSERT_OK(status);
+  }
+}
+
+TEST(FfiTest, OptionalRets) {
+  std::vector<float> storage(4, 0.0f);
+  se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float));
+
+  CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/1);
+  builder.AddBufferRet(memory, PrimitiveType::F32, /*dims=*/{2, 2});
+  auto call_frame = builder.Build();
+
+  {  // Single optional result.
+    auto fn = [&](std::optional<Result<AnyBuffer>> ret0) {
+      EXPECT_TRUE(ret0.has_value());
+      return absl::OkStatus();
+    };
+
+    auto handler = Ffi::Bind().OptionalRet<AnyBuffer>().To(fn);
+    auto status = Call(*handler, call_frame);
+
+    TF_ASSERT_OK(status);
+  }
+
+  {  // Two optional results.
+    auto fn = [&](std::optional<Result<AnyBuffer>> ret0,
+                  std::optional<Result<AnyBuffer>> ret1) {
+      EXPECT_TRUE(ret0.has_value());
+      EXPECT_FALSE(ret1.has_value());
+      return absl::OkStatus();
+    };
+
+    auto handler =
+        Ffi::Bind().OptionalRet<AnyBuffer>().OptionalRet<AnyBuffer>().To(fn);
+    auto status = Call(*handler, call_frame);
+
+    TF_ASSERT_OK(status);
+  }
+
+  {  // Optional result after a regular one.
+    auto fn = [&](Result<AnyBuffer> ret0,
+                  std::optional<Result<AnyBuffer>> ret1) {
+      EXPECT_FALSE(ret1.has_value());
+      return absl::OkStatus();
+    };
+
+    auto handler = Ffi::Bind().Ret<AnyBuffer>().OptionalRet<AnyBuffer>().To(fn);
+    auto status = Call(*handler, call_frame);
+
+    TF_ASSERT_OK(status);
+  }
+
+  {  // Remaining results after optional one.
+    auto fn = [&](std::optional<Result<AnyBuffer>> ret0, RemainingRets rets) {
+      EXPECT_TRUE(ret0.has_value());
+      EXPECT_EQ(rets.size(), 0);
+      return absl::OkStatus();
+    };
+
+    auto handler = Ffi::Bind().OptionalRet<AnyBuffer>().RemainingRets().To(fn);
+    auto status = Call(*handler, call_frame);
+
+    TF_ASSERT_OK(status);
+  }
+}
+
 TEST(FfiTest, RunOptionsCtx) {
   auto call_frame = CallFrameBuilder(/*num_args=*/0, /*num_rets=*/0).Build();
   auto* expected = reinterpret_cast<se::Stream*>(0x01234567);
@@ -674,7 +809,7 @@
   };
 
   CallOptions options;
-  options.stream = expected;
+  options.backend_options = CallOptions::GpuOptions{expected};
 
   auto handler = Ffi::Bind().Ctx<Stream>().To(fn);
   auto status = Call(*handler, call_frame, options);
diff --git a/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc b/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc
index 7610060..b91e50f 100644
--- a/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc
+++ b/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc
@@ -1823,6 +1823,7 @@
       auto generate_twiddles = [](int64_t length, bool inverse) {
         std::vector<ComplexType> twiddles;
         // Need only half the twiddles.
+        twiddles.reserve(length / 2);
         for (int64_t k = 0; k < length / 2; k++) {
           twiddles.push_back(Twiddle(k, length, inverse));
         }
diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD b/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD
index 30a0283..acd6fff 100644
--- a/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD
+++ b/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD
@@ -345,6 +345,7 @@
     ],
     deps = [
         ":auto_sharding",
+        ":auto_sharding_cost_graph",
         ":auto_sharding_option",
         ":auto_sharding_strategy",
         ":auto_sharding_util",
diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc
index 7fa0c4e..558c6ff 100644
--- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc
+++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc
@@ -1989,6 +1989,7 @@
     request.mutable_max_cost()->set_coeff(*max_cost);
   }
   for (const auto& [edge, edge_cost] : cost_graph.edge_costs_) {
+    const auto normalized_edge_cost = Normalize(edge_cost);
     AutoShardingSolverRequest_Pair raw_edge;
     raw_edge.set_first(edge.first);
     raw_edge.set_second(edge.second);
@@ -1997,8 +1998,8 @@
     AutoShardingSolverRequest_Costs mij;
     for (NodeStrategyIdx i = 0; i < edge_cost.n_; i++) {
       for (NodeStrategyIdx j = 0; j < edge_cost.m_; j++) {
-        rij.add_costs(edge_cost(i, j).communication_cost);
-        mij.add_costs(edge_cost(i, j).memory_cost);
+        rij.add_costs(normalized_edge_cost(i, j).communication_cost);
+        mij.add_costs(normalized_edge_cost(i, j).memory_cost);
       }
     }
     request.mutable_resharding_costs()->Add(std::move(rij));
@@ -3016,7 +3017,6 @@
 
   for (size_t i = 0; i < cur->operand_count(); ++i) {
     HloInstruction* operand = cur->mutable_operand(i);
-    operand = PassThroughCustomCallMarkerOperand(operand, cur);
 
     if (!visited.contains(operand) && !IsAlwaysReplicated(operand) &&
         GetShardingStrategy(operand, strategy_map, cost_graph, s_val)
@@ -3040,9 +3040,6 @@
 
   // Propagation ends at output.
   const HloInstruction* output = instructions.back();
-  if (IsCustomCallMarker(output)) {
-    output = output->operand(0);
-  }
 
   // A debug option: whether to do all-gather after backward pass.
   // This controls the location of all-gather.
@@ -3118,8 +3115,7 @@
       while (true) {
         path.push_back(root);
         if (root->opcode() == HloOpcode::kGetTupleElement) {
-          root = PassThroughCustomCallMarkerOperand(root->mutable_operand(0),
-                                                    root);
+          root = root->mutable_operand(0);
         } else {
           break;
         }
@@ -3215,14 +3211,6 @@
             insert_all_gather.push_back(alias_map.at(to_split));
           } else {
             insert_all_gather.push_back(to_split);
-
-            if (to_split->opcode() == HloOpcode::kGetTupleElement &&
-                IsCustomCallMarker(to_split->operand(0)) &&
-                to_split->users().size() == 1 &&
-                to_split->users().front() == output) {
-              insert_all_gather.push_back(PassThroughCustomCallMarkerOperand(
-                  to_split->mutable_operand(0), to_split));
-            }
           }
         }
       } else {
diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.cc
index 8512788..9d28df3 100644
--- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.cc
+++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.cc
@@ -18,6 +18,7 @@
 #include <algorithm>
 #include <cstddef>
 #include <cstdlib>
+#include <limits>
 #include <numeric>
 #include <string>
 #include <utility>
@@ -34,6 +35,24 @@
 namespace xla {
 namespace spmd {
 
+EdgeReshardingCostMatrix Normalize(const EdgeReshardingCostMatrix& edge_cost) {
+  double min_communication_cost = std::numeric_limits<double>::max();
+  for (int i = 0; i < edge_cost.n_; ++i) {
+    for (int j = 0; j < edge_cost.m_; ++j) {
+      min_communication_cost =
+          std::min(min_communication_cost, edge_cost(i, j).communication_cost);
+    }
+  }
+  if (min_communication_cost >= 0) return edge_cost;
+  EdgeReshardingCostMatrix normalized_edge_cost = edge_cost;
+  for (int i = 0; i < edge_cost.n_; ++i) {
+    for (int j = 0; j < edge_cost.m_; ++j) {
+      normalized_edge_cost(i, j).communication_cost -= min_communication_cost;
+    }
+  }
+  return normalized_edge_cost;
+}
+
 CostGraph::CostGraph(const StrategyGroups& strategy_groups,
                      const AssociativeDotPairs& associative_dot_pairs) {
   node_lens_.reserve(strategy_groups.size());
diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h
index fda06ee..3d6bac1 100644
--- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h
+++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h
@@ -55,6 +55,10 @@
 
 using EdgeReshardingCostMatrix = Matrix<EdgeReshardingCost>;
 
+// Normalizes the edge cost matrix by a fixed constant to ensure there are no
+// negative communication costs.
+EdgeReshardingCostMatrix Normalize(const EdgeReshardingCostMatrix& edge_cost);
+
 // A graph data structure to simplify the edge cost graph. It merges nodes and
 // performs path compression.
 class CostGraph {
diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc
index 969fdf3..dbd161b 100644
--- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc
+++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc
@@ -39,7 +39,6 @@
 #include "xla/hlo/experimental/auto_sharding/auto_sharding_option.h"
 #include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h"
 #include "xla/hlo/experimental/auto_sharding/auto_sharding_util.h"
-#include "xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h"
 #include "xla/hlo/experimental/auto_sharding/cluster_environment.h"
 #include "xla/hlo/ir/hlo_casting_utils.h"
 #include "xla/hlo/ir/hlo_instruction.h"
@@ -57,14 +56,8 @@
 namespace spmd {
 namespace {
 
-using DimMap = StableMap</*tensor dim*/ int, /* mesh dim*/ int>;
-using MeshDims = absl::Span<const int64_t>;
-
-struct Enumeration {
-  MeshDims mesh_dims;
-  int64_t i;
-  int64_t j;
-};
+using MeshDimSet = StableSet<int>;
+using DimMap = StableMap</*tensor dim*/ int, /*mesh dims*/ MeshDimSet>;
 
 // Contains base functionality common to both DotHandler and ConvHandler.
 class HandlerBase {
@@ -88,7 +81,6 @@
         option_(option),
         call_graph_(call_graph),
         device_mesh_(cluster_env.device_mesh_),
-        device_mesh_1d_(cluster_env.device_mesh_1d_),
         lhs_(ins->operand(0)),
         rhs_(ins->operand(1)) {}
 
@@ -102,10 +94,12 @@
   HloSharding CreateInputSpec(const HloInstruction* ins, const DimMap& dim_map,
                               const Array<int64_t>& device_mesh) const {
     if (dim_map.empty()) return HloSharding::Replicate();
-    std::vector<int64_t> tensor_dims, mesh_dims;
-    for (const auto& [tensor_dim, mesh_dim] : dim_map) {
+    std::vector<int64_t> tensor_dims;
+    std::vector<std::vector<int64_t>> mesh_dims;
+    for (const auto& [tensor_dim, mesh_dim_set] : dim_map) {
       tensor_dims.push_back(tensor_dim);
-      mesh_dims.push_back(mesh_dim);
+      mesh_dims.push_back(
+          std::vector<int64_t>(mesh_dim_set.begin(), mesh_dim_set.end()));
     }
     return Tile(ins->shape(), tensor_dims, mesh_dims, device_mesh);
   }
@@ -116,7 +110,7 @@
       const std::string& name, const DimMap& lhs_dim_map,
       const DimMap& rhs_dim_map,
       const std::optional<DimMap>& expected_output_dim_map,
-      const Array<int64_t>& device_mesh, double compute_cost = 0,
+      double compute_cost = 0,
       const std::optional<std::function<double(const HloSharding&)>>&
           communication_cost_fn = std::nullopt);
 
@@ -126,7 +120,7 @@
       const std::string& name, const DimMap& lhs_dim_map,
       const DimMap& rhs_dim_map,
       const std::optional<DimMap>& expected_output_dim_map,
-      const Array<int64_t>& device_mesh, double compute_cost = 0,
+      double compute_cost = 0,
       const std::optional<std::function<double(const HloSharding&)>>&
           communication_cost_fn = std::nullopt);
 
@@ -137,7 +131,7 @@
   virtual void AppendAllGatherWindowedEinsumStrategyForOperand(
       int operand_num, const std::string& name, const DimMap& lhs_dim_map,
       const DimMap& rhs_dim_map, const DimMap& output_dim_map,
-      const Array<int64_t>& device_mesh, double compute_cost) {}
+      double compute_cost) {}
 
   // Given an existing (allreduce) sharding candidate, generate a corresponding
   // candidate by additionally sharding (if possible) the dot/conv output, such
@@ -146,7 +140,7 @@
   virtual void AppendReduceScatterWindowedEinsumStrategy(
       const std::string& name, const DimMap& lhs_dim_map,
       const DimMap& rhs_dim_map, const DimMap& output_dim_map,
-      const Array<int64_t>& device_mesh, double compute_cost) {}
+      double compute_cost) {}
 
   std::optional<HloSharding> GetShardingFromUser(const HloSharding& lhs_spec,
                                                  const HloSharding& rhs_spec);
@@ -155,32 +149,65 @@
   // where a subset of all tensor dims is mapped to a subset of mesh dims, such
   // that each tensor dim is mapped to at most mesh dim, and no two tensor dims
   // are mapped to the same mesh dim.
-  // TODO(b/226977360): We might need to generalize this to also allow cases
-  // where a tensor dim can be mapped to multiple mesh dims.
-  void EnumerateGeneral(std::function<void(const DimMap&)> split_func,
-                        int tensor_rank, int current_tensor_dim,
-                        const absl::flat_hash_set<int>& unassigned_mesh_dims,
-                        const DimMap& current_dim_map) {
-    if (current_tensor_dim == tensor_rank) {
+  void Enumerate(std::function<void(const DimMap&)> split_func, int tensor_rank,
+                 int current_mesh_dim_idx,
+                 const std::vector<int>& unassigned_mesh_dims,
+                 const DimMap& current_dim_map) {
+    if (current_mesh_dim_idx == unassigned_mesh_dims.size()) {
       split_func(current_dim_map);
       return;
     }
-    // current_tensor_dim is unsharded
-    EnumerateGeneral(split_func, tensor_rank, current_tensor_dim + 1,
-                     unassigned_mesh_dims, current_dim_map);
-    // current_tensor_dim is sharded across one of the remaining mesh dims
-    for (int mesh_dim : unassigned_mesh_dims) {
+    // Current mesh dim is not assigned to any tensor dim
+    Enumerate(split_func, tensor_rank, current_mesh_dim_idx + 1,
+              unassigned_mesh_dims, current_dim_map);
+
+    for (int i = 0; i < tensor_rank; ++i) {
       DimMap updated_dim_map = current_dim_map;
-      updated_dim_map[current_tensor_dim] = mesh_dim;
-      absl::flat_hash_set<int> updated_unassigned_mesh_dims =
-          unassigned_mesh_dims;
-      updated_unassigned_mesh_dims.erase(
-          updated_unassigned_mesh_dims.find(mesh_dim));
-      EnumerateGeneral(split_func, tensor_rank, current_tensor_dim + 1,
-                       updated_unassigned_mesh_dims, updated_dim_map);
+      if (!updated_dim_map[i].empty() && !option_.allow_mixed_mesh_shape) {
+        continue;
+      }
+      updated_dim_map[i].insert(unassigned_mesh_dims[current_mesh_dim_idx]);
+      Enumerate(split_func, tensor_rank, current_mesh_dim_idx + 1,
+                unassigned_mesh_dims, updated_dim_map);
     }
   }
 
+  bool IsMeshDimSetNonTrivial(const MeshDimSet& mesh_dim_set) {
+    return absl::c_any_of(mesh_dim_set, [&](int mesh_dim) {
+      return device_mesh_.dim(mesh_dim) > 1;
+    });
+  }
+
+  bool IsFullyReplicatedSharding(const DimMap& dim_map,
+                                 const Array<int64_t>& device_mesh) {
+    if (dim_map.empty()) {
+      return true;
+    }
+    for (const auto& [_, mesh_dim_set] : dim_map) {
+      if (IsMeshDimSetNonTrivial(mesh_dim_set)) {
+        return false;
+      }
+    }
+    return true;
+  }
+
+  bool IsFullyReplicatedStrategy(const DimMap& output_dim_map,
+                                 const DimMap& lhs_dim_map,
+                                 const DimMap& rhs_dim_map,
+                                 const Array<int64_t>& device_mesh) {
+    return IsFullyReplicatedSharding(output_dim_map, device_mesh) &&
+           IsFullyReplicatedSharding(lhs_dim_map, device_mesh) &&
+           IsFullyReplicatedSharding(rhs_dim_map, device_mesh);
+  }
+
+  bool IsFullySharded(const DimMap& dim_map, int num_mesh_dims) {
+    int num_mesh_dims_used = 0;
+    for (const auto& [_, mesh_dims] : dim_map) {
+      num_mesh_dims_used += mesh_dims.size();
+    }
+    return num_mesh_dims_used >= num_mesh_dims;
+  }
+
   // Sorts strategies in the increasing order of their memory costs. Anecdotal
   // experience suggests that such a sorted list of strategies works better
   void SortStrategies();
@@ -197,7 +224,6 @@
   const CallGraph& call_graph_;
 
   const Array<int64_t>& device_mesh_;
-  const Array<int64_t>& device_mesh_1d_;
   const HloInstruction* lhs_;
   const HloInstruction* rhs_;
 };
@@ -234,12 +260,13 @@
   void AppendAllGatherWindowedEinsumStrategyForOperand(
       int operand_num, const std::string& name, const DimMap& lhs_dim_map,
       const DimMap& rhs_dim_map, const DimMap& output_dim_map,
-      const Array<int64_t>& device_mesh, double compute_cost) override;
+      double compute_cost) override;
 
-  void AppendReduceScatterWindowedEinsumStrategy(
-      const std::string& name, const DimMap& lhs_dim_map,
-      const DimMap& rhs_dim_map, const DimMap& output_dim_map,
-      const Array<int64_t>& device_mesh, double compute_cost) override;
+  void AppendReduceScatterWindowedEinsumStrategy(const std::string& name,
+                                                 const DimMap& lhs_dim_map,
+                                                 const DimMap& rhs_dim_map,
+                                                 const DimMap& output_dim_map,
+                                                 double compute_cost) override;
 
   absl::Status RegisterStrategies();
 
@@ -324,29 +351,28 @@
   }));
 }
 
-// Given lhs and rhs dim maps, infers a sharding for the output by relying on
-// the sharding_propagation pass. Given that this is a relatively new change
-// (as of 11/2023), we also take an optional expected output dim map as an
-// argument, to verify that sharding propagation in fact infers the sharding
-// we expect (and to crash if it doesn't).
+// Given lhs and rhs dim maps, infers a sharding for the output by relying
+// on the sharding_propagation pass. Given that this is a relatively new
+// change (as of 11/2023), we also take an optional expected output dim map
+// as an argument, to verify that sharding propagation in fact infers the
+// sharding we expect (and to crash if it doesn't).
 // TODO(b/309638633) As we build more confidence in this, we should remove
 // this expected_output_dim_map argument and fully rely on sharding
 // propagation.
 void HandlerBase::MaybeAppendInternal(
     const std::string& name, const DimMap& lhs_dim_map,
     const DimMap& rhs_dim_map,
-    const std::optional<DimMap>& expected_output_dim_map,
-    const Array<int64_t>& device_mesh, double compute_cost,
+    const std::optional<DimMap>& expected_output_dim_map, double compute_cost,
     const std::optional<std::function<double(const HloSharding&)>>&
         communication_cost_fn) {
-  HloSharding lhs_spec = CreateInputSpec(lhs_, lhs_dim_map, device_mesh);
-  HloSharding rhs_spec = CreateInputSpec(rhs_, rhs_dim_map, device_mesh);
+  HloSharding lhs_spec = CreateInputSpec(lhs_, lhs_dim_map, device_mesh_);
+  HloSharding rhs_spec = CreateInputSpec(rhs_, rhs_dim_map, device_mesh_);
   std::optional<HloSharding> output_spec =
       GetShardingFromUser(lhs_spec, rhs_spec);
   if (output_spec.has_value()) {
     if (expected_output_dim_map.has_value()) {
       HloSharding expected_output_spec =
-          CreateInputSpec(ins_, *expected_output_dim_map, device_mesh);
+          CreateInputSpec(ins_, *expected_output_dim_map, device_mesh_);
       // TODO(b/308687597) Once the bug is resolved, we ideally either want
       // have a CHECK statement verifying that the sharding inferred by
       // sharding propagation is in fact what we expect, or we trust sharding
@@ -366,7 +392,7 @@
     }
   } else {
     CHECK(expected_output_dim_map.has_value());
-    output_spec = CreateInputSpec(ins_, *expected_output_dim_map, device_mesh);
+    output_spec = CreateInputSpec(ins_, *expected_output_dim_map, device_mesh_);
     LOG(WARNING)
         << "Sharding propagation could not infer output sharding for:\n  "
         << ins_->ToString() << "\n  LHS Spec: " << lhs_spec
@@ -384,29 +410,27 @@
 void HandlerBase::MaybeAppend(
     const std::string& name, const DimMap& lhs_dim_map,
     const DimMap& rhs_dim_map,
-    const std::optional<DimMap>& expected_output_dim_map,
-    const Array<int64_t>& device_mesh, double compute_cost,
+    const std::optional<DimMap>& expected_output_dim_map, double compute_cost,
     const std::optional<std::function<double(const HloSharding&)>>&
         communication_cost_fn) {
   MaybeAppendInternal(name, lhs_dim_map, rhs_dim_map, expected_output_dim_map,
-                      device_mesh, compute_cost, communication_cost_fn);
+                      compute_cost, communication_cost_fn);
   if (!option_.generate_windowed_einsum_strategies ||
       !expected_output_dim_map.has_value()) {
     return;
   }
   if (absl::StrContains(name, "allreduce")) {
     CHECK(communication_cost_fn.has_value());
-    AppendReduceScatterWindowedEinsumStrategy(name, lhs_dim_map, rhs_dim_map,
-                                              *expected_output_dim_map,
-                                              device_mesh, compute_cost);
+    AppendReduceScatterWindowedEinsumStrategy(
+        name, lhs_dim_map, rhs_dim_map, *expected_output_dim_map, compute_cost);
   } else {
     CHECK(!communication_cost_fn.has_value());
     AppendAllGatherWindowedEinsumStrategyForOperand(
         0, name, lhs_dim_map, rhs_dim_map, *expected_output_dim_map,
-        device_mesh, compute_cost);
+        compute_cost);
     AppendAllGatherWindowedEinsumStrategyForOperand(
         1, name, lhs_dim_map, rhs_dim_map, *expected_output_dim_map,
-        device_mesh, compute_cost);
+        compute_cost);
   }
 }
 
@@ -437,14 +461,15 @@
 }
 
 void HandlerBase::SortStrategies() {
-  absl::c_sort(strategy_group_->strategies,
-               [](const ShardingStrategy& s1, const ShardingStrategy& s2) {
-                 if (s1.memory_cost == s2.memory_cost) {
-                   return s1.name < s2.name;
-                 } else {
-                   return s1.memory_cost < s2.memory_cost;
-                 }
-               });
+  absl::c_stable_sort(
+      strategy_group_->strategies,
+      [](const ShardingStrategy& s1, const ShardingStrategy& s2) {
+        if (s1.memory_cost == s2.memory_cost) {
+          return s1.name < s2.name;
+        } else {
+          return s1.memory_cost < s2.memory_cost;
+        }
+      });
 }
 
 /************** DotHandler function definitions **************/
@@ -520,6 +545,15 @@
   }
 }
 
+std::string ToString(const MeshDimSet& set) { return absl::StrJoin(set, "-"); }
+std::string ToString(const DimMap& map) {
+  std::vector<std::string> strings;
+  for (const auto& [tdim, mdims] : map) {
+    strings.push_back(absl::StrCat("[", tdim, ": ", ToString(mdims), "]"));
+  }
+  return absl::StrJoin(strings, ", ");
+}
+
 std::string DotHandler::GenerateNameForDotSharding(const DimMap& output_dim_map,
                                                    const DimMap& lhs_dim_map) {
   std::string name;
@@ -529,12 +563,12 @@
                                            absl::string_view identifier) {
     for (size_t i = 0; i < out_dims.size(); ++i) {
       int output_batch_dim = out_dims[i];
-      int mesh_dim = -1;
+      MeshDimSet mesh_dim_set;
       auto it = dim_map.find(output_batch_dim);
-      if (it != dim_map.end() && it->second >= 0) {
-        mesh_dim = it->second;
+      if (it != dim_map.end() && !it->second.empty()) {
+        mesh_dim_set = it->second;
       }
-      absl::StrAppend(&name, identifier, mesh_dim);
+      absl::StrAppend(&name, identifier, ToString(mesh_dim_set));
     }
   };
 
@@ -554,9 +588,9 @@
   bool contraction_dim_sharded = false;
   for (size_t i = 0; i < lhs_con_dims_.size(); ++i) {
     if (auto it = lhs_dim_map.find(lhs_con_dims_[i]);
-        it != lhs_dim_map.end() && it->second >= 0) {
+        it != lhs_dim_map.end() && !it->second.empty()) {
       contraction_dim_sharded =
-          contraction_dim_sharded || (device_mesh_.dim(it->second) > 1);
+          contraction_dim_sharded || IsMeshDimSetNonTrivial(it->second);
     }
   }
 
@@ -566,34 +600,17 @@
   return name;
 }
 
-bool IsFullyReplicatedSharding(const DimMap& dim_map,
-                               const Array<int64_t>& device_mesh) {
-  if (dim_map.empty()) {
-    return true;
-  }
-  for (const auto& [_, mesh_dim] : dim_map) {
-    if (device_mesh.dim(mesh_dim) > 1) {
-      return false;
-    }
-  }
-  return true;
-}
-
-bool IsFullyReplicatedStrategy(const DimMap& output_dim_map,
-                               const DimMap& lhs_dim_map,
-                               const DimMap& rhs_dim_map,
-                               const Array<int64_t>& device_mesh) {
-  return IsFullyReplicatedSharding(output_dim_map, device_mesh) &&
-         IsFullyReplicatedSharding(lhs_dim_map, device_mesh) &&
-         IsFullyReplicatedSharding(rhs_dim_map, device_mesh);
-}
-
-bool IsFullySharded(const DimMap& dim_map, int num_mesh_dims) {
-  return dim_map.size() >= num_mesh_dims;
-}
-
 void DotHandler::GenerateDotShardingStrategiesFromOutputSharding(
     const DimMap& output_dim_map) {
+  // This early return is added to ensure parity with the older strategy
+  // generation code. Removing it will only increase the search space.
+  for (const auto& [_, mesh_dims] : output_dim_map) {
+    if (mesh_dims.size() > 1 &&
+        mesh_dims.size() != device_mesh_.num_dimensions()) {
+      return;
+    }
+  }
+
   DimMap lhs_dim_map, rhs_dim_map;
   absl::flat_hash_set<int> used_mesh_dims;
 
@@ -603,11 +620,11 @@
     int lhs_batch_dim = lhs_batch_dims_[i];
     int rhs_batch_dim = rhs_batch_dims_[i];
     auto it = output_dim_map.find(output_batch_dim);
-    if (it != output_dim_map.end() && it->second >= 0) {
-      int mesh_dim = it->second;
-      used_mesh_dims.insert(mesh_dim);
-      lhs_dim_map[lhs_batch_dim] = mesh_dim;
-      rhs_dim_map[rhs_batch_dim] = mesh_dim;
+    if (it != output_dim_map.end() && !it->second.empty()) {
+      const StableSet<int>& mesh_dim_set = it->second;
+      used_mesh_dims.insert(mesh_dim_set.begin(), mesh_dim_set.end());
+      lhs_dim_map[lhs_batch_dim] = mesh_dim_set;
+      rhs_dim_map[rhs_batch_dim] = mesh_dim_set;
     }
   }
 
@@ -617,10 +634,10 @@
     int lhs_space_dim = lhs_space_dims_[i];
     int output_space_dim = out_lhs_space_dims_[i];
     auto it = output_dim_map.find(output_space_dim);
-    if (it != output_dim_map.end() && it->second >= 0) {
-      int mesh_dim = it->second;
-      used_mesh_dims.insert(mesh_dim);
-      lhs_dim_map[lhs_space_dim] = mesh_dim;
+    if (it != output_dim_map.end() && !it->second.empty()) {
+      const StableSet<int>& mesh_dim_set = it->second;
+      used_mesh_dims.insert(mesh_dim_set.begin(), mesh_dim_set.end());
+      lhs_dim_map[lhs_space_dim] = mesh_dim_set;
     }
   }
 
@@ -629,10 +646,10 @@
     int rhs_space_dim = rhs_space_dims_[i];
     int output_space_dim = out_rhs_space_dims_[i];
     auto it = output_dim_map.find(output_space_dim);
-    if (it != output_dim_map.end() && it->second >= 0) {
-      int mesh_dim = it->second;
-      used_mesh_dims.insert(mesh_dim);
-      rhs_dim_map[rhs_space_dim] = mesh_dim;
+    if (it != output_dim_map.end() && !it->second.empty()) {
+      const MeshDimSet& mesh_dim_set = it->second;
+      used_mesh_dims.insert(mesh_dim_set.begin(), mesh_dim_set.end());
+      rhs_dim_map[rhs_space_dim] = mesh_dim_set;
     }
   }
 
@@ -646,7 +663,7 @@
       // generation code. Removing it will only increase the search space.
       IsFullySharded(output_dim_map, device_mesh_.num_dimensions())) {
     MaybeAppend(GenerateNameForDotSharding(output_dim_map, lhs_dim_map),
-                lhs_dim_map, rhs_dim_map, output_dim_map, device_mesh_);
+                lhs_dim_map, rhs_dim_map, output_dim_map);
   }
 
   // Generate shardings for contraction dimensions
@@ -654,10 +671,10 @@
     return;
   }
 
-  absl::flat_hash_set<int> unused_mesh_dims;
+  std::vector<int> unused_mesh_dims;
   for (size_t i = 0; i < device_mesh_.num_dimensions(); ++i) {
     if (!used_mesh_dims.contains(i) && device_mesh_.dim(i) > 1) {
-      unused_mesh_dims.insert(i);
+      unused_mesh_dims.push_back(i);
     }
   }
 
@@ -675,11 +692,11 @@
 
     DimMap lhs_dim_map_with_contractions = lhs_dim_map;
     DimMap rhs_dim_map_with_contractions = rhs_dim_map;
-    for (const auto& [reducton_dim_index, mesh_dim] : reduction_dim_map) {
+    for (const auto& [reduction_dim_index, mesh_dim_set] : reduction_dim_map) {
       lhs_dim_map_with_contractions
-          [lhs_con_dims_[reduction_dims[reducton_dim_index]]] = mesh_dim;
+          [lhs_con_dims_[reduction_dims[reduction_dim_index]]] = mesh_dim_set;
       rhs_dim_map_with_contractions
-          [rhs_con_dims_[reduction_dims[reducton_dim_index]]] = mesh_dim;
+          [rhs_con_dims_[reduction_dims[reduction_dim_index]]] = mesh_dim_set;
     }
     // Skip fully the replicated strategy here as we add that outside of
     // HandleDot in auto_sharding_strategy.
@@ -696,8 +713,10 @@
       double memory_cost =
           ByteSizeOfShapeWithSharding(ins_->shape(), output_sharding);
       double total_cost = 0;
-      for (const auto& [_, mesh_dim] : reduction_dim_map) {
-        total_cost += cluster_env_.AllReduceCost(memory_cost, mesh_dim);
+      for (const auto& [_, mesh_dim_set] : reduction_dim_map) {
+        for (int mesh_dim : mesh_dim_set) {
+          total_cost += cluster_env_.AllReduceCost(memory_cost, mesh_dim);
+        }
       }
       return total_cost;
     };
@@ -705,24 +724,24 @@
     MaybeAppend(GenerateNameForDotSharding(output_dim_map,
                                            lhs_dim_map_with_contractions),
                 lhs_dim_map_with_contractions, rhs_dim_map_with_contractions,
-                output_dim_map, device_mesh_,
+                output_dim_map,
                 /*compute_cost=*/0, communication_cost_fn);
   };
 
-  EnumerateGeneral(split_func, reduction_dims.size(),
-                   /*current_tensor_dim=*/0, unused_mesh_dims,
-                   /*current_dim_map=*/{});
+  Enumerate(split_func, reduction_dims.size(),
+            /*current_mesh_dim_idx=*/0, unused_mesh_dims,
+            /*current_dim_map=*/{});
 }
 
 void DotHandler::AppendAllGatherWindowedEinsumStrategyForOperand(
     int operand_num, const std::string& name, const DimMap& lhs_dim_map,
     const DimMap& rhs_dim_map, const DimMap& output_dim_map,
-    const Array<int64_t>& device_mesh, double compute_cost) {
+    double compute_cost) {
   const HloInstruction* operand = ins_->operand(operand_num);
   const DimMap& operand_dim_map = operand_num == 0 ? lhs_dim_map : rhs_dim_map;
   absl::flat_hash_set<int64_t> used_mesh_dims;
-  for (const auto [tensor_dim, mesh_dim] : operand_dim_map) {
-    used_mesh_dims.insert(mesh_dim);
+  for (const auto& [tensor_dim, mesh_dim_set] : operand_dim_map) {
+    used_mesh_dims.insert(mesh_dim_set.begin(), mesh_dim_set.end());
   }
   if (used_mesh_dims.size() == device_mesh_.num_dimensions() ||
       used_mesh_dims.size() == operand->shape().rank()) {
@@ -732,16 +751,16 @@
   for (int64_t tensor_dim = 0; tensor_dim < operand->shape().rank();
        ++tensor_dim) {
     if (auto it = operand_dim_map.find(tensor_dim);
-        it != operand_dim_map.end() && device_mesh.dim(it->second) > 1) {
+        it != operand_dim_map.end() && IsMeshDimSetNonTrivial(it->second)) {
       continue;
     }
-    for (int64_t mesh_dim = 0; mesh_dim < device_mesh_.num_dimensions();
+    for (int mesh_dim = 0; mesh_dim < device_mesh_.num_dimensions();
          ++mesh_dim) {
       if (used_mesh_dims.contains(mesh_dim)) {
         continue;
       }
       DimMap further_sharded_dim_map = operand_dim_map;
-      further_sharded_dim_map[tensor_dim] = mesh_dim;
+      further_sharded_dim_map[tensor_dim] = MeshDimSet{mesh_dim};
 
       auto communication_cost_fn =
           [](const HloSharding& output_sharding) -> double {
@@ -756,7 +775,7 @@
           updated_name,
           operand_num == 0 ? further_sharded_dim_map : lhs_dim_map,
           operand_num == 1 ? further_sharded_dim_map : rhs_dim_map,
-          output_dim_map, device_mesh, compute_cost, communication_cost_fn);
+          output_dim_map, compute_cost, communication_cost_fn);
     }
   }
 }
@@ -764,10 +783,10 @@
 void DotHandler::AppendReduceScatterWindowedEinsumStrategy(
     const std::string& name, const DimMap& lhs_dim_map,
     const DimMap& rhs_dim_map, const DimMap& output_dim_map,
-    const Array<int64_t>& device_mesh, double compute_cost) {
+    double compute_cost) {
   absl::flat_hash_set<int64_t> used_mesh_dims;
-  for (const auto [tensor_dim, mesh_dim] : output_dim_map) {
-    used_mesh_dims.insert(mesh_dim);
+  for (const auto& [tensor_dim, mesh_dim_set] : output_dim_map) {
+    used_mesh_dims.insert(mesh_dim_set.begin(), mesh_dim_set.end());
   }
 
   if (used_mesh_dims.size() == device_mesh_.num_dimensions() ||
@@ -778,16 +797,16 @@
   for (int64_t tensor_dim = 0; tensor_dim < ins_->shape().rank();
        ++tensor_dim) {
     if (auto it = output_dim_map.find(tensor_dim);
-        it != output_dim_map.end() && device_mesh.dim(it->second) > 1) {
+        it != output_dim_map.end() && IsMeshDimSetNonTrivial(it->second)) {
       continue;
     }
-    for (int64_t mesh_dim = 0; mesh_dim < device_mesh_.num_dimensions();
+    for (int mesh_dim = 0; mesh_dim < device_mesh_.num_dimensions();
          ++mesh_dim) {
       if (used_mesh_dims.contains(mesh_dim)) {
         continue;
       }
       DimMap further_sharded_dim_map = output_dim_map;
-      further_sharded_dim_map[tensor_dim] = mesh_dim;
+      further_sharded_dim_map[tensor_dim] = MeshDimSet{mesh_dim};
 
       auto communication_cost_fn =
           [](const HloSharding& output_sharding) -> double {
@@ -799,23 +818,21 @@
           name,
           absl::StrFormat("|rs_windowed_einsum_t%dm%d", tensor_dim, mesh_dim));
       MaybeAppendInternal(updated_name, lhs_dim_map, rhs_dim_map,
-                          further_sharded_dim_map, device_mesh, compute_cost,
+                          further_sharded_dim_map, compute_cost,
                           communication_cost_fn);
     }
   }
 }
 
 absl::Status DotHandler::RegisterStrategies() {
-  absl::flat_hash_set<int> all_mesh_dims;
-  for (int i = 0; i < device_mesh_.num_dimensions(); ++i) {
-    all_mesh_dims.insert(i);
-  }
-  EnumerateGeneral(
+  std::vector<int> all_mesh_dims(device_mesh_.num_dimensions());
+  std::iota(all_mesh_dims.begin(), all_mesh_dims.end(), 0);
+  Enumerate(
       /*split_func=*/
       [&](const DimMap& output_dim_map) {
         GenerateDotShardingStrategiesFromOutputSharding(output_dim_map);
       },
-      ins_->shape().rank(), /*current_tensor_dim=*/0, all_mesh_dims,
+      ins_->shape().rank(), /*current_mesh_dim_idx=*/0, all_mesh_dims,
       /*current_dim_map=*/{});
   SortStrategies();
   return absl::OkStatus();
@@ -853,27 +870,27 @@
 
   // Propagate batch dim sharding
   auto it = output_dim_map.find(out_batch_dim_);
-  if (it != output_dim_map.end() && device_mesh_.dim(it->second) > 1) {
-    int mesh_dim = it->second;
-    lhs_dim_map[lhs_batch_dim_] = mesh_dim;
-    used_mesh_dims.insert(mesh_dim);
-    absl::StrAppend(&name, "b", mesh_dim);
+  if (it != output_dim_map.end() && IsMeshDimSetNonTrivial(it->second)) {
+    const MeshDimSet& mesh_dim_set = it->second;
+    lhs_dim_map[lhs_batch_dim_] = mesh_dim_set;
+    used_mesh_dims.insert(mesh_dim_set.begin(), mesh_dim_set.end());
+    absl::StrAppend(&name, "b", ToString(mesh_dim_set));
   } else {
     absl::StrAppend(&name, "b-1");
   }
 
   // Propagate out channel dim sharding
   it = output_dim_map.find(out_out_channel_dim_);
-  if (it != output_dim_map.end() && device_mesh_.dim(it->second) > 1) {
-    int mesh_dim = it->second;
-    lhs_dim_map[rhs_out_channel_dim_] = mesh_dim;
-    used_mesh_dims.insert(mesh_dim);
-    absl::StrAppend(&name, "oc", mesh_dim);
+  if (it != output_dim_map.end() && IsMeshDimSetNonTrivial(it->second)) {
+    const MeshDimSet& mesh_dim_set = it->second;
+    lhs_dim_map[rhs_out_channel_dim_] = mesh_dim_set;
+    used_mesh_dims.insert(mesh_dim_set.begin(), mesh_dim_set.end());
+    absl::StrAppend(&name, "oc", ToString(mesh_dim_set));
   } else {
     absl::StrAppend(&name, "oc-1");
   }
 
-  MaybeAppend(name, lhs_dim_map, rhs_dim_map, output_dim_map, device_mesh_);
+  MaybeAppend(name, lhs_dim_map, rhs_dim_map, output_dim_map);
 
   // Generate shardings for contraction dimensions
   if (used_mesh_dims.size() == device_mesh_.num_dimensions()) {
@@ -891,12 +908,12 @@
     return;
   }
 
-  for (int64_t mesh_dim : unused_mesh_dims) {
+  for (int mesh_dim : unused_mesh_dims) {
     DimMap lhs_dim_map_with_contractions = lhs_dim_map;
     DimMap rhs_dim_map_with_contractions = rhs_dim_map;
 
-    lhs_dim_map_with_contractions[lhs_in_channel_dim_] = mesh_dim;
-    rhs_dim_map_with_contractions[rhs_in_channel_dim_] = mesh_dim;
+    lhs_dim_map_with_contractions[lhs_in_channel_dim_] = MeshDimSet{mesh_dim};
+    rhs_dim_map_with_contractions[rhs_in_channel_dim_] = MeshDimSet{mesh_dim};
     absl::StrAppend(&name, "ic", mesh_dim, "@allreduce");
 
     auto communication_cost_fn = [&](const HloSharding& output_sharding) {
@@ -906,7 +923,7 @@
     };
 
     MaybeAppend(name, lhs_dim_map_with_contractions,
-                rhs_dim_map_with_contractions, output_dim_map, device_mesh_,
+                rhs_dim_map_with_contractions, output_dim_map,
                 /*compute_cost=*/0, communication_cost_fn);
   }
 }
@@ -931,15 +948,13 @@
     SplitDepthwise(false);
   }
 
-  absl::flat_hash_set<int> all_mesh_dims;
-  for (int i = 0; i < device_mesh_.num_dimensions(); ++i) {
-    all_mesh_dims.insert(i);
-  }
-  EnumerateGeneral(
+  std::vector<int> all_mesh_dims(device_mesh_.num_dimensions());
+  std::iota(all_mesh_dims.begin(), all_mesh_dims.end(), 0);
+  Enumerate(
       [&](const DimMap& output_dim_map) {
         GenerateConvolutionShardingStrategiesFromOutputSharding(output_dim_map);
       },
-      2, /*current_tensor_dim=*/0, all_mesh_dims,
+      2, /*current_mesh_dim_idx=*/0, all_mesh_dims,
       /*current_dim_map=*/{});
 
   // If force_batch_dim_to_mesh_dim is set, filter out invalid strategies
@@ -957,38 +972,37 @@
 void ConvHandler::SplitDepthwise(bool forward) {
   std::function<void(const DimMap&)> split_func =
       [&](const DimMap& output_dim_map) {
-        int out_batch_mesh_dim = -1;
-        int out_out_channel_mesh_dim = -1;
+        MeshDimSet out_batch_mesh_dim_set;
+        MeshDimSet out_out_channel_mesh_dim_set;
         if (auto it = output_dim_map.find(out_batch_dim_);
             it != output_dim_map.end()) {
-          out_batch_mesh_dim = it->second;
+          out_batch_mesh_dim_set = it->second;
         }
         if (auto it = output_dim_map.find(out_out_channel_dim_);
             it != output_dim_map.end()) {
-          out_out_channel_mesh_dim = it->second;
+          out_out_channel_mesh_dim_set = it->second;
         }
-        if (out_batch_mesh_dim == -1 || out_out_channel_mesh_dim == -1) {
+        if (out_batch_mesh_dim_set.empty() ||
+            out_out_channel_mesh_dim_set.empty()) {
           return;
         }
 
         DimMap lhs_dim_map, rhs_dim_map;
         lhs_dim_map[lhs_batch_dim_] =
-            forward ? out_batch_mesh_dim : out_out_channel_mesh_dim;
+            forward ? out_batch_mesh_dim_set : out_out_channel_mesh_dim_set;
         lhs_dim_map[lhs_in_channel_dim_] =
-            forward ? out_out_channel_mesh_dim : out_batch_mesh_dim;
+            forward ? out_out_channel_mesh_dim_set : out_batch_mesh_dim_set;
 
-        rhs_dim_map[rhs_out_channel_dim_] = out_out_channel_mesh_dim;
+        rhs_dim_map[rhs_out_channel_dim_] = out_out_channel_mesh_dim_set;
 
-        MaybeAppend(absl::StrCat("b", out_batch_mesh_dim, "oc",
-                                 out_out_channel_mesh_dim, "|depthwise"),
-                    lhs_dim_map, rhs_dim_map, output_dim_map, device_mesh_);
+        MaybeAppend(
+            absl::StrCat("b", ToString(out_batch_mesh_dim_set), "oc",
+                         ToString(out_out_channel_mesh_dim_set), "|depthwise"),
+            lhs_dim_map, rhs_dim_map, output_dim_map);
       };
-  absl::flat_hash_set<int> all_mesh_dims;
-  for (int i = 0; i < device_mesh_.num_dimensions(); ++i) {
-    all_mesh_dims.insert(i);
-  }
-  EnumerateGeneral(split_func, 2, /*current_tensor_dim=*/0, all_mesh_dims,
-                   /*current_dim_map=*/{});
+  std::vector<int> all_mesh_dims(device_mesh_.num_dimensions());
+  Enumerate(split_func, 2, /*current_mesh_dim_idx=*/0, all_mesh_dims,
+            /*current_dim_map=*/{});
 }
 
 }  // namespace
diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.cc
index d2f124b..56ca325 100644
--- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.cc
+++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.cc
@@ -82,8 +82,6 @@
       absl::StrCat("allow_recompute_heavy_op: ", allow_recompute_heavy_op));
   lines.push_back(
       absl::StrCat("allow_mixed_mesh_shape: ", allow_mixed_mesh_shape));
-  lines.push_back(
-      absl::StrCat("grad_acc_num_micro_batches: ", grad_acc_num_micro_batches));
   lines.push_back(absl::StrCat("solve_nd_sharding_iteratively: ",
                                solve_nd_sharding_iteratively));
   lines.push_back(
diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.h
index 51eceae..5e73af7 100644
--- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.h
+++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.h
@@ -119,11 +119,6 @@
   // If true, allow adding 1d strategies in 2d logical mesh.
   bool allow_mixed_mesh_shape = true;
 
-  // The number of micro batches if gradient accumulation is used.
-  // If this is not 1, the cost of all-reduce for gradient synchronization
-  // is divided by this number.
-  int grad_acc_num_micro_batches = 1;
-
   // If true, N-D sharding (e.g., N maybe be 2 or 3) will be solved in N
   // iterations, where one iteration chooses one tensor dimension to shard. If
   // false, solve N-D sharding directly, i.e., generating all possible sharding
diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc
index f204ff4..67a6fed 100644
--- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc
+++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc
@@ -67,6 +67,10 @@
 // solver cannot guarantee exact numerical precision.
 constexpr double kMaxCostEpsilon = 1.0001;
 
+// Memory contributions in the Mixed ILP are converted to units in this range;
+// beware that significantly larger / smaller values can cause numerical issues.
+constexpr double kMemoryMultiplier = 1e6;
+
 bool AutoShardingSolverOutput::operator==(
     const AutoShardingSolverOutput& other) const {
   return s_val == other.s_val && cost == other.cost &&
@@ -261,7 +265,7 @@
       reduced_groups.push_back({group.prims().begin(), group.prims().end()});
     }
   }
-  solver.MakeIntVarArray(reduced_groups.size(), 0.0, MPSolver::infinity(),
+  solver.MakeNumVarArray(reduced_groups.size(), 0.0, MPSolver::infinity(),
                          absl::StrCat("group_", prim_type), &group_vars);
   for (int64_t group_idx = 0; group_idx < group_vars.size(); ++group_idx) {
     MPConstraint* constraint = solver.MakeRowConstraint(
@@ -271,7 +275,7 @@
     for (const int64_t prim_idx : reduced_groups[group_idx]) {
       for (int64_t j = 0; j < prim_vars[prim_idx].size(); ++j) {
         double memory_cost = memory_costs.at(prim_idx).costs(j);
-        memory_cost /= request.memory_budget() / 100.0;
+        memory_cost /= request.memory_budget() / kMemoryMultiplier;
         const double accumulated_coefficient =
             constraint->GetCoefficient(prim_vars[prim_idx][j]);
         constraint->SetCoefficient(prim_vars[prim_idx][j],
@@ -302,9 +306,12 @@
          time_idx <= intervals[prim_idx].second; ++time_idx) {
       if (!reduced_times.contains(time_idx)) continue;
       if (!constraints.contains(time_idx)) {
-        MPConstraint* constraint = solver.MakeRowConstraint(
-            -MPSolver::infinity(), 100.0, absl::StrCat("mem[", time_idx, "]"));
-        if (overbudget_var) constraint->SetCoefficient(overbudget_var, -100.0);
+        MPConstraint* constraint =
+            solver.MakeRowConstraint(-MPSolver::infinity(), kMemoryMultiplier,
+                                     absl::StrCat("mem[", time_idx, "]"));
+        if (overbudget_var) {
+          constraint->SetCoefficient(overbudget_var, -kMemoryMultiplier);
+        }
         constraints[time_idx] = constraint;
       }
       MPConstraint* constraint = constraints[time_idx];
@@ -314,7 +321,7 @@
       }
       for (int64_t j = 0; j < prim_vars[prim_idx].size(); ++j) {
         double memory_cost = memory_costs.at(prim_idx).costs(j);
-        memory_cost /= request.memory_budget() / 100.0;
+        memory_cost /= request.memory_budget() / kMemoryMultiplier;
         const double accumulated_coefficient =
             constraint->GetCoefficient(prim_vars[prim_idx][j]);
         constraint->SetCoefficient(prim_vars[prim_idx][j],
diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc
index 5a237b2..05631bc 100644
--- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc
+++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc
@@ -441,6 +441,46 @@
   EXPECT_EQ(result, expected_result);
 }
 
+TEST(CallORToolsSolverTest, HandlesGroupsWithTinyMemoryCosts) {
+  AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest();
+  const std::vector<std::pair<int64_t, int64_t>> node_intervals =
+      {{5, -1}, {5, -1}, {2, 3}, {3, 4}, {100, -1}, {0, 4}};
+  const std::vector<std::pair<int64_t, int64_t>> edge_intervals =
+      {{1, 2}, {2, 3}};
+  const std::vector<std::vector<int64_t>> node_groups = {{0, 1}};
+  const std::vector<std::vector<int64_t>> edge_groups = {};
+  const CostMatrix memory_costs = {{1, 1, 1, 1},  // These values are tiny and
+                                   {2, 2, 2},     // shouldn't be rounded up.
+                                   {300, 300, 300, 300, 300, 300, 300},
+                                   {4000, 4000, 4000, 4000, 4000, 4000, 4000},
+                                   {50000, 50000, 50000}};
+  const CostMatrix memory_edge_costs = {{0, 0, 0, 0,
+                                         0, 0, 0, 0,
+                                         0, 0, 0, 0,
+                                         0, 0, 0, 0},
+                                        {0, 0, 0, 0,
+                                         0, 0, 0, 0,
+                                         0, 0, 0, 0}};
+  request.clear_live();
+  request.clear_memory_costs();
+  AddIntervals(request.mutable_node_intervals(), node_intervals);
+  AddIntervals(request.mutable_edge_intervals(), edge_intervals);
+  AddGroups(request.mutable_node_groups(), node_groups);
+  AddGroups(request.mutable_edge_groups(), edge_groups);
+  AddCosts(request.mutable_memory_costs(), memory_costs);
+  AddCosts(request.mutable_memory_edge_costs(), memory_edge_costs);
+  request.set_enable_memory_edge_costs(true);
+  request.set_memory_budget(4321);
+
+  const AutoShardingSolverResult result = CallORToolsSolver(request);
+
+  const std::vector<NodeStrategyIdx> s_val = {0, 0, 0, 0, 0};
+  const double objective_value = 7650.0;
+  const AutoShardingSolverOutput expected_output = {s_val, objective_value};
+  const AutoShardingSolverResult expected_result = {expected_output, false};
+  EXPECT_EQ(result, expected_result);
+}
+
 TEST(CallORToolsSolverTest, SolvesWithEquivalences) {
   const AutoShardingSolverRequest request =
       AutoShardingSolverRequestWithEquivalences();
diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc
index dd6d38c..27ddc79 100644
--- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc
+++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc
@@ -824,15 +824,7 @@
               }
             };
 
-        if (IsCustomCallMarker(ins)) {
-          const HloInstruction* operand = ins->operand(0);
-          const StrategyGroup* src_strategy_group =
-              strategy_map.at(operand).get();
-          CHECK(src_strategy_group->is_tuple);
-          strategy_group = MaybeFollowInsStrategyGroup(
-              src_strategy_group, ins->shape(), instruction_id, strategy_groups,
-              cluster_env, pretrimmed_strategy_map);
-        } else if (IsSPMDFullToShardShapeCustomCall(ins)) {
+        if (IsSPMDFullToShardShapeCustomCall(ins)) {
           return absl::InternalError(
               "An SPMDFullToShardShape call found outside a manually "
               "partitioned sub-graph.");
@@ -1030,24 +1022,6 @@
     strategy_map[ins] = std::move(strategy_group);
   }  // end of for loop
 
-  // If gradient accumulation is used, adjust the cost of all-reduce for
-  // gradient synchronization.
-  if (option.grad_acc_num_micro_batches > 1) {
-    // find gradient-computation instructions
-    std::vector<const HloInstruction*> grad_insts =
-        GetGradientComputationInstructions(instructions);
-    for (const HloInstruction* inst : grad_insts) {
-      StrategyGroup* stra_vector = strategy_map[inst].get();
-      CHECK(!stra_vector->is_tuple);
-
-      for (auto& stra : stra_vector->strategies) {
-        if (absl::StrContains(stra.name, "allreduce")) {
-          stra.communication_cost /= option.grad_acc_num_micro_batches;
-        }
-      }
-    }
-  }
-
   return std::make_tuple(std::move(strategy_map), std::move(strategy_groups),
                          std::move(associative_dot_pairs));
 }
diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc
index c899a0c..8ab3824 100644
--- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc
+++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc
@@ -29,6 +29,7 @@
 #include "absl/log/log.h"
 #include "absl/status/statusor.h"
 #include "absl/strings/string_view.h"
+#include "xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h"
 #include "xla/hlo/experimental/auto_sharding/auto_sharding_option.h"
 #include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h"
 #include "xla/hlo/experimental/auto_sharding/auto_sharding_util.h"
@@ -554,7 +555,7 @@
   TF_ASSERT_OK_AND_ASSIGN(bool changed, AutoSharding(option).Run(module.get()));
   VLOG(10) << module->ToString();
   EXPECT_TRUE(changed);
-  auto* instruction = FindInstruction(module.get(), "p0");
+  const HloInstruction* instruction = FindInstruction(module.get(), "p0");
   ASSERT_NE(instruction, nullptr);
   EXPECT_THAT(instruction, op::Sharding("{replicated}"));
 }
@@ -687,14 +688,49 @@
   TF_ASSERT_OK_AND_ASSIGN(bool changed, AutoSharding(option).Run(module.get()));
   VLOG(10) << module->ToString();
   EXPECT_TRUE(changed);
-  auto* param0 = FindInstruction(module.get(), "param.0");
-  auto* param1 = FindInstruction(module.get(), "param.1");
+  const HloInstruction* param0 = FindInstruction(module.get(), "param.0");
+  const HloInstruction* param1 = FindInstruction(module.get(), "param.1");
   ASSERT_NE(param0, nullptr);
   ASSERT_NE(param0, nullptr);
   EXPECT_THAT(param0, op::Sharding("{replicated}"));
   EXPECT_THAT(param1, op::Sharding("{replicated}"));
 }
 
+TEST_F(AutoShardingTest, DotMixedMeshStrategies) {
+  constexpr absl::string_view kHloString = R"(
+HloModule module
+ENTRY %entry {
+  %param0 = f32[8192,23]{1,0} parameter(0), sharding={devices=[4,1]0,1,2,3}
+  %param1 = f32[23,23]{1,0} parameter(1)
+  %dot = f32[8192,23]{1,0} dot(%param0, %param1), lhs_contracting_dims={1}, rhs_contracting_dims={1}
+  ROOT %copy = f32[8192,23]{1,0} copy(%dot)
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+                          ParseAndReturnVerifiedModule(kHloString));
+  AutoShardingOption option;
+  option.enable = true;
+  option.device_mesh_shape = {2, 2};
+  option.device_mesh_ids = {0, 1, 2, 3};
+  option.device_mesh_alpha = {1.0, 1.0};
+  option.device_mesh_beta = {0.01, 1.0};
+  option.solve_nd_sharding_iteratively = false;
+  option.preserve_shardings =
+      AutoShardingOption::PreserveShardingsType::kKeepAllShardings;
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, AutoSharding(option).Run(module.get()));
+  VLOG(2) << module->ToString();
+  EXPECT_TRUE(changed);
+  const HloInstruction* param0 = FindInstruction(module.get(), "param0");
+  const HloInstruction* param1 = FindInstruction(module.get(), "param1");
+  const HloInstruction* dot = FindInstruction(module.get(), "dot");
+  ASSERT_NE(param0, nullptr);
+  ASSERT_NE(param1, nullptr);
+  ASSERT_NE(dot, nullptr);
+  EXPECT_THAT(param0, op::Sharding("{devices=[4,1]0,1,2,3}"));
+  EXPECT_THAT(param1, op::Sharding("{replicated}"));
+  EXPECT_THAT(dot, op::Sharding("{devices=[4,1]0,1,2,3}"));
+}
+
 TEST_F(AutoShardingTest, DotLHSTwoNonContractingDims) {
   constexpr absl::string_view kHloString = R"(
 HloModule module
@@ -717,9 +753,9 @@
   TF_ASSERT_OK_AND_ASSIGN(bool changed, AutoSharding(option).Run(module.get()));
   VLOG(2) << module->ToString();
   EXPECT_TRUE(changed);
-  auto* param0 = FindInstruction(module.get(), "param0");
-  auto* param1 = FindInstruction(module.get(), "param1");
-  auto* dot = FindInstruction(module.get(), "dot");
+  const HloInstruction* param0 = FindInstruction(module.get(), "param0");
+  const HloInstruction* param1 = FindInstruction(module.get(), "param1");
+  const HloInstruction* dot = FindInstruction(module.get(), "dot");
   ASSERT_NE(param0, nullptr);
   ASSERT_NE(param1, nullptr);
   ASSERT_NE(dot, nullptr);
@@ -770,9 +806,9 @@
   TF_ASSERT_OK_AND_ASSIGN(bool changed, AutoSharding(option).Run(module.get()));
   VLOG(2) << module->ToString();
   EXPECT_TRUE(changed);
-  auto* param0 = FindInstruction(module.get(), "param0");
-  auto* param1 = FindInstruction(module.get(), "param1");
-  auto* dot = FindInstruction(module.get(), "dot");
+  const HloInstruction* param0 = FindInstruction(module.get(), "param0");
+  const HloInstruction* param1 = FindInstruction(module.get(), "param1");
+  const HloInstruction* dot = FindInstruction(module.get(), "dot");
   ASSERT_NE(param0, nullptr);
   ASSERT_NE(param1, nullptr);
   ASSERT_NE(dot, nullptr);
@@ -2482,6 +2518,36 @@
             input_output_alias_config_after.ToString());
 }
 
+TEST(NormalizeTest, NormalizeHandlesNegativeCosts) {
+  EdgeReshardingCostMatrix edge_cost(2, 2);
+  edge_cost(0, 0).communication_cost = -100;
+  edge_cost(0, 1).communication_cost = 200;
+  edge_cost(1, 0).communication_cost = 300;
+  edge_cost(1, 1).communication_cost = 400;
+
+  const EdgeReshardingCostMatrix normalized_edge_cost = Normalize(edge_cost);
+
+  EXPECT_EQ(normalized_edge_cost(0, 0).communication_cost, 0);
+  EXPECT_EQ(normalized_edge_cost(0, 1).communication_cost, 300);
+  EXPECT_EQ(normalized_edge_cost(1, 0).communication_cost, 400);
+  EXPECT_EQ(normalized_edge_cost(1, 1).communication_cost, 500);
+}
+
+TEST(NormalizeTest, NormalizeHandlesPositiveCosts) {
+  EdgeReshardingCostMatrix edge_cost(2, 2);
+  edge_cost(0, 0).communication_cost = 100;
+  edge_cost(0, 1).communication_cost = 200;
+  edge_cost(1, 0).communication_cost = 300;
+  edge_cost(1, 1).communication_cost = 400;
+
+  const EdgeReshardingCostMatrix normalized_edge_cost = Normalize(edge_cost);
+
+  EXPECT_EQ(normalized_edge_cost(0, 0).communication_cost, 100);
+  EXPECT_EQ(normalized_edge_cost(0, 1).communication_cost, 200);
+  EXPECT_EQ(normalized_edge_cost(1, 0).communication_cost, 300);
+  EXPECT_EQ(normalized_edge_cost(1, 1).communication_cost, 400);
+}
+
 }  // namespace
 }  // namespace spmd
 }  // namespace xla
diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc
index cb8ec2c..d49f16b 100644
--- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc
+++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc
@@ -59,16 +59,10 @@
 #include "xla/shape_tree.h"
 #include "xla/shape_util.h"
 #include "xla/xla_data.pb.h"
-#include "tsl/platform/errors.h"
 
 namespace xla {
 namespace spmd {
 
-inline const HloInstruction* PassThroughCustomCallMarkerGetSource(
-    const HloInstruction* ins);
-inline HloInstruction* PassThroughCustomCallMarkerUser(
-    HloInstruction* raw_user, const HloInstruction* inst);
-
 std::optional<HloSharding> GetInputSharding(const HloInstruction* ins,
                                             int64_t op_index,
                                             const HloSharding& output_sharding,
@@ -109,26 +103,6 @@
   return inferred_sharding;
 }
 
-// Return whether the instruction is an activation from another pipeline stage.
-bool IsActivationFromAnotherStage(const HloInstruction* ins,
-                                  const InstructionBatchDimMap& batch_dim_map) {
-  if (!(ins->opcode() == HloOpcode::kParameter &&
-        batch_dim_map.contains(GetBatchDimMapKey(ins)))) {
-    return false;
-  }
-
-  for (const HloInstruction* user : ins->users()) {
-    if (!(user->opcode() == HloOpcode::kTuple && user->users().size() == 1 &&
-          user->users().front()->IsCustomCall(kPipelineMarker) &&
-          absl::StrContains(user->users().front()->metadata().op_type(),
-                            "start"))) {
-      return false;
-    }
-  }
-
-  return true;
-}
-
 // Propagate sharding for dim-wise operations (e.g., slice, pad) which works
 // independently on each dimension.
 // The sharding can successfully propagate if the operation only happens
@@ -194,11 +168,6 @@
     if (degree_dict[inst] == 0) {
       depth_map[inst] = 0;
 
-      // Add some initial depth for activations from other pipeline stages.
-      if (IsActivationFromAnotherStage(inst, batch_dim_map)) {
-        depth_map[inst] = 20;
-      }
-
       current_frontier.push_back(inst);
       collected++;
     }
@@ -246,10 +215,6 @@
 
           if (reset) {
             depth_map[node] = 0;
-          } else if (node->opcode() == HloOpcode::kGetTupleElement &&
-                     IsCustomCallMarker(node->operand(0))) {
-            depth_map[node] =
-                depth_map.at(PassThroughCustomCallMarkerGetSource(node));
           } else {
             int64_t max_depth = depth_map.at(inst) + delta;
             for (const HloInstruction* operand : node->operands()) {
@@ -813,12 +778,6 @@
         batch_map[ins->name()] = batch_dim_of_source;
       }
     }
-
-    if (ins->IsCustomCall(kPipelineMarker) &&
-        absl::StrContains(ins->metadata().op_type(), "start")) {
-      // Reset the status after meet a new pipeline marker.
-      set_the_next_dot_conv = true;
-    }
   }
   int64_t previous_cnt = 0;
   while (true) {
@@ -968,8 +927,7 @@
   for (HloInstruction* node : boundary_set) {
     HloInstruction* cur = node;
     while (cur->operand_count() == 1) {
-      HloInstruction* operand =
-          PassThroughCustomCallMarkerOperand(cur->mutable_operand(0), cur);
+      HloInstruction* operand = cur->mutable_operand(0);
       if (replicated_set.contains(operand)) {
         path.insert(cur);
       }
@@ -1007,8 +965,7 @@
 
   // Find the add instruction for grad accumulation, skip the identity marker
   // for remat and other elementwise ops.
-  HloInstruction* add =
-      PassThroughCustomCallMarkerUser(inst->users().front(), inst);
+  HloInstruction* add = inst->users().front();
   if (add->opcode() == HloOpcode::kGetTupleElement ||
       add->opcode() == HloOpcode::kTranspose) {
     if (add->users().size() != 1) {
@@ -1025,7 +982,7 @@
     }
     CHECK_EQ(add->users().size(), 1);
     // Skip the end marker of backward computation
-    add = PassThroughCustomCallMarkerUser(add->users().front(), add);
+    add = add->users().front();
 
     // Do not partition the dot, add and parameter, so we can generate
     // all-reduce for grad accumulation.
@@ -1037,7 +994,7 @@
 
       replicated_set.erase(cur);
       for (auto x : cur->operands()) {
-        dfs_remove(PassThroughCustomCallMarkerOperand(x, cur));
+        dfs_remove(x);
       }
     };
 
@@ -1138,39 +1095,23 @@
   return sharded_dims <= 0;
 }
 
-absl::StatusOr<std::vector<int64_t>> GetTensorDimToMeshDimNoCrash(
-    int64_t tensor_shape_rank, const HloSharding& spec,
-    const Array<int64_t>& device_mesh, bool consider_reverse_device_meshes) {
-  if (spec.IsReplicated()) {
-    return std::vector<int64_t>(tensor_shape_rank, -1);
-  }
-  // Check the compatibility of tensor_shape_rank and spec
-  if (tensor_shape_rank != spec.TiledDataRank()) {
-    return absl::InvalidArgumentError(
-        "Tensor shape rank should be equal to the tiled data rank of the input "
-        "spec.");
-  }
-
+absl::StatusOr<std::vector<int64_t>> GetMeshDimPermutationOrderInShardingSpec(
+    const HloSharding& spec, const Array<int64_t>& device_mesh,
+    bool consider_reverse_device_meshes) {
   auto check_mesh =
       [&](const Array<int64_t>& mesh) -> std::optional<std::vector<int64_t>> {
     // Permute the dimensions (or axes in numpy term), find the transform that
     // makes tile_assignment == device_mesh.
     std::vector<int64_t> axes(mesh.num_dimensions());
     absl::c_iota(axes, 0);
-    bool found = false;
     do {
       Array<int64_t> transposed_mesh = Transpose(mesh, axes);
       if (std::equal(transposed_mesh.begin(), transposed_mesh.end(),
                      spec.tile_assignment().array().begin())) {
-        found = true;
-        break;
+        return axes;
       }
     } while (absl::c_next_permutation(axes));
-    if (found) {
-      return std::optional<std::vector<int64_t>>(axes);
-    } else {
-      return std::nullopt;
-    }
+    return std::nullopt;
   };
 
   // This is an expensive search, as we try all possible meshes obtained by
@@ -1178,7 +1119,6 @@
   // the somewhat rare kReverse HLO op. The hope therefore is that most calls to
   // the function that reach here will find a mapping within the first iteration
   // of the loop below.
-  bool found = false;
   std::vector<int64_t> axes(device_mesh.num_dimensions());
   size_t num_subsets =
       consider_reverse_device_meshes ? (1 << device_mesh.num_dimensions()) : 1;
@@ -1199,24 +1139,35 @@
       *device = device_mesh(original_indices);
     });
     if (auto result = check_mesh(new_mesh); result.has_value()) {
-      axes = result.value();
-      found = true;
-      break;
+      return result.value();
     }
   }
+  return absl::NotFoundError(absl::StrCat("Could not find mapping for ",
+                                          spec.ToString(), " with device mesh ",
+                                          device_mesh.ToString()));
+}
 
-  if (!found) {
-    return absl::NotFoundError(
-        absl::StrCat("Could not find mapping for ", spec.ToString(),
-                     " with device mesh ", device_mesh.ToString()));
+absl::StatusOr<std::vector<int64_t>> GetTensorDimToMeshDimNoCrash(
+    int64_t tensor_shape_rank, const HloSharding& spec,
+    const Array<int64_t>& device_mesh, bool consider_reverse_device_meshes) {
+  if (spec.IsReplicated()) {
+    return std::vector<int64_t>(tensor_shape_rank, -1);
   }
-
+  // Check the compatibility of tensor_shape_rank and spec
+  if (tensor_shape_rank != spec.TiledDataRank()) {
+    return absl::InvalidArgumentError(
+        "Tensor shape rank should be equal to the tiled data rank of the input "
+        "spec.");
+  }
   if (!TileAssignmentMatchesMesh(spec, device_mesh)) {
     return absl::InvalidArgumentError(
         "Device mesh and tile assignment need to have the same number of "
         "sharded dims.");
   }
 
+  TF_ASSIGN_OR_RETURN(std::vector<int64_t> axes,
+                      GetMeshDimPermutationOrderInShardingSpec(
+                          spec, device_mesh, consider_reverse_device_meshes));
   // Transform tile_assignment_dimensions using found transformation (axes).
   std::vector<int64_t> tensor_dim_to_device_dim(tensor_shape_rank, -1);
   int mesh_index = 0;
@@ -1558,7 +1509,7 @@
 // Create a HloSharding that tiles some tensor dims on some device mesh dims.
 HloSharding Tile(const Shape& tensor_shape,
                  absl::Span<const int64_t> tensor_dims,
-                 absl::Span<const int64_t> mesh_dims,
+                 const std::vector<std::vector<int64_t>>& mesh_dims,
                  const Array<int64_t>& device_mesh) {
   CHECK_EQ(tensor_dims.size(), mesh_dims.size());
   CHECK(tensor_shape.IsArray());
@@ -1567,8 +1518,12 @@
   // Split on certain mesh dimensions
   int64_t split_prod = 1;
   for (size_t i = 0; i < tensor_dims.size(); ++i) {
-    tile_assignment_dimensions[tensor_dims[i]] = device_mesh.dim(mesh_dims[i]);
-    split_prod *= device_mesh.dim(mesh_dims[i]);
+    int64_t num_devices_for_tensor_dim = 1;
+    for (int64_t mesh_dim_idx : mesh_dims[i]) {
+      num_devices_for_tensor_dim *= device_mesh.dim(mesh_dim_idx);
+    }
+    tile_assignment_dimensions[tensor_dims[i]] = num_devices_for_tensor_dim;
+    split_prod *= num_devices_for_tensor_dim;
   }
   // Replicate on remaining mesh dimensions
   bool replicate_on_last_tile_dim = false;
@@ -1582,35 +1537,58 @@
   std::vector<int64_t> tile_assignment_devices;
   tile_assignment_devices.reserve(device_mesh.num_elements());
 
-  std::vector<int64_t> tmp_indices(device_mesh.num_dimensions(), 0);
-  std::function<void(int64_t, std::vector<int64_t>)>
+  std::function<void(int64_t, int64_t, std::vector<int64_t>)>
       generate_tile_assignment_devices;
-  generate_tile_assignment_devices = [&](int64_t tensor_dim,
+  generate_tile_assignment_devices = [&](int64_t current_tensor_dim,
+                                         int64_t current_mesh_dim_idx,
                                          std::vector<int64_t> mesh_indices) {
-    if (tensor_dim == tensor_shape.rank() - 1) {
-      AppendFlattenElements(&tile_assignment_devices, device_mesh, mesh_indices,
-                            -1, tmp_indices);
+    int64_t current_tensor_dim_index =
+        GetIndex(tensor_dims, current_tensor_dim);
+    bool proceed_to_next_tensor_dim = false;
+    if (current_tensor_dim_index >= 0) {
+      proceed_to_next_tensor_dim =
+          (current_mesh_dim_idx ==
+           mesh_dims[current_tensor_dim_index].size() - 1);
     } else {
-      int64_t next_tensor_dim = tensor_dim + 1;
-      int64_t next_mesh_dim = -1;
+      proceed_to_next_tensor_dim = true;
+    }
 
-      int64_t index = GetIndex(tensor_dims, next_tensor_dim);
-      if (index >= 0) {
-        next_mesh_dim = mesh_dims[index];
-      }
+    if (proceed_to_next_tensor_dim &&
+        current_tensor_dim == tensor_shape.rank() - 1) {
+      AppendFlattenElements(&tile_assignment_devices, device_mesh,
+                            mesh_indices);
+      return;
+    }
 
-      for (int64_t i = 0; i < tile_assignment_dimensions[next_tensor_dim];
-           ++i) {
-        if (next_mesh_dim != -1) {
-          mesh_indices[next_mesh_dim] = i;
-        }
-        generate_tile_assignment_devices(next_tensor_dim, mesh_indices);
+    int64_t next_tensor_dim, next_mesh_dim_idx = -1, next_mesh_dim = -1;
+    if (proceed_to_next_tensor_dim) {
+      next_tensor_dim = current_tensor_dim + 1;
+      next_mesh_dim_idx = -1;
+      int64_t next_tensor_dim_index = GetIndex(tensor_dims, next_tensor_dim);
+      if (next_tensor_dim_index >= 0) {
+        next_mesh_dim_idx = 0;
+        next_mesh_dim = mesh_dims[next_tensor_dim_index][0];
       }
+    } else {
+      next_tensor_dim = current_tensor_dim;
+      next_mesh_dim_idx = current_mesh_dim_idx + 1;
+      next_mesh_dim = mesh_dims[current_tensor_dim_index][next_mesh_dim_idx];
+    }
+
+    int64_t limit =
+        (next_mesh_dim_idx >= 0) ? device_mesh.dim(next_mesh_dim) : 1;
+    for (int64_t i = 0; i < limit; ++i) {
+      if (next_mesh_dim != -1) {
+        mesh_indices[next_mesh_dim] = i;
+      }
+      generate_tile_assignment_devices(next_tensor_dim, next_mesh_dim_idx,
+                                       mesh_indices);
     }
   };
 
   std::vector<int64_t> mesh_indices(device_mesh.num_dimensions(), -1);
-  generate_tile_assignment_devices(-1, mesh_indices);
+  generate_tile_assignment_devices(/*current_tensor_dim=*/-1,
+                                   /*current_mesh_dim_idx=*/-1, mesh_indices);
 
   // Make HloSharding
   Array<int64_t> tile_assignment(tile_assignment_dimensions);
@@ -1625,6 +1603,17 @@
              : HloSharding::Tile(std::move(tile_assignment));
 }
 
+HloSharding Tile(const Shape& tensor_shape,
+                 absl::Span<const int64_t> tensor_dims,
+                 absl::Span<const int64_t> mesh_dims,
+                 const Array<int64_t>& device_mesh) {
+  std::vector<std::vector<int64_t>> mesh_dims_general(mesh_dims.size());
+  for (int i = 0; i < mesh_dims.size(); ++i) {
+    mesh_dims_general[i].push_back(mesh_dims[i]);
+  }
+  return Tile(tensor_shape, tensor_dims, mesh_dims_general, device_mesh);
+}
+
 AliasMap BuildAliasMap(const HloModule* module,
                        const HloInputOutputAliasConfig& alias_config) {
   AliasMap alias_map;
@@ -1633,10 +1622,6 @@
   const auto& parameter_instructions = entry->parameter_instructions();
   const HloInstruction* output_tuple = entry->root_instruction();
 
-  if (IsCustomCallMarker(output_tuple)) {
-    output_tuple = output_tuple->operand(0);
-  }
-
   absl::flat_hash_map<int64_t, absl::flat_hash_map<int64_t, HloInstruction*>>
       parameter_index_to_operand_map;
   alias_config.ForEachAlias([&](const ShapeIndex& output_index,
diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h
index 64749c6..f114a3d 100644
--- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h
+++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h
@@ -48,10 +48,7 @@
 namespace xla {
 namespace spmd {
 
-inline constexpr absl::string_view kPipelineMarker = "xla_pipeline_marker";
 inline constexpr absl::string_view kIdentityMarker = "identity";
-inline constexpr absl::string_view kPipelineMarkerStartType = "start";
-inline constexpr absl::string_view kPipelineMarkerEndType = "end";
 
 inline constexpr int64_t kAutoShardingPointerSize = 8;
 
@@ -94,9 +91,11 @@
 // Append elements of `array` to `result`. The `indices` is a generalized
 // multi-dimensional index that can index a whole row (use -1 to indicate this).
 template <typename T>
-void AppendFlattenElements(std::vector<T>* result, const Array<T>& array,
-                           absl::Span<const int64_t> indices, int cur_depth,
-                           std::vector<int64_t> cur_indices) {
+void AppendFlattenElementsInternal(std::vector<T>* result,
+                                   const Array<T>& array,
+                                   absl::Span<const int64_t> indices,
+                                   int cur_depth,
+                                   std::vector<int64_t> cur_indices) {
   if (cur_depth == array.num_dimensions() - 1) {
     result->push_back(array(cur_indices));
   } else {
@@ -106,15 +105,25 @@
     if (index == -1) {
       for (int64_t i = 0; i < array.dim(next_depth); ++i) {
         cur_indices[next_depth] = i;
-        AppendFlattenElements(result, array, indices, next_depth, cur_indices);
+        AppendFlattenElementsInternal(result, array, indices, next_depth,
+                                      cur_indices);
       }
     } else {
       cur_indices[next_depth] = index;
-      AppendFlattenElements(result, array, indices, next_depth, cur_indices);
+      AppendFlattenElementsInternal(result, array, indices, next_depth,
+                                    cur_indices);
     }
   }
 }
 
+template <typename T>
+void AppendFlattenElements(std::vector<T>* result, const Array<T>& array,
+                           absl::Span<const int64_t> indices) {
+  std::vector<int64_t> tmp_indices(array.num_dimensions(), 0);
+  AppendFlattenElementsInternal(result, array, indices,
+                                /*cur_depth=*/-1, tmp_indices);
+}
+
 // Return the index of key in a span. -1 means not found.
 template <typename T>
 int64_t GetIndex(absl::Span<const T> v, const T& key) {
@@ -201,11 +210,6 @@
   }
 }
 
-// Return whether this instruction is a custom call marker introduced by us.
-inline bool IsCustomCallMarker(const HloInstruction* inst) {
-  return inst->IsCustomCall({kPipelineMarker, kIdentityMarker});
-}
-
 // Return whether this instruction is a TopK custom call.
 inline bool IsTopKCustomCall(const HloInstruction* inst) {
   return inst->opcode() == HloOpcode::kCustomCall &&
@@ -218,70 +222,6 @@
          inst->custom_call_target() == "PartialReduce";
 }
 
-// Pass through the custom call marker and get the source instruction
-inline const HloInstruction* PassThroughCustomCallMarkerGetSource(
-    const HloInstruction* ins) {
-  while (ins->opcode() == HloOpcode::kGetTupleElement &&
-         IsCustomCallMarker(ins->operand(0))) {
-    const HloInstruction* custom_call = ins->operand(0);
-    const HloInstruction* tuple = custom_call->operand(0);
-    while (IsCustomCallMarker(tuple)) {
-      tuple = tuple->operand(0);
-    }
-    ins = tuple->operand(ins->tuple_index());
-  }
-  return ins;
-}
-
-// Pass through the custom call marker and get the acutal operand.
-inline HloInstruction* PassThroughCustomCallMarkerOperand(
-    HloInstruction* raw_operand, const HloInstruction* inst) {
-  if (!IsCustomCallMarker(raw_operand)) {
-    return raw_operand;
-  }
-
-  CHECK_EQ(inst->opcode(), HloOpcode::kGetTupleElement);
-
-  int index = inst->tuple_index();
-  return raw_operand->mutable_operand(0)->mutable_operand(index);
-}
-
-// Return whether the tuple is only used by a custom call marker.
-inline bool IsCustomCallMarkerTuple(const HloInstruction* inst) {
-  return inst->opcode() == HloOpcode::kTuple && inst->users().size() == 1 &&
-         IsCustomCallMarker(inst->users().front());
-}
-
-// Pass through the custom call marker and get the actual user.
-inline HloInstruction* PassThroughCustomCallMarkerUser(
-    HloInstruction* raw_user, const HloInstruction* inst) {
-  if (!IsCustomCallMarkerTuple(raw_user)) {
-    return raw_user;
-  }
-
-  const HloInstruction* custom_call = raw_user->users().front();
-
-  int index = -1;
-  for (int i = 0; i < raw_user->operand_count(); i++) {
-    if (raw_user->operand(i) == inst) {
-      index = i;
-      break;
-    }
-  }
-  CHECK_NE(index, -1);
-
-  HloInstruction* ret = nullptr;
-  for (HloInstruction* user : custom_call->users()) {
-    CHECK_EQ(user->opcode(), HloOpcode::kGetTupleElement);
-    if (user->tuple_index() == index) {
-      CHECK_EQ(ret, nullptr);
-      ret = user;
-    }
-  }
-
-  return ret == nullptr ? raw_user : ret;
-}
-
 // Return the users of an instruction and its alias,
 // excluding the final output tuple.
 inline InstructionSet UsersWithAlias(const HloInstruction* inst,
@@ -289,8 +229,7 @@
                                      const HloInstruction* output) {
   InstructionSet users;
   for (HloInstruction* user : inst->users()) {
-    HloInstruction* pass_through_user =
-        PassThroughCustomCallMarkerUser(user, inst);
+    HloInstruction* pass_through_user = user;
     if (pass_through_user == output) {
       continue;
     }
@@ -300,8 +239,7 @@
   auto iter = alias_map.find(inst);
   if (iter != alias_map.end()) {
     for (HloInstruction* user : iter->second->users()) {
-      HloInstruction* pass_through_user =
-          PassThroughCustomCallMarkerUser(user, iter->second);
+      HloInstruction* pass_through_user = user;
       if (pass_through_user == output) {
         continue;
       }
@@ -356,10 +294,6 @@
                                             const xla::CallGraph& call_graph,
                                             int64_t num_devices);
 
-// Return whether the instruction is an activation from another pipeline stage.
-bool IsActivationFromAnotherStage(const HloInstruction* inst,
-                                  const InstructionBatchDimMap& batch_dim_map);
-
 // Depth analysis (breadth first search) that compute the depth of each
 // instruction. We also assign a much larger distance to heavy operators (e.g.,
 // dot, convolution).
@@ -464,41 +398,6 @@
                                          const Array<int64_t>& device_mesh,
                                          ReshardingCache* resharding_cache);
 
-/*
- * Gradient accumulation
- */
-// Find all instructions that compute gradients in gradient accumulation.
-// This is done by using the hint from pipeline_marker (gradient marker).
-inline std::vector<const HloInstruction*> GetGradientComputationInstructions(
-    const std::vector<HloInstruction*>& instructions) {
-  std::vector<const HloInstruction*> ret;
-
-  for (size_t i = 0; i < instructions.size(); ++i) {
-    const HloInstruction* ins = instructions[i];
-    if (ins->IsCustomCall(kPipelineMarker) &&
-        (absl::StrContains(ins->metadata().op_name(), "compute_grad") ||
-         absl::StrContains(ins->metadata().op_name(), "backward")) &&
-        ins->metadata().op_type() == kPipelineMarkerEndType) {
-      const HloInstruction* tuple = ins->operand(0);
-      for (size_t j = 0; j < tuple->operand_count(); ++j) {
-        const HloInstruction* add = tuple->operand(j);
-        while (add->opcode() == HloOpcode::kAdd) {
-          ret.push_back(add->operand(0));
-          ret.push_back(add->operand(1));
-
-          if (add->operand(0)->opcode() == HloOpcode::kAdd) {
-            add = add->operand(0);
-          } else {
-            add = add->operand(1);
-          }
-        }
-      }
-    }
-  }
-
-  return ret;
-}
-
 // Gets the mapping vector from dim_from to dim_to.
 // Example: GetDimensionMapping([2], 3) = [0, 1, -1]
 std::vector<int64_t> GetDimensionMapping(
@@ -545,6 +444,11 @@
 
 HloSharding Tile(const Shape& tensor_shape,
                  absl::Span<const int64_t> tensor_dims,
+                 const std::vector<std::vector<int64_t>>& mesh_dims,
+                 const Array<int64_t>& device_mesh);
+
+HloSharding Tile(const Shape& tensor_shape,
+                 absl::Span<const int64_t> tensor_dims,
                  absl::Span<const int64_t> mesh_dims,
                  const Array<int64_t>& device_mesh);
 
diff --git a/third_party/xla/xla/hlo/ir/BUILD b/third_party/xla/xla/hlo/ir/BUILD
index 352f75c..e65c48d 100644
--- a/third_party/xla/xla/hlo/ir/BUILD
+++ b/third_party/xla/xla/hlo/ir/BUILD
@@ -25,7 +25,6 @@
         "dfs_hlo_visitor.cc",
         "dynamic_parameter_binding.cc",
         "hlo_computation.cc",
-        "hlo_frontend_attributes.cc",
         "hlo_input_output_alias_config.cc",
         "hlo_instruction.cc",
         "hlo_instructions.cc",
@@ -47,7 +46,6 @@
         "hlo_clone_context.h",
         "hlo_computation.h",
         "hlo_domain_metadata.h",
-        "hlo_frontend_attributes.h",
         "hlo_input_output_alias_config.h",
         "hlo_instruction.h",
         "hlo_instructions.h",
@@ -72,6 +70,7 @@
         "//xla:protobuf_util",
         "//xla:shape_tree",
         "//xla:shape_util",
+        "//xla:sort_json",
         "//xla:status_macros",
         "//xla:types",
         "//xla:util",
@@ -79,6 +78,7 @@
         "//xla:xla_data_proto_cc",
         "//xla:xla_proto_cc",
         "//xla/service:compilation_environments",
+        "//xla/service:computation_layout",
         "//xla/service:computation_placer_hdr",
         "//xla/service:hlo_lexer",
         "//xla/service:hlo_module_config",
@@ -100,13 +100,13 @@
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/strings:cord",
         "@com_google_absl//absl/strings:str_format",
+        "@com_google_absl//absl/synchronization",
         "@com_google_absl//absl/types:span",
         "@local_tsl//tsl/lib/gtl:iterator_range",
         "@local_tsl//tsl/lib/gtl:map_util",
         "@local_tsl//tsl/platform:env",
         "@local_tsl//tsl/platform:errors",
         "@local_tsl//tsl/platform:fingerprint",
-        "@local_tsl//tsl/platform:human_readable_json",
         "@local_tsl//tsl/platform:logging",
         "@local_tsl//tsl/platform:protobuf",
         "@local_tsl//tsl/platform:status",
diff --git a/third_party/xla/xla/hlo/ir/hlo_computation.cc b/third_party/xla/xla/hlo/ir/hlo_computation.cc
index 025b1ce..4fbf057 100644
--- a/third_party/xla/xla/hlo/ir/hlo_computation.cc
+++ b/third_party/xla/xla/hlo/ir/hlo_computation.cc
@@ -1103,6 +1103,19 @@
   return call_instruction;
 }
 
+HloInstruction* HloComputation::CreateCompositeCallInstruction(
+    absl::Span<HloInstruction* const> instructions_to_call,
+    const std::string& name, const std::string& attributes, int64_t version) {
+  HloInstruction* root = instructions_to_call.front();
+  HloInstruction* call_instruction =
+      AddInstruction(HloInstruction::CreateCompositeCall(
+                         root->shape(), root, name, attributes, version),
+                     root->name());
+  AppendInstructionsIntoCalledComputation(instructions_to_call,
+                                          call_instruction);
+  return call_instruction;
+}
+
 absl::StatusOr<HloInstruction*> HloComputation::CreateAsyncInstructions(
     HloInstruction* instruction, absl::Span<const Shape> context_shapes,
     absl::string_view async_execution_thread, bool replace,
diff --git a/third_party/xla/xla/hlo/ir/hlo_computation.h b/third_party/xla/xla/hlo/ir/hlo_computation.h
index 956cf1a..3e73a68 100644
--- a/third_party/xla/xla/hlo/ir/hlo_computation.h
+++ b/third_party/xla/xla/hlo/ir/hlo_computation.h
@@ -17,18 +17,20 @@
 #define XLA_HLO_IR_HLO_COMPUTATION_H_
 
 #include <cstdint>
-#include <list>
 #include <memory>
 #include <optional>
 #include <string>
 #include <utility>
 #include <vector>
 
+#include "absl/algorithm/container.h"
 #include "absl/container/flat_hash_map.h"
 #include "absl/container/flat_hash_set.h"
+#include "absl/container/inlined_vector.h"
 #include "absl/functional/function_ref.h"
 #include "absl/log/check.h"
 #include "absl/log/log.h"
+#include "absl/status/status.h"
 #include "absl/status/statusor.h"
 #include "absl/strings/cord.h"
 #include "absl/strings/string_view.h"
@@ -42,9 +44,14 @@
 #include "xla/printer.h"
 #include "xla/service/hlo.pb.h"
 #include "xla/service/name_uniquer.h"
+#include "xla/shape.h"
 #include "xla/shape_tree.h"
+#include "xla/shape_util.h"
+#include "xla/status_macros.h"
 #include "xla/util.h"
 #include "xla/xla_data.pb.h"
+#include "tsl/lib/gtl/iterator_range.h"
+#include "tsl/platform/errors.h"
 
 namespace xla {
 
@@ -465,7 +472,7 @@
       absl::Span<HloInstruction* const> instructions_to_fuse,
       HloInstruction::FusionKind fusion_kind);
 
-  // Creates a call instruction containing the given instructions.  Instructions
+  // Creates a call instruction containing the given instructions. Instructions
   // must be in reverse topological order (root of the called computation
   // first). Replaces all uses of the original root instruction with the call
   // instruction. The original instructions are removed if they have no uses
@@ -473,6 +480,16 @@
   HloInstruction* CreateCallInstruction(
       absl::Span<HloInstruction* const> instructions_to_call);
 
+  // Creates a composite call instruction containing the given instructions.
+  // Instructions must be in reverse topological order (root of the called
+  // computation first). Replaces all uses of the original root instruction with
+  // the composite call instruction. The original instructions are removed if
+  // they have no uses after creating the composite call (this is necessarily
+  // true for at least the root).
+  HloInstruction* CreateCompositeCallInstruction(
+      absl::Span<HloInstruction* const> instructions_to_call,
+      const std::string& name, const std::string& attributes, int64_t version);
+
   // Creates an async start/done instruction pair where instruction is wrapped
   // inside an asynchronous computation. The context shapes are appended to the
   // output tuple of the asynchronous start which is backend specific. Returns
diff --git a/third_party/xla/xla/hlo/ir/hlo_frontend_attributes.cc b/third_party/xla/xla/hlo/ir/hlo_frontend_attributes.cc
deleted file mode 100644
index 347edce..0000000
--- a/third_party/xla/xla/hlo/ir/hlo_frontend_attributes.cc
+++ /dev/null
@@ -1,44 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/hlo/ir/hlo_frontend_attributes.h"
-
-#include <string>
-#include <utility>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/str_format.h"
-#include "absl/strings/str_join.h"
-
-namespace xla {
-
-std::string FrontendAttributesToString(
-    const FrontendAttributes& frontend_attributes) {
-  std::vector<std::pair<std::string, std::string>> sorted_attributes(
-      frontend_attributes.map().begin(), frontend_attributes.map().end());
-  absl::c_sort(sorted_attributes);
-  // Frontend attribute is a comma-separated list of attribute="value" pairs,
-  // e.g., frontend_attributes={name="value_a",type="int32_t"}.
-  const auto formatter = [](std::string* out,
-                            const std::pair<std::string, std::string>& item) {
-    absl::StrAppend(out, item.first, "=\"", item.second, "\"");
-  };
-  return absl::StrFormat("{%s}",
-                         absl::StrJoin(sorted_attributes, ",", formatter));
-}
-
-}  // namespace xla
diff --git a/third_party/xla/xla/hlo/ir/hlo_frontend_attributes.h b/third_party/xla/xla/hlo/ir/hlo_frontend_attributes.h
deleted file mode 100644
index 7348691..0000000
--- a/third_party/xla/xla/hlo/ir/hlo_frontend_attributes.h
+++ /dev/null
@@ -1,28 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_HLO_IR_HLO_FRONTEND_ATTRIBUTES_H_
-#define XLA_HLO_IR_HLO_FRONTEND_ATTRIBUTES_H_
-
-#include <string>
-
-#include "xla/xla_data.pb.h"
-
-namespace xla {
-std::string FrontendAttributesToString(
-    const FrontendAttributes& frontend_attributes);
-}  // namespace xla
-
-#endif  // XLA_HLO_IR_HLO_FRONTEND_ATTRIBUTES_H_
diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.cc b/third_party/xla/xla/hlo/ir/hlo_instruction.cc
index 7261134..7764750 100644
--- a/third_party/xla/xla/hlo/ir/hlo_instruction.cc
+++ b/third_party/xla/xla/hlo/ir/hlo_instruction.cc
@@ -55,7 +55,6 @@
 #include "xla/hlo/ir/hlo_clone_context.h"
 #include "xla/hlo/ir/hlo_computation.h"
 #include "xla/hlo/ir/hlo_domain_metadata.h"
-#include "xla/hlo/ir/hlo_frontend_attributes.h"
 #include "xla/hlo/ir/hlo_instructions.h"
 #include "xla/hlo/ir/hlo_module.h"
 #include "xla/hlo/ir/hlo_op_metadata.h"
@@ -75,14 +74,16 @@
 #include "xla/service/name_uniquer.h"
 #include "xla/shape.h"
 #include "xla/shape_util.h"
+#include "xla/sort_json.h"
 #include "xla/status_macros.h"
 #include "xla/util.h"
 #include "xla/xla_data.pb.h"
 #include "tsl/lib/gtl/iterator_range.h"
 #include "tsl/lib/gtl/map_util.h"
 #include "tsl/platform/errors.h"
-#include "tsl/platform/human_readable_json.h"
 #include "tsl/platform/logging.h"  // IWYU pragma: keep
+#include "tsl/platform/status.h"
+#include "tsl/platform/statusor.h"
 
 namespace xla {
 
@@ -1159,12 +1160,43 @@
           << instruction->opcode() << proto.name();
       TF_RET_CHECK(!proto.has_dot_dimension_numbers()) << instruction->opcode();
 
-      auto call_instruction = new HloCallInstruction(
-          shape, all_operands(),
-          computation_map.at(proto.called_computation_ids()[0]));
-      call_instruction->set_output_to_operand_aliasing(
-          output_to_operand_aliasing());
-      instruction = absl::WrapUnique(call_instruction);
+      if (proto.is_composite()) {
+        TF_RET_CHECK(proto.has_frontend_attributes())
+            << "A composite call op must have frontend attributes";
+        auto map = proto.frontend_attributes().map();
+        auto name = map.find("composite.name");
+        TF_RET_CHECK(name != map.end() && !name->second.empty())
+            << "A composite call op must have frontend attributes with key "
+               "composite.name whose value is non-empty";
+
+        auto attributes = map.find("composite.attributes");
+        TF_RET_CHECK(attributes == map.end() || !attributes->second.empty())
+            << "A composite call op must have frontend attributes with key "
+               "composite.attributes whose value is default: {} or non-empty";
+
+        auto version_str = map.find("composite.version");
+        int64_t version = 0;
+        TF_RET_CHECK(
+            version_str == map.end() ||
+            (absl::SimpleAtoi(version_str->second, &version) && version >= 0))
+            << "A composite call op must have frontend attributes with a "
+               "composite.version whose value is a non-negative integer but "
+               "got: "
+            << version_str->second;
+
+        instruction = CreateCompositeCall(
+            shape, all_operands(),
+            computation_map.at(proto.called_computation_ids()[0]), name->second,
+            attributes == map.end() ? "{}" : attributes->second, version);
+        instruction->set_output_to_operand_aliasing(
+            output_to_operand_aliasing());
+      } else {
+        instruction = std::make_unique<HloCallInstruction>(
+            shape, all_operands(),
+            computation_map.at(proto.called_computation_ids()[0]));
+        instruction->set_output_to_operand_aliasing(
+            output_to_operand_aliasing());
+      }
       break;
     }
     default: {
@@ -1230,7 +1262,6 @@
     const xla::OriginalValueProto& original_value_proto =
         proto.original_value();
     auto original_value = std::make_shared<OriginalValue>(shape);
-    std::cerr << __func__ << ", shape: " << shape.ToString() << "\n";
 
     for (const auto& leaf : original_value_proto.leaves()) {
       *original_value->mutable_element(ShapeIndex(leaf.leaf_shape_index())) = {
@@ -2252,6 +2283,27 @@
   return std::make_unique<HloCallInstruction>(shape, operands, computation);
 }
 
+/* static */ std::unique_ptr<HloInstruction>
+HloInstruction::CreateCompositeCall(const Shape& shape,
+                                    HloInstruction* decomposition_root,
+                                    const std::string& name,
+                                    const std::string& attributes,
+                                    int64_t version) {
+  return std::make_unique<HloCallInstruction>(shape, decomposition_root, name,
+                                              attributes, version);
+}
+
+/* static */ std::unique_ptr<HloInstruction>
+HloInstruction::CreateCompositeCall(const Shape& shape,
+                                    absl::Span<HloInstruction* const> operands,
+                                    HloComputation* decomposition,
+                                    const std::string& name,
+                                    const std::string& attributes,
+                                    int64_t version) {
+  return std::make_unique<HloCallInstruction>(shape, operands, decomposition,
+                                              name, attributes, version);
+}
+
 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
     const Shape& shape, absl::Span<HloInstruction* const> operands,
     absl::string_view custom_call_target, std::string opaque,
@@ -3635,6 +3687,13 @@
   }
   if (options.print_backend_config() && !backend_config_.empty()) {
     absl::string_view config = backend_config_.GetRawString();
+    std::string sorted_config;
+    if (options.sort_backend_config()) {
+      // Use `value_or` below, because the backend config string isn't
+      // guaranteed to be a JSON string.
+      sorted_config = SortJson(config).value_or(std::string(config));
+      config = sorted_config;
+    }
     printer->Append(", backend_config=");
     // In the common case that the backend-config is valid-ish JSON, the parser
     // doesn't need it delimited by quotes, so we can print it without
@@ -3782,6 +3841,10 @@
           PrintNameInternal(printer, to_apply()->name(), options);
         });
       }
+      if (opcode() == HloOpcode::kCall && is_composite()) {
+        printer.Next(
+            [](Printer* printer) { printer->Append("is_composite=true"); });
+      }
     } else if (opcode() == HloOpcode::kCustomCall) {
       if (!called_computations().empty()) {
         printer.Next([this, &options](Printer* printer) {
@@ -3878,6 +3941,10 @@
             to_apply()->Print(printer, new_options);
           });
         }
+        if (opcode() == HloOpcode::kCall && is_composite()) {
+          printer.Next(
+              [](Printer* printer) { printer->Append("is_composite=true"); });
+        }
         break;
       default:
         if (!called_computations().empty()) {
@@ -3900,13 +3967,18 @@
       sharding().Print(printer, options.print_metadata());
     });
   }
-  if (!rare()->frontend_attributes.map().empty()) {
+  if (!frontend_attributes().map().empty()) {
     printer.Next([this](Printer* printer) {
       AppendCat(printer, "frontend_attributes=",
-                FrontendAttributesToString(rare()->frontend_attributes));
+                FrontendAttributesToString(frontend_attributes()));
     });
   }
 
+  if (opcode() != HloOpcode::kCall) {
+    CHECK(!is_composite())
+        << "Only kCall instructions should have is_composite set";
+  }
+
   if (options.print_control_dependencies() && !control_predecessors().empty()) {
     printer.Next([this, &options](Printer* printer) {
       printer->Append("control-predecessors={");
@@ -3952,6 +4024,23 @@
   return std::move(multi_string_printer).ConsumeStrings();
 }
 
+std::string FrontendAttributesToString(
+    const FrontendAttributes& frontend_attributes) {
+  std::vector<std::pair<std::string, std::string>> sorted_attributes(
+      frontend_attributes.map().begin(), frontend_attributes.map().end());
+  absl::c_sort(sorted_attributes);
+  const auto formatter = [](std::string* out,
+                            const std::pair<std::string, std::string>& item) {
+    if (LexesAsJsonDict(item.second)) {
+      absl::StrAppend(out, item.first, "=", item.second);
+    } else {
+      absl::StrAppend(out, item.first, "=\"", item.second, "\"");
+    }
+  };
+  return absl::StrFormat("{%s}",
+                         absl::StrJoin(sorted_attributes, ",", formatter));
+}
+
 std::string HloInstruction::ToShortString() const {
   return StrCat("%", name(), " = ", HloOpcodeString(opcode()), "(",
                 StrJoin(operands_, ", ",
@@ -3990,6 +4079,7 @@
   }
 
   *proto.mutable_frontend_attributes() = frontend_attributes();
+  proto.set_is_composite(is_composite());
 
   *proto.mutable_statistics_viz() = statistics_viz();
 
diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.h b/third_party/xla/xla/hlo/ir/hlo_instruction.h
index 6d8821c..a98f996 100644
--- a/third_party/xla/xla/hlo/ir/hlo_instruction.h
+++ b/third_party/xla/xla/hlo/ir/hlo_instruction.h
@@ -25,6 +25,7 @@
 #include <cstdint>
 #include <functional>
 #include <iosfwd>
+#include <iterator>
 #include <map>
 #include <memory>
 #include <optional>
@@ -98,6 +99,7 @@
         print_metadata_(true),
         print_metadata_only_op_name_(false),
         print_backend_config_(true),
+        sort_backend_config_(false),
         print_infeed_outfeed_config_(true),
         compact_operands_(false),
         include_layout_in_shapes_(true),
@@ -218,6 +220,14 @@
     return *this;
   }
 
+  // If true, will attempt to sort the backend config's json representation
+  // before printing it. If the backend config is a raw string that is not json,
+  // it will be printed as is, without sorting.
+  HloPrintOptions& set_sort_backend_config(bool value) {
+    sort_backend_config_ = value;
+    return *this;
+  }
+
   // If true, infeed_config and outfeed_config will be printed.
   HloPrintOptions& set_print_infeed_outfeed_config(bool value) {
     print_infeed_outfeed_config_ = value;
@@ -382,6 +392,7 @@
     return print_metadata_only_op_name_;
   }
   bool print_backend_config() const { return print_backend_config_; }
+  bool sort_backend_config() const { return sort_backend_config_; }
   bool print_infeed_outfeed_config() const {
     return print_infeed_outfeed_config_;
   }
@@ -422,6 +433,7 @@
   bool print_metadata_;
   bool print_metadata_only_op_name_;
   bool print_backend_config_;
+  bool sort_backend_config_;
   bool print_infeed_outfeed_config_;
   bool compact_operands_;
   bool include_layout_in_shapes_;
@@ -1351,6 +1363,17 @@
       const Shape& shape, absl::Span<HloInstruction* const> operands,
       HloComputation* computation);
 
+  // Creates a composite call instruction that applies the given computation on
+  // the given operands. "shape" is the resultant shape.
+  static std::unique_ptr<HloInstruction> CreateCompositeCall(
+      const Shape& shape, HloInstruction* decomposition_root,
+      const std::string& name, const std::string& attributes, int64_t version);
+
+  static std::unique_ptr<HloInstruction> CreateCompositeCall(
+      const Shape& shape, absl::Span<HloInstruction* const> operands,
+      HloComputation* decomposition, const std::string& name,
+      const std::string& attributes, int64_t version);
+
   // Creates a custom call instruction that applies the given custom call target
   // to the given operands. "opaque" can be an arbitrary string with a
   // backend-specific interpretation. "shape" is the resultant shape.
@@ -2093,6 +2116,9 @@
     mutable_rare()->frontend_attributes = std::move(frontend_attributes);
   }
 
+  // Appends the given frontend attributes to the existing ones. If existing
+  // frontend attributes are empty, then create it and set it to the provided
+  // one.
   void add_frontend_attributes(FrontendAttributes frontend_attributes) {
     if (!frontend_attributes.map().empty()) {
       mutable_rare()->frontend_attributes.mutable_map()->insert(
@@ -2100,10 +2126,25 @@
     }
   }
 
+  bool has_frontend_attributes() const {
+    return has_rare() && !rare()->frontend_attributes.map().empty();
+  }
+
   const FrontendAttributes& frontend_attributes() const {
     return rare()->frontend_attributes;
   }
 
+  void set_is_composite(bool is_composite) {
+    if (!has_rare() && !is_composite) {
+      return;
+    }
+    mutable_rare()->is_composite = is_composite;
+  }
+
+  // Return the is_composite attribute. This attribute is only relevant for
+  // kCall instructions used as a Composite op.
+  bool is_composite() const { return has_rare() && rare()->is_composite; }
+
   void add_single_statistic(Statistic statistic) {
     *mutable_rare()->statistics_viz.add_statistics() = std::move(statistic);
   }
@@ -2199,8 +2240,8 @@
   void set_metadata_preserve_layout(bool preserve_layout) {
     metadata_->set_preserve_layout(preserve_layout);
   }
-  void set_metadata_scheduling_name(const std::string& name) {
-    metadata_->set_scheduling_name(name);
+  void set_metadata_scheduling_name(absl::string_view name) {
+    metadata_->set_scheduling_name(std::string(name));
   }
   const OpMetadata& metadata() const { return *metadata_; }
 
@@ -2688,6 +2729,9 @@
     //    z' = const(20), frontend_attributes={?}
     FrontendAttributes frontend_attributes;
 
+    // Used by kCall to determine if the Call instruction is a composite.
+    bool is_composite;
+
     // Used to render an HLO graph when tracking the propagation desired values
     // through it.
     StatisticsViz statistics_viz;
@@ -2827,6 +2871,14 @@
 // Custom (de)stringification functions for protos that live inside
 // HloInstruction.
 std::string PaddingConfigToString(const PaddingConfig& padding);
+
+// Returns string representation of frontend attributes.
+// Frontend attribute is a list of attribute=<value> pairs where value is either
+// a "string" or a JSON-like dict surrounded in {}. Similar to custom_call
+// backend config, this can be used to store stringified MLIR-dictionaries with
+// pretty printing.
+std::string FrontendAttributesToString(
+    const FrontendAttributes& frontend_attributes);
 std::string StatisticsVizToString(const StatisticsViz& statistics_viz);
 std::string RandomAlgorithmToString(const RandomAlgorithm& algorithm);
 std::string RandomDistributionToString(const RandomDistribution& distribution);
diff --git a/third_party/xla/xla/hlo/ir/hlo_instructions.cc b/third_party/xla/xla/hlo/ir/hlo_instructions.cc
index d9801cf..cff0907 100644
--- a/third_party/xla/xla/hlo/ir/hlo_instructions.cc
+++ b/third_party/xla/xla/hlo/ir/hlo_instructions.cc
@@ -1877,6 +1877,35 @@
   }
 }
 
+HloCallableInstruction::HloCallableInstruction(HloOpcode opcode,
+                                               const Shape& shape,
+                                               const std::string& name,
+                                               const std::string& attributes,
+                                               int64_t version)
+    : HloInstruction(opcode, shape) {
+  auto frontend_attributes =
+      BuildFrontendAttributesForComposite(name, attributes, version);
+  add_frontend_attributes(frontend_attributes);
+  set_is_composite(true);
+}
+
+HloCallableInstruction::HloCallableInstruction(
+    HloOpcode opcode, const Shape& shape,
+    absl::Span<HloInstruction* const> operands, HloComputation* decomposition,
+    const std::string& name, const std::string& attributes, int64_t version)
+    : HloInstruction(opcode, shape) {
+  for (auto operand : operands) {
+    AppendOperand(operand);
+  }
+  SetAndSanitizeName(HloOpcodeString(opcode));
+  AppendComputation(decomposition);
+
+  auto frontend_attributes =
+      BuildFrontendAttributesForComposite(name, attributes, version);
+  add_frontend_attributes(frontend_attributes);
+  set_is_composite(true);
+}
+
 HloCallableInstruction::~HloCallableInstruction() { ClearCalledComputations(); }
 
 HloComputation* HloCallableInstruction::called_computation() const {
@@ -1924,7 +1953,7 @@
         return u->opcode() == HloOpcode::kGetTupleElement;
       });
   if (called_computations().empty()) {
-    // New fusion instruction. It should not be a multioutput instruction.
+    // New fusion instruction. It should not be a multi-output instruction.
     CHECK(!add_output);
     auto builder = HloComputation::Builder(default_called_computation_name());
     builder.AddInstruction(instruction_to_append->Clone(/*suffix=*/""));
@@ -2552,6 +2581,47 @@
     : HloCallableInstruction(HloOpcode::kCall, shape, operands,
                              called_computation) {}
 
+HloCallInstruction::HloCallInstruction(const Shape& shape,
+                                       HloInstruction* decomposition_root,
+                                       const std::string& name,
+                                       const std::string& attributes,
+                                       int64_t version)
+    : HloCallableInstruction(HloOpcode::kCall, shape, name, attributes,
+                             version) {
+  CHECK(decomposition_root != nullptr);
+  SetAndSanitizeName(HloOpcodeString(opcode()));
+
+  FrontendAttributes frontend_attributes;
+  frontend_attributes.mutable_map()->insert({"composite.name", name});
+  frontend_attributes.mutable_map()->insert(
+      {"composite.attributes", attributes});
+  frontend_attributes.mutable_map()->insert(
+      {"composite.version", std::to_string(version)});
+
+  add_frontend_attributes(frontend_attributes);
+  set_is_composite(true);
+  set_parent(decomposition_root->parent());
+  set_metadata(decomposition_root->metadata());
+  CloneAndAppendInstructionIntoCalledComputation(decomposition_root);
+}
+
+HloCallInstruction::HloCallInstruction(
+    const Shape& shape, absl::Span<HloInstruction* const> operands,
+    HloComputation* decomposition, const std::string& name,
+    const std::string& attributes, int64_t version)
+    : HloCallableInstruction(HloOpcode::kCall, shape, operands, decomposition,
+                             name, attributes, version) {
+  FrontendAttributes frontend_attributes;
+  frontend_attributes.mutable_map()->insert({"composite.name", name});
+  frontend_attributes.mutable_map()->insert(
+      {"composite.attributes", attributes});
+  frontend_attributes.mutable_map()->insert(
+      {"composite.version", std::to_string(version)});
+
+  add_frontend_attributes(frontend_attributes);
+  set_is_composite(true);
+}
+
 HloRngInstruction::HloRngInstruction(
     const Shape& shape, RandomDistribution distribution,
     absl::Span<HloInstruction* const> parameters)
@@ -3500,13 +3570,23 @@
   AppendJoin(printer, dim_numbers.collapsed_slice_dims(), ",");
   printer->Append("}, start_index_map={");
   AppendJoin(printer, dim_numbers.start_index_map(), ",");
+  if (dim_numbers.operand_batching_dims_size()) {
+    printer->Append("}, operand_batching_dims={");
+    AppendJoin(printer, dim_numbers.operand_batching_dims(), ",");
+  }
+  if (dim_numbers.start_indices_batching_dims_size()) {
+    printer->Append("}, start_indices_batching_dims={");
+    AppendJoin(printer, dim_numbers.start_indices_batching_dims(), ",");
+  }
   AppendCat(printer, "}, index_vector_dim=", dim_numbers.index_vector_dim());
 }
 
 /* static */ GatherDimensionNumbers HloGatherInstruction::MakeGatherDimNumbers(
     absl::Span<const int64_t> offset_dims,
     absl::Span<const int64_t> collapsed_slice_dims,
-    absl::Span<const int64_t> start_index_map, int64_t index_vector_dim) {
+    absl::Span<const int64_t> start_index_map, int64_t index_vector_dim,
+    absl::Span<const int64_t> operand_batching_dims,
+    absl::Span<const int64_t> start_indices_batching_dims) {
   GatherDimensionNumbers gather_dim_numbers;
   for (int64_t output_window_dim : offset_dims) {
     gather_dim_numbers.add_offset_dims(output_window_dim);
@@ -3517,6 +3597,13 @@
   for (int64_t gather_dim_to_input_dim : start_index_map) {
     gather_dim_numbers.add_start_index_map(gather_dim_to_input_dim);
   }
+  for (int64_t operand_batching_dim : operand_batching_dims) {
+    gather_dim_numbers.add_operand_batching_dims(operand_batching_dim);
+  }
+  for (int64_t start_indices_batching_dim : start_indices_batching_dims) {
+    gather_dim_numbers.add_start_indices_batching_dims(
+        start_indices_batching_dim);
+  }
 
   gather_dim_numbers.set_index_vector_dim(index_vector_dim);
   return gather_dim_numbers;
@@ -3601,6 +3688,14 @@
   AppendJoin(printer, dim_numbers.inserted_window_dims(), ",");
   printer->Append("}, scatter_dims_to_operand_dims={");
   AppendJoin(printer, dim_numbers.scatter_dims_to_operand_dims(), ",");
+  if (dim_numbers.input_batching_dims_size()) {
+    printer->Append("}, input_batching_dims={");
+    AppendJoin(printer, dim_numbers.input_batching_dims(), ",");
+  }
+  if (dim_numbers.scatter_indices_batching_dims_size()) {
+    printer->Append("}, scatter_indices_batching_dims={");
+    AppendJoin(printer, dim_numbers.scatter_indices_batching_dims(), ",");
+  }
   AppendCat(printer, "}, index_vector_dim=", dim_numbers.index_vector_dim());
 }
 
@@ -3609,7 +3704,8 @@
     absl::Span<const int64_t> update_window_dims,
     absl::Span<const int64_t> inserted_window_dims,
     absl::Span<const int64_t> scatter_dims_to_operand_dims,
-    int64_t index_vector_dim) {
+    int64_t index_vector_dim, absl::Span<const int64_t> input_batching_dims,
+    absl::Span<const int64_t> scatter_indices_batching_dims) {
   ScatterDimensionNumbers scatter_dim_numbers;
   for (int64_t update_window_dim : update_window_dims) {
     scatter_dim_numbers.add_update_window_dims(update_window_dim);
@@ -3621,6 +3717,13 @@
     scatter_dim_numbers.add_scatter_dims_to_operand_dims(
         scatter_dim_to_operand_dim);
   }
+  for (int64_t input_batching_dim : input_batching_dims) {
+    scatter_dim_numbers.add_input_batching_dims(input_batching_dim);
+  }
+  for (int64_t scatter_indices_batching_dim : scatter_indices_batching_dims) {
+    scatter_dim_numbers.add_scatter_indices_batching_dims(
+        scatter_indices_batching_dim);
+  }
   scatter_dim_numbers.set_index_vector_dim(index_vector_dim);
   return scatter_dim_numbers;
 }
diff --git a/third_party/xla/xla/hlo/ir/hlo_instructions.h b/third_party/xla/xla/hlo/ir/hlo_instructions.h
index b0e337a..c0f0324 100644
--- a/third_party/xla/xla/hlo/ir/hlo_instructions.h
+++ b/third_party/xla/xla/hlo/ir/hlo_instructions.h
@@ -19,13 +19,13 @@
 #define XLA_HLO_IR_HLO_INSTRUCTIONS_H_
 
 #include <cstdint>
-#include <list>
 #include <memory>
 #include <optional>
 #include <string>
 #include <utility>
 #include <vector>
 
+#include "absl/base/attributes.h"
 #include "absl/container/inlined_vector.h"
 #include "absl/functional/function_ref.h"
 #include "absl/status/status.h"
@@ -38,7 +38,6 @@
 #include "xla/hlo/ir/hlo_domain_metadata.h"
 #include "xla/hlo/ir/hlo_instruction.h"
 #include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/iterator_util.h"
 #include "xla/layout.h"
 #include "xla/literal.h"
 #include "xla/printer.h"
@@ -1343,6 +1342,15 @@
                          absl::Span<HloInstruction* const> operands,
                          absl::Span<HloComputation* const> called_computations);
 
+  HloCallableInstruction(HloOpcode opcode, const Shape& shape,
+                         const std::string& name, const std::string& attributes,
+                         int64_t version);
+
+  HloCallableInstruction(HloOpcode opcode, const Shape& shape,
+                         absl::Span<HloInstruction* const> operands,
+                         HloComputation* decomposition, const std::string& name,
+                         const std::string& attributes, int64_t version);
+
   ~HloCallableInstruction() override;
 
   // Adds a new operand to the callable instruction.
@@ -1402,6 +1410,21 @@
     output_to_operand_aliasing_ = std::move(aliasing);
   }
 
+  FrontendAttributes BuildFrontendAttributesForComposite(
+      const std::string& name,
+      std::optional<absl::string_view> attributes = std::nullopt,
+      std::optional<int64_t> version = std::nullopt) {
+    FrontendAttributes frontend_attributes;
+    frontend_attributes.mutable_map()->insert({"composite.name", name});
+    frontend_attributes.mutable_map()->insert(
+        {"composite.attributes",
+         attributes.has_value() ? std::string(*attributes) : "{}"});
+    frontend_attributes.mutable_map()->insert(
+        {"composite.version",
+         version.has_value() ? std::to_string(*version) : "0"});
+    return frontend_attributes;
+  }
+
  protected:
   // Returns the default called computation name.
   virtual std::string default_called_computation_name() const = 0;
@@ -1450,7 +1473,7 @@
   void MergeFusionInstruction(HloFusionInstruction* instruction_to_merge);
 
   // Merges the fused instructions from instruction_to_merge into the fused
-  // instruction set of 'this' and generates multioutput fusion instructions.
+  // instruction set of 'this' and generates multi-output fusion instructions.
   // All the users of instruction_to_merge will be redirected to 'this'
   // instruction. instruction_to_merge will be removed from its parent
   // computation.
@@ -1555,6 +1578,15 @@
                      absl::Span<HloInstruction* const> operands,
                      HloComputation* called_computation);
 
+  HloCallInstruction(const Shape& shape, HloInstruction* decomposition_root,
+                     const std::string& name, const std::string& attributes,
+                     int64_t version);
+
+  HloCallInstruction(const Shape& shape,
+                     absl::Span<HloInstruction* const> operands,
+                     HloComputation* decomposition, const std::string& name,
+                     const std::string& attributes, int64_t version);
+
   static bool ClassOf(const HloInstruction* hlo) {
     return hlo->opcode() == HloOpcode::kCall;
   }
@@ -2313,7 +2345,9 @@
   static GatherDimensionNumbers MakeGatherDimNumbers(
       absl::Span<const int64_t> offset_dims,
       absl::Span<const int64_t> collapsed_slice_dims,
-      absl::Span<const int64_t> start_index_map, int64_t index_vector_dim);
+      absl::Span<const int64_t> start_index_map, int64_t index_vector_dim,
+      absl::Span<const int64_t> operand_batching_dims = {},
+      absl::Span<const int64_t> start_indices_batching_dims = {});
   // Returns the dump string of the given gather dimension numbers.
   static std::string GatherDimensionNumbersToString(
       const GatherDimensionNumbers& dim_numbers);
@@ -2378,7 +2412,9 @@
       absl::Span<const int64_t> update_window_dims,
       absl::Span<const int64_t> inserted_window_dims,
       absl::Span<const int64_t> scatter_dims_to_operand_dims,
-      int64_t index_vector_dim);
+      int64_t index_vector_dim,
+      absl::Span<const int64_t> input_batching_dims = {},
+      absl::Span<const int64_t> scatter_indices_batching_dims = {});
   // Returns the dump string of the given scatter dimension numbers.
   static std::string ScatterDimensionNumbersToString(
       const ScatterDimensionNumbers& dim_numbers);
diff --git a/third_party/xla/xla/hlo/ir/hlo_module.cc b/third_party/xla/xla/hlo/ir/hlo_module.cc
index 0711d49..cc8dda9 100644
--- a/third_party/xla/xla/hlo/ir/hlo_module.cc
+++ b/third_party/xla/xla/hlo/ir/hlo_module.cc
@@ -17,8 +17,8 @@
 
 #include <algorithm>
 #include <atomic>
+#include <cstddef>
 #include <cstdint>
-#include <functional>
 #include <iterator>
 #include <memory>
 #include <optional>
@@ -30,24 +30,36 @@
 #include "absl/algorithm/container.h"
 #include "absl/container/flat_hash_map.h"
 #include "absl/container/flat_hash_set.h"
+#include "absl/status/status.h"
+#include "absl/strings/cord.h"
 #include "absl/strings/escaping.h"
 #include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "absl/synchronization/mutex.h"
+#include "absl/types/span.h"
+#include "xla/hlo/ir/hlo_clone_context.h"
 #include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_frontend_attributes.h"
 #include "xla/hlo/ir/hlo_input_output_alias_config.h"
 #include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
 #include "xla/hlo/ir/hlo_schedule.h"
+#include "xla/hlo/ir/hlo_sharding.h"
 #include "xla/map_util.h"
 #include "xla/printer.h"
 #include "xla/service/compilation_environments.h"
+#include "xla/service/computation_layout.h"
 #include "xla/service/computation_placer.h"
 #include "xla/service/hlo.pb.h"
 #include "xla/service/hlo_module_config.h"
 #include "xla/service/mapped_ptr_container_sorter.h"
+#include "xla/service/name_uniquer.h"
+#include "xla/shape.h"
 #include "xla/shape_util.h"
 #include "xla/status_macros.h"
+#include "xla/util.h"
 #include "xla/xla_data.pb.h"
 #include "tsl/lib/gtl/map_util.h"
+#include "tsl/platform/env.h"
 #include "tsl/platform/errors.h"
 #include "tsl/platform/fingerprint.h"
 #include "tsl/platform/logging.h"
@@ -405,8 +417,8 @@
                                  ? MakeComputationSorted()
                                  : MakeComputationPostOrder();
   for (const HloComputation* computation : computations) {
-    // Don't print async computations when the sytax sugar is enabled since that
-    // is redundant information.
+    // Don't print async computations when the syntax sugar is enabled since
+    // that is redundant information.
     if (options.syntax_sugar_async_ops() && computation->IsAsyncComputation() &&
         computation->CanExpandIntoSingleInstruction()) {
       continue;
@@ -848,7 +860,7 @@
                 outlined_instruction);
 
     // Mark instruction_to_outline an output if it is used outside the
-    // subcomputation or is the output of the original computation (i.e. used
+    // sub-computation or is the output of the original computation (i.e. used
     // externally).
     if (instruction_to_outline->user_count() == 0 ||
         IsUsedOutsideSubcomputation(*instruction_to_outline,
@@ -917,7 +929,7 @@
   if (computations_.empty()) {
     return {};
   }
-  // First determine all root computations by building a set of nonroot
+  // First determine all root computations by building a set of non-root
   // computations (computations which are called by an instruction in the
   // module).
   absl::flat_hash_set<HloComputation*> nonroot_computations;
diff --git a/third_party/xla/xla/hlo/utils/BUILD b/third_party/xla/xla/hlo/utils/BUILD
index 0bae42a..8f20f63 100644
--- a/third_party/xla/xla/hlo/utils/BUILD
+++ b/third_party/xla/xla/hlo/utils/BUILD
@@ -158,6 +158,7 @@
         "//xla/service:pattern_matcher",
         "@com_google_absl//absl/algorithm:container",
         "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/status:statusor",
         "@com_google_absl//absl/strings:string_view",
     ],
 )
diff --git a/third_party/xla/xla/hlo/utils/hlo_query.cc b/third_party/xla/xla/hlo/utils/hlo_query.cc
index 69a6fef..147f548 100644
--- a/third_party/xla/xla/hlo/utils/hlo_query.cc
+++ b/third_party/xla/xla/hlo/utils/hlo_query.cc
@@ -17,6 +17,7 @@
 
 #include <algorithm>
 #include <cstdint>
+#include <utility>
 
 #include "absl/algorithm/container.h"
 #include "absl/strings/string_view.h"
@@ -269,22 +270,46 @@
   return gte;
 }
 
-bool IsBeforeInComputation(const HloComputation* computation,
-                           absl::string_view inst1, absl::string_view inst2) {
-  int index1 = -1;
-  int index2 = -1;
+HloComputation* FindComputation(HloModule* module, absl::string_view name) {
+  auto computations = module->computations();
+  auto it = absl::c_find_if(
+      computations, [&](HloComputation* c) { return c->name() == name; });
+  if (it == computations.end()) {
+    return nullptr;
+  }
+  return *it;
+}
+
+std::pair<HloInstruction*, int> FindFirstInstruction(
+    const HloComputation* computation, absl::string_view name) {
   int current_index = 0;
-  for (auto instruction : computation->instructions()) {
-    if (instruction->name() == inst1) {
-      index1 = current_index;
-    }
-    if (instruction->name() == inst2) {
-      index2 = current_index;
+  for (auto* instruction : computation->instructions()) {
+    if (instruction->name() == name) {
+      return {instruction, current_index};
+      break;
     }
     current_index++;
   }
-  current_index++;
-  return index1 < index2;
+  return {nullptr, -1};
+}
+
+std::pair<HloInstruction*, int> FindFirstInstruction(
+    const HloComputation* computation, HloOpcode opcode) {
+  int current_index = 0;
+  for (auto* instruction : computation->instructions()) {
+    if (instruction->opcode() == opcode) {
+      return {instruction, current_index};
+      break;
+    }
+    current_index++;
+  }
+  return {nullptr, -1};
+}
+
+bool IsBeforeInComputation(const HloComputation* computation,
+                           absl::string_view inst1, absl::string_view inst2) {
+  return FindFirstInstruction(computation, inst1).second <
+         FindFirstInstruction(computation, inst2).second;
 }
 }  // namespace hlo_query
 }  // namespace xla
diff --git a/third_party/xla/xla/hlo/utils/hlo_query.h b/third_party/xla/xla/hlo/utils/hlo_query.h
index 950082a..ec5c0b2 100644
--- a/third_party/xla/xla/hlo/utils/hlo_query.h
+++ b/third_party/xla/xla/hlo/utils/hlo_query.h
@@ -17,6 +17,7 @@
 #define XLA_HLO_UTILS_HLO_QUERY_H_
 
 #include <cstdint>
+#include <utility>
 
 #include "absl/container/flat_hash_set.h"
 #include "absl/strings/string_view.h"
@@ -153,7 +154,19 @@
 HloInstruction* GetUniqueGteInstruction(const HloInstruction* operand,
                                         int64_t index);
 
-// TODO: b/356153995 - refactor hlo_test_base
+// Gets the computation from the given module with the given name.
+HloComputation* FindComputation(HloModule* module, absl::string_view name);
+// Gets the first instruction and its index from the given computation with the
+// given instruction name. The function returns {nullptr, -1} if the instruction
+// cannot be found.
+std::pair<HloInstruction*, int> FindFirstInstruction(
+    const HloComputation* computation, absl::string_view name);
+// Gets the first instruction and its index from the given computation with the
+// given instruction opcode. The function returns {nullptr, -1} if the
+// instruction cannot be found.
+std::pair<HloInstruction*, int> FindFirstInstruction(
+    const HloComputation* computation, HloOpcode opcode);
+
 // Check that one instruction comes before another one for a given computation.
 // The function returns true if the first instruction comes before the second
 // one, and false otherwise. This is useful for partial checks on the
diff --git a/third_party/xla/xla/hlo/utils/hlo_query_test.cc b/third_party/xla/xla/hlo/utils/hlo_query_test.cc
index acefa21..e4dad10 100644
--- a/third_party/xla/xla/hlo/utils/hlo_query_test.cc
+++ b/third_party/xla/xla/hlo/utils/hlo_query_test.cc
@@ -40,6 +40,14 @@
   return counter;
 }
 
+constexpr absl::string_view kConstantAdditionHloString = R"(
+HloModule test
+ENTRY main {
+  zero = f32[] constant(0)
+  five = f32[] constant(5)
+  ROOT out = f32[] add(zero, five)
+})";
+
 TEST_F(HloQueryTest,
        GetInstructionWithOpCodeReturnsMatchingInstructionForModule) {
   constexpr absl::string_view kHloString = R"(
@@ -132,5 +140,66 @@
   EXPECT_EQ(gte2, nullptr);
 }
 
+TEST_F(HloQueryTest, FindComputationTest) {
+  TF_ASSERT_OK_AND_ASSIGN(
+      std::unique_ptr<HloModule> module,
+      ParseAndReturnUnverifiedModule(kConstantAdditionHloString));
+  EXPECT_NE(hlo_query::FindComputation(module.get(), "main"), nullptr);
+  EXPECT_EQ(hlo_query::FindComputation(module.get(), "foo"), nullptr);
+}
+
+TEST_F(HloQueryTest, FindInstructionUsingNameTest) {
+  TF_ASSERT_OK_AND_ASSIGN(
+      std::unique_ptr<HloModule> module,
+      ParseAndReturnUnverifiedModule(kConstantAdditionHloString));
+  const HloComputation* main = hlo_query::FindComputation(module.get(), "main");
+  EXPECT_NE(hlo_query::FindFirstInstruction(main, "zero").first, nullptr);
+  EXPECT_NE(hlo_query::FindFirstInstruction(main, "five").first, nullptr);
+  EXPECT_NE(hlo_query::FindFirstInstruction(main, "out").first, nullptr);
+  EXPECT_EQ(hlo_query::FindFirstInstruction(main, "foo").first, nullptr);
+}
+
+TEST_F(HloQueryTest, FindInstructionUsingOpcodeTest) {
+  TF_ASSERT_OK_AND_ASSIGN(
+      std::unique_ptr<HloModule> module,
+      ParseAndReturnUnverifiedModule(kConstantAdditionHloString));
+  const HloComputation* main = hlo_query::FindComputation(module.get(), "main");
+  EXPECT_NE(
+      hlo_query::FindFirstInstruction(main, StringToHloOpcode("add").value())
+          .first,
+      nullptr);
+  EXPECT_NE(hlo_query::FindFirstInstruction(
+                main, StringToHloOpcode("constant").value())
+                .first,
+            nullptr);
+  EXPECT_EQ(
+      hlo_query::FindFirstInstruction(main, StringToHloOpcode("select").value())
+          .first,
+      nullptr);
+}
+
+TEST_F(HloQueryTest, FindInstructionDoesNotExistTest) {
+  TF_ASSERT_OK_AND_ASSIGN(
+      std::unique_ptr<HloModule> module,
+      ParseAndReturnUnverifiedModule(kConstantAdditionHloString));
+  const HloComputation* main = hlo_query::FindComputation(module.get(), "main");
+  EXPECT_NE(main, nullptr);
+  auto find_beef = hlo_query::FindFirstInstruction(main, "deadbeef");
+  auto find_nothing = hlo_query::FindFirstInstruction(main, "");
+  EXPECT_EQ(find_beef.first, nullptr);
+  EXPECT_EQ(find_beef.second, -1);
+  EXPECT_EQ(find_nothing.first, nullptr);
+  EXPECT_EQ(find_nothing.second, -1);
+}
+
+TEST_F(HloQueryTest, IsBeforeInComputationTest) {
+  TF_ASSERT_OK_AND_ASSIGN(
+      std::unique_ptr<HloModule> module,
+      ParseAndReturnUnverifiedModule(kConstantAdditionHloString));
+  const HloComputation* main = hlo_query::FindComputation(module.get(), "main");
+  EXPECT_TRUE(hlo_query::IsBeforeInComputation(main, "zero", "five"));
+  EXPECT_TRUE(hlo_query::IsBeforeInComputation(main, "five", "out"));
+}
+
 }  // namespace
 }  // namespace xla
diff --git a/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc b/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc
index a4e8c6a..28c9f0f 100644
--- a/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc
+++ b/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc
@@ -2247,7 +2247,6 @@
   //   %indices = concatenate(..., %iota.1, ...)
   //   ... = gather(..., %indices)
   // is common for tf.reverse_sequence and would match this case.
-  absl::InlinedVector<const HloIotaInstruction*, 4> iotas;
   const int num_indices = index_map.size();
   std::vector<int64_t> index_parallel_in_dim(num_indices, -1);
 
@@ -2733,21 +2732,21 @@
 
   // 2. Try borrow dimensions from replicable_dims in order, and group sharding.
   if (sharding.IsTiled()) {
-    int64_t max_replicable_dimensions =
+    const int64_t reps_on_last_tile_dim =
         sharding.ReplicateOnLastTileDim()
             ? sharding.tile_assignment().dimensions().back()
             : 1;
-    max_replicable_dimensions = absl::c_accumulate(
-        replicable_dims, max_replicable_dimensions,
+
+    const int64_t max_replicable_dimensions = absl::c_accumulate(
+        replicable_dims, reps_on_last_tile_dim,
         [&](int64_t product, int64_t dim) {
           return product * sharding.tile_assignment().dim(dim);
         });
-    if (max_replicable_dimensions % num_groups == 0) {
+
+    if (max_replicable_dimensions % num_groups == 0 &&
+        num_groups % reps_on_last_tile_dim == 0) {
       auto tile_assignment = [&]() -> std::optional<TileAssignment> {
-        int dimensions_to_borrow =
-            num_groups / (sharding.ReplicateOnLastTileDim()
-                              ? sharding.tile_assignment().dimensions().back()
-                              : 1);
+        int dimensions_to_borrow = num_groups / reps_on_last_tile_dim;
         DimensionVector tile_dims(
             sharding.tile_assignment().dimensions().begin(),
             sharding.tile_assignment().dimensions().end());
diff --git a/third_party/xla/xla/literal_util.h b/third_party/xla/xla/literal_util.h
index f6e5f58..dc28277 100644
--- a/third_party/xla/xla/literal_util.h
+++ b/third_party/xla/xla/literal_util.h
@@ -533,7 +533,7 @@
 template <typename NativeT>
 /* static */ Literal LiteralUtil::MakeScalarMatrixR2(int64_t size,
                                                      NativeT scalar) {
-  Array2D<NativeT> array(size, size, 0);
+  Array2D<NativeT> array(size, size, NativeT(0));
   for (int64_t i = 0; i < size; ++i) {
     array(i, i) = scalar;
   }
@@ -542,7 +542,7 @@
 
 template <typename NativeT>
 /* static */ Literal LiteralUtil::MakeIdentityR2(int64_t size) {
-  return MakeScalarMatrixR2<NativeT>(size, 1);
+  return MakeScalarMatrixR2<NativeT>(size, NativeT(1));
 }
 
 template <typename NativeT>
@@ -550,7 +550,7 @@
                                                            NativeT scale) {
   NativeT row_factor = log10(m) + 1;
   NativeT col_factor = log10(n) + 1;
-  Array2D<NativeT> array(m, n, 0);
+  Array2D<NativeT> array(m, n, NativeT(0));
   for (int64_t i = 0; i < m; ++i) {
     for (int64_t j = 0; j < n; ++j) {
       array(i, i) = scale * (row_factor * i + col_factor * j);
diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/prepare_for_export/prepare_for_export.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/prepare_for_export/prepare_for_export.cc
index 0ab6c24..dfd3702 100644
--- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/prepare_for_export/prepare_for_export.cc
+++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/prepare_for_export/prepare_for_export.cc
@@ -87,56 +87,6 @@
   op->erase();
 }
 
-// Ensure that there aren't any implicit capture before exporting.
-static void prepareWhileOp(WhileOp whileOp) {
-  llvm::SetVector<Value> implicitInputs;
-  getUsedValuesDefinedAbove(whileOp->getRegions(), implicitInputs);
-  if (implicitInputs.empty()) return;
-  // Each captured value has to be passed as operand to the while, become then
-  // an operand to the condition region and the body region, and an extra
-  // operand to the return op in the body. It also becomes an extra result for
-  // the while operation, even if it is unused.
-  // We'll process the captured values one at a time and patch the body and
-  // condition regions as we go, but we'll accumulate the new operands and
-  // result type and recreate a new while op to replace the existing one at the
-  // end.
-  SmallVector<Type> returnedTypes(whileOp->getResultTypes().begin(),
-                                  whileOp->getResultTypes().end());
-  SmallVector<Value> operands(whileOp->getOperands().begin(),
-                              whileOp->getOperands().end());
-  Region &condRegion = whileOp.getCond();
-  Region &bodyRegion = whileOp.getBody();
-
-  for (Value input : implicitInputs) {
-    returnedTypes.push_back(input.getType());
-    operands.push_back(input);
-
-    Value condArg =
-        condRegion.front().addArgument(input.getType(), input.getLoc());
-    Value bodyArg =
-        bodyRegion.front().addArgument(input.getType(), input.getLoc());
-    for (OpOperand &operand : llvm::make_early_inc_range(input.getUses())) {
-      if (condRegion.isAncestor(operand.getOwner()->getParentRegion()))
-        operand.set(condArg);
-      else if (bodyRegion.isAncestor(operand.getOwner()->getParentRegion()))
-        operand.set(bodyArg);
-    }
-    auto returnOp = cast<mhlo::ReturnOp>(bodyRegion.front().back());
-    returnOp->insertOperands(returnOp->getNumOperands(), bodyArg);
-  }
-  OpBuilder builder(whileOp);
-  auto newWhileOp =
-      builder.create<mhlo::WhileOp>(whileOp.getLoc(), returnedTypes, operands);
-  newWhileOp.getCond().getBlocks().clear();
-  newWhileOp.getCond().takeBody(whileOp.getCond());
-  newWhileOp.getBody().getBlocks().clear();
-  newWhileOp.getBody().takeBody(whileOp.getBody());
-  for (auto zippedResults :
-       llvm::zip_first(whileOp.getResults(), newWhileOp.getResults()))
-    std::get<0>(zippedResults).replaceAllUsesWith(std::get<1>(zippedResults));
-  whileOp->erase();
-}
-
 static void prepareBroadcastInDim(BroadcastInDimOp bcast) {
   DenseIntElementsAttr dims = bcast.getBroadcastDimensions();
   // If dimensions aren't sorted, there is a transpose fused into the op, which
@@ -200,7 +150,6 @@
     mlir::SplatElementsAttr attr;
     if (matchPattern(op, m_Constant(&attr))) return prepareConstantOp(op, attr);
 
-    if (auto whileOp = dyn_cast<WhileOp>(op)) return prepareWhileOp(whileOp);
     if (auto bcastOp = dyn_cast<BroadcastInDimOp>(op))
       return prepareBroadcastInDim(bcastOp);
     // IfOp, CaseOp, WhileOp are already being handled during
diff --git a/third_party/xla/xla/mlir_hlo/stablehlo_ext/transforms/chlo_recompose_ops.cpp b/third_party/xla/xla/mlir_hlo/stablehlo_ext/transforms/chlo_recompose_ops.cpp
index fb5728f..68ec969 100644
--- a/third_party/xla/xla/mlir_hlo/stablehlo_ext/transforms/chlo_recompose_ops.cpp
+++ b/third_party/xla/xla/mlir_hlo/stablehlo_ext/transforms/chlo_recompose_ops.cpp
@@ -104,7 +104,7 @@
     auto res = verifyCustomCallOpAttributes(
         op, rewriter, [&](NamedAttribute attr) -> LogicalResult {
           if (attr.getName() != "largest") return success();
-          if (cast<BoolAttr>(attr.getValue()).getValue() == false)
+          if (!cast<BoolAttr>(attr.getValue()).getValue())
             return rewriter.notifyMatchFailure(
                 op, "largest = false is not supported.");
           return success();
diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/prepare-for-export.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/prepare-for-export.mlir
index 6944e79..a214de5 100644
--- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/prepare-for-export.mlir
+++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/prepare-for-export.mlir
@@ -22,94 +22,6 @@
 
 // -----
 
-// CHECK-LABEL: @while_without_implicit_capture
-func.func @while_without_implicit_capture(%arg0: tensor<i64>) -> tensor<i64> {
-  // CHECK: mhlo.while
-  // CHECK-SAME: (%[[ARG1:.*]] = %arg0, %[[ARG2:.*]] = %arg0)
-  // CHECK-SAME: {mhlo.sharding = "{{\{}}{replicated},{replicated}}"}
-  %0:2 = "mhlo.while"(%arg0, %arg0) ({
-  ^bb0(%arg1: tensor<i64>, %arg2: tensor<i64>):
-    %1 = "mhlo.compare"(%arg1, %arg2) {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<i64>, tensor<i64>) -> tensor<i1>
-    "mhlo.return"(%1) : (tensor<i1>) -> ()
-  },  {
-  ^bb0(%arg1: tensor<i64>, %arg2: tensor<i64>):
-    %2 = mhlo.add %arg1, %arg1 : tensor<i64>
-    "mhlo.return"(%2, %arg2) : (tensor<i64>, tensor<i64>) -> ()
-  }) {mhlo.sharding = "{{replicated},{replicated}}"} : (tensor<i64>, tensor<i64>) -> (tensor<i64>, tensor<i64>)
-  func.return %0#0 : tensor<i64>
-}
-
-// -----
-
-// CHECK-LABEL: @while_with_implicit_arg_capture
-func.func @while_with_implicit_arg_capture(%arg0: tensor<i64>) -> tensor<i64> {
-  // CHECK: mhlo.while
-  // CHECK-SAME: (%[[ARG1:.*]] = %arg0, %[[ARG2:.*]] = %arg0)
-  %0 = "mhlo.while"(%arg0) ({
-  ^bb0(%arg1: tensor<i64>):
-    // CHECK: mhlo.compare
-    // CHECK-SAME: %[[ARG2]], %[[ARG1]]
-    %1 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<i64>, tensor<i64>) -> tensor<i1>
-    "mhlo.return"(%1) : (tensor<i1>) -> ()
-  },  {
-  ^bb0(%arg1: tensor<i64>):
-    // CHECK: %[[ADD:.*]] = mhlo.add %[[ARG1]], %[[ARG1]]
-    %2 = mhlo.add %arg1, %arg1 : tensor<i64>
-    // CHECK: mhlo.return
-    // CHECK-SAME: %[[ADD]], %[[ARG2]]
-    "mhlo.return"(%2) : (tensor<i64>) -> ()
-  }) : (tensor<i64>) -> tensor<i64>
-  func.return %0 : tensor<i64>
-}
-
-// -----
-
-// CHECK-LABEL: @while_with_implicit_capture
-// func @while_with_implicit_capture(%arg0 :  tuple<tensor<i1>, tensor<5xi32>>) -> tuple<tensor<i1>, tensor<5xi32>> {
-func.func @while_with_implicit_capture(%arg0 :  tensor<i1>, %arg1 : tensor<5xi32>) -> tuple<tensor<i1>, tensor<5xi32>> {
-  %0 = mhlo.constant dense<0> : tensor<i32>
-  %1 = mhlo.constant dense<false> : tensor<i1>
-  // Check that the iota implicit capture is made explicit
-  // CHECK: %[[IOTA:.*]] = "mhlo.iota
-  %2 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<5xi32>
-  // CHECK: mhlo.while{{.*}} %[[IOTA]])
-  %3:2 = "mhlo.while"(%arg0, %arg1) ({
-  ^bb0(%arg2: tensor<i1>, %arg3 : tensor<5xi32>):
-    "mhlo.return"(%arg2) : (tensor<i1>) -> ()
-  },  {
-  ^bb0(%arg2: tensor<i1>, %arg3 : tensor<5xi32>):
-    "mhlo.return"(%arg2, %2) : (tensor<i1>, tensor<5xi32>) -> ()
-  }) : (tensor<i1>, tensor<5xi32>) -> (tensor<i1>, tensor<5xi32>)
-  %4 = "mhlo.tuple"(%3#0, %3#1) : (tensor<i1>, tensor<5xi32>) -> tuple<tensor<i1>, tensor<5xi32>>
-  func.return %4 : tuple<tensor<i1>, tensor<5xi32>>
-  }
-
-// -----
-
-// Verifies that a value captured multiple times gets all of its uses updated.
-// CHECK-LABEL: @while_with_multiple_capture
-func.func @while_with_multiple_capture(%arg0: tensor<i64>) -> tensor<i64> {
-  // CHECK: mhlo.while
-  // CHECK-SAME: (%[[ARG1:.*]] = %arg0, %[[ARG2:.*]] = %arg0)
-  %0 = "mhlo.while"(%arg0) ({
-  ^bb0(%arg1: tensor<i64>):
-    // CHECK: mhlo.compare
-    // CHECK-SAME: %[[ARG2]], %[[ARG1]]
-    %1 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<i64>, tensor<i64>) -> tensor<i1>
-    "mhlo.return"(%1) : (tensor<i1>) -> ()
-  },  {
-  ^bb0(%arg1: tensor<i64>):
-    // CHECK: %[[ADD:.*]] = mhlo.add %[[ARG2]], %[[ARG1]]
-    %2 = mhlo.add %arg0, %arg1 : tensor<i64>
-    // CHECK: mhlo.return
-    // CHECK-SAME: %[[ADD]], %[[ARG2]]
-    "mhlo.return"(%2) : (tensor<i64>) -> ()
-  }) : (tensor<i64>) -> tensor<i64>
-  func.return %0 : tensor<i64>
-}
-
-// -----
-
 // CHECK-LABEL: @broadcast_in_dim_dimension_unsorted
 func.func @broadcast_in_dim_dimension_unsorted(%arg0: tensor<1x2xi32>) -> tensor<1x2x3xi32> {
 // Unfuse the transpose from the broadcastInDim before export.
diff --git a/third_party/xla/xla/mlir_hlo/utils/codegen_utils.cc b/third_party/xla/xla/mlir_hlo/utils/codegen_utils.cc
index 2562c06..0c8a4de 100644
--- a/third_party/xla/xla/mlir_hlo/utils/codegen_utils.cc
+++ b/third_party/xla/xla/mlir_hlo/utils/codegen_utils.cc
@@ -100,9 +100,10 @@
   return calcMultiDimIndex(b, loc, linearIndex, shapeVec);
 }
 
-SmallVector<Value> calcMultiDimIndexForFirstOperand(OpBuilder& b, Location loc,
-                                                    Value linearIndex,
-                                                    Operation* op) {
+static SmallVector<Value> calcMultiDimIndexForFirstOperand(OpBuilder& b,
+                                                           Location loc,
+                                                           Value linearIndex,
+                                                           Operation* op) {
   assert(op->getDialect()->getNamespace() == "lmhlo");
   Value operandMemref = op->getOperand(0);
   return calcMultiDimIndex(b, loc, linearIndex, operandMemref);
diff --git a/third_party/xla/xla/pjrt/BUILD b/third_party/xla/xla/pjrt/BUILD
index b73a59f..adf30e0 100644
--- a/third_party/xla/xla/pjrt/BUILD
+++ b/third_party/xla/xla/pjrt/BUILD
@@ -311,6 +311,10 @@
     deps = [
         ":pjrt_common",
         "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings:string_view",
+        "@com_google_absl//absl/types:span",
     ],
 )
 
diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc
index eee88ad..e17b04d 100644
--- a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc
+++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc
@@ -77,8 +77,8 @@
 
 TEST_F(PjrtCApiGpuTest, CreateViewOfDeviceBuffer) {
   // Prepares a device memory ptr on GPU.
-  std::unique_ptr<PJRT_Buffer, ::pjrt::PJRT_BufferDeleter> buffer =
-      create_buffer().first;
+  auto [buffer, buffer_future] = create_buffer();
+  TF_CHECK_OK(buffer_future.Await());
   PJRT_Buffer_OpaqueDeviceMemoryDataPointer_Args device_buffer_ptr_args;
   device_buffer_ptr_args.struct_size =
       PJRT_Buffer_OpaqueDeviceMemoryDataPointer_Args_STRUCT_SIZE;
diff --git a/third_party/xla/xla/pjrt/cpu/BUILD b/third_party/xla/xla/pjrt/cpu/BUILD
index f3ae2c5..0ac2c99 100644
--- a/third_party/xla/xla/pjrt/cpu/BUILD
+++ b/third_party/xla/xla/pjrt/cpu/BUILD
@@ -191,6 +191,7 @@
         "//xla/stream_executor",
         "//xla/tsl/concurrency:async_value",
         "//xla/tsl/concurrency:ref_count",
+        "//xla/tsl/lib/strings:proto_serialization",
         "@com_google_absl//absl/algorithm:container",
         "@com_google_absl//absl/base:core_headers",
         "@com_google_absl//absl/base:dynamic_annotations",
@@ -208,7 +209,6 @@
         "@com_google_absl//absl/types:span",
         "@eigen_archive//:eigen3",  # TODO(zhangqiaorjc): Remove if use TFRT threadpool.
         "@llvm-project//mlir:IR",
-        "@local_tsl//tsl/lib/strings:proto_serialization",
         "@local_tsl//tsl/platform:casts",
         "@local_tsl//tsl/platform:denormal",
         "@local_tsl//tsl/platform:env",
diff --git a/third_party/xla/xla/pjrt/cpu/cpu_client.cc b/third_party/xla/xla/pjrt/cpu/cpu_client.cc
index 3b2caa6..78779c8 100644
--- a/third_party/xla/xla/pjrt/cpu/cpu_client.cc
+++ b/third_party/xla/xla/pjrt/cpu/cpu_client.cc
@@ -103,10 +103,10 @@
 #include "xla/tsl/concurrency/async_value.h"
 #include "xla/tsl/concurrency/async_value_ref.h"
 #include "xla/tsl/concurrency/ref_count.h"
+#include "xla/tsl/lib/strings/proto_serialization.h"
 #include "xla/util.h"
 #include "xla/xla.pb.h"
 #include "xla/xla_data.pb.h"
-#include "tsl/lib/strings/proto_serialization.h"
 #include "tsl/platform/casts.h"
 #include "tsl/platform/denormal.h"
 #include "tsl/platform/env.h"
@@ -398,6 +398,11 @@
       num_threads, options.asynchronous));
 }
 
+// An upper bound on the number of threads to use for intra-op parallelism. It
+// is nearly impossible to utilize efficiently more than 256 threads for compute
+// intensive operations that are supposed to run inside the intra-op threadpool.
+static const size_t kMaxIntraOpThreads = 256;
+
 static tsl::ThreadOptions GetThreadOptions() {
   tsl::ThreadOptions thread_options;
   // On Mac OS the default stack size is 512KiB, which is too small for some
@@ -415,16 +420,17 @@
     : process_index_(process_index),
       owned_devices_(std::move(devices)),
       computation_placer_(std::make_unique<ComputationPlacer>()),
+      eigen_intraop_pool_(new tsl::thread::ThreadPool(
+          tsl::Env::Default(), "XLAEigen",
+          std::min(num_threads, kMaxIntraOpThreads))),
+      eigen_intraop_device_(
+          new Eigen::ThreadPoolDevice(eigen_intraop_pool_->AsEigenThreadPool(),
+                                      eigen_intraop_pool_->NumThreads())),
       pjrt_client_thread_pool_(
           new tsl::thread::ThreadPool(tsl::Env::Default(), GetThreadOptions(),
                                       "XLATfrtCpuClient", num_threads)),
       async_work_runner_(std::make_unique<ThreadPoolAsyncWorkRunner>(
           pjrt_client_thread_pool_.get())),
-      eigen_intraop_pool_(new tsl::thread::ThreadPool(tsl::Env::Default(),
-                                                      "XLAEigen", num_threads)),
-      eigen_intraop_device_(
-          new Eigen::ThreadPoolDevice(eigen_intraop_pool_->AsEigenThreadPool(),
-                                      eigen_intraop_pool_->NumThreads())),
       last_collective_launch_event_(
           tsl::MakeAvailableAsyncValueRef<CpuEvent>()),
       transpose_cache_(1024),
@@ -463,10 +469,10 @@
     owned_memory_spaces_.push_back(std::move(memory_space));
   }
 
-  LOG(INFO) << "TfrtCpuClient created.";
+  VLOG(1) << "TfrtCpuClient created.";
 }
 
-TfrtCpuClient::~TfrtCpuClient() { LOG(INFO) << "TfrtCpuClient destroyed."; }
+TfrtCpuClient::~TfrtCpuClient() { VLOG(1) << "TfrtCpuClient destroyed."; }
 
 absl::StatusOr<PjRtDevice*> TfrtCpuClient::LookupDevice(
     xla::PjRtGlobalDeviceId global_device_id) const {
diff --git a/third_party/xla/xla/pjrt/cpu/cpu_client.h b/third_party/xla/xla/pjrt/cpu/cpu_client.h
index 54f75f3..ba4426b 100644
--- a/third_party/xla/xla/pjrt/cpu/cpu_client.h
+++ b/third_party/xla/xla/pjrt/cpu/cpu_client.h
@@ -446,14 +446,14 @@
   // Pointers to `owned_memory_spaces_`.
   std::vector<PjRtMemorySpace*> memory_spaces_;
 
-  // Thread pool for running PjRtClient tasks.
-  std::unique_ptr<tsl::thread::ThreadPool> pjrt_client_thread_pool_;
-  std::unique_ptr<AsyncWorkRunner> async_work_runner_;
-
   // TODO(zhangqiaorjc): Use tsl::compat::EigenHostContextThreadPool.
   std::unique_ptr<tsl::thread::ThreadPool> eigen_intraop_pool_;
   std::unique_ptr<Eigen::ThreadPoolDevice> eigen_intraop_device_;
 
+  // Thread pool for running PjRtClient tasks.
+  std::unique_ptr<tsl::thread::ThreadPool> pjrt_client_thread_pool_;
+  std::unique_ptr<AsyncWorkRunner> async_work_runner_;
+
   // Launching collectives are prone to deadlock when we use fixed-sized
   // threadpools since ExecuteHelper will block until all replicas reach the
   // barrier. We ensure that
@@ -589,7 +589,7 @@
     }
     memory_stats.serialized_hlo_proto = proto->SerializeAsString();
     memory_stats.PopulateBufferStatsFromAllocations(
-        cpu_executable_.get()->GetAllocations());
+        cpu_executable_->GetAllocations());
     return memory_stats;
   }
 
diff --git a/third_party/xla/xla/pjrt/gpu/BUILD b/third_party/xla/xla/pjrt/gpu/BUILD
index 90ab11e..f9019fc 100644
--- a/third_party/xla/xla/pjrt/gpu/BUILD
+++ b/third_party/xla/xla/pjrt/gpu/BUILD
@@ -91,6 +91,7 @@
         "//xla/tsl/framework:bfc_allocator",
         "//xla/tsl/framework:device_id",
         "//xla/tsl/framework:device_id_impl",
+        "//xla/tsl/lib/strings:proto_serialization",
         "//xla/tsl/util:env_var",
         "@com_google_absl//absl/algorithm:container",
         "@com_google_absl//absl/base:core_headers",
@@ -108,7 +109,6 @@
         "@com_google_absl//absl/synchronization",
         "@com_google_absl//absl/time",
         "@com_google_absl//absl/types:span",
-        "@local_tsl//tsl/lib/strings:proto_serialization",
         "@local_tsl//tsl/platform:casts",
         "@local_tsl//tsl/platform:env",
         "@local_tsl//tsl/platform:errors",
diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc
index 27d3f18..9fe8b2e 100644
--- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc
+++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc
@@ -82,7 +82,7 @@
 #include "xla/stream_executor/stream.h"
 #include "xla/stream_executor/stream_executor.h"
 #include "xla/tsl/framework/allocator.h"
-#include "tsl/lib/strings/proto_serialization.h"
+#include "xla/tsl/lib/strings/proto_serialization.h"
 #include "tsl/platform/casts.h"
 #include "tsl/platform/errors.h"
 #include "tsl/platform/fingerprint.h"
diff --git a/third_party/xla/xla/pjrt/pjrt_device_description.h b/third_party/xla/xla/pjrt/pjrt_device_description.h
index ed85269..77107fd 100644
--- a/third_party/xla/xla/pjrt/pjrt_device_description.h
+++ b/third_party/xla/xla/pjrt/pjrt_device_description.h
@@ -20,12 +20,35 @@
 #include <string_view>
 
 #include "absl/container/flat_hash_map.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
 #include "xla/pjrt/pjrt_common.h"
 
 namespace xla {
 
 using PjRtDeviceAttribute = PjRtValueType;
 
+class PjRtMemorySpaceDescription {
+ public:
+  PjRtMemorySpaceDescription(absl::string_view kind, int kind_id)
+      : kind_(kind), kind_id_(kind_id) {}
+
+  // A platform-dependent string that uniquely identifies the kind of the
+  // memory space.
+  absl::string_view kind() const { return kind_; }
+
+  // An ID uniquely identifies the kind of the memory space among those attached
+  // to the same `PjRtClient`. The IDs assigned to a kind is implementation
+  // specific.
+  int kind_id() const { return kind_id_; }
+
+ private:
+  absl::string_view kind_;
+  int kind_id_;
+};
+
 class PjRtDeviceDescription {
  public:
   virtual ~PjRtDeviceDescription() = default;
@@ -60,6 +83,19 @@
   // reference will remain valid for the lifetime of the PjRtDevice.
   virtual const absl::flat_hash_map<std::string, PjRtDeviceAttribute>&
   Attributes() const = 0;
+
+  // Returns all memory spaces attached to this device.
+  // The memory spaces are in no particular order.
+  virtual absl::Span<const PjRtMemorySpaceDescription* const> memory_spaces()
+      const {
+    return {};
+  }
+
+  // Returns the default memory space attached to this device.
+  virtual absl::StatusOr<const PjRtMemorySpaceDescription*>
+  default_memory_space() const {
+    return absl::UnimplementedError("default_memory_space Not implemented.");
+  }
 };
 
 }  // namespace xla
diff --git a/third_party/xla/xla/protobuf_util.cc b/third_party/xla/xla/protobuf_util.cc
index a8d6dfa..4c6815d 100644
--- a/third_party/xla/xla/protobuf_util.cc
+++ b/third_party/xla/xla/protobuf_util.cc
@@ -49,20 +49,5 @@
   return absl::HashOf(serialized);
 }
 
-absl::Status DumpProtoToDirectory(const tsl::protobuf::Message& message,
-                                  const std::string& directory,
-                                  const std::string& file_name,
-                                  std::string* full_path) {
-  tsl::Env* env = tsl::Env::Default();
-  TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(directory));
-  std::string safe_file_name = SanitizeFileName(file_name) + ".pb";
-  std::string full_path_impl;
-  if (!full_path) {
-    full_path = &full_path_impl;
-  }
-  *full_path = tsl::io::JoinPath(directory, safe_file_name);
-  return tsl::WriteBinaryProto(env, *full_path, message);
-}
-
 }  // namespace protobuf_util
 }  // namespace xla
diff --git a/third_party/xla/xla/protobuf_util.h b/third_party/xla/xla/protobuf_util.h
index 79f0077..81f7952 100644
--- a/third_party/xla/xla/protobuf_util.h
+++ b/third_party/xla/xla/protobuf_util.h
@@ -55,17 +55,6 @@
     return ProtobufHash(m);
   }
 };
-// Writes the given message in binary proto to the path formed by joining
-// 'directory/file_name.pb'. The 'directory' is recursively created if it
-// doesn't already exist, and the 'file_name' is sanitized by replacing
-// illegal characters with underscore '_'.
-//
-// If 'full_name' is not null then it is set to the name of the file the
-// protobuf was written to.
-absl::Status DumpProtoToDirectory(const tsl::protobuf::Message& message,
-                                  const std::string& directory,
-                                  const std::string& file_name,
-                                  std::string* full_path = nullptr);
 
 // Registers a function that may either expand a dirpath or forward the original
 // dirpath along as-is.
diff --git a/third_party/xla/xla/python/BUILD b/third_party/xla/xla/python/BUILD
index 3bf1ebf..99c49a1 100644
--- a/third_party/xla/xla/python/BUILD
+++ b/third_party/xla/xla/python/BUILD
@@ -3,10 +3,6 @@
 load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
 load("@local_tsl//tsl/platform:build_config.bzl", "tf_proto_library")
 load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library")
-load(
-    "@local_tsl//tsl/platform/default:cuda_build_defs.bzl",
-    "if_cuda_is_configured",
-)
 load("//xla:pytype.default.bzl", "pytype_strict_library")
 load("//xla:strict.default.bzl", "py_strict_library", "py_strict_test")
 load(
@@ -18,7 +14,6 @@
     "//xla/tsl:tsl.bzl",
     "if_cuda_or_rocm",
     "if_google",
-    "if_oss",
     "internal_visibility",
 )
 load("//xla/tsl:tsl.default.bzl", "get_compatible_with_portable", "tsl_pybind_extension")
@@ -1052,12 +1047,7 @@
         "@local_tsl//tsl/profiler/rpc:profiler_server_impl",
         "@local_tsl//tsl/profiler/rpc/client:capture_profile",
         "@local_tsl//tsl/profiler/rpc/client:profiler_client_impl",
-    ] + select({
-        ":gpu_enabled": [
-            "//xla/backends/profiler/gpu:device_tracer",
-        ],
-        "//conditions:default": [],
-    }),
+    ],
 )
 
 cc_library(
@@ -1180,7 +1170,7 @@
         "//xla/service:hlo_proto_cc",
         "//xla/service:name_uniquer",
         "//xla/service:tuple_simplifier",
-        "@local_tsl//tsl/lib/strings:proto_serialization",
+        "//xla/tsl/lib/strings:proto_serialization",
         "@local_tsl//tsl/platform:env",
         "@local_tsl//tsl/platform:errors",
         "@local_tsl//tsl/platform:logging",
@@ -1194,26 +1184,6 @@
     cc_api_version = 2,
 )
 
-# TODO(phawkins): the configuration settings here are overly confusing. The right fix is to split
-# xla_extension.so so that each backend is a separate plugin, however that must wait for a clean
-# ABI separation between devices.
-config_setting(
-    name = "link_gpu_plugin",
-    define_values = {"xla_python_enable_gpu": "true"},
-)
-
-bool_flag(
-    name = "enable_gpu",
-    build_setting_default = True,
-)
-
-config_setting(
-    name = "gpu_enabled",
-    flag_values = {
-        ":enable_gpu": "True",
-    },
-)
-
 # If this flag is enabled, it sets RPATH on the xla_extension to values that are suitable for
 # finding NVIDIA's CUDA libraries when they are installed as pip packages.
 bool_flag(
@@ -1228,17 +1198,6 @@
     },
 )
 
-# We cannot nest select and if_cuda_is_configured so we introduce
-# a standalone cc_library target.
-cc_library(
-    name = "gpu_plugin_deps",
-    deps = [
-        "//xla/service:gpu_plugin",
-    ] + if_cuda_is_configured([
-        "//xla/stream_executor:cuda_platform",
-    ]),
-)
-
 cc_library(
     name = "logging",
     srcs = ["logging.cc"],
@@ -1301,13 +1260,6 @@
         "-fexceptions",
         "-fno-strict-aliasing",
     ],
-    defines = if_google(
-        [],
-        select({
-            ":gpu_enabled": ["XLA_PYTHON_ENABLE_GPU=1"],
-            "//conditions:default": [],
-        }),
-    ),
     features = ["-use_header_modules"],
     linkopts = select({
         ":use_jax_cuda_pip_rpaths": [
@@ -1407,10 +1359,7 @@
         "@local_tsl//tsl/platform/cloud:gcs_file_system",
     ] + select({
         # gloo transport only builds on linux
-        "//xla/tsl:macos": [
-            "//xla/pjrt/cpu:gloo_collectives",
-            "//xla/pjrt/cpu:gloo_kv_store",
-        ] + if_oss(["@gloo//:transport_uv"]),
+        "//xla/tsl:macos": [],
         "//xla/tsl:windows": [],
         "//conditions:default": [
             "//xla/pjrt/cpu:gloo_collectives",
@@ -1423,20 +1372,7 @@
         "//conditions:default": [
             "//xla/pjrt/cpu:mpi_collectives",
         ],
-    }) + if_google(
-        [],
-        select({
-            ":gpu_enabled": [
-                ":gpu_support",
-            ],
-            "//conditions:default": [],
-        }) + select({
-            ":link_gpu_plugin": [
-                ":gpu_plugin_deps",
-            ],
-            "//conditions:default": [],
-        }),
-    ),
+    }),
 )
 
 cc_library(
diff --git a/third_party/xla/xla/python/ifrt/ir/ifrt_ops.cc b/third_party/xla/xla/python/ifrt/ir/ifrt_ops.cc
index 6821c7e..080f0fa 100644
--- a/third_party/xla/xla/python/ifrt/ir/ifrt_ops.cc
+++ b/third_party/xla/xla/python/ifrt/ir/ifrt_ops.cc
@@ -182,6 +182,46 @@
   int output_index;
 };
 
+mlir::LogicalResult VerifyElementTypeAndPerShardShapeAreEqual(
+    mlir::Operation* op, IfrtArrayType in, int in_index, IfrtArrayType out,
+    int out_index) {
+  if (in.getShape().getElementType() != out.getShape().getElementType()) {
+    return op->emitOpError()
+           << "can't alias input #" << in_index << " to output #" << out_index
+           << " with different element types: " << in << " vs " << out;
+  }
+
+  absl::StatusOr<llvm::SmallVector<int64_t>> in_per_shard_shape =
+      in.getShardingAttr().LocalShapeFromGlobalShape(in.getShape().getShape());
+  if (!in_per_shard_shape.ok()) {
+    return op->emitOpError()
+           << "unable to get per-shard shape of aliased input #" << in_index
+           << ": " << in_per_shard_shape.status().message();
+  }
+  absl::StatusOr<llvm::SmallVector<int64_t>> out_per_shard_shape =
+      out.getShardingAttr().LocalShapeFromGlobalShape(
+          out.getShape().getShape());
+  if (!out_per_shard_shape.ok()) {
+    return op->emitOpError()
+           << "unable to get per-shard shape of aliased output #" << out_index
+           << ": " << out_per_shard_shape.status().message();
+  }
+  if (in_per_shard_shape->size() != out_per_shard_shape->size()) {
+    return op->emitOpError()
+           << "can't alias input #" << in_index << " to output #" << out_index
+           << " with different per-shard shapes: " << in << " vs " << out;
+  }
+  for (const auto& [in_dim, out_dim] :
+       llvm::zip(*in_per_shard_shape, *out_per_shard_shape)) {
+    if (in_dim != out_dim) {
+      return op->emitOpError()
+             << "can't alias input #" << in_index << " to output #" << out_index
+             << " with different per-shard shapes: " << in << " vs " << out;
+    }
+  }
+  return mlir::success();
+}
+
 mlir::LogicalResult VerifyIoAlias(mlir::Operation* op, IoAlias io_alias,
                                   llvm::ArrayRef<IfrtArrayType> inputs,
                                   llvm::ArrayRef<IfrtArrayType> outputs) {
@@ -198,11 +238,12 @@
            << " outputs";
   }
   if (inputs[io_alias.input_index] != outputs[io_alias.output_index]) {
-    return op->emitOpError()
-           << "can't alias input #" << io_alias.input_index << " to output #"
-           << io_alias.output_index
-           << " with different types: " << inputs[io_alias.input_index]
-           << " vs " << outputs[io_alias.output_index];
+    // TODO(icgog): Relax this aliasing check to allow for different per-shard
+    // shapes as long as the byte size is the same. We cannot do this now
+    // because we do not have layout information.
+    return VerifyElementTypeAndPerShardShapeAreEqual(
+        op, inputs[io_alias.input_index], io_alias.input_index,
+        outputs[io_alias.output_index], io_alias.output_index);
   }
   return mlir::success();
 }
diff --git a/third_party/xla/xla/python/ifrt/ir/tests/ifrt_verify_donation.mlir b/third_party/xla/xla/python/ifrt/ir/tests/ifrt_verify_donation.mlir
index edaff92..8c70318 100644
--- a/third_party/xla/xla/python/ifrt/ir/tests/ifrt_verify_donation.mlir
+++ b/third_party/xla/xla/python/ifrt/ir/tests/ifrt_verify_donation.mlir
@@ -238,7 +238,6 @@
   }
 }
 
-
 // -----
 
 !array0 = !ifrt.array<tensor<2xi32>,
@@ -259,3 +258,21 @@
     return %arg0 : tensor<2xi32>
   }
 }
+
+// -----
+
+!array = !ifrt.array<tensor<2xi32>,
+                     #ifrt.sharding_param<2 to [0] on 2>, [0, 1]>
+module @copy_arrays_with_already_donated_array_error {
+  func.func @main(%arg0: !array {ifrt.donated}) -> (!array, !array)
+      attributes {ifrt.function} {
+    %0, %ctrl_0 = ifrt.Call @identity(%arg0) on devices [0,1]
+        {io_aliases=[array<i32: 0, 0>]} : (!array) -> !array
+    // expected-error @+1 {{'func.return' op result #1 of op at}}
+    return %0, %arg0 : !array, !array
+  }
+
+  func.func private @identity(%arg0: tensor<2xi32>) -> tensor<2xi32> {
+    return %arg0 : tensor<2xi32>
+  }
+}
diff --git a/third_party/xla/xla/python/ifrt/ir/tests/spmd_expansion.mlir b/third_party/xla/xla/python/ifrt/ir/tests/spmd_expansion.mlir
index 28a1dda..4fef087 100644
--- a/third_party/xla/xla/python/ifrt/ir/tests/spmd_expansion.mlir
+++ b/third_party/xla/xla/python/ifrt/ir/tests/spmd_expansion.mlir
@@ -3,7 +3,7 @@
 #device = #ifrt<devices[0,1]>
 #sharding = #ifrt.sharding_param<2x1 to [0] on 2>
 // CHECK-LABEL: @identity_axis0_sharded
-module @identity_axis0_sharded attributes {ifrt.devices = #device} {
+module @identity_axis0_sharded attributes {ifrt.num_devices = 2} {
   // CHECK-NEXT: func.func @main
   // CHECK-SAME: %[[ARG:.*]]: tensor<1x2xi32>
   // CHECK-NEXT: return %[[ARG]]
@@ -23,7 +23,7 @@
 #sharding = #ifrt.sharding_param<1x2 to [0] on 2>
 // CHECK-LABEL: @identity_axis1_sharded
 module @identity_axis1_sharded
-    attributes {ifrt.devices = #device, ifrt.entry_function = "entry_func"} {
+    attributes {ifrt.num_devices = 2, ifrt.entry_function = "entry_func"} {
   // CHECK-NEXT: func.func @entry_func
   // CHECK-SAME: %[[ARG:.*]]: tensor<2x1xi32>
   // CHECK-NEXT: return %[[ARG]]
@@ -42,7 +42,7 @@
 #device = #ifrt<devices[0,1,2,3,4,5]>
 #sharding = #ifrt.sharding_param<3x2 to [1,0] on 2x3>
 // CHECK-LABEL: @identify_both_axes_sharded
-module @identify_both_axes_sharded attributes {ifrt.devices = #device} {
+module @identify_both_axes_sharded attributes {ifrt.num_devices = 6} {
   // CHECK-NEXT: func.func @main
   // CHECK-SAME: %[[ARG:.*]]: tensor<1x1xi32>
   // CHECK-NEXT: return %[[ARG]]
@@ -60,7 +60,7 @@
 
 #device = #ifrt<devices[0,1]>
 // CHECK-LABEL: @with_func_call
-module @with_func_call attributes {ifrt.devices = #device} {
+module @with_func_call attributes {ifrt.num_devices = 2} {
   // CHECK-NEXT: func.func @main
   // CHECK-SAME: %[[ARG:.*]]: tensor<1x2xi32>
   // CHECK-SAME: tensor<1x2xi32>
@@ -94,7 +94,7 @@
 
 #device = #ifrt<devices[0,1]>
 // CHECK-LABEL: @with_nested_func_call
-module @with_nested_func_call attributes {ifrt.devices = #device} {
+module @with_nested_func_call attributes {ifrt.num_devices = 2} {
   // CHECK-NEXT: func.func @main
   // CHECK-SAME: %[[ARG:.*]]: tensor<1x2xi32>
   // CHECK-SAME: tensor<1x2xi32>
@@ -139,11 +139,10 @@
 
 // -----
 
-#device = #ifrt<devices[0,1]>
 #sharding = #ifrt.sharding_param<1x2 to [0] on 2>
 // expected-error@+1 {{cannot find entry function `main`}}
 module @missing_main_function
-    attributes {ifrt.devices = #device} {
+    attributes {ifrt.num_devices = 2} {
 }
 
 // -----
@@ -152,7 +151,7 @@
 #sharding = #ifrt.sharding_param<1x2 to [0] on 2>
 // expected-error@+1 {{cannot find entry function `entry_func`}}
 module @missing_entry_function
-    attributes {ifrt.devices = #device, ifrt.entry_function = "entry_func"} {
+    attributes {ifrt.num_devices = 2, ifrt.entry_function = "entry_func"} {
   func.func @main(
       %arg0: tensor<2x2xi32> {ifrt.sharding = #sharding,
       ifrt.devices = #device})
@@ -166,7 +165,7 @@
 
 #device = #ifrt<devices[0,1]>
 #sharding = #ifrt.sharding_param<2x1 to [0] on 2>
-module @non_divisible_global_shape attributes {ifrt.devices = #device} {
+module @non_divisible_global_shape attributes {ifrt.num_devices = 2} {
   // expected-error@+1 {{Global shape is not divisible by the number of shards in dimension 0. Global size: 3, number of shards: 2}}
   func.func @main(
       %arg0: tensor<3x2xi32> {ifrt.sharding = #sharding,
diff --git a/third_party/xla/xla/python/ifrt/ir/tests/verify_call.mlir b/third_party/xla/xla/python/ifrt/ir/tests/verify_call.mlir
index b318756..e512b26 100644
--- a/third_party/xla/xla/python/ifrt/ir/tests/verify_call.mlir
+++ b/third_party/xla/xla/python/ifrt/ir/tests/verify_call.mlir
@@ -355,11 +355,32 @@
 
 // -----
 
-func.func @io_aliases_should_have_same_type(
+!array0 = !ifrt.array<tensor<2x1xi32>,
+                      #ifrt.sharding_param<2x1 to [0] on 2>, [0,1]>
+!array1 = !ifrt.array<tensor<1x1xi32>,
+                      #ifrt.sharding_param<1x1 to [0] on 2>, [0,1]>
+func.func @io_aliases_of_different_type_but_same_per_shard_shape(%arg0: !array0)
+    attributes {ifrt.function} {
+  %0, %ctrl_0 = ifrt.Call @callee(%arg0) on devices [0,1]
+    {io_aliases=[array<i32: 0, 0>]} : (!array0) -> !array1
+  return
+}
+
+func.func @callee(%arg0: tensor<2x1xi32>) -> tensor<1x1xi32> {
+  %0 = mhlo.constant dense<-2147483648> : tensor<i32>
+  %1 = mhlo.reduce(%arg0 init: %0) applies mhlo.maximum across dimensions = [0, 1]
+    : (tensor<2x1xi32>, tensor<i32>) -> tensor<i32>
+  %2 = mhlo.reshape %1 : (tensor<i32>) -> tensor<1x1xi32>
+  return %2 : tensor<1x1xi32>
+}
+
+// -----
+
+func.func @io_aliases_should_alias_arrays_with_same_per_shard_shape(
     %arg0: !ifrt.array<tensor<2x2xi32>, #ifrt.sharding_param<1x1 to [0] on 2>,
                        [0,1]>)
     attributes {ifrt.function} {
-  // expected-error@+1 {{'ifrt.Call' op can't alias input #0 to output #0 with different types: '!ifrt.array<tensor<2x2xi32>, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]>' vs '!ifrt.array<tensor<2x2xi32>, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]>'}}
+  // expected-error@+1 {{'ifrt.Call' op can't alias input #0 to output #0 with different per-shard shapes: '!ifrt.array<tensor<2x2xi32>, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]>' vs '!ifrt.array<tensor<2x2xi32>, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]>'}}
   %0, %ctrl_0 = ifrt.Call @callee(%arg0) on devices [0,1]
     {io_aliases=[array<i32: 0, 0>]}
     : (!ifrt.array<tensor<2x2xi32>, #ifrt.sharding_param<1x1 to [0] on 2>,
diff --git a/third_party/xla/xla/python/ifrt/ir/tests/verify_call_loaded_executable.mlir b/third_party/xla/xla/python/ifrt/ir/tests/verify_call_loaded_executable.mlir
index d522841..14485f4 100644
--- a/third_party/xla/xla/python/ifrt/ir/tests/verify_call_loaded_executable.mlir
+++ b/third_party/xla/xla/python/ifrt/ir/tests/verify_call_loaded_executable.mlir
@@ -215,7 +215,7 @@
     %arg0: !ifrt.array<tensor<2x2xi32>, #ifrt.sharding_param<1x1 to [0] on 2>,
                        [0,1]>)
     attributes {ifrt.function} {
-  // expected-error@+1 {{'ifrt.CallLoadedExecutable' op can't alias input #0 to output #0 with different types: '!ifrt.array<tensor<2x2xi32>, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]>' vs '!ifrt.array<tensor<2x2xi32>, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]>'}}
+  // expected-error@+1 {{'ifrt.CallLoadedExecutable' op can't alias input #0 to output #0 with different per-shard shapes: '!ifrt.array<tensor<2x2xi32>, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]>' vs '!ifrt.array<tensor<2x2xi32>, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]>'}}
   %0, %ctrl_0 = ifrt.CallLoadedExecutable @callee(%arg0)
     {io_aliases=[array<i32: 0, 0>]}
     : (!ifrt.array<tensor<2x2xi32>, #ifrt.sharding_param<1x1 to [0] on 2>,
diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/constants.h b/third_party/xla/xla/python/ifrt/ir/transforms/constants.h
index 98bfd12..fff0484 100644
--- a/third_party/xla/xla/python/ifrt/ir/transforms/constants.h
+++ b/third_party/xla/xla/python/ifrt/ir/transforms/constants.h
@@ -21,6 +21,8 @@
 namespace xla::ifrt {
 
 inline constexpr llvm::StringLiteral kIfrtDevicesAttrName = "ifrt.devices";
+inline constexpr llvm::StringLiteral kIfrtNumDevicesAttrName =
+    "ifrt.num_devices";
 inline constexpr llvm::StringLiteral kIfrtShardingAttrName = "ifrt.sharding";
 inline constexpr llvm::StringLiteral kIfrtEntryFunctionAttrName =
     "ifrt.entry_function";
diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_verify_donation_pass.cc b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_verify_donation_pass.cc
index bdfcf78..7e34921 100644
--- a/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_verify_donation_pass.cc
+++ b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_verify_donation_pass.cc
@@ -16,10 +16,10 @@
 #include <memory>
 
 #include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/Value.h"
 #include "mlir/IR/Visitors.h"
@@ -63,10 +63,15 @@
 };
 
 void IfrtVerifyDonationPass::runOnOperation() {
-  mlir::ModuleOp module_op = getOperation();
+  mlir::func::FuncOp func_op = getOperation();
+  // We only need to run this pass on IFRT functions.
+  if (!func_op->hasAttr(kIfrtFunctionAttrName) &&
+      !func_op->hasAttr(kIfrtReshardFunctionAttrName)) {
+    return;
+  }
   llvm::DenseMap<mlir::Value, mlir::Operation*> donated_value_to_op;
-  mlir::WalkResult result = module_op.walk([&](mlir::Operation* op)
-                                               -> mlir::WalkResult {
+  mlir::WalkResult result = func_op.walk([&](mlir::Operation* op)
+                                             -> mlir::WalkResult {
     auto result =
         llvm::TypeSwitch<mlir::Operation*, mlir::LogicalResult>(op)
             .Case<xla::ifrt::CallOp, xla::ifrt::CallLoadedExecutableOp>(
@@ -136,6 +141,20 @@
               }
               return mlir::success();
             })
+            .Case<mlir::func::ReturnOp>([&](mlir::func::ReturnOp return_op) {
+              for (const auto& [idx, result] :
+                   llvm::enumerate(return_op.getOperands())) {
+                auto donated_it = donated_value_to_op.find(result);
+                if (donated_it != donated_value_to_op.end()) {
+                  return_op.emitOpError()
+                      << "result #" << idx << " of op at " << return_op.getLoc()
+                      << " was already donated to the op at "
+                      << donated_it->second->getLoc();
+                  return mlir::failure();
+                }
+              }
+              return mlir::success();
+            })
             .Default(mlir::success());
 
     if (mlir::failed(result)) {
@@ -151,7 +170,7 @@
 
 }  // namespace
 
-std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
+std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
 CreateIfrtVerifyDonationPass() {
   return std::make_unique<IfrtVerifyDonationPass>();
 }
diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/passes.td b/third_party/xla/xla/python/ifrt/ir/transforms/passes.td
index d299c6b..c8c8e99 100644
--- a/third_party/xla/xla/python/ifrt/ir/transforms/passes.td
+++ b/third_party/xla/xla/python/ifrt/ir/transforms/passes.td
@@ -139,7 +139,8 @@
   let constructor = "CreateIfrtMergeReshardsPass()";
 }
 
-def IfrtVerifyDonationPass : Pass<"ifrt-verify-donation", "mlir::ModuleOp"> {
+def IfrtVerifyDonationPass :
+    Pass<"ifrt-verify-donation", "mlir::func::FuncOp"> {
   let summary = "Verify that `!ifrt.array` are not donated more than once.";
   let description = [{
 Verifiy that no `!ifrt.array` is donated more than once, and that all
diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/spmd_expansion_pass.cc b/third_party/xla/xla/python/ifrt/ir/transforms/spmd_expansion_pass.cc
index 2669dfd..07a8c8d 100644
--- a/third_party/xla/xla/python/ifrt/ir/transforms/spmd_expansion_pass.cc
+++ b/third_party/xla/xla/python/ifrt/ir/transforms/spmd_expansion_pass.cc
@@ -35,7 +35,6 @@
 #include "mlir/Pass/Pass.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/LogicalResult.h"
-#include "xla/python/ifrt/ir/ifrt_dialect.h"
 #include "xla/python/ifrt/ir/ifrt_interfaces.h"
 #include "xla/python/ifrt/ir/transforms/constants.h"
 #include "xla/python/ifrt/ir/transforms/passes.h"
@@ -272,15 +271,15 @@
 void SpmdExpansionPass::runOnOperation() {
   mlir::ModuleOp module_op = getOperation();
   // Skip single-device case.
-  auto devices = module_op->getAttrOfType<xla::ifrt::IfrtDevicesAttr>(
-      kIfrtDevicesAttrName);
-  if (devices == nullptr) {
+  auto num_devices =
+      module_op->getAttrOfType<mlir::IntegerAttr>(kIfrtNumDevicesAttrName);
+  if (num_devices == nullptr) {
     module_op->emitOpError()
         << "`" << module_op.getName()->str() << "` requires `"
-        << kIfrtDevicesAttrName << "` attribute.";
+        << kIfrtNumDevicesAttrName << "` attribute.";
     return signalPassFailure();
   }
-  if (devices.getIds().size() == 1) {
+  if (num_devices.getInt() == 1) {
     return;
   }
 
diff --git a/third_party/xla/xla/python/ifrt/sharding.cc b/third_party/xla/xla/python/ifrt/sharding.cc
index e302535..f5ca885 100644
--- a/third_party/xla/xla/python/ifrt/sharding.cc
+++ b/third_party/xla/xla/python/ifrt/sharding.cc
@@ -50,6 +50,14 @@
 
 namespace {
 
+// Returns a canonicalized memory kind for the given devices.
+// REQUIRES: !devices.empty()
+MemoryKind CanonicalizeMemoryKindWithDevices(const MemoryKind& memory_kind,
+                                             const DeviceList& devices) {
+  CHECK(!devices.empty());
+  return CanonicalizeMemoryKind(memory_kind, devices.front());
+}
+
 // Returns if `sharding_param` indicates a fully replicated sharding.
 bool ComputeIsFullyReplicated(const ShardingParam& sharding_param) {
   return llvm::all_of(sharding_param.dim_shards(),
@@ -155,6 +163,12 @@
 
 char DeserializeShardingOptions::ID = 0;
 
+Sharding::Sharding(DeviceList devices, MemoryKind memory_kind,
+                   bool is_fully_replicated)
+    : devices_(std::move(devices)),
+      memory_kind_(memory_kind),
+      is_fully_replicated_(is_fully_replicated) {}
+
 bool Sharding::operator==(const Sharding& other) const {
   if (this == &other) {
     return true;
@@ -184,6 +198,7 @@
 
 std::unique_ptr<SingleDeviceSharding> SingleDeviceSharding::Create(
     Device* device, MemoryKind memory_kind) {
+  memory_kind = CanonicalizeMemoryKind(memory_kind, device);
   return std::unique_ptr<SingleDeviceSharding>(
       new SingleDeviceSharding(device, memory_kind));
 }
@@ -247,6 +262,7 @@
 
 std::unique_ptr<OpaqueSharding> OpaqueSharding::Create(DeviceList devices,
                                                        MemoryKind memory_kind) {
+  memory_kind = CanonicalizeMemoryKindWithDevices(memory_kind, devices);
   return std::unique_ptr<OpaqueSharding>(
       new OpaqueSharding(std::move(devices), memory_kind));
 }
@@ -318,6 +334,7 @@
     DeviceList devices, MemoryKind memory_kind, Shape shape,
     std::vector<Shape> shard_shapes) {
   CHECK_EQ(devices.size(), shard_shapes.size());
+  memory_kind = CanonicalizeMemoryKindWithDevices(memory_kind, devices);
   return std::unique_ptr<ConcreteSharding>(
       new ConcreteSharding(std::move(devices), memory_kind, std::move(shape),
                            std::move(shard_shapes)));
@@ -327,6 +344,7 @@
     DeviceList devices, MemoryKind memory_kind, DynamicShape dynamic_shape,
     std::vector<DynamicShape> shard_dynamic_shapes) {
   CHECK_EQ(devices.size(), shard_dynamic_shapes.size());
+  memory_kind = CanonicalizeMemoryKindWithDevices(memory_kind, devices);
   return std::unique_ptr<ConcreteSharding>(new ConcreteSharding(
       std::move(devices), memory_kind, std::move(dynamic_shape),
       std::move(shard_dynamic_shapes)));
@@ -472,6 +490,7 @@
 std::unique_ptr<ConcreteEvenSharding> ConcreteEvenSharding::Create(
     DeviceList devices, MemoryKind memory_kind, Shape shape, Shape shard_shape,
     bool is_fully_replicated) {
+  memory_kind = CanonicalizeMemoryKindWithDevices(memory_kind, devices);
   return std::unique_ptr<ConcreteEvenSharding>(new ConcreteEvenSharding(
       std::move(devices), memory_kind, std::move(shape), std::move(shard_shape),
       is_fully_replicated));
@@ -586,6 +605,7 @@
         "%d",
         device_count, devices.size());
   }
+  memory_kind = CanonicalizeMemoryKindWithDevices(memory_kind, devices);
   return std::unique_ptr<ShardingParamSharding>(new ShardingParamSharding(
       std::move(sharding_param), std::move(devices), memory_kind));
 }
@@ -595,7 +615,8 @@
                                              DeviceList devices,
                                              MemoryKind memory_kind)
     : llvm::RTTIExtends<ShardingParamSharding, Sharding>(
-          devices, memory_kind, ComputeIsFullyReplicated(sharding_param)),
+          std::move(devices), memory_kind,
+          ComputeIsFullyReplicated(sharding_param)),
       sharding_param_(sharding_param) {}
 
 absl::StatusOr<std::vector<std::pair<Shape, std::shared_ptr<const Sharding>>>>
diff --git a/third_party/xla/xla/python/ifrt/sharding.h b/third_party/xla/xla/python/ifrt/sharding.h
index c7fbd25..91b8b8a 100644
--- a/third_party/xla/xla/python/ifrt/sharding.h
+++ b/third_party/xla/xla/python/ifrt/sharding.h
@@ -125,10 +125,8 @@
   static char ID;  // NOLINT
 
  protected:
-  Sharding(DeviceList devices, MemoryKind memory_kind, bool is_fully_replicated)
-      : devices_(devices),
-        memory_kind_(memory_kind),
-        is_fully_replicated_(is_fully_replicated) {}
+  Sharding(DeviceList devices, MemoryKind memory_kind,
+           bool is_fully_replicated);
 
   DeviceList devices_;
   MemoryKind memory_kind_;
@@ -189,6 +187,7 @@
 class OpaqueSharding : public llvm::RTTIExtends<OpaqueSharding, Sharding> {
  public:
   // Creates an opaque sharding. `Disassemble()` will fail.
+  // REQUIRES: !devices.empty()
   static std::unique_ptr<OpaqueSharding> Create(DeviceList devices,
                                                 MemoryKind memory_kind);
 
@@ -230,6 +229,7 @@
  public:
   // Creates a concrete sharding that may contain non-identical shard shapes.
   // REQUIRES: `devices`.size() == `shard_shapes`.size()
+  // REQUIRES: !devices.empty()
   static std::unique_ptr<ConcreteSharding> Create(
       DeviceList devices, MemoryKind memory_kind, Shape shape,
       std::vector<Shape> shard_shapes);
@@ -237,6 +237,7 @@
   // Creates a concrete sharding that may contain non-identical shard dynamic
   // shapes.
   // REQUIRES: `devices`.size() == `shard_dynamic_shapes`.size()
+  // REQUIRES: !devices.empty()
   static std::unique_ptr<ConcreteSharding> Create(
       DeviceList devices, MemoryKind memory_kind, DynamicShape dynamic_shape,
       std::vector<DynamicShape> shard_dynamic_shapes);
@@ -321,6 +322,7 @@
   // Creates a concrete even sharding.
   // TODO(hyeontaek): Remove the default value of `is_fully_replicated` once all
   // callers are updated to provide it explicitly.
+  // REQUIRES: !devices.empty()
   static std::unique_ptr<ConcreteEvenSharding> Create(
       DeviceList devices, MemoryKind memory_kind, Shape shape,
       Shape shard_shape, bool is_fully_replicated = false);
@@ -371,6 +373,7 @@
 class ShardingParamSharding
     : public llvm::RTTIExtends<ShardingParamSharding, Sharding> {
  public:
+  // REQUIRES: !devices.empty()
   static absl::StatusOr<std::unique_ptr<ShardingParamSharding>> Create(
       ShardingParam sharding_param, DeviceList devices, MemoryKind memory_kind);
 
diff --git a/third_party/xla/xla/python/ifrt_proxy/client/client.cc b/third_party/xla/xla/python/ifrt_proxy/client/client.cc
index 15c2935..09b4ccb 100644
--- a/third_party/xla/xla/python/ifrt_proxy/client/client.cc
+++ b/third_party/xla/xla/python/ifrt_proxy/client/client.cc
@@ -82,7 +82,6 @@
   for (const auto& d : init_response.devices()) {
     absl::flat_hash_map<std::string, xla::PjRtDeviceAttribute>
         pjrt_device_attributes;
-    AttributeMap::Map attributes;
     if (rpc_helper->version().protocol_version() <= 3) {
       for (const auto& [key, attr] : d.deprecated_attributes()) {
         TF_ASSIGN_OR_RETURN(xla::PjRtDeviceAttribute value,
diff --git a/third_party/xla/xla/python/pjrt_ifrt/xla_sharding.cc b/third_party/xla/xla/python/pjrt_ifrt/xla_sharding.cc
index 6f79e56..ec675f4 100644
--- a/third_party/xla/xla/python/pjrt_ifrt/xla_sharding.cc
+++ b/third_party/xla/xla/python/pjrt_ifrt/xla_sharding.cc
@@ -97,11 +97,20 @@
   return result;
 }
 
+// Returns a canonicalized memory kind for the given devices.
+// REQUIRES: !devices.empty()
+MemoryKind CanonicalizeMemoryKindWithDevices(const MemoryKind& memory_kind,
+                                             const DeviceList& devices) {
+  CHECK(!devices.empty());
+  return CanonicalizeMemoryKind(memory_kind, devices.front());
+}
+
 }  // namespace
 
 std::unique_ptr<HloSharding> HloSharding::Create(
     DeviceList devices, MemoryKind memory_kind,
     xla::HloSharding xla_hlo_sharding) {
+  memory_kind = CanonicalizeMemoryKindWithDevices(memory_kind, devices);
   return std::unique_ptr<HloSharding>(new HloSharding(
       std::move(devices), memory_kind, std::move(xla_hlo_sharding)));
 }
diff --git a/third_party/xla/xla/python/profiler/internal/python_hooks.h b/third_party/xla/xla/python/profiler/internal/python_hooks.h
index a9b502e..29e6b83 100644
--- a/third_party/xla/xla/python/profiler/internal/python_hooks.h
+++ b/third_party/xla/xla/python/profiler/internal/python_hooks.h
@@ -77,7 +77,7 @@
     Py_XDECREF(m_module);
   }
 
-  PythonTraceEntry(PythonTraceEntry&& other) {
+  PythonTraceEntry(PythonTraceEntry&& other) noexcept {
     start_time_ns = other.start_time_ns;
     end_time_ns = other.end_time_ns;
     co_firstlineno = other.co_firstlineno;
diff --git a/third_party/xla/xla/python/py_compile_only_client.cc b/third_party/xla/xla/python/py_compile_only_client.cc
index 9d9db9a..6f5aff6 100644
--- a/third_party/xla/xla/python/py_compile_only_client.cc
+++ b/third_party/xla/xla/python/py_compile_only_client.cc
@@ -19,12 +19,14 @@
 #include <functional>
 #include <memory>
 #include <optional>
+#include <string>
 #include <string_view>
 #include <utility>
 #include <vector>
 
 #include "absl/container/flat_hash_map.h"
 #include "absl/status/statusor.h"
+#include "absl/strings/str_format.h"
 #include "absl/strings/string_view.h"
 #include "absl/types/span.h"
 #include "llvm/Support/Casting.h"
@@ -79,6 +81,40 @@
 
 namespace {
 
+class CompileOnlyMemory
+    : public llvm::RTTIExtends<CompileOnlyMemory, ifrt::Memory> {
+ public:
+  explicit CompileOnlyMemory(
+      int id, const PjRtMemorySpaceDescription* memory_description,
+      ifrt::Device* device)
+      : id_(id),
+        kind_(memory_description->kind()),
+        debug_string_(absl::StrFormat("CompileOnlyMemory(id=%d, kind=%s)", id,
+                                      memory_description->kind())),
+        device_(device) {}
+
+  ifrt::MemoryId Id() const override { return ifrt::MemoryId(id_); }
+
+  const ifrt::MemoryKind& Kind() const override { return kind_; }
+
+  absl::string_view ToString() const override { return debug_string_; }
+  absl::string_view DebugString() const override { return debug_string_; }
+
+  absl::Span<ifrt::Device* const> Devices() const override {
+    return absl::Span<ifrt::Device* const>{&device_, 1};
+  }
+
+  static char ID;  // NOLINT
+
+ private:
+  int id_;
+  ifrt::MemoryKind kind_;
+  std::string debug_string_;
+  ifrt::Device* device_;
+};
+
+[[maybe_unused]] char CompileOnlyMemory::ID = 0;
+
 class CompileOnlyDevice
     : public llvm::RTTIExtends<CompileOnlyDevice, ifrt::Device> {
  public:
@@ -108,16 +144,31 @@
     return description_->DebugString();
   }
 
-  absl::Span<ifrt::Memory* const> Memories() const override { return {}; }
+  absl::Span<ifrt::Memory* const> Memories() const override {
+    return unowned_memories_;
+  }
   absl::StatusOr<ifrt::Memory*> DefaultMemory() const override {
+    if (default_memory_) {
+      return default_memory_;
+    }
     return Unimplemented("DefaultMemory is not supported");
   }
 
   const ifrt::AttributeMap& Attributes() const override { return attributes_; }
 
+  void AttachMemory(std::unique_ptr<ifrt::Memory> memory) {
+    unowned_memories_.push_back(memory.get());
+    owned_memories_.push_back(std::move(memory));
+  }
+
+  void SetDefaultMemory(ifrt::Memory* memory) { default_memory_ = memory; }
+
  private:
   const PjRtDeviceDescription* description_;
   ifrt::AttributeMap attributes_;
+  ifrt::Memory* default_memory_ = nullptr;
+  std::vector<ifrt::Memory*> unowned_memories_;
+  std::vector<std::unique_ptr<ifrt::Memory>> owned_memories_;
 };
 
 class InvalidIfrtCompiler final
@@ -153,10 +204,24 @@
       : topology_(std::move(topology)),
         descriptions_(topology_->DeviceDescriptions()),
         attributes_(ifrt::AttributeMap::Map()) {
+    int offset = 0;
     for (auto& description : descriptions_) {
       owned_devices_.push_back(
           std::make_unique<CompileOnlyDevice>(description.get()));
-      devices_.push_back(owned_devices_.back().get());
+      auto* device = owned_devices_.back().get();
+      devices_.push_back(device);
+      if (description->process_index() == process_index()) {
+        auto default_memory = description->default_memory_space();
+        for (auto* memory_description : description->memory_spaces()) {
+          auto memory = std::make_unique<CompileOnlyMemory>(
+              offset, memory_description, device);
+          if (default_memory.ok() && memory_description == *default_memory) {
+            device->SetDefaultMemory(memory.get());
+          }
+          device->AttachMemory(std::move(memory));
+          ++offset;
+        }
+      }
     }
   }
 
diff --git a/third_party/xla/xla/python/py_values.h b/third_party/xla/xla/python/py_values.h
index 9733a42..51bfdb9 100644
--- a/third_party/xla/xla/python/py_values.h
+++ b/third_party/xla/xla/python/py_values.h
@@ -49,8 +49,8 @@
   // dangerous due to `owning_pybuffer`.
   DevicePutResult(const DevicePutResult&) = delete;
   DevicePutResult& operator=(const DevicePutResult&) = delete;
-  DevicePutResult(DevicePutResult&&) = default;
-  DevicePutResult& operator=(DevicePutResult&&) = default;
+  DevicePutResult(DevicePutResult&&) noexcept = default;
+  DevicePutResult& operator=(DevicePutResult&&) noexcept = default;
 
   // Points to the on-device array. Not owned.
   tsl::RCReference<ifrt::Array> ifrt_array;
diff --git a/third_party/xla/xla/python/python_ref_manager.h b/third_party/xla/xla/python/python_ref_manager.h
index 815e80e..4f1d821 100644
--- a/third_party/xla/xla/python/python_ref_manager.h
+++ b/third_party/xla/xla/python/python_ref_manager.h
@@ -57,7 +57,7 @@
     ManagedPyObjects(const ManagedPyObjects& other) = delete;
     ManagedPyObjects(ManagedPyObjects&& other) = default;
     ManagedPyObjects& operator=(const ManagedPyObjects& other) = delete;
-    ManagedPyObjects& operator=(ManagedPyObjects&& other) = default;
+    ManagedPyObjects& operator=(ManagedPyObjects&& other) noexcept = default;
 
    private:
     PythonRefManager* manager_ = nullptr;
diff --git a/third_party/xla/xla/python/pytree.cc b/third_party/xla/xla/python/pytree.cc
index 68a483c..65bfb3f 100644
--- a/third_party/xla/xla/python/pytree.cc
+++ b/third_party/xla/xla/python/pytree.cc
@@ -1249,7 +1249,7 @@
         nb::cast<std::string_view>(nb::repr(node_data->first))));
   }
   node.kind = registration->kind;
-  if (node.kind == PyTreeKind::kCustom) {
+  if (node.kind == PyTreeKind::kCustom || node.kind == PyTreeKind::kDataclass) {
     node.custom = registration;
     node.node_data = node_data->second;
   } else if (node.kind == PyTreeKind::kNamedTuple) {
diff --git a/third_party/xla/xla/python/pytree_test.py b/third_party/xla/xla/python/pytree_test.py
index 4125d7a..922a4d7 100644
--- a/third_party/xla/xla/python/pytree_test.py
+++ b/third_party/xla/xla/python/pytree_test.py
@@ -13,6 +13,7 @@
 # limitations under the License.
 # ==============================================================================
 import collections
+import dataclasses
 
 from absl.testing import absltest
 
@@ -44,6 +45,15 @@
 registry.register_node(ExampleType2, ExampleType2.to_iterable, from_iterable)
 
 
+@dataclasses.dataclass
+class Custom:
+  a: int
+  b: str
+
+
+registry.register_dataclass_node(Custom, ["a"], ["b"])
+
+
 class PyTreeTest(absltest.TestCase):
 
   def roundtrip(self, example):
@@ -92,6 +102,15 @@
     y = registry.flatten((0, 0))[1]
     self.assertEqual((x.compose(y)).num_leaves, 2)
 
+  def testDataclassMakeFromNodeData(self):
+    c = Custom(1, "a")
+    c_leafs, c_tree = registry.flatten(c)
+    c_tree2 = c_tree.make_from_node_data_and_children(
+        registry, c_tree.node_data(), c_tree.children()
+    )
+    self.assertEqual(c_tree2.unflatten(c_leafs), c)
+    self.assertEqual(str(c_tree2), str(c_tree))
+
 
 if __name__ == "__main__":
   absltest.main()
diff --git a/third_party/xla/xla/python/tools/BUILD b/third_party/xla/xla/python/tools/BUILD
index 6d57e56..cc0c5e0 100644
--- a/third_party/xla/xla/python/tools/BUILD
+++ b/third_party/xla/xla/python/tools/BUILD
@@ -86,7 +86,7 @@
         ":types",
         "@absl_py//absl/testing:absltest",
         "@absl_py//absl/testing:parameterized",
-        #internal proto upb dep
+        # copybara:uncomment "//third_party/py/google/protobuf:use_fast_cpp_protos",
         "//third_party/py/numpy",
         "//xla:xla_data_proto_py",
     ],
diff --git a/third_party/xla/xla/python/traceback.cc b/third_party/xla/xla/python/traceback.cc
index b3cdfce..19e4f94 100644
--- a/third_party/xla/xla/python/traceback.cc
+++ b/third_party/xla/xla/python/traceback.cc
@@ -99,7 +99,8 @@
   }
 }
 
-Traceback::Traceback(Traceback&& other) : frames_(std::move(other.frames_)) {
+Traceback::Traceback(Traceback&& other) noexcept
+    : frames_(std::move(other.frames_)) {
   // absl::InlinedVector does not always clear itself if moved. Since we rely on
   // its empty() method to destroy Traceback differently, we explicitly clear
   // here.
diff --git a/third_party/xla/xla/python/traceback.h b/third_party/xla/xla/python/traceback.h
index c93860b..da80362 100644
--- a/third_party/xla/xla/python/traceback.h
+++ b/third_party/xla/xla/python/traceback.h
@@ -48,7 +48,7 @@
   ~Traceback();
 
   Traceback(const Traceback&) = delete;
-  Traceback(Traceback&& other);
+  Traceback(Traceback&& other) noexcept;
   Traceback& operator=(const Traceback&) = delete;
   Traceback& operator=(Traceback&&) = delete;
 
diff --git a/third_party/xla/xla/python/xla.cc b/third_party/xla/xla/python/xla.cc
index b9f9b6e..4feb8cb 100644
--- a/third_party/xla/xla/python/xla.cc
+++ b/third_party/xla/xla/python/xla.cc
@@ -62,19 +62,12 @@
 #include "xla/python/py_program.h"
 #include "xla/service/cpu/collectives_interface.h"
 #include "xla/tsl/python/lib/core/numpy.h"  //NOLINT
-#ifdef XLA_PYTHON_ENABLE_GPU
-#include "xla/python/gpu_support.h"
-#endif  // XLA_PYTHON_ENABLE_GPU
 
 #ifdef __linux__
 #include "gloo/transport/tcp/attr.h"
 #include "gloo/transport/tcp/device.h"
 #include "xla/pjrt/cpu/gloo_collectives.h"
 #include "xla/pjrt/cpu/gloo_kv_store.h"
-#elif __APPLE__
-#include "gloo/transport/uv/device.h"
-#include "xla/pjrt/cpu/gloo_collectives.h"
-#include "xla/pjrt/cpu/gloo_kv_store.h"
 #endif  // __linux__
 
 #if !defined(_WIN32) && !defined(PLATFORM_GOOGLE)
@@ -261,7 +254,7 @@
          std::optional<std::string> hostname,
          std::optional<std::string> interface)
           -> std::shared_ptr<xla::cpu::CollectivesInterface> {
-#if defined(__linux__)
+#ifdef __linux__
         std::shared_ptr<KeyValueStoreInterface> kv_store = nullptr;
         if (distributed_client != nullptr) {
           kv_store = GetDistributedKeyValueStore(distributed_client,
@@ -278,26 +271,9 @@
         auto tcp_device = gloo::transport::tcp::CreateDevice(tcp_attrs);
         return std::make_shared<cpu::GlooCollectives>(std::move(gloo_kv_store),
                                                       std::move(tcp_device));
-#elif defined(__APPLE__)
-        std::shared_ptr<KeyValueStoreInterface> kv_store = nullptr;
-        if (distributed_client != nullptr) {
-          kv_store = GetDistributedKeyValueStore(distributed_client,
-                                                 /*key_prefix=*/"cpu:");
-        }
-        auto gloo_kv_store = std::make_unique<cpu::GlooKeyValueStore>(kv_store);
-        auto uv_attrs = gloo::transport::uv::attr();
-        if (hostname) {
-          uv_attrs.hostname = *hostname;
-        }
-        if (interface) {
-          uv_attrs.iface = *interface;
-        }
-        auto uv_device = gloo::transport::uv::CreateDevice(uv_attrs);
-        return std::make_shared<cpu::GlooCollectives>(std::move(gloo_kv_store),
-                                                      std::move(uv_device));
 #else   // __linux__
         throw xla::XlaRuntimeError(
-            "make_gloo_tcp_collectives only implemented for linux and macos");
+            "make_gloo_tcp_collectives only implemented for linux");
 #endif  // __linux__
       },
       nb::arg("distributed_client"), nb::arg("hostname").none() = std::nullopt,
@@ -387,10 +363,6 @@
     return xla::ThrowIfError(pjrt::InitializePjrtPlugin(platform_name));
   });
 
-#ifdef XLA_PYTHON_ENABLE_GPU
-  RegisterGpuClientAndDefineGpuAllocatorConfig(m_nb);
-#endif  // XLA_PYTHON_ENABLE_GPU
-
   m_nb.def(
       "get_c_api_client",
       [](std::string platform_name,
diff --git a/third_party/xla/xla/python/xla_client.pyi b/third_party/xla/xla/python/xla_client.pyi
index bf5a6d9..8731080 100644
--- a/third_party/xla/xla/python/xla_client.pyi
+++ b/third_party/xla/xla/python/xla_client.pyi
@@ -222,12 +222,16 @@
   collapsed_slice_dims: list[int]
   start_index_map: list[int]
   index_vector_dim: int
+  operand_batching_dims: list[int]
+  start_indices_batching_dims: list[int]
 
 class ScatterDimensionNumbers:
   update_window_dims: list[int]
   inserted_window_dims: list[int]
   scatter_dims_to_operand_dims: list[int]
   index_vector_dim: int
+  input_batching_dims: list[int]
+  scatter_indices_batching_dims: list[int]
 
 class ReplicaGroup:
   replica_ids: list[int]
diff --git a/third_party/xla/xla/python/xla_client_test.py b/third_party/xla/xla/python/xla_client_test.py
index 65d6d7f..37484cc 100644
--- a/third_party/xla/xla/python/xla_client_test.py
+++ b/third_party/xla/xla/python/xla_client_test.py
@@ -65,8 +65,12 @@
 
 
 # pylint: disable=invalid-name
-def jax_array_convert_to_array(self):
-  return self._single_device_array_to_np_array()
+def jax_array_convert_to_array(self, dtype=None, copy=None):
+  del copy
+  out = self._single_device_array_to_np_array()
+  if dtype is not None:
+    out = out.astype(dtype)
+  return out
 
 
 def jax_array_device(self):
@@ -586,7 +590,10 @@
     def testScalarTimesVector(self, dtype):
       c = self._NewComputation()
       arg0 = np.array(3, dtype=dtype)
-      arg1 = np.array([10, 15, -2, 7], dtype=dtype)
+      if np.issubdtype(dtype, np.unsignedinteger):
+        arg1 = np.array([10, 15, 2, 7], dtype=dtype)
+      else:
+        arg1 = np.array([10, 15, -2, 7], dtype=dtype)
       p0 = ops.Parameter(c, 0, xla_client.shape_from_pyval(arg0))
       p1 = ops.Parameter(c, 1, xla_client.shape_from_pyval(arg1))
       ops.Mul(p0, p1)
diff --git a/third_party/xla/xla/python/xla_compiler.cc b/third_party/xla/xla/python/xla_compiler.cc
index 7e2504c..f58a59a 100644
--- a/third_party/xla/xla/python/xla_compiler.cc
+++ b/third_party/xla/xla/python/xla_compiler.cc
@@ -80,10 +80,10 @@
 #include "xla/service/tuple_simplifier.h"
 #include "xla/shape.h"
 #include "xla/shape_util.h"
+#include "xla/tsl/lib/strings/proto_serialization.h"
 #include "xla/util.h"
 #include "xla/xla.pb.h"
 #include "xla/xla_data.pb.h"
-#include "tsl/lib/strings/proto_serialization.h"
 #include "tsl/platform/env.h"
 #include "tsl/platform/errors.h"
 #include "tsl/platform/logging.h"
diff --git a/third_party/xla/xla/reference_util.cc b/third_party/xla/xla/reference_util.cc
index 33fb500..d7461ca 100644
--- a/third_party/xla/xla/reference_util.cc
+++ b/third_party/xla/xla/reference_util.cc
@@ -25,11 +25,19 @@
 
 #include "absl/container/flat_hash_set.h"
 #include "absl/functional/function_ref.h"
+#include "absl/types/span.h"
+#include "xla/array2d.h"
+#include "xla/array3d.h"
+#include "xla/array4d.h"
+#include "xla/client/padding.h"
 #include "xla/client/xla_builder.h"
 #include "xla/hlo/evaluator/hlo_evaluator.h"
 #include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/literal.h"
 #include "xla/literal_util.h"
+#include "xla/service/hlo_module_config.h"
 #include "xla/service/shape_inference.h"
+#include "xla/shape.h"
 #include "xla/window_util.h"
 #include "xla/xla_data.pb.h"
 #include "tsl/lib/math/math_util.h"
diff --git a/third_party/xla/xla/reference_util.h b/third_party/xla/xla/reference_util.h
index 9a124d6..a086fdb 100644
--- a/third_party/xla/xla/reference_util.h
+++ b/third_party/xla/xla/reference_util.h
@@ -24,6 +24,7 @@
 #include <vector>
 
 #include "absl/functional/function_ref.h"
+#include "absl/log/check.h"
 #include "absl/types/span.h"
 #include "xla/array2d.h"
 #include "xla/array3d.h"
diff --git a/third_party/xla/xla/reference_util_test.cc b/third_party/xla/xla/reference_util_test.cc
index 320d1ca..c27e541 100644
--- a/third_party/xla/xla/reference_util_test.cc
+++ b/third_party/xla/xla/reference_util_test.cc
@@ -22,7 +22,9 @@
 #include "xla/array3d.h"
 #include "xla/array4d.h"
 #include "xla/client/padding.h"
+#include "xla/error_spec.h"
 #include "xla/literal.h"
+#include "xla/literal_util.h"
 #include "xla/test.h"
 #include "xla/tests/literal_test_util.h"
 #include "xla/xla_data.pb.h"
diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD
index 2582534..208c9d6 100644
--- a/third_party/xla/xla/service/BUILD
+++ b/third_party/xla/xla/service/BUILD
@@ -257,53 +257,6 @@
 )
 
 cc_library(
-    name = "all_reduce_splitter",
-    srcs = ["all_reduce_splitter.cc"],
-    hdrs = ["all_reduce_splitter.h"],
-    deps = [
-        ":collective_opt_utils",
-        ":hlo_module_config",
-        ":hlo_pass",
-        "//xla:shape_util",
-        "//xla:xla_data_proto_cc",
-        "//xla/hlo/ir:hlo",
-        "//xla/hlo/utils:hlo_query",
-        "@com_google_absl//absl/cleanup",
-        "@com_google_absl//absl/container:flat_hash_map",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/log",
-        "@com_google_absl//absl/log:check",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/strings:string_view",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-xla_cc_test(
-    name = "all_reduce_splitter_test",
-    srcs = ["all_reduce_splitter_test.cc"],
-    deps = [
-        ":all_reduce_splitter",
-        ":hlo_module_config",
-        ":hlo_pass_pipeline",
-        "//xla/hlo/ir:hlo",
-        "//xla/service/gpu:gpu_reduce_scatter_creator",
-        "//xla/tests:filecheck",
-        "//xla/tests:hlo_test_base",
-        "//xla/tsl/lib/core:status_test_util",
-        "@com_google_absl//absl/algorithm:container",
-        "@com_google_absl//absl/log",
-        "@com_google_absl//absl/status",
-        "@com_google_absl//absl/strings:string_view",
-        "@com_google_googletest//:gtest_main",
-        "@local_tsl//tsl/platform:status_matchers",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-cc_library(
     name = "float_support",
     srcs = ["float_support.cc"],
     hdrs = ["float_support.h"],
@@ -687,6 +640,7 @@
         "//xla:util",
         "//xla:xla_proto_cc",
         "//xla/hlo/ir:hlo",
+        "//xla/tsl/lib/strings:proto_serialization",
         "@com_google_absl//absl/algorithm:container",
         "@com_google_absl//absl/base:core_headers",
         "@com_google_absl//absl/container:flat_hash_set",
@@ -703,7 +657,6 @@
         "@llvm-project//mlir:Transforms",
         "@local_tsl//tsl/lib/io:zlib_compression_options",
         "@local_tsl//tsl/lib/io:zlib_outputbuffer",
-        "@local_tsl//tsl/lib/strings:proto_serialization",
         "@local_tsl//tsl/platform:env",
         "@local_tsl//tsl/platform:errors",
         "@local_tsl//tsl/platform:path",
@@ -1089,7 +1042,8 @@
     deps = [
         ":pattern_matcher",
         ":pattern_matcher_gmock",
-        "//xla:literal",
+        "//xla:comparison_util",
+        "//xla:literal_util",
         "//xla:protobuf_util",
         "//xla:shape_util",
         "//xla:test",
@@ -1103,6 +1057,9 @@
         "//xla/tests:xla_internal_test_main",
         "//xla/tsl/lib/core:status_test_util",
         "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/strings:string_view",
+        "@local_tsl//tsl/platform:statusor",
     ],
 )
 
@@ -1217,17 +1174,20 @@
     srcs = ["call_inliner_test.cc"],
     deps = [
         ":call_inliner",
-        ":hlo_pass",
+        ":hlo_parser",
         "//xla:literal",
+        "//xla:literal_util",
         "//xla:shape_util",
         "//xla:test",
-        "//xla:types",
         "//xla:xla_data_proto_cc",
         "//xla/hlo/ir:hlo",
         "//xla/hlo/utils:hlo_matchers",
         "//xla/tests:hlo_test_base",
         "//xla/tests:xla_internal_test_main",
         "//xla/tsl/lib/core:status_test_util",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/strings:string_view",
+        "@local_tsl//tsl/platform:statusor",
     ],
 )
 
@@ -1737,6 +1697,7 @@
         "//xla/stream_executor",
         "//xla/stream_executor:device_description",
         "//xla/stream_executor:device_memory_allocator",
+        "//xla/tsl/lib/strings:proto_serialization",
         "@com_google_absl//absl/log:check",
         "@com_google_absl//absl/memory",
         "@com_google_absl//absl/status",
@@ -1744,7 +1705,6 @@
         "@com_google_absl//absl/strings:str_format",
         "@com_google_absl//absl/types:span",
         "@com_google_absl//absl/types:variant",
-        "@local_tsl//tsl/lib/strings:proto_serialization",
         "@local_tsl//tsl/platform:env",
         "@local_tsl//tsl/platform:errors",
         "@local_tsl//tsl/platform:logging",
@@ -4549,7 +4509,9 @@
         ":hlo_parser",
         ":pattern_matcher",
         ":pattern_matcher_gmock",
-        "//xla:literal",
+        "//xla:comparison_util",
+        "//xla:literal_util",
+        "//xla:shape_tree",
         "//xla:shape_util",
         "//xla:test",
         "//xla:test_helpers",
@@ -4559,7 +4521,9 @@
         "//xla/tests:xla_internal_test_main",
         "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/status",
         "@com_google_googletest//:gtest",
+        "@local_tsl//tsl/platform:status",
         "@local_tsl//tsl/platform:statusor",
     ],
 )
@@ -4581,10 +4545,10 @@
         "//xla/tests:hlo_test_base",
         "//xla/tests:xla_internal_test_main",
         "//xla/tsl/lib/core:status_test_util",
+        "//xla/tsl/lib/strings:proto_serialization",
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/types:span",
         "@com_google_googletest//:gtest_main",
-        "@local_tsl//tsl/lib/strings:proto_serialization",
         "@local_tsl//tsl/platform:errors",
         "@local_tsl//tsl/platform:statusor",
     ],
@@ -6943,6 +6907,7 @@
         "@local_tsl//tsl/lib/gtl:map_util",
         "@local_tsl//tsl/platform:errors",
         "@local_tsl//tsl/platform:logging",
+        "@local_tsl//tsl/platform:status",
     ],
 )
 
@@ -6951,9 +6916,12 @@
     size = "small",
     srcs = ["hlo_parser_test.cc"],
     deps = [
+        ":hlo_lexer",
+        ":hlo_module_config",
         ":hlo_parser",
         ":pattern_matcher",
         ":pattern_matcher_gmock",
+        "//xla:array",
         "//xla:shape_util",
         "//xla:window_util",
         "//xla:xla_data_proto_cc",
@@ -6961,8 +6929,13 @@
         "//xla/tests:verified_hlo_module",
         "//xla/tests:xla_internal_test_main",
         "//xla/tsl/lib/core:status_test_util",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
         "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:span",
         "@com_google_googletest//:gtest",
+        "@local_tsl//tsl/platform:errors",
         "@local_tsl//tsl/platform:status_matchers",
         "@local_tsl//tsl/platform:statusor",
         "@local_tsl//tsl/platform:test",
@@ -7926,8 +7899,7 @@
     deps = [
         ":hlo_creation_utils",
         ":hlo_pass",
-        "//xla/service/cpu:onednn_convolution_rewriter",
-        "//xla/service/cpu:onednn_matmul_rewriter",
+        "//xla/service/cpu:onednn_contraction_rewriter",
     ],
 )
 
@@ -8385,4 +8357,42 @@
     ],
 )
 
+cc_library(
+    name = "add_original_value",
+    srcs = ["add_original_value.cc"],
+    hdrs = ["add_original_value.h"],
+    deps = [
+        ":hlo_pass",
+        "//xla:shape_util",
+        "//xla/hlo/ir:hlo",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings:string_view",
+    ],
+)
+
+xla_cc_test(
+    name = "add_original_value_test",
+    srcs = ["add_original_value_test.cc"],
+    deps = [
+        ":add_original_value",
+        ":pattern_matcher",
+        ":pattern_matcher_gmock",
+        "//xla:shape_util",
+        "//xla:window_util",
+        "//xla:xla_data_proto_cc",
+        "//xla/hlo/ir:hlo",
+        "//xla/tests:hlo_test_base",
+        "//xla/tests:verified_hlo_module",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/strings:string_view",
+        "@com_google_googletest//:gtest_main",
+        "@local_tsl//tsl/lib/core:status_test_util",
+        "@local_tsl//tsl/platform:status_matchers",
+        "@local_tsl//tsl/platform:statusor",
+        "@local_tsl//tsl/platform:test",
+        "@local_tsl//tsl/platform:test_main",
+    ],
+)
+
 exports_files(["xla_aot_compile_test_gpu_target_config.prototxt"])
diff --git a/third_party/xla/xla/service/add_original_value.cc b/third_party/xla/xla/service/add_original_value.cc
new file mode 100644
index 0000000..37cab3c
--- /dev/null
+++ b/third_party/xla/xla/service/add_original_value.cc
@@ -0,0 +1,68 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/add_original_value.h"
+
+#include <cstdint>
+#include <memory>
+#include <optional>
+#include <string>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/hlo/ir/hlo_original_value.h"
+#include "xla/shape_util.h"
+
+namespace xla {
+
+absl::StatusOr<bool> AddOriginalValue::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  bool changed = false;
+
+  for (const auto computation : module->computations()) {
+    for (const auto instruction : computation->instructions()) {
+      auto original_value =
+          std::make_shared<OriginalValue>(instruction->shape());
+
+      if (instruction->opcode() == HloOpcode::kGetTupleElement) {
+        const auto* tuple = instruction->operand(0);
+        original_value->CopySubtreeFrom(*tuple->original_value(),
+                                        {instruction->tuple_index()}, {});
+      } else if (instruction->opcode() == HloOpcode::kTuple) {
+        for (int64_t operand_number = 0;
+             operand_number < instruction->operand_count(); ++operand_number) {
+          original_value->CopySubtreeFrom(
+              *instruction->operand(operand_number)->original_value(), {},
+              {operand_number});
+        }
+      } else {
+        for (auto& leaf : original_value->leaves()) {
+          leaf.second = {std::string(instruction->name()), leaf.first};
+        }
+      }
+      instruction->set_original_value(original_value);
+      changed = true;
+    }
+  }
+
+  return changed;
+}
+
+}  // namespace xla
diff --git a/third_party/xla/xla/service/add_original_value.h b/third_party/xla/xla/service/add_original_value.h
new file mode 100644
index 0000000..b4fb093
--- /dev/null
+++ b/third_party/xla/xla/service/add_original_value.h
@@ -0,0 +1,38 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_ADD_ORIGINAL_VALUE_H_
+#define XLA_SERVICE_ADD_ORIGINAL_VALUE_H_
+
+#include "absl/status/statusor.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+
+// This pass adds to each op in the HLO graph the original_value attribute,
+// which is used for HLO value tracking. See go/hlo-value-tracking for more
+// details.
+class AddOriginalValue : public HloModulePass {
+ public:
+  absl::string_view name() const override { return "add-original-value"; }
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+}  // namespace xla
+
+#endif  // XLA_SERVICE_ADD_ORIGINAL_VALUE_H_
diff --git a/third_party/xla/xla/service/add_original_value_test.cc b/third_party/xla/xla/service/add_original_value_test.cc
new file mode 100644
index 0000000..f69ba94
--- /dev/null
+++ b/third_party/xla/xla/service/add_original_value_test.cc
@@ -0,0 +1,118 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/add_original_value.h"
+
+#include <memory>
+
+#include <gtest/gtest.h>
+#include "absl/strings/string_view.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/statusor.h"
+#include "tsl/platform/test.h"
+
+namespace xla {
+namespace {
+
+using AddOriginalValueTest = HloTestBase;
+
+using ::absl::string_view;
+
+TEST_F(AddOriginalValueTest, Basic) {
+  constexpr absl::string_view hlo_string = R"(
+HloModule test, entry_computation_layout={(s32[]{:T(256)})->u32[2]{0:T(256)}}
+
+ENTRY test {
+  Arg_0.1 = s32[] parameter(0)
+  constant.2 = s32[] constant(32)
+  shift-right-logical.3 = s32[] shift-right-logical(Arg_0.1, constant.2)
+  convert.4 = u32[] convert(shift-right-logical.3)
+  reshape.5 = u32[1]{0} reshape(convert.4)
+  convert.6 = u32[] convert(Arg_0.1)
+  reshape.7 = u32[1]{0} reshape(convert.6)
+  ROOT concatenate.8 = u32[2]{0} concatenate(reshape.5, reshape.7), dimensions={0}
+}
+)";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+
+  AddOriginalValue pass;
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get()));
+  EXPECT_TRUE(changed);
+}
+
+TEST_F(AddOriginalValueTest, Tuple) {
+  constexpr absl::string_view hlo_string = R"(
+HloModule test, entry_computation_layout={(f32[], f32[3]{0}, f32[2,3]{1,0})->((f32[], f32[3]{0}), f32[2,3]{1,0})}
+
+ENTRY test (v1: f32[], v2: f32[3], v3: f32[2,3]) -> ((f32[], f32[3]{0}), f32[2,3]{1,0}) {
+  v1 = f32[] parameter(0)
+  v2 = f32[3]{0} parameter(1)
+  v3 = f32[2,3]{1,0} parameter(2)
+  t1 = (f32[], f32[3]{0}) tuple(f32[] v1, f32[3]{0} v2)
+  ROOT t2 = ((f32[], f32[3]{0}), f32[2,3]{1,0}) tuple((f32[], f32[3]{0}) t1, f32[2,3]{1,0} v3)
+}
+
+)";
+
+  RunAndFilecheckHloRewrite(hlo_string, AddOriginalValue(), R"(
+CHECK:  %[[V1:.*]] = f32[] parameter(0), original_value={{[{]}}{"[[V1]]"}
+CHECK:  %[[V2:.*]] = f32[3]{0} parameter(1), original_value={{[{]}}{"[[V2]]"}
+CHECK:  %[[TUPLE:.*]] = (f32[], f32[3]{0}) tuple(%[[V1]], %[[V2]]), original_value={({"[[V1]]"}, {"[[V2]]"})}
+CHECK:  %[[V3:.*]] = f32[2,3]{1,0} parameter(2), original_value={{[{]}}{"[[V3]]"}
+CHECK:  ((f32[], f32[3]{0}), f32[2,3]{1,0}) tuple(%[[TUPLE]], %[[V3]]), original_value={(({"v1"}, {"v2"}), {"v3"})}
+  )");
+}
+
+TEST_F(AddOriginalValueTest, GetTupleElement) {
+  constexpr absl::string_view hlo_string = R"(
+HloModule test, entry_computation_layout={()->s32[2,3]{1,0}}
+
+ENTRY test {
+  constant = f32[3]{0} constant({1, 2, 3})
+  constant.1 = s32[2,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 } })
+  tuple = (f32[3]{0}, s32[2,3]{1,0}) tuple(f32[3]{0} constant, s32[2,3]{1,0} constant.1)
+  ROOT get-tuple-element = s32[2,3]{1,0} get-tuple-element((f32[3]{0}, s32[2,3]{1,0}) tuple), index=1
+}
+
+)";
+
+  RunAndFilecheckHloRewrite(hlo_string, AddOriginalValue(), R"(
+CHECK:  %[[CONSTANT1:.*]] = f32[3]{0} constant({1, 2, 3}), original_value={{[{]}}{"[[CONSTANT1]]"}
+CHECK:  %[[CONSTANT2:.*]] = s32[2,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 } }), original_value={{[{]}}{"[[CONSTANT2]]"}
+CHECK:  %[[TUPLE:.*]] = (f32[3]{0}, s32[2,3]{1,0}) tuple(%[[CONSTANT1]], %[[CONSTANT2]]), original_value={({"[[CONSTANT1]]"}, {"[[CONSTANT2]]"})}
+CHECK:  s32[2,3]{1,0} get-tuple-element(%[[TUPLE]]), index=1, original_value={{[{]}}{"[[CONSTANT2]]"}
+  )");
+}
+
+TEST_F(AddOriginalValueTest, GetTupleElementNonSymbolic) {
+  constexpr absl::string_view hlo_string = R"(
+HloModule test, entry_computation_layout={((f32[], s32[]))->s32[]}
+
+ENTRY test {
+  p = (f32[], s32[]) parameter(0)
+  ROOT get-tuple-element = s32[] get-tuple-element(p), index=1
+}
+
+)";
+
+  RunAndFilecheckHloRewrite(hlo_string, AddOriginalValue(), R"(
+CHECK:  %[[PARAM:.*]] = (f32[], s32[]) parameter(0), original_value={({"p" {0}{{[}]}}, {"p" {1}})}
+CHECK:  s32[] get-tuple-element(%[[PARAM]]), index=1, original_value={{[{]}}{"[[PARAM]]" {1}
+  )");
+}
+
+}  // namespace
+}  // namespace xla
diff --git a/third_party/xla/xla/service/algebraic_simplifier.cc b/third_party/xla/xla/service/algebraic_simplifier.cc
index 3ee1ffa..a3a7f76 100644
--- a/third_party/xla/xla/service/algebraic_simplifier.cc
+++ b/third_party/xla/xla/service/algebraic_simplifier.cc
@@ -179,6 +179,11 @@
           using NativeT = NativeTypeOf<primitive_type_constant>;
           return static_cast<double>(
               inst->literal().GetFirstElement<NativeT>());
+        } else if constexpr (primitive_util::IsIntegralType(
+                                 primitive_type_constant)) {
+          using NativeT = NativeTypeOf<primitive_type_constant>;
+          return static_cast<int64_t>(
+              inst->literal().GetFirstElement<NativeT>());
         }
         return std::nullopt;
       },
@@ -608,6 +613,11 @@
           using NativeT = NativeTypeOf<primitive_type_constant>;
           return HloInstruction::CreateConstant(
               LiteralUtil::CreateR0<NativeT>(static_cast<NativeT>(multiplier)));
+        } else if constexpr (primitive_util::IsIntegralType(
+                                 primitive_type_constant)) {
+          using NativeT = NativeTypeOf<primitive_type_constant>;
+          return HloInstruction::CreateConstant(
+              LiteralUtil::CreateR0<NativeT>(static_cast<NativeT>(multiplier)));
         }
         LOG(FATAL) << "Unsupported data type: "
                    << target->shape().element_type();
@@ -3331,12 +3341,11 @@
     } else if (opcode == HloOpcode::kBroadcast) {
       // Broadcasts of dot contracting dimensions can be reordered to reduces
       // of the corresponding contracting dimensions in the other dot operand
-      DimensionVector reduce_dims, broadcast_dim_sizes;
+      DimensionVector reduce_dims;
       const int64_t pre_broadcast_rank =
           reorder_from->mutable_operand(0)->shape().rank();
       int64_t post_broadcast_rank = reorder_from->shape().rank();
       Shape new_broadcast_shape = reorder_from->shape();
-      DimensionVector contracting_reordered;
 
       // Construct map from broadcasted shape to its original shape. Broadcast
       // dimensions are mapped to -1 since they were not present
@@ -3554,48 +3563,28 @@
           other_index = outer_dnums.lhs_batch_dimensions(i);
         }
 
-        // Once we have the inner_index, we determine whether this index
-        // corresponds to a dimension coming from the lhs or rhs of inner
-        bool from_inner_lhs = map_inner_rhs[inner_index] == -1;
+        auto add_batch_dims = [](DotDimensionNumbers& dnums, int64_t lhs_ix,
+                                 int64_t rhs_ix) {
+          dnums.add_lhs_batch_dimensions(lhs_ix);
+          dnums.add_rhs_batch_dimensions(rhs_ix);
+        };
 
-        // The map we use depends on which operand of inner this dim comes from
-        std::vector<int64_t> map;
-        if (from_inner_lhs) {
-          map = map_inner_lhs;
-        } else {
-          map = map_inner_rhs;
-        }
-
-        // Whether the mapped value goes into the lhs or rhs of the new dnums
-        // depends on whether inner was the lhs or rhs operand of outer
-        int64_t lhs_index, rhs_index;
-        if (outer_lhs_dot) {
-          lhs_index = map[inner_index];
-          rhs_index = other_index;
-        } else {
-          lhs_index = other_index;
-          rhs_index = map[inner_index];
-        }
-
-        // Finally, we have to determine which dnums to add to
-        DotDimensionNumbers* dnums;
-        if (outer_lhs_dot) {
-          if (from_inner_lhs) {
-            dnums = &ac_dnums;
-          } else {
-            dnums = &bc_dnums;
-          }
-        } else {
-          if (from_inner_lhs) {
-            dnums = &ab_dnums;
-          } else {
-            dnums = &ac_dnums;
+        for (auto& map : {map_inner_lhs, map_inner_rhs}) {
+          int64_t mapped_index = map[inner_index];
+          if (mapped_index != -1) {
+            // Whether the mapped value is the lhs or rhs of the new dnums
+            // depends on whether inner is the lhs or rhs operand of outer. The
+            // dnums itself depends on this and also on which map we are
+            // iterating through
+            if (outer_lhs_dot) {
+              add_batch_dims(map == map_inner_lhs ? ac_dnums : bc_dnums,
+                             mapped_index, other_index);
+            } else {
+              add_batch_dims(map == map_inner_lhs ? ab_dnums : ac_dnums,
+                             other_index, mapped_index);
+            }
           }
         }
-
-        // Add the batch dimensions
-        dnums->add_lhs_batch_dimensions(lhs_index);
-        dnums->add_rhs_batch_dimensions(rhs_index);
       }
 
       // We now do the same thing for the contracting dimensions of outer
@@ -3614,7 +3603,14 @@
 
         // Once we have the inner_index, we determine whether this index
         // corresponds to a dimension coming from the lhs or rhs of inner
-        bool from_inner_lhs = map_inner_rhs[inner_index] == -1;
+        bool from_inner_lhs = map_inner_lhs[inner_index] != -1;
+        bool from_inner_rhs = map_inner_rhs[inner_index] != -1;
+
+        // If a dimension of inner is the result of batching and it is
+        // contracted in outer, we stop trying to reorder
+        if (from_inner_lhs && from_inner_rhs) {
+          return absl::OkStatus();
+        }
 
         // The map we use depends on which operand of inner this dim comes from
         std::vector<int64_t> map;
@@ -3714,8 +3710,11 @@
             rhs_index = other_index;
           }
 
-          new_outer_dnums.add_lhs_batch_dimensions(lhs_index);
-          new_outer_dnums.add_rhs_batch_dimensions(rhs_index);
+          if (!absl::c_linear_search(new_outer_dnums.lhs_batch_dimensions(),
+                                     lhs_index)) {
+            new_outer_dnums.add_lhs_batch_dimensions(lhs_index);
+            new_outer_dnums.add_rhs_batch_dimensions(rhs_index);
+          }
         }
         for (int64_t i = 0; i < dnums.lhs_contracting_dimensions_size(); ++i) {
           int64_t new_inner_index, other_index;
@@ -4280,6 +4279,19 @@
 }
 }  // namespace
 
+bool AlgebraicSimplifierVisitor::IsNondecreasingSublinear(
+    const HloInstruction* hlo) {
+  switch (hlo->opcode()) {
+    case HloOpcode::kCbrt:
+    case HloOpcode::kErf:
+    case HloOpcode::kLogistic:
+    case HloOpcode::kTanh:
+      return true;
+    default:
+      return false;
+  }
+}
+
 absl::Status AlgebraicSimplifierVisitor::HandleMaximum(
     HloInstruction* maximum) {
   HloInstruction *lhs, *rhs;
@@ -4350,6 +4362,33 @@
     }
   }
 
+  // If the operands of the max are the same non-decreasing function, then we
+  // can sink it; i.e. max(tanh(x), tanh(y)) to tanh(max(x, y))
+  // We only do this if the function asymptotically satisfies |f(x)| <= |x| to
+  // guarantee that no overflow occurs. Proof of correctness:
+  /* https://cvc5.github.io/app/
+  (set-logic ALL)
+  (declare-fun f (Float32) Float32)
+  (assert (forall ((x Float32) (y Float32))
+                  (=> (fp.lt x y) (fp.leq (f x) (f y))))) ; NonDecreasing
+  (assert (forall ((x Float32))
+                  (fp.leq (fp.abs (f x)) (fp.abs x)))) ; Sublinear
+  (assert (not (forall ((x Float32) (y Float32))
+                       (fp.eq (fp.max (f x) (f y))
+                              (f (fp.max x y)))))) ; Expect unsat
+  (check-sat)
+  */
+  if (lhs->opcode() == rhs->opcode() && IsNondecreasingSublinear(lhs)) {
+    TF_ASSIGN_OR_RETURN(
+        auto new_maximum,
+        MakeBinaryHlo(HloOpcode::kMaximum, lhs->mutable_operand(0),
+                      rhs->mutable_operand(0)));
+    VLOG(10) << "Sinking nondecreasing op through max";
+    return ReplaceWithNewInstruction(
+        maximum, HloInstruction::CreateUnary(maximum->shape(), lhs->opcode(),
+                                             new_maximum));
+  }
+
   return absl::OkStatus();
 }
 
@@ -7839,27 +7878,6 @@
     }
   }
 
-  // Replace Reduce(Broadcast(x), dims, Sum()) with Broadcast(x * prod(dims)).
-  if (HloInstruction * broadcast_arg;
-      Match(arg, m::Broadcast(m::ConstantScalar(&broadcast_arg))) &&
-      Match(function->root_instruction(),
-            m::AddAnyOrder(m::Parameter(0), m::Parameter(1)))) {
-    if (auto broadcast_value = GetConstantValue(broadcast_arg);
-        broadcast_value.has_value() &&
-        // Skip float64, where product is too accurate compared to repeated-sum.
-        broadcast_arg->shape().element_type() != PrimitiveType::F64) {
-      auto result_value = broadcast_value.value() *
-                          ShapeUtil::ElementsIn(arg->shape()) /
-                          ShapeUtil::ElementsIn(reduce_result_shape);
-      return ReplaceWithNewInstruction(
-          reduce, HloInstruction::CreateBroadcast(
-                      reduce_result_shape,
-                      reduce->AddInstruction(
-                          MakeScalarInstruction(reduce, result_value)),
-                      {}));
-    }
-  }
-
   // For Computation equal to Min, Max, And or Or, replace Reduce(Broadcast(x),
   // a, Computation()) with Computation(x, a) when x is a scalar and the
   // broadcast is reduced to a scalar.
@@ -7883,26 +7901,52 @@
     }
   }
 
-  // Replace Reduce(Broadcast(Scalar)) with Broadcast(Multiply(Scalar)) when the
-  // reduction operation is addition
+  // Replace Reduce(Broadcast(x), +, init_value) with Broadcast(Add(Multiply(x),
+  // init_value))) if all reduction dimensions were introduced by Broadcast
   if (arg->opcode() == HloOpcode::kBroadcast &&
-      ShapeUtil::IsScalar(arg->operand(0)->shape())) {
-    if (Match(reduce->to_apply()->root_instruction(),
-              m::AddAnyOrder(m::Parameter(0), m::Parameter(1))) &&
-        IsScalarConstantZero(init_value)) {
-      int64_t reduction_dims_prod = 1;
-      for (auto i : reduce->dimensions()) {
-        reduction_dims_prod *= arg->shape().dimensions(i);
-      }
+      Match(reduce->to_apply()->root_instruction(),
+            m::AddAnyOrder(m::Parameter(0), m::Parameter(1)))) {
+    bool only_reduce_dims_from_broadcast = true;
+    int64_t common_dims_prod = 1;
+    int64_t num_common_dims = 0;
+    Shape new_broadcast_shape = arg->shape();
+    std::vector<int64_t> new_broadcast_dims;
 
+    // Now we build up the new broadcast shape and dims vector
+    for (int64_t i = 0; i < arg->shape().rank(); ++i) {
+      bool added_by_broadcast = !absl::c_linear_search(arg->dimensions(), i);
+      bool removed_by_reduce = absl::c_linear_search(reduce->dimensions(), i);
+
+      if (removed_by_reduce && !added_by_broadcast) {
+        only_reduce_dims_from_broadcast = false;
+        break;
+      } else if (removed_by_reduce && added_by_broadcast) {
+        new_broadcast_shape.DeleteDimension(i - num_common_dims);
+        common_dims_prod *= arg->shape().dimensions(i);
+        num_common_dims++;
+      } else if (!removed_by_reduce && !added_by_broadcast) {
+        new_broadcast_dims.push_back(i - num_common_dims);
+      }
+    }
+
+    if (only_reduce_dims_from_broadcast) {
+      // HloConstantFolding will later remove any unnecessary multiply and add
+      // instructions.
       HloInstruction* multiplier =
-          MakeScalarLike(arg->mutable_operand(0), reduction_dims_prod);
+          MakeScalarLike(arg->mutable_operand(0), common_dims_prod);
       TF_ASSIGN_OR_RETURN(HloInstruction * multiplied_scalar,
                           MakeBinaryHlo(HloOpcode::kMultiply,
                                         arg->mutable_operand(0), multiplier));
+      TF_ASSIGN_OR_RETURN(
+          HloInstruction * add,
+          MakeBinaryHlo(
+              HloOpcode::kAdd,
+              MakeBroadcastHlo(init_value, {}, multiplied_scalar->shape()),
+              multiplied_scalar));
+      VLOG(10) << "Converting common reduce(broadcast) dimensions to multiply";
       return ReplaceWithNewInstruction(
-          reduce, HloInstruction::CreateBroadcast(reduce->shape(),
-                                                  multiplied_scalar, {}));
+          reduce, HloInstruction::CreateBroadcast(new_broadcast_shape, add,
+                                                  new_broadcast_dims));
     }
   }
 
diff --git a/third_party/xla/xla/service/algebraic_simplifier.h b/third_party/xla/xla/service/algebraic_simplifier.h
index bdd4f91..49e87fa 100644
--- a/third_party/xla/xla/service/algebraic_simplifier.h
+++ b/third_party/xla/xla/service/algebraic_simplifier.h
@@ -487,6 +487,10 @@
   static bool IsNonNegative(const HloInstruction* hlo,
                             const AlgebraicSimplifierOptions& options);
 
+  // Check if the opcode of a given instruction is a non-decreasing function
+  // asymptotically satisfying |f(x)| <= |x|
+  static bool IsNondecreasingSublinear(const HloInstruction* hlo);
+
   // Modify the layout dimensions of result_shape, so that it becomes the
   // re-shaped result of applying bitcast to the original_shape, by using
   // dim_map to re-shape layout dimensions of original_shape. Returns the
diff --git a/third_party/xla/xla/service/algebraic_simplifier_test.cc b/third_party/xla/xla/service/algebraic_simplifier_test.cc
index d9e4252..60003eb 100644
--- a/third_party/xla/xla/service/algebraic_simplifier_test.cc
+++ b/third_party/xla/xla/service/algebraic_simplifier_test.cc
@@ -6411,6 +6411,94 @@
                                 m::Dot(m::Parameter(1), m::Parameter(2)))));
 }
 
+TEST_F(AlgebraicSimplifierTest, DotLeftDotSharedBatchReorder) {
+  const char* hlo_string = R"(
+    HloModule module
+
+    ENTRY test {
+        a = f32[5,150,5] parameter(0)
+        b = f32[5,5,5] parameter(1)
+        c = f32[5,5,5] parameter(2)
+
+        inner = f32[5,150,5] dot(a,b),
+                lhs_batch_dims={0}, lhs_contracting_dims={2},
+                rhs_batch_dims={0}, rhs_contracting_dims={2}
+        ROOT outer = f32[5,150,5] dot(inner,c),
+                lhs_batch_dims={0}, lhs_contracting_dims={2},
+                rhs_batch_dims={0}, rhs_contracting_dims={2}
+      }
+    )";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+
+  AlgebraicSimplifierOptions options;
+  options.set_use_associative_reordering(true);
+  options.set_associative_reordering_threshold(1.5);
+  AlgebraicSimplifier simplifier(options);
+  EXPECT_TRUE(simplifier.Run(module.get()).value());
+  ASSERT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch(m::Dot(m::Parameter(0),
+                                m::Dot(m::Parameter(1), m::Parameter(2)))));
+}
+
+TEST_F(AlgebraicSimplifierTest, DotRightDotSharedBatchReorder) {
+  const char* hlo_string = R"(
+    HloModule module
+
+    ENTRY test {
+        a = f32[2,3,3] parameter(0)
+        b = f32[2,3,3] parameter(1)
+        c = f32[2,3,16] parameter(2)
+
+        inner = f32[2,3,16] dot(b,c),
+                lhs_batch_dims={0}, lhs_contracting_dims={2},
+                rhs_batch_dims={0}, rhs_contracting_dims={1}
+        ROOT outer = f32[2,3,16] dot(a,inner),
+                lhs_batch_dims={0}, lhs_contracting_dims={2},
+                rhs_batch_dims={0}, rhs_contracting_dims={1}
+      }
+    )";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+
+  AlgebraicSimplifierOptions options;
+  options.set_use_associative_reordering(true);
+  options.set_associative_reordering_threshold(1.5);
+  AlgebraicSimplifier simplifier(options);
+  EXPECT_TRUE(simplifier.Run(module.get()).value());
+  ASSERT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch(m::Dot(m::Dot(m::Parameter(0), m::Parameter(1)),
+                                m::Parameter(2))));
+}
+
+TEST_F(AlgebraicSimplifierTest, DotRightDotContractBatchReorder) {
+  const char* hlo_string = R"(
+    HloModule module
+
+    ENTRY test {
+      a = f32[80,38,1536] parameter(0)
+      b = f32[80,38,4] parameter(1)
+      c = f32[80,4,1536] parameter(2)
+      inner = f32[80,38,1536] dot(b, c),
+              lhs_batch_dims={0},
+              lhs_contracting_dims={2},
+              rhs_batch_dims={0},
+              rhs_contracting_dims={1}
+      ROOT outer = f32[1536,1536] dot(a, inner),
+                  lhs_contracting_dims={0,1},
+                  rhs_contracting_dims={0,1}
+      }
+    )";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+
+  AlgebraicSimplifierOptions options;
+  options.set_use_associative_reordering(true);
+  options.set_associative_reordering_threshold(1.5);
+  AlgebraicSimplifier simplifier(options);
+  EXPECT_FALSE(simplifier.Run(module.get()).value());
+}
+
 TEST_F(AlgebraicSimplifierTest, DotReverseLeftReorder) {
   const char* hlo_string = R"(
     HloModule module
@@ -10075,24 +10163,25 @@
 
 TEST_F(AlgebraicSimplifierTest, ReplaceReduceSumOfConstantBroadcast) {
   const char* kModuleStr = R"(
-HloModule ReplaceReduceSumOfConstantBroadcast
+  HloModule ReplaceReduceSumOfConstantBroadcast
 
-add_f32 {
-  p0 = f32[] parameter(0)
-  p1 = f32[] parameter(1)
-  ROOT r = f32[] add(p0, p1)
-}
+  add_f32 {
+    p0 = f32[] parameter(0)
+    p1 = f32[] parameter(1)
+    ROOT r = f32[] add(p0, p1)
+  }
 
-ENTRY main {
-  init_value = f32[] constant(0)
-  const_value = f32[] constant(1)
-  const_bcast = f32[8, 128] broadcast(f32[] const_value), dimensions={}
-  ROOT reduce = f32[8] reduce(f32[8, 128] const_bcast, f32[] init_value), dimensions={1}, to_apply=add_f32
-}
-)";
+  ENTRY main {
+    init_value = f32[] constant(0)
+    const_value = f32[] constant(1)
+    const_bcast = f32[8, 128] broadcast(f32[] const_value), dimensions={}
+    ROOT reduce = f32[8] reduce(f32[8, 128] const_bcast, f32[] init_value), dimensions={1}, to_apply=add_f32
+  }
+  )";
 
   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
-  ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value());
+  HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
+  EXPECT_TRUE(simplifier.Run(m.get()).value());
   int64_t reduce_count =
       absl::c_count_if(m->entry_computation()->instructions(),
                        HloPredicateIsOp<HloOpcode::kReduce>);
@@ -11714,6 +11803,88 @@
             HloOpcode::kParameter);
 }
 
+TEST_F(AlgebraicSimplifierTest, ReduceOfConstantBroadcastS32) {
+  const std::string hlo_string = R"(
+  HloModule test
+    add_s32 {
+      p0 = s32[] parameter(0)
+      p1 = s32[] parameter(1)
+      ROOT r = s32[] add(p0, p1)
+    }
+    ENTRY test.1 {
+      one = s32[] constant(2)
+      init = s32[] constant(10)
+      bcast = s32[1,7,7,1] broadcast(one), dimensions={}
+      ROOT out = s32[1,7,1] reduce(bcast, init), dimensions={1}, to_apply=add_s32
+    }
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
+  auto clone = m->Clone();
+  HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
+  EXPECT_TRUE(simplifier.Run(m.get()).value());
+  std::cout << m->ToString() << std::endl;
+  int64_t reduce_count =
+      absl::c_count_if(m->entry_computation()->instructions(),
+                       HloPredicateIsOp<HloOpcode::kReduce>);
+  // Expect no Reduce operation after simplification.
+  EXPECT_EQ(0, reduce_count);
+}
+
+TEST_F(AlgebraicSimplifierTest, ReduceOfConstantBroadcastBF16) {
+  const std::string hlo_string = R"(
+  HloModule test
+    add_bf16 {
+      p0 = bf16[] parameter(0)
+      p1 = bf16[] parameter(1)
+      ROOT r = bf16[] add(p0, p1)
+    }
+    ENTRY test.1 {
+      one = bf16[] constant(2.12)
+      init = bf16[] constant(10.34)
+      bcast = bf16[1,7,7,1] broadcast(one), dimensions={}
+      ROOT out = bf16[1,7,1] reduce(bcast, init), dimensions={1}, to_apply=add_bf16
+    }
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
+  auto clone = m->Clone();
+  HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
+  EXPECT_TRUE(simplifier.Run(m.get()).value());
+  int64_t reduce_count =
+      absl::c_count_if(m->entry_computation()->instructions(),
+                       HloPredicateIsOp<HloOpcode::kReduce>);
+  // Expect no Reduce operation after simplification.
+  EXPECT_EQ(0, reduce_count);
+}
+
+TEST_F(AlgebraicSimplifierTest, ReduceOfNonScalarBroadcast) {
+  const std::string hlo_string = R"(
+    HloModule module
+    add {
+      a = f32[] parameter(0)
+      b = f32[] parameter(1)
+      ROOT sum = f32[] add(a, b)
+    }
+
+    ENTRY test {
+        a = f32[64,1001] parameter(0)
+        broadcast = f32[64,7,7,1001] broadcast(a), dimensions={0,3}
+        zero = f32[] constant(0)
+        ROOT reduce = f32[64,7,1001] reduce(broadcast, zero), dimensions={2},
+                  to_apply=add
+      }
+    )";
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
+  HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
+  EXPECT_TRUE(simplifier.Run(m.get()).value());
+  HloInstruction* root = m->entry_computation()->root_instruction();
+  int64_t reduce_count =
+      absl::c_count_if(m->entry_computation()->instructions(),
+                       HloPredicateIsOp<HloOpcode::kReduce>);
+  // Expect no Reduce operation after simplification.
+  EXPECT_EQ(0, reduce_count);
+  EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Multiply())));
+}
+
 TEST_F(AlgebraicSimplifierTest, RemoveConvertConstant) {
   const std::string hlo_string = R"(
     HloModule module
@@ -11773,11 +11944,31 @@
       }
     )";
   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
-  EXPECT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value());
+  HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
+  EXPECT_TRUE(simplifier.Run(m.get()).value());
   HloInstruction* root = m->entry_computation()->root_instruction();
   EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast);
   EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kMultiply);
 }
 
+TEST_F(AlgebraicSimplifierTest, SinkCbrtThroughMax) {
+  absl::string_view hlo_string = R"(
+    HloModule module
+
+    ENTRY test {
+        a = bf16[17,96,120] parameter(0)
+        b = bf16[17,96,120] parameter(1)
+        cbrt_a = bf16[17,96,120] cbrt(a)
+        cbrt_b = bf16[17,96,120] cbrt(b)
+        ROOT max = bf16[17,96,120] maximum(cbrt_a, cbrt_b)
+      }
+    )";
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
+  EXPECT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value());
+  HloInstruction* root = m->entry_computation()->root_instruction();
+  EXPECT_THAT(
+      root, GmockMatch(m::Cbrt(m::Maximum(m::Parameter(0), m::Parameter(1)))));
+}
+
 }  // namespace
 }  // namespace xla
diff --git a/third_party/xla/xla/service/all_reduce_splitter.cc b/third_party/xla/xla/service/all_reduce_splitter.cc
deleted file mode 100644
index ce1e0e2..0000000
--- a/third_party/xla/xla/service/all_reduce_splitter.cc
+++ /dev/null
@@ -1,436 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/all_reduce_splitter.h"
-
-#include <cstdint>
-#include <optional>
-#include <string>
-#include <variant>
-#include <vector>
-
-#include "absl/cleanup/cleanup.h"
-#include "absl/container/flat_hash_map.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/log/check.h"
-#include "absl/log/log.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/string_view.h"
-#include "absl/strings/substitute.h"
-#include "xla/hlo/ir/collective_device_list.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/hlo/utils/hlo_query.h"
-#include "xla/service/collective_opt_utils.h"
-#include "xla/service/hlo_module_config.h"
-#include "xla/shape.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace {
-
-// Structure containing the newly calculated replica groups.
-struct ARReplicaGroups {
-  // First AR's replica group.
-  std::vector<ReplicaGroup> first_ar_replica_groups;
-  // Second AR's replica group.
-  std::vector<ReplicaGroup> second_ar_replica_groups;
-};
-
-// Contains relevant data to rewrite the AR + DS into AR + DS + AR.
-struct AllReduceRewriteSpec {
-  // Determines a dimension on which DS occurs.
-  int split_dim;
-  // Determines the size of the process group.
-  int group_size;
-  // AllReduce instruction to be rewritten.
-  HloAllReduceInstruction* all_reduce;
-  // DynamicSlice following the `all_reduce` indicating logical RS.
-  HloDynamicSliceInstruction* dynamic_slice;
-  // New replica groups for an `all_reduce`.
-  ARReplicaGroups replica_groups;
-
-  std::string ToString() {
-    return absl::Substitute(
-        "{\n split_dim=$0\n group_size=$1\n all_reduce=$2\n "
-        "dynamic_slice=$3\n}\n",
-        split_dim, group_size, all_reduce->ToString(),
-        dynamic_slice->ToString());
-  }
-};
-
-// Contains the relevant metadata for debugging why rewrite is infeasible.
-struct RewriteInfeasibleReason {
-  // Instruction for which it is infeasible to do a rewrite.
-  const HloInstruction* ar;
-  // Describes a reason of infeasibility.
-  std::string message;
-};
-
-// Hashable container to hold replica groups.
-struct ReplicaGroups {
-  std::vector<ReplicaGroup> replica_groups;
-
-  template <typename H>
-  friend H AbslHashValue(H h, const ReplicaGroups& rg) {
-    return H::combine(std::move(h), rg.replica_groups.size());
-  }
-
-  friend bool operator==(const ReplicaGroups& item,
-                         const ReplicaGroups& other) {
-    if (item.replica_groups.size() != other.replica_groups.size()) {
-      return false;
-    }
-    for (int i = 0; i < item.replica_groups.size(); i++) {
-      const ReplicaGroup& item_replica_group = item.replica_groups[i];
-      const ReplicaGroup& other_replica_group = other.replica_groups[i];
-      for (int i = 0; i < item_replica_group.replica_ids_size(); i++) {
-        if (item_replica_group.replica_ids(i) !=
-            other_replica_group.replica_ids(i)) {
-          return false;
-        }
-      }
-    }
-    return true;
-  }
-};
-
-using ARReplicaGroupMap =
-    absl::flat_hash_map<ReplicaGroups,
-                        std::vector<const HloAllReduceInstruction*>>;
-
-using RewriteDecision =
-    std::variant<AllReduceRewriteSpec, RewriteInfeasibleReason>;
-
-// Returns a single dimension which is being split by `ds`. Returns
-// std::nullopt if there are more, or no dimension to be split.
-std::optional<int> GetSplitDim(const HloAllReduceInstruction& ar,
-                               const HloDynamicSliceInstruction& ds) {
-  int split_dim = -1;
-  int num_dims = 0;
-  for (int64_t dim = 0; dim < ar.shape().rank(); ++dim) {
-    if (ar.shape().dimensions(dim) != ds.shape().dimensions(dim)) {
-      num_dims++;
-      split_dim = dim;
-    }
-  }
-  if (num_dims != 1) {
-    VLOG(2) << "No support for multiple nor 0 split dims.";
-    return std::nullopt;
-  }
-  return split_dim;
-}
-
-// For input collective instruction `ar` get the process group size (# shards).
-std::optional<int> GetProcessGroupSize(const HloAllReduceInstruction& ar,
-                                       const HloDynamicSliceInstruction& ds) {
-  CHECK(ds.operand(0) == &ar) << "Irrelevant AR + DS pair.";
-  std::optional<int> split_dim = GetSplitDim(ar, ds);
-  if (!split_dim.has_value()) {
-    return std::nullopt;
-  }
-
-  return ar.shape().dimensions(*split_dim) /
-         ds.dynamic_slice_sizes()[*split_dim];
-}
-
-ARReplicaGroupMap GetReplicaGroupsMap(HloComputation& computation) {
-  ARReplicaGroupMap map;
-  hlo_query::ForEachInstructionWithOpcode(
-      computation, HloOpcode::kAllReduce,
-      [&map](const HloInstruction* instruction) {
-        const HloAllReduceInstruction* ar =
-            Cast<HloAllReduceInstruction>(instruction);
-        auto rgs = ReplicaGroups{ar->replica_groups()};
-        map[rgs].push_back(ar);
-      });
-  return map;
-}
-
-ARReplicaGroups GetNewReplicaGroups(int group_size, int num_partitions) {
-  CHECK_EQ(num_partitions % group_size, 0);
-
-  std::vector<ReplicaGroup> first_ar_rgs, second_ar_rgs;
-  int num_units = num_partitions / group_size;
-  first_ar_rgs.reserve(num_units);
-  second_ar_rgs.reserve(group_size);
-
-  // Construct first AR replica groups.
-  for (int u = 0; u < group_size * num_units; u += group_size) {
-    ReplicaGroup& group = first_ar_rgs.emplace_back();
-    for (int r = u; r < u + group_size; r++) {
-      group.add_replica_ids(r);
-    }
-  }
-
-  // Construct second AR replica groups.
-  for (int g = 0; g < group_size; g++) {
-    ReplicaGroup& group = second_ar_rgs.emplace_back();
-    for (int r = g; r < group_size * num_units; r += group_size) {
-      group.add_replica_ids(r);
-    }
-  }
-  return {
-      /*first_ar_replica_groups=*/first_ar_rgs,
-      /*second_ar_replica_groups=*/second_ar_rgs,
-  };
-}
-
-// Returns true if `spec` can be transformed into a logical reduce scatter.
-// False otherwise.
-bool IsLogicalReduceScatter(const HloModule& module,
-                            const AllReduceRewriteSpec& spec,
-                            HloComputation& computation) {
-  HloAllReduceInstruction& ar = *spec.all_reduce;
-  CHECK_EQ(ar.user_count(), 1);
-  CHECK_EQ(module.config().replica_count(), 1);
-
-  HloInstruction* first_ar =
-      computation.AddInstruction(HloInstruction::CreateAllReduce(
-          ar.shape(), ar.operands(), ar.to_apply(),
-          CollectiveDeviceList(spec.replica_groups.first_ar_replica_groups),
-          ar.constrain_layout(), hlo_query::NextChannelId(module),
-          ar.use_global_device_ids()));
-
-  HloInstruction* ds = ar.users()[0];
-  auto* old_operand = ds->mutable_operand(0);
-  if (!ds->ReplaceOperandWith(0, first_ar).ok()) {
-    return false;
-  }
-  absl::Cleanup _ = [&] {
-    CHECK_OK(ds->ReplaceOperandWith(0, old_operand));
-    CHECK_OK(computation.RemoveInstruction(first_ar));
-  };
-  return MatchReduceScatter(Cast<HloAllReduceInstruction>(first_ar),
-                            module.config().num_partitions(),
-                            module.config().replica_count(),
-                            /*allow_multiple_split_dims=*/false,
-                            /*allow_intervening_reshape=*/true)
-      .has_value();
-}
-
-// Determine whether the given `spec`'s AllReduce instruction is profitable to
-// split. Currently it employs a simple heuristic, and it checks whether there
-// exists at least one all reduce with same replica groups as any of the all
-// reduce's replica groups after the potential split.
-bool IsProfitableToSplit(const ARReplicaGroupMap& replica_map,
-                         const AllReduceRewriteSpec& spec) {
-  auto new_rgs = spec.replica_groups;
-  bool first_replica_exists =
-      replica_map.contains(ReplicaGroups{new_rgs.first_ar_replica_groups});
-  bool second_replica_exists =
-      replica_map.contains(ReplicaGroups{new_rgs.second_ar_replica_groups});
-  return first_replica_exists || second_replica_exists;
-}
-
-RewriteDecision CanRewrite(const HloModule& module,
-                           const ARReplicaGroupMap& replica_map,
-                           HloComputation& computation,
-                           HloInstruction& instruction) {
-  // We rely on SPMD partitioning enabled, thus asserting `replica_count` = 1.
-  const HloModuleConfig& config = module.config();
-  if (config.use_auto_spmd_partitioning() || !config.use_spmd_partitioning() ||
-      config.replica_count() != 1) {
-    return RewriteInfeasibleReason{
-        &instruction,
-        "Supporting only SPMD partitioning scheme.",
-    };
-  }
-
-  if (instruction.opcode() != HloOpcode::kAllReduce) {
-    return RewriteInfeasibleReason{
-        &instruction,
-        "Cannot rewrite an AllReduce, since it's not AllReduce.",
-    };
-  }
-
-  auto* ar = Cast<HloAllReduceInstruction>(&instruction);
-
-  if (!ar->use_global_device_ids()) {
-    return RewriteInfeasibleReason{
-        &instruction,
-        "Only global ids are supported currently.",
-    };
-  }
-
-  if (ar->user_count() != 1 ||
-      ar->users().front()->opcode() != HloOpcode::kDynamicSlice) {
-    return RewriteInfeasibleReason{
-        &instruction,
-        "Cannot rewrite AllReduce if it is not a logical reduce scatter.",
-    };
-  }
-
-  auto* ds = Cast<HloDynamicSliceInstruction>(ar->users().front());
-
-  if (ds->user_count() > 1) {
-    return RewriteInfeasibleReason{
-        &instruction,
-        "Exactly one user of dynamic slice is required for a rewrite.",
-    };
-  }
-
-  int num_partitions = config.num_partitions();
-
-  std::vector<ReplicaGroup> rgs = ar->replica_groups();
-  if (rgs.size() != 1 || rgs.front().replica_ids_size() != num_partitions) {
-    return RewriteInfeasibleReason{
-        &instruction,
-        absl::StrCat("Cannot determine a valid split with num_partitions: ",
-                     num_partitions),
-    };
-  }
-
-  std::optional<int> split_dim = GetSplitDim(*ar, *ds);
-  if (!split_dim.has_value()) {
-    return RewriteInfeasibleReason{
-        &instruction,
-        "Cannot get a split dim.",
-    };
-  }
-
-  std::optional<int> group_size = GetProcessGroupSize(*ar, *ds);
-  if (!group_size.has_value()) {
-    return RewriteInfeasibleReason{
-        &instruction,
-        "Cannot determine a group size.",
-    };
-  }
-
-  if (num_partitions == group_size) {
-    return RewriteInfeasibleReason{
-        &instruction,
-        "Nothing to rewrite",
-    };
-  }
-
-  if (num_partitions % *group_size != 0) {
-    return RewriteInfeasibleReason{
-        &instruction,
-        "Group size does not evenly divide the number of partitions",
-    };
-  }
-
-  auto spec = AllReduceRewriteSpec{
-      /*split_dim=*/*split_dim,
-      /*group_size=*/*group_size,
-      /*all_reduce=*/ar,
-      /*dynamic_slice=*/ds,
-      /*replica_groups=*/GetNewReplicaGroups(*group_size, num_partitions),
-  };
-
-  if (!IsLogicalReduceScatter(module, spec, computation)) {
-    return RewriteInfeasibleReason{
-        &instruction,
-        "Not a logical reduce scatter.",
-    };
-  }
-
-  if (!IsProfitableToSplit(replica_map, spec)) {
-    return RewriteInfeasibleReason{
-        &instruction,
-        "Splitting is not profitable.",
-    };
-  }
-
-  return spec;
-}
-
-absl::StatusOr<bool> SplitAllReduce(const HloModuleConfig& config,
-                                    AllReduceRewriteSpec spec,
-                                    HloComputation& computation) {
-  int64_t next_channel_id =
-      hlo_query::NextChannelId(*spec.all_reduce->GetModule());
-  VLOG(1) << "AR splitting spec: " << spec.ToString();
-  // Create first AR.
-  int num_partitions = config.num_partitions();
-  // # of shards within a replica
-  int group_size = spec.group_size;
-
-  CHECK_EQ(num_partitions % group_size, 0);
-
-  HloAllReduceInstruction& ar = *spec.all_reduce;
-  HloDynamicSliceInstruction& ds = *spec.dynamic_slice;
-
-  const auto& [first_ar_replica_groups, second_ar_replica_groups] =
-      spec.replica_groups;
-  int channel_id = next_channel_id++;
-  HloInstruction* first_ar =
-      computation.AddInstruction(HloInstruction::CreateAllReduce(
-          ar.shape(), ar.operands(), ar.to_apply(),
-          CollectiveDeviceList(first_ar_replica_groups), ar.constrain_layout(),
-          channel_id, ar.use_global_device_ids()));
-
-  // Create second AR.
-  channel_id = next_channel_id++;
-  HloInstruction* second_ar =
-      computation.AddInstruction(HloInstruction::CreateAllReduce(
-          ds.shape(), {&ds}, ar.to_apply(),
-          CollectiveDeviceList(second_ar_replica_groups), ar.constrain_layout(),
-          channel_id, ar.use_global_device_ids()));
-
-  // Rewire.
-  TF_RETURN_IF_ERROR(computation.ReplaceInstruction(&ar, first_ar));
-  if (ds.IsRoot()) {
-    computation.set_root_instruction(second_ar);
-  }
-  TF_RETURN_IF_ERROR(ds.ReplaceAllUsesWith(second_ar));
-  return true;  // changed
-}
-
-// Splits `instruction` if it finds it is feasible and profitable to do so.
-// Return true if `instruction` has been rewritten, or false otherwise.
-absl::StatusOr<bool> SplitAllReduce(const HloModule& module,
-                                    const ARReplicaGroupMap& replica_map,
-                                    HloComputation& computation,
-                                    HloInstruction& instruction) {
-  RewriteDecision spec =
-      CanRewrite(module, replica_map, computation, instruction);
-  if (std::holds_alternative<RewriteInfeasibleReason>(spec)) {
-    auto reason = std::get<RewriteInfeasibleReason>(spec);
-    VLOG(1) << "Cannot process {" << reason.ar->ToString()
-            << "} due to : " << reason.message;
-    return false;  // changed
-  }
-  return SplitAllReduce(module.config(), std::get<AllReduceRewriteSpec>(spec),
-                        computation);  // changed
-}
-
-}  // namespace
-
-absl::StatusOr<bool> AllReduceSplitter::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  bool changed = false;
-
-  for (auto* computation : module->computations(execution_threads)) {
-    ARReplicaGroupMap replica_map = GetReplicaGroupsMap(*computation);
-    for (HloInstruction* instr : computation->MakeInstructionPostOrder()) {
-      TF_ASSIGN_OR_RETURN(bool rewritten, SplitAllReduce(*module, replica_map,
-                                                         *computation, *instr));
-      changed |= rewritten;
-    }
-  }
-
-  return changed;
-}
-
-}  // namespace xla
diff --git a/third_party/xla/xla/service/all_reduce_splitter.h b/third_party/xla/xla/service/all_reduce_splitter.h
deleted file mode 100644
index ac8dec7..0000000
--- a/third_party/xla/xla/service/all_reduce_splitter.h
+++ /dev/null
@@ -1,77 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_ALL_REDUCE_SPLITTER_H_
-#define XLA_SERVICE_ALL_REDUCE_SPLITTER_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-
-// Rewrites global AR if it is in the form of AR + DS and matches existing
-// replica groups into a logical RS followed by AR.
-//
-// If the pass detects AR followed by DS, then it checks whether
-// it is profitable to break it down into a logical RS (but AR + DS still),
-// followed by an AR to keep the rewrite numerically equivalent.
-//
-// Consider following example:
-//
-// Input program:
-//   HloModule m, num_partitions=8
-//     p = partition_id()
-//     ar = bf16[32] all-reduce(x), replica_groups={{0,1,2,3,4,5,6,7}}
-//     ds = dynamic-slice(ar, pointer(partition_id)), dynamic_slice_sizes={8}
-//
-// There is a global AR performing a reduction over 8 partitions.
-// However DS is performing 8-sized slice of a 32-sized tensor which implies
-// only 4 distinct slices of a tensor, which further implies 2 replicas of each
-// calculated slice. This can be expressed as RS within the replicas followed by
-// AR across the replicas. The transformation limits collectives to the data
-// that is actually needed for the requested slice.
-//
-// Output program:
-//   HloModule m, num_partitions=8
-//     p = partition_id()
-//     ar = bf16[32] all-reduce(x), replica_groups={{0,1,2,3},{4,5,6,7}}
-//     ds = dynamic-slice(ar, pointer(partition_id)), dynamic_slice_sizes={8}
-//     ar.2 = bf16[32] all-reduce(ds), replica_groups={{0,4},{1,5},{2,6},{3,7}}
-//
-// In addition the pass does the rewrite only if it finds it profitable to do
-// so. The profitability function is simple, and just checks whether there are
-// any collectives with same replica groups. If there are then the combiner pass
-// can pick it up, and fuse it into the same NCCL call.
-//
-// While the solution is orthogonal to existing known distribution patterns, in
-// practice it is profitable for HSDP style communication pattern.
-// https://arxiv.org/pdf/2203.11014
-//
-class AllReduceSplitter : public HloModulePass {
- public:
-  absl::string_view name() const override { return "all-reduce-splitter"; }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-};
-
-}  // namespace xla
-
-#endif  // XLA_SERVICE_ALL_REDUCE_SPLITTER_H_
diff --git a/third_party/xla/xla/service/all_reduce_splitter_test.cc b/third_party/xla/xla/service/all_reduce_splitter_test.cc
deleted file mode 100644
index 3902a97..0000000
--- a/third_party/xla/xla/service/all_reduce_splitter_test.cc
+++ /dev/null
@@ -1,506 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/all_reduce_splitter.h"
-
-#include <cstddef>
-#include <cstdint>
-#include <memory>
-#include <string>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "absl/algorithm/container.h"
-#include "absl/log/log.h"
-#include "absl/status/status.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/gpu/gpu_reduce_scatter_creator.h"
-#include "xla/service/hlo_module_config.h"
-#include "xla/service/hlo_pass_pipeline.h"
-#include "xla/tests/filecheck.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/tsl/lib/core/status_test_util.h"
-#include "tsl/platform/status_matchers.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-using ::tsl::testing::IsOkAndHolds;
-
-class AllReduceSplitterTest : public HloTestBase {
- public:
-  absl::StatusOr<std::unique_ptr<HloModule>> PrepareModule(
-      absl::string_view hlo_module, int64_t num_replicas,
-      int64_t num_partitions) {
-    HloModuleConfig config = GetModuleConfigForTest(
-        /*replica_count=*/num_replicas,
-        /*num_partitions=*/num_partitions);
-    config.set_use_spmd_partitioning(num_partitions > 1);
-    return ParseAndReturnVerifiedModule(hlo_module, config);
-  }
-
-  size_t AllReduceCount(const HloModule &module) {
-    return CollectiveCount(module, HloOpcode::kAllReduce);
-  }
-
- private:
-  size_t CollectiveCount(const HloModule &module, HloOpcode opcode) {
-    return absl::c_count_if(
-        module.entry_computation()->instructions(),
-        [&opcode](HloInstruction *instr) { return instr->opcode() == opcode; });
-  }
-};
-
-class AllReduceSplitterFilecheckTest : public AllReduceSplitterTest {
- public:
-  absl::Status FileCheck(const std::string &hlo_text,
-                         absl::string_view pattern) {
-    TF_ASSIGN_OR_RETURN(bool matched, RunFileCheck(hlo_text, pattern));
-    if (!matched) {
-      return absl::InternalError("Filecheck failed.");
-    }
-    return absl::OkStatus();
-  }
-};
-
-TEST_F(
-    AllReduceSplitterFilecheckTest,
-    MatchBasicPatternIfDynamicSliceIsRootAndThereExistsAllReduceWithSameReplicaGroups) {  // NOLINT
-  absl::string_view hlo_string = R"(
-HloModule m
-
-sum {
-  a = bf16[] parameter(0)
-  b = bf16[] parameter(1)
-  ROOT _ = bf16[] add(a,b)
-}
-
-ENTRY main {
-  p = bf16[2,4096,4096] parameter(0)
-  first.ar = bf16[2,4096,4096] all-reduce(p), replica_groups={{0,1,2,3},{4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=1
-  zero = bf16[] constant(0)
-  reduce = bf16[4096] reduce(first.ar, zero), dimensions={0,1}, to_apply=sum
-  all-reduce = bf16[4096] all-reduce(reduce), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=2
-  table = s32[8]{0} constant({0,1,2,3,0,1,2,3})
-  pid = u32[] partition-id()
-  id = s32[1] dynamic-slice(table, pid), dynamic_slice_sizes={1}
-  reshape = s32[] reshape(id)
-  slice_size = s32[] constant(1024)
-  offset = s32[] multiply(reshape, slice_size)
-  ROOT _ = bf16[1024] dynamic-slice(all-reduce, offset), dynamic_slice_sizes={1024}
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(
-      std::unique_ptr<HloModule> module,
-      PrepareModule(hlo_string, /*num_replicas=*/1, /*num_partitions=*/8));
-
-  EXPECT_THAT(AllReduceSplitter().Run(module.get()), IsOkAndHolds(true));
-  TF_EXPECT_OK(FileCheck(module->ToString(), R"(
-    CHECK-DAG:    %[[P0:.*]] = bf16[2,4096,4096]{2,1,0} parameter(0)
-    CHECK:        %[[AR0:.*]] = bf16[2,4096,4096]{2,1,0} all-reduce(bf16[2,4096,4096]{2,1,0} %[[P0]])
-    CHECK-SAME:   replica_groups={[[DESIRED_RGS:.*]]}
-    CHECK-DAG:    %[[ZERO:.*]] = bf16[] constant(0)
-    CHECK-DAG:    %[[LOCAL_REDUCE:.*]] = bf16[4096]{0} reduce(bf16[2,4096,4096]{2,1,0} %[[AR0]], bf16[] %[[ZERO]])
-    CHECK:        %[[AR1:.*]] = bf16[4096]{0} all-reduce(bf16[4096]{0} %[[LOCAL_REDUCE]])
-    CHECK-SAME:   replica_groups={[[DESIRED_RGS]]}
-    CHECK:        %[[DS:.*]] = bf16[1024]{0} dynamic-slice(bf16[4096]{0} %[[AR1]], s32[] %[[_:.*]])
-    CHECK-SAME:   dynamic_slice_sizes={1024}
-    CHECK-NEXT:   ROOT %[[AR2:.*]] = bf16[1024]{0} all-reduce(bf16[1024]{0} %[[DS]])
-    CHECK-SAME:   replica_groups={{[{]}}{0,4},{1,5},{2,6},{3,7}{{[}]}}
-    )"));
-}
-
-TEST_F(
-    AllReduceSplitterTest,
-    DoesNotMatchMatchBasicPatternIfDynamicSliceIsRootAndThereIsNoAllReduceWithSameReplicaGroups) {  // NOLINT
-  absl::string_view hlo_string = R"(
-HloModule m
-
-sum {
-  a = bf16[] parameter(0)
-  b = bf16[] parameter(1)
-  ROOT _ = bf16[] add(a,b)
-}
-
-ENTRY main {
-  p = bf16[2,4096,4096] parameter(0)
-  zero = bf16[] constant(0)
-  reduce = bf16[4096] reduce(p, zero), dimensions={0,1}, to_apply=sum
-  all-reduce = bf16[4096] all-reduce(reduce), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=2
-  table = s32[8]{0} constant({0,1,2,3,0,1,2,3})
-  pid = u32[] partition-id()
-  id = s32[1] dynamic-slice(table, pid), dynamic_slice_sizes={1}
-  reshape = s32[] reshape(id)
-  slice_size = s32[] constant(1024)
-  offset = s32[] multiply(reshape, slice_size)
-  ROOT _ = bf16[1024] dynamic-slice(all-reduce, offset), dynamic_slice_sizes={1024}
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(
-      std::unique_ptr<HloModule> module,
-      PrepareModule(hlo_string, /*num_replicas=*/1, /*num_partitions=*/8));
-
-  EXPECT_THAT(AllReduceSplitter().Run(module.get()), IsOkAndHolds(false));
-
-  EXPECT_EQ(AllReduceCount(*module), 1);
-}
-
-TEST_F(
-    AllReduceSplitterFilecheckTest,
-    MatchBasicPatternIfDynamicSliceIsNotRootAndThereExistsAllReduceWithSameReplicaGroups) {  // NOLINT
-  absl::string_view hlo_string = R"(
-HloModule m
-
-sum {
-  a = bf16[] parameter(0)
-  b = bf16[] parameter(1)
-  ROOT _ = bf16[] add(a,b)
-}
-
-ENTRY main {
-  p = bf16[2,4096,4096] parameter(0)
-  zero = bf16[] constant(0)
-  first.ar = bf16[2,4096,4096] all-reduce(p), replica_groups={{0,1,2,3},{4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=1
-  reduce = bf16[4096] reduce(p, zero), dimensions={0,1}, to_apply=sum
-  all-reduce = bf16[4096] all-reduce(reduce), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=1
-  table = s32[8]{0} constant({0,1,2,3,0,1,2,3})
-  pid = u32[] partition-id()
-  id = s32[1] dynamic-slice(table, pid), dynamic_slice_sizes={1}
-  reshape = s32[] reshape(id)
-  slice_size = s32[] constant(1024)
-  offset = s32[] multiply(reshape, slice_size)
-  dynamic_slice = bf16[1024] dynamic-slice(all-reduce, offset), dynamic_slice_sizes={1024}
-  broadcast = bf16[1024,1024] broadcast(dynamic_slice), dimensions={0}
-  ROOT _ = tuple(broadcast, first.ar)
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(
-      std::unique_ptr<HloModule> module,
-      PrepareModule(hlo_string, /*num_replicas=*/1, /*num_partitions=*/8));
-
-  EXPECT_THAT(AllReduceSplitter().Run(module.get()), IsOkAndHolds(true));
-  TF_EXPECT_OK(FileCheck(module->ToString(), R"(
-    CHECK-DAG:    %[[P0:.*]] = bf16[2,4096,4096]{2,1,0} parameter(0)
-    CHECK-DAG:    %[[ZERO:.*]] = bf16[] constant(0)
-    CHECK-DAG:    %[[LOCAL_REDUCE:.*]] = bf16[4096]{0} reduce(bf16[2,4096,4096]{2,1,0} %[[P0]], bf16[] %[[ZERO]])
-    CHECK:        %[[AR0:.*]] = bf16[4096]{0} all-reduce(bf16[4096]{0} %[[LOCAL_REDUCE]])
-    CHECK-SAME:   replica_groups={[[DESIRED_RGS:.*]]}
-    CHECK:        %[[DS:.*]] = bf16[1024]{0} dynamic-slice(bf16[4096]{0} %[[AR0]], s32[] %[[_:.*]])
-    CHECK-SAME:   dynamic_slice_sizes={1024}
-    CHECK-NEXT:   %[[AR1:.*]] = bf16[1024]{0} all-reduce(bf16[1024]{0} %[[DS]])
-    CHECK-SAME:   replica_groups={{[{]}}{0,4},{1,5},{2,6},{3,7}{{[}]}}
-    CHECK:        %[[EXISTING_AR:.*]] = bf16[2,4096,4096]{2,1,0} all-reduce(bf16[2,4096,4096]{2,1,0} %[[P0]])
-    CHECK-SAME:   replica_groups={[[DESIRED_RGS]]}
-    CHECK:        ROOT
-    CHECK-NOT:    %[[AR1]]
-    CHECK-SAME:   %[[EXISTING_AR]]
-    )"));
-}
-
-TEST_F(
-    AllReduceSplitterTest,
-    DoesNotMatchBasicPatternIfDynamicSliceIsNotRootAndThereIsNoAllReduceWithSameReplicaGroups) {  // NOLINT
-  absl::string_view hlo_string = R"(
-HloModule m
-
-sum {
-  a = bf16[] parameter(0)
-  b = bf16[] parameter(1)
-  ROOT _ = bf16[] add(a,b)
-}
-
-ENTRY main {
-  p = bf16[2,4096,4096] parameter(0)
-  p.1 = bf16[2,4096,4096] parameter(1)
-  zero = bf16[] constant(0)
-  reduce = bf16[4096] reduce(p, zero), dimensions={0,1}, to_apply=sum
-  all-reduce = bf16[4096] all-reduce(reduce), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=1
-  table = s32[8]{0} constant({0,1,2,3,0,1,2,3})
-  pid = u32[] partition-id()
-  id = s32[1] dynamic-slice(table, pid), dynamic_slice_sizes={1}
-  reshape = s32[] reshape(id)
-  slice_size = s32[] constant(1024)
-  offset = s32[] multiply(reshape, slice_size)
-  dynamic_slice = bf16[1024] dynamic-slice(all-reduce, offset), dynamic_slice_sizes={1024}
-  broadcast = bf16[1024,1024] broadcast(dynamic_slice), dimensions={0}
-  add = bf16[2,4096,4096] add(p,p.1)
-  ROOT _ = tuple(broadcast, add)
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(
-      std::unique_ptr<HloModule> module,
-      PrepareModule(hlo_string, /*num_replicas=*/1, /*num_partitions=*/8));
-
-  EXPECT_THAT(AllReduceSplitter().Run(module.get()), IsOkAndHolds(false));
-  EXPECT_EQ(AllReduceCount(*module), 1);
-}
-
-TEST_F(AllReduceSplitterTest,
-       DoesNotMatchBasicPatternIfDynamicSliceIsFullySharded) {  // NOLINT
-  absl::string_view hlo_string = R"(
-HloModule m
-
-sum {
-  a = bf16[] parameter(0)
-  b = bf16[] parameter(1)
-  ROOT _ = bf16[] add(a,b)
-}
-
-ENTRY main {
-  p = bf16[2,4096,4096] parameter(0)
-  first.ar = bf16[2,4096,4096] all-reduce(p), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=1
-  zero = bf16[] constant(0)
-  reduce = bf16[4096] reduce(first.ar, zero), dimensions={0,1}, to_apply=sum
-  all-reduce = bf16[4096] all-reduce(reduce), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=2
-  table = s32[8]{0} constant({0,1,2,3,0,1,2,3})
-  pid = u32[] partition-id()
-  id = s32[1] dynamic-slice(table, pid), dynamic_slice_sizes={1}
-  reshape = s32[] reshape(id)
-  slice_size = s32[] constant(512)
-  offset = s32[] multiply(reshape, slice_size)
-  ROOT _ = bf16[512] dynamic-slice(all-reduce, offset), dynamic_slice_sizes={512}
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(
-      std::unique_ptr<HloModule> module,
-      PrepareModule(hlo_string, /*num_replicas=*/1, /*num_partitions=*/8));
-
-  EXPECT_THAT(AllReduceSplitter().Run(module.get()), IsOkAndHolds(false));
-  EXPECT_EQ(AllReduceCount(*module), 2);
-}
-
-TEST_F(AllReduceSplitterTest,
-       DoesNotMatchBasicPatternIfItIsNotCompiledWithSPMDPartitioning) {
-  absl::string_view hlo_string = R"(
-HloModule m
-
-sum {
-  a = bf16[] parameter(0)
-  b = bf16[] parameter(1)
-  ROOT _ = bf16[] add(a,b)
-}
-
-ENTRY main {
-  p = bf16[2,4096,4096] parameter(0)
-  first.ar = bf16[2,4096,4096] all-reduce(p), replica_groups={{0,1,2,3},{4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=1
-  zero = bf16[] constant(0)
-  reduce = bf16[4096] reduce(first.ar, zero), dimensions={0,1}, to_apply=sum
-  all-reduce = bf16[4096] all-reduce(reduce), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=2
-  table = s32[8]{0} constant({0,1,2,3,0,1,2,3})
-  pid = u32[] partition-id()
-  id = s32[1] dynamic-slice(table, pid), dynamic_slice_sizes={1}
-  reshape = s32[] reshape(id)
-  slice_size = s32[] constant(1024)
-  offset = s32[] multiply(reshape, slice_size)
-  ROOT _ = bf16[1024] dynamic-slice(all-reduce, offset), dynamic_slice_sizes={1024}
-}
-)";
-  HloModuleConfig config =
-      GetModuleConfigForTest(/*replica_count=*/1, /*num_partitions=*/8);
-  config.set_use_spmd_partitioning(false);
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(hlo_string, config));
-
-  EXPECT_THAT(AllReduceSplitter().Run(module.get()), IsOkAndHolds(false));
-  EXPECT_THAT(AllReduceCount(*module), 2);
-}
-
-TEST_F(AllReduceSplitterTest,
-       DoesNotMatchBasicPatternIfUseGlobalDeviceIdsIsFalse) {  // NOLINT
-  absl::string_view hlo_string = R"(
-HloModule m
-
-sum {
-  a = bf16[] parameter(0)
-  b = bf16[] parameter(1)
-  ROOT _ = bf16[] add(a,b)
-}
-
-ENTRY main {
-  p = bf16[2,4096,4096] parameter(0)
-  first.ar = bf16[2,4096,4096] all-reduce(p), replica_groups={{0,1,2,3},{4,5,6,7}}, to_apply=sum, channel_id=1
-  zero = bf16[] constant(0)
-  reduce = bf16[4096] reduce(first.ar, zero), dimensions={0,1}, to_apply=sum
-  all-reduce = bf16[4096] all-reduce(reduce), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=sum, channel_id=2
-  table = s32[8]{0} constant({0,1,2,3,0,1,2,3})
-  pid = u32[] partition-id()
-  id = s32[1] dynamic-slice(table, pid), dynamic_slice_sizes={1}
-  reshape = s32[] reshape(id)
-  slice_size = s32[] constant(1024)
-  offset = s32[] multiply(reshape, slice_size)
-  ROOT _ = bf16[1024] dynamic-slice(all-reduce, offset), dynamic_slice_sizes={1024}
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(
-      std::unique_ptr<HloModule> module,
-      PrepareModule(hlo_string, /*num_replicas=*/1, /*num_partitions=*/8));
-
-  EXPECT_THAT(AllReduceSplitter().Run(module.get()), IsOkAndHolds(false));
-
-  EXPECT_EQ(AllReduceCount(*module), 2);
-}
-
-TEST_F(AllReduceSplitterTest,
-       DoesNotMatchBasicPatternIfIsNotCrossAllPartitionsAllReduce) {  // NOLINT
-  absl::string_view hlo_string = R"(
-HloModule m
-
-sum {
-  a = bf16[] parameter(0)
-  b = bf16[] parameter(1)
-  ROOT _ = bf16[] add(a,b)
-}
-
-ENTRY main {
-  p = bf16[2,4096,4096] parameter(0)
-  first.ar = bf16[2,4096,4096] all-reduce(p), replica_groups={{0,1,2,3},{4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=1
-  zero = bf16[] constant(0)
-  reduce = bf16[4096] reduce(first.ar, zero), dimensions={0,1}, to_apply=sum
-  all-reduce = bf16[4096] all-reduce(reduce), replica_groups={{0,1,2,3},{4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=2
-  table = s32[8]{0} constant({0,1,2,3,0,1,2,3})
-  pid = u32[] partition-id()
-  id = s32[1] dynamic-slice(table, pid), dynamic_slice_sizes={1}
-  reshape = s32[] reshape(id)
-  slice_size = s32[] constant(1024)
-  offset = s32[] multiply(reshape, slice_size)
-  ROOT _ = bf16[1024] dynamic-slice(all-reduce, offset), dynamic_slice_sizes={1024}
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(
-      std::unique_ptr<HloModule> module,
-      PrepareModule(hlo_string, /*num_replicas=*/1, /*num_partitions=*/8));
-
-  EXPECT_THAT(AllReduceSplitter().Run(module.get()), IsOkAndHolds(false));
-
-  EXPECT_EQ(AllReduceCount(*module), 2);
-}
-
-TEST_F(
-    AllReduceSplitterFilecheckTest,
-    PipelineMatchesBasicPatternWithDynamicSliceAsRootAndRewritesToReduceScatter) {  // NOLINT
-  absl::string_view hlo_string = R"(
-HloModule m
-
-sum {
-  a = bf16[] parameter(0)
-  b = bf16[] parameter(1)
-  ROOT _ = bf16[] add(a,b)
-}
-
-ENTRY main {
-  p = bf16[2,4096,4096] parameter(0)
-  first.ar = bf16[2,4096,4096] all-reduce(p), replica_groups={{0,1,2,3},{4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=1
-  zero = bf16[] constant(0)
-  reduce = bf16[4096] reduce(first.ar, zero), dimensions={0,1}, to_apply=sum
-  all-reduce = bf16[4096] all-reduce(reduce), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=2
-  table = s32[8]{0} constant({0,1,2,3,0,1,2,3})
-  pid = u32[] partition-id()
-  id = s32[1] dynamic-slice(table, pid), dynamic_slice_sizes={1}
-  reshape = s32[] reshape(id)
-  slice_size = s32[] constant(1024)
-  offset = s32[] multiply(reshape, slice_size)
-  ROOT _ = bf16[1024] dynamic-slice(all-reduce, offset), dynamic_slice_sizes={1024}
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(
-      std::unique_ptr<HloModule> module,
-      PrepareModule(hlo_string, /*num_replicas=*/1, /*num_partitions=*/8));
-
-  HloPassPipeline pipeline("all-reduce-splitter-rewrite");
-  pipeline.AddPass<AllReduceSplitter>();
-  pipeline.AddPass<ReduceScatterCreator>();
-  EXPECT_THAT(pipeline.Run(module.get()), IsOkAndHolds(true));
-  TF_EXPECT_OK(FileCheck(module->ToString(), R"(
-    CHECK-DAG:    %[[P0:.*]] = bf16[2,4096,4096]{2,1,0} parameter(0)
-    CHECK:        %[[AR0:.*]] = bf16[2,4096,4096]{2,1,0} all-reduce(bf16[2,4096,4096]{2,1,0} %[[P0]])
-    CHECK-SAME:   replica_groups={[[DESIRED_RGS:.*]]}
-    CHECK-DAG:    %[[ZERO:.*]] = bf16[] constant(0)
-    CHECK-DAG:    %[[LOCAL_REDUCE:.*]] = bf16[4096]{0} reduce(bf16[2,4096,4096]{2,1,0} %[[AR0]], bf16[] %[[ZERO]])
-    CHECK:        %[[REDUCE_SCATTER:.*]] = bf16[1024]{0} reduce-scatter(bf16[4096]{0} %[[LOCAL_REDUCE]])
-    CHECK-SAME:   replica_groups={[[DESIRED_RGS]]}
-    CHECK-NEXT:   ROOT %[[AR2:.*]] = bf16[1024]{0} all-reduce(bf16[1024]{0} %[[REDUCE_SCATTER]])
-    CHECK-SAME:   replica_groups={{[{]}}{0,4},{1,5},{2,6},{3,7}{{[}]}}
-    )"));
-}
-
-TEST_F(
-    AllReduceSplitterFilecheckTest,
-    PipelineMatchesBasicPatternWithDynamicSliceNotAsRootAndRewritesToReduceScatter) {  // NOLINT
-  absl::string_view hlo_string = R"(
-HloModule m
-
-sum {
-  a = bf16[] parameter(0)
-  b = bf16[] parameter(1)
-  ROOT _ = bf16[] add(a,b)
-}
-
-ENTRY main {
-  p = bf16[2,4096,4096] parameter(0)
-  zero = bf16[] constant(0)
-  first.ar = bf16[2,4096,4096] all-reduce(p), replica_groups={{0,1,2,3},{4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=1
-  reduce = bf16[4096] reduce(p, zero), dimensions={0,1}, to_apply=sum
-  all-reduce = bf16[4096] all-reduce(reduce), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=1
-  table = s32[8]{0} constant({0,1,2,3,0,1,2,3})
-  pid = u32[] partition-id()
-  id = s32[1] dynamic-slice(table, pid), dynamic_slice_sizes={1}
-  reshape = s32[] reshape(id)
-  slice_size = s32[] constant(1024)
-  offset = s32[] multiply(reshape, slice_size)
-  dynamic_slice = bf16[1024] dynamic-slice(all-reduce, offset), dynamic_slice_sizes={1024}
-  broadcast = bf16[1024,1024] broadcast(dynamic_slice), dimensions={0}
-  ROOT _ = tuple(broadcast, first.ar)
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(
-      std::unique_ptr<HloModule> module,
-      PrepareModule(hlo_string, /*num_replicas=*/1, /*num_partitions=*/8));
-
-  HloPassPipeline pipeline("all-reduce-splitter-rewrite");
-  pipeline.AddPass<AllReduceSplitter>();
-  pipeline.AddPass<ReduceScatterCreator>();
-  EXPECT_THAT(pipeline.Run(module.get()), IsOkAndHolds(true));
-  TF_EXPECT_OK(FileCheck(module->ToString(), R"(
-    CHECK-DAG:    %[[P0:.*]] = bf16[2,4096,4096]{2,1,0} parameter(0)
-    CHECK-DAG:    %[[ZERO:.*]] = bf16[] constant(0)
-    CHECK-DAG:    %[[LOCAL_REDUCE:.*]] = bf16[4096]{0} reduce(bf16[2,4096,4096]{2,1,0} %[[P0]], bf16[] %[[ZERO]])
-    CHECK:        %[[REDUCE_SCATTER:.*]] = bf16[1024]{0} reduce-scatter(bf16[4096]{0} %[[LOCAL_REDUCE]])
-    CHECK-NEXT:   %[[AR1:.*]] = bf16[1024]{0} all-reduce(bf16[1024]{0} %[[REDUCE_SCATTER]])
-    CHECK-SAME:   replica_groups={{[{]}}{0,4},{1,5},{2,6},{3,7}{{[}]}}
-    CHECK:        %[[EXISTING_AR:.*]] = bf16[2,4096,4096]{2,1,0} all-reduce(bf16[2,4096,4096]{2,1,0} %[[P0]])
-    CHECK:        ROOT
-    CHECK-NOT:    %[[AR1]]
-    CHECK-SAME:   %[[EXISTING_AR]]
-    )"));
-}
-
-}  // namespace
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/call_graph.cc b/third_party/xla/xla/service/call_graph.cc
index ea16ca0..80515e1 100644
--- a/third_party/xla/xla/service/call_graph.cc
+++ b/third_party/xla/xla/service/call_graph.cc
@@ -214,8 +214,8 @@
   } else if (a == b) {
     return a;
   } else {
-    // Contexts are different and neither is kNone, ie one is kSequential and
-    // the other is kParallel.
+    // Contexts are different and neither is kNone, i.e. one is kControlFlow and
+    // the other is kEmbedded.
     return CallContext::kBoth;
   }
 }
diff --git a/third_party/xla/xla/service/call_inliner.cc b/third_party/xla/xla/service/call_inliner.cc
index 579de41..0605fbd 100644
--- a/third_party/xla/xla/service/call_inliner.cc
+++ b/third_party/xla/xla/service/call_inliner.cc
@@ -139,6 +139,16 @@
 CallInliner::Inline(HloInstruction* call) {
   TF_RET_CHECK(call->opcode() == HloOpcode::kCall)
       << "Instruction was not a call op: " << call->opcode();
+  if (call->is_composite()) {
+    // Remove composite FE attrs before inlining, else they will appear on the
+    // inlined instructions.
+    FrontendAttributes frontend_attributes = call->frontend_attributes();
+    frontend_attributes.mutable_map()->erase("composite.name");
+    frontend_attributes.mutable_map()->erase("composite.attributes");
+    frontend_attributes.mutable_map()->erase("composite.version");
+    call->set_frontend_attributes(frontend_attributes);
+  }
+
   const auto& callees = call->called_computations();
   TF_RET_CHECK(callees.size() == 1);
   HloComputation* callee = callees[0];
diff --git a/third_party/xla/xla/service/call_inliner_test.cc b/third_party/xla/xla/service/call_inliner_test.cc
index da73fe6..ad6ee73 100644
--- a/third_party/xla/xla/service/call_inliner_test.cc
+++ b/third_party/xla/xla/service/call_inliner_test.cc
@@ -15,25 +15,25 @@
 
 #include "xla/service/call_inliner.h"
 
+#include <cstdint>
 #include <memory>
-#include <optional>
 #include <string>
-#include <utility>
-#include <vector>
 
+#include "absl/log/log.h"
+#include "absl/strings/string_view.h"
 #include "xla/hlo/ir/hlo_computation.h"
 #include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_opcode.h"
 #include "xla/hlo/utils/hlo_matchers.h"
-#include "xla/layout_util.h"
 #include "xla/literal.h"
-#include "xla/service/hlo_pass_fix.h"
+#include "xla/literal_util.h"
+#include "xla/service/hlo_parser.h"
+#include "xla/shape.h"
 #include "xla/shape_util.h"
 #include "xla/test.h"
 #include "xla/tests/hlo_test_base.h"
 #include "xla/tsl/lib/core/status_test_util.h"
-#include "xla/types.h"
 #include "xla/xla_data.pb.h"
+#include "tsl/platform/statusor.h"
 
 namespace op = xla::testing::opcode_matchers;
 
@@ -346,5 +346,36 @@
                         op::Constant(LiteralUtil::CreateR0<uint32_t>(2))));
   }
 }
+
+TEST_F(CallInlinerTest, InlineCompositeCall) {
+  const absl::string_view hlo_string = R"(
+  HloModule composite
+
+  %add (lhs: f32[]) -> f32[] {
+    %lhs = f32[] parameter(0)
+    %rhs = f32[] constant(2)
+    ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
+  }
+
+  ENTRY %main () -> f32[] {
+    %lhs = f32[] constant(42)
+    ROOT %call = f32[] call(f32[] %lhs), to_apply=%add, is_composite=true, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor<i32>},composite.name="foo.bar",composite.version="1"}
+  })";
+
+  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+  CallInliner call_inliner(/*single_call_site=*/true);
+  TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get()));
+  ASSERT_TRUE(mutated);
+
+  ASSERT_EQ(module->entry_computation()->instruction_count(), 3);
+  auto inst = module->entry_computation()->instructions().begin();
+  EXPECT_THAT(*inst, op::Constant());
+  ++inst;
+  EXPECT_THAT(*inst, op::Constant());
+  ++inst;
+  EXPECT_THAT(*inst, op::Add());
+  EXPECT_TRUE((*inst)->frontend_attributes().map().empty());
+}
+
 }  // namespace
 }  // namespace xla
diff --git a/third_party/xla/xla/service/change_op_data_type.cc b/third_party/xla/xla/service/change_op_data_type.cc
index 7c907d8..3c7875a 100644
--- a/third_party/xla/xla/service/change_op_data_type.cc
+++ b/third_party/xla/xla/service/change_op_data_type.cc
@@ -19,8 +19,7 @@
 
 #include "xla/service/hlo_creation_utils.h"
 #if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3)
-#include "xla/service/cpu/onednn_convolution_rewriter.h"
-#include "xla/service/cpu/onednn_matmul_rewriter.h"
+#include "xla/service/cpu/onednn_contraction_rewriter.h"
 #endif  // INTEL_MKL && ENABLE_ONEDNN_V3
 
 namespace xla {
@@ -65,11 +64,11 @@
       }
 #if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3)
       if (instr->opcode() == HloOpcode::kDot &&
-          cpu::OneDnnMatMulRewriter::ShouldRewrite(instr)) {
+          cpu::OneDnnContractionRewriter::ShouldRewriteDot(instr, true)) {
         continue;
       }
       if (instr->opcode() == HloOpcode::kConvolution &&
-          cpu::OneDnnConvolutionRewriter::ShouldRewrite(instr)) {
+          cpu::OneDnnContractionRewriter::ShouldRewriteConv(instr)) {
         continue;
       }
 #endif  // INTEL_MKL && ENABLE_ONEDNN_V3
diff --git a/third_party/xla/xla/service/collective_permute_decomposer_test.cc b/third_party/xla/xla/service/collective_permute_decomposer_test.cc
index 8de403e..eac5ab0 100644
--- a/third_party/xla/xla/service/collective_permute_decomposer_test.cc
+++ b/third_party/xla/xla/service/collective_permute_decomposer_test.cc
@@ -106,7 +106,7 @@
   EXPECT_THAT(
       recv->ToString(),
       HasSubstr(
-          "_xla_send_recv_source_target_pairs=\"{{0,1},{1,2},{2,3},{3,4}}\""));
+          "_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3},{3,4}}"));
   check_metadata(recv);
   check_not_pipelined(recv);
   HloInstruction* recv_done = FindInstruction(module.get(), "recv-done");
@@ -118,7 +118,7 @@
   EXPECT_THAT(
       send->ToString(),
       HasSubstr(
-          "_xla_send_recv_source_target_pairs=\"{{0,1},{1,2},{2,3},{3,4}}\""));
+          "_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3},{3,4}}"));
   check_metadata(send);
   check_not_pipelined(send);
   HloInstruction* send_done = FindInstruction(module.get(), "send-done");
@@ -212,7 +212,7 @@
   EXPECT_THAT(
       recv->ToString(),
       HasSubstr(
-          "_xla_send_recv_source_target_pairs=\"{{0,1},{1,2},{2,3},{3,4}}\""));
+          "_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3},{3,4}}"));
   EXPECT_THAT(recv->ToString(), HasSubstr("_xla_send_recv_pipeline=\"0\""));
   EXPECT_THAT(recv->ToString(), HasSubstr("_xla_other_attribute=\"xyz\""));
   HloInstruction* recv_done = FindInstruction(module.get(), "recv-done");
@@ -224,7 +224,7 @@
   EXPECT_THAT(
       send->ToString(),
       HasSubstr(
-          "_xla_send_recv_source_target_pairs=\"{{0,1},{1,2},{2,3},{3,4}}\""));
+          "_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3},{3,4}}"));
   EXPECT_THAT(send->ToString(), HasSubstr("_xla_send_recv_pipeline=\"0\""));
   EXPECT_THAT(send->ToString(), HasSubstr("_xla_other_attribute=\"xyz\""));
   HloInstruction* send_done = FindInstruction(module.get(), "send-done");
@@ -290,18 +290,18 @@
   HloInstruction* recv = FindInstruction(module.get(), "recv");
   EXPECT_EQ(recv->channel_id().value(), 1);
   EXPECT_THAT(recv->ToString(),
-              HasSubstr("_xla_send_recv_source_target_pairs=\"{{3,0}}\""));
+              HasSubstr("_xla_send_recv_source_target_pairs={{3,0}}"));
   EXPECT_THAT(recv->ToString(), HasSubstr("_xla_send_recv_pipeline=\"0\""));
   HloInstruction* send = FindInstruction(module.get(), "send");
   EXPECT_THAT(send->ToString(),
-              HasSubstr("_xla_send_recv_source_target_pairs=\"{{3,0}}\""));
+              HasSubstr("_xla_send_recv_source_target_pairs={{3,0}}"));
   EXPECT_THAT(send->ToString(), HasSubstr("_xla_send_recv_pipeline=\"0\""));
 
   HloInstruction* recv1 = FindInstruction(module.get(), "recv.1");
   EXPECT_EQ(recv1->channel_id().value(), 2);
   EXPECT_THAT(
       recv1->ToString(),
-      HasSubstr("_xla_send_recv_source_target_pairs=\"{{0,1},{1,2},{2,3}}\""));
+      HasSubstr("_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}"));
   EXPECT_THAT(recv1->ToString(), HasSubstr("_xla_send_recv_pipeline=\"1\""));
   HloInstruction* recv_done1 = FindInstruction(module.get(), "recv-done.1");
   EXPECT_THAT(recv_done1->ToString(),
@@ -309,7 +309,7 @@
   HloInstruction* send1 = FindInstruction(module.get(), "send.1");
   EXPECT_THAT(
       send1->ToString(),
-      HasSubstr("_xla_send_recv_source_target_pairs=\"{{0,1},{1,2},{2,3}}\""));
+      HasSubstr("_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}"));
   EXPECT_THAT(send1->ToString(), HasSubstr("_xla_send_recv_pipeline=\"1\""));
   HloInstruction* send_done1 = FindInstruction(module.get(), "send-done.1");
   EXPECT_THAT(send_done1->ToString(),
@@ -489,22 +489,22 @@
   EXPECT_EQ(recv->channel_id().value(), 1);
   EXPECT_THAT(
       recv->ToString(),
-      HasSubstr("_xla_send_recv_source_target_pairs=\"{{1,0},{2,1},{3,2}}\""));
+      HasSubstr("_xla_send_recv_source_target_pairs={{1,0},{2,1},{3,2}}"));
   EXPECT_THAT(recv->ToString(), HasSubstr("_xla_send_recv_pipeline=\"1\""));
   HloInstruction* send = FindInstruction(module.get(), "send");
   EXPECT_THAT(
       send->ToString(),
-      HasSubstr("_xla_send_recv_source_target_pairs=\"{{1,0},{2,1},{3,2}}\""));
+      HasSubstr("_xla_send_recv_source_target_pairs={{1,0},{2,1},{3,2}}"));
   EXPECT_THAT(send->ToString(), HasSubstr("_xla_send_recv_pipeline=\"1\""));
 
   HloInstruction* recv1 = FindInstruction(module.get(), "recv.1");
   EXPECT_EQ(recv1->channel_id().value(), 2);
   EXPECT_THAT(recv1->ToString(),
-              HasSubstr("_xla_send_recv_source_target_pairs=\"{{0,3}}\""));
+              HasSubstr("_xla_send_recv_source_target_pairs={{0,3}}"));
   EXPECT_THAT(recv1->ToString(), HasSubstr("_xla_send_recv_pipeline=\"0\""));
   HloInstruction* send1 = FindInstruction(module.get(), "send.1");
   EXPECT_THAT(send1->ToString(),
-              HasSubstr("_xla_send_recv_source_target_pairs=\"{{0,3}}\""));
+              HasSubstr("_xla_send_recv_source_target_pairs={{0,3}}"));
   EXPECT_THAT(send1->ToString(), HasSubstr("_xla_send_recv_pipeline=\"0\""));
 }
 
diff --git a/third_party/xla/xla/service/collective_pipeliner.cc b/third_party/xla/xla/service/collective_pipeliner.cc
index 859b6c9..72eeb36 100644
--- a/third_party/xla/xla/service/collective_pipeliner.cc
+++ b/third_party/xla/xla/service/collective_pipeliner.cc
@@ -445,7 +445,6 @@
                                                             ops.end());
   formatting_set.insert(source_ops.begin(), source_ops.end());
   std::vector<HloInstruction*> to_return;
-  absl::flat_hash_set<HloInstruction*> already_inserted;
   for (const HloInstruction* op : ops) {
     for (HloInstruction* operand : op->operands()) {
       if (!formatting_set.count(operand)) {
@@ -1380,7 +1379,6 @@
   // to still visit before visiting the HLO itself.
   std::vector<std::pair<const HloInstruction*, int>> stack(
       1, std::make_pair(instr, 0));
-  absl::flat_hash_set<const HloInstruction*> visited;
   while (!stack.empty()) {
     auto& current = stack.back();
     invariant_cache[std::get<0>(current)] = true;
@@ -2775,8 +2773,6 @@
         instruction, false, CollectivePipeliner::PipeliningDirection::kBackward,
         loop_analysis));
   }
-  absl::flat_hash_map<const HloInstruction*, HloInstruction*>
-      loop_cond_replacements;
   auto cond_builder =
       HloComputation::Builder(while_loop->while_condition()->name());
   HloInstruction* new_cond_param =
diff --git a/third_party/xla/xla/service/collective_pipeliner_test.cc b/third_party/xla/xla/service/collective_pipeliner_test.cc
index 5492cbc..a924a9a 100644
--- a/third_party/xla/xla/service/collective_pipeliner_test.cc
+++ b/third_party/xla/xla/service/collective_pipeliner_test.cc
@@ -239,14 +239,14 @@
   EXPECT_TRUE(*RunFileCheck(module->ToString(), R"(
     // CHECK: HloModule
     // CHECK: %while_body
-    // CHECK:   %[[cp:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}_xla_send_recv_validation="{{[{]}}{0,5},{0,6},{1,7},{2,8},{3,9},{4,10},{5,11},{6,12}{{[}]}}"
+    // CHECK:   %[[cp:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}_xla_send_recv_validation={{[{]}}{0,5},{0,6},{1,7},{2,8},{3,9},{4,10},{5,11},{6,12}{{[}]}}
     // CHECK:   %[[dus:.+]] = {{.+}} dynamic-slice({{.+}} %[[cp]], {{.+}})
     // CHECK:   %[[mul:.+]] = {{.+}} multiply({{.+}} %[[dus]], {{.+}} %[[dus]])
     // CHECK:   %[[dus2:.+]] = {{.+}} dynamic-update-slice({{.+}} %[[mul]], {{.+}})
     // CHECK:   ROOT {{.+}} = {{.+}} tuple({{.+}} %[[dus2]], {{.+}})
     // CHECK: }
     // CHECK: ENTRY %entry
-    // CHECK:   %[[cp:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}{_xla_send_recv_validation="{{[{]}}{0,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0}{{[}]}}"
+    // CHECK:   %[[cp:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}{_xla_send_recv_validation={{[{]}}{0,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0}{{[}]}}
     // CHECK:   %[[ds:.+]] = {{.+}} dynamic-slice({{.+}} %[[cp]], {{.+}})
     // CHECK:   %[[mul:.+]] = {{.+}} multiply({{.+}} %[[ds]], {{.+}} %[[ds]])
     // CHECK:   %[[dus:.+]] = {{.+}} dynamic-update-slice({{.+}} %[[mul]], {{.+}})
@@ -315,14 +315,14 @@
   EXPECT_TRUE(*RunFileCheck(module->ToString(), R"(
     // CHECK: HloModule
     // CHECK: %while_body
-    // CHECK:   %[[cp:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}_xla_send_recv_validation="{{[{]}}{6,12},{5,11},{4,10},{3,9},{2,8},{1,7},{0,6},{0,5}{{[}]}}"
+    // CHECK:   %[[cp:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}_xla_send_recv_validation={{[{]}}{6,12},{5,11},{4,10},{3,9},{2,8},{1,7},{0,6},{0,5}{{[}]}}
     // CHECK:   %[[dus:.+]] = {{.+}} dynamic-slice({{.+}} %[[cp]], {{.+}})
     // CHECK:   %[[mul:.+]] = {{.+}} multiply({{.+}} %[[dus]], {{.+}} %[[dus]])
     // CHECK:   %[[dus2:.+]] = {{.+}} dynamic-update-slice({{.+}} %[[mul]], {{.+}})
     // CHECK:   ROOT {{.+}} = {{.+}} tuple({{.+}} %[[dus2]], {{.+}})
     // CHECK: }
     // CHECK: ENTRY %entry
-    // CHECK:   %[[cp:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}{_xla_send_recv_validation="{{[{]}}{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{0,0}{{[}]}}"
+    // CHECK:   %[[cp:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}{_xla_send_recv_validation={{[{]}}{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{0,0}{{[}]}}
     // CHECK:   %[[ds:.+]] = {{.+}} dynamic-slice({{.+}} %[[cp]], {{.+}})
     // CHECK:   %[[mul:.+]] = {{.+}} multiply({{.+}} %[[ds]], {{.+}} %[[ds]])
     // CHECK:   %[[dus:.+]] = {{.+}} dynamic-update-slice({{.+}} %[[mul]], {{.+}})
@@ -1507,13 +1507,13 @@
   XLA_VLOG_LINES(1, module->ToString());
   EXPECT_TRUE(*RunFileCheck(module->ToString(), R"(
   // CHECK: %while_body
-  // CHECK: %[[cp:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}_xla_send_recv_validation="{{[{]}}{0,6},{1,7},{2,8},{3,9},{4,10},{5,11},{6,12},{7,12}{{[}]}}"}
+  // CHECK: %[[cp:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}_xla_send_recv_validation={{[{]}}{0,6},{1,7},{2,8},{3,9},{4,10},{5,11},{6,12},{7,12}{{[}]}}}
   // CHECK: %[[dus:.+]] = {{.+}} dynamic-update-slice({{.+}} %[[cp]], {{.+}})
   // CHECK: ROOT {{.+}} = {{.+}} tuple({{.+}} %[[dus]], {{.+}})
   // CHECK: ENTRY %entry
   // CHECK: %[[while:.+]] = {{.+}} while({{.+}})
   // CHECK: %[[gte:.+]] = {{.+}} get-tuple-element({{.+}} %[[while]]), index=1
-  // CHECK: %[[cp2:.+]] = {{.+}} collective-permute({{.+}} %[[gte]]), {{.+}}_xla_send_recv_validation="{{[{]}}{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{0,0}{{[}]}}"
+  // CHECK: %[[cp2:.+]] = {{.+}} collective-permute({{.+}} %[[gte]]), {{.+}}_xla_send_recv_validation={{[{]}}{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{0,0}{{[}]}}
   // CHECK: %[[dus:.+]] = {{.+}} dynamic-update-slice({{.+}} %[[cp2]], {{.+}})
   // CHECK: %[[tuple:.+]] = {{.+}} tuple({{.+}} %[[dus]], {{.+}})
   // CHECK: ROOT {{.+}} = {{.+}} get-tuple-element({{.+}} %[[tuple]]), index=1
@@ -1586,13 +1586,13 @@
   XLA_VLOG_LINES(1, module->ToString());
   EXPECT_TRUE(*RunFileCheck(module->ToString(), R"(
   // CHECK: %while_body
-  // CHECK: %[[cp:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}_xla_send_recv_validation="{{[{]}}{7,12},{6,12},{5,11},{4,10},{3,9},{2,8},{1,7},{0,6}{{[}]}}"}
+  // CHECK: %[[cp:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}_xla_send_recv_validation={{[{]}}{7,12},{6,12},{5,11},{4,10},{3,9},{2,8},{1,7},{0,6}{{[}]}}}
   // CHECK: %[[dus:.+]] = {{.+}} dynamic-update-slice({{.+}} %[[cp]], {{.+}})
   // CHECK: ROOT {{.+}} = {{.+}} tuple({{.+}} %[[dus]], {{.+}})
   // CHECK: ENTRY %entry
   // CHECK: %[[while:.+]] = {{.+}} while({{.+}})
   // CHECK: %[[gte:.+]] = {{.+}} get-tuple-element({{.+}} %[[while]]), index=1
-  // CHECK: %[[cp2:.+]] = {{.+}} collective-permute({{.+}} %[[gte]]), {{.+}}_xla_send_recv_validation="{{[{]}}{0,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0}{{[}]}}"
+  // CHECK: %[[cp2:.+]] = {{.+}} collective-permute({{.+}} %[[gte]]), {{.+}}_xla_send_recv_validation={{[{]}}{0,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0}{{[}]}}
   // CHECK: %[[dus:.+]] = {{.+}} dynamic-update-slice({{.+}} %[[cp2]], {{.+}})
   // CHECK: %[[tuple:.+]] = {{.+}} tuple({{.+}} %[[dus]], {{.+}})
   // CHECK: ROOT {{.+}} = {{.+}} get-tuple-element({{.+}} %[[tuple]]), index=1
@@ -2426,7 +2426,7 @@
 
   EXPECT_EQ(recv1->channel_id(), send1->channel_id());
 
-  const char* kSourceTarget = "_xla_send_recv_source_target_pairs=\"{{3,0}}\"";
+  const char* kSourceTarget = "_xla_send_recv_source_target_pairs={{3,0}}";
   const char* kPeeledAttr = "_xla_other_attr=\"1\"";
   const char* kRotatedAttr = "_xla_other_attr=\"2\"";
   EXPECT_THAT(send1->ToString(), ::testing::HasSubstr(kSourceTarget));
diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD
index a97736d..3ed45fe 100644
--- a/third_party/xla/xla/service/cpu/BUILD
+++ b/third_party/xla/xla/service/cpu/BUILD
@@ -222,8 +222,7 @@
         ":ir_emission_utils",
         ":ir_emitter",
         ":ir_emitter2",
-        ":onednn_convolution_rewriter",
-        ":onednn_matmul_rewriter",
+        ":onednn_contraction_rewriter",
         ":onednn_ops_rewriter",
         ":parallel_task_assignment",
         ":simple_orc_jit",
@@ -589,6 +588,7 @@
         "@com_google_absl//absl/strings:str_format",
         "@com_google_absl//absl/types:span",
         "@eigen_archive//:eigen3",
+        "@llvm-project//llvm:Core",
         "@llvm-project//llvm:ExecutionEngine",
         "@llvm-project//llvm:OrcJIT",
         "@llvm-project//llvm:OrcShared",
@@ -664,6 +664,28 @@
 )
 
 xla_cc_test(
+    name = "ir_emitter_test",
+    srcs = ["ir_emitter_test.cc"],
+    deps = [
+        ":ir_emitter",
+        ":ir_function",
+        ":target_machine_features_fake",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:buffer_assignment",
+        "//xla/service:hlo_module_config",
+        "//xla/service:hlo_ordering",
+        "//xla/service:hlo_parser",
+        "//xla/service:logical_buffer",
+        "//xla/tests:hlo_test_base",
+        "@llvm-project//llvm:Core",
+        "@llvm-project//llvm:Support",
+        "@local_tsl//tsl/platform:statusor",
+        "@local_tsl//tsl/platform:test",
+        "@local_tsl//tsl/platform:test_main",
+    ],
+)
+
+xla_cc_test(
     name = "ir_emitter2_test",
     srcs = ["ir_emitter2_test.cc"],
     deps = [
@@ -1850,17 +1872,19 @@
 )
 
 cc_library(
-    name = "onednn_matmul_rewriter",
-    srcs = ["onednn_matmul_rewriter.cc"],
+    name = "onednn_contraction_rewriter",
+    srcs = ["onednn_contraction_rewriter.cc"],
     hdrs = [
+        "onednn_contraction_rewriter.h",
+        "onednn_convolution.h",
         "onednn_matmul.h",
-        "onednn_matmul_rewriter.h",
         "//xla/tsl/util:onednn_util_hdrs",
     ],
     copts = tsl_copts(),
     deps = [
         ":backend_config_proto_cc",
         ":onednn_config_proto_cc",
+        ":onednn_convolution",
         ":onednn_matmul",
         ":onednn_memory_util",
         ":onednn_pattern_utils",
@@ -1876,6 +1900,7 @@
         "//xla/service:hlo_pass",
         "//xla/service:pattern_matcher",
         "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/status:statusor",
         "@eigen_archive//:eigen3",
         "@local_tsl//tsl/platform:blocking_counter",
         "@local_tsl//tsl/platform:env",
@@ -1909,43 +1934,12 @@
 )
 
 cc_library(
-    name = "onednn_convolution_rewriter",
-    srcs = ["onednn_convolution_rewriter.cc"],
-    hdrs = ["onednn_convolution_rewriter.h"],
-    copts = tsl_copts(),
-    deps = [
-        ":backend_config_proto_cc",
-        ":onednn_config_proto_cc",
-        ":onednn_convolution",
-        ":onednn_memory_util",
-        ":onednn_util",
-        "//xla:executable_run_options",
-        "//xla:shape_util",
-        "//xla:status_macros",
-        "//xla:xla_data_proto_cc",
-        "//xla/hlo/evaluator:hlo_evaluator",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_creation_utils",
-        "//xla/service:hlo_pass",
-        "//xla/service:pattern_matcher",
-        "@com_google_absl//absl/algorithm:container",
-        "@com_google_absl//absl/status:statusor",
-        "@eigen_archive//:eigen3",
-        "@local_tsl//tsl/platform:blocking_counter",
-        "@local_tsl//tsl/platform:env",
-        "@local_tsl//tsl/platform:logging",
-        "@local_tsl//tsl/platform:platform_port",
-    ] + mkl_deps(),
-)
-
-cc_library(
     name = "cpu_float_support",
     srcs = ["cpu_float_support.cc"],
     hdrs = ["cpu_float_support.h"],
     copts = tsl_copts(),
     deps = [
-        ":onednn_convolution_rewriter",
-        ":onednn_matmul_rewriter",
+        ":onednn_contraction_rewriter",
         "//xla/service:float_support",
     ],
 )
diff --git a/third_party/xla/xla/service/cpu/cpu_compiler.cc b/third_party/xla/xla/service/cpu/cpu_compiler.cc
index 53be150..fc292c9 100644
--- a/third_party/xla/xla/service/cpu/cpu_compiler.cc
+++ b/third_party/xla/xla/service/cpu/cpu_compiler.cc
@@ -205,8 +205,7 @@
 
 #if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3)
 #include "xla/service/cpu/cpu_float_support.h"
-#include "xla/service/cpu/onednn_convolution_rewriter.h"
-#include "xla/service/cpu/onednn_matmul_rewriter.h"
+#include "xla/service/cpu/onednn_contraction_rewriter.h"
 #include "xla/service/cpu/onednn_ops_rewriter.h"
 #include "xla/service/simplify_fp_conversions.h"
 #endif
@@ -759,11 +758,10 @@
     if (debug_options.xla_allow_excess_precision()) {
       pipeline.AddPass<SimplifyFPConversions>();
     }
-    pipeline.AddPass<OneDnnConvolutionRewriter>();
-    pipeline.AddPass<OneDnnMatMulRewriter>(max_parallelism,
-                                           compile_options.thread_pool);
+    pipeline.AddPass<OneDnnContractionRewriter>(max_parallelism,
+                                                compile_options.thread_pool);
     // Run SimplifyFPConversions pass again to remove redundant Convert ops
-    // that may exist as a result of running OneDnnMatMulRewriter pass.
+    // that may exist as a result of running OneDnnContractionRewriter pass.
     if (debug_options.xla_allow_excess_precision()) {
       pipeline.AddPass<SimplifyFPConversions>();
     }
@@ -1273,11 +1271,17 @@
     cantFail((*jit)->AddModule(llvm::orc::ThreadSafeModule(
         std::move(llvm_module), std::move(llvm_context))));
 
+    auto mangle = [&](std::string_view name) {
+      llvm::SmallVector<char, 40> mangled;
+      llvm::Mangler::getNameWithPrefix(mangled, name, (*jit)->data_layout());
+      return std::string(mangled.begin(), mangled.end());
+    };
+
     // TODO(ezhulenev): We should be able to make it lazy on-demand, but today
     // we capture obj_files by reference and it leads to asan errors. Figure out
     // lifetime issues and move compilation to Thunk initialization stage.
     for (const auto& kernel : ir_emitter2.kernels()) {
-      if (auto sym = (*jit)->FindCompiledSymbol(kernel.name); !sym) {
+      if (auto s = (*jit)->FindCompiledSymbol(mangle(kernel.name)); !s) {
         return Internal("Failed to find compiled symbol for kernel %s",
                         kernel.name);
       }
@@ -1285,7 +1289,7 @@
 
     // Compile auxiliary comparator functions used by sort thunks.
     for (const auto& comparator : ir_emitter2.comparators()) {
-      if (auto sym = (*jit)->FindCompiledSymbol(comparator.name); !sym) {
+      if (auto s = (*jit)->FindCompiledSymbol(mangle(comparator.name)); !s) {
         return Internal("Failed to find compiled symbol for comparator %s",
                         comparator.name);
       }
@@ -1785,16 +1789,22 @@
     TF_ASSIGN_OR_RETURN(ThunkSequence thunks,
                         thunk_emitter.EmitEntryComputation(*module));
 
+    auto mangle = [&](std::string_view name) {
+      llvm::SmallVector<char, 40> mangled;
+      llvm::Mangler::getNameWithPrefix(mangled, name, (*jit)->data_layout());
+      return std::string(mangled.begin(), mangled.end());
+    };
+
     // Lookup all kernel functions by name in the loaded object file.
     for (const auto& kernel : ir_emitter2.kernels()) {
-      if (auto sym = (*jit)->FindCompiledSymbol(kernel.name); !sym) {
+      if (auto s = (*jit)->FindCompiledSymbol(mangle(kernel.name)); !s) {
         return Internal("Failed to find compiled symbol for kernel %s",
                         kernel.name);
       }
     }
 
     for (const auto& comparator : ir_emitter2.comparators()) {
-      if (auto sym = (*jit)->FindCompiledSymbol(comparator.name); !sym) {
+      if (auto s = (*jit)->FindCompiledSymbol(mangle(comparator.name)); !s) {
         return Internal("Failed to find compiled symbol for comparator %s",
                         comparator.name);
       }
diff --git a/third_party/xla/xla/service/cpu/cpu_executable.cc b/third_party/xla/xla/service/cpu/cpu_executable.cc
index 816ad13..0bec735 100644
--- a/third_party/xla/xla/service/cpu/cpu_executable.cc
+++ b/third_party/xla/xla/service/cpu/cpu_executable.cc
@@ -37,7 +37,9 @@
 #include "absl/strings/string_view.h"
 #include "absl/types/span.h"
 #include "unsupported/Eigen/CXX11/Tensor"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/ExecutionEngine/Orc/Shared/ExecutorSymbolDef.h"
+#include "llvm/IR/Mangler.h"
 #include "llvm/Support/Error.h"
 #include "xla/executable_run_options.h"
 #include "xla/hlo/ir/hlo_computation.h"
@@ -80,12 +82,18 @@
 
 FunctionRegistry::FunctionRegistry(SimpleOrcJIT* jit) : jit_(jit) {}
 
+std::string FunctionRegistry::Mangle(std::string_view name) {
+  llvm::SmallVector<char, 40> mangled;
+  llvm::Mangler::getNameWithPrefix(mangled, name, jit_->data_layout());
+  return std::string(mangled.begin(), mangled.end());
+}
+
 absl::StatusOr<FunctionRegistry::Kernel> FunctionRegistry::FindKernel(
     std::string_view name) {
   VLOG(3) << "Find host kernel with a name " << name;
 
   llvm::Expected<llvm::orc::ExecutorSymbolDef> sym =
-      jit_->FindCompiledSymbol(std::string(name));
+      jit_->FindCompiledSymbol(Mangle(name));
   if (!sym) {
     return absl::InvalidArgumentError(
         absl::StrCat("Can't resolve host kernel with a name ", name,
@@ -99,7 +107,7 @@
   VLOG(3) << "Find comparator with a name " << name;
 
   llvm::Expected<llvm::orc::ExecutorSymbolDef> sym =
-      jit_->FindCompiledSymbol(std::string(name));
+      jit_->FindCompiledSymbol(Mangle(name));
   if (!sym) {
     return absl::InvalidArgumentError(
         absl::StrCat("Can't resolve comparator with a name ", name,
@@ -175,6 +183,7 @@
       std::move(hlo_profile_index_map), std::move(assignment)));
 
   executable->jit_ = std::move(jit);
+  executable->jit_->DoneCompiling();
   executable->function_registry_ = FunctionRegistry(executable->jit_.get());
 
   TF_ASSIGN_OR_RETURN(executable->thunks_,
diff --git a/third_party/xla/xla/service/cpu/cpu_executable.h b/third_party/xla/xla/service/cpu/cpu_executable.h
index b129674..8c8883d 100644
--- a/third_party/xla/xla/service/cpu/cpu_executable.h
+++ b/third_party/xla/xla/service/cpu/cpu_executable.h
@@ -151,6 +151,8 @@
     absl::StatusOr<Comparator> FindComparator(std::string_view name) final;
 
    private:
+    std::string Mangle(std::string_view name);
+
     SimpleOrcJIT* jit_;
   };
 
diff --git a/third_party/xla/xla/service/cpu/cpu_float_support.cc b/third_party/xla/xla/service/cpu/cpu_float_support.cc
index 6914e65..c590716 100644
--- a/third_party/xla/xla/service/cpu/cpu_float_support.cc
+++ b/third_party/xla/xla/service/cpu/cpu_float_support.cc
@@ -17,8 +17,7 @@
 
 #include "xla/service/cpu/cpu_float_support.h"
 
-#include "xla/service/cpu/onednn_convolution_rewriter.h"
-#include "xla/service/cpu/onednn_matmul_rewriter.h"
+#include "xla/service/cpu/onednn_contraction_rewriter.h"
 
 namespace xla {
 namespace cpu {
@@ -28,10 +27,10 @@
     // oneDNN rewritable ops
     case HloOpcode::kDot:
       return LowPrecisionType() == BF16 &&
-             OneDnnMatMulRewriter::ShouldRewrite(&hlo);
+             OneDnnContractionRewriter::ShouldRewriteDot(&hlo, true);
     case HloOpcode::kConvolution:
       return LowPrecisionType() == BF16 &&
-             OneDnnConvolutionRewriter::ShouldRewrite(&hlo);
+             OneDnnContractionRewriter::ShouldRewriteConv(&hlo);
     // Collective ops.
     case HloOpcode::kAllGather:
     case HloOpcode::kAllReduce:
diff --git a/third_party/xla/xla/service/cpu/executable.proto b/third_party/xla/xla/service/cpu/executable.proto
index bca8a2c..d222660 100644
--- a/third_party/xla/xla/service/cpu/executable.proto
+++ b/third_party/xla/xla/service/cpu/executable.proto
@@ -17,7 +17,6 @@
 
 package xla.cpu;
 
-import "xla/service/cpu/xla_framework.proto";
 import "xla/service/hlo.proto";
 import "xla/xla.proto";
 
diff --git a/third_party/xla/xla/service/cpu/ir_emitter.cc b/third_party/xla/xla/service/cpu/ir_emitter.cc
index d8bae80..1ee2f69 100644
--- a/third_party/xla/xla/service/cpu/ir_emitter.cc
+++ b/third_party/xla/xla/service/cpu/ir_emitter.cc
@@ -451,6 +451,18 @@
   }
 }
 
+void IrEmitter::AttachInvariantLoadMetadataForLoad(llvm::LoadInst* load) const {
+  AttachInvariantLoadMetadataForLoad(load, hlo_module_config_);
+}
+
+/*static*/ void IrEmitter::AttachInvariantLoadMetadataForLoad(
+    llvm::LoadInst* load, const HloModuleConfig& config) {
+  if (config.debug_options().xla_llvm_enable_invariant_load_metadata()) {
+    load->setMetadata(llvm::LLVMContext::MD_invariant_load,
+                      llvm::MDNode::get(load->getContext(), /*MDs=*/{}));
+  }
+}
+
 absl::Status IrEmitter::HandleGetTupleElement(
     HloInstruction* get_tuple_element) {
   // A tuple is an array of pointers, one for each operand. Each pointer points
@@ -4073,12 +4085,8 @@
       GetBufferTableArgument(), b()->getPtrTy(), slice.index(), b());
   llvm::LoadInst* tempbuf_address_base =
       Load(b()->getPtrTy(), tempbuf_address_ptr);
-  if (hlo_module_config_.debug_options()
-          .xla_llvm_enable_invariant_load_metadata()) {
-    tempbuf_address_base->setMetadata(
-        llvm::LLVMContext::MD_invariant_load,
-        llvm::MDNode::get(tempbuf_address_base->getContext(), /*MDs=*/{}));
-  }
+
+  AttachInvariantLoadMetadataForLoad(tempbuf_address_base);
   AttachAlignmentMetadataForLoad(tempbuf_address_base, allocation.size());
   AttachDereferenceableMetadataForLoad(tempbuf_address_base, allocation.size());
 
diff --git a/third_party/xla/xla/service/cpu/ir_emitter.h b/third_party/xla/xla/service/cpu/ir_emitter.h
index 0e6b1a3..d2c94a9 100644
--- a/third_party/xla/xla/service/cpu/ir_emitter.h
+++ b/third_party/xla/xla/service/cpu/ir_emitter.h
@@ -760,8 +760,14 @@
   // result with the dereferenceable bytes required by the shape / buffer size.
   void AttachDereferenceableMetadataForLoad(llvm::LoadInst* load,
                                             const Shape& shape);
-  void AttachDereferenceableMetadataForLoad(llvm::LoadInst* load,
-                                            int64_t buffer_size);
+  static void AttachDereferenceableMetadataForLoad(llvm::LoadInst* load,
+                                                   int64_t buffer_size);
+
+  // Given a load instruction, annotate the load's result with the invariant
+  // load metadata.
+  void AttachInvariantLoadMetadataForLoad(llvm::LoadInst* load) const;
+  static void AttachInvariantLoadMetadataForLoad(llvm::LoadInst* load,
+                                                 const HloModuleConfig& config);
 
   // Calculate the alignment of a buffer allocated for a given shape.
   int MinimumAlignmentForShape(const Shape& shape);
diff --git a/third_party/xla/xla/service/cpu/ir_emitter2.cc b/third_party/xla/xla/service/cpu/ir_emitter2.cc
index eeac5bc..2e64bea 100644
--- a/third_party/xla/xla/service/cpu/ir_emitter2.cc
+++ b/third_party/xla/xla/service/cpu/ir_emitter2.cc
@@ -40,6 +40,7 @@
 #include "llvm/IR/CallingConv.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/GlobalValue.h"
 #include "llvm/IR/GlobalVariable.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Instructions.h"
@@ -47,7 +48,7 @@
 #include "llvm/IR/Metadata.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Value.h"
-#include "llvm/Support/Casting.h"
+#include "llvm/Support/CodeGen.h"
 #include "xla/cpu_function_runtime.h"
 #include "xla/hlo/ir/hlo_computation.h"
 #include "xla/hlo/ir/hlo_instruction.h"
@@ -564,6 +565,10 @@
                           /*is_top_level_computation=*/true, schedule,
                           /*allow_reassociation=*/false));
 
+  // Generate unwind information so that GDB can crawl through the stack frames
+  // created by the JIT compiled code.
+  comparator_function->setUWTableKind(llvm::UWTableKind::Default);
+
   return comparators_.emplace_back(
       ComparatorInfo{comparator_function->getName().str()});
 }
@@ -710,6 +715,15 @@
   // emit metadata to allow LLVM to use that information for optimization.
   llvm_ir::SetAlignmentMetadataForLoad(data, cpu_function_runtime::MinAlign());
 
+  // All buffers pointers passed to host kernels are expected to be
+  // dereferenceable.
+  IrEmitter::AttachDereferenceableMetadataForLoad(data, ByteSizeOf(shape));
+
+  // All buffers pointers passed to host kernels are expected to be invariant
+  // over the whole program. Note the metadata is attached only to loading
+  // buffer pointers, not to loading actual buffers.
+  AttachInvariantLoadMetadataForLoad(data);
+
   return llvm_ir::IrArray(data, llvm_ir::ShapeToIrType(shape, module_), shape);
 }
 
@@ -791,11 +805,15 @@
     }
   }
 
-  // Create a kernel function with HostKernel API.
-  llvm::Function* function = llvm::dyn_cast<llvm::Function>(
-      module_->getOrInsertFunction(name, KernelFunctionTy(ctx)).getCallee());
+  // Create a kernel function with HostKernel API. We use external linkage
+  // because we'll be resolving this function from the XLA runtime.
+  llvm::Function* function = llvm::Function::Create(
+      KernelFunctionTy(ctx), llvm::GlobalValue::ExternalLinkage, name, module_);
   function->setCallingConv(llvm::CallingConv::C);
-  function->setDoesNotThrow();
+
+  // Generate unwind information so that GDB can crawl through the stack frames
+  // created by the JIT compiled code.
+  function->setUWTableKind(llvm::UWTableKind::Default);
 
   // Set prefer-vector-width attribute to allow LLVM to use wider vector
   // registers (by default LLVM uses at most 256-bit registers).
@@ -1027,4 +1045,17 @@
   return se::ThreadDim();
 }
 
+// This is a convenience function taken from IrEmitter, it uses module_ class
+// field. If there will be more functions that use module_, we should consider
+// refactoring (like we did for compute_function_ and builder_).
+int64_t IrEmitter2::ByteSizeOf(const Shape& shape) const {
+  return llvm_ir::ByteSizeOf(shape, module_->getDataLayout());
+}
+
+void IrEmitter2::AttachInvariantLoadMetadataForLoad(
+    llvm::LoadInst* instr) const {
+  nested_ir_emitter_->AttachInvariantLoadMetadataForLoad(instr,
+                                                         hlo_module_.config());
+}
+
 }  // namespace xla::cpu
diff --git a/third_party/xla/xla/service/cpu/ir_emitter2.h b/third_party/xla/xla/service/cpu/ir_emitter2.h
index c998840..a205e91 100644
--- a/third_party/xla/xla/service/cpu/ir_emitter2.h
+++ b/third_party/xla/xla/service/cpu/ir_emitter2.h
@@ -28,6 +28,7 @@
 #include "absl/types/span.h"
 #include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instructions.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Value.h"
 #include "xla/hlo/ir/hlo_instruction.h"
@@ -240,6 +241,13 @@
 
   bool fast_min_max() const;
 
+  // Returns the number of bytes within the shape.
+  int64_t ByteSizeOf(const Shape& shape) const;
+
+  // Given a load instruction, annotate the load's result with the invariant
+  // load metadata.
+  void AttachInvariantLoadMetadataForLoad(llvm::LoadInst* instr) const;
+
   const HloModule& hlo_module_;
   llvm::Module* module_;
 
diff --git a/third_party/xla/xla/service/cpu/ir_emitter2_test.cc b/third_party/xla/xla/service/cpu/ir_emitter2_test.cc
index 539facd..a70241b 100644
--- a/third_party/xla/xla/service/cpu/ir_emitter2_test.cc
+++ b/third_party/xla/xla/service/cpu/ir_emitter2_test.cc
@@ -85,45 +85,45 @@
   ASSERT_TRUE(*RunFileCheck(llvm_ir::DumpToString(module.get()), R"(
     CHECK: define ptr @test(ptr %0) #0 {
 
-    CHECK-NEXT: getelementptr inbounds %SE_HOST_KernelCallFrame, {{.*}} i32 0
-    CHECK:      getelementptr inbounds %SE_HOST_KernelThreadDim, {{.*}} i32 0
-    CHECK:      getelementptr inbounds %SE_HOST_KernelThreadDim, {{.*}} i32 1
-    CHECK:      getelementptr inbounds %SE_HOST_KernelThreadDim, {{.*}} i32 2
+    CHECK-NEXT: getelementptr inbounds nuw %SE_HOST_KernelCallFrame, {{.*}} i32 0
+    CHECK:      getelementptr inbounds nuw %SE_HOST_KernelThreadDim, {{.*}} i32 0
+    CHECK:      getelementptr inbounds nuw %SE_HOST_KernelThreadDim, {{.*}} i32 1
+    CHECK:      getelementptr inbounds nuw %SE_HOST_KernelThreadDim, {{.*}} i32 2
     CHECK:      load i64
     CHECK:      load i64
     CHECK:      load i64
 
-    CHECK-NEXT: getelementptr inbounds %SE_HOST_KernelCallFrame, {{.*}} i32 1
-    CHECK:      getelementptr inbounds %SE_HOST_KernelThread, {{.*}} i32 0
-    CHECK:      getelementptr inbounds %SE_HOST_KernelThread, {{.*}} i32 1
-    CHECK:      getelementptr inbounds %SE_HOST_KernelThread, {{.*}} i32 2
+    CHECK-NEXT: getelementptr inbounds nuw %SE_HOST_KernelCallFrame, {{.*}} i32 1
+    CHECK:      getelementptr inbounds nuw %SE_HOST_KernelThread, {{.*}} i32 0
+    CHECK:      getelementptr inbounds nuw %SE_HOST_KernelThread, {{.*}} i32 1
+    CHECK:      getelementptr inbounds nuw %SE_HOST_KernelThread, {{.*}} i32 2
     CHECK:      load i64
     CHECK:      load i64
     CHECK:      load i64
 
-    CHECK-NEXT: getelementptr inbounds %SE_HOST_KernelCallFrame, {{.*}} i32 3
+    CHECK-NEXT: getelementptr inbounds nuw %SE_HOST_KernelCallFrame, {{.*}} i32 3
     CHECK:      load ptr
     CHECK:      getelementptr %SE_HOST_KernelArg, {{.*}} i32 0, i32 0
-    CHECK:      %[[ARG0:.+]] = load ptr, {{.*}}, !align ![[ALIGNMENT:.+]]
+    CHECK:      %[[ARG0:.+]] = load ptr, {{.*}}, !invariant.load ![[SCOPE0:.+]], !dereferenceable ![[DEREF_BYTES:.*]], !align ![[ALIGNMENT:.+]]
 
-    CHECK-NEXT: getelementptr inbounds %SE_HOST_KernelCallFrame, {{.*}} i32 3
+    CHECK-NEXT: getelementptr inbounds nuw %SE_HOST_KernelCallFrame, {{.*}} i32 3
     CHECK:      load ptr
     CHECK:      getelementptr %SE_HOST_KernelArg, {{.*}} i32 1, i32 0
-    CHECK:      %[[ARG1:.+]] = load ptr, {{.*}}, !align ![[ALIGNMENT]]
+    CHECK:      %[[ARG1:.+]] = load ptr, {{.*}}, !invariant.load ![[SCOPE0]], !dereferenceable ![[DEREF_BYTES]], !align ![[ALIGNMENT]]
 
-    CHECK-NEXT: getelementptr inbounds %SE_HOST_KernelCallFrame, {{.*}} i32 3
+    CHECK-NEXT: getelementptr inbounds nuw %SE_HOST_KernelCallFrame, {{.*}} i32 3
     CHECK:      load ptr
     CHECK:      getelementptr %SE_HOST_KernelArg, {{.*}} i32 2, i32 0
-    CHECK:      %[[ARG2:.+]] = load ptr, {{.*}}, !align ![[ALIGNMENT]]
+    CHECK:      %[[ARG2:.+]] = load ptr, {{.*}}, !invariant.load ![[SCOPE0]], !dereferenceable ![[DEREF_BYTES]], !align ![[ALIGNMENT]]
 
-    CHECK-NEXT: getelementptr inbounds %SE_HOST_KernelCallFrame, {{.*}} i32 3
+    CHECK-NEXT: getelementptr inbounds nuw %SE_HOST_KernelCallFrame, {{.*}} i32 3
     CHECK:      load ptr
     CHECK:      getelementptr %SE_HOST_KernelArg, {{.*}} i32 3, i32 0
-    CHECK:      %[[ARG3:.+]] = load ptr, {{.*}}, !align ![[ALIGNMENT]]
+    CHECK:      %[[ARG3:.+]] = load ptr, {{.*}}, !invariant.load ![[SCOPE0]], !dereferenceable ![[DEREF_BYTES]], !align ![[ALIGNMENT]]
 
     CHECK-NEXT: %[[PTR0:.+]] = getelementptr inbounds float, ptr %[[ARG0]]
     CHECK:      load float, ptr %[[PTR0]], align 4,
-    CHECK-SAME:                            !invariant.load ![[SCOPE0:.+]],
+    CHECK-SAME:                            !invariant.load ![[SCOPE0]],
     CHECK-SAME:                            !noalias ![[SCOPE1:.+]]
 
     CHECK-NEXT: %[[PTR1:.+]] = getelementptr inbounds float, ptr %[[ARG1]]
@@ -142,6 +142,8 @@
     CHECK:      ret ptr null
     CHECK: }
 
+    #0 = { uwtable "frame-pointer"="all" "prefer-vector-width"="256" }
+    CHECK-DAG: ![[DEREF_BYTES]] = !{i64 32}
     CHECK-DAG: ![[ALIGNMENT]] = !{i64 16}
     CHECK-DAG: ![[SCOPE0]] = !{}
     CHECK-DAG: ![[SCOPE1]] = !{![[RES0:.+]], ![[RES1:.+]]}
diff --git a/third_party/xla/xla/service/cpu/ir_emitter_test.cc b/third_party/xla/xla/service/cpu/ir_emitter_test.cc
new file mode 100644
index 0000000..7102d20
--- /dev/null
+++ b/third_party/xla/xla/service/cpu/ir_emitter_test.cc
@@ -0,0 +1,124 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/cpu/ir_emitter.h"
+
+#include <cstdint>
+#include <memory>
+#include <utility>
+
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/GlobalValue.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/Support/Casting.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/service/buffer_assignment.h"
+#include "xla/service/cpu/ir_function.h"
+#include "xla/service/cpu/target_machine_features_fake.h"
+#include "xla/service/hlo_module_config.h"
+#include "xla/service/hlo_ordering.h"
+#include "xla/service/hlo_parser.h"
+#include "xla/service/logical_buffer.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/statusor.h"
+#include "tsl/platform/test.h"
+
+namespace xla::cpu {
+namespace {
+
+using IrEmitterTest = HloTestBase;
+
+static std::pair<llvm::Function*, llvm::BasicBlock*> CreateFunction(
+    llvm::LLVMContext& context, llvm::Module* module, llvm::IRBuilder<>* b) {
+  llvm::PointerType* ptrtype = llvm::PointerType::getUnqual(context);
+  llvm::FunctionType* ftype = llvm::FunctionType::get(ptrtype, ptrtype, false);
+
+  llvm::Function* function = llvm::dyn_cast<llvm::Function>(
+      module->getOrInsertFunction("func2", ftype).getCallee());
+
+  llvm::BasicBlock* return_block =
+      llvm::BasicBlock::Create(context, "", function);
+  b->SetInsertPoint(return_block);
+  [[maybe_unused]] llvm::ReturnInst* ret = b->CreateRet(
+      llvm::ConstantPointerNull::get(llvm::PointerType::getUnqual(context)));
+
+  return std::make_pair(function, return_block);
+}
+
+TEST_F(IrEmitterTest, ComputeFuncStack) {
+  llvm::LLVMContext context;
+  auto module = std::make_unique<llvm::Module>("test", context);
+
+  const char* hlo_text = R"(
+    HloModule m
+    ENTRY main {
+      ROOT %zero = f32[] constant(0)
+    })";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto hlo, ParseAndReturnUnverifiedModule(hlo_text));
+  const HloInstruction* zero = FindInstruction(hlo.get(), "zero");
+  ASSERT_NE(zero, nullptr);
+
+  TF_ASSERT_OK_AND_ASSIGN(
+      std::unique_ptr<BufferAssignment> buffer_assignment,
+      BufferAssigner::Run(
+          hlo.get(), std::make_unique<DependencyHloOrdering>(hlo.get()),
+          backend().compiler()->BufferSizeBytesFunction(),
+          [](LogicalBuffer::Color) { return /*alignment=*/1; }));
+
+  TargetMachineFeaturesWithFakeAlignmentLogic target_machine(
+      [](int64_t size) { return 1; });
+
+  IrEmitter ir_emitter(nullptr, *hlo, *buffer_assignment, module.get(), {}, {},
+                       {}, &target_machine, false);
+
+  llvm::IRBuilder<>* b = ir_emitter.b();
+  ASSERT_NE(b, nullptr);
+
+  const std::pair<llvm::Function*, llvm::BasicBlock*> fb =
+      CreateFunction(context, module.get(), b);
+
+  llvm::Function* function = fb.first;
+  llvm::BasicBlock* return_block = fb.second;
+
+  ASSERT_NE(function, nullptr);
+  ASSERT_NE(return_block, nullptr);
+
+  const auto funcname = "func1";
+  const auto linkagetype = llvm::GlobalValue::LinkageTypes::ExternalLinkage;
+  const HloModuleConfig module_config;
+  ir_emitter.PushComputeFunction(funcname, linkagetype, module_config,
+                                 module.get(), 0);
+  ASSERT_EQ(ir_emitter.compute_function()->function()->getName().str(),
+            funcname);
+
+  ir_emitter.PushComputeFunction(b, module.get(), 0, function, nullptr,
+                                 return_block);
+  ASSERT_EQ(ir_emitter.compute_function()->function(), function);
+
+  ir_emitter.PopComputeFunction();
+  ASSERT_EQ(ir_emitter.compute_function()->function()->getName().str(),
+            funcname);
+
+  ir_emitter.PopComputeFunction();
+}
+
+}  // namespace
+}  // namespace xla::cpu
diff --git a/third_party/xla/xla/service/cpu/onednn_contraction_rewriter.cc b/third_party/xla/xla/service/cpu/onednn_contraction_rewriter.cc
new file mode 100644
index 0000000..19122b3
--- /dev/null
+++ b/third_party/xla/xla/service/cpu/onednn_contraction_rewriter.cc
@@ -0,0 +1,1258 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3)
+
+#define EIGEN_USE_THREADS
+
+#include "xla/service/cpu/onednn_contraction_rewriter.h"
+
+#include "xla/executable_run_options.h"
+#include "xla/hlo/evaluator/hlo_evaluator.h"
+#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/service/cpu/backend_config.pb.h"
+#include "xla/service/cpu/onednn_config.pb.h"
+#include "xla/service/cpu/onednn_convolution.h"
+#include "xla/service/cpu/onednn_matmul.h"
+#include "xla/service/cpu/onednn_memory_util.h"
+#include "xla/service/cpu/onednn_pattern_utils.h"
+#include "xla/service/cpu/onednn_util.h"
+#include "xla/service/hlo_cost_analysis.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/status_macros.h"
+#include "xla/tsl/util/onednn_threadpool.h"
+#include "tsl/platform/logging.h"  // IWYU pragma: keep
+
+namespace xla {
+namespace cpu {
+
+namespace {
+namespace m = match;
+namespace pu = ::xla::cpu::onednn_pattern_utils_internal;
+
+inline absl::Status ValidateDotDimensionNumbers(
+    const DotDimensionNumbers& dim_numbers) {
+  // Checks some invariants that do not hold in general, but DotDecomposer
+  // should have established for us.
+  TF_RET_CHECK(dim_numbers.lhs_contracting_dimensions_size() == 1);
+  std::vector<int64_t> batch_dim_numbers(
+      dim_numbers.lhs_batch_dimensions_size());
+  absl::c_iota(batch_dim_numbers, 0);
+  TF_RET_CHECK(
+      absl::c_equal(batch_dim_numbers, dim_numbers.lhs_batch_dimensions()));
+  TF_RET_CHECK(
+      absl::c_equal(batch_dim_numbers, dim_numbers.rhs_batch_dimensions()));
+  return absl::OkStatus();
+}
+
+// Whether the element type of instr is compatible with oneDNN kernels.
+// TODO(intel-tf): Restict compatible types based on instruction kind.
+inline bool CompatibleElementType(const HloInstruction* instr) {
+  PrimitiveType element_type = instr->shape().element_type();
+  return element_type == BF16 || element_type == F32 || element_type == F16;
+}
+
+inline bool IsRowMajor(const Shape& shape) {
+  return LayoutUtil::IsMonotonicWithDim0Major(shape.layout());
+}
+
+template <typename Pattern>
+inline auto BitcastWithReshapeSemantics(HloInstruction** bitcast,
+                                        Pattern pattern) {
+  // TODO(intel-tf): Add stronger condition that Bitcast does not have transpose
+  // semantics. Some of the HLO passes replaces Transpose with Bitcast. Here
+  // the layouts are checked to be rowmajor since the current pass runs after
+  // the layout assignment and oneDNN matmul is enabled for rowmajor layouts.
+  auto is_reshape = [](const HloInstruction* instr) -> bool {
+    if (!instr) return false;
+    auto input_shape = instr->operand(0)->shape();
+    auto output_shape = instr->shape();
+    bool is_same_type = ShapeUtil::SameElementType(input_shape, output_shape);
+    bool has_equal_num_elems = ShapeUtil::ElementsIn(input_shape) ==
+                               ShapeUtil::ElementsIn(output_shape);
+    bool has_rowmajor_layout =
+        IsRowMajor(input_shape) && IsRowMajor(output_shape);
+    return is_same_type && has_equal_num_elems && has_rowmajor_layout;
+  };
+  return m::Bitcast(bitcast, pattern).WithPredicate(is_reshape);
+}
+
+template <typename Pattern>
+auto ElementwiseSafeIntermediates(HloInstruction** instr,
+                                  HloInstruction** optional_bitcast,
+                                  Pattern pattern) {
+  return m::AnyOf<HloInstruction>(
+      m::Broadcast(instr, pattern.WithOneUser()),
+      m::Slice(instr, pattern.WithOneUser()),
+      m::Bitcast(instr, pattern.WithOneUser()),
+      m::Reshape(instr, pattern.WithOneUser()),
+      pu::SupportedConvert(instr, pattern.WithOneUser()),
+      pu::SupportedConvert(instr, BitcastWithReshapeSemantics(
+                                      optional_bitcast, pattern.WithOneUser())),
+      pattern);
+}
+
+inline auto OneDnnMatmulInstr(HloInstruction** instr) {
+  return m::CustomCall(instr, {"__onednn$matmul"});
+}
+
+inline auto ConvertBF16ToF32(HloInstruction** instr) {
+  return m::Convert(m::Op(instr).WithElementType(PrimitiveType::BF16))
+      .WithElementType(PrimitiveType::F32);
+}
+
+inline auto BcastConstScalar(HloInstruction** instr, double value) {
+  return m::Broadcast(instr, m::ConstantScalar(value));
+}
+
+inline auto BcastConstScalar(double value) {
+  return BcastConstScalar(nullptr, value);
+}
+
+inline auto BcastConvertConstScalar(double value) {
+  return m::Broadcast(pu::OptionalConvert(m::ConstantScalar(value)));
+}
+
+inline bool IsBatchDot(const HloInstruction& instr) {
+  if (auto* dot_instr = DynCast<HloDotInstruction>(&instr)) {
+    return dot_instr->dot_dimension_numbers().lhs_batch_dimensions_size() > 0;
+  }
+  return false;
+}
+
+auto ConstScalarNear(double value) {
+  return m::ConstantScalar().WithPredicate(
+      [expected = value](const HloInstruction* instr) {
+        // Not a very robust floating-point comparison, but good enough for our
+        // purposes.
+        std::optional<double> actual =
+            static_cast<const HloConstantInstruction*>(instr)
+                ->literal()
+                .GetAsDouble({});
+        if (!actual.has_value()) return false;
+        double epsilon;
+        switch (instr->shape().element_type()) {
+          case F16:
+            epsilon = 128 * std::numeric_limits<Eigen::half>::epsilon();
+            break;
+          case BF16:
+            epsilon = 128 * std::numeric_limits<bfloat16>::epsilon();
+            break;
+          case F32:
+            epsilon = 128 * std::numeric_limits<float>::epsilon();
+            break;
+          case F64:
+            epsilon = 128 * std::numeric_limits<double>::epsilon();
+            break;
+          default:
+            return false;
+        }
+        return abs(*actual - expected) < (abs(*actual + expected) * epsilon);
+      });
+}
+
+bool IsScalar(const HloInstruction* instr) {
+  return ShapeUtil::IsEffectiveScalar(instr->shape());
+}
+
+std::optional<float> GetConstantValueAsFloat32(const HloInstruction* inst) {
+  if (!IsScalar(inst)) {
+    return std::nullopt;
+  }
+  switch (inst->shape().element_type()) {
+    case F16:
+      return inst->literal().GetFirstElement<half>();
+    case BF16:
+      return inst->literal().GetFirstElement<bfloat16>();
+    case F32:
+      return inst->literal().GetFirstElement<float>();
+    default:
+      return std::nullopt;
+  }
+}
+
+inline auto BcastConstScalarNear(double value) {
+  return m::Broadcast(ConstScalarNear(value));
+}
+
+// Associativity and commutativity properties of multiply results in various
+// patterns for an equivalent computation. This function tries to capture most
+// of the variations for a computation a * b * c. For example, patterns could be
+// any of (a * b) * c or a * (b * c), along with the variations resulting from
+// commutative patterns.
+template <typename PatternA, typename PatternB, typename PatternC>
+inline auto MultiplyMultiplyAnyOrder(PatternA a, PatternB b, PatternC c) {
+  return m::AnyOf<HloInstruction>(
+      m::MultiplyAnyOrder(a, m::MultiplyAnyOrder(b, c)),
+      m::MultiplyAnyOrder(b, m::MultiplyAnyOrder(a, c)),
+      m::MultiplyAnyOrder(c, m::MultiplyAnyOrder(a, b)));
+}
+
+auto GELUActivation(HloInstruction* instr, HloInstruction** src) {
+  // Attempt to match GELU_TANH activation or GELU_ERF activation
+  // (https://arxiv.org/abs/1606.08415), where:
+  // gelu_tanh(x) = x * cdf(x)
+  // cdf(x) = 0.5 * (1 + tanh(sqrt(2 / pi) * (x + 0.044715 * x**3))
+  //                     -------------errf_approximate------------
+  //
+  // gelu_erf(x) = x * cdf(x)
+  // cdf(x) = 0.5 * (1 + erf(x / sqrt(2)))
+  //                     --errf_exact--
+
+  HloInstruction* errf;
+
+  // The expression 0.5 * x * (1 + errf) as common pattern for GELU exact and
+  // approximate activations.
+  auto common_pattern = MultiplyMultiplyAnyOrder(
+      BcastConstScalar(0.5), m::Op(src),
+      m::AddAnyOrder(BcastConstScalar(1.0), m::Op(&errf).WithOneUser()));
+
+  bool matched = Match(instr, common_pattern);
+  if (matched) {
+    // The subexpression 0.044715 * x**3 appears in GELU approximate activation.
+    // However, it is often optimized by other HLO passes into an expression of
+    // 0.044715 * x * (x * x). Since there are three consecutive multiplies,
+    // there could be a large number of patterns. We try to capture some of
+    // those:
+    //
+    //      1. (0.044715 * x) * x * x
+    //      2. 0.044715 * (x * x) * x
+    //
+    // Note each of the above could in turn have various patterns due to
+    // associativity and commutativity properties of multiply.
+    auto subexpr_pattern = m::AnyOf<HloInstruction>(
+        MultiplyMultiplyAnyOrder(
+            m::MultiplyAnyOrder(BcastConstScalarNear(0.044715),
+                                m::Op().Is(*src)),
+            m::Op().Is(*src), m::Op().Is(*src)),
+        MultiplyMultiplyAnyOrder(
+            BcastConstScalarNear(0.044715),
+            m::Multiply(m::Op().Is(*src), m::Op().Is(*src)), m::Op().Is(*src)));
+
+    auto errf_apprx_pattern =
+        m::Tanh(m::MultiplyAnyOrder(
+                    BcastConstScalarNear(sqrt(M_2_PI)),
+                    m::AddAnyOrder(m::Op().Is(*src), subexpr_pattern)
+                        .WithOneUser()))
+            .WithOneUser();
+
+    HloInstruction* erf;
+    auto errf_exact_pattern =
+        m::Op(&erf)
+            .WithOpcode(HloOpcode::kErf)
+            .WithOperand(
+                0, m::MultiplyAnyOrder(m::Op(src),
+                                       m::AnyOf<HloInstruction>(
+                                           BcastConstScalarNear(0.707106769),
+                                           BcastConstScalarNear(0.70703125),
+                                           BcastConstScalarNear(0.707182348)))
+                       .WithOneUser())
+            .WithOneUser();
+
+    if (Match(errf, errf_apprx_pattern)) {
+      // Matched Gelu-approximate pattern
+      return OneDnnFusionConfig::GELU_TANH;
+    } else if (Match(errf, errf_exact_pattern)) {
+      // Matched Gelu-exact pattern
+      return OneDnnFusionConfig::GELU_ERF;
+    }
+  }
+  return OneDnnFusionConfig::UNDEFINED;
+}
+
+// OneDNN matmul can fuse add operation with automatic broadcasting along the
+// addend's dimensions that are 1s. When compatible, Broadcast can be replaced
+// by Bitcast, which is much cheaper. Compute new shape for the Bitcast.
+absl::StatusOr<Shape> AdjustBiasShape(const HloInstruction* broadcast_instr,
+                                      const Shape& dot_shape) {
+  if (broadcast_instr->opcode() != HloOpcode::kBroadcast) {
+    return absl::InvalidArgumentError(
+        "Hlo instruction is not a Broadcast insruction.");
+  }
+  auto bcast = Cast<HloBroadcastInstruction>(broadcast_instr);
+  Shape new_shape = bcast->shape();
+  // Broadcast instruction has "dimensions" parameter along which its input's
+  // dimensions should not change. For example,
+  //      dot = f32[3,4,5,6] dot(...)
+  //      arg = f32[3,6]{1,0} parameter(0)
+  //      broad = f32[3,4,5,6]{3,2,1,0} broadcast(arg), dimensions={0,3}
+  //      add = f32[3,4,5,6]{3,2,1,0} add(dot, arg)
+  // can be replaced with the following
+  //      arg = f32[3,6]{1,0} parameter(0)
+  //      bitcast = f32[3,1,1,6]{3,2,1,0} bitcast(arg)
+  //      fused = f32[3,4,5,6]{3,2,1,0} custom-call((..., bitcast)
+  auto kept_dimensions = bcast->dimensions();
+  for (int i = 0; i < new_shape.rank(); i++) {
+    if (!absl::c_linear_search(kept_dimensions, i)) {
+      new_shape.set_dimensions(i, 1);
+    }
+  }
+
+  // If rank(new_shape) > rank(dot), extra dimensions with value = 1 can be
+  // deleted from the new_shape.
+  int64_t rank_difference = new_shape.rank() - dot_shape.rank();
+  auto new_dims = new_shape.dimensions();
+  std::vector<int64_t> dims_to_delete;
+  for (int i = 0; i < rank_difference; ++i) {
+    if (new_dims[i] == 1) {
+      dims_to_delete.push_back(i);
+    }
+  }
+  new_shape = ShapeUtil::DeleteDimensions(dims_to_delete, new_shape);
+
+  // New shape for bias should satisfy the condition:
+  //   rank(new_shape) <= rank(dot).
+  if (new_shape.rank() > dot_shape.rank()) {
+    return absl::CancelledError(
+        "Bias shape could not be adjusted for a fusion.");
+  }
+
+  return new_shape;
+};
+
+inline bool IsOperandFusible(HloInstruction* operand, HloInstruction* dot) {
+  // Check if the operand's shape is compatible with matmul for fusion.
+  // An operand is fusable if
+  //    1. rank(operand) <= rank(dot) and
+  //    2. Starting from the last dim in backward direction, the dimension
+  //       size of operand is either 1 or same to dot.
+  auto operand_dims = operand->shape().dimensions();
+  auto dot_dims = dot->shape().dimensions();
+  if (operand_dims.size() > dot_dims.size()) return false;
+  int operand_idx = operand_dims.size() - 1;
+  int dot_idx = dot_dims.size() - 1;
+  for (; operand_idx >= 0; --operand_idx, --dot_idx) {
+    if (operand_dims[operand_idx] != 1 &&
+        operand_dims[operand_idx] != dot_dims[dot_idx])
+      return false;
+  }
+  return true;
+}
+
+template <typename Pattern>
+inline auto OptionalConvertAndBitcast(HloInstruction** optional_convert,
+                                      HloInstruction** optional_bitcast,
+                                      Pattern pattern) {
+  // Checks the presence of some intermediate operations that can be moved /
+  // folded to allow dot fusion with add.
+  // Try to match either of the following:
+  //   1. pattern-root -> bf16/f16-to-fp32 convert -> bitcast
+  //   2. pattern-root -> bf16/f16-to-fp32 convert
+  //   3. pattern-root -> bitcast
+  //   4. pattern-root
+  auto common = m::AnyOf<HloInstruction>(
+      pu::SupportedConvert(optional_convert, std::move(pattern).WithOneUser())
+          .WithElementType(PrimitiveType::F32),
+      std::move(pattern).WithOneUser());
+  return m::AnyOf<HloInstruction>(
+      BitcastWithReshapeSemantics(optional_bitcast, common), common);
+}
+
+}  // namespace
+
+bool OneDnnContractionRewriter::ShouldRewriteDot(
+    const HloInstruction* dot_instr, bool before_layout_assignment) {
+  // Currently, blocking control dependencies
+  if (dot_instr->HasControlDependencies()) return false;
+  if (!IsSupportedType(dot_instr->shape().element_type())) return false;
+  if (dot_instr->operands().size() != 2) return false;
+
+  // Currently, we rewrite when the data type is F32 or BF16. Note we do not
+  // need to check equality of contraction dim-size of the operands. HLO
+  // verifier already does the job. We, however, need to check if contraction
+  // is over only 1 dimension (a.k.a. K dimension in matrix-multiplication
+  // parlance). We also restrict that batch dimensions of the operands
+  // match.
+  const Shape& lhs_shape = dot_instr->operand(0)->shape();
+  const Shape& rhs_shape = dot_instr->operand(1)->shape();
+  const Shape& output_shape = dot_instr->shape();
+  // None of the operands and result should be ZeroElementArray.
+  if (ShapeUtil::IsZeroElementArray(lhs_shape) ||
+      ShapeUtil::IsZeroElementArray(rhs_shape) ||
+      ShapeUtil::IsZeroElementArray(output_shape)) {
+    return false;
+  }
+  // OneDNN only supports rank <= kOneDnnMaxNDims and singular non-contracting
+  // dimensions. We should not rewrite if any of these conditions are violated.
+  if (lhs_shape.rank() <= 0 || lhs_shape.rank() > kOneDnnMaxNDims ||
+      rhs_shape.rank() <= 0 || rhs_shape.rank() > kOneDnnMaxNDims ||
+      output_shape.rank() > std::min({lhs_shape.rank(), rhs_shape.rank(),
+                                      static_cast<int64_t>(kOneDnnMaxNDims)})) {
+    return false;
+  }
+
+  // Layout should be row-major, contraction dimensions captures transpose
+  // scenarios in last two dimensions.
+  // Col-major layouts are corrected to row-major for BatchDot operation as
+  // part of the layout-assignment pass.
+  // Skip row-major layout check before layout-assignment pass
+  if (!before_layout_assignment) {
+    bool row_major = IsRowMajor(lhs_shape) && IsRowMajor(rhs_shape) &&
+                     IsRowMajor(output_shape);
+    if (!row_major) return false;
+  }
+
+  auto dot_dim_numbers = dot_instr->dot_dimension_numbers();
+  int64_t lhs_dim_k = dot_dim_numbers.lhs_contracting_dimensions(0);
+  int64_t rhs_dim_k = dot_dim_numbers.rhs_contracting_dimensions(0);
+  // Supported contraction is only in one of last two dimensions.
+  if (lhs_dim_k < lhs_shape.rank() - 2 || rhs_dim_k < rhs_shape.rank() - 2) {
+    return false;
+  }
+
+  // OneDNN matmul has scratch allocation and copy overheads. The overheads
+  // can be amortized if there is sufficient number of flops. We don't rewrite
+  // for small cases (determined empirically).
+  // TODO(intel-tf): Relax the condition when more optimizations in oneDNN
+  // matmul is achieved.
+  auto num_flops = xla::HloCostAnalysis::GetDotFlops(lhs_shape, output_shape,
+                                                     dot_dim_numbers);
+  auto rank = output_shape.rank();
+  auto flops_threshold = (rank <= 2) ? (1 << 24) : (1 << 19);
+  return (num_flops >= flops_threshold);
+}
+
+bool OneDnnContractionRewriter::ShouldRewriteConv(
+    const HloInstruction* conv_instr) {
+  if (conv_instr->HasControlDependencies()) return false;
+  if (!IsSupportedType(conv_instr->shape().element_type())) return false;
+  if (conv_instr->batch_group_count() != 1) return false;
+
+  // TODO(intel-tf): Remove this restriction after enabling backward weights
+  // support
+  if (conv_instr->operand(1)->opcode() == HloOpcode::kReverse) return false;
+
+  const Shape& inp_shape = conv_instr->operand(0)->shape();
+  const Shape& ker_shape = conv_instr->operand(1)->shape();
+  const Shape& out_shape = conv_instr->shape();
+  if (ShapeUtil::IsZeroElementArray(inp_shape) ||
+      ShapeUtil::IsZeroElementArray(ker_shape) ||
+      ShapeUtil::IsZeroElementArray(out_shape)) {
+    return false;
+  }
+
+  auto dims = conv_instr->window().dimensions().size();
+  if (dims >= 4 || dims <= 0) return false;
+
+  if (inp_shape.rank() != ker_shape.rank() ||
+      inp_shape.rank() != out_shape.rank()) {
+    return false;
+  }
+
+  return true;
+}
+
+class OneDnnContractionRewriteVisitor : public DfsHloRewriteVisitor {
+ public:
+  // Matches patterns for possible MatMul fusions that are supported by oneDNN
+  // library. Matched HLO instruction(s) are replaced by custom call.
+  absl::Status HandleDot(HloInstruction* instr) override {
+    HloInstruction* dot_instr;
+    auto pattern = m::Op(&dot_instr).WithOpcode(HloOpcode::kDot);
+    if (!Match(instr, pattern)) return absl::OkStatus();
+
+    TF_RETURN_IF_ERROR(
+        ValidateDotDimensionNumbers(dot_instr->dot_dimension_numbers()));
+    if (!OneDnnContractionRewriter::ShouldRewriteDot(dot_instr)) {
+      TF_RETURN_IF_ERROR(UpcastDotToF32(dot_instr));
+      return absl::OkStatus();
+    }
+    TF_ASSIGN_OR_RETURN(dot_instr, ReconfigureDotDimensions(dot_instr));
+    auto dot_dim_numbers = dot_instr->dot_dimension_numbers();
+    const Shape& lhs_shape = dot_instr->operand(0)->shape();
+    const Shape& rhs_shape = dot_instr->operand(1)->shape();
+    const Shape& output_shape = dot_instr->shape();
+
+    int64_t lhs_dim_k = dot_dim_numbers.lhs_contracting_dimensions(0);
+    int64_t rhs_dim_k = dot_dim_numbers.rhs_contracting_dimensions(0);
+
+    HloInstruction* matmul_call =
+        dot_instr->AddInstruction(HloInstruction::CreateCustomCall(
+            output_shape,
+            {dot_instr->mutable_operand(0), dot_instr->mutable_operand(1)},
+            "__onednn$matmul"));
+    // Set additional info via config, e.g., transpose and fusion info.
+    BackendConfig backend_config;
+    OneDnnMatMulConfig* matmul_config =
+        backend_config.mutable_onednn_matmul_config();
+    bool transpose_a = (lhs_dim_k != lhs_shape.rank() - 1);
+    bool transpose_b = (rhs_dim_k != rhs_shape.rank() - 2);
+    matmul_config->set_transpose_a(transpose_a);
+    matmul_config->set_transpose_b(transpose_b);
+    TF_RETURN_IF_ERROR(matmul_call->set_backend_config(backend_config));
+    TF_RETURN_IF_ERROR(ReplaceInstruction(dot_instr, matmul_call));
+    return absl::OkStatus();
+  }
+
+  absl::Status HandleConvolution(HloInstruction* conv) override {
+    if (!OneDnnContractionRewriter::ShouldRewriteConv(conv)) {
+      return absl::OkStatus();
+    }
+
+    const Shape& conv_shape = conv->shape();
+    auto dims = conv->window().dimensions().size();
+    const ConvolutionDimensionNumbers& conv_dims =
+        conv->convolution_dimension_numbers();
+
+    BackendConfig backend_config;
+    OneDnnConvolutionConfig* conv_config =
+        backend_config.mutable_onednn_conv_config();
+
+    conv_config->set_dims(conv_shape.rank());
+    conv_config->set_feature_groups(conv->feature_group_count());
+    conv_config->mutable_input()->mutable_data()->set_batch_dim(
+        conv_dims.input_batch_dimension());
+    conv_config->mutable_kernel()->mutable_filter()->set_input_feature_dim(
+        conv_dims.kernel_input_feature_dimension());
+    conv_config->mutable_output()->mutable_data()->set_batch_dim(
+        conv_dims.output_batch_dimension());
+    conv_config->mutable_input()->mutable_data()->set_feature_dim(
+        conv_dims.input_feature_dimension());
+    conv_config->mutable_kernel()->mutable_filter()->set_output_feature_dim(
+        conv_dims.kernel_output_feature_dimension());
+    conv_config->mutable_output()->mutable_data()->set_feature_dim(
+        conv_dims.output_feature_dimension());
+
+    const Shape& output_shape = conv->shape();
+
+    for (auto it = conv->window().dimensions().begin();
+         it != conv->window().dimensions().end(); it++) {
+      if ((*it).padding_low() < 0 || (*it).padding_high() < 0 ||
+          (*it).stride() < 0 || (*it).base_dilation() != 1 ||
+          (*it).window_reversal()) {
+        return absl::OkStatus();
+      }
+      // Changing the input subspace of uint repeated fields from whole numbers
+      // to natural nummbers to avoid misinterpretation of buffer values.
+      conv_config->mutable_window()->add_pad_left((*it).padding_low() + 1);
+      conv_config->mutable_window()->add_pad_right((*it).padding_high() + 1);
+      conv_config->mutable_window()->add_strides((*it).stride() + 1);
+      conv_config->mutable_window()->add_window_dilations(
+          (*it).window_dilation() + 1);
+    }
+
+    for (int i = 0; i < dims; i++) {
+      conv_config->mutable_input()->mutable_data()->add_spatial_dims(
+          conv_dims.input_spatial_dimensions()[i] + 1);
+      conv_config->mutable_kernel()->mutable_filter()->add_spatial_dims(
+          conv_dims.kernel_spatial_dimensions()[i] + 1);
+      conv_config->mutable_output()->mutable_data()->add_spatial_dims(
+          conv_dims.output_spatial_dimensions()[i] + 1);
+    }
+
+    HloInstruction* custom_call =
+        conv->AddInstruction(HloInstruction::CreateCustomCall(
+            output_shape, {conv->mutable_operand(0), conv->mutable_operand(1)},
+            "__onednn$convolution"));
+
+    TF_RETURN_IF_ERROR(custom_call->set_backend_config(backend_config));
+    TF_RETURN_IF_ERROR(ReplaceInstruction(conv, custom_call));
+    return absl::OkStatus();
+  }
+
+  absl::Status HandleAdd(HloInstruction* instr) override {
+    // Try to do a fusion for Dot(onednn-matmul) + Add. However,
+    // HLO Add instruction might receive the addends after additional
+    // processing like Broadcast, Bitcast, Convert, etc. is applied to the raw
+    // addends. Here, the following possible pattern is matched.
+    //
+    // clang-format off
+    //
+    //            Dot                               addend
+    //             |                                  |
+    //             v                                  v
+    //     optional instructions            optional instructions
+    //    (e.g, Convert, Bitcast)         (e.g, Convert, Broadcast)
+    //             |                                  |
+    //             +--------------+-------------------+
+    //                            |
+    //                            v
+    //                           Add
+    //
+    // clang-format on
+
+    HloInstruction *addend_intermediate, *dot;
+    HloInstruction* optional_dot_bitcast = nullptr;
+    HloInstruction* optional_dot_convert = nullptr;
+
+    auto pattern = m::AddAnyOrder(
+        &instr,
+        OptionalConvertAndBitcast(&optional_dot_convert, &optional_dot_bitcast,
+                                  OneDnnMatmulInstr(&dot))
+            .WithOneUser(),
+        m::Op(&addend_intermediate));
+
+    if (Match(instr, pattern)) {
+      if (!IsSupportedType(dot->shape().element_type()))
+        return absl::OkStatus();
+      // TODO(intel-tf): Remove the condition below when the fusion Dot +
+      // Add(bias) + Add(e.g., residual) is enabled.
+      if (!dot->backend_config<BackendConfig>()
+               ->mutable_onednn_matmul_config()
+               ->mutable_fusions()
+               ->ops()
+               .empty() &&
+          dot->backend_config<BackendConfig>()
+                  ->mutable_onednn_matmul_config()
+                  ->mutable_fusions()
+                  ->ops(0) == OneDnnFusionConfig::BIAS) {
+        return absl::OkStatus();
+      }
+      std::vector<HloInstruction*> new_operands;
+      for (auto operand : dot->operands()) {
+        new_operands.push_back(operand);
+      }
+
+      // At this point, the addend could have one of the following
+      // possiblities that the current fusion can handle:
+      //
+      //   - addend -> Convert -> Broadcast -> Add
+      //   - addend -> Broadcast -> Convert -> Add
+      //   - addend -> Convert
+      //   - addend -> Broadcast
+      //   - addend
+      //
+      // Hunt for addend through possible sequences above and check the addend
+      // is compatible to onednn-matmul fusion.
+      HloInstruction* addend = nullptr;
+      HloInstruction* optional_addend_broadcast = nullptr;
+      auto addend_pattern = m::AnyOf<HloInstruction>(
+          m::Broadcast(&optional_addend_broadcast,
+                       m::Convert(&addend, m::Op())),
+          m::Convert(m::Broadcast(&optional_addend_broadcast, m::Op(&addend))),
+          m::Convert(&addend, m::Op()),
+          m::Broadcast(&optional_addend_broadcast, m::Op(&addend)),
+          m::Op(&addend));
+      if (!Match(addend_intermediate, addend_pattern)) return absl::OkStatus();
+
+      if (optional_addend_broadcast && addend->shape().rank() != 1) {
+        auto new_shape =
+            AdjustBiasShape(optional_addend_broadcast, dot->shape());
+        if (new_shape.ok()) {
+          addend = addend->AddInstruction(
+              HloInstruction::CreateBitcast(new_shape.value(), addend));
+        } else {
+          VLOG(2) << new_shape.status();
+          return absl::OkStatus();
+        }
+      }
+
+      // Validate addend for fusion.
+      if (IsSupportedType(addend->shape().element_type()) &&
+          IsOperandFusible(addend, dot)) {
+        new_operands.push_back(addend);
+      } else {
+        return absl::OkStatus();
+      }
+
+      // TODO(intel-tf): Remove this restriction once oneDNN has an optimized
+      // implementation for broadcasted add across all dimensions.
+      OneDnnFusionConfig_FusionKind kind = OneDnnFusionConfig::UNDEFINED;
+      kind = (addend->shape().rank() == 1)
+                 ? (dot->backend_config<BackendConfig>()
+                            ->mutable_onednn_matmul_config()
+                            ->fusions()
+                            .ops()
+                            .empty()
+                        ? OneDnnFusionConfig::BIAS
+                        : OneDnnFusionConfig::UNDEFINED)
+                 : OneDnnFusionConfig::BINARY_ADD;
+      if (kind == OneDnnFusionConfig::UNDEFINED) return absl::OkStatus();
+
+      auto matmul_call = Cast<HloCustomCallInstruction>(instr->AddInstruction(
+          dot->CloneWithNewOperands(dot->shape(), new_operands)));
+
+      auto backend_config = matmul_call->backend_config<BackendConfig>();
+      backend_config->mutable_onednn_matmul_config()
+          ->mutable_fusions()
+          ->add_ops(kind);
+
+      if (optional_addend_broadcast) {
+        backend_config->mutable_onednn_matmul_config()
+            ->mutable_optimization_config()
+            ->set_bias_broadcast(true);
+      }
+      TF_RETURN_IF_ERROR(matmul_call->set_backend_config(*backend_config));
+
+      HloInstruction* new_instr;
+      // If matched pattern has custom-call -> bitcast -> add, then we need to
+      // insert bitcast after the new fusion to maintain the correct shape
+      // (new-custom-call -> bitcast). Also, this will optionally be followed
+      // by -> convert for bf16 case to avoid datatype mismatch.
+      if (optional_dot_bitcast != nullptr &&
+          optional_dot_bitcast->opcode() == HloOpcode::kBitcast) {
+        if (optional_dot_convert != nullptr &&
+            optional_dot_convert->opcode() == HloOpcode::kConvert) {
+          auto bitcast_call =
+              matmul_call->AddInstruction(HloInstruction::CreateBitcast(
+                  ShapeUtil::ChangeElementType(
+                      instr->shape(), matmul_call->shape().element_type()),
+                  matmul_call));
+          new_instr =
+              bitcast_call->AddInstruction(HloInstruction::CreateConvert(
+                  ShapeUtil::ChangeElementType(
+                      bitcast_call->shape(),
+                      optional_dot_convert->shape().element_type()),
+                  bitcast_call));
+        } else {
+          new_instr = matmul_call->AddInstruction(
+              HloInstruction::CreateBitcast(instr->shape(), matmul_call));
+        }
+      } else {
+        if (optional_dot_convert != nullptr &&
+            optional_dot_convert->opcode() == HloOpcode::kConvert) {
+          new_instr = matmul_call->AddInstruction(HloInstruction::CreateConvert(
+              ShapeUtil::ChangeElementType(
+                  matmul_call->shape(),
+                  optional_dot_convert->shape().element_type()),
+              matmul_call));
+        } else {
+          new_instr = matmul_call;
+        }
+      }
+      TF_RETURN_IF_ERROR(ReplaceInstruction(instr, new_instr));
+    }
+
+    return absl::OkStatus();
+  }
+
+  absl::Status HandleMaximum(HloInstruction* instr) override {
+    HloInstruction* matmul_call;
+    HloInstruction* intermediate_instr = nullptr;
+    HloInstruction* optional_bitcast = nullptr;
+    // Attempt to elide maximum and fuse ReLU activation into GEMM, including
+    // when slicing or bitcasting is applied to the result.
+    if (Match(instr,
+              m::MaximumAnyOrder(ElementwiseSafeIntermediates(
+                                     &intermediate_instr, &optional_bitcast,
+                                     OneDnnMatmulInstr(&matmul_call))
+                                     .WithOneUser(),
+                                 BcastConstScalar(0)))) {
+      return FuseActivation(OneDnnFusionConfig::RELU, instr, matmul_call,
+                            intermediate_instr, optional_bitcast);
+    }
+    return absl::OkStatus();
+  }
+
+  auto ELUActivation(HloInstruction* instr, HloInstruction** src) {
+    //  Reference: tensorflow/compiler/tf2xla/kernels/elu_op.cc
+    //  const auto zero = ScalarLike(x, 0);
+    //  const auto pred = Gt(x, zero);
+    //  const auto expm1 = Expm1(x);
+    //  return Select(pred, x, expm1);
+    auto pattern = m::Select(
+        m::Gt(pu::OptionalConvert(m::Op(src)), BcastConvertConstScalar(0)),
+        m::Op(src),
+        pu::OptionalConvert(m::Expm1(pu::OptionalConvert(m::Op(src)))));
+    return Match(instr, pattern);
+  }
+
+  absl::Status HandleSelect(HloInstruction* instr) override {
+    HloInstruction* matmul_call;
+    HloInstruction* intermediate_instr = nullptr;
+    HloInstruction* optional_bitcast = nullptr;
+    HloInstruction* src;
+    // Attempt to elide ELU subgraph and fuse ELU activation into GEMM,
+    // including when slicing or bitcasting is applied to the result.
+    if (ELUActivation(instr, &src)) {
+      if (Match(src, ElementwiseSafeIntermediates(
+                         &intermediate_instr, &optional_bitcast,
+                         OneDnnMatmulInstr(&matmul_call)))) {
+        return FuseActivation(OneDnnFusionConfig::ELU, instr, matmul_call,
+                              intermediate_instr);
+      }
+    }
+    return absl::OkStatus();
+  }
+
+  absl::Status HandleTanh(HloInstruction* instr) override {
+    HloInstruction* matmul_call;
+    HloInstruction* intermediate_instr = nullptr;
+    HloInstruction* optional_bitcast = nullptr;
+    // Attempt to elide Tanh and fuse Tanh activation into GEMM, including
+    // when slicing or bitcasting is applied to the result.
+    if (Match(instr, m::Tanh(ElementwiseSafeIntermediates(
+                                 &intermediate_instr, &optional_bitcast,
+                                 OneDnnMatmulInstr(&matmul_call))
+                                 .WithOneUser()))) {
+      return FuseActivation(OneDnnFusionConfig::TANH, instr, matmul_call,
+                            intermediate_instr);
+    }
+    return absl::OkStatus();
+  }
+
+  absl::Status HandleClamp(HloInstruction* instr) override {
+    HloInstruction* matmul_call;
+    HloInstruction* intermediate_instr = nullptr;
+    HloInstruction* optional_bitcast = nullptr;
+    // Attempt to elide RELU6 and fuse RELU6 activation into GEMM, including
+    // when slicing or bitcasting is applied to the result.
+    if (Match(instr, m::Clamp(BcastConstScalar(0),
+                              ElementwiseSafeIntermediates(
+                                  &intermediate_instr, &optional_bitcast,
+                                  OneDnnMatmulInstr(&matmul_call))
+                                  .WithOneUser(),
+                              BcastConstScalar(6)))) {
+      return FuseActivation(OneDnnFusionConfig::RELU6, instr, matmul_call,
+                            intermediate_instr);
+    }
+    return absl::OkStatus();
+  }
+
+  absl::Status HandleMultiply(HloInstruction* instr) override {
+    HloInstruction* matmul_call;
+    HloInstruction* intermediate_instr = nullptr;
+    HloInstruction* src;
+    auto activation = GELUActivation(instr, &src);
+    if (activation != OneDnnFusionConfig::UNDEFINED) {
+      HloInstruction* optional_bitcast = nullptr;
+      if (Match(src, ElementwiseSafeIntermediates(
+                         &intermediate_instr, &optional_bitcast,
+                         OneDnnMatmulInstr(&matmul_call)))) {
+        return FuseActivation(activation, instr, matmul_call,
+                              intermediate_instr, optional_bitcast);
+      }
+    }
+
+    HloInstruction *dot, *constant;
+    HloInstruction* optional_convert = nullptr;
+    auto pattern = m::Op(&instr)
+                       .WithOpcode(HloOpcode::kMultiply)
+                       .WithBinaryOperandsAnyOrder(
+                           m::AnyOf<HloInstruction>(
+                               pu::SupportedConvert(&optional_convert,
+                                                    OneDnnMatmulInstr(&dot))
+                                   .WithElementType(PrimitiveType::F32),
+                               OneDnnMatmulInstr(&dot))
+                               .WithOneUser(),
+                           m::Broadcast(m::Constant(&constant)));
+
+    if (Match(instr, pattern)) {
+      std::vector<HloInstruction*> new_operands;
+      auto constant_value = GetConstantValueAsFloat32(constant);
+      if (!constant_value) {
+        return absl::OkStatus();
+      }
+
+      for (auto operand : dot->operands()) {
+        new_operands.push_back(operand);
+      }
+      auto matmul_call = Cast<HloCustomCallInstruction>(instr->AddInstruction(
+          dot->CloneWithNewOperands(instr->shape(), new_operands)));
+      auto backend_config = matmul_call->backend_config<BackendConfig>();
+      backend_config->mutable_onednn_matmul_config()
+          ->mutable_fusions()
+          ->add_ops(OneDnnFusionConfig::LINEAR);
+      // Casting to int32 because of issues in proto config for decimal types
+      // handling.
+      backend_config->mutable_onednn_matmul_config()
+          ->mutable_fusions()
+          ->set_alpha_typecast(
+              *(reinterpret_cast<int32_t*>(&constant_value.value())));
+      TF_RETURN_IF_ERROR(matmul_call->set_backend_config(*backend_config));
+      HloInstruction* new_instr;
+      if (optional_convert != nullptr &&
+          optional_convert->opcode() == HloOpcode::kConvert) {
+        new_instr = matmul_call->AddInstruction(HloInstruction::CreateConvert(
+            ShapeUtil::ChangeElementType(
+                matmul_call->shape(), optional_convert->shape().element_type()),
+            matmul_call));
+      } else {
+        new_instr = matmul_call;
+      }
+
+      TF_RETURN_IF_ERROR(ReplaceInstruction(instr, new_instr));
+    }
+    return absl::OkStatus();
+  }
+
+  auto SigmoidActivation(HloInstruction* instr, HloInstruction** src) {
+    return Match(instr,
+                 m::Divide(BcastConstScalar(1.0),
+                           m::AddAnyOrder(BcastConstScalar(1.0),
+                                          m::Exp(m::Negate(m::Op(src))))));
+  }
+
+  absl::Status HandleDivide(HloInstruction* instr) override {
+    HloInstruction* matmul_call;
+    HloInstruction* intermediate_instr = nullptr;
+    HloInstruction* optional_bitcast = nullptr;
+    HloInstruction* src;
+    if (SigmoidActivation(instr, &src)) {
+      if (Match(src, ElementwiseSafeIntermediates(
+                         &intermediate_instr, &optional_bitcast,
+                         OneDnnMatmulInstr(&matmul_call))
+                         .WithOneUser())) {
+        return FuseActivation(OneDnnFusionConfig::SIGMOID, instr, matmul_call,
+                              intermediate_instr, optional_bitcast);
+      }
+    }
+    return absl::OkStatus();
+  }
+
+  absl::Status FuseActivation(OneDnnFusionConfig_FusionKind kind,
+                              HloInstruction* activation,
+                              HloInstruction* matmul,
+                              HloInstruction* intermediate_instr = nullptr,
+                              HloInstruction* optional_bitcast = nullptr) {
+    TF_ASSIGN_OR_RETURN(auto backend_config,
+                        matmul->backend_config<BackendConfig>());
+    auto* matmul_config = backend_config.mutable_onednn_matmul_config();
+    matmul_config->mutable_fusions()->add_ops(kind);
+    TF_RETURN_IF_ERROR(matmul->set_backend_config(backend_config));
+    std::unique_ptr<HloInstruction> output = matmul->Clone();
+    if (optional_bitcast != nullptr &&
+        optional_bitcast->opcode() == HloOpcode::kBitcast) {
+      HloInstruction* new_instr = nullptr;
+      if (intermediate_instr != nullptr &&
+          intermediate_instr->opcode() == HloOpcode::kConvert) {
+        auto bitcast_call =
+            matmul->AddInstruction(HloInstruction::CreateBitcast(
+                ShapeUtil::ChangeElementType(optional_bitcast->shape(),
+                                             matmul->shape().element_type()),
+                matmul));
+        new_instr = bitcast_call->AddInstruction(HloInstruction::CreateConvert(
+            ShapeUtil::ChangeElementType(
+                bitcast_call->shape(),
+                intermediate_instr->shape().element_type()),
+            bitcast_call));
+        return ReplaceInstruction(activation, new_instr);
+      }
+    } else if (intermediate_instr) {
+      output = intermediate_instr->CloneWithNewOperands(
+          intermediate_instr->shape(),
+          {matmul->parent()->AddInstruction(std::move(output))});
+    }
+
+    return ReplaceWithNewInstruction(activation, std::move(output));
+  }
+
+  // This function changes dot instruction for supported matrix
+  // multiplication scenarios. In particular, it changes the shape
+  // of lhs, rhs and result arrays.
+  //    - lhs configuration scenario
+  //      lhs:    [batch_dims,contracting_dim] to [batch_dims,1,contracting_dim]
+  //      result: [batch_dims,feature_dim] to [batch_dims,1,feature_dim]
+  //
+  //    - rhs configuration scenario
+  //      rhs:    [batch_dims,contracting_dim] to [batch_dims,contracting_dim,1]
+  //      result: [batch_dims,feature_dim] to [batch_dims,feature_dim, 1]
+  //
+  //    - both lhs and rhs configuration scenario
+  //      lhs:    [batch_dims,contracting_dim] to [batch_dims,1,contracting_dim]
+  //      rhs:    [batch_dims,contracting_dim] to [batch_dims,contracting_dim,1]
+  //      result: [batch_dims] to [batch_dims,1,1]
+  absl::StatusOr<HloInstruction*> ReconfigureDotDimensions(
+      HloInstruction* dot_instr) {
+    HloInstruction* lhs = dot_instr->mutable_operand(0);
+    HloInstruction* rhs = dot_instr->mutable_operand(1);
+    DotDimensionNumbers dim_numbers = dot_instr->dot_dimension_numbers();
+
+    auto lhs_batch_dims = dim_numbers.lhs_batch_dimensions();
+    auto lhs_contraction_dims = dim_numbers.lhs_contracting_dimensions();
+    bool is_lhs_vector = lhs->shape().rank() ==
+                         (lhs_batch_dims.size() + lhs_contraction_dims.size());
+
+    auto rhs_batch_dims = dim_numbers.rhs_batch_dimensions();
+    auto rhs_contraction_dims = dim_numbers.rhs_contracting_dimensions();
+    bool is_rhs_vector = rhs->shape().rank() ==
+                         (rhs_batch_dims.size() + rhs_contraction_dims.size());
+
+    if (!is_lhs_vector && !is_rhs_vector) return dot_instr;
+
+    std::vector<int64_t> adjusted_lhs_dims(lhs->shape().dimensions().begin(),
+                                           lhs->shape().dimensions().end());
+    std::vector<int64_t> adjusted_rhs_dims(rhs->shape().dimensions().begin(),
+                                           rhs->shape().dimensions().end());
+    std::vector<int64_t> adjusted_dot_dims(
+        dot_instr->shape().dimensions().begin(),
+        dot_instr->shape().dimensions().end());
+
+    if (is_lhs_vector) {
+      auto lhs_it = adjusted_lhs_dims.begin() + lhs_batch_dims.size();
+      adjusted_lhs_dims.insert(lhs_it, 1, 1);
+      auto result_it = adjusted_dot_dims.begin() + lhs_batch_dims.size();
+      adjusted_dot_dims.insert(result_it, 1, 1);
+      auto lhs_contraction_dim =
+          dot_instr->dot_dimension_numbers().lhs_contracting_dimensions(0);
+      dim_numbers.set_lhs_contracting_dimensions(0, lhs_contraction_dim + 1);
+      lhs = lhs->AddInstruction(HloInstruction::CreateBitcast(
+          ShapeUtil::MakeShape(lhs->shape().element_type(), adjusted_lhs_dims),
+          lhs));
+    }
+
+    if (is_rhs_vector) {
+      auto it = adjusted_rhs_dims.end();
+      adjusted_rhs_dims.insert(it, 1, 1);
+      auto result_it = adjusted_dot_dims.end();
+      adjusted_dot_dims.insert(result_it, 1, 1);
+      rhs = rhs->AddInstruction(HloInstruction::CreateBitcast(
+          ShapeUtil::MakeShape(rhs->shape().element_type(), adjusted_rhs_dims),
+          rhs));
+    }
+
+    HloInstruction* adjusted_dot =
+        dot_instr->AddInstruction(HloInstruction::CreateDot(
+            ShapeUtil::MakeShape(dot_instr->shape().element_type(),
+                                 adjusted_dot_dims),
+            lhs, rhs, dim_numbers, dot_instr->precision_config()));
+
+    HloInstruction* replacement_instr = adjusted_dot->AddInstruction(
+        HloInstruction::CreateBitcast(dot_instr->shape(), adjusted_dot));
+
+    TF_RETURN_IF_ERROR(ReplaceInstruction(dot_instr, replacement_instr));
+    return adjusted_dot;
+  }
+
+  // This function upcasts BF16 dots to F32 if we are unable to rewrite them to
+  // oneDNN custom calls.
+  absl::Status UpcastDotToF32(HloInstruction* dot_instr) {
+    if (dot_instr->shape().element_type() != BF16) return absl::OkStatus();
+    std::vector<HloInstruction*> new_operands;
+    auto bf16_operands = dot_instr->operands();
+
+    std::for_each(
+        bf16_operands.begin(), bf16_operands.end(),
+        [&new_operands](HloInstruction* instr) {
+          new_operands.push_back(
+              instr->AddInstruction(HloInstruction::CreateConvert(
+                  ShapeUtil::ChangeElementType(instr->shape(), F32), instr)));
+        });
+
+    HloInstruction* f32_dot =
+        dot_instr->AddInstruction(dot_instr->CloneWithNewOperands(
+            ShapeUtil::ChangeElementType(dot_instr->shape(), F32),
+            new_operands));
+
+    HloInstruction* replacement_instr =
+        f32_dot->AddInstruction(HloInstruction::CreateConvert(
+            ShapeUtil::ChangeElementType(f32_dot->shape(), BF16), f32_dot));
+
+    TF_RETURN_IF_ERROR(ReplaceInstruction(dot_instr, replacement_instr));
+    return absl::OkStatus();
+  }
+};
+
+class OneDnnPostRewriteVisitor : public DfsHloRewriteVisitor {
+ public:
+  OneDnnPostRewriteVisitor(int intra_op_parallelism,
+                           const tsl::thread::ThreadPool* compile_threadpool)
+      : intra_op_parallelism_(intra_op_parallelism > 0
+                                  ? intra_op_parallelism
+                                  : tsl::port::MaxParallelism()),
+        evaluator_(/*max_loop_iterations=*/0) {
+    if (compile_threadpool) {
+      threadpool_device_.reset(
+          new Eigen::ThreadPoolDevice(compile_threadpool->AsEigenThreadPool(),
+                                      compile_threadpool->NumThreads()));
+    } else {
+      threadpool_handle_.reset(new tsl::thread::ThreadPool(
+          tsl::Env::Default(), "XLACpuCompile", tsl::port::MaxParallelism()));
+      threadpool_device_.reset(
+          new Eigen::ThreadPoolDevice(threadpool_handle_->AsEigenThreadPool(),
+                                      threadpool_handle_->NumThreads()));
+    }
+
+#ifndef ENABLE_ONEDNN_OPENMP
+    // Set oneDNN concurrency settings (which is thread-local)
+    tsl::OneDnnThreadPool::set_onednn_max_threads(intra_op_parallelism_);
+#endif
+  }
+
+  absl::Status HandleCustomCall(HloInstruction* custom_call) override {
+    HloInstruction* matmul;
+    if (Match(custom_call, OneDnnMatmulInstr(&matmul))) {
+      return HandleCustomCallInternal<dnnl::matmul::primitive_desc>(
+          custom_call);
+    }
+
+    return DefaultAction(custom_call);
+  }
+
+  template <typename PrimDesc>
+  absl::Status HandleCustomCallInternal(HloInstruction* custom_call) {
+    auto scratch_add = AddScratch<PrimDesc>(custom_call);
+    if (scratch_add.ok()) {
+      custom_call = *scratch_add;
+    } else {
+      VLOG(2) << scratch_add.status();
+    }
+    auto weights_prepack = PrepackWeights<PrimDesc>(custom_call);
+    if (!weights_prepack.ok()) {
+      VLOG(2) << weights_prepack.status();
+    }
+    return absl::OkStatus();
+  }
+
+  template <typename>
+  absl::Status SetWeightsPrepack(HloInstruction*, bool);
+
+  template <typename>
+  absl::Status SetUserScratch(HloInstruction*, bool);
+
+  template <typename>
+  bool GetWeightsPrepack(HloInstruction*);
+
+  template <typename>
+  bool GetUserScratch(HloInstruction*);
+
+  // Add scratch for matmul by changing the result of custom-call to
+  // tuple(result, scratch)
+  template <typename PrimDesc>
+  absl::StatusOr<HloInstruction*> AddScratch(HloInstruction* custom_call) {
+    if (GetUserScratch<PrimDesc>(custom_call)) {
+      return custom_call;
+    }
+    TF_RETURN_IF_ERROR(SetUserScratch<PrimDesc>(custom_call, true));
+    auto prim_desc = CreateOneDnnPrimDesc<PrimDesc>(custom_call);
+    int64_t scratch_size = prim_desc->scratchpad_desc().get_size();
+    Shape scratch_shape = ShapeUtil::MakeShape(U8, {scratch_size});
+    Shape tuple_shape =
+        ShapeUtil::MakeTupleShape({custom_call->shape(), scratch_shape});
+    auto new_custom_call = custom_call->AddInstruction(
+        custom_call->CloneWithNewShape(tuple_shape));
+    HloInstruction* gte =
+        new_custom_call->AddInstruction(HloInstruction::CreateGetTupleElement(
+            custom_call->shape(), new_custom_call, 0));
+    auto status = ReplaceInstruction(custom_call, gte);
+    if (!status.ok()) {
+      TF_RETURN_IF_ERROR(SetUserScratch<PrimDesc>(custom_call, false));
+      return absl::CancelledError("Adding scratch is unsuccessful.");
+    }
+    return new_custom_call;
+  }
+
+  template <typename PrimDesc>
+  absl::StatusOr<HloInstruction*> PrepackWeights(HloInstruction* custom_call) {
+    if (GetWeightsPrepack<PrimDesc>(custom_call)) {
+      return custom_call;
+    }
+    auto weights = custom_call->operand(1);
+    auto weights_shape = weights->shape();
+    Literal weights_literal;
+    if (!(weights_shape.rank() == 2 &&
+          evaluator_.TryEvaluate(weights, &weights_literal, true))) {
+      return absl::CancelledError(
+          "Cannot prepack weights. Not constant 2D weights.");
+    }
+    auto plain_weights_md = ShapeToMemDesc(weights_shape);
+    if constexpr (std::is_same<PrimDesc, dnnl::matmul::primitive_desc>::value) {
+      TF_ASSIGN_OR_RETURN(auto backend_config,
+                          custom_call->backend_config<BackendConfig>());
+      TRANSPOSE_LAST_TWO_DIMS_IF(
+          backend_config.onednn_matmul_config().transpose_b(),
+          plain_weights_md);
+    }
+    TF_RETURN_IF_ERROR(SetWeightsPrepack<PrimDesc>(custom_call, true));
+    auto prim_desc = CreateOneDnnPrimDesc<PrimDesc>(custom_call);
+    auto packed_weights_md = prim_desc->weights_desc();
+    auto packed_weights_shape = MemDescToXlaShapeFlattened(packed_weights_md);
+    auto packed_weights_literal = Literal(packed_weights_shape);
+    ReorderWeight(plain_weights_md, weights_literal.untyped_data(),
+                  packed_weights_md, packed_weights_literal.untyped_data());
+    HloInstruction* reordered_weight = custom_call->AddInstruction(
+        HloInstruction::CreateConstant(std::move(packed_weights_literal)));
+    auto status =
+        custom_call->ReplaceOperandWithDifferentShape(1, reordered_weight);
+    if (!status.ok()) {
+      TF_RETURN_IF_ERROR(SetWeightsPrepack<PrimDesc>(custom_call, false));
+      return absl::CancelledError(
+          "Cannot replace plain weights with prepacked weights.");
+    } else {
+      return custom_call;
+    }
+  }
+
+  void ReorderWeight(const dnnl::memory::desc& src_md, void* src_buf,
+                     const dnnl::memory::desc& dst_md, void* dst_buf) {
+    auto onednn_threadpool = CreateOneDnnThreadPool(threadpool_device_.get());
+    dnnl::engine cpu_engine(dnnl::engine::kind::cpu, 0);
+    auto onednn_stream = MakeOneDnnStream(cpu_engine, onednn_threadpool.get());
+    auto src_mem = dnnl::memory(src_md, cpu_engine, src_buf);
+    auto dst_mem = dnnl::memory(dst_md, cpu_engine, dst_buf);
+    dnnl::reorder reorder_prim{src_mem, dst_mem};
+    reorder_prim.execute(onednn_stream, src_mem, dst_mem);
+    onednn_stream.wait();
+  }
+
+ private:
+  int intra_op_parallelism_;
+  HloEvaluator evaluator_;
+  std::unique_ptr<tsl::thread::ThreadPool> threadpool_handle_;
+  std::unique_ptr<Eigen::ThreadPoolDevice> threadpool_device_;
+};
+
+#define EMIT_GET_BACKEND_CONFIG_SPECIALIZATION(GETTER, PRIM_DESC, CONFIG,      \
+                                               SUB_CONFIG, FIELD)              \
+  template <>                                                                  \
+  inline bool OneDnnPostRewriteVisitor::GETTER<PRIM_DESC>(HloInstruction *     \
+                                                          custom_call) {       \
+    auto backend_config = custom_call->backend_config<BackendConfig>();        \
+    return backend_config.ok() ? backend_config->CONFIG().SUB_CONFIG().FIELD() \
+                               : false;                                        \
+  }
+
+EMIT_GET_BACKEND_CONFIG_SPECIALIZATION(GetUserScratch,
+                                       dnnl::matmul::primitive_desc,
+                                       onednn_matmul_config,
+                                       optimization_config, user_scratchpad);
+EMIT_GET_BACKEND_CONFIG_SPECIALIZATION(GetWeightsPrepack,
+                                       dnnl::matmul::primitive_desc,
+                                       onednn_matmul_config,
+                                       optimization_config, weights_prepacked);
+
+#define EMIT_SET_BACKEND_CONFIG_SPECIALIZATION(SETTER, PRIM_DESC, CONFIG_TYPE, \
+                                               CONFIG, SUB_CONFIG, FIELD)      \
+  template <>                                                                  \
+  inline absl::Status OneDnnPostRewriteVisitor::SETTER<PRIM_DESC>(             \
+      HloInstruction * custom_call, bool value) {                              \
+    TF_ASSIGN_OR_RETURN(auto backend_config,                                   \
+                        custom_call->backend_config<BackendConfig>());         \
+    CONFIG_TYPE* config = backend_config.mutable_##CONFIG();                   \
+    config->mutable_##SUB_CONFIG()->set_##FIELD(value);                        \
+    return custom_call->set_backend_config(backend_config);                    \
+  }
+
+EMIT_SET_BACKEND_CONFIG_SPECIALIZATION(SetWeightsPrepack,
+                                       dnnl::matmul::primitive_desc,
+                                       OneDnnMatMulConfig, onednn_matmul_config,
+                                       optimization_config, weights_prepacked);
+EMIT_SET_BACKEND_CONFIG_SPECIALIZATION(SetUserScratch,
+                                       dnnl::matmul::primitive_desc,
+                                       OneDnnMatMulConfig, onednn_matmul_config,
+                                       optimization_config, user_scratchpad);
+
+absl::StatusOr<bool> OneDnnContractionRewriter::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  OneDnnContractionRewriteVisitor visitor;
+  TF_ASSIGN_OR_RETURN(auto result,
+                      visitor.RunOnModule(module, execution_threads));
+
+  OneDnnPostRewriteVisitor reorder_visitor(intra_op_parallelism_,
+                                           compile_threadpool_);
+  TF_ASSIGN_OR_RETURN(auto result2,
+                      reorder_visitor.RunOnModule(module, execution_threads));
+
+  return {result || result2};
+}
+
+}  // namespace cpu
+}  // namespace xla
+
+#endif  // INTEL_MKL && ENABLE_ONEDNN_V3
diff --git a/third_party/xla/xla/service/cpu/onednn_contraction_rewriter.h b/third_party/xla/xla/service/cpu/onednn_contraction_rewriter.h
new file mode 100644
index 0000000..7864dae
--- /dev/null
+++ b/third_party/xla/xla/service/cpu/onednn_contraction_rewriter.h
@@ -0,0 +1,63 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_CPU_ONEDNN_CONTRACTION_REWRITER_H_
+#define XLA_SERVICE_CPU_ONEDNN_CONTRACTION_REWRITER_H_
+#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3)
+
+#include <optional>
+
+#include "absl/algorithm/container.h"
+#include "unsupported/Eigen/CXX11/Tensor"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "tsl/platform/threadpool.h"
+
+namespace xla {
+namespace cpu {
+
+// This pass pattern-matches HLO Dot and Convolution instructions and rewrites
+// them into custom calls.
+class OneDnnContractionRewriter : public HloModulePass {
+ public:
+  OneDnnContractionRewriter(int intra_op_parallelism,
+                            const tsl::thread::ThreadPool* compile_threadpool)
+      : intra_op_parallelism_(intra_op_parallelism),
+        compile_threadpool_(compile_threadpool) {}
+  OneDnnContractionRewriter() = default;
+  absl::string_view name() const override {
+    return "onednn-contraction-rewriter";
+  }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+  static bool ShouldRewriteDot(const HloInstruction* dot_instr,
+                               bool before_layout_assignment = false);
+  static bool ShouldRewriteConv(const HloInstruction* conv_instr);
+
+ private:
+  int intra_op_parallelism_;
+  const tsl::thread::ThreadPool* compile_threadpool_;
+};
+
+}  // namespace cpu
+}  // namespace xla
+
+#endif  // INTEL_MKL && ENABLE_ONEDNN_V3
+#endif  // XLA_SERVICE_CPU_ONEDNN_CONTRACTION_REWRITER_H_
diff --git a/third_party/xla/xla/service/cpu/onednn_convolution_rewriter.cc b/third_party/xla/xla/service/cpu/onednn_convolution_rewriter.cc
deleted file mode 100644
index 0c65c5d..0000000
--- a/third_party/xla/xla/service/cpu/onednn_convolution_rewriter.cc
+++ /dev/null
@@ -1,143 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3)
-
-#include "xla/service/cpu/onednn_convolution_rewriter.h"
-
-#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/service/cpu/backend_config.pb.h"
-#include "xla/service/cpu/onednn_config.pb.h"
-#include "xla/service/cpu/onednn_memory_util.h"
-#include "xla/service/cpu/onednn_util.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/status_macros.h"
-
-namespace xla {
-namespace cpu {
-
-namespace {
-namespace m = match;
-}  // namespace
-
-bool OneDnnConvolutionRewriter::ShouldRewrite(const HloInstruction* conv) {
-  if (conv->HasControlDependencies()) return false;
-  if (!IsSupportedType(conv->shape().element_type())) return false;
-  if (conv->batch_group_count() != 1) return false;
-
-  if (conv->operand(1)->opcode() == HloOpcode::kReverse) return false;
-
-  const Shape& inp_shape = conv->operand(0)->shape();
-  const Shape& ker_shape = conv->operand(1)->shape();
-  const Shape& out_shape = conv->shape();
-  if (ShapeUtil::IsZeroElementArray(inp_shape) ||
-      ShapeUtil::IsZeroElementArray(ker_shape) ||
-      ShapeUtil::IsZeroElementArray(out_shape)) {
-    return false;
-  }
-
-  auto dims = conv->window().dimensions().size();
-  if (dims >= 4 || dims <= 0) return false;
-
-  if (inp_shape.rank() != ker_shape.rank() ||
-      inp_shape.rank() != out_shape.rank()) {
-    return false;
-  }
-
-  return true;
-}
-
-class OneDnnConvolutionRewriterVisitor : public DfsHloRewriteVisitor {
- public:
-  absl::Status HandleConvolution(HloInstruction* conv) override {
-    auto pattern = match::Op(&conv).WithOpcode(HloOpcode::kConvolution);
-    if (!Match(conv, pattern)) return absl::OkStatus();
-    if (!OneDnnConvolutionRewriter::ShouldRewrite(conv)) {
-      return absl::OkStatus();
-    }
-
-    const Shape& conv_shape = conv->shape();
-    auto dims = conv->window().dimensions().size();
-    const ConvolutionDimensionNumbers& conv_ddata =
-        conv->convolution_dimension_numbers();
-
-    BackendConfig backend_config;
-    OneDnnConvolutionConfig* conv_config =
-        backend_config.mutable_onednn_conv_config();
-
-    conv_config->set_dims(conv_shape.rank());
-    conv_config->set_feature_groups(conv->feature_group_count());
-    conv_config->mutable_input()->mutable_data()->set_batch_dim(
-        conv_ddata.input_batch_dimension());
-    conv_config->mutable_kernel()->mutable_filter()->set_input_feature_dim(
-        conv_ddata.kernel_input_feature_dimension());
-    conv_config->mutable_output()->mutable_data()->set_batch_dim(
-        conv_ddata.output_batch_dimension());
-    conv_config->mutable_input()->mutable_data()->set_feature_dim(
-        conv_ddata.input_feature_dimension());
-    conv_config->mutable_kernel()->mutable_filter()->set_output_feature_dim(
-        conv_ddata.kernel_output_feature_dimension());
-    conv_config->mutable_output()->mutable_data()->set_feature_dim(
-        conv_ddata.output_feature_dimension());
-
-    const Shape& output_shape = conv->shape();
-
-    for (auto it = conv->window().dimensions().begin();
-         it != conv->window().dimensions().end(); it++) {
-      if ((*it).padding_low() < 0 || (*it).padding_high() < 0 ||
-          (*it).stride() < 0) {
-        return absl::OkStatus();
-      }
-      conv_config->mutable_window()->add_pad_left((*it).padding_low() + 1);
-      conv_config->mutable_window()->add_pad_right((*it).padding_high() + 1);
-      conv_config->mutable_window()->add_strides((*it).stride() + 1);
-      conv_config->mutable_window()->add_window_dilations(
-          (*it).window_dilation() + 1);
-      if ((*it).base_dilation() != 1 || (*it).window_reversal()) {
-        return absl::OkStatus();
-      }
-    }
-
-    for (int i = 0; i < dims; i++) {
-      conv_config->mutable_input()->mutable_data()->add_spatial_dims(
-          conv_ddata.input_spatial_dimensions()[i] + 1);
-      conv_config->mutable_kernel()->mutable_filter()->add_spatial_dims(
-          conv_ddata.kernel_spatial_dimensions()[i] + 1);
-      conv_config->mutable_output()->mutable_data()->add_spatial_dims(
-          conv_ddata.output_spatial_dimensions()[i] + 1);
-    }
-
-    HloInstruction* custom_call =
-        conv->AddInstruction(HloInstruction::CreateCustomCall(
-            output_shape, {conv->mutable_operand(0), conv->mutable_operand(1)},
-            "__onednn$convolution"));
-
-    TF_RETURN_IF_ERROR(custom_call->set_backend_config(backend_config));
-    TF_RETURN_IF_ERROR(ReplaceInstruction(conv, custom_call));
-    return absl::OkStatus();
-  }
-};
-
-absl::StatusOr<bool> OneDnnConvolutionRewriter::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  OneDnnConvolutionRewriterVisitor visitor;
-  return visitor.RunOnModule(module, execution_threads);
-}
-
-}  // namespace cpu
-}  // namespace xla
-
-#endif  // INTEL_MKL && ENABLE_ONEDNN_V3
diff --git a/third_party/xla/xla/service/cpu/onednn_convolution_rewriter.h b/third_party/xla/xla/service/cpu/onednn_convolution_rewriter.h
deleted file mode 100644
index 2dbd3a6..0000000
--- a/third_party/xla/xla/service/cpu/onednn_convolution_rewriter.h
+++ /dev/null
@@ -1,50 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_CPU_ONEDNN_CONVOLUTION_REWRITER_H_
-#define XLA_SERVICE_CPU_ONEDNN_CONVOLUTION_REWRITER_H_
-#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3)
-
-#include <optional>
-
-#include "absl/algorithm/container.h"
-#include "absl/status/statusor.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace cpu {
-
-// This pass converts hlo convolution instructions into a single oneDNN
-// operation and rewrites into custom calls.
-class OneDnnConvolutionRewriter : public HloModulePass {
- public:
-  absl::string_view name() const override {
-    return "onednn-convolution-rewriter";
-  }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
-  static bool ShouldRewrite(const HloInstruction* instr);
-};
-
-}  // namespace cpu
-}  // namespace xla
-
-#endif  // INTEL_MKL && ENABLE_ONEDNN_V3
-#endif  // XLA_SERVICE_CPU_ONEDNN_CONVOLUTION_REWRITER_H_
diff --git a/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.cc b/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.cc
deleted file mode 100644
index 45c6bc1..0000000
--- a/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.cc
+++ /dev/null
@@ -1,1129 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3)
-
-#define EIGEN_USE_THREADS
-
-#include "xla/service/cpu/onednn_matmul_rewriter.h"
-
-#include "xla/executable_run_options.h"
-#include "xla/hlo/evaluator/hlo_evaluator.h"
-#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/service/cpu/backend_config.pb.h"
-#include "xla/service/cpu/onednn_config.pb.h"
-#include "xla/service/cpu/onednn_matmul.h"
-#include "xla/service/cpu/onednn_memory_util.h"
-#include "xla/service/cpu/onednn_pattern_utils.h"
-#include "xla/service/cpu/onednn_util.h"
-#include "xla/service/hlo_cost_analysis.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/status_macros.h"
-#include "xla/tsl/util/onednn_threadpool.h"
-#include "tsl/platform/logging.h"  // IWYU pragma: keep
-
-namespace xla {
-namespace cpu {
-
-namespace {
-namespace m = match;
-namespace pu = ::xla::cpu::onednn_pattern_utils_internal;
-
-inline absl::Status ValidateDotDimensionNumbers(
-    const DotDimensionNumbers& dim_numbers) {
-  // Checks some invariants that do not hold in general, but DotDecomposer
-  // should have established for us.
-  TF_RET_CHECK(dim_numbers.lhs_contracting_dimensions_size() == 1);
-  std::vector<int64_t> batch_dim_numbers(
-      dim_numbers.lhs_batch_dimensions_size());
-  absl::c_iota(batch_dim_numbers, 0);
-  TF_RET_CHECK(
-      absl::c_equal(batch_dim_numbers, dim_numbers.lhs_batch_dimensions()));
-  TF_RET_CHECK(
-      absl::c_equal(batch_dim_numbers, dim_numbers.rhs_batch_dimensions()));
-  return absl::OkStatus();
-}
-
-// Whether the element type of instr is compatible with oneDNN kernels.
-// TODO(intel-tf): Restict compatible types based on instruction kind.
-inline bool CompatibleElementType(const HloInstruction* instr) {
-  PrimitiveType element_type = instr->shape().element_type();
-  return element_type == BF16 || element_type == F32 || element_type == F16;
-}
-
-inline bool IsRowMajor(const Shape& shape) {
-  return LayoutUtil::IsMonotonicWithDim0Major(shape.layout());
-}
-
-template <typename Pattern>
-inline auto BitcastWithReshapeSemantics(HloInstruction** bitcast,
-                                        Pattern pattern) {
-  // TODO(intel-tf): Add stronger condition that Bitcast does not have transpose
-  // semantics. Some of the HLO passes replaces Transpose with Bitcast. Here
-  // the layouts are checked to be rowmajor since the current pass runs after
-  // the layout assignment and oneDNN matmul is enabled for rowmajor layouts.
-  auto is_reshape = [](const HloInstruction* instr) -> bool {
-    if (!instr) return false;
-    auto input_shape = instr->operand(0)->shape();
-    auto output_shape = instr->shape();
-    bool is_same_type = ShapeUtil::SameElementType(input_shape, output_shape);
-    bool has_equal_num_elems = ShapeUtil::ElementsIn(input_shape) ==
-                               ShapeUtil::ElementsIn(output_shape);
-    bool has_rowmajor_layout =
-        IsRowMajor(input_shape) && IsRowMajor(output_shape);
-    return is_same_type && has_equal_num_elems && has_rowmajor_layout;
-  };
-  return m::Bitcast(bitcast, pattern).WithPredicate(is_reshape);
-}
-
-template <typename Pattern>
-auto ElementwiseSafeIntermediates(HloInstruction** instr,
-                                  HloInstruction** optional_bitcast,
-                                  Pattern pattern) {
-  return m::AnyOf<HloInstruction>(
-      m::Broadcast(instr, pattern.WithOneUser()),
-      m::Slice(instr, pattern.WithOneUser()),
-      m::Bitcast(instr, pattern.WithOneUser()),
-      m::Reshape(instr, pattern.WithOneUser()),
-      pu::SupportedConvert(instr, pattern.WithOneUser()),
-      pu::SupportedConvert(instr, BitcastWithReshapeSemantics(
-                                      optional_bitcast, pattern.WithOneUser())),
-      pattern);
-}
-
-inline auto OneDnnMatmulInstr(HloInstruction** instr) {
-  return m::CustomCall(instr, {"__onednn$matmul"});
-}
-
-inline auto ConvertBF16ToF32(HloInstruction** instr) {
-  return m::Convert(m::Op(instr).WithElementType(PrimitiveType::BF16))
-      .WithElementType(PrimitiveType::F32);
-}
-
-inline auto BcastConstScalar(HloInstruction** instr, double value) {
-  return m::Broadcast(instr, m::ConstantScalar(value));
-}
-
-inline auto BcastConstScalar(double value) {
-  return BcastConstScalar(nullptr, value);
-}
-
-inline auto BcastConvertConstScalar(double value) {
-  return m::Broadcast(pu::OptionalConvert(m::ConstantScalar(value)));
-}
-
-inline bool IsBatchDot(const HloInstruction& instr) {
-  if (auto* dot_instr = DynCast<HloDotInstruction>(&instr)) {
-    return dot_instr->dot_dimension_numbers().lhs_batch_dimensions_size() > 0;
-  }
-  return false;
-}
-
-auto ConstScalarNear(double value) {
-  return m::ConstantScalar().WithPredicate(
-      [expected = value](const HloInstruction* instr) {
-        // Not a very robust floating-point comparison, but good enough for our
-        // purposes.
-        std::optional<double> actual =
-            static_cast<const HloConstantInstruction*>(instr)
-                ->literal()
-                .GetAsDouble({});
-        if (!actual.has_value()) return false;
-        double epsilon;
-        switch (instr->shape().element_type()) {
-          case F16:
-            epsilon = 128 * std::numeric_limits<Eigen::half>::epsilon();
-            break;
-          case BF16:
-            epsilon = 128 * std::numeric_limits<bfloat16>::epsilon();
-            break;
-          case F32:
-            epsilon = 128 * std::numeric_limits<float>::epsilon();
-            break;
-          case F64:
-            epsilon = 128 * std::numeric_limits<double>::epsilon();
-            break;
-          default:
-            return false;
-        }
-        return abs(*actual - expected) < (abs(*actual + expected) * epsilon);
-      });
-}
-
-bool IsScalar(const HloInstruction* instr) {
-  return ShapeUtil::IsEffectiveScalar(instr->shape());
-}
-
-std::optional<float> GetConstantValueAsFloat32(const HloInstruction* inst) {
-  if (!IsScalar(inst)) {
-    return std::nullopt;
-  }
-  switch (inst->shape().element_type()) {
-    case F16:
-      return inst->literal().GetFirstElement<half>();
-    case BF16:
-      return inst->literal().GetFirstElement<bfloat16>();
-    case F32:
-      return inst->literal().GetFirstElement<float>();
-    default:
-      return std::nullopt;
-  }
-}
-
-inline auto BcastConstScalarNear(double value) {
-  return m::Broadcast(ConstScalarNear(value));
-}
-
-// Associativity and commutativity properties of multiply results in various
-// patterns for an equivalent computation. This function tries to capture most
-// of the variations for a computation a * b * c. For example, patterns could be
-// any of (a * b) * c or a * (b * c), along with the variations resulting from
-// commutative patterns.
-template <typename PatternA, typename PatternB, typename PatternC>
-inline auto MultiplyMultiplyAnyOrder(PatternA a, PatternB b, PatternC c) {
-  return m::AnyOf<HloInstruction>(
-      m::MultiplyAnyOrder(a, m::MultiplyAnyOrder(b, c)),
-      m::MultiplyAnyOrder(b, m::MultiplyAnyOrder(a, c)),
-      m::MultiplyAnyOrder(c, m::MultiplyAnyOrder(a, b)));
-}
-
-auto GELUActivation(HloInstruction* instr, HloInstruction** src) {
-  // Attempt to match GELU_TANH activation or GELU_ERF activation
-  // (https://arxiv.org/abs/1606.08415), where:
-  // gelu_tanh(x) = x * cdf(x)
-  // cdf(x) = 0.5 * (1 + tanh(sqrt(2 / pi) * (x + 0.044715 * x**3))
-  //                     -------------errf_approximate------------
-  //
-  // gelu_erf(x) = x * cdf(x)
-  // cdf(x) = 0.5 * (1 + erf(x / sqrt(2)))
-  //                     --errf_exact--
-
-  HloInstruction* errf;
-
-  // The expression 0.5 * x * (1 + errf) as common pattern for GELU exact and
-  // approximate activations.
-  auto common_pattern = MultiplyMultiplyAnyOrder(
-      BcastConstScalar(0.5), m::Op(src),
-      m::AddAnyOrder(BcastConstScalar(1.0), m::Op(&errf).WithOneUser()));
-
-  bool matched = Match(instr, common_pattern);
-  if (matched) {
-    // The subexpression 0.044715 * x**3 appears in GELU approximate activation.
-    // However, it is often optimized by other HLO passes into an expression of
-    // 0.044715 * x * (x * x). Since there are three consecutive multiplies,
-    // there could be a large number of patterns. We try to capture some of
-    // those:
-    //
-    //      1. (0.044715 * x) * x * x
-    //      2. 0.044715 * (x * x) * x
-    //
-    // Note each of the above could in turn have various patterns due to
-    // associativity and commutativity properties of multiply.
-    auto subexpr_pattern = m::AnyOf<HloInstruction>(
-        MultiplyMultiplyAnyOrder(
-            m::MultiplyAnyOrder(BcastConstScalarNear(0.044715),
-                                m::Op().Is(*src)),
-            m::Op().Is(*src), m::Op().Is(*src)),
-        MultiplyMultiplyAnyOrder(
-            BcastConstScalarNear(0.044715),
-            m::Multiply(m::Op().Is(*src), m::Op().Is(*src)), m::Op().Is(*src)));
-
-    auto errf_apprx_pattern =
-        m::Tanh(m::MultiplyAnyOrder(
-                    BcastConstScalarNear(sqrt(M_2_PI)),
-                    m::AddAnyOrder(m::Op().Is(*src), subexpr_pattern)
-                        .WithOneUser()))
-            .WithOneUser();
-
-    HloInstruction* erf;
-    auto errf_exact_pattern =
-        m::Op(&erf)
-            .WithOpcode(HloOpcode::kErf)
-            .WithOperand(
-                0, m::MultiplyAnyOrder(m::Op(src),
-                                       m::AnyOf<HloInstruction>(
-                                           BcastConstScalarNear(0.707106769),
-                                           BcastConstScalarNear(0.70703125),
-                                           BcastConstScalarNear(0.707182348)))
-                       .WithOneUser())
-            .WithOneUser();
-
-    if (Match(errf, errf_apprx_pattern)) {
-      // Matched Gelu-approximate pattern
-      return OneDnnFusionConfig::GELU_TANH;
-    } else if (Match(errf, errf_exact_pattern)) {
-      // Matched Gelu-exact pattern
-      return OneDnnFusionConfig::GELU_ERF;
-    }
-  }
-  return OneDnnFusionConfig::UNDEFINED;
-}
-
-// OneDNN matmul can fuse add operation with automatic broadcasting along the
-// addend's dimensions that are 1s. When compatible, Broadcast can be replaced
-// by Bitcast, which is much cheaper. Compute new shape for the Bitcast.
-absl::StatusOr<Shape> AdjustBiasShape(const HloInstruction* broadcast_instr,
-                                      const Shape& dot_shape) {
-  if (broadcast_instr->opcode() != HloOpcode::kBroadcast) {
-    return absl::InvalidArgumentError(
-        "Hlo instruction is not a Broadcast insruction.");
-  }
-  auto bcast = Cast<HloBroadcastInstruction>(broadcast_instr);
-  Shape new_shape = bcast->shape();
-  // Broadcast instruction has "dimensions" parameter along which its input's
-  // dimensions should not change. For example,
-  //      dot = f32[3,4,5,6] dot(...)
-  //      arg = f32[3,6]{1,0} parameter(0)
-  //      broad = f32[3,4,5,6]{3,2,1,0} broadcast(arg), dimensions={0,3}
-  //      add = f32[3,4,5,6]{3,2,1,0} add(dot, arg)
-  // can be replaced with the following
-  //      arg = f32[3,6]{1,0} parameter(0)
-  //      bitcast = f32[3,1,1,6]{3,2,1,0} bitcast(arg)
-  //      fused = f32[3,4,5,6]{3,2,1,0} custom-call((..., bitcast)
-  auto kept_dimensions = bcast->dimensions();
-  for (int i = 0; i < new_shape.rank(); i++) {
-    if (!absl::c_linear_search(kept_dimensions, i)) {
-      new_shape.set_dimensions(i, 1);
-    }
-  }
-
-  // If rank(new_shape) > rank(dot), extra dimensions with value = 1 can be
-  // deleted from the new_shape.
-  int64_t rank_difference = new_shape.rank() - dot_shape.rank();
-  auto new_dims = new_shape.dimensions();
-  std::vector<int64_t> dims_to_delete;
-  for (int i = 0; i < rank_difference; ++i) {
-    if (new_dims[i] == 1) {
-      dims_to_delete.push_back(i);
-    }
-  }
-  new_shape = ShapeUtil::DeleteDimensions(dims_to_delete, new_shape);
-
-  // New shape for bias should satisfy the condition:
-  //   rank(new_shape) <= rank(dot).
-  if (new_shape.rank() > dot_shape.rank()) {
-    return absl::CancelledError(
-        "Bias shape could not be adjusted for a fusion.");
-  }
-
-  return new_shape;
-};
-
-inline bool IsOperandFusible(HloInstruction* operand, HloInstruction* dot) {
-  // Check if the operand's shape is compatible with matmul for fusion.
-  // An operand is fusable if
-  //    1. rank(operand) <= rank(dot) and
-  //    2. Starting from the last dim in backward direction, the dimension
-  //       size of operand is either 1 or same to dot.
-  auto operand_dims = operand->shape().dimensions();
-  auto dot_dims = dot->shape().dimensions();
-  if (operand_dims.size() > dot_dims.size()) return false;
-  int operand_idx = operand_dims.size() - 1;
-  int dot_idx = dot_dims.size() - 1;
-  for (; operand_idx >= 0; --operand_idx, --dot_idx) {
-    if (operand_dims[operand_idx] != 1 &&
-        operand_dims[operand_idx] != dot_dims[dot_idx])
-      return false;
-  }
-  return true;
-}
-
-template <typename Pattern>
-inline auto OptionalConvertAndBitcast(HloInstruction** optional_convert,
-                                      HloInstruction** optional_bitcast,
-                                      Pattern pattern) {
-  // Checks the presence of some intermediate operations that can be moved /
-  // folded to allow dot fusion with add.
-  // Try to match either of the following:
-  //   1. pattern-root -> bf16/f16-to-fp32 convert -> bitcast
-  //   2. pattern-root -> bf16/f16-to-fp32 convert
-  //   3. pattern-root -> bitcast
-  //   4. pattern-root
-  auto common = m::AnyOf<HloInstruction>(
-      pu::SupportedConvert(optional_convert, std::move(pattern).WithOneUser())
-          .WithElementType(PrimitiveType::F32),
-      std::move(pattern).WithOneUser());
-  return m::AnyOf<HloInstruction>(
-      BitcastWithReshapeSemantics(optional_bitcast, common), common);
-}
-
-}  // namespace
-
-bool OneDnnMatMulRewriter::ShouldRewrite(const HloInstruction* dot_instr) {
-  // Currently, blocking control dependencies
-  if (dot_instr->HasControlDependencies()) return false;
-  if (!IsSupportedType(dot_instr->shape().element_type())) return false;
-  if (dot_instr->operands().size() != 2) return false;
-
-  // Currently, we rewrite when the data type is F32 or BF16. Note we do not
-  // need to check equality of contraction dim-size of the operands. HLO
-  // verifier already does the job. We, however, need to check if contraction
-  // is over only 1 dimension (a.k.a. K dimension in matrix-multiplication
-  // parlance). We also restrict that batch dimensions of the operands
-  // match.
-  const Shape& lhs_shape = dot_instr->operand(0)->shape();
-  const Shape& rhs_shape = dot_instr->operand(1)->shape();
-  const Shape& output_shape = dot_instr->shape();
-  // None of the operands and result should be ZeroElementArray.
-  if (ShapeUtil::IsZeroElementArray(lhs_shape) ||
-      ShapeUtil::IsZeroElementArray(rhs_shape) ||
-      ShapeUtil::IsZeroElementArray(output_shape)) {
-    return false;
-  }
-  // OneDNN only supports rank <= kOneDnnMaxNDims and singular non-contracting
-  // dimensions. We should not rewrite if any of these conditions are violated.
-  if (lhs_shape.rank() <= 0 || lhs_shape.rank() > kOneDnnMaxNDims ||
-      rhs_shape.rank() <= 0 || rhs_shape.rank() > kOneDnnMaxNDims ||
-      output_shape.rank() > std::min({lhs_shape.rank(), rhs_shape.rank(),
-                                      static_cast<int64_t>(kOneDnnMaxNDims)})) {
-    return false;
-  }
-
-  // Layout should be row-major, contraction dimensions captures transpose
-  // scenarios in last two dimensions.
-  // Col-major layouts are corrected to row-majow for BatchDot operation as
-  // part of the layout-pass.
-  if (!IsBatchDot(*dot_instr) &&
-      (!IsRowMajor(lhs_shape) || !IsRowMajor(rhs_shape) ||
-       !IsRowMajor(output_shape))) {
-    return false;
-  }
-
-  auto dot_dim_numbers = dot_instr->dot_dimension_numbers();
-  int64_t lhs_dim_k = dot_dim_numbers.lhs_contracting_dimensions(0);
-  int64_t rhs_dim_k = dot_dim_numbers.rhs_contracting_dimensions(0);
-  // Supported contraction is only in one of last two dimensions.
-  if (lhs_dim_k < lhs_shape.rank() - 2 || rhs_dim_k < rhs_shape.rank() - 2) {
-    return false;
-  }
-
-  // OneDNN matmul has scratch allocation and copy overheads. The overheads
-  // can be amortized if there is sufficient number of flops. We don't rewrite
-  // for small cases (determined empirically).
-  // TODO(intel-tf): Relax the condition when more optimizations in oneDNN
-  // matmul is achieved.
-  auto num_flops = xla::HloCostAnalysis::GetDotFlops(lhs_shape, output_shape,
-                                                     dot_dim_numbers);
-  auto rank = output_shape.rank();
-  auto flops_threshold = (rank <= 2) ? (1 << 24) : (1 << 19);
-  return (num_flops >= flops_threshold);
-}
-
-class OneDnnMatMulRewriteVisitor : public DfsHloRewriteVisitor {
- public:
-  // Matches patterns for possible MatMul fusions that are supported by oneDNN
-  // library. Matched HLO instruction(s) are replaced by custom call.
-  absl::Status HandleDot(HloInstruction* instr) override {
-    HloInstruction* dot_instr;
-    auto pattern = m::Op(&dot_instr).WithOpcode(HloOpcode::kDot);
-    if (!Match(instr, pattern)) return absl::OkStatus();
-
-    TF_RETURN_IF_ERROR(
-        ValidateDotDimensionNumbers(dot_instr->dot_dimension_numbers()));
-    if (!OneDnnMatMulRewriter::ShouldRewrite(dot_instr))
-      return absl::OkStatus();
-    TF_ASSIGN_OR_RETURN(dot_instr, ReconfigureDotDimensions(dot_instr));
-    auto dot_dim_numbers = dot_instr->dot_dimension_numbers();
-    const Shape& lhs_shape = dot_instr->operand(0)->shape();
-    const Shape& rhs_shape = dot_instr->operand(1)->shape();
-    const Shape& output_shape = dot_instr->shape();
-
-    int64_t lhs_dim_k = dot_dim_numbers.lhs_contracting_dimensions(0);
-    int64_t rhs_dim_k = dot_dim_numbers.rhs_contracting_dimensions(0);
-
-    HloInstruction* matmul_call =
-        dot_instr->AddInstruction(HloInstruction::CreateCustomCall(
-            output_shape,
-            {dot_instr->mutable_operand(0), dot_instr->mutable_operand(1)},
-            "__onednn$matmul"));
-    // Set additional info via config, e.g., transpose and fusion info.
-    BackendConfig backend_config;
-    OneDnnMatMulConfig* matmul_config =
-        backend_config.mutable_onednn_matmul_config();
-    bool transpose_a = (lhs_dim_k != lhs_shape.rank() - 1);
-    bool transpose_b = (rhs_dim_k != rhs_shape.rank() - 2);
-    matmul_config->set_transpose_a(transpose_a);
-    matmul_config->set_transpose_b(transpose_b);
-    TF_RETURN_IF_ERROR(matmul_call->set_backend_config(backend_config));
-    TF_RETURN_IF_ERROR(ReplaceInstruction(dot_instr, matmul_call));
-    return absl::OkStatus();
-  }
-
-  absl::Status HandleAdd(HloInstruction* instr) override {
-    // Try to do a fusion for Dot(onednn-matmul) + Add. However,
-    // HLO Add instruction might receive the addends after additional
-    // processing like Broadcast, Bitcast, Convert, etc. is applied to the raw
-    // addends. Here, the following possible pattern is matched.
-    //
-    // clang-format off
-    //
-    //            Dot                               addend
-    //             |                                  |
-    //             v                                  v
-    //     optional instructions            optional instructions
-    //    (e.g, Convert, Bitcast)         (e.g, Convert, Broadcast)
-    //             |                                  |
-    //             +--------------+-------------------+
-    //                            |
-    //                            v
-    //                           Add
-    //
-    // clang-format on
-
-    HloInstruction *addend_intermediate, *dot;
-    HloInstruction* optional_dot_bitcast = nullptr;
-    HloInstruction* optional_dot_convert = nullptr;
-
-    auto pattern = m::AddAnyOrder(
-        &instr,
-        OptionalConvertAndBitcast(&optional_dot_convert, &optional_dot_bitcast,
-                                  OneDnnMatmulInstr(&dot))
-            .WithOneUser(),
-        m::Op(&addend_intermediate));
-
-    if (Match(instr, pattern)) {
-      if (!IsSupportedType(dot->shape().element_type()))
-        return absl::OkStatus();
-      // TODO(intel-tf): Remove the condition below when the fusion Dot +
-      // Add(bias) + Add(e.g., residual) is enabled.
-      if (!dot->backend_config<BackendConfig>()
-               ->mutable_onednn_matmul_config()
-               ->mutable_fusions()
-               ->ops()
-               .empty() &&
-          dot->backend_config<BackendConfig>()
-                  ->mutable_onednn_matmul_config()
-                  ->mutable_fusions()
-                  ->ops(0) == OneDnnFusionConfig::BIAS) {
-        return absl::OkStatus();
-      }
-      std::vector<HloInstruction*> new_operands;
-      for (auto operand : dot->operands()) {
-        new_operands.push_back(operand);
-      }
-
-      // At this point, the addend could have one of the following
-      // possiblities that the current fusion can handle:
-      //
-      //   - addend -> Convert -> Broadcast -> Add
-      //   - addend -> Broadcast -> Convert -> Add
-      //   - addend -> Convert
-      //   - addend -> Broadcast
-      //   - addend
-      //
-      // Hunt for addend through possible sequences above and check the addend
-      // is compatible to onednn-matmul fusion.
-      HloInstruction* addend = nullptr;
-      HloInstruction* optional_addend_broadcast = nullptr;
-      auto addend_pattern = m::AnyOf<HloInstruction>(
-          m::Broadcast(&optional_addend_broadcast,
-                       m::Convert(&addend, m::Op())),
-          m::Convert(m::Broadcast(&optional_addend_broadcast, m::Op(&addend))),
-          m::Convert(&addend, m::Op()),
-          m::Broadcast(&optional_addend_broadcast, m::Op(&addend)),
-          m::Op(&addend));
-      if (!Match(addend_intermediate, addend_pattern)) return absl::OkStatus();
-
-      if (optional_addend_broadcast && addend->shape().rank() != 1) {
-        auto new_shape =
-            AdjustBiasShape(optional_addend_broadcast, dot->shape());
-        if (new_shape.ok()) {
-          addend = addend->AddInstruction(
-              HloInstruction::CreateBitcast(new_shape.value(), addend));
-        } else {
-          VLOG(2) << new_shape.status();
-          return absl::OkStatus();
-        }
-      }
-
-      // Validate addend for fusion.
-      if (IsSupportedType(addend->shape().element_type()) &&
-          IsOperandFusible(addend, dot)) {
-        new_operands.push_back(addend);
-      } else {
-        return absl::OkStatus();
-      }
-
-      // TODO(intel-tf): Remove this restriction once oneDNN has an optimized
-      // implementation for broadcasted add across all dimensions.
-      OneDnnFusionConfig_FusionKind kind = OneDnnFusionConfig::UNDEFINED;
-      kind = (addend->shape().rank() == 1)
-                 ? (dot->backend_config<BackendConfig>()
-                            ->mutable_onednn_matmul_config()
-                            ->fusions()
-                            .ops()
-                            .empty()
-                        ? OneDnnFusionConfig::BIAS
-                        : OneDnnFusionConfig::UNDEFINED)
-                 : OneDnnFusionConfig::BINARY_ADD;
-      if (kind == OneDnnFusionConfig::UNDEFINED) return absl::OkStatus();
-
-      auto matmul_call = Cast<HloCustomCallInstruction>(instr->AddInstruction(
-          dot->CloneWithNewOperands(dot->shape(), new_operands)));
-
-      auto backend_config = matmul_call->backend_config<BackendConfig>();
-      backend_config->mutable_onednn_matmul_config()
-          ->mutable_fusions()
-          ->add_ops(kind);
-
-      if (optional_addend_broadcast) {
-        backend_config->mutable_onednn_matmul_config()
-            ->mutable_optimization_config()
-            ->set_bias_broadcast(true);
-      }
-      TF_RETURN_IF_ERROR(matmul_call->set_backend_config(*backend_config));
-
-      HloInstruction* new_instr;
-      // If matched pattern has custom-call -> bitcast -> add, then we need to
-      // insert bitcast after the new fusion to maintain the correct shape
-      // (new-custom-call -> bitcast). Also, this will optionally be followed
-      // by -> convert for bf16 case to avoid datatype mismatch.
-      if (optional_dot_bitcast != nullptr &&
-          optional_dot_bitcast->opcode() == HloOpcode::kBitcast) {
-        if (optional_dot_convert != nullptr &&
-            optional_dot_convert->opcode() == HloOpcode::kConvert) {
-          auto bitcast_call =
-              matmul_call->AddInstruction(HloInstruction::CreateBitcast(
-                  ShapeUtil::ChangeElementType(
-                      instr->shape(), matmul_call->shape().element_type()),
-                  matmul_call));
-          new_instr =
-              bitcast_call->AddInstruction(HloInstruction::CreateConvert(
-                  ShapeUtil::ChangeElementType(
-                      bitcast_call->shape(),
-                      optional_dot_convert->shape().element_type()),
-                  bitcast_call));
-        } else {
-          new_instr = matmul_call->AddInstruction(
-              HloInstruction::CreateBitcast(instr->shape(), matmul_call));
-        }
-      } else {
-        if (optional_dot_convert != nullptr &&
-            optional_dot_convert->opcode() == HloOpcode::kConvert) {
-          new_instr = matmul_call->AddInstruction(HloInstruction::CreateConvert(
-              ShapeUtil::ChangeElementType(
-                  matmul_call->shape(),
-                  optional_dot_convert->shape().element_type()),
-              matmul_call));
-        } else {
-          new_instr = matmul_call;
-        }
-      }
-      TF_RETURN_IF_ERROR(ReplaceInstruction(instr, new_instr));
-    }
-
-    return absl::OkStatus();
-  }
-
-  absl::Status HandleMaximum(HloInstruction* instr) override {
-    HloInstruction* matmul_call;
-    HloInstruction* intermediate_instr = nullptr;
-    HloInstruction* optional_bitcast = nullptr;
-    // Attempt to elide maximum and fuse ReLU activation into GEMM, including
-    // when slicing or bitcasting is applied to the result.
-    if (Match(instr,
-              m::MaximumAnyOrder(ElementwiseSafeIntermediates(
-                                     &intermediate_instr, &optional_bitcast,
-                                     OneDnnMatmulInstr(&matmul_call))
-                                     .WithOneUser(),
-                                 BcastConstScalar(0)))) {
-      return FuseActivation(OneDnnFusionConfig::RELU, instr, matmul_call,
-                            intermediate_instr, optional_bitcast);
-    }
-    return absl::OkStatus();
-  }
-
-  auto ELUActivation(HloInstruction* instr, HloInstruction** src) {
-    //  Reference: tensorflow/compiler/tf2xla/kernels/elu_op.cc
-    //  const auto zero = ScalarLike(x, 0);
-    //  const auto pred = Gt(x, zero);
-    //  const auto expm1 = Expm1(x);
-    //  return Select(pred, x, expm1);
-    auto pattern = m::Select(
-        m::Gt(pu::OptionalConvert(m::Op(src)), BcastConvertConstScalar(0)),
-        m::Op(src),
-        pu::OptionalConvert(m::Expm1(pu::OptionalConvert(m::Op(src)))));
-    return Match(instr, pattern);
-  }
-
-  absl::Status HandleSelect(HloInstruction* instr) override {
-    HloInstruction* matmul_call;
-    HloInstruction* intermediate_instr = nullptr;
-    HloInstruction* optional_bitcast = nullptr;
-    HloInstruction* src;
-    // Attempt to elide ELU subgraph and fuse ELU activation into GEMM,
-    // including when slicing or bitcasting is applied to the result.
-    if (ELUActivation(instr, &src)) {
-      if (Match(src, ElementwiseSafeIntermediates(
-                         &intermediate_instr, &optional_bitcast,
-                         OneDnnMatmulInstr(&matmul_call)))) {
-        return FuseActivation(OneDnnFusionConfig::ELU, instr, matmul_call,
-                              intermediate_instr);
-      }
-    }
-    return absl::OkStatus();
-  }
-
-  absl::Status HandleTanh(HloInstruction* instr) override {
-    HloInstruction* matmul_call;
-    HloInstruction* intermediate_instr = nullptr;
-    HloInstruction* optional_bitcast = nullptr;
-    // Attempt to elide Tanh and fuse Tanh activation into GEMM, including
-    // when slicing or bitcasting is applied to the result.
-    if (Match(instr, m::Tanh(ElementwiseSafeIntermediates(
-                                 &intermediate_instr, &optional_bitcast,
-                                 OneDnnMatmulInstr(&matmul_call))
-                                 .WithOneUser()))) {
-      return FuseActivation(OneDnnFusionConfig::TANH, instr, matmul_call,
-                            intermediate_instr);
-    }
-    return absl::OkStatus();
-  }
-
-  absl::Status HandleClamp(HloInstruction* instr) override {
-    HloInstruction* matmul_call;
-    HloInstruction* intermediate_instr = nullptr;
-    HloInstruction* optional_bitcast = nullptr;
-    // Attempt to elide RELU6 and fuse RELU6 activation into GEMM, including
-    // when slicing or bitcasting is applied to the result.
-    if (Match(instr, m::Clamp(BcastConstScalar(0),
-                              ElementwiseSafeIntermediates(
-                                  &intermediate_instr, &optional_bitcast,
-                                  OneDnnMatmulInstr(&matmul_call))
-                                  .WithOneUser(),
-                              BcastConstScalar(6)))) {
-      return FuseActivation(OneDnnFusionConfig::RELU6, instr, matmul_call,
-                            intermediate_instr);
-    }
-    return absl::OkStatus();
-  }
-
-  absl::Status HandleMultiply(HloInstruction* instr) override {
-    HloInstruction* matmul_call;
-    HloInstruction* intermediate_instr = nullptr;
-    HloInstruction* src;
-    auto activation = GELUActivation(instr, &src);
-    if (activation != OneDnnFusionConfig::UNDEFINED) {
-      HloInstruction* optional_bitcast = nullptr;
-      if (Match(src, ElementwiseSafeIntermediates(
-                         &intermediate_instr, &optional_bitcast,
-                         OneDnnMatmulInstr(&matmul_call)))) {
-        return FuseActivation(activation, instr, matmul_call,
-                              intermediate_instr, optional_bitcast);
-      }
-    }
-
-    HloInstruction *dot, *constant;
-    HloInstruction* optional_convert = nullptr;
-    auto pattern = m::Op(&instr)
-                       .WithOpcode(HloOpcode::kMultiply)
-                       .WithBinaryOperandsAnyOrder(
-                           m::AnyOf<HloInstruction>(
-                               pu::SupportedConvert(&optional_convert,
-                                                    OneDnnMatmulInstr(&dot))
-                                   .WithElementType(PrimitiveType::F32),
-                               OneDnnMatmulInstr(&dot))
-                               .WithOneUser(),
-                           m::Broadcast(m::Constant(&constant)));
-
-    if (Match(instr, pattern)) {
-      std::vector<HloInstruction*> new_operands;
-      auto constant_value = GetConstantValueAsFloat32(constant);
-      if (!constant_value) {
-        return absl::OkStatus();
-      }
-
-      for (auto operand : dot->operands()) {
-        new_operands.push_back(operand);
-      }
-      auto matmul_call = Cast<HloCustomCallInstruction>(instr->AddInstruction(
-          dot->CloneWithNewOperands(instr->shape(), new_operands)));
-      auto backend_config = matmul_call->backend_config<BackendConfig>();
-      backend_config->mutable_onednn_matmul_config()
-          ->mutable_fusions()
-          ->add_ops(OneDnnFusionConfig::LINEAR);
-      // Casting to int32 because of issues in proto config for decimal types
-      // handling.
-      backend_config->mutable_onednn_matmul_config()
-          ->mutable_fusions()
-          ->set_alpha_typecast(
-              *(reinterpret_cast<int32_t*>(&constant_value.value())));
-      TF_RETURN_IF_ERROR(matmul_call->set_backend_config(*backend_config));
-      HloInstruction* new_instr;
-      if (optional_convert != nullptr &&
-          optional_convert->opcode() == HloOpcode::kConvert) {
-        new_instr = matmul_call->AddInstruction(HloInstruction::CreateConvert(
-            ShapeUtil::ChangeElementType(
-                matmul_call->shape(), optional_convert->shape().element_type()),
-            matmul_call));
-      } else {
-        new_instr = matmul_call;
-      }
-
-      TF_RETURN_IF_ERROR(ReplaceInstruction(instr, new_instr));
-    }
-    return absl::OkStatus();
-  }
-
-  auto SigmoidActivation(HloInstruction* instr, HloInstruction** src) {
-    return Match(instr,
-                 m::Divide(BcastConstScalar(1.0),
-                           m::AddAnyOrder(BcastConstScalar(1.0),
-                                          m::Exp(m::Negate(m::Op(src))))));
-  }
-
-  absl::Status HandleDivide(HloInstruction* instr) override {
-    HloInstruction* matmul_call;
-    HloInstruction* intermediate_instr = nullptr;
-    HloInstruction* optional_bitcast = nullptr;
-    HloInstruction* src;
-    if (SigmoidActivation(instr, &src)) {
-      if (Match(src, ElementwiseSafeIntermediates(
-                         &intermediate_instr, &optional_bitcast,
-                         OneDnnMatmulInstr(&matmul_call))
-                         .WithOneUser())) {
-        return FuseActivation(OneDnnFusionConfig::SIGMOID, instr, matmul_call,
-                              intermediate_instr, optional_bitcast);
-      }
-    }
-    return absl::OkStatus();
-  }
-
-  absl::Status FuseActivation(OneDnnFusionConfig_FusionKind kind,
-                              HloInstruction* activation,
-                              HloInstruction* matmul,
-                              HloInstruction* intermediate_instr = nullptr,
-                              HloInstruction* optional_bitcast = nullptr) {
-    TF_ASSIGN_OR_RETURN(auto backend_config,
-                        matmul->backend_config<BackendConfig>());
-    auto* matmul_config = backend_config.mutable_onednn_matmul_config();
-    matmul_config->mutable_fusions()->add_ops(kind);
-    TF_RETURN_IF_ERROR(matmul->set_backend_config(backend_config));
-    std::unique_ptr<HloInstruction> output = matmul->Clone();
-    if (optional_bitcast != nullptr &&
-        optional_bitcast->opcode() == HloOpcode::kBitcast) {
-      HloInstruction* new_instr = nullptr;
-      if (intermediate_instr != nullptr &&
-          intermediate_instr->opcode() == HloOpcode::kConvert) {
-        auto bitcast_call =
-            matmul->AddInstruction(HloInstruction::CreateBitcast(
-                ShapeUtil::ChangeElementType(optional_bitcast->shape(),
-                                             matmul->shape().element_type()),
-                matmul));
-        new_instr = bitcast_call->AddInstruction(HloInstruction::CreateConvert(
-            ShapeUtil::ChangeElementType(
-                bitcast_call->shape(),
-                intermediate_instr->shape().element_type()),
-            bitcast_call));
-        return ReplaceInstruction(activation, new_instr);
-      }
-    } else if (intermediate_instr) {
-      output = intermediate_instr->CloneWithNewOperands(
-          intermediate_instr->shape(),
-          {matmul->parent()->AddInstruction(std::move(output))});
-    }
-
-    return ReplaceWithNewInstruction(activation, std::move(output));
-  }
-
-  // This function changes dot instruction for supported matrix
-  // multiplication scenarios. In particular, it changes the shape
-  // of lhs, rhs and result arrays.
-  //    - lhs configuration scenario
-  //      lhs:    [batch_dims,contracting_dim] to [batch_dims,1,contracting_dim]
-  //      result: [batch_dims,feature_dim] to [batch_dims,1,feature_dim]
-  //
-  //    - rhs configuration scenario
-  //      rhs:    [batch_dims,contracting_dim] to [batch_dims,contracting_dim,1]
-  //      result: [batch_dims,feature_dim] to [batch_dims,feature_dim, 1]
-  //
-  //    - both lhs and rhs configuration scenario
-  //      lhs:    [batch_dims,contracting_dim] to [batch_dims,1,contracting_dim]
-  //      rhs:    [batch_dims,contracting_dim] to [batch_dims,contracting_dim,1]
-  //      result: [batch_dims] to [batch_dims,1,1]
-  absl::StatusOr<HloInstruction*> ReconfigureDotDimensions(
-      HloInstruction* dot_instr) {
-    HloInstruction* lhs = dot_instr->mutable_operand(0);
-    HloInstruction* rhs = dot_instr->mutable_operand(1);
-    DotDimensionNumbers dim_numbers = dot_instr->dot_dimension_numbers();
-
-    auto lhs_batch_dims = dim_numbers.lhs_batch_dimensions();
-    auto lhs_contraction_dims = dim_numbers.lhs_contracting_dimensions();
-    bool is_lhs_vector = lhs->shape().rank() ==
-                         (lhs_batch_dims.size() + lhs_contraction_dims.size());
-
-    auto rhs_batch_dims = dim_numbers.rhs_batch_dimensions();
-    auto rhs_contraction_dims = dim_numbers.rhs_contracting_dimensions();
-    bool is_rhs_vector = rhs->shape().rank() ==
-                         (rhs_batch_dims.size() + rhs_contraction_dims.size());
-
-    if (!is_lhs_vector && !is_rhs_vector) return dot_instr;
-
-    std::vector<int64_t> adjusted_lhs_dims(lhs->shape().dimensions().begin(),
-                                           lhs->shape().dimensions().end());
-    std::vector<int64_t> adjusted_rhs_dims(rhs->shape().dimensions().begin(),
-                                           rhs->shape().dimensions().end());
-    std::vector<int64_t> adjusted_dot_dims(
-        dot_instr->shape().dimensions().begin(),
-        dot_instr->shape().dimensions().end());
-
-    if (is_lhs_vector) {
-      auto lhs_it = adjusted_lhs_dims.begin() + lhs_batch_dims.size();
-      adjusted_lhs_dims.insert(lhs_it, 1, 1);
-      auto result_it = adjusted_dot_dims.begin() + lhs_batch_dims.size();
-      adjusted_dot_dims.insert(result_it, 1, 1);
-      auto lhs_contraction_dim =
-          dot_instr->dot_dimension_numbers().lhs_contracting_dimensions(0);
-      dim_numbers.set_lhs_contracting_dimensions(0, lhs_contraction_dim + 1);
-      lhs = lhs->AddInstruction(HloInstruction::CreateBitcast(
-          ShapeUtil::MakeShape(lhs->shape().element_type(), adjusted_lhs_dims),
-          lhs));
-    }
-
-    if (is_rhs_vector) {
-      auto it = adjusted_rhs_dims.end();
-      adjusted_rhs_dims.insert(it, 1, 1);
-      auto result_it = adjusted_dot_dims.end();
-      adjusted_dot_dims.insert(result_it, 1, 1);
-      rhs = rhs->AddInstruction(HloInstruction::CreateBitcast(
-          ShapeUtil::MakeShape(rhs->shape().element_type(), adjusted_rhs_dims),
-          rhs));
-    }
-
-    HloInstruction* adjusted_dot =
-        dot_instr->AddInstruction(HloInstruction::CreateDot(
-            ShapeUtil::MakeShape(dot_instr->shape().element_type(),
-                                 adjusted_dot_dims),
-            lhs, rhs, dim_numbers, dot_instr->precision_config()));
-
-    HloInstruction* replacement_instr = adjusted_dot->AddInstruction(
-        HloInstruction::CreateBitcast(dot_instr->shape(), adjusted_dot));
-
-    TF_RETURN_IF_ERROR(ReplaceInstruction(dot_instr, replacement_instr));
-    return adjusted_dot;
-  }
-};
-
-class OneDnnPostRewriteVisitor : public DfsHloRewriteVisitor {
- public:
-  OneDnnPostRewriteVisitor(int intra_op_parallelism,
-                           const tsl::thread::ThreadPool* compile_threadpool)
-      : intra_op_parallelism_(intra_op_parallelism > 0
-                                  ? intra_op_parallelism
-                                  : tsl::port::MaxParallelism()),
-        evaluator_(/*max_loop_iterations=*/0) {
-    if (compile_threadpool) {
-      threadpool_device_.reset(
-          new Eigen::ThreadPoolDevice(compile_threadpool->AsEigenThreadPool(),
-                                      compile_threadpool->NumThreads()));
-    } else {
-      threadpool_handle_.reset(new tsl::thread::ThreadPool(
-          tsl::Env::Default(), "XLACpuCompile", tsl::port::MaxParallelism()));
-      threadpool_device_.reset(
-          new Eigen::ThreadPoolDevice(threadpool_handle_->AsEigenThreadPool(),
-                                      threadpool_handle_->NumThreads()));
-    }
-
-#ifndef ENABLE_ONEDNN_OPENMP
-    // Set oneDNN concurrency settings (which is thread-local)
-    tsl::OneDnnThreadPool::set_onednn_max_threads(intra_op_parallelism_);
-#endif
-  }
-
-  absl::Status HandleCustomCall(HloInstruction* custom_call) override {
-    HloInstruction* matmul;
-    if (Match(custom_call, OneDnnMatmulInstr(&matmul))) {
-      return HandleCustomCallInternal<dnnl::matmul::primitive_desc>(
-          custom_call);
-    }
-
-    return DefaultAction(custom_call);
-  }
-
-  template <typename PrimDesc>
-  absl::Status HandleCustomCallInternal(HloInstruction* custom_call) {
-    auto scratch_add = AddScratch<PrimDesc>(custom_call);
-    if (scratch_add.ok()) {
-      custom_call = *scratch_add;
-    } else {
-      VLOG(2) << scratch_add.status();
-    }
-    auto weights_prepack = PrepackWeights<PrimDesc>(custom_call);
-    if (!weights_prepack.ok()) {
-      VLOG(2) << weights_prepack.status();
-    }
-    return absl::OkStatus();
-  }
-
-  template <typename>
-  absl::Status SetWeightsPrepack(HloInstruction*, bool);
-
-  template <typename>
-  absl::Status SetUserScratch(HloInstruction*, bool);
-
-  template <typename>
-  bool GetWeightsPrepack(HloInstruction*);
-
-  template <typename>
-  bool GetUserScratch(HloInstruction*);
-
-  // Add scratch for matmul by changing the result of custom-call to
-  // tuple(result, scratch)
-  template <typename PrimDesc>
-  absl::StatusOr<HloInstruction*> AddScratch(HloInstruction* custom_call) {
-    if (GetUserScratch<PrimDesc>(custom_call)) {
-      return custom_call;
-    }
-    TF_RETURN_IF_ERROR(SetUserScratch<PrimDesc>(custom_call, true));
-    auto prim_desc = CreateOneDnnPrimDesc<PrimDesc>(custom_call);
-    int64_t scratch_size = prim_desc->scratchpad_desc().get_size();
-    Shape scratch_shape = ShapeUtil::MakeShape(U8, {scratch_size});
-    Shape tuple_shape =
-        ShapeUtil::MakeTupleShape({custom_call->shape(), scratch_shape});
-    auto new_custom_call = custom_call->AddInstruction(
-        custom_call->CloneWithNewShape(tuple_shape));
-    HloInstruction* gte =
-        new_custom_call->AddInstruction(HloInstruction::CreateGetTupleElement(
-            custom_call->shape(), new_custom_call, 0));
-    auto status = ReplaceInstruction(custom_call, gte);
-    if (!status.ok()) {
-      TF_RETURN_IF_ERROR(SetUserScratch<PrimDesc>(custom_call, false));
-      return absl::CancelledError("Adding scratch is unsuccessful.");
-    }
-    return new_custom_call;
-  }
-
-  template <typename PrimDesc>
-  absl::StatusOr<HloInstruction*> PrepackWeights(HloInstruction* custom_call) {
-    if (GetWeightsPrepack<PrimDesc>(custom_call)) {
-      return custom_call;
-    }
-    auto weights = custom_call->operand(1);
-    auto weights_shape = weights->shape();
-    Literal weights_literal;
-    if (!(weights_shape.rank() == 2 &&
-          evaluator_.TryEvaluate(weights, &weights_literal, true))) {
-      return absl::CancelledError(
-          "Cannot prepack weights. Not constant 2D weights.");
-    }
-    auto plain_weights_md = ShapeToMemDesc(weights_shape);
-    if constexpr (std::is_same<PrimDesc, dnnl::matmul::primitive_desc>::value) {
-      TF_ASSIGN_OR_RETURN(auto backend_config,
-                          custom_call->backend_config<BackendConfig>());
-      TRANSPOSE_LAST_TWO_DIMS_IF(
-          backend_config.onednn_matmul_config().transpose_b(),
-          plain_weights_md);
-    }
-    TF_RETURN_IF_ERROR(SetWeightsPrepack<PrimDesc>(custom_call, true));
-    auto prim_desc = CreateOneDnnPrimDesc<PrimDesc>(custom_call);
-    auto packed_weights_md = prim_desc->weights_desc();
-    auto packed_weights_shape = MemDescToXlaShapeFlattened(packed_weights_md);
-    auto packed_weights_literal = Literal(packed_weights_shape);
-    ReorderWeight(plain_weights_md, weights_literal.untyped_data(),
-                  packed_weights_md, packed_weights_literal.untyped_data());
-    HloInstruction* reordered_weight = custom_call->AddInstruction(
-        HloInstruction::CreateConstant(std::move(packed_weights_literal)));
-    auto status =
-        custom_call->ReplaceOperandWithDifferentShape(1, reordered_weight);
-    if (!status.ok()) {
-      TF_RETURN_IF_ERROR(SetWeightsPrepack<PrimDesc>(custom_call, false));
-      return absl::CancelledError(
-          "Cannot replace plain weights with prepacked weights.");
-    } else {
-      return custom_call;
-    }
-  }
-
-  void ReorderWeight(const dnnl::memory::desc& src_md, void* src_buf,
-                     const dnnl::memory::desc& dst_md, void* dst_buf) {
-    auto onednn_threadpool = CreateOneDnnThreadPool(threadpool_device_.get());
-    dnnl::engine cpu_engine(dnnl::engine::kind::cpu, 0);
-    auto onednn_stream = MakeOneDnnStream(cpu_engine, onednn_threadpool.get());
-    auto src_mem = dnnl::memory(src_md, cpu_engine, src_buf);
-    auto dst_mem = dnnl::memory(dst_md, cpu_engine, dst_buf);
-    dnnl::reorder reorder_prim{src_mem, dst_mem};
-    reorder_prim.execute(onednn_stream, src_mem, dst_mem);
-    onednn_stream.wait();
-  }
-
- private:
-  int intra_op_parallelism_;
-  HloEvaluator evaluator_;
-  std::unique_ptr<tsl::thread::ThreadPool> threadpool_handle_;
-  std::unique_ptr<Eigen::ThreadPoolDevice> threadpool_device_;
-};
-
-#define EMIT_GET_BACKEND_CONFIG_SPECIALIZATION(GETTER, PRIM_DESC, CONFIG,      \
-                                               SUB_CONFIG, FIELD)              \
-  template <>                                                                  \
-  inline bool OneDnnPostRewriteVisitor::GETTER<PRIM_DESC>(HloInstruction *     \
-                                                          custom_call) {       \
-    auto backend_config = custom_call->backend_config<BackendConfig>();        \
-    return backend_config.ok() ? backend_config->CONFIG().SUB_CONFIG().FIELD() \
-                               : false;                                        \
-  }
-
-EMIT_GET_BACKEND_CONFIG_SPECIALIZATION(GetUserScratch,
-                                       dnnl::matmul::primitive_desc,
-                                       onednn_matmul_config,
-                                       optimization_config, user_scratchpad);
-EMIT_GET_BACKEND_CONFIG_SPECIALIZATION(GetWeightsPrepack,
-                                       dnnl::matmul::primitive_desc,
-                                       onednn_matmul_config,
-                                       optimization_config, weights_prepacked);
-
-#define EMIT_SET_BACKEND_CONFIG_SPECIALIZATION(SETTER, PRIM_DESC, CONFIG_TYPE, \
-                                               CONFIG, SUB_CONFIG, FIELD)      \
-  template <>                                                                  \
-  inline absl::Status OneDnnPostRewriteVisitor::SETTER<PRIM_DESC>(             \
-      HloInstruction * custom_call, bool value) {                              \
-    TF_ASSIGN_OR_RETURN(auto backend_config,                                   \
-                        custom_call->backend_config<BackendConfig>());         \
-    CONFIG_TYPE* config = backend_config.mutable_##CONFIG();                   \
-    config->mutable_##SUB_CONFIG()->set_##FIELD(value);                        \
-    return custom_call->set_backend_config(backend_config);                    \
-  }
-
-EMIT_SET_BACKEND_CONFIG_SPECIALIZATION(SetWeightsPrepack,
-                                       dnnl::matmul::primitive_desc,
-                                       OneDnnMatMulConfig, onednn_matmul_config,
-                                       optimization_config, weights_prepacked);
-EMIT_SET_BACKEND_CONFIG_SPECIALIZATION(SetUserScratch,
-                                       dnnl::matmul::primitive_desc,
-                                       OneDnnMatMulConfig, onednn_matmul_config,
-                                       optimization_config, user_scratchpad);
-
-absl::StatusOr<bool> OneDnnMatMulRewriter::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  OneDnnMatMulRewriteVisitor visitor;
-  TF_ASSIGN_OR_RETURN(auto result,
-                      visitor.RunOnModule(module, execution_threads));
-
-  OneDnnPostRewriteVisitor reorder_visitor(intra_op_parallelism_,
-                                           compile_threadpool_);
-  TF_ASSIGN_OR_RETURN(auto result2,
-                      reorder_visitor.RunOnModule(module, execution_threads));
-
-  return {result || result2};
-}
-
-}  // namespace cpu
-}  // namespace xla
-
-#endif  // INTEL_MKL && ENABLE_ONEDNN_V3
diff --git a/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.h b/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.h
deleted file mode 100644
index 7ad7f76..0000000
--- a/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.h
+++ /dev/null
@@ -1,59 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_CPU_ONEDNN_MATMUL_REWRITER_H_
-#define XLA_SERVICE_CPU_ONEDNN_MATMUL_REWRITER_H_
-#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3)
-
-#include <optional>
-
-#include "absl/algorithm/container.h"
-#include "unsupported/Eigen/CXX11/Tensor"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "tsl/platform/threadpool.h"
-
-namespace xla {
-namespace cpu {
-
-// This pass pattern-matches HLO Dot instructions and rewrites into custom
-// calls.
-class OneDnnMatMulRewriter : public HloModulePass {
- public:
-  OneDnnMatMulRewriter(int intra_op_parallelism,
-                       const tsl::thread::ThreadPool* compile_threadpool)
-      : intra_op_parallelism_(intra_op_parallelism),
-        compile_threadpool_(compile_threadpool) {}
-  OneDnnMatMulRewriter() = default;
-  absl::string_view name() const override { return "onednn-matmul-rewriter"; }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
-  static bool ShouldRewrite(const HloInstruction* dot_instr);
-
- private:
-  int intra_op_parallelism_;
-  const tsl::thread::ThreadPool* compile_threadpool_;
-};
-
-}  // namespace cpu
-}  // namespace xla
-
-#endif  // INTEL_MKL && ENABLE_ONEDNN_V3
-#endif  // XLA_SERVICE_CPU_ONEDNN_MATMUL_REWRITER_H_
diff --git a/third_party/xla/xla/service/cpu/runtime/BUILD b/third_party/xla/xla/service/cpu/runtime/BUILD
index 9e5bd2e..a7e97b0 100644
--- a/third_party/xla/xla/service/cpu/runtime/BUILD
+++ b/third_party/xla/xla/service/cpu/runtime/BUILD
@@ -160,6 +160,7 @@
         "//xla/stream_executor:device_memory",
         "//xla/tsl/concurrency:async_value",
         "//xla/tsl/lib/core:status_test_util",
+        "@com_google_absl//absl/algorithm:container",
         "@com_google_absl//absl/status",
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/types:span",
diff --git a/third_party/xla/xla/service/cpu/runtime/collective_thunk.cc b/third_party/xla/xla/service/cpu/runtime/collective_thunk.cc
index a0cd9f4..32a452a 100644
--- a/third_party/xla/xla/service/cpu/runtime/collective_thunk.cc
+++ b/third_party/xla/xla/service/cpu/runtime/collective_thunk.cc
@@ -205,10 +205,6 @@
   return op_buffers_.source_shapes[index];
 }
 
-absl::Span<const Shape> CollectiveThunk::source_shapes() const {
-  return op_buffers_.source_shapes;
-}
-
 const BufferAllocation::Slice& CollectiveThunk::destination_buffer(
     int64_t index) const {
   return op_buffers_.destination_buffers[index];
@@ -223,8 +219,4 @@
   return op_buffers_.destination_shapes[index];
 }
 
-absl::Span<const Shape> CollectiveThunk::destination_shapes() const {
-  return op_buffers_.destination_shapes;
-}
-
 }  // namespace xla::cpu
diff --git a/third_party/xla/xla/service/cpu/runtime/collective_thunk.h b/third_party/xla/xla/service/cpu/runtime/collective_thunk.h
index 5bcf16b..5ae9c98 100644
--- a/third_party/xla/xla/service/cpu/runtime/collective_thunk.h
+++ b/third_party/xla/xla/service/cpu/runtime/collective_thunk.h
@@ -77,7 +77,6 @@
                   OpBuffers op_buffers, OpResources op_resources);
 
   const OpParams& op_params() const { return op_params_; }
-  const OpBuffers& op_buffers() const { return op_buffers_; }
 
   // Resolves operation's device memory from the buffers and buffer allocations.
   absl::StatusOr<OpDeviceMemory> GetOpDeviceMemory(const ExecuteParams& params);
@@ -109,13 +108,11 @@
   absl::Span<const BufferAllocation::Slice> source_buffers() const;
 
   const Shape& source_shape(int64_t index) const;
-  absl::Span<const Shape> source_shapes() const;
 
   const BufferAllocation::Slice& destination_buffer(int64_t index) const;
   absl::Span<const BufferAllocation::Slice> destination_buffers() const;
 
   const Shape& destination_shape(int64_t index) const;
-  absl::Span<const Shape> destination_shapes() const;
 
  private:
   OpParams op_params_;
diff --git a/third_party/xla/xla/service/cpu/runtime/custom_call_thunk.cc b/third_party/xla/xla/service/cpu/runtime/custom_call_thunk.cc
index 1161673..8c6deca 100644
--- a/third_party/xla/xla/service/cpu/runtime/custom_call_thunk.cc
+++ b/third_party/xla/xla/service/cpu/runtime/custom_call_thunk.cc
@@ -196,11 +196,11 @@
 
   // Forward ExecutableRunOptions to the FFI handlers via the call options.
   CustomCallExecuteParams* custom_call_params = params.custom_call_params;
-  ffi::CallOptions call_options = {custom_call_params->device_ordinal,
-                                   custom_call_params->stream,
-                                   custom_call_params->allocator,
-                                   /*called_computation=*/nullptr,
-                                   custom_call_params->ffi_execution_context};
+  ffi::CallOptions call_options = {
+      custom_call_params->device_ordinal,
+      ffi::CallOptions::CpuOptions{custom_call_params->intra_op_thread_pool},
+      /*called_computation=*/nullptr,
+      custom_call_params->ffi_execution_context};
 
   // Call the function and check execution status.
   auto status = ffi::Call(handler->bundle.execute, call_frame, call_options);
diff --git a/third_party/xla/xla/service/cpu/runtime/kernel_thunk.cc b/third_party/xla/xla/service/cpu/runtime/kernel_thunk.cc
index 50b42c7..5ab801c 100644
--- a/third_party/xla/xla/service/cpu/runtime/kernel_thunk.cc
+++ b/third_party/xla/xla/service/cpu/runtime/kernel_thunk.cc
@@ -200,7 +200,7 @@
 
   // TODO(ezhulenev): Kernel ptr should be loaded as a part of Thunk
   // initialization stage.
-  se::host::HostKernel* kernel = kernel_ptr_.load(std::memory_order_relaxed);
+  se::host::HostKernel* kernel = kernel_ptr_.load(std::memory_order_acquire);
 
   // Because thunks are owned by a parent CpuExecutable, we can safely assume
   // that kernel pointer will not change after we find it the first time.
@@ -209,8 +209,10 @@
                         params.function_registry->FindKernel(kernel_name_));
 
     absl::MutexLock lock(&mutex_);
-    kernel_.emplace(num_kernel_args_, kernel_fn, nullptr);
-    kernel_ptr_.store(kernel = &kernel_.value());
+    if ((kernel = kernel_ptr_.load(std::memory_order_relaxed)) == nullptr) {
+      kernel = &kernel_.emplace(num_kernel_args_, kernel_fn, nullptr);
+      kernel_ptr_.store(kernel, std::memory_order_release);
+    }
   }
 
   // Use a fast path if kernel called just once.
diff --git a/third_party/xla/xla/service/cpu/runtime/sort_thunk.cc b/third_party/xla/xla/service/cpu/runtime/sort_thunk.cc
index a24a227..041bf03 100644
--- a/third_party/xla/xla/service/cpu/runtime/sort_thunk.cc
+++ b/third_party/xla/xla/service/cpu/runtime/sort_thunk.cc
@@ -471,6 +471,9 @@
       case 25:
         sort(std::integral_constant<size_t, 25>{});
         break;
+      case 29:
+        sort(std::integral_constant<size_t, 29>{});
+        break;
       default:
         return Internal("Unsupported number of sorted inputs: %d", data.size());
     }
diff --git a/third_party/xla/xla/service/cpu/runtime/thunk.cc b/third_party/xla/xla/service/cpu/runtime/thunk.cc
index 9588b02..9228de3 100644
--- a/third_party/xla/xla/service/cpu/runtime/thunk.cc
+++ b/third_party/xla/xla/service/cpu/runtime/thunk.cc
@@ -140,18 +140,16 @@
           ? run_options->device_ordinal()
           : run_options->stream()->parent()->device_ordinal();
 
-  return CustomCallExecuteParams{device_ordinal, run_options->stream(),
-                                 run_options->allocator(),
+  return CustomCallExecuteParams{device_ordinal,
+                                 run_options->intra_op_thread_pool(),
                                  run_options->ffi_execution_context()};
 }
 
 Thunk::CustomCallExecuteParams::CustomCallExecuteParams(
-    int32_t device_ordinal, stream_executor::Stream* stream,
-    stream_executor::DeviceMemoryAllocator* allocator,
+    int32_t device_ordinal, const Eigen::ThreadPoolDevice* intra_op_thread_pool,
     const ffi::ExecutionContext* ffi_execution_context)
     : device_ordinal(device_ordinal),
-      stream(stream),
-      allocator(allocator),
+      intra_op_thread_pool(intra_op_thread_pool),
       ffi_execution_context(ffi_execution_context) {}
 
 tsl::AsyncValueRef<Thunk::ExecuteEvent> Thunk::OkExecuteEventSingleton() {
diff --git a/third_party/xla/xla/service/cpu/runtime/thunk.h b/third_party/xla/xla/service/cpu/runtime/thunk.h
index 0e645f2..9141da7 100644
--- a/third_party/xla/xla/service/cpu/runtime/thunk.h
+++ b/third_party/xla/xla/service/cpu/runtime/thunk.h
@@ -208,14 +208,12 @@
         const ExecutableRunOptions* run_options);
 
     int32_t device_ordinal;
-    stream_executor::Stream* stream = nullptr;
-    stream_executor::DeviceMemoryAllocator* allocator = nullptr;
+    const Eigen::ThreadPoolDevice* intra_op_thread_pool = nullptr;
     const ffi::ExecutionContext* ffi_execution_context = nullptr;
 
    private:
     CustomCallExecuteParams(int32_t device_ordinal,
-                            stream_executor::Stream* stream,
-                            stream_executor::DeviceMemoryAllocator* allocator,
+                            const Eigen::ThreadPoolDevice* intra_op_thread_pool,
                             const ffi::ExecutionContext* ffi_execution_context);
   };
 
diff --git a/third_party/xla/xla/service/cpu/runtime/thunk_executor.cc b/third_party/xla/xla/service/cpu/runtime/thunk_executor.cc
index 805840a..9b4c735 100644
--- a/third_party/xla/xla/service/cpu/runtime/thunk_executor.cc
+++ b/third_party/xla/xla/service/cpu/runtime/thunk_executor.cc
@@ -16,6 +16,7 @@
 #include "xla/service/cpu/runtime/thunk_executor.h"
 
 #include <atomic>
+#include <cstddef>
 #include <cstdint>
 #include <memory>
 #include <string>
@@ -61,7 +62,7 @@
   }
 
   // Erase redundant edges between nodes.
-  int64_t num_erased_edges = TransitiveReduction();
+  int64_t num_erased_edges = RunTransitiveReductionAndUpdatePriorities();
 
   // Check if constructed execution DAG is sequential: every node depends on the
   // completion of the previous node.
@@ -160,8 +161,14 @@
 
   // Create async execution state on heap and kick-off execution.
   auto state = std::make_unique<ExecuteState>(this, params.task_runner);
-  Execute(state.get(), params, ReadyQueue(source_.begin(), source_.end()),
-          /*lock=*/params.session.Join());
+
+  if (options_.use_priority_ready_queue) {
+    Execute(state.get(), params, PriorityReadyQueue(nodes_defs_, source_),
+            /*lock=*/params.session.Join());
+  } else {
+    Execute(state.get(), params, FifoReadyQueue(source_),
+            /*lock=*/params.session.Join());
+  }
 
   // If execution already completed (all kernels executed in the caller thread),
   // immediately return the result to avoid wasteful reference counting below.
@@ -256,13 +263,14 @@
   event.SetStateConcrete();
 }
 
+template <typename ReadyQueue>
 void ThunkExecutor::Execute(ExecuteState* state,
                             const Thunk::ExecuteParams& params,
                             ReadyQueue ready_queue,
                             Thunk::ExecuteSession::Lock lock) {
   tsl::profiler::TraceMe trace("ThunkExecutor::Execute");
 
-  DCHECK(!ready_queue.empty()) << "Ready queue must not be empty";
+  DCHECK(!ready_queue.Empty()) << "Ready queue must not be empty";
   DCHECK(lock) << "Execute session lock must be set";
 
   bool has_runner = state->runner != nullptr;
@@ -270,8 +278,8 @@
   // Threshold for splitting ready queue into separate thunk executor tasks.
   int64_t split_threshold = params.session.split_threshold();
 
-  for (int64_t i = 0; i < ready_queue.size(); ++i) {
-    NodeId id = ready_queue[i];
+  while (!ready_queue.Empty()) {
+    NodeId id = ready_queue.Pop();
     ExecuteState::Node& node = state->node(id);
 
     int64_t cnt = node.counter.load(std::memory_order_acquire);
@@ -279,9 +287,9 @@
 
     // If we have multiple ready thunks, split the ready queue and offload
     // thunks processing to the task runner.
-    int64_t num_ready_thunks = ready_queue.size() - i;
+    int64_t num_ready_thunks = ready_queue.Size();
     if (ABSL_PREDICT_FALSE(has_runner && num_ready_thunks > split_threshold)) {
-      SplitReadyQueue(state, params, /*start_index=*/i + 1, ready_queue);
+      SplitReadyQueue(state, params, ready_queue, split_threshold);
     }
 
     // Execute thunk for the given node id. If execution is aborted, we keep
@@ -307,13 +315,13 @@
       // the same execute session.
       execute_event.AndThen([&params, &node, state,
                              execute_event = execute_event.AsPtr(),
+                             ready_queue = ready_queue.CreateEmptyReadyQueue(),
                              lock = params.session.Join()]() mutable {
-        ReadyQueue ready_queue;
         state->executor->ProcessOutEdges(state, execute_event, node,
                                          ready_queue);
         // If ready queue is empty it might mean that we have completed an
         // execution and destroyed the `state`.
-        if (ABSL_PREDICT_TRUE(!ready_queue.empty())) {
+        if (ABSL_PREDICT_TRUE(!ready_queue.Empty())) {
           state->executor->Execute(state, params, std::move(ready_queue),
                                    std::move(lock));
         }
@@ -322,17 +330,17 @@
   }
 }
 
+template <typename ReadyQueue>
 inline ABSL_ATTRIBUTE_ALWAYS_INLINE void ThunkExecutor::SplitReadyQueue(
     ExecuteState* state, const Thunk::ExecuteParams& params,
-    int64_t start_index, ReadyQueue& ready_queue) {
+    ReadyQueue& ready_queue, int64_t split_threshold) {
   DCHECK(state->runner) << "TaskRunner must be set";
-  int64_t end_index = ready_queue.size();
 
   // We use recursive work splitting to push the tail of the ready queue to
   // the task runner. Recursive work splitting creates a more uniform work
   // distribution across the task runner threads and avoids a situation when
   // we have a long tail of work that is processed by a single thread.
-  while (end_index > start_index) {
+  while (ready_queue.Size() > split_threshold) {
     // Try to acquire a lock to offload ready thunks to the task runner. If
     // we can't get a lock, we will keep processing the ready queue in the
     // current thread as it means that we have enough concurrent workers
@@ -342,22 +350,16 @@
       break;
     }
 
-    // Execute [mid_index, end_index) nodes in the task runner.
-    int64_t mid_index = (start_index + end_index) / 2;
-    (*state->runner)([&params, state,
-                      ready_queue = ReadyQueue(ready_queue.begin() + mid_index,
-                                               ready_queue.begin() + end_index),
+    // Execute half of the ready queue nodes in the task runner.
+    (*state->runner)([&params, state, ready_queue = ready_queue.PopHalf(),
                       lock = std::move(task_runner_lock)]() mutable {
       state->executor->Execute(state, params, std::move(ready_queue),
                                std::move(lock));
     });
-    end_index = mid_index;
   }
-
-  // Erase ready nodes passed to the task runner.
-  ready_queue.erase(ready_queue.begin() + end_index, ready_queue.end());
 }
 
+template <typename ReadyQueue>
 void ThunkExecutor::ProcessOutEdges(
     ExecuteState* state, tsl::AsyncValuePtr<Thunk::ExecuteEvent> node_event,
     ExecuteState::Node& node, ReadyQueue& ready_queue) {
@@ -380,7 +382,7 @@
 
     int64_t cnt = out_node.counter.fetch_sub(1, std::memory_order_release);
     DCHECK_GE(cnt, 1) << "Node counter can't drop below 0";
-    if (cnt == 1) ready_queue.push_back(out_edge);
+    if (cnt == 1) ready_queue.Push(out_edge);
   }
 
   // Drop the pending sink nodes counter if the node is a sink.
@@ -431,7 +433,7 @@
   return 0;
 }
 
-int64_t ThunkExecutor::TransitiveReduction() {
+int64_t ThunkExecutor::RunTransitiveReductionAndUpdatePriorities() {
   int64_t num_erased_edges = 0;
 
   // Keep workspace for DFS traversal between iterations.
@@ -454,11 +456,11 @@
     stack.clear();
     visited.assign(nodes_defs_.size(), false);
 
-    // Initialize stack with nodes reachable via immediate out nodes. We don't
-    // need to add source node and immediate out nodes to the visited set
-    // because graph is acyclic and we don't visit them again.
+    // Initialize stack with nodes reachable via immediate out nodes. We mark
+    // immediate out nodes as visited to correctly compute node priority below.
     for (int64_t out_id : source_node.out_edges) {
       NodeDef& out_node = nodes_defs_[out_id];
+      visited[out_id] = true;
       for (int64_t start_id : out_node.out_edges) add_to_stack(start_id);
     }
 
@@ -472,6 +474,9 @@
 
       for (int64_t out_id : node.out_edges) add_to_stack(out_id);
     }
+
+    // Set node priority to the number of visited nodes in the DFS traversal.
+    source_node.priority = absl::c_count(visited, true);
   }
 
   return num_erased_edges;
@@ -495,14 +500,88 @@
     const Thunk& thunk = *thunk_sequence_[i];
     bool is_source = absl::c_find(source_, i) != source_.end();
     bool is_sink = absl::c_find(sink_, i) != sink_.end();
-    absl::StrAppendFormat(
-        &str,
-        "\n thunk #%05d: op_name=%s, dependencies=[%s], source=%v, sink=%v", i,
-        thunk.info().op_name, absl::StrJoin(in_edges[i], ", "), is_source,
-        is_sink);
+    absl::StrAppendFormat(&str,
+                          "\n thunk #%05d: op_name=%s, dependencies=[%s], "
+                          "source=%v, sink=%v, priority=%d",
+                          i, thunk.info().op_name,
+                          absl::StrJoin(in_edges[i], ", "), is_source, is_sink,
+                          nodes_defs_[i].priority);
   }
 
   return str;
 }
 
+ThunkExecutor::FifoReadyQueue::FifoReadyQueue(
+    absl::Span<const NodeId> ready_nodes)
+    : queue_(ready_nodes.begin(), ready_nodes.end()) {}
+
+void ThunkExecutor::FifoReadyQueue::Push(NodeId id) { queue_.push_back(id); }
+
+ThunkExecutor::NodeId ThunkExecutor::FifoReadyQueue::Pop() {
+  DCHECK(!Empty()) << "Queue must not be empty";
+  return queue_[head_++];
+}
+
+ThunkExecutor::FifoReadyQueue ThunkExecutor::FifoReadyQueue::PopHalf() {
+  DCHECK(!Empty()) << "Queue must not be empty";
+  auto mid = queue_.begin() + head_ + Size() / 2;
+  FifoReadyQueue popped(absl::MakeConstSpan(&*mid, queue_.end() - mid));
+  queue_.resize(mid - queue_.begin());
+  return popped;
+}
+
+size_t ThunkExecutor::FifoReadyQueue::Size() const {
+  return queue_.size() - head_;
+}
+
+bool ThunkExecutor::FifoReadyQueue::Empty() const {
+  return head_ == queue_.size();
+}
+
+ThunkExecutor::FifoReadyQueue
+ThunkExecutor::FifoReadyQueue::CreateEmptyReadyQueue() const {
+  return FifoReadyQueue(absl::Span<const NodeId>());
+}
+
+ThunkExecutor::PriorityReadyQueue::PriorityReadyQueue(
+    absl::Span<const NodeDef> nodes_defs, absl::Span<const NodeId> ready_nodes)
+    : nodes_defs_(nodes_defs),
+      queue_(ready_nodes.begin(), ready_nodes.end(), Compare{nodes_defs}) {}
+
+void ThunkExecutor::PriorityReadyQueue::Push(NodeId id) { queue_.push(id); }
+
+ThunkExecutor::NodeId ThunkExecutor::PriorityReadyQueue::Pop() {
+  DCHECK(!Empty()) << "Queue must not be empty";
+  NodeId id = queue_.top();
+  queue_.pop();
+  return id;
+}
+
+ThunkExecutor::PriorityReadyQueue ThunkExecutor::PriorityReadyQueue::PopHalf() {
+  DCHECK(!Empty()) << "Queue must not be empty";
+  int64_t keep_top_nodes = queue_.size() / 2;
+
+  // First pop nodes with highest priority from the queue.
+  PriorityReadyQueue popped(nodes_defs_, {});
+  while (keep_top_nodes-- > 0) {
+    popped.queue_.push(queue_.top());
+    queue_.pop();
+  }
+
+  // Swap popped nodes with remaining nodes, to return to the caller nodes with
+  // smaller priorities, and keep higher priority nodes in the queue.
+  popped.queue_.swap(queue_);
+
+  return popped;
+}
+
+size_t ThunkExecutor::PriorityReadyQueue::Size() const { return queue_.size(); }
+
+bool ThunkExecutor::PriorityReadyQueue::Empty() const { return queue_.empty(); }
+
+ThunkExecutor::PriorityReadyQueue
+ThunkExecutor::PriorityReadyQueue::CreateEmptyReadyQueue() const {
+  return PriorityReadyQueue(nodes_defs_, {});
+}
+
 }  // namespace xla::cpu
diff --git a/third_party/xla/xla/service/cpu/runtime/thunk_executor.h b/third_party/xla/xla/service/cpu/runtime/thunk_executor.h
index 10df02c..f0df6cf 100644
--- a/third_party/xla/xla/service/cpu/runtime/thunk_executor.h
+++ b/third_party/xla/xla/service/cpu/runtime/thunk_executor.h
@@ -21,6 +21,7 @@
 #include <cstdint>
 #include <limits>
 #include <new>
+#include <queue>
 #include <string>
 #include <type_traits>
 #include <vector>
@@ -45,6 +46,10 @@
   // `execute_sequential_buffer_threshold`, we mark execution as sequential, as
   // concurrency overheads will likely dominate the overall execution time.
   size_t execute_sequential_buffer_threshold = 512;
+
+  // Use priority ready queue to execute nodes according to their priority. By
+  // default we use FIFO ready queue.
+  bool use_priority_ready_queue = false;
 };
 }  // namespace internal
 
@@ -72,6 +77,7 @@
   // NodeDef defines an execution order for all thunks in a sequence.
   struct NodeDef {
     NodeId id = kInvalidNodeId;
+    int64_t priority = 0;
     std::vector<NodeId> in_edges;
     std::vector<NodeId> out_edges;
   };
@@ -97,6 +103,57 @@
 
   bool is_sequential() const { return is_sequential_; }
 
+  // A ready queue that executes nodes in FIFO order.
+  class FifoReadyQueue {
+   public:
+    explicit FifoReadyQueue(absl::Span<const NodeId> ready_nodes);
+
+    void Push(NodeId id);
+
+    NodeId Pop();
+    FifoReadyQueue PopHalf();
+
+    size_t Size() const;
+    bool Empty() const;
+
+    FifoReadyQueue CreateEmptyReadyQueue() const;
+
+   private:
+    absl::InlinedVector<NodeId, 8> queue_;
+    size_t head_ = 0;
+  };
+
+  // A ready queue that executes nodes sorted by NodeDef priority.
+  class PriorityReadyQueue {
+   public:
+    PriorityReadyQueue(absl::Span<const NodeDef> nodes_defs,
+                       absl::Span<const NodeId> ready_nodes);
+
+    void Push(NodeId id);
+
+    NodeId Pop();
+    PriorityReadyQueue PopHalf();
+
+    size_t Size() const;
+    bool Empty() const;
+
+    PriorityReadyQueue CreateEmptyReadyQueue() const;
+
+   private:
+    struct Compare {
+      bool operator()(NodeId a, NodeId b) const {
+        return nodes_defs[a].priority < nodes_defs[b].priority;
+      }
+      absl::Span<const NodeDef> nodes_defs;
+    };
+
+    using InlinedPriorityQueue =
+        std::priority_queue<NodeId, absl::InlinedVector<NodeId, 8>, Compare>;
+
+    absl::Span<const NodeDef> nodes_defs_;
+    InlinedPriorityQueue queue_;
+  };
+
  private:
   // Align all atomic counters to a cache line boundary to avoid false
   // sharing between multiple worker threads.
@@ -107,8 +164,6 @@
       64;
 #endif
 
-  using ReadyQueue = absl::InlinedVector<NodeId, 8>;
-
   // A struct to keep the state of a running ThunkExecutor.
   struct ExecuteState {
     // At run time NodeDef instantiated as a Node with an atomic counter that
@@ -162,26 +217,29 @@
                                tsl::AsyncValueRef<ExecuteEvent> event);
 
   // Executes nodes in the ready queue with given thunk parameters.
+  template <typename ReadyQueue>
   void Execute(ExecuteState* state, const Thunk::ExecuteParams& params,
                ReadyQueue ready_queue, Thunk::ExecuteSession::Lock lock);
 
   // Splits ready queue starting from `start_index` into ThunkExecutor tasks and
   // offloads them to the task runner.
+  template <typename ReadyQueue>
   void SplitReadyQueue(ExecuteState* state, const Thunk::ExecuteParams& params,
-                       int64_t start_index, ReadyQueue& ready_queue);
+                       ReadyQueue& ready_queue, int64_t split_threshold);
 
   // Processes out edges of a completed `node` and updates `ready_queue` with
   // nodes that are ready to execute. If `event` is in error state, aborts the
   // execution and records the error status to forward it to the caller.
+  template <typename ReadyQueue>
   void ProcessOutEdges(ExecuteState* state,
                        tsl::AsyncValuePtr<Thunk::ExecuteEvent> node_event,
                        ExecuteState::Node& node, ReadyQueue& ready_queue);
 
-  // Runs a transitive reduction on the NodeDef graph to remove redundant edges.
-  // Returns the number of removed edges.
+  // Runs a transitive reduction on the NodeDef graph to remove redundant edges,
+  // and updates nodes priorities. Returns the number of removed edges.
   //
   // See: https://en.wikipedia.org/wiki/Transitive_reduction
-  int64_t TransitiveReduction();
+  int64_t RunTransitiveReductionAndUpdatePriorities();
 
   ThunkSequence thunk_sequence_;
   Options options_;
diff --git a/third_party/xla/xla/service/cpu/runtime/thunk_executor_test.cc b/third_party/xla/xla/service/cpu/runtime/thunk_executor_test.cc
index 2bbb932..60996eb 100644
--- a/third_party/xla/xla/service/cpu/runtime/thunk_executor_test.cc
+++ b/third_party/xla/xla/service/cpu/runtime/thunk_executor_test.cc
@@ -26,6 +26,7 @@
 #include <utility>
 #include <vector>
 
+#include "absl/algorithm/container.h"
 #include "absl/status/status.h"
 #include "absl/strings/str_cat.h"
 #include "absl/types/span.h"
@@ -218,6 +219,123 @@
   return ThunkExecutor::Options{/*execute_sequential_buffer_threshold=*/0};
 }
 
+TEST(ThunkExecutorTest, FifoReadyQueueTest) {
+  ThunkExecutor::FifoReadyQueue queue({});
+
+  // Check basic queue properties.
+  EXPECT_TRUE(queue.Empty());
+  EXPECT_EQ(queue.Size(), 0);
+
+  queue.Push(1);
+  queue.Push(2);
+  queue.Push(3);
+
+  EXPECT_EQ(queue.Size(), 3);
+
+  EXPECT_EQ(queue.Pop(), 1);
+  EXPECT_EQ(queue.Pop(), 2);
+  EXPECT_EQ(queue.Pop(), 3);
+
+  EXPECT_TRUE(queue.Empty());
+  EXPECT_EQ(queue.Size(), 0);
+
+  // Prepare queue for PopHalf test case.
+  queue.Push(1);
+  queue.Push(2);
+  queue.Push(3);
+
+  // Pop half of the queue.
+  ThunkExecutor::FifoReadyQueue half0 = queue.PopHalf();
+  EXPECT_EQ(half0.Size(), 2);
+  EXPECT_EQ(half0.Pop(), 2);
+  EXPECT_EQ(half0.Pop(), 3);
+
+  // Check that the rest is still in the queue.
+  EXPECT_EQ(queue.Size(), 1);
+
+  // Pop the rest of the queue.
+  ThunkExecutor::FifoReadyQueue half1 = queue.PopHalf();
+  EXPECT_EQ(half1.Size(), 1);
+
+  // Check that all nodes were returned from PopHalf.
+  EXPECT_EQ(queue.Size(), 0);
+
+  // Add 5 elements to test Pop followed by PopHalf.
+  queue.Push(1);
+  queue.Push(2);
+  queue.Push(3);
+  queue.Push(4);
+  queue.Push(5);
+
+  EXPECT_EQ(queue.Pop(), 1);
+
+  // Check that PopHalf returns 2 last nodes.
+  ThunkExecutor::FifoReadyQueue half2 = queue.PopHalf();
+  EXPECT_EQ(half2.Size(), 2);
+  EXPECT_EQ(half2.Pop(), 4);
+  EXPECT_EQ(half2.Pop(), 5);
+}
+
+TEST(ThunkExecutorTest, PriorityReadyQueueTest) {
+  std::vector<ThunkExecutor::NodeDef> nodes_defs(16);
+  for (size_t i = 0; i < nodes_defs.size(); ++i) {
+    nodes_defs[i].priority = i;
+  }
+
+  ThunkExecutor::PriorityReadyQueue queue(nodes_defs, {});
+  // Check basic queue properties.
+  EXPECT_TRUE(queue.Empty());
+  EXPECT_EQ(queue.Size(), 0);
+
+  queue.Push(1);
+  queue.Push(3);
+  queue.Push(2);
+
+  EXPECT_EQ(queue.Pop(), 3);
+  EXPECT_EQ(queue.Pop(), 2);
+  EXPECT_EQ(queue.Pop(), 1);
+
+  EXPECT_TRUE(queue.Empty());
+  EXPECT_EQ(queue.Size(), 0);
+
+  // Prepare queue for PopHalf test case.
+  queue.Push(2);
+  queue.Push(1);
+  queue.Push(3);
+
+  // Pop half of the queue.
+  ThunkExecutor::PriorityReadyQueue half0 = queue.PopHalf();
+  EXPECT_EQ(half0.Size(), 2);
+  EXPECT_EQ(half0.Pop(), 2);
+  EXPECT_EQ(half0.Pop(), 1);
+
+  // Check that the rest is still in the queue.
+  EXPECT_EQ(queue.Size(), 1);
+
+  // Pop the rest of the queue.
+  ThunkExecutor::PriorityReadyQueue half1 = queue.PopHalf();
+  EXPECT_EQ(half1.Size(), 1);
+  EXPECT_EQ(half1.Pop(), 3);
+
+  // Check that all nodes were returned from PopHalf.
+  EXPECT_EQ(queue.Size(), 0);
+
+  // Add 5 elements to test Pop followed by PopHalf.
+  queue.Push(4);
+  queue.Push(3);
+  queue.Push(5);
+  queue.Push(1);
+  queue.Push(2);
+
+  EXPECT_EQ(queue.Pop(), 5);
+
+  // Check that PopHalf returns 2 last nodes.
+  ThunkExecutor::PriorityReadyQueue half2 = queue.PopHalf();
+  EXPECT_EQ(half2.Size(), 2);
+  EXPECT_EQ(half2.Pop(), 2);
+  EXPECT_EQ(half2.Pop(), 1);
+}
+
 TEST(ThunkExecutorTest, DependencyOrdering) {
   BufferAllocation alloc(/*index=*/0, /*size=*/80, /*color=*/0);
 
@@ -237,6 +355,10 @@
   EXPECT_FALSE(executor.is_sequential());
   EXPECT_THAT(executor.source(), ElementsAre(0, 1));
   EXPECT_THAT(executor.sink(), ElementsAre(2));
+
+  EXPECT_EQ(executor.node_def(0).priority, 1);
+  EXPECT_EQ(executor.node_def(1).priority, 1);
+  EXPECT_EQ(executor.node_def(2).priority, 0);
 }
 
 TEST(ThunkExecutorTest, SequentialOrdering) {
@@ -255,6 +377,10 @@
   EXPECT_TRUE(executor.is_sequential());
   EXPECT_THAT(executor.source(), ElementsAre(0));
   EXPECT_THAT(executor.sink(), ElementsAre(2));
+
+  EXPECT_EQ(executor.node_def(0).priority, 2);
+  EXPECT_EQ(executor.node_def(1).priority, 1);
+  EXPECT_EQ(executor.node_def(2).priority, 0);
 }
 
 TEST(ThunkExecutorTest, ResourceOrdering) {
@@ -278,6 +404,9 @@
   EXPECT_TRUE(executor.is_sequential());
   EXPECT_THAT(executor.source(), ElementsAre(0));
   EXPECT_THAT(executor.sink(), ElementsAre(1));
+
+  EXPECT_EQ(executor.node_def(0).priority, 1);
+  EXPECT_EQ(executor.node_def(1).priority, 0);
 }
 
 TEST(ThunkExecutorTest, TransitiveReduction) {
@@ -300,6 +429,10 @@
   EXPECT_THAT(executor.node_def(1).in_edges, ElementsAre(0));
   EXPECT_THAT(executor.node_def(1).out_edges, ElementsAre(2));
   EXPECT_THAT(executor.node_def(2).in_edges, ElementsAre(1));
+
+  EXPECT_EQ(executor.node_def(0).priority, 2);
+  EXPECT_EQ(executor.node_def(1).priority, 1);
+  EXPECT_EQ(executor.node_def(2).priority, 0);
 }
 
 TEST(ThunkExecutorTest, Execute) {
@@ -333,7 +466,7 @@
   Thunk::ExecuteParams params = {nullptr, &allocations};
   params.task_runner = &task_runner;
   params.session =
-      Thunk::ExecuteSession(/*max_workers=*/8, /*split_threshold=*/1);
+      Thunk::ExecuteSession(/*max_workers=*/8, /*split_threshold=*/0);
 
   auto execute_event = executor.Execute(params);
 
@@ -433,11 +566,11 @@
 // and optionally uses a thread pool to execute thunk executor tasks.
 class ThunkExecutorStressTest
     : public testing::TestWithParam<
-          std::tuple<int32_t, bool, bool, SharedResourceUse, bool>> {
+          std::tuple<int32_t, bool, bool, SharedResourceUse, bool, bool>> {
  public:
   void SetUp() override {
     auto& [num_thunks, use_task_runner, use_device, shared_resource_use,
-           inject_errors] = GetParam();
+           inject_errors, use_priority_ready_queue] = GetParam();
 
     use_task_runner_ = use_task_runner;
     use_device_ = use_device;
@@ -477,16 +610,21 @@
 
 TEST_P(ThunkExecutorStressTest, Execute) {
   auto [num_thunks, use_task_runner, use_device, shared_resource_use,
-        inject_errors] = GetParam();
+        inject_errors, use_priority_ready_queue] = GetParam();
 
   TF_ASSERT_OK_AND_ASSIGN(
       std::unique_ptr<GeneratedThunkSequence> g,
       GenerateThunkSequence(/*num_elements=*/1024, num_thunks,
                             shared_resource_use, inject_errors));
 
+  ThunkExecutor::Options executor_options = {
+      /*execute_sequential_buffer_threshold=*/0,
+      /*use_priority_ready_queue=*/use_priority_ready_queue,
+  };
+
   TF_ASSERT_OK_AND_ASSIGN(
       ThunkExecutor executor,
-      ThunkExecutor::Create(std::move(g->sequence), OptionsForTest()));
+      ThunkExecutor::Create(std::move(g->sequence), executor_options));
 
   BufferAllocations allocations(g->buffers);
   Thunk::ExecuteParams params = {nullptr, &allocations, nullptr, device(),
@@ -516,12 +654,95 @@
                      testing::Values(SharedResourceUse::kNo,
                                      SharedResourceUse::kAll,
                                      SharedResourceUse::kRandom),
-                     /*inject_errors=*/testing::Bool()));
+                     /*inject_errors=*/testing::Bool(),
+                     /*use_priority_ready_queue=*/testing::Bool()));
 
 //===----------------------------------------------------------------------===//
 // Performance benchmarks below
 //===----------------------------------------------------------------------===//
 
+static void BM_FifoReadyQueuePushPop(benchmark::State& state) {
+  ThunkExecutor::FifoReadyQueue queue({});
+  const size_t num_push_pop = state.range(0);
+
+  for (auto _ : state) {
+    for (int i = 0; i < num_push_pop; ++i) {
+      queue.Push(i);
+    }
+    for (int i = 0; i < num_push_pop; ++i) {
+      benchmark::DoNotOptimize(queue.Pop());
+    }
+  }
+}
+
+static void BM_FifoReadyQueuePushPopHalf(benchmark::State& state) {
+  ThunkExecutor::FifoReadyQueue queue({});
+  const size_t num_push_pop = state.range(0);
+
+  for (auto _ : state) {
+    for (int i = 0; i < num_push_pop; ++i) {
+      queue.Push(i);
+    }
+    benchmark::DoNotOptimize(queue.PopHalf());
+  }
+}
+
+static void BM_PriorityReadyQueuePushPop(benchmark::State& state) {
+  std::vector<ThunkExecutor::NodeDef> nodes_defs(16);
+  for (size_t i = 0; i < nodes_defs.size(); ++i) {
+    nodes_defs[i].priority = i;
+  }
+
+  std::default_random_engine rng;
+  absl::c_shuffle(nodes_defs, rng);
+
+  ThunkExecutor::PriorityReadyQueue queue(nodes_defs, {});
+  const size_t num_push_pop = state.range(0);
+
+  for (auto _ : state) {
+    for (int i = 0; i < num_push_pop; ++i) {
+      queue.Push(i);
+    }
+    for (int i = 0; i < num_push_pop; ++i) {
+      benchmark::DoNotOptimize(queue.Pop());
+    }
+  }
+}
+
+static void BM_PriorityReadyQueuePushPopHalf(benchmark::State& state) {
+  std::vector<ThunkExecutor::NodeDef> nodes_defs(16);
+  for (size_t i = 0; i < nodes_defs.size(); ++i) {
+    nodes_defs[i].priority = i;
+  }
+
+  std::default_random_engine rng;
+  absl::c_shuffle(nodes_defs, rng);
+
+  ThunkExecutor::PriorityReadyQueue queue(nodes_defs, {});
+  const size_t num_push_pop = state.range(0);
+
+  for (auto _ : state) {
+    for (int i = 0; i < num_push_pop; ++i) {
+      queue.Push(i);
+    }
+    benchmark::DoNotOptimize(queue.PopHalf());
+  }
+}
+
+#define BENCHMARK_READY_QUEUE(name) \
+  BENCHMARK(name)                   \
+      ->MeasureProcessCPUTime()     \
+      ->Arg(1)                      \
+      ->Arg(2)                      \
+      ->Arg(4)                      \
+      ->Arg(8)                      \
+      ->Arg(16)
+
+BENCHMARK_READY_QUEUE(BM_FifoReadyQueuePushPop);
+BENCHMARK_READY_QUEUE(BM_FifoReadyQueuePushPopHalf);
+BENCHMARK_READY_QUEUE(BM_PriorityReadyQueuePushPop);
+BENCHMARK_READY_QUEUE(BM_PriorityReadyQueuePushPopHalf);
+
 static void BM_SequentialThunkExecutor(benchmark::State& state) {
   const size_t num_thunks = state.range(0);
 
diff --git a/third_party/xla/xla/service/cpu/runtime_handle_ffi_call.cc b/third_party/xla/xla/service/cpu/runtime_handle_ffi_call.cc
index 3c9c8d1..6b07b41 100644
--- a/third_party/xla/xla/service/cpu/runtime_handle_ffi_call.cc
+++ b/third_party/xla/xla/service/cpu/runtime_handle_ffi_call.cc
@@ -142,9 +142,9 @@
 
   // Forward executable run options to the FFI handlers via the call options.
   ffi::CallOptions call_options = {
-      run_options->device_ordinal(), run_options->stream(),
-      run_options->allocator(), /*called_computation=*/nullptr,
-      run_options->ffi_execution_context()};
+      run_options->device_ordinal(),
+      ffi::CallOptions::CpuOptions{run_options->intra_op_thread_pool()},
+      /*called_computation=*/nullptr, run_options->ffi_execution_context()};
 
   ffi::CallFrame call_frame = builder.Build();
   return ffi::Call(registration->bundle.execute, call_frame, call_options);
diff --git a/third_party/xla/xla/service/cpu/tests/BUILD b/third_party/xla/xla/service/cpu/tests/BUILD
index 1c8c111..be51e84 100644
--- a/third_party/xla/xla/service/cpu/tests/BUILD
+++ b/third_party/xla/xla/service/cpu/tests/BUILD
@@ -361,6 +361,7 @@
     name = "onednn_matmul_test",
     srcs = ["onednn_matmul_test.cc"],
     copts = tsl_copts(),
+    shard_count = 4,
     tags = [
         "no_oss",
         "notap",
@@ -372,7 +373,7 @@
         "//xla:test_helpers",
         "//xla/hlo/utils:hlo_matchers",
         "//xla/service:cpu_plugin",
-        "//xla/service/cpu:onednn_matmul_rewriter",
+        "//xla/service/cpu:onednn_contraction_rewriter",
         "//xla/service/cpu:onednn_util",
         "//xla/tests:filecheck",
         "//xla/tests:hlo_test_base",
@@ -393,7 +394,7 @@
         "//xla:test_helpers",
         "//xla/hlo/utils:hlo_matchers",
         "//xla/service:cpu_plugin",
-        "//xla/service/cpu:onednn_matmul_rewriter",
+        "//xla/service/cpu:onednn_contraction_rewriter",
         "//xla/service/cpu:onednn_util",
         "//xla/tests:filecheck",
         "//xla/tests:hlo_test_base",
diff --git a/third_party/xla/xla/service/cpu/tests/onednn_convolution_test.cc b/third_party/xla/xla/service/cpu/tests/onednn_convolution_test.cc
index 50c0e8f..6bceebc 100644
--- a/third_party/xla/xla/service/cpu/tests/onednn_convolution_test.cc
+++ b/third_party/xla/xla/service/cpu/tests/onednn_convolution_test.cc
@@ -19,7 +19,7 @@
 
 #include "xla/hlo/utils/hlo_matchers.h"
 #include "xla/literal.h"
-#include "xla/service/cpu/onednn_matmul_rewriter.h"
+#include "xla/service/cpu/onednn_contraction_rewriter.h"
 #include "xla/service/cpu/onednn_util.h"
 #include "xla/shape_util.h"
 #include "xla/test.h"
diff --git a/third_party/xla/xla/service/cpu/tests/onednn_matmul_test.cc b/third_party/xla/xla/service/cpu/tests/onednn_matmul_test.cc
index d7fb39f..57f7c09 100644
--- a/third_party/xla/xla/service/cpu/tests/onednn_matmul_test.cc
+++ b/third_party/xla/xla/service/cpu/tests/onednn_matmul_test.cc
@@ -19,7 +19,7 @@
 
 #include "xla/hlo/utils/hlo_matchers.h"
 #include "xla/literal.h"
-#include "xla/service/cpu/onednn_matmul_rewriter.h"
+#include "xla/service/cpu/onednn_contraction_rewriter.h"
 #include "xla/service/cpu/onednn_util.h"
 #include "xla/shape_util.h"
 #include "xla/test.h"
@@ -803,7 +803,9 @@
   ; CHECK:       backend_config={
   ; CHECK-DAG:     "outer_dimension_partitions":[],
   ; CHECK-DAG:     "onednn_matmul_config":{
-  ; CHECK-NOT:       "fused_ops":["LINEAR"]
+  ; CHECK-NOT:       "fusions":{
+  ; CHECK-NOT:         "ops":["LINEAR"]
+  ; CHECK-NOT:     }
   ; CHECK-DAG:   }
   ; CHECK:     }
   )");
@@ -1499,47 +1501,44 @@
   )");
 }
 
+TEST_F(MatmulTest, ColMajorBF16DotBeforeLayoutAssignment) {
+  if (!IsSupportedType(PrimitiveType::BF16)) {
+    GTEST_SKIP() << "CPU does not support BF16.";
+  }
+
+  const char* matmul_module_str = R"(
+  HloModule matmul.colmajor.test
+  ENTRY matmul.colmajor.test.bf16 {
+    arg.0 = bf16[500,500]{0,1} parameter(0)
+    arg.1 = bf16[500,500]{1,0} parameter(1)
+    transpose.0 = bf16[500,500]{0,1} transpose(arg.1), dimensions={1,0}
+    ROOT dot.0 = bf16[500,500]{1,0} dot(arg.0, arg.1), lhs_contracting_dims={1},
+      rhs_contracting_dims={0}
+  })";
+
+  EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec(1e-2, 1e-2)));
+  MatchOptimizedHlo(matmul_module_str,
+                    R"(
+  ; CHECK: (bf16[500,500]{1,0}, u8[{{.*}}]{0})
+  ; CHECK-SAME: custom_call_target="__onednn$matmul"
+  )");
+}
+
 TEST_F(MatmulTest, ConsecutiveBinaryAdd) {
   const char* matmul_module_str = R"(
   HloModule matmul.test.f32
-    region_0.22 {
-    Arg_0.23 = f32[] parameter(0)
-    Arg_1.24 = f32[] parameter(1)
-  ROOT add.25 = f32[] add(Arg_0.23, Arg_1.24)
-  }
-
-  region_1.29 {
-    Arg_0.30 = f32[] parameter(0)
-    Arg_1.31 = f32[] parameter(1)
-    ROOT add.32 = f32[] add(Arg_0.30, Arg_1.31)
-  }
-
-  ENTRY main {
-    constant.2 = f32[] constant(1e-06)
-    broadcast.3 = f32[1000000] broadcast(constant.2), dimensions={}
-    constant.7 = f32[] constant(1)
-    broadcast.8 = f32[1000000,3] broadcast(constant.7), dimensions={}
-    Arg_0.1 = f32[3] parameter(0)
-    reshape.10 = f32[1,3] reshape(Arg_0.1)
-    broadcast.11 = f32[1,3] broadcast(reshape.10), dimensions={0,1}
-    reshape.12 = f32[3] reshape(broadcast.11)
-    broadcast.13 = f32[1000000,3] broadcast(reshape.12), dimensions={1}
-    subtract.14 = f32[1000000,3] subtract(broadcast.8, broadcast.13)
-    constant.4 = f32[] constant(0)
-    broadcast.5 = f32[3,3] broadcast(constant.4), dimensions={}
-    dot.15 = f32[1000000,3] dot(subtract.14, broadcast.5), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-    dot.16 = f32[1000000,3] dot(broadcast.3, dot.15), lhs_batch_dims={0}, lhs_contracting_dims={}, rhs_batch_dims={0}, rhs_contracting_dims={}
-    dot.17 = f32[1000000,3] dot(broadcast.3, subtract.14), lhs_batch_dims={0}, lhs_contracting_dims={}, rhs_batch_dims={0}, rhs_contracting_dims={}
-    dot.18 = f32[1000000,3] dot(dot.17, broadcast.5), lhs_contracting_dims={1}, rhs_contracting_dims={1}
-    add.19 = f32[1000000,3] add(dot.16, dot.18)
-    constant.9 = f32[3] constant({1, 2, 3})
-    dot.20 = f32[1000000,3] dot(broadcast.3, constant.9), lhs_contracting_dims={}, rhs_contracting_dims={}
-    add.21 = f32[1000000,3] add(add.19, dot.20)
-    constant.6 = f32[] constant(0)
-    reduce.26 = f32[3] reduce(add.21, constant.6), dimensions={0}, to_apply=region_0.22
-    reshape.27 = f32[1,3] reshape(reduce.26)
-    negate.28 = f32[1,3] negate(reshape.27)
-    ROOT reduce.33 = f32[3] reduce(negate.28, constant.6), dimensions={0}, to_apply=region_1.29
+  ENTRY matmul.test.f32 {
+    arg0.1 = f32[128,32,4,4] parameter(0)
+    arg0.2 = f32[128,32,4,4] parameter(1)
+    dot.7 = f32[128,32,4,4] dot(arg0.1, arg0.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+    const.0 = f32[128,32] constant({...})
+    bcast.1 = f32[128,32,4,4] broadcast(const.0), dimensions={0,1}
+    add.0 = f32[128,32,4,4] add(dot.7,bcast.1)
+    const.1 = f32[4] constant({1,2,3,4})
+    bcast.2 = f32[128,32,4,4] broadcast(const.1), dimensions={3}
+    add.1 = f32[128,32,4,4] add(add.0, bcast.2)
+    tuple.12 = (f32[128,32,4,4]) tuple(add.1)
+    ROOT get-tuple-element.13 = f32[128,32,4,4] get-tuple-element(tuple.12), index=0
   })";
 
   EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4}));
diff --git a/third_party/xla/xla/service/cpu_gpu_shape_verifier.cc b/third_party/xla/xla/service/cpu_gpu_shape_verifier.cc
index 0d4317b..534b359 100644
--- a/third_party/xla/xla/service/cpu_gpu_shape_verifier.cc
+++ b/third_party/xla/xla/service/cpu_gpu_shape_verifier.cc
@@ -31,6 +31,7 @@
   switch (instruction->opcode()) {
     case HloOpcode::kBitcast:
     case HloOpcode::kBroadcast:
+    case HloOpcode::kCall:
     case HloOpcode::kConstant:
     case HloOpcode::kConcatenate:
     case HloOpcode::kConvert:
diff --git a/third_party/xla/xla/service/cpu_gpu_shape_verifier_test.cc b/third_party/xla/xla/service/cpu_gpu_shape_verifier_test.cc
index 7bb40d7..ec4d07e 100644
--- a/third_party/xla/xla/service/cpu_gpu_shape_verifier_test.cc
+++ b/third_party/xla/xla/service/cpu_gpu_shape_verifier_test.cc
@@ -65,10 +65,15 @@
   const char* const hlo_string = R"(
   HloModule Module
 
-  ENTRY main {
+  bcast {
     p0 = u4[] parameter(0)
     ROOT out = u4[3, 3] broadcast(p0), dimensions={}
   }
+
+  ENTRY main {
+    p0 = u4[] parameter(0)
+    ROOT out = u4[3, 3] call(p0), to_apply=bcast
+  }
   )";
   TF_ASSERT_OK_AND_ASSIGN(auto module,
                           ParseAndReturnUnverifiedModule(hlo_string));
diff --git a/third_party/xla/xla/service/dump.cc b/third_party/xla/xla/service/dump.cc
index 13b3032..3aa3a88 100644
--- a/third_party/xla/xla/service/dump.cc
+++ b/third_party/xla/xla/service/dump.cc
@@ -50,10 +50,10 @@
 #include "xla/hlo/ir/hlo_opcode.h"
 #include "xla/service/hlo_graph_dumper.h"
 #include "xla/service/hlo_proto_util.h"
+#include "xla/tsl/lib/strings/proto_serialization.h"
 #include "xla/util.h"
 #include "tsl/lib/io/zlib_compression_options.h"
 #include "tsl/lib/io/zlib_outputbuffer.h"
-#include "tsl/lib/strings/proto_serialization.h"
 #include "tsl/platform/env.h"
 #include "tsl/platform/errors.h"
 #include "tsl/platform/file_system.h"
@@ -64,6 +64,24 @@
 
 namespace xla {
 
+absl::Status CreateDirIfNeeded(const std::string& dir, tsl::Env* env) {
+  if (!env->IsDirectory(dir).ok()) {
+    absl::Status status = env->RecursivelyCreateDir(dir);
+    // Two threads can race to observe the absence of the dump directory and
+    // simultaneously try to create it, causing the "losing" thread to get a
+    // "directory already exists" error.  We can work around this by checking
+    // again whether the dir exists.
+    if (!status.ok()) {
+      status = env->IsDirectory(dir);
+      if (!status.ok()) {
+        LOG(ERROR) << "Could not create directory " << dir;
+        return status;
+      }
+    }
+  }
+  return absl::OkStatus();
+}
+
 std::string RenderGraph(absl::string_view label, const HloModule& module,
                         RenderedGraphFormat format,
                         bool show_fusion_subcomputations) {
@@ -299,17 +317,8 @@
   VLOG(1) << "Dumping " << filename << " to " << dir;
 
   tsl::Env* env = tsl::Env::Default();
-  // Two threads can race to observe the absence of the dump directory and
-  // simultaneously try to create it, causing the "losing" thread to get a
-  // "directory already exists" error.  We can work around this by checking
-  // again whether the dir exists.
-  if (!env->IsDirectory(dir).ok()) {
-    auto status = env->RecursivelyCreateDir(dir);
-    if (!status.ok() && !env->IsDirectory(dir).ok()) {
-      LOG(ERROR) << "Could not create directory " << dir
-                 << " for dumping XLA debug data: " << status;
-      return std::nullopt;
-    }
+  if (!CreateDirIfNeeded(dir, env).ok()) {
+    return std::nullopt;
   }
 
   // Make sure we are not going to dump more modules than the user has asked.
@@ -677,15 +686,7 @@
   if (dir.empty()) {
     return;
   }
-  if (!env->IsDirectory(dir).ok()) {
-    auto status = env->RecursivelyCreateDir(dir);
-    if (!status.ok()) {
-      LOG(ERROR) << "Could not create directory " << dir
-                 << " for dumping: " << status;
-      return;
-    }
-  }
-  if (!env->IsDirectory(dir).ok()) {
+  if (!CreateDirIfNeeded(dir, env).ok()) {
     return;
   }
   const std::string path = tsl::io::JoinPath(dir, filename);
@@ -884,4 +885,20 @@
   }
 }
 
+absl::Status DumpProtoToDirectory(const tsl::protobuf::Message& message,
+                                  const std::string& directory,
+                                  const std::string& file_name,
+                                  std::string* full_path) {
+  tsl::Env* env = tsl::Env::Default();
+  TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(directory));
+  TF_RETURN_IF_ERROR(CreateDirIfNeeded(directory, env));
+  std::string safe_file_name = SanitizeFileName(file_name) + ".pb";
+  std::string full_path_impl;
+  if (!full_path) {
+    full_path = &full_path_impl;
+  }
+  *full_path = tsl::io::JoinPath(directory, safe_file_name);
+  return tsl::WriteBinaryProto(env, *full_path, message);
+}
+
 }  // namespace xla
diff --git a/third_party/xla/xla/service/dump.h b/third_party/xla/xla/service/dump.h
index 0b1a6d2..623e729 100644
--- a/third_party/xla/xla/service/dump.h
+++ b/third_party/xla/xla/service/dump.h
@@ -43,6 +43,10 @@
 class BufferAssignment;
 class HloSnapshot;
 
+// Creates dir if doesn't exist (analogue of `mkdir -p`), tries to get around
+// race conditions by trying again on collision.
+absl::Status CreateDirIfNeeded(const std::string& dir, tsl::Env* env);
+
 // Get a timestamp which we can use as a filename prefix specific to this
 // module.
 std::string TimestampFor(const HloModule& module);
@@ -173,6 +177,18 @@
 // writing to two files, but you don't want to print twice.
 bool DumpingToStdout(const DebugOptions& opts);
 
+// Writes the given message in binary proto to the path formed by joining
+// 'directory/file_name.pb'. The 'directory' is recursively created if it
+// doesn't already exist, and the 'file_name' is sanitized by replacing
+// illegal characters with underscore '_'.
+//
+// If 'full_name' is not null then it is set to the name of the file the
+// protobuf was written to.
+absl::Status DumpProtoToDirectory(const tsl::protobuf::Message& message,
+                                  const std::string& directory,
+                                  const std::string& file_name,
+                                  std::string* full_path = nullptr);
+
 }  // namespace xla
 
 #endif  // XLA_SERVICE_DUMP_H_
diff --git a/third_party/xla/xla/service/executable.cc b/third_party/xla/xla/service/executable.cc
index ed86114..aa81fce 100644
--- a/third_party/xla/xla/service/executable.cc
+++ b/third_party/xla/xla/service/executable.cc
@@ -25,7 +25,6 @@
 #include "xla/service/maybe_owning_device_memory.h"
 #include "xla/status_macros.h"
 #include "xla/stream_executor/device_description.h"
-#include "tsl/lib/strings/proto_serialization.h"
 #include "tsl/platform/env.h"
 #include "tsl/platform/errors.h"
 #include "tsl/platform/status.h"
diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD
index 7c04293..080c7e8 100644
--- a/third_party/xla/xla/service/gpu/BUILD
+++ b/third_party/xla/xla/service/gpu/BUILD
@@ -36,7 +36,6 @@
 load(
     "//xla/tsl:tsl.bzl",
     "if_google",
-    "if_oss",
     "internal_visibility",
     "tsl_copts",
     "tsl_gpu_library",
@@ -314,7 +313,6 @@
         ":execution_stream_assignment",
         ":gpu_asm_opts_util",
         ":gpu_conv_runner",
-        ":gpu_fused_mha_runner",
         ":gpu_norm_runner",
         ":hlo_fusion_analysis",
         ":ir_emission_utils",
@@ -357,9 +355,9 @@
         "//xla/service/gpu/runtime:conditional_thunk",
         "//xla/service/gpu/runtime:convolution_thunk",
         "//xla/service/gpu/runtime:copy_thunk",
+        "//xla/service/gpu/runtime:cudnn_thunk",
         "//xla/service/gpu/runtime:custom_call_thunk",
         "//xla/service/gpu/runtime:fft_thunk",
-        "//xla/service/gpu/runtime:fused_mha_thunk",
         "//xla/service/gpu/runtime:gemm_thunk",
         "//xla/service/gpu/runtime:gpublas_lt_matmul_thunk",
         "//xla/service/gpu/runtime:infeed_thunk",
@@ -486,135 +484,6 @@
 )
 
 cc_library(
-    name = "gemm_fusion_autotuner",
-    srcs = if_cuda_is_configured(["gemm_fusion_autotuner.cc"]),
-    hdrs = if_cuda_is_configured(["gemm_fusion_autotuner.h"]),
-    local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
-    deps = if_cuda_is_configured([
-        ":autotuner_compile_util",
-        ":autotuner_util",
-        ":backend_configs_cc",
-        ":buffer_comparator",
-        ":gemm_rewriter",
-        ":gpu_float_support",
-        ":gpu_fusible",
-        ":instruction_fusion",
-        ":ir_emission_utils",
-        ":matmul_utils",
-        ":split_k_gemm_rewriter",
-        ":stream_executor_util",
-        ":cudnn_fusion_compiler",
-        "@com_google_absl//absl/algorithm:container",
-        "@com_google_absl//absl/container:flat_hash_map",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/log:check",
-        "@com_google_absl//absl/log",
-        "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/synchronization",
-        "@com_google_absl//absl/time",
-        "@com_google_absl//absl/status",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings:str_format",
-        "@com_google_absl//absl/types:span",
-        "@local_config_cuda//cuda:cuda_headers",
-        "//xla:autotuning_proto_cc",
-        "//xla:shape_util",
-        "//xla:status_macros",
-        "//xla/tools:hlo_decomposer_lib",
-        "//xla:util",
-        "//xla:xla_data_proto_cc",
-        "//xla:xla_proto_cc",
-        "//xla/hlo/ir:hlo",
-        "//xla/hlo/utils:hlo_query",
-        "//xla/pjrt/distributed:key_value_store_interface",
-        "//xla/service:algorithm_util",
-        "//xla/service:dump",
-        "//xla/service:executable",
-        "//xla/service:float_normalization",
-        "//xla/service:hlo_module_config",
-        "//xla/service:hlo_pass",
-        "//xla/service:shaped_buffer",
-        "//xla/stream_executor:device_description",
-        "//xla/stream_executor:device_memory",
-        "//xla/stream_executor",
-        "//xla/stream_executor/gpu:redzone_allocator",
-        "@local_tsl//tsl/lib/core:bits",
-        "@local_tsl//tsl/platform:blocking_counter",
-        "@local_tsl//tsl/platform:env",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:protobuf",
-        "@local_tsl//tsl/platform:status",
-        "@local_tsl//tsl/platform:statusor",
-        "@local_tsl//tsl/profiler/lib:scoped_annotation",
-        "//xla/tsl/util/proto:proto_utils",
-        "//xla/service/gpu:hlo_traversal",
-        ":fusion_wrapper",
-        ":priority_fusion",
-        "//xla/service/gpu/model:gpu_hlo_cost_analysis",
-        "//xla/stream_executor:stream_executor_memory_allocator",
-        "@local_tsl//tsl/platform:path",
-    ]),
-)
-
-xla_test(
-    name = "gemm_fusion_autotuner_test",
-    srcs = if_cuda_is_configured(["gemm_fusion_autotuner_test.cc"]),
-    backend_tags = {"gpu": [
-        "requires-gpu-sm80",
-    ]},
-    backends = [
-        "gpu",
-    ],
-    tags = [
-        "nomac",
-    ],
-    deps = [
-        ":autotuner_util",
-        ":backend_configs_cc",
-        ":gemm_fusion",
-        ":gemm_fusion_autotuner",
-        ":gemm_rewriter",
-        ":ir_emission_utils",
-        ":matmul_utils",
-        "//xla:autotuning_proto_cc",
-        "//xla:error_spec",
-        "//xla:xla_data_proto_cc",
-        "//xla:xla_proto_cc",
-        "//xla/hlo/ir:hlo",
-        "//xla/pjrt/distributed:key_value_store_interface",
-        "//xla/service:call_inliner",
-        "//xla/service:dump",
-        "//xla/service:executable",
-        "//xla/service:hlo_module_config",
-        "//xla/service:hlo_pass_pipeline",
-        "//xla/service:pattern_matcher",
-        "//xla/service:pattern_matcher_gmock",
-        "//xla/stream_executor:device_description",
-        "//xla/stream_executor:device_description_proto_cc",
-        "//xla/tests:filecheck",
-        "//xla/tests:hlo_test_base",
-        "//xla/tests:test_utils",
-        "//xla/tests:verified_hlo_module",
-        "//xla/tests:xla_internal_test_main",  # fixdeps: keep
-        "//xla/tools:hlo_decomposer_lib",
-        "//xla/tsl/lib/core:status_test_util",
-        "@com_google_absl//absl/log",
-        "@com_google_absl//absl/log:check",
-        "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/strings:str_format",
-        "@com_google_googletest//:gtest",
-        "@local_tsl//tsl/platform:env",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:path",
-        "@local_tsl//tsl/platform:platform_port",
-        "@local_tsl//tsl/platform:status_matchers",
-        "@local_tsl//tsl/platform:statusor",
-    ] + if_cuda_is_configured([
-        "@local_config_cuda//cuda:cuda_headers",
-    ]),
-)
-
-cc_library(
     name = "triton_call",
     srcs = if_gpu_is_configured(["triton_call.cc"]),
     hdrs = ["triton_call.h"],
@@ -789,6 +658,7 @@
         "//xla/service/llvm_ir:llvm_type_conversion_util",
         "//xla/service/llvm_ir:llvm_util",
         "//xla/stream_executor:device_description",
+        "//xla/tsl/lib/strings:proto_serialization",
         "@com_google_absl//absl/algorithm:container",
         "@com_google_absl//absl/container:flat_hash_set",
         "@com_google_absl//absl/container:inlined_vector",
@@ -800,7 +670,6 @@
         "@llvm-project//llvm:Core",
         "@llvm-project//llvm:Support",
         "@llvm-project//llvm:TargetParser",
-        "@local_tsl//tsl/lib/strings:proto_serialization",
         "@local_tsl//tsl/platform:protobuf",
         "@local_tsl//tsl/platform:statusor",
     ],
@@ -817,11 +686,11 @@
         "//xla:literal_util",
         "//xla:shape_util",
         "//xla:types",
-        "//xla:util",
         "//xla/hlo/ir:backend_config",
         "//xla/service:buffer_assignment",
         "//xla/tests:hlo_test_base",
         "//xla/tests:xla_internal_test_main",  # fixdeps: keep
+        "@com_google_absl//absl/container:inlined_vector",
         "@com_google_absl//absl/strings",
         "@local_tsl//tsl/platform:status_matchers",
         "@local_tsl//tsl/platform:statusor",
@@ -833,6 +702,7 @@
     name = "reduction_utils",
     srcs = ["reduction_utils.cc"],
     hdrs = ["reduction_utils.h"],
+    compatible_with = get_compatible_with_portable(),
     local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
     deps = [
         ":ir_emission_utils",
@@ -914,45 +784,6 @@
 )
 
 cc_library(
-    name = "gemm_rewriter",
-    srcs = ["gemm_rewriter.cc"],
-    hdrs = ["gemm_rewriter.h"],
-    deps = [
-        ":backend_configs_cc",
-        ":cublas_cudnn",
-        ":ir_emission_utils",
-        ":matmul_utils",
-        "//xla:literal",
-        "//xla:literal_util",
-        "//xla:shape_util",
-        "//xla:status_macros",
-        "//xla:types",
-        "//xla:util",
-        "//xla:xla_data_proto_cc",
-        "//xla/hlo/evaluator:hlo_evaluator",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:algorithm_util",
-        "//xla/service:hlo_creation_utils",
-        "//xla/service:hlo_pass",
-        "//xla/service:pattern_matcher",
-        "//xla/stream_executor:blas",
-        "//xla/stream_executor:device_description",
-        "//xla/stream_executor/gpu:gpu_blas_lt",
-        "@com_google_absl//absl/algorithm:container",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/log",
-        "@com_google_absl//absl/status",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/types:span",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:ml_dtypes",
-        "@local_tsl//tsl/platform:statusor",
-        "@local_tsl//tsl/protobuf:dnn_proto_cc",
-    ],
-)
-
-cc_library(
     name = "triton_tiling_propagation",
     srcs = ["triton_tiling_propagation.cc"],
     hdrs = ["triton_tiling_propagation.h"],
@@ -1017,9 +848,9 @@
     name = "triton_fusion_analysis_test",
     srcs = ["triton_fusion_analysis_test.cc"],
     deps = [
-        ":gemm_fusion",
         ":triton_fusion_analysis",
         "//xla/hlo/ir:hlo",
+        "//xla/service/gpu/transforms:gemm_fusion",
         "//xla/stream_executor:device_description",
         "//xla/tests:hlo_test_base",
         "//xla/tests:verified_hlo_module",
@@ -1031,100 +862,6 @@
 )
 
 cc_library(
-    name = "gemm_fusion",
-    srcs = ["gemm_fusion.cc"],
-    hdrs = ["gemm_fusion.h"],
-    deps = [
-        ":backend_configs_cc",
-        ":cublas_padding_requirements",
-        ":ir_emission_utils",
-        ":matmul_utils",
-        ":triton_fusion_analysis",
-        ":triton_tiling_propagation",
-        "//xla:shape_util",
-        "//xla:util",
-        "//xla:xla_data_proto_cc",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_pass",
-        "//xla/service:instruction_fusion",
-        "//xla/service/gpu/fusions/triton:triton_support",
-        "//xla/stream_executor:device_description",
-        "@com_google_absl//absl/algorithm:container",
-        "@com_google_absl//absl/container:flat_hash_map",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/log",
-        "@com_google_absl//absl/log:check",
-        "@com_google_absl//absl/status",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:statusor",
-        "@local_tsl//tsl/platform:tensor_float_32_utils",
-    ],
-)
-
-xla_cc_test(
-    name = "gemm_fusion_test",
-    srcs = ["gemm_fusion_test.cc"],
-    deps = [
-        ":cublas_padding_requirements",
-        ":gemm_fusion",
-        ":triton_fusion_analysis",
-        "//xla:autotuning_proto_cc",
-        "//xla:xla_data_proto_cc",
-        "//xla:xla_proto_cc",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:pattern_matcher",
-        "//xla/service:pattern_matcher_gmock",
-        "//xla/stream_executor:device_description",
-        "//xla/tests:filecheck",
-        "//xla/tests:hlo_test_base",
-        "//xla/tests:verified_hlo_module",
-        "//xla/tests:xla_internal_test_main",  # fixdeps: keep
-        "@com_google_absl//absl/status",
-        "@com_google_absl//absl/strings",
-        "@com_google_googletest//:gtest",
-        "@local_tsl//tsl/platform:status_matchers",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-cc_library(
-    name = "gemv_rewriter",
-    srcs = ["gemv_rewriter.cc"],
-    hdrs = ["gemv_rewriter.h"],
-    deps = [
-        "//xla:shape_util",
-        "//xla:util",
-        "//xla:xla_data_proto_cc",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_pass",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/container:inlined_vector",
-        "@com_google_absl//absl/status",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings:string_view",
-        "@com_google_absl//absl/types:span",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-xla_cc_test(
-    name = "gemv_rewriter_test",
-    srcs = ["gemv_rewriter_test.cc"],
-    deps = [
-        ":gemv_rewriter",
-        "//xla/hlo/ir:hlo",
-        "//xla/tests:hlo_test_base",
-        "//xla/tests:xla_internal_test_main",  # fixdeps: keep
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_googletest//:gtest_main",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-cc_library(
     name = "split_k_gemm_rewriter",
     srcs = ["split_k_gemm_rewriter.cc"],
     hdrs = ["split_k_gemm_rewriter.h"],
@@ -1185,216 +922,6 @@
 )
 
 cc_library(
-    name = "softmax_rewriter_triton",
-    srcs = ["softmax_rewriter_triton.cc"],
-    hdrs = ["softmax_rewriter_triton.h"],
-    deps = [
-        ":backend_configs_cc",
-        ":hlo_traversal",
-        ":ir_emission_utils",
-        "//xla:shape_util",
-        "//xla:status_macros",
-        "//xla:util",
-        "//xla:xla_data_proto_cc",
-        "//xla/hlo/ir:hlo",
-        "//xla/hlo/utils:hlo_query",
-        "//xla/service:hlo_cost_analysis",
-        "//xla/service:hlo_pass",
-        "//xla/service:instruction_fusion",
-        "//xla/service/gpu/fusions/triton:triton_support",
-        "//xla/service/gpu/model:fusion_analysis_cache",
-        "//xla/service/gpu/model:gpu_indexing_performance_model",
-        "//xla/service/gpu/model:symbolic_tile_analysis",
-        "//xla/service/gpu/model:tiled_hlo_computation",
-        "//xla/stream_executor:device_description",
-        "@com_google_absl//absl/algorithm:container",
-        "@com_google_absl//absl/container:flat_hash_map",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/log",
-        "@com_google_absl//absl/log:check",
-        "@com_google_absl//absl/status",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings",
-        "@llvm-project//mlir:IR",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:logging",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-cc_library(
-    name = "gemm_algorithm_picker",
-    srcs = if_gpu_is_configured(["gemm_algorithm_picker.cc"]),
-    hdrs = if_gpu_is_configured(["gemm_algorithm_picker.h"]),
-    deps = if_gpu_is_configured([
-        ":backend_configs_cc",
-        ":buffer_comparator",
-        ":cublas_cudnn",
-        ":gpu_asm_opts_util",
-        ":gpu_conv_runner",
-        ":ir_emission_utils",
-        ":matmul_utils",
-        ":stream_executor_util",
-        ":variant_visitor",
-        ":autotuner_compile_util",
-        ":autotuner_util",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/log",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/synchronization",
-        "@com_google_absl//absl/types:span",
-        "//xla:autotune_results_proto_cc",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_module_config",
-        "//xla/service:hlo_pass",
-        "//xla:status_macros",
-        "//xla/stream_executor",
-        "//xla/stream_executor:blas",
-        "//xla/stream_executor/gpu:gpu_blas_lt",
-        "//xla/stream_executor:device_memory",
-        "//xla/stream_executor:device_memory_allocator",
-        "//xla/stream_executor/gpu:redzone_allocator",
-        "//xla/tsl/util/proto:proto_utils",
-        "//xla:util",
-        "//xla:autotuning_proto_cc",
-        "//xla:shape_util",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:logging",
-        "@local_tsl//tsl/platform:statusor",
-        "@local_tsl//tsl/profiler/lib:scoped_annotation",
-    ]) + ["@com_google_absl//absl/status"],
-)
-
-cc_library(
-    name = "autotuner_util",
-    srcs = if_gpu_is_configured(["autotuner_util.cc"]),
-    hdrs = if_gpu_is_configured(["autotuner_util.h"]),
-    deps = if_gpu_is_configured([
-        ":gpu_asm_opts_util",
-        ":stream_executor_util",
-        "@com_google_absl//absl/base:core_headers",
-        "@com_google_absl//absl/base",
-        "@com_google_absl//absl/container:flat_hash_map",
-        "@com_google_absl//absl/hash",
-        "@com_google_absl//absl/log:check",
-        "@com_google_absl//absl/log",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/status",
-        "@com_google_absl//absl/strings:str_format",
-        "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/synchronization",
-        "@llvm-project//llvm:Core",
-        "@llvm-project//llvm:Support",
-        "//xla:autotune_results_proto_cc",
-        "//xla:autotuning_proto_cc",
-        "//xla:shape_util",
-        "//xla:status_macros",
-        "//xla:types",
-        "//xla:util",
-        "//xla:xla_proto_cc",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:compilation_environments",
-        "//xla/stream_executor:stream_executor_memory_allocator",
-        "//xla/stream_executor",
-        "//xla/stream_executor/gpu:redzone_allocator",
-        "@local_tsl//tsl/platform:base64",
-        "@local_tsl//tsl/platform:env",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:logging",
-        "@local_tsl//tsl/platform:path",
-        "@local_tsl//tsl/platform:protobuf",
-        "@local_tsl//tsl/platform:statusor",
-    ]),
-)
-
-# We need a separate target, as runtime executable cannot depend on compilation
-# pipeline.
-cc_library(
-    name = "autotuner_compile_util",
-    srcs = if_gpu_is_configured(["autotuner_compile_util.cc"]),
-    hdrs = if_gpu_is_configured(["autotuner_compile_util.h"]),
-    deps = if_gpu_is_configured([
-        ":autotuner_util",
-        ":gpu_executable_run_options",
-        ":ir_emission_utils",
-        "@com_google_absl//absl/functional:any_invocable",
-        "@com_google_absl//absl/log:check",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/time",
-        "@com_google_absl//absl/types:span",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:compiler",
-        "//xla/service:executable",
-        "//xla/service:hlo_module_config",
-        "//xla/service:maybe_owning_device_memory",
-        "//xla/service:shaped_buffer",
-        "//xla/stream_executor",
-        "//xla/stream_executor/gpu:gpu_stream_header",
-        "//xla/stream_executor/gpu:redzone_allocator",
-        "//xla:executable_run_options",
-        "//xla:shape_util",
-        "//xla:util",
-        "//xla:xla_proto_cc",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:statusor",
-    ]) + ["@com_google_absl//absl/status"],
-)
-
-xla_test(
-    name = "autotuner_compile_util_test",
-    srcs = if_gpu_is_configured(["autotuner_compile_util_test.cc"]),
-    backends = ["gpu"],
-    deps = if_gpu_is_configured(
-        [
-            ":autotuner_compile_util",
-            ":autotuner_util",
-            "@com_google_googletest//:gtest_main",
-            "@com_google_absl//absl/strings",
-            "@com_google_absl//absl/strings:string_view",
-            "//xla/hlo/ir:hlo",
-            "//xla/service:platform_util",
-            "//xla/stream_executor:platform",
-            "//xla/tests:hlo_test_base",
-            "@local_tsl//tsl/platform:statusor",
-        ],
-        if_false = [
-            "@com_google_googletest//:gtest_main",  # b/317293391
-        ],
-    ),
-)
-
-xla_test(
-    name = "gemm_algorithm_picker_test",
-    srcs = if_gpu_is_configured(["gemm_algorithm_picker_test.cc"]),
-    backends = [
-        "gpu_v100",
-        "gpu_amd_any",
-    ],
-    deps = [
-        ":autotuner_util",
-        ":backend_configs_cc",
-        ":gemm_algorithm_picker",
-        ":gemm_rewriter",
-        ":variant_visitor",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:pattern_matcher",
-        "//xla/service:pattern_matcher_gmock",
-        "//xla/service:platform_util",
-        "//xla/stream_executor:device_description",
-        "//xla/stream_executor:platform",
-        "//xla/tests:hlo_test_base",
-        "//xla/tsl/lib/core:status_test_util",
-        "@com_google_absl//absl/strings:string_view",
-        "@local_tsl//tsl/platform:statusor",
-        "@local_tsl//tsl/platform:test",
-        "@local_tsl//tsl/platform:test_main",
-        "@local_tsl//tsl/protobuf:dnn_proto_cc",
-    ],
-)
-
-cc_library(
     name = "matmul_utils",
     srcs = ["matmul_utils.cc"],
     hdrs = ["matmul_utils.h"],
@@ -1465,239 +992,6 @@
 )
 
 cc_library(
-    name = "dot_dimension_sorter",
-    srcs = ["dot_dimension_sorter.cc"],
-    hdrs = ["dot_dimension_sorter.h"],
-    deps = [
-        "//xla:permutation_util",
-        "//xla:shape_util",
-        "//xla:util",
-        "//xla:xla_data_proto_cc",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_pass",
-        "@com_google_absl//absl/algorithm:container",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/status",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings:string_view",
-        "@com_google_absl//absl/types:span",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:logging",
-    ],
-)
-
-xla_test(
-    name = "dot_dimension_sorter_test",
-    srcs = ["dot_dimension_sorter_test.cc"],
-    backends = ["gpu"],
-    deps = [
-        ":dot_dimension_sorter",
-        "//xla:error_spec",
-        "//xla/hlo/ir:hlo",
-        "//xla/service/gpu/tests:gpu_codegen_test",
-        "//xla/tests:xla_internal_test_main",  # fixdeps: keep
-        "@com_google_googletest//:gtest",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-cc_library(
-    name = "dot_sparsity_rewriter",
-    srcs = ["dot_sparsity_rewriter.cc"],
-    hdrs = ["dot_sparsity_rewriter.h"],
-    deps = [
-        "//xla:util",
-        "//xla:xla_data_proto_cc",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_creation_utils",
-        "//xla/service:hlo_pass",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/status",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings:string_view",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-xla_cc_test(
-    name = "dot_sparsity_rewriter_test",
-    srcs = ["dot_sparsity_rewriter_test.cc"],
-    deps = [
-        ":dot_sparsity_rewriter",
-        "//xla:xla_data_proto_cc",
-        "//xla/hlo/ir:hlo",
-        "//xla/tests:hlo_test_base",
-        "//xla/tests:xla_internal_test_main",  # fixdeps: keep
-        "@com_google_googletest//:gtest",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-cc_library(
-    name = "gpu_async_collective_annotator",
-    srcs = ["gpu_async_collective_annotator.cc"],
-    hdrs = ["gpu_async_collective_annotator.h"],
-    deps = [
-        ":backend_configs_cc",
-        "//xla:util",
-        "//xla/hlo/ir:hlo",
-        "//xla/hlo/utils:hlo_query",
-        "//xla/service:hlo_pass",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings:string_view",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-xla_cc_test(
-    name = "gpu_async_collective_annotator_test",
-    srcs = ["gpu_async_collective_annotator_test.cc"],
-    deps = [
-        ":backend_configs_cc",
-        ":gpu_async_collective_annotator",
-        "//xla:util",
-        "//xla/hlo/ir:hlo",
-        "//xla/hlo/utils:hlo_query",
-        "//xla/tests:hlo_test_base",
-        "//xla/tests:test_macros_header",
-        "//xla/tests:xla_internal_test_main",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/strings",
-        "@com_google_googletest//:gtest_main",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-cc_library(
-    name = "gpu_convert_async_collectives_to_sync",
-    srcs = ["gpu_convert_async_collectives_to_sync.cc"],
-    hdrs = ["gpu_convert_async_collectives_to_sync.h"],
-    deps = [
-        ":backend_configs_cc",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:convert_async_collectives_to_sync",
-        "@com_google_absl//absl/container:flat_hash_map",
-        "@com_google_absl//absl/status",
-        "@com_google_absl//absl/strings:string_view",
-        "@com_google_absl//absl/types:span",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-xla_cc_test(
-    name = "gpu_convert_async_collectives_to_sync_test",
-    srcs = ["gpu_convert_async_collectives_to_sync_test.cc"],
-    deps = [
-        ":backend_configs_cc",
-        ":gpu_convert_async_collectives_to_sync",
-        "//xla:util",
-        "//xla/hlo/ir:hlo",
-        "//xla/tests:hlo_test_base",
-        "//xla/tests:xla_internal_test_main",
-        "//xla/tsl/lib/core:status_test_util",
-        "@com_google_absl//absl/status",
-        "@com_google_absl//absl/strings:string_view",
-        "@com_google_googletest//:gtest_main",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-cc_library(
-    name = "conv_algorithm_picker",
-    srcs = if_gpu_is_configured(["conv_algorithm_picker.cc"]),
-    hdrs = if_gpu_is_configured(["conv_algorithm_picker.h"]),
-    local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([
-        "TENSORFLOW_USE_ROCM=1",
-    ]),
-    deps = if_gpu_is_configured([
-        ":autotuner_compile_util",
-        ":autotuner_util",
-        ":backend_configs_cc",
-        ":buffer_comparator",
-        ":cublas_cudnn",
-        ":gpu_asm_opts_util",
-        ":gpu_autotuning_proto_cc",
-        ":gpu_conv_runner",
-        ":hlo_algorithm_denylist",
-        ":stream_executor_util",
-        "@com_google_absl//absl/algorithm:container",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/status",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/strings:str_format",
-        "@com_google_absl//absl/synchronization",
-        "@com_google_absl//absl/time",
-        "@com_google_absl//absl/types:span",
-        "@local_config_cuda//cuda:cudnn_header",
-        "//xla:autotune_results_proto_cc",
-        "//xla:autotuning_proto_cc",
-        "//xla:debug_options_flags",
-        "//xla:literal_util",
-        "//xla:shape_util",
-        "//xla:util",
-        "//xla:xla_data_proto_cc",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:executable",
-        "//xla/service:hlo_module_config",
-        "//xla/service:hlo_pass",
-        "//xla/service:slow_operation_alarm",
-        "//xla/stream_executor",
-        "//xla/stream_executor:dnn",
-        "//xla/stream_executor:numeric_options",
-        "//xla/stream_executor:scratch_allocator",
-        "//xla/stream_executor:device_memory_allocator",
-        "//xla/stream_executor:lazy_op_runner",
-        "//xla/stream_executor/cuda:cuda_platform_id",
-        "//xla/stream_executor/gpu:redzone_allocator",
-        "//xla/stream_executor/rocm:rocm_platform_id",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:logging",
-        "@local_tsl//tsl/platform:numbers",
-        "//xla/tsl/util:env_var",
-        "@local_tsl//tsl/platform:statusor",
-        "//xla/tsl/util/proto:proto_utils",
-        "@local_tsl//tsl/platform:status",
-    ]),
-)
-
-xla_test(
-    name = "conv_algorithm_picker_test",
-    srcs = if_gpu_is_configured(["conv_algorithm_picker_test.cc"]),
-    backends = [
-        "gpu_v100",
-        "gpu_amd_any",
-    ],
-    tags = [
-        "noasan",
-        "nomsan",
-    ],
-    deps = [
-        ":autotuner_util",
-        ":conv_algorithm_picker",
-        ":gpu_conv_rewriter",
-        ":stream_executor_util",
-        "//xla:debug_options_flags",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:pattern_matcher",
-        "//xla/service:pattern_matcher_gmock",
-        "//xla/service:platform_util",
-        "//xla/service:tuple_simplifier",
-        "//xla/stream_executor:device_description",
-        "//xla/stream_executor:platform",
-        "//xla/tests:hlo_test_base",
-        "//xla/tsl/lib/core:status_test_util",
-        "@com_google_absl//absl/strings:string_view",
-        "@local_tsl//tsl/platform:statusor",
-        "@local_tsl//tsl/platform:test",
-        "@local_tsl//tsl/platform:test_main",
-    ],
-)
-
-cc_library(
     name = "gpu_conv_runner",
     srcs = ["gpu_conv_runner.cc"],
     hdrs = ["gpu_conv_runner.h"],
@@ -1751,167 +1045,6 @@
 )
 
 cc_library(
-    name = "gpu_fused_mha_runner",
-    srcs = ["gpu_fused_mha_runner.cc"],
-    hdrs = ["gpu_fused_mha_runner.h"],
-    deps = [
-        ":backend_configs_cc",
-        ":cublas_cudnn",
-        ":stream_executor_util",
-        "//xla:shape_util",
-        "//xla:util",
-        "//xla:xla_data_proto_cc",
-        "//xla/stream_executor",
-        "//xla/stream_executor:dnn",
-        "//xla/stream_executor:lazy_op_runner",
-        "@com_google_absl//absl/container:inlined_vector",
-        "@com_google_absl//absl/log",
-        "@com_google_absl//absl/status",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/strings:str_format",
-        "@eigen_archive//:eigen3",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-cc_library(
-    name = "gpu_conv_rewriter",
-    srcs = ["gpu_conv_rewriter.cc"],
-    hdrs = ["gpu_conv_rewriter.h"],
-    deps = [
-        ":backend_configs_cc",
-        ":cublas_cudnn",
-        "//xla:permutation_util",
-        "//xla:shape_util",
-        "//xla:util",
-        "//xla:window_util",
-        "//xla:xla_data_proto_cc",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_pass",
-        "//xla/stream_executor:device_description",
-        "@com_google_absl//absl/algorithm:container",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/status",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/strings:string_view",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:logging",
-        "@local_tsl//tsl/platform:status",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-cc_library(
-    name = "gpu_sort_rewriter",
-    srcs = if_gpu_is_configured(
-        ["gpu_sort_rewriter.cc"],
-        ["gpu_sort_rewriter_stub.cc"],
-    ),
-    hdrs = ["gpu_sort_rewriter.h"],
-    deps = [
-        ":cublas_cudnn",
-        "//xla:comparison_util",
-        "//xla:shape_util",
-        "//xla:util",
-        "//xla:xla_data_proto_cc",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_pass",
-        "//xla/service:stable_sort_expander",
-        "//xla/service/gpu/runtime:cub_sort_thunk",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/types:span",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:logging",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-cc_library(
-    name = "move_copy_to_users",
-    srcs = ["move_copy_to_users.cc"],
-    hdrs = ["move_copy_to_users.h"],
-    deps = [
-        "//xla:shape_util",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_creation_utils",
-        "//xla/service:hlo_pass",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/status",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/types:span",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:logging",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-xla_cc_test(
-    name = "move_copy_to_users_test",
-    srcs = ["move_copy_to_users_test.cc"],
-    deps = [
-        ":move_copy_to_users",
-        "//xla/service:layout_assignment",
-        "//xla/tests:hlo_test_base",
-        "@com_google_absl//absl/strings:string_view",
-        "@local_tsl//tsl/platform:test",
-        "@local_tsl//tsl/platform:test_main",
-    ],
-)
-
-xla_cc_test(
-    name = "gpu_conv_rewriter_test",
-    srcs = ["gpu_conv_rewriter_test.cc"],
-    deps = [
-        ":cublas_cudnn",
-        ":gpu_conv_rewriter",
-        "//xla:array4d",
-        "//xla:literal_util",
-        "//xla:protobuf_util",
-        "//xla:shape_util",
-        "//xla:test",
-        "//xla:test_helpers",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:pattern_matcher",
-        "//xla/service:pattern_matcher_gmock",
-        "//xla/service:shape_inference",
-        "//xla/stream_executor:device_description",
-        "//xla/tests:hlo_test_base",
-        "//xla/tests:xla_internal_test_main",  # fixdeps: keep
-        "@com_google_absl//absl/log:check",
-        "@com_google_absl//absl/strings:str_format",
-        "@local_tsl//tsl/platform:status_matchers",
-        "@local_tsl//tsl/platform:statusor",
-        "@local_tsl//tsl/platform:test",
-    ],
-)
-
-xla_test(
-    name = "gpu_sort_rewriter_test",
-    srcs = if_cuda_is_configured(["gpu_sort_rewriter_test.cc"]),
-    backends = ["gpu"],
-    tags = ["no_oss"],
-    deps = [
-        ":cublas_cudnn",
-        ":gpu_sort_rewriter",
-        "//xla:error_spec",
-        "//xla:xla_data_proto_cc",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:pattern_matcher",
-        "//xla/service:pattern_matcher_gmock",
-        "//xla/tests:hlo_test_base",
-        "//xla/tests:xla_internal_test_main",  # fixdeps: keep
-        "@com_google_googletest//:gtest",
-        "@local_tsl//tsl/platform:statusor",
-        "@local_tsl//tsl/platform:test",
-    ],
-)
-
-cc_library(
     name = "cusolver_context",
     srcs = if_gpu_is_configured(["cusolver_context.cc"]),
     hdrs = if_gpu_is_configured(["cusolver_context.h"]),
@@ -1942,83 +1075,6 @@
     ]),
 )
 
-cc_library(
-    name = "cusolver_rewriter",
-    srcs = if_gpu_is_configured(["cusolver_rewriter.cc"]),
-    hdrs = if_gpu_is_configured(["cusolver_rewriter.h"]),
-    deps = if_gpu_is_configured([
-        ":cusolver_context",
-        ":ir_emission_utils",
-        "@com_google_absl//absl/algorithm:container",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings:string_view",
-        "//xla:comparison_util",
-        "//xla:literal",
-        "//xla:literal_util",
-        "//xla:shape_util",
-        "//xla:util",
-        "//xla:xla_data_proto_cc",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_pass",
-        "//xla/stream_executor",
-        "//xla/stream_executor:blas",
-        "//xla/stream_executor:device_memory_allocator",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:logging",
-        "@local_tsl//tsl/platform:status",
-        "@local_tsl//tsl/platform:statusor",
-    ]),
-)
-
-cc_library(
-    name = "instruction_fusion",
-    srcs = ["instruction_fusion.cc"],
-    hdrs = ["instruction_fusion.h"],
-    deps = [
-        ":gpu_fusible",
-        "//xla:shape_util",
-        "//xla:xla_data_proto_cc",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:fusion_node_indexing_evaluation",
-        "//xla/service:fusion_queue",
-        "//xla/service:hlo_pass",
-        "//xla/service:instruction_fusion",
-        "//xla/stream_executor:device_description",
-        "@com_google_absl//absl/container:flat_hash_map",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/meta:type_traits",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings",
-    ],
-)
-
-xla_cc_test(
-    name = "instruction_fusion_test",
-    srcs = ["instruction_fusion_test.cc"],
-    tags = [
-        "nomsan",
-        "not_run:arm",
-    ],
-    deps = [
-        ":gpu_device_info_for_tests",
-        ":gpu_fusible",
-        ":instruction_fusion",
-        "//xla:literal_util",
-        "//xla:shape_util",
-        "//xla:util",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:pattern_matcher",
-        "//xla/service:pattern_matcher_gmock",
-        "//xla/tests:hlo_test_base",
-        "//xla/tests:test_utils",
-        "//xla/tests:verified_hlo_module",
-        "//xla/tests:xla_internal_test_main",
-        "@com_google_googletest//:gtest_main",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
 tf_proto_library(
     name = "fusion_process_dump_proto",
     srcs = ["fusion_process_dump.proto"],
@@ -2074,311 +1130,6 @@
 )
 
 cc_library(
-    name = "priority_fusion",
-    srcs = ["priority_fusion.cc"],
-    hdrs = ["priority_fusion.h"],
-    deps = [
-        ":backend_configs_cc",
-        ":fusion_process_dump_proto_cc",
-        ":gpu_fusible",
-        ":hlo_fusion_analysis",
-        ":hlo_traversal",
-        "//xla:debug_options_flags",
-        "//xla:shape_util",
-        "//xla:xla_data_proto_cc",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:dump",
-        "//xla/service:fusion_queue",
-        "//xla/service:hlo_cost_analysis",
-        "//xla/service:hlo_graph_dumper",
-        "//xla/service:hlo_pass",
-        "//xla/service:instruction_fusion",
-        "//xla/service/gpu/model:fusion_analysis_cache",
-        "//xla/service/gpu/model:gpu_hlo_cost_analysis",
-        "//xla/service/gpu/model:gpu_performance_model",
-        "//xla/service/gpu/model:gpu_performance_model_base",
-        "//xla/service/gpu/model:symbolic_tile_analysis",
-        "//xla/stream_executor:device_description",
-        "@com_google_absl//absl/container:flat_hash_map",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/log:check",
-        "@com_google_absl//absl/meta:type_traits",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/strings:str_format",
-        "@com_google_absl//absl/synchronization",
-        "@com_google_absl//absl/time",
-        "@llvm-project//llvm:Support",
-        "@llvm-project//mlir:IR",
-        "@local_tsl//tsl/platform:blocking_counter",
-        "@local_tsl//tsl/platform:env",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:logging",
-        "@local_tsl//tsl/platform:status",
-    ],
-)
-
-xla_cc_test(
-    name = "priority_fusion_test",
-    srcs = ["priority_fusion_test.cc"],
-    local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
-    tags = ["no_pip"],
-    deps = [
-        ":backend_configs_cc",
-        ":gpu_device_info_for_tests",
-        ":gpu_fusible",
-        ":hlo_fusion_analysis",
-        ":priority_fusion",
-        "//xla:shape_util",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_cost_analysis",
-        "//xla/service:pattern_matcher",
-        "//xla/service:pattern_matcher_gmock",
-        "//xla/service/gpu/model:gpu_hlo_cost_analysis",
-        "//xla/tests:hlo_test_base",
-        "//xla/tests:verified_hlo_module",
-        "//xla/tests:xla_internal_test_main",
-        "@com_google_absl//absl/strings:string_view",
-        "@com_google_googletest//:gtest",
-        "@local_tsl//tsl/platform:status_matchers",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-cc_library(
-    name = "multi_output_fusion",
-    srcs = ["multi_output_fusion.cc"],
-    hdrs = ["multi_output_fusion.h"],
-    deps = [
-        ":gpu_fusible",
-        "//xla:debug_options_flags",
-        "//xla:shape_util",
-        "//xla/hlo/ir:hlo",
-        "//xla/hlo/ir:hlo_dfs_reachability",
-        "//xla/service:hlo_cost_analysis",
-        "//xla/service:hlo_graph_dumper",
-        "//xla/service:hlo_pass",
-        "//xla/service:instruction_fusion",
-        "//xla/service/gpu/model:gpu_hlo_cost_analysis",
-        "//xla/service/gpu/model:gpu_performance_model",
-        "//xla/service/gpu/model:gpu_performance_model_base",
-        "//xla/stream_executor:device_description",
-        "@com_google_absl//absl/algorithm:container",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/log",
-        "@com_google_absl//absl/log:check",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/strings:str_format",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:status",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-xla_cc_test(
-    name = "multi_output_fusion_test",
-    srcs = ["multi_output_fusion_test.cc"],
-    tags = [
-        "nomsan",
-    ],
-    deps = [
-        ":gpu_device_info_for_tests",
-        ":gpu_fusible",
-        ":multi_output_fusion",
-        "//xla:shape_util",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_cost_analysis",
-        "//xla/service:pattern_matcher",
-        "//xla/service:pattern_matcher_gmock",
-        "//xla/tests:hlo_test_base",
-        "//xla/tests:xla_internal_test_main",
-        "@com_google_absl//absl/strings",
-    ],
-)
-
-cc_library(
-    name = "rename_fusions",
-    srcs = ["rename_fusions.cc"],
-    hdrs = ["rename_fusions.h"],
-    deps = [
-        ":hlo_traversal",
-        ":ir_emission_utils",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_pass",
-        "@com_google_absl//absl/container:btree",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings",
-    ],
-)
-
-xla_cc_test(
-    name = "rename_fusions_test",
-    srcs = ["rename_fusions_test.cc"],
-    deps = [
-        ":rename_fusions",
-        "//xla/tests:hlo_test_base",
-        "//xla/tests:xla_internal_test_main",
-        "@com_google_absl//absl/strings:string_view",
-        "@com_google_googletest//:gtest",
-    ],
-)
-
-xla_cc_test(
-    name = "softmax_rewriter_triton_test",
-    srcs = ["softmax_rewriter_triton_test.cc"],
-    deps = [
-        ":backend_configs_cc",
-        ":gpu_device_info_for_tests",
-        ":softmax_rewriter_triton",
-        "//xla:shape_util",
-        "//xla/hlo/ir:hlo",
-        "//xla/hlo/utils:hlo_query",
-        "//xla/service:instruction_fusion",
-        "//xla/service:pattern_matcher",
-        "//xla/service:pattern_matcher_gmock",
-        "//xla/service/gpu/fusions/triton:triton_support",
-        "//xla/service/gpu/model:gpu_hlo_cost_analysis",
-        "//xla/stream_executor:device_description",
-        "//xla/tests:hlo_test_base",
-        "//xla/tests:xla_internal_test_main",  # build_cleaner: keep
-        "@com_google_absl//absl/log",
-        "@com_google_absl//absl/log:check",
-        "@com_google_googletest//:gtest",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:status_matchers",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-cc_library(
-    name = "gpu_sanitize_constant_names",
-    srcs = ["gpu_sanitize_constant_names.cc"],
-    hdrs = ["gpu_sanitize_constant_names.h"],
-    deps = [
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_pass",
-        "//xla/service:name_uniquer",
-        "//xla/service/llvm_ir:buffer_assignment_util",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings:string_view",
-        "@local_tsl//tsl/platform:logging",
-    ],
-)
-
-xla_cc_test(
-    name = "gpu_sanitize_constant_names_test",
-    srcs = ["gpu_sanitize_constant_names_test.cc"],
-    deps = [
-        ":gpu_sanitize_constant_names",
-        "//xla:literal_util",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:pattern_matcher",
-        "//xla/service:pattern_matcher_gmock",
-        "//xla/tests:hlo_test_base",
-        "//xla/tests:xla_internal_test_main",
-        "@local_tsl//tsl/platform:statusor",
-        "@local_tsl//tsl/platform:test",
-    ],
-)
-
-cc_library(
-    name = "fusion_merger",
-    srcs = ["fusion_merger.cc"],
-    hdrs = ["fusion_merger.h"],
-    deps = [
-        ":gpu_fusible",
-        "//xla:shape_util",
-        "//xla:util",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_cost_analysis",
-        "//xla/service:hlo_graph_dumper",
-        "//xla/service:hlo_pass",
-        "//xla/service:instruction_fusion",
-        "//xla/service/gpu/model:gpu_hlo_cost_analysis",
-        "//xla/service/gpu/model:gpu_performance_model",
-        "//xla/service/gpu/model:gpu_performance_model_base",
-        "//xla/stream_executor:device_description",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/log",
-        "@com_google_absl//absl/log:check",
-        "@com_google_absl//absl/status",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:status",
-    ],
-)
-
-xla_cc_test(
-    name = "fusion_merger_test",
-    srcs = ["fusion_merger_test.cc"],
-    tags = [
-        "nomsan",
-    ],
-    deps = [
-        ":fusion_merger",
-        ":gpu_device_info_for_tests",
-        ":gpu_fusible",
-        "//xla:shape_util",
-        "//xla:xla_data_proto_cc",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_cost_analysis",
-        "//xla/service:pattern_matcher",
-        "//xla/service:pattern_matcher_gmock",
-        "//xla/tests:hlo_test_base",
-        "//xla/tests:xla_internal_test_main",
-        "@com_google_absl//absl/types:span",
-        "@com_google_googletest//:gtest_main",
-    ],
-)
-
-cc_library(
-    name = "gpu_conv_padding_legalization",
-    srcs = ["gpu_conv_padding_legalization.cc"],
-    hdrs = ["gpu_conv_padding_legalization.h"],
-    deps = [
-        ":cublas_cudnn",
-        "//xla:literal_util",
-        "//xla:shape_util",
-        "//xla:util",
-        "//xla:window_util",
-        "//xla:xla_data_proto_cc",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_creation_utils",
-        "//xla/service:hlo_pass",
-        "//xla/service:shape_inference",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/log",
-        "@com_google_absl//absl/log:check",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings:string_view",
-        "@local_tsl//tsl/platform:status",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-xla_cc_test(
-    name = "gpu_conv_padding_legalization_test",
-    srcs = ["gpu_conv_padding_legalization_test.cc"],
-    deps = [
-        ":cublas_cudnn",
-        ":gpu_conv_padding_legalization",
-        "//xla:shape_util",
-        "//xla:test",
-        "//xla:xla_data_proto_cc",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:pattern_matcher",
-        "//xla/service:pattern_matcher_gmock",
-        "//xla/tests:hlo_test_base",
-        "//xla/tests:xla_internal_test_main",  # fixdeps: keep
-        "@local_tsl//tsl/platform:test",
-    ],
-)
-
-cc_library(
     name = "cudnn_support_utils",
     srcs = ["cudnn_support_utils.cc"],
     hdrs = ["cudnn_support_utils.h"],
@@ -2420,183 +1171,6 @@
 )
 
 cc_library(
-    name = "cudnn_pad_for_convolutions",
-    srcs = ["cudnn_pad_for_convolutions.cc"],
-    hdrs = ["cudnn_pad_for_convolutions.h"],
-    deps = [
-        ":cublas_cudnn",
-        ":cudnn_support_utils",
-        ":stream_executor_util",
-        "//xla:literal_util",
-        "//xla:shape_util",
-        "//xla:util",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_pass",
-        "//xla/stream_executor",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/functional:bind_front",
-        "@com_google_absl//absl/status",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings:string_view",
-        "@com_google_absl//absl/types:span",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:logging",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-xla_cc_test(
-    name = "cudnn_pad_for_convolutions_test",
-    srcs = ["cudnn_pad_for_convolutions_test.cc"],
-    deps = [
-        ":cublas_cudnn",
-        ":cudnn_pad_for_convolutions",
-        "//xla/service:hlo_parser",
-        "//xla/service:pattern_matcher",
-        "//xla/service:pattern_matcher_gmock",
-        "//xla/tests:hlo_test_base",
-        "//xla/tests:xla_internal_test_main",  # build_cleaner: keep
-        "@com_google_googletest//:gtest",
-    ],
-)
-
-cc_library(
-    name = "cudnn_vectorize_convolutions",
-    srcs = ["cudnn_vectorize_convolutions.cc"],
-    hdrs = ["cudnn_vectorize_convolutions.h"],
-    deps = [
-        ":backend_configs_cc",
-        ":cublas_cudnn",
-        ":cudnn_support_utils",
-        ":stream_executor_util",
-        "//xla:shape_util",
-        "//xla:util",
-        "//xla/client:xla_builder",
-        "//xla/client:xla_computation",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_module_config",
-        "//xla/service:hlo_pass",
-        "//xla/stream_executor",
-        "//xla/stream_executor:dnn",
-        "@com_google_absl//absl/algorithm:container",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/container:inlined_vector",
-        "@com_google_absl//absl/status",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/strings:string_view",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:logging",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-xla_cc_test(
-    name = "cudnn_vectorize_convolutions_test",
-    srcs = ["cudnn_vectorize_convolutions_test.cc"],
-    deps = [
-        ":backend_configs_cc",
-        ":cublas_cudnn",
-        ":cudnn_vectorize_convolutions",
-        "//xla:util",
-        "//xla/service:call_inliner",
-        "//xla/service:hlo_parser",
-        "//xla/service:pattern_matcher",
-        "//xla/service:pattern_matcher_gmock",
-        "//xla/stream_executor:device_description",
-        "//xla/stream_executor:dnn",
-        "//xla/tests:hlo_test_base",
-        "//xla/tests:xla_internal_test_main",  # build_cleaner: keep
-        "@com_google_absl//absl/algorithm:container",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_googletest//:gtest",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-cc_library(
-    name = "cudnn_simplify_padding",
-    srcs = ["cudnn_simplify_padding.cc"],
-    hdrs = ["cudnn_simplify_padding.h"],
-    deps = [
-        ":backend_configs_cc",
-        ":cublas_cudnn",
-        "//xla:literal",
-        "//xla:xla_data_proto_cc",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_creation_utils",
-        "//xla/service:hlo_pass",
-        "//xla/service:pattern_matcher",
-        "@com_google_absl//absl/algorithm:container",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/container:inlined_vector",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/types:span",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:logging",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-xla_cc_test(
-    name = "cudnn_simplify_padding_test",
-    srcs = ["cudnn_simplify_padding_test.cc"],
-    deps = [
-        ":cudnn_pad_for_convolutions",
-        ":cudnn_simplify_padding",
-        ":cudnn_vectorize_convolutions",
-        "//xla:literal",
-        "//xla:util",
-        "//xla/service:algebraic_simplifier",
-        "//xla/service:call_inliner",
-        "//xla/service:hlo_pass",
-        "//xla/service:hlo_pass_pipeline",
-        "//xla/service:pattern_matcher",
-        "//xla/service:pattern_matcher_gmock",
-        "//xla/service:reshape_mover",
-        "//xla/service:tuple_simplifier",
-        "//xla/stream_executor:device_description",
-        "//xla/stream_executor:dnn",
-        "//xla/tests:hlo_test_base",
-        "//xla/tests:xla_internal_test_main",  # build_cleaner: keep
-        "//xla/tsl/lib/core:status_test_util",
-        "@com_google_absl//absl/functional:function_ref",
-        "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/types:span",
-        "@com_google_googletest//:gtest",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:logging",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-cc_library(
-    name = "cublas_pad_for_gemms",
-    srcs = ["cublas_pad_for_gemms.cc"],
-    hdrs = ["cublas_pad_for_gemms.h"],
-    deps = [
-        ":gemm_fusion",
-        ":ir_emission_utils",
-        "//xla:literal_util",
-        "//xla:shape_util",
-        "//xla:util",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_pass",
-        "//xla/service/gpu/fusions/triton:triton_support",
-        "//xla/stream_executor:device_description",
-        "@com_google_absl//absl/algorithm:container",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings:string_view",
-        "@local_tsl//tsl/platform:logging",
-        "@local_tsl//tsl/platform:status",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-cc_library(
     name = "cublas_padding_requirements",
     srcs = ["cublas_padding_requirements.cc"],
     hdrs = ["cublas_padding_requirements.h"],
@@ -2609,92 +1183,6 @@
     ],
 )
 
-xla_cc_test(
-    name = "cublas_pad_for_gemms_test",
-    srcs = ["cublas_pad_for_gemms_test.cc"],
-    tags = [
-        "nomsan",
-    ],
-    deps = [
-        ":cublas_pad_for_gemms",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:pattern_matcher",
-        "//xla/service:pattern_matcher_gmock",
-        "//xla/stream_executor:device_description",
-        "//xla/tests:hlo_test_base",
-        "//xla/tests:xla_internal_test_main",  # build_cleaner: keep
-        "@com_google_googletest//:gtest",
-    ],
-)
-
-cc_library(
-    name = "cudnn_fusion_compiler",
-    srcs = if_cuda_is_configured(["cudnn_fusion_compiler.cc"]),
-    hdrs = if_cuda_is_configured(["cudnn_fusion_compiler.h"]),
-    deps = if_cuda_is_configured([
-        ":backend_configs_cc",
-        ":cudnn_support_utils",
-        ":ir_emission_utils",
-        ":kernel_reuse_cache",
-        ":matmul_utils",
-        ":triton_fusion_analysis",
-        "@com_google_absl//absl/algorithm:container",
-        "@com_google_absl//absl/container:flat_hash_map",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/log",
-        "@com_google_absl//absl/log:check",
-        "@com_google_absl//absl/status",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings:string_view",
-        "@local_config_cuda//cuda:cudnn_header",
-        "//xla:shape_util",
-        "//xla:comparison_util",
-        "//xla:util",
-        "//xla/hlo/ir:hlo",
-        "//xla/hlo/utils:hlo_query",
-        "//xla/service:hlo_pass",
-        "//xla/stream_executor:dnn",
-        "//xla/stream_executor:stream_executor_h",
-        "//xla/service:dump",
-        "//xla/stream_executor/cuda:cudnn_frontend_helpers",
-        "//xla/stream_executor/cuda:cudnn_plugin",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:statusor",
-    ]),
-)
-
-cc_library(
-    name = "cudnn_workspace_rewriter",
-    srcs = if_cuda_is_configured(["cudnn_workspace_rewriter.cc"]),
-    hdrs = if_cuda_is_configured(["cudnn_workspace_rewriter.h"]),
-    deps = if_cuda_is_configured([
-        ":backend_configs_cc",
-        ":ir_emission_utils",
-        ":gpu_fused_mha_runner",
-        ":cublas_cudnn",
-        ":stream_executor_util",
-        "@com_google_absl//absl/log",
-        "@com_google_absl//absl/log:check",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/container:inlined_vector",
-        "@com_google_absl//absl/status",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings:string_view",
-        "@local_config_cuda//cuda:cudnn_header",
-        "//xla:shape_util",
-        "//xla:util",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_pass",
-        "//xla/stream_executor/cuda:cudnn_frontend_helpers",
-        "//xla/stream_executor/cuda:cudnn_plugin",
-        "//xla/stream_executor:dnn",
-        "//xla/stream_executor:stream_executor_h",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:statusor",
-        "//xla:status_macros",
-    ]),
-)
-
 tf_proto_library(
     name = "executable_proto",
     srcs = ["executable.proto"],
@@ -2749,43 +1237,6 @@
 )
 
 cc_library(
-    name = "gpu_reduce_scatter_creator",
-    srcs = ["gpu_reduce_scatter_creator.cc"],
-    hdrs = ["gpu_reduce_scatter_creator.h"],
-    deps = [
-        "//xla:shape_util",
-        "//xla:status_macros",
-        "//xla/hlo/ir:hlo",
-        "//xla/hlo/utils:hlo_query",
-        "//xla/service:collective_opt_utils",
-        "//xla/service:hlo_module_config",
-        "//xla/service:hlo_pass",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/log",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings:string_view",
-        "@local_tsl//tsl/platform:errors",
-    ],
-)
-
-cc_library(
-    name = "gpu_all_gather_optimizer",
-    srcs = ["gpu_all_gather_optimizer.cc"],
-    hdrs = ["gpu_all_gather_optimizer.h"],
-    deps = [
-        "//xla:shape_util",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:collective_ops_utils",
-        "//xla/service:hlo_pass",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:logging",
-    ],
-)
-
-cc_library(
     name = "gpu_float_support",
     srcs = ["gpu_float_support.cc"],
     hdrs = ["gpu_float_support.h"],
@@ -2856,256 +1307,10 @@
 )
 
 cc_library(
-    name = "command_buffer_scheduling",
-    srcs = ["command_buffer_scheduling.cc"],
-    hdrs = ["command_buffer_scheduling.h"],
-    deps = [
-        ":backend_configs_cc",
-        ":cublas_cudnn",
-        ":hlo_fusion_analysis",
-        ":hlo_traversal",
-        ":ir_emission_utils",
-        ":variant_visitor",
-        "//xla:shape_util",
-        "//xla:util",
-        "//xla/ffi:ffi_api",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_pass",
-        "//xla/stream_executor:device_description",
-        "@com_google_absl//absl/algorithm:container",
-        "@com_google_absl//absl/container:flat_hash_map",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/container:inlined_vector",
-        "@com_google_absl//absl/status",
-        "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/types:span",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:logging",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-xla_cc_test(
-    name = "command_buffer_scheduling_test",
-    srcs = ["command_buffer_scheduling_test.cc"],
-    deps = [
-        ":command_buffer_scheduling",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_parser",
-        "//xla/stream_executor:device_description",
-        "//xla/tests:filecheck",
-        "//xla/tests:hlo_test_base",
-        "//xla/tests:verified_hlo_module",
-        "//xla/tsl/lib/core:status_test_util",
-        "@com_google_googletest//:gtest_main",
-        "@local_tsl//tsl/platform:status",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-cc_library(
-    name = "custom_kernel_fusion_autotuner",
-    srcs = if_gpu_is_configured(["custom_kernel_fusion_autotuner.cc"]),
-    hdrs = if_gpu_is_configured(["custom_kernel_fusion_autotuner.h"]),
-    local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
-    deps = if_gpu_is_configured([
-        ":autotuner_compile_util",
-        ":autotuner_util",
-        ":backend_configs_cc",
-        ":buffer_comparator",
-        ":gemm_rewriter",
-        ":gpu_float_support",
-        ":gpu_fusible",
-        ":instruction_fusion",
-        ":ir_emission_utils",
-        ":matmul_utils",
-        ":split_k_gemm_rewriter",
-        "//xla/service/gpu/kernels:custom_kernel",
-        "//xla/service/gpu/kernels:custom_kernel_fusion",
-        ":stream_executor_util",
-        ":cudnn_fusion_compiler",
-        "@com_google_absl//absl/algorithm:container",
-        "@com_google_absl//absl/container:flat_hash_map",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/log:check",
-        "@com_google_absl//absl/log",
-        "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/synchronization",
-        "@com_google_absl//absl/time",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings:str_format",
-        "@com_google_absl//absl/types:span",
-        "@local_config_cuda//cuda:cuda_headers",
-        "//xla:autotuning_proto_cc",
-        "//xla:shape_util",
-        "//xla:status_macros",
-        "//xla/tools:hlo_decomposer_lib",
-        "//xla:statusor",
-        "//xla:util",
-        "//xla:xla_data_proto_cc",
-        "//xla:xla_proto_cc",
-        "//xla/hlo/ir:hlo",
-        "//xla/hlo/utils:hlo_query",
-        "//xla/service:algorithm_util",
-        "//xla/service:dump",
-        "//xla/service:executable",
-        "//xla/service:float_normalization",
-        "//xla/service:hlo_module_config",
-        "//xla/service:hlo_pass",
-        "//xla/service:shaped_buffer",
-        "//xla/stream_executor:device_description",
-        "//xla/stream_executor:device_memory",
-        "//xla/stream_executor",
-        "//xla/stream_executor/gpu:redzone_allocator",
-        "@local_tsl//tsl/lib/core:bits",
-        "@local_tsl//tsl/platform:blocking_counter",
-        "@local_tsl//tsl/platform:env",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:protobuf",
-        "@local_tsl//tsl/platform:status",
-        "@local_tsl//tsl/platform:statusor",
-        "@local_tsl//tsl/profiler/lib:scoped_annotation",
-        "//xla/tsl/util/proto:proto_utils",
-        "//xla/service/gpu:hlo_traversal",
-    ]) + [
-        "//xla/stream_executor:stream_executor_memory_allocator",
-        "@com_google_absl//absl/status",
-        "@local_tsl//tsl/platform:path",
-    ],
-)
-
-xla_test(
-    name = "custom_kernel_fusion_autotuner_test",
-    srcs = if_cuda_is_configured(["custom_kernel_fusion_autotuner_test.cc"]),
-    backends = [
-        "gpu",
-    ],
-    deps = [
-        ":autotuner_util",
-        ":custom_kernel_fusion_autotuner",
-        "//xla:xla_proto_cc",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_pass_pipeline",
-        "//xla/tests:hlo_test_base",
-        "//xla/tests:xla_internal_test_main",  # fixdeps: keep
-        "@com_google_googletest//:gtest",
-        "@local_tsl//tsl/platform:path",
-        "@local_tsl//tsl/platform:test",
-    ],
-)
-
-cc_library(
-    name = "custom_kernel_fusion_rewriter",
-    srcs = ["custom_kernel_fusion_rewriter.cc"],
-    hdrs = ["custom_kernel_fusion_rewriter.h"],
-    deps = [
-        "//xla:shape_util",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_pass",
-        "//xla/service/gpu/kernels:custom_fusion_library",
-        "//xla/service/gpu/kernels:custom_kernel_fusion_pattern",
-        "//xla/stream_executor:device_description",
-        "@com_google_absl//absl/algorithm:container",
-        "@com_google_absl//absl/container:flat_hash_map",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/container:inlined_vector",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/types:span",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:logging",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-xla_cc_test(
-    name = "custom_kernel_fusion_rewriter_test",
-    srcs = ["custom_kernel_fusion_rewriter_test.cc"],
-    deps = [
-        ":custom_kernel_fusion_rewriter",
-        ":gpu_device_info_for_tests",
-        "//xla/hlo/ir:hlo",
-        "//xla/service/gpu/kernels:custom_kernel_fusion_pattern",
-        "//xla/stream_executor:device_description",
-        "//xla/tests:hlo_test_base",
-        "@local_tsl//tsl/platform:test",
-        "@local_tsl//tsl/platform:test_main",
-    ],
-)
-
-cc_library(
-    name = "dynamic_slice_fusion_rewriter",
-    srcs = ["dynamic_slice_fusion_rewriter.cc"],
-    hdrs = ["dynamic_slice_fusion_rewriter.h"],
-    deps = [
-        ":backend_configs_cc",
-        ":cublas_cudnn",
-        ":gpu_constants",
-        ":hlo_traversal",
-        ":ir_emission_utils",
-        "//xla:shape_util",
-        "//xla:util",
-        "//xla/ffi:ffi_api",
-        "//xla/ffi/api:c_api",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:custom_call_target_registry",
-        "//xla/service:hlo_pass",
-        "//xla/service/gpu/kernels:custom_fusion_library",
-        "@com_google_absl//absl/algorithm:container",
-        "@com_google_absl//absl/container:flat_hash_map",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/container:inlined_vector",
-        "@com_google_absl//absl/status",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/types:span",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:logging",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-xla_cc_test(
-    name = "dynamic_slice_fusion_rewriter_test",
-    srcs = if_cuda_is_configured(["dynamic_slice_fusion_rewriter_test.cc"]),
-    deps = [
-        ":dynamic_slice_fusion_rewriter",
-        ":gpu_device_info_for_tests",
-        "//xla:shape_util",
-        "//xla/client:xla_builder",
-        "//xla/client/lib:constants",
-        "//xla/ffi",
-        "//xla/ffi:ffi_api",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:buffer_value",
-        "//xla/service:custom_call_target_registry",
-        "//xla/service:executable",
-        "//xla/service:hlo_memory_scheduler",
-        "//xla/service:hlo_module_config",
-        "//xla/stream_executor",
-        "//xla/stream_executor/gpu:gpu_types_header",
-        "//xla/tests:hlo_test_base",
-        "@com_google_absl//absl/algorithm:container",
-        "@com_google_absl//absl/status",
-        "@local_tsl//tsl/platform:status",
-        "@local_tsl//tsl/platform:statusor",
-        "@local_tsl//tsl/platform:test",
-        "@local_tsl//tsl/platform:test_main",
-    ],
-)
-
-cc_library(
     name = "fusion_pipeline",
     srcs = ["fusion_pipeline.cc"],
     hdrs = ["fusion_pipeline.h"],
     deps = [
-        ":fusion_merger",
-        ":horizontal_input_fusion",
-        ":horizontal_loop_fusion",
-        ":instruction_fusion",
-        ":multi_output_fusion",
-        ":priority_fusion",
-        ":variadic_op_splitter",
         "//xla:xla_proto_cc",
         "//xla/service:cpu_gpu_shape_verifier",
         "//xla/service:hlo_cost_analysis",
@@ -3116,6 +1321,13 @@
         "//xla/service:hlo_verifier",
         "//xla/service:layout_assignment",
         "//xla/service/gpu/model:gpu_hlo_cost_analysis",
+        "//xla/service/gpu/transforms:fusion_merger",
+        "//xla/service/gpu/transforms:horizontal_input_fusion",
+        "//xla/service/gpu/transforms:horizontal_loop_fusion",
+        "//xla/service/gpu/transforms:instruction_fusion",
+        "//xla/service/gpu/transforms:multi_output_fusion",
+        "//xla/service/gpu/transforms:priority_fusion",
+        "//xla/service/gpu/transforms:variadic_op_splitter",
         "//xla/stream_executor:device_description",
         "@local_tsl//tsl/platform:env",
     ],
@@ -3126,10 +1338,6 @@
     srcs = ["prepare_hlo_for_ir_emitting_pipeline.cc"],
     hdrs = ["prepare_hlo_for_ir_emitting_pipeline.h"],
     deps = [
-        ":alias_passthrough_params",
-        ":copy_fusion",
-        ":gpu_sanitize_constant_names",
-        ":horizontal_loop_fusion",
         "//xla:xla_proto_cc",
         "//xla/hlo/ir:hlo",
         "//xla/service:copy_insertion",
@@ -3140,6 +1348,10 @@
         "//xla/service:hlo_verifier",
         "//xla/service:layout_assignment",
         "//xla/service:loop_schedule_linearizer",
+        "//xla/service/gpu/transforms:alias_passthrough_params",
+        "//xla/service/gpu/transforms:copy_fusion",
+        "//xla/service/gpu/transforms:horizontal_loop_fusion",
+        "//xla/service/gpu/transforms:sanitize_constant_names",
     ],
 )
 
@@ -3153,53 +1365,20 @@
     ]),
     deps = if_gpu_is_configured([
         # go/keep-sorted start prefix_order=":,,
-        ":algorithm_checker",
-        ":alias_passthrough_params",
-        ":all_reduce_blueconnect",
-        ":autotuner_util",
         ":buffer_sharing",
-        ":collective_permute_cycle_decomposer",
-        ":collective_permute_valid_iteration_annotator",
-        ":command_buffer_scheduling",
         ":compile_module_to_llvm_ir",
         ":conv_layout_normalization",
-        ":copy_fusion",
-        ":custom_kernel_fusion_autotuner",
-        ":custom_kernel_fusion_rewriter",
-        ":dot_dimension_sorter",
-        ":dot_operand_converter",
-        ":double_buffer_loop_unrolling",
-        ":dynamic_slice_fusion_rewriter",
         ":executable_proto_cc",
         ":execution_stream_assignment",
-        ":fusion_merger",
         ":fusion_pipeline",
-        ":fusion_wrapper",
-        ":gemm_broadcast_folding_rewriter",
-        ":gemm_fusion",
-        ":gemm_rewriter",
-        ":gemv_rewriter",
-        ":gpu_algebraic_simplifier",
-        ":gpu_all_gather_optimizer",
-        ":gpu_async_collective_annotator",
         ":gpu_constants",
-        ":gpu_conv_rewriter",
-        ":gpu_convert_async_collectives_to_sync",
         ":gpu_executable",
         ":gpu_float_support",
         ":gpu_hlo_schedule",
         ":gpu_latency_hiding_scheduler",
-        ":gpu_layout_assignment",
         ":gpu_p2p_pipeliner",
-        ":gpu_reduce_scatter_creator",
-        ":gpu_sanitize_constant_names",
-        ":gpu_scatter_expander",
         ":gpu_spmd_pipeline",
-        ":gpu_windowed_einsum_handler",
         ":hlo_fusion_stats",
-        ":horizontal_input_fusion",
-        ":horizontal_loop_fusion",
-        ":instruction_fusion",
         ":ir_emission_utils",
         ":ir_emitter",
         ":ir_emitter_context",
@@ -3207,28 +1386,10 @@
         ":kernel_reuse_cache",
         ":matmul_utils",
         ":metrics",
-        ":move_copy_to_users",
-        ":multi_output_fusion",
-        ":pipelined_p2p_rewriter",
         ":prepare_hlo_for_ir_emitting_pipeline",
-        ":priority_fusion",
-        ":reduction_degenerate_dim_remover",
-        ":reduction_dimension_grouper",
-        ":reduction_layout_normalizer",
-        ":reduction_splitter",
         ":reduction_utils",
-        ":rename_fusions",
         ":runtime_intrinsics",
-        ":scatter_slice_simplifier",
-        ":softmax_rewriter_triton",
-        ":stream_attribute_annotator",
-        ":stream_attribute_async_wrapper",
         ":stream_executor_util",
-        ":topk_specializer",
-        ":topk_splitter",
-        ":tree_reduction_rewriter",
-        ":triton_fusion_numerics_verifier",
-        ":variadic_op_splitter",
         "@com_google_absl//absl/base",
         "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/container:flat_hash_set",
@@ -3252,9 +1413,53 @@
         "@llvm-project//mlir:Support",
         "//xla/hlo/ir:hlo",
         "//xla/hlo/ir:hlo_module_group",
+        "//xla/service/gpu/autotuning:autotuner_util",
+        "//xla/service/gpu/autotuning:custom_kernel_fusion_autotuner",
         "//xla/service/gpu/model:gpu_cost_model_stats_collection",
         "//xla/service/gpu/model:gpu_hlo_cost_analysis",
         "//xla/service/gpu/runtime:thunk",
+        "//xla/service/gpu/transforms:algebraic_simplifier",
+        "//xla/service/gpu/transforms:algorithm_checker",
+        "//xla/service/gpu/transforms:all_gather_optimizer",
+        "//xla/service/gpu/transforms:all_reduce_blueconnect",
+        "//xla/service/gpu/transforms:all_reduce_splitter",
+        "//xla/service/gpu/transforms:async_collective_annotator",
+        "//xla/service/gpu/transforms:collective_permute_cycle_decomposer",
+        "//xla/service/gpu/transforms:collective_permute_valid_iteration_annotator",
+        "//xla/service/gpu/transforms:command_buffer_scheduling",
+        "//xla/service/gpu/transforms:conv_rewriter",
+        "//xla/service/gpu/transforms:convert_async_collectives_to_sync",
+        "//xla/service/gpu/transforms:cudnn_custom_call_converter",
+        "//xla/service/gpu/transforms:custom_kernel_fusion_rewriter",
+        "//xla/service/gpu/transforms:dot_dimension_sorter",
+        "//xla/service/gpu/transforms:dot_operand_converter",
+        "//xla/service/gpu/transforms:double_buffer_loop_unrolling",
+        "//xla/service/gpu/transforms:dynamic_slice_fusion_rewriter",
+        "//xla/service/gpu/transforms:fusion_wrapper",
+        "//xla/service/gpu/transforms:gemm_broadcast_folding_rewriter",
+        "//xla/service/gpu/transforms:gemm_fusion",
+        "//xla/service/gpu/transforms:gemm_rewriter",
+        "//xla/service/gpu/transforms:gemv_rewriter",
+        "//xla/service/gpu/transforms:layout_assignment",
+        "//xla/service/gpu/transforms:move_copy_to_users",
+        "//xla/service/gpu/transforms:pipelined_p2p_rewriter",
+        "//xla/service/gpu/transforms:reduce_scatter_creator",
+        "//xla/service/gpu/transforms:reduction_degenerate_dim_remover",
+        "//xla/service/gpu/transforms:reduction_dimension_grouper",
+        "//xla/service/gpu/transforms:reduction_layout_normalizer",
+        "//xla/service/gpu/transforms:reduction_splitter",
+        "//xla/service/gpu/transforms:rename_fusions",
+        "//xla/service/gpu/transforms:sanitize_constant_names",
+        "//xla/service/gpu/transforms:scatter_expander",
+        "//xla/service/gpu/transforms:scatter_slice_simplifier",
+        "//xla/service/gpu/transforms:softmax_rewriter_triton",
+        "//xla/service/gpu/transforms:stream_attribute_annotator",
+        "//xla/service/gpu/transforms:stream_attribute_async_wrapper",
+        "//xla/service/gpu/transforms:topk_specializer",
+        "//xla/service/gpu/transforms:topk_splitter",
+        "//xla/service/gpu/transforms:tree_reduction_rewriter",
+        "//xla/service/gpu/transforms:triton_fusion_numerics_verifier",
+        "//xla/service/gpu/transforms:windowed_einsum_handler",
         "//xla/service/llvm_ir:llvm_util",
         "//xla/service/spmd:collective_permute_motion",
         "//xla/service:algebraic_simplifier",
@@ -3265,7 +1470,6 @@
         "//xla/service:all_reduce_folder",
         "//xla/service:all_reduce_promotion",
         "//xla/service:all_reduce_reassociate",
-        "//xla/service:all_reduce_splitter",
         "//xla/service:async_collective_creator",
         "//xla/service:batchnorm_expander",
         "//xla/service:bitcast_dtypes_expander",
@@ -3395,11 +1599,10 @@
 
 xla_test(
     name = "gpu_compiler_test",
-    srcs = if_gpu_is_configured(["gpu_compiler_test.cc"]),
+    srcs = ["gpu_compiler_test.cc"],
     backends = ["gpu"],
     data = ["gpu_compiler_test_autotune_db.textproto"],
     deps = [
-        ":autotuner_util",
         ":gpu_compiler",
         ":gpu_hlo_schedule",
         ":metrics",
@@ -3413,6 +1616,7 @@
         "//xla/service:pattern_matcher",
         "//xla/service:pattern_matcher_gmock",
         "//xla/service:xla_debug_info_manager",
+        "//xla/service/gpu/autotuning:autotuner_util",
         "//xla/stream_executor:device_description",
         "//xla/tests:filecheck",
         "//xla/tests:hlo_test_base",
@@ -3449,7 +1653,7 @@
         "//xla/service:hlo_cost_analysis",
         "//xla/service:hlo_memory_scheduler",
         "//xla/service:hlo_rematerialization",
-        "//xla/service/gpu:stream_attribute_annotator",
+        "//xla/service/gpu/transforms:stream_attribute_annotator",
         "//xla/tests:hlo_test_base",
         "//xla/tests:xla_internal_test_main",
         "//xla/tsl/lib/core:status_test_util",
@@ -3511,34 +1715,13 @@
         "no_rocm",
     ],
     deps = [
-        ":autotuner_util",
         ":buffer_sharing",
-        ":conv_algorithm_picker",
-        ":cublas_pad_for_gemms",
         ":cublas_padding_requirements",
-        ":cudnn_fused_conv_rewriter",
-        ":cudnn_fused_mha_rewriter",
-        ":cudnn_fused_mha_transpose_fusion",
-        ":cudnn_fusion_compiler",
-        ":cudnn_norm_rewriter",
-        ":cudnn_pad_for_convolutions",
-        ":cudnn_simplify_padding",
-        ":cudnn_vectorize_convolutions",
-        ":cudnn_workspace_rewriter",
-        ":cusolver_rewriter",
-        ":dot_sparsity_rewriter",
-        ":gemm_algorithm_picker",
-        ":gemm_fusion_autotuner",
-        ":gpu_algebraic_simplifier",
         ":gpu_asm_opts_util",
         ":gpu_compiler",
-        ":gpu_conv_padding_legalization",
-        ":gpu_conv_rewriter",
-        ":gpu_sort_rewriter",
         ":ir_emission_utils",
         ":metrics",
         ":target_constants",
-        ":triangular_solve_rewriter",
         "//xla:autotune_results_proto_cc",
         "//xla:util",
         "//xla:xla_proto_cc",
@@ -3561,13 +1744,36 @@
         "//xla/service:hlo_verifier",
         "//xla/service:reshape_mover",
         "//xla/service:tuple_simplifier",
+        "//xla/service/gpu/autotuning:autotuner_util",
+        "//xla/service/gpu/autotuning:conv_algorithm_picker",
+        "//xla/service/gpu/autotuning:gemm_algorithm_picker",
+        "//xla/service/gpu/autotuning:gemm_fusion_autotuner",
         "//xla/service/gpu/llvm_gpu_backend",
+        "//xla/service/gpu/transforms:algebraic_simplifier",
+        "//xla/service/gpu/transforms:conv_padding_legalization",
+        "//xla/service/gpu/transforms:conv_rewriter",
+        "//xla/service/gpu/transforms:cublas_pad_for_gemms",
+        "//xla/service/gpu/transforms:cudnn_custom_call_compiler",
+        "//xla/service/gpu/transforms:cudnn_fused_conv_rewriter",
+        "//xla/service/gpu/transforms:cudnn_fused_mha_rewriter",
+        "//xla/service/gpu/transforms:cudnn_fused_mha_transpose_fusion",
+        "//xla/service/gpu/transforms:cudnn_fusion_compiler",
+        "//xla/service/gpu/transforms:cudnn_norm_rewriter",
+        "//xla/service/gpu/transforms:cudnn_pad_for_convolutions",
+        "//xla/service/gpu/transforms:cudnn_simplify_padding",
+        "//xla/service/gpu/transforms:cudnn_vectorize_convolutions",
+        "//xla/service/gpu/transforms:dot_sparsity_rewriter",
+        "//xla/service/gpu/transforms:gpusolver_rewriter",
+        "//xla/service/gpu/transforms:sort_rewriter",
+        "//xla/service/gpu/transforms:triangular_solve_rewriter",
         "//xla/service/llvm_ir:llvm_util",
         "//xla/stream_executor",
         "//xla/stream_executor:dnn",
         "//xla/stream_executor/cuda:cuda_asm_compiler",
         "//xla/stream_executor/cuda:cuda_diagnostics",
         "//xla/stream_executor/cuda:cuda_platform_id",
+        "//xla/stream_executor/cuda:nvjitlink",
+        "//xla/stream_executor/cuda:nvjitlink_support",
         "//xla/stream_executor/cuda:ptx_compilation_method",
         "//xla/stream_executor/cuda:ptx_compiler",
         "//xla/stream_executor/cuda:ptx_compiler_support",
@@ -3615,7 +1821,6 @@
         "gpu_a100",
     ],
     tags = [
-        "gpu",
         "no_rocm",
         "nomsan",  # Pulls in precompiled NVIDIA libraries which cause false positives in msan.
     ],
@@ -3654,7 +1859,6 @@
         "gpu",
     ],
     tags = [
-        "gpu",
         "no_rocm",
         "nomsan",  # Pulls in precompiled NVIDIA libraries which cause false positives in msan.
     ],
@@ -3666,6 +1870,7 @@
         "//xla/service:executable",
         "//xla/service:hlo_module_config",
         "//xla/stream_executor:device_description",
+        "//xla/stream_executor/cuda:nvjitlink_support",
         "//xla/stream_executor/cuda:ptx_compilation_method",
         "//xla/stream_executor/cuda:ptx_compiler_support",
         "//xla/stream_executor/cuda:ptx_linking_method",
@@ -3745,45 +1950,6 @@
 )
 
 cc_library(
-    name = "gpu_algebraic_simplifier",
-    srcs = [
-        "gpu_algebraic_simplifier.cc",
-    ],
-    hdrs = [
-        "gpu_algebraic_simplifier.h",
-    ],
-    deps = [
-        ":matmul_utils",
-        "//xla:util",
-        "//xla:xla_data_proto_cc",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:algebraic_simplifier",
-        "//xla/service:hlo_pass",
-        "//xla/service/gpu/fusions/triton:triton_support",
-        "//xla/stream_executor:device_description",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/log:check",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings:string_view",
-    ],
-)
-
-xla_cc_test(
-    name = "gpu_algebraic_simplifier_test",
-    srcs = ["gpu_algebraic_simplifier_test.cc"],
-    deps = [
-        ":gpu_algebraic_simplifier",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:algebraic_simplifier",
-        "//xla/stream_executor:device_description",
-        "//xla/tests:hlo_test_base",
-        "//xla/tests:xla_internal_test_main",
-        "@com_google_googletest//:gtest",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-cc_library(
     name = "amdgpu_compiler_impl",
     srcs = [
         "amdgpu_compiler.cc",
@@ -3793,20 +1959,9 @@
     ],
     tags = ["manual"],
     deps = [
-        ":autotuner_util",
-        ":conv_algorithm_picker",
-        ":cublas_pad_for_gemms",
         ":cublas_padding_requirements",
-        ":cudnn_fused_conv_rewriter",
-        ":cusolver_rewriter",
-        ":gemm_algorithm_picker",
-        ":gpu_algebraic_simplifier",
         ":gpu_compiler",
-        ":gpu_conv_padding_legalization",
-        ":gpu_conv_rewriter",
-        ":gpu_sort_rewriter",
         ":target_constants",
-        ":triangular_solve_rewriter",
         "//xla:util",
         "//xla:xla_proto_cc",
         "//xla/hlo/ir:hlo",
@@ -3823,7 +1978,18 @@
         "//xla/service:hlo_verifier",
         "//xla/service:reshape_mover",
         "//xla/service:tuple_simplifier",
+        "//xla/service/gpu/autotuning:autotuner_util",
+        "//xla/service/gpu/autotuning:conv_algorithm_picker",
+        "//xla/service/gpu/autotuning:gemm_algorithm_picker",
         "//xla/service/gpu/llvm_gpu_backend",
+        "//xla/service/gpu/transforms:algebraic_simplifier",
+        "//xla/service/gpu/transforms:conv_padding_legalization",
+        "//xla/service/gpu/transforms:conv_rewriter",
+        "//xla/service/gpu/transforms:cublas_pad_for_gemms",
+        "//xla/service/gpu/transforms:cudnn_fused_conv_rewriter",
+        "//xla/service/gpu/transforms:gpusolver_rewriter",
+        "//xla/service/gpu/transforms:sort_rewriter",
+        "//xla/service/gpu/transforms:triangular_solve_rewriter",
         "//xla/stream_executor:device_description",
         "//xla/stream_executor:device_memory_allocator",
         "//xla/stream_executor:dnn",
@@ -3843,53 +2009,6 @@
 )
 
 cc_library(
-    name = "all_reduce_blueconnect",
-    srcs = ["all_reduce_blueconnect.cc"],
-    hdrs = ["all_reduce_blueconnect.h"],
-    deps = [
-        "//xla:shape_util",
-        "//xla:status_macros",
-        "//xla/hlo/ir:hlo",
-        "//xla/hlo/utils:hlo_query",
-        "//xla/service:collective_ops_utils",
-        "//xla/service:computation_placer_hdr",
-        "//xla/service:global_device_id",
-        "//xla/service:hlo_creation_utils",
-        "//xla/service:hlo_pass",
-        "@com_google_absl//absl/algorithm:container",
-        "@com_google_absl//absl/container:btree",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings:string_view",
-        "@com_google_absl//absl/types:span",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:logging",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-xla_cc_test(
-    name = "all_reduce_blueconnect_test",
-    srcs = ["all_reduce_blueconnect_test.cc"],
-    deps = [
-        ":all_reduce_blueconnect",
-        "//xla:shape_util",
-        "//xla:util",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:computation_placer_hdr",
-        "//xla/service:pattern_matcher",
-        "//xla/service:pattern_matcher_gmock",
-        "//xla/tests:hlo_test_base",
-        "@com_google_absl//absl/strings:string_view",
-        "@com_google_absl//absl/types:span",
-        "@com_google_googletest//:gtest",
-        "@local_tsl//tsl/platform:status_matchers",
-        "@local_tsl//tsl/platform:statusor",
-        "@local_tsl//tsl/platform:test_main",
-    ],
-)
-
-cc_library(
     name = "xfeed_queue",
     hdrs = ["xfeed_queue.h"],
     deps = [
@@ -3934,109 +2053,12 @@
 )
 
 cc_library(
-    name = "gpu_layout_assignment",
-    srcs = ["gpu_layout_assignment.cc"],
-    hdrs = ["gpu_layout_assignment.h"],
-    deps = [
-        ":backend_configs_cc",
-        ":cublas_cudnn",
-        ":matmul_utils",
-        ":reduction_utils",
-        ":stream_executor_util",
-        "//xla:shape_layout",
-        "//xla:shape_util",
-        "//xla:util",
-        "//xla:window_util",
-        "//xla:xla_data_proto_cc",
-        "//xla:xla_proto_cc",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:computation_layout",
-        "//xla/service:host_memory_offload_annotations_hdr",
-        "//xla/service:layout_assignment",
-        "//xla/service:logical_buffer",
-        "//xla/stream_executor",
-        "//xla/stream_executor:dnn",
-        "//xla/tsl/util:env_var",
-        "@com_google_absl//absl/log",
-        "@com_google_absl//absl/log:check",
-        "@com_google_absl//absl/status",
-        "@com_google_absl//absl/types:span",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:status",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-xla_cc_test(
-    name = "gpu_layout_assignment_test",
-    srcs = ["gpu_layout_assignment_test.cc"],
-    deps = [
-        ":gpu_layout_assignment",
-        ":stream_executor_util",
-        "//xla:shape_layout",
-        "//xla:shape_util",
-        "//xla:xla_data_proto_cc",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:computation_layout",
-        "//xla/service:hlo_parser",
-        "//xla/service:pattern_matcher",
-        "//xla/service:pattern_matcher_gmock",
-        "//xla/stream_executor:device_description",
-        "//xla/stream_executor:dnn",
-        "//xla/tests:hlo_test_base",
-        "//xla/tests:xla_internal_test_main",  # build_cleaner: keep
-        "@com_google_absl//absl/types:span",
-        "@com_google_googletest//:gtest",
-        "@local_tsl//tsl/platform:status_matchers",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-cc_library(
-    name = "gpu_schedule_postprocessing",
-    srcs = ["gpu_schedule_postprocessing.cc"],
-    hdrs = ["gpu_schedule_postprocessing.h"],
-    deps = [
-        ":backend_configs_cc",
-        "//xla/hlo/ir:hlo",
-        "//xla/hlo/utils:hlo_query",
-        "//xla/service:hlo_pass",
-        "@com_google_absl//absl/algorithm:container",
-        "@com_google_absl//absl/container:flat_hash_map",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings:string_view",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-xla_cc_test(
-    name = "gpu_schedule_postprocessing_test",
-    srcs = ["gpu_schedule_postprocessing_test.cc"],
-    deps = [
-        ":backend_configs_cc",
-        ":gpu_schedule_postprocessing",
-        "//xla:util",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_parser",
-        "//xla/tests:hlo_test_base",
-        "//xla/tests:xla_internal_test_main",
-        "@com_google_absl//absl/strings:string_view",
-        "@com_google_googletest//:gtest",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-cc_library(
     name = "gpu_hlo_schedule",
     srcs = ["gpu_hlo_schedule.cc"],
     hdrs = ["gpu_hlo_schedule.h"],
     deps = [
         ":backend_configs_cc",
         ":gpu_latency_hiding_scheduler",
-        ":gpu_schedule_postprocessing",
-        ":scheduling_instruction_annotator",
         "//xla:shape_util",
         "//xla:util",
         "//xla/hlo/ir:hlo",
@@ -4049,6 +2071,8 @@
         "//xla/service:p2p_schedule_preparation",
         "//xla/service:profile_guided_latency_estimator",
         "//xla/service/gpu/model:analytical_latency_estimator",
+        "//xla/service/gpu/transforms:schedule_postprocessing",
+        "//xla/service/gpu/transforms:scheduling_instruction_annotator",
         "//xla/stream_executor:device_description",
         "@com_google_absl//absl/algorithm:container",
         "@com_google_absl//absl/container:flat_hash_map",
@@ -4085,6 +2109,7 @@
         "//xla/tests:filecheck",
         "//xla/tests:hlo_test_base",
         "//xla/tests:test_utils",
+        "//xla/tests:verified_hlo_module",
         "//xla/tests:xla_internal_test_main",
         "@com_google_absl//absl/algorithm:container",
         "@com_google_absl//absl/log",
@@ -4141,7 +2166,6 @@
     srcs = ["gpu_spmd_pipeline.cc"],
     hdrs = ["gpu_spmd_pipeline.h"],
     deps = [
-        ":gpu_algebraic_simplifier",
         ":runtime_intrinsics",
         "//xla/hlo/ir:hlo",
         "//xla/hlo/transforms:hlo_constant_splitter",
@@ -4160,6 +2184,7 @@
         "//xla/service:tuple_simplifier",
         "//xla/service:while_loop_constant_sinking",
         "//xla/service:while_loop_simplifier",
+        "//xla/service/gpu/transforms:algebraic_simplifier",
         "//xla/service/spmd:collective_permute_motion",
         "//xla/service/spmd:stateful_rng_spmd_partitioner",
         "//xla/service/spmd/shardy:shardy_xla_pass",
@@ -4292,6 +2317,7 @@
     name = "hlo_fusion_analysis",
     srcs = ["hlo_fusion_analysis.cc"],
     hdrs = ["hlo_fusion_analysis.h"],
+    compatible_with = get_compatible_with_portable(),
     deps = [
         ":backend_configs_cc",
         ":hlo_traversal",
@@ -4319,6 +2345,8 @@
         ":gpu_device_info_for_tests",
         ":hlo_fusion_analysis",
         ":hlo_traversal",
+        ":ir_emission_utils",
+        "//xla:protobuf_util",
         "//xla/stream_executor:device_description",
         "//xla/stream_executor:device_description_proto_cc",
         "//xla/tests:hlo_test_base",
@@ -4426,10 +2454,13 @@
     name = "gpu_fusible",
     srcs = ["gpu_fusible.cc"],
     hdrs = ["gpu_fusible.h"],
+    compatible_with = get_compatible_with_portable(),
     deps = [
         ":backend_configs_cc",
+        ":hlo_fusion_analysis",
         ":hlo_traversal",
         ":ir_emission_utils",
+        ":launch_dimensions",
         ":reduction_utils",
         "//xla:permutation_util",
         "//xla:shape_util",
@@ -4467,265 +2498,6 @@
     ],
 )
 
-cc_library(
-    name = "cudnn_fused_conv_rewriter",
-    srcs = ["cudnn_fused_conv_rewriter.cc"],
-    hdrs = ["cudnn_fused_conv_rewriter.h"],
-    deps = [
-        ":backend_configs_cc",
-        ":cublas_cudnn",
-        "//xla:comparison_util",
-        "//xla:debug_options_flags",
-        "//xla:literal",
-        "//xla:shape_util",
-        "//xla:util",
-        "//xla:xla_data_proto_cc",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_creation_utils",
-        "//xla/service:hlo_pass",
-        "//xla/service:pattern_matcher",
-        "//xla/stream_executor",
-        "//xla/stream_executor:dnn",
-        "@com_google_absl//absl/algorithm:container",
-        "@com_google_absl//absl/container:flat_hash_map",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/container:inlined_vector",
-        "@com_google_absl//absl/log",
-        "@com_google_absl//absl/log:check",
-        "@com_google_absl//absl/status",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/strings:str_format",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:ml_dtypes",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-xla_test(
-    name = "cudnn_fused_conv_rewriter_test",
-    srcs = ["cudnn_fused_conv_rewriter_test.cc"],
-    backend_tags = {
-        "gpu_a100": [
-            "noasan",
-            "nomsan",
-            "no_rocm",
-        ],
-    },
-    backends = [
-        "gpu_a100",
-        "gpu_amd_any",
-    ] + if_oss(["gpu_any"]),
-    local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([
-        "TENSORFLOW_USE_ROCM=1",
-    ]),
-    shard_count = 10,
-    deps = [
-        ":backend_configs_cc",
-        ":cublas_cudnn",
-        ":cudnn_fused_conv_rewriter",
-        ":gpu_conv_rewriter",
-        ":stream_executor_util",
-        "//xla:comparison_util",
-        "//xla:error_spec",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:algebraic_simplifier",
-        "//xla/service:convert_mover",
-        "//xla/service:hlo_constant_folding",
-        "//xla/service:hlo_module_config",
-        "//xla/service:hlo_pass",
-        "//xla/service:hlo_pass_pipeline",
-        "//xla/service:pattern_matcher",
-        "//xla/service:pattern_matcher_gmock",
-        "//xla/service:reshape_mover",
-        "//xla/service/gpu/tests:gpu_codegen_test",
-        "//xla/stream_executor:device_description",
-        "//xla/stream_executor:dnn",
-        "//xla/tests:filecheck",
-        "//xla/tests:hlo_test_base",
-        "//xla/tests:verified_hlo_module",
-        "//xla/tsl/lib/core:status_test_util",
-        "@com_google_absl//absl/container:flat_hash_map",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/strings:str_format",
-        "@com_google_googletest//:gtest_main",
-        "@local_tsl//tsl/platform:statusor",
-        "@local_tsl//tsl/platform:test_main",
-    ] + if_cuda_is_configured([
-        "@local_config_cuda//cuda:cuda_headers",
-        "@local_config_cuda//cuda:cudnn_header",
-    ]) + if_rocm_is_configured([
-        "@local_config_rocm//rocm:rocm_headers",
-    ]),
-)
-
-cc_library(
-    name = "cudnn_norm_rewriter",
-    srcs = ["cudnn_norm_rewriter.cc"],
-    hdrs = ["cudnn_norm_rewriter.h"],
-    local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
-    deps = [
-        ":backend_configs_cc",
-        ":cublas_cudnn",
-        "//xla:shape_util",
-        "//xla:types",
-        "//xla:util",
-        "//xla:window_util",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_creation_utils",
-        "//xla/service:hlo_pass",
-        "//xla/service:pattern_matcher",
-        "//xla/stream_executor",
-        "@com_google_absl//absl/algorithm:container",
-        "@com_google_absl//absl/container:flat_hash_map",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/status",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/types:span",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:logging",
-        "@local_tsl//tsl/platform:statusor",
-        "@local_tsl//tsl/protobuf:dnn_proto_cc",
-    ] + if_cuda_is_configured([
-        "@local_config_cuda//cuda:cuda_headers",
-        "@local_config_cuda//cuda:cudnn_header",
-    ]) + if_google([
-        "@com_google_protobuf//:wrappers_cc_proto",
-    ]),
-)
-
-xla_test(
-    name = "cudnn_norm_rewriter_test",
-    srcs = ["cudnn_norm_rewriter_test.cc"],
-    backends = ["gpu"],
-    local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
-    deps = [
-        ":cublas_cudnn",
-        ":cudnn_norm_rewriter",
-        "//xla:error_spec",
-        "//xla/service/gpu/tests:gpu_codegen_test",
-        "//xla/stream_executor:device_description",
-        "//xla/tests:filecheck",
-        "//xla/tsl/lib/core:status_test_util",
-        "@com_google_googletest//:gtest_main",
-    ] + if_cuda_is_configured([
-        "@local_config_cuda//cuda:cuda_headers",
-        "@local_config_cuda//cuda:cudnn_header",
-    ]),
-)
-
-cc_library(
-    name = "cudnn_fused_mha_rewriter",
-    srcs = ["cudnn_fused_mha_rewriter.cc"],
-    hdrs = ["cudnn_fused_mha_rewriter.h"],
-    local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
-    deps = [
-        ":backend_configs_cc",
-        ":cublas_cudnn",
-        ":matmul_utils",
-        ":stream_executor_util",
-        "//xla:permutation_util",
-        "//xla:shape_util",
-        "//xla:status_macros",
-        "//xla:types",
-        "//xla:util",
-        "//xla:xla_data_proto_cc",
-        "//xla:xla_proto_cc",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_pass",
-        "//xla/service:pattern_matcher",
-        "//xla/stream_executor",
-        "//xla/stream_executor:dnn",
-        "@com_google_absl//absl/algorithm:container",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/log",
-        "@com_google_absl//absl/log:check",
-        "@com_google_absl//absl/status",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/strings:str_format",
-        "@com_google_absl//absl/types:span",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:statusor",
-    ] + if_cuda_is_configured([
-        "@local_config_cuda//cuda:cuda_headers",
-    ]),
-)
-
-cc_library(
-    name = "cudnn_fused_mha_transpose_fusion",
-    srcs = ["cudnn_fused_mha_transpose_fusion.cc"],
-    hdrs = ["cudnn_fused_mha_transpose_fusion.h"],
-    deps = [
-        ":backend_configs_cc",
-        ":cublas_cudnn",
-        ":matmul_utils",
-        "//xla:permutation_util",
-        "//xla:shape_util",
-        "//xla:util",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_pass",
-        "//xla/service:pattern_matcher",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/log",
-        "@com_google_absl//absl/log:check",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/types:span",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-xla_test(
-    name = "cudnn_fused_mha_rewriter_test",
-    srcs = ["cudnn_fused_mha_rewriter_test.cc"],
-    backend_tags = {"gpu": [
-        "requires-gpu-nvidia",
-        "no_rocm",
-    ]},
-    backends = [
-        "gpu",
-    ],
-    local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
-    deps = [
-        ":backend_configs_cc",
-        ":cublas_cudnn",
-        ":cudnn_fused_mha_rewriter",
-        ":cudnn_fused_mha_transpose_fusion",
-        "//xla:error_spec",
-        "//xla:test_helpers",
-        "//xla:util",
-        "//xla:xla_data_proto_cc",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:algebraic_simplifier",
-        "//xla/service:computation_layout",
-        "//xla/service:hlo_cse",
-        "//xla/service:hlo_dce",
-        "//xla/service:hlo_module_config",
-        "//xla/service:hlo_parser",
-        "//xla/service:hlo_verifier",
-        "//xla/service:layout_normalization",
-        "//xla/service:pattern_matcher",
-        "//xla/service:pattern_matcher_gmock",
-        "//xla/service:reshape_decomposer",
-        "//xla/stream_executor:device_description",
-        "//xla/stream_executor:dnn",
-        "//xla/tests:hlo_test_base",
-        "//xla/tsl/lib/core:status_test_util",
-        "@com_google_absl//absl/algorithm:container",
-        "@com_google_absl//absl/strings",
-        "@com_google_googletest//:gtest_main",
-        "@local_tsl//tsl/platform:statusor",
-        "@local_tsl//tsl/platform:test_main",
-    ] + if_cuda_is_configured([
-        "@local_config_cuda//cuda:cuda_headers",
-        "@local_config_cuda//cuda:cudnn_header",
-    ]),
-)
-
 xla_test(
     name = "float_support_test",
     srcs = ["float_support_test.cc"],
@@ -4763,78 +2535,15 @@
 )
 
 cc_library(
-    name = "variadic_op_splitter",
-    srcs = ["variadic_op_splitter.cc"],
-    hdrs = ["variadic_op_splitter.h"],
-    deps = [
-        "//xla:shape_util",
-        "//xla:util",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_pass",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/types:span",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-cc_library(
-    name = "gpu_scatter_expander",
-    srcs = ["gpu_scatter_expander.cc"],
-    hdrs = ["gpu_scatter_expander.h"],
-    deps = [
-        "//xla:shape_util",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:scatter_expander",
-        "@com_google_absl//absl/strings:string_view",
-    ],
-)
-
-xla_cc_test(
-    name = "variadic_op_splitter_test",
-    srcs = ["variadic_op_splitter_test.cc"],
-    tags = [
-        "nomsan",
-    ],
-    deps = [
-        ":variadic_op_splitter",
-        "//xla:literal_util",
-        "//xla:shape_util",
-        "//xla:util",
-        "//xla:xla_data_proto_cc",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_parser",
-        "//xla/service:pattern_matcher",
-        "//xla/tests:hlo_test_base",
-        "//xla/tests:xla_internal_test_main",
-        "@com_google_googletest//:gtest_main",
-    ],
-)
-
-tf_proto_library(
-    name = "gpu_autotuning_proto",
-    srcs = ["gpu_autotuning.proto"],
-    cc_api_version = 2,
-    protodeps = [
-        ":backend_configs",
-        "//xla:xla_data_proto",
-        "//xla/service:hlo_proto",
-        "//xla:autotuning_proto",
-    ],
-)
-
-cc_library(
     name = "hlo_algorithm_denylist",
     srcs = ["hlo_algorithm_denylist.cc"],
     hdrs = ["hlo_algorithm_denylist.h"],
     deps = [
         ":backend_configs_cc",
-        ":gpu_autotuning_proto_cc",
         "//xla:autotuning_proto_cc",
         "//xla:debug_options_flags",
         "//xla/hlo/ir:backend_config",
+        "//xla/service/gpu/autotuning:gpu_autotuning_proto_cc",
         "//xla/stream_executor:dnn",
         "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/log:check",
@@ -4861,133 +2570,6 @@
     ],
 )
 
-cc_library(
-    name = "alias_passthrough_params",
-    srcs = ["alias_passthrough_params.cc"],
-    hdrs = ["alias_passthrough_params.h"],
-    deps = [
-        "//xla:shape_util",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_pass",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings:string_view",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:logging",
-    ],
-)
-
-xla_cc_test(
-    name = "alias_passthrough_params_test",
-    srcs = ["alias_passthrough_params_test.cc"],
-    tags = [
-        "nomsan",
-    ],
-    deps = [
-        ":alias_passthrough_params",
-        "//xla/tests:hlo_test_base",
-        "//xla/tests:xla_internal_test_main",
-        "//xla/tsl/lib/core:status_test_util",
-        "@local_tsl//tsl/platform:test",
-    ],
-)
-
-cc_library(
-    name = "horizontal_loop_fusion",
-    srcs = ["horizontal_loop_fusion.cc"],
-    hdrs = ["horizontal_loop_fusion.h"],
-    deps = [
-        ":gpu_fusible",
-        "//xla:shape_util",
-        "//xla:util",
-        "//xla:xla_data_proto_cc",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_creation_utils",
-        "//xla/service:hlo_pass",
-        "//xla/service:sub_byte_normalization",
-        "@com_google_absl//absl/algorithm:container",
-        "@com_google_absl//absl/container:flat_hash_map",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/log",
-        "@com_google_absl//absl/log:check",
-        "@com_google_absl//absl/status",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/strings:string_view",
-        "@com_google_absl//absl/types:span",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-xla_test(
-    name = "horizontal_loop_fusion_test",
-    srcs = ["horizontal_loop_fusion_test.cc"],
-    backends = ["gpu"],
-    deps = [
-        ":gpu_device_info_for_tests",
-        ":horizontal_loop_fusion",
-        ":instruction_fusion",
-        "//xla:error_spec",
-        "//xla:shape_util",
-        "//xla:test",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_dce",
-        "//xla/service:hlo_parser",
-        "//xla/service:hlo_pass",
-        "//xla/service:hlo_pass_pipeline",
-        "//xla/service:pattern_matcher",
-        "//xla/service:pattern_matcher_gmock",
-        "//xla/stream_executor:device_description",
-        "//xla/tests:hlo_test_base",
-        "//xla/tests:xla_internal_test_main",
-        "//xla/tsl/lib/core:status_test_util",
-        "@com_google_absl//absl/algorithm:container",
-        "@com_google_absl//absl/log",
-    ],
-)
-
-cc_library(
-    name = "horizontal_input_fusion",
-    srcs = ["horizontal_input_fusion.cc"],
-    hdrs = ["horizontal_input_fusion.h"],
-    deps = [
-        ":gpu_fusible",
-        "//xla:shape_util",
-        "//xla:util",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_creation_utils",
-        "//xla/service:hlo_pass",
-        "//xla/stream_executor:device_description",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/log",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings:string_view",
-        "@com_google_absl//absl/types:span",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-xla_test(
-    name = "horizontal_input_fusion_test",
-    srcs = ["horizontal_input_fusion_test.cc"],
-    backends = ["gpu"],
-    deps = [
-        ":gpu_device_info_for_tests",
-        ":horizontal_input_fusion",
-        "//xla:error_spec",
-        "//xla:literal_util",
-        "//xla:shape_util",
-        "//xla:test",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:pattern_matcher",
-        "//xla/service:pattern_matcher_gmock",
-        "//xla/service/gpu/tests:gpu_codegen_test",
-        "//xla/stream_executor:device_description",
-        "//xla/tests:xla_internal_test_main",
-    ],
-)
-
 xla_cc_test(
     name = "gpu_float_support_test",
     srcs = ["gpu_float_support_test.cc"],
@@ -5010,151 +2592,6 @@
 )
 
 cc_library(
-    name = "reduction_degenerate_dim_remover",
-    srcs = ["reduction_degenerate_dim_remover.cc"],
-    hdrs = ["reduction_degenerate_dim_remover.h"],
-    deps = [
-        "//xla:shape_util",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_pass",
-        "@com_google_absl//absl/algorithm:container",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/container:inlined_vector",
-        "@com_google_absl//absl/status",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-cc_library(
-    name = "reduction_dimension_grouper",
-    srcs = ["reduction_dimension_grouper.cc"],
-    hdrs = ["reduction_dimension_grouper.h"],
-    deps = [
-        "//xla:shape_util",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_pass",
-        "@com_google_absl//absl/algorithm:container",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/container:inlined_vector",
-        "@com_google_absl//absl/log",
-        "@com_google_absl//absl/log:check",
-        "@com_google_absl//absl/status",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings:string_view",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-cc_library(
-    name = "reduction_splitter",
-    srcs = ["reduction_splitter.cc"],
-    hdrs = ["reduction_splitter.h"],
-    deps = [
-        ":reduction_utils",
-        "//xla:shape_util",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_pass",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/log",
-        "@com_google_absl//absl/log:check",
-        "@com_google_absl//absl/status",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings:string_view",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-xla_cc_test(
-    name = "reduction_splitter_test",
-    srcs = ["reduction_splitter_test.cc"],
-    deps = [
-        ":reduction_splitter",
-        "//xla:shape_util",
-        "//xla:test",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_parser",
-        "//xla/service:pattern_matcher",
-        "//xla/service:pattern_matcher_gmock",
-        "//xla/tests:hlo_test_base",
-        "//xla/tests:xla_internal_test_main",
-    ],
-)
-
-cc_library(
-    name = "reduction_layout_normalizer",
-    srcs = ["reduction_layout_normalizer.cc"],
-    hdrs = ["reduction_layout_normalizer.h"],
-    deps = [
-        "//xla:shape_util",
-        "//xla:status_macros",
-        "//xla:util",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_pass",
-        "@com_google_absl//absl/algorithm:container",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/container:inlined_vector",
-        "@com_google_absl//absl/log",
-        "@com_google_absl//absl/log:check",
-        "@com_google_absl//absl/status",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-cc_library(
-    name = "tree_reduction_rewriter",
-    srcs = ["tree_reduction_rewriter.cc"],
-    hdrs = ["tree_reduction_rewriter.h"],
-    deps = [
-        ":reduction_utils",
-        "//xla:shape_util",
-        "//xla:util",
-        "//xla:xla_data_proto_cc",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:collective_ops_utils",
-        "//xla/service:hlo_module_config",
-        "//xla/service:hlo_pass",
-        "//xla/stream_executor:device_description",
-        "@com_google_absl//absl/algorithm:container",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/container:inlined_vector",
-        "@com_google_absl//absl/log",
-        "@com_google_absl//absl/log:check",
-        "@com_google_absl//absl/numeric:bits",
-        "@com_google_absl//absl/status",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/types:span",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-cc_library(
-    name = "gemm_broadcast_folding_rewriter",
-    srcs = ["gemm_broadcast_folding_rewriter.cc"],
-    hdrs = ["gemm_broadcast_folding_rewriter.h"],
-    deps = [
-        ":backend_configs_cc",
-        ":cublas_cudnn",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_pass",
-        "//xla/service:pattern_matcher",
-        "@com_google_absl//absl/algorithm:container",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/log:check",
-        "@com_google_absl//absl/status",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings:string_view",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-cc_library(
     name = "metrics",
     srcs = ["metrics.cc"],
     hdrs = ["metrics.h"],
@@ -5166,48 +2603,6 @@
 )
 
 cc_library(
-    name = "dot_operand_converter",
-    srcs = ["dot_operand_converter.cc"],
-    hdrs = ["dot_operand_converter.h"],
-    deps = [
-        "//xla:shape_util",
-        "//xla:util",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:op_expander_pass",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings:string_view",
-        "@local_tsl//tsl/platform:errors",
-    ],
-)
-
-xla_test(
-    name = "dot_operand_converter_test",
-    srcs = if_gpu_is_configured(["dot_operand_converter_test.cc"]),
-    backends = [
-        "gpu_a100",
-        "gpu_p100",
-        "gpu_v100",
-        "gpu_amd_any",
-    ],
-    deps = if_gpu_is_configured(
-        [
-            ":dot_operand_converter",
-            "@com_google_googletest//:gtest",
-            "@com_google_absl//absl/strings",
-            "@com_google_absl//absl/strings:string_view",
-            "//xla:shape_util",
-            "//xla/hlo/ir:hlo",
-            "//xla/hlo/utils:hlo_matchers",
-            "//xla/service:pattern_matcher",
-            "//xla/tests:hlo_test_base",
-            "//xla/tests:xla_internal_test_main",
-            "@local_tsl//tsl/platform:statusor",
-        ],
-        ["@local_tsl//tsl/platform:test_main"],  # b/317293391
-    ) + ["//xla:xla_data_proto_cc"],
-)
-
-cc_library(
     name = "make_batch_pointers",
     srcs = if_gpu_is_configured(["make_batch_pointers.cc"]),
     hdrs = if_gpu_is_configured(["make_batch_pointers.h"]),
@@ -5237,25 +2632,6 @@
     ],
 )
 
-cc_library(
-    name = "triangular_solve_rewriter",
-    srcs = ["triangular_solve_rewriter.cc"],
-    hdrs = ["triangular_solve_rewriter.h"],
-    deps = [
-        ":cublas_cudnn",
-        "//xla:shape_util",
-        "//xla:xla_data_proto_cc",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_creation_utils",
-        "//xla/service:hlo_pass",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
 tsl_gpu_library(
     name = "runtime_intrinsics",
     srcs = ["runtime_intrinsics.cc"],
@@ -5324,43 +2700,6 @@
 )
 
 cc_library(
-    name = "scatter_slice_simplifier",
-    srcs = ["scatter_slice_simplifier.cc"],
-    hdrs = ["scatter_slice_simplifier.h"],
-    deps = [
-        "//xla:shape_util",
-        "//xla:util",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_creation_utils",
-        "//xla/service:hlo_pass",
-        "@com_google_absl//absl/algorithm:container",
-        "@com_google_absl//absl/container:flat_hash_map",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/log",
-        "@com_google_absl//absl/status",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings:string_view",
-        "@com_google_absl//absl/types:span",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-xla_cc_test(
-    name = "scatter_slice_simplifier_test",
-    srcs = ["scatter_slice_simplifier_test.cc"],
-    deps = [
-        ":scatter_slice_simplifier",
-        "//xla:shape_util",
-        "//xla/service:pattern_matcher",
-        "//xla/service:pattern_matcher_gmock",
-        "//xla/tests:hlo_test_base",
-        "//xla/tests:xla_internal_test_main",
-        "@com_google_googletest//:gtest_main",
-    ],
-)
-
-cc_library(
     name = "conv_layout_normalization",
     srcs = ["conv_layout_normalization.cc"],
     hdrs = ["conv_layout_normalization.h"],
@@ -5377,133 +2716,6 @@
     ],
 )
 
-cc_library(
-    name = "topk_specializer",
-    srcs = ["topk_specializer.cc"],
-    hdrs = ["topk_specializer.h"],
-    deps = [
-        "//xla:shape_util",
-        "//xla:status_macros",
-        "//xla:util",
-        "//xla:xla_data_proto_cc",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_pass",
-        "//xla/service:hlo_proto_cc",
-        "//xla/service:tuple_util",
-        "@com_google_absl//absl/algorithm:container",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/log",
-        "@com_google_absl//absl/log:check",
-        "@com_google_absl//absl/status",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings",
-    ],
-)
-
-cc_library(
-    name = "topk_splitter",
-    srcs = ["topk_splitter.cc"],
-    hdrs = ["topk_splitter.h"],
-    deps = [
-        "//xla:shape_util",
-        "//xla:xla_data_proto_cc",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_creation_utils",
-        "//xla/service:hlo_pass",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/log",
-        "@com_google_absl//absl/numeric:bits",
-        "@com_google_absl//absl/status",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/types:span",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-xla_cc_test(
-    name = "topk_splitter_test",
-    srcs = ["topk_splitter_test.cc"],
-    deps = [
-        ":topk_splitter",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_dce",
-        "//xla/service:pattern_matcher",
-        "//xla/service:topk_rewriter",
-        "//xla/tests:hlo_test_base",
-        "//xla/tests:verified_hlo_module",
-        "//xla/tests:xla_internal_test_main",
-        "@com_google_absl//absl/strings",
-        "@local_tsl//tsl/platform:status_matchers",
-        "@local_tsl//tsl/platform:statusor",
-        "@local_tsl//tsl/platform:test",
-    ],
-)
-
-xla_test(
-    name = "topk_test",
-    srcs = ["topk_test.cc"],
-    backends = ["gpu"],
-    deps = [
-        ":topk_specializer",
-        "//xla:shape_util",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_pass",
-        "//xla/service:platform_util",
-        "//xla/service:topk_rewriter",
-        "//xla/tests:hlo_test_base",
-        "//xla/tests:xla_internal_test_main",  # fixdeps: keep
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/status",
-        "@com_google_absl//absl/strings",
-        "@com_google_googletest//:gtest_main",
-        "@local_tsl//tsl/platform:statusor",
-        "@local_tsl//tsl/platform:test_main",
-    ],
-)
-
-cc_library(
-    name = "copy_fusion",
-    srcs = ["copy_fusion.cc"],
-    hdrs = ["copy_fusion.h"],
-    deps = [
-        ":gpu_fusible",
-        ":hlo_traversal",
-        ":ir_emission_utils",
-        ":reduction_utils",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_pass",
-        "@com_google_absl//absl/algorithm:container",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings:string_view",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:logging",
-    ],
-)
-
-cc_library(
-    name = "algorithm_checker",
-    srcs = ["algorithm_checker.cc"],
-    hdrs = ["algorithm_checker.h"],
-    deps = [
-        "//xla:xla_data_proto_cc",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:algorithm_util",
-        "//xla/service:hlo_pass",
-        "//xla/stream_executor:device_description",
-        "@com_google_absl//absl/algorithm:container",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/log",
-        "@com_google_absl//absl/status",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings:str_format",
-        "@com_google_absl//absl/strings:string_view",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:logging",
-    ],
-)
-
 xla_test(
     name = "dot_algorithm_support_test",
     srcs = if_gpu_is_configured(["dot_algorithm_support_test.cc"]),
@@ -5618,138 +2830,6 @@
     ],
 )
 
-cc_library(
-    name = "fusion_wrapper",
-    srcs = ["fusion_wrapper.cc"],
-    hdrs = ["fusion_wrapper.h"],
-    deps = [
-        ":gpu_fusible",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_pass",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/status",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings",
-        "@local_tsl//tsl/platform:errors",
-    ],
-)
-
-xla_cc_test(
-    name = "fusion_wrapper_test",
-    srcs = ["fusion_wrapper_test.cc"],
-    deps = [
-        ":fusion_wrapper",
-        "//xla/tests:hlo_test_base",
-        "@com_google_googletest//:gtest_main",
-    ],
-)
-
-xla_cc_test(
-    name = "copy_fusion_test",
-    srcs = ["copy_fusion_test.cc"],
-    deps = [
-        ":copy_fusion",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:pattern_matcher",
-        "//xla/service:pattern_matcher_gmock",
-        "//xla/tests:hlo_test_base",
-        "//xla/tests:xla_internal_test_main",
-        "@com_google_absl//absl/strings",
-        "@com_google_googletest//:gtest",
-    ],
-)
-
-xla_cc_test(
-    name = "autotuner_util_test",
-    srcs = if_cuda_is_configured(["autotuner_util_test.cc"]),
-    data = [
-        "//xla/tools/hlo_opt:gpu_specs/a100_sxm_40.txtpb",
-        "//xla/tools/hlo_opt:gpu_specs/a100_sxm_80.txtpb",
-        "//xla/tools/hlo_opt:gpu_specs/mi200.txtpb",
-    ],
-    deps = if_cuda_is_configured([
-        # keep sorted
-        ":autotuner_util",
-        "//xla:autotune_results_proto_cc",
-        "//xla:autotuning_proto_cc",
-        "//xla:xla_proto_cc",
-        "//xla/hlo/ir:hlo",
-        "//xla/hlo/utils:hlo_query",
-        "//xla/stream_executor:device_description",
-        "//xla/stream_executor:device_description_proto_cc",
-        "//xla/stream_executor:platform",
-        "//xla/stream_executor:platform_manager",
-        "//xla/stream_executor/host:host_platform",
-        "//xla/tests:hlo_test_base",
-        "//xla/tests:verified_hlo_module",
-        "//xla/tsl/lib/core:status_test_util",
-        "@com_google_absl//absl/base:log_severity",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/log:check",
-        "@com_google_absl//absl/log:scoped_mock_log",
-        "@com_google_absl//absl/status",
-        "@com_google_absl//absl/strings",
-        "@com_google_googletest//:gtest",
-        "@local_tsl//tsl/platform:env",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:logging",
-        "@local_tsl//tsl/platform:path",
-        "@local_tsl//tsl/platform:protobuf",
-        "@local_tsl//tsl/platform:status",
-        "@local_tsl//tsl/platform:status_matchers",
-        "@local_tsl//tsl/platform:statusor",
-        "@local_tsl//tsl/platform:test",
-    ]) + [
-        "//xla/tests:xla_internal_test_main",  # Keep outside GPU guard
-    ],
-)
-
-cc_library(
-    name = "double_buffer_loop_unrolling",
-    srcs = ["double_buffer_loop_unrolling.cc"],
-    hdrs = ["double_buffer_loop_unrolling.h"],
-    deps = [
-        "//xla:util",
-        "//xla:xla_data_proto_cc",
-        "//xla/hlo/ir:hlo",
-        "//xla/hlo/ir:hlo_instruction_utils",
-        "//xla/hlo/utils:hlo_query",
-        "//xla/service:collective_ops_utils",
-        "//xla/service:flatten_call_graph",
-        "//xla/service:hlo_pass",
-        "@com_google_absl//absl/algorithm:container",
-        "@com_google_absl//absl/container:flat_hash_map",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/log",
-        "@com_google_absl//absl/log:check",
-        "@com_google_absl//absl/status",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-xla_cc_test(
-    name = "double_buffer_loop_unrolling_test",
-    srcs = ["double_buffer_loop_unrolling_test.cc"],
-    deps = [
-        ":double_buffer_loop_unrolling",
-        "//xla:test",
-        "//xla:xla_data_proto_cc",
-        "//xla:xla_proto_cc",
-        "//xla/hlo/ir:hlo",
-        "//xla/hlo/utils:hlo_query",
-        "//xla/service:tuple_simplifier",
-        "//xla/tests:filecheck",
-        "//xla/tests:hlo_test_base",
-        "//xla/tests:xla_internal_test_main",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@local_tsl//tsl/platform:status_matchers",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
 xla_test(
     name = "determinism_test",
     srcs = if_gpu_is_configured(["determinism_test.cc"]),
@@ -5762,7 +2842,7 @@
     ]),
     deps = if_gpu_is_configured(
         [
-            ":autotuner_util",
+            "//xla/service/gpu/autotuning:autotuner_util",
             "@com_google_googletest//:gtest_main",
             "@com_google_absl//absl/strings",
             "//xla:literal",
@@ -5791,283 +2871,6 @@
 )
 
 cc_library(
-    name = "collective_permute_cycle_decomposer",
-    srcs = ["collective_permute_cycle_decomposer.cc"],
-    hdrs = ["collective_permute_cycle_decomposer.h"],
-    deps = [
-        ":backend_configs_cc",
-        "//xla:comparison_util",
-        "//xla:literal_util",
-        "//xla:shape_util",
-        "//xla:util",
-        "//xla:xla_data_proto_cc",
-        "//xla/hlo/ir:hlo",
-        "//xla/hlo/utils:hlo_query",
-        "//xla/service:collective_ops_utils",
-        "//xla/service:hlo_parser",
-        "//xla/service:hlo_pass",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/log:check",
-        "@com_google_absl//absl/status",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/strings:string_view",
-        "@local_tsl//tsl/platform:errors",
-    ],
-)
-
-xla_cc_test(
-    name = "collective_permute_cycle_decomposer_test",
-    srcs = ["collective_permute_cycle_decomposer_test.cc"],
-    deps = [
-        ":collective_permute_cycle_decomposer",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_parser",
-        "//xla/tests:hlo_test_base",
-        "@com_google_absl//absl/strings:string_view",
-        "@com_google_googletest//:gtest",
-        "@local_tsl//tsl/platform:statusor",
-        "@local_tsl//tsl/platform:test_main",
-    ],
-)
-
-cc_library(
-    name = "collective_permute_valid_iteration_annotator",
-    srcs = ["collective_permute_valid_iteration_annotator.cc"],
-    hdrs = ["collective_permute_valid_iteration_annotator.h"],
-    deps = [
-        "//xla:literal_util",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:collective_ops_utils",
-        "//xla/service:hlo_pass",
-        "//xla/service:pattern_matcher",
-        "//xla/service:while_loop_analysis",
-    ],
-)
-
-xla_cc_test(
-    name = "collective_permute_valid_iteration_annotator_test",
-    srcs = ["collective_permute_valid_iteration_annotator_test.cc"],
-    deps = [
-        ":collective_permute_valid_iteration_annotator",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:collective_ops_utils",
-        "//xla/service:hlo_pass_pipeline",
-        "//xla/service:while_loop_trip_count_annotator",
-        "//xla/tests:hlo_test_base",
-        "@local_tsl//tsl/platform:test_main",
-    ],
-)
-
-cc_library(
-    name = "stream_attribute_annotator",
-    srcs = ["stream_attribute_annotator.cc"],
-    hdrs = ["stream_attribute_annotator.h"],
-    deps = [
-        ":backend_configs_cc",
-        ":gpu_fusible",
-        "//xla:util",
-        "//xla:xla_data_proto_cc",
-        "//xla/hlo/ir:hlo",
-        "//xla/hlo/utils:hlo_query",
-        "//xla/service:hlo_pass",
-        "//xla/service/gpu/runtime:thunk",
-        "@com_google_absl//absl/algorithm:container",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings:string_view",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:logging",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-xla_cc_test(
-    name = "stream_attribute_annotator_test",
-    srcs = ["stream_attribute_annotator_test.cc"],
-    deps = [
-        ":backend_configs_cc",
-        ":stream_attribute_annotator",
-        "//xla/hlo/ir:hlo",
-        "//xla/tests:hlo_test_base",
-        "@com_google_absl//absl/algorithm:container",
-        "@com_google_absl//absl/strings:string_view",
-        "@com_google_googletest//:gtest_main",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-cc_library(
-    name = "stream_attribute_async_wrapper",
-    srcs = ["stream_attribute_async_wrapper.cc"],
-    hdrs = ["stream_attribute_async_wrapper.h"],
-    deps = [
-        ":backend_configs_cc",
-        "//xla:util",
-        "//xla:xla_data_proto_cc",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_pass",
-        "//xla/service/gpu/runtime:thunk",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings:string_view",
-        "@local_tsl//tsl/platform:logging",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-xla_cc_test(
-    name = "stream_attribute_async_wrapper_test",
-    srcs = ["stream_attribute_async_wrapper_test.cc"],
-    deps = [
-        ":backend_configs_cc",
-        ":stream_attribute_async_wrapper",
-        "//xla/hlo/ir:hlo",
-        "//xla/tests:hlo_test_base",
-        "@com_google_absl//absl/strings:string_view",
-        "@com_google_googletest//:gtest_main",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-cc_library(
-    name = "gpu_windowed_einsum_handler",
-    srcs = ["gpu_windowed_einsum_handler.cc"],
-    hdrs = ["gpu_windowed_einsum_handler.h"],
-    deps = [
-        ":backend_configs_cc",
-        "//xla:literal_util",
-        "//xla:util",
-        "//xla:xla_data_proto_cc",
-        "//xla/hlo/ir:hlo",
-        "//xla/hlo/utils:hlo_query",
-        "//xla/service:hlo_creation_utils",
-        "//xla/service:hlo_pass",
-        "//xla/service:pattern_matcher",
-        "//xla/service:shape_inference",
-        "@com_google_absl//absl/algorithm:container",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/status",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings:string_view",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:logging",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-xla_cc_test(
-    name = "gpu_windowed_einsum_handler_test",
-    srcs = ["gpu_windowed_einsum_handler_test.cc"],
-    deps = [
-        ":backend_configs_cc",
-        ":gpu_windowed_einsum_handler",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:pattern_matcher",
-        "//xla/service:pattern_matcher_gmock",
-        "//xla/tests:filecheck",
-        "//xla/tests:hlo_test_base",
-        "@com_google_absl//absl/strings:string_view",
-        "@com_google_googletest//:gtest_main",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-cc_library(
-    name = "triton_fusion_numerics_verifier",
-    srcs = if_gpu_is_configured(["triton_fusion_numerics_verifier.cc"]),
-    hdrs = if_gpu_is_configured(["triton_fusion_numerics_verifier.h"]),
-    deps = if_gpu_is_configured([
-        ":autotuner_compile_util",
-        ":autotuner_util",
-        ":backend_configs_cc",
-        ":buffer_comparator",
-        ":ir_emission_utils",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/functional:any_invocable",
-        "@com_google_absl//absl/status",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings:string_view",
-        "//xla:shape_util",
-        "//xla:status_macros",
-        "//xla:util",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:executable",
-        "//xla/service:hlo_pass",
-        "//xla/service:shaped_buffer",
-        "//xla/service:hlo_module_config",
-        "//xla/stream_executor:stream",
-        "//xla/tools:hlo_decomposer_lib",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:statusor",
-    ]),
-)
-
-xla_test(
-    name = "triton_fusion_numerics_verifier_test",
-    srcs = if_gpu_is_configured(["triton_fusion_numerics_verifier_test.cc"]),
-    backend_tags = {"gpu": [
-        "requires-gpu-sm80",
-    ]},
-    backends = ["gpu"],
-    deps = [
-        ":autotuner_compile_util",
-        ":autotuner_util",
-        ":triton_fusion_numerics_verifier",
-        "//xla:shape_util",
-        "//xla:test_helpers",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:platform_util",
-        "//xla/stream_executor:platform",
-        "//xla/tests:hlo_test_base",
-        "//xla/tsl/lib/core:status_test_util",
-        "@com_google_absl//absl/status",
-        "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/strings:string_view",
-        "@com_google_googletest//:gtest_main",
-    ],
-)
-
-cc_library(
-    name = "pipelined_p2p_rewriter",
-    srcs = ["pipelined_p2p_rewriter.cc"],
-    hdrs = ["pipelined_p2p_rewriter.h"],
-    deps = [
-        "//xla:shape_util",
-        "//xla:util",
-        "//xla/hlo/ir:hlo",
-        "//xla/hlo/utils:hlo_query",
-        "//xla/service:collective_ops_utils",
-        "//xla/service:hlo_pass",
-        "@com_google_absl//absl/container:flat_hash_map",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/log",
-        "@com_google_absl//absl/log:check",
-        "@com_google_absl//absl/status",
-        "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/strings:string_view",
-        "@com_google_absl//absl/types:span",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-xla_cc_test(
-    name = "pipelined_p2p_rewriter_test",
-    srcs = ["pipelined_p2p_rewriter_test.cc"],
-    deps = [
-        ":pipelined_p2p_rewriter",
-        "//xla/hlo/ir:hlo",
-        "//xla/tests:filecheck",
-        "//xla/tests:hlo_test_base",
-        "@com_google_absl//absl/strings:string_view",
-        "@com_google_googletest//:gtest",
-        "@local_tsl//tsl/platform:statusor",
-        "@local_tsl//tsl/platform:test_main",
-    ],
-)
-
-cc_library(
     name = "execution_stream_assignment",
     srcs = ["execution_stream_assignment.cc"],
     hdrs = ["execution_stream_assignment.h"],
@@ -6138,32 +2941,3 @@
         "@local_tsl//tsl/platform:statusor",
     ],
 )
-
-cc_library(
-    name = "scheduling_instruction_annotator",
-    srcs = ["scheduling_instruction_annotator.cc"],
-    hdrs = ["scheduling_instruction_annotator.h"],
-    deps = [
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_pass",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/log:check",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings:string_view",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-xla_cc_test(
-    name = "scheduling_instruction_annotator_test",
-    srcs = ["scheduling_instruction_annotator_test.cc"],
-    deps = [
-        ":scheduling_instruction_annotator",
-        "//xla/hlo/ir:hlo",
-        "//xla/tests:filecheck",
-        "//xla/tests:hlo_test_base",
-        "@com_google_absl//absl/strings:string_view",
-        "@com_google_googletest//:gtest_main",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
diff --git a/third_party/xla/xla/service/gpu/algorithm_checker.cc b/third_party/xla/xla/service/gpu/algorithm_checker.cc
deleted file mode 100644
index 3104293..0000000
--- a/third_party/xla/xla/service/gpu/algorithm_checker.cc
+++ /dev/null
@@ -1,117 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/algorithm_checker.h"
-
-#include <utility>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/log/log.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/str_format.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/service/algorithm_util.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h"
-
-namespace xla {
-namespace gpu {
-
-namespace {
-
-bool HasNonDefaultOperandPrecision(const PrecisionConfig& config) {
-  return absl::c_any_of(config.operand_precision(), [](int precision) {
-    return static_cast<PrecisionConfig::Precision>(precision) !=
-           PrecisionConfig::DEFAULT;
-  });
-}
-
-class AlgorithmCheckerVisitor : public ConstDfsHloVisitorWithDefault {
- public:
-  explicit AlgorithmCheckerVisitor(
-      se::GpuComputeCapability gpu_compute_capability)
-      : gpu_compute_capability_(std::move(gpu_compute_capability)) {}
-
-  absl::Status RunOnModule(
-      const HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads = {}) {
-    for (HloComputation* computation :
-         module->MakeNonfusionComputations(execution_threads)) {
-      TF_RETURN_IF_ERROR(computation->Accept(this));
-    }
-    return absl::OkStatus();
-  }
-
-  absl::Status HandleDot(const HloInstruction* hlo) override {
-    VLOG(1) << "Handling dot: " << hlo->ToString();
-    const PrecisionConfig& config = hlo->precision_config();
-
-    if (config.algorithm() != PrecisionConfig::ALG_UNSET &&
-        HasNonDefaultOperandPrecision(config)) {
-      LOG(WARNING)
-          << "There is no need to set precisions when we set the algorithm: "
-          << hlo->ToString();
-    }
-
-    if (config.algorithm() == PrecisionConfig::ALG_UNSET) {
-      return absl::OkStatus();
-    }
-
-    PrimitiveType lhs_storage_type = hlo->operand(0)->shape().element_type();
-    PrimitiveType rhs_storage_type = hlo->operand(1)->shape().element_type();
-    PrimitiveType output_storage_type = hlo->shape().element_type();
-
-    if (lhs_storage_type != rhs_storage_type) {
-      return absl::UnimplementedError(absl::StrFormat(
-          "Dot operands must have the same type when using an algorithm: %s",
-          hlo->ToString()));
-    }
-
-    return algorithm_util::IsSupportedDotAlgorithmOnGpu(
-               config.algorithm(), gpu_compute_capability_, lhs_storage_type,
-               output_storage_type)
-               ? absl::OkStatus()
-               : absl::UnimplementedError(absl::StrFormat(
-                     "Unsupported algorithm on the current device(s): %s",
-                     PrecisionConfig::Algorithm_Name(config.algorithm())));
-  }
-
-  absl::Status DefaultAction(const HloInstruction* hlo) override {
-    return absl::OkStatus();
-  }
-
- private:
-  se::GpuComputeCapability gpu_compute_capability_;
-};
-
-}  // namespace
-
-absl::StatusOr<bool> AlgorithmChecker::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  TF_RETURN_IF_ERROR(AlgorithmCheckerVisitor(gpu_compute_capability_)
-                         .RunOnModule(module, execution_threads));
-  // No change was made.
-  return false;
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/algorithm_checker.h b/third_party/xla/xla/service/gpu/algorithm_checker.h
deleted file mode 100644
index f3b30c1..0000000
--- a/third_party/xla/xla/service/gpu/algorithm_checker.h
+++ /dev/null
@@ -1,54 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_ALGORITHM_CHECKER_H_
-#define XLA_SERVICE_GPU_ALGORITHM_CHECKER_H_
-
-#include <utility>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/stream_executor/device_description.h"
-
-namespace xla {
-namespace gpu {
-
-// This checks if the requested algorithms are supported. This can give an early
-// and specific error if an unsupported algorithm is requested.
-//
-// Note: Maybe we can make this more generic and move it outside of GPU.
-class AlgorithmChecker : public HloModulePass {
- public:
-  explicit AlgorithmChecker(se::GpuComputeCapability gpu_compute_capability)
-      : gpu_compute_capability_(std::move(gpu_compute_capability)){};
-
-  absl::string_view name() const override { return "algorithm-checker"; }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
-  se::GpuComputeCapability gpu_compute_capability_;
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_ALGORITHM_CHECKER_H_
diff --git a/third_party/xla/xla/service/gpu/alias_passthrough_params.cc b/third_party/xla/xla/service/gpu/alias_passthrough_params.cc
deleted file mode 100644
index 5dea5bc..0000000
--- a/third_party/xla/xla/service/gpu/alias_passthrough_params.cc
+++ /dev/null
@@ -1,69 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "xla/service/gpu/alias_passthrough_params.h"
-
-#include <cstdint>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/shape_util.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h"
-
-namespace xla {
-namespace gpu {
-
-absl::StatusOr<bool> AliasPassthroughParams::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  const HloInstruction* root = module->entry_computation()->root_instruction();
-  if (module->entry_computation()->num_parameters() == 0 ||
-      root->opcode() != HloOpcode::kTuple) {
-    return false;
-  }
-  bool changed = false;
-  absl::flat_hash_set<int64_t> used_params;
-  for (int64_t i = 0; i < root->operand_count(); ++i) {
-    if (root->operand(i)->opcode() == HloOpcode::kParameter &&
-        used_params.count(root->operand(i)->parameter_number()) == 0) {
-      VLOG(2) << "Parameter " << root->operand(i)->parameter_number()
-              << " with shape " << root->operand(i)->shape().ToString()
-              << " in module " << module->name()
-              << " is passed-through to root tuple element " << i << ": "
-              << root->shape().ToString();
-
-      if (module->input_output_alias_config().OutputHasAlias({i}) ||
-          module->input_output_alias_config().ParameterHasAlias(
-              root->operand(i)->parameter_number(), /*param_index=*/{})) {
-        VLOG(2) << "Skip setting the above pass-through alias as an alias may"
-                << " have been set up for alising resource update.";
-        continue;
-      }
-
-      TF_RETURN_IF_ERROR(module->input_output_alias_config().SetUpAlias(
-          /*output_index=*/{i},
-          /*param_number=*/root->operand(i)->parameter_number(),
-          /*param_index=*/{}));
-      used_params.insert(root->operand(i)->parameter_number());
-      changed = true;
-    }
-  }
-  return changed;
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/alias_passthrough_params.h b/third_party/xla/xla/service/gpu/alias_passthrough_params.h
deleted file mode 100644
index 029068a..0000000
--- a/third_party/xla/xla/service/gpu/alias_passthrough_params.h
+++ /dev/null
@@ -1,50 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_ALIAS_PASSTHROUGH_PARAMS_H_
-#define XLA_SERVICE_GPU_ALIAS_PASSTHROUGH_PARAMS_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-// This pass aliases input and output buffers that are associated with a
-// parameter that is passed through to the module root unmodified.
-//
-// This pass assumes that parameters and the root use unnested shapes, which is
-// the case for XLA:GPU.
-//
-// This pass must run prior to copy insertion.
-class AliasPassthroughParams : public HloModulePass {
- public:
-  AliasPassthroughParams() = default;
-  ~AliasPassthroughParams() override = default;
-  absl::string_view name() const override { return "alias_passthrough_params"; }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_ALIAS_PASSTHROUGH_PARAMS_H_
diff --git a/third_party/xla/xla/service/gpu/alias_passthrough_params_test.cc b/third_party/xla/xla/service/gpu/alias_passthrough_params_test.cc
deleted file mode 100644
index 2c09daf..0000000
--- a/third_party/xla/xla/service/gpu/alias_passthrough_params_test.cc
+++ /dev/null
@@ -1,87 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/alias_passthrough_params.h"
-
-#include "xla/tests/hlo_test_base.h"
-#include "xla/tsl/lib/core/status_test_util.h"
-#include "tsl/platform/test.h"
-
-namespace xla {
-namespace gpu {
-
-class AliasPassthroughParamsTest : public HloTestBase {};
-
-TEST_F(AliasPassthroughParamsTest, AliasPassThroughParams) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule TestModule
-
-  ENTRY TestComputation {
-    p0 = f16[2048,1024] parameter(0)
-    p1 = f16[2048,1024] parameter(1)
-    sum = f16[2048,1024] add(p0, p1)
-    ROOT root = (f16[2048,1024], f16[2048,1024], f16[2048,1024]) tuple(p0, sum, p1)
-  })")
-                    .value();
-  EXPECT_TRUE(AliasPassthroughParams().Run(module.get()).value());
-  const auto& alias_config = module->input_output_alias_config();
-  EXPECT_EQ(0, alias_config.GetAliasedParameter({0})->parameter_number);
-  EXPECT_FALSE(alias_config.OutputHasAlias({1}));
-  EXPECT_EQ(1, alias_config.GetAliasedParameter({2})->parameter_number);
-}
-
-TEST_F(AliasPassthroughParamsTest, DoNotAliasPassThroughParamsMoreThanOnce) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule TestModule
-
-  ENTRY TestComputation {
-    p0 = f16[2048,1024] parameter(0)
-    ROOT root = (f16[2048,1024], f16[2048,1024]) tuple(p0, p0)
-  })")
-                    .value();
-  EXPECT_TRUE(AliasPassthroughParams().Run(module.get()).value());
-  const auto& alias_config = module->input_output_alias_config();
-  EXPECT_EQ(0, alias_config.GetAliasedParameter({0})->parameter_number);
-  EXPECT_FALSE(alias_config.OutputHasAlias({1}));
-}
-
-TEST_F(AliasPassthroughParamsTest, PresetAliases) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule TestModule
-
-  ENTRY TestComputation {
-    p0 = f16[2048,1024] parameter(0)
-    p1 = f16[2048,1024] parameter(1)
-    sum = f16[2048,1024] add(p0, p1)
-    ROOT root = (f16[2048,1024], f16[2048,1024], f16[2048,1024]) tuple(p0, sum, p1)
-  })")
-                    .value();
-
-  // Presetting an alias for p0 -> Sum. This could happen in a case of
-  // `alias_resource_update`.
-  auto& preset_alias = module->input_output_alias_config();
-  TF_EXPECT_OK(preset_alias.SetUpAlias(/*output_index=*/{1},
-                                       /*param_number=*/0,
-                                       /*param_index=*/{}));
-
-  EXPECT_TRUE(AliasPassthroughParams().Run(module.get()).value());
-  const auto& alias_result = module->input_output_alias_config();
-  // Assert that an alias p1 -> p1 is established by `AliasPassthroughParams`.
-  EXPECT_EQ(1, alias_result.GetAliasedParameter({2})->parameter_number);
-  EXPECT_FALSE(alias_result.OutputHasAlias({0}));
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/all_reduce_blueconnect.cc b/third_party/xla/xla/service/gpu/all_reduce_blueconnect.cc
deleted file mode 100644
index 2e75ffa..0000000
--- a/third_party/xla/xla/service/gpu/all_reduce_blueconnect.cc
+++ /dev/null
@@ -1,373 +0,0 @@
-/* Copyright 2021 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/all_reduce_blueconnect.h"
-
-#include <algorithm>
-#include <cstddef>
-#include <cstdint>
-#include <iterator>
-#include <optional>
-#include <utility>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/btree_map.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/hlo/utils/hlo_query.h"
-#include "xla/service/collective_ops_utils.h"
-#include "xla/service/computation_placer.h"
-#include "xla/service/global_device_id.h"
-#include "xla/service/hlo_creation_utils.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/status_macros.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace {
-
-std::vector<HloInstruction*> GetOutputs(HloInstruction& instruction) {
-  if (!instruction.shape().IsTuple()) {
-    return {&instruction};
-  }
-
-  std::vector<HloInstruction*> outputs;
-  outputs.reserve(instruction.shape().tuple_shapes_size());
-
-  HloComputation& computation = *instruction.parent();  // never null
-  for (int i = 0; i < instruction.shape().tuple_shapes_size(); ++i) {
-    outputs.push_back(computation.AddInstruction(
-        HloInstruction::CreateGetTupleElement(&instruction, i)));
-  }
-  return outputs;
-}
-
-struct DecomposedReplicaGroups {
-  std::vector<ReplicaGroup> scatter_gather_groups;
-  std::vector<ReplicaGroup> new_all_reduce_groups;
-};
-
-// Returns the global device id for the given replica id. Returns nullopt if
-// if the replica id can refer to multiple devices, or if the pass does not
-// support the CollectiveOpGroupMode.
-std::optional<GlobalDeviceId> TryConvertingReplicaIdToDeviceId(
-    int64_t replica_id, const DeviceAssignment& device_assignment,
-    CollectiveOpGroupMode collective_group_mode) {
-  if (collective_group_mode == CollectiveOpGroupMode::kCrossReplica) {
-    if (device_assignment.computation_count() != 1) {
-      // If there are multiple partitions, the replica_id may refer to multiple
-      // devices on different partitions.
-      return std::nullopt;
-    }
-    return GlobalDeviceId{device_assignment(replica_id, /*computation_id=*/0)};
-  } else if (collective_group_mode == CollectiveOpGroupMode::kFlattenedID) {
-    int partition_count = device_assignment.computation_count();
-    int64_t actual_replica_id = replica_id / partition_count;
-    int64_t partition_id = replica_id % partition_count;
-    return GlobalDeviceId{device_assignment(actual_replica_id, partition_id)};
-  }
-
-  // kCrossPartition and kCrossReplicaAndPartition are unsupported.
-  VLOG(1) << "Skip AllReduceBlueConnect because of unsupported "
-             "CollectiveOpGroupMode "
-          << CollectiveOpGroupModeToString(collective_group_mode);
-  return std::nullopt;
-}
-
-absl::StatusOr<std::optional<DecomposedReplicaGroups>> TryDecomposeReplicaGroup(
-    const ReplicaGroup& replica_group,
-    const DeviceAssignment& device_assignment, size_t num_devices_per_host,
-    CollectiveOpGroupMode collective_group_mode) {
-  int group_size = replica_group.replica_ids_size();
-  TF_RET_CHECK(group_size > 0);
-
-  absl::btree_map<int, std::vector<int64_t>> replica_ids_by_host;
-  for (int64_t replica_id : replica_group.replica_ids()) {
-    std::optional<GlobalDeviceId> device_id = TryConvertingReplicaIdToDeviceId(
-        replica_id, device_assignment, collective_group_mode);
-    if (!device_id.has_value()) {
-      return {std::nullopt};
-    }
-    TF_RET_CHECK(*device_id >= 0);
-    // We assume that devices are ordered by host.
-    int host_id = device_id->value() / num_devices_per_host;
-    replica_ids_by_host[host_id].push_back(replica_id);
-  }
-
-  size_t num_local_devices = replica_ids_by_host.begin()->second.size();
-  bool same_num_devices_on_each_host =
-      absl::c_all_of(replica_ids_by_host, [&](const auto& entry) {
-        return entry.second.size() == num_local_devices;
-      });
-
-  if (!same_num_devices_on_each_host) {
-    return {std::nullopt};
-  }
-
-  std::vector<int64_t> sorted_replica_group;
-  sorted_replica_group.reserve(group_size);
-  for (const auto& entry : replica_ids_by_host) {
-    absl::c_copy(entry.second, std::back_inserter(sorted_replica_group));
-  }
-
-  size_t scatter_group_size = std::max(num_local_devices, size_t(2));
-  size_t num_scatter_groups = group_size / scatter_group_size;
-
-  if ((group_size % scatter_group_size != 0) || (num_scatter_groups < 2)) {
-    return {std::nullopt};
-  }
-
-  std::vector<ReplicaGroup> scatter_gather_groups(num_scatter_groups);
-  std::vector<ReplicaGroup> new_all_reduce_groups(scatter_group_size);
-
-  for (size_t i = 0; i < group_size; ++i) {
-    int64_t replica_id = sorted_replica_group[i];
-    scatter_gather_groups[i / scatter_group_size].add_replica_ids(replica_id);
-    new_all_reduce_groups[i % scatter_group_size].add_replica_ids(replica_id);
-  }
-
-  return {DecomposedReplicaGroups{std::move(scatter_gather_groups),
-                                  std::move(new_all_reduce_groups)}};
-}
-
-absl::StatusOr<std::optional<DecomposedReplicaGroups>>
-TryDecomposeReplicaGroups(const HloAllReduceInstruction& all_reduce,
-                          size_t num_devices_per_host) {
-  const DeviceAssignment& device_assignment =
-      all_reduce.GetModule()->config().static_device_assignment();
-
-  absl::Span<const ReplicaGroup> replica_groups = all_reduce.replica_groups();
-
-  ReplicaGroup all_replicas;  // only populated if replica groups not present.
-  if (replica_groups.empty()) {
-    for (int i = 0; i < device_assignment.replica_count(); ++i) {
-      all_replicas.add_replica_ids(i);
-    }
-    replica_groups = absl::MakeSpan(&all_replicas, 1);
-  }
-
-  TF_ASSIGN_OR_RETURN(
-      CollectiveOpGroupMode collective_op_group_mode,
-      GetCollectiveOpGroupMode(all_reduce.channel_id().has_value(),
-                               all_reduce.use_global_device_ids()));
-
-  std::vector<ReplicaGroup> scatter_gather_groups;
-  std::vector<ReplicaGroup> new_all_reduce_groups;
-
-  // Try to find a valid decomposition for each replica group.
-  for (const ReplicaGroup& replica_group : replica_groups) {
-    TF_ASSIGN_OR_RETURN(
-        std::optional<DecomposedReplicaGroups> decomposed_groups,
-        TryDecomposeReplicaGroup(replica_group, device_assignment,
-                                 num_devices_per_host,
-                                 collective_op_group_mode));
-
-    if (!decomposed_groups) return {std::nullopt};
-
-    int scatter_group_size =
-        decomposed_groups->scatter_gather_groups[0].replica_ids_size();
-
-    if (scatter_gather_groups.empty()) {
-      // Check that every operand is exactly divisible by scatter group sizes.
-      for (const HloInstruction* operand : all_reduce.operands()) {
-        TF_RET_CHECK(operand->shape().IsArray());
-        int64_t num_elements = ShapeUtil::ElementsIn(operand->shape());
-        if (num_elements % scatter_group_size != 0) {
-          return {std::nullopt};
-        }
-      }
-
-      scatter_gather_groups.reserve(
-          replica_groups.size() *
-          decomposed_groups->scatter_gather_groups.size());
-      new_all_reduce_groups.reserve(
-          replica_groups.size() *
-          decomposed_groups->new_all_reduce_groups.size());
-    } else if (scatter_group_size !=
-               scatter_gather_groups[0].replica_ids_size()) {
-      // Reduce-scatter would have different output shapes on different devices.
-      return {std::nullopt};
-    }
-
-    absl::c_move(decomposed_groups->scatter_gather_groups,
-                 std::back_inserter(scatter_gather_groups));
-    absl::c_move(decomposed_groups->new_all_reduce_groups,
-                 std::back_inserter(new_all_reduce_groups));
-  }
-
-  return {DecomposedReplicaGroups{std::move(scatter_gather_groups),
-                                  std::move(new_all_reduce_groups)}};
-}
-
-// Attempts to decompose all-reduces as described by the BlueConnect paper.
-//
-// If possible, the all-reduce will be transformed into:
-// 1. reduce-scatter
-// 2. all-reduce
-// 3. all-gather
-//
-// If the all-reduce replica groups have more than one device within the same
-// host, the reduce-scatter will be performed over all devices with each host.
-// Otherwise, the reduce-scatter will be performed between pairs of devices on
-// different hosts.
-//
-// When applied repeatedly, this transformation will reproduce the same pattern
-// as described in the BlueConnect paper.
-absl::StatusOr<bool> TryDecomposeAllReduce(HloAllReduceInstruction* all_reduce,
-                                           size_t num_devices_per_host) {
-  TF_RET_CHECK(all_reduce);
-  TF_RET_CHECK(!all_reduce->has_sharding());
-
-  HloComputation& computation = *all_reduce->parent();  // never null
-  PrimitiveType element_type = all_reduce->operand(0)->shape().element_type();
-
-  TF_ASSIGN_OR_RETURN(
-      std::optional<DecomposedReplicaGroups> decomposed_groups,
-      TryDecomposeReplicaGroups(*all_reduce, num_devices_per_host));
-
-  if (!decomposed_groups) return false;
-
-  // Bitcast operands to 1D to guarantee that first dimension is divisible by
-  // scatter group size (we checked num elements was divisible above).
-  std::vector<HloInstruction*> flat_operands;
-  flat_operands.reserve(all_reduce->operand_count());
-  std::vector<Shape> flat_shapes;
-  flat_shapes.reserve(all_reduce->operand_count());
-  std::vector<Shape> scattered_shapes;
-  scattered_shapes.reserve(all_reduce->operand_count());
-
-  int scatter_group_size =
-      decomposed_groups->scatter_gather_groups[0].replica_ids_size();
-
-  for (HloInstruction* operand : all_reduce->operands()) {
-    TF_RET_CHECK(operand->shape().IsArray());
-    int64_t num_elements = ShapeUtil::ElementsIn(operand->shape());
-    Shape flat_shape = ShapeUtil::MakeShape(element_type, {num_elements});
-    flat_operands.push_back(computation.AddInstruction(
-        HloInstruction::CreateBitcast(flat_shape, operand)));
-    flat_shapes.push_back(std::move(flat_shape));
-    scattered_shapes.push_back(ShapeUtil::MakeShape(
-        element_type, {num_elements / scatter_group_size}));
-  }
-
-  Shape reduce_scatter_shape = ShapeUtil::MakeMaybeTupleShape(scattered_shapes);
-
-  int64_t next_channel_id = hlo_query::NextChannelId(*computation.parent());
-  auto get_channel_id = [&]() -> std::optional<int64_t> {
-    if (all_reduce->channel_id().has_value()) {
-      return next_channel_id++;
-    }
-    return std::nullopt;
-  };
-
-  HloInstruction* reduce_scatter =
-      computation.AddInstruction(HloInstruction::CreateReduceScatter(
-          reduce_scatter_shape, flat_operands, all_reduce->to_apply(),
-          CollectiveDeviceList(decomposed_groups->scatter_gather_groups),
-          /*constrain_layout=*/false, get_channel_id(),
-          all_reduce->use_global_device_ids(),
-          /*scatter_dimension=*/0));
-
-  HloInstruction* new_all_reduce =
-      computation.AddInstruction(HloInstruction::CreateAllReduce(
-          reduce_scatter_shape, GetOutputs(*reduce_scatter),
-          all_reduce->to_apply(),
-          CollectiveDeviceList(decomposed_groups->new_all_reduce_groups),
-          /*constrain_layout=*/false, all_reduce->channel_id(),
-          all_reduce->use_global_device_ids()));
-
-  HloInstruction* all_gather =
-      computation.AddInstruction(HloInstruction::CreateAllGather(
-          ShapeUtil::MakeMaybeTupleShape(flat_shapes),
-          GetOutputs(*new_all_reduce),
-          /*all_gather_dimension=*/0,
-          CollectiveDeviceList(decomposed_groups->scatter_gather_groups),
-          /*constrain_layout=*/false, get_channel_id(),
-          all_reduce->use_global_device_ids()));
-
-  // Bitcast back to the original shapes and replace all-reduce with decomposed
-  // implementation.
-  std::vector<HloInstruction*> outputs = GetOutputs(*all_gather);
-  for (int64_t i = 0; i < outputs.size(); ++i) {
-    outputs[i] = computation.AddInstruction(HloInstruction::CreateBitcast(
-        all_reduce->operand(i)->shape(), outputs[i]));
-  }
-  HloInstruction* replacement = MaybeMakeTuple(outputs);
-
-  TF_RETURN_IF_ERROR(
-      all_reduce->CopyAllControlDepsTo(reduce_scatter, replacement));
-
-  TF_RETURN_IF_ERROR(all_reduce->DropAllControlDeps());
-  TF_RETURN_IF_ERROR(computation.ReplaceInstruction(all_reduce, replacement));
-
-  // Try to apply decomposition recursively.
-  TF_RETURN_IF_ERROR(
-      TryDecomposeAllReduce(Cast<HloAllReduceInstruction>(new_all_reduce),
-                            num_devices_per_host)
-          .status());
-  return true;
-}
-
-}  // namespace
-
-absl::StatusOr<bool> AllReduceBlueConnect::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  VLOG(1) << "Running AllReduceBlueConnect";
-
-  if (hlo_query::ContainsLayoutConstrainedAllReduce(*module)) {
-    VLOG(1)
-        << "Skip AllReduceBlueConnect because the module contains all-reduce "
-           "with constrained layouts";
-    return false;
-  }
-  if (!module->config().has_static_device_assignment()) {
-    VLOG(1)
-        << "Skip AllReduceBlueConnect because the module doesn't have static "
-           "device assignment";
-    return false;
-  }
-
-  std::vector<HloAllReduceInstruction*> all_reduces;
-  for (HloComputation* computation :
-       module->MakeNonfusionComputations(execution_threads)) {
-    for (HloInstruction* instruction : computation->instructions()) {
-      if (instruction->opcode() == HloOpcode::kAllReduce) {
-        all_reduces.push_back(Cast<HloAllReduceInstruction>(instruction));
-      }
-    }
-  }
-
-  bool changed = false;
-  for (HloAllReduceInstruction* all_reduce : all_reduces) {
-    TF_ASSIGN_OR_RETURN(
-        bool all_reduce_changed,
-        TryDecomposeAllReduce(all_reduce, num_devices_per_host_));
-    changed |= all_reduce_changed;
-  }
-
-  return changed;
-}
-
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/all_reduce_blueconnect.h b/third_party/xla/xla/service/gpu/all_reduce_blueconnect.h
deleted file mode 100644
index 8633c77..0000000
--- a/third_party/xla/xla/service/gpu/all_reduce_blueconnect.h
+++ /dev/null
@@ -1,56 +0,0 @@
-/* Copyright 2021 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_ALL_REDUCE_BLUECONNECT_H_
-#define XLA_SERVICE_GPU_ALL_REDUCE_BLUECONNECT_H_
-
-#include <cstddef>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-
-// Decomposes all-reduce operations using the BlueConnect algorithm.
-//
-// Paper: "BLUECONNECT: DECOMPOSING ALL-REDUCE FOR DEEP LEARNING ON
-// HETEROGENEOUS NETWORK HIERARCHY"
-// https://mlsys.org/Conferences/2019/doc/2019/130.pdf
-//
-// This algorithm attempts to minimize the number of levels of network hierarchy
-// traversed for as much data transfer as possible. This implementation assumes
-// that host IDs are ordered corresponding to network hierarchy.
-class AllReduceBlueConnect : public HloModulePass {
- public:
-  explicit AllReduceBlueConnect(size_t num_devices_per_host)
-      : num_devices_per_host_(num_devices_per_host) {}
-
-  absl::string_view name() const override { return "all-reduce-blueconnect"; }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
-  size_t num_devices_per_host_;
-};
-
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_ALL_REDUCE_BLUECONNECT_H_
diff --git a/third_party/xla/xla/service/gpu/all_reduce_blueconnect_test.cc b/third_party/xla/xla/service/gpu/all_reduce_blueconnect_test.cc
deleted file mode 100644
index a6a66c5..0000000
--- a/third_party/xla/xla/service/gpu/all_reduce_blueconnect_test.cc
+++ /dev/null
@@ -1,414 +0,0 @@
-/* Copyright 2021 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/all_reduce_blueconnect.h"
-
-#include <cstddef>
-#include <cstdint>
-#include <memory>
-#include <optional>
-#include <utility>
-#include <vector>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/computation_placer.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/util.h"
-#include "tsl/platform/status_matchers.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace {
-
-using ::tsl::testing::IsOkAndHolds;
-namespace m = ::xla::match;
-
-using AllReduceBlueConnectTest = HloTestBase;
-
-HloPredicate MatchChannelId(std::optional<int64_t> channel_id) {
-  return [channel_id](const HloInstruction* instruction) {
-    return instruction->channel_id() == channel_id;
-  };
-}
-
-void SetModuleConfig(HloModuleConfig* module_config, size_t replica_count,
-                     size_t partition_count = 1) {
-  DeviceAssignment device_assignment(replica_count,
-                                     /*computation_count=*/partition_count);
-  device_assignment.FillIota(0);
-  module_config->set_replica_count(replica_count);
-  module_config->set_num_partitions(partition_count);
-  module_config->set_static_device_assignment(device_assignment);
-}
-
-void SetModuleConfig(HloModule& module, size_t replica_count,
-                     size_t partition_count = 1) {
-  SetModuleConfig(&module.mutable_config(), replica_count, partition_count);
-}
-
-TEST_F(AllReduceBlueConnectTest, OneStage) {
-  constexpr absl::string_view hlo_string = R"(
-HloModule module
-
-%add {
-  lhs = f32[] parameter(0)
-  rhs = f32[] parameter(1)
-  ROOT add = f32[] add(lhs, rhs)
-}
-
-ENTRY %comp {
-  p0 = f32[4,4] parameter(0)
-  ROOT crs = f32[4,4] all-reduce(p0), to_apply=add
-})";
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(hlo_string));
-  SetModuleConfig(*module, /*replica_count=*/8);
-
-  AllReduceBlueConnect pass(/*num_devices_per_host=*/4);
-  EXPECT_THAT(pass.Run(module.get()), IsOkAndHolds(true));
-
-  // clang-format off
-  std::vector<std::vector<int64_t>> scatter_gather_groups = {
-      {0, 1, 2, 3}, {4, 5, 6, 7}};
-  std::vector<std::vector<int64_t>> new_all_reduce_groups = {
-      {0, 4}, {1, 5}, {2, 6}, {3, 7}};
-  // clang-format on
-
-  auto bitcast = m::Bitcast(m::Parameter(0)).WithShape(F32, {16});
-  auto reduce_scatter = m::ReduceScatter(bitcast)
-                            .WithShape(F32, {4})
-                            .WithReplicaGroups(scatter_gather_groups)
-                            .WithPredicate(MatchChannelId(std::nullopt));
-  auto all_reduce = m::AllReduce(reduce_scatter)
-                        .WithShape(F32, {4})
-                        .WithReplicaGroups(new_all_reduce_groups)
-                        .WithPredicate(MatchChannelId(std::nullopt));
-  auto all_gather = m::AllGather(all_reduce)
-                        .WithShape(F32, {16})
-                        .WithReplicaGroups(scatter_gather_groups)
-                        .WithPredicate(MatchChannelId(std::nullopt));
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch(m::Bitcast(all_gather).WithShape(F32, {4, 4})));
-}
-
-TEST_F(AllReduceBlueConnectTest, TwoStage) {
-  constexpr absl::string_view hlo_string = R"(
-HloModule module
-
-%add {
-  lhs = f32[] parameter(0)
-  rhs = f32[] parameter(1)
-  ROOT add = f32[] add(lhs, rhs)
-}
-
-ENTRY %comp {
-  p0 = f32[4,4] parameter(0)
-  ROOT crs = f32[4,4] all-reduce(p0), to_apply=add
-})";
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(hlo_string));
-  SetModuleConfig(*module, /*replica_count=*/16);
-
-  AllReduceBlueConnect pass(/*num_devices_per_host=*/4);
-  EXPECT_THAT(pass.Run(module.get()), IsOkAndHolds(true));
-
-  std::vector<std::vector<int64_t>> outer_scatter_gather_groups = {
-      {0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}, {12, 13, 14, 15}};
-  std::vector<std::vector<int64_t>> inner_scatter_gather_groups = {
-      {0, 4}, {8, 12}, {1, 5}, {9, 13}, {2, 6}, {10, 14}, {3, 7}, {11, 15}};
-  std::vector<std::vector<int64_t>> new_all_reduce_groups = {
-      {0, 8}, {4, 12}, {1, 9}, {5, 13}, {2, 10}, {6, 14}, {3, 11}, {7, 15}};
-
-  auto bitcast0 = m::Bitcast(m::Parameter(0)).WithShape(F32, {16});
-  auto reduce_scatter0 =
-      m::ReduceScatter(bitcast0).WithShape(F32, {4}).WithReplicaGroups(
-          outer_scatter_gather_groups);
-  auto bitcast1 = m::Bitcast(reduce_scatter0).WithShape(F32, {4});
-  auto reduce_scatter1 =
-      m::ReduceScatter(bitcast1).WithShape(F32, {2}).WithReplicaGroups(
-          inner_scatter_gather_groups);
-  auto all_reduce = m::AllReduce(reduce_scatter1)
-                        .WithShape(F32, {2})
-                        .WithReplicaGroups(new_all_reduce_groups);
-  auto all_gather0 = m::AllGather(all_reduce)
-                         .WithShape(F32, {4})
-                         .WithReplicaGroups(inner_scatter_gather_groups);
-  auto bitcast2 = m::Bitcast(all_gather0).WithShape(F32, {4});
-  auto all_gather1 =
-      m::AllGather(bitcast2).WithShape(F32, {16}).WithReplicaGroups(
-          outer_scatter_gather_groups);
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch(m::Bitcast(all_gather1).WithShape(F32, {4, 4})));
-}
-
-TEST_F(AllReduceBlueConnectTest, TwoOperands) {
-  constexpr absl::string_view hlo_string = R"(
-HloModule module
-
-%add {
-  lhs = f32[] parameter(0)
-  rhs = f32[] parameter(1)
-  ROOT add = f32[] add(lhs, rhs)
-}
-
-ENTRY %comp {
-  p0 = f32[4,4] parameter(0)
-  p1 = f32[4,4,2] parameter(1)
-  ROOT crs = (f32[4,4], f32[4,4,2]) all-reduce(p0, p1), to_apply=add
-})";
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(hlo_string));
-  SetModuleConfig(*module, /*replica_count=*/8);
-
-  AllReduceBlueConnect pass(/*num_devices_per_host=*/4);
-  EXPECT_THAT(pass.Run(module.get()), IsOkAndHolds(true));
-
-  // clang-format off
-  std::vector<std::vector<int64_t>> scatter_gather_groups = {
-      {0, 1, 2, 3}, {4, 5, 6, 7}};
-  std::vector<std::vector<int64_t>> new_all_reduce_groups = {
-      {0, 4}, {1, 5}, {2, 6}, {3, 7}};
-  // clang-format on
-
-  auto bitcast0 = m::Bitcast(m::Parameter(0)).WithShape(F32, {16});
-  auto bitcast1 = m::Bitcast(m::Parameter(1)).WithShape(F32, {32});
-
-  Shape expected0 = ShapeUtil::MakeTupleShape(
-      {ShapeUtil::MakeShape(F32, {4}), ShapeUtil::MakeShape(F32, {8})});
-  Shape expected1 = ShapeUtil::MakeTupleShape(
-      {ShapeUtil::MakeShape(F32, {16}), ShapeUtil::MakeShape(F32, {32})});
-  auto reduce_scatter = m::ReduceScatter(bitcast0, bitcast1)
-                            .WithShapeEqualTo(&expected0)
-                            .WithReplicaGroups(scatter_gather_groups);
-  auto all_reduce = m::AllReduce(m::GetTupleElement(reduce_scatter, 0),
-                                 m::GetTupleElement(reduce_scatter, 1))
-                        .WithShapeEqualTo(&expected0)
-                        .WithReplicaGroups(new_all_reduce_groups);
-  auto all_gather = m::AllGather(m::GetTupleElement(all_reduce, 0),
-                                 m::GetTupleElement(all_reduce, 1))
-                        .WithShapeEqualTo(&expected1)
-                        .WithReplicaGroups(scatter_gather_groups);
-  auto bitcast2 =
-      m::Bitcast(m::GetTupleElement(all_gather, 0)).WithShape(F32, {4, 4});
-  auto bitcast3 =
-      m::Bitcast(m::GetTupleElement(all_gather, 1)).WithShape(F32, {4, 4, 2});
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch(m::Tuple(bitcast2, bitcast3)));
-}
-
-TEST_F(AllReduceBlueConnectTest, MultiplePartitionsFilecheck) {
-  constexpr absl::string_view hlo_string = R"(
-HloModule module
-
-%add {
-  lhs = f32[] parameter(0)
-  rhs = f32[] parameter(1)
-  ROOT add = f32[] add(lhs, rhs)
-}
-
-ENTRY %comp {
-  p0 = f32[8,8] parameter(0)
-  ROOT crs = f32[8,8] all-reduce(p0), channel_id=1,
-    replica_groups={{0,1,2,3,4,5,6,7}}, use_global_device_ids=true, to_apply=add
-})";
-  HloModuleConfig module_config;
-  SetModuleConfig(&module_config, /*replica_count=*/1, /*partition_count=*/8);
-
-  AllReduceBlueConnect pass(/*num_devices_per_host=*/4);
-  // Note: When matching strings like "replica_groups={{0,1,2,3}}", FileCheck
-  // interprets the string inside the double braces as regex. So to match such
-  // strings, we use "replica_groups={{..0,1,2,3..}}", where the dots match the
-  // opening and closing braces.
-  RunAndFilecheckHloRewrite(hlo_string, std::move(pass), R"(
-  CHECK:       %p0 = f32[8,8]{1,0} parameter(0)
-  CHECK-NEXT:  [[bitcast:%[^ ]+]] = f32[64]{0} bitcast(%p0)
-  CHECK-NEXT:  [[reduce_scatter:%[^ ]+]] = f32[16]{0} reduce-scatter([[bitcast]]), channel_id=2, replica_groups={{..0,1,2,3.,.4,5,6,7..}}, use_global_device_ids=true, dimensions={0}, to_apply=%add
-  CHECK-NEXT:  [[all_reduce:%[^ ]+]] = f32[16]{0} all-reduce([[reduce_scatter]]), channel_id=1, replica_groups={{..0,4.,.1,5.,.2,6.,.3,7..}}, use_global_device_ids=true, to_apply=%add
-  CHECK-NEXT:  [[all_gather:%[^ ]+]] = f32[64]{0} all-gather([[all_reduce]]), channel_id=3, replica_groups={{..0,1,2,3.,.4,5,6,7..}}, dimensions={0}, use_global_device_ids=true
-  CHECK-NEXT:  ROOT [[output:%[^ ]+]] = f32[8,8]{1,0} bitcast([[all_gather]])
-}
-)",
-                            /*after_pass_checks=*/nullptr, &module_config);
-}
-
-TEST_F(AllReduceBlueConnectTest, DifferentNumLocalDevicesWithinReplicaGroup) {
-  constexpr absl::string_view hlo_string = R"(
-HloModule module
-
-%add {
-  lhs = f32[] parameter(0)
-  rhs = f32[] parameter(1)
-  ROOT add = f32[] add(lhs, rhs)
-}
-
-ENTRY %comp {
-  p0 = f32[4,4] parameter(0)
-  ROOT crs = f32[4,4] all-reduce(p0),
-    replica_groups={{0,1,2,7},{3,4,5,6}}, to_apply=add
-})";
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(hlo_string));
-  SetModuleConfig(*module, /*replica_count=*/8);
-
-  AllReduceBlueConnect pass(/*num_devices_per_host=*/4);
-  EXPECT_THAT(pass.Run(module.get()), IsOkAndHolds(false));
-}
-
-TEST_F(AllReduceBlueConnectTest, DifferentNumLocalDevicesAcrossReplicaGroups) {
-  constexpr absl::string_view hlo_string = R"(
-HloModule module
-
-%add {
-  lhs = f32[] parameter(0)
-  rhs = f32[] parameter(1)
-  ROOT add = f32[] add(lhs, rhs)
-}
-
-ENTRY %comp {
-  p0 = f32[4,4] parameter(0)
-  ROOT crs = f32[4,4] all-reduce(p0),
-    replica_groups={{0,1,4,5},{2,3,6,7},{8,9,10,11},{12,13,14,15}}, to_apply=add
-})";
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(hlo_string));
-  SetModuleConfig(*module, /*replica_count=*/16);
-
-  AllReduceBlueConnect pass(/*num_devices_per_host=*/4);
-  EXPECT_THAT(pass.Run(module.get()), IsOkAndHolds(false));
-}
-
-TEST_F(AllReduceBlueConnectTest, OperandIndivisible) {
-  constexpr absl::string_view hlo_string = R"(
-HloModule module
-
-%add {
-  lhs = f32[] parameter(0)
-  rhs = f32[] parameter(1)
-  ROOT add = f32[] add(lhs, rhs)
-}
-
-ENTRY %comp {
-  p0 = f32[4,4] parameter(0)
-  p1 = f32[9] parameter(1)
-  ROOT crs = (f32[4,4], f32[9]) all-reduce(p0, p1), to_apply=add
-})";
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(hlo_string));
-  SetModuleConfig(*module, /*replica_count=*/8);
-
-  AllReduceBlueConnect pass(/*num_devices_per_host=*/4);
-  EXPECT_THAT(pass.Run(module.get()), IsOkAndHolds(false));
-}
-
-TEST_F(AllReduceBlueConnectTest, ControlDeps) {
-  constexpr absl::string_view hlo_string = R"(
-HloModule module
-
-%add {
-  lhs = f32[] parameter(0)
-  rhs = f32[] parameter(1)
-  ROOT add = f32[] add(lhs, rhs)
-}
-
-ENTRY %comp {
-  p0 = f32[4,4] parameter(0)
-  p1 = f32[4,4] parameter(1)
-  add = f32[4,4] add(p0, p1)
-  crs = f32[4,4] all-reduce(p0), to_apply=add, control-predecessors={add}
-  ROOT add1 = f32[4,4] add(crs, add), control-predecessors={crs}
-})";
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(hlo_string));
-  SetModuleConfig(*module, /*replica_count=*/8);
-
-  // Remember all-reduce's control succ and preds.
-  const HloInstruction* ar =
-      module->entry_computation()->root_instruction()->operand(0);
-  auto expected_preds = ar->control_predecessors();
-  auto expected_succs = ar->control_successors();
-
-  AllReduceBlueConnect pass(/*num_devices_per_host=*/4);
-  EXPECT_THAT(pass.Run(module.get()), IsOkAndHolds(true));
-
-  // clang-format off
-  std::vector<std::vector<int64_t>> scatter_gather_groups = {
-      {0, 1, 2, 3}, {4, 5, 6, 7}};
-  std::vector<std::vector<int64_t>> new_all_reduce_groups = {
-      {0, 4}, {1, 5}, {2, 6}, {3, 7}};
-  // clang-format on
-
-  const HloInstruction *matched_rs, *matched_bitcast;
-  auto bitcast = m::Bitcast(m::Parameter(0)).WithShape(F32, {16});
-  auto reduce_scatter = m::ReduceScatter(&matched_rs, bitcast)
-                            .WithShape(F32, {4})
-                            .WithReplicaGroups(scatter_gather_groups);
-  auto all_reduce = m::AllReduce(reduce_scatter)
-                        .WithShape(F32, {4})
-                        .WithReplicaGroups(new_all_reduce_groups);
-  auto all_gather = m::AllGather(all_reduce)
-                        .WithShape(F32, {16})
-                        .WithReplicaGroups(scatter_gather_groups);
-  HloInstruction* root = module->entry_computation()->root_instruction();
-  ASSERT_THAT(root, GmockMatch(m::Add()));
-
-  EXPECT_THAT(
-      root->operand(0),
-      GmockMatch(
-          m::Bitcast(&matched_bitcast, all_gather).WithShape(F32, {4, 4})));
-
-  // Verify that control dependencies are transferred correctly.
-  EXPECT_THAT(matched_rs, GmockMatch(m::Op().WithControlDeps(
-                              absl::MakeSpan(expected_preds), {})));
-  EXPECT_THAT(matched_bitcast, GmockMatch(m::Op().WithControlDeps(
-                                   {}, absl::MakeSpan(expected_succs))));
-}
-
-TEST_F(AllReduceBlueConnectTest, ReduceScatterUnchanged) {
-  // Tests that this pass does not affect reduce-scatter. In principle, the
-  // BlueConnect algorithm could be applied to reduce-scatter, but for now it
-  // doesn't.
-  constexpr absl::string_view hlo_string = R"(
-HloModule module
-
-%add {
-  lhs = f32[] parameter(0)
-  rhs = f32[] parameter(1)
-  ROOT add = f32[] add(lhs, rhs)
-}
-
-ENTRY %comp {
-  p0 = f32[8,4] parameter(0)
-  ROOT crs = f32[1,4] reduce-scatter(p0), dimensions={0}, to_apply=add
-})";
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(hlo_string));
-  SetModuleConfig(*module, /*replica_count=*/8);
-
-  AllReduceBlueConnect pass(/*num_devices_per_host=*/4);
-  EXPECT_THAT(pass.Run(module.get()), IsOkAndHolds(false));
-}
-
-}  // namespace
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/amdgpu_compiler.cc b/third_party/xla/xla/service/gpu/amdgpu_compiler.cc
index 04483ce..ae541ba 100644
--- a/third_party/xla/xla/service/gpu/amdgpu_compiler.cc
+++ b/third_party/xla/xla/service/gpu/amdgpu_compiler.cc
@@ -32,21 +32,21 @@
 #include "xla/service/dot_dimension_merger.h"
 #include "xla/service/float_normalization.h"
 #include "xla/service/float_support.h"
-#include "xla/service/gpu/autotuner_util.h"
-#include "xla/service/gpu/conv_algorithm_picker.h"
-#include "xla/service/gpu/cublas_pad_for_gemms.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
+#include "xla/service/gpu/autotuning/conv_algorithm_picker.h"
+#include "xla/service/gpu/autotuning/gemm_algorithm_picker.h"
 #include "xla/service/gpu/cublas_padding_requirements.h"
-#include "xla/service/gpu/cudnn_fused_conv_rewriter.h"
-#include "xla/service/gpu/cusolver_rewriter.h"
-#include "xla/service/gpu/gemm_algorithm_picker.h"
-#include "xla/service/gpu/gpu_algebraic_simplifier.h"
 #include "xla/service/gpu/gpu_compiler.h"
-#include "xla/service/gpu/gpu_conv_padding_legalization.h"
-#include "xla/service/gpu/gpu_conv_rewriter.h"
-#include "xla/service/gpu/gpu_sort_rewriter.h"
 #include "xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
 #include "xla/service/gpu/target_constants.h"
-#include "xla/service/gpu/triangular_solve_rewriter.h"
+#include "xla/service/gpu/transforms/algebraic_simplifier.h"
+#include "xla/service/gpu/transforms/conv_padding_legalization.h"
+#include "xla/service/gpu/transforms/conv_rewriter.h"
+#include "xla/service/gpu/transforms/cublas_pad_for_gemms.h"
+#include "xla/service/gpu/transforms/cudnn_fused_conv_rewriter.h"
+#include "xla/service/gpu/transforms/gpusolver_rewriter.h"
+#include "xla/service/gpu/transforms/sort_rewriter.h"
+#include "xla/service/gpu/transforms/triangular_solve_rewriter.h"
 #include "xla/service/hlo_constant_folding.h"
 #include "xla/service/hlo_module_config.h"
 #include "xla/service/hlo_pass_fix.h"
@@ -123,8 +123,8 @@
   pipeline.AddPass<FloatNormalization>(&conv_bf16_support);
 
   pipeline.AddPass<GpusolverRewriter>();
-  pipeline.AddPass<GpuConvRewriter>(gpu_version);
-  pipeline.AddPass<GpuConvPaddingLegalization>();
+  pipeline.AddPass<ConvRewriter>(gpu_version);
+  pipeline.AddPass<ConvPaddingLegalization>();
   auto rcc = std::get<se::RocmComputeCapability>(gpu_version);
   pipeline.AddPass<CudnnFusedConvRewriter>(rcc, dnn_version,
                                            GetToolkitVersion());
@@ -135,7 +135,7 @@
   pipeline.AddPass<CallInliner>();
   pipeline.AddPass<TupleSimplifier>();
 
-  // tf2xla bridge, DepthwiseConvolutionConverter and GpuConvRewriter
+  // tf2xla bridge, DepthwiseConvolutionConverter and ConvRewriter
   // introduces reshapes and transposes that can be eliminated using
   // AlgebraicSimplifier  We run algsimp to a fixed point.
   AlgebraicSimplifierOptions options =
@@ -144,7 +144,7 @@
   options.set_enable_unconditional_reduce_of_concat_replacement(false);
   pipeline.AddPass<HloPassFix<GpuAlgebraicSimplifier>>(options, gpu_version);
 
-  // tf2xla bridge, DepthwiseConvolutionConverter, GpuConvRewriter, and
+  // tf2xla bridge, DepthwiseConvolutionConverter, ConvRewriter, and
   // CudnnSimplifyPadding introduce reshapes and transposes.  Run ReshapeMover
   // to a fixed point.  Include algsimp because ReshapeMover relies on it.
   [&, &pipeline = pipeline.AddPass<HloPassFix<HloPassPipeline>>(
@@ -166,7 +166,7 @@
     pipeline.AddPass<GpuAlgebraicSimplifier>(options, gpu_version);
   }();
 
-  // GpuConvRewriter, GpuConvPaddingLegalization and
+  // ConvRewriter, ConvPaddingLegalization and
   // CudnnConvPadForTensorCores may add instructions which can be simplified
   // by constant folding.
   pipeline.AddPass<HloConstantFolding>();
@@ -240,7 +240,7 @@
 absl::Status AMDGPUCompiler::AddCustomKernelReplacementPasses(
     HloPassPipeline* pipeline, const DebugOptions& debug_options) {
   if (debug_options.xla_gpu_enable_cub_radix_sort()) {
-    pipeline->AddPass<GpuSortRewriter>();
+    pipeline->AddPass<SortRewriter>();
   }
   return absl::OkStatus();
 }
diff --git a/third_party/xla/xla/service/gpu/amdgpu_compiler.h b/third_party/xla/xla/service/gpu/amdgpu_compiler.h
index 483647b..062a0ef 100644
--- a/third_party/xla/xla/service/gpu/amdgpu_compiler.h
+++ b/third_party/xla/xla/service/gpu/amdgpu_compiler.h
@@ -20,7 +20,7 @@
 #include "absl/status/statusor.h"
 #include "llvm/IR/Module.h"
 #include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/gpu/autotuner_util.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
 #include "xla/service/gpu/gpu_compiler.h"
 #include "xla/service/hlo_module_config.h"
 #include "xla/service/hlo_pass_pipeline.h"
diff --git a/third_party/xla/xla/service/gpu/autotuner_compile_util.cc b/third_party/xla/xla/service/gpu/autotuner_compile_util.cc
deleted file mode 100644
index b3e880b..0000000
--- a/third_party/xla/xla/service/gpu/autotuner_compile_util.cc
+++ /dev/null
@@ -1,284 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/autotuner_compile_util.h"
-
-#include <cstdint>
-#include <iterator>
-#include <memory>
-#include <optional>
-#include <utility>
-#include <vector>
-
-#include "absl/log/check.h"
-#include "absl/status/status.h"
-#include "absl/strings/string_view.h"
-#include "absl/time/time.h"
-#include "absl/types/span.h"
-#include "xla/executable_run_options.h"
-#include "xla/hlo/ir/hlo_clone_context.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/compiler.h"
-#include "xla/service/executable.h"
-#include "xla/service/gpu/autotuner_util.h"
-#include "xla/service/gpu/gpu_executable_run_options.h"
-#include "xla/service/gpu/ir_emission_utils.h"
-#include "xla/service/maybe_owning_device_memory.h"
-#include "xla/service/service_executable_run_options.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/stream_executor/device_memory.h"
-#include "xla/stream_executor/gpu/redzone_allocator.h"
-#include "xla/stream_executor/stream.h"
-#include "xla/util.h"
-#include "xla/xla.pb.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-namespace {
-
-std::vector<ExecutionInput> ExecutionInputsFromBuffers(
-    absl::Span<se::DeviceMemoryBase const> buffers,
-    absl::Span<Shape const> shapes) {
-  CHECK_EQ(buffers.size(), shapes.size());
-  std::vector<ExecutionInput> inputs;
-  for (int i = 0; i < buffers.size(); ++i) {
-    inputs.emplace_back(shapes.at(i));
-    // Our executable doesn't have input-output aliasing, so we can pass
-    // unowned input buffers.
-    inputs.back().SetUnownedBuffer(
-        /*index=*/{}, MaybeOwningDeviceMemory(/*unowned=*/buffers.at(i)));
-  }
-  return inputs;
-}
-
-}  // namespace
-
-AutotunerCompileUtil::AutotunerCompileUtil(const AutotuneConfig& config,
-                                           Compiler* compiler,
-                                           se::StreamExecutor& stream_executor,
-                                           se::Stream& stream,
-                                           se::DeviceMemoryAllocator& allocator,
-                                           const DebugOptions& opts)
-    : config_(config),
-      compiler_(compiler),
-      stream_executor_(stream_executor),
-      stream_(stream),
-      allocator_(allocator),
-      opts_(opts) {
-  // Avoid dumping compilation steps.
-  opts_.set_xla_enable_dumping(false);
-  opts_.set_xla_gpu_dump_autotune_results_to("");
-  opts_.set_xla_gpu_load_autotune_results_from("");
-  opts_.set_xla_gpu_dump_llvmir(false);
-  opts_.set_xla_gpu_dump_autotune_logs_to("");
-  // Avoid using another thread pool.
-  opts_.set_xla_gpu_force_compilation_parallelism(1);
-  opts_.set_xla_gpu_enable_llvm_module_compilation_parallelism(false);
-  // Avoid using GPU graphs as we don't want to measure graph construction time.
-  opts_.clear_xla_gpu_enable_command_buffer();
-  opts_.set_xla_embed_ir_in_executable(false);
-  opts_.set_xla_gpu_kernel_cache_file("");
-}
-
-absl::StatusOr<std::optional<AutotunerCompileUtil::ProfilingOutput>>
-AutotunerCompileUtil::ProfileExecutable(
-    Executable* executable, se::Stream* stream,
-    absl::Span<se::DeviceMemoryBase const> input_buffers,
-    absl::Span<Shape const> input_shapes) {
-  {
-    std::vector<ExecutionInput> execution_inputs =
-        ExecutionInputsFromBuffers(input_buffers, input_shapes);
-    // Warmup: in and out buffers are reused while probing different configs,
-    // so GPU caches should be in some comparable states during measurements.
-    absl::StatusOr<ExecutionOutput> execution_output =
-        Execute(*executable, std::move(execution_inputs));
-    if (!execution_output.ok()) {
-      // Treat register allocation error gracefully. If the compilation happens
-      // with the driver during execution then the error could surface here.
-      // It's enough to check this once here.
-      if (execution_output.status().code() ==
-          absl::StatusCode::kResourceExhausted) {
-        return {std::nullopt};
-      }
-      return execution_output.status();
-    }
-
-    TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
-  }
-  std::vector<ExecutionInput> execution_inputs =
-      ExecutionInputsFromBuffers(input_buffers, input_shapes);
-  ExecutionProfile profile;
-  // Flag that a warm-up run was executed so that GpuTimer can use the, more
-  // accurate, delay kernel implementation.
-  profile.set_warmup_run_executed(true);
-  TF_ASSIGN_OR_RETURN(
-      ExecutionOutput execution_output,
-      Execute(*executable, std::move(execution_inputs), &profile));
-  return std::make_optional<ProfilingOutput>(
-      absl::Nanoseconds(profile.compute_time_ns()),
-      execution_output.Commit().ConsumeResult());
-}
-
-absl::StatusOr<std::unique_ptr<Executable>> AutotunerCompileUtil::Compile(
-    GenerateModuleFn extractor) {
-  absl::StatusOr<std::unique_ptr<HloModule>> new_hlo_module = extractor(opts_);
-  if (new_hlo_module.status().GetPayload(kUncompilableFusion).has_value()) {
-    // Incompatible value of split-k is an example of an expected failure.
-    return std::unique_ptr<Executable>();
-  } else if (!new_hlo_module.status().ok()) {
-    return new_hlo_module.status();
-  }
-
-  absl::StatusOr<std::unique_ptr<Executable>> out = compiler_->RunBackend(
-      std::move(*new_hlo_module), &stream_executor_,
-      Compiler::CompileOptions{&allocator_, /*thread_pool=*/nullptr,
-                               /*layout_canonicalization_callback=*/{},
-                               /*is_autotuning_compilation=*/true});
-  if (out.status().code() == absl::StatusCode::kResourceExhausted ||
-      out.status().code() == absl::StatusCode::kCancelled) {
-    // Being out of shared memory budget or registers is an expected failure.
-    // Cancelling upon register spilling is also an expected failure.
-    return std::unique_ptr<Executable>();
-  }
-  return out;
-}
-
-absl::StatusOr<std::unique_ptr<HloModule>> AutotunerCompileUtil::ExtractModule(
-    GenerateModuleFn extractor) {
-  return extractor(opts_);
-}
-
-/*static*/ absl::StatusOr<std::optional<AutotunerCompileUtil>>
-AutotunerCompileUtil::Create(const AutotuneConfig& config,
-                             const DebugOptions& opts) {
-  if (config.IsDeviceless()) {
-    return std::nullopt;
-  }
-  se::StreamExecutor* stream_exec = config.GetExecutor();
-  se::DeviceMemoryAllocator* allocator = config.GetAllocator();
-  TF_ASSIGN_OR_RETURN(se::Stream* const stream, config.GetStream());
-  TF_ASSIGN_OR_RETURN(Compiler * compiler,
-                      Compiler::GetForPlatform(stream_exec->GetPlatform()));
-  return AutotunerCompileUtil(config, compiler, *stream_exec, *stream,
-                              *allocator, opts);
-}
-
-absl::StatusOr<ExecutionOutput> AutotunerCompileUtil::Execute(
-    Executable& executable, std::vector<ExecutionInput> arguments,
-    ExecutionProfile* profile) {
-  // Require exclusive GPU lock to prevent other runs during autotuning.
-  GpuExecutableRunOptions gpu_opts;
-  gpu_opts.set_requires_exclusive_lock_on_gpu();
-
-  ExecutableRunOptions run_options;
-  run_options.set_device_ordinal(stream_executor_.device_ordinal());
-  run_options.set_stream(&stream_);
-  run_options.set_allocator(&allocator_);
-  run_options.set_gpu_executable_run_options(&gpu_opts);
-  run_options.set_execution_profile(profile);
-  ServiceExecutableRunOptions service_run_options(run_options);
-  TF_ASSIGN_OR_RETURN(ExecutionOutput output,
-                      executable.ExecuteAsyncOnStreamWrapper(
-                          &service_run_options, std::move(arguments)));
-
-  return std::move(output);
-}
-
-absl::StatusOr<RedzoneBuffers> RedzoneBuffers::FromInstruction(
-    const HloInstruction& instruction, const AutotuneConfig& config,
-    const DebugOptions& debug_options, BuffersToCreate buffers_to_create) {
-  RedzoneBuffers buffers;
-
-  TF_ASSIGN_OR_RETURN(auto rz_allocator, AutotunerUtil::CreateRedzoneAllocator(
-                                             config, debug_options));
-  buffers.redzone_allocator_ =
-      std::make_unique<se::RedzoneAllocator>(std::move(rz_allocator));
-
-  int64_t rng_state = 0;
-
-  TF_RETURN_IF_ERROR(
-      buffers.CreateInputs(instruction, config, debug_options, rng_state));
-
-  if (buffers_to_create == BuffersToCreate::kAllInputsAllOutputs ||
-      buffers_to_create == BuffersToCreate::kAllInputsOutputsNoScratch) {
-    TF_RETURN_IF_ERROR(buffers.CreateOutputs(instruction, config, debug_options,
-                                             buffers_to_create, rng_state));
-  }
-
-  return buffers;
-}
-
-absl::Status RedzoneBuffers::CreateInputs(const HloInstruction& instruction,
-                                          const AutotuneConfig& config,
-                                          const DebugOptions& debug_options,
-                                          int64_t& rng_state) {
-  for (const auto* operand : instruction.operands()) {
-    TF_ASSIGN_OR_RETURN(
-        se::DeviceMemoryBase buf,
-        AutotunerUtil::CreateBuffer(*redzone_allocator_, operand->shape(),
-                                    config, rng_state));
-    input_buffers_.push_back(buf);
-    input_shapes_.push_back(operand->shape());
-  }
-  return absl::OkStatus();
-}
-
-absl::Status RedzoneBuffers::CreateOutputs(const HloInstruction& instruction,
-                                           const AutotuneConfig& config,
-                                           const DebugOptions& debug_options,
-                                           BuffersToCreate buffers_to_create,
-                                           int64_t& rng_state) {
-  if (!instruction.shape().IsTuple()) {
-    TF_ASSIGN_OR_RETURN(
-        se::DeviceMemoryBase buf,
-        AutotunerUtil::CreateBuffer(*redzone_allocator_, instruction.shape(),
-                                    config, rng_state));
-    output_buffers_.push_back(buf);
-    output_shape_ = instruction.shape();
-    return absl::OkStatus();
-  }
-
-  // The output is a tuple.
-
-  auto current_shape_it = instruction.shape().tuple_shapes().begin();
-  auto end = instruction.shape().tuple_shapes().end();
-  end -= buffers_to_create == kAllInputsAllOutputs ? 0 : 1;
-
-  output_shape_ = std::distance(current_shape_it, end) == 1
-                      ? output_shape_ = *current_shape_it
-                      : ShapeUtil::MakeTupleShape(
-                            std::vector<Shape>{current_shape_it, end});
-
-  for (; current_shape_it < end; current_shape_it++) {
-    if (current_shape_it->IsTuple()) {
-      return Unimplemented("Nested tuples are unsupported by RedzoneBuffers.");
-    }
-    TF_ASSIGN_OR_RETURN(
-        se::DeviceMemoryBase buf,
-        AutotunerUtil::CreateBuffer(*redzone_allocator_, *current_shape_it,
-                                    config, rng_state));
-    output_buffers_.push_back(buf);
-  }
-  return absl::OkStatus();
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/autotuner_compile_util.h b/third_party/xla/xla/service/gpu/autotuner_compile_util.h
deleted file mode 100644
index 5137fcf..0000000
--- a/third_party/xla/xla/service/gpu/autotuner_compile_util.h
+++ /dev/null
@@ -1,177 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_AUTOTUNER_COMPILE_UTIL_H_
-#define XLA_SERVICE_GPU_AUTOTUNER_COMPILE_UTIL_H_
-
-#include <cstdint>
-#include <memory>
-#include <optional>
-#include <utility>
-#include <vector>
-
-#include "absl/functional/any_invocable.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/time/time.h"
-#include "absl/types/span.h"
-#include "xla/hlo/ir/hlo_clone_context.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/compiler.h"
-#include "xla/service/executable.h"
-#include "xla/service/gpu/autotuner_util.h"
-#include "xla/service/shaped_buffer.h"
-#include "xla/shape.h"
-#include "xla/stream_executor/device_memory_allocator.h"
-#include "xla/stream_executor/gpu/redzone_allocator.h"
-#include "xla/stream_executor/stream.h"
-#include "xla/util.h"
-#include "xla/xla.pb.h"
-
-namespace xla {
-namespace gpu {
-
-// Autotuning utils which require compiling fusions separately. Requires a
-// separate target, as runtime autotuning cannot perform compilation.
-class AutotunerCompileUtil {
- public:
-  // The GenerateModuleFn must generate/extract a module using the provided
-  // debug options. Typically it should set the debug options of the extracted
-  // module before it would transform it, to ensure that the transforms can use
-  // the debug options. In justified cases, it may override some of the provided
-  // debug options.
-  using GenerateModuleFn =
-      absl::AnyInvocable<absl::StatusOr<std::unique_ptr<HloModule>>(
-          const DebugOptions&)>;
-
-  // Generates a compile util for a platform associated with the `stream`.
-  //
-  // Returns an empty optional if the AutotuneConfig is deviceless, as
-  // autotuning is impossible in that case.
-  static absl::StatusOr<std::optional<AutotunerCompileUtil>> Create(
-      const AutotuneConfig& config, const DebugOptions& opts);
-
-  struct ProfilingOutput {
-    ProfilingOutput(absl::Duration duration, ScopedShapedBuffer&& buffer)
-        : duration(duration), output(std::move(buffer)) {}
-
-    absl::Duration duration;
-    ScopedShapedBuffer output;
-  };
-
-  // Generates an executable first, given the module generator function in
-  // `extractor`.
-  //
-  // Runs the resulting executable with the given extractor, cached with
-  // `(cache_key, config)`. Returns `std::nullopt` on expected failure, bad
-  // `Status` otherwise.
-  absl::StatusOr<std::optional<ProfilingOutput>> ProfileExecutable(
-      Executable* executable, se::Stream* stream,
-      absl::Span<se::DeviceMemoryBase const> input_buffers,
-      absl::Span<Shape const> input_shapes);
-
-  // Generic method to compile a generated module from `extractor` in isolation.
-  //
-  // Returns:
-  //  - `nullptr` on *expected* failure
-  //  - `Executable` if everything goes fine.
-  //  - `Status` on *unexpected* failure.
-  absl::StatusOr<std::unique_ptr<Executable>> Compile(
-      GenerateModuleFn extractor);
-
-  // Generic method to extract an HLO using the debug options of the
-  // AutotunerCompileUtil.
-  //
-  // Typically we can use Compile directly.
-  absl::StatusOr<std::unique_ptr<HloModule>> ExtractModule(
-      GenerateModuleFn extractor);
-
- private:
-  AutotunerCompileUtil(const AutotuneConfig& config, Compiler* compiler,
-                       se::StreamExecutor& stream_executor, se::Stream& stream,
-                       se::DeviceMemoryAllocator& allocator,
-                       const DebugOptions& opts);
-
-  absl::StatusOr<ExecutionOutput> Execute(Executable& executable,
-                                          std::vector<ExecutionInput> arguments,
-                                          ExecutionProfile* profile = nullptr);
-
-  AutotuneConfig config_;
-  Compiler* compiler_;
-  se::StreamExecutor& stream_executor_;
-  se::Stream& stream_;
-  se::DeviceMemoryAllocator& allocator_;
-  DebugOptions opts_;
-};
-
-// A RedZone allocator and a collection of buffers that store the inputs and
-// outputs of an HloInstruction. These are used when running the instruction
-// for autotuning.
-class RedzoneBuffers {
- public:
-  enum BuffersToCreate {
-    // Create a buffer for all of the instruction's operands. The result shape
-    // is ignored.
-    kAllInputs = 0,
-    // Create a buffer for all of the instruction's operands and the entire
-    // result shape. If the result shape is a tuple, a separate buffer is
-    // created for each subshape.
-    kAllInputsAllOutputs = 1,
-    // Create a buffer for all of the instruction's operands and all of the
-    // subshapes of the result tuple, except for the last one. The last subshape
-    // is considered a scratch buffer and is assumed to be allocated elsewhere.
-    // If the result shape is not a tuple, this will create a buffer
-    // corresponding to the entire shape - equivalent to `kAllInputsAllOutputs`.
-    kAllInputsOutputsNoScratch = 2,
-  };
-  static absl::StatusOr<RedzoneBuffers> FromInstruction(
-      const HloInstruction& instruction, const AutotuneConfig& config,
-      const DebugOptions& debug_options, BuffersToCreate buffers_to_create);
-
-  const std::vector<se::DeviceMemoryBase>& input_buffers() const {
-    return input_buffers_;
-  }
-  const std::vector<Shape>& input_shapes() const { return input_shapes_; }
-  const std::vector<se::DeviceMemoryBase>& output_buffers() const {
-    return output_buffers_;
-  }
-  const Shape& output_shape() const { return output_shape_; }
-  se::RedzoneAllocator& RedzoneAllocator() const { return *redzone_allocator_; }
-
- private:
-  absl::Status CreateInputs(const HloInstruction& instruction,
-                            const AutotuneConfig& config,
-                            const DebugOptions& debug_options,
-                            int64_t& rng_state);
-
-  absl::Status CreateOutputs(const HloInstruction& instruction,
-                             const AutotuneConfig& config,
-                             const DebugOptions& debug_options,
-                             BuffersToCreate buffers_to_create,
-                             int64_t& rng_state);
-
-  std::unique_ptr<se::RedzoneAllocator> redzone_allocator_;
-  std::vector<se::DeviceMemoryBase> input_buffers_;
-  std::vector<Shape> input_shapes_;
-  std::vector<se::DeviceMemoryBase> output_buffers_;
-  Shape output_shape_;
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_AUTOTUNER_COMPILE_UTIL_H_
diff --git a/third_party/xla/xla/service/gpu/autotuner_compile_util_test.cc b/third_party/xla/xla/service/gpu/autotuner_compile_util_test.cc
deleted file mode 100644
index 1db5afb..0000000
--- a/third_party/xla/xla/service/gpu/autotuner_compile_util_test.cc
+++ /dev/null
@@ -1,196 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/autotuner_compile_util.h"
-
-#include <vector>
-
-#include <gtest/gtest.h>
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/gpu/autotuner_util.h"
-#include "xla/service/platform_util.h"
-#include "xla/stream_executor/platform.h"
-#include "xla/tests/hlo_test_base.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla::gpu {
-namespace {
-
-using AutotunerCompileUtilTest = HloTestBase;
-
-TEST_F(AutotunerCompileUtilTest, VerifyOutputNotATuple) {
-  constexpr absl::string_view kHlo = R"(
-HloModule hlo
-ENTRY main {
-  p0 = f32[2,2] parameter(0)
-  p1 = f32[4,4] parameter(1)
-  p2 = f32[6,6] parameter(2)
-  ROOT root = f32[1,2,3] custom-call(p0, p1, p2), custom_call_target="fake"
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHlo));
-
-  se::Platform* platform = PlatformUtil::GetDefaultPlatform().value();
-  TF_ASSERT_OK_AND_ASSIGN(std::vector<se::StreamExecutor*> executors,
-                          PlatformUtil::GetStreamExecutors(platform));
-
-  AutotuneConfig autotune_config{DeviceConfig{executors.at(0), nullptr},
-                                 GetDebugOptionsForTest()};
-
-  auto& root = *module->entry_computation()->root_instruction();
-
-  TF_ASSERT_OK_AND_ASSIGN(RedzoneBuffers rzb,
-                          RedzoneBuffers::FromInstruction(
-                              root, autotune_config, GetDebugOptionsForTest(),
-                              RedzoneBuffers::kAllInputs));
-
-  EXPECT_EQ(rzb.input_shapes().size(), 3);
-  EXPECT_EQ(rzb.input_buffers().size(), 3);
-  EXPECT_EQ(rzb.output_buffers().size(), 0);
-  EXPECT_NE(rzb.output_shape(), root.shape());
-
-  TF_ASSERT_OK_AND_ASSIGN(RedzoneBuffers rzb2,
-                          RedzoneBuffers::FromInstruction(
-                              root, autotune_config, GetDebugOptionsForTest(),
-                              RedzoneBuffers::kAllInputsAllOutputs));
-
-  EXPECT_EQ(rzb2.input_shapes().size(), 3);
-  EXPECT_EQ(rzb2.input_buffers().size(), 3);
-  EXPECT_EQ(rzb2.output_buffers().size(), 1);
-  EXPECT_EQ(rzb2.output_shape(), root.shape());
-
-  TF_ASSERT_OK_AND_ASSIGN(RedzoneBuffers rzb3,
-                          RedzoneBuffers::FromInstruction(
-                              root, autotune_config, GetDebugOptionsForTest(),
-                              RedzoneBuffers::kAllInputsOutputsNoScratch));
-
-  EXPECT_EQ(rzb3.input_shapes().size(), 3);
-  EXPECT_EQ(rzb3.input_buffers().size(), 3);
-  EXPECT_EQ(rzb3.output_buffers().size(), 1);
-  EXPECT_EQ(rzb3.output_shape(), root.shape());
-}
-
-TEST_F(AutotunerCompileUtilTest, VerifyOutputTupleOneElement) {
-  constexpr absl::string_view kHlo = R"(
-HloModule hlo
-ENTRY main {
-  p0 = f32[2,2] parameter(0)
-  p1 = f32[4,4] parameter(1)
-  p2 = f32[6,6] parameter(2)
-  ROOT root = (f32[1,2,3]) custom-call(p0, p1, p2), custom_call_target="fake"
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHlo));
-
-  se::Platform* platform = PlatformUtil::GetDefaultPlatform().value();
-  TF_ASSERT_OK_AND_ASSIGN(std::vector<se::StreamExecutor*> executors,
-                          PlatformUtil::GetStreamExecutors(platform));
-
-  AutotuneConfig autotune_config{DeviceConfig{executors.at(0), nullptr},
-                                 GetDebugOptionsForTest()};
-
-  auto& root = *module->entry_computation()->root_instruction();
-
-  TF_ASSERT_OK_AND_ASSIGN(RedzoneBuffers rzb,
-                          RedzoneBuffers::FromInstruction(
-                              root, autotune_config, GetDebugOptionsForTest(),
-                              RedzoneBuffers::kAllInputs));
-
-  EXPECT_EQ(rzb.input_shapes().size(), 3);
-  EXPECT_EQ(rzb.input_buffers().size(), 3);
-  EXPECT_EQ(rzb.output_buffers().size(), 0);
-  EXPECT_NE(rzb.output_shape(), root.shape());
-
-  TF_ASSERT_OK_AND_ASSIGN(RedzoneBuffers rzb2,
-                          RedzoneBuffers::FromInstruction(
-                              root, autotune_config, GetDebugOptionsForTest(),
-                              RedzoneBuffers::kAllInputsAllOutputs));
-
-  EXPECT_EQ(rzb2.input_shapes().size(), 3);
-  EXPECT_EQ(rzb2.input_buffers().size(), 3);
-  EXPECT_EQ(rzb2.output_buffers().size(), 1);
-  EXPECT_FALSE(rzb2.output_shape().IsTuple());
-  EXPECT_EQ(rzb2.output_shape(), root.shape().tuple_shapes(0));
-
-  TF_ASSERT_OK_AND_ASSIGN(RedzoneBuffers rzb3,
-                          RedzoneBuffers::FromInstruction(
-                              root, autotune_config, GetDebugOptionsForTest(),
-                              RedzoneBuffers::kAllInputsOutputsNoScratch));
-
-  EXPECT_EQ(rzb3.input_shapes().size(), 3);
-  EXPECT_EQ(rzb3.input_buffers().size(), 3);
-  EXPECT_EQ(rzb3.output_buffers().size(), 0);
-}
-
-TEST_F(AutotunerCompileUtilTest, VerifyOutputTupleTwoElements) {
-  constexpr absl::string_view kHlo = R"(
-HloModule hlo
-ENTRY main {
-  p0 = f32[2,2] parameter(0)
-  p1 = f32[4,4] parameter(1)
-  p2 = f32[6,6] parameter(2)
-  ROOT root = (f32[1,2,3], u8[1,2]) custom-call(p0, p1, p2), custom_call_target="fake"
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHlo));
-
-  se::Platform* platform = PlatformUtil::GetDefaultPlatform().value();
-  TF_ASSERT_OK_AND_ASSIGN(std::vector<se::StreamExecutor*> executors,
-                          PlatformUtil::GetStreamExecutors(platform));
-
-  AutotuneConfig autotune_config{DeviceConfig{executors.at(0), nullptr},
-                                 GetDebugOptionsForTest()};
-
-  auto& root = *module->entry_computation()->root_instruction();
-
-  TF_ASSERT_OK_AND_ASSIGN(RedzoneBuffers rzb,
-                          RedzoneBuffers::FromInstruction(
-                              root, autotune_config, GetDebugOptionsForTest(),
-                              RedzoneBuffers::kAllInputs));
-
-  EXPECT_EQ(rzb.input_shapes().size(), 3);
-  EXPECT_EQ(rzb.input_buffers().size(), 3);
-  EXPECT_EQ(rzb.output_buffers().size(), 0);
-  EXPECT_NE(rzb.output_shape(), root.shape());
-
-  TF_ASSERT_OK_AND_ASSIGN(RedzoneBuffers rzb2,
-                          RedzoneBuffers::FromInstruction(
-                              root, autotune_config, GetDebugOptionsForTest(),
-                              RedzoneBuffers::kAllInputsAllOutputs));
-
-  EXPECT_EQ(rzb2.input_shapes().size(), 3);
-  EXPECT_EQ(rzb2.input_buffers().size(), 3);
-  EXPECT_EQ(rzb2.output_buffers().size(), 2);
-  EXPECT_TRUE(rzb2.output_shape().IsTuple());
-  EXPECT_EQ(rzb2.output_shape(), root.shape());
-
-  TF_ASSERT_OK_AND_ASSIGN(RedzoneBuffers rzb3,
-                          RedzoneBuffers::FromInstruction(
-                              root, autotune_config, GetDebugOptionsForTest(),
-                              RedzoneBuffers::kAllInputsOutputsNoScratch));
-
-  EXPECT_EQ(rzb3.input_shapes().size(), 3);
-  EXPECT_EQ(rzb3.input_buffers().size(), 3);
-  EXPECT_EQ(rzb3.output_buffers().size(), 1);
-  EXPECT_FALSE(rzb3.output_shape().IsTuple());
-  EXPECT_EQ(rzb3.output_shape(), root.shape().tuple_shapes(0));
-}
-
-}  // namespace
-}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/autotuner_util.cc b/third_party/xla/xla/service/gpu/autotuner_util.cc
deleted file mode 100644
index 93c946f..0000000
--- a/third_party/xla/xla/service/gpu/autotuner_util.cc
+++ /dev/null
@@ -1,547 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/autotuner_util.h"
-
-#include <algorithm>
-#include <array>
-#include <cmath>
-#include <cstdint>
-#include <limits>
-#include <optional>
-#include <string>
-#include <utility>
-#include <variant>
-
-#include "absl/base/const_init.h"
-#include "absl/base/thread_annotations.h"
-#include "absl/container/flat_hash_map.h"
-#include "absl/log/log.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/match.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/str_format.h"
-#include "absl/strings/string_view.h"
-#include "absl/synchronization/mutex.h"
-#include "llvm/ADT/StringRef.h"
-#include "llvm/Support/SHA256.h"
-#include "xla/autotune_results.pb.h"
-#include "xla/autotuning.pb.h"
-#include "xla/hlo/ir/hlo_clone_context.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/gpu/gpu_asm_opts_util.h"
-#include "xla/service/gpu/stream_executor_util.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/status_macros.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/stream_executor/device_memory.h"
-#include "xla/stream_executor/gpu/redzone_allocator.h"
-#include "xla/stream_executor/stream.h"
-#include "xla/util.h"
-#include "tsl/platform/base64.h"
-#include "tsl/platform/env.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h"
-#include "tsl/platform/path.h"
-#include "tsl/platform/protobuf.h"  // IWYU pragma: keep
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-// Bump this version whenever you change the structure of the results.
-// LINT.IfChange(version)
-constexpr int kVersion = 3;
-// LINT.ThenChange()
-
-}  // namespace
-
-using AutotuneCacheMap = absl::flat_hash_map<AutotuneCacheKey, AutotuneResult>;
-
-static absl::Mutex autotune_cache_mu(absl::kConstInit);
-static auto& autotune_cache ABSL_GUARDED_BY(autotune_cache_mu) =
-    *new AutotuneCacheMap();
-
-absl::StatusOr<std::string> GetBase64EncodedSha256Hash(absl::string_view s) {
-  llvm::SHA256 sha256;
-  sha256.update(llvm::StringRef(s));
-  std::array<uint8_t, 32> hash = sha256.final();
-  // C++ strict aliasing rules allow reinterpret casting to (const) char*.
-  absl::string_view hash_view(reinterpret_cast<const char*>(hash.data()),
-                              hash.size());
-  std::string base64_encoded_hash;
-  TF_RETURN_IF_ERROR(tsl::Base64Encode(hash_view, &base64_encoded_hash));
-  return base64_encoded_hash;
-}
-
-namespace {
-
-// Get the path corresponding to the given key.
-absl::StatusOr<std::string> GetCacheFilePath(absl::string_view cache_dir,
-                                             const AutotuneCacheKey& key) {
-  if (cache_dir.empty()) {
-    return absl::InvalidArgumentError("autotune_cache_dir should not be empty");
-  }
-
-  TF_ASSIGN_OR_RETURN(std::string key_hash,
-                      GetBase64EncodedSha256Hash(key.ToString()));
-  return tsl::io::JoinPath(cache_dir, absl::StrCat(key_hash, ".textproto"));
-}
-
-struct ResultAndInserted {
-  // The result that ended up in the cache. This is the existing result if
-  // inserted is false, and the new result if inserted is true.
-  //
-  // We return a value, not a pointer, for thread safety reasons.
-  AutotuneResult result;
-  // Did we insert the given result into the cache?
-  bool inserted;
-};
-
-ResultAndInserted AddResultToInMemoryCache(const AutotuneCacheKey& key,
-                                           AutotuneResult result)
-    ABSL_LOCKS_EXCLUDED(autotune_cache_mu) {
-  absl::MutexLock lock(&autotune_cache_mu);
-  auto [it, inserted] = autotune_cache.emplace(key, std::move(result));
-  return {it->second, inserted};
-}
-
-absl::Status AddResultToFileBasedCacheIfEnabled(const AutotuneCacheKey& key,
-                                                AutotuneResult result,
-                                                std::string_view cache_dir)
-    ABSL_LOCKS_EXCLUDED(autotune_cache_mu) {
-  if (cache_dir.empty()) {
-    return absl::OkStatus();
-  }
-
-  TF_ASSIGN_OR_RETURN(const std::string file_path,
-                      GetCacheFilePath(cache_dir, key));
-
-  VLOG(1) << "Writing autotune result to file: " << file_path;
-
-  std::string result_str;
-  if (!tsl::protobuf::TextFormat::PrintToString(result, &result_str)) {
-    return absl::InternalError("Failed to serialize autotune result.");
-  }
-
-  // Rename trick: Write to a temporary file, then rename it to the final file
-  // to avoid mingled files when multiple processes are writing to the same
-  // file. Also avoids reading incomplete files. (This may not work on all file
-  // systems.)
-  std::string temp_file_path = tsl::io::GetTempFilename(".textproto");
-  tsl::Env* default_env = tsl::Env::Default();
-  TF_RETURN_IF_ERROR(
-      tsl::WriteStringToFile(default_env, temp_file_path, result_str));
-  return default_env->RenameFile(temp_file_path, file_path);
-}
-
-absl::StatusOr<ResultAndInserted> AddResultToCaches(const AutotuneCacheKey& key,
-                                                    AutotuneResult result,
-                                                    std::string_view cache_dir)
-    ABSL_LOCKS_EXCLUDED(autotune_cache_mu) {
-  ResultAndInserted result_and_inserted = AddResultToInMemoryCache(key, result);
-  if (result_and_inserted.inserted) {
-    TF_RETURN_IF_ERROR(AddResultToFileBasedCacheIfEnabled(
-        key, result_and_inserted.result, cache_dir));
-  }
-  return result_and_inserted;
-}
-
-std::optional<AutotuneResult> TryToFindInInMemoryCache(
-    const AutotuneCacheKey& key) ABSL_LOCKS_EXCLUDED(autotune_cache_mu) {
-  absl::MutexLock lock(&autotune_cache_mu);
-  auto it = autotune_cache.find(key);
-  if (it == autotune_cache.end()) {
-    return std::nullopt;
-  }
-  return it->second;
-}
-
-absl::StatusOr<std::optional<AutotuneResult>>
-TryToFindInFileBasedCacheIfEnabled(const AutotuneCacheKey& key,
-                                   absl::string_view cache_dir)
-    ABSL_LOCKS_EXCLUDED(autotune_cache_mu) {
-  if (cache_dir.empty()) {
-    return std::nullopt;
-  }
-
-  TF_ASSIGN_OR_RETURN(const std::string file_path,
-                      GetCacheFilePath(cache_dir, key));
-  if (!tsl::Env::Default()->FileExists(file_path).ok()) {
-    VLOG(1) << "Autotune result file not found: " << file_path;
-    return std::nullopt;
-  }
-
-  VLOG(1) << "Autotune result file found: " << file_path;
-  std::string autotune_result_str;
-  TF_RETURN_IF_ERROR(tsl::ReadFileToString(tsl::Env::Default(), file_path,
-                                           &autotune_result_str));
-  AutotuneResult result;
-  if (!tsl::protobuf::TextFormat::ParseFromString(autotune_result_str,
-                                                  &result)) {
-    return absl::InvalidArgumentError("Failed to parse autotune result.");
-  }
-  return result;
-}
-
-// Sort the results so that they're deterministic.
-void SortAutotuneResults(AutotuneResults* results) {
-  std::sort(results->mutable_results()->pointer_begin(),
-            results->mutable_results()->pointer_end(),
-            [](const auto* a, const auto* b) {
-              return std::make_pair(absl::string_view(a->device()),
-                                    absl::string_view(a->hlo())) <
-                     std::make_pair(absl::string_view(b->device()),
-                                    absl::string_view(b->hlo()));
-            });
-}
-
-}  // namespace
-
-// Serialize `results` to string as a proto.
-absl::StatusOr<std::string> AutotuneResultsToString(
-    const AutotuneResults& results, bool as_textproto) {
-  if (as_textproto) {
-    std::string textproto;
-    if (tsl::protobuf::TextFormat::PrintToString(results, &textproto)) {
-      return textproto;
-    } else {
-      return Internal("Failed to serialize autotune results.");
-    }
-  }
-  return results.SerializeAsString();
-}
-
-namespace {
-// Serialize a single entry to `results`.
-void SerializeAutotuneEntry(AutotuneResults* results, const AutotuneCacheKey& k,
-                            const AutotuneResult* res) {
-  auto& entry = *results->add_results();
-  entry.set_device(std::string(k.GetModelStr()));
-  entry.set_hlo(std::string(k.GetHlo()));
-  *entry.mutable_result() = *res;
-}
-}  // namespace
-
-/*static*/ absl::Status AutotunerUtil::SerializeAutotuneResults(
-    AutotuneResults* results) {
-  absl::MutexLock lock(&autotune_cache_mu);
-  for (const auto& [k, result] : autotune_cache) {
-    SerializeAutotuneEntry(results, k, &result);
-  }
-
-  results->set_version(kVersion);
-  SortAutotuneResults(results);
-
-  return absl::OkStatus();
-}
-
-/*static*/ absl::Status AutotunerUtil::LoadAutotuneResults(
-    const AutotuneResults& results) {
-  absl::MutexLock lock(&autotune_cache_mu);
-  for (const AutotuneResults::Entry& result : results.results()) {
-    if (auto [it, inserted] = autotune_cache.emplace(
-            AutotuneCacheKey(result.device(), result.hlo()), result.result());
-        !inserted) {
-      return absl::InternalError(absl::StrCat(
-          "Duplicate autotuning result for ", it->first.ToString()));
-    }
-  }
-  return absl::OkStatus();
-}
-
-/*static*/ void AutotunerUtil::ClearAutotuneResults() {
-  absl::MutexLock lock(&autotune_cache_mu);
-  autotune_cache.clear();
-}
-
-/*static*/ bool AutotunerUtil::ResultCacheIsEmpty() {
-  absl::MutexLock lock(&autotune_cache_mu);
-  return autotune_cache.empty();
-}
-
-/* static*/ absl::StatusOr<se::DeviceMemoryBase> AutotunerUtil::CreateBuffer(
-    se::RedzoneAllocator& allocator, const Shape& shape,
-    const AutotuneConfig& config, int64_t& rng_state) {
-  TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase buffer,
-                      allocator.AllocateBytes(ShapeUtil::ByteSizeOf(shape)));
-  if (config.should_init_buffers()) {
-    InitializeBuffer(allocator.stream(), shape.element_type(), &rng_state,
-                     buffer);
-  }
-  return buffer;
-}
-
-namespace {
-std::string ToCanonicalString(const HloInstruction* instr) {
-  auto options = HloPrintOptions::Canonical();
-  if (instr->opcode() != HloOpcode::kFusion) {
-    options.set_print_backend_config(true);
-    return instr->ToString(options);
-  }
-  options.set_print_subcomputation_mode(
-      HloPrintOptions::PrintSubcomputationMode::kOff);
-  options.set_print_infeed_outfeed_config(false);
-  options.set_print_only_essential_constants(true);
-  options.set_print_operand_shape(true);
-  options.set_print_ids(false);
-  options.set_canonicalize_computations(true);
-
-  // TODO(b/266210099): This is unsound. We should probably do the fingerprint
-  // of the HLO computation proto instead.
-  return instr->called_computations()[0]->ToString(options);
-}
-
-}  // namespace
-
-AutotuneCacheKey::AutotuneCacheKey(absl::string_view model_str,
-                                   const HloInstruction& instr)
-    : AutotuneCacheKey(model_str, ToCanonicalString(&instr)) {}
-
-/*static*/ std::string AutotuneCacheKey::DeviceDescriptionToCacheKey(
-    const se::DeviceDescription& device_description) {
-  std::string compute_capability;
-  if (auto* ccc = std::get_if<se::CudaComputeCapability>(
-          &device_description.gpu_compute_capability())) {
-    compute_capability = absl::StrCat("CUDA: ", ccc->major, ".", ccc->minor);
-  } else {
-    auto* rcc = std::get_if<se::RocmComputeCapability>(
-        &device_description.gpu_compute_capability());
-    CHECK(rcc != nullptr) << "Unknown compute capability type";
-    compute_capability = absl::StrCat("ROCM: ", rcc->gfx_version());
-  }
-
-  // The string below should include only as much information as is needed to
-  // make it a valid key. Information that should not be included is:
-  // - specs that are directly derivable from the compute capability, e.g.
-  //   shared memory size. For NVIDIA GPUs, you can see what is derivable from
-  //   the SM version here:
-  //   https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications-technical-specifications-per-compute-capability
-  // - specs that are irrelevant for autotuning. E.g. the total available memory
-  //   on a device is not relevant, because by itself, it does not affect the
-  //   performance of single kernels.
-  //
-  // See b/344573710 for some discussion.
-
-  double memory_bandwidth = device_description.memory_bandwidth() / 1e9;
-  // Round the memory bandwidth to make the final string nicer to read.
-  // This will also cause minute differences in bandwidth to yield the same
-  // cache key, but that's fine, since the difference is inconsequential.
-  memory_bandwidth = std::round(memory_bandwidth);
-
-  constexpr double kBytesPerMegabyte = 1 << 20;
-  double l2_cache_size = device_description.l2_cache_size() / kBytesPerMegabyte;
-
-  return absl::StrCat(compute_capability,
-                      ", Cores: ", device_description.core_count(),
-                      ", GPU clock: ", device_description.clock_rate_ghz(),
-                      " GHz, Memory bandwidth: ", memory_bandwidth,
-                      " GB/s, L2 cache: ", l2_cache_size, " MB");
-}
-
-namespace {
-absl::StatusOr<std::optional<AutotuneResult>> TryFindInCache(
-    const AutotuneCacheKey& key, absl::string_view cache_dir)
-    ABSL_LOCKS_EXCLUDED(autotune_cache_mu) {
-  std::optional<AutotuneResult> opt_result = TryToFindInInMemoryCache(key);
-  if (opt_result.has_value()) {
-    if (VLOG_IS_ON(1)) {
-      LOG(INFO) << "In-memory autotune cache hit";
-    } else if (VLOG_IS_ON(2)) {
-      LOG(INFO) << "In-memory autotune cache hit: key = " << key.ToString();
-    }
-    return opt_result;
-  }
-
-  TF_ASSIGN_OR_RETURN(opt_result,
-                      TryToFindInFileBasedCacheIfEnabled(key, cache_dir));
-  if (opt_result.has_value()) {
-    AddResultToInMemoryCache(key, opt_result.value());
-
-    if (VLOG_IS_ON(1)) {
-      LOG(INFO) << "File-based autotune cache hit";
-    } else if (VLOG_IS_ON(2)) {
-      LOG(INFO) << "File-based autotune cache hit: key = " << key.ToString();
-    }
-    return opt_result;
-  }
-
-  if (VLOG_IS_ON(1)) {
-    LOG(INFO) << "Autotune cache miss";
-  } else if (VLOG_IS_ON(2)) {
-    LOG(INFO) << "Autotune cache miss: key = " << key.ToString();
-  }
-  return std::nullopt;
-}
-}  // namespace
-
-/*static*/ AutotuneCacheKey AutotunerUtil::GetKey(
-    const HloInstruction* instr, const AutotuneConfig& config) {
-  return AutotuneCacheKey(config.GetModelStr(), *instr);
-}
-
-/*static*/ absl::StatusOr<bool> AutotunerUtil::IsInCache(
-    const AutotuneCacheKey& key, const AutotuneConfig& config) {
-  TF_ASSIGN_OR_RETURN(std::optional<AutotuneResult> opt_res,
-                      TryFindInCache(key, config.autotune_cache_dir()));
-  return opt_res.has_value();
-}
-
-/*static*/ absl::StatusOr<bool> AutotunerUtil::AddResult(
-    const AutotuneCacheKey& key, AutotuneResult result,
-    const AutotuneConfig& config) {
-  TF_ASSIGN_OR_RETURN(
-      ResultAndInserted result_and_inserted,
-      AddResultToCaches(key, std::move(result), config.autotune_cache_dir()));
-  return result_and_inserted.inserted;
-}
-
-/*static*/ absl::StatusOr<AutotuneResult> AutotunerUtil::Autotune(
-    const HloInstruction* instr, const AutotuneConfig& config,
-    const AutotuneNoCacheFn& autotune_fn) {
-  const AutotuneCacheKey key = GetKey(instr, config);
-  TF_ASSIGN_OR_RETURN(std::optional<AutotuneResult> opt_res,
-                      TryFindInCache(key, config.autotune_cache_dir()));
-  if (opt_res.has_value()) {
-    return opt_res.value();
-  }
-
-  // Cache miss.
-  if (config.should_require_complete_aot_autotune_results()) {
-    return NotFound(
-        "Complete XLA AOT autotuning results are required, but no AOT result "
-        "was found for key: %s",
-        key.ToString());
-  }
-
-  TF_ASSIGN_OR_RETURN(AutotuneResult autotune_result, autotune_fn());
-
-  TF_ASSIGN_OR_RETURN(ResultAndInserted result_and_inserted,
-                      AddResultToCaches(key, std::move(autotune_result),
-                                        config.autotune_cache_dir()));
-  return result_and_inserted.result;
-}
-
-namespace {
-
-bool IsTextProtoPath(absl::string_view file_path) {
-  return absl::EndsWith(file_path, ".txt") ||
-         absl::EndsWith(file_path, ".textproto") ||
-         absl::EndsWith(file_path, ".prototxt") ||
-         absl::EndsWith(file_path, ".pbtxt");
-}
-
-}  // anonymous namespace
-
-/*static*/ absl::Status AutotunerUtil::LoadAutotuneResults(
-    absl::string_view data, bool as_textproto) {
-  AutotuneResults results;
-  // The cast here is necessary for MacOS builds.
-  bool parse_success =
-      as_textproto ? tsl::protobuf::TextFormat::ParseFromString(
-                         std::string(data), &results)             // NOLINT
-                   : results.ParseFromString(std::string(data));  // NOLINT
-  if (!parse_success) {
-    return absl::InvalidArgumentError(
-        "Failed to parse autotune results string.");
-  }
-  if (results.version() != kVersion) {
-    return absl::InvalidArgumentError(absl::StrFormat(
-        "Version mismatch in autotune results. Expected %d but was %d",
-        kVersion, results.version()));
-  }
-
-  TF_RETURN_IF_ERROR(LoadAutotuneResults(results));
-  return absl::OkStatus();
-}
-
-/*static*/ absl::StatusOr<std::string> AutotunerUtil::SerializeAutotuneResults(
-    bool as_textproto) {
-  AutotuneResults results;
-  TF_RETURN_IF_ERROR(SerializeAutotuneResults(&results));
-  return AutotuneResultsToString(results, as_textproto);
-}
-
-/*static*/ absl::Status AutotunerUtil::SerializeAutotuneResultsToFile(
-    const AutotuneResults& results, absl::string_view file_path) {
-  TF_RET_CHECK(!file_path.empty());
-  TF_RET_CHECK(results.version() > 0)
-      << "Did you call SerializeAutotuneResults to get this AutotuneResults?";
-
-  std::string resolved_path;
-  if (!tsl::io::ResolveTestPrefixes(file_path, resolved_path)) {
-    return FailedPrecondition("File path can not be resolved: %s", file_path);
-  }
-
-  TF_ASSIGN_OR_RETURN(
-      std::string autotune_results_str,
-      AutotuneResultsToString(results, IsTextProtoPath(resolved_path)));
-  TF_RETURN_IF_ERROR(tsl::WriteStringToFile(tsl::Env::Default(), resolved_path,
-                                            autotune_results_str));
-  LOG(INFO) << "Autotune results serialized to file: " << resolved_path;
-
-  return absl::OkStatus();
-}
-
-/*static*/ absl::Status AutotunerUtil::SerializeAutotuneResultsToFile(
-    absl::string_view file_path) {
-  AutotuneResults results;
-  TF_RETURN_IF_ERROR(SerializeAutotuneResults(&results));
-  return SerializeAutotuneResultsToFile(results, file_path);
-}
-
-/*static*/ absl::Status AutotunerUtil::LoadAutotuneResultsFromFile(
-    absl::string_view file_path) {
-  TF_RET_CHECK(!file_path.empty());
-
-  std::string resolved_path;
-  if (!tsl::io::ResolveTestPrefixes(file_path, resolved_path)) {
-    return FailedPrecondition("File path can not be resolved: %s", file_path);
-  }
-
-  if (!tsl::Env::Default()->FileExists(resolved_path).ok()) {
-    return FailedPrecondition("Autotune results file does not exist: %s",
-                              resolved_path);
-  }
-  std::string autotune_results_str;
-  TF_RETURN_IF_ERROR(tsl::ReadFileToString(tsl::Env::Default(), resolved_path,
-                                           &autotune_results_str));
-
-  TF_RETURN_IF_ERROR(LoadAutotuneResults(autotune_results_str,
-                                         IsTextProtoPath(resolved_path)));
-
-  LOG(INFO) << "Autotune results loaded from file: " << resolved_path;
-
-  return absl::OkStatus();
-}
-
-/*static*/ absl::StatusOr<se::RedzoneAllocator>
-AutotunerUtil::CreateRedzoneAllocator(const AutotuneConfig& config,
-                                      const DebugOptions& opts) {
-  TF_ASSIGN_OR_RETURN(se::Stream * stream, config.GetStream());
-  return se::RedzoneAllocator(
-      stream, config.GetAllocator(), PtxOptsFromDebugOptions(opts),
-      /*memory_limit=*/std::numeric_limits<int64_t>::max(),
-      /*redzone_size=*/config.should_check_correctness()
-          ? opts.xla_gpu_redzone_padding_bytes()
-          : 0);
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/autotuner_util.h b/third_party/xla/xla/service/gpu/autotuner_util.h
deleted file mode 100644
index 4634fc2..0000000
--- a/third_party/xla/xla/service/gpu/autotuner_util.h
+++ /dev/null
@@ -1,334 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_AUTOTUNER_UTIL_H_
-#define XLA_SERVICE_GPU_AUTOTUNER_UTIL_H_
-
-#include <algorithm>
-#include <cstdint>
-#include <functional>
-#include <memory>
-#include <string>
-#include <utility>
-#include <variant>
-
-#include "absl/log/check.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/str_format.h"
-#include "absl/strings/string_view.h"
-#include "xla/autotune_results.pb.h"
-#include "xla/autotuning.pb.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/shape.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/stream_executor/device_memory.h"
-#include "xla/stream_executor/device_memory_allocator.h"
-#include "xla/stream_executor/gpu/redzone_allocator.h"
-#include "xla/stream_executor/stream_executor.h"
-#include "xla/stream_executor/stream_executor_memory_allocator.h"
-#include "xla/xla.pb.h"
-
-namespace xla {
-namespace gpu {
-
-struct DeviceConfig {
-  se::StreamExecutor* stream_exec;  // never null
-
-  // If the `allocator` parameter is not null, we will use it to allocate temp
-  // memory while timing the various convolution algorithms.  If it's null,
-  // we'll use the default allocator on the StreamExecutor.
-  se::DeviceMemoryAllocator* allocator = nullptr;  // may be null
-};
-
-struct DevicelessConfig {
-  // The device description of the target device.
-  se::DeviceDescription device_description;
-};
-
-class AutotuneCacheKey {
- public:
-  AutotuneCacheKey(const se::DeviceDescription& device_description,
-                   const HloInstruction& instruction)
-      : AutotuneCacheKey(DeviceDescriptionToCacheKey(device_description),
-                         instruction.ToString()) {}
-
-  AutotuneCacheKey(absl::string_view model_str,
-                   const HloInstruction& instruction);
-
-  explicit AutotuneCacheKey(absl::string_view model_str,
-                            absl::string_view hlo_canonical)
-      : model_str_(model_str), hlo_canonical_(hlo_canonical) {}
-
-  absl::string_view GetModelStr() const { return model_str_; }
-
-  absl::string_view GetHlo() const { return hlo_canonical_; }
-
-  template <typename H>
-  friend H AbslHashValue(H h, const AutotuneCacheKey& w) {
-    return H::combine(std::move(h), w.model_str_, w.hlo_canonical_);
-  }
-
-  bool operator==(const AutotuneCacheKey& w) const {
-    return model_str_ == w.model_str_ && hlo_canonical_ == w.hlo_canonical_;
-  }
-
-  std::string ToString() const {
-    return absl::StrFormat("<key model='%s', hlo='%s'>", model_str_,
-                           hlo_canonical_);
-  }
-
-  static std::string DeviceDescriptionToCacheKey(
-      const se::DeviceDescription& device_description);
-
- private:
-  std::string model_str_;
-  std::string hlo_canonical_;
-};
-
-class AutotuneConfig {
- public:
-  bool should_init_buffers() const { return autotune_level_ >= 2; }
-  bool should_reinit_output_buffer() const { return autotune_level_ >= 3; }
-  bool should_check_correctness() const { return autotune_level_ >= 4; }
-  bool should_skip_wrong_results() const { return autotune_level_ >= 5; }
-  bool should_crash_on_check_failure() const {
-    return should_crash_on_check_failure_;
-  }
-  bool should_require_complete_aot_autotune_results() const {
-    return require_complete_aot_autotune_results_;
-  }
-  // Empty string means no cache is used.
-  const std::string& autotune_cache_dir() const { return autotune_cache_dir_; }
-
-  AutotuneConfig(const AutotuneConfig& right)
-      : config_(right.config_),
-        autotune_level_(right.autotune_level_),
-        should_crash_on_check_failure_(right.should_crash_on_check_failure_),
-        exhaustive_tiling_search_(right.exhaustive_tiling_search_),
-        require_complete_aot_autotune_results_(
-            right.require_complete_aot_autotune_results_),
-        autotune_cache_dir_(right.autotune_cache_dir_) {}
-
-  AutotuneConfig(const std::variant<DeviceConfig, DevicelessConfig>& config,
-                 const DebugOptions& debug_options)
-      : config_(config),
-        autotune_level_(debug_options.xla_gpu_autotune_level()),
-        should_crash_on_check_failure_(
-            debug_options.xla_gpu_crash_on_verification_failures()),
-        exhaustive_tiling_search_(
-            debug_options.xla_gpu_exhaustive_tiling_search()),
-        require_complete_aot_autotune_results_(
-            debug_options.xla_gpu_require_complete_aot_autotune_results()),
-        autotune_cache_dir_(
-            debug_options.xla_gpu_per_fusion_autotune_cache_dir()) {}
-
-  std::string GetModelStr() const {
-    if (auto deviceless_config = std::get_if<DevicelessConfig>(&config_)) {
-      return AutotuneCacheKey::DeviceDescriptionToCacheKey(
-          deviceless_config->device_description);
-    }
-
-    const auto& device_config = std::get<DeviceConfig>(config_);
-    return AutotuneCacheKey::DeviceDescriptionToCacheKey(
-        device_config.stream_exec->GetDeviceDescription());
-  }
-
-  se::StreamExecutor* GetExecutor() const {
-    CHECK(std::holds_alternative<DeviceConfig>(config_));
-    return std::get<DeviceConfig>(config_).stream_exec;
-  }
-
-  se::DeviceMemoryAllocator* GetAllocator() const {
-    CHECK(std::holds_alternative<DeviceConfig>(config_));
-    auto& cf = std::get<DeviceConfig>(config_);
-    if (cf.allocator != nullptr) {
-      return cf.allocator;
-    }
-    if (allocator_ == nullptr) {
-      allocator_ =
-          std::make_unique<se::StreamExecutorMemoryAllocator>(GetExecutor());
-    }
-    return allocator_.get();
-  }
-
-  absl::StatusOr<se::Stream*> GetStream() const {
-    CHECK(std::holds_alternative<DeviceConfig>(config_));
-    return GetAllocator()->GetStream(GetExecutor()->device_ordinal());
-  }
-
-  const se::GpuComputeCapability& GetGpuComputeCapability() const {
-    if (auto c = std::get_if<DeviceConfig>(&config_)) {
-      return c->stream_exec->GetDeviceDescription().gpu_compute_capability();
-    }
-    return std::get<DevicelessConfig>(config_)
-        .device_description.gpu_compute_capability();
-  }
-
-  bool IsDeviceless() const {
-    return std::holds_alternative<DevicelessConfig>(config_);
-  }
-
-  bool ExhaustiveTilingSearch() const { return exhaustive_tiling_search_; }
-
- private:
-  std::variant<DeviceConfig, DevicelessConfig> config_;
-  int32_t autotune_level_;
-  bool should_crash_on_check_failure_;
-  bool exhaustive_tiling_search_;
-  bool require_complete_aot_autotune_results_;
-  mutable std::unique_ptr<se::DeviceMemoryAllocator> allocator_;
-  std::string autotune_cache_dir_;
-};
-
-using AutotuneNoCacheFn = std::function<absl::StatusOr<AutotuneResult>()>;
-
-struct AutotunerUtil {
-  // Create a buffer for a given operation using redzone checker, initialize
-  // based on a given rng state.
-  static absl::StatusOr<se::DeviceMemoryBase> CreateBuffer(
-      se::RedzoneAllocator& allocator, const Shape& shape,
-      const AutotuneConfig& config, int64_t& rng_state);
-
-  static absl::StatusOr<AutotuneResult> Autotune(
-      const HloInstruction* instr, const AutotuneConfig& config,
-      const AutotuneNoCacheFn& autotune_fn);
-
-  // Returns the same cache key that would be used inside Autotune().
-  //
-  // Normally, we don't have to use this low level method.
-  static AutotuneCacheKey GetKey(const HloInstruction* instr,
-                                 const AutotuneConfig& config);
-
-  // Checks if the key is in the autotune cache.
-  //
-  // Normally, we don't have to use this low level method.
-  static absl::StatusOr<bool> IsInCache(const AutotuneCacheKey& key,
-                                        const AutotuneConfig& config);
-
-  // Adds the result to the autotune cache.
-  //
-  // Returns true if the entry is inserted.
-  //
-  // Normally, we don't have to use this low level method.
-  static absl::StatusOr<bool> AddResult(const AutotuneCacheKey& key,
-                                        AutotuneResult result,
-                                        const AutotuneConfig& config);
-
-  // Creates a RedzoneAllocator from a given config.
-  static absl::StatusOr<se::RedzoneAllocator> CreateRedzoneAllocator(
-      const AutotuneConfig& config, const DebugOptions& opts);
-
-  // Functions to save/load XLA's autotuning results.
-  //
-  // This is used for ahead-of-time autotuning.  Specifically:
-  //
-  // When XLA calls cublas (for matmuls, aka "gemm" or "dot") or cudnn (for
-  // convolutions), it usually has to choose an "algorithm" for the particular
-  // dot/conv.  XLA queries cublas/cudnn for a list of candidate algorithms.
-  // Then it runs all of them and picks the fastest one.  This is what we call
-  // "autotuning". It happens in GemmAlgorithmPicker and GpuConvAlgorithmPicker.
-  //
-  // Autotuning is necessary to get good performance for dot/conv.  But it also
-  // has some disadvantages.
-  //
-  //  - Because it relies on timing data, it is fundamentally nondeterministic.
-  //    But even if two algorithms have similar runtimes, our choice of
-  //    algorithm may be visible to the user: Different algorithms can have
-  //    different numerics, and sometimes they can even have different bugs!
-  //
-  //  - Trying all the candidate algorithms can be slow, especially if when some
-  //    of the candidates are "very bad" and run especially slowly compared to
-  //    the optimal candidate.  This slows down compilation.
-  //
-  // To address the disadvantages above, we allow users to save/restore the
-  // autotuning choices that XLA has made, using the functions below.
-  //
-  // Loading autotuning results does not erase existing autotuning choices, but
-  // in the event of a disagreement between the existing data and the new data,
-  // the new algorithm is chosen.
-  //
-  // Note that even if you call LoadAutotuneResults(), if XLA encounters a
-  // dot/conv that is *not* covered by the loaded data, it will go ahead and
-  // autotune it like normal.  In other words, the behavior of XLA should be
-  // identical with or without ahead-of-time autotuning, modulo nondeterminism.
-  //
-  // This is important if you want to be able to use the same autotuning file
-  // with different versions of XLA, because as XLA changes, exactly which
-  // dots/convs it wants to run can also change.  For example, XLA might change
-  // the conv padding heuristics it uses, and we don't want that to mean that
-  // all users of ahead-of-time autotuning are broken.
-  static absl::StatusOr<std::string> SerializeAutotuneResults(
-      bool as_textproto = false);
-
-  // Serializes autotune results into the given proto.
-  static absl::Status SerializeAutotuneResults(AutotuneResults* results);
-
-  // Loads autotune results from the given string of bytes.
-  //
-  // Warning: The results are only loaded to the in-memory cache.
-  static absl::Status LoadAutotuneResults(absl::string_view data,
-                                          bool as_textproto = false);
-
-  // Loads autotune results from the given proto.
-  //
-  // Warning: The results are only loaded to the in-memory cache.
-  static absl::Status LoadAutotuneResults(const AutotuneResults& results);
-
-  // Serializes autotune results into a file.
-  //
-  // If `file_path` ends with ".txt" or ".textproto", then the textproto format
-  // is used, otherwise the binary protobuf format.
-  static absl::Status SerializeAutotuneResultsToFile(
-      absl::string_view file_path);
-
-  // As above, but if you already called SerializeAutotuneResults to get a
-  // proto.
-  static absl::Status SerializeAutotuneResultsToFile(
-      const AutotuneResults& results, absl::string_view file_path);
-
-  // Loads autotune results from a file.
-  //
-  // If `file_path` ends with ".txt" or ".textproto", then the file is
-  // considered to be in the textproto format, otherwise the binary protobuf
-  // format.
-  //
-  // Warning: The results are only loaded to the in-memory cache.
-  static absl::Status LoadAutotuneResultsFromFile(absl::string_view file_path);
-
-  // Warning: This only clears the in-memory cache. If you use a file based
-  // cache you're responsible for clearing the cache directory when you want to.
-  static void ClearAutotuneResults();
-
-  // Warning: This only checks the in-memory cache. If you use a file based
-  // cache, you're responsible for checking whether the cache directory is
-  // empty.
-  static bool ResultCacheIsEmpty();
-};
-
-absl::StatusOr<std::string> AutotuneResultsToString(
-    const AutotuneResults& results, bool as_textproto);
-
-// Exposed only for testing. Returns the SHA-256 hash of the input string,
-// encoded in base64.
-//
-// SHA-256 was chosen to follow industry best practices and avoid collisions.
-// Git is also transitioning to SHA-256. This is probably better than
-// tsl::Fingerprint128.
-absl::StatusOr<std::string> GetBase64EncodedSha256Hash(absl::string_view s);
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_AUTOTUNER_UTIL_H_
diff --git a/third_party/xla/xla/service/gpu/autotuner_util_test.cc b/third_party/xla/xla/service/gpu/autotuner_util_test.cc
deleted file mode 100644
index 69c1395..0000000
--- a/third_party/xla/xla/service/gpu/autotuner_util_test.cc
+++ /dev/null
@@ -1,459 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/autotuner_util.h"
-
-#include <memory>
-#include <string>
-#include <vector>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "absl/container/flat_hash_set.h"
-#include "absl/log/check.h"
-#include "absl/status/status.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/string_view.h"
-#include "xla/autotune_results.pb.h"
-#include "xla/autotuning.pb.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/hlo/utils/hlo_query.h"
-#include "xla/stream_executor/platform.h"
-#include "xla/stream_executor/platform_manager.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/tsl/lib/core/status_test_util.h"
-#include "xla/xla.pb.h"
-#include "tsl/platform/env.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h"  // IWYU pragma: keep
-#include "tsl/platform/path.h"
-#include "tsl/platform/protobuf.h"  // IWYU pragma: keep
-#include "tsl/platform/status.h"
-#include "tsl/platform/status_matchers.h"
-#include "tsl/platform/statusor.h"
-#include "tsl/platform/test.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-using ::testing::ElementsAre;
-using ::testing::HasSubstr;
-using ::testing::IsEmpty;
-using ::testing::Not;
-using ::testing::TempDir;
-using ::tsl::testing::StatusIs;
-
-class AutotunerUtilTest : public HloTestBase {
- protected:
-  static constexpr absl::string_view kHloText = R"(
-HloModule t
-
-ENTRY e {
-  p0 = f16[1,16,17,3] parameter(0)
-  p1 = s8[16,17,3] parameter(1)
-  cp1 = f16[16,17,3] convert(p1)
-  ROOT _ = f16[1,16,16] dot(p0, cp1),
-    lhs_contracting_dims={2,3}, rhs_contracting_dims={1,2}
-})";
-
-  static constexpr absl::string_view kResultText = R"(
-version: 3
-results {
-  device: "CUDA: 8.0, Cores: 108, GPU clock: 1.41 GHz, Memory bandwidth: 1555 GB/s, L2 cache: 40 MB"
-  hlo: "{\n  tmp_0 = f16[1,16,17,3]{3,2,1,0} parameter(0)\n  tmp_1 = f16[16,51]{1,0} bitcast(f16[1,16,17,3]{3,2,1,0} tmp_0)\n  tmp_2 = s8[16,17,3]{2,1,0} parameter(1)\n  tmp_3 = s8[51,16]{0,1} bitcast(s8[16,17,3]{2,1,0} tmp_2)\n  tmp_4 = f16[51,16]{0,1} convert(s8[51,16]{0,1} tmp_3)\n  tmp_5 = f16[16,16]{1,0} dot(f16[16,51]{1,0} tmp_1, f16[51,16]{0,1} tmp_4), lhs_contracting_dims={1}, rhs_contracting_dims={0}\n  ROOT tmp_6 = f16[1,16,16]{2,1,0} bitcast(f16[16,16]{1,0} tmp_5)\n}"
-  result {
-    run_time {
-      nanos: 31744
-    }
-    triton {
-      block_m: 32
-      block_n: 32
-      block_k: 32
-      split_k: 1
-      num_stages: 1
-      num_warps: 4
-      num_ctas: 1
-    }
-  }
-})";
-
-  void SetUp() override { AutotunerUtil::ClearAutotuneResults(); }
-
-  std::string GetUniqueTempFilePath(absl::string_view suffix) {
-    std::string filename = TempDir();
-    CHECK(tsl::Env::Default()->CreateUniqueFileName(&filename,
-                                                    std::string(suffix)));
-    return filename;
-  }
-
-  std::string ExpectToReadNonEmptyFile(absl::string_view file_path) {
-    std::string str;
-    tsl::Env* env = tsl::Env::Default();
-    TF_EXPECT_OK(tsl::ReadFileToString(env, std::string(file_path), &str));
-    EXPECT_THAT(str, Not(IsEmpty()));
-    return str;
-  }
-
-  static std::unique_ptr<stream_executor::StreamExecutor> NewStreamExecutor() {
-    stream_executor::Platform* platform =
-        stream_executor::PlatformManager::PlatformWithName("Host").value();
-    stream_executor::StreamExecutorConfig config(/*ordinal=*/0);
-    return platform->GetUncachedExecutor(config).value();
-  }
-
-  absl::Status PopulateResultCache() {
-    EXPECT_TRUE(AutotunerUtil::ResultCacheIsEmpty());
-    TF_RETURN_IF_ERROR(AutotunerUtil::LoadAutotuneResults(kResultText, true));
-    EXPECT_FALSE(AutotunerUtil::ResultCacheIsEmpty());
-    return absl::OkStatus();
-  }
-};
-
-TEST_F(AutotunerUtilTest, SerializeAutotuneResultsToFile_TextProto1) {
-  TF_EXPECT_OK(PopulateResultCache());
-  std::string kFilePath = GetUniqueTempFilePath(".txt");
-  TF_EXPECT_OK(AutotunerUtil::SerializeAutotuneResultsToFile(kFilePath));
-
-  std::string autotune_results_str = ExpectToReadNonEmptyFile(kFilePath);
-  AutotuneResults results;
-  EXPECT_TRUE(tsl::protobuf::TextFormat::ParseFromString(autotune_results_str,
-                                                         &results));
-  EXPECT_GT(results.results_size(), 0);
-}
-
-TEST_F(AutotunerUtilTest, SerializeAutotuneResultsToFile_TextProto2) {
-  TF_EXPECT_OK(PopulateResultCache());
-  std::string kFilePath = GetUniqueTempFilePath(".textproto");
-  TF_EXPECT_OK(AutotunerUtil::SerializeAutotuneResultsToFile(kFilePath));
-
-  std::string autotune_results_str = ExpectToReadNonEmptyFile(kFilePath);
-  AutotuneResults results;
-  EXPECT_TRUE(tsl::protobuf::TextFormat::ParseFromString(autotune_results_str,
-                                                         &results));
-}
-
-TEST_F(AutotunerUtilTest, SerializeAutotuneResultsToFile_Protobuf) {
-  TF_EXPECT_OK(PopulateResultCache());
-  std::string kFilePath = GetUniqueTempFilePath(".pb");
-  TF_EXPECT_OK(AutotunerUtil::SerializeAutotuneResultsToFile(kFilePath));
-
-  std::string autotune_results_str = ExpectToReadNonEmptyFile(kFilePath);
-  AutotuneResults results;
-  EXPECT_TRUE(results.ParseFromString(autotune_results_str));
-}
-
-TEST_F(AutotunerUtilTest, LoadAutotuneResultsFromFile_TextProto1) {
-  TF_EXPECT_OK(PopulateResultCache());
-  std::string kFilePath = GetUniqueTempFilePath(".txt");
-  TF_EXPECT_OK(AutotunerUtil::SerializeAutotuneResultsToFile(kFilePath));
-  AutotunerUtil::ClearAutotuneResults();
-  EXPECT_TRUE(AutotunerUtil::ResultCacheIsEmpty());
-
-  TF_EXPECT_OK(AutotunerUtil::LoadAutotuneResultsFromFile(kFilePath));
-  EXPECT_FALSE(AutotunerUtil::ResultCacheIsEmpty());
-}
-
-TEST_F(AutotunerUtilTest, LoadAutotuneResultsFromFile_TextProto2) {
-  TF_EXPECT_OK(PopulateResultCache());
-  std::string kFilePath = GetUniqueTempFilePath(".textproto");
-  TF_EXPECT_OK(AutotunerUtil::SerializeAutotuneResultsToFile(kFilePath));
-  AutotunerUtil::ClearAutotuneResults();
-  EXPECT_TRUE(AutotunerUtil::ResultCacheIsEmpty());
-
-  TF_EXPECT_OK(AutotunerUtil::LoadAutotuneResultsFromFile(kFilePath));
-  EXPECT_FALSE(AutotunerUtil::ResultCacheIsEmpty());
-}
-
-TEST_F(AutotunerUtilTest, LoadAutotuneResultsFromFile_Protobuf) {
-  TF_EXPECT_OK(PopulateResultCache());
-  std::string kFilePath = GetUniqueTempFilePath(".pb");
-  TF_EXPECT_OK(AutotunerUtil::SerializeAutotuneResultsToFile(kFilePath));
-  AutotunerUtil::ClearAutotuneResults();
-  EXPECT_TRUE(AutotunerUtil::ResultCacheIsEmpty());
-
-  TF_EXPECT_OK(AutotunerUtil::LoadAutotuneResultsFromFile(kFilePath));
-  EXPECT_FALSE(AutotunerUtil::ResultCacheIsEmpty());
-}
-
-TEST_F(AutotunerUtilTest, ResultConflictsAreDetected) {
-  TF_EXPECT_OK(PopulateResultCache());
-  std::string kFilePath = GetUniqueTempFilePath(".pb");
-  TF_EXPECT_OK(AutotunerUtil::SerializeAutotuneResultsToFile(kFilePath));
-  EXPECT_THAT(AutotunerUtil::LoadAutotuneResultsFromFile(kFilePath),
-              StatusIs(absl::StatusCode::kInternal,
-                       HasSubstr("Duplicate autotuning result")));
-}
-
-// Test that when complete AOT autotuning is required, and there is cache miss,
-// a `NotFound` error will be raised.
-TEST_F(AutotunerUtilTest, FailIfRequireCompleteAotAutotuning) {
-  std::string kFilePath = GetUniqueTempFilePath(".txt");
-  auto hlo_module = GetOptimizedModule(kHloText);
-  TF_EXPECT_OK(hlo_module.status());
-  std::vector<HloComputation*> computations =
-      (*hlo_module)
-          ->MakeNonfusionComputations(absl::flat_hash_set<absl::string_view>());
-  EXPECT_THAT(computations, Not(IsEmpty()));
-  const HloInstruction* instruction = *computations[0]->instructions().begin();
-  std::unique_ptr<stream_executor::StreamExecutor> executor =
-      NewStreamExecutor();
-  auto options = DebugOptions();
-  options.set_xla_gpu_require_complete_aot_autotune_results(true);
-  AutotuneConfig config(DeviceConfig{executor.get()}, options);
-  EXPECT_THAT(
-      AutotunerUtil::Autotune(instruction, config,
-                              [&] { return AutotuneResult(); }),
-      StatusIs(
-          absl::StatusCode::kNotFound,
-          HasSubstr("Complete XLA AOT autotuning results are required, but "
-                    "no AOT result was found for key: <key model")));
-}
-
-// Test that when JIT autotuning is disabled, but no cache miss due to AOT
-// autotuning, `Autotune` still returns Ok status.
-TEST_F(AutotunerUtilTest, OkIfJitAutotuningDisabledButAlreadyLoadedAOT) {
-  auto hlo_module = GetOptimizedModule(kHloText);
-  std::vector<HloComputation*> computations =
-      (*hlo_module)
-          ->MakeNonfusionComputations(absl::flat_hash_set<absl::string_view>());
-  EXPECT_THAT(computations, Not(IsEmpty()));
-  const HloInstruction* instruction = *computations[0]->instructions().begin();
-  std::unique_ptr<stream_executor::StreamExecutor> executor =
-      NewStreamExecutor();
-
-  {
-    // By default, JIT autotuning is OK.
-    AutotuneConfig config(DeviceConfig{executor.get()}, DebugOptions());
-    TF_EXPECT_OK(AutotunerUtil::Autotune(instruction, config, [&] {
-                   return AutotuneResult();
-                 }).status());
-  }
-
-  // Now require complete AOT autotuning results.
-  auto options = DebugOptions();
-  options.set_xla_gpu_require_complete_aot_autotune_results(true);
-
-  AutotuneConfig config(DeviceConfig{executor.get()}, options);
-  // Even though JIT autotuning is disabled, there is no cache miss when running
-  // autotuning for the same entry, so no error should be raised either.
-  TF_EXPECT_OK(AutotunerUtil::Autotune(instruction, config, [&] {
-                 return AutotuneResult();
-               }).status());
-}
-
-class FileBasedCacheTest : public AutotunerUtilTest {
- public:
-  static std::string ToString(const AutotuneResult& message) {
-    std::string textproto;
-    CHECK(tsl::protobuf::TextFormat::PrintToString(message, &textproto));
-    return textproto;
-  }
-
-  static std::vector<std::string> GetFilesInDir(
-      const absl::string_view cache_dir) {
-    std::vector<std::string> files_in_cache;
-    TF_CHECK_OK(tsl::Env::Default()->GetChildren(std::string(cache_dir),
-                                                 &files_in_cache));
-    return files_in_cache;
-  }
-
-  static std::string Read(const absl::string_view filepath) {
-    std::string file_content;
-    TF_CHECK_OK(tsl::ReadFileToString(tsl::Env::Default(),
-                                      std::string(filepath), &file_content));
-    return file_content;
-  }
-
-  static void Write(const absl::string_view filepath,
-                    const absl::string_view content) {
-    TF_CHECK_OK(tsl::WriteStringToFile(tsl::Env::Default(),
-                                       std::string(filepath), content));
-  }
-
-  std::unique_ptr<stream_executor::StreamExecutor> executor_ =
-      NewStreamExecutor();
-  std::unique_ptr<HloModule> module_ =
-      ParseAndReturnVerifiedModule(kHloText).value();
-  const HloInstruction* dot_ = hlo_query::GetFirstInstructionWithOpcode(
-      *module_->entry_computation(), HloOpcode::kDot);
-  std::string cache_dir_ = [] {
-    tsl::Env* default_env = tsl::Env::Default();
-    std::string cache_dir;
-    CHECK(default_env->LocalTempFilename(&cache_dir));
-    CHECK_OK(default_env->CreateDir(cache_dir));
-    return cache_dir;
-  }();
-  AutotuneConfig config_ = AutotuneConfig(DeviceConfig{executor_.get()}, [&] {
-    DebugOptions options;
-    options.set_xla_gpu_per_fusion_autotune_cache_dir(cache_dir_);
-    return options;
-  }());
-  AutotuneCacheKey cache_key_ = AutotunerUtil::GetKey(dot_, config_);
-  std::string cache_filename_ = [&] {
-    absl::StatusOr<std::string> key_hash =
-        GetBase64EncodedSha256Hash(cache_key_.ToString());
-    CHECK_OK(key_hash.status());
-    return absl::StrCat(key_hash.value(), ".textproto");
-  }();
-  std::string cache_file_path_ = tsl::io::JoinPath(cache_dir_, cache_filename_);
-  const AutotuneResult result1_ = [] {
-    AutotuneResult result;
-    result.set_scratch_bytes(1);
-    return result;
-  }();
-  const AutotuneResult result2_ = [] {
-    AutotuneResult result;
-    result.set_scratch_bytes(2);
-    return result;
-  }();
-};
-
-TEST_F(FileBasedCacheTest, AutotuneWritesResultToTheCacheDir) {
-  TF_ASSERT_OK_AND_ASSIGN(
-      AutotuneResult result,
-      AutotunerUtil::Autotune(dot_, config_, [&] { return result1_; }));
-  EXPECT_EQ(ToString(result), ToString(result1_));
-
-  ASSERT_THAT(GetFilesInDir(cache_dir_), ElementsAre(cache_filename_));
-  EXPECT_EQ(Read(cache_file_path_), ToString(result1_));
-}
-
-TEST_F(FileBasedCacheTest, AutotuneReadsResultFromTheCacheDir) {
-  Write(cache_file_path_, ToString(result1_));
-
-  bool cache_hit = true;
-  TF_ASSERT_OK_AND_ASSIGN(AutotuneResult result,
-                          AutotunerUtil::Autotune(dot_, config_, [&] {
-                            cache_hit = false;
-                            return result2_;
-                          }));
-
-  EXPECT_TRUE(cache_hit);
-  EXPECT_EQ(ToString(result), ToString(result1_));
-}
-
-TEST_F(FileBasedCacheTest,
-       RepeatedAutotuneCallsDontReadOrWriteTheCacheFileAgain) {
-  auto check_autotune_cache_hit = [](const HloInstruction* instr,
-                                     const AutotuneConfig& config,
-                                     const AutotuneResult& expected_result) {
-    bool cache_hit = true;
-    TF_ASSERT_OK_AND_ASSIGN(AutotuneResult result,
-                            AutotunerUtil::Autotune(instr, config, [&] {
-                              cache_hit = false;
-                              AutotuneResult new_result;
-                              new_result.set_scratch_bytes(2);
-                              return new_result;
-                            }));
-    EXPECT_TRUE(cache_hit);
-    EXPECT_EQ(ToString(result), ToString(expected_result));
-  };
-
-  Write(cache_file_path_, ToString(result1_));
-  check_autotune_cache_hit(dot_, config_, /*expected_result=*/result1_);
-
-  constexpr absl::string_view kPlaceholderContent = "placeholder content";
-  Write(cache_file_path_, kPlaceholderContent);
-  // File was not read again:
-  check_autotune_cache_hit(dot_, config_, /*expected_result=*/result1_);
-  // File was not written again:
-  EXPECT_EQ(Read(cache_file_path_), kPlaceholderContent);
-}
-
-TEST_F(FileBasedCacheTest,
-       IsInCacheReturnsTrueIfTheResultIsInTheFileBasedCache) {
-  Write(cache_file_path_, ToString(result1_));
-
-  TF_ASSERT_OK_AND_ASSIGN(bool is_in_cache,
-                          AutotunerUtil::IsInCache(cache_key_, config_));
-
-  EXPECT_TRUE(is_in_cache);
-}
-
-TEST_F(FileBasedCacheTest, IsInCacheReturnsFalseIfTheResultIsNotInEitherCache) {
-  TF_ASSERT_OK_AND_ASSIGN(bool is_in_cache,
-                          AutotunerUtil::IsInCache(cache_key_, config_));
-
-  EXPECT_FALSE(is_in_cache);
-}
-
-TEST_F(FileBasedCacheTest, AddResultAddsTheResultToTheFileBasedCache) {
-  TF_ASSERT_OK_AND_ASSIGN(
-      bool added, AutotunerUtil::AddResult(cache_key_, result1_, config_));
-  EXPECT_TRUE(added);
-
-  ASSERT_THAT(GetFilesInDir(cache_dir_), ElementsAre(cache_filename_));
-  EXPECT_EQ(Read(cache_file_path_), ToString(result1_));
-}
-
-TEST_F(FileBasedCacheTest, RepeatedAddResultDoesNotWriteTheFileAgain) {
-  {
-    TF_ASSERT_OK_AND_ASSIGN(
-        bool added, AutotunerUtil::AddResult(cache_key_, result1_, config_));
-    EXPECT_TRUE(added);
-  }
-  ASSERT_THAT(GetFilesInDir(cache_dir_), ElementsAre(cache_filename_));
-  EXPECT_EQ(Read(cache_file_path_), ToString(result1_));
-  constexpr absl::string_view kPlaceholderContent = "placeholder content";
-  Write(cache_file_path_, kPlaceholderContent);
-
-  {
-    TF_ASSERT_OK_AND_ASSIGN(
-        bool added, AutotunerUtil::AddResult(cache_key_, result1_, config_));
-    EXPECT_FALSE(added);
-  }
-
-  // File was not written again:
-  EXPECT_EQ(Read(cache_file_path_), kPlaceholderContent);
-}
-
-TEST(AutotuneCacheKeyTest, DeviceDescriptionToCacheKey) {
-  auto device_description =
-      [](absl::string_view spec_file_name) -> se::DeviceDescription {
-    se::GpuTargetConfigProto proto;
-    std::string spec_string;
-    CHECK_OK(tsl::ReadFileToString(
-        tsl::Env::Default(),
-        tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "tools", "hlo_opt",
-                          "gpu_specs", spec_file_name),
-        &spec_string));
-    EXPECT_TRUE(
-        tsl::protobuf::TextFormat::ParseFromString(spec_string, &proto));
-    return se::DeviceDescription(proto.gpu_device_info());
-  };
-
-  EXPECT_EQ(AutotuneCacheKey::DeviceDescriptionToCacheKey(
-                device_description("a100_sxm_40.txtpb")),
-            "CUDA: 8.0, Cores: 108, GPU clock: 1.41 GHz, Memory bandwidth: "
-            "1555 GB/s, L2 cache: 40 MB");
-
-  EXPECT_EQ(AutotuneCacheKey::DeviceDescriptionToCacheKey(
-                device_description("a100_sxm_80.txtpb")),
-            "CUDA: 8.0, Cores: 108, GPU clock: 1.41 GHz, Memory bandwidth: "
-            "2039 GB/s, L2 cache: 40 MB");
-
-  EXPECT_EQ(AutotuneCacheKey::DeviceDescriptionToCacheKey(
-                device_description("mi200.txtpb")),
-            "ROCM: gfx90a, Cores: 110, GPU clock: 1.7 GHz, Memory bandwidth: "
-            "1638 GB/s, L2 cache: 8 MB");
-}
-
-}  // namespace
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/autotuning/BUILD b/third_party/xla/xla/service/gpu/autotuning/BUILD
new file mode 100644
index 0000000..aa82b86
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/autotuning/BUILD
@@ -0,0 +1,542 @@
+# Description:
+#   Components that implement GPU autotuning.
+
+load(
+    "@local_tsl//tsl/platform:build_config.bzl",
+    "tf_proto_library",
+)
+load(
+    "@local_tsl//tsl/platform/default:cuda_build_defs.bzl",
+    "if_cuda_is_configured",
+)
+load("//xla:xla.bzl", "xla_cc_test")
+load("//xla/tests:build_defs.bzl", "xla_test")
+
+package(
+    # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
+    default_visibility = [":friends"],
+    licenses = ["notice"],
+)
+
+package_group(
+    name = "friends",
+    includes = [
+        "//xla:friends",
+    ],
+)
+
+cc_library(
+    name = "gemm_fusion_autotuner",
+    srcs = ["gemm_fusion_autotuner.cc"],
+    hdrs = ["gemm_fusion_autotuner.h"],
+    tags = [
+        "gpu",
+        "no_rocm",
+    ],
+    deps = [
+        ":autotuner_compile_util",
+        ":autotuner_util",
+        "//xla:autotuning_proto_cc",
+        "//xla:shape_util",
+        "//xla:status_macros",
+        "//xla:util",
+        "//xla:xla_data_proto_cc",
+        "//xla:xla_proto_cc",
+        "//xla/hlo/ir:hlo",
+        "//xla/hlo/utils:hlo_query",
+        "//xla/pjrt/distributed:key_value_store_interface",
+        "//xla/service:algorithm_util",
+        "//xla/service:dump",
+        "//xla/service:executable",
+        "//xla/service:float_normalization",
+        "//xla/service:hlo_module_config",
+        "//xla/service:hlo_pass",
+        "//xla/service:shaped_buffer",
+        "//xla/service/gpu:backend_configs_cc",
+        "//xla/service/gpu:buffer_comparator",
+        "//xla/service/gpu:gpu_float_support",
+        "//xla/service/gpu:hlo_traversal",
+        "//xla/service/gpu:ir_emission_utils",
+        "//xla/service/gpu:matmul_utils",
+        "//xla/service/gpu:split_k_gemm_rewriter",
+        "//xla/service/gpu:stream_executor_util",
+        "//xla/service/gpu/model:gpu_hlo_cost_analysis",
+        "//xla/service/gpu/transforms:cudnn_fusion_compiler",
+        "//xla/service/gpu/transforms:fusion_wrapper",
+        "//xla/service/gpu/transforms:gemm_rewriter",
+        "//xla/service/gpu/transforms:instruction_fusion",
+        "//xla/service/gpu/transforms:priority_fusion",
+        "//xla/stream_executor",
+        "//xla/stream_executor:device_description",
+        "//xla/stream_executor:device_memory",
+        "//xla/stream_executor:stream_executor_memory_allocator",
+        "//xla/stream_executor/gpu:redzone_allocator",
+        "//xla/tools:hlo_decomposer_lib",
+        "//xla/tsl/util/proto:proto_utils",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/strings:str_format",
+        "@com_google_absl//absl/synchronization",
+        "@com_google_absl//absl/time",
+        "@com_google_absl//absl/types:span",
+        "@local_config_cuda//cuda:cuda_headers",
+        "@local_tsl//tsl/lib/core:bits",
+        "@local_tsl//tsl/platform:blocking_counter",
+        "@local_tsl//tsl/platform:env",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:path",
+        "@local_tsl//tsl/platform:protobuf",
+        "@local_tsl//tsl/platform:status",
+        "@local_tsl//tsl/platform:statusor",
+        "@local_tsl//tsl/profiler/lib:scoped_annotation",
+    ],
+)
+
+xla_test(
+    name = "gemm_fusion_autotuner_test",
+    timeout = "long",
+    srcs = ["gemm_fusion_autotuner_test.cc"],
+    backend_tags = {"gpu": [
+        "requires-gpu-sm80",
+    ]},
+    backends = [
+        "gpu",
+    ],
+    tags = [
+        "no_rocm",
+        "nomac",
+    ],
+    deps = [
+        ":autotuner_compile_util",
+        ":autotuner_util",
+        ":gemm_fusion_autotuner",
+        "//xla:autotuning_proto_cc",
+        "//xla:error_spec",
+        "//xla:xla_data_proto_cc",
+        "//xla:xla_proto_cc",
+        "//xla/hlo/ir:hlo",
+        "//xla/pjrt/distributed:key_value_store_interface",
+        "//xla/service:call_inliner",
+        "//xla/service:dump",
+        "//xla/service:executable",
+        "//xla/service:hlo_module_config",
+        "//xla/service:hlo_pass_pipeline",
+        "//xla/service:pattern_matcher",
+        "//xla/service:pattern_matcher_gmock",
+        "//xla/service/gpu:backend_configs_cc",
+        "//xla/service/gpu:ir_emission_utils",
+        "//xla/service/gpu:matmul_utils",
+        "//xla/service/gpu/transforms:gemm_fusion",
+        "//xla/service/gpu/transforms:gemm_rewriter",
+        "//xla/stream_executor:device_description",
+        "//xla/stream_executor:device_description_proto_cc",
+        "//xla/tests:filecheck",
+        "//xla/tests:hlo_test_base",
+        "//xla/tests:test_utils",
+        "//xla/tests:verified_hlo_module",
+        "//xla/tests:xla_internal_test_main",  # fixdeps: keep
+        "//xla/tools:hlo_decomposer_lib",
+        "//xla/tsl/lib/core:status_test_util",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/strings:str_format",
+        "@com_google_googletest//:gtest",
+        "@local_config_cuda//cuda:cuda_headers",
+        "@local_tsl//tsl/platform:env",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:path",
+        "@local_tsl//tsl/platform:platform_port",
+        "@local_tsl//tsl/platform:status_matchers",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+cc_library(
+    name = "gemm_algorithm_picker",
+    srcs = ["gemm_algorithm_picker.cc"],
+    hdrs = ["gemm_algorithm_picker.h"],
+    tags = ["gpu"],
+    deps = [
+        ":autotuner_compile_util",
+        ":autotuner_util",
+        "//xla:autotune_results_proto_cc",
+        "//xla:autotuning_proto_cc",
+        "//xla:shape_util",
+        "//xla:util",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_module_config",
+        "//xla/service:hlo_pass",
+        "//xla/service/gpu:backend_configs_cc",
+        "//xla/service/gpu:buffer_comparator",
+        "//xla/service/gpu:cublas_cudnn",
+        "//xla/service/gpu:matmul_utils",
+        "//xla/service/gpu:stream_executor_util",
+        "//xla/service/gpu:variant_visitor",
+        "//xla/stream_executor",
+        "//xla/stream_executor:blas",
+        "//xla/stream_executor:device_memory",
+        "//xla/stream_executor:device_memory_allocator",
+        "//xla/stream_executor/gpu:redzone_allocator",
+        "//xla/tsl/util/proto:proto_utils",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/synchronization",
+        "@com_google_absl//absl/types:span",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:logging",
+        "@local_tsl//tsl/platform:statusor",
+        "@local_tsl//tsl/profiler/lib:scoped_annotation",
+    ],
+)
+
+cc_library(
+    name = "autotuner_util",
+    srcs = ["autotuner_util.cc"],
+    hdrs = ["autotuner_util.h"],
+    tags = ["gpu"],
+    deps = [
+        "//xla:autotune_results_proto_cc",
+        "//xla:autotuning_proto_cc",
+        "//xla:shape_util",
+        "//xla:status_macros",
+        "//xla:util",
+        "//xla:xla_proto_cc",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:dump",
+        "//xla/service/gpu:gpu_asm_opts_util",
+        "//xla/service/gpu:stream_executor_util",
+        "//xla/stream_executor:device_description",
+        "//xla/stream_executor:device_memory",
+        "//xla/stream_executor:device_memory_allocator",
+        "//xla/stream_executor:stream",
+        "//xla/stream_executor:stream_executor_h",
+        "//xla/stream_executor:stream_executor_memory_allocator",
+        "//xla/stream_executor/gpu:redzone_allocator",
+        "@com_google_absl//absl/base:core_headers",
+        "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/strings:str_format",
+        "@com_google_absl//absl/synchronization",
+        "@llvm-project//llvm:Support",
+        "@local_tsl//tsl/platform:base64",
+        "@local_tsl//tsl/platform:env",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:logging",
+        "@local_tsl//tsl/platform:path",
+        "@local_tsl//tsl/platform:protobuf",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+# We need a separate target, as runtime executable cannot depend on compilation
+# pipeline.
+cc_library(
+    name = "autotuner_compile_util",
+    srcs = ["autotuner_compile_util.cc"],
+    hdrs = ["autotuner_compile_util.h"],
+    tags = ["gpu"],
+    deps = [
+        ":autotuner_util",
+        "//xla:executable_run_options",
+        "//xla:shape_util",
+        "//xla:util",
+        "//xla:xla_proto_cc",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:compiler",
+        "//xla/service:executable",
+        "//xla/service:maybe_owning_device_memory",
+        "//xla/service:shaped_buffer",
+        "//xla/service/gpu:gpu_executable_run_options",
+        "//xla/service/gpu:ir_emission_utils",
+        "//xla/stream_executor",
+        "//xla/stream_executor/gpu:redzone_allocator",
+        "@com_google_absl//absl/functional:any_invocable",
+        "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/time",
+        "@com_google_absl//absl/types:span",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+xla_test(
+    name = "autotuner_compile_util_test",
+    srcs = ["autotuner_compile_util_test.cc"],
+    backends = ["gpu"],
+    deps = [
+        ":autotuner_compile_util",
+        ":autotuner_util",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:platform_util",
+        "//xla/stream_executor:platform",
+        "//xla/tests:hlo_test_base",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/strings:string_view",
+        "@com_google_googletest//:gtest_main",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+xla_test(
+    name = "gemm_algorithm_picker_test",
+    srcs = ["gemm_algorithm_picker_test.cc"],
+    backends = [
+        "gpu_v100",
+        "gpu_amd_any",
+    ],
+    deps = [
+        ":autotuner_util",
+        ":gemm_algorithm_picker",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:pattern_matcher",
+        "//xla/service:pattern_matcher_gmock",
+        "//xla/service:platform_util",
+        "//xla/service/gpu:backend_configs_cc",
+        "//xla/service/gpu:variant_visitor",
+        "//xla/service/gpu/transforms:gemm_rewriter",
+        "//xla/stream_executor:device_description",
+        "//xla/stream_executor:platform",
+        "//xla/tests:hlo_test_base",
+        "//xla/tsl/lib/core:status_test_util",
+        "@com_google_absl//absl/strings:string_view",
+        "@local_tsl//tsl/platform:statusor",
+        "@local_tsl//tsl/platform:test",
+        "@local_tsl//tsl/platform:test_main",
+        "@local_tsl//tsl/protobuf:dnn_proto_cc",
+    ],
+)
+
+cc_library(
+    name = "conv_algorithm_picker",
+    srcs = ["conv_algorithm_picker.cc"],
+    hdrs = ["conv_algorithm_picker.h"],
+    local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
+    tags = ["gpu"],
+    deps = [
+        ":autotuner_compile_util",
+        ":autotuner_util",
+        ":gpu_autotuning_proto_cc",
+        "//xla:autotune_results_proto_cc",
+        "//xla:autotuning_proto_cc",
+        "//xla:debug_options_flags",
+        "//xla:literal_util",
+        "//xla:shape_util",
+        "//xla:util",
+        "//xla:xla_data_proto_cc",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:executable",
+        "//xla/service:hlo_module_config",
+        "//xla/service:hlo_pass",
+        "//xla/service:slow_operation_alarm",
+        "//xla/service/gpu:backend_configs_cc",
+        "//xla/service/gpu:cublas_cudnn",
+        "//xla/service/gpu:gpu_asm_opts_util",
+        "//xla/service/gpu:gpu_conv_runner",
+        "//xla/service/gpu:hlo_algorithm_denylist",
+        "//xla/service/gpu:stream_executor_util",
+        "//xla/stream_executor",
+        "//xla/stream_executor:device_memory_allocator",
+        "//xla/stream_executor:dnn",
+        "//xla/stream_executor:lazy_op_runner",
+        "//xla/stream_executor:numeric_options",
+        "//xla/stream_executor:scratch_allocator",
+        "//xla/stream_executor/cuda:cuda_platform_id",
+        "//xla/stream_executor/rocm:rocm_platform_id",
+        "//xla/tsl/util:env_var",
+        "//xla/tsl/util/proto:proto_utils",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/strings:str_format",
+        "@com_google_absl//absl/synchronization",
+        "@com_google_absl//absl/time",
+        "@com_google_absl//absl/types:span",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:logging",
+        "@local_tsl//tsl/platform:numbers",
+        "@local_tsl//tsl/platform:status",
+        "@local_tsl//tsl/platform:statusor",
+    ] + if_cuda_is_configured([
+        # keep sorted
+        "//xla/service/gpu:buffer_comparator",
+        "//xla/stream_executor/gpu:redzone_allocator",
+        "@local_config_cuda//cuda:cudnn_header",
+    ]),
+)
+
+xla_test(
+    name = "conv_algorithm_picker_test",
+    srcs = ["conv_algorithm_picker_test.cc"],
+    backends = [
+        "gpu_v100",
+        "gpu_amd_any",
+    ],
+    tags = [
+        "noasan",
+        "nomsan",
+    ],
+    deps = [
+        ":autotuner_util",
+        ":conv_algorithm_picker",
+        "//xla:debug_options_flags",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:pattern_matcher",
+        "//xla/service:pattern_matcher_gmock",
+        "//xla/service:platform_util",
+        "//xla/service:tuple_simplifier",
+        "//xla/service/gpu:stream_executor_util",
+        "//xla/service/gpu/transforms:conv_rewriter",
+        "//xla/stream_executor:device_description",
+        "//xla/stream_executor:platform",
+        "//xla/tests:hlo_test_base",
+        "//xla/tsl/lib/core:status_test_util",
+        "@com_google_absl//absl/strings:string_view",
+        "@local_tsl//tsl/platform:statusor",
+        "@local_tsl//tsl/platform:test",
+        "@local_tsl//tsl/platform:test_main",
+    ],
+)
+
+cc_library(
+    name = "custom_kernel_fusion_autotuner",
+    srcs = ["custom_kernel_fusion_autotuner.cc"],
+    hdrs = ["custom_kernel_fusion_autotuner.h"],
+    tags = ["gpu"],
+    deps = [
+        ":autotuner_compile_util",
+        ":autotuner_util",
+        "//xla:autotuning_proto_cc",
+        "//xla:status_macros",
+        "//xla:util",
+        "//xla:xla_proto_cc",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:executable",
+        "//xla/service:hlo_pass",
+        "//xla/service:shaped_buffer",
+        "//xla/service/gpu:backend_configs_cc",
+        "//xla/service/gpu:buffer_comparator",
+        "//xla/service/gpu:gpu_float_support",
+        "//xla/service/gpu:gpu_fusible",
+        "//xla/service/gpu:hlo_traversal",
+        "//xla/service/gpu:ir_emission_utils",
+        "//xla/service/gpu:matmul_utils",
+        "//xla/service/gpu:split_k_gemm_rewriter",
+        "//xla/service/gpu:stream_executor_util",
+        "//xla/service/gpu/kernels:custom_kernel",
+        "//xla/service/gpu/kernels:custom_kernel_fusion",
+        "//xla/stream_executor",
+        "//xla/stream_executor:device_description",
+        "//xla/stream_executor:stream_executor_memory_allocator",
+        "//xla/stream_executor/gpu:redzone_allocator",
+        "//xla/tools:hlo_decomposer_lib",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/time",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:path",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+xla_test(
+    name = "custom_kernel_fusion_autotuner_test",
+    srcs = ["custom_kernel_fusion_autotuner_test.cc"],
+    backends = [
+        "gpu",
+    ],
+    tags = ["no_rocm"],
+    deps = [
+        ":autotuner_util",
+        ":custom_kernel_fusion_autotuner",
+        "//xla:xla_proto_cc",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_pass_pipeline",
+        "//xla/tests:hlo_test_base",
+        "//xla/tests:xla_internal_test_main",  # fixdeps: keep
+        "@com_google_googletest//:gtest",
+        "@local_tsl//tsl/platform:path",
+        "@local_tsl//tsl/platform:test",
+    ],
+)
+
+tf_proto_library(
+    name = "gpu_autotuning_proto",
+    srcs = ["gpu_autotuning.proto"],
+    cc_api_version = 2,
+    protodeps = [
+        "//xla/service/gpu:backend_configs",
+        "//xla:xla_data_proto",
+        "//xla/service:hlo_proto",
+        "//xla:autotuning_proto",
+    ],
+)
+
+xla_cc_test(
+    name = "autotuner_util_test",
+    srcs = ["autotuner_util_test.cc"],
+    data = [
+        "//xla/tools/hlo_opt:gpu_specs/a100_sxm_40.txtpb",
+        "//xla/tools/hlo_opt:gpu_specs/a100_sxm_80.txtpb",
+        "//xla/tools/hlo_opt:gpu_specs/mi200.txtpb",
+    ],
+    tags = [
+        "gpu",
+        "no_rocm",
+    ],
+    deps = [
+        ":autotuner_util",
+        "//xla:autotune_results_proto_cc",
+        "//xla:autotuning_proto_cc",
+        "//xla:xla_proto_cc",
+        "//xla/hlo/ir:hlo",
+        "//xla/hlo/utils:hlo_query",
+        "//xla/service:dump",
+        "//xla/stream_executor:device_description",
+        "//xla/stream_executor:device_description_proto_cc",
+        "//xla/stream_executor:platform",
+        "//xla/stream_executor:platform_manager",
+        "//xla/stream_executor/host:host_platform",
+        "//xla/tests:hlo_test_base",
+        "//xla/tests:verified_hlo_module",
+        "//xla/tests:xla_internal_test_main",
+        "//xla/tsl/lib/core:status_test_util",
+        "@com_google_absl//absl/base:log_severity",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/log:scoped_mock_log",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/strings",
+        "@com_google_googletest//:gtest",
+        "@local_tsl//tsl/platform:env",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:logging",
+        "@local_tsl//tsl/platform:path",
+        "@local_tsl//tsl/platform:protobuf",
+        "@local_tsl//tsl/platform:status",
+        "@local_tsl//tsl/platform:status_matchers",
+        "@local_tsl//tsl/platform:statusor",
+        "@local_tsl//tsl/platform:test",
+    ],
+)
diff --git a/third_party/xla/xla/service/gpu/autotuning/autotuner_compile_util.cc b/third_party/xla/xla/service/gpu/autotuning/autotuner_compile_util.cc
new file mode 100644
index 0000000..9922ea2
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/autotuning/autotuner_compile_util.cc
@@ -0,0 +1,284 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/autotuning/autotuner_compile_util.h"
+
+#include <cstdint>
+#include <iterator>
+#include <memory>
+#include <optional>
+#include <utility>
+#include <vector>
+
+#include "absl/log/check.h"
+#include "absl/status/status.h"
+#include "absl/strings/string_view.h"
+#include "absl/time/time.h"
+#include "absl/types/span.h"
+#include "xla/executable_run_options.h"
+#include "xla/hlo/ir/hlo_clone_context.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/compiler.h"
+#include "xla/service/executable.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
+#include "xla/service/gpu/gpu_executable_run_options.h"
+#include "xla/service/gpu/ir_emission_utils.h"
+#include "xla/service/maybe_owning_device_memory.h"
+#include "xla/service/service_executable_run_options.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/stream_executor/device_memory.h"
+#include "xla/stream_executor/gpu/redzone_allocator.h"
+#include "xla/stream_executor/stream.h"
+#include "xla/util.h"
+#include "xla/xla.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+
+std::vector<ExecutionInput> ExecutionInputsFromBuffers(
+    absl::Span<se::DeviceMemoryBase const> buffers,
+    absl::Span<Shape const> shapes) {
+  CHECK_EQ(buffers.size(), shapes.size());
+  std::vector<ExecutionInput> inputs;
+  for (int i = 0; i < buffers.size(); ++i) {
+    inputs.emplace_back(shapes.at(i));
+    // Our executable doesn't have input-output aliasing, so we can pass
+    // unowned input buffers.
+    inputs.back().SetUnownedBuffer(
+        /*index=*/{}, MaybeOwningDeviceMemory(/*unowned=*/buffers.at(i)));
+  }
+  return inputs;
+}
+
+}  // namespace
+
+AutotunerCompileUtil::AutotunerCompileUtil(const AutotuneConfig& config,
+                                           Compiler* compiler,
+                                           se::StreamExecutor& stream_executor,
+                                           se::Stream& stream,
+                                           se::DeviceMemoryAllocator& allocator,
+                                           const DebugOptions& opts)
+    : config_(config),
+      compiler_(compiler),
+      stream_executor_(stream_executor),
+      stream_(stream),
+      allocator_(allocator),
+      opts_(opts) {
+  // Avoid dumping compilation steps.
+  opts_.set_xla_enable_dumping(false);
+  opts_.set_xla_gpu_dump_autotune_results_to("");
+  opts_.set_xla_gpu_load_autotune_results_from("");
+  opts_.set_xla_gpu_dump_llvmir(false);
+  opts_.set_xla_gpu_dump_autotune_logs_to("");
+  // Avoid using another thread pool.
+  opts_.set_xla_gpu_force_compilation_parallelism(1);
+  opts_.set_xla_gpu_enable_llvm_module_compilation_parallelism(false);
+  // Avoid using GPU graphs as we don't want to measure graph construction time.
+  opts_.clear_xla_gpu_enable_command_buffer();
+  opts_.set_xla_embed_ir_in_executable(false);
+  opts_.set_xla_gpu_kernel_cache_file("");
+}
+
+absl::StatusOr<std::optional<AutotunerCompileUtil::ProfilingOutput>>
+AutotunerCompileUtil::ProfileExecutable(
+    Executable* executable, se::Stream* stream,
+    absl::Span<se::DeviceMemoryBase const> input_buffers,
+    absl::Span<Shape const> input_shapes) {
+  {
+    std::vector<ExecutionInput> execution_inputs =
+        ExecutionInputsFromBuffers(input_buffers, input_shapes);
+    // Warmup: in and out buffers are reused while probing different configs,
+    // so GPU caches should be in some comparable states during measurements.
+    absl::StatusOr<ExecutionOutput> execution_output =
+        Execute(*executable, std::move(execution_inputs));
+    if (!execution_output.ok()) {
+      // Treat register allocation error gracefully. If the compilation happens
+      // with the driver during execution then the error could surface here.
+      // It's enough to check this once here.
+      if (execution_output.status().code() ==
+          absl::StatusCode::kResourceExhausted) {
+        return {std::nullopt};
+      }
+      return execution_output.status();
+    }
+
+    TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
+  }
+  std::vector<ExecutionInput> execution_inputs =
+      ExecutionInputsFromBuffers(input_buffers, input_shapes);
+  ExecutionProfile profile;
+  // Flag that a warm-up run was executed so that GpuTimer can use the, more
+  // accurate, delay kernel implementation.
+  profile.set_warmup_run_executed(true);
+  TF_ASSIGN_OR_RETURN(
+      ExecutionOutput execution_output,
+      Execute(*executable, std::move(execution_inputs), &profile));
+  return std::make_optional<ProfilingOutput>(
+      absl::Nanoseconds(profile.compute_time_ns()),
+      execution_output.Commit().ConsumeResult());
+}
+
+absl::StatusOr<std::unique_ptr<Executable>> AutotunerCompileUtil::Compile(
+    GenerateModuleFn extractor) {
+  absl::StatusOr<std::unique_ptr<HloModule>> new_hlo_module = extractor(opts_);
+  if (new_hlo_module.status().GetPayload(kUncompilableFusion).has_value()) {
+    // Incompatible value of split-k is an example of an expected failure.
+    return std::unique_ptr<Executable>();
+  } else if (!new_hlo_module.status().ok()) {
+    return new_hlo_module.status();
+  }
+
+  absl::StatusOr<std::unique_ptr<Executable>> out = compiler_->RunBackend(
+      std::move(*new_hlo_module), &stream_executor_,
+      Compiler::CompileOptions{&allocator_, /*thread_pool=*/nullptr,
+                               /*layout_canonicalization_callback=*/{},
+                               /*is_autotuning_compilation=*/true});
+  if (out.status().code() == absl::StatusCode::kResourceExhausted ||
+      out.status().code() == absl::StatusCode::kCancelled) {
+    // Being out of shared memory budget or registers is an expected failure.
+    // Cancelling upon register spilling is also an expected failure.
+    return std::unique_ptr<Executable>();
+  }
+  return out;
+}
+
+absl::StatusOr<std::unique_ptr<HloModule>> AutotunerCompileUtil::ExtractModule(
+    GenerateModuleFn extractor) {
+  return extractor(opts_);
+}
+
+/*static*/ absl::StatusOr<std::optional<AutotunerCompileUtil>>
+AutotunerCompileUtil::Create(const AutotuneConfig& config,
+                             const DebugOptions& opts) {
+  if (config.IsDeviceless()) {
+    return std::nullopt;
+  }
+  se::StreamExecutor* stream_exec = config.GetExecutor();
+  se::DeviceMemoryAllocator* allocator = config.GetAllocator();
+  TF_ASSIGN_OR_RETURN(se::Stream* const stream, config.GetStream());
+  TF_ASSIGN_OR_RETURN(Compiler * compiler,
+                      Compiler::GetForPlatform(stream_exec->GetPlatform()));
+  return AutotunerCompileUtil(config, compiler, *stream_exec, *stream,
+                              *allocator, opts);
+}
+
+absl::StatusOr<ExecutionOutput> AutotunerCompileUtil::Execute(
+    Executable& executable, std::vector<ExecutionInput> arguments,
+    ExecutionProfile* profile) {
+  // Require exclusive GPU lock to prevent other runs during autotuning.
+  GpuExecutableRunOptions gpu_opts;
+  gpu_opts.set_requires_exclusive_lock_on_gpu();
+
+  ExecutableRunOptions run_options;
+  run_options.set_device_ordinal(stream_executor_.device_ordinal());
+  run_options.set_stream(&stream_);
+  run_options.set_allocator(&allocator_);
+  run_options.set_gpu_executable_run_options(&gpu_opts);
+  run_options.set_execution_profile(profile);
+  ServiceExecutableRunOptions service_run_options(run_options);
+  TF_ASSIGN_OR_RETURN(ExecutionOutput output,
+                      executable.ExecuteAsyncOnStreamWrapper(
+                          &service_run_options, std::move(arguments)));
+
+  return std::move(output);
+}
+
+absl::StatusOr<RedzoneBuffers> RedzoneBuffers::FromInstruction(
+    const HloInstruction& instruction, const AutotuneConfig& config,
+    const DebugOptions& debug_options, BuffersToCreate buffers_to_create) {
+  RedzoneBuffers buffers;
+
+  TF_ASSIGN_OR_RETURN(auto rz_allocator, AutotunerUtil::CreateRedzoneAllocator(
+                                             config, debug_options));
+  buffers.redzone_allocator_ =
+      std::make_unique<se::RedzoneAllocator>(std::move(rz_allocator));
+
+  int64_t rng_state = 0;
+
+  TF_RETURN_IF_ERROR(
+      buffers.CreateInputs(instruction, config, debug_options, rng_state));
+
+  if (buffers_to_create == BuffersToCreate::kAllInputsAllOutputs ||
+      buffers_to_create == BuffersToCreate::kAllInputsOutputsNoScratch) {
+    TF_RETURN_IF_ERROR(buffers.CreateOutputs(instruction, config, debug_options,
+                                             buffers_to_create, rng_state));
+  }
+
+  return buffers;
+}
+
+absl::Status RedzoneBuffers::CreateInputs(const HloInstruction& instruction,
+                                          const AutotuneConfig& config,
+                                          const DebugOptions& debug_options,
+                                          int64_t& rng_state) {
+  for (const auto* operand : instruction.operands()) {
+    TF_ASSIGN_OR_RETURN(
+        se::DeviceMemoryBase buf,
+        AutotunerUtil::CreateBuffer(*redzone_allocator_, operand->shape(),
+                                    config, rng_state));
+    input_buffers_.push_back(buf);
+    input_shapes_.push_back(operand->shape());
+  }
+  return absl::OkStatus();
+}
+
+absl::Status RedzoneBuffers::CreateOutputs(const HloInstruction& instruction,
+                                           const AutotuneConfig& config,
+                                           const DebugOptions& debug_options,
+                                           BuffersToCreate buffers_to_create,
+                                           int64_t& rng_state) {
+  if (!instruction.shape().IsTuple()) {
+    TF_ASSIGN_OR_RETURN(
+        se::DeviceMemoryBase buf,
+        AutotunerUtil::CreateBuffer(*redzone_allocator_, instruction.shape(),
+                                    config, rng_state));
+    output_buffers_.push_back(buf);
+    output_shape_ = instruction.shape();
+    return absl::OkStatus();
+  }
+
+  // The output is a tuple.
+
+  auto current_shape_it = instruction.shape().tuple_shapes().begin();
+  auto end = instruction.shape().tuple_shapes().end();
+  end -= buffers_to_create == kAllInputsAllOutputs ? 0 : 1;
+
+  output_shape_ = std::distance(current_shape_it, end) == 1
+                      ? output_shape_ = *current_shape_it
+                      : ShapeUtil::MakeTupleShape(
+                            std::vector<Shape>{current_shape_it, end});
+
+  for (; current_shape_it < end; current_shape_it++) {
+    if (current_shape_it->IsTuple()) {
+      return Unimplemented("Nested tuples are unsupported by RedzoneBuffers.");
+    }
+    TF_ASSIGN_OR_RETURN(
+        se::DeviceMemoryBase buf,
+        AutotunerUtil::CreateBuffer(*redzone_allocator_, *current_shape_it,
+                                    config, rng_state));
+    output_buffers_.push_back(buf);
+  }
+  return absl::OkStatus();
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/autotuning/autotuner_compile_util.h b/third_party/xla/xla/service/gpu/autotuning/autotuner_compile_util.h
new file mode 100644
index 0000000..02b1bfa
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/autotuning/autotuner_compile_util.h
@@ -0,0 +1,177 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_AUTOTUNING_AUTOTUNER_COMPILE_UTIL_H_
+#define XLA_SERVICE_GPU_AUTOTUNING_AUTOTUNER_COMPILE_UTIL_H_
+
+#include <cstdint>
+#include <memory>
+#include <optional>
+#include <utility>
+#include <vector>
+
+#include "absl/functional/any_invocable.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/time/time.h"
+#include "absl/types/span.h"
+#include "xla/hlo/ir/hlo_clone_context.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/compiler.h"
+#include "xla/service/executable.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
+#include "xla/service/shaped_buffer.h"
+#include "xla/shape.h"
+#include "xla/stream_executor/device_memory_allocator.h"
+#include "xla/stream_executor/gpu/redzone_allocator.h"
+#include "xla/stream_executor/stream.h"
+#include "xla/util.h"
+#include "xla/xla.pb.h"
+
+namespace xla {
+namespace gpu {
+
+// Autotuning utils which require compiling fusions separately. Requires a
+// separate target, as runtime autotuning cannot perform compilation.
+class AutotunerCompileUtil {
+ public:
+  // The GenerateModuleFn must generate/extract a module using the provided
+  // debug options. Typically it should set the debug options of the extracted
+  // module before it would transform it, to ensure that the transforms can use
+  // the debug options. In justified cases, it may override some of the provided
+  // debug options.
+  using GenerateModuleFn =
+      absl::AnyInvocable<absl::StatusOr<std::unique_ptr<HloModule>>(
+          const DebugOptions&)>;
+
+  // Generates a compile util for a platform associated with the `stream`.
+  //
+  // Returns an empty optional if the AutotuneConfig is deviceless, as
+  // autotuning is impossible in that case.
+  static absl::StatusOr<std::optional<AutotunerCompileUtil>> Create(
+      const AutotuneConfig& config, const DebugOptions& opts);
+
+  struct ProfilingOutput {
+    ProfilingOutput(absl::Duration duration, ScopedShapedBuffer&& buffer)
+        : duration(duration), output(std::move(buffer)) {}
+
+    absl::Duration duration;
+    ScopedShapedBuffer output;
+  };
+
+  // Generates an executable first, given the module generator function in
+  // `extractor`.
+  //
+  // Runs the resulting executable with the given extractor, cached with
+  // `(cache_key, config)`. Returns `std::nullopt` on expected failure, bad
+  // `Status` otherwise.
+  absl::StatusOr<std::optional<ProfilingOutput>> ProfileExecutable(
+      Executable* executable, se::Stream* stream,
+      absl::Span<se::DeviceMemoryBase const> input_buffers,
+      absl::Span<Shape const> input_shapes);
+
+  // Generic method to compile a generated module from `extractor` in isolation.
+  //
+  // Returns:
+  //  - `nullptr` on *expected* failure
+  //  - `Executable` if everything goes fine.
+  //  - `Status` on *unexpected* failure.
+  absl::StatusOr<std::unique_ptr<Executable>> Compile(
+      GenerateModuleFn extractor);
+
+  // Generic method to extract an HLO using the debug options of the
+  // AutotunerCompileUtil.
+  //
+  // Typically we can use Compile directly.
+  absl::StatusOr<std::unique_ptr<HloModule>> ExtractModule(
+      GenerateModuleFn extractor);
+
+ private:
+  AutotunerCompileUtil(const AutotuneConfig& config, Compiler* compiler,
+                       se::StreamExecutor& stream_executor, se::Stream& stream,
+                       se::DeviceMemoryAllocator& allocator,
+                       const DebugOptions& opts);
+
+  absl::StatusOr<ExecutionOutput> Execute(Executable& executable,
+                                          std::vector<ExecutionInput> arguments,
+                                          ExecutionProfile* profile = nullptr);
+
+  AutotuneConfig config_;
+  Compiler* compiler_;
+  se::StreamExecutor& stream_executor_;
+  se::Stream& stream_;
+  se::DeviceMemoryAllocator& allocator_;
+  DebugOptions opts_;
+};
+
+// A RedZone allocator and a collection of buffers that store the inputs and
+// outputs of an HloInstruction. These are used when running the instruction
+// for autotuning.
+class RedzoneBuffers {
+ public:
+  enum BuffersToCreate {
+    // Create a buffer for all of the instruction's operands. The result shape
+    // is ignored.
+    kAllInputs = 0,
+    // Create a buffer for all of the instruction's operands and the entire
+    // result shape. If the result shape is a tuple, a separate buffer is
+    // created for each subshape.
+    kAllInputsAllOutputs = 1,
+    // Create a buffer for all of the instruction's operands and all of the
+    // subshapes of the result tuple, except for the last one. The last subshape
+    // is considered a scratch buffer and is assumed to be allocated elsewhere.
+    // If the result shape is not a tuple, this will create a buffer
+    // corresponding to the entire shape - equivalent to `kAllInputsAllOutputs`.
+    kAllInputsOutputsNoScratch = 2,
+  };
+  static absl::StatusOr<RedzoneBuffers> FromInstruction(
+      const HloInstruction& instruction, const AutotuneConfig& config,
+      const DebugOptions& debug_options, BuffersToCreate buffers_to_create);
+
+  const std::vector<se::DeviceMemoryBase>& input_buffers() const {
+    return input_buffers_;
+  }
+  const std::vector<Shape>& input_shapes() const { return input_shapes_; }
+  const std::vector<se::DeviceMemoryBase>& output_buffers() const {
+    return output_buffers_;
+  }
+  const Shape& output_shape() const { return output_shape_; }
+  se::RedzoneAllocator& RedzoneAllocator() const { return *redzone_allocator_; }
+
+ private:
+  absl::Status CreateInputs(const HloInstruction& instruction,
+                            const AutotuneConfig& config,
+                            const DebugOptions& debug_options,
+                            int64_t& rng_state);
+
+  absl::Status CreateOutputs(const HloInstruction& instruction,
+                             const AutotuneConfig& config,
+                             const DebugOptions& debug_options,
+                             BuffersToCreate buffers_to_create,
+                             int64_t& rng_state);
+
+  std::unique_ptr<se::RedzoneAllocator> redzone_allocator_;
+  std::vector<se::DeviceMemoryBase> input_buffers_;
+  std::vector<Shape> input_shapes_;
+  std::vector<se::DeviceMemoryBase> output_buffers_;
+  Shape output_shape_;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_AUTOTUNING_AUTOTUNER_COMPILE_UTIL_H_
diff --git a/third_party/xla/xla/service/gpu/autotuning/autotuner_compile_util_test.cc b/third_party/xla/xla/service/gpu/autotuning/autotuner_compile_util_test.cc
new file mode 100644
index 0000000..a8b9594
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/autotuning/autotuner_compile_util_test.cc
@@ -0,0 +1,196 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/autotuning/autotuner_compile_util.h"
+
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
+#include "xla/service/platform_util.h"
+#include "xla/stream_executor/platform.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla::gpu {
+namespace {
+
+using AutotunerCompileUtilTest = HloTestBase;
+
+TEST_F(AutotunerCompileUtilTest, VerifyOutputNotATuple) {
+  constexpr absl::string_view kHlo = R"(
+HloModule hlo
+ENTRY main {
+  p0 = f32[2,2] parameter(0)
+  p1 = f32[4,4] parameter(1)
+  p2 = f32[6,6] parameter(2)
+  ROOT root = f32[1,2,3] custom-call(p0, p1, p2), custom_call_target="fake"
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHlo));
+
+  se::Platform* platform = PlatformUtil::GetDefaultPlatform().value();
+  TF_ASSERT_OK_AND_ASSIGN(std::vector<se::StreamExecutor*> executors,
+                          PlatformUtil::GetStreamExecutors(platform));
+
+  AutotuneConfig autotune_config{DeviceConfig{executors.at(0), nullptr},
+                                 GetDebugOptionsForTest()};
+
+  auto& root = *module->entry_computation()->root_instruction();
+
+  TF_ASSERT_OK_AND_ASSIGN(RedzoneBuffers rzb,
+                          RedzoneBuffers::FromInstruction(
+                              root, autotune_config, GetDebugOptionsForTest(),
+                              RedzoneBuffers::kAllInputs));
+
+  EXPECT_EQ(rzb.input_shapes().size(), 3);
+  EXPECT_EQ(rzb.input_buffers().size(), 3);
+  EXPECT_EQ(rzb.output_buffers().size(), 0);
+  EXPECT_NE(rzb.output_shape(), root.shape());
+
+  TF_ASSERT_OK_AND_ASSIGN(RedzoneBuffers rzb2,
+                          RedzoneBuffers::FromInstruction(
+                              root, autotune_config, GetDebugOptionsForTest(),
+                              RedzoneBuffers::kAllInputsAllOutputs));
+
+  EXPECT_EQ(rzb2.input_shapes().size(), 3);
+  EXPECT_EQ(rzb2.input_buffers().size(), 3);
+  EXPECT_EQ(rzb2.output_buffers().size(), 1);
+  EXPECT_EQ(rzb2.output_shape(), root.shape());
+
+  TF_ASSERT_OK_AND_ASSIGN(RedzoneBuffers rzb3,
+                          RedzoneBuffers::FromInstruction(
+                              root, autotune_config, GetDebugOptionsForTest(),
+                              RedzoneBuffers::kAllInputsOutputsNoScratch));
+
+  EXPECT_EQ(rzb3.input_shapes().size(), 3);
+  EXPECT_EQ(rzb3.input_buffers().size(), 3);
+  EXPECT_EQ(rzb3.output_buffers().size(), 1);
+  EXPECT_EQ(rzb3.output_shape(), root.shape());
+}
+
+TEST_F(AutotunerCompileUtilTest, VerifyOutputTupleOneElement) {
+  constexpr absl::string_view kHlo = R"(
+HloModule hlo
+ENTRY main {
+  p0 = f32[2,2] parameter(0)
+  p1 = f32[4,4] parameter(1)
+  p2 = f32[6,6] parameter(2)
+  ROOT root = (f32[1,2,3]) custom-call(p0, p1, p2), custom_call_target="fake"
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHlo));
+
+  se::Platform* platform = PlatformUtil::GetDefaultPlatform().value();
+  TF_ASSERT_OK_AND_ASSIGN(std::vector<se::StreamExecutor*> executors,
+                          PlatformUtil::GetStreamExecutors(platform));
+
+  AutotuneConfig autotune_config{DeviceConfig{executors.at(0), nullptr},
+                                 GetDebugOptionsForTest()};
+
+  auto& root = *module->entry_computation()->root_instruction();
+
+  TF_ASSERT_OK_AND_ASSIGN(RedzoneBuffers rzb,
+                          RedzoneBuffers::FromInstruction(
+                              root, autotune_config, GetDebugOptionsForTest(),
+                              RedzoneBuffers::kAllInputs));
+
+  EXPECT_EQ(rzb.input_shapes().size(), 3);
+  EXPECT_EQ(rzb.input_buffers().size(), 3);
+  EXPECT_EQ(rzb.output_buffers().size(), 0);
+  EXPECT_NE(rzb.output_shape(), root.shape());
+
+  TF_ASSERT_OK_AND_ASSIGN(RedzoneBuffers rzb2,
+                          RedzoneBuffers::FromInstruction(
+                              root, autotune_config, GetDebugOptionsForTest(),
+                              RedzoneBuffers::kAllInputsAllOutputs));
+
+  EXPECT_EQ(rzb2.input_shapes().size(), 3);
+  EXPECT_EQ(rzb2.input_buffers().size(), 3);
+  EXPECT_EQ(rzb2.output_buffers().size(), 1);
+  EXPECT_FALSE(rzb2.output_shape().IsTuple());
+  EXPECT_EQ(rzb2.output_shape(), root.shape().tuple_shapes(0));
+
+  TF_ASSERT_OK_AND_ASSIGN(RedzoneBuffers rzb3,
+                          RedzoneBuffers::FromInstruction(
+                              root, autotune_config, GetDebugOptionsForTest(),
+                              RedzoneBuffers::kAllInputsOutputsNoScratch));
+
+  EXPECT_EQ(rzb3.input_shapes().size(), 3);
+  EXPECT_EQ(rzb3.input_buffers().size(), 3);
+  EXPECT_EQ(rzb3.output_buffers().size(), 0);
+}
+
+TEST_F(AutotunerCompileUtilTest, VerifyOutputTupleTwoElements) {
+  constexpr absl::string_view kHlo = R"(
+HloModule hlo
+ENTRY main {
+  p0 = f32[2,2] parameter(0)
+  p1 = f32[4,4] parameter(1)
+  p2 = f32[6,6] parameter(2)
+  ROOT root = (f32[1,2,3], u8[1,2]) custom-call(p0, p1, p2), custom_call_target="fake"
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHlo));
+
+  se::Platform* platform = PlatformUtil::GetDefaultPlatform().value();
+  TF_ASSERT_OK_AND_ASSIGN(std::vector<se::StreamExecutor*> executors,
+                          PlatformUtil::GetStreamExecutors(platform));
+
+  AutotuneConfig autotune_config{DeviceConfig{executors.at(0), nullptr},
+                                 GetDebugOptionsForTest()};
+
+  auto& root = *module->entry_computation()->root_instruction();
+
+  TF_ASSERT_OK_AND_ASSIGN(RedzoneBuffers rzb,
+                          RedzoneBuffers::FromInstruction(
+                              root, autotune_config, GetDebugOptionsForTest(),
+                              RedzoneBuffers::kAllInputs));
+
+  EXPECT_EQ(rzb.input_shapes().size(), 3);
+  EXPECT_EQ(rzb.input_buffers().size(), 3);
+  EXPECT_EQ(rzb.output_buffers().size(), 0);
+  EXPECT_NE(rzb.output_shape(), root.shape());
+
+  TF_ASSERT_OK_AND_ASSIGN(RedzoneBuffers rzb2,
+                          RedzoneBuffers::FromInstruction(
+                              root, autotune_config, GetDebugOptionsForTest(),
+                              RedzoneBuffers::kAllInputsAllOutputs));
+
+  EXPECT_EQ(rzb2.input_shapes().size(), 3);
+  EXPECT_EQ(rzb2.input_buffers().size(), 3);
+  EXPECT_EQ(rzb2.output_buffers().size(), 2);
+  EXPECT_TRUE(rzb2.output_shape().IsTuple());
+  EXPECT_EQ(rzb2.output_shape(), root.shape());
+
+  TF_ASSERT_OK_AND_ASSIGN(RedzoneBuffers rzb3,
+                          RedzoneBuffers::FromInstruction(
+                              root, autotune_config, GetDebugOptionsForTest(),
+                              RedzoneBuffers::kAllInputsOutputsNoScratch));
+
+  EXPECT_EQ(rzb3.input_shapes().size(), 3);
+  EXPECT_EQ(rzb3.input_buffers().size(), 3);
+  EXPECT_EQ(rzb3.output_buffers().size(), 1);
+  EXPECT_FALSE(rzb3.output_shape().IsTuple());
+  EXPECT_EQ(rzb3.output_shape(), root.shape().tuple_shapes(0));
+}
+
+}  // namespace
+}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/autotuning/autotuner_util.cc b/third_party/xla/xla/service/gpu/autotuning/autotuner_util.cc
new file mode 100644
index 0000000..79bb744
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/autotuning/autotuner_util.cc
@@ -0,0 +1,551 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/autotuning/autotuner_util.h"
+
+#include <algorithm>
+#include <array>
+#include <cmath>
+#include <cstdint>
+#include <limits>
+#include <optional>
+#include <string>
+#include <utility>
+#include <variant>
+
+#include "absl/base/const_init.h"
+#include "absl/base/thread_annotations.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/match.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/string_view.h"
+#include "absl/synchronization/mutex.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/SHA256.h"
+#include "xla/autotune_results.pb.h"
+#include "xla/autotuning.pb.h"
+#include "xla/hlo/ir/hlo_clone_context.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/dump.h"
+#include "xla/service/gpu/gpu_asm_opts_util.h"
+#include "xla/service/gpu/stream_executor_util.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/status_macros.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/stream_executor/device_memory.h"
+#include "xla/stream_executor/gpu/redzone_allocator.h"
+#include "xla/stream_executor/stream.h"
+#include "xla/util.h"
+#include "tsl/platform/base64.h"
+#include "tsl/platform/env.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/logging.h"
+#include "tsl/platform/path.h"
+#include "tsl/platform/protobuf.h"  // IWYU pragma: keep
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+// Bump this version whenever you change the structure of the results.
+// LINT.IfChange(version)
+constexpr int kVersion = 3;
+// LINT.ThenChange()
+
+}  // namespace
+
+using AutotuneCacheMap = absl::flat_hash_map<AutotuneCacheKey, AutotuneResult>;
+
+static absl::Mutex autotune_cache_mu(absl::kConstInit);
+static auto& autotune_cache ABSL_GUARDED_BY(autotune_cache_mu) =
+    *new AutotuneCacheMap();
+
+absl::StatusOr<std::string> GetBase64EncodedSha256Hash(absl::string_view s) {
+  llvm::SHA256 sha256;
+  sha256.update(llvm::StringRef(s));
+  std::array<uint8_t, 32> hash = sha256.final();
+  // C++ strict aliasing rules allow reinterpret casting to (const) char*.
+  absl::string_view hash_view(reinterpret_cast<const char*>(hash.data()),
+                              hash.size());
+  std::string base64_encoded_hash;
+  TF_RETURN_IF_ERROR(tsl::Base64Encode(hash_view, &base64_encoded_hash));
+  return base64_encoded_hash;
+}
+
+namespace {
+
+// Get the path corresponding to the given key.
+absl::StatusOr<std::string> GetCacheFilePath(absl::string_view cache_dir,
+                                             const AutotuneCacheKey& key) {
+  if (cache_dir.empty()) {
+    return absl::InvalidArgumentError("autotune_cache_dir should not be empty");
+  }
+
+  TF_ASSIGN_OR_RETURN(std::string key_hash,
+                      GetBase64EncodedSha256Hash(key.ToString()));
+  return tsl::io::JoinPath(cache_dir, absl::StrCat(key_hash, ".textproto"));
+}
+
+struct ResultAndInserted {
+  // The result that ended up in the cache. This is the existing result if
+  // inserted is false, and the new result if inserted is true.
+  //
+  // We return a value, not a pointer, for thread safety reasons.
+  AutotuneResult result;
+  // Did we insert the given result into the cache?
+  bool inserted;
+};
+
+ResultAndInserted AddResultToInMemoryCache(const AutotuneCacheKey& key,
+                                           AutotuneResult result)
+    ABSL_LOCKS_EXCLUDED(autotune_cache_mu) {
+  absl::MutexLock lock(&autotune_cache_mu);
+  auto [it, inserted] = autotune_cache.emplace(key, std::move(result));
+  return {it->second, inserted};
+}
+
+absl::Status AddResultToFileBasedCacheIfEnabled(const AutotuneCacheKey& key,
+                                                AutotuneResult result,
+                                                std::string_view cache_dir)
+    ABSL_LOCKS_EXCLUDED(autotune_cache_mu) {
+  if (cache_dir.empty()) {
+    return absl::OkStatus();
+  }
+
+  tsl::Env* default_env = tsl::Env::Default();
+  TF_RETURN_IF_ERROR(CreateDirIfNeeded(std::string(cache_dir), default_env));
+
+  TF_ASSIGN_OR_RETURN(const std::string file_path,
+                      GetCacheFilePath(cache_dir, key));
+
+  VLOG(1) << "Writing autotune result to file: " << file_path;
+
+  std::string result_str;
+  if (!tsl::protobuf::TextFormat::PrintToString(result, &result_str)) {
+    return absl::InternalError("Failed to serialize autotune result.");
+  }
+
+  // Rename trick: Write to a temporary file, then rename it to the final file
+  // to avoid mingled files when multiple processes are writing to the same
+  // file. Also avoids reading incomplete files. (This may not work on all file
+  // systems.)
+  std::string temp_file_path = tsl::io::GetTempFilename(".textproto");
+  TF_RETURN_IF_ERROR(
+      tsl::WriteStringToFile(default_env, temp_file_path, result_str));
+  return default_env->RenameFile(temp_file_path, file_path);
+}
+
+absl::StatusOr<ResultAndInserted> AddResultToCaches(const AutotuneCacheKey& key,
+                                                    AutotuneResult result,
+                                                    std::string_view cache_dir)
+    ABSL_LOCKS_EXCLUDED(autotune_cache_mu) {
+  ResultAndInserted result_and_inserted = AddResultToInMemoryCache(key, result);
+  if (result_and_inserted.inserted) {
+    TF_RETURN_IF_ERROR(AddResultToFileBasedCacheIfEnabled(
+        key, result_and_inserted.result, cache_dir));
+  }
+  return result_and_inserted;
+}
+
+std::optional<AutotuneResult> TryToFindInInMemoryCache(
+    const AutotuneCacheKey& key) ABSL_LOCKS_EXCLUDED(autotune_cache_mu) {
+  absl::MutexLock lock(&autotune_cache_mu);
+  auto it = autotune_cache.find(key);
+  if (it == autotune_cache.end()) {
+    return std::nullopt;
+  }
+  return it->second;
+}
+
+absl::StatusOr<std::optional<AutotuneResult>>
+TryToFindInFileBasedCacheIfEnabled(const AutotuneCacheKey& key,
+                                   absl::string_view cache_dir)
+    ABSL_LOCKS_EXCLUDED(autotune_cache_mu) {
+  if (cache_dir.empty()) {
+    return std::nullopt;
+  }
+
+  TF_ASSIGN_OR_RETURN(const std::string file_path,
+                      GetCacheFilePath(cache_dir, key));
+  if (!tsl::Env::Default()->FileExists(file_path).ok()) {
+    VLOG(1) << "Autotune result file not found: " << file_path;
+    return std::nullopt;
+  }
+
+  VLOG(1) << "Autotune result file found: " << file_path;
+  std::string autotune_result_str;
+  TF_RETURN_IF_ERROR(tsl::ReadFileToString(tsl::Env::Default(), file_path,
+                                           &autotune_result_str));
+  AutotuneResult result;
+  if (!tsl::protobuf::TextFormat::ParseFromString(autotune_result_str,
+                                                  &result)) {
+    return absl::InvalidArgumentError("Failed to parse autotune result.");
+  }
+  return result;
+}
+
+// Sort the results so that they're deterministic.
+void SortAutotuneResults(AutotuneResults* results) {
+  std::sort(results->mutable_results()->pointer_begin(),
+            results->mutable_results()->pointer_end(),
+            [](const auto* a, const auto* b) {
+              return std::make_pair(absl::string_view(a->device()),
+                                    absl::string_view(a->hlo())) <
+                     std::make_pair(absl::string_view(b->device()),
+                                    absl::string_view(b->hlo()));
+            });
+}
+
+}  // namespace
+
+// Serialize `results` to string as a proto.
+absl::StatusOr<std::string> AutotuneResultsToString(
+    const AutotuneResults& results, bool as_textproto) {
+  if (as_textproto) {
+    std::string textproto;
+    if (tsl::protobuf::TextFormat::PrintToString(results, &textproto)) {
+      return textproto;
+    } else {
+      return Internal("Failed to serialize autotune results.");
+    }
+  }
+  return results.SerializeAsString();
+}
+
+namespace {
+// Serialize a single entry to `results`.
+void SerializeAutotuneEntry(AutotuneResults* results, const AutotuneCacheKey& k,
+                            const AutotuneResult* res) {
+  auto& entry = *results->add_results();
+  entry.set_device(std::string(k.GetModelStr()));
+  entry.set_hlo(std::string(k.GetHlo()));
+  *entry.mutable_result() = *res;
+}
+}  // namespace
+
+/*static*/ absl::Status AutotunerUtil::SerializeAutotuneResults(
+    AutotuneResults* results) {
+  absl::MutexLock lock(&autotune_cache_mu);
+  for (const auto& [k, result] : autotune_cache) {
+    SerializeAutotuneEntry(results, k, &result);
+  }
+
+  results->set_version(kVersion);
+  SortAutotuneResults(results);
+
+  return absl::OkStatus();
+}
+
+/*static*/ absl::Status AutotunerUtil::LoadAutotuneResults(
+    const AutotuneResults& results) {
+  absl::MutexLock lock(&autotune_cache_mu);
+  for (const AutotuneResults::Entry& result : results.results()) {
+    if (auto [it, inserted] = autotune_cache.emplace(
+            AutotuneCacheKey(result.device(), result.hlo()), result.result());
+        !inserted) {
+      return absl::InternalError(absl::StrCat(
+          "Duplicate autotuning result for ", it->first.ToString()));
+    }
+  }
+  return absl::OkStatus();
+}
+
+/*static*/ void AutotunerUtil::ClearAutotuneResults() {
+  absl::MutexLock lock(&autotune_cache_mu);
+  autotune_cache.clear();
+}
+
+/*static*/ bool AutotunerUtil::ResultCacheIsEmpty() {
+  absl::MutexLock lock(&autotune_cache_mu);
+  return autotune_cache.empty();
+}
+
+/* static*/ absl::StatusOr<se::DeviceMemoryBase> AutotunerUtil::CreateBuffer(
+    se::RedzoneAllocator& allocator, const Shape& shape,
+    const AutotuneConfig& config, int64_t& rng_state) {
+  TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase buffer,
+                      allocator.AllocateBytes(ShapeUtil::ByteSizeOf(shape)));
+  if (config.should_init_buffers()) {
+    InitializeBuffer(allocator.stream(), shape.element_type(), &rng_state,
+                     buffer);
+  }
+  return buffer;
+}
+
+namespace {
+std::string ToCanonicalString(const HloInstruction* instr) {
+  auto options = HloPrintOptions::Canonical();
+  if (instr->opcode() != HloOpcode::kFusion) {
+    options.set_print_backend_config(true);
+    options.set_sort_backend_config(true);
+    return instr->ToString(options);
+  }
+  options.set_print_subcomputation_mode(
+      HloPrintOptions::PrintSubcomputationMode::kOff);
+  options.set_print_infeed_outfeed_config(false);
+  options.set_print_only_essential_constants(true);
+  options.set_print_operand_shape(true);
+  options.set_print_ids(false);
+  options.set_canonicalize_computations(true);
+
+  // TODO(b/266210099): This is unsound. We should probably do the fingerprint
+  // of the HLO computation proto instead.
+  return instr->called_computations()[0]->ToString(options);
+}
+
+}  // namespace
+
+AutotuneCacheKey::AutotuneCacheKey(absl::string_view model_str,
+                                   const HloInstruction& instr)
+    : AutotuneCacheKey(model_str, ToCanonicalString(&instr)) {}
+
+/*static*/ std::string AutotuneCacheKey::DeviceDescriptionToCacheKey(
+    const se::DeviceDescription& device_description) {
+  std::string compute_capability;
+  if (auto* ccc = std::get_if<se::CudaComputeCapability>(
+          &device_description.gpu_compute_capability())) {
+    compute_capability = absl::StrCat("CUDA: ", ccc->major, ".", ccc->minor);
+  } else {
+    auto* rcc = std::get_if<se::RocmComputeCapability>(
+        &device_description.gpu_compute_capability());
+    CHECK(rcc != nullptr) << "Unknown compute capability type";
+    compute_capability = absl::StrCat("ROCM: ", rcc->gfx_version());
+  }
+
+  // The string below should include only as much information as is needed to
+  // make it a valid key. Information that should not be included is:
+  // - specs that are directly derivable from the compute capability, e.g.
+  //   shared memory size. For NVIDIA GPUs, you can see what is derivable from
+  //   the SM version here:
+  //   https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications-technical-specifications-per-compute-capability
+  // - specs that are irrelevant for autotuning. E.g. the total available memory
+  //   on a device is not relevant, because by itself, it does not affect the
+  //   performance of single kernels.
+  //
+  // See b/344573710 for some discussion.
+
+  double memory_bandwidth = device_description.memory_bandwidth() / 1e9;
+  // Round the memory bandwidth to make the final string nicer to read.
+  // This will also cause minute differences in bandwidth to yield the same
+  // cache key, but that's fine, since the difference is inconsequential.
+  memory_bandwidth = std::round(memory_bandwidth);
+
+  constexpr double kBytesPerMegabyte = 1 << 20;
+  double l2_cache_size = device_description.l2_cache_size() / kBytesPerMegabyte;
+
+  return absl::StrCat(compute_capability,
+                      ", Cores: ", device_description.core_count(),
+                      ", GPU clock: ", device_description.clock_rate_ghz(),
+                      " GHz, Memory bandwidth: ", memory_bandwidth,
+                      " GB/s, L2 cache: ", l2_cache_size, " MB");
+}
+
+namespace {
+absl::StatusOr<std::optional<AutotuneResult>> TryFindInCache(
+    const AutotuneCacheKey& key, absl::string_view cache_dir)
+    ABSL_LOCKS_EXCLUDED(autotune_cache_mu) {
+  std::optional<AutotuneResult> opt_result = TryToFindInInMemoryCache(key);
+  if (opt_result.has_value()) {
+    if (VLOG_IS_ON(1)) {
+      LOG(INFO) << "In-memory autotune cache hit";
+    } else if (VLOG_IS_ON(2)) {
+      LOG(INFO) << "In-memory autotune cache hit: key = " << key.ToString();
+    }
+    return opt_result;
+  }
+
+  TF_ASSIGN_OR_RETURN(opt_result,
+                      TryToFindInFileBasedCacheIfEnabled(key, cache_dir));
+  if (opt_result.has_value()) {
+    AddResultToInMemoryCache(key, opt_result.value());
+
+    if (VLOG_IS_ON(1)) {
+      LOG(INFO) << "File-based autotune cache hit";
+    } else if (VLOG_IS_ON(2)) {
+      LOG(INFO) << "File-based autotune cache hit: key = " << key.ToString();
+    }
+    return opt_result;
+  }
+
+  if (VLOG_IS_ON(1)) {
+    LOG(INFO) << "Autotune cache miss";
+  } else if (VLOG_IS_ON(2)) {
+    LOG(INFO) << "Autotune cache miss: key = " << key.ToString();
+  }
+  return std::nullopt;
+}
+}  // namespace
+
+/*static*/ AutotuneCacheKey AutotunerUtil::GetKey(
+    const HloInstruction* instr, const AutotuneConfig& config) {
+  return AutotuneCacheKey(config.GetModelStr(), *instr);
+}
+
+/*static*/ absl::StatusOr<bool> AutotunerUtil::IsInCache(
+    const AutotuneCacheKey& key, const AutotuneConfig& config) {
+  TF_ASSIGN_OR_RETURN(std::optional<AutotuneResult> opt_res,
+                      TryFindInCache(key, config.autotune_cache_dir()));
+  return opt_res.has_value();
+}
+
+/*static*/ absl::StatusOr<bool> AutotunerUtil::AddResult(
+    const AutotuneCacheKey& key, AutotuneResult result,
+    const AutotuneConfig& config) {
+  TF_ASSIGN_OR_RETURN(
+      ResultAndInserted result_and_inserted,
+      AddResultToCaches(key, std::move(result), config.autotune_cache_dir()));
+  return result_and_inserted.inserted;
+}
+
+/*static*/ absl::StatusOr<AutotuneResult> AutotunerUtil::Autotune(
+    const HloInstruction* instr, const AutotuneConfig& config,
+    const AutotuneNoCacheFn& autotune_fn) {
+  const AutotuneCacheKey key = GetKey(instr, config);
+  TF_ASSIGN_OR_RETURN(std::optional<AutotuneResult> opt_res,
+                      TryFindInCache(key, config.autotune_cache_dir()));
+  if (opt_res.has_value()) {
+    return opt_res.value();
+  }
+
+  // Cache miss.
+  if (config.should_require_complete_aot_autotune_results()) {
+    return NotFound(
+        "Complete XLA AOT autotuning results are required, but no AOT result "
+        "was found for key: %s",
+        key.ToString());
+  }
+
+  TF_ASSIGN_OR_RETURN(AutotuneResult autotune_result, autotune_fn());
+
+  TF_ASSIGN_OR_RETURN(ResultAndInserted result_and_inserted,
+                      AddResultToCaches(key, std::move(autotune_result),
+                                        config.autotune_cache_dir()));
+  return result_and_inserted.result;
+}
+
+namespace {
+
+bool IsTextProtoPath(absl::string_view file_path) {
+  return absl::EndsWith(file_path, ".txt") ||
+         absl::EndsWith(file_path, ".textproto") ||
+         absl::EndsWith(file_path, ".prototxt") ||
+         absl::EndsWith(file_path, ".pbtxt");
+}
+
+}  // anonymous namespace
+
+/*static*/ absl::Status AutotunerUtil::LoadAutotuneResults(
+    absl::string_view data, bool as_textproto) {
+  AutotuneResults results;
+  // The cast here is necessary for MacOS builds.
+  bool parse_success =
+      as_textproto ? tsl::protobuf::TextFormat::ParseFromString(
+                         std::string(data), &results)             // NOLINT
+                   : results.ParseFromString(std::string(data));  // NOLINT
+  if (!parse_success) {
+    return absl::InvalidArgumentError(
+        "Failed to parse autotune results string.");
+  }
+  if (results.version() != kVersion) {
+    return absl::InvalidArgumentError(absl::StrFormat(
+        "Version mismatch in autotune results. Expected %d but was %d",
+        kVersion, results.version()));
+  }
+
+  TF_RETURN_IF_ERROR(LoadAutotuneResults(results));
+  return absl::OkStatus();
+}
+
+/*static*/ absl::StatusOr<std::string> AutotunerUtil::SerializeAutotuneResults(
+    bool as_textproto) {
+  AutotuneResults results;
+  TF_RETURN_IF_ERROR(SerializeAutotuneResults(&results));
+  return AutotuneResultsToString(results, as_textproto);
+}
+
+/*static*/ absl::Status AutotunerUtil::SerializeAutotuneResultsToFile(
+    const AutotuneResults& results, absl::string_view file_path) {
+  TF_RET_CHECK(!file_path.empty());
+  TF_RET_CHECK(results.version() > 0)
+      << "Did you call SerializeAutotuneResults to get this AutotuneResults?";
+
+  std::string resolved_path;
+  if (!tsl::io::ResolveTestPrefixes(file_path, resolved_path)) {
+    return FailedPrecondition("File path can not be resolved: %s", file_path);
+  }
+
+  TF_ASSIGN_OR_RETURN(
+      std::string autotune_results_str,
+      AutotuneResultsToString(results, IsTextProtoPath(resolved_path)));
+  TF_RETURN_IF_ERROR(tsl::WriteStringToFile(tsl::Env::Default(), resolved_path,
+                                            autotune_results_str));
+  LOG(INFO) << "Autotune results serialized to file: " << resolved_path;
+
+  return absl::OkStatus();
+}
+
+/*static*/ absl::Status AutotunerUtil::SerializeAutotuneResultsToFile(
+    absl::string_view file_path) {
+  AutotuneResults results;
+  TF_RETURN_IF_ERROR(SerializeAutotuneResults(&results));
+  return SerializeAutotuneResultsToFile(results, file_path);
+}
+
+/*static*/ absl::Status AutotunerUtil::LoadAutotuneResultsFromFile(
+    absl::string_view file_path) {
+  TF_RET_CHECK(!file_path.empty());
+
+  std::string resolved_path;
+  if (!tsl::io::ResolveTestPrefixes(file_path, resolved_path)) {
+    return FailedPrecondition("File path can not be resolved: %s", file_path);
+  }
+
+  if (!tsl::Env::Default()->FileExists(resolved_path).ok()) {
+    return FailedPrecondition("Autotune results file does not exist: %s",
+                              resolved_path);
+  }
+  std::string autotune_results_str;
+  TF_RETURN_IF_ERROR(tsl::ReadFileToString(tsl::Env::Default(), resolved_path,
+                                           &autotune_results_str));
+
+  TF_RETURN_IF_ERROR(LoadAutotuneResults(autotune_results_str,
+                                         IsTextProtoPath(resolved_path)));
+
+  LOG(INFO) << "Autotune results loaded from file: " << resolved_path;
+
+  return absl::OkStatus();
+}
+
+/*static*/ absl::StatusOr<se::RedzoneAllocator>
+AutotunerUtil::CreateRedzoneAllocator(const AutotuneConfig& config,
+                                      const DebugOptions& opts) {
+  TF_ASSIGN_OR_RETURN(se::Stream * stream, config.GetStream());
+  return se::RedzoneAllocator(
+      stream, config.GetAllocator(), PtxOptsFromDebugOptions(opts),
+      /*memory_limit=*/std::numeric_limits<int64_t>::max(),
+      /*redzone_size=*/config.should_check_correctness()
+          ? opts.xla_gpu_redzone_padding_bytes()
+          : 0);
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/autotuning/autotuner_util.h b/third_party/xla/xla/service/gpu/autotuning/autotuner_util.h
new file mode 100644
index 0000000..6d5c321
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/autotuning/autotuner_util.h
@@ -0,0 +1,334 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_AUTOTUNING_AUTOTUNER_UTIL_H_
+#define XLA_SERVICE_GPU_AUTOTUNING_AUTOTUNER_UTIL_H_
+
+#include <algorithm>
+#include <cstdint>
+#include <functional>
+#include <memory>
+#include <string>
+#include <utility>
+#include <variant>
+
+#include "absl/log/check.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/string_view.h"
+#include "xla/autotune_results.pb.h"
+#include "xla/autotuning.pb.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/shape.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/stream_executor/device_memory.h"
+#include "xla/stream_executor/device_memory_allocator.h"
+#include "xla/stream_executor/gpu/redzone_allocator.h"
+#include "xla/stream_executor/stream_executor.h"
+#include "xla/stream_executor/stream_executor_memory_allocator.h"
+#include "xla/xla.pb.h"
+
+namespace xla {
+namespace gpu {
+
+struct DeviceConfig {
+  se::StreamExecutor* stream_exec;  // never null
+
+  // If the `allocator` parameter is not null, we will use it to allocate temp
+  // memory while timing the various convolution algorithms.  If it's null,
+  // we'll use the default allocator on the StreamExecutor.
+  se::DeviceMemoryAllocator* allocator = nullptr;  // may be null
+};
+
+struct DevicelessConfig {
+  // The device description of the target device.
+  se::DeviceDescription device_description;
+};
+
+class AutotuneCacheKey {
+ public:
+  AutotuneCacheKey(const se::DeviceDescription& device_description,
+                   const HloInstruction& instruction)
+      : AutotuneCacheKey(DeviceDescriptionToCacheKey(device_description),
+                         instruction.ToString()) {}
+
+  AutotuneCacheKey(absl::string_view model_str,
+                   const HloInstruction& instruction);
+
+  explicit AutotuneCacheKey(absl::string_view model_str,
+                            absl::string_view hlo_canonical)
+      : model_str_(model_str), hlo_canonical_(hlo_canonical) {}
+
+  absl::string_view GetModelStr() const { return model_str_; }
+
+  absl::string_view GetHlo() const { return hlo_canonical_; }
+
+  template <typename H>
+  friend H AbslHashValue(H h, const AutotuneCacheKey& w) {
+    return H::combine(std::move(h), w.model_str_, w.hlo_canonical_);
+  }
+
+  bool operator==(const AutotuneCacheKey& w) const {
+    return model_str_ == w.model_str_ && hlo_canonical_ == w.hlo_canonical_;
+  }
+
+  std::string ToString() const {
+    return absl::StrFormat("<key model='%s', hlo='%s'>", model_str_,
+                           hlo_canonical_);
+  }
+
+  static std::string DeviceDescriptionToCacheKey(
+      const se::DeviceDescription& device_description);
+
+ private:
+  std::string model_str_;
+  std::string hlo_canonical_;
+};
+
+class AutotuneConfig {
+ public:
+  bool should_init_buffers() const { return autotune_level_ >= 2; }
+  bool should_reinit_output_buffer() const { return autotune_level_ >= 3; }
+  bool should_check_correctness() const { return autotune_level_ >= 4; }
+  bool should_skip_wrong_results() const { return autotune_level_ >= 5; }
+  bool should_crash_on_check_failure() const {
+    return should_crash_on_check_failure_;
+  }
+  bool should_require_complete_aot_autotune_results() const {
+    return require_complete_aot_autotune_results_;
+  }
+  // Empty string means no cache is used.
+  const std::string& autotune_cache_dir() const { return autotune_cache_dir_; }
+
+  AutotuneConfig(const AutotuneConfig& right)
+      : config_(right.config_),
+        autotune_level_(right.autotune_level_),
+        should_crash_on_check_failure_(right.should_crash_on_check_failure_),
+        exhaustive_tiling_search_(right.exhaustive_tiling_search_),
+        require_complete_aot_autotune_results_(
+            right.require_complete_aot_autotune_results_),
+        autotune_cache_dir_(right.autotune_cache_dir_) {}
+
+  AutotuneConfig(const std::variant<DeviceConfig, DevicelessConfig>& config,
+                 const DebugOptions& debug_options)
+      : config_(config),
+        autotune_level_(debug_options.xla_gpu_autotune_level()),
+        should_crash_on_check_failure_(
+            debug_options.xla_gpu_crash_on_verification_failures()),
+        exhaustive_tiling_search_(
+            debug_options.xla_gpu_exhaustive_tiling_search()),
+        require_complete_aot_autotune_results_(
+            debug_options.xla_gpu_require_complete_aot_autotune_results()),
+        autotune_cache_dir_(
+            debug_options.xla_gpu_per_fusion_autotune_cache_dir()) {}
+
+  std::string GetModelStr() const {
+    if (auto deviceless_config = std::get_if<DevicelessConfig>(&config_)) {
+      return AutotuneCacheKey::DeviceDescriptionToCacheKey(
+          deviceless_config->device_description);
+    }
+
+    const auto& device_config = std::get<DeviceConfig>(config_);
+    return AutotuneCacheKey::DeviceDescriptionToCacheKey(
+        device_config.stream_exec->GetDeviceDescription());
+  }
+
+  se::StreamExecutor* GetExecutor() const {
+    CHECK(std::holds_alternative<DeviceConfig>(config_));
+    return std::get<DeviceConfig>(config_).stream_exec;
+  }
+
+  se::DeviceMemoryAllocator* GetAllocator() const {
+    CHECK(std::holds_alternative<DeviceConfig>(config_));
+    auto& cf = std::get<DeviceConfig>(config_);
+    if (cf.allocator != nullptr) {
+      return cf.allocator;
+    }
+    if (allocator_ == nullptr) {
+      allocator_ =
+          std::make_unique<se::StreamExecutorMemoryAllocator>(GetExecutor());
+    }
+    return allocator_.get();
+  }
+
+  absl::StatusOr<se::Stream*> GetStream() const {
+    CHECK(std::holds_alternative<DeviceConfig>(config_));
+    return GetAllocator()->GetStream(GetExecutor()->device_ordinal());
+  }
+
+  const se::GpuComputeCapability& GetGpuComputeCapability() const {
+    if (auto c = std::get_if<DeviceConfig>(&config_)) {
+      return c->stream_exec->GetDeviceDescription().gpu_compute_capability();
+    }
+    return std::get<DevicelessConfig>(config_)
+        .device_description.gpu_compute_capability();
+  }
+
+  bool IsDeviceless() const {
+    return std::holds_alternative<DevicelessConfig>(config_);
+  }
+
+  bool ExhaustiveTilingSearch() const { return exhaustive_tiling_search_; }
+
+ private:
+  std::variant<DeviceConfig, DevicelessConfig> config_;
+  int32_t autotune_level_;
+  bool should_crash_on_check_failure_;
+  bool exhaustive_tiling_search_;
+  bool require_complete_aot_autotune_results_;
+  mutable std::unique_ptr<se::DeviceMemoryAllocator> allocator_;
+  std::string autotune_cache_dir_;
+};
+
+using AutotuneNoCacheFn = std::function<absl::StatusOr<AutotuneResult>()>;
+
+struct AutotunerUtil {
+  // Create a buffer for a given operation using redzone checker, initialize
+  // based on a given rng state.
+  static absl::StatusOr<se::DeviceMemoryBase> CreateBuffer(
+      se::RedzoneAllocator& allocator, const Shape& shape,
+      const AutotuneConfig& config, int64_t& rng_state);
+
+  static absl::StatusOr<AutotuneResult> Autotune(
+      const HloInstruction* instr, const AutotuneConfig& config,
+      const AutotuneNoCacheFn& autotune_fn);
+
+  // Returns the same cache key that would be used inside Autotune().
+  //
+  // Normally, we don't have to use this low level method.
+  static AutotuneCacheKey GetKey(const HloInstruction* instr,
+                                 const AutotuneConfig& config);
+
+  // Checks if the key is in the autotune cache.
+  //
+  // Normally, we don't have to use this low level method.
+  static absl::StatusOr<bool> IsInCache(const AutotuneCacheKey& key,
+                                        const AutotuneConfig& config);
+
+  // Adds the result to the autotune cache.
+  //
+  // Returns true if the entry is inserted.
+  //
+  // Normally, we don't have to use this low level method.
+  static absl::StatusOr<bool> AddResult(const AutotuneCacheKey& key,
+                                        AutotuneResult result,
+                                        const AutotuneConfig& config);
+
+  // Creates a RedzoneAllocator from a given config.
+  static absl::StatusOr<se::RedzoneAllocator> CreateRedzoneAllocator(
+      const AutotuneConfig& config, const DebugOptions& opts);
+
+  // Functions to save/load XLA's autotuning results.
+  //
+  // This is used for ahead-of-time autotuning.  Specifically:
+  //
+  // When XLA calls cublas (for matmuls, aka "gemm" or "dot") or cudnn (for
+  // convolutions), it usually has to choose an "algorithm" for the particular
+  // dot/conv.  XLA queries cublas/cudnn for a list of candidate algorithms.
+  // Then it runs all of them and picks the fastest one.  This is what we call
+  // "autotuning". It happens in GemmAlgorithmPicker and GpuConvAlgorithmPicker.
+  //
+  // Autotuning is necessary to get good performance for dot/conv.  But it also
+  // has some disadvantages.
+  //
+  //  - Because it relies on timing data, it is fundamentally nondeterministic.
+  //    But even if two algorithms have similar runtimes, our choice of
+  //    algorithm may be visible to the user: Different algorithms can have
+  //    different numerics, and sometimes they can even have different bugs!
+  //
+  //  - Trying all the candidate algorithms can be slow, especially if when some
+  //    of the candidates are "very bad" and run especially slowly compared to
+  //    the optimal candidate.  This slows down compilation.
+  //
+  // To address the disadvantages above, we allow users to save/restore the
+  // autotuning choices that XLA has made, using the functions below.
+  //
+  // Loading autotuning results does not erase existing autotuning choices, but
+  // in the event of a disagreement between the existing data and the new data,
+  // the new algorithm is chosen.
+  //
+  // Note that even if you call LoadAutotuneResults(), if XLA encounters a
+  // dot/conv that is *not* covered by the loaded data, it will go ahead and
+  // autotune it like normal.  In other words, the behavior of XLA should be
+  // identical with or without ahead-of-time autotuning, modulo nondeterminism.
+  //
+  // This is important if you want to be able to use the same autotuning file
+  // with different versions of XLA, because as XLA changes, exactly which
+  // dots/convs it wants to run can also change.  For example, XLA might change
+  // the conv padding heuristics it uses, and we don't want that to mean that
+  // all users of ahead-of-time autotuning are broken.
+  static absl::StatusOr<std::string> SerializeAutotuneResults(
+      bool as_textproto = false);
+
+  // Serializes autotune results into the given proto.
+  static absl::Status SerializeAutotuneResults(AutotuneResults* results);
+
+  // Loads autotune results from the given string of bytes.
+  //
+  // Warning: The results are only loaded to the in-memory cache.
+  static absl::Status LoadAutotuneResults(absl::string_view data,
+                                          bool as_textproto = false);
+
+  // Loads autotune results from the given proto.
+  //
+  // Warning: The results are only loaded to the in-memory cache.
+  static absl::Status LoadAutotuneResults(const AutotuneResults& results);
+
+  // Serializes autotune results into a file.
+  //
+  // If `file_path` ends with ".txt" or ".textproto", then the textproto format
+  // is used, otherwise the binary protobuf format.
+  static absl::Status SerializeAutotuneResultsToFile(
+      absl::string_view file_path);
+
+  // As above, but if you already called SerializeAutotuneResults to get a
+  // proto.
+  static absl::Status SerializeAutotuneResultsToFile(
+      const AutotuneResults& results, absl::string_view file_path);
+
+  // Loads autotune results from a file.
+  //
+  // If `file_path` ends with ".txt" or ".textproto", then the file is
+  // considered to be in the textproto format, otherwise the binary protobuf
+  // format.
+  //
+  // Warning: The results are only loaded to the in-memory cache.
+  static absl::Status LoadAutotuneResultsFromFile(absl::string_view file_path);
+
+  // Warning: This only clears the in-memory cache. If you use a file based
+  // cache you're responsible for clearing the cache directory when you want to.
+  static void ClearAutotuneResults();
+
+  // Warning: This only checks the in-memory cache. If you use a file based
+  // cache, you're responsible for checking whether the cache directory is
+  // empty.
+  static bool ResultCacheIsEmpty();
+};
+
+absl::StatusOr<std::string> AutotuneResultsToString(
+    const AutotuneResults& results, bool as_textproto);
+
+// Exposed only for testing. Returns the SHA-256 hash of the input string,
+// encoded in base64.
+//
+// SHA-256 was chosen to follow industry best practices and avoid collisions.
+// Git is also transitioning to SHA-256. This is probably better than
+// tsl::Fingerprint128.
+absl::StatusOr<std::string> GetBase64EncodedSha256Hash(absl::string_view s);
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_AUTOTUNING_AUTOTUNER_UTIL_H_
diff --git a/third_party/xla/xla/service/gpu/autotuning/autotuner_util_test.cc b/third_party/xla/xla/service/gpu/autotuning/autotuner_util_test.cc
new file mode 100644
index 0000000..974f4d4
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/autotuning/autotuner_util_test.cc
@@ -0,0 +1,456 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/autotuning/autotuner_util.h"
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/check.h"
+#include "absl/status/status.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "xla/autotune_results.pb.h"
+#include "xla/autotuning.pb.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/hlo/utils/hlo_query.h"
+#include "xla/service/dump.h"
+#include "xla/stream_executor/platform.h"
+#include "xla/stream_executor/platform_manager.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/tsl/lib/core/status_test_util.h"
+#include "xla/xla.pb.h"
+#include "tsl/platform/env.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/logging.h"  // IWYU pragma: keep
+#include "tsl/platform/path.h"
+#include "tsl/platform/protobuf.h"  // IWYU pragma: keep
+#include "tsl/platform/status.h"
+#include "tsl/platform/status_matchers.h"
+#include "tsl/platform/statusor.h"
+#include "tsl/platform/test.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::HasSubstr;
+using ::testing::IsEmpty;
+using ::testing::Not;
+using ::testing::TempDir;
+using ::tsl::testing::StatusIs;
+
+class AutotunerUtilTest : public HloTestBase {
+ protected:
+  static constexpr absl::string_view kHloText = R"(
+HloModule t
+
+ENTRY e {
+  p0 = f16[1,16,17,3] parameter(0)
+  p1 = s8[16,17,3] parameter(1)
+  cp1 = f16[16,17,3] convert(p1)
+  ROOT _ = f16[1,16,16] dot(p0, cp1),
+    lhs_contracting_dims={2,3}, rhs_contracting_dims={1,2}
+})";
+
+  static constexpr absl::string_view kResultText = R"(
+version: 3
+results {
+  device: "CUDA: 8.0, Cores: 108, GPU clock: 1.41 GHz, Memory bandwidth: 1555 GB/s, L2 cache: 40 MB"
+  hlo: "{\n  tmp_0 = f16[1,16,17,3]{3,2,1,0} parameter(0)\n  tmp_1 = f16[16,51]{1,0} bitcast(f16[1,16,17,3]{3,2,1,0} tmp_0)\n  tmp_2 = s8[16,17,3]{2,1,0} parameter(1)\n  tmp_3 = s8[51,16]{0,1} bitcast(s8[16,17,3]{2,1,0} tmp_2)\n  tmp_4 = f16[51,16]{0,1} convert(s8[51,16]{0,1} tmp_3)\n  tmp_5 = f16[16,16]{1,0} dot(f16[16,51]{1,0} tmp_1, f16[51,16]{0,1} tmp_4), lhs_contracting_dims={1}, rhs_contracting_dims={0}\n  ROOT tmp_6 = f16[1,16,16]{2,1,0} bitcast(f16[16,16]{1,0} tmp_5)\n}"
+  result {
+    run_time {
+      nanos: 31744
+    }
+    triton {
+      block_m: 32
+      block_n: 32
+      block_k: 32
+      split_k: 1
+      num_stages: 1
+      num_warps: 4
+      num_ctas: 1
+    }
+  }
+})";
+
+  void SetUp() override { AutotunerUtil::ClearAutotuneResults(); }
+
+  std::string GetUniqueTempFilePath(absl::string_view suffix) {
+    std::string filename = TempDir();
+    CHECK(tsl::Env::Default()->CreateUniqueFileName(&filename,
+                                                    std::string(suffix)));
+    return filename;
+  }
+
+  std::string ExpectToReadNonEmptyFile(absl::string_view file_path) {
+    std::string str;
+    tsl::Env* env = tsl::Env::Default();
+    TF_EXPECT_OK(tsl::ReadFileToString(env, std::string(file_path), &str));
+    EXPECT_THAT(str, Not(IsEmpty()));
+    return str;
+  }
+
+  static stream_executor::StreamExecutor* NewStreamExecutor() {
+    stream_executor::Platform* platform =
+        stream_executor::PlatformManager::PlatformWithName("Host").value();
+    return platform->ExecutorForDevice(/*ordinal=*/0).value();
+  }
+
+  absl::Status PopulateResultCache() {
+    EXPECT_TRUE(AutotunerUtil::ResultCacheIsEmpty());
+    TF_RETURN_IF_ERROR(AutotunerUtil::LoadAutotuneResults(kResultText, true));
+    EXPECT_FALSE(AutotunerUtil::ResultCacheIsEmpty());
+    return absl::OkStatus();
+  }
+};
+
+TEST_F(AutotunerUtilTest, SerializeAutotuneResultsToFile_TextProto1) {
+  TF_EXPECT_OK(PopulateResultCache());
+  std::string kFilePath = GetUniqueTempFilePath(".txt");
+  TF_EXPECT_OK(AutotunerUtil::SerializeAutotuneResultsToFile(kFilePath));
+
+  std::string autotune_results_str = ExpectToReadNonEmptyFile(kFilePath);
+  AutotuneResults results;
+  EXPECT_TRUE(tsl::protobuf::TextFormat::ParseFromString(autotune_results_str,
+                                                         &results));
+  EXPECT_GT(results.results_size(), 0);
+}
+
+TEST_F(AutotunerUtilTest, SerializeAutotuneResultsToFile_TextProto2) {
+  TF_EXPECT_OK(PopulateResultCache());
+  std::string kFilePath = GetUniqueTempFilePath(".textproto");
+  TF_EXPECT_OK(AutotunerUtil::SerializeAutotuneResultsToFile(kFilePath));
+
+  std::string autotune_results_str = ExpectToReadNonEmptyFile(kFilePath);
+  AutotuneResults results;
+  EXPECT_TRUE(tsl::protobuf::TextFormat::ParseFromString(autotune_results_str,
+                                                         &results));
+}
+
+TEST_F(AutotunerUtilTest, SerializeAutotuneResultsToFile_Protobuf) {
+  TF_EXPECT_OK(PopulateResultCache());
+  std::string kFilePath = GetUniqueTempFilePath(".pb");
+  TF_EXPECT_OK(AutotunerUtil::SerializeAutotuneResultsToFile(kFilePath));
+
+  std::string autotune_results_str = ExpectToReadNonEmptyFile(kFilePath);
+  AutotuneResults results;
+  EXPECT_TRUE(results.ParseFromString(autotune_results_str));
+}
+
+TEST_F(AutotunerUtilTest, LoadAutotuneResultsFromFile_TextProto1) {
+  TF_EXPECT_OK(PopulateResultCache());
+  std::string kFilePath = GetUniqueTempFilePath(".txt");
+  TF_EXPECT_OK(AutotunerUtil::SerializeAutotuneResultsToFile(kFilePath));
+  AutotunerUtil::ClearAutotuneResults();
+  EXPECT_TRUE(AutotunerUtil::ResultCacheIsEmpty());
+
+  TF_EXPECT_OK(AutotunerUtil::LoadAutotuneResultsFromFile(kFilePath));
+  EXPECT_FALSE(AutotunerUtil::ResultCacheIsEmpty());
+}
+
+TEST_F(AutotunerUtilTest, LoadAutotuneResultsFromFile_TextProto2) {
+  TF_EXPECT_OK(PopulateResultCache());
+  std::string kFilePath = GetUniqueTempFilePath(".textproto");
+  TF_EXPECT_OK(AutotunerUtil::SerializeAutotuneResultsToFile(kFilePath));
+  AutotunerUtil::ClearAutotuneResults();
+  EXPECT_TRUE(AutotunerUtil::ResultCacheIsEmpty());
+
+  TF_EXPECT_OK(AutotunerUtil::LoadAutotuneResultsFromFile(kFilePath));
+  EXPECT_FALSE(AutotunerUtil::ResultCacheIsEmpty());
+}
+
+TEST_F(AutotunerUtilTest, LoadAutotuneResultsFromFile_Protobuf) {
+  TF_EXPECT_OK(PopulateResultCache());
+  std::string kFilePath = GetUniqueTempFilePath(".pb");
+  TF_EXPECT_OK(AutotunerUtil::SerializeAutotuneResultsToFile(kFilePath));
+  AutotunerUtil::ClearAutotuneResults();
+  EXPECT_TRUE(AutotunerUtil::ResultCacheIsEmpty());
+
+  TF_EXPECT_OK(AutotunerUtil::LoadAutotuneResultsFromFile(kFilePath));
+  EXPECT_FALSE(AutotunerUtil::ResultCacheIsEmpty());
+}
+
+TEST_F(AutotunerUtilTest, ResultConflictsAreDetected) {
+  TF_EXPECT_OK(PopulateResultCache());
+  std::string kFilePath = GetUniqueTempFilePath(".pb");
+  TF_EXPECT_OK(AutotunerUtil::SerializeAutotuneResultsToFile(kFilePath));
+  EXPECT_THAT(AutotunerUtil::LoadAutotuneResultsFromFile(kFilePath),
+              StatusIs(absl::StatusCode::kInternal,
+                       HasSubstr("Duplicate autotuning result")));
+}
+
+// Test that when complete AOT autotuning is required, and there is cache miss,
+// a `NotFound` error will be raised.
+TEST_F(AutotunerUtilTest, FailIfRequireCompleteAotAutotuning) {
+  std::string kFilePath = GetUniqueTempFilePath(".txt");
+  auto hlo_module = GetOptimizedModule(kHloText);
+  TF_EXPECT_OK(hlo_module.status());
+  std::vector<HloComputation*> computations =
+      (*hlo_module)
+          ->MakeNonfusionComputations(absl::flat_hash_set<absl::string_view>());
+  EXPECT_THAT(computations, Not(IsEmpty()));
+  const HloInstruction* instruction = *computations[0]->instructions().begin();
+  stream_executor::StreamExecutor* executor = NewStreamExecutor();
+  auto options = DebugOptions();
+  options.set_xla_gpu_require_complete_aot_autotune_results(true);
+  AutotuneConfig config(DeviceConfig{executor}, options);
+  EXPECT_THAT(
+      AutotunerUtil::Autotune(instruction, config,
+                              [&] { return AutotuneResult(); }),
+      StatusIs(
+          absl::StatusCode::kNotFound,
+          HasSubstr("Complete XLA AOT autotuning results are required, but "
+                    "no AOT result was found for key: <key model")));
+}
+
+// Test that when JIT autotuning is disabled, but no cache miss due to AOT
+// autotuning, `Autotune` still returns Ok status.
+TEST_F(AutotunerUtilTest, OkIfJitAutotuningDisabledButAlreadyLoadedAOT) {
+  auto hlo_module = GetOptimizedModule(kHloText);
+  std::vector<HloComputation*> computations =
+      (*hlo_module)
+          ->MakeNonfusionComputations(absl::flat_hash_set<absl::string_view>());
+  EXPECT_THAT(computations, Not(IsEmpty()));
+  const HloInstruction* instruction = *computations[0]->instructions().begin();
+  stream_executor::StreamExecutor* executor = NewStreamExecutor();
+
+  {
+    // By default, JIT autotuning is OK.
+    AutotuneConfig config(DeviceConfig{executor}, DebugOptions());
+    TF_EXPECT_OK(AutotunerUtil::Autotune(instruction, config, [&] {
+                   return AutotuneResult();
+                 }).status());
+  }
+
+  // Now require complete AOT autotuning results.
+  auto options = DebugOptions();
+  options.set_xla_gpu_require_complete_aot_autotune_results(true);
+
+  AutotuneConfig config(DeviceConfig{executor}, options);
+  // Even though JIT autotuning is disabled, there is no cache miss when running
+  // autotuning for the same entry, so no error should be raised either.
+  TF_EXPECT_OK(AutotunerUtil::Autotune(instruction, config, [&] {
+                 return AutotuneResult();
+               }).status());
+}
+
+class FileBasedCacheTest : public AutotunerUtilTest {
+ public:
+  static std::string ToString(const AutotuneResult& message) {
+    std::string textproto;
+    CHECK(tsl::protobuf::TextFormat::PrintToString(message, &textproto));
+    return textproto;
+  }
+
+  static std::vector<std::string> GetFilesInDir(
+      const absl::string_view cache_dir) {
+    std::vector<std::string> files_in_cache;
+    TF_CHECK_OK(tsl::Env::Default()->GetChildren(std::string(cache_dir),
+                                                 &files_in_cache));
+    return files_in_cache;
+  }
+
+  static std::string Read(const absl::string_view filepath) {
+    std::string file_content;
+    TF_CHECK_OK(tsl::ReadFileToString(tsl::Env::Default(),
+                                      std::string(filepath), &file_content));
+    return file_content;
+  }
+
+  void Write(const absl::string_view filepath,
+             const absl::string_view content) {
+    TF_CHECK_OK(CreateDirIfNeeded(cache_dir_, tsl::Env::Default()));
+    TF_CHECK_OK(tsl::WriteStringToFile(tsl::Env::Default(),
+                                       std::string(filepath), content));
+  }
+
+  stream_executor::StreamExecutor* executor_ = NewStreamExecutor();
+  std::unique_ptr<HloModule> module_ =
+      ParseAndReturnVerifiedModule(kHloText).value();
+  const HloInstruction* dot_ = hlo_query::GetFirstInstructionWithOpcode(
+      *module_->entry_computation(), HloOpcode::kDot);
+  std::string cache_dir_ = [] {
+    tsl::Env* default_env = tsl::Env::Default();
+    std::string cache_dir;
+    CHECK(default_env->LocalTempFilename(&cache_dir));
+    return cache_dir;
+  }();
+  AutotuneConfig config_ = AutotuneConfig(DeviceConfig{executor_}, [&] {
+    DebugOptions options;
+    options.set_xla_gpu_per_fusion_autotune_cache_dir(cache_dir_);
+    return options;
+  }());
+  AutotuneCacheKey cache_key_ = AutotunerUtil::GetKey(dot_, config_);
+  std::string cache_filename_ = [&] {
+    absl::StatusOr<std::string> key_hash =
+        GetBase64EncodedSha256Hash(cache_key_.ToString());
+    CHECK_OK(key_hash.status());
+    return absl::StrCat(key_hash.value(), ".textproto");
+  }();
+  std::string cache_file_path_ = tsl::io::JoinPath(cache_dir_, cache_filename_);
+  const AutotuneResult result1_ = [] {
+    AutotuneResult result;
+    result.set_scratch_bytes(1);
+    return result;
+  }();
+  const AutotuneResult result2_ = [] {
+    AutotuneResult result;
+    result.set_scratch_bytes(2);
+    return result;
+  }();
+};
+
+TEST_F(FileBasedCacheTest, AutotuneWritesResultToTheCacheDir) {
+  TF_ASSERT_OK_AND_ASSIGN(
+      AutotuneResult result,
+      AutotunerUtil::Autotune(dot_, config_, [&] { return result1_; }));
+  EXPECT_EQ(ToString(result), ToString(result1_));
+
+  ASSERT_THAT(GetFilesInDir(cache_dir_), ElementsAre(cache_filename_));
+  EXPECT_EQ(Read(cache_file_path_), ToString(result1_));
+}
+
+TEST_F(FileBasedCacheTest, AutotuneReadsResultFromTheCacheDir) {
+  Write(cache_file_path_, ToString(result1_));
+
+  bool cache_hit = true;
+  TF_ASSERT_OK_AND_ASSIGN(AutotuneResult result,
+                          AutotunerUtil::Autotune(dot_, config_, [&] {
+                            cache_hit = false;
+                            return result2_;
+                          }));
+
+  EXPECT_TRUE(cache_hit);
+  EXPECT_EQ(ToString(result), ToString(result1_));
+}
+
+TEST_F(FileBasedCacheTest,
+       RepeatedAutotuneCallsDontReadOrWriteTheCacheFileAgain) {
+  auto check_autotune_cache_hit = [](const HloInstruction* instr,
+                                     const AutotuneConfig& config,
+                                     const AutotuneResult& expected_result) {
+    bool cache_hit = true;
+    TF_ASSERT_OK_AND_ASSIGN(AutotuneResult result,
+                            AutotunerUtil::Autotune(instr, config, [&] {
+                              cache_hit = false;
+                              AutotuneResult new_result;
+                              new_result.set_scratch_bytes(2);
+                              return new_result;
+                            }));
+    EXPECT_TRUE(cache_hit);
+    EXPECT_EQ(ToString(result), ToString(expected_result));
+  };
+
+  Write(cache_file_path_, ToString(result1_));
+  check_autotune_cache_hit(dot_, config_, /*expected_result=*/result1_);
+
+  constexpr absl::string_view kPlaceholderContent = "placeholder content";
+  Write(cache_file_path_, kPlaceholderContent);
+  // File was not read again:
+  check_autotune_cache_hit(dot_, config_, /*expected_result=*/result1_);
+  // File was not written again:
+  EXPECT_EQ(Read(cache_file_path_), kPlaceholderContent);
+}
+
+TEST_F(FileBasedCacheTest,
+       IsInCacheReturnsTrueIfTheResultIsInTheFileBasedCache) {
+  Write(cache_file_path_, ToString(result1_));
+
+  TF_ASSERT_OK_AND_ASSIGN(bool is_in_cache,
+                          AutotunerUtil::IsInCache(cache_key_, config_));
+
+  EXPECT_TRUE(is_in_cache);
+}
+
+TEST_F(FileBasedCacheTest, IsInCacheReturnsFalseIfTheResultIsNotInEitherCache) {
+  TF_ASSERT_OK_AND_ASSIGN(bool is_in_cache,
+                          AutotunerUtil::IsInCache(cache_key_, config_));
+
+  EXPECT_FALSE(is_in_cache);
+}
+
+TEST_F(FileBasedCacheTest, AddResultAddsTheResultToTheFileBasedCache) {
+  TF_ASSERT_OK_AND_ASSIGN(
+      bool added, AutotunerUtil::AddResult(cache_key_, result1_, config_));
+  EXPECT_TRUE(added);
+
+  ASSERT_THAT(GetFilesInDir(cache_dir_), ElementsAre(cache_filename_));
+  EXPECT_EQ(Read(cache_file_path_), ToString(result1_));
+}
+
+TEST_F(FileBasedCacheTest, RepeatedAddResultDoesNotWriteTheFileAgain) {
+  {
+    TF_ASSERT_OK_AND_ASSIGN(
+        bool added, AutotunerUtil::AddResult(cache_key_, result1_, config_));
+    EXPECT_TRUE(added);
+  }
+  ASSERT_THAT(GetFilesInDir(cache_dir_), ElementsAre(cache_filename_));
+  EXPECT_EQ(Read(cache_file_path_), ToString(result1_));
+  constexpr absl::string_view kPlaceholderContent = "placeholder content";
+  Write(cache_file_path_, kPlaceholderContent);
+
+  {
+    TF_ASSERT_OK_AND_ASSIGN(
+        bool added, AutotunerUtil::AddResult(cache_key_, result1_, config_));
+    EXPECT_FALSE(added);
+  }
+
+  // File was not written again:
+  EXPECT_EQ(Read(cache_file_path_), kPlaceholderContent);
+}
+
+TEST(AutotuneCacheKeyTest, DeviceDescriptionToCacheKey) {
+  auto device_description =
+      [](absl::string_view spec_file_name) -> se::DeviceDescription {
+    se::GpuTargetConfigProto proto;
+    std::string spec_string;
+    CHECK_OK(tsl::ReadFileToString(
+        tsl::Env::Default(),
+        tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "tools", "hlo_opt",
+                          "gpu_specs", spec_file_name),
+        &spec_string));
+    EXPECT_TRUE(
+        tsl::protobuf::TextFormat::ParseFromString(spec_string, &proto));
+    return se::DeviceDescription(proto.gpu_device_info());
+  };
+
+  EXPECT_EQ(AutotuneCacheKey::DeviceDescriptionToCacheKey(
+                device_description("a100_sxm_40.txtpb")),
+            "CUDA: 8.0, Cores: 108, GPU clock: 1.41 GHz, Memory bandwidth: "
+            "1555 GB/s, L2 cache: 40 MB");
+
+  EXPECT_EQ(AutotuneCacheKey::DeviceDescriptionToCacheKey(
+                device_description("a100_sxm_80.txtpb")),
+            "CUDA: 8.0, Cores: 108, GPU clock: 1.41 GHz, Memory bandwidth: "
+            "2039 GB/s, L2 cache: 40 MB");
+
+  EXPECT_EQ(AutotuneCacheKey::DeviceDescriptionToCacheKey(
+                device_description("mi200.txtpb")),
+            "ROCM: gfx90a, Cores: 110, GPU clock: 1.7 GHz, Memory bandwidth: "
+            "1638 GB/s, L2 cache: 8 MB");
+}
+
+}  // namespace
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker.cc b/third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker.cc
new file mode 100644
index 0000000..6173915
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker.cc
@@ -0,0 +1,1194 @@
+/* Copyright 2018 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/autotuning/conv_algorithm_picker.h"
+
+#include <algorithm>
+#include <cmath>
+#include <cstddef>
+#include <cstdint>
+#include <limits>
+#include <memory>
+#include <optional>
+#include <string>
+#include <string_view>
+#include <tuple>
+#include <utility>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/status.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/string_view.h"
+#include "absl/synchronization/mutex.h"
+#include "absl/time/time.h"
+#include "absl/types/span.h"
+#include "xla/autotuning.pb.h"
+#include "xla/debug_options_flags.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/literal_util.h"
+#include "xla/service/gpu/autotuning/autotuner_compile_util.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
+#include "xla/service/gpu/autotuning/gpu_autotuning.pb.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/service/gpu/gpu_conv_runner.h"
+#include "xla/service/gpu/hlo_algorithm_denylist.h"
+#include "xla/service/gpu/stream_executor_util.h"
+#include "xla/service/hlo_module_config.h"
+#include "xla/service/slow_operation_alarm.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/stream_executor/cuda/cuda_platform_id.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/stream_executor/device_memory.h"
+#include "xla/stream_executor/device_memory_allocator.h"
+#include "xla/stream_executor/dnn.h"
+#include "xla/stream_executor/lazy_op_runner.h"
+#include "xla/stream_executor/numeric_options.h"
+#include "xla/stream_executor/platform.h"
+#include "xla/stream_executor/rocm/rocm_platform_id.h"
+#include "xla/stream_executor/scratch_allocator.h"
+#include "xla/stream_executor/stream.h"
+#include "xla/stream_executor/stream_executor.h"
+#include "xla/tsl/util/env_var.h"
+#include "xla/tsl/util/proto/proto_utils.h"
+#include "xla/util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/logging.h"
+#include "tsl/platform/numbers.h"
+#include "tsl/platform/status.h"
+#include "tsl/platform/statusor.h"
+
+#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
+#include "third_party/gpus/cudnn/cudnn.h"  // IWYU pragma: keep
+#include "third_party/gpus/cudnn/cudnn_version.h"
+#if CUDNN_VERSION >= 90000
+#include "third_party/gpus/cudnn/cudnn_ops.h"
+#else
+#include "third_party/gpus/cudnn/cudnn_ops_infer.h"
+#endif  // CUDNN_VERSION >= 90000
+#include "xla/service/gpu/buffer_comparator.h"
+#include "xla/stream_executor/gpu/redzone_allocator.h"
+#endif
+
+namespace xla {
+namespace gpu {
+namespace {
+
+using se::DeviceMemoryBase;
+using se::dnn::AlgorithmDesc;
+using std::optional;
+
+class ScratchAllocator : public se::ScratchAllocator {
+ public:
+  ScratchAllocator(int device_ordinal,
+                   se::DeviceMemoryAllocator* memory_allocator)
+      : device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {}
+
+  int64_t GetMemoryLimitInBytes() override {
+    return ScratchAllocator::GetDefaultMemoryLimitInBytes();
+  }
+  int64_t TotalAllocatedBytes() { return total_allocated_bytes_; }
+
+  static int64_t GetDefaultMemoryLimitInBytes() {
+    int64_t value;
+    TF_CHECK_OK(tsl::ReadInt64FromEnvVar("TF_CUDNN_WORKSPACE_LIMIT_IN_MB",
+                                         1LL << 12, &value));
+    return value * (1LL << 20);
+  }
+
+  absl::StatusOr<se::DeviceMemory<uint8_t>> AllocateBytes(
+      int64_t byte_size) override;
+
+  template <typename T>
+  absl::StatusOr<se::DeviceMemory<T>> Allocate(int64_t num_elements) {
+    TF_ASSIGN_OR_RETURN(se::DeviceMemory<uint8_t> bytes,
+                        AllocateBytes(num_elements * sizeof(T)));
+    return se::DeviceMemory<T>(bytes);
+  }
+
+ private:
+  const int device_ordinal_;
+  se::DeviceMemoryAllocator* memory_allocator_;
+  std::vector<se::OwningDeviceMemory> allocated_buffers_;
+  int64_t total_allocated_bytes_ = 0;
+};
+
+absl::StatusOr<se::DeviceMemory<uint8_t>> ScratchAllocator::AllocateBytes(
+    int64_t byte_size) {
+  CHECK_GE(byte_size, 0) << "byte_size must be positive.";
+  if (byte_size > GetMemoryLimitInBytes()) {
+    return absl::ResourceExhaustedError(absl::StrFormat(
+        "Allocating %d bytes exceeds the memory limit of %d bytes.", byte_size,
+        GetMemoryLimitInBytes()));
+  }
+
+  TF_ASSIGN_OR_RETURN(se::OwningDeviceMemory allocated_buffer,
+                      memory_allocator_->Allocate(device_ordinal_, byte_size,
+                                                  /*retry_on_failure=*/false));
+  total_allocated_bytes_ += byte_size;
+
+  se::DeviceMemoryBase buffer_addr = *allocated_buffer;
+  allocated_buffers_.push_back(std::move(allocated_buffer));
+  return se::DeviceMemory<uint8_t>(buffer_addr);
+}
+
+absl::StatusOr<std::vector<GenericConvRunner>> GetAlgorithms(
+    const GpuConvConfig& config, se::Stream* stream, bool use_cudnn_frontend,
+    bool use_fallback, const se::NumericOptions& numeric_options) {
+  TF_ASSIGN_OR_RETURN(se::dnn::ConvolutionKind kind,
+                      GetDNNConvKindFromCudnnConvKind(config.kind));
+
+  TF_ASSIGN_OR_RETURN(se::dnn::DataType input_type,
+                      GetDNNDataTypeFromPrimitiveType(config.input_type));
+
+  TF_ASSIGN_OR_RETURN(se::dnn::DataType output_type,
+                      GetDNNDataTypeFromPrimitiveType(config.output_type));
+
+  se::StreamExecutor* stream_exec = stream->parent();
+  std::vector<GenericConvRunner> result;
+
+  auto dnn = stream_exec->AsDnn();
+  if (dnn == nullptr) {
+    return absl::InvalidArgumentError("No DNN in stream executor.");
+  }
+  switch (kind) {
+    default:
+      return Internal("Unknown ConvolutionKind %d", kind);
+    case se::dnn::ConvolutionKind::FORWARD_BIAS_ACTIVATION: {
+      if (!config.fusion) {
+        return Internal(
+            "GpuConvConfig had fusion ConvolutionKind but no FusionConfig.");
+      }
+      std::vector<std::unique_ptr<const se::dnn::FusedConvRunner>> runners;
+      TF_RETURN_IF_ERROR(dnn->GetFusedConvolveRunners(
+          use_cudnn_frontend,
+          // This refers to the kind of convolution op inside the fusion, not
+          // the whole fused graph.
+          se::dnn::ConvolutionKind::FORWARD, input_type,
+          BiasTypeForInputType(input_type), output_type,
+          /* conv_input_scale = */ config.conv_result_scale,
+          /* side_input_scale = */ config.fusion->side_input_scale,
+          /* leakyrelu_alpha = */ config.fusion->leakyrelu_alpha, stream,
+          config.input_descriptor, config.filter_descriptor,
+          config.bias_descriptor, config.output_descriptor, config.conv_desc,
+          use_fallback, config.fusion->mode, numeric_options, &runners));
+      for (auto& runner : runners) {
+        TF_ASSIGN_OR_RETURN(
+            auto runner_cache,
+            se::dnn::LazyOpRunner<se::dnn::FusedConvOp>::FromOpRunner(
+                std::move(runner)));
+        result.emplace_back(std::move(runner_cache));
+      }
+      break;
+    }
+
+    case se::dnn::ConvolutionKind::FORWARD_GRAPH: {
+      std::vector<std::unique_ptr<const se::dnn::GraphConvRunner>> runners;
+      // This path is cuDNN-only, where the DeviceMemoryBase arguments and the
+      // allocator are unused; so, they're all provided as nullptr.
+      TF_RETURN_IF_ERROR(dnn->GetGraphConvolveRunners(
+          kind, input_type, output_type, stream, config.input_descriptor,
+          config.filter_descriptor, config.output_descriptor, config.conv_desc,
+          use_fallback, numeric_options, &runners, config.serialized_graph));
+      for (auto& runner : runners) {
+        TF_ASSIGN_OR_RETURN(
+            auto runner_cache,
+            se::dnn::LazyOpRunner<se::dnn::GraphConvOp>::FromOpRunner(
+                std::move(runner)));
+        result.emplace_back(std::move(runner_cache));
+      }
+      break;
+    }
+
+    case se::dnn::ConvolutionKind::FORWARD:
+    case se::dnn::ConvolutionKind::BACKWARD_DATA:
+    case se::dnn::ConvolutionKind::BACKWARD_FILTER: {
+      std::vector<std::unique_ptr<const se::dnn::ConvRunner>> runners;
+      // This path is cuDNN-only, where the DeviceMemoryBase arguments and the
+      // allocator are unused; so, they're all provided as nullptr.
+      TF_RETURN_IF_ERROR(dnn->GetConvolveRunners(
+          use_cudnn_frontend, kind, input_type, output_type, stream,
+          config.input_descriptor,
+          /* input_data = */ DeviceMemoryBase(nullptr),
+          config.filter_descriptor,
+          /* filter_data = */ DeviceMemoryBase(nullptr),
+          config.output_descriptor,
+          /* output_data = */ DeviceMemoryBase(nullptr), config.conv_desc,
+          use_fallback, nullptr, numeric_options, &runners));
+
+      for (auto& runner : runners) {
+        TF_ASSIGN_OR_RETURN(
+            auto runner_cache,
+            se::dnn::LazyOpRunner<se::dnn::ConvOp>::FromOpRunner(
+                std::move(runner)));
+        result.emplace_back(std::move(runner_cache));
+      }
+      break;
+    }
+  }
+
+  return result;
+}
+
+absl::StatusOr<std::vector<std::unique_ptr<const se::dnn::ConvRunner>>>
+GetMIOpenAlgorithms(const HloCustomCallInstruction* instr,
+                    absl::Span<se::DeviceMemoryBase> operand_buffers,
+                    absl::Span<se::DeviceMemoryBase> result_buffers,
+                    se::StreamExecutor* stream_exec,
+                    ScratchAllocator* scratch_allocator, se::Stream* stream,
+                    const se::NumericOptions& numeric_options) {
+  TF_ASSIGN_OR_RETURN(GpuConvConfig config, GetGpuConvConfig(instr));
+
+  TF_ASSIGN_OR_RETURN(se::dnn::ConvolutionKind kind,
+                      GetDNNConvKindFromCudnnConvKind(config.kind));
+
+  TF_ASSIGN_OR_RETURN(se::dnn::DataType dtype,
+                      GetDNNDataTypeFromPrimitiveType(config.output_type));
+
+  TF_ASSIGN_OR_RETURN(
+      GpuConvParams params,
+      GetGpuConvParams(config, operand_buffers, result_buffers));
+
+  std::vector<std::unique_ptr<const se::dnn::ConvRunner>> runners;
+  auto dnn = stream_exec->AsDnn();
+  if (dnn == nullptr) {
+    return absl::InvalidArgumentError("No DNN in stream executor.");
+  }
+  TF_RETURN_IF_ERROR(dnn->GetConvolveRunners(
+      /* use_cudnn_frontend = */ false, kind, dtype, dtype, stream,
+      params.config->input_descriptor, params.input_buf,
+      params.config->filter_descriptor, params.filter_buf,
+      params.config->output_descriptor, params.output_buf,
+      params.config->conv_desc,
+      /* use_fallback = */ false, scratch_allocator, numeric_options,
+      &runners));
+
+  return runners;
+}
+
+std::string NumBytesToString(int64_t bytes) {
+  return absl::StrCat(tsl::strings::HumanReadableNumBytes(bytes), " (", bytes,
+                      "B)");
+}
+
+CudnnVersion GetCudnnVersion(se::StreamExecutor* stream_executor) {
+  se::dnn::VersionInfo version = GetDnnVersionInfoOrDefault(stream_executor);
+  CudnnVersion cudnn_version;
+  cudnn_version.set_major(version.major_version());
+  cudnn_version.set_minor(version.minor_version());
+  cudnn_version.set_patch(version.patch());
+
+  return cudnn_version;
+}
+
+ComputeCapability GetComputeCapability(se::StreamExecutor* stream_executor) {
+  ComputeCapability cc;
+  se::CudaComputeCapability se_cc =
+      stream_executor->GetDeviceDescription().cuda_compute_capability();
+  cc.set_major(se_cc.major);
+  cc.set_minor(se_cc.minor);
+  return cc;
+}
+
+void PrintPlatformInfo(const se::Stream* stream) {
+  auto* se = stream->parent();
+  const auto& desc = se->GetDeviceDescription();
+  LOG(ERROR) << "Device: " << desc.name();
+  LOG(ERROR) << "Platform: " << desc.platform_version();
+  LOG(ERROR) << "Driver: " << desc.driver_version();
+  LOG(ERROR) << "Runtime: " << desc.runtime_version();
+
+  auto dnn_version = GetDnnVersionInfo(se);
+  if (dnn_version.ok()) {
+    auto v = dnn_version.value();
+    LOG(ERROR) << "cudnn version: " << v.major_version() << "."
+               << v.minor_version() << "." << v.patch();
+  }
+}
+
+// Returns true if the redzones in `allocator`'s allocations are unmodified.
+//
+// If the redzones are modified, logs an error, sets the appropriate failure
+// bits on `result`, and returns false.
+//
+// Returns a absl::Status if an unexpected error has occurred, and the stream
+// has been poisoned.
+//
+// `name` is a user-friendly name for the set of redzones being checked, e.g.
+// "input/output" or "scratch".
+absl::StatusOr<bool> CheckRedzones(const se::RedzoneAllocator& allocator,
+                                   se::Stream* stream, absl::string_view name,
+                                   std::string_view instr_str,
+                                   AutotuneResult* result) {
+  XLA_SCOPED_LOGGING_TIMER_LEVEL("CudnnConvAlgorithmPicker checking redzones",
+                                 2);
+  using RedzoneCheckStatus = se::RedzoneAllocator::RedzoneCheckStatus;
+  TF_ASSIGN_OR_RETURN(RedzoneCheckStatus redzone_check,
+                      allocator.CheckRedzones());
+  if (redzone_check.ok()) {
+    return true;
+  }
+
+  auto* fail = result->mutable_failure();
+  fail->set_kind(AutotuneResult::REDZONE_MODIFIED);
+  *fail->mutable_msg() = redzone_check.RedzoneFailureMsg();
+  fail->set_buffer_address(
+      reinterpret_cast<uint64_t>(redzone_check.user_buffer_address));
+
+  LOG(ERROR) << absl::StreamFormat(
+      "Detected cudnn out-of-bounds write in conv %s buffer! This is likely a "
+      "cudnn bug. We will skip this algorithm in the future, but your GPU "
+      "state may already be corrupted, leading to incorrect results. Within "
+      "Google, no action is needed on your part. Outside of Google, please "
+      "ensure you're running the latest version of cudnn. If that doesn't fix "
+      "the problem, please file a bug with this full error message and we'll "
+      "contact nvidia.",
+      name);
+  LOG(ERROR) << redzone_check.RedzoneFailureMsg();
+  LOG(ERROR) << "HloInstruction " << instr_str;
+  PrintPlatformInfo(stream);
+  return false;
+}
+
+}  // anonymous namespace
+
+bool ShouldInitConvData(const HloModuleConfig& hlo_module_config) {
+  const int32_t conv_autotune_level =
+      hlo_module_config.debug_options().xla_gpu_autotune_level();
+  return conv_autotune_level >= 2;
+}
+
+bool ShouldCheckConv(const HloModuleConfig& hlo_module_config) {
+  const int32_t conv_autotune_level =
+      hlo_module_config.debug_options().xla_gpu_autotune_level();
+  return conv_autotune_level >= 4;
+}
+
+absl::StatusOr<AutotuneResult> GpuConvAlgorithmPicker::PickBestAlgorithm(
+    const HloCustomCallInstruction* instr) {
+  return AutotunerUtil::Autotune(
+      instr, config_, [&] { return PickBestAlgorithmNoCache(instr); });
+}
+
+absl::StatusOr<AutotuneResult> GpuConvAlgorithmPicker::PickBestAlgorithmNoCache(
+    const HloCustomCallInstruction* instr) {
+  if (config_.IsDeviceless()) {
+    // Return an autotune result with algo id -1, which means that we autotune
+    // at runtime.
+    AutotuneResult result;
+    result.mutable_algorithm()->set_algo_id(-1);
+    return result;
+  }
+
+  se::StreamExecutor* stream_exec = config_.GetExecutor();
+  // Don't run this function concurrently on the same GPU.
+  //
+  // This is a bit of a hack and doesn't protect us against arbitrary concurrent
+  // use of a GPU, but it's sufficient to let us compile two HLO modules
+  // concurrently and then run them sequentially.
+  //
+  // Putting the lock in here rather than in PickBestAlgorithmNoCache lets us
+  // avoid ever doing duplicate work.  If we have a cache miss, only one thread
+  // will run PickBestAlgorithmImpl for a particular device.
+  absl::MutexLock lock(&GetGpuMutex(stream_exec));
+
+  // Make sure any previous activity on this executor is done. We don't want
+  // other work still running on the GPU to interfere with autotuning.
+  if (!stream_exec->SynchronizeAllActivity()) {
+    return Internal(
+        "Failed to synchronize GPU for autotuning conv instruction");
+  }
+
+  absl::StatusOr<AutotuneResult> result_or(Internal("Unknown platform."));
+  // Check StreamExecutor on which platform it is. ROCm and Cuda implementation
+  // have diverged. Specifically, we need to make sure redzone allocator related
+  // utilities are not used in ROCm routine
+  se::Platform::Id platform_id = stream_exec->GetPlatform()->id();
+  if (platform_id == se::rocm::kROCmPlatformId) {
+    result_or = PickBestAlgorithmNoCacheRocm(instr);
+  } else if (platform_id == se::cuda::kCudaPlatformId) {
+#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
+    result_or = PickBestAlgorithmNoCacheCuda(instr);
+#endif
+  }
+
+  return result_or;
+}
+
+#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
+
+absl::StatusOr<GpuConvAlgorithmPicker::AutotuneRuntimeArguments>
+GpuConvAlgorithmPicker::AutotuneRuntimeArguments::FromInstruction(
+    const HloCustomCallInstruction* instr, const AutotuneConfig& config,
+    const DebugOptions& debug_options) {
+  TF_ASSIGN_OR_RETURN(auto rz_buffers,
+                      RedzoneBuffers::FromInstruction(
+                          *instr, config, debug_options,
+                          RedzoneBuffers::kAllInputsOutputsNoScratch));
+
+  // Get canonical HLO.
+  std::string canonical_hlo(
+      AutotuneCacheKey(config.GetExecutor()->GetDeviceDescription(), *instr)
+          .GetHlo());
+
+  TF_ASSIGN_OR_RETURN(GpuConvConfig gpu_conv_config, GetGpuConvConfig(instr));
+
+  GpuConvAlgorithmPicker::AutotuneRuntimeArguments runtime_arguments = {
+      instr->GetModule()->config(),
+      std::move(rz_buffers),
+      std::move(gpu_conv_config),
+      {canonical_hlo}};
+
+  return runtime_arguments;
+}
+
+struct CudnnVersionRange {
+  using TupleVersion = std::tuple<int, int, int>;
+  TupleVersion begin;
+  TupleVersion end;
+
+  bool IsInRange(const CudnnVersion& other) const {
+    TupleVersion other_version{other.major(), other.minor(), other.patch()};
+    return begin <= other_version && other_version < end;
+  }
+
+  CudnnVersionRange(const CudnnVersion& begin, const CudnnVersion& end)
+      : begin(begin.major(), begin.minor(), begin.patch()),
+        end(end.major(), end.minor(), end.patch()) {}
+
+  CudnnVersionRange(const TupleVersion& begin, const TupleVersion& end)
+      : begin(begin), end(end) {}
+};
+
+struct ComputeCapabilityRange {
+  using TupleComputeCapability = std::tuple<int, int>;
+  TupleComputeCapability begin;
+  TupleComputeCapability end;
+
+  bool IsInRange(const ComputeCapability& other) const {
+    TupleComputeCapability other_cc{other.major(), other.minor()};
+    return begin <= other_cc && other_cc < end;
+  }
+};
+
+struct DisabledAlgorithm {
+  CudnnVersionRange cudnn_version_range;
+  ComputeCapabilityRange compute_capability_range;
+  int algo_id;
+};
+
+// TODO(b/343101418): Remove this once the bug is fixed in upstream cudnn and
+// once we updated to that cudnn version.
+static const DisabledAlgorithm kDisabledAlgorithms[] = {
+    {/*.cudnn_version_range=*/{/*.begin=*/{9, 0, 0}, /*.end=*/{10, 0, 0}},
+     /*.compute_capability_range=*/{/*.begin=*/{6, 0}, /*.end=*/{8, 0}},
+     /*.algo_id=*/14}};
+
+// There are three tiers of errors possible here: returning a failed
+// absl::StatusOr means autotuning fails immediately; returning an
+// AutotuneResult with a failure code other than DISQUALIFIED means autotuning
+// fails if crash_on_checking_failure is set; and returning a DISQUALIFIED
+// AutotuneResult simply skips the engine/algorithm while recording a reason for
+// skipping it.
+absl::StatusOr<AutotuneResult> GpuConvAlgorithmPicker::AutotuneOneConvRunner(
+    GenericConvRunner* const runner,
+    std::optional<ReferenceResult>* reference_result,
+    absl::Span<const AlgorithmDesc> disabled_algos,
+    std::optional<AutotuneCacheKey> instruction_info,
+    const AutotuneRuntimeArguments& runtime_arguments) {
+  auto alg = runner->ToAlgorithmDesc();
+
+  se::StreamExecutor* stream_exec = config_.GetExecutor();
+  XLA_SCOPED_LOGGING_TIMER_LEVEL(
+      absl::StrCat("CudnnConvAlgorithmPicker::PickBestAlgorithm algo ",
+                   alg.ToString()),
+      2);
+
+  auto make_failure = [&alg](AutotuneResult::FailureKind kind,
+                             absl::string_view msg) {
+    AutotuneResult result;
+    *result.mutable_algorithm() = alg.ToProto();
+    result.mutable_failure()->set_kind(kind);
+    result.mutable_failure()->set_msg(/* *sigh* */ msg.data(), msg.size());
+    return result;
+  };
+
+  AlgorithmDesc alg_key(alg.algo_id(), alg.tensor_ops_enabled(), std::nullopt);
+
+  std::string instr_str = instruction_info.has_value()
+                              ? std::string(instruction_info->GetHlo())
+                              : "<unknown>";
+
+  for (const auto& disabled_algo : kDisabledAlgorithms) {
+    if (disabled_algo.cudnn_version_range.IsInRange(
+            GetCudnnVersion(stream_exec)) &&
+        disabled_algo.compute_capability_range.IsInRange(
+            GetComputeCapability(stream_exec)) &&
+        disabled_algo.algo_id == alg.algo_id()) {
+      LOG(INFO) << "Omitted potentially buggy algorithm " << alg.ToString()
+                << " for conv " << instr_str;
+      return make_failure(AutotuneResult::DISQUALIFIED,
+                          "Disqualified for being known-buggy.");
+    }
+  }
+
+  if (absl::c_linear_search(disabled_algos, alg_key)) {
+    LOG(INFO) << "Omitted potentially buggy algorithm " << alg.ToString()
+              << " for conv " << instr_str;
+    return make_failure(AutotuneResult::DISQUALIFIED,
+                        "Disqualified for being known-buggy.");
+  }
+
+  GpuConvConfig config = runtime_arguments.gpu_conv_config;
+  auto activation_mode =
+      config.fusion ? config.fusion->mode : se::dnn::ActivationMode::kNone;
+
+  // For fused convolutions with the identity function as the activation, only
+  // ALGO_IMPLICIT_PRECOMP_GEMM does the right thing. Other algorithms
+  // silently do Relu. See
+  // https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnConvolutionBiasActivationForward
+  //
+  // For cuDNN Frontend, there is no way to check whether we're using a broken
+  // algorithm, so on versions where some algorithms are broken, we don't use
+  // the cuDNN Frontend for these convs at all.  As such, if we get a
+  // frontend-based runner, we can be sure it's not one of the broken
+  // algorithms we're checking for.
+  if (!alg.is_cudnn_frontend() &&
+      config.kind == CudnnConvKind::kForwardActivation &&
+      activation_mode == se::dnn::ActivationMode::kNone &&
+      alg.algo_id() != CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM) {
+    return make_failure(AutotuneResult::DISQUALIFIED,
+                        "Disqualified for implicit RELU.");
+  }
+
+  TF_ASSIGN_OR_RETURN(
+      se::RedzoneAllocator scratch_allocator,
+      AutotunerUtil::CreateRedzoneAllocator(
+          config_, runtime_arguments.hlo_module_config.debug_options()));
+
+  se::dnn::ProfileResult profile_result;
+  VLOG(4) << "Trying algorithm " << alg.ToString() << " for " << instr_str;
+
+  SlowOperationAlarm alarm(absl::Seconds(1), [&] {
+    return absl::StrFormat(
+        "Trying algorithm %s for conv %s is taking a while...", alg.ToString(),
+        instr_str);
+  });
+
+  std::optional<size_t> workspace_size =
+      runner->ToAlgorithmDesc().workspace_size();
+  if (!workspace_size) {
+    return make_failure(AutotuneResult::UNKNOWN,
+                        "Internal error: missing workspace size from "
+                        "OpRunner::ToAlgorithmDesc()");
+  }
+
+  auto scratch_or = scratch_allocator.AllocateBytes(*workspace_size);
+  if (!scratch_or.ok()) {
+    return make_failure(AutotuneResult::DISQUALIFIED,
+                        absl::StrCat("Scratch allocation failed: ",
+                                     scratch_or.status().ToString()));
+  }
+  se::DeviceMemoryBase scratch_memory = scratch_or.value();
+
+  // Use assignment instead of brace-list to make GCC 4.9 happy.
+  RunConvOptions options;
+  options.runner_cache = runner;
+  // The following plan timing code is based on
+  // https://github.com/NVIDIA/cudnn-frontend/blob/60496f42fdc7a4ccc059f5934e306e728a756755/include/cudnn_frontend_find_plan.h
+  float max_time = 0;
+  float min_time = std::numeric_limits<float>::max();
+  absl::Status launch_status;
+  std::vector<se::DeviceMemoryBase> operand_buffers =
+      runtime_arguments.rz_buffers.input_buffers();
+  std::vector<se::DeviceMemoryBase> result_buffers =
+      runtime_arguments.rz_buffers.output_buffers();
+
+  TF_ASSIGN_OR_RETURN(se::Stream* const stream, config_.GetStream());
+
+  // Dry-run to warmup the plan.
+  launch_status = RunGpuConv(config, operand_buffers, result_buffers,
+                             scratch_memory, stream, options);
+  // Flag that a warm-up run has been executed; this allows the GpuTimer for
+  // the main measurement to safely use the delay kernel pattern, even if lazy
+  // module loading is enabled.
+  options.profile_result = &profile_result;
+  profile_result.set_warmup_run_executed(true);
+  constexpr int kMaxIter = 10;
+  // Iterate until the new measurement is within kThreshold of the current
+  // minimum.
+  int num_iters = 0;
+  for (; num_iters < kMaxIter && launch_status.ok(); ++num_iters) {
+    launch_status = RunGpuConv(config, operand_buffers, result_buffers,
+                               scratch_memory, stream, options);
+    if (!profile_result.is_valid()) {
+      break;
+    }
+    float old_min_time = min_time;
+    min_time = std::min(min_time, profile_result.elapsed_time_in_ms());
+    max_time = std::max(max_time, profile_result.elapsed_time_in_ms());
+
+    constexpr float kThreshold = 0.05f;
+    if (std::abs(profile_result.elapsed_time_in_ms() - old_min_time) /
+            old_min_time <
+        kThreshold) {
+      break;
+    }
+  }
+  if (!launch_status.ok()) {
+    VLOG(5) << "Launch failed: " << launch_status;
+    return make_failure(
+        AutotuneResult::DISQUALIFIED,
+        absl::StrCat("Profiling failure on cuDNN engine ", alg.ToString(), ": ",
+                     launch_status.ToString()));
+  }
+  if (!profile_result.is_valid()) {
+    VLOG(5) << "Launch succeeded but profile result is invalid.";
+    // Not DISQUALIFIED: this means something went wrong internally.
+    return make_failure(
+        AutotuneResult::UNKNOWN,
+        absl::StrCat("Launch succeeded but profile result is invalid, "
+                     "with cuDNN engine ",
+                     alg.ToString(), ": ", launch_status.ToString()));
+  }
+  VLOG(4) << "Best time: " << min_time << " ms. Worst time: " << max_time
+          << " ms. Total iterations: " << num_iters;
+  int64_t scratch_bytes_used =
+      scratch_allocator.TotalAllocatedBytesExcludingRedzones();
+
+  AutotuneResult result;
+  *result.mutable_algorithm() = alg.ToProto();
+  result.set_scratch_bytes(scratch_bytes_used);
+  *result.mutable_run_time() =
+      tsl::proto_utils::ToDurationProto(absl::Milliseconds(min_time));
+
+  if (!ShouldCheckConv(runtime_arguments.hlo_module_config)) {
+    if (!reference_result->has_value()) {
+      (*reference_result) = {
+          alg, std::vector<DeviceMemoryBase>(result_buffers.size())};
+    }
+    return result;
+  }
+
+  // Check for writes to redzones.
+  TF_ASSIGN_OR_RETURN(
+      bool input_output_allocator_redzone_clear,
+      CheckRedzones(runtime_arguments.rz_buffers.RedzoneAllocator(), stream,
+                    "input/output", instr_str, &result));
+
+  TF_ASSIGN_OR_RETURN(
+      bool scratch_allocator_redzone_clear,
+      CheckRedzones(scratch_allocator, stream, "scratch", instr_str, &result));
+
+  if (!input_output_allocator_redzone_clear ||
+      !scratch_allocator_redzone_clear) {
+    if (runtime_arguments.canonical_hlo.has_value()) {
+      std::string canonical_hlo = runtime_arguments.canonical_hlo.value();
+      std::string blas_version;
+      if (auto* blas = stream_exec->AsBlas()) {
+        (void)blas->GetVersion(&blas_version);
+      }
+
+      AlgorithmDenylist proto;
+      auto entry = proto.add_entries();
+      entry->set_hlo(canonical_hlo);
+      *entry->mutable_cc() = GetComputeCapability(stream_exec);
+      *entry->mutable_cudnn_version() = GetCudnnVersion(stream_exec);
+      entry->set_blas_version(blas_version);
+      auto algo = entry->add_algos();
+      algo->set_id(alg.algo_id());
+      algo->set_tensor_ops(alg.tensor_ops_enabled());
+
+      LOG(ERROR) << "To denylist this algorithm for this convolution, "
+                    "copy-paste the following "
+                    "proto to the denylist file pointed by XLA_FLAGS "
+                    "--xla_gpu_algorithm_denylist_path="
+                 << GetDebugOptionsFromFlags().xla_gpu_algorithm_denylist_path()
+                 << " : " << proto.ShortDebugString();
+    }
+
+    // CheckRedzones has modified the result in-place to include a failure.
+    return result;
+  }
+
+  if (reference_result->has_value()) {
+    XLA_SCOPED_LOGGING_TIMER_LEVEL("BufferComparator::CompareEqual", 2);
+
+    const DebugOptions& debug_options =
+        runtime_arguments.hlo_module_config.debug_options();
+    BufferComparator comparator(runtime_arguments.rz_buffers.output_shape(),
+                                debug_options.xla_gpu_autotune_gemm_rtol());
+    for (int i = 0; i < result_buffers.size(); ++i) {
+      absl::StatusOr<bool> compare_result = comparator.CompareEqual(
+          stream, (*reference_result)->buffers[i], result_buffers[i]);
+      if (!compare_result.ok()) {
+        LOG(ERROR) << "Unable to compare "
+                   << (*reference_result)->algorithm.ToString() << " against "
+                   << alg.ToString() << " for " << instr_str << ": "
+                   << compare_result.status();
+        if (compare_result.status().code() ==
+            absl::StatusCode::kResourceExhausted) {
+          // Possibly OOM. Propagate the error.
+          return compare_result.status();
+        }
+        CHECK(!debug_options.xla_gpu_crash_on_verification_failures());
+      } else if (!compare_result.value()) {
+        LOG(ERROR)
+            << "Results mismatch between different convolution algorithms. "
+               "This is likely a bug/unexpected loss of precision in cudnn.\n"
+            << instr_str << " for " << (*reference_result)->algorithm.ToString()
+            << " vs " << alg.ToString();
+        PrintPlatformInfo(stream);
+        if (instruction_info.has_value()) {
+          VLOG(2) << "Full module on failure: \n"
+                  << instruction_info->GetModelStr();
+        }
+        auto* fail = result.mutable_failure();
+        fail->set_kind(AutotuneResult::WRONG_RESULT);
+        fail->set_buffer_address(
+            reinterpret_cast<uint64_t>(result_buffers[i].opaque()));
+        *fail->mutable_reference_algorithm() =
+            (*reference_result)->algorithm.ToProto();
+      }
+    }
+  } else {
+    XLA_SCOPED_LOGGING_TIMER_LEVEL("Memcpy Reference Result", 2);
+    std::vector<DeviceMemoryBase> reference_result_buffers(
+        result_buffers.size());
+    for (int i = 0; i < result_buffers.size(); ++i) {
+      TF_ASSIGN_OR_RETURN(
+          reference_result_buffers[i],
+          runtime_arguments.rz_buffers.RedzoneAllocator().AllocateBytes(
+              result_buffers[i].size()));
+      TF_RETURN_IF_ERROR(stream->Memcpy(&reference_result_buffers[i],
+                                        result_buffers[i],
+                                        result_buffers[i].size()));
+    }
+    (*reference_result) = {alg, reference_result_buffers};
+  }
+
+  return result;
+}
+
+absl::StatusOr<AutotuneResult>
+GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda(
+    const HloCustomCallInstruction* instr) {
+  AutotuneCacheKey instruction_info{config_.GetModelStr(), *instr};
+  std::string instr_str(instruction_info.GetHlo());
+  XLA_SCOPED_LOGGING_TIMER(absl::StrCat(
+      "GpuConvAlgorithmPicker::PickBestAlgorithmImpl for ", instr_str));
+
+  const DebugOptions& debug_options =
+      instr->GetModule()->config().debug_options();
+  const bool crash_on_checking_failure =
+      debug_options.xla_gpu_crash_on_verification_failures();
+
+  std::string blas_version;
+  se::StreamExecutor* stream_exec = config_.GetExecutor();
+  if (auto* blas = stream_exec->AsBlas()) {
+    (void)blas->GetVersion(&blas_version);
+  }
+
+  std::vector<AlgorithmDesc> disabled_algos;
+  TF_ASSIGN_OR_RETURN(
+      AutotuneRuntimeArguments runtime_arguments,
+      AutotuneRuntimeArguments::FromInstruction(instr, config_, debug_options));
+  if (runtime_arguments.canonical_hlo.has_value()) {
+    disabled_algos = GetDisabledConvAlgorithms(
+        GetComputeCapability(stream_exec), GetCudnnVersion(stream_exec),
+        blas_version, runtime_arguments.canonical_hlo.value());
+  }
+
+  const bool cudnn_frontend_enabled =
+      debug_options.xla_gpu_enable_cudnn_frontend();
+  bool allow_tf32 = true;
+  // TODO(b/284371623): Properly set allow_tf32 even if instr==nullptr, which is
+  // the case when running an AOT compiled executable with runtime autotuning.
+  if (instr) {
+    allow_tf32 = absl::c_all_of(
+        instr->precision_config().operand_precision(),
+        [](int precision) { return precision <= PrecisionConfig::HIGH; });
+  }
+  const se::NumericOptions numeric_options{
+      RequireDeterminism(instr->GetModule()->config()), allow_tf32};
+
+  // Use the first algorithm that's supported as reference. There isn't a
+  // particular reason to use it, as any algorithm suffices. It doesn't make
+  // this algorithm considered correct, though.
+  std::optional<ReferenceResult> reference_result;
+
+  TF_ASSIGN_OR_RETURN(se::Stream* const stream, config_.GetStream());
+  TF_ASSIGN_OR_RETURN(
+      std::vector<GenericConvRunner> runners,
+      GetAlgorithms(runtime_arguments.gpu_conv_config, stream,
+                    cudnn_frontend_enabled,
+                    /* use_fallback = */ false, numeric_options));
+
+  std::vector<AutotuneResult> profile_results;
+  for (auto& runner_cache : runners) {
+    TF_ASSIGN_OR_RETURN(
+        auto result,
+        AutotuneOneConvRunner(&runner_cache, &reference_result, disabled_algos,
+                              instruction_info, runtime_arguments));
+    profile_results.emplace_back(std::move(result));
+  }
+
+  // If any algorithm has worked, we'll skip the fallback algorithms, since
+  // they include some very slow algorithms.
+  if (!reference_result) {
+    LOG(WARNING) << "None of the algorithms provided by cuDNN heuristics "
+                    "worked; trying fallback algorithms.";
+    if (runtime_arguments.canonical_hlo.has_value()) {
+      LOG(WARNING) << "Conv: " << runtime_arguments.canonical_hlo.value();
+    }
+
+    TF_ASSIGN_OR_RETURN(
+        std::vector<GenericConvRunner> fallback_runners,
+        GetAlgorithms(runtime_arguments.gpu_conv_config, stream,
+                      cudnn_frontend_enabled,
+                      /* use_fallback = */ true, numeric_options));
+
+    for (auto& runner_cache : fallback_runners) {
+      TF_ASSIGN_OR_RETURN(
+          auto result, AutotuneOneConvRunner(&runner_cache, &reference_result,
+                                             disabled_algos, instruction_info,
+                                             runtime_arguments));
+      profile_results.emplace_back(std::move(result));
+    }
+  }
+
+  // Log the autotuning result.
+  if (instr) {
+    AutotuningLog log;
+    {
+      ConvInstructionLog instr_log;
+      *instr_log.mutable_instruction() = instr->ToProto();
+      for (int i = 0; i < instr->operand_count(); i++) {
+        *instr_log.add_operand_shapes() = instr->operand(i)->shape().ToProto();
+        instr_log.add_operand_addresses(reinterpret_cast<uint64_t>(
+            runtime_arguments.rz_buffers.input_buffers()[i].opaque()));
+      }
+      for (se::DeviceMemoryBase result_buffer :
+           runtime_arguments.rz_buffers.output_buffers()) {
+        instr_log.add_result_addresses(
+            reinterpret_cast<uint64_t>(result_buffer.opaque()));
+      }
+      log.mutable_instr()->PackFrom(instr_log);
+    }
+    for (const auto& profile : profile_results) {
+      *log.add_results() = profile;
+    }
+    *log.mutable_compute_capability() = GetComputeCapability(stream_exec);
+    *log.mutable_cudnn_version() = GetCudnnVersion(stream_exec);
+    log.set_device_pci_bus_id(stream_exec->GetDeviceDescription().pci_bus_id());
+    log.set_blas_version(blas_version);
+    VLOG(2) << "Autotuning result: " << log.ShortDebugString();
+    // If we crash on checking failure, we are in a testing/benchmark mode.
+    if (crash_on_checking_failure) {
+      // Crash on miscompares and redzone violations if desired.
+      for (const auto& profile : profile_results) {
+        if (profile.has_failure() &&
+            profile.failure().kind() != AutotuneResult::DISQUALIFIED) {
+          LOG(FATAL) << "crash_on_checking_failure encountered errors:\n\n"
+                     << log.DebugString();  // NOLINT
+        }
+      }
+    }
+  }
+
+  TF_ASSIGN_OR_RETURN(AutotuneResult selected_algorithm,
+                      PickBestResult(profile_results, instr_str,
+                                     runtime_arguments.hlo_module_config));
+  return selected_algorithm;
+}
+#endif
+
+absl::StatusOr<AutotuneResult>
+GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheRocm(
+    const HloCustomCallInstruction* instr) {
+  XLA_SCOPED_LOGGING_TIMER(absl::StrCat(
+      "GpuConvAlgorithmPicker::PickBestAlgorithmImpl for ", instr->ToString()));
+
+  const bool allow_tf32 = absl::c_all_of(
+      instr->precision_config().operand_precision(),
+      [](int precision) { return precision <= PrecisionConfig::HIGH; });
+  const se::NumericOptions numeric_options{
+      RequireDeterminism(instr->GetModule()->config()), allow_tf32};
+
+  se::StreamExecutor* stream_exec = config_.GetExecutor();
+  const auto device_ordinal = stream_exec->device_ordinal();
+  std::vector<se::DeviceMemoryBase> operand_buffers;
+
+  // allocator either points to this->allocator_ or, if that's null, to a
+  // se::StreamExecutorMemoryAllocator for stream_exec.
+  se::DeviceMemoryAllocator* allocator = config_.GetAllocator();
+  ScratchAllocator input_output_allocator(device_ordinal, allocator);
+  TF_ASSIGN_OR_RETURN(se::Stream* const stream, config_.GetStream());
+  const auto initialize_buffer = [stream](DeviceMemoryBase buffer) {
+    // Although we don't have evidence this matters, zero out the buffers
+    // before autotuning.  It's conceivable that using uninitialized memory as
+    // the inputs might affect performance if e.g. the inputs contain
+    // denormals, and this is easy enough.
+    return stream->MemZero(&buffer, buffer.size());
+  };
+
+  // Allocate space for the input, filter, and output of the convolution.  We
+  // use a ScratchAllocator for this instead of calling allocator_ directly so
+  // that our allocations don't leak.
+  for (const auto* operand : instr->operands()) {
+    TF_ASSIGN_OR_RETURN(auto buffer,
+                        input_output_allocator.AllocateBytes(
+                            ShapeUtil::ByteSizeOf(operand->shape())));
+    TF_RETURN_IF_ERROR(initialize_buffer(buffer));
+    operand_buffers.push_back(buffer);
+  }
+
+  std::vector<se::DeviceMemoryBase> result_buffers(
+      instr->shape().tuple_shapes_size());
+  if (instr->shape().IsTuple()) {
+    for (int i = 0; i < instr->shape().tuple_shapes_size(); ++i) {
+      TF_ASSIGN_OR_RETURN(
+          result_buffers[i],
+          input_output_allocator.AllocateBytes(
+              ShapeUtil::ByteSizeOf(instr->shape().tuple_shapes(i))));
+      TF_RETURN_IF_ERROR(initialize_buffer(result_buffers[i]));
+    }
+  } else {
+    TF_ASSIGN_OR_RETURN(
+        result_buffers[0],
+        input_output_allocator.AllocateBytes(
+            ShapeUtil::ByteSizeOf(instr->shape().tuple_shapes(0))));
+    TF_RETURN_IF_ERROR(initialize_buffer(result_buffers[0]));
+  }
+
+  ScratchAllocator scratch_allocator(device_ordinal, allocator);
+
+  TF_ASSIGN_OR_RETURN(
+      std::vector<std::unique_ptr<const se::dnn::ConvRunner>> runners,
+      GetMIOpenAlgorithms(instr, absl::MakeSpan(operand_buffers),
+                          absl::MakeSpan(result_buffers), stream_exec,
+                          &scratch_allocator, stream, numeric_options));
+
+  std::vector<AutotuneResult> profile_results;
+
+  if (runners.size() == 1) {
+    TF_ASSIGN_OR_RETURN(auto alg, runners[0]->ToAlgorithmDesc());
+    auto algorithm_proto = alg.ToProto();
+    profile_results.emplace_back();
+    auto& result = profile_results.back();
+    *result.mutable_algorithm() = algorithm_proto;
+
+    result.set_scratch_bytes(runners[0]->GetWorkspaceSize());
+
+    // TODO(awpr): if the profile result time for a singleton algorithm is
+    // needed, plumb it via OpRunner; we'll need to do this to let TF ops avoid
+    // re-profiling ROCm algorithms anyway.
+    *result.mutable_run_time() =
+        tsl::proto_utils::ToDurationProto(absl::Milliseconds(-1));
+  } else {
+    TF_ASSIGN_OR_RETURN(GpuConvConfig config, GetGpuConvConfig(instr));
+    for (auto& runner : runners) {
+      TF_ASSIGN_OR_RETURN(auto alg, runner->ToAlgorithmDesc());
+      XLA_SCOPED_LOGGING_TIMER_LEVEL(
+          absl::StrCat("CudnnConvAlgorithmPicker::PickBestAlgorithm algo ",
+                       alg.ToString()),
+          2);
+
+      se::dnn::ProfileResult profile_result;
+      VLOG(4) << "Trying algorithm " << alg.ToString() << " for "
+              << instr->ToString();
+
+      TF_ASSIGN_OR_RETURN(
+          DeviceMemoryBase scratch_memory,
+          scratch_allocator.AllocateBytes(runner->GetWorkspaceSize()));
+
+      TF_ASSIGN_OR_RETURN(auto lazy_runner,
+                          se::dnn::LazyOpRunner<se::dnn::ConvOp>::FromOpRunner(
+                              std::move(runner)));
+
+      GenericConvRunner runner_cache(std::move(lazy_runner));
+
+      // Use assignment instead of brace-list to make GCC 4.9 happy.
+      RunConvOptions options;
+      options.profile_result = &profile_result;
+      options.runner_cache = &runner_cache;
+      absl::Status launch_status =
+          RunGpuConv(config, absl::MakeSpan(operand_buffers), result_buffers,
+                     scratch_memory, stream, options);
+
+      if (!launch_status.ok()) {
+        continue;
+      }
+
+      if (!profile_result.is_valid()) {
+        continue;
+      }
+
+      profile_results.emplace_back();
+      AutotuneResult& result = profile_results.back();
+      *result.mutable_algorithm() = alg.ToProto();
+
+      int64_t scratch_bytes_used = scratch_allocator.TotalAllocatedBytes();
+      result.set_scratch_bytes(scratch_bytes_used);
+      *result.mutable_run_time() = tsl::proto_utils::ToDurationProto(
+          absl::Milliseconds(profile_result.elapsed_time_in_ms()));
+    }
+  }
+
+  TF_ASSIGN_OR_RETURN(AutotuneResult selected_algorithm,
+                      PickBestResult(profile_results, instr->ToString(),
+                                     instr->GetModule()->config()));
+  return selected_algorithm;
+}
+
+absl::StatusOr<bool> GpuConvAlgorithmPicker::RunOnInstruction(
+    HloInstruction* instr) {
+  CHECK(IsCustomCallToDnnConvolution(*instr));
+
+  const bool strict = instr->parent()
+                          ->parent()
+                          ->config()
+                          .debug_options()
+                          .xla_gpu_strict_conv_algorithm_picker();
+
+  absl::StatusOr<AutotuneResult> best_algo_or =
+      PickBestAlgorithm(Cast<HloCustomCallInstruction>(instr));
+  if (!best_algo_or.ok()) {
+    auto msg = absl::StrFormat(
+        "Failed to determine best cudnn convolution algorithm for:\n%s\n\n"
+        "Original error: %s",
+        instr->ToString(), best_algo_or.status().ToString());
+
+    if (strict) {
+      return Unknown(
+          "%s\n\nTo ignore this failure and try to use a fallback algorithm "
+          "(which may have suboptimal performance), use "
+          "XLA_FLAGS=--xla_gpu_strict_conv_algorithm_picker=false.  Please "
+          "also file a bug for the root cause of failing autotuning.",
+          msg);
+    }
+    LOG(WARNING)
+        << msg << "\n\nAs a result, convolution performance may be suboptimal.";
+    return false;
+  }
+
+  auto best_algo = std::move(best_algo_or).value();
+  VLOG(3) << "Setting cudnn conv to use algorithm "
+          << best_algo.conv().algorithm() << " and "
+          << NumBytesToString(best_algo.scratch_bytes())
+          << " of scratch memory: " << instr->ToString()
+          << " tensor_ops_enabled: " << best_algo.conv().tensor_ops_enabled();
+
+  // Replace instr with a new CustomCall which has the correct algorithm, and
+  // whose output shape has the appropriate amount of scratch memory.
+  HloComputation* computation = instr->parent();
+  std::vector<Shape> new_call_element_shapes;
+  // Add the shapes of the outputs of the convolution.
+  new_call_element_shapes.reserve(instr->shape().tuple_shapes_size() - 1);
+  for (int i = 0; i < instr->shape().tuple_shapes_size() - 1; ++i) {
+    new_call_element_shapes.emplace_back(instr->shape().tuple_shapes(i));
+  }
+  // The final element is the size of the workspace.
+  new_call_element_shapes.emplace_back(
+      ShapeUtil::MakeShape(U8, {best_algo.scratch_bytes()}));
+  Shape new_call_shape = ShapeUtil::MakeTupleShape(new_call_element_shapes);
+
+  TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_backend_config,
+                      instr->backend_config<GpuBackendConfig>());
+  CudnnConvBackendConfig& backend_config =
+      *gpu_backend_config.mutable_cudnn_conv_backend_config();
+  *backend_config.mutable_algorithm() = best_algo.algorithm();
+  backend_config.mutable_algorithm()->mutable_workspace_size()->set_value(
+      best_algo.scratch_bytes());
+
+  HloInstruction* new_call = computation->AddInstruction(
+      instr->CloneWithNewOperands(new_call_shape, instr->operands()));
+
+  // Preserve the name of the old instruction.  This is safe because we're going
+  // to remove the old one anyway, and it makes it easier to trace how our conv
+  // is transformed through all our passes.
+  new_call->SetAndSanitizeName(instr->name());
+
+  VLOG(3) << "Replacing convolution " << instr->ToString() << " with "
+          << new_call->ToString();
+
+  TF_RETURN_IF_ERROR(new_call->set_backend_config(gpu_backend_config));
+
+  std::vector<HloInstruction*> new_tuple_elements;
+  new_tuple_elements.reserve(new_call->shape().tuple_shapes_size() - 1);
+  for (int i = 0; i < new_call->shape().tuple_shapes_size() - 1; ++i) {
+    new_tuple_elements.emplace_back(
+        computation->AddInstruction(HloInstruction::CreateGetTupleElement(
+            new_call->shape().tuple_shapes(i), new_call, i)));
+  }
+  new_tuple_elements.emplace_back(computation->AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::CreateR1<uint8_t>({}))));
+
+  // Repackage new_call so it has the same shape as the original call, namely
+  // (conv_result, u8[0]).
+  HloInstruction* new_tuple = computation->AddInstruction(
+      HloInstruction::CreateTuple(new_tuple_elements));
+
+  TF_RETURN_IF_ERROR(instr->parent()->ReplaceInstruction(instr, new_tuple));
+  return true;
+}
+
+absl::StatusOr<bool> GpuConvAlgorithmPicker::RunOnComputation(
+    HloComputation* computation) {
+  std::vector<HloInstruction*> convs;
+  for (HloInstruction* instr : computation->instructions()) {
+    if (IsCandidate(instr)) {
+      convs.push_back(instr);
+    }
+  }
+
+  bool changed = false;
+  for (HloInstruction* instr : convs) {
+    TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(instr));
+    changed |= result;
+  }
+  return changed;
+}
+
+absl::StatusOr<bool> GpuConvAlgorithmPicker::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  XLA_SCOPED_LOGGING_TIMER(
+      absl::StrCat("GpuConvAlgorithmPicker for ", module->name()));
+
+  if (!IsEnabled(module)) {
+    VLOG(3) << "Convolution auto-tuning disabled, GpuConvAlgorithmPicker "
+               "returning early.";
+    return false;
+  }
+
+  bool changed = false;
+  for (HloComputation* computation :
+       module->MakeNonfusionComputations(execution_threads)) {
+    TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation));
+    changed |= result;
+  }
+  return changed;
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker.h b/third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker.h
new file mode 100644
index 0000000..173a0c6
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker.h
@@ -0,0 +1,153 @@
+/* Copyright 2018 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_AUTOTUNING_CONV_ALGORITHM_PICKER_H_
+#define XLA_SERVICE_GPU_AUTOTUNING_CONV_ALGORITHM_PICKER_H_
+
+#include <optional>
+#include <string>
+#include <vector>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/autotune_results.pb.h"
+#include "xla/autotuning.pb.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/gpu/autotuning/autotuner_compile_util.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/service/gpu/gpu_conv_runner.h"
+#include "xla/service/hlo_module_config.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/stream_executor/device_memory.h"
+#include "xla/stream_executor/dnn.h"
+#include "xla/stream_executor/stream_executor.h"
+
+namespace xla {
+namespace gpu {
+
+// Choose the fastest algorithm for each conv.
+// Modifies CustomCalls to cudnn convolutions, choosing the best algorithm for
+// each and adding explicit scratch space to the CustomCalls.
+//
+// We pick the algorithm before fusion so that we can generate better HLO. After
+// GpuConvRewriter, our convolutions are CustomCalls which return a
+// tuple (conv_result, scratch_memory), and the each conv uses 0 bytes of
+// scratch:
+//
+//   customcall = (f32[...], f32[0])
+//   return gte(customcall, 0)
+//
+// The algorithm picker then chooses the best algorithm, and potentially
+// increases the scratch space.  It replaces customcall with new_tuple,
+// giving us the following:
+//
+//   new_customcall = (f32[...], f32[N])
+//   new_tuple = tuple(gte(new_customcall, 0), constant f32[0])
+//   return gte(new_tuple, 0)
+//
+// The new tuple and gte instructions can be simplified away, because
+// nobody is expected to use the scratch value.
+//
+// However, if we were to run GpuConvAlgorithmPicker after fusion
+// the gte(customcall, 0) would probably already be into a fusion node.  We
+// can't simplify across HloComputation boundaries, so in this case we
+// wouldn't be able to simplify away the new_tuple bits.
+//
+// It supports two modes: device and deviceless.
+// In device mode, we run autotuning on the device and store autotune results.
+//
+// In deviceless mode, we pass in some information related to the device and
+// use stored autotune results to rewrite convolutions. If the required autotune
+// result is not stored, then the performance of convolution will be suboptimal.
+class GpuConvAlgorithmPicker : public HloModulePass {
+ public:
+  explicit GpuConvAlgorithmPicker(AutotuneConfig config) : config_(config) {}
+
+  absl::string_view name() const override {
+    return "gpu-conv-algorithm-picker";
+  }
+
+  static bool IsEnabled(const HloModule* module) {
+    return module->config().debug_options().xla_gpu_autotune_level() != 0;
+  }
+
+  static bool IsCandidate(const HloInstruction* instr) {
+    return IsCustomCallToDnnConvolution(*instr);
+  }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+  absl::StatusOr<bool> RunOnComputation(HloComputation* computation);
+  absl::StatusOr<bool> RunOnInstruction(HloInstruction* instr);
+
+  absl::StatusOr<AutotuneResult> PickBestAlgorithm(
+      const HloCustomCallInstruction* instr);
+  absl::StatusOr<AutotuneResult> PickBestAlgorithmNoCache(
+      const HloCustomCallInstruction* instr);
+
+#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
+  // Simple bundle of an algorithm and its output, for comparing results across
+  // autotuned algorithms.
+  struct ReferenceResult {
+    stream_executor::dnn::AlgorithmDesc algorithm;
+    std::vector<stream_executor::DeviceMemoryBase> buffers;
+  };
+
+  // Execution environment for autotuning. Runtime autotuning requires runtime
+  // information such as input/output buffers in order to run. It can be
+  // constructed from the autotuned instruction by FromInstruction.
+  struct AutotuneRuntimeArguments {
+    const HloModuleConfig hlo_module_config;
+    RedzoneBuffers rz_buffers;
+    const GpuConvConfig gpu_conv_config;
+    std::optional<std::string> canonical_hlo;
+
+    static absl::StatusOr<AutotuneRuntimeArguments> FromInstruction(
+        const HloCustomCallInstruction* instr, const AutotuneConfig& config,
+        const DebugOptions& debug_options);
+  };
+
+  absl::StatusOr<AutotuneResult> AutotuneOneConvRunner(
+      GenericConvRunner* runner,
+      std::optional<ReferenceResult>* reference_result,
+      absl::Span<const stream_executor::dnn::AlgorithmDesc> disabled_algos,
+      std::optional<AutotuneCacheKey> instruction_info,
+      const AutotuneRuntimeArguments& runtime_arguments);
+
+  // Pick the best algorithm for CUDA platform.
+  absl::StatusOr<AutotuneResult> PickBestAlgorithmNoCacheCuda(
+      const HloCustomCallInstruction* instr);
+#endif
+
+  absl::StatusOr<AutotuneResult> PickBestAlgorithmNoCacheRocm(
+      const HloCustomCallInstruction* instr);
+
+ private:
+  AutotuneConfig config_;
+};
+
+}  // namespace gpu
+}  // namespace xla
+#endif  // XLA_SERVICE_GPU_AUTOTUNING_CONV_ALGORITHM_PICKER_H_
diff --git a/third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker_test.cc b/third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker_test.cc
new file mode 100644
index 0000000..9652014
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker_test.cc
@@ -0,0 +1,129 @@
+/* Copyright 2022 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/autotuning/conv_algorithm_picker.h"
+
+#include <cstdint>
+#include <variant>
+#include <vector>
+
+#include "absl/strings/string_view.h"
+#include "xla/debug_options_flags.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
+#include "xla/service/gpu/stream_executor_util.h"
+#include "xla/service/gpu/transforms/conv_rewriter.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/service/platform_util.h"
+#include "xla/service/tuple_simplifier.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/stream_executor/platform.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/tsl/lib/core/status_test_util.h"
+#include "tsl/platform/statusor.h"
+#include "tsl/platform/test.h"
+
+namespace xla::gpu {
+namespace {
+
+namespace m = ::xla::match;
+
+class GpuConvAlgorithmPickerTest : public HloTestBase {
+ public:
+  GpuConvAlgorithmPickerTest() { AutotunerUtil::ClearAutotuneResults(); }
+};
+
+TEST_F(GpuConvAlgorithmPickerTest, SetAlgorithm) {
+  constexpr absl::string_view kHlo = R"(
+HloModule module
+
+ENTRY main {
+  %arg0 = f32[3,56,56,16]{2,1,0,3} parameter(0)
+  %arg1 = f32[3,3,3,64]{2,1,0,3} parameter(1)
+  ROOT %conv = f32[54,54,16,64]{1,0,3,2} convolution(%arg0, %arg1), window={size=3x3}, dim_labels=f01b_i01o->01bf
+})";
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kHlo));
+
+  se::Platform* platform = PlatformUtil::GetDefaultPlatform().value();
+  TF_ASSERT_OK_AND_ASSIGN(std::vector<se::StreamExecutor*> executors,
+                          PlatformUtil::GetStreamExecutors(platform));
+  ASSERT_GT(executors.size(), 0);
+  se::StreamExecutor* stream_exec = executors[0];
+
+  const se::GpuComputeCapability& cc = backend()
+                                           .default_stream_executor()
+                                           ->GetDeviceDescription()
+                                           .gpu_compute_capability();
+  bool changed = false;
+  TF_ASSERT_OK_AND_ASSIGN(changed, RunHloPass(ConvRewriter(cc), m.get()));
+  changed = false;
+  DebugOptions opts = DefaultDebugOptionsIgnoringFlags();
+
+  AutotuneConfig cfg{DeviceConfig{stream_exec, nullptr}, opts};
+  TF_ASSERT_OK_AND_ASSIGN(changed,
+                          RunHloPass(GpuConvAlgorithmPicker(cfg), m.get()));
+  ASSERT_TRUE(changed);
+
+  AutotuneResults results;
+  TF_ASSERT_OK(AutotunerUtil::SerializeAutotuneResults(&results));
+  ASSERT_EQ(results.results_size(), 1);
+  auto& result = *results.mutable_results(0)->mutable_result();
+  int64_t old_scratch_bytes = result.scratch_bytes();
+  int64_t new_scratch_bytes = old_scratch_bytes + 1;
+  result.set_scratch_bytes(new_scratch_bytes);
+
+  AutotunerUtil::ClearAutotuneResults();
+  TF_ASSERT_OK(AutotunerUtil::LoadAutotuneResults(results));
+
+  // Now send the same module through GpuConvAlgorithmPicker again.  The conv
+  // should have the new scratch bytes.
+  TF_ASSERT_OK_AND_ASSIGN(m, ParseAndReturnVerifiedModule(kHlo));
+  changed = false;
+  TF_ASSERT_OK_AND_ASSIGN(changed, RunHloPass(ConvRewriter(cc), m.get()));
+  changed = false;
+  TF_ASSERT_OK_AND_ASSIGN(changed,
+                          RunHloPass(GpuConvAlgorithmPicker(cfg), m.get()));
+  ASSERT_TRUE(changed);
+
+  // TupleSimplifier cleans this up a bit before we pattern-match
+  TF_ASSERT_OK(RunHloPass(TupleSimplifier(), m.get()).status());
+
+  SCOPED_TRACE(m->ToString());
+  HloInstruction* conv;
+  ASSERT_THAT(m->entry_computation()->root_instruction(),
+              GmockMatch(m::GetTupleElement(m::CustomCall(&conv))));
+  EXPECT_THAT(
+      conv->shape(),
+      GmockMatch(m::Shape().WithSubshape(
+          {1}, m::Shape().WithElementType(U8).WithDims({new_scratch_bytes}))));
+
+  // Algorithm 14 is disabled for cuDNN 9 on V100
+  TF_ASSERT_OK_AND_ASSIGN(auto dnn_version, GetDnnVersionInfo(stream_exec));
+  if (dnn_version.major_version() >= 9 && dnn_version.major_version() < 10 &&
+      std::holds_alternative<stream_executor::CudaComputeCapability>(cc) &&
+      std::get<stream_executor::CudaComputeCapability>(cc).major == 7 &&
+      std::get<stream_executor::CudaComputeCapability>(cc).minor == 0) {
+    EXPECT_TRUE(conv->backend_config<GpuBackendConfig>()
+                    ->has_cudnn_conv_backend_config() &&
+                conv->backend_config<GpuBackendConfig>()
+                        ->cudnn_conv_backend_config()
+                        .algorithm()
+                        .algo_id() != 14);
+  }
+}
+
+}  // namespace
+}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/autotuning/custom_kernel_fusion_autotuner.cc b/third_party/xla/xla/service/gpu/autotuning/custom_kernel_fusion_autotuner.cc
new file mode 100644
index 0000000..a920b99
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/autotuning/custom_kernel_fusion_autotuner.cc
@@ -0,0 +1,220 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/autotuning/custom_kernel_fusion_autotuner.h"
+
+#include <cstdint>
+#include <memory>
+#include <optional>
+#include <tuple>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "absl/time/time.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/executable.h"
+#include "xla/service/gpu/autotuning/autotuner_compile_util.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/kernels/custom_kernel.h"
+#include "xla/service/gpu/kernels/custom_kernel_fusion.h"
+#include "xla/service/shaped_buffer.h"
+#include "xla/status_macros.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/stream_executor/device_memory_allocator.h"
+#include "xla/stream_executor/gpu/redzone_allocator.h"
+#include "xla/stream_executor/stream.h"
+#include "xla/stream_executor/stream_executor_memory_allocator.h"
+#include "xla/tools/hlo_decomposer.h"
+#include "xla/util.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+absl::StatusOr<std::unique_ptr<HloModule>> ExtractFusionModule(
+    HloInstruction* fusion_instruction, int64_t kernel_index) {
+  std::unique_ptr<HloModule> hlo_module =
+      ExtractInstructionIntoNewModule(*fusion_instruction);
+
+  HloInstruction* instruction =
+      hlo_module->entry_computation()->root_instruction();
+  GpuBackendConfig gpu_config =
+      instruction->backend_config<GpuBackendConfig>().value();
+  gpu_config.mutable_fusion_backend_config()
+      ->mutable_custom_fusion_config()
+      ->set_kernel_index(kernel_index);
+  TF_RETURN_IF_ERROR(instruction->set_backend_config(gpu_config));
+
+  return hlo_module;
+}
+
+absl::StatusOr<std::vector<std::tuple<int, absl::Duration>>> ProfileKernels(
+    std::vector<CustomKernel>& kernels, HloInstruction* fusion_instruction,
+    AutotunerCompileUtil& compile_util, const AutotuneConfig& autotune_config,
+    const DebugOptions& debug_options) {
+  se::StreamExecutor* stream_exec = autotune_config.GetExecutor();
+  std::vector<std::tuple<int, absl::Duration>> results;
+  for (int i = 0; i < kernels.size(); ++i) {
+    TF_ASSIGN_OR_RETURN(absl::StatusOr<std::unique_ptr<Executable>> executable,
+                        compile_util.Compile([&](const DebugOptions& opt) {
+                          return ExtractFusionModule(fusion_instruction, i);
+                        }));
+
+    se::DeviceMemoryAllocator* allocator = autotune_config.GetAllocator();
+    std::unique_ptr<se::DeviceMemoryAllocator> owned_allocator;
+    if (allocator == nullptr) {
+      owned_allocator =
+          std::make_unique<se::StreamExecutorMemoryAllocator>(stream_exec);
+      allocator = owned_allocator.get();
+    }
+    TF_ASSIGN_OR_RETURN(se::Stream* const stream, autotune_config.GetStream());
+
+    TF_ASSIGN_OR_RETURN(auto rz_buffers,
+                        RedzoneBuffers::FromInstruction(
+                            *fusion_instruction, autotune_config, debug_options,
+                            RedzoneBuffers::kAllInputs));
+
+    std::optional<ScopedShapedBuffer> reference_buffer;
+    std::optional<AutotunerCompileUtil::ProfilingOutput> profiling_output;
+    TF_ASSIGN_OR_RETURN(profiling_output, compile_util.ProfileExecutable(
+                                              executable->get(), stream,
+                                              rz_buffers.input_buffers(),
+                                              rz_buffers.input_shapes()));
+    results.push_back({i, profiling_output->duration});
+  }
+  return results;
+}
+
+absl::StatusOr<int> FindFastestKernel(
+    const std::vector<std::tuple<int, absl::Duration>>& results) {
+  auto iter = absl::c_min_element(
+      results, [](const std::tuple<int, absl::Duration>& lhs,
+                  const std::tuple<int, absl::Duration>& rhs) {
+        return std::get<1>(lhs) < std::get<1>(rhs);
+      });
+  if (iter == results.end()) {
+    return absl::InternalError("Failed to find fastest kernel.");
+  }
+  return std::get<0>(*iter);
+}
+
+absl::Status UpdateFusionInstructionKernelIndex(
+    HloInstruction* fusion_instruction, int kernel_index) {
+  GpuBackendConfig gpu_config =
+      fusion_instruction->backend_config<GpuBackendConfig>().value();
+  gpu_config.mutable_fusion_backend_config()
+      ->mutable_custom_fusion_config()
+      ->set_kernel_index(kernel_index);
+  TF_RETURN_IF_ERROR(fusion_instruction->set_backend_config(gpu_config));
+
+  return absl::OkStatus();
+}
+
+absl::StatusOr<std::vector<CustomKernel>> LoadKernels(
+    const HloInstruction* fusion_instruction,
+    const AutotuneConfig& autotune_config) {
+  auto config = fusion_instruction->backend_config<GpuBackendConfig>()
+                    ->fusion_backend_config()
+                    .custom_fusion_config();
+  auto* registry = CustomKernelFusionRegistry::Default();
+  auto* custom_kernel_fusion = registry->Lookup(config.name());
+
+  // If custom fusion is not found it means that some of the build targets might
+  // not be statically linked into the binary.
+  if (custom_kernel_fusion == nullptr) {
+    return absl::InternalError(
+        absl::StrCat("Custom kernel fusion ", config.name(),
+                     " not found in a default registry."));
+  }
+
+  se::StreamExecutor* stream_exec = autotune_config.GetExecutor();
+  if (!stream_exec->SynchronizeAllActivity()) {
+    return Internal("Failed to synchronize GPU for autotuning.");
+  }
+  se::DeviceDescription device_description =
+      stream_exec->GetDeviceDescription();
+
+  // Load custom kernels that can implement a fusion computation.
+  TF_ASSIGN_OR_RETURN(
+      std::vector<CustomKernel> kernels,
+      custom_kernel_fusion->LoadKernels(
+          device_description,
+          fusion_instruction->fused_instructions_computation()));
+
+  return kernels;
+}
+
+absl::StatusOr<bool> AutotuneCustomKernelFusion(
+    HloInstruction* fusion_instruction, const AutotuneConfig& autotune_config,
+    AutotunerCompileUtil& compile_util, const DebugOptions& debug_options) {
+  int previous_kernel_index =
+      fusion_instruction->backend_config<GpuBackendConfig>()
+          ->fusion_backend_config()
+          .custom_fusion_config()
+          .kernel_index();
+
+  TF_ASSIGN_OR_RETURN(std::vector<CustomKernel> kernels,
+                      LoadKernels(fusion_instruction, autotune_config));
+
+  std::vector<std::tuple<int, absl::Duration>> results;
+  TF_ASSIGN_OR_RETURN(results,
+                      ProfileKernels(kernels, fusion_instruction, compile_util,
+                                     autotune_config, debug_options));
+
+  TF_ASSIGN_OR_RETURN(int fastest_kernel_index, FindFastestKernel(results));
+
+  TF_RETURN_IF_ERROR(UpdateFusionInstructionKernelIndex(fusion_instruction,
+                                                        fastest_kernel_index));
+
+  return previous_kernel_index != fastest_kernel_index;
+}
+}  // namespace
+
+absl::StatusOr<bool> CustomKernelFusionAutotuner::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  const DebugOptions& debug_options = module->config().debug_options();
+  TF_ASSIGN_OR_RETURN(std::optional<AutotunerCompileUtil> compile_util,
+                      AutotunerCompileUtil::Create(config_, debug_options));
+  TF_RET_CHECK(compile_util.has_value());
+
+  bool hlo_changed = false;
+  for (const HloComputation* computation : module->computations()) {
+    if (computation->IsFusionComputation()) {
+      TF_ASSIGN_OR_RETURN(
+          bool instruction_changed,
+          AutotuneCustomKernelFusion(computation->FusionInstruction(), config_,
+                                     compile_util.value(), debug_options));
+      if (instruction_changed) {
+        hlo_changed = true;
+      }
+    }
+  }
+
+  return hlo_changed;
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/autotuning/custom_kernel_fusion_autotuner.h b/third_party/xla/xla/service/gpu/autotuning/custom_kernel_fusion_autotuner.h
new file mode 100644
index 0000000..07aad07
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/autotuning/custom_kernel_fusion_autotuner.h
@@ -0,0 +1,53 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_AUTOTUNING_CUSTOM_KERNEL_FUSION_AUTOTUNER_H_
+#define XLA_SERVICE_GPU_AUTOTUNING_CUSTOM_KERNEL_FUSION_AUTOTUNER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/autotuning.pb.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/xla.pb.h"
+
+namespace xla {
+namespace gpu {
+
+// Find best custom kernel for custom kernel fusions.
+class CustomKernelFusionAutotuner : public HloModulePass {
+ public:
+  explicit CustomKernelFusionAutotuner(const AutotuneConfig& config)
+      : config_(config) {}
+
+  absl::string_view name() const override {
+    return "custom_kernel-fusion-autotuner";
+  }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+  const AutotuneConfig config_;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_AUTOTUNING_CUSTOM_KERNEL_FUSION_AUTOTUNER_H_
diff --git a/third_party/xla/xla/service/gpu/autotuning/custom_kernel_fusion_autotuner_test.cc b/third_party/xla/xla/service/gpu/autotuning/custom_kernel_fusion_autotuner_test.cc
new file mode 100644
index 0000000..8defca9
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/autotuning/custom_kernel_fusion_autotuner_test.cc
@@ -0,0 +1,112 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/autotuning/custom_kernel_fusion_autotuner.h"
+
+#include <memory>
+#include <string>
+#include <utility>
+
+#include <gtest/gtest.h>
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
+#include "xla/service/hlo_pass_pipeline.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/xla.pb.h"
+#include "tsl/platform/test.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+class CustomKernelFusionAutotunerTest : public HloTestBase {
+ public:
+  CustomKernelFusionAutotunerTest()
+      : HloTestBase(/*verifier_layout_sensitive=*/false,
+                    /*allow_mixed_precision_in_hlo_verifier=*/true) {}
+
+  void SetUp() override { HloTestBase::SetUp(); }
+
+  void TearDown() override { HloTestBase::TearDown(); }
+};
+
+TEST_F(CustomKernelFusionAutotunerTest,
+       CustomKernelFusionAutotunerPassSucceeds) {
+  const std::string hlo_string = R"(
+    HloModule extracted
+
+    cutlass_gemm {
+      p0 = f32[15,19]{1,0} parameter(0)
+      p1 = f32[19,17]{1,0} parameter(1)
+      ROOT r = f32[15, 17]{1,0} dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+    }
+
+    ENTRY region_198.14436 {
+      p.0 = f32[15,19]{1,0} parameter(0)
+      p.1 = f32[19,17]{1,0} parameter(1)
+      ROOT cutlass_gemm = f32[15,17]{1,0} fusion(p.0, p.1), kind=kCustom, calls=cutlass_gemm, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"cutlass_gemm","kernel_index":0}},"force_earliest_schedule":false}
+    }
+  )";
+  std::unique_ptr<HloModule> hlo_module =
+      ParseAndReturnVerifiedModule(hlo_string).value();
+
+  HloPassPipeline pipeline("custom_kernel_fusion_autotuner");
+  DebugOptions debug_options;
+  AutotuneConfig autotune_config =
+      AutotuneConfig{DeviceConfig{backend().default_stream_executor(),
+                                  backend().memory_allocator()},
+                     debug_options};
+  pipeline.AddPass<CustomKernelFusionAutotuner>(autotune_config);
+  ASSERT_TRUE(pipeline.Run(hlo_module.get()).ok());
+}
+
+TEST_F(CustomKernelFusionAutotunerTest,
+       CustomKernelFusionAutotunerPassUpdatesUpdatesKernelIndex) {
+  const std::string hlo_string = R"(
+    HloModule extracted
+
+    cutlass_gemm {
+      p0 = f32[15,19]{1,0} parameter(0)
+      p1 = f32[19,17]{1,0} parameter(1)
+      ROOT r = f32[15, 17]{1,0} dot(p0, p1), lhs_contracting_dims={1},
+      rhs_contracting_dims={0}
+    }
+
+    ENTRY region_198.14436 {
+      p.0 = f32[15,19]{1,0} parameter(0)
+      p.1 = f32[19,17]{1,0} parameter(1)
+      ROOT cutlass_gemm = f32[15,17]{1,0} fusion(p.0, p.1), kind=kCustom,
+      calls=cutlass_gemm,
+      backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"cutlass_gemm","kernel_index":-1}},"force_earliest_schedule":false}
+    }
+  )";
+
+  HloPassPipeline pipeline("custom_kernel_fusion_autotuner");
+  DebugOptions debug_options;
+  AutotuneConfig autotune_config =
+      AutotuneConfig{DeviceConfig{backend().default_stream_executor(),
+                                  backend().memory_allocator()},
+                     debug_options};
+  pipeline.AddPass<CustomKernelFusionAutotuner>(autotune_config);
+
+  std::string expected = R"(
+    CHECK: "kernel_index":0
+  )";
+  RunAndFilecheckHloRewrite(hlo_string, std::move(pipeline), expected);
+}
+
+}  // namespace
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc
new file mode 100644
index 0000000..8a870f8
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc
@@ -0,0 +1,500 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/autotuning/gemm_algorithm_picker.h"
+
+#include <cstddef>
+#include <cstdint>
+#include <memory>
+#include <optional>
+#include <string>
+#include <utility>
+#include <variant>
+#include <vector>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "absl/synchronization/mutex.h"
+#include "absl/types/span.h"
+#include "xla/autotuning.pb.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/service/gpu/autotuning/autotuner_compile_util.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/buffer_comparator.h"
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/service/gpu/matmul_utils.h"
+#include "xla/service/gpu/stream_executor_util.h"
+#include "xla/service/gpu/variant_visitor.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/stream_executor/blas.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/stream_executor/device_memory.h"
+#include "xla/stream_executor/device_memory_allocator.h"
+#include "xla/stream_executor/gpu/redzone_allocator.h"
+#include "xla/tsl/util/proto/proto_utils.h"
+#include "xla/util.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/logging.h"
+#include "tsl/platform/statusor.h"
+#include "tsl/profiler/lib/scoped_annotation.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+using se::gpu::BlasLt;
+
+absl::StatusOr<BlasLt::Epilogue> AsBlasLtEpilogue(
+    GemmBackendConfig_Epilogue epilogue) {
+  switch (epilogue) {
+    case GemmBackendConfig::DEFAULT:
+      return BlasLt::Epilogue::kDefault;
+    case GemmBackendConfig::RELU:
+      return BlasLt::Epilogue::kReLU;
+    case GemmBackendConfig::GELU:
+      return BlasLt::Epilogue::kGELU;
+    case GemmBackendConfig::GELU_AUX:
+      return BlasLt::Epilogue::kGELUWithAux;
+    case GemmBackendConfig::BIAS:
+      return BlasLt::Epilogue::kBias;
+    case GemmBackendConfig::BIAS_RELU:
+      return BlasLt::Epilogue::kBiasThenReLU;
+    case GemmBackendConfig::BIAS_GELU:
+      return BlasLt::Epilogue::kBiasThenGELU;
+    case GemmBackendConfig::BIAS_GELU_AUX:
+      return BlasLt::Epilogue::kBiasThenGELUWithAux;
+    default:
+      return Internal("Unsupported Epilogue.");
+  }
+}
+
+class GemmAutotuner {
+  const AutotuneConfig& autotune_config_;
+  RedzoneBuffers rz_buffers_;
+  se::Stream* stream_ = nullptr;
+  bool deterministic_ops_ = false;
+  size_t solutions_limit_ = 0;
+  size_t num_algorithms_left_ = 0;
+
+ public:
+  explicit GemmAutotuner(const AutotuneConfig& autotune_config)
+      : autotune_config_(autotune_config) {}
+
+  size_t num_algorithms_left() const { return num_algorithms_left_; }
+
+  absl::StatusOr<AutotuneResult> operator()(const HloInstruction* gemm,
+                                            const AutotuneCacheKey& key) {
+    num_algorithms_left_ = 0;
+    if (autotune_config_.IsDeviceless()) {
+      // Return empty result, will tune at runtime.
+      return AutotuneResult{};
+    }
+    VLOG(3) << "Starting autotune of GemmThunk " << gemm->ToString();
+
+    TF_ASSIGN_OR_RETURN(stream_, autotune_config_.GetStream());
+    const DebugOptions& debug_options =
+        gemm->GetModule()->config().debug_options();
+    deterministic_ops_ = RequireDeterminism(gemm->GetModule()->config());
+    solutions_limit_ = debug_options.xla_gpu_autotune_max_solutions();
+
+    TF_ASSIGN_OR_RETURN(auto gemm_config, GemmConfig::For(gemm));
+
+    // Don't run autotuning concurrently on the same GPU.
+    absl::MutexLock gpu_lock(&GetGpuMutex(stream_->parent()));
+
+    TF_ASSIGN_OR_RETURN(rz_buffers_, RedzoneBuffers::FromInstruction(
+                                         *gemm, autotune_config_, debug_options,
+                                         RedzoneBuffers::kAllInputsAllOutputs));
+
+    return IsCublasLtMatmul(*gemm) || IsCublasLtMatmulF8(*gemm)
+               ? TuneGpuBlasLt(gemm, gemm_config)
+               : TuneGpuBlas(gemm, gemm_config);
+  }
+
+ private:
+  se::DeviceMemoryBase LhsBuffer() { return rz_buffers_.input_buffers().at(0); }
+  se::DeviceMemoryBase RhsBuffer() { return rz_buffers_.input_buffers().at(1); }
+  se::DeviceMemoryBase OutputBuffer() {
+    return rz_buffers_.output_buffers().at(0);
+  }
+
+  const Shape& GetOutputShape(const HloInstruction* gemm) {
+    return gemm->shape().IsTuple() ? gemm->shape().tuple_shapes(0)
+                                   : gemm->shape();
+  }
+
+  absl::StatusOr<AutotuneResult> TuneGpuBlasLt(const HloInstruction* gemm,
+                                               const GemmConfig& gemm_config) {
+    auto workspace_buffer =
+        rz_buffers_.output_buffers().at(gemm->shape().tuple_shapes_size() - 1);
+
+    GpuBackendConfig gpu_config =
+        gemm->backend_config<GpuBackendConfig>().value();
+    const GemmBackendConfig& backend_config = gpu_config.gemm_backend_config();
+
+    bool has_matrix_bias = gemm_config.beta != 0.;
+
+    TF_ASSIGN_OR_RETURN(
+        bool has_vector_bias,
+        gpublas_lt::EpilogueAddsVectorBias(backend_config.epilogue()));
+
+    TF_ASSIGN_OR_RETURN(
+        bool has_aux_output,
+        gpublas_lt::EpilogueHasAuxiliaryOutput(backend_config.epilogue()));
+
+    TF_ASSIGN_OR_RETURN(auto epilogue,
+                        AsBlasLtEpilogue(backend_config.epilogue()));
+
+    se::DeviceMemoryBase a_scale_buffer, b_scale_buffer, c_scale_buffer,
+        d_scale_buffer, d_amax_buffer, bias_buffer, aux_buffer;
+
+    if (has_vector_bias) {
+      bias_buffer = rz_buffers_.input_buffers().at(has_matrix_bias ? 3 : 2);
+    }
+    if (has_aux_output) {
+      aux_buffer = rz_buffers_.output_buffers().at(1);
+    }
+
+    TF_ASSIGN_OR_RETURN(auto plan,
+                        BlasLt::GetMatmulPlan(stream_, gemm_config, epilogue));
+
+    TF_ASSIGN_OR_RETURN(
+        auto algorithms,
+        plan->GetAlgorithms(/*max_algorithm_count*/ 128,
+                            /*max_workspace_size*/ workspace_buffer.size()));
+
+    auto tuned_func = [&](const BlasLt::MatmulAlgorithm& algorithm)
+        -> absl::StatusOr<se::blas::ProfileResult> {
+      // Run a warmup iteration without the profiler active.
+      TF_RETURN_IF_ERROR(plan->ExecuteOnStream(
+          stream_, LhsBuffer(), RhsBuffer(), OutputBuffer(), OutputBuffer(),
+          bias_buffer, aux_buffer, a_scale_buffer, b_scale_buffer,
+          c_scale_buffer, d_scale_buffer, d_amax_buffer, algorithm,
+          workspace_buffer));
+      se::blas::ProfileResult profile_result;
+      profile_result.set_warmup_run_executed(true);
+      TF_RETURN_IF_ERROR(plan->ExecuteOnStream(
+          stream_, LhsBuffer(), RhsBuffer(), OutputBuffer(), OutputBuffer(),
+          bias_buffer, aux_buffer, a_scale_buffer, b_scale_buffer,
+          c_scale_buffer, d_scale_buffer, d_amax_buffer, algorithm,
+          workspace_buffer, &profile_result));
+      return std::move(profile_result);
+    };
+
+    return GetBestAlgorithm<BlasLt::MatmulAlgorithm>(
+        gemm, algorithms, gemm_config.beta, /*return_algo_index*/ true,
+        tuned_func);
+  }
+
+  absl::StatusOr<AutotuneResult> TuneGpuBlas(const HloInstruction* gemm,
+                                             const GemmConfig& gemm_config) {
+    auto workspace_buffer = rz_buffers_.output_buffers().at(1);
+
+    std::vector<se::blas::AlgorithmType> algorithms;
+    TF_ASSIGN_OR_RETURN(GemmConfig::DescriptorsTuple desc,
+                        gemm_config.GetMatrixDescriptors(
+                            LhsBuffer(), RhsBuffer(), OutputBuffer()));
+
+    auto blas = stream_->parent()->AsBlas();
+    if (blas == nullptr) {
+      return absl::InternalError("No BLAS support for stream");
+    }
+    blas->GetBlasGemmAlgorithms(stream_, desc.lhs, desc.rhs, &desc.output,
+                                &gemm_config.alpha, &gemm_config.beta,
+                                &algorithms);
+
+    auto tuned_func = [&](const se::blas::AlgorithmType& algorithm)
+        -> absl::StatusOr<se::blas::ProfileResult> {
+      // Do a warm-up run first, without a profile result. RunGemm swallows
+      // error codes when profile_result is passed, as it is in the measurement
+      // below, but not otherwise. It is, therefore, consistent to ignore the
+      // error code here.
+      static_cast<void>(RunGemm(gemm_config, LhsBuffer(), RhsBuffer(),
+                                OutputBuffer(), workspace_buffer,
+                                deterministic_ops_, stream_, algorithm));
+      se::blas::ProfileResult profile_result;
+      // Allow GpuTimer to use its delay kernel implementation to improve
+      // accuracy.
+      profile_result.set_warmup_run_executed(true);
+      // We expect GemmWithAlgorithm to fail sometimes -- in fact, it will fail
+      // for all algorithms if we're targeting < sm_50. But because we pass a
+      // non-null ProfileResult, DoGemmWithAlgorithm should always return true,
+      // and the actual success-ness is returned in ProfileResult::is_valid.
+      TF_RETURN_IF_ERROR(RunGemm(gemm_config, LhsBuffer(), RhsBuffer(),
+                                 OutputBuffer(), workspace_buffer,
+                                 deterministic_ops_, stream_, algorithm,
+                                 &profile_result));
+      return std::move(profile_result);
+    };
+
+    return GetBestAlgorithm<se::blas::AlgorithmType>(
+        gemm, algorithms, gemm_config.beta, /*return_algo_index*/ false,
+        tuned_func);
+  }
+
+  // Returns the index (into `algorithms`) of the fastest algorithm.
+  template <typename AlgoT, typename TunedFunc>
+  absl::StatusOr<AutotuneResult> GetBestAlgorithm(
+      const HloInstruction* gemm, absl::Span<const AlgoT> algorithms,
+      double beta, bool return_algo_index, TunedFunc&& run_benchmark) {
+    static_assert(std::is_invocable_r_v<absl::StatusOr<se::blas::ProfileResult>,
+                                        TunedFunc, const AlgoT&>,
+                  "Tuned function has incorrect prototype!");
+
+    if (!stream_->parent()->SynchronizeAllActivity()) {
+      return Internal("Failed to synchronize GPU for autotuning.");
+    }
+    tsl::profiler::ScopedAnnotation annotation([&] {
+      return absl::StrFormat("XlaAutotunerMeasurement:#hlo_op=%s#",
+                             gemm->name());
+    });
+
+    auto& hlo_module_config = gemm->GetModule()->mutable_config();
+    const auto& output_shape = GetOutputShape(gemm);
+
+    se::DeviceMemoryBase reference_buffer;
+    if (autotune_config_.should_check_correctness()) {
+      TF_ASSIGN_OR_RETURN(reference_buffer,
+                          rz_buffers_.RedzoneAllocator().AllocateBytes(
+                              ShapeUtil::ByteSizeOf(output_shape)));
+    }
+
+    // Do not print error messages if should_skip_wrong_results() is ON.
+    BufferComparator comparator(
+        output_shape,
+        hlo_module_config.debug_options().xla_gpu_autotune_gemm_rtol(),
+        /* verbose */ !autotune_config_.should_skip_wrong_results());
+    std::vector<AutotuneResult> results;
+    results.reserve(algorithms.size());
+    std::optional<int64_t> reference_algorithm;
+
+    auto num = algorithms.size();
+    if (solutions_limit_ > 0) num = std::min(num, solutions_limit_);
+    for (size_t i = 0; i < num; i++) {
+      const AlgoT& algorithm = algorithms[i];
+      // Make sure the output buffer always has the same value if we use
+      // the bias parameter.
+      if (autotune_config_.should_reinit_output_buffer() && beta != 0) {
+        int64_t rng_state = 0;
+        InitializeBuffer(stream_, output_shape.element_type(), &rng_state,
+                         OutputBuffer());
+      }
+      TF_ASSIGN_OR_RETURN(auto profile_result, run_benchmark(algorithm));
+
+      AutotuneResult& result = results.emplace_back();
+      result.mutable_gemm()->set_algorithm(profile_result.algorithm());
+
+      if (!profile_result.is_valid()) {  // Unsupported algorithm.
+        result.mutable_failure()->set_kind(AutotuneResult::DISQUALIFIED);
+        continue;
+      }
+
+      VLOG(2) << "gemm algorithm " << profile_result.algorithm() << " took "
+              << profile_result.elapsed_time_in_ms() << "ms";
+
+      *result.mutable_run_time() = tsl::proto_utils::ToDurationProto(
+          absl::Milliseconds(profile_result.elapsed_time_in_ms()));
+
+      if (!autotune_config_.should_check_correctness()) {
+        num_algorithms_left_++;
+        continue;
+      }
+      TF_ASSIGN_OR_RETURN(
+          se::RedzoneAllocator::RedzoneCheckStatus rz_check_status,
+          rz_buffers_.RedzoneAllocator().CheckRedzones());
+
+      if (!rz_check_status.ok()) {
+        result.mutable_failure()->set_kind(AutotuneResult::REDZONE_MODIFIED);
+        *result.mutable_failure()->mutable_msg() =
+            rz_check_status.RedzoneFailureMsg();
+        LOG(ERROR) << "Detected out-of-bounds write in gemm buffer";
+        CHECK(!autotune_config_.should_crash_on_check_failure());
+        continue;
+      }
+
+      num_algorithms_left_++;
+      if (!reference_algorithm) {
+        TF_RETURN_IF_ERROR(stream_->Memcpy(&reference_buffer, OutputBuffer(),
+                                           OutputBuffer().size()));
+        reference_algorithm = profile_result.algorithm();
+        continue;
+      }
+      // Perform the comparison versus the reference algorithm.
+      TF_ASSIGN_OR_RETURN(
+          bool outputs_match,
+          comparator.CompareEqual(stream_, /*current=*/OutputBuffer(),
+                                  /*expected=*/reference_buffer));
+      if (!outputs_match) {
+        LOG(ERROR) << "Results mismatch between different GEMM algorithms. "
+                   << "This is likely a bug/unexpected loss of precision.";
+        CHECK(!autotune_config_.should_crash_on_check_failure());
+
+        // By default, autotuner does NOT really skip wrong results, but
+        // merely prints out the above error message: this may lead to a
+        // great confusion. When should_skip_wrong_results() is set to true,
+        // solutions with accuracy problems will be disqualified.
+        auto kind = AutotuneResult::WRONG_RESULT;
+        if (autotune_config_.should_skip_wrong_results()) {
+          kind = AutotuneResult::DISQUALIFIED;
+          num_algorithms_left_--;  // Decrement again since we disqualified it.
+        }
+        result.mutable_failure()->set_kind(kind);
+        result.mutable_failure()->mutable_reference_gemm()->set_algorithm(
+            *reference_algorithm);
+      }
+    }  // for algorithms
+
+    absl::StatusOr<AutotuneResult> best =
+        PickBestResult(results, gemm->ToString(), hlo_module_config);
+    if (best.ok()) {
+      // Note that, cublas-lt returns an opaque object as an algorithm ID,
+      // therefore we need to convert it to the index from the algorithms list
+      // (otherwise, we cannot store this ID inside a gemm_backend_config).
+      // In contrast, legacy cublas returns a 32-bit integer algorithm ID which
+      // can be readily stored inside an HLO (hence return_algo_index is false
+      // for cublas case).
+      if (!return_algo_index) return best;
+      // Otherwise, map a real algorithm ID to its index among the results.
+      for (size_t i = 0; i < results.size(); ++i) {
+        if (best->gemm().algorithm() == results[i].gemm().algorithm()) {
+          best->mutable_gemm()->set_algorithm(i);
+          return best;
+        }
+      }
+      return Internal("unknown best algorithm");
+    }
+    LOG(WARNING) << "Failed to find best cuBLAS algorithm, GEMM performance "
+                    "might be suboptimal: "
+                 << best.status();
+    return AutotuneResult{};
+  }  // GetBestAlgorithm
+};  // GemmAutotuner
+
+// Do Gemm Autotune without stream executor. Use results from autotune cache
+// only.
+absl::StatusOr<bool> RunOnInstruction(HloInstruction* gemm,
+                                      const AutotuneConfig& config,
+                                      size_t* num_algorithms_left) {
+  VLOG(3) << "Loading the autotune result of GemmThunk " << gemm->ToString();
+
+  GpuBackendConfig gpu_config =
+      gemm->backend_config<GpuBackendConfig>().value();
+  GemmBackendConfig& backend_config = *gpu_config.mutable_gemm_backend_config();
+
+  *num_algorithms_left = 0;
+  // Degenerate gemms replaced with memzero operation, no need to auto tune it.
+  if (backend_config.alpha_real() == 0.0 &&
+      backend_config.alpha_imag() == 0.0 && backend_config.beta() == 0.0) {
+    VLOG(3) << "Skip degenerate gemm instruction auto tuning";
+    return false;
+  }
+
+  AutotuneCacheKey key(config.GetModelStr(), *gemm);
+  GemmAutotuner autotuner(config);
+  TF_ASSIGN_OR_RETURN(AutotuneResult algorithm,
+                      AutotunerUtil::Autotune(
+                          gemm, config, [&] { return autotuner(gemm, key); }));
+
+  *num_algorithms_left = autotuner.num_algorithms_left();
+  auto old_algorithm = backend_config.selected_algorithm();
+  bool update_algorithm =
+      IsCublasLtMatmulF8(*gemm) ||
+      std::visit(VariantVisitor{[](const se::CudaComputeCapability& cc) {
+                                  // We only set the 'algorithm' field on
+                                  // non-Ampere architectures, as for Ampere
+                                  // it's ignored in any case.
+                                  return !cc.IsAtLeast(
+                                      se::CudaComputeCapability::AMPERE);
+                                },
+                                [](const se::RocmComputeCapability&) {
+                                  return true;  // TODO: not decided yet
+                                }},
+                 config.GetGpuComputeCapability());
+
+  if (update_algorithm) {
+    int64_t new_algorithm{};
+    if (algorithm.has_gemm()) {
+      new_algorithm = algorithm.gemm().algorithm();
+    } else {
+      // NOTE: runtime autotuning is no longer available => set to default
+      new_algorithm = se::blas::kDefaultAlgorithm;
+    }
+
+    if (new_algorithm == old_algorithm &&
+        backend_config.has_selected_algorithm()) {
+      // We don't need to update the backend config if
+      // the algorithm hasn't changed unless previously
+      // the algorithm wasn't set explicitly.
+      return false;
+    }
+
+    backend_config.set_selected_algorithm(new_algorithm);
+    TF_RETURN_IF_ERROR(gemm->set_backend_config(gpu_config));
+    return true;  // We changed `gemm`
+  }
+
+  return false;  // No change to `gemm`
+}
+
+absl::StatusOr<bool> RunOnComputation(HloComputation* computation,
+                                      AutotuneConfig config,
+                                      size_t* num_algorithms_left) {
+  bool changed = false;
+
+  for (HloInstruction* instr : computation->instructions()) {
+    if (IsCublasGemm(*instr)) {
+      size_t num_left;
+      TF_ASSIGN_OR_RETURN(bool result,
+                          RunOnInstruction(instr, config, &num_left));
+      // Gathering statistics on the algorithms left after tuning (for testing)
+      *num_algorithms_left = std::max(*num_algorithms_left, num_left);
+      changed |= result;
+    }
+  }
+  return changed;
+}
+
+}  // namespace
+
+absl::StatusOr<bool> GemmAlgorithmPicker::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  XLA_SCOPED_LOGGING_TIMER(
+      absl::StrCat("GemmAlgorithmPicker for ", module->name()));
+
+  num_algorithms_left_ = 0;
+  if (module->config().debug_options().xla_gpu_autotune_level() == 0) {
+    VLOG(2) << "GEMM auto-tuning disabled, GemmAlgorithmPicker returning early";
+    return false;
+  }
+
+  bool changed = false;
+  for (HloComputation* computation :
+       module->MakeNonfusionComputations(execution_threads)) {
+    TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation, config_,
+                                                      &num_algorithms_left_));
+    changed |= result;
+  }
+  return changed;
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.h b/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.h
new file mode 100644
index 0000000..2373583
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.h
@@ -0,0 +1,70 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_AUTOTUNING_GEMM_ALGORITHM_PICKER_H_
+#define XLA_SERVICE_GPU_AUTOTUNING_GEMM_ALGORITHM_PICKER_H_
+
+#include <functional>
+#include <optional>
+#include <string_view>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/autotune_results.pb.h"
+#include "xla/autotuning.pb.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
+#include "xla/service/hlo_module_config.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/shape.h"
+#include "xla/stream_executor/blas.h"
+#include "xla/stream_executor/device_memory.h"
+#include "xla/stream_executor/device_memory_allocator.h"
+#include "xla/stream_executor/gpu/redzone_allocator.h"
+#include "xla/stream_executor/stream_executor.h"
+
+namespace xla {
+namespace gpu {
+
+// GemmAlgorithmPicker supports two modes: device and deviceless.
+// In device mode, we run autotuning on the device and store autotune results.
+// In deviceless mode, we pass in some information related to the device and
+// use stored autotune results to rewrite Gemm instructions. If the required
+// autotune result is not stored, then algorithm is set to kRuntimeAutotuning.
+class GemmAlgorithmPicker : public HloModulePass {
+ public:
+  explicit GemmAlgorithmPicker(AutotuneConfig config) : config_(config) {}
+
+  absl::string_view name() const override { return "gemm-algorithm-picker"; }
+
+  size_t num_algorithms_left() const { return num_algorithms_left_; }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+  AutotuneConfig config_;
+  // The number of valid algorithms used for autotuning (from the last call),
+  // to be used for testing purposes.
+  size_t num_algorithms_left_ = 0;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_AUTOTUNING_GEMM_ALGORITHM_PICKER_H_
diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker_test.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker_test.cc
new file mode 100644
index 0000000..f1bd187
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker_test.cc
@@ -0,0 +1,301 @@
+/* Copyright 2022 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/autotuning/gemm_algorithm_picker.h"
+
+#include <cstdint>
+#include <variant>
+#include <vector>
+
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/transforms/gemm_rewriter.h"
+#include "xla/service/gpu/variant_visitor.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/service/platform_util.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/stream_executor/platform.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/tsl/lib/core/status_test_util.h"
+#include "tsl/platform/statusor.h"
+#include "tsl/platform/test.h"
+#include "tsl/protobuf/dnn.pb.h"
+
+namespace xla::gpu {
+namespace {
+
+namespace m = ::xla::match;
+
+class GemmAlgorithmPickerTest : public HloTestBase,
+                                public ::testing::WithParamInterface<bool> {
+ public:
+  GemmAlgorithmPickerTest() { AutotunerUtil::ClearAutotuneResults(); }
+
+  DebugOptions GetDebugOptionsForTest() override {
+    DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest();
+    debug_options.set_xla_gpu_enable_cublaslt(GetParam());
+    debug_options.set_xla_gpu_enable_triton_gemm(false);
+    return debug_options;
+  }
+
+  const se::DeviceDescription& device_desc() {
+    return backend().default_stream_executor()->GetDeviceDescription();
+  }
+
+  se::StreamExecutor* stream_exec() {
+    return backend().default_stream_executor();
+  }
+  const se::DeviceDescription& gpu_device_desc() {
+    return stream_exec()->GetDeviceDescription();
+  }
+  const se::GpuComputeCapability& gpu_comp() {
+    return gpu_device_desc().gpu_compute_capability();
+  }
+
+  void SetUp() override {
+    std::string_view name =
+        ::testing::UnitTest::GetInstance()->current_test_info()->name();
+    // We need special handling for BlasGetVersion test.
+    bool blas_get_version = name.rfind("BlasGetVersion") == 0;
+
+    std::visit(
+        VariantVisitor{
+            [&](const se::CudaComputeCapability& cc) {
+              if (!blas_get_version && cc.IsAtLeastAmpere()) {
+                GTEST_SKIP()
+                    << "Skipping this test for Ampere+ as it is supported "
+                       "and recommended with the Nvidia Volta+ GPUs.";
+              }
+            },
+            [&](const se::RocmComputeCapability& cc) {
+              if (blas_get_version) {
+                auto version = std::stol(device_desc().runtime_version());
+                if (version < 60200) {
+                  GTEST_SKIP()
+                      << "This API is not available on ROCM 6.1 and below.";
+                }
+              } else if (GetDebugOptionsForTest().xla_gpu_enable_cublaslt() &&
+                         !cc.has_hipblaslt()) {
+                GTEST_SKIP() << "No gpublas-lt support on this architecture!";
+              }
+            }},
+        gpu_comp());
+  }
+};
+
+TEST_P(GemmAlgorithmPickerTest, BlasGetVersion) {
+  auto* blas = backend().default_stream_executor()->AsBlas();
+  ASSERT_TRUE(blas != nullptr);
+  std::string version;
+  ASSERT_TRUE(blas->GetVersion(&version).ok());
+  VLOG(0) << "Blas version: " << version;
+  ASSERT_TRUE(!version.empty());
+}
+
+TEST_P(GemmAlgorithmPickerTest, SkipAlgorithmsWithAccuracyCheck) {
+  constexpr absl::string_view kHlo = R"(
+HloModule module
+
+ENTRY main {
+  %arg0 = f32[100,100]{1,0} parameter(0)
+  %arg1 = f32[100,100]{1,0} parameter(1)
+  ROOT %dot = f32[100,100]{1,0} dot(arg0, arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})";
+
+  auto module_cfg = GetModuleConfigForTest();
+  auto debug_opts = module_cfg.debug_options();
+  size_t num_left1 = 0, num_left2 = 0;
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(kHlo, module_cfg));
+
+  {
+    // Run first with default settings (autotune level = 4), keep the number of
+    // algorithms left after autotuning
+    TF_ASSERT_OK_AND_ASSIGN(
+        bool changed,
+        RunHloPass(GemmRewriter(gpu_comp(), /*toolkit_version=*/12040),
+                   module.get()));
+
+    AutotuneConfig cfg{DeviceConfig{stream_exec(), nullptr}, debug_opts};
+    GemmAlgorithmPicker gpicker(cfg);
+    // Note that, we do not care if the algorithm index has been changed:
+    // the thing matters is the # of algorithms left after sorting out.
+    TF_ASSERT_OK_AND_ASSIGN(changed, RunHloPass(gpicker, module.get()));
+    num_left1 = gpicker.num_algorithms_left();
+    if (num_left1 < 2) {
+      GTEST_SKIP() << "Too few algorithms left after the first step";
+    }
+  }
+
+  // Clear cache before the second run!
+  AutotunerUtil::ClearAutotuneResults();
+  {
+    // Run once again but now with autotune level 5 and embarassingly tight
+    // rtol which shall disqualify most of the algorithms.
+
+    // Note that, we have "two sources of truth" for GemmAlgorithmPicker: i.e.,
+    // debug_options are used to initialize both 'HloModuleConfig' and also
+    // 'AutotuneConfig'.
+    debug_opts.set_xla_gpu_autotune_gemm_rtol(1e-12);
+    debug_opts.set_xla_gpu_autotune_level(5);
+    module->mutable_config().set_debug_options(debug_opts);
+    TF_ASSERT_OK_AND_ASSIGN(
+        bool changed,
+        RunHloPass(GemmRewriter(gpu_comp(), /*toolkit_version=*/12040),
+                   module.get()));
+
+    AutotuneConfig cfg{DeviceConfig{stream_exec(), nullptr}, debug_opts};
+    GemmAlgorithmPicker gpicker(cfg);
+    TF_ASSERT_OK_AND_ASSIGN(changed, RunHloPass(gpicker, module.get()));
+    num_left2 = gpicker.num_algorithms_left();
+  }
+  // Assert that we have fewer algorithms left after the second run.
+  ASSERT_TRUE(num_left1 > num_left2);
+}
+
+TEST_P(GemmAlgorithmPickerTest, SetAlgorithm) {
+  constexpr absl::string_view kHlo = R"(
+HloModule module
+
+ENTRY main {
+  %arg0 = f32[100,100]{1,0} parameter(0)
+  %arg1 = f32[100,100]{1,0} parameter(1)
+  ROOT %dot = f32[100,100]{1,0} dot(arg0, arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})";
+
+  auto module_cfg = GetModuleConfigForTest();
+  TF_ASSERT_OK_AND_ASSIGN(auto m,
+                          ParseAndReturnVerifiedModule(kHlo, module_cfg));
+
+  bool changed = false;
+  TF_ASSERT_OK_AND_ASSIGN(
+      changed,
+      RunHloPass(GemmRewriter(gpu_comp(), /*toolkit_version=*/12040), m.get()));
+  changed = false;
+  DebugOptions opts;
+  AutotuneConfig cfg{DeviceConfig{stream_exec(), nullptr}, opts};
+  TF_ASSERT_OK_AND_ASSIGN(changed,
+                          RunHloPass(GemmAlgorithmPicker(cfg), m.get()));
+  ASSERT_TRUE(changed);
+
+  AutotuneResults results;
+  TF_ASSERT_OK(AutotunerUtil::SerializeAutotuneResults(&results));
+  ASSERT_EQ(results.results_size(), 1);
+  auto& result = *results.mutable_results(0)->mutable_result();
+  int64_t old_algo_id = result.algorithm().algo_id();
+  int64_t new_algo_id = old_algo_id + 1;
+  result.mutable_gemm()->set_algorithm(new_algo_id);
+
+  AutotunerUtil::ClearAutotuneResults();
+  TF_ASSERT_OK(AutotunerUtil::LoadAutotuneResults(results));
+
+  // Now send the same module through GemmAlgorithmPicker again.  The dot should
+  // have the new algorithm.
+  TF_ASSERT_OK_AND_ASSIGN(m, ParseAndReturnVerifiedModule(kHlo, module_cfg));
+  changed = false;
+  TF_ASSERT_OK_AND_ASSIGN(
+      changed,
+      RunHloPass(GemmRewriter(gpu_comp(), /*toolkit_version=*/12040), m.get()));
+  changed = false;
+  TF_ASSERT_OK_AND_ASSIGN(changed,
+                          RunHloPass(GemmAlgorithmPicker(cfg), m.get()));
+  ASSERT_TRUE(changed);
+
+  SCOPED_TRACE(m->ToString());
+  HloInstruction* dot;
+  ASSERT_THAT(m->entry_computation()->root_instruction(),
+              GmockMatch(m::GetTupleElement(m::CustomCall(&dot), 0)));
+
+  TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config,
+                          dot->backend_config<GpuBackendConfig>());
+  const GemmBackendConfig& config = gpu_config.gemm_backend_config();
+  EXPECT_EQ(config.selected_algorithm(), new_algo_id);
+}
+
+TEST_P(GemmAlgorithmPickerTest, GetAlgorithmWithoutDevice) {
+  constexpr absl::string_view kHlo = R"(
+HloModule module
+
+ENTRY main {
+  %arg0 = f32[100,100]{1,0} parameter(0)
+  %arg1 = f32[100,100]{1,0} parameter(1)
+  ROOT %dot = f32[100,100]{1,0} dot(arg0, arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})";
+  TF_ASSERT_OK_AND_ASSIGN(
+      auto m, ParseAndReturnVerifiedModule(kHlo, GetModuleConfigForTest()));
+
+  bool changed = false;
+  TF_ASSERT_OK_AND_ASSIGN(
+      changed,
+      RunHloPass(GemmRewriter(gpu_comp(), /*toolkit_version=*/12040), m.get()));
+  changed = false;
+
+  DebugOptions opts;
+  AutotuneConfig cfg{DeviceConfig{stream_exec(), nullptr}, opts};
+
+  TF_ASSERT_OK_AND_ASSIGN(changed,
+                          RunHloPass(GemmAlgorithmPicker(cfg), m.get()));
+  ASSERT_TRUE(changed);
+
+  AutotuneResults results;
+  TF_ASSERT_OK(AutotunerUtil::SerializeAutotuneResults(&results));
+  ASSERT_EQ(results.results_size(), 1);
+  auto& result = *results.mutable_results(0)->mutable_result();
+  int64_t old_algo_id = result.algorithm().algo_id();
+  int64_t new_algo_id = old_algo_id + 1;
+  result.mutable_gemm()->set_algorithm(new_algo_id);
+
+  AutotunerUtil::ClearAutotuneResults();
+  TF_ASSERT_OK(AutotunerUtil::LoadAutotuneResults(results));
+
+  auto module_cfg = GetModuleConfigForTest();
+  // Now send the same module through GemmAlgorithmPicker again.  The dot should
+  // have the new algorithm.
+  TF_ASSERT_OK_AND_ASSIGN(m, ParseAndReturnVerifiedModule(kHlo, module_cfg));
+  changed = false;
+
+  DevicelessConfig deviceless_config{gpu_device_desc()};
+  AutotuneConfig deviceless_cfg{deviceless_config, opts};
+  TF_ASSERT_OK_AND_ASSIGN(changed,
+                          RunHloPass(GemmRewriter(gpu_comp(),
+                                                  /*toolkit_version=*/12040),
+                                     m.get()));
+  changed = false;
+  TF_ASSERT_OK_AND_ASSIGN(
+      changed, RunHloPass(GemmAlgorithmPicker(deviceless_cfg), m.get()))
+  ASSERT_TRUE(changed);
+
+  SCOPED_TRACE(m->ToString());
+  HloInstruction* dot;
+
+  ASSERT_THAT(m->entry_computation()->root_instruction(),
+              GmockMatch(m::GetTupleElement(m::CustomCall(&dot), 0)));
+
+  TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config,
+                          dot->backend_config<GpuBackendConfig>());
+  const GemmBackendConfig& config = gpu_config.gemm_backend_config();
+
+  EXPECT_EQ(config.selected_algorithm(), new_algo_id);
+}
+
+INSTANTIATE_TEST_SUITE_P(GemmAlgorithmPickerTestSuite, GemmAlgorithmPickerTest,
+                         ::testing::Bool());
+
+}  // namespace
+}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc
new file mode 100644
index 0000000..e480f8b
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc
@@ -0,0 +1,1294 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/autotuning/gemm_fusion_autotuner.h"
+
+#include <algorithm>
+#include <array>
+#include <atomic>
+#include <cstdint>
+#include <iterator>
+#include <memory>
+#include <optional>
+#include <string>
+#include <utility>
+#include <variant>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/string_view.h"
+#include "absl/synchronization/mutex.h"
+#include "absl/time/time.h"
+#include "absl/types/span.h"
+#include "third_party/gpus/cuda/include/cublas_v2.h"
+#include "xla/autotuning.pb.h"
+#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_clone_context.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/hlo/utils/hlo_query.h"
+#include "xla/pjrt/distributed/key_value_store_interface.h"
+#include "xla/primitive_util.h"
+#include "xla/service/algorithm_util.h"
+#include "xla/service/dump.h"
+#include "xla/service/executable.h"
+#include "xla/service/float_normalization.h"
+#include "xla/service/gpu/autotuning/autotuner_compile_util.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/buffer_comparator.h"
+#include "xla/service/gpu/gpu_float_support.h"
+#include "xla/service/gpu/hlo_traversal.h"
+#include "xla/service/gpu/ir_emission_utils.h"
+#include "xla/service/gpu/matmul_utils.h"
+#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h"
+#include "xla/service/gpu/split_k_gemm_rewriter.h"
+#include "xla/service/gpu/stream_executor_util.h"
+#include "xla/service/gpu/transforms/cudnn_fusion_compiler.h"
+#include "xla/service/gpu/transforms/fusion_wrapper.h"
+#include "xla/service/gpu/transforms/gemm_rewriter.h"
+#include "xla/service/gpu/transforms/instruction_fusion.h"
+#include "xla/service/gpu/transforms/priority_fusion.h"
+#include "xla/service/hlo_module_config.h"
+#include "xla/service/shaped_buffer.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/status_macros.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/stream_executor/device_memory.h"
+#include "xla/stream_executor/device_memory_allocator.h"
+#include "xla/stream_executor/gpu/redzone_allocator.h"
+#include "xla/stream_executor/stream.h"
+#include "xla/stream_executor/stream_executor_memory_allocator.h"
+#include "xla/tools/hlo_decomposer.h"
+#include "xla/tsl/util/proto/proto_utils.h"
+#include "xla/util.h"
+#include "xla/xla.pb.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/lib/core/bits.h"
+#include "tsl/platform/blocking_counter.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/path.h"
+#include "tsl/platform/protobuf.h"
+#include "tsl/platform/status.h"
+#include "tsl/platform/statusor.h"
+#include "tsl/platform/threadpool.h"
+#include "tsl/profiler/lib/scoped_annotation.h"
+
+// Log levels used in this file:
+// VLOG(1): Overview
+// VLOG(2): Autotuning progress
+// VLOG(3): Autotuning progress - more frequent
+// VLOG(4): Print all fusions
+// VLOG(5): Profiling information for every tiling
+// VLOG(10): Print fusion computations and each configuration
+
+// TODO(b/317016172): Update usages of TritonGemmConfig to use newly exposed
+// parameters.
+
+namespace xla {
+namespace gpu {
+
+using Config = GemmFusionAutotunerImpl::Config;
+using TilingConfigs = GemmFusionAutotunerImpl::TilingConfigs;
+using ProfilingOutput = AutotunerCompileUtil::ProfilingOutput;
+
+namespace {
+
+// Minimum tile size.
+constexpr int kMinTileSize = 16;
+
+// Default tiling when autotuning is disabled.
+constexpr TritonGemmConfig kDefaultGemmTiling = {32, 32, 32, 1, 1, 4};
+
+// Split-K is enabled when the estimate number of waves is lower than the limit.
+constexpr int kMaxWavesForSplitK = 5;
+
+// Search space for exhaustive matmul autotuning.
+constexpr std::array<int, 6> kBlockSizes = {16, 32, 64, 128, 256, 512};
+constexpr std::array<int, 4> kNumStages = {1, 2, 3, 4};
+constexpr std::array<int, 4> kNumWarps = {2, 4, 8, 16};
+constexpr std::array<int, 5> kSplitK = {1, 2, 4, 8, 16};
+constexpr std::array<int, 5> kNumCtas = {1, 2, 4, 8, 16};
+
+using AutoTuneCacheKeyCount = absl::flat_hash_map<AutotuneCacheKey, uint64_t>;
+
+class GemmFusionAutotunerVisitor : public DfsHloRewriteVisitor {
+ public:
+  explicit GemmFusionAutotunerVisitor(const AutotuneConfig& config)
+      : config_(config) {}
+
+  absl::Status HandleFusion(HloInstruction* hlo) override {
+    TF_ASSIGN_OR_RETURN(auto gpu_config,
+                        hlo->backend_config<GpuBackendConfig>());
+    FusionBackendConfig& backend_config =
+        *gpu_config.mutable_fusion_backend_config();
+    if (backend_config.kind() != kTritonGemmFusionKind &&
+        backend_config.kind() != kCuDnnFusionKind) {
+      return absl::OkStatus();
+    }
+
+    VLOG(4) << "Processing " << hlo->ToString();
+    if (!backend_config.has_triton_gemm_config() &&
+        !backend_config.has_cudnn_fusion_config()) {
+      TF_ASSIGN_OR_RETURN(
+          AutotuneResult autotune_result,
+          AutotunerUtil::Autotune(
+              hlo, config_, [&]() -> absl::StatusOr<AutotuneResult> {
+                if (config_.IsDeviceless()) {
+                  return absl::InternalError(absl::StrCat(
+                      "Expect autotune result cache hit for deviceless "
+                      "compilation (HLO: ",
+                      hlo->ToString(), ")"));
+                }
+                return absl::InternalError("Expect autotune result cache hit.");
+              }));
+      VLOG(4) << "Result: " << autotune_result.ShortDebugString();
+
+      if (autotune_result.has_triton()) {
+        *backend_config.mutable_triton_gemm_config() = autotune_result.triton();
+        TF_RETURN_IF_ERROR(hlo->set_backend_config(gpu_config));
+      } else if (autotune_result.has_gemm()) {
+        // Falling back to cuBLAS: Converting the fusion to a Call, so that it
+        // can be inlined back again.
+        HloComputation* const computation = hlo->parent();
+        HloInstruction* const call = computation->AddInstruction(
+            HloInstruction::CreateCall(hlo->shape(), hlo->operands(),
+                                       hlo->fused_instructions_computation()));
+        TF_RETURN_IF_ERROR(computation->ReplaceInstruction(hlo, call));
+        hlo = call;
+      } else {
+        CHECK(autotune_result.has_algorithm());
+        backend_config.set_kind(std::string(kCuDnnFusionKind));
+        backend_config.mutable_cudnn_fusion_config()->set_plan_id(
+            autotune_result.algorithm().algo_id());
+        TF_RETURN_IF_ERROR(hlo->set_backend_config(gpu_config));
+      }
+    }
+
+    if (backend_config.has_triton_gemm_config()) {
+      TF_ASSIGN_OR_RETURN(
+          const TritonGemmConfig config,
+          TritonGemmConfig::FromProto(backend_config.triton_gemm_config()));
+      if (config.split_k > 1) {
+        TF_RETURN_IF_ERROR(MakeDotSplitKBatch(hlo, config));
+      }
+    }
+
+    MarkAsChanged();
+    return absl::OkStatus();
+  }
+
+ private:
+  AutotuneConfig config_;
+};
+
+class GemmConfigSetCollector : public ConstDfsHloVisitorWithDefault {
+ public:
+  explicit GemmConfigSetCollector(GemmFusionAutotunerImpl* impl)
+      : impl_(impl) {}
+
+  // Find configurations to tune.
+  absl::StatusOr<TilingConfigs> CollectGemmConfigSets(
+      const HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads = {}) {
+    error_out_on_cache_miss_ =
+        module->config()
+            .debug_options()
+            .xla_gpu_require_complete_aot_autotune_results();
+    gemm_config_sets_.clear();
+    for (HloComputation* computation :
+         module->MakeNonfusionComputations(execution_threads)) {
+      TF_RETURN_IF_ERROR(computation->Accept(this));
+    }
+    return std::move(gemm_config_sets_);
+  }
+
+  AutoTuneCacheKeyCount GetFusionsCount() {
+    return std::move(fusion_count_map_);
+  }
+
+  absl::Status HandleFusion(const HloInstruction* hlo) override {
+    const HloFusionInstruction* fusion = Cast<HloFusionInstruction>(hlo);
+
+    TF_ASSIGN_OR_RETURN(auto gpu_config,
+                        hlo->backend_config<GpuBackendConfig>());
+    const FusionBackendConfig& backend_config =
+        gpu_config.fusion_backend_config();
+
+    AutotuneCacheKey key = AutotunerUtil::GetKey(hlo, impl_->GetConfig());
+
+    auto [iterator, inserted] = fusion_count_map_.insert({key, 1});
+    if (!inserted) {
+      ++(iterator->second);
+    }
+
+    TF_ASSIGN_OR_RETURN(bool is_in_cache,
+                        AutotunerUtil::IsInCache(key, impl_->GetConfig()));
+    if (is_in_cache || handled_fusions_.contains(key)) {
+      return absl::OkStatus();
+    }
+
+    bool missing_config = (backend_config.kind() == kTritonGemmFusionKind &&
+                           !backend_config.has_triton_gemm_config()) ||
+                          (backend_config.kind() == kCuDnnFusionKind &&
+                           !backend_config.has_cudnn_fusion_config());
+    if (missing_config) {
+      if (error_out_on_cache_miss_) {
+        return absl::NotFoundError(absl::StrCat(
+            "Complete autotuning results are required, but no cache result "
+            "found for key: ",
+            key.ToString()));
+      }
+
+      TF_ASSIGN_OR_RETURN(std::vector<Config> configs,
+                          impl_->GenerateConfigs(*fusion));
+      gemm_config_sets_.push_back({fusion, std::move(configs)});
+    }
+
+    handled_fusions_.insert(key);
+    return absl::OkStatus();
+  }
+
+  absl::Status DefaultAction(const HloInstruction* hlo) override {
+    return absl::OkStatus();
+  }
+
+ private:
+  bool error_out_on_cache_miss_;
+  GemmFusionAutotunerImpl* impl_;
+  TilingConfigs gemm_config_sets_;
+  AutoTuneCacheKeyCount fusion_count_map_;
+  absl::flat_hash_set<AutotuneCacheKey> handled_fusions_;
+};
+
+struct TileSizeLimit {
+  int block_m = 0;
+  int block_n = 0;
+  int block_k = 0;
+};
+
+absl::StatusOr<TileSizeLimit> GetLimits(const HloDotInstruction& dot) {
+  TF_ASSIGN_OR_RETURN(int64_t non_contracting_index_lhs,
+                      NonContractingDimensionIndex(dot, /*operand_number=*/0));
+  TF_ASSIGN_OR_RETURN(int64_t non_contracting_index_rhs,
+                      NonContractingDimensionIndex(dot, /*operand_number=*/1));
+  TF_ASSIGN_OR_RETURN(int64_t contracting_index,
+                      ContractingDimensionIndex(dot, /*operand_number=*/1));
+  // This is not a sharp upper limit, the actual m value can be much smaller
+  // based on how much of the m dimension is physically contiguous.
+  // TODO(tdanyluk): Get the exact m value by running a TritonFusionAnalysis.
+  const int max_m = tsl::NextPowerOfTwoS64(
+      dot.operand(0)->shape().dimensions(non_contracting_index_lhs));
+  // Theoretically the same is true as for m, but that is not possible in
+  // practice with the current implementation.
+  const int max_n = tsl::NextPowerOfTwoS64(
+      dot.operand(1)->shape().dimensions(non_contracting_index_rhs));
+  // This is before doing the split-k transform.
+  const int max_k = tsl::NextPowerOfTwoS64(
+      dot.operand(1)->shape().dimensions(contracting_index));
+
+  return TileSizeLimit{
+      /*block_m=*/std::max(max_m, kMinTileSize),
+      /*block_n=*/std::max(max_n, kMinTileSize),
+      /*block_k=*/std::max(max_k, kMinTileSize),
+  };
+}
+
+int GetLogEveryN() { return VLOG_IS_ON(3) ? 100 : 1000; }
+
+absl::StatusOr<std::unique_ptr<HloModule>> TritonGemmAutotuneExtractor(
+    const TritonGemmConfig& config,
+    const se::DeviceDescription& gpu_device_info,
+    const HloFusionInstruction* fusion, DebugOptions debug_opts,
+    bool allow_filtering_kernels_spilling_registers) {
+  std::unique_ptr<HloModule> new_module =
+      ExtractInstructionIntoNewModule(*fusion);
+  if (!allow_filtering_kernels_spilling_registers) {
+    debug_opts.set_xla_gpu_filter_kernels_spilling_registers_on_autotuning(
+        false);
+  }
+  new_module->mutable_config().set_debug_options(debug_opts);
+
+  HloComputation* entry_computation = new_module->entry_computation();
+  HloInstruction* cloned_dot_fusion = entry_computation->root_instruction();
+
+  TF_ASSIGN_OR_RETURN(auto gpu_config,
+                      cloned_dot_fusion->backend_config<GpuBackendConfig>());
+  FusionBackendConfig& backend_config =
+      *gpu_config.mutable_fusion_backend_config();
+
+  *backend_config.mutable_triton_gemm_config() = config.ToProto();
+  TF_RETURN_IF_ERROR(cloned_dot_fusion->set_backend_config(gpu_config));
+
+  if (config.split_k > 1) {
+    TF_RETURN_IF_ERROR(MakeDotSplitKBatch(cloned_dot_fusion, config));
+    for (PrimitiveType type :
+         {BF16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) {
+      GpuFloatSupport float_support(gpu_device_info.cuda_compute_capability(),
+                                    type);
+      FloatNormalization float_normalization(&float_support);
+      TF_RETURN_IF_ERROR(float_normalization.Run(new_module.get()).status());
+    }
+
+    auto shape_size_function = [&](const Shape& shape) {
+      // The real pointer size is set in GpuCompiler. In HloCostAnalysis, the
+      // pointer size is used only to determine the size of tuple types. We
+      // shouldn't have any tuples in the autotuned module, so it's safe to use
+      // a constant here, instead of piping the real value.
+      constexpr int64_t kPointerSize = 8;
+      return ShapeUtil::ByteSizeOf(shape, kPointerSize);
+    };
+    PriorityFusion priority_fusion(
+        /*thread_pool=*/nullptr, gpu_device_info,
+        GpuHloCostAnalysis::Options{/*shape_size=*/shape_size_function,
+                                    /*per_second_rates=*/{},
+                                    /*count_multiple_input_accesses=*/true});
+    TF_RETURN_IF_ERROR(priority_fusion.Run(new_module.get()).status());
+
+    // If the priority fusion pass above skipped some instructions, turn them
+    // into fusions.
+    FusionWrapper fusion_wrapper;
+    TF_RETURN_IF_ERROR(fusion_wrapper.Run(new_module.get()).status());
+  }
+  return new_module;
+}
+
+absl::StatusOr<std::unique_ptr<HloModule>> CublasGemmAutotuneExtractor(
+    const AutotuneConfig& config, const int32_t toolkit_version,
+    const HloFusionInstruction* fusion, const DebugOptions& debug_opts) {
+  const HloComputation* fusion_computation =
+      fusion->called_computations().at(0);
+  std::unique_ptr<HloModule> new_module =
+      ExtractComputationIntoNewModule(*fusion_computation);
+  new_module->mutable_config().set_debug_options(debug_opts);
+
+  auto* dot = hlo_query::GetFirstInstructionWithOpcode(
+      *new_module->entry_computation(), HloOpcode::kDot);
+  // Substitute algorithms, which are not supported by cuBLAS for the check, but
+  // don't use cuBlas in the end. This assumes that the substituting algorithm
+  // has result which are close enough for the check in this file.
+  if (dot->precision_config().algorithm() ==
+          PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3 ||
+      dot->precision_config().algorithm() ==
+          PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6) {
+    dot->mutable_precision_config()->set_algorithm(
+        PrecisionConfig::ALG_DOT_F32_F32_F32);
+  }
+
+  for (GemmRewriterOptions::DType dtype :
+       {GemmRewriterOptions::DType::kFp8Only,
+        GemmRewriterOptions::DType::kNonFp8Only}) {
+    GemmRewriter rewriter(config.GetGpuComputeCapability(), toolkit_version,
+                          GemmRewriterOptions{dtype});
+    GpuInstructionFusion fusion_pass(
+        /*may_duplicate=*/false, config.GetExecutor()->GetDeviceDescription());
+    TF_RETURN_IF_ERROR(rewriter.Run(new_module.get()).status());
+    TF_RETURN_IF_ERROR(fusion_pass.Run(new_module.get()).status());
+  }
+  // TODO(tdanyluk): Consider running GemmAlgorithmPicker here for better cuBLAS
+  // performance. It is probably not needed on Ampere and later because cuBLAS
+  // ignores the algorithm parameter for those targets. If we run
+  // GemmAlgorithmPicker, we probably should not run this in parallel with other
+  // compilations.
+  return new_module;
+}
+
+absl::StatusOr<std::unique_ptr<HloModule>> FusionExtractor(
+    const HloFusionInstruction& fusion, const DebugOptions& debug_opts) {
+  std::unique_ptr<HloModule> module = ExtractInstructionIntoNewModule(fusion);
+  module->mutable_config().set_debug_options(debug_opts);
+  return module;
+}
+
+absl::StatusOr<std::unique_ptr<HloModule>> CuDnnFusionExtractor(
+    const HloFusionInstruction& fusion, const DebugOptions& debug_opts,
+    const int plan_id) {
+  TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
+                      FusionExtractor(fusion, debug_opts));
+
+  GpuBackendConfig gpu_config;
+  FusionBackendConfig& backend_config =
+      *gpu_config.mutable_fusion_backend_config();
+  backend_config.set_kind(std::string(kCuDnnFusionKind));
+  // Provided a plan ID the autotuner just compiles one plan.
+  backend_config.mutable_cudnn_fusion_config()->set_plan_id(plan_id);
+  TF_RETURN_IF_ERROR(
+      module->entry_computation()->root_instruction()->set_backend_config(
+          gpu_config));
+  return module;
+}
+
+bool IsFusionKind(const HloInstruction& hlo, absl::string_view kind) {
+  auto gpu_config = hlo.backend_config<GpuBackendConfig>();
+  if (!gpu_config.ok()) {
+    return false;
+  }
+  return gpu_config->fusion_backend_config().kind() == kind;
+}
+
+int GetCuDnnPlanCount(const HloInstruction& hlo,
+                      const AutotuneConfig& autotune_config) {
+  if (auto gpu_config = hlo.backend_config<GpuBackendConfig>();
+      !gpu_config.ok() ||
+      gpu_config->fusion_backend_config().has_cudnn_fusion_config()) {
+    return {};
+  }
+  return CuDnnFusionCompiler::GetAvailablePlanCount(
+      *autotune_config.GetExecutor(), *DynCast<HloFusionInstruction>(&hlo));
+}
+
+AutotuneResult FromConfig(const Config& config) {
+  AutotuneResult res;
+  if (std::holds_alternative<GemmFusionAutotunerImpl::CuBlasConfig>(config)) {
+    res.mutable_gemm()->set_algorithm(CUBLAS_GEMM_DEFAULT);
+  } else if (std::holds_alternative<GemmFusionAutotunerImpl::CuDnnConfig>(
+                 config)) {
+    res.mutable_algorithm()->set_algo_id(
+        std::get<GemmFusionAutotunerImpl::CuDnnConfig>(config).plan_id);
+  } else if (std::holds_alternative<TritonGemmConfig>(config)) {
+    *res.mutable_triton() = std::get<TritonGemmConfig>(config).ToProto();
+  } else {
+    LOG(FATAL) << "Unsupported config type: " << config.index();
+  }
+  return res;
+}
+
+absl::Status DumpOriginalFusion(AutotunerCompileUtil& util,
+                                const HloFusionInstruction& fusion,
+                                int fusion_id) {
+  TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
+                      util.ExtractModule([&](const DebugOptions& debug_opts) {
+                        return FusionExtractor(fusion, debug_opts);
+                      }));
+  module->set_name(std::string(fusion.name()));
+  // Using the original module for its debug info and name in the first
+  // parameter. It's better to include the name of both the original module
+  // and the extracted module, to avoid name clashes.
+  std::string rendered_graph_name =
+      absl::StrCat("gemm_fusion_", fusion_id, ".", module->name(), ".dot");
+  std::string rendered_graph = RenderGraph(rendered_graph_name, *module,
+                                           RenderedGraphFormat::kDot, true);
+  DumpToFileInDir(
+      /*module=*/*fusion.GetModule(),
+      /*file_prefix=*/"",
+      /*file_suffix=*/rendered_graph_name,
+      /*contents=*/rendered_graph);
+  DumpToFileInDirOrStdout(
+      /*module=*/*fusion.GetModule(),
+      /*file_prefix=*/"",
+      /*file_suffix=*/
+      absl::StrCat("gemm_fusion_", fusion_id, ".", module->name(), ".txt"),
+      /*contents=*/module->ToString());
+  return absl::OkStatus();
+}
+
+absl::Status DumpAutotunedFusion(const AutotuneConfig& autotune_config,
+                                 const int32_t toolkit_version,
+                                 AutotunerCompileUtil& util,
+                                 const AutotuneResult result,
+                                 const HloFusionInstruction* fusion,
+                                 int fusion_id) {
+  TritonGemmConfig triton_gemm_config;
+  if (result.has_triton()) {
+    TF_ASSIGN_OR_RETURN(triton_gemm_config,
+                        TritonGemmConfig::FromProto(result.triton()));
+  }
+  const se::DeviceDescription& device_desc =
+      autotune_config.GetExecutor()->GetDeviceDescription();
+  TF_ASSIGN_OR_RETURN(
+      std::unique_ptr<HloModule> module,
+      util.ExtractModule([&](const DebugOptions& debug_opts) {
+        if (result.has_algorithm()) {
+          return CuDnnFusionExtractor(*fusion, debug_opts,
+                                      result.algorithm().algo_id());
+        } else if (result.has_triton()) {
+          return TritonGemmAutotuneExtractor(
+              triton_gemm_config, device_desc, fusion, debug_opts,
+              /*allow_filtering_kernels_spilling_registers=*/true);
+        } else if (result.has_gemm()) {
+          return CublasGemmAutotuneExtractor(autotune_config, toolkit_version,
+                                             fusion, debug_opts);
+        } else {
+          LOG(FATAL) << "Unknown result type: " << result.DebugString();
+        }
+      }));
+  module->set_name(std::string(fusion->name()));
+  // Using the original module for its debug info and name in the first
+  // parameter. It's better to include the name of both the original module
+  // and the extracted module, to avoid name clashes.
+  DumpToFileInDirOrStdout(
+      /*module=*/*fusion->GetModule(),
+      /*file_prefix=*/"",
+      /*file_suffix=*/
+      absl::StrCat("gemm_fusion_", fusion_id, ".", module->name(),
+                   ".optimized.txt"),
+      /*contents=*/module->ToString());
+  return absl::OkStatus();
+}
+
+std::string Serialize(const Config& config) {
+  if (auto triton_config = std::get_if<TritonGemmConfig>(&config)) {
+    tsl::protobuf::TextFormat::Printer printer;
+    printer.SetSingleLineMode(true);
+    std::string result;
+    printer.PrintToString(triton_config->ToProto(), &result);
+    return result;
+  }
+  return GemmFusionAutotunerImpl::ToString(config);
+}
+
+}  // anonymous namespace
+
+// Methods required for sorting the configs.
+bool GemmFusionAutotunerImpl::CuBlasConfig::operator<(
+    const CuBlasConfig& other) const {
+  return false;
+}
+bool GemmFusionAutotunerImpl::CuDnnConfig::operator<(
+    const CuDnnConfig& other) const {
+  return plan_id < other.plan_id;
+}
+
+bool GemmFusionAutotunerImpl::IsAutotuningEnabled() const {
+  return debug_options_.xla_gpu_autotune_level() > 0 &&
+         !debug_options_.xla_gpu_deterministic_ops();
+}
+
+/*static*/ std::string GemmFusionAutotunerImpl::ToString(const Config& config) {
+  if (std::holds_alternative<TritonGemmConfig>(config)) {
+    return std::get<TritonGemmConfig>(config).ToString();
+  } else if (std::holds_alternative<CuDnnConfig>(config)) {
+    return absl::StrFormat("cuDNN plan %d",
+                           std::get<CuDnnConfig>(config).plan_id);
+  } else if (std::holds_alternative<CuBlasConfig>(config)) {
+    return "reference (cublas)";
+  } else {
+    LOG(FATAL) << "Unsupported config type: " << config.index();
+  }
+}
+
+absl::StatusOr<std::vector<Config>> GemmFusionAutotunerImpl::GenerateConfigs(
+    const HloFusionInstruction& fusion) {
+  const HloDotInstruction* dot =
+      Cast<HloDotInstruction>(hlo_query::GetFirstInstructionWithOpcode(
+          *fusion.called_computations().at(0), HloOpcode::kDot));
+
+  // Add cuBLAS reference config, if available.
+  std::vector<Config> configs;
+  if (algorithm_util::IsSupportedByCublasOrCublasLt(
+          dot->precision_config().algorithm()) &&
+      !dot->sparse_operands() && IsAutotuningEnabled()) {
+    configs.push_back(CuBlasConfig{});
+  }
+
+  // Add cuDNN plans, if available.
+  bool is_hopper =
+      !config_.IsDeviceless() && GetComputeCapability().IsAtLeastHopper();
+  bool is_cudnn_enabled =
+      debug_options_.xla_gpu_cudnn_gemm_fusion_level() > 0 && is_hopper &&
+      GetDnnVersionInfoOrDefault(config_.GetExecutor()).major_version() >= 9;
+  if ((IsFusionKind(fusion, kCuDnnFusionKind) && IsAutotuningEnabled()) ||
+      (IsFusionKind(fusion, kTritonGemmFusionKind) && is_cudnn_enabled &&
+       algorithm_util::IsSupportedByCudnn(
+           dot->precision_config().algorithm()) &&
+       !dot->sparse_operands() && IsAutotuningEnabled())) {
+    const int plan_count = GetCuDnnPlanCount(fusion, config_);
+    for (int plan_id = 0; plan_id < plan_count; ++plan_id) {
+      configs.push_back(CuDnnConfig{plan_id});
+    }
+  }
+  if (IsFusionKind(fusion, kCuDnnFusionKind)) {
+    if (!IsAutotuningEnabled()) {
+      configs.push_back(CuDnnConfig{-1});
+    }
+    return configs;
+  }
+
+  // Add triton configs.
+  TF_ASSIGN_OR_RETURN(std::vector<TritonGemmConfig> triton_configs,
+                      GenerateTritonConfigs(*dot));
+  for (TritonGemmConfig& config : triton_configs) {
+    configs.push_back(std::move(config));
+  }
+  return configs;
+}
+
+absl::StatusOr<std::vector<TritonGemmConfig>>
+GemmFusionAutotunerImpl::GenerateTritonConfigs(const HloDotInstruction& dot) {
+  // Retrieve the minimum bit-width participating in the dot. This is needed
+  // to avoid autotuning configurations that are not supported by Triton. This
+  // is used to restrict the values for tile_k.
+  std::vector<const HloInstruction*> converts =
+      HloBfsFindAll({&dot}, [&](const HloInstruction* node) {
+        return node->opcode() == HloOpcode::kConvert;
+      });
+  int minBitWidth = primitive_util::BitWidth(dot.shape().element_type());
+  for (auto convert : converts) {
+    auto in_type = convert->operand(0)->shape().element_type();
+    auto out_type = convert->shape().element_type();
+    minBitWidth = std::min({minBitWidth, primitive_util::BitWidth(in_type),
+                            primitive_util::BitWidth(out_type)});
+  }
+
+  std::vector<TritonGemmConfig> result_configs;
+  TF_ASSIGN_OR_RETURN(TileSizeLimit limits, GetLimits(dot));
+
+  // Generate the list of configurations (once).
+  if (triton_configs_.empty()) {
+    triton_configs_ = !IsAutotuningEnabled()
+                          ? std::vector(1, kDefaultGemmTiling)
+                      : debug_options_.xla_gpu_exhaustive_tiling_search()
+                          ? GetExhaustiveTritonConfigs()
+                          : GetDefaultTritonConfigs();
+  }
+
+  // Avoid autotuning tiny fusions.
+  constexpr int kMinGemmElements = 32 * 32;
+  bool small_dot =
+      ShapeUtil::ElementsIn(dot.operand(0)->shape()) <= kMinGemmElements &&
+      ShapeUtil::ElementsIn(dot.operand(1)->shape()) <= kMinGemmElements;
+  std::vector<TritonGemmConfig> triton_configs =
+      small_dot ? std::vector(1, kDefaultGemmTiling) : triton_configs_;
+
+  // Split-K optimization enables more even utilization of a GPU in cases
+  // where tiling just the non-contracting dimensions of a GEMM does not create
+  // a sufficient number of thread block programs to occupy all available cores.
+  // Around 5 full waves completely avoid the need for split-K.
+  // n_tiles = split_k * (M * N) / (block_m * block_n)
+  const int kCoreCount =
+      !config_.IsDeviceless()
+          ? config_.GetExecutor()->GetDeviceDescription().core_count()
+          : 100;  // some sensible default
+  const int64_t kSufficientNumberOfTiles = kMaxWavesForSplitK * kCoreCount;
+  const int64_t result_size = ShapeUtil::ElementsIn(dot.shape());
+
+  // Triton configurations are adjusted and deduplicated.
+  absl::flat_hash_set<TritonGemmConfig> added;
+  bool is_hopper =
+      !config_.IsDeviceless() && GetComputeCapability().IsAtLeastHopper();
+  for (TritonGemmConfig& config : triton_configs) {
+    config.block_m = std::min(config.block_m, limits.block_m);
+    config.block_n = std::min(config.block_n, limits.block_n);
+    config.block_k = std::min(config.block_k, limits.block_k);
+    int max_split_k = 1;
+    if (debug_options_.xla_gpu_enable_split_k_autotuning()) {
+      int64_t ratio = kSufficientNumberOfTiles * config.block_m *
+                      config.block_n / result_size;
+      max_split_k = 1 << std::max<int>(tsl::Log2Floor64(ratio), 0);
+    }
+    config.split_k = std::min(config.split_k, max_split_k);
+
+    // TODO(b/337839570): Triton currently has a limitation where it crashes
+    // on small block_k values depending on the bit-width of the inputs to the
+    // dot. The logic below accounts for this limitation.
+    constexpr int kLdmatrixGranularity = 256;
+    config.block_k =
+        std::max(config.block_k, kLdmatrixGranularity / minBitWidth);
+
+    // Sparse meta should have at least one element per thread.
+    // Note: only 2:4 structured sparsity is currently supported.
+    if (dot.sparse_operands()) {
+      if (is_hopper) {
+        config.block_m = std::max(config.block_m, 64);
+        config.num_warps = std::max(config.num_warps, 4);
+      }
+      config.block_k = std::max(
+          config.block_k,
+          2 * std::max(kMinTileSize, kLdmatrixGranularity / minBitWidth));
+      int meta_elements = config.block_m * config.block_k / 16;
+      config.num_warps =
+          std::min<int>(config.num_warps, meta_elements / WarpSize());
+    }
+
+    if (added.insert(config).second) {
+      result_configs.push_back(config);
+    }
+  }
+  return result_configs;
+}
+
+absl::StatusOr<absl::flat_hash_map<
+    const HloFusionInstruction*,
+    std::vector<GemmFusionAutotunerImpl::ExecutableCandidate>>>
+GemmFusionAutotunerImpl::CompileAll(AutotunerCompileUtil& compile_util,
+                                    const TilingConfigs& task) {
+  tsl::profiler::ScopedAnnotation annotation("XlaAutotunerCompilation");
+  absl::Mutex results_mu;
+  absl::flat_hash_map<const HloFusionInstruction*,
+                      std::vector<ExecutableCandidate>>
+      results;
+  if (task.empty()) {
+    return results;
+  }
+
+  const int log_every_n = GetLogEveryN();
+  int64_t config_count = 0;
+  for (const auto& [unused, configs] : task) {
+    config_count += configs.size();
+  }
+
+  std::atomic<int> done_count = 0;
+  std::atomic<int> good_count = 0;
+  auto log = [&](bool success) {
+    const int done_so_far = done_count.fetch_add(1) + 1;
+    const int good_so_far =
+        success ? good_count.fetch_add(1) + 1 : good_count.load();
+    if (done_so_far % log_every_n == 0) {
+      VLOG(2) << "Compiled " << done_so_far << " of " << config_count
+              << " configs (successful: " << good_so_far << ")";
+    }
+  };
+
+  auto compile = [&](const HloFusionInstruction* fusion, const Config& config,
+                     bool allow_filtering_kernels_spilling_registers)
+      -> absl::StatusOr<bool> {
+    std::unique_ptr<Executable> executable;
+    if (std::holds_alternative<TritonGemmConfig>(config)) {
+      TF_ASSIGN_OR_RETURN(
+          executable, compile_util.Compile([&](const DebugOptions& opts) {
+            return TritonGemmAutotuneExtractor(
+                std::get<TritonGemmConfig>(config),
+                config_.GetExecutor()->GetDeviceDescription(), fusion, opts,
+                allow_filtering_kernels_spilling_registers);
+          }));
+    } else if (std::holds_alternative<CuDnnConfig>(config)) {
+      executable =
+          compile_util
+              .Compile([&](const DebugOptions& opts) {
+                return CuDnnFusionExtractor(
+                    *fusion, opts, std::get<CuDnnConfig>(config).plan_id);
+              })
+              .value_or(nullptr);
+    } else if (std::holds_alternative<CuBlasConfig>(config)) {
+      TF_ASSIGN_OR_RETURN(executable,
+                          compile_util.Compile([&](const DebugOptions& opts) {
+                            return CublasGemmAutotuneExtractor(
+                                config_, toolkit_version_, fusion, opts);
+                          }));
+    } else {
+      LOG(FATAL) << "Unsupported config type: " << config.index();
+    }
+    if (executable != nullptr) {
+      absl::MutexLock lock(&results_mu);
+      results[fusion].push_back({config, std::move(executable)});
+      return true;
+    }
+    return false;
+  };
+
+  // If the thread pool has only one thread, then it is actually slower to
+  // offload the tasks there.
+  if (thread_pool_ && thread_pool_->NumThreads() > 1 &&
+      debug_options_.xla_gpu_force_compilation_parallelism() != 1) {
+    if (task.size() == 1) {
+      absl::string_view fusion_name = task.begin()->first->name();
+      VLOG(1) << "Compiling " << config_count << " configs for " << fusion_name
+              << " on " << thread_pool_->NumThreads() << " threads.";
+    } else {
+      VLOG(1) << "Compiling " << config_count << " configs for " << task.size()
+              << " fusions on " << thread_pool_->NumThreads() << " threads.";
+    }
+
+    tsl::BlockingCounter counter(config_count);
+    for (const auto& key_value : task) {
+      const HloFusionInstruction* fusion = key_value.first;
+      const std::vector<Config>& gemm_config_set = key_value.second;
+
+      VLOG(10) << "Compiling fusion: " << fusion->name();
+      VLOG(10) << "Dumping fusion computation: "
+               << fusion->called_computation()->ToString();
+      for (const Config& config : gemm_config_set) {
+        thread_pool_->Schedule([&, fusion] {
+          VLOG(10) << "Trying configuration forceable through: "
+                      "--xla_gpu_override_gemm_autotuner='"
+                   << Serialize(config) << "'";
+          VLOG(10) << "WARNING: you are running in multithreaded-mode, the "
+                      "last configuration printed out might not be the one "
+                      "causing issues! Use "
+                      "--xla_gpu_force_compilation_parallelism=1 to fix.";
+          absl::StatusOr<bool> has_executable =
+              compile(fusion, config, gemm_config_set.size() > 1);
+          TF_CHECK_OK(has_executable.status())
+              << "Failure occured when compiling fusion " << fusion->name()
+              << " with config '" << ToString(config)
+              << "'\nFused HLO computation:\n"
+              << fusion->fused_instructions_computation()->ToString();
+          log(has_executable.value());
+          counter.DecrementCount();
+        });
+      }
+    }
+    counter.Wait();
+  } else {
+    if (task.size() == 1) {
+      absl::string_view fusion_name = task.begin()->first->name();
+      LOG(WARNING) << "Compiling " << config_count << " configs for "
+                   << fusion_name << " on a single thread.";
+    } else {
+      LOG(WARNING) << "Compiling " << config_count << " configs for "
+                   << task.size() << " fusions on a single thread.";
+    }
+
+    for (const auto& [fusion, gemm_config_set] : task) {
+      VLOG(10) << "Compiling fusion: " << fusion->name();
+      VLOG(10) << "Dumping fusion computation: "
+               << fusion->called_computation()->ToString();
+      for (const Config& config : gemm_config_set) {
+        VLOG(10) << "Trying configuration forceable through: "
+                    "--xla_gpu_override_gemm_autotuner='"
+                 << Serialize(config) << "'";
+        TF_ASSIGN_OR_RETURN(
+            bool has_executable,
+            compile(fusion, config, gemm_config_set.size() > 1));
+        log(has_executable);
+      }
+    }
+  }
+
+  VLOG(1) << "Done compiling (successful: " << good_count.load() << ").";
+  return results;
+}
+
+absl::StatusOr<std::vector<AutotuneResult>> GemmFusionAutotunerImpl::Profile(
+    AutotunerCompileUtil& compile_util, const HloFusionInstruction& fusion,
+    absl::Span<const ExecutableCandidate> candidates) {
+  const HloComputation* fusion_computation = fusion.called_computations().at(0);
+
+  se::StreamExecutor* stream_exec = config_.GetExecutor();
+  if (!stream_exec->SynchronizeAllActivity()) {
+    return Internal("Failed to synchronize GPU for autotuning.");
+  }
+  tsl::profiler::ScopedAnnotation annotation([&] {
+    return absl::StrFormat("XlaAutotunerMeasurement:#hlo_op=%s#",
+                           fusion.name());
+  });
+  se::DeviceMemoryAllocator* allocator = config_.GetAllocator();
+  std::unique_ptr<se::DeviceMemoryAllocator> owned_allocator;
+  if (allocator == nullptr) {
+    owned_allocator =
+        std::make_unique<se::StreamExecutorMemoryAllocator>(stream_exec);
+    allocator = owned_allocator.get();
+  }
+  TF_ASSIGN_OR_RETURN(se::Stream* const stream, config_.GetStream());
+
+  const HloInstruction& root = *fusion_computation->root_instruction();
+  BufferComparator comparator(root.shape(),
+                              debug_options_.xla_gpu_autotune_gemm_rtol());
+
+  TF_ASSIGN_OR_RETURN(auto rz_buffers,
+                      RedzoneBuffers::FromInstruction(
+                          *fusion_computation->FusionInstruction(), config_,
+                          debug_options_, RedzoneBuffers::kAllInputs));
+
+  const int log_every_n = GetLogEveryN();
+  std::vector<AutotuneResult> results;
+  std::optional<ScopedShapedBuffer> reference_buffer;
+  for (const ExecutableCandidate& candidate : candidates) {
+    VLOG(5) << "Trying : " << ToString(candidate.config);
+    AutotuneResult res = FromConfig(candidate.config);
+
+    std::optional<ProfilingOutput> profiling_output;
+    if (IsAutotuningEnabled()) {
+      TF_ASSIGN_OR_RETURN(
+          profiling_output,
+          compile_util.ProfileExecutable(candidate.executable.get(), stream,
+                                         rz_buffers.input_buffers(),
+                                         rz_buffers.input_shapes()));
+      if (std::holds_alternative<CuBlasConfig>(candidate.config) &&
+          config_.should_check_correctness()) {
+        reference_buffer = std::move(profiling_output->output);
+      }
+
+      int ran_so_far = results.size() + 1;
+      if (ran_so_far % log_every_n == 0) {
+        VLOG(2) << "Ran " << ran_so_far << " configs of " << candidates.size()
+                << ".";
+      }
+      if (!profiling_output) {
+        VLOG(5) << "Skipping this tiling.";
+        continue;
+      }
+
+      VLOG(5) << "Running the kernel took: " << profiling_output->duration;
+      if (profiling_output->duration >= absl::Seconds(1)) {
+        LOG(WARNING) << "Slow kernel for "
+                     << fusion.called_computations()[0]->ToString()
+                     << " took: " << profiling_output->duration << ". "
+                     << ToString(candidate.config);
+      }
+      *res.mutable_run_time() =
+          tsl::proto_utils::ToDurationProto(profiling_output->duration);
+    }
+
+    // Reference buffer is available when `config.should_check_correctness()`
+    // is set and reference executable was compiled.
+    if (reference_buffer.has_value() &&
+        !std::holds_alternative<CuBlasConfig>(candidate.config)) {
+      TF_ASSIGN_OR_RETURN(
+          se::RedzoneAllocator::RedzoneCheckStatus rz_check_status,
+          rz_buffers.RedzoneAllocator().CheckRedzones());
+      if (!rz_check_status.ok()) {
+        LOG(ERROR) << "Red zone modified";
+        res.mutable_failure()->set_kind(AutotuneResult::REDZONE_MODIFIED);
+        res.mutable_failure()->set_msg(rz_check_status.RedzoneFailureMsg());
+        CHECK(!config_.should_crash_on_check_failure());
+        continue;
+      }
+
+      TF_ASSIGN_OR_RETURN(
+          bool outputs_match,
+          comparator.CompareEqual(
+              stream, /*current=*/profiling_output->output.root_buffer(),
+              /*expected=*/reference_buffer->root_buffer()));
+      if (!outputs_match) {
+        const char kMessage[] =
+            "Results do not match the reference. This is likely a "
+            "bug/unexpected loss of precision.";
+        LOG(ERROR) << kMessage;
+        CHECK(!config_.should_crash_on_check_failure());
+        // WRONG_RESULT is not taken seriously by PickBestResult(), so
+        // use DISQUALIFIED.
+        res.mutable_failure()->set_kind(AutotuneResult::DISQUALIFIED);
+        res.mutable_failure()->set_msg(kMessage);
+      }
+    }
+    results.push_back(std::move(res));
+  }
+  VLOG(2) << "Done running.";
+  return results;
+}
+
+std::vector<TritonGemmConfig>
+GemmFusionAutotunerImpl::GetExhaustiveTritonConfigs() const {
+  std::vector<TritonGemmConfig> configs;
+  se::CudaComputeCapability cc = GetComputeCapability();
+  bool tune_ctas =
+      debug_options_.xla_gpu_enable_triton_hopper() && cc.IsAtLeastHopper();
+
+  for (int num_stages : kNumStages) {
+    // Volta doesn't support num_stages > 2.
+    if (!cc.IsAtLeastAmpere() && num_stages > 2) {
+      break;
+    }
+    for (int tile_m : kBlockSizes) {
+      for (int tile_n : kBlockSizes) {
+        for (int tile_k : kBlockSizes) {
+          const int tile_lhs = tile_m * tile_k;
+          const int tile_rhs = tile_k * tile_n;
+          for (int num_warps : kNumWarps) {
+            // Each thread should read at least one input element.
+            if (num_warps * WarpSize() > std::min(tile_lhs, tile_rhs)) {
+              break;
+            }
+            for (int split_k : kSplitK) {
+              // Split-K autotuning may be disabled by a flag.
+              if (!debug_options_.xla_gpu_enable_split_k_autotuning() &&
+                  split_k > 1) {
+                break;
+              }
+              for (int num_ctas : kNumCtas) {
+                // Clusters are only supported on Hopper.
+                // Autotuning this parameter is enabled by a flag.
+                if (!tune_ctas && num_ctas > 1) {
+                  break;
+                }
+                if (num_ctas > num_warps) {
+                  break;
+                }
+                configs.push_back(TritonGemmConfig(tile_m, tile_n, tile_k,
+                                                   split_k, num_stages,
+                                                   num_warps, num_ctas));
+              }
+            }
+          }
+        }
+      }
+    }
+  }
+  return configs;
+}
+
+std::vector<TritonGemmConfig> GemmFusionAutotunerImpl::GetDefaultTritonConfigs()
+    const {
+  using Config = TritonGemmConfig;
+  std::vector<Config> configs = {
+      Config(32, 32, 256, 1, 1, 4), Config(64, 32, 32, 16, 1, 4),
+      Config(32, 64, 64, 4, 1, 4),  Config(128, 128, 64, 4, 1, 4),
+      Config(16, 16, 256, 1, 1, 4), Config(16, 128, 32, 16, 1, 4),
+      Config(16, 64, 128, 1, 1, 4), Config(16, 128, 32, 8, 1, 4),
+      Config(16, 16, 512, 1, 1, 4), Config(32, 16, 512, 1, 1, 4),
+      Config(64, 32, 64, 1, 2, 8)};
+  if (GetComputeCapability().IsAtLeastAmpere()) {
+    absl::c_copy(
+        std::vector<Config>{
+            Config(128, 256, 32, 1, 3, 8),  Config(256, 128, 32, 1, 3, 8),
+            Config(256, 64, 32, 1, 4, 4),   Config(64, 256, 32, 1, 4, 4),
+            Config(128, 64, 32, 1, 4, 4),   Config(64, 128, 32, 1, 4, 4),
+            Config(256, 128, 128, 1, 3, 8), Config(256, 64, 128, 1, 4, 4),
+            Config(64, 256, 128, 1, 4, 4),  Config(128, 128, 128, 1, 4, 4),
+            Config(128, 64, 64, 1, 4, 4),   Config(64, 128, 64, 1, 4, 4),
+            Config(128, 32, 64, 1, 4, 4),   Config(64, 32, 64, 1, 4, 4),
+            Config(32, 128, 32, 1, 4, 4),   Config(128, 128, 32, 1, 4, 4),
+            Config(16, 16, 256, 1, 3, 4),   Config(128, 128, 64, 2, 1, 8),
+            Config(64, 64, 64, 1, 2, 4),    Config(16, 64, 256, 8, 1, 4),
+            Config(256, 256, 128, 1, 3, 8)},
+        std::back_inserter(configs));
+  }
+  if (GetComputeCapability().IsAtLeastHopper()) {
+    absl::c_copy(
+        std::vector<Config>{
+            Config(16, 32, 32, 8, 1, 2),
+            Config(16, 64, 128, 8, 1, 4),
+            Config(16, 64, 128, 16, 3, 4),
+        },
+        std::back_inserter(configs));
+  }
+  return configs;
+}
+
+absl::Status DumpAutotuningLogs(const DebugOptions& debug_opts,
+                                const AutotuningLogs& autotuning_logs) {
+  if (absl::string_view file_path = debug_opts.xla_gpu_dump_autotune_logs_to();
+      !file_path.empty()) {
+    std::string resolved_path;
+    if (!tsl::io::ResolveTestPrefixes(file_path, resolved_path)) {
+      return FailedPrecondition("File path can not be resolved: %s", file_path);
+    }
+
+    std::string textproto;
+    tsl::protobuf::TextFormat::PrintToString(autotuning_logs, &textproto);
+
+    TF_RETURN_IF_ERROR(
+        tsl::WriteStringToFile(tsl::Env::Default(), resolved_path, textproto));
+    LOG(INFO) << "Autotune logs serialized to file: " << resolved_path;
+  }
+  return absl::OkStatus();
+}
+
+absl::Status GemmFusionAutotunerImpl::Autotune(
+    AutotunerCompileUtil& compile_util, const TilingConfigs& gemm_config_sets,
+    AutoTuneCacheKeyCount fusion_count_map) {
+  TF_ASSIGN_OR_RETURN(auto executable_sets,
+                      CompileAll(compile_util, gemm_config_sets));
+
+  // Sort the candidates to make their execution order well-defined for each
+  // fusion.
+  for (auto& [unused, candidates] : executable_sets) {
+    absl::c_sort(candidates, [](const auto& a, const auto& b) {
+      return a.config < b.config;
+    });
+  }
+
+  AutotuningLogs autotuning_logs;
+  int fusion_id = 0;
+  for (const auto& [fusion, candidates] : executable_sets) {
+    TF_ASSIGN_OR_RETURN(std::vector<AutotuneResult> results,
+                        Profile(compile_util, *fusion, candidates));
+
+    // The reference config (if it exists) will be the first in the results,
+    // due to how sorting the variants work.
+    if (!debug_options_.xla_gpu_cublas_fallback() &&
+        results.front().has_gemm()) {
+      results.erase(results.begin());
+    }
+
+    const HloInstruction* root =
+        fusion->called_computations().at(0)->root_instruction();
+    TF_ASSIGN_OR_RETURN(
+        AutotuneResult best,
+        PickBestResult(results, root->ToString(), root->GetModule()->config()));
+    VLOG(2) << "Best time: "
+            << tsl::proto_utils::FromDurationProto(best.run_time());
+
+    if (debug_options_.xla_gpu_dump_autotuned_gemm_fusions()) {
+      TF_RETURN_IF_ERROR(DumpOriginalFusion(compile_util, *fusion, fusion_id));
+      TF_RETURN_IF_ERROR(DumpAutotunedFusion(
+          config_, toolkit_version_, compile_util, best, fusion, fusion_id++));
+    }
+
+    const AutotuneCacheKey key = AutotunerUtil::GetKey(fusion, config_);
+    TF_ASSIGN_OR_RETURN(
+        bool added, AutotunerUtil::AddResult(key, std::move(best), config_));
+    if (!added) {
+      // In the context of model server, concurrent autotuning is expected and
+      // insertion of identical autotuning keys is accepted.
+      LOG(WARNING) << "AutotunerUtil::AddResult already existed: "
+                   << key.ToString();
+    }
+
+    if (!debug_options_.xla_gpu_dump_autotune_logs_to().empty()) {
+      auto autotuning_log = autotuning_logs.add_logs();
+      autotuning_log->set_fusion_name(std::string(fusion->name()));
+
+      for (const auto& autotune_result : results) {
+        auto log_result = autotuning_log->add_results();
+        log_result->CopyFrom(autotune_result);
+      }
+
+      if (auto fusion_key_count = fusion_count_map.find(key);
+          fusion_key_count != fusion_count_map.end()) {
+        auto fusion_key = fusion_key_count->first;
+        auto fusion_count = fusion_key_count->second;
+        autotuning_log->set_fusion_count(fusion_count);
+      }
+    }
+  }
+
+  TF_RETURN_IF_ERROR(DumpAutotuningLogs(debug_options_, autotuning_logs));
+
+  return absl::OkStatus();
+}
+
+// Trim the set of configs to what one rank has to run.
+static TilingConfigs TrimConfigs(const TilingConfigs& gemm_config_sets,
+                                 const int shard_index, const int shard_count) {
+  const uint64_t bucket_size =
+      (gemm_config_sets.size() + shard_count - 1) / shard_count;
+  const uint64_t start = bucket_size * shard_index;
+  const uint64_t end = std::min(start + bucket_size, gemm_config_sets.size());
+  if (start >= end) {
+    return {};
+  }
+  return TilingConfigs(gemm_config_sets.cbegin() + start,
+                       gemm_config_sets.cbegin() + end);
+}
+
+// Exchange the results with the other ranks.
+absl::Status ExchangeResults(KeyValueStoreInterface& key_value_store,
+                             const int module_id, const int shard_index,
+                             const int shard_count) {
+  AutotuneResults results;
+  TF_RETURN_IF_ERROR(AutotunerUtil::SerializeAutotuneResults(&results));
+  TF_ASSIGN_OR_RETURN(std::string results_str,
+                      AutotuneResultsToString(results, true));
+  constexpr absl::string_view kKeyPrefix = "gemm_fusion_autotuning_results";
+  TF_RETURN_IF_ERROR(key_value_store.Set(
+      absl::StrFormat("%s_%d_%d", kKeyPrefix, module_id, shard_index),
+      results_str));
+  VLOG(2) << "Rank " << shard_index << ": published results";
+  for (int i = 0; i < shard_count; ++i) {
+    if (i == shard_index) {
+      continue;
+    }
+    VLOG(2) << "Rank " << shard_index << ": waiting for results from rank " << i
+            << " / " << shard_count;
+    TF_ASSIGN_OR_RETURN(
+        std::string autotune_results_str,
+        key_value_store.Get(
+            absl::StrFormat("%s_%d_%d", kKeyPrefix, module_id, i),
+            absl::InfiniteDuration()));
+    TF_RETURN_IF_ERROR(
+        AutotunerUtil::LoadAutotuneResults(autotune_results_str, true));
+  }
+  return absl::OkStatus();
+}
+
+absl::StatusOr<bool> GemmFusionAutotuner::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  XLA_SCOPED_LOGGING_TIMER("GEMM fusion autotuner");
+
+  const DebugOptions& debug_options = module->config().debug_options();
+  GemmFusionAutotunerImpl autotuner(config_, toolkit_version_, debug_options,
+                                    thread_pool_);
+  GemmConfigSetCollector gemm_config_set_collector(&autotuner);
+  TF_ASSIGN_OR_RETURN(TilingConfigs gemm_config_sets,
+                      gemm_config_set_collector.CollectGemmConfigSets(
+                          module, execution_threads));
+  const int total_fusion_count = gemm_config_sets.size();
+
+  AutoTuneCacheKeyCount fusion_count_map =
+      gemm_config_set_collector.GetFusionsCount();
+
+  if (!autotuner.IsAutotuningEnabled()) {
+    // Pick the first option for each gemm instead of autotuning.
+    for (const auto& [fusion, tilings] : gemm_config_sets) {
+      const AutotuneCacheKey key = AutotunerUtil::GetKey(fusion, config_);
+      AutotuneResult res = FromConfig(tilings[0]);
+      *res.mutable_run_time() =
+          tsl::proto_utils::ToDurationProto(absl::ZeroDuration());
+      TF_RETURN_IF_ERROR(AutotunerUtil::AddResult(key, res, config_).status());
+    }
+  } else if (!debug_options.xla_gpu_override_gemm_autotuner().empty()) {
+    // TODO(gflegar): support overriding with non-Triton configs (cuBLAS, cuDNN)
+    AutotuneResult::TritonGemmKey gemm_key;
+    CHECK(tsl::protobuf::TextFormat::ParseFromString(
+        debug_options.xla_gpu_override_gemm_autotuner(), &gemm_key));
+    VLOG(1) << "Overriding GEMM autotuner with the following config: "
+            << gemm_key.DebugString();
+    for (const auto& [fusion, unused] : gemm_config_sets) {
+      const AutotuneCacheKey key = AutotunerUtil::GetKey(fusion, config_);
+      AutotuneResult res;
+      *res.mutable_triton() = gemm_key;
+      *res.mutable_run_time() =
+          tsl::proto_utils::ToDurationProto(absl::ZeroDuration());
+      TF_RETURN_IF_ERROR(AutotunerUtil::AddResult(key, res, config_).status());
+    }
+  } else if (!config_.IsDeviceless()) {
+    TF_ASSIGN_OR_RETURN(std::optional<AutotunerCompileUtil> opt_compile_util,
+                        AutotunerCompileUtil::Create(config_, debug_options));
+    TF_RET_CHECK(opt_compile_util.has_value());
+    std::string correctness_check_str = config_.should_check_correctness()
+                                            ? "(with correctness check)"
+                                            : "(without correctness check)";
+
+    const bool shard_autotuning = debug_options.xla_gpu_shard_autotuning() &&
+                                  key_value_store_.process_count > 1 &&
+                                  total_fusion_count > 0;
+    if (shard_autotuning) {
+      if (key_value_store_.key_value_store == nullptr) {
+        return absl::FailedPreconditionError(
+            "Sharded autotuning requested but key-value store is missing.");
+      }
+      gemm_config_sets =
+          TrimConfigs(gemm_config_sets, key_value_store_.process_index,
+                      key_value_store_.process_count);
+    }
+
+    VLOG(1) << absl::StrFormat(
+        "Shard %d / %d: autotuning %d / %d fusions for %s %s.",
+        key_value_store_.process_index + 1, key_value_store_.process_count,
+        gemm_config_sets.size(), total_fusion_count, module->name(),
+        correctness_check_str);
+    TF_RETURN_IF_ERROR(autotuner.Autotune(*opt_compile_util, gemm_config_sets,
+                                          std::move(fusion_count_map)));
+    VLOG(1) << "Done autotuning.";
+
+    if (shard_autotuning) {
+      TF_RETURN_IF_ERROR(ExchangeResults(
+          *key_value_store_.key_value_store, module->unique_id(),
+          key_value_store_.process_index, key_value_store_.process_count));
+    }
+  }
+
+  return GemmFusionAutotunerVisitor(config_).RunOnModule(module,
+                                                         execution_threads);
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.h b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.h
new file mode 100644
index 0000000..b12fe00
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.h
@@ -0,0 +1,147 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_AUTOTUNING_GEMM_FUSION_AUTOTUNER_H_
+#define XLA_SERVICE_GPU_AUTOTUNING_GEMM_FUSION_AUTOTUNER_H_
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+#include <variant>
+#include <vector>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/autotuning.pb.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/pjrt/distributed/key_value_store_interface.h"
+#include "xla/service/executable.h"
+#include "xla/service/gpu/autotuning/autotuner_compile_util.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
+#include "xla/service/gpu/matmul_utils.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/xla.pb.h"
+#include "tsl/platform/threadpool.h"
+
+namespace xla {
+namespace gpu {
+
+// Find best tiling configuration for each triton fusion outlined.
+class GemmFusionAutotuner : public HloModulePass {
+ public:
+  explicit GemmFusionAutotuner(const AutotuneConfig& config,
+                               const int32_t toolkit_version,
+                               tsl::thread::ThreadPool* thread_pool,
+                               const MultiProcessKeyValueStore& key_value_store)
+      : config_(config),
+        toolkit_version_(toolkit_version),
+        thread_pool_(thread_pool),
+        key_value_store_(key_value_store) {}
+
+  absl::string_view name() const override { return "triton-autotuner"; }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+  const AutotuneConfig config_;
+  const int32_t toolkit_version_;
+  tsl::thread::ThreadPool* thread_pool_;
+  MultiProcessKeyValueStore key_value_store_;
+};
+
+// Autotuner implementation.
+class GemmFusionAutotunerImpl {
+ public:
+  GemmFusionAutotunerImpl(const AutotuneConfig config,
+                          const int32_t toolkit_version,
+                          const DebugOptions debug_options,
+                          tsl::thread::ThreadPool* thread_pool)
+      : config_(std::move(config)),
+        toolkit_version_(toolkit_version),
+        debug_options_(std::move(debug_options)),
+        thread_pool_(thread_pool) {}
+
+  struct CuBlasConfig {
+    bool operator<(const CuBlasConfig& other) const;
+  };
+  struct CuDnnConfig {
+    int64_t plan_id;
+    bool operator<(const CuDnnConfig& other) const;
+  };
+  using Config = std::variant<CuBlasConfig, CuDnnConfig, TritonGemmConfig>;
+  using TilingConfigs =
+      std::vector<std::pair<const HloFusionInstruction*, std::vector<Config>>>;
+
+  struct ExecutableCandidate {
+    Config config;
+    std::unique_ptr<Executable> executable;
+  };
+
+  // Generate all possible configs for a dot operation.
+  absl::StatusOr<std::vector<Config>> GenerateConfigs(
+      const HloFusionInstruction& fusion);
+  absl::StatusOr<std::vector<TritonGemmConfig>> GenerateTritonConfigs(
+      const HloDotInstruction& dot);
+
+  // Compile all executables for all fusions.
+  absl::StatusOr<absl::flat_hash_map<const HloFusionInstruction*,
+                                     std::vector<ExecutableCandidate>>>
+  CompileAll(AutotunerCompileUtil& compile_util, const TilingConfigs& task);
+
+  // Profile all executables for a fusion.
+  absl::StatusOr<std::vector<AutotuneResult>> Profile(
+      AutotunerCompileUtil& compile_util, const HloFusionInstruction& fusion,
+      absl::Span<const ExecutableCandidate> candidates);
+
+  // Autotune and save the results to the autotuning cache.
+  absl::Status Autotune(
+      AutotunerCompileUtil& compile_util, const TilingConfigs& gemm_config_sets,
+      absl::flat_hash_map<AutotuneCacheKey, uint64_t> fusion_count_map);
+
+  // Helper methods.
+  const AutotuneConfig& GetConfig() const { return config_; }
+  bool IsAutotuningEnabled() const;
+  static std::string ToString(const Config& config);
+
+ private:
+  se::CudaComputeCapability GetComputeCapability() const {
+    return std::get<se::CudaComputeCapability>(
+        config_.GetGpuComputeCapability());
+  }
+
+  std::vector<TritonGemmConfig> GetDefaultTritonConfigs() const;
+  std::vector<TritonGemmConfig> GetExhaustiveTritonConfigs() const;
+
+  const AutotuneConfig config_;
+  const int32_t toolkit_version_;
+  const DebugOptions debug_options_;
+  tsl::thread::ThreadPool* thread_pool_;
+  std::vector<TritonGemmConfig> triton_configs_;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_AUTOTUNING_GEMM_FUSION_AUTOTUNER_H_
diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc
new file mode 100644
index 0000000..03239a1
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc
@@ -0,0 +1,999 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "xla/service/gpu/autotuning/gemm_fusion_autotuner.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+#include <variant>
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/string_view.h"
+#include "third_party/gpus/cuda/include/cuda.h"
+#include "xla/autotuning.pb.h"
+#include "xla/error_spec.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/pjrt/distributed/key_value_store_interface.h"
+#include "xla/service/call_inliner.h"
+#include "xla/service/dump.h"
+#include "xla/service/executable.h"
+#include "xla/service/gpu/autotuning/autotuner_compile_util.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/ir_emission_utils.h"
+#include "xla/service/gpu/matmul_utils.h"
+#include "xla/service/gpu/transforms/gemm_fusion.h"
+#include "xla/service/gpu/transforms/gemm_rewriter.h"
+#include "xla/service/hlo_module_config.h"
+#include "xla/service/hlo_pass_pipeline.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/stream_executor/device_description.pb.h"
+#include "xla/tests/filecheck.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/tests/test_utils.h"
+#include "xla/tests/verified_hlo_module.h"
+#include "xla/tools/hlo_decomposer.h"
+#include "xla/tsl/lib/core/status_test_util.h"
+#include "xla/xla.pb.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/cpu_info.h"
+#include "tsl/platform/env.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/path.h"
+#include "tsl/platform/status_matchers.h"
+#include "tsl/platform/statusor.h"
+#include "tsl/platform/threadpool.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+namespace m = ::xla::match;
+
+using HloExtractionTest = HloTestBase;
+
+TEST_F(HloExtractionTest, InstructionExtractionIsCorrect) {
+  std::unique_ptr<VerifiedHloModule> module = ParseAndReturnVerifiedModule(R"(
+HloModule module
+
+triton_gemm_dot {
+  p0 = s8[10,10] parameter(0)
+  p1 = f32[10,10] parameter(1)
+  c0 = f32[10,10] convert(p0)
+  ROOT dot.0 = f32[10,10] dot(c0, p1),
+    lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+
+ENTRY entry {
+  p0 = s8[10,10] parameter(0)
+  p1 = f32[10,10] parameter(1)
+  s = f32[10,10] sqrt(p1)
+  d = f32[10,10] fusion(p0, p1),
+    kind=kCustom, calls=triton_gemm_dot
+  ROOT r = f32[10,10] add(d, s)
+})")
+                                                  .value();
+
+  std::unique_ptr<HloModule> extracted_module = ExtractInstructionIntoNewModule(
+      *module->entry_computation()->root_instruction()->operand(0));
+
+  // Destroy the original module to be sure that the extracted one has no
+  // dependency on it.
+  module.release();
+
+  EXPECT_THAT(extracted_module->entry_computation()->root_instruction(),
+              GmockMatch(m::Fusion(m::Parameter(), m::Parameter())));
+  EXPECT_EQ(extracted_module->entry_computation()->instruction_count(), 3);
+  TF_EXPECT_OK(VerifyHloModule(extracted_module.get(),
+                               /*layout_sensitive=*/true,
+                               /*allow_mixed_precision=*/false));
+}
+
+TEST_F(HloExtractionTest, ComputationExtractionIsCorrect) {
+  std::unique_ptr<VerifiedHloModule> module = ParseAndReturnVerifiedModule(R"(
+HloModule module
+
+triton_gemm_dot {
+  p0 = s8[10,10] parameter(0)
+  p1 = f32[10,10] parameter(1)
+  c0 = f32[10,10] convert(p0)
+  ROOT dot.0 = f32[10,10] dot(c0, p1),
+    lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+
+ENTRY entry {
+  p0 = s8[10,10] parameter(0)
+  p1 = f32[10,10] parameter(1)
+  s = f32[10,10] sqrt(p1)
+  d = f32[10,10] fusion(p0, p1),
+    kind=kCustom, calls=triton_gemm_dot
+  ROOT r = f32[10,10] add(d, s)
+})")
+                                                  .value();
+
+  std::unique_ptr<HloModule> extracted_module =
+      ExtractComputationIntoNewModule(*module->entry_computation()
+                                           ->root_instruction()
+                                           ->operand(0)
+                                           ->fused_instructions_computation());
+
+  // Destroy the original module to be sure that the extracted one has no
+  // dependency on it.
+  module.release();
+
+  EXPECT_THAT(extracted_module->entry_computation()->root_instruction(),
+              GmockMatch(m::Dot(m::Convert(m::Parameter()), m::Parameter())));
+  EXPECT_EQ(extracted_module->entry_computation()->instruction_count(), 4);
+  TF_EXPECT_OK(VerifyHloModule(extracted_module.get(),
+                               /*layout_sensitive=*/true,
+                               /*allow_mixed_precision=*/false));
+}
+
+class StatelessAutotunerTest : public HloTestBase {
+ public:
+  StatelessAutotunerTest()
+      : HloTestBase(/*verifier_layout_sensitive=*/true,
+                    /*allow_mixed_precision_in_hlo_verifier=*/false) {}
+
+  int32_t GetToolkitVersion() const { return CUDA_VERSION; }
+
+  void SetUp() override {
+    AutotunerUtil::ClearAutotuneResults();
+    HloTestBase::SetUp();
+  }
+
+  void TearDown() override {
+    AutotunerUtil::ClearAutotuneResults();
+    HloTestBase::TearDown();
+  }
+};
+
+class GemmFusionAutotunerTest : public StatelessAutotunerTest {
+ public:
+  DebugOptions GetDebugOptionsForTest() override {
+    DebugOptions debug_options =
+        StatelessAutotunerTest::GetDebugOptionsForTest();
+    debug_options.set_xla_gpu_enable_triton_gemm(true);
+    debug_options.set_xla_gpu_cublas_fallback(false);
+    debug_options.set_xla_gpu_cudnn_gemm_fusion_level(0);
+    return debug_options;
+  }
+
+  se::CudaComputeCapability GetCudaComputeCapability() {
+    return backend()
+        .default_stream_executor()
+        ->GetDeviceDescription()
+        .cuda_compute_capability();
+  }
+
+  void CheckTritonAutotuning(absl::string_view hlo,
+                             absl::string_view expected) {
+    HloPassPipeline pipeline("gemm_rewrite");
+    pipeline.AddPass<GemmFusion>(backend()
+                                     .default_stream_executor()
+                                     ->GetDeviceDescription()
+                                     .cuda_compute_capability());
+    tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "",
+                                        tsl::port::MaxParallelism());
+    DebugOptions opts;
+    MultiProcessKeyValueStore key_value_store;
+    pipeline.AddPass<GemmFusionAutotuner>(
+        AutotuneConfig{DeviceConfig{backend().default_stream_executor(),
+                                    backend().memory_allocator()},
+                       opts},
+        GetToolkitVersion(), &thread_pool, key_value_store);
+
+    RunAndFilecheckHloRewrite(
+        hlo, std::move(pipeline), expected, [](const HloModule* m) {
+          VLOG(5) << m->ToString();
+          const HloInstruction* dot_fusion =
+              m->entry_computation()->root_instruction();
+          if (dot_fusion->opcode() == HloOpcode::kReduce) {
+            dot_fusion = dot_fusion->operand(0);
+          }
+          CHECK_EQ(dot_fusion->opcode(), HloOpcode::kFusion);
+          if (!dot_fusion->backend_config<GpuBackendConfig>()
+                   ->fusion_backend_config()
+                   .has_cudnn_fusion_config()) {
+            CHECK_GT(dot_fusion->backend_config<GpuBackendConfig>()
+                         .value()
+                         .fusion_backend_config()
+                         .triton_gemm_config()
+                         .block_m(),
+                     0);
+          }
+        });
+  }
+};
+
+class GemmFusionAutotunerTestWithMorePreciseReduction
+    : public GemmFusionAutotunerTest {
+ public:
+  DebugOptions GetDebugOptionsForTest() override {
+    DebugOptions debug_options =
+        GemmFusionAutotunerTest::GetDebugOptionsForTest();
+    debug_options.set_xla_gpu_triton_gemm_disable_reduced_precision_reduction(
+        true);
+    return debug_options;
+  }
+};
+
+absl::StatusOr<std::vector<TritonGemmConfig>> GetPossibleMatmulAutotuneConfigs(
+    const HloDotInstruction& dot,
+    const se::CudaComputeCapability& compute_capability,
+    const int32_t toolkit_version, const DebugOptions& debug_options) {
+  se::GpuDeviceInfoProto deviceless_proto;
+  auto ccc = deviceless_proto.mutable_cuda_compute_capability();
+  ccc->set_major(compute_capability.major);
+  ccc->set_minor(compute_capability.minor);
+  DevicelessConfig test_config{se::DeviceDescription{deviceless_proto}};
+  AutotuneConfig autotune_config{test_config, debug_options};
+  GemmFusionAutotunerImpl autotuner(autotune_config, toolkit_version,
+                                    debug_options, nullptr);
+  return autotuner.GenerateTritonConfigs(dot);
+}
+
+TEST_F(GemmFusionAutotunerTest, AmpereUsesMoreThanTwoStages) {
+  std::unique_ptr<VerifiedHloModule> module = ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+  p0 = f32[1024,1024] parameter(0)
+  p1 = f32[1024,1024] parameter(1)
+  ROOT r = f32[1024,1024] dot(p0, p1),
+    lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})")
+                                                  .value();
+  const se::CudaComputeCapability compute_capability{
+      se::CudaComputeCapability::AMPERE, /*minor=*/0};
+  TF_ASSERT_OK_AND_ASSIGN(
+      const std::vector<TritonGemmConfig> configs,
+      GetPossibleMatmulAutotuneConfigs(
+          *Cast<HloDotInstruction>(
+              module->entry_computation()->root_instruction()),
+          compute_capability, GetToolkitVersion(), GetDebugOptionsForTest()));
+  EXPECT_TRUE(std::any_of(
+      configs.begin(), configs.end(),
+      [](const TritonGemmConfig& config) { return config.num_stages > 2; }));
+}
+
+TEST_F(GemmFusionAutotunerTest, SmallOutputCanUseLargeSplitK) {
+  std::unique_ptr<VerifiedHloModule> module = ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+  p0 = f32[1024,1024] parameter(0)
+  p1 = f32[1024,1024] parameter(1)
+  ROOT r = f32[1024,1024] dot(p0, p1),
+    lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})")
+                                                  .value();
+  const se::CudaComputeCapability compute_capability{
+      se::CudaComputeCapability::AMPERE, /*minor=*/0};
+  TF_ASSERT_OK_AND_ASSIGN(
+      const std::vector<TritonGemmConfig> configs,
+      GetPossibleMatmulAutotuneConfigs(
+          *Cast<HloDotInstruction>(
+              module->entry_computation()->root_instruction()),
+          compute_capability, GetToolkitVersion(), GetDebugOptionsForTest()));
+  EXPECT_TRUE(std::any_of(
+      configs.begin(), configs.end(),
+      [](const TritonGemmConfig& config) { return config.split_k >= 4; }));
+}
+
+TEST_F(GemmFusionAutotunerTest, LargeOutputDoesNotUseLargeSplitK) {
+  std::unique_ptr<VerifiedHloModule> module = ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+  p0 = f32[20480,20480] parameter(0)
+  p1 = f32[20480,20480] parameter(1)
+  ROOT r = f32[20480,20480] dot(p0, p1),
+    lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})")
+                                                  .value();
+  const se::CudaComputeCapability compute_capability{
+      se::CudaComputeCapability::AMPERE, /*minor=*/0};
+  TF_ASSERT_OK_AND_ASSIGN(
+      const std::vector<TritonGemmConfig> configs,
+      GetPossibleMatmulAutotuneConfigs(
+          *Cast<HloDotInstruction>(
+              module->entry_computation()->root_instruction()),
+          compute_capability, GetToolkitVersion(), GetDebugOptionsForTest()));
+  EXPECT_FALSE(std::any_of(
+      configs.begin(), configs.end(),
+      [](const TritonGemmConfig& config) { return config.split_k > 1; }));
+}
+
+TEST_F(GemmFusionAutotunerTest, Int8FusedGemm) {
+  const std::string hlo = R"(
+HloModule module
+
+ENTRY e {
+  x = s8[128,64] parameter(0)
+  c = f16[128,64] convert(x)
+
+  y = f16[64,6144] parameter(1)
+
+  ROOT out = f16[128,6144] dot(c, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+)";
+  CheckTritonAutotuning(hlo, R"(
+// CHECK: ENTRY
+// CHECK: ROOT
+// CHECK-SAME: kCustom
+// CHECK-SAME: block_m
+)");
+
+  EXPECT_TRUE(RunAndCompare(hlo, ErrorSpec{/*aabs=*/5e-3, /*arel=*/5e-3}));
+}
+
+TEST_F(GemmFusionAutotunerTest, Int8FusedGemm256) {
+  const std::string hlo = R"(
+HloModule module
+
+ENTRY e {
+  x = s8[128,256] parameter(0)
+  c = f16[128,256] convert(x)
+
+  y = f16[256,6144] parameter(1)
+
+  ROOT out = f16[128,6144] dot(c, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+)";
+
+  CheckTritonAutotuning(hlo, R"(
+// CHECK: ENTRY
+// CHECK: ROOT
+// CHECK-SAME: kCustom
+// CHECK-SAME: block_m
+)");
+
+  EXPECT_TRUE(RunAndCompare(hlo, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2}));
+}
+
+TEST_F(GemmFusionAutotunerTest, SelectsSplitK) {
+  // Shapes with K >> M, N have to force split-K configurations.
+  const std::string kHloText = R"(
+HloModule t
+
+ENTRY e {
+  p0 = s8[7,8192] parameter(0)
+  p0c = f16[7,8192] convert(p0)
+  p1 = f16[8192,18] parameter(1)
+  ROOT dot.0 = f16[7,18] dot(p0c, p1),
+    lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})";
+
+  MatchOptimizedHlo(kHloText, R"(
+; CHECK: reduce
+; CHECK: ENTRY
+; CHECK-NEXT: parameter
+; CHECK-NEXT: parameter
+; CHECK-NEXT: kCustom
+; CHECK-NEXT: kLoop
+)");
+
+  EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1, /*arel=*/0.5}));
+}
+
+TEST_F(GemmFusionAutotunerTestWithMorePreciseReduction, SelectsSplitK) {
+  // Shapes with K >> M, N have to force split-K configurations.
+  constexpr absl::string_view kHloText = R"(
+HloModule t
+
+ENTRY e {
+  p0 = s8[7,8192] parameter(0)
+  p0c = f16[7,8192] convert(p0)
+  p1 = f16[8192,18] parameter(1)
+  ROOT dot.0 = f16[7,18] dot(p0c, p1),
+    lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})";
+
+  MatchOptimizedHlo(kHloText, R"(
+; CHECK: reduce
+; CHECK: ENTRY
+; CHECK-NEXT: parameter
+; CHECK-NEXT: parameter
+; CHECK-NEXT: kCustom
+; CHECK-NEXT: kLoop
+)");
+
+  EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-3}));
+}
+
+TEST_F(GemmFusionAutotunerTest, ApplySplitKWithoutAlteringTiling) {
+  const std::string kHloText = R"(
+triton_dot {
+  p0 = f16[55,120] parameter(0)
+  p1 = f16[120,20] parameter(1)
+  ROOT dot = f16[55,20] dot(p0, p1),
+    lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+
+ENTRY e {
+  p0 = f16[55,120]{1,0} parameter(0)
+  p1 = f16[120,20]{1,0} parameter(1)
+  ROOT _ = f16[55,20] fusion(p0, p1), kind=kCustom, calls=triton_dot,
+    backend_config={"fusion_backend_config":{kind: "__triton_gemm", triton_gemm_config: {"block_m":16,"block_n":64,"block_k":32,"split_k":3,"num_stages":1,"num_warps":2,"num_ctas":1}}}
+})";
+
+  MatchOptimizedHlo(kHloText, R"(
+; CHECK: f16[3,55,20]
+; CHECK: {"block_m":16,"block_n":64,"block_k":32,"split_k":3,"num_stages":1,"num_warps":2,"num_ctas":1}
+; CHECK: f16[55,20]{1,0} {{(reduce|fusion)}}
+)");
+
+  EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3}));
+}
+
+// Modify block_k back to 16 once b/337839570 is fixed.
+// TODO(b/344770374): Make this test not fragile.
+TEST_F(GemmFusionAutotunerTest, DoNotRunAutotuningKernelSpillingRegisters) {
+  const std::string kHloText = R"(
+HloModule m
+
+%triton_gemm_dot {
+  %p1 = s8[4,12288]{1,0} parameter(1)
+  %p0 = s8[12288,1536]{1,0} parameter(0)
+  %convert.p0 = f16[12288,1536]{1,0} convert(s8[12288,1536]{1,0} %p0)
+  %convert.p1 = f16[4,12288]{1,0} convert(s8[4,12288]{1,0} %p1)
+  %dot = f16[4,1536]{1,0} dot(f16[4,12288]{1,0} %convert.p1, f16[12288,1536]{1,0} %convert.p0), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  ROOT %convert = s8[4,1536]{1,0} convert(f16[4,1536]{1,0} %dot)
+}
+
+ENTRY %e {
+  %get-tuple-element.7020 = s8[12288,1536]{1,0} parameter(0)
+  %convert = s8[4,12288]{1,0} parameter(1)
+  ROOT %triton = s8[4,1536]{1,0} fusion(s8[12288,1536]{1,0} %get-tuple-element.7020, s8[4,12288]{1,0} %convert), kind=kCustom, calls=%triton_gemm_dot,
+    backend_config={"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"256","block_n":"256","block_k":"32","split_k":"1","num_stages":"1","num_warps":"16","num_ctas":"1"}}}
+})";
+
+  auto module = ParseAndReturnVerifiedModule(kHloText).value();
+  EXPECT_THAT(
+      backend().compiler()->RunBackend(std::move(module),
+                                       backend().default_stream_executor(),
+                                       {/*device_allocator=*/nullptr,
+                                        /*thread_pool=*/nullptr,
+                                        /*layout_canonicalization_callback=*/{},
+                                        /*is_autotuning_compilation=*/true}),
+      ::testing::AnyOf(
+          tsl::testing::StatusIs(
+              tsl::error::CANCELLED,
+              absl::StrFormat(
+                  "Compilation result discarded due to register spilling")),
+          // Hopper can't spill registers since wgmma instructions are
+          // asynchronous, instead it just runs out of them.
+          tsl::testing::StatusIs(
+              tsl::error::RESOURCE_EXHAUSTED,
+              absl::StrFormat("Register allocation failed"))));
+}
+
+// Modify block_k back to 16 once b/337839570 is fixed.
+// TODO(b/344770374): Make this test not fragile.
+TEST_F(GemmFusionAutotunerTest,
+       DoNotFilterOutAutotuningKernelSpillingRegisters) {
+  if (GetCudaComputeCapability().IsAtLeastHopper()) {
+    GTEST_SKIP() << "Hopper and newer runs out of registers for such HLOs";
+  }
+  const std::string kHloText = R"(
+HloModule m
+
+%triton_gemm_dot {
+  %p1 = s8[4,12288]{1,0} parameter(1)
+  %p0 = s8[12288,1536]{1,0} parameter(0)
+  %convert.p0 = f16[12288,1536]{1,0} convert(s8[12288,1536]{1,0} %p0)
+  %convert.p1 = f16[4,12288]{1,0} convert(s8[4,12288]{1,0} %p1)
+  %dot = f16[4,1536]{1,0} dot(f16[4,12288]{1,0} %convert.p1, f16[12288,1536]{1,0} %convert.p0), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  ROOT %convert = s8[4,1536]{1,0} convert(f16[4,1536]{1,0} %dot)
+}
+
+ENTRY %e {
+  %get-tuple-element.7020 = s8[12288,1536]{1,0} parameter(0)
+  %convert = s8[4,12288]{1,0} parameter(1)
+  ROOT %triton = s8[4,1536]{1,0} fusion(s8[12288,1536]{1,0} %get-tuple-element.7020, s8[4,12288]{1,0} %convert), kind=kCustom, calls=%triton_gemm_dot,
+    backend_config={"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"256","block_n":"256","block_k":"32","split_k":"1","num_stages":"1","num_warps":"16","num_ctas":"1"}}}
+})";
+
+  auto module = ParseAndReturnVerifiedModule(kHloText).value();
+  HloModuleConfig config = module->config();
+  DebugOptions debug_options = config.debug_options();
+  debug_options.set_xla_gpu_filter_kernels_spilling_registers_on_autotuning(
+      false);
+  config.set_debug_options(debug_options);
+  module->set_config(config);
+
+  std::unique_ptr<Executable> executable =
+      backend()
+          .compiler()
+          ->RunBackend(std::move(module), backend().default_stream_executor(),
+                       {/*device_allocator=*/nullptr,
+                        /*thread_pool=*/nullptr,
+                        /*layout_canonicalization_callback=*/{},
+                        /*is_autotuning_compilation=*/true})
+          .value();
+  EXPECT_NE(executable, nullptr);
+}
+
+// Modify block_k back to 16 once b/337839570 is fixed.
+TEST_F(GemmFusionAutotunerTest, RunAutotuningKernelNotSpillingRegisters) {
+  const std::string kHloText = R"(
+HloModule m
+
+%triton_gemm_dot {
+  %p1 = f16[4,12288]{1,0} parameter(1)
+  %p0 = s8[12288,1536]{1,0} parameter(0)
+  %convert.10406 = f16[12288,1536]{1,0} convert(s8[12288,1536]{1,0} %p0)
+  ROOT %dot = f16[4,1536]{1,0} dot(f16[4,12288]{1,0} %p1, f16[12288,1536]{1,0} %convert.10406), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+
+ENTRY %e {
+  %p0 = s8[12288,1536]{1,0} parameter(0)
+  %p1 = f16[4,12288]{1,0} parameter(1)
+  ROOT %triton_dot = f16[4,1536]{1,0} fusion(s8[12288,1536]{1,0} %p0, f16[4,12288]{1,0} %p1), kind=kCustom, calls=%triton_gemm_dot,
+    backend_config={"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"16","block_n":"32","block_k":"32","split_k":"1","num_stages":"1","num_warps":"2","num_ctas":"1"}}}
+})";
+
+  auto module = ParseAndReturnVerifiedModule(kHloText).value();
+  std::unique_ptr<Executable> executable =
+      backend()
+          .compiler()
+          ->RunBackend(std::move(module), backend().default_stream_executor(),
+                       {/*device_allocator=*/nullptr,
+                        /*thread_pool=*/nullptr,
+                        /*layout_canonicalization_callback=*/{},
+                        /*is_autotuning_compilation=*/true})
+          .value();
+  EXPECT_NE(executable, nullptr);
+}
+
+using GemmFusionAutotunerDumpTest = GemmFusionAutotunerTest;
+
+TEST_F(GemmFusionAutotunerDumpTest, Fp8CublasltFallbackSupport) {
+  const std::string kHloText = R"(
+HloModule o
+
+gemm_fusion {
+  p0 = f8e4m3fn[64,6144]{1,0} parameter(0)
+  p1 = f8e4m3fn[64,6144]{1,0} parameter(1)
+  ROOT %dot.0 = f32[64,64]{1,0} dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={1}
+}
+
+ENTRY main {
+  p0 = f8e4m3fn[64,6144]{1,0} parameter(0)
+  p1 = f8e4m3fn[64,6144]{1,0} parameter(1)
+  ROOT %dot.0 = f32[64,64]{1,0} fusion(p0, p1), kind=kCustom, calls=gemm_fusion, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_gemm"},"force_earliest_schedule":false}
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+                          ParseAndReturnVerifiedModule(kHloText));
+
+  DebugOptions opts;
+  AutotuneConfig autotune_config{
+      DeviceConfig{backend().default_stream_executor(),
+                   backend().memory_allocator()},
+      opts};
+  AutotuneCacheKey cache_key(autotune_config.GetModelStr(),
+                             *module->entry_computation()->root_instruction());
+
+  TF_ASSERT_OK_AND_ASSIGN(AutotuneResults autotune_results_override,
+                          ParseTextProto<AutotuneResults>(R"pb(
+                            version: 3
+                            results {
+                              device: "..."
+                              hlo: "..."
+                              result {
+                                gemm { algorithm: -1 }
+                                run_time { nanos: 14 }
+                              }
+                            })pb"));
+  autotune_results_override.mutable_results(0)->set_device(
+      std::string(cache_key.GetModelStr()));
+  autotune_results_override.mutable_results(0)->set_hlo(
+      std::string(cache_key.GetHlo()));
+  CHECK_OK(AutotunerUtil::LoadAutotuneResults(autotune_results_override));
+
+  HloPassPipeline pipeline("gemm_autotune");
+  tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "",
+                                      tsl::port::MaxParallelism());
+  MultiProcessKeyValueStore key_value_store;
+  pipeline.AddPass<GemmFusionAutotuner>(autotune_config, GetToolkitVersion(),
+                                        &thread_pool, key_value_store);
+  pipeline.AddPass<CallInliner>();
+  for (GemmRewriterOptions::DType dtype :
+       {GemmRewriterOptions::DType::kFp8Only,
+        GemmRewriterOptions::DType::kNonFp8Only}) {
+    pipeline.AddPass<GemmRewriter>(autotune_config.GetGpuComputeCapability(),
+                                   GetToolkitVersion(),
+                                   GemmRewriterOptions{dtype});
+  }
+
+  TF_EXPECT_OK(HloTestBase::RunHloPass(&pipeline, module.get()));
+  const bool is_at_least_hopper =
+      std::holds_alternative<se::CudaComputeCapability>(
+          autotune_config.GetGpuComputeCapability()) &&
+      std::get<se::CudaComputeCapability>(
+          autotune_config.GetGpuComputeCapability())
+          .IsAtLeastHopper();
+  TF_ASSERT_OK_AND_ASSIGN(
+      bool filecheck_matches,
+      RunFileCheck(module->ToString(), is_at_least_hopper
+                                           ? "// CHECK: __cublas$lt"
+                                           : "// CHECK: __cublas$gemm"));
+  EXPECT_TRUE(filecheck_matches);
+}
+
+TEST_F(GemmFusionAutotunerDumpTest, DumpingWorks) {
+  HloModuleConfig config;
+  DebugOptions options = GetDebugOptionsForTest();
+  options.set_xla_gpu_cublas_fallback(true);
+  options.set_xla_gpu_dump_autotuned_gemm_fusions(true);
+  std::string output_directory;
+  if (!tsl::io::GetTestUndeclaredOutputsDir(&output_directory)) {
+    output_directory = tsl::testing::TmpDir();
+  }
+  options.set_xla_dump_to(output_directory);
+  config.set_debug_options(options);
+  // Computation is chosen such that relatively heavy math operations before the
+  // GEMM are not worth fusing because they would get duplicated many times and
+  // slow down execution. Therefore autotuning picks cuBLAS here.
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+                          ParseAndReturnVerifiedModule(R"(
+fusion1 {
+  p0 = f32[333,333] parameter(0)
+  s = f32[333,333] sine(p0)
+  p1 = f32[333,333] parameter(1)
+  c = f32[333,333] cosine(p1)
+  ROOT dot = f32[333,333] dot(s, c),
+    lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+
+ENTRY e {
+  p0 = f32[333,333] parameter(0)
+  p1 = f32[333,333] parameter(1)
+  ROOT rr = f32[333,333] fusion(p0, p1), kind=kCustom, calls=fusion1,
+    backend_config={"fusion_backend_config": {kind: "__triton_gemm"}}
+})",
+                                                       config));
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
+                          GetOptimizedModule(std::move(module)));
+
+  std::string dump;
+  TF_EXPECT_OK(tsl::ReadFileToString(
+      tsl::Env::Default(),
+      tsl::io::JoinPath(output_directory,
+                        FilenameFor(*optimized_module, /*prefix=*/"",
+                                    /*suffix=*/"gemm_fusion_0.rr.txt")),
+      &dump));
+  EXPECT_TRUE(*RunFileCheck(dump, R"(
+CHECK: HloModule rr
+CHECK-NOT: cublas
+CHECK: __triton_gemm
+CHECK-NOT: block_m
+)"));
+
+  dump.clear();
+
+  TF_EXPECT_OK(tsl::ReadFileToString(
+      tsl::Env::Default(),
+      tsl::io::JoinPath(
+          output_directory,
+          FilenameFor(*optimized_module, /*prefix=*/"",
+                      /*suffix=*/"gemm_fusion_0.rr.optimized.txt")),
+      &dump));
+  EXPECT_TRUE(*RunFileCheck(dump, R"(
+CHECK: HloModule rr
+CHECK-NOT: triton
+CHECK: cublas
+)"));
+}
+
+TEST_F(GemmFusionAutotunerTest, AutotuneCuDnnFusion) {
+  const std::string kHlo = R"(
+fusion1 {
+  p0 = f32[3,28,32] parameter(0)
+  p1 = f32[3,28,32] parameter(1)
+  ROOT d = f32[3,32,32] dot(p0, p1),
+    lhs_batch_dims={0}, rhs_batch_dims={0},
+    lhs_contracting_dims={1}, rhs_contracting_dims={1}
+}
+
+ENTRY e {
+  p0 = f32[3,28,32] parameter(0)
+  p1 = f32[3,28,32] parameter(1)
+  ROOT _ = f32[3,32,32] fusion(p0, p1), kind=kCustom, calls=fusion1,
+    backend_config={"fusion_backend_config": {kind: "__cudnn$fusion"}}
+})";
+
+  CheckTritonAutotuning(kHlo, R"(
+// CHECK: "plan_id":
+)");
+}
+
+// TODO(b/281489442): Write a testcase called
+// `SkipConfigsProducingDeviantResults` or similar.
+
+class GemmFusionAutotunerLevelTest : public StatelessAutotunerTest,
+                                     public ::testing::WithParamInterface<int> {
+ public:
+  DebugOptions GetDebugOptionsForTest() override {
+    DebugOptions debug_options =
+        StatelessAutotunerTest::GetDebugOptionsForTest();
+    debug_options.set_xla_gpu_autotune_level(GetParam());
+    debug_options.set_xla_gpu_cublas_fallback(false);
+    return debug_options;
+  }
+};
+
+TEST_P(GemmFusionAutotunerLevelTest, AllAutotuningLevelsWorkCorrectly) {
+  const std::string kHloText = R"(
+HloModule m
+
+ENTRY e {
+  p0 = pred[64,10] parameter(0)
+  p0c = f32[64,10] convert(p0)
+  p1 = f32[10,128] parameter(1)
+  ROOT r = f32[64,128] dot(p0c, p1),
+    lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})";
+
+  MatchOptimizedHlo(kHloText, R"(
+; CHECK: kind=kCustom
+; CHECK-SAME: block_m
+      )");
+
+  EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3}));
+}
+
+TEST_P(GemmFusionAutotunerLevelTest, Deviceless) {
+  const std::string hlo = R"(
+HloModule module
+
+ENTRY e {
+  x = s8[16,16] parameter(0)
+  c = f16[16,16] convert(x)
+  y = f16[16,16] parameter(1)
+  ROOT out = f16[16,16] dot(c, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+)";
+
+  HloPassPipeline pipeline("gemm_rewrite_deviceless");
+  pipeline.AddPass<GemmFusion>(backend()
+                                   .default_stream_executor()
+                                   ->GetDeviceDescription()
+                                   .cuda_compute_capability());
+  tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "",
+                                      tsl::port::MaxParallelism());
+  DebugOptions opts;
+  MultiProcessKeyValueStore key_value_store;
+  pipeline.AddPass<GemmFusionAutotuner>(
+      AutotuneConfig{
+          DevicelessConfig{
+              backend().default_stream_executor()->GetDeviceDescription()},
+          opts},
+      GetToolkitVersion(), &thread_pool, key_value_store);
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+                          ParseAndReturnVerifiedModule(hlo));
+  if (GetDebugOptionsForTest().xla_gpu_autotune_level() == 0) {
+    TF_ASSERT_OK_AND_ASSIGN(bool changed,
+                            HloTestBase::RunHloPass(&pipeline, module.get()));
+    EXPECT_TRUE(changed);
+
+    // Check default configuration.
+    TF_ASSERT_OK_AND_ASSIGN(
+        bool filecheck_matches,
+        RunFileCheck(
+            module->ToString(HloPrintOptions{}.set_print_operand_shape(false)),
+            R"(
+// CHECK: backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"16","block_n":"16","block_k":"32","split_k":"1","num_stages":"1","num_warps":"4","num_ctas":"1"}},"force_earliest_schedule":false}
+            )"));
+    EXPECT_TRUE(filecheck_matches);
+  } else {
+    EXPECT_THAT(HloTestBase::RunHloPass(&pipeline, module.get()),
+                tsl::testing::StatusIs(
+                    tsl::error::INTERNAL,
+                    ::testing::HasSubstr(
+                        "Expect autotune result cache hit for deviceless")));
+  }
+}
+
+INSTANTIATE_TEST_SUITE_P(GemmFusionAutotunerLevelSweep,
+                         GemmFusionAutotunerLevelTest, ::testing::Range(0, 5));
+
+class GemmFusionAutotunerExhaustiveTest : public GemmFusionAutotunerTest {
+ public:
+  DebugOptions GetDebugOptionsForTest() override {
+    DebugOptions debug_options =
+        GemmFusionAutotunerTest::GetDebugOptionsForTest();
+    debug_options.set_xla_gpu_exhaustive_tiling_search(true);
+    return debug_options;
+  }
+};
+
+TEST_F(GemmFusionAutotunerExhaustiveTest, DISABLED_CompileOnly) {
+  const std::string hlo = R"(
+HloModule module
+
+ENTRY e {
+  x = s8[16,16] parameter(0)
+  c = f16[16,16] convert(x)
+  y = f16[16,16] parameter(1)
+  ROOT out = f16[16,16] dot(c, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+)";
+
+  CheckTritonAutotuning(hlo, R"(
+// CHECK:   %triton_gemm_out_computation (
+// CHECK:   ROOT %out.1 = f16[16,16]{1,0} dot(%c.1, %parameter_1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+// CHECK:   ROOT %triton_gemm_out = f16[16,16]{1,0} fusion(%x, %y), kind=kCustom, calls=%triton_gemm_out_computation
+// CHECK-SAME: "block_m":
+)");
+}
+
+// TODO(b/337839570): Triton currently has a limitation where it crashes
+// on small block_k values depending on the bit-width of the inputs to the
+// dot. For this test case, it should skip any block_k values that are <= 16
+// since the smallest type has a bit-width of 8.
+TEST_F(GemmFusionAutotunerExhaustiveTest, SkipsCrashingTileKConfig) {
+  std::unique_ptr<VerifiedHloModule> module = ParseAndReturnVerifiedModule(R"(
+HloModule module
+ENTRY e {
+  x = s8[33,33]{1,0} parameter(0)
+  c = f16[33,33]{1,0} convert(x)
+  y = f16[33,33]{1,0} parameter(1)
+  ROOT out = f16[33,33]{1,0} dot(c, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+)")
+                                                  .value();
+  const se::CudaComputeCapability compute_capability{
+      se::CudaComputeCapability::AMPERE, /*minor=*/0};
+  TF_ASSERT_OK_AND_ASSIGN(
+      const std::vector<TritonGemmConfig> configs,
+      GetPossibleMatmulAutotuneConfigs(
+          *Cast<HloDotInstruction>(
+              module->entry_computation()->root_instruction()),
+          compute_capability, GetToolkitVersion(), GetDebugOptionsForTest()));
+  EXPECT_TRUE(std::all_of(
+      configs.begin(), configs.end(),
+      [](const TritonGemmConfig& config) { return config.block_k > 16; }));
+}
+
+class GemmFusionAutotunerDisableSplitK : public GemmFusionAutotunerTest {
+ public:
+  DebugOptions GetDebugOptionsForTest() override {
+    DebugOptions debug_options =
+        GemmFusionAutotunerTest::GetDebugOptionsForTest();
+    debug_options.set_xla_gpu_enable_split_k_autotuning(false);
+    return debug_options;
+  }
+};
+
+TEST_F(GemmFusionAutotunerDisableSplitK, SplitKIsDisabled) {
+  std::unique_ptr<VerifiedHloModule> module = ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+  p0 = f32[1024,1024] parameter(0)
+  p1 = f32[1024,1024] parameter(1)
+  ROOT r = f32[1024,1024] dot(p0, p1),
+    lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})")
+                                                  .value();
+  const se::CudaComputeCapability compute_capability{
+      se::CudaComputeCapability::AMPERE, /*minor=*/0};
+  TF_ASSERT_OK_AND_ASSIGN(
+      const std::vector<TritonGemmConfig> configs,
+      GetPossibleMatmulAutotuneConfigs(
+          *Cast<HloDotInstruction>(
+              module->entry_computation()->root_instruction()),
+          compute_capability, GetToolkitVersion(), GetDebugOptionsForTest()));
+  EXPECT_TRUE(std::all_of(
+      configs.begin(), configs.end(),
+      [](const TritonGemmConfig& config) { return config.split_k == 1; }));
+}
+
+class GemmFusionAutotunerConfigTest
+    : public StatelessAutotunerTest,
+      public ::testing::WithParamInterface<bool> {};
+
+TEST_P(GemmFusionAutotunerConfigTest, SparseDotDiscardsUnsupportedTiles) {
+  const std::string kHloText = R"(
+HloModule test
+ENTRY wais {
+  lhs = f16[5,1600] parameter(0)
+  rhs = f16[3200,10] parameter(1)
+  meta = u16[5,200] parameter(2)
+  ROOT dot = f32[5,10] dot(lhs, rhs, meta),
+      lhs_contracting_dims={1}, rhs_contracting_dims={0}, sparsity=L.1@2:4
+})";
+  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText));
+  const se::CudaComputeCapability compute_capability{
+      se::CudaComputeCapability::AMPERE, /*minor=*/0};
+  DebugOptions debug_options = GetDebugOptionsForTest();
+  debug_options.set_xla_gpu_exhaustive_tiling_search(GetParam());
+
+  TF_ASSERT_OK_AND_ASSIGN(
+      const std::vector<TritonGemmConfig> configs,
+      GetPossibleMatmulAutotuneConfigs(
+          *Cast<HloDotInstruction>(
+              module->entry_computation()->root_instruction()),
+          compute_capability, GetToolkitVersion(), debug_options));
+  for (const auto& config : configs) {
+    int metadata_size = config.block_m * config.block_k / 16;
+    EXPECT_LE(config.num_warps * WarpSize(), metadata_size);
+    EXPECT_GT(config.block_k, 16);  // kMinTileSize
+  }
+}
+
+INSTANTIATE_TEST_SUITE_P(GemmFusionAutotunerConfigSweep,
+                         GemmFusionAutotunerConfigTest, ::testing::Bool());
+
+TEST_F(GemmFusionAutotunerTest, SplitKFLoatNormalization) {
+  if (!GetCudaComputeCapability().IsAtLeastHopper()) {
+    GTEST_SKIP() << "f8 types are only supported from Hopper onwards.";
+  }
+  const se::CudaComputeCapability compute_capability =
+      GetCudaComputeCapability();
+  se::GpuDeviceInfoProto deviceless_proto;
+  auto ccc = deviceless_proto.mutable_cuda_compute_capability();
+  ccc->set_major(compute_capability.major);
+  ccc->set_minor(compute_capability.minor);
+  DeviceConfig test_config{backend().default_stream_executor(),
+                           backend().memory_allocator()};
+  AutotuneConfig autotune_config{test_config, GetDebugOptionsForTest()};
+  GemmFusionAutotunerImpl autotuner(autotune_config, GetToolkitVersion(),
+                                    GetDebugOptionsForTest(), nullptr);
+  TF_ASSERT_OK_AND_ASSIGN(
+      auto compile_util,
+      AutotunerCompileUtil::Create(autotune_config, GetDebugOptionsForTest()))
+
+  std::unique_ptr<VerifiedHloModule> module = ParseAndReturnVerifiedModule(R"(
+HloModule module
+
+%gemm_fusion_dot_computation (parameter_0: f8e5m2[256,256], parameter_1: f8e4m3fn[128,256]) -> f8e5m2[256,128] {
+  %parameter_0 = f8e5m2[256,256]{1,0} parameter(0)
+  %parameter_1 = f8e4m3fn[128,256]{1,0} parameter(1)
+  %dot.1 = f32[256,128]{1,0} dot(f8e5m2[256,256]{1,0} %parameter_0, f8e4m3fn[128,256]{1,0} %parameter_1), lhs_contracting_dims={0}, rhs_contracting_dims={1}
+  ROOT %convert.2 = f8e5m2[256,128]{1,0} convert(f32[256,128]{1,0} %dot.1)
+}
+ENTRY entry {
+  %p0 = f8e5m2[256,256]{1,0} parameter(0)
+  %p1 = f8e4m3fn[128,256]{1,0} parameter(1)
+  ROOT r = f8e5m2[256,128]{1,0} fusion(f8e5m2[256,256]{1,0} %p0, f8e4m3fn[128,256]{1,0} %p1), kind=kCustom, calls=%gemm_fusion_dot_computation, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_gemm"},"force_earliest_schedule":false}
+})")
+                                                  .value();
+  GemmFusionAutotunerImpl::TilingConfigs configs;
+  configs.emplace_back(DynCast<HloFusionInstruction>(
+                           module->entry_computation()->root_instruction()),
+                       std::vector<GemmFusionAutotunerImpl::Config>{
+                           GemmFusionAutotunerImpl::Config(TritonGemmConfig(
+                               /*block_m=*/32,
+                               /*block_n=*/64,
+                               /*block_k=*/64,
+                               /*split_k=*/4,
+                               /*num_stages=*/1,
+                               /*num_warps=*/4,
+                               /*num_ctas=*/1))});
+  CHECK_OK(autotuner.CompileAll(*compile_util, configs));
+}
+
+}  // namespace
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gpu_autotuning.proto b/third_party/xla/xla/service/gpu/autotuning/gpu_autotuning.proto
similarity index 100%
rename from third_party/xla/xla/service/gpu/gpu_autotuning.proto
rename to third_party/xla/xla/service/gpu/autotuning/gpu_autotuning.proto
diff --git a/third_party/xla/xla/service/gpu/buffer_sharing.cc b/third_party/xla/xla/service/gpu/buffer_sharing.cc
index 624d324..0ffb8e3 100644
--- a/third_party/xla/xla/service/gpu/buffer_sharing.cc
+++ b/third_party/xla/xla/service/gpu/buffer_sharing.cc
@@ -79,7 +79,7 @@
   // first, i.e. before processing other outputs (that may overwrite the input).
   stream_executor::GpuDeviceInfoProto device_info;
   stream_executor::DeviceDescription device_description(device_info);
-  auto analysis = HloFusionAnalysis::Create(fusion, &device_description);
+  auto analysis = HloFusionAnalysis::Create(*user, device_description);
   bool is_reduction_emitter = analysis.GetEmitterFusionKind() ==
                               HloFusionAnalysis::EmitterFusionKind::kReduction;
   const HloInstruction* reduction_hero =
diff --git a/third_party/xla/xla/service/gpu/collective_permute_cycle_decomposer.cc b/third_party/xla/xla/service/gpu/collective_permute_cycle_decomposer.cc
deleted file mode 100644
index 9102d75..0000000
--- a/third_party/xla/xla/service/gpu/collective_permute_cycle_decomposer.cc
+++ /dev/null
@@ -1,231 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/collective_permute_cycle_decomposer.h"
-
-#include <cstdint>
-#include <string>
-#include <utility>
-#include <vector>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/log/check.h"
-#include "absl/status/status.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/str_join.h"
-#include "absl/strings/string_view.h"
-#include "xla/comparison_util.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/hlo/utils/hlo_query.h"
-#include "xla/literal_util.h"
-#include "xla/service/collective_ops_utils.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/hlo_parser.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/util.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/errors.h"
-
-namespace xla {
-
-namespace {
-using SourceTargetPair = std::pair<int64_t, int64_t>;
-using SourceTargetPairs = std::vector<SourceTargetPair>;
-enum class CycleType { kUnknown, kForward, kBackward };
-
-// Returns true if the CollectivePermute instruction has a cycle in its
-// source-target pairs and should be decomposed.
-CycleType ShouldDecomposeWithCycleType(
-    const HloCollectivePermuteInstruction& collective_permute,
-    int64_t threshold_in_bytes) {
-  if (!collective_permute.channel_id().has_value()) {
-    return CycleType::kUnknown;
-  }
-
-  if (collective_permute.operand_count() != 1) {
-    return CycleType::kUnknown;
-  }
-
-  const Shape& result_shape = collective_permute.shape();
-  // Skip the transformation if there is any context data.
-  if (result_shape.IsTuple()) {
-    return CycleType::kUnknown;
-  }
-
-  CHECK(result_shape.IsArray());
-  if (ShapeUtil::ByteSizeOf(result_shape) < threshold_in_bytes) {
-    return CycleType::kUnknown;
-  }
-
-  const SourceTargetPairs& pairs = collective_permute.source_target_pairs();
-  if (pairs.size() == 1) {
-    return CycleType::kUnknown;
-  }
-
-  return IsForwardCycle(pairs)    ? CycleType::kForward
-         : IsBackwardCycle(pairs) ? CycleType::kBackward
-                                  : CycleType::kUnknown;
-}
-
-// Constructs the frontend attributes for the two decomposed CollectivePermute
-// instructions.
-absl::Status GetFrontendAttributes(HloCollectivePermuteInstruction* cp,
-                                   CycleType cycle_type,
-                                   xla::FrontendAttributes& cp1_attr,
-                                   xla::FrontendAttributes& cp2_attr) {
-  cp1_attr = cp->frontend_attributes();
-  cp2_attr = cp->frontend_attributes();
-  auto validation_it =
-      cp->frontend_attributes().map().find(kSendRecvValidationAttr);
-  if (validation_it == cp->frontend_attributes().map().end() ||
-      validation_it->second == "invalid") {
-    return absl::OkStatus();
-  }
-
-  auto statusor_bounds = ParseReplicaGroupsOnly(validation_it->second);
-  if (!statusor_bounds.ok()) {
-    return statusor_bounds.status();
-  }
-  const std::vector<ReplicaGroup>& bounds = statusor_bounds.value();
-  if (bounds.size() < 2) {
-    return Internal("Invalid number of replica groups");
-  }
-
-  int64_t num_pairs = bounds.size();
-  // A forward cycle has its backedge at the end while a backward cycle has its
-  // backedge at the beginning.
-  auto backedge_start = cycle_type == CycleType::kBackward
-                            ? bounds.begin()
-                            : bounds.begin() + num_pairs - 1;
-  auto other_edges_start =
-      cycle_type == CycleType::kBackward ? bounds.begin() + 1 : bounds.begin();
-  std::vector<ReplicaGroup> cp1_bounds(backedge_start, backedge_start + 1);
-  std::vector<ReplicaGroup> cp2_bounds(other_edges_start,
-                                       other_edges_start + num_pairs - 1);
-  auto bounds_to_string = [](const std::vector<ReplicaGroup> groups) {
-    return "{" +
-           absl::StrJoin(groups, ",",
-                         [](std::string* out, const ReplicaGroup& value) {
-                           absl::StrAppend(out, "{", value.replica_ids(0), ",",
-                                           value.replica_ids(1), "}");
-                         }) +
-           "}";
-  };
-  std::string cp1_validation_str = bounds_to_string(cp1_bounds);
-  std::string cp2_validation_str = bounds_to_string(cp2_bounds);
-  (*cp1_attr.mutable_map())[kSendRecvValidationAttr] = cp1_validation_str;
-  (*cp2_attr.mutable_map())[kSendRecvValidationAttr] = cp2_validation_str;
-  return absl::OkStatus();
-}
-
-// Decomposes a CollectivePermute instruction with a cycle in its source-target
-// pairs into two CollectivePermute instructions.
-absl::Status DecomposeCollectivePermuteCycle(
-    HloCollectivePermuteInstruction* cp, HloComputation* computation,
-    HloModule* module, int64_t next_channel_id, CycleType cycle_type) {
-  const SourceTargetPairs& pairs = cp->source_target_pairs();
-  int64_t num_pairs = pairs.size();
-  // A forward cycle has its backedge at the end as in
-  // {{0,1},{1,2},{2,3},{3,0}} while a backward cycle has its backedge at the
-  // beginning as in {{0,3},{1,0},{2,1},{3,2}}.
-  auto backedge_start = cycle_type == CycleType::kBackward
-                            ? pairs.begin()
-                            : pairs.begin() + num_pairs - 1;
-  auto other_edges_start =
-      cycle_type == CycleType::kBackward ? pairs.begin() + 1 : pairs.begin();
-  SourceTargetPairs backedge(backedge_start, backedge_start + 1);
-  SourceTargetPairs other_edges(other_edges_start,
-                                other_edges_start + num_pairs - 1);
-  const OpMetadata& metadata = cp->metadata();
-  xla::FrontendAttributes cp1_attr, cp2_attr;
-  TF_RETURN_IF_ERROR(GetFrontendAttributes(cp, cycle_type, cp1_attr, cp2_attr));
-
-  // Create the CollectivePermute instruction for the communication represented
-  // by the backedge.
-  HloInstruction* cp1 =
-      computation->AddInstruction(HloInstruction::CreateCollectivePermute(
-          cp->shape(), cp->mutable_operand(0), backedge,
-          cp->channel_id().value()));
-  cp1->set_metadata(metadata);
-  cp1->set_frontend_attributes(cp1_attr);
-  int64_t cp1_receiver = backedge.back().second;
-
-  // Create the CollectivePermute instruction for the communication represented
-  // byt other edges.
-  HloInstruction* cp2 =
-      computation->AddInstruction(HloInstruction::CreateCollectivePermute(
-          cp->shape(), cp->mutable_operand(0), other_edges, next_channel_id));
-  cp2->set_metadata(metadata);
-  cp2->set_frontend_attributes(cp2_attr);
-
-  // Calculate the received data as follows:
-  //   partition = u32[] partition-id()
-  //   constant = u32[] constant(cp1_receiver)
-  //   compare0 = pred[] compare(partition, cp1_received), direction=EQ
-  //   compare = pred[?] broadcast(compare0), dimensions={}
-  //   recv-data = type[?] select(compare, cp1_done, cp2_done)
-  HloInstruction* partition =
-      computation->AddInstruction(HloInstruction::CreatePartitionId());
-  HloInstruction* constant = computation->AddInstruction(
-      HloInstruction::CreateConstant(LiteralUtil::CreateR0(U32, cp1_receiver)));
-  HloInstruction* compare0 = computation->AddInstruction(
-      HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), partition,
-                                    constant, Comparison::Direction::kEq));
-  HloInstruction* compare =
-      computation->AddInstruction(HloInstruction::CreateBroadcast(
-          ShapeUtil::MakeShape(PRED, cp1->shape().dimensions()), compare0, {}));
-  HloInstruction* recv_data =
-      computation->AddInstruction(HloInstruction::CreateTernary(
-          cp1->shape(), HloOpcode::kSelect, compare, cp1, cp2));
-
-  TF_RETURN_IF_ERROR(cp->ReplaceAllUsesWith(recv_data));
-  TF_RETURN_IF_ERROR(computation->RemoveInstructionAndUnusedOperands(cp));
-
-  return absl::OkStatus();
-}
-}  // namespace
-
-absl::StatusOr<bool> CollectivePermuteCycleDecomposer::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  bool changed = false;
-  int64_t next_channel_id;
-  for (auto comp : module->computations(execution_threads)) {
-    for (auto hlo : comp->MakeInstructionPostOrder()) {
-      if (hlo->opcode() != HloOpcode::kCollectivePermute) {
-        continue;
-      }
-      auto collective_permute = Cast<HloCollectivePermuteInstruction>(hlo);
-      CycleType cycle_type = ShouldDecomposeWithCycleType(*collective_permute,
-                                                          threshold_in_bytes_);
-      if (cycle_type != CycleType::kUnknown) {
-        if (changed == false) {
-          next_channel_id = hlo_query::NextChannelId(*module);
-          changed = true;
-        }
-        TF_RETURN_IF_ERROR(DecomposeCollectivePermuteCycle(
-            collective_permute, comp, module, next_channel_id++, cycle_type));
-      }
-    }
-  }
-  return changed;
-}
-
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/collective_permute_cycle_decomposer.h b/third_party/xla/xla/service/gpu/collective_permute_cycle_decomposer.h
deleted file mode 100644
index 508a859..0000000
--- a/third_party/xla/xla/service/gpu/collective_permute_cycle_decomposer.h
+++ /dev/null
@@ -1,73 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_COLLECTIVE_PERMUTE_CYCLE_DECOMPOSER_H_
-#define XLA_SERVICE_GPU_COLLECTIVE_PERMUTE_CYCLE_DECOMPOSER_H_
-
-#include <cstdint>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-
-// CollectivePermuteCycleDecomposer is a pass that converts CollectivePermute
-// instructions with all participants forming either a forward cycle (such as
-// {{0,1},{1,2},{2,3},{3,0}) or a backward cycle (such as {{3,2},{2,1},{1,0},
-// {0,3}}) into two CollectivePermute instructions. We currently restrict
-// this transformation to CollectivePermute using partition mode, with one
-// input, without any context data. Here is an example.
-//
-// before transformation:
-//     start = (<rt>, <rt>) collective-permute(data),
-//       source_target_pairs={{0,1},{1,2},{2,3},{3,0}}
-//
-// after transformation:
-//     partition-id = u32[] partition-id()
-//     constant = u32[] constant(0)
-//     compare = pred[] compare(u32[] partition-id, u32[] constant),
-//       direction=EQ
-//     pred = pred[] broadcast(pred[] compare), dimensions={}
-//     cp1 = (<rt>, <rt>) collective-permute(data), source_target_pairs={{3,0}}
-//     cp2 = (<rt>, <rt>) collective-permute(data),
-//       source_target_pairs={{0,1},{1,2},{2,3}}
-//     data = <rt> select(pred, cp1, cp2)
-//
-class CollectivePermuteCycleDecomposer : public HloModulePass {
- public:
-  explicit CollectivePermuteCycleDecomposer(int64_t threshold_in_bytes)
-      : threshold_in_bytes_(threshold_in_bytes) {}
-  absl::string_view name() const override {
-    return "collective-permute-cycle-decomposer";
-  }
-
-  using HloPassInterface::Run;
-  // Runs CollectivePermuteCycleDecomposer pass on computations in 'module'.
-  // Returns whether the 'module' was changed.
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
-  // Transform only if the size of the CollectivePermute data >= threshold.
-  int64_t threshold_in_bytes_;
-};
-
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_COLLECTIVE_PERMUTE_CYCLE_DECOMPOSER_H_
diff --git a/third_party/xla/xla/service/gpu/collective_permute_cycle_decomposer_test.cc b/third_party/xla/xla/service/gpu/collective_permute_cycle_decomposer_test.cc
deleted file mode 100644
index 19436ee..0000000
--- a/third_party/xla/xla/service/gpu/collective_permute_cycle_decomposer_test.cc
+++ /dev/null
@@ -1,237 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/collective_permute_cycle_decomposer.h"
-
-#include <memory>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_parser.h"
-#include "xla/tests/hlo_test_base.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace {
-
-using ::testing::HasSubstr;
-using CollectivePermuteCycleDecomposerTest = HloTestBase;
-
-using ::testing::HasSubstr;
-using CollectivePermuteDecomposerTest = HloTestBase;
-
-TEST_F(CollectivePermuteDecomposerTest, DefaultChannelNotTransformed) {
-  const absl::string_view kModuleStr = R"(
-      HloModule test
-      ENTRY test_computation {
-        p = u32[] replica-id()
-        ROOT start = u32[] collective-permute(p),
-          source_target_pairs={{0,1},{1,0}}
-      }
-    )";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnUnverifiedModule((kModuleStr)));
-  CollectivePermuteCycleDecomposer decomposer(/*threshold_in_bytes=*/0);
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get()));
-  EXPECT_FALSE(changed);
-}
-
-TEST_F(CollectivePermuteCycleDecomposerTest, TrivialNotTransformed) {
-  const absl::string_view kModuleStr = R"(
-      HloModule test
-      ENTRY test_computation {
-        p = u32[] partition-id()
-        ROOT start = u32[] collective-permute(p), channel_id=1,
-          source_target_pairs={{0,0}}
-      }
-    )";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnUnverifiedModule((kModuleStr)));
-  CollectivePermuteCycleDecomposer decomposer(/*threshold_in_bytes=*/0);
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get()));
-  EXPECT_FALSE(changed);
-}
-
-TEST_F(CollectivePermuteCycleDecomposerTest, BelowThresholdNotTransformed) {
-  const absl::string_view kModuleStr = R"(
-      HloModule test
-      ENTRY test_computation {
-        p = u32[] partition-id()
-        ROOT start = u32[] collective-permute(p), channel_id=1,
-          source_target_pairs={{0,1},{1,2},{2,3},{3,0}}
-      }
-    )";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnUnverifiedModule((kModuleStr)));
-  CollectivePermuteCycleDecomposer decomposer(/*threshold_in_bytes=*/33);
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get()));
-  EXPECT_FALSE(changed);
-}
-
-TEST_F(CollectivePermuteCycleDecomposerTest, ForwardCycle) {
-  const absl::string_view kModuleStr = R"(
-      HloModule test
-      ENTRY test_computation {
-        p = u32[] partition-id()
-        ROOT start = u32[3,2] collective-permute(p), channel_id=1,
-          source_target_pairs={{0,1},{1,2},{2,3},{3,0}},
-          frontend_attributes={_xla_send_recv_validation="{{0,7},{1,8},{2,9},{3,10}}"},
-          metadata={op_name="op1/op2/add" source_file="foo/bar/mysource.py" source_line=35}
-      }
-    )";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnUnverifiedModule((kModuleStr)));
-  CollectivePermuteCycleDecomposer decomposer(/*threshold_in_bytes=*/0);
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get()));
-  EXPECT_TRUE(changed);
-
-  auto check_metadata = [](const HloInstruction* inst) {
-    EXPECT_EQ(inst->metadata().op_name(), "op1/op2/add");
-    EXPECT_EQ(inst->metadata().source_file(), "foo/bar/mysource.py");
-    EXPECT_EQ(inst->metadata().source_line(), 35);
-  };
-
-  HloCollectivePermuteInstruction* cp1 =
-      DynCast<HloCollectivePermuteInstruction>(
-          FindInstruction(module.get(), "collective-permute"));
-  HloCollectivePermuteInstruction* cp2 =
-      DynCast<HloCollectivePermuteInstruction>(
-          FindInstruction(module.get(), "collective-permute.1"));
-  EXPECT_NE(cp1, nullptr);
-  EXPECT_NE(cp2, nullptr);
-  EXPECT_EQ(cp1->operand(0), cp2->operand(0));
-  EXPECT_GT(cp2->channel_id().value(), cp1->channel_id().value());
-  EXPECT_THAT(cp1->ToString(), HasSubstr("source_target_pairs={{3,0}}"));
-  EXPECT_THAT(cp1->ToString(),
-              HasSubstr("_xla_send_recv_validation=\"{{3,10}}\""));
-  EXPECT_THAT(cp2->ToString(),
-              HasSubstr("source_target_pairs={{0,1},{1,2},{2,3}}"));
-  EXPECT_THAT(cp2->ToString(),
-              HasSubstr("_xla_send_recv_validation=\"{{0,7},{1,8},{2,9}}\""));
-  check_metadata(cp1);
-  check_metadata(cp2);
-}
-
-TEST_F(CollectivePermuteCycleDecomposerTest, ForwardCycleWithMatmul) {
-  const absl::string_view kModuleStr = R"(
-  HloModule test
-
-  while_cond {
-    param = (u32[], f32[2,2], f32[2,2]) parameter(0)
-    iter = u32[] get-tuple-element(param), index=0
-    max_iter = u32[] constant(3)
-    ROOT cmp = pred[] compare(iter, max_iter), direction=LT
-  }
-
-  while_body {
-    param = (u32[], f32[2,2], f32[2,2]) parameter(0)
-    iter = u32[] get-tuple-element(param), index=0
-    data = f32[2,2] get-tuple-element(param), index=1
-    weights = f32[2,2] get-tuple-element(param), index=2
-    cp = f32[2,2] collective-permute(data),
-      channel_id=1,
-      source_target_pairs={{0,1}, {1,2}, {2,3}, {3,0}},
-      frontend_attributes={_xla_send_recv_validation="{{0,7},{1,8},{2,9},{3,10}}"}
-    matmul = f32[2,2] dot(weights, cp), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-    iter_increment = u32[] constant(1)
-    next_iter = u32[] add(iter, iter_increment)
-    ROOT result = (u32[], f32[2,2], f32[2,2]) tuple(next_iter, matmul, weights)
-  }
-
-  ENTRY test_computation {
-    iter = u32[] constant(0)
-    data = f32[2,2] parameter(0)
-    weights = f32[2,2] parameter(1)
-    input = (u32[], f32[2,2], f32[2,2]) tuple(iter, data, weights)
-    while_res = (u32[], f32[2,2], f32[2,2]) while(input), condition=while_cond, body=while_body
-    ROOT data_out = f32[2,2] get-tuple-element(while_res), index=1
-  }
-  )";
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnUnverifiedModule((kModuleStr)));
-  CollectivePermuteCycleDecomposer decomposer(/*threshold_in_bytes=*/0);
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get()));
-  EXPECT_TRUE(changed);
-  HloCollectivePermuteInstruction* cp1 =
-      DynCast<HloCollectivePermuteInstruction>(
-          FindInstruction(module.get(), "collective-permute"));
-  HloCollectivePermuteInstruction* cp2 =
-      DynCast<HloCollectivePermuteInstruction>(
-          FindInstruction(module.get(), "collective-permute.1"));
-  EXPECT_THAT(cp1->ToString(), HasSubstr("source_target_pairs={{3,0}}"));
-  EXPECT_THAT(cp1->ToString(),
-              HasSubstr("_xla_send_recv_validation=\"{{3,10}}\""));
-  EXPECT_THAT(cp2->ToString(),
-              HasSubstr("source_target_pairs={{0,1},{1,2},{2,3}}"));
-  EXPECT_THAT(cp2->ToString(),
-              HasSubstr("_xla_send_recv_validation=\"{{0,7},{1,8},{2,9}}\""));
-}
-
-TEST_F(CollectivePermuteCycleDecomposerTest, BackwardCycle) {
-  const absl::string_view kModuleStr = R"(
-      HloModule test
-      ENTRY test_computation {
-        p = u32[] partition-id()
-        ROOT start = u32[] collective-permute(p), channel_id=1,
-          source_target_pairs={{0,3},{1,0},{2,1},{3,2}},
-          frontend_attributes={_xla_send_recv_validation="{{0,7},{1,8},{2,9},{3,10}}"},
-          metadata={op_name="op1/op2/add" source_file="foo/bar/mysource.py" source_line=35}
-      }
-    )";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnUnverifiedModule((kModuleStr)));
-  CollectivePermuteCycleDecomposer decomposer(/*threshold_in_bytes=*/0);
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get()));
-  EXPECT_TRUE(changed);
-  auto check_metadata = [](const HloInstruction* inst) {
-    EXPECT_EQ(inst->metadata().op_name(), "op1/op2/add");
-    EXPECT_EQ(inst->metadata().source_file(), "foo/bar/mysource.py");
-    EXPECT_EQ(inst->metadata().source_line(), 35);
-  };
-
-  HloCollectivePermuteInstruction* cp1 =
-      DynCast<HloCollectivePermuteInstruction>(
-          FindInstruction(module.get(), "collective-permute"));
-  HloCollectivePermuteInstruction* cp2 =
-      DynCast<HloCollectivePermuteInstruction>(
-          FindInstruction(module.get(), "collective-permute.1"));
-  EXPECT_NE(cp1, nullptr);
-  EXPECT_NE(cp2, nullptr);
-  EXPECT_EQ(cp1->operand(0), cp2->operand(0));
-  EXPECT_GT(cp2->channel_id().value(), cp1->channel_id().value());
-  EXPECT_THAT(cp1->ToString(), HasSubstr("source_target_pairs={{0,3}}"));
-  EXPECT_THAT(cp1->ToString(),
-              HasSubstr("_xla_send_recv_validation=\"{{0,7}}\""));
-  EXPECT_THAT(cp2->ToString(),
-              HasSubstr("source_target_pairs={{1,0},{2,1},{3,2}}"));
-  EXPECT_THAT(cp2->ToString(),
-              HasSubstr("_xla_send_recv_validation=\"{{1,8},{2,9},{3,10}}\""));
-  check_metadata(cp1);
-  check_metadata(cp2);
-}
-
-}  // namespace
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/collective_permute_valid_iteration_annotator.cc b/third_party/xla/xla/service/gpu/collective_permute_valid_iteration_annotator.cc
deleted file mode 100644
index b1e8812..0000000
--- a/third_party/xla/xla/service/gpu/collective_permute_valid_iteration_annotator.cc
+++ /dev/null
@@ -1,163 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-    http://www.apache.org/licenses/LICENSE-2.0
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/collective_permute_valid_iteration_annotator.h"
-
-#include "xla/literal_util.h"
-#include "xla/service/collective_ops_utils.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/while_loop_analysis.h"
-
-namespace xla {
-
-// Finds and returns the non-constant operand in instr.
-//
-// CHECK-fails if instr doesn't have exactly one unique non-constant operand.
-static const HloInstruction* NonConstantOperand(const HloInstruction* instr) {
-  const HloInstruction* result = nullptr;
-  for (const HloInstruction* operand : instr->operands()) {
-    if (!operand->IsConstant()) {
-      if (result != nullptr) {
-        CHECK_EQ(result, operand);
-      }
-      result = operand;
-    }
-  }
-  CHECK_NE(result, nullptr);
-  return result;
-}
-
-// Finds the step (k) for while instruction, if the loop is of the form:
-//
-// while(cond) {
-//   ind_var = ind_var + k
-// }
-//
-// If this pattern is not found, it returns std::nullopt.
-std::optional<int64_t> GetStep(HloInstruction* while_inst) {
-  // Get the update operation
-  std::optional<int64_t> indvar_tuple_idx =
-      GetLoopInductionVarTupleIdx(while_inst);
-  if (!indvar_tuple_idx) {
-    return std::nullopt;
-  };
-  auto* while_body_indvar_update =
-      while_inst->while_body()->root_instruction()->mutable_operand(
-          *indvar_tuple_idx);
-  auto* while_body_indvar = NonConstantOperand(while_body_indvar_update);
-
-  HloInstruction* trip_count_increase_step_instr = nullptr;
-  if (!Match(while_body_indvar_update,
-             match::AddAnyOrder(match::Op().Is(while_body_indvar),
-                                match::Op(&trip_count_increase_step_instr)))) {
-    return std::nullopt;
-  }
-  return LiteralUtil::LiteralAsScalarInt64(
-      trip_count_increase_step_instr->literal());
-}
-
-absl::StatusOr<bool> CollectivePermuteValidIterationAnnotator::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  bool changed = false;
-  for (HloComputation* comp : module->computations(execution_threads)) {
-    for (HloInstruction* inst : comp->instructions()) {
-      if (inst->opcode() != HloOpcode::kCollectivePermute) {
-        continue;
-      }
-
-      if (inst->frontend_attributes().map().find(kSendRecvValidationAttr) !=
-          inst->frontend_attributes().map().end()) {
-        continue;
-      }
-      auto sourceTargetPairs = inst->source_target_pairs();
-      if (!IsForwardCycle(sourceTargetPairs) &&
-          !IsBackwardCycle(sourceTargetPairs)) {
-        continue;
-      }
-
-      VLOG(2) << "Collective permute with cycle: " << inst->ToString();
-
-      int64_t max_device_num = -1;
-      for (auto [source, target] : sourceTargetPairs) {
-        max_device_num = std::max(std::max(source, target), max_device_num);
-      }
-      int64_t num_devices = max_device_num + 1;
-
-      HloInstruction* whileOp = inst->parent()->WhileCallInstruction();
-      if (whileOp == nullptr) {
-        VLOG(2) << "No surrounding while op found. Ignoring " << inst->name();
-        continue;
-      }
-      if (!whileOp->frontend_attributes().map().contains(
-              "is_pipelined_while_loop"))
-        continue;
-      TF_ASSIGN_OR_RETURN(WhileLoopBackendConfig config,
-                          whileOp->backend_config<WhileLoopBackendConfig>());
-      if (!config.has_known_trip_count()) {
-        VLOG(2) << "Trip count for while loop (" << whileOp->name()
-                << "): unknown";
-        continue;
-      }
-
-      int64_t trip_count = config.known_trip_count().n();
-      std::optional<int64_t> step = GetStep(whileOp);
-      VLOG(2) << "Trip count for while loop (" << whileOp->name()
-              << "): " << trip_count;
-      if (!step) {
-        VLOG(2) << "Could not find step for while operation";
-        continue;
-      }
-      VLOG(2) << "Step for while loop (" << whileOp->name() << "): " << *step;
-      if (*step != 1) {
-        VLOG(2) << "Step is not 1. Skipping...";
-        continue;
-      }
-
-      // For each source i, the send/recv iteration instances are {i, i+offset}
-      // where offset is `number of microbatches * CR - 1`. We know that
-      // `trip_count = number_of_microbatches * CR + num_devices - 1` So, offset
-      // = number_of_microbatches * CR - 1 = trip_count - num_devices.
-      int64_t offset = trip_count - num_devices;
-
-      std::vector<std::pair<int64_t, int64_t>> sendRecvValidation(
-          sourceTargetPairs.size());
-
-      for (size_t currIdx = 0; currIdx < sourceTargetPairs.size(); currIdx++) {
-        sendRecvValidation[currIdx] = {currIdx, currIdx + offset};
-      }
-
-      if (IsBackwardCycle(sourceTargetPairs)) {
-        std::reverse(sendRecvValidation.begin(), sendRecvValidation.end());
-      }
-
-      xla::FrontendAttributes attributes;
-      std::string iteration_instances =
-          "{" +
-          absl::StrJoin(sendRecvValidation, ",",
-                        [](std::string* out, std::pair<int64_t, int64_t> item) {
-                          absl::StrAppend(out, "{", item.first, ",",
-                                          item.second, "}");
-                        }) +
-          "}";
-      (*attributes.mutable_map())[kSendRecvValidationAttr] =
-          iteration_instances;
-
-      inst->add_frontend_attributes(attributes);
-      VLOG(1) << "Adding " << kSendRecvValidationAttr << " to " << inst->name()
-              << ": " << iteration_instances;
-      changed = true;
-    }
-  }
-  return changed;
-}
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/collective_permute_valid_iteration_annotator.h b/third_party/xla/xla/service/gpu/collective_permute_valid_iteration_annotator.h
deleted file mode 100644
index e6b04c9..0000000
--- a/third_party/xla/xla/service/gpu/collective_permute_valid_iteration_annotator.h
+++ /dev/null
@@ -1,58 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-    http://www.apache.org/licenses/LICENSE-2.0
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_COLLECTIVE_PERMUTE_VALID_ITERATION_ANNOTATOR_H_
-#define XLA_SERVICE_GPU_COLLECTIVE_PERMUTE_VALID_ITERATION_ANNOTATOR_H_
-
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-
-// This is an unsafe transformation that is triggered only if the attribute
-// `is_pipelined_while_loop` is present on a while loop.
-//
-// If a while loop is known to be a pipelined while loop, has a known trip count
-// and increments with step=1, then this pass annotates the `collective-permute`
-// operations within the while loop with valid iterations for each GPU. This is
-// only done when the source-target pairs of the `collective-permute` operation
-// form a forward or backward cycle.
-//
-// For example, if the trip count is 10 (iteration 0 to 9), with step=1, and the
-// source-target pairs of a `collective-permute` operation are
-// `{{0,1},{1,2},{2,3},{3,0}}`, then this pass would annotate such operation
-// with `_xla_send_recv_validation="{{0,6},{1,7},{2,8},{3,9}}"`. This annotation
-// means that
-//   - for GPU index 0, the valid iterations are 0,1,2,3,4,5,6.
-//   - for GPU index 1, the valid iterations are 1,2,3,4,5,6,7.
-//   - for GPU index 2, the valid iterations are 2,3,4,5,6,7,8.
-//   - for GPU index 3, the valid iterations are 3,4,5,6,7,8,9.
-//
-// The index in the list denotes the device index and the bounds {start,end} are
-// inclusive. For more examples, look at
-// `xla/service/spmd/collective_permute_valid_iteration_annotator_tests.cc`.
-class CollectivePermuteValidIterationAnnotator : public HloModulePass {
- public:
-  CollectivePermuteValidIterationAnnotator() = default;
-  absl::string_view name() const override {
-    return "collective-permute-valid-iteration-annotator";
-  }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-};
-
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_COLLECTIVE_PERMUTE_VALID_ITERATION_ANNOTATOR_H_
diff --git a/third_party/xla/xla/service/gpu/collective_permute_valid_iteration_annotator_test.cc b/third_party/xla/xla/service/gpu/collective_permute_valid_iteration_annotator_test.cc
deleted file mode 100644
index 3d1d0b4..0000000
--- a/third_party/xla/xla/service/gpu/collective_permute_valid_iteration_annotator_test.cc
+++ /dev/null
@@ -1,174 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-    http://www.apache.org/licenses/LICENSE-2.0
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/collective_permute_valid_iteration_annotator.h"
-
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/service/collective_ops_utils.h"
-#include "xla/service/hlo_pass_pipeline.h"
-#include "xla/service/while_loop_trip_count_annotator.h"
-#include "xla/tests/hlo_test_base.h"
-
-namespace xla {
-namespace {
-
-using CollectivePermuteValidIterationAnnotatorTest = HloTestBase;
-
-TEST_F(CollectivePermuteValidIterationAnnotatorTest, NoChange) {
-  // We expect no changes here because the while loop is not labelled as
-  // `is_pipelined_while_loop`.
-  absl::string_view hlo_string = R"(
-    HloModule test, entry_computation_layout={()->(s32[], s32[])}
-    %Body (param: (s32[], s32[])) -> (s32[], s32[]) {
-      %param = (s32[], s32[]) parameter(0)
-      %i = s32[] get-tuple-element((s32[], s32[]) %param), index=1
-      %one = s32[] constant(1)
-      %i_plus_one = s32[] add(s32[] %i, s32[] %one)
-      %permute = s32[] collective-permute(%i_plus_one), channel_id=1, source_target_pairs={{0,1},{1,2},{2,3},{3,0}}
-      ROOT %tuple = (s32[], s32[]) tuple(s32[] %permute, s32[] %permute)
-    }
-    %Cond (param.1: (s32[], s32[])) -> pred[] {
-      %param.1 = (s32[], s32[]) parameter(0)
-      %i.1 = s32[] get-tuple-element((s32[], s32[]) %param.1), index=1
-      %trip_count = s32[] constant(10)
-      ROOT %done = pred[] compare(s32[] %i.1, s32[] %trip_count), direction=LT
-    }
-    ENTRY %test () -> (s32[], s32[]) {
-      %i_start = s32[] constant(0)
-      %p_start = s32[] constant(0)
-      %initial_tuple = (s32[], s32[]) tuple(s32[] %i_start, s32[] %p_start)
-      ROOT %while = (s32[], s32[]) while((s32[], s32[]) %initial_tuple), condition=%Cond, body=%Body
-    }
-  )";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(hlo_string, 1, 4));
-
-  HloPassPipeline pipeline("my-pass-pipeline");
-
-  pipeline.AddPass<WhileLoopTripCountAnnotator>();
-  pipeline.AddPass<CollectivePermuteValidIterationAnnotator>();
-
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get()));
-  EXPECT_FALSE(changed);
-
-  HloCollectivePermuteInstruction* cp =
-      DynCastOrNull<HloCollectivePermuteInstruction>(
-          FindInstruction(module.get(), HloOpcode::kCollectivePermute));
-
-  ASSERT_NE(cp, nullptr);
-
-  auto sendRecvValidationIt =
-      cp->frontend_attributes().map().find(kSendRecvValidationAttr);
-  ASSERT_EQ(sendRecvValidationIt, cp->frontend_attributes().map().end());
-}
-
-TEST_F(CollectivePermuteValidIterationAnnotatorTest, ForwardCycle) {
-  absl::string_view hlo_string = R"(
-    HloModule test, entry_computation_layout={()->(s32[], s32[])}
-    %Body (param: (s32[], s32[])) -> (s32[], s32[]) {
-      %param = (s32[], s32[]) parameter(0)
-      %i = s32[] get-tuple-element((s32[], s32[]) %param), index=1
-      %one = s32[] constant(1)
-      %i_plus_one = s32[] add(s32[] %i, s32[] %one)
-      %permute = s32[] collective-permute(%i_plus_one), channel_id=1, source_target_pairs={{0,1},{1,2},{2,3},{3,0}}
-      ROOT %tuple = (s32[], s32[]) tuple(s32[] %permute, s32[] %i_plus_one)
-    }
-    %Cond (param.1: (s32[], s32[])) -> pred[] {
-      %param.1 = (s32[], s32[]) parameter(0)
-      %i.1 = s32[] get-tuple-element((s32[], s32[]) %param.1), index=1
-      %trip_count = s32[] constant(10)
-      ROOT %done = pred[] compare(s32[] %i.1, s32[] %trip_count), direction=LT
-    }
-    ENTRY %test () -> (s32[], s32[]) {
-      %i_start = s32[] constant(0)
-      %p_start = s32[] constant(0)
-      %initial_tuple = (s32[], s32[]) tuple(s32[] %i_start, s32[] %p_start)
-      ROOT %while = (s32[], s32[]) while((s32[], s32[]) %initial_tuple), condition=%Cond, body=%Body, frontend_attributes={is_pipelined_while_loop="true"}
-    }
-  )";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(hlo_string, 1, 4));
-
-  HloPassPipeline pipeline("my-pass-pipeline");
-
-  pipeline.AddPass<WhileLoopTripCountAnnotator>();
-  pipeline.AddPass<CollectivePermuteValidIterationAnnotator>();
-
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get()));
-  EXPECT_TRUE(changed);
-
-  HloCollectivePermuteInstruction* cp =
-      DynCastOrNull<HloCollectivePermuteInstruction>(
-          FindInstruction(module.get(), HloOpcode::kCollectivePermute));
-
-  ASSERT_NE(cp, nullptr);
-
-  auto sendRecvValidationIt =
-      cp->frontend_attributes().map().find(kSendRecvValidationAttr);
-  ASSERT_NE(sendRecvValidationIt, cp->frontend_attributes().map().end());
-  std::string sendRecvValidationAttr = sendRecvValidationIt->second;
-  EXPECT_EQ(sendRecvValidationAttr, "{{0,6},{1,7},{2,8},{3,9}}");
-}
-
-TEST_F(CollectivePermuteValidIterationAnnotatorTest, BackwardCycle) {
-  absl::string_view hlo_string = R"(
-    HloModule test, entry_computation_layout={()->(s32[], s32[])}
-    %Body (param: (s32[], s32[])) -> (s32[], s32[]) {
-      %param = (s32[], s32[]) parameter(0)
-      %i = s32[] get-tuple-element((s32[], s32[]) %param), index=1
-      %one = s32[] constant(1)
-      %i_plus_one = s32[] add(s32[] %i, s32[] %one)
-      %permute = s32[] collective-permute(%i_plus_one), channel_id=1, source_target_pairs={{0,3},{1,0},{2,1},{3,2}}
-      ROOT %tuple = (s32[], s32[]) tuple(s32[] %permute, s32[] %i_plus_one)
-    }
-    %Cond (param.1: (s32[], s32[])) -> pred[] {
-      %param.1 = (s32[], s32[]) parameter(0)
-      %i.1 = s32[] get-tuple-element((s32[], s32[]) %param.1), index=1
-      %trip_count = s32[] constant(10)
-      ROOT %done = pred[] compare(s32[] %i.1, s32[] %trip_count), direction=LT
-    }
-    ENTRY %test () -> (s32[], s32[]) {
-      %i_start = s32[] constant(0)
-      %p_start = s32[] constant(0)
-      %initial_tuple = (s32[], s32[]) tuple(s32[] %i_start, s32[] %p_start)
-      ROOT %while = (s32[], s32[]) while((s32[], s32[]) %initial_tuple), condition=%Cond, body=%Body, frontend_attributes={is_pipelined_while_loop="true"}
-    }
-  )";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(hlo_string, 1, 4));
-
-  HloPassPipeline pipeline("my-pass-pipeline");
-
-  pipeline.AddPass<WhileLoopTripCountAnnotator>();
-  pipeline.AddPass<CollectivePermuteValidIterationAnnotator>();
-
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get()));
-  EXPECT_TRUE(changed);
-
-  HloCollectivePermuteInstruction* cp =
-      DynCastOrNull<HloCollectivePermuteInstruction>(
-          FindInstruction(module.get(), HloOpcode::kCollectivePermute));
-
-  ASSERT_NE(cp, nullptr);
-
-  auto sendRecvValidationIt =
-      cp->frontend_attributes().map().find(kSendRecvValidationAttr);
-  ASSERT_NE(sendRecvValidationIt, cp->frontend_attributes().map().end());
-  std::string sendRecvValidationAttr = sendRecvValidationIt->second;
-  EXPECT_EQ(sendRecvValidationAttr, "{{3,9},{2,8},{1,7},{0,6}}");
-}
-}  // namespace
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/command_buffer_scheduling.cc b/third_party/xla/xla/service/gpu/command_buffer_scheduling.cc
deleted file mode 100644
index d113a6b..0000000
--- a/third_party/xla/xla/service/gpu/command_buffer_scheduling.cc
+++ /dev/null
@@ -1,811 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/command_buffer_scheduling.h"
-
-#include <algorithm>
-#include <cstddef>
-#include <cstdint>
-#include <iterator>
-#include <memory>
-#include <utility>
-#include <variant>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_map.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/container/inlined_vector.h"
-#include "absl/status/status.h"
-#include "absl/strings/match.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "xla/ffi/ffi_api.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_clone_context.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/hlo/ir/hlo_schedule.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/service/gpu/hlo_fusion_analysis.h"
-#include "xla/service/gpu/hlo_traversal.h"
-#include "xla/service/gpu/ir_emission_utils.h"
-#include "xla/service/gpu/variant_visitor.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/util.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla::gpu {
-
-using CommandBuffer = CommandBufferScheduling::CommandBuffer;
-using CommandBufferConfig = CommandBufferScheduling::CommandBufferConfig;
-
-// Returns true if HLO computation can be executed as a command buffer.
-static bool IsCommand(const HloComputation* computation,
-                      const CommandBufferConfig& config);
-
-//===----------------------------------------------------------------------===//
-// No-op HLO operations.
-//===----------------------------------------------------------------------===//
-
-// Some of the HLO operations do not have corresponding operations at run time
-// and they can be safely wrapped into command buffers together with load
-// bearing commands.
-
-static bool IsConstant(const HloInstruction* hlo) {
-  return hlo->opcode() == HloOpcode::kConstant;
-}
-
-static bool IsParameter(const HloInstruction* hlo) {
-  return hlo->opcode() == HloOpcode::kParameter;
-}
-
-// Returns true if instruction is no-op at run time and doesn't have a
-// corresponding Thunk or Command (metadata only operation).
-static bool IsNoOp(const HloInstruction* hlo) {
-  return HloPredicateIsOp<HloOpcode::kBitcast, HloOpcode::kTuple,
-                          HloOpcode::kGetTupleElement>(hlo);
-};
-
-//===----------------------------------------------------------------------===//
-// Synchronous HLO operations mapped to commands.
-//===----------------------------------------------------------------------===//
-
-// Synchronous HLO operations can be wrapped into command buffers when they have
-// a corresponding commands.
-
-// This is a template to define pattern matching functions for HLO instructions
-// that do not have a corresponding class for them.
-template <HloOpcode op>
-static bool IsCommand(const HloInstruction*, const CommandBufferConfig&);
-
-// While loops can be executed inside command buffers only if condition and body
-// regions can be executed as command buffers.
-template <>
-bool IsCommand<HloOpcode::kWhile>(const HloInstruction* hlo,
-                                  const CommandBufferConfig& config) {
-  return config.enabled_commands.contains(DebugOptions::CONDITIONALS) &&
-         IsCommand(hlo->while_body(), config) &&
-         IsCommand(hlo->while_condition(), config);
-}
-
-// Conditional can be executed inside command buffers only if all regions of its
-// branches can be executed as command buffers.
-template <>
-bool IsCommand<HloOpcode::kConditional>(const HloInstruction* hlo,
-                                        const CommandBufferConfig& config) {
-  return config.enabled_commands.contains(DebugOptions::CONDITIONALS) &&
-         absl::c_all_of(hlo->branch_computations(),
-                        [&](const HloComputation* comp) {
-                          return IsCommand(comp, config);
-                        });
-}
-
-static bool IsCommand(const HloCustomCallInstruction* hlo,
-                      const CommandBufferConfig& config) {
-  // cuBLAS gemms represented in the HLO as custom call instructions.
-  if (config.enabled_commands.contains(DebugOptions::CUBLAS) &&
-      IsLegacyCublasMatmul(*hlo)) {
-    return true;
-  }
-
-  if (config.enabled_commands.contains(DebugOptions::CUBLASLT) &&
-      (IsCublasLtMatmul(*hlo) || IsCublasLtMatmulF8(*hlo))) {
-    return true;
-  }
-
-  if (config.enabled_commands.contains(DebugOptions::CUDNN) &&
-      IsCustomCallTofMHA(*hlo)) {
-    VLOG(3) << "Recording FusedMHA, target " << hlo->custom_call_target()
-            << " into command buffer.";
-    return true;
-  }
-
-  if (!config.enabled_commands.contains(DebugOptions::CUSTOM_CALL)) {
-    return false;
-  }
-
-  if (config.enabled_legacy_custom_call_targets.contains(
-          hlo->custom_call_target())) {
-    VLOG(3) << "Recording legacy custom call target "
-            << hlo->custom_call_target() << " into command buffer.";
-    return true;
-  }
-
-  // A special case for jax-triton kernel while it is not ported to FFI.
-  if (hlo->custom_call_target() == "triton_kernel_call" &&
-      // TODO(b/327718087): This is an ugly hack to prevent capturing triton
-      // custom calls that might do autotuning at run time.
-      !absl::StrContains(hlo->metadata().op_name(), "Autotuner")) {
-    return true;
-  }
-
-  // Check if FFI handler is compatible with command buffers.
-  auto registration = ffi::FindHandler(hlo->custom_call_target(), "gpu");
-  return registration.ok()
-             ? ffi::IsCommandBufferCompatible(registration->traits)
-             : false;
-}
-
-static bool IsCommand(const HloInstruction* hlo,
-                      const CommandBufferConfig& config) {
-  if (auto* fusion = DynCast<HloFusionInstruction>(hlo)) {
-    auto gpu_config = fusion->backend_config<GpuBackendConfig>();
-    const FusionBackendConfig& backend_config =
-        gpu_config->fusion_backend_config();
-    if (backend_config.kind() == kCuDnnFusionKind) {
-      return config.enabled_commands.contains(DebugOptions::CUDNN);
-    }
-    const auto& custom_config = backend_config.custom_fusion_config();
-    if (custom_config.name() == "address_computation") {
-      auto fusion_analysis =
-          HloFusionAnalysis::Create(fusion, &config.device_description);
-      const HloFusionAdaptor& adaptor = fusion_analysis.fusion();
-      auto custom_call_adaptor = HloBfsFindIf(
-          adaptor.GetRoots(), adaptor,
-          [](auto node) { return node.opcode() == HloOpcode::kCustomCall; });
-      const auto* custom_call = static_cast<const HloCustomCallInstruction*>(
-          &custom_call_adaptor->instruction());
-      return IsCommand(custom_call, config);
-    }
-    if (custom_config.name() == "dynamic_address_computation") {
-      return false;
-    }
-    return config.enabled_commands.contains(DebugOptions::FUSION);
-  }
-
-  if (auto* sort = DynCast<HloSortInstruction>(hlo))
-    return config.enabled_commands.contains(DebugOptions::FUSION);
-
-  if (hlo->opcode() == HloOpcode::kPartitionId ||
-      hlo->opcode() == HloOpcode::kReplicaId) {
-    return config.enabled_commands.contains(DebugOptions::FUSION);
-  }
-
-  if (auto* custom_call = DynCast<HloCustomCallInstruction>(hlo))
-    return IsCommand(custom_call, config);
-
-  if (hlo->opcode() == HloOpcode::kWhile)
-    return IsCommand<HloOpcode::kWhile>(hlo, config);
-
-  if (hlo->opcode() == HloOpcode::kConditional)
-    return IsCommand<HloOpcode::kConditional>(hlo, config);
-
-  return false;
-}
-
-//===----------------------------------------------------------------------===//
-// Asynchronous HLO operations mapped to commands.
-//===----------------------------------------------------------------------===//
-
-// Asynchronous HLO operations can be wrapped into command buffers only when
-// both start and done operations can be put into the same command buffer.
-// Command buffer semantics implies that when command buffer execution
-// completes, all recorded commands are also completed, which means that if
-// done operation is not part of the same command buffer, we would change the
-// execution semantics and create additional synchronization point.
-
-static bool IsAsyncStartCommand(const HloInstruction* hlo,
-                                const CommandBufferConfig& config) {
-  if (hlo->opcode() == HloOpcode::kAllReduceStart ||
-      hlo->opcode() == HloOpcode::kAllGatherStart) {
-    return config.enabled_commands.contains(DebugOptions::COLLECTIVES);
-  }
-
-  if (hlo->opcode() == HloOpcode::kAsyncStart) {
-    if (hlo->async_wrapped_opcode() == HloOpcode::kReduceScatter) {
-      return config.enabled_commands.contains(DebugOptions::COLLECTIVES);
-    }
-  }
-
-  return false;
-}
-
-static bool IsAsyncDoneCommand(const HloInstruction* hlo,
-                               const CommandBufferConfig& config) {
-  if (hlo->opcode() == HloOpcode::kAllReduceDone ||
-      hlo->opcode() == HloOpcode::kAllGatherDone) {
-    return config.enabled_commands.contains(DebugOptions::COLLECTIVES);
-  }
-
-  if (hlo->opcode() == HloOpcode::kAsyncDone) {
-    if (hlo->async_wrapped_opcode() == HloOpcode::kReduceScatter) {
-      return config.enabled_commands.contains(DebugOptions::COLLECTIVES);
-    }
-  }
-
-  return false;
-}
-
-// Finds an async-done HLO operation corresponding on an async-start one.
-static HloInstruction* FindAsyncDoneCommand(const HloInstruction* start) {
-  if (start->opcode() == HloOpcode::kAllReduceStart ||
-      start->opcode() == HloOpcode::kAllGatherStart) {
-    CHECK(start->users().size() == 1);  // NOLINT, checked by HLO verifier
-    return start->users().front();
-  } else if (start->opcode() == HloOpcode::kAsyncStart) {
-    return start->async_chain_done();
-  }
-
-  return nullptr;
-}
-
-//===----------------------------------------------------------------------===//
-// HLO computations mapped to command buffers.
-//===----------------------------------------------------------------------===//
-
-// Returns true if HLO computation can be executed as a command buffer.
-static bool IsCommand(const HloComputation* computation,
-                      const CommandBufferConfig& config) {
-  return absl::c_all_of(
-      computation->instructions(), [&](const HloInstruction* inst) {
-        return IsNoOp(inst) || IsConstant(inst) || IsParameter(inst) ||
-               IsCommand(inst, config) || IsAsyncStartCommand(inst, config) ||
-               IsAsyncDoneCommand(inst, config);
-      });
-}
-
-//===----------------------------------------------------------------------===//
-
-static void RemoveTrailingNoOps(HloInstructionSequence& seq) {
-  std::vector<HloInstruction*> instructions = seq.instructions();
-  for (int i = instructions.size() - 1; i >= 0; i--) {
-    if (HloInstruction* inst = instructions[i]; IsNoOp(inst)) {
-      seq.remove_instruction(inst);
-    } else {
-      break;
-    }
-  }
-}
-
-//===----------------------------------------------------------------------===//
-// Discovering sequences of compatible Hlo instructions
-//===----------------------------------------------------------------------===//
-
-// The input is a scheduled sequence of instructions. This function collects
-// subsequences that will be extracted as command buffers.
-std::vector<HloInstructionSequence>
-CommandBufferScheduling::CollectCommandBufferSequences(
-    const HloInstructionSequence schedule, const CommandBufferConfig& config,
-    int32_t min_num_commands) {
-  std::vector<HloInstructionSequence> sequences;
-
-  HloInstructionSequence current_seq;
-  int64_t num_commands_in_current_seq = 0;
-
-  // Adds `current_seq` to `sequences` if it has enough commands in it.
-  auto collect_current_seq = [&]() {
-    if (num_commands_in_current_seq >= std::max(1, min_num_commands)) {
-      RemoveTrailingNoOps(current_seq);
-      sequences.push_back(std::move(current_seq));
-    }
-    current_seq = HloInstructionSequence();
-    num_commands_in_current_seq = 0;
-  };
-
-  auto& instructions = schedule.instructions();
-
-  // Collect the sequence of instructions that contains the async start and its
-  // corresponding done instruction. If there is another start instruction
-  // between the original start and done, we may potentially extend the sequence
-  // to include its corresponding done instruction. For example, if we call this
-  // function on async-start_a in the following sequence:
-  //
-  // async_start_a
-  // async_start_b
-  // async_done_a
-  // async_done_b
-  //
-  // The returned sequence will contain async_done_b. So that all async pairs
-  // are captured by the same command buffer.
-  auto collect_async_region = [&](const HloInstruction* start) {
-    auto get_index = [&](const HloInstruction* inst) -> size_t {
-      auto it = std::find(instructions.begin(), instructions.end(), inst);
-      return std::distance(instructions.begin(), it);
-    };
-
-    HloInstructionSequence seq;
-    size_t done_index = get_index(FindAsyncDoneCommand(start));
-    for (size_t i = get_index(start); i <= done_index; i++) {
-      HloInstruction* inst = instructions.at(i);
-      if (IsAsyncStartCommand(inst, config)) {
-        const HloInstruction* done = FindAsyncDoneCommand(inst);
-        done_index = std::max(done_index, get_index(done));
-      }
-      seq.push_back(inst);
-    }
-    return seq;
-  };
-
-  // Check that instructions are safe to be captured by command buffer, and that
-  // we do not capture unmatched async done instruction.
-  auto check_async_region = [&](const HloInstructionSequence& seq) {
-    if (!absl::c_all_of(seq.instructions(), [&](HloInstruction* inst) {
-          return IsNoOp(inst) || IsCommand(inst, config) ||
-                 IsAsyncStartCommand(inst, config) ||
-                 IsAsyncDoneCommand(inst, config);
-        })) {
-      return false;
-    }
-
-    absl::flat_hash_set<HloInstruction*> done_instructions;
-    for (const HloInstruction* inst : seq.instructions()) {
-      if (IsAsyncStartCommand(inst, config)) {
-        done_instructions.insert(FindAsyncDoneCommand(inst));
-      }
-      if (IsAsyncDoneCommand(inst, config)) {
-        if (!done_instructions.contains(inst)) {
-          return false;
-        }
-      }
-    }
-    return true;
-  };
-
-  for (size_t i = 0; i < instructions.size(); i++) {
-    HloInstruction* inst = instructions.at(i);
-
-    // We add no-op instructions to current sequence only if they act as a glue
-    // between commands. We do not create command sequences consisting only from
-    // no-op instruction. First and last instruction in the command buffer is
-    // always a load-bearing command.
-    if (IsNoOp(inst) && num_commands_in_current_seq) {
-      current_seq.push_back(inst);
-      continue;
-    }
-
-    // Synchronous commands always can be added to instruction sequence.
-    if (IsCommand(inst, config)) {
-      num_commands_in_current_seq++;
-      current_seq.push_back(inst);
-      continue;
-    }
-
-    // We capture async commands if all instruction between start and done can
-    // be outlined into a command buffer.
-    if (IsAsyncStartCommand(inst, config)) {
-      HloInstructionSequence seq = collect_async_region(inst);
-      if (check_async_region(seq)) {
-        num_commands_in_current_seq += seq.instructions().size();
-        for (HloInstruction* inst : seq.instructions()) {
-          current_seq.push_back(inst);
-        }
-        i += seq.instructions().size() - 1;
-        continue;
-      }
-    }
-
-    // If we didn't find the next command, collect the current sequence and
-    // start a new one.
-    collect_current_seq();
-  }
-
-  // Don't forget to collect the final command sequence.
-  collect_current_seq();
-  return sequences;
-}
-
-// This function moves kParameter and kConstant instructions in a computation to
-// the beginning of the computation. This simplifies the construction of command
-// buffer computations because we don't need to deal with parameters and
-// constants that have users outside of a command buffer.
-absl::Status CommandBufferScheduling::MoveParametersAndConstantsToFront(
-    HloComputation* computation) {
-  HloInstructionSequence new_sequence;
-  HloSchedule& schedule = computation->parent()->schedule();
-  HloInstructionSequence& sequence = schedule.GetOrCreateSequence(computation);
-
-  for (HloInstruction* inst : sequence.instructions()) {
-    if (IsParameter(inst) || IsConstant(inst)) {
-      new_sequence.push_back(inst);
-
-      // Because we move instruction to the front of the computation we can't
-      // have any control predecessors, however silently dropping them is unsafe
-      // as we can have transitive dependencies that define schedule order, so
-      // we forward control predecessors to all users.
-      for (HloInstruction* control_predecessor : inst->control_predecessors()) {
-        for (HloInstruction* user : inst->users()) {
-          TF_RETURN_IF_ERROR(control_predecessor->AddControlDependencyTo(user));
-        }
-      }
-      TF_RETURN_IF_ERROR(inst->DropAllControlDeps());
-    }
-  }
-
-  for (HloInstruction* inst : sequence.instructions()) {
-    if (!IsParameter(inst) && !IsConstant(inst)) {
-      new_sequence.push_back(inst);
-    }
-  }
-
-  schedule.set_sequence(computation, new_sequence);
-  return absl::OkStatus();
-}
-
-//===----------------------------------------------------------------------===//
-// Prepares command buffer from sequence of instructions
-//===----------------------------------------------------------------------===//
-
-absl::StatusOr<CommandBuffer> CommandBufferScheduling::PrepareCommandBuffer(
-    const HloInstructionSequence& seq, HloModule* module) {
-  auto builder = HloComputation::Builder("command_buffer");
-
-  absl::Span<HloInstruction* const> instructions =
-      absl::MakeSpan(seq.instructions());
-
-  // A set of instructions that will be moved into command buffer computation.
-  absl::flat_hash_set<HloInstruction*> in_command_buffer(instructions.begin(),
-                                                         instructions.end());
-
-  // The sequence might use results of instructions that are not captured by the
-  // sequence. We pass those results as parameters and map the producers of the
-  // results to their corresponding parameter instructions.
-  absl::flat_hash_map<HloInstruction*, HloParameterInstruction*> parameters;
-
-  // Mapping from command buffer instructions to their clones in the command
-  // buffer computation body.
-  absl::flat_hash_map<HloInstruction*, HloInstruction*> inst_mapping;
-
-  // Maps HLO instructions in the original computation to instructions in the
-  // command buffer: (a) a parameter corresponding to captured value (b) cloned
-  // instruction corresponding to a command.
-  auto mapped_operands = [&](HloInstruction* instr) {
-    absl::InlinedVector<HloInstruction*, 4> operands;
-    for (HloInstruction* operand : instr->operands()) {
-      if (auto it = inst_mapping.find(operand); it != inst_mapping.end())
-        operands.push_back(it->second);
-    }
-    return operands;
-  };
-
-  // Create parameters in the command buffer computation for captured values.
-  for (HloInstruction* inst : instructions) {
-    for (HloInstruction* operand : inst->operands()) {
-      // We already mapped instruction to a parameter.
-      if (parameters.contains(operand)) continue;
-
-      // Operand instruction is a part of the command buffer.
-      if (in_command_buffer.contains(operand)) continue;
-
-      // Create a new parameter for value defined outside of a command buffer.
-      int64_t parameter_id = parameters.size();
-      auto* parameter = Cast<HloParameterInstruction>(
-          builder.AddInstruction(HloInstruction::CreateParameter(
-              parameter_id, operand->shape(), "p")));
-
-      parameter->UniquifyName(module);
-      parameter->UniquifyId(module);
-      inst_mapping[operand] = parameters[operand] = parameter;
-    }
-  }
-
-  // Clone commands into the command buffer body with mapped operands.
-  for (HloInstruction* inst : seq.instructions()) {
-    HloCloneContext ctx(inst->GetModule());
-
-    // Cloned instructions should call the same computations as original
-    // instructions will be dead code eliminated.
-    for (HloComputation* called_computation : inst->called_computations()) {
-      // Async computations can only be referenced by a single async chain at
-      // a time. Detach the current chain to let its copy bind to the
-      // computation.
-      if (called_computation->IsAsyncComputation()) {
-        called_computation->RemoveAsyncStart();
-      }
-      ctx.MapComputation(called_computation, called_computation);
-    }
-
-    inst_mapping[inst] = builder.AddInstruction(
-        inst->CloneWithNewOperands(inst->shape(), mapped_operands(inst), &ctx));
-    inst_mapping[inst]->UniquifyId(module);
-  }
-
-  // Convert parameters to command buffer arguments.
-  std::vector<HloInstruction*> arguments(parameters.size());
-  for (auto& [argument, parameter] : parameters) {
-    arguments[parameter->parameter_number()] = argument;
-  }
-
-  // Collect command buffer `results` (instructions replaced in the original
-  // computation) and `results` (instructions in the command buffer).
-  std::vector<HloInstruction*> results;
-  std::vector<HloInstruction*> returned;
-
-  auto has_external_users = [&](HloInstruction* inst) {
-    return inst->IsRoot() || absl::c_any_of(inst->users(), [&](auto* user) {
-             return !in_command_buffer.contains(user);
-           });
-  };
-
-  for (HloInstruction* inst : instructions) {
-    if (has_external_users(inst)) {
-      results.push_back(inst);
-      returned.push_back(inst_mapping[inst]);
-    }
-  }
-
-  // If we return multiple results wrap them into tuple.
-  if (returned.size() > 1) {
-    HloInstruction* inst =
-        builder.AddInstruction(HloInstruction::CreateTuple(returned));
-    inst->UniquifyName(module);
-    inst->UniquifyId(module);
-  }
-
-  std::unique_ptr<HloComputation> comp = builder.Build();
-  comp->UniquifyName(module);
-  comp->SetUniqueId(comp->root_instruction()->unique_id());
-
-  return CommandBuffer{std::move(arguments), std::move(results),
-                       std::move(comp), std::move(inst_mapping)};
-}
-
-//===----------------------------------------------------------------------===//
-// Rewrites original computation into command buffer call
-//===----------------------------------------------------------------------===//
-
-absl::StatusOr<HloComputation*> CommandBufferScheduling::RewriteCommandBuffer(
-    HloComputation* parent, const HloInstructionSequence& seq,
-    CommandBuffer command_buffer) {
-  if (command_buffer.results.empty())
-    return absl::InternalError("command buffer results must not be empty");
-
-  // If we have more than one result we return them as tuple, and get individual
-  // values using `get-tuple-element` instructions. Otherwise we simply return
-  // a result from a command buffer computation.
-  Shape cmd_buffer_result_shape;
-  bool has_single_result = command_buffer.results.size() == 1;
-
-  if (has_single_result) {
-    cmd_buffer_result_shape = command_buffer.results[0]->shape();
-  } else {
-    absl::InlinedVector<Shape, 4> shapes;
-    shapes.reserve(command_buffer.results.size());
-    for (auto* res : command_buffer.results) shapes.push_back(res->shape());
-    cmd_buffer_result_shape = ShapeUtil::MakeTupleShape(shapes);
-  }
-
-  HloComputation* computation =
-      parent->parent()->AddComputation(std::move(command_buffer.computation),
-                                       /*is_entry=*/false);
-
-  HloInstruction* call = parent->AddInstruction(HloInstruction::CreateCall(
-      cmd_buffer_result_shape, command_buffer.arguments, computation));
-
-  // Replace all users or original results with a command buffer results.
-  if (has_single_result) {
-    TF_RETURN_IF_ERROR(command_buffer.results[0]->ReplaceAllUsesWith(call));
-  } else {
-    for (int i = 0; i < command_buffer.results.size(); i++) {
-      TF_RETURN_IF_ERROR(
-          command_buffer.results[i]->ReplaceAllUsesWith(parent->AddInstruction(
-              HloInstruction::CreateGetTupleElement(call, i))));
-    }
-  }
-
-  // As we are running after scheduling we have to keep it valid.
-  HloSchedule& schedule = parent->parent()->schedule();
-
-  // Update schedule to replace the last instruction with a command buffer call.
-  // Removal of the rest of the instructions in the sequence is handled by
-  // schedule update below.
-  HloInstructionSequence& sequence = schedule.GetOrCreateSequence(parent);
-  sequence.replace_instruction(seq.instructions().back(), call);
-
-  // Rebuild original instruction sequence schedule in a newly created
-  // command buffer computation to guarantee that we'll get exactly the same
-  // buffer assignment result as if we were running without command buffers.
-  HloInstructionSequence cmd_buffer_schedule;
-  for (auto* argument : command_buffer.arguments) {
-    cmd_buffer_schedule.push_back(command_buffer.inst_mapping[argument]);
-  }
-  for (auto* inst : seq.instructions()) {
-    cmd_buffer_schedule.push_back(command_buffer.inst_mapping[inst]);
-  }
-  if (!has_single_result) {
-    cmd_buffer_schedule.push_back(computation->root_instruction());
-  }
-  schedule.set_sequence(computation, cmd_buffer_schedule);
-
-  // Forward control dependencies between original instructions to instruction
-  // in the command buffer computation.
-  auto& inst_mapping = command_buffer.inst_mapping;
-  for (HloInstruction* inst : seq.instructions()) {
-    HloInstruction* cmd_inst = inst_mapping[inst];
-
-    // Forward control dependencies to the new instruction inside command
-    // buffer. If the dependent instruction is not captured by the command
-    // buffer, forward the dependency to the command buffer call instead.
-    for (HloInstruction* predecessor : inst->control_predecessors()) {
-      if (auto it = inst_mapping.find(predecessor); it != inst_mapping.end()) {
-        // If predecessor mapped to a parameter instruction it means that we
-        // need to forward control dependency to a call operation, otherwise
-        // we add control dependency between commands in the command buffer.
-        HloInstruction* cmd_predecessor = it->second;
-        if (IsParameter(cmd_predecessor)) {
-          TF_RETURN_IF_ERROR(predecessor->AddControlDependencyTo(call));
-        } else {
-          TF_RETURN_IF_ERROR(cmd_predecessor->AddControlDependencyTo(cmd_inst));
-        }
-      } else {
-        TF_RETURN_IF_ERROR(predecessor->AddControlDependencyTo(call));
-      }
-    }
-
-    for (HloInstruction* successor : inst->control_successors()) {
-      if (auto it = inst_mapping.find(successor); it != inst_mapping.end()) {
-        HloInstruction* cmd_successor = it->second;
-        TF_RETURN_IF_ERROR(cmd_inst->AddControlDependencyTo(cmd_successor));
-      } else {
-        TF_RETURN_IF_ERROR(call->AddControlDependencyTo(successor));
-      }
-    }
-
-    TF_RETURN_IF_ERROR(inst->DropAllControlDeps());
-  }
-
-  // Traverse in reverse order as original sequence was topologically sorted and
-  // we can't remove instructions with users.
-  for (int32_t i = seq.instructions().size() - 1; i >= 0; i--) {
-    TF_RETURN_IF_ERROR(parent->RemoveInstruction(seq.instructions()[i]));
-  }
-
-  return computation;
-}
-
-//===----------------------------------------------------------------------===//
-
-CommandBufferScheduling::CommandBufferScheduling(
-    const se::DeviceDescription& device_description,
-    int32_t gpu_toolkit_version, int32_t gpu_driver_version)
-    : device_description_(device_description),
-      gpu_toolkit_version_(gpu_toolkit_version),
-      gpu_driver_version_(gpu_driver_version) {}
-
-absl::StatusOr<bool> CommandBufferScheduling::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  // We run command buffer scheduling after a regular scheduling to guarantee
-  // that command buffers will not change execution order and buffer assignment
-  // compared to a regular execution. Some operations (i.e. async collectives)
-  // can't be captured into command buffers, and forming too large command
-  // buffers too early can impact async operations scheduling.
-  if (!module->has_schedule()) return Internal("module is not scheduled");
-
-  const DebugOptions& debug_options = module->config().debug_options();
-
-  absl::flat_hash_set<DebugOptions::CommandBufferCmdType> commands;
-  for (auto cmd_type : debug_options.xla_gpu_enable_command_buffer()) {
-    commands.insert(static_cast<DebugOptions::CommandBufferCmdType>(cmd_type));
-  }
-
-  absl::flat_hash_set<std::string> legacy_custom_call_targets;
-  for (const auto& target :
-       debug_options.legacy_command_buffer_custom_call_targets()) {
-    legacy_custom_call_targets.insert(target);
-  }
-
-  CommandBufferConfig config{std::move(commands),
-                             std::move(legacy_custom_call_targets),
-                             device_description_};
-
-  // Erase command buffer cmd types that are not supported by the gpu runtime.
-  static constexpr auto kRequireConditionals = {DebugOptions::CONDITIONALS};
-  static constexpr auto kRequireTracing = {
-      DebugOptions::CUBLAS, DebugOptions::CUBLASLT, DebugOptions::CUDNN,
-      DebugOptions::CUSTOM_CALL, DebugOptions::COLLECTIVES};
-
-  auto erase = [&](absl::Span<const DebugOptions::CommandBufferCmdType> cmds) {
-    for (auto cmd : cmds) {
-      if (config.enabled_commands.erase(cmd)) {
-        VLOG(1) << "Removed command buffer support for "
-                << DebugOptions::CommandBufferCmdType_Name(cmd)
-                << " as it's not supported with gpu toolkit version "
-                << gpu_toolkit_version_ << " and driver version "
-                << gpu_driver_version_
-                << ". This might negatively impact peformance. To enable "
-                << DebugOptions::CommandBufferCmdType_Name(cmd)
-                << " support in command buffers use cuda-compat package: "
-#if defined(PLATFORM_GOOGLE)
-                << "set CUDA_COMPAT_LOAD=1 env variable.";
-#else
-                << "https://docs.nvidia.com/deploy/cuda-compatibility/.";
-#endif
-      }
-    }
-  };
-
-  // Check if CUDA/ROCM driver supports required features.
-  auto erase_cuda = [&](const se::CudaComputeCapability& cuda_comp) {
-    if (std::min(gpu_toolkit_version_, gpu_driver_version_) < 12030) {
-      erase(kRequireTracing);       // cuStreamBeginCaptureToGraph
-      erase(kRequireConditionals);  // on-device control flow
-    }
-  };
-  auto erase_rocm = [&](const se::RocmComputeCapability& rocm_comp) {
-    erase(kRequireConditionals);  // on-device control flow
-  };
-
-  std::visit(VariantVisitor{erase_cuda, erase_rocm},
-             device_description_.gpu_compute_capability());
-
-  auto order = module->MakeComputationPostOrder();
-  std::reverse(order.begin(), order.end());
-  absl::flat_hash_set<HloComputation*> processed_command_buffers;
-
-  for (HloComputation* comp : order) {
-    // Skip special computations that do not have lowering to thunks.
-    if (comp->IsFusionComputation() || comp->IsAsyncComputation() ||
-        comp->IsCustomCallComputation())
-      continue;
-
-    // Skip computations that already part of command buffers.
-    if (processed_command_buffers.contains(comp)) continue;
-
-    TF_RETURN_IF_ERROR(MoveParametersAndConstantsToFront(comp));
-
-    std::vector<HloInstructionSequence> sequences =
-        CollectCommandBufferSequences(
-            module->schedule().sequence(comp), config,
-            debug_options.xla_gpu_graph_min_graph_size());
-
-    for (const HloInstructionSequence& seq : sequences) {
-      TF_ASSIGN_OR_RETURN(CommandBuffer command_buffer,
-                          PrepareCommandBuffer(seq, comp->parent()));
-      TF_ASSIGN_OR_RETURN(
-          HloComputation * command_buffer_computation,
-          RewriteCommandBuffer(comp, seq, std::move(command_buffer)));
-
-      // All computations reachable from a command buffer computation are nested
-      // command buffers (i.e. body computations attached to a while operation).
-      for (HloComputation* called :
-           command_buffer_computation->MakeEmbeddedComputationsList()) {
-        processed_command_buffers.insert(called);
-      }
-    }
-  }
-  TF_RETURN_IF_ERROR(module->schedule().Update());
-
-  return true;
-}
-
-}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/command_buffer_scheduling.h b/third_party/xla/xla/service/gpu/command_buffer_scheduling.h
deleted file mode 100644
index 78590a8..0000000
--- a/third_party/xla/xla/service/gpu/command_buffer_scheduling.h
+++ /dev/null
@@ -1,143 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_COMMAND_BUFFER_SCHEDULING_H_
-#define XLA_SERVICE_GPU_COMMAND_BUFFER_SCHEDULING_H_
-
-#include <cstdint>
-#include <memory>
-#include <vector>
-
-#include "absl/container/flat_hash_map.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/status.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_schedule.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/stream_executor/device_description.h"
-
-namespace xla::gpu {
-
-// Lift fusion instructions to command buffers.
-//
-// Before the pass:
-//   %fused_computation (param_0: s32[], param_1: s32[]) -> s32[] {
-//     ...
-//   }
-//
-//   ENTRY %main (a: s32[], b: s32[]) -> s32[] {
-//     %a = s32[] parameter(0)
-//     %b = s32[] parameter(1)
-//     ROOT %fusion = s32[] fusion(s32[] %a, s32[] %b), kind=kLoop,
-//       calls=%fused_computation
-//   }
-//
-// After the pass:
-//   %fused_computation (param_0: s32[], param_1: s32[]) -> s32[] {
-//     ...
-//   }
-//
-//   %command_buffer (param_0: s32[], param_1: s32[]) -> s32[] {
-//     %param_0 = s32[] parameter(0)
-//     %param_1 = s32[] parameter(1)
-//     ROOT %fusion = s32[] fusion(s32[] %param_0, s32[] %param_1), kind=kLoop,
-//       calls=%fused_computation
-//   }
-//
-//   ENTRY %main (a: s32[], b: s32[]) -> s32[] {
-//     %a = s32[] parameter(0)
-//     %b = s32[] parameter(1)
-//     ROOT %call = s32[] call(s32[] %a, s32[] %b), to_apply=%command_buffer
-//  }
-//
-// We currently do not have a command_buffer HLO operation, so we'll start with
-// a kCall op code with an attached HLO computation. We'll consider graduating
-// custom call to a first class operation later.
-class CommandBufferScheduling : public HloModulePass {
- public:
-  struct CommandBufferConfig {
-    // DebugOptions control which commands are enabled. Long term we want to
-    // remove that flag and enable all supported commands by default.
-    absl::flat_hash_set<DebugOptions::CommandBufferCmdType> enabled_commands;
-    absl::flat_hash_set<std::string> enabled_legacy_custom_call_targets;
-    const se::DeviceDescription& device_description;
-  };
-
-  CommandBufferScheduling(const se::DeviceDescription& device_description,
-                          int32_t gpu_toolkit_version,
-                          int32_t gpu_driver_version);
-
-  absl::string_view name() const override {
-    return "command-buffer-scheduling";
-  }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
-  static std::vector<HloInstructionSequence> CollectCommandBufferSequences(
-      HloInstructionSequence schedule, const CommandBufferConfig& config,
-      int32_t min_num_commands = 1);
-
-  // Moves kParameter and kConstant instructions in a computation to
-  // the beginning of the computation. This simplifies the construction of
-  // command buffer computations because we don't need to deal with parameters
-  // and constants that have users outside of a command buffer.
-  static absl::Status MoveParametersAndConstantsToFront(
-      HloComputation* computation);
-
-  struct CommandBuffer {
-    // Command buffer arguments (call instruction arguments).
-    std::vector<HloInstruction*> arguments;
-
-    // Command buffer result (call instruction result tuple).
-    std::vector<HloInstruction*> results;
-
-    // Hlo computation corresponding to a command buffer body.
-    std::unique_ptr<HloComputation> computation;
-
-    // Mapping from original instruction to their clones in the command buffer.
-    absl::flat_hash_map<HloInstruction*, HloInstruction*> inst_mapping;
-  };
-
-  // Prepares a command buffer from the instruction sequence. Used values
-  // constructed by instructions outside of the sequence are passed in as
-  // parameters. Results of instructions in the sequence are returned in a tuple
-  // (if command buffer has a single result we don't wrap it into tuple).
-  static absl::StatusOr<CommandBuffer> PrepareCommandBuffer(
-      const HloInstructionSequence& seq, HloModule* module);
-
-  // Rewrites prepared command buffer computation into Hlo operations in the
-  // parent computation (calls command buffer and replaced all users).
-  static absl::StatusOr<HloComputation*> RewriteCommandBuffer(
-      HloComputation* parent, const HloInstructionSequence& seq,
-      CommandBuffer command_buffer);
-
- private:
-  se::DeviceDescription device_description_;
-  // For NVIDIA gpus XLA can be compiled with a CUDA version that is larger than
-  // the version supported by the driver, e.g. we can compile for CUDA 12.3 but
-  // have 12.1 driver installed. When deciding what command buffer features we
-  // can use we have to consider both versions.
-  int32_t gpu_toolkit_version_;
-  int32_t gpu_driver_version_;
-};
-
-}  // namespace xla::gpu
-
-#endif  // XLA_SERVICE_GPU_COMMAND_BUFFER_SCHEDULING_H_
diff --git a/third_party/xla/xla/service/gpu/command_buffer_scheduling_test.cc b/third_party/xla/xla/service/gpu/command_buffer_scheduling_test.cc
deleted file mode 100644
index 3a46193..0000000
--- a/third_party/xla/xla/service/gpu/command_buffer_scheduling_test.cc
+++ /dev/null
@@ -1,1018 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "xla/service/gpu/command_buffer_scheduling.h"
-
-#include <cstdint>
-#include <memory>
-#include <string>
-#include <utility>
-#include <vector>
-
-#include <gtest/gtest.h>
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/hlo/ir/hlo_schedule.h"
-#include "xla/service/hlo_parser.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/tests/filecheck.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/tests/verified_hlo_module.h"
-#include "xla/tsl/lib/core/status_test_util.h"
-#include "tsl/platform/status.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla::gpu {
-namespace {
-
-class CommandBufferSchedulingTest : public HloTestBase {
- public:
-  // Use CUDA 12.3 version for testing as it has all the features we rely on.
-  static constexpr int32_t kCudaVersion = 12030;
-
-  const se::DeviceDescription& device_desc() {
-    return backend().default_stream_executor()->GetDeviceDescription();
-  }
-
-  DebugOptions GetDebugOptionsForTest() override {
-    auto debug_options = HloTestBase::GetDebugOptionsForTest();
-    debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::FUSION);
-    debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CONDITIONALS);
-    debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::COLLECTIVES);
-    debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CUDNN);
-    debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CUBLASLT);
-    debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CUSTOM_CALL);
-    debug_options.set_xla_gpu_graph_min_graph_size(2);
-    return debug_options;
-  }
-};
-
-using CommandBuffer = CommandBufferScheduling::CommandBuffer;
-
-TEST_F(CommandBufferSchedulingTest, SingleCommandBuffer) {
-  const char* hlo = R"(
-      HloModule TestModule, is_scheduled=true
-
-      %fused_computation (param_0: s32[], param_1: s32[]) -> s32[] {
-        %p0 = s32[] parameter(0)
-        %p1 = s32[] parameter(1)
-        ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
-      }
-
-      %fused_computation.1 (param_0: s32[], param_1: s32[]) -> s32[] {
-        %p0 = s32[] parameter(0)
-        %p1 = s32[] parameter(1)
-        ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
-      }
-
-      ENTRY %main (a: s32[], b: s32[]) -> s32[] {
-        %a = s32[] parameter(0)
-        %b = s32[] parameter(1)
-        %fusion = s32[] fusion(s32[] %a, s32[] %b), kind=kLoop, calls=%fused_computation
-        %fusion.1 = s32[] fusion(s32[] %a, s32[] %b), kind=kLoop, calls=%fused_computation.1
-        ROOT %custom-call = s32[] custom-call(s32[] %fusion, s32[] %fusion.1), custom_call_target="some target"
-      })";
-
-  const char* expected = R"(
-// CHECK: %command_buffer ([[P0:.+]]: s32[], [[P1:.+]]: s32[]) -> (s32[], s32[]) {
-// CHECK:   %[[P0]] = s32[] parameter(0)
-// CHECK:   %[[P1]] = s32[] parameter(1)
-// CHECK:   %fusion = s32[] fusion(%[[P0]], %[[P1]]), kind=kLoop, calls=%fused_computation
-// CHECK:   %fusion.1 = s32[] fusion(%[[P0]], %[[P1]]), kind=kLoop, calls=%fused_computation.1
-// CHECK:   ROOT %tuple = (s32[], s32[]) tuple(%fusion, %fusion.1)
-// CHECK: }
-//
-// CHECK: ENTRY %main (a: s32[], b: s32[]) -> s32[] {
-// CHECK:   %a = s32[] parameter(0)
-// CHECK:   %b = s32[] parameter(1)
-// CHECK:   %call = (s32[], s32[]) call(%a, %b), to_apply=%command_buffer
-// CHECK:   %get-tuple-element = s32[] get-tuple-element(%call), index=0
-// CHECK:   %get-tuple-element.1 = s32[] get-tuple-element(%call), index=1
-// CHECK:   ROOT %custom-call = s32[] custom-call(%get-tuple-element, %get-tuple-element.1), custom_call_target="some target"
-// CHECK: })";
-
-  RunAndFilecheckHloRewrite(
-      hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
-      expected, [](HloModule* module) {
-        EXPECT_TRUE(module->has_schedule());
-        TF_CHECK_OK(module->schedule().Verify());
-      });
-}
-
-TEST_F(CommandBufferSchedulingTest, MultipleCommandBuffers) {
-  const char* hlo = R"(
-      HloModule TestModule, is_scheduled=true
-
-      %fused_computation(param_0: s32[], param_1: s32[]) -> s32[] {
-        %p0 = s32[] parameter(0)
-        %p1 = s32[] parameter(1)
-        ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
-      }
-
-      %fused_computation.1(param_0: s32[], param_1: s32[]) -> s32[] {
-        %p0 = s32[] parameter(0)
-        %p1 = s32[] parameter(1)
-        ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
-      }
-
-      %fused_computation.2(param_0: s32[], param_1: s32[]) -> s32[] {
-        %p0 = s32[] parameter(0)
-        %p1 = s32[] parameter(1)
-        ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
-      }
-
-      %fused_computation.3(param_0: s32[], param_1: s32[]) -> s32[] {
-        %p0 = s32[] parameter(0)
-        %p1 = s32[] parameter(1)
-        ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
-      }
-
-      ENTRY %main (a: s32[], b: s32[], c: (s32[], s32[])) -> s32[] {
-        %a = s32[] parameter(0)
-        %b = s32[] parameter(1)
-        %c = (s32[], s32[]) parameter(2)
-        %fusion = s32[] fusion(s32[] %a, s32[] %b), kind=kLoop, calls=%fused_computation
-        %d = s32[] get-tuple-element((s32[], s32[]) %c), index=0
-        %fusion.1 = s32[] fusion(s32[] %fusion, s32[] %d), kind=kLoop, calls=%fused_computation.1
-        %e = s32[] get-tuple-element((s32[], s32[]) %c), index=1
-        %custom-call = s32[] custom-call(s32[] %fusion.1, s32[] %e), custom_call_target="some target"
-        %fusion.2 = s32[] fusion(s32[] %custom-call, s32[] %a), kind=kLoop, calls=%fused_computation.2
-        %fusion.3 = s32[] fusion(s32[] %custom-call, s32[] %fusion.2), kind=kLoop, calls=%fused_computation.3
-        ROOT %custom-call.1 = s32[] custom-call(s32[] %fusion.3), custom_call_target="some target"
-      })";
-
-  const char* expected = R"(
-// CHECK:  %command_buffer ([[P0:.+]]: s32[], [[P1:.+]]: s32[], [[P2:.+]]: (s32[], s32[])) -> s32[] {
-// CHECK:    %[[P0]] = s32[] parameter(0)
-// CHECK:    %[[P1]] = s32[] parameter(1)
-// CHECK:    %[[P2]] = (s32[], s32[]) parameter(2)
-// CHECK:    %[[F0:.+]] = s32[] fusion(%[[P0]], %[[P1]]), kind=kLoop, calls=%fused_computation
-// CHECK:    %[[V0:.+]] = s32[] get-tuple-element(%[[P2]]), index=0
-// CHECK:    ROOT {{.*}} = s32[] fusion(%[[F0]], %[[V0]]), kind=kLoop, calls=%fused_computation.1
-// CHECK:  }
-
-// CHECK:  %command_buffer.2 ([[P0:.+]]: s32[], [[P1:.+]]: s32[]) -> s32[] {
-// CHECK:    %[[P0]] = s32[] parameter(0)
-// CHECK:    %[[P1]] = s32[] parameter(1)
-// CHECK:    %[[F2:.+]] = s32[] fusion(%[[P0]], %[[P1]]), kind=kLoop, calls=%fused_computation.2
-// CHECK:    ROOT {{.*}} = s32[] fusion(%[[P0]], %[[F2]]), kind=kLoop, calls=%fused_computation.3
-// CHECK:  }
-
-// CHECK:  ENTRY %main (a: s32[], b: s32[], c: (s32[], s32[])) -> s32[] {
-// CHECK:    %a = s32[] parameter(0)
-// CHECK:    %b = s32[] parameter(1)
-// CHECK:    %c = (s32[], s32[]) parameter(2)
-// CHECK:    %[[CMD0:.+]] = s32[] call(%a, %b, %c), to_apply=%command_buffer
-// CHECK:    %e = s32[] get-tuple-element(%c), index=1
-// CHECK:    %[[CALL:.+]] = s32[] custom-call(%[[CMD0]], %e), custom_call_target="some target"
-// CHECK:    %[[CMD1:.+]] = s32[] call(%[[CALL]], %a), to_apply=%command_buffer.2
-// CHECK:    ROOT {{.*}} = s32[] custom-call(%[[CMD1]]), custom_call_target="some target"
-// CHECK:  })";
-
-  RunAndFilecheckHloRewrite(
-      hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
-      expected, [](HloModule* module) {
-        EXPECT_TRUE(module->has_schedule());
-        TF_CHECK_OK(module->schedule().Verify());
-      });
-}
-
-TEST_F(CommandBufferSchedulingTest, AllReduceStartFollowedByDone) {
-  const char* hlo = R"(
-    HloModule TestModule, is_scheduled=true
-
-    %add (p0: s32[4], p1: s32[4]) -> s32[4] {
-      %p0 = s32[4] parameter(0)
-      %p1 = s32[4] parameter(1)
-      ROOT %add = s32[4] add(s32[4] %p0, s32[4] %p1)
-    }
-
-    ENTRY %main (a: s32[4]) -> s32[4] {
-      %a = s32[4] parameter(0)
-      %start = s32[4]{0} all-reduce-start(s32[4]{0} %a),
-        replica_groups={{0,1}}, to_apply=%add,
-        backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}}
-      ROOT %done = s32[4]{0} all-reduce-done(s32[4]{0} %start)
-    })";
-
-  const char* expected = R"(
-    CHECK: %command_buffer ([[P0:.+]]: s32[4]) -> s32[4] {
-    CHECK:   %[[P0]] = s32[4]{0} parameter(0)
-    CHECK:   %[[START:.+]] = s32[4]{0} all-reduce-start(%[[P0]])
-    CHECK:   ROOT %[[DONE:.+]] = s32[4]{0} all-reduce-done(%[[START]])
-    CHECK: }
-
-    CHECK: ENTRY %main (a: s32[4]) -> s32[4] {
-    CHECK:   %[[A:.+]] = s32[4]{0} parameter(0)
-    CHECK:   ROOT %[[CALL:.+]] = s32[4]{0} call(%[[A]]),
-    CHECK:     to_apply=%command_buffer
-    CHECK: })";
-
-  RunAndFilecheckHloRewrite(
-      hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
-      expected, [](HloModule* module) {
-        EXPECT_TRUE(module->has_schedule());
-        TF_CHECK_OK(module->schedule().Verify());
-      });
-}
-
-TEST_F(CommandBufferSchedulingTest, AllGatherStartFollowedByDone) {
-  const char* hlo = R"(
-    HloModule TestModule, is_scheduled=true
-
-    ENTRY %main (a: s32[2]) -> s32[4] {
-      %a = s32[2] parameter(0)
-
-      %start = (s32[2]{0}, s32[4]{0}) all-gather-start(%a),
-        channel_id=555, replica_groups={{0,1}}, dimensions={0},
-        backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}}
-
-      ROOT %done = s32[4]{0} all-gather-done(%start)
-    })";
-
-  const char* expected = R"(
-    CHECK: %command_buffer ([[P0:.+]]: s32[2]) -> s32[4] {
-    CHECK:   %[[P0]] = s32[2]{0} parameter(0)
-    CHECK:   %[[START:.+]] = {{.*}} all-gather-start(%[[P0]])
-    CHECK:   ROOT %[[DONE:.+]] = s32[4]{0} all-gather-done(%[[START]])
-    CHECK: }
-
-    CHECK: ENTRY %main (a: s32[2]) -> s32[4] {
-    CHECK:   %[[A:.+]] = s32[2]{0} parameter(0)
-    CHECK:   ROOT %[[CALL:.+]] = s32[4]{0} call(%[[A]]),
-    CHECK:     to_apply=%command_buffer
-    CHECK: })";
-
-  RunAndFilecheckHloRewrite(
-      hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
-      expected, [](HloModule* module) {
-        EXPECT_TRUE(module->has_schedule());
-        TF_CHECK_OK(module->schedule().Verify());
-      });
-}
-
-TEST_F(CommandBufferSchedulingTest, ReduceScatterStartFollowedByDone) {
-  const char* hlo = R"(
-    HloModule TestModule, is_scheduled=true
-
-    %add (p0: s32[], p1: s32[]) -> s32[] {
-      %p0 = s32[] parameter(0)
-      %p1 = s32[] parameter(1)
-      ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
-    }
-
-    ENTRY %main (a: s32[4]) -> s32[2] {
-      %a = s32[4] parameter(0)
-
-      %start = ((s32[4]{0}), s32[2]{0}) reduce-scatter-start(%a),
-        channel_id=555, replica_groups={{0,1}}, dimensions={0}, to_apply=add,
-        backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}}
-
-      ROOT %done = s32[2]{0} reduce-scatter-done(%start)
-    })";
-
-  const char* expected = R"(
-    CHECK: %command_buffer ([[P0:.+]]: s32[4]) -> s32[2] {
-    CHECK:   %[[P0]] = s32[4]{0} parameter(0)
-    CHECK:   %[[START:.+]] = {{.*}} reduce-scatter-start(%[[P0]])
-    CHECK:   ROOT %[[DONE:.+]] = s32[2]{0} reduce-scatter-done(%[[START]])
-    CHECK: }
-
-    CHECK: ENTRY %main (a: s32[4]) -> s32[2] {
-    CHECK:   %[[A:.+]] = s32[4]{0} parameter(0)
-    CHECK:   ROOT %[[CALL:.+]] = s32[2]{0} call(%[[A]]),
-    CHECK:     to_apply=%command_buffer
-    CHECK: })";
-
-  RunAndFilecheckHloRewrite(
-      hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
-      expected, [](HloModule* module) {
-        EXPECT_TRUE(module->has_schedule());
-        TF_CHECK_OK(module->schedule().Verify());
-      });
-}
-
-TEST_F(CommandBufferSchedulingTest, AllReduceStartFollowedByBitcast) {
-  const char* hlo = R"(
-    HloModule TestModule, is_scheduled=true
-
-    %add (p0: s32[4], p1: s32[4]) -> s32[4] {
-      %p0 = s32[4] parameter(0)
-      %p1 = s32[4] parameter(1)
-      ROOT %add = s32[4] add(s32[4] %p0, s32[4] %p1)
-    }
-
-    ENTRY %main (a: s32[4]) -> s32[4] {
-      %a = s32[4] parameter(0)
-      %start = s32[4]{0} all-reduce-start(s32[4]{0} %a),
-        replica_groups={{0,1}}, to_apply=%add,
-        backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}}
-      %bitcast = s32[4] bitcast(s32[4]{0} %a)
-      ROOT %done = s32[4]{0} all-reduce-done(s32[4]{0} %start)
-    })";
-
-  const char* expected = R"(
-    CHECK: %command_buffer ([[P0:.+]]: s32[4]) -> s32[4] {
-    CHECK:   %[[P0]] = s32[4]{0} parameter(0)
-    CHECK:   %[[START:.+]] = s32[4]{0} all-reduce-start(%[[P0]])
-    CHECK:   %[[BITCAST:.+]] = s32[4]{0} bitcast(%[[P0]])
-    CHECK:   ROOT %[[DONE:.+]] = s32[4]{0} all-reduce-done(%[[START]])
-    CHECK: }
-
-    CHECK: ENTRY %main (a: s32[4]) -> s32[4] {
-    CHECK:   %[[A:.+]] = s32[4]{0} parameter(0)
-    CHECK:   ROOT %[[CALL:.+]] = s32[4]{0} call(%[[A]]),
-    CHECK:     to_apply=%command_buffer
-    CHECK: })";
-
-  RunAndFilecheckHloRewrite(
-      hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
-      expected, [](HloModule* module) {
-        EXPECT_TRUE(module->has_schedule());
-        TF_CHECK_OK(module->schedule().Verify());
-      });
-}
-
-TEST_F(CommandBufferSchedulingTest, AllReduceStartFollowedAllReduceStart) {
-  const char* hlo = R"(
-    HloModule TestModule, is_scheduled=true
-
-    %add (p0: s32[4], p1: s32[4]) -> s32[4] {
-      %p0 = s32[4] parameter(0)
-      %p1 = s32[4] parameter(1)
-      ROOT %add = s32[4] add(s32[4] %p0, s32[4] %p1)
-    }
-
-    ENTRY %main (a: s32[4]) -> s32[4] {
-      %a = s32[4] parameter(0)
-      %start1 = s32[4]{0} all-reduce-start(s32[4]{0} %a),
-        replica_groups={{0,1}}, to_apply=%add,
-        backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}}
-      %start2 = s32[4]{0} all-reduce-start(s32[4]{0} %a),
-        replica_groups={{0,1}}, to_apply=%add,
-        backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}}
-      %done1 = s32[4]{0} all-reduce-done(s32[4]{0} %start1)
-      ROOT %done2 = s32[4]{0} all-reduce-done(s32[4]{0} %start2)
-    })";
-
-  const char* expected = R"(
-    CHECK: %command_buffer ([[P0:.+]]: s32[4]) -> s32[4] {
-    CHECK:   %[[P0]] = s32[4]{0} parameter(0)
-    CHECK:   %[[START1:.+]] = s32[4]{0} all-reduce-start(%[[P0]])
-    CHECK:   %[[START2:.+]] = s32[4]{0} all-reduce-start(%[[P0]])
-    CHECK:   %[[DONE1:.+]] = s32[4]{0} all-reduce-done(%[[START1]])
-    CHECK:   ROOT %[[DONE2:.+]] = s32[4]{0} all-reduce-done(%[[START2]])
-    CHECK: }
-
-    CHECK: ENTRY %main (a: s32[4]) -> s32[4] {
-    CHECK:   %[[A:.+]] = s32[4]{0} parameter(0)
-    CHECK:   ROOT %[[CALL:.+]] = s32[4]{0} call(%[[A]]),
-    CHECK:     to_apply=%command_buffer
-    CHECK: })";
-
-  RunAndFilecheckHloRewrite(
-      hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
-      expected, [](HloModule* module) {
-        EXPECT_TRUE(module->has_schedule());
-        TF_CHECK_OK(module->schedule().Verify());
-      });
-}
-
-TEST_F(CommandBufferSchedulingTest, DoNotCaptureUnmatchedAsyncDone) {
-  const char* hlo = R"(
-    HloModule TestModule, is_scheduled=true
-
-    %fused_computation(param_0: s32[], param_1: s32[]) -> s32[] {
-      %p0 = s32[] parameter(0)
-      %p1 = s32[] parameter(1)
-      ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
-    }
-
-    %fused_computation.1(param_0: s32[], param_1: s32[]) -> s32[] {
-      %p0 = s32[] parameter(0)
-      %p1 = s32[] parameter(1)
-      ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
-    }
-
-    %add (p0: s32[4], p1: s32[4]) -> s32[4] {
-      %p0 = s32[4] parameter(0)
-      %p1 = s32[4] parameter(1)
-      ROOT %add = s32[4] add(s32[4] %p0, s32[4] %p1)
-    }
-
-    ENTRY %main (a: s32[4], b:s32[]) -> s32[] {
-      %a = s32[4] parameter(0)
-      %b = s32[] parameter(1)
-      %start1 = s32[4]{0} all-reduce-start(s32[4]{0} %a),
-        replica_groups={{0,1}}, to_apply=%add,
-        backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}}
-      %c = s32[] custom-call(), custom_call_target="target"
-      %start2 = s32[4]{0} all-reduce-start(s32[4]{0} %a),
-        replica_groups={{0,1}}, to_apply=%add,
-        backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}}
-      %done1 = s32[4]{0} all-reduce-done(s32[4]{0} %start1)
-      %done2 = s32[4]{0} all-reduce-done(s32[4]{0} %start2)
-      %fusion = s32[] fusion(s32[] %b, s32[] %c), kind=kLoop, calls=%fused_computation
-      ROOT %fusion.1 = s32[] fusion(s32[] %b, s32[] %c), kind=kLoop, calls=%fused_computation.1
-    })";
-
-  const char* expected = R"(
-    CHECK: %command_buffer ([[P0:.+]]: s32[], [[P1:.+]]: s32[]) -> s32[] {
-    CHECK:   %[[P0]] = s32[] parameter(0)
-    CHECK:   %[[P1]] = s32[] parameter(1)
-    CHECK:   %fusion = s32[] fusion(%[[P0]], %[[P1]]), kind=kLoop, calls=%fused_computation
-    CHECK:   ROOT %fusion.1 = s32[] fusion(%[[P0]], %[[P1]]), kind=kLoop, calls=%fused_computation.1
-    CHECK: }
-
-    CHECK: ENTRY %main (a: s32[4], b: s32[]) -> s32[] {
-    CHECK:   %[[A:.+]] = s32[4]{0} parameter(0)
-    CHECK:   %[[B:.+]] = s32[] parameter(1)
-    CHECK:   %[[START1:.+]] = s32[4]{0} all-reduce-start(%[[A]])
-    CHECK:   %[[C:.+]] = s32[] custom-call()
-    CHECK:   %[[START2:.+]] = s32[4]{0} all-reduce-start(%[[A]])
-    CHECK:   %[[DONE1:.+]] = s32[4]{0} all-reduce-done(%[[START1]])
-    CHECK:   %[[DONE2:.+]] = s32[4]{0} all-reduce-done(%[[START2]])
-    CHECK:   %call = s32[] call(%b, %c), to_apply=%command_buffer
-    CHECK: })";
-
-  RunAndFilecheckHloRewrite(
-      hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
-      expected, [](HloModule* module) {
-        EXPECT_TRUE(module->has_schedule());
-        TF_CHECK_OK(module->schedule().Verify());
-      });
-}
-
-TEST_F(CommandBufferSchedulingTest, CollectCommandBufferSequence) {
-  const char* hlo = R"(
-      HloModule TestModule, is_scheduled=true
-
-      %fused_computation(param_0: s32[], param_1: s32[]) -> s32[] {
-        %p0 = s32[] parameter(0)
-        %p1 = s32[] parameter(1)
-        ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
-      }
-
-      %fused_computation.1(param_0: s32[], param_1: s32[]) -> s32[] {
-        %p0 = s32[] parameter(0)
-        %p1 = s32[] parameter(1)
-        ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
-      }
-
-      %fused_computation.2(param_0: s32[], param_1: s32[]) -> s32[] {
-        %p0 = s32[] parameter(0)
-        %p1 = s32[] parameter(1)
-        ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
-      }
-
-      %fused_computation.3(param_0: s32[], param_1: s32[]) -> s32[] {
-        %p0 = s32[] parameter(0)
-        %p1 = s32[] parameter(1)
-        ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
-      }
-
-      ENTRY %main (a: s32[], b: s32[], c: (s32[], s32[])) -> s32[] {
-        %a = s32[] parameter(0)
-        %b = s32[] parameter(1)
-        %c = (s32[], s32[]) parameter(2)
-        %fusion = s32[] fusion(s32[] %a, s32[] %b), kind=kLoop, calls=%fused_computation
-        %d = s32[] get-tuple-element((s32[], s32[]) %c), index=0
-        %fusion.1 = s32[] fusion(s32[] %fusion, s32[] %d), kind=kLoop, calls=%fused_computation.1
-        %e = s32[] get-tuple-element((s32[], s32[]) %c), index=1
-        %custom-call = s32[] custom-call(s32[] %fusion.1, s32[] %e), custom_call_target="some target"
-        %fusion.2 = s32[] fusion(s32[] %custom-call, s32[] %a), kind=kLoop, calls=%fused_computation.2
-        ROOT %fusion.3 = s32[] fusion(s32[] %custom-call, s32[] %fusion.2), kind=kLoop, calls=%fused_computation.3
-      })";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
-                          ParseAndReturnVerifiedModule(hlo));
-
-  HloInstructionSequence seq;
-  for (HloInstruction* x : module->entry_computation()->instructions()) {
-    seq.push_back(x);
-  }
-  EXPECT_EQ(seq.size(), 10);
-
-  CommandBufferScheduling::CommandBufferConfig config{
-      {DebugOptions::FUSION}, {}, device_desc()};
-
-  std::vector<HloInstructionSequence> command_buffer_sequences =
-      CommandBufferScheduling::CollectCommandBufferSequences(seq, config);
-  EXPECT_EQ(command_buffer_sequences.size(), 2);
-
-  std::vector<HloInstruction*> seq_0 =
-      command_buffer_sequences[0].instructions();
-  EXPECT_EQ(seq_0.size(), 3);
-  EXPECT_EQ(seq_0[0]->opcode(), HloOpcode::kFusion);
-  EXPECT_EQ(seq_0[1]->opcode(), HloOpcode::kGetTupleElement);
-  EXPECT_EQ(seq_0[2]->opcode(), HloOpcode::kFusion);
-
-  std::vector<HloInstruction*> seq_1 =
-      command_buffer_sequences[1].instructions();
-  EXPECT_EQ(seq_1.size(), 2);
-  EXPECT_EQ(seq_1[0]->opcode(), HloOpcode::kFusion);
-  EXPECT_EQ(seq_1[1]->opcode(), HloOpcode::kFusion);
-}
-
-TEST_F(CommandBufferSchedulingTest, MoveParametersToFront) {
-  const char* hlo = R"(
-      HloModule TestModule, is_scheduled=true
-
-      %fused_computation (param_0: s32[], param_1: s32[]) -> s32[] {
-        %p0 = s32[] parameter(0)
-        %p1 = s32[] parameter(1)
-        ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
-      }
-
-      %fused_computation.1 (param_0: s32[], param_1: s32[]) -> s32[] {
-        %p0 = s32[] parameter(0)
-        %p1 = s32[] parameter(1)
-        ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
-      }
-
-      ENTRY %main (a: s32[], b: s32[], c: s32[]) -> s32[] {
-        %a = s32[] parameter(0)
-        %b = s32[] parameter(1)
-        %fusion = s32[] fusion(s32[] %a, s32[] %b), kind=kLoop, calls=%fused_computation
-        %c = s32[] parameter(2)
-        ROOT %fusion.1 = s32[] fusion(s32[] %a, s32[] %c), kind=kLoop, calls=%fused_computation.1
-      })";
-
-  const char* expected = R"(
-// CHECK: ENTRY %main (a: s32[], b: s32[], c: s32[]) -> s32[] {
-// CHECK:   %a = s32[] parameter(0)
-// CHECK:   %b = s32[] parameter(1)
-// CHECK:   %c = s32[] parameter(2)
-// CHECK:   %fusion = s32[] fusion(%a, %b), kind=kLoop, calls=%fused_computation
-// CHECK:   ROOT %fusion.1 = s32[] fusion(%a, %c), kind=kLoop, calls=%fused_computation.1
-// CHECK: })";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
-                          ParseAndReturnVerifiedModule(hlo));
-  TF_ASSERT_OK(CommandBufferScheduling::MoveParametersAndConstantsToFront(
-      module->entry_computation()));
-  TF_ASSERT_OK_AND_ASSIGN(
-      bool filecheck_matches,
-      RunFileCheck(
-          module->ToString(HloPrintOptions{}.set_print_operand_shape(false)),
-          expected));
-  EXPECT_TRUE(filecheck_matches);
-}
-
-TEST_F(CommandBufferSchedulingTest, PrepareCommandBuffer) {
-  const char* hlo = R"(
-      HloModule TestModule, is_scheduled=true
-
-      %fused_computation(param_0: s32[], param_1: s32[]) -> (s32[], s32[]) {
-        %p0 = s32[] parameter(0)
-        %p1 = s32[] parameter(1)
-        ROOT %tuple.1 = (s32[], s32[]) tuple(s32[] %p0, s32[] %p1)
-      }
-
-      %fused_computation.1(param_0: s32[], param_1: s32[]) -> s32[] {
-        %p0 = s32[] parameter(0)
-        %p1 = s32[] parameter(1)
-        ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
-      }
-
-      ENTRY %main (a: s32[], b: s32[]) -> s32[] {
-        %a = s32[] parameter(0)
-        %b = s32[] custom-call(), custom_call_target="target"
-        %fusion = (s32[], s32[]) fusion(s32[] %a, s32[] %b), kind=kLoop, calls=%fused_computation
-        %d = s32[] get-tuple-element((s32[], s32[]) %fusion), index=0
-        %fusion.1 = s32[] fusion(s32[] %a, s32[] %d), kind=kLoop, calls=%fused_computation.1
-        ROOT %custom-call = s32[] custom-call(s32[] %fusion.1, s32[] %d), custom_call_target="some target"
-      })";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnUnverifiedModule(hlo));
-
-  EXPECT_EQ(module->entry_computation()->instruction_count(), 6);
-  std::vector<HloInstruction*> instructions;
-  HloInstructionSequence seq;
-  for (HloInstruction* inst : module->entry_computation()->instructions()) {
-    if (inst->opcode() == HloOpcode::kFusion ||
-        inst->opcode() == HloOpcode::kGetTupleElement) {
-      seq.push_back(inst);
-    }
-    instructions.push_back(inst);
-  }
-
-  TF_ASSERT_OK_AND_ASSIGN(
-      CommandBuffer command_buffer,
-      CommandBufferScheduling::PrepareCommandBuffer(seq, module.get()));
-  HloComputation* computation = module->AddComputation(
-      std::move(command_buffer.computation), /*is_entry=*/false);
-
-  const char* expected = R"(
-// CHECK: %command_buffer ([[P0:.+]]: s32[], [[P1:.+]]: s32[]) -> (s32[], s32[]) {
-// CHECK:  %[[P0]] = s32[] parameter(0)
-// CHECK:  %[[P1]] = s32[] parameter(1)
-// CHECK:  %fusion = (s32[], s32[]) fusion(%[[P0]], %[[P1]]), kind=kLoop, calls=%fused_computation
-// CHECK:  %[[V0:.+]] = s32[] get-tuple-element(%fusion), index=0
-// CHECK:  %fusion.1 = s32[] fusion(%[[P0]], %[[V0]]), kind=kLoop, calls=%fused_computation.1
-// CHECK:  ROOT {{.*}} = (s32[], s32[]) tuple(%[[V0]], %fusion.1)
-// CHECK:})";
-
-  TF_ASSERT_OK_AND_ASSIGN(
-      bool filecheck_matches,
-      RunFileCheck(computation->ToString(
-                       HloPrintOptions{}.set_print_operand_shape(false)),
-                   expected));
-  EXPECT_TRUE(filecheck_matches);
-
-  auto& arguments = command_buffer.arguments;
-  ASSERT_EQ(arguments.size(), 2);
-  EXPECT_EQ(arguments[0], instructions[0]);
-  EXPECT_EQ(arguments[1], instructions[1]);
-
-  auto& results = command_buffer.results;
-  ASSERT_EQ(results.size(), 2);
-  EXPECT_EQ(results[0], instructions[3]);
-  EXPECT_EQ(results[1], instructions[4]);
-}
-
-TEST_F(CommandBufferSchedulingTest, ForwardControlDependencies) {
-  const char* hlo = R"(
-    HloModule TestModule, is_scheduled=true
-
-    %fused_computation (param_0: s32[], param_1: s32[]) -> s32[] {
-      %p0 = s32[] parameter(0)
-      %p1 = s32[] parameter(1)
-      ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
-    }
-
-    %fused_computation.1 (param_0: s32[], param_1: s32[]) -> s32[] {
-      %p0 = s32[] parameter(0)
-      %p1 = s32[] parameter(1)
-      ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
-    }
-
-    %fused_computation.2 (param_0: s32[], param_1: s32[]) -> s32[] {
-      %p0 = s32[] parameter(0)
-      %p1 = s32[] parameter(1)
-      ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
-    }
-
-    ENTRY %main (a: s32[], b: s32[]) -> s32[] {
-      %a = s32[] parameter(0)
-      %b = s32[] parameter(1)
-      %custom-call = s32[] custom-call(), custom_call_target="some target"
-      %fusion = s32[] fusion(s32[] %a, s32[] %b), kind=kLoop, calls=%fused_computation, control-predecessors={%custom-call}
-      %fusion.1 = s32[] fusion(s32[] %a, s32[] %b), kind=kLoop, calls=%fused_computation.1, control-predecessors={%fusion}
-      %custom-call.1 = s32[] custom-call(), custom_call_target="some target"
-      %fusion.2 = s32[] fusion(s32[] %a, s32[] %b), kind=kLoop, calls=%fused_computation.2, control-predecessors={%fusion.1}
-      ROOT %custom-call.2 = s32[] custom-call(s32[] %fusion.1, s32[] %fusion.2), custom_call_target="some target"
-    })";
-
-  const char* expected = R"(
-    CHECK: %command_buffer ([[P0:.+]]: s32[], [[P1:.+]]: s32[]) -> s32[] {
-    CHECK:   %[[P0]] = s32[] parameter(0)
-    CHECK:   %[[P1]] = s32[] parameter(1)
-    CHECK:   %[[F0:.+]] = s32[] fusion(%[[P0]], %[[P1]])
-    CHECK:   ROOT {{.*}} = s32[] fusion(%[[P0]], %[[P1]]), {{.*}} control-predecessors={%[[F0]]}
-    CHECK: }
-
-    CHECK: ENTRY %main (a: s32[], b: s32[]) -> s32[] {
-    CHECK:   %a = s32[] parameter(0)
-    CHECK:   %b = s32[] parameter(1)
-    CHECK:   %custom-call = s32[] custom-call(), custom_call_target="some target"
-    CHECK:   %call = s32[] call(%a, %b), to_apply=%command_buffer, control-predecessors={%custom-call}
-    CHECK:   %custom-call.1 = s32[] custom-call(), custom_call_target="some target"
-    CHECK:   %[[F3:.+]] = s32[] fusion(%a, %b), kind=kLoop, calls=%fused_computation.2, control-predecessors={%call}
-    CHECK:   ROOT %custom-call.2 = s32[] custom-call(%call, %[[F3]]), custom_call_target="some target"
-    CHECK: })";
-
-  RunAndFilecheckHloRewrite(
-      hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
-      expected, [](HloModule* module) {
-        EXPECT_TRUE(module->has_schedule());
-        TF_CHECK_OK(module->schedule().Verify());
-      });
-}
-
-TEST_F(CommandBufferSchedulingTest, ForwardControlDependenciesToParams) {
-  const char* hlo = R"(
-    HloModule TestModule, is_scheduled=true
-
-    %fused_computation.0 (p0: s32[], p1: s32[]) -> s32[] {
-      %p0 = s32[] parameter(0)
-      %p1 = s32[] parameter(1)
-      ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
-    }
-
-    %fused_computation.1 (p0: s32[], p1: s32[]) -> s32[] {
-      %p0 = s32[] parameter(0)
-      %p1 = s32[] parameter(1)
-      ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
-    }
-
-    ENTRY %main (a: s32[], b: s32[]) -> s32[] {
-      %a = s32[] parameter(0)
-      %b = s32[] parameter(1)
-      %custom-call = s32[] custom-call(), custom_call_target="some target"
-      %fusion = s32[] fusion(s32[] %custom-call, s32[] %a), kind=kLoop, calls=%fused_computation.0, control-predecessors={%custom-call}
-      ROOT %fusion.1 = s32[] fusion(s32[] %fusion, s32[] %b), kind=kLoop, calls=%fused_computation.1
-    })";
-
-  const char* expected = R"(
-    CHECK: ENTRY %main (a: s32[], b: s32[]) -> s32[] {
-    CHECK:   %a = s32[] parameter(0)
-    CHECK:   %b = s32[] parameter(1)
-    CHECK:   %[[CUSTOM_CALL:.+]] = s32[] custom-call(), custom_call_target="some target"
-    CHECK:   ROOT {{.*}} call(%[[CUSTOM_CALL]], %a, %b), to_apply=%command_buffer, control-predecessors={%[[CUSTOM_CALL]]}
-    CHECK: })";
-
-  RunAndFilecheckHloRewrite(
-      hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
-      expected, [](HloModule* module) {
-        EXPECT_TRUE(module->has_schedule());
-        TF_CHECK_OK(module->schedule().Verify());
-      });
-}
-
-TEST_F(CommandBufferSchedulingTest, WhileNotCommand) {
-  const char* hlo = R"(
-    HloModule TestModule, is_scheduled=true
-
-    %fused_computation (param_0: f32[1]) -> f32[1] {
-      %param_0 = f32[1]{0} parameter(0)
-      ROOT %copy.5 = f32[1]{0} copy(f32[1]{0} %param_0)
-    }
-
-    %fused_computation.1 (param_0.1: f32[1], param_1: f32[1]) -> f32[1] {
-      %param_0.1 = f32[1]{0} parameter(0)
-      %param_1 = f32[1]{0} parameter(1)
-      ROOT %add.2 = f32[1]{0} add(f32[1]{0} %param_0.1, f32[1]{0} %param_1)
-    }
-
-    %fused_computation.2 (param_0.2: f32[1], param_1.1: f32[1]) -> pred[1] {
-      %param_0.2 = f32[1]{0} parameter(0)
-      %param_1.1 = f32[1]{0} parameter(1)
-      ROOT %compare.3 = pred[1]{0} compare(f32[1]{0} %param_0.2, f32[1]{0} %param_1.1), direction=LT
-    }
-
-    %fused_computation.3 (param_0.1: f32[1], param_1: f32[1]) -> f32[1] {
-      %param_0.1 = f32[1]{0} parameter(0)
-      %param_1 = f32[1]{0} parameter(1)
-      ROOT %add.2 = f32[1]{0} add(f32[1]{0} %param_0.1, f32[1]{0} %param_1)
-    }
-
-    %body (Arg_.3: f32[1]) -> f32[1] {
-      %constant_4 = f32[1]{0} constant({1})
-      %Arg_.3 = f32[1]{0} parameter(0)
-      %custom-call = s32[] custom-call(), custom_call_target="some target"
-      %add = f32[1]{0} fusion(f32[1]{0} %Arg_.3, f32[1]{0} %constant_4), kind=kLoop, calls=%fused_computation.1, control-predecessors={%custom-call}
-      ROOT %wrapped_add.1 = f32[1]{0} fusion(f32[1]{0} %add, f32[1]{0} %constant_4), kind=kLoop, calls=%fused_computation.3, control-predecessors={%custom-call}
-    }
-
-    %cond (Arg_.11: f32[1]) -> pred[] {
-      %constant = f32[1]{0} constant({100})
-      %Arg_.11 = f32[1]{0} parameter(0)
-      %wrapped_compare.2 = pred[1]{0} fusion(f32[1]{0} %Arg_.11, f32[1]{0} %constant), kind=kLoop, calls=%fused_computation.2
-      ROOT %bitcast = pred[] bitcast(pred[1]{0} %wrapped_compare.2)
-    }
-
-    ENTRY %main.18 (Arg_0.1: f32[1]) -> f32[] {
-      %Arg_0.1 = f32[1]{0} parameter(0), sharding={replicated}
-      %wrapped_copy.4 = f32[1]{0} fusion(f32[1]{0} %Arg_0.1), kind=kLoop, calls=%fused_computation
-      %while.16 = f32[1]{0} while(f32[1]{0} %wrapped_copy.4), condition=%cond, body=%body
-      ROOT %bitcast.1 = f32[] bitcast(f32[1]{0} %while.16)
-    })";
-
-  const char* expected = R"(
-    CHECK: %command_buffer ([[P0:.+]]: f32[1], [[P1:.+]]: f32[1]) -> f32[1] {
-    CHECK:   %[[P0]] = f32[1]{0} parameter(0)
-    CHECK:   %[[P1]] = f32[1]{0} parameter(1)
-    CHECK:   %[[ADD:.*]] = f32[1]{0} fusion(%[[P0]], %[[P1]]), kind=kLoop
-    CHECK:   ROOT {{.*}} = f32[1]{0} fusion(%[[ADD]], %[[P1]]), kind=kLoop
-    CHECK: }
-
-    CHECK: %[[BODY:[a-z_0-9.]+]] ([[P0:.+]]: f32[1]) -> f32[1] {
-    CHECK:   %[[C1:.*]] = f32[1]{0} constant({1})
-    CHECK:   %[[P0]] = f32[1]{0} parameter(0)
-    CHECK:   %[[CC:.*]] = s32[] custom-call(), custom_call_target="some target"
-    CHECK:   ROOT %call = f32[1]{0} call(%[[P0]], %[[C1]]), to_apply=%command_buffer, control-predecessors={%[[CC]]}
-    CHECK: }
-
-    CHECK: ENTRY %[[MAIN:.+]] ([[ARG0:.+]]: f32[1]) -> f32[] {
-    CHECK:   %[[ARG0]] = f32[1]{0} parameter(0)
-    CHECK:   %[[COPY:.*]] = f32[1]{0} fusion(%[[ARG0]]), kind=kLoop
-    CHECK:   %[[WHILE:.*]] = f32[1]{0} while(%[[COPY]]), condition=%[[COND:[a-z_0-9.]+]], body=%[[BODY]]
-    CHECK:   ROOT %[[BC:.+]] = f32[] bitcast(%[[WHILE]])
-    CHECK: })";
-
-  RunAndFilecheckHloRewrite(
-      hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
-      expected, [](HloModule* module) {
-        EXPECT_TRUE(module->has_schedule());
-        TF_CHECK_OK(module->schedule().Verify());
-      });
-}
-
-TEST_F(CommandBufferSchedulingTest, While) {
-  const char* hlo = R"(
-    HloModule TestModule, is_scheduled=true
-
-    %fused_computation (param_0: f32[1]) -> f32[1] {
-      %param_0 = f32[1]{0} parameter(0)
-      ROOT %copy.5 = f32[1]{0} copy(f32[1]{0} %param_0)
-    }
-
-    %fused_computation.1 (param_0.1: f32[1], param_1: f32[1]) -> f32[1] {
-      %param_0.1 = f32[1]{0} parameter(0)
-      %param_1 = f32[1]{0} parameter(1)
-      ROOT %add.2 = f32[1]{0} add(f32[1]{0} %param_0.1, f32[1]{0} %param_1)
-    }
-
-    %fused_computation.2 (param_0.2: f32[1], param_1.1: f32[1]) -> pred[1] {
-      %param_0.2 = f32[1]{0} parameter(0)
-      %param_1.1 = f32[1]{0} parameter(1)
-      ROOT %compare.3 = pred[1]{0} compare(f32[1]{0} %param_0.2, f32[1]{0} %param_1.1), direction=LT
-    }
-
-    %body (Arg_.3: f32[1]) -> f32[1] {
-      %constant_4 = f32[1]{0} constant({1})
-      %Arg_.3 = f32[1]{0} parameter(0)
-      ROOT %wrapped_add.1 = f32[1]{0} fusion(f32[1]{0} %Arg_.3, f32[1]{0} %constant_4), kind=kLoop, calls=%fused_computation.1
-    }
-
-    %cond (Arg_.11: f32[1]) -> pred[] {
-      %constant = f32[1]{0} constant({100})
-      %Arg_.11 = f32[1]{0} parameter(0)
-      %wrapped_compare.2 = pred[1]{0} fusion(f32[1]{0} %Arg_.11, f32[1]{0} %constant), kind=kLoop, calls=%fused_computation.2
-      ROOT %bitcast = pred[] bitcast(pred[1]{0} %wrapped_compare.2)
-    }
-
-    ENTRY %main.18 (Arg_0.1: f32[1]) -> f32[] {
-      %Arg_0.1 = f32[1]{0} parameter(0), sharding={replicated}
-      %wrapped_copy.4 = f32[1]{0} fusion(f32[1]{0} %Arg_0.1), kind=kLoop, calls=%fused_computation
-      %while.16 = f32[1]{0} while(f32[1]{0} %wrapped_copy.4), condition=%cond, body=%body
-      ROOT %bitcast.1 = f32[] bitcast(f32[1]{0} %while.16)
-    })";
-
-  const char* expected = R"(
-    CHECK: %command_buffer ([[P0:.+]]: f32[1]) -> f32[1] {
-    CHECK:   %[[P0]] = f32[1]{0} parameter(0)
-    CHECK:   %[[COPY:.*]] = f32[1]{0} fusion(%[[P0]]), kind=kLoop
-    CHECK:   ROOT {{.*}} = f32[1]{0} while(%[[COPY]]), condition=%[[COND:[a-z_0-9.]+]], body=%[[BODY:[a-z_0-9.]+]]
-    CHECK: }
-
-    CHECK: ENTRY %[[MAIN:.+]] ([[ARG0:.+]]: f32[1]) -> f32[] {
-    CHECK:   %[[ARG0]] = f32[1]{0} parameter(0)
-    CHECK:   %call = f32[1]{0} call(%[[ARG0]]), to_apply=%command_buffer
-    CHECK:   ROOT %[[BC:.+]] = f32[] bitcast(%call)
-    CHECK: })";
-
-  RunAndFilecheckHloRewrite(
-      hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
-      expected, [](HloModule* module) {
-        EXPECT_TRUE(module->has_schedule());
-        TF_CHECK_OK(module->schedule().Verify());
-      });
-}
-
-TEST_F(CommandBufferSchedulingTest, Conditional) {
-  const char* hlo = R"(
-    HloModule TestModule, is_scheduled=true
-
-    %fused_computation.1 (param_0.2: s32[5]) -> s32[5] {
-      %param_0.2 = s32[5]{0} parameter(0)
-      ROOT %negate.2 = s32[5]{0} negate(s32[5]{0} %param_0.2)
-    }
-
-    %region_0.7 (Arg_.8: s32[5]) -> (s32[5]) {
-      %Arg_.8 = s32[5]{0} parameter(0)
-      %wrapped_negate.1 = s32[5]{0} fusion(s32[5]{0} %Arg_.8), kind=kLoop, calls=%fused_computation.1
-      ROOT %tuple.3 = (s32[5]{0}) tuple(s32[5]{0} %wrapped_negate.1)
-    }
-
-    %fused_computation.2 (param_0.3: s32[5]) -> s32[5] {
-      %param_0.3 = s32[5]{0} parameter(0)
-      ROOT %not.2 = s32[5]{0} not(s32[5]{0} %param_0.3)
-    }
-
-    %region_1.10 (Arg_.11: s32[5]) -> (s32[5]) {
-      %Arg_.11 = s32[5]{0} parameter(0)
-      %wrapped_not.1 = s32[5]{0} fusion(s32[5]{0} %Arg_.11), kind=kLoop, calls=%fused_computation.2
-      ROOT %tuple.4 = (s32[5]{0}) tuple(s32[5]{0} %wrapped_not.1)
-    }
-
-    %fused_computation.3 (param_0.4: s32[5]) -> s32[5] {
-      %param_0.4 = s32[5]{0} parameter(0)
-      ROOT %multiply.2 = s32[5]{0} multiply(s32[5]{0} %param_0.4, s32[5]{0} %param_0.4)
-    }
-
-    %region_2.13 (Arg_.14: s32[5]) -> (s32[5]) {
-      %Arg_.14 = s32[5]{0} parameter(0)
-      %wrapped_multiply.1 = s32[5]{0} fusion(s32[5]{0} %Arg_.14), kind=kLoop, calls=%fused_computation.3
-      ROOT %tuple.5 = (s32[5]{0}) tuple(s32[5]{0} %wrapped_multiply.1)
-    }
-
-    %fused_computation (param_0.1: s64[]) -> s32[] {
-      %constant_1 = s32[] constant(0)
-      %param_0.1 = s64[] parameter(0)
-      %convert.2 = s32[] convert(s64[] %param_0.1)
-      %constant_0 = s32[] constant(2)
-      ROOT %clamp.2 = s32[] clamp(s32[] %constant_1, s32[] %convert.2, s32[] %constant_0)
-    }
-
-    ENTRY %main.17 (Arg_0.1: s64[], Arg_1.2: s32[5]) -> s32[5] {
-      %Arg_0.1 = s64[] parameter(0), sharding={replicated}
-      %fusion = s32[] fusion(s64[] %Arg_0.1), kind=kLoop, calls=%fused_computation
-      %Arg_1.2 = s32[5]{0} parameter(1), sharding={replicated}
-      %conditional.16.clone = (s32[5]{0}) conditional(s32[] %fusion, s32[5]{0} %Arg_1.2, s32[5]{0} %Arg_1.2, s32[5]{0} %Arg_1.2), branch_computations={%region_0.7, %region_1.10, %region_2.13}
-      ROOT %get-tuple-element = s32[5]{0} get-tuple-element((s32[5]{0}) %conditional.16.clone), index=0
-    })";
-
-  const char* expected = R"(
-    CHECK: %command_buffer ([[P0:.+]]: s64[], [[P1:.+]]: s32[5]) -> (s32[5]) {
-    CHECK:   %[[P0]] = s64[] parameter(0)
-    CHECK:   %[[P1]] = s32[5]{0} parameter(1)
-    CHECK:   %[[FUSION:.*]] = s32[] fusion(%[[P0]]), kind=kLoop
-    CHECK:   ROOT {{.*}} = (s32[5]{0}) conditional(%[[FUSION]], %[[P1]], %[[P1]], %[[P1]]), branch_computations={%[[B1:[a-z_0-9.]+]], %[[B2:[a-z_0-9.]+]], %[[B3:[a-z_0-9.]+]]}
-    CHECK: }
-
-    CHECK: ENTRY %[[MAIN:.+]] ([[ARG0:.+]]: s64[], [[ARG1:.+]]: s32[5]) -> s32[5] {
-    CHECK:   %[[ARG0]] = s64[] parameter(0)
-    CHECK:   %[[ARG1]] = s32[5]{0} parameter(1)
-    CHECK:   %call = (s32[5]{0}) call(%[[ARG0]], %[[ARG1]]), to_apply=%command_buffer
-    CHECK:   ROOT %[[GEP:.+]] = s32[5]{0} get-tuple-element(%call)
-    CHECK: })";
-
-  RunAndFilecheckHloRewrite(
-      hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
-      expected, [](HloModule* module) {
-        EXPECT_TRUE(module->has_schedule());
-        TF_CHECK_OK(module->schedule().Verify());
-      });
-}
-
-TEST_F(CommandBufferSchedulingTest, CuDnnFusionGraphCaptureWorks) {
-  const std::string kHloText = R"(
-HloModule m, is_scheduled=true
-
-fusion0 {
-  p0 = f32[64,64] parameter(0)
-  p1 = f32[64,64] parameter(1)
-  ROOT d = f32[64,64] dot(p0, p1),
-    lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
-
-fusion1 {
-  p0 = f32[64,64] parameter(0)
-  p1 = f32[64,64] parameter(1)
-  ROOT d = f32[64,64] dot(p0, p1),
-    lhs_contracting_dims={0}, rhs_contracting_dims={1}
-}
-
-fusion_a {
-  p0 = f32[64,64] parameter(0)
-  p1 = f32[64,64] parameter(1)
-  ROOT a = f32[64,64] add(p0, p1)
-}
-
-ENTRY e {
-  p0 = f32[64,64] parameter(0)
-  p1 = f32[64,64] parameter(1)
-  d0 = f32[64,64] fusion(p0, p1), kind=kCustom,
-    calls=fusion0,
-    backend_config={"fusion_backend_config": {"kind":"__cudnn$fusion"}}
-  a = f32[64,64] fusion(d0, d0), kind=kLoop, calls=fusion_a
-  ROOT d1 = f32[64,64] fusion(a, p1), kind=kCustom,
-    calls=fusion1,
-    backend_config={"fusion_backend_config": {"kind":"__cudnn$fusion"}}
-})";
-
-  const std::string kExpected = R"(
-; CHECK: ENTRY
-; CHECK-NEXT: parameter
-; CHECK-NEXT: parameter
-; CHECK-NEXT: ROOT
-; CHECK-SAME: call(
-; CHECK-SAME: to_apply=%command_buffer
-})";
-
-  RunAndFilecheckHloRewrite(
-      kHloText,
-      CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
-      kExpected, [](HloModule* module) {
-        EXPECT_TRUE(module->has_schedule());
-        TF_CHECK_OK(module->schedule().Verify());
-      });
-}
-
-}  // namespace
-
-}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/conv_algorithm_picker.cc b/third_party/xla/xla/service/gpu/conv_algorithm_picker.cc
deleted file mode 100644
index 40bbb7a..0000000
--- a/third_party/xla/xla/service/gpu/conv_algorithm_picker.cc
+++ /dev/null
@@ -1,1193 +0,0 @@
-/* Copyright 2018 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/conv_algorithm_picker.h"
-
-#include <algorithm>
-#include <cmath>
-#include <cstddef>
-#include <cstdint>
-#include <limits>
-#include <memory>
-#include <optional>
-#include <string>
-#include <string_view>
-#include <tuple>
-#include <utility>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/status.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/str_format.h"
-#include "absl/strings/string_view.h"
-#include "absl/synchronization/mutex.h"
-#include "absl/time/time.h"
-#include "absl/types/span.h"
-#include "xla/autotuning.pb.h"
-#include "xla/debug_options_flags.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/literal_util.h"
-#include "xla/service/gpu/autotuner_compile_util.h"
-#include "xla/service/gpu/autotuner_util.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/service/gpu/gpu_autotuning.pb.h"
-#include "xla/service/gpu/gpu_conv_runner.h"
-#include "xla/service/gpu/hlo_algorithm_denylist.h"
-#include "xla/service/gpu/stream_executor_util.h"
-#include "xla/service/hlo_module_config.h"
-#include "xla/service/slow_operation_alarm.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/stream_executor/cuda/cuda_platform_id.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/stream_executor/device_memory_allocator.h"
-#include "xla/stream_executor/dnn.h"
-#include "xla/stream_executor/lazy_op_runner.h"
-#include "xla/stream_executor/numeric_options.h"
-#include "xla/stream_executor/platform.h"
-#include "xla/stream_executor/rocm/rocm_platform_id.h"
-#include "xla/stream_executor/scratch_allocator.h"
-#include "xla/stream_executor/stream.h"
-#include "xla/stream_executor/stream_executor.h"
-#include "xla/tsl/util/env_var.h"
-#include "xla/tsl/util/proto/proto_utils.h"
-#include "xla/util.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h"
-#include "tsl/platform/numbers.h"
-#include "tsl/platform/status.h"
-#include "tsl/platform/statusor.h"
-
-#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
-#include "third_party/gpus/cudnn/cudnn.h"  // IWYU pragma: keep
-#include "third_party/gpus/cudnn/cudnn_version.h"
-#if CUDNN_VERSION >= 90000
-#include "third_party/gpus/cudnn/cudnn_ops.h"
-#else
-#include "third_party/gpus/cudnn/cudnn_ops_infer.h"
-#endif  // CUDNN_VERSION >= 90000
-#include "xla/service/gpu/buffer_comparator.h"
-#include "xla/stream_executor/gpu/redzone_allocator.h"
-#endif
-
-namespace xla {
-namespace gpu {
-namespace {
-
-using se::DeviceMemoryBase;
-using se::dnn::AlgorithmDesc;
-using std::optional;
-
-class ScratchAllocator : public se::ScratchAllocator {
- public:
-  ScratchAllocator(int device_ordinal,
-                   se::DeviceMemoryAllocator* memory_allocator)
-      : device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {}
-
-  int64_t GetMemoryLimitInBytes() override {
-    return ScratchAllocator::GetDefaultMemoryLimitInBytes();
-  }
-  int64_t TotalAllocatedBytes() { return total_allocated_bytes_; }
-
-  static int64_t GetDefaultMemoryLimitInBytes() {
-    int64_t value;
-    TF_CHECK_OK(tsl::ReadInt64FromEnvVar("TF_CUDNN_WORKSPACE_LIMIT_IN_MB",
-                                         1LL << 12, &value));
-    return value * (1LL << 20);
-  }
-
-  absl::StatusOr<se::DeviceMemory<uint8_t>> AllocateBytes(
-      int64_t byte_size) override;
-
-  template <typename T>
-  absl::StatusOr<se::DeviceMemory<T>> Allocate(int64_t num_elements) {
-    TF_ASSIGN_OR_RETURN(se::DeviceMemory<uint8_t> bytes,
-                        AllocateBytes(num_elements * sizeof(T)));
-    return se::DeviceMemory<T>(bytes);
-  }
-
- private:
-  const int device_ordinal_;
-  se::DeviceMemoryAllocator* memory_allocator_;
-  std::vector<se::OwningDeviceMemory> allocated_buffers_;
-  int64_t total_allocated_bytes_ = 0;
-};
-
-absl::StatusOr<se::DeviceMemory<uint8_t>> ScratchAllocator::AllocateBytes(
-    int64_t byte_size) {
-  CHECK_GE(byte_size, 0) << "byte_size must be positive.";
-  if (byte_size > GetMemoryLimitInBytes()) {
-    return absl::ResourceExhaustedError(absl::StrFormat(
-        "Allocating %d bytes exceeds the memory limit of %d bytes.", byte_size,
-        GetMemoryLimitInBytes()));
-  }
-
-  TF_ASSIGN_OR_RETURN(se::OwningDeviceMemory allocated_buffer,
-                      memory_allocator_->Allocate(device_ordinal_, byte_size,
-                                                  /*retry_on_failure=*/false));
-  total_allocated_bytes_ += byte_size;
-
-  se::DeviceMemoryBase buffer_addr = *allocated_buffer;
-  allocated_buffers_.push_back(std::move(allocated_buffer));
-  return se::DeviceMemory<uint8_t>(buffer_addr);
-}
-
-absl::StatusOr<std::vector<GenericConvRunner>> GetAlgorithms(
-    const GpuConvConfig& config, se::Stream* stream, bool use_cudnn_frontend,
-    bool use_fallback, const se::NumericOptions& numeric_options) {
-  TF_ASSIGN_OR_RETURN(se::dnn::ConvolutionKind kind,
-                      GetDNNConvKindFromCudnnConvKind(config.kind));
-
-  TF_ASSIGN_OR_RETURN(se::dnn::DataType input_type,
-                      GetDNNDataTypeFromPrimitiveType(config.input_type));
-
-  TF_ASSIGN_OR_RETURN(se::dnn::DataType output_type,
-                      GetDNNDataTypeFromPrimitiveType(config.output_type));
-
-  se::StreamExecutor* stream_exec = stream->parent();
-  std::vector<GenericConvRunner> result;
-
-  auto dnn = stream_exec->AsDnn();
-  if (dnn == nullptr) {
-    return absl::InvalidArgumentError("No DNN in stream executor.");
-  }
-  switch (kind) {
-    default:
-      return Internal("Unknown ConvolutionKind %d", kind);
-    case se::dnn::ConvolutionKind::FORWARD_BIAS_ACTIVATION: {
-      if (!config.fusion) {
-        return Internal(
-            "GpuConvConfig had fusion ConvolutionKind but no FusionConfig.");
-      }
-      std::vector<std::unique_ptr<const se::dnn::FusedConvRunner>> runners;
-      TF_RETURN_IF_ERROR(dnn->GetFusedConvolveRunners(
-          use_cudnn_frontend,
-          // This refers to the kind of convolution op inside the fusion, not
-          // the whole fused graph.
-          se::dnn::ConvolutionKind::FORWARD, input_type,
-          BiasTypeForInputType(input_type), output_type,
-          /* conv_input_scale = */ config.conv_result_scale,
-          /* side_input_scale = */ config.fusion->side_input_scale,
-          /* leakyrelu_alpha = */ config.fusion->leakyrelu_alpha, stream,
-          config.input_descriptor, config.filter_descriptor,
-          config.bias_descriptor, config.output_descriptor, config.conv_desc,
-          use_fallback, config.fusion->mode, numeric_options, &runners));
-      for (auto& runner : runners) {
-        TF_ASSIGN_OR_RETURN(
-            auto runner_cache,
-            se::dnn::LazyOpRunner<se::dnn::FusedConvOp>::FromOpRunner(
-                std::move(runner)));
-        result.emplace_back(std::move(runner_cache));
-      }
-      break;
-    }
-
-    case se::dnn::ConvolutionKind::FORWARD_GRAPH: {
-      std::vector<std::unique_ptr<const se::dnn::GraphConvRunner>> runners;
-      // This path is cuDNN-only, where the DeviceMemoryBase arguments and the
-      // allocator are unused; so, they're all provided as nullptr.
-      TF_RETURN_IF_ERROR(dnn->GetGraphConvolveRunners(
-          kind, input_type, output_type, stream, config.input_descriptor,
-          config.filter_descriptor, config.output_descriptor, config.conv_desc,
-          use_fallback, numeric_options, &runners, config.serialized_graph));
-      for (auto& runner : runners) {
-        TF_ASSIGN_OR_RETURN(
-            auto runner_cache,
-            se::dnn::LazyOpRunner<se::dnn::GraphConvOp>::FromOpRunner(
-                std::move(runner)));
-        result.emplace_back(std::move(runner_cache));
-      }
-      break;
-    }
-
-    case se::dnn::ConvolutionKind::FORWARD:
-    case se::dnn::ConvolutionKind::BACKWARD_DATA:
-    case se::dnn::ConvolutionKind::BACKWARD_FILTER: {
-      std::vector<std::unique_ptr<const se::dnn::ConvRunner>> runners;
-      // This path is cuDNN-only, where the DeviceMemoryBase arguments and the
-      // allocator are unused; so, they're all provided as nullptr.
-      TF_RETURN_IF_ERROR(dnn->GetConvolveRunners(
-          use_cudnn_frontend, kind, input_type, output_type, stream,
-          config.input_descriptor,
-          /* input_data = */ DeviceMemoryBase(nullptr),
-          config.filter_descriptor,
-          /* filter_data = */ DeviceMemoryBase(nullptr),
-          config.output_descriptor,
-          /* output_data = */ DeviceMemoryBase(nullptr), config.conv_desc,
-          use_fallback, nullptr, numeric_options, &runners));
-
-      for (auto& runner : runners) {
-        TF_ASSIGN_OR_RETURN(
-            auto runner_cache,
-            se::dnn::LazyOpRunner<se::dnn::ConvOp>::FromOpRunner(
-                std::move(runner)));
-        result.emplace_back(std::move(runner_cache));
-      }
-      break;
-    }
-  }
-
-  return result;
-}
-
-absl::StatusOr<std::vector<std::unique_ptr<const se::dnn::ConvRunner>>>
-GetMIOpenAlgorithms(const HloCustomCallInstruction* instr,
-                    absl::Span<se::DeviceMemoryBase> operand_buffers,
-                    absl::Span<se::DeviceMemoryBase> result_buffers,
-                    se::StreamExecutor* stream_exec,
-                    ScratchAllocator* scratch_allocator, se::Stream* stream,
-                    const se::NumericOptions& numeric_options) {
-  TF_ASSIGN_OR_RETURN(GpuConvConfig config, GetGpuConvConfig(instr));
-
-  TF_ASSIGN_OR_RETURN(se::dnn::ConvolutionKind kind,
-                      GetDNNConvKindFromCudnnConvKind(config.kind));
-
-  TF_ASSIGN_OR_RETURN(se::dnn::DataType dtype,
-                      GetDNNDataTypeFromPrimitiveType(config.output_type));
-
-  TF_ASSIGN_OR_RETURN(
-      GpuConvParams params,
-      GetGpuConvParams(config, operand_buffers, result_buffers));
-
-  std::vector<std::unique_ptr<const se::dnn::ConvRunner>> runners;
-  auto dnn = stream_exec->AsDnn();
-  if (dnn == nullptr) {
-    return absl::InvalidArgumentError("No DNN in stream executor.");
-  }
-  TF_RETURN_IF_ERROR(dnn->GetConvolveRunners(
-      /* use_cudnn_frontend = */ false, kind, dtype, dtype, stream,
-      params.config->input_descriptor, params.input_buf,
-      params.config->filter_descriptor, params.filter_buf,
-      params.config->output_descriptor, params.output_buf,
-      params.config->conv_desc,
-      /* use_fallback = */ false, scratch_allocator, numeric_options,
-      &runners));
-
-  return runners;
-}
-
-std::string NumBytesToString(int64_t bytes) {
-  return absl::StrCat(tsl::strings::HumanReadableNumBytes(bytes), " (", bytes,
-                      "B)");
-}
-
-CudnnVersion GetCudnnVersion(se::StreamExecutor* stream_executor) {
-  se::dnn::VersionInfo version = GetDnnVersionInfoOrDefault(stream_executor);
-  CudnnVersion cudnn_version;
-  cudnn_version.set_major(version.major_version());
-  cudnn_version.set_minor(version.minor_version());
-  cudnn_version.set_patch(version.patch());
-
-  return cudnn_version;
-}
-
-ComputeCapability GetComputeCapability(se::StreamExecutor* stream_executor) {
-  ComputeCapability cc;
-  se::CudaComputeCapability se_cc =
-      stream_executor->GetDeviceDescription().cuda_compute_capability();
-  cc.set_major(se_cc.major);
-  cc.set_minor(se_cc.minor);
-  return cc;
-}
-
-void PrintPlatformInfo(const se::Stream* stream) {
-  auto* se = stream->parent();
-  const auto& desc = se->GetDeviceDescription();
-  LOG(ERROR) << "Device: " << desc.name();
-  LOG(ERROR) << "Platform: " << desc.platform_version();
-  LOG(ERROR) << "Driver: " << desc.driver_version();
-  LOG(ERROR) << "Runtime: " << desc.runtime_version();
-
-  auto dnn_version = GetDnnVersionInfo(se);
-  if (dnn_version.ok()) {
-    auto v = dnn_version.value();
-    LOG(ERROR) << "cudnn version: " << v.major_version() << "."
-               << v.minor_version() << "." << v.patch();
-  }
-}
-
-// Returns true if the redzones in `allocator`'s allocations are unmodified.
-//
-// If the redzones are modified, logs an error, sets the appropriate failure
-// bits on `result`, and returns false.
-//
-// Returns a absl::Status if an unexpected error has occurred, and the stream
-// has been poisoned.
-//
-// `name` is a user-friendly name for the set of redzones being checked, e.g.
-// "input/output" or "scratch".
-absl::StatusOr<bool> CheckRedzones(const se::RedzoneAllocator& allocator,
-                                   se::Stream* stream, absl::string_view name,
-                                   std::string_view instr_str,
-                                   AutotuneResult* result) {
-  XLA_SCOPED_LOGGING_TIMER_LEVEL("CudnnConvAlgorithmPicker checking redzones",
-                                 2);
-  using RedzoneCheckStatus = se::RedzoneAllocator::RedzoneCheckStatus;
-  TF_ASSIGN_OR_RETURN(RedzoneCheckStatus redzone_check,
-                      allocator.CheckRedzones());
-  if (redzone_check.ok()) {
-    return true;
-  }
-
-  auto* fail = result->mutable_failure();
-  fail->set_kind(AutotuneResult::REDZONE_MODIFIED);
-  *fail->mutable_msg() = redzone_check.RedzoneFailureMsg();
-  fail->set_buffer_address(
-      reinterpret_cast<uint64_t>(redzone_check.user_buffer_address));
-
-  LOG(ERROR) << absl::StreamFormat(
-      "Detected cudnn out-of-bounds write in conv %s buffer! This is likely a "
-      "cudnn bug. We will skip this algorithm in the future, but your GPU "
-      "state may already be corrupted, leading to incorrect results. Within "
-      "Google, no action is needed on your part. Outside of Google, please "
-      "ensure you're running the latest version of cudnn. If that doesn't fix "
-      "the problem, please file a bug with this full error message and we'll "
-      "contact nvidia.",
-      name);
-  LOG(ERROR) << redzone_check.RedzoneFailureMsg();
-  LOG(ERROR) << "HloInstruction " << instr_str;
-  PrintPlatformInfo(stream);
-  return false;
-}
-
-}  // anonymous namespace
-
-bool ShouldInitConvData(const HloModuleConfig& hlo_module_config) {
-  const int32_t conv_autotune_level =
-      hlo_module_config.debug_options().xla_gpu_autotune_level();
-  return conv_autotune_level >= 2;
-}
-
-bool ShouldCheckConv(const HloModuleConfig& hlo_module_config) {
-  const int32_t conv_autotune_level =
-      hlo_module_config.debug_options().xla_gpu_autotune_level();
-  return conv_autotune_level >= 4;
-}
-
-absl::StatusOr<AutotuneResult> GpuConvAlgorithmPicker::PickBestAlgorithm(
-    const HloCustomCallInstruction* instr) {
-  return AutotunerUtil::Autotune(
-      instr, config_, [&] { return PickBestAlgorithmNoCache(instr); });
-}
-
-absl::StatusOr<AutotuneResult> GpuConvAlgorithmPicker::PickBestAlgorithmNoCache(
-    const HloCustomCallInstruction* instr) {
-  if (config_.IsDeviceless()) {
-    // Return an autotune result with algo id -1, which means that we autotune
-    // at runtime.
-    AutotuneResult result;
-    result.mutable_algorithm()->set_algo_id(-1);
-    return result;
-  }
-
-  se::StreamExecutor* stream_exec = config_.GetExecutor();
-  // Don't run this function concurrently on the same GPU.
-  //
-  // This is a bit of a hack and doesn't protect us against arbitrary concurrent
-  // use of a GPU, but it's sufficient to let us compile two HLO modules
-  // concurrently and then run them sequentially.
-  //
-  // Putting the lock in here rather than in PickBestAlgorithmNoCache lets us
-  // avoid ever doing duplicate work.  If we have a cache miss, only one thread
-  // will run PickBestAlgorithmImpl for a particular device.
-  absl::MutexLock lock(&GetGpuMutex(stream_exec));
-
-  // Make sure any previous activity on this executor is done. We don't want
-  // other work still running on the GPU to interfere with autotuning.
-  if (!stream_exec->SynchronizeAllActivity()) {
-    return Internal(
-        "Failed to synchronize GPU for autotuning conv instruction");
-  }
-
-  absl::StatusOr<AutotuneResult> result_or(Internal("Unknown platform."));
-  // Check StreamExecutor on which platform it is. ROCm and Cuda implementation
-  // have diverged. Specifically, we need to make sure redzone allocator related
-  // utilities are not used in ROCm routine
-  se::Platform::Id platform_id = stream_exec->GetPlatform()->id();
-  if (platform_id == se::rocm::kROCmPlatformId) {
-    result_or = PickBestAlgorithmNoCacheRocm(instr);
-  } else if (platform_id == se::cuda::kCudaPlatformId) {
-#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
-    result_or = PickBestAlgorithmNoCacheCuda(instr);
-#endif
-  }
-
-  return result_or;
-}
-
-#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
-
-absl::StatusOr<GpuConvAlgorithmPicker::AutotuneRuntimeArguments>
-GpuConvAlgorithmPicker::AutotuneRuntimeArguments::FromInstruction(
-    const HloCustomCallInstruction* instr, const AutotuneConfig& config,
-    const DebugOptions& debug_options) {
-  TF_ASSIGN_OR_RETURN(auto rz_buffers,
-                      RedzoneBuffers::FromInstruction(
-                          *instr, config, debug_options,
-                          RedzoneBuffers::kAllInputsOutputsNoScratch));
-
-  // Get canonical HLO.
-  std::string canonical_hlo(
-      AutotuneCacheKey(config.GetExecutor()->GetDeviceDescription(), *instr)
-          .GetHlo());
-
-  TF_ASSIGN_OR_RETURN(GpuConvConfig gpu_conv_config, GetGpuConvConfig(instr));
-
-  GpuConvAlgorithmPicker::AutotuneRuntimeArguments runtime_arguments = {
-      instr->GetModule()->config(),
-      std::move(rz_buffers),
-      std::move(gpu_conv_config),
-      {canonical_hlo}};
-
-  return runtime_arguments;
-}
-
-struct CudnnVersionRange {
-  using TupleVersion = std::tuple<int, int, int>;
-  TupleVersion begin;
-  TupleVersion end;
-
-  bool IsInRange(const CudnnVersion& other) const {
-    TupleVersion other_version{other.major(), other.minor(), other.patch()};
-    return begin <= other_version && other_version < end;
-  }
-
-  CudnnVersionRange(const CudnnVersion& begin, const CudnnVersion& end)
-      : begin(begin.major(), begin.minor(), begin.patch()),
-        end(end.major(), end.minor(), end.patch()) {}
-
-  CudnnVersionRange(const TupleVersion& begin, const TupleVersion& end)
-      : begin(begin), end(end) {}
-};
-
-struct ComputeCapabilityRange {
-  using TupleComputeCapability = std::tuple<int, int>;
-  TupleComputeCapability begin;
-  TupleComputeCapability end;
-
-  bool IsInRange(const ComputeCapability& other) const {
-    TupleComputeCapability other_cc{other.major(), other.minor()};
-    return begin <= other_cc && other_cc < end;
-  }
-};
-
-struct DisabledAlgorithm {
-  CudnnVersionRange cudnn_version_range;
-  ComputeCapabilityRange compute_capability_range;
-  int algo_id;
-};
-
-// TODO(b/343101418): Remove this once the bug is fixed in upstream cudnn and
-// once we updated to that cudnn version.
-static const DisabledAlgorithm kDisabledAlgorithms[] = {
-    {/*.cudnn_version_range=*/{/*.begin=*/{9, 0, 0}, /*.end=*/{10, 0, 0}},
-     /*.compute_capability_range=*/{/*.begin=*/{6, 0}, /*.end=*/{8, 0}},
-     /*.algo_id=*/14}};
-
-// There are three tiers of errors possible here: returning a failed
-// absl::StatusOr means autotuning fails immediately; returning an
-// AutotuneResult with a failure code other than DISQUALIFIED means autotuning
-// fails if crash_on_checking_failure is set; and returning a DISQUALIFIED
-// AutotuneResult simply skips the engine/algorithm while recording a reason for
-// skipping it.
-absl::StatusOr<AutotuneResult> GpuConvAlgorithmPicker::AutotuneOneConvRunner(
-    GenericConvRunner* const runner,
-    std::optional<ReferenceResult>* reference_result,
-    absl::Span<const AlgorithmDesc> disabled_algos,
-    std::optional<AutotuneCacheKey> instruction_info,
-    const AutotuneRuntimeArguments& runtime_arguments) {
-  auto alg = runner->ToAlgorithmDesc();
-
-  se::StreamExecutor* stream_exec = config_.GetExecutor();
-  XLA_SCOPED_LOGGING_TIMER_LEVEL(
-      absl::StrCat("CudnnConvAlgorithmPicker::PickBestAlgorithm algo ",
-                   alg.ToString()),
-      2);
-
-  auto make_failure = [&alg](AutotuneResult::FailureKind kind,
-                             absl::string_view msg) {
-    AutotuneResult result;
-    *result.mutable_algorithm() = alg.ToProto();
-    result.mutable_failure()->set_kind(kind);
-    result.mutable_failure()->set_msg(/* *sigh* */ msg.data(), msg.size());
-    return result;
-  };
-
-  AlgorithmDesc alg_key(alg.algo_id(), alg.tensor_ops_enabled(), std::nullopt);
-
-  std::string instr_str = instruction_info.has_value()
-                              ? std::string(instruction_info->GetHlo())
-                              : "<unknown>";
-
-  for (const auto& disabled_algo : kDisabledAlgorithms) {
-    if (disabled_algo.cudnn_version_range.IsInRange(
-            GetCudnnVersion(stream_exec)) &&
-        disabled_algo.compute_capability_range.IsInRange(
-            GetComputeCapability(stream_exec)) &&
-        disabled_algo.algo_id == alg.algo_id()) {
-      LOG(INFO) << "Omitted potentially buggy algorithm " << alg.ToString()
-                << " for conv " << instr_str;
-      return make_failure(AutotuneResult::DISQUALIFIED,
-                          "Disqualified for being known-buggy.");
-    }
-  }
-
-  if (absl::c_linear_search(disabled_algos, alg_key)) {
-    LOG(INFO) << "Omitted potentially buggy algorithm " << alg.ToString()
-              << " for conv " << instr_str;
-    return make_failure(AutotuneResult::DISQUALIFIED,
-                        "Disqualified for being known-buggy.");
-  }
-
-  GpuConvConfig config = runtime_arguments.gpu_conv_config;
-  auto activation_mode =
-      config.fusion ? config.fusion->mode : se::dnn::ActivationMode::kNone;
-
-  // For fused convolutions with the identity function as the activation, only
-  // ALGO_IMPLICIT_PRECOMP_GEMM does the right thing. Other algorithms
-  // silently do Relu. See
-  // https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnConvolutionBiasActivationForward
-  //
-  // For cuDNN Frontend, there is no way to check whether we're using a broken
-  // algorithm, so on versions where some algorithms are broken, we don't use
-  // the cuDNN Frontend for these convs at all.  As such, if we get a
-  // frontend-based runner, we can be sure it's not one of the broken
-  // algorithms we're checking for.
-  if (!alg.is_cudnn_frontend() &&
-      config.kind == CudnnConvKind::kForwardActivation &&
-      activation_mode == se::dnn::ActivationMode::kNone &&
-      alg.algo_id() != CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM) {
-    return make_failure(AutotuneResult::DISQUALIFIED,
-                        "Disqualified for implicit RELU.");
-  }
-
-  TF_ASSIGN_OR_RETURN(
-      se::RedzoneAllocator scratch_allocator,
-      AutotunerUtil::CreateRedzoneAllocator(
-          config_, runtime_arguments.hlo_module_config.debug_options()));
-
-  se::dnn::ProfileResult profile_result;
-  VLOG(4) << "Trying algorithm " << alg.ToString() << " for " << instr_str;
-
-  SlowOperationAlarm alarm(absl::Seconds(1), [&] {
-    return absl::StrFormat(
-        "Trying algorithm %s for conv %s is taking a while...", alg.ToString(),
-        instr_str);
-  });
-
-  std::optional<size_t> workspace_size =
-      runner->ToAlgorithmDesc().workspace_size();
-  if (!workspace_size) {
-    return make_failure(AutotuneResult::UNKNOWN,
-                        "Internal error: missing workspace size from "
-                        "OpRunner::ToAlgorithmDesc()");
-  }
-
-  auto scratch_or = scratch_allocator.AllocateBytes(*workspace_size);
-  if (!scratch_or.ok()) {
-    return make_failure(AutotuneResult::DISQUALIFIED,
-                        absl::StrCat("Scratch allocation failed: ",
-                                     scratch_or.status().ToString()));
-  }
-  se::DeviceMemoryBase scratch_memory = scratch_or.value();
-
-  // Use assignment instead of brace-list to make GCC 4.9 happy.
-  RunConvOptions options;
-  options.runner_cache = runner;
-  // The following plan timing code is based on
-  // https://github.com/NVIDIA/cudnn-frontend/blob/60496f42fdc7a4ccc059f5934e306e728a756755/include/cudnn_frontend_find_plan.h
-  float max_time = 0;
-  float min_time = std::numeric_limits<float>::max();
-  absl::Status launch_status;
-  std::vector<se::DeviceMemoryBase> operand_buffers =
-      runtime_arguments.rz_buffers.input_buffers();
-  std::vector<se::DeviceMemoryBase> result_buffers =
-      runtime_arguments.rz_buffers.output_buffers();
-
-  TF_ASSIGN_OR_RETURN(se::Stream* const stream, config_.GetStream());
-
-  // Dry-run to warmup the plan.
-  launch_status = RunGpuConv(config, operand_buffers, result_buffers,
-                             scratch_memory, stream, options);
-  // Flag that a warm-up run has been executed; this allows the GpuTimer for
-  // the main measurement to safely use the delay kernel pattern, even if lazy
-  // module loading is enabled.
-  options.profile_result = &profile_result;
-  profile_result.set_warmup_run_executed(true);
-  constexpr int kMaxIter = 10;
-  // Iterate until the new measurement is within kThreshold of the current
-  // minimum.
-  int num_iters = 0;
-  for (; num_iters < kMaxIter && launch_status.ok(); ++num_iters) {
-    launch_status = RunGpuConv(config, operand_buffers, result_buffers,
-                               scratch_memory, stream, options);
-    if (!profile_result.is_valid()) {
-      break;
-    }
-    float old_min_time = min_time;
-    min_time = std::min(min_time, profile_result.elapsed_time_in_ms());
-    max_time = std::max(max_time, profile_result.elapsed_time_in_ms());
-
-    constexpr float kThreshold = 0.05f;
-    if (std::abs(profile_result.elapsed_time_in_ms() - old_min_time) /
-            old_min_time <
-        kThreshold) {
-      break;
-    }
-  }
-  if (!launch_status.ok()) {
-    VLOG(5) << "Launch failed: " << launch_status;
-    return make_failure(
-        AutotuneResult::DISQUALIFIED,
-        absl::StrCat("Profiling failure on cuDNN engine ", alg.ToString(), ": ",
-                     launch_status.ToString()));
-  }
-  if (!profile_result.is_valid()) {
-    VLOG(5) << "Launch succeeded but profile result is invalid.";
-    // Not DISQUALIFIED: this means something went wrong internally.
-    return make_failure(
-        AutotuneResult::UNKNOWN,
-        absl::StrCat("Launch succeeded but profile result is invalid, "
-                     "with cuDNN engine ",
-                     alg.ToString(), ": ", launch_status.ToString()));
-  }
-  VLOG(4) << "Best time: " << min_time << " ms. Worst time: " << max_time
-          << " ms. Total iterations: " << num_iters;
-  int64_t scratch_bytes_used =
-      scratch_allocator.TotalAllocatedBytesExcludingRedzones();
-
-  AutotuneResult result;
-  *result.mutable_algorithm() = alg.ToProto();
-  result.set_scratch_bytes(scratch_bytes_used);
-  *result.mutable_run_time() =
-      tsl::proto_utils::ToDurationProto(absl::Milliseconds(min_time));
-
-  if (!ShouldCheckConv(runtime_arguments.hlo_module_config)) {
-    if (!reference_result->has_value()) {
-      (*reference_result) = {
-          alg, std::vector<DeviceMemoryBase>(result_buffers.size())};
-    }
-    return result;
-  }
-
-  // Check for writes to redzones.
-  TF_ASSIGN_OR_RETURN(
-      bool input_output_allocator_redzone_clear,
-      CheckRedzones(runtime_arguments.rz_buffers.RedzoneAllocator(), stream,
-                    "input/output", instr_str, &result));
-
-  TF_ASSIGN_OR_RETURN(
-      bool scratch_allocator_redzone_clear,
-      CheckRedzones(scratch_allocator, stream, "scratch", instr_str, &result));
-
-  if (!input_output_allocator_redzone_clear ||
-      !scratch_allocator_redzone_clear) {
-    if (runtime_arguments.canonical_hlo.has_value()) {
-      std::string canonical_hlo = runtime_arguments.canonical_hlo.value();
-      std::string blas_version;
-      if (auto* blas = stream_exec->AsBlas()) {
-        (void)blas->GetVersion(&blas_version);
-      }
-
-      AlgorithmDenylist proto;
-      auto entry = proto.add_entries();
-      entry->set_hlo(canonical_hlo);
-      *entry->mutable_cc() = GetComputeCapability(stream_exec);
-      *entry->mutable_cudnn_version() = GetCudnnVersion(stream_exec);
-      entry->set_blas_version(blas_version);
-      auto algo = entry->add_algos();
-      algo->set_id(alg.algo_id());
-      algo->set_tensor_ops(alg.tensor_ops_enabled());
-
-      LOG(ERROR) << "To denylist this algorithm for this convolution, "
-                    "copy-paste the following "
-                    "proto to the denylist file pointed by XLA_FLAGS "
-                    "--xla_gpu_algorithm_denylist_path="
-                 << GetDebugOptionsFromFlags().xla_gpu_algorithm_denylist_path()
-                 << " : " << proto.ShortDebugString();
-    }
-
-    // CheckRedzones has modified the result in-place to include a failure.
-    return result;
-  }
-
-  if (reference_result->has_value()) {
-    XLA_SCOPED_LOGGING_TIMER_LEVEL("BufferComparator::CompareEqual", 2);
-
-    const DebugOptions& debug_options =
-        runtime_arguments.hlo_module_config.debug_options();
-    BufferComparator comparator(runtime_arguments.rz_buffers.output_shape(),
-                                debug_options.xla_gpu_autotune_gemm_rtol());
-    for (int i = 0; i < result_buffers.size(); ++i) {
-      absl::StatusOr<bool> compare_result = comparator.CompareEqual(
-          stream, (*reference_result)->buffers[i], result_buffers[i]);
-      if (!compare_result.ok()) {
-        LOG(ERROR) << "Unable to compare "
-                   << (*reference_result)->algorithm.ToString() << " against "
-                   << alg.ToString() << " for " << instr_str << ": "
-                   << compare_result.status();
-        if (compare_result.status().code() ==
-            absl::StatusCode::kResourceExhausted) {
-          // Possibly OOM. Propagate the error.
-          return compare_result.status();
-        }
-        CHECK(!debug_options.xla_gpu_crash_on_verification_failures());
-      } else if (!compare_result.value()) {
-        LOG(ERROR)
-            << "Results mismatch between different convolution algorithms. "
-               "This is likely a bug/unexpected loss of precision in cudnn.\n"
-            << instr_str << " for " << (*reference_result)->algorithm.ToString()
-            << " vs " << alg.ToString();
-        PrintPlatformInfo(stream);
-        if (instruction_info.has_value()) {
-          VLOG(2) << "Full module on failure: \n"
-                  << instruction_info->GetModelStr();
-        }
-        auto* fail = result.mutable_failure();
-        fail->set_kind(AutotuneResult::WRONG_RESULT);
-        fail->set_buffer_address(
-            reinterpret_cast<uint64_t>(result_buffers[i].opaque()));
-        *fail->mutable_reference_algorithm() =
-            (*reference_result)->algorithm.ToProto();
-      }
-    }
-  } else {
-    XLA_SCOPED_LOGGING_TIMER_LEVEL("Memcpy Reference Result", 2);
-    std::vector<DeviceMemoryBase> reference_result_buffers(
-        result_buffers.size());
-    for (int i = 0; i < result_buffers.size(); ++i) {
-      TF_ASSIGN_OR_RETURN(
-          reference_result_buffers[i],
-          runtime_arguments.rz_buffers.RedzoneAllocator().AllocateBytes(
-              result_buffers[i].size()));
-      TF_RETURN_IF_ERROR(stream->Memcpy(&reference_result_buffers[i],
-                                        result_buffers[i],
-                                        result_buffers[i].size()));
-    }
-    (*reference_result) = {alg, reference_result_buffers};
-  }
-
-  return result;
-}
-
-absl::StatusOr<AutotuneResult>
-GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda(
-    const HloCustomCallInstruction* instr) {
-  AutotuneCacheKey instruction_info{config_.GetModelStr(), *instr};
-  std::string instr_str(instruction_info.GetHlo());
-  XLA_SCOPED_LOGGING_TIMER(absl::StrCat(
-      "GpuConvAlgorithmPicker::PickBestAlgorithmImpl for ", instr_str));
-
-  const DebugOptions& debug_options =
-      instr->GetModule()->config().debug_options();
-  const bool crash_on_checking_failure =
-      debug_options.xla_gpu_crash_on_verification_failures();
-
-  std::string blas_version;
-  se::StreamExecutor* stream_exec = config_.GetExecutor();
-  if (auto* blas = stream_exec->AsBlas()) {
-    (void)blas->GetVersion(&blas_version);
-  }
-
-  std::vector<AlgorithmDesc> disabled_algos;
-  TF_ASSIGN_OR_RETURN(
-      AutotuneRuntimeArguments runtime_arguments,
-      AutotuneRuntimeArguments::FromInstruction(instr, config_, debug_options));
-  if (runtime_arguments.canonical_hlo.has_value()) {
-    disabled_algos = GetDisabledConvAlgorithms(
-        GetComputeCapability(stream_exec), GetCudnnVersion(stream_exec),
-        blas_version, runtime_arguments.canonical_hlo.value());
-  }
-
-  const bool cudnn_frontend_enabled =
-      debug_options.xla_gpu_enable_cudnn_frontend();
-  bool allow_tf32 = true;
-  // TODO(b/284371623): Properly set allow_tf32 even if instr==nullptr, which is
-  // the case when running an AOT compiled executable with runtime autotuning.
-  if (instr) {
-    allow_tf32 = absl::c_all_of(
-        instr->precision_config().operand_precision(),
-        [](int precision) { return precision <= PrecisionConfig::HIGH; });
-  }
-  const se::NumericOptions numeric_options{
-      RequireDeterminism(instr->GetModule()->config()), allow_tf32};
-
-  // Use the first algorithm that's supported as reference. There isn't a
-  // particular reason to use it, as any algorithm suffices. It doesn't make
-  // this algorithm considered correct, though.
-  std::optional<ReferenceResult> reference_result;
-
-  TF_ASSIGN_OR_RETURN(se::Stream* const stream, config_.GetStream());
-  TF_ASSIGN_OR_RETURN(
-      std::vector<GenericConvRunner> runners,
-      GetAlgorithms(runtime_arguments.gpu_conv_config, stream,
-                    cudnn_frontend_enabled,
-                    /* use_fallback = */ false, numeric_options));
-
-  std::vector<AutotuneResult> profile_results;
-  for (auto& runner_cache : runners) {
-    TF_ASSIGN_OR_RETURN(
-        auto result,
-        AutotuneOneConvRunner(&runner_cache, &reference_result, disabled_algos,
-                              instruction_info, runtime_arguments));
-    profile_results.emplace_back(std::move(result));
-  }
-
-  // If any algorithm has worked, we'll skip the fallback algorithms, since
-  // they include some very slow algorithms.
-  if (!reference_result) {
-    LOG(WARNING) << "None of the algorithms provided by cuDNN heuristics "
-                    "worked; trying fallback algorithms.";
-    if (runtime_arguments.canonical_hlo.has_value()) {
-      LOG(WARNING) << "Conv: " << runtime_arguments.canonical_hlo.value();
-    }
-
-    TF_ASSIGN_OR_RETURN(
-        std::vector<GenericConvRunner> fallback_runners,
-        GetAlgorithms(runtime_arguments.gpu_conv_config, stream,
-                      cudnn_frontend_enabled,
-                      /* use_fallback = */ true, numeric_options));
-
-    for (auto& runner_cache : fallback_runners) {
-      TF_ASSIGN_OR_RETURN(
-          auto result, AutotuneOneConvRunner(&runner_cache, &reference_result,
-                                             disabled_algos, instruction_info,
-                                             runtime_arguments));
-      profile_results.emplace_back(std::move(result));
-    }
-  }
-
-  // Log the autotuning result.
-  if (instr) {
-    AutotuningLog log;
-    {
-      ConvInstructionLog instr_log;
-      *instr_log.mutable_instruction() = instr->ToProto();
-      for (int i = 0; i < instr->operand_count(); i++) {
-        *instr_log.add_operand_shapes() = instr->operand(i)->shape().ToProto();
-        instr_log.add_operand_addresses(reinterpret_cast<uint64_t>(
-            runtime_arguments.rz_buffers.input_buffers()[i].opaque()));
-      }
-      for (se::DeviceMemoryBase result_buffer :
-           runtime_arguments.rz_buffers.output_buffers()) {
-        instr_log.add_result_addresses(
-            reinterpret_cast<uint64_t>(result_buffer.opaque()));
-      }
-      log.mutable_instr()->PackFrom(instr_log);
-    }
-    for (const auto& profile : profile_results) {
-      *log.add_results() = profile;
-    }
-    *log.mutable_compute_capability() = GetComputeCapability(stream_exec);
-    *log.mutable_cudnn_version() = GetCudnnVersion(stream_exec);
-    log.set_device_pci_bus_id(stream_exec->GetDeviceDescription().pci_bus_id());
-    log.set_blas_version(blas_version);
-    VLOG(2) << "Autotuning result: " << log.ShortDebugString();
-    // If we crash on checking failure, we are in a testing/benchmark mode.
-    if (crash_on_checking_failure) {
-      // Crash on miscompares and redzone violations if desired.
-      for (const auto& profile : profile_results) {
-        if (profile.has_failure() &&
-            profile.failure().kind() != AutotuneResult::DISQUALIFIED) {
-          LOG(FATAL) << "crash_on_checking_failure encountered errors:\n\n"
-                     << log.DebugString();  // NOLINT
-        }
-      }
-    }
-  }
-
-  TF_ASSIGN_OR_RETURN(AutotuneResult selected_algorithm,
-                      PickBestResult(profile_results, instr_str,
-                                     runtime_arguments.hlo_module_config));
-  return selected_algorithm;
-}
-#endif
-
-absl::StatusOr<AutotuneResult>
-GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheRocm(
-    const HloCustomCallInstruction* instr) {
-  XLA_SCOPED_LOGGING_TIMER(absl::StrCat(
-      "GpuConvAlgorithmPicker::PickBestAlgorithmImpl for ", instr->ToString()));
-
-  const bool allow_tf32 = absl::c_all_of(
-      instr->precision_config().operand_precision(),
-      [](int precision) { return precision <= PrecisionConfig::HIGH; });
-  const se::NumericOptions numeric_options{
-      RequireDeterminism(instr->GetModule()->config()), allow_tf32};
-
-  se::StreamExecutor* stream_exec = config_.GetExecutor();
-  const auto device_ordinal = stream_exec->device_ordinal();
-  std::vector<se::DeviceMemoryBase> operand_buffers;
-
-  // allocator either points to this->allocator_ or, if that's null, to a
-  // se::StreamExecutorMemoryAllocator for stream_exec.
-  se::DeviceMemoryAllocator* allocator = config_.GetAllocator();
-  ScratchAllocator input_output_allocator(device_ordinal, allocator);
-  TF_ASSIGN_OR_RETURN(se::Stream* const stream, config_.GetStream());
-  const auto initialize_buffer = [stream](DeviceMemoryBase buffer) {
-    // Although we don't have evidence this matters, zero out the buffers
-    // before autotuning.  It's conceivable that using uninitialized memory as
-    // the inputs might affect performance if e.g. the inputs contain
-    // denormals, and this is easy enough.
-    return stream->MemZero(&buffer, buffer.size());
-  };
-
-  // Allocate space for the input, filter, and output of the convolution.  We
-  // use a ScratchAllocator for this instead of calling allocator_ directly so
-  // that our allocations don't leak.
-  for (const auto* operand : instr->operands()) {
-    TF_ASSIGN_OR_RETURN(auto buffer,
-                        input_output_allocator.AllocateBytes(
-                            ShapeUtil::ByteSizeOf(operand->shape())));
-    TF_RETURN_IF_ERROR(initialize_buffer(buffer));
-    operand_buffers.push_back(buffer);
-  }
-
-  std::vector<se::DeviceMemoryBase> result_buffers(
-      instr->shape().tuple_shapes_size());
-  if (instr->shape().IsTuple()) {
-    for (int i = 0; i < instr->shape().tuple_shapes_size(); ++i) {
-      TF_ASSIGN_OR_RETURN(
-          result_buffers[i],
-          input_output_allocator.AllocateBytes(
-              ShapeUtil::ByteSizeOf(instr->shape().tuple_shapes(i))));
-      TF_RETURN_IF_ERROR(initialize_buffer(result_buffers[i]));
-    }
-  } else {
-    TF_ASSIGN_OR_RETURN(
-        result_buffers[0],
-        input_output_allocator.AllocateBytes(
-            ShapeUtil::ByteSizeOf(instr->shape().tuple_shapes(0))));
-    TF_RETURN_IF_ERROR(initialize_buffer(result_buffers[0]));
-  }
-
-  ScratchAllocator scratch_allocator(device_ordinal, allocator);
-
-  TF_ASSIGN_OR_RETURN(
-      std::vector<std::unique_ptr<const se::dnn::ConvRunner>> runners,
-      GetMIOpenAlgorithms(instr, absl::MakeSpan(operand_buffers),
-                          absl::MakeSpan(result_buffers), stream_exec,
-                          &scratch_allocator, stream, numeric_options));
-
-  std::vector<AutotuneResult> profile_results;
-
-  if (runners.size() == 1) {
-    TF_ASSIGN_OR_RETURN(auto alg, runners[0]->ToAlgorithmDesc());
-    auto algorithm_proto = alg.ToProto();
-    profile_results.emplace_back();
-    auto& result = profile_results.back();
-    *result.mutable_algorithm() = algorithm_proto;
-
-    result.set_scratch_bytes(runners[0]->GetWorkspaceSize());
-
-    // TODO(awpr): if the profile result time for a singleton algorithm is
-    // needed, plumb it via OpRunner; we'll need to do this to let TF ops avoid
-    // re-profiling ROCm algorithms anyway.
-    *result.mutable_run_time() =
-        tsl::proto_utils::ToDurationProto(absl::Milliseconds(-1));
-  } else {
-    TF_ASSIGN_OR_RETURN(GpuConvConfig config, GetGpuConvConfig(instr));
-    for (auto& runner : runners) {
-      TF_ASSIGN_OR_RETURN(auto alg, runner->ToAlgorithmDesc());
-      XLA_SCOPED_LOGGING_TIMER_LEVEL(
-          absl::StrCat("CudnnConvAlgorithmPicker::PickBestAlgorithm algo ",
-                       alg.ToString()),
-          2);
-
-      se::dnn::ProfileResult profile_result;
-      VLOG(4) << "Trying algorithm " << alg.ToString() << " for "
-              << instr->ToString();
-
-      TF_ASSIGN_OR_RETURN(
-          DeviceMemoryBase scratch_memory,
-          scratch_allocator.AllocateBytes(runner->GetWorkspaceSize()));
-
-      TF_ASSIGN_OR_RETURN(auto lazy_runner,
-                          se::dnn::LazyOpRunner<se::dnn::ConvOp>::FromOpRunner(
-                              std::move(runner)));
-
-      GenericConvRunner runner_cache(std::move(lazy_runner));
-
-      // Use assignment instead of brace-list to make GCC 4.9 happy.
-      RunConvOptions options;
-      options.profile_result = &profile_result;
-      options.runner_cache = &runner_cache;
-      absl::Status launch_status =
-          RunGpuConv(config, absl::MakeSpan(operand_buffers), result_buffers,
-                     scratch_memory, stream, options);
-
-      if (!launch_status.ok()) {
-        continue;
-      }
-
-      if (!profile_result.is_valid()) {
-        continue;
-      }
-
-      profile_results.emplace_back();
-      AutotuneResult& result = profile_results.back();
-      *result.mutable_algorithm() = alg.ToProto();
-
-      int64_t scratch_bytes_used = scratch_allocator.TotalAllocatedBytes();
-      result.set_scratch_bytes(scratch_bytes_used);
-      *result.mutable_run_time() = tsl::proto_utils::ToDurationProto(
-          absl::Milliseconds(profile_result.elapsed_time_in_ms()));
-    }
-  }
-
-  TF_ASSIGN_OR_RETURN(AutotuneResult selected_algorithm,
-                      PickBestResult(profile_results, instr->ToString(),
-                                     instr->GetModule()->config()));
-  return selected_algorithm;
-}
-
-absl::StatusOr<bool> GpuConvAlgorithmPicker::RunOnInstruction(
-    HloInstruction* instr) {
-  CHECK(IsCustomCallToDnnConvolution(*instr));
-
-  const bool strict = instr->parent()
-                          ->parent()
-                          ->config()
-                          .debug_options()
-                          .xla_gpu_strict_conv_algorithm_picker();
-
-  absl::StatusOr<AutotuneResult> best_algo_or =
-      PickBestAlgorithm(Cast<HloCustomCallInstruction>(instr));
-  if (!best_algo_or.ok()) {
-    auto msg = absl::StrFormat(
-        "Failed to determine best cudnn convolution algorithm for:\n%s\n\n"
-        "Original error: %s",
-        instr->ToString(), best_algo_or.status().ToString());
-
-    if (strict) {
-      return Unknown(
-          "%s\n\nTo ignore this failure and try to use a fallback algorithm "
-          "(which may have suboptimal performance), use "
-          "XLA_FLAGS=--xla_gpu_strict_conv_algorithm_picker=false.  Please "
-          "also file a bug for the root cause of failing autotuning.",
-          msg);
-    }
-    LOG(WARNING)
-        << msg << "\n\nAs a result, convolution performance may be suboptimal.";
-    return false;
-  }
-
-  auto best_algo = std::move(best_algo_or).value();
-  VLOG(3) << "Setting cudnn conv to use algorithm "
-          << best_algo.conv().algorithm() << " and "
-          << NumBytesToString(best_algo.scratch_bytes())
-          << " of scratch memory: " << instr->ToString()
-          << " tensor_ops_enabled: " << best_algo.conv().tensor_ops_enabled();
-
-  // Replace instr with a new CustomCall which has the correct algorithm, and
-  // whose output shape has the appropriate amount of scratch memory.
-  HloComputation* computation = instr->parent();
-  std::vector<Shape> new_call_element_shapes;
-  // Add the shapes of the outputs of the convolution.
-  new_call_element_shapes.reserve(instr->shape().tuple_shapes_size() - 1);
-  for (int i = 0; i < instr->shape().tuple_shapes_size() - 1; ++i) {
-    new_call_element_shapes.emplace_back(instr->shape().tuple_shapes(i));
-  }
-  // The final element is the size of the workspace.
-  new_call_element_shapes.emplace_back(
-      ShapeUtil::MakeShape(U8, {best_algo.scratch_bytes()}));
-  Shape new_call_shape = ShapeUtil::MakeTupleShape(new_call_element_shapes);
-
-  TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_backend_config,
-                      instr->backend_config<GpuBackendConfig>());
-  CudnnConvBackendConfig& backend_config =
-      *gpu_backend_config.mutable_cudnn_conv_backend_config();
-  *backend_config.mutable_algorithm() = best_algo.algorithm();
-  backend_config.mutable_algorithm()->mutable_workspace_size()->set_value(
-      best_algo.scratch_bytes());
-
-  HloInstruction* new_call = computation->AddInstruction(
-      instr->CloneWithNewOperands(new_call_shape, instr->operands()));
-
-  // Preserve the name of the old instruction.  This is safe because we're going
-  // to remove the old one anyway, and it makes it easier to trace how our conv
-  // is transformed through all our passes.
-  new_call->SetAndSanitizeName(instr->name());
-
-  VLOG(3) << "Replacing convolution " << instr->ToString() << " with "
-          << new_call->ToString();
-
-  TF_RETURN_IF_ERROR(new_call->set_backend_config(gpu_backend_config));
-
-  std::vector<HloInstruction*> new_tuple_elements;
-  new_tuple_elements.reserve(new_call->shape().tuple_shapes_size() - 1);
-  for (int i = 0; i < new_call->shape().tuple_shapes_size() - 1; ++i) {
-    new_tuple_elements.emplace_back(
-        computation->AddInstruction(HloInstruction::CreateGetTupleElement(
-            new_call->shape().tuple_shapes(i), new_call, i)));
-  }
-  new_tuple_elements.emplace_back(computation->AddInstruction(
-      HloInstruction::CreateConstant(LiteralUtil::CreateR1<uint8_t>({}))));
-
-  // Repackage new_call so it has the same shape as the original call, namely
-  // (conv_result, u8[0]).
-  HloInstruction* new_tuple = computation->AddInstruction(
-      HloInstruction::CreateTuple(new_tuple_elements));
-
-  TF_RETURN_IF_ERROR(instr->parent()->ReplaceInstruction(instr, new_tuple));
-  return true;
-}
-
-absl::StatusOr<bool> GpuConvAlgorithmPicker::RunOnComputation(
-    HloComputation* computation) {
-  std::vector<HloInstruction*> convs;
-  for (HloInstruction* instr : computation->instructions()) {
-    if (IsCandidate(instr)) {
-      convs.push_back(instr);
-    }
-  }
-
-  bool changed = false;
-  for (HloInstruction* instr : convs) {
-    TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(instr));
-    changed |= result;
-  }
-  return changed;
-}
-
-absl::StatusOr<bool> GpuConvAlgorithmPicker::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  XLA_SCOPED_LOGGING_TIMER(
-      absl::StrCat("GpuConvAlgorithmPicker for ", module->name()));
-
-  if (!IsEnabled(module)) {
-    VLOG(3) << "Convolution auto-tuning disabled, GpuConvAlgorithmPicker "
-               "returning early.";
-    return false;
-  }
-
-  bool changed = false;
-  for (HloComputation* computation :
-       module->MakeNonfusionComputations(execution_threads)) {
-    TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation));
-    changed |= result;
-  }
-  return changed;
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/conv_algorithm_picker.h b/third_party/xla/xla/service/gpu/conv_algorithm_picker.h
deleted file mode 100644
index e6dea8b..0000000
--- a/third_party/xla/xla/service/gpu/conv_algorithm_picker.h
+++ /dev/null
@@ -1,159 +0,0 @@
-/* Copyright 2018 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_CONV_ALGORITHM_PICKER_H_
-#define XLA_SERVICE_GPU_CONV_ALGORITHM_PICKER_H_
-
-#include <optional>
-#include <string>
-#include <vector>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "xla/autotune_results.pb.h"
-#include "xla/autotuning.pb.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/gpu/autotuner_compile_util.h"
-#include "xla/service/gpu/autotuner_util.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/service/gpu/gpu_conv_runner.h"
-#include "xla/service/hlo_module_config.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/shape.h"
-#include "xla/stream_executor/device_memory.h"
-#include "xla/stream_executor/device_memory_allocator.h"
-#include "xla/stream_executor/dnn.h"
-#include "xla/stream_executor/stream_executor.h"
-
-#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
-#include "xla/stream_executor/gpu/redzone_allocator.h"
-#endif
-
-namespace xla {
-namespace gpu {
-
-// Choose the fastest algorithm for each conv.
-// Modifies CustomCalls to cudnn convolutions, choosing the best algorithm for
-// each and adding explicit scratch space to the CustomCalls.
-//
-// We pick the algorithm before fusion so that we can generate better HLO. After
-// GpuConvRewriter, our convolutions are CustomCalls which return a
-// tuple (conv_result, scratch_memory), and the each conv uses 0 bytes of
-// scratch:
-//
-//   customcall = (f32[...], f32[0])
-//   return gte(customcall, 0)
-//
-// The algorithm picker then chooses the best algorithm, and potentially
-// increases the scratch space.  It replaces customcall with new_tuple,
-// giving us the following:
-//
-//   new_customcall = (f32[...], f32[N])
-//   new_tuple = tuple(gte(new_customcall, 0), constant f32[0])
-//   return gte(new_tuple, 0)
-//
-// The new tuple and gte instructions can be simplified away, because
-// nobody is expected to use the scratch value.
-//
-// However, if we were to run GpuConvAlgorithmPicker after fusion
-// the gte(customcall, 0) would probably already be into a fusion node.  We
-// can't simplify across HloComputation boundaries, so in this case we
-// wouldn't be able to simplify away the new_tuple bits.
-//
-// It supports two modes: device and deviceless.
-// In device mode, we run autotuning on the device and store autotune results.
-//
-// In deviceless mode, we pass in some information related to the device and
-// use stored autotune results to rewrite convolutions. If the required autotune
-// result is not stored, then the performance of convolution will be suboptimal.
-class GpuConvAlgorithmPicker : public HloModulePass {
- public:
-  explicit GpuConvAlgorithmPicker(AutotuneConfig config) : config_(config) {}
-
-  absl::string_view name() const override {
-    return "gpu-conv-algorithm-picker";
-  }
-
-  static bool IsEnabled(const HloModule* module) {
-    return module->config().debug_options().xla_gpu_autotune_level() != 0;
-  }
-
-  static bool IsCandidate(const HloInstruction* instr) {
-    return IsCustomCallToDnnConvolution(*instr);
-  }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
-  absl::StatusOr<bool> RunOnComputation(HloComputation* computation);
-  absl::StatusOr<bool> RunOnInstruction(HloInstruction* instr);
-
-  absl::StatusOr<AutotuneResult> PickBestAlgorithm(
-      const HloCustomCallInstruction* instr);
-  absl::StatusOr<AutotuneResult> PickBestAlgorithmNoCache(
-      const HloCustomCallInstruction* instr);
-
-#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
-  // Simple bundle of an algorithm and its output, for comparing results across
-  // autotuned algorithms.
-  struct ReferenceResult {
-    stream_executor::dnn::AlgorithmDesc algorithm;
-    std::vector<stream_executor::DeviceMemoryBase> buffers;
-  };
-
-  // Execution environment for autotuning. Runtime autotuning requires runtime
-  // information such as input/output buffers in order to run. It can be
-  // constructed from the autotuned instruction by FromInstruction.
-  struct AutotuneRuntimeArguments {
-    const HloModuleConfig hlo_module_config;
-    RedzoneBuffers rz_buffers;
-    const GpuConvConfig gpu_conv_config;
-    std::optional<std::string> canonical_hlo;
-
-    static absl::StatusOr<AutotuneRuntimeArguments> FromInstruction(
-        const HloCustomCallInstruction* instr, const AutotuneConfig& config,
-        const DebugOptions& debug_options);
-  };
-
-  absl::StatusOr<AutotuneResult> AutotuneOneConvRunner(
-      GenericConvRunner* runner,
-      std::optional<ReferenceResult>* reference_result,
-      absl::Span<const stream_executor::dnn::AlgorithmDesc> disabled_algos,
-      std::optional<AutotuneCacheKey> instruction_info,
-      const AutotuneRuntimeArguments& runtime_arguments);
-
-  // Pick the best algorithm for CUDA platform.
-  absl::StatusOr<AutotuneResult> PickBestAlgorithmNoCacheCuda(
-      const HloCustomCallInstruction* instr);
-#endif
-
-  absl::StatusOr<AutotuneResult> PickBestAlgorithmNoCacheRocm(
-      const HloCustomCallInstruction* instr);
-
- private:
-  AutotuneConfig config_;
-};
-
-}  // namespace gpu
-}  // namespace xla
-#endif  // XLA_SERVICE_GPU_CONV_ALGORITHM_PICKER_H_
diff --git a/third_party/xla/xla/service/gpu/conv_algorithm_picker_test.cc b/third_party/xla/xla/service/gpu/conv_algorithm_picker_test.cc
deleted file mode 100644
index aa7c5e2..0000000
--- a/third_party/xla/xla/service/gpu/conv_algorithm_picker_test.cc
+++ /dev/null
@@ -1,129 +0,0 @@
-/* Copyright 2022 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/conv_algorithm_picker.h"
-
-#include <cstdint>
-#include <variant>
-#include <vector>
-
-#include "absl/strings/string_view.h"
-#include "xla/debug_options_flags.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/service/gpu/autotuner_util.h"
-#include "xla/service/gpu/gpu_conv_rewriter.h"
-#include "xla/service/gpu/stream_executor_util.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/service/platform_util.h"
-#include "xla/service/tuple_simplifier.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/stream_executor/platform.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/tsl/lib/core/status_test_util.h"
-#include "tsl/platform/statusor.h"
-#include "tsl/platform/test.h"
-
-namespace xla::gpu {
-namespace {
-
-namespace m = ::xla::match;
-
-class GpuConvAlgorithmPickerTest : public HloTestBase {
- public:
-  GpuConvAlgorithmPickerTest() { AutotunerUtil::ClearAutotuneResults(); }
-};
-
-TEST_F(GpuConvAlgorithmPickerTest, SetAlgorithm) {
-  constexpr absl::string_view kHlo = R"(
-HloModule module
-
-ENTRY main {
-  %arg0 = f32[3,56,56,16]{2,1,0,3} parameter(0)
-  %arg1 = f32[3,3,3,64]{2,1,0,3} parameter(1)
-  ROOT %conv = f32[54,54,16,64]{1,0,3,2} convolution(%arg0, %arg1), window={size=3x3}, dim_labels=f01b_i01o->01bf
-})";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kHlo));
-
-  se::Platform* platform = PlatformUtil::GetDefaultPlatform().value();
-  TF_ASSERT_OK_AND_ASSIGN(std::vector<se::StreamExecutor*> executors,
-                          PlatformUtil::GetStreamExecutors(platform));
-  ASSERT_GT(executors.size(), 0);
-  se::StreamExecutor* stream_exec = executors[0];
-
-  const se::GpuComputeCapability& cc = backend()
-                                           .default_stream_executor()
-                                           ->GetDeviceDescription()
-                                           .gpu_compute_capability();
-  bool changed = false;
-  TF_ASSERT_OK_AND_ASSIGN(changed, RunHloPass(GpuConvRewriter(cc), m.get()));
-  changed = false;
-  DebugOptions opts = DefaultDebugOptionsIgnoringFlags();
-
-  AutotuneConfig cfg{DeviceConfig{stream_exec, nullptr}, opts};
-  TF_ASSERT_OK_AND_ASSIGN(changed,
-                          RunHloPass(GpuConvAlgorithmPicker(cfg), m.get()));
-  ASSERT_TRUE(changed);
-
-  AutotuneResults results;
-  TF_ASSERT_OK(AutotunerUtil::SerializeAutotuneResults(&results));
-  ASSERT_EQ(results.results_size(), 1);
-  auto& result = *results.mutable_results(0)->mutable_result();
-  int64_t old_scratch_bytes = result.scratch_bytes();
-  int64_t new_scratch_bytes = old_scratch_bytes + 1;
-  result.set_scratch_bytes(new_scratch_bytes);
-
-  AutotunerUtil::ClearAutotuneResults();
-  TF_ASSERT_OK(AutotunerUtil::LoadAutotuneResults(results));
-
-  // Now send the same module through GpuConvAlgorithmPicker again.  The conv
-  // should have the new scratch bytes.
-  TF_ASSERT_OK_AND_ASSIGN(m, ParseAndReturnVerifiedModule(kHlo));
-  changed = false;
-  TF_ASSERT_OK_AND_ASSIGN(changed, RunHloPass(GpuConvRewriter(cc), m.get()));
-  changed = false;
-  TF_ASSERT_OK_AND_ASSIGN(changed,
-                          RunHloPass(GpuConvAlgorithmPicker(cfg), m.get()));
-  ASSERT_TRUE(changed);
-
-  // TupleSimplifier cleans this up a bit before we pattern-match
-  TF_ASSERT_OK(RunHloPass(TupleSimplifier(), m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  HloInstruction* conv;
-  ASSERT_THAT(m->entry_computation()->root_instruction(),
-              GmockMatch(m::GetTupleElement(m::CustomCall(&conv))));
-  EXPECT_THAT(
-      conv->shape(),
-      GmockMatch(m::Shape().WithSubshape(
-          {1}, m::Shape().WithElementType(U8).WithDims({new_scratch_bytes}))));
-
-  // Algorithm 14 is disabled for cuDNN 9 on V100
-  TF_ASSERT_OK_AND_ASSIGN(auto dnn_version, GetDnnVersionInfo(stream_exec));
-  if (dnn_version.major_version() >= 9 && dnn_version.major_version() < 10 &&
-      std::holds_alternative<stream_executor::CudaComputeCapability>(cc) &&
-      std::get<stream_executor::CudaComputeCapability>(cc).major == 7 &&
-      std::get<stream_executor::CudaComputeCapability>(cc).minor == 0) {
-    EXPECT_TRUE(conv->backend_config<GpuBackendConfig>()
-                    ->has_cudnn_conv_backend_config() &&
-                conv->backend_config<GpuBackendConfig>()
-                        ->cudnn_conv_backend_config()
-                        .algorithm()
-                        .algo_id() != 14);
-  }
-}
-
-}  // namespace
-}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/copy_fusion.cc b/third_party/xla/xla/service/gpu/copy_fusion.cc
deleted file mode 100644
index a83354c..0000000
--- a/third_party/xla/xla/service/gpu/copy_fusion.cc
+++ /dev/null
@@ -1,197 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/copy_fusion.h"
-
-#include <cstdint>
-#include <queue>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/gpu/gpu_fusible.h"
-#include "xla/service/gpu/hlo_traversal.h"
-#include "xla/service/gpu/ir_emission_utils.h"
-#include "xla/service/gpu/reduction_utils.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h"
-
-namespace xla {
-namespace gpu {
-
-bool OnlyElementwiseOpsReachableFromParams(HloComputation* fused_computation) {
-  std::queue<const HloInstruction*> q;
-  absl::flat_hash_set<const HloInstruction*> visited;
-  for (auto param : fused_computation->parameter_instructions()) {
-    q.push(param);
-    visited.insert(param);
-  }
-  while (!q.empty()) {
-    const HloInstruction* hlo = q.front();
-    q.pop();
-    for (auto user : hlo->users()) {
-      if ((!user->IsElementwiseOnOperand(user->operand_index(hlo)) ||
-           user->opcode() == HloOpcode::kCopy) &&
-          user->opcode() != HloOpcode::kBitcast &&
-          user->opcode() != HloOpcode::kTuple) {
-        return false;
-      }
-      if (visited.insert(user).second) {
-        q.push(user);
-      }
-    }
-  }
-  return true;
-}
-
-absl::StatusOr<bool> CopyFusion::DoCopyFusion(HloComputation* computation) {
-  bool changed = false;
-  std::vector<HloInstruction*> defs_before_uses =
-      computation->MakeInstructionPostOrder();
-
-  for (HloInstruction* hlo : defs_before_uses) {
-    if (hlo->opcode() != HloOpcode::kFusion) {
-      continue;
-    }
-    std::vector<HloInstruction*> copies;
-    std::vector<HloInstruction*> other_users;
-    HloComputation* fused_computation = hlo->fused_instructions_computation();
-    if (!OnlyElementwiseOpsReachableFromParams(fused_computation)) {
-      continue;
-    }
-    HloInstruction* root = fused_computation->root_instruction();
-    if (IsReductionFromOrToContiguousDimensions(*root) ||
-        root->opcode() == HloOpcode::kScatter ||
-        (hlo->IsMultiOutputFusion() &&
-         absl::c_all_of(root->operands(), [](const HloInstruction* slice) {
-           return slice->opcode() == HloOpcode::kSlice;
-         }))) {
-      continue;
-    }
-    for (auto user : hlo->users()) {
-      HloInstruction* copy_user = user;
-      // Skip get-tuple-element ops.
-      if (copy_user->opcode() == HloOpcode::kGetTupleElement &&
-          copy_user->user_count() == 1) {
-        if (IsReductionFromOrToContiguousDimensions(
-                *(root->operand(copy_user->tuple_index())))) {
-          other_users.push_back(user);
-          continue;
-        }
-        copy_user = copy_user->users()[0];
-      }
-      // Skip bitcast ops.
-      if (copy_user->opcode() == HloOpcode::kBitcast &&
-          copy_user->user_count() == 1) {
-        copy_user = copy_user->users()[0];
-      }
-      if (copy_user->opcode() == HloOpcode::kCopy &&
-          copy_user->shape() == copy_user->operand(0)->shape() &&
-          !copy_user->shape().IsTuple() &&
-          !copy_user->HasControlDependencies()) {
-        copies.push_back(copy_user);
-      } else {
-        other_users.push_back(user);
-      }
-    }
-    if (copies.empty()) {
-      continue;
-    }
-    auto fusion_adaptor = HloFusionAdaptor::ForComputation(fused_computation);
-    auto dynamic_update_slices =
-        GetOutputDefiningDynamicUpdateSlices(fusion_adaptor->GetRoots());
-    // Skip dynamic update slice fusions which might be emitted in-place.
-    if (!dynamic_update_slices.empty() &&
-        (root->opcode() != HloOpcode::kTuple ||
-         dynamic_update_slices.size() == root->shape().tuple_shapes_size())) {
-      continue;
-    }
-    changed = true;
-
-    HloInstruction::InstructionVector tuple_elements;
-    int64_t num_outputs =
-        hlo->IsMultiOutputFusion() ? root->operand_count() : int64_t{1};
-    tuple_elements.reserve(copies.size() + num_outputs);
-    if (hlo->IsMultiOutputFusion()) {
-      for (HloInstruction* operand : root->operands()) {
-        tuple_elements.push_back(operand);
-      }
-    } else {
-      tuple_elements.push_back(root);
-    }
-
-    for (auto copy : copies) {
-      HloInstruction* user = copy;
-      std::vector<HloInstruction*> operand_chain;
-      operand_chain.push_back(user);
-      while (user->operand(0) != hlo) {
-        user = user->mutable_operand(0);
-        operand_chain.push_back(user);
-      }
-      HloInstruction* clone_operand = root;
-      if (hlo->IsMultiOutputFusion()) {
-        clone_operand = root->mutable_operand(user->tuple_index());
-        CHECK_EQ(operand_chain.back()->opcode(), HloOpcode::kGetTupleElement);
-        operand_chain.pop_back();
-      }
-      for (int64_t i = operand_chain.size() - 1; i >= 0; --i) {
-        HloInstruction* user = operand_chain[i];
-        clone_operand = fused_computation->AddInstruction(
-            user->CloneWithNewOperands(user->shape(), {clone_operand}));
-      }
-      tuple_elements.push_back(clone_operand);
-    }
-
-    HloInstruction* new_root = fused_computation->AddInstruction(
-        HloInstruction::CreateTuple(tuple_elements));
-    fused_computation->set_root_instruction(new_root,
-                                            /*accept_different_shape=*/true);
-    *hlo->mutable_shape() = new_root->shape();
-
-    if (root->opcode() == HloOpcode::kTuple) {
-      TF_RETURN_IF_ERROR(fused_computation->RemoveInstruction(root));
-    } else {
-      auto get_tuple_element_root = computation->AddInstruction(
-          HloInstruction::CreateGetTupleElement(hlo, 0));
-      TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWithDifferentShape(
-          other_users, get_tuple_element_root));
-    }
-    for (int64_t i = 0; i < copies.size(); ++i) {
-      auto get_tuple_element = computation->AddInstruction(
-          HloInstruction::CreateGetTupleElement(hlo, num_outputs + i));
-      TF_RETURN_IF_ERROR(
-          computation->ReplaceInstruction(copies[i], get_tuple_element));
-    }
-  }
-  return changed;
-}
-
-absl::StatusOr<bool> CopyFusion::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  // Only for the entry computation we can be sure that the copies do not share
-  // a buffer with a parameter of the fusion that it will be fused with. For
-  // example while loop computations have tuple parameters that need to share
-  // the buffers with the output tuples, and copies inserted by the
-  // CopyInsertion pass will share a buffer with the tuple output (and thus
-  // with the tuple input as well).
-  return DoCopyFusion(module->entry_computation());
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/copy_fusion.h b/third_party/xla/xla/service/gpu/copy_fusion.h
deleted file mode 100644
index 973b671..0000000
--- a/third_party/xla/xla/service/gpu/copy_fusion.h
+++ /dev/null
@@ -1,49 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_COPY_FUSION_H_
-#define XLA_SERVICE_GPU_COPY_FUSION_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-// CopyFusion checks if a fusion is followed by multiple copies and if so, adds
-// those copies to the fusion, replacing the copies with get_tuple_elements.
-class CopyFusion : public HloModulePass {
- public:
-  CopyFusion() = default;
-
-  absl::string_view name() const override { return "copy_fusion"; }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
-  absl::StatusOr<bool> DoCopyFusion(HloComputation* computation);
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_COPY_FUSION_H_
diff --git a/third_party/xla/xla/service/gpu/copy_fusion_test.cc b/third_party/xla/xla/service/gpu/copy_fusion_test.cc
deleted file mode 100644
index d2116eb..0000000
--- a/third_party/xla/xla/service/gpu/copy_fusion_test.cc
+++ /dev/null
@@ -1,500 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/copy_fusion.h"
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "absl/strings/str_cat.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/tests/hlo_test_base.h"
-
-namespace xla {
-namespace gpu {
-
-namespace m = ::xla::match;
-
-class CopyFusionTest : public HloTestBase {
- public:
-  CopyFusion cf_;
-};
-
-const char kModulePrefix[] = R"(
-    HloModule test_module
-
-    scalar_add_computation {
-      scalar_lhs.0 = f32[] parameter(0)
-      scalar_rhs.0 = f32[] parameter(1)
-      ROOT add.0 = f32[] add(scalar_lhs.0, scalar_rhs.0)
-    }
-    scalar_mul_computation {
-      scalar_lhs.1 = f32[] parameter(0)
-      scalar_rhs.1 = f32[] parameter(1)
-      ROOT mul.1 = f32[] multiply(scalar_lhs.1, scalar_rhs.1)
-    })";
-
-TEST_F(CopyFusionTest, CopyFusionTransposeOfBroadcastedConstantTwoCopies) {
-  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
-    fused_computation {
-      two = f32[] constant(2.0)
-      broadcast = f32[16,32]{1,0} broadcast(two), dimensions={}
-      s.1 = f32[16,32]{1,0} sqrt(broadcast)
-      ROOT c.1 = f32[32,16]{1,0} transpose(s.1), dimensions={1,0}
-    }
-
-    ENTRY main {
-      fusion = f32[32,16]{1,0} fusion(), kind=kInput, calls=fused_computation
-      copy.1 = f32[32,16]{1,0} copy(fusion)
-      copy.2 = f32[32,16]{1,0} copy(fusion)
-      ROOT t = (f32[32,16]{1,0}, f32[32,16]{1,0}) tuple(copy.2, copy.1)
-    })"))
-                    .value();
-  ASSERT_TRUE(cf_.Run(module.get()).value());
-  SCOPED_TRACE(module->ToString());
-  const HloInstruction* root = module->entry_computation()->root_instruction();
-  const HloInstruction* fusion = nullptr;
-  ASSERT_THAT(root, GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
-                                        m::GetTupleElement())));
-  EXPECT_THAT(fusion->fused_expression_root(),
-              GmockMatch(m::Tuple(m::Transpose(), m::Copy(), m::Copy())));
-}
-
-TEST_F(CopyFusionTest, CopyFusionTransposeTwoCopies) {
-  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
-    fused_computation {
-      param_0.1 = f32[16,32]{1,0} parameter(0)
-      s.1 = f32[16,32]{1,0} sqrt(param_0.1)
-      ROOT c.1 = f32[32,16]{1,0} transpose(s.1), dimensions={1,0}
-    }
-
-    ENTRY main {
-      p = f32[16,32]{1,0} parameter(0)
-      fusion = f32[32,16]{1,0} fusion(p), kind=kInput, calls=fused_computation
-      copy.1 = f32[32,16]{1,0} copy(fusion)
-      copy.2 = f32[32,16]{1,0} copy(fusion)
-      ROOT t = (f32[32,16]{1,0}, f32[32,16]{1,0}) tuple(copy.2, copy.1)
-    })"))
-                    .value();
-  ASSERT_FALSE(cf_.Run(module.get()).value());
-}
-
-TEST_F(CopyFusionTest, CopyFusionNegateAndTwoCopies) {
-  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
-    fused_computation {
-      p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(0)
-      mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1)
-      ROOT neg = f32[128,512,28,28]{3,2,1,0} negate(mul)
-    }
-
-    ENTRY entry {
-      p0 = f32[128,512,28,28]{3,2,1,0} parameter(0)
-      fusion = f32[128,512,28,28]{3,2,1,0} fusion(p0), kind=kInput, calls=fused_computation
-      copy.1 = f32[128,512,28,28]{3,2,1,0} copy(fusion)
-      copy.2 = f32[128,512,28,28]{3,2,1,0} copy(fusion)
-      ROOT root = (f32[128,512,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0}) tuple(copy.1, copy.2)
-    })"))
-                    .value();
-  ASSERT_TRUE(cf_.Run(module.get()).value());
-  SCOPED_TRACE(module->ToString());
-  const HloInstruction* root = module->entry_computation()->root_instruction();
-  const HloInstruction* fusion = nullptr;
-  ASSERT_THAT(root, GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
-                                        m::GetTupleElement())));
-  EXPECT_THAT(fusion->fused_expression_root(),
-              GmockMatch(m::Tuple(m::Negate(), m::Copy(), m::Copy())));
-}
-
-TEST_F(CopyFusionTest, CopyFusionShouldNotRunWithReduce) {
-  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
-    fused_computation {
-      p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
-      mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1)
-      const.1 = f32[] parameter(0)
-      ROOT reduce.1 = f32[512]{0} reduce(mul, const.1), dimensions={0,2,3}, to_apply=scalar_add_computation
-    }
-
-    ENTRY entry {
-      p0 = f32[] parameter(0)
-      p1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
-      fusion = f32[512] fusion(p0, p1), kind=kInput, calls=fused_computation
-      copy.1 = f32[512]{0} copy(fusion)
-      copy.2 = f32[512]{0} copy(fusion)
-      ROOT root = (f32[512]{0}, f32[512]{0}) tuple(copy.1, copy.2)
-    })"))
-                    .value();
-  ASSERT_FALSE(cf_.Run(module.get()).value());
-}
-
-TEST_F(CopyFusionTest, CopyFusionShouldRunWithUncopiedReduce) {
-  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
-    fused_computation {
-      two = f32[] constant(2.0)
-      broadcast = f32[128,512,28,28]{3,2,1,0} broadcast(two)
-      mul = f32[128,512,28,28]{3,2,1,0} multiply(broadcast, broadcast)
-      const = f32[] constant(0.0)
-      reduce = f32[512]{0} reduce(mul, const), dimensions={0,2,3}, to_apply=scalar_add_computation
-      ROOT tuple = (f32[128,512,28,28]{3,2,1,0}, f32[512]{0}) tuple(mul, reduce)
-    }
-
-    ENTRY entry {
-      fusion = (f32[128,512,28,28]{3,2,1,0}, f32[512]) fusion(), kind=kInput, calls=fused_computation
-      gte = f32[128,512,28,28]{3,2,1,0} get-tuple-element(fusion), index=0
-      gte.2 = f32[512]{0} get-tuple-element(fusion), index=1
-      copy.1 = f32[128,512,28,28]{3,2,1,0} copy(gte)
-      ROOT root = (f32[128,512,28,28]{3,2,1,0}, f32[512]{0}) tuple(copy.1, gte.2)
-    })"))
-                    .value();
-  ASSERT_TRUE(cf_.Run(module.get()).value());
-  SCOPED_TRACE(module->ToString());
-  const HloInstruction* root = module->entry_computation()->root_instruction();
-  const HloInstruction* fusion = nullptr;
-  ASSERT_THAT(root, GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
-                                        m::GetTupleElement())));
-  EXPECT_THAT(fusion->fused_expression_root(),
-              GmockMatch(m::Tuple(m::Multiply(), m::Reduce(), m::Copy())));
-}
-
-TEST_F(CopyFusionTest, CopyFusionShouldNotFuseForSliceMultioutputFusion) {
-  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
-    fused_computation {
-      p1 = f32[128,512,28,28]{3,2,1,0} parameter(0)
-      mul = f32[128,512,28,28]{3,2,1,0} multiply(p1, p1)
-      slice1 = f32[128,100,28,28]{3,2,1,0} slice(mul), slice={[0:128],[0:100],[0:28],[0:28]}
-      slice2 = f32[128,200,28,28]{3,2,1,0} slice(mul), slice={[0:128],[50:250],[0:28],[0:28]}
-      ROOT tuple = (f32[128,100,28,28]{3,2,1,0}, f32[128,200,28,28]{3,2,1,0}) tuple(slice1, slice2)
-    }
-
-    ENTRY entry {
-      p1 = f32[128,512,28,28]{3,2,1,0} parameter(0)
-      ROOT fusion = (f32[128,100,28,28]{3,2,1,0}, f32[128,200,28,28]{3,2,1,0}) fusion(p1), kind=kInput, calls=fused_computation
-    })"))
-                    .value();
-  ASSERT_FALSE(cf_.Run(module.get()).value());
-}
-
-TEST_F(CopyFusionTest, CopyFusionShouldNotRunWithScatter) {
-  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
-    fused_computation {
-      p0 = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
-      scatter_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
-      updates = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} parameter(2)
-      input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} negate(p0)
-      ROOT %scatter = f32[50,49,48,47,46]{4,3,2,1,0} scatter(input_tensor, scatter_indices, updates), update_window_dims={4,5,6,7,8}, inserted_window_dims={}, scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, to_apply=scalar_add_computation
-}
-
-    ENTRY entry {
-      param.0 = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
-      param.1 = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
-      param.2 = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} parameter(2)
-      fusion = f32[50,49,48,47,46]{4,3,2,1,0} fusion(param.0, param.1, param.2), kind=kInput, calls=fused_computation
-      ROOT copy = f32[50,49,48,47,46]{4,3,2,1,0} copy(fusion)
-    })"))
-                    .value();
-  ASSERT_FALSE(cf_.Run(module.get()).value());
-}
-
-TEST_F(CopyFusionTest, CopyFusionShouldNotRunOutsideEntryComputation) {
-  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
-fused_computation.549 {
-  param_0.8511 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} parameter(0)
-  bitcast.52601 = bf16[15,1,2,48,128,2048]{5,4,3,2,1,0} bitcast(param_0.8511)
-  slice = bf16[15,1,2,48,128,1]{5,4,3,2,1,0} slice(bitcast.52601), slice={[0:15:1], [0:1:1], [0:2:1], [0:48:1], [0:128:1], [0:1:1]}
-  bitcast = bf16[15,1,2,48,128]{4,3,2,1,0} bitcast(slice)
-  ROOT broadcast = bf16[15,1,2,48,128,2048]{5,4,3,2,1,0} broadcast(bitcast), dimensions={0,1,2,3,4}
-}
-
-condition {
-  constant_6915 = s32[] constant(15)
-  param.218 = (bf16[15,1,2,2048,48,128]{3,5,4,2,1,0}, s32[]) parameter(0)
-  get-tuple-element.3714 = s32[] get-tuple-element(param.218), index=1
-  ROOT compare.1738 = pred[] compare(get-tuple-element.3714, constant_6915), direction=LT
-}
-
-body {
-  tuple_param = (bf16[15,1,2,2048,48,128]{3,5,4,2,1,0}, s32[]) parameter(0)
-  param_0 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} get-tuple-element(tuple_param), index=0
-  param_1 = s32[] get-tuple-element(tuple_param), index=1
-  fusion.549 = bf16[15,1,2,48,128,2048]{5,4,3,2,1,0} fusion(param_0), kind=kLoop, calls=fused_computation.549
-  bitcast = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} bitcast(fusion.549)
-  copy = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} copy(bitcast)
-  constant_one = s32[] constant(1)
-  add = s32[] add(param_1, constant_one), control-predecessors={fusion.549}
-  ROOT tuple = (bf16[15,1,2,2048,48,128]{3,5,4,2,1,0}, s32[]) tuple(copy, add)
-}
-
-ENTRY main {
-  param_0 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} parameter(0)
-  zero = s32[] constant(0)
-  copy.0 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} copy(param_0)
-  copy.1 = s32[] copy(zero)
-  tuple = tuple(copy.0, copy.1)
-  ROOT while = (bf16[15,1,2,2048,48,128]{3,5,4,2,1,0}, s32[]) while(tuple), condition=condition, body=body, backend_config="{\"known_trip_count\":{\"n\":\"15\"}}"
-})"))
-                    .value();
-  ASSERT_FALSE(cf_.Run(module.get()).value());
-}
-
-TEST_F(CopyFusionTest, CopyFusionShouldNotRunWithDynamicUpdateSliceInplace) {
-  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
-    fused_computation {
-      p.0 = f16[50,96,1024]{2,1,0} parameter(0)
-      p.1 = f16[1,96,1024]{2,1,0} parameter(1)
-      c.0 = s32[3]{0} constant({0, 0, 0})
-      ROOT %dynamic-update-slice = f16[50,96,1024]{2,1,0} dynamic-update-slice(p.0, p.1, c.0)
-    }
-
-    ENTRY entry {
-      p0 = f16[50,96,1024]{2,1,0} parameter(0)
-      p1 = f16[1,96,1024]{2,1,0} parameter(1)
-      fusion = f16[50,96,1024]{2,1,0} fusion(p0, p1), kind=kInput, calls=fused_computation
-      copy.1 = f16[50,96,1024]{2,1,0} copy(fusion)
-      copy.2 = f16[50,96,1024]{2,1,0} copy(fusion)
-      ROOT root = (f16[50,96,1024]{2,1,0}, f16[50,96,1024]{2,1,0}) tuple(copy.1, copy.2)
-    })"))
-                    .value();
-  ASSERT_FALSE(cf_.Run(module.get()).value());
-}
-
-TEST_F(CopyFusionTest, CopyFusionWithDynamicUpdateSliceNotInplace) {
-  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
-    fused_computation {
-      one = f32[] constant(1.0)
-      zero = f32[] constant(0.0)
-      p.0 = f16[50,96,1024]{2,1,0} broadcast(one), dimensions={}
-      p.1 = f16[1,96,1024]{2,1,0} broadcast(zero), dimensions={}
-      c.0 = s32[3]{0} constant({0, 0, 0})
-      dynamic-update-slice = f16[50,96,1024]{2,1,0} dynamic-update-slice(p.0, p.1, c.0)
-      neg = f16[50,96,1024]{2,1,0} negate(dynamic-update-slice)
-      ROOT tuple = (f16[50,96,1024]{2,1,0}, f16[50,96,1024]{2,1,0}) tuple(dynamic-update-slice, neg)
-    }
-
-    ENTRY entry {
-      fusion = (f16[50,96,1024]{2,1,0}, f16[50,96,1024]{2,1,0}) fusion(), kind=kInput, calls=fused_computation
-      gte.0 = f16[50,96,1024]{2,1,0} get-tuple-element(fusion), index=0
-      gte.1 = f16[50,96,1024]{2,1,0} get-tuple-element(fusion), index=1
-      bitcast = f16[1,50,96,1024]{3,2,1,0} bitcast(gte.0)
-      copy = f16[1,50,96,1024]{3,2,1,0} copy(bitcast)
-      ROOT root = (f16[1,50,96,1024]{3,2,1,0}, f16[50,96,1024]{2,1,0}) tuple(copy, gte.1)
-    })"))
-                    .value();
-  ASSERT_TRUE(cf_.Run(module.get()).value());
-  SCOPED_TRACE(module->ToString());
-  const HloInstruction* root = module->entry_computation()->root_instruction();
-  const HloInstruction* fusion = nullptr;
-  ASSERT_THAT(root, GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
-                                        m::GetTupleElement())));
-  EXPECT_THAT(
-      fusion->fused_expression_root(),
-      GmockMatch(m::Tuple(m::DynamicUpdateSlice(), m::Negate(), m::Copy())));
-}
-
-TEST_F(CopyFusionTest, CopyFusionTransposeAndThreeCopies) {
-  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
-    fused_computation {
-      two = f32[] constant(2.0)
-      param_0.1 = f32[16,32]{1,0} broadcast(two), dimensions={}
-      s.1 = f32[16,32]{1,0} sqrt(param_0.1)
-      ROOT c.1 = f32[32,16]{1,0} transpose(s.1), dimensions={1,0}
-    }
-
-    ENTRY entry {
-      fusion = f32[32,16]{1,0} fusion(), kind=kInput, calls=fused_computation
-      copy.1 = f32[32,16]{1,0} copy(fusion)
-      copy.2 = f32[32,16]{1,0} copy(fusion)
-      copy.3 = f32[32,16]{1,0} copy(fusion)
-      ROOT root = (f32[32,16]{1,0}, f32[32,16]{1,0}, f32[32,16]{1,0}) tuple(copy.1, copy.2, copy.3)
-    })"))
-                    .value();
-  ASSERT_TRUE(cf_.Run(module.get()).value());
-  SCOPED_TRACE(module->ToString());
-  const HloInstruction* root = module->entry_computation()->root_instruction();
-  const HloInstruction* fusion = nullptr;
-  ASSERT_THAT(root,
-              GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
-                                  m::GetTupleElement(), m::GetTupleElement())));
-  EXPECT_THAT(
-      fusion->fused_expression_root(),
-      GmockMatch(m::Tuple(m::Transpose(), m::Copy(), m::Copy(), m::Copy())));
-}
-
-TEST_F(CopyFusionTest, CopyFusionRunWithOnlyOneCopy) {
-  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
-    fused_computation {
-      p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(0)
-      mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1)
-      ROOT neg = f32[128,512,28,28]{3,2,1,0} negate(mul)
-    }
-
-    ENTRY entry {
-      p0 = f32[128,512,28,28]{3,2,1,0} parameter(0)
-      fusion = f32[128,512,28,28]{3,2,1,0} fusion(p0), kind=kInput, calls=fused_computation
-      ROOT copy.1 = f32[128,512,28,28]{3,2,1,0} copy(fusion)
-    })"))
-                    .value();
-  ASSERT_TRUE(cf_.Run(module.get()).value());
-  SCOPED_TRACE(module->ToString());
-  const HloInstruction* root = module->entry_computation()->root_instruction();
-  const HloInstruction* fusion = nullptr;
-  ASSERT_THAT(root, GmockMatch(m::GetTupleElement(m::Fusion(&fusion))));
-  EXPECT_THAT(fusion->fused_expression_root(),
-              GmockMatch(m::Tuple(m::Negate(), m::Copy())));
-}
-
-TEST_F(CopyFusionTest, CopyFusionNegateAndTwoCopiesAndTransposeCopy) {
-  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
-    fused_computation {
-      p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(0)
-      mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1)
-      ROOT neg = f32[128,512,28,28]{3,2,1,0} negate(mul)
-    }
-
-    ENTRY entry {
-      p0 = f32[128,512,28,28]{3,2,1,0} parameter(0)
-      fusion = f32[128,512,28,28]{3,2,1,0} fusion(p0), kind=kInput, calls=fused_computation
-      copy.1 = f32[128,512,28,28]{3,2,1,0} copy(fusion)
-      transpose = f32[128,512,28,28]{2,3,0,1} copy(fusion)
-      bitcast = f32[512,128,28,28]{3,2,1,0} bitcast(transpose)
-      copy.2 = f32[128,512,28,28]{3,2,1,0} copy(fusion)
-      ROOT root = (f32[128,512,28,28]{3,2,1,0}, f32[512,128,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0}) tuple(copy.1, bitcast, copy.2)
-    })"))
-                    .value();
-  ASSERT_TRUE(cf_.Run(module.get()).value());
-  SCOPED_TRACE(module->ToString());
-  const HloInstruction* root = module->entry_computation()->root_instruction();
-  const HloInstruction* fusion = nullptr;
-  ASSERT_THAT(root, GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
-                                        m::Bitcast(), m::GetTupleElement())));
-  EXPECT_THAT(fusion->fused_expression_root(),
-              GmockMatch(m::Tuple(m::Negate(), m::Copy(), m::Copy())));
-}
-
-TEST_F(CopyFusionTest, CopyFusionRunWithOnlyOneNonTransposeCopy) {
-  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
-    fused_computation {
-      p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(0)
-      mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1)
-      ROOT neg = f32[128,512,28,28]{3,2,1,0} negate(mul)
-    }
-
-    ENTRY entry {
-      p0 = f32[128,512,28,28]{3,2,1,0} parameter(0)
-      fusion = f32[128,512,28,28]{3,2,1,0} fusion(p0), kind=kInput, calls=fused_computation
-      copy.1 = f32[128,512,28,28]{3,2,1,0} copy(fusion)
-      transpose.1 = f32[128,512,28,28]{2,3,0,1} copy(fusion)
-      bitcast.1 = f32[512,128,28,28]{3,2,1,0} bitcast(transpose.1)
-      transpose.2 = f32[128,512,28,28]{2,3,0,1} copy(fusion)
-      bitcast.2 = f32[512,128,28,28]{3,2,1,0} bitcast(transpose.2)
-      ROOT root = (f32[128,512,28,28]{3,2,1,0}, f32[512,128,28,28]{3,2,1,0}, f32[512,128,28,28]{3,2,1,0}) tuple(copy.1, bitcast.1, bitcast.2)
-    })"))
-                    .value();
-  ASSERT_TRUE(cf_.Run(module.get()).value());
-  SCOPED_TRACE(module->ToString());
-  const HloInstruction* root = module->entry_computation()->root_instruction();
-  const HloInstruction* fusion = nullptr;
-  ASSERT_THAT(root, GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
-                                        m::Bitcast(), m::Bitcast())));
-  EXPECT_THAT(fusion->fused_expression_root(),
-              GmockMatch(m::Tuple(m::Negate(), m::Copy())));
-}
-
-TEST_F(CopyFusionTest, CopyFusionSkipTupleCopies) {
-  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
-    fused_computation {
-      p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(0)
-      mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1)
-      neg.1 = f32[128,512,28,28]{3,2,1,0} negate(mul)
-      neg.2 = f32[128,512,28,28]{3,2,1,0} negate(mul)
-      ROOT tuple = (f32[128,512,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0}) tuple(neg.1, neg.2)
-    }
-
-    ENTRY entry {
-      p0 = f32[128,512,28,28]{3,2,1,0} parameter(0)
-      fusion = (f32[128,512,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0}) fusion(p0), kind=kInput, calls=fused_computation
-      copy.1 = (f32[128,512,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0}) copy(fusion)
-      copy.2 = (f32[128,512,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0}) copy(fusion)
-      ROOT root = ((f32[128,512,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0}),(f32[128,512,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0})) tuple(copy.1, copy.2)
-    })"))
-                    .value();
-  ASSERT_FALSE(cf_.Run(module.get()).value());
-}
-
-TEST_F(CopyFusionTest, CopyFusionTupleAndGetTuple) {
-  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
-    fused_computation {
-      p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(0)
-      mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1)
-      neg.1 = f32[128,512,28,28]{3,2,1,0} negate(mul)
-      neg.2 = f32[128,512,28,28]{3,2,1,0} negate(mul)
-      ROOT tuple = (f32[128,512,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0}) tuple(neg.1, neg.2)
-    }
-
-    ENTRY entry {
-      p0 = f32[128,512,28,28]{3,2,1,0} parameter(0)
-      fusion = (f32[128,512,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0}) fusion(p0), kind=kInput, calls=fused_computation
-      gte.1 = f32[128,512,28,28]{3,2,1,0} get-tuple-element(fusion), index=0
-      gte.2 = f32[128,512,28,28]{3,2,1,0} get-tuple-element(fusion), index=1
-      copy.1 = f32[128,512,28,28]{3,2,1,0} copy(gte.1)
-      copy.2 = f32[128,512,28,28]{3,2,1,0} copy(gte.2)
-      ROOT root = (f32[128,512,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0}) tuple(copy.1, copy.2)
-    })"))
-                    .value();
-  ASSERT_TRUE(cf_.Run(module.get()).value());
-  SCOPED_TRACE(module->ToString());
-  const HloInstruction* root = module->entry_computation()->root_instruction();
-  const HloInstruction* fusion = nullptr;
-  ASSERT_THAT(root, GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
-                                        m::GetTupleElement())));
-  EXPECT_THAT(
-      fusion->fused_expression_root(),
-      GmockMatch(m::Tuple(m::Negate(), m::Negate(), m::Copy(), m::Copy())));
-}
-
-TEST_F(CopyFusionTest, CopyFusionWithFusionReturningTupleAndOtherUser) {
-  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
-    fused_computation {
-      p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(0)
-      mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1)
-      neg.1 = f32[128,512,28,28]{3,2,1,0} negate(mul)
-      neg.2 = f32[128,512,28,28]{3,2,1,0} negate(mul)
-      ROOT tuple = (f32[128,512,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0}) tuple(neg.1, neg.2)
-    }
-
-    ENTRY entry {
-      p0 = f32[128,512,28,28]{3,2,1,0} parameter(0)
-      fusion = (f32[128,512,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0}) fusion(p0), kind=kInput, calls=fused_computation
-      gte.1 = f32[128,512,28,28]{3,2,1,0} get-tuple-element(fusion), index=0
-      gte.2 = f32[128,512,28,28]{3,2,1,0} get-tuple-element(fusion), index=1
-      copy.1 = f32[128,512,28,28]{3,2,1,0} copy(gte.1)
-      copy.2 = f32[128,512,28,28]{3,2,1,0} copy(gte.2)
-      transpose = f32[128,512,28,28]{2,3,0,1} copy(gte.1)
-      bitcast = f32[512,128,28,28]{3,2,1,0} bitcast(transpose)
-      ROOT root = (f32[128,512,28,28]{3,2,1,0}, f32[512,128,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0}) tuple(copy.1, bitcast, copy.2)
-    })"))
-                    .value();
-  ASSERT_TRUE(cf_.Run(module.get()).value());
-  SCOPED_TRACE(module->ToString());
-  const HloInstruction* root = module->entry_computation()->root_instruction();
-  const HloInstruction* fusion = nullptr;
-  ASSERT_THAT(root,
-              GmockMatch(m::Tuple(m::Copy(), m::Bitcast(),
-                                  m::GetTupleElement(m::Fusion(&fusion)))));
-  EXPECT_THAT(fusion->fused_expression_root(),
-              GmockMatch(m::Tuple(m::Negate(), m::Negate(), m::Copy())));
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/cublas_pad_for_gemms.cc b/third_party/xla/xla/service/gpu/cublas_pad_for_gemms.cc
deleted file mode 100644
index f0da0e5..0000000
--- a/third_party/xla/xla/service/gpu/cublas_pad_for_gemms.cc
+++ /dev/null
@@ -1,210 +0,0 @@
-/* Copyright 2018 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/cublas_pad_for_gemms.h"
-
-#include <cstdint>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/literal_util.h"
-#include "xla/service/gpu/fusions/triton/triton_support.h"
-#include "xla/service/gpu/gemm_fusion.h"
-#include "xla/service/gpu/ir_emission_utils.h"
-#include "xla/shape.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/util.h"
-#include "tsl/platform/logging.h"
-#include "tsl/platform/status.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-static absl::StatusOr<bool> PadForGemm(HloDotInstruction* dot,
-                                       PrimitiveType datatype,
-                                       int pad_to_multiple_of) {
-  auto* lhs = dot->mutable_operand(0);
-  auto* rhs = dot->mutable_operand(1);
-
-  Shape lshape = lhs->shape();
-  Shape rshape = rhs->shape();
-  Shape result_shape = dot->shape();
-
-  if (lshape.element_type() != datatype || rshape.element_type() != datatype) {
-    return false;
-  }
-
-  auto pad_dim = [&](Shape& s, int dim) {
-    s.set_dimensions(dim,
-                     RoundUpTo<int64_t>(s.dimensions(dim), pad_to_multiple_of));
-  };
-
-  auto pad_matrix_dims = [&pad_dim](Shape s) {
-    // Since the dot instruction is canonicalized, the last two dimensions for
-    // each operand represent non-batch dimensions, and the others are the same
-    // for both operands and correspond to batch dimensions.
-    pad_dim(s, s.rank() - 2);
-    pad_dim(s, s.rank() - 1);
-    return s;
-  };
-
-  Shape new_lshape = pad_matrix_dims(lshape);
-  Shape new_rshape = pad_matrix_dims(rshape);
-  Shape new_result_shape = pad_matrix_dims(result_shape);
-
-  if (new_lshape == lshape && new_rshape == rshape) {
-    return false;
-  }
-
-  VLOG(3) << "old shape: " << lshape << " " << rshape << " " << result_shape;
-  VLOG(3) << "new shape: " << new_lshape << " " << new_rshape << " "
-          << new_result_shape;
-
-  auto create_padding_config = [](Shape& shape, Shape& new_shape) {
-    PaddingConfig padding_config;
-    for (int i = 0; i < shape.rank(); ++i) {
-      auto dimension = padding_config.add_dimensions();
-      dimension->set_edge_padding_high(new_shape.dimensions()[i] -
-                                       shape.dimensions()[i]);
-      dimension->set_edge_padding_low(0);
-      dimension->set_interior_padding(0);
-    }
-    return padding_config;
-  };
-
-  auto l_padding_config = create_padding_config(lshape, new_lshape);
-  auto r_padding_config = create_padding_config(rshape, new_rshape);
-
-  HloComputation* parent = dot->parent();
-
-  HloInstruction* zero_float = parent->AddInstruction(
-      HloInstruction::CreateConstant(LiteralUtil::Zero(datatype)));
-  zero_float->set_metadata(dot->metadata());
-
-  HloInstruction* lpad = parent->AddInstruction(
-      HloInstruction::CreatePad(new_lshape, lhs, zero_float, l_padding_config));
-  lpad->set_metadata(dot->metadata());
-
-  HloInstruction* rpad = parent->AddInstruction(
-      HloInstruction::CreatePad(new_rshape, rhs, zero_float, r_padding_config));
-  rpad->set_metadata(dot->metadata());
-
-  HloInstruction* new_dot = parent->AddInstruction(
-      dot->CloneWithNewOperands(new_result_shape, {lpad, rpad}));
-
-  std::vector<int64_t> start_indices(result_shape.rank(), 0);
-  std::vector<int64_t> strides(result_shape.rank(), 1);
-  HloInstruction* slice = parent->AddInstruction(
-      HloInstruction::CreateSlice(result_shape, new_dot, start_indices,
-                                  result_shape.dimensions(), strides));
-  slice->set_metadata(dot->metadata());
-
-  bool is_root = dot->user_count() == 0;
-
-  TF_CHECK_OK(parent->ReplaceInstruction(dot, slice));
-
-  if (is_root) {
-    parent->set_root_instruction(slice);
-  }
-
-  return true;
-}
-
-namespace {
-
-// We need this check because PadForGemm works in the assumption that
-// the dot instruction is canonicalized.
-bool CheckCanonical(HloDotInstruction* dot) {
-  const auto& dimension_numbers = dot->dot_dimension_numbers();
-
-  if (dimension_numbers.lhs_batch_dimensions_size() + 2 !=
-          dot->operand(0)->shape().rank() ||
-      dimension_numbers.rhs_batch_dimensions_size() + 2 !=
-          dot->operand(1)->shape().rank()) {
-    VLOG(2)
-        << dot->ToString()
-        << " is not canonical: Expected all dimensions but 2 to be "
-           "batch_dimensions. Hence, this dot is not a candidate for padding.";
-    return false;
-  }
-
-  std::vector<int64_t> canonical_batch_dims(
-      dimension_numbers.lhs_batch_dimensions_size());
-  absl::c_iota(canonical_batch_dims, 0);
-  if (!absl::c_equal(dimension_numbers.lhs_batch_dimensions(),
-                     canonical_batch_dims) ||
-      !absl::c_equal(dimension_numbers.rhs_batch_dimensions(),
-                     canonical_batch_dims)) {
-    VLOG(2)
-        << dot->ToString()
-        << " is not canonical: Expected batch dimensions to be all "
-           "dimensions except for the last 2 ones. Hence, this dot is not a "
-           "candidate for padding.";
-    return false;
-  }
-
-  return true;
-}
-
-}  // namespace
-
-static std::vector<HloDotInstruction*> GetRelevantDots(
-    const se::GpuComputeCapability& gpu_compute_capability,
-    HloComputation* comp, PrimitiveType datatype) {
-  std::vector<HloDotInstruction*> gemms;
-
-  for (HloInstruction* instr : comp->instructions()) {
-    if (IsMatrixMultiplication(*instr)) {
-      HloDotInstruction* dot = Cast<HloDotInstruction>(instr);
-      if (instr->operand(0)->shape().element_type() == datatype &&
-          CheckCanonical(dot) &&
-          !(instr->GetModule()
-                ->config()
-                .debug_options()
-                .xla_gpu_enable_triton_gemm() &&
-            legacy_triton::IsTritonSupportedInstruction(
-                *dot, gpu_compute_capability) &&
-            ShouldTritonHandleGEMM(*dot, gpu_compute_capability))) {
-        gemms.push_back(dot);
-      }
-    }
-  }
-  return gemms;
-}
-
-absl::StatusOr<bool> CublasPadForGemms::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  bool changed = false;
-  for (HloComputation* comp :
-       module->MakeNonfusionComputations(execution_threads)) {
-    for (HloDotInstruction* dot :
-         GetRelevantDots(gpu_compute_capability_, comp, datatype_)) {
-      TF_ASSIGN_OR_RETURN(bool result,
-                          PadForGemm(dot, datatype_, pad_to_multiple_of_));
-      changed |= result;
-    }
-  }
-  return changed;
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/cublas_pad_for_gemms.h b/third_party/xla/xla/service/gpu/cublas_pad_for_gemms.h
deleted file mode 100644
index 2a1f9c6..0000000
--- a/third_party/xla/xla/service/gpu/cublas_pad_for_gemms.h
+++ /dev/null
@@ -1,63 +0,0 @@
-/* Copyright 2018 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_CUBLAS_PAD_FOR_GEMMS_H_
-#define XLA_SERVICE_GPU_CUBLAS_PAD_FOR_GEMMS_H_
-
-#include <cstdint>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/stream_executor/device_description.h"
-
-namespace xla {
-namespace gpu {
-
-// Adds padding to dot operations to make them run faster on GPUs.
-//
-//
-// This can be used to pad f16 dots on tensor cores, or s8 dots to multiples of
-// four.
-//
-// This pass depends on xla::DotDecomposer pass,
-// so it should go strictly later.
-class CublasPadForGemms : public HloModulePass {
- public:
-  CublasPadForGemms(const se::GpuComputeCapability gpu_compute_capability,
-                    PrimitiveType datatype, int32_t pad_to_multiple_of)
-      : gpu_compute_capability_(gpu_compute_capability),
-        datatype_(datatype),
-        pad_to_multiple_of_(pad_to_multiple_of) {}
-
-  absl::string_view name() const override { return "cublas-pad-for-gemms"; }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
-  const se::GpuComputeCapability gpu_compute_capability_;
-  PrimitiveType datatype_;
-  int32_t pad_to_multiple_of_;
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_CUBLAS_PAD_FOR_GEMMS_H_
diff --git a/third_party/xla/xla/service/gpu/cublas_pad_for_gemms_test.cc b/third_party/xla/xla/service/gpu/cublas_pad_for_gemms_test.cc
deleted file mode 100644
index d20dd94..0000000
--- a/third_party/xla/xla/service/gpu/cublas_pad_for_gemms_test.cc
+++ /dev/null
@@ -1,306 +0,0 @@
-/* Copyright 2018 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/cublas_pad_for_gemms.h"
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/tests/hlo_test_base.h"
-
-namespace m = ::xla::match;
-
-namespace xla {
-namespace gpu {
-namespace {
-
-class CublasGemmPadForTensorCoresTest : public HloTestBase {
- protected:
-  bool PadForF16Gemms(HloModule* module) {
-    return CublasPadForGemms(se::CudaComputeCapability(7, 0),
-                             PrimitiveType::F16, 8)
-        .Run(module)
-        .value();
-  }
-
-  DebugOptions GetDebugOptionsForTest() override {
-    DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest();
-    // Some pads would not be added if we detect that Triton will handle the
-    // given dot operation.
-    debug_options.set_xla_gpu_triton_gemm_any(false);
-    return debug_options;
-  }
-};
-
-TEST_F(CublasGemmPadForTensorCoresTest, OneDotRootComputation) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule TestModule
-
-  ENTRY TestComputation {
-    %param1 = f16[2048,1024] parameter(0)
-    %param2 = f16[1024,33708] parameter(1)
-    ROOT %dot.2309 = f16[2048,33708]{1,0} dot(f16[2048,1024]{1,0} %param1,
-                f16[1024,33708]{0,1} %param2),
-                lhs_contracting_dims={1}, rhs_contracting_dims={0}
-                })")
-                    .value();
-
-  EXPECT_TRUE(PadForF16Gemms(module.get()));
-  SCOPED_TRACE(module->ToString());
-
-  auto* root = module->entry_computation()->root_instruction();
-  EXPECT_THAT(
-      root,
-      GmockMatch(
-          m::Slice(m::Dot(m::Pad(m::Parameter().WithShape(F16, {2048, 1024}),
-                                 m::Constant().WithShape(F16, {}))
-                              .WithShape(F16, {2048, 1024}),
-                          m::Pad(m::Parameter().WithShape(F16, {1024, 33708}),
-                                 m::Constant().WithShape(F16, {}))
-                              .WithShape(F16, {1024, 33712}))
-                       .WithShape(F16, {2048, 33712})
-                       .WithContractingDims(/*lhs_contracting_dims=*/{1},
-                                            /*rhs_contracting_dims=*/{0}))
-              .WithShape(F16, {2048, 33708})));
-}
-
-TEST_F(CublasGemmPadForTensorCoresTest, OneDotS8RootComputation) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule TestModule
-
-  ENTRY TestComputation {
-    %param1 = s8[2047,1023] parameter(0)
-    %param2 = s8[1023,33707] parameter(1)
-    ROOT %dot.2309 = s32[2047,33707]{1,0} dot(s8[2047,1023]{1,0} %param1,
-                s8[1023,33707]{0,1} %param2),
-                lhs_contracting_dims={1}, rhs_contracting_dims={0}
-                })")
-                    .value();
-
-  EXPECT_TRUE(
-      CublasPadForGemms(se::CudaComputeCapability(7, 0), PrimitiveType::S8, 4)
-          .Run(module.get())
-          .value());
-  SCOPED_TRACE(module->ToString());
-
-  auto* root = module->entry_computation()->root_instruction();
-  EXPECT_THAT(
-      root,
-      GmockMatch(
-          m::Slice(m::Dot(m::Pad(m::Parameter().WithShape(S8, {2047, 1023}),
-                                 m::Constant().WithShape(S8, {}))
-                              .WithShape(S8, {2048, 1024}),
-                          m::Pad(m::Parameter().WithShape(S8, {1023, 33707}),
-                                 m::Constant().WithShape(S8, {}))
-                              .WithShape(S8, {1024, 33708}))
-                       .WithShape(S32, {2048, 33708})
-                       .WithContractingDims(/*lhs_contracting_dims=*/{1},
-                                            /*rhs_contracting_dims=*/{0}))
-              .WithShape(S32, {2047, 33707})));
-}
-
-TEST_F(CublasGemmPadForTensorCoresTest, TwoDotsComputation) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule TestModule
-
-  ENTRY TestComputation {
-    %param1 = f16[2048, 1024] parameter(0)
-    %param2 = f16[1024, 33708] parameter(1)
-    %param3 = f16[33708, 1] parameter(2)
-    %dot1 = f16[2048, 33708]{1,0} dot(f16[2048, 1024]{1,0} %param1,
-                f16[1024, 33708]{0,1} %param2),
-                lhs_contracting_dims={1}, rhs_contracting_dims={0}
-    ROOT %dot2 = f16[2048, 1]{1,0} dot(f16[2048, 33708]{1,0} %dot1,
-                f16[33708, 1]{0,1} %param3),
-                lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  })")
-                    .value();
-
-  EXPECT_TRUE(PadForF16Gemms(module.get()));
-  SCOPED_TRACE(module->ToString());
-
-  auto* root = module->entry_computation()->root_instruction();
-  const HloInstruction* dot2 = nullptr;
-  ASSERT_THAT(
-      root,
-      GmockMatch(
-          m::Slice(
-              m::Dot(
-                  m::Pad(m::Slice(m::Dot(&dot2,
-                                         m::Pad().WithShape(F16, {2048, 1024}),
-                                         m::Pad().WithShape(F16, {1024, 33712}))
-                                      .WithContractingDims(
-                                          /*lhs_contracting_dims=*/{1},
-                                          /*rhs_contracting_dims=*/{0})
-                                      .WithShape(F16, {2048, 33712}))
-                             .WithShape(F16, {2048, 33708}),
-                         m::Constant().WithShape(F16, {}))
-                      .WithShape(F16, {2048, 33712}),
-
-                  m::Pad(m::Parameter().WithShape(F16, {33708, 1}),
-                         m::Constant().WithShape(F16, {}))
-                      .WithShape(F16, {33712, 8}))
-                  .WithShape(F16, {2048, 8})
-                  .WithContractingDims(/*lhs_contracting_dims=*/{1},
-                                       /*rhs_contracting_dims=*/{0}))
-              .WithShape(F16, {2048, 1})));
-
-  EXPECT_THAT(
-      dot2,
-      GmockMatch(m::Dot(m::Pad(m::Parameter().WithShape(F16, {2048, 1024}),
-                               m::Constant().WithShape(F16, {}))
-                            .WithShape(F16, {2048, 1024}),
-                        m::Pad(m::Parameter().WithShape(F16, {1024, 33708}),
-                               m::Constant().WithShape(F16, {}))
-                            .WithShape(F16, {1024, 33712}))
-                     .WithContractingDims(/*lhs_contracting_dims=*/{1},
-                                          /*rhs_contracting_dims=*/{0})));
-}
-
-TEST_F(CublasGemmPadForTensorCoresTest, DotWithBatchDimensions) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule TestModule
-
-  ENTRY TestComputation {
-    %param1 = f16[3, 5, 2048, 1024] parameter(0)
-    %param2 = f16[3, 5, 1024, 33708] parameter(1)
-    ROOT %dot.2309 = f16[3, 5, 2048, 33708]{3, 2, 1,0} dot(f16[3, 5, 2048, 1024]{3, 2, 1,0} %param1,
-                f16[3, 5, 1024, 33708]{2, 3, 0,1} %param2), lhs_batch_dims={0, 1}, rhs_batch_dims={0, 1}, lhs_contracting_dims={3}, rhs_contracting_dims={2}})")
-                    .value();
-
-  EXPECT_TRUE(PadForF16Gemms(module.get()));
-  SCOPED_TRACE(module->ToString());
-
-  auto* root = module->entry_computation()->root_instruction();
-  EXPECT_THAT(
-      root,
-      GmockMatch(
-          m::Slice(
-              m::Dot(m::Pad(m::Parameter().WithShape(F16, {3, 5, 2048, 1024}),
-                            m::Constant().WithShape(F16, {}))
-                         .WithShape(F16, {3, 5, 2048, 1024}),
-                     m::Pad(m::Parameter().WithShape(F16, {3, 5, 1024, 33708}),
-                            m::Constant().WithShape(F16, {}))
-                         .WithShape(F16, {3, 5, 1024, 33712}))
-                  .WithShape(F16, {3, 5, 2048, 33712})
-                  .WithContractingDims(/*lhs_contracting_dims=*/{3},
-                                       /*rhs_contracting_dims=*/{2}))
-              .WithShape(F16, {3, 5, 2048, 33708})));
-}
-
-TEST_F(CublasGemmPadForTensorCoresTest, NoDotComputation) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule TestModule
-
-  ENTRY TestComputation {
-    %x = f32[] parameter(0)
-    %y = f32[] parameter(1)
-    ROOT %maximum = f32[] maximum(f32[] %x, f32[] %y)
-  })")
-                    .value();
-
-  EXPECT_FALSE(PadForF16Gemms(module.get()));
-}
-
-TEST_F(CublasGemmPadForTensorCoresTest, F32DotComputation) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule TestModule
-
-  ENTRY TestComputation {
-    %param1 = f32[2048,1024] parameter(0)
-    %param2 = f32[1024,33708] parameter(1)
-    ROOT %dot.2309 = f32[2048,33708]{1,0} dot(f32[2048,1024]{1,0} %param1,
-                f32[1024,33708]{0,1} %param2),
-                lhs_contracting_dims={1}, rhs_contracting_dims={0}})")
-                    .value();
-
-  EXPECT_FALSE(PadForF16Gemms(module.get()));
-}
-
-TEST_F(CublasGemmPadForTensorCoresTest, F64DotComputation) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule TestModule
-
-  ENTRY TestComputation {
-    %param1 = f64[2048,1024] parameter(0)
-    %param2 = f64[1024,33708] parameter(1)
-    ROOT %dot.2309 = f64[2048,33708]{1,0} dot(f64[2048,1024]{1,0} %param1,
-                f64[1024,33708]{0,1} %param2),
-                lhs_contracting_dims={1}, rhs_contracting_dims={0}})")
-                    .value();
-
-  EXPECT_FALSE(PadForF16Gemms(module.get()));
-}
-
-TEST_F(CublasGemmPadForTensorCoresTest, MultiplesOf8DotComputation) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule TestModule
-
-  ENTRY TestComputation {
-    %param1 = f16[2048,1024] parameter(0)
-    %param2 = f16[1024,33712] parameter(1)
-    ROOT %dot.2309 = f16[2048,33712]{1,0} dot(f16[2048,1024]{1,0} %param1,
-                f16[1024,33712]{0,1} %param2),
-                lhs_contracting_dims={1}, rhs_contracting_dims={0}})")
-                    .value();
-
-  EXPECT_FALSE(PadForF16Gemms(module.get()));
-}
-
-TEST_F(CublasGemmPadForTensorCoresTest, CheckSavingMetadata) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule TestModule
-
-  ENTRY TestComputation {
-    %param1 = f16[2048,1024] parameter(0)
-    %param2 = f16[1024,33708] parameter(1)
-    ROOT %dot.2309 = f16[2048,33708]{1,0} dot(f16[2048,1024]{1,0} %param1,
-                f16[1024,33708]{0,1} %param2),
-                lhs_contracting_dims={1}, rhs_contracting_dims={0},
-                metadata={op_type="MatMul" op_name="transformer_v2/Transformer/decode/embedding_shared_weights_1/presoftmax_linear/MatMul"}
-                })")
-                    .value();
-
-  SCOPED_TRACE(module->ToString());
-
-  EXPECT_TRUE(PadForF16Gemms(module.get()));
-  auto metadata = module->entry_computation()->root_instruction()->metadata();
-  EXPECT_EQ("MatMul", metadata.op_type());
-  EXPECT_EQ(
-      "transformer_v2/Transformer/decode/embedding_shared_weights_1/"
-      "presoftmax_linear/MatMul",
-      metadata.op_name());
-}
-
-TEST_F(CublasGemmPadForTensorCoresTest, NotCanonicalizedDot) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule TestModule
-
-  ENTRY TestComputation {
-    %param1 = f16[3, 5, 2048, 1024] parameter(0)
-    %param2 = f16[3, 5, 1024, 33708] parameter(1)
-    ROOT %dot.2309 = f16[3,2048, 33708]{2, 1, 0} dot(f16[3, 5, 2048, 1024]{3, 2, 1, 0} %param1, f16[3, 5, 1024, 33708]{3, 2, 1, 0} %param2), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={3, 1}, rhs_contracting_dims={2, 1}})")
-                    .value();
-
-  EXPECT_FALSE(PadForF16Gemms(module.get()));
-}
-
-}  // anonymous namespace
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/cudnn_fused_conv_rewriter.cc b/third_party/xla/xla/service/gpu/cudnn_fused_conv_rewriter.cc
deleted file mode 100644
index e9cb21b..0000000
--- a/third_party/xla/xla/service/gpu/cudnn_fused_conv_rewriter.cc
+++ /dev/null
@@ -1,1566 +0,0 @@
-/* Copyright 2018 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/cudnn_fused_conv_rewriter.h"
-
-#include <algorithm>
-#include <array>
-#include <cstdint>
-#include <functional>
-#include <limits>
-#include <optional>
-#include <string>
-#include <tuple>
-#include <utility>
-#include <variant>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_map.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/container/inlined_vector.h"
-#include "absl/log/check.h"
-#include "absl/log/log.h"
-#include "absl/status/status.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/str_format.h"
-#include "absl/strings/str_join.h"
-#include "absl/strings/string_view.h"
-#include "xla/comparison_util.h"
-#include "xla/debug_options_flags.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/literal.h"
-#include "xla/primitive_util.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/service/hlo_creation_utils.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/stream_executor/dnn.h"
-#include "xla/stream_executor/stream_executor.h"
-#include "xla/util.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/ml_dtypes.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-namespace m = match;
-
-bool IsConvCustomCall(const HloInstruction* instr) {
-  return instr->opcode() == HloOpcode::kCustomCall &&
-         (instr->custom_call_target() == kCudnnConvForwardCallTarget ||
-          instr->custom_call_target() ==
-              kCudnnConvBiasActivationForwardCallTarget);
-}
-
-bool IsConvDepthwise(const HloInstruction* instr) {
-  int64_t feature_group_count = instr->feature_group_count();
-  if (feature_group_count == 1) {
-    return false;
-  }
-
-  const HloInstruction* input = instr->operand(0);
-  int64_t input_feature_dimension =
-      instr->convolution_dimension_numbers().input_feature_dimension();
-  int64_t input_feature_count =
-      input->shape().dimensions(input_feature_dimension);
-  return input_feature_count == feature_group_count;
-}
-
-// We don't want to upgrade depthwise convolutions to ConvBiasActivation,
-// because the fused CUDNN functions are slower for some of those.
-bool IsNonDepthwiseConvCustomCall(const HloInstruction* instr) {
-  return IsConvCustomCall(instr) && !IsConvDepthwise(instr);
-}
-
-bool IsROCm(se::GpuComputeCapability cc) {
-  return std::holds_alternative<se::RocmComputeCapability>(cc);
-}
-
-// elu, relu6, and leaky-relu activations are supported in cudnn via the
-// "runtime fusion" engine, which JIT compiles C++ code.  This can be slow to
-// compile, so we guard it with a debug option.
-//
-// nvidia currently recommends that we enable this only on Ampere+, but we've
-// tested on Turing (sm75) and it seems to work fine.
-//
-// Note that as of writing, xla_gpu_use_runtime_fusion is disabled by default
-// due to apparent bugs in cudnn 8.9.0.  See debug_options_flags.cc for details.
-bool ShouldUseCudnnRuntimeFusion(const DebugOptions& debug_opts,
-                                 se::GpuComputeCapability cc) {
-  const auto* cuda_cc = std::get_if<se::CudaComputeCapability>(&cc);
-  if (cuda_cc != nullptr)
-    return debug_opts.xla_gpu_use_runtime_fusion() && cuda_cc->IsAtLeast(7, 5);
-  else
-    return true;
-}
-
-bool IsSuitableForCudnnRuntimeFusion(HloInstruction* conv) {
-  // cudnn runtime fusion is pathologically slow on convs with side-inputs.
-  // TODO(kaixih@nvidia): remove this check when cuDNN fixes it.
-  if (conv->operands().size() > 3) {
-    return false;
-  }
-
-  // cuDNN runtime funsion kernels require 32-bit aligned data access, which
-  // means that the number of in/out channels must be divisible by 2 for fp16.
-  // (We don't currently do runtime fusion for int8.)
-  if (conv->operand(0)->shape().element_type() != F16) {
-    return false;
-  }
-  const Shape& shape = conv->operand(1)->shape();
-  int64_t num_input_features = shape.dimensions(
-      conv->convolution_dimension_numbers().kernel_input_feature_dimension());
-  int64_t num_output_features = shape.dimensions(
-      conv->convolution_dimension_numbers().kernel_output_feature_dimension());
-  if (num_input_features % 2 != 0 || num_output_features % 2 != 0) {
-    return false;
-  }
-
-  return true;
-}
-
-// Can instr be converted to type `dst_ty` without losing any precision?  For
-// our purposes, this is true if:
-//
-//  - instr already has type dst_ty, or
-//  - instr is convert<wider type>(op_with_dst_ty), or
-//  - instr is a constant which we can convert orig_ty -> dst_ty -> orig_ty and
-//    get back exactly the original value, or
-//  - instr is a broadcast, reshape, or transpose of one of the above.
-bool IsLosslesslyConvertibleTo(const HloInstruction* instr,
-                               PrimitiveType dst_ty) {
-  if (instr->shape().element_type() == dst_ty) {
-    return true;
-  }
-
-  if (Match(instr, m::Convert(m::Op().WithElementType(dst_ty)))) {
-    // Check that the convert from dst_ty to instr->element_type() doesn't lose
-    // precision.  Otherwise, this convert is not lossless.
-    return primitive_util::CastPreservesValues(dst_ty,
-                                               instr->shape().element_type());
-  }
-
-  if (instr->opcode() == HloOpcode::kConstant) {
-    if (!instr->shape().IsArray()) {
-      return false;
-    }
-    // Check if instr's literal roundtrips to ty and back to its original type
-    // without modification.
-    PrimitiveType orig_ty = instr->shape().element_type();
-
-    // The only reason Convert() should fail is if we don't support converting
-    // from x to y, which indeed means it's not losslessly-convertible.
-    absl::StatusOr<Literal> converted1 = instr->literal().Convert(dst_ty);
-    if (!converted1.ok()) {
-      return false;
-    }
-    absl::StatusOr<Literal> converted2 = converted1->Convert(orig_ty);
-    if (!converted2.ok()) {
-      return false;
-    }
-
-    return instr->literal() == *converted2;
-  }
-
-  if (instr->opcode() == HloOpcode::kBroadcast ||
-      instr->opcode() == HloOpcode::kReshape ||
-      instr->opcode() == HloOpcode::kTranspose) {
-    return IsLosslesslyConvertibleTo(instr->operand(0), dst_ty);
-  }
-
-  return false;
-}
-
-// Helpers suitable for use in m::Op().WithPredicate(...).
-bool IsLosslesslyConvertibleToS8(const HloInstruction* instr) {
-  return IsLosslesslyConvertibleTo(instr, S8);
-}
-bool IsLosslesslyConvertibleToF16(const HloInstruction* instr) {
-  return IsLosslesslyConvertibleTo(instr, F16);
-}
-
-// If `conv` is a vanilla forward conv, transforms it into a
-// conv-bias-activation.  If it's already a conv-bias-activation, does nothing.
-//
-// If `conv` is anything else, returns an error.
-absl::StatusOr<HloInstruction*> EnsureIsConvBiasActivation(
-    HloInstruction* conv) {
-  CHECK_EQ(conv->opcode(), HloOpcode::kCustomCall);
-
-  if (conv->custom_call_target() == kCudnnConvBiasActivationForwardCallTarget) {
-    return conv;
-  }
-
-  if (conv->custom_call_target() == kCudnnConvForwardCallTarget) {
-    HloComputation* comp = conv->parent();
-
-    const Shape& shape = conv->shape().tuple_shapes(0);
-    int64_t num_output_features = shape.dimensions(
-        conv->convolution_dimension_numbers().output_feature_dimension());
-
-    // bias for integer convs is always f32, see
-    // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionBiasActivationForward
-    PrimitiveType bias_ty;
-    if (primitive_util::IsIntegralType(shape.element_type())) {
-      bias_ty = F32;
-    } else {
-      bias_ty = shape.element_type();
-    }
-    auto bias = BroadcastZeros(comp, bias_ty, {num_output_features});
-
-    absl::InlinedVector<HloInstruction*, 3> new_operands(
-        conv->operands().begin(), conv->operands().end());
-    new_operands.push_back(bias);
-
-    HloInstruction* new_conv = comp->AddInstruction(
-        conv->CloneWithNewOperands(conv->shape(), new_operands));
-    TF_RETURN_IF_ERROR(comp->ReplaceInstruction(conv, new_conv));
-    new_conv->set_custom_call_target(kCudnnConvBiasActivationForwardCallTarget);
-    comp->parent()->SetAndUniquifyInstrName(new_conv,
-                                            "cudnn-conv-bias-activation");
-    return new_conv;
-  }
-
-  return FailedPrecondition("Unsupported conv: %s", conv->ToString());
-}
-
-// convert<cvt_type>(gte(custom-call<conv_type>(int8_x, int8_w))) ->
-// gte(custom-call<cvt_type>(int8_x, int8_w))
-absl::StatusOr<bool> FuseConvertTypeIntoConv(HloComputation* comp,
-                                             PrimitiveType conv_type,
-                                             PrimitiveType cvt_type) {
-  bool changed = false;
-  for (auto instr : comp->MakeInstructionPostOrder()) {
-    HloInstruction* conv = nullptr;
-    auto tuple_elem =
-        m::GetTupleElement(m::Op(&conv).WithPredicate(IsConvCustomCall), 0)
-            .WithElementType(conv_type);
-    auto pattern =
-        m::Convert(tuple_elem.WithOneUser()).WithElementType(cvt_type);
-    if (!Match(instr, pattern)) {
-      continue;
-    }
-    if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
-          return absl::StrCat("FuseConvertTypeIntoConv: ", conv->ToString());
-        })) {
-      continue;
-    }
-
-    Shape new_shape = conv->shape();
-    new_shape.mutable_tuple_shapes(0)->set_element_type(cvt_type);
-    HloInstruction* new_conv =
-        comp->AddInstruction(conv->CloneWithNewShape(new_shape));
-    comp->parent()->SetAndUniquifyInstrName(new_conv, conv->name());
-    TF_ASSIGN_OR_RETURN(HloInstruction * new_gte,
-                        MakeGetTupleElementHlo(new_conv, 0));
-    TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_gte));
-
-    changed = true;
-  }
-
-  return changed;
-}
-
-struct ConvConvertTypes {
-  PrimitiveType convolution_type;
-  PrimitiveType conversion_type;
-};
-
-// Remove convert around convolution by making the convolution-type
-// (custom call) to be the same as the conversion result.
-// For example: convert<float>(gte(custom-call<int32>(int8_x, int8_w))) ->
-// gte(custom-call<float>(int8_x, int8_w))
-absl::StatusOr<bool> FuseRemoveConvertInConv(HloComputation* comp) {
-  bool changed = false;
-  // Note: We are eliminating F16->F32 because it fails on internal tests.
-  std::array<ConvConvertTypes, 3> types{{
-      {S32, F32},
-      {S8, F32},
-      {F32, S8},
-  }};
-  for (auto [conv_type, cvt_type] : types) {
-    TF_ASSIGN_OR_RETURN(bool curr_change,
-                        FuseConvertTypeIntoConv(comp, conv_type, cvt_type));
-    changed |= curr_change;
-  }
-  return changed;
-}
-
-// alpha * gte(custom-call(...)) ->
-// gte(custom-call(..., backend_config={alpha})).
-absl::StatusOr<bool> FuseConvAlpha(HloComputation* comp) {
-  bool changed = false;
-  for (auto instr : comp->MakeInstructionPostOrder()) {
-    HloInstruction* conv = nullptr;
-    HloInstruction* gte = nullptr;
-    HloInstruction* alpha = nullptr;
-
-    auto pattern = m::MultiplyAnyOrder(
-        m::GetTupleElement(
-            &gte, m::Op(&conv).WithPredicate(IsNonDepthwiseConvCustomCall), 0)
-            .WithOneUse(),
-        m::Broadcast(m::ConstantEffectiveScalar(&alpha)));
-    if (!Match(instr, pattern)) {
-      continue;
-    }
-
-    // alpha is f32 except for f64 convs, where it's f64.  See
-    // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionBiasActivationForward
-    PrimitiveType alpha_ty = gte->shape().element_type() == F64 ? F64 : F32;
-    if (!IsLosslesslyConvertibleTo(alpha, alpha_ty)) {
-      continue;
-    }
-
-    TF_ASSIGN_OR_RETURN(auto gpu_config,
-                        conv->backend_config<GpuBackendConfig>());
-    CudnnConvBackendConfig& config =
-        *gpu_config.mutable_cudnn_conv_backend_config();
-
-    if (config.conv_result_scale() != 1) {
-      continue;
-    }
-    if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
-          return absl::StrCat("FuseConvAlpha: ", conv->ToString());
-        })) {
-      continue;
-    }
-
-    // StreamExecutor doesn't support the alpha parameter on non-bias-activation
-    // convs, so we have to upgrade `conv`.
-    TF_ASSIGN_OR_RETURN(conv, EnsureIsConvBiasActivation(conv));
-
-    TF_ASSIGN_OR_RETURN(Literal alpha_f64, alpha->literal().Convert(F64));
-    config.set_conv_result_scale(alpha_f64.GetFirstElement<double>());
-    TF_RETURN_IF_ERROR(conv->set_backend_config(gpu_config));
-    TF_RETURN_IF_ERROR(conv->parent()->ReplaceInstruction(instr, gte));
-
-    changed = true;
-  }
-  return changed;
-}
-
-// The format of the serialized graph describing a sequence of ops fused
-// into the cuDNN convolution Custom Call is
-// "UID:[output_type]conv();UID[output_type]:op_name(operand
-// UID);UID:[output_type]op_name(operand UID);..." with the convolution assumed
-// to be the first op in the graph. Operand UIDs identifying ops outside the
-// serialized graph are elided. Currently, multiplication and division by a
-// broadcast scalar, addition of a matrix bias, the application of a ReLU
-// activation and the calculation of the maximum of the absolute value are
-// supported.
-class GraphString {
- public:
-  GraphString() = default;
-
-  bool AppendOp(std::string op_name, HloInstruction* op,
-                std::vector<HloInstruction*> operands = {}) {
-    std::optional<int64_t> operand_uid;
-    int num_operands_in_graph = 0;
-    for (HloInstruction* operand : operands) {
-      if (OpInGraph(operand->unique_id())) {
-        num_operands_in_graph++;
-        // Ops with more than one operand in the graph are not supported.
-        if (num_operands_in_graph > 1) {
-          return false;
-        }
-        operand_uid = operand->unique_id();
-      }
-    }
-    graph_.emplace_back(OpDescriptor(
-        {op->unique_id(), op->shape().element_type(), op_name, operand_uid}));
-    return true;
-  }
-
-  void ChangeDataType(PrimitiveType type) {
-    DCHECK(!graph_.empty());
-    graph_.back().output_type = type;
-  }
-
-  std::string Graph() const {
-    std::string graph;
-    for (OpDescriptor op : graph_) {
-      graph.append(std::to_string(op.uid));
-      graph.append(":[" +
-                   primitive_util::LowercasePrimitiveTypeName(op.output_type) +
-                   "]");
-      graph.append(op.name);
-      graph.append("(");
-      if (op.operand.has_value()) {
-        graph.append(std::to_string(*op.operand));
-      }
-      graph.append(");");
-    }
-    return graph;
-  }
-
-  bool OpInGraph(int64_t uid, std::string op_name = "") const {
-    auto op_filter = [&](OpDescriptor op) -> bool {
-      if (op_name.empty()) {
-        return op.uid == uid;
-      } else {
-        return op.uid == uid && op.name == op_name;
-      }
-    };
-    return std::find_if(graph_.begin(), graph_.end(), op_filter) !=
-           graph_.end();
-  }
-
- private:
-  struct OpDescriptor {
-    int64_t uid;
-    PrimitiveType output_type;
-    std::string name;
-    std::optional<int64_t> operand;
-  };
-
-  std::vector<OpDescriptor> graph_;
-};
-
-bool IsF8Type(const HloInstruction* instr) {
-  return primitive_util::IsF8Type(instr->shape().element_type());
-}
-
-bool IsScalar(const HloInstruction* instr) {
-  return ShapeUtil::IsScalar(instr->shape());
-}
-
-std::optional<PrimitiveType> IsSaturatingCastToF8(HloInstruction* instr) {
-  HloInstruction *op, *clamp_lower, *clamp_upper;
-  if (Match(instr,
-            m::Convert(
-                &op,
-                m::Clamp(m::Broadcast(m::ConstantScalar(&clamp_lower)), m::Op(),
-                         m::Broadcast(m::ConstantScalar(&clamp_upper))))) &&
-      ((op->shape().element_type() == F8E4M3FN &&
-        clamp_lower->literal().IsAllFloat(static_cast<float>(
-            std::numeric_limits<tsl::float8_e4m3fn>::lowest())) &&
-        clamp_upper->literal().IsAllFloat(static_cast<float>(
-            std::numeric_limits<tsl::float8_e4m3fn>::max()))) ||
-       (op->shape().element_type() == F8E5M2 &&
-        clamp_lower->literal().IsAllFloat(static_cast<float>(
-            std::numeric_limits<tsl::float8_e5m2>::lowest())) &&
-        clamp_upper->literal().IsAllFloat(static_cast<float>(
-            std::numeric_limits<tsl::float8_e5m2>::max()))))) {
-    return op->shape().element_type();
-  }
-  return std::nullopt;
-}
-
-// Returns whether the HLO Computation applied by `op` calculates the largest
-// element.
-bool AppliesMaxReduce(HloInstruction* op) {
-  HloComputation* reduce_comp = op->to_apply();
-  HloInstruction* reduce_comp_root = reduce_comp->root_instruction();
-  return ShapeUtil::IsScalar(op->shape()) &&
-         ShapeUtil::IsScalar(op->operand(1)->shape()) &&
-         op->operand(1)->IsConstant() &&
-         op->operand(1)->literal().GetAsDouble({}) <= 0. &&
-         reduce_comp_root->opcode() == HloOpcode::kMaximum &&
-         reduce_comp_root->operand(0)->opcode() == HloOpcode::kParameter &&
-         reduce_comp_root->operand(1)->opcode() == HloOpcode::kParameter;
-}
-
-// Recursively captures and serializes the graph of pointwise operations
-// operating on the convolution.
-void CaptureConvGraphRecursive(HloInstruction* instr,
-                               std::vector<HloInstruction*>& operands,
-                               std::vector<HloInstruction*>& aux_outputs,
-                               GraphString& graph_string,
-                               absl::flat_hash_set<int>& visited_instrs,
-                               HloInstruction*& final_instr) {
-  // Avoid visiting the same instruction more than once.
-  if (!visited_instrs.emplace(instr->unique_id()).second) {
-    return;
-  }
-  final_instr = instr;
-
-  // Copy the current state in case fusion will be unsuccessful or unfavorable.
-  GraphString init_graph_string = graph_string;
-  std::vector<HloInstruction*> init_operands = operands,
-                               init_aux_outputs = aux_outputs;
-  // The loop adds each user of `instr` that supports fusion into the
-  // cuDNN convolution Custom Call to GraphString. Most ops following the
-  // convolution describe a linear sequence that generates a single return
-  // tensor. The identification of one of these linear ops is followed by a
-  // recursive call of CaptureConvGraphRecursive to match and potentially fuse
-  // its users. The calculation of the scalar maximum of the absolute value
-  // (Amax) of a preceding op is considered a nonlinear user as it adds a
-  // return value to the convolution. The users of a nonlinear op are
-  // not considered for fusion into the Custom Call. The numbers of linear and
-  // nonlinear users of `instr` are stored in `num_linear_users` and
-  // `num_nonlinear_users`.
-  int num_linear_users = 0, num_nonlinear_users = 0;
-  for (HloInstruction* user : instr->users()) {
-    HloInstruction *op, *operand0, *operand1;
-    // Add
-    if (Match(user, m::AddAnyOrder(&op, m::Op(&operand0), m::Op(&operand1)))) {
-      if (graph_string.AppendOp("add", op, {operand0, operand1})) {
-        operands.push_back(operand0 == instr ? operand1 : operand0);
-        num_linear_users++;
-        CaptureConvGraphRecursive(user, operands, aux_outputs, graph_string,
-                                  visited_instrs, final_instr);
-      }
-      continue;
-    }
-    // Scale
-    if (Match(user, m::MultiplyAnyOrder(&op, m::Op(&operand0),
-                                        m::Broadcast(m::Op(&operand1)))) &&
-        ShapeUtil::IsScalar(operand1->shape())) {
-      if (graph_string.AppendOp("scale", op, {operand0, operand1})) {
-        operands.push_back(operand1);
-        num_linear_users++;
-        CaptureConvGraphRecursive(user, operands, aux_outputs, graph_string,
-                                  visited_instrs, final_instr);
-      }
-      continue;
-    }
-    // Inverse Scale
-    if (Match(user, m::Divide(&op, m::Op(&operand0),
-                              m::Broadcast(m::Op(&operand1)))) &&
-        ShapeUtil::IsScalar(operand1->shape())) {
-      if (graph_string.AppendOp("invscale", op, {operand0, operand1})) {
-        operands.push_back(operand1);
-        num_linear_users++;
-        CaptureConvGraphRecursive(user, operands, aux_outputs, graph_string,
-                                  visited_instrs, final_instr);
-      }
-      continue;
-    }
-    // ReLU
-    if (Match(user, m::MaximumAnyOrder(&op, m::Op(&operand0),
-                                       m::Broadcast(m::ConstantScalar(0))))) {
-      if (graph_string.AppendOp("relu", op, {operand0})) {
-        num_linear_users++;
-        CaptureConvGraphRecursive(user, operands, aux_outputs, graph_string,
-                                  visited_instrs, final_instr);
-      }
-      continue;
-    }
-    //  Maximum of the absolute value (Amax) following ReLU (elided Abs) -- not
-    //  a linear user
-    if (Match(user, m::Reduce(&op, m::Op(&operand0), m::Op())) &&
-        graph_string.OpInGraph(operand0->unique_id(), "relu") &&
-        AppliesMaxReduce(op)) {
-      if (graph_string.AppendOp("amax", op, {operand0})) {
-        aux_outputs.emplace_back(op);
-        num_nonlinear_users++;
-      }
-      continue;
-    }
-
-    // The following patterns match the user of `user`.
-    if (!user->users().empty()) {
-      HloInstruction* users_user = user->users()[0];
-      // Convert with Clamp to FP8 types
-      std::optional<PrimitiveType> f8_type = IsSaturatingCastToF8(users_user);
-      if (f8_type.has_value()) {
-        graph_string.ChangeDataType(f8_type.value());
-        num_linear_users++;
-        CaptureConvGraphRecursive(users_user, operands, aux_outputs,
-                                  graph_string, visited_instrs, final_instr);
-        continue;
-      }
-      // Maximum of the absolute value (Amax) -- not a linear user
-      if (Match(users_user,
-                m::Reduce(&op, m::Abs(m::Op(&operand0)), m::Op())) &&
-          AppliesMaxReduce(op)) {
-        if (graph_string.AppendOp("amax", op, {operand0})) {
-          aux_outputs.emplace_back(op);
-          num_nonlinear_users++;
-        }
-        continue;
-      }
-    }
-  }
-  // Do not fuse into the cuDNN convolution Custom Call when there are more than
-  // one linear or nonlinear users, or when the number of users eligible for
-  // fusion is less than the total number of users.
-  if (num_linear_users > 1 || num_nonlinear_users > 1 ||
-      num_linear_users + num_nonlinear_users < instr->user_count()) {
-    graph_string = init_graph_string;
-    operands = init_operands;
-    aux_outputs = init_aux_outputs;
-    final_instr = instr;
-  }
-}
-
-// Captures in a GraphString the subgraph of pointwise operations operating on
-// the convolution that will be fused into the cuDNN convolution Custom Call.
-absl::StatusOr<
-    std::tuple<std::vector<HloInstruction*>, std::vector<HloInstruction*>,
-               GraphString, HloInstruction*>>
-CaptureConvGraph(HloInstruction* instr, HloInstruction* convolution,
-                 HloInstruction* wide_input, HloInstruction* wide_filter,
-                 HloInstruction* input_scale, HloInstruction* filter_scale,
-                 bool x_mult_scale, bool w_mult_scale) {
-  GraphString graph_string;
-  graph_string.AppendOp("conv", instr);
-
-  // Shift the scaling of the input and filter to the output of the convolution.
-  HloInstruction *input_scaled_conv, *filter_scaled_conv;
-  if (input_scale) {
-    TF_RETURN_IF_ERROR(convolution->ReplaceOperandWith(0, wide_input));
-    HloInstruction* bcast_input_scale = instr->AddInstruction(
-        HloInstruction::CreateBroadcast(instr->shape(), input_scale, {}));
-    input_scaled_conv = instr->AddInstruction(HloInstruction::CreateBinary(
-        instr->shape(),
-        x_mult_scale ? HloOpcode::kMultiply : HloOpcode::kDivide, instr,
-        bcast_input_scale));
-    TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(input_scaled_conv));
-  }
-  if (filter_scale) {
-    TF_RETURN_IF_ERROR(convolution->ReplaceOperandWith(1, wide_filter));
-    HloInstruction* bcast_filter_scale = instr->AddInstruction(
-        HloInstruction::CreateBroadcast(instr->shape(), filter_scale, {}));
-    filter_scaled_conv = instr->AddInstruction(HloInstruction::CreateBinary(
-        instr->shape(),
-        w_mult_scale ? HloOpcode::kMultiply : HloOpcode::kDivide,
-        input_scale ? input_scaled_conv : instr, bcast_filter_scale));
-    TF_RETURN_IF_ERROR((input_scale ? input_scaled_conv : instr)
-                           ->ReplaceAllUsesWith(filter_scaled_conv));
-  }
-
-  std::vector<HloInstruction*> operands, aux_outputs;
-  absl::flat_hash_set<int> visited_instrs;
-  HloInstruction* final_instr;
-  CaptureConvGraphRecursive(instr, operands, aux_outputs, graph_string,
-                            visited_instrs, final_instr);
-  return std::make_tuple(operands, aux_outputs, graph_string, final_instr);
-}
-
-// Matches convolutions operating on FP8 inputs and filters and rewrites into a
-// ForwardGraph Custom Call. For scaled FP8 convolutions on Hopper systems, the
-// following steps are elided and rewritten into a ForwardGraph Custom Call:
-//
-// 1. Cast the filter and input from FP8 to a wider type such as FP16 or FP32.
-// 2. Optionally unscale the filter and input by multiplying or dividing by
-// scalars.
-// 3. Evaluate the convolution based on the scaled filter and input.
-// 4. Apply a series of elementwise transformations, where a transformation can
-// be adding a matrix bias, applying a ReLU activation, or
-// multiplying or dividing by a broadcast scalar.
-// 5. Optionally calculate the maximum of the absolute of the result.
-// 6. Optionally cast the output back to FP8.
-absl::StatusOr<bool> F8GraphConv(HloComputation* comp,
-                                 se::CudaComputeCapability cc,
-                                 se::dnn::VersionInfo dnn_version,
-                                 int32_t toolkit_version) {
-  bool changed = false;
-
-  if (dnn_version < se::dnn::VersionInfo(8, 9, 0)) {
-    return false;
-  }
-  if (toolkit_version < 12000) {
-    return false;
-  }
-  if (!cc.IsAtLeast(se::CudaComputeCapability::HOPPER)) {
-    return false;
-  }
-  for (auto instr : comp->MakeInstructionPostOrder()) {
-    HloInstruction *convolution, *gte, *input, *filter,
-        *input_scale = nullptr, *filter_scale = nullptr,
-        *input_scale_op = nullptr, *filter_scale_op = nullptr,
-        *wide_input = nullptr, *wide_filter = nullptr;
-
-    auto conv_operand_maybe_scaled = [](HloInstruction** operand,
-                                        HloInstruction** wide_operand,
-                                        HloInstruction** scale_op,
-                                        HloInstruction** scale) {
-      return m::AnyOf<HloInstruction>(
-          m::Op(operand).WithPredicate(IsF8Type),
-          m::Convert(wide_operand, m::Op(operand).WithPredicate(IsF8Type)),
-          m::Divide(
-              scale_op,
-              m::Convert(wide_operand, m::Op(operand).WithPredicate(IsF8Type)),
-              m::Broadcast(m::Op(scale).WithPredicate(IsScalar))),
-          m::MultiplyAnyOrder(
-              scale_op,
-              m::Convert(wide_operand, m::Op(operand).WithPredicate(IsF8Type)),
-              m::Broadcast(m::Op(scale).WithPredicate(IsScalar))));
-    };
-
-    // TODO(philipphack): Consider allowing ops between dequantization and
-    // convolution.
-    auto pattern = m::GetTupleElement(
-        &gte,
-        m::CustomCall(
-            &convolution,
-            conv_operand_maybe_scaled(&input, &wide_input, &input_scale_op,
-                                      &input_scale),
-            conv_operand_maybe_scaled(&filter, &wide_filter, &filter_scale_op,
-                                      &filter_scale))
-            .WithPredicate(IsConvCustomCall),
-        0);
-    if (Match(instr, pattern)) {
-      if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
-            return absl::StrCat("F8GraphConv: ", convolution->ToString());
-          })) {
-        continue;
-      }
-
-      std::vector<HloInstruction*> operands, aux_outputs;
-      GraphString graph_string;
-      HloInstruction* final_instr;
-
-      TF_ASSIGN_OR_RETURN(
-          std::tie(operands, aux_outputs, graph_string, final_instr),
-          CaptureConvGraph(
-              instr, convolution, wide_input, wide_filter, input_scale,
-              filter_scale,
-              input_scale_op ? input_scale_op->opcode() == HloOpcode::kMultiply
-                             : false,
-              filter_scale_op
-                  ? filter_scale_op->opcode() == HloOpcode::kMultiply
-                  : false));
-      TF_ASSIGN_OR_RETURN(auto gpu_config,
-                          convolution->backend_config<GpuBackendConfig>());
-      CudnnConvBackendConfig& config =
-          *gpu_config.mutable_cudnn_conv_backend_config();
-
-      config.set_serialized_graph(graph_string.Graph());
-      operands.insert(operands.begin(), input);
-      operands.insert(operands.begin() + 1, filter);
-
-      std::vector<Shape> output_shapes;
-      output_shapes.emplace_back(ShapeUtil::ChangeElementType(
-          ShapeUtil::GetTupleElementShape(convolution->shape(), 0),
-          final_instr->shape().element_type()));
-      for (HloInstruction* aux_output : aux_outputs) {
-        output_shapes.emplace_back(aux_output->shape());
-      }
-      output_shapes.emplace_back(
-          ShapeUtil::GetTupleElementShape(convolution->shape(), 1));
-
-      HloInstruction* new_convolution =
-          comp->AddInstruction(convolution->CloneWithNewOperands(
-              ShapeUtil::MakeTupleShape(output_shapes), operands));
-
-      new_convolution->set_custom_call_target(kCudnnConvForwardGraphCallTarget);
-      TF_RETURN_IF_ERROR(new_convolution->set_backend_config(gpu_config));
-      TF_ASSIGN_OR_RETURN(HloInstruction * new_gte,
-                          MakeGetTupleElementHlo(new_convolution, 0));
-      TF_RETURN_IF_ERROR(comp->ReplaceInstruction(final_instr, new_gte));
-
-      for (int i = 0; i < aux_outputs.size(); ++i) {
-        TF_ASSIGN_OR_RETURN(HloInstruction * new_gte,
-                            MakeGetTupleElementHlo(new_convolution, i + 1));
-        TF_RETURN_IF_ERROR(comp->ReplaceInstruction(aux_outputs[i], new_gte));
-      }
-
-      changed = true;
-    }
-  }
-  return changed;
-}
-
-absl::StatusOr<bool> FuseBiasOrSideInput(HloComputation* comp) {
-  bool changed = false;
-  for (auto instr : comp->MakeInstructionPostOrder()) {
-    HloInstruction* conv = nullptr;
-    HloInstruction* gte = nullptr;
-    HloInstruction* addend = nullptr;
-
-    auto pattern = m::AddAnyOrder(
-        m::GetTupleElement(&gte,
-                           m::Op(&conv)
-                               .WithPredicate(IsNonDepthwiseConvCustomCall)
-                               .WithOneUse(),
-                           0)
-            .WithOneUse(),
-        m::Op(&addend));
-    if (!Match(instr, pattern)) {
-      continue;
-    }
-
-    if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
-          return absl::StrCat("FuseBiasOrSideInput: ", conv->ToString());
-        })) {
-      continue;
-    }
-
-    // If it's a vanilla forward conv, upgrade it to a bias-activation conv.  We
-    // only want to do this if the fusion will succeed, but we're guaranteed
-    // that it will, because the only reason we'll bail at this point is if
-    // !can_accept_bias && !can_accept_side_input, and our shiny new
-    // bias-activation conv will be able to accept both.
-    if (conv->custom_call_target() == kCudnnConvForwardCallTarget) {
-      TF_ASSIGN_OR_RETURN(conv, EnsureIsConvBiasActivation(conv));
-    }
-
-    // Can't fuse bias or side-input if the conv already has a relu (or other
-    // activation), because bias and side-input are added before the activation
-    // is applied.
-    TF_ASSIGN_OR_RETURN(auto gpu_config,
-                        conv->backend_config<GpuBackendConfig>());
-    CudnnConvBackendConfig& config =
-        *gpu_config.mutable_cudnn_conv_backend_config();
-    if (config.activation_mode() != se::dnn::kNone) {
-      continue;
-    }
-
-    // Does `conv` already have a (nonzero) bias?  Does it already have a
-    // side_input?
-    bool can_accept_bias =
-        Match(conv->operand(2), m::Broadcast(m::ConstantEffectiveScalar(0)));
-    bool can_accept_side_input = conv->operand_count() < 4;
-
-    // The addend can be fused as a bias if
-    //  - it is 1D broadcasted in the output feature dimension, and
-    //  - it is losslessly-convertible to the correct type (f32 for s8/f32/u32
-    //    convs, and conv_ty for floating-point convs)
-    PrimitiveType conv_ty = gte->shape().element_type();
-    PrimitiveType bias_ty =
-        primitive_util::IsFloatingPointType(conv_ty) ? conv_ty : F32;
-    bool addend_may_be_rank1_bias =
-        addend->opcode() == HloOpcode::kBroadcast &&
-        addend->dimensions().size() == 1 &&
-        addend->dimensions(0) ==
-            conv->convolution_dimension_numbers().output_feature_dimension() &&
-        IsLosslesslyConvertibleTo(addend, bias_ty);
-
-    bool addend_may_be_rank0_bias = addend->opcode() == HloOpcode::kBroadcast &&
-                                    addend->dimensions().empty() &&
-                                    IsLosslesslyConvertibleTo(addend, bias_ty);
-
-    absl::InlinedVector<HloInstruction*, 4> new_operands(
-        conv->operands().begin(), conv->operands().end());
-    if (can_accept_bias && addend_may_be_rank1_bias) {
-      new_operands[2] = MakeConvertToHlo(addend->mutable_operand(0), bias_ty,
-                                         &addend->operand(0)->metadata());
-    } else if (can_accept_bias && addend_may_be_rank0_bias) {
-      new_operands[2] = MakeBroadcastHlo(
-          MakeConvertToHlo(addend->mutable_operand(0), bias_ty,
-                           &addend->operand(0)->metadata()),
-          /*broadcast_dimensions=*/{},
-          /*result_shape_bounds=*/
-          {gte->shape().dimensions(conv->convolution_dimension_numbers()
-                                       .output_feature_dimension())});
-    } else if (can_accept_side_input) {
-      CHECK_EQ(new_operands.size(), 3);
-      new_operands.push_back(addend);
-      config.set_side_input_scale(1);
-    } else {
-      // Can't fuse; this op already has a bias and a side-input.
-      continue;
-    }
-
-    HloInstruction* new_conv = comp->AddInstruction(
-        conv->CloneWithNewOperands(conv->shape(), new_operands));
-    comp->parent()->SetAndUniquifyInstrName(new_conv, conv->name());
-    TF_RETURN_IF_ERROR(new_conv->set_backend_config(gpu_config));
-    TF_ASSIGN_OR_RETURN(HloInstruction * new_instr,
-                        MakeGetTupleElementHlo(new_conv, 0));
-    TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_instr));
-    changed = true;
-  }
-  return changed;
-}
-
-// custom-call(..., alpha * side_input) ->
-// custom-call(..., side_input, backend_config={alpha}).
-//
-// We also have to support the more complicated case of
-//
-//   custom-call(..., reshape(side_input * alpha)) -->
-//   custom-call(..., reshape(side_input), backend_config={alpha}),
-//
-// where `reshape` can be an arbitrary chain of reshapes+transposes.  This idiom
-// is created by the ReshapeMover pass.
-absl::StatusOr<bool> FuseSideInputAlpha(HloComputation* comp) {
-  bool changed = false;
-  for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
-    HloInstruction* conv;
-    HloInstruction* side_input;
-    auto pattern = m::Op(&conv)
-                       .WithPredicate(IsConvCustomCall)
-                       .WithOperand(3, m::Op(&side_input));
-    if (!Match(instr, pattern)) {
-      continue;
-    }
-    TF_ASSIGN_OR_RETURN(auto gpu_config,
-                        conv->backend_config<GpuBackendConfig>());
-    CudnnConvBackendConfig& config =
-        *gpu_config.mutable_cudnn_conv_backend_config();
-    if (config.side_input_scale() != 1) {
-      continue;
-    }
-
-    // Given side_input, pattern match the following (working from bottom up).
-    //
-    // before_reshape = multiply(base, broadcast(alpha))
-    // side_input = chain_of_reshapes_and_transposes(before_reshape)
-    //
-    // where alpha is a scalar constant.
-    //
-    // alpha is f32 except for f64 convs, where it's f64.  See
-    // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionBiasActivationForward
-    HloInstruction* before_reshape = side_input;
-    while (before_reshape->opcode() == HloOpcode::kReshape ||
-           before_reshape->opcode() == HloOpcode::kTranspose) {
-      before_reshape = before_reshape->mutable_operand(0);
-    }
-
-    PrimitiveType conv_ty = conv->shape().tuple_shapes(0).element_type();
-    PrimitiveType alpha_ty = conv_ty == F64 ? F64 : F32;
-    HloInstruction* base;
-    HloInstruction* alpha;
-    if (!Match(
-            before_reshape,
-            m::MultiplyAnyOrder(
-                m::Op(&base),
-                m::Broadcast(m::ConstantEffectiveScalar(&alpha).WithPredicate(
-                    [&](const HloInstruction* instr) {
-                      return IsLosslesslyConvertibleTo(instr, alpha_ty);
-                    }))))) {
-      continue;
-    }
-    if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
-          return absl::StrCat("FuseSideInputAlpha: ", conv->ToString());
-        })) {
-      continue;
-    }
-
-    // Rewrite conv's operand 3 to
-    //
-    //   chain_of_reshapes_and_transposes(before_reshape).
-    //
-    // and store alpha in the conv's backend config.
-    //
-    // We're going to do something bad here: We aren't going to check that the
-    // chain of reshapes/transposes has one use, so we're potentially
-    // duplicating all these instructions (once with alpha and once without).
-    //
-    // This is justified because
-    //
-    //  - duplicating reshapes/transposes shouldn't be "that bad" -- these
-    //    instructions can usually be fused, and
-    //
-    //  - *not* fusing alpha can be catastrophic.  For s8->s8 convolutions, the
-    //    side-input must be s8.  But the product side_input * alpha is f32, so
-    //    we can only see that side-input is s8 if we fuse alpha. IOW not fusing
-    //    alpha means we'll run this s8->s8 conv as s8->f32, which is *much*
-    //    slower than some extra transposes.
-
-    // Recursively clone chain_of_reshapes_and_transposes until we get to
-    // `before_reshape`, at which point we skip the multiply(base, alpha) and
-    // just return base.
-    std::function<HloInstruction*(const HloInstruction*)> clone =
-        [&](const HloInstruction* instr) {
-          if (instr == before_reshape) {
-            return base;
-          }
-          CHECK(instr->opcode() == HloOpcode::kReshape ||
-                instr->opcode() == HloOpcode::kTranspose)
-              << "Must be reshape or transpose: " << instr->ToString();
-          return comp->AddInstruction(instr->CloneWithNewOperands(
-              instr->shape(), {clone(instr->operand(0))}));
-        };
-    absl::InlinedVector<HloInstruction*, 4> new_operands(
-        conv->operands().begin(), conv->operands().end());
-    new_operands[3] = clone(side_input);
-
-    HloInstruction* new_conv = comp->AddInstruction(
-        conv->CloneWithNewOperands(conv->shape(), new_operands));
-    comp->parent()->SetAndUniquifyInstrName(new_conv, conv->name());
-
-    TF_ASSIGN_OR_RETURN(Literal alpha_f64, alpha->literal().Convert(F64));
-    config.set_side_input_scale(alpha_f64.GetFirstElement<double>());
-    TF_RETURN_IF_ERROR(new_conv->set_backend_config(gpu_config));
-
-    TF_RETURN_IF_ERROR(comp->ReplaceInstruction(conv, new_conv));
-    changed = true;
-  }
-  return changed;
-}
-
-absl::StatusOr<bool> FuseElu(HloComputation* comp,
-                             se::GpuComputeCapability cc) {
-  if (!ShouldUseCudnnRuntimeFusion(comp->parent()->config().debug_options(),
-                                   cc)) {
-    return false;
-  }
-
-  bool changed = false;
-  for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
-    HloInstruction *gte1, *gte2, *gte3;
-    HloInstruction* conv;
-    HloInstruction* expm1;
-
-    if (!Match(instr,
-               m::Select(m::Compare(m::GetTupleElement(&gte1, m::Op()),
-                                    m::Broadcast(m::ConstantEffectiveScalar(0)))
-                             .WithComparisonDirection(ComparisonDirection::kGt)
-                             .WithOneUse(),
-                         m::GetTupleElement(
-                             &gte2,
-                             m::Op(&conv)
-                                 .WithPredicate(IsNonDepthwiseConvCustomCall)
-                                 .WithOneUse(),
-                             /*tuple_index=*/0)
-                             // TODO(jlebar): Why only fp16?
-                             .WithElementType(F16),
-                         m::Op(&expm1)
-                             .WithOpcode(HloOpcode::kExpm1)
-                             .WithOperand(0, m::GetTupleElement(&gte3, m::Op()))
-                             .WithOneUse()))) {
-      continue;
-    }
-
-    // The three GTEs should be the same, and these should be the only uses.
-    if (gte1 != gte2 || gte2 != gte3 || gte1->user_count() != 3) {
-      continue;
-    }
-
-    if (!IsSuitableForCudnnRuntimeFusion(conv)) {
-      continue;
-    }
-
-    TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
-                        conv->backend_config<GpuBackendConfig>());
-    CudnnConvBackendConfig& config =
-        *gpu_config.mutable_cudnn_conv_backend_config();
-    if (config.activation_mode() != se::dnn::kNone) {
-      continue;
-    }
-
-    if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
-          return absl::StrCat("FuseElu: ", conv->ToString());
-        })) {
-      continue;
-    }
-    TF_ASSIGN_OR_RETURN(conv, EnsureIsConvBiasActivation(conv));
-    config.set_activation_mode(se::dnn::kElu);
-    TF_RETURN_IF_ERROR(conv->set_backend_config(gpu_config));
-    TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, gte1));
-    changed = true;
-  }
-  return changed;
-}
-
-absl::StatusOr<bool> FuseRelu(HloComputation* comp) {
-  bool changed = false;
-  for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
-    HloInstruction* gte;
-    HloInstruction* conv;
-    if (!Match(instr,
-               m::MaximumAnyOrder(
-                   m::Broadcast(m::ConstantEffectiveScalar(0)),
-                   m::GetTupleElement(
-                       &gte, m::Op(&conv)
-                                 .WithPredicate(IsNonDepthwiseConvCustomCall)
-                                 .WithOneUse())
-                       .WithOneUse()))) {
-      continue;
-    }
-    TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
-                        conv->backend_config<GpuBackendConfig>());
-    CudnnConvBackendConfig& config =
-        *gpu_config.mutable_cudnn_conv_backend_config();
-    if (config.activation_mode() != se::dnn::kNone) {
-      continue;
-    }
-
-    if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
-          return absl::StrCat("FuseRelu: ", conv->ToString());
-        })) {
-      continue;
-    }
-    TF_ASSIGN_OR_RETURN(conv, EnsureIsConvBiasActivation(conv));
-    config.set_activation_mode(se::dnn::kRelu);
-    TF_RETURN_IF_ERROR(conv->set_backend_config(gpu_config));
-    TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, gte));
-    changed = true;
-  }
-  return changed;
-}
-
-absl::StatusOr<bool> FuseRelu6(HloComputation* comp,
-                               se::GpuComputeCapability cc) {
-  if (!ShouldUseCudnnRuntimeFusion(comp->parent()->config().debug_options(),
-                                   cc)) {
-    return false;
-  }
-
-  bool changed = false;
-  for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
-    HloInstruction *gte, *conv;
-    if (!Match(
-            instr,
-            m::Clamp(m::Broadcast(m::ConstantEffectiveScalar(0)),
-                     m::GetTupleElement(
-                         &gte, m::Op(&conv)
-                                   .WithPredicate(IsNonDepthwiseConvCustomCall)
-                                   .WithOneUse())
-                         // TODO(jlebar): Why only fp16?
-                         .WithElementType(F16)
-                         .WithOneUse(),
-                     m::Broadcast(m::ConstantEffectiveScalar(6))))) {
-      continue;
-    }
-    TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
-                        conv->backend_config<GpuBackendConfig>());
-    CudnnConvBackendConfig& config =
-        *gpu_config.mutable_cudnn_conv_backend_config();
-    if (config.activation_mode() != se::dnn::kNone) {
-      continue;
-    }
-
-    if (!IsSuitableForCudnnRuntimeFusion(conv)) {
-      continue;
-    }
-
-    if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
-          return absl::StrCat("FuseRelu6: ", conv->ToString());
-        })) {
-      continue;
-    }
-    TF_ASSIGN_OR_RETURN(conv, EnsureIsConvBiasActivation(conv));
-    config.set_activation_mode(se::dnn::kRelu6);
-    TF_RETURN_IF_ERROR(conv->set_backend_config(gpu_config));
-    TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, gte));
-    changed = true;
-  }
-  return changed;
-}
-
-absl::StatusOr<bool> FuseLeakyRelu(HloComputation* comp,
-                                   se::GpuComputeCapability cc) {
-  if (!ShouldUseCudnnRuntimeFusion(comp->parent()->config().debug_options(),
-                                   cc)) {
-    return false;
-  }
-
-  bool changed = false;
-  for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
-    HloInstruction *gte1, *gte2, *gte3, *conv, *alpha;
-    if (!Match(instr,
-               m::Select(
-                   m::Compare(m::GetTupleElement(&gte1, m::Op()),
-                              m::Broadcast(m::ConstantEffectiveScalar(0)))
-                       .WithComparisonDirection(ComparisonDirection::kGt)
-                       .WithOneUse(),
-                   m::GetTupleElement(
-                       &gte2, m::Op(&conv)
-                                  .WithPredicate(IsNonDepthwiseConvCustomCall)
-                                  .WithOneUse())
-                       // TODO(jlebar): Why only fp16?
-                       .WithElementType(F16),
-                   m::Multiply(m::GetTupleElement(&gte3, m::Op()),
-                               m::Broadcast(m::ConstantEffectiveScalar(&alpha)))
-                       .WithOneUse()))) {
-      continue;
-    }
-
-    // The three GTEs should be the same, and these should be the only uses.
-    if (gte1 != gte2 || gte2 != gte3 || gte1->user_count() != 3) {
-      continue;
-    }
-
-    TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
-                        conv->backend_config<GpuBackendConfig>());
-    CudnnConvBackendConfig& config =
-        *gpu_config.mutable_cudnn_conv_backend_config();
-    if (config.activation_mode() != se::dnn::kNone) {
-      continue;
-    }
-
-    if (!IsSuitableForCudnnRuntimeFusion(conv)) {
-      continue;
-    }
-
-    if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
-          return absl::StrCat("FuseLeakyRelu: ", conv->ToString());
-        })) {
-      continue;
-    }
-    TF_ASSIGN_OR_RETURN(conv, EnsureIsConvBiasActivation(conv));
-    config.set_activation_mode(se::dnn::kLeakyRelu);
-    TF_ASSIGN_OR_RETURN(Literal alpha_f64, alpha->literal().Convert(F64));
-    config.set_leakyrelu_alpha(alpha_f64.GetFirstElement<double>());
-    TF_RETURN_IF_ERROR(conv->set_backend_config(gpu_config));
-    TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, gte1));
-    changed = true;
-  }
-  return changed;
-}
-
-absl::StatusOr<bool> FuseConvertToF16(HloComputation* comp) {
-  bool changed = false;
-  for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
-    HloInstruction* gte = nullptr;
-    HloInstruction* conv = nullptr;
-
-    auto f32_convertible_to_f16_pat =
-        m::Op().WithElementType(F32).WithPredicate(
-            IsLosslesslyConvertibleToF16);
-    if (!MatchAndLogIfFailed(
-            instr, "f16 conv",
-            m::Convert(
-                m::GetTupleElement(
-                    &gte,
-                    m::Op(&conv)
-                        .WithPredicate(IsConvCustomCall)
-                        .WithOperand(0, f32_convertible_to_f16_pat)
-                        .WithOperand(1, f32_convertible_to_f16_pat)
-                        .WithOperandIfPresent(2, f32_convertible_to_f16_pat)
-                        .WithOperandIfPresent(3, f32_convertible_to_f16_pat),
-                    0)
-                    .WithOneUse())
-                .WithElementType(F16),
-            VLOG_IS_ON(3),
-            m::Op().WithOperand(0, m::GetTupleElement(m::Op().WithPredicate(
-                                       IsConvCustomCall))))) {
-      continue;
-    }
-    if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
-          return absl::StrCat("FuseConvertToF16: ", conv->ToString());
-        })) {
-      continue;
-    }
-
-    VLOG(2) << "Matched fp16 conv: " << conv->ToString();
-
-    // In fp16 convs, all operands, including `bias`, must be fp16.  This is
-    // different from int8 convs, where the bias is fp32.  See table of
-    // supported datatypes at
-    // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionBiasActivationForward
-    absl::InlinedVector<HloInstruction*, 4> new_operands;
-    for (HloInstruction* operand : conv->operands()) {
-      new_operands.push_back(
-          MakeConvertToHlo(operand, F16, &operand->metadata()));
-    }
-
-    Shape new_shape = conv->shape();
-    new_shape.mutable_tuple_shapes(0)->set_element_type(F16);
-
-    HloInstruction* new_conv = comp->AddInstruction(
-        conv->CloneWithNewOperands(new_shape, new_operands));
-    comp->parent()->SetAndUniquifyInstrName(new_conv, conv->name());
-    TF_ASSIGN_OR_RETURN(HloInstruction * new_instr,
-                        MakeGetTupleElementHlo(new_conv, 0));
-    TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_instr));
-    changed = true;
-  }
-  return changed;
-}
-
-absl::StatusOr<bool> FuseConvertToS8(HloComputation* comp,
-                                     se::GpuComputeCapability cc) {
-  if (IsROCm(cc)) return false;
-  bool changed = false;
-  for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
-    HloInstruction* gte = nullptr;
-    HloInstruction* conv = nullptr;
-
-    auto conv_pattern =
-        m::Op(&conv)
-            .WithPredicate(IsConvCustomCall)
-            .WithOperand(0, m::Op().WithPredicate(IsLosslesslyConvertibleToS8))
-            .WithOperand(1, m::Op().WithPredicate(IsLosslesslyConvertibleToS8));
-
-    PrimitiveType conv_output_ty;
-    if (MatchAndLogIfFailed(
-            instr, "s8->s8 conv",
-            m::Convert(m::Clamp(m::Broadcast(m::ConstantEffectiveScalar(-128)),
-                                m::GetTupleElement(
-                                    &gte,
-                                    conv_pattern.WithOperandIfPresent(
-                                        3, m::Op().WithPredicate(
-                                               IsLosslesslyConvertibleToS8)),
-                                    0)
-                                    .WithOneUse(),
-                                m::Broadcast(m::ConstantEffectiveScalar(127))))
-                .WithElementType(S8),
-            VLOG_IS_ON(3),
-            m::Convert(m::Clamp(m::Op(),
-                                m::GetTupleElement(
-                                    m::Op().WithPredicate(IsConvCustomCall)),
-                                m::Op()))
-                .WithElementType(S8))) {
-      conv_output_ty = S8;
-    } else if (MatchAndLogIfFailed(
-                   instr, "s8->f32 conv",
-                   m::GetTupleElement(&gte,
-                                      conv_pattern.WithOperandIfPresent(
-                                          3, m::Op().WithElementType(F32)),
-                                      0)
-                       .WithElementType(F32),
-                   VLOG_IS_ON(3),
-                   m::GetTupleElement(m::Op().WithPredicate(IsConvCustomCall))
-                       .WithElementType(F32))) {
-      conv_output_ty = F32;
-    } else {
-      continue;
-    }
-    if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
-          return absl::StrCat("FuseConvertToS8: ", conv->ToString());
-        })) {
-      continue;
-    }
-
-    absl::InlinedVector<HloInstruction*, 4> new_operands(
-        conv->operands().begin(), conv->operands().end());
-    new_operands[0] =
-        MakeConvertToHlo(new_operands[0], S8, &new_operands[0]->metadata());
-    new_operands[1] =
-        MakeConvertToHlo(new_operands[1], S8, &new_operands[1]->metadata());
-    // Don't convert bias (operand 2); it's always f32 for s8 ops in cudnn.  See
-    // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionBiasActivationForward
-    if (new_operands.size() >= 4) {
-      // side-input always matches conv output type.  We checked in the patterns
-      // above that it's losslessly-convertible to this type.
-      new_operands[3] = MakeConvertToHlo(new_operands[3], conv_output_ty,
-                                         &new_operands[3]->metadata());
-    }
-
-    Shape new_shape = conv->shape();
-    new_shape.mutable_tuple_shapes(0)->set_element_type(conv_output_ty);
-
-    HloInstruction* new_conv = comp->AddInstruction(
-        conv->CloneWithNewOperands(new_shape, new_operands));
-    comp->parent()->SetAndUniquifyInstrName(new_conv, conv->name());
-    TF_ASSIGN_OR_RETURN(HloInstruction * new_instr,
-                        MakeGetTupleElementHlo(new_conv, 0));
-    TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_instr));
-    changed = true;
-  }
-  return changed;
-}
-
-absl::Status CheckNoIllegalIntegerConvs(HloComputation* comp) {
-  auto is_integral_not_s8 = [](const Shape& s) {
-    return primitive_util::IsIntegralType(s.element_type()) &&
-           s.element_type() != S8;
-  };
-
-  std::vector<HloInstruction*> bad_convs;
-  for (HloInstruction* instr : comp->instructions()) {
-    if (!IsConvCustomCall(instr)) {
-      continue;
-    }
-    if (is_integral_not_s8(instr->shape().tuple_shapes(0)) ||
-        is_integral_not_s8(instr->operand(0)->shape()) ||
-        is_integral_not_s8(instr->operand(1)->shape()) ||
-        (instr->operand_count() >= 4 &&
-         is_integral_not_s8(instr->operand(3)->shape()))) {
-      bad_convs.push_back(instr);
-    }
-  }
-
-  if (bad_convs.empty()) {
-    return absl::OkStatus();
-  }
-
-  return Unimplemented(
-      R"(
-Can't lower one or more integer convolutions to idioms supported by CuDNN.
-
-CuDNN integer convolutions must have:
-
-  - s8 input and filter,
-  - f32 bias (if present),
-  - s8 or f32 output, and
-  - s8 side_input (if present) if output is s8.
-
-For each of the unsupported convs below, we weren't able to lower one of the
-operands or the output to the appropriate type.
-
-See specific HLO idioms in cudnn_fused_conv_rewriter.h, and see cudnn semantics:
-
-https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionBiasActivationForward and
-https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#scaling-parameters
-
-Unsupported convs:
-%s
-
-******* Full HLO module *******
-%s
-)",
-      absl::StrJoin(bad_convs, "\n",
-                    [](std::string* out, HloInstruction* instr) {
-                      absl::StrAppend(out, " - ", instr->ToString());
-                    }),
-      comp->parent()->ToString());
-}
-
-void VlogStats(HloModule* module) {
-  if (!VLOG_IS_ON(1)) {
-    return;
-  }
-
-  VLOG(1) << "Results of CudnnFusedConvRewriter for " << module->name();
-  absl::flat_hash_map<std::string, int> stats;
-  for (HloComputation* comp : module->MakeNonfusionComputations()) {
-    for (HloInstruction* instr : comp->instructions()) {
-      if (!Match(instr, m::Op().WithPredicate(IsConvCustomCall))) {
-        continue;
-      }
-
-      VLOG(3) << instr->ToString();
-
-      if (instr->custom_call_target() == kCudnnConvForwardCallTarget) {
-        ++stats["01 non-fused forward convs"];
-      } else if (instr->custom_call_target() ==
-                 kCudnnConvBiasActivationForwardCallTarget) {
-        ++stats["02 fused forward convs"];
-      }
-
-      PrimitiveType conv_in_ty = instr->operand(0)->shape().element_type();
-      PrimitiveType conv_out_ty = instr->shape().tuple_shapes(0).element_type();
-      if (conv_in_ty == F32) {
-        ++stats["10 f32 convs"];
-      } else if (conv_in_ty == F16) {
-        ++stats["11 f16 convs"];
-      } else if (conv_in_ty == S8) {
-        if (conv_out_ty == S8) {
-          ++stats["12 s8->s8 convs"];
-        } else if (conv_out_ty == F32) {
-          ++stats["13 s8->f32 convs"];
-        } else {
-          LOG(ERROR) << "Unexpected conv: " << instr->ToString();
-        }
-      }
-
-      if (instr->operand_count() > 2) {
-        ++stats["20 convs with bias"];
-        if (Match(instr->operand(2),
-                  m::Broadcast(m::ConstantEffectiveScalar(0)))) {
-          ++stats["21 convs with 0 bias"];
-        }
-      }
-      if (instr->operand_count() > 3) {
-        ++stats["22 convs with side-input"];
-      }
-
-      auto gpu_config = instr->backend_config<GpuBackendConfig>();
-      if (!gpu_config.ok()) {
-        LOG(ERROR) << "Couldn't parse backend config for " << instr->ToString();
-        continue;
-      }
-      const CudnnConvBackendConfig& config =
-          gpu_config->cudnn_conv_backend_config();
-      if (config.conv_result_scale() != 1) {
-        ++stats["30 convs with result scale"];
-      }
-      if (config.side_input_scale() != 0 && config.side_input_scale() != 1) {
-        ++stats["31 convs with side-input scale"];
-      }
-      ++stats[absl::StrCat(
-          "32 convs with activation mode ",
-          se::dnn::ActivationMode_Name(config.activation_mode()))];
-    }
-  }
-
-  std::vector<std::pair<std::string, int>> stats_sorted(stats.begin(),
-                                                        stats.end());
-  absl::c_sort(stats_sorted);
-  for (const auto& kv : stats_sorted) {
-    VLOG(1) << absl::StreamFormat("%4d %s", kv.second,
-                                  absl::string_view(kv.first).substr(3));
-  }
-}
-
-}  // namespace
-
-absl::StatusOr<bool> CudnnFusedConvRewriter::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  bool any_changed = false;
-
-  for (HloComputation* comp :
-       module->MakeNonfusionComputations(execution_threads)) {
-    bool changed = false;
-    // Rewrite FP8 convolutions and supported adjacent pointwise ops into a
-    // ForwardGraph Custom Call.
-    if (!IsROCm(compute_capability_)) {
-      auto cc = std::get<se::CudaComputeCapability>(compute_capability_);
-      TF_ASSIGN_OR_RETURN(
-          changed, F8GraphConv(comp, cc, dnn_version_, toolkit_version_));
-      if (changed) {
-        return changed;
-      }
-    }
-    // Fuse "inside out" starting with the operations closest to the conv.
-    TF_ASSIGN_OR_RETURN(changed, FuseRemoveConvertInConv(comp));
-    any_changed |= changed;
-
-    TF_ASSIGN_OR_RETURN(changed, FuseConvAlpha(comp));
-    any_changed |= changed;
-
-    // s8 convs' bias and side-input appear before conversion to s8.
-    //
-    // Run FuseBiasOrSideInput twice, so we get both the bias and the side
-    // input, if both are present.
-    TF_ASSIGN_OR_RETURN(changed, FuseBiasOrSideInput(comp));
-    any_changed |= changed;
-    TF_ASSIGN_OR_RETURN(changed, FuseBiasOrSideInput(comp));
-    any_changed |= changed;
-    TF_ASSIGN_OR_RETURN(changed, FuseSideInputAlpha(comp));
-    any_changed |= changed;
-
-    // Relu might appear before or after convert-to-f16/s8, so we check in both
-    // cases.
-    TF_ASSIGN_OR_RETURN(changed, FuseRelu(comp));
-    any_changed |= changed;
-    TF_ASSIGN_OR_RETURN(changed, FuseElu(comp, compute_capability_));
-    any_changed |= changed;
-    TF_ASSIGN_OR_RETURN(changed, FuseRelu6(comp, compute_capability_));
-    any_changed |= changed;
-    TF_ASSIGN_OR_RETURN(changed, FuseLeakyRelu(comp, compute_capability_));
-    any_changed |= changed;
-
-    TF_ASSIGN_OR_RETURN(changed, FuseConvertToF16(comp));
-    any_changed |= changed;
-
-    TF_ASSIGN_OR_RETURN(changed, FuseConvertToS8(comp, compute_capability_));
-    any_changed |= changed;
-
-    // f16 convs' bias+side-input can appear before or after conversion to f16.
-    TF_ASSIGN_OR_RETURN(changed, FuseBiasOrSideInput(comp));
-    any_changed |= changed;
-    TF_ASSIGN_OR_RETURN(changed, FuseBiasOrSideInput(comp));
-    any_changed |= changed;
-    TF_ASSIGN_OR_RETURN(changed, FuseSideInputAlpha(comp));
-    any_changed |= changed;
-
-    TF_ASSIGN_OR_RETURN(changed, FuseRelu(comp));
-    any_changed |= changed;
-    TF_ASSIGN_OR_RETURN(changed, FuseElu(comp, compute_capability_));
-    any_changed |= changed;
-    TF_ASSIGN_OR_RETURN(changed, FuseRelu6(comp, compute_capability_));
-    any_changed |= changed;
-    TF_ASSIGN_OR_RETURN(changed, FuseLeakyRelu(comp, compute_capability_));
-    any_changed |= changed;
-
-    // Check that we don't have any convs outputting integer types other than
-    // s8 - cudnn does not support these.  They should have been transformed to
-    // int8->int8 or int8->float above.
-    TF_RETURN_IF_ERROR(CheckNoIllegalIntegerConvs(comp));
-  }
-
-  VlogStats(module);
-
-  return any_changed;
-}
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/cudnn_fused_conv_rewriter.h b/third_party/xla/xla/service/gpu/cudnn_fused_conv_rewriter.h
deleted file mode 100644
index 906a67a..0000000
--- a/third_party/xla/xla/service/gpu/cudnn_fused_conv_rewriter.h
+++ /dev/null
@@ -1,135 +0,0 @@
-/* Copyright 2018 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_CUDNN_FUSED_CONV_REWRITER_H_
-#define XLA_SERVICE_GPU_CUDNN_FUSED_CONV_REWRITER_H_
-
-#include <cstdint>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/stream_executor/dnn.h"
-
-namespace xla {
-namespace gpu {
-
-// Rewrites custom-calls targeting cudnnConvolutionForward to
-// cudnnConvolutionBiasActivationForward by fusing operations following forward
-// convolution.  This transform must run after GpuConvRewriter.
-//
-// Semantics of underlying cudnn ops:
-//
-// https://docs.nvidia.com/deeplearning/cudnn/latest/api/cudnn-cnn-library.html#cudnnconvolutionforward
-// https://docs.nvidia.com/deeplearning/cudnn/latest/developer/misc.html#scaling-parameters
-//
-// ## Floating-point convs
-//
-// A "complete" fused floating-point conv has the form
-//
-//   max(0, alpha1 * conv(x, w) + alpha2 * side_input + broadcast(bias)),
-//
-// which we fuse to
-//
-//   cudnnConvolutionBiasActivationForward(x, w, bias, side_input).
-//
-// You can leave out side_input, bias, alpha1, alpha2, and max(x, 0) and still
-// get a fused convolution.  alpha1/2 must be broadcasts of scalar constants.
-//
-// f16 convs accumulate in f32.  We represent this in HLO as an f32 convolution
-// whose inputs can be converted to f16 without loss of precision and whose
-// output is immediately converted to f16.  A fused f16 conv must follow one of
-// the following idioms.
-//
-//   1. convert_f16(conv_f32(x_f32, w_f32)) +
-//      side_input_f16 + broadcast(bias_f16)
-//
-//   2. convert_f16(conv_f32(x_f32, w_f32) +
-//                  side_input_f32 + broadcast(bias_f32))
-//
-// (These are not strictly mathematically equivalent, but cudnn doesn't tell us
-// which one it does, and we deem them "close enough".)
-//
-// The foo_f32 HLOs must all be losslessly-convertible to f16.  Some valid
-// examples:
-//
-//   - foo_f32 = convert_f32(foo_f16)
-//   - foo_f32 = an f32 constant whose values all fit within f16
-//   - foo_f32 = broadcast/transpose/reshape(one of the above)
-//
-// If you have a relu, it can appear before or after the convert_f16.
-//
-// Note that here `bias` must be losslessly-convertible to f16; this is
-// different than for s8 convolutions, where bias is f32.
-//
-// ## Integer convs
-//
-// In pure HLO, a "complete" integer conv is spelled as one of the following
-// `result`s.
-//
-//   base = alpha1_f32 * convert_f32(conv_s32(input_s32, filter_s32)) +
-//          alpha2_f32 * side_input +
-//          bias_f32
-//
-//   result_f32        = max(base, 0)
-//   result_s8_option1 = max(convert_s8(clamp(-128, base, 127)), 0)
-//   result_s8_option2 = convert_s8(clamp(-128, max(base, 0), 127))
-//
-// The foo_s32 HLOs must be losslessly-convertible to s8.  If the `result_s8`
-// case, side_input should be an f32 HLO that's losslessly-convertible to s8;
-// otherwise, it should be losslessly-convertible to f32.
-//
-// In the `result_s8` case where there's no bias, side-input, or alpha1, you can
-// skip the convert_f32 on conv.
-//
-// If you have an integer convolution that doesn't fit one of these idioms, this
-// pass returns an error -- cudnn will not be able to run it.
-class CudnnFusedConvRewriter : public HloModulePass {
- public:
-  CudnnFusedConvRewriter(se::CudaComputeCapability cc,
-                         se::dnn::VersionInfo dnn_version,
-                         int32_t toolkit_version)
-      : compute_capability_(cc),
-        dnn_version_(dnn_version),
-        toolkit_version_(toolkit_version) {}
-  CudnnFusedConvRewriter(se::RocmComputeCapability cc,
-                         se::dnn::VersionInfo dnn_version,
-                         int32_t toolkit_version)
-      : compute_capability_(cc),
-        dnn_version_(dnn_version),
-        toolkit_version_(toolkit_version) {}
-
-  absl::string_view name() const override {
-    return "cudnn-fused-convolution-rewriter";
-  }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
-  const se::GpuComputeCapability compute_capability_;
-  const se::dnn::VersionInfo dnn_version_;
-  const int32_t toolkit_version_;
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_CUDNN_FUSED_CONV_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc b/third_party/xla/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc
deleted file mode 100644
index 0a58ecf..0000000
--- a/third_party/xla/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc
+++ /dev/null
@@ -1,3170 +0,0 @@
-/* Copyright 2018 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/cudnn_fused_conv_rewriter.h"
-
-#include <array>
-#include <memory>
-#include <string>
-#include <string_view>
-#include <thread>  // NOLINT
-#include <utility>
-#include <variant>
-#include <vector>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "absl/container/flat_hash_map.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/str_format.h"
-#include "absl/strings/str_replace.h"
-#include "absl/strings/string_view.h"
-#include "xla/comparison_util.h"
-#include "xla/error_spec.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/gpu/stream_executor_util.h"
-#include "xla/service/hlo_module_config.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/stream_executor/dnn.h"
-#include "xla/tests/verified_hlo_module.h"
-#include "tsl/platform/statusor.h"
-
-#if GOOGLE_CUDA
-#include "third_party/gpus/cuda/include/cuda.h"
-#elif TENSORFLOW_USE_ROCM
-#include "rocm/rocm_config.h"
-#endif  // GOOGLE_CUDA
-
-#include "xla/service/algebraic_simplifier.h"
-#include "xla/service/convert_mover.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/service/gpu/gpu_conv_rewriter.h"
-#include "xla/service/gpu/tests/gpu_codegen_test.h"
-#include "xla/service/hlo_constant_folding.h"
-#include "xla/service/hlo_pass_fix.h"
-#include "xla/service/hlo_pass_pipeline.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/service/reshape_mover.h"
-#include "xla/tests/filecheck.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/tsl/lib/core/status_test_util.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-// TODO(b/210165681): The tests in this file are fragile to HLO op names.
-
-namespace m = match;
-
-using ::testing::HasSubstr;
-using ::testing::Not;
-
-// TODO: Use constexpr vector once XLA is compiled with C++20.
-const auto* kf16f32f64 = new std::vector<std::string>({"f16", "f32", "f64"});
-const auto* kf16f32 = new std::vector<std::string>({"f16", "f32"});
-
-class CudnnFusedConvRewriterHloTest : public HloTestBase {
- public:
-  bool IsCuda() {
-    return std::holds_alternative<se::CudaComputeCapability>(
-        backend()
-            .default_stream_executor()
-            ->GetDeviceDescription()
-            .gpu_compute_capability());
-  }
-  se::CudaComputeCapability GetCudaComputeCapability() {
-    return backend()
-        .default_stream_executor()
-        ->GetDeviceDescription()
-        .cuda_compute_capability();
-  }
-  stream_executor::dnn::VersionInfo GetDnnVersion() {
-    return GetDnnVersionInfoOrDefault(backend().default_stream_executor());
-  }
-
-  int32_t GetToolkitVersion() const {
-#if GOOGLE_CUDA
-    return CUDA_VERSION;
-#elif TENSORFLOW_USE_ROCM
-    return TF_ROCM_VERSION;
-#endif
-    return 0;
-  }
-
-  CudnnFusedConvRewriterHloTest()
-      : HloTestBase(/*verifier_layout_sensitive=*/false,
-                    /*allow_mixed_precision_in_hlo_verifier=*/false,
-                    /*instruction_can_change_layout_func=*/{}) {}
-};
-
-class CudnnFusedConvRewriterTest : public GpuCodegenTest {
- public:
-  bool IsCuda() {
-    return std::holds_alternative<se::CudaComputeCapability>(
-        backend()
-            .default_stream_executor()
-            ->GetDeviceDescription()
-            .gpu_compute_capability());
-  }
-  se::CudaComputeCapability GetCudaComputeCapability() {
-    return backend()
-        .default_stream_executor()
-        ->GetDeviceDescription()
-        .cuda_compute_capability();
-  }
-  stream_executor::dnn::VersionInfo GetDnnVersion() {
-    return GetDnnVersionInfoOrDefault(backend().default_stream_executor());
-  }
-
-  int32_t GetToolkitVersion() const {
-#if GOOGLE_CUDA
-    return CUDA_VERSION;
-#elif TENSORFLOW_USE_ROCM
-    return TF_ROCM_VERSION;
-#endif
-    return 0;
-  }
-
- protected:
-  std::string GetOptimizedHlo(absl::string_view hlo_string) {
-    // cudnn_vectorize_convolutions transforms convolutions, making it hard to
-    // match them here in this test.  What's worse, the transforms it does
-    // depends on the GPU that's available!  So just disable them for this
-    // function that gets the optimized HLO.  When we actually run the module
-    // we'll still have this pass enabled.
-    HloModuleConfig config = GetModuleConfigForTest();
-    DebugOptions debug_opts = config.debug_options();
-    debug_opts.add_xla_disable_hlo_passes("cudnn_vectorize_convolutions");
-    debug_opts.set_xla_gpu_use_runtime_fusion(true);
-    config.set_debug_options(debug_opts);
-
-    auto result = backend().compiler()->RunHloPasses(
-        ParseAndReturnVerifiedModule(hlo_string, config).value(),
-        backend().default_stream_executor(), backend().memory_allocator());
-    if (!result.status().ok()) {
-      TF_EXPECT_OK(result.status())
-          << "HLO compilation failed: " << result.status();
-      return "";
-    }
-    HloPrintOptions print_opts;
-    print_opts.set_print_operand_shape(false);
-    return (*result)->ToString(print_opts);
-  }
-
-  void TestMatchWithAllTypes(absl::string_view hlo_string) {
-    for (absl::string_view type : *(IsCuda() ? kf16f32f64 : kf16f32)) {
-      const std::string hlo_with_new_type =
-          absl::StrReplaceAll(hlo_string, {{"TYPE", type}});
-      std::string optimized_hlo_string = GetOptimizedHlo(hlo_with_new_type);
-      EXPECT_THAT(optimized_hlo_string,
-                  Not(HasSubstr(kCudnnConvForwardCallTarget)))
-          << optimized_hlo_string;
-      EXPECT_THAT(optimized_hlo_string,
-                  HasSubstr(kCudnnConvBiasActivationForwardCallTarget));
-
-      TF_ASSERT_OK_AND_ASSIGN(auto module,
-                              ParseAndReturnVerifiedModule(hlo_with_new_type));
-      DebugOptions debug_opts = module->config().debug_options();
-      debug_opts.set_xla_gpu_use_runtime_fusion(true);
-      module->mutable_config().set_debug_options(debug_opts);
-      EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{0.01}))
-          << optimized_hlo_string;
-    }
-  }
-
-  void TestClamp(absl::string_view pre_hlo_string,
-                 absl::string_view post_hlo_string) {
-    std::string alpha_conv_scalar, alpha_side_input_scalar;
-    std::string elementwise_type;
-
-    std::string optimized_hlo_string = GetOptimizedHlo(pre_hlo_string);
-    EXPECT_THAT(optimized_hlo_string, Not(HasSubstr("Convert")));
-    EXPECT_THAT(optimized_hlo_string, HasSubstr("__cudnn$conv"));
-    EXPECT_TRUE(RunAndCompare(pre_hlo_string, ErrorSpec{0.01}))
-        << pre_hlo_string;
-
-    absl::StatusOr<bool> filecheck_result =
-        RunFileCheck(optimized_hlo_string, post_hlo_string);
-    ASSERT_TRUE(filecheck_result.ok()) << filecheck_result.status();
-    EXPECT_TRUE(*filecheck_result);
-  }
-
-  void TestNotMatchWithAllTypes(absl::string_view hlo_string) {
-    for (absl::string_view type : *(IsCuda() ? kf16f32f64 : kf16f32)) {
-      const std::string hlo_with_new_type =
-          absl::StrReplaceAll(hlo_string, {{"TYPE", type}});
-      std::string optimized_hlo_string = GetOptimizedHlo(hlo_with_new_type);
-      SCOPED_TRACE(optimized_hlo_string);
-      EXPECT_THAT(optimized_hlo_string, HasSubstr(kCudnnConvForwardCallTarget));
-      EXPECT_THAT(optimized_hlo_string,
-                  Not(HasSubstr(kCudnnConvBiasActivationForwardCallTarget)));
-    }
-  }
-
-  void TestF8(std::string pre_hlo_string, std::string custom_call_string,
-              std::string serialized_graph_string) {
-    if (!IsCuda()) return;
-    if (GetCudaComputeCapability().IsAtLeast(
-            se::CudaComputeCapability::HOPPER)) {
-      // On Hopper and newer architectures, test numerical correctness and
-      // verify the HLO of the Custom Call with operand and return layouts and
-      // the serialized graph based on the full compiler pipeline.
-      std::string optimized_hlo_string = GetOptimizedHlo(pre_hlo_string);
-      EXPECT_THAT(optimized_hlo_string, Not(HasSubstr("Convert")));
-      EXPECT_THAT(optimized_hlo_string, HasSubstr("__cudnn$conv"));
-      EXPECT_TRUE(RunAndCompare(pre_hlo_string, ErrorSpec{0.15, 0.15}))
-          << pre_hlo_string;
-
-      absl::StatusOr<bool> filecheck_result =
-          RunFileCheck(optimized_hlo_string, custom_call_string);
-      ASSERT_TRUE(filecheck_result.ok()) << filecheck_result.status();
-      EXPECT_TRUE(*filecheck_result);
-
-      filecheck_result =
-          RunFileCheck(optimized_hlo_string, serialized_graph_string);
-      ASSERT_TRUE(filecheck_result.ok()) << filecheck_result.status();
-      EXPECT_TRUE(*filecheck_result);
-    } else {
-      // On older architectures, disregard layout information and only verify
-      // the basic configuration of the convolution Custom Call using the number
-      // of operands and the window_size and serialized graph attributes based
-      // on the GpuConvRewriter and CudnnFusedConvRewriter passes.
-      std::string::size_type p0 = custom_call_string.find(':');
-      std::string::size_type p1 = custom_call_string.find("custom-call");
-      custom_call_string.erase(p0 + 1, p1 - p0 - 2);
-      p0 = custom_call_string.find(", dim_labels");
-      custom_call_string.erase(p0);
-
-      TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
-                              ParseAndReturnVerifiedModule(pre_hlo_string));
-      TF_ASSERT_OK_AND_ASSIGN(
-          bool changed, RunHloPass(GpuConvRewriter(GetCudaComputeCapability()),
-                                   module.get()));
-      EXPECT_TRUE(changed);
-      RunAndFilecheckHloRewrite(
-          module->ToString(HloPrintOptions{}.set_print_operand_shape(false)),
-          CudnnFusedConvRewriter(
-              se::CudaComputeCapability{se::CudaComputeCapability::HOPPER, 0},
-              GetDnnVersion(), GetToolkitVersion()),
-          custom_call_string);
-      RunAndFilecheckHloRewrite(
-          module->ToString(HloPrintOptions{}.set_print_operand_shape(false)),
-          CudnnFusedConvRewriter(
-              se::CudaComputeCapability{se::CudaComputeCapability::HOPPER, 0},
-              GetDnnVersion(), GetToolkitVersion()),
-          serialized_graph_string);
-    }
-  }
-
-  void TestF8Parameterized(std::string template_pre_hlo_string,
-                           std::string template_custom_call_string,
-                           std::string template_serialized_graph_string) {
-    std::array<absl::string_view, 2> types = {"f8e4m3fn", "f8e5m2"};
-    std::array<absl::string_view, 2> clamp_lower = {"-448.", "-57344."};
-    std::array<absl::string_view, 2> clamp_upper = {"448.", "57344."};
-    absl::flat_hash_map<absl::string_view, absl::string_view> replacements;
-    for (int i = 0; i < 2; ++i) {
-      replacements["<<InputType>>"] = types[i];
-      for (int j = 0; j < 2; ++j) {
-        replacements["<<FilterType>>"] = types[j];
-        for (int k = 0; k < 2; ++k) {
-          replacements["<<OutputType>>"] = types[k];
-          replacements["<<ClampLower>>"] = clamp_lower[k];
-          replacements["<<ClampUpper>>"] = clamp_upper[k];
-          TestF8(absl::StrReplaceAll(template_pre_hlo_string, replacements),
-                 absl::StrReplaceAll(template_custom_call_string, replacements),
-                 absl::StrReplaceAll(template_serialized_graph_string,
-                                     replacements));
-        }
-      }
-    }
-  }
-};
-
-#if GOOGLE_CUDA
-#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8900)
-#define MAYBE_SKIP_TEST(CAUSE)                                           \
-  do {                                                                   \
-    if (absl::string_view(CAUSE) == "F8")                                \
-      GTEST_SKIP() << "FP8 convolutions require CUDA 12 and cuDNN 8.9."; \
-  } while (0)
-#else
-#define MAYBE_SKIP_TEST(CAUSE)
-#endif
-#else
-#define MAYBE_SKIP_TEST(CAUSE)                                \
-  do {                                                        \
-    GTEST_SKIP() << "ROCm does not support " CAUSE " fusion"; \
-  } while (0)
-#endif
-
-TEST_F(CudnnFusedConvRewriterTest, TestConvOnly) {
-  // max(0, conv(x, w));
-  TestMatchWithAllTypes(R"(
-    HloModule Test
-
-    ENTRY Test {
-      zero = TYPE[] constant(0)
-      zeros = TYPE[1,32,9,9] broadcast(zero), dimensions={}
-
-      input = TYPE[1,17,9,9] parameter(0)
-      filter = TYPE[3,3,17,32] parameter(1)
-
-      conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
-      ROOT relu = TYPE[1,32,9,9] maximum(zeros, conv)
-    })");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, DontFuseReluWithDepthwiseConv) {
-  // max(0, conv(x, w));
-  TestNotMatchWithAllTypes(R"(
-    HloModule Test
-
-    ENTRY Test {
-      zero = TYPE[] constant(0)
-      zeros = TYPE[1,17,9,9] broadcast(zero), dimensions={}
-
-      input = TYPE[1,17,9,9] parameter(0)
-      filter = TYPE[3,3,1,17] parameter(1)
-
-      conv = TYPE[1,17,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=17
-      ROOT relu = TYPE[1,17,9,9] maximum(zeros, conv)
-    })");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestBias) {
-  // max(0, conv(x, w) + bias);
-  TestMatchWithAllTypes(R"(
-    HloModule Test
-
-    ENTRY Test {
-      zero = TYPE[] constant(0)
-      zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
-
-      input = TYPE[1,3,3,64] parameter(0)
-      filter = TYPE[3,3,64,64] parameter(1)
-      bias = TYPE[64] parameter(2)
-
-      conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
-      broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
-      add1 = TYPE[1,3,3,64] add(conv, broadcasted_bias)
-      ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
-    })");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, Test3D) {
-  // max(0, conv(x, w) + bias);
-  std::string body = R"(
-    HloModule Test
-
-    ENTRY Test {
-      zero = TYPE[] constant(0)
-      zeros = TYPE[1,3,5,7,64] broadcast(zero), dimensions={}
-
-      input = TYPE[1,3,5,7,64] parameter(0)
-      filter = TYPE[3,3,3,64,64] parameter(1)
-      bias = TYPE[64] parameter(2)
-
-      conv = TYPE[1,3,5,7,64] convolution(input, filter), window={size=3x3x3 pad=1_1x1_1x1_1}, dim_labels=b012f_012io->b012f, feature_group_count=1
-      broadcasted_bias = TYPE[1,3,5,7,64] broadcast(bias), dimensions={4}
-      add1 = TYPE[1,3,5,7,64] add(conv, broadcasted_bias)
-    )";
-
-  std::string relu = R"(
-      ROOT relu = TYPE[1,3,5,7,64] maximum(zeros, add1)
-    })";
-
-  std::string elu = R"(
-      cmp = pred[1,3,5,7,64] compare(add1, zeros), direction=GT
-      expm1 = TYPE[1,3,5,7,64] exponential-minus-one(add1)
-      ROOT elu = TYPE[1,3,5,7,64] select(cmp, add1, expm1)
-    })";
-
-  TestMatchWithAllTypes(body + relu);
-  if (!IsCuda()) TestMatchWithAllTypes(body + elu);
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestBiasMultiCall) {
-  // max(0, conv(x, w) + bias);
-  std::string code = R"(
-    HloModule Test
-
-    ENTRY Test {
-      zero = TYPE[] constant(0)
-      zeros = TYPE[1,<<<format>>>,64] broadcast(zero), dimensions={}
-
-      input = TYPE[1,<<<format>>>,64] parameter(0)
-      filter = TYPE[3,3,64,64] parameter(1)
-      bias = TYPE[64] parameter(2)
-
-      conv = TYPE[1,<<<format>>>,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
-      broadcasted_bias = TYPE[1,<<<format>>>,64] broadcast(bias), dimensions={3}
-      add1 = TYPE[1,<<<format>>>,64] add(conv, broadcasted_bias)
-      ROOT relu = TYPE[1,<<<format>>>,64] maximum(zeros, add1)
-    })";
-  absl::flat_hash_map<absl::string_view, absl::string_view> replacements;
-  replacements["<<<format>>>"] = "3,3";
-  TestMatchWithAllTypes(absl::StrReplaceAll(code, replacements));
-  replacements["<<<format>>>"] = "5,5";
-  TestMatchWithAllTypes(absl::StrReplaceAll(code, replacements));
-  replacements["<<<format>>>"] = "3,3";
-  TestMatchWithAllTypes(absl::StrReplaceAll(code, replacements));
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestBiasNoRelu) {
-  // conv(x, w) + bias;
-  TestMatchWithAllTypes(R"(
-    HloModule Test
-
-    ENTRY Test {
-      input = TYPE[1,3,3,64] parameter(0)
-      filter = TYPE[3,3,64,64] parameter(1)
-      bias = TYPE[64] parameter(2)
-
-      conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
-      broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
-      ROOT add1 = TYPE[1,3,3,64] add(conv, broadcasted_bias)
-    })");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, DontFuseBiasWithDepthwiseConv) {
-  // conv(x, w) + bias;
-  TestNotMatchWithAllTypes(R"(
-    HloModule Test
-
-    ENTRY Test {
-      zero = TYPE[] constant(0)
-      zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
-
-      input = TYPE[1,3,3,64] parameter(0)
-      filter = TYPE[3,3,1,64] parameter(1)
-      bias = TYPE[64] parameter(2)
-
-      conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=64
-      broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
-      add1 = TYPE[1,3,3,64] add(conv, broadcasted_bias)
-      ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
-    })");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestElu) {
-  // sum = conv(x, w) + bias
-  // select(compare(sum, 0, GT), sum, exponential-minus-one(sum));
-  TestMatchWithAllTypes(R"(
-    HloModule Test
-
-    ENTRY Test {
-      zero = TYPE[] constant(0)
-      zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
-
-      input = TYPE[1,3,3,64] parameter(0)
-      filter = TYPE[3,3,64,64] parameter(1)
-      bias = TYPE[64] parameter(2)
-
-      conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
-      broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
-      sum = TYPE[1,3,3,64] add(conv, broadcasted_bias)
-      cmp = pred[1,3,3,64] compare(sum, zeros), direction=GT
-      expm1 = TYPE[1,3,3,64] exponential-minus-one(sum)
-      ROOT elu = TYPE[1,3,3,64] select(cmp, sum, expm1)
-    })");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, DontFuseEluWithDepthwiseConv) {
-  // sum = conv(x, w) + bias
-  // select(compare(sum, 0, GT), sum, exponential-minus-one(sum));
-  TestNotMatchWithAllTypes(R"(
-    HloModule Test
-
-    ENTRY Test {
-      zero = TYPE[] constant(0)
-      zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
-
-      input = TYPE[1,3,3,64] parameter(0)
-      filter = TYPE[3,3,1,64] parameter(1)
-      bias = TYPE[64] parameter(2)
-
-      conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=64
-      broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
-      sum = TYPE[1,3,3,64] add(conv, broadcasted_bias)
-      cmp = pred[1,3,3,64] compare(sum, zeros), direction=GT
-      expm1 = TYPE[1,3,3,64] exponential-minus-one(sum)
-      ROOT elu = TYPE[1,3,3,64] select(cmp, sum, expm1)
-    })");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestRelu6) {
-  if (IsCuda() && !GetCudaComputeCapability().IsAtLeast(
-                      se::CudaComputeCapability::AMPERE)) {
-    GTEST_SKIP() << "Conv-Bias-Relu6 fusion is supported and recommended with "
-                    "the Nvidia Ampere+ GPUs.";
-  }
-  // sum = conv(x, w) + bias
-  // clamp(0, sum, 6);
-  TestMatchWithAllTypes(R"(
-    HloModule Test
-    ENTRY Test {
-      zero = TYPE[] constant(0)
-      zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
-      six = TYPE[] constant(6)
-      sixes = TYPE[1,3,3,64] broadcast(six), dimensions={}
-      input = TYPE[1,3,3,64] parameter(0)
-      filter = TYPE[3,3,64,64] parameter(1)
-      bias = TYPE[64] parameter(2)
-      conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
-      broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
-      sum = TYPE[1,3,3,64] add(conv, broadcasted_bias)
-      ROOT relu6 = TYPE[1,3,3,64] clamp(zeros, sum, sixes)
-    })");
-}
-
-// At time of writing, cudnn runtime fusion cannot handle f16 convs with an odd
-// number of input/output channels.  Check that we don't try to run this conv
-// with runtime fusion (or, if we do, that it works!).
-TEST_F(CudnnFusedConvRewriterTest, TestRelu6OddChannels) {
-  if (IsCuda() && !GetCudaComputeCapability().IsAtLeast(
-                      se::CudaComputeCapability::AMPERE)) {
-    GTEST_SKIP() << "Conv-Bias-Relu6 fusion is supported and recommended with "
-                    "the Nvidia Ampere+ GPUs.";
-  }
-  TestMatchWithAllTypes(R"(
-    HloModule Test
-    ENTRY Test {
-      zeros = TYPE[1,384,1024,32] broadcast(TYPE[] constant(0)), dimensions={}
-      sixes = TYPE[1,384,1024,32] broadcast(TYPE[] constant(6)), dimensions={}
-      input = TYPE[1,769,2049,3] parameter(0)
-      filter = TYPE[32,3,3,3] parameter(1)
-      bias = TYPE[32] parameter(2)
-      conv = TYPE[1,384,1024,32] convolution(input, filter), window={size=3x3 stride=2x2}, dim_labels=b01f_o01i->b01f
-      broadcasted_bias = TYPE[1,384,1024,32] broadcast(bias), dimensions={3}
-      sum = add(conv, broadcasted_bias)
-      ROOT relu6 = clamp(zeros, sum, sixes)
-    })");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestLeakyRelu) {
-  if (IsCuda() && !GetCudaComputeCapability().IsAtLeast(
-                      se::CudaComputeCapability::AMPERE)) {
-    GTEST_SKIP()
-        << "Conv-Bias-LeakyRelu fusion is supported and recommended with "
-           "the Nvidia Ampere+ GPUs.";
-  }
-  // sum = conv(x, w) + bias
-  // select(compare(sum, 0, GT), sum, multiply(sum, alpha));
-  TestMatchWithAllTypes(R"(
-    HloModule Test
-    ENTRY Test {
-      zero = TYPE[] constant(0)
-      zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
-      alpha = TYPE[] constant(0.2)
-      alphas = TYPE[1,3,3,64] broadcast(alpha), dimensions={}
-      input = TYPE[1,3,3,64] parameter(0)
-      filter = TYPE[3,3,64,64] parameter(1)
-      bias = TYPE[64] parameter(2)
-      conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
-      broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
-      sum = TYPE[1,3,3,64] add(conv, broadcasted_bias)
-      cmp = pred[1,3,3,64] compare(sum, zeros), direction=GT
-      mul = TYPE[1,3,3,64] multiply(sum, alphas)
-      ROOT elu = TYPE[1,3,3,64] select(cmp, sum, mul)
-    })");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestSideInputOnly) {
-  // max(0, conv(x, w) + side_input);
-  TestMatchWithAllTypes(R"(
-    HloModule Test
-
-    ENTRY Test {
-      zero = TYPE[] constant(0)
-      zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
-
-      input = TYPE[1,3,3,64] parameter(0)
-      filter = TYPE[3,3,64,64] parameter(1)
-      side_input = TYPE[1,3,3,64] parameter(2)
-
-      conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
-      add1 = TYPE[1,3,3,64] add(conv, side_input)
-      ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
-    })");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, DontFuseSideInputWithDepthwiseConv) {
-  // max(0, conv(x, w) + side_input);
-  TestNotMatchWithAllTypes(R"(
-    HloModule Test
-
-    ENTRY Test {
-      zero = TYPE[] constant(0)
-      zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
-
-      input = TYPE[1,3,3,64] parameter(0)
-      filter = TYPE[3,3,1,64] parameter(1)
-      side_input = TYPE[1,3,3,64] parameter(2)
-
-      conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=64
-      add1 = TYPE[1,3,3,64] add(conv, side_input)
-      ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
-    })");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestBiasAndSideInput) {
-  // max(0, conv(x, w) + side_input + bias);
-  TestMatchWithAllTypes(R"(
-    HloModule Test
-
-    ENTRY Test {
-      zero = TYPE[] constant(0)
-      zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
-
-      input = TYPE[1,3,3,64] parameter(0)
-      filter = TYPE[3,3,64,64] parameter(1)
-      side_input = TYPE[1,3,3,64] parameter(2)
-      bias = TYPE[64] parameter(3)
-
-      conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
-      broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
-      add1 = TYPE[1,3,3,64] add(conv, broadcasted_bias)
-      add2 = TYPE[1,3,3,64] add(add1, side_input)
-      ROOT relu = TYPE[1,3,3,64] maximum(zeros, add2)
-    })");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestScaledConv) {
-  // max(0, 0.999994934 * conv(x, w));
-  TestMatchWithAllTypes(R"(
-    HloModule Test
-
-    ENTRY Test {
-      zero = TYPE[] constant(0)
-      zeros = TYPE[1,32,9,9] broadcast(zero), dimensions={}
-      alpha_conv_scalar = TYPE[] constant(0.999994934)
-
-      input = TYPE[1,17,9,9] parameter(0)
-      filter = TYPE[3,3,17,32] parameter(1)
-
-      conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
-      alpha_conv = TYPE[1,32,9,9] broadcast(alpha_conv_scalar), dimensions={}
-      scaled_conv = TYPE[1,32,9,9] multiply(conv, alpha_conv)
-      ROOT relu = TYPE[1,32,9,9] maximum(zeros, scaled_conv)
-    })");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, DontFuseScaledDepthwiseConv) {
-  // max(0, 0.999994934 * conv(x, w));
-  TestNotMatchWithAllTypes(R"(
-    HloModule Test
-
-    ENTRY Test {
-      zero = TYPE[] constant(0)
-      zeros = TYPE[1,17,9,9] broadcast(zero), dimensions={}
-      alpha_conv_scalar = TYPE[] constant(0.999994934)
-
-      input = TYPE[1,17,9,9] parameter(0)
-      filter = TYPE[3,3,1,17] parameter(1)
-
-      conv = TYPE[1,17,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=17
-      alpha_conv = TYPE[1,17,9,9] broadcast(alpha_conv_scalar), dimensions={}
-      scaled_conv = TYPE[1,17,9,9] multiply(conv, alpha_conv)
-      ROOT relu = TYPE[1,17,9,9] maximum(zeros, scaled_conv)
-    })");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestNoCrashOnInf) {
-  EXPECT_TRUE(RunAndCompare(R"(
-    HloModule Test
-
-    ENTRY Test {
-      zero = f32[] constant(inf)
-      zeros = f32[1,32,9,9] broadcast(zero), dimensions={}
-      alpha_conv_scalar = f32[] constant(0.999994934)
-
-      input = f32[1,17,9,9] parameter(0)
-      filter = f32[3,3,17,32] parameter(1)
-
-      conv = f32[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
-      alpha_conv = f32[1,32,9,9] broadcast(alpha_conv_scalar), dimensions={}
-      scaled_conv = f32[1,32,9,9] multiply(conv, alpha_conv)
-      ROOT relu = f32[1,32,9,9] maximum(zeros, scaled_conv)
-    })",
-                            ErrorSpec{0.01}));
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestConvAndScaledSideInput) {
-  // max(0, conv(x, w) + 0.899994934 * side_input);
-  TestMatchWithAllTypes(R"(
-    HloModule Test
-
-    ENTRY Test {
-      zero = TYPE[] constant(0)
-      zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
-      alpha_side_input_scalar = TYPE[] constant(0.899994934)
-      alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={}
-
-      input = TYPE[1,3,3,64] parameter(0)
-      filter = TYPE[3,3,64,64] parameter(1)
-      side_input = TYPE[1,3,3,64] parameter(2)
-
-      conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
-      scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input)
-      add1 = TYPE[1,3,3,64] add(conv, scaled_side_input)
-      ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
-    })");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, DontFuseDepthwiseConvWithScaledSideInput) {
-  // max(0, conv(x, w) + 0.899994934 * side_input);
-  TestNotMatchWithAllTypes(R"(
-    HloModule Test
-
-    ENTRY Test {
-      zero = TYPE[] constant(0)
-      zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
-      alpha_side_input_scalar = TYPE[] constant(0.899994934)
-      alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={}
-
-      input = TYPE[1,3,3,64] parameter(0)
-      filter = TYPE[3,3,1,64] parameter(1)
-      side_input = TYPE[1,3,3,64] parameter(2)
-
-      conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=64
-      scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input)
-      add1 = TYPE[1,3,3,64] add(conv, scaled_side_input)
-      ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
-    })");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestScaledConvAndScaledSideInput) {
-  // max(0, 0.999994934 * conv(x, w) + 0.899994934 * side_input);
-  TestMatchWithAllTypes(R"(
-    HloModule Test
-
-    ENTRY Test {
-      zero = TYPE[] constant(0)
-      zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
-      alpha_conv_scalar = TYPE[] constant(0.999994934)
-      alpha_conv = TYPE[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={}
-      alpha_side_input_scalar = TYPE[] constant(0.899994934)
-      alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={}
-
-      input = TYPE[1,3,3,64] parameter(0)
-      filter = TYPE[3,3,64,64] parameter(1)
-      side_input = TYPE[1,3,3,64] parameter(2)
-
-      conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
-      scaled_conv = TYPE[1,3,3,64] multiply(conv, alpha_conv)
-      scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input)
-      add1 = TYPE[1,3,3,64] add(scaled_conv, scaled_side_input)
-      ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
-    })");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestScaledConvAndScaledSideInputWithBias) {
-  // max(0, 0.999994934 * conv(x, w) + 0.899994934 * side_input + bias);
-  TestMatchWithAllTypes(R"(
-    HloModule Test
-
-    ENTRY Test {
-      zero = TYPE[] constant(0)
-      zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
-      alpha_conv_scalar = TYPE[] constant(0.999994934)
-      alpha_conv = TYPE[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={}
-      alpha_side_input_scalar = TYPE[] constant(0.899994934)
-      alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={}
-
-      input = TYPE[1,3,3,64] parameter(0)
-      filter = TYPE[3,3,64,64] parameter(1)
-      side_input = TYPE[1,3,3,64] parameter(2)
-      bias = TYPE[64] parameter(3)
-
-      conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
-      scaled_conv = TYPE[1,3,3,64] multiply(conv, alpha_conv)
-      scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input)
-      broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
-      add1 = TYPE[1,3,3,64] add(scaled_conv, broadcasted_bias)
-      add2 = TYPE[1,3,3,64] add(add1, scaled_side_input)
-      ROOT relu = TYPE[1,3,3,64] maximum(zeros, add2)
-    })");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestMatchMaxZeroOnly) {
-  // max(0.1, conv(x, w)) shouldn't match.
-  TestNotMatchWithAllTypes(R"(
-    HloModule Test
-
-    ENTRY Test {
-      point_one = TYPE[] constant(0.1)
-      point_ones = TYPE[1,32,9,9] broadcast(point_one), dimensions={}
-
-      input = TYPE[1,17,9,9] parameter(0)
-      filter = TYPE[3,3,17,32] parameter(1)
-
-      conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
-      ROOT relu = TYPE[1,32,9,9] maximum(point_ones, conv)
-    })");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, PreservesMetadata) {
-  const char* kHloString = R"(
-    HloModule Test
-
-    ENTRY Test {
-      zero = f32[] constant(0)
-      zeros = f32[1,32,9,9] broadcast(zero), dimensions={}
-
-      input = f32[1,17,9,9] parameter(0)
-      filter = f32[3,3,17,32] parameter(1)
-
-      conv = f32[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1, metadata={op_type="foo" op_name="bar"}
-      ROOT relu = f32[1,32,9,9] maximum(zeros, conv)
-    })";
-
-  const std::string optimized_hlo_string =
-      backend()
-          .compiler()
-          ->RunHloPasses(
-              ParseAndReturnVerifiedModule(kHloString, GetModuleConfigForTest())
-                  .value(),
-              backend().default_stream_executor(), backend().memory_allocator())
-          .value()
-          ->ToString();
-  EXPECT_THAT(optimized_hlo_string,
-              ::testing::ContainsRegex(
-                  R"(custom-call.*metadata=\{op_type="foo" op_name="bar"\})"));
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestPreservesFeatureGroupCount) {
-  // The convolution below would crash if feature_count is not preserved.
-  const char* kHloString = R"(
-    HloModule jaxpr_computation__6.19
-
-    primitive_computation__1.4 {
-      parameter.5 = f32[] parameter(0)
-      parameter.6 = f32[] parameter(1)
-      ROOT add.7 = f32[] add(parameter.5, parameter.6)
-    }
-
-    ENTRY jaxpr_computation__7.8 {
-      parameter.11 = f32[2,64,64,53]{3,2,1,0} parameter(1)
-      parameter.10 = f32[3,3,1,53]{3,2,1,0} parameter(0)
-      convolution.12 = f32[2,64,64,53]{3,2,1,0} convolution(parameter.11, parameter.10), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=53
-      constant.13 = f32[] constant(0)
-      broadcast.14 = f32[2,64,64,53]{3,2,1,0} broadcast(constant.13), dimensions={}
-      maximum.15 = f32[2,64,64,53]{3,2,1,0} maximum(convolution.12, broadcast.14)
-      ROOT reduce.17 = f32[] reduce(maximum.15, constant.13), dimensions={0,1,2,3}, to_apply=primitive_computation__1.4
-    }
-  )";
-  EXPECT_TRUE(RunAndCompare(kHloString, ErrorSpec{0.01}));
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestConvF8) {
-  MAYBE_SKIP_TEST("F8");
-  TestF8(
-      // pre_hlo
-      R"(
-    HloModule Test
-
-    ENTRY Test {
-       input = f8e4m3fn[1,128,6,6] parameter(0)
-       filter = f8e4m3fn[3,3,128,16] parameter(1)
-       ROOT conv_a = f8e4m3fn[1,16,6,6] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
-
-    })",
-      // custom_call
-      R"(
-// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f8e4m3fn[1,6,6,16]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph"
-  )",
-      // serialized_graph
-      R"(
-// CHECK: "serialized_graph":"[[CONV_UID:[0-9]+]]:[f8e4m3fn]conv();"
-      )");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestConvScaledOutputF8) {
-  MAYBE_SKIP_TEST("F8");
-  TestF8(
-      // pre_hlo
-      R"(
-    HloModule Test
-
-    ENTRY Test {
-       input = f8e4m3fn[1,128,6,6] parameter(0)
-       filter = f8e4m3fn[3,3,128,16] parameter(1)
-       input_f32 = f32[1,128,6,6] convert(input)
-       filter_f32 = f32[3,3,128,16] convert(filter)
-       z_scale = f32[] parameter(2)
-       z_scale_bcast = f32[1,16,6,6] broadcast(z_scale), dimensions={}
-       conv_a = f32[1,16,6,6] convolution(input_f32, filter_f32), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
-       conv_a_scaled = f32[1,16,6,6] multiply(conv_a, z_scale_bcast)
-       c1 = f32[] constant(-448.)
-       c1_bcast = f32[1,16,6,6] broadcast(c1), dimensions={}
-       c2 = f32[] constant(448.)
-       c2_bcast = f32[1,16,6,6] broadcast(c2), dimensions={}
-       conv_a_clamped = f32[1,16,6,6] clamp(c1_bcast, conv_a_scaled, c2_bcast)
-       ROOT conv_f8 = f8e4m3fn[1,16,6,6] convert(conv_a_clamped)
-
-    })",
-      // custom_call
-      R"(
-// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f8e4m3fn[1,6,6,16]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]], [[OPERAND2:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph"
-  )",
-      // serialized_graph
-      R"(
-// CHECK: "serialized_graph":"[[CONV_UID:[0-9]+]]:[f32]conv();[[SCALE_UID:[0-9]+]]:[f8e4m3fn]scale([[CONV_UID]]);"
-      )");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestConvInvscaledOutputF8) {
-  MAYBE_SKIP_TEST("F8");
-  TestF8(
-      // pre_hlo
-      R"(
-    HloModule Test
-
-    ENTRY Test {
-       input = f8e4m3fn[1,128,6,6] parameter(0)
-       filter = f8e4m3fn[3,3,128,16] parameter(1)
-       input_f32 = f32[1,128,6,6] convert(input)
-       filter_f32 = f32[3,3,128,16] convert(filter)
-       z_scale = f32[] parameter(2)
-       z_scale_bcast = f32[1,16,6,6] broadcast(z_scale), dimensions={}
-       conv_a = f32[1,16,6,6] convolution(input_f32, filter_f32), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
-       conv_a_scaled = f32[1,16,6,6] divide(conv_a, z_scale_bcast)
-       c1 = f32[] constant(-448.)
-       c1_bcast = f32[1,16,6,6] broadcast(c1), dimensions={}
-       c2 = f32[] constant(448.)
-       c2_bcast = f32[1,16,6,6] broadcast(c2), dimensions={}
-       conv_a_clamped = f32[1,16,6,6] clamp(c1_bcast, conv_a_scaled, c2_bcast)
-       ROOT conv_f8 = f8e4m3fn[1,16,6,6] convert(conv_a_clamped)
-
-    })",
-      // custom_call
-      R"(
-// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f8e4m3fn[1,6,6,16]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]], [[OPERAND2:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph"
-  )",
-      // serialized_graph
-      R"(
-// CHECK: "serialized_graph":"[[CONV_UID:[0-9]+]]:[f32]conv();[[SCALE0_UID:[0-9]+]]:[f8e4m3fn]invscale([[CONV_UID]]);"
-      )");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestConvScaledF8Parameterized) {
-  MAYBE_SKIP_TEST("F8");
-  TestF8Parameterized(
-      // pre_hlo
-      R"(
-    HloModule Test
-
-    ENTRY Test {
-       input = <<InputType>>[1,128,6,6] parameter(0)
-       filter = <<FilterType>>[3,3,128,16] parameter(1)
-       input_scale = f32[] parameter(2)
-       input_scale_bcast = f32[1,128,6,6] broadcast(input_scale), dimensions={}
-       filter_scale = f32[] parameter(3)
-       filter_scale_bcast = f32[3,3,128,16] broadcast(filter_scale), dimensions={}
-       input_f32 = f32[1,128,6,6] convert(input)
-       input_unscaled = f32[1,128,6,6] multiply(input_f32, input_scale_bcast)
-       filter_f32 = f32[3,3,128,16] convert(filter)
-       filter_unscaled = f32[3,3,128,16] multiply(filter_f32, filter_scale_bcast)
-       z_scale = f32[] parameter(4)
-       z_scale_bcast = f32[1,16,6,6] broadcast(z_scale), dimensions={}
-       conv_a = f32[1,16,6,6] convolution(input_unscaled, filter_unscaled), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
-       conv_a_scaled = f32[1,16,6,6] multiply(conv_a, z_scale_bcast)
-       c1 = f32[] constant(<<ClampLower>>)
-       c1_bcast = f32[1,16,6,6] broadcast(c1), dimensions={}
-       c2 = f32[] constant(<<ClampUpper>>)
-       c2_bcast = f32[1,16,6,6] broadcast(c2), dimensions={}
-       conv_a_clamped = f32[1,16,6,6] clamp(c1_bcast, conv_a_scaled, c2_bcast)
-       ROOT conv_f8 = <<OutputType>>[1,16,6,6] convert(conv_a_clamped)
-
-    })",
-      // custom_call
-      R"(
-// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (<<OutputType>>[1,6,6,16]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]], [[OPERAND2:%[^ ]+]], [[OPERAND3:%[^ ]+]], [[OPERAND4:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph"
-      )",
-      // serialized_graph
-      R"(
-// CHECK: "serialized_graph":"[[CONV_UID:[0-9]+]]:[f32]conv();[[SCALE0_UID:[0-9]+]]:[f32]scale([[CONV_UID]]);[[SCALE1_UID:[0-9]+]]:[f32]scale([[SCALE0_UID]]);[[SCALE2_UID:[0-9]+]]:[<<OutputType>>]scale([[SCALE1_UID]]);"
-      )");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestConvScaledBiasF8) {
-  MAYBE_SKIP_TEST("F8");
-  TestF8(
-      // pre_hlo
-      R"(
-    HloModule Test
-
-    ENTRY Test {
-       input = f8e4m3fn[1,128,6,6] parameter(0)
-       filter = f8e4m3fn[3,3,128,16] parameter(1)
-       input_scale = f32[] parameter(2)
-       input_scale_bcast = f32[1,128,6,6] broadcast(input_scale), dimensions={}
-       filter_scale = f32[] parameter(3)
-       filter_scale_bcast = f32[3,3,128,16] broadcast(filter_scale), dimensions={}
-       input_f32 = f32[1,128,6,6] convert(input)
-       input_unscaled = f32[1,128,6,6] multiply(input_f32, input_scale_bcast)
-       filter_f32 = f32[3,3,128,16] convert(filter)
-       filter_unscaled = f32[3,3,128,16] multiply(filter_f32, filter_scale_bcast)
-       bias = f32[1,16,6,6] parameter(4)
-       z_scale = f32[] parameter(5)
-       z_scale_bcast = f32[1,16,6,6] broadcast(z_scale), dimensions={}
-       conv_a = f32[1,16,6,6] convolution(input_unscaled, filter_unscaled), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
-       conv_a_bias = f32[1,16,6,6] add(conv_a, bias)
-       conv_a_scaled = f32[1,16,6,6] multiply(conv_a_bias, z_scale_bcast)
-       c1 = f32[] constant(-448.)
-       c1_bcast = f32[1,16,6,6] broadcast(c1), dimensions={}
-       c2 = f32[] constant(448.)
-       c2_bcast = f32[1,16,6,6] broadcast(c2), dimensions={}
-       conv_a_clamped = f32[1,16,6,6] clamp(c1_bcast, conv_a_scaled, c2_bcast)
-       ROOT conv_f8 = f8e4m3fn[1,16,6,6] convert(conv_a_clamped)
-
-    })",
-      // custom_call
-      R"(
-// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f8e4m3fn[1,6,6,16]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]], [[OPERAND2:%[^ ]+]], [[OPERAND3:%[^ ]+]], [[OPERAND4:%[^ ]+]], /*index=5*/[[OPERAND5:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph"
-      )",
-      // serialized_graph
-      R"(
-// CHECK: "serialized_graph":"[[CONV_UID:[0-9]+]]:[f32]conv();[[SCALE0_UID:[0-9]+]]:[f32]scale([[CONV_UID]]);[[SCALE1_UID:[0-9]+]]:[f32]scale([[SCALE0_UID]]);[[ADD_UID:[0-9]+]]:[f32]add([[SCALE1_UID]]);[[SCALE2_UID:[0-9]+]]:[f8e4m3fn]scale([[ADD_UID]]);"
-      )");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestConvScaledReluF8) {
-  MAYBE_SKIP_TEST("F8");
-  TestF8(
-      // pre_hlo
-      R"(
-    HloModule Test
-
-    ENTRY Test {
-       input = f8e4m3fn[1,128,6,6] parameter(0)
-       filter = f8e4m3fn[3,3,128,16] parameter(1)
-       input_f32 = f32[1,128,6,6] convert(input)
-       filter_f32 = f32[3,3,128,16] convert(filter)
-       z_scale = f32[] parameter(2)
-       z_scale_bcast = f32[1,16,6,6] broadcast(z_scale), dimensions={}
-       c = f32[] constant(0)
-       c_bcast = f32[1,16,6,6] broadcast(c), dimensions={}
-       conv_a = f32[1,16,6,6] convolution(input_f32, filter_f32), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
-       relu_a = f32[1,16,6,6] maximum(conv_a, c_bcast)
-       relu_a_scaled = f32[1,16,6,6] multiply(relu_a, z_scale_bcast)
-       c1 = f32[] constant(-448.)
-       c1_bcast = f32[1,16,6,6] broadcast(c1), dimensions={}
-       c2 = f32[] constant(448.)
-       c2_bcast = f32[1,16,6,6] broadcast(c2), dimensions={}
-       relu_a_clamped = f32[1,16,6,6] clamp(c1_bcast, relu_a_scaled, c2_bcast)
-       ROOT conv_f8 = f8e4m3fn[1,16,6,6] convert(relu_a_clamped)
-
-    })",
-      // custom_call
-      R"(
-// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f8e4m3fn[1,6,6,16]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]], [[OPERAND2:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph"
-  )",
-      // serialized_graph
-      R"(
-// CHECK: "serialized_graph":"[[CONV_UID:[0-9]+]]:[f32]conv();[[RELU_UID:[0-9]+]]:[f32]relu([[CONV_UID]]);[[SCALE0_UID:[0-9]+]]:[f8e4m3fn]scale([[RELU_UID]]);"
-      )");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestConvAmaxF8) {
-  MAYBE_SKIP_TEST("F8");
-  TestF8(
-      // pre_hlo
-      R"(
-    HloModule Test
-
-    apply {
-      a = f32[] parameter(0)
-      b = f32[] parameter(1)
-      ROOT c = f32[] maximum(a, b)
-    }
-
-    ENTRY Test {
-       input = f8e4m3fn[1,128,6,6] parameter(0)
-       filter = f8e4m3fn[3,3,128,16] parameter(1)
-       input_scale = f32[] parameter(2)
-       input_scale_bcast = f32[1,128,6,6] broadcast(input_scale), dimensions={}
-       filter_scale = f32[] parameter(3)
-       filter_scale_bcast = f32[3,3,128,16] broadcast(filter_scale), dimensions={}
-       input_f32 = f32[1,128,6,6] convert(input)
-       input_unscaled = f32[1,128,6,6] multiply(input_f32, input_scale_bcast)
-       filter_f32 = f32[3,3,128,16] convert(filter)
-       filter_unscaled = f32[3,3,128,16] multiply(filter_f32, filter_scale_bcast)
-       conv_a = f32[1,16,6,6] convolution(input_unscaled, filter_unscaled), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
-       z_scale = f32[] parameter(4)
-       z_scale_bcast = f32[1,16,6,6] broadcast(z_scale), dimensions={}
-       conv_a_scaled = f32[1,16,6,6] multiply(conv_a, z_scale_bcast)
-       c1 = f32[] constant(-448.)
-       c1_bcast = f32[1,16,6,6] broadcast(c1), dimensions={}
-       c2 = f32[] constant(448.)
-       c2_bcast = f32[1,16,6,6] broadcast(c2), dimensions={}
-       conv_a_clamped = f32[1,16,6,6] clamp(c1_bcast, conv_a_scaled, c2_bcast)
-       conv_a_clamped_f8 = f8e4m3fn[1,16,6,6] convert(conv_a_clamped)
-       abs_conv_a = f32[1,16,6,6] abs(conv_a)
-       c0 = f32[] constant(-inf)
-       amax = f32[] reduce(abs_conv_a, c0), dimensions={0,1,2,3}, to_apply=apply
-       ROOT conv_f8 = (f8e4m3fn[1,16,6,6], f32[]) tuple(conv_a_clamped_f8, amax)
-
-    })",
-      // custom_call
-      R"(
-// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f8e4m3fn[1,6,6,16]{3,2,1,0}, f32[], u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]], [[OPERAND2:%[^ ]+]], [[OPERAND3:%[^ ]+]], [[OPERAND4:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph"
-    )",
-      // serialized_graph
-      R"(
-// CHECK: "serialized_graph":"[[CONV_UID:[0-9]+]]:[f32]conv();[[SCALE0_UID:[0-9]+]]:[f32]scale([[CONV_UID]]);[[SCALE1_UID:[0-9]+]]:[f32]scale([[SCALE0_UID]]);[[SCALE2_UID:[0-9]+]]:[f8e4m3fn]scale([[SCALE1_UID]]);[[AMAX_UID:[0-9]+]]:[f32]amax([[SCALE1_UID]]);"
-      )");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestConvReluAmaxF8) {
-  MAYBE_SKIP_TEST("F8");
-  TestF8(
-      // pre_hlo
-      R"(
-    HloModule Test
-
-    apply {
-      a = f32[] parameter(0)
-      b = f32[] parameter(1)
-      ROOT c = f32[] maximum(a, b)
-    }
-
-    ENTRY Test {
-       input = f8e4m3fn[1,128,6,6] parameter(0)
-       filter = f8e4m3fn[3,3,128,16] parameter(1)
-       input_scale = f32[] parameter(2)
-       input_scale_bcast = f32[1,128,6,6] broadcast(input_scale), dimensions={}
-       filter_scale = f32[] parameter(3)
-       filter_scale_bcast = f32[3,3,128,16] broadcast(filter_scale), dimensions={}
-       input_f32 = f32[1,128,6,6] convert(input)
-       input_unscaled = f32[1,128,6,6] multiply(input_f32, input_scale_bcast)
-       filter_f32 = f32[3,3,128,16] convert(filter)
-       filter_unscaled = f32[3,3,128,16] multiply(filter_f32, filter_scale_bcast)
-       c = f32[] constant(0)
-       c_bcast = f32[1,16,6,6] broadcast(c), dimensions={}
-       conv_a = f32[1,16,6,6] convolution(input_unscaled, filter_unscaled), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
-       relu_a = f32[1,16,6,6] maximum(conv_a, c_bcast)
-       z_scale = f32[] parameter(4)
-       z_scale_bcast = f32[1,16,6,6] broadcast(z_scale), dimensions={}
-       relu_a_scaled = f32[1,16,6,6] multiply(relu_a, z_scale_bcast)
-       c1 = f32[] constant(-448.)
-       c1_bcast = f32[1,16,6,6] broadcast(c1), dimensions={}
-       c2 = f32[] constant(448.)
-       c2_bcast = f32[1,16,6,6] broadcast(c2), dimensions={}
-       relu_a_clamped = f32[1,16,6,6] clamp(c1_bcast, relu_a_scaled, c2_bcast)
-       relu_a_clamped_f8 = f8e4m3fn[1,16,6,6] convert(relu_a_clamped)
-       abs_relu_a = f32[1,16,6,6] abs(relu_a)
-       c0 = f32[] constant(-inf)
-       amax = f32[] reduce(abs_relu_a, c0), dimensions={0,1,2,3}, to_apply=apply
-       ROOT conv_f8 = (f8e4m3fn[1,16,6,6], f32[]) tuple(relu_a_clamped_f8, amax)
-
-    })",
-      // custom_call
-      R"(
-// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f8e4m3fn[1,6,6,16]{3,2,1,0}, f32[], u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]], [[OPERAND2:%[^ ]+]], [[OPERAND3:%[^ ]+]], [[OPERAND4:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph"
-    )",
-      // serialized_graph
-      R"(
-// CHECK: "serialized_graph":"[[CONV_UID:[0-9]+]]:[f32]conv();[[SCALE0_UID:[0-9]+]]:[f32]scale([[CONV_UID]]);[[SCALE1_UID:[0-9]+]]:[f32]scale([[SCALE0_UID]]);[[RELU_UID:[0-9]+]]:[f32]relu([[SCALE1_UID]]);[[SCALE2_UID:[0-9]+]]:[f8e4m3fn]scale([[RELU_UID]]);[[AMAX_UID:[0-9]+]]:[f32]amax([[RELU_UID]]);"
-      )");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestConvScaledOutputMultipleUsersF8) {
-  MAYBE_SKIP_TEST("F8");
-  TestF8(
-      // pre_hlo
-      R"(
-    HloModule Test
-
-    ENTRY Test {
-       input = f8e4m3fn[1,128,6,6] parameter(0)
-       filter = f8e4m3fn[3,3,128,16] parameter(1)
-       input_f32 = f32[1,128,6,6] convert(input)
-       filter_f32 = f32[3,3,128,16] convert(filter)
-       z_scale0 = f32[] parameter(2)
-       z_scale0_bcast = f32[1,16,6,6] broadcast(z_scale0), dimensions={}
-       z_scale1 = f32[] parameter(3)
-       z_scale1_bcast = f32[1,16,6,6] broadcast(z_scale1), dimensions={}
-       conv_a = f32[1,16,6,6] convolution(input_f32, filter_f32), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
-       conv_a_scaled0 = f32[1,16,6,6] multiply(conv_a, z_scale0_bcast)
-       conv_a_scaled1 = f32[1,16,6,6] multiply(conv_a, z_scale1_bcast)
-       c1 = f32[] constant(-448.)
-       c1_bcast = f32[1,16,6,6] broadcast(c1), dimensions={}
-       c2 = f32[] constant(448.)
-       c2_bcast = f32[1,16,6,6] broadcast(c2), dimensions={}
-       conv_a_clamped0 = f32[1,16,6,6] clamp(c1_bcast, conv_a_scaled0, c2_bcast)
-       conv_a_clamped1 = f32[1,16,6,6] clamp(c1_bcast, conv_a_scaled1, c2_bcast)
-       conv_a_convert0 = f8e4m3fn[1,16,6,6] convert(conv_a_clamped0)
-       conv_a_convert1 = f8e4m3fn[1,16,6,6] convert(conv_a_clamped1)
-       ROOT conv_f8 = (f8e4m3fn[1,16,6,6], f8e4m3fn[1,16,6,6]) tuple(conv_a_convert0, conv_a_convert1)
-
-    })",
-      // custom_call
-      R"(
-// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f32[1,6,6,16]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph"
-  )",
-      // serialized_graph
-      R"(
-// CHECK: "serialized_graph":"[[CONV_UID:[0-9]+]]:[f32]conv();"
-      )");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestConvScaledOutputUnsupportedUserF8) {
-  MAYBE_SKIP_TEST("F8");
-  TestF8(
-      // pre_hlo
-      R"(
-    HloModule Test
-
-    ENTRY Test {
-       input = f8e4m3fn[1,128,6,6] parameter(0)
-       filter = f8e4m3fn[3,3,128,16] parameter(1)
-       input_f32 = f32[1,128,6,6] convert(input)
-       filter_f32 = f32[3,3,128,16] convert(filter)
-       z_scale = f32[] parameter(2)
-       z_scale_bcast = f32[1,16,6,6] broadcast(z_scale), dimensions={}
-       conv_a = f32[1,16,6,6] convolution(input_f32, filter_f32), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
-       conv_a_cos = f32[1,16,6,6] cosine(conv_a)
-       conv_a_scaled = f32[1,16,6,6] multiply(conv_a, z_scale_bcast)
-       c1 = f32[] constant(-448.)
-       c1_bcast = f32[1,16,6,6] broadcast(c1), dimensions={}
-       c2 = f32[] constant(448.)
-       c2_bcast = f32[1,16,6,6] broadcast(c2), dimensions={}
-       conv_a_clamped = f32[1,16,6,6] clamp(c1_bcast, conv_a_scaled, c2_bcast)
-       conv_a_convert = f8e4m3fn[1,16,6,6] convert(conv_a_clamped)
-       ROOT conv_f8 = (f8e4m3fn[1,16,6,6], f32[1,16,6,6]) tuple(conv_a_convert, conv_a_cos)
-
-    })",
-      // custom_call
-      R"(
-// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f32[1,6,6,16]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph"
-  )",
-      // serialized_graph
-      R"(
-// CHECK: "serialized_graph":"[[CONV_UID:[0-9]+]]:[f32]conv();"
-      )");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestConvInt8ToInt8) {
-  MAYBE_SKIP_TEST("I8");
-  // max(0, clamp(conv(x, w)))); for int8_t
-  TestClamp(
-      // pre_hlo
-      R"(
-    HloModule Test
-
-    ENTRY Test {
-      zero = s8[] constant(0)
-      zeros = s8[1,32,9,9] broadcast(zero), dimensions={}
-
-      input = s8[1,17,9,9] parameter(0)
-      filter = s8[3,3,17,32] parameter(1)
-
-      inputs32 = s32[1,17,9,9] convert(input)
-      filters32 = s32[3,3,17,32] convert(filter)
-
-      conv = s32[1,32,9,9] convolution(inputs32, filters32), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
-
-      lower = s32[] constant(-128)
-      lowers = s32[1,32,9,9] broadcast(lower), dimensions={}
-      upper = s32[] constant(127)
-      uppers = s32[1,32,9,9] broadcast(upper), dimensions={}
-
-      clamp = s32[1,32,9,9] clamp(lowers, conv, uppers)
-
-      ROOT convert = s8[1,32,9,9] convert(clamp)
-    })",
-      // post_hlo
-      R"(
-// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (s8[1,9,9,32]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[fusion_2_1:%[^ ]+]], [[fusion_1_2:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForward"
-      )");
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, TestConvInt8ToFloat) {
-  MAYBE_SKIP_TEST("I8");
-  const std::string module_str = R"(
-    HloModule Test
-
-    ENTRY Test {
-      input = s8[1,17,9,9] parameter(0)
-      filter = s8[3,3,17,32] parameter(1)
-
-      inputs32 = s32[1,17,9,9] convert(input)
-      filters32 = s32[3,3,17,32] convert(filter)
-
-      conv = s32[1,32,9,9] convolution(inputs32, filters32),
-               window={size=3x3 pad=1_1x1_1},
-               dim_labels=bf01_01io->bf01
-
-      ROOT convert = f32[1,32,9,9] convert(conv)
-    })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  GpuConvRewriter rewriter{GetCudaComputeCapability()};
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
-                               GetToolkitVersion()};
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  EXPECT_THAT(m->entry_computation()->root_instruction(),
-              GmockMatch(m::GetTupleElement(
-                             m::CustomCall({kCudnnConvForwardCallTarget}), 0)
-                             .WithShape(F32, {1, 32, 9, 9})));
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, TestConvInt8ToInt8BiasSideInput) {
-  MAYBE_SKIP_TEST("I8");
-  const std::string module_str = R"(
-    HloModule Test
-
-    ENTRY Test {
-      input = s32[1,17,9,9] convert(s8[1,17,9,9] parameter(0))
-      filter = s32[3,3,17,32] convert(s8[3,3,17,32] parameter(1))
-      bias = f32[1,32,9,9] broadcast(f32[32] parameter(2)), dimensions={1}
-      side_input = f32[1,32,9,9] convert(s8[1,32,9,9] parameter(3))
-
-      conv = s32[1,32,9,9] convolution(input, filter),
-               window={size=3x3 pad=1_1x1_1},
-               dim_labels=bf01_01io->bf01
-      conv_f32 = f32[1,32,9,9] convert(conv)
-      ROOT root = s8[1,32,9,9] convert(clamp(f32[1,32,9,9] broadcast(f32[] constant(-128)),
-                                             add(add(conv_f32, bias), side_input),
-                                             f32[1,32,9,9] broadcast(f32[] constant(127))))
-    })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  GpuConvRewriter rewriter{GetCudaComputeCapability()};
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
-                               GetToolkitVersion()};
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  // Simplify new `convert`'s that may be added to the graph.
-  AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
-  TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  EXPECT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(m::GetTupleElement(
-                     m::CustomCall({kCudnnConvBiasActivationForwardCallTarget},
-                                   m::Parameter(0), m::Parameter(1),
-                                   m::Parameter(2), m::Parameter(3)),
-                     0)
-                     .WithShape(S8, {1, 32, 9, 9})));
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, TestReluAfterConvert) {
-  MAYBE_SKIP_TEST("I8");
-  const std::string module_str = R"(
-    HloModule Test
-
-    ENTRY Test {
-      input = s32[1,17,9,9] convert(s8[1,17,9,9] parameter(0))
-      filter = s32[3,3,17,32] convert(s8[3,3,17,32] parameter(1))
-
-      conv = s32[1,32,9,9] convolution(input, filter),
-               window={size=3x3 pad=1_1x1_1},
-               dim_labels=bf01_01io->bf01
-      conv_s8 = s8[1,32,9,9] convert(clamp(s32[1,32,9,9] broadcast(s32[] constant(-128)),
-                                           conv,
-                                           s32[1,32,9,9] broadcast(s32[] constant(127))))
-      zeros = s8[1,32,9,9] broadcast(s8[] constant(0)), dimensions={}
-      ROOT root = s8[1,32,9,9] maximum(conv_s8, zeros)
-    })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  GpuConvRewriter rewriter{GetCudaComputeCapability()};
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
-                               GetToolkitVersion()};
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  // Simplify new `convert`'s that may be added to the graph.
-  AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
-  TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  const HloInstruction* conv;
-  ASSERT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::GetTupleElement(
-              m::CustomCall(
-                  &conv, {kCudnnConvBiasActivationForwardCallTarget},
-                  m::Parameter(0),  //
-                  m::Parameter(1),  //
-                  m::Broadcast(
-                      m::ConstantEffectiveScalar(0).WithElementType(F32))),
-              0)
-              .WithShape(S8, {1, 32, 9, 9})));
-  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
-                          conv->backend_config<GpuBackendConfig>());
-  const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
-  EXPECT_EQ(config.activation_mode(), se::dnn::kRelu);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, TestConvInt8ToFloatBiasSideInput) {
-  MAYBE_SKIP_TEST("I8");
-  const std::string module_str = R"(
-    HloModule Test
-
-    ENTRY Test {
-      input = s8[1,17,9,9] parameter(0)
-      filter = s8[3,3,17,32] parameter(1)
-      bias = f32[32] parameter(2)
-      bias_broadcast = f32[1,32,9,9] broadcast(bias), dimensions={1}
-      side_input_f32 = f32[1,32,9,9] parameter(3)
-
-      inputs32 = s32[1,17,9,9] convert(input)
-      filters32 = s32[3,3,17,32] convert(filter)
-
-      conv = s32[1,32,9,9] convolution(inputs32, filters32),
-               window={size=3x3 pad=1_1x1_1},
-               dim_labels=bf01_01io->bf01
-      conv_f32 = f32[1,32,9,9] convert(conv)
-      sum1 = add(conv_f32, bias_broadcast)
-      ROOT sum2 = add(sum1, side_input_f32)
-    })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  GpuConvRewriter rewriter{GetCudaComputeCapability()};
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
-                               GetToolkitVersion()};
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  // Simplify new `convert`'s that may be added to the graph.
-  AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
-  TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  EXPECT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(m::GetTupleElement(
-                     m::CustomCall({kCudnnConvBiasActivationForwardCallTarget},
-                                   m::Parameter(0), m::Parameter(1),
-                                   m::Parameter(2), m::Parameter(3)),
-                     0)
-                     .WithShape(F32, {1, 32, 9, 9})));
-}
-
-// The ReshapeMover pass changes
-//   reshape(side_input) * alpha -->
-//   reshape(side_input * alpha).
-// Make sure we can pattern-match this.
-TEST_F(CudnnFusedConvRewriterHloTest, Int8SideInputWithScaleAndReshape) {
-  MAYBE_SKIP_TEST("I8");
-  const std::string module_str = R"(
-    HloModule Test
-
-    ENTRY Test {
-      input = s32[1,17,9,9] convert(s8[1,17,9,9] parameter(0))
-      filter = s32[3,3,17,32] convert(s8[3,3,17,32] parameter(1))
-      bias = f32[1,32,9,9] broadcast(f32[32] parameter(2)), dimensions={1}
-      side_input_scale = f32[2592] broadcast(f32[] constant(0.25)), dimensions={}
-      side_input = f32[1,32,9,9] reshape(multiply(f32[2592] convert(s8[2592] parameter(3)), side_input_scale))
-
-      conv = s32[1,32,9,9] convolution(input, filter),
-               window={size=3x3 pad=1_1x1_1},
-               dim_labels=bf01_01io->bf01
-      ROOT root = s8[1,32,9,9] convert(clamp(f32[1,32,9,9] broadcast(f32[] constant(-128)),
-                                             add(add(f32[1,32,9,9] convert(conv), bias), side_input),
-                                             f32[1,32,9,9] broadcast(f32[] constant(127))))
-    })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  GpuConvRewriter rewriter{GetCudaComputeCapability()};
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
-                               GetToolkitVersion()};
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  // Simplify new `convert`'s that may be added to the graph.
-  HloPassFix<HloPassPipeline> simplify("simplify");
-  simplify.AddPass<AlgebraicSimplifier>(AlgebraicSimplifierOptions{});
-  simplify.AddPass<ReshapeMover>();
-  simplify.AddPass<ConvertMover>();
-  TF_ASSERT_OK(RunHloPass(&simplify, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  const HloInstruction* conv = nullptr;
-  ASSERT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::GetTupleElement(
-              m::CustomCall(
-                  &conv, {kCudnnConvBiasActivationForwardCallTarget},
-                  m::Parameter(0),  //
-                  m::Parameter(1),  //
-                  m::Parameter(2),  //
-                  m::Reshape(m::Parameter(3)).WithShape(S8, {1, 32, 9, 9})),
-              0)
-              .WithShape(S8, {1, 32, 9, 9})));
-  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
-                          conv->backend_config<GpuBackendConfig>());
-  const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
-  EXPECT_EQ(config.conv_result_scale(), 1);
-  EXPECT_EQ(config.side_input_scale(), 0.25);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, FuseAlpha) {
-  MAYBE_SKIP_TEST("I8");
-  const std::string module_str = R"(
-    HloModule Test
-
-    ENTRY Test {
-      input = s8[1,17,9,9] parameter(0)
-      filter = s8[3,3,17,32] parameter(1)
-      inputs32 = s32[1,17,9,9] convert(input)
-      filters32 = s32[3,3,17,32] convert(filter)
-      alpha = f32[] constant(42)
-      alpha_broadcast = f32[1,32,9,9] broadcast(alpha), dimensions={}
-
-      conv = s32[1,32,9,9] convolution(inputs32, filters32),
-               window={size=3x3 pad=1_1x1_1},
-               dim_labels=bf01_01io->bf01
-      convert = f32[1,32,9,9] convert(conv)
-      ROOT root = multiply(convert, alpha_broadcast)
-    })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  GpuConvRewriter rewriter{GetCudaComputeCapability()};
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
-                               GetToolkitVersion()};
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  const HloInstruction* conv = nullptr;
-  ASSERT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::GetTupleElement(
-              m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget}),
-              0)
-              .WithShape(F32, {1, 32, 9, 9})));
-  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
-                          conv->backend_config<GpuBackendConfig>());
-  const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
-  EXPECT_EQ(config.conv_result_scale(), 42);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, FuseRelu) {
-  const std::string module_str = R"(
-    HloModule Test
-
-    ENTRY Test {
-      inputs = f32[1,17,9,9] parameter(0)
-      filters = f32[3,3,17,32] parameter(1)
-      bias = f32[32] parameter(2)
-      bias_broadcast = f32[1,32,9,9] broadcast(bias), dimensions={1}
-      zero = f32[] constant(0)
-      zeros = f32[1,32,9,9] broadcast(zero), dimensions={}
-      conv = f32[1,32,9,9] convolution(inputs, filters),
-               window={size=3x3 pad=1_1x1_1},
-               dim_labels=bf01_01io->bf01
-      sum = add(conv, bias_broadcast)
-      ROOT relu = maximum(sum, zeros)
-    })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  GpuConvRewriter rewriter{GetCudaComputeCapability()};
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
-                               GetToolkitVersion()};
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  const HloInstruction* conv;
-  ASSERT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::GetTupleElement(
-              m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget},
-                            m::Parameter(0), m::Parameter(1), m::Parameter(2)),
-              0)
-              .WithShape(F32, {1, 32, 9, 9})));
-  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
-                          conv->backend_config<GpuBackendConfig>());
-  const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
-  EXPECT_EQ(config.activation_mode(), se::dnn::kRelu);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, DontFuseReluIfMultipleUses) {
-  const std::string module_str = R"(
-    HloModule Test
-
-    ENTRY Test {
-      inputs = f32[1,17,9,9] parameter(0)
-      filters = f32[3,3,17,32] parameter(1)
-      bias = f32[1,32,9,9] broadcast(f32[32] parameter(2)), dimensions={1}
-      zeros = f32[1,32,9,9] broadcast(f32[] constant(0)), dimensions={}
-      conv = f32[1,32,9,9] convolution(inputs, filters),
-               window={size=3x3 pad=1_1x1_1},
-               dim_labels=bf01_01io->bf01
-      sum = add(conv, bias)
-      relu = maximum(sum, zeros)
-      not_relu = minimum(sum, zeros)
-      ROOT root = tuple(relu, not_relu)
-    })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  GpuConvRewriter rewriter{GetCudaComputeCapability()};
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
-                               GetToolkitVersion()};
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  const HloInstruction* conv;
-  ASSERT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(m::Tuple(
-          m::MaximumAnyOrder(
-              m::Broadcast(m::ConstantEffectiveScalar(0)),
-              m::GetTupleElement(
-                  m::CustomCall(
-                      &conv, {kCudnnConvBiasActivationForwardCallTarget},
-                      m::Parameter(0), m::Parameter(1), m::Parameter(2)),
-                  0)
-                  .WithShape(F32, {1, 32, 9, 9})),
-          m::Minimum())));
-  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
-                          conv->backend_config<GpuBackendConfig>());
-  const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
-  EXPECT_EQ(config.activation_mode(), se::dnn::kNone);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, FuseElu) {
-  const std::string module_str = R"(
-    HloModule Test
-
-    ENTRY Test {
-      inputs = f16[1,16,9,9] parameter(0)
-      filters = f16[3,3,16,32] parameter(1)
-      bias = f16[32] parameter(2)
-      bias_broadcast = f16[1,32,9,9] broadcast(bias), dimensions={1}
-      zero = f16[] constant(0)
-      zeros = f16[1,32,9,9] broadcast(zero), dimensions={}
-      conv = f16[1,32,9,9] convolution(inputs, filters),
-               window={size=3x3 pad=1_1x1_1},
-               dim_labels=bf01_01io->bf01
-      sum = add(conv, bias_broadcast)
-      cmp = compare(sum, zeros), direction=GT
-      expm1 = exponential-minus-one(sum)
-      ROOT elu = select(cmp, sum, expm1)
-    })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-  DebugOptions debug_opts = m->config().debug_options();
-  debug_opts.set_xla_gpu_use_runtime_fusion(true);
-  m->mutable_config().set_debug_options(debug_opts);
-
-  GpuConvRewriter rewriter{GetCudaComputeCapability()};
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  // elu fusion is only active on Ampere+.
-  CudnnFusedConvRewriter fuser{se::CudaComputeCapability(8, 0), GetDnnVersion(),
-                               GetToolkitVersion()};
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  const HloInstruction* conv;
-  ASSERT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::GetTupleElement(
-              m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget},
-                            m::Parameter(0), m::Parameter(1), m::Parameter(2)),
-              0)
-              .WithShape(F16, {1, 32, 9, 9})));
-  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
-                          conv->backend_config<GpuBackendConfig>());
-  const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
-  EXPECT_EQ(config.activation_mode(), se::dnn::kElu);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, DontFuseEluIfMultipleUses) {
-  const std::string module_str = R"(
-    HloModule Test
-
-    ENTRY Test {
-      inputs = f16[1,16,9,9] parameter(0)
-      filters = f16[3,3,16,32] parameter(1)
-      bias = f16[32] parameter(2)
-      bias_broadcast = f16[1,32,9,9] broadcast(bias), dimensions={1}
-      zero = f16[] constant(0)
-      zeros = f16[1,32,9,9] broadcast(zero), dimensions={}
-      conv = f16[1,32,9,9] convolution(inputs, filters),
-               window={size=3x3 pad=1_1x1_1},
-               dim_labels=bf01_01io->bf01
-      sum = add(conv, bias_broadcast)
-      cmp = compare(sum, zeros), direction=GT
-      expm1 = exponential-minus-one(sum)
-      elu = select(cmp, sum, expm1)
-      not_elu = minimum(sum, zeros)
-      ROOT root = tuple(elu, not_elu)
-    })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-  DebugOptions debug_opts = m->config().debug_options();
-  debug_opts.set_xla_gpu_use_runtime_fusion(true);
-  m->mutable_config().set_debug_options(debug_opts);
-
-  GpuConvRewriter rewriter{GetCudaComputeCapability()};
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
-                               GetToolkitVersion()};
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  const HloInstruction* conv;
-  auto gte_pattern =
-      m::GetTupleElement(
-          m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget},
-                        m::Parameter(0), m::Parameter(1), m::Parameter(2)),
-          0)
-          .WithShape(F16, {1, 32, 9, 9});
-  ASSERT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(m::Tuple(
-          m::Select(m::Compare(gte_pattern,
-                               m::Broadcast(m::ConstantEffectiveScalar(0)))
-                        .WithComparisonDirection(ComparisonDirection::kGt),
-                    gte_pattern,
-                    m::Op()
-                        .WithPredicate(HloPredicateIsOp<HloOpcode::kExpm1>)
-                        .WithOperand(0, gte_pattern)),
-          m::Minimum())));
-  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
-                          conv->backend_config<GpuBackendConfig>());
-  const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
-  EXPECT_EQ(config.activation_mode(), se::dnn::kNone);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, FuseRelu6) {
-  const std::string module_str = R"(
-    HloModule Test
-    ENTRY Test {
-      inputs = f16[1,18,9,9] parameter(0)
-      filters = f16[3,3,18,32] parameter(1)
-      bias = f16[32] parameter(2)
-      bias_broadcast = f16[1,32,9,9] broadcast(bias), dimensions={1}
-      zero = f16[] constant(0)
-      zeros = f16[1,32,9,9] broadcast(zero), dimensions={}
-      sixes = f16[1,32,9,9] broadcast(f16[] constant(6)), dimensions={}
-      conv = f16[1,32,9,9] convolution(inputs, filters),
-               window={size=3x3 pad=1_1x1_1},
-               dim_labels=bf01_01io->bf01
-      sum = add(conv, bias_broadcast)
-      ROOT relu = clamp(zeros, sum, sixes)
-    })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-  DebugOptions debug_opts = m->config().debug_options();
-  debug_opts.set_xla_gpu_use_runtime_fusion(true);
-  m->mutable_config().set_debug_options(debug_opts);
-
-  GpuConvRewriter rewriter{GetCudaComputeCapability()};
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  // relu6 fusion is only enabled on Ampere+.
-  CudnnFusedConvRewriter fuser{se::CudaComputeCapability(8, 0), GetDnnVersion(),
-                               GetToolkitVersion()};
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-  SCOPED_TRACE(m->ToString());
-  const HloInstruction* conv;
-  ASSERT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::GetTupleElement(
-              m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget},
-                            m::Parameter(0), m::Parameter(1), m::Parameter(2)),
-              0)
-              .WithShape(F16, {1, 32, 9, 9})));
-  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
-                          conv->backend_config<GpuBackendConfig>());
-  const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
-  EXPECT_EQ(config.activation_mode(), se::dnn::kRelu6);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, DontFuseRelu6IfMultipleUses) {
-  const std::string module_str = R"(
-    HloModule Test
-    ENTRY Test {
-      inputs = f16[1,18,9,9] parameter(0)
-      filters = f16[3,3,18,32] parameter(1)
-      bias = f16[1,32,9,9] broadcast(f16[32] parameter(2)), dimensions={1}
-      zeros = f16[1,32,9,9] broadcast(f16[] constant(0)), dimensions={}
-      sixes = f16[1,32,9,9] broadcast(f16[] constant(6)), dimensions={}
-      conv = f16[1,32,9,9] convolution(inputs, filters),
-               window={size=3x3 pad=1_1x1_1},
-               dim_labels=bf01_01io->bf01
-      sum = add(conv, bias)
-      relu = clamp(zeros, sum, sixes)
-      not_relu = minimum(sum, zeros)
-      ROOT root = tuple(relu, not_relu)
-    })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-  DebugOptions debug_opts = m->config().debug_options();
-  debug_opts.set_xla_gpu_use_runtime_fusion(true);
-  m->mutable_config().set_debug_options(debug_opts);
-
-  GpuConvRewriter rewriter{GetCudaComputeCapability()};
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
-                               GetToolkitVersion()};
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  const HloInstruction* conv;
-  ASSERT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(m::Tuple(
-          m::Clamp(m::Broadcast(m::ConstantEffectiveScalar(0)),
-                   m::GetTupleElement(
-                       m::CustomCall(
-                           &conv, {kCudnnConvBiasActivationForwardCallTarget},
-                           m::Parameter(0), m::Parameter(1), m::Parameter(2)),
-                       0)
-                       .WithShape(F16, {1, 32, 9, 9}),
-                   m::Broadcast(m::ConstantEffectiveScalar(6))),
-          m::Minimum())));
-  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
-                          conv->backend_config<GpuBackendConfig>());
-  const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
-  EXPECT_EQ(config.activation_mode(), se::dnn::kNone);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, FuseLeakyRelu) {
-  const std::string module_str = R"(
-    HloModule Test
-    ENTRY Test {
-      inputs = f16[1,16,9,9] parameter(0)
-      filters = f16[3,3,16,32] parameter(1)
-      bias = f16[1,32,9,9] broadcast(f16[32] parameter(2)), dimensions={1}
-      zeros = f16[1,32,9,9] broadcast(f16[] constant(0)), dimensions={}
-      alphas = f16[1,32,9,9] broadcast(f16[] constant(0.2)), dimensions={}
-      conv = f16[1,32,9,9] convolution(inputs, filters),
-               window={size=3x3 pad=1_1x1_1},
-               dim_labels=bf01_01io->bf01
-      sum = add(conv, bias)
-      cmp = compare(sum, zeros), direction=GT
-      mul = multiply(sum, alphas)
-      ROOT leaky_relu = select(cmp, sum, mul)
-    })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-  DebugOptions debug_opts = m->config().debug_options();
-  debug_opts.set_xla_gpu_use_runtime_fusion(true);
-  m->mutable_config().set_debug_options(debug_opts);
-
-  GpuConvRewriter rewriter{GetCudaComputeCapability()};
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  // Leaky-relu fusion is only enabled on Ampere+.
-  CudnnFusedConvRewriter fuser{se::CudaComputeCapability(8, 0), GetDnnVersion(),
-                               GetToolkitVersion()};
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  const HloInstruction* conv;
-  ASSERT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::GetTupleElement(
-              m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget},
-                            m::Parameter(0), m::Parameter(1), m::Parameter(2)),
-              0)
-              .WithShape(F16, {1, 32, 9, 9})));
-  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
-                          conv->backend_config<GpuBackendConfig>());
-  const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
-  EXPECT_EQ(config.activation_mode(), se::dnn::kLeakyRelu);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, DontFuseLeakyReluIfMultipleUses) {
-  const std::string module_str = R"(
-    HloModule Test
-    ENTRY Test {
-      inputs = f16[1,16,9,9] parameter(0)
-      filters = f16[3,3,16,32] parameter(1)
-      bias = f16[1,32,9,9] broadcast(f16[32] parameter(2)), dimensions={1}
-      zeros = f16[1,32,9,9] broadcast(f16[] constant(0)), dimensions={}
-      alphas = f16[1,32,9,9] broadcast(f16[] constant(0.2)), dimensions={}
-      conv = f16[1,32,9,9] convolution(inputs, filters),
-               window={size=3x3 pad=1_1x1_1},
-               dim_labels=bf01_01io->bf01
-      sum = add(conv, bias)
-      cmp = compare(sum, zeros), direction=GT
-      mul = multiply(sum, alphas)
-      leaky_relu = select(cmp, sum, mul)
-      not_leaky_relu = minimum(sum, zeros)
-      ROOT root = tuple(leaky_relu, not_leaky_relu)
-    })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-  DebugOptions debug_opts = m->config().debug_options();
-  debug_opts.set_xla_gpu_use_runtime_fusion(true);
-  m->mutable_config().set_debug_options(debug_opts);
-
-  GpuConvRewriter rewriter{GetCudaComputeCapability()};
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
-                               GetToolkitVersion()};
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  const HloInstruction* conv;
-  auto gte_pattern =
-      m::GetTupleElement(
-          m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget},
-                        m::Parameter(0), m::Parameter(1), m::Parameter(2)),
-          0)
-          .WithShape(F16, {1, 32, 9, 9});
-  ASSERT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(m::Tuple(
-          m::Select(m::Compare(gte_pattern,
-                               m::Broadcast(m::ConstantEffectiveScalar(0)))
-                        .WithComparisonDirection(ComparisonDirection::kGt)
-                        .WithOneUse(),
-                    gte_pattern,
-                    m::Multiply(gte_pattern,
-                                m::Broadcast(m::ConstantEffectiveScalar()))),
-          m::Minimum())));
-  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
-                          conv->backend_config<GpuBackendConfig>());
-  const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
-  EXPECT_EQ(config.activation_mode(), se::dnn::kNone);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, DontFuseAlphaIfMultipleUsers) {
-  const std::string module_str = R"(
-    HloModule Test
-
-    ENTRY Test {
-      inputs = f32[1,17,9,9] parameter(0)
-      filters = f32[3,3,17,32] parameter(1)
-      bias = f32[1,32,9,9] broadcast(f32[32] parameter(2)), dimensions={1}
-      alpha = f32[1,32,9,9] broadcast(f32[] parameter(3)), dimensions={}
-      conv = f32[1,32,9,9] convolution(inputs, filters),
-               window={size=3x3 pad=1_1x1_1},
-               dim_labels=bf01_01io->bf01
-      sum = add(multiply(alpha, conv), bias)
-      ROOT root = tuple(conv, sum)
-    })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  GpuConvRewriter rewriter{GetCudaComputeCapability()};
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
-                               GetToolkitVersion()};
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  const HloInstruction* conv1;
-  const HloInstruction* conv2;
-  ASSERT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(m::Tuple(
-          m::GetTupleElement(m::CustomCall(&conv1), 0),
-          m::AddAnyOrder(m::Broadcast(m::Parameter(2)),
-                         m::MultiplyAnyOrder(
-                             m::Broadcast(m::Parameter(3)),
-                             m::GetTupleElement(m::CustomCall(&conv2), 0))))));
-  EXPECT_EQ(conv1, conv2);
-  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
-                          conv1->backend_config<GpuBackendConfig>());
-  const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
-  EXPECT_EQ(config.conv_result_scale(), 1);
-  EXPECT_EQ(config.activation_mode(), se::dnn::kNone);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, DontFuseBiasIfMultipleUsers) {
-  const std::string module_str = R"(
-    HloModule Test
-
-    ENTRY Test {
-      inputs = f32[1,17,9,9] parameter(0)
-      filters = f32[3,3,17,32] parameter(1)
-      bias = f32[1,32,9,9] broadcast(f32[32] parameter(2)), dimensions={1}
-      conv = f32[1,32,9,9] convolution(inputs, filters),
-               window={size=3x3 pad=1_1x1_1},
-               dim_labels=bf01_01io->bf01
-      ROOT root = tuple(conv, add(conv, bias))
-    })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  GpuConvRewriter rewriter{GetCudaComputeCapability()};
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
-                               GetToolkitVersion()};
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  const HloInstruction* conv1;
-  const HloInstruction* conv2;
-  ASSERT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(m::Tuple(
-          m::GetTupleElement(m::CustomCall(&conv1), 0),
-          m::AddAnyOrder(m::Broadcast(m::Parameter(2)),
-                         m::GetTupleElement(m::CustomCall(&conv2), 0)))));
-  EXPECT_EQ(conv1, conv2);
-  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
-                          conv1->backend_config<GpuBackendConfig>());
-  const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
-  EXPECT_EQ(config.conv_result_scale(), 1);
-  EXPECT_EQ(config.activation_mode(), se::dnn::kNone);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, DontFuseSideInputThroughRelu) {
-  const std::string module_str = R"(
-    HloModule Test
-
-    ENTRY Test {
-      inputs = f32[1,17,9,9] parameter(0)
-      filters = f32[3,3,17,32] parameter(1)
-      side_input = f32[1,32,9,9] parameter(2)
-      conv = f32[1,32,9,9] convolution(inputs, filters),
-               window={size=3x3 pad=1_1x1_1},
-               dim_labels=bf01_01io->bf01
-      relu = maximum(conv, f32[1,32,9,9] broadcast(f32[] constant(0)))
-      ROOT root = add(relu, side_input)
-    })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  GpuConvRewriter rewriter{GetCudaComputeCapability()};
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
-                               GetToolkitVersion()};
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  const HloInstruction* conv;
-  ASSERT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(m::AddAnyOrder(
-          m::Parameter(2),
-          m::GetTupleElement(
-              m::CustomCall(&conv, m::Parameter(0), m::Parameter(1),
-                            m::Broadcast(m::ConstantEffectiveScalar(0))),
-              0))));
-  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
-                          conv->backend_config<GpuBackendConfig>());
-  const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
-  EXPECT_EQ(config.conv_result_scale(), 1);
-  EXPECT_EQ(config.activation_mode(), se::dnn::kRelu);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, DontFuseBiasThroughRelu) {
-  const std::string module_str = R"(
-    HloModule Test
-
-    ENTRY Test {
-      inputs = f32[1,17,9,9] parameter(0)
-      filters = f32[3,3,17,32] parameter(1)
-      bias = f32[1,32,9,9] broadcast(f32[32] parameter(2)), dimensions={1}
-      conv = f32[1,32,9,9] convolution(inputs, filters),
-               window={size=3x3 pad=1_1x1_1},
-               dim_labels=bf01_01io->bf01
-      relu = maximum(conv, f32[1,32,9,9] broadcast(f32[] constant(0)))
-      ROOT root = add(relu, bias)
-    })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  GpuConvRewriter rewriter{GetCudaComputeCapability()};
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
-                               GetToolkitVersion()};
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  const HloInstruction* conv;
-  ASSERT_THAT(m->entry_computation()->root_instruction(),
-              GmockMatch(m::AddAnyOrder(
-                  m::Broadcast(m::Parameter(2)),
-                  m::GetTupleElement(m::CustomCall(
-                      &conv, m::Parameter(0), m::Parameter(1),
-                      m::Broadcast(m::ConstantEffectiveScalar(0)))))));
-  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
-                          conv->backend_config<GpuBackendConfig>());
-  const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
-  EXPECT_EQ(config.conv_result_scale(), 1);
-  EXPECT_EQ(config.activation_mode(), se::dnn::kRelu);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, DontFuseSideInputIfMultipleUsers) {
-  const std::string module_str = R"(
-    HloModule Test
-
-    ENTRY Test {
-      inputs = f32[1,17,9,9] parameter(0)
-      filters = f32[3,3,17,32] parameter(1)
-      side_input = f32[1,32,9,9] parameter(2)
-      conv = f32[1,32,9,9] convolution(inputs, filters),
-               window={size=3x3 pad=1_1x1_1},
-               dim_labels=bf01_01io->bf01
-      ROOT root = tuple(conv, add(conv, side_input))
-    })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  GpuConvRewriter rewriter{GetCudaComputeCapability()};
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
-                               GetToolkitVersion()};
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  const HloInstruction* conv1;
-  const HloInstruction* conv2;
-  ASSERT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(m::Tuple(
-          m::GetTupleElement(m::CustomCall(&conv1), 0),
-          m::AddAnyOrder(m::Parameter(2),
-                         m::GetTupleElement(m::CustomCall(&conv2), 0)))));
-  EXPECT_EQ(conv1, conv2);
-  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
-                          conv1->backend_config<GpuBackendConfig>());
-  const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
-  EXPECT_EQ(config.conv_result_scale(), 1);
-  EXPECT_EQ(config.activation_mode(), se::dnn::kNone);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, DontFuseConvertToF16IfMultipleUsers) {
-  const std::string module_str = R"(
-    HloModule Test
-
-    ENTRY Test {
-      inputs = f32[1,17,9,9] convert(f16[1,17,9,9] parameter(0))
-      filters = f32[3,3,17,32] convert(f16[3,3,17,32] parameter(1))
-      conv = f32[1,32,9,9] convolution(inputs, filters),
-               window={size=3x3 pad=1_1x1_1},
-               dim_labels=bf01_01io->bf01
-      ROOT root = tuple(conv, f16[1,32,9,9] convert(conv))
-    })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  GpuConvRewriter rewriter{GetCudaComputeCapability()};
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
-                               GetToolkitVersion()};
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  const HloInstruction* conv1;
-  const HloInstruction* conv2;
-  ASSERT_THAT(m->entry_computation()->root_instruction(),
-              GmockMatch(m::Tuple(
-                  m::GetTupleElement(m::CustomCall(&conv1), 0),
-                  m::Convert(m::GetTupleElement(m::CustomCall(&conv2), 0)))));
-  EXPECT_EQ(conv1, conv2);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, DontFuseToS8IfMultipleUsers) {
-  const std::string module_str = R"(
-    HloModule Test
-
-    ENTRY Test {
-      inputs = f32[1,17,9,9] convert(s8[1,17,9,9] parameter(0))
-      filters = f32[3,3,17,32] convert(s8[3,3,17,32] parameter(1))
-      conv = f32[1,32,9,9] convolution(inputs, filters),
-               window={size=3x3 pad=1_1x1_1},
-               dim_labels=bf01_01io->bf01
-      conv_s8 = s8[1,32,9,9] convert(clamp(
-                  f32[1,32,9,9] broadcast(f32[] constant(-128)),
-                  conv,
-                  f32[1,32,9,9] broadcast(f32[] constant(127))))
-      ROOT root = tuple(conv, conv_s8)
-    })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  GpuConvRewriter rewriter{GetCudaComputeCapability()};
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
-                               GetToolkitVersion()};
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  const HloInstruction* conv1;
-  const HloInstruction* conv2;
-  ASSERT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(m::Tuple(
-          m::GetTupleElement(m::CustomCall(&conv1), 0),
-          m::Convert(m::Clamp(m::Op(),  //
-                              m::GetTupleElement(m::CustomCall(&conv2), 0),
-                              m::Op())))));
-  EXPECT_EQ(conv1, conv2);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, RemoveConvertByFusingS32ToF32) {
-  MAYBE_SKIP_TEST("I8");
-  const std::string_view module_str = R"(
-    HloModule Test
-
-    ENTRY test_entry {
-      inputs = s8[1, 17, 9, 9] parameter(0)
-      filters = s8[3, 3, 17, 32] parameter(1)
-      mult_op  = f32[1, 32, 9, 9] parameter(2)
-      conv = s32[1, 32, 9, 9] convolution(inputs, filters), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01
-      ROOT ret = multiply(f32[1, 32, 9, 9] convert(conv), mult_op)
-    })";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-  GpuConvRewriter rewriter{GetCudaComputeCapability()};
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
-                               GetToolkitVersion()};
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-  SCOPED_TRACE(m->ToString());
-  HloInstruction* conv1 = nullptr;
-  // Checks that it removed the Convert inside multiply around conv.
-  ASSERT_THAT(m->entry_computation()->root_instruction(),
-              GmockMatch(m::Multiply(m::GetTupleElement(m::CustomCall(&conv1)),
-                                     m::Parameter(2))));
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, RemoveConvertByFusingS8ToF32) {
-  MAYBE_SKIP_TEST("I8");
-  const std::string_view module_str = R"(
-    HloModule Test
-
-    ENTRY test_entry {
-      inputs = s8[1, 17, 9, 9] parameter(0)
-      filters = s8[3, 3, 17, 32] parameter(1)
-      mult_op  = f32[1, 32, 9, 9] parameter(2)
-      conv = convolution(inputs, filters), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01
-      ROOT ret = multiply(f32[1, 32, 9, 9] convert(conv), mult_op)
-    })";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-  GpuConvRewriter rewriter{GetCudaComputeCapability()};
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
-                               GetToolkitVersion()};
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-  SCOPED_TRACE(m->ToString());
-  HloInstruction* conv1 = nullptr;
-  // Checks that it removed the Convert inside multiply around conv.
-  ASSERT_THAT(m->entry_computation()->root_instruction(),
-              GmockMatch(m::Multiply(m::GetTupleElement(m::CustomCall(&conv1)),
-                                     m::Parameter(2))));
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, RemoveConvertByFusingF32ToS8) {
-  MAYBE_SKIP_TEST("I8");
-  const std::string_view module_str = R"(
-    HloModule Test
-
-    ENTRY test_entry {
-      inputs = f32[1, 17, 9, 9] parameter(0)
-      filters = f32[3, 3, 17, 32] parameter(1)
-      mult_op  = s8[1, 32, 9, 9] parameter(2)
-      conv = convolution(inputs, filters), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01
-      ROOT ret = multiply(s8[1, 32, 9, 9] convert(conv), mult_op)
-    })";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-  GpuConvRewriter rewriter{GetCudaComputeCapability()};
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
-                               GetToolkitVersion()};
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-  SCOPED_TRACE(m->ToString());
-  HloInstruction* conv1 = nullptr;
-  // Checks that it removed the Convert inside multiply around conv.
-  ASSERT_THAT(m->entry_computation()->root_instruction(),
-              GmockMatch(m::Multiply(m::GetTupleElement(m::CustomCall(&conv1)),
-                                     m::Parameter(2))));
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, DontRemoveConvertDuetoMultpleUser) {
-  const std::string_view module_str = R"(
-    HloModule Test
-
-    ENTRY test_entry {
-      inputs = f32[1, 17, 9, 9] parameter(0)
-      filters = f32[3, 3, 17, 32] parameter(1)
-      mult_op  = s8[1, 32, 9, 9] parameter(2)
-      sub_op = s8[1, 32, 9, 9] parameter(3)
-      conv = convolution(inputs, filters), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01
-      another = subtract(s8[1, 32, 9, 9] convert(conv), sub_op)
-      ROOT ret = multiply(s8[1, 32, 9, 9] convert(conv), mult_op)
-    })";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-  GpuConvRewriter rewriter{GetCudaComputeCapability()};
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
-                               GetToolkitVersion()};
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-  SCOPED_TRACE(m->ToString());
-  HloInstruction* conv1 = nullptr;
-  // Checks that it removed the Convert inside multiply around conv.
-  ASSERT_THAT(m->entry_computation()->root_instruction(),
-              GmockMatch(m::Multiply(
-                  m::Convert(m::GetTupleElement(m::CustomCall(&conv1))),
-                  m::Parameter(2))));
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, FuseBias) {
-  const std::string module_str = R"(
-    HloModule Test
-
-    ENTRY Test {
-      inputs = f32[1,17,9,9] parameter(0)
-      filters = f32[3,3,17,32] parameter(1)
-      bias = f32[32] parameter(2)
-      bias_broadcast = f32[1,32,9,9] broadcast(bias), dimensions={1}
-      conv = f32[1,32,9,9] convolution(inputs, filters),
-               window={size=3x3 pad=1_1x1_1},
-               dim_labels=bf01_01io->bf01
-      ROOT root = add(conv, bias_broadcast)
-    })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  GpuConvRewriter rewriter{GetCudaComputeCapability()};
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
-                               GetToolkitVersion()};
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  ASSERT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::GetTupleElement(
-              m::CustomCall({kCudnnConvBiasActivationForwardCallTarget},
-                            m::Parameter(0), m::Parameter(1), m::Parameter(2)),
-              0)
-              .WithShape(F32, {1, 32, 9, 9})));
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, FuseSideInput) {
-  const std::string module_str = R"(
-    HloModule Test
-
-    ENTRY Test {
-      inputs = f32[1,17,9,9] parameter(0)
-      filters = f32[3,3,17,32] parameter(1)
-      side_input = f32[1,32,9,9] parameter(2)
-      conv = f32[1,32,9,9] convolution(inputs, filters),
-               window={size=3x3 pad=1_1x1_1},
-               dim_labels=bf01_01io->bf01
-      ROOT root = add(conv, side_input)
-    })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  GpuConvRewriter rewriter{GetCudaComputeCapability()};
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
-                               GetToolkitVersion()};
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  const HloInstruction* conv;
-  ASSERT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::GetTupleElement(
-              m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget},
-                            m::Parameter(0), m::Parameter(1),
-                            m::Broadcast(m::ConstantEffectiveScalar(0))
-                                .WithShape(F32, {32}),
-                            m::Parameter(2)),
-              0)
-              .WithShape(F32, {1, 32, 9, 9})));
-  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
-                          conv->backend_config<GpuBackendConfig>());
-  const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
-  EXPECT_EQ(config.side_input_scale(), 1);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, FuseScaledSideInput) {
-  const std::string module_str = R"(
-    HloModule Test
-
-    ENTRY Test {
-      inputs = f32[1,17,9,9] parameter(0)
-      filters = f32[3,3,17,32] parameter(1)
-      side_input = f32[1,32,9,9] parameter(2)
-      side_input_scale = f32[] constant(42)
-      side_input_scale_broadcast = f32[1,32,9,9] broadcast(side_input_scale), dimensions={}
-      side_input_product = multiply(side_input, side_input_scale_broadcast)
-      conv = f32[1,32,9,9] convolution(inputs, filters),
-               window={size=3x3 pad=1_1x1_1},
-               dim_labels=bf01_01io->bf01
-      ROOT root = add(conv, side_input_product)
-    })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  GpuConvRewriter rewriter{GetCudaComputeCapability()};
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
-                               GetToolkitVersion()};
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  const HloInstruction* conv;
-  ASSERT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::GetTupleElement(
-              m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget},
-                            m::Parameter(0), m::Parameter(1),
-                            m::Broadcast(m::ConstantEffectiveScalar(0))
-                                .WithShape(F32, {32}),
-                            m::Parameter(2)),
-              0)
-              .WithShape(F32, {1, 32, 9, 9})));
-  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
-                          conv->backend_config<GpuBackendConfig>());
-  const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
-  EXPECT_EQ(config.side_input_scale(), 42);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, FuseBiasAndSideInput) {
-  const std::string module_str = R"(
-    HloModule Test
-
-    ENTRY Test {
-      inputs = f32[1,17,9,9] parameter(0)
-      filters = f32[3,3,17,32] parameter(1)
-      bias = f32[32] parameter(2)
-      side_input = f32[1,32,9,9] parameter(3)
-      bias_broadcast = f32[1,32,9,9] broadcast(bias), dimensions={1}
-      conv = f32[1,32,9,9] convolution(inputs, filters),
-               window={size=3x3 pad=1_1x1_1},
-               dim_labels=bf01_01io->bf01
-      sum = add(conv, side_input)
-      ROOT sum2 = add(sum, bias_broadcast)
-    })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  GpuConvRewriter rewriter{GetCudaComputeCapability()};
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
-                               GetToolkitVersion()};
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  const HloInstruction* conv;
-  ASSERT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::GetTupleElement(
-              m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget},
-                            m::Parameter(0), m::Parameter(1), m::Parameter(2),
-                            m::Parameter(3)),
-              0)
-              .WithShape(F32, {1, 32, 9, 9})));
-  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
-                          conv->backend_config<GpuBackendConfig>());
-  const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
-  EXPECT_EQ(config.side_input_scale(), 1);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, EffectiveScalarBias) {
-  const std::string module_str = R"(
-    HloModule Test
-
-    ENTRY Test {
-      inputs = f32[1,17,9,9] parameter(0)
-      filters = f32[3,3,17,32] parameter(1)
-      bias = f32[1,32,9,9] broadcast(f32[] parameter(2)), dimensions={}
-      conv = f32[1,32,9,9] convolution(inputs, filters),
-               window={size=3x3 pad=1_1x1_1},
-               dim_labels=bf01_01io->bf01
-      ROOT root = add(conv, bias)
-    })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  GpuConvRewriter rewriter{GetCudaComputeCapability()};
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
-                               GetToolkitVersion()};
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  const HloInstruction* conv;
-  ASSERT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::GetTupleElement(
-              m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget},
-                            m::Parameter(0), m::Parameter(1),
-                            m::Broadcast(m::Parameter(2)).WithShape(F32, {32})),
-              0)
-              .WithShape(F32, {1, 32, 9, 9})));
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, StrengthReduceF32ToF16) {
-  const std::string module_str = R"(
-    HloModule Test
-
-    ENTRY Test {
-      inputs = f16[1,17,9,9] parameter(0)
-      filters = f16[3,3,17,32] parameter(1)
-      bias = f16[32] parameter(2)
-      side_input = f16[1,32,9,9] parameter(3)
-
-      inputs_f32 = f32[1,17,9,9] convert(inputs)
-      filters_f32 = f32[3,3,17,32] convert(filters)
-      bias_f32 = f32[32] convert(bias)
-      bias_broadcast = f32[1,32,9,9] broadcast(bias_f32), dimensions={1}
-      side_input_f32 = f32[1,32,9,9] convert(side_input)
-      conv = f32[1,32,9,9] convolution(inputs_f32, filters_f32),
-               window={size=3x3 pad=1_1x1_1},
-               dim_labels=bf01_01io->bf01
-      sum = add(conv, side_input_f32)
-      sum2 = add(sum, bias_broadcast)
-      ROOT conv_f16 = f16[1,32,9,9] convert(sum2)
-    })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  GpuConvRewriter rewriter{GetCudaComputeCapability()};
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
-                               GetToolkitVersion()};
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  // Simplify new `convert`'s that may be added to the graph.
-  AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
-  TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  const HloInstruction* conv;
-  ASSERT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::GetTupleElement(
-              m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget},
-                            m::Parameter(0), m::Parameter(1), m::Parameter(2),
-                            m::Parameter(3)),
-              0)
-              .WithShape(F16, {1, 32, 9, 9})));
-  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
-                          conv->backend_config<GpuBackendConfig>());
-  const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
-  EXPECT_EQ(config.side_input_scale(), 1);
-}
-
-// We should be able to lower this to an f16 convolution even though the
-// f16-ness of the inputs is hidden behind broadcast/transpose/reshape.
-TEST_F(CudnnFusedConvRewriterHloTest, BroadcastReshapeTransposeAfterConvert) {
-  const std::string module_str = R"(
-    HloModule Test
-
-    ENTRY Test {
-      inputs = f32[1,17,9,9] reshape(f32[1377] convert(f16[1377] parameter(0)))
-      filters = f32[3,3,17,32] transpose(f32[17,32,3,3] convert(f16[17,32,3,3] parameter(1))), dimensions={2,3,0,1}
-      bias = f16[1,32,9,9] broadcast(f16[32] parameter(2)), dimensions={1}
-      side_input = f16[1,32,9,9] reshape(f16[2592] parameter(3))
-
-      conv_f32 = f32[1,32,9,9] convolution(inputs, filters),
-                 window={size=3x3 pad=1_1x1_1},
-                 dim_labels=bf01_01io->bf01
-      conv_f16 = f16[1,32,9,9] convert(conv_f32)
-      ROOT root = f16[1,32,9,9] add(add(conv_f16, side_input), bias)
-    })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  GpuConvRewriter rewriter{GetCudaComputeCapability()};
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
-                               GetToolkitVersion()};
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  // Simplify new `convert`'s that may be added to the graph.
-  AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
-  TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  const HloInstruction* conv;
-  ASSERT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(m::GetTupleElement(
-                     m::CustomCall(
-                         &conv, {kCudnnConvBiasActivationForwardCallTarget},
-                         m::Convert(m::Reshape(m::Convert(m::Parameter(0))))
-                             .WithElementType(F16),
-                         m::Convert(m::Transpose(m::Convert(m::Parameter(1))))
-                             .WithElementType(F16),
-                         m::Parameter(2), m::Reshape(m::Parameter(3))),
-                     0)
-                     .WithShape(F16, {1, 32, 9, 9})));
-  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
-                          conv->backend_config<GpuBackendConfig>());
-  const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
-  EXPECT_EQ(config.side_input_scale(), 1);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, NoStrengthReduceF32ToF16IfBiasIsF32) {
-  const std::string module_str = R"(
-    HloModule Test
-
-    ENTRY Test {
-      inputs = f16[1,17,9,9] parameter(0)
-      filters = f16[3,3,17,32] parameter(1)
-      bias = f32[32] parameter(2)
-      side_input = f16[1,32,9,9] parameter(3)
-
-      inputs_f32 = f32[1,17,9,9] convert(inputs)
-      filters_f32 = f32[3,3,17,32] convert(filters)
-      bias_broadcast = f32[1,32,9,9] broadcast(bias), dimensions={1}
-      side_input_f32 = f32[1,32,9,9] convert(side_input)
-      conv = f32[1,32,9,9] convolution(inputs_f32, filters_f32),
-               window={size=3x3 pad=1_1x1_1},
-               dim_labels=bf01_01io->bf01
-      sum = add(conv, side_input_f32)
-      sum2 = add(sum, bias_broadcast)
-      ROOT conv_f16 = f16[1,32,9,9] convert(sum2)
-    })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  GpuConvRewriter rewriter{GetCudaComputeCapability()};
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
-                               GetToolkitVersion()};
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  // Simplify new `convert`'s that may be added to the graph.
-  AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
-  TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  const HloInstruction* conv;
-  // fp16 convs only support fp16 biases.  Because bias is fp32, it doesn't get
-  // fused in, and we get an fp32 conv.
-  ASSERT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::Convert(m::GetTupleElement(
-                         m::CustomCall(
-                             &conv, {kCudnnConvBiasActivationForwardCallTarget},
-                             m::Convert(m::Parameter(0)).WithElementType(F32),
-                             m::Convert(m::Parameter(1)).WithElementType(F32),
-                             m::Parameter(2),
-                             m::Convert(m::Parameter(3)).WithElementType(F32)),
-                         0))
-              .WithShape(F16, {1, 32, 9, 9})));
-  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
-                          conv->backend_config<GpuBackendConfig>());
-  const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
-  EXPECT_EQ(config.side_input_scale(), 1);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, F32Constants) {
-  const std::string module_str = R"(
-    HloModule Test
-
-    ENTRY Test {
-      inputs = f16[1,2,2,2] parameter(0)
-      filters_f32 = f32[1,1,2,2] constant({{{{1, 2},{3, 4}}}})
-      bias = f16[2] parameter(1)
-      bias_f32 = f32[2] convert(bias)
-      side_input_f32 = f32[1,2,2,2] constant({{
-        {{0.5, 0.25}, {0.125, 0.0625}},
-        {{0.5, 0.25}, {0.125, 0.0625}}
-      }})
-
-      inputs_f32 = f32[1,2,2,2] convert(inputs)
-      bias_broadcast = f32[1,2,2,2] broadcast(bias_f32), dimensions={1}
-      conv = f32[1,2,2,2] convolution(inputs_f32, filters_f32),
-               window={size=1x1}, dim_labels=bf01_01io->bf01
-      sum = add(conv, side_input_f32)
-      sum2 = add(sum, bias_broadcast)
-      ROOT conv_f16 = f16[1,2,2,2] convert(sum2)
-    })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  GpuConvRewriter rewriter{GetCudaComputeCapability()};
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
-                               GetToolkitVersion()};
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  // Simplify new `convert`'s that may be added to the graph, and fold
-  // convert back into constants.
-  AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
-  TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
-  HloConstantFolding constant_folding;
-  TF_ASSERT_OK(RunHloPass(&constant_folding, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  const HloInstruction* conv;
-  ASSERT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(m::GetTupleElement(
-                     m::CustomCall(
-                         &conv, {kCudnnConvBiasActivationForwardCallTarget},
-                         m::Parameter(0), m::Constant().WithElementType(F16),
-                         m::Parameter(1), m::Constant().WithElementType(F16)),
-                     0)
-                     .WithShape(F16, {1, 2, 2, 2})));
-  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
-                          conv->backend_config<GpuBackendConfig>());
-  const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
-  EXPECT_EQ(config.side_input_scale(), 1);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, F32ConstantsNotLosslesslyConvertible) {
-  const std::string module_str = R"(
-    HloModule Test
-
-    ENTRY Test {
-      inputs = f16[1,2,2,2] parameter(0)
-      filters_f32 = f32[1,1,2,2] constant({{{{1, 2.123456789},{3, 4}}}})
-      bias = f16[2] parameter(1)
-      bias_f32 = f32[2] convert(bias)
-      side_input_f32 = f32[1,2,2,2] constant({{
-        {{0.1, 0.2}, {0.3, 0.4}},
-        {{0.5, 0.6}, {0.7, 0.8}}
-      }})
-
-      inputs_f32 = f32[1,2,2,2] convert(inputs)
-      bias_broadcast = f32[1,2,2,2] broadcast(bias_f32), dimensions={1}
-      conv = f32[1,2,2,2] convolution(inputs_f32, filters_f32),
-               window={size=1x1}, dim_labels=bf01_01io->bf01
-      sum = add(conv, side_input_f32)
-      sum2 = add(sum, bias_broadcast)
-      ROOT conv_f16 = f16[1,2,2,2] convert(sum2)
-    })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  GpuConvRewriter rewriter{GetCudaComputeCapability()};
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
-                               GetToolkitVersion()};
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  // Simplify new `convert`'s that may be added to the graph, and fold
-  // convert back into constants.
-  AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
-  TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
-  HloConstantFolding constant_folding;
-  TF_ASSERT_OK(RunHloPass(&constant_folding, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  const HloInstruction* conv;
-  // This doesn't get transformed into an f16 conv because the filters param is
-  // not losslessly expressible as f16.
-  ASSERT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::Convert(m::GetTupleElement(
-                         m::CustomCall(
-                             &conv, {kCudnnConvBiasActivationForwardCallTarget},
-                             m::Convert(m::Parameter(0)).WithElementType(F32),
-                             m::Constant().WithElementType(F32),
-                             m::Convert(m::Parameter(1)).WithElementType(F32),
-                             m::Constant().WithElementType(F32)),
-                         0)
-                         .WithShape(F32, {1, 2, 2, 2}))
-              .WithElementType(F16)));
-  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
-                          conv->backend_config<GpuBackendConfig>());
-  const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
-  EXPECT_EQ(config.side_input_scale(), 1);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, FuseReluBeforeConvert) {
-  MAYBE_SKIP_TEST("I8");
-  const std::string module_str = R"(
-  HloModule Test
-
-  ENTRY Test {
-    input = s8[1,17,9,9] parameter(0)
-    filter = s8[3,3,17,32] parameter(1)
-    inputs32 = s32[1,17,9,9] convert(input)
-    filters32 = s32[3,3,17,32] convert(filter)
-
-    conv = s32[1,32,9,9] convolution(inputs32, filters32), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
-
-    zero = s32[] constant(0)
-    zeros = s32[1,32,9,9] broadcast(zero), dimensions={}
-    relu = maximum(conv, zeros)
-
-    lower = s32[] constant(-128)
-    lowers = s32[1,32,9,9] broadcast(lower), dimensions={}
-    upper = s32[] constant(127)
-    uppers = s32[1,32,9,9] broadcast(upper), dimensions={}
-
-    clamp = s32[1,32,9,9] clamp(lowers, relu, uppers)
-
-    ROOT convert = s8[1,32,9,9] convert(clamp)
-  })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  GpuConvRewriter rewriter{GetCudaComputeCapability()};
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
-                               GetToolkitVersion()};
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  // Simplify new `convert`'s that may be added to the graph.
-  AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
-  TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  const HloInstruction* conv;
-  ASSERT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::GetTupleElement(
-              m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget},
-                            m::Parameter(0),  //
-                            m::Parameter(1),  //
-                            m::Broadcast(m::ConstantEffectiveScalar(0))
-                                .WithShape(F32, {32})),
-              0)
-              .WithShape(S8, {1, 32, 9, 9})));
-  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
-                          conv->backend_config<GpuBackendConfig>());
-  const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
-  EXPECT_EQ(config.activation_mode(), se::dnn::kRelu);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, BiasTypeMatchesConvTypeIfFp) {
-  MAYBE_SKIP_TEST("F64");
-  const std::string module_str = R"(
-  HloModule Test
-
-  ENTRY Test {
-    input = f64[1,17,9,9] parameter(0)
-    filter = f64[3,3,17,32] parameter(1)
-    bias = f64[1,32,9,9] broadcast(f64[32] convert(f32[32] parameter(2))), dimensions={1}
-    conv = f64[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
-    ROOT root = f64[1,32,9,9] add(conv, bias)
-  })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  GpuConvRewriter rewriter{GetCudaComputeCapability()};
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
-                               GetToolkitVersion()};
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  // Simplify new `convert`'s that may be added to the graph.
-  AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
-  TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  const HloInstruction* conv;
-  ASSERT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::GetTupleElement(
-              m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget},
-                            m::Parameter(0),  //
-                            m::Parameter(1),  //
-                            m::Convert(m::Parameter(2)).WithShape(F64, {32})),
-              0)
-              .WithShape(F64, {1, 32, 9, 9})));
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestFusedConvInt8ToInt8) {
-  MAYBE_SKIP_TEST("I8");
-  // clamp(max(0, conv(x, w)+bias)); for int8_t
-  TestClamp(
-      // pre_hlo
-      R"(
-    HloModule Test
-
-    ENTRY Test {
-      zero = f32[] constant(0)
-      zeros = f32[1,3,3,64] broadcast(zero), dimensions={}
-
-      input = s8[1,3,3,64] parameter(0)
-      filter = s8[3,3,64,64] parameter(1)
-      bias = f32[64] parameter(2)
-
-      inputs32 = s32[1,3,3,64] convert(input)
-      filters32 = s32[3,3,64,64] convert(filter)
-
-      conv = s32[1,3,3,64] convolution(inputs32, filters32), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
-
-      convfloat = f32[1,3,3,64] convert(conv)
-      broadcasted_bias = f32[1,3,3,64] broadcast(bias), dimensions={3}
-      add1 = f32[1,3,3,64] add(convfloat, broadcasted_bias)
-      relu = f32[1,3,3,64] maximum(zeros, add1)
-
-      lower = f32[] constant(-128)
-      lowers = f32[1,3,3,64] broadcast(lower), dimensions={}
-      upper = f32[] constant(127)
-      uppers = f32[1,3,3,64] broadcast(upper), dimensions={}
-
-      clamp = f32[1,3,3,64] clamp(lowers, relu, uppers)
-
-      ROOT convert = s8[1,3,3,64] convert(clamp)
-    })",
-      // post_hlo
-      R"(
-// CHECK: [[cudnn_conv_bias_activation_7_0:%[^ ]+]] = (s8[1,3,3,64]{3,2,1,0}, u8[{{[0-9]+}}]{0}) custom-call([[input_1:%[^ ]+]], [[transpose_2:%[^ ]+]], [[bias_3:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convBiasActivationForward"
-      )");
-}
-
-// Disabled per b/190854862 or nvbugs/3326122.
-TEST_F(CudnnFusedConvRewriterTest, DISABLED_TestFusedConvInt8ToFloat) {
-  MAYBE_SKIP_TEST("I8");
-  // max(0, convert<float>(conv<int32_t>(int8_x),
-  // conv<int32_t>(int8_w))+float_bias)); int8_t to float via bias.
-  TestClamp(
-      // pre_hlo
-      R"(
-    HloModule Test
-
-    ENTRY Test {
-      zero = f32[] constant(0)
-      zeros = f32[1,3,3,64] broadcast(zero), dimensions={}
-
-      input = s8[1,3,3,64] parameter(0)
-      filter = s8[3,3,64,64] parameter(1)
-      bias = f32[64] parameter(2)
-
-      inputs32 = s32[1,3,3,64] convert(input)
-      filters32 = s32[3,3,64,64] convert(filter)
-
-      conv = s32[1,3,3,64] convolution(inputs32, filters32), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
-
-      convfloat = f32[1,3,3,64] convert(conv)
-      broadcasted_bias = f32[1,3,3,64] broadcast(bias), dimensions={3}
-      add1 = f32[1,3,3,64] add(convfloat, broadcasted_bias)
-      ROOT relu = f32[1,3,3,64] maximum(zeros, add1)
-    })",
-      // post_hlo
-      R"(
-      ; CHECK-LABEL: ENTRY %Test (input: s8[1,3,3,64], filter: s8[3,3,64,64], bias: f32[64]) -> f32[1,3,3,64] {
-      ; CHECK:  [[custom_call_0:%[^ ]+]]{{(\.[0-9])?}} = (f32[1,3,3,64]{3,2,1,0}, u8[{{[0-9]*}}]{0}) custom-call([[input_1:%[^ ]+]], [[copy_2:%[^ ]+]]{{(\.[0-9])?}}, [[bias_3:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convBiasActivationForward", backend_config=
-      ; CHECK-NEXT:  ROOT [[get_tuple_element_4:%[^ ]+]]{{(\.[0-9])?}} = f32[1,3,3,64]{3,2,1,0} get-tuple-element([[custom_call_0]]{{(\.[0-9])?}}), index=0
-      )");
-}
-
-TEST_F(CudnnFusedConvRewriterTest,
-       TestFusedConvWithScaledInt8SideInputBiasInt8ToInt8) {
-  MAYBE_SKIP_TEST("I8");
-  // clamp(max(0, alpha_conv * conv(x, w) + alpha_side *
-  // convert<int32_t>(int8_side_input) + bias)); for int8_t
-  TestClamp(
-      // pre_hlo
-      R"(
-    HloModule Test
-
-    ENTRY Test {
-      zero = f32[] constant(0)
-      zeros = f32[1,3,3,64] broadcast(zero), dimensions={}
-      alpha_conv_scalar = f32[] constant(0.999994934)
-      alpha_conv = f32[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={}
-      alpha_side_input_scalar = f32[] constant(0.899994934)
-      alpha_side_input = f32[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={}
-
-      input = s8[1,3,3,64] parameter(0)
-      filter = s8[3,3,64,64] parameter(1)
-      side_input = s8[1,3,3,64] parameter(2)
-      bias = f32[64] parameter(3)
-
-      inputs32 = s32[1,3,3,64] convert(input)
-      filters32 = s32[3,3,64,64] convert(filter)
-      side_input_f32 = f32[1,3,3,64] convert(side_input)
-
-      conv = s32[1,3,3,64] convolution(inputs32, filters32), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
-
-      convfloat = f32[1,3,3,64] convert(conv)
-      scaled_conv = f32[1,3,3,64] multiply(convfloat, alpha_conv)
-      scaled_side_input = f32[1,3,3,64] multiply(side_input_f32, alpha_side_input)
-      broadcasted_bias = f32[1,3,3,64] broadcast(bias), dimensions={3}
-      add1 = f32[1,3,3,64] add(scaled_conv, broadcasted_bias)
-      add2 = f32[1,3,3,64] add(add1, scaled_side_input)
-      relu = f32[1,3,3,64] maximum(zeros, add2)
-
-      lower = f32[] constant(-128)
-      lowers = f32[1,3,3,64] broadcast(lower), dimensions={}
-      upper = f32[] constant(127)
-      uppers = f32[1,3,3,64] broadcast(upper), dimensions={}
-
-      clamp = f32[1,3,3,64] clamp(lowers, relu, uppers)
-
-      ROOT convert = s8[1,3,3,64] convert(clamp)
-    })",
-      // post_hlo
-      R"(
-// CHECK: [[cudnn_conv_bias_activation_11_0:%[^ ]+]] = (s8[1,3,3,64]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[input_1:%[^ ]+]], [[transpose_2:%[^ ]+]], [[bias_3:%[^ ]+]], [[side_input_4:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convBiasActivationForward"
-      )");
-}
-
-TEST_F(CudnnFusedConvRewriterTest,
-       TestFusedConvWithScaledFloatSideInputBiasInt8ToInt8) {
-  MAYBE_SKIP_TEST("I8");
-  // From:
-  // convert<int8_t>(clamp(max(0, alpha_conv * conv(x, w) + alpha_side *
-  // float_side_input + bias))); To: convert<int8_t>(clamp(conv(int8_x, int8_w,
-  // float_alpha_side, float_side_input, float_bias)));
-  TestClamp(
-      // pre_hlo
-      R"(
-    HloModule Test
-
-    ENTRY Test {
-      zero = f32[] constant(0)
-      zeros = f32[1,3,3,64] broadcast(zero), dimensions={}
-      alpha_conv_scalar = f32[] constant(0.999994934)
-      alpha_conv = f32[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={}
-      alpha_side_input_scalar = f32[] constant(0.899994934)
-      alpha_side_input = f32[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={}
-
-      input = s8[1,3,3,64] parameter(0)
-      filter = s8[3,3,64,64] parameter(1)
-      side_input = f32[1,3,3,64] parameter(2)
-      bias = f32[64] parameter(3)
-
-      inputs32 = s32[1,3,3,64] convert(input)
-      filters32 = s32[3,3,64,64] convert(filter)
-
-      conv = s32[1,3,3,64] convolution(inputs32, filters32), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
-
-      convfloat = f32[1,3,3,64] convert(conv)
-      scaled_conv = f32[1,3,3,64] multiply(convfloat, alpha_conv)
-      scaled_side_input = f32[1,3,3,64] multiply(side_input, alpha_side_input)
-      broadcasted_bias = f32[1,3,3,64] broadcast(bias), dimensions={3}
-      add1 = f32[1,3,3,64] add(scaled_conv, broadcasted_bias)
-      add2 = f32[1,3,3,64] add(add1, scaled_side_input)
-      relu = f32[1,3,3,64] maximum(zeros, add2)
-
-      lower = f32[] constant(-128)
-      lowers = f32[1,3,3,64] broadcast(lower), dimensions={}
-      upper = f32[] constant(127)
-      uppers = f32[1,3,3,64] broadcast(upper), dimensions={}
-
-      clamp = f32[1,3,3,64] clamp(lowers, relu, uppers)
-
-      ROOT convert = s8[1,3,3,64] convert(clamp)
-    })",
-      //  post_hlo
-      R"(
-// CHECK: [[cudnn_conv_bias_activation_9_0:%[^ ]+]] = (f32[1,3,3,64]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[input_1:%[^ ]+]], [[transpose_2:%[^ ]+]], [[bias_3:%[^ ]+]], [[side_input_4:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convBiasActivationForward"
-      )");
-}
-
-TEST_F(CudnnFusedConvRewriterTest,
-       TestFusedConvWithScaledInt8SideInputBiasInt8ToFloat) {
-  MAYBE_SKIP_TEST("I8");
-  // From:
-  // clamp(max(0, alpha_conv * conv(x, w) + alpha_side *
-  // convert<float>(int8_side_input) + bias)); To: clamp(conv(int8_x, int8_w,
-  // float_alpha_side, convert<float>(int8_side_input), float_bias));
-  TestClamp(
-      // pre_hlo
-      R"(
-    HloModule Test
-
-    ENTRY Test {
-      zero = f32[] constant(0)
-      zeros = f32[1,3,3,64] broadcast(zero), dimensions={}
-      alpha_conv_scalar = f32[] constant(0.999994934)
-      alpha_conv = f32[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={}
-      alpha_side_input_scalar = f32[] constant(0.899994934)
-      alpha_side_input = f32[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={}
-
-      input = s8[1,3,3,64] parameter(0)
-      filter = s8[3,3,64,64] parameter(1)
-      side_input = s8[1,3,3,64] parameter(2)
-      bias = f32[64] parameter(3)
-
-      inputs32 = s32[1,3,3,64] convert(input)
-      filters32 = s32[3,3,64,64] convert(filter)
-      side_input_f32 = f32[1,3,3,64] convert(side_input)
-
-      conv = s32[1,3,3,64] convolution(inputs32, filters32), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
-
-      convfloat = f32[1,3,3,64] convert(conv)
-      scaled_conv = f32[1,3,3,64] multiply(convfloat, alpha_conv)
-      scaled_side_input = f32[1,3,3,64] multiply(side_input_f32, alpha_side_input)
-      broadcasted_bias = f32[1,3,3,64] broadcast(bias), dimensions={3}
-      add1 = f32[1,3,3,64] add(scaled_conv, broadcasted_bias)
-      add2 = f32[1,3,3,64] add(add1, scaled_side_input)
-      relu = f32[1,3,3,64] maximum(zeros, add2)
-
-      lower = f32[] constant(-128)
-      lowers = f32[1,3,3,64] broadcast(lower), dimensions={}
-      upper = f32[] constant(127)
-      uppers = f32[1,3,3,64] broadcast(upper), dimensions={}
-
-      ROOT clamp = f32[1,3,3,64] clamp(lowers, relu, uppers)
-    })",
-      // post_hlo
-      R"(
-// CHECK: [[cudnn_conv_bias_activation_9_0:%[^ ]+]] = (f32[1,3,3,64]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[input_1:%[^ ]+]], [[transpose_2:%[^ ]+]], [[bias_3:%[^ ]+]], [[fusion_1_4:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convBiasActivationForward"
-      )");
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestConvInt8ToInt8NoClamp) {
-  MAYBE_SKIP_TEST("I8");
-  // Check that integer convolution without clamp to int8_t is not allowed.
-  // convert<int8_t>(custom_call<int32_t>(int32_x, int32_w,
-  // cudnnConvolutionForward))
-  const std::string module_str = absl::StrFormat(R"(
-    HloModule Test
-
-    ENTRY Test (input: s8[1,17,9,9], filter: s8[3,3,17,32]) -> s8[1,32,9,9] {
-      zero = s8[] constant(0)
-      zeros = s8[1,32,9,9]{3,2,1,0} broadcast(s8[] zero), dimensions={}
-      input = s8[1,17,9,9]{3,2,1,0} parameter(0)
-      filter = s8[3,3,17,32]{3,2,1,0} parameter(1)
-      custom-call = (s32[1,32,9,9]{3,2,1,0}, u8[0]{0}) custom-call(s8[1,17,9,9]{3,2,1,0} input, s8[3,3,17,32]{3,2,1,0} filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, custom_call_target="__cudnn$convForward", backend_config="{\"convResultScale\":1}"
-      get-tuple-element = s32[1,32,9,9]{3,2,1,0} get-tuple-element((s32[1,32,9,9]{3,2,1,0}, u8[0]{0}) custom-call), index=0
-      convert = s8[1,32,9,9]{3,2,1,0} convert(s32[1,32,9,9]{3,2,1,0} get-tuple-element)
-      ROOT relu = s8[1,32,9,9]{3,2,1,0} maximum(s8[1,32,9,9]{3,2,1,0} zeros, s8[1,32,9,9]{3,2,1,0} convert)
-    })");
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  ASSERT_FALSE(CudnnFusedConvRewriter(GetCudaComputeCapability(),
-                                      GetDnnVersion(), GetToolkitVersion())
-                   .Run(m.get())
-                   .ok());
-}
-
-TEST_F(CudnnFusedConvRewriterTest, TestFusedConvInt8ToInt8NoClamp) {
-  MAYBE_SKIP_TEST("I8");
-  // Although bias and so on are fused with forward convolution,
-  // it is still not allowed if the output is not clampped/converted to int8_t
-  // max(0, alpha_conv * conv(x, w) + alpha_side * side_input + bias); for
-  // int8_t
-
-  const std::string module_str = absl::StrFormat(R"(
-    HloModule Test
-
-    ENTRY Test (input: s8[1,17,9,9], filter: s8[3,3,17,32]) -> s8[1,32,9,9] {
-      zero = s8[] constant(0)
-      zeros = s8[1,32,9,9]{3,2,1,0} broadcast(s8[] zero), dimensions={}
-      input = s8[1,17,9,9]{3,2,1,0} parameter(0)
-      filter = s8[3,3,17,32]{3,2,1,0} parameter(1)
-      custom-call = (s32[1,32,9,9]{3,2,1,0}, u8[0]{0}) custom-call(s8[1,17,9,9]{3,2,1,0} input, s8[3,3,17,32]{3,2,1,0} filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, custom_call_target="__cudnn$convForward", backend_config="{\"convResultScale\":1}"
-      get-tuple-element = s32[1,32,9,9]{3,2,1,0} get-tuple-element((s32[1,32,9,9]{3,2,1,0}, u8[0]{0}) custom-call), index=0
-      convert = s8[1,32,9,9]{3,2,1,0} convert(s32[1,32,9,9]{3,2,1,0} get-tuple-element)
-      ROOT relu = s8[1,32,9,9]{3,2,1,0} maximum(s8[1,32,9,9]{3,2,1,0} zeros, s8[1,32,9,9]{3,2,1,0} convert)
-    })");
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  ASSERT_FALSE(CudnnFusedConvRewriter(GetCudaComputeCapability(),
-                                      GetDnnVersion(), GetToolkitVersion())
-                   .Run(m.get())
-                   .ok());
-}
-
-}  // namespace
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter.cc b/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter.cc
deleted file mode 100644
index 23b238d..0000000
--- a/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter.cc
+++ /dev/null
@@ -1,1781 +0,0 @@
-/* Copyright 2022 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/cudnn_fused_mha_rewriter.h"
-
-#include <algorithm>
-#include <cstdint>
-#include <numeric>
-#include <optional>
-#include <queue>
-#include <string>
-#include <utility>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/log/check.h"
-#include "absl/log/log.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/str_format.h"
-#include "absl/strings/str_join.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/permutation_util.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/service/gpu/matmul_utils.h"
-#include "xla/service/gpu/stream_executor_util.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/status_macros.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/stream_executor/dnn.h"
-#include "xla/types.h"
-#include "xla/util.h"
-#include "xla/xla.pb.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/statusor.h"
-
-#if GOOGLE_CUDA
-#include "third_party/gpus/cuda/include/cuda.h"
-#endif
-
-namespace xla {
-namespace gpu {
-namespace {
-namespace m = match;
-
-// A struct that contains all the matched nodes
-// and results from pattern matching forward graph
-struct MatchFwdResult {
-  HloInstruction* matched_bmm_1 = nullptr;
-  HloInstruction* matched_bmm_2 = nullptr;
-  HloInstruction* matched_bias = nullptr;
-  HloInstruction* matched_scale = nullptr;
-  HloInstruction* matched_softmax_input = nullptr;
-  HloInstruction* matched_reduce_sum = nullptr;
-
-  double matched_dropout_rate = 0.0;
-  bool need_canonicalization = false;
-  bool is_training = false;
-  // We use this to keep track of whether the bias is being
-  // applied to the bmm1 is a causal mask, cuDNN can generate causal mask inside
-  // the attention kernel to save I/O.
-  bool is_causal_mask = false;
-  bool has_match = false;
-  std::string matched_custom_call_name;
-};
-
-// A struct that contains all the matched nodes
-// and results from pattern matching backward graph
-struct MatchBwdResult {
-  HloInstruction* matched_bmm_1_grad_1 = nullptr;
-  HloInstruction* matched_bmm_1_grad_2 = nullptr;
-
-  HloInstruction* matched_bmm_2_grad_1 = nullptr;
-  HloInstruction* matched_bmm_2_grad_2 = nullptr;
-  HloInstruction* matched_dbias = nullptr;
-  // We use this to keep track of all gradient bmms that need
-  // canonicalization.
-  bool bmm_1_grad_1_need_canonicalization = false;
-  bool bmm_1_grad_2_need_canonicalization = false;
-  bool bmm_2_grad_1_need_canonicalization = false;
-  bool bmm_2_grad_2_need_canonicalization = false;
-
-  bool has_match = false;
-  std::string matched_custom_call_name;
-};
-
-template <typename Pattern>
-auto OptionalReshape(Pattern pattern) {
-  auto shared = m::SharedSubpattern(pattern);
-  return m::AnyOf<HloInstruction>(m::Reshape(shared), shared);
-}
-
-template <typename Pattern>
-auto OptionalConvert(Pattern pattern) {
-  auto shared = m::SharedSubpattern(pattern);
-  return m::AnyOf<HloInstruction>(m::Convert(shared), shared);
-}
-
-template <typename Pattern>
-auto OptionalBitcast(Pattern pattern) {
-  auto shared = m::SharedSubpattern(pattern);
-  return m::AnyOf<HloInstruction>(m::Bitcast(shared), shared);
-}
-
-template <typename Pattern>
-auto OptionalBroadcast(Pattern pattern) {
-  auto shared = m::SharedSubpattern(pattern);
-  return m::AnyOf<HloInstruction>(m::Broadcast(shared), shared);
-}
-
-bool IsBatchedMatmul(const HloInstruction* instr) {
-  if (instr->opcode() != HloOpcode::kDot) return false;
-  if (Cast<HloDotInstruction>(instr)->sparse_operands()) return false;
-  const DotDimensionNumbers& dot_dims = instr->dot_dimension_numbers();
-  bool is_batch_dot = !dot_dims.lhs_batch_dimensions().empty() ||
-                      !dot_dims.rhs_batch_dimensions().empty();
-  return is_batch_dot;
-}
-
-// We need to check if current gemm is sharing a parent node with a forward
-// fMHA call because when we match backward gemms, the only way that we can be
-// sure this is a backward gemm is to see if it's sharing the same operand with
-// forward mha call(i.e Q,K,V,activation tensors). We can also use this function
-// to infer if a gemm is a forward fmha gemm or not. We check this by doing a
-// BFS of all operands to see if there's any user that is a forward fMHA custom
-// call. We continue the traversal for shape ops like bitcast, reshape and
-// transpose until we see a forward fmha call or there's no shape ops in path
-// which means that current node will never share the same operand with a
-// forward fmha call.
-bool IsSharingOperandWithFwdMha(HloInstruction* gemm) {
-  for (int64_t i = 0; i < gemm->operands().size(); i++) {
-    std::queue<HloInstruction*> visit_list;
-    visit_list.push(gemm->mutable_operand(i));
-    while (!visit_list.empty()) {
-      HloInstruction* current_instr = visit_list.front();
-      for (auto user : current_instr->users()) {
-        switch (user->opcode()) {
-          case HloOpcode::kBitcast:
-          case HloOpcode::kReshape:
-          case HloOpcode::kTranspose: {
-            visit_list.push(user);
-            break;
-          }
-          case HloOpcode::kCustomCall: {
-            if (IsFwdCustomCallTofMHA(*user)) {
-              return true;
-            }
-          } break;
-          default:
-            break;
-        }
-      }
-      visit_list.pop();
-    }
-  }
-  return false;
-}
-// When we reach a gemm instruction, it could be one of the 3 cases:
-//   1. one of the 2 gemms in forward fmha call
-//   2. one of the 4 gemms in backward fmha call
-//   3. gemms of other un-related layers
-// 3 can be easily ruled out by the pattern matcher.
-// However, 1 and 2 have very similar bmm-bmm structures.
-// We need to determine that we exactly match case 1 for forward gemms
-// which have below properties:
-//    - A batched matmul
-//    - None of the operands is a forward fmha call, in which case would make it
-//      a backward gemm.
-//    - It's not directly or indirectly sharing an operand with any other fmha
-//      call, in which case would make it a backward gemm
-bool IsFirstFwdMatmul(HloInstruction* gemm) {
-  return IsBatchedMatmul(gemm) && !IsFwdCustomCallTofMHA(*gemm->operand(0)) &&
-         !IsFwdCustomCallTofMHA(*gemm->operand(1)) &&
-         !IsSharingOperandWithFwdMha(gemm);
-}
-
-bool IsScalar(const HloInstruction* instr) {
-  return ShapeUtil::IsEffectiveScalar(instr->shape());
-}
-
-bool IsReduceMax(const HloInstruction* instr) {
-  return instr->opcode() == HloOpcode::kReduce &&
-         instr->to_apply()->root_instruction()->opcode() == HloOpcode::kMaximum;
-}
-
-bool IsReduceSum(const HloInstruction* instr) {
-  return instr->opcode() == HloOpcode::kReduce &&
-         instr->to_apply()->root_instruction()->opcode() == HloOpcode::kAdd;
-}
-
-// Set up subpatterns for re-use.
-// Matches softmax sub-pattern ->
-// divide(exp(Subtract(producer, reduce_max(producer))),
-// broadcast(reduce_add(exp(Subtract(...))))). There might be reshape and
-// convert nodes between reduce and Subtract.
-// TODO TJ: Make this more general to any patterns that has this structure when
-// cudnn runner supports generic cudnnOpGraphs. producer
-// |   \
-// |  reduce
-// |     |
-// |  broadcast
-// |   /
-// root
-auto GetUnfusedReduceMaxSumSoftmaxPattern(
-    HloInstruction** softmax_input = nullptr,
-    HloInstruction** softmax_reduce_sum = nullptr,
-    HloInstruction** softmax_reduce_sum_bcast = nullptr) {
-  // The reduce-max part of the softmax
-  // reduce_max and subtract will always have exactly 1 user
-  // in both training and inference
-  // softmax_input should always have exactly 2 users
-  auto unfused_softmax_max_subpattern = m::SharedSubpattern(
-      m::Subtract(
-          m::Op(),
-          m::Broadcast(OptionalConvert(
-              m::Op()
-                  .WithPredicate(IsReduceMax)
-                  .WithOneUse()
-                  .WithOperand(0, OptionalBitcast(OptionalConvert(
-                                      m::Op(softmax_input).WithNumUser(2)))))))
-          .WithOneUse());
-  // The reduce-add part of the softmax
-  // reduce_sum and reduce_sum_broadcast should have 2 users in training
-  // and 1 user in inference
-  auto unfused_softmax_sum_subpattern = m::SharedSubpattern(m::Divide(
-      OptionalBitcast(m::Exp(unfused_softmax_max_subpattern)),
-      m::Broadcast(
-          softmax_reduce_sum_bcast,
-          OptionalConvert(
-              m::Op(softmax_reduce_sum)
-                  .WithOperand(0, OptionalBitcast(OptionalConvert(
-                                      m::Exp(unfused_softmax_max_subpattern))))
-                  .WithPredicate(IsReduceSum)
-                  .WithAtMostNumUser(2)))
-          .WithAtMostNumUser(2)));
-  return unfused_softmax_sum_subpattern;
-}
-
-std::optional<double> GetConstantValue(const HloInstruction* inst) {
-  if (!IsScalar(inst)) {
-    return std::nullopt;
-  }
-  switch (inst->shape().element_type()) {
-    case F16:
-      return static_cast<float>(inst->literal().GetFirstElement<half>());
-    case BF16:
-      return static_cast<float>(inst->literal().GetFirstElement<bfloat16>());
-    case F32:
-      return inst->literal().GetFirstElement<float>();
-    case F64:
-      return inst->literal().GetFirstElement<double>();
-    default:
-      return std::nullopt;
-  }
-}
-
-double GetDropoutRateFromHlo(HloInstruction* dropout) {
-  std::optional<double> dropout_rate_inv;
-  dropout_rate_inv = GetConstantValue(dropout);
-  if (!dropout_rate_inv.has_value()) {
-    return 0.0;
-  }
-  // In dropout, inputs are divided by (1 - rate), we need to divide 1 by
-  // the constant in dropout node and substract
-  // from 1 here to get the actual dropout rate.
-  return (1.0 - (1.0 / *dropout_rate_inv));
-}
-
-bool IsComputeCapabilityAndCudnnSupported(
-    stream_executor::CudaComputeCapability cc,
-    stream_executor::dnn::VersionInfo cudnn_version,
-    stream_executor::dnn::VersionInfo supported_cudnn_version) {
-  // Enforce capability minor == 0 because hardware with a non-zero minor
-  // number typically has insufficient shared memory for cuDNN FMHA.
-  if (cc.IsAtLeastAmpere() && cc.minor == 0 &&
-      cudnn_version >= supported_cudnn_version) {
-    return true;
-  }
-  VLOG(2) << absl::StrFormat(
-      "CudnnFusedMHARewriter did not run. Unsupported compute "
-      "capability(%s; major should be >= 8, minor should be 0) or cudnn version"
-      "(%s; should be >= %s)",
-      cc.ToString(), cudnn_version.ToString(),
-      supported_cudnn_version.ToString());
-  return false;
-}
-
-bool IsSupportedPrimitiveType(const HloInstruction* bmm) {
-  PrimitiveType dtype = bmm->shape().element_type();
-  return dtype == BF16 || dtype == F16;
-}
-
-std::vector<int64_t> GetDimensionVector(absl::Span<const int64_t> dimensions,
-                                        absl::Span<const int64_t> dim_nums) {
-  std::vector<int64_t> vec(dim_nums.size());
-  for (int i = 0; i < dim_nums.size(); i++) {
-    vec[i] = dimensions.at(dim_nums.at(i));
-  }
-  return vec;
-}
-
-struct QKVLayout {
-  int64_t batch;
-  int64_t num_heads;
-  int64_t seqlen_q;
-  int64_t seqlen_kv;
-  int64_t hidden_dim;
-};
-
-absl::StatusOr<std::optional<QKVLayout>> GetQKVLayout(
-    HloInstruction* bmm_1, HloInstruction* bmm_2, bool need_canonicalization) {
-  // get layout from bmm1
-  const DotDimensionNumbers& bmm1_dnums = bmm_1->dot_dimension_numbers();
-  TF_ASSIGN_OR_RETURN(
-      std::vector<int64_t> bmm1_s_q_dims,
-      GetNonContractingDims(bmm_1->operand(0)->shape(),
-                            bmm1_dnums.lhs_batch_dimensions(),
-                            bmm1_dnums.lhs_contracting_dimensions()));
-
-  TF_ASSIGN_OR_RETURN(
-      std::vector<int64_t> bmm1_s_kv_dims,
-      GetNonContractingDims(bmm_1->operand(1)->shape(),
-                            bmm1_dnums.rhs_batch_dimensions(),
-                            bmm1_dnums.rhs_contracting_dimensions()));
-
-  std::vector<int64_t> bmm1_bh =
-      GetDimensionVector(bmm_1->operand(0)->shape().dimensions(),
-                         bmm1_dnums.lhs_batch_dimensions());
-
-  std::vector<int64_t> bmm1_s_q = GetDimensionVector(
-      bmm_1->operand(0)->shape().dimensions(), bmm1_s_q_dims);
-
-  std::vector<int64_t> bmm1_s_kv = GetDimensionVector(
-      bmm_1->operand(1)->shape().dimensions(), bmm1_s_kv_dims);
-
-  std::vector<int64_t> bmm1_d =
-      GetDimensionVector(bmm_1->operand(0)->shape().dimensions(),
-                         bmm1_dnums.lhs_contracting_dimensions());
-
-  TF_RET_CHECK(bmm1_bh.size() == 2);
-  TF_RET_CHECK(bmm1_s_q.size() == 1);
-  TF_RET_CHECK(bmm1_s_kv.size() == 1);
-  TF_RET_CHECK(bmm1_d.size() == 1);
-
-  // get layout from bmm2
-  const DotDimensionNumbers& bmm2_dnums = bmm_2->dot_dimension_numbers();
-  TF_ASSIGN_OR_RETURN(
-      std::vector<int64_t> bmm2_lhs_non_contracting_dims,
-      GetNonContractingDims(bmm_2->operand(0)->shape(),
-                            bmm2_dnums.lhs_batch_dimensions(),
-                            bmm2_dnums.lhs_contracting_dimensions()));
-
-  TF_ASSIGN_OR_RETURN(
-      std::vector<int64_t> bmm2_rhs_non_contracting_dims,
-      GetNonContractingDims(bmm_2->operand(1)->shape(),
-                            bmm2_dnums.rhs_batch_dimensions(),
-                            bmm2_dnums.rhs_contracting_dimensions()));
-
-  std::vector<int64_t> bmm2_bh =
-      GetDimensionVector(bmm_2->operand(0)->shape().dimensions(),
-                         bmm2_dnums.lhs_batch_dimensions());
-
-  std::vector<int64_t> bmm2_s_kv =
-      GetDimensionVector(bmm_2->operand(0)->shape().dimensions(),
-                         bmm2_dnums.lhs_contracting_dimensions());
-
-  std::vector<int64_t> bmm2_s_q =
-      need_canonicalization
-          ? GetDimensionVector(bmm_2->operand(1)->shape().dimensions(),
-                               bmm2_rhs_non_contracting_dims)
-          : GetDimensionVector(bmm_2->operand(0)->shape().dimensions(),
-                               bmm2_lhs_non_contracting_dims);
-
-  std::vector<int64_t> bmm2_d =
-      need_canonicalization
-          ? GetDimensionVector(bmm_2->operand(0)->shape().dimensions(),
-                               bmm2_lhs_non_contracting_dims)
-          : GetDimensionVector(bmm_2->operand(1)->shape().dimensions(),
-                               bmm2_rhs_non_contracting_dims);
-
-  TF_RET_CHECK(bmm2_bh.size() == 2);
-  TF_RET_CHECK(bmm2_s_q.size() == 1);
-  TF_RET_CHECK(bmm2_s_kv.size() == 1);
-  TF_RET_CHECK(bmm2_d.size() == 1);
-
-  // check if bhsd is correct between bmm1 and bmm2
-  if (bmm1_bh[0] != bmm2_bh[0] || bmm1_bh[1] != bmm2_bh[1] ||
-      bmm1_s_q[0] != bmm2_s_q[0] || bmm1_s_kv[0] != bmm2_s_kv[0] ||
-      bmm1_d[0] != bmm2_d[0]) {
-    return std::nullopt;
-  }
-
-  QKVLayout qkv_layout;
-  qkv_layout.batch = bmm1_bh[0];
-  qkv_layout.num_heads = bmm1_bh[1];
-  qkv_layout.seqlen_q = bmm1_s_q[0];
-  qkv_layout.seqlen_kv = bmm1_s_kv[0];
-  qkv_layout.hidden_dim = bmm1_d[0];
-  return qkv_layout;
-}
-
-absl::StatusOr<bool> IsFlashAttention(
-    QKVLayout qkv_layout, bool is_training,
-    stream_executor::CudaComputeCapability cc,
-    stream_executor::dnn::VersionInfo cudnn_version) {
-  int64_t s_q = qkv_layout.seqlen_q;
-  int64_t s_kv = qkv_layout.seqlen_kv;
-  int64_t hidden_dim = qkv_layout.hidden_dim;
-  // start with most relaxed constraint
-  bool is_seqlen_supported = (!is_training || (s_q % 2 == 0 && s_kv % 2 == 0));
-  bool is_hidden_dim_supported = hidden_dim <= 128 && hidden_dim % 8 == 0;
-  bool is_flash_attention = is_seqlen_supported && is_hidden_dim_supported;
-  if (!is_flash_attention) return false;
-
-  // going backwards to check compatibility
-  if ((is_training && (s_q < 64 || s_kv < 64)) &&
-      !IsComputeCapabilityAndCudnnSupported(
-          cc, cudnn_version, stream_executor::dnn::VersionInfo(9, 0, 0))) {
-    VLOG(2) << "Flash attention training with seq < 64 not supported cuDNN < "
-               "9.0.0.";
-    return false;
-  }
-
-  if ((hidden_dim != 64 && hidden_dim != 128) &&
-      !IsComputeCapabilityAndCudnnSupported(
-          cc, cudnn_version, stream_executor::dnn::VersionInfo(8, 9, 6))) {
-    VLOG(2) << "Flash attention head dim != 64 or 128 not supported with cuDNN "
-               "< 8.9.6.";
-    return false;
-  }
-
-  if ((is_training && s_kv % 64 != 0) &&
-      !IsComputeCapabilityAndCudnnSupported(
-          cc, cudnn_version, stream_executor::dnn::VersionInfo(8, 9, 5))) {
-    VLOG(2) << "Flash attention training with seq kv % 64 != 0 not supported "
-               "with cuDNN < 8.9.5.";
-    return false;
-  }
-
-  if (!IsComputeCapabilityAndCudnnSupported(
-          cc, cudnn_version, stream_executor::dnn::VersionInfo(8, 9, 4))) {
-    VLOG(2) << "Require cuDNN 8.9.4 to run flash attention.";
-    return false;
-  }
-  return is_flash_attention;
-}
-
-bool IsCausalMaskPattern(HloInstruction* mask) {
-  auto causal_mask =
-      m::Select(m::Compare(m::Iota(), m::Iota()), m::Broadcast(m::Constant()),
-                m::Broadcast(m::Constant()));
-  auto causal_mask_pattern_fwd_remat =
-      m::Broadcast(OptionalBitcast(causal_mask));
-  auto causal_mask_pattern_bwd = m::Broadcast(m::Convert(OptionalBitcast(
-      m::Minimum(m::Op(), m::Broadcast(OptionalBitcast(causal_mask))))));
-  HloInstruction* param = nullptr;
-  HloInstruction* gte = nullptr;
-  auto causal_mask_pattern_fwd = m::Broadcast(
-      OptionalBitcast(m::GetTupleElement(&gte, m::Parameter(&param))));
-  auto causal_mask_pattern = m::AnyOf<HloInstruction>(
-      causal_mask_pattern_fwd_remat, causal_mask_pattern_fwd,
-      causal_mask_pattern_bwd);
-  if (Match(mask, causal_mask_pattern)) {
-    if (param != nullptr && param->parent()->IsWhileBodyComputation()) {
-      // need to track to outside of the while loop body to find the real mask.
-      auto while_instr = param->parent()->WhileCallInstruction();
-      auto mask_index = gte->tuple_index();
-      auto actual_mask =
-          while_instr->mutable_operand(0)->mutable_operand(mask_index);
-      auto causal_mask_pattern_fwd =
-          OptionalBitcast(m::Convert(m::MinimumAnyOrder(
-              m::Op(),
-              OptionalBitcast(m::MinimumAnyOrder(
-                  m::Op(), m::Broadcast(OptionalBitcast(causal_mask)))))));
-      return Match(actual_mask, causal_mask_pattern_fwd);
-    }
-    return true;
-  }
-  return false;
-}
-
-MatchFwdResult MatchSoftmaxDropoutBmm(MatchFwdResult previous_result,
-                                      int64_t bmm2_operand_position,
-                                      HloInstruction* instr) {
-  // Matches the dropout-softmax subpattern.
-  // Softmax_output is a divide
-  // Dropout can take multiple forms, we capture 2 forms here based on
-  // heurustics Form 1 -> softmax - mul - select(dropout) - BMM2
-  MatchFwdResult match_result = previous_result;
-  HloInstruction* softmax_reduce_sum;
-  HloInstruction* softmax_reduce_sum_bcast;
-  HloInstruction* bmm_2;
-  HloInstruction* softmax_input;
-  HloInstruction* dropout = nullptr;
-  auto dropout_softmax_pattern_form_1 = m::Select(
-      m::Op(),
-      OptionalConvert(m::MultiplyAnyOrder(
-          OptionalBitcast(OptionalReshape(
-              OptionalConvert(GetUnfusedReduceMaxSumSoftmaxPattern(
-                  &softmax_input, &softmax_reduce_sum,
-                  &softmax_reduce_sum_bcast)))),
-          m::Broadcast(
-              OptionalConvert(m::Constant(&dropout).WithPredicate(IsScalar))))),
-      m::Op());
-
-  // Form 2 -> softmax - mul - BMM2
-  //                     /
-  //                    /
-  //                 select(dropout)
-  auto dropout_softmax_pattern_form_2 =
-      OptionalBitcast(OptionalBitcast(OptionalConvert(m::MultiplyAnyOrder(
-          OptionalReshape(OptionalConvert(GetUnfusedReduceMaxSumSoftmaxPattern(
-              &softmax_input, &softmax_reduce_sum, &softmax_reduce_sum_bcast))),
-          m::Broadcast(
-              OptionalConvert(OptionalBitcast(OptionalReshape(m::Select(
-                  m::Op(),
-                  m::Broadcast(m::Constant(&dropout).WithPredicate(IsScalar)),
-                  m::Op())))))))));
-
-  // Form3 -> softmax - mul(dropout) - mul(scale) - BMM2
-  auto dropout_softmax_pattern_form_3 = m::MultiplyAnyOrder(
-      m::MultiplyAnyOrder(
-          OptionalConvert(GetUnfusedReduceMaxSumSoftmaxPattern(
-              &softmax_input, &softmax_reduce_sum, &softmax_reduce_sum_bcast)),
-          m::Op()),
-      m::Broadcast(m::Constant(&dropout).WithPredicate(IsScalar)));
-
-  // Try matching BMM1 - (Scale) - (Bias) - Softmax - (Dropout) -
-  // BMM2 Dropout with non-zero drop rate has select(divide(softmax_output,
-  // broadcast(1-dropout_rate)))
-  auto softmax_dropout_bmm2_pattern =
-      m::Op(&bmm_2)
-          .WithPredicate(IsBatchedMatmul)
-          .WithOperand(bmm2_operand_position,
-                       m::AnyOf<HloInstruction>(
-                           OptionalBitcast(OptionalConvert(
-                               GetUnfusedReduceMaxSumSoftmaxPattern(
-                                   &softmax_input, &softmax_reduce_sum,
-                                   &softmax_reduce_sum_bcast))),
-                           dropout_softmax_pattern_form_1,
-                           dropout_softmax_pattern_form_2,
-                           dropout_softmax_pattern_form_3));
-
-  if (!Match(instr, softmax_dropout_bmm2_pattern) ||
-      !IsSupportedPrimitiveType(bmm_2)) {
-    match_result.has_match = false;
-    return match_result;
-  }
-  if (softmax_reduce_sum->users()[0]->opcode() == HloOpcode::kConvert) {
-    softmax_reduce_sum = softmax_reduce_sum->users()[0];
-  }
-  match_result.is_training = softmax_reduce_sum->user_count() == 2 &&
-                             softmax_reduce_sum_bcast->user_count() == 2;
-  match_result.matched_bmm_2 = bmm_2;
-  if (dropout) {
-    match_result.matched_dropout_rate = GetDropoutRateFromHlo(dropout);
-  }
-  match_result.matched_softmax_input = softmax_input;
-  match_result.matched_reduce_sum = softmax_reduce_sum;
-  match_result.has_match = true;
-  return match_result;
-}
-
-MatchFwdResult MatchBmm1UnfusedBiasSoftmaxBmm2(MatchFwdResult previous_result,
-                                               HloInstruction* softmax_input,
-                                               bool has_dropout) {
-  MatchFwdResult match_result = previous_result;
-  HloInstruction* bmm_1;
-  HloInstruction* bias = nullptr;
-  HloInstruction* scale = nullptr;
-  // bmm1/scale/bias add should have 2 users if being connected to softmax
-  // otherwise should have exactly 1 user
-  auto first_bmm_pattern =
-      m::SharedSubpattern(m::Op(&bmm_1).WithPredicate(IsBatchedMatmul));
-  auto unfused_scaled_bmm_subpattern = m::MultiplyAnyOrder(
-      OptionalConvert(first_bmm_pattern.WithOneUse()),
-      OptionalConvert(
-          m::Broadcast(m::Constant(&scale).WithPredicate(IsScalar))));
-  if (Match(softmax_input,
-            OptionalConvert(OptionalBitcast(m::AnyOf<HloInstruction>(
-                first_bmm_pattern, unfused_scaled_bmm_subpattern))))) {
-    // bmm1 - (scale) - softmax
-    match_result.matched_bmm_1 = bmm_1;
-    match_result.matched_scale = scale;
-    match_result.matched_custom_call_name =
-        has_dropout ? kCudnnfMHASoftmaxDropoutCallTarget
-                    : kCudnnfMHASoftmaxCallTarget;
-    match_result.has_match = true;
-  } else if (Match(softmax_input,
-                   OptionalBitcast(m::AddAnyOrder(
-                       OptionalConvert(OptionalBitcast(m::AnyOf<HloInstruction>(
-                           unfused_scaled_bmm_subpattern.WithOneUse(),
-                           first_bmm_pattern.WithOneUse()))),
-                       m::Op(&bias))))) {
-    // bmm1 - (scale) - bias - softmax
-    match_result.matched_bmm_1 = bmm_1;
-    match_result.matched_scale = scale;
-    match_result.matched_custom_call_name =
-        has_dropout ? kCudnnfMHAScaleBiasSoftmaxDropoutCallTarget
-                    : kCudnnfMHAScaleBiasSoftmaxCallTarget;
-    match_result.is_causal_mask |= IsCausalMaskPattern(bias);
-    if (!match_result.is_causal_mask &&
-        bias->opcode() == HloOpcode::kBroadcast) {
-      // we can take the bias before broadcast
-      auto dims = Cast<HloBroadcastInstruction>(bias)->dimensions();
-      if (dims == std::vector<int64_t>{2, 3} ||
-          dims == std::vector<int64_t>{0, 2, 3} ||
-          dims == std::vector<int64_t>{1, 2, 3}) {
-        // shapes [1, 1, s, s], [b, 1, s, s], [1, h, s, s] are supported
-        HloInstruction* bias_bc = bias->mutable_operand(0);
-        // bitcast bias_before_broadcast to be 4D
-        std::vector<int64_t> bitcast_dims(bias->shape().rank(), 1);
-        for (int dim : dims) {
-          bitcast_dims[dim] = bias->shape().dimensions()[dim];
-        }
-        bias = bias_bc->AddInstruction(HloInstruction::CreateBitcast(
-            ShapeUtil::MakeShape(bias->shape().element_type(), bitcast_dims),
-            bias_bc));
-      }
-    }
-    match_result.matched_bias = bias;
-    match_result.has_match = true;
-  } else {
-    match_result.has_match = false;
-  }
-  return match_result;
-}
-
-// We will try to match all the patterns below:
-// BMM1 - Scale - bias - Softmax - Dropout - BMM2
-// BMM1 - Scale - bias - Softmax - BMM2
-// BMM1 - Softmax - Dropout - BMM2
-// BMM1 - Softmax - BMM2
-MatchFwdResult MatchFwdMHAPatternsForCanonicalization(HloInstruction* instr) {
-  // We need to match 2 general cases:
-  // 1. bmm1 --> (intermediate nodes) --> bmm2 <-- V matrix
-  // 2. V matrix --> bmm2 <-- (intermediate nodes) <-- bmm1
-  // to determine if we need to canonicalize bmm2.
-  // So we go through both of bmm2's operands and see which one matches our
-  // desired patterns, if operand 1 consumes them, then we need to canonicalize.
-  MatchFwdResult match_result;
-  for (auto bmm2_operand_pos : {0, 1}) {
-    if (bmm2_operand_pos == 1) {
-      match_result.need_canonicalization = true;
-    }
-
-    bool has_dropout = false;
-    // We first check if bmm2 is connect to a softmax or dropout.
-    // If so, we set softmax input and dropout rate to their corresponding
-    // values.
-    match_result =
-        MatchSoftmaxDropoutBmm(match_result, bmm2_operand_pos, instr);
-    if (!match_result.has_match) {
-      continue;
-    }
-    has_dropout = match_result.matched_dropout_rate > 0.0;
-    match_result = MatchBmm1UnfusedBiasSoftmaxBmm2(
-        match_result, match_result.matched_softmax_input, has_dropout);
-    if (match_result.has_match) {
-      return match_result;
-    }
-  }
-  // Didn't find any match
-  match_result.need_canonicalization = false;
-  return match_result;
-}
-
-bool IsBmm2GradGemm2(HloInstruction* instr) {
-  // Check to see if input bmm is bmm2 gradient gemm2, it needs to be either:
-  // 1. having 1 user in cases of dropout
-  // 2. having 2 users in other cases.
-  return (instr->user_count() == 1) || (instr->user_count() == 2);
-}
-
-MatchBwdResult MatchBmm1GradGemm1(MatchBwdResult previous_result,
-                                  HloInstruction* bmm_1) {
-  MatchBwdResult match_result = previous_result;
-  match_result.has_match = false;
-  const HloInstruction* q_tensor = bmm_1->operand(0);
-  for (int64_t i = 0; i < q_tensor->user_count(); i++) {
-    HloInstruction* q_tensor_user_i = q_tensor->users()[i];
-    if (IsBatchedMatmul(q_tensor_user_i) && q_tensor_user_i != bmm_1) {
-      match_result.matched_bmm_1_grad_1 = q_tensor_user_i;
-      // Check for canonicalization.
-      if (match_result.matched_bmm_1_grad_1->operand_index(q_tensor) != 1) {
-        match_result.bmm_1_grad_1_need_canonicalization = true;
-      }
-      match_result.has_match = true;
-    }
-  }
-  return match_result;
-}
-
-MatchBwdResult MatchBmm1GradGemm2(MatchBwdResult previous_result,
-                                  HloInstruction* fwd_fmha_call) {
-  HloInstruction* bmm_1_grad_2 = nullptr;
-  MatchBwdResult match_result = previous_result;
-  match_result.has_match = false;
-  // bmm1 gradient gemm2 shares the same input d_s as bmm1 gradient gemm1.
-  // Check to see if bmm1 grad gemm1 needs canonicalization or not, if not,
-  // then the shared input is the first operand.
-  int64_t d_s_index = match_result.bmm_1_grad_1_need_canonicalization ? 1 : 0;
-  HloInstruction* d_s_user_0 = match_result.matched_bmm_1_grad_1;
-
-  HloInstruction* d_s = d_s_user_0->mutable_operand(d_s_index);
-  if (d_s->opcode() == HloOpcode::kBitcast && d_s->user_count() == 1) {
-    d_s = d_s->mutable_operand(0);
-  }
-
-  auto bmm_1_grad_2_it = std::find_if(
-      d_s->users().begin(), d_s->users().end(), [&](HloInstruction* instr) {
-        return instr != match_result.matched_bmm_1_grad_1 &&
-               instr->opcode() == HloOpcode::kDot;
-      });
-  if (bmm_1_grad_2_it != d_s->users().end()) {
-    bmm_1_grad_2 = *bmm_1_grad_2_it;
-  } else {
-    return match_result;
-  }
-
-  match_result.matched_bmm_1_grad_2 = bmm_1_grad_2;
-
-  if (match_result.matched_bmm_1_grad_2->operand_index(d_s) != 0) {
-    match_result.bmm_1_grad_2_need_canonicalization = true;
-  }
-  match_result.has_match = true;
-  return match_result;
-}
-
-MatchBwdResult MatchBmm2GradGemm1(HloInstruction* fwd_fmha_call) {
-  HloInstruction* bmm_2_grad_1 = nullptr;
-  MatchBwdResult matched_result;
-  // The second GTE of the forward MHA call is the input of the bmm2's gradient
-  // gemm 1, we check to see if the current gemm satisfies above condition.
-  int64_t activation_out_gte_index = 1;
-  if (fwd_fmha_call->user_count() < 2 ||
-      fwd_fmha_call->users()[activation_out_gte_index]->opcode() !=
-          HloOpcode::kGetTupleElement ||
-      fwd_fmha_call->users()[activation_out_gte_index]->user_count() > 1 ||
-      !IsBatchedMatmul(
-          fwd_fmha_call->users()[activation_out_gte_index]->users()[0])) {
-    matched_result.has_match = false;
-    return matched_result;
-  }
-  // Found fmha->GTE->gemm, assign it to bmm_2_grad_1 and check to see if it
-  // needs canonicalization.
-  bmm_2_grad_1 = fwd_fmha_call->users()[activation_out_gte_index]->users()[0];
-  matched_result.matched_bmm_2_grad_1 = bmm_2_grad_1;
-  if (bmm_2_grad_1->operand_index(
-          fwd_fmha_call->users()[activation_out_gte_index]) != 0) {
-    matched_result.bmm_2_grad_1_need_canonicalization = true;
-  }
-
-  matched_result.has_match = true;
-  return matched_result;
-}
-
-MatchBwdResult MatchBmm2GradGemm2(MatchBwdResult previous_result,
-                                  HloInstruction* fwd_fmha_call,
-                                  bool v_transposed) {
-  MatchBwdResult match_result = previous_result;
-  match_result.has_match = false;
-  // If v tensor is transposed by forward fmha call, then we need to take fmha v
-  // input's producer's producer.
-  const HloInstruction* v_tensor = v_transposed
-                                       ? fwd_fmha_call->operand(2)->operand(0)
-                                       : fwd_fmha_call->operand(2);
-  for (int64_t i = 0; i < v_tensor->user_count(); i++) {
-    HloInstruction* v_tensor_user_i = v_tensor->users()[i];
-    if (IsBatchedMatmul(v_tensor_user_i) && IsBmm2GradGemm2(v_tensor_user_i)) {
-      match_result.matched_bmm_2_grad_2 = v_tensor_user_i;
-      // Check for canonicalization.
-      if (match_result.matched_bmm_2_grad_2->operand_index(v_tensor) != 1) {
-        match_result.bmm_2_grad_2_need_canonicalization = true;
-      }
-      match_result.has_match = true;
-    }
-  }
-  return match_result;
-}
-
-MatchBwdResult MatchDbias(MatchBwdResult previous_result,
-                          HloInstruction* d_intermediate,
-                          const absl::flat_hash_set<HloInstruction*> users) {
-  MatchBwdResult match_result = previous_result;
-  auto user_count = d_intermediate->user_count();
-  HloInstruction* dbias_user = nullptr;
-  HloInstruction* dbias = nullptr;
-  for (auto user : d_intermediate->users()) {
-    if (users.contains(user)) {
-      user_count -= 1;
-    } else {
-      dbias_user = user;
-    }
-  }
-  auto ConsumeExtraConvert = [](HloInstruction* instr) {
-    Match(instr->users()[0], m::Convert(&instr, m::Op()).WithOneUse());
-    return true;
-  };
-  // user_count == 1 && (reduce-> {convert} ->bitcast)
-  match_result.has_match =
-      user_count == 1 &&
-      Match(dbias_user, m::Reduce(&dbias, m::Op(), m::Op()).WithOneUse()) &&
-      dbias->shape().rank() == 3 && ConsumeExtraConvert(dbias);
-  if (match_result.has_match) {
-    // cuDNN only supports dbias for [1, h, s, s]
-    // make sure reduce dimension is on batch dim
-    auto reduce_dim = dbias->dimensions();
-    if (reduce_dim.size() == 1 && reduce_dim[0] == 0) {
-      match_result.matched_dbias = dbias;
-    } else {
-      match_result.has_match = false;
-    }
-  }
-  return match_result;
-}
-
-MatchBwdResult MatchBwdBmmSoftmaxDropoutBmm(MatchBwdResult previous_result,
-                                            HloInstruction* fwd_fmha_call) {
-  MatchBwdResult match_result = previous_result;
-  bool is_bmm1_grad1_canonicalized =
-      match_result.bmm_1_grad_1_need_canonicalization;
-  match_result.has_match = false;
-  bool has_scale = false;
-  bool has_dropout = false;
-  // Backward dropout pattern
-  // select(mask, bmm2_grad2, broadcast())
-  auto bwd_dropout_pattern_form_1 = m::SharedSubpattern(
-      OptionalBitcast(OptionalReshape(OptionalConvert(m::Select(
-          m::Op(), m::Op().WithPredicate([&](const HloInstruction* instr) {
-            return instr == match_result.matched_bmm_2_grad_2;
-          }),
-          m::Broadcast(
-              OptionalConvert(m::Constant().WithPredicate(IsScalar))))))));
-
-  // multiply(bmm2_grad2, broadcast(select(mask, broadcast(), op())))
-  auto bwd_dropout_pattern_form_2 =
-      m::SharedSubpattern(OptionalBitcast(m::MultiplyAnyOrder(
-          OptionalConvert(
-              m::Op().WithPredicate([&](const HloInstruction* instr) {
-                return instr == match_result.matched_bmm_2_grad_2;
-              })),
-          m::Broadcast(OptionalConvert(OptionalBitcast(OptionalReshape(
-              m::Select(m::Op(),
-                        m::Broadcast(OptionalConvert(
-                            m::Constant().WithPredicate(IsScalar))),
-                        m::Op()))))))));
-  auto bwd_dropout_pattern_form_3 = OptionalConvert(m::MultiplyAnyOrder(
-      m::MultiplyAnyOrder(
-          m::Op().WithPredicate([&](const HloInstruction* instr) {
-            return instr == match_result.matched_bmm_2_grad_2;
-          }),
-          m::Broadcast(m::Constant().WithPredicate(IsScalar))),
-      m::Op()));
-  auto bwd_dropout_pattern = m::AnyOf<HloInstruction>(
-      bwd_dropout_pattern_form_1, bwd_dropout_pattern_form_2,
-      bwd_dropout_pattern_form_3);
-  // Backward softmax pattern
-  HloInstruction* bwd_softmax_input = nullptr;
-  HloInstruction* exp_1;
-  HloInstruction* exp_2;
-  HloInstruction* d_softmax;
-
-  // d_softmax = exp * (dy / s_b - sum(dy * exp * 1 / s^2))
-  // there could be at most 3 users of d_softmax: bmm1grad1 bmm1grad2 and dbias
-  auto bwd_softmax_pattern = OptionalBitcast(OptionalConvert(
-      m::MultiplyAnyOrder(
-          &d_softmax,
-          m::AddAnyOrder(
-              m::Divide().WithOneUse(),
-              m::Broadcast(OptionalBitcast(OptionalConvert(
-                  m::Negate(
-                      OptionalBitcast(
-                          m::Op()
-                              .WithPredicate(IsReduceSum)
-                              .WithOneUse()
-                              .WithOperand(
-                                  0, OptionalBitcast(
-                                         m::MultiplyAnyOrder(
-                                             m::MultiplyAnyOrder(
-                                                 m::Op(&bwd_softmax_input),
-                                                 m::Broadcast())
-                                                 .WithOneUse(),
-                                             m::Exp(&exp_2, m::Op()))
-                                             .WithOneUse()))))
-                      .WithOneUse())))),
-          m::Exp(&exp_1, m::Op()))
-          .WithAtMostNumUser(3)));
-
-  // Backward scale input pattern
-  HloInstruction* bwd_scale_input = nullptr;
-  HloInstruction* bwd_scale = nullptr;
-
-  auto bwd_scale_pattern =
-      m::MultiplyAnyOrder(&bwd_scale, m::Op(&bwd_scale_input),
-                          m::Broadcast(m::Constant().WithPredicate(IsScalar)))
-          .WithNumUser(2);
-  int intermediate_input_pos = is_bmm1_grad1_canonicalized ? 1 : 0;
-
-  HloInstruction* intermediate_input =
-      match_result.matched_bmm_1_grad_1->mutable_operand(
-          intermediate_input_pos);
-
-  has_scale = Match(intermediate_input, bwd_scale_pattern);
-
-  if (has_scale) {
-    intermediate_input = bwd_scale_input;
-  }
-
-  if (!Match(intermediate_input, bwd_softmax_pattern) || exp_1 != exp_2) {
-    return match_result;
-  }
-  has_dropout = Match(bwd_softmax_input, bwd_dropout_pattern);
-  // If no dropout but softmax input is not coming from bmm2 gradient gemm 2,
-  // then it's not the pattern that we care about.
-  if (!has_dropout &&
-      !Match(bwd_softmax_input,
-             OptionalConvert((OptionalBitcast(
-                 m::Op().WithPredicate([&](const HloInstruction* instr) {
-                   return instr == match_result.matched_bmm_2_grad_2;
-                 })))))) {
-    return match_result;
-  }
-
-  if (has_dropout) {
-    // has bias
-    if (fwd_fmha_call->custom_call_target() ==
-        kCudnnfMHAScaleBiasSoftmaxDropoutCallTarget)
-      match_result.matched_custom_call_name =
-          kCudnnfMHAScaleBiasSoftmaxDropoutBackwardCallTarget;
-    // no bias
-    if (fwd_fmha_call->custom_call_target() ==
-        kCudnnfMHASoftmaxDropoutCallTarget)
-      match_result.matched_custom_call_name =
-          kCudnnfMHASoftmaxDropoutBackwardCallTarget;
-  } else {
-    // has bias
-    if (fwd_fmha_call->custom_call_target() ==
-        kCudnnfMHAScaleBiasSoftmaxCallTarget)
-      match_result.matched_custom_call_name =
-          kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget;
-    // no bias
-    if (fwd_fmha_call->custom_call_target() == kCudnnfMHASoftmaxCallTarget)
-      match_result.matched_custom_call_name =
-          kCudnnfMHASoftmaxBackwardCallTarget;
-  }
-  // try to pattern match dbias
-  HloInstruction* dS = d_softmax;
-  if (dS->users()[0]->opcode() == HloOpcode::kConvert) {
-    dS = dS->users()[0];
-  }
-  if (has_scale) {
-    // bmm1-(scale)-(bias)-softmax pattern users could be dbias or scale bwd
-    if (dS->user_count() == 1) {
-      // no dbias
-      match_result.has_match = true;
-    } else if (dS->user_count() == 2) {
-      match_result = MatchDbias(match_result, dS, {bwd_scale});
-    } else {
-      match_result.has_match = false;
-    }
-  } else {
-    // bmm1-(bias)-softmax pattern
-    // users could be dbias besides bmm1grad1 bmm1grad2
-    if (dS->user_count() == 2) {
-      match_result.has_match = true;
-    } else if (dS->user_count() == 3) {
-      match_result = MatchDbias(match_result, dS,
-                                {match_result.matched_bmm_1_grad_1,
-                                 match_result.matched_bmm_1_grad_2});
-    } else {
-      match_result.has_match = false;
-    }
-  }
-  return match_result;
-}
-// First, we look for the bmm2 gradient gemm 1 which takes the activation
-// output from a forward fmha call.
-// Secondly, look for bmm2 gradient gemm 2 that takes the v tensor as an
-// input. We take the v tensor from the third operand of the forward fmha
-// call. If forward is canonicalized, then we skip the additional transpose in
-// between.
-// Then we look for bmm1 gradient gemm1 by searching for gemms that share q
-// tensor with current fmha call.
-MatchBwdResult MatchBackwardBmms(HloInstruction* fwd_fmha_call,
-                                 HloInstruction* bmm_1, bool v_transposed) {
-  MatchBwdResult matched_result = MatchBmm2GradGemm1(fwd_fmha_call);
-  if (!matched_result.has_match) {
-    return matched_result;
-  }
-
-  matched_result =
-      MatchBmm2GradGemm2(matched_result, fwd_fmha_call, v_transposed);
-  if (!matched_result.has_match) {
-    return matched_result;
-  }
-
-  matched_result = MatchBmm1GradGemm1(matched_result, bmm_1);
-  if (!matched_result.has_match) {
-    return matched_result;
-  }
-
-  matched_result = MatchBmm1GradGemm2(matched_result, fwd_fmha_call);
-  if (!matched_result.has_match) {
-    return matched_result;
-  }
-  return matched_result;
-}
-// We will match the backward graphs for all forward patterns defined in
-// MatchFwdMHAPatternsForCanonicalization
-MatchBwdResult MatchBwdMHAPatternsForCanonicalization(
-    HloInstruction* fwd_fmha_call, HloInstruction* bmm_1, bool v_transposed) {
-  MatchBwdResult match_result =
-      MatchBackwardBmms(fwd_fmha_call, bmm_1, v_transposed);
-  if (!match_result.has_match) {
-    return match_result;
-  }
-  match_result = MatchBwdBmmSoftmaxDropoutBmm(match_result, fwd_fmha_call);
-  return match_result;
-}
-
-absl::StatusOr<bool> IsMHABlockSupported(
-    HloInstruction* bmm_1, HloInstruction* bmm_2, bool need_canonicalization,
-    bool is_training, bool is_causal_mask, std::string& custom_call_name,
-    const DebugOptions& debug_options,
-    stream_executor::CudaComputeCapability cc,
-    stream_executor::dnn::VersionInfo cudnn_version) {
-  if (MHACallHasDropout(custom_call_name) &&
-      !debug_options.xla_gpu_fused_attention_use_cudnn_rng()) {
-    VLOG(3) << "Using CUDNN RNG for fused attention dropout is not enabled.\n";
-    return false;
-  }
-
-  // cuDNN 8.8 currently only supports BF16 and F16 data types.
-  if (!IsSupportedPrimitiveType(bmm_1) || !IsSupportedPrimitiveType(bmm_2)) {
-    if (VLOG_IS_ON(2)) {
-      VLOG(2) << "Unsupported primitive type for cuDNN MHA fusion:\n"
-              << bmm_1->ToString() << "\nOR\n"
-              << bmm_2->ToString() << "\n"
-              << "BF16 and F16 are the supported Dtypes.";
-    }
-    return false;
-  }
-
-  if (bmm_1->shape().rank() != 4 || bmm_2->shape().rank() != 4) {
-    if (VLOG_IS_ON(2)) {
-      VLOG(2) << "Unsupported bmm rank for cuDNN MHA fusion:\n"
-              << bmm_1->ToString() << "\nOR\n"
-              << bmm_2->ToString() << "\n"
-              << "Only bmm with rank 4 is supported.";
-    }
-    return false;
-  }
-
-  // get batch/num heads/sequence length/hidden dim from bmm1 and bmm2
-  // also make sure they are the same between bmm1 and bmm2
-  TF_ASSIGN_OR_RETURN(std::optional<QKVLayout> qkv_layout,
-                      GetQKVLayout(bmm_1, bmm_2, need_canonicalization));
-  if (!qkv_layout.has_value()) {
-    VLOG(2) << "bmm1 and bmm2 have different qkv layout.";
-    return false;
-  }
-
-  // check if matched attention block is supported by cuDNN flash attention.
-  TF_ASSIGN_OR_RETURN(
-      bool is_flash_attention,
-      IsFlashAttention(qkv_layout.value(), is_training, cc, cudnn_version));
-  if (is_flash_attention) {
-    if (is_causal_mask) {
-      // if bias is causal mask, needs to remove bias from name
-      if (custom_call_name == kCudnnfMHAScaleBiasSoftmaxDropoutCallTarget) {
-        custom_call_name = kCudnnfMHASoftmaxDropoutCallTarget;
-      } else if (custom_call_name == kCudnnfMHAScaleBiasSoftmaxCallTarget) {
-        custom_call_name = kCudnnfMHASoftmaxCallTarget;
-      }
-    }
-  }
-  return is_flash_attention;
-}
-
-absl::StatusOr<HloInstruction*> CanonicalizeBatchedGemmForcuDNNFMHA(
-    HloInstruction* bmm, HloComputation* comp) {
-  if (VLOG_IS_ON(3)) {
-    VLOG(3) << "Before FMHA Dot Cannonicalization: \n"
-            << comp->parent()->ToString();
-  }
-  HloInstruction* lhs_bmm = bmm->mutable_operand(0);
-  HloInstruction* rhs_bmm = bmm->mutable_operand(1);
-  const DotDimensionNumbers& dnums = bmm->dot_dimension_numbers();
-
-  int64_t rank = bmm->shape().dimensions_size();
-  std::vector<int64_t> perm(rank);
-  std::iota(perm.begin(), perm.end(), 0);
-  // Swap the non-contracting dims of BMM shape. By contract, the
-  // non-contracting dims in the output are the last two dimensions.
-  std::swap(perm[rank - 1], perm[rank - 2]);
-
-  DotDimensionNumbers new_dnums = dnums;
-  std::swap(*new_dnums.mutable_lhs_contracting_dimensions(),
-            *new_dnums.mutable_rhs_contracting_dimensions());
-  std::swap(*new_dnums.mutable_lhs_batch_dimensions(),
-            *new_dnums.mutable_rhs_batch_dimensions());
-  auto original_bmm_shape = bmm->shape();
-  HloInstruction* new_dot = comp->AddInstruction(HloInstruction::CreateDot(
-      ShapeUtil::MakeShape(original_bmm_shape.element_type(),
-                           Permute(original_bmm_shape.dimensions(), perm)),
-      /* lhs */ rhs_bmm, /* rhs */ lhs_bmm, new_dnums,
-      bmm->precision_config()));
-
-  TF_RETURN_IF_ERROR(comp->ReplaceWithNewInstruction(
-      bmm, HloInstruction::CreateTranspose(original_bmm_shape, new_dot, perm)));
-  if (VLOG_IS_ON(2)) {
-    VLOG(2) << "After FMHA Dot Cannonicalization: \n"
-            << comp->parent()->ToString();
-  }
-  return new_dot;
-}
-
-absl::StatusOr<HloInstruction*> ChangeCheckedDimToFastest(
-    HloComputation* comp, HloInstruction* bmm, bool is_lhs,
-    bool should_contracting_be_fastest) {
-  const DotDimensionNumbers& dot_dims_bmm = bmm->dot_dimension_numbers();
-  DotDimensionNumbers new_dot_dims_bmm = dot_dims_bmm;
-  int64_t bmm_operand = is_lhs ? 0 : 1;
-  absl::Span<const int64_t> contracting_dims =
-      is_lhs ? dot_dims_bmm.lhs_contracting_dimensions()
-             : dot_dims_bmm.rhs_contracting_dimensions();
-  absl::Span<const int64_t> batch_dims =
-      is_lhs ? dot_dims_bmm.lhs_batch_dimensions()
-             : dot_dims_bmm.rhs_batch_dimensions();
-  absl::Span<const int64_t> lhs_minor_to_major_bmm =
-      bmm->operand(0)->shape().layout().minor_to_major();
-  absl::Span<const int64_t> rhs_minor_to_major_bmm =
-      bmm->operand(1)->shape().layout().minor_to_major();
-
-  absl::Span<const int64_t>& minor_to_major_to_check =
-      is_lhs ? lhs_minor_to_major_bmm : rhs_minor_to_major_bmm;
-
-  CHECK_EQ(contracting_dims.size(), 1);
-  TF_ASSIGN_OR_RETURN(std::vector<int64_t> non_contracting_dims,
-                      GetNonContractingDims(bmm->operand(bmm_operand)->shape(),
-                                            batch_dims, contracting_dims));
-  CHECK_EQ(non_contracting_dims.size(), 1);
-  HloInstruction* operand_bmm = bmm->mutable_operand(bmm_operand);
-  int64_t hidden_dim = should_contracting_be_fastest ? contracting_dims[0]
-                                                     : non_contracting_dims[0];
-  int64_t minor_dim = minor_to_major_to_check[0];
-  // If the hidden dim of the target operand is not the fastest moving
-  // dimension, make it so.
-  if (minor_dim != hidden_dim) {
-    std::vector<int64_t> perm(bmm->shape().dimensions_size());
-    std::iota(perm.begin(), perm.end(), 0);
-    std::swap(perm[hidden_dim], perm[minor_dim]);
-
-    if (is_lhs) {
-      new_dot_dims_bmm.set_lhs_contracting_dimensions(0,
-                                                      non_contracting_dims[0]);
-    } else {
-      new_dot_dims_bmm.set_rhs_contracting_dimensions(0,
-                                                      non_contracting_dims[0]);
-    }
-
-    operand_bmm = comp->AddInstruction(
-        HloInstruction::CreateTranspose(
-            ShapeUtil::MakeShapeWithDenseLayout(
-                bmm->shape().element_type(),
-                Permute(operand_bmm->shape().dimensions(), perm),
-                minor_to_major_to_check),
-            operand_bmm, perm),
-        &operand_bmm->metadata());
-    *((DynCast<HloDotInstruction>(bmm))->mutable_dot_dimension_numbers()) =
-        new_dot_dims_bmm;
-  }
-  return operand_bmm;
-}
-
-absl::StatusOr<HloInstruction*> FuseFwdMultiHeadedAttentionBlock(
-    HloComputation* comp, HloInstruction* bmm_1, HloInstruction* bmm_2,
-    HloInstruction* bias, HloInstruction* scale, HloInstruction* reduce_sum,
-    HloInstruction* softmax_input, double dropout_rate,
-    std::string& custom_call_name, stream_executor::CudaComputeCapability cc,
-    bool is_training, bool& changed, bool& v_transposed, bool is_causal_mask) {
-  double scale_value = 1.0;
-  HloInstruction* lhs_bmm1;
-  HloInstruction* rhs_bmm1;
-  HloInstruction* rhs_bmm2;
-  DotDimensionNumbers orig_bmm1_dot_dim = bmm_1->dot_dimension_numbers();
-  DotDimensionNumbers orig_bmm2_dot_dim = bmm_2->dot_dimension_numbers();
-
-  TF_ASSIGN_OR_RETURN(rhs_bmm1, ChangeCheckedDimToFastest(
-                                    comp, bmm_1, false /*is_lhs*/,
-                                    true /*should_contracting_be_fastest*/));
-  TF_ASSIGN_OR_RETURN(lhs_bmm1, ChangeCheckedDimToFastest(
-                                    comp, bmm_1, true /*is_lhs*/,
-                                    true /*should_contracting_be_fastest*/));
-
-  TF_ASSIGN_OR_RETURN(rhs_bmm2, ChangeCheckedDimToFastest(
-                                    comp, bmm_2, false /*is_lhs*/,
-                                    false /*should_contracting_be_fastest*/));
-
-  if (rhs_bmm2 != bmm_2->mutable_operand(1)) {
-    v_transposed = true;
-  }
-
-  GpuBackendConfig gpu_config;
-  CudnnfMHABackendConfig& fmha_config =
-      *gpu_config.mutable_cudnn_fmha_backend_config();
-
-  *fmha_config.mutable_bmm1_dot_dimension_numbers() =
-      bmm_1->dot_dimension_numbers();
-  *fmha_config.mutable_bmm2_dot_dimension_numbers() =
-      bmm_2->dot_dimension_numbers();
-
-  TF_RET_CHECK((dropout_rate >= 0.0 && dropout_rate <= 1.0));
-  // Restore original DotDimensionNumbers.
-  *((DynCast<HloDotInstruction>(bmm_1))->mutable_dot_dimension_numbers()) =
-      orig_bmm1_dot_dim;
-  *((DynCast<HloDotInstruction>(bmm_2))->mutable_dot_dimension_numbers()) =
-      orig_bmm2_dot_dim;
-
-  // If scale node is assigned, extract value from it.
-  if (scale != nullptr) {
-    std::optional<double> value;
-    value = GetConstantValue(scale);
-    TF_RET_CHECK(value.has_value());
-    scale_value = (double)*value;
-  }
-
-  fmha_config.set_fmha_scale(scale_value);
-  fmha_config.set_dropout_rate(dropout_rate);
-  // Set to an arbitrary seed for now, seed is not exposed to XLA in HLO
-  // graph.
-  // TODO Find a way to compute original seed from dropout keys.
-  fmha_config.set_seed(42);
-
-  *fmha_config.mutable_intermediate_tensor_shape() = bmm_1->shape().ToProto();
-  {
-    auto* algorithm = fmha_config.mutable_algorithm();
-    algorithm->set_algo_id(0);  // engine id
-    algorithm->set_math_type(se::dnn::AlgorithmProto::TENSOR_OP_MATH);
-    std::vector<int64_t> knob_ids = /* {0, 1} */ {17, 24};
-    std::vector<int64_t> knob_vals = {1, 0};
-    for (int i = 0; i < knob_ids.size(); ++i) {
-      (*algorithm->mutable_tuning_knobs())[knob_ids[i]] = knob_vals[i];
-    }
-    algorithm->set_is_cudnn_frontend(true);
-    algorithm->mutable_workspace_size()->set_value(0);
-  }
-  // set is_causal_mask here
-  // choose to generate causal mask inside cuDNN attention or not
-  fmha_config.set_mask_type(is_causal_mask ? CudnnfMHABackendConfig::CAUSAL
-                                           : CudnnfMHABackendConfig::NO_MASK);
-
-  const Shape& output_shape = bmm_2->shape();
-
-  Shape call_shape;
-  // Activation output is used by backward gemm.
-  HloInstruction* activation_output = nullptr;
-
-  // Output Order: {O, Fwd act*, workspace}
-  std::vector<Shape> output_shapes = {output_shape};
-  if (is_training) {
-    activation_output = bmm_2->mutable_operand(0);
-    // Sometimes activation output is bitcast, the actual activation is the
-    // other user of the producer of bmm_2's first operand.
-    if (activation_output->user_count() < 2 &&
-        activation_output->opcode() == HloOpcode::kBitcast) {
-      HloInstruction* producer = activation_output->mutable_operand(0);
-      TF_RET_CHECK(producer->user_count() == 2);
-      HloInstruction* bmm2_grad2_user =
-          producer->users()[0] == activation_output ? producer->users()[1]
-                                                    : producer->users()[0];
-      // might be (transpose) - bmm2_grad2
-      if (IsBatchedMatmul(bmm2_grad2_user)) {
-        activation_output = producer;
-      } else if (bmm2_grad2_user->opcode() == HloOpcode::kTranspose) {
-        activation_output = bmm2_grad2_user;
-      } else {
-        return Internal("Unexpected activation patterns");
-      }
-    }
-    // if it is flash attention, should output softmax stats to the bwd
-    TF_RET_CHECK(reduce_sum != nullptr);
-    output_shapes.push_back(
-        ShapeUtil::MakeShape(F32, reduce_sum->shape().dimensions()));
-  }
-  output_shapes.push_back(ShapeUtil::MakeShape(U8, {0}));
-  call_shape = ShapeUtil::MakeTupleShape(output_shapes);
-
-  // Input Order: {Q, K, V, bias*}
-  std::vector<HloInstruction*> operands = {lhs_bmm1, rhs_bmm1, rhs_bmm2};
-  if (!is_causal_mask && bias != nullptr) {
-    HloInstruction* original_bias;
-    HloInstruction* original_broadcast;
-    // There will be cases where the bias is up-casted to wider float type,
-    // we need to take the original bias node and broadcast it without
-    // converting.
-    if (Match(bias, m::Broadcast(
-                        &original_broadcast,
-                        m::Convert(
-                            m::Op(&original_bias)
-                                .WithPredicate([](const HloInstruction* instr) {
-                                  return instr->shape().element_type() == F16 ||
-                                         instr->shape().element_type() == BF16;
-                                }))
-                            .WithPredicate([](const HloInstruction* instr) {
-                              return instr->shape().element_type() == F32 ||
-                                     instr->shape().element_type() == F64;
-                            })))) {
-      absl::Span<const int64_t> original_bcast_dims =
-          (DynCast<HloBroadcastInstruction>(original_broadcast))->dimensions();
-      // This is to deal with cases like paxml where an extra dimension of 1 is
-      // added to the left of the tensor.
-      // TODO Make this logic more generic
-      absl::Span<const int64_t> original_broadcast_shape_dims =
-          original_broadcast->shape().dimensions();
-      int64_t starting_index = original_broadcast_shape_dims.size() == 5 &&
-                                       original_broadcast_shape_dims[0] == 1
-                                   ? 1
-                                   : 0;
-      std::vector<int64_t> bcast_dimensions;
-      for (auto& dim : original_bcast_dims) {
-        bcast_dimensions.push_back(dim - starting_index);
-      }
-
-      const Shape& bcast_shape = bmm_1->shape();
-      bias = comp->AddInstruction(HloInstruction::CreateBroadcast(
-          bcast_shape, original_bias, bcast_dimensions));
-    }
-    operands.push_back(bias);
-  }
-
-  HloInstruction* fmha_call =
-      comp->AddInstruction(HloInstruction::CreateCustomCall(
-          call_shape, operands, absl::string_view(custom_call_name)));
-  TF_RETURN_IF_ERROR(fmha_call->set_backend_config(gpu_config));
-  TF_RETURN_IF_ERROR(SetFMHAInstructionName(bmm_1->GetModule(), fmha_call));
-
-  TF_RETURN_IF_ERROR(comp->ReplaceWithNewInstruction(
-      bmm_2,
-      HloInstruction::CreateGetTupleElement(bmm_2->shape(), fmha_call, 0)));
-
-  if (activation_output) {
-    HloInstruction* activation_gte =
-        comp->AddInstruction(HloInstruction::CreateGetTupleElement(
-            activation_output->shape(), fmha_call, 1));
-    TF_RETURN_IF_ERROR(comp->ReplaceInstructionWithDifferentShape(
-                               activation_output, activation_gte,
-                               /*preserve_sharding=*/false,
-                               /*relay_control_dependency=*/false,
-                               /*remove_unused_operands=*/false)
-                           .status());
-  }
-
-  if (VLOG_IS_ON(2)) {
-    VLOG(2) << "After CudnnFusedMHARewriter: \n" << comp->parent()->ToString();
-  }
-  changed = true;
-  return fmha_call;
-}
-
-absl::StatusOr<bool> FuseBwdMultiHeadedAttentionBlock(
-    HloComputation* comp, HloInstruction* bmm_1_grad_1,
-    HloInstruction* bmm_1_grad_2, HloInstruction* bmm_2_grad_1,
-    HloInstruction* bmm_2_grad_2, HloInstruction* fwd_fmha_call,
-    HloInstruction* dbias, HloInstruction* bias,
-    std::string& bwd_custom_call_name) {
-  HloInstruction* rhs_bmm1_grad_gemm1;
-  HloInstruction* lhs_bmm1_grad_gemm2;
-  HloInstruction* rhs_bmm2_grad_gemm2;
-  HloInstruction* d_output_grad;
-
-  DotDimensionNumbers orig_bmm1_grad1_config =
-      bmm_1_grad_1->dot_dimension_numbers();
-  DotDimensionNumbers orig_bmm1_grad2_config =
-      bmm_1_grad_2->dot_dimension_numbers();
-  DotDimensionNumbers orig_bmm2_grad1_config =
-      bmm_2_grad_1->dot_dimension_numbers();
-  DotDimensionNumbers orig_bmm2_grad2_config =
-      bmm_2_grad_2->dot_dimension_numbers();
-
-  TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
-                      fwd_fmha_call->backend_config<GpuBackendConfig>());
-  const CudnnfMHABackendConfig& fwd_config =
-      gpu_config.cudnn_fmha_backend_config();
-  bool is_causal_mask =
-      fwd_config.mask_type() == CudnnfMHABackendConfig::CAUSAL;
-  CudnnfMHABackendConfig bwd_fmha_config;
-  // Q tensor
-  TF_ASSIGN_OR_RETURN(
-      rhs_bmm1_grad_gemm1,
-      ChangeCheckedDimToFastest(comp, bmm_1_grad_1, false /*is_lhs*/,
-                                false /*should_contracting_be_fastest*/));
-  // K tensor
-  TF_ASSIGN_OR_RETURN(
-      lhs_bmm1_grad_gemm2,
-      ChangeCheckedDimToFastest(comp, bmm_1_grad_2, false /*is_lhs*/,
-                                false /*should_contracting_be_fastest*/));
-
-  // Forward activation
-  // softmax_stats
-  HloInstruction* fwd_act;
-  int64_t fwd_act_index = 1;
-  fwd_act = comp->AddInstruction(HloInstruction::CreateGetTupleElement(
-      fwd_fmha_call->shape().tuple_shapes(fwd_act_index), fwd_fmha_call,
-      fwd_act_index));
-
-  // V tensor
-  TF_ASSIGN_OR_RETURN(
-      rhs_bmm2_grad_gemm2,
-      ChangeCheckedDimToFastest(comp, bmm_2_grad_2, false /*is_lhs*/,
-                                true /*should_contracting_be_fastest*/));
-  // d output to bmm2_grad2
-  // Since d_o is the input of 2 bmms, we set the dim number using the
-  // constraint
-  // -> the contracting dimension of the lhs of bmm_2_grad_2 needs to be the
-  // fastest moving dimension.
-  TF_ASSIGN_OR_RETURN(
-      d_output_grad,
-      ChangeCheckedDimToFastest(comp, bmm_2_grad_2, true /*is_lhs*/,
-                                true /*should_contracting_be_fastest*/));
-  // d output to bmm2_grad1
-  // we don't use this value but we call this to make sure dot number is being
-  // set correctly
-  TF_ASSIGN_OR_RETURN(
-      HloInstruction * bmm_2_grad_1_rhs,
-      ChangeCheckedDimToFastest(comp, bmm_2_grad_1, false /*is_lhs*/,
-                                false /*should_contracting_be_fastest*/));
-  (void)bmm_2_grad_1_rhs;
-  // Operand order: {Q, K, V, Fwd act, d_o, bias*, O*}
-  std::vector<HloInstruction*> operands = {
-      rhs_bmm1_grad_gemm1, lhs_bmm1_grad_gemm2, rhs_bmm2_grad_gemm2, fwd_act,
-      d_output_grad};
-
-  // For flash attention, add fwd output to input list
-  if (!is_causal_mask && bias) {
-    operands.push_back(bias);
-  }
-  HloInstruction* fwd_output;
-  for (auto user : fwd_fmha_call->users()) {
-    if (user->opcode() == HloOpcode::kGetTupleElement &&
-        user->tuple_index() == 0) {
-      fwd_output = user;
-    }
-  }
-  // should be able to find the instruction
-  TF_RET_CHECK(fwd_output != nullptr);
-  // check dO and O have the same layout as it is required by cuDNN
-  TF_RET_CHECK(fwd_output->shape() == d_output_grad->shape());
-  operands.push_back(fwd_output);
-
-  *bwd_fmha_config.mutable_bmm1_grad_gemm1_dot_dimension_numbers() =
-      bmm_1_grad_1->dot_dimension_numbers();
-  *bwd_fmha_config.mutable_bmm1_grad_gemm2_dot_dimension_numbers() =
-      bmm_1_grad_2->dot_dimension_numbers();
-  *bwd_fmha_config.mutable_bmm2_grad_gemm1_dot_dimension_numbers() =
-      bmm_2_grad_1->dot_dimension_numbers();
-  *bwd_fmha_config.mutable_bmm2_grad_gemm2_dot_dimension_numbers() =
-      bmm_2_grad_2->dot_dimension_numbers();
-
-  // Restore original DotDimensionNumbers
-  *((DynCast<HloDotInstruction>(bmm_1_grad_1))
-        ->mutable_dot_dimension_numbers()) = orig_bmm1_grad1_config;
-  *((DynCast<HloDotInstruction>(bmm_1_grad_2))
-        ->mutable_dot_dimension_numbers()) = orig_bmm1_grad2_config;
-  *((DynCast<HloDotInstruction>(bmm_2_grad_1))
-        ->mutable_dot_dimension_numbers()) = orig_bmm2_grad1_config;
-  *((DynCast<HloDotInstruction>(bmm_2_grad_2))
-        ->mutable_dot_dimension_numbers()) = orig_bmm2_grad2_config;
-
-  bwd_fmha_config.set_fmha_scale(fwd_config.fmha_scale());
-  bwd_fmha_config.set_dropout_rate(fwd_config.dropout_rate());
-  // Set to an arbitrary seed for now, seed is not exposed to XLA in HLO
-  // graph.
-  // TODO Find a way to compute original seed from dropout keys.
-  bwd_fmha_config.set_seed(fwd_config.seed());
-  bwd_fmha_config.set_mask_type(is_causal_mask
-                                    ? CudnnfMHABackendConfig::CAUSAL
-                                    : CudnnfMHABackendConfig::NO_MASK);
-
-  *bwd_fmha_config.mutable_intermediate_tensor_shape() =
-      fwd_config.intermediate_tensor_shape();
-  {
-    auto* algorithm = bwd_fmha_config.mutable_algorithm();
-    algorithm->set_algo_id(0);  // engine id
-    algorithm->set_math_type(se::dnn::AlgorithmProto::TENSOR_OP_MATH);
-    std::vector<int64_t> knob_ids = /* {0, 1} */ {17, 24};
-    std::vector<int64_t> knob_vals = {1, 0};
-    for (int i = 0; i < knob_ids.size(); ++i) {
-      (*algorithm->mutable_tuning_knobs())[knob_ids[i]] = knob_vals[i];
-    }
-    algorithm->set_is_cudnn_frontend(true);
-    algorithm->mutable_workspace_size()->set_value(0);
-  }
-
-  // Output order:
-  // {dQ(bmm_1_grad_2), dK(bmm_1_grad_1), dV(bmm_2_grad_1), dbias*, workspace}
-  std::vector<Shape> output_shapes = {
-      bmm_1_grad_2->shape(), bmm_1_grad_1->shape(), bmm_2_grad_1->shape()};
-
-  if (dbias) {
-    // Cudnn kernel only outputs dbias in this shape [1, num_heads, seq, seq],
-    // so we add a dimension of 1 to existing dbias' shape.
-    std::vector<int64_t> dbias_shape_vector =
-        SpanToVector(dbias->shape().dimensions());
-    dbias_shape_vector.insert(dbias_shape_vector.begin(), 1);
-    Shape cudnn_dbias_shape =
-        ShapeUtil::MakeShape(dbias->shape().element_type(), dbias_shape_vector);
-    output_shapes.push_back(cudnn_dbias_shape);
-  }
-  // Reserved placeholder for workspace
-  output_shapes.push_back(ShapeUtil::MakeShape(U8, {0}));
-  Shape call_shape = ShapeUtil::MakeTupleShape(output_shapes);
-  HloInstruction* fmha_bwd_call =
-      comp->AddInstruction(HloInstruction::CreateCustomCall(
-          call_shape, operands, absl::string_view(bwd_custom_call_name)));
-  GpuBackendConfig bwd_gpu_config;
-  *bwd_gpu_config.mutable_cudnn_fmha_backend_config() = bwd_fmha_config;
-  TF_RETURN_IF_ERROR(fmha_bwd_call->set_backend_config(bwd_gpu_config));
-  TF_RETURN_IF_ERROR(
-      SetFMHAInstructionName(bmm_1_grad_1->GetModule(), fmha_bwd_call));
-
-  // Q gradient
-  TF_RETURN_IF_ERROR(comp->ReplaceWithNewInstruction(
-      bmm_1_grad_2, HloInstruction::CreateGetTupleElement(bmm_1_grad_2->shape(),
-                                                          fmha_bwd_call, 0)));
-  // K gradient
-  TF_RETURN_IF_ERROR(comp->ReplaceWithNewInstruction(
-      bmm_1_grad_1, HloInstruction::CreateGetTupleElement(bmm_1_grad_1->shape(),
-                                                          fmha_bwd_call, 1)));
-  // V gradient
-  TF_RETURN_IF_ERROR(comp->ReplaceWithNewInstruction(
-      bmm_2_grad_1, HloInstruction::CreateGetTupleElement(bmm_2_grad_1->shape(),
-                                                          fmha_bwd_call, 2)));
-
-  if (dbias) {
-    // Reshape fmha dbias output to original user's input shape.
-    // If the reshape doesn't involve physical data movement,
-    // algebraic simplifer can change it to a no-op bitcast.
-    Shape original_shape = dbias->shape();
-    HloInstruction* dbias_user = dbias->users()[0];
-    HloInstruction* cudnn_dbias_output =
-        comp->AddInstruction(HloInstruction::CreateGetTupleElement(
-            output_shapes[3], fmha_bwd_call, 3));
-    HloInstruction* reshape_dbias = comp->AddInstruction(
-        HloInstruction::CreateReshape(original_shape, cudnn_dbias_output));
-    TF_RETURN_IF_ERROR(dbias_user->ReplaceOperandWith(
-        dbias_user->operand_index(dbias), reshape_dbias));
-
-    TF_RETURN_IF_ERROR(
-        comp->ReplaceInstructionWithDifferentShape(dbias, cudnn_dbias_output));
-  }
-  return true;
-}
-
-absl::Status RestoreFwdGraph(
-    HloComputation* comp, HloInstruction* fwd_fmha_call, HloInstruction* bmm2,
-    HloInstruction* activation, HloInstruction* original_bmm2_producer0,
-    HloInstruction* original_bmm2_producer1,
-    std::vector<HloInstruction*>& original_activation_producers,
-    bool bmm_2_need_canonicalization) {
-  // If backward pattern is not matched, we need to restore the
-  // original graph structure.
-  // Replacing new GTEs added by forward FMHA call with cloned old
-  // activations and bmm2.
-  HloInstruction* output_gte = fwd_fmha_call->users()[0];
-  HloInstruction* activation_gte = fwd_fmha_call->users()[1];
-  std::string suffix = "fmha_no_match_clone";
-  HloInstruction* cloned_activation =
-      comp->AddInstruction(activation->CloneWithNewOperands(
-          activation->shape(), original_activation_producers, suffix));
-
-  // Since old activation is detached by forward FMHA rewrite, we need
-  // to use the newly cloned activation.
-  HloInstruction* lhs = activation == original_bmm2_producer0
-                            ? cloned_activation
-                            : original_bmm2_producer0;
-  HloInstruction* rhs = activation == original_bmm2_producer0
-                            ? original_bmm2_producer1
-                            : cloned_activation;
-  HloInstruction* cloned_bmm2 = comp->AddInstruction(
-      bmm2->CloneWithNewOperands(bmm2->shape(), {lhs, rhs}, suffix));
-  if (bmm_2_need_canonicalization) {
-    TF_RET_CHECK(output_gte->users()[0]->opcode() == HloOpcode::kTranspose);
-    TF_RETURN_IF_ERROR(
-        comp->ReplaceInstruction(output_gte->users()[0], cloned_bmm2));
-  } else {
-    TF_RETURN_IF_ERROR(comp->ReplaceInstruction(output_gte, cloned_bmm2));
-  }
-  TF_RETURN_IF_ERROR(
-      comp->ReplaceInstruction(activation_gte, cloned_activation));
-  return absl::OkStatus();
-}
-}  // namespace
-
-absl::StatusOr<bool> CudnnFusedMHARewriter::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  bool any_changed = false;
-  // we use this set to keep track of all already matched attention block
-  absl::flat_hash_set<HloInstruction*> matched_bmm1;
-  for (HloComputation* comp :
-       module->MakeNonfusionComputations(execution_threads)) {
-    const DebugOptions& debug_options =
-        comp->parent()->config().debug_options();
-    const se::dnn::VersionInfo cudnn_version =
-        GetDnnVersionInfoOrDefault(stream_executor_, cudnn_version_);
-#if !defined(GOOGLE_CUDA) || CUDA_VERSION < 12000
-    // CUDA needs to be >= 12.0 for cuDNN to work with all supported hardware.
-    // Some cuDNN versions work with CUDA 11, but it is impractical for us to
-    // test those combinations so just disable them.
-    return false;
-#endif
-    if (!debug_options.xla_gpu_enable_cudnn_fmha() ||
-        !IsComputeCapabilityAndCudnnSupported(
-            compute_capability_, cudnn_version,
-            stream_executor::dnn::VersionInfo(8, 9, 4))) {
-      return false;
-    }
-    for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
-      bool v_transposed = false;
-      bool changed = false;
-      MatchFwdResult matched_result =
-          MatchFwdMHAPatternsForCanonicalization(instr);
-      if (!matched_result.has_match) {
-        continue;
-      }
-      // We check the validity of bmms here before canonicalization so we don't
-      // modify the graph if mha fusion is not possible
-      TF_ASSIGN_OR_RETURN(
-          bool is_mha_module_supported,
-          IsMHABlockSupported(
-              matched_result.matched_bmm_1, matched_result.matched_bmm_2,
-              matched_result.need_canonicalization, matched_result.is_training,
-              matched_result.is_causal_mask,
-              matched_result.matched_custom_call_name, debug_options,
-              compute_capability_, cudnn_version));
-
-      if (!is_mha_module_supported) continue;
-      // If we have an activation with more than 1 users in non-training mode,
-      // we cannot rewrite the graph. So skip processing the rest.
-      HloInstruction* activation =
-          matched_result.need_canonicalization
-              ? matched_result.matched_bmm_2->mutable_operand(1)
-              : matched_result.matched_bmm_2->mutable_operand(0);
-      if (!matched_result.is_training && activation->user_count() > 1) {
-        VLOG(2)
-            << "Activation: " << activation->ToString()
-            << " cannot have more than 1 users in non-training mode. Skipping.";
-        continue;
-      }
-      HloInstruction* original_bmm2_producer0 =
-          matched_result.matched_bmm_2->mutable_operand(0);
-      HloInstruction* original_bmm2_producer1 =
-          matched_result.matched_bmm_2->mutable_operand(1);
-
-      HloInstruction* original_bmm2 = matched_result.matched_bmm_2;
-      std::vector<HloInstruction*> original_activation_producers;
-      for (HloInstruction* operand : activation->mutable_operands()) {
-        original_activation_producers.push_back(operand);
-      }
-      // We make sure no attention block is matched and replaced twice here
-      if (!matched_bmm1.insert(matched_result.matched_bmm_1).second) {
-        continue;
-      }
-      // If we need to canonicalize the bmm, we will assign the newly
-      // canonicalized bmm to bmm_2.
-      if (matched_result.need_canonicalization) {
-        TF_ASSIGN_OR_RETURN(matched_result.matched_bmm_2,
-                            CanonicalizeBatchedGemmForcuDNNFMHA(
-                                matched_result.matched_bmm_2, comp));
-      }
-
-      // Fuse the bmms and intermediate nodes into fMHA call, the fused call
-      // will replace bmm_2.
-      TF_ASSIGN_OR_RETURN(
-          HloInstruction * fwd_fmha_call,
-          FuseFwdMultiHeadedAttentionBlock(
-              comp, matched_result.matched_bmm_1, matched_result.matched_bmm_2,
-              matched_result.matched_bias, matched_result.matched_scale,
-              matched_result.matched_reduce_sum,
-              matched_result.matched_softmax_input,
-              matched_result.matched_dropout_rate,
-              matched_result.matched_custom_call_name, compute_capability_,
-              matched_result.is_training, changed, v_transposed,
-              matched_result.is_causal_mask));
-      any_changed |= changed;
-      if (matched_result.is_training) {
-        MatchBwdResult matched_bwd_result =
-            MatchBwdMHAPatternsForCanonicalization(
-                fwd_fmha_call, matched_result.matched_bmm_1, v_transposed);
-        if (!matched_bwd_result.has_match) {
-          VLOG(2) << "Backward pattern not matching, skipping.";
-          // restore fwd graph if bwd pattern match failed
-          TF_RETURN_IF_ERROR(
-              RestoreFwdGraph(comp, fwd_fmha_call, original_bmm2, activation,
-                              original_bmm2_producer0, original_bmm2_producer1,
-                              original_activation_producers,
-                              matched_result.need_canonicalization));
-          continue;
-        }
-        if (matched_bwd_result.matched_dbias &&
-            !(compute_capability_.IsAtLeastHopper() &&
-              compute_capability_.minor == 0 &&
-              cudnn_version >= stream_executor::dnn::VersionInfo(8, 9, 6))) {
-          VLOG(2) << "Flash attention dbias requires cudnn 8.9.6 + hopper.";
-          // restore fwd graph if bwd pattern match failed
-          TF_RETURN_IF_ERROR(
-              RestoreFwdGraph(comp, fwd_fmha_call, original_bmm2, activation,
-                              original_bmm2_producer0, original_bmm2_producer1,
-                              original_activation_producers,
-                              matched_result.need_canonicalization));
-          continue;
-        }
-        // Canonicalize gemms
-        if (matched_bwd_result.bmm_1_grad_1_need_canonicalization) {
-          TF_ASSIGN_OR_RETURN(
-              matched_bwd_result.matched_bmm_1_grad_1,
-              CanonicalizeBatchedGemmForcuDNNFMHA(
-                  matched_bwd_result.matched_bmm_1_grad_1, comp));
-        }
-        if (matched_bwd_result.bmm_1_grad_2_need_canonicalization) {
-          TF_ASSIGN_OR_RETURN(
-              matched_bwd_result.matched_bmm_1_grad_2,
-              CanonicalizeBatchedGemmForcuDNNFMHA(
-                  matched_bwd_result.matched_bmm_1_grad_2, comp));
-        }
-        if (matched_bwd_result.bmm_2_grad_1_need_canonicalization) {
-          TF_ASSIGN_OR_RETURN(
-              matched_bwd_result.matched_bmm_2_grad_1,
-              CanonicalizeBatchedGemmForcuDNNFMHA(
-                  matched_bwd_result.matched_bmm_2_grad_1, comp));
-        }
-        if (matched_bwd_result.bmm_2_grad_2_need_canonicalization) {
-          TF_ASSIGN_OR_RETURN(
-              matched_bwd_result.matched_bmm_2_grad_2,
-              CanonicalizeBatchedGemmForcuDNNFMHA(
-                  matched_bwd_result.matched_bmm_2_grad_2, comp));
-        }
-
-        // Fuse the corresponding gradient graph to an fMHA fused call.s
-        TF_ASSIGN_OR_RETURN(
-            changed,
-            FuseBwdMultiHeadedAttentionBlock(
-                comp, matched_bwd_result.matched_bmm_1_grad_1,
-                matched_bwd_result.matched_bmm_1_grad_2,
-                matched_bwd_result.matched_bmm_2_grad_1,
-                matched_bwd_result.matched_bmm_2_grad_2, fwd_fmha_call,
-                matched_bwd_result.matched_dbias, matched_result.matched_bias,
-                matched_bwd_result.matched_custom_call_name));
-        any_changed |= changed;
-      }
-    }
-  }
-
-  return any_changed;
-}
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter.h b/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter.h
deleted file mode 100644
index f0aa687..0000000
--- a/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter.h
+++ /dev/null
@@ -1,59 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_CUDNN_FUSED_MHA_REWRITER_H_
-#define XLA_SERVICE_GPU_CUDNN_FUSED_MHA_REWRITER_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/stream_executor/device_memory.h"
-#include "xla/stream_executor/dnn.h"
-
-namespace xla {
-namespace gpu {
-
-class CudnnFusedMHARewriter : public HloModulePass {
- public:
-  explicit CudnnFusedMHARewriter(se::CudaComputeCapability cc,
-                                 se::StreamExecutor* stream_executor)
-      : compute_capability_(cc), stream_executor_(stream_executor) {}
-
-  explicit CudnnFusedMHARewriter(se::CudaComputeCapability cc,
-                                 se::dnn::VersionInfo cudnn_version)
-      : compute_capability_(cc), cudnn_version_(cudnn_version) {}
-
-  absl::string_view name() const override {
-    return "cudnn-fused-multi-headed-attention-rewriter";
-  }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
-  const se::CudaComputeCapability compute_capability_;
-  se::StreamExecutor* stream_executor_ = nullptr;
-  const se::dnn::VersionInfo cudnn_version_;
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_CUDNN_FUSED_MHA_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter_test.cc b/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter_test.cc
deleted file mode 100644
index 897136b..0000000
--- a/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter_test.cc
+++ /dev/null
@@ -1,3183 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/cudnn_fused_mha_rewriter.h"
-
-#include <cstddef>
-#include <memory>
-#include <optional>
-#include <utility>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "absl/algorithm/container.h"
-#include "absl/strings/string_view.h"
-#include "xla/error_spec.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/algebraic_simplifier.h"
-#include "xla/service/computation_layout.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/service/gpu/cudnn_fused_mha_transpose_fusion.h"
-#include "xla/service/hlo_cse.h"
-#include "xla/service/hlo_dce.h"
-#include "xla/service/hlo_module_config.h"
-#include "xla/service/hlo_parser.h"
-#include "xla/service/hlo_verifier.h"
-#include "xla/service/layout_normalization.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/service/reshape_decomposer.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/stream_executor/dnn.h"
-#include "xla/test_helpers.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/tsl/lib/core/status_test_util.h"
-#include "xla/util.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/statusor.h"
-
-#if GOOGLE_CUDA
-#include "third_party/gpus/cuda/include/cuda.h"
-#include "third_party/gpus/cudnn/cudnn.h"  // IWYU pragma: keep
-#endif
-
-namespace xla {
-namespace gpu {
-namespace {
-
-namespace m = xla::match;
-
-class CudnnFusedMhaRewriterTestHloTest : public HloTestBase {
- public:
-  se::CudaComputeCapability GetCudaComputeCapability() {
-    // Fake a supported compute capability to run tests,
-    // we don't run any kernels in these tests so they should be safe
-    // to run anywhere.
-    return se::CudaComputeCapability(8, 0);
-  }
-
-  se::CudaComputeCapability GetRealCudaComputeCapability() {
-    return backend()
-        .default_stream_executor()
-        ->GetDeviceDescription()
-        .cuda_compute_capability();
-  }
-
-  se::dnn::VersionInfo GetCudnnVersion() {
-    // Fake a supported compute capability to run tests,
-    // we don't run any kernels in these tests so they should be safe
-    // to run anywhere.
-    return se::dnn::VersionInfo(8, 9, 4);
-  }
-
-  CudnnFusedMhaRewriterTestHloTest()
-      : HloTestBase(/*verifier_layout_sensitive=*/false,
-                    /*allow_mixed_precision_in_hlo_verifier=*/false,
-                    /*instruction_can_change_layout_func=*/{}) {
-#if !defined(GOOGLE_CUDA) || CUDA_VERSION < 12000
-    skip_reason_ = "cuDNN fused MHA requires CUDA 12 or later.";
-    return;
-#endif
-  }
-
- protected:
-  size_t CountFusedAttentionCall(HloModule* module, bool is_backward = false) {
-    return absl::c_count_if(module->entry_computation()->instructions(),
-                            [&](const HloInstruction* instr) {
-                              if (is_backward) {
-                                return IsBwdCustomCallTofMHA(*instr);
-                              } else {
-                                return IsFwdCustomCallTofMHA(*instr);
-                              }
-                            });
-  }
-
-  DebugOptions GetDebugOptionsForTest() override {
-    auto debug_options = HloTestBase::GetDebugOptionsForTest();
-    debug_options.set_xla_gpu_enable_cudnn_fmha(true);
-    debug_options.set_xla_gpu_fused_attention_use_cudnn_rng(true);
-    return debug_options;
-  }
-
-  HloModuleConfig GetModuleConfig() {
-    DebugOptions debug_options = GetDebugOptionsForTest();
-    HloModuleConfig config_with_fmha;
-    config_with_fmha.set_debug_options(debug_options);
-    return config_with_fmha;
-  }
-
-  // Centralize skip checks in the constructor. Unfortunately we cannot call
-  // GTEST_SKIP from the constructor. Instead, we set (if needed) `skip_reason`,
-  // and then check it from all test fixtures.
-  // An alternative would be to use the SetUp() override, but for this to be
-  // correct we'd have to ensure that all the parents' SetUp() methods are
-  // called, which is error prone.
-  std::optional<absl::string_view> skip_reason_;
-};
-
-constexpr absl::string_view
-    hlo_BF16Bmm1SoftmaxBmm2Pattern_k_hidden_not_most_minor = R"(
-HloModule fmha_test, entry_computation_layout={(bf16[16,16,256,64]{3,2,1,0},bf16[16,16,256,64]{3,2,1,0},bf16[16,16,256,64]{3,2,1,0})->bf16[16,16,256,64]{3,2,1,0}}
-
-region_0.7 {
-  Arg_0.8 = bf16[] parameter(0)
-  Arg_1.9 = bf16[] parameter(1)
-  ROOT maximum = bf16[] maximum(Arg_0.8, Arg_1.9)
-}
-
-region_1.19 {
-  Arg_0.20 = f32[] parameter(0)
-  Arg_1.21 = f32[] parameter(1)
-  ROOT add = f32[] add(Arg_0.20, Arg_1.21)
-}
-
-ENTRY main.6 {
-  Arg_2.3 = bf16[16,16,256,64]{3,2,1,0} parameter(2)
-  Arg_0.1 = bf16[16,16,256,64]{3,2,1,0} parameter(0)
-  Arg_1.2 = bf16[16,16,256,64]{2,3,1,0} parameter(1)
-  dot.0 = bf16[16,16,256,256]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}, metadata={}
-  constant = bf16[] constant(-inf)
-  reduce.11 = bf16[16,16,256]{2,1,0} reduce(dot.0, constant), dimensions={3}, to_apply=region_0.7
-  broadcast.3 = bf16[16,16,256,256]{3,2,1,0} broadcast(reduce.11), dimensions={0,1,2}
-  subtract.1 = bf16[16,16,256,256]{3,2,1,0} subtract(dot.0, broadcast.3)
-  exponential.1 = bf16[16,16,256,256]{3,2,1,0} exponential(subtract.1)
-  convert.1 = f32[16,16,256,256]{3,2,1,0} convert(exponential.1)
-  constant.1 = f32[] constant(0)
-  reduce.23 = f32[16,16,256]{2,1,0} reduce(convert.1, constant.1), dimensions={3}, to_apply=region_1.19
-  convert.2 = bf16[16,16,256]{2,1,0} convert(reduce.23)
-  broadcast.4 = bf16[16,16,256,256]{3,2,1,0} broadcast(convert.2), dimensions={0,1,2}
-  divide = bf16[16,16,256,256]{3,2,1,0} divide(exponential.1, broadcast.4)
-  ROOT dot.1 = bf16[16,16,256,64]{3,2,1,0} dot(divide, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}, metadata={}
-})";
-
-TEST_F(CudnnFusedMhaRewriterTestHloTest,
-       BF16Bmm1SoftmaxBmm2Pattern_bmm1_rhs_contracting_dim_not_most_minor) {
-  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
-  TF_ASSERT_OK_AND_ASSIGN(
-      auto m, ParseAndReturnVerifiedModule(
-                  hlo_BF16Bmm1SoftmaxBmm2Pattern_k_hidden_not_most_minor));
-  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
-                                         GetCudnnVersion()};
-  TF_ASSERT_OK_AND_ASSIGN(bool result, RunHloPass(&fusedMhaRewriter, m.get()));
-  EXPECT_TRUE(result);
-  const HloInstruction* fmha;
-
-  SCOPED_TRACE(m->ToString());
-  EXPECT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(m::GetTupleElement(
-                     m::CustomCall(&fmha, {kCudnnfMHASoftmaxCallTarget}), 0)
-                     .WithShape(BF16, {16, 16, 256, 64})));
-  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
-                          fmha->backend_config<GpuBackendConfig>());
-  const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config();
-  EXPECT_EQ(config.bmm1_dot_dimension_numbers().rhs_contracting_dimensions()[0],
-            2);
-}
-
-constexpr absl::string_view
-    hlo_BF16Bmm1SoftmaxBmm2Pattern_q_hidden_not_most_minor = R"(
-HloModule fmha_test, entry_computation_layout={(bf16[16,16,256,64]{3,2,1,0},bf16[16,16,256,64]{3,2,1,0},bf16[16,16,256,64]{3,2,1,0})->bf16[16,16,256,64]{3,2,1,0}}
-
-region_0.7 {
-  Arg_0.8 = bf16[] parameter(0)
-  Arg_1.9 = bf16[] parameter(1)
-  ROOT maximum = bf16[] maximum(Arg_0.8, Arg_1.9)
-}
-
-region_1.19 {
-  Arg_0.20 = f32[] parameter(0)
-  Arg_1.21 = f32[] parameter(1)
-  ROOT add = f32[] add(Arg_0.20, Arg_1.21)
-}
-
-ENTRY main.6 {
-  Arg_2.3 = bf16[16,16,256,64]{3,2,1,0} parameter(2)
-  Arg_0.1 = bf16[16,16,256,64]{2,3,1,0} parameter(0)
-  Arg_1.2 = bf16[16,16,256,64]{2,3,1,0} parameter(1)
-  dot.0 = bf16[16,16,256,256]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}, metadata={}
-  constant = bf16[] constant(-inf)
-  reduce.11 = bf16[16,16,256]{2,1,0} reduce(dot.0, constant), dimensions={3}, to_apply=region_0.7
-  broadcast.3 = bf16[16,16,256,256]{3,2,1,0} broadcast(reduce.11), dimensions={0,1,2}
-  subtract.1 = bf16[16,16,256,256]{3,2,1,0} subtract(dot.0, broadcast.3)
-  exponential.1 = bf16[16,16,256,256]{3,2,1,0} exponential(subtract.1)
-  convert.1 = f32[16,16,256,256]{3,2,1,0} convert(exponential.1)
-  constant.1 = f32[] constant(0)
-  reduce.23 = f32[16,16,256]{2,1,0} reduce(convert.1, constant.1), dimensions={3}, to_apply=region_1.19
-  convert.2 = bf16[16,16,256]{2,1,0} convert(reduce.23)
-  broadcast.4 = bf16[16,16,256,256]{3,2,1,0} broadcast(convert.2), dimensions={0,1,2}
-  divide = bf16[16,16,256,256]{3,2,1,0} divide(exponential.1, broadcast.4)
-  ROOT dot.1 = bf16[16,16,256,64]{3,2,1,0} dot(divide, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}, metadata={}
-})";
-
-TEST_F(CudnnFusedMhaRewriterTestHloTest,
-       BF16Bmm1SoftmaxBmm2Pattern_bmm1_lhs_contracting_dim_not_most_minor) {
-  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
-  TF_ASSERT_OK_AND_ASSIGN(
-      auto m, ParseAndReturnVerifiedModule(
-                  hlo_BF16Bmm1SoftmaxBmm2Pattern_q_hidden_not_most_minor));
-  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
-                                         GetCudnnVersion()};
-  TF_ASSERT_OK_AND_ASSIGN(bool result, RunHloPass(&fusedMhaRewriter, m.get()));
-  EXPECT_TRUE(result);
-  const HloInstruction* fmha;
-
-  SCOPED_TRACE(m->ToString());
-  EXPECT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(m::GetTupleElement(
-                     m::CustomCall(&fmha, {kCudnnfMHASoftmaxCallTarget}), 0)
-                     .WithShape(BF16, {16, 16, 256, 64})));
-  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
-                          fmha->backend_config<GpuBackendConfig>());
-  const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config();
-  EXPECT_EQ(config.bmm1_dot_dimension_numbers().lhs_contracting_dimensions()[0],
-            2);
-  EXPECT_EQ(config.bmm1_dot_dimension_numbers().rhs_contracting_dimensions()[0],
-            2);
-}
-
-constexpr absl::string_view
-    hlo_BF16Bmm1SoftmaxBmm2Pattern_v_hidden_dim_not_most_minor = R"(
-HloModule fmha_test, entry_computation_layout={(bf16[16,16,256,64]{3,2,1,0},bf16[16,16,256,64]{3,2,1,0},bf16[16,16,256,64]{3,2,1,0})->bf16[16,16,256,64]{3,2,1,0}}
-
-region_0.7 {
-  Arg_0.8 = bf16[] parameter(0)
-  Arg_1.9 = bf16[] parameter(1)
-  ROOT maximum = bf16[] maximum(Arg_0.8, Arg_1.9)
-}
-
-region_1.19 {
-  Arg_0.20 = f32[] parameter(0)
-  Arg_1.21 = f32[] parameter(1)
-  ROOT add = f32[] add(Arg_0.20, Arg_1.21)
-}
-
-ENTRY main.6 {
-  Arg_2.3 = bf16[16,16,256,64]{2,3,1,0} parameter(2)
-  Arg_0.1 = bf16[16,16,256,64]{2,3,1,0} parameter(0)
-  Arg_1.2 = bf16[16,16,256,64]{2,3,1,0} parameter(1)
-  dot.0 = bf16[16,16,256,256]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}, metadata={}
-  constant = bf16[] constant(-inf)
-  reduce.11 = bf16[16,16,256]{2,1,0} reduce(dot.0, constant), dimensions={3}, to_apply=region_0.7
-  broadcast.3 = bf16[16,16,256,256]{3,2,1,0} broadcast(reduce.11), dimensions={0,1,2}
-  subtract.1 = bf16[16,16,256,256]{3,2,1,0} subtract(dot.0, broadcast.3)
-  exponential.1 = bf16[16,16,256,256]{3,2,1,0} exponential(subtract.1)
-  convert.1 = f32[16,16,256,256]{3,2,1,0} convert(exponential.1)
-  constant.1 = f32[] constant(0)
-  reduce.23 = f32[16,16,256]{2,1,0} reduce(convert.1, constant.1), dimensions={3}, to_apply=region_1.19
-  convert.2 = bf16[16,16,256]{2,1,0} convert(reduce.23)
-  broadcast.4 = bf16[16,16,256,256]{3,2,1,0} broadcast(convert.2), dimensions={0,1,2}
-  divide = bf16[16,16,256,256]{3,2,1,0} divide(exponential.1, broadcast.4)
-  ROOT dot.1 = bf16[16,16,256,64]{3,2,1,0} dot(divide, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}, metadata={}
-})";
-
-TEST_F(CudnnFusedMhaRewriterTestHloTest,
-       BF16Bmm1SoftmaxBmm2Pattern_bmm2_non_contracting_dim_not_most_minor) {
-  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
-  TF_ASSERT_OK_AND_ASSIGN(
-      auto m, ParseAndReturnVerifiedModule(
-                  hlo_BF16Bmm1SoftmaxBmm2Pattern_v_hidden_dim_not_most_minor));
-  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
-                                         GetCudnnVersion()};
-  TF_ASSERT_OK_AND_ASSIGN(bool result, RunHloPass(&fusedMhaRewriter, m.get()));
-  EXPECT_TRUE(result);
-  const HloInstruction* fmha;
-
-  SCOPED_TRACE(m->ToString());
-  EXPECT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(m::GetTupleElement(
-                     m::CustomCall(&fmha, {kCudnnfMHASoftmaxCallTarget}), 0)
-                     .WithShape(BF16, {16, 16, 256, 64})));
-  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
-                          fmha->backend_config<GpuBackendConfig>());
-  const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config();
-  EXPECT_EQ(config.bmm2_dot_dimension_numbers().lhs_contracting_dimensions()[0],
-            3);
-  EXPECT_EQ(config.bmm2_dot_dimension_numbers().rhs_contracting_dimensions()[0],
-            3);
-}
-
-TEST_F(CudnnFusedMhaRewriterTestHloTest, BF16Bmm1CombinedMaskBiasSoftmaxBmm2) {
-  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
-  const char* module_str = R"(
-HloModule jit__unnamed_wrapped_function_,
-entry_computation_layout={(bf16[16,256,16,64]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0},bf16[1,16,256,256]{3,2,1,0},pred[16,1,256,256]{3,2,1,0})->bf16[16,256,16,64]{3,2,1,0}}
-
-region_0.32.clone {
-  Arg_0.0 = f32[] parameter(0)
-  Arg_1.0 = f32[] parameter(1)
-  ROOT maximum.1 = f32[] maximum(Arg_0.0, Arg_1.0)
-}
-
-region_1.44 {
-  Arg_0.45 = f32[] parameter(0)
-  Arg_1.46 = f32[] parameter(1)
-  ROOT add = f32[] add(Arg_0.45, Arg_1.46)
-}
-
-ENTRY main.61 {
-  Arg_2.3 = bf16[16,256,16,64]{3,2,1,0} parameter(2), sharding={replicated}
-  transpose.5 = bf16[16,16,64,256]{3,2,1,0} transpose(Arg_2.3), dimensions={0,2,3,1}
-  Arg_0.1 = bf16[16,256,16,64]{3,2,1,0} parameter(0), sharding={replicated}
-  transpose.6 = bf16[16,16,256,64]{3,2,1,0} transpose(Arg_0.1), dimensions={0,2,1,3}
-  Arg_1.2 = bf16[16,256,16,64]{3,2,1,0} parameter(1), sharding={replicated}
-  transpose.7 = bf16[16,16,64,256]{3,2,1,0} transpose(Arg_1.2), dimensions={0,2,3,1}
-  Arg_4.5 = pred[16,1,256,256]{3,2,1,0} parameter(4), sharding={replicated}
-  bitcast.35 = pred[16,256,256]{2,1,0} bitcast(Arg_4.5)
-  convert.49 = s32[16,256,256]{2,1,0} convert(bitcast.35)
-  constant.5 = s32[] constant(0)
-  broadcast.10 = s32[16,256,256]{2,1,0} broadcast(constant.5), dimensions={}
-  compare = pred[16,256,256]{2,1,0} compare(convert.49, broadcast.10), direction=GT
-  constant.7 = bf16[] constant(0)
-  broadcast.12 = bf16[16,256,256]{2,1,0} broadcast(constant.7), dimensions={}
-  constant.9 = bf16[] constant(-9.999e+09)
-  broadcast.13 = bf16[16,256,256]{2,1,0} broadcast(constant.9), dimensions={}
-  select = bf16[16,256,256]{2,1,0} select(compare, broadcast.12, broadcast.13)
-  convert.51 = f32[16,256,256]{2,1,0} convert(select)
-  broadcast.14 = f32[16,16,256,256]{3,2,1,0} broadcast(convert.51), dimensions={0,2,3}
-  Arg_3.4 = bf16[1,16,256,256]{3,2,1,0} parameter(3), sharding={replicated}
-  bitcast.52 = bf16[16,256,256]{2,1,0} bitcast(Arg_3.4)
-  convert.52 = f32[16,256,256]{2,1,0} convert(bitcast.52)
-  broadcast.15 = f32[16,16,256,256]{3,2,1,0} broadcast(convert.52), dimensions={1,2,3}
-  add.1 = f32[16,16,256,256]{3,2,1,0} add(broadcast.14, broadcast.15)
-  dot.2 = bf16[16,16,256,256]{3,2,1,0} dot(transpose.6, transpose.7), lhs_contracting_dims={3}, rhs_contracting_dims={2}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1}
-  convert.55 = f32[16,16,256,256]{3,2,1,0} convert(dot.2)
-  add.18 = f32[16,16,256,256]{3,2,1,0} add(convert.55, add.1)
-  constant.11 = f32[] constant(-inf)
-  reduce.36 = f32[16,16,256]{2,1,0} reduce(add.18, constant.11), dimensions={3}, to_apply=region_0.32.clone
-  broadcast.17 = f32[16,16,256,256]{3,2,1,0} broadcast(reduce.36), dimensions={0,1,2}
-  subtract.1 = f32[16,16,256,256]{3,2,1,0} subtract(add.18, broadcast.17)
-  exponential.1 = f32[16,16,256,256]{3,2,1,0} exponential(subtract.1)
-  constant.14 = f32[] constant(0)
-  reduce.48 = f32[16,16,256]{2,1,0} reduce(exponential.1, constant.14), dimensions={3}, to_apply=region_1.44
-  broadcast.18 = f32[16,16,256,256]{3,2,1,0} broadcast(reduce.48), dimensions={0,1,2}
-  divide = f32[16,16,256,256]{3,2,1,0} divide(exponential.1, broadcast.18)
-  convert.68 = bf16[16,16,256,256]{3,2,1,0} convert(divide)
-  dot.1 = bf16[16,16,64,256]{3,2,1,0} dot(transpose.5, convert.68), lhs_contracting_dims={3}, rhs_contracting_dims={3}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1}
-  ROOT transpose.8 = bf16[16,256,16,64]{3,2,1,0} transpose(dot.1), dimensions={0,3,1,2}
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
-                                         GetCudnnVersion()};
-  TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
-  const HloInstruction* fmha;
-
-  SCOPED_TRACE(m->ToString());
-  EXPECT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::Transpose(
-              m::Transpose(m::GetTupleElement(
-                  m::CustomCall(&fmha, {kCudnnfMHAScaleBiasSoftmaxCallTarget}),
-                  0)))
-              .WithShape(BF16, {16, 256, 16, 64})));
-  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
-                          fmha->backend_config<GpuBackendConfig>());
-  EXPECT_EQ(fmha->operands().size(), 4);
-}
-
-TEST_F(CudnnFusedMhaRewriterTestHloTest, F16Bmm1UnfusedSoftmaxBmm2) {
-  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
-  const char* module_str = R"(
-HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(f16[2,6,40,64]{3,2,1,0},f16[2,6,64,40]{3,2,1,0},f16[2,6,40,64]{3,2,1,0})->f16[2,6,40,64]{3,2,1,0}}
-
-region_0.7 {
-  Arg_0.8 = f16[] parameter(0)
-  Arg_1.9 = f16[] parameter(1)
-  ROOT maximum = f16[] maximum(Arg_0.8, Arg_1.9)
-}
-
-region_1.19 {
-  Arg_0.20 = f32[] parameter(0)
-  Arg_1.21 = f32[] parameter(1)
-  ROOT add = f32[] add(Arg_0.20, Arg_1.21)
-}
-
-ENTRY main.31 {
-  Arg_0.1 = f16[2,6,40,64]{3,2,1,0} parameter(0), sharding={replicated}
-  Arg_1.2 = f16[2,6,64,40]{3,2,1,0} parameter(1), sharding={replicated}
-  dot = f16[2,6,40,40]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={3}, rhs_contracting_dims={2}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1}
-  constant = f16[] constant(-inf)
-  reduce.11 = f16[2,6,40]{2,1,0} reduce(dot, constant), dimensions={3}, to_apply=region_0.7
-  broadcast.3 = f16[2,6,40,40]{3,2,1,0} broadcast(reduce.11), dimensions={0,1,2}
-  subtract.1 = f16[2,6,40,40]{3,2,1,0} subtract(dot, broadcast.3)
-  exponential.1 = f16[2,6,40,40]{3,2,1,0} exponential(subtract.1)
-  convert.1 = f32[2,6,40,40]{3,2,1,0} convert(exponential.1)
-  constant.1 = f32[] constant(0)
-  reduce.23 = f32[2,6,40]{2,1,0} reduce(convert.1, constant.1), dimensions={3}, to_apply=region_1.19
-  convert.2 = f16[2,6,40]{2,1,0} convert(reduce.23)
-  broadcast.4 = f16[2,6,40,40]{3,2,1,0} broadcast(convert.2), dimensions={0,1,2}
-  divide = f16[2,6,40,40]{3,2,1,0} divide(exponential.1, broadcast.4)
-  Arg_2.3 = f16[2,6,40,64]{3,2,1,0} parameter(2), sharding={replicated}
-  ROOT dot.1 = f16[2,6,40,64]{3,2,1,0} dot(divide, Arg_2.3), lhs_contracting_dims={3}, rhs_contracting_dims={2}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1}
-})";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
-                                         GetCudnnVersion()};
-  TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
-  const HloInstruction* fmha;
-
-  SCOPED_TRACE(m->ToString());
-  EXPECT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(m::GetTupleElement(
-                     m::CustomCall(&fmha, {kCudnnfMHASoftmaxCallTarget}), 0)
-                     .WithShape(F16, {2, 6, 40, 64})));
-  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
-                          fmha->backend_config<GpuBackendConfig>());
-  const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config();
-  EXPECT_FLOAT_EQ(config.fmha_scale(), 1.0);
-  EXPECT_FLOAT_EQ(config.dropout_rate(), 0.0);
-  EXPECT_EQ(fmha->operands().size(), 3);
-}
-
-TEST_F(CudnnFusedMhaRewriterTestHloTest,
-       BF16Bmm1ConvertedMaskAddedAfterFirstGemmSoftmaxBmm2) {
-  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
-  const char* module_str = R"(
-HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[16,256,16,64]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0},pred[16,1,256,256]{3,2,1,0})->bf16[16,256,16,64]{3,2,1,0}}
-
-region_0.27.clone {
-  Arg_0.0 = f32[] parameter(0)
-  Arg_1.0 = f32[] parameter(1)
-  ROOT maximum.1 = f32[] maximum(Arg_0.0, Arg_1.0)
-}
-
-region_1.39 {
-  Arg_0.40 = f32[] parameter(0)
-  Arg_1.41 = f32[] parameter(1)
-  ROOT add = f32[] add(Arg_0.40, Arg_1.41)
-}
-
-ENTRY main.56 {
-  Arg_2.3 = bf16[16,256,16,64]{3,2,1,0} parameter(2), sharding={replicated}
-  transpose.5 = bf16[16,16,64,256]{3,2,1,0} transpose(Arg_2.3), dimensions={0,2,3,1}
-  Arg_0.1 = bf16[16,256,16,64]{3,2,1,0} parameter(0), sharding={replicated}
-  transpose.6 = bf16[16,16,256,64]{3,2,1,0} transpose(Arg_0.1), dimensions={0,2,1,3}
-  Arg_1.2 = bf16[16,256,16,64]{3,2,1,0} parameter(1), sharding={replicated}
-  transpose.7 = bf16[16,16,64,256]{3,2,1,0} transpose(Arg_1.2), dimensions={0,2,3,1}
-  dot = bf16[16,16,256,256]{3,2,1,0} dot(transpose.6, transpose.7), lhs_contracting_dims={3}, rhs_contracting_dims={2}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1}
-  convert.47 = f32[16,16,256,256]{3,2,1,0} convert(dot)
-  Arg_3.4 = pred[16,1,256,256]{3,2,1,0} parameter(3), sharding={replicated}
-  bitcast.37 = pred[16,256,256]{2,1,0} bitcast(Arg_3.4)
-  convert.42 = s32[16,256,256]{2,1,0} convert(bitcast.37)
-  constant.6 = s32[] constant(0)
-  broadcast.9 = s32[16,256,256]{2,1,0} broadcast(constant.6), dimensions={}
-  compare = pred[16,256,256]{2,1,0} compare(convert.42, broadcast.9), direction=GT
-  constant.8 = bf16[] constant(0)
-  broadcast.11 = bf16[16,256,256]{2,1,0} broadcast(constant.8), dimensions={}
-  constant.10 = bf16[] constant(-9.999e+09)
-  broadcast.12 = bf16[16,256,256]{2,1,0} broadcast(constant.10), dimensions={}
-  select = bf16[16,256,256]{2,1,0} select(compare, broadcast.11, broadcast.12)
-  convert.48 = f32[16,256,256]{2,1,0} convert(select)
-  broadcast.14 = f32[16,16,256,256]{3,2,1,0} broadcast(convert.48), dimensions={0,2,3}
-  add.2 = f32[16,16,256,256]{3,2,1,0} add(convert.47, broadcast.14)
-  constant.13 = f32[] constant(-inf)
-  reduce.31 = f32[16,16,256]{2,1,0} reduce(add.2, constant.13), dimensions={3}, to_apply=region_0.27.clone
-  broadcast.16 = f32[16,16,256,256]{3,2,1,0} broadcast(reduce.31), dimensions={0,1,2}
-  subtract.1 = f32[16,16,256,256]{3,2,1,0} subtract(add.2, broadcast.16)
-  exponential.1 = f32[16,16,256,256]{3,2,1,0} exponential(subtract.1)
-  constant.14 = f32[] constant(0)
-  reduce.43 = f32[16,16,256]{2,1,0} reduce(exponential.1, constant.14), dimensions={3}, to_apply=region_1.39
-  broadcast.17 = f32[16,16,256,256]{3,2,1,0} broadcast(reduce.43), dimensions={0,1,2}
-  divide = f32[16,16,256,256]{3,2,1,0} divide(exponential.1, broadcast.17)
-  convert.63 = bf16[16,16,256,256]{3,2,1,0} convert(divide)
-  dot.1 = bf16[16,16,64,256]{3,2,1,0} dot(transpose.5, convert.63), lhs_contracting_dims={3}, rhs_contracting_dims={3}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1}
-  ROOT transpose.8 = bf16[16,256,16,64]{3,2,1,0} transpose(dot.1), dimensions={0,3,1,2}
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
-                                         GetCudnnVersion()};
-  TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
-  const HloInstruction* fmha;
-
-  SCOPED_TRACE(m->ToString());
-  EXPECT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::Transpose(
-              m::Transpose(m::GetTupleElement(
-                  m::CustomCall(&fmha, {kCudnnfMHAScaleBiasSoftmaxCallTarget}),
-                  0)))
-              .WithShape(BF16, {16, 256, 16, 64})));
-  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
-                          fmha->backend_config<GpuBackendConfig>());
-  EXPECT_EQ(fmha->operands().size(), 4);
-}
-
-// negative test
-TEST_F(CudnnFusedMhaRewriterTestHloTest,
-       BF16Bmm1Bmm2Pattern_bmm1_contracting_dim_not_equal_64) {
-  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
-  const char* module_str = R"(
-HloModule fmha_test, entry_computation_layout={(bf16[16,16,256,32]{3,2,1,0},bf16[16,16,256,32]{3,2,1,0},bf16[16,16,256,64]{3,2,1,0})->bf16[16,16,256,64]{3,2,1,0}}
-ENTRY main.6 {
-  Arg_2.3 = bf16[16,16,256,64]{3,2,1,0} parameter(2)
-  Arg_0.1 = bf16[16,16,256,32]{3,2,1,0} parameter(0)
-  Arg_1.2 = bf16[16,16,256,32]{3,2,1,0} parameter(1)
-  dot.0 = bf16[16,16,256,256]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}, metadata={}
-  ROOT dot.1 = bf16[16,16,256,64]{3,2,1,0} dot(dot.0, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}, metadata={}
-}
-
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
-                                         GetCudnnVersion()};
-  TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
-  const HloInstruction* fmha;
-
-  SCOPED_TRACE(m->ToString());
-  EXPECT_THAT(m->entry_computation()->root_instruction(),
-              GmockMatch(m::Dot(&fmha, m::Dot(m::Parameter(0), m::Parameter(1)),
-                                m::Parameter(2))
-                             .WithShape(BF16, {16, 16, 256, 64})));
-}
-
-TEST_F(CudnnFusedMhaRewriterTestHloTest,
-       BF16Bmm1Bmm2Pattern_bmm2_rhs_non_contracting_dim_not_equal_64) {
-  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
-  const char* module_str = R"(
-HloModule fmha_test, entry_computation_layout={(bf16[16,16,256,64]{3,2,1,0},bf16[16,16,256,64]{3,2,1,0},bf16[16,16,256,32]{3,2,1,0})->bf16[16,16,256,32]{3,2,1,0}}
-ENTRY main.6 {
-  Arg_2.3 = bf16[16,16,256,32]{3,2,1,0} parameter(2)
-  Arg_0.1 = bf16[16,16,256,64]{3,2,1,0} parameter(0)
-  Arg_1.2 = bf16[16,16,256,64]{3,2,1,0} parameter(1)
-  dot.0 = bf16[16,16,256,256]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}, metadata={}
-  ROOT dot.1 = bf16[16,16,256,32]{3,2,1,0} dot(dot.0, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}, metadata={}
-}
-
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
-                                         GetCudnnVersion()};
-  TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
-  const HloInstruction* fmha;
-
-  SCOPED_TRACE(m->ToString());
-  EXPECT_THAT(m->entry_computation()->root_instruction(),
-              GmockMatch(m::Dot(&fmha, m::Op(), m::Parameter(2))
-                             .WithShape(BF16, {16, 16, 256, 32})));
-}
-
-// check if MHA is unsupported, canonicalization will not kick in
-TEST_F(CudnnFusedMhaRewriterTestHloTest,
-       BF16Bmm1Bmm2PatternUncanonicalized_bmm1_contracting_dim_not_equal_64) {
-  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
-  const char* module_str = R"(
-HloModule fmha_test, entry_computation_layout={(bf16[16,16,256,32]{3,2,1,0},bf16[16,16,256,32]{3,2,1,0},bf16[16,16,256,64]{3,2,1,0})->bf16[16,16,64,256]{3,2,1,0}}
-
-ENTRY main.6 {
-  Arg_2.3 = bf16[16,16,256,64]{3,2,1,0} parameter(2)
-  Arg_0.1 = bf16[16,16,256,32]{3,2,1,0} parameter(0)
-  Arg_1.2 = bf16[16,16,256,32]{3,2,1,0} parameter(1)
-  dot.0 = bf16[16,16,256,256]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}, metadata={}
-  ROOT dot.1 = bf16[16,16,64,256]{3,2,1,0} dot(Arg_2.3, dot.0), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}, metadata={}
-}
-
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
-                                         GetCudnnVersion()};
-
-  TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
-  const HloInstruction* fmha;
-
-  SCOPED_TRACE(m->ToString());
-  EXPECT_THAT(m->entry_computation()->root_instruction(),
-              GmockMatch(m::Dot(&fmha, m::Parameter(2), m::Op())
-                             .WithShape(BF16, {16, 16, 64, 256})));
-}
-
-TEST_F(CudnnFusedMhaRewriterTestHloTest, BF16Bmm1BiasSoftmaxDropoutBmm2) {
-  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
-  const char* module_str = R"(
-HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[16,256,16,64]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0},bf16[1,16,256,256]{3,2,1,0})->bf16[16,256,16,64]{3,2,1,0}}
-
-region_0.34 {
-  Arg_0.35 = bf16[] parameter(0)
-  Arg_1.36 = bf16[] parameter(1)
-  ROOT maximum.37 = bf16[] maximum(Arg_0.35, Arg_1.36)
-}
-
-region_1.46 {
-  Arg_0.47 = f32[] parameter(0)
-  Arg_1.48 = f32[] parameter(1)
-  ROOT add.49 = f32[] add(Arg_0.47, Arg_1.48)
-}
-
-ENTRY main.82 {
-  Arg_2.3 = bf16[16,256,16,64]{3,2,1,0} parameter(2), sharding={replicated}
-  copy = bf16[16,256,16,64]{1,3,2,0} copy(Arg_2.3), sharding={replicated}
-  transpose.2 = bf16[16,16,64,256]{3,2,1,0} transpose(copy), dimensions={0,2,3,1}
-  Arg_0.1 = bf16[16,256,16,64]{3,2,1,0} parameter(0), sharding={replicated}
-  copy.1 = bf16[16,256,16,64]{3,1,2,0} copy(Arg_0.1), sharding={replicated}
-  transpose = bf16[16,16,256,64]{3,2,1,0} transpose(copy.1), dimensions={0,2,1,3}
-  Arg_1.2 = bf16[16,256,16,64]{3,2,1,0} parameter(1), sharding={replicated}
-  copy.2 = bf16[16,256,16,64]{1,3,2,0} copy(Arg_1.2), sharding={replicated}
-  transpose.1 = bf16[16,16,64,256]{3,2,1,0} transpose(copy.2), dimensions={0,2,3,1}
-  dot = bf16[16,16,256,256]{3,2,1,0} dot(transpose, transpose.1), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  Arg_3.4 = bf16[1,16,256,256]{3,2,1,0} parameter(3), sharding={replicated}
-  reshape.31 = bf16[16,256,256]{2,1,0} reshape(Arg_3.4)
-  broadcast.32 = bf16[16,16,256,256]{3,2,1,0} broadcast(reshape.31), dimensions={1,2,3}
-  add.33 = bf16[16,16,256,256]{3,2,1,0} add(dot, broadcast.32)
-  constant.21 = bf16[] constant(-inf)
-  reduce.38 = bf16[16,16,256]{2,1,0} reduce(add.33, constant.21), dimensions={3}, to_apply=region_0.34
-  broadcast.42 = bf16[16,16,256,256]{3,2,1,0} broadcast(reduce.38), dimensions={0,1,2}
-  subtract.43 = bf16[16,16,256,256]{3,2,1,0} subtract(add.33, broadcast.42)
-  exponential.44 = bf16[16,16,256,256]{3,2,1,0} exponential(subtract.43)
-  convert.45 = f32[16,16,256,256]{3,2,1,0} convert(exponential.44)
-  constant.9 = f32[] constant(0)
-  reduce.50 = f32[16,16,256]{2,1,0} reduce(convert.45, constant.9), dimensions={3}, to_apply=region_1.46
-  convert.1 = bf16[16,16,256]{2,1,0} convert(reduce.50)
-  broadcast.55 = bf16[16,16,256,256]{3,2,1,0} broadcast(convert.1), dimensions={0,1,2}
-  divide.56 = bf16[16,16,256,256]{3,2,1,0} divide(exponential.44, broadcast.55)
-  constant.18 = u32[1]{0} constant({255383827})
-  constant.17 = u32[1]{0} constant({267815257})
-  constant.2 = u32[1]{0} constant({0})
-  constant.19 = u32[1]{0} constant({3213575472})
-  custom-call.26 = (u32[1]{0}, u32[1]{0}) custom-call(constant.18, constant.17, constant.2, constant.19), custom_call_target="cu_threefry2x32", operand_layout_constraints={u32[1]{0}, u32[1]{0}, u32[1]{0}, u32[1]{0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config="\001\000\000\000\000\000\000\000"
-  get-tuple-element.27 = u32[1]{0} get-tuple-element(custom-call.26), index=0
-  reshape.58 = u32[] reshape(get-tuple-element.27)
-  broadcast.62 = u32[32768]{0} broadcast(reshape.58), dimensions={}
-  get-tuple-element.28 = u32[1]{0} get-tuple-element(custom-call.26), index=1
-  reshape.59 = u32[] reshape(get-tuple-element.28)
-  broadcast.63 = u32[32768]{0} broadcast(reshape.59), dimensions={}
-  iota.57 = u32[65536]{0} iota(), iota_dimension=0
-  slice.60 = u32[32768]{0} slice(iota.57), slice={[0:32768]}
-  slice.61 = u32[32768]{0} slice(iota.57), slice={[32768:65536]}
-  custom-call.64 = (u32[32768]{0}, u32[32768]{0}) custom-call(broadcast.62, broadcast.63, slice.60, slice.61), custom_call_target="cu_threefry2x32", operand_layout_constraints={u32[32768]{0}, u32[32768]{0}, u32[32768]{0}, u32[32768]{0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config="\000\200\000\000\000\000\000\000"
-  get-tuple-element.65 = u32[32768]{0} get-tuple-element(custom-call.64), index=0
-  get-tuple-element.66 = u32[32768]{0} get-tuple-element(custom-call.64), index=1
-  concatenate.67 = u32[65536]{0} concatenate(get-tuple-element.65, get-tuple-element.66), dimensions={0}
-  constant.15 = u32[] constant(9)
-  broadcast.3 = u32[65536]{0} broadcast(constant.15), dimensions={}
-  shift-right-logical.0 = u32[65536]{0} shift-right-logical(concatenate.67, broadcast.3)
-  constant.13 = u32[] constant(1065353216)
-  broadcast.11 = u32[65536]{0} broadcast(constant.13), dimensions={}
-  or.0 = u32[65536]{0} or(shift-right-logical.0, broadcast.11)
-  bitcast-convert.0 = f32[65536]{0} bitcast-convert(or.0)
-  constant.3 = f32[] constant(-1)
-  broadcast.17 = f32[65536]{0} broadcast(constant.3), dimensions={}
-  add.1 = f32[65536]{0} add(bitcast-convert.0, broadcast.17)
-  broadcast.18 = f32[65536]{0} broadcast(constant.9), dimensions={}
-  maximum.0 = f32[65536]{0} maximum(add.1, broadcast.18)
-  constant.7 = f32[] constant(0.9)
-  broadcast.19 = f32[65536]{0} broadcast(constant.7), dimensions={}
-  compare.0 = pred[65536]{0} compare(maximum.0, broadcast.19), direction=LT
-  constant = bf16[] constant(1.109)
-  broadcast.20 = bf16[65536]{0} broadcast(constant), dimensions={}
-  constant.4 = bf16[] constant(0)
-  broadcast.21 = bf16[65536]{0} broadcast(constant.4), dimensions={}
-  select.1 = bf16[65536]{0} select(compare.0, broadcast.20, broadcast.21)
-  reshape.19 = bf16[16,16,256]{2,1,0} reshape(select.1)
-  broadcast.9 = bf16[16,16,256,256]{3,2,1,0} broadcast(reshape.19), dimensions={0,1,3}
-  multiply.79 = bf16[16,16,256,256]{3,2,1,0} multiply(divide.56, broadcast.9)
-  dot.1 = bf16[16,16,64,256]{3,2,1,0} dot(transpose.2, multiply.79), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
-  transpose.81 = bf16[16,256,16,64]{1,3,2,0} transpose(dot.1), dimensions={0,3,1,2}
-  ROOT copy.3 = bf16[16,256,16,64]{3,2,1,0} copy(transpose.81)
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
-                                         GetCudnnVersion()};
-  TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
-  const HloInstruction* fmha;
-
-  SCOPED_TRACE(m->ToString());
-  EXPECT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::Copy(m::Transpose(m::Transpose(m::GetTupleElement(
-                      m::CustomCall(
-                          &fmha, {kCudnnfMHAScaleBiasSoftmaxDropoutCallTarget}),
-                      0))))
-              .WithShape(BF16, {16, 256, 16, 64})));
-  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
-                          fmha->backend_config<GpuBackendConfig>());
-  const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config();
-  EXPECT_EQ(fmha->operands().size(), 4);
-  EXPECT_NEAR(config.dropout_rate(), 0.1, 1e-2);
-}
-
-TEST_F(CudnnFusedMhaRewriterTestHloTest,
-       BF16Bmm1ScaleBiasSoftmaxDropoutForm2Bmm2) {
-  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
-  const char* module_str = R"(
-HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[32,40,60,64]{3,2,1,0},bf16[32,40,60,64]{3,2,1,0},bf16[32,40,60,64]{3,2,1,0})->bf16[32,40,60,64]{3,2,1,0}}, allow_spmd_sharding_propagation_to_output={true}
-
-region_0.29 {
-  Arg_0.30 = bf16[] parameter(0)
-  Arg_1.31 = bf16[] parameter(1)
-  ROOT maximum.32 = bf16[] maximum(Arg_0.30, Arg_1.31)
-}
-
-region_1.41 {
-  Arg_0.42 = f32[] parameter(0)
-  Arg_1.43 = f32[] parameter(1)
-  ROOT add.44 = f32[] add(Arg_0.42, Arg_1.43)
-}
-
-ENTRY main.79 {
-  Arg_2.3 = bf16[32,40,60,64]{3,2,1,0} parameter(2), sharding={replicated}
-  copy = bf16[32,40,60,64]{1,3,2,0} copy(Arg_2.3), sharding={replicated}
-  transpose.2 = bf16[32,60,64,40]{3,2,1,0} transpose(copy), dimensions={0,2,3,1}
-  constant.19 = u32[1]{0} constant({2718843009})
-  constant.18 = u32[1]{0} constant({1272950319})
-  constant.2 = u32[1]{0} constant({0})
-  constant.20 = u32[1]{0} constant({2711844646})
-  custom-call.54 = (u32[1]{0}, u32[1]{0}) custom-call(constant.19, constant.18, constant.2, constant.20), custom_call_target="cu_threefry2x32", operand_layout_constraints={u32[1]{0}, u32[1]{0}, u32[1]{0}, u32[1]{0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config="\001\000\000\000\000\000\000\000"
-  get-tuple-element.55 = u32[1]{0} get-tuple-element(custom-call.54), index=0
-  reshape.58 = u32[] reshape(get-tuple-element.55)
-  broadcast.62 = u32[1536000]{0} broadcast(reshape.58), dimensions={}
-  get-tuple-element.56 = u32[1]{0} get-tuple-element(custom-call.54), index=1
-  reshape.59 = u32[] reshape(get-tuple-element.56)
-  broadcast.63 = u32[1536000]{0} broadcast(reshape.59), dimensions={}
-  iota.57 = u32[3072000]{0} iota(), iota_dimension=0
-  slice.60 = u32[1536000]{0} slice(iota.57), slice={[0:1536000]}
-  slice.61 = u32[1536000]{0} slice(iota.57), slice={[1536000:3072000]}
-  custom-call.64 = (u32[1536000]{0}, u32[1536000]{0}) custom-call(broadcast.62, broadcast.63, slice.60, slice.61), custom_call_target="cu_threefry2x32", operand_layout_constraints={u32[1536000]{0}, u32[1536000]{0}, u32[1536000]{0}, u32[1536000]{0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config="\000p\027\000\000\000\000\000"
-  get-tuple-element.65 = u32[1536000]{0} get-tuple-element(custom-call.64), index=0
-  get-tuple-element.66 = u32[1536000]{0} get-tuple-element(custom-call.64), index=1
-  concatenate.67 = u32[3072000]{0} concatenate(get-tuple-element.65, get-tuple-element.66), dimensions={0}
-  constant.16 = u32[] constant(9)
-  broadcast.2 = u32[3072000]{0} broadcast(constant.16), dimensions={}
-  shift-right-logical.0 = u32[3072000]{0} shift-right-logical(concatenate.67, broadcast.2)
-  constant.14 = u32[] constant(1065353216)
-  broadcast.6 = u32[3072000]{0} broadcast(constant.14), dimensions={}
-  or.0 = u32[3072000]{0} or(shift-right-logical.0, broadcast.6)
-  bitcast-convert.0 = f32[3072000]{0} bitcast-convert(or.0)
-  constant.3 = f32[] constant(-1)
-  broadcast.8 = f32[3072000]{0} broadcast(constant.3), dimensions={}
-  add.1 = f32[3072000]{0} add(bitcast-convert.0, broadcast.8)
-  constant.10 = f32[] constant(0)
-  broadcast.10 = f32[3072000]{0} broadcast(constant.10), dimensions={}
-  maximum.0 = f32[3072000]{0} maximum(add.1, broadcast.10)
-  constant.8 = f32[] constant(0.9)
-  broadcast.12 = f32[3072000]{0} broadcast(constant.8), dimensions={}
-  compare.0 = pred[3072000]{0} compare(maximum.0, broadcast.12), direction=LT
-  reshape.18 = pred[32,60,40,40]{3,2,1,0} reshape(compare.0)
-  Arg_0.1 = bf16[32,40,60,64]{3,2,1,0} parameter(0), sharding={replicated}
-  copy.1 = bf16[32,40,60,64]{3,1,2,0} copy(Arg_0.1), sharding={replicated}
-  transpose = bf16[32,60,40,64]{3,2,1,0} transpose(copy.1), dimensions={0,2,1,3}
-  Arg_1.2 = bf16[32,40,60,64]{3,2,1,0} parameter(1), sharding={replicated}
-  copy.2 = bf16[32,40,60,64]{1,3,2,0} copy(Arg_1.2), sharding={replicated}
-  transpose.1 = bf16[32,60,64,40]{3,2,1,0} transpose(copy.2), dimensions={0,2,3,1}
-  dot = bf16[32,60,40,40]{3,2,1,0} dot(transpose, transpose.1), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  constant.25 = bf16[] constant(1)
-  broadcast.26 = bf16[32,60,40,40]{3,2,1,0} broadcast(constant.25), dimensions={}
-  add.28 = bf16[32,60,40,40]{3,2,1,0} add(dot, broadcast.26)
-  constant.24 = bf16[] constant(-inf)
-  reduce.33 = bf16[32,60,40]{2,1,0} reduce(add.28, constant.24), dimensions={3}, to_apply=region_0.29
-  broadcast.37 = bf16[32,60,40,40]{3,2,1,0} broadcast(reduce.33), dimensions={0,1,2}
-  subtract.38 = bf16[32,60,40,40]{3,2,1,0} subtract(add.28, broadcast.37)
-  exponential.39 = bf16[32,60,40,40]{3,2,1,0} exponential(subtract.38)
-  convert.40 = f32[32,60,40,40]{3,2,1,0} convert(exponential.39)
-  reduce.45 = f32[32,60,40]{2,1,0} reduce(convert.40, constant.10), dimensions={3}, to_apply=region_1.41
-  convert.0 = bf16[32,60,40]{2,1,0} convert(reduce.45)
-  broadcast.50 = bf16[32,60,40,40]{3,2,1,0} broadcast(convert.0), dimensions={0,1,2}
-  divide.51 = bf16[32,60,40,40]{3,2,1,0} divide(exponential.39, broadcast.50)
-  constant = bf16[] constant(1.109)
-  broadcast.1 = bf16[32,60,40,40]{3,2,1,0} broadcast(constant), dimensions={}
-  multiply = bf16[32,60,40,40]{3,2,1,0} multiply(divide.51, broadcast.1)
-  constant.4 = bf16[] constant(0)
-  broadcast.5 = bf16[32,60,40,40]{3,2,1,0} broadcast(constant.4), dimensions={}
-  select.76 = bf16[32,60,40,40]{3,2,1,0} select(reshape.18, multiply, broadcast.5)
-  dot.1 = bf16[32,60,64,40]{3,2,1,0} dot(transpose.2, select.76), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
-  transpose.78 = bf16[32,40,60,64]{1,3,2,0} transpose(dot.1), dimensions={0,3,1,2}
-  ROOT copy.3 = bf16[32,40,60,64]{3,2,1,0} copy(transpose.78)
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
-                                         GetCudnnVersion()};
-  TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
-  const HloInstruction* fmha;
-
-  SCOPED_TRACE(m->ToString());
-  EXPECT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::Copy(m::Transpose(m::Transpose(m::GetTupleElement(
-                      m::CustomCall(
-                          &fmha, {kCudnnfMHAScaleBiasSoftmaxDropoutCallTarget}),
-                      0))))
-              .WithShape(BF16, {32, 40, 60, 64})));
-  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
-                          fmha->backend_config<GpuBackendConfig>());
-  const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config();
-  EXPECT_NEAR(config.dropout_rate(), 0.1, 1e-2);
-  EXPECT_EQ(fmha->operands().size(), 4);
-}
-
-TEST_F(CudnnFusedMhaRewriterTestHloTest, BF16TrainingBmm1Bmm2) {
-  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
-  const char* module_str = R"(
-HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[16,256,16,64]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0})->(bf16[16,256,16,64]{3,2,1,0}, bf16[16,256,16,64]{3,2,1,0}, bf16[16,256,16,64]{3,2,1,0}, bf16[16,256,16,64]{3,2,1,0})}
-
-ENTRY main.17 {
-  Arg_2.3 = bf16[16,256,16,64]{3,2,1,0} parameter(2), sharding={replicated}
-  copy = bf16[16,256,16,64]{1,3,2,0} copy(Arg_2.3), sharding={replicated}
-  transpose.2 = bf16[16,16,64,256]{3,2,1,0} transpose(copy), dimensions={0,2,3,1}
-  Arg_0.1 = bf16[16,256,16,64]{3,2,1,0} parameter(0), sharding={replicated}
-  copy.1 = bf16[16,256,16,64]{3,1,2,0} copy(Arg_0.1), sharding={replicated}
-  transpose = bf16[16,16,256,64]{3,2,1,0} transpose(copy.1), dimensions={0,2,1,3}
-  Arg_1.2 = bf16[16,256,16,64]{3,2,1,0} parameter(1), sharding={replicated}
-  copy.2 = bf16[16,256,16,64]{1,3,2,0} copy(Arg_1.2), sharding={replicated}
-  transpose.1 = bf16[16,16,64,256]{3,2,1,0} transpose(copy.2), dimensions={0,2,3,1}
-  dot = bf16[16,16,256,256]{3,2,1,0} dot(transpose, transpose.1), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  dot.1 = bf16[16,16,64,256]{3,2,1,0} dot(transpose.2, dot), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
-  transpose.7 = bf16[16,256,16,64]{1,3,2,0} transpose(dot.1), dimensions={0,3,1,2}
-  Arg_3.4 = bf16[16,256,16,64]{3,2,1,0} parameter(3), sharding={replicated}
-  copy.3 = bf16[16,256,16,64]{3,1,2,0} copy(Arg_3.4), sharding={replicated}
-  transpose.4 = bf16[16,16,256,64]{3,2,1,0} transpose(copy.3), dimensions={0,2,1,3}
-  dot.2 = bf16[16,16,256,256]{3,2,1,0} dot(transpose.4, transpose.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  copy.4 = bf16[16,256,16,64]{3,1,2,0} copy(Arg_1.2), sharding={replicated}
-  transpose.12 = bf16[16,16,256,64]{3,2,1,0} transpose(copy.4), dimensions={0,2,1,3}
-  dot.4 = bf16[16,16,256,64]{3,2,1,0} dot(dot.2, transpose.12), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  transpose.15 = bf16[16,256,16,64]{3,1,2,0} transpose(dot.4), dimensions={0,2,1,3}
-  dot.3 = bf16[16,16,256,64]{3,2,1,0} dot(dot.2, transpose), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  transpose.13 = bf16[16,256,16,64]{3,1,2,0} transpose(dot.3), dimensions={0,2,1,3}
-  copy.5 = bf16[16,256,16,64]{1,3,2,0} copy(Arg_3.4), sharding={replicated}
-  transpose.8 = bf16[16,16,64,256]{3,2,1,0} transpose(copy.5), dimensions={0,2,3,1}
-  dot.10 = bf16[16,16,64,256]{3,2,1,0} dot(transpose.8, dot), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  transpose.11 = bf16[16,256,16,64]{1,3,2,0} transpose(dot.10), dimensions={0,3,1,2}
-  tuple.16 = (bf16[16,256,16,64]{1,3,2,0}, bf16[16,256,16,64]{3,1,2,0}, bf16[16,256,16,64]{3,1,2,0}, bf16[16,256,16,64]{1,3,2,0}) tuple(transpose.7, transpose.15, transpose.13, transpose.11)
-  get-tuple-element = bf16[16,256,16,64]{1,3,2,0} get-tuple-element(tuple.16), index=0
-  copy.6 = bf16[16,256,16,64]{3,2,1,0} copy(get-tuple-element)
-  get-tuple-element.1 = bf16[16,256,16,64]{3,1,2,0} get-tuple-element(tuple.16), index=1
-  copy.7 = bf16[16,256,16,64]{3,2,1,0} copy(get-tuple-element.1)
-  get-tuple-element.2 = bf16[16,256,16,64]{3,1,2,0} get-tuple-element(tuple.16), index=2
-  copy.8 = bf16[16,256,16,64]{3,2,1,0} copy(get-tuple-element.2)
-  get-tuple-element.3 = bf16[16,256,16,64]{1,3,2,0} get-tuple-element(tuple.16), index=3
-  copy.9 = bf16[16,256,16,64]{3,2,1,0} copy(get-tuple-element.3)
-  ROOT tuple = (bf16[16,256,16,64]{3,2,1,0}, bf16[16,256,16,64]{3,2,1,0}, bf16[16,256,16,64]{3,2,1,0}, bf16[16,256,16,64]{3,2,1,0}) tuple(copy.6, copy.7, copy.8, copy.9)
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
-                                         GetCudnnVersion()};
-  const auto status = RunHloPass(&fusedMhaRewriter, m.get());
-  const bool changed = status.value();
-  EXPECT_EQ(changed, false);
-}
-
-TEST_F(CudnnFusedMhaRewriterTestHloTest, BF16MiniT5xTest) {
-  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
-  const char* module_str = R"(
-HloModule jit__lambda_, entry_computation_layout={(bf16[12,512,32,64]{3,2,1,0},bf16[12,512,2,32,64]{4,3,2,1,0},f32[12,512]{1,0},f32[12,512]{1,0})->(bf16[], bf16[12,512,32,64]{3,2,1,0}, bf16[12,512,2,32,64]{4,3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true}
-
-region_0.51 {
-  Arg_0.52 = bf16[] parameter(0)
-  Arg_1.53 = bf16[] parameter(1)
-  ROOT maximum.54 = bf16[] maximum(Arg_0.52, Arg_1.53)
-}
-
-region_1.63 {
-  Arg_0.64 = f32[] parameter(0)
-  Arg_1.65 = f32[] parameter(1)
-  ROOT add.66 = f32[] add(Arg_0.64, Arg_1.65)
-}
-
-region_3.99 {
-  Arg_0.100 = bf16[] parameter(0)
-  Arg_1.101 = bf16[] parameter(1)
-  ROOT add.102 = bf16[] add(Arg_0.100, Arg_1.101)
-}
-
-ENTRY main.129 {
-  Arg_1.2 = bf16[12,512,2,32,64]{4,3,2,1,0} parameter(1), sharding={replicated}
-  copy = bf16[12,512,2,32,64]{1,4,3,0,2} copy(Arg_1.2), sharding={replicated}
-  slice.42 = bf16[12,512,1,32,64]{1,4,3,0,2} slice(copy), slice={[0:12], [0:512], [1:2], [0:32], [0:64]}
-  reshape.44 = bf16[12,512,32,64]{1,3,2,0} reshape(slice.42)
-  transpose.5 = bf16[12,32,64,512]{3,2,1,0} transpose(reshape.44), dimensions={0,2,3,1}
-  Arg_0.1 = bf16[12,512,32,64]{3,2,1,0} parameter(0), sharding={replicated}
-  copy.1 = bf16[12,512,32,64]{3,1,2,0} copy(Arg_0.1), sharding={replicated}
-  constant.5 = bf16[] constant(0.125)
-  broadcast.6 = bf16[12,512,32,64]{3,1,2,0} broadcast(constant.5), dimensions={}
-  multiply.45 = bf16[12,512,32,64]{3,1,2,0} multiply(copy.1, broadcast.6)
-  transpose = bf16[12,32,512,64]{3,2,1,0} transpose(multiply.45), dimensions={0,2,1,3}
-  copy.2 = bf16[12,512,2,32,64]{1,4,3,0,2} copy(Arg_1.2), sharding={replicated}
-  slice.41 = bf16[12,512,1,32,64]{1,4,3,0,2} slice(copy.2), slice={[0:12], [0:512], [0:1], [0:32], [0:64]}
-  reshape.43 = bf16[12,512,32,64]{1,3,2,0} reshape(slice.41)
-  transpose.1 = bf16[12,32,64,512]{3,2,1,0} transpose(reshape.43), dimensions={0,2,3,1}
-  dot = bf16[12,32,512,512]{3,2,1,0} dot(transpose, transpose.1), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  Arg_2.3 = f32[12,512]{1,0} parameter(2), sharding={replicated}
-  constant.14 = f32[] constant(0)
-  broadcast.19 = f32[12,512]{1,0} broadcast(constant.14), dimensions={}
-  compare.24 = pred[12,512]{1,0} compare(Arg_2.3, broadcast.19), direction=GT
-  broadcast.30 = pred[12,512,512]{2,1,0} broadcast(compare.24), dimensions={0,1}
-  Arg_3.4 = f32[12,512]{1,0} parameter(3), sharding={replicated}
-  compare.25 = pred[12,512]{1,0} compare(Arg_3.4, broadcast.19), direction=GT
-  broadcast.33 = pred[12,512,512]{2,1,0} broadcast(compare.25), dimensions={0,2}
-  and.34 = pred[12,512,512]{2,1,0} and(broadcast.30, broadcast.33)
-  convert.4 = s32[12,512,512]{2,1,0} convert(and.34)
-  constant.16 = s32[] constant(0)
-  broadcast.21 = s32[12,512,512]{2,1,0} broadcast(constant.16), dimensions={}
-  compare.0 = pred[12,512,512]{2,1,0} compare(convert.4, broadcast.21), direction=GT
-  constant.20 = bf16[] constant(0)
-  broadcast.22 = bf16[12,512,512]{2,1,0} broadcast(constant.20), dimensions={}
-  constant.11 = bf16[] constant(-9.999e+09)
-  broadcast.23 = bf16[12,512,512]{2,1,0} broadcast(constant.11), dimensions={}
-  select.0 = bf16[12,512,512]{2,1,0} select(compare.0, broadcast.22, broadcast.23)
-  broadcast.49 = bf16[12,32,512,512]{3,2,1,0} broadcast(select.0), dimensions={0,2,3}
-  add.50 = bf16[12,32,512,512]{3,2,1,0} add(dot, broadcast.49)
-  constant.22 = bf16[] constant(-inf)
-  reduce.55 = bf16[12,32,512]{2,1,0} reduce(add.50, constant.22), dimensions={3}, to_apply=region_0.51
-  broadcast.59 = bf16[12,32,512,512]{3,2,1,0} broadcast(reduce.55), dimensions={0,1,2}
-  subtract.60 = bf16[12,32,512,512]{3,2,1,0} subtract(add.50, broadcast.59)
-  exponential.61 = bf16[12,32,512,512]{3,2,1,0} exponential(subtract.60)
-  convert.62 = f32[12,32,512,512]{3,2,1,0} convert(exponential.61)
-  reduce.67 = f32[12,32,512]{2,1,0} reduce(convert.62, constant.14), dimensions={3}, to_apply=region_1.63
-  convert.5 = bf16[12,32,512]{2,1,0} convert(reduce.67)
-  broadcast.72 = bf16[12,32,512,512]{3,2,1,0} broadcast(convert.5), dimensions={0,1,2}
-  divide.73 = bf16[12,32,512,512]{3,2,1,0} divide(exponential.61, broadcast.72)
-  dot.1 = bf16[12,32,64,512]{3,2,1,0} dot(transpose.5, divide.73), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
-  convert.6 = f32[12,32,64,512]{3,2,1,0} convert(dot.1)
-  reduce.83 = f32[] reduce(convert.6, constant.14), dimensions={0,3,1,2}, to_apply=region_1.63
-  convert.84 = bf16[] convert(reduce.83)
-  constant.2 = bf16[] constant(0.0007935)
-  multiply.86 = bf16[] multiply(convert.84, constant.2)
-  broadcast.9 = bf16[12,32,512,64]{3,2,1,0} broadcast(constant.2), dimensions={}
-  dot.2 = bf16[12,32,512,512]{3,2,1,0} dot(broadcast.9, transpose.5), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  divide.109 = bf16[12,32,512,512]{3,2,1,0} divide(dot.2, broadcast.72)
-  constant.10 = bf16[] constant(1)
-  broadcast.24 = bf16[12,32,512]{2,1,0} broadcast(constant.10), dimensions={}
-  multiply.4 = bf16[12,32,512]{2,1,0} multiply(convert.5, convert.5)
-  divide.0 = bf16[12,32,512]{2,1,0} divide(broadcast.24, multiply.4)
-  broadcast.96 = bf16[12,32,512,512]{3,2,1,0} broadcast(divide.0), dimensions={0,1,2}
-  multiply.97 = bf16[12,32,512,512]{3,2,1,0} multiply(dot.2, broadcast.96)
-  multiply.98 = bf16[12,32,512,512]{3,2,1,0} multiply(multiply.97, exponential.61)
-  reduce.103 = bf16[12,32,512]{2,1,0} reduce(multiply.98, constant.20), dimensions={3}, to_apply=region_3.99
-  negate.0 = bf16[12,32,512]{2,1,0} negate(reduce.103)
-  broadcast.10 = bf16[12,32,512,512]{3,2,1,0} broadcast(negate.0), dimensions={0,1,2}
-  add.118 = bf16[12,32,512,512]{3,2,1,0} add(divide.109, broadcast.10)
-  multiply.119 = bf16[12,32,512,512]{3,2,1,0} multiply(add.118, exponential.61)
-  transpose.9 = bf16[12,32,512,64]{2,3,1,0} transpose(reshape.43), dimensions={0,2,1,3}
-  copy.3 = bf16[12,32,512,64]{3,2,1,0} copy(transpose.9)
-  dot.4 = bf16[12,32,512,64]{3,2,1,0} dot(multiply.119, copy.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  broadcast.12 = bf16[12,32,512,64]{3,2,1,0} broadcast(constant.5), dimensions={}
-  multiply.3 = bf16[12,32,512,64]{3,2,1,0} multiply(dot.4, broadcast.12)
-  transpose.11 = bf16[12,512,32,64]{3,1,2,0} transpose(multiply.3), dimensions={0,2,1,3}
-  broadcast.7 = bf16[12,32,64,512]{3,2,1,0} broadcast(constant.2), dimensions={}
-  dot.90 = bf16[12,32,64,512]{3,2,1,0} dot(broadcast.7, divide.73), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  transpose.91 = bf16[12,512,32,64]{1,3,2,0} transpose(dot.90), dimensions={0,3,1,2}
-  reshape.92 = bf16[12,512,1,32,64]{1,4,3,0,2} reshape(transpose.91)
-  pad.93 = bf16[12,512,2,32,64]{1,4,3,0,2} pad(reshape.92, constant.20), padding=0_0x0_0x1_0x0_0x0_0
-  dot.3 = bf16[12,32,512,64]{3,2,1,0} dot(multiply.119, transpose), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  copy.4 = bf16[12,32,512,64]{2,3,1,0} copy(dot.3)
-  transpose.121 = bf16[12,512,32,64]{1,3,2,0} transpose(copy.4), dimensions={0,2,1,3}
-  reshape.124 = bf16[12,512,1,32,64]{1,4,3,0,2} reshape(transpose.121)
-  pad.125 = bf16[12,512,2,32,64]{1,4,3,0,2} pad(reshape.124, constant.20), padding=0_0x0_0x0_1x0_0x0_0
-  add.126 = bf16[12,512,2,32,64]{1,4,3,0,2} add(pad.93, pad.125)
-  tuple.128 = (bf16[], bf16[12,512,32,64]{3,1,2,0}, bf16[12,512,2,32,64]{1,4,3,0,2}) tuple(multiply.86, transpose.11, add.126)
-  get-tuple-element = bf16[] get-tuple-element(tuple.128), index=0
-  get-tuple-element.1 = bf16[12,512,32,64]{3,1,2,0} get-tuple-element(tuple.128), index=1
-  copy.5 = bf16[12,512,32,64]{3,2,1,0} copy(get-tuple-element.1)
-  get-tuple-element.2 = bf16[12,512,2,32,64]{1,4,3,0,2} get-tuple-element(tuple.128), index=2
-  copy.6 = bf16[12,512,2,32,64]{4,3,2,1,0} copy(get-tuple-element.2)
-  ROOT tuple = (bf16[], bf16[12,512,32,64]{3,2,1,0}, bf16[12,512,2,32,64]{4,3,2,1,0}) tuple(get-tuple-element, copy.5, copy.6)
-}
-
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-  AlgebraicSimplifierOptions alg_sim_options;
-  alg_sim_options.set_supports_non_canonical_dots(false);
-  alg_sim_options.set_is_layout_sensitive(true);
-  alg_sim_options.set_enable_conv_operand_swap(false);
-  AlgebraicSimplifier alge_simp{alg_sim_options};
-  ReshapeDecomposer reshape_decomposer;
-  LayoutNormalization layout_normalizer;
-  HloCSE cse{/*is_layout_sensitive=*/true};
-  TF_ASSERT_OK(RunHloPass(&reshape_decomposer, m.get()).status());
-  TF_ASSERT_OK(RunHloPass(&layout_normalizer, m.get()).status());
-  TF_ASSERT_OK(RunHloPass(&cse, m.get()).status());
-  TF_ASSERT_OK(RunHloPass(&alge_simp, m.get()).status());
-
-  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
-                                         GetCudnnVersion()};
-  TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
-
-  CudnnFusedMHATransposeFusion fmha_transpose_fusion;
-
-  HloDCE dce;
-  TF_ASSERT_OK(RunHloPass(&alge_simp, m.get()).status());
-  TF_ASSERT_OK(RunHloPass(&fmha_transpose_fusion, m.get()).status());
-
-  TF_ASSERT_OK(RunHloPass(&dce, m.get()).status());
-
-  EXPECT_EQ(CountFusedAttentionCall(m.get()), 1);
-  EXPECT_EQ(CountFusedAttentionCall(m.get(), /*is_backward*/ true), 1);
-}
-
-TEST_F(CudnnFusedMhaRewriterTestHloTest,
-       ActivationHasMoreThan1UserShouldNotLower) {
-  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
-  const char* module_str = R"(
-HloModule test
-
-%region_50.2457 (Arg_0.2458: bf16[], Arg_1.2459: bf16[]) -> bf16[] {
-  %Arg_0.2458 = bf16[] parameter(0)
-  %Arg_1.2459 = bf16[] parameter(1)
-  ROOT %maximum.2 = bf16[] maximum(bf16[] %Arg_0.2458, bf16[] %Arg_1.2459)
-}
-
-%region_36.2316 (Arg_0.2317: f32[], Arg_1.2318: f32[]) -> f32[] {
-  %Arg_0.2317 = f32[] parameter(0)
-  %Arg_1.2318 = f32[] parameter(1)
-  ROOT %add.342 = f32[] add(f32[] %Arg_0.2317, f32[] %Arg_1.2318)
-}
-
-ENTRY main {
-  %transpose.482 = bf16[4,5,64]{2,1,0} parameter(0)
-  %transpose.484 = bf16[4,64,5]{2,1,0} parameter(1)
-  %dot.20 = bf16[4,5,5]{2,1,0} dot(bf16[4,5,64]{2,1,0} %transpose.482, bf16[4,64,5]{2,1,0} %transpose.484), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}
-  %constant.2515 = bf16[] constant(0.125)
-  %broadcast.789 = bf16[4,5,5]{2,1,0} broadcast(bf16[] %constant.2515), dimensions={}
-  %multiply.267 = bf16[4,5,5]{2,1,0} multiply(bf16[4,5,5]{2,1,0} %dot.20, bf16[4,5,5]{2,1,0} %broadcast.789)
-  %constant.287 = f32[] constant(-1)
-  %broadcast.792 = bf16[4,5,5]{2,1,0} parameter(3)
-  %add.348 = bf16[4,5,5]{2,1,0} add(bf16[4,5,5]{2,1,0} %multiply.267, bf16[4,5,5]{2,1,0} %broadcast.792)
-  %constant.2510 = bf16[] constant(-inf)
-  %reduce.2550 = bf16[4,5]{1,0} reduce(bf16[4,5,5]{2,1,0} %add.348, bf16[] %constant.2510), dimensions={2}, to_apply=%region_50.2457
-  %broadcast.793 = bf16[4,5,5]{2,1,0} broadcast(bf16[4,5]{1,0} %reduce.2550), dimensions={0,1}
-  %subtract.81 = bf16[4,5,5]{2,1,0} subtract(bf16[4,5,5]{2,1,0} %add.348, bf16[4,5,5]{2,1,0} %broadcast.793)
-  %exponential.21 = bf16[4,5,5]{2,1,0} exponential(bf16[4,5,5]{2,1,0} %subtract.81)
-  %convert.180 = f32[4,5,5]{2,1,0} convert(bf16[4,5,5]{2,1,0} %exponential.21)
-  %constant.2509 = f32[] constant(0)
-  %reduce.2558 = f32[4,5]{1,0} reduce(f32[4,5,5]{2,1,0} %convert.180, f32[] %constant.2509), dimensions={2}, to_apply=%region_36.2316
-  %convert.182 = bf16[4,5]{1,0} convert(f32[4,5]{1,0} %reduce.2558)
-  %broadcast.794 = bf16[4,5,5]{2,1,0} broadcast(bf16[4,5]{1,0} %convert.182), dimensions={0,1}
-  %divide.25 = bf16[4,5,5]{2,1,0} divide(bf16[4,5,5]{2,1,0} %exponential.21, bf16[4,5,5]{2,1,0} %broadcast.794)
-  %transpose.481 = bf16[4,64,5]{2,1,0} parameter(2)
-  %dot.21 = bf16[4,64,5]{2,1,0} dot(bf16[4,64,5]{2,1,0} %transpose.481, bf16[4,5,5]{2,1,0} %divide.25), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={2}
-  ROOT %tuple.2668 = (bf16[4,5,5]{2,1,0}, bf16[4,64,5]{2,1,0}) tuple(bf16[4,5,5]{2,1,0} %divide.25, bf16[4,64,5]{2,1,0} %dot.21)
-})";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
-                                         GetCudnnVersion()};
-  TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
-  HloDCE dce;
-  TF_ASSERT_OK(RunHloPass(&dce, m.get()).status());
-
-  HloVerifier verifier(/*layout_sensitive=*/false,
-                       /*allow_mixed_precision*/ true);
-  ASSERT_IS_OK(verifier.Run(m.get()).status());
-
-  EXPECT_EQ(CountFusedAttentionCall(m.get()), 0);
-}
-
-TEST_F(CudnnFusedMhaRewriterTestHloTest,
-       F16InvalidTrainingBmm1ScaleBiasMaskSoftmaxBmm2ShouldNotBeLowered) {
-  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
-  const char* module_str = R"(
-HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(f16[2,6,128,64]{3,2,1,0},f16[2,6,64,128]{3,2,1,0},f16[2,6,128,64]{3,2,1,0},f16[2,6,128,64]{3,2,1,0})->(f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,64,128]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true}
-
-region_0.21 {
-  Arg_0.22 = f16[] parameter(0)
-  Arg_1.23 = f16[] parameter(1)
-  ROOT maximum = f16[] maximum(Arg_0.22, Arg_1.23)
-}
-
-region_1.33 {
-  Arg_0.34 = f32[] parameter(0)
-  Arg_1.35 = f32[] parameter(1)
-  ROOT add = f32[] add(Arg_0.34, Arg_1.35)
-}
-
-region_2.55 {
-  Arg_0.56 = f16[] parameter(0)
-  Arg_1.57 = f16[] parameter(1)
-  ROOT add.1 = f16[] add(Arg_0.56, Arg_1.57)
-}
-
-ENTRY main.82 {
-  constant.18 = pred[2,6,128,128]{3,2,1,0} constant({...})
-  Arg_0.1 = f16[2,6,128,64]{3,2,1,0} parameter(0), sharding={replicated}
-  Arg_1.2 = f16[2,6,64,128]{3,2,1,0} parameter(1), sharding={replicated}
-  dot.17 = f16[2,6,128,128]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  constant.22 = f16[] constant(2)
-  broadcast.24 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.22), dimensions={}
-  multiply.2 = f16[2,6,128,128]{3,2,1,0} multiply(dot.17, broadcast.24)
-  constant.19 = f16[] constant(1)
-  broadcast.13 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.19), dimensions={}
-  add.3 = f16[2,6,128,128]{3,2,1,0} add(multiply.2, broadcast.13)
-  constant.21 = f16[] constant(0)
-  broadcast.23 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.21), dimensions={}
-  select.1 = f16[2,6,128,128]{3,2,1,0} select(constant.18, add.3, broadcast.23)
-  constant.15 = f16[] constant(-inf)
-  reduce.25 = f16[2,6,128]{2,1,0} reduce(select.1, constant.15), dimensions={3}, to_apply=region_0.21
-  broadcast.17 = f16[2,6,128,128]{3,2,1,0} broadcast(reduce.25), dimensions={0,1,2}
-  subtract.1 = f16[2,6,128,128]{3,2,1,0} subtract(select.1, broadcast.17)
-  exponential.1 = f16[2,6,128,128]{3,2,1,0} exponential(subtract.1)
-  convert.5 = f32[2,6,128,128]{3,2,1,0} convert(exponential.1)
-  constant.17 = f32[] constant(0)
-  reduce.37 = f32[2,6,128]{2,1,0} reduce(convert.5, constant.17), dimensions={3}, to_apply=region_1.33
-  convert.9 = f16[2,6,128]{2,1,0} convert(reduce.37)
-  broadcast.26 = f16[2,6,128,128]{3,2,1,0} broadcast(convert.9), dimensions={0,1,2}
-  divide.5 = f16[2,6,128,128]{3,2,1,0} divide(exponential.1, broadcast.26)
-  Arg_2.3 = f16[2,6,128,64]{3,2,1,0} parameter(2), sharding={replicated}
-  dot.46 = f16[2,6,128,64]{3,2,1,0} dot(divide.5, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  Arg_3.4 = f16[2,6,128,64]{3,2,1,0} parameter(3), sharding={replicated}
-  dot.49 = f16[2,6,128,128]{3,2,1,0} dot(Arg_3.4, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
-  divide.4 = f16[2,6,128,128]{3,2,1,0} divide(dot.49, broadcast.26)
-  broadcast.20 = f16[2,6,128]{2,1,0} broadcast(constant.19), dimensions={}
-  multiply.3 = f16[2,6,128]{2,1,0} multiply(convert.9, convert.9)
-  divide.3 = f16[2,6,128]{2,1,0} divide(broadcast.20, multiply.3)
-  broadcast.21 = f16[2,6,128,128]{3,2,1,0} broadcast(divide.3), dimensions={0,1,2}
-  multiply.4 = f16[2,6,128,128]{3,2,1,0} multiply(dot.49, broadcast.21)
-  multiply.5 = f16[2,6,128,128]{3,2,1,0} multiply(multiply.4, exponential.1)
-  reduce.59 = f16[2,6,128]{2,1,0} reduce(multiply.5, constant.21), dimensions={3}, to_apply=region_2.55
-  broadcast.25 = f16[2,6,128,128]{3,2,1,0} broadcast(reduce.59), dimensions={0,1,2}
-  add.5 = f16[2,6,128,128]{3,2,1,0} add(divide.4, broadcast.25)
-  multiply.8 = f16[2,6,128,128]{3,2,1,0} multiply(add.5, exponential.1)
-  select.3 = f16[2,6,128,128]{3,2,1,0} select(constant.18, multiply.8, broadcast.23)
-  multiply.9 = f16[2,6,128,128]{3,2,1,0} multiply(select.3, broadcast.24)
-  dot.80 = f16[2,6,128,64]{3,2,1,0} dot(multiply.9, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
-  dot = f16[2,6,64,128]{3,2,1,0} dot(Arg_0.1, multiply.9), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  dot.1 = f16[2,6,128,64]{3,2,1,0} dot(divide.5, Arg_3.4), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  ROOT tuple.81 = (f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,64,128]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}) tuple(dot.46, dot.80, dot, dot.1)
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
-                                         GetCudnnVersion()};
-  TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
-  HloDCE dce;
-  TF_ASSERT_OK(RunHloPass(&dce, m.get()).status());
-
-  ComputationLayout computation_layout(
-      m->entry_computation()->ComputeProgramShape());
-
-  HloVerifier verifier(/*layout_sensitive=*/false,
-                       /*allow_mixed_precision*/ true);
-  ASSERT_IS_OK(verifier.Run(m.get()).status());
-
-  // The backward pattern in the graph is not a valid fmha pattern,
-  // we expect no rewrite happening.
-  EXPECT_EQ(CountFusedAttentionCall(m.get()), 0);
-  EXPECT_EQ(CountFusedAttentionCall(m.get(), /*is_backward*/ true), 0);
-}
-
-TEST_F(CudnnFusedMhaRewriterTestHloTest,
-       F16InvalidTrainingBmm1ScaleBiasMaskSoftmaxDropoutBmm2ShouldNotLower) {
-  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
-  const char* module_str = R"(
-HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(f16[2,6,128,64]{3,2,1,0},f16[2,6,64,128]{3,2,1,0},f16[2,6,128,64]{3,2,1,0},f16[2,6,128,64]{3,2,1,0})->(f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,64,128]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true}
-
-region_0.38 {
-  Arg_0.39 = f16[] parameter(0)
-  Arg_1.40 = f16[] parameter(1)
-  ROOT maximum.1 = f16[] maximum(Arg_0.39, Arg_1.40)
-}
-
-region_1.50 {
-  Arg_0.51 = f32[] parameter(0)
-  Arg_1.52 = f32[] parameter(1)
-  ROOT add.2 = f32[] add(Arg_0.51, Arg_1.52)
-}
-
-region_2.99 {
-  Arg_0.100 = f16[] parameter(0)
-  Arg_1.101 = f16[] parameter(1)
-  ROOT add.3 = f16[] add(Arg_0.100, Arg_1.101)
-}
-
-ENTRY main.126 {
-  constant.6 = u32[1]{0} constant({2718843009})
-  constant.8 = u32[1]{0} constant({1272950319})
-  constant.10 = u32[1]{0} constant({0})
-  constant.12 = u32[1]{0} constant({2711844646})
-  custom-call.65 = (u32[1]{0}, u32[1]{0}) custom-call(constant.6, constant.8, constant.10, constant.12), custom_call_target="cu_threefry2x32", operand_layout_constraints={u32[1]{0}, u32[1]{0}, u32[1]{0}, u32[1]{0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config="\001\000\000\000\000\000\000\000"
-  get-tuple-element.66 = u32[1]{0} get-tuple-element(custom-call.65), index=0
-  bitcast.343 = u32[] bitcast(get-tuple-element.66)
-  broadcast.27 = u32[98304]{0} broadcast(bitcast.343), dimensions={}
-  get-tuple-element.67 = u32[1]{0} get-tuple-element(custom-call.65), index=1
-  bitcast.344 = u32[] bitcast(get-tuple-element.67)
-  broadcast.28 = u32[98304]{0} broadcast(bitcast.344), dimensions={}
-  iota.68 = u32[196608]{0} iota(), iota_dimension=0
-  slice = u32[98304]{0} slice(iota.68), slice={[0:98304]}
-  slice.1 = u32[98304]{0} slice(iota.68), slice={[98304:196608]}
-  custom-call.75 = (u32[98304]{0}, u32[98304]{0}) custom-call(broadcast.27, broadcast.28, slice, slice.1), custom_call_target="cu_threefry2x32", operand_layout_constraints={u32[98304]{0}, u32[98304]{0}, u32[98304]{0}, u32[98304]{0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config="\000\200\001\000\000\000\000\000"
-  get-tuple-element.76 = u32[98304]{0} get-tuple-element(custom-call.75), index=0
-  get-tuple-element.77 = u32[98304]{0} get-tuple-element(custom-call.75), index=1
-  concatenate.2 = u32[196608]{0} concatenate(get-tuple-element.76, get-tuple-element.77), dimensions={0}
-  constant.56 = u32[] constant(9)
-  broadcast.63 = u32[196608]{0} broadcast(constant.56), dimensions={}
-  shift-right-logical.3 = u32[196608]{0} shift-right-logical(concatenate.2, broadcast.63)
-  constant.57 = u32[] constant(1065353216)
-  broadcast.64 = u32[196608]{0} broadcast(constant.57), dimensions={}
-  or.3 = u32[196608]{0} or(shift-right-logical.3, broadcast.64)
-  bitcast-convert.3 = f32[196608]{0} bitcast-convert(or.3)
-  constant.58 = f32[] constant(-1)
-  broadcast.65 = f32[196608]{0} broadcast(constant.58), dimensions={}
-  add.10 = f32[196608]{0} add(bitcast-convert.3, broadcast.65)
-  constant.48 = f32[] constant(0)
-  broadcast.66 = f32[196608]{0} broadcast(constant.48), dimensions={}
-  maximum.4 = f32[196608]{0} maximum(add.10, broadcast.66)
-  constant.59 = f32[] constant(0.9)
-  broadcast.67 = f32[196608]{0} broadcast(constant.59), dimensions={}
-  compare.3 = pred[196608]{0} compare(maximum.4, broadcast.67), direction=LT
-  bitcast.308 = pred[2,6,128,128]{3,2,1,0} bitcast(compare.3)
-  constant.44 = pred[2,6,128,128]{3,2,1,0} constant({...})
-  Arg_0.1 = f16[2,6,128,64]{3,2,1,0} parameter(0), sharding={replicated}
-  Arg_1.2 = f16[2,6,64,128]{3,2,1,0} parameter(1), sharding={replicated}
-  dot.34 = f16[2,6,128,128]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  constant.55 = f16[] constant(2)
-  broadcast.61 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.55), dimensions={}
-  multiply.8 = f16[2,6,128,128]{3,2,1,0} multiply(dot.34, broadcast.61)
-  constant.52 = f16[] constant(1)
-  broadcast.39 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.52), dimensions={}
-  add.6 = f16[2,6,128,128]{3,2,1,0} add(multiply.8, broadcast.39)
-  constant.54 = f16[] constant(0)
-  broadcast.52 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.54), dimensions={}
-  select.1 = f16[2,6,128,128]{3,2,1,0} select(constant.44, add.6, broadcast.52)
-  constant.41 = f16[] constant(-inf)
-  reduce.42 = f16[2,6,128]{2,1,0} reduce(select.1, constant.41), dimensions={3}, to_apply=region_0.38
-  broadcast.42 = f16[2,6,128,128]{3,2,1,0} broadcast(reduce.42), dimensions={0,1,2}
-  subtract.1 = f16[2,6,128,128]{3,2,1,0} subtract(select.1, broadcast.42)
-  exponential.1 = f16[2,6,128,128]{3,2,1,0} exponential(subtract.1)
-  convert.5 = f32[2,6,128,128]{3,2,1,0} convert(exponential.1)
-  reduce.54 = f32[2,6,128]{2,1,0} reduce(convert.5, constant.48), dimensions={3}, to_apply=region_1.50
-  convert.9 = f16[2,6,128]{2,1,0} convert(reduce.54)
-  broadcast.68 = f16[2,6,128,128]{3,2,1,0} broadcast(convert.9), dimensions={0,1,2}
-  divide.5 = f16[2,6,128,128]{3,2,1,0} divide(exponential.1, broadcast.68)
-  constant.60 = f16[] constant(1.1113)
-  broadcast.69 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.60), dimensions={}
-  multiply.20 = f16[2,6,128,128]{3,2,1,0} multiply(divide.5, broadcast.69)
-  select.8 = f16[2,6,128,128]{3,2,1,0} select(bitcast.308, multiply.20, broadcast.52)
-  Arg_2.3 = f16[2,6,128,64]{3,2,1,0} parameter(2), sharding={replicated}
-  dot.88 = f16[2,6,128,64]{3,2,1,0} dot(select.8, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  bitcast.248 = pred[2,6,128,128]{3,2,1,0} bitcast(compare.3)
-  Arg_3.4 = f16[2,6,128,64]{3,2,1,0} parameter(3), sharding={replicated}
-  dot.91 = f16[2,6,128,128]{3,2,1,0} dot(Arg_3.4, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
-  select.6 = f16[2,6,128,128]{3,2,1,0} select(bitcast.248, dot.91, broadcast.52)
-  multiply.17 = f16[2,6,128,128]{3,2,1,0} multiply(select.6, broadcast.69)
-  divide.4 = f16[2,6,128,128]{3,2,1,0} divide(multiply.17, broadcast.68)
-  broadcast.55 = f16[2,6,128]{2,1,0} broadcast(constant.52), dimensions={}
-  multiply.11 = f16[2,6,128]{2,1,0} multiply(convert.9, convert.9)
-  divide.3 = f16[2,6,128]{2,1,0} divide(broadcast.55, multiply.11)
-  broadcast.56 = f16[2,6,128]{2,1,0} broadcast(constant.60), dimensions={}
-  multiply.12 = f16[2,6,128]{2,1,0} multiply(divide.3, broadcast.56)
-  broadcast.58 = f16[2,6,128,128]{3,2,1,0} broadcast(multiply.12), dimensions={0,1,2}
-  multiply.13 = f16[2,6,128,128]{3,2,1,0} multiply(select.6, broadcast.58)
-  multiply.14 = f16[2,6,128,128]{3,2,1,0} multiply(multiply.13, exponential.1)
-  reduce.103 = f16[2,6,128]{2,1,0} reduce(multiply.14, constant.54), dimensions={3}, to_apply=region_2.99
-  broadcast.62 = f16[2,6,128,128]{3,2,1,0} broadcast(reduce.103), dimensions={0,1,2}
-  add.9 = f16[2,6,128,128]{3,2,1,0} add(divide.4, broadcast.62)
-  multiply.18 = f16[2,6,128,128]{3,2,1,0} multiply(add.9, exponential.1)
-  select.7 = f16[2,6,128,128]{3,2,1,0} select(constant.44, multiply.18, broadcast.52)
-  multiply.19 = f16[2,6,128,128]{3,2,1,0} multiply(select.7, broadcast.61)
-  dot.124 = f16[2,6,128,64]{3,2,1,0} dot(multiply.19, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
-  dot = f16[2,6,64,128]{3,2,1,0} dot(Arg_0.1, multiply.19), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  dot.1 = f16[2,6,128,64]{3,2,1,0} dot(select.8, Arg_3.4), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  ROOT tuple.125 = (f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,64,128]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}) tuple(dot.88, dot.124, dot, dot.1)
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
-                                         GetCudnnVersion()};
-  TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
-  HloDCE dce;
-  TF_ASSERT_OK(RunHloPass(&dce, m.get()).status());
-
-  ComputationLayout computation_layout(
-      m->entry_computation()->ComputeProgramShape());
-
-  HloVerifier verifier(/*layout_sensitive=*/false,
-                       /*allow_mixed_precision*/ true);
-  ASSERT_IS_OK(verifier.Run(m.get()).status());
-
-  // The backward pattern in the graph is not a valid fmha pattern,
-  // we expect no rewrite happening.
-  EXPECT_EQ(CountFusedAttentionCall(m.get()), 0);
-  EXPECT_EQ(CountFusedAttentionCall(m.get(), /*is_backward*/ true), 0);
-}
-
-TEST_F(CudnnFusedMhaRewriterTestHloTest,
-       F16TrainingBmm1ScaleBiasSoftmaxBmm2QTranspose) {
-  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
-  const char* module_str = R"(
-HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(f16[2,6,64,128]{3,2,1,0},f16[2,6,64,128]{3,2,1,0},f16[2,6,128,64]{3,2,1,0},f16[2,6,128,64]{3,2,1,0})->(f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,64,128]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true}
-
-region_0.21 {
-  Arg_0.22 = f16[] parameter(0)
-  Arg_1.23 = f16[] parameter(1)
-  ROOT maximum = f16[] maximum(Arg_0.22, Arg_1.23)
-}
-
-region_1.33 {
-  Arg_0.34 = f32[] parameter(0)
-  Arg_1.35 = f32[] parameter(1)
-  ROOT add = f32[] add(Arg_0.34, Arg_1.35)
-}
-
-region_2.55 {
-  Arg_0.56 = f16[] parameter(0)
-  Arg_1.57 = f16[] parameter(1)
-  ROOT add.1 = f16[] add(Arg_0.56, Arg_1.57)
-}
-
-ENTRY main.82 {
-  Arg_0.1 = f16[2,6,64,128]{3,2,1,0} parameter(0), sharding={replicated}
-  Arg_1.2 = f16[2,6,64,128]{3,2,1,0} parameter(1), sharding={replicated}
-  dot.17 = f16[2,6,128,128]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  constant.22 = f16[] constant(2)
-  broadcast.24 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.22), dimensions={}
-  multiply.2 = f16[2,6,128,128]{3,2,1,0} multiply(dot.17, broadcast.24)
-  constant.19 = f16[] constant(1)
-  broadcast.13 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.19), dimensions={}
-  add.3 = f16[2,6,128,128]{3,2,1,0} add(multiply.2, broadcast.13)
-  constant.21 = f16[] constant(0)
-  constant.15 = f16[] constant(-inf)
-  reduce.25 = f16[2,6,128]{2,1,0} reduce(add.3, constant.15), dimensions={3}, to_apply=region_0.21
-  broadcast.17 = f16[2,6,128,128]{3,2,1,0} broadcast(reduce.25), dimensions={0,1,2}
-  subtract.1 = f16[2,6,128,128]{3,2,1,0} subtract(add.3, broadcast.17)
-  exponential.1 = f16[2,6,128,128]{3,2,1,0} exponential(subtract.1)
-  convert.5 = f32[2,6,128,128]{3,2,1,0} convert(exponential.1)
-  constant.17 = f32[] constant(0)
-  reduce.37 = f32[2,6,128]{2,1,0} reduce(convert.5, constant.17), dimensions={3}, to_apply=region_1.33
-  convert.9 = f16[2,6,128]{2,1,0} convert(reduce.37)
-  broadcast.26 = f16[2,6,128,128]{3,2,1,0} broadcast(convert.9), dimensions={0,1,2}
-  divide.5 = f16[2,6,128,128]{3,2,1,0} divide(exponential.1, broadcast.26)
-  Arg_2.3 = f16[2,6,128,64]{3,2,1,0} parameter(2), sharding={replicated}
-  dot.46 = f16[2,6,128,64]{3,2,1,0} dot(divide.5, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  Arg_3.4 = f16[2,6,128,64]{3,2,1,0} parameter(3), sharding={replicated}
-  dot.49 = f16[2,6,128,128]{3,2,1,0} dot(Arg_3.4, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
-  divide.4 = f16[2,6,128,128]{3,2,1,0} divide(dot.49, broadcast.26)
-  broadcast.20 = f16[2,6,128]{2,1,0} broadcast(constant.19), dimensions={}
-  multiply.3 = f16[2,6,128]{2,1,0} multiply(convert.9, convert.9)
-  divide.3 = f16[2,6,128]{2,1,0} divide(broadcast.20, multiply.3)
-  broadcast.21 = f16[2,6,128,128]{3,2,1,0} broadcast(divide.3), dimensions={0,1,2}
-  multiply.4 = f16[2,6,128,128]{3,2,1,0} multiply(dot.49, broadcast.21)
-  multiply.5 = f16[2,6,128,128]{3,2,1,0} multiply(multiply.4, exponential.1)
-  reduce.59 = f16[2,6,128]{2,1,0} reduce(multiply.5, constant.21), dimensions={3}, to_apply=region_2.55
-  negate.2 = f16[2,6,128]{2,1,0} negate(reduce.59)
-  broadcast.25 = f16[2,6,128,128]{3,2,1,0} broadcast(negate.2), dimensions={0,1,2}
-  add.5 = f16[2,6,128,128]{3,2,1,0} add(divide.4, broadcast.25)
-  multiply.8 = f16[2,6,128,128]{3,2,1,0} multiply(add.5, exponential.1)
-  multiply.9 = f16[2,6,128,128]{3,2,1,0} multiply(multiply.8, broadcast.24)
-  dot.80 = f16[2,6,128,64]{3,2,1,0} dot(multiply.9, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
-  dot = f16[2,6,64,128]{3,2,1,0} dot(Arg_0.1, multiply.9), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  dot.1 = f16[2,6,128,64]{3,2,1,0} dot(divide.5, Arg_3.4), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  ROOT tuple.81 = (f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,64,128]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}) tuple(dot.46, dot.80, dot, dot.1)
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
-                                         GetCudnnVersion()};
-  TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
-  HloDCE dce;
-  TF_ASSERT_OK(RunHloPass(&dce, m.get()).status());
-
-  ComputationLayout computation_layout(
-      m->entry_computation()->ComputeProgramShape());
-
-  const HloInstruction* fmha;
-
-  SCOPED_TRACE(m->ToString());
-  EXPECT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(m::Tuple(
-          m::GetTupleElement(
-              m::CustomCall(&fmha, {kCudnnfMHAScaleBiasSoftmaxCallTarget}), 0)
-              .WithShape(F16, {2, 6, 128, 64}),
-          m::GetTupleElement(
-              m::CustomCall(&fmha,
-                            {kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget}),
-              0)
-              .WithShape(F16, {2, 6, 128, 64}),
-          m::Transpose(
-              m::GetTupleElement(
-                  m::CustomCall({kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget}),
-                  1))
-              .WithShape(F16, {2, 6, 64, 128}),
-          m::GetTupleElement(
-              m::CustomCall({kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget}), 2)
-              .WithShape(F16, {2, 6, 128, 64}))));
-  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
-                          fmha->backend_config<GpuBackendConfig>());
-  const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config();
-  EXPECT_EQ(fmha->operands().size(), 7);
-  EXPECT_NEAR(config.dropout_rate(), 0, 1e-2);
-}
-
-TEST_F(CudnnFusedMhaRewriterTestHloTest,
-       F16Bmm1UnfusedSoftmaxBmm2IncorrectBmm1NumUsers) {
-  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
-  const char* module_str = R"(
-HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(f16[2,6,40,64]{3,2,1,0},f16[2,6,64,40]{3,2,1,0},f16[2,6,40,64]{3,2,1,0})->(f16[2,6,40,64]{3,2,1,0}, f16[2,6,40,40]{3,2,1,0})}
-
-region_0.7 {
-  Arg_0.8 = f16[] parameter(0)
-  Arg_1.9 = f16[] parameter(1)
-  ROOT maximum = f16[] maximum(Arg_0.8, Arg_1.9)
-}
-
-region_1.19 {
-  Arg_0.20 = f32[] parameter(0)
-  Arg_1.21 = f32[] parameter(1)
-  ROOT add = f32[] add(Arg_0.20, Arg_1.21)
-}
-
-ENTRY main.31 {
-  Arg_0.1 = f16[2,6,40,64]{3,2,1,0} parameter(0), sharding={replicated}
-  Arg_1.2 = f16[2,6,64,40]{3,2,1,0} parameter(1), sharding={replicated}
-  dot = f16[2,6,40,40]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={3}, rhs_contracting_dims={2}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1}
-  // extra user of bmm1
-  neg.1 = f16[2,6,40,40]{3,2,1,0} negate(dot)
-  constant = f16[] constant(-inf)
-  reduce.11 = f16[2,6,40]{2,1,0} reduce(dot, constant), dimensions={3}, to_apply=region_0.7
-  broadcast.3 = f16[2,6,40,40]{3,2,1,0} broadcast(reduce.11), dimensions={0,1,2}
-  subtract.1 = f16[2,6,40,40]{3,2,1,0} subtract(dot, broadcast.3)
-  exponential.1 = f16[2,6,40,40]{3,2,1,0} exponential(subtract.1)
-  convert.1 = f32[2,6,40,40]{3,2,1,0} convert(exponential.1)
-  constant.1 = f32[] constant(0)
-  reduce.23 = f32[2,6,40]{2,1,0} reduce(convert.1, constant.1), dimensions={3}, to_apply=region_1.19
-  convert.2 = f16[2,6,40]{2,1,0} convert(reduce.23)
-  broadcast.4 = f16[2,6,40,40]{3,2,1,0} broadcast(convert.2), dimensions={0,1,2}
-  divide = f16[2,6,40,40]{3,2,1,0} divide(exponential.1, broadcast.4)
-  Arg_2.3 = f16[2,6,40,64]{3,2,1,0} parameter(2), sharding={replicated}
-  dot.1 = f16[2,6,40,64]{3,2,1,0} dot(divide, Arg_2.3), lhs_contracting_dims={3}, rhs_contracting_dims={2}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1}
-  ROOT tuple.81 = (f16[2,6,40,64]{3,2,1,0}, f16[2,6,40,40]{3,2,1,0}) tuple(dot.1, neg.1)
-})";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
-                                         GetCudnnVersion()};
-  TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  EXPECT_THAT(m->entry_computation()->root_instruction(),
-              GmockMatch(m::Tuple(m::Dot(), m::Negate())));
-}
-
-TEST_F(CudnnFusedMhaRewriterTestHloTest,
-       F16Bmm1UnfusedSoftmaxBmm2IncorrectSoftmaxNumUsers) {
-  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
-  const char* module_str = R"(
-HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(f16[2,6,40,64]{3,2,1,0},f16[2,6,64,40]{3,2,1,0},f16[2,6,40,64]{3,2,1,0})->(f16[2,6,40,64]{3,2,1,0}, f16[2,6,40,40]{3,2,1,0})}
-
-region_0.7 {
-  Arg_0.8 = f16[] parameter(0)
-  Arg_1.9 = f16[] parameter(1)
-  ROOT maximum = f16[] maximum(Arg_0.8, Arg_1.9)
-}
-
-region_1.19 {
-  Arg_0.20 = f32[] parameter(0)
-  Arg_1.21 = f32[] parameter(1)
-  ROOT add = f32[] add(Arg_0.20, Arg_1.21)
-}
-
-ENTRY main.31 {
-  Arg_0.1 = f16[2,6,40,64]{3,2,1,0} parameter(0), sharding={replicated}
-  Arg_1.2 = f16[2,6,64,40]{3,2,1,0} parameter(1), sharding={replicated}
-  dot = f16[2,6,40,40]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={3}, rhs_contracting_dims={2}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1}
-  constant = f16[] constant(-inf)
-  reduce.11 = f16[2,6,40]{2,1,0} reduce(dot, constant), dimensions={3}, to_apply=region_0.7
-  broadcast.3 = f16[2,6,40,40]{3,2,1,0} broadcast(reduce.11), dimensions={0,1,2}
-  subtract.1 = f16[2,6,40,40]{3,2,1,0} subtract(dot, broadcast.3)
-  // extra user of softmax sub node
-  neg.1 = f16[2,6,40,40]{3,2,1,0} negate(subtract.1)
-  exponential.1 = f16[2,6,40,40]{3,2,1,0} exponential(subtract.1)
-  convert.1 = f32[2,6,40,40]{3,2,1,0} convert(exponential.1)
-  constant.1 = f32[] constant(0)
-  reduce.23 = f32[2,6,40]{2,1,0} reduce(convert.1, constant.1), dimensions={3}, to_apply=region_1.19
-  convert.2 = f16[2,6,40]{2,1,0} convert(reduce.23)
-  broadcast.4 = f16[2,6,40,40]{3,2,1,0} broadcast(convert.2), dimensions={0,1,2}
-  divide = f16[2,6,40,40]{3,2,1,0} divide(exponential.1, broadcast.4)
-  Arg_2.3 = f16[2,6,40,64]{3,2,1,0} parameter(2), sharding={replicated}
-  dot.1 = f16[2,6,40,64]{3,2,1,0} dot(divide, Arg_2.3), lhs_contracting_dims={3}, rhs_contracting_dims={2}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1}
-  ROOT tuple.81 = (f16[2,6,40,64]{3,2,1,0}, f16[2,6,40,40]{3,2,1,0}) tuple(dot.1, neg.1)
-})";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
-                                         GetCudnnVersion()};
-  TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  EXPECT_THAT(m->entry_computation()->root_instruction(),
-              GmockMatch(m::Tuple(m::Dot(), m::Negate())));
-}
-
-TEST_F(CudnnFusedMhaRewriterTestHloTest,
-       F16TrainingBmm1ScaleBiasSoftmaxBmm2IncorrectSoftmaxBwdNumUsers) {
-  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
-  const char* module_str = R"(
-HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(f16[2,6,64,128]{3,2,1,0},f16[2,6,64,128]{3,2,1,0},f16[2,6,128,64]{3,2,1,0},f16[2,6,128,64]{3,2,1,0})->(f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,64,128]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,128]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true}
-
-region_0.21 {
-  Arg_0.22 = f16[] parameter(0)
-  Arg_1.23 = f16[] parameter(1)
-  ROOT maximum = f16[] maximum(Arg_0.22, Arg_1.23)
-}
-
-region_1.33 {
-  Arg_0.34 = f32[] parameter(0)
-  Arg_1.35 = f32[] parameter(1)
-  ROOT add = f32[] add(Arg_0.34, Arg_1.35)
-}
-
-region_2.55 {
-  Arg_0.56 = f16[] parameter(0)
-  Arg_1.57 = f16[] parameter(1)
-  ROOT add.1 = f16[] add(Arg_0.56, Arg_1.57)
-}
-
-ENTRY main.82 {
-  Arg_0.1 = f16[2,6,64,128]{3,2,1,0} parameter(0), sharding={replicated}
-  Arg_1.2 = f16[2,6,64,128]{3,2,1,0} parameter(1), sharding={replicated}
-  dot.17 = f16[2,6,128,128]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  constant.22 = f16[] constant(2)
-  broadcast.24 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.22), dimensions={}
-  multiply.2 = f16[2,6,128,128]{3,2,1,0} multiply(dot.17, broadcast.24)
-  constant.19 = f16[] constant(1)
-  broadcast.13 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.19), dimensions={}
-  add.3 = f16[2,6,128,128]{3,2,1,0} add(multiply.2, broadcast.13)
-  constant.21 = f16[] constant(0)
-  constant.15 = f16[] constant(-inf)
-  reduce.25 = f16[2,6,128]{2,1,0} reduce(add.3, constant.15), dimensions={3}, to_apply=region_0.21
-  broadcast.17 = f16[2,6,128,128]{3,2,1,0} broadcast(reduce.25), dimensions={0,1,2}
-  subtract.1 = f16[2,6,128,128]{3,2,1,0} subtract(add.3, broadcast.17)
-  exponential.1 = f16[2,6,128,128]{3,2,1,0} exponential(subtract.1)
-  convert.5 = f32[2,6,128,128]{3,2,1,0} convert(exponential.1)
-  constant.17 = f32[] constant(0)
-  reduce.37 = f32[2,6,128]{2,1,0} reduce(convert.5, constant.17), dimensions={3}, to_apply=region_1.33
-  convert.9 = f16[2,6,128]{2,1,0} convert(reduce.37)
-  broadcast.26 = f16[2,6,128,128]{3,2,1,0} broadcast(convert.9), dimensions={0,1,2}
-  divide.5 = f16[2,6,128,128]{3,2,1,0} divide(exponential.1, broadcast.26)
-  Arg_2.3 = f16[2,6,128,64]{3,2,1,0} parameter(2), sharding={replicated}
-  dot.46 = f16[2,6,128,64]{3,2,1,0} dot(divide.5, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  Arg_3.4 = f16[2,6,128,64]{3,2,1,0} parameter(3), sharding={replicated}
-  dot.49 = f16[2,6,128,128]{3,2,1,0} dot(Arg_3.4, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
-  divide.4 = f16[2,6,128,128]{3,2,1,0} divide(dot.49, broadcast.26)
-  // extra user of softmax bwd divide node
-  neg.1 = f16[2,6,128,128]{3,2,1,0} negate(divide.4)
-  broadcast.20 = f16[2,6,128]{2,1,0} broadcast(constant.19), dimensions={}
-  multiply.3 = f16[2,6,128]{2,1,0} multiply(convert.9, convert.9)
-  divide.3 = f16[2,6,128]{2,1,0} divide(broadcast.20, multiply.3)
-  broadcast.21 = f16[2,6,128,128]{3,2,1,0} broadcast(divide.3), dimensions={0,1,2}
-  multiply.4 = f16[2,6,128,128]{3,2,1,0} multiply(dot.49, broadcast.21)
-  multiply.5 = f16[2,6,128,128]{3,2,1,0} multiply(multiply.4, exponential.1)
-  reduce.59 = f16[2,6,128]{2,1,0} reduce(multiply.5, constant.21), dimensions={3}, to_apply=region_2.55
-  negate.2 = f16[2,6,128]{2,1,0} negate(reduce.59)
-  broadcast.25 = f16[2,6,128,128]{3,2,1,0} broadcast(negate.2), dimensions={0,1,2}
-  add.5 = f16[2,6,128,128]{3,2,1,0} add(divide.4, broadcast.25)
-  multiply.8 = f16[2,6,128,128]{3,2,1,0} multiply(add.5, exponential.1)
-  multiply.9 = f16[2,6,128,128]{3,2,1,0} multiply(multiply.8, broadcast.24)
-  dot.80 = f16[2,6,128,64]{3,2,1,0} dot(multiply.9, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
-  dot = f16[2,6,64,128]{3,2,1,0} dot(Arg_0.1, multiply.9), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  dot.1 = f16[2,6,128,64]{3,2,1,0} dot(divide.5, Arg_3.4), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  ROOT tuple.81 = (f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,64,128]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,128]{3,2,1,0}) tuple(dot.46, dot.80, dot, dot.1, neg.1)
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
-                                         GetCudnnVersion()};
-  TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
-  HloDCE dce;
-  TF_ASSERT_OK(RunHloPass(&dce, m.get()).status());
-
-  ComputationLayout computation_layout(
-      m->entry_computation()->ComputeProgramShape());
-
-  SCOPED_TRACE(m->ToString());
-  EXPECT_THAT(m->entry_computation()->root_instruction(),
-              GmockMatch(m::Tuple(m::Dot(), m::Dot(), m::Dot(), m::Dot(),
-                                  m::Negate())));
-}
-
-TEST_F(CudnnFusedMhaRewriterTestHloTest, F16Bmm1SoftmaxBmm2IncorrectRank) {
-  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
-  const char* module_str = R"(
-HloModule reproducer, entry_computation_layout={(f16[1,8,16,5,128]{4,3,2,1,0}, f16[1,8,16,5,128]{4,3,2,1,0}, f16[1,8,16,5,128]{4,3,2,1,0}, f32[128,2,64]{2,1,0}, f32[2,64]{1,0}, /*index=5*/f32[128,2,64]{2,1,0}, f32[2,64]{1,0}, f32[128,2,64]{2,1,0}, f32[2,64]{1,0})->f16[8,16,2,5,64]{4,3,2,1,0}}
-
-region_0.36 {
-  Arg_0.37 = f16[] parameter(0)
-  Arg_1.38 = f16[] parameter(1)
-  ROOT maximum = f16[] maximum(Arg_0.37, Arg_1.38)
-}
-
-region_1.48 {
-  Arg_0.49 = f32[] parameter(0)
-  Arg_1.50 = f32[] parameter(1)
-  ROOT add.1 = f32[] add(Arg_0.49, Arg_1.50)
-}
-
-ENTRY main {
-  arg2.3 = f16[1,8,16,5,128]{4,3,2,1,0} parameter(2), parameter_replication={false}
-  bitcast.31 = f16[640,128]{1,0} bitcast(arg2.3)
-  arg5.6 = f32[128,2,64]{2,1,0} parameter(5), parameter_replication={false}
-  convert.3 = f16[128,2,64]{2,1,0} convert(arg5.6)
-  bitcast.36 = f16[128,128]{1,0} bitcast(convert.3)
-  dot = f16[640,128]{1,0} dot(bitcast.31, bitcast.36), lhs_contracting_dims={1}, rhs_contracting_dims={0}, frontend_attributes={grad_x="false",grad_y="false"}
-  bitcast.39 = f16[1,8,16,5,2,64]{5,4,3,2,1,0} bitcast(dot)
-  transpose.27 = f16[1,8,16,2,5,64]{5,4,3,2,1,0} transpose(bitcast.39), dimensions={0,1,2,4,3,5}, frontend_attributes={grad_x="false",grad_y="false"}
-  arg6.7 = f32[2,64]{1,0} parameter(6), parameter_replication={false}
-  convert.4 = f16[2,64]{1,0} convert(arg6.7)
-  broadcast.9 = f16[1,8,16,2,5,64]{5,4,3,2,1,0} broadcast(convert.4), dimensions={3,5}
-  add.2 = f16[1,8,16,2,5,64]{5,4,3,2,1,0} add(transpose.27, broadcast.9)
-  bitcast.49 = f16[8,16,2,5,64]{4,3,2,1,0} bitcast(add.2)
-  arg0.1 = f16[1,8,16,5,128]{4,3,2,1,0} parameter(0), parameter_replication={false}
-  bitcast.53 = f16[640,128]{1,0} bitcast(arg0.1)
-  arg3.4 = f32[128,2,64]{2,1,0} parameter(3), parameter_replication={false}
-  convert.5 = f16[128,2,64]{2,1,0} convert(arg3.4)
-  bitcast.58 = f16[128,128]{1,0} bitcast(convert.5)
-  dot.1 = f16[640,128]{1,0} dot(bitcast.53, bitcast.58), lhs_contracting_dims={1}, rhs_contracting_dims={0}, frontend_attributes={grad_x="false",grad_y="false"}
-  bitcast.61 = f16[1,8,16,5,2,64]{5,4,3,2,1,0} bitcast(dot.1)
-  transpose.28 = f16[1,8,16,2,64,5]{5,4,3,2,1,0} transpose(bitcast.61), dimensions={0,1,2,4,5,3}, frontend_attributes={grad_x="false",grad_y="false"}
-  arg4.5 = f32[2,64]{1,0} parameter(4), parameter_replication={false}
-  convert.6 = f16[2,64]{1,0} convert(arg4.5)
-  broadcast.10 = f16[1,8,16,2,64,5]{5,4,3,2,1,0} broadcast(convert.6), dimensions={3,4}
-  add.3 = f16[1,8,16,2,64,5]{5,4,3,2,1,0} add(transpose.28, broadcast.10)
-  constant.29 = f16[] constant(0.125)
-  broadcast.11 = f16[1,8,16,2,64,5]{5,4,3,2,1,0} broadcast(constant.29), dimensions={}
-  multiply = f16[1,8,16,2,64,5]{5,4,3,2,1,0} multiply(add.3, broadcast.11)
-  bitcast.74 = f16[8,16,2,64,5]{4,3,2,1,0} bitcast(multiply)
-  dot.6 = f16[8,16,2,5,5]{4,3,2,1,0} dot(bitcast.49, bitcast.74), lhs_batch_dims={0,1,2}, lhs_contracting_dims={4}, rhs_batch_dims={0,1,2}, rhs_contracting_dims={3}, frontend_attributes={grad_x="false",grad_y="false"}
-  constant.35 = f16[] constant(-inf)
-  reduce.1 = f16[8,16,2,5]{3,2,1,0} reduce(dot.6, constant.35), dimensions={3}, to_apply=region_0.36
-  broadcast.12 = f16[8,16,2,5,5]{4,3,2,1,0} broadcast(reduce.1), dimensions={0,1,2,4}
-  subtract.2 = f16[8,16,2,5,5]{4,3,2,1,0} subtract(dot.6, broadcast.12)
-  exponential.2 = f16[8,16,2,5,5]{4,3,2,1,0} exponential(subtract.2)
-  convert.7 = f32[8,16,2,5,5]{4,3,2,1,0} convert(exponential.2)
-  constant.34 = f32[] constant(0)
-  reduce.3 = f32[8,16,2,5]{3,2,1,0} reduce(convert.7, constant.34), dimensions={3}, to_apply=region_1.48
-  convert.8 = f16[8,16,2,5]{3,2,1,0} convert(reduce.3)
-  broadcast.13 = f16[8,16,2,5,5]{4,3,2,1,0} broadcast(convert.8), dimensions={0,1,2,4}
-  divide.2 = f16[8,16,2,5,5]{4,3,2,1,0} divide(exponential.2, broadcast.13)
-  bitcast.98 = f16[8,16,2,5,5]{3,4,2,1,0} bitcast(divide.2)
-  arg1.2 = f16[1,8,16,5,128]{4,3,2,1,0} parameter(1), parameter_replication={false}
-  bitcast.102 = f16[640,128]{1,0} bitcast(arg1.2)
-  arg7.8 = f32[128,2,64]{2,1,0} parameter(7), parameter_replication={false}
-  convert.9 = f16[128,2,64]{2,1,0} convert(arg7.8)
-  bitcast.107 = f16[128,128]{1,0} bitcast(convert.9)
-  dot.3 = f16[640,128]{1,0} dot(bitcast.102, bitcast.107), lhs_contracting_dims={1}, rhs_contracting_dims={0}, frontend_attributes={grad_x="false",grad_y="false"}
-  bitcast.110 = f16[1,8,16,5,2,64]{5,4,3,2,1,0} bitcast(dot.3)
-  transpose.30 = f16[1,8,16,2,5,64]{5,4,3,2,1,0} transpose(bitcast.110), dimensions={0,1,2,4,3,5}, frontend_attributes={grad_x="false",grad_y="false"}
-  arg8.9 = f32[2,64]{1,0} parameter(8), parameter_replication={false}
-  convert.10 = f16[2,64]{1,0} convert(arg8.9)
-  broadcast.14 = f16[1,8,16,2,5,64]{5,4,3,2,1,0} broadcast(convert.10), dimensions={3,5}
-  add.4 = f16[1,8,16,2,5,64]{5,4,3,2,1,0} add(transpose.30, broadcast.14)
-  bitcast.120 = f16[8,16,2,5,64]{4,3,2,1,0} bitcast(add.4)
-  ROOT dot.7 = f16[8,16,2,5,64]{4,3,2,1,0} dot(bitcast.98, bitcast.120), lhs_batch_dims={0,1,2}, lhs_contracting_dims={4}, rhs_batch_dims={0,1,2}, rhs_contracting_dims={3}, frontend_attributes={grad_x="false",grad_y="false"}
-} // main
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
-                                         GetCudnnVersion()};
-  const auto status_or = RunHloPass(&fusedMhaRewriter, m.get());
-  TF_ASSERT_OK(status_or.status());
-  EXPECT_FALSE(status_or.value());
-
-  HloDCE dce;
-  TF_ASSERT_OK(RunHloPass(&dce, m.get()).status());
-
-  ComputationLayout computation_layout(
-      m->entry_computation()->ComputeProgramShape());
-
-  SCOPED_TRACE(m->ToString());
-  EXPECT_THAT(m->entry_computation()->root_instruction(), GmockMatch(m::Dot()));
-}
-
-TEST_F(CudnnFusedMhaRewriterTestHloTest,
-       F16TrainingBmm1ScaleBiasSoftmaxBmm2NonContractingDimNotDivisibleBy64) {
-  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
-  const char* module_str = R"(
-HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(f16[2,6,64,100]{3,2,1,0},f16[2,6,64,100]{3,2,1,0},f16[2,6,100,64]{3,2,1,0},f16[2,6,100,64]{3,2,1,0})->(f16[2,6,100,64]{3,2,1,0}, f16[2,6,100,64]{3,2,1,0}, f16[2,6,64,100]{3,2,1,0}, f16[2,6,100,64]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true}
-
-region_0.21 {
-  Arg_0.22 = f16[] parameter(0)
-  Arg_1.23 = f16[] parameter(1)
-  ROOT maximum = f16[] maximum(Arg_0.22, Arg_1.23)
-}
-
-region_1.33 {
-  Arg_0.34 = f32[] parameter(0)
-  Arg_1.35 = f32[] parameter(1)
-  ROOT add = f32[] add(Arg_0.34, Arg_1.35)
-}
-
-region_2.55 {
-  Arg_0.56 = f16[] parameter(0)
-  Arg_1.57 = f16[] parameter(1)
-  ROOT add.1 = f16[] add(Arg_0.56, Arg_1.57)
-}
-
-ENTRY main.82 {
-  Arg_0.1 = f16[2,6,64,100]{3,2,1,0} parameter(0), sharding={replicated}
-  Arg_1.2 = f16[2,6,64,100]{3,2,1,0} parameter(1), sharding={replicated}
-  dot.17 = f16[2,6,100,100]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  constant.22 = f16[] constant(2)
-  broadcast.24 = f16[2,6,100,100]{3,2,1,0} broadcast(constant.22), dimensions={}
-  multiply.2 = f16[2,6,100,100]{3,2,1,0} multiply(dot.17, broadcast.24)
-  constant.19 = f16[] constant(1)
-  broadcast.13 = f16[2,6,100,100]{3,2,1,0} broadcast(constant.19), dimensions={}
-  add.3 = f16[2,6,100,100]{3,2,1,0} add(multiply.2, broadcast.13)
-  constant.21 = f16[] constant(0)
-  constant.15 = f16[] constant(-inf)
-  reduce.25 = f16[2,6,100]{2,1,0} reduce(add.3, constant.15), dimensions={3}, to_apply=region_0.21
-  broadcast.17 = f16[2,6,100,100]{3,2,1,0} broadcast(reduce.25), dimensions={0,1,2}
-  subtract.1 = f16[2,6,100,100]{3,2,1,0} subtract(add.3, broadcast.17)
-  exponential.1 = f16[2,6,100,100]{3,2,1,0} exponential(subtract.1)
-  convert.5 = f32[2,6,100,100]{3,2,1,0} convert(exponential.1)
-  constant.17 = f32[] constant(0)
-  reduce.37 = f32[2,6,100]{2,1,0} reduce(convert.5, constant.17), dimensions={3}, to_apply=region_1.33
-  convert.9 = f16[2,6,100]{2,1,0} convert(reduce.37)
-  broadcast.26 = f16[2,6,100,100]{3,2,1,0} broadcast(convert.9), dimensions={0,1,2}
-  divide.5 = f16[2,6,100,100]{3,2,1,0} divide(exponential.1, broadcast.26)
-  Arg_2.3 = f16[2,6,100,64]{3,2,1,0} parameter(2), sharding={replicated}
-  dot.46 = f16[2,6,100,64]{3,2,1,0} dot(divide.5, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  Arg_3.4 = f16[2,6,100,64]{3,2,1,0} parameter(3), sharding={replicated}
-  dot.49 = f16[2,6,100,100]{3,2,1,0} dot(Arg_3.4, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
-  divide.4 = f16[2,6,100,100]{3,2,1,0} divide(dot.49, broadcast.26)
-  broadcast.20 = f16[2,6,100]{2,1,0} broadcast(constant.19), dimensions={}
-  multiply.3 = f16[2,6,100]{2,1,0} multiply(convert.9, convert.9)
-  divide.3 = f16[2,6,100]{2,1,0} divide(broadcast.20, multiply.3)
-  broadcast.21 = f16[2,6,100,100]{3,2,1,0} broadcast(divide.3), dimensions={0,1,2}
-  multiply.4 = f16[2,6,100,100]{3,2,1,0} multiply(dot.49, broadcast.21)
-  multiply.5 = f16[2,6,100,100]{3,2,1,0} multiply(multiply.4, exponential.1)
-  reduce.59 = f16[2,6,100]{2,1,0} reduce(multiply.5, constant.21), dimensions={3}, to_apply=region_2.55
-  negate.2 = f16[2,6,100]{2,1,0} negate(reduce.59)
-  broadcast.25 = f16[2,6,100,100]{3,2,1,0} broadcast(negate.2), dimensions={0,1,2}
-  add.5 = f16[2,6,100,100]{3,2,1,0} add(divide.4, broadcast.25)
-  multiply.8 = f16[2,6,100,100]{3,2,1,0} multiply(add.5, exponential.1)
-  multiply.9 = f16[2,6,100,100]{3,2,1,0} multiply(multiply.8, broadcast.24)
-  dot.80 = f16[2,6,100,64]{3,2,1,0} dot(multiply.9, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
-  dot = f16[2,6,64,100]{3,2,1,0} dot(Arg_0.1, multiply.9), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  dot.1 = f16[2,6,100,64]{3,2,1,0} dot(divide.5, Arg_3.4), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  ROOT tuple.81 = (f16[2,6,100,64]{3,2,1,0}, f16[2,6,100,64]{3,2,1,0}, f16[2,6,64,100]{3,2,1,0}, f16[2,6,100,64]{3,2,1,0}) tuple(dot.46, dot.80, dot, dot.1)
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
-                                         GetCudnnVersion()};
-  const auto status_or = RunHloPass(&fusedMhaRewriter, m.get());
-  TF_ASSERT_OK(status_or.status());
-  EXPECT_FALSE(status_or.value());
-
-  HloDCE dce;
-  TF_ASSERT_OK(RunHloPass(&dce, m.get()).status());
-
-  ComputationLayout computation_layout(
-      m->entry_computation()->ComputeProgramShape());
-
-  SCOPED_TRACE(m->ToString());
-  EXPECT_THAT(m->entry_computation()->root_instruction(),
-              GmockMatch(m::Tuple(m::Dot(), m::Dot(), m::Dot(), m::Dot())));
-}
-
-TEST_F(CudnnFusedMhaRewriterTestHloTest, F16TrainingBmm2Grad1IncorrectPattern) {
-  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
-  const char* module_str = R"(
-HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(f16[2,6,64,128]{3,2,1,0},f16[2,6,64,128]{3,2,1,0},f16[2,6,128,64]{3,2,1,0},f16[2,6,128,64]{3,2,1,0})->(f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,64,128]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,128]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true}
-
-region_0.21 {
-  Arg_0.22 = f16[] parameter(0)
-  Arg_1.23 = f16[] parameter(1)
-  ROOT maximum = f16[] maximum(Arg_0.22, Arg_1.23)
-}
-
-region_1.33 {
-  Arg_0.34 = f32[] parameter(0)
-  Arg_1.35 = f32[] parameter(1)
-  ROOT add = f32[] add(Arg_0.34, Arg_1.35)
-}
-
-region_2.55 {
-  Arg_0.56 = f16[] parameter(0)
-  Arg_1.57 = f16[] parameter(1)
-  ROOT add.1 = f16[] add(Arg_0.56, Arg_1.57)
-}
-
-ENTRY main.82 {
-  Arg_0.1 = f16[2,6,64,128]{3,2,1,0} parameter(0), sharding={replicated}
-  Arg_1.2 = f16[2,6,64,128]{3,2,1,0} parameter(1), sharding={replicated}
-  dot.17 = f16[2,6,128,128]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  constant.22 = f16[] constant(2)
-  broadcast.24 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.22), dimensions={}
-  multiply.2 = f16[2,6,128,128]{3,2,1,0} multiply(dot.17, broadcast.24)
-  constant.19 = f16[] constant(1)
-  broadcast.13 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.19), dimensions={}
-  add.3 = f16[2,6,128,128]{3,2,1,0} add(multiply.2, broadcast.13)
-  constant.21 = f16[] constant(0)
-  constant.15 = f16[] constant(-inf)
-  reduce.25 = f16[2,6,128]{2,1,0} reduce(add.3, constant.15), dimensions={3}, to_apply=region_0.21
-  broadcast.17 = f16[2,6,128,128]{3,2,1,0} broadcast(reduce.25), dimensions={0,1,2}
-  subtract.1 = f16[2,6,128,128]{3,2,1,0} subtract(add.3, broadcast.17)
-  exponential.1 = f16[2,6,128,128]{3,2,1,0} exponential(subtract.1)
-  convert.5 = f32[2,6,128,128]{3,2,1,0} convert(exponential.1)
-  constant.17 = f32[] constant(0)
-  reduce.37 = f32[2,6,128]{2,1,0} reduce(convert.5, constant.17), dimensions={3}, to_apply=region_1.33
-  convert.9 = f16[2,6,128]{2,1,0} convert(reduce.37)
-  broadcast.26 = f16[2,6,128,128]{3,2,1,0} broadcast(convert.9), dimensions={0,1,2}
-  divide.5 = f16[2,6,128,128]{3,2,1,0} divide(exponential.1, broadcast.26)
-  Arg_2.3 = f16[2,6,128,64]{3,2,1,0} parameter(2), sharding={replicated}
-  dot.46 = f16[2,6,128,64]{3,2,1,0} dot(divide.5, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  Arg_3.4 = f16[2,6,128,64]{3,2,1,0} parameter(3), sharding={replicated}
-  dot.49 = f16[2,6,128,128]{3,2,1,0} dot(Arg_3.4, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
-  divide.4 = f16[2,6,128,128]{3,2,1,0} divide(dot.49, broadcast.26)
-  broadcast.20 = f16[2,6,128]{2,1,0} broadcast(constant.19), dimensions={}
-  multiply.3 = f16[2,6,128]{2,1,0} multiply(convert.9, convert.9)
-  divide.3 = f16[2,6,128]{2,1,0} divide(broadcast.20, multiply.3)
-  broadcast.21 = f16[2,6,128,128]{3,2,1,0} broadcast(divide.3), dimensions={0,1,2}
-  multiply.4 = f16[2,6,128,128]{3,2,1,0} multiply(dot.49, broadcast.21)
-  multiply.5 = f16[2,6,128,128]{3,2,1,0} multiply(multiply.4, exponential.1)
-  reduce.59 = f16[2,6,128]{2,1,0} reduce(multiply.5, constant.21), dimensions={3}, to_apply=region_2.55
-  negate.2 = f16[2,6,128]{2,1,0} negate(reduce.59)
-  broadcast.25 = f16[2,6,128,128]{3,2,1,0} broadcast(negate.2), dimensions={0,1,2}
-  add.5 = f16[2,6,128,128]{3,2,1,0} add(divide.4, broadcast.25)
-  multiply.8 = f16[2,6,128,128]{3,2,1,0} multiply(add.5, exponential.1)
-  multiply.9 = f16[2,6,128,128]{3,2,1,0} multiply(multiply.8, broadcast.24)
-  dot.80 = f16[2,6,128,64]{3,2,1,0} dot(multiply.9, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
-  dot = f16[2,6,64,128]{3,2,1,0} dot(Arg_0.1, multiply.9), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  // add another user of ds multiply.9 here, neg.1 should not be pattern matched as bmm2grad1
-  neg.1 = f16[2,6,128,128]{3,2,1,0} negate(multiply.9)
-  dot.1 = f16[2,6,128,64]{3,2,1,0} dot(divide.5, Arg_3.4), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  ROOT tuple.81 = (f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,64,128]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,128]{3,2,1,0}) tuple(dot.46, dot.80, dot, dot.1, neg.1)
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
-                                         GetCudnnVersion()};
-  TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
-  HloDCE dce;
-  TF_ASSERT_OK(RunHloPass(&dce, m.get()).status());
-
-  ComputationLayout computation_layout(
-      m->entry_computation()->ComputeProgramShape());
-
-  SCOPED_TRACE(m->ToString());
-  EXPECT_THAT(m->entry_computation()->root_instruction(),
-              GmockMatch(m::Tuple(m::Dot(), m::Dot(), m::Dot(), m::Dot(),
-                                  m::Negate())));
-}
-
-// flash attention
-TEST_F(CudnnFusedMhaRewriterTestHloTest,
-       FlashAttentionBF16TrainingBmm1CausalMaskSoftmaxBmm2Pattern) {
-  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
-  const char* module_str = R"(
-HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[2,6,2048,128]{3,2,1,0},bf16[2,6,128,2048]{3,2,1,0},bf16[2,6,2048,128]{3,2,1,0},bf16[2,6,2048,128]{3,2,1,0})->(bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,128,2048]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true}
-region_0.32 {
-  Arg_0.33 = bf16[] parameter(0)
-  Arg_1.34 = bf16[] parameter(1)
-  ROOT maximum = bf16[] maximum(Arg_0.33, Arg_1.34)
-}
-region_1.44 {
-  Arg_0.45 = f32[] parameter(0)
-  Arg_1.46 = f32[] parameter(1)
-  ROOT add = f32[] add(Arg_0.45, Arg_1.46)
-}
-region_2.66 {
-  Arg_0.67 = bf16[] parameter(0)
-  Arg_1.68 = bf16[] parameter(1)
-  ROOT add.1 = bf16[] add(Arg_0.67, Arg_1.68)
-}
-ENTRY main.92 {
-  Arg_0.1 = bf16[2,6,2048,128]{3,2,1,0} parameter(0), sharding={replicated}
-  Arg_1.2 = bf16[2,6,128,2048]{3,2,1,0} parameter(1), sharding={replicated}
-  dot.14 = bf16[2,6,2048,2048]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  constant.17 = bf16[] constant(2)
-  broadcast.29 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(constant.17), dimensions={}
-  multiply.2 = bf16[2,6,2048,2048]{3,2,1,0} multiply(dot.14, broadcast.29)
-  iota.2 = s32[2048,2048]{1,0} iota(), iota_dimension=0
-  iota.5 = s32[2048,2048]{1,0} iota(), iota_dimension=1
-  compare.1 = pred[2048,2048]{1,0} compare(iota.2, iota.5), direction=LT
-  constant.6 = bf16[] constant(-2.366e+38)
-  broadcast.16 = bf16[2048,2048]{1,0} broadcast(constant.6), dimensions={}
-  constant.16 = bf16[] constant(0)
-  broadcast.17 = bf16[2048,2048]{1,0} broadcast(constant.16), dimensions={}
-  select.2 = bf16[2048,2048]{1,0} select(compare.1, broadcast.16, broadcast.17)
-  broadcast.19 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(select.2), dimensions={2,3}
-  add.3 = bf16[2,6,2048,2048]{3,2,1,0} add(multiply.2, broadcast.19)
-  constant.10 = bf16[] constant(-inf)
-  reduce.36 = bf16[2,6,2048]{2,1,0} reduce(add.3, constant.10), dimensions={3}, to_apply=region_0.32
-  broadcast.21 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(reduce.36), dimensions={0,1,2}
-  subtract.1 = bf16[2,6,2048,2048]{3,2,1,0} subtract(add.3, broadcast.21)
-  exponential.1 = bf16[2,6,2048,2048]{3,2,1,0} exponential(subtract.1)
-  convert.5 = f32[2,6,2048,2048]{3,2,1,0} convert(exponential.1)
-  constant.14 = f32[] constant(0)
-  reduce.48 = f32[2,6,2048]{2,1,0} reduce(convert.5, constant.14), dimensions={3}, to_apply=region_1.44
-  convert.9 = bf16[2,6,2048]{2,1,0} convert(reduce.48)
-  broadcast.32 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(convert.9), dimensions={0,1,2}
-  divide.5 = bf16[2,6,2048,2048]{3,2,1,0} divide(exponential.1, broadcast.32)
-  Arg_2.3 = bf16[2,6,2048,128]{3,2,1,0} parameter(2), sharding={replicated}
-  dot.57 = bf16[2,6,2048,128]{3,2,1,0} dot(divide.5, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  Arg_3.4 = bf16[2,6,2048,128]{3,2,1,0} parameter(3), sharding={replicated}
-  dot.60 = bf16[2,6,2048,2048]{3,2,1,0} dot(Arg_3.4, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
-  divide.4 = bf16[2,6,2048,2048]{3,2,1,0} divide(dot.60, broadcast.32)
-  constant.15 = bf16[] constant(1)
-  broadcast.25 = bf16[2,6,2048]{2,1,0} broadcast(constant.15), dimensions={}
-  multiply.3 = bf16[2,6,2048]{2,1,0} multiply(convert.9, convert.9)
-  divide.3 = bf16[2,6,2048]{2,1,0} divide(broadcast.25, multiply.3)
-  broadcast.26 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(divide.3), dimensions={0,1,2}
-  multiply.4 = bf16[2,6,2048,2048]{3,2,1,0} multiply(dot.60, broadcast.26)
-  multiply.5 = bf16[2,6,2048,2048]{3,2,1,0} multiply(multiply.4, exponential.1)
-  reduce.70 = bf16[2,6,2048]{2,1,0} reduce(multiply.5, constant.16), dimensions={3}, to_apply=region_2.66
-  negate.2 = bf16[2,6,2048]{2,1,0} negate(reduce.70)
-  broadcast.31 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(negate.2), dimensions={0,1,2}
-  add.5 = bf16[2,6,2048,2048]{3,2,1,0} add(divide.4, broadcast.31)
-  multiply.8 = bf16[2,6,2048,2048]{3,2,1,0} multiply(add.5, exponential.1)
-  multiply.9 = bf16[2,6,2048,2048]{3,2,1,0} multiply(multiply.8, broadcast.29)
-  dot.90 = bf16[2,6,2048,128]{3,2,1,0} dot(multiply.9, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
-  dot = bf16[2,6,128,2048]{3,2,1,0} dot(Arg_0.1, multiply.9), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  dot.1 = bf16[2,6,2048,128]{3,2,1,0} dot(divide.5, Arg_3.4), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  ROOT tuple.91 = (bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,128,2048]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0}) tuple(dot.57, dot.90, dot, dot.1)
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
-                                         GetCudnnVersion()};
-  TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
-  HloDCE dce;
-  TF_ASSERT_OK(RunHloPass(&dce, m.get()).status());
-
-  ComputationLayout computation_layout(
-      m->entry_computation()->ComputeProgramShape());
-
-  const HloInstruction* fwd_fmha;
-  const HloInstruction* bwd_fmha;
-  SCOPED_TRACE(m->ToString());
-  EXPECT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(m::Tuple(
-          m::GetTupleElement(
-              m::CustomCall(&fwd_fmha, {kCudnnfMHASoftmaxCallTarget}), 0)
-              .WithShape(BF16, {2, 6, 2048, 128}),
-          m::GetTupleElement(
-              m::CustomCall(&bwd_fmha, {kCudnnfMHASoftmaxBackwardCallTarget}),
-              0)
-              .WithShape(BF16, {2, 6, 2048, 128}),
-          m::Transpose(
-              m::GetTupleElement(
-                  m::CustomCall({kCudnnfMHASoftmaxBackwardCallTarget}), 1))
-              .WithShape(BF16, {2, 6, 128, 2048}),
-          m::GetTupleElement(
-              m::CustomCall({kCudnnfMHASoftmaxBackwardCallTarget}), 2)
-              .WithShape(BF16, {2, 6, 2048, 128}))));
-  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
-                          fwd_fmha->backend_config<GpuBackendConfig>());
-  const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config();
-  EXPECT_EQ(fwd_fmha->operands().size(), 3);
-  EXPECT_EQ(bwd_fmha->operands().size(), 6);
-  EXPECT_NEAR(config.dropout_rate(), 0, 1e-2);
-  EXPECT_EQ(config.mask_type(), CudnnfMHABackendConfig::CAUSAL);
-}
-
-TEST_F(CudnnFusedMhaRewriterTestHloTest,
-       FlashAttentionBF16TrainingBmm1BiasSoftmaxBmm2Pattern) {
-  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
-  const char* module_str = R"(
-HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[2,6,2048,128]{3,2,1,0},bf16[2,6,128,2048]{3,2,1,0},bf16[2,6,2048,128]{3,2,1,0},bf16[2,6,2048,128]{3,2,1,0},bf16[2,6,2048,2048]{3,2,1,0})->(bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,128,2048]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true}
-region_0.32 {
-  Arg_0.33 = bf16[] parameter(0)
-  Arg_1.34 = bf16[] parameter(1)
-  ROOT maximum = bf16[] maximum(Arg_0.33, Arg_1.34)
-}
-region_1.44 {
-  Arg_0.45 = f32[] parameter(0)
-  Arg_1.46 = f32[] parameter(1)
-  ROOT add = f32[] add(Arg_0.45, Arg_1.46)
-}
-region_2.66 {
-  Arg_0.67 = bf16[] parameter(0)
-  Arg_1.68 = bf16[] parameter(1)
-  ROOT add.1 = bf16[] add(Arg_0.67, Arg_1.68)
-}
-ENTRY main.92 {
-  Arg_0.1 = bf16[2,6,2048,128]{3,2,1,0} parameter(0), sharding={replicated}
-  Arg_1.2 = bf16[2,6,128,2048]{3,2,1,0} parameter(1), sharding={replicated}
-  dot.14 = bf16[2,6,2048,2048]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  constant.17 = bf16[] constant(2)
-  broadcast.29 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(constant.17), dimensions={}
-  multiply.2 = bf16[2,6,2048,2048]{3,2,1,0} multiply(dot.14, broadcast.29)
-  // bias
-  Arg_4.5 = bf16[2,6,2048,2048]{3,2,1,0} parameter(4), sharding={replicated}
-  add.3 = bf16[2,6,2048,2048]{3,2,1,0} add(multiply.2, Arg_4.5)
-  constant.10 = bf16[] constant(-inf)
-  constant.16 = bf16[] constant(0)
-  reduce.36 = bf16[2,6,2048]{2,1,0} reduce(add.3, constant.10), dimensions={3}, to_apply=region_0.32
-  broadcast.21 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(reduce.36), dimensions={0,1,2}
-  subtract.1 = bf16[2,6,2048,2048]{3,2,1,0} subtract(add.3, broadcast.21)
-  exponential.1 = bf16[2,6,2048,2048]{3,2,1,0} exponential(subtract.1)
-  convert.5 = f32[2,6,2048,2048]{3,2,1,0} convert(exponential.1)
-  constant.14 = f32[] constant(0)
-  reduce.48 = f32[2,6,2048]{2,1,0} reduce(convert.5, constant.14), dimensions={3}, to_apply=region_1.44
-  convert.9 = bf16[2,6,2048]{2,1,0} convert(reduce.48)
-  broadcast.32 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(convert.9), dimensions={0,1,2}
-  divide.5 = bf16[2,6,2048,2048]{3,2,1,0} divide(exponential.1, broadcast.32)
-  Arg_2.3 = bf16[2,6,2048,128]{3,2,1,0} parameter(2), sharding={replicated}
-  dot.57 = bf16[2,6,2048,128]{3,2,1,0} dot(divide.5, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  Arg_3.4 = bf16[2,6,2048,128]{3,2,1,0} parameter(3), sharding={replicated}
-  dot.60 = bf16[2,6,2048,2048]{3,2,1,0} dot(Arg_3.4, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
-  divide.4 = bf16[2,6,2048,2048]{3,2,1,0} divide(dot.60, broadcast.32)
-  constant.15 = bf16[] constant(1)
-  broadcast.25 = bf16[2,6,2048]{2,1,0} broadcast(constant.15), dimensions={}
-  multiply.3 = bf16[2,6,2048]{2,1,0} multiply(convert.9, convert.9)
-  divide.3 = bf16[2,6,2048]{2,1,0} divide(broadcast.25, multiply.3)
-  broadcast.26 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(divide.3), dimensions={0,1,2}
-  multiply.4 = bf16[2,6,2048,2048]{3,2,1,0} multiply(dot.60, broadcast.26)
-  multiply.5 = bf16[2,6,2048,2048]{3,2,1,0} multiply(multiply.4, exponential.1)
-  reduce.70 = bf16[2,6,2048]{2,1,0} reduce(multiply.5, constant.16), dimensions={3}, to_apply=region_2.66
-  negate.2 = bf16[2,6,2048]{2,1,0} negate(reduce.70)
-  broadcast.31 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(negate.2), dimensions={0,1,2}
-  add.5 = bf16[2,6,2048,2048]{3,2,1,0} add(divide.4, broadcast.31)
-  multiply.8 = bf16[2,6,2048,2048]{3,2,1,0} multiply(add.5, exponential.1)
-  multiply.9 = bf16[2,6,2048,2048]{3,2,1,0} multiply(multiply.8, broadcast.29)
-  dot.90 = bf16[2,6,2048,128]{3,2,1,0} dot(multiply.9, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
-  dot = bf16[2,6,128,2048]{3,2,1,0} dot(Arg_0.1, multiply.9), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  dot.1 = bf16[2,6,2048,128]{3,2,1,0} dot(divide.5, Arg_3.4), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  ROOT tuple.91 = (bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,128,2048]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0}) tuple(dot.57, dot.90, dot, dot.1)
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
-                                         GetCudnnVersion()};
-  TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
-  HloDCE dce;
-  TF_ASSERT_OK(RunHloPass(&dce, m.get()).status());
-
-  ComputationLayout computation_layout(
-      m->entry_computation()->ComputeProgramShape());
-
-  const HloInstruction* fmha;
-
-  SCOPED_TRACE(m->ToString());
-  EXPECT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(m::Tuple(
-          m::GetTupleElement(
-              m::CustomCall(&fmha, {kCudnnfMHAScaleBiasSoftmaxCallTarget}), 0)
-              .WithShape(BF16, {2, 6, 2048, 128}),
-          m::GetTupleElement(
-              m::CustomCall(&fmha,
-                            {kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget}),
-              0)
-              .WithShape(BF16, {2, 6, 2048, 128}),
-          m::Transpose(
-              m::GetTupleElement(
-                  m::CustomCall({kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget}),
-                  1))
-              .WithShape(BF16, {2, 6, 128, 2048}),
-          m::GetTupleElement(
-              m::CustomCall({kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget}), 2)
-              .WithShape(BF16, {2, 6, 2048, 128}))));
-  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
-                          fmha->backend_config<GpuBackendConfig>());
-  const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config();
-  EXPECT_EQ(fmha->operands().size(), 7);
-  EXPECT_NEAR(config.dropout_rate(), 0, 1e-2);
-  EXPECT_EQ(config.mask_type(), CudnnfMHABackendConfig::NO_MASK);
-}
-
-TEST_F(CudnnFusedMhaRewriterTestHloTest,
-       FlashAttentionBF16TrainingBmm1SoftmaxBmm2Pattern) {
-  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
-  const char* module_str = R"(
-HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[2,6,2048,128]{3,2,1,0},bf16[2,6,128,2048]{3,2,1,0},bf16[2,6,2048,128]{3,2,1,0},bf16[2,6,2048,128]{3,2,1,0})->(bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,128,2048]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true}
-region_0.32 {
-  Arg_0.33 = bf16[] parameter(0)
-  Arg_1.34 = bf16[] parameter(1)
-  ROOT maximum = bf16[] maximum(Arg_0.33, Arg_1.34)
-}
-region_1.44 {
-  Arg_0.45 = f32[] parameter(0)
-  Arg_1.46 = f32[] parameter(1)
-  ROOT add = f32[] add(Arg_0.45, Arg_1.46)
-}
-region_2.66 {
-  Arg_0.67 = bf16[] parameter(0)
-  Arg_1.68 = bf16[] parameter(1)
-  ROOT add.1 = bf16[] add(Arg_0.67, Arg_1.68)
-}
-ENTRY main.92 {
-  Arg_0.1 = bf16[2,6,2048,128]{3,2,1,0} parameter(0), sharding={replicated}
-  Arg_1.2 = bf16[2,6,128,2048]{3,2,1,0} parameter(1), sharding={replicated}
-  dot.14 = bf16[2,6,2048,2048]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  constant.17 = bf16[] constant(2)
-  broadcast.29 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(constant.17), dimensions={}
-  multiply.2 = bf16[2,6,2048,2048]{3,2,1,0} multiply(dot.14, broadcast.29)
-  constant.10 = bf16[] constant(-inf)
-  constant.16 = bf16[] constant(0)
-  reduce.36 = bf16[2,6,2048]{2,1,0} reduce(multiply.2, constant.10), dimensions={3}, to_apply=region_0.32
-  broadcast.21 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(reduce.36), dimensions={0,1,2}
-  subtract.1 = bf16[2,6,2048,2048]{3,2,1,0} subtract(multiply.2, broadcast.21)
-  exponential.1 = bf16[2,6,2048,2048]{3,2,1,0} exponential(subtract.1)
-  convert.5 = f32[2,6,2048,2048]{3,2,1,0} convert(exponential.1)
-  constant.14 = f32[] constant(0)
-  reduce.48 = f32[2,6,2048]{2,1,0} reduce(convert.5, constant.14), dimensions={3}, to_apply=region_1.44
-  convert.9 = bf16[2,6,2048]{2,1,0} convert(reduce.48)
-  broadcast.32 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(convert.9), dimensions={0,1,2}
-  divide.5 = bf16[2,6,2048,2048]{3,2,1,0} divide(exponential.1, broadcast.32)
-  Arg_2.3 = bf16[2,6,2048,128]{3,2,1,0} parameter(2), sharding={replicated}
-  dot.57 = bf16[2,6,2048,128]{3,2,1,0} dot(divide.5, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  Arg_3.4 = bf16[2,6,2048,128]{3,2,1,0} parameter(3), sharding={replicated}
-  dot.60 = bf16[2,6,2048,2048]{3,2,1,0} dot(Arg_3.4, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
-  divide.4 = bf16[2,6,2048,2048]{3,2,1,0} divide(dot.60, broadcast.32)
-  constant.15 = bf16[] constant(1)
-  broadcast.25 = bf16[2,6,2048]{2,1,0} broadcast(constant.15), dimensions={}
-  multiply.3 = bf16[2,6,2048]{2,1,0} multiply(convert.9, convert.9)
-  divide.3 = bf16[2,6,2048]{2,1,0} divide(broadcast.25, multiply.3)
-  broadcast.26 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(divide.3), dimensions={0,1,2}
-  multiply.4 = bf16[2,6,2048,2048]{3,2,1,0} multiply(dot.60, broadcast.26)
-  multiply.5 = bf16[2,6,2048,2048]{3,2,1,0} multiply(multiply.4, exponential.1)
-  reduce.70 = bf16[2,6,2048]{2,1,0} reduce(multiply.5, constant.16), dimensions={3}, to_apply=region_2.66
-  negate.2 = bf16[2,6,2048]{2,1,0} negate(reduce.70)
-  broadcast.31 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(negate.2), dimensions={0,1,2}
-  add.5 = bf16[2,6,2048,2048]{3,2,1,0} add(divide.4, broadcast.31)
-  multiply.8 = bf16[2,6,2048,2048]{3,2,1,0} multiply(add.5, exponential.1)
-  multiply.9 = bf16[2,6,2048,2048]{3,2,1,0} multiply(multiply.8, broadcast.29)
-  dot.90 = bf16[2,6,2048,128]{3,2,1,0} dot(multiply.9, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
-  dot = bf16[2,6,128,2048]{3,2,1,0} dot(Arg_0.1, multiply.9), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  dot.1 = bf16[2,6,2048,128]{3,2,1,0} dot(divide.5, Arg_3.4), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  ROOT tuple.91 = (bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,128,2048]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0}) tuple(dot.57, dot.90, dot, dot.1)
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
-                                         GetCudnnVersion()};
-  TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
-  HloDCE dce;
-  TF_ASSERT_OK(RunHloPass(&dce, m.get()).status());
-
-  ComputationLayout computation_layout(
-      m->entry_computation()->ComputeProgramShape());
-
-  const HloInstruction* fmha;
-
-  SCOPED_TRACE(m->ToString());
-  EXPECT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(m::Tuple(
-          m::GetTupleElement(
-              m::CustomCall(&fmha, {kCudnnfMHASoftmaxCallTarget}), 0)
-              .WithShape(BF16, {2, 6, 2048, 128}),
-          m::GetTupleElement(
-              m::CustomCall(&fmha, {kCudnnfMHASoftmaxBackwardCallTarget}), 0)
-              .WithShape(BF16, {2, 6, 2048, 128}),
-          m::Transpose(
-              m::GetTupleElement(
-                  m::CustomCall({kCudnnfMHASoftmaxBackwardCallTarget}), 1))
-              .WithShape(BF16, {2, 6, 128, 2048}),
-          m::GetTupleElement(
-              m::CustomCall({kCudnnfMHASoftmaxBackwardCallTarget}), 2)
-              .WithShape(BF16, {2, 6, 2048, 128}))));
-  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
-                          fmha->backend_config<GpuBackendConfig>());
-  const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config();
-  EXPECT_EQ(fmha->operands().size(), 6);
-  EXPECT_NEAR(config.dropout_rate(), 0, 1e-2);
-  EXPECT_FLOAT_EQ(config.fmha_scale(), 2);
-  EXPECT_EQ(config.mask_type(), CudnnfMHABackendConfig::NO_MASK);
-}
-
-// GPT3 pattern
-TEST_F(CudnnFusedMhaRewriterTestHloTest, FlashAttentionBF16TrainingGPT3_5B) {
-  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
-  const char* module_str = R"(
-HloModule jit__unnamed_wrapped_function_, entry_computation_layout={((s32[], bf16[32,2048,2048]{1,0,2}, bf16[24,8192]{1,0}, bf16[24,1024,8192]{2,1,0}, bf16[24,1024]{0,1}, /*index=5*/bf16[24,8192,1024]{1,2,0}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, /*index=10*/bf16[24,3,16,128]{3,2,1,0}, bf16[24,3,1024,16,128]{4,3,1,2,0}, bf16[24,1024]{1,0}, bf16[24,1024,16,128]{3,2,1,0}, bf16[24,8192]{1,0}, /*index=15*/bf16[24,1024,8192]{2,1,0}, bf16[24,8192,1024]{1,2,0}, bf16[24,2048]{1,0}, bf16[24,2048]{1,0}, bf16[24,2048]{1,0}, /*index=20*/bf16[24,2048]{1,0}, bf16[24,3,16,128]{3,2,1,0}, bf16[24,3,1024,16,128]{4,3,1,2,0}, bf16[24,1024]{1,0}, bf16[24,1024,16,128]{3,2,1,0}, /*index=25*/bf16[24,32,2048,2048]{2,1,3,0}, bf16[32,1,2048,2048]{3,2,0,1}, bf16[32,2048]{1,0}))->(s32[], bf16[32,2048,2048]{1,0,2}, bf16[24,8192]{1,0}, bf16[24,1024,8192]{2,1,0}, bf16[24,1024]{0,1}, /*index=5*/bf16[24,8192,1024]{1,2,0}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, /*index=10*/bf16[24,3,16,128]{3,2,1,0}, bf16[24,3,1024,16,128]{4,3,1,2,0}, bf16[24,1024]{1,0}, bf16[24,1024,16,128]{3,2,1,0}, bf16[24,8192]{1,0}, /*index=15*/bf16[24,1024,8192]{2,1,0}, bf16[24,8192,1024]{1,2,0}, bf16[24,2048]{1,0}, bf16[24,2048]{1,0}, bf16[24,2048]{1,0}, /*index=20*/bf16[24,2048]{1,0}, bf16[24,3,16,128]{3,2,1,0}, bf16[24,3,1024,16,128]{4,3,1,2,0}, bf16[24,1024]{1,0}, bf16[24,1024,16,128]{3,2,1,0}, /*index=25*/bf16[24,32,2048,2048]{2,1,3,0}, bf16[32,1,2048,2048]{3,2,0,1}, bf16[32,2048]{1,0})}
-add {
-  x = bf16[] parameter(0)
-  y = bf16[] parameter(1)
-  ROOT add.580 = bf16[] add(x, y)
-}
-
-region_20.962 {
-  Arg_0.963 = f32[] parameter(0)
-  Arg_1.964 = f32[] parameter(1)
-  ROOT add.579 = f32[] add(Arg_0.963, Arg_1.964)
-}
-
-region_39.1120 {
-  Arg_0.1121 = f32[] parameter(0)
-  Arg_1.1122 = f32[] parameter(1)
-  ROOT maximum.21 = f32[] maximum(Arg_0.1121, Arg_1.1122)
-}
-
-main {
-  param.3 = (s32[], bf16[32,2048,2048]{1,0,2}, bf16[24,8192]{1,0}, bf16[24,1024,8192]{2,1,0}, bf16[24,1024]{0,1}, /*index=5*/bf16[24,8192,1024]{1,2,0}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, /*index=10*/bf16[24,3,16,128]{3,2,1,0}, bf16[24,3,1024,16,128]{4,3,1,2,0}, bf16[24,1024]{1,0}, bf16[24,1024,16,128]{3,2,1,0}, bf16[24,8192]{1,0}, /*index=15*/bf16[24,1024,8192]{2,1,0}, bf16[24,8192,1024]{1,2,0}, bf16[24,2048]{1,0}, bf16[24,2048]{1,0}, bf16[24,2048]{1,0}, /*index=20*/bf16[24,2048]{1,0}, bf16[24,3,16,128]{3,2,1,0}, bf16[24,3,1024,16,128]{4,3,1,2,0}, bf16[24,1024]{1,0}, bf16[24,1024,16,128]{3,2,1,0}, /*index=25*/bf16[24,32,2048,2048]{2,1,3,0}, bf16[32,1,2048,2048]{3,2,0,1}, bf16[32,2048]{1,0}) parameter(0)
-  get-tuple-element.31 = s32[] get-tuple-element(param.3), index=0
-  constant.1961 = s32[] constant(1)
-  add.581 = s32[] add(get-tuple-element.31, constant.1961)
-  get-tuple-element.32 = bf16[24,32,2048,2048]{2,1,3,0} get-tuple-element(param.3), index=25
-  bitcast.187 = bf16[24,2048,32,2048]{3,2,1,0} bitcast(get-tuple-element.32)
-  constant.1977 = s32[] constant(23)
-  subtract.221 = s32[] subtract(constant.1977, get-tuple-element.31)
-  constant.1980 = s32[] constant(0)
-  compare.210 = pred[] compare(subtract.221, constant.1980), direction=LT
-  constant.1979 = s32[] constant(47)
-  subtract.222 = s32[] subtract(constant.1979, get-tuple-element.31)
-  select.372 = s32[] select(compare.210, subtract.222, subtract.221)
-  dynamic-slice.324 = bf16[1,2048,32,2048]{3,2,1,0} dynamic-slice(bitcast.187, select.372, constant.1980, constant.1980, constant.1980), dynamic_slice_sizes={1,2048,32,2048}
-  bitcast.756 = bf16[2048,32,2048]{2,1,0} bitcast(dynamic-slice.324)
-  convert.282 = f32[2048,32,2048]{2,1,0} convert(bitcast.756)
-  constant.1991 = bf16[] constant(1)
-  broadcast.1270 = bf16[32,2048]{1,0} broadcast(constant.1991), dimensions={}
-  get-tuple-element.33 = bf16[32,2048]{1,0} get-tuple-element(param.3), index=27
-  subtract.229 = bf16[32,2048]{1,0} subtract(broadcast.1270, get-tuple-element.33)
-  convert.285 = f32[32,2048]{1,0} convert(subtract.229)
-  broadcast.1228 = f32[2048,32,2048]{2,1,0} broadcast(convert.285), dimensions={1,2}
-  multiply.656 = f32[2048,32,2048]{2,1,0} multiply(convert.282, broadcast.1228)
-  bitcast.367 = f32[32,2048,2048]{1,0,2} bitcast(multiply.656)
-  constant.1968 = f32[] constant(0)
-  reduce.84 = f32[] reduce(bitcast.367, constant.1968), dimensions={0,1,2}, to_apply=region_20.962
-  all-reduce.230 = f32[] all-reduce(reduce.84), channel_id=278, replica_groups={{0,1,2,3,4,5,6,7}}, use_global_device_ids=true, to_apply=region_20.962
-  broadcast.1221 = f32[32,2048,4096]{2,1,0} broadcast(convert.285), dimensions={0,1}
-  reduce.85 = f32[] reduce(broadcast.1221, constant.1968), dimensions={0,1,2}, to_apply=region_20.962
-  all-reduce.14 = f32[] all-reduce(reduce.85), channel_id=49, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=region_20.962
-  constant.2005 = f32[] constant(1)
-  maximum.24 = f32[] maximum(all-reduce.14, constant.2005)
-  divide.96 = f32[] divide(all-reduce.230, maximum.24)
-  broadcast.1223 = f32[2048,32,2048]{2,1,0} broadcast(divide.96), dimensions={}
-  subtract.219 = f32[2048,32,2048]{2,1,0} subtract(convert.282, broadcast.1223)
-  multiply.644 = f32[2048,32,2048]{2,1,0} multiply(subtract.219, broadcast.1228)
-  multiply.645 = f32[2048,32,2048]{2,1,0} multiply(multiply.644, multiply.644)
-  bitcast.271 = f32[32,2048,2048]{1,0,2} bitcast(multiply.645)
-  reduce.86 = f32[] reduce(bitcast.271, constant.1968), dimensions={0,1,2}, to_apply=region_20.962
-  all-reduce.231 = f32[] all-reduce(reduce.86), channel_id=279, replica_groups={{0,1,2,3,4,5,6,7}}, use_global_device_ids=true, to_apply=region_20.962
-  divide.99 = f32[] divide(all-reduce.231, maximum.24)
-  rsqrt.16 = f32[] rsqrt(divide.99)
-  multiply.650 = f32[] multiply(rsqrt.16, constant.1968)
-  divide.100 = f32[] divide(multiply.650, maximum.24)
-  constant.1974 = f32[] constant(2)
-  multiply.652 = f32[] multiply(divide.100, constant.1974)
-  broadcast.1227 = f32[2048,32,2048]{2,1,0} broadcast(multiply.652), dimensions={}
-  multiply.653 = f32[2048,32,2048]{2,1,0} multiply(multiply.644, broadcast.1227)
-  multiply.654 = f32[2048,32,2048]{2,1,0} multiply(multiply.653, broadcast.1228)
-  negate.56 = f32[2048,32,2048]{2,1,0} negate(multiply.654)
-  bitcast.321 = f32[32,2048,2048]{1,0,2} bitcast(negate.56)
-  reduce.87 = f32[] reduce(bitcast.321, constant.1968), dimensions={0,1,2}, to_apply=region_20.962
-  all-reduce.232 = f32[] all-reduce(reduce.87), channel_id=280, replica_groups={{0,1,2,3,4,5,6,7}}, use_global_device_ids=true, to_apply=region_20.962
-  divide.101 = f32[] divide(all-reduce.232, maximum.24)
-  broadcast.1229 = f32[32,2048]{1,0} broadcast(divide.101), dimensions={}
-  multiply.655 = f32[32,2048]{1,0} multiply(broadcast.1229, convert.285)
-  broadcast.1230 = f32[2048,32,2048]{2,1,0} broadcast(multiply.655), dimensions={1,2}
-  add.582 = f32[2048,32,2048]{2,1,0} add(multiply.654, broadcast.1230)
-  broadcast.1236 = f32[2048,32,2048]{2,1,0} broadcast(constant.1968), dimensions={}
-  compare.208 = pred[2048,32,2048]{2,1,0} compare(multiply.656, broadcast.1236), direction=GE
-  abs.22 = f32[2048,32,2048]{2,1,0} abs(multiply.656)
-  bitcast.373 = f32[32,2048,2048]{1,0,2} bitcast(abs.22)
-  constant.1989 = f32[] constant(-inf)
-  reduce.88 = f32[] reduce(bitcast.373, constant.1989), dimensions={0,1,2}, to_apply=region_39.1120
-  all-reduce.233 = f32[] all-reduce(reduce.88), channel_id=281, replica_groups={{0,1,2,3,4,5,6,7}}, use_global_device_ids=true, to_apply=region_39.1120
-  broadcast.1233 = f32[2048,32,2048]{2,1,0} broadcast(all-reduce.233), dimensions={}
-  compare.207 = pred[2048,32,2048]{2,1,0} compare(abs.22, broadcast.1233), direction=EQ
-  convert.286 = f32[2048,32,2048]{2,1,0} convert(compare.207)
-  bitcast.393 = f32[32,2048,2048]{1,0,2} bitcast(convert.286)
-  reduce.89 = f32[] reduce(bitcast.393, constant.1968), dimensions={0,1,2}, to_apply=region_20.962
-  all-reduce.234 = f32[] all-reduce(reduce.89), channel_id=282, replica_groups={{0,1,2,3,4,5,6,7}}, use_global_device_ids=true, to_apply=region_20.962
-  divide.103 = f32[] divide(constant.1968, all-reduce.234)
-  broadcast.1238 = f32[2048,32,2048]{2,1,0} broadcast(divide.103), dimensions={}
-  select.370 = f32[2048,32,2048]{2,1,0} select(compare.207, broadcast.1238, broadcast.1236)
-  select.369 = f32[2048,32,2048]{2,1,0} select(compare.208, select.370, broadcast.1236)
-  constant.1976 = pred[] constant(false)
-  broadcast.1237 = pred[2048,32,2048]{2,1,0} broadcast(constant.1976), dimensions={}
-  compare.209 = pred[2048,32,2048]{2,1,0} compare(compare.208, broadcast.1237), direction=EQ
-  select.371 = f32[2048,32,2048]{2,1,0} select(compare.209, select.370, broadcast.1236)
-  negate.57 = f32[2048,32,2048]{2,1,0} negate(select.371)
-  add.583 = f32[2048,32,2048]{2,1,0} add(select.369, negate.57)
-  multiply.658 = f32[2048,32,2048]{2,1,0} multiply(add.583, broadcast.1228)
-  add.585 = f32[2048,32,2048]{2,1,0} add(add.582, multiply.658)
-  convert.287 = bf16[2048,32,2048]{2,1,0} convert(add.585)
-  get-tuple-element.34 = bf16[32,2048,2048]{1,0,2} get-tuple-element(param.3), index=1
-  bitcast.1652 = bf16[2048,32,2048]{2,1,0} bitcast(get-tuple-element.34)
-  get-tuple-element.35 = bf16[24,3,1024,16,128]{4,3,1,2,0} get-tuple-element(param.3), index=22
-  bitcast.461 = bf16[24,1024,3,16,128]{4,3,2,1,0} bitcast(get-tuple-element.35)
-  dynamic-slice.325 = bf16[1,1024,3,16,128]{4,3,2,1,0} dynamic-slice(bitcast.461, select.372, constant.1980, constant.1980, constant.1980, /*index=5*/constant.1980), dynamic_slice_sizes={1,1024,3,16,128}
-  bitcast.485 = bf16[3,1024,16,128]{3,2,0,1} bitcast(dynamic-slice.325)
-  all-gather.7 = bf16[3,4096,16,128]{3,2,0,1} all-gather(bitcast.485), channel_id=60, replica_groups={{0,2,4,6},{1,3,5,7}}, dimensions={1}, use_global_device_ids=true
-  bitcast.1420 = bf16[6144,4096]{0,1} bitcast(all-gather.7)
-  bitcast.500 = f32[32,2048,2048]{1,0,2} bitcast(convert.282)
-  reduce.90 = f32[32,2048]{1,0} reduce(bitcast.500, constant.1968), dimensions={2}, to_apply=region_20.962
-  all-reduce.23 = f32[32,2048]{1,0} all-reduce(reduce.90), channel_id=58, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, to_apply=region_20.962
-  constant.1983 = f32[] constant(0.000244140625)
-  broadcast.1243 = f32[32,2048]{1,0} broadcast(constant.1983), dimensions={}
-  multiply.660 = f32[32,2048]{1,0} multiply(all-reduce.23, broadcast.1243)
-  broadcast.1242 = f32[2048,32,2048]{2,1,0} broadcast(multiply.660), dimensions={1,2}
-  subtract.224 = f32[2048,32,2048]{2,1,0} subtract(convert.282, broadcast.1242)
-  multiply.661 = f32[2048,32,2048]{2,1,0} multiply(subtract.224, subtract.224)
-  bitcast.527 = f32[32,2048,2048]{1,0,2} bitcast(multiply.661)
-  reduce.91 = f32[32,2048]{1,0} reduce(bitcast.527, constant.1968), dimensions={2}, to_apply=region_20.962
-  all-reduce.24 = f32[32,2048]{1,0} all-reduce(reduce.91), channel_id=59, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, to_apply=region_20.962
-  multiply.662 = f32[32,2048]{1,0} multiply(all-reduce.24, broadcast.1243)
-  constant.1990 = f32[] constant(1e-05)
-  broadcast.1264 = f32[32,2048]{1,0} broadcast(constant.1990), dimensions={}
-  add.587 = f32[32,2048]{1,0} add(multiply.662, broadcast.1264)
-  bitcast.1447 = f32[1,32,2048]{2,1,0} bitcast(add.587)
-  rsqrt.20 = f32[1,32,2048]{2,1,0} rsqrt(bitcast.1447)
-  bitcast.1892 = f32[32,2048]{1,0} bitcast(rsqrt.20)
-  broadcast.1337 = f32[2048,32,2048]{2,1,0} broadcast(bitcast.1892), dimensions={1,2}
-  multiply.754 = f32[2048,32,2048]{2,1,0} multiply(subtract.224, broadcast.1337)
-  convert.314 = bf16[2048,32,2048]{2,1,0} convert(multiply.754)
-  get-tuple-element.36 = bf16[24,2048]{1,0} get-tuple-element(param.3), index=20
-  dynamic-slice.326 = bf16[1,2048]{1,0} dynamic-slice(get-tuple-element.36, select.372, constant.1980), dynamic_slice_sizes={1,2048}
-  broadcast.1266 = bf16[1,2048]{1,0} broadcast(constant.1991), dimensions={}
-  add.588 = bf16[1,2048]{1,0} add(dynamic-slice.326, broadcast.1266)
-  bitcast.1992 = bf16[2048]{0} bitcast(add.588)
-  broadcast.1338 = bf16[2048,32,2048]{2,1,0} broadcast(bitcast.1992), dimensions={0}
-  multiply.755 = bf16[2048,32,2048]{2,1,0} multiply(convert.314, broadcast.1338)
-  get-tuple-element.37 = bf16[24,2048]{1,0} get-tuple-element(param.3), index=19
-  dynamic-slice.327 = bf16[1,2048]{1,0} dynamic-slice(get-tuple-element.37, select.372, constant.1980), dynamic_slice_sizes={1,2048}
-  bitcast.1998 = bf16[2048]{0} bitcast(dynamic-slice.327)
-  broadcast.1339 = bf16[2048,32,2048]{2,1,0} broadcast(bitcast.1998), dimensions={0}
-  add.640 = bf16[2048,32,2048]{2,1,0} add(multiply.755, broadcast.1339)
-  bitcast.2003 = bf16[32,2048,2048]{1,0,2} bitcast(add.640)
-  all-gather.8 = bf16[32,2048,4096]{1,0,2} all-gather(bitcast.2003), channel_id=61, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={2}, use_global_device_ids=true
-  bitcast.597 = bf16[4096,65536]{1,0} bitcast(all-gather.8)
-  dot.42 = bf16[6144,65536]{1,0} dot(bitcast.1420, bitcast.597), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  bitcast.623 = bf16[3,16,128,32,2048]{4,3,2,1,0} bitcast(dot.42)
-  transpose.112 = bf16[3,32,16,128,2048]{4,3,2,1,0} transpose(bitcast.623), dimensions={0,3,1,2,4}
-  get-tuple-element.38 = bf16[24,3,16,128]{3,2,1,0} get-tuple-element(param.3), index=21
-  dynamic-slice.328 = bf16[1,3,16,128]{3,2,1,0} dynamic-slice(get-tuple-element.38, select.372, constant.1980, constant.1980, constant.1980), dynamic_slice_sizes={1,3,16,128}
-  bitcast.626 = bf16[3,16,128]{2,1,0} bitcast(dynamic-slice.328)
-  broadcast.1250 = bf16[3,32,16,128,2048]{4,3,2,1,0} broadcast(bitcast.626), dimensions={0,2,3}
-  add.591 = bf16[3,32,16,128,2048]{4,3,2,1,0} add(transpose.112, broadcast.1250)
-  slice.87 = bf16[1,32,16,128,2048]{4,3,2,1,0} slice(add.591), slice={[2:3], [0:32], [0:16], [0:128], [0:2048]}
-  bitcast.1280 = bf16[32,16,128,2048]{3,2,1,0} bitcast(slice.87)
-  slice.88 = bf16[1,32,16,128,2048]{4,3,2,1,0} slice(add.591), slice={[0:1], [0:32], [0:16], [0:128], [0:2048]}
-  constant.2007 = bf16[] constant(0.08838)
-  broadcast.1251 = bf16[1,32,16,128,2048]{4,3,2,1,0} broadcast(constant.2007), dimensions={}
-  multiply.666 = bf16[1,32,16,128,2048]{4,3,2,1,0} multiply(slice.88, broadcast.1251)
-  bitcast.1330 = bf16[32,16,128,2048]{3,2,1,0} bitcast(multiply.666)
-  transpose.113 = bf16[32,16,2048,128]{3,2,1,0} transpose(bitcast.1330), dimensions={0,1,3,2}
-  slice.89 = bf16[1,32,16,128,2048]{4,3,2,1,0} slice(add.591), slice={[1:2], [0:32], [0:16], [0:128], [0:2048]}
-  bitcast.647 = bf16[32,16,128,2048]{3,2,1,0} bitcast(slice.89)
-  dot.43 = bf16[32,16,2048,2048]{3,2,1,0} dot(transpose.113, bitcast.647), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  convert.291 = f32[32,16,2048,2048]{3,2,1,0} convert(dot.43)
-  get-tuple-element.39 = bf16[32,1,2048,2048]{3,2,0,1} get-tuple-element(param.3), index=26
-  bitcast.651 = bf16[1,32,2048,2048]{3,2,1,0} bitcast(get-tuple-element.39)
-  iota.38 = s32[2048,2048]{1,0} iota(), iota_dimension=0
-  iota.39 = s32[2048,2048]{1,0} iota(), iota_dimension=1
-  compare.211 = pred[2048,2048]{1,0} compare(iota.38, iota.39), direction=LT
-  constant.1987 = bf16[] constant(-2.366e+38)
-  broadcast.1252 = bf16[2048,2048]{1,0} broadcast(constant.1987), dimensions={}
-  constant.2006 = bf16[] constant(0)
-  broadcast.1253 = bf16[2048,2048]{1,0} broadcast(constant.2006), dimensions={}
-  select.373 = bf16[2048,2048]{1,0} select(compare.211, broadcast.1252, broadcast.1253)
-  broadcast.1254 = bf16[1,32,2048,2048]{3,2,1,0} broadcast(select.373), dimensions={2,3}
-  minimum.5 = bf16[1,32,2048,2048]{3,2,1,0} minimum(bitcast.651, broadcast.1254)
-  bitcast.673 = bf16[32,2048,2048]{2,1,0} bitcast(minimum.5)
-  convert.292 = f32[32,2048,2048]{2,1,0} convert(bitcast.673)
-  broadcast.1256 = f32[32,16,2048,2048]{3,2,1,0} broadcast(convert.292), dimensions={0,2,3}
-  add.593 = f32[32,16,2048,2048]{3,2,1,0} add(convert.291, broadcast.1256)
-  reduce.92 = f32[32,16,2048]{2,1,0} reduce(add.593, constant.1989), dimensions={3}, to_apply=region_39.1120
-  broadcast.1258 = f32[32,16,2048,2048]{3,2,1,0} broadcast(reduce.92), dimensions={0,1,2}
-  subtract.226 = f32[32,16,2048,2048]{3,2,1,0} subtract(add.593, broadcast.1258)
-  exponential.8 = f32[32,16,2048,2048]{3,2,1,0} exponential(subtract.226)
-  reduce.93 = f32[32,16,2048]{2,1,0} reduce(exponential.8, constant.1968), dimensions={3}, to_apply=region_20.962
-  broadcast.1309 = f32[32,16,2048,2048]{3,2,1,0} broadcast(reduce.93), dimensions={0,1,2}
-  divide.109 = f32[32,16,2048,2048]{3,2,1,0} divide(exponential.8, broadcast.1309)
-  convert.306 = bf16[32,16,2048,2048]{3,2,1,0} convert(divide.109)
-  dot.44 = bf16[32,16,128,2048]{3,2,1,0} dot(bitcast.1280, convert.306), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
-  transpose.116 = bf16[32,2048,16,128]{3,2,1,0} transpose(dot.44), dimensions={0,3,1,2}
-  bitcast.711 = bf16[65536,2048]{1,0} bitcast(transpose.116)
-  get-tuple-element.40 = bf16[24,1024,16,128]{3,2,1,0} get-tuple-element(param.3), index=24
-  dynamic-slice.329 = bf16[1,1024,16,128]{3,2,1,0} dynamic-slice(get-tuple-element.40, select.372, constant.1980, constant.1980, constant.1980), dynamic_slice_sizes={1,1024,16,128}
-  bitcast.724 = bf16[1024,16,128]{2,1,0} bitcast(dynamic-slice.329)
-  all-gather.9 = bf16[4096,16,128]{2,1,0} all-gather(bitcast.724), channel_id=62, replica_groups={{0,2,4,6},{1,3,5,7}}, dimensions={0}, use_global_device_ids=true
-  bitcast.729 = bf16[2048,4096]{0,1} bitcast(all-gather.9)
-  dot.57 = bf16[65536,4096]{0,1} dot(bitcast.711, bitcast.729), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  bitcast.733 = bf16[32,2048,4096]{1,0,2} bitcast(dot.57)
-  reduce-scatter = bf16[32,2048,2048]{1,0,2} reduce-scatter(bitcast.733), channel_id=322, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, dimensions={2}, to_apply=add
-  bitcast.763 = bf16[2048,32,2048]{2,1,0} bitcast(reduce-scatter)
-  get-tuple-element.41 = bf16[24,1024]{1,0} get-tuple-element(param.3), index=23
-  dynamic-slice.330 = bf16[1,1024]{1,0} dynamic-slice(get-tuple-element.41, select.372, constant.1980), dynamic_slice_sizes={1,1024}
-  bitcast.748 = bf16[1024]{0} bitcast(dynamic-slice.330)
-  collective-permute.1 = bf16[1024]{0} collective-permute(bitcast.748), channel_id=64, source_target_pairs={{0,0},{1,2},{2,4},{3,6},{4,1},{5,3},{6,5},{7,7}}
-  all-gather.10 = bf16[2048]{0} all-gather(collective-permute.1), channel_id=65, replica_groups={{0,4},{2,6},{1,5},{3,7}}, dimensions={0}, use_global_device_ids=true
-  broadcast.1261 = bf16[2048,32,2048]{2,1,0} broadcast(all-gather.10), dimensions={0}
-  add.596 = bf16[2048,32,2048]{2,1,0} add(bitcast.763, broadcast.1261)
-  add.597 = bf16[2048,32,2048]{2,1,0} add(add.596, bitcast.756)
-  convert.295 = f32[2048,32,2048]{2,1,0} convert(add.597)
-  bitcast.774 = f32[32,2048,2048]{1,0,2} bitcast(convert.295)
-  reduce.94 = f32[32,2048]{1,0} reduce(bitcast.774, constant.1968), dimensions={2}, to_apply=region_20.962
-  all-reduce.26 = f32[32,2048]{1,0} all-reduce(reduce.94), channel_id=66, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, to_apply=region_20.962
-  multiply.668 = f32[32,2048]{1,0} multiply(all-reduce.26, broadcast.1243)
-  broadcast.1263 = f32[2048,32,2048]{2,1,0} broadcast(multiply.668), dimensions={1,2}
-  subtract.228 = f32[2048,32,2048]{2,1,0} subtract(convert.295, broadcast.1263)
-  multiply.669 = f32[2048,32,2048]{2,1,0} multiply(subtract.228, subtract.228)
-  bitcast.809 = f32[32,2048,2048]{1,0,2} bitcast(multiply.669)
-  reduce.95 = f32[32,2048]{1,0} reduce(bitcast.809, constant.1968), dimensions={2}, to_apply=region_20.962
-  all-reduce.27 = f32[32,2048]{1,0} all-reduce(reduce.95), channel_id=67, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, to_apply=region_20.962
-  multiply.670 = f32[32,2048]{1,0} multiply(all-reduce.27, broadcast.1243)
-  add.598 = f32[32,2048]{1,0} add(multiply.670, broadcast.1264)
-  bitcast.1148 = f32[1,32,2048]{2,1,0} bitcast(add.598)
-  rsqrt.19 = f32[1,32,2048]{2,1,0} rsqrt(bitcast.1148)
-  bitcast.1602 = f32[32,2048]{1,0} bitcast(rsqrt.19)
-  broadcast.1329 = f32[2048,32,2048]{2,1,0} broadcast(bitcast.1602), dimensions={1,2}
-  multiply.750 = f32[2048,32,2048]{2,1,0} multiply(subtract.228, broadcast.1329)
-  convert.312 = bf16[2048,32,2048]{2,1,0} convert(multiply.750)
-  get-tuple-element.42 = bf16[24,2048]{1,0} get-tuple-element(param.3), index=18
-  dynamic-slice.331 = bf16[1,2048]{1,0} dynamic-slice(get-tuple-element.42, select.372, constant.1980), dynamic_slice_sizes={1,2048}
-  add.599 = bf16[1,2048]{1,0} add(dynamic-slice.331, broadcast.1266)
-  bitcast.1609 = bf16[2048]{0} bitcast(add.599)
-  broadcast.1330 = bf16[2048,32,2048]{2,1,0} broadcast(bitcast.1609), dimensions={0}
-  multiply.745 = bf16[2048,32,2048]{2,1,0} multiply(convert.312, broadcast.1330)
-  get-tuple-element.43 = bf16[24,2048]{1,0} get-tuple-element(param.3), index=17
-  dynamic-slice.332 = bf16[1,2048]{1,0} dynamic-slice(get-tuple-element.43, select.372, constant.1980), dynamic_slice_sizes={1,2048}
-  bitcast.1615 = bf16[2048]{0} bitcast(dynamic-slice.332)
-  broadcast.1331 = bf16[2048,32,2048]{2,1,0} broadcast(bitcast.1615), dimensions={0}
-  add.636 = bf16[2048,32,2048]{2,1,0} add(multiply.745, broadcast.1331)
-  bitcast.1620 = bf16[32,2048,2048]{1,0,2} bitcast(add.636)
-  all-gather.12 = bf16[32,2048,4096]{1,0,2} all-gather(bitcast.1620), channel_id=69, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={2}, use_global_device_ids=true
-  bitcast.877 = bf16[65536,4096]{0,1} bitcast(all-gather.12)
-  get-tuple-element.44 = bf16[24,1024,8192]{2,1,0} get-tuple-element(param.3), index=15
-  dynamic-slice.333 = bf16[1,1024,8192]{2,1,0} dynamic-slice(get-tuple-element.44, select.372, constant.1980, constant.1980), dynamic_slice_sizes={1,1024,8192}
-  bitcast.890 = bf16[1024,8192]{1,0} bitcast(dynamic-slice.333)
-  all-gather.11 = bf16[4096,8192]{1,0} all-gather(bitcast.890), channel_id=68, replica_groups={{0,2,4,6},{1,3,5,7}}, dimensions={0}, use_global_device_ids=true
-  dot.45 = bf16[65536,8192]{1,0} dot(bitcast.877, all-gather.11), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  get-tuple-element.45 = bf16[24,8192]{1,0} get-tuple-element(param.3), index=14
-  dynamic-slice.334 = bf16[1,8192]{1,0} dynamic-slice(get-tuple-element.45, select.372, constant.1980), dynamic_slice_sizes={1,8192}
-  bitcast.906 = bf16[8192]{0} bitcast(dynamic-slice.334)
-  broadcast.1269 = bf16[65536,8192]{1,0} broadcast(bitcast.906), dimensions={1}
-  add.601 = bf16[65536,8192]{1,0} add(dot.45, broadcast.1269)
-  bitcast.997 = bf16[32,2048,8192]{2,1,0} bitcast(add.601)
-  broadcast.1333 = bf16[2048,32,2048]{2,1,0} broadcast(subtract.229), dimensions={1,2}
-  multiply.746 = bf16[2048,32,2048]{2,1,0} multiply(bitcast.1652, broadcast.1333)
-  bitcast.1739 = bf16[32,2048,2048]{1,0,2} bitcast(multiply.746)
-  all-gather.14 = bf16[32,2048,4096]{1,0,2} all-gather(bitcast.1739), channel_id=71, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={2}, use_global_device_ids=true
-  bitcast.934 = bf16[65536,4096]{0,1} bitcast(all-gather.14)
-  get-tuple-element.46 = bf16[24,8192,1024]{1,2,0} get-tuple-element(param.3), index=16
-  bitcast.935 = bf16[24,1024,8192]{2,1,0} bitcast(get-tuple-element.46)
-  dynamic-slice.335 = bf16[1,1024,8192]{2,1,0} dynamic-slice(bitcast.935, select.372, constant.1980, constant.1980), dynamic_slice_sizes={1,1024,8192}
-  bitcast.947 = bf16[8192,1024]{0,1} bitcast(dynamic-slice.335)
-  all-gather.13 = bf16[8192,4096]{0,1} all-gather(bitcast.947), channel_id=70, replica_groups={{0,2,4,6},{1,3,5,7}}, dimensions={1}, use_global_device_ids=true
-  dot.46 = bf16[65536,8192]{1,0} dot(bitcast.934, all-gather.13), lhs_contracting_dims={1}, rhs_contracting_dims={1}
-  bitcast.1092 = bf16[32,2048,8192]{2,1,0} bitcast(dot.46)
-  broadcast.1335 = bf16[32,2048,8192]{2,1,0} broadcast(subtract.229), dimensions={0,1}
-  multiply.703 = bf16[32,2048,8192]{2,1,0} multiply(bitcast.1092, broadcast.1335)
-  multiply.685 = bf16[32,2048,8192]{2,1,0} multiply(bitcast.997, multiply.703)
-  constant.2002 = bf16[] constant(0.5)
-  broadcast.1288 = bf16[32,2048,8192]{2,1,0} broadcast(constant.2002), dimensions={}
-  multiply.686 = bf16[32,2048,8192]{2,1,0} multiply(multiply.685, broadcast.1288)
-  broadcast.1287 = bf16[32,2048,8192]{2,1,0} broadcast(constant.1991), dimensions={}
-  multiply.700 = bf16[32,2048,8192]{2,1,0} multiply(bitcast.997, bitcast.997)
-  multiply.693 = bf16[32,2048,8192]{2,1,0} multiply(bitcast.997, multiply.700)
-  constant.1998 = bf16[] constant(0.04468)
-  broadcast.1282 = bf16[32,2048,8192]{2,1,0} broadcast(constant.1998), dimensions={}
-  multiply.694 = bf16[32,2048,8192]{2,1,0} multiply(multiply.693, broadcast.1282)
-  add.605 = bf16[32,2048,8192]{2,1,0} add(bitcast.997, multiply.694)
-  constant.2010 = bf16[] constant(0.7969)
-  broadcast.1324 = bf16[32,2048,8192]{2,1,0} broadcast(constant.2010), dimensions={}
-  multiply.695 = bf16[32,2048,8192]{2,1,0} multiply(add.605, broadcast.1324)
-  tanh.7 = bf16[32,2048,8192]{2,1,0} tanh(multiply.695)
-  subtract.231 = bf16[32,2048,8192]{2,1,0} subtract(broadcast.1287, tanh.7)
-  multiply.691 = bf16[32,2048,8192]{2,1,0} multiply(multiply.686, subtract.231)
-  multiply.737 = bf16[32,2048,8192]{2,1,0} multiply(multiply.691, tanh.7)
-  add.630 = bf16[32,2048,8192]{2,1,0} add(multiply.691, multiply.737)
-  multiply.738 = bf16[32,2048,8192]{2,1,0} multiply(add.630, broadcast.1324)
-  constant.2011 = bf16[] constant(0.03564)
-  broadcast.1326 = bf16[32,2048,8192]{2,1,0} broadcast(constant.2011), dimensions={}
-  multiply.739 = bf16[32,2048,8192]{2,1,0} multiply(add.630, broadcast.1326)
-  constant.2012 = bf16[] constant(3)
-  broadcast.1327 = bf16[32,2048,8192]{2,1,0} broadcast(constant.2012), dimensions={}
-  multiply.740 = bf16[32,2048,8192]{2,1,0} multiply(multiply.700, broadcast.1327)
-  multiply.741 = bf16[32,2048,8192]{2,1,0} multiply(multiply.739, multiply.740)
-  add.632 = bf16[32,2048,8192]{2,1,0} add(multiply.738, multiply.741)
-  add.637 = bf16[32,2048,8192]{2,1,0} add(tanh.7, broadcast.1287)
-  multiply.747 = bf16[32,2048,8192]{2,1,0} multiply(add.637, broadcast.1288)
-  multiply.743 = bf16[32,2048,8192]{2,1,0} multiply(multiply.703, multiply.747)
-  add.635 = bf16[32,2048,8192]{2,1,0} add(add.632, multiply.743)
-  bitcast.1629 = bf16[65536,8192]{1,0} bitcast(add.635)
-  dot.47 = bf16[65536,4096]{0,1} dot(bitcast.1629, all-gather.11), lhs_contracting_dims={1}, rhs_contracting_dims={1}
-  bitcast.1130 = bf16[32,2048,4096]{1,0,2} bitcast(dot.47)
-  reduce-scatter.1 = bf16[32,2048,2048]{1,0,2} reduce-scatter(bitcast.1130), channel_id=323, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, dimensions={2}, to_apply=add
-  bitcast.1766 = bf16[2048,32,2048]{2,1,0} bitcast(reduce-scatter.1)
-  multiply.712 = bf16[2048,32,2048]{2,1,0} multiply(bitcast.1766, broadcast.1330)
-  convert.299 = f32[2048,32,2048]{2,1,0} convert(multiply.712)
-  multiply.707 = f32[2048,32,2048]{2,1,0} multiply(subtract.228, convert.299)
-  bitcast.1135 = f32[32,2048,2048]{1,0,2} bitcast(multiply.707)
-  reduce.96 = f32[32,2048]{1,0} reduce(bitcast.1135, constant.1968), dimensions={2}, to_apply=region_20.962
-  all-reduce.29 = f32[32,2048]{1,0} all-reduce(reduce.96), channel_id=73, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, to_apply=region_20.962
-  bitcast.1140 = f32[1,32,2048]{2,1,0} bitcast(all-reduce.29)
-  divide.105 = f32[1,32,2048]{2,1,0} divide(rsqrt.19, bitcast.1148)
-  constant.2008 = f32[] constant(-0.5)
-  broadcast.1313 = f32[1,32,2048]{2,1,0} broadcast(constant.2008), dimensions={}
-  multiply.708 = f32[1,32,2048]{2,1,0} multiply(divide.105, broadcast.1313)
-  multiply.709 = f32[1,32,2048]{2,1,0} multiply(bitcast.1140, multiply.708)
-  constant.2009 = f32[] constant(0.00048828125)
-  broadcast.1315 = f32[1,32,2048]{2,1,0} broadcast(constant.2009), dimensions={}
-  multiply.710 = f32[1,32,2048]{2,1,0} multiply(multiply.709, broadcast.1315)
-  bitcast.1235 = f32[32,2048]{1,0} bitcast(multiply.710)
-  broadcast.1296 = f32[2048,32,2048]{2,1,0} broadcast(bitcast.1235), dimensions={1,2}
-  multiply.717 = f32[2048,32,2048]{2,1,0} multiply(subtract.228, broadcast.1296)
-  multiply.718 = f32[2048,32,2048]{2,1,0} multiply(convert.299, broadcast.1329)
-  add.617 = f32[2048,32,2048]{2,1,0} add(multiply.717, multiply.718)
-  negate.58 = f32[2048,32,2048]{2,1,0} negate(multiply.717)
-  bitcast.1189 = f32[32,2048,2048]{1,0,2} bitcast(negate.58)
-  reduce.97 = f32[32,2048]{1,0} reduce(bitcast.1189, constant.1968), dimensions={2}, to_apply=region_20.962
-  negate.59 = f32[2048,32,2048]{2,1,0} negate(multiply.718)
-  bitcast.1203 = f32[32,2048,2048]{1,0,2} bitcast(negate.59)
-  reduce.98 = f32[32,2048]{1,0} reduce(bitcast.1203, constant.1968), dimensions={2}, to_apply=region_20.962
-  add.613 = f32[32,2048]{1,0} add(reduce.97, reduce.98)
-  all-reduce.274 = f32[32,2048]{1,0} all-reduce(add.613), channel_id=335, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, to_apply=region_20.962
-  multiply.719 = f32[32,2048]{1,0} multiply(all-reduce.274, broadcast.1243)
-  broadcast.1297 = f32[2048,32,2048]{2,1,0} broadcast(multiply.719), dimensions={1,2}
-  add.618 = f32[2048,32,2048]{2,1,0} add(add.617, broadcast.1297)
-  convert.301 = bf16[2048,32,2048]{2,1,0} convert(add.618)
-  add.619 = bf16[2048,32,2048]{2,1,0} add(bitcast.1652, convert.301)
-  add.616 = bf16[2048,32,2048]{2,1,0} add(convert.287, add.619)
-  bitcast.2063 = bf16[32,2048,2048]{1,0,2} bitcast(add.619)
-  all-gather.15 = bf16[32,2048,4096]{1,0,2} all-gather(bitcast.2063), channel_id=76, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={2}, use_global_device_ids=true
-  bitcast.1263 = bf16[65536,4096]{0,1} bitcast(all-gather.15)
-  bitcast.1269 = bf16[4096,2048]{1,0} bitcast(all-gather.9)
-  dot.48 = bf16[65536,2048]{1,0} dot(bitcast.1263, bitcast.1269), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  bitcast.1381 = bf16[32,2048,16,128]{3,2,1,0} bitcast(dot.48)
-  transpose.122 = bf16[32,16,2048,128]{3,2,1,0} transpose(bitcast.1381), dimensions={0,2,1,3}
-  dot.49 = bf16[32,16,2048,2048]{3,2,1,0} dot(transpose.122, bitcast.1280), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  convert.303 = f32[32,16,2048,2048]{3,2,1,0} convert(dot.49)
-  broadcast.1298 = f32[32,16,2048]{2,1,0} broadcast(constant.2005), dimensions={}
-  multiply.720 = f32[32,16,2048]{2,1,0} multiply(reduce.93, reduce.93)
-  divide.106 = f32[32,16,2048]{2,1,0} divide(broadcast.1298, multiply.720)
-  broadcast.1299 = f32[32,16,2048,2048]{3,2,1,0} broadcast(divide.106), dimensions={0,1,2}
-  multiply.721 = f32[32,16,2048,2048]{3,2,1,0} multiply(convert.303, broadcast.1299)
-  multiply.722 = f32[32,16,2048,2048]{3,2,1,0} multiply(multiply.721, exponential.8)
-  reduce.99 = f32[32,16,2048]{2,1,0} reduce(multiply.722, constant.1968), dimensions={3}, to_apply=region_20.962
-  negate.61 = f32[32,16,2048]{2,1,0} negate(reduce.99)
-  broadcast.1305 = f32[32,16,2048,2048]{3,2,1,0} broadcast(negate.61), dimensions={0,1,2}
-  divide.108 = f32[32,16,2048,2048]{3,2,1,0} divide(convert.303, broadcast.1309)
-  add.622 = f32[32,16,2048,2048]{3,2,1,0} add(broadcast.1305, divide.108)
-  multiply.724 = f32[32,16,2048,2048]{3,2,1,0} multiply(add.622, exponential.8)
-  convert.305 = bf16[32,16,2048,2048]{3,2,1,0} convert(multiply.724)
-  dot.50 = bf16[32,16,2048,128]{3,2,1,0} dot(convert.305, transpose.113), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  bitcast.1934 = bf16[1,32,16,2048,128]{4,3,2,1,0} bitcast(dot.50)
-  pad.6 = bf16[3,32,16,2048,128]{4,3,2,1,0} pad(bitcast.1934, constant.2006), padding=1_1x0_0x0_0x0_0x0_0
-  transpose.120 = bf16[32,16,2048,128]{3,2,1,0} transpose(bitcast.647), dimensions={0,1,3,2}
-  dot.51 = bf16[32,16,2048,128]{3,2,1,0} dot(convert.305, transpose.120), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  broadcast.1307 = bf16[32,16,2048,128]{3,2,1,0} broadcast(constant.2007), dimensions={}
-  multiply.725 = bf16[32,16,2048,128]{3,2,1,0} multiply(dot.51, broadcast.1307)
-  bitcast.1941 = bf16[1,32,16,2048,128]{4,3,2,1,0} bitcast(multiply.725)
-  pad.7 = bf16[3,32,16,2048,128]{4,3,2,1,0} pad(bitcast.1941, constant.2006), padding=0_2x0_0x0_0x0_0x0_0
-  add.638 = bf16[3,32,16,2048,128]{4,3,2,1,0} add(pad.6, pad.7)
-  transpose.123 = bf16[32,16,128,2048]{3,2,1,0} transpose(bitcast.1381), dimensions={0,2,3,1}
-  dot.89 = bf16[32,16,2048,128]{3,2,1,0} dot(convert.306, transpose.123), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
-  bitcast.1949 = bf16[1,32,16,2048,128]{4,3,2,1,0} bitcast(dot.89)
-  pad.8 = bf16[3,32,16,2048,128]{4,3,2,1,0} pad(bitcast.1949, constant.2006), padding=2_0x0_0x0_0x0_0x0_0
-  add.639 = bf16[3,32,16,2048,128]{4,3,2,1,0} add(add.638, pad.8)
-  transpose.127 = bf16[32,2048,3,16,128]{4,3,2,1,0} transpose(add.639), dimensions={1,3,0,2,4}
-  bitcast.1416 = bf16[65536,6144]{1,0} bitcast(transpose.127)
-  dot.52 = bf16[65536,4096]{0,1} dot(bitcast.1416, bitcast.1420), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  bitcast.1424 = bf16[32,2048,4096]{1,0,2} bitcast(dot.52)
-  reduce-scatter.2 = bf16[32,2048,2048]{1,0,2} reduce-scatter(bitcast.1424), channel_id=324, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, dimensions={2}, to_apply=add
-  bitcast.1851 = bf16[2048,32,2048]{2,1,0} bitcast(reduce-scatter.2)
-  multiply.732 = bf16[2048,32,2048]{2,1,0} multiply(bitcast.1851, broadcast.1338)
-  convert.308 = f32[2048,32,2048]{2,1,0} convert(multiply.732)
-  multiply.727 = f32[2048,32,2048]{2,1,0} multiply(subtract.224, convert.308)
-  bitcast.1434 = f32[32,2048,2048]{1,0,2} bitcast(multiply.727)
-  reduce.100 = f32[32,2048]{1,0} reduce(bitcast.1434, constant.1968), dimensions={2}, to_apply=region_20.962
-  all-reduce.33 = f32[32,2048]{1,0} all-reduce(reduce.100), channel_id=78, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, to_apply=region_20.962
-  bitcast.1439 = f32[1,32,2048]{2,1,0} bitcast(all-reduce.33)
-  divide.110 = f32[1,32,2048]{2,1,0} divide(rsqrt.20, bitcast.1447)
-  multiply.728 = f32[1,32,2048]{2,1,0} multiply(divide.110, broadcast.1313)
-  multiply.729 = f32[1,32,2048]{2,1,0} multiply(bitcast.1439, multiply.728)
-  multiply.730 = f32[1,32,2048]{2,1,0} multiply(multiply.729, broadcast.1315)
-  bitcast.1485 = f32[32,2048]{1,0} bitcast(multiply.730)
-  broadcast.1321 = f32[2048,32,2048]{2,1,0} broadcast(bitcast.1485), dimensions={1,2}
-  multiply.734 = f32[2048,32,2048]{2,1,0} multiply(subtract.224, broadcast.1321)
-  multiply.735 = f32[2048,32,2048]{2,1,0} multiply(convert.308, broadcast.1337)
-  add.625 = f32[2048,32,2048]{2,1,0} add(multiply.734, multiply.735)
-  negate.62 = f32[2048,32,2048]{2,1,0} negate(multiply.734)
-  bitcast.1491 = f32[32,2048,2048]{1,0,2} bitcast(negate.62)
-  reduce.101 = f32[32,2048]{1,0} reduce(bitcast.1491, constant.1968), dimensions={2}, to_apply=region_20.962
-  negate.63 = f32[2048,32,2048]{2,1,0} negate(multiply.735)
-  bitcast.1505 = f32[32,2048,2048]{1,0,2} bitcast(negate.63)
-  reduce.102 = f32[32,2048]{1,0} reduce(bitcast.1505, constant.1968), dimensions={2}, to_apply=region_20.962
-  add.626 = f32[32,2048]{1,0} add(reduce.101, reduce.102)
-  all-reduce.275 = f32[32,2048]{1,0} all-reduce(add.626), channel_id=336, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, to_apply=region_20.962
-  multiply.736 = f32[32,2048]{1,0} multiply(all-reduce.275, broadcast.1243)
-  broadcast.1323 = f32[2048,32,2048]{2,1,0} broadcast(multiply.736), dimensions={1,2}
-  add.628 = f32[2048,32,2048]{2,1,0} add(add.625, broadcast.1323)
-  convert.309 = bf16[2048,32,2048]{2,1,0} convert(add.628)
-  add.629 = bf16[2048,32,2048]{2,1,0} add(add.616, convert.309)
-  bitcast.1525 = bf16[32,2048,2048]{1,0,2} bitcast(add.629)
-  get-tuple-element.47 = bf16[24,8192]{1,0} get-tuple-element(param.3), index=2
-  reduce.103 = bf16[8192]{0} reduce(add.635, constant.2006), dimensions={0,1}, to_apply=add
-  all-reduce.36 = bf16[8192]{0} all-reduce(reduce.103), channel_id=81, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=add
-  bitcast.1583 = bf16[1,8192]{1,0} bitcast(all-reduce.36)
-  dynamic-update-slice.28 = bf16[24,8192]{1,0} dynamic-update-slice(get-tuple-element.47, bitcast.1583, select.372, constant.1980)
-  get-tuple-element.48 = bf16[24,1024,8192]{2,1,0} get-tuple-element(param.3), index=3
-  all-gather.16 = bf16[32,2048,4096]{1,0,2} all-gather(bitcast.1620), channel_id=82, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={2}, use_global_device_ids=true
-  bitcast.1625 = bf16[4096,65536]{1,0} bitcast(all-gather.16)
-  dot.53 = bf16[4096,8192]{1,0} dot(bitcast.1625, bitcast.1629), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  reduce-scatter.3 = bf16[1024,8192]{1,0} reduce-scatter(dot.53), channel_id=325, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, dimensions={0}, to_apply=add
-  bitcast.1634 = bf16[1,1024,8192]{2,1,0} bitcast(reduce-scatter.3)
-  dynamic-update-slice.29 = bf16[24,1024,8192]{2,1,0} dynamic-update-slice(get-tuple-element.48, bitcast.1634, select.372, constant.1980, constant.1980)
-  get-tuple-element.49 = bf16[24,1024]{0,1} get-tuple-element(param.3), index=4
-  collective-permute.2 = bf16[24,1024]{0,1} collective-permute(get-tuple-element.49), channel_id=85, source_target_pairs={{0,0},{1,2},{2,4},{3,6},{4,1},{5,3},{6,5},{7,7}}
-  all-gather.17 = bf16[24,2048]{0,1} all-gather(collective-permute.2), channel_id=86, replica_groups={{0,4},{2,6},{1,5},{3,7}}, dimensions={1}, use_global_device_ids=true
-  bitcast.1649 = bf16[2048,24]{1,0} bitcast(all-gather.17)
-  reduce.104 = bf16[2048]{0} reduce(bitcast.1739, constant.2006), dimensions={0,1}, to_apply=add
-  all-reduce.38 = bf16[2048]{0} all-reduce(reduce.104), channel_id=84, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=add
-  bitcast.1671 = bf16[2048,1]{1,0} bitcast(all-reduce.38)
-  dynamic-update-slice.30 = bf16[2048,24]{1,0} dynamic-update-slice(bitcast.1649, bitcast.1671, constant.1980, select.372)
-  constant.2013 = s32[8]{0} constant({0, 2048, 0, 2048, 1024, 3072, 1024, 3072})
-  partition-id.3 = u32[] partition-id()
-  dynamic-slice.336 = s32[1]{0} dynamic-slice(constant.2013, partition-id.3), dynamic_slice_sizes={1}
-  constant.2014 = s32[8]{0} constant({0, 2048, 0, 2048, 0, 2048, 0, 2048})
-  dynamic-slice.337 = s32[1]{0} dynamic-slice(constant.2014, partition-id.3), dynamic_slice_sizes={1}
-  subtract.232 = s32[1]{0} subtract(dynamic-slice.336, dynamic-slice.337)
-  bitcast.2087 = s32[] bitcast(subtract.232)
-  dynamic-slice.338 = bf16[1024,24]{1,0} dynamic-slice(dynamic-update-slice.30, bitcast.2087, constant.1980), dynamic_slice_sizes={1024,24}
-  bitcast.1695 = bf16[24,1024]{0,1} bitcast(dynamic-slice.338)
-  collective-permute.9 = bf16[24,1024]{0,1} collective-permute(bitcast.1695), channel_id=109, source_target_pairs={{0,0},{2,1},{4,2},{6,3},{1,4},{3,5},{5,6},{7,7}}
-  get-tuple-element.50 = bf16[24,8192,1024]{1,2,0} get-tuple-element(param.3), index=5
-  bitcast.1698 = bf16[24,1024,8192]{2,1,0} bitcast(get-tuple-element.50)
-  multiply.748 = bf16[32,2048,8192]{2,1,0} multiply(bitcast.997, multiply.747)
-  multiply.749 = bf16[32,2048,8192]{2,1,0} multiply(multiply.748, broadcast.1335)
-  bitcast.1735 = bf16[8192,65536]{0,1} bitcast(multiply.749)
-  all-gather.18 = bf16[32,2048,4096]{1,0,2} all-gather(bitcast.1739), channel_id=87, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={2}, use_global_device_ids=true
-  bitcast.1743 = bf16[65536,4096]{0,1} bitcast(all-gather.18)
-  dot.54 = bf16[8192,4096]{0,1} dot(bitcast.1735, bitcast.1743), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  reduce-scatter.4 = bf16[8192,1024]{0,1} reduce-scatter(dot.54), channel_id=326, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, dimensions={1}, to_apply=add
-  bitcast.1748 = bf16[1,1024,8192]{2,1,0} bitcast(reduce-scatter.4)
-  dynamic-update-slice.31 = bf16[24,1024,8192]{2,1,0} dynamic-update-slice(bitcast.1698, bitcast.1748, select.372, constant.1980, constant.1980)
-  bitcast.1758 = bf16[24,8192,1024]{1,2,0} bitcast(dynamic-update-slice.31)
-  get-tuple-element.51 = bf16[24,1024]{0,1} get-tuple-element(param.3), index=6
-  collective-permute.3 = bf16[24,1024]{0,1} collective-permute(get-tuple-element.51), channel_id=90, source_target_pairs={{0,0},{1,2},{2,4},{3,6},{4,1},{5,3},{6,5},{7,7}}
-  all-gather.19 = bf16[24,2048]{0,1} all-gather(collective-permute.3), channel_id=91, replica_groups={{0,4},{2,6},{1,5},{3,7}}, dimensions={1}, use_global_device_ids=true
-  bitcast.1763 = bf16[2048,24]{1,0} bitcast(all-gather.19)
-  reduce.105 = bf16[2048]{0} reduce(reduce-scatter.1, constant.2006), dimensions={0,1}, to_apply=add
-  all-reduce.40 = bf16[2048]{0} all-reduce(reduce.105), channel_id=89, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=add
-  bitcast.1779 = bf16[2048,1]{1,0} bitcast(all-reduce.40)
-  dynamic-update-slice.32 = bf16[2048,24]{1,0} dynamic-update-slice(bitcast.1763, bitcast.1779, constant.1980, select.372)
-  dynamic-slice.339 = bf16[1024,24]{1,0} dynamic-slice(dynamic-update-slice.32, bitcast.2087, constant.1980), dynamic_slice_sizes={1024,24}
-  bitcast.1794 = bf16[24,1024]{0,1} bitcast(dynamic-slice.339)
-  collective-permute.10 = bf16[24,1024]{0,1} collective-permute(bitcast.1794), channel_id=110, source_target_pairs={{0,0},{2,1},{4,2},{6,3},{1,4},{3,5},{5,6},{7,7}}
-  get-tuple-element.52 = bf16[24,1024]{0,1} get-tuple-element(param.3), index=7
-  collective-permute.4 = bf16[24,1024]{0,1} collective-permute(get-tuple-element.52), channel_id=93, source_target_pairs={{0,0},{1,2},{2,4},{3,6},{4,1},{5,3},{6,5},{7,7}}
-  all-gather.20 = bf16[24,2048]{0,1} all-gather(collective-permute.4), channel_id=94, replica_groups={{0,4},{2,6},{1,5},{3,7}}, dimensions={1}, use_global_device_ids=true
-  bitcast.1801 = bf16[2048,24]{1,0} bitcast(all-gather.20)
-  multiply.751 = bf16[2048,32,2048]{2,1,0} multiply(convert.312, bitcast.1766)
-  bitcast.1817 = bf16[32,2048,2048]{1,0,2} bitcast(multiply.751)
-  reduce.106 = bf16[2048]{0} reduce(bitcast.1817, constant.2006), dimensions={0,1}, to_apply=add
-  all-reduce.41 = bf16[2048]{0} all-reduce(reduce.106), channel_id=92, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=add
-  bitcast.1826 = bf16[2048,1]{1,0} bitcast(all-reduce.41)
-  dynamic-update-slice.33 = bf16[2048,24]{1,0} dynamic-update-slice(bitcast.1801, bitcast.1826, constant.1980, select.372)
-  dynamic-slice.340 = bf16[1024,24]{1,0} dynamic-slice(dynamic-update-slice.33, bitcast.2087, constant.1980), dynamic_slice_sizes={1024,24}
-  bitcast.1841 = bf16[24,1024]{0,1} bitcast(dynamic-slice.340)
-  collective-permute.11 = bf16[24,1024]{0,1} collective-permute(bitcast.1841), channel_id=111, source_target_pairs={{0,0},{2,1},{4,2},{6,3},{1,4},{3,5},{5,6},{7,7}}
-  get-tuple-element.53 = bf16[24,1024]{0,1} get-tuple-element(param.3), index=8
-  collective-permute.5 = bf16[24,1024]{0,1} collective-permute(get-tuple-element.53), channel_id=96, source_target_pairs={{0,0},{1,2},{2,4},{3,6},{4,1},{5,3},{6,5},{7,7}}
-  all-gather.21 = bf16[24,2048]{0,1} all-gather(collective-permute.5), channel_id=97, replica_groups={{0,4},{2,6},{1,5},{3,7}}, dimensions={1}, use_global_device_ids=true
-  bitcast.1848 = bf16[2048,24]{1,0} bitcast(all-gather.21)
-  reduce.107 = bf16[2048]{0} reduce(reduce-scatter.2, constant.2006), dimensions={0,1}, to_apply=add
-  all-reduce.42 = bf16[2048]{0} all-reduce(reduce.107), channel_id=95, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=add
-  bitcast.1864 = bf16[2048,1]{1,0} bitcast(all-reduce.42)
-  dynamic-update-slice.34 = bf16[2048,24]{1,0} dynamic-update-slice(bitcast.1848, bitcast.1864, constant.1980, select.372)
-  dynamic-slice.341 = bf16[1024,24]{1,0} dynamic-slice(dynamic-update-slice.34, bitcast.2087, constant.1980), dynamic_slice_sizes={1024,24}
-  bitcast.1879 = bf16[24,1024]{0,1} bitcast(dynamic-slice.341)
-  collective-permute.12 = bf16[24,1024]{0,1} collective-permute(bitcast.1879), channel_id=112, source_target_pairs={{0,0},{2,1},{4,2},{6,3},{1,4},{3,5},{5,6},{7,7}}
-  get-tuple-element.54 = bf16[24,1024]{0,1} get-tuple-element(param.3), index=9
-  collective-permute.6 = bf16[24,1024]{0,1} collective-permute(get-tuple-element.54), channel_id=99, source_target_pairs={{0,0},{1,2},{2,4},{3,6},{4,1},{5,3},{6,5},{7,7}}
-  all-gather.22 = bf16[24,2048]{0,1} all-gather(collective-permute.6), channel_id=100, replica_groups={{0,4},{2,6},{1,5},{3,7}}, dimensions={1}, use_global_device_ids=true
-  bitcast.1886 = bf16[2048,24]{1,0} bitcast(all-gather.22)
-  multiply.753 = bf16[2048,32,2048]{2,1,0} multiply(convert.314, bitcast.1851)
-  bitcast.1905 = bf16[32,2048,2048]{1,0,2} bitcast(multiply.753)
-  reduce.108 = bf16[2048]{0} reduce(bitcast.1905, constant.2006), dimensions={0,1}, to_apply=add
-  all-reduce.43 = bf16[2048]{0} all-reduce(reduce.108), channel_id=98, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=add
-  bitcast.1914 = bf16[2048,1]{1,0} bitcast(all-reduce.43)
-  dynamic-update-slice.35 = bf16[2048,24]{1,0} dynamic-update-slice(bitcast.1886, bitcast.1914, constant.1980, select.372)
-  dynamic-slice.342 = bf16[1024,24]{1,0} dynamic-slice(dynamic-update-slice.35, bitcast.2087, constant.1980), dynamic_slice_sizes={1024,24}
-  bitcast.1929 = bf16[24,1024]{0,1} bitcast(dynamic-slice.342)
-  collective-permute.13 = bf16[24,1024]{0,1} collective-permute(bitcast.1929), channel_id=113, source_target_pairs={{0,0},{2,1},{4,2},{6,3},{1,4},{3,5},{5,6},{7,7}}
-  get-tuple-element.55 = bf16[24,3,16,128]{3,2,1,0} get-tuple-element(param.3), index=10
-  bitcast.1979 = bf16[3,32,2048,16,128]{4,2,3,1,0} bitcast(add.639)
-  reduce.109 = bf16[3,16,128]{2,1,0} reduce(bitcast.1979, constant.2006), dimensions={1,2}, to_apply=add
-  all-reduce.44 = bf16[3,16,128]{2,1,0} all-reduce(reduce.109), channel_id=101, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=add
-  bitcast.1963 = bf16[1,3,16,128]{3,2,1,0} bitcast(all-reduce.44)
-  dynamic-update-slice.36 = bf16[24,3,16,128]{3,2,1,0} dynamic-update-slice(get-tuple-element.55, bitcast.1963, select.372, constant.1980, constant.1980, /*index=5*/constant.1980)
-  get-tuple-element.56 = bf16[24,3,1024,16,128]{4,3,1,2,0} get-tuple-element(param.3), index=11
-  bitcast.1974 = bf16[24,1024,3,16,128]{4,3,2,1,0} bitcast(get-tuple-element.56)
-  transpose.130 = bf16[3,16,128,32,2048]{4,3,2,1,0} transpose(add.639), dimensions={0,2,4,1,3}
-  bitcast.1983 = bf16[6144,65536]{1,0} bitcast(transpose.130)
-  all-gather.23 = bf16[32,2048,4096]{1,0,2} all-gather(bitcast.2003), channel_id=102, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={2}, use_global_device_ids=true
-  bitcast.2007 = bf16[65536,4096]{0,1} bitcast(all-gather.23)
-  dot.55 = bf16[6144,4096]{0,1} dot(bitcast.1983, bitcast.2007), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  bitcast.2011 = bf16[3,16,128,4096]{2,1,0,3} bitcast(dot.55)
-  reduce-scatter.5 = bf16[3,16,128,1024]{2,1,0,3} reduce-scatter(bitcast.2011), channel_id=327, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, dimensions={3}, to_apply=add
-  bitcast.2015 = bf16[1,1024,3,16,128]{4,3,2,1,0} bitcast(reduce-scatter.5)
-  dynamic-update-slice.37 = bf16[24,1024,3,16,128]{4,3,2,1,0} dynamic-update-slice(bitcast.1974, bitcast.2015, select.372, constant.1980, constant.1980, /*index=5*/constant.1980, constant.1980)
-  bitcast.2025 = bf16[24,3,1024,16,128]{4,3,1,2,0} bitcast(dynamic-update-slice.37)
-  get-tuple-element.57 = bf16[24,1024]{1,0} get-tuple-element(param.3), index=12
-  reduce.110 = bf16[2048]{0} reduce(bitcast.2063, constant.2006), dimensions={0,1}, to_apply=add
-  all-reduce.46 = bf16[2048]{0} all-reduce(reduce.110), channel_id=104, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=add
-  dynamic-slice.343 = bf16[1024]{0} dynamic-slice(all-reduce.46, bitcast.2087), dynamic_slice_sizes={1024}
-  bitcast.2046 = bf16[1,1024]{1,0} bitcast(dynamic-slice.343)
-  collective-permute.7 = bf16[1,1024]{1,0} collective-permute(bitcast.2046), channel_id=105, source_target_pairs={{0,0},{2,1},{4,2},{6,3},{1,4},{3,5},{5,6},{7,7}}
-  dynamic-update-slice.38 = bf16[24,1024]{1,0} dynamic-update-slice(get-tuple-element.57, collective-permute.7, select.372, constant.1980)
-  get-tuple-element.58 = bf16[24,1024,16,128]{3,2,1,0} get-tuple-element(param.3), index=13
-  bitcast.2066 = bf16[2048,65536]{1,0} bitcast(add.619)
-  transpose.133 = bf16[16,32,2048,128]{3,2,1,0} transpose(dot.44), dimensions={1,0,3,2}
-  bitcast.2072 = bf16[32,2048,16,128]{3,1,0,2} bitcast(transpose.133)
-  all-gather.24 = bf16[32,2048,32,128]{3,1,0,2} all-gather(bitcast.2072), channel_id=106, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={2}, use_global_device_ids=true
-  bitcast.2073 = bf16[32,32,2048,128]{3,2,1,0} bitcast(all-gather.24)
-  transpose.134 = bf16[32,2048,32,128]{3,2,1,0} transpose(bitcast.2073), dimensions={1,2,0,3}
-  bitcast.2077 = bf16[65536,4096]{1,0} bitcast(transpose.134)
-  dot.56 = bf16[2048,4096]{1,0} dot(bitcast.2066, bitcast.2077), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  bitcast.2081 = bf16[2048,32,128]{2,1,0} bitcast(dot.56)
-  all-reduce.47 = bf16[2048,32,128]{2,1,0} all-reduce(bitcast.2081), channel_id=107, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=add
-  constant.2015 = s32[8]{0} constant({0, 0, 16, 16, 0, 0, 16, 16})
-  dynamic-slice.344 = s32[1]{0} dynamic-slice(constant.2015, partition-id.3), dynamic_slice_sizes={1}
-  bitcast.2095 = s32[] bitcast(dynamic-slice.344)
-  dynamic-slice.345 = bf16[1024,16,128]{2,1,0} dynamic-slice(all-reduce.47, bitcast.2087, bitcast.2095, constant.1980), dynamic_slice_sizes={1024,16,128}
-  bitcast.2102 = bf16[1,1024,16,128]{3,2,1,0} bitcast(dynamic-slice.345)
-  collective-permute.8 = bf16[1,1024,16,128]{3,2,1,0} collective-permute(bitcast.2102), channel_id=108, source_target_pairs={{0,0},{2,1},{4,2},{6,3},{1,4},{3,5},{5,6},{7,7}}
-  dynamic-update-slice.39 = bf16[24,1024,16,128]{3,2,1,0} dynamic-update-slice(get-tuple-element.58, collective-permute.8, select.372, constant.1980, constant.1980, /*index=5*/constant.1980)
-  ROOT tuple.2 = (s32[], bf16[32,2048,2048]{1,0,2}, bf16[24,8192]{1,0}, bf16[24,1024,8192]{2,1,0}, bf16[24,1024]{0,1}, /*index=5*/bf16[24,8192,1024]{1,2,0}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, /*index=10*/bf16[24,3,16,128]{3,2,1,0}, bf16[24,3,1024,16,128]{4,3,1,2,0}, bf16[24,1024]{1,0}, bf16[24,1024,16,128]{3,2,1,0}, bf16[24,8192]{1,0}, /*index=15*/bf16[24,1024,8192]{2,1,0}, bf16[24,8192,1024]{1,2,0}, bf16[24,2048]{1,0}, bf16[24,2048]{1,0}, bf16[24,2048]{1,0}, /*index=20*/bf16[24,2048]{1,0}, bf16[24,3,16,128]{3,2,1,0}, bf16[24,3,1024,16,128]{4,3,1,2,0}, bf16[24,1024]{1,0}, bf16[24,1024,16,128]{3,2,1,0}, /*index=25*/bf16[24,32,2048,2048]{2,1,3,0}, bf16[32,1,2048,2048]{3,2,0,1}, bf16[32,2048]{1,0}) tuple(add.581, bitcast.1525, dynamic-update-slice.28, dynamic-update-slice.29, collective-permute.9, /*index=5*/bitcast.1758, collective-permute.10, collective-permute.11, collective-permute.12, collective-permute.13, /*index=10*/dynamic-update-slice.36, bitcast.2025, dynamic-update-slice.38, dynamic-update-slice.39, get-tuple-element.45, /*index=15*/get-tuple-element.44, get-tuple-element.46, get-tuple-element.43, get-tuple-element.42, get-tuple-element.37, /*index=20*/get-tuple-element.36, get-tuple-element.38, get-tuple-element.35, get-tuple-element.41, get-tuple-element.40, /*index=25*/get-tuple-element.32, get-tuple-element.39, get-tuple-element.33)
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
-                                         GetCudnnVersion()};
-  TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
-
-  ComputationLayout computation_layout(
-      m->entry_computation()->ComputeProgramShape());
-
-  HloInstruction* fwd_instruction = nullptr;
-  HloInstruction* bwd_instruction = nullptr;
-  SCOPED_TRACE(m->ToString());
-  for (HloInstruction* instr :
-       m->entry_computation()->MakeInstructionPostOrder()) {
-    if (instr->opcode() == HloOpcode::kCustomCall &&
-        instr->custom_call_target() == kCudnnfMHASoftmaxCallTarget) {
-      fwd_instruction = instr;
-    }
-    if (instr->opcode() == HloOpcode::kCustomCall &&
-        instr->custom_call_target() == kCudnnfMHASoftmaxBackwardCallTarget) {
-      bwd_instruction = instr;
-    }
-  }
-  EXPECT_NE(fwd_instruction, nullptr);
-  EXPECT_NE(bwd_instruction, nullptr);
-  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
-                          fwd_instruction->backend_config<GpuBackendConfig>());
-  const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config();
-  EXPECT_EQ(config.mask_type(), CudnnfMHABackendConfig::CAUSAL);
-}
-
-TEST_F(CudnnFusedMhaRewriterTestHloTest,
-       BF16TrainingBmm2CanonicalizationRestoreFwdGraph) {
-  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
-  const char* module_str = R"(
-HloModule pjit__unnamed_function_, entry_computation_layout={(bf16[2,256,4,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}, bf16[2,4,256,256]{3,2,1,0})->(bf16[4,256,8,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={false,false,false,false}, num_partitions=4
-
-region_0.6 {
-  Arg_0.7 = bf16[] parameter(0)
-  Arg_1.8 = bf16[] parameter(1)
-  ROOT maximum.5 = bf16[] maximum(Arg_0.7, Arg_1.8)
-}
-
-region_1.10 {
-  Arg_0.11 = f32[] parameter(0)
-  Arg_1.12 = f32[] parameter(1)
-  ROOT add.14 = f32[] add(Arg_0.11, Arg_1.12)
-}
-
-add.clone {
-  x.1 = u32[] parameter(0)
-  y.1 = u32[] parameter(1)
-  ROOT add.15 = u32[] add(x.1, y.1)
-}
-
-region_2.65 {
-  Arg_0.66 = bf16[] parameter(0)
-  Arg_1.67 = bf16[] parameter(1)
-  ROOT add.16 = bf16[] add(Arg_0.66, Arg_1.67)
-}
-
-ENTRY main.164_spmd {
-  param = bf16[2,256,4,64]{3,2,1,0} parameter(2), sharding={devices=[2,1,2,1]<=[4]}
-  transpose.26 = bf16[2,4,64,256]{3,2,1,0} transpose(param), dimensions={0,2,3,1}
-  param.1 = bf16[2,256,4,64]{3,2,1,0} parameter(0), sharding={devices=[2,1,2,1]<=[4]}
-  transpose.27 = bf16[2,4,256,64]{3,2,1,0} transpose(param.1), dimensions={0,2,1,3}
-  constant.46 = bf16[] constant(0.5)
-  broadcast.126 = bf16[2,4,256,64]{3,2,1,0} broadcast(constant.46), dimensions={}
-  multiply.34 = bf16[2,4,256,64]{3,2,1,0} multiply(transpose.27, broadcast.126)
-  param.2 = bf16[2,256,4,64]{3,2,1,0} parameter(1), sharding={devices=[2,1,2,1]<=[4]}
-  transpose.29 = bf16[2,4,64,256]{3,2,1,0} transpose(param.2), dimensions={0,2,3,1}
-  dot.12 = bf16[2,4,256,256]{3,2,1,0} dot(multiply.34, transpose.29), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  param.3 = bf16[2,4,256,256]{3,2,1,0} parameter(4), sharding={devices=[2,2,1,1]<=[4]}
-  add.17 = bf16[2,4,256,256]{3,2,1,0} add(dot.12, param.3)
-  constant.47 = bf16[] constant(-inf)
-  reduce.4 = bf16[2,4,256]{2,1,0} reduce(add.17, constant.47), dimensions={3}, to_apply=region_0.6
-  broadcast.127 = bf16[2,4,256,256]{3,2,1,0} broadcast(reduce.4), dimensions={0,1,2}
-  subtract.14 = bf16[2,4,256,256]{3,2,1,0} subtract(add.17, broadcast.127)
-  exponential.2 = bf16[2,4,256,256]{3,2,1,0} exponential(subtract.14)
-  convert.46 = f32[2,4,256,256]{3,2,1,0} convert(exponential.2)
-  constant.48 = f32[] constant(0)
-  reduce.5 = f32[2,4,256]{2,1,0} reduce(convert.46, constant.48), dimensions={3}, to_apply=region_1.10
-  convert.47 = bf16[2,4,256]{2,1,0} convert(reduce.5)
-  broadcast.128 = bf16[2,4,256,256]{3,2,1,0} broadcast(convert.47), dimensions={0,1,2}
-  divide.7 = bf16[2,4,256,256]{3,2,1,0} divide(exponential.2, broadcast.128)
-  broadcast.129 = f32[4096]{0} broadcast(constant.48), dimensions={}
-  constant.50 = u32[] constant(0)
-  broadcast.131 = u32[8192]{0} broadcast(constant.50), dimensions={}
-  broadcast.133 = u32[4096]{0} broadcast(constant.50), dimensions={}
-  iota.3 = u32[8192]{0} iota(), iota_dimension=0
-  slice.14 = u32[4096]{0} slice(iota.3), slice={[0:4096]}
-  slice.15 = u32[4096]{0} slice(iota.3), slice={[4096:8192]}
-  custom-call.3 = (u32[4096]{0}, u32[4096]{0}) custom-call(broadcast.133, broadcast.133, slice.14, slice.15), custom_call_target="cu_threefry2x32", operand_layout_constraints={u32[4096]{0}, u32[4096]{0}, u32[4096]{0}, u32[4096]{0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config="\000\020\000\000\000\000\000\000"
-  get-tuple-element.6 = u32[4096]{0} get-tuple-element(custom-call.3), index=0
-  constant.115 = u32[1]{0} constant({0})
-  constant.52 = u32[4]{0} constant({0, 0, 1, 1})
-  partition-id = u32[] partition-id()
-  dynamic-slice.21 = u32[1]{0} dynamic-slice(constant.52, partition-id), dynamic_slice_sizes={1}
-  constant.116 = u32[1]{0} constant({1})
-  clamp.3 = u32[1]{0} clamp(constant.115, dynamic-slice.21, constant.116)
-  convert.48 = s32[1]{0} convert(clamp.3)
-  constant.117 = s32[1]{0} constant({2048})
-  multiply.35 = s32[1]{0} multiply(convert.48, constant.117)
-  bitcast.105 = s32[] bitcast(multiply.35)
-  dynamic-slice.22 = u32[2048]{0} dynamic-slice(get-tuple-element.6, bitcast.105), dynamic_slice_sizes={2048}
-  constant.58 = s32[4]{0} constant({0, 0, 1, 1})
-  dynamic-slice.23 = s32[1]{0} dynamic-slice(constant.58, partition-id), dynamic_slice_sizes={1}
-  multiply.36 = s32[1]{0} multiply(dynamic-slice.23, constant.117)
-  bitcast.108 = s32[] bitcast(multiply.36)
-  dynamic-update-slice.2 = u32[8192]{0} dynamic-update-slice(broadcast.131, dynamic-slice.22, bitcast.108)
-  get-tuple-element.7 = u32[4096]{0} get-tuple-element(custom-call.3), index=1
-  dynamic-slice.24 = u32[2048]{0} dynamic-slice(get-tuple-element.7, bitcast.105), dynamic_slice_sizes={2048}
-  constant.65 = s32[] constant(4096)
-  add.18 = s32[] add(bitcast.108, constant.65)
-  dynamic-update-slice.3 = u32[8192]{0} dynamic-update-slice(dynamic-update-slice.2, dynamic-slice.24, add.18)
-  all-reduce = u32[8192]{0} all-reduce(dynamic-update-slice.3), channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=add.clone
-  constant.118 = s32[1]{0} constant({4096})
-  multiply.37 = s32[1]{0} multiply(dynamic-slice.23, constant.118)
-  bitcast.119 = s32[] bitcast(multiply.37)
-  dynamic-slice.25 = u32[4096]{0} dynamic-slice(all-reduce, bitcast.119), dynamic_slice_sizes={4096}
-  constant.69 = u32[] constant(9)
-  broadcast.134 = u32[4096]{0} broadcast(constant.69), dimensions={}
-  shift-right-logical.6 = u32[4096]{0} shift-right-logical(dynamic-slice.25, broadcast.134)
-  constant.70 = u32[] constant(1065353216)
-  broadcast.135 = u32[4096]{0} broadcast(constant.70), dimensions={}
-  or.5 = u32[4096]{0} or(shift-right-logical.6, broadcast.135)
-  bitcast-convert.5 = f32[4096]{0} bitcast-convert(or.5)
-  constant.71 = f32[] constant(-1)
-  broadcast.136 = f32[4096]{0} broadcast(constant.71), dimensions={}
-  add.19 = f32[4096]{0} add(bitcast-convert.5, broadcast.136)
-  maximum.6 = f32[4096]{0} maximum(broadcast.129, add.19)
-  constant.72 = f32[] constant(0.5)
-  broadcast.137 = f32[4096]{0} broadcast(constant.72), dimensions={}
-  compare.4 = pred[4096]{0} compare(maximum.6, broadcast.137), direction=LT
-  bitcast.135 = pred[2,8,256]{2,1,0} bitcast(compare.4)
-  convert.49 = bf16[2,8,256]{2,1,0} convert(bitcast.135)
-  constant.80 = s32[] constant(0)
-  constant.78 = s32[4]{0} constant({0, 4, 0, 4})
-  dynamic-slice.26 = s32[1]{0} dynamic-slice(constant.78, partition-id), dynamic_slice_sizes={1}
-  bitcast.181 = s32[] bitcast(dynamic-slice.26)
-  dynamic-slice.27 = bf16[2,4,256]{2,1,0} dynamic-slice(convert.49, constant.80, bitcast.181, constant.80), dynamic_slice_sizes={2,4,256}
-  broadcast.139 = bf16[2,4,256,256]{3,2,1,0} broadcast(dynamic-slice.27), dimensions={0,1,3}
-  multiply.38 = bf16[2,4,256,256]{3,2,1,0} multiply(divide.7, broadcast.139)
-  constant.93 = bf16[] constant(2)
-  broadcast.141 = bf16[2,4,256,256]{3,2,1,0} broadcast(constant.93), dimensions={}
-  multiply.39 = bf16[2,4,256,256]{3,2,1,0} multiply(multiply.38, broadcast.141)
-  dot.13 = bf16[2,4,64,256]{3,2,1,0} dot(transpose.26, multiply.39), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
-  transpose.31 = bf16[4,2,64,256]{3,2,1,0} transpose(dot.13), dimensions={1,0,2,3}
-  bitcast.154 = bf16[2,256,4,64]{1,3,0,2} bitcast(transpose.31)
-  all-gather = bf16[2,256,8,64]{1,3,0,2} all-gather(bitcast.154), channel_id=2, replica_groups={{0,1},{2,3}}, dimensions={2}, use_global_device_ids=true
-  bitcast.155 = bf16[8,2,64,256]{3,2,1,0} bitcast(all-gather)
-  transpose.32 = bf16[2,8,64,256]{3,2,1,0} transpose(bitcast.155), dimensions={1,0,2,3}
-  bitcast.157 = bf16[2,256,8,64]{1,3,2,0} bitcast(transpose.32)
-  all-gather.1 = bf16[4,256,8,64]{1,3,2,0} all-gather(bitcast.157), channel_id=3, replica_groups={{0,2},{1,3}}, dimensions={0}, use_global_device_ids=true
-  bitcast.236 = bf16[4,8,64,256]{3,2,1,0} bitcast(all-gather.1)
-  transpose.38 = bf16[4,256,8,64]{3,2,1,0} transpose(bitcast.236), dimensions={0,3,1,2}
-  param.4 = bf16[2,256,4,64]{3,2,1,0} parameter(3), sharding={devices=[2,1,2,1]<=[4]}
-  transpose.33 = bf16[2,4,256,64]{3,2,1,0} transpose(param.4), dimensions={0,2,1,3}
-  dot.14 = bf16[2,4,256,256]{3,2,1,0} dot(transpose.33, transpose.26), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  broadcast.142 = bf16[4096]{0} broadcast(constant.93), dimensions={}
-  constant.95 = bf16[] constant(0)
-  broadcast.143 = bf16[4096]{0} broadcast(constant.95), dimensions={}
-  select.4 = bf16[4096]{0} select(compare.4, broadcast.142, broadcast.143)
-  bitcast.176 = bf16[2,8,256]{2,1,0} bitcast(select.4)
-  dynamic-slice.28 = bf16[2,4,256]{2,1,0} dynamic-slice(bitcast.176, constant.80, bitcast.181, constant.80), dynamic_slice_sizes={2,4,256}
-  broadcast.145 = bf16[2,4,256,256]{3,2,1,0} broadcast(dynamic-slice.28), dimensions={0,1,3}
-  multiply.40 = bf16[2,4,256,256]{3,2,1,0} multiply(dot.14, broadcast.145)
-  divide.8 = bf16[2,4,256,256]{3,2,1,0} divide(multiply.40, broadcast.128)
-  constant.106 = bf16[] constant(1)
-  broadcast.146 = bf16[2,4,256]{2,1,0} broadcast(constant.106), dimensions={}
-  multiply.41 = bf16[2,4,256]{2,1,0} multiply(convert.47, convert.47)
-  divide.9 = bf16[2,4,256]{2,1,0} divide(broadcast.146, multiply.41)
-  broadcast.147 = bf16[2,4,256,256]{3,2,1,0} broadcast(divide.9), dimensions={0,1,2}
-  multiply.42 = bf16[2,4,256,256]{3,2,1,0} multiply(multiply.40, broadcast.147)
-  multiply.43 = bf16[2,4,256,256]{3,2,1,0} multiply(multiply.42, exponential.2)
-  reduce.6 = bf16[2,4,256]{2,1,0} reduce(multiply.43, constant.95), dimensions={3}, to_apply=region_2.65
-  negate.4 = bf16[2,4,256]{2,1,0} negate(reduce.6)
-  broadcast.148 = bf16[2,4,256,256]{3,2,1,0} broadcast(negate.4), dimensions={0,1,2}
-  add.20 = bf16[2,4,256,256]{3,2,1,0} add(divide.8, broadcast.148)
-  multiply.44 = bf16[2,4,256,256]{3,2,1,0} multiply(add.20, exponential.2)
-  transpose.34 = bf16[2,4,256,64]{3,2,1,0} transpose(param.2), dimensions={0,2,1,3}
-  dot.15 = bf16[2,4,256,64]{3,2,1,0} dot(multiply.44, transpose.34), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  multiply.45 = bf16[2,4,256,64]{3,2,1,0} multiply(dot.15, broadcast.126)
-  transpose.39 = bf16[2,256,4,64]{3,2,1,0} transpose(multiply.45), dimensions={0,2,1,3}
-  dot.16 = bf16[2,4,256,64]{3,2,1,0} dot(multiply.44, multiply.34), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  transpose.40 = bf16[2,256,4,64]{3,2,1,0} transpose(dot.16), dimensions={0,2,1,3}
-  transpose.36 = bf16[2,4,64,256]{3,2,1,0} transpose(param.4), dimensions={0,2,3,1}
-  dot.11 = bf16[2,4,64,256]{3,2,1,0} dot(transpose.36, multiply.39), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  transpose.41 = bf16[2,256,4,64]{3,2,1,0} transpose(dot.11), dimensions={0,3,1,2}
-  ROOT tuple.2 = (bf16[4,256,8,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}) tuple(transpose.38, transpose.39, transpose.40, transpose.41)
-} // main.164_spmd
-)";
-  // Dropout bwd pattern not supported, should not lower fwd as well
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
-                                         GetCudnnVersion()};
-  TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
-  SCOPED_TRACE(m->ToString());
-  // check if fwd graph has been restored with cloned activation
-  EXPECT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(m::Tuple(
-          m::Transpose(), m::Transpose(), m::Transpose(),
-          m::Transpose(m::Dot(
-              m::Op(), m::Op().WithPredicate([](const HloInstruction* instr) {
-                return instr->name() == "multiply.39.fmha_no_match_clone";
-              }))))));
-}
-
-constexpr absl::string_view hlo_head_dim_not_multiple_of_64 = R"(
-HloModule jit__reference, entry_computation_layout={(f16[4,48,1024,16]{3,2,1,0}, f16[4,48,1024,16]{3,2,1,0}, f16[4,48,1024,16]{3,2,1,0})->f16[4,48,1024,16]{3,2,1,0}}
-
-region_0.26 {
-  Arg_0.27 = f32[] parameter(0)
-  Arg_1.28 = f32[] parameter(1)
-  ROOT maximum = f32[] maximum(Arg_0.27, Arg_1.28)
-}
-
-region_1.37 {
-  Arg_0.38 = f32[] parameter(0)
-  Arg_1.39 = f32[] parameter(1)
-  ROOT add = f32[] add(Arg_0.38, Arg_1.39)
-}
-
-ENTRY main.49 {
-  iota.2 = s32[1024,1024]{1,0} iota(), iota_dimension=0
-  iota.3 = s32[1024,1024]{1,0} iota(), iota_dimension=1
-  compare = pred[1024,1024]{1,0} compare(iota.2, iota.3), direction=GE
-  broadcast.4 = pred[4,48,1024,1024]{3,2,1,0} broadcast(compare), dimensions={2,3}
-  Arg_0.1 = f16[4,48,1024,16]{3,2,1,0} parameter(0)
-  Arg_1.2 = f16[4,48,1024,16]{3,2,1,0} parameter(1)
-  dot.9 = f16[4,48,1024,1024]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
-  constant.4 = f16[] constant(0.5)
-  broadcast.6 = f16[4,48,1024,1024]{3,2,1,0} broadcast(constant.4), dimensions={}
-  multiply = f16[4,48,1024,1024]{3,2,1,0} multiply(dot.9, broadcast.6)
-  convert.1 = f32[4,48,1024,1024]{3,2,1,0} convert(multiply)
-  constant.7 = f32[] constant(-inf)
-  reduce.30 = f32[4,48,1024]{2,1,0} reduce(convert.1, constant.7), dimensions={3}, to_apply=region_0.26
-  broadcast.8 = f32[4,48,1024,1024]{3,2,1,0} broadcast(reduce.30), dimensions={0,1,2}
-  subtract = f32[4,48,1024,1024]{3,2,1,0} subtract(convert.1, broadcast.8)
-  exponential = f32[4,48,1024,1024]{3,2,1,0} exponential(subtract)
-  constant.6 = f32[] constant(0)
-  reduce.41 = f32[4,48,1024]{2,1,0} reduce(exponential, constant.6), dimensions={3}, to_apply=region_1.37
-  broadcast.9 = f32[4,48,1024,1024]{3,2,1,0} broadcast(reduce.41), dimensions={0,1,2}
-  divide = f32[4,48,1024,1024]{3,2,1,0} divide(exponential, broadcast.9)
-  convert.2 = f16[4,48,1024,1024]{3,2,1,0} convert(divide)
-  Arg_2.3 = f16[4,48,1024,16]{3,2,1,0} parameter(2)
-  ROOT dot.48 = f16[4,48,1024,16]{3,2,1,0} dot(convert.2, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-} // main.49
-)";
-
-TEST_F(CudnnFusedMhaRewriterTestHloTest, HeadDimNotMultipleOf64) {
-  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
-  TF_ASSERT_OK_AND_ASSIGN(
-      auto m, ParseAndReturnVerifiedModule(hlo_head_dim_not_multiple_of_64,
-                                           GetModuleConfig()));
-  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
-                                         GetCudnnVersion()};
-  TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
-
-  // head dim not a multiple of 64 should not be lowered with cuDNN < 8.9.6
-  SCOPED_TRACE(m->ToString());
-  EXPECT_THAT(m->entry_computation()->root_instruction(), GmockMatch(m::Dot()));
-
-  // should be lowered with cuDNN >= 8.9.6
-  CudnnFusedMHARewriter fusedMhaRewriterWithcuDNN8907{
-      GetCudaComputeCapability(), se::dnn::VersionInfo(8, 9, 7)};
-  TF_ASSERT_OK(RunHloPass(&fusedMhaRewriterWithcuDNN8907, m.get()).status());
-  const HloInstruction* fmha;
-
-  SCOPED_TRACE(m->ToString());
-  EXPECT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(m::GetTupleElement(
-                     m::CustomCall(&fmha, {kCudnnfMHASoftmaxCallTarget}), 0)
-                     .WithShape(F16, {4, 48, 1024, 16})));
-  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
-                          fmha->backend_config<GpuBackendConfig>());
-  const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config();
-  EXPECT_EQ(config.fmha_scale(), 0.5);
-  EXPECT_EQ(config.dropout_rate(), 0.0);
-}
-
-constexpr absl::string_view hlo_BF16Bmm1BiasSoftmaxBmm2Pattern_dbias = R"(
-HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[2,1024,4,64]{3,2,1,0}, bf16[2,1024,4,64]{3,2,1,0}, bf16[2,1024,4,64]{3,2,1,0}, bf16[2,1024,4,64]{3,2,1,0}, bf16[4,1024,1024]{2,1,0})->(bf16[2,1024,4,64]{3,2,1,0}, bf16[2,1024,4,64]{3,2,1,0}, bf16[2,1024,4,64]{3,2,1,0}, bf16[2,1024,4,64]{3,2,1,0}, bf16[4,1024,1024]{2,1,0})}, allow_spmd_sharding_propagation_to_parameters={true,true,true,true,true}, allow_spmd_sharding_propagation_to_output={true,true,true,true,true}
-
-region_0.14 {
-  Arg_0.15 = bf16[] parameter(0)
-  Arg_1.16 = bf16[] parameter(1)
-  ROOT maximum = bf16[] maximum(Arg_0.15, Arg_1.16)
-}
-
-region_1.27 {
-  Arg_0.28 = f32[] parameter(0)
-  Arg_1.29 = f32[] parameter(1)
-  ROOT add = f32[] add(Arg_0.28, Arg_1.29)
-}
-
-region_2.56 {
-  Arg_0.57 = bf16[] parameter(0)
-  Arg_1.58 = bf16[] parameter(1)
-  ROOT add.1 = bf16[] add(Arg_0.57, Arg_1.58)
-}
-
-ENTRY main.87 {
-  Arg_2.3 = bf16[2,1024,4,64]{3,2,1,0} parameter(2)
-  transpose.12 = bf16[2,4,64,1024]{3,2,1,0} transpose(Arg_2.3), dimensions={0,2,3,1}
-  Arg_0.1 = bf16[2,1024,4,64]{3,2,1,0} parameter(0)
-  transpose.13 = bf16[2,4,1024,64]{3,2,1,0} transpose(Arg_0.1), dimensions={0,2,1,3}
-  Arg_1.2 = bf16[2,1024,4,64]{3,2,1,0} parameter(1)
-  transpose.15 = bf16[2,4,64,1024]{3,2,1,0} transpose(Arg_1.2), dimensions={0,2,3,1}
-  dot = bf16[2,4,1024,1024]{3,2,1,0} dot(transpose.13, transpose.15), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  Arg_4.5 = bf16[4,1024,1024]{2,1,0} parameter(4)
-  broadcast.9 = bf16[2,4,1024,1024]{3,2,1,0} broadcast(Arg_4.5), dimensions={1,2,3}
-  add.2 = bf16[2,4,1024,1024]{3,2,1,0} add(dot, broadcast.9)
-  constant.10 = bf16[] constant(-inf)
-  reduce.18 = bf16[2,4,1024]{2,1,0} reduce(add.2, constant.10), dimensions={3}, to_apply=region_0.14
-  broadcast.10 = bf16[2,4,1024,1024]{3,2,1,0} broadcast(reduce.18), dimensions={0,1,2}
-  subtract = bf16[2,4,1024,1024]{3,2,1,0} subtract(add.2, broadcast.10)
-  exponential = bf16[2,4,1024,1024]{3,2,1,0} exponential(subtract)
-  convert.5 = f32[2,4,1024,1024]{3,2,1,0} convert(exponential)
-  constant.9 = f32[] constant(0)
-  reduce.31 = f32[2,4,1024]{2,1,0} reduce(convert.5, constant.9), dimensions={3}, to_apply=region_1.27
-  convert.6 = bf16[2,4,1024]{2,1,0} convert(reduce.31)
-  broadcast.11 = bf16[2,4,1024,1024]{3,2,1,0} broadcast(convert.6), dimensions={0,1,2}
-  divide.2 = bf16[2,4,1024,1024]{3,2,1,0} divide(exponential, broadcast.11)
-  dot.1 = bf16[2,4,64,1024]{3,2,1,0} dot(transpose.12, divide.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
-  transpose.22 = bf16[2,1024,4,64]{3,2,1,0} transpose(dot.1), dimensions={0,3,1,2}
-  Arg_3.4 = bf16[2,1024,4,64]{3,2,1,0} parameter(3)
-  transpose.17 = bf16[2,4,1024,64]{3,2,1,0} transpose(Arg_3.4), dimensions={0,2,1,3}
-  dot.2 = bf16[2,4,1024,1024]{3,2,1,0} dot(transpose.17, transpose.12), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  divide.3 = bf16[2,4,1024,1024]{3,2,1,0} divide(dot.2, broadcast.11)
-  constant.0 = bf16[] constant(1)
-  broadcast.13 = bf16[2,4,1024]{2,1,0} broadcast(constant.0), dimensions={}
-  multiply.2 = bf16[2,4,1024]{2,1,0} multiply(convert.6, convert.6)
-  divide.4 = bf16[2,4,1024]{2,1,0} divide(broadcast.13, multiply.2)
-  broadcast.14 = bf16[2,4,1024,1024]{3,2,1,0} broadcast(divide.4), dimensions={0,1,2}
-  multiply.3 = bf16[2,4,1024,1024]{3,2,1,0} multiply(dot.2, broadcast.14)
-  multiply.4 = bf16[2,4,1024,1024]{3,2,1,0} multiply(multiply.3, exponential)
-  constant.8 = bf16[] constant(0)
-  reduce.60 = bf16[2,4,1024]{2,1,0} reduce(multiply.4, constant.8), dimensions={3}, to_apply=region_2.56
-  negate.1 = bf16[2,4,1024]{2,1,0} negate(reduce.60)
-  broadcast.15 = bf16[2,4,1024,1024]{3,2,1,0} broadcast(negate.1), dimensions={0,1,2}
-  add.3 = bf16[2,4,1024,1024]{3,2,1,0} add(divide.3, broadcast.15)
-  multiply.5 = bf16[2,4,1024,1024]{3,2,1,0} multiply(add.3, exponential)
-  transpose.18 = bf16[2,4,1024,64]{3,2,1,0} transpose(Arg_1.2), dimensions={0,2,1,3}
-  dot.4 = bf16[2,4,1024,64]{3,2,1,0} dot(multiply.5, transpose.18), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  transpose.23 = bf16[2,1024,4,64]{3,2,1,0} transpose(dot.4), dimensions={0,2,1,3}
-  dot.3 = bf16[2,4,1024,64]{3,2,1,0} dot(multiply.5, transpose.13), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  transpose.24 = bf16[2,1024,4,64]{3,2,1,0} transpose(dot.3), dimensions={0,2,1,3}
-  transpose.20 = bf16[2,4,64,1024]{3,2,1,0} transpose(Arg_3.4), dimensions={0,2,3,1}
-  dot.49 = bf16[2,4,64,1024]{3,2,1,0} dot(transpose.20, divide.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-  transpose.25 = bf16[2,1024,4,64]{3,2,1,0} transpose(dot.49), dimensions={0,3,1,2}
-  reduce.81 = bf16[4,1024,1024]{2,1,0} reduce(multiply.5, constant.8), dimensions={0}, to_apply=region_2.56
-  ROOT tuple = (bf16[2,1024,4,64]{3,2,1,0}, bf16[2,1024,4,64]{3,2,1,0}, bf16[2,1024,4,64]{3,2,1,0}, bf16[2,1024,4,64]{3,2,1,0}, bf16[4,1024,1024]{2,1,0}) tuple(transpose.22, transpose.23, transpose.24, transpose.25, reduce.81)
-} // main.87
-)";
-
-TEST_F(CudnnFusedMhaRewriterTestHloTest, BF16Bmm1BiasSoftmaxBmm2PatternDbias) {
-  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
-  TF_ASSERT_OK_AND_ASSIGN(
-      auto m,
-      ParseAndReturnVerifiedModule(hlo_BF16Bmm1BiasSoftmaxBmm2Pattern_dbias));
-  // require cudnn 8.9.6 + hopper for dbias
-  CudnnFusedMHARewriter fusedMhaRewriter{se::CudaComputeCapability(9, 0),
-                                         se::dnn::VersionInfo(8, 9, 6)};
-  TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
-
-  ComputationLayout computation_layout(
-      m->entry_computation()->ComputeProgramShape());
-
-  const HloInstruction* fmha;
-
-  SCOPED_TRACE(m->ToString());
-  EXPECT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(m::Tuple(
-          m::Transpose(
-              m::Transpose(m::GetTupleElement(
-                  m::CustomCall(&fmha, {kCudnnfMHAScaleBiasSoftmaxCallTarget}),
-                  0)))
-              .WithShape(BF16, {2, 1024, 4, 64}),
-          m::Transpose(
-              m::GetTupleElement(
-                  m::CustomCall({kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget}),
-                  0))
-              .WithShape(BF16, {2, 1024, 4, 64}),
-          m::Transpose(
-              m::GetTupleElement(
-                  m::CustomCall({kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget}),
-                  1))
-              .WithShape(BF16, {2, 1024, 4, 64}),
-          m::Transpose(
-              m::Transpose(m::GetTupleElement(
-                  m::CustomCall({kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget}),
-                  2)))
-              .WithShape(BF16, {2, 1024, 4, 64}),
-          m::Reshape(
-              m::GetTupleElement(
-                  m::CustomCall({kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget}),
-                  3))
-              .WithShape(BF16, {4, 1024, 1024}))));
-  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
-                          fmha->backend_config<GpuBackendConfig>());
-  const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config();
-  EXPECT_EQ(fmha->operands().size(), 4);
-  EXPECT_EQ(fmha->operand(3)->shape(),
-            ShapeUtil::MakeShape(BF16, {1, 4, 1024, 1024}));
-  EXPECT_EQ(config.fmha_scale(), 1.0);
-  EXPECT_EQ(config.dropout_rate(), 0.0);
-  EXPECT_EQ(config.mask_type(), CudnnfMHABackendConfig::NO_MASK);
-}
-}  // anonymous namespace
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/cudnn_fused_mha_transpose_fusion.cc b/third_party/xla/xla/service/gpu/cudnn_fused_mha_transpose_fusion.cc
deleted file mode 100644
index 665cc0b..0000000
--- a/third_party/xla/xla/service/gpu/cudnn_fused_mha_transpose_fusion.cc
+++ /dev/null
@@ -1,668 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/cudnn_fused_mha_transpose_fusion.h"
-
-#include <algorithm>
-#include <cstdint>
-#include <iterator>
-#include <vector>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/log/check.h"
-#include "absl/log/log.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/permutation_util.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/service/gpu/matmul_utils.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/util.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-namespace m = match;
-
-bool IsFMHACustomCall(const HloInstruction* instr) {
-  return IsCustomCallTofMHA(*instr);
-}
-
-bool IsFwdFMHACustomCall(const HloInstruction* instr) {
-  return IsFwdCustomCallTofMHA(*instr);
-}
-
-bool IsBwdFMHACustomCall(const HloInstruction* instr) {
-  return IsBwdCustomCallTofMHA(*instr);
-}
-
-absl::StatusOr<bool> FuseArgPrologueTransposeWithcuDNNFMHA(
-    HloInstruction* fmha, int64_t operand_index, bool is_lhs,
-    bool should_contracting_be_fastest) {
-  HloInstruction* transpose_arg = fmha->mutable_operand(operand_index);
-  HloInstruction* transpose_arg_operand = transpose_arg->mutable_operand(0);
-  TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
-                      fmha->backend_config<GpuBackendConfig>());
-  CudnnfMHABackendConfig config = gpu_config.cudnn_fmha_backend_config();
-  CudnnfMHABackendConfig& new_fmha_config =
-      *gpu_config.mutable_cudnn_fmha_backend_config();
-
-  std::vector<int64_t> inverse_perm =
-      InversePermutation(transpose_arg->dimensions());
-  DotDimensionNumbers new_bmm_dot_dims;
-  if (IsFwdCustomCallTofMHA(*fmha)) {
-    if (operand_index == 0 || operand_index == 1) {
-      new_bmm_dot_dims = config.bmm1_dot_dimension_numbers();
-    } else {
-      new_bmm_dot_dims = config.bmm2_dot_dimension_numbers();
-    }
-  } else {
-    switch (operand_index) {
-      case 0:
-        // Q
-        new_bmm_dot_dims = config.bmm1_grad_gemm1_dot_dimension_numbers();
-        break;
-      case 1:
-        // K
-        new_bmm_dot_dims = config.bmm1_grad_gemm2_dot_dimension_numbers();
-        break;
-      case 2:
-        // V
-        new_bmm_dot_dims = config.bmm2_grad_gemm2_dot_dimension_numbers();
-        break;
-      case 3:
-        // Forward activation
-        new_bmm_dot_dims = config.bmm2_grad_gemm1_dot_dimension_numbers();
-        break;
-      case 4:
-        // Output gradient
-        new_bmm_dot_dims = config.bmm2_grad_gemm2_dot_dimension_numbers();
-        break;
-      default:
-        return Internal("Invalid operand index.");
-    }
-  }
-  absl::Span<const int64_t> checked_dims;
-  std::vector<int64_t> checked_dims_vec;
-
-  // `should_contracting_be_fastest` means if contracting dim is the head
-  // dim. cuDNN requires head dim to be the fastest dim. fwd bmm1 and bwd
-  // bmm2grad1 should set this value to true.
-  if (should_contracting_be_fastest) {
-    checked_dims = is_lhs ? new_bmm_dot_dims.lhs_contracting_dimensions()
-                          : new_bmm_dot_dims.rhs_contracting_dimensions();
-  } else {
-    absl::Span<const int64_t> batch_dims =
-        is_lhs ? new_bmm_dot_dims.lhs_batch_dimensions()
-               : new_bmm_dot_dims.rhs_batch_dimensions();
-    absl::Span<const int64_t> contracting_dims =
-        is_lhs ? new_bmm_dot_dims.lhs_contracting_dimensions()
-               : new_bmm_dot_dims.rhs_contracting_dimensions();
-
-    TF_ASSIGN_OR_RETURN(checked_dims_vec,
-                        GetNonContractingDims(transpose_arg->shape(),
-                                              batch_dims, contracting_dims));
-    checked_dims = checked_dims_vec;
-  }
-
-  int64_t checked_dims_bmm_size = checked_dims.size();
-  std::vector<int64_t> new_bmm_checked_dims(checked_dims_bmm_size);
-  for (int i = 0; i < checked_dims_bmm_size; i++) {
-    auto itr =
-        std::find(inverse_perm.begin(), inverse_perm.end(), checked_dims[i]);
-    if (itr == inverse_perm.end()) {
-      return Internal("Invalid inverse perm");
-    }
-    new_bmm_checked_dims[i] = std::distance(inverse_perm.begin(), itr);
-  }
-  // We want to make sure that making the argument to transpose, an input to
-  // fmha, doesn't break cuDNN constraint that the head dim of
-  // corresponding operand of BMM is the fastest moving dimension.
-  // One exception is the forward activation which doesn't have the constraint
-  // since it does not have head dim.
-  absl::Span<const int64_t> minor_to_major_bmm =
-      transpose_arg_operand->shape().layout().minor_to_major();
-  if ((minor_to_major_bmm[0] != new_bmm_checked_dims[0]) &&
-      !(IsBwdCustomCallTofMHA(*fmha) && operand_index == 3)) {
-    return false;
-  }
-  if (should_contracting_be_fastest) {
-    if (is_lhs) {
-      new_bmm_dot_dims.clear_lhs_contracting_dimensions();
-      *new_bmm_dot_dims.mutable_lhs_contracting_dimensions() = {
-          new_bmm_checked_dims.begin(), new_bmm_checked_dims.end()};
-    } else {
-      new_bmm_dot_dims.clear_rhs_contracting_dimensions();
-      *new_bmm_dot_dims.mutable_rhs_contracting_dimensions() = {
-          new_bmm_checked_dims.begin(), new_bmm_checked_dims.end()};
-    }
-  }
-  auto& batch_dims = is_lhs ? new_bmm_dot_dims.lhs_batch_dimensions()
-                            : new_bmm_dot_dims.rhs_batch_dimensions();
-  int64_t batch_dims_bmm_size = batch_dims.size();
-  std::vector<int64_t> new_bmm_batch_dims(batch_dims_bmm_size);
-  for (int i = 0; i < batch_dims_bmm_size; i++) {
-    auto itr =
-        std::find(inverse_perm.begin(), inverse_perm.end(), batch_dims[i]);
-    if (itr == inverse_perm.end()) {
-      return Internal("Invalid inverse perm");
-    }
-    new_bmm_batch_dims[i] = std::distance(inverse_perm.begin(), itr);
-  }
-
-  if (is_lhs) {
-    new_bmm_dot_dims.clear_lhs_batch_dimensions();
-    *new_bmm_dot_dims.mutable_lhs_batch_dimensions() = {
-        new_bmm_batch_dims.begin(), new_bmm_batch_dims.end()};
-
-  } else {
-    new_bmm_dot_dims.clear_rhs_batch_dimensions();
-    *new_bmm_dot_dims.mutable_rhs_batch_dimensions() = {
-        new_bmm_batch_dims.begin(), new_bmm_batch_dims.end()};
-  }
-
-  if (!should_contracting_be_fastest) {
-    // Given the non-contracting dimensions, we can use the same function,
-    // GetNonContractingDims, to find the new contracting dims. Simply pass the
-    // non-contracting dimensions as the second argument.
-    TF_ASSIGN_OR_RETURN(
-        std::vector<int64_t> new_bmm_contracting_dims,
-        GetNonContractingDims(transpose_arg_operand->shape(),
-                              new_bmm_batch_dims, new_bmm_checked_dims));
-    if (is_lhs) {
-      new_bmm_dot_dims.clear_lhs_contracting_dimensions();
-      *new_bmm_dot_dims.mutable_lhs_contracting_dimensions() = {
-          new_bmm_contracting_dims.begin(), new_bmm_contracting_dims.end()};
-
-    } else {
-      new_bmm_dot_dims.clear_rhs_contracting_dimensions();
-      *new_bmm_dot_dims.mutable_rhs_contracting_dimensions() = {
-          new_bmm_contracting_dims.begin(), new_bmm_contracting_dims.end()};
-    }
-  }
-  if (IsFwdCustomCallTofMHA(*fmha)) {
-    if (operand_index == 0 || operand_index == 1) {
-      // Q or K
-      *new_fmha_config.mutable_bmm1_dot_dimension_numbers() = new_bmm_dot_dims;
-    } else {
-      // V
-      *new_fmha_config.mutable_bmm2_dot_dimension_numbers() = new_bmm_dot_dims;
-    }
-  } else {
-    switch (operand_index) {
-      case 0:
-        // Q
-        *new_fmha_config.mutable_bmm1_grad_gemm1_dot_dimension_numbers() =
-            new_bmm_dot_dims;
-        break;
-      case 1:
-        // K
-        *new_fmha_config.mutable_bmm1_grad_gemm2_dot_dimension_numbers() =
-            new_bmm_dot_dims;
-        break;
-      case 2:
-        // V
-        *new_fmha_config.mutable_bmm2_grad_gemm2_dot_dimension_numbers() =
-            new_bmm_dot_dims;
-        break;
-      case 3:
-        // Forward activation
-        *new_fmha_config.mutable_bmm2_grad_gemm1_dot_dimension_numbers() =
-            new_bmm_dot_dims;
-        break;
-      case 4: {
-        // Output gradient
-        *new_fmha_config.mutable_bmm2_grad_gemm2_dot_dimension_numbers() =
-            new_bmm_dot_dims;
-        DotDimensionNumbers bmm2_grad_gemm1_dot_dims =
-            config.bmm2_grad_gemm1_dot_dimension_numbers();
-        absl::Span<const int64_t> bmm2_grad_gemm1_contracting_dims =
-            bmm2_grad_gemm1_dot_dims.rhs_contracting_dimensions();
-        CHECK_EQ(bmm2_grad_gemm1_contracting_dims.size(), 1);
-        absl::Span<const int64_t> transpose_permutation =
-            transpose_arg->dimensions();
-        auto itr = std::find(transpose_permutation.begin(),
-                             transpose_permutation.end(),
-                             bmm2_grad_gemm1_contracting_dims[0]);
-        if (itr == transpose_permutation.end()) {
-          return Internal(
-              "bmm2 gradident gemm1 contracting dimension not found.");
-        }
-        int64_t index = std::distance(transpose_permutation.begin(), itr);
-        std::vector<int64_t> new_bmm2_grad_gemm1_rhs_contracting_dims = {index};
-        // Find the new batch dimensions, this is done by passing new
-        // contracting dimensions and contracting dimension of lhs of
-        // bmm2_grad_gemm2(which is the non-contracting dimension of rhs
-        // bmm2_grad_gemm1) to GetNonContractingDims.
-        TF_ASSIGN_OR_RETURN(
-            std::vector<int64_t> new_bmm2_grad_gemm1_rhs_batch_dims,
-            GetNonContractingDims(
-                transpose_arg_operand->shape(),
-                new_bmm2_grad_gemm1_rhs_contracting_dims,
-                new_bmm_dot_dims.lhs_contracting_dimensions()));
-        bmm2_grad_gemm1_dot_dims.clear_rhs_contracting_dimensions();
-        bmm2_grad_gemm1_dot_dims.clear_rhs_batch_dimensions();
-        *bmm2_grad_gemm1_dot_dims.mutable_rhs_contracting_dimensions() = {
-            new_bmm2_grad_gemm1_rhs_contracting_dims.begin(),
-            new_bmm2_grad_gemm1_rhs_contracting_dims.end()};
-        *bmm2_grad_gemm1_dot_dims.mutable_rhs_batch_dimensions() = {
-            new_bmm2_grad_gemm1_rhs_batch_dims.begin(),
-            new_bmm2_grad_gemm1_rhs_batch_dims.end()};
-        *new_fmha_config.mutable_bmm2_grad_gemm1_dot_dimension_numbers() =
-            bmm2_grad_gemm1_dot_dims;
-        break;
-      }
-      default:
-        return Internal("Invalid operand index.");
-    }
-  }
-
-  TF_RETURN_IF_ERROR(fmha->set_backend_config(gpu_config));
-
-  TF_RETURN_IF_ERROR(fmha->ReplaceOperandWithDifferentShape(
-      operand_index, transpose_arg_operand));
-
-  return true;
-}
-
-/* Let's say A is transposed to B with perm {3, 0, 2, 1} as shown below:
-A[16, 256, 32, 64]
-      |
-      |
-      | Transpose with perm = {3, 0, 2, 1}
-      |
-      \/
-B[64, 16, 32, 256]
-
-The inverse perm to obtain A from B would be {1, 3, 2, 0}. That is
-B[64, 16, 32, 256]
-      |
-      |
-      | Transpose' with inv_perm = {1, 3, 2, 0}
-      |
-      \/
-A[16, 256, 32, 64]
-
-Now, let's say B is the lhs of a BatchedMatmul and the lhs_contracting
-dim is 3 (i.e dim 256). In order to now make A the lhs to the
-batchedMatmul (thus consuming the Transpose from A->B), we need to find
-the dimension number in A that corresponds to dimension number 3 in B.
-This can be done by finding the index of dim num 3 in inv_perm. That
-would be 2. Hence, dim num 3 in B is equivalent to dim num 2 in A. Thus
-the new lhs_contracting dim ,if A were to be the new lhs, would be 2.
-
-Similarly, we need to find corresponding batch dimensions as well.
-*/
-absl::StatusOr<bool> FusePrologueTransposeWithcuDNNFMHA(HloComputation* comp) {
-  bool changed = false;
-  for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
-    HloInstruction *transpose_arg0, *transpose_arg0_operand;
-    HloInstruction *transpose_arg1, *transpose_arg1_operand;
-    HloInstruction *transpose_arg2, *transpose_arg2_operand;
-    HloInstruction *transpose_arg3, *transpose_arg3_operand;
-    HloInstruction *transpose_arg4, *transpose_arg4_operand;
-
-    HloInstruction* fmha;
-
-    // Arg0 is common between forward and backward fmha calls, so we match
-    // either of these.
-    auto pattern_arg0 =
-        m::Op(&fmha)
-            .WithPredicate(IsFMHACustomCall)
-            .WithOperand(0, m::Transpose(&transpose_arg0,
-                                         m::Op(&transpose_arg0_operand)));
-    if (Match(instr, pattern_arg0)) {
-      if (VLOG_IS_ON(2)) {
-        VLOG(2) << "Before CudnnFusedMHATransposeFusion Arg 0: \n"
-                << comp->parent()->ToString();
-      }
-      if (IsFwdFMHACustomCall(fmha)) {
-        // Q tensor in forward graph is lhs with constraint on contracting dim.
-        TF_ASSIGN_OR_RETURN(changed,
-                            FuseArgPrologueTransposeWithcuDNNFMHA(
-                                fmha, 0, true /*is_lhs*/,
-                                true /*should_contracting_be_fastest*/));
-      } else {
-        // Q tensor in backward graph is rhs with constraint on non-contracting
-        // dim.
-        TF_ASSIGN_OR_RETURN(changed,
-                            FuseArgPrologueTransposeWithcuDNNFMHA(
-                                fmha, 0, false /*is_lhs*/,
-                                false /*should_contracting_be_fastest*/));
-      }
-
-      if (changed && VLOG_IS_ON(2)) {
-        VLOG(2) << "After CudnnFusedMHATransposeFusion Arg 0: \n"
-                << comp->parent()->ToString();
-      }
-    }
-
-    // Arg1 is common between forward and backward fmha calls, so we match
-    // either of these.
-    auto pattern_arg1 =
-        m::Op(&fmha)
-            .WithPredicate(IsFMHACustomCall)
-            .WithOperand(1, m::Transpose(&transpose_arg1,
-                                         m::Op(&transpose_arg1_operand)));
-    if (Match(instr, pattern_arg1)) {
-      if (VLOG_IS_ON(2)) {
-        VLOG(2) << "Before CudnnFusedMHATransposeFusion Arg 1: \n"
-                << comp->parent()->ToString();
-      }
-      if (IsFwdFMHACustomCall(fmha)) {
-        // K tensor in forward graph is rhs with constraint on contracting dim.
-        TF_ASSIGN_OR_RETURN(changed,
-                            FuseArgPrologueTransposeWithcuDNNFMHA(
-                                fmha, 1, false /*is_lhs*/,
-                                true /*should_contracting_be_fastest*/));
-      } else {
-        // K tensor in backward graph is rhs with constraint on non-contracting
-        // dim.
-        TF_ASSIGN_OR_RETURN(changed,
-                            FuseArgPrologueTransposeWithcuDNNFMHA(
-                                fmha, 1, false /*is_lhs*/,
-                                false /*should_contracting_be_fastest*/));
-      }
-
-      if (changed && VLOG_IS_ON(2)) {
-        VLOG(2) << "After CudnnFusedMHATransposeFusion Arg 1: \n"
-                << comp->parent()->ToString();
-      }
-    }
-
-    // Arg2 is common between forward and backward fmha calls, so we match
-    // either of these.
-    auto pattern_arg2 =
-        m::Op(&fmha)
-            .WithPredicate(IsFMHACustomCall)
-            .WithOperand(2, m::Transpose(&transpose_arg2,
-                                         m::Op(&transpose_arg2_operand)));
-    if (Match(instr, pattern_arg2)) {
-      if (VLOG_IS_ON(2)) {
-        VLOG(2) << "Before CudnnFusedMHATransposeFusion Arg 2: \n"
-                << comp->parent()->ToString();
-      }
-      if (IsFwdFMHACustomCall(fmha)) {
-        // V tensor in forward graph is rhs with constraint on non-contracting
-        // dim.
-        TF_ASSIGN_OR_RETURN(changed,
-                            FuseArgPrologueTransposeWithcuDNNFMHA(
-                                fmha, 2, false /*is_lhs*/,
-                                false /*should_contracting_be_fastest*/));
-      } else {
-        // V tensor in backward graph is rhs with constraint on contracting dim.
-        TF_ASSIGN_OR_RETURN(changed,
-                            FuseArgPrologueTransposeWithcuDNNFMHA(
-                                fmha, 2, false /*is_lhs*/,
-                                true /*should_contracting_be_fastest*/));
-      }
-
-      if (changed && VLOG_IS_ON(2)) {
-        VLOG(2) << "After CudnnFusedMHATransposeFusion Arg 2: \n"
-                << comp->parent()->ToString();
-      }
-    }
-
-    // We only care about arg3 of backward
-    auto pattern_arg3 =
-        m::Op(&fmha)
-            .WithPredicate(IsBwdFMHACustomCall)
-            .WithOperand(3, m::Transpose(&transpose_arg3,
-                                         m::Op(&transpose_arg3_operand)));
-    if (Match(instr, pattern_arg3)) {
-      if (VLOG_IS_ON(2)) {
-        VLOG(2) << "Before CudnnFusedMHATransposeFusion Arg 3: \n"
-                << comp->parent()->ToString();
-      }
-      // Forward activation tensor in backward graph is lhs with constraint on
-      // non-contracting dim.
-      TF_ASSIGN_OR_RETURN(changed,
-                          FuseArgPrologueTransposeWithcuDNNFMHA(
-                              fmha, 3, true /*is_lhs*/,
-                              false /*should_contracting_be_fastest*/));
-
-      if (changed && VLOG_IS_ON(2)) {
-        VLOG(2) << "After CudnnFusedMHATransposeFusion Arg 3: \n"
-                << comp->parent()->ToString();
-      }
-    }
-
-    // We only care about arg4 of backward
-    auto pattern_arg4 =
-        m::Op(&fmha)
-            .WithPredicate(IsBwdFMHACustomCall)
-            .WithOperand(4, m::Transpose(&transpose_arg4,
-                                         m::Op(&transpose_arg4_operand)));
-    if (Match(instr, pattern_arg4)) {
-      if (VLOG_IS_ON(2)) {
-        VLOG(2) << "Before CudnnFusedMHATransposeFusion Arg 4: \n"
-                << comp->parent()->ToString();
-      }
-      // D_output tensor in backward graph is lhs with constraint on
-      // contracting dim.
-      // make sure we dont change layout of dO in flash attention case as dO
-      // should have the same layout of O
-      TF_ASSIGN_OR_RETURN(auto gpu_config,
-                          fmha->backend_config<GpuBackendConfig>());
-      if (changed && VLOG_IS_ON(2)) {
-        VLOG(2) << "After CudnnFusedMHATransposeFusion Arg 4: \n"
-                << comp->parent()->ToString();
-      }
-    }
-  }
-  return changed;
-}
-
-/* Let's say FMHA out is transposed to result with perm {1, 2, 0, 3} as shown
-below: FMHA_out[b0, b1, n, m]{}
-      |
-      |
-  Transpose with perm = {1, 2, 0, 3}
-      |
-      \/
-result[b1, n, b0, m]{1, 0, 3, 2}
-The goal is to find the minor_to_major of 'FMHA_out' such that it's physical
-layout matches the physical layout of 'result', thus eliminating the need for an
-explicit transpose. cuDNN can perform an implicit transpose by knowing the
-corresponding strides (inferred from the corresponding minor_to_major).
-
-In order to find the required mino_to_major of 'FMHA_out', we first determine
-the inverse perm to obtain 'FMHA_out' from 'result'. The function
-"ShapeUtil::PermuteDimensions" generates a transposed shape such that the
-physical layout of the transposed shape is equivalent to the input shape.
-Calling this function with 'result' shape as the input shape and the inverse
-perm as the permutation will generate an output shape whose dimensions match
-'FMHA_out' dimensions but the physical layout is equivalent to 'result'. This is
-exactly what we want.
-
-FMHA output should have exactly one gte instruction for a tuple index
-so we can safely fuse the transpose following that gte to FMHA
-
-FMHA_out = gte(FMHA, index=0)
-FMHA_out_t = transpose(FMHA_out)
-use(FMHA_out_t)
-
-after fusion:
-
-FMHA_out_t = gte(FMHA, index=0)
-use(FMHA_out_t)
-*/
-
-absl::StatusOr<bool> FuseEpilogueTransposeWithcuDNNFMHA(HloComputation* comp) {
-  bool changed = false;
-
-  auto only_one_gte_with_spec_index = [](const HloInstruction* instr,
-                                         int64_t index) {
-    int count = 0;
-    for (auto user : instr->users()) {
-      if (user->opcode() == HloOpcode::kGetTupleElement &&
-          user->tuple_index() == index) {
-        count += 1;
-      }
-    }
-    return count == 1;
-  };
-
-  for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
-    HloInstruction* fmha;
-    HloInstruction* transpose;
-    HloInstruction* gte;
-    auto fwd_tuple_elem =
-        m::GetTupleElement(&gte,
-                           m::Op(&fmha).WithPredicate(IsFwdFMHACustomCall), 0)
-            .WithOneUser();
-    // Note that we don't match any specific tuple index in matcher for
-    // backward.
-    auto bwd_tuple_elem =
-        m::GetTupleElement(&gte,
-                           m::Op(&fmha).WithPredicate(IsBwdFMHACustomCall))
-            .WithOneUser();
-    auto fwd_pattern = m::Transpose(&transpose, fwd_tuple_elem);
-    auto bwd_pattern = m::Transpose(&transpose, bwd_tuple_elem);
-
-    if (Match(instr, fwd_pattern)) {
-      // check if only one gte with such index exist
-      int64_t tuple_index = gte->tuple_index();
-      if (!only_one_gte_with_spec_index(fmha, tuple_index)) continue;
-
-      std::vector<int64_t> inverse_perm =
-          InversePermutation(transpose->dimensions());
-
-      auto expected_fmha_shape =
-          ShapeUtil::PermuteDimensions(inverse_perm, transpose->shape());
-
-      // cuDNN requires the last dimension of the output to be the fastest
-      // moving.
-      if (expected_fmha_shape.layout().minor_to_major()[0] !=
-          expected_fmha_shape.dimensions_size() - 1) {
-        VLOG(3) << "cuDNN requires the last dimension of the FMHA output to be "
-                   "the fastest moving. The last dimension is dim: "
-                << expected_fmha_shape.dimensions_size() - 1
-                << " but the upon fusion with transpose, the fmha output shape "
-                   "would have been "
-                << expected_fmha_shape.ToString(true)
-                << " and the fastest moving "
-                   "dimension would be dim: "
-                << expected_fmha_shape.layout().minor_to_major()[0];
-        continue;
-      }
-      Shape call_shape = fmha->shape();
-      *call_shape.mutable_tuple_shapes(0) = expected_fmha_shape;
-      HloInstruction* new_fmha_custom_call =
-          comp->AddInstruction(HloInstruction::CreateCustomCall(
-              call_shape, fmha->operands(),
-              absl::string_view(fmha->custom_call_target())));
-
-      TF_ASSIGN_OR_RETURN(GpuBackendConfig config,
-                          fmha->backend_config<GpuBackendConfig>());
-      TF_RETURN_IF_ERROR(new_fmha_custom_call->set_backend_config(config));
-      TF_RETURN_IF_ERROR(
-          SetFMHAInstructionName(fmha->GetModule(), new_fmha_custom_call));
-      new_fmha_custom_call->set_metadata(fmha->metadata());
-
-      auto gte = comp->AddInstruction(HloInstruction::CreateGetTupleElement(
-          new_fmha_custom_call->shape().tuple_shapes(0), new_fmha_custom_call,
-          0));
-      TF_RETURN_IF_ERROR(comp->ReplaceWithNewInstruction(
-          instr, HloInstruction::CreateBitcast(transpose->shape(), gte)));
-      TF_RETURN_IF_ERROR(fmha->ReplaceAllUsesWith(new_fmha_custom_call));
-
-      if (VLOG_IS_ON(2)) {
-        VLOG(2) << "After forward FuseEpilogueTransposeWithcuDNNFMHA: \n"
-                << comp->parent()->ToString();
-      }
-      changed |= true;
-    } else if (Match(instr, bwd_pattern)) {
-      // check if only one gte with such index exist
-      int64_t operand_tuple_idx = gte->tuple_index();
-      if (!only_one_gte_with_spec_index(fmha, operand_tuple_idx)) continue;
-
-      std::vector<int64_t> inverse_perm =
-          InversePermutation(transpose->dimensions());
-
-      auto expected_fmha_shape =
-          ShapeUtil::PermuteDimensions(inverse_perm, transpose->shape());
-
-      // cuDNN requires the last dimension of the output to be the fastest
-      // moving.
-      if (expected_fmha_shape.layout().minor_to_major()[0] !=
-          expected_fmha_shape.dimensions_size() - 1) {
-        VLOG(3) << "cuDNN requires the last dimension of the FMHA output to be "
-                   "the fastest moving. The last dimension is dim: "
-                << expected_fmha_shape.dimensions_size() - 1
-                << " but the upon fusion with transpose, the fmha output shape "
-                   "would have been "
-                << expected_fmha_shape.ToString(true)
-                << " and the fastest moving "
-                   "dimension would be dim: "
-                << expected_fmha_shape.layout().minor_to_major()[0];
-        continue;
-      }
-      Shape call_shape = fmha->shape();
-      *call_shape.mutable_tuple_shapes(operand_tuple_idx) = expected_fmha_shape;
-      HloInstruction* new_fmha_custom_call =
-          comp->AddInstruction(HloInstruction::CreateCustomCall(
-              call_shape, fmha->operands(),
-              absl::string_view(fmha->custom_call_target())));
-
-      TF_ASSIGN_OR_RETURN(GpuBackendConfig config,
-                          fmha->backend_config<GpuBackendConfig>());
-      TF_RETURN_IF_ERROR(new_fmha_custom_call->set_backend_config(config));
-      TF_RETURN_IF_ERROR(
-          SetFMHAInstructionName(fmha->GetModule(), new_fmha_custom_call));
-      new_fmha_custom_call->set_metadata(fmha->metadata());
-      TF_RETURN_IF_ERROR(fmha->ReplaceAllUsesWith(new_fmha_custom_call));
-
-      auto gte = comp->AddInstruction(HloInstruction::CreateGetTupleElement(
-          new_fmha_custom_call->shape().tuple_shapes(operand_tuple_idx),
-          new_fmha_custom_call, operand_tuple_idx));
-      TF_RETURN_IF_ERROR(comp->ReplaceWithNewInstruction(
-          instr, HloInstruction::CreateBitcast(transpose->shape(), gte)));
-
-      if (VLOG_IS_ON(2)) {
-        VLOG(2) << "After backward FuseEpilogueTransposeWithcuDNNFMHA: \n"
-                << comp->parent()->ToString();
-      }
-      changed |= true;
-    }
-  }
-  return changed;
-}
-}  // namespace
-
-absl::StatusOr<bool> CudnnFusedMHATransposeFusion::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  bool any_changed = false;
-  for (HloComputation* comp :
-       module->MakeNonfusionComputations(execution_threads)) {
-    bool changed = false;
-    TF_ASSIGN_OR_RETURN(changed, FusePrologueTransposeWithcuDNNFMHA(comp));
-    any_changed |= changed;
-    TF_ASSIGN_OR_RETURN(changed, FuseEpilogueTransposeWithcuDNNFMHA(comp));
-    any_changed |= changed;
-  }
-
-  return any_changed;
-}
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/cudnn_fused_mha_transpose_fusion.h b/third_party/xla/xla/service/gpu/cudnn_fused_mha_transpose_fusion.h
deleted file mode 100644
index 94ec229..0000000
--- a/third_party/xla/xla/service/gpu/cudnn_fused_mha_transpose_fusion.h
+++ /dev/null
@@ -1,45 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_CUDNN_FUSED_MHA_TRANSPOSE_FUSION_H_
-#define XLA_SERVICE_GPU_CUDNN_FUSED_MHA_TRANSPOSE_FUSION_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-class CudnnFusedMHATransposeFusion : public HloModulePass {
- public:
-  CudnnFusedMHATransposeFusion() = default;
-
-  absl::string_view name() const override {
-    return "cudnn-fused-multi-headed-attention-transpose-fusion";
-  }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_CUDNN_FUSED_MHA_TRANSPOSE_FUSION_H_
diff --git a/third_party/xla/xla/service/gpu/cudnn_fusion_compiler.cc b/third_party/xla/xla/service/gpu/cudnn_fusion_compiler.cc
deleted file mode 100644
index 1806796..0000000
--- a/third_party/xla/xla/service/gpu/cudnn_fusion_compiler.cc
+++ /dev/null
@@ -1,734 +0,0 @@
-/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/cudnn_fusion_compiler.h"
-
-#include <algorithm>
-#include <cstdint>
-#include <memory>
-#include <optional>
-#include <string>
-#include <utility>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_map.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/log/check.h"
-#include "absl/log/log.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "third_party/gpus/cudnn/cudnn_version.h"
-#include "xla/comparison_util.h"
-#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_clone_context.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/hlo/utils/hlo_query.h"
-#include "xla/primitive_util.h"
-#include "xla/service/dump.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/cudnn_support_utils.h"
-#include "xla/service/gpu/ir_emission_utils.h"
-#include "xla/service/gpu/kernel_reuse_cache.h"
-#include "xla/service/gpu/matmul_utils.h"
-#include "xla/service/gpu/triton_fusion_analysis.h"
-#include "xla/shape_util.h"
-#include "xla/stream_executor/cuda/cuda_dnn.h"
-#include "xla/stream_executor/cuda/cudnn_frontend_helpers.h"
-#include "xla/stream_executor/dnn.h"
-#include "xla/stream_executor/stream_executor.h"
-#include "xla/util.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-namespace {
-
-namespace fe = cudnn_frontend;
-namespace graph = fe::graph;
-
-inline std::optional<fe::PointwiseMode_t> GetElementwiseMode(
-    const HloInstruction& instruction) {
-  const HloOpcode opcode = instruction.opcode();
-  using m = fe::PointwiseMode_t;
-  switch (opcode) {
-    case HloOpcode::kAbs:
-      return m::ABS;
-    case HloOpcode::kAdd:
-      return m::ADD;
-    case HloOpcode::kCeil:
-      return m::CEIL;
-    case HloOpcode::kCompare:
-      switch (instruction.comparison_direction()) {
-        case Comparison::Direction::kEq:
-          return m::CMP_EQ;
-        case Comparison::Direction::kNe:
-          return m::CMP_NEQ;
-        case Comparison::Direction::kGe:
-          return m::CMP_GE;
-        case Comparison::Direction::kGt:
-          return m::CMP_GT;
-        case Comparison::Direction::kLe:
-          return m::CMP_LE;
-        case Comparison::Direction::kLt:
-          return m::CMP_LT;
-      }
-      break;
-    case HloOpcode::kConvert:
-      return m::IDENTITY;
-    case HloOpcode::kCos:
-      return m::COS;
-    case HloOpcode::kDivide:
-      return m::DIV;
-    case HloOpcode::kExp:
-      return m::EXP;
-    case HloOpcode::kFloor:
-      return m::FLOOR;
-    case HloOpcode::kLog:
-      return m::LOG;
-    case HloOpcode::kMaximum:
-      return m::MAX;
-    case HloOpcode::kMinimum:
-      return m::MIN;
-    case HloOpcode::kMultiply:
-      return m::MUL;
-    case HloOpcode::kNegate:
-      return m::NEG;
-    case HloOpcode::kPower:
-      return m::POW;
-    case HloOpcode::kRsqrt:
-      return m::RSQRT;
-#if CUDNN_VERSION >= 90100
-    case HloOpcode::kSelect:
-      return m::BINARY_SELECT;
-#endif  // CUDNN_VERSION
-    case HloOpcode::kSin:
-      return m::SIN;
-    case HloOpcode::kSqrt:
-      return m::SQRT;
-    case HloOpcode::kSubtract:
-      return m::SUB;
-    case HloOpcode::kTan:
-      return m::TAN;
-    case HloOpcode::kTanh:
-      return m::TANH_FWD;
-    default:
-      return std::nullopt;
-  }
-}
-
-inline std::optional<fe::DataType_t> ToCudnnDataType(const PrimitiveType type) {
-  using t = fe::DataType_t;
-  switch (type) {
-    case PrimitiveType::F32:
-      return t::FLOAT;
-    case PrimitiveType::F16:
-      return t::HALF;
-    case PrimitiveType::BF16:
-      return t::BFLOAT16;
-    case PrimitiveType::S32:
-      return t::INT32;
-    case PrimitiveType::S8:
-      return t::INT8;
-    case PrimitiveType::PRED:
-      return t::INT8;
-    case PrimitiveType::F8E5M2:
-      return t::FP8_E5M2;
-    case PrimitiveType::F8E4M3FN:
-      return t::FP8_E4M3;
-    default:
-      return std::nullopt;
-  }
-}
-
-inline std::optional<fe::DataType_t> GetComputeDataType(
-    const PrimitiveType type) {
-  fe::DataType_t compute_dtype = fe::DataType_t::FLOAT;
-  if (primitive_util::IsIntegralType(type)) {
-#if CUDNN_VERSION >= 90100
-    compute_dtype = fe::DataType_t::INT32;
-#else
-    VLOG(3) << "Integer math requires cuDNN 9.1+.";
-    return std::nullopt;
-#endif  // CUDNN_VERSION
-  }
-  return compute_dtype;
-}
-
-int FusionLevel(const HloInstruction& hlo) {
-  return hlo.GetModule()
-      ->config()
-      .debug_options()
-      .xla_gpu_cudnn_gemm_fusion_level();
-};
-
-// Extracts dimensions and strides from HLO tensors in the format expected by
-// cuDNN.
-class GemmDimensionAdapter {
-  explicit GemmDimensionAdapter(const HloDotInstruction& dot,
-                                TritonFusionAnalysis analysis)
-      : analysis_(std::move(analysis)), dot_(dot) {};
-
- public:
-  const TritonFusionAnalysis analysis_;
-
-  static absl::StatusOr<std::optional<GemmDimensionAdapter>> Create(
-      const HloComputation& computation) {
-    const HloInstruction* maybe_dot =
-        hlo_query::GetFirstInstructionWithOpcode(computation, HloOpcode::kDot);
-    if (maybe_dot == nullptr) {
-      VLOG(3) << "Not a GEMM fusion.";
-      return std::nullopt;
-    }
-    const HloDotInstruction* dot = DynCast<HloDotInstruction>(
-        hlo_query::GetFirstInstructionWithOpcode(computation, HloOpcode::kDot));
-    if (absl::c_any_of(dot->precision_config().operand_precision(),
-                       [](int x) { return x != PrecisionConfig::DEFAULT; })) {
-      VLOG(3) << "Non-default precision is not supported.";
-      return std::nullopt;
-    }
-    TF_ASSIGN_OR_RETURN(auto analysis,
-                        TritonFusionAnalysis::Execute(computation));
-    return GemmDimensionAdapter{*dot, std::move(analysis)};
-  }
-
-  bool DimensionsAndStrides(const HloInstruction& hlo,
-                            const TritonFusionAnalysis::Scope scope,
-                            std::vector<int64_t>& dimensions,
-                            std::vector<int64_t>& strides) {
-    const DotDimensionNumbers& dims = dot_.dot_dimension_numbers();
-    // GEMM fusions require a specific canonical order of dimensions.
-    constexpr int kBatchDimensionIndex = 0;
-    constexpr int kOutputLHSNonContractingDimensionIndex = 1;
-    std::vector<int64_t> dim_indices;
-    int lhs_noncontracting_index = -1;
-    switch (scope) {
-      case TritonFusionAnalysis::Scope::LHS:
-        lhs_noncontracting_index =
-            GetNonContractingDims(dot_.operand(0)->shape(),
-                                  dims.lhs_batch_dimensions(),
-                                  dims.lhs_contracting_dimensions())
-                .value()[0];
-        dim_indices = {
-            dims.lhs_batch_dimensions().empty() ? -1
-                                                : dims.lhs_batch_dimensions(0),
-            lhs_noncontracting_index, dims.lhs_contracting_dimensions(0)};
-        break;
-      case TritonFusionAnalysis::Scope::RHS:
-        dim_indices = {dims.rhs_batch_dimensions().empty()
-                           ? -1
-                           : dims.rhs_batch_dimensions(0),
-                       dims.rhs_contracting_dimensions(0),
-                       GetNonContractingDims(dot_.operand(1)->shape(),
-                                             dims.rhs_batch_dimensions(),
-                                             dims.rhs_contracting_dimensions())
-                           .value()[0]};
-        break;
-      case TritonFusionAnalysis::Scope::OUTPUT:
-        lhs_noncontracting_index = dot_.shape().rank() - 2;
-        dim_indices = {dims.lhs_batch_dimensions().empty() ? -1 : 0,
-                       lhs_noncontracting_index, dot_.shape().rank() - 1};
-        break;
-      case TritonFusionAnalysis::Scope::META:
-        LOG(FATAL) << "Unsupported scope.";
-    }
-    dimensions.reserve(dim_indices.size());
-    strides.reserve(dim_indices.size());
-    for (const int index : dim_indices) {
-      const auto* spec = analysis_.IterSpec(scope, &hlo, index);
-      if (spec == nullptr) {
-        dimensions.push_back(1);
-        strides.push_back(strides.empty() ? 1 : strides.back());
-        continue;
-      } else {
-        if (spec->size() == 1) {
-          // The dimension is not split, nothing to do.
-        } else if (spec->size() == 2) {
-          if (FusionLevel(hlo) < 3) {
-            return false;
-          }
-          if (!dims.lhs_batch_dimensions().empty()) {
-            VLOG(8) << "Noncontracting dimension split is not compatible with "
-                       "batch dimensions.";
-            return false;
-          }
-          if (index != lhs_noncontracting_index) {
-            VLOG(8) << "Only LHS noncontracting dimension can be split.";
-            return false;
-          }
-          switch (scope) {
-            case TritonFusionAnalysis::Scope::LHS:
-              lhs_noncontracting_split_ = spec->back().count;
-              break;
-            case TritonFusionAnalysis::Scope::OUTPUT:
-              if (lhs_noncontracting_split_ != spec->back().count) {
-                VLOG(8) << "Output non-contracting dimension has to be split "
-                           "the same way as the LHS input one if it is split.";
-                return false;
-              }
-              break;
-            default:
-              VLOG(8) << "Only LHS noncontracting dimension can be split.";
-              return false;
-          }
-          // Assign the major part of the noncontracting dimension to the
-          // unused batch one.
-          CHECK_EQ(dimensions[kBatchDimensionIndex], 1);
-          dimensions[kBatchDimensionIndex] = spec->back().count;
-          strides[kBatchDimensionIndex] = spec->back().stride;
-        } else {
-          VLOG(8) << "The dimension is split multiple times.";
-          return false;
-        }
-        dimensions.push_back(spec->front().count);
-        strides.push_back(spec->front().stride);
-      }
-    }
-    if (lhs_noncontracting_split_ > 1 &&
-        scope == TritonFusionAnalysis::Scope::OUTPUT &&
-        dimensions[kBatchDimensionIndex] == 1) {
-      // LHS input noncontracting dimension is split but the corresponding
-      // output one is not. Assign part of the output one to the unused batch
-      // dimension.
-      dimensions[kBatchDimensionIndex] = lhs_noncontracting_split_;
-      dimensions[kOutputLHSNonContractingDimensionIndex] /=
-          lhs_noncontracting_split_;
-      strides[kBatchDimensionIndex] =
-          strides[kOutputLHSNonContractingDimensionIndex] *
-          dimensions[kOutputLHSNonContractingDimensionIndex];
-    }
-    return true;
-  }
-
- private:
-  int64_t lhs_noncontracting_split_ = 1;
-  const HloDotInstruction& dot_;
-};
-
-template <PrimitiveType XlaT, typename T>
-std::shared_ptr<graph::Tensor_attributes> LiteralToCudnnTensor(
-    const HloInstruction& hlo, graph::Graph& graph) {
-  using NativeT = typename primitive_util::PrimitiveTypeToNative<XlaT>::type;
-  return graph.tensor(T(hlo.literal().GetFirstElement<NativeT>()));
-}
-
-std::optional<std::shared_ptr<graph::Tensor_attributes>>
-HandleConstantHloToCudnnGraph(const HloInstruction& hlo, graph::Graph& graph) {
-  CHECK(hlo.IsConstant()) << "HLO is not a constant: " << hlo.ToShortString();
-  if (!ShapeUtil::IsScalar(hlo.shape())) {
-    VLOG(3) << "Currently only support fusing scalar in the graph";
-    return std::nullopt;
-  }
-  PrimitiveType constant_type = hlo.shape().element_type();
-  switch (constant_type) {
-    case BF16:
-      return LiteralToCudnnTensor<BF16, __nv_bfloat16>(hlo, graph);
-    case F32:
-      return LiteralToCudnnTensor<F32, float>(hlo, graph);
-    case S32:
-      return LiteralToCudnnTensor<S32, int>(hlo, graph);
-    default:
-      VLOG(3) << "Unsupported constant type: "
-              << PrimitiveType_Name(constant_type);
-      return std::nullopt;
-  }
-}
-
-std::optional<std::shared_ptr<graph::Tensor_attributes>>
-HandleClampToCudnnGraph(
-    const HloInstruction& hlo, graph::Graph& graph,
-    absl::flat_hash_map<const HloInstruction*,
-                        std::shared_ptr<graph::Tensor_attributes>>
-        hlo_to_cudnn,
-    fe::DataType_t data_type, fe::DataType_t compute_dtype) {
-  CHECK(hlo.opcode() == HloOpcode::kClamp)
-      << "HLO is not a clamp: " << hlo.ToShortString();
-  CHECK(hlo.operands().size() == 3)
-      << "Clamp requires to have 3 operands: " << hlo.ToShortString();
-  // clamp = max(lower, min(value, upper));
-  const auto min_attrs = graph::Pointwise_attributes()
-                             .set_mode(fe::PointwiseMode_t::MIN)
-                             .set_compute_data_type(compute_dtype);
-  std::shared_ptr<graph::Tensor_attributes> min_tensor = graph.pointwise(
-      hlo_to_cudnn[hlo.operand(1)], hlo_to_cudnn[hlo.operand(2)], min_attrs);
-  min_tensor->set_data_type(data_type).set_name(std::string(hlo.name()));
-  const auto max_attrs = graph::Pointwise_attributes()
-                             .set_mode(fe::PointwiseMode_t::MAX)
-                             .set_compute_data_type(compute_dtype);
-  return graph.pointwise(min_tensor, hlo_to_cudnn[hlo.operand(0)], max_attrs);
-}
-
-// Traverses fusion computations and creates cuDNN graphs out of them.
-absl::StatusOr<std::optional<se::gpu::CudnnGraph>> HloFusionToCuDnnGraph(
-    const HloFusionInstruction& fusion) {
-  const HloComputation& computation = *fusion.fused_instructions_computation();
-  VLOG(5) << fusion.ToString();
-  VLOG(5) << computation.ToString();
-  graph::Graph graph;
-  std::vector<HloInstruction*> instructions =
-      computation.MakeInstructionPostOrder();
-  absl::flat_hash_map<const HloInstruction*,
-                      std::shared_ptr<graph::Tensor_attributes>>
-      hlo_to_cudnn;
-  TF_ASSIGN_OR_RETURN(std::optional<GemmDimensionAdapter> adapter,
-                      GemmDimensionAdapter::Create(computation));
-  if (!adapter.has_value()) {
-    return std::nullopt;
-  }
-  auto add_parameter = [&](const HloInstruction& parameter,
-                           std::vector<int64_t>& dimensions,
-                           std::vector<int64_t> strides) {
-    const std::optional<fe::DataType_t> data_type =
-        ToCudnnDataType(parameter.shape().element_type());
-    if (!data_type.has_value()) {
-      VLOG(3) << "Unsupported data type.";
-      return false;
-    }
-    hlo_to_cudnn[&parameter] = graph.tensor(
-        graph::Tensor_attributes()
-            .set_dim(dimensions)
-            .set_stride(strides)
-            .set_data_type(*data_type)
-            .set_name(std::string(parameter.name()))
-            .set_uid(se::gpu::CuDnnTensorUID(parameter.parameter_number())));
-    return true;
-  };
-  for (const TritonFusionAnalysis::Scope scope :
-       {TritonFusionAnalysis::Scope::LHS, TritonFusionAnalysis::Scope::RHS,
-        TritonFusionAnalysis::Scope::OUTPUT}) {
-    for (const HloInstruction* parameter :
-         adapter->analysis_.ScopeParameters(scope)) {
-      std::vector<int64_t> dimensions;
-      std::vector<int64_t> strides;
-      if (!adapter->DimensionsAndStrides(*parameter, scope, dimensions,
-                                         strides)) {
-        VLOG(3) << "Unsupported dimensions.";
-        return std::nullopt;
-      }
-      if (!add_parameter(*parameter, dimensions, strides)) {
-        return std::nullopt;
-      }
-    }
-  }
-
-  for (const HloInstruction* hlo : instructions) {
-    VLOG(5) << hlo->ToShortString();
-    auto operand = [&hlo_to_cudnn, &hlo](int i) {
-      return hlo_to_cudnn[hlo->operand(i)];
-    };
-    const auto data_type = ToCudnnDataType(hlo->shape().element_type());
-    if (!data_type.has_value()) {
-      VLOG(3) << "Unimplemented data type: " << hlo->shape().element_type();
-      return std::nullopt;
-    }
-    if (hlo->opcode() == HloOpcode::kParameter) {
-      CHECK(hlo_to_cudnn.contains(hlo));
-      continue;
-    } else if (hlo->opcode() == HloOpcode::kCustomCall) {
-      if (hlo->user_count() != 1 ||
-          !IsWorkspaceAllocationRoot(*hlo->users()[0])) {
-        VLOG(3) << "Custom calls are only expected to be used for workspace "
-                   "allocation.";
-        return std::nullopt;
-      }
-      continue;
-    } else if (hlo->opcode() == HloOpcode::kTuple) {
-      if (!IsWorkspaceAllocationRoot(*hlo)) {
-        VLOG(3) << "Tuples are only expected at outputs for workspace "
-                   "allocation.";
-        return std::nullopt;
-      }
-      continue;
-    } else if (FusionLevel(fusion) >= 2 &&
-               hlo->opcode() == HloOpcode::kConstant) {
-      if (const auto const_tensor = HandleConstantHloToCudnnGraph(*hlo, graph);
-          const_tensor.has_value()) {
-        hlo_to_cudnn[hlo] = const_tensor.value();
-      } else {
-        return std::nullopt;
-      }
-    } else if (hlo->opcode() == HloOpcode::kReshape ||
-               hlo->opcode() == HloOpcode::kBitcast ||
-               hlo->opcode() == HloOpcode::kTranspose ||
-               hlo->opcode() == HloOpcode::kCopy ||
-               (FusionLevel(fusion) >= 2 &&
-                hlo->opcode() == HloOpcode::kBroadcast)) {
-      // All these are accounted for separately as transformations of strides.
-      hlo_to_cudnn[hlo] = operand(0);
-    } else if (hlo->IsElementwise()) {
-      const auto compute_dtype =
-          GetComputeDataType(hlo->shape().element_type());
-      if (!compute_dtype.has_value()) {
-        return std::nullopt;
-      }
-      if (hlo->opcode() == HloOpcode::kClamp) {
-        const auto clamp =
-            HandleClampToCudnnGraph(*hlo, graph, hlo_to_cudnn,
-                                    data_type.value(), compute_dtype.value());
-        if (!clamp.has_value()) {
-          return std::nullopt;
-        }
-        hlo_to_cudnn[hlo] = clamp.value();
-      } else {
-        const auto mode = GetElementwiseMode(*hlo);
-        if (!mode.has_value()) {
-          VLOG(3) << "Unsupported elementwise operation.";
-          return std::nullopt;
-        }
-        const auto attrs = graph::Pointwise_attributes()
-                               .set_mode(mode.value())
-                               .set_compute_data_type(compute_dtype.value());
-        if (hlo->operand_count() == 1) {
-          hlo_to_cudnn[hlo] = graph.pointwise(operand(0), attrs);
-          // Sets the dimensions for unary ops whose operands are broadcast
-          // for cuDNN to infer its inputs' shapes. constant has dimension [1]
-          // while cuDNN requires constant to have dimension [1,1,1]. Not
-          // setting output of the unary shapes results in the rejection of
-          // the cuDNN graph.
-          if (hlo->operand(0)->opcode() == HloOpcode::kBroadcast) {
-            const auto scope = adapter->analysis_.QueryInstructionScope(*hlo);
-            std::vector<int64_t> dimensions;
-            std::vector<int64_t> strides;
-            if (!scope.has_value()) {
-              LOG(FATAL) << "No scope for instruction: "
-                         << hlo->ToShortString();
-            }
-            if (!adapter->DimensionsAndStrides(*hlo, scope.value(), dimensions,
-                                               strides)) {
-              VLOG(3) << "Unsupported hlo for querying dimensions: "
-                      << hlo->ToShortString();
-            } else {
-              hlo_to_cudnn[hlo]->set_dim(dimensions);
-            }
-          }
-        } else if (hlo->operand_count() == 2) {
-          hlo_to_cudnn[hlo] = graph.pointwise(operand(0), operand(1), attrs);
-        } else if (hlo->operand_count() == 3) {
-          if (hlo->opcode() != HloOpcode::kSelect) {
-            VLOG(3) << "Unexpected ternary operation: " << hlo->ToString();
-            return std::nullopt;
-          }
-          // Operand order for select differs between HLO and cuDNN.
-          hlo_to_cudnn[hlo] =
-              graph.pointwise(operand(1), operand(2), operand(0), attrs);
-        } else {
-          VLOG(3) << "Unimplemented elementwise operation.";
-          return std::nullopt;
-        }
-      }
-    } else if (hlo->opcode() == HloOpcode::kDot) {
-      const auto compute_dtype =
-          GetComputeDataType(hlo->shape().element_type());
-      if (!compute_dtype.has_value()) {
-        return std::nullopt;
-      }
-      hlo_to_cudnn[hlo] =
-          graph.matmul(operand(0), operand(1),
-                       graph::Matmul_attributes().set_compute_data_type(
-                           compute_dtype.value()));
-    } else {
-      VLOG(3) << "Unimplemented operation.";
-      return std::nullopt;
-    }
-    if (hlo_to_cudnn[hlo] == nullptr) {
-      VLOG(3) << "Creation of the operation failed.";
-      return std::nullopt;
-    }
-    hlo_to_cudnn[hlo]
-        ->set_data_type(data_type.value())
-        .set_name(std::string(hlo->name()));
-  }
-  const HloInstruction* output = instructions.back();
-  if (instructions.back()->shape().IsTuple()) {
-    output = instructions.back()->operand(0);
-  }
-  std::vector<int64_t> dimensions;
-  std::vector<int64_t> strides;
-  if (!adapter->DimensionsAndStrides(
-          *output, TritonFusionAnalysis::Scope::OUTPUT, dimensions, strides)) {
-    VLOG(3) << "Unsupported dimensions.";
-    return std::nullopt;
-  }
-  hlo_to_cudnn[output]
-      ->set_output(true)
-      .set_dim(dimensions)
-      .set_stride(strides)
-      .set_uid(se::gpu::CuDnnTensorUID(fusion.operand_count()));
-  if (!fusion.GetModule()->config().debug_options().xla_dump_to().empty()) {
-    json dump;
-    graph.serialize(dump);
-    DumpToFileInDirOrStdout(
-        /*module=*/*fusion.GetModule(),
-        /*file_prefix=*/"",
-        /*file_suffix=*/
-        absl::StrCat("cudnn_fusion_", fusion.name(), ".json"),
-        /*contents=*/dump.dump(1));
-  }
-
-  return se::gpu::CudnnGraph(std::move(graph));
-}
-
-// Creates a cuDNN graph, queries cuDNN whether it is supported.
-absl::StatusOr<se::gpu::CudnnGraph> PrepareGraph(
-    se::dnn::DnnSupport& dnn_support, const HloFusionInstruction& hlo) {
-  TF_ASSIGN_OR_RETURN(std::optional<se::gpu::CudnnGraph> graph,
-                      HloFusionToCuDnnGraph(hlo));
-  if (!graph.has_value()) {
-    return absl::InternalError("Construction of cuDNN graph failed.");
-  }
-  TF_RETURN_IF_ERROR(graph->Prepare(dnn_support));
-  return *graph;
-}
-
-absl::StatusOr<HloInstruction*> AddWorkspace(HloInstruction& fusion,
-                                             const int64_t workspace_size) {
-  HloComputation* computation = fusion.fused_instructions_computation();
-  HloInstruction* custom_call =
-      computation->AddInstruction(HloInstruction::CreateCustomCall(
-          ShapeUtil::MakeShape(S8, {workspace_size}), {},
-          kWorkspaceAllocationCustomCallTarget));
-  HloInstruction* output_tuple =
-      computation->AddInstruction(HloInstruction::CreateTuple(
-          {computation->root_instruction(), custom_call}));
-  computation->set_root_instruction(output_tuple, true);
-  HloInstruction* new_fusion = fusion.parent()->AddInstruction(
-      fusion.CloneWithNewShape(output_tuple->shape()));
-  TF_RETURN_IF_ERROR(fusion.ReplaceAllUsesWith(fusion.parent()->AddInstruction(
-      HloInstruction::CreateGetTupleElement(new_fusion, 0))));
-  TF_RETURN_IF_ERROR(fusion.parent()->RemoveInstruction(&fusion));
-  return new_fusion;
-}
-
-class CuDnnFusionVisitor : public DfsHloRewriteVisitor {
- public:
-  explicit CuDnnFusionVisitor(se::dnn::DnnSupport& dnn_support,
-                              BinaryMap& compilation_results)
-      : dnn_support_(dnn_support), compilation_results_(compilation_results) {}
-
-  absl::Status HandleFusion(HloInstruction* hlo) override {
-    TF_ASSIGN_OR_RETURN(auto gpu_config,
-                        hlo->backend_config<GpuBackendConfig>());
-    const auto& fusion_backend_config = gpu_config.fusion_backend_config();
-    if (fusion_backend_config.kind() != kCuDnnFusionKind) {
-      return absl::OkStatus();
-    }
-    int64_t plan_id = -1;
-    if (fusion_backend_config.has_cudnn_fusion_config()) {
-      plan_id = fusion_backend_config.cudnn_fusion_config().plan_id();
-    }
-
-    VLOG(4) << "Processing " << hlo->ToString();
-    VLOG(4) << "Plan ID: " << plan_id;
-
-    auto add_workspace = [&](const int64_t workspace_size) {
-      if (workspace_size > 0) {
-        TF_ASSIGN_OR_RETURN(hlo, AddWorkspace(*hlo, workspace_size));
-        SetVisited(*hlo);
-      }
-      return absl::OkStatus();
-    };
-    const std::string fingerprint_without_workspace =
-        GetComputationFingerprint(hlo->fused_instructions_computation(), {});
-    auto workspace_size_it =
-        workspace_sizes_.find(fingerprint_without_workspace);
-    if (workspace_size_it == workspace_sizes_.cend()) {
-      TF_ASSIGN_OR_RETURN(
-          se::gpu::CudnnGraph graph,
-          PrepareGraph(dnn_support_, *DynCast<HloFusionInstruction>(hlo)));
-
-      if (plan_id >= 0) {
-        // Build single plan with given ID.
-        if (plan_id >= graph.Graph().get_execution_plan_count()) {
-          return absl::InternalError("cuDNN graph plan does not exist.");
-        }
-        TF_RETURN_IF_ERROR(graph.Build(dnn_support_, plan_id));
-      } else {
-        // Build plans one by one till first successful when no plan_id was
-        // provided.
-        for (plan_id = 0; plan_id < graph.Graph().get_execution_plan_count();
-             ++plan_id) {
-          VLOG(7) << "Trying plan ID " << plan_id;
-          if (graph.Build(dnn_support_, plan_id).ok()) {
-            VLOG(7) << "Successfully built plan ID " << plan_id;
-            break;
-          }
-        }
-        if (plan_id == graph.Graph().get_execution_plan_count()) {
-          return absl::InternalError("No cuDNN plans can be built.");
-        }
-      }
-      const int64_t workspace_size = graph.Graph().get_workspace_size();
-      workspace_sizes_.insert(workspace_size_it,
-                              {fingerprint_without_workspace, workspace_size});
-      TF_RETURN_IF_ERROR(add_workspace(workspace_size));
-
-      std::vector<uint8_t> serialized_graph;
-      RETURN_IF_CUDNN_FRONTEND_ERROR(graph.Graph().serialize(serialized_graph));
-      // Compute a new fingerprint with a potential workspace for the
-      // compilation results to match a fingerprint computed by the emitter.
-      compilation_results_[GetComputationFingerprint(
-          hlo->fused_instructions_computation(), {})] =
-          std::string(reinterpret_cast<char*>(serialized_graph.data()),
-                      serialized_graph.size());
-    } else {
-      VLOG(4) << "Cache hit.";
-      TF_RETURN_IF_ERROR(add_workspace(workspace_size_it->second));
-    }
-    auto cudnn_config = gpu_config.mutable_fusion_backend_config()
-                            ->mutable_cudnn_fusion_config();
-    cudnn_config->set_plan_id(plan_id);
-    TF_RETURN_IF_ERROR(hlo->set_backend_config(gpu_config));
-
-    MarkAsChanged();
-    return absl::OkStatus();
-  }
-
- private:
-  se::dnn::DnnSupport& dnn_support_;
-  // <HLO computation fingerprint, serialized compiled cuDNN graph>.
-  BinaryMap& compilation_results_;
-  absl::flat_hash_map<std::string, int64_t> workspace_sizes_;
-};
-
-}  // namespace
-
-absl::StatusOr<bool> CuDnnFusionCompiler::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  XLA_SCOPED_LOGGING_TIMER("cuDNN fusion compiler");
-  return CuDnnFusionVisitor(dnn_support_, compilation_results_)
-      .RunOnModule(module, execution_threads);
-}
-
-int CuDnnFusionCompiler::GetAvailablePlanCount(
-    se::StreamExecutor& stream_exec, const HloFusionInstruction& hlo) {
-  auto graph = PrepareGraph(*stream_exec.AsDnn(), hlo);
-  if (!graph.ok()) {
-    return 0;
-  }
-  constexpr int64_t kMaxPlans = 10;
-  return std::min(graph->Graph().get_execution_plan_count(), kMaxPlans);
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/cudnn_fusion_compiler.h b/third_party/xla/xla/service/gpu/cudnn_fusion_compiler.h
deleted file mode 100644
index f34bbb1..0000000
--- a/third_party/xla/xla/service/gpu/cudnn_fusion_compiler.h
+++ /dev/null
@@ -1,59 +0,0 @@
-/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_CUDNN_FUSION_COMPILER_H_
-#define XLA_SERVICE_GPU_CUDNN_FUSION_COMPILER_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/gpu/ir_emission_utils.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/stream_executor/dnn.h"
-#include "xla/stream_executor/stream_executor.h"
-
-namespace xla {
-namespace gpu {
-
-// Converts HLO fusions with cuDNN backend config to cuDNN graphs,
-// compiles them using a cuDNN handle and serializes them.
-class CuDnnFusionCompiler : public HloModulePass {
- public:
-  explicit CuDnnFusionCompiler(se::StreamExecutor& stream_exec,
-                               BinaryMap& compilation_results)
-      : dnn_support_(*stream_exec.AsDnn()),
-        compilation_results_(compilation_results) {}
-
-  absl::string_view name() const override { return "cudnn-fusion-compiler"; }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
-  static int GetAvailablePlanCount(se::StreamExecutor& stream_exec,
-                                   const HloFusionInstruction& hlo);
-
- private:
-  se::dnn::DnnSupport& dnn_support_;
-  BinaryMap& compilation_results_;
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_CUDNN_FUSION_COMPILER_H_
diff --git a/third_party/xla/xla/service/gpu/cudnn_norm_rewriter.cc b/third_party/xla/xla/service/gpu/cudnn_norm_rewriter.cc
deleted file mode 100644
index 5e78f48..0000000
--- a/third_party/xla/xla/service/gpu/cudnn_norm_rewriter.cc
+++ /dev/null
@@ -1,1553 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/cudnn_norm_rewriter.h"
-
-#include <algorithm>
-#include <cstdint>
-#include <cstdlib>
-#include <functional>
-#include <iterator>
-#include <limits>
-#include <optional>
-#include <utility>
-#include <vector>
-
-#include "google/protobuf/wrappers.pb.h"
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_map.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/status.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/layout_util.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/service/hlo_creation_utils.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/types.h"
-#include "xla/util.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h"
-#include "tsl/platform/statusor.h"
-#include "tsl/protobuf/dnn.pb.h"
-
-#if GOOGLE_CUDA
-#include "third_party/gpus/cuda/include/cuda.h"  // IWYU pragma: keep
-#include "third_party/gpus/cudnn/cudnn.h"        // IWYU pragma: keep
-#include "third_party/gpus/cudnn/cudnn_version.h"
-#endif
-
-namespace xla {
-namespace gpu {
-
-namespace {
-
-namespace m = match;
-
-// Traverses the graph upward starting at instr and returns the
-// first instruction that is not a convert, bitcast or reshape.
-const HloInstruction* SkipUnaryOps(const HloInstruction* instr) {
-  while (instr->opcode() == HloOpcode::kConvert ||
-         instr->opcode() == HloOpcode::kBitcast ||
-         instr->opcode() == HloOpcode::kReshape) {
-    instr = instr->operand(0);
-  }
-  return instr;
-}
-
-// Recursively traverses the graph downward starting at instr and stores in
-// instrs the users that are not a convert, bitcast or reshape.
-void SkipUnaryOpsTopDownRecursive(HloInstruction* instr,
-                                  std::vector<HloInstruction*>& instrs) {
-  if (instr->opcode() == HloOpcode::kConvert ||
-      instr->opcode() == HloOpcode::kBitcast ||
-      instr->opcode() == HloOpcode::kReshape) {
-    for (HloInstruction* user : instr->users()) {
-      SkipUnaryOpsTopDownRecursive(user, instrs);
-    }
-  } else {
-    instrs.emplace_back(instr);
-  }
-}
-
-// Holds auxiliary information about individual layer norm patterns rewritten
-// into a cuDNN Custom Call.
-struct NormMetadata {
-  // Transposes applied to the input and output of the forward layer norm to
-  // order the normalization and non-normalization dimensions as required by
-  // cuDNN. Nullptr if no transposes were inserted.
-  HloInstruction *x_transpose, *y_transpose;
-  // The reduction and non-reduction dimensions of the input into the forward
-  // layer norm before the potential application of transposes and adjusted for
-  // the removal of any degenerate dimensions in the input to the norm.
-  std::vector<int64_t> norm_dims_adjusted, non_norm_dims_adjusted;
-};
-
-// Map from the instruction pointer of a layer norm Custom Call to its metadata.
-using NormMetadataMap = absl::flat_hash_map<HloInstruction*, NormMetadata>;
-
-// Captures multiple HloInstruction pointers and verifies that their target
-// is identical.
-//
-// Example:
-// Pattern cos(x) / sin(x) with cos and sin intended to operate on the same
-// HloInstruction:
-//  UniqueHloInstruction x;
-//  bool m = Match(
-//      instr, m::Divide(m::Cos(m::Op().WithPredicate(x.capture_and_verify)),
-//                       m::Sin(m::Op().WithPredicate(x.capture_and_verify))));
-// m is true and x.Instr() returns an HloInstruction pointer to the operand of
-// cosine and sine iff HloInstruction *instr points to a division of a cosine by
-// a sine that operate on the same instruction.
-class UniqueHloInstruction {
- public:
-  UniqueHloInstruction()
-      : is_set_(false), instr_(nullptr), capture_or_verify_() {}
-  HloInstruction* Instr() const { return instr_; }
-  void SetInstr(HloInstruction* instr) {
-    is_set_ = true;
-    instr_ = instr;
-  }
-
-  // Stores instr when invoked the first time. Otherwise, compares instr to the
-  // stored value and sets the stored value to nullptr if the comparison fails.
-  bool CaptureOrVerify(HloInstruction* instr) {
-    if (is_set_ && instr != instr_) {
-      instr_ = nullptr;
-    }
-    if (!is_set_) {
-      is_set_ = true;
-      instr_ = instr;
-    }
-    return instr_;
-  }
-
-  // Returns a std::function for capturing or verifying an instruction using
-  // WithPredicate.
-  std::function<bool(const HloInstruction*)> GetCaptureOrVerifyFn() {
-    if (!capture_or_verify_) {
-      capture_or_verify_ = [this](const HloInstruction* instr) -> bool {
-        return CaptureOrVerify(const_cast<HloInstruction*>(instr));
-      };
-    }
-    return capture_or_verify_;
-  }
-
- private:
-  bool is_set_;
-  HloInstruction* instr_;
-  std::function<bool(const HloInstruction*)> capture_or_verify_;
-};
-
-// Returns an architecture-specific constant for the calculation of an upper
-// bound for the size of the scratch space for layer norm kernels.
-absl::StatusOr<int64_t> CConstant(
-    se::CudaComputeCapability cuda_compute_capability) {
-  if (cuda_compute_capability.major == se::CudaComputeCapability::AMPERE) {
-    return 32 * 128;
-  } else if (cuda_compute_capability.major ==
-             se::CudaComputeCapability::HOPPER) {
-    return 32 * 144;
-  }
-  return xla::Internal("Norm kernels require Ampere or Hopper architecture.");
-}
-
-// Returns whether the element type of instr is compatible with layer norm
-// kernels.
-bool CompatibleElementType(const HloInstruction* instr) {
-  PrimitiveType element_type = instr->shape().element_type();
-  return element_type == BF16 || element_type == F16 || element_type == F32;
-}
-
-// Returns the dimensions associated with shape, adjusted for the removal of any
-// degenerate dimensions in shape. Specifically, for each dimension d in
-// dimensions, returns the new index of d if all dimensions of size 1 are
-// removed from shape. If d has size 1, it is not included in the returned
-// vector.
-std::vector<int64_t> AdjustedDimensions(const Shape& shape,
-                                        absl::Span<const int64_t> dimensions) {
-  absl::flat_hash_map<int64_t, int64_t> dimension_map;
-  for (int64_t dimension = 0, non_degen_dimension = 0; dimension < shape.rank();
-       ++dimension) {
-    if (shape.dimensions(dimension) > 1) {
-      dimension_map.insert({dimension, non_degen_dimension});
-      non_degen_dimension++;
-    }
-  }
-  std::vector<int64_t> adjusted_dimensions;
-  for (int64_t dimension : dimensions) {
-    auto non_degenerate_dimension = dimension_map.find(dimension);
-    if (non_degenerate_dimension != dimension_map.end()) {
-      adjusted_dimensions.emplace_back(non_degenerate_dimension->second);
-    }
-  }
-  return adjusted_dimensions;
-}
-
-// Returns the dimensions of broadcast or reduction instructions, adjusted for
-// the removal of any degenerate dimensions in the output or input.
-std::vector<int64_t> AdjustedDimensions(const HloInstruction* instr) {
-  Shape shape;
-  if (instr->opcode() == HloOpcode::kBroadcast) {
-    shape = instr->shape();
-  } else if (instr->opcode() == HloOpcode::kReduce) {
-    shape = instr->operand(0)->shape();
-  } else {
-    return {};
-  }
-  return AdjustedDimensions(shape, instr->dimensions());
-}
-
-// Returns whether the HLO Computation applied by instr calculates the sum of
-// the elements. When provided, compares reduce_dims to the dimensions of the
-// reduction.
-bool AppliesAddReduce(const HloInstruction* instr,
-                      absl::Span<const int64_t> reduce_dims = {}) {
-  if (instr->opcode() != HloOpcode::kReduce) {
-    return false;
-  }
-
-  // Verify the dimensions of the reduction.
-  if (!reduce_dims.empty() && AdjustedDimensions(instr) != reduce_dims) {
-    return false;
-  }
-
-  HloComputation* reduce_comp = instr->to_apply();
-  HloInstruction* reduce_comp_root = reduce_comp->root_instruction();
-  return instr->operand_count() == 2 &&
-         instr->operand(1)->opcode() == HloOpcode::kConstant &&
-         ShapeUtil::IsScalar(instr->operand(1)->shape()) &&
-         instr->operand(1)->literal().GetAsDouble({}) == 0. &&
-         reduce_comp_root->opcode() == HloOpcode::kAdd &&
-         reduce_comp_root->operand(0)->opcode() == HloOpcode::kParameter &&
-         reduce_comp_root->operand(1)->opcode() == HloOpcode::kParameter;
-}
-
-// Returns whether instr multiplies the result of a reduction by one over the
-// number of reduced elements.
-bool CalculatesExpectation(const HloInstruction* instr) {
-  instr = SkipUnaryOps(instr);
-  if (instr->opcode() != HloOpcode::kMultiply) {
-    return false;
-  }
-  bool bcast_operand = instr->operand(0)->opcode() != HloOpcode::kBroadcast;
-  const HloInstruction *broadcast = instr->operand(bcast_operand),
-                       *reduce = SkipUnaryOps(instr->operand(!bcast_operand));
-  if (reduce->opcode() != HloOpcode::kReduce ||
-      broadcast->opcode() != HloOpcode::kBroadcast ||
-      broadcast->operand(0)->opcode() != HloOpcode::kConstant) {
-    return false;
-  }
-
-  float actual_r_nelems =
-      broadcast->operand(0)->literal().GetAsDouble({}).value();
-  int64_t nelems = 1;
-  for (int64_t norm_dim : reduce->dimensions()) {
-    nelems *= reduce->operand(0)->shape().dimensions()[norm_dim];
-  }
-  // The absolute of the difference between the actual scaling factor and the
-  // reference value must not exceed a prescribed threshold.
-  float r_nelems = 1. / static_cast<float>(nelems);
-  float numerical_epsilon = std::numeric_limits<bfloat16>::epsilon();
-  return abs(actual_r_nelems - r_nelems) <
-         ((actual_r_nelems + r_nelems) * numerical_epsilon);
-}
-
-// Returns whether target can be reached from instr by recursively traversing
-// the graph across converts, bitcasts and reshapes.
-bool FindTargetRecursive(
-    const HloInstruction* instr, const HloInstruction* target,
-    absl::flat_hash_set<const HloInstruction*>& visited_instrs,
-    const HloInstruction* transpose) {
-  visited_instrs.emplace(instr);
-  const absl::flat_hash_set<HloOpcode> supported_ops = {
-      HloOpcode::kConvert, HloOpcode::kBitcast, HloOpcode::kReshape};
-  if (instr == target) {
-    return true;
-  }
-  // Look for target among the users of instr.
-  for (HloInstruction* user : instr->users()) {
-    if ((supported_ops.contains(user->opcode()) || user == transpose) &&
-        !visited_instrs.contains(user)) {
-      return FindTargetRecursive(user, target, visited_instrs, transpose);
-    }
-  }
-  // Ascend the graph if target is not found and instr is a convert, bitcast
-  // or reshape.
-  if (supported_ops.contains(instr->opcode())) {
-    return FindTargetRecursive(instr->operand(0), target, visited_instrs,
-                               transpose);
-  }
-  return false;
-}
-
-bool FindTarget(const HloInstruction* custom_call, const HloInstruction* instr,
-                const HloInstruction* target,
-                const NormMetadataMap& norm_metadata) {
-  absl::flat_hash_set<const HloInstruction*> visited_instrs;
-  auto custom_call_metadata = norm_metadata.find(custom_call);
-  if (custom_call_metadata == norm_metadata.end()) {
-    return false;
-  }
-  return FindTargetRecursive(instr, target, visited_instrs,
-                             custom_call_metadata->second.x_transpose);
-}
-
-// Maps the dimension numbers in dimensions from shape original_shape to shape
-// reshaped_shape, assuming that the shapes are related through a strict
-// reshape. Returns an empty vector if a dimension mapping is not found.
-std::vector<int64_t> MapDimensions(const Shape& original_shape,
-                                   const Shape& reshaped_shape,
-                                   const absl::Span<const int64_t> dimensions) {
-  auto dimension_product =
-      [](const Shape& shape,
-         absl::Span<const int64_t> product_dimensions) -> int64_t {
-    int64_t product = 1;
-    for (int64_t product_dimension : product_dimensions) {
-      product *= shape.dimensions(product_dimension);
-    }
-    return product;
-  };
-  // Construct the dimension mapping.
-  absl::flat_hash_map<int64_t, std::vector<int64_t>> dimensions_map;
-  std::vector<int64_t> original_dimensions, reshaped_dimensions;
-  for (int64_t original_dimension = 0, reshaped_dimension = 0;
-       original_dimension < original_shape.rank(); ++original_dimension) {
-    original_dimensions.emplace_back(original_dimension);
-    while ((reshaped_dimensions.empty() ||
-            dimension_product(reshaped_shape, reshaped_dimensions) <
-                dimension_product(original_shape, original_dimensions)) &&
-           reshaped_dimension < reshaped_shape.rank()) {
-      reshaped_dimensions.emplace_back(reshaped_dimension++);
-    }
-
-    // Many-to-many dimension mappings are not supported.
-    if (original_dimensions.size() > 1 && reshaped_dimensions.size() > 1) {
-      return {};
-    }
-
-    if (dimension_product(original_shape, original_dimensions) ==
-        dimension_product(reshaped_shape, reshaped_dimensions)) {
-      std::vector<int64_t> original_dimensions_in_dimensions;
-      std::set_intersection(
-          original_dimensions.begin(), original_dimensions.end(),
-          dimensions.begin(), dimensions.end(),
-          std::back_inserter(original_dimensions_in_dimensions));
-      // The unique mapping of dimensions requires either all or none of the
-      // entries of original_dimensions to be an element of dimensions.
-      if (!original_dimensions_in_dimensions.empty() &&
-          original_dimensions_in_dimensions.size() !=
-              original_dimensions.size()) {
-        return {};
-      }
-      for (int64_t dimension : original_dimensions) {
-        dimensions_map.insert({dimension, reshaped_dimensions});
-      }
-      original_dimensions.clear();
-      reshaped_dimensions.clear();
-    }
-  }
-
-  // Map the dimensions numbers to the reshaped shape.
-  std::vector<int64_t> mapped_dimensions;
-  for (int64_t dimension : dimensions) {
-    auto mapped_dimension = dimensions_map.find(dimension);
-    if (mapped_dimension == dimensions_map.end()) {
-      return {};
-    }
-    mapped_dimensions.insert(mapped_dimensions.end(),
-                             mapped_dimension->second.begin(),
-                             mapped_dimension->second.end());
-  }
-
-  // Eliminate duplicates in the mapped dimension numbers.
-  mapped_dimensions.erase(
-      std::unique(mapped_dimensions.begin(), mapped_dimensions.end()),
-      mapped_dimensions.end());
-  return mapped_dimensions;
-}
-
-// Recursively traverses the graph across converts, bitcasts and reshapes,
-// starting from instr, and returns the first addition-reduction identified.
-// Returns nullptr if no addition-reduction is found.
-HloInstruction* FindAddReduceRecursive(
-    HloInstruction* instr, const Shape& orig_instr_shape,
-    const absl::Span<const int64_t> reduce_dims,
-    absl::flat_hash_set<HloInstruction*>& visited_instrs) {
-  visited_instrs.emplace(instr);
-  const absl::flat_hash_set<HloOpcode> supported_ops = {
-      HloOpcode::kConvert, HloOpcode::kBitcast, HloOpcode::kReshape};
-  // Look for a reduction among the users of instr.
-  for (HloInstruction* user : instr->users()) {
-    if (user->opcode() == HloOpcode::kReduce) {
-      std::vector<int64_t> mapped_reduce_dims =
-          MapDimensions(orig_instr_shape, instr->shape(), reduce_dims);
-      if (!mapped_reduce_dims.empty() &&
-          AppliesAddReduce(user, mapped_reduce_dims)) {
-        return user;
-      }
-    }
-    if (supported_ops.contains(user->opcode()) &&
-        !visited_instrs.contains(user)) {
-      return FindAddReduceRecursive(user, orig_instr_shape, reduce_dims,
-                                    visited_instrs);
-    }
-  }
-  // Ascend the graph if the addition-reduction is not found and instr is a
-  // convert, bitcast or reshape.
-  if (supported_ops.contains(instr->opcode())) {
-    return FindAddReduceRecursive(instr->mutable_operand(0), orig_instr_shape,
-                                  reduce_dims, visited_instrs);
-  }
-  return nullptr;
-}
-
-HloInstruction* FindAddReduce(HloInstruction* instr,
-                              const absl::Span<const int64_t> reduce_dims) {
-  absl::flat_hash_set<HloInstruction*> visited_instrs;
-  return FindAddReduceRecursive(instr, instr->shape(), reduce_dims,
-                                visited_instrs);
-}
-
-// Type conversion from and to any of BF16, FP16 and FP32.
-template <typename Pattern>
-auto SupportedConvert(Pattern pattern) {
-  auto supported_convert = [](const HloInstruction* instr) -> bool {
-    return CompatibleElementType(instr) &&
-           CompatibleElementType(instr->operand(0));
-  };
-  return m::Convert(pattern).WithPredicate(supported_convert);
-}
-
-// Bitcast or reshape adding or removing degenerate dimensions.
-template <typename Pattern>
-auto SupportedBitcastOrReshape(Pattern pattern) {
-  auto supported_bitcast_or_reshape = [](const HloInstruction* instr) -> bool {
-    return ShapeUtil::Equal(
-        ShapeUtil::DropDegenerateDimensions(instr->shape()),
-        ShapeUtil::DropDegenerateDimensions(instr->operand(0)->shape()));
-  };
-  return m::AnyOf<HloInstruction>(
-      m::Bitcast(pattern).WithPredicate(supported_bitcast_or_reshape),
-      m::Reshape(pattern).WithPredicate(supported_bitcast_or_reshape));
-}
-
-// Matches pattern, SupportedConvert(pattern),
-// SupportedBitcastOrReshape(pattern),
-// SupportedConvert(SupportedBitcastOrReshape(pattern)) and
-// SupportedBitcastOrReshape(SupportedConvert(pattern)).
-template <typename Pattern>
-auto OptionalSupportedTransform(Pattern pattern) {
-  auto shared_subpattern = m::SharedSubpattern(pattern);
-  return m::AnyOf<HloInstruction>(
-      SupportedConvert(SupportedBitcastOrReshape(shared_subpattern)),
-      SupportedBitcastOrReshape(SupportedConvert(shared_subpattern)),
-      SupportedConvert(shared_subpattern),
-      SupportedBitcastOrReshape(shared_subpattern), shared_subpattern);
-}
-
-// Bitcast or reshape with optional supported type conversion and/or addition or
-// removal of degenerate dimensions.
-template <typename Pattern>
-auto BitcastOrReshape(Pattern pattern) {
-  return OptionalSupportedTransform(
-      m::AnyOf<HloInstruction>(m::Bitcast(pattern), m::Reshape(pattern)));
-}
-
-// Transpose with optional supported type conversion and/or addition or removal
-// of degenerate dimensions.
-template <typename Pattern>
-auto Transpose(Pattern pattern) {
-  return OptionalSupportedTransform(m::Transpose(pattern));
-}
-
-// Rsqrt with optional supported type conversion and/or addition or removal of
-// degenerate dimensions.
-template <typename Pattern>
-auto Rsqrt(HloInstruction** rsqrt, Pattern pattern) {
-  return OptionalSupportedTransform(m::Rsqrt(rsqrt, pattern));
-}
-
-// AddAnyOrder with optional supported type conversion and/or addition or
-// removal of degenerate dimensions.
-template <typename Pattern0, typename Pattern1>
-auto AddAnyOrder(Pattern0 pattern0, Pattern1 pattern1) {
-  return OptionalSupportedTransform(m::AddAnyOrder(pattern0, pattern1));
-}
-
-// Subtract with optional supported type conversion and/or addition or removal
-// of degenerate dimensions.
-template <typename Pattern0, typename Pattern1>
-auto Subtract(Pattern0 pattern0, Pattern1 pattern1) {
-  return OptionalSupportedTransform(m::Subtract(pattern0, pattern1));
-}
-
-// Capturing subtract with optional supported type conversion and/or addition or
-// removal of degenerate dimensions.
-template <typename Pattern0, typename Pattern1>
-auto Subtract(HloInstruction** subtract, Pattern0 pattern0, Pattern1 pattern1) {
-  return OptionalSupportedTransform(m::Subtract(subtract, pattern0, pattern1));
-}
-
-// Multiply with optional supported type conversion and/or addition or removal
-// of degenerate dimensions.
-template <typename Pattern0, typename Pattern1>
-auto MultiplyAnyOrder(Pattern0 pattern0, Pattern1 pattern1) {
-  return OptionalSupportedTransform(m::MultiplyAnyOrder(pattern0, pattern1));
-}
-
-// Capturing multiply with optional supported type conversion and/or addition or
-// removal of degenerate dimensions.
-template <typename Pattern0, typename Pattern1>
-auto MultiplyAnyOrder(HloInstruction** multiply, Pattern0 pattern0,
-                      Pattern1 pattern1) {
-  return OptionalSupportedTransform(
-      m::MultiplyAnyOrder(multiply, pattern0, pattern1));
-}
-
-// Multiplication of pattern by itself with optional supported type conversion
-// and/or addition or removal of degenerate dimensions.
-template <typename Pattern>
-auto Square(Pattern pattern) {
-  return MultiplyAnyOrder(pattern, pattern)
-      .WithPredicate([](const HloInstruction* instr) {
-        return instr->unique_operands().size() == 1;
-      });
-}
-
-// Multiplication of the square of pattern by pattern with optional supported
-// type conversion and/or addition or removal of degenerate dimensions. The root
-// instruction of pattern cannot be a multiplication.
-template <typename Pattern>
-auto Cube(Pattern pattern) {
-  auto unique_cube = [](const HloInstruction* instr) -> bool {
-    bool square_operand = instr->operand(0)->opcode() != HloOpcode::kMultiply;
-    return instr->operand(!square_operand)->opcode() != HloOpcode::kMultiply &&
-           instr->operand(square_operand)->operand(0) ==
-               instr->operand(!square_operand);
-  };
-  return MultiplyAnyOrder(Square(pattern), pattern).WithPredicate(unique_cube);
-}
-
-// Addition-reduction of pattern with optional supported type conversion and/or
-// addition or removal of degenerate dimensions.
-template <typename Pattern>
-auto AddReduce(Pattern pattern) {
-  return OptionalSupportedTransform(
-      m::Reduce(pattern, m::Op())
-          .WithPredicate([](const HloInstruction* instr) {
-            return AppliesAddReduce(instr);
-          }));
-}
-
-// Capturing addition-reduction of pattern with optional supported type
-// conversion and/or addition or removal of degenerate dimensions.
-template <typename Pattern>
-auto AddReduce(HloInstruction** reduction, Pattern pattern) {
-  return OptionalSupportedTransform(
-      m::Reduce(reduction, pattern, m::Op())
-          .WithPredicate([](const HloInstruction* instr) {
-            return AppliesAddReduce(instr);
-          }));
-}
-
-// Negated addition-reduction.
-template <typename Pattern>
-auto NegateAddReduce(HloInstruction** reduction, Pattern pattern) {
-  return m::AnyOf<HloInstruction>(AddReduce(reduction, m::Negate(pattern)),
-                                  m::Negate(AddReduce(reduction, pattern)));
-}
-
-// Expected value, or mean, with optional broadcast.
-template <typename Pattern>
-auto Expectation(Pattern pattern) {
-  auto shared_subpattern =
-      MultiplyAnyOrder(m::Broadcast(m::ConstantScalar()), AddReduce(pattern))
-          .WithPredicate([](const HloInstruction* instr) {
-            return CalculatesExpectation(instr);
-          });
-  return m::AnyOf<HloInstruction>(m::Broadcast(shared_subpattern),
-                                  shared_subpattern);
-}
-
-// Expected value, or mean, with optional broadcast.
-template <typename Pattern>
-auto Expectation(UniqueHloInstruction* expectation, Pattern pattern) {
-  auto shared_subpattern = OptionalSupportedTransform(
-      m::MultiplyAnyOrder(m::Broadcast(m::ConstantScalar()), AddReduce(pattern))
-          .WithPredicate([](const HloInstruction* instr) {
-            return CalculatesExpectation(instr);
-          })
-          .WithPredicate(expectation->GetCaptureOrVerifyFn()));
-  return m::AnyOf<HloInstruction>(m::Broadcast(shared_subpattern),
-                                  shared_subpattern);
-}
-
-// Expected value, or mean, with optional broadcast.
-template <typename Pattern>
-auto Expectation(UniqueHloInstruction* expectation, HloInstruction** reduce,
-                 Pattern pattern) {
-  auto shared_subpattern = OptionalSupportedTransform(
-      m::MultiplyAnyOrder(m::Broadcast(m::ConstantScalar()),
-                          AddReduce(reduce, pattern))
-          .WithPredicate([](const HloInstruction* instr) {
-            return CalculatesExpectation(instr);
-          })
-          .WithPredicate(expectation->GetCaptureOrVerifyFn()));
-  return m::AnyOf<HloInstruction>(m::Broadcast(shared_subpattern),
-                                  shared_subpattern);
-}
-
-// Variance, expressed as expectation(X^2) - expectation(X)^2 or
-// expectation((X - expectation(X))^2).
-auto Variance(UniqueHloInstruction* variance, UniqueHloInstruction* expectation,
-              UniqueHloInstruction* x) {
-  return m::AnyOf<HloInstruction>(
-      Subtract(
-          Expectation(Square(OptionalSupportedTransform(
-              m::Op().WithPredicate(x->GetCaptureOrVerifyFn())))),
-          Square(Expectation(expectation,
-                             OptionalSupportedTransform(m::Op().WithPredicate(
-                                 x->GetCaptureOrVerifyFn())))))
-          .WithPredicate(variance->GetCaptureOrVerifyFn()),
-      Expectation(
-          Square(Subtract(
-              OptionalSupportedTransform(
-                  m::Op().WithPredicate(x->GetCaptureOrVerifyFn())),
-              Expectation(expectation,
-                          OptionalSupportedTransform(m::Op().WithPredicate(
-                              x->GetCaptureOrVerifyFn()))))))
-          .WithPredicate(variance->GetCaptureOrVerifyFn()));
-}
-
-// Reciprocal of the square root of variance + epsilon with optional broadcast.
-auto NormFactor(HloInstruction** norm_factor, UniqueHloInstruction* x,
-                UniqueHloInstruction* variance,
-                UniqueHloInstruction* expectation,
-                UniqueHloInstruction* epsilon) {
-  auto shared_subpattern = m::SharedSubpattern(Rsqrt(
-      norm_factor, AddAnyOrder(Variance(variance, expectation, x),
-                               m::Broadcast(m::ConstantScalar().WithPredicate(
-                                   epsilon->GetCaptureOrVerifyFn())))));
-  return m::AnyOf<HloInstruction>(m::Broadcast(shared_subpattern),
-                                  shared_subpattern);
-}
-
-// Any order of p0 * p1 * p2.
-template <typename P0, typename P1, typename P2>
-auto MultiplyMultiplyAnyOrder(P0 p0, P1 p1, P2 p2) {
-  return m::AnyOf<HloInstruction>(
-      MultiplyAnyOrder(p0, MultiplyAnyOrder(p1, p2)),
-      MultiplyAnyOrder(p1, MultiplyAnyOrder(p0, p2)),
-      MultiplyAnyOrder(p2, MultiplyAnyOrder(p0, p1)));
-}
-
-// Any order of p0 + p1 + p2.
-template <typename P0, typename P1, typename P2>
-auto AddAddAnyOrder(P0 p0, P1 p1, P2 p2) {
-  return m::AnyOf<HloInstruction>(AddAnyOrder(p0, AddAnyOrder(p1, p2)),
-                                  AddAnyOrder(p1, AddAnyOrder(p0, p2)),
-                                  AddAnyOrder(p2, AddAnyOrder(p0, p1)));
-}
-
-// Any order of p0 * (p1 + p2).
-template <typename P0, typename P1, typename P2>
-auto MultiplyAddAnyOrder(P0 p0, P1 p1, P2 p2) {
-  return m::AnyOf<HloInstruction>(
-      MultiplyAnyOrder(p0, AddAnyOrder(p1, p2)),
-      AddAnyOrder(MultiplyAnyOrder(p0, p1), MultiplyAnyOrder(p0, p2)));
-}
-
-// Any order of p0 - p1 + p2.
-template <typename P0, typename P1, typename P2>
-auto SubtractAddAnyOrder(P0 p0, P1 p1, P2 p2) {
-  return m::AnyOf<HloInstruction>(AddAnyOrder(Subtract(p0, p1), p2),
-                                  AddAnyOrder(Subtract(p2, p1), p0),
-                                  Subtract(AddAnyOrder(p0, p2), p1));
-}
-
-// Any order of (p0 - p1) * p2 * p3 + p4.
-template <typename P0, typename P1, typename P2, typename P3, typename P4>
-auto SubtractMultiplyAddAnyOrder(P0 p0, P1 p1, P2 p2, P3 p3, P4 p4) {
-  return m::AnyOf<HloInstruction>(
-      SubtractAddAnyOrder(MultiplyMultiplyAnyOrder(p0, p2, p3),
-                          MultiplyMultiplyAnyOrder(p1, p2, p3), p4),
-      AddAnyOrder(MultiplyMultiplyAnyOrder(Subtract(p0, p1), p2, p3), p4));
-}
-
-// Expectation fused into a layer norm Custom Call.
-auto FusedExpectation(UniqueHloInstruction* custom_call) {
-  auto shared_subpattern = m::SharedSubpattern(m::GetTupleElement(
-      m::CustomCall({kCudnnNormCallTarget})
-          .WithPredicate(custom_call->GetCaptureOrVerifyFn()),
-      1));
-  return m::AnyOf<HloInstruction>(shared_subpattern,
-                                  BitcastOrReshape(shared_subpattern));
-}
-
-// Expectation fused into a layer norm Custom Call.
-auto FusedExpectation(UniqueHloInstruction* fused_expectation,
-                      UniqueHloInstruction* custom_call) {
-  auto shared_subpattern = m::SharedSubpattern(
-      m::GetTupleElement(
-          m::CustomCall({kCudnnNormCallTarget})
-              .WithPredicate(custom_call->GetCaptureOrVerifyFn()),
-          1)
-          .WithPredicate(fused_expectation->GetCaptureOrVerifyFn()));
-  return m::AnyOf<HloInstruction>(shared_subpattern,
-                                  BitcastOrReshape(shared_subpattern));
-}
-
-// Norm factor fused into a layer norm Custom Call.
-auto FusedNormFactor(UniqueHloInstruction* custom_call) {
-  auto shared_subpattern = m::SharedSubpattern(m::GetTupleElement(
-      m::CustomCall({kCudnnNormCallTarget})
-          .WithPredicate(custom_call->GetCaptureOrVerifyFn()),
-      2));
-  return m::AnyOf<HloInstruction>(shared_subpattern,
-                                  BitcastOrReshape(shared_subpattern));
-}
-
-// Norm factor fused into a layer norm Custom Call.
-auto FusedNormFactor(UniqueHloInstruction* fused_norm_factor,
-                     UniqueHloInstruction* custom_call) {
-  auto shared_subpattern = m::SharedSubpattern(
-      m::GetTupleElement(
-          m::CustomCall({kCudnnNormCallTarget})
-              .WithPredicate(custom_call->GetCaptureOrVerifyFn()),
-          2)
-          .WithPredicate(fused_norm_factor->GetCaptureOrVerifyFn()));
-  return m::AnyOf<HloInstruction>(shared_subpattern,
-                                  BitcastOrReshape(shared_subpattern));
-}
-
-// Derivative of the norm factor w.r.t. variance + epsilon,
-// d(norm_factor)/d(variance + epsilon)
-// = d((variance + epsilon)^-1/2)/d(variance + epsilon)
-// = -1/2 * norm_factor^3.
-// Forwards custom_call to FusedNormFactor for verification.
-auto DNormFactor(UniqueHloInstruction* custom_call) {
-  return MultiplyAnyOrder(m::Broadcast(m::ConstantScalar(-0.5)),
-                          Cube(FusedNormFactor(custom_call)));
-}
-
-//  Zero-centered input of the layer norm, X - expectation(X). Verifies that
-//  custom_call is a forward layer norm fusing X. Forwards custom_call to
-//  FusedExpectation for verification.
-auto XCenter(UniqueHloInstruction* x, UniqueHloInstruction* custom_call,
-             const NormMetadataMap& norm_metadata) {
-  auto capture_or_verify_x =
-      [x, custom_call, &norm_metadata](const HloInstruction* instr) -> bool {
-    return x->CaptureOrVerify(
-        FindTarget(custom_call->Instr(), instr->operand(0),
-                   custom_call->Instr()->operand(0), norm_metadata)
-            ? custom_call->Instr()->mutable_operand(0)
-            : nullptr);
-  };
-  return Subtract(m::Op(), m::Broadcast(FusedExpectation(custom_call)))
-      .WithPredicate(capture_or_verify_x);
-}
-
-// Zero-centered input of the layer norm, X - expectation(X). Captures X in x if
-// custom_call is a forward layer norm fusing X. Forwards custom_call to
-// FusedExpectation for comparison.
-auto XCenter(UniqueHloInstruction* x_center, UniqueHloInstruction* x,
-             UniqueHloInstruction* fused_expectation,
-             UniqueHloInstruction* custom_call,
-             const NormMetadataMap& norm_metadata) {
-  auto capture_or_verify_x =
-      [x, custom_call, &norm_metadata](const HloInstruction* instr) -> bool {
-    return x->CaptureOrVerify(
-        FindTarget(custom_call->Instr(), instr->operand(0),
-                   custom_call->Instr()->operand(0), norm_metadata)
-            ? custom_call->Instr()->mutable_operand(0)
-            : nullptr);
-  };
-  return Subtract(m::Op(), m::Broadcast(FusedExpectation(fused_expectation,
-                                                         custom_call)))
-      .WithPredicate(x_center->GetCaptureOrVerifyFn())
-      .WithPredicate(capture_or_verify_x);
-}
-
-// Addition-reduction of the product of XCenter, the broadcasted scale and DY,
-// XCenter * scale * DY. Captures the scale in scale if custom_call is a forward
-// layer norm fusing the scale. Forwards custom_call to XCenter for comparison.
-auto F0(UniqueHloInstruction* custom_call, UniqueHloInstruction* scale,
-        UniqueHloInstruction* dy, UniqueHloInstruction* x,
-        HloInstruction** reduce, const NormMetadataMap& norm_metadata) {
-  auto capture_or_verify_scale = [scale, custom_call, &norm_metadata](
-                                     const HloInstruction* instr) -> bool {
-    return scale->CaptureOrVerify(FindTarget(custom_call->Instr(), instr,
-                                             custom_call->Instr()->operand(1),
-                                             norm_metadata)
-                                      ? custom_call->Instr()->mutable_operand(1)
-                                      : nullptr);
-  };
-  return AddReduce(
-      reduce, MultiplyMultiplyAnyOrder(
-                  XCenter(x, custom_call, norm_metadata),
-                  m::Broadcast(m::Op().WithPredicate(capture_or_verify_scale)),
-                  m::Op().WithPredicate(dy->GetCaptureOrVerifyFn())));
-}
-
-// Product of XCenter and the scaled and broadcasted product of F0 and
-// d(norm_factor)/d(variance + epsilon), XCenter * F0 * DNormFactor * 2 /
-// nelems. Forwards custom_call to XCenter, F0 and DNormFactor for capture or
-// verification.
-auto F1(UniqueHloInstruction* x, UniqueHloInstruction* x_center,
-        UniqueHloInstruction* fused_expectation,
-        UniqueHloInstruction* custom_call, UniqueHloInstruction* scale,
-        UniqueHloInstruction* dy, HloInstruction** reduce,
-        const NormMetadataMap& norm_metadata) {
-  auto broadcasts_two_over_nelems = [](const HloInstruction* instr) -> bool {
-    const HloInstruction* multiply = SkipUnaryOps(instr->operand(0));
-    bool bcast_operand =
-        multiply->operand(0)->opcode() != HloOpcode::kBroadcast;
-
-    // The captured scalar must be two over the number of elements in the
-    // broadcasted dimensions.
-    float actual_two_over_nelems = multiply->operand(bcast_operand)
-                                       ->operand(0)
-                                       ->literal()
-                                       .GetAsDouble({})
-                                       .value();
-    int64_t nelems = 1;
-    for (int i = 0; i < instr->shape().dimensions_size(); ++i) {
-      if (!absl::c_linear_search(instr->dimensions(), i)) {
-        nelems *= instr->shape().dimensions()[i];
-      }
-    }
-    // The absolute of the difference between the actual scaling factor and the
-    // reference value must not exceed a prescribed threshold.
-    float two_over_nelems = 2. / static_cast<float>(nelems);
-    float numerical_epsilon = std::numeric_limits<bfloat16>::epsilon();
-    return abs(actual_two_over_nelems - two_over_nelems) <
-           ((actual_two_over_nelems + two_over_nelems) * numerical_epsilon);
-  };
-
-  return MultiplyAnyOrder(
-      XCenter(x_center, x, fused_expectation, custom_call, norm_metadata),
-      m::Broadcast(
-          MultiplyAnyOrder(m::Broadcast(m::ConstantScalar()),
-                           MultiplyAnyOrder(DNormFactor(custom_call),
-                                            F0(custom_call, scale, dy, x,
-                                               reduce, norm_metadata))))
-          .WithPredicate(broadcasts_two_over_nelems));
-}
-
-// Product of the norm factor, scale and DY, NormFactor * scale * DY. Captures
-// the scale in scale if custom_call is a forward layer norm fusing the scale.
-// Forwards custom_call to FusedNormFactor for comparison.
-auto F2(UniqueHloInstruction* fused_norm_factor, UniqueHloInstruction* scale,
-        UniqueHloInstruction* dy, UniqueHloInstruction* custom_call,
-        const NormMetadataMap& norm_metadata) {
-  auto capture_or_verify_scale = [scale, custom_call, &norm_metadata](
-                                     const HloInstruction* instr) -> bool {
-    return scale->CaptureOrVerify(
-        FindTarget(custom_call->Instr(), instr->operand(0),
-                   custom_call->Instr()->operand(1), norm_metadata)
-            ? custom_call->Instr()->mutable_operand(1)
-            : nullptr);
-  };
-  return MultiplyAnyOrder(
-      m::Broadcast(
-          BitcastOrReshape(FusedNormFactor(fused_norm_factor, custom_call))),
-      MultiplyAnyOrder(m::Broadcast().WithPredicate(capture_or_verify_scale),
-                       m::Op().WithPredicate(dy->GetCaptureOrVerifyFn())));
-}
-
-class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor {
- public:
-  explicit CudnnNormRewriterVisitor(
-      const se::CudaComputeCapability cuda_compute_capability)
-      : cuda_compute_capability_(cuda_compute_capability) {}
-
-  absl::Status HandleAdd(HloInstruction* instr) override {
-    TF_RETURN_IF_ERROR(MatchLayerNorm(instr));
-    TF_RETURN_IF_ERROR(MatchLayerNormGradient(instr));
-    return absl::OkStatus();
-  }
-
-  absl::Status HandleSubtract(HloInstruction* instr) override {
-    return MatchLayerNorm(instr);
-  }
-
-  // Matches and rewrites layer norm patterns,
-  // Y = (X - expectation(X))/sqrt(variance(X) + epsilon) * scale + bias,
-  // into Custom Calls to cuDNN.
-  absl::Status MatchLayerNorm(HloInstruction* instr) {
-    UniqueHloInstruction x, expectation, variance, epsilon;
-    HloInstruction *scale, *bias, *reduce, *norm_factor, *broadcast_scale,
-        *broadcast_bias;
-    if (Match(
-            instr,
-            SubtractMultiplyAddAnyOrder(
-                OptionalSupportedTransform(
-                    m::Op().WithPredicate(x.GetCaptureOrVerifyFn())),
-                Expectation(&expectation, &reduce,
-                            OptionalSupportedTransform(m::Op().WithPredicate(
-                                x.GetCaptureOrVerifyFn()))),
-                NormFactor(&norm_factor, &x, &variance, &expectation, &epsilon),
-                m::Broadcast(&broadcast_scale, m::Op(&scale)),
-                m::Broadcast(&broadcast_bias, m::Op(&bias))))) {
-#if CUDNN_VERSION < 8905
-      // Layer norm kernels are available with cuDNN 8.9.5 and above.
-      VLOG(1) << "Layer norm Custom Calls require cuDNN 8.9.5.";
-      return absl::OkStatus();
-#endif  // CUDNN_VERSION < 8905
-
-      if (!instr->GetModule()
-               ->config()
-               .debug_options()
-               .xla_gpu_enable_cudnn_layer_norm()) {
-        VLOG(1) << "Layer norm Custom Calls disabled.";
-        return absl::OkStatus();
-      }
-
-      // Layer norm kernels require Ampere or Hopper architectures.
-      if (cuda_compute_capability_.major != se::CudaComputeCapability::AMPERE &&
-          cuda_compute_capability_.major != se::CudaComputeCapability::HOPPER) {
-        VLOG(1) << "Layer norm Custom Calls require Ampere or Hopper "
-                   "architectures.";
-        return absl::OkStatus();
-      }
-
-      // Verify the uniqueness of the inputs.
-      if (!x.Instr() || !expectation.Instr() || !variance.Instr() ||
-          !epsilon.Instr()) {
-        VLOG(1) << "Layer norm operands not unique.";
-        return absl::OkStatus();
-      }
-
-      // Verify the input and output layouts.
-      // TODO(philipphack): Consider supporting more general cases.
-      if (!LayoutUtil::IsMonotonicWithDim0Major(x.Instr()->shape().layout()) ||
-          !LayoutUtil::IsMonotonicWithDim0Major(scale->shape().layout()) ||
-          !LayoutUtil::IsMonotonicWithDim0Major(bias->shape().layout()) ||
-          !LayoutUtil::IsMonotonicWithDim0Major(instr->shape().layout())) {
-        VLOG(1) << "Layer norm input and/or output layouts nor supported.";
-        return absl::OkStatus();
-      }
-
-      // Verify the element types. The element types of input and output and the
-      // shapes of scale and bias must match.
-      if (!CompatibleElementType(instr) || !CompatibleElementType(scale) ||
-          !CompatibleElementType(bias) ||
-          !ShapeUtil::SameElementType(instr->shape(), x.Instr()->shape()) ||
-          !ShapeUtil::Equal(scale->shape(), bias->shape())) {
-        VLOG(1) << "Layer norm input types or shapes not supported.";
-        return absl::OkStatus();
-      }
-
-      // Verify that the shapes of scale and bias are compatible with the
-      // operation. The adjusted norm dimensions are the dimensions of the
-      // reduction after removing any degenerate dimensions from the input of
-      // the reduction.
-      std::vector<int64_t> norm_dims(reduce->dimensions().begin(),
-                                     reduce->dimensions().end());
-      std::vector<int64_t> norm_dims_adjusted = AdjustedDimensions(reduce);
-      if (norm_dims_adjusted.size() !=
-          ShapeUtil::DropDegenerateDimensions(scale->shape())
-              .dimensions_size()) {
-        VLOG(1) << "Layer norm input dimensions not supported.";
-        return absl::OkStatus();
-      }
-
-      // Verify the broadcasts of scale and bias.
-      if (!ShapeUtil::EqualIgnoringElementType(
-              ShapeUtil::DropDegenerateDimensions(reduce->operand(0)->shape()),
-              ShapeUtil::DropDegenerateDimensions(broadcast_scale->shape())) ||
-          !ShapeUtil::EqualIgnoringElementType(
-              ShapeUtil::DropDegenerateDimensions(reduce->operand(0)->shape()),
-              ShapeUtil::DropDegenerateDimensions(broadcast_bias->shape())) ||
-          norm_dims_adjusted != AdjustedDimensions(broadcast_scale) ||
-          norm_dims_adjusted != AdjustedDimensions(broadcast_bias)) {
-        VLOG(1) << "Layer norm operand broadcast not supported.";
-        return absl::OkStatus();
-      }
-
-      // If necessary, transpose the input so that the dimensions not being
-      // normalized are the leading dimensions.
-      std::vector<int64_t> non_norm_dims;
-      for (int64_t x_dim = 0; x_dim < x.Instr()->shape().rank(); ++x_dim) {
-        if (std::find(norm_dims.begin(), norm_dims.end(), x_dim) ==
-            norm_dims.end()) {
-          non_norm_dims.emplace_back(x_dim);
-        }
-      }
-      std::vector<int64_t> non_norm_dims_adjusted =
-          AdjustedDimensions(x.Instr()->shape(), non_norm_dims);
-
-      std::vector<int64_t> x_transpose_order = non_norm_dims;
-      x_transpose_order.insert(x_transpose_order.end(), norm_dims.begin(),
-                               norm_dims.end());
-
-      bool apply_transpose = false;
-      for (int i = 0; i < x_transpose_order.size(); ++i) {
-        if (x_transpose_order[i] != i) {
-          apply_transpose = true;
-          break;
-        }
-      }
-
-      std::optional<HloInstruction*> x_transpose;
-      // The transpose applied to the output is the inverse of the transpose
-      // applied to the input.
-      std::vector<int64_t> y_transpose_order(x_transpose_order.size());
-      if (apply_transpose) {
-        for (int k = 0; k < x_transpose_order.size(); ++k) {
-          y_transpose_order[x_transpose_order[k]] = k;
-        }
-        TF_ASSIGN_OR_RETURN(x_transpose,
-                            MakeTransposeHlo(x.Instr(), x_transpose_order));
-      }
-
-      // Combine the dimensions not normalized into the first dimension of the
-      // input as required by cuDNN.
-      std::vector<int64_t> reshaped_dims = {1};
-      for (auto non_norm_dim : non_norm_dims) {
-        reshaped_dims[0] *= x.Instr()->shape().dimensions(non_norm_dim);
-      }
-      for (auto norm_dim : norm_dims) {
-        reshaped_dims.emplace_back(x.Instr()->shape().dimensions(norm_dim));
-      }
-      // cuDNN requires tensors to have at least four dimensions.
-      while (reshaped_dims.size() < 4) {
-        reshaped_dims.emplace_back(1);
-      }
-
-      Shape reshaped_shape = ShapeUtil::MakeShape(
-          x.Instr()->shape().element_type(), reshaped_dims);
-      TF_ASSIGN_OR_RETURN(
-          HloInstruction * x_reshape,
-          MakeReshapeHlo(reshaped_shape, x_transpose.value_or(x.Instr())));
-
-      // Reshape the scale and bias. The first dimension corresponds to the
-      // non-normalization dimension of the norm input and must have size 1.
-      std::vector<int64_t> reshaped_scale_dims = reshaped_dims;
-      reshaped_scale_dims[0] = 1;
-
-      Shape scale_bias_shape = ShapeUtil::MakeShape(
-          scale->shape().element_type(), reshaped_scale_dims);
-      TF_ASSIGN_OR_RETURN(HloInstruction * scale_reshape,
-                          MakeReshapeHlo(scale_bias_shape, scale));
-      TF_ASSIGN_OR_RETURN(HloInstruction * bias_reshape,
-                          MakeReshapeHlo(scale_bias_shape, bias));
-      GpuBackendConfig gpu_backend_config;
-      CudnnNormBackendConfig& backend_config =
-          *gpu_backend_config.mutable_cudnn_norm_backend_config();
-      backend_config.set_epsilon(
-          epsilon.Instr()->literal().GetAsDouble({}).value());
-      backend_config.set_kind(CudnnNormBackendConfig::LAYER_FWD_INFER);
-      auto* algorithm = backend_config.mutable_algorithm();
-      algorithm->set_algo_id(0);
-      algorithm->set_math_type(se::dnn::AlgorithmProto::TENSOR_OP_MATH);
-      algorithm->set_is_cudnn_frontend(true);
-
-      // Set the workspace size to its upper bound.
-      // TODO(philipphack): Consider autotuning the norm kernels.
-      TF_ASSIGN_OR_RETURN(const int64_t c_constant,
-                          CConstant(cuda_compute_capability_));
-      const int64_t workspace_size =
-          (2 * c_constant * (4 + 256)) + (2 * reshaped_dims[0] * 4) + 64;
-      algorithm->mutable_workspace_size()->set_value(workspace_size);
-
-      // The output of the Custom Call is a tuple, the second element of which
-      // describes the scratch space.
-      Shape custom_call_shape = ShapeUtil::MakeTupleShape(
-          {x_reshape->shape(), ShapeUtil::MakeShape(U8, {workspace_size})});
-
-      HloInstruction* custom_call =
-          instr->AddInstruction(HloInstruction::CreateCustomCall(
-              custom_call_shape, {x_reshape, scale_reshape, bias_reshape},
-              kCudnnNormCallTarget));
-      TF_RETURN_IF_ERROR(custom_call->set_backend_config(gpu_backend_config));
-
-      TF_ASSIGN_OR_RETURN(HloInstruction * gte,
-                          MakeGetTupleElementHlo(custom_call, 0));
-      TF_ASSIGN_OR_RETURN(
-          HloInstruction * y_reshape,
-          MakeReshapeHlo(x_transpose.value_or(instr)->shape(), gte));
-
-      std::optional<HloInstruction*> y_transpose;
-      if (apply_transpose) {
-        TF_ASSIGN_OR_RETURN(y_transpose,
-                            MakeTransposeHlo(y_reshape, y_transpose_order));
-      }
-      TF_RETURN_IF_ERROR(
-          ReplaceInstruction(instr, y_transpose.value_or(y_reshape)));
-
-      // Store metadata for potential use in the backward graph.
-      norm_metadata_.insert(
-          {custom_call,
-           NormMetadata({x_transpose.value_or(nullptr),
-                         y_transpose.value_or(nullptr), norm_dims_adjusted,
-                         non_norm_dims_adjusted})});
-
-      VLOG(1) << "Layer norm rewritten into Custom Call.";
-
-      // The layer norm training graph separately contains the norm factor
-      // divided by the sum of variance and epsilon.
-      for (HloInstruction* user : norm_factor->users()) {
-        if (user->opcode() == HloOpcode::kDivide &&
-            user->operand_index(norm_factor) == 0) {
-          TF_ASSIGN_OR_RETURN(bool changed,
-                              MatchNormFactor(user, custom_call, variance,
-                                              expectation, epsilon));
-          if (changed) {
-            break;
-          }
-        }
-      }
-    }
-
-    return absl::OkStatus();
-  }
-
-  // The layer norm training graph separately contains the expectation as well
-  // as the norm factor and its cube, (variance + epsilon)^-1/2 and (variance +
-  // epsilon)^-3/2. When identified in the graph, these quantities are fused
-  // into the layer norm Custom Call.
-  absl::StatusOr<bool> MatchNormFactor(HloInstruction* instr,
-                                       HloInstruction* custom_call,
-                                       UniqueHloInstruction& variance,
-                                       UniqueHloInstruction& expectation,
-                                       UniqueHloInstruction& epsilon) {
-    HloInstruction* gte = custom_call->users()[0];
-    if (Match(instr,
-              m::Divide(
-                  m::Op(),
-                  AddAnyOrder(
-                      m::Op().WithPredicate(variance.GetCaptureOrVerifyFn()),
-                      m::Broadcast(m::ConstantScalar().WithPredicate(
-                          epsilon.GetCaptureOrVerifyFn())))))) {
-      // Verify the uniqueness of the operands.
-      if (!variance.Instr() || !epsilon.Instr()) {
-        VLOG(1) << "Layer norm operands not unique.";
-        return false;
-      }
-
-      // Verify the element types.
-      if (!CompatibleElementType(instr) ||
-          !CompatibleElementType(expectation.Instr())) {
-        VLOG(1) << "Layer norm input types not compatible.";
-        return false;
-      }
-
-      // Retrieve metadata of the forward layer norm.
-      auto norm_metadata = norm_metadata_.extract(custom_call);
-      if (!norm_metadata) {
-        VLOG(1) << "Unable to retrieve norm metadata of forward Custom Call.";
-        return false;
-      }
-
-      // The shape of the expectation and norm factor return values of the
-      // Custom Call is [nelems, 1, 1, 1], where nelems is the
-      // number of elements in the expectation and norm factor shapes.
-      auto make_compatible_shape = [](Shape shape) -> Shape {
-        return ShapeUtil::MakeShape(shape.element_type(),
-                                    {ShapeUtil::ElementsIn(shape), 1, 1, 1});
-      };
-
-      Shape expectation_shape =
-          make_compatible_shape(expectation.Instr()->shape());
-      Shape norm_factor_shape = make_compatible_shape(instr->shape());
-
-      // The augmented Custom Call additionally returns the expectation and the
-      // norm factor.
-      std::vector<Shape> tuple_shapes = custom_call->shape().tuple_shapes();
-      tuple_shapes.insert(tuple_shapes.begin() + 1,
-                          {expectation_shape, norm_factor_shape});
-
-      Shape custom_call_shape = ShapeUtil::MakeTupleShape(tuple_shapes);
-
-      HloInstruction* new_custom_call = instr->AddInstruction(
-          custom_call->CloneWithNewShape(custom_call_shape));
-
-      TF_ASSIGN_OR_RETURN(
-          GpuBackendConfig gpu_backend_config,
-          custom_call->backend_config<xla::gpu::GpuBackendConfig>());
-      CudnnNormBackendConfig& backend_config =
-          *gpu_backend_config.mutable_cudnn_norm_backend_config();
-      backend_config.set_kind(CudnnNormBackendConfig::LAYER_FWD_TRAIN);
-
-      // Update the workspace size.
-      TF_ASSIGN_OR_RETURN(const int64_t c_constant,
-                          CConstant(cuda_compute_capability_));
-      const int64_t workspace_size = (2 * c_constant * (4 + 256)) + 32;
-      backend_config.mutable_algorithm()->mutable_workspace_size()->set_value(
-          workspace_size);
-      TF_RETURN_IF_ERROR(
-          new_custom_call->set_backend_config(gpu_backend_config));
-
-      auto replace_with_new_cc = [new_custom_call, this](
-                                     HloInstruction* old_instr,
-                                     int tuple_index) -> absl::Status {
-        TF_ASSIGN_OR_RETURN(
-            HloInstruction * new_gte,
-            MakeGetTupleElementHlo(new_custom_call, tuple_index));
-        HloInstruction* new_instr = new_gte;
-        if (!ShapeUtil::Equal(new_gte->shape(), old_instr->shape())) {
-          TF_ASSIGN_OR_RETURN(new_instr,
-                              MakeReshapeHlo(old_instr->shape(), new_gte));
-        }
-        if (old_instr->opcode() != HloOpcode::kDivide) {
-          // Replace the result of the layer norm or the expectation.
-          TF_RETURN_IF_ERROR(ReplaceInstruction(old_instr, new_instr));
-        } else {
-          // Replace the norm factor, (variance + epsilon)^-1/2.
-          TF_RETURN_IF_ERROR(
-              ReplaceInstruction(old_instr->mutable_operand(0), new_instr));
-          // Also replace the norm factor to the power of 3, (variance +
-          // epsilon)^-1/2 / (variance + epsilon) = ((variance +
-          // epsilon)^-1/2)^3.
-          TF_ASSIGN_OR_RETURN(
-              HloInstruction * new_multiply0,
-              MakeBinaryHlo(HloOpcode::kMultiply, new_instr, new_instr));
-          TF_ASSIGN_OR_RETURN(
-              HloInstruction * new_multiply1,
-              MakeBinaryHlo(HloOpcode::kMultiply, new_multiply0, new_instr));
-          TF_RETURN_IF_ERROR(ReplaceInstruction(old_instr, new_multiply1));
-        }
-        return absl::OkStatus();
-      };
-
-      // Replace the result of the original Custom Call as well as the
-      // expectation and the norm factor with the augmented Custom Call.
-      TF_RETURN_IF_ERROR(replace_with_new_cc(gte, 0));
-      TF_RETURN_IF_ERROR(replace_with_new_cc(expectation.Instr(), 1));
-      TF_RETURN_IF_ERROR(replace_with_new_cc(instr, 2));
-
-      // Update the Custom Call associated with the metadata of the forward
-      // norm.
-      norm_metadata.key() = new_custom_call;
-      norm_metadata_.insert(std::move(norm_metadata));
-
-      VLOG(1)
-          << "Expectation and norm factor fused into layer norm Custom Call.";
-    }
-
-    return true;
-  }
-
-  // Matches and rewrites the backward graph of layer norm patterns into Custom
-  // Calls to cuDNN when the associated forward graph has been rewritten into a
-  // cuDNN Custom Call. The gradients are
-  //   DX = F1 + F2 - AddReduce(F1 + F2) / nelems,
-  //   Dscale = AddReduce(DY * XCenter * NormFactor),
-  //   Dbias = AddReduce(DY),
-  // with
-  //   F0 = XCenter * scale * DY,
-  //   F1 =  XCenter * F0 * DNormFactor * 2 / nelems,
-  //   F2 = NormFactor * scale * DY,
-  //   XCenter = X - expectation(X),
-  //   NormFactor = (variance(X) + epsilon)^-1/2 and
-  //   DNormFactor = -1/2 * NormFactor^3.
-  absl::Status MatchLayerNormGradient(HloInstruction* instr) {
-    UniqueHloInstruction fwd_custom_call, x, x_center, scale, dy,
-        fused_expectation, fused_norm_factor;
-    HloInstruction *broadcast, *scalar, *dscale, *dbias, *reduce0, *reduce1,
-        *reduce2, *reduce3;
-    if (Match(instr,
-              AddAddAnyOrder(
-                  m::Broadcast(
-                      &broadcast,
-                      MultiplyAddAnyOrder(
-                          m::Broadcast(m::ConstantScalar(&scalar)),
-                          NegateAddReduce(&reduce0,
-                                          F1(&x, &x_center, &fused_expectation,
-                                             &fwd_custom_call, &scale, &dy,
-                                             &reduce2, norm_metadata_)),
-                          NegateAddReduce(
-                              &reduce1, F2(&fused_norm_factor, &scale, &dy,
-                                           &fwd_custom_call, norm_metadata_)))),
-                  F2(&fused_norm_factor, &scale, &dy, &fwd_custom_call,
-                     norm_metadata_),
-                  F1(&x, &x_center, &fused_expectation, &fwd_custom_call,
-                     &scale, &dy, &reduce3, norm_metadata_)))) {
-      // Skip initial convert, if present.
-      if (instr->user_count() == 1 &&
-          instr->users()[0]->opcode() == HloOpcode::kConvert &&
-          CompatibleElementType(instr->users()[0])) {
-        instr = instr->users()[0];
-      }
-
-      // Verify the uniqueness of the captured Custom Call and inputs.
-      if (!fwd_custom_call.Instr() || !x.Instr() || !dy.Instr() ||
-          !x_center.Instr() || !scale.Instr() || !fused_expectation.Instr() ||
-          !fused_norm_factor.Instr()) {
-        VLOG(1) << "Layer norm gradient inputs not unique.";
-        return absl::OkStatus();
-      }
-
-      // Retrieve metadata of the forward layer norm.
-      auto norm_metadata = norm_metadata_.find(fwd_custom_call.Instr());
-      if (norm_metadata == norm_metadata_.end()) {
-        VLOG(1) << "Unable to retrieve norm metadata of forward Custom Call.";
-        return absl::OkStatus();
-      }
-
-      // Verify the dimensions of reductions in the backward graph.
-      if (AdjustedDimensions(reduce0) !=
-              norm_metadata->second.norm_dims_adjusted ||
-          AdjustedDimensions(reduce1) !=
-              norm_metadata->second.norm_dims_adjusted ||
-          AdjustedDimensions(reduce2) !=
-              norm_metadata->second.norm_dims_adjusted ||
-          AdjustedDimensions(reduce3) !=
-              norm_metadata->second.norm_dims_adjusted) {
-        VLOG(1) << "Unexpected reductions dimensions in layer norm gradient.";
-        return absl::OkStatus();
-      }
-
-      // The captured scalar must be one over the number of elements in the
-      // broadcasted dimensions.
-      float actual_r_nelems = scalar->literal().GetAsDouble({}).value();
-      int64_t nelems = 1;
-      for (int i = 0; i < broadcast->shape().dimensions_size(); ++i) {
-        if (!absl::c_linear_search(broadcast->dimensions(), i)) {
-          nelems *= broadcast->shape().dimensions()[i];
-        }
-      }
-      // The absolute of the difference between the actual scaling factor and
-      // the reference value must not exceed a prescribed threshold.
-      float r_nelems = 1. / static_cast<float>(nelems);
-      float numerical_epsilon = std::numeric_limits<bfloat16>::epsilon();
-      if (!(abs(actual_r_nelems - r_nelems) <
-            ((actual_r_nelems + r_nelems) * numerical_epsilon))) {
-        VLOG(1)
-            << "Layer norm backward broadcast operand outside expected range.";
-        return absl::OkStatus();
-      }
-
-      // Identify Dscale = AddReduce(DY * XCenter * norm factor) with factor0
-      // and factor1 intended to be XCenter and DY or DY and XCenter.
-      auto find_dscale =
-          [&fused_norm_factor, &norm_metadata](
-              const UniqueHloInstruction& factor0,
-              const UniqueHloInstruction& factor1) -> HloInstruction* {
-        for (HloInstruction* factor0_user : factor0.Instr()->users()) {
-          std::vector<HloInstruction*> users;
-          SkipUnaryOpsTopDownRecursive(factor0_user, users);
-          // One of the users of factor0 must be a chained multiplication by the
-          // fused norm factor and factor1.
-          for (HloInstruction* user : users) {
-            if (Match(user,
-                      MultiplyAnyOrder(
-                          m::Op(), MultiplyAnyOrder(
-                                       m::Broadcast(BitcastOrReshape(m::Op().Is(
-                                           fused_norm_factor.Instr()))),
-                                       m::Op().Is(factor1.Instr()))))) {
-              // Dscale is an addition-reduction of the product.
-              for (HloInstruction* multiply_user : user->users()) {
-                if (AppliesAddReduce(
-                        multiply_user,
-                        norm_metadata->second.non_norm_dims_adjusted)) {
-                  return multiply_user;
-                }
-              }
-            }
-          }
-        }
-        return nullptr;
-      };
-      if (!(dscale = find_dscale(x_center, dy)) &&
-          !(dscale = find_dscale(dy, x_center))) {
-        VLOG(1) << "Unable to identify Dscale in graph.";
-        return absl::OkStatus();
-      }
-
-      // Find Dbias, i.e. an addition-reduction of DY, starting from DY.
-      // Rewriting proceeds without fusing Dbias if unsuccessful.
-      dbias = FindAddReduce(dy.Instr(),
-                            norm_metadata->second.non_norm_dims_adjusted);
-
-      // Verify the input and output layouts.
-      // TODO(philipphack): Consider supporting more general cases.
-      if (!LayoutUtil::IsMonotonicWithDim0Major(dy.Instr()->shape().layout()) ||
-          !LayoutUtil::IsMonotonicWithDim0Major(instr->shape().layout()) ||
-          !LayoutUtil::IsMonotonicWithDim0Major(dscale->shape().layout()) ||
-          (dbias &&
-           !LayoutUtil::IsMonotonicWithDim0Major(dbias->shape().layout()))) {
-        VLOG(1) << "Layer norm input and/or output layouts nor supported.";
-        return absl::OkStatus();
-      }
-
-      // The types of X and DX must match.
-      if (x.Instr()->shape().element_type() != instr->shape().element_type()) {
-        VLOG(1) << "The types of X and DX must match.";
-        return absl::OkStatus();
-      }
-
-      // The types and shapes of scale, Dscale and Dbias (if present) must
-      // match.
-      if (!ShapeUtil::Equal(
-              ShapeUtil::DropDegenerateDimensions(scale.Instr()->shape()),
-              ShapeUtil::DropDegenerateDimensions(dscale->shape())) ||
-          (dbias &&
-           !ShapeUtil::Equal(
-               ShapeUtil::DropDegenerateDimensions(scale.Instr()->shape()),
-               ShapeUtil::DropDegenerateDimensions(dbias->shape())))) {
-        VLOG(1) << "Backward layer norm types not supported.";
-        return absl::OkStatus();
-      }
-
-      // Verify the element types.
-      if (!CompatibleElementType(dy.Instr())) {
-        VLOG(1) << "Backward layer norm types not supported.";
-        return absl::OkStatus();
-      }
-
-      // cuDNN requires the byte size of the element type of X to be at least
-      // that of DY and scale.
-      if (ShapeUtil::ByteSizeOfPrimitiveType(
-              x.Instr()->shape().element_type()) <
-              ShapeUtil::ByteSizeOfPrimitiveType(
-                  dy.Instr()->shape().element_type()) ||
-          ShapeUtil::ByteSizeOfPrimitiveType(
-              x.Instr()->shape().element_type()) <
-              ShapeUtil::ByteSizeOfPrimitiveType(
-                  scale.Instr()->shape().element_type())) {
-        VLOG(1) << "Backward layer norm types not supported.";
-        return absl::OkStatus();
-      }
-
-      // Transpose DY applying the stored transpose order of X from the forward
-      // graph.
-      HloInstruction* transposed_dy = dy.Instr();
-      if (norm_metadata->second.x_transpose) {
-        TF_ASSIGN_OR_RETURN(
-            transposed_dy,
-            MakeTransposeHlo(dy.Instr(),
-                             norm_metadata->second.x_transpose->dimensions()));
-      }
-      TF_ASSIGN_OR_RETURN(HloInstruction * reshaped_dy,
-                          MakeReshapeHlo(x.Instr()->shape(), transposed_dy));
-
-      Shape dx_shape = ShapeUtil::MakeShape(instr->shape().element_type(),
-                                            x.Instr()->shape().dimensions());
-
-      Shape dscale_dbias_shape = ShapeUtil::MakeShape(
-          dscale->shape().element_type(), scale.Instr()->shape().dimensions());
-
-      GpuBackendConfig gpu_backend_config;
-      CudnnNormBackendConfig& backend_config =
-          *gpu_backend_config.mutable_cudnn_norm_backend_config();
-      backend_config.set_kind(CudnnNormBackendConfig::LAYER_BWD);
-      auto* algorithm = backend_config.mutable_algorithm();
-      algorithm->set_algo_id(0);
-      algorithm->set_math_type(se::dnn::AlgorithmProto::TENSOR_OP_MATH);
-      algorithm->set_is_cudnn_frontend(true);
-
-      // Set the workspace size to its upper bound.
-      // TODO(philipphack): Consider autotuning the norm kernels.
-      TF_ASSIGN_OR_RETURN(const int64_t c_constant,
-                          CConstant(cuda_compute_capability_));
-      const int64_t workspace_size =
-          (2 * c_constant * (4 + 256)) +
-          (2 * x.Instr()->shape().dimensions(0) * 4) + 64;
-      algorithm->mutable_workspace_size()->set_value(workspace_size);
-
-      // The output of the Custom Call is a tuple. The output shape of Dscale
-      // and Dbias is that of scale.
-      Shape custom_call_shape = ShapeUtil::MakeTupleShape(
-          {dx_shape, dscale_dbias_shape, dscale_dbias_shape,
-           ShapeUtil::MakeShape(U8, {workspace_size})});
-
-      HloInstruction* custom_call =
-          instr->AddInstruction(HloInstruction::CreateCustomCall(
-              custom_call_shape,
-              {x.Instr(), scale.Instr(), reshaped_dy, fused_expectation.Instr(),
-               fused_norm_factor.Instr()},
-              kCudnnNormCallTarget));
-      TF_RETURN_IF_ERROR(custom_call->set_backend_config(gpu_backend_config));
-
-      auto replace_with_cc = [custom_call, norm_metadata, transposed_dy, this](
-                                 HloInstruction* old_instr,
-                                 int tuple_index) -> absl::Status {
-        TF_ASSIGN_OR_RETURN(HloInstruction * gte,
-                            MakeGetTupleElementHlo(custom_call, tuple_index));
-        HloInstruction* new_instr;
-        // Transpose DX applying the stored transpose order of Y from the
-        // forward graph.
-        if (tuple_index == 0 && norm_metadata->second.y_transpose) {
-          TF_ASSIGN_OR_RETURN(new_instr,
-                              MakeReshapeHlo(transposed_dy->shape(), gte));
-          TF_ASSIGN_OR_RETURN(
-              new_instr,
-              MakeTransposeHlo(
-                  new_instr, norm_metadata->second.y_transpose->dimensions()));
-        } else {
-          TF_ASSIGN_OR_RETURN(new_instr,
-                              MakeReshapeHlo(old_instr->shape(), gte));
-        }
-        TF_RETURN_IF_ERROR(ReplaceInstruction(old_instr, new_instr));
-        return absl::OkStatus();
-      };
-
-      TF_RETURN_IF_ERROR(replace_with_cc(instr, 0));
-      TF_RETURN_IF_ERROR(replace_with_cc(dscale, 1));
-      if (dbias) {
-        TF_RETURN_IF_ERROR(replace_with_cc(dbias, 2));
-      }
-      VLOG(1) << "Gradients w.r.t. x"
-              << (dbias ? ", scale and bias" : " and scale")
-              << " rewritten into layer norm backward Custom Call.";
-    }
-
-    return absl::OkStatus();
-  }
-
- private:
-  se::CudaComputeCapability cuda_compute_capability_;
-  NormMetadataMap norm_metadata_;
-};
-
-absl::StatusOr<bool> RunOnComputation(
-    HloComputation* computation,
-    se::CudaComputeCapability cuda_compute_capability) {
-  CudnnNormRewriterVisitor visitor(cuda_compute_capability);
-  TF_RETURN_IF_ERROR(computation->Accept(&visitor));
-  return visitor.changed();
-}
-
-}  // anonymous namespace
-
-CudnnNormRewriter::CudnnNormRewriter(
-    se::CudaComputeCapability cuda_compute_capability)
-    : cuda_compute_capability_(cuda_compute_capability) {}
-
-absl::StatusOr<bool> CudnnNormRewriter::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  bool changed = false;
-  for (HloComputation* computation :
-       module->MakeNonfusionComputations(execution_threads)) {
-    TF_ASSIGN_OR_RETURN(
-        bool result, RunOnComputation(computation, cuda_compute_capability_));
-    changed |= result;
-  }
-  return changed;
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/cudnn_norm_rewriter.h b/third_party/xla/xla/service/gpu/cudnn_norm_rewriter.h
deleted file mode 100644
index 7b3ef8d..0000000
--- a/third_party/xla/xla/service/gpu/cudnn_norm_rewriter.h
+++ /dev/null
@@ -1,48 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_CUDNN_NORM_REWRITER_H_
-#define XLA_SERVICE_GPU_CUDNN_NORM_REWRITER_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/stream_executor/device_description.h"
-
-namespace xla {
-namespace gpu {
-
-// Rewrites norm patterns into Custom Calls to the cuDNN library. Currently, the
-// forward and backward passes of layer norm patterns are implemented.
-class CudnnNormRewriter : public HloModulePass {
- public:
-  explicit CudnnNormRewriter(se::CudaComputeCapability cuda_compute_capability);
-  absl::string_view name() const override { return "norm-rewriter"; }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
-  se::CudaComputeCapability cuda_compute_capability_;
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_CUDNN_NORM_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/cudnn_norm_rewriter_test.cc b/third_party/xla/xla/service/gpu/cudnn_norm_rewriter_test.cc
deleted file mode 100644
index 754563a..0000000
--- a/third_party/xla/xla/service/gpu/cudnn_norm_rewriter_test.cc
+++ /dev/null
@@ -1,1798 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include <string>
-
-#include <gtest/gtest.h>
-#include "xla/error_spec.h"
-#include "xla/stream_executor/device_description.h"
-
-#if GOOGLE_CUDA
-#include "third_party/gpus/cuda/include/cuda.h"
-#include "third_party/gpus/cudnn/cudnn.h"  // IWYU pragma: keep
-#include "third_party/gpus/cudnn/cudnn_version.h"
-#endif
-
-#include "xla/service/gpu/tests/gpu_codegen_test.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-class CudnnNormRewriterTest : public GpuCodegenTest {
- public:
-  se::CudaComputeCapability GetCudaComputeCapability() {
-    return backend()
-        .default_stream_executor()
-        ->GetDeviceDescription()
-        .cuda_compute_capability();
-  }
-
-  DebugOptions GetDebugOptionsForTest() override {
-    DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest();
-    debug_options.set_xla_gpu_enable_cudnn_layer_norm(true);
-    return debug_options;
-  }
-
- protected:
-  void TestNorm(std::string hlo_text, std::string optimized_hlo) {
-    EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
-    MatchOptimizedHlo(hlo_text, optimized_hlo);
-  }
-};
-
-// The following tests evaluate LayerNormXDY configurations, with X the rank of
-// the input and Y the dimensions that are normalized.
-TEST_F(CudnnNormRewriterTest, LayerNorm2D1) {
-#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
-  GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
-#endif
-  if (!(GetCudaComputeCapability().major ==
-        se::CudaComputeCapability::AMPERE) &&
-      !(GetCudaComputeCapability().major ==
-        se::CudaComputeCapability::HOPPER)) {
-    GTEST_SKIP()
-        << "Layer norm kernels require Ampere or Hopper architectures.";
-  }
-  const char* hlo_text = R"(
-    HloModule test
-
-    apply {
-      a = f32[] parameter(0)
-      b = f32[] parameter(1)
-      ROOT c = f32[] add(a,b)
-    }
-
-    ENTRY test {
-        input = f32[2,4] parameter(0)
-        input_square = f32[2,4] multiply(input, input)
-        c0 = f32[] constant(0)
-        input_square_sum = f32[2] reduce(input_square, c0), dimensions={1}, to_apply=apply
-        r_nelems = f32[] constant(0.25)
-        r_nelems_bcast = f32[2] broadcast(r_nelems), dimensions={}
-        input_square_mean = f32[2] multiply(input_square_sum, r_nelems_bcast)
-        input_sum = f32[2] reduce(input, c0),dimensions={1}, to_apply=apply
-        input_mean = f32[2] multiply(input_sum, r_nelems_bcast)
-        input_mean_square = f32[2] multiply(input_mean, input_mean)
-        variance = f32[2] subtract(input_square_mean, input_mean_square)
-        epsilon = f32[] constant(0.001)
-        epsilon_bcast = f32[2] broadcast(epsilon), dimensions={}
-        variance_plus_epsilon = f32[2] add(variance, epsilon_bcast)
-        norm_factor = f32[2] rsqrt(variance_plus_epsilon)
-        norm_factor_bcast = f32[2,4] broadcast(norm_factor), dimensions={0}
-        input_mean_bcast = f32[2,4] broadcast(input_mean), dimensions={0}
-        input_center = f32[2,4] subtract(input, input_mean_bcast)
-        norm = f32[2,4] multiply(norm_factor_bcast, input_center)
-        scale = f32[4] parameter(1)
-        scale_bcast = f32[2,4] broadcast(scale), dimensions={1}
-        norm_scale = f32[2,4] multiply(norm, scale_bcast)
-        bias = f32[4] parameter(2)
-        bias_broadcast = f32[2,4] broadcast(bias), dimensions={1}
-        ROOT out = f32[2,4] add(norm_scale, bias_broadcast)
-    })";
-
-  const char* optimized_hlo = R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4], {{.*}}: f32[4], {{.*}}: f32[4]) -> f32[2,4] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,4]{1,0} parameter(0)
-; CHECK-NEXT:    [[P0_BITCAST:%[^ ]+]] = f32[2,4,1,1]{3,2,1,0} bitcast([[P0]])
-; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[4]{0} parameter(1)
-; CHECK-NEXT:    [[P1_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P1]])
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[4]{0} parameter(2)
-; CHECK-NEXT:    [[P2_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P2]])
-; CHECK-NEXT:    [[CC:%[^ ]+]] = (f32[2,4,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
-; CHECK:           custom_call_target="__cudnn$norm",
-; CHECK:           backend_config={
-; CHECK-DAG:         "epsilon":0.001
-; CHECK:           }
-; CHECK-NEXT:    [[GTE:%[^ ]+]] = f32[2,4,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=0
-; CHECK-NEXT:  ROOT [[GTE_BITCAST:%[^ ]+]] = f32[2,4]{1,0} bitcast([[GTE]])
-  )";
-
-  TestNorm(hlo_text, optimized_hlo);
-}
-
-TEST_F(CudnnNormRewriterTest, LayerNorm4D3) {
-#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
-  GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
-#endif
-  if (!(GetCudaComputeCapability().major ==
-        se::CudaComputeCapability::AMPERE) &&
-      !(GetCudaComputeCapability().major ==
-        se::CudaComputeCapability::HOPPER)) {
-    GTEST_SKIP()
-        << "Layer norm kernels require Ampere or Hopper architectures.";
-  }
-  const char* hlo_text = R"(
-    HloModule test
-
-    apply {
-      a = f32[] parameter(0)
-      b = f32[] parameter(1)
-      ROOT c = f32[] add(a,b)
-    }
-
-    ENTRY test {
-        input = f32[2,4,6,8] parameter(0)
-        input_square = f32[2,4,6,8] multiply(input, input)
-        c0 = f32[] constant(0)
-        input_square_sum = f32[2,4,6] reduce(input_square, c0), dimensions={3}, to_apply=apply
-        r_nelems = f32[] constant(0.125)
-        r_nelems_bcast = f32[2,4,6] broadcast(r_nelems), dimensions={}
-        input_square_mean = f32[2,4,6] multiply(input_square_sum, r_nelems_bcast)
-        input_sum = f32[2,4,6] reduce(input, c0), dimensions={3}, to_apply=apply
-        input_mean = f32[2,4,6] multiply(input_sum, r_nelems_bcast)
-        input_mean_square = f32[2,4,6] multiply(input_mean, input_mean)
-        variance = f32[2,4,6] subtract(input_square_mean, input_mean_square)
-        epsilon = f32[] constant(0.001)
-        epsilon_bcast = f32[2,4,6] broadcast(epsilon), dimensions={}
-        variance_plus_epsilon = f32[2,4,6] add(variance, epsilon_bcast)
-        norm_factor = f32[2,4,6] rsqrt(variance_plus_epsilon)
-        norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,1,2}
-        input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,1,2}
-        input_center = f32[2,4,6,8] subtract(input, input_mean_bcast)
-        norm = f32[2,4,6,8] multiply(norm_factor_bcast, input_center)
-        scale = f32[8] parameter(1)
-        scale_bcast = f32[2,4,6,8] broadcast(scale), dimensions={3}
-        norm_scale = f32[2,4,6,8] multiply(norm, scale_bcast)
-        bias = f32[8] parameter(2)
-        bias_bcast = f32[2,4,6,8] broadcast(bias), dimensions={3}
-        ROOT out = f32[2,4,6,8] add(norm_scale, bias_bcast)
-    })";
-
-  const char* optimized_hlo = R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[8], {{.*}}: f32[8]) -> f32[2,4,6,8] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0)
-; CHECK-NEXT:    [[P0_BITCAST:%[^ ]+]] = f32[48,8,1,1]{3,2,1,0} bitcast([[P0]])
-; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[8]{0} parameter(1)
-; CHECK-NEXT:    [[P1_BITCAST:%[^ ]+]] = f32[1,8,1,1]{3,2,1,0} bitcast([[P1]])
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[8]{0} parameter(2)
-; CHECK-NEXT:    [[P2_BITCAST:%[^ ]+]] = f32[1,8,1,1]{3,2,1,0} bitcast([[P2]])
-; CHECK-NEXT:    [[CC:%[^ ]+]] = (f32[48,8,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
-; CHECK:           custom_call_target="__cudnn$norm",
-; CHECK:           backend_config={
-; CHECK-DAG:         "epsilon":0.001
-; CHECK:           }
-; CHECK-NEXT:    [[GTE:%[^ ]+]] = f32[48,8,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=0
-; CHECK-NEXT:  ROOT [[GTE_BITCAST:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} bitcast([[GTE]])
-  )";
-
-  TestNorm(hlo_text, optimized_hlo);
-}
-
-TEST_F(CudnnNormRewriterTest, LayerNorm4D3Degenerate0) {
-#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
-  GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
-#endif
-  if (!(GetCudaComputeCapability().major ==
-        se::CudaComputeCapability::AMPERE) &&
-      !(GetCudaComputeCapability().major ==
-        se::CudaComputeCapability::HOPPER)) {
-    GTEST_SKIP()
-        << "Layer norm kernels require Ampere or Hopper architectures.";
-  }
-  const char* hlo_text = R"(
-    HloModule test
-
-    apply {
-      a = f32[] parameter(0)
-      b = f32[] parameter(1)
-      ROOT c = f32[] add(a,b)
-    }
-
-    ENTRY test {
-        input = f32[1,4,6,8] parameter(0)
-        input_square = f32[1,4,6,8] multiply(input, input)
-        c0 = f32[] constant(0)
-        input_square_sum = f32[1,4,6] reduce(input_square, c0), dimensions={3}, to_apply=apply
-        r_nelems = f32[] constant(0.125)
-        r_nelems_bcast = f32[1,4,6] broadcast(r_nelems), dimensions={}
-        input_square_mean = f32[1,4,6] multiply(input_square_sum, r_nelems_bcast)
-        input_sum = f32[1,4,6] reduce(input, c0), dimensions={3}, to_apply=apply
-        input_mean = f32[1,4,6] multiply(input_sum, r_nelems_bcast)
-        input_mean_square = f32[1,4,6] multiply(input_mean, input_mean)
-        variance = f32[1,4,6] subtract(input_square_mean, input_mean_square)
-        epsilon = f32[] constant(0.001)
-        epsilon_bcast = f32[1,4,6] broadcast(epsilon), dimensions={}
-        variance_plus_epsilon = f32[1,4,6] add(variance, epsilon_bcast)
-        norm_factor = f32[1,4,6] rsqrt(variance_plus_epsilon)
-        norm_factor_bcast = f32[1,4,6,8] broadcast(norm_factor), dimensions={0,1,2}
-        input_mean_bcast = f32[1,4,6,8] broadcast(input_mean), dimensions={0,1,2}
-        input_center = f32[1,4,6,8] subtract(input, input_mean_bcast)
-        norm = f32[1,4,6,8] multiply(norm_factor_bcast, input_center)
-        scale = f32[8] parameter(1)
-        scale_bcast = f32[1,4,6,8] broadcast(scale), dimensions={3}
-        norm_scale = f32[1,4,6,8] multiply(norm, scale_bcast)
-        bias = f32[8] parameter(2)
-        bias_bcast = f32[1,4,6,8] broadcast(bias), dimensions={3}
-        ROOT out = f32[1,4,6,8] add(norm_scale, bias_bcast)
-    })";
-
-  const char* optimized_hlo = R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[1,4,6,8], {{.*}}: f32[8], {{.*}}: f32[8]) -> f32[1,4,6,8] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[1,4,6,8]{3,2,1,0} parameter(0)
-; CHECK-NEXT:    [[P0_BITCAST:%[^ ]+]] = f32[24,8,1,1]{3,2,1,0} bitcast([[P0]])
-; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[8]{0} parameter(1)
-; CHECK-NEXT:    [[P1_BITCAST:%[^ ]+]] = f32[1,8,1,1]{3,2,1,0} bitcast([[P1]])
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[8]{0} parameter(2)
-; CHECK-NEXT:    [[P2_BITCAST:%[^ ]+]] = f32[1,8,1,1]{3,2,1,0} bitcast([[P2]])
-; CHECK-NEXT:    [[CC:%[^ ]+]] = (f32[24,8,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
-; CHECK:           custom_call_target="__cudnn$norm",
-; CHECK:           backend_config={
-; CHECK-DAG:         "epsilon":0.001
-; CHECK:           }
-; CHECK-NEXT:    [[GTE:%[^ ]+]] = f32[24,8,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=0
-; CHECK-NEXT:  ROOT [[GTE_BITCAST:%[^ ]+]] = f32[1,4,6,8]{3,2,1,0} bitcast([[GTE]])
-  )";
-
-  TestNorm(hlo_text, optimized_hlo);
-}
-
-TEST_F(CudnnNormRewriterTest, LayerNorm4D2) {
-#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
-  GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
-#endif
-  if (!(GetCudaComputeCapability().major ==
-        se::CudaComputeCapability::AMPERE) &&
-      !(GetCudaComputeCapability().major ==
-        se::CudaComputeCapability::HOPPER)) {
-    GTEST_SKIP()
-        << "Layer norm kernels require Ampere or Hopper architectures.";
-  }
-  const char* hlo_text = R"(
-    HloModule test
-
-    apply {
-      a = f32[] parameter(0)
-      b = f32[] parameter(1)
-      ROOT c = f32[] add(a,b)
-    }
-
-    ENTRY test {
-        input = f32[2,4,6,8] parameter(0)
-        input_square = f32[2,4,6,8] multiply(input, input)
-        c0 = f32[] constant(0)
-        input_square_sum = f32[2,4,8] reduce(input_square, c0), dimensions={2}, to_apply=apply
-        r_nelems = f32[] constant(0.166667)
-        r_nelems_bcast = f32[2,4,8] broadcast(r_nelems), dimensions={}
-        input_square_mean = f32[2,4,8] multiply(input_square_sum, r_nelems_bcast)
-        reduce = f32[2,4,8] reduce(input, c0), dimensions={2}, to_apply=apply
-        input_mean = f32[2,4,8] multiply(reduce, r_nelems_bcast)
-        input_mean_square = f32[2,4,8] multiply(input_mean, input_mean)
-        variance = f32[2,4,8] subtract(input_square_mean, input_mean_square)
-        epsilon = f32[] constant(0.001)
-        epsilon_bcast = f32[2,4,8] broadcast(epsilon), dimensions={}
-        variance_plus_epsilon = f32[2,4,8] add(variance, epsilon_bcast)
-        norm_factor = f32[2,4,8] rsqrt(variance_plus_epsilon)
-        norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,1,3}
-        input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,1,3}
-        input_center = f32[2,4,6,8] subtract(input, input_mean_bcast)
-        norm = f32[2,4,6,8] multiply(norm_factor_bcast, input_center)
-        scale = f32[6] parameter(1)
-        scale_bcast = f32[2,4,6,8] broadcast(scale), dimensions={2}
-        norm_scale = f32[2,4,6,8] multiply(norm, scale_bcast)
-        bias = f32[6] parameter(2)
-        bias_broadcast = f32[2,4,6,8] broadcast(bias), dimensions={2}
-        ROOT out = f32[2,4,6,8] add(norm_scale, bias_broadcast)
-    })";
-
-  const char* optimized_hlo = R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[6], {{.*}}: f32[6]) -> f32[2,4,6,8] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0)
-; CHECK-NEXT:    [[TRANSPOSE:%[^ ]+]] = f32[2,4,8,6]{3,2,1,0} transpose([[P0]]), dimensions={0,1,3,2}
-; CHECK-NEXT:    [[P0_BITCAST:%[^ ]+]] = f32[64,6,1,1]{3,2,1,0} bitcast([[TRANSPOSE]])
-; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[6]{0} parameter(1)
-; CHECK-NEXT:    [[P1_BITCAST:%[^ ]+]] = f32[1,6,1,1]{3,2,1,0} bitcast([[P1]])
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[6]{0} parameter(2)
-; CHECK-NEXT:    [[P2_BITCAST:%[^ ]+]] = f32[1,6,1,1]{3,2,1,0} bitcast([[P2]])
-; CHECK-NEXT:    [[CC:%[^ ]+]] = (f32[64,6,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
-; CHECK:           custom_call_target="__cudnn$norm",
-; CHECK:           backend_config={
-; CHECK-DAG:         "epsilon":0.001
-; CHECK:           }
-; CHECK-NEXT:    [[GTE:%[^ ]+]] = f32[64,6,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=0
-; CHECK-NEXT:  ROOT [[FUSION:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} fusion([[GTE]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]]
-  )";
-
-  TestNorm(hlo_text, optimized_hlo);
-}
-
-TEST_F(CudnnNormRewriterTest, LayerNorm4D2Degenerate1) {
-#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
-  GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
-#endif
-  if (!(GetCudaComputeCapability().major ==
-        se::CudaComputeCapability::AMPERE) &&
-      !(GetCudaComputeCapability().major ==
-        se::CudaComputeCapability::HOPPER)) {
-    GTEST_SKIP()
-        << "Layer norm kernels require Ampere or Hopper architectures.";
-  }
-  const char* hlo_text = R"(
-    HloModule test
-
-    apply {
-      a = f32[] parameter(0)
-      b = f32[] parameter(1)
-      ROOT c = f32[] add(a,b)
-    }
-
-    ENTRY test {
-        input = f32[2,1,6,8] parameter(0)
-        input_square = f32[2,1,6,8] multiply(input, input)
-        c0 = f32[] constant(0)
-        input_square_sum = f32[2,1,8] reduce(input_square, c0), dimensions={2}, to_apply=apply
-        r_nelems = f32[] constant(0.166667)
-        r_nelems_bcast = f32[2,1,8] broadcast(r_nelems), dimensions={}
-        input_square_mean = f32[2,1,8] multiply(input_square_sum, r_nelems_bcast)
-        reduce = f32[2,1,8] reduce(input, c0), dimensions={2}, to_apply=apply
-        input_mean = f32[2,1,8] multiply(reduce, r_nelems_bcast)
-        input_mean_square = f32[2,1,8] multiply(input_mean, input_mean)
-        variance = f32[2,1,8] subtract(input_square_mean, input_mean_square)
-        epsilon = f32[] constant(0.001)
-        epsilon_bcast = f32[2,1,8] broadcast(epsilon), dimensions={}
-        variance_plus_epsilon = f32[2,1,8] add(variance, epsilon_bcast)
-        norm_factor = f32[2,1,8] rsqrt(variance_plus_epsilon)
-        norm_factor_bcast = f32[2,1,6,8] broadcast(norm_factor), dimensions={0,1,3}
-        input_mean_bcast = f32[2,1,6,8] broadcast(input_mean), dimensions={0,1,3}
-        input_center = f32[2,1,6,8] subtract(input, input_mean_bcast)
-        norm = f32[2,1,6,8] multiply(norm_factor_bcast, input_center)
-        scale = f32[6] parameter(1)
-        scale_bcast = f32[2,1,6,8] broadcast(scale), dimensions={2}
-        norm_scale = f32[2,1,6,8] multiply(norm, scale_bcast)
-        bias = f32[6] parameter(2)
-        bias_broadcast = f32[2,1,6,8] broadcast(bias), dimensions={2}
-        ROOT out = f32[2,1,6,8] add(norm_scale, bias_broadcast)
-    })";
-
-  const char* optimized_hlo = R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,1,6,8], {{.*}}: f32[6], {{.*}}: f32[6]) -> f32[2,1,6,8] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,1,6,8]{3,2,1,0} parameter(0)
-; CHECK-NEXT:    [[TRANSPOSE:%[^ ]+]] = f32[1,2,8,6]{3,2,1,0} transpose([[P0]]), dimensions={1,0,3,2}
-; CHECK-NEXT:    [[P0_BITCAST:%[^ ]+]] = f32[16,6,1,1]{3,2,1,0} bitcast([[TRANSPOSE]])
-; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[6]{0} parameter(1)
-; CHECK-NEXT:    [[P1_BITCAST:%[^ ]+]] = f32[1,6,1,1]{3,2,1,0} bitcast([[P1]])
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[6]{0} parameter(2)
-; CHECK-NEXT:    [[P2_BITCAST:%[^ ]+]] = f32[1,6,1,1]{3,2,1,0} bitcast([[P2]])
-; CHECK-NEXT:    [[CC:%[^ ]+]] = (f32[16,6,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
-; CHECK:           custom_call_target="__cudnn$norm",
-; CHECK:           backend_config={
-; CHECK-DAG:         "epsilon":0.001
-; CHECK:           }
-; CHECK-NEXT:    [[GTE:%[^ ]+]] = f32[16,6,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=0
-; CHECK-NEXT:  ROOT [[FUSION:%[^ ]+]] = f32[2,1,6,8]{3,2,1,0} fusion([[GTE]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]]
-  )";
-
-  TestNorm(hlo_text, optimized_hlo);
-}
-
-TEST_F(CudnnNormRewriterTest, LayerNorm4D12) {
-#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
-  GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
-#endif
-  if (!(GetCudaComputeCapability().major ==
-        se::CudaComputeCapability::AMPERE) &&
-      !(GetCudaComputeCapability().major ==
-        se::CudaComputeCapability::HOPPER)) {
-    GTEST_SKIP()
-        << "Layer norm kernels require Ampere or Hopper architectures.";
-  }
-  const char* hlo_text = R"(
-    HloModule test
-
-    apply {
-      a = f32[] parameter(0)
-      b = f32[] parameter(1)
-      ROOT c = f32[] add(a,b)
-    }
-
-    ENTRY test {
-        input = f32[2,4,6,8] parameter(0)
-        input_square = f32[2,4,6,8] multiply(input, input)
-        c0 = f32[] constant(0)
-        input_square_sum = f32[2,8] reduce(input_square, c0), dimensions={1,2}, to_apply=apply
-        r_nelems = f32[] constant(0.041667)
-        r_nelems_bcast = f32[2,8] broadcast(r_nelems), dimensions={}
-        input_square_mean = f32[2,8] multiply(input_square_sum, r_nelems_bcast)
-        reduce = f32[2,8] reduce(input, c0), dimensions={1,2}, to_apply=apply
-        input_mean = f32[2,8] multiply(reduce, r_nelems_bcast)
-        input_mean_square = f32[2,8] multiply(input_mean, input_mean)
-        variance = f32[2,8] subtract(input_square_mean, input_mean_square)
-        epsilon = f32[] constant(0.001)
-        epsilon_bcast = f32[2,8] broadcast(epsilon), dimensions={}
-        variance_plus_epsilon = f32[2,8] add(variance, epsilon_bcast)
-        norm_factor = f32[2,8] rsqrt(variance_plus_epsilon)
-        norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,3}
-        input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,3}
-        input_center = f32[2,4,6,8] subtract(input, input_mean_bcast)
-        norm = f32[2,4,6,8] multiply(norm_factor_bcast, input_center)
-        scale = f32[4,6] parameter(1)
-        scale_bcast = f32[2,4,6,8] broadcast(scale), dimensions={1,2}
-        norm_scale = f32[2,4,6,8] multiply(norm, scale_bcast)
-        bias = f32[4,6] parameter(2)
-        bias_broadcast = f32[2,4,6,8] broadcast(bias), dimensions={1,2}
-        ROOT out = f32[2,4,6,8] add(norm_scale, bias_broadcast)
-    })";
-
-  const char* optimized_hlo = R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[4,6], {{.*}}: f32[4,6]) -> f32[2,4,6,8] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0)
-; CHECK-NEXT:    [[TRANSPOSE:%[^ ]+]] = f32[2,8,4,6]{3,2,1,0} transpose([[P0]]), dimensions={0,3,1,2}
-; CHECK-NEXT:    [[P0_BITCAST:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} bitcast([[TRANSPOSE]])
-; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[4,6]{1,0} parameter(1)
-; CHECK-NEXT:    [[P1_BITCAST:%[^ ]+]] = f32[1,4,6,1]{3,2,1,0} bitcast([[P1]])
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[4,6]{1,0} parameter(2)
-; CHECK-NEXT:    [[P2_BITCAST:%[^ ]+]] = f32[1,4,6,1]{3,2,1,0} bitcast([[P2]])
-; CHECK-NEXT:    [[CC:%[^ ]+]] = (f32[16,4,6,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
-; CHECK:           custom_call_target="__cudnn$norm",
-; CHECK:           backend_config={
-; CHECK-DAG:         "epsilon":0.001
-; CHECK:           }
-; CHECK-NEXT:    [[GTE:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} get-tuple-element([[CC]]), index=0
-; CHECK-NEXT:  ROOT  [[FUSION:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} fusion([[GTE]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]]
-  )";
-
-  TestNorm(hlo_text, optimized_hlo);
-}
-
-TEST_F(CudnnNormRewriterTest, LayerNorm4D12Degenerate2) {
-#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
-  GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
-#endif
-  if (!(GetCudaComputeCapability().major ==
-        se::CudaComputeCapability::AMPERE) &&
-      !(GetCudaComputeCapability().major ==
-        se::CudaComputeCapability::HOPPER)) {
-    GTEST_SKIP()
-        << "Layer norm kernels require Ampere or Hopper architectures.";
-  }
-  const char* hlo_text = R"(
-    HloModule test
-
-    apply {
-      a = f32[] parameter(0)
-      b = f32[] parameter(1)
-      ROOT c = f32[] add(a,b)
-    }
-
-    ENTRY test {
-        input = f32[2,4,1,8] parameter(0)
-        input_square = f32[2,4,1,8] multiply(input, input)
-        c0 = f32[] constant(0)
-        input_square_sum = f32[2,8] reduce(input_square, c0), dimensions={1,2}, to_apply=apply
-        r_nelems = f32[] constant(0.25)
-        r_nelems_bcast = f32[2,8] broadcast(r_nelems), dimensions={}
-        input_square_mean = f32[2,8] multiply(input_square_sum, r_nelems_bcast)
-        reduce = f32[2,8] reduce(input, c0), dimensions={1,2}, to_apply=apply
-        input_mean = f32[2,8] multiply(reduce, r_nelems_bcast)
-        input_mean_square = f32[2,8] multiply(input_mean, input_mean)
-        variance = f32[2,8] subtract(input_square_mean, input_mean_square)
-        epsilon = f32[] constant(0.001)
-        epsilon_bcast = f32[2,8] broadcast(epsilon), dimensions={}
-        variance_plus_epsilon = f32[2,8] add(variance, epsilon_bcast)
-        norm_factor = f32[2,8] rsqrt(variance_plus_epsilon)
-        norm_factor_bcast = f32[2,4,1,8] broadcast(norm_factor), dimensions={0,3}
-        input_mean_bcast = f32[2,4,1,8] broadcast(input_mean), dimensions={0,3}
-        input_center = f32[2,4,1,8] subtract(input, input_mean_bcast)
-        norm = f32[2,4,1,8] multiply(norm_factor_bcast, input_center)
-        scale = f32[4,1] parameter(1)
-        scale_bcast = f32[2,4,1,8] broadcast(scale), dimensions={1,2}
-        norm_scale = f32[2,4,1,8] multiply(norm, scale_bcast)
-        bias = f32[4,1] parameter(2)
-        bias_broadcast = f32[2,4,1,8] broadcast(bias), dimensions={1,2}
-        ROOT out = f32[2,4,1,8] add(norm_scale, bias_broadcast)
-    })";
-
-  const char* optimized_hlo = R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,1,8], {{.*}}: f32[4,1], {{.*}}: f32[4,1]) -> f32[2,4,1,8] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} parameter(0)
-; CHECK-NEXT:    [[TRANSPOSE:%[^ ]+]] = f32[1,2,8,4]{3,2,1,0} transpose([[P0]]), dimensions={2,0,3,1}
-; CHECK-NEXT:    [[P0_BITCAST:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} bitcast([[TRANSPOSE]])
-; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[4,1]{1,0} parameter(1)
-; CHECK-NEXT:    [[P1_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P1]])
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[4,1]{1,0} parameter(2)
-; CHECK-NEXT:    [[P2_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P2]])
-; CHECK-NEXT:    [[CC:%[^ ]+]] = (f32[16,4,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
-; CHECK:           custom_call_target="__cudnn$norm",
-; CHECK:           backend_config={
-; CHECK-DAG:         "epsilon":0.001
-; CHECK:           }
-; CHECK-NEXT:    [[GTE:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=0
-; CHECK-NEXT:  ROOT  [[FUSION:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} fusion([[GTE]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]]
-  )";
-
-  TestNorm(hlo_text, optimized_hlo);
-}
-
-TEST_F(CudnnNormRewriterTest, LayerNorm4D3IncorrectScaleBroadcast) {
-#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
-  GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
-#endif
-  if (!(GetCudaComputeCapability().major ==
-        se::CudaComputeCapability::AMPERE) &&
-      !(GetCudaComputeCapability().major ==
-        se::CudaComputeCapability::HOPPER)) {
-    GTEST_SKIP()
-        << "Layer norm kernels require Ampere or Hopper architectures.";
-  }
-  const char* hlo_text = R"(
-    HloModule test
-
-    apply {
-      a = f32[] parameter(0)
-      b = f32[] parameter(1)
-      ROOT c = f32[] add(a,b)
-    }
-
-    ENTRY test {
-        input = f32[2,2,2,2] parameter(0)
-        input_square = f32[2,2,2,2] multiply(input, input)
-        c0 = f32[] constant(0)
-        input_square_sum = f32[2,2,2] reduce(input_square, c0), dimensions={3}, to_apply=apply
-        r_nelems = f32[] constant(0.5)
-        r_nelems_bcast = f32[2,2,2] broadcast(r_nelems), dimensions={}
-        input_square_mean = f32[2,2,2] multiply(input_square_sum, r_nelems_bcast)
-        input_sum = f32[2,2,2] reduce(input, c0), dimensions={3}, to_apply=apply
-        input_mean = f32[2,2,2] multiply(input_sum, r_nelems_bcast)
-        input_mean_square = f32[2,2,2] multiply(input_mean, input_mean)
-        variance = f32[2,2,2] subtract(input_square_mean, input_mean_square)
-        epsilon = f32[] constant(0.001)
-        epsilon_bcast = f32[2,2,2] broadcast(epsilon), dimensions={}
-        variance_plus_epsilon = f32[2,2,2] add(variance, epsilon_bcast)
-        norm_factor = f32[2,2,2] rsqrt(variance_plus_epsilon)
-        norm_factor_bcast = f32[2,2,2,2] broadcast(norm_factor), dimensions={0,1,2}
-        input_mean_bcast = f32[2,2,2,2] broadcast(input_mean), dimensions={0,1,2}
-        input_center = f32[2,2,2,2] subtract(input, input_mean_bcast)
-        norm = f32[2,2,2,2] multiply(norm_factor_bcast, input_center)
-        scale = f32[2] parameter(1)
-        scale_bcast = f32[2,2,2,2] broadcast(scale), dimensions={2}
-        norm_scale = f32[2,2,2,2] multiply(norm, scale_bcast)
-        bias = f32[2] parameter(2)
-        bias_bcast = f32[2,2,2,2] broadcast(bias), dimensions={3}
-        ROOT out = f32[2,2,2,2] add(norm_scale, bias_bcast)
-    })";
-
-  const char* optimized_hlo = R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,2,2,2], {{.*}}: f32[2], {{.*}}: f32[2]) -> f32[2,2,2,2] {
-; CHECK-NOT:           custom_call_target="__cudnn$norm"
-  )";
-
-  TestNorm(hlo_text, optimized_hlo);
-}
-
-TEST_F(CudnnNormRewriterTest, LayerNorm4D3InputOutputTypeMismatch) {
-#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
-  GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
-#endif
-  if (!(GetCudaComputeCapability().major ==
-        se::CudaComputeCapability::AMPERE) &&
-      !(GetCudaComputeCapability().major ==
-        se::CudaComputeCapability::HOPPER)) {
-    GTEST_SKIP()
-        << "Layer norm kernels require Ampere or Hopper architectures.";
-  }
-  const char* hlo_text = R"(
-    HloModule test
-
-    apply {
-      a = f32[] parameter(0)
-      b = f32[] parameter(1)
-      ROOT c = f32[] add(a,b)
-    }
-
-    ENTRY test {
-        input = f16[2,4,6,8] parameter(0)
-        input_f32 = f32[2,4,6,8] convert(input)
-        input_square = f32[2,4,6,8] multiply(input_f32, input_f32)
-        c0 = f32[] constant(0)
-        input_square_sum = f32[2,4,6] reduce(input_square, c0), dimensions={3}, to_apply=apply
-        r_nelems = f32[] constant(0.125)
-        r_nelems_bcast = f32[2,4,6] broadcast(r_nelems), dimensions={}
-        input_square_mean = f32[2,4,6] multiply(input_square_sum, r_nelems_bcast)
-        input_sum = f32[2,4,6] reduce(input_f32, c0), dimensions={3}, to_apply=apply
-        input_mean = f32[2,4,6] multiply(input_sum, r_nelems_bcast)
-        input_mean_square = f32[2,4,6] multiply(input_mean, input_mean)
-        variance = f32[2,4,6] subtract(input_square_mean, input_mean_square)
-        epsilon = f32[] constant(0.001)
-        epsilon_bcast = f32[2,4,6] broadcast(epsilon), dimensions={}
-        variance_plus_epsilon = f32[2,4,6] add(variance, epsilon_bcast)
-        norm_factor = f32[2,4,6] rsqrt(variance_plus_epsilon)
-        norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,1,2}
-        input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,1,2}
-        input_center = f32[2,4,6,8] subtract(input_f32, input_mean_bcast)
-        norm = f32[2,4,6,8] multiply(norm_factor_bcast, input_center)
-        scale = f32[8] parameter(1)
-        scale_bcast = f32[2,4,6,8] broadcast(scale), dimensions={3}
-        norm_scale = f32[2,4,6,8] multiply(norm, scale_bcast)
-        bias = f32[8] parameter(2)
-        bias_bcast = f32[2,4,6,8] broadcast(bias), dimensions={3}
-        ROOT out = f32[2,4,6,8] add(norm_scale, bias_bcast)
-    })";
-
-  const char* optimized_hlo = R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f16[2,4,6,8], {{.*}}: f32[8], {{.*}}: f32[8]) -> f32[2,4,6,8] {
-; CHECK-NOT:           custom_call_target="__cudnn$norm"
-  )";
-
-  TestNorm(hlo_text, optimized_hlo);
-}
-
-TEST_F(CudnnNormRewriterTest, LayerNormTrain2D1) {
-#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
-  GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
-#endif
-  if (!(GetCudaComputeCapability().major ==
-        se::CudaComputeCapability::AMPERE) &&
-      !(GetCudaComputeCapability().major ==
-        se::CudaComputeCapability::HOPPER)) {
-    GTEST_SKIP()
-        << "Layer norm kernels require Ampere or Hopper architectures.";
-  }
-  const char* hlo_text = R"(
-    HloModule test
-
-    apply {
-      a = f32[] parameter(0)
-      b = f32[] parameter(1)
-      ROOT c = f32[] add(a,b)
-    }
-
-    ENTRY test {
-        input = f32[2,4] parameter(0)
-        input_square = f32[2,4] multiply(input, input)
-        c0 = f32[] constant(0)
-        input_square_sum = f32[2] reduce(input_square, c0), dimensions={1}, to_apply=apply
-        r_nelems = f32[] constant(0.25)
-        r_nelems_bcast = f32[2] broadcast(r_nelems), dimensions={}
-        input_square_mean = f32[2] multiply(input_square_sum,r_nelems_bcast)
-        reduce = f32[2] reduce(input, c0), dimensions={1}, to_apply=apply
-        input_mean = f32[2] multiply(reduce,r_nelems_bcast)
-        input_mean_square = f32[2] multiply(input_mean,input_mean)
-        variance = f32[2] subtract(input_square_mean,input_mean_square)
-        epsilon = f32[] constant(0.001)
-        epsilon_bcast = f32[2] broadcast(epsilon), dimensions={}
-        variance_plus_epsilon = f32[2] add(variance, epsilon_bcast)
-        norm_factor = f32[2] rsqrt(variance_plus_epsilon)
-        norm_factor_bcast = f32[2,4] broadcast(norm_factor), dimensions={0}
-        input_mean_bcast = f32[2,4] broadcast(input_mean), dimensions={0}
-        input_center = f32[2,4] subtract(input,input_mean_bcast)
-        norm = f32[2,4] multiply(norm_factor_bcast,input_center)
-        scale = f32[4] parameter(1)
-        scale_bcast = f32[2,4] broadcast(scale), dimensions={1}
-        norm_scale = f32[2,4] multiply(norm,scale_bcast)
-        bias = f32[4] parameter(2)
-        bias_broadcast = f32[2,4] broadcast(bias), dimensions={1}
-        norm_scale_bias = f32[2,4] add(norm_scale, bias_broadcast)
-        norm_factor_cube = f32[2] divide(norm_factor, variance_plus_epsilon)
-        ROOT out = (f32[2,4], f32[2], f32[2], f32[2]) tuple(norm_scale_bias, input_mean, norm_factor, norm_factor_cube)
-    })";
-
-  const char* optimized_hlo = R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4], {{.*}}: f32[4], {{.*}}: f32[4]) -> (f32[2,4], f32[2], f32[2], f32[2]) {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,4]{1,0} parameter(0)
-; CHECK-NEXT:    [[P0_BITCAST:%[^ ]+]] = f32[2,4,1,1]{3,2,1,0} bitcast([[P0]])
-; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[4]{0} parameter(1)
-; CHECK-NEXT:    [[P1_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P1]])
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[4]{0} parameter(2)
-; CHECK-NEXT:    [[P2_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P2]])
-; CHECK-NEXT:    [[CC:%[^ ]+]] = (f32[2,4,1,1]{3,2,1,0}, f32[2,1,1,1]{3,2,1,0}, f32[2,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
-; CHECK:           custom_call_target="__cudnn$norm",
-; CHECK:           backend_config={
-; CHECK-DAG:         "epsilon":0.001
-; CHECK:           }
-; CHECK-NEXT:    [[GTE0:%[^ ]+]] = f32[2,4,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=0
-; CHECK-NEXT:    [[GTE0_BITCAST:%[^ ]+]] = f32[2,4]{1,0} bitcast([[GTE0]])
-; CHECK-NEXT:    [[GTE1:%[^ ]+]] = f32[2,1,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=1
-; CHECK-NEXT:    [[GTE1_BITCAST:%[^ ]+]] = f32[2]{0} bitcast([[GTE1]])
-; CHECK-NEXT:    [[GTE2:%[^ ]+]] = f32[2,1,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=2
-; CHECK-NEXT:    [[GTE2_BITCAST:%[^ ]+]] = f32[2]{0} bitcast([[GTE2]])
-; CHECK-NEXT:    [[FUSION:%[^ ]+]] = f32[2]{0} fusion([[GTE2]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]]
-; CHECK-NEXT:  ROOT [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, f32[2]{0}, f32[2]{0}, f32[2]{0}) tuple([[GTE0_BITCAST]], [[GTE1_BITCAST]], [[GTE2_BITCAST]], [[FUSION]])
-  )";
-
-  TestNorm(hlo_text, optimized_hlo);
-}
-
-TEST_F(CudnnNormRewriterTest, LayerNormTrain4D3) {
-#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
-  GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
-#endif
-  if (!(GetCudaComputeCapability().major ==
-        se::CudaComputeCapability::AMPERE) &&
-      !(GetCudaComputeCapability().major ==
-        se::CudaComputeCapability::HOPPER)) {
-    GTEST_SKIP()
-        << "Layer norm kernels require Ampere or Hopper architectures.";
-  }
-  const char* hlo_text = R"(
-    HloModule test
-
-    apply {
-      a = f32[] parameter(0)
-      b = f32[] parameter(1)
-      ROOT c = f32[] add(a,b)
-    }
-
-    ENTRY test {
-        input = f32[2,4,6,8] parameter(0)
-        input_square = f32[2,4,6,8] multiply(input, input)
-        c0 = f32[] constant(0)
-        input_square_sum = f32[2,4,6] reduce(input_square, c0), dimensions={3}, to_apply=apply
-        r_nelems = f32[] constant(0.125)
-        r_nelems_bcast = f32[2,4,6] broadcast(r_nelems), dimensions={}
-        input_square_mean = f32[2,4,6] multiply(input_square_sum, r_nelems_bcast)
-        reduce = f32[2,4,6] reduce(input, c0), dimensions={3}, to_apply=apply
-        input_mean = f32[2,4,6] multiply(reduce, r_nelems_bcast)
-        input_mean_square = f32[2,4,6] multiply(input_mean, input_mean)
-        variance = f32[2,4,6] subtract(input_square_mean, input_mean_square)
-        epsilon = f32[] constant(0.001)
-        epsilon_bcast = f32[2,4,6] broadcast(epsilon), dimensions={}
-        variance_plus_epsilon = f32[2,4,6] add(variance, epsilon_bcast)
-        norm_factor = f32[2,4,6] rsqrt(variance_plus_epsilon)
-        norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,1,2}
-        input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,1,2}
-        input_center = f32[2,4,6,8] subtract(input, input_mean_bcast)
-        norm = f32[2,4,6,8] multiply(norm_factor_bcast, input_center)
-        scale = f32[8] parameter(1)
-        scale_bcast = f32[2,4,6,8] broadcast(scale), dimensions={3}
-        norm_scale = f32[2,4,6,8] multiply(norm,scale_bcast)
-        bias = f32[8] parameter(2)
-        bias_broadcast = f32[2,4,6,8] broadcast(bias), dimensions={3}
-        norm_scale_bias = f32[2,4,6,8] add(norm_scale, bias_broadcast)
-        norm_factor_cube = f32[2,4,6] divide(norm_factor, variance_plus_epsilon)
-        ROOT out = (f32[2,4,6,8], f32[2,4,6], f32[2,4,6], f32[2,4,6]) tuple(norm_scale_bias, input_mean, norm_factor, norm_factor_cube)
-    })";
-
-  const char* optimized_hlo = R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[8], {{.*}}: f32[8]) -> (f32[2,4,6,8], f32[2,4,6], f32[2,4,6], f32[2,4,6]) {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0)
-; CHECK-NEXT:    [[P0_BITCAST:%[^ ]+]] = f32[48,8,1,1]{3,2,1,0} bitcast([[P0]])
-; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[8]{0} parameter(1)
-; CHECK-NEXT:    [[P1_BITCAST:%[^ ]+]] = f32[1,8,1,1]{3,2,1,0} bitcast([[P1]])
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[8]{0} parameter(2)
-; CHECK-NEXT:    [[P2_BITCAST:%[^ ]+]] = f32[1,8,1,1]{3,2,1,0} bitcast([[P2]])
-; CHECK-NEXT:    [[CC:%[^ ]+]] = (f32[48,8,1,1]{3,2,1,0}, f32[48,1,1,1]{3,2,1,0}, f32[48,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
-; CHECK:           custom_call_target="__cudnn$norm",
-; CHECK:           backend_config={
-; CHECK-DAG:         "epsilon":0.001
-; CHECK:           }
-; CHECK-NEXT:    [[GTE0:%[^ ]+]] = f32[48,8,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=0
-; CHECK-NEXT:    [[GTE0_BITCAST:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} bitcast([[GTE0]])
-; CHECK-NEXT:    [[GTE1:%[^ ]+]] = f32[48,1,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=1
-; CHECK-NEXT:    [[GTE1_BITCAST:%[^ ]+]] = f32[2,4,6]{2,1,0} bitcast([[GTE1]])
-; CHECK-NEXT:    [[GTE2:%[^ ]+]] = f32[48,1,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=2
-; CHECK-NEXT:    [[GTE2_BITCAST:%[^ ]+]] = f32[2,4,6]{2,1,0} bitcast([[GTE2]])
-; CHECK-NEXT:    [[FUSION:%[^ ]+]] = f32[2,4,6]{2,1,0} fusion([[GTE2]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]]
-; CHECK-NEXT:  ROOT [[OUT:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6]{2,1,0}, f32[2,4,6]{2,1,0}, f32[2,4,6]{2,1,0}) tuple([[GTE0_BITCAST]], [[GTE1_BITCAST]], [[GTE2_BITCAST]], [[FUSION]])
-  )";
-
-  TestNorm(hlo_text, optimized_hlo);
-}
-
-TEST_F(CudnnNormRewriterTest, LayerNormTrain4D12) {
-#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
-  GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
-#endif
-  if (!(GetCudaComputeCapability().major ==
-        se::CudaComputeCapability::AMPERE) &&
-      !(GetCudaComputeCapability().major ==
-        se::CudaComputeCapability::HOPPER)) {
-    GTEST_SKIP()
-        << "Layer norm kernels require Ampere or Hopper architectures.";
-  }
-  const char* hlo_text = R"(
-    HloModule test
-
-    apply {
-      a = f32[] parameter(0)
-      b = f32[] parameter(1)
-      ROOT c = f32[] add(a,b)
-    }
-
-    ENTRY test {
-        input = f32[2,4,6,8] parameter(0)
-        input_square = f32[2,4,6,8] multiply(input, input)
-        c0 = f32[] constant(0)
-        input_square_sum = f32[2,8] reduce(input_square, c0), dimensions={1,2}, to_apply=apply
-        r_nelems = f32[] constant(0.041667)
-        r_nelems_bcast = f32[2,8] broadcast(r_nelems), dimensions={}
-        input_square_mean = f32[2,8] multiply(input_square_sum, r_nelems_bcast)
-        reduce = f32[2,8] reduce(input, c0), dimensions={1,2}, to_apply=apply
-        input_mean = f32[2,8] multiply(reduce, r_nelems_bcast)
-        input_mean_square = f32[2,8] multiply(input_mean, input_mean)
-        variance = f32[2,8] subtract(input_square_mean, input_mean_square)
-        epsilon = f32[] constant(0.001)
-        epsilon_bcast = f32[2,8] broadcast(epsilon), dimensions={}
-        variance_plus_epsilon = f32[2,8] add(variance, epsilon_bcast)
-        norm_factor = f32[2,8] rsqrt(variance_plus_epsilon)
-        norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,3}
-        input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,3}
-        input_center = f32[2,4,6,8] subtract(input, input_mean_bcast)
-        norm = f32[2,4,6,8] multiply(norm_factor_bcast, input_center)
-        scale = f32[4,6] parameter(1)
-        scale_bcast = f32[2,4,6,8] broadcast(scale), dimensions={1,2}
-        norm_scale = f32[2,4,6,8] multiply(norm, scale_bcast)
-        bias = f32[4,6] parameter(2)
-        bias_broadcast = f32[2,4,6,8] broadcast(bias), dimensions={1,2}
-        norm_scale_bias = f32[2,4,6,8] add(norm_scale, bias_broadcast)
-        norm_factor_cube = f32[2,8] divide(norm_factor, variance_plus_epsilon)
-        ROOT out = (f32[2,4,6,8], f32[2,8], f32[2,8], f32[2,8]) tuple(norm_scale_bias, input_mean, norm_factor, norm_factor_cube)
-    })";
-
-  const char* optimized_hlo = R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[4,6], {{.*}}: f32[4,6]) -> (f32[2,4,6,8], f32[2,8], f32[2,8], f32[2,8]) {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0)
-; CHECK-NEXT:    [[TRANSPOSE:%[^ ]+]] = f32[2,8,4,6]{3,2,1,0} transpose([[P0]]), dimensions={0,3,1,2}
-; CHECK-NEXT:    [[P0_BITCAST:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} bitcast([[TRANSPOSE]])
-; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[4,6]{1,0} parameter(1)
-; CHECK-NEXT:    [[P1_BITCAST:%[^ ]+]] = f32[1,4,6,1]{3,2,1,0} bitcast([[P1]])
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[4,6]{1,0} parameter(2)
-; CHECK-NEXT:    [[P2_BITCAST:%[^ ]+]] = f32[1,4,6,1]{3,2,1,0} bitcast([[P2]])
-; CHECK-NEXT:    [[CC:%[^ ]+]] = (f32[16,4,6,1]{3,2,1,0}, f32[16,1,1,1]{3,2,1,0}, f32[16,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
-; CHECK:           custom_call_target="__cudnn$norm",
-; CHECK:           backend_config={
-; CHECK-DAG:         "epsilon":0.001
-; CHECK:           }
-; CHECK-NEXT:    [[GTE0:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} get-tuple-element([[CC]]), index=0
-; CHECK-NEXT:    [[FUSION0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} fusion([[GTE0]]), kind=kLoop, calls=[[FUSED_COMPUTATION0:%[^ ]+]]
-; CHECK-NEXT:    [[GTE1:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=1
-; CHECK-NEXT:    [[GTE1_BITCAST:%[^ ]+]] = f32[2,8]{1,0} bitcast([[GTE1]])
-; CHECK-NEXT:    [[GTE2:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=2
-; CHECK-NEXT:    [[GTE2_BITCAST:%[^ ]+]] = f32[2,8]{1,0} bitcast([[GTE2]])
-; CHECK-NEXT:    [[FUSION1:%[^ ]+]] = f32[2,8]{1,0} fusion([[GTE2]]), kind=kLoop, calls=[[FUSED_COMPUTATION1:%[^ ]+]]
-; CHECK-NEXT:  ROOT [[OUT:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,8]{1,0}, f32[2,8]{1,0}, f32[2,8]{1,0}) tuple([[FUSION0]], [[GTE1_BITCAST]], [[GTE2_BITCAST]], [[FUSION1]])
-  )";
-
-  TestNorm(hlo_text, optimized_hlo);
-}
-
-TEST_F(CudnnNormRewriterTest, LayerNormTrain4D12Degenerate2) {
-#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
-  GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
-#endif
-  if (!(GetCudaComputeCapability().major ==
-        se::CudaComputeCapability::AMPERE) &&
-      !(GetCudaComputeCapability().major ==
-        se::CudaComputeCapability::HOPPER)) {
-    GTEST_SKIP()
-        << "Layer norm kernels require Ampere or Hopper architectures.";
-  }
-  const char* hlo_text = R"(
-    HloModule test
-
-    apply {
-      a = f32[] parameter(0)
-      b = f32[] parameter(1)
-      ROOT c = f32[] add(a,b)
-    }
-
-    ENTRY test {
-        input = f32[2,4,1,8] parameter(0)
-        input_square = f32[2,4,1,8] multiply(input, input)
-        c0 = f32[] constant(0)
-        input_square_sum = f32[2,8] reduce(input_square, c0), dimensions={1,2}, to_apply=apply
-        r_nelems = f32[] constant(0.25)
-        r_nelems_bcast = f32[2,8] broadcast(r_nelems), dimensions={}
-        input_square_mean = f32[2,8] multiply(input_square_sum, r_nelems_bcast)
-        reduce = f32[2,8] reduce(input, c0), dimensions={1,2}, to_apply=apply
-        input_mean = f32[2,8] multiply(reduce, r_nelems_bcast)
-        input_mean_square = f32[2,8] multiply(input_mean, input_mean)
-        variance = f32[2,8] subtract(input_square_mean, input_mean_square)
-        epsilon = f32[] constant(0.001)
-        epsilon_bcast = f32[2,8] broadcast(epsilon), dimensions={}
-        variance_plus_epsilon = f32[2,8] add(variance, epsilon_bcast)
-        norm_factor = f32[2,8] rsqrt(variance_plus_epsilon)
-        norm_factor_bcast = f32[2,4,1,8] broadcast(norm_factor), dimensions={0,3}
-        input_mean_bcast = f32[2,4,1,8] broadcast(input_mean), dimensions={0,3}
-        input_center = f32[2,4,1,8] subtract(input, input_mean_bcast)
-        norm = f32[2,4,1,8] multiply(norm_factor_bcast, input_center)
-        scale = f32[4,1] parameter(1)
-        scale_bcast = f32[2,4,1,8] broadcast(scale), dimensions={1,2}
-        norm_scale = f32[2,4,1,8] multiply(norm, scale_bcast)
-        bias = f32[4,1] parameter(2)
-        bias_broadcast = f32[2,4,1,8] broadcast(bias), dimensions={1,2}
-        norm_scale_bias = f32[2,4,1,8] add(norm_scale, bias_broadcast)
-        norm_factor_cube = f32[2,8] divide(norm_factor, variance_plus_epsilon)
-        ROOT out = (f32[2,4,1,8], f32[2,8], f32[2,8], f32[2,8]) tuple(norm_scale_bias, input_mean, norm_factor, norm_factor_cube)
-    })";
-
-  const char* optimized_hlo = R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,1,8], {{.*}}: f32[4,1], {{.*}}: f32[4,1]) -> (f32[2,4,1,8], f32[2,8], f32[2,8], f32[2,8]) {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} parameter(0)
-; CHECK-NEXT:    [[TRANSPOSE:%[^ ]+]] = f32[1,2,8,4]{3,2,1,0} transpose([[P0]]), dimensions={2,0,3,1}
-; CHECK-NEXT:    [[P0_BITCAST:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} bitcast([[TRANSPOSE]])
-; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[4,1]{1,0} parameter(1)
-; CHECK-NEXT:    [[P1_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P1]])
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[4,1]{1,0} parameter(2)
-; CHECK-NEXT:    [[P2_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P2]])
-; CHECK-NEXT:    [[CC:%[^ ]+]] = (f32[16,4,1,1]{3,2,1,0}, f32[16,1,1,1]{3,2,1,0}, f32[16,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
-; CHECK:           custom_call_target="__cudnn$norm",
-; CHECK:           backend_config={
-; CHECK-DAG:         "epsilon":0.001
-; CHECK:           }
-; CHECK-NEXT:    [[GTE0:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=0
-; CHECK-NEXT:    [[FUSION0:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} fusion([[GTE0]]), kind=kLoop, calls=[[FUSED_COMPUTATION0:%[^ ]+]]
-; CHECK-NEXT:    [[GTE1:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=1
-; CHECK-NEXT:    [[GTE1_BITCAST:%[^ ]+]] = f32[2,8]{1,0} bitcast([[GTE1]])
-; CHECK-NEXT:    [[GTE2:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=2
-; CHECK-NEXT:    [[GTE2_BITCAST:%[^ ]+]] = f32[2,8]{1,0} bitcast([[GTE2]])
-; CHECK-NEXT:    [[FUSION1:%[^ ]+]] = f32[2,8]{1,0} fusion([[GTE2]]), kind=kLoop, calls=[[FUSED_COMPUTATION1:%[^ ]+]]
-; CHECK-NEXT:  ROOT [[OUT:%[^ ]+]] = (f32[2,4,1,8]{3,2,1,0}, f32[2,8]{1,0}, f32[2,8]{1,0}, f32[2,8]{1,0}) tuple([[FUSION0]], [[GTE1_BITCAST]], [[GTE2_BITCAST]], [[FUSION1]])
-  )";
-
-  TestNorm(hlo_text, optimized_hlo);
-}
-
-TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward2D1) {
-#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
-  GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
-#endif
-  if (!(GetCudaComputeCapability().major ==
-        se::CudaComputeCapability::AMPERE) &&
-      !(GetCudaComputeCapability().major ==
-        se::CudaComputeCapability::HOPPER)) {
-    GTEST_SKIP()
-        << "Layer norm kernels require Ampere or Hopper architectures.";
-  }
-  const char* hlo_text = R"(
-    HloModule test
-
-    apply {
-      a = f32[] parameter(0)
-      b = f32[] parameter(1)
-      ROOT c = f32[] add(a,b)
-    }
-
-    ENTRY test {
-        input = f32[2,4] parameter(0)
-        input_square = f32[2,4] multiply(input, input)
-        c0 = f32[] constant(0)
-        input_square_sum = f32[2] reduce(input_square, c0), dimensions={1}, to_apply=apply
-        reduce = f32[2] reduce(input, c0), dimensions={1}, to_apply=apply
-        r_nelems = f32[] constant(0.25)
-        r_nelems_bcast = f32[2] broadcast(r_nelems), dimensions={}
-        input_square_mean = f32[2] multiply(input_square_sum,r_nelems_bcast)
-        input_mean = f32[2] multiply(reduce, r_nelems_bcast)
-        input_mean_square = f32[2] multiply(input_mean,input_mean)
-        variance = f32[2] subtract(input_square_mean,input_mean_square)
-        epsilon = f32[] constant(0.001)
-        epsilon_bcast = f32[2] broadcast(epsilon), dimensions={}
-        variance_plus_epsilon = f32[2] add(variance, epsilon_bcast)
-        norm_factor = f32[2] rsqrt(variance_plus_epsilon)
-        norm_factor_bcast = f32[2,4] broadcast(norm_factor), dimensions={0}
-        input_mean_bcast = f32[2,4] broadcast(input_mean), dimensions={0}
-        input_center = f32[2,4] subtract(input, input_mean_bcast)
-        norm = f32[2,4] multiply(input_center, norm_factor_bcast)
-        scale = f32[4] parameter(1)
-        scale_bcast = f32[2,4] broadcast(scale), dimensions={1}
-        norm_scale = f32[2,4] multiply(norm, scale_bcast)
-        bias = f32[4] parameter(2)
-        bias_bcast = f32[2,4] broadcast(bias), dimensions={1}
-        norm_scale_bias = f32[2,4] add(norm_scale, bias_bcast)
-        doutput = f32[2,4] parameter(3)
-        dbias = f32[4] reduce(doutput, c0), dimensions={0}, to_apply=apply
-        norm_doutput = f32[2,4] multiply(norm, doutput)
-        dscale = f32[4] reduce(norm_doutput, c0), dimensions={0}, to_apply=apply
-        scale_doutput = f32[2,4] multiply(scale_bcast, doutput)
-        input_center_scale_doutput = f32[2,4] multiply(input_center, scale_doutput)
-        f0 = f32[2] reduce(input_center_scale_doutput, c0), dimensions={1}, to_apply=apply
-        norm_factor_cube = f32[2] divide(norm_factor, variance_plus_epsilon)
-        c1 = f32[] constant(-0.5)
-        c1_bcast = f32[2] broadcast(c1), dimensions={}
-        dnorm_factor = f32[2] multiply(norm_factor_cube, c1_bcast)
-        f0_dnorm_factor = f32[2] multiply(f0, dnorm_factor)
-        c2 = f32[] constant(0.5)
-        c2_bcast = f32[2] broadcast(c2), dimensions={}
-        f0_dnorm_factor_scaled = f32[2] multiply(f0_dnorm_factor, c2_bcast)
-        f0_dnorm_factor_scaled_bcast = f32[2,4] broadcast(f0_dnorm_factor_scaled), dimensions={0}
-        f1 = f32[2,4] multiply(input_center, f0_dnorm_factor_scaled_bcast)
-        minus_f1 = f32[2,4] negate(f1)
-        minus_f1_sum = f32[2] reduce(minus_f1, c0), dimensions={1}, to_apply=apply
-        f2 = f32[2,4] multiply(norm_factor_bcast, scale_doutput)
-        minus_f2 = f32[2,4] negate(f2)
-        minus_f2_sum = f32[2] reduce(minus_f2, c0), dimensions={1}, to_apply=apply
-        minus_f1_f2_sum = f32[2] add(minus_f1_sum, minus_f2_sum)
-        minus_f1_f2_sum_scaled = f32[2] multiply(minus_f1_f2_sum, r_nelems_bcast)
-        minus_f1_f2_sum_scaled_bcast = f32[2,4] broadcast(minus_f1_f2_sum_scaled), dimensions={0}
-        f1_f2 = f32[2,4] add(f1, f2)
-        dinput = f32[2,4] add(f1_f2, minus_f1_f2_sum_scaled_bcast)
-        ROOT out = (f32[2,4], f32[2,4], f32[4], f32[4]) tuple(norm_scale_bias, dinput, dscale, dbias)
-    })";
-
-  const char* optimized_hlo = R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4], {{.*}}: f32[4], {{.*}}: f32[4], {{.*}}: f32[2,4]) -> (f32[2,4], f32[2,4], f32[4], f32[4]) {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,4]{1,0} parameter(0)
-; CHECK-NEXT:    [[P0_BITCAST:%[^ ]+]] = f32[2,4,1,1]{3,2,1,0} bitcast([[P0]])
-; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[4]{0} parameter(1)
-; CHECK-NEXT:    [[P1_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P1]])
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[4]{0} parameter(2)
-; CHECK-NEXT:    [[P2_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P2]])
-; CHECK-NEXT:    [[CC0:%[^ ]+]] = (f32[2,4,1,1]{3,2,1,0}, f32[2,1,1,1]{3,2,1,0}, f32[2,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
-; CHECK:           custom_call_target="__cudnn$norm",
-; CHECK:           backend_config={
-; CHECK-DAG:         "epsilon":0.001
-; CHECK-DAG:         "kind":"LAYER_FWD_TRAIN"
-; CHECK:           }
-; CHECK-DAG:     [[GTE0:%[^ ]+]] = f32[2,4,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=0
-; CHECK-DAG:     [[GTE0_BITCAST:%[^ ]+]] = f32[2,4]{1,0} bitcast([[GTE0]])
-; CHECK-DAG:     [[P3:%[^ ]+]] = f32[2,4]{1,0} parameter(3)
-; CHECK-DAG:     [[P3_BITCAST:%[^ ]+]] = f32[2,4,1,1]{3,2,1,0} bitcast([[P3]])
-; CHECK-DAG:     [[GTE1:%[^ ]+]] = f32[2,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=1
-; CHECK-DAG:     [[GTE2:%[^ ]+]] = f32[2,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=2
-; CHECK-NEXT:    [[CC1:%[^ ]+]] = (f32[2,4,1,1]{3,2,1,0}, f32[1,4,1,1]{3,2,1,0}, f32[1,4,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P3_BITCAST]], [[GTE1]], [[GTE2]]),
-; CHECK:           custom_call_target="__cudnn$norm",
-; CHECK:           backend_config={
-; CHECK-DAG:         "epsilon":0
-; CHECK-DAG:         "kind":"LAYER_BWD"
-; CHECK:           }
-; CHECK-DAG:     [[GTE3:%[^ ]+]] = f32[2,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=0
-; CHECK-DAG:     [[GTE3_BITCAST:%[^ ]+]] = f32[2,4]{1,0} bitcast([[GTE3]])
-; CHECK-DAG:     [[GTE4:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=1
-; CHECK-DAG:     [[GTE4_BITCAST:%[^ ]+]] = f32[4]{0} bitcast([[GTE4]])
-; CHECK-DAG:     [[GTE5:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=2
-; CHECK-DAG:     [[GTE5_BITCAST:%[^ ]+]] = f32[4]{0} bitcast([[GTE5]])
-; CHECK-DAG:  ROOT [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, f32[2,4]{1,0}, f32[4]{0}, f32[4]{0}) tuple([[GTE0_BITCAST]], [[GTE3_BITCAST]], [[GTE4_BITCAST]], [[GTE5_BITCAST]])
-  )";
-
-  TestNorm(hlo_text, optimized_hlo);
-}
-
-TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D3) {
-#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
-  GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
-#endif
-  if (!(GetCudaComputeCapability().major ==
-        se::CudaComputeCapability::AMPERE) &&
-      !(GetCudaComputeCapability().major ==
-        se::CudaComputeCapability::HOPPER)) {
-    GTEST_SKIP()
-        << "Layer norm kernels require Ampere or Hopper architectures.";
-  }
-  const char* hlo_text = R"(
-    HloModule test
-
-    apply {
-      a = f32[] parameter(0)
-      b = f32[] parameter(1)
-      ROOT c = f32[] add(a,b)
-    }
-
-    ENTRY test {
-        input = f32[2,4,6,8] parameter(0)
-        input_square = f32[2,4,6,8] multiply(input, input)
-        c0 = f32[] constant(0)
-        input_square_sum = f32[2,4,6] reduce(input_square, c0), dimensions={3}, to_apply=apply
-        reduce = f32[2,4,6] reduce(input, c0), dimensions={3}, to_apply=apply
-        r_nelems = f32[] constant(0.125)
-        r_nelems_bcast = f32[2,4,6] broadcast(r_nelems), dimensions={}
-        input_square_mean = f32[2,4,6] multiply(input_square_sum,r_nelems_bcast)
-        input_mean = f32[2,4,6] multiply(reduce, r_nelems_bcast)
-        input_mean_square = f32[2,4,6] multiply(input_mean,input_mean)
-        variance = f32[2,4,6] subtract(input_square_mean,input_mean_square)
-        epsilon = f32[] constant(0.001)
-        epsilon_bcast = f32[2,4,6] broadcast(epsilon), dimensions={}
-        variance_plus_epsilon = f32[2,4,6] add(variance, epsilon_bcast)
-        norm_factor = f32[2,4,6] rsqrt(variance_plus_epsilon)
-        norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,1,2}
-        input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,1,2}
-        input_center = f32[2,4,6,8] subtract(input, input_mean_bcast)
-        norm = f32[2,4,6,8] multiply(input_center, norm_factor_bcast)
-        scale = f32[8] parameter(1)
-        scale_bcast = f32[2,4,6,8] broadcast(scale), dimensions={3}
-        norm_scale = f32[2,4,6,8] multiply(norm, scale_bcast)
-        bias = f32[8] parameter(2)
-        bias_bcast = f32[2,4,6,8] broadcast(bias), dimensions={3}
-        norm_scale_bias = f32[2,4,6,8] add(norm_scale, bias_bcast)
-        doutput = f32[2,4,6,8] parameter(3)
-        dbias = f32[8] reduce(doutput, c0), dimensions={0,1,2}, to_apply=apply
-        norm_doutput = f32[2,4,6,8] multiply(norm, doutput)
-        dscale = f32[8] reduce(norm_doutput, c0), dimensions={0,1,2}, to_apply=apply
-        scale_doutput = f32[2,4,6,8] multiply(scale_bcast, doutput)
-        input_center_scale_doutput = f32[2,4,6,8] multiply(input_center, scale_doutput)
-        f0 = f32[2,4,6] reduce(input_center_scale_doutput, c0), dimensions={3}, to_apply=apply
-        norm_factor_cube = f32[2,4,6] divide(norm_factor, variance_plus_epsilon)
-        c1 = f32[] constant(-0.5)
-        c1_bcast = f32[2,4,6] broadcast(c1), dimensions={}
-        dnorm_factor = f32[2,4,6] multiply(norm_factor_cube, c1_bcast)
-        f0_dnorm_factor = f32[2,4,6] multiply(f0, dnorm_factor)
-        c2 = f32[] constant(0.25)
-        c2_bcast = f32[2,4,6] broadcast(c2), dimensions={}
-        f0_dnorm_factor_scaled = f32[2,4,6] multiply(f0_dnorm_factor, c2_bcast)
-        f0_dnorm_factor_scaled_bcast = f32[2,4,6,8] broadcast(f0_dnorm_factor_scaled), dimensions={0,1,2}
-        f1 = f32[2,4,6,8] multiply(input_center, f0_dnorm_factor_scaled_bcast)
-        minus_f1 = f32[2,4,6,8] negate(f1)
-        minus_f1_sum = f32[2,4,6] reduce(minus_f1, c0), dimensions={3}, to_apply=apply
-        f2 = f32[2,4,6,8] multiply(norm_factor_bcast, scale_doutput)
-        minus_f2 = f32[2,4,6,8] negate(f2)
-        minus_f2_sum = f32[2,4,6] reduce(minus_f2, c0), dimensions={3}, to_apply=apply
-        minus_f1_f2_sum = f32[2,4,6] add(minus_f1_sum, minus_f2_sum)
-        minus_f1_f2_sum_scaled = f32[2,4,6] multiply(minus_f1_f2_sum, r_nelems_bcast)
-        minus_f1_f2_sum_scaled_bcast = f32[2,4,6,8] broadcast(minus_f1_f2_sum_scaled), dimensions={0,1,2}
-        f1_f2 = f32[2,4,6,8] add(f1, f2)
-        dinput = f32[2,4,6,8] add(f1_f2, minus_f1_f2_sum_scaled_bcast)
-        ROOT out = (f32[2,4,6,8], f32[2,4,6,8], f32[8], f32[8]) tuple(norm_scale_bias, dinput, dscale, dbias)
-    })";
-
-  const char* optimized_hlo = R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[8], {{.*}}: f32[8], {{.*}}: f32[2,4,6,8]) -> (f32[2,4,6,8], f32[2,4,6,8], f32[8], f32[8]) {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0)
-; CHECK-NEXT:    [[P0_BITCAST:%[^ ]+]] = f32[48,8,1,1]{3,2,1,0} bitcast([[P0]])
-; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[8]{0} parameter(1)
-; CHECK-NEXT:    [[P1_BITCAST:%[^ ]+]] = f32[1,8,1,1]{3,2,1,0} bitcast([[P1]])
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[8]{0} parameter(2)
-; CHECK-NEXT:    [[P2_BITCAST:%[^ ]+]] = f32[1,8,1,1]{3,2,1,0} bitcast([[P2]])
-; CHECK-NEXT:    [[CC0:%[^ ]+]] = (f32[48,8,1,1]{3,2,1,0}, f32[48,1,1,1]{3,2,1,0}, f32[48,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
-; CHECK:           custom_call_target="__cudnn$norm",
-; CHECK:           backend_config={
-; CHECK-DAG:         "epsilon":0.001
-; CHECK-DAG:         "kind":"LAYER_FWD_TRAIN"
-; CHECK:           }
-; CHECK-DAG:     [[GTE0:%[^ ]+]] = f32[48,8,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=0
-; CHECK-DAG:     [[GTE0_BITCAST:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} bitcast([[GTE0]])
-; CHECK-DAG:     [[P3:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(3)
-; CHECK-DAG:     [[P3_BITCAST:%[^ ]+]] = f32[48,8,1,1]{3,2,1,0} bitcast([[P3]])
-; CHECK-DAG:     [[GTE1:%[^ ]+]] = f32[48,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=1
-; CHECK-DAG:     [[GTE2:%[^ ]+]] = f32[48,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=2
-; CHECK-NEXT:    [[CC1:%[^ ]+]] = (f32[48,8,1,1]{3,2,1,0}, f32[1,8,1,1]{3,2,1,0}, f32[1,8,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P3_BITCAST]], [[GTE1]], [[GTE2]]),
-; CHECK:           custom_call_target="__cudnn$norm",
-; CHECK:           backend_config={
-; CHECK-DAG:         "epsilon":0
-; CHECK-DAG:         "kind":"LAYER_BWD"
-; CHECK:           }
-; CHECK-DAG:     [[GTE3:%[^ ]+]] = f32[48,8,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=0
-; CHECK-DAG:     [[GTE3_BITCAST:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} bitcast([[GTE3]])
-; CHECK-DAG:     [[GTE4:%[^ ]+]] = f32[1,8,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=1
-; CHECK-DAG:     [[GTE4_BITCAST:%[^ ]+]] = f32[8]{0} bitcast([[GTE4]])
-; CHECK-DAG:     [[GTE5:%[^ ]+]] = f32[1,8,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=2
-; CHECK-DAG:     [[GTE5_BITCAST:%[^ ]+]] = f32[8]{0} bitcast([[GTE5]])
-; CHECK-DAG:  ROOT [[OUT:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}, f32[8]{0}, f32[8]{0}) tuple([[GTE0_BITCAST]], [[GTE3_BITCAST]], [[GTE4_BITCAST]], [[GTE5_BITCAST]])
-  )";
-
-  TestNorm(hlo_text, optimized_hlo);
-}
-
-TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D2) {
-#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
-  GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
-#endif
-  if (!(GetCudaComputeCapability().major ==
-        se::CudaComputeCapability::AMPERE) &&
-      !(GetCudaComputeCapability().major ==
-        se::CudaComputeCapability::HOPPER)) {
-    GTEST_SKIP()
-        << "Layer norm kernels require Ampere or Hopper architectures.";
-  }
-  const char* hlo_text = R"(
-    HloModule test
-
-    apply {
-      a = f32[] parameter(0)
-      b = f32[] parameter(1)
-      ROOT c = f32[] add(a,b)
-    }
-
-    ENTRY test {
-        input = f32[2,4,6,8] parameter(0)
-        input_square = f32[2,4,6,8] multiply(input, input)
-        c0 = f32[] constant(0)
-        input_square_sum = f32[2,4,8] reduce(input_square, c0), dimensions={2}, to_apply=apply
-        reduce = f32[2,4,8] reduce(input, c0), dimensions={2}, to_apply=apply
-        r_nelems = f32[] constant(0.166667)
-        r_nelems_bcast = f32[2,4,8] broadcast(r_nelems), dimensions={}
-        input_square_mean = f32[2,4,8] multiply(input_square_sum,r_nelems_bcast)
-        input_mean = f32[2,4,8] multiply(reduce, r_nelems_bcast)
-        input_mean_square = f32[2,4,8] multiply(input_mean,input_mean)
-        variance = f32[2,4,8] subtract(input_square_mean,input_mean_square)
-        epsilon = f32[] constant(0.001)
-        epsilon_bcast = f32[2,4,8] broadcast(epsilon), dimensions={}
-        variance_plus_epsilon = f32[2,4,8] add(variance, epsilon_bcast)
-        norm_factor = f32[2,4,8] rsqrt(variance_plus_epsilon)
-        norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,1,3}
-        input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,1,3}
-        input_center = f32[2,4,6,8] subtract(input, input_mean_bcast)
-        norm = f32[2,4,6,8] multiply(input_center, norm_factor_bcast)
-        scale = f32[6] parameter(1)
-        scale_bcast = f32[2,4,6,8] broadcast(scale), dimensions={2}
-        norm_scale = f32[2,4,6,8] multiply(norm, scale_bcast)
-        bias = f32[6] parameter(2)
-        bias_bcast = f32[2,4,6,8] broadcast(bias), dimensions={2}
-        norm_scale_bias = f32[2,4,6,8] add(norm_scale, bias_bcast)
-        doutput = f32[2,4,6,8] parameter(3)
-        dbias = f32[6] reduce(doutput, c0), dimensions={0,1,3}, to_apply=apply
-        norm_doutput = f32[2,4,6,8] multiply(norm, doutput)
-        dscale = f32[6] reduce(norm_doutput, c0), dimensions={0,1,3}, to_apply=apply
-        scale_doutput = f32[2,4,6,8] multiply(scale_bcast, doutput)
-        input_center_scale_doutput = f32[2,4,6,8] multiply(input_center, scale_doutput)
-        f0 = f32[2,4,8] reduce(input_center_scale_doutput, c0), dimensions={2}, to_apply=apply
-        norm_factor_cube = f32[2,4,8] divide(norm_factor, variance_plus_epsilon)
-        c1 = f32[] constant(-0.5)
-        c1_bcast = f32[2,4,8] broadcast(c1), dimensions={}
-        dnorm_factor = f32[2,4,8] multiply(norm_factor_cube, c1_bcast)
-        f0_dnorm_factor = f32[2,4,8] multiply(f0, dnorm_factor)
-        c2 = f32[] constant(0.333333)
-        c2_bcast = f32[2,4,8] broadcast(c2), dimensions={}
-        f0_dnorm_factor_scaled = f32[2,4,8] multiply(f0_dnorm_factor, c2_bcast)
-        f0_dnorm_factor_scaled_bcast = f32[2,4,6,8] broadcast(f0_dnorm_factor_scaled), dimensions={0,1,3}
-        f1 = f32[2,4,6,8] multiply(input_center, f0_dnorm_factor_scaled_bcast)
-        minus_f1 = f32[2,4,6,8] negate(f1)
-        minus_f1_sum = f32[2,4,8] reduce(minus_f1, c0), dimensions={2}, to_apply=apply
-        f2 = f32[2,4,6,8] multiply(norm_factor_bcast, scale_doutput)
-        minus_f2 = f32[2,4,6,8] negate(f2)
-        minus_f2_sum = f32[2,4,8] reduce(minus_f2, c0), dimensions={2}, to_apply=apply
-        minus_f1_f2_sum = f32[2,4,8] add(minus_f1_sum, minus_f2_sum)
-        minus_f1_f2_sum_scaled = f32[2,4,8] multiply(minus_f1_f2_sum, r_nelems_bcast)
-        minus_f1_f2_sum_scaled_bcast = f32[2,4,6,8] broadcast(minus_f1_f2_sum_scaled), dimensions={0,1,3}
-        f1_f2 = f32[2,4,6,8] add(f1, f2)
-        dinput = f32[2,4,6,8] add(f1_f2, minus_f1_f2_sum_scaled_bcast)
-        ROOT out = (f32[2,4,6,8], f32[2,4,6,8], f32[6], f32[6]) tuple(norm_scale_bias, dinput, dscale, dbias)
-    })";
-
-  const char* optimized_hlo = R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[6], {{.*}}: f32[6], {{.*}}: f32[2,4,6,8]) -> (f32[2,4,6,8], f32[2,4,6,8], f32[6], f32[6]) {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0)
-; CHECK-NEXT:    [[TRANSPOSE0:%[^ ]+]] = f32[2,4,8,6]{3,2,1,0} transpose([[P0]]), dimensions={0,1,3,2}
-; CHECK-NEXT:    [[P0_BITCAST:%[^ ]+]] = f32[64,6,1,1]{3,2,1,0} bitcast([[TRANSPOSE0]])
-; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[6]{0} parameter(1)
-; CHECK-NEXT:    [[P1_BITCAST:%[^ ]+]] = f32[1,6,1,1]{3,2,1,0} bitcast([[P1]])
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[6]{0} parameter(2)
-; CHECK-NEXT:    [[P2_BITCAST:%[^ ]+]] = f32[1,6,1,1]{3,2,1,0} bitcast([[P2]])
-; CHECK-NEXT:    [[CC0:%[^ ]+]] = (f32[64,6,1,1]{3,2,1,0}, f32[64,1,1,1]{3,2,1,0}, f32[64,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
-; CHECK:           custom_call_target="__cudnn$norm",
-; CHECK:           backend_config={
-; CHECK-DAG:         "epsilon":0.001
-; CHECK-DAG:         "kind":"LAYER_FWD_TRAIN"
-; CHECK:           }
-; CHECK-DAG:     [[GTE0:%[^ ]+]] = f32[64,6,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=0
-; CHECK-DAG:     [[P3:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(3)
-; CHECK-NEXT:    [[TRANSPOSE1:%[^ ]+]] = f32[2,4,8,6]{3,2,1,0} transpose([[P3]]), dimensions={0,1,3,2}
-; CHECK-DAG:     [[P3_BITCAST:%[^ ]+]] = f32[64,6,1,1]{3,2,1,0} bitcast([[TRANSPOSE1]])
-; CHECK-DAG:     [[GTE1:%[^ ]+]] = f32[64,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=1
-; CHECK-DAG:     [[GTE2:%[^ ]+]] = f32[64,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=2
-; CHECK-NEXT:    [[CC1:%[^ ]+]] = (f32[64,6,1,1]{3,2,1,0}, f32[1,6,1,1]{3,2,1,0}, f32[1,6,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P3_BITCAST]], [[GTE1]], [[GTE2]]),
-; CHECK:           custom_call_target="__cudnn$norm",
-; CHECK:           backend_config={
-; CHECK-DAG:         "epsilon":0
-; CHECK-DAG:         "kind":"LAYER_BWD"
-; CHECK:           }
-; CHECK-DAG:     [[GTE3:%[^ ]+]] = f32[64,6,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=0
-; CHECK-DAG:     [[FUSION:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}) fusion([[GTE0]], [[GTE3]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]]
-; CHECK-DAG:     [[GTEF0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION]]), index=0
-; CHECK-DAG:     [[GTEF1:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION]]), index=1
-; CHECK-DAG:     [[GTE4:%[^ ]+]] = f32[1,6,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=1
-; CHECK-DAG:     [[GTE4_BITCAST:%[^ ]+]] = f32[6]{0} bitcast([[GTE4]])
-; CHECK-DAG:     [[GTE5:%[^ ]+]] = f32[1,6,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=2
-; CHECK-DAG:     [[GTE5_BITCAST:%[^ ]+]] = f32[6]{0} bitcast([[GTE5]])
-; CHECK-DAG:  ROOT [[OUT:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}, f32[6]{0}, f32[6]{0}) tuple([[GTEF0]], [[GTEF1]], [[GTE4_BITCAST]], [[GTE5_BITCAST]])
-  )";
-
-  TestNorm(hlo_text, optimized_hlo);
-}
-
-TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D12) {
-#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
-  GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
-#endif
-  if (!(GetCudaComputeCapability().major ==
-        se::CudaComputeCapability::AMPERE) &&
-      !(GetCudaComputeCapability().major ==
-        se::CudaComputeCapability::HOPPER)) {
-    GTEST_SKIP()
-        << "Layer norm kernels require Ampere or Hopper architectures.";
-  }
-  const char* hlo_text = R"(
-    HloModule test
-
-    apply {
-      a = f32[] parameter(0)
-      b = f32[] parameter(1)
-      ROOT c = f32[] add(a,b)
-    }
-
-    ENTRY test {
-        input = f32[2,4,6,8] parameter(0)
-        input_square = f32[2,4,6,8] multiply(input, input)
-        c0 = f32[] constant(0)
-        input_square_sum = f32[2,8] reduce(input_square, c0), dimensions={1,2}, to_apply=apply
-        reduce = f32[2,8] reduce(input, c0), dimensions={1,2}, to_apply=apply
-        r_nelems = f32[] constant(0.041667)
-        r_nelems_bcast = f32[2,8] broadcast(r_nelems), dimensions={}
-        input_square_mean = f32[2,8] multiply(input_square_sum,r_nelems_bcast)
-        input_mean = f32[2,8] multiply(reduce, r_nelems_bcast)
-        input_mean_square = f32[2,8] multiply(input_mean,input_mean)
-        variance = f32[2,8] subtract(input_square_mean,input_mean_square)
-        epsilon = f32[] constant(0.001)
-        epsilon_bcast = f32[2,8] broadcast(epsilon), dimensions={}
-        variance_plus_epsilon = f32[2,8] add(variance, epsilon_bcast)
-        norm_factor = f32[2,8] rsqrt(variance_plus_epsilon)
-        norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,3}
-        input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,3}
-        input_center = f32[2,4,6,8] subtract(input, input_mean_bcast)
-        norm = f32[2,4,6,8] multiply(input_center, norm_factor_bcast)
-        scale = f32[4,6] parameter(1)
-        scale_bcast = f32[2,4,6,8] broadcast(scale), dimensions={1,2}
-        norm_scale = f32[2,4,6,8] multiply(norm, scale_bcast)
-        bias = f32[4,6] parameter(2)
-        bias_bcast = f32[2,4,6,8] broadcast(bias), dimensions={1,2}
-        norm_scale_bias = f32[2,4,6,8] add(norm_scale, bias_bcast)
-        doutput = f32[2,4,6,8] parameter(3)
-        dbias = f32[4,6] reduce(doutput, c0), dimensions={0,3}, to_apply=apply
-        norm_doutput = f32[2,4,6,8] multiply(norm, doutput)
-        dscale = f32[4,6] reduce(norm_doutput, c0), dimensions={0,3}, to_apply=apply
-        scale_doutput = f32[2,4,6,8] multiply(scale_bcast, doutput)
-        input_center_scale_doutput = f32[2,4,6,8] multiply(input_center, scale_doutput)
-        f0 = f32[2,8] reduce(input_center_scale_doutput, c0), dimensions={1,2}, to_apply=apply
-        norm_factor_cube = f32[2,8] divide(norm_factor, variance_plus_epsilon)
-        c1 = f32[] constant(-0.5)
-        c1_bcast = f32[2,8] broadcast(c1), dimensions={}
-        dnorm_factor = f32[2,8] multiply(norm_factor_cube, c1_bcast)
-        f0_dnorm_factor = f32[2,8] multiply(f0, dnorm_factor)
-        c2 = f32[] constant(0.083333)
-        c2_bcast = f32[2,8] broadcast(c2), dimensions={}
-        f0_dnorm_factor_scaled = f32[2,8] multiply(f0_dnorm_factor, c2_bcast)
-        f0_dnorm_factor_scaled_bcast = f32[2,4,6,8] broadcast(f0_dnorm_factor_scaled), dimensions={0,3}
-        f1 = f32[2,4,6,8] multiply(input_center, f0_dnorm_factor_scaled_bcast)
-        minus_f1 = f32[2,4,6,8] negate(f1)
-        minus_f1_sum = f32[2,8] reduce(minus_f1, c0), dimensions={1,2}, to_apply=apply
-        f2 = f32[2,4,6,8] multiply(norm_factor_bcast, scale_doutput)
-        minus_f2 = f32[2,4,6,8] negate(f2)
-        minus_f2_sum = f32[2,8] reduce(minus_f2, c0), dimensions={1,2}, to_apply=apply
-        minus_f1_f2_sum = f32[2,8] add(minus_f1_sum, minus_f2_sum)
-        minus_f1_f2_sum_scaled = f32[2,8] multiply(minus_f1_f2_sum, r_nelems_bcast)
-        minus_f1_f2_sum_scaled_bcast = f32[2,4,6,8] broadcast(minus_f1_f2_sum_scaled), dimensions={0,3}
-        f1_f2 = f32[2,4,6,8] add(f1, f2)
-        dinput = f32[2,4,6,8] add(f1_f2, minus_f1_f2_sum_scaled_bcast)
-        ROOT out = (f32[2,4,6,8], f32[2,4,6,8], f32[4,6], f32[4,6]) tuple(norm_scale_bias, dinput, dscale, dbias)
-    })";
-
-  const char* optimized_hlo = R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[4,6], {{.*}}: f32[4,6], {{.*}}: f32[2,4,6,8]) -> (f32[2,4,6,8], f32[2,4,6,8], f32[4,6], f32[4,6]) {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0)
-; CHECK-NEXT:    [[TRANSPOSE0:%[^ ]+]] = f32[2,8,4,6]{3,2,1,0} transpose([[P0]]), dimensions={0,3,1,2}
-; CHECK-NEXT:    [[P0_BITCAST:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} bitcast([[TRANSPOSE0]])
-; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[4,6]{1,0} parameter(1)
-; CHECK-NEXT:    [[P1_BITCAST:%[^ ]+]] = f32[1,4,6,1]{3,2,1,0} bitcast([[P1]])
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[4,6]{1,0} parameter(2)
-; CHECK-NEXT:    [[P2_BITCAST:%[^ ]+]] = f32[1,4,6,1]{3,2,1,0} bitcast([[P2]])
-; CHECK-NEXT:    [[CC0:%[^ ]+]] = (f32[16,4,6,1]{3,2,1,0}, f32[16,1,1,1]{3,2,1,0}, f32[16,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
-; CHECK:           custom_call_target="__cudnn$norm",
-; CHECK:           backend_config={
-; CHECK-DAG:         "epsilon":0.001
-; CHECK-DAG:         "kind":"LAYER_FWD_TRAIN"
-; CHECK:           }
-; CHECK-DAG:     [[GTE0:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} get-tuple-element([[CC0]]), index=0
-; CHECK-DAG:     [[P3:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(3)
-; CHECK-NEXT:    [[TRANSPOSE1:%[^ ]+]] = f32[2,8,4,6]{3,2,1,0} transpose([[P3]]), dimensions={0,3,1,2}
-; CHECK-DAG:     [[P3_BITCAST:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} bitcast([[TRANSPOSE1]])
-; CHECK-DAG:     [[GTE1:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=1
-; CHECK-DAG:     [[GTE2:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=2
-; CHECK-NEXT:    [[CC1:%[^ ]+]] = (f32[16,4,6,1]{3,2,1,0}, f32[1,4,6,1]{3,2,1,0}, f32[1,4,6,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P3_BITCAST]], [[GTE1]], [[GTE2]]),
-; CHECK:           custom_call_target="__cudnn$norm",
-; CHECK:           backend_config={
-; CHECK-DAG:         "epsilon":0
-; CHECK-DAG:         "kind":"LAYER_BWD"
-; CHECK:           }
-; CHECK-DAG:     [[GTE3:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} get-tuple-element([[CC1]]), index=0
-; CHECK-DAG:     [[FUSION:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}) fusion([[GTE0]], [[GTE3]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]]
-; CHECK-DAG:     [[GTEF0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION]]), index=0
-; CHECK-DAG:     [[GTEF1:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION]]), index=1
-; CHECK-DAG:     [[GTE4:%[^ ]+]] = f32[1,4,6,1]{3,2,1,0} get-tuple-element([[CC1]]), index=1
-; CHECK-DAG:     [[GTE4_BITCAST:%[^ ]+]] = f32[4,6]{1,0} bitcast([[GTE4]])
-; CHECK-DAG:     [[GTE5:%[^ ]+]] = f32[1,4,6,1]{3,2,1,0} get-tuple-element([[CC1]]), index=2
-; CHECK-DAG:     [[GTE5_BITCAST:%[^ ]+]] = f32[4,6]{1,0} bitcast([[GTE5]])
-; CHECK-DAG:  ROOT [[OUT:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}, f32[4,6]{1,0}, f32[4,6]{1,0}) tuple([[GTEF0]], [[GTEF1]], [[GTE4_BITCAST]], [[GTE5_BITCAST]])
-  )";
-
-  TestNorm(hlo_text, optimized_hlo);
-}
-
-TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D12Degenerate2) {
-#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
-  GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
-#endif
-  if (!(GetCudaComputeCapability().major ==
-        se::CudaComputeCapability::AMPERE) &&
-      !(GetCudaComputeCapability().major ==
-        se::CudaComputeCapability::HOPPER)) {
-    GTEST_SKIP()
-        << "Layer norm kernels require Ampere or Hopper architectures.";
-  }
-  const char* hlo_text = R"(
-    HloModule test
-
-    apply {
-      a = f32[] parameter(0)
-      b = f32[] parameter(1)
-      ROOT c = f32[] add(a,b)
-    }
-
-    ENTRY test {
-        input = f32[2,4,1,8] parameter(0)
-        input_square = f32[2,4,1,8] multiply(input, input)
-        c0 = f32[] constant(0)
-        input_square_sum = f32[2,8] reduce(input_square, c0), dimensions={1,2}, to_apply=apply
-        reduce = f32[2,8] reduce(input, c0), dimensions={1,2}, to_apply=apply
-        r_nelems = f32[] constant(0.25)
-        r_nelems_bcast = f32[2,8] broadcast(r_nelems), dimensions={}
-        input_square_mean = f32[2,8] multiply(input_square_sum,r_nelems_bcast)
-        input_mean = f32[2,8] multiply(reduce, r_nelems_bcast)
-        input_mean_square = f32[2,8] multiply(input_mean,input_mean)
-        variance = f32[2,8] subtract(input_square_mean,input_mean_square)
-        epsilon = f32[] constant(0.001)
-        epsilon_bcast = f32[2,8] broadcast(epsilon), dimensions={}
-        variance_plus_epsilon = f32[2,8] add(variance, epsilon_bcast)
-        norm_factor = f32[2,8] rsqrt(variance_plus_epsilon)
-        norm_factor_bcast = f32[2,4,1,8] broadcast(norm_factor), dimensions={0,3}
-        input_mean_bcast = f32[2,4,1,8] broadcast(input_mean), dimensions={0,3}
-        input_center = f32[2,4,1,8] subtract(input, input_mean_bcast)
-        norm = f32[2,4,1,8] multiply(input_center, norm_factor_bcast)
-        scale = f32[4,1] parameter(1)
-        scale_bcast = f32[2,4,1,8] broadcast(scale), dimensions={1,2}
-        norm_scale = f32[2,4,1,8] multiply(norm, scale_bcast)
-        bias = f32[4,1] parameter(2)
-        bias_bcast = f32[2,4,1,8] broadcast(bias), dimensions={1,2}
-        norm_scale_bias = f32[2,4,1,8] add(norm_scale, bias_bcast)
-        doutput = f32[2,4,1,8] parameter(3)
-        dbias = f32[4,1] reduce(doutput, c0), dimensions={0,3}, to_apply=apply
-        norm_doutput = f32[2,4,1,8] multiply(norm, doutput)
-        dscale = f32[4,1] reduce(norm_doutput, c0), dimensions={0,3}, to_apply=apply
-        scale_doutput = f32[2,4,1,8] multiply(scale_bcast, doutput)
-        input_center_scale_doutput = f32[2,4,1,8] multiply(input_center, scale_doutput)
-        f0 = f32[2,8] reduce(input_center_scale_doutput, c0), dimensions={1,2}, to_apply=apply
-        norm_factor_cube = f32[2,8] divide(norm_factor, variance_plus_epsilon)
-        c1 = f32[] constant(-0.5)
-        c1_bcast = f32[2,8] broadcast(c1), dimensions={}
-        dnorm_factor = f32[2,8] multiply(norm_factor_cube, c1_bcast)
-        f0_dnorm_factor = f32[2,8] multiply(f0, dnorm_factor)
-        c2 = f32[] constant(0.5)
-        c2_bcast = f32[2,8] broadcast(c2), dimensions={}
-        f0_dnorm_factor_scaled = f32[2,8] multiply(f0_dnorm_factor, c2_bcast)
-        f0_dnorm_factor_scaled_bcast = f32[2,4,1,8] broadcast(f0_dnorm_factor_scaled), dimensions={0,3}
-        f1 = f32[2,4,1,8] multiply(input_center, f0_dnorm_factor_scaled_bcast)
-        minus_f1 = f32[2,4,1,8] negate(f1)
-        minus_f1_sum = f32[2,8] reduce(minus_f1, c0), dimensions={1,2}, to_apply=apply
-        f2 = f32[2,4,1,8] multiply(norm_factor_bcast, scale_doutput)
-        minus_f2 = f32[2,4,1,8] negate(f2)
-        minus_f2_sum = f32[2,8] reduce(minus_f2, c0), dimensions={1,2}, to_apply=apply
-        minus_f1_f2_sum = f32[2,8] add(minus_f1_sum, minus_f2_sum)
-        minus_f1_f2_sum_scaled = f32[2,8] multiply(minus_f1_f2_sum, r_nelems_bcast)
-        minus_f1_f2_sum_scaled_bcast = f32[2,4,1,8] broadcast(minus_f1_f2_sum_scaled), dimensions={0,3}
-        f1_f2 = f32[2,4,1,8] add(f1, f2)
-        dinput = f32[2,4,1,8] add(f1_f2, minus_f1_f2_sum_scaled_bcast)
-        ROOT out = (f32[2,4,1,8], f32[2,4,1,8], f32[4,1], f32[4,1]) tuple(norm_scale_bias, dinput, dscale, dbias)
-    })";
-
-  const char* optimized_hlo = R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,1,8], {{.*}}: f32[4,1], {{.*}}: f32[4,1], {{.*}}: f32[2,4,1,8]) -> (f32[2,4,1,8], f32[2,4,1,8], f32[4,1], f32[4,1]) {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} parameter(0)
-; CHECK-NEXT:    [[TRANSPOSE0:%[^ ]+]] = f32[1,2,8,4]{3,2,1,0} transpose([[P0]]), dimensions={2,0,3,1}
-; CHECK-NEXT:    [[P0_BITCAST:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} bitcast([[TRANSPOSE0]])
-; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[4,1]{1,0} parameter(1)
-; CHECK-NEXT:    [[P1_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P1]])
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[4,1]{1,0} parameter(2)
-; CHECK-NEXT:    [[P2_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P2]])
-; CHECK-NEXT:    [[CC0:%[^ ]+]] = (f32[16,4,1,1]{3,2,1,0}, f32[16,1,1,1]{3,2,1,0}, f32[16,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
-; CHECK:           custom_call_target="__cudnn$norm",
-; CHECK:           backend_config={
-; CHECK-DAG:         "epsilon":0.001
-; CHECK-DAG:         "kind":"LAYER_FWD_TRAIN"
-; CHECK:           }
-; CHECK-DAG:     [[GTE0:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=0
-; CHECK-DAG:     [[P3:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} parameter(3)
-; CHECK-NEXT:    [[TRANSPOSE1:%[^ ]+]] = f32[1,2,8,4]{3,2,1,0} transpose([[P3]]), dimensions={2,0,3,1}
-; CHECK-DAG:     [[P3_BITCAST:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} bitcast([[TRANSPOSE1]])
-; CHECK-DAG:     [[GTE1:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=1
-; CHECK-DAG:     [[GTE2:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=2
-; CHECK-NEXT:    [[CC1:%[^ ]+]] = (f32[16,4,1,1]{3,2,1,0}, f32[1,4,1,1]{3,2,1,0}, f32[1,4,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P3_BITCAST]], [[GTE1]], [[GTE2]]),
-; CHECK:           custom_call_target="__cudnn$norm",
-; CHECK:           backend_config={
-; CHECK-DAG:         "epsilon":0
-; CHECK-DAG:         "kind":"LAYER_BWD"
-; CHECK:           }
-; CHECK-DAG:     [[GTE3:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=0
-; CHECK-DAG:     [[FUSION0:%[^ ]+]] = (f32[2,4,1,8]{3,2,1,0}, f32[2,4,1,8]{3,2,1,0}) fusion([[GTE0]], [[GTE3]]), kind=kLoop, calls=[[FUSED_COMPUTATION0:%[^ ]+]]
-; CHECK-DAG:     [[GTEF0:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} get-tuple-element([[FUSION0]]), index=0
-; CHECK-DAG:     [[GTEF1:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} get-tuple-element([[FUSION0]]), index=1
-; CHECK-DAG:     [[GTE4:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=1
-; CHECK-DAG:     [[GTE4_BITCAST:%[^ ]+]] = f32[4,1]{1,0} bitcast([[GTE4]])
-; CHECK-DAG:     [[GTE5:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=2
-; CHECK-DAG:     [[GTE5_BITCAST:%[^ ]+]] = f32[4,1]{1,0} bitcast([[GTE5]])
-; CHECK-DAG:  ROOT [[OUT:%[^ ]+]] = (f32[2,4,1,8]{3,2,1,0}, f32[2,4,1,8]{3,2,1,0}, f32[4,1]{1,0}, f32[4,1]{1,0}) tuple([[GTEF0]], [[GTEF1]], [[GTE4_BITCAST]], [[GTE5_BITCAST]])
-  )";
-
-  TestNorm(hlo_text, optimized_hlo);
-}
-
-// TODO(b/343124533) Reenable when fixed
-TEST_F(CudnnNormRewriterTest,
-       DISABLED_LayerNormTrainBackward4D1DoutputReshapeSplit) {
-#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
-  GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
-#endif
-  if (!(GetCudaComputeCapability().major ==
-        se::CudaComputeCapability::AMPERE) &&
-      !(GetCudaComputeCapability().major ==
-        se::CudaComputeCapability::HOPPER)) {
-    GTEST_SKIP()
-        << "Layer norm kernels require Ampere or Hopper architectures.";
-  }
-  const char* hlo_text = R"(
-    HloModule test
-
-    apply {
-      a = f32[] parameter(0)
-      b = f32[] parameter(1)
-      ROOT c = f32[] add(a,b)
-    }
-
-    ENTRY test {
-        input = f32[2,4,6,8] parameter(0)
-        input_square = f32[2,4,6,8] multiply(input, input)
-        c0 = f32[] constant(0)
-        input_square_sum = f32[2,6,8] reduce(input_square, c0), dimensions={1}, to_apply=apply
-        reduce = f32[2,6,8] reduce(input, c0), dimensions={1}, to_apply=apply
-        r_nelems = f32[] constant(0.25)
-        r_nelems_bcast = f32[2,6,8] broadcast(r_nelems), dimensions={}
-        input_square_mean = f32[2,6,8] multiply(input_square_sum,r_nelems_bcast)
-        input_mean = f32[2,6,8] multiply(reduce, r_nelems_bcast)
-        input_mean_square = f32[2,6,8] multiply(input_mean,input_mean)
-        variance = f32[2,6,8] subtract(input_square_mean,input_mean_square)
-        epsilon = f32[] constant(0.001)
-        epsilon_bcast = f32[2,6,8] broadcast(epsilon), dimensions={}
-        variance_plus_epsilon = f32[2,6,8] add(variance, epsilon_bcast)
-        norm_factor = f32[2,6,8] rsqrt(variance_plus_epsilon)
-        norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,2,3}
-        input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,2,3}
-        input_center = f32[2,4,6,8] subtract(input, input_mean_bcast)
-        norm = f32[2,4,6,8] multiply(input_center, norm_factor_bcast)
-        scale = f32[4] parameter(1)
-        scale_bcast = f32[2,4,6,8] broadcast(scale), dimensions={1}
-        norm_scale = f32[2,4,6,8] multiply(norm, scale_bcast)
-        bias = f32[4] parameter(2)
-        bias_bcast = f32[2,4,6,8] broadcast(bias), dimensions={1}
-        norm_scale_bias = f32[2,4,6,8] add(norm_scale, bias_bcast)
-        doutput = f32[2,4,48] parameter(3)
-        dbias = f32[4] reduce(doutput, c0), dimensions={0,2}, to_apply=apply
-        doutput_bitcast = f32[2,4,6,8] reshape(doutput)
-        norm_doutput = f32[2,4,6,8] multiply(norm, doutput_bitcast)
-        dscale = f32[4] reduce(norm_doutput, c0), dimensions={0,2,3}, to_apply=apply
-        scale_doutput = f32[2,4,6,8] multiply(scale_bcast, doutput_bitcast)
-        input_center_scale_doutput = f32[2,4,6,8] multiply(input_center, scale_doutput)
-        f0 = f32[2,6,8] reduce(input_center_scale_doutput, c0), dimensions={1}, to_apply=apply
-        norm_factor_cube = f32[2,6,8] divide(norm_factor, variance_plus_epsilon)
-        c1 = f32[] constant(-0.5)
-        c1_bcast = f32[2,6,8] broadcast(c1), dimensions={}
-        dnorm_factor = f32[2,6,8] multiply(norm_factor_cube, c1_bcast)
-        f0_dnorm_factor = f32[2,6,8] multiply(f0, dnorm_factor)
-        c2 = f32[] constant(0.5)
-        c2_bcast = f32[2,6,8] broadcast(c2), dimensions={}
-        f0_dnorm_factor_scaled = f32[2,6,8] multiply(f0_dnorm_factor, c2_bcast)
-        f0_dnorm_factor_scaled_bcast = f32[2,4,6,8] broadcast(f0_dnorm_factor_scaled), dimensions={0,2,3}
-        f1 = f32[2,4,6,8] multiply(input_center, f0_dnorm_factor_scaled_bcast)
-        minus_f1 = f32[2,4,6,8] negate(f1)
-        minus_f1_sum = f32[2,6,8] reduce(minus_f1, c0), dimensions={1}, to_apply=apply
-        f2 = f32[2,4,6,8] multiply(norm_factor_bcast, scale_doutput)
-        minus_f2 = f32[2,4,6,8] negate(f2)
-        minus_f2_sum = f32[2,6,8] reduce(minus_f2, c0), dimensions={1}, to_apply=apply
-        minus_f1_f2_sum = f32[2,6,8] add(minus_f1_sum, minus_f2_sum)
-        minus_f1_f2_sum_scaled = f32[2,6,8] multiply(minus_f1_f2_sum, r_nelems_bcast)
-        minus_f1_f2_sum_scaled_bcast = f32[2,4,6,8] broadcast(minus_f1_f2_sum_scaled), dimensions={0,2,3}
-        f1_f2 = f32[2,4,6,8] add(f1, f2)
-        dinput = f32[2,4,6,8] add(f1_f2, minus_f1_f2_sum_scaled_bcast)
-        ROOT out = (f32[2,4,6,8], f32[2,4,6,8], f32[4], f32[4]) tuple(norm_scale_bias, dinput, dscale, dbias)
-    })";
-
-  const char* optimized_hlo = R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[4], {{.*}}: f32[4], {{.*}}: f32[2,4,48]) -> (f32[2,4,6,8], f32[2,4,6,8], f32[4], f32[4]) {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0)
-; CHECK-NEXT:    [[TRANSPOSE0:%[^ ]+]] = f32[2,6,8,4]{3,2,1,0} transpose([[P0]]), dimensions={0,2,3,1}
-; CHECK-NEXT:    [[P0_BITCAST:%[^ ]+]] = f32[96,4,1,1]{3,2,1,0} bitcast([[TRANSPOSE0]])
-; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[4]{0} parameter(1)
-; CHECK-NEXT:    [[P1_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P1]])
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[4]{0} parameter(2)
-; CHECK-NEXT:    [[P2_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P2]])
-; CHECK-NEXT:    [[CC0:%[^ ]+]] = (f32[96,4,1,1]{3,2,1,0}, f32[96,1,1,1]{3,2,1,0}, f32[96,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
-; CHECK:           custom_call_target="__cudnn$norm",
-; CHECK:           backend_config={
-; CHECK-DAG:         "epsilon":0.001
-; CHECK-DAG:         "kind":"LAYER_FWD_TRAIN"
-; CHECK:           }
-; CHECK-DAG:     [[GTE0:%[^ ]+]] = f32[96,4,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=0
-; CHECK-DAG:     [[P3:%[^ ]+]] = f32[2,4,48]{2,1,0} parameter(3)
-; CHECK-DAG:     [[FUSION0:%[^ ]+]] = f32[2,6,8,4]{3,2,1,0} fusion([[P3]]), kind=kLoop, calls=[[FUSED_COMPUTATION0:%[^ ]+]]
-; CHECK-DAG:     [[FUSION0_BITCAST:%[^ ]+]] = f32[96,4,1,1]{3,2,1,0} bitcast([[FUSION0]])
-; CHECK-DAG:     [[GTE1:%[^ ]+]] = f32[96,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=1
-; CHECK-DAG:     [[GTE2:%[^ ]+]] = f32[96,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=2
-; CHECK-NEXT:    [[CC1:%[^ ]+]] = (f32[96,4,1,1]{3,2,1,0}, f32[1,4,1,1]{3,2,1,0}, f32[1,4,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[FUSION0_BITCAST]], [[GTE1]], [[GTE2]]),
-; CHECK:           custom_call_target="__cudnn$norm",
-; CHECK:           backend_config={
-; CHECK-DAG:         "epsilon":0
-; CHECK-DAG:         "kind":"LAYER_BWD"
-; CHECK:           }
-; CHECK-DAG:     [[GTE3:%[^ ]+]] = f32[96,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=0
-; CHECK-DAG:     [[FUSION1:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}) fusion([[GTE0]], [[GTE3]]), kind=kLoop, calls=[[FUSED_COMPUTATION1:%[^ ]+]]
-; CHECK-DAG:     [[GTEF1:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION1]]), index=0
-; CHECK-DAG:     [[GTEF2:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION1]]), index=1
-; CHECK-DAG:     [[GTE4:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=1
-; CHECK-DAG:     [[GTE4_BITCAST:%[^ ]+]] = f32[4]{0} bitcast([[GTE4]])
-; CHECK-DAG:     [[GTE5:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=2
-; CHECK-DAG:     [[GTE5_BITCAST:%[^ ]+]] = f32[4]{0} bitcast([[GTE5]])
-; CHECK-DAG:  ROOT [[OUT:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}, f32[4]{0}, f32[4]{0}) tuple([[GTEF1]], [[GTEF2]], [[GTE4_BITCAST]], [[GTE5_BITCAST]])
-  )";
-
-  TestNorm(hlo_text, optimized_hlo);
-}
-
-// TODO(b/343124533) Reenable when fixed
-TEST_F(CudnnNormRewriterTest,
-       DISABLED_LayerNormTrainBackward4D1DoutputReshapeCombine) {
-#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
-  GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
-#endif
-  if (!(GetCudaComputeCapability().major ==
-        se::CudaComputeCapability::AMPERE) &&
-      !(GetCudaComputeCapability().major ==
-        se::CudaComputeCapability::HOPPER)) {
-    GTEST_SKIP()
-        << "Layer norm kernels require Ampere or Hopper architectures.";
-  }
-  const char* hlo_text = R"(
-    HloModule test
-
-    apply {
-      a = f32[] parameter(0)
-      b = f32[] parameter(1)
-      ROOT c = f32[] add(a,b)
-    }
-
-    ENTRY test {
-        input = f32[2,4,6,8] parameter(0)
-        input_square = f32[2,4,6,8] multiply(input, input)
-        c0 = f32[] constant(0)
-        input_square_sum = f32[2,6,8] reduce(input_square, c0), dimensions={1}, to_apply=apply
-        reduce = f32[2,6,8] reduce(input, c0), dimensions={1}, to_apply=apply
-        r_nelems = f32[] constant(0.25)
-        r_nelems_bcast = f32[2,6,8] broadcast(r_nelems), dimensions={}
-        input_square_mean = f32[2,6,8] multiply(input_square_sum,r_nelems_bcast)
-        input_mean = f32[2,6,8] multiply(reduce, r_nelems_bcast)
-        input_mean_square = f32[2,6,8] multiply(input_mean,input_mean)
-        variance = f32[2,6,8] subtract(input_square_mean,input_mean_square)
-        epsilon = f32[] constant(0.001)
-        epsilon_bcast = f32[2,6,8] broadcast(epsilon), dimensions={}
-        variance_plus_epsilon = f32[2,6,8] add(variance, epsilon_bcast)
-        norm_factor = f32[2,6,8] rsqrt(variance_plus_epsilon)
-        norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,2,3}
-        input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,2,3}
-        input_center = f32[2,4,6,8] subtract(input, input_mean_bcast)
-        norm = f32[2,4,6,8] multiply(input_center, norm_factor_bcast)
-        scale = f32[4] parameter(1)
-        scale_bcast = f32[2,4,6,8] broadcast(scale), dimensions={1}
-        norm_scale = f32[2,4,6,8] multiply(norm, scale_bcast)
-        bias = f32[4] parameter(2)
-        bias_bcast = f32[2,4,6,8] broadcast(bias), dimensions={1}
-        norm_scale_bias = f32[2,4,6,8] add(norm_scale, bias_bcast)
-        doutput = f32[2,4,6,2,2,2] parameter(3)
-        dbias = f32[4] reduce(doutput, c0), dimensions={0,2,3,4,5}, to_apply=apply
-        doutput_bitcast = f32[2,4,6,8] reshape(doutput)
-        norm_doutput = f32[2,4,6,8] multiply(norm, doutput_bitcast)
-        dscale = f32[4] reduce(norm_doutput, c0), dimensions={0,2,3}, to_apply=apply
-        scale_doutput = f32[2,4,6,8] multiply(scale_bcast, doutput_bitcast)
-        input_center_scale_doutput = f32[2,4,6,8] multiply(input_center, scale_doutput)
-        f0 = f32[2,6,8] reduce(input_center_scale_doutput, c0), dimensions={1}, to_apply=apply
-        norm_factor_cube = f32[2,6,8] divide(norm_factor, variance_plus_epsilon)
-        c1 = f32[] constant(-0.5)
-        c1_bcast = f32[2,6,8] broadcast(c1), dimensions={}
-        dnorm_factor = f32[2,6,8] multiply(norm_factor_cube, c1_bcast)
-        f0_dnorm_factor = f32[2,6,8] multiply(f0, dnorm_factor)
-        c2 = f32[] constant(0.5)
-        c2_bcast = f32[2,6,8] broadcast(c2), dimensions={}
-        f0_dnorm_factor_scaled = f32[2,6,8] multiply(f0_dnorm_factor, c2_bcast)
-        f0_dnorm_factor_scaled_bcast = f32[2,4,6,8] broadcast(f0_dnorm_factor_scaled), dimensions={0,2,3}
-        f1 = f32[2,4,6,8] multiply(input_center, f0_dnorm_factor_scaled_bcast)
-        minus_f1 = f32[2,4,6,8] negate(f1)
-        minus_f1_sum = f32[2,6,8] reduce(minus_f1, c0), dimensions={1}, to_apply=apply
-        f2 = f32[2,4,6,8] multiply(norm_factor_bcast, scale_doutput)
-        minus_f2 = f32[2,4,6,8] negate(f2)
-        minus_f2_sum = f32[2,6,8] reduce(minus_f2, c0), dimensions={1}, to_apply=apply
-        minus_f1_f2_sum = f32[2,6,8] add(minus_f1_sum, minus_f2_sum)
-        minus_f1_f2_sum_scaled = f32[2,6,8] multiply(minus_f1_f2_sum, r_nelems_bcast)
-        minus_f1_f2_sum_scaled_bcast = f32[2,4,6,8] broadcast(minus_f1_f2_sum_scaled), dimensions={0,2,3}
-        f1_f2 = f32[2,4,6,8] add(f1, f2)
-        dinput = f32[2,4,6,8] add(f1_f2, minus_f1_f2_sum_scaled_bcast)
-        ROOT out = (f32[2,4,6,8], f32[2,4,6,8], f32[4], f32[4]) tuple(norm_scale_bias, dinput, dscale, dbias)
-    })";
-
-  const char* optimized_hlo = R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[4], {{.*}}: f32[4], {{.*}}: f32[2,4,6,2,2,2]) -> (f32[2,4,6,8], f32[2,4,6,8], f32[4], f32[4]) {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0)
-; CHECK-NEXT:    [[TRANSPOSE0:%[^ ]+]] = f32[2,6,8,4]{3,2,1,0} transpose([[P0]]), dimensions={0,2,3,1}
-; CHECK-NEXT:    [[P0_BITCAST:%[^ ]+]] = f32[96,4,1,1]{3,2,1,0} bitcast([[TRANSPOSE0]])
-; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[4]{0} parameter(1)
-; CHECK-NEXT:    [[P1_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P1]])
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[4]{0} parameter(2)
-; CHECK-NEXT:    [[P2_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P2]])
-; CHECK-NEXT:    [[CC0:%[^ ]+]] = (f32[96,4,1,1]{3,2,1,0}, f32[96,1,1,1]{3,2,1,0}, f32[96,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
-; CHECK:           custom_call_target="__cudnn$norm",
-; CHECK:           backend_config={
-; CHECK-DAG:         "epsilon":0.001
-; CHECK-DAG:         "kind":"LAYER_FWD_TRAIN"
-; CHECK:           }
-; CHECK-DAG:     [[GTE0:%[^ ]+]] = f32[96,4,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=0
-; CHECK-DAG:     [[P3:%[^ ]+]] = f32[2,4,6,2,2,2]{5,4,3,2,1,0} parameter(3)
-; CHECK-DAG:     [[FUSION0:%[^ ]+]] = f32[2,6,8,4]{3,2,1,0} fusion([[P3]]), kind=kLoop, calls=[[FUSED_COMPUTATION0:%[^ ]+]]
-; CHECK-DAG:     [[FUSION0_BITCAST:%[^ ]+]] = f32[96,4,1,1]{3,2,1,0} bitcast([[FUSION0]])
-; CHECK-DAG:     [[GTE1:%[^ ]+]] = f32[96,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=1
-; CHECK-DAG:     [[GTE2:%[^ ]+]] = f32[96,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=2
-; CHECK-NEXT:    [[CC1:%[^ ]+]] = (f32[96,4,1,1]{3,2,1,0}, f32[1,4,1,1]{3,2,1,0}, f32[1,4,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[FUSION0_BITCAST]], [[GTE1]], [[GTE2]]),
-; CHECK:           custom_call_target="__cudnn$norm",
-; CHECK:           backend_config={
-; CHECK-DAG:         "epsilon":0
-; CHECK-DAG:         "kind":"LAYER_BWD"
-; CHECK:           }
-; CHECK-DAG:     [[GTE3:%[^ ]+]] = f32[96,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=0
-; CHECK-DAG:     [[FUSION1:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}) fusion([[GTE0]], [[GTE3]]), kind=kLoop, calls=[[FUSED_COMPUTATION1:%[^ ]+]]
-; CHECK-DAG:     [[GTEF1:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION1]]), index=0
-; CHECK-DAG:     [[GTEF2:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION1]]), index=1
-; CHECK-DAG:     [[GTE4:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=1
-; CHECK-DAG:     [[GTE4_BITCAST:%[^ ]+]] = f32[4]{0} bitcast([[GTE4]])
-; CHECK-DAG:     [[GTE5:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=2
-; CHECK-DAG:     [[GTE5_BITCAST:%[^ ]+]] = f32[4]{0} bitcast([[GTE5]])
-; CHECK-DAG:  ROOT [[OUT:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}, f32[4]{0}, f32[4]{0}) tuple([[GTEF1]], [[GTEF2]], [[GTE4_BITCAST]], [[GTE5_BITCAST]])
-  )";
-
-  TestNorm(hlo_text, optimized_hlo);
-}
-
-}  // namespace
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/cudnn_pad_for_convolutions.cc b/third_party/xla/xla/service/gpu/cudnn_pad_for_convolutions.cc
deleted file mode 100644
index ed83a62..0000000
--- a/third_party/xla/xla/service/gpu/cudnn_pad_for_convolutions.cc
+++ /dev/null
@@ -1,528 +0,0 @@
-/* Copyright 2018 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/cudnn_pad_for_convolutions.h"
-
-#include <cstdint>
-#include <functional>
-#include <memory>
-#include <optional>
-#include <tuple>
-#include <utility>
-#include <vector>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/functional/bind_front.h"
-#include "absl/status/status.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/literal_util.h"
-#include "xla/primitive_util.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/service/gpu/cudnn_support_utils.h"
-#include "xla/service/gpu/stream_executor_util.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/util.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-// Creates and returns an HLO that zero-pads one or more dimensions in the given
-// instruction so that its shape is equal to the given shape.
-//
-// Padding is added to the end of each relevant dimension.
-//
-// If the instruction already has the given shape, simply returns it without an
-// intervening pad.
-static HloInstruction* PadInstruction(HloInstruction* instr,
-                                      const Shape& new_shape) {
-  HloComputation* comp = instr->parent();
-
-  const Shape& shape = instr->shape();
-  PaddingConfig pad_config = MakeNoPaddingConfig(shape.rank());
-
-  bool added_padding = false;
-  for (int64_t dim = 0; dim < shape.rank(); ++dim) {
-    if (shape.dimensions(dim) == new_shape.dimensions(dim)) {
-      continue;
-    }
-    CHECK_GT(new_shape.dimensions(dim), shape.dimensions(dim));
-    pad_config.mutable_dimensions(dim)->set_edge_padding_high(
-        new_shape.dimensions(dim) - shape.dimensions(dim));
-    added_padding = true;
-  }
-  if (!added_padding) {
-    return instr;
-  }
-
-  auto* zero = comp->AddInstruction(
-      HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type())));
-  return comp->AddInstruction(
-      HloInstruction::CreatePad(new_shape, instr, zero, pad_config),
-      &instr->metadata());
-}
-
-// Modifies the given convolution to have the given input and result shapes.
-static absl::Status PadConv(HloCustomCallInstruction* conv,
-                            absl::Span<const Shape> new_input_shapes,
-                            const Shape& new_result_shape) {
-  CHECK_EQ(0, conv->shape().tuple_shapes(1).dimensions(0))
-      << "conv must use 0 scratch bytes, i.e. this pass must be run "
-         "before CudnnConvAlgorithmPicker.";
-  std::vector<HloInstruction*> new_operands;
-  new_operands.reserve(conv->operand_count());
-  for (int i = 0; i < conv->operand_count(); ++i) {
-    new_operands.push_back(
-        PadInstruction(conv->mutable_operand(i), new_input_shapes[i]));
-  }
-  const Shape& result_shape = conv->shape().tuple_shapes(0);
-
-  bool changed = false;
-  for (int i = 0; i < conv->operand_count(); ++i) {
-    changed |= (new_operands[i] != conv->mutable_operand(i));
-  }
-  CHECK(changed) << "We should have had to pad at least one input operand.";
-
-  auto add = [&](std::unique_ptr<HloInstruction> new_instr) {
-    return conv->parent()->AddInstruction(std::move(new_instr));
-  };
-
-  Shape new_conv_shape = ShapeUtil::MakeTupleShape(
-      {new_result_shape, ShapeUtil::MakeShape(U8, {0})});
-  auto* new_conv =
-      add(conv->CloneWithNewOperands(new_conv_shape, new_operands));
-
-  // Clone conv's name to new_conv.  This is safe because we're going to remove
-  // conv below.
-  new_conv->SetAndSanitizeName(conv->name());
-
-  VLOG(2) << "Padded features of " << conv->ToString() << ", replaced with "
-          << new_conv->ToString();
-
-  // Slice the new conv result if necessary, keeping in mind that new_conv
-  // has tuple shape (new_result_shape, u8[0]).
-  if (!ShapeUtil::Equal(result_shape, new_result_shape)) {
-    std::vector<int64_t> start_indices(result_shape.dimensions_size(), 0);
-    std::vector<int64_t> end_indices(result_shape.dimensions().begin(),
-                                     result_shape.dimensions().end());
-    std::vector<int64_t> strides(result_shape.dimensions_size(), 1);
-
-    auto* new_conv_result = add(
-        HloInstruction::CreateGetTupleElement(new_result_shape, new_conv, 0));
-    auto* empty_temp_buffer =
-        add(HloInstruction::CreateConstant(LiteralUtil::CreateR1<uint8_t>({})));
-    auto* sliced_result = add(HloInstruction::CreateSlice(
-        result_shape, new_conv_result, start_indices, end_indices, strides));
-    new_conv =
-        add(HloInstruction::CreateTuple({sliced_result, empty_temp_buffer}));
-  }
-
-  return conv->parent()->ReplaceInstruction(conv, new_conv);
-}
-
-static std::vector<HloCustomCallInstruction*> GetRelevantConvs(
-    HloComputation* comp) {
-  std::vector<HloCustomCallInstruction*> convs;
-  for (HloInstruction* instr : comp->instructions()) {
-    if (IsCustomCallToDnnConvolution(*instr)) {
-      convs.push_back(Cast<HloCustomCallInstruction>(instr));
-    }
-  }
-  return convs;
-}
-
-// This is the main function of the transform.  It runs on a given custom call
-// nodes to cuDNN convolution, calls resolve_pad_shapes to resolve
-// the desired input/output feature map shapes, and adds necessary padding and
-// slicing nodes around them.
-//
-// resolve_pad_shapes takes conv, a custom call instruction to cuDNN convolution
-// that may need padding to figure out the desired padded input and output
-// tensor shapes and store the desired shapes in new_input_shapes and
-// new_input_shapes.  Notice that new_input_shapes is a vector for multiple
-// input tensors. This function shall return true if padding is necessary or
-// false otherwise in addition to status.
-static absl::StatusOr<bool> ResolveAndPad(
-    HloCustomCallInstruction* conv,
-    std::function<absl::StatusOr<bool>(HloCustomCallInstruction* conv,
-                                       std::vector<Shape>* new_input_shapes,
-                                       Shape* new_result_shape)>
-        resolve_pad_shapes) {
-  std::vector<Shape> new_input_shapes;
-  Shape new_result_shape;
-  TF_ASSIGN_OR_RETURN(bool result, resolve_pad_shapes(conv, &new_input_shapes,
-                                                      &new_result_shape));
-  if (result) {
-    TF_RETURN_IF_ERROR(PadConv(conv, new_input_shapes, new_result_shape));
-    return true;
-  }
-  return false;
-}
-
-// Adds padding to cudnn convolutions to make them run faster on GPUs with
-// tensor cores.
-//
-//  - f16 convolutions are padded to have input/output channel dimensions that
-//    are multiples of 8, so that we can use tensor cores.
-//
-//  - f16 convolutions with 3 input channels and 32 or 64 output channels are
-//    padded to 4 input channels.  There's a special-cased cudnn algorithm just
-//    for this.
-//
-// Don't run this pass on GPUs without tensor cores -- it will make them slower!
-//
-// TODO(jlebar): Also pad dots.
-static absl::StatusOr<bool> TryResolvePaddedShapesForTensorCore(
-    HloCustomCallInstruction* conv, std::vector<Shape>* new_input_shapes_ptr,
-    Shape* new_result_shape_ptr) {
-  TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(conv));
-  const auto& dnums = conv->convolution_dimension_numbers();
-  auto* lhs = conv->mutable_operand(0);
-  auto* rhs = conv->mutable_operand(1);
-  const Shape& result_shape = conv->shape().tuple_shapes(0);
-
-  // Nothing to do on non-f16 convolutions.
-  if (result_shape.element_type() != PrimitiveType::F16) {
-    return false;
-  }
-
-  // When convolution is grouped, the shapes are in agreement with the group
-  // size. We cannot pad them independently.
-  if (conv->feature_group_count() > 1 || conv->batch_group_count() > 1) {
-    VLOG(2) << "Do not pad grouped convolution.";
-    return false;
-  }
-
-  // TODO(timshen): Don't skip forward-activation convs if we find a benchmark
-  // where there's a speedup.
-  if (kind == CudnnConvKind::kForwardActivation) {
-    return false;
-  }
-
-  Shape new_lhs_shape = lhs->shape();
-  Shape new_rhs_shape = rhs->shape();
-  Shape& new_result_shape = *new_result_shape_ptr;
-  new_result_shape = conv->shape().tuple_shapes(0);
-
-  // new_{input,filter_output}_shape points to the appropriate one of
-  // new_{lhs,rhs,result}_shape.
-  Shape* new_input_shape;
-  Shape* new_filter_shape;
-  Shape* new_output_shape;
-  std::tie(new_input_shape, new_filter_shape, new_output_shape) = [&] {
-    switch (kind) {
-      case CudnnConvKind::kForward:
-      case CudnnConvKind::kForwardActivation:
-      case CudnnConvKind::kForwardGraph:
-        return std::make_tuple(&new_lhs_shape, &new_rhs_shape,
-                               &new_result_shape);
-      case CudnnConvKind::kBackwardInput:
-        return std::make_tuple(&new_result_shape, &new_rhs_shape,
-                               &new_lhs_shape);
-      case CudnnConvKind::kBackwardFilter:
-        return std::make_tuple(&new_lhs_shape, &new_result_shape,
-                               &new_rhs_shape);
-    }
-  }();
-
-  // If there are 3 input features and 32 or 64 output features, pad the input
-  // features to 4.  Otherwise, try padding to multiples of 8 and check that
-  // this doesn't make any of the conv buffers too much larger.
-  auto input_features =
-      new_input_shape->dimensions(dnums.input_feature_dimension());
-  auto output_features =
-      new_output_shape->dimensions(dnums.output_feature_dimension());
-  if (input_features == 3 && (output_features == 32 || output_features == 64)) {
-    new_input_shape->set_dimensions(dnums.input_feature_dimension(), 4);
-    new_filter_shape->set_dimensions(dnums.kernel_input_feature_dimension(), 4);
-  } else {
-    auto pad_dim = [](Shape* s, int64_t dim) {
-      s->set_dimensions(dim, RoundUpTo<int64_t>(s->dimensions(dim), 8));
-    };
-    pad_dim(new_input_shape, dnums.input_feature_dimension());
-    pad_dim(new_filter_shape, dnums.kernel_input_feature_dimension());
-    pad_dim(new_filter_shape, dnums.kernel_output_feature_dimension());
-    pad_dim(new_output_shape, dnums.output_feature_dimension());
-
-    // We won't pad a conv if doing so increases the total number of bytes in
-    // the lhs, rhs, or result by more than this amount.
-    //
-    // TODO(jlebar): This number was tuned experimentally.  It represents a
-    // compromise on our current benchmarks; it speeds some up significantly,
-    // and doesn't slow any down.  But we can observe by changing this value
-    // that there's additional room for speedups.  Achieving those speedups
-    // without also slowing other things down will likely require a more
-    // sophisticated heuristic, possibly some form of auto-tuning.
-    static constexpr double kMaxBytesTouchedBound = 1.35;
-
-    // Check that padding wouldn't increase the total bytes read/written by this
-    // operation too much.
-    auto check_size_increase = [&](const Shape& old_shape,
-                                   const Shape& new_shape) {
-      int64_t old_bytes = ShapeUtil::ByteSizeOf(old_shape);
-      int64_t new_bytes = ShapeUtil::ByteSizeOf(new_shape);
-      if (new_bytes <= old_bytes * kMaxBytesTouchedBound) {
-        return true;
-      }
-      VLOG(3)
-          << "Not padding convolution; doing so would change input / result "
-             "shape from "
-          << ShapeUtil::HumanString(old_shape) << " to "
-          << ShapeUtil::HumanString(new_shape) << ", a size increase of "
-          << new_bytes / static_cast<double>(old_bytes) << "x > "
-          << kMaxBytesTouchedBound << "x: " << conv->ToString();
-      return false;
-    };
-
-    if (!check_size_increase(lhs->shape(), new_lhs_shape) ||
-        !check_size_increase(rhs->shape(), new_rhs_shape) ||
-        !check_size_increase(result_shape, new_result_shape)) {
-      return false;
-    }
-  }
-
-  if (ShapeUtil::Equal(lhs->shape(), new_lhs_shape) &&
-      ShapeUtil::Equal(rhs->shape(), new_rhs_shape)) {
-    VLOG(3) << "No need to pad features of " << conv->ToString();
-    return false;
-  }
-
-  new_input_shapes_ptr->push_back(new_lhs_shape);
-  new_input_shapes_ptr->push_back(new_rhs_shape);
-  return true;
-}
-
-// Adds padding to cudnn integer convolutions to make input and output feature
-// maps multiples of pad_to (usually 4 or 32).
-absl::StatusOr<bool> TryResolvePaddedShapesForIntegerConvolution(
-    int pad_to, const se::CudaComputeCapability& compute_capability,
-    HloCustomCallInstruction* conv, std::vector<Shape>* new_input_shapes_ptr,
-    Shape* new_result_shape_ptr) {
-  TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(conv));
-  const Shape& input_shape = conv->operand(0)->shape();
-  const Shape& kernel_shape = conv->operand(1)->shape();
-  const Shape& result_shape = conv->shape().tuple_shapes(0);
-
-  // Integer convolution only
-  if (!primitive_util::IsIntegralType(input_shape.element_type())) {
-    return false;
-  }
-
-  // kForward and kForwardActivation only
-  if (kind != CudnnConvKind::kForward &&
-      kind != CudnnConvKind::kForwardActivation) {
-    return false;
-  }
-
-  const auto& dnums = conv->convolution_dimension_numbers();
-  std::vector<Shape>& new_input_shapes = *new_input_shapes_ptr;
-  for (auto operand : conv->operands()) {
-    new_input_shapes.push_back(operand->shape());
-  }
-  Shape& new_result_shape = *new_result_shape_ptr;
-  new_result_shape = conv->shape().tuple_shapes(0);
-
-  // The input/kernel/output might already be vectorized (i.e. cudnn layout
-  // NCHW_VECT_C).  If so, we pad the features dim so that
-  // size(features_dim) * size(vect_dim) is a multiple of pad_to.
-  std::optional<int64_t> input_vect_dim;
-  std::optional<int64_t> kernel_vect_dim;
-  std::optional<int64_t> result_vect_dim;
-  std::tie(input_vect_dim, kernel_vect_dim, result_vect_dim) =
-      FindVectorizedFeatureDims(dnums, input_shape, kernel_shape, result_shape);
-
-  int64_t input_vect_size =
-      input_vect_dim.has_value() ? input_shape.dimensions(*input_vect_dim) : 1;
-  int64_t kernel_vect_size = kernel_vect_dim.has_value()
-                                 ? kernel_shape.dimensions(*kernel_vect_dim)
-                                 : 1;
-  int64_t result_vect_size = result_vect_dim.has_value()
-                                 ? result_shape.dimensions(*result_vect_dim)
-                                 : 1;
-  if (pad_to % input_vect_size != 0 || pad_to % kernel_vect_size != 0 ||
-      pad_to % result_vect_size != 0) {
-    // If the conv is already vectorized but pad_to is not a multiple of the
-    // vector size, we choose not to pad.  This is a weird case, because the
-    // only useful vector sizes in cudnn (as of writing) are 4 and 32, and those
-    // are also the only pad_to cases.
-    return false;
-  }
-
-  // Check that cudnn support our desired integer padding/vectorization.
-  TF_ASSIGN_OR_RETURN(bool cudnn_supports,
-                      CudnnSupportsOptimizedIntegerConvolution(
-                          compute_capability, *conv, pad_to));
-  if (!cudnn_supports) {
-    return false;
-  }
-
-  // Pad the features to multiples of pad_to.
-  {
-    auto pad_dim = [&](Shape* s, int64_t dim, int64_t cur_vect_size) {
-      CHECK_EQ(pad_to % cur_vect_size, 0);
-      s->set_dimensions(
-          dim, RoundUpTo<int64_t>(s->dimensions(dim), pad_to / cur_vect_size));
-    };
-
-    switch (kind) {
-      case CudnnConvKind::kForward:
-        CHECK_EQ(new_input_shapes.size(), 2);
-        // Input feature maps
-        pad_dim(new_input_shapes.data(), dnums.input_feature_dimension(),
-                input_vect_size);
-        // Kernel for the input feature maps
-        pad_dim(&new_input_shapes[1], dnums.kernel_input_feature_dimension(),
-                kernel_vect_size);
-        // Kernel for the output feature maps.  In the NCHW_VECT_C, only the
-        // kernel input feature dim is vectorized, so this has cur_vect_size 1.
-        pad_dim(&new_input_shapes[1], dnums.kernel_output_feature_dimension(),
-                /*cur_vect_size=*/1);
-        // Output feature maps
-        pad_dim(&new_result_shape, dnums.output_feature_dimension(),
-                result_vect_size);
-        break;
-      case CudnnConvKind::kForwardActivation:
-        CHECK(new_input_shapes.size() == 3 || new_input_shapes.size() == 4);
-        // Input feature maps
-        pad_dim(new_input_shapes.data(), dnums.input_feature_dimension(),
-                input_vect_size);
-        // Kernel for the input feature maps
-        pad_dim(&new_input_shapes[1], dnums.kernel_input_feature_dimension(),
-                kernel_vect_size);
-        // Kernel for the output feature maps.  In the NCHW_VECT_C, only the
-        // kernel input feature dim is vectorized, so this has cur_vect_size 1.
-        pad_dim(&new_input_shapes[1], dnums.kernel_output_feature_dimension(),
-                /*cur_vect_size=*/1);
-
-        // Bias.  This ia 1D vector of length output-depth, and it's unclear if
-        // we *have* to pad it.  But hey, we might as well.  cur_vect_size 1
-        // because NCHW_VECT_C doesn't apply here (there is no channels
-        // dimension!).
-        pad_dim(&new_input_shapes[2], /*dim=*/0, /*cur_vect_size=*/1);
-
-        if (new_input_shapes.size() == 4) {
-          // Optional side input.  Same layout as result, so gets padded the
-          // same.
-          pad_dim(&new_input_shapes[3], dnums.output_feature_dimension(),
-                  result_vect_size);
-        }
-        // Output feature maps
-        pad_dim(&new_result_shape, dnums.output_feature_dimension(),
-                result_vect_size);
-        break;
-      default:
-        CHECK(false);
-    }
-
-    // We won't pad a conv if doing so increases the total number of bytes in
-    // the lhs, rhs, or result by a factor of this much or more.
-    //
-    // Note: It's important that this bound is exclusive.  It's a performance
-    // regression to pad and increase input/output size by 2x, so we only pad
-    // strictly less than 2x.
-    //
-    // TODO(jlebar): This number was tuned experimentally, but without much
-    // experimental evidence.
-    static constexpr double kMaxBytesTouchedBound = 2;
-
-    // Check that padding wouldn't increase the total bytes read/written by this
-    // operation too much.
-    auto check_size_increase = [&](const Shape& old_shape,
-                                   const Shape& new_shape) {
-      int64_t old_bytes = ShapeUtil::ByteSizeOf(old_shape);
-      int64_t new_bytes = ShapeUtil::ByteSizeOf(new_shape);
-      if (new_bytes < old_bytes * kMaxBytesTouchedBound) {
-        return true;
-      }
-      VLOG(3)
-          << "Not padding convolution; doing so would change input / result "
-             "shape from "
-          << ShapeUtil::HumanString(old_shape) << " to "
-          << ShapeUtil::HumanString(new_shape) << ", a size increase of "
-          << new_bytes / static_cast<double>(old_bytes)
-          << "x >= " << kMaxBytesTouchedBound << "x: " << conv->ToString();
-      return false;
-    };
-
-    // Check size increase only on the input and output.  No need to check the
-    // filter, since that's determined by the input/output.  The bias (if
-    // present) is tiny (1D array of length output-depth), so padding doesn't
-    // matter.  And the side-input, if present, is the same shape as the input.
-    if (!check_size_increase(conv->operand(0)->shape(), new_input_shapes[0]) ||
-        !check_size_increase(result_shape, new_result_shape)) {
-      return false;
-    }
-  }
-
-  bool changed = false;
-  for (int64_t i = 0; i < conv->operand_count(); ++i) {
-    changed |=
-        !ShapeUtil::Equal(conv->operand(i)->shape(), new_input_shapes[i]);
-  }
-  if (!changed) {
-    VLOG(3) << "No need to pad features of " << conv->ToString();
-  }
-
-  return changed;
-}
-
-absl::StatusOr<bool> CudnnPadForConvolutions::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  bool changed = false;
-  for (HloComputation* comp :
-       module->MakeNonfusionComputations(execution_threads)) {
-    for (HloCustomCallInstruction* conv : GetRelevantConvs(comp)) {
-      // On Turing and later (sm75+), pad to multiples of 32 bytes if possible,
-      // because that lets us use the fast int8x32 data type.
-      bool local_changed = false;
-      if (compute_capability_.IsAtLeast(7, 5)) {
-        TF_ASSIGN_OR_RETURN(
-            local_changed,
-            ResolveAndPad(conv, absl::bind_front(
-                                    TryResolvePaddedShapesForIntegerConvolution,
-                                    32, compute_capability_)));
-      }
-      if (!local_changed) {
-        TF_ASSIGN_OR_RETURN(
-            local_changed,
-            ResolveAndPad(conv, absl::bind_front(
-                                    TryResolvePaddedShapesForIntegerConvolution,
-                                    4, compute_capability_)));
-      }
-      changed |= local_changed;
-    }
-    if (compute_capability_.IsAtLeast(se::CudaComputeCapability::VOLTA)) {
-      for (HloCustomCallInstruction* conv : GetRelevantConvs(comp)) {
-        TF_ASSIGN_OR_RETURN(
-            bool local_changed,
-            ResolveAndPad(conv, TryResolvePaddedShapesForTensorCore));
-        changed |= local_changed;
-      }
-    }
-  }
-  return changed;
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/cudnn_pad_for_convolutions.h b/third_party/xla/xla/service/gpu/cudnn_pad_for_convolutions.h
deleted file mode 100644
index be7fae2..0000000
--- a/third_party/xla/xla/service/gpu/cudnn_pad_for_convolutions.h
+++ /dev/null
@@ -1,53 +0,0 @@
-/* Copyright 2018 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_CUDNN_PAD_FOR_CONVOLUTIONS_H_
-#define XLA_SERVICE_GPU_CUDNN_PAD_FOR_CONVOLUTIONS_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/util.h"
-
-namespace xla {
-namespace gpu {
-
-// Two zero-paddings for CuDNN thunking are done in this transform: padding for
-// tensor cores and padding for integer convolutions.  This transform also
-// add slice instruction to remove unnecessary output features.
-class CudnnPadForConvolutions : public HloModulePass {
- public:
-  explicit CudnnPadForConvolutions(se::CudaComputeCapability compute_capability)
-      : compute_capability_(compute_capability) {}
-
-  absl::string_view name() const override {
-    return "cudnn_pad_for_convolutions";
-  }
-  // Run PadForConvolutions on the given module and return if any change is made
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
-  const se::CudaComputeCapability compute_capability_;
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_CUDNN_PAD_FOR_CONVOLUTIONS_H_
diff --git a/third_party/xla/xla/service/gpu/cudnn_pad_for_convolutions_test.cc b/third_party/xla/xla/service/gpu/cudnn_pad_for_convolutions_test.cc
deleted file mode 100644
index 2bae239..0000000
--- a/third_party/xla/xla/service/gpu/cudnn_pad_for_convolutions_test.cc
+++ /dev/null
@@ -1,456 +0,0 @@
-/* Copyright 2018 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/cudnn_pad_for_convolutions.h"
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/service/hlo_parser.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/tests/hlo_test_base.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-namespace m = xla::match;
-
-class CudnnPadForConvolutionsTest : public HloTestBase {};
-
-TEST_F(CudnnPadForConvolutionsTest, DoNotPadF16ForwardConvWhenGrouped) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule TestModule
-
-  ENTRY TestComputation {
-    input = f16[704,48,1,49]{3,2,1,0} parameter(0)
-    filter = f16[44,768,1,50]{3,2,1,0} parameter(1)
-    ROOT result = (f16[1,128,48,768]{3,2,1,0}, u8[0]{0})
-      custom-call(input, filter)
-      , window={size=1x50 pad=0_0x64_64}
-      , dim_labels=fb01_io01->01bf
-      , feature_group_count=16
-      , custom_call_target="__cudnn$convForward"
-  })")
-                    .value();
-  EXPECT_FALSE(CudnnPadForConvolutions({7, 5}).Run(module.get()).value());
-}
-
-TEST_F(CudnnPadForConvolutionsTest, PadF16ForwardConvInputChannels) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule TestModule
-
-  ENTRY TestComputation {
-    input = f16[10,20,30,41] parameter(0)
-    filter = f16[2,2,41,40] parameter(1)
-    ROOT result = (f16[10,20,30,40], u8[0]) custom-call(input, filter),
-                  window={size=2x2}, dim_labels=b01f_01io->b01f,
-                  custom_call_target="__cudnn$convForward"
-  })")
-                    .value();
-  EXPECT_TRUE(CudnnPadForConvolutions({7, 0}).Run(module.get()).value());
-  auto* root = module->entry_computation()->root_instruction();
-
-  SCOPED_TRACE(module->ToString());
-
-  EXPECT_THAT(
-      root,
-      GmockMatch(m::CustomCall(
-          {kCudnnConvForwardCallTarget},
-          m::Pad(m::Parameter(0), m::Op()).WithShape(F16, {10, 20, 30, 48}),
-          m::Pad(m::Parameter(1), m::Op()).WithShape(F16, {2, 2, 48, 40}))));
-}
-
-TEST_F(CudnnPadForConvolutionsTest, PadF16BackwardInputConvOutputChannels) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule TestModule
-
-  ENTRY TestComputation {
-    output = f16[10,20,30,41] parameter(0)
-    filter = f16[2,2,40,41] parameter(1)
-    ROOT result = (f16[10,20,30,40], u8[0]) custom-call(output, filter),
-                  window={size=2x2}, dim_labels=b01f_01io->b01f,
-                  custom_call_target="__cudnn$convBackwardInput"
-  })")
-                    .value();
-  EXPECT_TRUE(CudnnPadForConvolutions({7, 0}).Run(module.get()).value());
-  auto* root = module->entry_computation()->root_instruction();
-  EXPECT_THAT(
-      root,
-      GmockMatch(m::CustomCall(
-          {kCudnnConvBackwardInputCallTarget},
-          m::Pad(m::Parameter(0), m::Op()).WithShape(F16, {10, 20, 30, 48}),
-          m::Pad(m::Parameter(1), m::Op()).WithShape(F16, {2, 2, 40, 48}))));
-}
-
-TEST_F(CudnnPadForConvolutionsTest, PadF16ForwardConvOutputChannels) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule TestModule
-
-  ENTRY TestComputation {
-    input = f16[10,20,30,40] parameter(0)
-    filter = f16[2,2,40,41] parameter(1)
-    ROOT result = (f16[10,20,30,41], u8[0]) custom-call(input, filter),
-                  window={size=2x2}, dim_labels=b01f_01io->b01f,
-                  custom_call_target="__cudnn$convForward"
-  })")
-                    .value();
-  EXPECT_TRUE(CudnnPadForConvolutions({7, 0}).Run(module.get()).value());
-  auto* root = module->entry_computation()->root_instruction();
-  EXPECT_THAT(root, GmockMatch(m::Tuple(
-                        m::Slice(m::GetTupleElement(m::CustomCall(
-                            {kCudnnConvForwardCallTarget}, m::Parameter(0),
-                            m::Pad(m::Parameter(1), m::Op())))),
-                        m::Op())));
-}
-
-TEST_F(CudnnPadForConvolutionsTest, PadF16BackwardInputConvInputChannels) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule TestModule
-
-  ENTRY TestComputation {
-    output = f16[10,20,30,40] parameter(0)
-    filter = f16[2,2,41,40] parameter(1)
-    result = (f16[10,20,30,41], u8[0]) custom-call(output, filter),
-              window={size=2x2}, dim_labels=b01f_01io->b01f,
-              custom_call_target="__cudnn$convBackwardInput"
-    ROOT gte = f16[10,20,30,41] get-tuple-element(result), index=0
-  })")
-                    .value();
-  EXPECT_TRUE(CudnnPadForConvolutions({7, 0}).Run(module.get()).value());
-  auto* root = module->entry_computation()->root_instruction();
-  EXPECT_THAT(root,
-              GmockMatch(m::GetTupleElement(m::Tuple(
-                  m::Slice(m::GetTupleElement(m::CustomCall(
-                      {kCudnnConvBackwardInputCallTarget}, m::Parameter(0),
-                      m::Pad(m::Parameter(1), m::Op())))),
-                  m::Op()))));
-}
-
-TEST_F(CudnnPadForConvolutionsTest, PadF16BackwardFilterConvInputChannels) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule TestModule
-
-  ENTRY TestComputation {
-    input = f16[10,20,30,41] parameter(0)
-    output = f16[10,20,30,40] parameter(1)
-    result = (f16[2,2,41,40], u8[0]) custom-call(input, output),
-              window={size=2x2}, dim_labels=b01f_01io->b01f,
-              custom_call_target="__cudnn$convBackwardFilter"
-    ROOT gte = f16[2,2,41,40] get-tuple-element(result), index=0
-  })")
-                    .value();
-  EXPECT_TRUE(CudnnPadForConvolutions({7, 0}).Run(module.get()).value());
-  auto* root = module->entry_computation()->root_instruction();
-  EXPECT_THAT(root,
-              GmockMatch(m::GetTupleElement(m::Tuple(
-                  m::Slice(m::GetTupleElement(m::CustomCall(
-                      {kCudnnConvBackwardFilterCallTarget},
-                      m::Pad(m::Parameter(0), m::Op()), m::Parameter(1)))),
-                  m::Op()))));
-}
-
-TEST_F(CudnnPadForConvolutionsTest, PadF16BackwardFilterConvOutputChannels) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule TestModule
-
-  ENTRY TestComputation {
-    input = f16[10,20,30,40] parameter(0)
-    output = f16[10,20,30,41] parameter(1)
-    result = (f16[2,2,40,41], u8[0]) custom-call(input, output),
-              window={size=2x2}, dim_labels=b01f_01io->b01f,
-              custom_call_target="__cudnn$convBackwardFilter"
-    ROOT gte = f16[2,2,40,41] get-tuple-element(result), index=0
-  })")
-                    .value();
-  EXPECT_TRUE(CudnnPadForConvolutions({7, 0}).Run(module.get()).value());
-  auto* root = module->entry_computation()->root_instruction();
-  EXPECT_THAT(root,
-              GmockMatch(m::GetTupleElement(m::Tuple(
-                  m::Slice(m::GetTupleElement(m::CustomCall(
-                      {kCudnnConvBackwardFilterCallTarget}, m::Parameter(0),
-                      m::Pad(m::Parameter(1), m::Op())))),
-                  m::Op()))));
-}
-
-TEST_F(CudnnPadForConvolutionsTest, PadInputFeatures3To4) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule TestModule
-
-  ENTRY TestComputation {
-    input = f16[10,20,30,3] parameter(0)
-    filter = f16[2,2,3,32] parameter(1)
-    ROOT result = (f16[10,20,30,32], u8[0]) custom-call(input, filter),
-                  window={size=2x2}, dim_labels=b01f_01io->b01f,
-                  custom_call_target="__cudnn$convForward"
-  })")
-                    .value();
-  EXPECT_TRUE(CudnnPadForConvolutions({7, 0}).Run(module.get()).value());
-  auto* root = module->entry_computation()->root_instruction();
-
-  SCOPED_TRACE(module->ToString());
-  EXPECT_THAT(
-      root,
-      GmockMatch(m::CustomCall(
-          {kCudnnConvForwardCallTarget},
-          m::Pad(m::Parameter(0), m::Op()).WithShape(F16, {10, 20, 30, 4}),
-          m::Pad(m::Parameter(1), m::Op()).WithShape(F16, {2, 2, 4, 32}))));
-}
-
-TEST_F(CudnnPadForConvolutionsTest, PadIntForwardConvInputChannels) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule TestModule
-
-  ENTRY TestComputation {
-    input = s8[10,20,30,41] parameter(0)
-    filter = s8[2,2,41,40] parameter(1)
-    ROOT result = (f32[10,20,30,40], u8[0]) custom-call(input, filter),
-                  window={size=2x2}, dim_labels=b01f_01io->b01f,
-                  custom_call_target="__cudnn$convForward"
-  })")
-                    .value();
-  EXPECT_TRUE(CudnnPadForConvolutions({7, 0}).Run(module.get()).value());
-  auto* root = module->entry_computation()->root_instruction();
-
-  SCOPED_TRACE(module->ToString());
-  EXPECT_THAT(
-      root,
-      GmockMatch(m::CustomCall(
-          {kCudnnConvForwardCallTarget},
-          m::Pad(m::Parameter(0), m::Op()).WithShape(S8, {10, 20, 30, 44}),
-          m::Pad(m::Parameter(1), m::Op()).WithShape(S8, {2, 2, 44, 40}))));
-}
-
-TEST_F(CudnnPadForConvolutionsTest, PadIntForwardConvOutputChannels) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule TestModule
-
-  ENTRY TestComputation {
-    input = s8[10,20,30,40] parameter(0)
-    filter = s8[2,2,40,41] parameter(1)
-    ROOT result = (f32[10,20,30,41], u8[0]) custom-call(input, filter),
-                  window={size=2x2}, dim_labels=b01f_01io->b01f,
-                  custom_call_target="__cudnn$convForward"
-  })")
-                    .value();
-  EXPECT_TRUE(CudnnPadForConvolutions({7, 0}).Run(module.get()).value());
-  auto* root = module->entry_computation()->root_instruction();
-  EXPECT_THAT(root, GmockMatch(m::Tuple(
-                        m::Slice(m::GetTupleElement(m::CustomCall(
-                            {kCudnnConvForwardCallTarget}, m::Parameter(0),
-                            m::Pad(m::Parameter(1), m::Op())))),
-                        m::Op())));
-}
-
-TEST_F(CudnnPadForConvolutionsTest, PadInt8To32OnSm75) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule TestModule
-
-  ENTRY TestComputation {
-    input = s8[10,20,30,40] parameter(0)
-    filter = s8[2,2,40,41] parameter(1)
-    ROOT result = (s8[10,20,30,41], u8[0]) custom-call(input, filter),
-                  window={size=2x2}, dim_labels=b01f_01io->b01f,
-                  custom_call_target="__cudnn$convForward"
-  })")
-                    .value();
-  EXPECT_TRUE(CudnnPadForConvolutions({7, 5}).Run(module.get()).value());
-  auto* root = module->entry_computation()->root_instruction();
-  EXPECT_THAT(
-      root,
-      GmockMatch(m::Tuple(
-          m::Slice(m::GetTupleElement(m::CustomCall(
-              {kCudnnConvForwardCallTarget},
-              m::Pad(m::Parameter(0), m::Op()).WithShape(S8, {10, 20, 30, 64}),
-              m::Pad(m::Parameter(1), m::Op()).WithShape(S8, {2, 2, 64, 64})))),
-          m::Op())));
-}
-
-TEST_F(CudnnPadForConvolutionsTest, NoPadInt8To32OnSm70) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule TestModule
-
-  ENTRY TestComputation {
-    input = s8[10,20,30,40] parameter(0)
-    filter = s8[2,2,40,41] parameter(1)
-    ROOT result = (s8[10,20,30,41], u8[0]) custom-call(input, filter),
-                  window={size=2x2}, dim_labels=b01f_01io->b01f,
-                  custom_call_target="__cudnn$convForward"
-  })")
-                    .value();
-  EXPECT_TRUE(CudnnPadForConvolutions({7, 0}).Run(module.get()).value());
-  auto* root = module->entry_computation()->root_instruction();
-  EXPECT_THAT(
-      root,
-      GmockMatch(m::Tuple(
-          m::Slice(m::GetTupleElement(m::CustomCall(
-              {kCudnnConvForwardCallTarget}, m::Parameter(0),
-              m::Pad(m::Parameter(1), m::Op()).WithShape(S8, {2, 2, 40, 44})))),
-          m::Op())));
-}
-
-TEST_F(CudnnPadForConvolutionsTest, NoPadInt8To32FloatOutputSm75) {
-  // This test checks that the padding pass correctly calls
-  // CudnnSupportsOptimizedIntegerConvolution() which should reject this
-  // convolution because its output type is f32. It should be padded to int8x4
-  // because that supports float outputs.
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule TestModule
-
-  ENTRY TestComputation {
-    input = s8[10,20,30,38] parameter(0)
-    filter = s8[2,2,38,41] parameter(1)
-    ROOT result = (f32[10,20,30,41], u8[0]) custom-call(input, filter),
-                  window={size=2x2}, dim_labels=b01f_01io->b01f,
-                  custom_call_target="__cudnn$convForward"
-  })")
-                    .value();
-  EXPECT_TRUE(CudnnPadForConvolutions({7, 5}).Run(module.get()).value());
-  auto* root = module->entry_computation()->root_instruction();
-  EXPECT_THAT(
-      root,
-      GmockMatch(m::Tuple(
-          m::Slice(m::GetTupleElement(m::CustomCall(
-              {kCudnnConvForwardCallTarget},
-              m::Pad(m::Parameter(0), m::Op()).WithShape(S8, {10, 20, 30, 40}),
-              m::Pad(m::Parameter(1), m::Op()).WithShape(S8, {2, 2, 40, 44})))),
-          m::Op())));
-}
-
-TEST_F(CudnnPadForConvolutionsTest, NoPadInt8UnsupportedFilterTypeOutputSm75) {
-  // This test checks that the padding pass correctly calls
-  // CudnnSupportsOptimizedIntegerConvolution() which should reject this
-  // convolution because kernel type is f32.
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule TestModule
-
-  ENTRY TestComputation {
-    input = s8[10,20,30,38] parameter(0)
-    filter = f32[2,2,38,41] parameter(1)
-    ROOT result = (s8[10,20,30,41], u8[0]) custom-call(input, filter),
-                  window={size=2x2}, dim_labels=b01f_01io->b01f,
-                  custom_call_target="__cudnn$convForward"
-  })")
-                    .value();
-  EXPECT_FALSE(CudnnPadForConvolutions({7, 5}).Run(module.get()).value());
-}
-
-TEST_F(CudnnPadForConvolutionsTest, NoPadToInt8x32ExcessiveBlowup) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule TestModule
-
-  ENTRY TestComputation {
-    input = s8[128,4,48,48] parameter(0)
-    filter = s8[64,4,3,3] parameter(1)
-    ROOT result = (f32[128,64,48,48], u8[0]) custom-call(input, filter),
-                  window={size=3x3}, dim_labels=bf01_io01->bf01,
-                  custom_call_target="__cudnn$convForward"
-  })")
-                    .value();
-  EXPECT_FALSE(CudnnPadForConvolutions({7, 5}).Run(module.get()).value());
-}
-
-TEST_F(CudnnPadForConvolutionsTest, PadInt8x4To32) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule TestModule
-
-  ENTRY TestComputation {
-    input = s8[10,20,30,41,4] parameter(0)
-    filter = s8[2,2,41,4,168] parameter(1)
-    ROOT result = (s8[10,20,30,42,4], u8[0]) custom-call(input, filter),
-                  window={size=2x2}, dim_labels=b01f?_01i?o->b01f?,
-                  custom_call_target="__cudnn$convForward"
-  })")
-                    .value();
-  EXPECT_TRUE(CudnnPadForConvolutions({7, 5}).Run(module.get()).value());
-  auto* root = module->entry_computation()->root_instruction();
-  EXPECT_THAT(
-      root,
-      GmockMatch(m::Tuple(
-          m::Slice(m::GetTupleElement(
-                       m::CustomCall({kCudnnConvForwardCallTarget},
-                                     m::Pad(m::Parameter(0), m::Op())
-                                         .WithShape(S8, {10, 20, 30, 48, 4}),
-                                     m::Pad(m::Parameter(1), m::Op())
-                                         .WithShape(S8, {2, 2, 48, 4, 192})))
-                       .WithShape(S8, {10, 20, 30, 48, 4})),
-          m::Op())));
-}
-
-TEST_F(CudnnPadForConvolutionsTest, PadInt8x4To32BiasActivation) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule TestModule
-
-  ENTRY TestComputation {
-    input = s8[10,20,30,41,4] parameter(0)
-    filter = s8[2,2,41,4,168] parameter(1)
-    bias = f32[10] parameter(2)
-    side_input = s8[10,20,30,42,4] parameter(3)
-    ROOT result = (s8[10,20,30,42,4], u8[0]) custom-call(input, filter, bias, side_input),
-                  window={size=2x2}, dim_labels=b01f?_01i?o->b01f?,
-                  custom_call_target="__cudnn$convBiasActivationForward"
-  })")
-                    .value();
-  EXPECT_TRUE(CudnnPadForConvolutions({7, 5}).Run(module.get()).value());
-  auto* root = module->entry_computation()->root_instruction();
-  EXPECT_THAT(
-      root,
-      GmockMatch(m::Tuple(
-          m::Slice(
-              m::GetTupleElement(
-                  m::CustomCall(
-                      {kCudnnConvBiasActivationForwardCallTarget},
-                      m::Pad(m::Parameter(0), m::Op())
-                          .WithShape(S8, {10, 20, 30, 48, 4}),
-                      m::Pad(m::Parameter(1), m::Op())
-                          .WithShape(S8, {2, 2, 48, 4, 192}),
-                      m::Pad(m::Parameter(2), m::Op()).WithShape(F32, {32}),
-                      m::Pad(m::Parameter(3), m::Op())
-                          .WithShape(S8, {10, 20, 30, 48, 4})))
-                  .WithShape(S8, {10, 20, 30, 48, 4})),
-          m::Op())));
-}
-
-TEST_F(CudnnPadForConvolutionsTest,
-       PadIntFusedForwardConvInputAndOutputChannels) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule Test
-
-    ENTRY %Test (input: s8[1,3,3,2], filter: s8[3,3,2,5], side_input: s8[1,3,3,5], bias: s8[5]) -> f32[1,3,3,5] {
-    %input = s8[1,3,3,3]{3,2,1,0} parameter(0)
-    %filter = s8[3,3,2,5]{3,2,1,0} parameter(1)
-    %bias = s8[5]{0} parameter(3)
-    %convert = f32[5]{0} convert(s8[5]{0} %bias)
-    %side_input = f32[1,3,3,5]{3,2,1,0} parameter(2)
-    %custom-call.1 = (f32[1,3,3,5]{3,2,1,0}, u8[0]{0}) custom-call(s8[1,3,3,3]{3,2,1,0} %input, s8[3,3,2,5]{3,2,1,0} %filter, f32[5]{0} %convert, f32[1,3,3,5]{3,2,1,0} %side_input), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convBiasActivationForward", backend_config="{\"activationMode\":\"2\",\"convResultScale\":1,\"sideInputScale\":1}"
-    ROOT %get-tuple-element.1 = f32[1,3,3,5]{3,2,1,0} get-tuple-element((f32[1,3,3,5]{3,2,1,0}, u8[0]{0}) %custom-call.1), index=0
-    })")
-                    .value();
-  EXPECT_TRUE(CudnnPadForConvolutions({7, 0}).Run(module.get()).value());
-  auto* root = module->entry_computation()->root_instruction();
-  EXPECT_THAT(root, GmockMatch(m::GetTupleElement(m::Tuple(
-                        m::Slice(m::GetTupleElement(m::CustomCall(
-                            {kCudnnConvBiasActivationForwardCallTarget},
-                            m::Pad(m::Parameter(0), m::Op()),
-                            m::Pad(m::Parameter(1), m::Op()),
-                            m::Pad(m::Convert(m::Parameter(3)), m::Op()),
-                            m::Pad(m::Parameter(2), m::Op())))),
-                        m::Op()))));
-}
-
-}  // anonymous namespace
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/cudnn_simplify_padding.cc b/third_party/xla/xla/service/gpu/cudnn_simplify_padding.cc
deleted file mode 100644
index c8f87f7..0000000
--- a/third_party/xla/xla/service/gpu/cudnn_simplify_padding.cc
+++ /dev/null
@@ -1,482 +0,0 @@
-/* Copyright 2022 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/cudnn_simplify_padding.h"
-
-#include <algorithm>
-#include <cstdint>
-#include <iterator>
-#include <optional>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/container/inlined_vector.h"
-#include "absl/strings/str_join.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/literal.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/service/hlo_creation_utils.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla::gpu {
-
-namespace {
-namespace m = ::xla::match;
-
-// If exactly one index of `dims` is false, returns that index.  If 0 or more
-// than one index is false, returns nullopt.
-std::optional<int64_t> FindFalseIndex(absl::Span<const bool> vals) {
-  std::optional<int64_t> missing_dim;
-  for (int i = 0; i < vals.size(); i++) {
-    if (vals[i]) {
-      continue;
-    }
-    if (missing_dim.has_value()) {
-      VLOG(2) << "Multiple dimensions are missing from conv dnums; can't "
-                 "determine which is vect_c dimension";
-      return std::nullopt;
-    }
-    missing_dim = i;
-  }
-  return missing_dim;
-}
-
-// Finds the vect_c dimension in the convolution's output.
-//
-// The vect_c dimension in dnums is the dimension that's not mentioned in
-// `dnums`.  If there's zero or more than one such dimension, returns nullopt.
-std::optional<int64_t> FindOutputVectCDim(HloInstruction* conv) {
-  const ConvolutionDimensionNumbers& dnums =
-      conv->convolution_dimension_numbers();
-  int64_t num_dims = conv->shape().tuple_shapes(0).dimensions_size();
-  absl::InlinedVector<bool, 5> seen_dims(num_dims);
-  seen_dims[dnums.output_batch_dimension()] = true;
-  seen_dims[dnums.output_feature_dimension()] = true;
-  for (int64_t d : dnums.output_spatial_dimensions()) {
-    seen_dims[d] = true;
-  }
-  return FindFalseIndex(seen_dims);
-}
-
-// Finds the vect_c dimension in the convolution's kernel.
-std::optional<int64_t> FindKernelVectCDim(HloInstruction* conv) {
-  const ConvolutionDimensionNumbers& dnums =
-      conv->convolution_dimension_numbers();
-  int64_t num_dims = conv->operand(1)->shape().dimensions_size();
-  absl::InlinedVector<bool, 5> seen_dims(num_dims);
-  seen_dims[dnums.kernel_input_feature_dimension()] = true;
-  seen_dims[dnums.kernel_output_feature_dimension()] = true;
-  for (int64_t d : dnums.kernel_spatial_dimensions()) {
-    seen_dims[d] = true;
-  }
-  return FindFalseIndex(seen_dims);
-}
-
-// Attempts to count the number of output features at the end of conv that are
-// guaranteed to be 0.
-//
-// This is the same as counting the number of values o at the end of the kernel
-// for which kernel[i,o,h,w] is 0 for all values i,h,w.
-std::optional<int64_t> NumTrailingZeroOutputFeatures(HloInstruction* conv) {
-  const ConvolutionDimensionNumbers& dnums =
-      conv->convolution_dimension_numbers();
-  int64_t feature_dim = dnums.kernel_output_feature_dimension();
-  const HloInstruction* weights = conv->operand(1);
-
-  // If the filter is reordered for an int8x32 NCHW_VECT_C convolution, find the
-  // original, un-reordered filter and check *it* for trailing zero output
-  // features.
-  auto backend_config = conv->backend_config<GpuBackendConfig>();
-  if (backend_config.ok() &&
-      backend_config->cudnn_conv_backend_config().reordered_int8_nchw_vect()) {
-    VLOG(2) << "Matched int8x32 convolution with filter reordering";
-
-    // Try to set weights to the original, un-reordered value.
-    const HloInstruction *reshape, *transpose;
-    bool matched =
-        Match(weights, m::Reshape(m::Transpose(
-                           &transpose, m::Reshape(&reshape, m::Op(&weights)))));
-
-    // Verify some properties of the reshape-transpose-reshape pattern.
-    // If these don't hold, it means that some pass (e.g. constant folding)
-    // has modified the filter, making making it infeasible to get the original,
-    // un-reordered value.
-    if (!matched || feature_dim != 0 || transpose->shape().rank() != 8) {
-      VLOG(2) << "The filter output feature dimension cannot be determined, as "
-                 "the reordering sequence is modified";
-      return std::nullopt;
-    }
-
-    // Calculate the actual output feature dimension before the transpose.
-    // For example: the input filter [I, O, H, W] will get reshaped to
-    // [I/32, 8, 4, O/8, 4, 2, H, W], transposed in a way that is compatible
-    // with cuDNN INT8x32_CONFIG convolutions (see 'cudnn_support_utils.h') and
-    // reshaped again to [O, I/32, H, W, 32]. While the output features
-    // dimension is zero, we need to get the dimension in the original shape
-    // (equals to one in this example).
-    const auto& transpose_dimensions =
-        Cast<HloTransposeInstruction>(transpose)->dimensions();
-
-    // Calculate combined dimensions size before the first appearing output
-    // component [O/8], which appears in position 3 of the transpose.
-    int64_t preceding_size = 1;
-    for (int64_t i = transpose_dimensions.at(3) - 1; i >= 0; --i) {
-      preceding_size *= reshape->shape().dimensions(i);
-    }
-
-    // Skip dimensions in front until the output features dimension is found.
-    int64_t accumulated_size = 1;
-    for (int64_t size : weights->shape().dimensions()) {
-      if (accumulated_size < preceding_size) {
-        accumulated_size *= size;
-        ++feature_dim;
-      } else {
-        break;
-      }
-    }
-    // Sanity check; if this condition doesn't hold, something is off.
-    if (accumulated_size != preceding_size) {
-      VLOG(2) << "Something is really wrong here, I give up";
-      return std::nullopt;
-    }
-    VLOG(2) << "Computed output feature dimension: " << feature_dim;
-  }
-
-  VLOG(2) << "Computing NumTrailingZeroOutputFeatures of " << conv->ToString()
-          << "\nwith weights " << weights->ToString();
-  if (Match(weights, m::Pad(m::Op(), m::ConstantEffectiveScalar(0)))) {
-    const PaddingConfig::PaddingConfigDimension& padding_config =
-        weights->padding_config().dimensions(feature_dim);
-    // The last N output feature weights are all 0.
-    VLOG(2) << "Success: Weights is a pad; padding on output feature dim is "
-            << padding_config.edge_padding_high();
-    return padding_config.edge_padding_high();
-  } else if (const HloInstruction * pad; Match(
-                 weights, m::Reshape(m::Pad(&pad, m::Op(),
-                                            m::ConstantEffectiveScalar(0))))) {
-    // Check that the reshape merely adds a VECT_C to the kernel input features.
-    // That is, we reshape from [I,O,H,W] (in some order) to [I/k,k,O,H,W] (in
-    // the same order) for some constant k (probably 32).  Then check how much
-    // the pad adds to the O dimension.
-    std::optional<int64_t> vect_c_dim = FindKernelVectCDim(conv);
-    if (!vect_c_dim.has_value()) {
-      VLOG(2) << "fail: Can't find vect_c dimension in conv.";
-      return std::nullopt;
-    }
-    if (*vect_c_dim != dnums.kernel_input_feature_dimension() + 1) {
-      VLOG(2) << "fail: vect_c dim is in the wrong place; should be right "
-                 "after kernel input feature dims in conv.";
-      return std::nullopt;
-    }
-    absl::InlinedVector<int64_t, 5> expected_pad_dim_sizes(
-        weights->shape().dimensions().begin(),
-        weights->shape().dimensions().end());
-    expected_pad_dim_sizes[dnums.kernel_input_feature_dimension()] *=
-        weights->shape().dimensions(*vect_c_dim);
-    expected_pad_dim_sizes.erase(expected_pad_dim_sizes.begin() + *vect_c_dim);
-    if (pad->shape().dimensions() != expected_pad_dim_sizes) {
-      VLOG(2) << "fail: Reshape doesn't simply merge vect_c dimension into "
-                 "input features dim "
-              << weights->ToString() << " but expected dims "
-              << absl::StrJoin(expected_pad_dim_sizes, ",");
-      return std::nullopt;
-    }
-
-    // If the filter dnums are e.g. [I,O,H,W] then after reshape they are
-    // [I/k,k,O,H,W] and the new index of O is greater less than before the
-    // reshape (which we know only adds the I/k and k dims, which we also know
-    // are contiguous).  OTOH if the O comes before the I in the original, then
-    // the index of O doesn't change after the reshape.
-    int64_t feature_dim_before_reshape = feature_dim;
-    if (dnums.kernel_output_feature_dimension() >
-        dnums.kernel_input_feature_dimension()) {
-      feature_dim_before_reshape--;
-    }
-    const PaddingConfig::PaddingConfigDimension& padding_config =
-        pad->padding_config().dimensions(feature_dim_before_reshape);
-
-    // The last N output feature weights are all 0.
-    VLOG(2) << "Success: Weights is a reshape of a pad; padding on output "
-               "feature dim is "
-            << padding_config.edge_padding_high();
-    return padding_config.edge_padding_high();
-  } else if (Match(weights, m::Constant())) {
-    // Iterate backwards over `weights` to find the index of the first nonzero
-    // value.
-    //
-    // TODO(jlebar): This may be slow, because it iterates over potentially the
-    // whole constant and does a multi_index -> linear_index conversion for each
-    // element. If necessary we could rewrite this by using linear indices, but
-    // we'd have to be careful of the fact that literals can have arbitrary
-    // layouts, so you can't just iterate over the literal's bytes.
-    const Literal& lit = weights->literal();
-    const auto& dims = weights->shape().dimensions();
-    absl::InlinedVector<int64_t, 5> multi_index;
-    for (int64_t dim : dims) {
-      multi_index.push_back(dim - 1);
-    }
-    // This iterates through the literal with feature_dim as the most
-    // major dimension looking for the final non-zero feature.
-    auto decrement_multi_index = [&] {
-      for (int i = 0; i < multi_index.size(); ++i) {
-        if (i != feature_dim) {
-          int64_t& idx = multi_index[i];
-          --idx;
-          if (idx == -1) {
-            idx = dims[i] - 1;
-          } else {
-            return true;
-          }
-        }
-      }
-      int64_t& idx = multi_index[feature_dim];
-      --idx;
-      return idx != -1;
-    };
-    do {
-      if (!lit.IsZero(multi_index)) {
-        break;
-      }
-    } while (decrement_multi_index());
-
-    // The iteration stops if a feature has a non-zero value (or -1), but we
-    // want the first zero feature which is always the next one (or 0 if -1).
-    int64_t first_trailing_zero_feature = multi_index[feature_dim] + 1;
-
-    if (first_trailing_zero_feature == 0) {
-      VLOG(2) << "Weights constant is entirely zero.";
-    } else {
-      VLOG(2) << "First nonzero index in weights constant is "
-              << absl::StrJoin(multi_index, ",");
-    }
-    int64_t ret =
-        std::max<int64_t>(0, weights->shape().dimensions(feature_dim) -
-                                 first_trailing_zero_feature);
-    VLOG(2) << "Success: weights is a constant; num zero trailing output "
-               "features is "
-            << ret;
-    return ret;
-  }
-  return std::nullopt;
-}
-
-absl::StatusOr<bool> TrySimplifyPadding(HloInstruction* instr) {
-  // Match one of the following patterns.
-  //   conv -> slice -> pad
-  //   conv -> reshape -> slice-> pad
-  //   conv -> transpose -> reshape -> slice -> pad
-  //
-  // where `pad` (the root of the pattern) is `instr`.
-  HloInstruction* conv;
-  HloInstruction* transpose = nullptr;  // optional
-  HloInstruction* reshape = nullptr;    // optional
-  HloInstruction* slice;
-  HloInstruction* pad;
-  auto conv_matcher = m::GetTupleElement(
-      m::CustomCall(&conv).WithPredicate([](const HloInstruction* instr) {
-        return instr->custom_call_target() == kCudnnConvForwardCallTarget ||
-               instr->custom_call_target() ==
-                   kCudnnConvBiasActivationForwardCallTarget;
-      }),
-      0);
-  auto pad_matcher = m::Pad(m::Op(), m::ConstantEffectiveScalar(0));
-  if (!MatchAndLogIfFailed(instr, "conv-slice-pad",
-                           m::Pad(&pad, m::Slice(&slice, conv_matcher),
-                                  m::ConstantEffectiveScalar(0)),
-                           VLOG_IS_ON(3), pad_matcher) &&
-      !MatchAndLogIfFailed(
-          instr, "conv-reshape-slice-pad",
-          m::Pad(&pad, m::Slice(&slice, m::Reshape(&reshape, conv_matcher)),
-                 m::ConstantEffectiveScalar(0)),
-          VLOG_IS_ON(3), pad_matcher) &&
-      !MatchAndLogIfFailed(
-          instr, "conv-transpose-reshape-slice-pad",
-          m::Pad(&pad,
-                 m::Slice(&slice,
-                          m::Reshape(&reshape,
-                                     m::Transpose(&transpose, conv_matcher))),
-                 m::ConstantEffectiveScalar(0)),
-          VLOG_IS_ON(3), pad_matcher)) {
-    return false;
-  }
-
-  VLOG(2) << "Found pattern to attempt to simplify:\n"
-          << "conv: " << conv->ToString()  //
-          << "\ntranspose: "
-          << (transpose != nullptr ? transpose->ToString() : "(null)")
-          << "\nreshape: "
-          << (reshape != nullptr ? reshape->ToString() : "(null)")
-          << "\nslice: " << slice->ToString()  //
-          << "\npad: " << pad->ToString();
-
-  // Now check that we can merge the slice into the pad, because the slice is
-  // slicing off elements that we know are 0 and the pad is just adding those 0s
-  // back.
-  //
-  // First, we have to check whether any of the output features at the end of
-  // the conv are known to be 0.
-  std::optional<int64_t> num_known_zero_output_features =
-      NumTrailingZeroOutputFeatures(conv);
-  if (!num_known_zero_output_features.has_value() ||
-      *num_known_zero_output_features == 0) {
-    VLOG(2) << "fail: Didn't find any known-zero output features";
-    return false;
-  }
-
-  // We now know that some of the output features of the conv (starting at
-  // known_zero_output_features_start_idx) are zero.  Check if the
-  // optional-reshape + optional-transpose + slice + pad combination is setting
-  // all of these features to 0.  If so, we can merge the slice into the pad.
-  const auto& dnums = conv->convolution_dimension_numbers();
-  int64_t output_feature_dim;
-  if (reshape == nullptr) {
-    CHECK_EQ(transpose, nullptr);
-    output_feature_dim = dnums.output_feature_dimension();
-  } else {
-    std::optional<int64_t> vect_c_dim_before_transpose =
-        FindOutputVectCDim(conv);
-    if (!vect_c_dim_before_transpose.has_value()) {
-      VLOG(2) << "Couldn't find vect_c output dim in conv.";
-      return false;
-    }
-
-    // If there's no transpose, check that the vect_c dim is immediately after
-    // the feature dim.  OTOH if there is a transpose, check that the transpose
-    // moves the vect_c dim immediately after the feature dim.
-    int64_t feature_dim_after_transpose;
-    int64_t vect_c_dim_after_transpose;
-    if (transpose == nullptr) {
-      feature_dim_after_transpose = dnums.output_feature_dimension();
-      vect_c_dim_after_transpose = *vect_c_dim_before_transpose;
-    } else {
-      const auto& transpose_dims = transpose->dimensions();
-      feature_dim_after_transpose = std::distance(
-          transpose->dimensions().begin(),
-          absl::c_find(transpose_dims, dnums.output_feature_dimension()));
-      vect_c_dim_after_transpose = std::distance(
-          transpose->dimensions().begin(),
-          absl::c_find(transpose_dims, *vect_c_dim_before_transpose));
-    }
-    if (vect_c_dim_after_transpose != feature_dim_after_transpose + 1) {
-      VLOG(2) << "fail: after transpose (if present), vect_c dim must appear "
-                 "immediately after output feature dim: Computed "
-                 "vect_d_dim_after_transpose to be "
-              << vect_c_dim_after_transpose;
-      return false;
-    }
-
-    // Now check that the reshape merges the feature + vect_c dims and
-    // doesn't do anything else.
-    absl::InlinedVector<int64_t, 5> expected_reshape_dim_sizes(
-        reshape->operand(0)->shape().dimensions().begin(),
-        reshape->operand(0)->shape().dimensions().end());
-    expected_reshape_dim_sizes[feature_dim_after_transpose] *=
-        expected_reshape_dim_sizes[vect_c_dim_after_transpose];
-    expected_reshape_dim_sizes.erase(expected_reshape_dim_sizes.begin() +
-                                     vect_c_dim_after_transpose);
-    if (reshape->shape().dimensions() != expected_reshape_dim_sizes) {
-      VLOG(2) << "fail: Reshape doesn't merge vect_c with feature dimension.";
-      return false;
-    }
-
-    output_feature_dim = feature_dim_after_transpose;
-  }
-
-  // Check that `slice` slices only the output feature dimension.
-  if (!absl::c_all_of(slice->slice_starts(), [](auto v) { return v == 0; }) ||
-      !absl::c_all_of(slice->slice_strides(), [](auto v) { return v == 1; })) {
-    VLOG(2) << "fail: Slice doesn't start at the front or has stride != 1.";
-    return false;
-  }
-
-  // We're only allowed to slice the feature dim.
-  for (int64_t dim = 0; dim < slice->slice_limits().size(); dim++) {
-    if (slice->slice_starts(dim) != 0 || slice->slice_strides(dim) != 1 ||
-        (dim != output_feature_dim &&
-         slice->slice_limits(dim) !=
-             slice->operand(0)->shape().dimensions(dim))) {
-      VLOG(2) << "fail: Slice removes something other than the features dim.";
-      return false;
-    }
-  }
-  int64_t num_sliced_from_feature_dim =
-      slice->operand(0)->shape().dimensions(output_feature_dim) -
-      slice->slice_limits(output_feature_dim);
-
-  // If we slice off more than the known-zero output features, then we need to
-  // keep the slice -- it's "doing something".
-  if (num_sliced_from_feature_dim > *num_known_zero_output_features) {
-    VLOG(2) << "fail: Slice removes " << num_sliced_from_feature_dim
-            << " features from the conv, but only "
-            << *num_known_zero_output_features
-            << " features in the conv are known to be zero.";
-    return false;
-  }
-
-  // Check if we can merge the slice into the pad.
-  if (pad->padding_config().dimensions(output_feature_dim).interior_padding() !=
-      0) {
-    VLOG(2)
-        << "fail: Can't merge slice into pad because pad adds interior padding "
-           "in feature dimension.";
-    return false;
-  }
-
-  // Okay!  If we got here, it's legal to fold the slice into the pad.  We pad
-  // less, because we know that the sliced-off elements are all 0.  Ideally, the
-  // pad becomes a nop and gets eliminated by algsimp later.
-  VLOG(1) << "Eliminating " << num_sliced_from_feature_dim
-          << " elements of padding from conv " << conv->name();
-  PaddingConfig new_padding_config = pad->padding_config();
-  PaddingConfig::PaddingConfigDimension* new_pad_feature_dim =
-      new_padding_config.mutable_dimensions(output_feature_dim);
-  // This is safe even if the new edge_padding_high is negative -- negative
-  // padding is allowed.
-  new_pad_feature_dim->set_edge_padding_high(
-      new_pad_feature_dim->edge_padding_high() - num_sliced_from_feature_dim);
-  TF_ASSIGN_OR_RETURN(HloInstruction * new_pad,
-                      MakePadHlo(slice->mutable_operand(0),
-                                 pad->mutable_operand(1), new_padding_config));
-  TF_RETURN_IF_ERROR(pad->parent()->ReplaceInstruction(pad, new_pad));
-  return true;
-}
-
-}  // anonymous namespace
-
-absl::StatusOr<bool> CudnnSimplifyPadding::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  bool changed = false;
-  for (HloComputation* comp :
-       module->MakeNonfusionComputations(execution_threads)) {
-    for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
-      TF_ASSIGN_OR_RETURN(bool c, TrySimplifyPadding(instr));
-      changed |= c;
-    }
-  }
-  return changed;
-}
-
-}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/cudnn_simplify_padding.h b/third_party/xla/xla/service/gpu/cudnn_simplify_padding.h
deleted file mode 100644
index 5811d26..0000000
--- a/third_party/xla/xla/service/gpu/cudnn_simplify_padding.h
+++ /dev/null
@@ -1,67 +0,0 @@
-/* Copyright 2022 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_CUDNN_SIMPLIFY_PADDING_H_
-#define XLA_SERVICE_GPU_CUDNN_SIMPLIFY_PADDING_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla::gpu {
-
-// Simplifies or eliminates padding introduced by CudnnPadForConvolutions and
-// CudnnVectorizeConvolutions.
-//
-// CudnnVectorizeConvolutions will generate code that does the following.
-//  - pad input and output features to a multiple of 32 (or 4),
-//  - reshape input from [N,C,H,W] to [N,C/32,H,W,32] and reshape kernel from
-//    [I,O,H,W] to [I/32,32,O,H,W],
-//  - run the conv,
-//  - reshape output from [N,C/32,H,W,32] to [N,C,H,W], and finally
-//  - slice off the padding on the C channel.
-//
-// But if this is followed by another convolution (very common), then the slice
-// is immediately followed by another pad. This may be redundant; we know that
-// the trailing channels sliced off from the first conv are 0.
-//
-// Ideally we can eliminate the whole reshape+slice+pad+reshape sequence between
-// the two convolutions.
-//
-// Specifically, this pass tries to merge the slice at the end of the sequence
-// above into the pad from the next convolution (when we can prove that the
-// sliced-off elements are all 0). We then rely on algsimp to remove the pad if
-// it's a nop and then to merge and eliminate the remaining reshapes.
-//
-// This pass should run after CudnnVectorizeConvolutions and there should be no
-// simplification passes in between that modify the reshape-transpose-reshape
-// introduced by int8x32 convolution filter reordering.
-class CudnnSimplifyPadding : public HloModulePass {
- public:
-  CudnnSimplifyPadding() = default;
-
-  absl::string_view name() const override { return "cudnn_simplify_padding"; }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-};
-
-}  // namespace xla::gpu
-
-#endif  // XLA_SERVICE_GPU_CUDNN_SIMPLIFY_PADDING_H_
diff --git a/third_party/xla/xla/service/gpu/cudnn_simplify_padding_test.cc b/third_party/xla/xla/service/gpu/cudnn_simplify_padding_test.cc
deleted file mode 100644
index a0e527a..0000000
--- a/third_party/xla/xla/service/gpu/cudnn_simplify_padding_test.cc
+++ /dev/null
@@ -1,771 +0,0 @@
-/* Copyright 2021 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/cudnn_simplify_padding.h"
-
-#include <cstdint>
-#include <memory>
-#include <utility>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "absl/functional/function_ref.h"
-#include "absl/strings/str_cat.h"
-#include "absl/types/span.h"
-#include "xla/literal.h"
-#include "xla/service/algebraic_simplifier.h"
-#include "xla/service/call_inliner.h"
-#include "xla/service/gpu/cudnn_pad_for_convolutions.h"
-#include "xla/service/gpu/cudnn_vectorize_convolutions.h"
-#include "xla/service/hlo_pass_fix.h"
-#include "xla/service/hlo_pass_pipeline.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/service/reshape_mover.h"
-#include "xla/service/tuple_simplifier.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/stream_executor/dnn.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/tsl/lib/core/status_test_util.h"
-#include "xla/util.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla::gpu {
-namespace {
-
-namespace m = ::xla::match;
-
-class CudnnSimplifyPaddingTest : public HloTestBase {
- protected:
-  // Runs the whole relevant pass pipeline starting at CudnnPadForConvolutions.
-  // This lets us test that we're matching the patterns that actually get
-  // generated by padding+vectorization.
-  absl::StatusOr<bool> RunEndToEnd(std::pair<int, int> compute_capability,
-                                   HloModule* module) {
-    se::CudaComputeCapability cc{compute_capability.first,
-                                 compute_capability.second};
-
-    TF_RETURN_IF_ERROR(
-        RunHloPass(CudnnPadForConvolutions(cc), module).status());
-
-    TF_RETURN_IF_ERROR(
-        RunHloPass(CudnnVectorizeConvolutions(
-                       cc, /*cudnn_version=*/se::dnn::VersionInfo{8, 3, 0}),
-                   module)
-            .status());
-    VLOG(1) << "after vectorizing convs:\n" << module->ToString();
-
-    TF_RETURN_IF_ERROR(RunHloPass(CallInliner(), module).status());
-    VLOG(1) << "after inliner:\n" << module->ToString();
-
-    TF_RETURN_IF_ERROR(RunHloPass(TupleSimplifier(), module).status());
-    VLOG(1) << "after tuple simplifier:\n" << module->ToString();
-
-    TF_ASSIGN_OR_RETURN(bool changed,
-                        RunHloPass(CudnnSimplifyPadding(), module));
-    VLOG(1) << "after simplify_padding:\n" << module->ToString();
-
-    {
-      // reshape-mover expects to be run alongside algsimp.
-      HloPassFix<HloPassPipeline> pipeline("reshape-mover and algsimp");
-      pipeline.AddPass<ReshapeMover>();
-      pipeline.AddPass<AlgebraicSimplifier>(AlgebraicSimplifierOptions());
-      TF_RETURN_IF_ERROR(RunHloPass(pipeline, module).status());
-    }
-    VLOG(1) << "after reshape mover + algsimp:\n" << module->ToString();
-
-    return changed;
-  }
-
-  absl::StatusOr<bool> RunJustThisPass(HloModule* module) {
-    TF_ASSIGN_OR_RETURN(bool changed,
-                        RunHloPass(CudnnSimplifyPadding(), module));
-    VLOG(1) << "after simplify_padding:\n" << module->ToString();
-
-    // I know the name says "just this pass", but you really want algsimp too,
-    // otherwise the resulting patterns are ugly/hard to match.
-    TF_RETURN_IF_ERROR(RunHloPass(HloPassFix<AlgebraicSimplifier>(
-                                      AlgebraicSimplifierOptions()),
-                                  module)
-                           .status());
-    return changed;
-  }
-};
-
-void ExpectOnlyPadsOneDim(int64_t dim, int64_t padding_high,
-                          const PaddingConfig& p) {
-  SCOPED_TRACE(p.DebugString());
-  for (int i = 0; i < p.dimensions_size(); ++i) {
-    SCOPED_TRACE(absl::StrCat("dimension ", i));
-    EXPECT_EQ(p.dimensions(i).edge_padding_low(), 0);
-    if (i == dim) {
-      EXPECT_EQ(p.dimensions(i).edge_padding_high(), padding_high);
-    } else {
-      EXPECT_EQ(p.dimensions(i).edge_padding_high(), 0);
-    }
-  }
-}
-
-template <typename NativeT>
-void SetConstantValue(
-    HloInstruction* instr,
-    absl::FunctionRef<NativeT(absl::Span<const int64_t>, NativeT)> value_fn) {
-  Literal new_literal = instr->literal().Clone();
-  new_literal.MutableEachCell<int8_t>(value_fn);
-  TF_EXPECT_OK(instr->parent()->ReplaceWithNewInstruction(
-      instr, HloInstruction::CreateConstant(std::move(new_literal))));
-}
-
-TEST_F(CudnnSimplifyPaddingTest, EndToEnd) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule TestModule
-
-  ENTRY TestComputation {
-    conv1 = (s8[10,20,30,190], u8[0]) custom-call(
-        s8[10,20,30,63] parameter(0), s8[3,5,63,190] parameter(1),
-        f32[10] parameter(2), s8[10,20,30,190] parameter(3)),
-      window={size=3x5}, dim_labels=b01f_01io->b01f,
-      custom_call_target="__cudnn$convBiasActivationForward"
-    conv1_result = get-tuple-element(conv1), index=0
-    ROOT conv2 = (s8[10,20,30,29], u8[0]) custom-call(
-        conv1_result, s8[3,5,190,29] parameter(4),
-        f32[10] parameter(5), s8[10,20,30,29] parameter(6)),
-      window={size=3x5}, dim_labels=b01f_01io->b01f,
-      custom_call_target="__cudnn$convBiasActivationForward"
-  })")
-                    .value();
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunEndToEnd({7, 5}, module.get()));
-  EXPECT_TRUE(changed);
-
-  SCOPED_TRACE(module->ToString());
-  auto* root = module->entry_computation()->root_instruction();
-
-  // conv2 should be fed directly from conv1, without any intervening
-  // reshapes/pads.
-  EXPECT_THAT(
-      root,
-      GmockMatch(m::Tuple(
-          m::Slice(m::Reshape(m::GetTupleElement(m::CustomCall(
-              {"__cudnn$convBiasActivationForward"},
-              m::GetTupleElement(
-                  m::CustomCall({"__cudnn$convBiasActivationForward"}), 0),
-              m::Op(), m::Op(), m::Op())))),
-          m::Op())));
-}
-
-TEST_F(CudnnSimplifyPaddingTest, EndToEndNCHW) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule TestModule
-
-  ENTRY TestComputation {
-    conv1 = (s8[1,64,480,400], u8[0]) custom-call(
-        s8[1,112,480,400] parameter(0), s8[3,3,112,64] parameter(1),
-        f32[64] parameter(2)),
-      window={size=3x3}, dim_labels=bf01_01io->bf01,
-      custom_call_target="__cudnn$convBiasActivationForward"
-    conv1_result = get-tuple-element(conv1), index=0
-    convert = f32[1,64,480,400] convert(conv1_result)
-    constant = f32[] constant(0.349002093)
-    broadcast = f32[1,64,480,400] broadcast(constant)
-    ROOT multiply = f32[1,64,480,400] multiply(convert, broadcast)
-  })")
-                    .value();
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunEndToEnd({7, 5}, module.get()));
-  // The SimplifyPadding pass itself does not do anything.
-  EXPECT_FALSE(changed);
-
-  SCOPED_TRACE(module->ToString());
-  auto* root = module->entry_computation()->root_instruction();
-
-  // The reshape introduced by CudnnVectorizeConvolutions should have been moved
-  // to the root.
-  EXPECT_THAT(root, GmockMatch(m::Reshape(m::Multiply())));
-}
-
-TEST_F(CudnnSimplifyPaddingTest, PaddedWeights) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule TestModule
-
-    ENTRY TestComputation {
-      weights = pad(s8[3,3,10,10] parameter(0), s8[] constant(0)), padding=0_0x0_0x0_0x0_4
-      conv = (s8[10,10,10,10], u8[0]) custom-call(
-          s8[10,10,10,10] parameter(1),
-          weights
-        ), window={size=3x3}, dim_labels=b01f_01io->b01f,
-        custom_call_target="__cudnn$convForward"
-      conv_result = get-tuple-element(conv), index=0
-      slice = s8[10,10,10,6] slice(conv_result), slice={[0:10], [0:10], [0:10], [0:6]}
-      ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_5
-    }
-  )")
-                    .value();
-
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
-  EXPECT_TRUE(changed);
-
-  SCOPED_TRACE(module->ToString());
-  auto* root = module->entry_computation()->root_instruction();
-  const HloInstruction* pad = nullptr;
-  ASSERT_THAT(root,
-              GmockMatch(m::Pad(&pad, m::GetTupleElement(m::CustomCall(), 0),
-                                m::ConstantScalar(0))));
-
-  ExpectOnlyPadsOneDim(/*dim=*/3, /*padding_high=*/1, pad->padding_config());
-}
-
-// This is similar to PaddedWeights, except the only 3 elements of the weights
-// are padded to 0 while we slice off 4 elements from the output features. As a
-// result, not all of the sliced elements are 0, and we can't merge the slice
-// into the pad that follows.
-TEST_F(CudnnSimplifyPaddingTest, PaddedWeightsNotPaddedEnough) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule TestModule
-
-    ENTRY TestComputation {
-      weights = pad(s8[3,3,10,10] parameter(0), s8[] constant(0)), padding=0_0x0_0x0_0x0_3
-      conv = (s8[10,10,10,10], u8[0]) custom-call(
-          s8[10,10,10,10] parameter(1),
-          weights
-        ), window={size=3x3}, dim_labels=b01f_01io->b01f,
-        custom_call_target="__cudnn$convForward"
-      conv_result = get-tuple-element(conv), index=0
-      slice = s8[10,10,10,6] slice(conv_result), slice={[0:10], [0:10], [0:10], [0:6]}
-      ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_5
-    }
-  )")
-                    .value();
-
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
-  EXPECT_FALSE(changed);
-}
-
-TEST_F(CudnnSimplifyPaddingTest, PaddedAndReshapedWeightsNCHW) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule TestModule
-
-    ENTRY TestComputation {
-      weights_p = pad(s8[64,60,3,3] parameter(0), s8[] constant(0)), padding=0_0x0_4x0_0x0_0
-      weights = s8[2,32,64,3,3] reshape(weights_p)
-      conv = (s8[10,2,32,10,10], u8[0]) custom-call(
-          s8[10,2,32,10,10] parameter(1),
-          weights
-        ), window={size=3x3}, dim_labels=bf?01_i?o01->bf?01,
-        custom_call_target="__cudnn$convForward"
-      conv_result = get-tuple-element(conv), index=0
-      slice = s8[10,60,10,10] slice(s8[10,64,10,10] reshape(conv_result)), slice={[0:10], [0:60], [0:10], [0:10]}
-      ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_5x0_0x0_0
-    }
-  )")
-                    .value();
-
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
-  EXPECT_TRUE(changed);
-
-  SCOPED_TRACE(module->ToString());
-  auto* root = module->entry_computation()->root_instruction();
-  const HloInstruction* pad = nullptr;
-  ASSERT_THAT(
-      root, GmockMatch(
-                m::Pad(&pad, m::Reshape(m::GetTupleElement(m::CustomCall(), 0)),
-                       m::ConstantScalar(0))));
-
-  ExpectOnlyPadsOneDim(/*dim=*/1, /*padding_high=*/1, pad->padding_config());
-}
-
-TEST_F(CudnnSimplifyPaddingTest, PaddedAndReshapedWeightsNHWC) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule TestModule
-
-    ENTRY TestComputation {
-      weights_p = pad(s8[3,3,64,60] parameter(0), s8[] constant(0)), padding=0_0x0_0x0_0x0_4
-      weights = s8[3,3,2,32,64] reshape(weights_p)
-      conv = (s8[10,10,10,2,32], u8[0]) custom-call(
-          s8[10,10,10,2,32] parameter(1),
-          weights
-        ), window={size=3x3}, dim_labels=b01f?_01i?o->b01f?,
-        custom_call_target="__cudnn$convForward"
-      conv_result = get-tuple-element(conv), index=0
-      slice = s8[10,10,10,60] slice(s8[10,10,10,64] reshape(conv_result)), slice={[0:10], [0:10], [0:10], [0:60]}
-      ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_5
-    }
-  )")
-                    .value();
-
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
-  EXPECT_TRUE(changed);
-
-  SCOPED_TRACE(module->ToString());
-  auto* root = module->entry_computation()->root_instruction();
-  const HloInstruction* pad = nullptr;
-  ASSERT_THAT(
-      root, GmockMatch(
-                m::Pad(&pad, m::Reshape(m::GetTupleElement(m::CustomCall(), 0)),
-                       m::ConstantScalar(0))));
-
-  ExpectOnlyPadsOneDim(/*dim=*/3, /*padding_high=*/1, pad->padding_config());
-}
-
-TEST_F(CudnnSimplifyPaddingTest, PaddedTransposedAndReshapedOutput) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule TestModule
-
-    ENTRY TestComputation {
-      weights_p = pad(s8[64,60,3,3] parameter(0), s8[] constant(0)), padding=0_0x0_4x0_0x0_0
-      weights = s8[2,32,64,3,3] reshape(weights_p)
-      conv = (s8[10,2,10,10,32], u8[0]) custom-call(
-          s8[10,2,10,10,32] parameter(1),
-          weights
-        ), window={size=3x3}, dim_labels=bf01?_i?o01->bf01?,
-        custom_call_target="__cudnn$convForward"
-      conv_result = get-tuple-element(conv), index=0
-      conv_transposed = s8[10,2,32,10,10] transpose(conv_result), dimensions={0,1,4,2,3}
-      slice = s8[10,60,10,10] slice(s8[10,64,10,10] reshape(conv_transposed)), slice={[0:10], [0:60], [0:10], [0:10]}
-      ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_6x0_0x0_0
-    }
-  )")
-                    .value();
-
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
-  EXPECT_TRUE(changed);
-
-  SCOPED_TRACE(module->ToString());
-  auto* root = module->entry_computation()->root_instruction();
-  const HloInstruction* pad = nullptr;
-  ASSERT_THAT(
-      root,
-      GmockMatch(m::Pad(
-          &pad,
-          m::Reshape(m::Transpose(m::GetTupleElement(m::CustomCall(), 0))),
-          m::ConstantScalar(0))));
-
-  ExpectOnlyPadsOneDim(/*dim=*/1, /*padding_high=*/2, pad->padding_config());
-}
-
-TEST_F(CudnnSimplifyPaddingTest, PaddedConstantWeight) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule TestModule
-
-    ENTRY TestComputation {
-      conv = (s8[10,10,10,10], u8[0]) custom-call(
-          s8[10,10,10,10] parameter(0),
-          s8[3,3,10,10] constant({...})
-        ), window={size=3x3}, dim_labels=b01f_01io->b01f,
-        custom_call_target="__cudnn$convForward"
-      conv_result = get-tuple-element(conv), index=0
-      slice = s8[10,10,10,6] slice(conv_result), slice={[0:10], [0:10], [0:10], [0:6]}
-      ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_5
-    }
-  )")
-                    .value();
-
-  // Set the constant's value.  (The HLO text above sets it to all 0s.)
-  {
-    HloInstruction* weights = nullptr;
-    ASSERT_THAT(module->entry_computation()->root_instruction(),
-                GmockMatch(m::Pad(m::Slice(m::GetTupleElement(m::CustomCall(
-                                      m::Op(), m::Constant(&weights)))),
-                                  m::Op())));
-    SetConstantValue<int8_t>(
-        weights, [](absl::Span<const int64_t> dims, int8_t old_val) -> int8_t {
-          if (dims[3] < 6) return 1;
-          return 0;
-        });
-  }
-
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
-  EXPECT_TRUE(changed);
-
-  SCOPED_TRACE(module->ToString());
-  auto* root = module->entry_computation()->root_instruction();
-  const HloInstruction* pad = nullptr;
-  ASSERT_THAT(root,
-              GmockMatch(m::Pad(&pad, m::GetTupleElement(m::CustomCall(), 0),
-                                m::ConstantScalar(0))));
-
-  ExpectOnlyPadsOneDim(/*dim=*/3, /*padding_high=*/1, pad->padding_config());
-}
-
-TEST_F(CudnnSimplifyPaddingTest, PaddedConstantWeightIsNotLargeEnough) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule TestModule
-
-    ENTRY TestComputation {
-      conv = (s8[10,10,10,10], u8[0]) custom-call(
-          s8[10,10,10,10] parameter(0),
-          s8[3,3,10,10] constant({...})
-        ), window={size=3x3}, dim_labels=b01f_01io->b01f,
-        custom_call_target="__cudnn$convForward"
-      conv_result = get-tuple-element(conv), index=0
-      slice = s8[10,10,10,6] slice(conv_result), slice={[0:10], [0:10], [0:10], [0:6]}
-      ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_5
-    }
-  )")
-                    .value();
-
-  // Set the constant's value.  (The HLO text above sets it to all 0s.)
-  {
-    HloInstruction* weights = nullptr;
-    ASSERT_THAT(module->entry_computation()->root_instruction(),
-                GmockMatch(m::Pad(m::Slice(m::GetTupleElement(m::CustomCall(
-                                      m::Op(), m::Constant(&weights)))),
-                                  m::Op())));
-    SetConstantValue<int8_t>(
-        weights, [](absl::Span<const int64_t> dims, int8_t old_val) -> int8_t {
-          // The sixth feature dimension (i.e. index 5) is only partially 0.
-          if (dims[3] < 5 /*|| (dims[3] == 5 && dims[2] > 1)*/) return 0;
-          return 1;
-        });
-  }
-
-  // Some of the value sliced off are not 0, so we can't merge the slice into
-  // the pad.
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
-  EXPECT_FALSE(changed);
-}
-
-TEST_F(CudnnSimplifyPaddingTest, ReshapeDoesntMergeVectCDim) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule TestModule
-
-    ENTRY TestComputation {
-      weights_p = pad(s8[64,60,3,3] parameter(0), s8[] constant(0)), padding=0_0x0_4x0_0x0_0
-      weights = s8[2,64,3,3,32] reshape(weights_p)
-      conv = (s8[10,2,10,10,32], u8[0]) custom-call(
-          s8[10,2,10,10,32] parameter(1),
-          weights_p
-        ), window={size=3x3}, dim_labels=bf01?_io01?->bf01?,
-        custom_call_target="__cudnn$convForward"
-      conv_result = get-tuple-element(conv), index=0
-      slice = s8[10,60,10,10] slice(s8[10,64,10,10] reshape(conv_result)), slice={[0:10], [0:60], [0:10], [0:10]}
-      ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_6x0_0x0_0
-    }
-  )")
-                    .value();
-
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
-  EXPECT_FALSE(changed);
-}
-
-TEST_F(CudnnSimplifyPaddingTest, TwoVectCDimsInOutput) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule TestModule
-
-    ENTRY TestComputation {
-      weights_p = pad(s8[64,60,3,3] parameter(0), s8[] constant(0)), padding=0_0x0_4x0_0x0_0
-      weights = s8[2,64,3,3,32] reshape(weights_p)
-      conv = (s8[10,2,10,10,4,8], u8[0]) custom-call(
-          s8[10,2,10,10,32] parameter(1),
-          weights
-        ), window={size=3x3}, dim_labels=bf01?_io01?->bf01??,
-        custom_call_target="__cudnn$convForward"
-      conv_result = get-tuple-element(conv), index=0
-      conv_transposed = s8[10,2,4,8,10,10] transpose(conv_result), dimensions={0,1,4,5,2,3}
-      slice = s8[10,60,10,10] slice(s8[10,64,10,10] reshape(conv_transposed)), slice={[0:10], [0:60], [0:10], [0:10]}
-      ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_6x0_0x0_0
-    }
-  )")
-                    .value();
-
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
-  EXPECT_FALSE(changed);
-}
-
-TEST_F(CudnnSimplifyPaddingTest, TwoVectCDimsInKernel) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule TestModule
-
-    ENTRY TestComputation {
-      weights_p = pad(s8[64,60,3,3] parameter(0), s8[] constant(0)), padding=0_0x0_4x0_0x0_0
-      weights = s8[2,64,3,3,4,8] reshape(weights_p)
-      conv = (s8[10,2,10,10,32], u8[0]) custom-call(
-          s8[10,2,10,10,32] parameter(1),
-          weights
-        ), window={size=3x3}, dim_labels=bf01?_io01??->bf01?,
-        custom_call_target="__cudnn$convForward"
-      conv_result = get-tuple-element(conv), index=0
-      conv_transposed = s8[10,2,32,10,10] transpose(conv_result), dimensions={0,1,4,2,3}
-      slice = s8[10,60,10,10] slice(s8[10,64,10,10] reshape(conv_transposed)), slice={[0:10], [0:60], [0:10], [0:10]}
-      ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_6x0_0x0_0
-    }
-  )")
-                    .value();
-
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
-  EXPECT_FALSE(changed);
-}
-
-TEST_F(CudnnSimplifyPaddingTest, SliceDoesntStartAtBeginning) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule TestModule
-
-    ENTRY TestComputation {
-      weights = pad(s8[3,3,10,10] parameter(0), s8[] constant(0)), padding=0_0x0_0x0_0x0_4
-      conv = (s8[10,10,10,10], u8[0]) custom-call(
-          s8[10,10,10,10] parameter(1),
-          weights
-        ), window={size=3x3}, dim_labels=b01f_01io->b01f,
-        custom_call_target="__cudnn$convForward"
-      conv_result = get-tuple-element(conv), index=0
-      slice = s8[10,9,10,6] slice(conv_result), slice={[0:10], [1:10], [0:10], [0:6]}
-      ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_5
-    }
-  )")
-                    .value();
-
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
-  EXPECT_FALSE(changed);
-}
-
-TEST_F(CudnnSimplifyPaddingTest, SliceDoesntStartAtBeginningOfFeatureDim) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule TestModule
-
-    ENTRY TestComputation {
-      weights = pad(s8[3,3,10,10] parameter(0), s8[] constant(0)), padding=0_0x0_0x0_0x0_4
-      conv = (s8[10,10,10,10], u8[0]) custom-call(
-          s8[10,10,10,10] parameter(1),
-          weights
-        ), window={size=3x3}, dim_labels=b01f_01io->b01f,
-        custom_call_target="__cudnn$convForward"
-      conv_result = get-tuple-element(conv), index=0
-      slice = s8[10,10,10,5] slice(conv_result), slice={[0:10], [0:10], [0:10], [1:6]}
-      ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_5
-    }
-  )")
-                    .value();
-
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
-  EXPECT_FALSE(changed);
-}
-
-TEST_F(CudnnSimplifyPaddingTest, SliceHasStride) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule TestModule
-
-    ENTRY TestComputation {
-      weights = pad(s8[3,3,10,10] parameter(0), s8[] constant(0)), padding=0_0x0_0x0_0x0_4
-      conv = (s8[10,10,10,10], u8[0]) custom-call(
-          s8[10,10,10,10] parameter(1),
-          weights
-        ), window={size=3x3}, dim_labels=b01f_01io->b01f,
-        custom_call_target="__cudnn$convForward"
-      conv_result = get-tuple-element(conv), index=0
-      slice = s8[10,10,10,3] slice(conv_result), slice={[0:10], [0:10], [0:10], [0:6:2]}
-      ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_5
-    }
-  )")
-                    .value();
-
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
-  EXPECT_FALSE(changed);
-}
-
-TEST_F(CudnnSimplifyPaddingTest, PadAddsInteriorPadding) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule TestModule
-
-    ENTRY TestComputation {
-      weights = pad(s8[3,3,10,10] parameter(0), s8[] constant(0)), padding=0_0x0_0x0_0x0_4
-      conv = (s8[10,10,10,10], u8[0]) custom-call(
-          s8[10,10,10,10] parameter(1),
-          weights
-        ), window={size=3x3}, dim_labels=b01f_01io->b01f,
-        custom_call_target="__cudnn$convForward"
-      conv_result = get-tuple-element(conv), index=0
-      slice = s8[10,10,10,6] slice(conv_result), slice={[0:10], [0:10], [0:10], [0:6]}
-      ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_5_1
-    }
-  )")
-                    .value();
-
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
-  EXPECT_FALSE(changed);
-}
-
-TEST_F(CudnnSimplifyPaddingTest, SliceMoreElementsThanPad) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule TestModule
-
-    ENTRY TestComputation {
-      weights = pad(s8[3,3,10,10] parameter(0), s8[] constant(0)), padding=0_0x0_0x0_0x0_4
-      conv = (s8[10,10,10,10], u8[0]) custom-call(
-          s8[10,10,10,10] parameter(1),
-          weights
-        ), window={size=3x3}, dim_labels=b01f_01io->b01f,
-        custom_call_target="__cudnn$convForward"
-      conv_result = get-tuple-element(conv), index=0
-      slice = s8[10,10,10,6] slice(conv_result), slice={[0:10], [0:10], [0:10], [0:6]}
-      ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_2
-    }
-  )")
-                    .value();
-
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
-  EXPECT_TRUE(changed);
-
-  SCOPED_TRACE(module->ToString());
-  auto* root = module->entry_computation()->root_instruction();
-  const HloInstruction* slice = nullptr;
-  // The pass creates a pad with negative padding; this is simplified by algsimp
-  // into a slice.
-  ASSERT_THAT(root, GmockMatch(m::Slice(
-                        &slice, m::GetTupleElement(m::CustomCall(), 0))));
-  for (int64_t i = 0; i < slice->shape().dimensions_size(); ++i) {
-    SCOPED_TRACE(i);
-    EXPECT_EQ(slice->slice_starts(i), 0);
-    EXPECT_EQ(slice->slice_strides(i), 1);
-    if (i != 3) {
-      EXPECT_EQ(slice->slice_limits(i), 10);
-    } else {
-      EXPECT_EQ(slice->slice_limits(i), 8);
-    }
-  }
-}
-
-TEST_F(CudnnSimplifyPaddingTest, NoChangeOnNonTrivialConstants) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-HloModule jit_outer
-
-ENTRY main.26 {
-  reshape.2 = f32[1,3,3,12]{3,2,1,0} parameter(0)
-  constant.1 = f32[3,3,1,12]{3,2,1,0} constant({ {
-    { /*i1=0*/ { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
-    { /*i1=1*/ { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
-    { /*i1=2*/ { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } }
-  }, {
-    { /*i1=0*/ { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
-    { /*i1=2*/ { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } }
-    { /*i1=2*/ { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } }
-  }, {
-    { /*i1=0*/ { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
-    { /*i1=1*/ { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
-    { /*i1=2*/ { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } } } })
-  cudnn-conv = (f32[1,5,5,12]{3,2,1,0}, u8[0]{0}) custom-call(reshape.2, constant.1), window={size=3x3 pad=2_2x2_2}, dim_labels=b01f_01io->b01f, feature_group_count=12, custom_call_target="__cudnn$convForward"
-  get-tuple-element = f32[1,5,5,12]{3,2,1,0} get-tuple-element(cudnn-conv), index=0
-  slice.2 = f32[1,5,1,12]{3,2,1,0} slice(get-tuple-element), slice={[0:1], [0:5], [0:1], [0:12]}
-  constant.0 = f32[] constant(0)
-  ROOT pad.1 = f32[1,5,3,12]{3,2,1,0} pad(slice.2, constant.0), padding=0_0x0_0x2_0x0_0
-}
-  )")
-                    .value();
-
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
-  EXPECT_FALSE(changed);
-}
-
-TEST_F(CudnnSimplifyPaddingTest, NoChangeOnComplexSlices) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-HloModule jit_outer
-
-ENTRY main.26 {
-  reshape.2 = f32[1,3,3,12]{3,2,1,0} parameter(0)
-  constant.1 = f32[3,3,1,12]{3,2,1,0} constant({ {
-    { /*i1=0*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
-    { /*i1=1*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
-    { /*i1=2*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } }
-  }, {
-    { /*i1=0*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
-    { /*i1=2*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } }
-    { /*i1=2*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } }
-  }, {
-    { /*i1=0*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
-    { /*i1=1*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
-    { /*i1=2*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } } } })
-  cudnn-conv = (f32[1,5,5,12]{3,2,1,0}, u8[0]{0}) custom-call(reshape.2, constant.1), window={size=3x3 pad=2_2x2_2}, dim_labels=b01f_01io->b01f, feature_group_count=12, custom_call_target="__cudnn$convForward"
-  get-tuple-element = f32[1,5,5,12]{3,2,1,0} get-tuple-element(cudnn-conv), index=0
-  slice.2 = f32[1,5,5,4]{3,2,1,0} slice(get-tuple-element), slice={[0:1], [0:5], [0:5], [2:6]}
-  constant.0 = f32[] constant(0)
-  ROOT pad.1 = f32[1,5,5,12]{3,2,1,0} pad(slice.2, constant.0), padding=0_0x0_0x0_0x0_8
-}
-  )")
-                    .value();
-
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
-  EXPECT_FALSE(changed);
-}
-
-TEST_F(CudnnSimplifyPaddingTest, ScanOrderFeatureDimLast) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-HloModule jit_outer
-
-ENTRY main.26 {
-  reshape.2 = f32[1,3,3,12]{3,2,1,0} parameter(0)
-  constant.1 = f32[3,3,1,12]{3,2,1,0} constant({ {
-    { /*i1=0*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
-    { /*i1=1*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
-    { /*i1=2*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } }
-  }, {
-    { /*i1=0*/ { 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0 } },
-    { /*i1=2*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } }
-    { /*i1=2*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } }
-  }, {
-    { /*i1=0*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
-    { /*i1=1*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
-    { /*i1=2*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } } } })
-  cudnn-conv = (f32[1,5,5,12]{3,2,1,0}, u8[0]{0}) custom-call(reshape.2, constant.1), window={size=3x3 pad=2_2x2_2}, dim_labels=b01f_01io->b01f, feature_group_count=12, custom_call_target="__cudnn$convForward"
-  get-tuple-element = f32[1,5,5,12]{3,2,1,0} get-tuple-element(cudnn-conv), index=0
-  slice.2 = f32[1,5,5,6]{3,2,1,0} slice(get-tuple-element), slice={[0:1], [0:5], [0:5], [0:6]}
-  constant.0 = f32[] constant(0)
-  ROOT pad.1 = f32[1,5,5,12]{3,2,1,0} pad(slice.2, constant.0), padding=0_0x0_0x0_0x0_6
-}
-  )")
-                    .value();
-
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
-  EXPECT_FALSE(changed);
-}
-
-TEST_F(CudnnSimplifyPaddingTest, Int8FilterReorderedOutputFirst) {
-  // Test feature dimension calculation from reordering transpose (oi01)
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule TestModule
-
-  ENTRY TestComputation {
-    conv.1 = (s8[1,63,80,80], u8[0]) custom-call(
-        s8[1,112,80,80] parameter(0), s8[63,112,3,3] parameter(1)),
-      window={size=3x3}, dim_labels=bf01_oi01->bf01,
-      custom_call_target="__cudnn$convForward"
-    gte.1 = s8[1,63,80,80] get-tuple-element(conv.1), index=0
-    const.0 = s8[] constant(0)
-    ROOT pad.1 = s8[1,64,80,80] pad(gte.1, const.0), padding=0_0x0_1x0_0x0_0
-  })")
-                    .value();
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunEndToEnd({7, 5}, module.get()));
-  EXPECT_TRUE(changed);
-}
-
-TEST_F(CudnnSimplifyPaddingTest, Int8FilterReorderedOutputLast) {
-  // Test feature dimension calculation from reordering transpose (01io)
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule TestModule
-
-  ENTRY TestComputation {
-    conv.1 = (s8[1,63,80,80], u8[0]) custom-call(
-        s8[1,112,80,80] parameter(0), s8[3,3,112,63] parameter(1)),
-      window={size=3x3}, dim_labels=bf01_01io->bf01,
-      custom_call_target="__cudnn$convForward"
-    gte.1 = s8[1,63,80,80] get-tuple-element(conv.1), index=0
-    const.0 = s8[] constant(0)
-    ROOT pad.1 = s8[1,64,80,80] pad(gte.1, const.0), padding=0_0x0_1x0_0x0_0
-  })")
-                    .value();
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunEndToEnd({7, 5}, module.get()));
-  EXPECT_TRUE(changed);
-}
-
-}  // anonymous namespace
-}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/cudnn_vectorize_convolutions.cc b/third_party/xla/xla/service/gpu/cudnn_vectorize_convolutions.cc
deleted file mode 100644
index 3846f01..0000000
--- a/third_party/xla/xla/service/gpu/cudnn_vectorize_convolutions.cc
+++ /dev/null
@@ -1,650 +0,0 @@
-/* Copyright 2021 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/cudnn_vectorize_convolutions.h"
-
-#include <cstdint>
-#include <optional>
-#include <string>
-#include <tuple>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/container/inlined_vector.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/string_view.h"
-#include "xla/client/xla_builder.h"
-#include "xla/client/xla_computation.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_clone_context.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/primitive_util.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/service/gpu/cudnn_support_utils.h"
-#include "xla/service/gpu/stream_executor_util.h"
-#include "xla/service/hlo_module_config.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/stream_executor/dnn.h"
-#include "xla/util.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-// Finds convolutions that this pass may be able to transform, namely int8_t
-// cudnn forward or forward-bias-activation convolutions
-//
-// cudnn as of v8.2 supports the following data type combinations for forward
-// and forward-bias-activation convolutions.  We have to make sure we only
-// vectorize to one of these supported configs.
-//
-//   in       out
-//   int8x1   int8x1
-//   int8x1   float
-//   int8x1   int32_t
-//
-//   int8x4   int8x4
-//   int8x4   float
-//
-//   int8x32  int8x32
-//
-// https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionForward
-//
-// For now we restrict ourselves to only the int8xN -> int8xN cases.  We could
-// allow the int8x4 -> float case in the future if desirable.
-static std::vector<HloCustomCallInstruction*> GetRelevantConvs(
-    HloComputation* comp) {
-  std::vector<HloCustomCallInstruction*> convs;
-  for (HloInstruction* instr : comp->instructions()) {
-    if (instr->opcode() != HloOpcode::kCustomCall ||
-        (instr->custom_call_target() != kCudnnConvForwardCallTarget &&
-         instr->custom_call_target() !=
-             kCudnnConvBiasActivationForwardCallTarget) ||
-        instr->operand_count() < 2) {
-      continue;
-    }
-
-    PrimitiveType input_ty = instr->operand(0)->shape().element_type();
-    PrimitiveType output_ty = instr->shape().tuple_shapes(0).element_type();
-    if (input_ty == output_ty && (input_ty == S8 || input_ty == U8)) {
-      convs.push_back(Cast<HloCustomCallInstruction>(instr));
-    }
-  }
-  return convs;
-}
-
-// Converts an XlaBuilder into an HloComputation in the same module as
-// `sibling_computation`.
-//
-// Yes, we serialize/deserialize as a proto.  :)
-static absl::StatusOr<HloComputation*> BuilderToHloComputation(
-    XlaBuilder& b, XlaOp root, HloComputation* sibling_computation) {
-  TF_ASSIGN_OR_RETURN(XlaComputation comp, b.Build(root));
-  TF_ASSIGN_OR_RETURN(ProgramShape program_shape, comp.GetProgramShape());
-  HloModuleConfig config(program_shape);
-  TF_ASSIGN_OR_RETURN(auto new_module,
-                      HloModule::CreateFromProto(comp.proto(), config));
-
-  HloModule* dest_module = sibling_computation->parent();
-  HloCloneContext context(dest_module);
-  return dest_module->DeepCloneComputation(new_module->entry_computation(),
-                                           &context);
-}
-
-// Reshapes `instr` so that it has an extra dimension of size `vect_size` right
-// after `dim`.
-static XlaOp SplitAtDim(XlaOp instr, int64_t dim, int64_t vect_size) {
-  XlaBuilder& b = *instr.builder();
-  Shape shape = b.GetShape(instr).value();
-  DimensionVector new_dims(shape.dimensions().begin(),
-                           shape.dimensions().end());
-  CHECK_EQ(new_dims[dim] % vect_size, 0);
-  new_dims[dim] /= vect_size;
-  new_dims.insert(new_dims.begin() + dim + 1, vect_size);
-  return Reshape(instr, new_dims);
-}
-
-// Reshapes `shape` so that there's an extra dimension of size `vect_size` right
-// after `dim`.
-//
-// For example given shape=s8[10, 32, 20], dim=1, vect_size=4, returns
-// s8[10, 8, 4, 20].
-static Shape SplitShapeAtDim(Shape shape, int64_t dim, int64_t vect_size) {
-  DimensionVector new_dims(shape.dimensions().begin(),
-                           shape.dimensions().end());
-  CHECK_EQ(new_dims[dim] % vect_size, 0);
-  new_dims[dim] /= vect_size;
-  new_dims.insert(new_dims.begin() + dim + 1, vect_size);
-  return ShapeUtil::MakeShape(shape.element_type(), new_dims);
-}
-
-// Transposes dimension `src` to right before `dst`.
-static XlaOp MoveDim(XlaOp instr, int64_t src, int64_t dst) {
-  XlaBuilder& b = *instr.builder();
-  int64_t rank = b.GetShape(instr)->dimensions_size();
-
-  DimensionVector idxs(rank);
-  absl::c_iota(idxs, 0);
-  if (src < dst) {
-    idxs.insert(idxs.begin() + dst, src);
-    idxs.erase(idxs.begin() + src);
-  } else {
-    idxs.erase(idxs.begin() + src);
-    idxs.insert(idxs.begin() + dst, src);
-  }
-  return Transpose(instr, idxs);
-}
-
-// Reshapes instr so that dimension `vect_dim` has size `vect_size`, by stealing
-// elements from `dim`.
-//
-// Requires that this is possible without merging and re-splitting the two
-// dimensions.  I.e. there should be some amount of dim that we can "split off"
-// and add to vect_dim to get it to have size vect_size.
-static XlaOp RevectorizeInstr(XlaOp instr, int64_t dim, int64_t vect_dim,
-                              int64_t vect_size) {
-  XlaBuilder& b = *instr.builder();
-  Shape shape = b.GetShape(instr).value();
-  auto size = [&](int64_t d) { return shape.dimensions(d); };
-
-  CHECK_LE(size(vect_dim), vect_size);
-  CHECK_EQ(vect_size % size(vect_dim), 0);
-
-  int64_t split_factor = vect_size / size(vect_dim);
-  CHECK_EQ(size(dim) % split_factor, 0);
-
-  // Split dim into [C, split_factor].
-  instr = SplitAtDim(instr, dim, split_factor);
-
-  // SplitAtDim may have added a dimension before vect_dim.
-  if (vect_dim > dim) {
-    vect_dim++;
-  }
-
-  // Move the split_factor dimension to right before vect_dim.
-  instr = MoveDim(instr, dim + 1, vect_dim);
-
-  // Moving the split_factor dimension may have *removed* a dimension before
-  // vect_dim.
-  if (vect_dim > dim) {
-    vect_dim--;
-  }
-
-  // Collapse the split_factor dimension into vect_dim.
-  return Collapse(instr, {vect_dim, vect_dim + 1});
-}
-
-// Inverse of RevectorizeInstr.  Reshapes instr so that dimension `vect_dim` has
-// size `vect_size`, moving excess elements into `dim`.
-static XlaOp UnrevectorizeInstr(XlaOp instr, int64_t dim, int64_t vect_dim,
-                                int64_t orig_vect_size) {
-  XlaBuilder& b = *instr.builder();
-  Shape shape = b.GetShape(instr).value();
-  auto size = [&](int64_t d) { return shape.dimensions(d); };
-
-  CHECK_GE(size(vect_dim), orig_vect_size);
-  CHECK_EQ(size(vect_dim) % orig_vect_size, 0);
-
-  // Split vect_dim into [C, orig_vect_size].
-  instr = SplitAtDim(instr, vect_dim, orig_vect_size);
-
-  // SplitAtDim may have added a dimension before dim.
-  if (dim > vect_dim) {
-    dim++;
-  }
-
-  // Move the `C` dimension to right after `dim`.  Take into account that
-  // SplitAtDim may have added a dimension before dim.
-  instr = MoveDim(instr, vect_dim, dim + 1);
-
-  // MoveDim may have *removed* a dimension before dim.
-  if (dim > vect_dim) {
-    dim--;
-  }
-
-  // Collapse the `C` and `dim` dimensions.
-  return Collapse(instr, {dim, dim + 1});
-}
-
-// Adds a vectorized-feature dimension to dnums right after the current feature
-// dimension.
-//
-// ConvolutionDimensionNumbers doesn't represent the vectorized-feature
-// dimension explicitly, because the whole concept of a vectorized-feature
-// dimension is specific to cudnn.  Rather, the vectorized-feature dimension is
-// implicit; it's the first dimension that *doesn't* appear in the dnums.
-//
-// This function "makes room" in dnums for the new vectorized dimension by
-// incrementing any dimensions which appear after the feature dim.  The implicit
-// vector dim is then in this "empty" spot.
-static ConvolutionDimensionNumbers VectorizeDnums(
-    ConvolutionDimensionNumbers dnums, bool reordered_filter) {
-  int64_t input_vect_dim = dnums.input_feature_dimension();
-  if (dnums.input_batch_dimension() > input_vect_dim) {
-    dnums.set_input_batch_dimension(dnums.input_batch_dimension() + 1);
-  }
-  for (int64_t& d : *dnums.mutable_input_spatial_dimensions()) {
-    if (d > input_vect_dim) {
-      ++d;
-    }
-  }
-
-  if (!reordered_filter) {
-    int64_t kernel_vect_dim = dnums.kernel_input_feature_dimension();
-    if (dnums.kernel_output_feature_dimension() > kernel_vect_dim) {
-      dnums.set_kernel_output_feature_dimension(
-          dnums.kernel_output_feature_dimension() + 1);
-    }
-    for (int64_t& d : *dnums.mutable_kernel_spatial_dimensions()) {
-      if (d > kernel_vect_dim) {
-        ++d;
-      }
-    }
-  }
-
-  int64_t output_vect_dim = dnums.output_feature_dimension();
-  if (dnums.output_batch_dimension() > output_vect_dim) {
-    dnums.set_output_batch_dimension(dnums.output_batch_dimension() + 1);
-  }
-  for (int64_t& d : *dnums.mutable_output_spatial_dimensions()) {
-    if (d > output_vect_dim) {
-      ++d;
-    }
-  }
-
-  return dnums;
-}
-
-// Reorders the convolution's filter and bias (if present) according to
-// cudnnReorderFilterAndBias.  Also marks that the filter + bias are reordered
-// in the conv's backend-config.
-absl::Status ReorderInt8NchwVect(HloCustomCallInstruction* conv,
-                                 XlaOp* operands) {
-  bool has_bias = conv->operand_count() > 2;
-  VLOG(1) << "Reordering filter" << (has_bias ? " and bias" : "")
-          << " (replacement for cudnnReorderFilterAndBias)";
-
-  auto builder = operands->builder();
-  ConvolutionDimensionNumbers dnums = conv->convolution_dimension_numbers();
-
-  // Update convolution backend config.
-  TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
-                      conv->backend_config<GpuBackendConfig>());
-  CudnnConvBackendConfig& config =
-      *gpu_config.mutable_cudnn_conv_backend_config();
-  config.set_reordered_int8_nchw_vect(true);
-  TF_RETURN_IF_ERROR(conv->set_backend_config(gpu_config));
-
-  // Reorder the filter.
-  TF_ASSIGN_OR_RETURN(Shape filter_shape, builder->GetShape(operands[1]));
-  TF_ASSIGN_OR_RETURN(auto reorder, CudnnInferTransposeForFilterReordering(
-                                        filter_shape, dnums));
-  XlaOp reshape = Reshape(reorder.transpose_shape, operands[1]);
-  XlaOp transpose = Transpose(reshape, reorder.permutation);
-  operands[1] = Reshape(reorder.result_shape, transpose);
-
-  // The reshape-transpose-reshape we did above makes sure the resulting filter
-  // has dimension numbers corresponding to "oihw?", so update them.
-  dnums.set_kernel_output_feature_dimension(0);
-  dnums.set_kernel_input_feature_dimension(1);
-  dnums.set_kernel_spatial_dimensions(0, 2);
-  dnums.set_kernel_spatial_dimensions(1, 3);
-  conv->set_convolution_dimension_numbers(dnums);
-
-  if (has_bias) {
-    // Reorder the bias.
-    TF_ASSIGN_OR_RETURN(Shape bias_shape, builder->GetShape(operands[2]));
-    TF_ASSIGN_OR_RETURN(reorder,
-                        CudnnInferTransposeForBiasReordering(bias_shape));
-    reshape = Reshape(reorder.transpose_shape, operands[2]);
-    transpose = Transpose(reshape, reorder.permutation);
-    operands[2] = Reshape(reorder.result_shape, transpose);
-  }
-  return absl::OkStatus();
-}
-
-// Tries to vectorize an already-vectorized convolution.
-//
-// That is, given a convolution of shape [N, C/k, H, W, k], changes it to have
-// shape [N, C/vect_size, H, W, vect_size].  Similarly changes the filter from
-// [H, W, I/k, O] to [H, W, I/vect_size, vect_size, O].
-//
-// (The dimensions can appear in any order; which is N/C/etc is determined by
-// the convolutions' dnums.)
-static absl::StatusOr<bool> TryRevectorizeConv(
-    const se::CudaComputeCapability& compute_capability,
-    const se::dnn::VersionInfo& cudnn_version, HloCustomCallInstruction* conv,
-    int vect_size) {
-  const Shape& input_shape = conv->operand(0)->shape();
-  const Shape& kernel_shape = conv->operand(1)->shape();
-  const Shape& output_shape = conv->shape().tuple_shapes(0);
-  const ConvolutionDimensionNumbers* dnums =
-      &conv->convolution_dimension_numbers();
-
-  // Find the vectorized-features dim in the input/kernel/output.
-  std::optional<int64_t> input_vect_dim;
-  std::optional<int64_t> kernel_vect_dim;
-  std::optional<int64_t> output_vect_dim;
-  std::tie(input_vect_dim, kernel_vect_dim, output_vect_dim) =
-      FindVectorizedFeatureDims(*dnums, input_shape, kernel_shape,
-                                output_shape);
-
-  if (!input_vect_dim.has_value() || !kernel_vect_dim.has_value() ||
-      !output_vect_dim.has_value()) {
-    return false;
-  }
-
-  int64_t input_feat_size =
-      input_shape.dimensions(dnums->input_feature_dimension());
-  int64_t output_feat_size =
-      output_shape.dimensions(dnums->output_feature_dimension());
-  int64_t input_vect_size = input_shape.dimensions(*input_vect_dim);
-  int64_t output_vect_size = output_shape.dimensions(*output_vect_dim);
-  if (vect_size % input_vect_size != 0 || vect_size % output_vect_size != 0 ||
-      input_feat_size % (vect_size / input_vect_size) != 0 ||
-      output_feat_size % (vect_size / output_vect_size) != 0) {
-    return false;
-  }
-
-  // If this is an integer convolution check that we only vectorize when cuDNN
-  // supports the vectorized implementation.
-  if (primitive_util::IsIntegralType(input_shape.element_type())) {
-    TF_ASSIGN_OR_RETURN(bool supported_target_vectorization,
-                        CudnnSupportsOptimizedIntegerConvolution(
-                            compute_capability, *conv, vect_size));
-    if (!supported_target_vectorization) {
-      VLOG(3) << "Skipping re-vectorization of conv to vector size: "
-              << vect_size << ": " << conv->ToString();
-      return false;
-    }
-  }
-
-  VLOG(1) << "Re-vectorizing conv channels from "
-          << input_shape.dimensions(*input_vect_dim) << " to " << vect_size
-          << ": " << conv->ToString();
-
-  // We use XlaBuilder because it's a lot easier to get these tricky
-  // reshape/transposes correct using that API.
-  XlaBuilder b(absl::StrCat(conv->name(), ".revectorized"));
-  b.SetOpMetadata(conv->metadata());
-
-  XlaOp filter = Parameter(&b, 1, conv->operand(1)->shape(), "filter");
-  absl::InlinedVector<XlaOp, 4> new_operands = {
-      RevectorizeInstr(Parameter(&b, 0, conv->operand(0)->shape(), "input"),
-                       dnums->input_feature_dimension(), *input_vect_dim,
-                       vect_size),
-      RevectorizeInstr(filter, dnums->kernel_input_feature_dimension(),
-                       *kernel_vect_dim, vect_size),
-  };
-  if (conv->operand_count() > 2) {
-    // Bias, if present.  This is passed through unmodified.
-    new_operands.push_back(Parameter(&b, 2, conv->operand(2)->shape(), "bias"));
-  }
-  if (conv->operand_count() > 3) {
-    new_operands.push_back(RevectorizeInstr(
-        Parameter(&b, 3, conv->operand(3)->shape(), "side_input"),
-        dnums->input_feature_dimension(), *input_vect_dim, vect_size));
-  }
-
-  if (conv->operand_count() > 4) {
-    return InvalidArgument(
-        "Don't understand a conv with more than 4 arguments: %s",
-        conv->ToString());
-  }
-
-  // Reorder filter and bias for the int8x32 convolutions.  This requires cudnn
-  // >= 8.3.0.
-  //
-  // TODO(jlebar): Remove this guard once JAX no longer supports cudnn 8.3.
-  const auto& debug_options = conv->GetModule()->config().debug_options();
-  bool use_reordering =
-      input_shape.element_type() == xla::S8 && vect_size == 32 &&
-      debug_options.xla_gpu_enable_cudnn_int8x32_convolution_reordering() &&
-      cudnn_version >= se::dnn::VersionInfo{8, 3, 0};
-  if (use_reordering) {
-    // Reordering helper supports vector sizes of 4 and 32, so an additional
-    // reshape-transpose-reshape is not necessary in these cases.
-    int64_t kernel_vect_size = kernel_shape.dimensions(*kernel_vect_dim);
-    if (kernel_vect_size == 4 || kernel_vect_size == 32) {
-      new_operands[1] = filter;
-    }
-    TF_RETURN_IF_ERROR(ReorderInt8NchwVect(conv, new_operands.data()));
-    dnums = &conv->convolution_dimension_numbers();
-  }
-
-  // The custom-call returns a tuple (new_output_shape, u8[0]), where the second
-  // value in the tuple represents the convolution's scratch memory.
-  DimensionVector new_output_dims(output_shape.dimensions().begin(),
-                                  output_shape.dimensions().end());
-  new_output_dims[dnums->output_feature_dimension()] /=
-      (vect_size / output_vect_size);
-  new_output_dims[*output_vect_dim] = vect_size;
-  XlaOp new_conv = CustomCallWithConvDnums(
-      &b, conv->custom_call_target(), new_operands,
-      ShapeUtil::MakeTupleShape(
-          {ShapeUtil::MakeShape(output_shape.element_type(), new_output_dims),
-           ShapeUtil::MakeShape(U8, {0})}),
-      /*operand_shapes_with_layout=*/{},
-      /*opaque=*/conv->raw_backend_config_string(), /*has_side_effect=*/false,
-      /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
-      /*window=*/conv->window(),
-      /*dnums=*/*dnums);
-
-  XlaOp new_conv_result = GetTupleElement(new_conv, 0);
-  XlaOp new_conv_scratch = GetTupleElement(new_conv, 1);
-
-  XlaOp new_conv_result_unrevectorized = UnrevectorizeInstr(
-      new_conv_result, dnums->output_feature_dimension(), *output_vect_dim,
-      /*orig_vect_size=*/output_shape.dimensions(*output_vect_dim));
-
-  TF_ASSIGN_OR_RETURN(
-      HloComputation * new_conv_comp,
-      BuilderToHloComputation(
-          b, Tuple(&b, {new_conv_result_unrevectorized, new_conv_scratch}),
-          conv->parent()));
-
-  // Set the name on the new conv.  This is purely cosmetic, but we attempt to
-  // preserve e.g. "cudnn-conv.42" instead of "custom-call.42".
-  auto new_conv_comp_instrs = new_conv_comp->instructions();
-  auto new_conv_it =
-      absl::c_find_if(new_conv_comp_instrs, [](HloInstruction* instr) {
-        return instr->opcode() == HloOpcode::kCustomCall;
-      });
-  if (new_conv_it != new_conv_comp_instrs.end()) {
-    new_conv_comp->parent()->SetAndUniquifyInstrName(*new_conv_it,
-                                                     conv->name());
-  }
-
-  // Replace the old conv with a call to the computation we just created.
-  VLOG(1) << "Re-vectorized conv to " << new_conv_comp->ToString();
-  TF_RETURN_IF_ERROR(conv->parent()->ReplaceWithNewInstruction(
-      conv, HloInstruction::CreateCall(conv->shape(), conv->operands(),
-                                       new_conv_comp)));
-
-  return true;
-}
-
-// Tries to vectorize a convolution.
-//
-// Given a convolution of dimensions [N, C, H, W], tries to convert it to have
-// shape [N, C/vect_size, H, W, vect_size].  Similarly, given a kernel of shape
-// [H, W, I, O], tries to conver it to [H, W, I/vect_size, vect_size, O].
-//
-// This requires that C be a multiple of vect_size.  CudnnPadForConvolutions can
-// add padding to make this true.
-static absl::StatusOr<bool> TryVectorizeConv(
-    const se::CudaComputeCapability& compute_capability,
-    const se::dnn::VersionInfo& cudnn_version, HloCustomCallInstruction* conv,
-    int64_t vect_size) {
-  const Shape& input_shape = conv->operand(0)->shape();
-  const Shape& output_shape = conv->shape().tuple_shapes(0);
-  const ConvolutionDimensionNumbers* dnums =
-      &conv->convolution_dimension_numbers();
-  int64_t in_channels =
-      input_shape.dimensions(dnums->input_feature_dimension());
-  int64_t out_channels =
-      output_shape.dimensions(dnums->output_feature_dimension());
-
-  if (in_channels % vect_size != 0 || out_channels % vect_size != 0) {
-    return false;
-  }
-
-  if (input_shape.dimensions_size() >
-      2 + dnums->input_spatial_dimensions_size()) {
-    // Conv already has an extra dimension, which we assume is the vectorized
-    // features dim.
-    return false;
-  }
-
-  // If this is an integer convolution check that we only vectorize when cuDNN
-  // supports the vectorized implementation.
-  if (primitive_util::IsIntegralType(input_shape.element_type())) {
-    TF_ASSIGN_OR_RETURN(bool supported_target_vectorization,
-                        CudnnSupportsOptimizedIntegerConvolution(
-                            compute_capability, *conv, vect_size));
-    if (!supported_target_vectorization) {
-      VLOG(3) << "Skipping vectorization of conv to vector size: " << vect_size
-              << ": " << conv->ToString();
-      return false;
-    }
-  }
-
-  VLOG(1) << "Vectorizing conv channels by " << vect_size << ": "
-          << conv->ToString();
-
-  // We use XlaBuilder because it's a lot easier to get these tricky
-  // reshape/transposes correct using that API.
-  XlaBuilder b(absl::StrCat(conv->name(), ".revectorized"));
-  b.SetOpMetadata(conv->metadata());
-
-  XlaOp filter = Parameter(&b, 1, conv->operand(1)->shape(), "filter");
-  absl::InlinedVector<XlaOp, 4> new_operands = {
-      SplitAtDim(Parameter(&b, 0, conv->operand(0)->shape(), "input"),
-                 dnums->input_feature_dimension(), vect_size),
-      SplitAtDim(filter, dnums->kernel_input_feature_dimension(), vect_size),
-  };
-  if (conv->operand_count() > 2) {
-    // Bias, if present.  This is passed through unmodified.
-    new_operands.push_back(Parameter(&b, 2, conv->operand(2)->shape(), "bias"));
-  }
-  if (conv->operand_count() > 3) {
-    // Handle side input, which has same shape as the output.
-    new_operands.push_back(
-        SplitAtDim(Parameter(&b, 3, conv->operand(3)->shape(), "side_input"),
-                   dnums->output_feature_dimension(), vect_size));
-  }
-  if (conv->operand_count() > 4) {
-    return InvalidArgument(
-        "Don't understand a conv with more than 4 arguments: %s",
-        conv->ToString());
-  }
-
-  // Reorder filter and bias for the int8x32 convolutions.  This requires cudnn
-  // >= 8.3.0.
-  //
-  // TODO(jlebar): Remove this guard once JAX no longer supports cudnn 8.3.
-  const auto& debug_options = conv->GetModule()->config().debug_options();
-  bool use_reordering =
-      input_shape.element_type() == xla::S8 && vect_size == 32 &&
-      debug_options.xla_gpu_enable_cudnn_int8x32_convolution_reordering() &&
-      cudnn_version >= se::dnn::VersionInfo{8, 3, 0};
-  if (use_reordering) {
-    new_operands[1] = filter;
-    TF_RETURN_IF_ERROR(ReorderInt8NchwVect(conv, new_operands.data()));
-    dnums = &conv->convolution_dimension_numbers();
-  }
-
-  // The custom-call returns a tuple (new_output_shape, u8[0]), where the second
-  // value in the tuple represents the convolution's scratch memory.
-  Shape new_output_shape = SplitShapeAtDim(
-      output_shape, dnums->output_feature_dimension(), vect_size);
-  XlaOp new_conv = CustomCallWithConvDnums(
-      &b, conv->custom_call_target(), new_operands,
-      ShapeUtil::MakeTupleShape(
-          {new_output_shape, ShapeUtil::MakeShape(U8, {0})}),
-      /*operand_shapes_with_layout=*/{},
-      /*opaque=*/conv->raw_backend_config_string(), /*has_side_effect=*/false,
-      /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
-      /*window=*/conv->window(),
-      /*dnums=*/VectorizeDnums(*dnums, use_reordering));
-
-  XlaOp new_conv_result = GetTupleElement(new_conv, 0);
-  XlaOp new_conv_scratch = GetTupleElement(new_conv, 1);
-
-  // Reshape back to the original shape.
-  XlaOp conv_result_collapsed =
-      Collapse(new_conv_result, {dnums->output_feature_dimension(),
-                                 dnums->output_feature_dimension() + 1});
-
-  TF_ASSIGN_OR_RETURN(
-      HloComputation * new_conv_comp,
-      BuilderToHloComputation(
-          b, Tuple(&b, {conv_result_collapsed, new_conv_scratch}),
-          conv->parent()));
-
-  // Create a tuple and replace the old conv with it!
-  VLOG(1) << "Vectorized conv to: " << new_conv_comp->ToString();
-  TF_RETURN_IF_ERROR(conv->parent()->ReplaceWithNewInstruction(
-      conv, HloInstruction::CreateCall(conv->shape(), conv->operands(),
-                                       new_conv_comp)));
-  return true;
-}
-
-}  // namespace
-
-absl::StatusOr<bool> CudnnVectorizeConvolutions::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  bool changed = false;
-  for (HloComputation* comp :
-       module->MakeNonfusionComputations(execution_threads)) {
-    for (HloCustomCallInstruction* conv : GetRelevantConvs(comp)) {
-      // Try to (re)vectorize to int8x32 if this is an sm75+ GPU.  If we can't,
-      // fall back to int8x4.
-      bool local_changed = false;
-      if (compute_capability_.IsAtLeast(7, 5)) {
-        TF_ASSIGN_OR_RETURN(
-            local_changed,
-            TryRevectorizeConv(compute_capability_, cudnn_version_, conv, 32));
-        if (!local_changed) {
-          TF_ASSIGN_OR_RETURN(
-              local_changed,
-              TryVectorizeConv(compute_capability_, cudnn_version_, conv, 32));
-        }
-      }
-      if (!local_changed) {
-        TF_ASSIGN_OR_RETURN(
-            local_changed,
-            TryVectorizeConv(compute_capability_, cudnn_version_, conv, 4));
-      }
-      changed |= local_changed;
-    }
-  }
-  return changed;
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/cudnn_vectorize_convolutions.h b/third_party/xla/xla/service/gpu/cudnn_vectorize_convolutions.h
deleted file mode 100644
index 43165f2..0000000
--- a/third_party/xla/xla/service/gpu/cudnn_vectorize_convolutions.h
+++ /dev/null
@@ -1,73 +0,0 @@
-/* Copyright 2021 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_CUDNN_VECTORIZE_CONVOLUTIONS_H_
-#define XLA_SERVICE_GPU_CUDNN_VECTORIZE_CONVOLUTIONS_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/stream_executor/dnn.h"
-
-namespace xla {
-namespace gpu {
-
-// Changes the shape of cudnn convolutions to allow faster "vectorized"
-// algorithms.
-//
-// On sm61+ will convert int8_t convolutions from
-//
-//   - [N, C, H, W] to [N, C/4, H, W, 4],
-//
-// assuming C is divisible by 4.
-//
-// On sm75+ will convert int8_t convolutions from
-//
-//   - [N, C, H, W]      to [N, C/32, H, W, 32],
-//   - [N, C/4, H, W, 4] to [N, C/32, H, W, 32], and
-//   - [N, C, H, W]      to [N,  C/4, H, W,  4] (same as sm61+),
-//
-// assuming C is divisible by 4 or 32.
-//
-// This pass will not pad the channel dim to a multiple of 4 or 32, so you
-// should run CudnnPadForConvolutions before this.
-class CudnnVectorizeConvolutions : public HloModulePass {
- public:
-  explicit CudnnVectorizeConvolutions(
-      se::CudaComputeCapability compute_capability,
-      se::dnn::VersionInfo cudnn_version)
-      : compute_capability_(compute_capability),
-        cudnn_version_(cudnn_version) {}
-
-  absl::string_view name() const override {
-    return "cudnn_vectorize_convolutions";
-  }
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
-  const se::CudaComputeCapability compute_capability_;
-  const se::dnn::VersionInfo cudnn_version_;
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_CUDNN_VECTORIZE_CONVOLUTIONS_H_
diff --git a/third_party/xla/xla/service/gpu/cudnn_vectorize_convolutions_test.cc b/third_party/xla/xla/service/gpu/cudnn_vectorize_convolutions_test.cc
deleted file mode 100644
index aa15fc7..0000000
--- a/third_party/xla/xla/service/gpu/cudnn_vectorize_convolutions_test.cc
+++ /dev/null
@@ -1,758 +0,0 @@
-/* Copyright 2021 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/cudnn_vectorize_convolutions.h"
-
-#include <cstdint>
-#include <utility>
-#include <vector>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "absl/algorithm/container.h"
-#include "absl/status/statusor.h"
-#include "xla/service/call_inliner.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/service/hlo_parser.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/stream_executor/dnn.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/util.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-namespace m = ::xla::match;
-
-class CudnnVectorizeConvolutionsTest : public HloTestBase {
- protected:
-  // Runs this pass and some cleanup to make pattern-matching easier.
-  absl::StatusOr<bool> Run(std::pair<int, int> compute_capability,
-                           HloModule* module) {
-    CudnnVectorizeConvolutions pass(
-        se::CudaComputeCapability{compute_capability.first,
-                                  compute_capability.second},
-        se::dnn::VersionInfo(8, 3, 0));
-    TF_ASSIGN_OR_RETURN(bool changed, RunHloPass(&pass, module));
-
-    CallInliner inliner;
-    TF_RETURN_IF_ERROR(RunHloPass(&inliner, module).status());
-
-    return changed;
-  }
-};
-
-TEST_F(CudnnVectorizeConvolutionsTest, VectorizeTo4) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule TestModule
-
-  ENTRY TestComputation {
-    input = s8[10,20,30,40] parameter(0)
-    filter = s8[2,2,40,44] parameter(1)
-    ROOT result = (s8[10,20,30,44], u8[0]) custom-call(input, filter),
-                  window={size=2x2}, dim_labels=b01f_01io->b01f,
-                  custom_call_target="__cudnn$convForward",
-                  backend_config="{bar: 0}"
-  })")
-                    .value();
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
-  EXPECT_TRUE(changed);
-
-  SCOPED_TRACE(module->ToString());
-  auto* root = module->entry_computation()->root_instruction();
-
-  const HloInstruction* conv = nullptr;
-  ASSERT_THAT(
-      root,
-      GmockMatch(m::Tuple(
-          m::Reshape(m::GetTupleElement(
-                         m::CustomCall(&conv, {kCudnnConvForwardCallTarget},
-                                       m::Reshape(m::Parameter(0))
-                                           .WithShape(S8, {10, 20, 30, 10, 4}),
-                                       m::Reshape(m::Parameter(1))
-                                           .WithShape(S8, {2, 2, 10, 4, 44}))
-                             .WithConvDnums("b01f?_01i?o->b01f?"))
-                         .WithShape(S8, {10, 20, 30, 11, 4})),
-          m::Op())));
-
-  EXPECT_EQ(conv->raw_backend_config_string(), "{bar: 0}");
-}
-
-TEST_F(CudnnVectorizeConvolutionsTest, NoVectorizeTo4UnsupportedFilterType) {
-  // This test checks that the vectorize pass correctly calls
-  // CudnnSupportsOptimizedIntegerConvolution() which should reject this
-  // convolution because its filter type is f32.
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule TestModule
-
-  ENTRY TestComputation {
-    input = s8[10,20,30,40] parameter(0)
-    filter = f32[2,2,40,44] parameter(1)
-    ROOT result = (s8[10,20,30,44], u8[0]) custom-call(input, filter),
-                  window={size=2x2}, dim_labels=b01f_01io->b01f,
-                  custom_call_target="__cudnn$convForward",
-                  backend_config="{bar: 0}"
-  })")
-                    .value();
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
-  EXPECT_FALSE(changed);
-}
-
-TEST_F(CudnnVectorizeConvolutionsTest, VectorizeTo4NCHW) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule TestModule
-
-  ENTRY TestComputation {
-    input = s8[10,48,20,30] parameter(0)
-    filter = s8[48,44,2,2] parameter(1)
-    ROOT result = (s8[10,44,20,30], u8[0]) custom-call(input, filter),
-                  window={size=2x2}, dim_labels=bf01_io01->bf01,
-                  custom_call_target="__cudnn$convForward"
-  })")
-                    .value();
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
-  EXPECT_TRUE(changed);
-
-  SCOPED_TRACE(module->ToString());
-  auto* root = module->entry_computation()->root_instruction();
-
-  const HloInstruction* conv = nullptr;
-  ASSERT_THAT(
-      root,
-      GmockMatch(m::Tuple(
-          m::Reshape(m::GetTupleElement(
-                         m::CustomCall(&conv, {kCudnnConvForwardCallTarget},
-                                       m::Reshape(m::Parameter(0))
-                                           .WithShape(S8, {10, 12, 4, 20, 30}),
-                                       m::Reshape(m::Parameter(1))
-                                           .WithShape(S8, {12, 4, 44, 2, 2}))
-                             .WithConvDnums("bf?01_i?o01->bf?01"))
-                         .WithShape(S8, {10, 11, 4, 20, 30})),
-          m::Op())));
-}
-
-TEST_F(CudnnVectorizeConvolutionsTest, IncrementAllDnums) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule TestModule
-
-  ENTRY TestComputation {
-    input = s8[16,16,16,16] parameter(0)
-    filter = s8[16,16,3,3] parameter(1)
-    ROOT result = (s8[16,16,16,16], u8[0]) custom-call(input, filter),
-                  window={size=2x2}, dim_labels=fb01_i01o->fb01,
-                  custom_call_target="__cudnn$convForward"
-  })")
-                    .value();
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
-  EXPECT_TRUE(changed);
-
-  SCOPED_TRACE(module->ToString());
-  auto* root = module->entry_computation()->root_instruction();
-
-  const HloInstruction* conv = nullptr;
-  ASSERT_THAT(
-      root,
-      GmockMatch(m::Tuple(
-          m::Reshape(m::GetTupleElement(
-                         m::CustomCall(&conv, {kCudnnConvForwardCallTarget},
-                                       m::Reshape(m::Parameter(0))
-                                           .WithShape(S8, {4, 4, 16, 16, 16}),
-                                       m::Reshape(m::Parameter(1))
-                                           .WithShape(S8, {4, 4, 16, 3, 3}))
-                             .WithConvDnums("f?b01_i?01o->f?b01"))
-                         .WithShape(S8, {4, 4, 16, 16, 16})),
-          m::Op())));
-}
-
-TEST_F(CudnnVectorizeConvolutionsTest, FilterDnums) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule TestModule
-
-  ENTRY TestComputation {
-    input = s8[1,20,9,9] parameter(0)
-    filter = s8[3,3,20,32] parameter(1)
-    ROOT result = (s8[1,32,9,9], u8[0]) custom-call(s8[1,20,9,9] input, s8[3,3,20,32] filter),
-                  window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01,
-                  custom_call_target="__cudnn$convForward"
-  })")
-                    .value();
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
-  EXPECT_TRUE(changed);
-
-  SCOPED_TRACE(module->ToString());
-  auto* root = module->entry_computation()->root_instruction();
-
-  const HloInstruction* conv = nullptr;
-  ASSERT_THAT(
-      root,
-      GmockMatch(m::Tuple(
-          m::Reshape(m::GetTupleElement(
-                         m::CustomCall(&conv, {kCudnnConvForwardCallTarget},
-                                       m::Reshape(m::Parameter(0))
-                                           .WithShape(S8, {1, 5, 4, 9, 9}),
-                                       m::Reshape(m::Parameter(1))
-                                           .WithShape(S8, {3, 3, 5, 4, 32}))
-                             .WithConvDnums("bf?01_01i?o->bf?01"))
-                         .WithShape(S8, {1, 8, 4, 9, 9})),
-          m::Op())));
-}
-
-TEST_F(CudnnVectorizeConvolutionsTest, NoVectorizeTo4) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule TestModule
-
-  ENTRY TestComputation {
-    input = s8[10,20,30,41] parameter(0)
-    filter = s8[2,2,41,44] parameter(1)
-    ROOT result = (s8[10,20,30,44], u8[0]) custom-call(input, filter),
-                  window={size=2x2}, dim_labels=b01f_01io->b01f,
-                  custom_call_target="__cudnn$convForward"
-  })")
-                    .value();
-  CudnnVectorizeConvolutions pass(
-      /*compute_capability=*/{7, 5},
-      /*cudnn_version=*/se::dnn::VersionInfo{8, 3, 0});
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
-
-  SCOPED_TRACE(module->ToString());
-  EXPECT_FALSE(changed);
-}
-
-// Don't vectorize int8_t -> int32_t into int8x4 or int8x32; this is not
-// supported in cudnn.
-TEST_F(CudnnVectorizeConvolutionsTest, NoVectorizeTo4IfOutputIsS32) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule TestModule
-
-  ENTRY TestComputation {
-    input = s8[10,20,30,41] parameter(0)
-    filter = s8[2,2,41,44] parameter(1)
-    ROOT result = (s32[10,20,30,44], u8[0]) custom-call(input, filter),
-                  window={size=2x2}, dim_labels=b01f_01io->b01f,
-                  custom_call_target="__cudnn$convForward"
-  })")
-                    .value();
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
-  SCOPED_TRACE(module->ToString());
-  EXPECT_FALSE(changed);
-}
-
-// Don't vectorize int8_t -> float into int8x4 or int8x32.  Vectorizing to
-// int8x4 *is* allowed by cudnn, but we don't do it at the moment.
-TEST_F(CudnnVectorizeConvolutionsTest, NoVectorizeTo4IfOutputIsF32) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule TestModule
-
-  ENTRY TestComputation {
-    input = s8[10,20,30,41] parameter(0)
-    filter = s8[2,2,41,44] parameter(1)
-    ROOT result = (f32[10,20,30,44], u8[0]) custom-call(input, filter),
-                  window={size=2x2}, dim_labels=b01f_01io->b01f,
-                  custom_call_target="__cudnn$convForward"
-  })")
-                    .value();
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
-  SCOPED_TRACE(module->ToString());
-  EXPECT_FALSE(changed);
-}
-
-TEST_F(CudnnVectorizeConvolutionsTest, VectorizeTo32) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule TestModule
-
-  ENTRY TestComputation {
-    input = s8[10,20,30,64] parameter(0)
-    filter = s8[2,2,64,128] parameter(1)
-    ROOT result = (s8[10,20,30,128], u8[0]) custom-call(input, filter),
-                  window={size=2x2}, dim_labels=b01f_01io->b01f,
-                  custom_call_target="__cudnn$convForward"
-  })")
-                    .value();
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
-  EXPECT_TRUE(changed);
-
-  SCOPED_TRACE(module->ToString());
-  auto* root = module->entry_computation()->root_instruction();
-
-  const HloInstruction* conv = nullptr;
-  ASSERT_THAT(
-      root,
-      GmockMatch(m::Tuple(
-          m::Reshape(
-              m::GetTupleElement(
-                  m::CustomCall(
-                      &conv, {kCudnnConvForwardCallTarget},
-                      m::Reshape(m::Parameter(0))
-                          .WithShape(S8, {10, 20, 30, 2, 32}),
-                      m::Reshape(
-                          m::Transpose(
-                              m::Reshape(m::Parameter(1))
-                                  .WithShape(S8, {2, 2, 2, 8, 4, 16, 4, 2}))
-                              .WithShape(S8, {2, 2, 2, 16, 2, 8, 4, 4})
-                              .WithPredicate([](const HloInstruction* instr) {
-                                return absl::c_equal(
-                                    instr->dimensions(),
-                                    std::vector<int64_t>{2, 0, 1, 5, 7, 3, 6,
-                                                         4});
-                              }))
-                          .WithShape(S8, {128, 2, 2, 2, 32})))
-                  .WithShape(S8, {10, 20, 30, 4, 32})),
-          m::Op())));
-
-  EXPECT_TRUE(conv->backend_config<GpuBackendConfig>()
-                  ->cudnn_conv_backend_config()
-                  .reordered_int8_nchw_vect());
-}
-
-TEST_F(CudnnVectorizeConvolutionsTest, BiasAndSideInput) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule TestModule
-
-  ENTRY TestComputation {
-    input = s8[10,20,30,64] parameter(0)
-    filter = s8[2,2,64,128] parameter(1)
-    bias = f32[128] parameter(2)
-    side_input = s8[10,20,30,64] parameter(3)
-
-    ROOT result = (s8[10,20,30,128], u8[0]) custom-call(input, filter, bias, side_input),
-                  window={size=2x2}, dim_labels=b01f_01io->b01f,
-                  custom_call_target="__cudnn$convForward"
-  })")
-                    .value();
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
-  EXPECT_TRUE(changed);
-
-  SCOPED_TRACE(module->ToString());
-  auto* root = module->entry_computation()->root_instruction();
-
-  const HloInstruction* conv = nullptr;
-  ASSERT_THAT(
-      root,
-      GmockMatch(m::Tuple(
-          m::Reshape(
-              m::GetTupleElement(
-                  m::CustomCall(
-                      &conv, {kCudnnConvForwardCallTarget},
-                      m::Reshape(m::Parameter(0))
-                          .WithShape(S8, {10, 20, 30, 2, 32}),
-                      m::Reshape(m::Transpose(m::Reshape(m::Parameter(1))))
-                          .WithShape(S8, {128, 2, 2, 2, 32}),
-                      m::Reshape(
-                          m::Transpose(m::Reshape(m::Parameter(2))
-                                           .WithShape(F32, {4, 4, 2, 4}))
-                              .WithShape(F32, {4, 2, 4, 4})
-                              .WithPredicate([](const HloInstruction* instr) {
-                                return absl::c_equal(
-                                    instr->dimensions(),
-                                    std::vector<int64_t>{0, 2, 1, 3});
-                              }))
-                          .WithShape(F32, {128}),
-                      m::Reshape(m::Parameter(3))
-                          .WithShape(S8, {10, 20, 30, 2, 32})))
-                  .WithShape(S8, {10, 20, 30, 4, 32})),
-          m::Op())));
-
-  EXPECT_TRUE(conv->backend_config<GpuBackendConfig>()
-                  ->cudnn_conv_backend_config()
-                  .reordered_int8_nchw_vect());
-}
-
-TEST_F(CudnnVectorizeConvolutionsTest, InputNHWC_OutputNCHW) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule TestModule
-
-  ENTRY TestComputation {
-    input = s8[10,20,30,64] parameter(0)
-    filter = s8[2,2,64,128] parameter(1)
-    bias = f32[128] parameter(2)
-    side_input = s8[10,128,20,30] parameter(3)
-
-    ROOT result = (s8[10,128,20,30], u8[0]) custom-call(input, filter, bias, side_input),
-                  window={size=2x2}, dim_labels=b01f_01io->bf01,
-                  custom_call_target="__cudnn$convForward"
-  })")
-                    .value();
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
-  EXPECT_TRUE(changed);
-
-  SCOPED_TRACE(module->ToString());
-  auto* root = module->entry_computation()->root_instruction();
-
-  const HloInstruction* conv = nullptr;
-  ASSERT_THAT(
-      root,
-      GmockMatch(m::Tuple(
-          m::Reshape(
-              m::GetTupleElement(
-                  m::CustomCall(
-                      &conv, {kCudnnConvForwardCallTarget},
-                      m::Reshape(m::Parameter(0))
-                          .WithShape(S8, {10, 20, 30, 2, 32}),
-                      m::Reshape(m::Transpose(m::Reshape(m::Parameter(1))))
-                          .WithShape(S8, {128, 2, 2, 2, 32}),
-                      m::Reshape(
-                          m::Transpose(m::Reshape(m::Parameter(2))
-                                           .WithShape(F32, {4, 4, 2, 4}))
-                              .WithShape(F32, {4, 2, 4, 4})
-                              .WithPredicate([](const HloInstruction* instr) {
-                                return absl::c_equal(
-                                    instr->dimensions(),
-                                    std::vector<int64_t>{0, 2, 1, 3});
-                              }))
-                          .WithShape(F32, {128}),
-                      m::Reshape(m::Parameter(3))
-                          .WithShape(S8, {10, 4, 32, 20, 30})))
-                  .WithShape(S8, {10, 4, 32, 20, 30})),
-          m::Op())));
-
-  EXPECT_TRUE(conv->backend_config<GpuBackendConfig>()
-                  ->cudnn_conv_backend_config()
-                  .reordered_int8_nchw_vect());
-}
-
-TEST_F(CudnnVectorizeConvolutionsTest, NoVectorizeTo32) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule TestModule
-
-  ENTRY TestComputation {
-    input = s8[10,20,30,64] parameter(0)
-    filter = s8[2,2,64,128] parameter(1)
-    ROOT result = (s8[10,20,30,128], u8[0]) custom-call(input, filter),
-                  window={size=2x2}, dim_labels=b01f_01io->b01f,
-                  custom_call_target="__cudnn$convForward"
-  })")
-                    .value();
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 0}, module.get()));
-  EXPECT_TRUE(changed);
-
-  SCOPED_TRACE(module->ToString());
-  auto* root = module->entry_computation()->root_instruction();
-
-  const HloInstruction* conv = nullptr;
-  ASSERT_THAT(
-      root,
-      GmockMatch(m::Tuple(
-          m::Reshape(m::GetTupleElement(
-                         m::CustomCall(&conv, {kCudnnConvForwardCallTarget},
-                                       m::Reshape(m::Parameter(0))
-                                           .WithShape(S8, {10, 20, 30, 16, 4}),
-                                       m::Reshape(m::Parameter(1))
-                                           .WithShape(S8, {2, 2, 16, 4, 128})))
-                         .WithShape(S8, {10, 20, 30, 32, 4})),
-          m::Op())));
-
-  EXPECT_FALSE(conv->backend_config<GpuBackendConfig>()
-                   ->cudnn_conv_backend_config()
-                   .reordered_int8_nchw_vect());
-}
-
-TEST_F(CudnnVectorizeConvolutionsTest, Vectorize4To32) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule TestModule
-
-  ENTRY TestComputation {
-    input = s8[10,20,30,16,4] parameter(0)
-    filter = s8[3,5,16,192,4] parameter(1)
-    bias = f32[64] parameter(2)
-    side_input = s8[10,20,30,16,4] parameter(3)
-    ROOT result = (s8[10,20,30,48,4], u8[0]) custom-call(input, filter, bias, side_input),
-                  window={size=3x5}, dim_labels=b01f_01io->b01f,
-                  custom_call_target="__cudnn$convForward"
-  })")
-                    .value();
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
-  EXPECT_TRUE(changed);
-
-  SCOPED_TRACE(module->ToString());
-  auto* root = module->entry_computation()->root_instruction();
-
-  const HloInstruction* conv = nullptr;
-  auto conv_pat =
-      m::GetTupleElement(
-          m::CustomCall(
-              &conv, {kCudnnConvForwardCallTarget},
-              m::Reshape(m::Transpose(m::Reshape(m::Parameter(0))
-                                          .WithShape(S8, {10, 20, 30, 2, 8, 4}))
-                             .WithShape(S8, {10, 20, 30, 2, 8, 4}))
-                  .WithShape(S8, {10, 20, 30, 2, 32}),
-              m::Reshape(
-                  m::Transpose(m::Reshape(m::Parameter(1))
-                                   .WithShape(S8, {3, 5, 2, 8, 24, 4, 2, 4}))
-                      .WithShape(S8, {2, 3, 5, 24, 2, 8, 4, 4})
-                      .WithPredicate([](const HloInstruction* instr) {
-                        return absl::c_equal(
-                            instr->dimensions(),
-                            std::vector<int64_t>{2, 0, 1, 4, 6, 3, 5, 7});
-                      }))
-                  .WithShape(S8, {192, 2, 3, 5, 32}),
-              m::Reshape(m::Transpose(m::Reshape(m::Parameter(2)))),
-              m::Reshape(m::Transpose(m::Reshape(m::Parameter(3))
-                                          .WithShape(S8, {10, 20, 30, 2, 8, 4}))
-                             .WithShape(S8, {10, 20, 30, 2, 8, 4}))
-                  .WithShape(S8, {10, 20, 30, 2, 32}))
-              .WithConvDnums("b01f?_oi01?->b01f?"))
-          .WithShape(S8, {10, 20, 30, 6, 32});
-  ASSERT_THAT(root, GmockMatch(m::Tuple(
-                        m::Reshape(m::Transpose(m::Reshape(conv_pat).WithShape(
-                                                    S8, {10, 20, 30, 6, 8, 4}))
-                                       .WithShape(S8, {10, 20, 30, 6, 8, 4}))
-                            .WithShape(S8, {10, 20, 30, 48, 4}),
-                        m::Op())));
-
-  EXPECT_TRUE(conv->backend_config<GpuBackendConfig>()
-                  ->cudnn_conv_backend_config()
-                  .reordered_int8_nchw_vect());
-}
-
-TEST_F(CudnnVectorizeConvolutionsTest, Vectorize4To32NCHW) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule TestModule
-
-  ENTRY TestComputation {
-    input = s8[10,16,20,30,4] parameter(0)
-    filter = s8[16,128,2,2,4] parameter(1)
-    bias = f32[64] parameter(2)
-    side_input = s8[10,16,20,30,4] parameter(3)
-    ROOT result = (s8[10,32,20,30,4], u8[0]) custom-call(input, filter, bias, side_input),
-                  window={size=2x2}, dim_labels=bf01_io01->bf01,
-                  custom_call_target="__cudnn$convForward"
-  })")
-                    .value();
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
-  EXPECT_TRUE(changed);
-
-  SCOPED_TRACE(module->ToString());
-  auto* root = module->entry_computation()->root_instruction();
-
-  const HloInstruction* conv = nullptr;
-  auto conv_pat =
-      m::GetTupleElement(
-          m::CustomCall(
-              &conv, {kCudnnConvForwardCallTarget},
-              m::Reshape(m::Transpose(m::Reshape(m::Parameter(0))
-                                          .WithShape(S8, {10, 2, 8, 20, 30, 4}))
-                             .WithShape(S8, {10, 2, 20, 30, 8, 4}))
-                  .WithShape(S8, {10, 2, 20, 30, 32}),
-              m::Reshape(
-                  m::Transpose(m::Reshape(m::Parameter(1))
-                                   .WithShape(S8, {2, 8, 16, 4, 2, 2, 2, 4}))
-                      .WithShape(S8, {2, 2, 2, 16, 2, 8, 4, 4})
-                      .WithPredicate([](const HloInstruction* instr) {
-                        return absl::c_equal(
-                            instr->dimensions(),
-                            std::vector<int64_t>{0, 5, 6, 2, 4, 1, 3, 7});
-                      }))
-                  .WithShape(S8, {128, 2, 2, 2, 32}),
-              m::Reshape(m::Transpose(m::Reshape(m::Parameter(2)))),
-              m::Reshape(m::Transpose(m::Reshape(m::Parameter(3))
-                                          .WithShape(S8, {10, 2, 8, 20, 30, 4}))
-                             .WithShape(S8, {10, 2, 20, 30, 8, 4}))
-                  .WithShape(S8, {10, 2, 20, 30, 32}))
-              .WithConvDnums("bf01_oi01->bf01"))
-          .WithShape(S8, {10, 4, 20, 30, 32});
-  ASSERT_THAT(root, GmockMatch(m::Tuple(
-                        m::Reshape(m::Transpose(m::Reshape(conv_pat).WithShape(
-                                                    S8, {10, 4, 20, 30, 8, 4}))
-                                       .WithShape(S8, {10, 4, 8, 20, 30, 4}))
-                            .WithShape(S8, {10, 32, 20, 30, 4}),
-                        m::Op())));
-
-  EXPECT_TRUE(conv->backend_config<GpuBackendConfig>()
-                  ->cudnn_conv_backend_config()
-                  .reordered_int8_nchw_vect());
-}
-
-TEST_F(CudnnVectorizeConvolutionsTest, Vectorize4To32VectorDimFirst) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule TestModule
-
-  ENTRY TestComputation {
-    input = s8[4,10,20,30,16] parameter(0)
-    filter = s8[4,3,5,16,192] parameter(1)
-    bias = f32[64] parameter(2)
-    side_input = s8[4,10,20,30,16] parameter(3)
-    ROOT result = (s8[4,10,20,30,48], u8[0]) custom-call(input, filter, bias, side_input),
-                  window={size=3x5}, dim_labels=?b01f_?01io->?b01f,
-                  custom_call_target="__cudnn$convForward"
-  })")
-                    .value();
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
-  EXPECT_TRUE(changed);
-
-  SCOPED_TRACE(module->ToString());
-  auto* root = module->entry_computation()->root_instruction();
-
-  const HloInstruction* conv = nullptr;
-  auto conv_pat =
-      m::GetTupleElement(
-          m::CustomCall(
-              &conv, {kCudnnConvForwardCallTarget},
-              m::Reshape(m::Transpose(m::Reshape(m::Parameter(0))
-                                          .WithShape(S8, {4, 10, 20, 30, 2, 8}))
-                             .WithShape(S8, {8, 4, 10, 20, 30, 2}))
-                  .WithShape(S8, {32, 10, 20, 30, 2}),
-              m::Reshape(
-                  m::Transpose(m::Reshape(m::Parameter(1))
-                                   .WithShape(S8, {4, 3, 5, 2, 8, 24, 4, 2}))
-                      .WithShape(S8, {2, 3, 5, 24, 2, 8, 4, 4})
-                      .WithPredicate([](const HloInstruction* instr) {
-                        return absl::c_equal(
-                            instr->dimensions(),
-                            std::vector<int64_t>{3, 1, 2, 5, 7, 4, 6, 0});
-                      }))
-                  .WithShape(S8, {192, 2, 3, 5, 32}),
-              m::Reshape(m::Transpose(m::Reshape(m::Parameter(2)))),
-              m::Reshape(m::Transpose(m::Reshape(m::Parameter(3))
-                                          .WithShape(S8, {4, 10, 20, 30, 2, 8}))
-                             .WithShape(S8, {8, 4, 10, 20, 30, 2}))
-                  .WithShape(S8, {32, 10, 20, 30, 2}))
-              .WithConvDnums("?b01f_oi01->?b01f"))
-          .WithShape(S8, {32, 10, 20, 30, 6});
-  ASSERT_THAT(root, GmockMatch(m::Tuple(
-                        m::Reshape(m::Transpose(m::Reshape(conv_pat).WithShape(
-                                                    S8, {8, 4, 10, 20, 30, 6}))
-                                       .WithShape(S8, {4, 10, 20, 30, 6, 8}))
-                            .WithShape(S8, {4, 10, 20, 30, 48}),
-                        m::Op())));
-
-  EXPECT_TRUE(conv->backend_config<GpuBackendConfig>()
-                  ->cudnn_conv_backend_config()
-                  .reordered_int8_nchw_vect());
-}
-
-TEST_F(CudnnVectorizeConvolutionsTest, NoVectorize4To32) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule TestModule
-
-  ENTRY TestComputation {
-    input = s8[10,20,30,16,4] parameter(0)
-    filter = s8[2,2,16,128,4] parameter(1)
-    bias = f32[10] parameter(2)
-    side_input = s8[10,20,30,16,4] parameter(3)
-    ROOT result = (s8[10,20,30,32,4], u8[0]) custom-call(input, filter, bias, side_input),
-                  window={size=2x2}, dim_labels=b01f_01io->b01f,
-                  custom_call_target="__cudnn$convForward"
-  })")
-                    .value();
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 0}, module.get()));
-  EXPECT_FALSE(changed);
-}
-
-TEST_F(CudnnVectorizeConvolutionsTest, Vectorize16To32) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule TestModule
-
-  ENTRY TestComputation {
-    input = s8[10,20,30,4,16] parameter(0)
-    filter = s8[3,5,4,192,16] parameter(1)
-    ROOT result = (s8[10,20,30,12,16], u8[0]) custom-call(input, filter),
-                  window={size=3x5}, dim_labels=b01f_01io->b01f,
-                  custom_call_target="__cudnn$convForward"
-  })")
-                    .value();
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
-  EXPECT_TRUE(changed);
-
-  SCOPED_TRACE(module->ToString());
-  auto* root = module->entry_computation()->root_instruction();
-
-  const HloInstruction* conv = nullptr;
-  auto filter_pat =
-      m::Reshape(
-          m::Transpose(
-              m::Reshape(m::Parameter(1)).WithShape(S8, {3, 5, 2, 2, 192, 16}))
-              .WithShape(S8, {3, 5, 2, 192, 2, 16}))
-          .WithShape(S8, {3, 5, 2, 192, 32});
-  auto conv_pat =
-      m::GetTupleElement(
-          m::CustomCall(
-              &conv, {kCudnnConvForwardCallTarget},
-              m::Reshape(
-                  m::Transpose(m::Reshape(m::Parameter(0))
-                                   .WithShape(S8, {10, 20, 30, 2, 2, 16}))
-                      .WithShape(S8, {10, 20, 30, 2, 2, 16}))
-                  .WithShape(S8, {10, 20, 30, 2, 32}),
-              m::Reshape(
-                  m::Transpose(m::Reshape(filter_pat)
-                                   .WithShape(S8, {3, 5, 2, 24, 4, 2, 8, 4}))
-                      .WithShape(S8, {2, 3, 5, 24, 2, 8, 4, 4}))
-                  .WithShape(S8, {192, 2, 3, 5, 32}))
-              .WithConvDnums("b01f_oi01->b01f"))
-          .WithShape(S8, {10, 20, 30, 6, 32});
-  ASSERT_THAT(root, GmockMatch(m::Tuple(
-                        m::Reshape(m::Transpose(m::Reshape(conv_pat).WithShape(
-                                                    S8, {10, 20, 30, 6, 2, 16}))
-                                       .WithShape(S8, {10, 20, 30, 6, 2, 16}))
-                            .WithShape(S8, {10, 20, 30, 12, 16}),
-                        m::Op())));
-  EXPECT_TRUE(conv->backend_config<GpuBackendConfig>()
-                  ->cudnn_conv_backend_config()
-                  .reordered_int8_nchw_vect());
-}
-
-TEST_F(CudnnVectorizeConvolutionsTest, VectorizeMixedTo32) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule TestModule
-
-  ENTRY TestComputation {
-    input = s8[10,20,30,8,8] parameter(0)
-    filter = s8[3,5,2,192,32] parameter(1)
-    ROOT result = (s8[10,20,30,96,2], u8[0]) custom-call(input, filter),
-                  window={size=3x5}, dim_labels=b01f_01io->b01f,
-                  custom_call_target="__cudnn$convForward"
-  })")
-                    .value();
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
-  EXPECT_TRUE(changed);
-
-  SCOPED_TRACE(module->ToString());
-  auto* root = module->entry_computation()->root_instruction();
-
-  const HloInstruction* conv = nullptr;
-  auto conv_pat =
-      m::GetTupleElement(
-          m::CustomCall(
-              &conv, {kCudnnConvForwardCallTarget},
-              m::Reshape(m::Transpose(m::Reshape(m::Parameter(0))
-                                          .WithShape(S8, {10, 20, 30, 2, 4, 8}))
-                             .WithShape(S8, {10, 20, 30, 2, 4, 8}))
-                  .WithShape(S8, {10, 20, 30, 2, 32}),
-              m::Reshape(
-                  m::Transpose(m::Reshape(m::Parameter(1))
-                                   .WithShape(S8, {3, 5, 2, 24, 4, 2, 8, 4}))
-                      .WithShape(S8, {2, 3, 5, 24, 2, 8, 4, 4}))
-                  .WithShape(S8, {192, 2, 3, 5, 32}))
-              .WithConvDnums("b01f_oi01->b01f"))
-          .WithShape(S8, {10, 20, 30, 6, 32});
-  ASSERT_THAT(root, GmockMatch(m::Tuple(
-                        m::Reshape(m::Transpose(m::Reshape(conv_pat).WithShape(
-                                                    S8, {10, 20, 30, 6, 16, 2}))
-                                       .WithShape(S8, {10, 20, 30, 6, 16, 2}))
-                            .WithShape(S8, {10, 20, 30, 96, 2}),
-                        m::Op())));
-  EXPECT_TRUE(conv->backend_config<GpuBackendConfig>()
-                  ->cudnn_conv_backend_config()
-                  .reordered_int8_nchw_vect());
-}
-
-}  // namespace
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/cudnn_workspace_rewriter.cc b/third_party/xla/xla/service/gpu/cudnn_workspace_rewriter.cc
deleted file mode 100644
index 387a3f4..0000000
--- a/third_party/xla/xla/service/gpu/cudnn_workspace_rewriter.cc
+++ /dev/null
@@ -1,272 +0,0 @@
-/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/cudnn_workspace_rewriter.h"
-
-#include <optional>
-#include <utility>
-#include <vector>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/container/inlined_vector.h"
-#include "absl/log/check.h"
-#include "absl/log/log.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_clone_context.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/service/gpu/gpu_fused_mha_runner.h"
-#include "xla/service/gpu/stream_executor_util.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/status_macros.h"
-#include "xla/stream_executor/cuda/cuda_dnn.h"
-#include "xla/stream_executor/dnn.h"
-#include "xla/util.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-namespace {
-
-// create cuDNN graphs from HloCustomCall
-absl::StatusOr<se::gpu::CudnnGraph> HloCustomCallToCuDnnGraph(
-    se::dnn::DnnSupport& dnn_support, HloCustomCallInstruction* custom_call) {
-  if (IsFwdCustomCallTofMHA(*custom_call)) {
-    TF_ASSIGN_OR_RETURN(const xla::gpu::CudnnfMHAKind kind,
-                        xla::gpu::GetCudnnfMHAKind(custom_call));
-    std::optional<Shape> mask_shape, bias_shape;
-    {
-      bool has_bias = kind == CudnnfMHAKind::kScaleBiasSoftmax ||
-                      kind == CudnnfMHAKind::kScaleBiasSoftmaxDropout;
-
-      if (has_bias) {
-        const HloInstruction* bias = custom_call->operand(3);
-        bias_shape = bias->shape();
-      }
-    }
-
-    TF_ASSIGN_OR_RETURN(
-        const auto gpu_config,
-        custom_call->backend_config<xla::gpu::GpuBackendConfig>());
-    const xla::gpu::CudnnfMHABackendConfig& config =
-        gpu_config.cudnn_fmha_backend_config();
-    Shape intermediate_tensor_shape(config.intermediate_tensor_shape());
-    absl::InlinedVector<Shape, 2> output_shapes = {
-        ShapeUtil::GetSubshape(custom_call->shape(), {0})};
-
-    bool has_activation =
-        xla::ShapeUtil::TupleElementCount(custom_call->shape()) == 3;
-    if (has_activation) {
-      output_shapes.push_back(
-          ShapeUtil::GetSubshape(custom_call->shape(), {1}));
-    }
-
-    Shape q_shape = custom_call->operand(0)->shape();
-    Shape k_shape = custom_call->operand(1)->shape();
-    Shape v_shape = custom_call->operand(2)->shape();
-    TF_ASSIGN_OR_RETURN(CudnnfMHAMaskKind cudnn_mask_type,
-                        AsCudnnFmhaMaskKind(config.mask_type()));
-    GpufMHADescriptor descriptor = {kind,
-                                    config,
-                                    cudnn_mask_type,
-                                    q_shape,
-                                    k_shape,
-                                    v_shape,
-                                    intermediate_tensor_shape,
-                                    output_shapes,
-                                    config.bmm1_dot_dimension_numbers(),
-                                    config.bmm2_dot_dimension_numbers(),
-                                    mask_shape,
-                                    bias_shape};
-
-    TF_ASSIGN_OR_RETURN(GpufMHAConfig fmha_config,
-                        GpufMHAConfig::For(descriptor));
-    TF_ASSIGN_OR_RETURN(
-        se::dnn::FMHAMaskKind dnn_mask_type,
-        GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(fmha_config.mask_type));
-    TF_ASSIGN_OR_RETURN(
-        se::gpu::CudnnGraph graph,
-        se::gpu::GetCudnnFlashAttentionOperationGraph(
-            dnn_support, fmha_config.lhs_bmm1, fmha_config.rhs_bmm1,
-            fmha_config.rhs_bmm2, fmha_config.output, fmha_config.bias,
-            fmha_config.activation, static_cast<float>(*fmha_config.fmha_scale),
-            fmha_config.dropout_rate && *fmha_config.dropout_rate > 0.0,
-            fmha_config.dropout_rate, dnn_mask_type));
-    return std::move(graph);
-  } else {
-    TF_ASSIGN_OR_RETURN(
-        auto gpu_config,
-        custom_call->backend_config<xla::gpu::GpuBackendConfig>());
-    xla::gpu::CudnnfMHABackendConfig& config =
-        *gpu_config.mutable_cudnn_fmha_backend_config();
-
-    int input_index = 0;
-    Shape bmm1_grad_gemm1_rhs_shape =
-        custom_call->operand(input_index++)->shape();
-    Shape bmm1_grad_gemm2_rhs_shape =
-        custom_call->operand(input_index++)->shape();
-    Shape bmm2_grad_gemm2_rhs_shape =
-        custom_call->operand(input_index++)->shape();
-    Shape bmm2_grad_gemm1_lhs_shape(config.intermediate_tensor_shape());
-    input_index++;
-    Shape d_output_shape = custom_call->operand(input_index++)->shape();
-
-    TF_ASSIGN_OR_RETURN(const CudnnfMHAKind kind,
-                        GetCudnnfMHAKind(custom_call));
-    std::optional<Shape> mask_shape;
-
-    bool has_bias = (kind == CudnnfMHAKind::kBackwardScaleBiasSoftmax ||
-                     kind == CudnnfMHAKind::kBackwardScaleBiasSoftmaxDropout);
-    std::optional<Shape> bias_shape;
-    if (has_bias) {
-      bias_shape = custom_call->operand(input_index++)->shape();
-    }
-
-    std::optional<Shape> fwd_output_shape =
-        custom_call->operand(input_index++)->shape();
-    if (config.mask_type() == xla::gpu::CudnnfMHABackendConfig::PADDING ||
-        config.mask_type() ==
-            xla::gpu::CudnnfMHABackendConfig::PADDING_CAUSAL) {
-      // skip q_seqlen and kv_seqlen
-      input_index += 2;
-    }
-    TF_RET_CHECK(input_index == custom_call->operand_count());
-
-    int output_index = 0;
-    Shape d_bmm1_lhs_shape =
-        ShapeUtil::GetSubshape(custom_call->shape(), {output_index++});
-    Shape d_bmm1_rhs_shape =
-        ShapeUtil::GetSubshape(custom_call->shape(), {output_index++});
-    Shape d_bmm2_rhs_shape =
-        ShapeUtil::GetSubshape(custom_call->shape(), {output_index++});
-    std::optional<Shape> d_s_shape;
-    std::optional<Shape> d_bias_shape;
-    bool has_dbias = custom_call->shape().tuple_shapes().size() == 5;
-    if (has_dbias) {
-      d_bias_shape =
-          ShapeUtil::GetSubshape(custom_call->shape(), {output_index++});
-    }
-    // The last one is the workspace.
-    TF_RET_CHECK(output_index ==
-                 custom_call->shape().tuple_shapes().size() - 1);
-    TF_ASSIGN_OR_RETURN(CudnnfMHAMaskKind cudnn_mask_type,
-                        AsCudnnFmhaMaskKind(config.mask_type()));
-
-    const bool force_deterministic =
-        RequireDeterminism(custom_call->GetModule()->config());
-    // set the correct force_deterministic attribute here
-    config.set_force_deterministic(force_deterministic);
-    TF_RETURN_IF_ERROR(custom_call->set_backend_config(gpu_config));
-
-    GpufMHABackwardDescriptor descriptor = {
-        kind,
-        config,
-        cudnn_mask_type,
-        bmm1_grad_gemm1_rhs_shape,
-        bmm1_grad_gemm2_rhs_shape,
-        bmm2_grad_gemm1_lhs_shape,
-        bmm2_grad_gemm2_rhs_shape,
-        d_output_shape,
-        d_bmm1_lhs_shape,
-        d_bmm1_rhs_shape,
-        d_bmm2_rhs_shape,
-        config.bmm1_grad_gemm1_dot_dimension_numbers(),
-        config.bmm1_grad_gemm2_dot_dimension_numbers(),
-        config.bmm2_grad_gemm1_dot_dimension_numbers(),
-        config.bmm2_grad_gemm2_dot_dimension_numbers(),
-        d_s_shape,
-        fwd_output_shape,
-        mask_shape,
-        d_bias_shape,
-        bias_shape,
-        force_deterministic};
-
-    TF_ASSIGN_OR_RETURN(GpufMHABackwardConfig fmha_config,
-                        GpufMHABackwardConfig::For(descriptor));
-    TF_ASSIGN_OR_RETURN(
-        se::dnn::FMHAMaskKind dnn_mask_type,
-        GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(fmha_config.mask_type));
-
-    TF_ASSIGN_OR_RETURN(
-        se::gpu::CudnnGraph graph,
-        se::gpu::GetCudnnFlashAttentionBackwardOperationGraph(
-            dnn_support, fmha_config.bmm1_grad_gemm1_rhs,
-            fmha_config.bmm1_grad_gemm2_rhs, fmha_config.bmm2_grad_gemm1_lhs,
-            fmha_config.bmm2_grad_gemm2_rhs, fmha_config.d_output,
-            fmha_config.d_bmm1_lhs, fmha_config.d_bmm1_rhs,
-            fmha_config.d_bmm2_rhs, fmha_config.bias, fmha_config.dropout_rate,
-            fmha_config.seed, *fmha_config.fmha_scale,
-            fmha_config.dropout_rate && *fmha_config.dropout_rate > 0.0,
-            fmha_config.bias != std::nullopt, dnn_mask_type,
-            force_deterministic));
-    return std::move(graph);
-  }
-}
-
-class CuDnnCustomCallVisitor : public DfsHloRewriteVisitor {
- public:
-  explicit CuDnnCustomCallVisitor(se::dnn::DnnSupport& dnn_support)
-      : dnn_support_(dnn_support) {}
-
-  absl::Status HandleCustomCall(HloInstruction* hlo) override {
-    if (!IsCustomCallTofMHA(*hlo)) {
-      // don't do anything about other cuDNN custom calls
-      return absl::OkStatus();
-    }
-    TF_ASSIGN_OR_RETURN(auto gpu_config,
-                        hlo->backend_config<GpuBackendConfig>());
-
-    TF_ASSIGN_OR_RETURN(
-        se::gpu::CudnnGraph graph,
-        HloCustomCallToCuDnnGraph(dnn_support_,
-                                  DynCast<HloCustomCallInstruction>(hlo)));
-    auto workspace = graph.Graph().get_workspace_size();
-    if (workspace != 0) {
-      // rewrite custom call to have correct workspace size
-      VLOG(4) << "Rewriting: " << hlo->ToString();
-      Shape* shape = hlo->mutable_shape();
-      shape->mutable_tuple_shapes(shape->tuple_shapes_size() - 1)
-          ->set_dimensions(0, workspace);
-      MarkAsChanged();
-    }
-    return absl::OkStatus();
-  }
-
- private:
-  se::dnn::DnnSupport& dnn_support_;
-};
-
-}  // namespace
-
-absl::StatusOr<bool> CuDnnWorkspaceRewriter::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  XLA_SCOPED_LOGGING_TIMER("cuDNN workspace rewriter");
-  return CuDnnCustomCallVisitor(dnn_support_)
-      .RunOnModule(module, execution_threads);
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/cudnn_workspace_rewriter.h b/third_party/xla/xla/service/gpu/cudnn_workspace_rewriter.h
deleted file mode 100644
index de81d6d..0000000
--- a/third_party/xla/xla/service/gpu/cudnn_workspace_rewriter.h
+++ /dev/null
@@ -1,51 +0,0 @@
-/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_CUDNN_WORKSPACE_REWRITER_H_
-#define XLA_SERVICE_GPU_CUDNN_WORKSPACE_REWRITER_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/stream_executor/dnn.h"
-#include "xla/stream_executor/stream_executor.h"
-
-namespace xla {
-namespace gpu {
-
-// Rewrite cuDNN custom call to have correct workspace size by build graph
-// and serialize so we can use it later
-class CuDnnWorkspaceRewriter : public HloModulePass {
- public:
-  explicit CuDnnWorkspaceRewriter(se::StreamExecutor& stream_exec)
-      : dnn_support_(*stream_exec.AsDnn()) {}
-
-  absl::string_view name() const override { return "cudnn-workspace-rewriter"; }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
-  se::dnn::DnnSupport& dnn_support_;
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_CUDNN_WORKSPACE_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/cusolver_rewriter.cc b/third_party/xla/xla/service/gpu/cusolver_rewriter.cc
deleted file mode 100644
index ddfda66..0000000
--- a/third_party/xla/xla/service/gpu/cusolver_rewriter.cc
+++ /dev/null
@@ -1,201 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/cusolver_rewriter.h"
-
-#include <cstdint>
-#include <functional>
-#include <utility>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/strings/string_view.h"
-#include "xla/comparison_util.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/layout_util.h"
-#include "xla/literal.h"
-#include "xla/literal_util.h"
-#include "xla/service/gpu/cusolver_context.h"
-#include "xla/service/gpu/ir_emission_utils.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/stream_executor/blas.h"
-#include "xla/util.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-namespace {
-
-void SetFortranLayout(Shape* shape) {
-  LayoutUtil::SetToDefaultLayout(shape);
-  int n = shape->mutable_layout()->minor_to_major_size();
-  CHECK_GE(n, 2);
-  std::swap(shape->mutable_layout()->mutable_minor_to_major()->at(0),
-            shape->mutable_layout()->mutable_minor_to_major()->at(1));
-}
-
-absl::StatusOr<HloInstruction*> CreateCholesky(GpuSolverContext* context,
-                                               HloInstruction* operand,
-                                               const CholeskyOptions& options,
-                                               const OpMetadata& metadata) {
-  HloComputation* computation = operand->parent();
-
-  Shape a_shape = operand->shape();
-  int ndim = a_shape.dimensions_size();
-  CHECK_GE(ndim, 2);
-  int64_t n = a_shape.dimensions(ndim - 1);
-
-  std::vector<int64_t> batch_dims(a_shape.dimensions().begin(),
-                                  a_shape.dimensions().end() - 2);
-  std::vector<int64_t> batch_dim_ids(batch_dims.size());
-  absl::c_iota(batch_dim_ids, 0);
-  int64_t batch_size = absl::c_accumulate(batch_dims, 1, std::multiplies<>{});
-
-  // Find the workspace size.
-  se::blas::UpperLower uplo = options.lower() ? se::blas::UpperLower::kLower
-                                              : se::blas::UpperLower::kUpper;
-  int64_t workspace_size;  // Number of elements of size a_shape.element_type()
-  TF_ASSIGN_OR_RETURN(
-      workspace_size,
-      context->PotrfBufferSize(a_shape.element_type(), uplo, n, n, batch_size));
-
-  // TODO(phawkins): Ideally we would relax this constraint. What we actually
-  // want is that:
-  // a) the batch dimensions are major, in no particular order.
-  // b) the two minor dimensions are in fortran (column-major) order,
-
-  SetFortranLayout(&a_shape);
-
-  // This call returns a tuple of (cholesky_result, workspace, info) where:
-  // * cholesky_result is the result of the Cholesky decomposition,
-  // * workspace is temporary scratch memory used by cuSolver.
-  // * info contains the Potrf success/failure status.
-  // Currently we have no meaningful way to report an error, so we simply
-  // discard the success/failure information. Obviously this is suboptimal.
-  Shape info_shape = ShapeUtil::MakeShape(S32, batch_dims);
-  Shape call_shape = ShapeUtil::MakeTupleShape(
-      {a_shape,
-       ShapeUtil::MakeShape(operand->shape().element_type(), {workspace_size}),
-       info_shape});
-
-  HloInstruction* custom_call =
-      computation->AddInstruction(HloInstruction::CreateCustomCall(
-          call_shape, {operand}, kCusolverCholeskyCallTarget, {a_shape}));
-  custom_call->set_metadata(metadata);
-  TF_RETURN_IF_ERROR(custom_call->set_backend_config(options));
-  HloInstruction* out = computation->AddInstruction(
-      HloInstruction::CreateGetTupleElement(a_shape, custom_call, 0));
-  HloInstruction* info = computation->AddInstruction(
-      HloInstruction::CreateGetTupleElement(info_shape, custom_call, 2));
-
-  // If info was non-zero, indicating that the Cholesky decomposition failed,
-  // returns an array full of NaNs for the corresponding batch element.
-  HloInstruction* zero = computation->AddInstruction(
-      HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
-  HloInstruction* zeros =
-      computation->AddInstruction(HloInstruction::CreateBroadcast(
-          info_shape, zero, /*broadcast_dimensions=*/{}));
-  HloInstruction* ok = computation->AddInstruction(
-      HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, batch_dims),
-                                    info, zeros, ComparisonDirection::kEq));
-  ok = computation->AddInstruction(HloInstruction::CreateBroadcast(
-      ShapeUtil::MakeShape(PRED, a_shape.dimensions()), ok,
-      /*broadcast_dimensions=*/batch_dim_ids));
-
-  TF_ASSIGN_OR_RETURN(Literal nan_literal,
-                      LiteralUtil::NanValue(a_shape.element_type()));
-  HloInstruction* nan = computation->AddInstruction(
-      HloInstruction::CreateConstant(std::move(nan_literal)));
-  HloInstruction* nans =
-      computation->AddInstruction(HloInstruction::CreateBroadcast(
-          a_shape, nan, /*broadcast_dimensions=*/{}));
-
-  HloInstruction* select =
-      computation->AddInstruction(HloInstruction::CreateTernary(
-          a_shape, HloOpcode::kSelect, ok, out, nans));
-  return select;
-}
-
-// Tries to rewrite a single convolution into a call to cudnn.
-absl::StatusOr<bool> RunOnInstruction(GpuSolverContext* context,
-                                      HloInstruction* instruction) {
-  if (instruction->opcode() != HloOpcode::kCholesky) {
-    return false;
-  }
-
-  TF_ASSIGN_OR_RETURN(
-      HloInstruction * custom_call,
-      CreateCholesky(context, instruction->mutable_operand(0),
-                     instruction->cholesky_options(), instruction->metadata()));
-
-  VLOG(1) << "Replacing " << instruction->ToString() << " with "
-          << custom_call->ToString();
-
-  TF_RETURN_IF_ERROR(
-      instruction->parent()->ReplaceInstruction(instruction, custom_call));
-  return true;
-}
-
-}  // namespace
-
-// Rewrites the convolutions in the given computation into calls to cudnn.
-// Returns true if it made any changes.
-absl::StatusOr<bool> GpusolverRewriter::RunOnComputation(
-    HloComputation* computation) {
-  std::vector<HloInstruction*> cusolver_calls;
-  for (auto* hlo : computation->instructions()) {
-    if (hlo->opcode() == HloOpcode::kCholesky) {
-      cusolver_calls.push_back(hlo);
-    }
-  }
-
-  if (cusolver_calls.empty()) {
-    return false;
-  }
-
-  TF_ASSIGN_OR_RETURN(GpuSolverContext context, GpuSolverContext::Create());
-
-  bool changed = false;
-  for (HloInstruction* instruction : cusolver_calls) {
-    TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(&context, instruction));
-    changed |= result;
-  }
-  return changed;
-}
-
-GpusolverRewriter::GpusolverRewriter() = default;
-
-absl::StatusOr<bool> GpusolverRewriter::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  bool changed = false;
-  for (HloComputation* computation :
-       module->MakeNonfusionComputations(execution_threads)) {
-    TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation));
-    changed |= result;
-  }
-  return changed;
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/cusolver_rewriter.h b/third_party/xla/xla/service/gpu/cusolver_rewriter.h
deleted file mode 100644
index fd1d84d..0000000
--- a/third_party/xla/xla/service/gpu/cusolver_rewriter.h
+++ /dev/null
@@ -1,47 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_CUSOLVER_REWRITER_H_
-#define XLA_SERVICE_GPU_CUSOLVER_REWRITER_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-// Rewrites Cholesky calls into CustomCall HLOs that call into cuSolver.
-class GpusolverRewriter : public HloModulePass {
- public:
-  GpusolverRewriter();
-  absl::string_view name() const override { return "gpusolver-rewriter"; }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
-  absl::StatusOr<bool> RunOnComputation(HloComputation* computation);
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_CUSOLVER_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/custom_kernel_fusion_autotuner.cc b/third_party/xla/xla/service/gpu/custom_kernel_fusion_autotuner.cc
deleted file mode 100644
index d5114bc..0000000
--- a/third_party/xla/xla/service/gpu/custom_kernel_fusion_autotuner.cc
+++ /dev/null
@@ -1,220 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/custom_kernel_fusion_autotuner.h"
-
-#include <cstdint>
-#include <memory>
-#include <optional>
-#include <tuple>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/log/log.h"
-#include "absl/status/status.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/string_view.h"
-#include "absl/time/time.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/executable.h"
-#include "xla/service/gpu/autotuner_compile_util.h"
-#include "xla/service/gpu/autotuner_util.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/kernels/custom_kernel.h"
-#include "xla/service/gpu/kernels/custom_kernel_fusion.h"
-#include "xla/service/shaped_buffer.h"
-#include "xla/status_macros.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/stream_executor/device_memory_allocator.h"
-#include "xla/stream_executor/gpu/redzone_allocator.h"
-#include "xla/stream_executor/stream.h"
-#include "xla/stream_executor/stream_executor_memory_allocator.h"
-#include "xla/tools/hlo_decomposer.h"
-#include "xla/util.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-namespace {
-absl::StatusOr<std::unique_ptr<HloModule>> ExtractFusionModule(
-    HloInstruction* fusion_instruction, int64_t kernel_index) {
-  std::unique_ptr<HloModule> hlo_module =
-      ExtractInstructionIntoNewModule(*fusion_instruction);
-
-  HloInstruction* instruction =
-      hlo_module->entry_computation()->root_instruction();
-  GpuBackendConfig gpu_config =
-      instruction->backend_config<GpuBackendConfig>().value();
-  gpu_config.mutable_fusion_backend_config()
-      ->mutable_custom_fusion_config()
-      ->set_kernel_index(kernel_index);
-  TF_RETURN_IF_ERROR(instruction->set_backend_config(gpu_config));
-
-  return hlo_module;
-}
-
-absl::StatusOr<std::vector<std::tuple<int, absl::Duration>>> ProfileKernels(
-    std::vector<CustomKernel>& kernels, HloInstruction* fusion_instruction,
-    AutotunerCompileUtil& compile_util, const AutotuneConfig& autotune_config,
-    const DebugOptions& debug_options) {
-  se::StreamExecutor* stream_exec = autotune_config.GetExecutor();
-  std::vector<std::tuple<int, absl::Duration>> results;
-  for (int i = 0; i < kernels.size(); ++i) {
-    TF_ASSIGN_OR_RETURN(absl::StatusOr<std::unique_ptr<Executable>> executable,
-                        compile_util.Compile([&](const DebugOptions& opt) {
-                          return ExtractFusionModule(fusion_instruction, i);
-                        }));
-
-    se::DeviceMemoryAllocator* allocator = autotune_config.GetAllocator();
-    std::unique_ptr<se::DeviceMemoryAllocator> owned_allocator;
-    if (allocator == nullptr) {
-      owned_allocator =
-          std::make_unique<se::StreamExecutorMemoryAllocator>(stream_exec);
-      allocator = owned_allocator.get();
-    }
-    TF_ASSIGN_OR_RETURN(se::Stream* const stream, autotune_config.GetStream());
-
-    TF_ASSIGN_OR_RETURN(auto rz_buffers,
-                        RedzoneBuffers::FromInstruction(
-                            *fusion_instruction, autotune_config, debug_options,
-                            RedzoneBuffers::kAllInputs));
-
-    std::optional<ScopedShapedBuffer> reference_buffer;
-    std::optional<AutotunerCompileUtil::ProfilingOutput> profiling_output;
-    TF_ASSIGN_OR_RETURN(profiling_output, compile_util.ProfileExecutable(
-                                              executable->get(), stream,
-                                              rz_buffers.input_buffers(),
-                                              rz_buffers.input_shapes()));
-    results.push_back({i, profiling_output->duration});
-  }
-  return results;
-}
-
-absl::StatusOr<int> FindFastestKernel(
-    const std::vector<std::tuple<int, absl::Duration>>& results) {
-  auto iter = absl::c_min_element(
-      results, [](const std::tuple<int, absl::Duration>& lhs,
-                  const std::tuple<int, absl::Duration>& rhs) {
-        return std::get<1>(lhs) < std::get<1>(rhs);
-      });
-  if (iter == results.end()) {
-    return absl::InternalError("Failed to find fastest kernel.");
-  }
-  return std::get<0>(*iter);
-}
-
-absl::Status UpdateFusionInstructionKernelIndex(
-    HloInstruction* fusion_instruction, int kernel_index) {
-  GpuBackendConfig gpu_config =
-      fusion_instruction->backend_config<GpuBackendConfig>().value();
-  gpu_config.mutable_fusion_backend_config()
-      ->mutable_custom_fusion_config()
-      ->set_kernel_index(kernel_index);
-  TF_RETURN_IF_ERROR(fusion_instruction->set_backend_config(gpu_config));
-
-  return absl::OkStatus();
-}
-
-absl::StatusOr<std::vector<CustomKernel>> LoadKernels(
-    const HloInstruction* fusion_instruction,
-    const AutotuneConfig& autotune_config) {
-  auto config = fusion_instruction->backend_config<GpuBackendConfig>()
-                    ->fusion_backend_config()
-                    .custom_fusion_config();
-  auto* registry = CustomKernelFusionRegistry::Default();
-  auto* custom_kernel_fusion = registry->Lookup(config.name());
-
-  // If custom fusion is not found it means that some of the build targets might
-  // not be statically linked into the binary.
-  if (custom_kernel_fusion == nullptr) {
-    return absl::InternalError(
-        absl::StrCat("Custom kernel fusion ", config.name(),
-                     " not found in a default registry."));
-  }
-
-  se::StreamExecutor* stream_exec = autotune_config.GetExecutor();
-  if (!stream_exec->SynchronizeAllActivity()) {
-    return Internal("Failed to synchronize GPU for autotuning.");
-  }
-  se::DeviceDescription device_description =
-      stream_exec->GetDeviceDescription();
-
-  // Load custom kernels that can implement a fusion computation.
-  TF_ASSIGN_OR_RETURN(
-      std::vector<CustomKernel> kernels,
-      custom_kernel_fusion->LoadKernels(
-          device_description,
-          fusion_instruction->fused_instructions_computation()));
-
-  return kernels;
-}
-
-absl::StatusOr<bool> AutotuneCustomKernelFusion(
-    HloInstruction* fusion_instruction, const AutotuneConfig& autotune_config,
-    AutotunerCompileUtil& compile_util, const DebugOptions& debug_options) {
-  int previous_kernel_index =
-      fusion_instruction->backend_config<GpuBackendConfig>()
-          ->fusion_backend_config()
-          .custom_fusion_config()
-          .kernel_index();
-
-  TF_ASSIGN_OR_RETURN(std::vector<CustomKernel> kernels,
-                      LoadKernels(fusion_instruction, autotune_config));
-
-  std::vector<std::tuple<int, absl::Duration>> results;
-  TF_ASSIGN_OR_RETURN(results,
-                      ProfileKernels(kernels, fusion_instruction, compile_util,
-                                     autotune_config, debug_options));
-
-  TF_ASSIGN_OR_RETURN(int fastest_kernel_index, FindFastestKernel(results));
-
-  TF_RETURN_IF_ERROR(UpdateFusionInstructionKernelIndex(fusion_instruction,
-                                                        fastest_kernel_index));
-
-  return previous_kernel_index != fastest_kernel_index;
-}
-}  // namespace
-
-absl::StatusOr<bool> CustomKernelFusionAutotuner::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  const DebugOptions& debug_options = module->config().debug_options();
-  TF_ASSIGN_OR_RETURN(std::optional<AutotunerCompileUtil> compile_util,
-                      AutotunerCompileUtil::Create(config_, debug_options));
-  TF_RET_CHECK(compile_util.has_value());
-
-  bool hlo_changed = false;
-  for (const HloComputation* computation : module->computations()) {
-    if (computation->IsFusionComputation()) {
-      TF_ASSIGN_OR_RETURN(
-          bool instruction_changed,
-          AutotuneCustomKernelFusion(computation->FusionInstruction(), config_,
-                                     compile_util.value(), debug_options));
-      if (instruction_changed) {
-        hlo_changed = true;
-      }
-    }
-  }
-
-  return hlo_changed;
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/custom_kernel_fusion_autotuner.h b/third_party/xla/xla/service/gpu/custom_kernel_fusion_autotuner.h
deleted file mode 100644
index f6cd0c0..0000000
--- a/third_party/xla/xla/service/gpu/custom_kernel_fusion_autotuner.h
+++ /dev/null
@@ -1,53 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_CUSTOM_KERNEL_FUSION_AUTOTUNER_H_
-#define XLA_SERVICE_GPU_CUSTOM_KERNEL_FUSION_AUTOTUNER_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/autotuning.pb.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/gpu/autotuner_util.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/xla.pb.h"
-
-namespace xla {
-namespace gpu {
-
-// Find best custom kernel for custom kernel fusions.
-class CustomKernelFusionAutotuner : public HloModulePass {
- public:
-  explicit CustomKernelFusionAutotuner(const AutotuneConfig& config)
-      : config_(config) {}
-
-  absl::string_view name() const override {
-    return "custom_kernel-fusion-autotuner";
-  }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
-  const AutotuneConfig config_;
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_CUSTOM_KERNEL_FUSION_AUTOTUNER_H_
diff --git a/third_party/xla/xla/service/gpu/custom_kernel_fusion_autotuner_test.cc b/third_party/xla/xla/service/gpu/custom_kernel_fusion_autotuner_test.cc
deleted file mode 100644
index aa6c1d2..0000000
--- a/third_party/xla/xla/service/gpu/custom_kernel_fusion_autotuner_test.cc
+++ /dev/null
@@ -1,112 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/custom_kernel_fusion_autotuner.h"
-
-#include <memory>
-#include <string>
-#include <utility>
-
-#include <gtest/gtest.h>
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/gpu/autotuner_util.h"
-#include "xla/service/hlo_pass_pipeline.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/xla.pb.h"
-#include "tsl/platform/test.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-class CustomKernelFusionAutotunerTest : public HloTestBase {
- public:
-  CustomKernelFusionAutotunerTest()
-      : HloTestBase(/*verifier_layout_sensitive=*/false,
-                    /*allow_mixed_precision_in_hlo_verifier=*/true) {}
-
-  void SetUp() override { HloTestBase::SetUp(); }
-
-  void TearDown() override { HloTestBase::TearDown(); }
-};
-
-TEST_F(CustomKernelFusionAutotunerTest,
-       CustomKernelFusionAutotunerPassSucceeds) {
-  const std::string hlo_string = R"(
-    HloModule extracted
-
-    cutlass_gemm {
-      p0 = f32[15,19]{1,0} parameter(0)
-      p1 = f32[19,17]{1,0} parameter(1)
-      ROOT r = f32[15, 17]{1,0} dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-    }
-
-    ENTRY region_198.14436 {
-      p.0 = f32[15,19]{1,0} parameter(0)
-      p.1 = f32[19,17]{1,0} parameter(1)
-      ROOT cutlass_gemm = f32[15,17]{1,0} fusion(p.0, p.1), kind=kCustom, calls=cutlass_gemm, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"cutlass_gemm","kernel_index":0}},"force_earliest_schedule":false}
-    }
-  )";
-  std::unique_ptr<HloModule> hlo_module =
-      ParseAndReturnVerifiedModule(hlo_string).value();
-
-  HloPassPipeline pipeline("custom_kernel_fusion_autotuner");
-  DebugOptions debug_options;
-  AutotuneConfig autotune_config =
-      AutotuneConfig{DeviceConfig{backend().default_stream_executor(),
-                                  backend().memory_allocator()},
-                     debug_options};
-  pipeline.AddPass<CustomKernelFusionAutotuner>(autotune_config);
-  ASSERT_TRUE(pipeline.Run(hlo_module.get()).ok());
-}
-
-TEST_F(CustomKernelFusionAutotunerTest,
-       CustomKernelFusionAutotunerPassUpdatesUpdatesKernelIndex) {
-  const std::string hlo_string = R"(
-    HloModule extracted
-
-    cutlass_gemm {
-      p0 = f32[15,19]{1,0} parameter(0)
-      p1 = f32[19,17]{1,0} parameter(1)
-      ROOT r = f32[15, 17]{1,0} dot(p0, p1), lhs_contracting_dims={1},
-      rhs_contracting_dims={0}
-    }
-
-    ENTRY region_198.14436 {
-      p.0 = f32[15,19]{1,0} parameter(0)
-      p.1 = f32[19,17]{1,0} parameter(1)
-      ROOT cutlass_gemm = f32[15,17]{1,0} fusion(p.0, p.1), kind=kCustom,
-      calls=cutlass_gemm,
-      backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"cutlass_gemm","kernel_index":-1}},"force_earliest_schedule":false}
-    }
-  )";
-
-  HloPassPipeline pipeline("custom_kernel_fusion_autotuner");
-  DebugOptions debug_options;
-  AutotuneConfig autotune_config =
-      AutotuneConfig{DeviceConfig{backend().default_stream_executor(),
-                                  backend().memory_allocator()},
-                     debug_options};
-  pipeline.AddPass<CustomKernelFusionAutotuner>(autotune_config);
-
-  std::string expected = R"(
-    CHECK: "kernel_index":0
-  )";
-  RunAndFilecheckHloRewrite(hlo_string, std::move(pipeline), expected);
-}
-
-}  // namespace
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/custom_kernel_fusion_rewriter.cc b/third_party/xla/xla/service/gpu/custom_kernel_fusion_rewriter.cc
deleted file mode 100644
index 814ccf0..0000000
--- a/third_party/xla/xla/service/gpu/custom_kernel_fusion_rewriter.cc
+++ /dev/null
@@ -1,240 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/custom_kernel_fusion_rewriter.h"
-
-#include <cstdint>
-#include <optional>
-#include <utility>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_map.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/container/inlined_vector.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/gpu/kernels/custom_kernel_fusion_pattern.h"
-#include "xla/shape_util.h"
-#include "xla/stream_executor/device_description.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla::gpu {
-
-CustomKernelFusionRewriter::CustomKernelFusionRewriter(
-    const se::DeviceDescription* device,
-    const CustomKernelFusionPatternRegistry* patterns)
-    : device_(device), patterns_(patterns) {}
-
-// Returns a set of instruction that have users outside of a matched pattern
-// and have a replacement that must be applied after building a new custom
-// fusion instruction. Only root instruction can have external users and does
-// not require a replacement, as the fusion itself is a replacement. If
-// instruction has external users and does not have a replacement returns empty
-// optional.
-static std::optional<absl::flat_hash_set<HloInstruction*>>
-GetPatternReplacements(const CustomKernelFusionPattern::Match& match) {
-  absl::flat_hash_set<HloInstruction*> requires_replacement;
-  absl::flat_hash_set<HloInstruction*> instructions_set(
-      match.instructions().begin(), match.instructions().end());
-
-  for (HloInstruction* instr : match.instructions()) {
-    for (HloInstruction* user : instr->users()) {
-      if (instr == match.root() || instructions_set.contains(user)) continue;
-
-      if (match.HasReplacement(instr)) {
-        requires_replacement.insert(instr);
-        continue;
-      }
-
-      VLOG(3) << "Custom kernel fusion intermediate result " << instr->name()
-              << " has users outside of a matched pattern: " << user->name();
-      return std::nullopt;
-    }
-  }
-
-  return requires_replacement;
-}
-
-// Returns instructions that have to become custom kernel fusion parameters.
-// Returns an error if matched pattern can't be outlined as a fusion.
-static absl::InlinedVector<HloInstruction*, 4> GetPatternCaptures(
-    const CustomKernelFusionPattern::Match& match) {
-  absl::InlinedVector<HloInstruction*, 4> captures;
-
-  absl::flat_hash_set<HloInstruction*> instructions_set(
-      match.instructions().begin(), match.instructions().end());
-
-  for (HloInstruction* instr : match.instructions()) {
-    for (HloInstruction* operand : instr->operands()) {
-      if (!instructions_set.contains(operand) &&
-          absl::c_find(captures, operand) == captures.end()) {
-        captures.emplace_back(operand);
-      }
-    }
-  }
-
-  return captures;
-}
-
-// Creates custom kernel fusion computation and moves all matched instructions
-// into it.
-static absl::StatusOr<HloComputation*> CreateFusionBody(
-    HloModule* module, const CustomKernelFusionPattern::Match& match,
-    absl::Span<HloInstruction* const> captures) {
-  HloComputation::Builder builder(match.config().name());
-
-  // A mapping from original instructions to instructions in the fusion body.
-  absl::flat_hash_map<const HloInstruction*, HloInstruction*> instr_mapping;
-
-  auto mapped_operands = [&](HloInstruction* instr) {
-    absl::InlinedVector<HloInstruction*, 4> operands;
-    for (HloInstruction* operand : instr->operands()) {
-      operands.push_back(instr_mapping.at(operand));
-    }
-    return operands;
-  };
-
-  // For every captured value create a parameter instruction in the computation
-  // body and set up instruction mapping.
-  for (const HloInstruction* capture : captures) {
-    int64_t index = instr_mapping.size();
-    instr_mapping[capture] =
-        builder.AddInstruction(HloInstruction::CreateParameter(
-            index, capture->shape(), absl::StrCat("p", index)));
-  }
-
-  // TODO(ezhulenev): Instructions in the pattern must be topologically sorted,
-  // otherwise we'll get a crash! Figure out how to do it!
-  for (HloInstruction* instr : match.instructions()) {
-    instr_mapping[instr] = builder.AddInstruction(
-        instr->CloneWithNewOperands(instr->shape(), mapped_operands(instr)));
-  }
-
-  HloInstruction* root = builder.last_added_instruction();
-
-  // If custom kernel fusion requires a workspace we add a custom call that
-  // allocates workspace and return a tuple of "real" result and a workspace.
-  if (match.workspace_size_bytes() > 0) {
-    auto workspace_shape =
-        ShapeUtil::MakeShape(PrimitiveType::U8, {match.workspace_size_bytes()});
-    HloInstruction* workspace =
-        builder.AddInstruction(HloInstruction::CreateCustomCall(
-            workspace_shape, {}, CustomKernelFusionPattern::kWorkspace, "",
-            CustomCallApiVersion::API_VERSION_TYPED_FFI));
-    builder.AddInstruction(HloInstruction::CreateTuple({root, workspace}));
-  }
-
-  return module->AddComputationAndUnifyNamesAndIds(builder.Build(), false);
-}
-
-static absl::StatusOr<HloInstruction*> CreateFusionInstruction(
-    HloModule* module, const CustomKernelFusionPattern::Match& match,
-    absl::Span<HloInstruction* const> captures, HloComputation* body) {
-  // We'll be replacing the root operation of a custom kernel fusion with a
-  // fusion instruction calling fusion computation.
-  HloInstruction* root = match.root();
-  HloComputation* parent = root->parent();
-
-  // Add a fusion operation calling outlined fusion computation.
-  HloInstruction* fusion = parent->AddInstruction(HloInstruction::CreateFusion(
-      body->root_instruction()->shape(), HloInstruction::FusionKind::kCustom,
-      captures, body));
-  module->SetAndUniquifyInstrName(fusion, match.config().name());
-
-  // Set backends config to a matched custom fusion config.
-  GpuBackendConfig gpu_config;
-  FusionBackendConfig& backend_config =
-      *gpu_config.mutable_fusion_backend_config();
-  backend_config.set_kind("__custom_fusion");
-  *backend_config.mutable_custom_fusion_config() = match.config();
-  backend_config.mutable_custom_fusion_config()->set_kernel_index(0);
-  TF_RETURN_IF_ERROR(fusion->set_backend_config(std::move(gpu_config)));
-
-  // If we don't have workspace we can return constructed fusion instruction.
-  if (match.workspace_size_bytes() == 0) return fusion;
-
-  // Otherwise have to get result corresponding to the original value;
-  return parent->AddInstruction(
-      HloInstruction::CreateGetTupleElement(fusion, 0));
-}
-
-absl::StatusOr<bool> CustomKernelFusionRewriter::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  std::vector<CustomKernelFusionPattern::Match> matches;
-
-  // Collect all potential custom fusion matches in the module.
-  for (HloComputation* computation : module->computations()) {
-    for (HloInstruction* instr : computation->instructions()) {
-      auto matched = patterns_->Match(*device_, instr);
-      matches.insert(matches.end(), matched.begin(), matched.end());
-    }
-  }
-
-  if (matches.empty()) return false;
-
-  for (const CustomKernelFusionPattern::Match& match : matches) {
-    VLOG(2) << "Matched custom kernel fusion " << match.config().name()
-            << "; root instruction: " << match.instructions().back()->name();
-
-    auto replacememts = GetPatternReplacements(match);
-    if (!replacememts.has_value()) continue;
-
-    auto captures = GetPatternCaptures(match);
-
-    TF_ASSIGN_OR_RETURN(HloComputation * fusion_body,
-                        CreateFusionBody(module, match, captures));
-    TF_ASSIGN_OR_RETURN(
-        HloInstruction * fusion,
-        CreateFusionInstruction(module, match, captures, fusion_body));
-
-    VLOG(2) << "Added a fusion instruction: " << fusion->name()
-            << " for custom kernel fusion " << match.config().name()
-            << " (instruction count = " << match.instructions().size() << ")";
-
-    for (HloInstruction* instr : *replacememts) {
-      VLOG(2) << "Replace matched instruction: " << instr->name()
-              << " with a pattern replacement";
-
-      TF_ASSIGN_OR_RETURN(
-          HloInstruction * replacement,
-          match.BuildReplacement(instr, Cast<HloFusionInstruction>(fusion)));
-
-      TF_RETURN_IF_ERROR(
-          instr->ReplaceAllUsesWith(replacement, match.config().name()));
-
-      VLOG(2) << "Replaced instruction: " << instr->name()
-              << " with: " << replacement->name();
-    }
-
-    VLOG(2) << "Replace custom kernel fusion root instruction "
-            << match.root()->name() << "with " << fusion->name();
-    HloComputation* parent = match.root()->parent();
-    TF_RETURN_IF_ERROR(parent->ReplaceInstruction(match.root(), fusion));
-  }
-
-  return true;
-}
-
-}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/custom_kernel_fusion_rewriter.h b/third_party/xla/xla/service/gpu/custom_kernel_fusion_rewriter.h
deleted file mode 100644
index cb19d91..0000000
--- a/third_party/xla/xla/service/gpu/custom_kernel_fusion_rewriter.h
+++ /dev/null
@@ -1,86 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_CUSTOM_KERNEL_FUSION_REWRITER_H_
-#define XLA_SERVICE_GPU_CUSTOM_KERNEL_FUSION_REWRITER_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_schedule.h"
-#include "xla/service/gpu/kernels/custom_kernel_fusion_pattern.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/stream_executor/device_description.h"
-
-namespace xla::gpu {
-
-// Pattern matches HLO instruction to custom kernel fusions (hand written CUDA
-// C++ kernels, e.g. custom GEMMs implemented with CUTLASS) and rewrites them
-// into fusion instructions and fusion computations.
-//
-// Example: pattern matching dot operation into CUTLASS gemm
-//
-//  ENTRY %main (p0: f16[15,19], p1: f16[19,17]) -> f16[15,17] {
-//    %p0 = f16[15,19]{1,0} parameter(0)
-//    %p1 = f16[19,17]{1,0} parameter(1)
-//    ROOT %r = f16[15,17]{1,0} dot(%p0, %p1),
-//      lhs_contracting_dims={1}, rhs_contracting_dims={0}
-//  }
-//
-// After the pass:
-//
-//  %cutlass_gemm (p0: f16[19,17], p1: f16[15,19]) -> f16[15,17] {
-//    %p0 = f16[15,19]{1,0} parameter(0)
-//    %p1 = f16[19,17]{1,0} parameter(1)
-//    ROOT %r = f16[15,17]{1,0} dot(%p0, %p1),
-//      lhs_contracting_dims={1}, rhs_contracting_dims={0}
-//  }
-//
-//  ENTRY %main (p0: f16[15,19], p1: f16[19,17]) -> f16[15,17] {
-//    %p0 = f16[15,19]{1,0} parameter(0)
-//    %p1 = f16[19,17]{1,0} parameter(1)
-//    ROOT %r = f16[15,17]{1,0} fusion(%p0, %p1), kind=kCustom,
-//      calls==cutlass_gemm,
-//      backend_config={kind: "__custom_fusion",
-//                      custom_fusion_config: {"name":"cutlass_gemm"}}
-//  }
-//
-class CustomKernelFusionRewriter : public HloModulePass {
- public:
-  explicit CustomKernelFusionRewriter(
-      const se::DeviceDescription* device,
-      const CustomKernelFusionPatternRegistry* patterns =
-          CustomKernelFusionPatternRegistry::Default());
-
-  absl::string_view name() const override {
-    return "custom-kernel-fusion-rewriter";
-  }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
-  const se::DeviceDescription* device_;
-  const CustomKernelFusionPatternRegistry* patterns_;
-};
-
-}  // namespace xla::gpu
-
-#endif  // XLA_SERVICE_GPU_CUSTOM_KERNEL_FUSION_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/custom_kernel_fusion_rewriter_test.cc b/third_party/xla/xla/service/gpu/custom_kernel_fusion_rewriter_test.cc
deleted file mode 100644
index f2c824c..0000000
--- a/third_party/xla/xla/service/gpu/custom_kernel_fusion_rewriter_test.cc
+++ /dev/null
@@ -1,138 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/custom_kernel_fusion_rewriter.h"
-
-#include <cstdint>
-#include <optional>
-#include <utility>
-
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/service/gpu/gpu_device_info_for_tests.h"
-#include "xla/service/gpu/kernels/custom_kernel_fusion_pattern.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/tests/hlo_test_base.h"
-#include "tsl/platform/test.h"
-
-namespace xla::gpu {
-
-//===----------------------------------------------------------------------===//
-// Simple pattern matchers for testing custom kernel_fusion rewriter.
-//===----------------------------------------------------------------------===//
-
-struct SimpleGemmPattern : public CustomKernelFusionPattern {
-  explicit SimpleGemmPattern(int64_t workspace = 0) : workspace(workspace) {}
-
-  std::optional<Match> TryMatch(const se::DeviceDescription& device,
-                                HloInstruction* instr) const override {
-    if (auto* dot = DynCast<HloDotInstruction>(instr)) {
-      CustomFusionConfig config;
-      config.set_name("simple_gemm");
-      return Match{config, {instr}, workspace};
-    }
-    return std::nullopt;
-  }
-
-  int64_t workspace;
-};
-
-//===----------------------------------------------------------------------===//
-
-class CustomKernelFusionRewriterTest : public HloTestBase {};
-
-TEST_F(CustomKernelFusionRewriterTest, SimpleGemm) {
-  const char* hlo = R"(
-    HloModule test
-
-    ENTRY %main (p0: f16[15,19], p1: f16[19,17]) -> f16[15,17] {
-      %p0 = f16[15,19]{1,0} parameter(0)
-      %p1 = f16[19,17]{1,0} parameter(1)
-      ROOT %r = f16[15,17]{1,0} dot(%p0, %p1),
-        lhs_contracting_dims={1}, rhs_contracting_dims={0}
-    }
-  )";
-
-  const char* expected = R"(
-    ; CHECK: %simple_gemm {{.*}} {
-    ; CHECK:   [[P0:%[^ ]+]] = f16[15,19]{1,0} parameter(0)
-    ; CHECK:   [[P1:%[^ ]+]] = f16[19,17]{1,0} parameter(1)
-    ; CHECK:   ROOT [[DOT:%[^ ]+]] = f16[15,17]{1,0} dot([[P0]], [[P1]]),
-    ; CHECK:     lhs_contracting_dims={1}, rhs_contracting_dims={0}
-    ; CHECK: }
-
-    ; CHECK: ENTRY %main {{.*}} {
-    ; CHECK:   ROOT [[FUSION:%[^ ]+]] = f16[15,17]{1,0} fusion
-    ; CHECK:     kind=kCustom, calls=%simple_gemm,
-    ; CHECK:     backend_config={
-    ; CHECK:       "kind":"__custom_fusion",
-    ; CHECK:       "custom_fusion_config":{"name":"simple_gemm","kernel_index":0}
-    ; CHECK:     }
-    ; CHECK: }
-  )";
-
-  CustomKernelFusionPatternRegistry patterns;
-  patterns.Emplace<SimpleGemmPattern>();
-
-  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
-  CustomKernelFusionRewriter pass(&device, &patterns);
-  RunAndFilecheckHloRewrite(hlo, std::move(pass), expected);
-}
-
-TEST_F(CustomKernelFusionRewriterTest, SimpleGemmWithWorkspace) {
-  const char* hlo = R"(
-    HloModule test
-
-    ENTRY %main (p0: f16[15,19], p1: f16[19,17]) -> f16[15,17] {
-      %p0 = f16[15,19]{1,0} parameter(0)
-      %p1 = f16[19,17]{1,0} parameter(1)
-      ROOT %r = f16[15,17]{1,0} dot(%p0, %p1),
-        lhs_contracting_dims={1}, rhs_contracting_dims={0}
-    }
-  )";
-
-  const char* expected = R"(
-    ; CHECK: %simple_gemm {{.*}} {
-    ; CHECK:   [[P0:%[^ ]+]] = f16[15,19]{1,0} parameter(0)
-    ; CHECK:   [[P1:%[^ ]+]] = f16[19,17]{1,0} parameter(1)
-    ; CHECK:   [[DOT:%[^ ]+]] = f16[15,17]{1,0} dot([[P0]], [[P1]]),
-    ; CHECK:     lhs_contracting_dims={1}, rhs_contracting_dims={0}
-    ; CHECK:   [[WORKSPACE:%[^ ]+]] = u8[1024]{0} custom-call(),
-    ; CHECK:     custom_call_target="__custom_kernel_fusion$workspace"
-    ; CHECK:   ROOT [[TUPLE:%[^ ]+]] = (f16[15,17]{1,0}, u8[1024]{0})
-    ; CHECK:     tuple([[DOT]], [[WORKSPACE]])
-    ; CHECK: }
-
-    ; CHECK: ENTRY %main {{.*}} {
-    ; CHECK:   [[FUSION:%[^ ]+]] = (f16[15,17]{1,0}, u8[1024]{0}) fusion
-    ; CHECK:     kind=kCustom, calls=%simple_gemm,
-    ; CHECK:     backend_config={
-    ; CHECK:       "kind":"__custom_fusion",
-    ; CHECK:       "custom_fusion_config":{"name":"simple_gemm","kernel_index":0}
-    ; CHECK:     }
-    ; CHECK:   ROOT {{.*}} get-tuple-element([[FUSION]]), index=0
-    ; CHECK: }
-  )";
-
-  CustomKernelFusionPatternRegistry patterns;
-  patterns.Emplace<SimpleGemmPattern>(1024);
-
-  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
-  CustomKernelFusionRewriter pass(&device, &patterns);
-  RunAndFilecheckHloRewrite(hlo, std::move(pass), expected);
-}
-
-}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/determinism_test.cc b/third_party/xla/xla/service/gpu/determinism_test.cc
index 93c5b15..45b9704 100644
--- a/third_party/xla/xla/service/gpu/determinism_test.cc
+++ b/third_party/xla/xla/service/gpu/determinism_test.cc
@@ -22,7 +22,7 @@
 #include "absl/strings/string_view.h"
 #include "xla/hlo/ir/hlo_module.h"
 #include "xla/literal.h"
-#include "xla/service/gpu/autotuner_util.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
 #include "xla/service/gpu/tests/gpu_codegen_test.h"
 #include "xla/stream_executor/device_description.h"
 #include "xla/stream_executor/gpu/gpu_timer.h"
diff --git a/third_party/xla/xla/service/gpu/dot_dimension_sorter.cc b/third_party/xla/xla/service/gpu/dot_dimension_sorter.cc
deleted file mode 100644
index 38920ee..0000000
--- a/third_party/xla/xla/service/gpu/dot_dimension_sorter.cc
+++ /dev/null
@@ -1,136 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/dot_dimension_sorter.h"
-
-#include <cstdint>
-#include <memory>
-#include <utility>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/status.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/layout_util.h"
-#include "xla/permutation_util.h"
-#include "xla/util.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h"
-
-namespace xla {
-namespace gpu {
-
-namespace {
-
-// Sort contracting dimensions of a dot() instruction preserving lhs-rhs pairs.
-absl::Status SortDotDimensions(HloDotInstruction* dot) {
-  const DotDimensionNumbers& dims = dot->dot_dimension_numbers();
-  DotDimensionNumbers new_dims(dims);
-  new_dims.clear_lhs_contracting_dimensions();
-  new_dims.clear_rhs_contracting_dimensions();
-  const bool sort_by_lhs =
-      DistinctNumbersAreConsecutiveIfSorted(dims.lhs_contracting_dimensions());
-  // Sort lhs and rhs by sort_key using the fact that
-  // sort_key is guaranteed to have only distinct consecutive numbers.
-  const absl::Span<const int64_t>& sort_key =
-      sort_by_lhs ? dims.lhs_contracting_dimensions()
-                  : dims.rhs_contracting_dimensions();
-  std::vector<int64_t> permutation;
-  for (const int64_t a : sort_key) {
-    permutation.push_back(a - *absl::c_min_element(sort_key));
-  }
-  const std::vector<int64_t> sorted_lhs =
-      Permute(dims.lhs_contracting_dimensions(), permutation);
-  *new_dims.mutable_lhs_contracting_dimensions() = {sorted_lhs.begin(),
-                                                    sorted_lhs.end()};
-  const std::vector<int64_t> sorted_rhs =
-      Permute(dims.rhs_contracting_dimensions(), permutation);
-  *new_dims.mutable_rhs_contracting_dimensions() = {sorted_rhs.begin(),
-                                                    sorted_rhs.end()};
-  std::unique_ptr<HloInstruction> new_dot = HloInstruction::CreateDot(
-      dot->shape(), dot->mutable_operand(0), dot->mutable_operand(1), new_dims,
-      dot->precision_config(), {dot->sparsity().begin(), dot->sparsity().end()},
-      absl::MakeSpan(dot->operands()).subspan(HloDotInstruction::kOperands));
-  dot->SetupDerivedInstruction(new_dot.get());
-
-  VLOG(3) << "Sorted dot() dimensions:\n"
-          << "\t before: " << dot->ToString() << "\n"
-          << "\t after: " << new_dot->ToString();
-  return dot->parent()->ReplaceWithNewInstruction(dot, std::move(new_dot));
-}
-
-}  // namespace
-
-absl::StatusOr<bool> DotDimensionSorter::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  std::vector<HloInstruction*> dots_to_process;
-  for (const HloComputation* computation :
-       module->MakeNonfusionComputations(execution_threads)) {
-    for (HloInstruction* instr : computation->instructions()) {
-      if (instr->opcode() != HloOpcode::kDot) {
-        continue;
-      }
-      // TODO(b/265688934): should non-default layouts be expected here at all?
-      if ((instr->operand(0)->shape().has_layout() &&
-           !LayoutUtil::IsMonotonicWithDim0Major(
-               instr->operand(0)->shape().layout())) ||
-          (instr->operand(1)->shape().has_layout() &&
-           !LayoutUtil::IsMonotonicWithDim0Major(
-               instr->operand(1)->shape().layout()))) {
-        continue;
-      }
-      const DotDimensionNumbers& dims = instr->dot_dimension_numbers();
-      if (dims.lhs_contracting_dimensions_size() == 0) {
-        continue;
-      }
-      const bool cons_lhs = DistinctNumbersAreConsecutiveIfSorted(
-          dims.lhs_contracting_dimensions());
-      const bool cons_rhs = DistinctNumbersAreConsecutiveIfSorted(
-          dims.rhs_contracting_dimensions());
-      const bool sorted_lhs =
-          absl::c_is_sorted(dims.lhs_contracting_dimensions());
-      const bool sorted_rhs =
-          absl::c_is_sorted(dims.rhs_contracting_dimensions());
-      // The side to be sorted has to be consecutive and not sorted yet;
-      // the other side should not get worsened.
-      // TODO(b/265688934): we may still want to change which one is sorted
-      // if this reduces the amount of transposed data.
-      if ((cons_lhs && !sorted_lhs && !cons_rhs) ||
-          (cons_rhs && !sorted_rhs && !cons_lhs) ||
-          (cons_lhs && !sorted_lhs && cons_rhs && !sorted_rhs)) {
-        dots_to_process.push_back(instr);
-      }
-    }
-  }
-  if (dots_to_process.empty()) {
-    return false;
-  }
-  for (HloInstruction* dot : dots_to_process) {
-    TF_RETURN_IF_ERROR(SortDotDimensions(Cast<HloDotInstruction>(dot)));
-  }
-  return true;
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/dot_dimension_sorter.h b/third_party/xla/xla/service/gpu/dot_dimension_sorter.h
deleted file mode 100644
index 5eadeb1..0000000
--- a/third_party/xla/xla/service/gpu/dot_dimension_sorter.h
+++ /dev/null
@@ -1,52 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_DOT_DIMENSION_SORTER_H_
-#define XLA_SERVICE_GPU_DOT_DIMENSION_SORTER_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-// Sorts contracting dimensions of dot() operands when this reduces the
-// number of transposes. Example:
-// dot(p0, p1), lhs_contracting_dims={3,2}, rhs_contracting_dims={2,1}  ->
-// dot(p0, p1), lhs_contracting_dims={2,3}, rhs_contracting_dims={1,2}
-// The first case gets transposes inserted by dot_decomposer, the second one
-// does not and thus is generally more efficient.
-
-// TODO(b/265688934): do the same for batch dimensions?
-
-class DotDimensionSorter : public HloModulePass {
- public:
-  absl::string_view name() const override { return "dot_dimension_sorter"; }
-
-  // Run the pass on computations in 'module'.
-  // Returns whether the 'module' was changed.
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_DOT_DIMENSION_SORTER_H_
diff --git a/third_party/xla/xla/service/gpu/dot_dimension_sorter_test.cc b/third_party/xla/xla/service/gpu/dot_dimension_sorter_test.cc
deleted file mode 100644
index fedd1ea..0000000
--- a/third_party/xla/xla/service/gpu/dot_dimension_sorter_test.cc
+++ /dev/null
@@ -1,191 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/dot_dimension_sorter.h"
-
-#include <memory>
-
-#include <gtest/gtest.h>
-#include "xla/error_spec.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/gpu/tests/gpu_codegen_test.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-class WithoutDotDimensionSorterTest : public GpuCodegenTest {
- public:
-  DebugOptions GetDebugOptionsForTest() override {
-    DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest();
-    // The pass is disabled here to preserve suboptimal dimension order in
-    // 1) UnsortedDimsCreateTransposes to reveal the transposes.
-    // 2) DimOrderCanBeChanged for the comparison of ordered vs unordered.
-    // The pass does not touch SortedDimsDoNotCreateTransposes anyway because
-    // the dimensions are already ordered there.
-    debug_options.add_xla_disable_hlo_passes("dot_dimension_sorter");
-    return debug_options;
-  }
-};
-
-TEST_F(WithoutDotDimensionSorterTest, UnsortedDimsCreateTransposes) {
-  const char* hlo_text = R"(
-HloModule m
-
-ENTRY e {
-  p0 = f16[1,14,9,32] parameter(0)
-  p1 = f16[12,9,32] parameter(1)
-  ROOT _ = f16[1,14,12] dot(p0, p1),
-    lhs_contracting_dims={3,2}, rhs_contracting_dims={2,1}
-}
-)";
-
-  MatchOptimizedHlo(hlo_text, R"(
-; CHECK: transpose
-)");
-}
-
-TEST_F(WithoutDotDimensionSorterTest, SortedDimsDoNotCreateTransposes) {
-  const char* hlo_text = R"(
-HloModule m
-
-ENTRY e {
-  p0 = f16[1,14,9,32] parameter(0)
-  p1 = f16[12,9,32] parameter(1)
-  ROOT _ = f16[1,14,12] dot(p0, p1),
-    lhs_contracting_dims={2,3}, rhs_contracting_dims={1,2}
-}
-)";
-
-  MatchOptimizedHlo(hlo_text, R"(
-; CHECK-NOT: transpose
-)");
-}
-
-TEST_F(WithoutDotDimensionSorterTest, DimOrderCanBeChanged) {
-  const char* hlo_text_ref = R"(
-HloModule m
-
-ENTRY e {
-  p0 = f16[1,14,9,32] parameter(0)
-  p1 = f16[12,9,32] parameter(1)
-  ROOT _ = f16[1,14,12] dot(p0, p1),
-    lhs_contracting_dims={3,2}, rhs_contracting_dims={2,1}
-}
-)";
-
-  const char* hlo_text_modified = R"(
-HloModule m
-
-ENTRY e {
-  p0 = f16[1,14,9,32] parameter(0)
-  p1 = f16[12,9,32] parameter(1)
-  ROOT _ = f16[1,14,12] dot(p0, p1),
-    lhs_contracting_dims={2,3}, rhs_contracting_dims={1,2}
-}
-)";
-
-  EXPECT_TRUE(RunAndCompareTwoModules(hlo_text_ref, hlo_text_modified,
-                                      ErrorSpec{1e-5, 1e-3},
-                                      /*run_hlo_passes=*/true));
-}
-
-using DotDimensionSorterTest = GpuCodegenTest;
-
-TEST_F(DotDimensionSorterTest, SortContractingDims) {
-  const char* module_string = R"(
-HloModule m
-
-ENTRY e {
-  p0 = f16[1,144,96,32] parameter(0)
-  p1 = f16[122,96,32] parameter(1)
-  ROOT _ = f16[1,144,122] dot(p0, p1),
-    lhs_contracting_dims={3,2}, rhs_contracting_dims={2,1}
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(module_string));
-  const auto& dims =
-      module->entry_computation()->root_instruction()->dot_dimension_numbers();
-
-  EXPECT_EQ(dims.lhs_contracting_dimensions(0), 3);
-  EXPECT_EQ(dims.lhs_contracting_dimensions(1), 2);
-
-  EXPECT_EQ(dims.rhs_contracting_dimensions(0), 2);
-  EXPECT_EQ(dims.rhs_contracting_dimensions(1), 1);
-
-  TF_ASSERT_OK_AND_ASSIGN(bool modified,
-                          DotDimensionSorter().Run(module.get()));
-  EXPECT_TRUE(modified);
-  const auto& dims2 =
-      module->entry_computation()->root_instruction()->dot_dimension_numbers();
-
-  EXPECT_EQ(dims2.lhs_contracting_dimensions(0), 2);
-  EXPECT_EQ(dims2.lhs_contracting_dimensions(1), 3);
-
-  EXPECT_EQ(dims2.rhs_contracting_dimensions(0), 1);
-  EXPECT_EQ(dims2.rhs_contracting_dimensions(1), 2);
-}
-
-TEST_F(DotDimensionSorterTest, NothingToReorder) {
-  const char* module_string = R"(
-HloModule m
-
-ENTRY e {
-  p0 = f16[1,144,96,32] parameter(0)
-  p1 = f16[122,96,32] parameter(1)
-  ROOT _ = f16[1,144,122] dot(p0, p1),
-    lhs_contracting_dims={2,3}, rhs_contracting_dims={1,2}
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(module_string));
-
-  TF_ASSERT_OK_AND_ASSIGN(bool modified,
-                          DotDimensionSorter().Run(module.get()));
-  EXPECT_FALSE(modified);
-}
-
-TEST_F(DotDimensionSorterTest, SparseDotSortContractingDims) {
-  const char* module_string = R"(
-HloModule m
-
-ENTRY e {
-  p0 = f16[1,144,96,16] parameter(0)
-  p1 = f16[122,96,32] parameter(1)
-  meta = u16[1,144,96,2] parameter(2)
-  ROOT _ = f16[1,144,122] dot(p0, p1, meta), sparsity=L.3@2:4,
-    lhs_contracting_dims={3,2}, rhs_contracting_dims={2,1}
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(module_string));
-  TF_ASSERT_OK_AND_ASSIGN(bool modified,
-                          DotDimensionSorter().Run(module.get()));
-  EXPECT_TRUE(modified);
-  HloDotInstruction* dot = DynCast<HloDotInstruction>(
-      module->entry_computation()->root_instruction());
-  EXPECT_TRUE(dot != nullptr && dot->sparse_operands() == 1);
-}
-
-}  // namespace
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/dot_operand_converter.cc b/third_party/xla/xla/service/gpu/dot_operand_converter.cc
deleted file mode 100644
index 2a298e6..0000000
--- a/third_party/xla/xla/service/gpu/dot_operand_converter.cc
+++ /dev/null
@@ -1,74 +0,0 @@
-/* Copyright 2020 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/dot_operand_converter.h"
-
-#include "absl/status/statusor.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/shape_util.h"
-#include "tsl/platform/errors.h"
-
-namespace xla::gpu {
-
-bool DotOperandConverter::InstructionMatchesPattern(
-    HloInstruction* instruction) {
-  if (instruction->opcode() != HloOpcode::kDot) {
-    return false;
-  }
-  HloInstruction* lhs = instruction->mutable_operand(0);
-  HloInstruction* rhs = instruction->mutable_operand(1);
-
-  PrimitiveType lhs_type = lhs->shape().element_type();
-  PrimitiveType rhs_type = rhs->shape().element_type();
-
-  if (lhs_type == rhs_type) {
-    return false;
-  }
-
-  // Exclude conversions between FP8 types.
-  absl::flat_hash_set<PrimitiveType> non_converting = {F8E4M3FN, F8E5M2};
-  if (non_converting.contains(lhs_type) && non_converting.contains(rhs_type)) {
-    return false;
-  }
-
-  PrimitiveType desired_type =
-      ShapeUtil::HigherPrecisionElementType(lhs->shape(), rhs->shape());
-
-  return desired_type == lhs_type || desired_type == rhs_type;
-}
-
-absl::StatusOr<HloInstruction*> DotOperandConverter::ExpandInstruction(
-    HloInstruction* instruction) {
-  HloInstruction* lhs = instruction->mutable_operand(0);
-  HloInstruction* rhs = instruction->mutable_operand(1);
-
-  // Find the higher precision type among the two operands, and add a convert
-  // instruction to convert the lesser-precise operand to that type.
-  PrimitiveType desired_type =
-      ShapeUtil::HigherPrecisionElementType(lhs->shape(), rhs->shape());
-  int operand_index = desired_type == lhs->shape().element_type() ? 1 : 0;
-  HloInstruction* inst_to_replace =
-      desired_type == lhs->shape().element_type() ? rhs : lhs;
-  auto upcast_shape = inst_to_replace->shape();
-  upcast_shape.set_element_type(desired_type);
-  auto* convert_inst = instruction->AddInstruction(
-      HloInstruction::CreateConvert(upcast_shape, inst_to_replace));
-  TF_RETURN_IF_ERROR(instruction->ReplaceOperandWithDifferentShape(
-      operand_index, convert_inst));
-  return nullptr;
-}
-
-}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/dot_operand_converter.h b/third_party/xla/xla/service/gpu/dot_operand_converter.h
deleted file mode 100644
index d277a24..0000000
--- a/third_party/xla/xla/service/gpu/dot_operand_converter.h
+++ /dev/null
@@ -1,46 +0,0 @@
-/* Copyright 2020 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_DOT_OPERAND_CONVERTER_H_
-#define XLA_SERVICE_GPU_DOT_OPERAND_CONVERTER_H_
-
-#include <utility>
-
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/service/op_expander_pass.h"
-#include "xla/util.h"
-
-namespace xla::gpu {
-
-// Converts both operands to the highest precision operand type.
-class DotOperandConverter : public OpExpanderPass {
- public:
-  explicit DotOperandConverter(HloPredicate extra_filter = nullptr)
-      : OpExpanderPass(std::move(extra_filter)) {}
-
-  absl::string_view name() const override { return "operand_converter"; }
-
- protected:
-  bool InstructionMatchesPattern(HloInstruction* instruction) override;
-
-  absl::StatusOr<HloInstruction*> ExpandInstruction(
-      HloInstruction* instruction) override;
-};
-
-}  // namespace xla::gpu
-
-#endif  // XLA_SERVICE_GPU_DOT_OPERAND_CONVERTER_H_
diff --git a/third_party/xla/xla/service/gpu/dot_operand_converter_test.cc b/third_party/xla/xla/service/gpu/dot_operand_converter_test.cc
deleted file mode 100644
index 63b0017..0000000
--- a/third_party/xla/xla/service/gpu/dot_operand_converter_test.cc
+++ /dev/null
@@ -1,142 +0,0 @@
-/* Copyright 2020 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/dot_operand_converter.h"
-
-#include <memory>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "absl/strings/string_view.h"
-#include "absl/strings/substitute.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/utils/hlo_matchers.h"
-#include "xla/primitive_util.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla::gpu {
-namespace {
-
-namespace op = ::xla::testing::opcode_matchers;
-
-class DotOperandConverterTest : public HloTestBase {
- public:
-  void TestConvert(bool left_less_precise, PrimitiveType lhs_type,
-                   PrimitiveType rhs_type, PrimitiveType result_type) {
-    absl::string_view module_tmpl = R"(
-  HloModule module
-
-  ENTRY main {
-    p0 = $0[2,3]{1,0} parameter(0)
-    p1 = $1[3,2]{1,0} parameter(1)
-    ROOT dot = $2[2,2]{1,0} dot(p0, p1), lhs_contracting_dims={1},
-                                         rhs_contracting_dims={0}
-  })";
-    auto module_string = absl::Substitute(
-        module_tmpl, primitive_util::LowercasePrimitiveTypeName(lhs_type),
-        primitive_util::LowercasePrimitiveTypeName(rhs_type),
-        primitive_util::LowercasePrimitiveTypeName(result_type));
-    TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                            ParseAndReturnVerifiedModule(module_string));
-    TF_ASSERT_OK_AND_ASSIGN(bool upcasted,
-                            DotOperandConverter().Run(module.get()));
-    EXPECT_TRUE(upcasted);
-    if (left_less_precise) {
-      auto original_lhs = op::Parameter(0);
-      auto upcasted_lhs =
-          AllOf(op::Convert(original_lhs),
-                op::Shape(absl::Substitute(
-                    "$0[2,3]{1,0}",
-                    primitive_util::LowercasePrimitiveTypeName(rhs_type))));
-      EXPECT_THAT(
-          module->entry_computation()->root_instruction(),
-          AllOf(op::Dot(upcasted_lhs, op::Parameter(1)),
-                op::Shape(absl::Substitute(
-                    "$0[2,2]{1,0}",
-                    primitive_util::LowercasePrimitiveTypeName(result_type)))));
-    } else {
-      auto original_rhs = op::Parameter(1);
-      auto upcasted_rhs =
-          AllOf(op::Convert(original_rhs),
-                op::Shape(absl::Substitute(
-                    "$0[3,2]{1,0}",
-                    primitive_util::LowercasePrimitiveTypeName(lhs_type))));
-      EXPECT_THAT(
-          module->entry_computation()->root_instruction(),
-          AllOf(op::Dot(op::Parameter(0), upcasted_rhs),
-                op::Shape(absl::Substitute(
-                    "$0[2,2]{1,0}",
-                    primitive_util::LowercasePrimitiveTypeName(result_type)))));
-    }
-  }
-};
-
-TEST_F(DotOperandConverterTest, ConvertsLeftAndRight) {
-  TestConvert(/*left_less_precise=*/true, S8, BF16, F32);
-  TestConvert(/*left_less_precise=*/false, BF16, S8, F32);
-}
-
-TEST_F(DotOperandConverterTest, NoConvertHappensWithSameTypes) {
-  absl::string_view module_string = R"(
-  HloModule module
-
-  ENTRY main {
-    p0 = s8[2,3]{1,0} parameter(0)
-    p1 = s8[3,2]{1,0} parameter(1)
-    ROOT dot = bf16[2,2]{1,0} dot(p0, p1), lhs_contracting_dims={1},
-                                         rhs_contracting_dims={0}
-  })";
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(module_string));
-  TF_ASSERT_OK_AND_ASSIGN(bool upcasted,
-                          DotOperandConverter().Run(module.get()));
-  EXPECT_FALSE(upcasted);
-}
-
-TEST_F(DotOperandConverterTest, NoConvertFromF8toF8) {
-  absl::string_view module_string = R"(
-  HloModule module
-
-  ENTRY main {
-    p0 = f8e4m3fn[2,3]{1,0} parameter(0)
-    p1 = f8e5m2[3,2]{1,0} parameter(1)
-    ROOT dot = bf16[2,2]{1,0} dot(p0, p1), lhs_contracting_dims={1},
-                                         rhs_contracting_dims={0}
-  })";
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(module_string));
-  TF_ASSERT_OK_AND_ASSIGN(bool upcasted,
-                          DotOperandConverter().Run(module.get()));
-  EXPECT_FALSE(upcasted);
-}
-
-TEST_F(DotOperandConverterTest, CompilerOptimizesUsingDotOperandConverter) {
-  absl::string_view module_string = R"(
-  HloModule module
-
-  ENTRY main {
-    p0 = s8[2,3]{1,0} parameter(0)
-    p1 = bf16[3,2]{1,0} parameter(1)
-    ROOT dot = bf16[2,2]{1,0} dot(p0, p1), lhs_contracting_dims={1},
-                                         rhs_contracting_dims={0}
-  })";
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          GetOptimizedModule(module_string));
-}
-
-}  // namespace
-}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/dot_sparsity_rewriter.cc b/third_party/xla/xla/service/gpu/dot_sparsity_rewriter.cc
deleted file mode 100644
index 0f41091..0000000
--- a/third_party/xla/xla/service/gpu/dot_sparsity_rewriter.cc
+++ /dev/null
@@ -1,110 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/dot_sparsity_rewriter.h"
-
-#include <utility>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_creation_utils.h"
-#include "xla/util.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-class SparseDotRewriterImpl : public DfsHloRewriteVisitor {
- public:
-  absl::Status HandleDot(HloInstruction* instr) override {
-    // Only handle sparse dots with a single RHS sparse descriptor.
-    HloDotInstruction* dot = Cast<HloDotInstruction>(instr);
-    if (dot->sparse_operands() != 1 || dot->sparsity().front().index() != 1) {
-      return absl::OkStatus();
-    }
-
-    HloInstruction* lhs = dot->mutable_operand(0);
-    HloInstruction* rhs = dot->mutable_operand(1);
-    HloInstruction* meta = dot->mutable_operand(2);
-
-    // Swap LHS and RHS in the attributes.
-    DotDimensionNumbers dnums = dot->dot_dimension_numbers();
-    std::swap(*dnums.mutable_lhs_batch_dimensions(),
-              *dnums.mutable_rhs_batch_dimensions());
-    std::swap(*dnums.mutable_lhs_contracting_dimensions(),
-              *dnums.mutable_rhs_contracting_dimensions());
-
-    PrecisionConfig precision_config = dot->precision_config();
-    std::swap(precision_config.mutable_operand_precision()->at(0),
-              precision_config.mutable_operand_precision()->at(1));
-
-    SparsityDescriptor sparsity = dot->sparsity().front();
-    sparsity.set_index(0);
-
-    // Create new dot with LHS and RHS swapped.
-    TF_ASSIGN_OR_RETURN(
-        HloInstruction * new_dot,
-        MakeDotHlo(rhs, lhs, dnums, precision_config,
-                   dot->shape().element_type(), {std::move(sparsity)}, {meta}));
-    dot->SetupDerivedInstruction(new_dot);
-
-    // Result dimensions: <batch>, <rhs_noncontracting>, <lhs_noncontracting>
-    int batch_dims = dnums.lhs_batch_dimensions().size();
-    int new_lhs_noncontracting = rhs->shape().rank() - batch_dims -
-                                 dnums.lhs_contracting_dimensions().size();
-    int new_rhs_noncontracting = lhs->shape().rank() - batch_dims -
-                                 dnums.rhs_contracting_dimensions().size();
-
-    int rank = dot->shape().rank();
-    DimensionVector dimensions(rank);
-    for (int i = 0; i < batch_dims; ++i) {
-      dimensions[i] = i;
-    }
-    for (int i = 0; i < new_lhs_noncontracting; ++i) {
-      dimensions[i + batch_dims] = i + batch_dims + new_rhs_noncontracting;
-    }
-    for (int i = 0; i < new_rhs_noncontracting; ++i) {
-      dimensions[i + batch_dims + new_lhs_noncontracting] = i + batch_dims;
-    }
-
-    // Transpose the result.
-    TF_ASSIGN_OR_RETURN(HloInstruction * transpose,
-                        MakeTransposeHlo(new_dot, dimensions));
-    transpose->set_metadata(dot->metadata());
-    *transpose->mutable_shape()->mutable_layout() = dot->shape().layout();
-
-    return ReplaceInstruction(dot, transpose);
-  }
-};
-
-}  // namespace
-
-absl::StatusOr<bool> DotSparsityRewriter::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  return SparseDotRewriterImpl().RunOnModule(module, execution_threads);
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/dot_sparsity_rewriter.h b/third_party/xla/xla/service/gpu/dot_sparsity_rewriter.h
deleted file mode 100644
index b422197..0000000
--- a/third_party/xla/xla/service/gpu/dot_sparsity_rewriter.h
+++ /dev/null
@@ -1,42 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_DOT_SPARSITY_REWRITER_H_
-#define XLA_SERVICE_GPU_DOT_SPARSITY_REWRITER_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-// Make sure sparse dot requirements are met (sparse operand is LHS).
-class DotSparsityRewriter : public HloModulePass {
- public:
-  absl::string_view name() const override { return "dot_sparsity_rewriter"; }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_DOT_SPARSITY_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/dot_sparsity_rewriter_test.cc b/third_party/xla/xla/service/gpu/dot_sparsity_rewriter_test.cc
deleted file mode 100644
index c608f8d..0000000
--- a/third_party/xla/xla/service/gpu/dot_sparsity_rewriter_test.cc
+++ /dev/null
@@ -1,85 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/dot_sparsity_rewriter.h"
-
-#include <memory>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-using ::testing::ElementsAre;
-
-class DotSparsityRewriterTest : public HloTestBase {
- public:
-  DotSparsityRewriterTest() : HloTestBase(/*verifier_layout_sensitive=*/true) {}
-};
-
-TEST_F(DotSparsityRewriterTest, SparseDotRhsToLhs) {
-  const char* module_string = R"(
-HloModule m
-
-ENTRY e {
-  lhs = f16[4,2,16,8,64] parameter(0)
-  rhs = f16[2,4,8,32,128] parameter(1)
-  meta = u16[2,4,8,4,128] parameter(2)
-  ROOT dot = f16[4,2,16,128] dot(lhs, rhs, meta),
-    lhs_contracting_dims={3,4}, rhs_contracting_dims={2,3},
-    lhs_batch_dims={0,1}, rhs_batch_dims={1,0}, sparsity=R.3@2:4
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(module_string));
-  TF_ASSERT_OK_AND_ASSIGN(bool modified,
-                          DotSparsityRewriter().Run(module.get()));
-  EXPECT_TRUE(modified);
-
-  const HloTransposeInstruction* transpose = DynCast<HloTransposeInstruction>(
-      module->entry_computation()->root_instruction());
-  ASSERT_TRUE(transpose != nullptr);
-  EXPECT_THAT(transpose->dimensions(), ElementsAre(0, 1, 3, 2));
-
-  const HloDotInstruction* dot =
-      DynCast<HloDotInstruction>(transpose->operand(0));
-  ASSERT_TRUE(dot != nullptr);
-
-  const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
-  EXPECT_EQ(dnums.lhs_contracting_dimensions(0), 2);
-  EXPECT_EQ(dnums.lhs_contracting_dimensions(1), 3);
-  EXPECT_EQ(dnums.rhs_contracting_dimensions(0), 3);
-  EXPECT_EQ(dnums.rhs_contracting_dimensions(1), 4);
-  EXPECT_EQ(dnums.lhs_batch_dimensions(0), 1);
-  EXPECT_EQ(dnums.lhs_batch_dimensions(1), 0);
-  EXPECT_EQ(dnums.rhs_batch_dimensions(0), 0);
-  EXPECT_EQ(dnums.rhs_batch_dimensions(1), 1);
-
-  EXPECT_EQ(dot->sparse_operands(), 1);
-  EXPECT_EQ(dot->sparsity().front().index(), 0);
-}
-
-}  // namespace
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/double_buffer_loop_unrolling.cc b/third_party/xla/xla/service/gpu/double_buffer_loop_unrolling.cc
deleted file mode 100644
index 9cd9113..0000000
--- a/third_party/xla/xla/service/gpu/double_buffer_loop_unrolling.cc
+++ /dev/null
@@ -1,569 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "xla/service/gpu/double_buffer_loop_unrolling.h"
-
-#include <cstdint>
-#include <iterator>
-#include <optional>
-#include <string>
-#include <utility>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_map.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/log/check.h"
-#include "absl/log/log.h"
-#include "absl/status/status.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_clone_context.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instruction_utils.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/hlo/utils/hlo_query.h"
-#include "xla/service/collective_ops_utils.h"
-#include "xla/service/flatten_call_graph.h"
-#include "xla/util.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-namespace {
-
-void SetChannelIdForNewCollective(HloInstruction* new_instr,
-                                  const HloModule* module) {
-  // This is to track mappings of old->new channel id for async collectives
-  // wrapped in the form of HloAsyncInstruction, the start and done need to
-  // have the same unique channel id.
-  absl::flat_hash_map<int64_t, int64_t> old_to_new_channel_id_map;
-  absl::flat_hash_map<int64_t, HloComputation*> channel_id_comp_map;
-  if (new_instr->IsAsynchronous() && hlo_query::IsCollectiveCommunicationOp(
-                                         new_instr->async_wrapped_opcode())) {
-    HloInstruction* wrapped_instr =
-        DynCast<HloAsyncInstruction>(new_instr)->async_wrapped_instruction();
-    int64_t old_channel_id = *wrapped_instr->channel_id();
-    int64_t new_channel_id = old_to_new_channel_id_map[old_channel_id];
-    if (old_to_new_channel_id_map.find(old_channel_id) ==
-        old_to_new_channel_id_map.end()) {
-      new_channel_id = hlo_query::NextChannelId(*module);
-      VLOG(2) << "Generated new channel id " << new_channel_id;
-      old_to_new_channel_id_map[old_channel_id] = new_channel_id;
-    }
-
-    VLOG(2) << "Setting channel id to " << new_channel_id;
-
-    wrapped_instr->set_channel_id(new_channel_id);
-    if (channel_id_comp_map.find(new_channel_id) == channel_id_comp_map.end()) {
-      channel_id_comp_map[new_channel_id] =
-          new_instr->async_wrapped_computation();
-    } else {
-      channel_id_comp_map[new_channel_id]->AddAsyncStart(new_instr);
-    }
-  } else if (hlo_query::IsCollectiveCommunicationOp(new_instr->opcode()) ||
-             hlo_query::IsAsyncCollectiveStartOp(new_instr)) {
-    new_instr->set_channel_id(hlo_query::NextChannelId(*module));
-  }
-}
-
-using Interval = std::pair<int64_t, int64_t>;
-
-// Parses a string of the format `{{a,b},{c,d},{e,f}...}` to a vector of pairs.
-absl::StatusOr<std::vector<Interval>> ParseVectorOfPairs(
-    absl::string_view str) {
-  TF_ASSIGN_OR_RETURN(std::vector<ReplicaGroup> replica_groups,
-                      ParseReplicaGroupsOnly(str));
-  std::vector<Interval> res;
-  res.reserve(replica_groups.size());
-  for (const ReplicaGroup& replica_group : replica_groups) {
-    TF_RET_CHECK(replica_group.replica_ids_size() == 2);
-    int64_t a = replica_group.replica_ids(0);
-    int64_t b = replica_group.replica_ids(1);
-    res.emplace_back(a, b);
-  }
-  return res;
-}
-
-// This function fixes the `_xla_send_recv_validation` attribute for peeled
-// instructions. When the loop trip count is odd, the peeled instructions are
-// moved before the loop. The collectives in these instructions correspond to
-// the first iteration of the original loop. We have to run this peeled
-// collective for all those devices that had the 0-th iteration as a valid
-// iteration.
-absl::Status SetSendRecvValidationForPeeledInstr(HloInstruction* new_instr,
-                                                 HloInstruction* old_instr) {
-  TF_RET_CHECK(
-      new_instr->opcode() == old_instr->opcode() &&
-      "cloned instruction and original instruction have different opcodes");
-  if (!HloPredicateIsOp<HloOpcode::kCollectivePermute,
-                        HloOpcode::kCollectivePermuteStart, HloOpcode::kSend,
-                        HloOpcode::kRecv>(old_instr)) {
-    return absl::OkStatus();
-  }
-
-  const auto& attribute_map = new_instr->frontend_attributes().map();
-  if (!attribute_map.contains(kSendRecvValidationAttr)) {
-    return absl::OkStatus();
-  }
-
-  VLOG(3) << "Original send-recv iterations: "
-          << attribute_map.at(kSendRecvValidationAttr);
-
-  TF_ASSIGN_OR_RETURN(
-      auto send_recv_validation_attr,
-      ParseVectorOfPairs(attribute_map.at(kSendRecvValidationAttr)));
-
-  uint64_t n_pairs = send_recv_validation_attr.size();
-  if (n_pairs == 0) {
-    return absl::OkStatus();
-  }
-  std::vector<Interval> send_recv_validation_attr_updated(n_pairs, {1, 0});
-  // Check which of the attributes have iteration number zero as valid
-  // iteration. For all those, set the peeled instruction to run.
-  for (std::uint64_t i = 0; i < send_recv_validation_attr.size(); i++) {
-    if (send_recv_validation_attr[i].first <= 0 &&
-        send_recv_validation_attr[i].second >= 0) {
-      send_recv_validation_attr_updated[i] = {0, 0};
-    }
-  }
-
-  hlo_instruction_utils::AddOrUpdateVectorOfPairsAsAttribute(
-      /*instr=*/new_instr, /*attr_name=*/kSendRecvValidationAttr,
-      /*intervals=*/send_recv_validation_attr_updated);
-  return absl::OkStatus();
-}
-
-// This function fixes the `_xla_send_recv_validation` attribute for the two new
-// collectives inside the loop. The calculation of the new valid iterations
-// depends on whether the loop was peeled or not.
-//
-// If the loop was not peeled, then
-//  - iteration 0 of the new loop coressponds to iteration 0,1 of the old loop.
-//  - iteration 1 of the new loop coressponds to iteration 2,3 of the old loop.
-//  - and so on...
-// If the loop was peeled, then the first iteration runs before the loop. So,
-//  - iteration 0 of the new loop coressponds to iteration 1,2 of the old loop.
-//  - iteration 1 of the new loop coressponds to iteration 3,4 of the old loop.
-//  - and so on...
-//
-// Consider the case when the loop was peeled, and the original attribute for
-// some device was {4,7}. Consider that the two new collectives are
-// `collective.1` and `collective.2` (they execute in this order inside the new
-// loop). In the old loop, iterations 4,5,6,7 were valid. In the new
-// loop,
-//  - collective.2 in iteration 1 of new loop runs 4th iteration of old loop.
-//  - collective.1 in iteration 2 of new loop runs 5th iteration of old loop.
-//  - collective.2 in iteration 2 of new loop runs 6th iteration of old loop.
-//  - collective.1 in iteration 3 of new loop runs 7th iteration of old loop.
-// So, the updated attribute for that device are {1,2} for `collective.2` and
-// {2,3} for `collective.1`.
-//
-// In a similar fashion we can generalize the computation of new values based on
-// the values of the old attribute as done in the logic below.
-absl::Status SetSendRecvValidation(HloInstruction* cp1, HloInstruction* cp2,
-                                   bool is_peeled) {
-  TF_RET_CHECK(
-      cp2->opcode() == cp1->opcode() &&
-      "cloned instruction and original instruction have different opcodes");
-  if (!HloPredicateIsOp<HloOpcode::kCollectivePermute,
-                        HloOpcode::kCollectivePermuteStart, HloOpcode::kSend,
-                        HloOpcode::kRecv>(cp1)) {
-    return absl::OkStatus();
-  }
-  const auto& attribute_map = cp2->frontend_attributes().map();
-  if (!attribute_map.contains(kSendRecvValidationAttr)) {
-    return absl::OkStatus();
-  }
-  VLOG(3) << "Original send-recv iterations: "
-          << attribute_map.at(kSendRecvValidationAttr);
-
-  TF_ASSIGN_OR_RETURN(
-      auto send_recv_validation_attr,
-      ParseVectorOfPairs(attribute_map.at(kSendRecvValidationAttr)));
-
-  if (send_recv_validation_attr.size() == 0) {
-    return absl::OkStatus();
-  }
-
-  std::vector<Interval> send_recv_iterations_new_instr1,
-      send_recv_iterations_new_instr2;
-  send_recv_iterations_new_instr1.reserve(send_recv_validation_attr.size());
-  send_recv_iterations_new_instr2.reserve(send_recv_validation_attr.size());
-  for (const Interval& pair : send_recv_validation_attr) {
-    int64_t a = pair.first;
-    int64_t b = pair.second;
-    if (is_peeled) {
-      send_recv_iterations_new_instr1.emplace_back(
-          std::floor(a / 2.0), std::max(0.0, std::floor((b - 1) / 2.0)));
-      send_recv_iterations_new_instr2.emplace_back(
-          std::max(0.0, std::floor((a - 1) / 2.0)),
-          std::max(0.0, std::floor((b - 2) / 2.0)));
-    } else {
-      send_recv_iterations_new_instr1.emplace_back(std::floor((a + 1) / 2.0),
-                                                   std::floor(b / 2.0));
-      send_recv_iterations_new_instr2.emplace_back(
-          std::floor(a / 2.0), std::max(0.0, std::floor((b - 1) / 2.0)));
-    }
-  }
-
-  hlo_instruction_utils::AddOrUpdateVectorOfPairsAsAttribute(
-      /*instr=*/cp1, /*attr_name=*/kSendRecvValidationAttr,
-      /*intervals=*/send_recv_iterations_new_instr1);
-  hlo_instruction_utils::AddOrUpdateVectorOfPairsAsAttribute(
-      /*instr=*/cp2, /*attr_name=*/kSendRecvValidationAttr,
-      /*intervals=*/send_recv_iterations_new_instr2);
-
-  VLOG(3) << "Updated send-recv iterations for " << cp1->name() << " : "
-          << cp1->frontend_attributes().map().at(kSendRecvValidationAttr);
-  VLOG(3) << "Updated send-recv iterations for " << cp2->name() << " : "
-          << cp2->frontend_attributes().map().at(kSendRecvValidationAttr);
-  return absl::OkStatus();
-}
-
-// Handle control predecessors/successors for every old-new instruction pair.
-// For every new instruction, we find the relevant predecessor/successor
-// relationships of the old instruction and we reconstruct them by looking up
-// new (already created) predecessors/successors.
-//
-// When rewiring dependencies from output of the original body, to the input of
-// the cloned body we skip collectives, and ops in `skip_control_dep_injection`.
-absl::Status HandleControlDependencies(
-    const HloComputation* while_body,
-    const absl::flat_hash_map<HloInstruction*, HloInstruction*>& old_to_new_map,
-    HloInstruction::InstructionVector* old_loop_roots,
-    HloInstruction* input_parameter,
-    const absl::flat_hash_set<HloInstruction*>& skip_control_dep_injection) {
-  for (HloInstruction* old_instr : while_body->MakeInstructionPostOrder()) {
-    if (old_to_new_map.find(old_instr) != old_to_new_map.end()) {
-      HloInstruction* new_instr = old_to_new_map.at(old_instr);
-      VLOG(2) << "Processing control predecessors for "
-              << new_instr->ToString();
-      std::vector<HloInstruction*> new_control_pred;
-      new_control_pred.reserve(old_instr->control_predecessors().size());
-      for (HloInstruction* pred : old_instr->control_predecessors()) {
-        if (!old_to_new_map.contains(pred)) {
-          continue;
-        }
-        new_control_pred.push_back(old_to_new_map.at(pred));
-      }
-
-      TF_RETURN_IF_ERROR(new_instr->DropAllControlDeps());
-      for (HloInstruction* new_pred : new_control_pred) {
-        TF_RETURN_IF_ERROR(new_pred->AddControlDependencyTo(new_instr));
-        VLOG(2) << "Adding " << new_pred->ToString()
-                << " to control dependency of " << new_instr->ToString();
-      }
-    }
-  }
-  for (HloInstruction* input_consumer : input_parameter->users()) {
-    for (HloInstruction* old_input : input_consumer->users()) {
-      if (old_to_new_map.find(old_input) != old_to_new_map.end()) {
-        HloInstruction* new_input = old_to_new_map.at(old_input);
-        if (skip_control_dep_injection.find(old_input) ==
-                skip_control_dep_injection.end() &&
-            !IsCollective(old_input)) {
-          for (HloInstruction* old_root : *old_loop_roots) {
-            TF_RETURN_IF_ERROR(old_root->AddControlDependencyTo(new_input));
-          }
-        }
-      }
-    }
-  }
-
-  return absl::OkStatus();
-}
-
-absl::StatusOr<bool> FullyUnroll(HloInstruction* while_instr,
-                                 HloModule* module) {
-  HloComputation* while_body = while_instr->while_body();
-  bool changed = false;
-  VLOG(2) << "Processing root " << while_body->root_instruction()->ToString();
-
-  auto loop_roots = while_body->root_instruction()->mutable_operands();
-  HloInstruction* input_parameter = while_body->parameter_instruction(0);
-  VLOG(2) << "Processing input parameter " << input_parameter->ToString();
-
-  absl::flat_hash_map<HloInstruction*, HloInstruction*> old_to_new_map;
-  absl::flat_hash_set<HloInstruction*> skip_control_dep_injection;
-  std::string clone_suffix = "full_unroll_clone";
-
-  TF_ASSIGN_OR_RETURN(WhileLoopBackendConfig config,
-                      while_instr->backend_config<WhileLoopBackendConfig>());
-  std::vector<HloInstruction*> ops_to_clone;
-  ops_to_clone.reserve(while_body->MakeInstructionPostOrder().size());
-
-  // Pre-loop prep.
-  HloInstruction* old_input_parameter = input_parameter;
-  HloInstruction* new_input_parameter = while_body->root_instruction();
-  absl::flat_hash_set<HloInstruction*> seen_ops;
-  for (HloInstruction* old_instr : while_body->MakeInstructionPostOrder()) {
-    if (seen_ops.contains(old_instr)) {
-      continue;
-    }
-    ops_to_clone.push_back(old_instr);
-    seen_ops.insert(old_instr);
-  }
-
-  int n = config.known_trip_count().n();
-  while (--n) {
-    std::vector<HloInstruction*> new_ops_to_clone;
-    old_to_new_map[old_input_parameter] = new_input_parameter;
-    for (HloInstruction* old_instr : ops_to_clone) {
-      if (old_to_new_map.contains(old_instr)) {
-        continue;
-      }
-      VLOG(2) << "Cloning instruction " << old_instr->ToString();
-      std::vector<HloInstruction*> new_operands;
-      for (HloInstruction* old_operand : old_instr->mutable_operands()) {
-        new_operands.push_back(old_to_new_map[old_operand]);
-      }
-      HloInstruction* new_instr =
-          while_body->AddInstruction(old_instr->CloneWithNewOperands(
-              old_instr->shape(), new_operands, clone_suffix));
-
-      // If an elementwise instruction with constant operand is present, we
-      // won't inject control dependency at the end to allow more constant
-      // folding opportunities.
-      if (old_instr->IsElementwiseBinary() && old_instr->HasConstantOperand()) {
-        skip_control_dep_injection.insert(old_instr);
-      }
-      SetChannelIdForNewCollective(new_instr, module);
-      old_to_new_map[old_instr] = new_instr;
-      new_ops_to_clone.push_back(new_instr);
-      VLOG(2) << "Added instruction " << new_instr->ToString();
-    }
-
-    while_body->set_root_instruction(
-        old_to_new_map[while_body->root_instruction()]);
-    VLOG(2) << "Replaced with new root "
-            << while_body->root_instruction()->ToString();
-
-    TF_RETURN_IF_ERROR(HandleControlDependencies(
-        while_body, old_to_new_map, &loop_roots, old_input_parameter,
-        skip_control_dep_injection));
-
-    // Inductive step update, clean/update necessary buffers to prepare them for
-    // the next unrolling iteration.
-    old_to_new_map.clear();
-    skip_control_dep_injection.clear();
-    loop_roots = while_body->root_instruction()->mutable_operands();
-    old_input_parameter = new_input_parameter;
-    new_input_parameter = while_body->root_instruction();
-    ops_to_clone = std::move(new_ops_to_clone);
-    changed = true;
-  }
-
-  WhileLoopBackendConfig new_config;
-  new_config.mutable_known_trip_count()->set_n(1);
-  TF_RETURN_IF_ERROR(while_instr->set_backend_config(new_config));
-
-  return changed;
-}
-
-absl::Status PeelInstructionsForOddTripCount(HloModule* module,
-                                             HloInstruction* while_instr) {
-  std::string suffix = "peeled_double_buffer";
-  absl::flat_hash_map<HloInstruction*, HloInstruction*> old_to_new_map;
-  HloComputation* while_body = while_instr->while_body();
-  HloInstruction* input_parameter = while_body->parameter_instruction(0);
-  HloInstruction* input_tuple = while_instr->mutable_operand(0);
-
-  auto old_loop_roots = while_body->root_instruction()->mutable_operands();
-  HloComputation* parent_comp = while_instr->parent();
-  old_to_new_map[input_parameter] = input_tuple;
-
-  for (HloInstruction* old_instr : while_body->MakeInstructionPostOrder()) {
-    if (old_to_new_map.find(old_instr) != old_to_new_map.end()) {
-      continue;
-    }
-    VLOG(2) << "Peeling instruction " << old_instr->ToString();
-    std::vector<HloInstruction*> new_operands(old_instr->operand_count());
-    for (int64_t i = 0; i < old_instr->operand_count(); i++) {
-      new_operands[i] = old_to_new_map[old_instr->mutable_operand(i)];
-    }
-    HloInstruction* new_instr =
-        parent_comp->AddInstruction(old_instr->CloneWithNewOperands(
-            old_instr->shape(), new_operands, suffix));
-
-    SetChannelIdForNewCollective(new_instr, module);
-    TF_CHECK_OK(SetSendRecvValidationForPeeledInstr(new_instr, old_instr));
-    old_to_new_map[old_instr] = new_instr;
-    VLOG(2) << "Added instruction " << new_instr->ToString()
-            << " to parent computation.";
-  }
-
-  std::vector<HloInstruction*> new_roots;
-  for (HloInstruction* instr : old_loop_roots) {
-    new_roots.push_back(old_to_new_map[instr]);
-  }
-  TF_RETURN_IF_ERROR(while_instr->ReplaceOperandWith(
-      0, old_to_new_map[while_body->root_instruction()]));
-  VLOG(2) << "Replaced with new input tuple "
-          << while_instr->operand(0)->ToString();
-
-  // Handle existing control dependencies.
-  for (HloInstruction* old_instr : while_body->MakeInstructionPostOrder()) {
-    if (old_to_new_map.find(old_instr) != old_to_new_map.end()) {
-      HloInstruction* new_instr = old_to_new_map[old_instr];
-      VLOG(2) << "Processing control predecessors for peeled instruction "
-              << new_instr->ToString();
-      std::vector<HloInstruction*> new_control_pred(
-          old_instr->control_predecessors().size());
-      for (HloInstruction* pred : old_instr->control_predecessors()) {
-        new_control_pred.push_back(old_to_new_map[pred]);
-      }
-
-      TF_RETURN_IF_ERROR(new_instr->DropAllControlDeps());
-      for (HloInstruction* new_pred : new_control_pred) {
-        TF_RETURN_IF_ERROR(new_pred->AddControlDependencyTo(new_instr));
-        VLOG(2) << "Adding " << new_pred->ToString()
-                << " to control dependency of peeled instruction: "
-                << new_instr->ToString();
-      }
-    }
-  }
-  return absl::OkStatus();
-}
-
-// TODO(olechwierowicz): Extract common logic of this and `FullyUnroll` to
-// a separate function.
-absl::StatusOr<bool> DoubleBufferingUnroll(HloInstruction* while_instr,
-                                           HloModule* module) {
-  TF_ASSIGN_OR_RETURN(auto config,
-                      while_instr->backend_config<WhileLoopBackendConfig>());
-
-  CHECK(config.has_known_trip_count())
-      << "Only loops with known trip count are supported.";
-  int64_t exact_trip_count = config.known_trip_count().n();
-  VLOG(2) << "Processing while loop " << while_instr->ToString()
-          << " with trip count: " << exact_trip_count;
-
-  HloComputation* while_body = while_instr->while_body();
-
-  VLOG(2) << "Processing root " << while_body->root_instruction()->ToString();
-
-  auto old_loop_roots = while_body->root_instruction()->mutable_operands();
-  HloInstruction* input_parameter = while_body->parameter_instruction(0);
-  VLOG(2) << "Processing input parameter " << input_parameter->ToString();
-  absl::flat_hash_map<HloInstruction*, HloInstruction*> old_to_new_map;
-  absl::flat_hash_set<HloInstruction*> skip_control_dep_injection;
-
-  bool is_peeled = exact_trip_count % 2;
-  if (is_peeled) {
-    VLOG(2) << "Found loops with odd trip count, 1 iteration will be peeled "
-               "outside of the main body.";
-    TF_RETURN_IF_ERROR(PeelInstructionsForOddTripCount(module, while_instr));
-    exact_trip_count -= 1;
-  }
-
-  std::string suffix = "double_buffer_clone";
-  old_to_new_map[input_parameter] = while_body->root_instruction();
-  for (HloInstruction* old_instr : while_body->MakeInstructionPostOrder()) {
-    if (old_to_new_map.find(old_instr) != old_to_new_map.end()) {
-      continue;
-    }
-    VLOG(2) << "Cloning instruction " << old_instr->ToString();
-    std::vector<HloInstruction*> new_operands;
-    for (HloInstruction* old_operand : old_instr->mutable_operands()) {
-      new_operands.push_back(old_to_new_map[old_operand]);
-    }
-    HloInstruction* new_instr =
-        while_body->AddInstruction(old_instr->CloneWithNewOperands(
-            old_instr->shape(), new_operands, suffix));
-
-    // If an elementwise instruction with constant operand is present, we
-    // won't inject control dependency at the end to allow more constant
-    // folding opportunities.
-    if (old_instr->IsElementwiseBinary() && old_instr->HasConstantOperand()) {
-      skip_control_dep_injection.insert(old_instr);
-    }
-    SetChannelIdForNewCollective(new_instr, module);
-    TF_CHECK_OK(SetSendRecvValidation(old_instr, new_instr, is_peeled));
-    old_to_new_map[old_instr] = new_instr;
-    VLOG(2) << "Added instruction " << new_instr->ToString();
-  }
-
-  while_body->set_root_instruction(
-      old_to_new_map[while_body->root_instruction()]);
-  VLOG(2) << "Replaced with new root "
-          << while_body->root_instruction()->ToString();
-
-  // Handle existing control dependencies.
-  TF_RETURN_IF_ERROR(HandleControlDependencies(while_body, old_to_new_map,
-                                               &old_loop_roots, input_parameter,
-                                               skip_control_dep_injection));
-
-  WhileLoopBackendConfig new_config;
-  new_config.mutable_known_trip_count()->set_n(exact_trip_count / 2);
-  TF_RETURN_IF_ERROR(while_instr->set_backend_config(new_config));
-  return true;  // changed
-}
-
-}  // namespace
-
-absl::StatusOr<bool> DoubleBufferLoopUnrolling::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  bool changed = false;
-  std::vector<HloInstruction*> while_instrs;
-  for (auto comp : module->MakeNonfusionComputations()) {
-    absl::c_copy_if(comp->instructions(), std::back_inserter(while_instrs),
-                    HloPredicateIsOp<HloOpcode::kWhile>);
-  }
-  VLOG(2) << "Processing " << while_instrs.size() << " while loops.";
-
-  for (HloInstruction* while_instr : while_instrs) {
-    TF_ASSIGN_OR_RETURN(WhileLoopBackendConfig config,
-                        while_instr->backend_config<WhileLoopBackendConfig>());
-    if (!config.has_known_trip_count() || config.known_trip_count().n() == 1) {
-      VLOG(2) << while_instr->ToString()
-              << " doesn't have exact trip count, skipping loop unrolling "
-                 "for now";
-      continue;
-    }
-
-    if (unroll_strategy_ == UnrollStrategy::kFullUnroll) {
-      TF_ASSIGN_OR_RETURN(changed, FullyUnroll(while_instr, module));
-    } else if (unroll_strategy_ == UnrollStrategy::kDoubleBuffer) {
-      TF_ASSIGN_OR_RETURN(changed, DoubleBufferingUnroll(while_instr, module));
-    } else {
-      LOG(FATAL) << absl::StrCat("Unhandled unrolling strategy: ",
-                                 unroll_strategy_);
-    }
-  }
-
-  VLOG(2) << "LoopDoubleBufferTransformer output: " << module->ToString();
-
-  // Run necessary cleanup to ensure LoopDoubleBufferTransformer behaves
-  // correctly.
-  if (changed) {
-    // The call graph will not be flat if one of the loops that was unrolled
-    // contains any kind of call to another computation---since the call will
-    // be duplicated, thereby adding a second callsite for that computation.
-    TF_RETURN_IF_ERROR(
-        FlattenCallGraph().Run(module, execution_threads).status());
-  }
-
-  return changed;
-}
-
-}  // end namespace gpu
-}  // end namespace xla
diff --git a/third_party/xla/xla/service/gpu/double_buffer_loop_unrolling.h b/third_party/xla/xla/service/gpu/double_buffer_loop_unrolling.h
deleted file mode 100644
index 120070d..0000000
--- a/third_party/xla/xla/service/gpu/double_buffer_loop_unrolling.h
+++ /dev/null
@@ -1,75 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_DOUBLE_BUFFER_LOOP_UNROLLING_H_
-#define XLA_SERVICE_GPU_DOUBLE_BUFFER_LOOP_UNROLLING_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-// With `kDoubleBuffer` strategy:
-//   This pass performs the unrolling-by-2 loop transformation
-//   to effectively achieve double buffering between inputs and outputs
-//   of previously rolled iterations.
-//   This pass only runs on loops with known trip counts.
-//   For even number of iterations, unrolling-by-2 will be done directly.
-//   For odd number of iterations, the first iteration of the loop will be
-//   peeled outside of the while loop to make the trip count an even number,
-//   then proceed to unroll by 2.
-//   It also updates the trip count property of the loop to the correct one
-//   (n/2).
-//
-// With `kFullUnroll` strategy:
-//   This pass will perform the full unroll of the loop with the same strategy
-//   that is used with `kDoubleBuffer` but while loop trip count times.
-//   It updates the trip count of the while loop to 1, and relies on other
-//   passes (like `WhileLoopSimplifier`) to simplify/get rid of the while loop
-//   eventually.
-//
-// Note that this pass will flatten the call graph if any loop has been
-// unrolled.
-// TODO(olechwierowicz): Rename the loop unroller to something more generic like
-// 'DoubleBufferLoopUnrolling'.
-class DoubleBufferLoopUnrolling : public HloModulePass {
- public:
-  enum class UnrollStrategy { kDoubleBuffer, kFullUnroll };
-
-  explicit DoubleBufferLoopUnrolling(
-      UnrollStrategy unroll_strategy = UnrollStrategy::kDoubleBuffer)
-      : unroll_strategy_(unroll_strategy) {};
-  ~DoubleBufferLoopUnrolling() override = default;
-
-  absl::string_view name() const override {
-    return "loop-double-buffer-transformer";
-  }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
-  UnrollStrategy unroll_strategy_;
-};
-
-}  // end namespace gpu
-}  // end namespace xla
-
-#endif  // XLA_SERVICE_GPU_DOUBLE_BUFFER_LOOP_UNROLLING_H_
diff --git a/third_party/xla/xla/service/gpu/double_buffer_loop_unrolling_test.cc b/third_party/xla/xla/service/gpu/double_buffer_loop_unrolling_test.cc
deleted file mode 100644
index 8fed319..0000000
--- a/third_party/xla/xla/service/gpu/double_buffer_loop_unrolling_test.cc
+++ /dev/null
@@ -1,1243 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/double_buffer_loop_unrolling.h"
-
-#include <cstdint>
-#include <memory>
-#include <optional>
-
-#include "absl/container/flat_hash_set.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/hlo/utils/hlo_query.h"
-#include "xla/service/tuple_simplifier.h"
-#include "xla/test.h"
-#include "xla/tests/filecheck.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/xla.pb.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/status_matchers.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-using tsl::testing::IsOkAndHolds;
-
-int64_t CountInstructions(const HloComputation& computation, HloOpcode opcode) {
-  int64_t count = 0;
-  for (const auto& instruction : computation.instructions()) {
-    if (instruction->opcode() == opcode) {
-      count++;
-    }
-  }
-  return count;
-}
-
-int64_t CountInstructions(const HloModule& module, HloOpcode opcode) {
-  int64_t count = 0;
-  for (const auto& computation : module.computations()) {
-    count += CountInstructions((*computation), opcode);
-  }
-  return count;
-}
-
-class GpuLoopDoubleBufferTransformerTest : public HloTestBase {
-  DebugOptions GetDebugOptionsForTest() override {
-    DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest();
-    debug_options.set_xla_gpu_enable_while_loop_double_buffering(true);
-    return debug_options;
-  }
-};
-
-TEST_F(GpuLoopDoubleBufferTransformerTest, FullUnrollOddTripCountTest) {
-  const char* const kModuleString = R"(
-HloModule all_gather_overlapping
-condition {
-  input_tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) parameter(0)
-  cond = s32[] get-tuple-element(input_tuple), index=3
-  trip_count = s32[] constant(10)
-  ROOT done = pred[] compare(cond, trip_count), direction=LT
-}
-
-body {
- input_tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) parameter(0)
- param_0 = f32[1,128] get-tuple-element(input_tuple), index=0
- param_1 = f32[2,128] get-tuple-element(input_tuple), index=2
- cond = s32[] get-tuple-element(input_tuple), index=3
- c0 = f32[] constant(0)
- splat_c0 = f32[1,128] broadcast(c0), dimensions={}
- add = f32[1,128] add(splat_c0, param_0)
- all-gather-start = (f32[1,128], f32[2,128]) all-gather-start(add), channel_id=1337, replica_groups={{0,1}}, dimensions={0}, use_global_device_ids=true
- c1_s32 = s32[] constant(1)
- c0_s32 = s32[] constant(0)
- one = s32[] constant(1)
- cond_plus_1 = s32[] add(cond, one)
- dynamic-slice = f32[1,128] dynamic-slice(param_1, c1_s32, c0_s32), dynamic_slice_sizes={1,128}
- all-gather-done = f32[2,128] all-gather-done(all-gather-start)
- ROOT output_tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) tuple(param_0, dynamic-slice, all-gather-done, cond_plus_1)
-}
-
-ENTRY main {
- param_0 = f32[1,128] parameter(0)
- param_1 = f32[2,128] parameter(1)
- param_2 = s32[] constant(0)
- tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) tuple(param_0, param_0, param_1, param_2)
- ROOT while = (f32[1,128], f32[1,128], f32[2,128], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"11"}}
-})";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
-                          ParseAndReturnVerifiedModule(kModuleString));
-  DoubleBufferLoopUnrolling double_buffer(
-      DoubleBufferLoopUnrolling::UnrollStrategy::kFullUnroll);
-  TupleSimplifier tuple_simp;
-  bool changed;
-  TF_ASSERT_OK_AND_ASSIGN(changed, double_buffer.Run(module.get()));
-  EXPECT_TRUE(changed);
-  TF_ASSERT_OK_AND_ASSIGN(changed, tuple_simp.Run(module.get()));
-  EXPECT_TRUE(changed);
-  HloInstruction* while_instruction = hlo_query::GetFirstInstructionWithOpcode(
-      *module->entry_computation(), HloOpcode::kWhile);
-  TF_ASSERT_OK_AND_ASSIGN(
-      WhileLoopBackendConfig config,
-      while_instruction->backend_config<WhileLoopBackendConfig>());
-  int64_t exact_trip_count = config.known_trip_count().n();
-  EXPECT_EQ(exact_trip_count, 1);
-  EXPECT_EQ(CountInstructions((*while_instruction->while_body()),
-                              HloOpcode::kAllGatherStart),
-            11);
-  EXPECT_EQ(CountInstructions((*module), HloOpcode::kAllGatherStart), 11);
-}
-
-TEST_F(GpuLoopDoubleBufferTransformerTest, FullUnrollEvenTripCountTest) {
-  const char* const kModuleString = R"(
-HloModule all_gather_overlapping
-condition {
-  input_tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) parameter(0)
-  cond = s32[] get-tuple-element(input_tuple), index=3
-  trip_count = s32[] constant(10)
-  ROOT done = pred[] compare(cond, trip_count), direction=LT
-}
-
-body {
- input_tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) parameter(0)
- param_0 = f32[1,128] get-tuple-element(input_tuple), index=0
- param_1 = f32[2,128] get-tuple-element(input_tuple), index=2
- cond = s32[] get-tuple-element(input_tuple), index=3
- c0 = f32[] constant(0)
- splat_c0 = f32[1,128] broadcast(c0), dimensions={}
- add = f32[1,128] add(splat_c0, param_0)
- // Start all-gather communication
- all-gather-start = (f32[1,128], f32[2,128]) all-gather-start(add), channel_id=1337, replica_groups={{0,1}}, dimensions={0}, use_global_device_ids=true
- c1_s32 = s32[] constant(1)
- c0_s32 = s32[] constant(0)
- one = s32[] constant(1)
- cond_plus_1 = s32[] add(cond, one)
- dynamic-slice = f32[1,128] dynamic-slice(param_1, c1_s32, c0_s32), dynamic_slice_sizes={1,128}
- all-gather-done = f32[2,128] all-gather-done(all-gather-start)
- ROOT output_tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) tuple(param_0, dynamic-slice, all-gather-done, cond_plus_1)
-}
-
-ENTRY main {
- param_0 = f32[1,128] parameter(0)
- param_1 = f32[2,128] parameter(1)
- param_2 = s32[] constant(0)
- tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) tuple(param_0, param_0, param_1, param_2)
- ROOT while = (f32[1,128], f32[1,128], f32[2,128], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"10"}}
-})";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
-                          ParseAndReturnVerifiedModule(kModuleString));
-  DoubleBufferLoopUnrolling double_buffer(
-      DoubleBufferLoopUnrolling::UnrollStrategy::kFullUnroll);
-  TupleSimplifier tuple_simp;
-  bool changed;
-  TF_ASSERT_OK_AND_ASSIGN(changed, double_buffer.Run(module.get()));
-  EXPECT_TRUE(changed);
-  TF_ASSERT_OK_AND_ASSIGN(changed, tuple_simp.Run(module.get()));
-  EXPECT_TRUE(changed);
-
-  HloInstruction* while_instruction;
-  for (auto instr : module->entry_computation()->instructions()) {
-    if (instr->opcode() == HloOpcode::kWhile) {
-      while_instruction = instr;
-    }
-  }
-  TF_ASSERT_OK_AND_ASSIGN(
-      WhileLoopBackendConfig config,
-      while_instruction->backend_config<WhileLoopBackendConfig>());
-  int64_t exact_trip_count = config.known_trip_count().n();
-  EXPECT_EQ(exact_trip_count, 1);
-  EXPECT_EQ(CountInstructions((*while_instruction->while_body()),
-                              HloOpcode::kAllGatherStart),
-            10);
-  EXPECT_EQ(CountInstructions((*module), HloOpcode::kAllGatherStart), 10);
-}
-
-TEST_F(GpuLoopDoubleBufferTransformerTest, UnrolledLoopEvenTripCount) {
-  const char* const kModuleString = R"(
-HloModule all_gather_overlapping
-condition {
-  input_tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) parameter(0)
-  cond = s32[] get-tuple-element(input_tuple), index=3
-  trip_count = s32[] constant(10)
-  ROOT done = pred[] compare(cond, trip_count), direction=LT
-}
-
-body {
- input_tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) parameter(0)
- param_0 = f32[1,128] get-tuple-element(input_tuple), index=0
- param_1 = f32[2,128] get-tuple-element(input_tuple), index=2
- cond = s32[] get-tuple-element(input_tuple), index=3
- c0 = f32[] constant(0)
- splat_c0 = f32[1,128] broadcast(c0), dimensions={}
- add = f32[1,128] add(splat_c0, param_0)
- // Start all-gather communication
- all-gather-start = (f32[1,128], f32[2,128]) all-gather-start(add), channel_id=1337, replica_groups={{0,1}}, dimensions={0}, use_global_device_ids=true
- // Intertwined with the all-gather communication, an operation happens which
- // depends on param_1, but crucially has a different output shape (which
- // excludes reusing param_1's buffer for its output).
- c1_s32 = s32[] constant(1)
- c0_s32 = s32[] constant(0)
- one = s32[] constant(1)
- cond_plus_1 = s32[] add(cond, one)
- dynamic-slice = f32[1,128] dynamic-slice(param_1, c1_s32, c0_s32), dynamic_slice_sizes={1,128}
- // The all-gather communication finishes
- all-gather-done = f32[2,128] all-gather-done(all-gather-start)
- ROOT output_tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) tuple(param_0, dynamic-slice, all-gather-done, cond_plus_1)
-}
-
-ENTRY main {
- param_0 = f32[1,128] parameter(0)
- param_1 = f32[2,128] parameter(1)
- param_2 = s32[] constant(0)
- tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) tuple(param_0, param_0, param_1, param_2)
- ROOT while = (f32[1,128], f32[1,128], f32[2,128], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"10"}}
-})";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
-                          ParseAndReturnVerifiedModule(kModuleString));
-  DoubleBufferLoopUnrolling double_buffer;
-  TupleSimplifier tuple_simp;
-  bool changed;
-  TF_ASSERT_OK_AND_ASSIGN(changed, double_buffer.Run(module.get()));
-  EXPECT_TRUE(changed);
-  TF_ASSERT_OK_AND_ASSIGN(changed, tuple_simp.Run(module.get()));
-  EXPECT_TRUE(changed);
-
-  HloInstruction* while_instruction = hlo_query::GetFirstInstructionWithOpcode(
-      *module->entry_computation(), HloOpcode::kWhile);
-  TF_ASSERT_OK_AND_ASSIGN(
-      WhileLoopBackendConfig config,
-      while_instruction->backend_config<WhileLoopBackendConfig>());
-  int64_t exact_trip_count = config.known_trip_count().n();
-  // We expect that after unrolling, the total trip count is half of original
-  // count.
-  EXPECT_EQ(exact_trip_count, 5);
-  // We expect that after unrolling, there should be 2 allgather starts,
-  // both in while body.
-  EXPECT_EQ(CountInstructions((*while_instruction->while_body()),
-                              HloOpcode::kAllGatherStart),
-            2);
-  EXPECT_EQ(CountInstructions((*module), HloOpcode::kAllGatherStart), 2);
-}
-
-TEST_F(GpuLoopDoubleBufferTransformerTest, UnrolledLoopOddTripCount) {
-  const char* const kModuleString = R"(
-HloModule all_gather_overlapping
-condition {
-  input_tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) parameter(0)
-  cond = s32[] get-tuple-element(input_tuple), index=3
-  trip_count = s32[] constant(10)
-  ROOT done = pred[] compare(cond, trip_count), direction=LT
-}
-
-body {
- input_tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) parameter(0)
- param_0 = f32[1,128] get-tuple-element(input_tuple), index=0
- param_1 = f32[2,128] get-tuple-element(input_tuple), index=2
- cond = s32[] get-tuple-element(input_tuple), index=3
- c0 = f32[] constant(0)
- splat_c0 = f32[1,128] broadcast(c0), dimensions={}
- add = f32[1,128] add(splat_c0, param_0)
- // Start all-gather communication
- all-gather-start = (f32[1,128], f32[2,128]) all-gather-start(add), channel_id=1337, replica_groups={{0,1}}, dimensions={0}, use_global_device_ids=true
- // Intertwined with the all-gather communication, an operation happens which
- // depends on param_1, but crucially has a different output shape (which
- // excludes reusing param_1's buffer for its output).
- c1_s32 = s32[] constant(1)
- c0_s32 = s32[] constant(0)
- one = s32[] constant(1)
- cond_plus_1 = s32[] add(cond, one)
- dynamic-slice = f32[1,128] dynamic-slice(param_1, c1_s32, c0_s32), dynamic_slice_sizes={1,128}
- // The all-gather communication finishes
- all-gather-done = f32[2,128] all-gather-done(all-gather-start)
- ROOT output_tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) tuple(param_0, dynamic-slice, all-gather-done, cond_plus_1)
-}
-
-ENTRY main {
- param_0 = f32[1,128] parameter(0)
- param_1 = f32[2,128] parameter(1)
- param_2 = s32[] constant(0)
- tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) tuple(param_0, param_0, param_1, param_2)
- ROOT while = (f32[1,128], f32[1,128], f32[2,128], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"11"}}
-})";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
-                          ParseAndReturnVerifiedModule(kModuleString));
-  DoubleBufferLoopUnrolling double_buffer;
-  TupleSimplifier tuple_simp;
-  EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
-  EXPECT_THAT(tuple_simp.Run(module.get()), IsOkAndHolds(true));
-
-  // We expect that for the while loop, no further copy needs to be added to the
-  // module.
-  HloInstruction* while_instruction = hlo_query::GetFirstInstructionWithOpcode(
-      *module->entry_computation(), HloOpcode::kWhile);
-  TF_ASSERT_OK_AND_ASSIGN(
-      WhileLoopBackendConfig config,
-      while_instruction->backend_config<WhileLoopBackendConfig>());
-  int64_t exact_trip_count = config.known_trip_count().n();
-  // We expect that after unrolling, the total trip count is half of original
-  // count.
-  EXPECT_EQ(exact_trip_count, 5);
-
-  // We expect that after unrolling, there should be 3 allgather starts,
-  // 1 in parent computation, 2 in while body.
-  EXPECT_EQ(CountInstructions((*while_instruction->while_body()),
-                              HloOpcode::kAllGatherStart),
-            2);
-  EXPECT_EQ(CountInstructions((*module), HloOpcode::kAllGatherStart), 3);
-
-  // We expect that after unrolling, the third operand of the input tuple should
-  // be the peeled allgather done.
-  EXPECT_EQ(while_instruction->operand(0)->operand(2)->opcode(),
-            HloOpcode::kAllGatherDone);
-}
-
-TEST_F(GpuLoopDoubleBufferTransformerTest,
-       UnrolledLoopNoControlDepsForConstantAdd) {
-  const char* const kModuleString = R"(
-HloModule loop_unrolling_no_deps
-condition {
-  input_tuple = (f32[], s32[]) parameter(0)
-  cond = s32[] get-tuple-element(input_tuple), index=1
-  trip_count = s32[] constant(10)
-  ROOT done = pred[] compare(cond, trip_count), direction=LT
-}
-
-body {
- input_tuple = (f32[], s32[]) parameter(0)
- param_0 = f32[] get-tuple-element(input_tuple), index=0
- cond = s32[] get-tuple-element(input_tuple), index=1
- c2 = f32[] constant(2)
- add = f32[] add(c2, param_0)
- one = s32[] constant(1)
- cond_plus_1 = s32[] add(cond, one)
- ROOT output_tuple = (f32[], s32[]) tuple(add, cond_plus_1)
-}
-
-ENTRY main {
- param_0 = f32[] parameter(0)
- param_2 = s32[] constant(0)
- tuple = (f32[], s32[]) tuple(param_0, param_2)
- ROOT while = (f32[], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"11"}}
-})";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
-                          ParseAndReturnVerifiedModule(kModuleString));
-  DoubleBufferLoopUnrolling double_buffer;
-  TupleSimplifier tuple_simp;
-  EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
-  EXPECT_THAT(tuple_simp.Run(module.get()), IsOkAndHolds(true));
-
-  HloInstruction* while_instruction = hlo_query::GetFirstInstructionWithOpcode(
-      *module->entry_computation(), HloOpcode::kWhile);
-  TF_ASSERT_OK_AND_ASSIGN(
-      WhileLoopBackendConfig config,
-      while_instruction->backend_config<WhileLoopBackendConfig>());
-  int64_t exact_trip_count = config.known_trip_count().n();
-  // We expect that after unrolling, the total trip count is half of original
-  // count.
-  EXPECT_EQ(exact_trip_count, 5);
-
-  // We expect that after unrolling, there should be 4 adds
-  EXPECT_EQ(
-      CountInstructions((*while_instruction->while_body()), HloOpcode::kAdd),
-      4);
-
-  // We expect that after unrolling, the first operand of the output tuple
-  // should not have any control dependency since it's a elementwise add with a
-  // constant operand.
-  EXPECT_EQ(while_instruction->while_body()
-                ->root_instruction()
-                ->operand(0)
-                ->control_predecessors()
-                .size(),
-            0);
-}
-
-TEST_F(GpuLoopDoubleBufferTransformerTest,
-       UnrolledLoopNoControlDepsForCollective) {
-  const char* const kModuleString = R"(
-HloModule loop_unrolling_no_deps
-condition {
-  input_tuple = (f32[], s32[]) parameter(0)
-  cond = s32[] get-tuple-element(input_tuple), index=1
-  trip_count = s32[] constant(10)
-  ROOT done = pred[] compare(cond, trip_count), direction=LT
-}
-
-ar_add {
-  Arg_1 = f32[] parameter(1)
-  Arg_0 = f32[] parameter(0)
-  ROOT add_ar = f32[] add(Arg_1, Arg_0)
-}
-
-body {
- input_tuple = (f32[], s32[]) parameter(0)
- param_0 = f32[] get-tuple-element(input_tuple), index=0
- cond = s32[] get-tuple-element(input_tuple), index=1
- all-reduce-start = f32[] all-reduce-start(param_0), channel_id=8, replica_groups={{0}}, to_apply=ar_add, backend_config="{\"is_sync\":false}"
- one = s32[] constant(1)
- all-reduce-done = f32[] all-reduce-done(all-reduce-start)
- cond_plus_1 = s32[] add(cond, one)
- ROOT output_tuple = (f32[], s32[]) tuple(all-reduce-done, cond_plus_1)
-}
-
-ENTRY main {
- param_0 = f32[] parameter(0)
- param_2 = s32[] constant(0)
- tuple = (f32[], s32[]) tuple(param_0, param_2)
- ROOT while = (f32[], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"10"}}
-})";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
-                          ParseAndReturnVerifiedModule(kModuleString));
-  DoubleBufferLoopUnrolling double_buffer;
-  TupleSimplifier tuple_simp;
-  EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
-  EXPECT_THAT(tuple_simp.Run(module.get()), IsOkAndHolds(true));
-
-  HloInstruction* while_instruction = hlo_query::GetFirstInstructionWithOpcode(
-      *module->entry_computation(), HloOpcode::kWhile);
-  TF_ASSERT_OK_AND_ASSIGN(
-      WhileLoopBackendConfig config,
-      while_instruction->backend_config<WhileLoopBackendConfig>());
-  int64_t exact_trip_count = config.known_trip_count().n();
-  // We expect that after unrolling, the total trip count is half of original
-  // count.
-  EXPECT_EQ(exact_trip_count, 5);
-
-  // We expect that after unrolling, there should be 2 all-reduce-starts
-  EXPECT_EQ(CountInstructions((*while_instruction->while_body()),
-                              HloOpcode::kAllReduceStart),
-            2);
-  absl::flat_hash_set<int64_t> channel_ids;
-  hlo_query::ForEachInstructionWithOpcode(
-      *while_instruction->while_body(), HloOpcode::kAllReduceStart,
-      [&channel_ids](HloInstruction* ar) {
-        // We expect that after unrolling, all-reduces should not have any
-        // control deps.
-        EXPECT_EQ(ar->control_predecessors().size(), 0);
-        channel_ids.insert(*(ar->channel_id()));
-      });
-  // we expect that all 2 all-reduces will have different channel ids.
-  EXPECT_EQ(channel_ids.size(), 2);
-}
-
-TEST_F(GpuLoopDoubleBufferTransformerTest,
-       FullyUnrolledLoopNoControlDepsForCollective) {
-  const char* const kModuleString = R"(
-HloModule loop_unrolling_no_deps
-condition {
-  input_tuple = (f32[], s32[]) parameter(0)
-  cond = s32[] get-tuple-element(input_tuple), index=1
-  trip_count = s32[] constant(10)
-  ROOT done = pred[] compare(cond, trip_count), direction=LT
-}
-
-ar_add {
-  Arg_1 = f32[] parameter(1)
-  Arg_0 = f32[] parameter(0)
-  ROOT add_ar = f32[] add(Arg_1, Arg_0)
-}
-
-body {
- input_tuple = (f32[], s32[]) parameter(0)
- param_0 = f32[] get-tuple-element(input_tuple), index=0
- cond = s32[] get-tuple-element(input_tuple), index=1
- all-reduce-start = f32[] all-reduce-start(param_0), channel_id=8, replica_groups={{0}}, to_apply=ar_add, backend_config="{\"is_sync\":false}"
- one = s32[] constant(1)
- all-reduce-done = f32[] all-reduce-done(all-reduce-start)
- cond_plus_1 = s32[] add(cond, one)
- ROOT output_tuple = (f32[], s32[]) tuple(all-reduce-done, cond_plus_1)
-}
-
-ENTRY main {
- param_0 = f32[] parameter(0)
- param_2 = s32[] constant(0)
- tuple = (f32[], s32[]) tuple(param_0, param_2)
- ROOT while = (f32[], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"10"}}
-})";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
-                          ParseAndReturnVerifiedModule(kModuleString));
-  DoubleBufferLoopUnrolling double_buffer(
-      DoubleBufferLoopUnrolling::UnrollStrategy::kFullUnroll);
-  TupleSimplifier tuple_simp;
-  EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
-  EXPECT_THAT(tuple_simp.Run(module.get()), IsOkAndHolds(true));
-
-  HloInstruction* while_instruction = hlo_query::GetFirstInstructionWithOpcode(
-      *module->entry_computation(), HloOpcode::kWhile);
-  TF_ASSERT_OK_AND_ASSIGN(
-      WhileLoopBackendConfig config,
-      while_instruction->backend_config<WhileLoopBackendConfig>());
-  int64_t exact_trip_count = config.known_trip_count().n();
-  EXPECT_EQ(exact_trip_count, 1);
-
-  // We expect that after unrolling, there should be 10 all-reduce-starts
-  EXPECT_EQ(CountInstructions((*while_instruction->while_body()),
-                              HloOpcode::kAllReduceStart),
-            10);
-  absl::flat_hash_set<int64_t> channel_ids;
-  hlo_query::ForEachInstructionWithOpcode(
-      *while_instruction->while_body(), HloOpcode::kAllReduceStart,
-      [&channel_ids](HloInstruction* ar) {
-        // We expect that after unrolling, all-reduces should not have any
-        // control deps.
-        EXPECT_EQ(ar->control_predecessors().size(), 0);
-        channel_ids.insert(*(ar->channel_id()));
-      });
-  // we expect that all 10 all-reduces will have different channel ids.
-  EXPECT_EQ(channel_ids.size(), 10);
-}
-
-// The following 2 tests also address the regression described here:
-// https://github.com/openxla/xla/issues/6353
-TEST_F(GpuLoopDoubleBufferTransformerTest, NestedWhileLoopRemainsFlattened) {
-  const char* const kModuleString = R"(
-HloModule loop_unrolling_nested_while_loop_remains_flattened
-
-condition_nested {
-  input_tuple = (s32[]) parameter(0)
-  cond = s32[] get-tuple-element(input_tuple), index=0
-  trip_count = s32[] constant(10)
-  ROOT done = pred[] compare(cond, trip_count), direction=LT
-}
-
-body_nested {
- input_tuple = (s32[]) parameter(0)
- cond = s32[] get-tuple-element(input_tuple), index=0
- one = s32[] constant(1)
- cond_plus_1 = s32[] add(cond, one)
- ROOT output = (s32[]) tuple(cond_plus_1)
-}
-
-condition {
-  input_tuple = (s32[]) parameter(0)
-  cond = s32[] get-tuple-element(input_tuple), index=0
-  trip_count = s32[] constant(10)
-  ROOT done = pred[] compare(cond, trip_count), direction=LT
-}
-
-body {
-  input_tuple = (s32[]) parameter(0)
-  ROOT output = (s32[]) while(input_tuple), condition=condition_nested, body=body_nested
-}
-
-ENTRY main {
- param_0 = (s32[]) parameter(0)
- ROOT while = (s32[]) while(param_0), condition=condition, body=body, backend_config={"known_trip_count":{"n":"10"}}
-})";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
-                          ParseAndReturnVerifiedModule(kModuleString));
-  DoubleBufferLoopUnrolling double_buffer;
-  EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
-
-  absl::flat_hash_set<const HloComputation*> while_loops_callees;
-
-  hlo_query::ForEachInstructionWithOpcode(
-      *module, HloOpcode::kWhile,
-      [&while_loops_callees](HloInstruction* instr) {
-        EXPECT_TRUE(
-            while_loops_callees.insert(instr->while_condition()).second);
-        EXPECT_TRUE(while_loops_callees.insert(instr->while_body()).second);
-      });
-
-  // We expect that the nested while loop has been duplicated, along with its
-  // associated computations.
-  EXPECT_EQ(while_loops_callees.size(), 6);
-}
-
-TEST_F(GpuLoopDoubleBufferTransformerTest,
-       NestedWhileLoopRemainsFlattenedOddTripCount) {
-  const char* const kModuleString = R"(
-HloModule loop_unrolling_nested_while_loop_remains_flattened
-
-condition_nested {
-  input_tuple = (s32[]) parameter(0)
-  cond = s32[] get-tuple-element(input_tuple), index=0
-  trip_count = s32[] constant(10)
-  ROOT done = pred[] compare(cond, trip_count), direction=LT
-}
-
-body_nested {
- input_tuple = (s32[]) parameter(0)
- cond = s32[] get-tuple-element(input_tuple), index=0
- one = s32[] constant(1)
- cond_plus_1 = s32[] add(cond, one)
- ROOT output = (s32[]) tuple(cond_plus_1)
-}
-
-condition {
-  input_tuple = (s32[]) parameter(0)
-  cond = s32[] get-tuple-element(input_tuple), index=0
-  trip_count = s32[] constant(10)
-  ROOT done = pred[] compare(cond, trip_count), direction=LT
-}
-
-body {
-  input_tuple = (s32[]) parameter(0)
-  ROOT output = (s32[]) while(input_tuple), condition=condition_nested, body=body_nested
-}
-
-ENTRY main {
- param_0 = (s32[]) parameter(0)
- ROOT while = (s32[]) while(param_0), condition=condition, body=body, backend_config={"known_trip_count":{"n":"11"}}
-})";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
-                          ParseAndReturnVerifiedModule(kModuleString));
-  DoubleBufferLoopUnrolling double_buffer;
-  EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
-
-  absl::flat_hash_set<const HloComputation*> while_loops_callees;
-
-  hlo_query::ForEachInstructionWithOpcode(
-      *module, HloOpcode::kWhile,
-      [&while_loops_callees](HloInstruction* instr) {
-        EXPECT_TRUE(
-            while_loops_callees.insert(instr->while_condition()).second);
-        EXPECT_TRUE(while_loops_callees.insert(instr->while_body()).second);
-      });
-
-  // We expect that the nested while loop has been duplicated, along with its
-  // associated computations.
-  EXPECT_EQ(while_loops_callees.size(), 8);
-}
-
-TEST_F(GpuLoopDoubleBufferTransformerTest,
-       NestedWhileLoopRemainsFlattenedWhenFullyUnrolled) {
-  const char* const kModuleString = R"(
-HloModule loop_unrolling_nested_while_loop_remains_flattened
-
-condition_nested {
-  input_tuple = (s32[]) parameter(0)
-  cond = s32[] get-tuple-element(input_tuple), index=0
-  trip_count = s32[] constant(10)
-  ROOT done = pred[] compare(cond, trip_count), direction=LT
-}
-
-body_nested {
- input_tuple = (s32[]) parameter(0)
- cond = s32[] get-tuple-element(input_tuple), index=0
- one = s32[] constant(1)
- cond_plus_1 = s32[] add(cond, one)
- ROOT output = (s32[]) tuple(cond_plus_1)
-}
-
-condition {
-  input_tuple = (s32[]) parameter(0)
-  cond = s32[] get-tuple-element(input_tuple), index=0
-  trip_count = s32[] constant(10)
-  ROOT done = pred[] compare(cond, trip_count), direction=LT
-}
-
-body {
-  input_tuple = (s32[]) parameter(0)
-  ROOT output = (s32[]) while(input_tuple), condition=condition_nested, body=body_nested
-}
-
-ENTRY main {
- param_0 = (s32[]) parameter(0)
- ROOT while = (s32[]) while(param_0), condition=condition, body=body, backend_config={"known_trip_count":{"n":"10"}}
-})";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
-                          ParseAndReturnVerifiedModule(kModuleString));
-  DoubleBufferLoopUnrolling double_buffer(
-      DoubleBufferLoopUnrolling::UnrollStrategy::kFullUnroll);
-  EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
-
-  absl::flat_hash_set<const HloComputation*> while_loops_callees;
-
-  hlo_query::ForEachInstructionWithOpcode(
-      *module, HloOpcode::kWhile,
-      [&while_loops_callees](HloInstruction* instr) {
-        EXPECT_TRUE(
-            while_loops_callees.insert(instr->while_condition()).second);
-        EXPECT_TRUE(while_loops_callees.insert(instr->while_body()).second);
-      });
-
-  hlo_query::ForEachInstructionWithOpcode(
-      *module->entry_computation(), HloOpcode::kWhile,
-      [](HloInstruction* instr) {
-        TF_ASSERT_OK_AND_ASSIGN(
-            WhileLoopBackendConfig config,
-            instr->backend_config<WhileLoopBackendConfig>());
-        int64_t exact_trip_count = config.known_trip_count().n();
-        EXPECT_EQ(exact_trip_count, 1);
-      });
-
-  // We expect that the nested while loop has been fully duplicated 10
-  // times. The one outer while loop still remains so that's 11 while
-  // instructions. We check whether there are 22 distinct computations for
-  // each while loop body and condition.
-  EXPECT_EQ(while_loops_callees.size(), 22);
-}
-
-TEST_F(GpuLoopDoubleBufferTransformerTest, NestedWhileLoopAreUnrolled) {
-  const char* const kModuleString = R"(
-HloModule loop_unrolling_nested_are_unrolled
-condition_nested {
-  input_tuple = (s32[]) parameter(0)
-  cond = s32[] get-tuple-element(input_tuple), index=0
-  trip_count = s32[] constant(10)
-  ROOT done = pred[] compare(cond, trip_count), direction=LT
-}
-body_nested {
- input_tuple = (s32[]) parameter(0)
- cond = s32[] get-tuple-element(input_tuple), index=0
- one = s32[] constant(1)
- cond_plus_1 = s32[] add(cond, one)
- ROOT output = (s32[]) tuple(cond_plus_1)
-}
-condition {
-  input_tuple = (s32[]) parameter(0)
-  cond = s32[] get-tuple-element(input_tuple), index=0
-  trip_count = s32[] constant(10)
-  ROOT done = pred[] compare(cond, trip_count), direction=LT
-}
-body {
-  input_tuple = (s32[]) parameter(0)
-  ROOT output = (s32[]) while(input_tuple), condition=condition_nested, body=body_nested, backend_config={"known_trip_count":{"n":"11"}}
-}
-ENTRY main {
- param_0 = (s32[]) parameter(0)
- ROOT while = (s32[]) while(param_0), condition=condition, body=body, backend_config={"known_trip_count":{"n":"11"}}
-})";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
-                          ParseAndReturnVerifiedModule(kModuleString));
-  DoubleBufferLoopUnrolling double_buffer;
-  EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
-
-  int64_t num_whiles = 0;
-  hlo_query::ForEachInstructionWithOpcode(
-      *module, HloOpcode::kWhile, [&num_whiles](HloInstruction* instr) {
-        EXPECT_EQ(instr->backend_config<WhileLoopBackendConfig>()
-                      ->known_trip_count()
-                      .n(),
-                  5);
-        ++num_whiles;
-      });
-  // We expect the number of while loops to be 4 in total after unrolling.
-  EXPECT_EQ(num_whiles, 4);
-}
-
-TEST_F(GpuLoopDoubleBufferTransformerTest, NestedWhileLoopAreFullyUnrolled) {
-  const char* const kModuleString = R"(
-HloModule loop_unrolling_nested_are_unrolled
-condition_nested {
-  input_tuple = (s32[]) parameter(0)
-  cond = s32[] get-tuple-element(input_tuple), index=0
-  trip_count = s32[] constant(10)
-  ROOT done = pred[] compare(cond, trip_count), direction=LT
-}
-body_nested {
- input_tuple = (s32[]) parameter(0)
- cond = s32[] get-tuple-element(input_tuple), index=0
- one = s32[] constant(1)
- cond_plus_1 = s32[] add(cond, one)
- ROOT output = (s32[]) tuple(cond_plus_1)
-}
-condition {
-  input_tuple = (s32[]) parameter(0)
-  cond = s32[] get-tuple-element(input_tuple), index=0
-  trip_count = s32[] constant(10)
-  ROOT done = pred[] compare(cond, trip_count), direction=LT
-}
-body {
-  input_tuple = (s32[]) parameter(0)
-  ROOT output = (s32[]) while(input_tuple), condition=condition_nested, body=body_nested, backend_config={"known_trip_count":{"n":"11"}}
-}
-ENTRY main {
- param_0 = (s32[]) parameter(0)
- ROOT while = (s32[]) while(param_0), condition=condition, body=body, backend_config={"known_trip_count":{"n":"11"}}
-})";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
-                          ParseAndReturnVerifiedModule(kModuleString));
-  DoubleBufferLoopUnrolling double_buffer(
-      DoubleBufferLoopUnrolling::UnrollStrategy::kFullUnroll);
-  EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
-
-  int64_t num_whiles = 0;
-  hlo_query::ForEachInstructionWithOpcode(
-      *module, HloOpcode::kWhile, [&num_whiles](HloInstruction* instr) {
-        EXPECT_EQ(instr->backend_config<WhileLoopBackendConfig>()
-                      ->known_trip_count()
-                      .n(),
-                  1);
-        ++num_whiles;
-      });
-  EXPECT_EQ(num_whiles, 12);
-}
-
-TEST_F(GpuLoopDoubleBufferTransformerTest, WhileLoopWithCollectivePermute) {
-  const char* kModuleString = R"(
-HloModule loop_unrolling_no_deps
-condition {
-  input_tuple = (f32[], s32[]) parameter(0)
-  cond = s32[] get-tuple-element(input_tuple), index=1
-  trip_count = s32[] constant(10)
-  ROOT done = pred[] compare(cond, trip_count), direction=LT
-}
-ar_add {
-  Arg_1 = f32[] parameter(1)
-  Arg_0 = f32[] parameter(0)
-  ROOT add_ar = f32[] add(Arg_1, Arg_0)
-}
-body {
-  input_tuple = (f32[], s32[]) parameter(0)
-  param_0 = f32[] get-tuple-element(input_tuple), index=0
-  cond = s32[] get-tuple-element(input_tuple), index=1
-  collective-permute = f32[] collective-permute(param_0), channel_id=1, source_target_pairs={{0,1},{1,2},{2,3},{3,0}},
-                             frontend_attributes={_xla_send_recv_validation="{{0,6},{1,7},{2,8},{3,9}}"}
-  one = s32[] constant(1)
-  cond_plus_1 = s32[] add(cond, one)
-  ROOT output_tuple = (f32[], s32[]) tuple(collective-permute, cond_plus_1)
-}
-ENTRY main {
-  param_0 = f32[] parameter(0)
-  param_2 = s32[] constant(0)
-  tuple = (f32[], s32[]) tuple(param_0, param_2)
-  ROOT while = (f32[], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"10"}}
-}
-  )";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
-                          ParseAndReturnVerifiedModule(kModuleString));
-  DoubleBufferLoopUnrolling double_buffer(
-      DoubleBufferLoopUnrolling::UnrollStrategy::kDoubleBuffer);
-  EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
-  VLOG(0) << module->ToString();
-  EXPECT_TRUE(*RunFileCheck(module->ToString(), R"(
-    // CHECK: %body {{.+}} {
-    // CHECK:   %[[cp1:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation="{{[{]}}{0,3},{1,3},{1,4},{2,4}{{[}]}}"}
-    // CHECK:   %[[out1:.+]] = {{.+}} tuple({{.+}} %[[cp1]], {{.+}})
-    // CHECK:   %[[param2:.+]] = {{.+}} get-tuple-element({{.+}} %[[out1]]), index=0
-    // CHECK:   %[[cp2:.+]] = {{.+}} collective-permute({{.+}} %[[param2]]), {{.+}}, frontend_attributes={_xla_send_recv_validation="{{[{]}}{0,2},{0,3},{1,3},{1,4}{{[}]}}"}
-    // CHECK:   ROOT {{.+}} = {{.+}} tuple({{.+}} %[[cp2]], {{.+}})
-    // CHECK: }
-    // CHECK: ENTRY %main {{.+}} {
-    // CHECK-NOT: collective-permute
-    // CHECK: }
-  )"));
-}
-
-TEST_F(GpuLoopDoubleBufferTransformerTest,
-       WhileLoopWithCollectivePermutePeeled) {
-  const char* kModuleString = R"(
-HloModule loop_unrolling_no_deps
-condition {
-  input_tuple = (f32[], s32[]) parameter(0)
-  cond = s32[] get-tuple-element(input_tuple), index=1
-  trip_count = s32[] constant(15)
-  ROOT done = pred[] compare(cond, trip_count), direction=LT
-}
-ar_add {
-  Arg_1 = f32[] parameter(1)
-  Arg_0 = f32[] parameter(0)
-  ROOT add_ar = f32[] add(Arg_1, Arg_0)
-}
-body {
-  input_tuple = (f32[], s32[]) parameter(0)
-  param_0 = f32[] get-tuple-element(input_tuple), index=0
-  cond = s32[] get-tuple-element(input_tuple), index=1
-  collective-permute = f32[] collective-permute(param_0), channel_id=1, source_target_pairs={{0,1},{1,2},{2,3},{3,4},{4,5},{5,6},{6,7},{7,0}},
-                             frontend_attributes={_xla_send_recv_validation="{{0,7},{1,8},{2,9},{3,10},{4,11},{5,12},{6,13},{7,14}}"}
-  one = s32[] constant(1)
-  cond_plus_1 = s32[] add(cond, one)
-  ROOT output_tuple = (f32[], s32[]) tuple(collective-permute, cond_plus_1)
-}
-ENTRY main {
-  param_0 = f32[] parameter(0)
-  param_2 = s32[] constant(0)
-  tuple = (f32[], s32[]) tuple(param_0, param_2)
-  ROOT while = (f32[], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"15"}}
-}
-  )";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
-                          ParseAndReturnVerifiedModule(kModuleString));
-  DoubleBufferLoopUnrolling double_buffer(
-      DoubleBufferLoopUnrolling::UnrollStrategy::kDoubleBuffer);
-  EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
-  VLOG(0) << module->ToString();
-
-  EXPECT_TRUE(*RunFileCheck(module->ToString(), R"(
-    // CHECK: %body
-    // CHECK:   %[[cp1:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation="{{[{]}}{0,3},{0,3},{1,4},{1,4},{2,5},{2,5},{3,6},{3,6}{{[}]}}"}
-    // CHECK:   %[[out1:.+]] = {{.+}} tuple({{.+}} %[[cp1]], {{.+}})
-    // CHECK:   %[[param2:.+]] = {{.+}} get-tuple-element({{.+}} %[[out1]])
-    // CHECK:   %[[cp2:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation="{{[{]}}{0,2},{0,3},{0,3},{1,4},{1,4},{2,5},{2,5},{3,6}{{[}]}}"}
-    // CHECK:   ROOT {{.+}} = {{.+}} tuple({{.+}} %[[cp2]], {{.+}})
-    // CHECK: ENTRY %main {{.+}} {
-    // CHECK:   %[[cp_peeled:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation="{{[{]}}{0,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0}{{[}]}}"}
-    // CHECK:   %[[out_peeled:.+]] = {{.+}} tuple({{.+}} %[[cp_peeled]], {{.+}})
-    // CHECK:   %[[while:.+]] = {{.+}} while({{.+}} %[[out_peeled]])
-    // CHECK: }
-    )"));
-}
-
-TEST_F(GpuLoopDoubleBufferTransformerTest,
-       WhileLoopWithCollectivePermuteBackwardCycle) {
-  const char* kModuleString = R"(
-HloModule loop_unrolling_no_deps
-condition {
-  input_tuple = (f32[], s32[]) parameter(0)
-  cond = s32[] get-tuple-element(input_tuple), index=1
-  trip_count = s32[] constant(14)
-  ROOT done = pred[] compare(cond, trip_count), direction=LT
-}
-ar_add {
-  Arg_1 = f32[] parameter(1)
-  Arg_0 = f32[] parameter(0)
-  ROOT add_ar = f32[] add(Arg_1, Arg_0)
-}
-body {
-  input_tuple = (f32[], s32[]) parameter(0)
-  param_0 = f32[] get-tuple-element(input_tuple), index=0
-  cond = s32[] get-tuple-element(input_tuple), index=1
-  collective-permute = f32[] collective-permute(param_0), channel_id=1, source_target_pairs={{0,7},{1,0},{2,1},{3,2},{4,3},{5,4},{6,5},{7,6}},
-                             frontend_attributes={_xla_send_recv_validation="{{7,13},{6,12},{5,11},{4,10},{3,9},{2,8},{1,7},{0,6}}"}
-  one = s32[] constant(1)
-  cond_plus_1 = s32[] add(cond, one)
-  ROOT output_tuple = (f32[], s32[]) tuple(collective-permute, cond_plus_1)
-}
-ENTRY main {
-  param_0 = f32[] parameter(0)
-  param_2 = s32[] constant(0)
-  tuple = (f32[], s32[]) tuple(param_0, param_2)
-  ROOT while = (f32[], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"14"}}
-}
-  )";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
-                          ParseAndReturnVerifiedModule(kModuleString));
-  DoubleBufferLoopUnrolling double_buffer(
-      DoubleBufferLoopUnrolling::UnrollStrategy::kDoubleBuffer);
-  EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
-
-  EXPECT_TRUE(*RunFileCheck(module->ToString(), R"(
-    // CHECK: %body
-    // CHECK:   %[[cp1:.+]] = f32[] collective-permute(f32[] %param_0), {{.+}}, frontend_attributes={_xla_send_recv_validation="{{[{]}}{4,6},{3,6},{3,5},{2,5},{2,4},{1,4},{1,3},{0,3}{{[}]}}"}
-    // CHECK:   %[[out1:.+]] = {{.+}} tuple({{.+}} %[[cp1]], {{.+}})
-    // CHECK:   %[[param2:.+]] = {{.+}} get-tuple-element({{.+}} %[[out1]]), index=0
-    // CHECK:   %[[cp2:.+]] = {{.+}} collective-permute({{.+}} %[[param2]]), {{.+}}, frontend_attributes={_xla_send_recv_validation="{{[{]}}{3,6},{3,5},{2,5},{2,4},{1,4},{1,3},{0,3},{0,2}{{[}]}}"}
-    // CHECK:   ROOT {{.+}} = {{.+}} tuple({{.+}} %[[cp2]], {{.+}})
-    // CHECK: ENTRY %main
-    // CHECK-NOT: collective-permute
-    // CHECK: }
-  )"));
-}
-
-TEST_F(GpuLoopDoubleBufferTransformerTest,
-       WhileLoopWithCollectivePermuteBackwardCyclePeeled) {
-  const char* kModuleString = R"(
-HloModule loop_unrolling_no_deps
-condition {
-  input_tuple = (f32[], s32[]) parameter(0)
-  cond = s32[] get-tuple-element(input_tuple), index=1
-  trip_count = s32[] constant(15)
-  ROOT done = pred[] compare(cond, trip_count), direction=LT
-}
-ar_add {
-  Arg_1 = f32[] parameter(1)
-  Arg_0 = f32[] parameter(0)
-  ROOT add_ar = f32[] add(Arg_1, Arg_0)
-}
-body {
-  input_tuple = (f32[], s32[]) parameter(0)
-  param_0 = f32[] get-tuple-element(input_tuple), index=0
-  cond = s32[] get-tuple-element(input_tuple), index=1
-  collective-permute = f32[] collective-permute(param_0), channel_id=1, source_target_pairs={{0,7},{1,0},{2,1},{3,2},{4,3},{5,4},{6,5},{7,6}},
-                             frontend_attributes={_xla_send_recv_validation="{{7,14},{6,13},{5,12},{4,11},{3,10},{2,9},{1,8},{0,7}}"}
-  one = s32[] constant(1)
-  cond_plus_1 = s32[] add(cond, one)
-  ROOT output_tuple = (f32[], s32[]) tuple(collective-permute, cond_plus_1)
-}
-ENTRY main {
-  param_0 = f32[] parameter(0)
-  param_2 = s32[] constant(0)
-  tuple = (f32[], s32[]) tuple(param_0, param_2)
-  ROOT while = (f32[], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"15"}}
-}
-  )";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
-                          ParseAndReturnVerifiedModule(kModuleString));
-  DoubleBufferLoopUnrolling double_buffer(
-      DoubleBufferLoopUnrolling::UnrollStrategy::kDoubleBuffer);
-  EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
-
-  EXPECT_TRUE(*RunFileCheck(module->ToString(), R"(
-    // CHECK: %body
-    // CHECK:   %[[cp1:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation="{{[{]}}{3,6},{3,6},{2,5},{2,5},{1,4},{1,4},{0,3},{0,3}{{[}]}}"}
-    // CHECK:   %[[out1:.+]] = {{.+}} tuple({{.+}} %[[cp1]], {{.+}})
-    // CHECK:   %[[param2:.+]] = {{.+}} get-tuple-element({{.+}} %[[out1]]), index=0
-    // CHECK:   %[[cp2:.+]] = {{.+}} collective-permute({{.+}} %[[param2]]), {{.+}}, frontend_attributes={_xla_send_recv_validation="{{[{]}}{3,6},{2,5},{2,5},{1,4},{1,4},{0,3},{0,3},{0,2}{{[}]}}"}
-    // CHECK:   ROOT {{.+}} = {{.+}} tuple({{.+}} %[[cp2]], {{.+}})
-    // CHECK: }
-    // CHECK: ENTRY %main
-    // CHECK:   %[[cp_peeled:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation="{{[{]}}{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{0,0}{{[}]}}"}
-    // CHECK:   %[[out_peeled:.+]] = {{.+}} tuple({{.+}} %[[cp_peeled]], {{.+}})
-    // CHECK:   ROOT {{.+}} = {{.+}} while({{.+}} %[[out_peeled]])
-    // CHECK: }
-  )"));
-}
-
-TEST_F(GpuLoopDoubleBufferTransformerTest,
-       WhileLoopWithCollectivePermuteStartDone) {
-  const char* kModuleString = R"(
-HloModule loop_unrolling_no_deps
-condition {
-  input_tuple = (f32[], s32[]) parameter(0)
-  cond = s32[] get-tuple-element(input_tuple), index=1
-  trip_count = s32[] constant(10)
-  ROOT done = pred[] compare(cond, trip_count), direction=LT
-}
-ar_add {
-  Arg_1 = f32[] parameter(1)
-  Arg_0 = f32[] parameter(0)
-  ROOT add_ar = f32[] add(Arg_1, Arg_0)
-}
-body {
-  input_tuple = (f32[], s32[]) parameter(0)
-  param_0 = f32[] get-tuple-element(input_tuple), index=0
-  cond = s32[] get-tuple-element(input_tuple), index=1
-  collective-permute-start = (f32[], f32[], u32[], u32[]) collective-permute-start(param_0), channel_id=1, source_target_pairs={{0,1},{1,2},{2,3},{3,0}},
-                             frontend_attributes={_xla_send_recv_validation="{{0,6},{1,7},{2,8},{3,9}}"}
-  collective-permute = f32[] collective-permute-done(collective-permute-start)
-  one = s32[] constant(1)
-  cond_plus_1 = s32[] add(cond, one)
-  ROOT output_tuple = (f32[], s32[]) tuple(collective-permute, cond_plus_1)
-}
-ENTRY main {
-  param_0 = f32[] parameter(0)
-  param_2 = s32[] constant(0)
-  tuple = (f32[], s32[]) tuple(param_0, param_2)
-  ROOT while = (f32[], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"10"}}
-}
-  )";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
-                          ParseAndReturnVerifiedModule(kModuleString));
-  DoubleBufferLoopUnrolling double_buffer(
-      DoubleBufferLoopUnrolling::UnrollStrategy::kDoubleBuffer);
-  EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
-
-  EXPECT_TRUE(*RunFileCheck(module->ToString(), R"(
-    // CHECK: %body
-    // CHECK:   %[[cp_start1:.+]] = {{.+}} collective-permute-start({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation="{{[{]}}{0,3},{1,3},{1,4},{2,4}{{[}]}}"}
-    // CHECK:   %[[cp1:.+]] = {{.+}} collective-permute-done({{.+}} %[[cp_start1]])
-    // CHECK:   %[[out1:.+]] = {{.+}} tuple({{.+}} %[[cp1]], {{.+}})
-    // CHECK:   %[[param2:.+]] = {{.+}} get-tuple-element({{.+}} %[[out1]]), index=0
-    // CHECK:   %[[cp_start2:.+]] = {{.+}} collective-permute-start({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation="{{[{]}}{0,2},{0,3},{1,3},{1,4}{{[}]}}"}
-    // CHECK:   %[[cp2:.+]] = {{.+}} collective-permute-done({{.+}} %[[cp_start2]])
-    // CHECK:   ROOT {{.+}} = {{.+}} tuple({{.+}} %[[cp2]], {{.+}})
-    // CHECK: }
-    // CHECK: ENTRY %main
-    // CHECK-NOT: collective-permute
-    // CHECK: }
-  )"));
-}
-
-TEST_F(GpuLoopDoubleBufferTransformerTest, WhileLoopWithRecvDone) {
-  const char* kModuleString = R"(
-HloModule loop_unrolling_no_deps
-condition {
-  input_tuple = (f32[], s32[]) parameter(0)
-  cond = s32[] get-tuple-element(input_tuple), index=1
-  trip_count = s32[] constant(10)
-  ROOT done = pred[] compare(cond, trip_count), direction=LT
-}
-ar_add {
-  Arg_1 = f32[] parameter(1)
-  Arg_0 = f32[] parameter(0)
-  ROOT add_ar = f32[] add(Arg_1, Arg_0)
-}
-body {
-  input_tuple = (f32[], s32[]) parameter(0)
-  param_0 = f32[] get-tuple-element(input_tuple), index=0
-  cond = s32[] get-tuple-element(input_tuple), index=1
-  after-all.0 = token[] after-all()
-  recv.0 = (f32[], u32[], token[]) recv(after-all.0), channel_id=1,
-        frontend_attributes={
-          _xla_send_recv_source_target_pairs="{{0,1},{1,2},{2,3},{3,0}}",
-          _xla_send_recv_pipeline="0",
-          _xla_send_recv_validation="{{0,6},{1,7},{2,8},{3,9}}"
-        }
-  recv-done.0 = (f32[], token[]) recv-done(recv.0), channel_id=1,
-        frontend_attributes={
-          _xla_send_recv_pipeline="0"
-        }
-  recv-data = f32[] get-tuple-element(recv-done.0), index=0
-  one = s32[] constant(1)
-  cond_plus_1 = s32[] add(cond, one)
-  ROOT output_tuple = (f32[], s32[]) tuple(recv-data, cond_plus_1)
-}
-ENTRY main {
-  param_0 = f32[] parameter(0)
-  param_2 = s32[] constant(0)
-  tuple = (f32[], s32[]) tuple(param_0, param_2)
-  ROOT while = (f32[], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"10"}}
-}
-  )";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
-                          ParseAndReturnVerifiedModule(kModuleString));
-  DoubleBufferLoopUnrolling double_buffer(
-      DoubleBufferLoopUnrolling::UnrollStrategy::kDoubleBuffer);
-  EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
-
-  EXPECT_TRUE(*RunFileCheck(module->ToString(), R"(
-    // CHECK: %body
-    // CHECK:   %[[recv1:.+]] = {{.+}} recv({{.+}}), {{.+}},_xla_send_recv_validation="{{[{]}}{0,3},{1,3},{1,4},{2,4}{{[}]}}"
-    // CHECK:   %[[recv2:.+]] = {{.+}} recv({{.+}}), {{.+}},_xla_send_recv_validation="{{[{]}}{0,2},{0,3},{1,3},{1,4}{{[}]}}"
-    // CHECK: ENTRY %main
-    // CHECK-NOT: recv
-    // CHECK: }
-  )"));
-}
-
-TEST_F(GpuLoopDoubleBufferTransformerTest, WhileLoopWithSendDone) {
-  const char* kModuleString = R"(
-HloModule loop_unrolling_no_deps
-condition {
-  input_tuple = (f32[], s32[]) parameter(0)
-  cond = s32[] get-tuple-element(input_tuple), index=1
-  trip_count = s32[] constant(10)
-  ROOT done = pred[] compare(cond, trip_count), direction=LT
-}
-ar_add {
-  Arg_1 = f32[] parameter(1)
-  Arg_0 = f32[] parameter(0)
-  ROOT add_ar = f32[] add(Arg_1, Arg_0)
-}
-body {
-  input_tuple = (f32[], s32[]) parameter(0)
-  param_0 = f32[] get-tuple-element(input_tuple), index=0
-  cond = s32[] get-tuple-element(input_tuple), index=1
-  after-all.0 = token[] after-all()
-  send.0 = (f32[], u32[], token[]) send(param_0, after-all.0), channel_id=1,
-        frontend_attributes={
-          _xla_send_recv_source_target_pairs="{{0,1},{1,2},{2,3},{3,0}}",
-          _xla_send_recv_pipeline="0",
-          _xla_send_recv_validation="{{0,6},{1,7},{2,8},{3,9}}"
-        }
-  send-done.0 = token[] send-done(send.0), channel_id=1,
-        frontend_attributes={
-          _xla_send_recv_pipeline="0"
-        }
-  one = s32[] constant(1)
-  cond_plus_1 = s32[] add(cond, one)
-  ROOT output_tuple = (f32[], s32[]) tuple(param_0, cond_plus_1)
-}
-ENTRY main {
-  param_0 = f32[] parameter(0)
-  param_2 = s32[] constant(0)
-  tuple = (f32[], s32[]) tuple(param_0, param_2)
-  ROOT while = (f32[], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"10"}}
-}
-  )";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
-                          ParseAndReturnVerifiedModule(kModuleString));
-  DoubleBufferLoopUnrolling double_buffer(
-      DoubleBufferLoopUnrolling::UnrollStrategy::kDoubleBuffer);
-  EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
-
-  EXPECT_TRUE(*RunFileCheck(module->ToString(), R"(
-    // CHECK: %body
-    // CHECK:   %[[send1:.+]] = {{.+}} send({{.+}}), {{.+}},_xla_send_recv_validation="{{[{]}}{0,3},{1,3},{1,4},{2,4}{{[}]}}"
-    // CHECK:   %[[send2:.+]] = {{.+}} send({{.+}}), {{.+}},_xla_send_recv_validation="{{[{]}}{0,2},{0,3},{1,3},{1,4}{{[}]}}"
-    // CHECK: ENTRY %main
-    // CHECK-NOT: send
-    // CHECK: }
-  )"));
-}
-
-TEST_F(GpuLoopDoubleBufferTransformerTest,
-       WhileLoopWithTripCount1ShouldBeSkipped) {
-  const char* const kModuleString = R"(
-HloModule loop_unrolling_skipped
-condition_nested {
-  input_tuple = (s32[]) parameter(0)
-  cond = s32[] get-tuple-element(input_tuple), index=0
-  trip_count = s32[] constant(0)
-  ROOT done = pred[] compare(cond, trip_count), direction=LT
-}
-body_nested {
- input_tuple = (s32[]) parameter(0)
- cond = s32[] get-tuple-element(input_tuple), index=0
- one = s32[] constant(1)
- cond_plus_1 = s32[] add(cond, one)
- ROOT output = (s32[]) tuple(cond_plus_1)
-}
-condition {
-  input_tuple = (s32[]) parameter(0)
-  cond = s32[] get-tuple-element(input_tuple), index=0
-  trip_count = s32[] constant(0)
-  ROOT done = pred[] compare(cond, trip_count), direction=LT
-}
-body {
-  input_tuple = (s32[]) parameter(0)
-  ROOT output = (s32[]) while(input_tuple), condition=condition_nested, body=body_nested, backend_config={"known_trip_count":{"n":"1"}}
-}
-ENTRY main {
- param_0 = (s32[]) parameter(0)
- ROOT while = (s32[]) while(param_0), condition=condition, body=body, backend_config={"known_trip_count":{"n":"1"}}
-})";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
-                          ParseAndReturnVerifiedModule(kModuleString));
-  DoubleBufferLoopUnrolling double_buffer(
-      DoubleBufferLoopUnrolling::UnrollStrategy::kFullUnroll);
-  // The processing of the loop should be completely skipped.
-  EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(false));
-}
-
-}  // namespace
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/dynamic_slice_fusion_rewriter.cc b/third_party/xla/xla/service/gpu/dynamic_slice_fusion_rewriter.cc
deleted file mode 100644
index 0919241..0000000
--- a/third_party/xla/xla/service/gpu/dynamic_slice_fusion_rewriter.cc
+++ /dev/null
@@ -1,513 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "xla/service/gpu/dynamic_slice_fusion_rewriter.h"
-
-#include <cstddef>
-#include <cstdint>
-#include <functional>
-#include <iterator>
-#include <optional>
-#include <string>
-#include <utility>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_map.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/container/inlined_vector.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "xla/ffi/ffi_api.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/hlo/ir/hlo_schedule.h"
-#include "xla/service/custom_call_target_registry.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/service/gpu/gpu_constants.h"
-#include "xla/service/gpu/hlo_traversal.h"
-#include "xla/service/gpu/ir_emission_utils.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/util.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-namespace {
-
-// A dataflow path flowing from a definition to a user.
-using DefUseDataflowPath = absl::InlinedVector<HloInstruction*, 2>;
-
-// All dataflow paths flowing from a definition to all users. Each user will
-// have a separate entry in the vector.
-using DefUseDataflowPaths = absl::InlinedVector<DefUseDataflowPath, 4>;
-
-// A dataflow path flowing from a user to a definition.
-using UseDefDataflowPath = absl::InlinedVector<HloInstruction*, 4>;
-
-// All dataflow paths flowing from a user to all definitions of its operands.
-using UseDefDataflowPaths = absl::InlinedVector<HloInstruction*, 8>;
-
-using DataflowPathView = absl::Span<HloInstruction* const>;
-using DataflowPathsView = absl::Span<DataflowPathView>;
-
-using InstructionSet = absl::flat_hash_set<HloInstruction*>;
-
-bool IsNoOp(const HloInstruction* hlo) {
-  return HloPredicateIsOp<HloOpcode::kBitcast, HloOpcode::kTuple,
-                          HloOpcode::kGetTupleElement>(hlo);
-}
-
-bool IsCustomCall(const HloInstruction* hlo, absl::string_view platform_name) {
-  auto* custom_call = DynCast<HloCustomCallInstruction>(hlo);
-  if (custom_call == nullptr) return false;
-
-  // TODO(vuson): properly handle token by following
-  // `LhloDialectEmitter::EmitCustomCallOp`'s `CreateOperands` logic for
-  // `LhloDialectEmitter::EmitFusionOp`'s `RewriteFusionOperand`
-  if (custom_call->shape().IsTuple() &&
-      absl::c_any_of(
-          custom_call->shape().tuple_shapes(),
-          [&](const Shape& sub_shape) { return sub_shape.IsToken(); }))
-    return false;
-
-  const std::string call_target_name = custom_call->custom_call_target();
-
-  bool is_ffi_custom_call =
-      custom_call->api_version() == CustomCallApiVersion::API_VERSION_TYPED_FFI;
-
-  void* call_target = CustomCallTargetRegistry::Global()->Lookup(
-      call_target_name, std::string(platform_name));
-
-  absl::StatusOr<ffi::HandlerRegistration> handler_registration =
-      ffi::FindHandler(call_target_name, platform_name);
-
-  // At least one implementation should be available at run time.
-  bool found_custom_call = !is_ffi_custom_call && call_target != nullptr;
-  bool found_ffi_handler = is_ffi_custom_call && handler_registration.ok();
-
-  return found_custom_call || found_ffi_handler;
-}
-
-// Returns true if the slice is 128-byte-aligned. The slice starting
-// address is determined by the product of all non-sliced dimensions and an
-// offset defined by `slice_starts` of the slice op.
-//
-// For dynamic cases, we don't have info about the start indices, so we have to
-// be conservative by only accepting sliced shapes that have the product of all
-// non-sliced dimensions being a multiple of `kXlaAllocatedBufferAlignBytes`.
-bool IsAlignedSlice(const HloInstruction* slice) {
-  DCHECK(slice->opcode() == HloOpcode::kSlice ||
-         slice->opcode() == HloOpcode::kDynamicSlice ||
-         slice->opcode() == HloOpcode::kDynamicUpdateSlice)
-      << "Unknown slice operation: " << slice->ToString();
-
-  if (!IsContiguousSlice(*slice)) return false;
-
-  auto [full_shape, slice_shape] = [&] {
-    if (auto* dus = DynCast<HloDynamicUpdateSliceInstruction>(slice)) {
-      return std::make_pair(dus->shape(), dus->update()->shape());
-    }
-    return std::make_pair(slice->operand(0)->shape(), slice->shape());
-  }();
-
-  auto strides = ShapeUtil::ByteStrides(slice_shape);
-  if (!strides.has_value()) return false;
-
-  for (auto dim : slice_shape.layout().minor_to_major()) {
-    if ((strides.value()[dim] % kXlaAllocatedBufferAlignBytes) == 0) {
-      return true;
-    }
-    if (slice_shape.dimensions(dim) < full_shape.dimensions(dim)) {
-      return (slice->opcode() == HloOpcode::kSlice &&
-              (((*strides)[dim] * slice->slice_starts(dim)) %
-                   kXlaAllocatedBufferAlignBytes ==
-               0));
-    }
-  }
-  return true;
-}
-
-UseDefDataflowPaths GetSlicedOperandPaths(const HloInstruction* instr) {
-  UseDefDataflowPaths sliced_operand_paths;
-
-  // This set is used to avoid duplicates in the matched results. It contains
-  // the matched instructions that we have seen so far.
-  InstructionSet processed_instrs;
-
-  const auto& aliasing_pairs =
-      Cast<HloCustomCallInstruction>(instr)->output_to_operand_aliasing();
-  absl::flat_hash_set<int64_t> aliased_operands;
-  for (const auto& pair : aliasing_pairs) {
-    aliased_operands.insert(pair.second.first);
-  }
-
-  for (const auto* operand : instr->operands()) {
-    // output_to_operand_aliasing means the operand is to be materialized, which
-    // is against the whole idea of address computation fusion. Skip this
-    // operand.
-    if (aliased_operands.contains(instr->operand_index(operand))) continue;
-    UseDefDataflowPath maybe_sliced_operand_path;
-    bool slice_found = false;
-    // TODO: currently HloFindIf exits upon encountering the first node that
-    // matches. This works well if each operand only has 1 data flow (i.e. only
-    // flows through unary op). We might want to keep finding until the queue is
-    // empty: if the operand is a tuple, it might have different data flows
-    // (i.e. 1 for each element).
-    auto maybe_slice_instr =
-        HloBfsFindIf({operand}, [&](const HloInstruction* cur) {
-          // If the node is a match that has been processed, stop the traversal.
-          if (processed_instrs.contains(cur)) return true;
-
-          maybe_sliced_operand_path.push_back(const_cast<HloInstruction*>(cur));
-
-          if (IsOpcodeAnyOf<HloOpcode::kDynamicSlice, HloOpcode::kSlice>(cur)) {
-            if (IsAlignedSlice(cur)) {
-              slice_found = true;
-              return slice_found;
-            }
-          }
-
-          return !IsNoOp(cur);
-        });
-
-    if (maybe_slice_instr == std::nullopt) continue;
-
-    if (slice_found || processed_instrs.contains(maybe_slice_instr.value())) {
-      // Even in the case of stopping at a match that has been processed, we
-      // still need to add instructions encountered in the sliced operand path
-      // during the latest traversal.
-      sliced_operand_paths.insert(sliced_operand_paths.end(),
-                                  maybe_sliced_operand_path.rbegin(),
-                                  maybe_sliced_operand_path.rend());
-      processed_instrs.insert(maybe_sliced_operand_path.begin(),
-                              maybe_sliced_operand_path.end());
-    }
-  }
-
-  sliced_operand_paths.push_back(const_cast<HloInstruction*>(instr));
-  return sliced_operand_paths;
-}
-
-// Each user of `instr` that goes into a DUS will have an entry in the returned
-// vector.
-// Each entry contains the sliced paths for that user, i.e. the sequence of ops
-// following the dataflow from the user itself to the DUS (included).
-DefUseDataflowPaths GetSlicedUserPaths(const HloInstruction* instr) {
-  DefUseDataflowPaths sliced_user_paths;
-  // This set is used to avoid duplicates in the matched results. It contains
-  // the matched instructions that we have seen so far.
-  InstructionSet processed_instrs;
-
-  auto traverse_hlo_and_collect = [&](HloInstruction* start) {
-    DefUseDataflowPath maybe_sliced_user_path;
-    bool dus_found = false;
-    auto maybe_dus_instr = HloBfsFindIf(
-        {start},
-        [&](const HloInstruction* cur) {
-          // If the node is a match that has been processed, stop the
-          // traversal.
-          if (processed_instrs.contains(cur)) return true;
-          maybe_sliced_user_path.push_back(const_cast<HloInstruction*>(cur));
-          if (const auto slice_instr =
-                  DynCast<HloDynamicUpdateSliceInstruction>(cur)) {
-            if (IsAlignedSlice(slice_instr)) {
-              dus_found = true;
-              return true;
-            }
-          }
-          return cur->user_count() > 1 || !IsNoOp(cur);
-        },
-        /*visit_operands=*/false);
-    if (maybe_dus_instr == std::nullopt) return;
-    if (dus_found || processed_instrs.contains(maybe_dus_instr.value())) {
-      // Even in the case of stopping at a match that has been processed, we
-      // still need to add instructions encountered in the sliced user path
-      // during the latest traversal.
-      processed_instrs.insert(maybe_sliced_user_path.begin(),
-                              maybe_sliced_user_path.end());
-      sliced_user_paths.push_back(std::move(maybe_sliced_user_path));
-    }
-  };
-
-  if (instr->shape().IsTuple()) {
-    for (auto* user : instr->users()) {
-      if (DynCast<HloGetTupleElementInstruction>(user)) {
-        traverse_hlo_and_collect(user);
-      }
-    }
-  } else {
-    if (instr->user_count() == 1) {
-      traverse_hlo_and_collect(instr->users().front());
-    }
-  }
-
-  return sliced_user_paths;
-}
-
-absl::InlinedVector<HloInstruction*, 4> GetPatternCaptures(
-    DataflowPathView matches) {
-  absl::InlinedVector<HloInstruction*, 4> captures;
-
-  InstructionSet matched_instrs(matches.begin(), matches.end());
-
-  for (HloInstruction* instr : matches) {
-    for (HloInstruction* operand : instr->operands()) {
-      if (!matched_instrs.contains(operand) &&
-          absl::c_find(captures, operand) == captures.end()) {
-        captures.emplace_back(operand);
-      }
-    }
-  }
-
-  return captures;
-}
-
-absl::Status CreateRootTuple(
-    HloInstruction* hero, HloComputation::Builder& builder,
-    DataflowPathsView sliced_user_paths,
-    absl::flat_hash_map<const HloInstruction*, HloInstruction*>&
-        instr_mapping) {
-  unsigned tuple_size = hero->shape().tuple_shapes_size();
-
-  std::vector<HloInstruction*> sliced_elems(tuple_size, nullptr);
-  for (auto& sliced_user_path : sliced_user_paths) {
-    auto gte = Cast<HloGetTupleElementInstruction>(sliced_user_path.front());
-    sliced_elems[gte->tuple_index()] = sliced_user_path.back();
-  }
-
-  std::vector<HloInstruction*> elements;
-  for (size_t i = 0; i < tuple_size; ++i) {
-    if (sliced_elems[i] != nullptr) {
-      elements.push_back(instr_mapping[sliced_elems[i]]);
-      continue;
-    }
-    auto* gte = builder.AddInstruction(
-        HloInstruction::CreateGetTupleElement(instr_mapping[hero], i));
-    if (hero->shape().tuple_shapes(i).IsTuple()) {
-      instr_mapping[gte] = gte;
-      TF_RETURN_IF_ERROR(CreateRootTuple(gte, builder, {}, instr_mapping));
-      elements.push_back(builder.last_added_instruction());
-    } else {
-      elements.push_back(gte);
-    }
-  }
-  if (elements.size() > 1)
-    builder.AddInstruction(HloInstruction::CreateTuple(elements));
-
-  return absl::OkStatus();
-}
-
-absl::StatusOr<HloComputation*> CreateFusionBody(
-    HloModule* module, DataflowPathView sliced_operand_paths,
-    DataflowPathsView sliced_user_paths, DataflowPathView captures) {
-  HloComputation::Builder builder("address-computation");
-
-  // A mapping from original instructions to instructions in the fusion body.
-  absl::flat_hash_map<const HloInstruction*, HloInstruction*> instr_mapping;
-
-  auto mapped_operands = [&](HloInstruction* instr) {
-    absl::InlinedVector<HloInstruction*, 4> operands;
-    for (HloInstruction* operand : instr->operands()) {
-      operands.push_back(instr_mapping.at(operand));
-    }
-    return operands;
-  };
-
-  // For every captured value create a parameter instruction in the computation
-  // body and set up instruction mapping.
-  for (const HloInstruction* capture : captures) {
-    int64_t index = instr_mapping.size();
-    instr_mapping[capture] =
-        builder.AddInstruction(HloInstruction::CreateParameter(
-            index, capture->shape(), absl::StrCat("p", index)));
-  }
-
-  // Instructions in the pattern are already topologically sorted, as we visited
-  // them following use-def path, then reverse the list.
-  HloInstruction* hero;
-  for (HloInstruction* instr : sliced_operand_paths) {
-    instr_mapping[instr] = builder.AddInstruction(
-        instr->CloneWithNewOperands(instr->shape(), mapped_operands(instr)));
-    hero = instr;
-  }
-
-  for (auto& sliced_user_path : sliced_user_paths) {
-    for (HloInstruction* instr : sliced_user_path) {
-      instr_mapping[instr] = builder.AddInstruction(
-          instr->CloneWithNewOperands(instr->shape(), mapped_operands(instr)));
-    }
-  }
-
-  // Create a tuple if the hero is a tuple to make sure there's a buffer
-  // assigned for each of the elements. Make sure the tuple is not nil first.
-  if (hero->shape().IsTuple() && hero->shape().tuple_shapes_size() > 0) {
-    TF_RETURN_IF_ERROR(
-        CreateRootTuple(hero, builder, sliced_user_paths, instr_mapping));
-  }
-
-  return module->AddComputationAndUnifyNamesAndIds(builder.Build(), false);
-}
-
-absl::StatusOr<HloInstruction*> CreateFusionInstruction(
-    HloModule* module, HloInstruction* orig, DataflowPathView captures,
-    HloComputation* body, bool dynamic) {
-  HloComputation* parent = orig->parent();
-
-  // Add a fusion operation calling outlined fusion computation.
-  HloInstruction* fusion = parent->AddInstruction(HloInstruction::CreateFusion(
-      body->root_instruction()->shape(), HloInstruction::FusionKind::kCustom,
-      captures, body));
-  module->SetAndUniquifyInstrName(fusion, "address_computation");
-
-  // We don't need to set/update output_to_operand_aliasing for the new fusion
-  // instruction because all buffers are already assigned at this point.
-
-  // Set backends config to a matched custom fusion config.
-  GpuBackendConfig gpu_config;
-  FusionBackendConfig& backend_config =
-      *gpu_config.mutable_fusion_backend_config();
-  backend_config.set_kind("__custom_fusion");
-  CustomFusionConfig config;
-  config.set_name(dynamic ? "dynamic_address_computation"
-                          : "address_computation");
-  *backend_config.mutable_custom_fusion_config() = config;
-  TF_RETURN_IF_ERROR(fusion->set_backend_config(std::move(gpu_config)));
-
-  return fusion;
-}
-
-}  // namespace
-
-absl::StatusOr<bool> DynamicSliceFusionRewriter::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  absl::flat_hash_map<HloInstruction*,
-                      std::pair<UseDefDataflowPaths, DefUseDataflowPaths>>
-      matches;
-
-  // Collect all potential custom call matches in the non-fusion computations.
-  for (HloComputation* computation : module->computations()) {
-    if (computation->IsFusionComputation()) continue;
-    for (HloInstruction* instr : computation->instructions()) {
-      if (IsLegacyCublasMatmul(*instr) ||
-          (IsCustomCall(instr, platform_name_))) {
-        UseDefDataflowPaths sliced_operand_paths = GetSlicedOperandPaths(instr);
-        bool has_sliced_operand_paths = sliced_operand_paths.size() > 1;
-
-        DefUseDataflowPaths sliced_user_paths = GetSlicedUserPaths(instr);
-        bool has_sliced_user_paths = absl::c_any_of(
-            sliced_user_paths,
-            [&](auto& sliced_user_path) { return !sliced_user_path.empty(); });
-
-        if (absl::c_any_of(sliced_user_paths, [&](auto& sliced_user_path) {
-              return DynCast<HloDynamicUpdateSliceInstruction>(
-                         sliced_user_path.back()) == nullptr;
-            })) {
-          return absl::InternalError(
-              "Expect sliced user path to end with a DUS.");
-        }
-
-        if (has_sliced_operand_paths || has_sliced_user_paths) {
-          matches[instr] = std::make_pair(std::move(sliced_operand_paths),
-                                          std::move(sliced_user_paths));
-        }
-      }
-    }
-  }
-
-  if (matches.empty()) return false;
-
-  for (auto& [hero, paths] : matches) {
-    auto& [sliced_operand_paths, sliced_user_paths] = paths;
-    std::vector<HloInstruction*> matched_instrs;
-    absl::c_copy(sliced_operand_paths, std::back_inserter(matched_instrs));
-
-    std::vector<DataflowPathView> sliced_user_paths_view;
-    for (auto& sliced_user_path : sliced_user_paths) {
-      absl::c_copy(sliced_user_path, std::back_inserter(matched_instrs));
-      DataflowPathView sliced_user_path_view{&sliced_user_path.front(),
-                                             sliced_user_path.size()};
-      sliced_user_paths_view.push_back(std::move(sliced_user_path_view));
-    }
-
-    auto captures = GetPatternCaptures(matched_instrs);
-
-    TF_ASSIGN_OR_RETURN(
-        HloComputation * fusion_body,
-        CreateFusionBody(module, sliced_operand_paths,
-                         DataflowPathsView(sliced_user_paths_view), captures));
-
-    bool has_dynamic_slices = absl::c_any_of(matched_instrs, [&](auto* instr) {
-      return DynCast<HloDynamicIndexInstruction>(instr) != nullptr;
-    });
-    TF_ASSIGN_OR_RETURN(
-        HloInstruction * fusion,
-        CreateFusionInstruction(module, hero, captures, fusion_body,
-                                has_dynamic_slices));
-
-    HloComputation* parent = hero->parent();
-    if (fusion->shape().IsTuple()) {
-      TF_RETURN_IF_ERROR(parent->ReplaceInstructionWithDifferentShape(
-          const_cast<HloInstruction*>(hero), fusion));
-      for (auto& sliced_user_path : sliced_user_paths) {
-        auto old_gte =
-            Cast<HloGetTupleElementInstruction>(sliced_user_path.front());
-        HloInstruction* gte =
-            parent->AddInstruction(HloInstruction::CreateGetTupleElement(
-                fusion, old_gte->tuple_index()));
-        TF_RETURN_IF_ERROR(
-            parent->ReplaceInstruction(sliced_user_path.back(), gte));
-      }
-    } else {
-      auto* instr_to_be_replaced = const_cast<HloInstruction*>(hero);
-      if (sliced_user_paths.empty()) {
-        // The only case where a tuple-shaped original hero op is fused into a
-        // non-tuple-shaped fusion is there's only one element of the original
-        // tuple being used. In that case, we need to replace that single
-        // get-tuple-element (instead of the hero op) with the fusion
-        // instruction.
-        if (hero->shape().IsTuple()) {
-          if (hero->user_count() != 1 ||
-              !DynCast<HloGetTupleElementInstruction>(hero->users().front())) {
-            return absl::InternalError(
-                "Expect a single get-tuple-element user of the original "
-                "tuple-shaped hero op when address computation fusion does "
-                "not return a tuple");
-          }
-          instr_to_be_replaced = hero->users().front();
-        }
-      } else {
-        instr_to_be_replaced = sliced_user_paths.front().back();
-      }
-      TF_RETURN_IF_ERROR(
-          parent->ReplaceInstruction(instr_to_be_replaced, fusion));
-    }
-  }
-
-  return true;
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/dynamic_slice_fusion_rewriter.h b/third_party/xla/xla/service/gpu/dynamic_slice_fusion_rewriter.h
deleted file mode 100644
index 15da284..0000000
--- a/third_party/xla/xla/service/gpu/dynamic_slice_fusion_rewriter.h
+++ /dev/null
@@ -1,91 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_DYNAMIC_SLICE_FUSION_REWRITER_H_
-#define XLA_SERVICE_GPU_DYNAMIC_SLICE_FUSION_REWRITER_H_
-
-#include <string>
-#include <utility>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-// Pattern matches (slice(s) + custom call) to custom address computation
-// fusions and rewrites them into fusion instructions and fusion computations.
-//
-// Example:
-//
-//  ENTRY %main {
-//    %p0 = bf16[2,8,8]{2,1,0} parameter(0)
-//    %p1 = bf16[2,8,8]{2,1,0} parameter(1)
-//    %slice_lhs = bf16[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]}
-//    %bitcast_lhs = bf16[8,8]{1,0} bitcast(%slice_lhs)
-//    %slice_rhs = bf16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]}
-//    %bitcast_rhs = bf16[8,8]{1,0} bitcast(%slice_rhs)
-//    ROOT %dot = bf16[8,8]{1,0} custom-call(%bitcast_lhs, %bitcast_rhs),
-//      custom_call_target="__cublas$gemm"
-//  }
-//
-// After the pass:
-//
-//  %address_computation {
-//    %p0 = bf16[2,8,8]{2,1,0} parameter(0)
-//    %p1 = bf16[2,8,8]{2,1,0} parameter(1)
-//    %slice_lhs = bf16[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]}
-//    %bitcast_lhs = bf16[8,8]{1,0} bitcast(%slice_lhs)
-//    %slice_rhs = bf16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]}
-//    %bitcast_rhs = bf16[8,8]{1,0} bitcast(%slice_rhs)
-//    ROOT %dot = bf16[8,8]{1,0} custom-call(%bitcast_lhs, %bitcast_rhs),
-//      custom_call_target="__cublas$gemm"
-//  }
-//
-//  ENTRY %main {
-//    %p0 = bf16[2,8,8]{2,1,0} parameter(0)
-//    %p1 = bf16[2,8,8]{2,1,0} parameter(1)
-//    ROOT %fusion.2 = bf16[8,8]{1,0} fusion(%p0, %p1),
-//        kind=kCustom, calls=%address_computation,
-//        backend_config={"fusion_backend_config":{
-//            "kind":"__custom_fusion",
-//            "custom_fusion_config":{"name":"address_computation"}
-//        }}
-//  }
-//
-class DynamicSliceFusionRewriter : public HloModulePass {
- public:
-  absl::string_view name() const override {
-    return "address-computation-fusion-rewriter";
-  }
-
-  explicit DynamicSliceFusionRewriter(std::string platform_name)
-      : platform_name_(std::move(platform_name)) {}
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
-  std::string platform_name_;
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_DYNAMIC_SLICE_FUSION_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/dynamic_slice_fusion_rewriter_test.cc b/third_party/xla/xla/service/gpu/dynamic_slice_fusion_rewriter_test.cc
deleted file mode 100644
index a539fb5..0000000
--- a/third_party/xla/xla/service/gpu/dynamic_slice_fusion_rewriter_test.cc
+++ /dev/null
@@ -1,1755 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/dynamic_slice_fusion_rewriter.h"
-
-#include <cstddef>
-#include <optional>
-
-#include "absl/status/status.h"
-#include "xla/client/lib/constants.h"
-#include "xla/client/xla_builder.h"
-#include "xla/ffi/ffi.h"
-#include "xla/ffi/ffi_api.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_schedule.h"
-#include "xla/service/custom_call_target_registry.h"
-#include "xla/service/gpu/gpu_device_info_for_tests.h"
-#include "xla/service/hlo_module_config.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/stream_executor/device_memory.h"
-#include "xla/stream_executor/gpu/gpu_types.h"
-#include "xla/stream_executor/stream.h"
-#include "xla/tests/hlo_test_base.h"
-#include "tsl/platform/statusor.h"
-#include "tsl/platform/test.h"
-
-namespace xla::gpu {
-
-class DynamicSliceFusionRewriterTest : public HloTestBase {};
-
-TEST_F(DynamicSliceFusionRewriterTest, SimpleGemm) {
-  const char* hlo = R"(
-    HloModule test
-
-    ENTRY %main.9 {
-      %p0 = f16[2,8,8]{2,1,0} parameter(0)
-      %p1 = f16[2,8,8]{2,1,0} parameter(1)
-      %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]}
-      %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13)
-      %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]}
-      %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14)
-
-      ROOT %custom-call.1 = f16[8,8]{1,0} custom-call(%bitcast.41, %bitcast.42),
-        custom_call_target="__cublas$gemm",
-        backend_config={"gemm_backend_config":{
-          "alpha_real":1,
-          "beta":0,
-          "dot_dimension_numbers":{
-            "lhs_contracting_dimensions":["1"],
-            "rhs_contracting_dimensions":["0"],
-            "lhs_batch_dimensions":[],
-            "rhs_batch_dimensions":[]
-          },
-          "alpha_imag":0,
-          "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
-          "epilogue":"DEFAULT",
-          "lhs_stride":"64",
-          "rhs_stride":"64",
-          "grad_x":false,
-          "grad_y":false
-        }}
-    }
-  )";
-
-  const char* expected = R"(
-    ; CHECK:     %address-computation {{.*}} {
-    ; CHECK-DAG:   [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
-    ; CHECK-DAG:   [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1)
-    ; CHECK-DAG:   [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[1:2], [0:8], [0:8]}
-    ; CHECK-DAG:   [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
-    ; CHECK-DAG:   [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P1]]), slice={[1:2], [0:8], [0:8]}
-    ; CHECK-DAG:   [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
-    ; CHECK:       ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]),
-    ; CHECK:              custom_call_target="__cublas$gemm"
-    ; CHECK:     }
-
-    ; CHECK:     ENTRY %main{{.*}} {
-    ; CHECK:       ROOT [[FUSION:%[^ ]+]] = f16[8,8]{1,0} fusion
-    ; CHECK:         kind=kCustom, calls=%address-computation,
-    ; CHECK:         backend_config={
-    ; CHECK:           "kind":"__custom_fusion",
-    ; CHECK:           "custom_fusion_config":{"name":"address_computation","kernel_index":0}
-    ; CHECK:         }
-    ; CHECK:     }
-  )";
-
-  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
-  RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
-}
-
-TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmWithWorkspace) {
-  const char* hlo = R"(
-    HloModule test
-
-    ENTRY %main.9 {
-      %p0 = f16[2,8,8]{2,1,0} parameter(0)
-      %p1 = f16[2,8,8]{2,1,0} parameter(1)
-      %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]}
-      %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13)
-      %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]}
-      %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14)
-
-      ROOT %custom-call.1 = (f16[8,8]{1,0}, s8[256]{0}) custom-call(%bitcast.41, %bitcast.42),
-        custom_call_target="__cublas$gemm",
-        backend_config={"gemm_backend_config":{
-          "alpha_real":1,
-          "beta":0,
-          "dot_dimension_numbers":{
-            "lhs_contracting_dimensions":["1"],
-            "rhs_contracting_dimensions":["0"],
-            "lhs_batch_dimensions":[],
-            "rhs_batch_dimensions":[]
-          },
-          "alpha_imag":0,
-          "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
-          "epilogue":"DEFAULT",
-          "lhs_stride":"64",
-          "rhs_stride":"64",
-          "grad_x":false,
-          "grad_y":false
-        }}
-    }
-  )";
-
-  const char* expected = R"(
-    ; CHECK:     %address-computation {{.*}} {
-    ; CHECK-DAG:   [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
-    ; CHECK-DAG:   [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1)
-    ; CHECK-DAG:   [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[1:2], [0:8], [0:8]}
-    ; CHECK-DAG:   [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
-    ; CHECK-DAG:   [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P1]]), slice={[1:2], [0:8], [0:8]}
-    ; CHECK-DAG:   [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
-    ; CHECK:       [[CC:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0}) custom-call([[B0]], [[B1]]),
-    ; CHECK:              custom_call_target="__cublas$gemm"
-    ; CHECK:       [[DOT:%[^ ]+]] = f16[8,8]{1,0} get-tuple-element([[CC]]), index=0
-    ; CHECK:       [[WORKSPACE:%[^ ]+]] = s8[256]{0} get-tuple-element([[CC]]), index=1
-    ; CHECK:       ROOT [[TUPLE:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0})
-    ; CHECK:              tuple([[DOT]], [[WORKSPACE]])
-    ; CHECK:     }
-
-    ; CHECK:     ENTRY %main{{.*}} {
-    ; CHECK:       ROOT [[FUSION:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0}) fusion
-    ; CHECK:         kind=kCustom, calls=%address-computation,
-    ; CHECK:         backend_config={
-    ; CHECK:           "kind":"__custom_fusion",
-    ; CHECK:           "custom_fusion_config":{"name":"address_computation","kernel_index":0}
-    ; CHECK:         }
-    ; CHECK:     }
-  )";
-
-  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
-  RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
-}
-
-TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmWorkspaceIgnored) {
-  const char* hlo = R"(
-    HloModule test
-
-    ENTRY %main.9 {
-      %p0 = f16[2,8,8]{2,1,0} parameter(0)
-      %p1 = f16[2,8,8]{2,1,0} parameter(1)
-      %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]}
-      %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13)
-      %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]}
-      %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14)
-
-      %custom-call.1 = (f16[8,8]{1,0}, s8[256]{0}) custom-call(%bitcast.41, %bitcast.42),
-        custom_call_target="__cublas$gemm",
-        backend_config={"gemm_backend_config":{
-          "alpha_real":1,
-          "beta":0,
-          "dot_dimension_numbers":{
-            "lhs_contracting_dimensions":["1"],
-            "rhs_contracting_dimensions":["0"],
-            "lhs_batch_dimensions":[],
-            "rhs_batch_dimensions":[]
-          },
-          "alpha_imag":0,
-          "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
-          "epilogue":"DEFAULT",
-          "lhs_stride":"64",
-          "rhs_stride":"64",
-          "grad_x":false,
-          "grad_y":false
-        }}
-      ROOT %get-tuple-element.0 = f16[8,8]{1,0} get-tuple-element(%custom-call.1), index=0
-    }
-  )";
-
-  const char* expected = R"(
-    ; CHECK:     %address-computation {{.*}} {
-    ; CHECK-DAG:   [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
-    ; CHECK-DAG:   [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1)
-    ; CHECK-DAG:   [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[1:2], [0:8], [0:8]}
-    ; CHECK-DAG:   [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
-    ; CHECK-DAG:   [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P1]]), slice={[1:2], [0:8], [0:8]}
-    ; CHECK-DAG:   [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
-    ; CHECK:       [[CC:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0}) custom-call([[B0]], [[B1]]),
-    ; CHECK:              custom_call_target="__cublas$gemm"
-    ; CHECK:       [[DOT:%[^ ]+]] = f16[8,8]{1,0} get-tuple-element([[CC]]), index=0
-    ; CHECK:       [[WORKSPACE:%[^ ]+]] = s8[256]{0} get-tuple-element([[CC]]), index=1
-    ; CHECK:       ROOT [[TUPLE:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0})
-    ; CHECK:              tuple([[DOT]], [[WORKSPACE]])
-    ; CHECK:     }
-
-    ; CHECK:     ENTRY %main{{.*}} {
-    ; CHECK:       [[FUSION:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0}) fusion
-    ; CHECK:         kind=kCustom, calls=%address-computation,
-    ; CHECK:         backend_config={
-    ; CHECK:           "kind":"__custom_fusion",
-    ; CHECK:           "custom_fusion_config":{"name":"address_computation","kernel_index":0}
-    ; CHECK:         }
-    ; CHECK:       ROOT [[DOT_MAIN:%[^ ]+]] = f16[8,8]{1,0} get-tuple-element([[FUSION]]), index=0
-    ; CHECK:     }
-  )";
-
-  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
-  RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
-}
-
-TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmNotRoot) {
-  const char* hlo = R"(
-    HloModule test
-
-    ENTRY %main.9 {
-      %p0 = f16[2,8,8]{2,1,0} parameter(0)
-      %p1 = f16[2,8,8]{2,1,0} parameter(1)
-      %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]}
-      %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13)
-      %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]}
-      %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14)
-
-      %custom-call.1 = f16[8,8]{1,0} custom-call(%bitcast.41, %bitcast.42),
-        custom_call_target="__cublas$gemm",
-        backend_config={"gemm_backend_config":{
-          "alpha_real":1,
-          "beta":0,
-          "dot_dimension_numbers":{
-            "lhs_contracting_dimensions":["1"],
-            "rhs_contracting_dimensions":["0"],
-            "lhs_batch_dimensions":[],
-            "rhs_batch_dimensions":[]
-          },
-          "alpha_imag":0,
-          "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
-          "epilogue":"DEFAULT",
-          "lhs_stride":"64",
-          "rhs_stride":"64",
-          "grad_x":false,
-          "grad_y":false
-        }}
-      ROOT %res = f16[8,8]{1,0} add(%custom-call.1, %custom-call.1)
-    }
-  )";
-
-  const char* expected = R"(
-    ; CHECK:     %address-computation {{.*}} {
-    ; CHECK-DAG:   [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
-    ; CHECK-DAG:   [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1)
-    ; CHECK-DAG:   [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[1:2], [0:8], [0:8]}
-    ; CHECK-DAG:   [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
-    ; CHECK-DAG:   [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P1]]), slice={[1:2], [0:8], [0:8]}
-    ; CHECK-DAG:   [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
-    ; CHECK:       ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]),
-    ; CHECK:              custom_call_target="__cublas$gemm"
-    ; CHECK:     }
-
-    ; CHECK:     ENTRY %main{{.*}} {
-    ; CHECK:       [[FUSION:%[^ ]+]] = f16[8,8]{1,0} fusion
-    ; CHECK:         kind=kCustom, calls=%address-computation,
-    ; CHECK:         backend_config={
-    ; CHECK:           "kind":"__custom_fusion",
-    ; CHECK:           "custom_fusion_config":{"name":"address_computation","kernel_index":0}
-    ; CHECK:         }
-    ; CHECK:       ROOT {{.*}} = f16[8,8]{1,0} add([[FUSION]], [[FUSION]])
-    ; CHECK:     }
-  )";
-
-  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
-  RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
-}
-
-TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmOperandHasMultipleUsers) {
-  const char* hlo = R"(
-    HloModule test
-
-    ENTRY %main.9 {
-      %p0 = f16[2,8,8]{2,1,0} parameter(0)
-      %p1 = f16[4,8,8]{2,1,0} parameter(1)
-      %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]}
-      %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13)
-      %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[2:3], [0:8], [0:8]}
-      %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14)
-
-      %custom-call.1 = f16[8,8]{1,0} custom-call(%bitcast.41, %bitcast.42),
-        custom_call_target="__cublas$gemm",
-        backend_config={"gemm_backend_config":{
-          "alpha_real":1,
-          "beta":0,
-          "dot_dimension_numbers":{
-            "lhs_contracting_dimensions":["1"],
-            "rhs_contracting_dimensions":["0"],
-            "lhs_batch_dimensions":[],
-            "rhs_batch_dimensions":[]
-          },
-          "alpha_imag":0,
-          "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
-          "epilogue":"DEFAULT",
-          "lhs_stride":"64",
-          "rhs_stride":"64",
-          "grad_x":false,
-          "grad_y":false
-        }}
-      ROOT %res = f16[8,8]{1,0} add(%custom-call.1, %bitcast.41)
-    }
-  )";
-
-  const char* expected = R"(
-    ; CHECK:     %address-computation {{.*}} {
-    ; CHECK-DAG:   [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
-    ; CHECK-DAG:   [[P1:%[^ ]+]] = f16[4,8,8]{2,1,0} parameter(1)
-    ; CHECK-DAG:   [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[1:2], [0:8], [0:8]}
-    ; CHECK-DAG:   [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
-    ; CHECK-DAG:   [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P1]]), slice={[2:3], [0:8], [0:8]}
-    ; CHECK-DAG:   [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
-    ; CHECK:       ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]),
-    ; CHECK:              custom_call_target="__cublas$gemm"
-    ; CHECK:     }
-
-    ; CHECK:     ENTRY %main{{.*}} {
-    ; CHECK-DAG:   [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
-    ; CHECK-DAG:   [[P1:%[^ ]+]] = f16[4,8,8]{2,1,0} parameter(1)
-    ; CHECK-DAG:   [[FUSION:%[^ ]+]] = f16[8,8]{1,0} fusion([[P0]], [[P1]])
-    ; CHECK-DAG:     kind=kCustom, calls=%address-computation,
-    ; CHECK-DAG:     backend_config={
-    ; CHECK-DAG:       "kind":"__custom_fusion",
-    ; CHECK-DAG:       "custom_fusion_config":{"name":"address_computation","kernel_index":0}
-    ; CHECK-DAG:     }
-    ; CHECK-DAG:   [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[1:2], [0:8], [0:8]}
-    ; CHECK-DAG:   [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
-    ; CHECK:       ROOT {{.*}} = f16[8,8]{1,0} add([[FUSION]], [[B0]])
-    ; CHECK:     }
-  )";
-
-  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
-  RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
-}
-
-TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmOperandsHaveMultipleUsers) {
-  const char* hlo = R"(
-    HloModule test
-
-    ENTRY %main.9 {
-      %p0 = f16[2,8,8]{2,1,0} parameter(0)
-      %p1 = f16[2,8,8]{2,1,0} parameter(1)
-      %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]}
-      %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13)
-      %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]}
-      %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14)
-
-      %custom-call.0 = f16[8,8]{1,0} custom-call(%bitcast.41, %bitcast.42),
-        custom_call_target="__cublas$gemm",
-        backend_config={"gemm_backend_config":{
-          "alpha_real":1,
-          "beta":0,
-          "dot_dimension_numbers":{
-            "lhs_contracting_dimensions":["1"],
-            "rhs_contracting_dimensions":["0"],
-            "lhs_batch_dimensions":[],
-            "rhs_batch_dimensions":[]
-          },
-          "alpha_imag":0,
-          "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
-          "epilogue":"DEFAULT",
-          "lhs_stride":"64",
-          "rhs_stride":"64",
-          "grad_x":false,
-          "grad_y":false
-        }}
-
-      ROOT %custom-call.1 = f16[8,8]{1,0} custom-call(%bitcast.42, %bitcast.41),
-        custom_call_target="__cublas$gemm",
-        backend_config={"gemm_backend_config":{
-          "alpha_real":1,
-          "beta":0,
-          "dot_dimension_numbers":{
-            "lhs_contracting_dimensions":["1"],
-            "rhs_contracting_dimensions":["0"],
-            "lhs_batch_dimensions":[],
-            "rhs_batch_dimensions":[]
-          },
-          "alpha_imag":0,
-          "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
-          "epilogue":"DEFAULT",
-          "lhs_stride":"64",
-          "rhs_stride":"64",
-          "grad_x":false,
-          "grad_y":false
-        }}
-    }
-  )";
-
-  const char* expected = R"(
-    ; CHECK:     %address-computation{{.*}} {
-    ; CHECK-DAG:   [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
-    ; CHECK-DAG:   [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1)
-    ; CHECK-DAG:   [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[1:2], [0:8], [0:8]}
-    ; CHECK-DAG:   [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
-    ; CHECK-DAG:   [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P1]]), slice={[1:2], [0:8], [0:8]}
-    ; CHECK-DAG:   [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
-    ; CHECK:       ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]),
-    ; CHECK:              custom_call_target="__cublas$gemm"
-    ; CHECK:     }
-    ; CHECK:     %address-computation{{.*}} {
-    ; CHECK-DAG:   [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
-    ; CHECK-DAG:   [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1)
-    ; CHECK-DAG:   [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[1:2], [0:8], [0:8]}
-    ; CHECK-DAG:   [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
-    ; CHECK-DAG:   [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P1]]), slice={[1:2], [0:8], [0:8]}
-    ; CHECK-DAG:   [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
-    ; CHECK:       ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]),
-    ; CHECK:              custom_call_target="__cublas$gemm"
-    ; CHECK:     }
-  )";
-
-  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
-  RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
-}
-
-TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmSlicingNotParameter) {
-  const char* hlo = R"(
-    HloModule test
-
-    ENTRY %main.9 {
-      %p0 = f16[4,8,8]{2,1,0} parameter(0)
-      %p1 = f16[2,8,8]{2,1,0} parameter(1)
-      %slice.12 = f16[2,8,8]{2,1,0} slice(%p0), slice={[0:2], [0:8], [0:8]}
-      %slice.13 = f16[1,8,8]{2,1,0} slice(%slice.12), slice={[1:2], [0:8], [0:8]}
-      %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13)
-      %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]}
-      %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14)
-
-      %custom-call.1 = f16[8,8]{1,0} custom-call(%bitcast.41, %bitcast.42),
-        custom_call_target="__cublas$gemm",
-        backend_config={"gemm_backend_config":{
-          "alpha_real":1,
-          "beta":0,
-          "dot_dimension_numbers":{
-            "lhs_contracting_dimensions":["1"],
-            "rhs_contracting_dimensions":["0"],
-            "lhs_batch_dimensions":[],
-            "rhs_batch_dimensions":[]
-          },
-          "alpha_imag":0,
-          "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
-          "epilogue":"DEFAULT",
-          "lhs_stride":"64",
-          "rhs_stride":"64",
-          "grad_x":false,
-          "grad_y":false
-        }}
-      ROOT %res = f16[8,8]{1,0} add(%custom-call.1, %custom-call.1)
-    }
-  )";
-
-  const char* expected = R"(
-    ; CHECK:     %address-computation {{.*}} {
-    ; CHECK-DAG:   [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
-    ; CHECK-DAG:   [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1)
-    ; CHECK-DAG:   [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[1:2], [0:8], [0:8]}
-    ; CHECK-DAG:   [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
-    ; CHECK-DAG:   [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P1]]), slice={[1:2], [0:8], [0:8]}
-    ; CHECK-DAG:   [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
-    ; CHECK:       ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]),
-    ; CHECK:              custom_call_target="__cublas$gemm"
-    ; CHECK:     }
-
-    ; CHECK:     ENTRY %main{{.*}} {
-    ; CHECK-DAG:   [[P0:%[^ ]+]] = f16[4,8,8]{2,1,0} parameter(0)
-    ; CHECK-DAG:   [[S0:%[^ ]+]] = f16[2,8,8]{2,1,0} slice([[P0]]), slice={[0:2], [0:8], [0:8]}
-    ; CHECK-DAG:   [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1)
-    ; CHECK:       [[FUSION:%[^ ]+]] = f16[8,8]{1,0} fusion([[S0]], [[P1]])
-    ; CHECK:         kind=kCustom, calls=%address-computation,
-    ; CHECK:         backend_config={
-    ; CHECK:           "kind":"__custom_fusion",
-    ; CHECK:           "custom_fusion_config":{"name":"address_computation","kernel_index":0}
-    ; CHECK:         }
-    ; CHECK:       ROOT {{.*}} = f16[8,8]{1,0} add([[FUSION]], [[FUSION]])
-    ; CHECK:     }
-  )";
-
-  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
-  RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
-}
-
-TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmNotContiguousSlice) {
-  const char* hlo = R"(
-    HloModule test
-
-    ENTRY %main.9 {
-      %p0 = f16[2,8,8]{2,1,0} parameter(0)
-      %p1 = f16[2,8,8]{2,1,0} parameter(1)
-      %slice.13 = f16[1,4,6]{2,1,0} slice(%p0), slice={[1:2], [0:4], [0:6]}
-      %bitcast.41 = f16[4,6]{1,0} bitcast(%slice.13)
-      %slice.14 = f16[1,6,4]{2,1,0} slice(%p1), slice={[1:2], [0:6], [0:4]}
-      %bitcast.42 = f16[6,4]{1,0} bitcast(%slice.14)
-
-      ROOT %custom-call.1 = f16[4,4]{1,0} custom-call(%bitcast.41, %bitcast.42),
-        custom_call_target="__cublas$gemm",
-        backend_config={"gemm_backend_config":{
-          "alpha_real":1,
-          "beta":0,
-          "dot_dimension_numbers":{
-            "lhs_contracting_dimensions":["1"],
-            "rhs_contracting_dimensions":["0"],
-            "lhs_batch_dimensions":[],
-            "rhs_batch_dimensions":[]
-          },
-          "alpha_imag":0,
-          "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
-          "epilogue":"DEFAULT",
-          "lhs_stride":"64",
-          "rhs_stride":"64",
-          "grad_x":false,
-          "grad_y":false
-        }}
-    }
-  )";
-
-  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
-  RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"),
-                            std::nullopt);
-}
-
-TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmNonNoOpInSliceChain) {
-  const char* hlo = R"(
-    HloModule test
-
-    ENTRY %main.9 {
-      %p0 = f16[2,8,8]{2,1,0} parameter(0)
-      %p1 = f16[2,8,8]{2,1,0} parameter(1)
-      %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[0:1], [0:8], [0:8]}
-      %slice.14 = f16[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]}
-      %add.0 = f16[1,8,8]{2,1,0} add(%slice.13, %slice.14)
-      %bitcast.41 = f16[8,8]{1,0} bitcast(%add.0)
-      %slice.15 = f16[1,8,8]{2,1,0} slice(%p1), slice={[0:1], [0:8], [0:8]}
-      %slice.16 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]}
-      %add.1 = f16[1,8,8]{2,1,0} add(%slice.15, %slice.16)
-      %bitcast.42 = f16[8,8]{1,0} bitcast(%add.1)
-
-      ROOT %custom-call.1 = f16[8,8]{1,0} custom-call(%bitcast.41, %bitcast.42),
-        custom_call_target="__cublas$gemm",
-        backend_config={"gemm_backend_config":{
-          "alpha_real":1,
-          "beta":0,
-          "dot_dimension_numbers":{
-            "lhs_contracting_dimensions":["1"],
-            "rhs_contracting_dimensions":["0"],
-            "lhs_batch_dimensions":[],
-            "rhs_batch_dimensions":[]
-          },
-          "alpha_imag":0,
-          "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
-          "epilogue":"DEFAULT",
-          "lhs_stride":"64",
-          "rhs_stride":"64",
-          "grad_x":false,
-          "grad_y":false
-        }}
-    }
-  )";
-
-  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
-  RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"),
-                            std::nullopt);
-}
-
-TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmDuplicateOperand) {
-  const char* hlo = R"(
-    HloModule test
-
-    ENTRY %main {
-      %p0 = (f32[100,100]{1,0}, f32[100,100]{1,0}) parameter(0)
-      %get-tuple-element.240 = f32[100,100]{1,0} get-tuple-element(%p0), index=0
-      %get-tuple-element.241 = f32[100,100]{1,0} get-tuple-element(%p0), index=1
-      %concatenate.10 = f32[200,100]{1,0} concatenate(%get-tuple-element.240, %get-tuple-element.241), dimensions={0}
-      %custom-call.16 = (f32[200,100]{1,0}, s8[120000]{0}) custom-call(%concatenate.10, %get-tuple-element.240),
-        custom_call_target="__cublas$gemm",
-        backend_config={
-          "gemm_backend_config":{
-            "alpha_real":1,
-            "beta":0,
-            "dot_dimension_numbers":{
-              "lhs_contracting_dimensions":["1"],
-              "rhs_contracting_dimensions":["0"],
-              "lhs_batch_dimensions":[],
-              "rhs_batch_dimensions":[]
-            },
-            "alpha_imag":0,
-            "precision_config":{"operand_precision":["HIGHEST","HIGHEST"]},
-            "epilogue":"DEFAULT",
-            "lhs_stride":"20000",
-            "rhs_stride":"10000",
-            "grad_x":false,
-            "grad_y":false
-          }
-        }
-      %get-tuple-element.97 = f32[200,100]{1,0} get-tuple-element(%custom-call.16), index=0
-      %slice.26 = f32[100,100]{1,0} slice(%get-tuple-element.97), slice={[0:100], [0:100]}
-      ROOT %custom-call.17 = (f32[100,100]{1,0}, s8[80000]{0}) custom-call(%slice.26, %slice.26),
-        custom_call_target="__cublas$gemm",
-        backend_config={
-          "gemm_backend_config":{
-            "alpha_real":1,
-            "beta":0,
-            "dot_dimension_numbers":{
-              "lhs_contracting_dimensions":["1"],
-              "rhs_contracting_dimensions":["0"],
-              "lhs_batch_dimensions":[],
-              "rhs_batch_dimensions":[]
-            },
-            "alpha_imag":0,
-            "precision_config":{"operand_precision":["HIGHEST","HIGHEST"]},
-            "epilogue":"DEFAULT",
-            "lhs_stride":"10000",
-            "rhs_stride":"10000",
-            "grad_x":false,
-            "grad_y":false
-          }
-        }
-  })";
-
-  const char* expected = R"(
-    ; CHECK:     %address-computation {{.*}} {
-    ; CHECK:       [[P0:%[^ ]+]] = f32[200,100]{1,0} parameter(0)
-    ; CHECK:       [[S0:%[^ ]+]] = f32[100,100]{1,0} slice([[P0]]), slice={[0:100], [0:100]}
-    ; CHECK-NOT:   slice
-    ; CHECK:       [[CC:%[^ ]+]] = (f32[100,100]{1,0}, s8[80000]{0}) custom-call([[S0]], [[S0]]),
-    ; CHECK:         custom_call_target="__cublas$gemm"
-    ; CHECK:     }
-
-    ; CHECK:     ENTRY %main{{.*}} {
-    ; CHECK:       ROOT [[FUSION:%[^ ]+]] = (f32[100,100]{1,0}, s8[80000]{0}) fusion
-    ; CHECK:         kind=kCustom, calls=%address-computation,
-    ; CHECK:         backend_config={
-    ; CHECK:           "kind":"__custom_fusion",
-    ; CHECK:           "custom_fusion_config":{"name":"address_computation","kernel_index":0}
-    ; CHECK:         }
-    ; CHECK:     }
-  )";
-
-  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
-  RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
-}
-
-TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmReverseOperandOrder) {
-  const char* hlo = R"(
-    HloModule test
-
-    ENTRY %main.9 {
-      %p0 = f16[2,8,8]{2,1,0} parameter(1)
-      %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[0:1], [0:8], [0:8]}
-      %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13)
-      %p1 = f16[2,8,8]{2,1,0} parameter(0)
-      %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]}
-      %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14)
-
-      ROOT %custom-call.1 = f16[8,8]{1,0} custom-call(%bitcast.41, %bitcast.42),
-        custom_call_target="__cublas$gemm",
-        backend_config={"gemm_backend_config":{
-          "alpha_real":1,
-          "beta":0,
-          "dot_dimension_numbers":{
-            "lhs_contracting_dimensions":["1"],
-            "rhs_contracting_dimensions":["0"],
-            "lhs_batch_dimensions":[],
-            "rhs_batch_dimensions":[]
-          },
-          "alpha_imag":0,
-          "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
-          "epilogue":"DEFAULT",
-          "lhs_stride":"64",
-          "rhs_stride":"64",
-          "grad_x":false,
-          "grad_y":false
-        }}
-    }
-  )";
-
-  const char* expected = R"(
-    ; CHECK:     %address-computation {{.*}} {
-    ; CHECK-DAG:   [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
-    ; CHECK-DAG:   [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1)
-    ; CHECK-DAG:   [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[0:1], [0:8], [0:8]}
-    ; CHECK-DAG:   [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
-    ; CHECK-DAG:   [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P1]]), slice={[1:2], [0:8], [0:8]}
-    ; CHECK-DAG:   [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
-    ; CHECK:       ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]),
-    ; CHECK:              custom_call_target="__cublas$gemm"
-    ; CHECK:     }
-
-    ; CHECK:     ENTRY %main{{.*}} {
-    ; CHECK-DAG:   [[A0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1)
-    ; CHECK-DAG:   [[A1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
-    ; CHECK:       ROOT [[FUSION:%[^ ]+]] = f16[8,8]{1,0} fusion([[A0]], [[A1]])
-    ; CHECK:         kind=kCustom, calls=%address-computation,
-    ; CHECK:         backend_config={
-    ; CHECK:           "kind":"__custom_fusion",
-    ; CHECK:           "custom_fusion_config":{"name":"address_computation","kernel_index":0}
-    ; CHECK:         }
-    ; CHECK:     }
-  )";
-
-  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
-  RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
-}
-
-TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmReverseOperandOrder2) {
-  const char* hlo = R"(
-    HloModule test
-
-    ENTRY %main.9 {
-      %p0 = f16[2,8,8]{2,1,0} parameter(0)
-      %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[0:1], [0:8], [0:8]}
-      %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13)
-      %p1 = f16[2,8,8]{2,1,0} parameter(1)
-      %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]}
-      %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14)
-
-      ROOT %custom-call.1 = f16[8,8]{1,0} custom-call(%bitcast.42, %bitcast.41),
-        custom_call_target="__cublas$gemm",
-        backend_config={"gemm_backend_config":{
-          "alpha_real":1,
-          "beta":0,
-          "dot_dimension_numbers":{
-            "lhs_contracting_dimensions":["1"],
-            "rhs_contracting_dimensions":["0"],
-            "lhs_batch_dimensions":[],
-            "rhs_batch_dimensions":[]
-          },
-          "alpha_imag":0,
-          "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
-          "epilogue":"DEFAULT",
-          "lhs_stride":"64",
-          "rhs_stride":"64",
-          "grad_x":false,
-          "grad_y":false
-        }}
-    }
-  )";
-
-  const char* expected = R"(
-    ; CHECK:     %address-computation {{.*}} {
-    ; CHECK-DAG:   [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
-    ; CHECK-DAG:   [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1)
-    ; CHECK-DAG:   [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[1:2], [0:8], [0:8]}
-    ; CHECK-DAG:   [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
-    ; CHECK-DAG:   [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P1]]), slice={[0:1], [0:8], [0:8]}
-    ; CHECK-DAG:   [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
-    ; CHECK:       ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]),
-    ; CHECK:              custom_call_target="__cublas$gemm"
-    ; CHECK:     }
-
-    ; CHECK:     ENTRY %main{{.*}} {
-    ; CHECK-DAG:   [[A0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1)
-    ; CHECK-DAG:   [[A1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
-    ; CHECK:       ROOT [[FUSION:%[^ ]+]] = f16[8,8]{1,0} fusion([[A0]], [[A1]])
-    ; CHECK:         kind=kCustom, calls=%address-computation,
-    ; CHECK:         backend_config={
-    ; CHECK:           "kind":"__custom_fusion",
-    ; CHECK:           "custom_fusion_config":{"name":"address_computation","kernel_index":0}
-    ; CHECK:         }
-    ; CHECK:     }
-  )";
-
-  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
-  RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
-}
-
-TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmOperandAliasingOutput) {
-  const char* hlo = R"(
-    HloModule test
-
-    ENTRY %main.9 {
-      %p0 = (f32[100,100]{1,0}, f32[100,100]{1,0}) parameter(0)
-      %get-tuple-element.287 = f32[100,100]{1,0} get-tuple-element(%p0), index=0
-      %get-tuple-element.288 = f32[100,100]{1,0} get-tuple-element(%p0), index=1
-      %concatenate.12 = f32[200,100]{1,0} concatenate(%get-tuple-element.287, %get-tuple-element.288), dimensions={0}
-      %slice.30 = f32[100,100]{1,0} slice(%concatenate.12), slice={[16:116], [0:100]}
-      %slice.34 = f32[100,100]{1,0} slice(%concatenate.12), slice={[99:199], [0:100]}
-      ROOT %cublas-gemm.15 = (f32[100,100]{1,0}, s8[120000]{0}) custom-call(%get-tuple-element.287, %slice.30, %slice.34),
-        custom_call_target="__cublas$gemm",
-        output_to_operand_aliasing={{0}: (2, {})},
-        backend_config={"gemm_backend_config":{
-          "alpha_real":1,
-          "beta":1,
-          "dot_dimension_numbers":{
-            "lhs_contracting_dimensions":["1"],
-            "rhs_contracting_dimensions":["0"],
-            "lhs_batch_dimensions":[],
-            "rhs_batch_dimensions":[]
-          },
-          "alpha_imag":0,
-          "precision_config":{"operand_precision":["HIGHEST","HIGHEST"]},
-          "epilogue":"DEFAULT",
-          "lhs_stride":"10000",
-          "rhs_stride":"10000",
-          "grad_x":false,
-          "grad_y":false
-        }}
-    }
-  )";
-
-  const char* expected = R"(
-    ; CHECK:     %address-computation {{.*}} {
-    ; CHECK-DAG:   [[P2:%[^ ]+]] = f32[100,100]{1,0} parameter(2)
-    ; CHECK-DAG:   [[P1:%[^ ]+]] = f32[100,100]{1,0} parameter(1)
-    ; CHECK-DAG:   [[P0:%[^ ]+]] = f32[200,100]{1,0} parameter(0)
-    ; CHECK-DAG:   [[S1:%[^ ]+]] = f32[100,100]{1,0} slice([[P0]]), slice={[16:116], [0:100]}
-    ; CHECK:       [[CC:%[^ ]+]] = (f32[100,100]{1,0}, s8[120000]{0}) custom-call([[P1]], [[S1]], [[P2]]),
-    ; CHECK:         custom_call_target="__cublas$gemm"
-    ; CHECK:     }
-
-    ; CHECK:     ENTRY %main{{.*}} {
-    ; CHECK:       [[P:%[^ ]+]] = (f32[100,100]{1,0}, f32[100,100]{1,0}) parameter(0)
-    ; CHECK:       [[GTE0:%[^ ]+]] = f32[100,100]{1,0} get-tuple-element([[P]]), index=0
-    ; CHECK:       [[GTE1:%[^ ]+]] = f32[100,100]{1,0} get-tuple-element([[P]]), index=1
-    ; CHECK:       [[CONCAT:%[^ ]+]] = f32[200,100]{1,0} concatenate([[GTE0]], [[GTE1]]), dimensions={0}
-    ; CHECK:       [[S:%[^ ]+]] = f32[100,100]{1,0} slice([[CONCAT]]), slice={[99:199], [0:100]}
-    ; CHECK:       ROOT [[FUSION:%[^ ]+]] = (f32[100,100]{1,0}, s8[120000]{0}) fusion([[CONCAT]], [[GTE0]], [[S]])
-    ; CHECK:         kind=kCustom, calls=%address-computation,
-    ; CHECK:         backend_config={
-    ; CHECK:           "kind":"__custom_fusion",
-    ; CHECK:           "custom_fusion_config":{"name":"address_computation","kernel_index":0}
-    ; CHECK:         }
-    ; CHECK:     }
-  )";
-
-  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
-  RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
-}
-
-TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmOperandsFromSameSlice) {
-  const char* hlo = R"(
-    HloModule test
-
-    ENTRY %main.9 {
-      %p0 = f16[2,8,8]{2,1,0} parameter(0)
-      %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[0:1], [0:8], [0:8]}
-      %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13)
-      %bitcast.42 = f16[8,8]{0,1} bitcast(%slice.13)
-
-      ROOT %custom-call.1 = f16[8,8]{1,0} custom-call(%bitcast.41, %bitcast.42),
-        custom_call_target="__cublas$gemm",
-        backend_config={"gemm_backend_config":{
-          "alpha_real":1,
-          "beta":0,
-          "dot_dimension_numbers":{
-            "lhs_contracting_dimensions":["1"],
-            "rhs_contracting_dimensions":["0"],
-            "lhs_batch_dimensions":[],
-            "rhs_batch_dimensions":[]
-          },
-          "alpha_imag":0,
-          "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
-          "epilogue":"DEFAULT",
-          "lhs_stride":"64",
-          "rhs_stride":"64",
-          "grad_x":false,
-          "grad_y":false
-        }}
-    }
-  )";
-
-  const char* expected = R"(
-    ; CHECK:     %address-computation {{.*}} {
-    ; CHECK-DAG:   [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
-    ; CHECK-DAG:   [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[0:1], [0:8], [0:8]}
-    ; CHECK-DAG:   [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
-    ; CHECK-DAG:   [[B1:%[^ ]+]] = f16[8,8]{0,1} bitcast([[S0]])
-    ; CHECK:       ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]),
-    ; CHECK:              custom_call_target="__cublas$gemm"
-    ; CHECK:     }
-
-    ; CHECK:     ENTRY %main{{.*}} {
-    ; CHECK-DAG:   [[A0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
-    ; CHECK:       ROOT [[FUSION:%[^ ]+]] = f16[8,8]{1,0} fusion([[A0]])
-    ; CHECK:         kind=kCustom, calls=%address-computation,
-    ; CHECK:         backend_config={
-    ; CHECK:           "kind":"__custom_fusion",
-    ; CHECK:           "custom_fusion_config":{"name":"address_computation","kernel_index":0}
-    ; CHECK:         }
-    ; CHECK:     }
-  )";
-
-  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
-  RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
-}
-
-static absl::Status Memcpy(se::Stream* stream, ffi::AnyBuffer src,
-                           ffi::AnyBuffer dst) {
-  se::DeviceMemoryBase dst_mem = dst.device_memory();
-  se::DeviceMemoryBase src_mem = src.device_memory();
-  return stream->MemcpyD2D(&dst_mem, src_mem, src_mem.size());
-}
-
-XLA_FFI_DEFINE_HANDLER(kMemcpy, Memcpy,
-                       ffi::Ffi::Bind()
-                           .Ctx<ffi::Stream>()
-                           .Arg<ffi::AnyBuffer>()  // src
-                           .Arg<ffi::AnyBuffer>()  // dst
-);
-XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$memcpy", "gpu",
-                         kMemcpy);
-
-TEST_F(DynamicSliceFusionRewriterTest, SimpleCustomCall) {
-  XlaBuilder b(TestName());
-  CustomCall(&b, "__xla_test$$memcpy",
-             /*operands=*/
-             {Slice(Broadcast(ConstantR0WithType(&b, F32, 42.0), {256}), {0},
-                    {128}, {1})},
-             ShapeUtil::MakeShape(F32, {128}), /*opaque=*/"",
-             /*has_side_effect=*/false,
-             /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
-             /*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
-             /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
-  TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build());
-  xla::HloModuleConfig hlo_config(
-      xla::ProgramShape(computation.proto().host_program_shape()),
-      /*ignore_layouts=*/false);
-  DebugOptions debug_options = GetDebugOptionsForTest();
-  debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false);
-  hlo_config.set_debug_options(debug_options);
-  TF_ASSERT_OK_AND_ASSIGN(auto hlo, xla::HloModule::CreateFromProto(
-                                        computation.proto(), hlo_config));
-
-  const char* expected = R"(
-    ; CHECK:     %address-computation {{.*}} {
-    ; CHECK:       [[P0:%[^ ]+]] = f32[256]{0} parameter(0)
-    ; CHECK:       [[S0:%[^ ]+]] = f32[128]{0} slice([[P0]]), slice={[0:128]}
-    ; CHECK:       ROOT [[CC:%[^ ]+]] = f32[128]{0} custom-call([[S0]]),
-    ; CHECK:              custom_call_target="__xla_test$$memcpy",
-    ; CHECK:              api_version=API_VERSION_TYPED_FFI
-    ; CHECK:     }
-
-    ; CHECK:     ENTRY %{{.*}} {
-    ; CHECK:       [[C0:%[^ ]+]] = f32[] constant(42)
-    ; CHECK:       [[BC:%[^ ]+]] = f32[256]{0} broadcast([[C0]])
-    ; CHECK:       ROOT [[FUSION:%[^ ]+]] = f32[128]{0} fusion([[BC]])
-    ; CHECK:         kind=kCustom, calls=%address-computation,
-    ; CHECK:         backend_config={
-    ; CHECK:           "kind":"__custom_fusion",
-    ; CHECK:           "custom_fusion_config":{"name":"address_computation","kernel_index":0}
-    ; CHECK:         }
-    ; CHECK:     }
-  )";
-
-  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
-  RunAndFilecheckHloRewrite(hlo->ToString(), DynamicSliceFusionRewriter("gpu"),
-                            expected);
-}
-
-void Callback_Void(se::gpu::GpuStreamHandle stream, void** buffers,
-                   const char* /*opaque*/, size_t /*opaque_len*/) {}
-
-XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_Void, "gpu");
-
-TEST_F(DynamicSliceFusionRewriterTest, SimpleCustomCallLegacy) {
-  XlaBuilder b(TestName());
-  CustomCall(&b, "Callback_Void",
-             /*operands=*/
-             {Slice(Broadcast(ConstantR0WithType(&b, F32, 42.0), {256}), {0},
-                    {128}, {1})},
-             ShapeUtil::MakeShape(F32, {128}), /*opaque=*/"");
-  TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build());
-  xla::HloModuleConfig hlo_config(
-      xla::ProgramShape(computation.proto().host_program_shape()),
-      /*ignore_layouts=*/false);
-  DebugOptions debug_options = GetDebugOptionsForTest();
-  debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false);
-  hlo_config.set_debug_options(debug_options);
-  TF_ASSERT_OK_AND_ASSIGN(auto hlo, xla::HloModule::CreateFromProto(
-                                        computation.proto(), hlo_config));
-  // TF_ASSERT_OK_AND_ASSIGN(
-  //     HloSchedule schedule,
-  //     ScheduleModule(hlo.get(), [](const BufferValue& buffer) {
-  //       return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
-  //     }));
-  // TF_CHECK_OK(hlo->set_schedule(std::move(schedule)));
-
-  const char* expected = R"(
-    ; CHECK:     %address-computation {{.*}} {
-    ; CHECK:       [[P0:%[^ ]+]] = f32[256]{0} parameter(0)
-    ; CHECK:       [[S0:%[^ ]+]] = f32[128]{0} slice([[P0]]), slice={[0:128]}
-    ; CHECK:       ROOT [[CC:%[^ ]+]] = f32[128]{0} custom-call([[S0]]),
-    ; CHECK:              custom_call_target="Callback_Void"
-    ; CHECK:     }
-
-    ; CHECK:     ENTRY %{{.*}} {
-    ; CHECK:       [[C0:%[^ ]+]] = f32[] constant(42)
-    ; CHECK:       [[BC:%[^ ]+]] = f32[256]{0} broadcast([[C0]])
-    ; CHECK:       ROOT [[FUSION:%[^ ]+]] = f32[128]{0} fusion([[BC]])
-    ; CHECK:         kind=kCustom, calls=%address-computation,
-    ; CHECK:         backend_config={
-    ; CHECK:           "kind":"__custom_fusion",
-    ; CHECK:           "custom_fusion_config":{"name":"address_computation","kernel_index":0}
-    ; CHECK:         }
-    ; CHECK:     }
-  )";
-
-  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
-  RunAndFilecheckHloRewrite(hlo->ToString(), DynamicSliceFusionRewriter("gpu"),
-                            expected);
-}
-
-TEST_F(DynamicSliceFusionRewriterTest, TupleSliceCustomCallLegacy) {
-  XlaBuilder b(TestName());
-  CustomCall(
-      &b, "Callback_Void",
-      /*operands=*/
-      {
-          Tuple(&b,
-                {
-                    Slice(Broadcast(ConstantR0WithType(&b, F32, 5), {8, 8}),
-                          {0, 0}, {4, 8}, {1, 1}),
-                    Broadcast(ConstantR0WithType(&b, F32, 2), {256}),
-                }),
-          Tuple(&b,
-                {
-                    Broadcast(ConstantR0WithType(&b, F32, 3), {1024}),
-                    Broadcast(ConstantR0WithType(&b, F32, 4), {8}),
-                }),
-      },
-      ShapeUtil::MakeShape(F32, {128}), /*opaque=*/"");
-  TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build());
-  xla::HloModuleConfig hlo_config(
-      xla::ProgramShape(computation.proto().host_program_shape()),
-      /*ignore_layouts=*/false);
-  DebugOptions debug_options = GetDebugOptionsForTest();
-  debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false);
-  hlo_config.set_debug_options(debug_options);
-  TF_ASSERT_OK_AND_ASSIGN(auto hlo, xla::HloModule::CreateFromProto(
-                                        computation.proto(), hlo_config));
-  // TF_ASSERT_OK_AND_ASSIGN(
-  //     HloSchedule schedule,
-  //     ScheduleModule(hlo.get(), [](const BufferValue& buffer) {
-  //       return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
-  //     }));
-  // TF_CHECK_OK(hlo->set_schedule(std::move(schedule)));
-
-  const char* expected = R"(
-    ; CHECK:     %address-computation {{.*}} {
-    ; CHECK-DAG:   [[P0:%[^ ]+]] = f32[8,8]{1,0} parameter(0)
-    ; CHECK-DAG:   [[S0:%[^ ]+]] = f32[4,8]{1,0} slice([[P0]]), slice={[0:4], [0:8]}
-    ; CHECK-DAG:   [[P1:%[^ ]+]] = f32[256]{0} parameter(1)
-    ; CHECK-DAG:   [[T0:%[^ ]+]] = (f32[4,8]{1,0}, f32[256]{0}) tuple([[S0]], [[P1]])
-    ; CHECK-DAG:   [[P2:%[^ ]+]] = (f32[1024]{0}, f32[8]{0}) parameter(2)
-    ; CHECK:       ROOT [[CC:%[^ ]+]] = f32[128]{0} custom-call([[T0]], [[P2]]),
-    ; CHECK:              custom_call_target="Callback_Void"
-    ; CHECK:     }
-
-    ; CHECK:     ENTRY %{{.*}} {
-    ; CHECK:       ROOT [[FUSION:%[^ ]+]] = f32[128]{0} fusion(
-    ; CHECK:         kind=kCustom, calls=%address-computation,
-    ; CHECK:         backend_config={
-    ; CHECK:           "kind":"__custom_fusion",
-    ; CHECK:           "custom_fusion_config":{"name":"address_computation","kernel_index":0}
-    ; CHECK:         }
-    ; CHECK:     }
-  )";
-
-  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
-  RunAndFilecheckHloRewrite(hlo->ToString(), DynamicSliceFusionRewriter("gpu"),
-                            expected);
-}
-
-TEST_F(DynamicSliceFusionRewriterTest, TupledOutputCustomCallLegacy) {
-  XlaBuilder b(TestName());
-  auto custom_call = CustomCall(
-      &b, "Callback_Void",
-      /*operands=*/
-      {
-          Tuple(&b,
-                {
-                    Slice(Broadcast(ConstantR0WithType(&b, F32, 5), {8, 8}),
-                          {0, 0}, {4, 8}, {1, 1}),
-                    Broadcast(ConstantR0WithType(&b, F32, 2), {256}),
-                }),
-          Tuple(&b,
-                {
-                    Broadcast(ConstantR0WithType(&b, F32, 3), {1024}),
-                    Broadcast(ConstantR0WithType(&b, F32, 4), {8}),
-                }),
-      },
-      ShapeUtil::MakeTupleShape({
-          ShapeUtil::MakeShape(F32, {8}),
-          ShapeUtil::MakeTupleShape({
-              ShapeUtil::MakeShape(F32, {128}),
-              ShapeUtil::MakeShape(F32, {256}),
-          }),
-          ShapeUtil::MakeShape(F32, {1024}),
-          ShapeUtil::MakeShape(F32, {4, 8}),
-      }),
-      /*opaque=*/"");
-  Tuple(&b, {GetTupleElement(GetTupleElement(custom_call, 1), 0),
-             GetTupleElement(custom_call, 2)});
-  TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build());
-  xla::HloModuleConfig hlo_config(
-      xla::ProgramShape(computation.proto().host_program_shape()),
-      /*ignore_layouts=*/false);
-  DebugOptions debug_options = GetDebugOptionsForTest();
-  debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false);
-  hlo_config.set_debug_options(debug_options);
-  TF_ASSERT_OK_AND_ASSIGN(auto hlo, xla::HloModule::CreateFromProto(
-                                        computation.proto(), hlo_config));
-  // TF_ASSERT_OK_AND_ASSIGN(
-  //     HloSchedule schedule,
-  //     ScheduleModule(hlo.get(), [](const BufferValue& buffer) {
-  //       return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
-  //     }));
-  // TF_CHECK_OK(hlo->set_schedule(std::move(schedule)));
-
-  const char* expected = R"(
-    ; CHECK:     %address-computation {{.*}} {
-    ; CHECK-DAG:   [[P2:%[^ ]+]] = (f32[1024]{0}, f32[8]{0}) parameter(2)
-    ; CHECK-DAG:   [[P1:%[^ ]+]] = f32[256]{0} parameter(1)
-    ; CHECK-DAG:   [[P0:%[^ ]+]] = f32[8,8]{1,0} parameter(0)
-    ; CHECK-DAG:   [[S0:%[^ ]+]] = f32[4,8]{1,0} slice([[P0]]), slice={[0:4], [0:8]}
-    ; CHECK-DAG:   [[T0:%[^ ]+]] = (f32[4,8]{1,0}, f32[256]{0}) tuple([[S0]], [[P1]])
-    ; CHECK:       [[CC:%[^ ]+]] = (f32[8]{0}, (f32[128]{0}, f32[256]{0}), f32[1024]{0}, f32[4,8]{1,0}) custom-call([[T0]], [[P2]]),
-    ; CHECK:              custom_call_target="Callback_Void"
-    ; CHECK-DAG:   [[GTE0:%[^ ]+]] = f32[8]{0} get-tuple-element([[CC]]), index=0
-    ; CHECK-DAG:   [[GTE1:%[^ ]+]] = (f32[128]{0}, f32[256]{0}) get-tuple-element([[CC]]), index=1
-    ; CHECK-DAG:   [[GTE2:%[^ ]+]] = f32[128]{0} get-tuple-element([[GTE1]]), index=0
-    ; CHECK-DAG:   [[GTE3:%[^ ]+]] = f32[256]{0} get-tuple-element([[GTE1]]), index=1
-    ; CHECK-DAG:   [[T1:%[^ ]+]] = (f32[128]{0}, f32[256]{0}) tuple([[GTE2]], [[GTE3]])
-    ; CHECK-DAG:   [[GTE4:%[^ ]+]] = f32[1024]{0} get-tuple-element([[CC]]), index=2
-    ; CHECK-DAG:   [[GTE5:%[^ ]+]] = f32[4,8]{1,0} get-tuple-element([[CC]]), index=3
-    ; CHECK:       ROOT {{.*}} = (f32[8]{0}, (f32[128]{0}, f32[256]{0}), f32[1024]{0}, f32[4,8]{1,0}) tuple([[GTE0]], [[T1]], [[GTE4]], [[GTE5]])
-    ; CHECK:     }
-
-    ; CHECK:     ENTRY %{{.*}} {
-    ; CHECK:       [[FUSION:%[^ ]+]] = (f32[8]{0}, (f32[128]{0}, f32[256]{0}), f32[1024]{0}, f32[4,8]{1,0}) fusion
-    ; CHECK:         kind=kCustom, calls=%address-computation,
-    ; CHECK:         backend_config={
-    ; CHECK:           "kind":"__custom_fusion",
-    ; CHECK:           "custom_fusion_config":{"name":"address_computation","kernel_index":0}
-    ; CHECK:         }
-    ; CHECK-DAG:   [[GTE6:%[^ ]+]] = f32[1024]{0} get-tuple-element([[FUSION]]), index=2
-    ; CHECK-DAG:   [[GTE7:%[^ ]+]] = (f32[128]{0}, f32[256]{0}) get-tuple-element([[FUSION]]), index=1
-    ; CHECK-DAG:   [[GTE8:%[^ ]+]] = f32[128]{0} get-tuple-element([[GTE7]]), index=0
-    ; CHECK:       ROOT {{.*}} = (f32[128]{0}, f32[1024]{0}) tuple([[GTE8]], [[GTE6]])
-    ; CHECK:     }
-  )";
-
-  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
-  RunAndFilecheckHloRewrite(hlo->ToString(), DynamicSliceFusionRewriter("gpu"),
-                            expected);
-}
-
-TEST_F(DynamicSliceFusionRewriterTest, UnalignedSlice) {
-  XlaBuilder b(TestName());
-  CustomCall(
-      &b, "Callback_Void",
-      /*operands=*/
-      {Slice(Broadcast(ConstantR0WithType(&b, S32, 42), {17}), {1}, {17}, {1})},
-      ShapeUtil::MakeShape(S32, {16}), /*opaque=*/"");
-  TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build());
-  xla::HloModuleConfig hlo_config(
-      xla::ProgramShape(computation.proto().host_program_shape()),
-      /*ignore_layouts=*/false);
-  DebugOptions debug_options = GetDebugOptionsForTest();
-  debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false);
-  hlo_config.set_debug_options(debug_options);
-  TF_ASSERT_OK_AND_ASSIGN(auto hlo, xla::HloModule::CreateFromProto(
-                                        computation.proto(), hlo_config));
-  // TF_ASSERT_OK_AND_ASSIGN(
-  //     HloSchedule schedule,
-  //     ScheduleModule(hlo.get(), [](const BufferValue& buffer) {
-  //       return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
-  //     }));
-  // TF_CHECK_OK(hlo->set_schedule(std::move(schedule)));
-
-  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
-  RunAndFilecheckHloRewrite(hlo->ToString(), DynamicSliceFusionRewriter("gpu"),
-                            std::nullopt);
-}
-
-TEST_F(DynamicSliceFusionRewriterTest, DynamicSimpleGemm) {
-  const char* hlo = R"(
-    HloModule test
-
-    ENTRY main.9 {
-      p0 = f16[2,8,8]{2,1,0} parameter(0)
-      p1 = f16[2,8,8]{2,1,0} parameter(1)
-      c1_s32 = s32[] constant(1)
-      c0_s32 = s32[] constant(0)
-      slice.13 = f16[1,8,8]{2,1,0} dynamic-slice(p0, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8}
-      bitcast.41 = f16[8,8]{1,0} bitcast(slice.13)
-      slice.14 = f16[1,8,8]{2,1,0} dynamic-slice(p1, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8}
-      bitcast.42 = f16[8,8]{1,0} bitcast(slice.14)
-
-      ROOT custom-call.1 = f16[8,8]{1,0} custom-call(bitcast.41, bitcast.42),
-        custom_call_target="__cublas$gemm",
-        backend_config={"gemm_backend_config":{
-          "alpha_real":1,
-          "beta":0,
-          "dot_dimension_numbers":{
-            "lhs_contracting_dimensions":["1"],
-            "rhs_contracting_dimensions":["0"],
-            "lhs_batch_dimensions":[],
-            "rhs_batch_dimensions":[]
-          },
-          "alpha_imag":0,
-          "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
-          "epilogue":"DEFAULT",
-          "lhs_stride":"64",
-          "rhs_stride":"64",
-          "grad_x":false,
-          "grad_y":false
-        }}
-    }
-  )";
-
-  const char* expected = R"(
-    ; CHECK:     address-computation {{.*}} {
-    ; CHECK-DAG:   [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
-    ; CHECK-DAG:   [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(3)
-    ; CHECK-DAG:   [[C1:%[^ ]+]] = s32[] parameter(1)
-    ; CHECK-DAG:   [[C0:%[^ ]+]] = s32[] parameter(2)
-    ; CHECK-DAG:   [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P0]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8}
-    ; CHECK-DAG:   [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
-    ; CHECK-DAG:   [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P1]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8}
-    ; CHECK-DAG:   [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
-    ; CHECK:       ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]),
-    ; CHECK:              custom_call_target="__cublas$gemm"
-    ; CHECK:     }
-
-    ; CHECK:     ENTRY %main{{.*}} {
-    ; CHECK:       ROOT [[FUSION:%[^ ]+]] = f16[8,8]{1,0} fusion
-    ; CHECK:         kind=kCustom, calls=%address-computation,
-    ; CHECK:         backend_config={
-    ; CHECK:           "kind":"__custom_fusion",
-    ; CHECK:           "custom_fusion_config":{"name":"dynamic_address_computation","kernel_index":0}
-    ; CHECK:         }
-    ; CHECK:     }
-  )";
-
-  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
-  RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
-}
-
-TEST_F(DynamicSliceFusionRewriterTest, DynamicSimpleGemmWithWorkspace) {
-  const char* hlo = R"(
-    HloModule test
-
-    ENTRY main.9 {
-      p0 = f16[2,8,8]{2,1,0} parameter(0)
-      p1 = f16[2,8,8]{2,1,0} parameter(1)
-      c1_s32 = s32[] constant(1)
-      c0_s32 = s32[] constant(0)
-      slice.13 = f16[1,8,8]{2,1,0} dynamic-slice(p0, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8}
-      bitcast.41 = f16[8,8]{1,0} bitcast(slice.13)
-      slice.14 = f16[1,8,8]{2,1,0} dynamic-slice(p1, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8}
-      bitcast.42 = f16[8,8]{1,0} bitcast(slice.14)
-
-      ROOT custom-call.1 = (f16[8,8]{1,0}, s8[256]{0}) custom-call(bitcast.41, bitcast.42),
-        custom_call_target="__cublas$gemm",
-        backend_config={"gemm_backend_config":{
-          "alpha_real":1,
-          "beta":0,
-          "dot_dimension_numbers":{
-            "lhs_contracting_dimensions":["1"],
-            "rhs_contracting_dimensions":["0"],
-            "lhs_batch_dimensions":[],
-            "rhs_batch_dimensions":[]
-          },
-          "alpha_imag":0,
-          "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
-          "epilogue":"DEFAULT",
-          "lhs_stride":"64",
-          "rhs_stride":"64",
-          "grad_x":false,
-          "grad_y":false
-        }}
-    }
-  )";
-
-  const char* expected = R"(
-    ; CHECK:     address-computation {{.*}} {
-    ; CHECK-DAG:   [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
-    ; CHECK-DAG:   [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(3)
-    ; CHECK-DAG:   [[C1:%[^ ]+]] = s32[] parameter(1)
-    ; CHECK-DAG:   [[C0:%[^ ]+]] = s32[] parameter(2)
-    ; CHECK-DAG:   [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P0]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8}
-    ; CHECK-DAG:   [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
-    ; CHECK-DAG:   [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P1]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8}
-    ; CHECK-DAG:   [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
-    ; CHECK:       [[CC:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0}) custom-call([[B0]], [[B1]]),
-    ; CHECK:              custom_call_target="__cublas$gemm"
-    ; CHECK:       [[DOT:%[^ ]+]] = f16[8,8]{1,0} get-tuple-element([[CC]]), index=0
-    ; CHECK:       [[WORKSPACE:%[^ ]+]] = s8[256]{0} get-tuple-element([[CC]]), index=1
-    ; CHECK:       ROOT [[TUPLE:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0})
-    ; CHECK:              tuple([[DOT]], [[WORKSPACE]])
-    ; CHECK:     }
-
-
-    ; CHECK:     ENTRY %main{{.*}} {
-    ; CHECK:       ROOT [[FUSION:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0}) fusion
-    ; CHECK:         kind=kCustom, calls=%address-computation,
-    ; CHECK:         backend_config={
-    ; CHECK:           "kind":"__custom_fusion",
-    ; CHECK:           "custom_fusion_config":{"name":"dynamic_address_computation","kernel_index":0}
-    ; CHECK:         }
-    ; CHECK:     }
-  )";
-
-  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
-  RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
-}
-
-TEST_F(DynamicSliceFusionRewriterTest, DynamicSimpleGemmWorkspaceIgnored) {
-  const char* hlo = R"(
-    HloModule test
-
-    ENTRY main.9 {
-      p0 = f16[2,8,8]{2,1,0} parameter(0)
-      p1 = f16[2,8,8]{2,1,0} parameter(1)
-      c1_s32 = s32[] constant(1)
-      c0_s32 = s32[] constant(0)
-      slice.13 = f16[1,8,8]{2,1,0} dynamic-slice(p0, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8}
-      bitcast.41 = f16[8,8]{1,0} bitcast(slice.13)
-      slice.14 = f16[1,8,8]{2,1,0} dynamic-slice(p1, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8}
-      bitcast.42 = f16[8,8]{1,0} bitcast(slice.14)
-
-      custom-call.1 = (f16[8,8]{1,0}, s8[256]{0}) custom-call(bitcast.41, bitcast.42),
-        custom_call_target="__cublas$gemm",
-        backend_config={"gemm_backend_config":{
-          "alpha_real":1,
-          "beta":0,
-          "dot_dimension_numbers":{
-            "lhs_contracting_dimensions":["1"],
-            "rhs_contracting_dimensions":["0"],
-            "lhs_batch_dimensions":[],
-            "rhs_batch_dimensions":[]
-          },
-          "alpha_imag":0,
-          "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
-          "epilogue":"DEFAULT",
-          "lhs_stride":"64",
-          "rhs_stride":"64",
-          "grad_x":false,
-          "grad_y":false
-        }}
-      ROOT get-tuple-element.0 = f16[8,8]{1,0} get-tuple-element(custom-call.1), index=0
-    }
-  )";
-
-  const char* expected = R"(
-    ; CHECK:     address-computation {{.*}} {
-    ; CHECK-DAG:   [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
-    ; CHECK-DAG:   [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(3)
-    ; CHECK-DAG:   [[C1:%[^ ]+]] = s32[] parameter(1)
-    ; CHECK-DAG:   [[C0:%[^ ]+]] = s32[] parameter(2)
-    ; CHECK-DAG:   [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P0]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8}
-    ; CHECK-DAG:   [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
-    ; CHECK-DAG:   [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P1]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8}
-    ; CHECK-DAG:   [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
-    ; CHECK:       [[CC:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0}) custom-call([[B0]], [[B1]]),
-    ; CHECK:              custom_call_target="__cublas$gemm"
-    ; CHECK:       [[DOT:%[^ ]+]] = f16[8,8]{1,0} get-tuple-element([[CC]]), index=0
-    ; CHECK:       [[WORKSPACE:%[^ ]+]] = s8[256]{0} get-tuple-element([[CC]]), index=1
-    ; CHECK:       ROOT [[TUPLE:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0})
-    ; CHECK:              tuple([[DOT]], [[WORKSPACE]])
-    ; CHECK:     }
-
-    ; CHECK:     ENTRY %main{{.*}} {
-    ; CHECK:       [[FUSION:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0}) fusion
-    ; CHECK:         kind=kCustom, calls=%address-computation,
-    ; CHECK:         backend_config={
-    ; CHECK:           "kind":"__custom_fusion",
-    ; CHECK:           "custom_fusion_config":{"name":"dynamic_address_computation","kernel_index":0}
-    ; CHECK:         }
-    ; CHECK:       ROOT [[DOT_MAIN:%[^ ]+]] = f16[8,8]{1,0} get-tuple-element([[FUSION]]), index=0
-    ; CHECK:     }
-  )";
-
-  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
-  RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
-}
-
-TEST_F(DynamicSliceFusionRewriterTest, DynamicSimpleGemmNotRoot) {
-  const char* hlo = R"(
-    HloModule test
-
-    ENTRY main.9 {
-      p0 = f16[2,8,8]{2,1,0} parameter(0)
-      p1 = f16[2,8,8]{2,1,0} parameter(1)
-      c1_s32 = s32[] constant(1)
-      c0_s32 = s32[] constant(0)
-      slice.13 = f16[1,8,8]{2,1,0} dynamic-slice(p0, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8}
-      bitcast.41 = f16[8,8]{1,0} bitcast(slice.13)
-      slice.14 = f16[1,8,8]{2,1,0} dynamic-slice(p1, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8}
-      bitcast.42 = f16[8,8]{1,0} bitcast(slice.14)
-
-      custom-call.1 = f16[8,8]{1,0} custom-call(bitcast.41, bitcast.42),
-        custom_call_target="__cublas$gemm",
-        backend_config={"gemm_backend_config":{
-          "alpha_real":1,
-          "beta":0,
-          "dot_dimension_numbers":{
-            "lhs_contracting_dimensions":["1"],
-            "rhs_contracting_dimensions":["0"],
-            "lhs_batch_dimensions":[],
-            "rhs_batch_dimensions":[]
-          },
-          "alpha_imag":0,
-          "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
-          "epilogue":"DEFAULT",
-          "lhs_stride":"64",
-          "rhs_stride":"64",
-          "grad_x":false,
-          "grad_y":false
-        }}
-      ROOT res = f16[8,8]{1,0} add(custom-call.1, custom-call.1)
-    }
-  )";
-
-  const char* expected = R"(
-    ; CHECK:     address-computation {{.*}} {
-    ; CHECK-DAG:   [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
-    ; CHECK-DAG:   [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(3)
-    ; CHECK-DAG:   [[C1:%[^ ]+]] = s32[] parameter(1)
-    ; CHECK-DAG:   [[C0:%[^ ]+]] = s32[] parameter(2)
-    ; CHECK-DAG:   [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P0]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8}
-    ; CHECK-DAG:   [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
-    ; CHECK-DAG:   [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P1]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8}
-    ; CHECK-DAG:   [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
-    ; CHECK:       ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]),
-    ; CHECK:              custom_call_target="__cublas$gemm"
-    ; CHECK:     }
-
-    ; CHECK:     ENTRY %main{{.*}} {
-    ; CHECK:       [[FUSION:%[^ ]+]] = f16[8,8]{1,0} fusion
-    ; CHECK:         kind=kCustom, calls=%address-computation,
-    ; CHECK:         backend_config={
-    ; CHECK:           "kind":"__custom_fusion",
-    ; CHECK:           "custom_fusion_config":{"name":"dynamic_address_computation","kernel_index":0}
-    ; CHECK:         }
-    ; CHECK:       ROOT {{.*}} = f16[8,8]{1,0} add([[FUSION]], [[FUSION]])
-    ; CHECK:     }
-  )";
-
-  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
-  RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
-}
-
-TEST_F(DynamicSliceFusionRewriterTest, DUSSimpleGemm) {
-  const char* hlo = R"(
-    HloModule test
-
-    ENTRY main.9 {
-      p0 = f16[1,8,8]{2,1,0} parameter(0)
-      p1 = f16[1,8,8]{2,1,0} parameter(1)
-      p2 = f16[4,8,8]{2,1,0} parameter(2)
-      c1_s32 = s32[] constant(1)
-      c0_s32 = s32[] constant(0)
-      bitcast.41 = f16[8,8]{1,0} bitcast(p0)
-      bitcast.42 = f16[8,8]{1,0} bitcast(p1)
-
-      custom-call.1 = f16[8,8]{1,0} custom-call(bitcast.41, bitcast.42),
-        custom_call_target="__cublas$gemm",
-        backend_config={"gemm_backend_config":{
-          "alpha_real":1,
-          "beta":0,
-          "dot_dimension_numbers":{
-            "lhs_contracting_dimensions":["1"],
-            "rhs_contracting_dimensions":["0"],
-            "lhs_batch_dimensions":[],
-            "rhs_batch_dimensions":[]
-          },
-          "alpha_imag":0,
-          "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
-          "epilogue":"DEFAULT",
-          "lhs_stride":"64",
-          "rhs_stride":"64",
-          "grad_x":false,
-          "grad_y":false
-        }}
-      bitcast.43 = f16[1,8,8]{2,1,0} bitcast(custom-call.1)
-      ROOT dus = f16[4,8,8]{2,1,0} dynamic-update-slice(p2, bitcast.43, c1_s32, c0_s32, c0_s32)
-    }
-  )";
-
-  const char* expected = R"(
-    ; CHECK-DAG:   [[P0:%[^ ]+]] = f16[8,8]{1,0} parameter(0)
-    ; CHECK-DAG:   [[P1:%[^ ]+]] = f16[8,8]{1,0} parameter(1)
-    ; CHECK-DAG:   [[P2:%[^ ]+]] = f16[4,8,8]{2,1,0} parameter(2)
-    ; CHECK-DAG:   [[C1:%[^ ]+]] = s32[] parameter(3)
-    ; CHECK-DAG:   [[C0:%[^ ]+]] = s32[] parameter(4)
-    ; CHECK-DAG:   [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[P0]], [[P1]]),
-    ; CHECK-DAG:          custom_call_target="__cublas$gemm"
-    ; CHECK-DAG:   [[BC:%[^ ]+]] = f16[1,8,8]{2,1,0} bitcast([[CC]])
-    ; CHECK:       ROOT {{.*}} = f16[4,8,8]{2,1,0} dynamic-update-slice([[P2]], [[BC]], [[C1]], [[C0]], [[C0]])
-    ; CHECK:     }
-
-    ; CHECK:     ENTRY %main{{.*}} {
-    ; CHECK:       ROOT [[FUSION:%[^ ]+]] = f16[4,8,8]{2,1,0} fusion
-    ; CHECK:         kind=kCustom, calls=%address-computation,
-    ; CHECK:         backend_config={
-    ; CHECK:           "kind":"__custom_fusion",
-    ; CHECK:           "custom_fusion_config":{"name":"dynamic_address_computation","kernel_index":0}
-    ; CHECK:         }
-    ; CHECK:     }
-  )";
-
-  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
-  RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
-}
-
-TEST_F(DynamicSliceFusionRewriterTest, DUSSimpleGemmNotRoot) {
-  const char* hlo = R"(
-    HloModule test
-
-    ENTRY main.9 {
-      p0 = f16[2,8,8]{2,1,0} parameter(0)
-      p1 = f16[2,8,8]{2,1,0} parameter(1)
-      p2 = f16[4,8,8]{2,1,0} parameter(2)
-      c1_s32 = s32[] constant(1)
-      c0_s32 = s32[] constant(0)
-      slice.13 = f16[1,8,8]{2,1,0} dynamic-slice(p0, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8}
-      bitcast.41 = f16[8,8]{1,0} bitcast(slice.13)
-      slice.14 = f16[1,8,8]{2,1,0} dynamic-slice(p1, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8}
-      bitcast.42 = f16[8,8]{1,0} bitcast(slice.14)
-
-      custom-call.1 = f16[8,8]{1,0} custom-call(bitcast.41, bitcast.42),
-        custom_call_target="__cublas$gemm",
-        backend_config={"gemm_backend_config":{
-          "alpha_real":1,
-          "beta":0,
-          "dot_dimension_numbers":{
-            "lhs_contracting_dimensions":["1"],
-            "rhs_contracting_dimensions":["0"],
-            "lhs_batch_dimensions":[],
-            "rhs_batch_dimensions":[]
-          },
-          "alpha_imag":0,
-          "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
-          "epilogue":"DEFAULT",
-          "lhs_stride":"64",
-          "rhs_stride":"64",
-          "grad_x":false,
-          "grad_y":false
-        }}
-      bitcast.43 = f16[1,8,8]{2,1,0} bitcast(custom-call.1)
-      dus = f16[4,8,8]{2,1,0} dynamic-update-slice(p2, bitcast.43, c1_s32, c0_s32, c0_s32)
-      ROOT res = f16[4,8,8]{2,1,0} log(dus)
-    }
-  )";
-
-  const char* expected = R"(
-    ; CHECK:     address-computation {{.*}} {
-    ; CHECK-DAG:   [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
-    ; CHECK-DAG:   [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(3)
-    ; CHECK-DAG:   [[P2:%[^ ]+]] = f16[4,8,8]{2,1,0} parameter(4)
-    ; CHECK-DAG:   [[C1:%[^ ]+]] = s32[] parameter(1)
-    ; CHECK-DAG:   [[C0:%[^ ]+]] = s32[] parameter(2)
-    ; CHECK-DAG:   [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P0]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8}
-    ; CHECK-DAG:   [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
-    ; CHECK-DAG:   [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P1]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8}
-    ; CHECK-DAG:   [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
-    ; CHECK-DAG:   [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]),
-    ; CHECK-DAG:          custom_call_target="__cublas$gemm"
-    ; CHECK-DAG:   [[BC:%[^ ]+]] = f16[1,8,8]{2,1,0} bitcast([[CC]])
-    ; CHECK:       ROOT {{.*}} = f16[4,8,8]{2,1,0} dynamic-update-slice([[P2]], [[BC]], [[C1]], [[C0]], [[C0]])
-    ; CHECK:     }
-
-    ; CHECK:     ENTRY %main{{.*}} {
-    ; CHECK:       [[FUSION:%[^ ]+]] = f16[4,8,8]{2,1,0} fusion
-    ; CHECK:         kind=kCustom, calls=%address-computation,
-    ; CHECK:         backend_config={
-    ; CHECK:           "kind":"__custom_fusion",
-    ; CHECK:           "custom_fusion_config":{"name":"dynamic_address_computation","kernel_index":0}
-    ; CHECK:         }
-    ; CHECK:       ROOT {{.*}} = f16[4,8,8]{2,1,0} log([[FUSION]])
-    ; CHECK:     }
-  )";
-
-  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
-  RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
-}
-
-TEST_F(DynamicSliceFusionRewriterTest, DUSSimpleGemmWithWorkspace) {
-  const char* hlo = R"(
-    HloModule test
-
-    ENTRY main.9 {
-      p0 = f16[2,8,8]{2,1,0} parameter(0)
-      p1 = f16[2,8,8]{2,1,0} parameter(1)
-      p2 = f16[4,8,8]{2,1,0} parameter(2)
-      c1_s32 = s32[] constant(1)
-      c0_s32 = s32[] constant(0)
-      slice.13 = f16[1,8,8]{2,1,0} dynamic-slice(p0, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8}
-      bitcast.41 = f16[8,8]{1,0} bitcast(slice.13)
-      slice.14 = f16[1,8,8]{2,1,0} dynamic-slice(p1, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8}
-      bitcast.42 = f16[8,8]{1,0} bitcast(slice.14)
-
-      custom-call.1 = (f16[8,8]{1,0}, s8[256]{0}) custom-call(bitcast.41, bitcast.42),
-        custom_call_target="__cublas$gemm",
-        backend_config={"gemm_backend_config":{
-          "alpha_real":1,
-          "beta":0,
-          "dot_dimension_numbers":{
-            "lhs_contracting_dimensions":["1"],
-            "rhs_contracting_dimensions":["0"],
-            "lhs_batch_dimensions":[],
-            "rhs_batch_dimensions":[]
-          },
-          "alpha_imag":0,
-          "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
-          "epilogue":"DEFAULT",
-          "lhs_stride":"64",
-          "rhs_stride":"64",
-          "grad_x":false,
-          "grad_y":false
-        }}
-
-    get-tuple-element.0 = f16[8,8]{1,0} get-tuple-element(custom-call.1), index=0
-    bitcast.43 = f16[1,8,8]{2,1,0} bitcast(get-tuple-element.0)
-    dus = f16[4,8,8]{2,1,0} dynamic-update-slice(p2, bitcast.43, c1_s32, c0_s32, c0_s32)
-    get-tuple-element.1 = s8[256]{0} get-tuple-element(custom-call.1), index=1
-    ROOT tuple = (f16[4,8,8]{2,1,0}, s8[256]{0}) tuple(dus, get-tuple-element.1)
-    }
-  )";
-
-  const char* expected = R"(
-    ; CHECK:     address-computation {{.*}} {
-    ; CHECK-DAG:   [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
-    ; CHECK-DAG:   [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(3)
-    ; CHECK-DAG:   [[P2:%[^ ]+]] = f16[4,8,8]{2,1,0} parameter(4)
-    ; CHECK-DAG:   [[C1:%[^ ]+]] = s32[] parameter(1)
-    ; CHECK-DAG:   [[C0:%[^ ]+]] = s32[] parameter(2)
-    ; CHECK-DAG:   [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P0]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8}
-    ; CHECK-DAG:   [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
-    ; CHECK-DAG:   [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P1]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8}
-    ; CHECK-DAG:   [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
-    ; CHECK:       [[CC:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0}) custom-call([[B0]], [[B1]]),
-    ; CHECK:              custom_call_target="__cublas$gemm"
-    ; CHECK:       [[DOT:%[^ ]+]] = f16[8,8]{1,0} get-tuple-element([[CC]]), index=0
-    ; CHECK:       [[BC:%[^ ]+]] = f16[1,8,8]{2,1,0} bitcast([[DOT]])
-    ; CHECK:       [[DUS:%[^ ]+]] = f16[4,8,8]{2,1,0} dynamic-update-slice([[P2]], [[BC]], [[C1]], [[C0]], [[C0]])
-    ; CHECK:       [[WORKSPACE:%[^ ]+]] = s8[256]{0} get-tuple-element([[CC]]), index=1
-    ; CHECK:       ROOT [[TUPLE:%[^ ]+]] = (f16[4,8,8]{2,1,0}, s8[256]{0})
-    ; CHECK:              tuple([[DUS]], [[WORKSPACE]])
-    ; CHECK:     }
-
-    ; CHECK:     ENTRY %main{{.*}} {
-    ; CHECK:       [[FUSION:%[^ ]+]] = (f16[4,8,8]{2,1,0}, s8[256]{0}) fusion
-    ; CHECK:         kind=kCustom, calls=%address-computation,
-    ; CHECK:         backend_config={
-    ; CHECK:           "kind":"__custom_fusion",
-    ; CHECK:           "custom_fusion_config":{"name":"dynamic_address_computation","kernel_index":0}
-    ; CHECK:         }
-    ; CHECK:       [[DUS_MAIN:%[^ ]+]] = f16[4,8,8]{2,1,0} get-tuple-element([[FUSION]]), index=0
-    ; CHECK:       [[WORKSPACE_MAIN:%[^ ]+]] = s8[256]{0} get-tuple-element([[FUSION]]), index=1
-    ; CHECK:       ROOT {{.*}} = (f16[4,8,8]{2,1,0}, s8[256]{0})
-    ; CHECK:              tuple([[DUS_MAIN]], [[WORKSPACE_MAIN]])
-    ; CHECK:     }
-  )";
-
-  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
-  RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
-}
-
-TEST_F(DynamicSliceFusionRewriterTest, DUSSimpleGemmWorkspaceIgnored) {
-  const char* hlo = R"(
-    HloModule test
-
-    ENTRY %main.9 {
-      %p0 = f16[8,8]{1,0} parameter(0)
-      %p1 = f16[8,8]{1,0} parameter(1)
-      %p2 = f16[4,8,8]{2,1,0} parameter(2)
-      %c1_s32 = s32[] constant(1)
-      %c0_s32 = s32[] constant(0)
-
-      %custom-call.1 = (f16[8,8]{1,0}, s8[256]{0}) custom-call(%p0, %p1),
-        custom_call_target="__cublas$gemm",
-        backend_config={"gemm_backend_config":{
-          "alpha_real":1,
-          "beta":0,
-          "dot_dimension_numbers":{
-            "lhs_contracting_dimensions":["1"],
-            "rhs_contracting_dimensions":["0"],
-            "lhs_batch_dimensions":[],
-            "rhs_batch_dimensions":[]
-          },
-          "alpha_imag":0,
-          "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
-          "epilogue":"DEFAULT",
-          "lhs_stride":"64",
-          "rhs_stride":"64",
-          "grad_x":false,
-          "grad_y":false
-        }}
-      %get-tuple-element.0 = f16[8,8]{1,0} get-tuple-element(%custom-call.1), index=0
-      %bitcast.43 = f16[1,8,8]{2,1,0} bitcast(%get-tuple-element.0)
-      ROOT %dus = f16[4,8,8]{2,1,0} dynamic-update-slice(%p2, %bitcast.43, %c1_s32, %c0_s32, %c0_s32)
-    })";
-
-  const char* expected = R"(
-    ; CHECK:     address-computation {{.*}} {
-    ; CHECK-DAG:   [[P0:%[^ ]+]] = f16[8,8]{1,0} parameter(0)
-    ; CHECK-DAG:   [[P1:%[^ ]+]] = f16[8,8]{1,0} parameter(1)
-    ; CHECK-DAG:   [[P2:%[^ ]+]] = f16[4,8,8]{2,1,0} parameter(2)
-    ; CHECK-DAG:   [[C1:%[^ ]+]] = s32[] parameter(3)
-    ; CHECK-DAG:   [[C0:%[^ ]+]] = s32[] parameter(4)
-    ; CHECK-DAG:   [[CC:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0}) custom-call([[P0]], [[P1]]),
-    ; CHECK-DAG:          custom_call_target="__cublas$gemm"
-    ; CHECK-DAG:   [[DOT:%[^ ]+]] = f16[8,8]{1,0} get-tuple-element([[CC]]), index=0
-    ; CHECK-DAG:   [[BC:%[^ ]+]] = f16[1,8,8]{2,1,0} bitcast([[DOT]])
-    ; CHECK-DAG:   [[DUS:%[^ ]+]] = f16[4,8,8]{2,1,0} dynamic-update-slice([[P2]], [[BC]], [[C1]], [[C0]], [[C0]])
-    ; CHECK-DAG:   [[WORKSPACE:%[^ ]+]] = s8[256]{0} get-tuple-element([[CC]]), index=1
-    ; CHECK:       ROOT [[TUPLE:%[^ ]+]] = (f16[4,8,8]{2,1,0}, s8[256]{0})
-    ; CHECK:              tuple([[DUS]], [[WORKSPACE]])
-    ; CHECK:     }
-
-    ; CHECK:     ENTRY %main{{.*}} {
-    ; CHECK:       [[FUSION:%[^ ]+]] = (f16[4,8,8]{2,1,0}, s8[256]{0}) fusion
-    ; CHECK:         kind=kCustom, calls=%address-computation,
-    ; CHECK:         backend_config={
-    ; CHECK:           "kind":"__custom_fusion",
-    ; CHECK:           "custom_fusion_config":{"name":"dynamic_address_computation","kernel_index":0}
-    ; CHECK:         }
-    ; CHECK:       ROOT [[DOT_MAIN:%[^ ]+]] = f16[4,8,8]{2,1,0} get-tuple-element([[FUSION]]), index=0
-    ; CHECK:     }
-  )";
-
-  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
-  RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
-}
-
-}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/fusion_merger.cc b/third_party/xla/xla/service/gpu/fusion_merger.cc
deleted file mode 100644
index 3703c98..0000000
--- a/third_party/xla/xla/service/gpu/fusion_merger.cc
+++ /dev/null
@@ -1,327 +0,0 @@
-/* Copyright 2016 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/fusion_merger.h"
-
-#include <optional>
-#include <string>
-#include <vector>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/log/check.h"
-#include "absl/log/log.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/str_join.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/gpu/gpu_fusible.h"
-#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h"
-#include "xla/service/gpu/model/gpu_performance_model.h"
-#include "xla/service/gpu/model/gpu_performance_model_base.h"
-#include "xla/service/hlo_cost_analysis.h"
-#include "xla/service/hlo_graph_dumper.h"
-#include "xla/service/instruction_fusion.h"
-#include "xla/shape_util.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/util.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/status.h"
-
-namespace xla {
-namespace gpu {
-
-// For each fusion F, attempts to fuse F into *all* of F's users (does not fuse
-// if can't fuse into at least one).
-class FusionInstructionMerger {
- public:
-  explicit FusionInstructionMerger(
-      HloComputation* computation, const se::DeviceDescription& gpu_device_info,
-      HloCostAnalysis::ShapeSizeFunction shape_size_function)
-      : computation_(computation),
-        shape_size_function_(shape_size_function),
-        gpu_device_info_(gpu_device_info),
-        dump_fusion_visualization_(computation->parent()
-                                       ->config()
-                                       .debug_options()
-                                       .xla_dump_fusion_visualization()) {}
-
-  absl::Status Run();
-
-  bool changed() const { return changed_; }
-
- private:
-  FusionDecision ShouldFuse(HloInstruction* producer);
-  absl::Status FuseIntoAllUsers(HloInstruction* producer);
-
-  HloComputation* computation_;
-  HloCostAnalysis::ShapeSizeFunction shape_size_function_;
-  // Many cheap checks can prevent fusion merging - postpone execution of full
-  // HLO cost analysis of the computation so that it may be not needed at all.
-  std::optional<GpuHloCostAnalysis> cost_analysis_;
-  FusionInfoCache fusion_info_cache_;
-  const se::DeviceDescription& gpu_device_info_;
-  bool changed_ = false;
-  bool dump_fusion_visualization_ = false;
-
-  // Fusion instruction merge stats.
-  int total_visited_ = 0;
-  int total_merged_ = 0;
-  int num_fail_no_users_ = 0;
-  int num_fail_not_loop_fusion_ = 0;
-  int num_fail_merge_all_users_ = 0;
-  int num_fail_inefficient_fusion_emitter_ = 0;
-  int num_fail_fusion_too_large_ = 0;
-  int num_fail_uncoalesced_read_ = 0;
-  int num_fail_slower_if_fused_ = 0;
-
-  FusionInstructionMerger(const FusionInstructionMerger&) = delete;
-  FusionInstructionMerger& operator=(const FusionInstructionMerger&) = delete;
-};
-
-absl::Status FusionInstructionMerger::FuseIntoAllUsers(
-    HloInstruction* producer) {
-  // Merge fused instructions from 'fusion' into each user.
-  std::vector<HloInstruction*> users = producer->users();
-  for (HloInstruction* user : users) {
-    if (dump_fusion_visualization_) {
-      RegisterFusionState(
-          *computation_,
-          absl::StrCat("About to fuse |", producer->name(), "| into |",
-                       user->name(), "| inside FusionMerger"),
-          /*consumer=*/*user,
-          /*producer=*/producer);
-    }
-
-    TF_RETURN_IF_ERROR(cost_analysis_->RemoveInstruction(user));
-
-    // Wrap consumers which are not fusions first.
-    HloInstruction* consumer = user;
-    if (consumer->opcode() != HloOpcode::kFusion) {
-      consumer = computation_->AddInstruction(HloInstruction::CreateFusion(
-          user->shape(), ChooseFusionKind(*producer, *user), user));
-      TF_CHECK_OK(computation_->ReplaceInstruction(user, consumer));
-    }
-
-    consumer->MergeFusionInstruction(producer);
-    TF_RETURN_IF_ERROR(cost_analysis_->RevisitInstruction(consumer));
-    fusion_info_cache_.Invalidate(consumer);
-
-    if (dump_fusion_visualization_) {
-      RegisterFusionState(*computation_,
-                          absl::StrCat("Fused |", producer->name(), "| into |",
-                                       user->name(), "| inside FusionMerger"),
-                          *consumer);
-    }
-
-    changed_ = true;
-  }
-
-  CHECK_EQ(0, producer->user_count()) << producer->ToString();
-  TF_RETURN_IF_ERROR(computation_->RemoveInstruction(producer));
-  TF_RETURN_IF_ERROR(cost_analysis_->RemoveInstruction(producer));
-  fusion_info_cache_.Invalidate(producer);
-  VLOG(2) << "Merged fusion instruction: " << producer->name()
-          << " into users { "
-          << absl::StrJoin(users, ", ",
-                           [](std::string* out, HloInstruction* user) {
-                             absl::StrAppend(out, user->name());
-                           })
-          << " }";
-  return absl::OkStatus();
-}
-
-absl::Status FusionInstructionMerger::Run() {
-  for (HloInstruction* producer : computation_->MakeInstructionPostOrder()) {
-    if (producer->opcode() != HloOpcode::kFusion) {
-      continue;
-    }
-    FusionDecision should_fuse = ShouldFuse(producer);
-    if (should_fuse) {
-      TF_RETURN_IF_ERROR(FuseIntoAllUsers(producer));
-      ++total_merged_;
-    } else {
-      VLOG(3) << "Not fusing fusion |" << producer->name()
-              << "| with all of it's users due to: " << should_fuse.Explain();
-      if (dump_fusion_visualization_ && !producer->users().empty()) {
-        RegisterFusionState(
-            *computation_,
-            absl::StrCat(
-                "Not fusing fusion |", producer->name(),
-                "| into all of its users due to: ", should_fuse.Explain()),
-            // Just pick any consumer, since we are trying to merge into all.
-            /*consumer=*/*producer->users()[0],
-            /*producer=*/producer);
-      }
-    }
-  }
-
-  VLOG(1) << "FusionInstructionMerger EXIT"
-          << " computation: " << computation_->name()
-          << " total_visited: " << total_visited_
-          << " total_merged: " << total_merged_ << " merge failures { "
-          << " no_users: " << num_fail_no_users_
-          << " not_loop_fusion: " << num_fail_not_loop_fusion_
-          << " merge_all_users: " << num_fail_merge_all_users_
-          << " uncoalesced_read: " << num_fail_uncoalesced_read_
-          << " inefficient_fusion_emitter: "
-          << num_fail_inefficient_fusion_emitter_
-          << " slower_if_fused: " << num_fail_slower_if_fused_
-          << " fusion_too_large: " << num_fail_fusion_too_large_ << " }";
-  return absl::OkStatus();
-}
-
-bool TransposesMostData(const HloInstruction& fusion) {
-  float score = 0;
-
-  for (const HloInstruction* instr : fusion.fused_instructions()) {
-    if (IsPhysicallyTransposing(*instr)) {
-      score += 1.0 * ShapeUtil::ElementsInRecursive(instr->shape()) /
-               ShapeUtil::ElementsInRecursive(fusion.shape());
-      if (score >= 0.5) {
-        VLOG(3) << fusion.ToString() << " transpose ratio exceeds " << score;
-        return true;
-      }
-    }
-  }
-
-  return false;
-}
-
-FusionDecision FusionInstructionMerger::ShouldFuse(HloInstruction* producer) {
-  ++total_visited_;
-
-  VLOG(4) << "Considering producer " << producer->name();
-
-  // Skip 'producer' instruction if there are no users into which we can
-  // merge.
-  if (producer->users().empty()) {
-    ++num_fail_no_users_;
-    return "fusion has no users";
-  }
-
-  // Skip 'producer' instruction if it is not a loop fusion. Library fusion
-  // instructions match specific patterns, so they shouldn't be further fused.
-  // Input fusion instructions need to be rooted at a particular HLO (e.g.
-  // kReduce), so they shouldn't be further fused either.
-  if (!producer->IsLoopFusion()) {
-    ++num_fail_not_loop_fusion_;
-    return "not a loop fusion";
-  }
-
-  auto producer_hero = GetRealHeroForMultiOutputFusion(*producer);
-
-  bool has_reduction_user = false;
-  for (const HloInstruction* user : producer->users()) {
-    if (user->opcode() == HloOpcode::kBitcast) {
-      ++num_fail_merge_all_users_;
-      return "not fusing bitcast ops";
-    }
-    if (user->IsCustomFusion()) {
-      ++num_fail_merge_all_users_;
-      return "not fusing custom fusions";
-    }
-    auto consumer_hero = GetRealHeroForMultiOutputFusion(*user);
-    if (auto compatible =
-            FusionHeroesAreCompatible(producer_hero, consumer_hero);
-        !compatible) {
-      return compatible;
-    }
-    FusionDecision fusible = IsProducerConsumerFusible(*producer, *user);
-    if (!fusible) {
-      ++num_fail_merge_all_users_;
-      VLOG(9) << user->ToString();
-      return fusible;
-    }
-    if (IsInputFusibleReduction(*user)) {
-      has_reduction_user = true;
-    }
-  }
-
-  // We do not want to worsen reduction's memory access pattern by connecting
-  // it to a producer which transposes most data.
-  if (has_reduction_user && TransposesMostData(*producer)) {
-    ++num_fail_uncoalesced_read_;
-    return "would read mostly uncoalesced";
-  }
-
-  for (const HloInstruction* user : producer->users()) {
-    // Skip 'fusion' instruction if merging it into at least one of the users
-    // would make the fusion use too much shared memory or registers.
-    FusionDecision fits = FusionFitsInBudget(
-        *user, *producer, gpu_device_info_,
-        /*is_consumer_producer_fusion=*/true, &fusion_info_cache_);
-    if (!fits) {
-      ++num_fail_fusion_too_large_;
-      return fits;
-    }
-  }
-
-  if (!cost_analysis_) {
-    VLOG(2) << "Running full HLO cost analysis for " << computation_->name();
-    cost_analysis_.emplace(
-        GpuHloCostAnalysis::Options{shape_size_function_,
-                                    /*per_second_rates=*/{},
-                                    /*count_multiple_input_accesses=*/true},
-        gpu_device_info_);
-    TF_CHECK_OK(computation_->Accept(&cost_analysis_.value()));
-  }
-
-  for (const HloInstruction* user : producer->users()) {
-    if (cost_analysis_->ProducerConsumerMergedTooLarge(*producer, *user)) {
-      ++num_fail_inefficient_fusion_emitter_;
-      return FusionDecision{} << "if merged with " << user->name()
-                              << " will generate huge IR";
-    }
-  }
-
-  GpuPerformanceModel::RunTimes t = GpuPerformanceModel::EstimateRunTimes(
-      producer, gpu_device_info_, &*cost_analysis_,
-      GpuPerformanceModelOptions::Default(), producer->users());
-  if (t.time_fused > t.time_unfused) {
-    ++num_fail_slower_if_fused_;
-    return "will execute slower if fused";
-  }
-
-  return {};
-}
-
-absl::StatusOr<bool> FusionMerger::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  bool changed = false;
-  VLOG(1) << "FusionMerger for module: " << module->name();
-  for (auto* computation :
-       module->MakeNonfusionComputations(execution_threads)) {
-    VLOG(9) << "Before running FusionInstructionMerger for computation: "
-            << computation->name();
-    XLA_VLOG_LINES(9, computation->ToString());
-
-    FusionInstructionMerger fusion_merger(computation, gpu_device_info_,
-                                          shape_size_function_);
-    TF_RETURN_IF_ERROR(fusion_merger.Run());
-    changed |= fusion_merger.changed();
-
-    VLOG(9) << "After running FusionInstructionMerger for computation: "
-            << computation->name() << " changed: " << changed;
-    XLA_VLOG_LINES(9, computation->ToString());
-  }
-  return changed;
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusion_merger.h b/third_party/xla/xla/service/gpu/fusion_merger.h
deleted file mode 100644
index acbc93e..0000000
--- a/third_party/xla/xla/service/gpu/fusion_merger.h
+++ /dev/null
@@ -1,85 +0,0 @@
-/* Copyright 2016 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_FUSION_MERGER_H_
-#define XLA_SERVICE_GPU_FUSION_MERGER_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_cost_analysis.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/stream_executor/device_description.h"
-
-namespace xla {
-namespace gpu {
-
-// An HLO pass that attempts to merge fusion instructions to reduce memory
-// bandwidth requirements and kernel launch overhead.
-//
-// Consider the example below. On the left-hand side, op A is the producer and
-// ops B and C are its consumers. FusionMerger duplicates producer ops and fuses
-// them into all consumers. The result is depicted on the right-hand side below.
-//
-//        p                    p
-//        |                  /   \
-//        v                 /     \
-//        A            +fusion+  +fusion+
-//      /   \          |  A'  |  |  A"  |
-//     |     |         |  |   |  |  |   |
-//     v     v         |  v   |  |  v   |
-//     B     C         |  B   |  |  C   |
-//                     +------+  +------+
-//
-// Op A has been cloned twice and fused with B and C. The kernel launch overhead
-// is reduced from 3 to 2. The memory bandwidth requirements may be reduced.
-// We trade 1 read of input(A) + 1 write and 2 reads of output(A) for 2 reads of
-// input(A). In general the achieveable savings in memory bandwidth depend on
-// the differences in memory read and written and the number of consumers. The
-// FusionMeger pass takes this into account when making fusion decisions.
-//
-// The pass traverses the HLO module in post-order (defs before uses).
-// Fusion instructions are merged into their users if some conditions are met:
-// * The result of merging the fusion instruction into its users would not
-//   increase bytes transferred.
-// * Producer ops are fusible with _all_ consumers. If they are not fusible with
-//   at least one consumers, they won't be fused at all.
-// * Producers are kLoop fusion ops.
-//
-// None of these restrictions are necessary for correctness. In fact, lifting
-// the latter two could be beneficial.
-
-class FusionMerger : public HloModulePass {
- public:
-  explicit FusionMerger(const se::DeviceDescription& d,
-                        HloCostAnalysis::ShapeSizeFunction f)
-      : gpu_device_info_(d), shape_size_function_(f) {}
-  absl::string_view name() const override { return "fusion_merger"; }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
-  se::DeviceDescription gpu_device_info_;
-  HloCostAnalysis::ShapeSizeFunction shape_size_function_;
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_FUSION_MERGER_H_
diff --git a/third_party/xla/xla/service/gpu/fusion_merger_test.cc b/third_party/xla/xla/service/gpu/fusion_merger_test.cc
deleted file mode 100644
index de45a4b..0000000
--- a/third_party/xla/xla/service/gpu/fusion_merger_test.cc
+++ /dev/null
@@ -1,1162 +0,0 @@
-/* Copyright 2016 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/fusion_merger.h"
-
-#include <cstdint>
-#include <vector>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "absl/types/span.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/gpu/gpu_device_info_for_tests.h"
-#include "xla/service/gpu/gpu_fusible.h"
-#include "xla/service/hlo_cost_analysis.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/xla_data.pb.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-namespace m = ::xla::match;
-
-class FusionMergerTest : public HloTestBase {
-  HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const {
-    return [&](const Shape& shape) {
-      constexpr int64_t kPointerSize = 8;
-      return ShapeUtil::ByteSizeOf(shape, kPointerSize);
-    };
-  }
-
- public:
-  FusionMerger fusion_merger_{TestGpuDeviceInfo::RTXA6000DeviceInfo(),
-                              ShapeSizeBytesFunction()};
-  FusionMergerTest() : HloTestBase() {}
-};
-
-// Tests that we can merge a fusion instruction that is below threshold.
-//
-// Computation after fusion merger pass (Fusion2 is merged into Fusion0 and
-// Fusion1):
-//                   Param
-//                 /   |   \
-//          Fusion3 Fusion0 Fusion1
-//                 \   |   /
-//                   Tuple
-//
-TEST_F(FusionMergerTest, MergeSharedFusionInstruction) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-HloModule MergeSharedFusionInstruction
-
-comp.3 {
-  constant.param_0 = f32[4]{0} parameter(0)
-  param.param_1.2 = (f32[4]{0}, f32[4]{0}, f32[4]{0}) parameter(1)
-  get-tuple-element.6 = f32[4]{0} get-tuple-element(param.param_1.2), index=0
-  ROOT add.7 = f32[4]{0} add(constant.param_0, get-tuple-element.6)
-}
-
-comp.2 {
-  param.param_1.1 = (f32[4]{0}, f32[4]{0}, f32[4]{0}) parameter(0)
-  get-tuple-element.4 = f32[4]{0} get-tuple-element(param.param_1.1), index=1
-  get-tuple-element.5 = f32[4]{0} get-tuple-element(param.param_1.1), index=2
-  ROOT add.6 = f32[4]{0} add(get-tuple-element.4, get-tuple-element.5)
-}
-
-comp.1 {
-  add.1.param_1.1 = f32[4]{0} parameter(1)
-  constant.param_1.3 = f32[4]{0} parameter(0)
-  add.5 = f32[4]{0} add(add.1.param_1.1, constant.param_1.3)
-  ROOT multiply.3 = f32[4]{0} multiply(add.5, constant.param_1.3)
-}
-
-comp {
-  add.1.param_1 = f32[4]{0} parameter(1)
-  constant.param_1.1 = f32[4]{0} parameter(0)
-  multiply.2 = f32[4]{0} multiply(add.1.param_1, constant.param_1.1)
-  ROOT add.4 = f32[4]{0} add(multiply.2, constant.param_1.1)
-}
-
-ENTRY MergeSharedFusionInstruction.Computation0 {
-  constant = f32[4]{0} constant({1, 1, 1, 1})
-  param = (f32[4]{0}, f32[4]{0}, f32[4]{0}) parameter(0)
-  fusion.3 = f32[4]{0} fusion(constant, param), kind=kLoop, calls=comp.3
-  fusion.4 = f32[4]{0} fusion(param), kind=kLoop, calls=comp.2
-  fusion.5 = f32[4]{0} fusion(constant, fusion.4), kind=kLoop, calls=comp.1
-  fusion.6 = f32[4]{0} fusion(constant, fusion.4), kind=kLoop, calls=comp
-  ROOT tuple = (f32[4]{0}, f32[4]{0}, f32[4]{0}) tuple(fusion.3, fusion.5, fusion.6)
-})")
-                    .value();
-  EXPECT_TRUE(fusion_merger_.Run(module.get()).value());
-
-  auto* root = module->entry_computation()->root_instruction();
-  EXPECT_EQ(HloOpcode::kTuple, root->opcode());
-  // Check operand 0 (not merged). Should have 4 instructions.
-  auto* operand0 = root->operand(0);
-  EXPECT_EQ(HloOpcode::kFusion, operand0->opcode());
-  EXPECT_EQ(4, operand0->fused_instruction_count());
-  // Check operand 1 (should have merged in its operand fusion instruction).
-  auto* operand1 = root->operand(1);
-  EXPECT_EQ(HloOpcode::kFusion, operand1->opcode());
-  EXPECT_EQ(7, operand1->fused_instruction_count());
-  // Check operand 2 (should have merged in its operand fusion instruction).
-  auto* operand2 = root->operand(2);
-  EXPECT_EQ(HloOpcode::kFusion, operand2->opcode());
-  EXPECT_EQ(7, operand2->fused_instruction_count());
-}
-
-TEST_F(FusionMergerTest, MoreMemoryAccessIfFused) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-HloModule m
-
-f32add {
-  x = f32[] parameter(0)
-  y = f32[] parameter(1)
-  ROOT _ = f32[] add(x, y)
-}
-
-comp0 {
-  p = (f32[100000000], f32[100000000], f32[100000000], f32[100000000]) parameter(0)
-  gte0 = f32[100000000] get-tuple-element(p), index=0
-  gte1 = f32[100000000] get-tuple-element(p), index=1
-  add.9 = f32[100000000] add(gte0, gte1)
-  gte2 = f32[100000000] get-tuple-element(p), index=2
-  add.10 = f32[100000000] add(add.9, gte2)
-  gte3 = f32[100000000] get-tuple-element(p), index=3
-  add.11 = f32[100000000] add(add.10, gte3)
-  p1 = (f32[100000000], f32[100000000], f32[100000000], f32[100000000]) parameter(1)
-  gte4 = f32[100000000] get-tuple-element(p1), index=0
-  gte5 = f32[100000000] get-tuple-element(p1), index=1
-  add.12 = f32[100000000] add(gte4, gte5)
-  gte6 = f32[100000000] get-tuple-element(p1), index=2
-  add.13 = f32[100000000] add(add.12, gte6)
-  gte7 = f32[100000000] get-tuple-element(p1), index=3
-  add.14 = f32[100000000] add(add.13, gte7)
-  ROOT r = f32[100000000] add(add.14, add.11)
-}
-
-comp1 {
-  p = f32[100000000] parameter(0)
-  c0 = f32[] constant(0)
-  ROOT r = f32[] reduce(p, c0), dimensions={0}, to_apply=f32add
-}
-
-comp2 {
-  p = f32[100000000] parameter(0)
-  c0 = f32[] constant(0)
-  r = f32[] reduce(p, c0), dimensions={0}, to_apply=f32add
-  ROOT n = f32[] negate(r)
-}
-
-ENTRY m.Computation2 {
-  p0 = (f32[100000000], f32[100000000], f32[100000000], f32[100000000]) parameter(0)
-  p1 = (f32[100000000], f32[100000000], f32[100000000], f32[100000000]) parameter(1)
-  fusion.0 = f32[100000000] fusion(p0, p1), kind=kLoop, calls=comp0
-  fusion.1 = f32[] fusion(fusion.0), kind=kLoop, calls=comp1
-  fusion.2 = f32[] fusion(fusion.0), kind=kLoop, calls=comp2
-  ROOT tuple = (f32[], f32[]) tuple(fusion.1, fusion.2)
-}
-)")
-                    .value();
-  EXPECT_FALSE(fusion_merger_.Run(module.get()).value());
-}
-
-TEST_F(FusionMergerTest, LessMemoryAccessIfFused) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-HloModule m
-
-comp.2 {
-  state.param_1.1 = (f32[4]{0}, f32[4]{0}, f32[4]{0}) parameter(0)
-  get-tuple-element.5 = f32[4]{0} get-tuple-element(state.param_1.1), index=0
-  get-tuple-element.6 = f32[4]{0} get-tuple-element(state.param_1.1), index=1
-  add.7 = f32[4]{0} add(get-tuple-element.5, get-tuple-element.6)
-  get-tuple-element.7 = f32[4]{0} get-tuple-element(state.param_1.1), index=2
-  ROOT add.8 = f32[4]{0} add(add.7, get-tuple-element.7)
-}
-
-comp.1 {
-  add.1.param_1.1 = f32[4]{0} parameter(1)
-  constant.param_1.3 = f32[4]{0} parameter(0)
-  add.5 = f32[4]{0} add(add.1.param_1.1, constant.param_1.3)
-  ROOT multiply.3 = f32[4]{0} multiply(add.5, constant.param_1.3)
-}
-
-comp {
-  add.1.param_1 = f32[4]{0} parameter(1)
-  constant.param_1.1 = f32[4]{0} parameter(0)
-  multiply.2 = f32[4]{0} multiply(add.1.param_1, constant.param_1.1)
-  ROOT add.4 = f32[4]{0} add(multiply.2, constant.param_1.1)
-}
-
-ENTRY m.Computation2 {
-  constant = f32[4]{0} constant({1, 1, 1, 1})
-  state = (f32[4]{0}, f32[4]{0}, f32[4]{0}) parameter(0)
-  fusion.2 = f32[4]{0} fusion(state), kind=kLoop, calls=comp.2
-  fusion.3 = f32[4]{0} fusion(constant, fusion.2), kind=kLoop, calls=comp.1
-  fusion.4 = f32[4]{0} fusion(constant, fusion.2), kind=kLoop, calls=comp
-  ROOT tuple = (f32[4]{0}, f32[4]{0}) tuple(fusion.3, fusion.4)
-})")
-                    .value();
-  EXPECT_TRUE(fusion_merger_.Run(module.get()).value());
-}
-
-// Check that we're willing to merge f1_computation into f2_computation, even
-// though f2 is an input fusion node.
-TEST_F(FusionMergerTest, WillMergeIntoInputFusion) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule m
-
-    f1_computation {
-      f1_p0 = f32[32]{0} parameter(0)
-      ROOT f1_root = f32[32]{0} add(f1_p0, f1_p0)
-    }
-
-    add_computation {
-      add_lhs = f32[] parameter(0)
-      add_rhs = f32[] parameter(1)
-      ROOT add_root = f32[] add(add_lhs, add_rhs)
-    }
-
-    f2_computation {
-      f2_p0 = f32[32]{0} parameter(0)
-      f2_mul = f32[32]{0} multiply(f2_p0, f2_p0)
-      f2_zero = f32[] constant(0)
-      ROOT f2_root = f32[] reduce(f2_mul, f2_zero), dimensions={0},
-             to_apply=add_computation
-    }
-
-    ENTRY entry {
-      p0 = f32[32]{0} parameter(0)
-      f1 = f32[32]{0} fusion(p0), kind=kLoop, calls=f1_computation
-      ROOT f2 = f32[] fusion(f1), kind=kInput, calls=f2_computation
-    })")
-                    .value();
-  EXPECT_TRUE(fusion_merger_.Run(module.get()).value());
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch(m::Fusion(m::Parameter())));
-}
-
-TEST_F(FusionMergerTest, WillMergeIntoUnfusedConsumer) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule jit_matmul.36
-
-    max (parameter.13: f32[], parameter.14: f32[]) -> f32[] {
-      parameter.13 = f32[] parameter(0)
-      parameter.14 = f32[] parameter(1)
-      ROOT maximum.15 = f32[] maximum(f32[] parameter.13, f32[] parameter.14)
-    }
-
-    add (parameter.29: f32[], parameter.30: f32[]) -> f32[] {
-      parameter.29 = f32[] parameter(0)
-      parameter.30 = f32[] parameter(1)
-      ROOT add.31 = f32[] add(f32[] parameter.29, f32[] parameter.30)
-    }
-
-    fused_computation.1 (param_1.4: f32[200,200,200], param_2.1: f32[200,200]) -> f32[200,200] {
-      param_1.4 = f32[200,200,200]{2,1,0} parameter(0)
-      param_2.1 = f32[200,200]{1,0} parameter(1)
-      broadcast.3 = f32[200,200,200]{2,1,0} broadcast(f32[200,200]{1,0} param_2.1), dimensions={0,2}
-      subtract.0 = f32[200,200,200]{2,1,0} subtract(f32[200,200,200]{2,1,0} param_1.4, f32[200,200,200]{2,1,0} broadcast.3)
-      exponential.0 = f32[200,200,200]{2,1,0} exponential(f32[200,200,200]{2,1,0} subtract.0)
-      constant.27 = f32[] constant(0)
-      ROOT reduce.0 = f32[200,200]{1,0} reduce(f32[200,200,200]{2,1,0} exponential.0, f32[] constant.27), dimensions={1}, to_apply=add
-    }
-
-    fused_computation.3 (param_0.7: f32[200,200], param_1.9: f32[200,200]) -> f32[200,200,200] {
-      param_1.9 = f32[200,200]{1,0} parameter(1)
-      broadcast.10 = f32[200,200,200]{2,1,0} broadcast(f32[200,200]{1,0} param_1.9), dimensions={0,1}
-      param_0.7 = f32[200,200]{1,0} parameter(0)
-      broadcast.8 = f32[200,200,200]{2,1,0} broadcast(f32[200,200]{1,0} param_0.7), dimensions={1,2}
-      ROOT add.1 = f32[200,200,200]{2,1,0} add(f32[200,200,200]{2,1,0} broadcast.10, f32[200,200,200]{2,1,0} broadcast.8)
-    }
-
-    ENTRY entry (parameter.1: f32[200,200], parameter.2: f32[200,200]) -> f32[200,200] {
-      parameter.2 = f32[200,200]{1,0} parameter(1)
-      parameter.1 = f32[200,200]{1,0} parameter(0)
-      fusion.3 = f32[200,200,200]{2,1,0} fusion(f32[200,200]{1,0} parameter.2, f32[200,200]{1,0} parameter.1), kind=kLoop, calls=fused_computation.3
-      constant.11 = f32[] constant(-inf)
-      reduce.16 = f32[200,200]{1,0} reduce(f32[200,200,200]{2,1,0} fusion.3, f32[] constant.11), dimensions={1}, to_apply=max
-      ROOT fusion.1 = f32[200,200]{1,0} fusion(f32[200,200,200]{2,1,0} fusion.3, f32[200,200]{1,0} reduce.16), kind=kInput, calls=fused_computation.1
-    })")
-                    .value();
-  EXPECT_TRUE(fusion_merger_.Run(module.get()).value());
-  EXPECT_THAT(
-      module->entry_computation()->root_instruction(),
-      GmockMatch(m::Fusion(m::Fusion(), m::Parameter(), m::Parameter())));
-}
-
-TEST_F(FusionMergerTest, WillNotMergeReduceUnfriendlyLayouts) {
-  // TODO(b/247762001): the case here does not represent the problem -
-  // profiling shows that it works faster if merged (even on larger dimensions).
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule m
-
-    f1_computation {
-      f1_p0 = f32[16,16,256]{0,1,2} parameter(0)
-      add = f32[16,16,256]{0,1,2} add(f1_p0, f1_p0)
-      // Note that the copy changes the layout from {0,1,2} to {2,1,0}.
-      ROOT f1_root = f32[16,16,256]{2,1,0} copy(add)
-    }
-
-    add_computation {
-      add_lhs = f32[] parameter(0)
-      add_rhs = f32[] parameter(1)
-      ROOT add_root = f32[] add(add_lhs, add_rhs)
-    }
-
-    f2_computation {
-      f2_p0 = f32[16,16,256]{2,1,0} parameter(0)
-      f2_zero = f32[] constant(0)
-      ROOT f2_root = f32[] reduce(f2_p0, f2_zero), dimensions={0,1,2},
-             to_apply=add_computation
-    }
-
-    ENTRY entry {
-      p0 = f32[16,16,256]{0,1,2} parameter(0)
-      f1 = f32[16,16,256]{2,1,0} fusion(p0), kind=kLoop, calls=f1_computation
-      ROOT f2 = f32[] fusion(f1), kind=kInput, calls=f2_computation
-    })")
-                    .value();
-  EXPECT_FALSE(fusion_merger_.Run(module.get()).value());
-}
-
-TEST_F(FusionMergerTest, WillMergeReduceNotTooUnfriendlyLayouts) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule m
-
-    f1_computation {
-      f1_p0 = f32[16,16,256]{0,1,2} parameter(0)
-      slice1 = f32[5,16,256]{0,1,2} slice(f1_p0), slice={[0:5], [0:16], [0:256]}
-      // Here the copy changes the layout only of a part of the data.
-      f1_copy = f32[5,16,256]{2,1,0} copy(slice1)
-      slice2 = f32[11,16,256]{0,1,2} slice(f1_p0), slice={[0:11], [0:16], [0:256]}
-      bitcast = f32[11,16,256]{2,1,0} bitcast(slice2)
-      ROOT f1_root = f32[16,16,256]{2,1,0} concatenate(f1_copy, bitcast), dimensions={0}
-    }
-
-    add_computation {
-      add_lhs = f32[] parameter(0)
-      add_rhs = f32[] parameter(1)
-      ROOT add_root = f32[] add(add_lhs, add_rhs)
-    }
-
-    f2_computation {
-      f2_p0 = f32[16,16,256]{2,1,0} parameter(0)
-      f2_zero = f32[] constant(0)
-      ROOT f2_root = f32[] reduce(f2_p0, f2_zero), dimensions={0,1,2},
-             to_apply=add_computation
-    }
-
-    ENTRY entry {
-      p0 = f32[16,16,256]{0,1,2} parameter(0)
-      f1 = f32[16,16,256]{2,1,0} fusion(p0), kind=kLoop, calls=f1_computation
-      ROOT f2 = f32[] fusion(f1), kind=kInput, calls=f2_computation
-    })")
-                    .value();
-  EXPECT_TRUE(fusion_merger_.Run(module.get()).value());
-}
-
-// Check that we limit the number of operands to fusions we create.
-TEST_F(FusionMergerTest, AvoidsLargeFusion) {
-  constexpr int64_t kNumParams = MaxOperandsAndOutputsPerFusion() + 1;
-
-  // Compute
-  //   p0 + p1 + p2 + ... + pn,
-  // Use so many parameters that they do not fit into one fusion.
-  auto module = CreateNewVerifiedModule();
-  HloComputation::Builder b(TestName());
-  Shape shape = ShapeUtil::MakeShape(F32, {10, 100});
-
-  std::vector<HloInstruction*> entry_params;
-
-  for (int64_t i = 0; i < kNumParams; ++i) {
-    entry_params.push_back(
-        b.AddInstruction(HloInstruction::CreateParameter(i, shape, "p")));
-  }
-  auto make_fusion = [&](absl::Span<HloInstruction* const> params) {
-    // Build a fusion computation for calculating the sum of all parameters.
-    HloComputation::Builder sub_builder("subcomp");
-    HloInstruction* sum = nullptr;
-    for (int64_t i = 0; i < params.size(); ++i) {
-      auto p = sub_builder.AddInstruction(
-          HloInstruction::CreateParameter(i, shape, "p"));
-      if (sum == nullptr) {
-        sum = p;
-      } else {
-        sum = sub_builder.AddInstruction(
-            HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sum, p));
-      }
-    }
-    HloComputation* subcomp =
-        module->AddEmbeddedComputation(sub_builder.Build());
-    return HloInstruction::CreateFusion(
-        shape, HloInstruction::FusionKind::kLoop, params, subcomp);
-  };
-  auto fusion = b.AddInstruction(
-      make_fusion(absl::MakeSpan(entry_params)
-                      .subspan(0, MaxOperandsAndOutputsPerFusion())));
-  b.AddInstruction(make_fusion({entry_params.back(), fusion}));
-  module->AddEntryComputation(b.Build());
-  EXPECT_FALSE(fusion_merger_.Run(module.get()).value());
-}
-
-// TODO(b/119692968): Remove this test once fusion emitter is fixed.
-TEST_F(FusionMergerTest, WillNotMergeIfFusionEmitterIsInefficient) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-HloModule m
-
-f1 {
-  Arg_0.5 = f32[200000] parameter(0)
-  slice.7 = f32[100000] slice(Arg_0.5), slice={[0:199999:2]}
-  slice.8 = f32[100000] slice(Arg_0.5), slice={[1:200000:2]}
-  add.9 = f32[100000] add(slice.7, slice.8)
-  slice.10 = f32[50000] slice(add.9), slice={[0:99999:2]}
-  slice.11 = f32[50000] slice(add.9), slice={[1:100000:2]}
-  add.12 = f32[50000] add(slice.10, slice.11)
-  slice.13 = f32[25000] slice(add.12), slice={[0:49999:2]}
-  slice.14 = f32[25000] slice(add.12), slice={[1:50000:2]}
-  add.15 = f32[25000] add(slice.13, slice.14)
-  slice.16 = f32[12500] slice(add.15), slice={[0:24999:2]}
-  slice.17 = f32[12500] slice(add.15), slice={[1:25000:2]}
-  add.18 = f32[12500] add(slice.16, slice.17)
-  slice.19 = f32[6250] slice(add.18), slice={[0:12499:2]}
-  slice.20 = f32[6250] slice(add.18), slice={[1:12500:2]}
-  add.21 = f32[6250] add(slice.19, slice.20)
-  slice.22 = f32[3125] slice(add.21), slice={[0:6249:2]}
-  slice.23 = f32[3125] slice(add.21), slice={[1:6250:2]}
-  ROOT add.24 = f32[3125] add(slice.22, slice.23)
-}
-
-f2 {
-  Arg_0 = f32[3125] parameter(0)
-  slice.25 = f32[1562] slice(Arg_0), slice={[0:3124:2]}
-  slice.26 = f32[1562] slice(Arg_0), slice={[1:3125:2]}
-  add.27 = f32[1562] add(slice.25, slice.26)
-  slice.28 = f32[781] slice(add.27), slice={[0:1561:2]}
-  slice.29 = f32[781] slice(add.27), slice={[1:1562:2]}
-  add.30 = f32[781] add(slice.28, slice.29)
-  slice.31 = f32[390] slice(add.30), slice={[0:780:2]}
-  slice.32 = f32[390] slice(add.30), slice={[1:781:2]}
-  add.33 = f32[390] add(slice.31, slice.32)
-  slice.34 = f32[195] slice(add.33), slice={[0:389:2]}
-  slice.35 = f32[195] slice(add.33), slice={[1:390:2]}
-  add.36 = f32[195] add(slice.34, slice.35)
-  slice.37 = f32[97] slice(add.36), slice={[0:194:2]}
-  slice.38 = f32[97] slice(add.36), slice={[1:195:2]}
-  add.39 = f32[97] add(slice.37, slice.38)
-  slice.40 = f32[48] slice(add.39), slice={[0:96:2]}
-  slice.41 = f32[48] slice(add.39), slice={[1:97:2]}
-  ROOT add.42 = f32[48] add(slice.40, slice.41)
-}
-
-ENTRY e {
-  p0 = f32[200000] parameter(0)
-  f1 = f32[3125] fusion(p0), kind=kLoop, calls=f1
-  ROOT r = f32[48] fusion(f1), kind=kLoop, calls=f2
-})")
-                    .value();
-  EXPECT_FALSE(fusion_merger_.Run(module.get()).value());
-}
-
-TEST_F(FusionMergerTest, WillMergeSliceIntoReusingConsumer) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-HloModule m
-
-f1 {
-  p01 = s8[1000000] parameter(0)
-  ROOT s0 = s8[10] slice(p01), slice={[0:10]}
-}
-
-f2 {
-  p02 = s8[10] parameter(0)
-  ROOT b0 = s8[10,1000000] broadcast(p02), dimensions={0}
-}
-
-ENTRY e {
-  p0 = s8[1000000] parameter(0)
-  f1 = s8[10] fusion(p0), kind=kLoop, calls=f1
-  ROOT r = s8[10,1000000] fusion(f1), kind=kLoop, calls=f2
-})")
-                    .value();
-  EXPECT_TRUE(fusion_merger_.Run(module.get()).value());
-}
-
-TEST_F(FusionMergerTest, WillMergeExpensiveFusionsIfSavesMemory) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule m
-
-    %f_a (p: f32[]) -> f32[1024,1024,1024] {
-      %p = f32[] parameter(0)
-      %b = f32[1024,1024,1024] broadcast(%p), dimensions={}
-      ROOT %t = f32[1024,1024,1024] tanh(%b)
-    }
-
-    %f_b (p: f32[1024,1024,1024]) -> f32[1024,1024,1024] {
-      %p = f32[1024,1024,1024] parameter(0)
-      ROOT %t = f32[1024,1024,1024] tanh(%p)
-    }
-
-    %f_c (p: f32[1024,1024,1024]) -> f32[1024,1024,1024] {
-      %p = f32[1024,1024,1024] parameter(0)
-      ROOT %t = f32[1024,1024,1024] tanh(%p)
-    }
-
-    ENTRY entry {
-      p0 = f32[] parameter(0)
-      f1 = f32[1024,1024,1024] fusion(p0), kind=kLoop, calls=%f_a
-      f2 = f32[1024,1024,1024] fusion(f1), kind=kLoop, calls=%f_b
-      f3 = f32[1024,1024,1024] fusion(f1), kind=kLoop, calls=%f_c
-      ROOT f4 = f32[1024,1024,1024] add(f2, f3)
-    })")
-                    .value();
-  EXPECT_TRUE(fusion_merger_.Run(module.get()).value());
-}
-
-TEST_F(FusionMergerTest, WillMergeExpensiveFusionsWithSingleConsumer) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule m
-
-    %f_b (p: f32[1024,1024,1024]) -> f32[1024,1024,1024] {
-      %p = f32[1024,1024,1024] parameter(0)
-      ROOT %t = f32[1024,1024,1024] tanh(%p)
-    }
-
-    %f_c (p: f32[1024,1024,1024]) -> f32[1024,1024,1024] {
-      %p = f32[1024,1024,1024] parameter(0)
-      ROOT %t = f32[1024,1024,1024] add(%p, %p)
-    }
-
-    ENTRY entry {
-      p0 = f32[1024,1024,1024] parameter(0)
-      f1 = f32[1024,1024,1024] fusion(p0), kind=kLoop, calls=%f_b
-      ROOT f2 = f32[1024,1024,1024] fusion(f1), kind=kLoop, calls=%f_c
-    })")
-                    .value();
-  EXPECT_TRUE(fusion_merger_.Run(module.get()).value());
-}
-
-TEST_F(FusionMergerTest, WillNotMergeExpensiveFusionsWithReusingConsumer) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule m
-
-    %f_b {
-      %p = f32[1024,1024,1024] parameter(0)
-      %t1 = f32[1024,1024,1024] tanh(%p)
-      %t2 = f32[1024,1024,1024] tanh(%t1)
-      %t3 = f32[1024,1024,1024] tanh(%t2)
-      %t4 = f32[1024,1024,1024] tanh(%t3)
-      %t5 = f32[1024,1024,1024] tanh(%t4)
-      %t6 = f32[1024,1024,1024] tanh(%t5)
-      %t7 = f32[1024,1024,1024] tanh(%t6)
-      %t8 = f32[1024,1024,1024] tanh(%t7)
-      ROOT %t9 = f32[1024,1024,1024] tanh(%t8)
-    }
-
-    %f_c {
-      %p = f32[1024,1024,1024] parameter(0)
-      ROOT %t = f32[1024,1024,1024,2048] broadcast(%p), dimensions={0,1,2}
-    }
-
-    ENTRY entry {
-      p0 = f32[1024,1024,1024] parameter(0)
-      f1 = f32[1024,1024,1024] fusion(p0), kind=kLoop, calls=%f_b
-      ROOT f2 = f32[1024,1024,1024,2048] fusion(f1), kind=kLoop, calls=%f_c
-    })")
-                    .value();
-  EXPECT_FALSE(fusion_merger_.Run(module.get()).value());
-}
-
-TEST_F(FusionMergerTest, NoMergeWithBitcast) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-HloModule m
-
-f32add {
-  x.634 = f32[] parameter(0)
-  y.635 = f32[] parameter(1)
-  ROOT add.636 = f32[] add(x.634, y.635)
-}
-
-fused_computation.103 {
-  param_0.310 = f16[1,8,512,1536]{2,3,1,0} parameter(0)
-  param_1.420 = f32[8,512]{1,0} parameter(1)
-  bitcast.1144 = f32[1,8,512]{2,1,0} bitcast(param_1.420)
-  convert.252 = f16[1,8,512]{2,1,0} convert(bitcast.1144)
-  bitcast.1143 = f16[8,512]{1,0} bitcast(convert.252)
-  broadcast.481 = f16[1,8,512,1536]{2,3,1,0} broadcast(bitcast.1143), dimensions={1,2}
-  divide.15 = f16[1,8,512,1536]{2,3,1,0} divide(param_0.310, broadcast.481)
-  ROOT bitcast.1142 = f16[8,512,1536]{1,2,0} bitcast(divide.15)
-}
-
-fused_computation.105 {
-  param_1.426 = f16[8,1536,512]{2,1,0} parameter(1)
-  bitcast.1896 = f16[1,8,1536,512]{3,2,1,0} bitcast(param_1.426)
-  transpose.238 = f16[1,8,512,1536]{2,3,1,0} transpose(bitcast.1896), dimensions={0,1,3,2}
-  param_0.315 = f16[8,512]{1,0} parameter(0)
-  broadcast.482 = f16[1,8,512,1536]{2,3,1,0} broadcast(param_0.315), dimensions={1,2}
-  subtract.22 = f16[1,8,512,1536]{2,3,1,0} subtract(transpose.238, broadcast.482)
-  ROOT exponential.15 = f16[1,8,512,1536]{2,3,1,0} exponential(subtract.22)
-}
-
-fused_computation.104 {
-  param_0.1000 = f16[8,1536,512]{2,1,0} parameter(0)
-  convert.652 = f32[8,1536,512]{2,1,0} convert(param_0.1000)
-  constant_752 = f32[] constant(-0)
-  ROOT reduce.232 = f32[8,512]{1,0} reduce(convert.652, constant_752),
-  dimensions={1}, to_apply=f32add
-}
-
-ENTRY entry {
-  p0 = f16[8,1536,512]{2,1,0} parameter(0)
-  p1 = f16[8,512]{1,0} parameter(1)
-  fusion.105 = f16[1,8,512,1536]{2,3,1,0} fusion(p1, p0), kind=kLoop, calls=fused_computation.105
-  bitcast.1787 = f16[8,1536,512]{2,1,0} bitcast(fusion.105)
-  fusion.104 = f32[8,512]{1,0} fusion(bitcast.1787), kind=kInput, calls=fused_computation.104
-  ROOT fusion.103 = f16[8,512,1536]{1,2,0} fusion(fusion.105, fusion.104), kind=kLoop, calls=fused_computation.103
-}
-    )")
-                    .value();
-  EXPECT_FALSE(fusion_merger_.Run(module.get()).value());
-}
-
-TEST_F(FusionMergerTest, CostBasedMerge) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-HloModule m
-
-fused_computation.45 {
-  param_1.194 = f16[8,1536,512]{2,1,0} parameter(1)
-  bitcast.1042 = f16[1,8,512,1536]{2,3,1,0} bitcast(param_1.194)
-  param_0.135 = f16[8,512]{1,0} parameter(0)
-  broadcast.391 = f16[1,8,512,1536]{2,3,1,0} broadcast(param_0.135), dimensions={1,2}
-  subtract.6 = f16[1,8,512,1536]{2,3,1,0} subtract(bitcast.1042, broadcast.391)
-  ROOT exponential.11 = f16[1,8,512,1536]{2,3,1,0} exponential(subtract.6)
-}
-
-f32add {
-  x.634 = f32[] parameter(0)
-  y.635 = f32[] parameter(1)
-  ROOT add.636 = f32[] add(x.634, y.635)
-}
-
-fused_computation.44 {
-  param_0.869 = f16[1,8,512,1536]{2,3,1,0} parameter(0)
-  convert.221 = f32[1,8,512,1536]{2,3,1,0} convert(param_0.869)
-  transpose.212 = f32[1,8,1536,512]{3,2,1,0} transpose(convert.221), dimensions={0,1,3,2}
-  bitcast.1041 = f32[8,1536,512]{2,1,0} bitcast(transpose.212)
-  constant_429 = f32[] constant(0)
-  ROOT reduce.149 = f32[8,512]{1,0} reduce(bitcast.1041, constant_429), dimensions={1}, to_apply=f32add
-}
-
-fused_computation.43 {
-  param_0.130 = f16[1,8,512,1536]{2,3,1,0} parameter(0)
-  param_1.188 = f32[8,512]{1,0} parameter(1)
-  bitcast.1040 = f32[1,8,512]{2,1,0} bitcast(param_1.188)
-  convert.220 = f16[1,8,512]{2,1,0} convert(bitcast.1040)
-  bitcast.1039 = f16[8,512]{1,0} bitcast(convert.220)
-  broadcast.390 = f16[1,8,512,1536]{2,3,1,0} broadcast(bitcast.1039), dimensions={1,2}
-  divide.11 = f16[1,8,512,1536]{2,3,1,0} divide(param_0.130, broadcast.390)
-  ROOT bitcast.1038 = f16[8,512,1536]{1,2,0} bitcast(divide.11)
-}
-
-ENTRY entry {
-  p0 = f16[8,1536,512]{2,1,0} parameter(0)
-  p1 = f16[8,512]{1,0} parameter(1)
-  fusion.45 = f16[1,8,512,1536]{2,3,1,0} fusion(p1, p0), kind=kLoop, calls=fused_computation.45
-  fusion.44 = f32[8,512]{1,0} fusion(fusion.45), kind=kInput, calls=fused_computation.44
-  ROOT fusion.43 = f16[8,512,1536]{1,2,0} fusion(fusion.45, fusion.44), kind=kLoop, calls=fused_computation.43
-}
-    )")
-                    .value();
-  EXPECT_TRUE(fusion_merger_.Run(module.get()).value());
-}
-
-// Outputs of fusions 66 and 67 here are heavily reused by fusion 59 - so
-// it is better to not merge here.
-TEST_F(FusionMergerTest, CostBasedNoMerge) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-HloModule m
-
-add_float_.56 {
-  x.57 = f32[] parameter(0)
-  y.58 = f32[] parameter(1)
-  ROOT add.59 = f32[] add(x.57, y.58)
-}
-
-fused_computation.66 {
-  constant.635 = f32[] constant(0)
-  broadcast.257 = f32[459,3]{1,0} broadcast(constant.635), dimensions={}
-  constant.641 = f32[] constant(1)
-  broadcast.256 = f32[459,3]{1,0} broadcast(constant.641), dimensions={}
-  broadcast.255 = f32[459]{0} broadcast(constant.635), dimensions={}
-  iota.28 = f32[459]{0} iota(), iota_dimension=0
-  constant.629 = f32[] constant(1.49891067)
-  broadcast.253 = f32[459]{0} broadcast(constant.629), dimensions={}
-  multiply.39 = f32[459]{0} multiply(iota.28, broadcast.253)
-  constant.633 = f32[] constant(-1)
-  broadcast.252 = f32[459]{0} broadcast(constant.633), dimensions={}
-  add.31 = f32[459]{0} add(multiply.39, broadcast.252)
-  ceil.11 = f32[459]{0} ceil(add.31)
-  constant.630 = f32[] constant(685)
-  broadcast.251 = f32[459]{0} broadcast(constant.630), dimensions={}
-  clamp.49 = f32[459]{0} clamp(broadcast.255, ceil.11, broadcast.251)
-  subtract.11 = f32[459]{0} subtract(clamp.49, multiply.39)
-  broadcast.249 = f32[459,3]{1,0} broadcast(subtract.11), dimensions={0}
-  iota.26 = f32[459,3]{1,0} iota(), iota_dimension=1
-  add.30 = f32[459,3]{1,0} add(broadcast.249, iota.26)
-  abs.3 = f32[459,3]{1,0} abs(add.30)
-  subtract.10 = f32[459,3]{1,0} subtract(broadcast.256, abs.3)
-  maximum.6 = f32[459,3]{1,0} maximum(broadcast.257, subtract.10)
-  ROOT reduce.3 = f32[459]{0} reduce(maximum.6, constant.635), dimensions={1}, to_apply=add_float_.56
-}
-
-fused_computation.67 {
-  constant.684 = f32[] constant(0)
-  broadcast.296 = f32[1130,3]{1,0} broadcast(constant.684), dimensions={}
-  constant.685 = f32[] constant(1)
-  broadcast.295 = f32[1130,3]{1,0} broadcast(constant.685), dimensions={}
-  broadcast.294 = f32[1130]{0} broadcast(constant.684), dimensions={}
-  iota.41 = f32[1130]{0} iota(), iota_dimension=0
-  constant.675 = f32[] constant(1.34513271)
-  broadcast.293 = f32[1130]{0} broadcast(constant.675), dimensions={}
-  multiply.47 = f32[1130]{0} multiply(iota.41, broadcast.293)
-  constant.677 = f32[] constant(-1)
-  broadcast.290 = f32[1130]{0} broadcast(constant.677), dimensions={}
-  add.39 = f32[1130]{0} add(multiply.47, broadcast.290)
-  ceil.15 = f32[1130]{0} ceil(add.39)
-  constant.676 = f32[] constant(1517)
-  broadcast.289 = f32[1130]{0} broadcast(constant.676), dimensions={}
-  clamp.53 = f32[1130]{0} clamp(broadcast.294, ceil.15, broadcast.289)
-  subtract.19 = f32[1130]{0} subtract(clamp.53, multiply.47)
-  broadcast.287 = f32[1130,3]{1,0} broadcast(subtract.19), dimensions={0}
-  iota.39 = f32[1130,3]{1,0} iota(), iota_dimension=1
-  add.38 = f32[1130,3]{1,0} add(broadcast.287, iota.39)
-  abs.7 = f32[1130,3]{1,0} abs(add.38)
-  subtract.18 = f32[1130,3]{1,0} subtract(broadcast.295, abs.7)
-  maximum.10 = f32[1130,3]{1,0} maximum(broadcast.296, subtract.18)
-  ROOT reduce.4 = f32[1130]{0} reduce(maximum.10, constant.684), dimensions={1}, to_apply=add_float_.56
-}
-
-fused_computation.59 {
-  constant.532 = f32[] constant(0)
-  broadcast.316 = f32[1130,3]{1,0} broadcast(constant.532), dimensions={}
-  constant.663 = f32[] constant(1)
-  broadcast.315 = f32[1130,3]{1,0} broadcast(constant.663), dimensions={}
-  broadcast.314 = f32[1130]{0} broadcast(constant.532), dimensions={}
-  iota.47 = f32[1130]{0} iota(), iota_dimension=0
-  constant.579 = f32[] constant(1.34513271)
-  broadcast.311 = f32[1130]{0} broadcast(constant.579), dimensions={}
-  multiply.51 = f32[1130]{0} multiply(iota.47, broadcast.311)
-  constant.578 = f32[] constant(-1)
-  broadcast.310 = f32[1130]{0} broadcast(constant.578), dimensions={}
-  add.43 = f32[1130]{0} add(multiply.51, broadcast.310)
-  ceil.17 = f32[1130]{0} ceil(add.43)
-  constant.576 = f32[] constant(1517)
-  broadcast.309 = f32[1130]{0} broadcast(constant.576), dimensions={}
-  clamp.55 = f32[1130]{0} clamp(broadcast.314, ceil.17, broadcast.309)
-  subtract.24 = f32[1130]{0} subtract(clamp.55, multiply.51)
-  broadcast.306 = f32[1130,3]{1,0} broadcast(subtract.24), dimensions={0}
-  iota.45 = f32[1130,3]{1,0} iota(), iota_dimension=1
-  add.42 = f32[1130,3]{1,0} add(broadcast.306, iota.45)
-  abs.9 = f32[1130,3]{1,0} abs(add.42)
-  subtract.23 = f32[1130,3]{1,0} subtract(broadcast.315, abs.9)
-  maximum.12 = f32[1130,3]{1,0} maximum(broadcast.316, subtract.23)
-  param_2.183 = f32[1130]{0} parameter(2)
-  broadcast.172 = f32[1130,3]{1,0} broadcast(param_2.183), dimensions={0}
-  divide.3 = f32[1130,3]{1,0} divide(maximum.12, broadcast.172)
-  bitcast.53 = f32[3390]{0} bitcast(divide.3)
-  broadcast.171 = f32[3390,1377]{1,0} broadcast(bitcast.53), dimensions={0}
-  broadcast.276 = f32[459,3]{1,0} broadcast(constant.532), dimensions={}
-  broadcast.275 = f32[459,3]{1,0} broadcast(constant.663), dimensions={}
-  broadcast.274 = f32[459]{0} broadcast(constant.532), dimensions={}
-  iota.35 = f32[459]{0} iota(), iota_dimension=0
-  constant.614 = f32[] constant(1.49891067)
-  broadcast.273 = f32[459]{0} broadcast(constant.614), dimensions={}
-  multiply.43 = f32[459]{0} multiply(iota.35, broadcast.273)
-  broadcast.272 = f32[459]{0} broadcast(constant.578), dimensions={}
-  add.35 = f32[459]{0} add(multiply.43, broadcast.272)
-  ceil.13 = f32[459]{0} ceil(add.35)
-  constant.611 = f32[] constant(685)
-  broadcast.269 = f32[459]{0} broadcast(constant.611), dimensions={}
-  clamp.51 = f32[459]{0} clamp(broadcast.274, ceil.13, broadcast.269)
-  subtract.15 = f32[459]{0} subtract(clamp.51, multiply.43)
-  broadcast.267 = f32[459,3]{1,0} broadcast(subtract.15), dimensions={0}
-  iota.33 = f32[459,3]{1,0} iota(), iota_dimension=1
-  add.34 = f32[459,3]{1,0} add(broadcast.267, iota.33)
-  abs.5 = f32[459,3]{1,0} abs(add.34)
-  subtract.14 = f32[459,3]{1,0} subtract(broadcast.275, abs.5)
-  maximum.8 = f32[459,3]{1,0} maximum(broadcast.276, subtract.14)
-  param_1.177 = f32[459]{0} parameter(1)
-  broadcast.170 = f32[459,3]{1,0} broadcast(param_1.177), dimensions={0}
-  divide.2 = f32[459,3]{1,0} divide(maximum.8, broadcast.170)
-  bitcast.52 = f32[1377]{0} bitcast(divide.2)
-  broadcast.169 = f32[3390,1377]{1,0} broadcast(bitcast.52), dimensions={1}
-  multiply.15 = f32[3390,1377]{1,0} multiply(broadcast.171, broadcast.169)
-  bitcast.61 = f32[1130,3,459,3]{3,2,1,0} bitcast(multiply.15)
-  transpose.68 = f32[459,1130,3,3]{2,0,3,1} transpose(bitcast.61), dimensions={2,0,3,1}
-  copy.1 = f32[459,1130,3,3]{3,2,1,0} copy(transpose.68)
-  bitcast.50 = f32[1130,459,9]{2,1,0} bitcast(copy.1)
-  broadcast.168 = f32[1130,459,6,9]{3,2,1,0} broadcast(bitcast.50), dimensions={0,1,3}
-  param_0.171 = u8[1,688,1520,6]{3,2,1,0} parameter(0)
-  bitcast.49 = u8[688,1520,1,6]{3,1,0,2} bitcast(param_0.171)
-  convert.175 = f32[688,1520,1,6]{3,1,0,2} convert(bitcast.49)
-  broadcast.167 = f32[459,1130,1]{2,1,0} broadcast(clamp.51), dimensions={0}
-  broadcast.166 = f32[459,1130,1]{2,1,0} broadcast(clamp.55), dimensions={1}
-  concatenate.3 = f32[459,1130,2]{2,1,0} concatenate(broadcast.167, broadcast.166), dimensions={2}
-  convert.174 = s32[459,1130,2]{2,1,0} convert(concatenate.3)
-  bitcast.48 = s32[518670,2]{1,0} bitcast(convert.174)
-  gather.1 = f32[518670,3,3,1,6]{2,1,4,0,3} gather(convert.175, bitcast.48), offset_dims={1,2,3,4}, collapsed_slice_dims={}, start_index_map={0,1}, index_vector_dim=1, slice_sizes={3,3,1,6}
-  transpose.69 = f32[1,518670,6,3,3]{4,3,2,1,0} transpose(gather.1), dimensions={3,0,4,1,2}
-  bitcast.47 = f32[1130,459,6,9]{3,2,1,0} bitcast(transpose.69)
-  multiply.14 = f32[1130,459,6,9]{3,2,1,0} multiply(broadcast.168, bitcast.47)
-  reduce.2 = f32[1130,459,6]{2,1,0} reduce(multiply.14, constant.532), dimensions={3}, to_apply=add_float_.56
-  convert.173 = f16[1130,459,6]{2,1,0} convert(reduce.2)
-  bitcast.46 = f16[1,459,1130,6]{3,2,1,0} bitcast(convert.173)
-  constant.533 = f16[] constant(0)
-  pad.9 = f16[1,480,1130,6]{3,2,1,0} pad(bitcast.46, constant.533), padding=0_0x0_21x0_0x0_0
-  pad.8 = f16[1,480,1152,6]{3,2,1,0} pad(pad.9, constant.533), padding=0_0x0_0x0_22x0_0
-  constant.532f16 = f16[] constant(0)
-  ROOT pad.7 = f16[1,485,1157,6]{3,2,1,0} pad(pad.8, constant.532f16), padding=0_0x2_3x2_3x0_0
-}
-
-ENTRY e {
-  arg0.1 = u8[1,688,1520,6]{3,2,1,0} parameter(0), parameter_replication={false}
-  fusion.66 = f32[459]{0} fusion(), kind=kLoop, calls=fused_computation.66
-  fusion.67 = f32[1130]{0} fusion(), kind=kLoop, calls=fused_computation.67
-  ROOT fusion.59 = f16[1,485,1157,6]{2,1,3,0} fusion(arg0.1, fusion.66, fusion.67), kind=kLoop, calls=fused_computation.59
-}
-    )")
-                    .value();
-  EXPECT_FALSE(fusion_merger_.Run(module.get()).value());
-}
-
-TEST_F(FusionMergerTest, NoMergeBecauseTooManyBasicBlockSplits) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-HloModule m
-
-region_6.97 {
-  Arg_0.98 = pred[] parameter(0)
-  Arg_1.99 = pred[] parameter(1)
-  ROOT or.100 = pred[] or(Arg_0.98, Arg_1.99)
-}
-
-region_4.50 {
-  Arg_0.51 = f64[] parameter(0)
-  Arg_1.52 = f64[] parameter(1)
-  ROOT add.53 = f64[] add(Arg_0.51, Arg_1.52)
-}
-
-f2 {
-  param_0 = s64[1]{0} parameter(0)
-  constant_70 = f64[] constant(0)
-  convert.41.clone.1 = f64[1]{0} convert(param_0)
-  ROOT pad.99.clone.1 = f64[3]{0} pad(convert.41.clone.1, constant_70), padding=0_2
-}
-
-f1 {
-  param_0.361 = pred[5]{0} parameter(0)
-  broadcast.107 = pred[10,5]{1,0} broadcast(param_0.361), dimensions={1}
-  param_6.244 = pred[5]{0} parameter(6)
-  broadcast.111.clone.1 = pred[10,5]{1,0} broadcast(param_6.244), dimensions={1}
-  param_1.450 = f64[10,5]{1,0} parameter(1)
-  constant_294_clone_1 = f64[] constant(1)
-  broadcast.153.clone.1 = f64[10,5]{1,0} broadcast(constant_294_clone_1), dimensions={}
-  compare.22.clone.1 = pred[10,5]{1,0} compare(param_1.450, broadcast.153.clone.1), direction=GE
-  constant_75_clone_1 = f64[] constant(-1)
-  broadcast.109.clone.1 = f64[10,5]{1,0} broadcast(constant_75_clone_1), dimensions={}
-  add.34.clone.1 = f64[10,5]{1,0} add(param_1.450, broadcast.109.clone.1)
-  param_5.322 = f64[10,5,4]{1,0,2} parameter(5)
-  slice.45.clone.1 = f64[10,5,1]{1,0,2} slice(param_5.322), slice={[0:10], [0:5], [3:4]}
-  bitcast.94.clone.1 = f64[10,5]{1,0} bitcast(slice.45.clone.1)
-  divide.7.clone.1 = f64[10,5]{1,0} divide(add.34.clone.1, bitcast.94.clone.1)
-  add.33.clone.1 = f64[10,5]{1,0} add(divide.7.clone.1, broadcast.153.clone.1)
-  constant_70 = f64[] constant(0)
-  broadcast.157.clone.1 = f64[10,5]{1,0} broadcast(constant_70), dimensions={}
-  compare.26.clone.1 = pred[10,5]{1,0} compare(param_1.450, broadcast.157.clone.1), direction=LE
-  slice.46.clone.1 = f64[10,5,1]{1,0,2} slice(param_5.322), slice={[0:10], [0:5], [0:1]}
-  bitcast.93.clone.1 = f64[10,5]{1,0} bitcast(slice.46.clone.1)
-  divide.6.clone.1 = f64[10,5]{1,0} divide(param_1.450, bitcast.93.clone.1)
-  broadcast.295.clone.1 = f64[10,5,3]{1,0,2} broadcast(param_1.450), dimensions={0,1}
-  param_4.368 = f64[10,5,2]{1,0,2} parameter(4)
-  pad.103.clone.1 = f64[10,5,3]{1,0,2} pad(param_4.368, constant_70), padding=0_0x0_0x1_0
-  compare.121.clone.1 = pred[10,5,3]{1,0,2} compare(broadcast.295.clone.1, pad.103.clone.1), direction=GE
-  pad.102.clone.1 = f64[10,5,3]{1,0,2} pad(param_4.368, constant_294_clone_1), padding=0_0x0_0x0_1
-  compare.120.clone.1 = pred[10,5,3]{1,0,2} compare(broadcast.295.clone.1, pad.102.clone.1), direction=LT
-  and.39.clone.1 = pred[10,5,3]{1,0,2} and(compare.121.clone.1, compare.120.clone.1)
-  transpose.9 = pred[3,10,5]{2,1,0} transpose(and.39.clone.1), dimensions={2,0,1}
-  constant_296_clone_1 = pred[] constant(false)
-  reduce.91.clone.1 = pred[10,5]{1,0} reduce(transpose.9, constant_296_clone_1), dimensions={0}, to_apply=region_6.97
-  broadcast.294.clone.1 = pred[10,5,3]{1,0,2} broadcast(reduce.91.clone.1), dimensions={0,1}
-  pad.99.clone.1 = f64[3]{0} parameter(3)
-  broadcast.292.clone.1 = f64[3]{0} broadcast(constant_70), dimensions={}
-  compare.117.clone.1 = pred[3]{0} compare(pad.99.clone.1, broadcast.292.clone.1), direction=NE
-  broadcast.290.clone.1 = pred[10,5,3]{1,0,2} broadcast(compare.117.clone.1), dimensions={2}
-  select.67.clone.1 = pred[10,5,3]{1,0,2} select(broadcast.294.clone.1, and.39.clone.1, broadcast.290.clone.1)
-  convert.40.clone.1 = f64[10,5,3]{1,0,2} convert(select.67.clone.1)
-  broadcast.288.clone.1 = f64[10,5,3,3]{1,0,2,3} broadcast(convert.40.clone.1), dimensions={0,1,2}
-  param_2.361 = f64[10,5,4,3]{1,0,2,3} parameter(2)
-  slice.114.clone.1 = f64[10,5,3,3]{1,0,2,3} slice(param_2.361), slice={[0:10], [0:5], [1:4], [0:3]}
-  multiply.53.clone.1 = f64[10,5,3,3]{1,0,2,3} multiply(broadcast.288.clone.1, slice.114.clone.1)
-  transpose.10 = f64[3,3,10,5]{3,2,1,0} transpose(multiply.53.clone.1), dimensions={3,2,0,1}
-  reduce.90.clone.1 = f64[3,10,5]{2,1,0} reduce(transpose.10, constant_70), dimensions={1}, to_apply=region_4.50
-  transpose.11 = f64[10,5,3]{1,0,2} transpose(reduce.90.clone.1), dimensions={1,2,0}
-  slice.28.clone.1 = f64[10,5,1]{1,0,2} slice(transpose.11), slice={[0:10], [0:5], [0:1]}
-  bitcast.99.clone.1 = f64[10,5]{1,0} bitcast(slice.28.clone.1)
-  slice.108.clone.1 = f64[10,5,3,3]{1,0,2,3} slice(param_2.361), slice={[0:10], [0:5], [0:3], [0:3]}
-  multiply.49.clone.1 = f64[10,5,3,3]{1,0,2,3} multiply(broadcast.288.clone.1, slice.108.clone.1)
-  transpose.12 = f64[3,3,10,5]{3,2,1,0} transpose(multiply.49.clone.1), dimensions={3,2,0,1}
-  reduce.82.clone.1 = f64[3,10,5]{2,1,0} reduce(transpose.12, constant_70), dimensions={1}, to_apply=region_4.50
-  transpose.13 = f64[10,5,3]{1,0,2} transpose(reduce.82.clone.1), dimensions={1,2,0}
-  slice.107.clone.1 = f64[10,5,1]{1,0,2} slice(transpose.13), slice={[0:10], [0:5], [0:1]}
-  bitcast.240.clone.1 = f64[10,5]{1,0} bitcast(slice.107.clone.1)
-  subtract.27.clone.1 = f64[10,5]{1,0} subtract(bitcast.99.clone.1, bitcast.240.clone.1)
-  slice.27.clone.1 = f64[10,5,1]{1,0,2} slice(transpose.13), slice={[0:10], [0:5], [2:3]}
-  bitcast.98.clone.1 = f64[10,5]{1,0} bitcast(slice.27.clone.1)
-  slice.26.clone.1 = f64[10,5,1]{1,0,2} slice(transpose.11), slice={[0:10], [0:5], [2:3]}
-  bitcast.97.clone.1 = f64[10,5]{1,0} bitcast(slice.26.clone.1)
-  add.36.clone.1 = f64[10,5]{1,0} add(bitcast.97.clone.1, bitcast.98.clone.1)
-  slice.24.clone.1 = f64[10,5,1]{1,0,2} slice(transpose.11), slice={[0:10], [0:5], [1:2]}
-  bitcast.95.clone.1 = f64[10,5]{1,0} bitcast(slice.24.clone.1)
-  slice.121.clone.1 = f64[10,5,1]{1,0,2} slice(transpose.13), slice={[0:10], [0:5], [1:2]}
-  bitcast.274.clone.1 = f64[10,5]{1,0} bitcast(slice.121.clone.1)
-  subtract.26.clone.1 = f64[10,5]{1,0} subtract(bitcast.95.clone.1, bitcast.274.clone.1)
-  divide.21 = f64[10,5]{1,0} divide(subtract.26.clone.1, subtract.27.clone.1)
-  constant_77_clone_1 = f64[] constant(2)
-  broadcast.117.clone.1 = f64[10,5]{1,0} broadcast(constant_77_clone_1), dimensions={}
-  multiply.37.clone.1 = f64[10,5]{1,0} multiply(divide.21, broadcast.117.clone.1)
-  subtract.25.clone.1 = f64[10,5]{1,0} subtract(add.36.clone.1, multiply.37.clone.1)
-  subtract.24.clone.1 = f64[10,5]{1,0} subtract(param_1.450, bitcast.274.clone.1)
-  divide.9.clone.1 = f64[10,5]{1,0} divide(subtract.24.clone.1, subtract.26.clone.1)
-  clamp.7.clone.1 = f64[10,5]{1,0} clamp(broadcast.157.clone.1, divide.9.clone.1, broadcast.153.clone.1)
-  multiply.36.clone.1 = f64[10,5]{1,0} multiply(subtract.25.clone.1, clamp.7.clone.1)
-  subtract.23.clone.1 = f64[10,5]{1,0} subtract(bitcast.98.clone.1, multiply.36.clone.1)
-  compare.13.clone.1 = pred[10,5]{1,0} compare(subtract.23.clone.1, broadcast.157.clone.1), direction=GE
-  negate.19.clone.1 = f64[10,5]{1,0} negate(divide.21)
-  multiply.35.clone.1 = f64[10,5]{1,0} multiply(negate.19.clone.1, clamp.7.clone.1)
-  multiply.34.clone.1 = f64[10,5]{1,0} multiply(multiply.35.clone.1, broadcast.117.clone.1)
-  negate.18.clone.1 = f64[10,5]{1,0} negate(subtract.23.clone.1)
-  multiply.33.clone.1 = f64[10,5]{1,0} multiply(subtract.23.clone.1, subtract.23.clone.1)
-  subtract.22.clone.1 = f64[10,5]{1,0} subtract(divide.21, subtract.23.clone.1)
-  constant_78_clone_1 = f64[] constant(4)
-  broadcast.113.clone.1 = f64[10,5]{1,0} broadcast(constant_78_clone_1), dimensions={}
-  multiply.32.clone.1 = f64[10,5]{1,0} multiply(subtract.22.clone.1, broadcast.113.clone.1)
-  multiply.31.clone.1 = f64[10,5]{1,0} multiply(multiply.32.clone.1, multiply.35.clone.1)
-  subtract.21.clone.1 = f64[10,5]{1,0} subtract(multiply.33.clone.1, multiply.31.clone.1)
-  compare.12.clone.1 = pred[10,5]{1,0} compare(subtract.21.clone.1, broadcast.157.clone.1), direction=GT
-  constant_79_clone_1 = f64[] constant(2.2250738585072014e-308)
-  broadcast.112.clone.1 = f64[10,5]{1,0} broadcast(constant_79_clone_1), dimensions={}
-  maximum.18.clone.1 = f64[10,5]{1,0} maximum(broadcast.112.clone.1, subtract.21.clone.1)
-  sqrt.1.clone.1 = f64[10,5]{1,0} sqrt(maximum.18.clone.1)
-  select.47.clone.1 = f64[10,5]{1,0} select(compare.12.clone.1, sqrt.1.clone.1, broadcast.157.clone.1)
-  add.35.clone.1 = f64[10,5]{1,0} add(negate.18.clone.1, select.47.clone.1)
-  select.46.clone.1 = f64[10,5]{1,0} select(compare.13.clone.1, multiply.34.clone.1, add.35.clone.1)
-  subtract.20.clone.1 = f64[10,5]{1,0} subtract(negate.18.clone.1, select.47.clone.1)
-  multiply.30.clone.1 = f64[10,5]{1,0} multiply(subtract.22.clone.1, broadcast.117.clone.1)
-  select.45.clone.1 = f64[10,5]{1,0} select(compare.13.clone.1, subtract.20.clone.1, multiply.30.clone.1)
-  divide.8.clone.1 = f64[10,5]{1,0} divide(select.46.clone.1, select.45.clone.1)
-  clamp.6.clone.1 = f64[10,5]{1,0} clamp(broadcast.157.clone.1, divide.8.clone.1, broadcast.153.clone.1)
-  multiply.29.clone.1 = f64[10,5]{1,0} multiply(subtract.27.clone.1, clamp.6.clone.1)
-  add.32.clone.1 = f64[10,5]{1,0} add(multiply.29.clone.1, bitcast.240.clone.1)
-  select.44.clone.1 = f64[10,5]{1,0} select(compare.26.clone.1, divide.6.clone.1, add.32.clone.1)
-  select.43.clone.1 = f64[10,5]{1,0} select(compare.22.clone.1, add.33.clone.1, select.44.clone.1)
-  select.42.clone.1 = f64[10,5]{1,0} select(broadcast.111.clone.1, param_1.450, select.43.clone.1)
-  select.41 = f64[10,5]{1,0} select(broadcast.107, select.42.clone.1, broadcast.157.clone.1)
-  ROOT tuple.14 = (f64[10,5]{1,0}, f64[10,5]{1,0}, f64[10,5]{1,0}, f64[10,5]{1,0}, f64[10,5]{1,0}, f64[10,5]{1,0}, f64[10,5]{1,0}, f64[10,5]{1,0}) tuple(select.41, select.42.clone.1, clamp.6.clone.1, subtract.25.clone.1, bitcast.97.clone.1, multiply.37.clone.1, bitcast.98.clone.1, divide.21)
-}
-
-ENTRY e {
-  p3 = s64[1]{0} parameter(3)
-  f2 = f64[3]{0} fusion(p3), kind=kLoop, calls=f2
-
-  p0 = pred[5]{0} parameter(0)
-  p1 = f64[10,5]{1,0} parameter(1)
-  p2 = f64[10,5,4,3]{1,0,2,3} parameter(2)
-  p4 = f64[10,5,2]{1,0,2} parameter(4)
-  p5 = f64[10,5,4]{1,0,2} parameter(5)
-  p6 = pred[5]{0} parameter(6)
-  ROOT ret = (f64[10,5]{1,0}, f64[10,5]{1,0}, f64[10,5]{1,0}, f64[10,5]{1,0}, f64[10,5]{1,0}, f64[10,5]{1,0}, f64[10,5]{1,0}, f64[10,5]{1,0}) fusion(p0, p1, p2, f2, p4, p5, p6), kind=kLoop, calls=f1
-}
-  )")
-                    .value();
-  EXPECT_FALSE(fusion_merger_.Run(module.get()).value());
-}
-
-TEST_F(FusionMergerTest, CommonElementwiseUsedParameter) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule m
-
-    p {
-      p0 = f32[10000000] parameter(0)
-      p1 = f32[10000000] parameter(1)
-      p2 = f32[10000000] parameter(2)
-      p3 = f32[10000000] parameter(3)
-      a0 = f32[10000000] add(p1, p2)
-      a1 = f32[10000000] add(a0, p3)
-      ROOT _ = add(p0, a1)
-    }
-
-    c1 {
-      p0 = f32[10000000] parameter(0)
-      p1 = f32[10000000] parameter(1)
-      ROOT _ = add(p0, p1)
-    }
-
-    c2 {
-      p0 = f32[10000000] parameter(0)
-      p1 = f32[10000000] parameter(1)
-      ROOT _ = multiply(p0, p1)
-    }
-
-    ENTRY entry {
-      p0 = f32[10000000] parameter(0)
-      p1 = f32[10000000] parameter(1)
-      p2 = f32[10000000] parameter(2)
-      p3 = f32[10000000] parameter(3)
-      f = f32[10000000] fusion(p0, p1, p2, p3), kind=kLoop, calls=p
-      f1 = f32[10000000] fusion(p0, f), kind=kLoop, calls=c1
-      f2 = f32[10000000] fusion(p1, f), kind=kLoop, calls=c2
-      ROOT _ = (f32[10000000], f32[10000000]) tuple(f1, f2)
-    }
-    )")
-                    .value();
-  EXPECT_TRUE(fusion_merger_.Run(module.get()).value());
-}
-
-TEST_F(FusionMergerTest, IncompatibleNonTrivialHeroes) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule module
-
-    fused_computation {
-      param_0.1 = f32[18,16,32]{2,1,0} parameter(0)
-      param_1.1 = f32[32,16,18]{2,1,0} parameter(1)
-      s.1 = f32[18,16,32]{2,1,0} sqrt(param_0.1)
-      t.1 = f32[32,16,18]{2,1,0} transpose(s.1), dimensions={2,1,0}
-      sub.1 = f32[32,16,18]{2,1,0} subtract(t.1, param_1.1)
-      exp.1 = f32[32,16,18]{2,1,0} exponential(sub.1)
-      ROOT add.1 = f32[32,16,18]{2,1,0} add(exp.1, exp.1)
-    }
-
-    fused_computation.2 {
-      param_0.2 = f32[32,16,18]{2,1,0} parameter(0)
-      s.2 = f32[32,16,18]{2,1,0} sqrt(param_0.2)
-      ROOT t.2 = f32[32,18,16]{2,1,0} transpose(s.2), dimensions={0,2,1}
-    }
-
-    ENTRY main {
-      p = f32[18,16,32]{2,1,0} parameter(0)
-      p2 = f32[32,16,18]{2,1,0} parameter(1)
-      fusion = f32[32,16,18]{2,1,0} fusion(p, p2), kind=kLoop, calls=fused_computation
-      ROOT fusion2 = f32[32,18,16]{2,1,0} fusion(fusion), kind=kInput, calls=fused_computation.2
-    }
-    )")
-                    .value();
-  EXPECT_FALSE(fusion_merger_.Run(module.get()).value());
-}
-
-TEST_F(FusionMergerTest, DoNotMergeDUSFusions) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule module
-
-    %fused_computation (param_0: f32[8], param_1.2: f32[], param_2.3: f32[8]) -> f32[8] {
-      %param_0 = f32[8]{0} parameter(0)
-      %param_2.3 = f32[8]{0} parameter(2)
-      %slice.2 = f32[5]{0} slice(f32[8]{0} %param_2.3), slice={[0:5]}
-      %param_1.2 = f32[] parameter(1)
-      %broadcast.2 = f32[5]{0} broadcast(f32[] %param_1.2), dimensions={}
-      %add.2 = f32[5]{0} add(f32[5]{0} %slice.2, f32[5]{0} %broadcast.2)
-      %two.1 = s32[] constant(2)
-      ROOT %dynamic-update-slice.2 = f32[8]{0} dynamic-update-slice(f32[8]{0} %param_0, f32[5]{0} %add.2, s32[] %two.1)
-    }
-
-    %fused_computation.1 (param_0.1: f32[8], param_1.4: f32[6], param_2.6: f32[]) -> f32[8] {
-      %param_0.1 = f32[8]{0} parameter(0)
-      %param_1.4 = f32[6]{0} parameter(1)
-      %param_2.6 = f32[] parameter(2)
-      %broadcast.3 = f32[6]{0} broadcast(f32[] %param_2.6), dimensions={}
-      %add.3 = f32[6]{0} add(f32[6]{0} %param_1.4, f32[6]{0} %broadcast.3)
-      %three.1 = s32[] constant(3)
-      ROOT %dynamic-update-slice.3 = f32[8]{0} dynamic-update-slice(f32[8]{0} %param_0.1, f32[6]{0} %add.3, s32[] %three.1)
-    }
-
-    ENTRY %Test (parameter: f32[8]) -> f32[8] {
-      %parameter = f32[8]{0} parameter(0)
-      %slice.1 = f32[6]{0} slice(f32[8]{0} %parameter), slice={[0:6]}
-      %one = f32[] constant(1)
-      %fusion.1 = f32[8]{0} fusion(f32[8]{0} %parameter, f32[6]{0} %slice.1, f32[] %one), kind=kLoop, calls=%fused_computation.1
-      ROOT %fusion = f32[8]{0} fusion(f32[8]{0} %fusion.1, f32[] %one, f32[8]{0} %parameter), kind=kLoop, calls=%fused_computation
-    }
-    )")
-                    .value();
-  EXPECT_FALSE(fusion_merger_.Run(module.get()).value());
-}
-
-TEST_F(FusionMergerTest, MergeDUSFusionWithElementwiseFusion) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule module
-
-    %fused_computation {
-      %param_0 = f32[1,8]{1,0} parameter(0)
-      %bitcast = f32[8]{0} bitcast(%param_0)
-      ROOT %neg = f32[8]{0} negate(%bitcast)
-    }
-
-    %fused_computation.1 {
-      %param_0.1 = f32[8]{0} parameter(0)
-      %param_1.4 = f32[5]{0} parameter(1)
-      %three.1 = s32[] constant(3)
-      %exp = f32[5]{0} exponential(%param_1.4)
-      ROOT %dynamic-update-slice.3 = f32[8]{0} dynamic-update-slice(f32[8]{0} %param_0.1, f32[5]{0} %exp, s32[] %three.1)
-    }
-
-    ENTRY %Test {
-      %parameter = f32[5]{0} parameter(0)
-      %parameter.1 = f32[1,8]{1,0} parameter(1)
-      %fusion = f32[8]{0} fusion(f32[1,8]{1,0} %parameter.1), kind=kLoop, calls=%fused_computation
-      ROOT %fusion.1 = f32[8]{0} fusion(f32[8]{0} %fusion, f32[5]{0} %parameter), kind=kLoop, calls=%fused_computation.1
-    }
-    )")
-                    .value();
-  EXPECT_TRUE(fusion_merger_.Run(module.get()).value());
-}
-
-TEST_F(FusionMergerTest, DoNotMergeTwoReduces) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule test_module
-
-    add {
-      p0 = f32[] parameter(0)
-      p1 = f32[] parameter(1)
-      ROOT add.13235 = f32[] add(p0, p1)
-    }
-
-    ENTRY main {
-      p0 = f32[8,4,128,226]{3,2,1,0} parameter(0)
-      c0 = f32[] constant(0)
-      r0 = f32[8,4,128]{2,1,0} reduce(p0, c0), dimensions={3}, to_apply=add
-      ROOT r1 = f32[8,4]{1,0} reduce(r0, c0), dimensions={2}, to_apply=add
-    }
-    )")
-                    .value();
-  EXPECT_FALSE(fusion_merger_.Run(module.get()).value());
-}
-
-}  // namespace
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusion_pipeline.cc b/third_party/xla/xla/service/gpu/fusion_pipeline.cc
index 2a184c0..4fc4af0 100644
--- a/third_party/xla/xla/service/gpu/fusion_pipeline.cc
+++ b/third_party/xla/xla/service/gpu/fusion_pipeline.cc
@@ -19,14 +19,14 @@
 #include <utility>
 
 #include "xla/service/cpu_gpu_shape_verifier.h"
-#include "xla/service/gpu/fusion_merger.h"
-#include "xla/service/gpu/horizontal_input_fusion.h"
-#include "xla/service/gpu/horizontal_loop_fusion.h"
-#include "xla/service/gpu/instruction_fusion.h"
 #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h"
-#include "xla/service/gpu/multi_output_fusion.h"
-#include "xla/service/gpu/priority_fusion.h"
-#include "xla/service/gpu/variadic_op_splitter.h"
+#include "xla/service/gpu/transforms/fusion_merger.h"
+#include "xla/service/gpu/transforms/horizontal_input_fusion.h"
+#include "xla/service/gpu/transforms/horizontal_loop_fusion.h"
+#include "xla/service/gpu/transforms/instruction_fusion.h"
+#include "xla/service/gpu/transforms/multi_output_fusion.h"
+#include "xla/service/gpu/transforms/priority_fusion.h"
+#include "xla/service/gpu/transforms/variadic_op_splitter.h"
 #include "xla/service/hlo_cost_analysis.h"
 #include "xla/service/hlo_cse.h"
 #include "xla/service/hlo_dce.h"
@@ -63,8 +63,8 @@
         shape_size_bytes_function,
         /*per_second_rates=*/{},
         /*count_multiple_input_accesses=*/true};
-    fusion.AddPass<GpuPriorityFusion>(thread_pool, gpu_device_info,
-                                      std::move(cost_analysis_options));
+    fusion.AddPass<PriorityFusion>(thread_pool, gpu_device_info,
+                                   std::move(cost_analysis_options));
   } else {
     fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/false,
                                          gpu_device_info);
@@ -77,8 +77,7 @@
   fusion.AddPass<HloCSE>(/*is_layout_sensitive=*/true,
                          /*only_fusion_computations=*/true);
   fusion.AddPass<HloDCE>();
-  fusion.AddPass<GpuMultiOutputFusion>(gpu_device_info,
-                                       shape_size_bytes_function);
+  fusion.AddPass<MultiOutputFusion>(gpu_device_info, shape_size_bytes_function);
   fusion.AddPass<HloCSE>(/*is_layout_sensitive=*/true,
                          /*only_fusion_computations=*/true);
   fusion.AddPass<HloDCE>();
@@ -88,8 +87,8 @@
 HloPassPipeline HorizontalFusionPipeline(
     const se::DeviceDescription& gpu_device_info) {
   HloPassFix<HloPassPipeline> horizontal_fusion("horizontal fusion");
-  horizontal_fusion.AddPass<GpuHorizontalLoopFusion>();
-  horizontal_fusion.AddPass<GpuHorizontalInputFusion>(gpu_device_info);
+  horizontal_fusion.AddPass<HorizontalLoopFusion>();
+  horizontal_fusion.AddPass<HorizontalInputFusion>(gpu_device_info);
   horizontal_fusion.AddPass<HloCSE>(/*is_layout_sensitive=*/true,
                                     /*only_fusion_computations=*/true);
   horizontal_fusion.AddPass<HloDCE>();
diff --git a/third_party/xla/xla/service/gpu/fusion_wrapper.cc b/third_party/xla/xla/service/gpu/fusion_wrapper.cc
deleted file mode 100644
index 2cb8471..0000000
--- a/third_party/xla/xla/service/gpu/fusion_wrapper.cc
+++ /dev/null
@@ -1,152 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "xla/service/gpu/fusion_wrapper.h"
-
-#include <functional>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/gpu/gpu_fusible.h"
-#include "tsl/platform/errors.h"
-
-namespace xla {
-namespace gpu {
-
-absl::StatusOr<bool> FusionWrapper::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  auto instructions = module->entry_computation()->MakeInstructionPostOrder();
-  bool changed = false;
-
-  std::function<absl::Status(HloInstruction*)> handle_instruction;
-  handle_instruction = [&](HloInstruction* instruction) -> absl::Status {
-    switch (instruction->opcode()) {
-      case HloOpcode::kConditional:
-      case HloOpcode::kWhile:
-        for (auto* computation : instruction->called_computations()) {
-          for (auto* inner_instruction :
-               computation->MakeInstructionPostOrder()) {
-            TF_RETURN_IF_ERROR(handle_instruction(inner_instruction));
-          }
-        }
-        break;
-      case HloOpcode::kAbs:
-      case HloOpcode::kAdd:
-      case HloOpcode::kAnd:
-      case HloOpcode::kAtan2:
-      case HloOpcode::kBitcastConvert:
-      case HloOpcode::kBroadcast:
-      case HloOpcode::kCeil:
-      case HloOpcode::kCbrt:
-      case HloOpcode::kClamp:
-      case HloOpcode::kClz:
-      case HloOpcode::kCompare:
-      case HloOpcode::kComplex:
-      case HloOpcode::kConcatenate:
-      case HloOpcode::kConvert:
-      case HloOpcode::kCopy:
-      case HloOpcode::kCos:
-      case HloOpcode::kDivide:
-      case HloOpcode::kDot:
-      case HloOpcode::kDynamicSlice:
-      case HloOpcode::kDynamicUpdateSlice:
-      case HloOpcode::kErf:
-      case HloOpcode::kExp:
-      case HloOpcode::kExpm1:
-      case HloOpcode::kFloor:
-      case HloOpcode::kGather:
-      case HloOpcode::kImag:
-      case HloOpcode::kIota:
-      case HloOpcode::kIsFinite:
-      case HloOpcode::kLog:
-      case HloOpcode::kLog1p:
-      case HloOpcode::kMap:
-      case HloOpcode::kMaximum:
-      case HloOpcode::kMinimum:
-      case HloOpcode::kMultiply:
-      case HloOpcode::kNegate:
-      case HloOpcode::kNot:
-      case HloOpcode::kOr:
-      case HloOpcode::kPad:
-      case HloOpcode::kPopulationCount:
-      case HloOpcode::kPower:
-      case HloOpcode::kReal:
-      case HloOpcode::kReshape:
-      case HloOpcode::kReduce:
-      case HloOpcode::kReducePrecision:
-      case HloOpcode::kReduceWindow:
-      case HloOpcode::kRemainder:
-      case HloOpcode::kReverse:
-      case HloOpcode::kRoundNearestAfz:
-      case HloOpcode::kRoundNearestEven:
-      case HloOpcode::kRsqrt:
-      case HloOpcode::kScatter:
-      case HloOpcode::kSelect:
-      case HloOpcode::kShiftLeft:
-      case HloOpcode::kShiftRightLogical:
-      case HloOpcode::kShiftRightArithmetic:
-      case HloOpcode::kSign:
-      case HloOpcode::kSin:
-      case HloOpcode::kSlice:
-      case HloOpcode::kSqrt:
-      case HloOpcode::kSubtract:
-      case HloOpcode::kStochasticConvert:
-      case HloOpcode::kTan:
-      case HloOpcode::kTanh:
-      case HloOpcode::kTranspose:
-      case HloOpcode::kXor: {
-        auto* computation = instruction->parent();
-        auto* fusion_instruction =
-            computation->AddInstruction(HloInstruction::CreateFusion(
-                instruction->shape(),
-                ChooseFusionKind(*instruction, *instruction), instruction));
-        const absl::string_view wrapped_opcode =
-            HloOpcodeString(instruction->opcode());
-        module->SetAndUniquifyInstrName(
-            fusion_instruction, absl::StrCat("wrapped_", wrapped_opcode));
-        module->SetAndUniquifyComputationName(
-            fusion_instruction->fused_instructions_computation(),
-            absl::StrCat("wrapped_", wrapped_opcode, "_computation"));
-        if (module->has_schedule()) {
-          module->schedule().replace_instruction(computation, instruction,
-                                                 fusion_instruction);
-        }
-        TF_RETURN_IF_ERROR(
-            fusion_instruction->CopyAllControlDepsFrom(instruction));
-        TF_RETURN_IF_ERROR(instruction->DropAllControlDeps());
-        TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(fusion_instruction));
-        TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction));
-        changed = true;
-        break;
-      }
-      default:
-        break;
-    }
-    return absl::OkStatus();
-  };
-
-  for (auto* instruction : instructions) {
-    TF_RETURN_IF_ERROR(handle_instruction(instruction));
-  }
-  return changed;
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusion_wrapper.h b/third_party/xla/xla/service/gpu/fusion_wrapper.h
deleted file mode 100644
index fc46692..0000000
--- a/third_party/xla/xla/service/gpu/fusion_wrapper.h
+++ /dev/null
@@ -1,42 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_FUSION_WRAPPER_H_
-#define XLA_SERVICE_GPU_FUSION_WRAPPER_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-// Wraps leftover unfused instruction that are in the entry computation that
-// have no LHLO equivalent in fusions containing just that instruction.
-class FusionWrapper : public HloModulePass {
- public:
-  absl::string_view name() const override { return "fusion-wrapper"; }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_FUSION_WRAPPER_H_
diff --git a/third_party/xla/xla/service/gpu/fusion_wrapper_test.cc b/third_party/xla/xla/service/gpu/fusion_wrapper_test.cc
deleted file mode 100644
index 397fe75..0000000
--- a/third_party/xla/xla/service/gpu/fusion_wrapper_test.cc
+++ /dev/null
@@ -1,188 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "xla/service/gpu/fusion_wrapper.h"
-
-#include <optional>
-
-#include <gtest/gtest.h>
-#include "xla/tests/hlo_test_base.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-class FusionWrapperTest : public HloTestBase {};
-
-TEST_F(FusionWrapperTest, SimpleOp) {
-  RunAndFilecheckHloRewrite(R"(
-      HloModule TestModule
-
-      ENTRY TestComputation {
-        p0 = f16[30,41] parameter(0)
-        p1 = f16[30,41] parameter(1)
-        ROOT result = f16[60, 41] concatenate(p0, p1), dimensions={0}
-      })",
-                            FusionWrapper(), R"(
-// CHECK: %wrapped_concatenate_computation (param_0: f16[30,41], param_1: f16[30,41]) -> f16[60,41] {
-// CHECK:   %param_0 = f16[30,41]{1,0} parameter(0)
-// CHECK:   %param_1 = f16[30,41]{1,0} parameter(1)
-// CHECK:   ROOT %result.1 = f16[60,41]{1,0} concatenate(%param_0, %param_1), dimensions={0}
-// CHECK: }
-
-// CHECK: ENTRY %TestComputation (p0: f16[30,41], p1: f16[30,41]) -> f16[60,41] {
-// CHECK:   %p0 = f16[30,41]{1,0} parameter(0)
-// CHECK:   %p1 = f16[30,41]{1,0} parameter(1)
-// CHECK:   ROOT %wrapped_concatenate = f16[60,41]{1,0} fusion(%p0, %p1), kind=kLoop, calls=%wrapped_concatenate_computation
-// CHECK: })");
-}
-
-TEST_F(FusionWrapperTest, Scatter) {
-  RunAndFilecheckHloRewrite(R"(
-      HloModule ScatterIntoScalar
-
-      update_s32 {
-        lhs = s32[] parameter(0)
-        ROOT rhs = s32[] parameter(1)
-      }
-
-      ENTRY main {
-        parameter.1 = s32[] parameter(0)
-        parameter.2 = s32[0]{0} parameter(1)
-        parameter.3 = s32[] parameter(2)
-        ROOT scatter_ScatterIntoScalar = s32[] scatter(parameter.1, parameter.2, parameter.3),
-            update_window_dims={},
-            inserted_window_dims={},
-            scatter_dims_to_operand_dims={},
-            index_vector_dim=0,
-            to_apply=update_s32
-      })",
-                            FusionWrapper(), R"(
-// CHECK: wrapped_scatter_computation
-// CHECK:   %[[param_0:.*]] = s32[] parameter(0)
-// CHECK:   %[[param_1:.*]] = s32[0]{0} parameter(1)
-// CHECK:   %[[param_2:.*]] = s32[] parameter(2)
-// CHECK:   ROOT %{{.*}} = s32[] scatter(%[[param_0]], %[[param_1]], %[[param_2]])
-
-// CHECK: ENTRY
-// CHECK:   %[[p0:.*]] = s32[] parameter(0)
-// CHECK:   %[[p1:.*]] = s32[0]{0} parameter(1)
-// CHECK:   %[[p2:.*]] = s32[] parameter(2)
-// CHECK:   ROOT %{{.*}} = s32[] fusion(%[[p0]], %[[p1]], %[[p2]]), kind=kInput, calls=%wrapped_scatter_computation
-// CHECK: })");
-}
-
-TEST_F(FusionWrapperTest, ControlDependency) {
-  RunAndFilecheckHloRewrite(R"(
-      HloModule TestModule
-
-      fusion {
-        ROOT param = f32[] parameter(0)
-      }
-
-      ENTRY main {
-        param = f32[] parameter(0)
-        fusion = f32[] fusion(param), kind=kLoop, calls=fusion
-        constant_one = f32[] constant(1)
-        ROOT add = f32[] add(param, constant_one), control-predecessors={fusion}
-      })",
-                            FusionWrapper(), R"(
-// CHECK: ROOT %wrapped_add = f32[] fusion(%param.1, %constant_one),
-// CHECK-SAME: control-predecessors={%fusion})");
-}
-
-TEST_F(FusionWrapperTest, While) {
-  RunAndFilecheckHloRewrite(R"(
-      HloModule While
-
-      %body {
-        %parameter.5 = (f32[5]{0}) parameter(0)
-        %constant_8 = f32[] constant(0)
-        %broadcast.9 = f32[5]{0} broadcast(f32[] %constant_8), dimensions={}
-        ROOT %tuple.2 = (f32[5]{0}) tuple(f32[5]{0} %broadcast.9)
-      }
-
-      %cond {
-        %parameter.12 = (f32[5]{0}) parameter(0)
-        ROOT %constant_1 = pred[] constant(false)
-      }
-
-      ENTRY %main (parameter.1: f32[5]) -> (f32[5]) {
-        %parameter.1 = f32[5]{0} parameter(0)
-        %copy.3 = f32[5]{0} copy(f32[5]{0} %parameter.1)
-        %tuple = (f32[5]{0}) tuple(f32[5]{0} %copy.3)
-        ROOT %while.19 = (f32[5]{0}) while((f32[5]{0}) %tuple), condition=%cond, body=%body
-      })",
-                            FusionWrapper(), R"(
-// CHECK: %wrapped_broadcast_computation {{.*}} {
-// CHECK:  %param_0.1 = f32[] parameter(0)
-// CHECK:  ROOT %broadcast.0 = f32[5]{0} broadcast(%param_0.1), dimensions={}
-// CHECK: }
-// CHECK: %body {{.*}} {
-// CHECK:   %parameter.5 = (f32[5]{0}) parameter(0)
-// CHECK:   %constant_8 = f32[] constant(0)
-// CHECK:   %wrapped_broadcast = f32[5]{0} fusion(%constant_8), kind=kLoop, calls=%wrapped_broadcast_computation
-// CHECK:   ROOT %tuple.2 = (f32[5]{0}) tuple(%wrapped_broadcast)
-// CHECK: }
-// CHECK: %cond {{.*}} {
-// CHECK:   %parameter.12 = (f32[5]{0}) parameter(0)
-// CHECK:   ROOT %constant_1 = pred[] constant(false)
-// CHECK: }
-// CHECK: %wrapped_copy_computation {{.*}} {
-// CHECK:   %param_0 = f32[5]{0} parameter(0)
-// CHECK:   ROOT %copy.0 = f32[5]{0} copy(%param_0)
-// CHECK: }
-// CHECK: ENTRY %main {{.*}} {
-// CHECK:   %parameter.1 = f32[5]{0} parameter(0)
-// CHECK:   %wrapped_copy = f32[5]{0} fusion(%parameter.1), kind=kLoop, calls=%wrapped_copy_computation
-// CHECK:   %tuple = (f32[5]{0}) tuple(%wrapped_copy)
-// CHECK:   ROOT %while.19 = (f32[5]{0}) while(%tuple), condition=%cond, body=%body
-// CHECK: })");
-}
-
-TEST_F(FusionWrapperTest, WhileInFusion) {
-  RunAndFilecheckHloRewrite(R"(
-      HloModule While
-
-      %body {
-        %parameter.5 = (f32[5]{0}) parameter(0)
-        %constant_8 = f32[] constant(0)
-        %broadcast.9 = f32[5]{0} broadcast(f32[] %constant_8), dimensions={}
-        ROOT %tuple.2 = (f32[5]{0}) tuple(f32[5]{0} %broadcast.9)
-      }
-
-      %cond {
-        %parameter.12 = (f32[5]{0}) parameter(0)
-        ROOT %constant_1 = pred[] constant(false)
-      }
-
-      %fusion {
-        %parameter.1 = f32[5]{0} parameter(0)
-        %copy.3 = f32[5]{0} copy(f32[5]{0} %parameter.1)
-        %tuple = (f32[5]{0}) tuple(f32[5]{0} %copy.3)
-        ROOT %while.19 = (f32[5]{0}) while((f32[5]{0}) %tuple), condition=%cond, body=%body
-      }
-
-      ENTRY %main (parameter.1: f32[5]) -> (f32[5]) {
-        %parameter.1 = f32[5]{0} parameter(0)
-        ROOT %fusion = (f32[5]{0}) fusion(f32[5]{0} %parameter.1), kind=kLoop, calls=%fusion
-      })",
-                            FusionWrapper(),
-                            // No change
-                            std::nullopt);
-}
-
-}  // namespace
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/BUILD b/third_party/xla/xla/service/gpu/fusions/BUILD
index e936cdf..188eaed 100644
--- a/third_party/xla/xla/service/gpu/fusions/BUILD
+++ b/third_party/xla/xla/service/gpu/fusions/BUILD
@@ -9,49 +9,6 @@
 )
 
 cc_library(
-    name = "in_place_dynamic_update_slice",
-    srcs = ["in_place_dynamic_update_slice.cc"],
-    hdrs = ["in_place_dynamic_update_slice.h"],
-    deps = [
-        ":fusion_emitter",
-        "//xla/hlo/ir:hlo",
-        "//xla/service/gpu:hlo_fusion_analysis",
-        "//xla/service/gpu:ir_emission_utils",
-        "//xla/service/gpu:ir_emitter",
-        "//xla/service/gpu:ir_emitter_context",
-        "//xla/service/gpu:launch_dimensions",
-        "//xla/service/gpu/model:indexing_analysis",
-        "//xla/service/llvm_ir:dynamic_update_slice_util",
-        "//xla/service/llvm_ir:fused_ir_emitter",
-        "//xla/service/llvm_ir:ir_array",
-        "@com_google_absl//absl/status",
-        "@com_google_absl//absl/status:statusor",
-        "@llvm-project//llvm:Support",
-        "@llvm-project//llvm:ir_headers",
-        "@llvm-project//mlir:IR",
-    ],
-)
-
-xla_cc_test(
-    name = "in_place_dynamic_update_slice_test",
-    srcs = ["in_place_dynamic_update_slice_test.cc"],
-    deps = [
-        ":fusions",
-        ":in_place_dynamic_update_slice",
-        "//xla/service/gpu:gpu_device_info_for_tests",
-        "//xla/service/gpu:hlo_fusion_analysis",
-        "//xla/service/gpu/model:affine_map_printer",
-        "//xla/service/gpu/model:indexing_test_utils",
-        "//xla/stream_executor:device_description",
-        "//xla/tests:hlo_test_base",
-        "//xla/tests:xla_internal_test_main",
-        "@com_google_googletest//:gtest",
-        "@llvm-project//mlir:IR",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-cc_library(
     name = "in_place_dynamic_update_slice_mlir",
     srcs = ["in_place_dynamic_update_slice_mlir.cc"],
     hdrs = ["in_place_dynamic_update_slice_mlir.h"],
@@ -122,6 +79,7 @@
     deps = [
         ":fusion_emitter",
         "//xla:shape_util",
+        "//xla:status_macros",
         "//xla:util",
         "//xla:xla_data_proto_cc",
         "//xla/ffi:attribute_map",
@@ -147,6 +105,10 @@
         "//xla/service/gpu/runtime:dynamic_slice_thunk",
         "//xla/service/gpu/runtime:gemm_thunk",
         "//xla/service/gpu/runtime:kernel_thunk",
+        "//xla/service/gpu/runtime:nccl_all_reduce_thunk",
+        "//xla/service/gpu/runtime:nccl_api",
+        "//xla/service/gpu/runtime:nccl_clique_key",
+        "//xla/service/gpu/runtime:nccl_collective_thunk",
         "//xla/service/gpu/runtime:thunk",
         "@com_google_absl//absl/algorithm:container",
         "@com_google_absl//absl/log",
@@ -167,6 +129,12 @@
 xla_test(
     name = "dynamic_slice_fusion_test",
     srcs = if_cuda_is_configured(["dynamic_slice_fusion_test.cc"]),
+    backend_tags = {
+        "gpu": [
+            "multi_gpu",
+            "no_oss",
+        ],
+    },
     backends = ["gpu"],
     local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
     deps = [
@@ -184,11 +152,14 @@
         "//xla/service:executable",
         "//xla/service:hlo_module_config",
         "//xla/service:hlo_proto_cc",
-        "//xla/service/gpu:dynamic_slice_fusion_rewriter",
+        "//xla/service/gpu:backend_configs_cc",
+        "//xla/service/gpu/transforms:dynamic_slice_fusion_rewriter",
         "//xla/stream_executor",
         "//xla/stream_executor:device_description",
         "//xla/stream_executor/gpu:gpu_types_header",
+        "//xla/tests:filecheck",
         "//xla/tests:hlo_test_base",
+        "//xla/tests:test_utils",
         "@com_google_absl//absl/algorithm:container",
         "@com_google_absl//absl/status",
         "@local_tsl//tsl/platform:errors",
@@ -244,23 +215,16 @@
     hdrs = ["fusions.h"],
     visibility = ["//xla/service/gpu:__subpackages__"],
     deps = [
-        ":concatenate",
         ":concatenate_mlir",
         ":copy",
         ":cudnn",
         ":custom",
         ":fusion_emitter",
-        ":in_place_dynamic_update_slice",
         ":in_place_dynamic_update_slice_mlir",
-        ":input_slices",
         ":input_slices_mlir",
-        ":loop",
         ":loop_mlir",
-        ":reduction",
         ":reduction_mlir",
-        ":scatter",
         ":scatter_mlir",
-        ":transpose",
         ":transpose_mlir",
         ":triton",
         "//xla:shape_util",
@@ -270,6 +234,13 @@
         "//xla/service/gpu:hlo_fusion_analysis",
         "//xla/service/gpu:hlo_traversal",
         "//xla/service/gpu:ir_emission_utils",
+        "//xla/service/gpu/fusions/legacy:concatenate",
+        "//xla/service/gpu/fusions/legacy:in_place_dynamic_update_slice",
+        "//xla/service/gpu/fusions/legacy:input_slices",
+        "//xla/service/gpu/fusions/legacy:loop",
+        "//xla/service/gpu/fusions/legacy:reduction",
+        "//xla/service/gpu/fusions/legacy:scatter",
+        "//xla/service/gpu/fusions/legacy:transpose",
         "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir",
         "@com_google_absl//absl/algorithm:container",
         "@com_google_absl//absl/log",
@@ -291,8 +262,8 @@
         "//xla/service:gpu_plugin",
         "//xla/service/gpu:gpu_device_info_for_tests",
         "//xla/service/gpu:hlo_fusion_analysis",
+        "//xla/service/gpu/fusions/ir:xla_gpu",
         "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter",
-        "//xla/service/gpu/fusions/mlir/ir:xla_gpu",
         "//xla/service/gpu/model:affine_map_printer",
         "//xla/stream_executor:device_description",
         "//xla/tests:filecheck",
@@ -320,53 +291,22 @@
 )
 
 cc_library(
-    name = "loop",
-    srcs = ["loop.cc"],
-    hdrs = ["loop.h"],
-    deps = [
-        ":fusion_emitter",
-        "//xla:shape_util",
-        "//xla:util",
-        "//xla/hlo/ir:hlo",
-        "//xla/service/gpu:gpu_fusible",
-        "//xla/service/gpu:hlo_fusion_analysis",
-        "//xla/service/gpu:hlo_traversal",
-        "//xla/service/gpu:ir_emission_utils",
-        "//xla/service/gpu:ir_emitter",
-        "//xla/service/gpu:ir_emitter_context",
-        "//xla/service/gpu:launch_dimensions",
-        "//xla/service/gpu:parallel_loop_emitter",
-        "//xla/service/gpu/model:indexing_analysis",
-        "//xla/service/llvm_ir:fused_ir_emitter",
-        "//xla/service/llvm_ir:ir_array",
-        "@com_google_absl//absl/log",
-        "@com_google_absl//absl/log:check",
-        "@com_google_absl//absl/numeric:bits",
-        "@com_google_absl//absl/status",
-        "@llvm-project//llvm:ir_headers",
-        "@llvm-project//mlir:IR",
-        "@local_tsl//tsl/platform:macros",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-cc_library(
     name = "loop_mlir",
     srcs = ["loop_mlir.cc"],
     hdrs = ["loop_mlir.h"],
     deps = [
-        ":loop",
         "//xla:shape_util",
         "//xla:status_macros",
         "//xla:xla_data_proto_cc",
         "//xla/hlo/ir:hlo",
+        "//xla/service/gpu:gpu_fusible",
         "//xla/service/gpu:hlo_fusion_analysis",
         "//xla/service/gpu:hlo_traversal",
         "//xla/service/gpu:launch_dimensions",
+        "//xla/service/gpu/fusions/ir:xla_gpu",
         "//xla/service/gpu/fusions/mlir:computation_partitioner",
         "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir",
         "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter",
-        "//xla/service/gpu/fusions/mlir/ir:xla_gpu",
         "//xla/service/gpu/model:indexing_analysis",
         "@com_google_absl//absl/algorithm:container",
         "@com_google_absl//absl/log",
@@ -401,17 +341,17 @@
     srcs = ["scatter_mlir.cc"],
     hdrs = ["scatter_mlir.h"],
     deps = [
-        ":loop",
         "//xla:shape_util",
         "//xla:xla_data_proto_cc",
         "//xla/hlo/ir:hlo",
         "//xla/service:scatter_simplifier",
+        "//xla/service/gpu:gpu_fusible",
         "//xla/service/gpu:hlo_fusion_analysis",
         "//xla/service/gpu:launch_dimensions",
+        "//xla/service/gpu/fusions/ir:xla_gpu",
         "//xla/service/gpu/fusions/mlir:computation_partitioner",
         "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir",
         "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter",
-        "//xla/service/gpu/fusions/mlir/ir:xla_gpu",
         "//xla/service/gpu/model:indexing_analysis",
         "@com_google_absl//absl/container:flat_hash_set",
         "@com_google_absl//absl/log",
@@ -456,15 +396,14 @@
         "//xla:util",
         "//xla:xla_data_proto_cc",
         "//xla/hlo/ir:hlo",
-        "//xla/mlir/utils:type_util",
         "//xla/service/gpu:hlo_fusion_analysis",
         "//xla/service/gpu:ir_emission_utils",
         "//xla/service/gpu:launch_dimensions",
+        "//xla/service/gpu/fusions/ir:xla_gpu",
         "//xla/service/gpu/fusions/mlir:computation_partitioner",
         "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir",
         "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter",
         "//xla/service/gpu/fusions/mlir:type_util",
-        "//xla/service/gpu/fusions/mlir/ir:xla_gpu",
         "//xla/service/gpu/model:indexing_analysis",
         "@com_google_absl//absl/algorithm:container",
         "@com_google_absl//absl/container:flat_hash_map",
@@ -497,102 +436,6 @@
     ],
 )
 
-xla_cc_test(
-    name = "loop_test",
-    srcs = ["loop_test.cc"],
-    deps = [
-        ":fusion_emitter",
-        ":fusions",
-        "//xla:status_macros",
-        "//xla/service/gpu:gpu_device_info_for_tests",
-        "//xla/service/gpu:hlo_fusion_analysis",
-        "//xla/service/gpu/model:affine_map_printer",
-        "//xla/service/gpu/model:indexing_test_utils",
-        "//xla/stream_executor:device_description",
-        "//xla/tests:hlo_test_base",
-        "//xla/tests:xla_internal_test_main",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_googletest//:gtest",
-        "@llvm-project//mlir:IR",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-cc_library(
-    name = "scatter",
-    srcs = ["scatter.cc"],
-    hdrs = ["scatter.h"],
-    deps = [
-        ":fusion_emitter",
-        ":loop",
-        "//xla:shape_util",
-        "//xla:xla_data_proto_cc",
-        "//xla/hlo/ir:hlo",
-        "//xla/service/gpu:hlo_fusion_analysis",
-        "//xla/service/gpu:ir_emission_utils",
-        "//xla/service/gpu:ir_emitter",
-        "//xla/service/gpu:ir_emitter_context",
-        "//xla/service/gpu:launch_dimensions",
-        "//xla/service/gpu:parallel_loop_emitter",
-        "//xla/service/gpu/model:indexing_analysis",
-        "//xla/service/llvm_ir:fused_ir_emitter",
-        "//xla/service/llvm_ir:ir_array",
-        "//xla/service/llvm_ir:llvm_util",
-        "@com_google_absl//absl/log:check",
-        "@com_google_absl//absl/status",
-        "@com_google_absl//absl/types:span",
-        "@llvm-project//llvm:Support",
-        "@llvm-project//llvm:ir_headers",
-        "@llvm-project//mlir:IR",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-xla_cc_test(
-    name = "scatter_test",
-    srcs = ["scatter_test.cc"],
-    deps = [
-        ":fusions",
-        ":scatter",
-        "//xla/service/gpu:gpu_device_info_for_tests",
-        "//xla/service/gpu:hlo_fusion_analysis",
-        "//xla/service/gpu/model:affine_map_printer",
-        "//xla/service/gpu/model:indexing_test_utils",
-        "//xla/stream_executor:device_description",
-        "//xla/tests:hlo_test_base",
-        "//xla/tests:xla_internal_test_main",
-        "@com_google_googletest//:gtest",
-        "@llvm-project//mlir:IR",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-cc_library(
-    name = "tiling_util",
-    srcs = ["tiling_util.cc"],
-    hdrs = ["tiling_util.h"],
-    visibility = ["//xla/service/gpu:__subpackages__"],
-    deps = [
-        "//xla:shape_util",
-        "//xla:util",
-        "//xla:xla_data_proto_cc",
-        "//xla/service/gpu:ir_emission_utils",
-        "//xla/service/gpu:target_util",
-        "//xla/service/llvm_ir:ir_array",
-        "//xla/service/llvm_ir:kernel_support_library",
-        "//xla/service/llvm_ir:llvm_loop",
-        "//xla/service/llvm_ir:llvm_util",
-        "@com_google_absl//absl/container:inlined_vector",
-        "@com_google_absl//absl/log:check",
-        "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/strings:str_format",
-        "@com_google_absl//absl/types:span",
-        "@llvm-project//llvm:Support",
-        "@llvm-project//llvm:ir_headers",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
 cc_library(
     name = "triton",
     srcs = ["triton.cc"],
@@ -694,14 +537,13 @@
         "//xla/hlo/ir:hlo",
         "//xla/service:dump",
         "//xla/service:executable",
-        "//xla/service:hlo_module_config",
         "//xla/service:pattern_matcher",
         "//xla/service:pattern_matcher_gmock",
-        "//xla/service/gpu:cudnn_fusion_compiler",
         "//xla/service/gpu:ir_emission_utils",
         "//xla/service/gpu:stream_executor_util",
         "//xla/service/gpu/runtime:thunk",
         "//xla/service/gpu/tests:gpu_codegen_test",
+        "//xla/service/gpu/transforms:cudnn_fusion_compiler",
         "//xla/stream_executor:dnn",
         "//xla/stream_executor:stream_executor_h",
         "//xla/stream_executor:stream_executor_memory_allocator",
@@ -739,87 +581,12 @@
 )
 
 cc_library(
-    name = "reduction",
-    srcs = ["reduction.cc"],
-    hdrs = ["reduction.h"],
-    deps = [
-        ":fusion_emitter",
-        ":reduction_base",
-        ":thunk_util",
-        ":tiling_util",
-        "//xla:shape_util",
-        "//xla:status_macros",
-        "//xla:util",
-        "//xla:xla_data_proto_cc",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:buffer_assignment",
-        "//xla/service/gpu:hlo_fusion_analysis",
-        "//xla/service/gpu:hlo_traversal",
-        "//xla/service/gpu:ir_emission_utils",
-        "//xla/service/gpu:ir_emitter",
-        "//xla/service/gpu:ir_emitter_context",
-        "//xla/service/gpu:kernel_arguments",
-        "//xla/service/gpu:kernel_reuse_cache",
-        "//xla/service/gpu:launch_dimensions",
-        "//xla/service/gpu:parallel_loop_emitter",
-        "//xla/service/gpu:reduction_utils",
-        "//xla/service/gpu:target_util",
-        "//xla/service/gpu/runtime:kernel_thunk",
-        "//xla/service/gpu/runtime:thunk",
-        "//xla/service/llvm_ir:fused_ir_emitter",
-        "//xla/service/llvm_ir:ir_array",
-        "//xla/service/llvm_ir:kernel_support_library",
-        "//xla/service/llvm_ir:llvm_loop",
-        "//xla/service/llvm_ir:llvm_util",
-        "//xla/service/llvm_ir:loop_emitter",
-        "//xla/stream_executor:device_description",
-        "@com_google_absl//absl/container:flat_hash_map",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "@com_google_absl//absl/container:inlined_vector",
-        "@com_google_absl//absl/container:node_hash_map",
-        "@com_google_absl//absl/log:check",
-        "@com_google_absl//absl/status",
-        "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/types:span",
-        "@llvm-project//llvm:Support",
-        "@llvm-project//llvm:ir_headers",
-        "@llvm-project//mlir:Support",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:logging",
-        "@local_tsl//tsl/platform:status",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-xla_cc_test(
-    name = "reduction_test",
-    srcs = ["reduction_test.cc"],
-    deps = [
-        ":fusion_emitter",
-        ":reduction",
-        "//xla/hlo/ir:hlo",
-        "//xla/service/gpu:gpu_device_info_for_tests",
-        "//xla/service/gpu:hlo_fusion_analysis",
-        "//xla/service/gpu:ir_emitter_context",
-        "//xla/service/gpu/model:indexing_analysis",
-        "//xla/service/gpu/model:indexing_test_utils",
-        "//xla/stream_executor:device_description",
-        "//xla/tests:hlo_test_base",
-        "//xla/tests:xla_internal_test_main",
-        "@com_google_absl//absl/status",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_googletest//:gtest",
-        "@llvm-project//mlir:IR",
-    ],
-)
-
-cc_library(
     name = "reduction_base",
     srcs = ["reduction_base.cc"],
     hdrs = ["reduction_base.h"],
+    visibility = ["//xla/service/gpu/fusions:__subpackages__"],
     deps = [
         ":fusion_emitter",
-        ":tiling_util",
         "//xla:shape_util",
         "//xla:union_find",
         "//xla:util",
@@ -863,11 +630,11 @@
         "//xla/service/gpu:ir_emission_utils",
         "//xla/service/gpu:launch_dimensions",
         "//xla/service/gpu:reduction_utils",
+        "//xla/service/gpu/fusions/ir:xla_gpu",
         "//xla/service/gpu/fusions/mlir:computation_partitioner",
         "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir",
         "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter",
         "//xla/service/gpu/fusions/mlir:type_util",
-        "//xla/service/gpu/fusions/mlir/ir:xla_gpu",
         "//xla/service/gpu/model:indexing_analysis",
         "@com_google_absl//absl/algorithm:container",
         "@com_google_absl//absl/container:flat_hash_map",
@@ -895,6 +662,7 @@
         ":mlir_emitter_test_base",
         ":reduction_mlir",
         "//xla:error_spec",
+        "//xla/service/gpu:hlo_fusion_analysis",
         "//xla/service/gpu/model:indexing_analysis",
         "//xla/service/gpu/model:indexing_test_utils",
         "//xla/tests:xla_internal_test_main",
@@ -907,59 +675,12 @@
 )
 
 cc_library(
-    name = "concatenate",
-    srcs = ["concatenate.cc"],
-    hdrs = ["concatenate.h"],
-    deps = [
-        ":fusion_emitter",
-        "//xla:shape_util",
-        "//xla/hlo/ir:hlo",
-        "//xla/service/gpu:hlo_fusion_analysis",
-        "//xla/service/gpu:ir_emission_utils",
-        "//xla/service/gpu:ir_emitter",
-        "//xla/service/gpu:ir_emitter_context",
-        "//xla/service/gpu:launch_dimensions",
-        "//xla/service/gpu:parallel_loop_emitter",
-        "//xla/service/gpu/model:indexing_analysis",
-        "//xla/service/llvm_ir:fused_ir_emitter",
-        "//xla/service/llvm_ir:ir_array",
-        "//xla/service/llvm_ir:loop_emitter",
-        "@com_google_absl//absl/algorithm:container",
-        "@com_google_absl//absl/status",
-        "@llvm-project//llvm:Support",
-        "@llvm-project//llvm:ir_headers",
-        "@llvm-project//mlir:IR",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-xla_cc_test(
-    name = "concatenate_test",
-    srcs = ["concatenate_test.cc"],
-    deps = [
-        ":concatenate",
-        ":fusions",
-        "//xla/service/gpu:gpu_device_info_for_tests",
-        "//xla/service/gpu:hlo_fusion_analysis",
-        "//xla/service/gpu/model:affine_map_printer",
-        "//xla/service/gpu/model:indexing_test_utils",
-        "//xla/stream_executor:device_description",
-        "//xla/tests:hlo_test_base",
-        "//xla/tests:xla_internal_test_main",
-        "@com_google_googletest//:gtest",
-        "@llvm-project//mlir:IR",
-    ],
-)
-
-cc_library(
     name = "concatenate_mlir",
     srcs = ["concatenate_mlir.cc"],
     hdrs = ["concatenate_mlir.h"],
     deps = [
-        ":concatenate",
-        ":loop",
         "//xla/hlo/ir:hlo",
+        "//xla/service/gpu:gpu_fusible",
         "//xla/service/gpu:hlo_fusion_analysis",
         "//xla/service/gpu:launch_dimensions",
         "//xla/service/gpu/fusions/mlir:computation_partitioner",
@@ -995,94 +716,6 @@
 )
 
 cc_library(
-    name = "transpose",
-    srcs = ["transpose.cc"],
-    hdrs = ["transpose.h"],
-    deps = [
-        ":fusion_emitter",
-        ":tiling_util",
-        "//xla:permutation_util",
-        "//xla:shape_util",
-        "//xla:util",
-        "//xla:xla_data_proto_cc",
-        "//xla/hlo/ir:hlo",
-        "//xla/service/gpu:hlo_fusion_analysis",
-        "//xla/service/gpu:ir_emission_utils",
-        "//xla/service/gpu:ir_emitter",
-        "//xla/service/gpu:ir_emitter_context",
-        "//xla/service/gpu:launch_dimensions",
-        "//xla/service/gpu:target_util",
-        "//xla/service/gpu/model:indexing_analysis",
-        "//xla/service/llvm_ir:fused_ir_emitter",
-        "//xla/service/llvm_ir:ir_array",
-        "//xla/service/llvm_ir:llvm_util",
-        "//xla/service/llvm_ir:loop_emitter",
-        "@com_google_absl//absl/container:flat_hash_map",
-        "@com_google_absl//absl/container:inlined_vector",
-        "@com_google_absl//absl/log:check",
-        "@com_google_absl//absl/status",
-        "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/types:span",
-        "@llvm-project//llvm:Support",
-        "@llvm-project//llvm:ir_headers",
-        "@llvm-project//mlir:IR",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-xla_cc_test(
-    name = "transpose_test",
-    srcs = ["transpose_test.cc"],
-    deps = [
-        ":fusions",
-        ":transpose",
-        "//xla:status_macros",
-        "//xla/service/gpu:gpu_device_info_for_tests",
-        "//xla/service/gpu:hlo_fusion_analysis",
-        "//xla/service/gpu/model:indexing_test_utils",
-        "//xla/stream_executor:device_description",
-        "//xla/tests:hlo_test_base",
-        "//xla/tests:xla_internal_test_main",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_googletest//:gtest",
-        "@llvm-project//mlir:IR",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-cc_library(
-    name = "input_slices",
-    srcs = ["input_slices.cc"],
-    hdrs = ["input_slices.h"],
-    deps = [
-        ":fusion_emitter",
-        "//xla:shape_util",
-        "//xla:util",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:elemental_ir_emitter",
-        "//xla/service/gpu:hlo_fusion_analysis",
-        "//xla/service/gpu:ir_emission_utils",
-        "//xla/service/gpu:ir_emitter",
-        "//xla/service/gpu:ir_emitter_context",
-        "//xla/service/gpu:launch_dimensions",
-        "//xla/service/gpu:parallel_loop_emitter",
-        "//xla/service/gpu/model:indexing_analysis",
-        "//xla/service/llvm_ir:fused_ir_emitter",
-        "//xla/service/llvm_ir:ir_array",
-        "//xla/service/llvm_ir:kernel_support_library",
-        "//xla/service/llvm_ir:llvm_loop",
-        "@com_google_absl//absl/log",
-        "@com_google_absl//absl/log:check",
-        "@com_google_absl//absl/status",
-        "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/types:span",
-        "@llvm-project//llvm:ir_headers",
-        "@llvm-project//mlir:IR",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-cc_library(
     name = "input_slices_mlir",
     srcs = ["input_slices_mlir.cc"],
     hdrs = ["input_slices_mlir.h"],
@@ -1093,10 +726,10 @@
         "//xla/service/gpu:hlo_fusion_analysis",
         "//xla/service/gpu:hlo_traversal",
         "//xla/service/gpu:launch_dimensions",
+        "//xla/service/gpu/fusions/ir:xla_gpu",
         "//xla/service/gpu/fusions/mlir:computation_partitioner",
         "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir",
         "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter",
-        "//xla/service/gpu/fusions/mlir/ir:xla_gpu",
         "//xla/service/gpu/model:indexing_analysis",
         "@com_google_absl//absl/algorithm:container",
         "@com_google_absl//absl/container:flat_hash_map",
@@ -1124,21 +757,3 @@
         "@com_google_googletest//:gtest",
     ],
 )
-
-xla_cc_test(
-    name = "input_slices_test",
-    srcs = ["input_slices_test.cc"],
-    deps = [
-        ":fusions",
-        ":input_slices",
-        "//xla/service/gpu:gpu_device_info_for_tests",
-        "//xla/service/gpu:hlo_fusion_analysis",
-        "//xla/service/gpu/model:affine_map_printer",
-        "//xla/service/gpu/model:indexing_test_utils",
-        "//xla/stream_executor:device_description",
-        "//xla/tests:hlo_test_base",
-        "//xla/tests:xla_internal_test_main",
-        "@com_google_googletest//:gtest",
-        "@llvm-project//mlir:IR",
-    ],
-)
diff --git a/third_party/xla/xla/service/gpu/fusions/concatenate.cc b/third_party/xla/xla/service/gpu/fusions/concatenate.cc
deleted file mode 100644
index 55a45dc..0000000
--- a/third_party/xla/xla/service/gpu/fusions/concatenate.cc
+++ /dev/null
@@ -1,137 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "xla/service/gpu/fusions/concatenate.h"
-
-#include <cstdint>
-#include <optional>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/status/status.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/IR/Constants.h"
-#include "llvm/IR/IRBuilder.h"
-#include "llvm/IR/Value.h"
-#include "mlir/IR/MLIRContext.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/service/gpu/elemental_ir_emitter.h"
-#include "xla/service/gpu/hlo_fusion_analysis.h"
-#include "xla/service/gpu/ir_emission_utils.h"
-#include "xla/service/gpu/ir_emitter_context.h"
-#include "xla/service/gpu/launch_dimensions.h"
-#include "xla/service/gpu/model/indexing_map.h"
-#include "xla/service/gpu/parallel_loop_emitter.h"
-#include "xla/service/llvm_ir/fused_ir_emitter.h"
-#include "xla/service/llvm_ir/ir_array.h"
-#include "xla/service/llvm_ir/loop_emitter.h"
-#include "xla/shape.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-const Shape& GetLargestConcatOperandShape(const HloFusionAnalysis& analysis) {
-  const HloInstruction& concat = analysis.fusion_hero(0).instruction();
-  int64_t dim = concat.concatenate_dimension();
-  auto less = [&](const HloInstruction* lhs, const HloInstruction* rhs) {
-    return lhs->shape().dimensions(dim) < rhs->shape().dimensions(dim);
-  };
-  HloInstruction* operand = *absl::c_max_element(concat.operands(), less);
-  return operand->shape();
-}
-
-ConcatenateFusion::ConcatenateFusion(const HloFusionAnalysis& analysis)
-    : analysis_(analysis) {}
-
-std::optional<IndexingMap> ConcatenateFusion::ComputeThreadIdToOutputIndexing(
-    int64_t root_index, mlir::MLIRContext* ctx) const {
-  return std::nullopt;
-}
-
-std::optional<IndexingMap> ConcatenateFusion::ComputeThreadIdToInputIndexing(
-    int64_t root_index, int64_t hero_operand_index,
-    mlir::MLIRContext* ctx) const {
-  return GetDefaultThreadIdIndexingMap(launch_dimensions(), /*unroll_factor=*/1,
-                                       GetLargestConcatOperandShape(analysis_),
-                                       ctx);
-}
-
-absl::Status ConcatenateFusion::EmitKernel(
-    IrEmitterContext& ir_emitter_context, const HloFusionInstruction& fusion,
-    const LaunchDimensions& launch_dims, std::vector<llvm_ir::IrArray> inputs,
-    std::vector<llvm_ir::IrArray> outputs, llvm::IRBuilder<>* builder) const {
-  GpuElementalIrEmitter elemental_emitter(ir_emitter_context, builder);
-  FusedIrEmitter fused_emitter(elemental_emitter);
-  for (int i = 0; i < fusion.fused_parameters().size(); i++) {
-    fused_emitter.BindGenerator(
-        *fusion.fused_parameter(i), [&, i](llvm_ir::IrArray::Index index) {
-          return inputs[i].EmitReadArrayElement(index, builder);
-        });
-  }
-
-  llvm::Type* index_type =
-      GetIndexTypeForKernel(&fusion, launch_dims.launch_bound(), builder);
-
-  const HloInstruction& concat = analysis_.fusion_hero(0).instruction();
-  int64_t concat_dim = concat.concatenate_dimension();
-  int64_t operand_offset = 0;
-
-  // Emit the slices that correspond to the operands of the concat hero.
-  for (const HloInstruction* operand : concat.operands()) {
-    llvm_ir::BodyEmitter body_emitter =
-        [&](const llvm_ir::IrArray::Index& operand_index) -> absl::Status {
-      // Bind concat to generate the current operand.
-      TF_ASSIGN_OR_RETURN(auto operand_generator,
-                          fused_emitter.GetGenerator(*operand));
-      fused_emitter.BindGenerator(concat, [&](llvm_ir::IrArray::Index) {
-        return operand_generator(operand_index);
-      });
-
-      // Create the index of the slice corresponding to the current operand.
-      llvm_ir::IrArray::Index result_index = operand_index.AddOffsetToDim(
-          llvm::ConstantInt::get(index_type, operand_offset), concat_dim,
-          builder);
-      operand_offset += operand->shape().dimensions(concat_dim);
-
-      // Generate and write out the slice for each root.
-      for (const auto& [output, root] :
-           llvm::zip_equal(outputs, analysis_.fusion_roots())) {
-        llvm_ir::IrArray::Index root_index = result_index.SourceIndexOfBitcast(
-            concat.shape(), root.shape(), builder);
-        TF_ASSIGN_OR_RETURN(auto generator,
-                            fused_emitter.GetGenerator(root.instruction()));
-        TF_ASSIGN_OR_RETURN(llvm::Value * value, generator(root_index));
-        output.EmitWriteArrayElement(root_index, value, builder);
-      }
-      return absl::OkStatus();
-    };
-
-    ParallelLoopEmitter emitter(body_emitter, operand->shape(), launch_dims,
-                                builder);
-    TF_RETURN_IF_ERROR(emitter.EmitLoop(fusion.name(), index_type));
-  }
-
-  return absl::OkStatus();
-}
-
-LaunchDimensions ConcatenateFusion::launch_dimensions() const {
-  return CalculateLaunchDimensions(GetLargestConcatOperandShape(analysis_),
-                                   analysis_.device_info());
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/concatenate.h b/third_party/xla/xla/service/gpu/fusions/concatenate.h
deleted file mode 100644
index e838b29..0000000
--- a/third_party/xla/xla/service/gpu/fusions/concatenate.h
+++ /dev/null
@@ -1,67 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_FUSIONS_CONCATENATE_H_
-#define XLA_SERVICE_GPU_FUSIONS_CONCATENATE_H_
-
-#include <optional>
-#include <vector>
-
-#include "absl/status/status.h"
-#include "llvm/IR/IRBuilder.h"
-#include "mlir/IR/MLIRContext.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/service/gpu/fusions/fusion_emitter.h"
-#include "xla/service/gpu/hlo_fusion_analysis.h"
-#include "xla/service/gpu/ir_emitter_context.h"
-#include "xla/service/gpu/launch_dimensions.h"
-#include "xla/service/gpu/model/indexing_map.h"
-#include "xla/service/llvm_ir/ir_array.h"
-#include "xla/shape.h"
-
-namespace xla {
-namespace gpu {
-
-const Shape& GetLargestConcatOperandShape(const HloFusionAnalysis& analysis);
-
-// Emits a kernel for the given hlo instruction where each thread produces
-// one element of each concat operand.
-class ConcatenateFusion : public KernelFusionEmitterBase {
- public:
-  explicit ConcatenateFusion(const HloFusionAnalysis& analysis);
-  LaunchDimensions launch_dimensions() const override;
-
-  std::optional<IndexingMap> ComputeThreadIdToOutputIndexing(
-      int64_t root_index, mlir::MLIRContext* ctx) const override;
-
-  std::optional<IndexingMap> ComputeThreadIdToInputIndexing(
-      int64_t root_index, int64_t hero_operand_index,
-      mlir::MLIRContext* ctx) const override;
-
- protected:
-  absl::Status EmitKernel(IrEmitterContext& ir_emitter_context,
-                          const HloFusionInstruction& fusion,
-                          const LaunchDimensions& launch_dims,
-                          std::vector<llvm_ir::IrArray> inputs,
-                          std::vector<llvm_ir::IrArray> outputs,
-                          llvm::IRBuilder<>* builder) const override;
-
- private:
-  const HloFusionAnalysis& analysis_;
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_FUSIONS_CONCATENATE_H_
diff --git a/third_party/xla/xla/service/gpu/fusions/concatenate_mlir.cc b/third_party/xla/xla/service/gpu/fusions/concatenate_mlir.cc
index f2cecc5..d7ebabf 100644
--- a/third_party/xla/xla/service/gpu/fusions/concatenate_mlir.cc
+++ b/third_party/xla/xla/service/gpu/fusions/concatenate_mlir.cc
@@ -35,10 +35,9 @@
 #include "mlir/IR/ValueRange.h"
 #include "xla/hlo/ir/hlo_instruction.h"
 #include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/service/gpu/fusions/concatenate.h"
-#include "xla/service/gpu/fusions/loop.h"
 #include "xla/service/gpu/fusions/mlir/computation_partitioner.h"
 #include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h"
+#include "xla/service/gpu/gpu_fusible.h"
 #include "xla/service/gpu/hlo_fusion_analysis.h"
 #include "xla/service/gpu/launch_dimensions.h"
 #include "xla/service/gpu/model/indexing_analysis.h"
@@ -52,6 +51,16 @@
 using mlir::Value;
 using mlir::ValueRange;
 
+const Shape& GetLargestConcatOperandShape(const HloFusionAnalysis& analysis) {
+  const HloInstruction& concat = analysis.fusion_hero(0).instruction();
+  int64_t dim = concat.concatenate_dimension();
+  auto less = [&](const HloInstruction* lhs, const HloInstruction* rhs) {
+    return lhs->shape().dimensions(dim) < rhs->shape().dimensions(dim);
+  };
+  HloInstruction* operand = *absl::c_max_element(concat.operands(), less);
+  return operand->shape();
+}
+
 // Computes the unroll factor that divides concat dimension of all operands.
 int ComputeUnrollFactor(const HloFusionAnalysis& analysis,
                         int unroll_factor_for_the_largest_shape) {
diff --git a/third_party/xla/xla/service/gpu/fusions/concatenate_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/concatenate_mlir_test.cc
index 92aff94..30ca0f6 100644
--- a/third_party/xla/xla/service/gpu/fusions/concatenate_mlir_test.cc
+++ b/third_party/xla/xla/service/gpu/fusions/concatenate_mlir_test.cc
@@ -52,7 +52,7 @@
   thread_id_printer_.SetSymbolName(1, "unroll_id");
 
   auto* root = module->entry_computation()->root_instruction();
-  auto analysis = AnalyzeFusion(*root, device_info_);
+  auto analysis = HloFusionAnalysis::Create(*root, device_info_);
   MlirConcatenateFusion fusion(analysis);
 
   constexpr auto kIndexing = R"(
@@ -102,9 +102,9 @@
     }
   )";
   TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"(
-    // CHECK-DAG: #[[MAP_1:.*]] = affine_map<(d0, d1) -> (d1 * 128 + d0)>
-    // CHECK-DAG: #[[MAP_2:.*]] = affine_map<(d0, d1) -> (d1 * 128 + d0 + 200)>
-    // CHECK-DAG: #[[MAP_3:.*]] = affine_map<(d0, d1) -> (d1 * 128 + d0 + 600)>
+    // CHECK-DAG: #[[MAP_1:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d1 * 128 + d0)
+    // CHECK-DAG: #[[MAP_2:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d1 * 128 + d0 + 200)
+    // CHECK-DAG: #[[MAP_3:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d1 * 128 + d0 + 600)
 
     // CHECK-LABEL: fused_computation
     // CHECK-SAME:    %[[ARG_0:[a-zA-Z0-9]*]]: {{[^,]*}},
@@ -152,7 +152,7 @@
     }
   )";
   TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"(
-    // CHECK: #[[MAP:.*]] = affine_map<(d0) -> (d0 + 64)>
+    // CHECK: #[[MAP:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 + 64)
 
     // CHECK-LABEL: fused_computation
     // CHECK-DAG: %[[C_63:.*]] = arith.constant 63
@@ -254,9 +254,9 @@
     }
   )";
   TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"(
-    // CHECK-DAG: affine_map<(d0, d1) -> (d1 * 128 + d0)>
-    // CHECK-DAG: affine_map<(d0, d1)[s0] -> (d0 * 2 + d1 * 256 + s0)>
-    // CHECK-DAG: affine_map<(d0, d1)[s0] -> (d0 * 2 + d1 * 256 + s0 + 640002)>
+    // CHECK-DAG: #xla_gpu.indexing_map<(d0, d1) -> (d1 * 128 + d0)
+    // CHECK-DAG: #xla_gpu.indexing_map<(d0, d1)[s0] -> (d0 * 2 + d1 * 256 + s0)
+    // CHECK-DAG: #xla_gpu.indexing_map<(d0, d1)[s0] -> (d0 * 2 + d1 * 256 + s0 + 640002)
 
     // CHECK-LABEL: fused_computation
     // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
diff --git a/third_party/xla/xla/service/gpu/fusions/concatenate_test.cc b/third_party/xla/xla/service/gpu/fusions/concatenate_test.cc
deleted file mode 100644
index 8192b33..0000000
--- a/third_party/xla/xla/service/gpu/fusions/concatenate_test.cc
+++ /dev/null
@@ -1,120 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "xla/service/gpu/fusions/concatenate.h"
-
-#include <optional>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "mlir/IR/MLIRContext.h"
-#include "xla/service/gpu/fusions/fusions.h"
-#include "xla/service/gpu/gpu_device_info_for_tests.h"
-#include "xla/service/gpu/hlo_fusion_analysis.h"
-#include "xla/service/gpu/model/affine_map_printer.h"
-#include "xla/service/gpu/model/indexing_test_utils.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/tests/hlo_test_base.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-class ConcatenateTest : public HloTestBase {
- public:
-  void SetUp() override {
-    HloTestBase::SetUp();
-    printer_ =
-        AffineMapPrinter({"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"},
-                         {"chunk_id", "unroll_id"});
-  }
-
- protected:
-  DebugOptions GetDebugOptionsForTest() override {
-    auto opts = HloTestBase::GetDebugOptionsForTest();
-    opts.set_xla_gpu_mlir_emitter_level(0);
-    return opts;
-  }
-  AffineMapPrinter printer_;
-  mlir::MLIRContext mlir_context_;
-};
-
-TEST_F(ConcatenateTest, ThreadIndexing) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule module
-
-    fused_computation {
-      param0 = f32[200] parameter(0)
-      param1 = f32[400] parameter(1)
-      param2 = f32[300] parameter(2)
-      ROOT concat = f32[900] concatenate(param0, param1, param2), dimensions={0}
-    }
-    ENTRY main {
-      param0 = f32[200] parameter(0)
-      param1 = f32[400] parameter(1)
-      param2 = f32[300] parameter(2)
-      ROOT fusion = f32[900] fusion(param0, param1, param2),
-        calls=fused_computation, kind=kLoop
-    }
-  )")
-                    .value();
-
-  stream_executor::DeviceDescription device_info =
-      TestGpuDeviceInfo::RTXA6000DeviceInfo();
-
-  auto* root = module->entry_computation()->root_instruction();
-  auto analysis_fused = AnalyzeFusion(*root, device_info);
-
-  auto emitter =
-      GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis_fused});
-  auto fusion = dynamic_cast<ConcatenateFusion*>(emitter.get());
-  ASSERT_NE(fusion, nullptr);
-
-  constexpr auto kIndexing = R"(
-    (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] ->
-      (bl_x * 128 + th_x)
-    domain:
-    th_x in [0, 127]
-    th_y in [0, 0]
-    th_z in [0, 0]
-    bl_x in [0, 3]
-    bl_y in [0, 0]
-    bl_z in [0, 0]
-    chunk_id in [0, 0]
-    unroll_id in [0, 0]
-    bl_x * 128 + th_x in [0, 399]
-  )";
-  EXPECT_THAT(
-      fusion
-          ->ComputeThreadIdToInputIndexing(
-              /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_)
-          ->ToString(printer_),
-      MatchIndexingString(kIndexing));
-  EXPECT_THAT(
-      fusion
-          ->ComputeThreadIdToInputIndexing(
-              /*root_index=*/0, /*hero_operand_index=*/1, &mlir_context_)
-          ->ToString(printer_),
-      MatchIndexingString(kIndexing));
-  EXPECT_THAT(
-      fusion
-          ->ComputeThreadIdToInputIndexing(
-              /*root_index=*/0, /*hero_operand_index=*/2, &mlir_context_)
-          ->ToString(printer_),
-      MatchIndexingString(kIndexing));
-}
-
-}  // namespace
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/cudnn_test.cc b/third_party/xla/xla/service/gpu/fusions/cudnn_test.cc
index caa24b1..172c53a 100644
--- a/third_party/xla/xla/service/gpu/fusions/cudnn_test.cc
+++ b/third_party/xla/xla/service/gpu/fusions/cudnn_test.cc
@@ -31,11 +31,10 @@
 #include "xla/primitive_util.h"
 #include "xla/service/dump.h"
 #include "xla/service/executable.h"
-#include "xla/service/gpu/cudnn_fusion_compiler.h"
 #include "xla/service/gpu/ir_emission_utils.h"
 #include "xla/service/gpu/stream_executor_util.h"
 #include "xla/service/gpu/tests/gpu_codegen_test.h"
-#include "xla/service/hlo_module_config.h"
+#include "xla/service/gpu/transforms/cudnn_fusion_compiler.h"
 #include "xla/service/pattern_matcher.h"
 #include "xla/service/pattern_matcher_gmock.h"
 #include "xla/stream_executor/dnn.h"
@@ -88,17 +87,49 @@
   }
 };
 
-TEST_F(CuDnnFusionTest, DumpingWorks) {
-  HloModuleConfig config;
-  DebugOptions options = GetDebugOptionsForTest();
-  std::string output_directory;
-  if (!tsl::io::GetTestUndeclaredOutputsDir(&output_directory)) {
-    output_directory = tsl::testing::TmpDir();
+class CuDnnFusionFileCheckTest : public CuDnnFusionTest {
+ public:
+  CuDnnFusionFileCheckTest() {
+    if (!tsl::io::GetTestUndeclaredOutputsDir(&output_directory_)) {
+      output_directory_ = tsl::testing::TmpDir();
+    }
   }
-  options.set_xla_dump_to(output_directory);
-  config.set_debug_options(options);
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
-                          ParseAndReturnVerifiedModule(R"(
+
+  DebugOptions GetDebugOptionsForTest() override {
+    DebugOptions options = CuDnnFusionTest::GetDebugOptionsForTest();
+    options.set_xla_dump_to(output_directory_);
+    return options;
+  }
+
+  absl::StatusOr<bool> RunCuDnnFileCheck(absl::string_view hlo,
+                                         absl::string_view pattern) {
+    TF_ASSIGN_OR_RETURN(std::unique_ptr<VerifiedHloModule> module,
+                        ParseAndReturnVerifiedModule(hlo));
+    const std::string root_name(
+        module->entry_computation()->root_instruction()->name());
+    BinaryMap dnn_compiled_graphs;
+    CuDnnFusionCompiler cudnn_compiler(*backend().default_stream_executor(),
+                                       dnn_compiled_graphs);
+    // Run filecheck even if CuDnnFusionCompiler failed.
+    cudnn_compiler.Run(module.get()).IgnoreError();
+    std::string dump;
+    TF_RETURN_IF_ERROR(tsl::ReadFileToString(
+        tsl::Env::Default(),
+        tsl::io::JoinPath(
+            output_directory_,
+            FilenameFor(*module, /*prefix=*/"",
+                        /*suffix=*/
+                        absl::StrCat("cudnn_fusion_", root_name, ".json"))),
+        &dump));
+    return RunFileCheck(dump, pattern);
+  }
+
+ private:
+  std::string output_directory_;
+};
+
+TEST_F(CuDnnFusionFileCheckTest, F32DotGraphIsConvertedCorrectly) {
+  EXPECT_TRUE(*RunCuDnnFileCheck(R"(
 fd0 {
   p0 = f32[64,64] parameter(0)
   p1 = f32[64,64] parameter(1)
@@ -111,20 +142,7 @@
   ROOT d0 = f32[64,64] fusion(p0, p1), kind=kCustom, calls=fd0,
     backend_config={"fusion_backend_config":{"kind":"__cudnn$fusion","cudnn_fusion_config":{"plan_id":"0"}}}
 })",
-                                                       config));
-  BinaryMap dnn_compiled_graphs;
-  CuDnnFusionCompiler cudnn_compiler(*backend().default_stream_executor(),
-                                     dnn_compiled_graphs);
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, cudnn_compiler.Run(module.get()));
-  EXPECT_TRUE(changed);
-  std::string dump;
-  TF_EXPECT_OK(tsl::ReadFileToString(
-      tsl::Env::Default(),
-      tsl::io::JoinPath(output_directory,
-                        FilenameFor(*module, /*prefix=*/"",
-                                    /*suffix=*/"cudnn_fusion_d0.json")),
-      &dump));
-  EXPECT_TRUE(*RunFileCheck(dump, R"(
+                                 R"(
 CHECK: "nodes": [
 CHECK:   "inputs": {
 CHECK:     "A": "p0",
diff --git a/third_party/xla/xla/service/gpu/fusions/custom.cc b/third_party/xla/xla/service/gpu/fusions/custom.cc
index 3a95abf..31cd030 100644
--- a/third_party/xla/xla/service/gpu/fusions/custom.cc
+++ b/third_party/xla/xla/service/gpu/fusions/custom.cc
@@ -58,12 +58,16 @@
 #include "xla/service/gpu/runtime/dynamic_slice_thunk.h"
 #include "xla/service/gpu/runtime/gemm_thunk.h"
 #include "xla/service/gpu/runtime/kernel_thunk.h"
+#include "xla/service/gpu/runtime/nccl_all_reduce_thunk.h"
+#include "xla/service/gpu/runtime/nccl_api.h"
+#include "xla/service/gpu/runtime/nccl_collective_thunk.h"
 #include "xla/service/gpu/runtime/thunk.h"
 #include "xla/service/gpu/stream_executor_util.h"
 #include "xla/service/hlo.pb.h"
 #include "xla/service/pattern_matcher.h"
 #include "xla/shape.h"
 #include "xla/shape_util.h"
+#include "xla/status_macros.h"
 #include "xla/util.h"
 #include "xla/xla_data.pb.h"
 #include "tsl/platform/errors.h"
@@ -755,6 +759,119 @@
   return result;
 }
 
+absl::StatusOr<FusionEmissionResult> EmitCollective(
+    IrEmitterContext& ir_emitter_context, const HloFusionAdaptor& adaptor,
+    const HloFusionInstruction& fusion_instr, const HloInstruction* instr) {
+  if (instr->opcode() != HloOpcode::kReduceScatter) {
+    return absl::UnimplementedError(
+        "Dynamic slice fusion with collectives only works for reduce-scatter "
+        "instruction");
+  }
+
+  const BufferAssignment& buffer_assignment =
+      ir_emitter_context.buffer_assignment();
+
+  std::vector<std::optional<std::vector<DynamicSliceThunk::Offset>>>
+      offset_buffer_indices(2, std::nullopt);
+  std::vector<std::optional<Shape>> orig_shapes(2, std::nullopt);
+  std::vector<std::optional<Shape>> sliced_shapes(2, std::nullopt);
+  std::vector<std::optional<uint64_t>> offset_byte_sizes(2, std::nullopt);
+
+  std::vector<HloInstruction*> slice_instrs(2, nullptr);
+  std::vector<std::optional<BufferAllocation::Slice>> arguments;
+
+  // Collect slice information for inputs.
+  unsigned arg_idx = 0;
+  TF_ASSIGN_OR_RETURN(arguments.emplace_back(),
+                      GetOperandSlice(buffer_assignment, adaptor, fusion_instr,
+                                      *instr->operand(arg_idx), slice_instrs,
+                                      /*shape_idx=*/{}, arg_idx));
+  TF_RETURN_IF_ERROR(CollectSliceInfo(
+      buffer_assignment, fusion_instr,
+      absl::Span<HloInstruction*>(slice_instrs), offset_buffer_indices,
+      orig_shapes, sliced_shapes, offset_byte_sizes, arg_idx++));
+
+  // Collect slice information for outputs.
+  TF_ASSIGN_OR_RETURN(
+      arguments.emplace_back(),
+      GetResultSlice(buffer_assignment, adaptor, fusion_instr, *instr,
+                     slice_instrs, /*shape_idx=*/{}, arg_idx));
+  TF_RETURN_IF_ERROR(CollectSliceInfo(
+      buffer_assignment, fusion_instr,
+      absl::Span<HloInstruction*>(slice_instrs), offset_buffer_indices,
+      orig_shapes, sliced_shapes, offset_byte_sizes, arg_idx));
+
+  if (absl::c_all_of(slice_instrs, [&](HloInstruction* slice_instr) {
+        return slice_instr &&
+               slice_instr->opcode() != HloOpcode::kDynamicUpdateSlice;
+      })) {
+    return absl::InternalError(
+        "DynamicSliceFusion with reduce-scatter expects a dynamic-update-slice "
+        "operation.");
+  }
+
+  // Provide fake allocations for inputs and outputs.
+  std::vector<std::unique_ptr<BufferAllocation>> fake_allocations(2);
+  unsigned fake_arg_idx = 0;
+  int64_t operand_byte_size =
+      ShapeUtil::ByteSizeOf(instr->operand(fake_arg_idx)->shape());
+  fake_allocations[fake_arg_idx] = std::make_unique<BufferAllocation>(
+      /*index=*/fake_arg_idx, operand_byte_size, /*color=*/0);
+  BufferAllocation::Slice slice_operand(fake_allocations[fake_arg_idx].get(), 0,
+                                        operand_byte_size);
+  fake_arg_idx++;
+  TF_RET_CHECK(instr->shape().IsArray() &&
+               "The output is not expected to be a tuple.");
+  int64_t out_fake_byte_size =
+      ShapeUtil::ByteSizeOf(instr->shape());  // TODO: we don't need this
+  fake_allocations[fake_arg_idx] = std::make_unique<BufferAllocation>(
+      /*index=*/fake_arg_idx, out_fake_byte_size, /*color=*/0);
+  BufferAllocation::Slice slice_out_fake(fake_allocations[fake_arg_idx].get(),
+                                         0, out_fake_byte_size);
+
+  // Generate the hero thunk and wrap it in a dynamic-slice thunk.
+  ThunkSequence seq;
+  auto thunk_info = Thunk::ThunkInfo::WithProfileAnnotation(instr);
+  std::vector<NcclCollectiveThunk::Buffer> buffers;
+  const Shape& src_shape = instr->operand(0)->shape();
+  const Shape& dst_shape = instr->shape();
+  buffers.push_back(NcclCollectiveThunk::Buffer{
+      ShapeUtil::ElementsIn(src_shape), slice_operand, slice_out_fake,
+      src_shape.layout().memory_space(), dst_shape.layout().memory_space(),
+      nullptr, nullptr});
+
+  if (instr->opcode() == HloOpcode::kReduceScatter) {
+    int64_t replica_count = instr->GetModule()->config().replica_count();
+    int64_t partition_count = instr->GetModule()->config().num_partitions();
+    auto rs = static_cast<const HloReduceScatterInstruction*>(instr);
+    TF_RETURN_IF_ERROR(NcclReduceScatterStartThunk::CheckImplementable(
+        rs, replica_count, partition_count));
+
+    // TODO: add special handling for degenerate case - where no communication
+    // is needed. Just copy.
+    auto rs_start_thunk = std::make_unique<NcclReduceScatterStartThunk>(
+        thunk_info, NcclApi::Default(), rs, buffers);
+    auto rs_done = std::make_unique<NcclCollectiveDoneThunk>(
+        /*kind=*/Thunk::kNcclReduceScatterDone,
+        /*thunk_info=*/Thunk::ThunkInfo::WithProfileAnnotation(rs),
+        /*async_events=*/rs_start_thunk->async_events(),
+        /*async_stream_kind=*/AsyncStreamKind::kCollective);
+    seq.emplace_back(std::move(rs_start_thunk));
+    seq.emplace_back(std::move(rs_done));
+  } else {
+    return absl::InternalError("Expected reduce-scatter hero instruction");
+  }
+
+  std::unique_ptr<Thunk> thunk = std::make_unique<DynamicSliceThunk>(
+      thunk_info, std::make_unique<ThunkSequence>(std::move(seq)),
+      std::move(arguments), std::move(fake_allocations),
+      std::move(offset_buffer_indices), std::move(orig_shapes),
+      std::move(sliced_shapes), std::move(offset_byte_sizes));
+  FusionEmissionResult result;
+  result.thunks.push_back(std::move(thunk));
+  return result;
+}
+
 }  // namespace
 
 absl::StatusOr<FusionEmissionResult> CustomFusion::Emit(
@@ -807,6 +924,16 @@
     IrEmitterContext& ir_emitter_context,
     const HloFusionInstruction& fusion) const {
   const HloFusionAdaptor& adaptor = analysis_.fusion();
+  // Only reduce-scatter is supported for now.
+  auto maybe_collective =
+      HloBfsFindIf(/*roots=*/adaptor.GetRoots(), /*fusion=*/adaptor,
+                   /*visit=*/[](HloInstructionAdaptor node) -> bool {
+                     return node.opcode() == HloOpcode::kReduceScatter;
+                   });
+  if (maybe_collective != std::nullopt) {
+    return EmitCollective(ir_emitter_context, adaptor, fusion,
+                          &maybe_collective->instruction());
+  }
   auto maybe_custom_call_adaptor = HloBfsFindIf(
       adaptor.GetRoots(), adaptor,
       [](auto node) { return node.opcode() == HloOpcode::kCustomCall; });
diff --git a/third_party/xla/xla/service/gpu/fusions/dynamic_slice_fusion_test.cc b/third_party/xla/xla/service/gpu/fusions/dynamic_slice_fusion_test.cc
index f53dc13..f835514 100644
--- a/third_party/xla/xla/service/gpu/fusions/dynamic_slice_fusion_test.cc
+++ b/third_party/xla/xla/service/gpu/fusions/dynamic_slice_fusion_test.cc
@@ -16,7 +16,9 @@
 #include <cstddef>
 #include <cstdint>
 #include <functional>
+#include <string>
 #include <utility>
+#include <vector>
 
 #include "absl/status/status.h"
 #include "xla/client/lib/constants.h"
@@ -24,9 +26,11 @@
 #include "xla/error_spec.h"
 #include "xla/ffi/ffi.h"
 #include "xla/ffi/ffi_api.h"
+#include "xla/hlo/ir/hlo_computation.h"
 #include "xla/hlo/ir/hlo_module.h"
 #include "xla/service/custom_call_target_registry.h"
-#include "xla/service/gpu/dynamic_slice_fusion_rewriter.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.h"
 #include "xla/service/hlo.pb.h"
 #include "xla/service/hlo_module_config.h"
 #include "xla/shape.h"
@@ -34,6 +38,7 @@
 #include "xla/stream_executor/device_memory.h"
 #include "xla/stream_executor/gpu/gpu_types.h"
 #include "xla/stream_executor/stream.h"
+#include "xla/tests/filecheck.h"
 #include "xla/tests/hlo_test_base.h"
 #include "xla/xla.pb.h"
 #include "xla/xla_data.pb.h"
@@ -80,6 +85,36 @@
     config.set_debug_options(debug_options);
     return config;
   }
+
+  HloModuleConfig GetModuleConfigWithDeterministicOps() {
+    DebugOptions debug_options = GetDebugOptionsForTest();
+    debug_options.set_xla_gpu_deterministic_ops(true);
+    HloModuleConfig config;
+    config.set_debug_options(debug_options);
+    return config;
+  }
+
+  std::vector<HloComputation*> GetAddressComputations(const HloModule& module) {
+    std::vector<HloComputation*> computations;
+    for (auto computation : module.computations()) {
+      if (!computation->IsFusionComputation()) {
+        continue;
+      }
+      auto backend_config = computation->FusionInstruction()
+                                ->backend_config<xla::gpu::GpuBackendConfig>();
+      if (backend_config.ok()) {
+        const FusionBackendConfig& fusion_backend_config =
+            backend_config.value().fusion_backend_config();
+        const std::string name =
+            fusion_backend_config.custom_fusion_config().name();
+        if (name == "dynamic_address_computation" ||
+            name == "address_computation") {
+          computations.push_back(computation);
+        }
+      }
+    }
+    return computations;
+  }
 };
 
 TEST_F(DynamicSliceFusionTest, CublasGemmSimple) {
@@ -237,8 +272,10 @@
         backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}}
   })";
 
-  EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec,
-                                      /*run_hlo_passes=*/false));
+  EXPECT_TRUE(RunAndCompareTwoModules(
+      hlo_ref, hlo_opt, GetModuleConfigWithDeterministicOps(),
+      GetModuleConfigWithDeterministicOps(), error_spec,
+      /*run_hlo_passes=*/false));
 }
 
 TEST_F(DynamicSliceFusionTest, ContiguousSlice) {
@@ -1327,8 +1364,10 @@
         backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}}
   })";
 
-  EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec,
-                                      /*run_hlo_passes=*/false));
+  EXPECT_TRUE(RunAndCompareTwoModules(
+      hlo_ref, hlo_opt, GetModuleConfigWithDeterministicOps(),
+      GetModuleConfigWithDeterministicOps(), error_spec,
+      /*run_hlo_passes=*/false));
 }
 
 TEST_F(DynamicSliceFusionTest, DynamicContiguousSlice) {
@@ -2156,8 +2195,10 @@
         backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}}
   })";
 
-  EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec,
-                                      /*run_hlo_passes=*/false));
+  EXPECT_TRUE(RunAndCompareTwoModules(
+      hlo_ref, hlo_opt, GetModuleConfigWithDeterministicOps(),
+      GetModuleConfigWithDeterministicOps(), error_spec,
+      /*run_hlo_passes=*/false));
 }
 
 TEST_F(DynamicSliceFusionTest, CublasGemmDUSWorkspaceIgnored) {
@@ -2241,8 +2282,10 @@
         backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}}
   })";
 
-  EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec,
-                                      /*run_hlo_passes=*/false));
+  EXPECT_TRUE(RunAndCompareTwoModules(
+      hlo_ref, hlo_opt, GetModuleConfigWithDeterministicOps(),
+      GetModuleConfigWithDeterministicOps(), error_spec,
+      /*run_hlo_passes=*/false));
 }
 
 TEST_F(DynamicSliceFusionTest, CublasGemmDUSOffsetS32NotConstant) {
@@ -2435,8 +2478,10 @@
         backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}}
   })";
 
-  EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec,
-                                      /*run_hlo_passes=*/false));
+  EXPECT_TRUE(RunAndCompareTwoModules(
+      hlo_ref, hlo_opt, GetModuleConfigWithDeterministicOps(),
+      GetModuleConfigWithDeterministicOps(), error_spec,
+      /*run_hlo_passes=*/false));
 }
 
 TEST_F(DynamicSliceFusionTest, DynamicCustomCallSimple) {
@@ -2445,9 +2490,7 @@
       &b, "__xla_test$$memcpy",
       /*operands=*/
       {DynamicSlice(Parameter(&b, 0, ShapeUtil::MakeShape(S32, {4, 128}), "p0"),
-                    {Parameter(&b, 1, ShapeUtil::MakeShape(S32, {}), "start0"),
-                     Parameter(&b, 2, ShapeUtil::MakeShape(S32, {}), "start1")},
-                    {2, 128})},
+                    {ConstantR0(&b, 2), ConstantR0(&b, 0)}, {2, 128})},
       ShapeUtil::MakeShape(F32, {2, 128}), /*opaque=*/"",
       /*has_side_effect=*/false,
       /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
@@ -2464,7 +2507,6 @@
   hlo_config.set_debug_options(debug_options);
   TF_ASSERT_OK_AND_ASSIGN(auto hlo_ref, xla::HloModule::CreateFromProto(
                                             computation.proto(), hlo_config));
-
   TF_ASSERT_OK_AND_ASSIGN(auto hlo_opt, xla::HloModule::CreateFromProto(
                                             computation.proto(), hlo_config));
   DynamicSliceFusionRewriter pass(PLATFORM);
@@ -2502,11 +2544,7 @@
                     DynamicSlice(
                         Parameter(&b, 0, ShapeUtil::MakeShape(S32, {4, 128}),
                                   "p0"),
-                        {Parameter(&b, 1, ShapeUtil::MakeShape(S32, {}),
-                                   "start0"),
-                         Parameter(&b, 2, ShapeUtil::MakeShape(S32, {}),
-                                   "start1")},
-                        {3, 128}),
+                        {ConstantR0(&b, 20), ConstantR0(&b, 0)}, {3, 128}),
                 }),
       },
       ShapeUtil::MakeTupleShape({
@@ -2545,6 +2583,15 @@
   DynamicSliceFusionRewriter pass(PLATFORM);
   TF_ASSERT_OK_AND_ASSIGN(auto changed, this->RunHloPass(&pass, hlo_opt.get()));
   EXPECT_TRUE(changed);
+  EXPECT_TRUE(*RunFileCheck(hlo_opt->ToString(), R"(
+    // CHECK: %address-computation{{.+}} {
+    // CHECK:   {{.+}} = {{.+}} slice
+    // CHECK:   {{.+}} = {{.+}} dynamic-slice
+    // CHECK:   {{.+}} = {{.+}} custom-call
+    // CHECK: ENTRY {{.+}} {
+    // CHECK-NOT: {{.+}} = {{.+}} slice
+    // CHECK-NOT: {{.+}} = {{.+}} dynamic-slice
+  )"));
 
   EXPECT_TRUE(RunAndCompareTwoModules(std::move(hlo_ref), std::move(hlo_opt),
                                       error_spec, /*run_hlo_passes=*/false));
@@ -2754,6 +2801,199 @@
                                       /*run_hlo_passes=*/false));
 }
 
+TEST_F(DynamicSliceFusionTest, ReduceScatterDUSConstant) {
+  // DUS offset is a constant
+  const char* hlo_ref = R"(
+  HloModule test, replica_count=2
+
+  add.clone {
+    x.1 = f16[] parameter(0)
+    y.1 = f16[] parameter(1)
+    ROOT add.462 = f16[] add(x.1, y.1)
+  }
+
+  ENTRY %main.9 {
+    param_0 = f16[128,128]{1,0} parameter(0)
+    param_1 = f16[128,128]{1,0} parameter(1)
+    constant_20 = u32[] constant(20)
+    constant_0 = u32[] constant(0)
+    reduce-scatter = f16[64,128]{1,0} reduce-scatter(param_0), channel_id=64, replica_groups={{0,1}}, use_global_device_ids=true, dimensions={0}, to_apply=add.clone
+    ROOT dynamic-update-slice = f16[128,128]{1,0} dynamic-update-slice(param_1, reduce-scatter, constant_20, constant_0)
+  })";
+
+  const char* hlo_opt = R"(
+  HloModule test, replica_count=2
+
+  %add {
+    %param_0 = f16[] parameter(0)
+    %param_1 = f16[] parameter(1)
+    ROOT %add.1 = f16[] add(%param_0, %param_1)
+  }
+
+  %address-computation {
+    %p1 = f16[128,128]{1,0} parameter(1)
+    %p0 = f16[128,128]{1,0} parameter(0)
+    %reduce-scatter.1 = f16[64,128]{1,0} reduce-scatter(%p0), channel_id=64, replica_groups={{0,1}}, use_global_device_ids=true, dimensions={0}, to_apply=%add
+    %p2 = u32[] parameter(2)
+    %p3 = u32[] parameter(3)
+    ROOT %loop_dynamic_update_slice_fusion.1 = f16[128,128]{1,0} dynamic-update-slice(%p1, %reduce-scatter.1, %p2, %p3)
+  }
+
+  ENTRY %main.9 {
+    %param_0.1 = f16[128,128]{1,0} parameter(0)
+    %param_1.1 = f16[128,128]{1,0} parameter(1)
+    %constant_20 = u32[] constant(20)
+    %constant_0 = u32[] constant(0)
+    ROOT %address_computation = f16[128,128]{1,0} fusion(%param_0.1, %param_1.1, %constant_20, %constant_0), kind=kCustom, calls=%address-computation, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}},"force_earliest_schedule":false}
+  })";
+
+  ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3};
+  EXPECT_TRUE(RunAndCompareTwoModulesReplicated(hlo_ref, hlo_opt, true, true,
+                                                error_spec));
+}
+
+TEST_F(DynamicSliceFusionTest, ReduceScatterDUSParameterOffset) {
+  // DUS offset is a parameter. This enforces a d2h copy.
+  const char* hlo_ref = R"(
+  HloModule test, replica_count=2
+
+  add.clone {
+    x.1 = f16[] parameter(0)
+    y.1 = f16[] parameter(1)
+    ROOT add.462 = f16[] add(x.1, y.1)
+  }
+
+  ENTRY %main.9 {
+    param_0 = f16[128,128]{1,0} parameter(0)
+    param_1 = f16[128,128]{1,0} parameter(1)
+    param_2 = u32[] parameter(2)
+    constant_0 = u32[] constant(0)
+    reduce-scatter = f16[64,128]{1,0} reduce-scatter(param_0), channel_id=64, replica_groups={{0,1}}, use_global_device_ids=true, dimensions={0}, to_apply=add.clone
+    ROOT dynamic-update-slice = f16[128,128]{1,0} dynamic-update-slice(param_1, reduce-scatter, param_2, constant_0)
+  })";
+
+  const char* hlo_opt = R"(
+  HloModule test, replica_count=2
+
+  %add {
+    %param_0 = f16[] parameter(0)
+    %param_1 = f16[] parameter(1)
+    ROOT %add.1 = f16[] add(f16[] %param_0, f16[] %param_1)
+  }
+
+  %address-computation {
+    %p1 = f16[128,128]{1,0} parameter(1)
+    %p0 = f16[128,128]{1,0} parameter(0)
+    %reduce-scatter.1 = f16[64,128]{1,0} reduce-scatter(%p0), channel_id=64, replica_groups={{0,1}}, use_global_device_ids=true, dimensions={0}, to_apply=%add
+    %p2 = u32[] parameter(2)
+    %p3 = u32[] parameter(3)
+    ROOT %loop_dynamic_update_slice_fusion.1 = f16[128,128]{1,0} dynamic-update-slice(%p1, %reduce-scatter.1, %p2, %p3)
+  }
+
+  ENTRY %main.9 {
+    %param_0 = f16[128,128]{1,0} parameter(0)
+    %param_1 = f16[128,128]{1,0} parameter(1)
+    %param_2 = u32[] parameter(2)
+    %constant_0 = u32[] constant(0)
+    ROOT %address_computation = f16[128,128]{1,0} fusion(%param_0, %param_1, %param_2, %constant_0), kind=kCustom, calls=%address-computation, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}},"force_earliest_schedule":false}
+  })";
+
+  ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3};
+  EXPECT_TRUE(RunAndCompareTwoModulesReplicated(hlo_ref, hlo_opt, true, true,
+                                                error_spec));
+}
+
+TEST_F(DynamicSliceFusionTest, ReduceScatterDUSLoopIterationOffset) {
+  const char* hlo_ref = R"(
+  HloModule jit_scan, replica_count=2
+
+  %add {
+    %param_0 = f32[] parameter(0)
+    %param_1 = f32[] parameter(1)
+    ROOT %add.1 = f32[] add(%param_0, %param_1)
+  }
+
+  %region_0.14 {
+    %arg_tuple.15 = (s32[], f32[128,128]{1,0}, f32[128,128,128]{2,1,0}, f32[128,128,128]{2,1,0}, f32[128,128]{1,0}) parameter(0)
+    %get-tuple-element.16 = s32[] get-tuple-element(%arg_tuple.15), index=0
+    %constant.21 = s32[] constant(1)
+    %add.37 = s32[] add(%get-tuple-element.16, %constant.21)
+    %get-tuple-element.20 = f32[128,128]{1,0} get-tuple-element(%arg_tuple.15), index=4
+    %get-tuple-element.18 = f32[128,128,128]{2,1,0} get-tuple-element(%arg_tuple.15), index=2
+    %reduce-scatter.1 = f32[64,128]{1,0} reduce-scatter(%get-tuple-element.20), channel_id=64, replica_groups={{0,1}}, use_global_device_ids=true, dimensions={0}, to_apply=%add
+    %reshape.32 = f32[1,64,128]{2,1,0} reshape(%reduce-scatter.1)
+    %constant.23 = s32[] constant(0)
+    %compare.33 = pred[] compare(%get-tuple-element.16, %constant.23), direction=LT
+    %constant.22 = s32[] constant(128)
+    %add.34 = s32[] add(%get-tuple-element.16, %constant.22)
+    %select.35 = s32[] select(%compare.33, %add.34, %get-tuple-element.16)
+    %dynamic-update-slice.36 = f32[128,128,128]{2,1,0} dynamic-update-slice(%get-tuple-element.18, %reshape.32, %select.35, %constant.23, %constant.23)
+    %get-tuple-element.19 = f32[128,128,128]{2,1,0} get-tuple-element(%arg_tuple.15), index=3
+    ROOT %tuple.38 = tuple(%add.37, %get-tuple-element.20, %dynamic-update-slice.36, %get-tuple-element.19, %get-tuple-element.20)
+  }
+
+  %region_1.39 {
+    %arg_tuple.40 = (s32[], f32[128,128]{1,0}, f32[128,128,128]{2,1,0}, f32[128,128,128]{2,1,0}, f32[128,128]{1,0}) parameter(0)
+    %get-tuple-element.41 = s32[] get-tuple-element(%arg_tuple.40), index=0
+    %constant.46 = s32[] constant(128)
+    ROOT %compare.47 = pred[] compare(%get-tuple-element.41, %constant.46), direction=LT
+  }
+
+  ENTRY %main.55 {
+    %constant.4 = s32[] constant(0)
+    %Arg_1.2 = f32[128,128]{1,0} parameter(1)
+    %constant.5 = f32[] constant(0)
+    %broadcast.6 = f32[128,128,128]{2,1,0} broadcast(%constant.5), dimensions={}
+    %Arg_2.3 = f32[128,128,128]{2,1,0} parameter(2)
+    %Arg_0.1 = f32[128,128]{1,0} parameter(0)
+    %tuple.7 = tuple(%constant.4, %Arg_1.2, %broadcast.6, %Arg_2.3, %Arg_0.1)
+    %while.48 = while(%tuple.7), condition=%region_1.39, body=%region_0.14
+    %get-tuple-element.50 = f32[128,128]{1,0} get-tuple-element(%while.48), index=1
+    %get-tuple-element.51 = f32[128,128,128]{2,1,0} get-tuple-element(%while.48), index=2
+    ROOT %tuple.54 = tuple(%get-tuple-element.50, %get-tuple-element.51)
+  })";
+  DebugOptions debugoptions = GetDebugOptionsForTest();
+
+  HloModuleConfig ref_config;
+  debugoptions.set_xla_gpu_enable_dynamic_slice_fusion(false);
+  ref_config.set_debug_options(debugoptions);
+  TF_ASSERT_OK_AND_ASSIGN(auto ref_module,
+                          ParseAndReturnVerifiedModule(hlo_ref, ref_config));
+  TF_ASSERT_OK_AND_ASSIGN(auto ref_module_opt,
+                          GetOptimizedModule(std::move(ref_module)));
+
+  HloModuleConfig opt_config;
+  debugoptions.set_xla_gpu_enable_dynamic_slice_fusion(true);
+  opt_config.set_debug_options(debugoptions);
+  TF_ASSERT_OK_AND_ASSIGN(auto module_with_adddress_computation_flag,
+                          ParseAndReturnVerifiedModule(hlo_ref, opt_config));
+  TF_ASSERT_OK_AND_ASSIGN(
+      auto module_with_adddress_computation,
+      GetOptimizedModule(std::move(module_with_adddress_computation_flag)));
+
+  std::vector<HloComputation*> address_computations_opt =
+      GetAddressComputations(*module_with_adddress_computation);
+  std::vector<HloComputation*> address_computations_ref =
+      GetAddressComputations(*ref_module_opt);
+  EXPECT_EQ(address_computations_ref.size(), 0);
+  ASSERT_EQ(address_computations_opt.size(), 1);
+
+  // Check that reduce scatter happens in the fusion in optimized module and not
+  // outside the fusion.
+  EXPECT_TRUE(*RunFileCheck(address_computations_opt[0]->ToString(), R"(
+    // CHECK: {{.+}} = {{.*}}reduce-scatter({{.+}})
+    // CHECK: {{.+}} = {{.*}}dynamic-update-slice({{.+}})
+  )"));
+  EXPECT_TRUE(*RunFileCheck(
+      address_computations_opt[0]->FusionInstruction()->parent()->ToString(),
+      "// CHECK-NOT: {{.+}} = {{.*}}reduce-scatter"));
+
+  ErrorSpec error{/*aabs=*/1e-3, /*arel=*/1e-3};
+  EXPECT_TRUE(RunAndCompareTwoModulesReplicated(
+      std::move(ref_module_opt), std::move(module_with_adddress_computation),
+      false, true, error));
+}
+
 }  // namespace
 }  // namespace gpu
 }  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/fusions.cc b/third_party/xla/xla/service/gpu/fusions/fusions.cc
index 03425b5..d662d25 100644
--- a/third_party/xla/xla/service/gpu/fusions/fusions.cc
+++ b/third_party/xla/xla/service/gpu/fusions/fusions.cc
@@ -27,24 +27,24 @@
 #include "xla/hlo/ir/hlo_opcode.h"
 #include "xla/layout_util.h"
 #include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/fusions/concatenate.h"
 #include "xla/service/gpu/fusions/concatenate_mlir.h"
 #include "xla/service/gpu/fusions/copy.h"
 #include "xla/service/gpu/fusions/cudnn.h"
 #include "xla/service/gpu/fusions/custom.h"
 #include "xla/service/gpu/fusions/fusion_emitter.h"
-#include "xla/service/gpu/fusions/in_place_dynamic_update_slice.h"
 #include "xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.h"
-#include "xla/service/gpu/fusions/input_slices.h"
 #include "xla/service/gpu/fusions/input_slices_mlir.h"
-#include "xla/service/gpu/fusions/loop.h"
+#include "xla/service/gpu/fusions/legacy/concatenate.h"
+#include "xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice.h"
+#include "xla/service/gpu/fusions/legacy/input_slices.h"
+#include "xla/service/gpu/fusions/legacy/loop.h"
+#include "xla/service/gpu/fusions/legacy/reduction.h"
+#include "xla/service/gpu/fusions/legacy/scatter.h"
+#include "xla/service/gpu/fusions/legacy/transpose.h"
 #include "xla/service/gpu/fusions/loop_mlir.h"
 #include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h"
-#include "xla/service/gpu/fusions/reduction.h"
 #include "xla/service/gpu/fusions/reduction_mlir.h"
-#include "xla/service/gpu/fusions/scatter.h"
 #include "xla/service/gpu/fusions/scatter_mlir.h"
-#include "xla/service/gpu/fusions/transpose.h"
 #include "xla/service/gpu/fusions/transpose_mlir.h"
 #include "xla/service/gpu/fusions/triton.h"
 #include "xla/service/gpu/hlo_fusion_analysis.h"
diff --git a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice.cc b/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice.cc
deleted file mode 100644
index 464de3c..0000000
--- a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice.cc
+++ /dev/null
@@ -1,105 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "xla/service/gpu/fusions/in_place_dynamic_update_slice.h"
-
-#include <optional>
-#include <utility>
-#include <vector>
-
-#include "absl/status/status.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/IR/IRBuilder.h"
-#include "mlir/IR/MLIRContext.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/service/gpu/elemental_ir_emitter.h"
-#include "xla/service/gpu/ir_emitter_context.h"
-#include "xla/service/gpu/launch_dimensions.h"
-#include "xla/service/gpu/model/indexing_map.h"
-#include "xla/service/llvm_ir/dynamic_update_slice_util.h"
-#include "xla/service/llvm_ir/fused_ir_emitter.h"
-#include "xla/service/llvm_ir/ir_array.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-constexpr int kDUSUpdateIndex = 1;
-
-}  // namespace
-
-LaunchDimensions InPlaceDynamicUpdateSliceFusion::launch_dimensions() const {
-  const auto& update_shape = dus_ops_.front().GetOperand(1).shape();
-  return CalculateLaunchDimensions(update_shape, analysis_.device_info());
-}
-
-std::optional<IndexingMap>
-InPlaceDynamicUpdateSliceFusion::ComputeThreadIdToInputIndexing(
-    int64_t root_index, int64_t hero_operand_index,
-    mlir::MLIRContext* mlir_context) const {
-  if (hero_operand_index != kDUSUpdateIndex) {
-    return std::nullopt;
-  }
-  auto launch_dims = launch_dimensions();
-  // It is guaranteed that all DUS ops have the same output shape at this point.
-  const auto& update_shape =
-      dus_ops_.front().GetOperand(kDUSUpdateIndex).shape();
-  return GetDefaultThreadIdIndexingMap(launch_dims, /*unroll_factor=*/1,
-                                       update_shape, mlir_context);
-}
-
-absl::Status InPlaceDynamicUpdateSliceFusion::EmitKernel(
-    IrEmitterContext& ir_emitter_context, const HloFusionInstruction& fusion,
-    const LaunchDimensions& launch_dims, std::vector<llvm_ir::IrArray> inputs,
-    std::vector<llvm_ir::IrArray> outputs, llvm::IRBuilder<>* builder) const {
-  // In case a dynamic slice update's output is bitcasted, we need to ensure we
-  // write to the output array using the shape and layout of the dynamic slice
-  // update. This cast is known to be safe to do iff, in the case the output of
-  // the dynamic slice update is bitcasted, that bitcast is either the fusion's
-  // output, or has a single user and is part of the fusion's tuple output.
-  // This condition should be enforced explicitly in the
-  // 'CanEmitFusedDynamicUpdateSliceInPlaceForGpu' matcher.
-  for (auto [op, output] : llvm::zip(dus_ops_, outputs)) {
-    output = output.CastToShape(op.shape(), builder);
-  }
-
-  auto* fused_computation = fusion.fused_instructions_computation();
-  GpuElementalIrEmitter elemental_emitter(ir_emitter_context, builder);
-  FusedIrEmitter fused_emitter(elemental_emitter);
-  for (auto [index, input] : llvm::enumerate(inputs)) {
-    auto fused_operand = fused_computation->parameter_instruction(index);
-    fused_emitter.BindGenerator(
-        *fused_operand, [input = input, builder,
-                         fused_operand](const llvm_ir::IrArray::Index& index) {
-          return input.EmitReadArrayElement(index, builder,
-                                            fused_operand->name());
-        });
-  }
-
-  std::vector<std::pair<const HloInstruction*, const llvm_ir::IrArray>>
-      dus_and_output_array;
-  dus_and_output_array.reserve(dus_ops_.size());
-
-  for (auto [op, output] : llvm::zip(dus_ops_, outputs)) {
-    dus_and_output_array.push_back(std::make_pair(&op.instruction(), output));
-  }
-
-  return llvm_ir::EmitParallelFusedDynamicUpdateSliceInPlace(
-      fused_computation, dus_and_output_array, &fused_emitter, launch_dims,
-      builder);
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice.h b/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice.h
deleted file mode 100644
index cfac87d..0000000
--- a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice.h
+++ /dev/null
@@ -1,98 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_FUSIONS_IN_PLACE_DYNAMIC_UPDATE_SLICE_H_
-#define XLA_SERVICE_GPU_FUSIONS_IN_PLACE_DYNAMIC_UPDATE_SLICE_H_
-
-#include <cstdint>
-#include <optional>
-#include <vector>
-
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "llvm/IR/IRBuilder.h"
-#include "mlir/IR/MLIRContext.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/service/gpu/fusions/fusion_emitter.h"
-#include "xla/service/gpu/hlo_fusion_analysis.h"
-#include "xla/service/gpu/ir_emission_utils.h"
-#include "xla/service/gpu/ir_emitter_context.h"
-#include "xla/service/gpu/launch_dimensions.h"
-#include "xla/service/gpu/model/indexing_map.h"
-#include "xla/service/llvm_ir/ir_array.h"
-
-namespace xla {
-namespace gpu {
-
-// Fusion node where the root is either:
-// 1. a dynamic-update-slice op
-// 2. a bitcast of a dynamic-update-slice op
-// 3. a tuple op returning the result of several dynamic-update-slice ops
-// 4. a tuple op returning the result of several bitcast
-//    dynamic-update-slice ops
-//
-// Additionally, all the dynamic-update-slice ops have exactly one user. The
-// fusion parameter that they update can have users (in addition to the
-// dynamic-update-slice op) that read in either
-// a. a dynamic-slice corresponding exactly to the slice of the parameter that
-//    is updated by the dynamic-update-slice op
-// b. a dynamic-slice reading in a single element anywhere in the parameter.
-//    This is only allowed if the dynamic-update-slice op updates a single
-//    element
-//
-// In both cases, the additional users must not flow into any other output
-// than the dynamic-slice-update corresponding to that particular slice of the
-// parameter.
-//
-// The assumption is that each op's input (i.e. array to update) shares the
-// same slice as its output. In this case, we have a special algorithm that
-// modifies the output in place without touching the un-updated elements. The
-// update slice is assumed to be the exact same for all the
-// dynamic-update-slice ops.
-class InPlaceDynamicUpdateSliceFusion : public KernelFusionEmitterBase {
- public:
-  explicit InPlaceDynamicUpdateSliceFusion(const HloFusionAnalysis& analysis)
-      : analysis_(analysis),
-        dus_ops_(
-            GetOutputDefiningDynamicUpdateSlices(analysis.fusion_roots())) {}
-  LaunchDimensions launch_dimensions() const override;
-
-  std::optional<IndexingMap> ComputeThreadIdToOutputIndexing(
-      int64_t root_index, mlir::MLIRContext* ctx) const override {
-    // The mapping cannot be statically computed in general, since the offsets
-    // are unknown.
-    return std::nullopt;
-  }
-
-  std::optional<IndexingMap> ComputeThreadIdToInputIndexing(
-      int64_t root_index, int64_t hero_operand_index,
-      mlir::MLIRContext* mlir_context) const override;
-
- protected:
-  absl::Status EmitKernel(IrEmitterContext& ir_emitter_context,
-                          const HloFusionInstruction& fusion,
-                          const LaunchDimensions& launch_dims,
-                          std::vector<llvm_ir::IrArray> inputs,
-                          std::vector<llvm_ir::IrArray> outputs,
-                          llvm::IRBuilder<>* builder) const override;
-
-  const HloFusionAnalysis& analysis_;
-  std::vector<HloInstructionAdaptor> dus_ops_;
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_FUSIONS_IN_PLACE_DYNAMIC_UPDATE_SLICE_H_
diff --git a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir_test.cc
index b0da3ef..f18173b 100644
--- a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir_test.cc
+++ b/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir_test.cc
@@ -56,7 +56,7 @@
 
   auto* root = module->entry_computation()->root_instruction();
 
-  auto analysis = AnalyzeFusion(*root, device_info_);
+  auto analysis = HloFusionAnalysis::Create(*root, device_info_);
   MlirInPlaceDynamicUpdateSliceFusion fusion(analysis);
 
   auto thread_id_update_indexing = fusion.ComputeThreadIdToInputIndexing(
@@ -100,8 +100,8 @@
     }
   )";
   TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"(
-    // CHECK-DAG: #[[MAP_1:.*]] = affine_map<(d0) -> (d0 floordiv 6)>
-    // CHECK-DAG: #[[MAP_2:.*]] = affine_map<(d0) -> (d0 mod 6)>
+    // CHECK-DAG: #[[MAP_1:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 floordiv 6), domain: d0 in [0, 29]
+    // CHECK-DAG: #[[MAP_2:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 mod 6), domain: d0 in [0, 29]
     // CHECK:     func.func @fused_computation
     // CHECK-SAME:  %arg0: tensor<20x30xf32>
     // CHECK-SAME:  %arg1: tensor<5x6xf32>
@@ -112,8 +112,8 @@
     // CHECK-DAG:   %[[C_15:.*]] = arith.constant 15
     // CHECK-DAG:   %[[C_0:.*]] = arith.constant 0
     // CHECK:       %[[THREAD_ID:.*]] = gpu.thread_id  x
-    // CHECK:       %[[INPUT_INDEX_0:.*]] = xla_gpu.apply_indexing #[[MAP_1]](%[[THREAD_ID]] in [0, 29])
-    // CHECK:       %[[INPUT_INDEX_1:.*]] = xla_gpu.apply_indexing #[[MAP_2]](%[[THREAD_ID]] in [0, 29])
+    // CHECK:       %[[INPUT_INDEX_0:.*]] = xla_gpu.apply_indexing #[[MAP_1]](%[[THREAD_ID]])
+    // CHECK:       %[[INPUT_INDEX_1:.*]] = xla_gpu.apply_indexing #[[MAP_2]](%[[THREAD_ID]])
     // CHECK:       %[[I0:.*]] = xla_gpu.pure_call @fused_computation_i0
     // CHECK:       %[[I1:.*]] = xla_gpu.pure_call @fused_computation_i1
     // CHECK:       %[[IDX0:.*]] = arith.index_cast %[[I0]]
@@ -151,8 +151,8 @@
     }
   )";
   TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"(
-    // CHECK-DAG: #[[MAP_1:.*]] = affine_map<(d0) -> (d0 floordiv 3)>
-    // CHECK-DAG: #[[MAP_2:.*]] = affine_map<(d0) -> (d0 mod 3)>
+    // CHECK-DAG: #[[MAP_1:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 floordiv 3), domain: d0 in [0, 5]
+    // CHECK-DAG: #[[MAP_2:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 mod 3), domain: d0 in [0, 5]
     // CHECK:     func.func @fused_computation
     // CHECK-SAME:  %arg0: tensor<7x8xf32>
     // CHECK-SAME:  %arg1: tensor<2x3xf32>
@@ -162,8 +162,8 @@
     // CHECK-DAG:   %[[C_5:.*]] = arith.constant 5
     // CHECK-DAG:   %[[C_0:.*]] = arith.constant 0
     // CHECK:       %[[THREAD_ID:.*]] = gpu.thread_id  x
-    // CHECK:       %[[INPUT_INDEX_0:.*]] = xla_gpu.apply_indexing #[[MAP_1]](%[[THREAD_ID]] in [0, 5])
-    // CHECK:       %[[INPUT_INDEX_1:.*]] = xla_gpu.apply_indexing #[[MAP_2]](%[[THREAD_ID]] in [0, 5])
+    // CHECK:       %[[INPUT_INDEX_0:.*]] = xla_gpu.apply_indexing #[[MAP_1]](%[[THREAD_ID]])
+    // CHECK:       %[[INPUT_INDEX_1:.*]] = xla_gpu.apply_indexing #[[MAP_2]](%[[THREAD_ID]])
     // CHECK:       %[[I0:.*]] = xla_gpu.pure_call @fused_computation_i0
     // CHECK:       %[[I1:.*]] = xla_gpu.pure_call @fused_computation_i1
     // CHECK:       %[[IDX0:.*]] = arith.index_cast %[[I0]]
diff --git a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_test.cc b/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_test.cc
deleted file mode 100644
index e48cee0..0000000
--- a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_test.cc
+++ /dev/null
@@ -1,144 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "xla/service/gpu/fusions/in_place_dynamic_update_slice.h"
-
-#include <optional>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "mlir/IR/MLIRContext.h"
-#include "xla/service/gpu/fusions/fusions.h"
-#include "xla/service/gpu/gpu_device_info_for_tests.h"
-#include "xla/service/gpu/hlo_fusion_analysis.h"
-#include "xla/service/gpu/model/affine_map_printer.h"
-#include "xla/service/gpu/model/indexing_test_utils.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/tests/hlo_test_base.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-class InPlaceDynamicUpdateSliceFusionTest : public HloTestBase {
- public:
-  void SetUp() override {
-    HloTestBase::SetUp();
-    printer_ =
-        AffineMapPrinter({"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"},
-                         {"chunk_id", "unroll_id"});
-  }
-
- protected:
-  DebugOptions GetDebugOptionsForTest() override {
-    auto opts = HloTestBase::GetDebugOptionsForTest();
-    opts.set_xla_gpu_mlir_emitter_level(0);
-    return opts;
-  }
-  AffineMapPrinter printer_;
-  mlir::MLIRContext mlir_context_;
-  stream_executor::DeviceDescription device_info_ =
-      TestGpuDeviceInfo::RTXA6000DeviceInfo();
-};
-
-TEST_F(InPlaceDynamicUpdateSliceFusionTest, ThreadIndexing) {
-  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"(
-    HloModule module
-
-    fused_computation {
-      in = f32[20,30] parameter(0)
-      updates = f32[5,6] parameter(1)
-      i0 = s32[] parameter(2)
-      i1 = s32[] parameter(3)
-      ROOT updated = f32[20,30] dynamic-update-slice(in, updates, i0, i1)
-    }
-    ENTRY entry {
-      in = f32[20,30] parameter(0)
-      updates = f32[5,6] parameter(1)
-      i0 = s32[] constant(2)
-      i1 = s32[] constant(3)
-      ROOT fusion = f32[20,30] fusion(in, updates, i0, i1), kind=kLoop, calls=fused_computation
-    }
-  )"));
-
-  auto* root = module->entry_computation()->root_instruction();
-  auto analysis_fused = AnalyzeFusion(*root, device_info_);
-
-  auto emitter =
-      GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis_fused});
-  auto fusion = dynamic_cast<InPlaceDynamicUpdateSliceFusion*>(emitter.get());
-  ASSERT_NE(fusion, nullptr);
-
-  auto thread_id_update_indexing = fusion->ComputeThreadIdToInputIndexing(
-      /*root_index=*/0, /*hero_operand_index=*/1, &mlir_context_);
-  EXPECT_THAT(thread_id_update_indexing->ToString(printer_),
-              MatchIndexingString(R"(
-    (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (
-    th_x floordiv 6, th_x mod 6)
-    domain:
-    th_x in [0, 29]
-    th_y in [0, 0]
-    th_z in [0, 0]
-    bl_x in [0, 0]
-    bl_y in [0, 0]
-    bl_z in [0, 0]
-    chunk_id in [0, 0]
-    unroll_id in [0, 0]
-  )"));
-  auto thread_id_dst_indexing = fusion->ComputeThreadIdToInputIndexing(
-      /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_);
-  EXPECT_THAT(thread_id_dst_indexing, ::testing::Eq(std::nullopt));
-}
-
-TEST_F(InPlaceDynamicUpdateSliceFusionTest, ProduceConsumerFusion) {
-  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"(
-    HloModule m
-
-    fused_computation.1 {
-      param_0 = bf16[1,2,5,1,2] parameter(0)
-      bitcast = bf16[1,5,1,2,2] bitcast(param_0)
-      param_1 = bf16[1,1,1,2,2] parameter(1)
-      param_2 = s32[] parameter(2)
-      param_3 = s32[] parameter(3)
-      ROOT dynamic-update-slice = bf16[1,5,1,2,2] dynamic-update-slice(bitcast, param_1, param_2, param_3, param_2, param_2, param_2)
-    }
-
-    ENTRY entry_computation {
-      param_0.2 = bf16[1,2,5,1,2] parameter(3)
-      param_1.2 = bf16[1,1,1,2,2] parameter(0)
-      param_2.2 = s32[] parameter(1)
-      param_3.2 = s32[] parameter(2)
-      fusion = bf16[1,5,1,2,2] fusion(param_0.2, param_1.2, param_2.2, param_3.2), kind=kLoop, calls=fused_computation.1
-      ROOT bitcast.1 = bf16[1,2,5,1,2] bitcast(fusion)
-    }
-  )"));
-
-  auto* root = module->entry_computation()->root_instruction();
-
-  auto analysis_fused =
-      AnalyzeProducerConsumerFusion(*root->operand(0), *root, device_info_);
-
-  auto emitter =
-      GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis_fused});
-
-  auto fusion = dynamic_cast<InPlaceDynamicUpdateSliceFusion*>(emitter.get());
-
-  ASSERT_NE(fusion, nullptr);
-  EXPECT_EQ(fusion->launch_dimensions().launch_bound(), 4 /* update size */);
-}
-
-}  // namespace
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/input_slices.cc b/third_party/xla/xla/service/gpu/fusions/input_slices.cc
deleted file mode 100644
index 75ffe4b..0000000
--- a/third_party/xla/xla/service/gpu/fusions/input_slices.cc
+++ /dev/null
@@ -1,220 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "xla/service/gpu/fusions/input_slices.h"
-
-#include <cstddef>
-#include <cstdint>
-#include <optional>
-#include <vector>
-
-#include "absl/log/check.h"
-#include "absl/log/log.h"
-#include "absl/status/status.h"
-#include "absl/strings/str_cat.h"
-#include "absl/types/span.h"
-#include "llvm/IR/IRBuilder.h"
-#include "llvm/IR/Value.h"
-#include "mlir/IR/MLIRContext.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/elemental_ir_emitter.h"
-#include "xla/service/gpu/elemental_ir_emitter.h"
-#include "xla/service/gpu/ir_emission_utils.h"
-#include "xla/service/gpu/ir_emitter_context.h"
-#include "xla/service/gpu/launch_dimensions.h"
-#include "xla/service/gpu/model/indexing_map.h"
-#include "xla/service/gpu/parallel_loop_emitter.h"
-#include "xla/service/llvm_ir/fused_ir_emitter.h"
-#include "xla/service/llvm_ir/ir_array.h"
-#include "xla/service/llvm_ir/kernel_support_library.h"
-#include "xla/service/llvm_ir/llvm_loop.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/util.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-// Emits code for slices based on the below structure. An if statement with
-// a guarding condition is generated for each ROOT slice.
-//
-// Pseudo code:
-//
-// Compute values of slice input operands
-//
-// Compute guarding_cond0
-// if (guarding_cond0) {
-//   Write to output of slice0
-// }
-//
-// Compute guarding_cond1
-// if (guarding_cond1) {
-//   Write to output of slice1
-// }
-//
-absl::Status EmitElementForInputFusibleSlices(
-    ElementalIrEmitter& elemental_emitter,
-    const HloComputation* fused_computation,
-    const std::vector<llvm_ir::IrArray>& inputs,
-    const std::vector<llvm_ir::IrArray>& outputs,
-    const llvm_ir::IrArray::Index& index, llvm::IRBuilder<>* builder) {
-  VLOG(10) << "Emitting slice input fusion for "
-           << fused_computation->ToString();
-
-  HloInstruction* slice_or_tuple = fused_computation->root_instruction();
-  auto slice_instructions = [&]() -> absl::Span<HloInstruction* const> {
-    if (slice_or_tuple->opcode() == HloOpcode::kSlice) {
-      return absl::Span<HloInstruction* const>(&slice_or_tuple, 1);
-    }
-    CHECK_EQ(slice_or_tuple->opcode(), HloOpcode::kTuple);
-    return slice_or_tuple->operands();
-  }();
-
-  // Emit input operand values of slices.
-  std::vector<llvm::Value*> input_ir_values;
-  FusedIrEmitter fused_emitter(elemental_emitter);
-  for (int i = 0; i < fused_computation->num_parameters(); i++) {
-    fused_emitter.BindGenerator(
-        *fused_computation->parameter_instruction(i),
-        [&inputs, i, builder](llvm_ir::IrArray::Index index) {
-          return inputs[i].EmitReadArrayElement(index, builder);
-        });
-  }
-  for (const HloInstruction* slice : slice_instructions) {
-    auto input_generator = *fused_emitter.GetGenerator(*slice->operand(0));
-    input_ir_values.push_back(input_generator(index).value());
-  }
-
-  // Emit for slice_instructions.
-  KernelSupportLibrary ksl(builder, llvm_ir::UnrollMode::kDefaultUnroll);
-  for (int64_t i = 0; i < slice_instructions.size(); ++i) {
-    HloInstruction* slice = slice_instructions[i];
-
-    // guarding_cond := index >= start && index < limit, for each dim.
-    std::vector<llvm::Value*> index_within_ranges;
-    for (size_t dim = 0; dim < slice->slice_starts().size(); ++dim) {
-      CHECK_EQ(slice->slice_strides(dim), 1);
-      auto larger_or_equal_than_start = builder->CreateICmpSGE(
-          index.multidim()[dim],
-          index.GetConstantWithIndexType(slice->slice_starts(dim)));
-      llvm::Value* smaller_than_limit = builder->CreateICmpSLT(
-          index.multidim()[dim],
-          index.GetConstantWithIndexType(slice->slice_limits(dim)));
-      llvm::Value* within_range =
-          builder->CreateAnd(larger_or_equal_than_start, smaller_than_limit);
-      index_within_ranges.push_back(within_range);
-    }
-    llvm::Value* guarding_cond = builder->CreateAnd(index_within_ranges);
-
-    auto emit_slice_elem_func = [&] {
-      const std::vector<llvm::Value*>& src_multidim = index.multidim();
-      std::vector<llvm::Value*> dst_multidim(src_multidim.size());
-      for (size_t dim = 0; dim < src_multidim.size(); ++dim) {
-        dst_multidim[dim] = builder->CreateSub(
-            src_multidim[dim],
-            index.GetConstantWithIndexType(slice->slice_starts(dim)));
-      }
-      const llvm_ir::IrArray& src_ir_array = outputs[i];
-      llvm_ir::IrArray::Index slice_dst_index(dst_multidim, slice->shape(),
-                                              index.GetType());
-      src_ir_array.EmitWriteArrayElement(slice_dst_index, input_ir_values[i],
-                                         builder);
-    };
-
-    ksl.If(absl::StrCat("slice", i), guarding_cond, emit_slice_elem_func);
-  }
-  return absl::OkStatus();
-}
-
-// Gets the input shape of the ROOT slices, which will be used as the kernel
-// launch dims. The slice input fusion requires the input shapes of the ROOT
-// slices to be the same although the (slice) output shapes can be different.
-//
-// Returns the input shape of the ROOT slices if all the input shapes of ROOT
-// slices are the same and the slices are non-strided. Otherwise, returns
-// FailedPrecondition.
-absl::StatusOr<Shape> GetConsistentInputShapeForRootSlices(
-    const HloComputation* fused_computation) {
-  const HloInstruction& root = *fused_computation->root_instruction();
-  if (root.opcode() == HloOpcode::kSlice) {
-    return root.operands()[0]->shape();
-  }
-
-  CHECK_EQ(root.opcode(), HloOpcode::kTuple);
-  const Shape& first_slice_operand_shape =
-      root.operands()[0]->operands()[0]->shape();
-  for (size_t i = 1; i < root.operands().size(); ++i) {
-    const HloInstruction* slice = root.operands()[i];
-    const Shape& operand_shape = slice->operands()[0]->shape();
-    if (!ShapeUtil::EqualIgnoringElementType(first_slice_operand_shape,
-                                             operand_shape)) {
-      return FailedPrecondition(
-          "Fused slices do not have the same input shape, fused computation = "
-          "%s.",
-          root.parent()->name());
-    }
-  }
-
-  return first_slice_operand_shape;
-}
-
-}  // namespace
-
-LaunchDimensions InputSlicesFusion::launch_dimensions() const {
-  const auto& root = analysis_.fusion_root(0).instruction();
-  const auto& shape = root.operand(0)->shape();
-  return CalculateLaunchDimensions(shape, analysis_.device_info(),
-                                   {unroll_factor_});
-}
-
-std::optional<IndexingMap> InputSlicesFusion::ComputeThreadIdToOutputIndexing(
-    int64_t output_id, mlir::MLIRContext* ctx) const {
-  // The mapping here is trivial and the same for all outputs - slice offsets
-  // are applied in the indexing from slice outputs to slice inputs.
-  auto launch_dims = launch_dimensions();
-  // The implementation requires the shapes and layouts to be the same, but we
-  // still use the requested output's shape for clarity.
-  const auto& shape = analysis_.fusion_root(output_id).shape();
-  return GetDefaultThreadIdIndexingMap(launch_dims, unroll_factor_, shape, ctx);
-}
-
-absl::Status InputSlicesFusion::EmitKernel(
-    IrEmitterContext& ir_emitter_context, const HloFusionInstruction& fusion,
-    const LaunchDimensions& launch_dims, std::vector<llvm_ir::IrArray> inputs,
-    std::vector<llvm_ir::IrArray> outputs, llvm::IRBuilder<>* builder) const {
-  TF_ASSIGN_OR_RETURN(Shape element_shape,
-                      GetConsistentInputShapeForRootSlices(
-                          fusion.fused_instructions_computation()));
-  LaunchDimensionsConfig launch_config;
-  launch_config.unroll_factor = unroll_factor_;
-  GpuElementalIrEmitter elemental_emitter(ir_emitter_context, builder);
-  return ParallelLoopEmitter(
-             [&](const llvm_ir::IrArray::Index index) -> absl::Status {
-               return EmitElementForInputFusibleSlices(
-                   elemental_emitter, fusion.fused_instructions_computation(),
-                   inputs, outputs, index, builder);
-             },
-             element_shape, launch_dims, builder, launch_config)
-      .EmitLoop(
-          fusion.name(),
-          GetIndexTypeForKernel(&fusion, launch_dims.launch_bound(), builder));
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/input_slices.h b/third_party/xla/xla/service/gpu/fusions/input_slices.h
deleted file mode 100644
index fa80438..0000000
--- a/third_party/xla/xla/service/gpu/fusions/input_slices.h
+++ /dev/null
@@ -1,79 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_FUSIONS_INPUT_SLICES_H_
-#define XLA_SERVICE_GPU_FUSIONS_INPUT_SLICES_H_
-
-#include <cstdint>
-#include <optional>
-#include <vector>
-
-#include "absl/status/status.h"
-#include "llvm/IR/IRBuilder.h"
-#include "mlir/IR/MLIRContext.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/service/gpu/fusions/fusion_emitter.h"
-#include "xla/service/gpu/hlo_fusion_analysis.h"
-#include "xla/service/gpu/ir_emitter_context.h"
-#include "xla/service/gpu/launch_dimensions.h"
-#include "xla/service/gpu/model/indexing_analysis.h"
-#include "xla/service/gpu/model/indexing_map.h"
-#include "xla/service/llvm_ir/ir_array.h"
-#include "xla/util.h"
-
-namespace xla {
-namespace gpu {
-
-// Generates code for input-fusible slices.
-//
-// Prerequisite: ROOT is either a slice or a tuple of slices. The input shapes
-// of all ROOT slices need to be the same while their output shapes can be
-// different. On the other hand, the input ranges of slices can be
-// overlapping. Further generalization/specialization when the needs are seen
-// in the future.
-class InputSlicesFusion : public KernelFusionEmitterBase {
- public:
-  explicit InputSlicesFusion(const HloFusionAnalysis& analysis)
-      : analysis_(analysis),
-        unroll_factor_(CeilOfRatio(
-            8, analysis.input_output_info().smallest_output_dtype_bits)) {}
-  LaunchDimensions launch_dimensions() const override;
-
-  std::optional<IndexingMap> ComputeThreadIdToOutputIndexing(
-      int64_t output_id, mlir::MLIRContext* ctx) const override;
-
-  std::optional<IndexingMap> ComputeThreadIdToInputIndexing(
-      int64_t root_index, int64_t hero_operand_index,
-      mlir::MLIRContext* ctx) const override {
-    // TODO(b/319081342): Implement this.
-    return std::nullopt;
-  }
-
- protected:
-  absl::Status EmitKernel(IrEmitterContext& ir_emitter_context,
-                          const HloFusionInstruction& fusion,
-                          const LaunchDimensions& launch_dims,
-                          std::vector<llvm_ir::IrArray> inputs,
-                          std::vector<llvm_ir::IrArray> outputs,
-                          llvm::IRBuilder<>* builder) const override;
-
- private:
-  const HloFusionAnalysis& analysis_;
-  const int unroll_factor_;
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_FUSIONS_INPUT_SLICES_H_
diff --git a/third_party/xla/xla/service/gpu/fusions/input_slices_mlir.cc b/third_party/xla/xla/service/gpu/fusions/input_slices_mlir.cc
index 5297bf7..d2739e0 100644
--- a/third_party/xla/xla/service/gpu/fusions/input_slices_mlir.cc
+++ b/third_party/xla/xla/service/gpu/fusions/input_slices_mlir.cc
@@ -38,9 +38,9 @@
 #include "xla/hlo/ir/hlo_computation.h"
 #include "xla/hlo/ir/hlo_instruction.h"
 #include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h"
 #include "xla/service/gpu/fusions/mlir/computation_partitioner.h"
 #include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h"
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h"
 #include "xla/service/gpu/hlo_traversal.h"
 #include "xla/service/gpu/launch_dimensions.h"
 #include "xla/service/gpu/model/indexing_analysis.h"
diff --git a/third_party/xla/xla/service/gpu/fusions/input_slices_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/input_slices_mlir_test.cc
index abeb57a..4aec42e 100644
--- a/third_party/xla/xla/service/gpu/fusions/input_slices_mlir_test.cc
+++ b/third_party/xla/xla/service/gpu/fusions/input_slices_mlir_test.cc
@@ -45,7 +45,7 @@
                     .value();
 
   auto* root = module->entry_computation()->root_instruction();
-  auto analysis = AnalyzeFusion(*root, device_info_);
+  auto analysis = HloFusionAnalysis::Create(*root, device_info_);
 
   auto emitter = GetEmitter(analysis);
 
diff --git a/third_party/xla/xla/service/gpu/fusions/input_slices_test.cc b/third_party/xla/xla/service/gpu/fusions/input_slices_test.cc
deleted file mode 100644
index 689727a..0000000
--- a/third_party/xla/xla/service/gpu/fusions/input_slices_test.cc
+++ /dev/null
@@ -1,104 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "xla/service/gpu/fusions/input_slices.h"
-
-#include <optional>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "mlir/IR/MLIRContext.h"
-#include "xla/service/gpu/fusions/fusions.h"
-#include "xla/service/gpu/gpu_device_info_for_tests.h"
-#include "xla/service/gpu/hlo_fusion_analysis.h"
-#include "xla/service/gpu/model/affine_map_printer.h"
-#include "xla/service/gpu/model/indexing_test_utils.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/tests/hlo_test_base.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-class InputSlicesTest : public HloTestBase {
- public:
-  void SetUp() override {
-    HloTestBase::SetUp();
-    printer_ =
-        AffineMapPrinter({"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"},
-                         {"chunk_id", "unroll_id"});
-  }
-
- protected:
-  DebugOptions GetDebugOptionsForTest() override {
-    auto opts = HloTestBase::GetDebugOptionsForTest();
-    opts.set_xla_gpu_mlir_emitter_level(0);
-    return opts;
-  }
-  AffineMapPrinter printer_;
-  mlir::MLIRContext mlir_context_;
-};
-
-TEST_F(InputSlicesTest, ThreadIndexing) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule module
-
-    fused_computation {
-      %input = f32[2,3,5,7]{2,1,0,3} parameter(0)
-      slice0 = f32[1,2,3,5]{2,1,0,3} slice(input), slice={[0:1],[1:3],[0:3],[2:7]}
-      slice1 = f32[1,2,3,5]{2,1,0,3} slice(input), slice={[0:1],[0:2],[0:3],[2:7]}
-      ROOT tuple = (f32[1,2,3,5]{2,1,0,3}, f32[1,2,3,5]{2,1,0,3}) tuple(slice0, slice1)
-    }
-
-    ENTRY entry {
-      %input = f32[2,3,5,7]{2,1,0,3} parameter(0)
-      ROOT %fusion = (f32[1,2,3,5]{2,1,0,3}, f32[1,2,3,5]{2,1,0,3}) fusion(%input), kind=kLoop, calls=fused_computation
-    })")
-                    .value();
-
-  stream_executor::DeviceDescription device_info =
-      TestGpuDeviceInfo::RTXA6000DeviceInfo();
-
-  auto* root = module->entry_computation()->root_instruction();
-  auto analysis_fused = AnalyzeFusion(*root, device_info);
-
-  auto emitter =
-      GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis_fused});
-  auto fusion = dynamic_cast<InputSlicesFusion*>(emitter.get());
-  ASSERT_NE(fusion, nullptr);
-
-  auto thread_id_to_output_indexing =
-      fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context_);
-  EXPECT_THAT(thread_id_to_output_indexing->ToString(printer_),
-              MatchIndexingString(R"(
-    (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (0,
-      ((bl_x * 128 + th_x) floordiv 3) mod 2,
-       (bl_x * 128 + th_x) mod 3,
-       (bl_x * 128 + th_x) floordiv 6)
-    domain:
-    th_x in [0, 127]
-    th_y in [0, 0]
-    th_z in [0, 0]
-    bl_x in [0, 1]
-    bl_y in [0, 0]
-    bl_z in [0, 0]
-    chunk_id in [0, 0]
-    unroll_id in [0, 0]
-    bl_x * 128 + th_x in [0, 29]
-  )"));
-}
-
-}  // namespace
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/ir/BUILD b/third_party/xla/xla/service/gpu/fusions/ir/BUILD
new file mode 100644
index 0000000..8250669
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/ir/BUILD
@@ -0,0 +1,157 @@
+load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")
+load("//xla/tests:build_defs.bzl", "xla_test")
+
+package(
+    # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
+    default_visibility = [":friends"],
+    licenses = ["notice"],
+)
+
+package_group(
+    name = "friends",
+    includes = [
+        "//xla:friends",
+    ],
+)
+
+td_library(
+    name = "xla_gpu_td_files",
+    srcs = glob(["*.td"]),
+    includes = ["."],
+    deps = [
+        "@llvm-project//mlir:BuiltinDialectTdFiles",
+        "@llvm-project//mlir:CallInterfacesTdFiles",
+        "@llvm-project//mlir:ControlFlowInterfacesTdFiles",
+        "@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
+        "@llvm-project//mlir:OpBaseTdFiles",
+        "@llvm-project//mlir:SideEffectInterfacesTdFiles",
+    ],
+)
+
+gentbl_cc_library(
+    name = "xla_gpu_dialect_inc_gen",
+    strip_include_prefix = ".",
+    tbl_outs = [
+        (
+            ["-gen-dialect-decls"],
+            "xla_gpu_dialect.h.inc",
+        ),
+        (
+            ["-gen-dialect-defs"],
+            "xla_gpu_dialect.cc.inc",
+        ),
+    ],
+    tblgen = "@llvm-project//mlir:mlir-tblgen",
+    td_file = "xla_gpu_dialect.td",
+    deps = [":xla_gpu_td_files"],
+)
+
+gentbl_cc_library(
+    name = "xla_gpu_ops_inc_gen",
+    strip_include_prefix = ".",
+    tbl_outs = [
+        (
+            ["-gen-op-decls"],
+            "xla_gpu_ops.h.inc",
+        ),
+        (
+            ["-gen-op-defs"],
+            "xla_gpu_ops.cc.inc",
+        ),
+    ],
+    tblgen = "@llvm-project//mlir:mlir-tblgen",
+    td_file = "xla_gpu_ops.td",
+    deps = [":xla_gpu_td_files"],
+)
+
+gentbl_cc_library(
+    name = "xla_gpu_attrs_inc_gen",
+    strip_include_prefix = ".",
+    tbl_outs = [
+        (
+            [
+                "-gen-attrdef-decls",
+            ],
+            "xla_gpu_attrs.h.inc",
+        ),
+        (
+            [
+                "-gen-attrdef-defs",
+            ],
+            "xla_gpu_attrs.cc.inc",
+        ),
+    ],
+    tblgen = "@llvm-project//mlir:mlir-tblgen",
+    td_file = "xla_gpu_attrs.td",
+    deps = [":xla_gpu_td_files"],
+)
+
+gentbl_cc_library(
+    name = "xla_gpu_types_inc_gen",
+    strip_include_prefix = ".",
+    tbl_outs = [
+        (
+            [
+                "-gen-typedef-decls",
+                "-typedefs-dialect=xla_gpu",
+            ],
+            "xla_gpu_types.h.inc",
+        ),
+        (
+            [
+                "-gen-typedef-defs",
+                "-typedefs-dialect=xla_gpu",
+            ],
+            "xla_gpu_types.cc.inc",
+        ),
+    ],
+    tblgen = "@llvm-project//mlir:mlir-tblgen",
+    td_file = "xla_gpu_types.td",
+    deps = [":xla_gpu_td_files"],
+)
+
+cc_library(
+    name = "xla_gpu",
+    srcs = [
+        "xla_gpu_attrs.cc",
+        "xla_gpu_dialect.cc",
+        "xla_gpu_ops.cc",
+        "xla_gpu_types.cc",
+    ],
+    hdrs = [
+        "xla_gpu_ops.h",
+    ],
+    deps = [
+        ":xla_gpu_attrs_inc_gen",
+        ":xla_gpu_dialect_inc_gen",
+        ":xla_gpu_ops_inc_gen",
+        ":xla_gpu_types_inc_gen",
+        "//xla/service/gpu/model:indexing_analysis",
+        "@com_google_absl//absl/strings:str_format",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:ArithDialect",
+        "@llvm-project//mlir:BytecodeOpInterface",
+        "@llvm-project//mlir:CallOpInterfaces",
+        "@llvm-project//mlir:FuncDialect",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:InferTypeOpInterface",
+        "@llvm-project//mlir:InliningUtils",
+        "@llvm-project//mlir:SideEffectInterfaces",
+        "@llvm-project//mlir:Support",
+    ],
+)
+
+xla_test(
+    name = "xla_gpu_ops_test",
+    srcs = ["xla_gpu_ops_test.cc"],
+    backends = ["gpu"],
+    deps = [
+        ":xla_gpu",
+        "//xla/service/gpu/model:indexing_analysis",
+        "//xla/service/gpu/model:indexing_test_utils",
+        "//xla/tests:hlo_test_base",
+        "//xla/tests:xla_internal_test_main",
+        "@com_google_googletest//:gtest",
+        "@local_tsl//tsl/platform:test",
+    ],
+)
diff --git a/third_party/xla/xla/service/gpu/fusions/ir/tests/BUILD b/third_party/xla/xla/service/gpu/fusions/ir/tests/BUILD
new file mode 100644
index 0000000..381d5a3
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/ir/tests/BUILD
@@ -0,0 +1,16 @@
+load("//xla:lit.bzl", "lit_test_suite")
+
+package(
+    # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
+    licenses = ["notice"],
+)
+
+lit_test_suite(
+    name = "tests",
+    srcs = glob(["*.mlir"]),
+    cfg = "//xla:lit.cfg.py",
+    tools = [
+        "//xla/service/gpu/fusions/tools:mlir_fusions_opt",
+        "@llvm-project//llvm:FileCheck",
+    ],
+)
diff --git a/third_party/xla/xla/service/gpu/fusions/ir/tests/canonicalize.mlir b/third_party/xla/xla/service/gpu/fusions/ir/tests/canonicalize.mlir
new file mode 100644
index 0000000..34065f9
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/ir/tests/canonicalize.mlir
@@ -0,0 +1,188 @@
+// RUN: mlir_fusions_opt %s --split-input-file -canonicalize | FileCheck %s
+
+#map0 = #xla_gpu.indexing_map<()[s0, s1] -> (1 + s0 + s1 mod 3 - s1, s0 mod 2),
+                                domain: s0 in [-10, 10], s1 in [0, 2]>
+func.func @simplify_apply_indexing(%s0: index, %s1: index) -> (index, index) {
+  %0:2 = xla_gpu.apply_indexing #map0 [%s0, %s1]
+  func.return %0#0, %0#1 : index, index
+}
+// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<()[s0] -> (s0 + 1, s0 mod 2),
+// CHECK-SAME:                     domain: s0 in [-10, 10]>
+
+// CHECK-LABEL: func.func @simplify_apply_indexing
+// CHECK-SAME:      %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index)
+// CHECK:         xla_gpu.apply_indexing #[[$MAP]][%[[ARG_0]]]
+
+// -----
+
+#map0 = #xla_gpu.indexing_map<(d0, d1, d2)[s0, s1] -> (1 + s0 + s1 mod 4 - s1, s0 mod 2, d0 + d2),
+                               domain: d0 in [0, 1], d1 in [0, 2], d2 in [0, 3], s0 in [-11, 11], s1 in [0, 3]>
+func.func @simplify_apply_indexing_remove_dims(%d0: index, %d1: index,
+    %d2: index, %s0: index, %s1: index) -> (index, index, index) {
+  %0:3 = xla_gpu.apply_indexing #map0(%d0, %d1, %d2)[%s0, %s1]
+  func.return %0#0, %0#1, %0#2 : index, index, index
+}
+// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1)[s0] -> (s0 + 1, s0 mod 2, d0 + d1),
+// CHECK-SAME:                     domain: d0 in [0, 1], d1 in [0, 3], s0 in [-11, 11]>
+
+// CHECK-LABEL: func.func @simplify_apply_indexing_remove_dims
+// CHECK-SAME:      %[[ARG_0:[a-zA-Z0-9_]+]]: index,
+// CHECK-SAME:      %[[ARG_1:[a-zA-Z0-9_]+]]: index,
+// CHECK-SAME:      %[[ARG_2:[a-zA-Z0-9_]+]]: index,
+// CHECK-SAME:      %[[ARG_3:[a-zA-Z0-9_]+]]: index,
+// CHECK-SAME:      %[[ARG_4:[a-zA-Z0-9_]+]]: index)
+// CHECK:       xla_gpu.apply_indexing #[[$MAP]]
+// CHECK-SAME:    (%[[ARG_0]], %[[ARG_2]])
+// CHECK-SAME:    [%[[ARG_3]]]
+
+// -----
+
+#map0 = #xla_gpu.indexing_map<(d0, d1)[s0] -> (d0 + s0, 4, d1, 1, s0),
+                                domain: d0 in [-10, 10], d1 in [0, 2], s0 in [-1, 1]>
+func.func @fold_indexing_map_results(%d0: index, %d1: index, %s0: index)
+    -> (index, index, index, index, index) {
+  %0:5 = xla_gpu.apply_indexing #map0 (%d0, %d1)[%s0]
+  func.return %0#0, %0#1, %0#2, %0#3, %0#4  : index, index, index, index, index
+}
+// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0),
+
+// CHECK-LABEL: func.func @fold_indexing_map_results
+// CHECK-SAME:  %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index, %[[ARG_2:.*]]: index)
+
+// CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C4:.*]] = arith.constant 4 : index
+
+// CHECK:       %[[NEW_RESULT:.*]] = xla_gpu.apply_indexing #[[$MAP]]
+// CHECK:       return %[[NEW_RESULT]], %[[C4]], %[[ARG_1]], %[[C1]], %[[ARG_2]]
+
+// -----
+
+#map0 = #xla_gpu.indexing_map<(d0, d1)[s0] -> (d0 + s0, s0 + 4, d1 mod 2, 1 + d1, s0),
+                                domain: d0 in [-10, 10], d1 in [0, 2], s0 in [-1, 1]>
+func.func @remove_unused_results(%d0: index, %d1: index, %s0: index) -> (index) {
+  %0:5 = xla_gpu.apply_indexing #map0 (%d0, %d1)[%s0]
+  func.return %0#2 : index
+}
+// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 mod 2),
+// CHECK-SAME:                     domain: d0 in [0, 2]>
+
+// CHECK-LABEL: func.func @remove_unused_results
+// CHECK-SAME:  %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index, %[[ARG_2:.*]]: index)
+
+// CHECK:       %[[NEW_RESULT:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[ARG_1]])
+// CHECK:       return %[[NEW_RESULT]]
+
+// -----
+
+#map0 = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0 + d1 + s0 + s1 mod 3),
+                                domain: d0 in [0, 10], d1 in [0, 5], s0 in [-10, 10], s1 in [0, 4]>
+func.func @fold_operands(%d0: index) -> index {
+  %d1 = arith.constant 1 : index
+  %s0 = arith.constant 2 : index
+  %s1 = arith.constant 3 : index
+  %0 = xla_gpu.apply_indexing #map0 (%d0, %d1)[%s0, %s1]
+  func.return %0 : index
+}
+// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 + 3),
+// CHECK-SAME:                     domain: d0 in [0, 10]>
+
+// CHECK-LABEL: func.func @fold_operands
+// CHECK-SAME:      %[[ARG_0:.*]]: index)
+// CHECK:         xla_gpu.apply_indexing #[[$MAP]](%[[ARG_0]])
+
+// -----
+
+func.func @fold_operands_and_results(%arg0: index, %arg1: index)
+  -> (index, index) {
+  %0:2 = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0, d1) -> (0, d1),
+    domain: d0 in [0, 4], d1 in [0, 5]>(%arg0, %arg1)
+  return %0#0, %0#1 : index, index
+}
+
+// CHECK-LABEL: func.func @fold_operands_and_results
+// CHECK-SAME:      %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index)
+// CHECK-NEXT: %[[C0:.*]] = arith.constant 0
+// CHECK-NEXT: return %[[C0]], %[[ARG_1]] : index, index
+
+// -----
+
+func.func @fold_sequence(%arg0: index, %arg1: index) -> index {
+  %0 = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1),
+    domain: d0 in [0, 5], d1 in [0, 4]>(%arg0, %arg1)
+  %1 = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0) -> (d0 mod 100 + 42),
+    domain: d0 in [0, 10000]>(%0)
+  func.return %1 : index
+}
+
+// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1 + 42),
+// CHECK-SAME:                     domain: d0 in [0, 5], d1 in [0, 4]>
+// CHECK-LABEL: func.func @fold_sequence
+// CHECK-SAME:      %[[ARG0:.*]]: index, %[[ARG1:.*]]: index)
+// CHECK-NEXT:  xla_gpu.apply_indexing #[[$MAP]]
+// CHECK-SAME:      (%[[ARG0]], %[[ARG1]])
+
+// -----
+
+func.func @fold_sequence_sym(%arg0: index, %arg1: index) -> index {
+  %0 = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1),
+    domain: d0 in [0, 5], d1 in [0, 4]>(%arg0, %arg1)
+  %1 = xla_gpu.apply_indexing #xla_gpu.indexing_map<()[s0] -> (s0 mod 100 + 42),
+    domain: s0 in [0, 10000]>(%0)
+  func.return %1 : index
+}
+
+// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1 + 42),
+// CHECK-SAME:                     domain: d0 in [0, 5], d1 in [0, 4]>
+// CHECK-LABEL: func.func @fold_sequence_sym
+// CHECK-SAME:      %[[ARG0:.*]]: index, %[[ARG1:.*]]: index)
+// CHECK-NEXT:  xla_gpu.apply_indexing #[[$MAP]]
+// CHECK-SAME:      (%[[ARG0]], %[[ARG1]])
+
+// -----
+
+func.func @fold_sequence_shared_operands(%arg0: index, %arg1: index) -> index {
+  %0 = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1),
+    domain: d0 in [0, 5], d1 in [0, 4]>(%arg0, %arg1)
+  %1 = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1),
+    domain: d0 in [0, 4], d1 in [0, 10000]>(%arg1, %0)
+  func.return %1 : index
+}
+
+// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d0 * 2 + d1),
+// CHECK-SAME:                     domain: d0 in [0, 4], d1 in [0, 5]>
+// CHECK-LABEL: func.func @fold_sequence_shared_operands
+// CHECK-SAME:      %[[ARG0:.*]]: index, %[[ARG1:.*]]: index)
+// CHECK-NEXT:  xla_gpu.apply_indexing #[[$MAP]]
+// CHECK-SAME:      (%[[ARG1]], %[[ARG0]])
+
+// -----
+
+func.func @atomic_rmw_empty(%in: tensor<2x3xf32>, %i: index, %j: index)
+    -> (tensor<2x3xf32>) {
+  %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x3xf32> {
+    ^bb0(%current : f32):
+      xla_gpu.yield %current : f32
+  }
+  return %ret : tensor<2x3xf32>
+}
+// CHECK-LABEL: func.func @atomic_rmw_empty
+// CHECK-SAME:      %[[ARG0:.*]]: tensor<2x3xf32>
+// CHECK: return %[[ARG0]]
+
+
+// -----
+
+func.func @atomic_rmw_cst(%in: tensor<2x3xf32>, %i: index, %j: index)
+    -> (tensor<2x3xf32>) {
+  %cst = arith.constant 0.0 : f32
+  %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x3xf32> {
+    ^bb0(%current : f32):
+      xla_gpu.yield %cst : f32
+  }
+  return %ret : tensor<2x3xf32>
+}
+// CHECK-LABEL: func.func @atomic_rmw_cst
+// CHECK-SAME:      %[[ARG0:.*]]: tensor<2x3xf32>
+// CHECK-NEXT: %[[CST:.*]] = arith.constant
+// CHECK-NEXT: atomic_rmw
+// CHECK:      xla_gpu.yield %[[CST]]
diff --git a/third_party/xla/xla/service/gpu/fusions/ir/tests/indexing_map_attr.mlir b/third_party/xla/xla/service/gpu/fusions/ir/tests/indexing_map_attr.mlir
new file mode 100644
index 0000000..cd2e09f
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/ir/tests/indexing_map_attr.mlir
@@ -0,0 +1,136 @@
+// RUN: mlir_fusions_opt %s -split-input-file | mlir_fusions_opt -split-input-file | FileCheck %s
+
+// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map<
+// CHECK-SAME: (d0, d1, d2)[s0] -> (d0),
+// CHECK-SAME: domain:
+// CHECK-SAME: d0 in [1, 2],
+// CHECK-SAME: d1 in [5, 8],
+// CHECK-SAME: d2 in [10, 12],
+// CHECK-SAME: s0 in [0, 32],
+// CHECK-SAME: d0 mod 2 in [0, 1],
+// CHECK-SAME: d0 + s0 in [1, 10]
+// CHECK-SAME: >
+#map = #xla_gpu.indexing_map<(d0, d1, d2)[s0] -> (d0),
+                            domain:
+                            d0 in [1, 2],
+                            d1 in [5, 8],
+                            d2 in [10, 12],
+                            s0 in [0, 32],
+                            d0 mod 2 in [0, 1],
+                            d0 + s0 in [1, 10]
+                            >
+
+func.func private @indexing_map_attr(!xla_gpu.indexed_vector<64x64x32xf64, #map>)
+// CHECK-LABEL: @indexing_map_attr
+// CHECK: !xla_gpu.indexed_vector<64x64x32xf64, #[[$INDEX_MAP]]>
+
+// -----
+
+// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map<
+// CHECK-SAME: (d0, d1)[s0, s1, s2] -> (d0 + s0, d1 + s1, d1 + s2)
+// CHECK-SAME: domain:
+// CHECK-SAME: d0 in [1, 2]
+// CHECK-SAME: d1 in [5, 8]
+// CHECK-SAME: s0 in [0, 10]
+// CHECK-SAME: s1 in [0, 5]
+// CHECK-SAME: s2 in [0, 32]
+// CHECK-SAME: d0 mod 2 in [0, 1]
+// CHECK-SAME: d0 + s0 in [1, 10]
+// CHECK-SAME: d1 + s1 + s2 in [1, 32]
+// CHECK-SAME: >
+#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1, s2] -> (d0 + s0, d1 + s1, d1 + s2),
+                            domain:
+                            d0 in [1, 2],
+                            d1 in [5, 8],
+                            s0 in [0, 10],
+                            s1 in [0, 5],
+                            s2 in [0, 32],
+                            d0 mod 2 in [0, 1],
+                            d0 + s0 in [1, 10],
+                            d1 + s1 + s2 in [1, 32]
+                            >
+func.func private @more_range_vars(!xla_gpu.indexed_vector<100x32xf64, #map>)
+// CHECK-LABEL: @more_range_vars
+// CHECK: !xla_gpu.indexed_vector<100x32xf64, #[[$INDEX_MAP]]>
+
+// -----
+
+// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map<
+// CHECK-SAME: (d0)[s0] -> (d0)
+// CHECK-SAME: domain:
+// CHECK-SAME: d0 in [0, 100]
+// CHECK-SAME: s0 in [-3, -1]
+// CHECK-SAME: >
+#map = #xla_gpu.indexing_map<(d0)[s0] -> (d0),
+                            domain:
+                            d0 in [0, 100],
+                            s0 in [-3, -1]
+                            >
+func.func private @indexing_map_small(!xla_gpu.indexed_vector<100xf64, #map>)
+// CHECK-LABEL: @indexing_map_small
+// CHECK: !xla_gpu.indexed_vector<100xf64, #[[$INDEX_MAP]]>
+
+// -----
+
+// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map<
+// CHECK-SAME: (d0, d1, d2)[s0] -> (d0)
+// CHECK-SAME: domain:
+// CHECK-SAME: d0 in [1, 2]
+// CHECK-SAME: d1 in [5, 8]
+// CHECK-SAME: d2 in [10, 12]
+// CHECK-SAME: s0 in [0, 32]
+// CHECK-SAME: >
+#map = #xla_gpu.indexing_map<(d0, d1, d2)[s0] -> (d0),
+                            domain:
+                            d0 in [1, 2],
+                            d1 in [5, 8],
+                            d2 in [10, 12],
+                            s0 in [0, 32]
+                            >
+func.func private @no_constraints(!xla_gpu.indexed_vector<32xf64, #map>)
+// CHECK-LABEL: @no_constraints
+// CHECK: !xla_gpu.indexed_vector<32xf64, #[[$INDEX_MAP]]>
+
+// -----
+
+// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map<
+// CHECK-SAME: ()[s0] -> (s0)
+// CHECK-SAME: domain:
+// CHECK-SAME: s0 in [3, 5]
+// CHECK-SAME: s0 mod 2 in [0, 1]
+// CHECK-SAME: >
+#map = #xla_gpu.indexing_map<()[s0] -> (s0),
+                            domain:
+                            s0 in [3, 5],
+                            s0 mod 2 in [0, 1]
+                            >
+func.func private @no_dimensions(!xla_gpu.indexed_vector<100xf64, #map>)
+// CHECK-LABEL: @no_dimensions
+// CHECK: !xla_gpu.indexed_vector<100xf64, #[[$INDEX_MAP]]>
+
+// -----
+
+// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map<
+// CHECK-SAME: (d0) -> (d0)
+// CHECK-SAME: domain:
+// CHECK-SAME: d0 in [3, 5]
+// CHECK-SAME: d0 mod 2 in [0, 1]
+// CHECK-SAME: >
+#map = #xla_gpu.indexing_map<(d0) -> (d0),
+                            domain:
+                            d0 in [3, 5],
+                            d0 mod 2 in [0, 1]
+                            >
+func.func private @no_symbols(!xla_gpu.indexed_vector<100xf64, #map>)
+// CHECK-LABEL: @no_symbols
+// CHECK: !xla_gpu.indexed_vector<100xf64, #[[$INDEX_MAP]]>
+
+// -----
+
+// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map<
+// CHECK-SAME: () -> ()
+// CHECK-SAME: >
+#map = #xla_gpu.indexing_map<() -> ()>
+func.func private @empty(!xla_gpu.indexed_vector<100xf64, #map>)
+// CHECK-LABEL: @empty
+// CHECK: !xla_gpu.indexed_vector<100xf64, #[[$INDEX_MAP]]>
\ No newline at end of file
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/inlining.mlir b/third_party/xla/xla/service/gpu/fusions/ir/tests/inlining.mlir
similarity index 100%
rename from third_party/xla/xla/service/gpu/fusions/mlir/tests/inlining.mlir
rename to third_party/xla/xla/service/gpu/fusions/ir/tests/inlining.mlir
diff --git a/third_party/xla/xla/service/gpu/fusions/ir/tests/invalid.mlir b/third_party/xla/xla/service/gpu/fusions/ir/tests/invalid.mlir
new file mode 100644
index 0000000..999a6de
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/ir/tests/invalid.mlir
@@ -0,0 +1,236 @@
+// RUN: mlir_fusions_opt  %s -split-input-file -verify-diagnostics
+
+#map0 = #xla_gpu.indexing_map<
+ (d0, d1)[s0] -> (d0, d1 + s0),
+ domain:
+ d0 in [1, 2],
+ d1 in [5, 8],
+ s0 in [0, 32]
+>
+func.func @apply_indexing(%d0: index, %d1: index, %s0: index) -> (index, index) {
+  // expected-error @+1 {{operand count must match the number of dimensions and symbols in the affine map}}
+  %0:2 = xla_gpu.apply_indexing #map0 (%d0)
+  func.return %0#0, %0#1 : index, index
+}
+
+// -----
+
+#map0 = #xla_gpu.indexing_map<
+ (d0, d1)[s0] -> (d0, d1 + s0),
+ domain:
+ d0 in [1, 2],
+ d1 in [5, 8],
+ s0 in [0, 32],
+ d0 mod 2 in [0, 1],
+ d0 + s0 in [1, 10]
+>
+func.func @cannot_have_constraints(%d0: index, %d1: index, %s0: index) -> (index, index) {
+  // expected-error @+1 {{apply indexing op cannot have any constraints}}
+  %0:2 = xla_gpu.apply_indexing #map0 (%d0, %d1)[%s0]
+  func.return %0#0, %0#1 : index, index
+}
+
+// -----
+
+#map = #xla_gpu.indexing_map<()[s0, s1] -> (s0, s1), domain: s0 in [0, 1024], s1 in [0, 32]>
+func.func @loop_result_num_mismatch(%input: tensor<1024x32xf32>,
+    %init: f32) -> (f32) {
+  // expected-error @+1 {{mismatch in number of loop-carried values and results}}
+   %sum:2 = "xla_gpu.loop"(%init) <{
+      indexing_map_attr = #map,
+      operandSegmentSizes = array<i32: 0, 1>
+    }> ({
+    ^bb0(%i: index, %j: index, %iter: f32):
+      %t = tensor.extract %input[%i, %j] : tensor<1024x32xf32>
+      %add = arith.addf %iter, %t : f32
+      xla_gpu.yield %add : f32
+    }) : (f32) -> (f32, f32)
+  func.return %sum#0 : f32
+}
+
+// -----
+
+#map = #xla_gpu.indexing_map<()[s0] -> (s0, s0), domain: s0 in [0, 1024]>
+func.func @loop_iv_num_mismatch(%input: tensor<1024x32xf32>,
+    %init: f32) -> (f32) {
+  // expected-error @+1 {{mismatch in number of induction variables 2 and RangeVars}}
+   %sum = "xla_gpu.loop"(%init) <{
+      indexing_map_attr = #map,
+      operandSegmentSizes = array<i32: 0, 1>
+    }> ({
+    ^bb0(%i: index, %j: index, %iter: f32):
+      %t = tensor.extract %input[%i, %j] : tensor<1024x32xf32>
+      %add = arith.addf %iter, %t : f32
+      xla_gpu.yield %add : f32
+    }) : (f32) -> (f32)
+  func.return %sum : f32
+}
+
+// -----
+
+#map = #xla_gpu.indexing_map<()[s0, s1] -> (s0, s1), domain: s0 in [0, 1024], s1 in [0, 32]>
+func.func @loop_types_mismatch(%input: tensor<1024x32xf32>, %init: f32) -> (i32) {
+  // expected-error @+1 {{block iter arg type = 'f32', result type = 'i32' and init operand type = 'f32' should match}}
+   %sum = "xla_gpu.loop"(%init) <{
+      indexing_map_attr = #map,
+      operandSegmentSizes = array<i32: 0, 1>
+    }> ({
+    ^bb0(%i: index, %j: index, %iter: f32):
+      %t = tensor.extract %input[%i, %j] : tensor<1024x32xf32>
+      %add = arith.addf %iter, %t : f32
+      xla_gpu.yield %add : f32
+    }) : (f32) -> (i32)
+  func.return %sum : i32
+}
+
+// -----
+
+#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (s0, s1), domain: d0 in [0, 3], s0 in [0, 1024], s1 in [0, 32]>
+func.func @loop_op(%input: tensor<1024x32xf32>, %init: f32, %dim: index) -> (f32) {
+  // expected-error @+1 {{mismatch in number of dims operands 0 and DimVars in the indexing map}}
+  %sum = xla_gpu.loop ()[%i, %j] in #map iter_args(%sum_ = %init) -> (f32) {
+    %t = tensor.extract %input[%i, %j] : tensor<1024x32xf32>
+    %add = arith.addf %sum_, %t : f32
+    xla_gpu.yield %add : f32
+  } {xla.range = [0 : index, 42 : index]}
+  func.return %sum : f32
+}
+
+// -----
+
+#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32]>
+func.func @indicies_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map> {
+  // expected-error @+1 {{number of indices must match number of dimensions of indexing map}}
+  %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map>
+  func.return %0 : !xla_gpu.indexed_vector<32x64xf32, #map>
+}
+
+// -----
+
+#map = #xla_gpu.indexing_map<()[s0, s1] -> (s0, s1), domain: s0 in [0, 1024], s1 in [0, 32]>
+#map1 = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32]>
+func.func @no_thread_id_in(%input: tensor<32x64xf32>, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> {
+  // expected-error @+1 {{must have thread_id dimension in both indexing maps}}
+  %0 = xla_gpu.materialize @exp(%input) at #map() : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1>
+  func.return %0 : !xla_gpu.indexed_vector<32x64xf32, #map1>
+}
+
+// -----
+
+#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32]>
+#map1 = #xla_gpu.indexing_map<()[s0, s1] -> (s0, s1), domain: s0 in [0, 1024], s1 in [0, 32]>
+func.func @no_thread_id_out(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> {
+  // expected-error @+1 {{must have thread_id dimension in both indexing maps}}
+  %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1>
+  func.return %0 : !xla_gpu.indexed_vector<32x64xf32, #map1>
+}
+
+// -----
+
+#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32]>
+#map1 = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 64], s0 in [0, 1024], s1 in [0, 32]>
+func.func @thread_id_bounds_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> {
+  // expected-error @+1 {{thread_id dimension must have the same bounds in both indexing maps}}
+  %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1>
+  func.return %0 : !xla_gpu.indexed_vector<32x64xf32, #map1>
+}
+
+// -----
+
+#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32], d0 + s0 in [0, 1024]>
+#map1 = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32]>
+func.func @thread_id_constraints_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> {
+  // expected-error @+1 {{constraints of indexing maps must be equal for the thread_id dimension}}
+  %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1>
+  func.return %0 : !xla_gpu.indexed_vector<32x64xf32, #map1>
+}
+
+// -----
+
+#map = #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0, s0), domain: d0 in [0, 32], s0 in [0, 1024]>
+#map1 = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32]>
+func.func @symbol_count_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> {
+  // expected-error @+1 {{number of symbols in both indexing_maps must match}}
+  %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1>
+  func.return %0 : !xla_gpu.indexed_vector<32x64xf32, #map1>
+}
+
+// -----
+
+#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64]>
+#map1 = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32]>
+func.func @symbol_domain_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> {
+  // expected-error @+1 {{domain of symbols of indexing_maps must match}}
+  %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1>
+  func.return %0 : !xla_gpu.indexed_vector<32x64xf32, #map1>
+}
+
+// -----
+
+#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 1024]>
+#map1 = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 32]>
+func.func @symbol_constraints_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> {
+  // expected-error @+1 {{constraints of indexing maps must be equal for all symbols}}
+  %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1>
+  func.return %0 : !xla_gpu.indexed_vector<32x64xf32, #map1>
+}
+
+// -----
+
+#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 mod 2 in [0, 0]>
+#map1 = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 32]>
+func.func @symbol_constraint_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> {
+  // expected-error @+1 {{constraints of indexing maps must be equal for all symbols}}
+  %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1>
+  func.return %0 : !xla_gpu.indexed_vector<32x64xf32, #map1>
+}
+
+// -----
+
+#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 1024]>
+#map1 = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 32]>
+func.func @symbol_constraint_interval_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> {
+  // expected-error @+1 {{constraints of indexing maps must be equal for all symbols}}
+  %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1>
+  func.return %0 : !xla_gpu.indexed_vector<32x64xf32, #map1>
+}
+
+// -----
+
+#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64]>
+#map1 = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0 + s0, d1 + s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64]>
+func.func @vector_mapping_depends_on_block_id(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> {
+  // expected-error @+1 {{vector mapping indices must not depend on the block ID}}
+  %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1>
+  func.return %0 : !xla_gpu.indexed_vector<32x64xf32, #map1>
+}
+
+// -----
+
+#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], d1 mod 2 in [0, 0]>
+#map1 = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64]>
+func.func @block_id_constraints_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %block_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> {
+  // expected-error @+1 {{constraints of indexing maps must be equal for the block_id dimension}}
+  %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id, %block_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1>
+  func.return %0 : !xla_gpu.indexed_vector<32x64xf32, #map1>
+}
+
+// -----
+
+#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64]>
+#map1 = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], d1 mod 2 in [0, 0]>
+func.func @block_id_constraints_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %block_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> {
+  // expected-error @+1 {{constraints of indexing maps must be equal for the block_id dimension}}
+  %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1>
+  func.return %0 : !xla_gpu.indexed_vector<32x64xf32, #map1>
+}
+
+// -----
+
+#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], d1 mod 2 in [0, 0]>
+#map1 = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], d1 mod 4 in [0, 0]>
+func.func @block_id_constraints_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %block_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> {
+  // expected-error @+1 {{constraints of indexing maps must be equal for the block_id dimension}}
+  %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id, %block_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1>
+  func.return %0 : !xla_gpu.indexed_vector<32x64xf32, #map1>
+}
\ No newline at end of file
diff --git a/third_party/xla/xla/service/gpu/fusions/ir/tests/ops.mlir b/third_party/xla/xla/service/gpu/fusions/ir/tests/ops.mlir
new file mode 100644
index 0000000..c4378a3
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/ir/tests/ops.mlir
@@ -0,0 +1,165 @@
+// RUN: mlir_fusions_opt %s --split-input-file | FileCheck %s
+// Verify the printed output can be parsed.
+// RUN: mlir_fusions_opt %s --split-input-file | mlir_fusions_opt --split-input-file | FileCheck %s
+// Verify the generic form can be parsed.
+// RUN: mlir_fusions_opt %s --split-input-file --mlir-print-op-generic | mlir_fusions_opt --split-input-file | FileCheck %s
+
+func.func @shared_and_sync() -> (tensor<2xf32>, tensor<2xf32>) {
+  %shared1 = xla_gpu.allocate_shared : tensor<2xf32>
+  %shared2 = xla_gpu.allocate_shared : tensor<2xf32>
+  %sync:2 = xla_gpu.sync_threads %shared1, %shared2
+    : tensor<2xf32>, tensor<2xf32>
+  return %sync#0, %sync#1 : tensor<2xf32>, tensor<2xf32>
+}
+// CHECK-LABEL: @shared_and_sync
+// CHECK-NEXT: allocate_shared
+// CHECK-NEXT: allocate_shared
+// CHECK-NEXT: sync_threads
+// CHECK-NEXT: return
+
+// -----
+
+func.func @atomic_rmw(%in: tensor<2x3xf32>, %i: index, %j: index)
+    -> (tensor<2x3xf32>) {
+  %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x3xf32> {
+    ^bb0(%current : f32):
+      %c42 = arith.constant 42.0 : f32
+      %add = arith.addf %current, %c42 : f32
+      xla_gpu.yield %add : f32
+  }
+  return %ret : tensor<2x3xf32>
+}
+// CHECK-LABEL: @atomic_rmw
+// CHECK: xla_gpu.atomic_rmw
+
+// -----
+
+func.func private @add(%a: f32, %b: f32) -> f32 {
+  %ret = arith.addf %a, %b : f32
+  return %ret : f32
+}
+
+func.func @caller(%a: f32, %b: f32) -> f32 {
+  %c = xla_gpu.pure_call @add(%a, %b) : (f32, f32) -> (f32)
+  %d = xla_gpu.pure_call @add(%a, %b) : (f32, f32) -> (f32)
+  %ret = arith.addf %c, %d : f32
+  return %ret : f32
+}
+// CHECK-LABEL: @caller
+// CHECK: %[[C:.*]] = xla_gpu.pure_call @add
+// CHECK: %[[D:.*]] = xla_gpu.pure_call @add
+// CHECK: arith.addf %[[C]], %[[D]]
+
+// CHECK-CSE: @caller
+// CHECK-CSE: %[[C:.*]] = xla_gpu.pure_call @add
+// CHECK-CSE: arith.addf %[[C]], %[[C]]
+
+// -----
+
+#map0 = #xla_gpu.indexing_map<
+(d0, d1)[s0] -> (d0, d1 + s0),
+ domain:
+ d0 in [1, 2],
+ d1 in [5, 8],
+ s0 in [0, 32]
+>
+func.func @apply_indexing(%d0: index, %d1: index, %s0: index) -> (index, index) {
+  %0:2 = xla_gpu.apply_indexing #map0 (%d0, %d1)[%s0]
+  func.return %0#0, %0#1 : index, index
+}
+// CHECK: #[[$MAP0:.*]] = #xla_gpu.indexing_map<
+// CHECK-SAME:              (d0, d1)[s0] -> (d0, d1 + s0)
+// CHECK-SAME:              domain:
+// CHECK-SAME:              d0 in [1, 2]
+// CHECK-SAME:              d1 in [5, 8]
+// CHECK-SAME:              s0 in [0, 32]
+// CHECK-SAME:             >
+
+// CHECK-LABEL: @apply_indexing
+// CHECK: (%[[d0:.*]]: index, %[[d1:.*]]: index, %[[s0:.*]]: index)
+// CHECK: xla_gpu.apply_indexing #[[$MAP0]]
+// CHECK-SAME:  (%[[d0]], %[[d1]])[%[[s0]]]
+
+// -----
+
+#map0 = #xla_gpu.indexing_map<
+(d0, d1) -> (d0, d1),
+ domain:
+ d0 in [0, 2],
+ d1 in [1, 3]
+>
+func.func @apply_indexing_no_symbols(%d0: index, %d1: index) -> (index, index) {
+  %0:2 = xla_gpu.apply_indexing #map0 (%d0, %d1)
+  func.return %0#0, %0#1 : index, index
+}
+// CHECK: #[[$MAP0:.*]] = #xla_gpu.indexing_map<
+// CHECK-SAME:              (d0, d1) -> (d0, d1)
+// CHECK-SAME:              domain:
+// CHECK-SAME:              d0 in [0, 2]
+// CHECK-SAME:              d1 in [1, 3]
+// CHECK-SAME:             >
+
+// CHECK-LABEL: @apply_indexing_no_symbols
+// CHECK: (%[[d0:.*]]: index, %[[d1:.*]]: index)
+// CHECK: xla_gpu.apply_indexing #[[$MAP0]]
+// CHECK-SAME:  (%[[d0]], %[[d1]])
+
+// -----
+
+#map0 = #xla_gpu.indexing_map<
+ ()[s0] -> (s0, s0),
+ domain:
+ s0 in [2, 4]
+>
+func.func @apply_indexing_no_dims(%s0: index) -> (index, index) {
+  %0:2 = xla_gpu.apply_indexing #map0 [%s0]
+  func.return %0#0, %0#1 : index, index
+}
+// CHECK: #[[$MAP0:.*]] = #xla_gpu.indexing_map<
+// CHECK-SAME:              ()[s0] -> (s0, s0)
+// CHECK-SAME:              domain:
+// CHECK-SAME:              s0 in [2, 4]
+// CHECK-SAME:             >
+
+// CHECK-LABEL: @apply_indexing_no_dims
+// CHECK: (%[[s0:.*]]: index)
+// CHECK: xla_gpu.apply_indexing #[[$MAP0]][%[[s0]]]
+
+// -----
+
+#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (s0, s1), domain: d0 in [0, 3], s0 in [0, 1024], s1 in [0, 32]>
+func.func @loop_op(%input: tensor<1024x32xf32>, %init: f32, %dim: index) -> (f32) {
+  %sum = xla_gpu.loop (%dim)[%i, %j] in #map iter_args(%sum_ = %init) -> (f32) {
+    %t = tensor.extract %input[%i, %j] : tensor<1024x32xf32>
+    %add = arith.addf %sum_, %t : f32
+    xla_gpu.yield %add : f32
+  } {xla.range = [0 : index, 42 : index]}
+  func.return %sum : f32
+}
+// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map
+// CHECK:      %0 = xla_gpu.loop (%{{.*}})[%[[I:.*]], %[[J:.*]]] in #[[$MAP]]
+// CHECK-SAME:     iter_args(%[[SUM_ITER:.*]] = %{{.*}}) -> (f32) {
+// CHECK:        %[[EXTRACTED:.*]] = tensor.extract %{{.*}}[%[[I]], %[[J]]]
+// CHECK:        %[[ADD:.*]] = arith.addf %{{.*}}, %[[EXTRACTED]] : f32
+// CHECK:        xla_gpu.yield %[[ADD]] : f32 
+// CHECK:      } {xla.range = [0 : index, 42 : index]}
+
+// -----
+
+func.func private @exp(%p0: tensor<32x64xf32>, %i: index, %j: index) -> f32 
+
+#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0 + s0, d1 + s1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32]>
+#map1 = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (s0, s1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32]>
+func.func @materialize_and_insert(%input: tensor<32x64xf32>, %i: index, %j: index, %output: tensor<32x64xf32>) -> tensor<32x64xf32> {
+  %0 = xla_gpu.materialize @exp(%input) at #map(%i, %j) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1>
+  %1 = xla_gpu.insert %0 into %output at #map1(%i, %j) : !xla_gpu.indexed_vector<32x64xf32, #map1> -> tensor<32x64xf32> into tensor<32x64xf32>
+  func.return %1 : tensor<32x64xf32>
+}
+
+// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0 + s0, d1 + s1)
+// CHECK-SAME: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32]>
+// CHECK: #[[$MAP1:.*]] = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (s0, s1)
+// CHECK-SAME: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32]>
+// CHECK-LABEL: @materialize_and_insert
+// CHECK: %[[MATERIALIZED:.*]] = xla_gpu.materialize @exp(%{{.*}}) at #[[$MAP]](%{{.*}}, %{{.*}})
+// CHECK: xla_gpu.insert %[[MATERIALIZED]] into %{{.*}} at #[[$MAP1]](%{{.*}}, %{{.*}})
diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.cc b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.cc
new file mode 100644
index 0000000..a3220b0
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.cc
@@ -0,0 +1,231 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <string>
+#include <utility>
+
+#include "absl/strings/str_format.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/LogicalResult.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/Types.h"
+#include "mlir/Support/LLVM.h"
+#include "xla/service/gpu/model/indexing_map.h"
+
+#define GET_ATTRDEF_LIST
+#define GET_ATTRDEF_CLASSES
+#include "xla/service/gpu/fusions/ir/xla_gpu_attrs.h.inc"
+
+namespace xla {
+namespace gpu {
+
+using llvm::ParseResult;
+using llvm::SmallVector;
+using mlir::AffineExpr;
+using mlir::ArrayRef;
+using mlir::AsmParser;
+using mlir::AsmPrinter;
+using mlir::failure;
+using mlir::success;
+
+ParseResult ParseInterval(AsmParser& parser, Interval& interval) {
+  // ParseResult converts to `true` if parsing failed.
+  return failure(parser.parseLSquare() || parser.parseInteger(interval.lower) ||
+                 parser.parseComma() || parser.parseInteger(interval.upper) ||
+                 parser.parseRSquare());
+}
+
+void PrintDimVars(AsmPrinter& p, ArrayRef<DimVar> dim_vars) {
+  int index = 0;
+  llvm::interleaveComma(dim_vars, p, [&](const DimVar& dim_var) {
+    p << "d" << index++ << " in " << dim_var.bounds;
+  });
+}
+
+ParseResult ParseDimVars(AsmParser& parser, ArrayRef<std::string> dim_names,
+                         SmallVector<DimVar>& dim_vars) {
+  dim_vars.reserve(dim_names.size());
+  for (const auto& [index, dim_name] : llvm::enumerate(dim_names)) {
+    if (parser.parseKeyword(dim_name) || parser.parseKeyword("in") ||
+        ParseInterval(parser, dim_vars.emplace_back().bounds)) {
+      return failure();
+    }
+    if (index < dim_names.size() - 1 && parser.parseComma()) {
+      return failure();
+    }
+  }
+  return success();
+}
+
+void PrintRangeVars(AsmPrinter& p, ArrayRef<RangeVar> range_vars) {
+  int index = 0;
+  llvm::interleaveComma(range_vars, p, [&](const RangeVar& range_var) {
+    p << "s" << index++ << " in " << range_var.range;
+  });
+}
+
+ParseResult ParseRangeVars(AsmParser& parser,
+                           ArrayRef<std::string> range_symbol_names,
+                           SmallVector<RangeVar>& range_vars) {
+  range_vars.reserve(range_symbol_names.size());
+  for (const auto& [index, range_symbol_name] :
+       llvm::enumerate(range_symbol_names)) {
+    if (parser.parseKeyword(range_symbol_name) || parser.parseKeyword("in") ||
+        ParseInterval(parser, range_vars.emplace_back().range)) {
+      return failure();
+    }
+    if (index < range_symbol_names.size() - 1 && parser.parseComma()) {
+      return failure();
+    }
+  }
+  return success();
+}
+
+void PrintConstraints(AsmPrinter& p,
+                      ArrayRef<std::pair<AffineExpr, Interval>> constraints) {
+  llvm::interleaveComma(constraints, p, [&](const auto& constraint) {
+    p << constraint.first << " in " << constraint.second;
+  });
+}
+
+ParseResult ParseConstraints(
+    AsmParser& parser,
+    ArrayRef<std::pair<llvm::StringRef, AffineExpr>> symbolSet,
+    SmallVector<std::pair<AffineExpr, Interval>>& constraints) {
+  // In order for there to be any constraints, there must be at least 1 symbol
+  // or dimension meaning there will be commas for as long as there are
+  // constraints left.
+  while (succeeded(parser.parseOptionalComma())) {
+    auto& constraint = constraints.emplace_back();
+    if (parser.parseAffineExpr(symbolSet, constraint.first) ||
+        parser.parseKeyword("in") || ParseInterval(parser, constraint.second)) {
+      return failure();
+    }
+  }
+  return success();
+}
+
+mlir::Attribute IndexingMapAttr::parse(mlir::AsmParser& parser, mlir::Type) {
+  mlir::AffineMap map;
+  if (parser.parseLess() || parser.parseAffineMap(map)) {
+    return {};
+  }
+
+  // Store real strings to back up StringRef throughout ParseConstraints.
+  SmallVector<std::string> dim_strings(map.getNumDims());
+  SmallVector<std::string> symbol_strings(map.getNumSymbols());
+  SmallVector<std::pair<llvm::StringRef, AffineExpr>> symbolSet;
+  symbolSet.reserve(map.getNumDims() + map.getNumSymbols());
+  for (int i = 0; i < map.getNumDims(); ++i) {
+    dim_strings[i] = absl::StrFormat("d%d", i);
+    symbolSet.push_back(
+        {dim_strings[i], mlir::getAffineDimExpr(i, parser.getContext())});
+  }
+  for (int i = 0; i < map.getNumSymbols(); ++i) {
+    symbol_strings[i] = absl::StrFormat("s%d", i);
+    symbolSet.push_back(
+        {symbol_strings[i], mlir::getAffineSymbolExpr(i, parser.getContext())});
+  }
+  if (map.getNumDims() + map.getNumSymbols() > 0) {
+    if (parser.parseComma() || parser.parseKeyword("domain") ||
+        parser.parseColon()) {
+      return {};
+    }
+  }
+
+  SmallVector<DimVar> dim_vars;
+  if (map.getNumDims() > 0) {
+    if (ParseDimVars(parser, dim_strings, dim_vars)) {
+      return {};
+    }
+  }
+
+  SmallVector<RangeVar> range_vars;
+  if (map.getNumSymbols() > 0) {
+    if (!dim_vars.empty() && parser.parseComma()) {
+      return {};
+    }
+    if (ParseRangeVars(parser, symbol_strings, range_vars)) {
+      return {};
+    }
+  }
+
+  SmallVector<std::pair<AffineExpr, Interval>> constraints;
+  if (ParseConstraints(parser, symbolSet, constraints) ||
+      parser.parseGreater()) {
+    return {};
+  }
+  return IndexingMapAttr::get(parser.getContext(), map, dim_vars, range_vars,
+                              constraints);
+}
+
+void IndexingMapAttr::print(mlir::AsmPrinter& printer) const {
+  printer << "<";
+  printer.printStrippedAttrOrType(getMap());
+  if (getDimVars().size() + getRangeVars().size() + getConstraints().size() >
+      0) {
+    printer << ", domain: ";
+  }
+  PrintDimVars(printer, getDimVars());
+  if (!getDimVars().empty() &&
+      getRangeVars().size() + getConstraints().size() > 0) {
+    printer << ", ";
+  }
+  PrintRangeVars(printer, getRangeVars());
+  if (!getRangeVars().empty() && !getConstraints().empty()) {
+    printer << ", ";
+  }
+  PrintConstraints(printer, getConstraints());
+  printer << ">";
+}
+
+IndexingMapAttr IndexingMapAttr::get(mlir::MLIRContext* context,
+                                     const IndexingMap& indexing_map) {
+  llvm::SmallVector<std::pair<AffineExpr, Interval>> constraints;
+  for (auto& constraint : indexing_map.GetConstraints()) {
+    constraints.push_back({constraint.first, constraint.second});
+  }
+  return get(context, indexing_map.GetAffineMap(), indexing_map.GetDimVars(),
+             indexing_map.GetRangeVars(), constraints);
+}
+
+mlir::LogicalResult IndexingMapAttr::verify(
+    mlir::function_ref<mlir::InFlightDiagnostic()> emitError,
+    mlir::AffineMap map, ArrayRef<DimVar> dim_vars,
+    ArrayRef<RangeVar> range_vars,
+    ArrayRef<std::pair<AffineExpr, Interval>> constraints) {
+  if (map.getNumDims() != dim_vars.size()) {
+    return emitError()
+           << "dim size must match the number of dimensions in the affine map";
+  }
+  if (map.getNumSymbols() != range_vars.size()) {
+    return emitError()
+           << "range size must match the number of symbols in the affine map";
+  }
+  return mlir::success();
+}
+
+IndexingMap IndexingMapAttr::getIndexingMap() {
+  return IndexingMap(getMap(), getDimVars(), getRangeVars(), /*rt_vars=*/{},
+                     getConstraints());
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.td b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.td
new file mode 100644
index 0000000..19dd24f
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.td
@@ -0,0 +1,64 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_FUSIONS_MLIR_ATTRS
+#define XLA_SERVICE_GPU_FUSIONS_MLIR_ATTRS
+
+include "mlir/IR/AttrTypeBase.td"
+include "xla/service/gpu/fusions/ir/xla_gpu_dialect.td"
+
+class XLAGPU_Attr<string name, list<Trait> traits = []> :
+      AttrDef<XlaGpuDialect, name, traits> {
+}
+
+def XLAGPU_AffineMapParameter :
+    AttrOrTypeParameter<"::mlir::AffineMap", ""> {
+}
+
+def XLAGPU_DimVarsParameter : ArrayRefParameter<"::xla::gpu::DimVar",
+                                                "DimVarArray"> {
+}
+
+def XLAGPU_RangeVarsParameter : ArrayRefParameter<"::xla::gpu::RangeVar",
+                                                  "RangeVarArray"> {
+}
+
+def XLAGPU_ConstraintsParameter :
+    ArrayRefParameter<"::std::pair<::mlir::AffineExpr, ::xla::gpu::Interval>",
+                      "ContraintsArray"> {
+}
+
+def XLAGPU_IndexingMapAttr : XLAGPU_Attr<"IndexingMap"> {
+  let summary = "An Attribute representing an indexing map.";
+  let mnemonic = "indexing_map";
+  let description = [{This attribute stores an indexing map. See
+    https://openxla.org/xla/indexing for more details.
+  }];
+  let parameters = (ins XLAGPU_AffineMapParameter:$map,
+                        XLAGPU_DimVarsParameter:$dim_vars,
+                        XLAGPU_RangeVarsParameter:$range_vars,
+                        XLAGPU_ConstraintsParameter:$constraints);
+  let hasCustomAssemblyFormat = 1;
+  let builders =  [
+    AttrBuilder<(ins "const ::xla::gpu::IndexingMap&":$indexing_map)>,
+  ];
+  let genVerifyDecl = 1;
+  let extraClassDeclaration = [{
+    // Returns the indexing map constructed from IndexingMapAttr.
+    xla::gpu::IndexingMap getIndexingMap();
+  }];
+}
+
+#endif // XLA_SERVICE_GPU_FUSIONS_MLIR_ATTRS
diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_dialect.cc b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_dialect.cc
new file mode 100644
index 0000000..57d2d70
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_dialect.cc
@@ -0,0 +1,136 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "llvm/ADT/TypeSwitch.h"  // IWYU pragma: keep
+#include "mlir/IR/DialectImplementation.h"  // IWYU pragma: keep
+#include "mlir/IR/OpImplementation.h"  // IWYU pragma: keep
+#include "mlir/Transforms/InliningUtils.h"
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h"
+#define GET_ATTRDEF_CLASSES
+#include "xla/service/gpu/fusions/ir/xla_gpu_attrs.cc.inc"
+#undef GET_ATTRDEF_CLASSES
+#define GET_TYPEDEF_CLASSES
+#include "xla/service/gpu/fusions/ir/xla_gpu_types.cc.inc"
+#undef GET_TYPEDEF_CLASSES
+
+namespace xla {
+namespace gpu {
+namespace {
+
+struct XlaGpuInlinerInterface : public mlir::DialectInlinerInterface {
+  using DialectInlinerInterface::DialectInlinerInterface;
+  // Returns true if the given operation 'callable', that implements the
+  // 'CallableOpInterface', can be inlined into the position given call
+  // operation 'call', that is registered to the current dialect and implements
+  // the `CallOpInterface`. 'wouldBeCloned' is set to true if the region of the
+  // given 'callable' is set to be cloned during the inlining process, or false
+  // if the region is set to be moved in-place (i.e. no duplicates would be
+  // created).
+  bool isLegalToInline(mlir::Operation* call, mlir::Operation* callable,
+                       bool wouldBeCloned) const final {
+    if (!wouldBeCloned) {
+      // If no duplicate would be created, 'call' is likely the only caller of
+      // 'callable'.
+      return true;
+    }
+    // Otherwise, inline only if the called function is small. We could
+    // theoretically also inline if there is no other caller in the function
+    // that contains the callee that has a call path to the callable, but that
+    // is more expensive to check.
+    auto func_op = mlir::dyn_cast<mlir::func::FuncOp>(callable);
+    if (!func_op) {
+      return false;
+    }
+    auto region = func_op.getCallableRegion();
+    if (!region) {
+      return false;
+    }
+
+    // If callee and caller call the same third function, inline. We have no
+    // guarantee that the indices are the same, but there is a good chance they
+    // are (or if the callee gets inlined as well, there will be CSE
+    // opportunities).
+    // This is duct tape to work around the limitations of our partitioner.
+    // Ideally, the partitioner would be aware of the actual indexing and create
+    // the partitions based on it (i.e., the case where the indices are the same
+    // would never happen).
+    llvm::SmallDenseSet<llvm::StringRef> callee_calls;
+    for (auto call : region->getOps<PureCallOp>()) {
+      callee_calls.insert(call.getCallee());
+    }
+    for (auto call : call->getParentRegion()->getOps<PureCallOp>()) {
+      if (callee_calls.contains(call.getCallee())) {
+        return true;
+      }
+    }
+
+    constexpr int kMaxOperationsToInline = 8;
+    int num_ops = 0;
+    region->front().walk([&](mlir::Operation* op) { ++num_ops; });
+
+    // Don't inline functions that are called more than once and contain more
+    // than one call themselves.
+    return num_ops <= kMaxOperationsToInline;
+  }
+  // Returns true if the given operation 'op', that is registered to this
+  // dialect, can be inlined into the given region, false otherwise.
+  // 'wouldBeCloned' is set to true if the given 'op' is set to be cloned
+  // during the inlining process, or false if the operation is set to be moved
+  // in-place(i.e. no duplicates would be created). 'valueMapping' contains any
+  // remapped values from within the 'src' region. This can be used to examine
+  // what values may potentially replace the operands to 'op'.
+  bool isLegalToInline(mlir::Operation* op, mlir::Region* dest,
+                       bool wouldBeCloned,
+                       mlir::IRMapping& valueMapping) const final {
+    // We allow any op from the xla_gpu dialect to be inlined.
+    return true;
+  }
+};
+
+struct XlaGpuOpAsmDialectInterface : public mlir::OpAsmDialectInterface {
+  using OpAsmDialectInterface::OpAsmDialectInterface;
+  AliasResult getAlias(mlir::Attribute attr,
+                       mlir::raw_ostream& os) const final {
+    if (llvm::isa<IndexingMapAttr>(attr)) {
+      os << "indexing_map";
+      return AliasResult::FinalAlias;
+    }
+    return AliasResult::NoAlias;
+  }
+};
+
+}  // namespace
+
+void XlaGpuDialect::initialize() {
+  addOperations<
+#define GET_OP_LIST
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.cc.inc"
+#undef GET_OP_LIST
+      >();
+  addAttributes<
+#define GET_ATTRDEF_LIST
+#include "xla/service/gpu/fusions/ir/xla_gpu_attrs.cc.inc"
+      >();
+#undef GET_ATTRDEF_LIST
+  addInterfaces<XlaGpuInlinerInterface, XlaGpuOpAsmDialectInterface>();
+  addTypes<
+#define GET_TYPEDEF_LIST
+#include "xla/service/gpu/fusions/ir/xla_gpu_types.cc.inc"
+#undef GET_TYPEDEF_LIST
+      >();
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_dialect.td b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_dialect.td
new file mode 100644
index 0000000..9a5c539
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_dialect.td
@@ -0,0 +1,33 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_FUSIONS_MLIR_DIALECT
+#define XLA_SERVICE_GPU_FUSIONS_MLIR_DIALECT
+
+include "mlir/IR/DialectBase.td"
+
+def XlaGpuDialect : Dialect {
+  let name = "xla_gpu";
+
+  let description = [{
+    This dialect contains ops required for lowering HLO to LLVM.
+  }];
+
+  let cppNamespace = "::xla::gpu";
+  let useDefaultAttributePrinterParser = 1;
+  let useDefaultTypePrinterParser = 1;
+}
+
+#endif // XLA_SERVICE_GPU_FUSIONS_MLIR_DIALECT
diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.cc b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.cc
new file mode 100644
index 0000000..bc59b7c
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.cc
@@ -0,0 +1,841 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h"
+
+#include <cstdint>
+#include <optional>
+#include <utility>
+#include <vector>
+
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/STLFunctionalExtras.h"
+#include "llvm/ADT/SmallBitVector.h"
+#include "llvm/ADT/TypeSwitch.h"  // IWYU pragma: keep
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/LogicalResult.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Builders.h"  // IWYU pragma: keep
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/DialectImplementation.h"  // IWYU pragma: keep
+#include "mlir/IR/MLIRContext.h"  // IWYU pragma: keep
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/OperationSupport.h"
+#include "mlir/IR/PatternMatch.h"  // IWYU pragma: keep
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/IR/TypeRange.h"
+#include "mlir/IR/TypeUtilities.h"  // IWYU pragma: keep
+#include "mlir/IR/Value.h"
+#include "mlir/IR/ValueRange.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "xla/service/gpu/fusions/ir/xla_gpu_dialect.cc.inc"
+#include "xla/service/gpu/model/indexing_map.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+using llvm::ArrayRef;
+using mlir::AffineExpr;
+using mlir::AffineMap;
+using mlir::Block;
+using mlir::failure;
+using mlir::getAffineConstantExpr;
+using mlir::getAffineDimExpr;
+using mlir::getAffineSymbolExpr;
+using mlir::LogicalResult;
+using mlir::MLIRContext;
+using mlir::OpAsmParser;
+using mlir::OpAsmPrinter;
+using mlir::OpBuilder;
+using mlir::OperationState;
+using mlir::PatternRewriter;
+using mlir::RankedTensorType;
+using mlir::Region;
+using mlir::SmallVector;
+using mlir::success;
+using mlir::Type;
+using mlir::TypeRange;
+using mlir::Value;
+using mlir::ValueRange;
+
+namespace arith = mlir::arith;
+
+}  // namespace
+
+LogicalResult PureCallOp::verifySymbolUses(
+    mlir::SymbolTableCollection& symbolTable) {
+  auto callee = getCalleeAttr();
+  auto function =
+      symbolTable.lookupNearestSymbolFrom<mlir::func::FuncOp>(*this, callee);
+  if (!function) {
+    return emitError("'f' attribute refers to an undefined function: ")
+           << callee;
+  }
+
+  int func_arg_count = function.getFunctionType().getNumInputs();
+  int arg_count = getOperands().size();
+
+  if (arg_count != func_arg_count) {
+    return emitError() << "argument count mismatch: 'operands' has "
+                       << arg_count << " arguments, but '" << callee
+                       << "' expects " << func_arg_count;
+  }
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// AllocateSharedOp
+//===----------------------------------------------------------------------===//
+
+void AllocateSharedOp::getAsmResultNames(
+    llvm::function_ref<void(mlir::Value, mlir::StringRef)> setNameFn) {
+  setNameFn(getResult(), "shmem");
+}
+
+//===----------------------------------------------------------------------===//
+// ApplyIndexingOp
+//===----------------------------------------------------------------------===//
+
+void ApplyIndexingOp::build(OpBuilder& builder, OperationState& result,
+                            ValueRange dims, ValueRange symbols,
+                            const IndexingMap& indexing_map) {
+  SmallVector<Value, 4> operands;
+  operands.reserve(dims.size() + symbols.size());
+  operands.append(dims.begin(), dims.end());
+  operands.append(symbols.begin(), symbols.end());
+  build(builder, result, operands, indexing_map);
+}
+
+void ApplyIndexingOp::build(OpBuilder& builder, OperationState& result,
+                            ValueRange operands,
+                            const IndexingMap& indexing_map) {
+  SmallVector<Type, 2> result_types(indexing_map.GetAffineMap().getNumResults(),
+                                    builder.getIndexType());
+  IndexingMapAttr indexing_map_attr =
+      IndexingMapAttr::get(builder.getContext(), indexing_map);
+  build(builder, result, result_types, operands, indexing_map_attr);
+}
+
+void ApplyIndexingOp::build(OpBuilder& builder, OperationState& result,
+                            ValueRange operands, AffineMap affine_map,
+                            ArrayRef<DimVar> dim_vars,
+                            ArrayRef<RangeVar> range_vars) {
+  IndexingMap indexing_map(affine_map, dim_vars, range_vars, {});
+  build(builder, result, operands, indexing_map);
+}
+
+// Parses a comma-separated list of operands, ex: %d1, %d2.
+mlir::ParseResult parseOperands(
+    OpAsmParser& parser,
+    SmallVector<OpAsmParser::UnresolvedOperand, 4>* operands) {
+  OpAsmParser::UnresolvedOperand operand;
+  return parser.parseCommaSeparatedList(
+      [&]() { return parser.parseOperand(operands->emplace_back()); });
+}
+
+mlir::ParseResult ApplyIndexingOp::parse(OpAsmParser& parser,
+                                         OperationState& result) {
+  mlir::Builder& builder = parser.getBuilder();
+  auto index_type = builder.getIndexType();
+
+  IndexingMapAttr indexing_map_attr;
+  if (parser.parseAttribute(indexing_map_attr, "indexing_map_attr",
+                            result.attributes)) {
+    return failure();
+  }
+
+  SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
+  SmallVector<int64_t, 4> lower_bounds, upper_bounds;
+  if (succeeded(parser.parseOptionalLParen())) {
+    if (parseOperands(parser, &operands) || parser.parseRParen()) {
+      return failure();
+    }
+  }
+  if (succeeded(parser.parseOptionalLSquare())) {
+    if (parseOperands(parser, &operands) || parser.parseRSquare()) {
+      return failure();
+    }
+  }
+  if (parser.resolveOperands(operands, index_type, result.operands) ||
+      parser.parseOptionalAttrDict(result.attributes)) {
+    return failure();
+  }
+  auto map = indexing_map_attr.getMap();
+  result.addTypes(SmallVector<Type, 2>(map.getNumResults(), index_type));
+  return success();
+}
+
+void ApplyIndexingOp::print(OpAsmPrinter& p) {
+  AffineMap affine_map = getIndexingMapAttr().getMap();
+  p << " " << getIndexingMapAttr();
+
+  auto operands = getOperands();
+  unsigned num_dimensions = affine_map.getNumDims();
+  if (num_dimensions > 0) {
+    p << '(';
+    auto dimension_operands = operands.slice(0, num_dimensions);
+    llvm::interleaveComma(dimension_operands, p);
+    p << ')';
+  }
+
+  unsigned num_symbols = affine_map.getNumSymbols();
+  if (num_symbols > 0) {
+    p << '[';
+    auto symbol_operands = operands.slice(num_dimensions, num_symbols);
+    llvm::interleaveComma(symbol_operands, p);
+    p << ']';
+  }
+
+  p.printOptionalAttrDict((*this)->getAttrs(),
+                          /*elidedAttrs=*/{"indexing_map_attr"});
+}
+
+LogicalResult ApplyIndexingOp::verify() {
+  auto affine_map = getIndexingMapAttr().getMap();
+  unsigned num_variables = affine_map.getNumDims() + affine_map.getNumSymbols();
+  if (getOperands().size() != num_variables) {
+    return emitOpError(
+        "operand count must match the number of dimensions and symbols in the "
+        "affine map");
+  }
+  if (!getIndexingMapAttr().getConstraints().empty()) {
+    return emitOpError("apply indexing op cannot have any constraints");
+  }
+  return success();
+}
+
+IndexingMap ApplyIndexingOp::getIndexingMap() {
+  return getIndexingMapAttr().getIndexingMap();
+}
+
+namespace {
+
+// Simplifies the indexing map, removes unused variables.
+struct SimplifyIndexingMap : public mlir::OpRewritePattern<ApplyIndexingOp> {
+  using OpRewritePattern<ApplyIndexingOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ApplyIndexingOp indexing_op,
+                                PatternRewriter& rewriter) const override {
+    IndexingMap indexing_map = indexing_op.getIndexingMap();
+    bool is_simplified = indexing_map.Simplify();
+
+    // Remove unused symbols.
+    auto unused_symbols_bit_vector = indexing_map.RemoveUnusedVars();
+    bool symbols_removed = unused_symbols_bit_vector.count() != 0;
+
+    if (!is_simplified && !symbols_removed) {
+      return rewriter.notifyMatchFailure(indexing_op,
+                                         "IndexingMap stayed unchanged");
+    }
+    if (!unused_symbols_bit_vector.empty()) {
+      SmallVector<Value, 4> operands;
+      operands.reserve(unused_symbols_bit_vector.count());
+      for (int i = 0; i < unused_symbols_bit_vector.size(); ++i) {
+        if (!unused_symbols_bit_vector[i]) {
+          operands.push_back(indexing_op.getOperand(i));
+        }
+      }
+      rewriter.replaceOpWithNewOp<ApplyIndexingOp>(indexing_op, operands,
+                                                   indexing_map);
+    } else {
+      rewriter.replaceOpWithNewOp<ApplyIndexingOp>(
+          indexing_op, indexing_op.getOperands(), indexing_map);
+    }
+    return success();
+  }
+};
+
+struct FoldApplyIndexingSequence
+    : public mlir::OpRewritePattern<ApplyIndexingOp> {
+  using OpRewritePattern<ApplyIndexingOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ApplyIndexingOp indexing_op,
+                                PatternRewriter& rewriter) const override {
+    MLIRContext* ctx = indexing_op.getContext();
+    int num_dims = indexing_op.getAffineMap().getNumDims();
+    int num_syms = indexing_op.getAffineMap().getNumSymbols();
+    mlir::DenseMap<Value, AffineExpr> operand_exprs;
+    for (auto& operand : indexing_op->getOpOperands()) {
+      int operand_number = operand.getOperandNumber();
+      operand_exprs[operand.get()] =
+          operand_number < num_dims
+              ? getAffineDimExpr(operand_number, ctx)
+              : getAffineSymbolExpr(operand_number - num_dims, ctx);
+    }
+
+    auto this_map = indexing_op.getIndexingMap();
+
+    SmallVector<Value> added_dim_args;
+    SmallVector<Value> added_sym_args;
+    auto new_dim_vars = this_map.GetDimVars();
+    auto new_sym_vars = this_map.GetRangeVars();
+
+    mlir::DenseMap<AffineExpr, AffineExpr> replacements;
+    for (auto& operand : indexing_op->getOpOperands()) {
+      if (auto producer = operand.get().getDefiningOp<ApplyIndexingOp>()) {
+        auto producer_map = producer.getIndexingMap();
+        int producer_result_id =
+            mlir::cast<mlir::OpResult>(operand.get()).getResultNumber();
+        int num_producer_dims = producer.getAffineMap().getNumDims();
+        SmallVector<AffineExpr> producer_dim_replacements;
+        SmallVector<AffineExpr> producer_sym_replacements;
+        for (auto& producer_operand : producer->getOpOperands()) {
+          int producer_operand_number = producer_operand.getOperandNumber();
+          bool is_dim = producer_operand_number < num_producer_dims;
+          auto& replacement_expr = operand_exprs[producer_operand.get()];
+          if (!replacement_expr) {
+            if (is_dim) {
+              int dim_num = producer_operand_number;
+              replacement_expr =
+                  getAffineDimExpr(num_dims + added_dim_args.size(), ctx);
+              added_dim_args.push_back(producer_operand.get());
+              new_dim_vars.push_back(producer_map.GetDimVars(dim_num));
+            } else {
+              int sym_num = producer_operand_number -
+                            producer.getAffineMap().getNumDims();
+              replacement_expr =
+                  getAffineSymbolExpr(num_syms + added_sym_args.size(), ctx);
+              added_sym_args.push_back(producer_operand.get());
+              new_sym_vars.push_back(producer_map.GetRangeVar(sym_num));
+            }
+          }
+
+          if (is_dim) {
+            producer_dim_replacements.push_back(replacement_expr);
+          } else {
+            producer_sym_replacements.push_back(replacement_expr);
+          }
+        }
+
+        replacements[operand_exprs[operand.get()]] =
+            producer.getAffineMap()
+                .getResult(producer_result_id)
+                .replaceDimsAndSymbols(producer_dim_replacements,
+                                       producer_sym_replacements);
+      }
+    }
+
+    if (replacements.empty()) {
+      return rewriter.notifyMatchFailure(indexing_op,
+                                         "No apply_indexing sequences found");
+    }
+
+    int new_num_operands = indexing_op->getNumOperands() +
+                           added_dim_args.size() + added_sym_args.size();
+    auto new_affine_map = indexing_op.getAffineMap().replace(
+        replacements, num_dims + added_dim_args.size(),
+        num_syms + added_sym_args.size());
+    IndexingMap new_indexing_map(new_affine_map, new_dim_vars, new_sym_vars,
+                                 /*rt_vars=*/{});
+    if (!new_indexing_map.Simplify()) {
+      return rewriter.notifyMatchFailure(
+          indexing_op, "Folded indexing map was not simplified");
+    }
+    SmallVector<Value> new_operands;
+    new_operands.reserve(new_num_operands);
+
+    auto begin = indexing_op.getOperands().begin();
+    new_operands.append(begin, begin + num_dims);
+    new_operands.append(added_dim_args);
+    new_operands.append(begin + num_dims, begin + num_dims + num_syms);
+    new_operands.append(added_sym_args);
+
+    rewriter.replaceOpWithNewOp<ApplyIndexingOp>(indexing_op, new_operands,
+                                                 new_indexing_map);
+    return success();
+  }
+};
+
+// Folds constants into the indexing map.
+struct FoldApplyIndexingOperands
+    : public mlir::OpRewritePattern<ApplyIndexingOp> {
+  using OpRewritePattern<ApplyIndexingOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ApplyIndexingOp indexing_op,
+                                PatternRewriter& rewriter) const override {
+    IndexingMap indexing_map = indexing_op.getIndexingMap();
+    AffineMap affine_map = indexing_map.GetAffineMap();
+
+    MLIRContext* ctx = affine_map.getContext();
+    unsigned num_operands = indexing_op->getNumOperands();
+    unsigned num_dims = affine_map.getNumDims();
+    unsigned num_symbols = affine_map.getNumSymbols();
+
+    SmallVector<std::optional<int64_t>> constant_values(num_operands,
+                                                        std::nullopt);
+    int num_constants = 0;
+    for (auto& operand : indexing_op->getOpOperands()) {
+      if (auto constant =
+              operand.get().getDefiningOp<arith::ConstantIndexOp>()) {
+        constant_values[operand.getOperandNumber()] = constant.value();
+        ++num_constants;
+      }
+    }
+    if (num_constants == 0) {
+      return rewriter.notifyMatchFailure(indexing_op,
+                                         "No constant operands found");
+    }
+    SmallVector<AffineExpr, 2> dim_replacements, symbol_replacements;
+    dim_replacements.reserve(num_dims);
+    symbol_replacements.reserve(num_symbols);
+
+    unsigned new_num_operands = indexing_op->getNumOperands() - num_constants;
+    SmallVector<Value, 4> new_operands;
+    new_operands.reserve(new_num_operands);
+    SmallVector<DimVar, 2> new_dim_vars;
+    new_dim_vars.reserve(num_dims);
+    SmallVector<RangeVar, 2> new_range_vars;
+    new_range_vars.reserve(num_symbols);
+
+    unsigned new_num_dims = 0;
+    unsigned new_num_symbols = 0;
+    for (auto [operand, constant_value] :
+         llvm::zip(indexing_op->getOpOperands(), constant_values)) {
+      unsigned operand_id = operand.getOperandNumber();
+      if (constant_value.has_value()) {
+        if (operand_id < num_dims) {
+          dim_replacements.push_back(
+              getAffineConstantExpr(*constant_value, ctx));
+        } else {
+          symbol_replacements.push_back(
+              getAffineConstantExpr(*constant_value, ctx));
+        }
+      } else {
+        new_operands.push_back(operand.get());
+        if (operand_id < num_dims) {
+          dim_replacements.push_back(getAffineDimExpr(new_num_dims++, ctx));
+          new_dim_vars.push_back(indexing_map.GetDimVars(operand_id));
+        } else {
+          symbol_replacements.push_back(
+              getAffineSymbolExpr(new_num_symbols++, ctx));
+          new_range_vars.push_back(
+              indexing_map.GetRangeVar(operand_id - num_dims));
+        }
+      }
+    }
+    rewriter.replaceOpWithNewOp<ApplyIndexingOp>(
+        indexing_op, new_operands,
+        affine_map.replaceDimsAndSymbols(dim_replacements, symbol_replacements,
+                                         new_num_dims, new_num_symbols),
+        new_dim_vars, new_range_vars);
+    return success();
+  }
+};
+
+// Folds constant and dim/symbol expression results.
+struct FoldApplyIndexingResults
+    : public mlir::OpRewritePattern<ApplyIndexingOp> {
+  using OpRewritePattern<ApplyIndexingOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ApplyIndexingOp indexing_op,
+                                PatternRewriter& rewriter) const override {
+    mlir::Location loc = indexing_op.getLoc();
+    IndexingMap indexing_map = indexing_op.getIndexingMap();
+    if (indexing_map.IsKnownEmpty()) {
+      return rewriter.notifyMatchFailure(indexing_op,
+                                         "Domain of the indexing map is empty");
+    }
+    AffineMap* affine_map = &indexing_map.GetMutableAffineMap();
+    unsigned num_results = affine_map->getNumResults();
+    SmallVector<AffineExpr, 4> new_exprs;
+    new_exprs.reserve(num_results);
+    SmallVector<Value, 4> new_values;
+    new_values.reserve(num_results);
+    for (mlir::OpResult opresult : indexing_op->getOpResults()) {
+      if (opresult.use_empty()) {
+        new_values.push_back(rewriter.create<arith::ConstantIndexOp>(loc, 0));
+        continue;
+      }
+
+      unsigned id = opresult.getResultNumber();
+      AffineExpr result_expr = affine_map->getResult(id);
+      if (auto const_expr =
+              mlir::dyn_cast<mlir::AffineConstantExpr>(result_expr)) {
+        new_values.push_back(rewriter.create<arith::ConstantIndexOp>(
+            loc, const_expr.getValue()));
+        continue;
+      }
+      if (auto dim_expr = mlir::dyn_cast<mlir::AffineDimExpr>(result_expr)) {
+        new_values.push_back(indexing_op.getOperand(dim_expr.getPosition()));
+        continue;
+      }
+      if (auto symbol_expr =
+              mlir::dyn_cast<mlir::AffineSymbolExpr>(result_expr)) {
+        new_values.push_back(indexing_op.getOperand(
+            indexing_map.GetDimVarsCount() + symbol_expr.getPosition()));
+        continue;
+      }
+      new_exprs.push_back(result_expr);
+      new_values.push_back(Value{});
+    }
+    if (new_exprs.size() == num_results) {
+      return rewriter.notifyMatchFailure(
+          indexing_op, "No constant or dim/symbol expression found");
+    }
+    *affine_map =
+        AffineMap::get(affine_map->getNumDims(), affine_map->getNumSymbols(),
+                       new_exprs, affine_map->getContext());
+    auto new_indexing_op = rewriter.create<ApplyIndexingOp>(
+        loc, indexing_op.getOperands(), indexing_map);
+    for (int new_result_id = 0, new_indexing_op_result_id = 0;
+         new_result_id < new_values.size(); ++new_result_id) {
+      auto& new_value = new_values[new_result_id];
+      if (new_value) continue;
+      new_value = new_indexing_op.getResult(new_indexing_op_result_id++);
+    }
+    rewriter.replaceOp(indexing_op, new_values);
+    return success();
+  }
+};
+
+}  // namespace
+
+void ApplyIndexingOp::getCanonicalizationPatterns(
+    mlir::RewritePatternSet& results, MLIRContext* context) {
+  results.add<FoldApplyIndexingOperands, FoldApplyIndexingResults,
+              SimplifyIndexingMap, FoldApplyIndexingSequence>(context);
+}
+
+mlir::LogicalResult ApplyIndexingOp::fold(
+    FoldAdaptor adaptor, llvm::SmallVectorImpl<mlir::OpFoldResult>& results) {
+  auto map = getAffineMap();
+  for (auto expr : map.getResults()) {
+    if (auto dim = mlir::dyn_cast<mlir::AffineDimExpr>(expr)) {
+      results.push_back(getOperand(dim.getPosition()));
+    } else if (auto sym = mlir::dyn_cast<mlir::AffineSymbolExpr>(expr)) {
+      results.push_back(getOperand(map.getNumDims() + sym.getPosition()));
+    } else {
+      results.clear();
+      return failure();
+    }
+  }
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// AtomicRMWOp
+//===----------------------------------------------------------------------===//
+
+void AtomicRMWOp::getAsmResultNames(
+    llvm::function_ref<void(mlir::Value, mlir::StringRef)> setNameFn) {
+  setNameFn(getResult(), "atomic_rmw");
+}
+
+void AtomicRMWOp::build(OpBuilder& builder, OperationState& result,
+                        Value tensor, ValueRange ivs) {
+  OpBuilder::InsertionGuard g(builder);
+  result.addOperands(tensor);
+  result.addOperands(ivs);
+  result.addTypes(tensor.getType());
+
+  auto tensor_type = llvm::cast<RankedTensorType>(tensor.getType());
+  Region* body = result.addRegion();
+  builder.createBlock(body);
+  body->addArgument(tensor_type.getElementType(), tensor.getLoc());
+}
+
+mlir::OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) {
+  auto* body = getBody();
+  if (&body->front() == body->getTerminator() &&
+      body->front().getOperand(0) == body->getArgument(0)) {
+    return getOperand(0);
+  }
+  return {};
+}
+
+//===----------------------------------------------------------------------===//
+// PureCallOp
+//===----------------------------------------------------------------------===//
+
+void PureCallOp::getAsmResultNames(
+    llvm::function_ref<void(mlir::Value, mlir::StringRef)> setNameFn) {
+  for (auto result : getResults()) {
+    setNameFn(result, "pure_call");
+  }
+}
+
+//===----------------------------------------------------------------------===//
+// SyncThreadsOp
+//===----------------------------------------------------------------------===//
+
+void SyncThreadsOp::getAsmResultNames(
+    llvm::function_ref<void(mlir::Value, mlir::StringRef)> setNameFn) {
+  for (auto result : getResults()) {
+    setNameFn(result, "synced_tensor");
+  }
+}
+
+//===----------------------------------------------------------------------===//
+// LoopOp
+//===----------------------------------------------------------------------===//
+
+void LoopOp::build(OpBuilder& builder, OperationState& result,
+                   IndexingMapAttr indexing_map_attr, ValueRange dims,
+                   ValueRange inits, BodyBuilderFn bodyBuilder) {
+  OpBuilder::InsertionGuard guard(builder);
+
+  int64_t num_ivs = indexing_map_attr.getRangeVars().size();
+  result.addOperands(dims);
+  result.addOperands(inits);
+  result.addTypes(TypeRange(inits));
+  Block* body_block = builder.createBlock(result.addRegion());
+  // Add induction variables block args.
+  for (int i = 0; i < num_ivs; ++i) {
+    body_block->addArgument(builder.getIndexType(), result.location);
+  }
+  // Add iteration arguments block args.
+  for (auto init_type : TypeRange(inits)) {
+    body_block->addArguments(init_type, result.location);
+  }
+
+  mlir::OperationName opname(LoopOp::getOperationName(), builder.getContext());
+  result.addAttribute(LoopOp::getIndexingMapAttrAttrName(opname),
+                      indexing_map_attr);
+  result.addAttribute(
+      LoopOp::getOperandSegmentSizesAttrName(opname),
+      builder.getDenseI32ArrayAttr({static_cast<int32_t>(dims.size()),
+                                    static_cast<int32_t>(inits.size())}));
+  if (bodyBuilder) {
+    OpBuilder::InsertionGuard guard(builder);
+    builder.setInsertionPointToStart(body_block);
+    bodyBuilder(builder, result.location,
+                body_block->getArguments().take_front(num_ivs),
+                body_block->getArguments().drop_front(num_ivs));
+  }
+}
+
+void LoopOp::build(OpBuilder& builder, OperationState& result,
+                   const IndexingMap& indexing_map, ValueRange dims,
+                   ValueRange inits, BodyBuilderFn bodyBuilder) {
+  build(builder, result,
+        IndexingMapAttr::get(builder.getContext(), indexing_map), dims, inits,
+        bodyBuilder);
+}
+
+mlir::ParseResult LoopOp::parse(OpAsmParser& parser, OperationState& result) {
+  SmallVector<OpAsmParser::Argument, 4> region_args, ivs, iter_args;
+  SmallVector<OpAsmParser::UnresolvedOperand, 4> dim_operands;
+
+  // Parse the dimension values.
+  OpBuilder b(parser.getContext());
+  Type index_type = b.getIndexType();
+  if (parser.parseOperandList(dim_operands, OpAsmParser::Delimiter::Paren) ||
+      parser.resolveOperands(dim_operands, index_type, result.operands))
+    return failure();
+  // Parse the induction variables.
+  if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Square))
+    return failure();
+  for (auto iv : ivs) {
+    region_args.push_back(iv);
+    region_args.back().type = index_type;
+  }
+
+  // Parse the indexing map attribute.
+  IndexingMapAttr indexing_map_attr;
+  if (parser.parseKeyword("in") ||
+      parser.parseAttribute(indexing_map_attr, "indexing_map_attr",
+                            result.attributes)) {
+    return failure();
+  }
+
+  // Parse the arguments.
+  SmallVector<OpAsmParser::UnresolvedOperand, 4> init_operands;
+  if (parser.parseKeyword("iter_args") ||
+      parser.parseAssignmentList(iter_args, init_operands) ||
+      parser.parseArrowTypeList(result.types) ||
+      parser.resolveOperands(init_operands, result.types, parser.getNameLoc(),
+                             result.operands))
+    return failure();
+
+  for (auto [index, iter_arg] : llvm::enumerate(iter_args)) {
+    region_args.push_back(iter_arg);
+    region_args.back().type = result.types[index];
+  }
+
+  if (region_args.size() != result.types.size() + ivs.size()) {
+    return parser.emitError(parser.getNameLoc(),
+                            "mismatch in number of induction variables + "
+                            "loop-carried values and the number of results");
+  }
+
+  // Parse the body region.
+  Region* body = result.addRegion();
+  if (parser.parseRegion(*body, region_args)) return failure();
+  LoopOp::ensureTerminator(*body, b, result.location);
+
+  // Parse the optional attribute list
+  result.addAttribute(
+      LoopOp::getOperandSegmentSizeAttr(),
+      b.getDenseI32ArrayAttr({static_cast<int32_t>(dim_operands.size()),
+                              static_cast<int32_t>(iter_args.size())}));
+  if (parser.parseOptionalAttrDict(result.attributes)) return failure();
+
+  return success();
+}
+
+void LoopOp::print(OpAsmPrinter& p) {
+  p << " (" << getDims() << ")[" << getInductionVars() << "] in "
+    << getIndexingMapAttr() << " iter_args(";
+  llvm::interleaveComma(
+      llvm::zip(getRegionIterArgs(), getInits()), p,
+      [&](auto it) { p << std::get<0>(it) << " = " << std::get<1>(it); });
+  p << ") -> (" << getInits().getTypes() << ") ";
+  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
+                /*printBlockTerminators=*/true);
+  p.printOptionalAttrDict((*this)->getAttrs(),
+                          /*elidedAttrs=*/{
+                              getIndexingMapAttrAttrName(),
+                              getOperandSegmentSizesAttrName(),
+                          });
+}
+
+LogicalResult LoopOp::verify() {
+  if (getInits().size() != getNumResults()) {
+    return emitOpError("mismatch in number of loop-carried values and results");
+  }
+  IndexingMap indexing_map = getIndexingMap();
+  if (indexing_map.GetRangeVarsCount() != getNumInductionVars()) {
+    return emitOpError() << "mismatch in number of induction variables "
+                         << getNumInductionVars()
+                         << " and RangeVars in the indexing map "
+                         << indexing_map.ToString();
+  }
+  if (indexing_map.GetDimVarsCount() != getDims().size()) {
+    return emitOpError() << "mismatch in number of dims operands "
+                         << getDims().size()
+                         << " and DimVars in the indexing map "
+                         << indexing_map.ToString();
+  }
+  for (auto [bb_arg, result_type, init] :
+       llvm::zip(getRegionIterArgs(), getResultTypes(), getInits())) {
+    if (bb_arg.getType() != result_type || init.getType() != result_type) {
+      return emitOpError() << "block iter arg type = " << bb_arg.getType()
+                           << ", result type = " << result_type
+                           << " and init operand type = " << init.getType()
+                           << " should match";
+    }
+  }
+  return success();
+}
+
+IndexingMap LoopOp::getIndexingMap() {
+  return getIndexingMapAttr().getIndexingMap();
+}
+
+//===----------------------------------------------------------------------===//
+// MaterializeOp
+//===----------------------------------------------------------------------===//
+
+VariableConstraints GetConstraintsForVariables(const IndexingMap& map) {
+  VariableConstraints result;
+  result.constraints_for_dims.resize(map.GetDimensionCount());
+  result.constraints_for_symbols.resize(map.GetSymbolCount());
+  for (const auto& constraint : map.GetConstraints()) {
+    constraint.first.walk([&](mlir::AffineExpr leaf) {
+      if (auto dim = mlir::dyn_cast<mlir::AffineDimExpr>(leaf)) {
+        result.constraints_for_dims[dim.getPosition()].push_back(constraint);
+      } else if (auto sym = mlir::dyn_cast<mlir::AffineSymbolExpr>(leaf)) {
+        result.constraints_for_symbols[sym.getPosition()].push_back(constraint);
+      }
+    });
+  }
+  return result;
+}
+
+LogicalResult MaterializeOp::verify() {
+  IndexingMap map_in = getMap().getIndexingMap();
+  IndexingMap map_out =
+      getResult().getType().getIndexingMapAttr().getIndexingMap();
+  if (getIndices().size() != map_in.GetDimVarsCount()) {
+    return emitOpError() << "number of indices must match number of dimensions "
+                            "of indexing map";
+  }
+
+  // The thread dimension must have the same domain (range and constraints)
+  if (map_in.GetDimVarsCount() == 0 || map_out.GetDimVarsCount() == 0) {
+    return emitOpError()
+           << "must have thread_id dimension in both indexing maps";
+  }
+  if (map_in.GetDimVars(0) != map_out.GetDimVars(0)) {
+    return emitOpError() << "thread_id dimension must have the same bounds in "
+                            "both indexing maps";
+  }
+
+  auto variable_constraints_in = GetConstraintsForVariables(map_in);
+  auto variable_constraints_out = GetConstraintsForVariables(map_out);
+  if (variable_constraints_in.constraints_for_dims[0] !=
+      variable_constraints_out.constraints_for_dims[0]) {
+    return emitOpError() << "constraints of indexing maps must be equal for "
+                         << "the thread_id dimension";
+  }
+
+  // The two maps must have the same symbols and they must have the same domain
+  if (map_in.GetRangeVarsCount() != map_out.GetRangeVarsCount()) {
+    return emitOpError()
+           << "number of symbols in both indexing_maps must match";
+  }
+  for (auto const& [range_in, range_out] :
+       llvm::zip(map_in.GetRangeVars(), map_out.GetRangeVars())) {
+    if (range_in.range != range_out.range) {
+      return emitOpError() << "domain of symbols of indexing_maps must match";
+    }
+  }
+  if (variable_constraints_in.constraints_for_symbols !=
+      variable_constraints_out.constraints_for_symbols) {
+    return emitOpError()
+           << "constraints of indexing maps must be equal for all symbols";
+  }
+
+  // The vector mapping indices must not depend on the block ID
+  if (map_out.GetDimVarsCount() > 1) {
+    for (auto expr : map_out.GetAffineMap().getResults()) {
+      if (expr.isFunctionOfDim(1)) {
+        return emitOpError() << "vector mapping indices must not depend on the "
+                             << "block ID";
+      }
+    }
+  }
+  // If there are constraints on the block ID, they must be the same in both
+  // maps
+  if (map_in.GetDimVarsCount() > 1 && map_out.GetDimVarsCount() > 1) {
+    if (variable_constraints_in.constraints_for_dims[1] !=
+        variable_constraints_out.constraints_for_dims[1]) {
+      return emitOpError() << "constraints of indexing maps must be equal for "
+                           << "the block_id dimension";
+    }
+  } else if (map_in.GetDimVarsCount() > 1 &&
+             !variable_constraints_in.constraints_for_dims[1].empty()) {
+    return emitOpError() << "constraints of indexing maps must be equal for "
+                         << "the block_id dimension";
+  } else if (map_out.GetDimVarsCount() > 1 &&
+             !variable_constraints_out.constraints_for_dims[1].empty()) {
+    return emitOpError() << "constraints of indexing maps must be equal for "
+                         << "the block_id dimension";
+  }
+
+  return success();
+}
+
+}  // namespace gpu
+}  // namespace xla
+
+#define GET_OP_CLASSES
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.cc.inc"
diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.h b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.h
new file mode 100644
index 0000000..e3b8fd6
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.h
@@ -0,0 +1,56 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_FUSIONS_IR_XLA_GPU_OPS_H_
+#define XLA_SERVICE_GPU_FUSIONS_IR_XLA_GPU_OPS_H_
+
+#include <utility>
+
+#include "llvm/ADT/SmallVector.h"
+#include "mlir/Bytecode/BytecodeOpInterface.h"  // IWYU pragma: keep
+#include "mlir/Dialect/Func/IR/FuncOps.h"  // IWYU pragma: keep
+#include "mlir/IR/Attributes.h"  // IWYU pragma: keep
+#include "mlir/IR/BuiltinTypes.h"  // IWYU pragma: keep
+#include "mlir/IR/Dialect.h"  // IWYU pragma: keep
+#include "mlir/IR/MLIRContext.h"  // IWYU pragma: keep
+#include "mlir/IR/OpDefinition.h"  // IWYU pragma: keep
+#include "mlir/IR/OpImplementation.h"  // IWYU pragma: keep
+#include "mlir/Interfaces/CallInterfaces.h"  // IWYU pragma: keep
+#include "mlir/Interfaces/InferTypeOpInterface.h"  // IWYU pragma: keep
+#include "mlir/Interfaces/SideEffectInterfaces.h"  // IWYU pragma: keep
+#include "xla/service/gpu/fusions/ir/xla_gpu_dialect.h.inc"
+#include "xla/service/gpu/model/indexing_map.h"  // IWYU pragma: keep
+#define GET_ATTRDEF_CLASSES
+#include "xla/service/gpu/fusions/ir/xla_gpu_attrs.h.inc"
+#undef GET_ATTRDEF_CLASSES
+#define GET_TYPEDEF_CLASSES
+#include "xla/service/gpu/fusions/ir/xla_gpu_types.h.inc"
+#undef GET_TYPEDEF_CLASSES
+#define GET_OP_CLASSES
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h.inc"
+#undef GET_OP_CLASSES
+
+namespace xla::gpu {
+
+struct VariableConstraints {
+  llvm::SmallVector<llvm::SmallVector<std::pair<mlir::AffineExpr, Interval>>>
+      constraints_for_dims;
+  llvm::SmallVector<llvm::SmallVector<std::pair<mlir::AffineExpr, Interval>>>
+      constraints_for_symbols;
+};
+VariableConstraints GetConstraintsForVariables(const IndexingMap& map);
+
+}  // namespace xla::gpu
+
+#endif  // XLA_SERVICE_GPU_FUSIONS_IR_XLA_GPU_OPS_H_
diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.td b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.td
new file mode 100644
index 0000000..9eb246f
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.td
@@ -0,0 +1,367 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_FUSIONS_MLIR_OPS
+#define XLA_SERVICE_GPU_FUSIONS_MLIR_OPS
+
+include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/OpBase.td"
+include "mlir/IR/SymbolInterfaces.td"
+include "mlir/IR/OpAsmInterface.td"
+include "mlir/Interfaces/CallInterfaces.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/ControlFlowInterfaces.td"
+include "xla/service/gpu/fusions/ir/xla_gpu_dialect.td"
+include "xla/service/gpu/fusions/ir/xla_gpu_attrs.td"
+include "xla/service/gpu/fusions/ir/xla_gpu_types.td"
+
+class XLAGPU_Op<string mnemonic, list<Trait> traits = []> :
+      Op<XlaGpuDialect, mnemonic, traits> {
+}
+
+def XLAGPU_AllocateSharedOp : XLAGPU_Op<"allocate_shared", [
+      DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+  ]> {
+  let summary = "Allocates a shared memory tile.";
+
+  let description = [{
+    Allocates a shared memory tensor. The tensor is shared among all threads in
+    a block.
+
+    ```mlir
+    %shared = xla_gpu.allocate_shared : tensor<32x32xf32>
+    ```
+  }];
+
+  let results = (outs AnyStaticShapeTensor:$result);
+
+  let assemblyFormat = "attr-dict `:` type($result)";
+}
+
+def XLAGPU_SyncThreadsOp : XLAGPU_Op<"sync_threads", [
+   TypesMatchWith<"result type matches type of dest",
+                   "operands", "results", "$_self">,
+   DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+  ]> {
+  let summary = "Synchronizes threads.";
+
+  let description = [{
+    Synchronizes threads, taking any number of distributed tensors and returning
+    the synchronized state.
+  }];
+
+  let arguments = (ins Variadic<AnyRankedTensor>:$operands);
+  let results = (outs Variadic<AnyRankedTensor>:$results);
+
+  let assemblyFormat = "operands attr-dict `:` type($operands)";
+}
+
+def XLAGPU_AtomicRMWOp : XLAGPU_Op<"atomic_rmw",
+    [Pure,
+     TypesMatchWith<"result type matches type of dest",
+                    "input", "result", "$_self">,
+     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+  ]> {
+  let summary = "Atomically updates an element of a tensor.";
+
+  let description = [{
+    Reads an element from a tensor, computes the updated value for it, and
+    writes back the result.
+  }];
+
+  let arguments = (ins AnyRankedTensor:$input, Variadic<Index>:$indices);
+  let results = (outs AnyRankedTensor:$result);
+  // The region takes the current value in the tensor as an argument and yields
+  // the updated value.
+  let regions = (region SizedRegion<1>:$computation);
+
+  let skipDefaultBuilders = 1;
+  let builders = [OpBuilder<(ins "mlir::Value":$memref, "mlir::ValueRange":$ivs)>];
+
+  let extraClassDeclaration = [{
+    mlir::Block* getBody() { return &getComputation().front(); }
+    mlir::OpBuilder getBodyBuilder() {
+      return mlir::OpBuilder(getBody(), std::prev(getBody()->end()));
+    }
+    // The value stored in tensor[ivs].
+    mlir::Value getCurrentValue() {
+      return getRegion().getArgument(0);
+    }
+  }];
+  let hasFolder = 1;
+
+  let assemblyFormat = [{
+    $input `[` $indices `]` `:` type($input) $computation attr-dict
+  }];
+}
+
+def XLAGPU_YieldOp : XLAGPU_Op<"yield", [
+    ParentOneOf<["::xla::gpu::AtomicRMWOp", "::xla::gpu::LoopOp"]>,
+    ReturnLike, Terminator]> {
+  let summary = "Terminator for atomic_rmw ops.";
+  let arguments = (ins AnyType:$result);
+
+  let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>];
+  let assemblyFormat = "$result attr-dict `:` type($result)";
+}
+
+def XLAGPU_PureCallOp : XLAGPU_Op<"pure_call",
+    [Pure, CallOpInterface, DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+  ]> {
+  let summary = "Function call without side effects.";
+  let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<AnyType>:$operands);
+  let results = (outs Variadic<AnyType>);
+  let builders = [
+    OpBuilder<(ins "mlir::func::FuncOp":$callee, CArg<"mlir::ValueRange", "{}">:$operands), [{
+      $_state.addOperands(operands);
+      $_state.addAttribute("callee", mlir::SymbolRefAttr::get(callee));
+      $_state.addTypes(callee.getFunctionType().getResults());
+    }]>];
+  let assemblyFormat = [{
+    $callee `(` $operands `)` attr-dict `:` functional-type($operands, results)
+  }];
+
+  let extraClassDeclaration = [{
+    operand_range getArgOperands() {
+      return getOperands();
+    }
+
+    mlir::MutableOperandRange getArgOperandsMutable() {
+      return getOperandsMutable();
+    }
+
+    mlir::CallInterfaceCallable getCallableForCallee() {
+      return (*this)->getAttrOfType<mlir::SymbolRefAttr>("callee");
+    }
+
+    void setCalleeFromCallable(mlir::CallInterfaceCallable callee) {
+      (*this)->setAttr("callee", callee.get<mlir::SymbolRefAttr>());
+    }
+  }];
+}
+
+def XLAGPU_ShuffleReduceOp : XLAGPU_Op<"shuffle_reduce",
+    [Pure,
+     TypesMatchWith<"result type matches type of operands",
+                     "operands", "results", "$_self">]> {
+  let summary = "Performs a full warp shuffle and reduces the values";
+  let description = [{
+    This op performs a full warp shuffle and reduces the results using the given
+    function. The function is invoked with the operands from the low lanes,
+    followed by the operands from the high lanes. For example:
+
+    ```
+    shuffle_reduce @argmax(%value, %idx) : (f32, index)
+    ```
+
+    Will perform shuffles with distance 16, 8, 4, 2 and 1, and will invoke
+    @argmax five times. The first invocations will be
+
+    ```
+    @argmax(%value[i], %idx[i], %value[16+i], %idx[16+i])
+    ```
+  }];
+  let builders = [
+    OpBuilder<(ins "mlir::func::FuncOp":$reducer, "mlir::ValueRange":$operands, "int64_t":$max_distance), [{
+      $_state.addOperands(operands);
+      $_state.addAttribute("reducer", mlir::SymbolRefAttr::get(reducer));
+      $_state.addAttribute("max_distance",
+          mlir::IntegerAttr::get(
+              mlir::IntegerType::get(reducer.getContext(), 64),
+              max_distance));
+      $_state.addTypes(reducer.getFunctionType().getResults());
+    }]>];
+  let arguments = (ins FlatSymbolRefAttr:$reducer,
+                   Variadic<AnyType>:$operands,
+                   I64Attr:$max_distance);
+  let results = (outs Variadic<AnyType>:$results);
+
+  let assemblyFormat = [{
+    $reducer `(` $operands `)` `to` $max_distance attr-dict `:` type($operands)
+  }];
+}
+
+def XLAGPU_PredicatedInsertOp : XLAGPU_Op<"predicated_insert",
+    [Pure,
+     TypesMatchWith<"result type matches type of operands",
+                    "dest", "result", "$_self">,
+     TypesMatchWith<"value type matches element type of dest",
+                    "dest", "value",
+                    "::llvm::cast<mlir::TensorType>($_self).getElementType()">]> {
+  let summary = "Inserts a value into a tensor if a condition holds";
+  let arguments = (ins I1:$condition, AnyType:$value,
+    AnyStaticShapeTensor:$dest, Variadic<Index>:$indices);
+  let results = (outs AnyStaticShapeTensor:$result);
+
+  let assemblyFormat = [{
+     $value `into` $dest `[` $indices `]` `if` $condition attr-dict `:` type($dest)
+  }];
+}
+
+def XLAGPU_PredicatedExtractOp : XLAGPU_Op<"predicated_extract",
+    [Pure,
+     TypesMatchWith<"fallback type matches element type of src",
+                    "src", "fallback",
+                    "::llvm::cast<mlir::TensorType>($_self).getElementType()">,
+     TypesMatchWith<"result type matches element type of src",
+                    "src", "result",
+                    "::llvm::cast<mlir::TensorType>($_self).getElementType()">]> {
+  let summary = "Inserts a value into a tensor if a condition holds";
+  let arguments = (ins I1:$condition, AnyType:$fallback,
+    AnyStaticShapeTensor:$src, Variadic<Index>:$indices);
+  let results = (outs AnyType:$result);
+
+  let assemblyFormat = [{
+    $src `[` $indices `]` `if` $condition `else` $fallback attr-dict `:` type($src)
+  }];
+}
+
+def ApplyIndexingOp : XLAGPU_Op<"apply_indexing", [Pure]> {
+  let summary = "Applies indexing map to a list of SSA values";
+  let description = [{
+    The `apply_indexing` operation applies an affine map to a list
+    of SSA values, yielding a single SSA value. The number of dimension and
+    symbol arguments must be equal to the respective number of dimensional and
+    symbolic inputs in the affine map. The affine mapping can be
+    multi-dimensional, and so the `apply_indexing` operation always returns one
+    value. The operands and results must all have ‘index’ type.
+
+    Example:
+
+    ```mlir
+    #map = affine_map<(d0, d1)[s0] -> (d0 floordiv 8 + d1 floordiv 128, s0)>
+    %results:2 = xla_gpu_ops.apply_indexing #map (%0 in [0, 10], %1 in [0, 11])[%2 in [11, 32]]
+    ```
+  }];
+  let arguments = (ins Variadic<Index>:$operands,
+                   XLAGPU_IndexingMapAttr:$indexing_map_attr);
+  let results = (outs Variadic<Index>);
+
+  let builders = [
+    OpBuilder<(ins "mlir::ValueRange":$dims, "mlir::ValueRange":$symbols,
+                   "const IndexingMap&":$indexing_map)>,
+    OpBuilder<(ins "mlir::ValueRange":$operands,
+                   "const IndexingMap&":$indexing_map)>,
+    OpBuilder<(ins "mlir::ValueRange":$operands, "mlir::AffineMap":$affine_map,
+               "llvm::ArrayRef<DimVar>":$dim_vars,
+               "llvm::ArrayRef<RangeVar>":$range_vars)>,
+  ];
+  let extraClassDeclaration = [{
+    // Returns the indexing map constructed from IndexingMapAttr.
+    xla::gpu::IndexingMap getIndexingMap();
+    // Extracts the affine map from the attribute.
+    mlir::AffineMap getAffineMap() { return getIndexingMapAttr().getMap(); }
+  }];
+  let hasCustomAssemblyFormat = 1;
+  let hasVerifier = 1;
+  let hasCanonicalizer = 1;
+  let hasFolder = 1;
+}
+
+def LoopOp : XLAGPU_Op<"loop", [
+    AttrSizedOperandSegments, Pure,
+    SingleBlockImplicitTerminator<"xla::gpu::YieldOp">
+  ]> {
+  let summary = "Loop nest that iterates over all feasible values of RangeVars.";
+  let description = [{
+
+    ```mlir
+      #map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1),
+        domain:
+        d0 in [0, 3],
+        s0 in [0, 1024],
+        s1 in [0, 32]
+      >
+      // Initial sum set to 0.
+      %sum_0 = arith.constant 0.0 : f32
+      %dim = arith.constant 1 : index
+      // iter_args binds initial values to the loop's region arguments.
+      %sum = xla_gpu.loop (%dim)[%i, %j] in #map iter_args(%sum_iter = %sum_0) -> (f32) {
+        %t = tensor.extract %buffer[%i, %j] : tensor<1024x32xf32>
+        %sum_next = arith.addf %sum_iter, %t : f32
+        // Yield current iteration sum to next iteration %sum_iter or to %sum
+        // if final iteration.
+        scf.yield %sum_next : f32
+      }
+    ```
+  }];
+  let arguments = (ins XLAGPU_IndexingMapAttr:$indexing_map_attr,
+                       Variadic<Index>:$dims,
+                       Variadic<AnyType>:$inits);
+  let results = (outs Variadic<AnyType>);
+  let regions = (region SizedRegion<1>:$region);
+
+  let builders = [
+    OpBuilder<(ins "IndexingMapAttr":$indexing_map_attr,
+      "mlir::ValueRange":$dims, "mlir::ValueRange":$inits,
+      CArg<"llvm::function_ref<void(mlir::OpBuilder &, mlir::Location, mlir::ValueRange, mlir::ValueRange)>",
+           "nullptr">)>,
+    OpBuilder<(ins "const IndexingMap&":$indexing_map,
+      "mlir::ValueRange":$dims, "mlir::ValueRange":$inits,
+      CArg<"llvm::function_ref<void(mlir::OpBuilder &, mlir::Location, mlir::ValueRange, mlir::ValueRange)>",
+           "nullptr">)>
+  ];
+
+  let extraClassDeclaration = [{
+    using BodyBuilderFn =
+        llvm::function_ref<void(mlir::OpBuilder&, mlir::Location,
+                                mlir::ValueRange, mlir::ValueRange)>;
+
+    // Returns the indexing map constructed from IndexingMapAttr.
+    xla::gpu::IndexingMap getIndexingMap();
+    int64_t getNumInductionVars() {
+      return getBody()->getNumArguments() - getNumResults();
+    }
+    mlir::BlockArgument getInductionVar(int64_t index) {
+      return getBody()->getArgument(index);
+    }
+    mlir::Block::BlockArgListType getInductionVars() {
+      return getBody()->getArguments().take_front(getNumInductionVars());
+    }
+    mlir::Block::BlockArgListType getRegionIterArgs() {
+      return getBody()->getArguments().drop_front(getNumInductionVars());
+    }
+  }];
+  let hasCustomAssemblyFormat = 1;
+  let hasVerifier = 1;
+}
+
+def XLAGPU_MaterializeOp : XLAGPU_Op<"materialize", [AttrSizedOperandSegments]> {
+  let summary = "Reads a tensor into registers";
+  let arguments = (ins Variadic<AnyType>:$input,
+                       Variadic<Index>:$indices,
+                       FlatSymbolRefAttr:$callee,
+                       XLAGPU_IndexingMapAttr:$map);
+  let results = (outs XLAGPU_IndexedVectorType:$result);
+  let hasVerifier = 1;
+  let assemblyFormat = [{
+    $callee `(` $input `)` `at` $map `(` $indices `)` attr-dict `:` functional-type($input, results)
+  }];
+}
+
+def XLAGPU_InsertOp : XLAGPU_Op<"insert", []> {
+  let summary = "Inserts an indexed vector into a tensor";
+  let arguments = (ins XLAGPU_IndexedVectorType:$source,
+                       Variadic<Index>:$indices,
+                       AnyRankedTensor:$dest,
+                       XLAGPU_IndexingMapAttr:$map);
+  let results = (outs AnyRankedTensor:$result);
+  let assemblyFormat = [{
+    $source `into` $dest `at` $map `(` $indices `)` attr-dict `:` type($source) `->` type($dest) `into` type($result)
+  }];
+}
+
+#endif // XLA_SERVICE_GPU_FUSIONS_MLIR_OPS
diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops_test.cc b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops_test.cc
new file mode 100644
index 0000000..2d9076d
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops_test.cc
@@ -0,0 +1,84 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h"
+
+#include <gtest/gtest.h>
+#include "xla/service/gpu/model/indexing_map.h"
+#include "xla/service/gpu/model/indexing_test_utils.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/test.h"
+
+namespace xla::gpu {
+namespace {
+
+using ::testing::IsEmpty;
+using ::testing::Pair;
+using ::testing::UnorderedElementsAre;
+
+class XLAGPUOpsTest : public HloTestBase {
+ public:
+  mlir::MLIRContext mlir_context_;
+};
+
+TEST_F(XLAGPUOpsTest, GetConstraintsForVariables) {
+  auto map = IndexingMap(
+      ParseAffineMap("(d0, d1)[s0, s1] -> (d0+s0, d1+s1)", &mlir_context_),
+      /*dimensions=*/{{Interval{0, 5}}, {Interval{0, 2}}},
+      /*range_vars=*/{{Interval{0, 32}}, {Interval{0, 1024}}}, /*rt_vars=*/{});
+  map.AddConstraint(ParseAffineExpr("s0 mod 4", &mlir_context_),
+                    Interval{0, 1});
+  map.AddConstraint(ParseAffineExpr("s1 mod 4", &mlir_context_),
+                    Interval{0, 2});
+  map.AddConstraint(ParseAffineExpr("s0 + s1", &mlir_context_), Interval{0, 3});
+  map.AddConstraint(ParseAffineExpr("s1 + d1", &mlir_context_), Interval{0, 4});
+  map.AddConstraint(ParseAffineExpr("d1 mod 32", &mlir_context_),
+                    Interval{0, 6});
+
+  auto constraints_for_variables = GetConstraintsForVariables(map);
+  EXPECT_THAT(constraints_for_variables.constraints_for_dims[0],
+              UnorderedElementsAre());
+  EXPECT_THAT(
+      constraints_for_variables.constraints_for_dims[1],
+      UnorderedElementsAre(
+          Pair(ParseAffineExpr("s1 + d1", &mlir_context_), Interval{0, 4}),
+          Pair(ParseAffineExpr("d1 mod 32", &mlir_context_), Interval{0, 6})));
+  EXPECT_THAT(
+      constraints_for_variables.constraints_for_symbols[0],
+      UnorderedElementsAre(
+          Pair(ParseAffineExpr("s0 mod 4", &mlir_context_), Interval{0, 1}),
+          Pair(ParseAffineExpr("s0 + s1", &mlir_context_), Interval{0, 3})));
+  EXPECT_THAT(
+      constraints_for_variables.constraints_for_symbols[1],
+      UnorderedElementsAre(
+          Pair(ParseAffineExpr("s1 mod 4", &mlir_context_), Interval{0, 2}),
+          Pair(ParseAffineExpr("s0 + s1", &mlir_context_), Interval{0, 3}),
+          Pair(ParseAffineExpr("s1 + d1", &mlir_context_), Interval{0, 4})));
+}
+
+TEST_F(XLAGPUOpsTest, GetConstraintsForVariablesEmpty) {
+  auto map = IndexingMap(
+      ParseAffineMap("(d0, d1)[s0, s1] -> (d0+s0, d1+s1)", &mlir_context_),
+      /*dimensions=*/{{Interval{0, 5}}, {Interval{0, 2}}},
+      /*range_vars=*/{{Interval{0, 32}}, {Interval{0, 1024}}}, /*rt_vars=*/{});
+  auto constraints_for_variables = GetConstraintsForVariables(map);
+  EXPECT_THAT(constraints_for_variables.constraints_for_dims,
+              ElementsAre(IsEmpty(), IsEmpty()));
+  EXPECT_THAT(constraints_for_variables.constraints_for_symbols,
+              ElementsAre(IsEmpty(), IsEmpty()));
+}
+
+}  // namespace
+}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_types.cc b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_types.cc
new file mode 100644
index 0000000..1c1b218
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_types.cc
@@ -0,0 +1,57 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <cstdint>
+
+#include "mlir/IR/Attributes.h"  // IWYU pragma: keep
+#include "mlir/IR/BuiltinTypes.h"  // IWYU pragma: keep
+#include "mlir/IR/Dialect.h"  // IWYU pragma: keep
+#include "mlir/IR/OpImplementation.h"  // IWYU pragma: keep
+#include "mlir/IR/Types.h"
+#include "mlir/Support/LLVM.h"
+#include "xla/service/gpu/fusions/ir/xla_gpu_dialect.h.inc"
+#include "xla/service/gpu/model/indexing_map.h"  // IWYU pragma: keep
+#define GET_ATTRDEF_CLASSES
+#include "xla/service/gpu/fusions/ir/xla_gpu_attrs.h.inc"
+#undef GET_ATTRDEF_CLASSES
+#define GET_TYPEDEF_LIST
+#define GET_TYPEDEF_CLASSES
+#include "xla/service/gpu/fusions/ir/xla_gpu_types.h.inc"
+
+namespace xla {
+namespace gpu {
+
+mlir::Type IndexedVectorType::parse(mlir::AsmParser& parser) {
+  mlir::SmallVector<int64_t, 4> shape;
+  mlir::Type type;
+  IndexingMapAttr indexing_map_attr;
+  if (parser.parseLess() ||
+      parser.parseDimensionList(shape, /*allowDynamic=*/false) ||
+      parser.parseType(type) || parser.parseComma() ||
+      parser.parseAttribute(indexing_map_attr) || parser.parseGreater()) {
+    return {};
+  }
+  return IndexedVectorType::get(parser.getContext(), shape, type,
+                                indexing_map_attr);
+}
+
+void IndexedVectorType::print(mlir::AsmPrinter& printer) const {
+  printer << "<";
+  printer.printDimensionList(getShape());
+  printer << "x" << getElementType() << ", " << getIndexingMapAttr() << ">";
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_types.td b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_types.td
new file mode 100644
index 0000000..5d73344
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_types.td
@@ -0,0 +1,50 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_FUSIONS_MLIR_TYPES
+#define XLA_SERVICE_GPU_FUSIONS_MLIR_TYPES
+
+include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/BuiltinTypes.td"
+include "mlir/IR/BuiltinTypeInterfaces.td"
+include "xla/service/gpu/fusions/ir/xla_gpu_dialect.td"
+include "xla/service/gpu/fusions/ir/xla_gpu_attrs.td"
+
+class XLAGPU_Type<string name, string typeMnemonic, list<Trait> traits = []>
+    : TypeDef<XlaGpuDialect, name, traits> {
+  let mnemonic = typeMnemonic;
+}
+
+def XLAGPU_IndexedVectorType : XLAGPU_Type<"IndexedVector", "indexed_vector",
+    [ShapedTypeInterface, ValueSemantics]> {
+  let summary = "Vector type with a specified layout";
+  let parameters = (ins
+    ArrayRefParameter<"int64_t">:$shape,
+    "mlir::Type":$elementType,
+    XLAGPU_IndexingMapAttr:$indexing_map_attr
+  );
+  let hasCustomAssemblyFormat = 1;
+  let extraClassDeclaration = [{
+    IndexedVectorType cloneWith(std::optional<llvm::ArrayRef<int64_t>> shape,
+                                mlir::Type elementType) const {
+        return IndexedVectorType::get(getContext(), shape.value_or(getShape()),
+                                      elementType, getIndexingMapAttr());
+    }
+
+    bool hasRank() const { return true; }
+  }];
+}
+
+#endif // XLA_SERVICE_GPU_FUSIONS_MLIR_TYPES
\ No newline at end of file
diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/BUILD b/third_party/xla/xla/service/gpu/fusions/legacy/BUILD
new file mode 100644
index 0000000..98d8ade
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/legacy/BUILD
@@ -0,0 +1,406 @@
+load("//xla:xla.bzl", "xla_cc_test")
+
+package(
+    # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
+    default_visibility = ["//xla/service/gpu/fusions:__pkg__"],
+    licenses = ["notice"],
+)
+
+cc_library(
+    name = "in_place_dynamic_update_slice",
+    srcs = ["in_place_dynamic_update_slice.cc"],
+    hdrs = ["in_place_dynamic_update_slice.h"],
+    deps = [
+        "//xla/hlo/ir:hlo",
+        "//xla/service/gpu:hlo_fusion_analysis",
+        "//xla/service/gpu:ir_emission_utils",
+        "//xla/service/gpu:ir_emitter",
+        "//xla/service/gpu:ir_emitter_context",
+        "//xla/service/gpu:launch_dimensions",
+        "//xla/service/gpu/fusions:fusion_emitter",
+        "//xla/service/gpu/model:indexing_analysis",
+        "//xla/service/llvm_ir:dynamic_update_slice_util",
+        "//xla/service/llvm_ir:fused_ir_emitter",
+        "//xla/service/llvm_ir:ir_array",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//llvm:ir_headers",
+        "@llvm-project//mlir:IR",
+    ],
+)
+
+xla_cc_test(
+    name = "in_place_dynamic_update_slice_test",
+    srcs = ["in_place_dynamic_update_slice_test.cc"],
+    deps = [
+        ":in_place_dynamic_update_slice",
+        "//xla/service/gpu:gpu_device_info_for_tests",
+        "//xla/service/gpu:hlo_fusion_analysis",
+        "//xla/service/gpu/fusions",
+        "//xla/service/gpu/model:affine_map_printer",
+        "//xla/service/gpu/model:indexing_test_utils",
+        "//xla/stream_executor:device_description",
+        "//xla/tests:hlo_test_base",
+        "//xla/tests:xla_internal_test_main",
+        "@com_google_googletest//:gtest",
+        "@llvm-project//mlir:IR",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+cc_library(
+    name = "loop",
+    srcs = ["loop.cc"],
+    hdrs = ["loop.h"],
+    deps = [
+        "//xla:shape_util",
+        "//xla:util",
+        "//xla/hlo/ir:hlo",
+        "//xla/service/gpu:gpu_fusible",
+        "//xla/service/gpu:hlo_fusion_analysis",
+        "//xla/service/gpu:hlo_traversal",
+        "//xla/service/gpu:ir_emission_utils",
+        "//xla/service/gpu:ir_emitter",
+        "//xla/service/gpu:ir_emitter_context",
+        "//xla/service/gpu:launch_dimensions",
+        "//xla/service/gpu:parallel_loop_emitter",
+        "//xla/service/gpu/fusions:fusion_emitter",
+        "//xla/service/gpu/model:indexing_analysis",
+        "//xla/service/llvm_ir:fused_ir_emitter",
+        "//xla/service/llvm_ir:ir_array",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/numeric:bits",
+        "@com_google_absl//absl/status",
+        "@llvm-project//llvm:ir_headers",
+        "@llvm-project//mlir:IR",
+        "@local_tsl//tsl/platform:macros",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+xla_cc_test(
+    name = "loop_test",
+    srcs = ["loop_test.cc"],
+    deps = [
+        "//xla:status_macros",
+        "//xla/service/gpu:gpu_device_info_for_tests",
+        "//xla/service/gpu:hlo_fusion_analysis",
+        "//xla/service/gpu/fusions",
+        "//xla/service/gpu/fusions:fusion_emitter",
+        "//xla/service/gpu/model:affine_map_printer",
+        "//xla/service/gpu/model:indexing_test_utils",
+        "//xla/stream_executor:device_description",
+        "//xla/tests:hlo_test_base",
+        "//xla/tests:xla_internal_test_main",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_googletest//:gtest",
+        "@llvm-project//mlir:IR",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+cc_library(
+    name = "scatter",
+    srcs = ["scatter.cc"],
+    hdrs = ["scatter.h"],
+    deps = [
+        ":loop",
+        "//xla:shape_util",
+        "//xla:xla_data_proto_cc",
+        "//xla/hlo/ir:hlo",
+        "//xla/service/gpu:gpu_fusible",
+        "//xla/service/gpu:hlo_fusion_analysis",
+        "//xla/service/gpu:ir_emission_utils",
+        "//xla/service/gpu:ir_emitter",
+        "//xla/service/gpu:ir_emitter_context",
+        "//xla/service/gpu:launch_dimensions",
+        "//xla/service/gpu:parallel_loop_emitter",
+        "//xla/service/gpu/fusions:fusion_emitter",
+        "//xla/service/gpu/model:indexing_analysis",
+        "//xla/service/llvm_ir:fused_ir_emitter",
+        "//xla/service/llvm_ir:ir_array",
+        "//xla/service/llvm_ir:llvm_util",
+        "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/types:span",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//llvm:ir_headers",
+        "@llvm-project//mlir:IR",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+xla_cc_test(
+    name = "scatter_test",
+    srcs = ["scatter_test.cc"],
+    deps = [
+        ":scatter",
+        "//xla/service/gpu:gpu_device_info_for_tests",
+        "//xla/service/gpu:hlo_fusion_analysis",
+        "//xla/service/gpu/fusions",
+        "//xla/service/gpu/model:affine_map_printer",
+        "//xla/service/gpu/model:indexing_test_utils",
+        "//xla/stream_executor:device_description",
+        "//xla/tests:hlo_test_base",
+        "//xla/tests:xla_internal_test_main",
+        "@com_google_googletest//:gtest",
+        "@llvm-project//mlir:IR",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+cc_library(
+    name = "tiling_util",
+    srcs = ["tiling_util.cc"],
+    hdrs = ["tiling_util.h"],
+    visibility = ["//xla/service/gpu:__subpackages__"],
+    deps = [
+        "//xla:shape_util",
+        "//xla:util",
+        "//xla:xla_data_proto_cc",
+        "//xla/service/gpu:ir_emission_utils",
+        "//xla/service/gpu:target_util",
+        "//xla/service/gpu/model:indexing_analysis",
+        "//xla/service/llvm_ir:ir_array",
+        "//xla/service/llvm_ir:kernel_support_library",
+        "//xla/service/llvm_ir:llvm_loop",
+        "//xla/service/llvm_ir:llvm_util",
+        "@com_google_absl//absl/container:inlined_vector",
+        "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/strings:str_format",
+        "@com_google_absl//absl/types:span",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//llvm:ir_headers",
+        "@llvm-project//mlir:IR",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+cc_library(
+    name = "reduction",
+    srcs = ["reduction.cc"],
+    hdrs = ["reduction.h"],
+    deps = [
+        ":tiling_util",
+        "//xla:shape_util",
+        "//xla:status_macros",
+        "//xla:util",
+        "//xla:xla_data_proto_cc",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:buffer_assignment",
+        "//xla/service/gpu:hlo_fusion_analysis",
+        "//xla/service/gpu:hlo_traversal",
+        "//xla/service/gpu:ir_emission_utils",
+        "//xla/service/gpu:ir_emitter",
+        "//xla/service/gpu:ir_emitter_context",
+        "//xla/service/gpu:kernel_arguments",
+        "//xla/service/gpu:kernel_reuse_cache",
+        "//xla/service/gpu:launch_dimensions",
+        "//xla/service/gpu:parallel_loop_emitter",
+        "//xla/service/gpu:reduction_utils",
+        "//xla/service/gpu:target_util",
+        "//xla/service/gpu/fusions:fusion_emitter",
+        "//xla/service/gpu/fusions:reduction_base",
+        "//xla/service/gpu/fusions:thunk_util",
+        "//xla/service/gpu/runtime:kernel_thunk",
+        "//xla/service/gpu/runtime:thunk",
+        "//xla/service/llvm_ir:fused_ir_emitter",
+        "//xla/service/llvm_ir:ir_array",
+        "//xla/service/llvm_ir:kernel_support_library",
+        "//xla/service/llvm_ir:llvm_loop",
+        "//xla/service/llvm_ir:llvm_util",
+        "//xla/service/llvm_ir:loop_emitter",
+        "//xla/stream_executor:device_description",
+        "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/container:inlined_vector",
+        "@com_google_absl//absl/container:node_hash_map",
+        "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:span",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//llvm:ir_headers",
+        "@llvm-project//mlir:Support",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:logging",
+        "@local_tsl//tsl/platform:status",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+xla_cc_test(
+    name = "reduction_test",
+    srcs = ["reduction_test.cc"],
+    deps = [
+        ":reduction",
+        "//xla/hlo/ir:hlo",
+        "//xla/service/gpu:gpu_device_info_for_tests",
+        "//xla/service/gpu:hlo_fusion_analysis",
+        "//xla/service/gpu:ir_emitter_context",
+        "//xla/service/gpu/fusions:fusion_emitter",
+        "//xla/service/gpu/model:indexing_analysis",
+        "//xla/service/gpu/model:indexing_test_utils",
+        "//xla/stream_executor:device_description",
+        "//xla/tests:hlo_test_base",
+        "//xla/tests:xla_internal_test_main",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_googletest//:gtest",
+        "@llvm-project//mlir:IR",
+    ],
+)
+
+cc_library(
+    name = "concatenate",
+    srcs = ["concatenate.cc"],
+    hdrs = ["concatenate.h"],
+    deps = [
+        "//xla:shape_util",
+        "//xla/hlo/ir:hlo",
+        "//xla/service/gpu:hlo_fusion_analysis",
+        "//xla/service/gpu:ir_emission_utils",
+        "//xla/service/gpu:ir_emitter",
+        "//xla/service/gpu:ir_emitter_context",
+        "//xla/service/gpu:launch_dimensions",
+        "//xla/service/gpu:parallel_loop_emitter",
+        "//xla/service/gpu/fusions:fusion_emitter",
+        "//xla/service/gpu/model:indexing_analysis",
+        "//xla/service/llvm_ir:fused_ir_emitter",
+        "//xla/service/llvm_ir:ir_array",
+        "//xla/service/llvm_ir:loop_emitter",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/status",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//llvm:ir_headers",
+        "@llvm-project//mlir:IR",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+xla_cc_test(
+    name = "concatenate_test",
+    srcs = ["concatenate_test.cc"],
+    deps = [
+        ":concatenate",
+        "//xla/service/gpu:gpu_device_info_for_tests",
+        "//xla/service/gpu:hlo_fusion_analysis",
+        "//xla/service/gpu/fusions",
+        "//xla/service/gpu/model:affine_map_printer",
+        "//xla/service/gpu/model:indexing_test_utils",
+        "//xla/stream_executor:device_description",
+        "//xla/tests:hlo_test_base",
+        "//xla/tests:xla_internal_test_main",
+        "@com_google_googletest//:gtest",
+        "@llvm-project//mlir:IR",
+    ],
+)
+
+cc_library(
+    name = "transpose",
+    srcs = ["transpose.cc"],
+    hdrs = ["transpose.h"],
+    deps = [
+        ":tiling_util",
+        "//xla:permutation_util",
+        "//xla:shape_util",
+        "//xla:xla_data_proto_cc",
+        "//xla/hlo/ir:hlo",
+        "//xla/service/gpu:hlo_fusion_analysis",
+        "//xla/service/gpu:ir_emission_utils",
+        "//xla/service/gpu:ir_emitter",
+        "//xla/service/gpu:ir_emitter_context",
+        "//xla/service/gpu:launch_dimensions",
+        "//xla/service/gpu:target_util",
+        "//xla/service/gpu/fusions:fusion_emitter",
+        "//xla/service/gpu/model:indexing_analysis",
+        "//xla/service/llvm_ir:fused_ir_emitter",
+        "//xla/service/llvm_ir:ir_array",
+        "//xla/service/llvm_ir:llvm_util",
+        "//xla/service/llvm_ir:loop_emitter",
+        "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/container:inlined_vector",
+        "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:span",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//llvm:ir_headers",
+        "@llvm-project//mlir:IR",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+xla_cc_test(
+    name = "transpose_test",
+    srcs = ["transpose_test.cc"],
+    deps = [
+        ":transpose",
+        "//xla:status_macros",
+        "//xla/service/gpu:gpu_device_info_for_tests",
+        "//xla/service/gpu:hlo_fusion_analysis",
+        "//xla/service/gpu/fusions",
+        "//xla/service/gpu/model:indexing_test_utils",
+        "//xla/stream_executor:device_description",
+        "//xla/tests:hlo_test_base",
+        "//xla/tests:xla_internal_test_main",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_googletest//:gtest",
+        "@llvm-project//mlir:IR",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+cc_library(
+    name = "input_slices",
+    srcs = ["input_slices.cc"],
+    hdrs = ["input_slices.h"],
+    deps = [
+        "//xla:shape_util",
+        "//xla:util",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:elemental_ir_emitter",
+        "//xla/service/gpu:hlo_fusion_analysis",
+        "//xla/service/gpu:ir_emission_utils",
+        "//xla/service/gpu:ir_emitter",
+        "//xla/service/gpu:ir_emitter_context",
+        "//xla/service/gpu:launch_dimensions",
+        "//xla/service/gpu:parallel_loop_emitter",
+        "//xla/service/gpu/fusions:fusion_emitter",
+        "//xla/service/gpu/model:indexing_analysis",
+        "//xla/service/llvm_ir:fused_ir_emitter",
+        "//xla/service/llvm_ir:ir_array",
+        "//xla/service/llvm_ir:kernel_support_library",
+        "//xla/service/llvm_ir:llvm_loop",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:span",
+        "@llvm-project//llvm:ir_headers",
+        "@llvm-project//mlir:IR",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+xla_cc_test(
+    name = "input_slices_test",
+    srcs = ["input_slices_test.cc"],
+    deps = [
+        ":input_slices",
+        "//xla/service/gpu:gpu_device_info_for_tests",
+        "//xla/service/gpu:hlo_fusion_analysis",
+        "//xla/service/gpu/fusions",
+        "//xla/service/gpu/model:affine_map_printer",
+        "//xla/service/gpu/model:indexing_test_utils",
+        "//xla/stream_executor:device_description",
+        "//xla/tests:hlo_test_base",
+        "//xla/tests:xla_internal_test_main",
+        "@com_google_googletest//:gtest",
+        "@llvm-project//mlir:IR",
+    ],
+)
diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/README.md b/third_party/xla/xla/service/gpu/fusions/legacy/README.md
new file mode 100644
index 0000000..0fa6bb9
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/legacy/README.md
@@ -0,0 +1,8 @@
+# Deprecated emitters
+
+The emitters in this directory are deprecated. Please do not add any new
+features. If you believe you need to add a feature, please reach out and
+describe your use case.
+
+These emitters have more modern MLIR-based equivalents in the directory above
+this one.
\ No newline at end of file
diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/concatenate.cc b/third_party/xla/xla/service/gpu/fusions/legacy/concatenate.cc
new file mode 100644
index 0000000..8bb0e04
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/legacy/concatenate.cc
@@ -0,0 +1,137 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "xla/service/gpu/fusions/legacy/concatenate.h"
+
+#include <cstdint>
+#include <optional>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/status/status.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Value.h"
+#include "mlir/IR/MLIRContext.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/service/gpu/elemental_ir_emitter.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
+#include "xla/service/gpu/ir_emission_utils.h"
+#include "xla/service/gpu/ir_emitter_context.h"
+#include "xla/service/gpu/launch_dimensions.h"
+#include "xla/service/gpu/model/indexing_map.h"
+#include "xla/service/gpu/parallel_loop_emitter.h"
+#include "xla/service/llvm_ir/fused_ir_emitter.h"
+#include "xla/service/llvm_ir/ir_array.h"
+#include "xla/service/llvm_ir/loop_emitter.h"
+#include "xla/shape.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+const Shape& GetLargestConcatOperandShape(const HloFusionAnalysis& analysis) {
+  const HloInstruction& concat = analysis.fusion_hero(0).instruction();
+  int64_t dim = concat.concatenate_dimension();
+  auto less = [&](const HloInstruction* lhs, const HloInstruction* rhs) {
+    return lhs->shape().dimensions(dim) < rhs->shape().dimensions(dim);
+  };
+  HloInstruction* operand = *absl::c_max_element(concat.operands(), less);
+  return operand->shape();
+}
+
+ConcatenateFusion::ConcatenateFusion(const HloFusionAnalysis& analysis)
+    : analysis_(analysis) {}
+
+std::optional<IndexingMap> ConcatenateFusion::ComputeThreadIdToOutputIndexing(
+    int64_t root_index, mlir::MLIRContext* ctx) const {
+  return std::nullopt;
+}
+
+std::optional<IndexingMap> ConcatenateFusion::ComputeThreadIdToInputIndexing(
+    int64_t root_index, int64_t hero_operand_index,
+    mlir::MLIRContext* ctx) const {
+  return GetDefaultThreadIdIndexingMap(launch_dimensions(), /*unroll_factor=*/1,
+                                       GetLargestConcatOperandShape(analysis_),
+                                       ctx);
+}
+
+absl::Status ConcatenateFusion::EmitKernel(
+    IrEmitterContext& ir_emitter_context, const HloFusionInstruction& fusion,
+    const LaunchDimensions& launch_dims, std::vector<llvm_ir::IrArray> inputs,
+    std::vector<llvm_ir::IrArray> outputs, llvm::IRBuilder<>* builder) const {
+  GpuElementalIrEmitter elemental_emitter(ir_emitter_context, builder);
+  FusedIrEmitter fused_emitter(elemental_emitter);
+  for (int i = 0; i < fusion.fused_parameters().size(); i++) {
+    fused_emitter.BindGenerator(
+        *fusion.fused_parameter(i), [&, i](llvm_ir::IrArray::Index index) {
+          return inputs[i].EmitReadArrayElement(index, builder);
+        });
+  }
+
+  llvm::Type* index_type =
+      GetIndexTypeForKernel(&fusion, launch_dims.launch_bound(), builder);
+
+  const HloInstruction& concat = analysis_.fusion_hero(0).instruction();
+  int64_t concat_dim = concat.concatenate_dimension();
+  int64_t operand_offset = 0;
+
+  // Emit the slices that correspond to the operands of the concat hero.
+  for (const HloInstruction* operand : concat.operands()) {
+    llvm_ir::BodyEmitter body_emitter =
+        [&](const llvm_ir::IrArray::Index& operand_index) -> absl::Status {
+      // Bind concat to generate the current operand.
+      TF_ASSIGN_OR_RETURN(auto operand_generator,
+                          fused_emitter.GetGenerator(*operand));
+      fused_emitter.BindGenerator(concat, [&](llvm_ir::IrArray::Index) {
+        return operand_generator(operand_index);
+      });
+
+      // Create the index of the slice corresponding to the current operand.
+      llvm_ir::IrArray::Index result_index = operand_index.AddOffsetToDim(
+          llvm::ConstantInt::get(index_type, operand_offset), concat_dim,
+          builder);
+      operand_offset += operand->shape().dimensions(concat_dim);
+
+      // Generate and write out the slice for each root.
+      for (const auto& [output, root] :
+           llvm::zip_equal(outputs, analysis_.fusion_roots())) {
+        llvm_ir::IrArray::Index root_index = result_index.SourceIndexOfBitcast(
+            concat.shape(), root.shape(), builder);
+        TF_ASSIGN_OR_RETURN(auto generator,
+                            fused_emitter.GetGenerator(root.instruction()));
+        TF_ASSIGN_OR_RETURN(llvm::Value * value, generator(root_index));
+        output.EmitWriteArrayElement(root_index, value, builder);
+      }
+      return absl::OkStatus();
+    };
+
+    ParallelLoopEmitter emitter(body_emitter, operand->shape(), launch_dims,
+                                builder);
+    TF_RETURN_IF_ERROR(emitter.EmitLoop(fusion.name(), index_type));
+  }
+
+  return absl::OkStatus();
+}
+
+LaunchDimensions ConcatenateFusion::launch_dimensions() const {
+  return CalculateLaunchDimensions(GetLargestConcatOperandShape(analysis_),
+                                   analysis_.device_info());
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/concatenate.h b/third_party/xla/xla/service/gpu/fusions/legacy/concatenate.h
new file mode 100644
index 0000000..be0465b
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/legacy/concatenate.h
@@ -0,0 +1,67 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_FUSIONS_LEGACY_CONCATENATE_H_
+#define XLA_SERVICE_GPU_FUSIONS_LEGACY_CONCATENATE_H_
+
+#include <optional>
+#include <vector>
+
+#include "absl/status/status.h"
+#include "llvm/IR/IRBuilder.h"
+#include "mlir/IR/MLIRContext.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/service/gpu/fusions/fusion_emitter.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
+#include "xla/service/gpu/ir_emitter_context.h"
+#include "xla/service/gpu/launch_dimensions.h"
+#include "xla/service/gpu/model/indexing_map.h"
+#include "xla/service/llvm_ir/ir_array.h"
+#include "xla/shape.h"
+
+namespace xla {
+namespace gpu {
+
+const Shape& GetLargestConcatOperandShape(const HloFusionAnalysis& analysis);
+
+// Emits a kernel for the given hlo instruction where each thread produces
+// one element of each concat operand.
+class ConcatenateFusion : public KernelFusionEmitterBase {
+ public:
+  explicit ConcatenateFusion(const HloFusionAnalysis& analysis);
+  LaunchDimensions launch_dimensions() const override;
+
+  std::optional<IndexingMap> ComputeThreadIdToOutputIndexing(
+      int64_t root_index, mlir::MLIRContext* ctx) const override;
+
+  std::optional<IndexingMap> ComputeThreadIdToInputIndexing(
+      int64_t root_index, int64_t hero_operand_index,
+      mlir::MLIRContext* ctx) const override;
+
+ protected:
+  absl::Status EmitKernel(IrEmitterContext& ir_emitter_context,
+                          const HloFusionInstruction& fusion,
+                          const LaunchDimensions& launch_dims,
+                          std::vector<llvm_ir::IrArray> inputs,
+                          std::vector<llvm_ir::IrArray> outputs,
+                          llvm::IRBuilder<>* builder) const override;
+
+ private:
+  const HloFusionAnalysis& analysis_;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_FUSIONS_LEGACY_CONCATENATE_H_
diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/concatenate_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/concatenate_test.cc
new file mode 100644
index 0000000..ee63bda
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/legacy/concatenate_test.cc
@@ -0,0 +1,120 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "xla/service/gpu/fusions/legacy/concatenate.h"
+
+#include <optional>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "mlir/IR/MLIRContext.h"
+#include "xla/service/gpu/fusions/fusions.h"
+#include "xla/service/gpu/gpu_device_info_for_tests.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
+#include "xla/service/gpu/model/affine_map_printer.h"
+#include "xla/service/gpu/model/indexing_test_utils.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/tests/hlo_test_base.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+class ConcatenateTest : public HloTestBase {
+ public:
+  void SetUp() override {
+    HloTestBase::SetUp();
+    printer_ =
+        AffineMapPrinter({"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"},
+                         {"chunk_id", "unroll_id"});
+  }
+
+ protected:
+  DebugOptions GetDebugOptionsForTest() override {
+    auto opts = HloTestBase::GetDebugOptionsForTest();
+    opts.set_xla_gpu_mlir_emitter_level(0);
+    return opts;
+  }
+  AffineMapPrinter printer_;
+  mlir::MLIRContext mlir_context_;
+};
+
+TEST_F(ConcatenateTest, ThreadIndexing) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule module
+
+    fused_computation {
+      param0 = f32[200] parameter(0)
+      param1 = f32[400] parameter(1)
+      param2 = f32[300] parameter(2)
+      ROOT concat = f32[900] concatenate(param0, param1, param2), dimensions={0}
+    }
+    ENTRY main {
+      param0 = f32[200] parameter(0)
+      param1 = f32[400] parameter(1)
+      param2 = f32[300] parameter(2)
+      ROOT fusion = f32[900] fusion(param0, param1, param2),
+        calls=fused_computation, kind=kLoop
+    }
+  )")
+                    .value();
+
+  stream_executor::DeviceDescription device_info =
+      TestGpuDeviceInfo::RTXA6000DeviceInfo();
+
+  auto* root = module->entry_computation()->root_instruction();
+  auto analysis_fused = HloFusionAnalysis::Create(*root, device_info);
+
+  auto emitter =
+      GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis_fused});
+  auto fusion = dynamic_cast<ConcatenateFusion*>(emitter.get());
+  ASSERT_NE(fusion, nullptr);
+
+  constexpr auto kIndexing = R"(
+    (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] ->
+      (bl_x * 128 + th_x)
+    domain:
+    th_x in [0, 127]
+    th_y in [0, 0]
+    th_z in [0, 0]
+    bl_x in [0, 3]
+    bl_y in [0, 0]
+    bl_z in [0, 0]
+    chunk_id in [0, 0]
+    unroll_id in [0, 0]
+    bl_x * 128 + th_x in [0, 399]
+  )";
+  EXPECT_THAT(
+      fusion
+          ->ComputeThreadIdToInputIndexing(
+              /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_)
+          ->ToString(printer_),
+      MatchIndexingString(kIndexing));
+  EXPECT_THAT(
+      fusion
+          ->ComputeThreadIdToInputIndexing(
+              /*root_index=*/0, /*hero_operand_index=*/1, &mlir_context_)
+          ->ToString(printer_),
+      MatchIndexingString(kIndexing));
+  EXPECT_THAT(
+      fusion
+          ->ComputeThreadIdToInputIndexing(
+              /*root_index=*/0, /*hero_operand_index=*/2, &mlir_context_)
+          ->ToString(printer_),
+      MatchIndexingString(kIndexing));
+}
+
+}  // namespace
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice.cc b/third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice.cc
new file mode 100644
index 0000000..38a3e5b
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice.cc
@@ -0,0 +1,105 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice.h"
+
+#include <optional>
+#include <utility>
+#include <vector>
+
+#include "absl/status/status.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/IR/IRBuilder.h"
+#include "mlir/IR/MLIRContext.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/service/gpu/elemental_ir_emitter.h"
+#include "xla/service/gpu/ir_emitter_context.h"
+#include "xla/service/gpu/launch_dimensions.h"
+#include "xla/service/gpu/model/indexing_map.h"
+#include "xla/service/llvm_ir/dynamic_update_slice_util.h"
+#include "xla/service/llvm_ir/fused_ir_emitter.h"
+#include "xla/service/llvm_ir/ir_array.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+constexpr int kDUSUpdateIndex = 1;
+
+}  // namespace
+
+LaunchDimensions InPlaceDynamicUpdateSliceFusion::launch_dimensions() const {
+  const auto& update_shape = dus_ops_.front().GetOperand(1).shape();
+  return CalculateLaunchDimensions(update_shape, analysis_.device_info());
+}
+
+std::optional<IndexingMap>
+InPlaceDynamicUpdateSliceFusion::ComputeThreadIdToInputIndexing(
+    int64_t root_index, int64_t hero_operand_index,
+    mlir::MLIRContext* mlir_context) const {
+  if (hero_operand_index != kDUSUpdateIndex) {
+    return std::nullopt;
+  }
+  auto launch_dims = launch_dimensions();
+  // It is guaranteed that all DUS ops have the same output shape at this point.
+  const auto& update_shape =
+      dus_ops_.front().GetOperand(kDUSUpdateIndex).shape();
+  return GetDefaultThreadIdIndexingMap(launch_dims, /*unroll_factor=*/1,
+                                       update_shape, mlir_context);
+}
+
+absl::Status InPlaceDynamicUpdateSliceFusion::EmitKernel(
+    IrEmitterContext& ir_emitter_context, const HloFusionInstruction& fusion,
+    const LaunchDimensions& launch_dims, std::vector<llvm_ir::IrArray> inputs,
+    std::vector<llvm_ir::IrArray> outputs, llvm::IRBuilder<>* builder) const {
+  // In case a dynamic slice update's output is bitcasted, we need to ensure we
+  // write to the output array using the shape and layout of the dynamic slice
+  // update. This cast is known to be safe to do iff, in the case the output of
+  // the dynamic slice update is bitcasted, that bitcast is either the fusion's
+  // output, or has a single user and is part of the fusion's tuple output.
+  // This condition should be enforced explicitly in the
+  // 'CanEmitFusedDynamicUpdateSliceInPlaceForGpu' matcher.
+  for (auto [op, output] : llvm::zip(dus_ops_, outputs)) {
+    output = output.CastToShape(op.shape(), builder);
+  }
+
+  auto* fused_computation = fusion.fused_instructions_computation();
+  GpuElementalIrEmitter elemental_emitter(ir_emitter_context, builder);
+  FusedIrEmitter fused_emitter(elemental_emitter);
+  for (auto [index, input] : llvm::enumerate(inputs)) {
+    auto fused_operand = fused_computation->parameter_instruction(index);
+    fused_emitter.BindGenerator(
+        *fused_operand, [input = input, builder,
+                         fused_operand](const llvm_ir::IrArray::Index& index) {
+          return input.EmitReadArrayElement(index, builder,
+                                            fused_operand->name());
+        });
+  }
+
+  std::vector<std::pair<const HloInstruction*, const llvm_ir::IrArray>>
+      dus_and_output_array;
+  dus_and_output_array.reserve(dus_ops_.size());
+
+  for (auto [op, output] : llvm::zip(dus_ops_, outputs)) {
+    dus_and_output_array.push_back(std::make_pair(&op.instruction(), output));
+  }
+
+  return llvm_ir::EmitParallelFusedDynamicUpdateSliceInPlace(
+      fused_computation, dus_and_output_array, &fused_emitter, launch_dims,
+      builder);
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice.h b/third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice.h
new file mode 100644
index 0000000..db12c3c
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice.h
@@ -0,0 +1,98 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_FUSIONS_LEGACY_IN_PLACE_DYNAMIC_UPDATE_SLICE_H_
+#define XLA_SERVICE_GPU_FUSIONS_LEGACY_IN_PLACE_DYNAMIC_UPDATE_SLICE_H_
+
+#include <cstdint>
+#include <optional>
+#include <vector>
+
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "llvm/IR/IRBuilder.h"
+#include "mlir/IR/MLIRContext.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/service/gpu/fusions/fusion_emitter.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
+#include "xla/service/gpu/ir_emission_utils.h"
+#include "xla/service/gpu/ir_emitter_context.h"
+#include "xla/service/gpu/launch_dimensions.h"
+#include "xla/service/gpu/model/indexing_map.h"
+#include "xla/service/llvm_ir/ir_array.h"
+
+namespace xla {
+namespace gpu {
+
+// Fusion node where the root is either:
+// 1. a dynamic-update-slice op
+// 2. a bitcast of a dynamic-update-slice op
+// 3. a tuple op returning the result of several dynamic-update-slice ops
+// 4. a tuple op returning the result of several bitcast
+//    dynamic-update-slice ops
+//
+// Additionally, all the dynamic-update-slice ops have exactly one user. The
+// fusion parameter that they update can have users (in addition to the
+// dynamic-update-slice op) that read in either
+// a. a dynamic-slice corresponding exactly to the slice of the parameter that
+//    is updated by the dynamic-update-slice op
+// b. a dynamic-slice reading in a single element anywhere in the parameter.
+//    This is only allowed if the dynamic-update-slice op updates a single
+//    element
+//
+// In both cases, the additional users must not flow into any other output
+// than the dynamic-slice-update corresponding to that particular slice of the
+// parameter.
+//
+// The assumption is that each op's input (i.e. array to update) shares the
+// same slice as its output. In this case, we have a special algorithm that
+// modifies the output in place without touching the un-updated elements. The
+// update slice is assumed to be the exact same for all the
+// dynamic-update-slice ops.
+class InPlaceDynamicUpdateSliceFusion : public KernelFusionEmitterBase {
+ public:
+  explicit InPlaceDynamicUpdateSliceFusion(const HloFusionAnalysis& analysis)
+      : analysis_(analysis),
+        dus_ops_(
+            GetOutputDefiningDynamicUpdateSlices(analysis.fusion_roots())) {}
+  LaunchDimensions launch_dimensions() const override;
+
+  std::optional<IndexingMap> ComputeThreadIdToOutputIndexing(
+      int64_t root_index, mlir::MLIRContext* ctx) const override {
+    // The mapping cannot be statically computed in general, since the offsets
+    // are unknown.
+    return std::nullopt;
+  }
+
+  std::optional<IndexingMap> ComputeThreadIdToInputIndexing(
+      int64_t root_index, int64_t hero_operand_index,
+      mlir::MLIRContext* mlir_context) const override;
+
+ protected:
+  absl::Status EmitKernel(IrEmitterContext& ir_emitter_context,
+                          const HloFusionInstruction& fusion,
+                          const LaunchDimensions& launch_dims,
+                          std::vector<llvm_ir::IrArray> inputs,
+                          std::vector<llvm_ir::IrArray> outputs,
+                          llvm::IRBuilder<>* builder) const override;
+
+  const HloFusionAnalysis& analysis_;
+  std::vector<HloInstructionAdaptor> dus_ops_;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_FUSIONS_LEGACY_IN_PLACE_DYNAMIC_UPDATE_SLICE_H_
diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice_test.cc
new file mode 100644
index 0000000..c4fd277
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice_test.cc
@@ -0,0 +1,144 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice.h"
+
+#include <optional>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "mlir/IR/MLIRContext.h"
+#include "xla/service/gpu/fusions/fusions.h"
+#include "xla/service/gpu/gpu_device_info_for_tests.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
+#include "xla/service/gpu/model/affine_map_printer.h"
+#include "xla/service/gpu/model/indexing_test_utils.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+class InPlaceDynamicUpdateSliceFusionTest : public HloTestBase {
+ public:
+  void SetUp() override {
+    HloTestBase::SetUp();
+    printer_ =
+        AffineMapPrinter({"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"},
+                         {"chunk_id", "unroll_id"});
+  }
+
+ protected:
+  DebugOptions GetDebugOptionsForTest() override {
+    auto opts = HloTestBase::GetDebugOptionsForTest();
+    opts.set_xla_gpu_mlir_emitter_level(0);
+    return opts;
+  }
+  AffineMapPrinter printer_;
+  mlir::MLIRContext mlir_context_;
+  stream_executor::DeviceDescription device_info_ =
+      TestGpuDeviceInfo::RTXA6000DeviceInfo();
+};
+
+TEST_F(InPlaceDynamicUpdateSliceFusionTest, ThreadIndexing) {
+  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"(
+    HloModule module
+
+    fused_computation {
+      in = f32[20,30] parameter(0)
+      updates = f32[5,6] parameter(1)
+      i0 = s32[] parameter(2)
+      i1 = s32[] parameter(3)
+      ROOT updated = f32[20,30] dynamic-update-slice(in, updates, i0, i1)
+    }
+    ENTRY entry {
+      in = f32[20,30] parameter(0)
+      updates = f32[5,6] parameter(1)
+      i0 = s32[] constant(2)
+      i1 = s32[] constant(3)
+      ROOT fusion = f32[20,30] fusion(in, updates, i0, i1), kind=kLoop, calls=fused_computation
+    }
+  )"));
+
+  auto* root = module->entry_computation()->root_instruction();
+  auto analysis_fused = HloFusionAnalysis::Create(*root, device_info_);
+
+  auto emitter =
+      GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis_fused});
+  auto fusion = dynamic_cast<InPlaceDynamicUpdateSliceFusion*>(emitter.get());
+  ASSERT_NE(fusion, nullptr);
+
+  auto thread_id_update_indexing = fusion->ComputeThreadIdToInputIndexing(
+      /*root_index=*/0, /*hero_operand_index=*/1, &mlir_context_);
+  EXPECT_THAT(thread_id_update_indexing->ToString(printer_),
+              MatchIndexingString(R"(
+    (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (
+    th_x floordiv 6, th_x mod 6)
+    domain:
+    th_x in [0, 29]
+    th_y in [0, 0]
+    th_z in [0, 0]
+    bl_x in [0, 0]
+    bl_y in [0, 0]
+    bl_z in [0, 0]
+    chunk_id in [0, 0]
+    unroll_id in [0, 0]
+  )"));
+  auto thread_id_dst_indexing = fusion->ComputeThreadIdToInputIndexing(
+      /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_);
+  EXPECT_THAT(thread_id_dst_indexing, ::testing::Eq(std::nullopt));
+}
+
+TEST_F(InPlaceDynamicUpdateSliceFusionTest, ProduceConsumerFusion) {
+  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"(
+    HloModule m
+
+    fused_computation.1 {
+      param_0 = bf16[1,2,5,1,2] parameter(0)
+      bitcast = bf16[1,5,1,2,2] bitcast(param_0)
+      param_1 = bf16[1,1,1,2,2] parameter(1)
+      param_2 = s32[] parameter(2)
+      param_3 = s32[] parameter(3)
+      ROOT dynamic-update-slice = bf16[1,5,1,2,2] dynamic-update-slice(bitcast, param_1, param_2, param_3, param_2, param_2, param_2)
+    }
+
+    ENTRY entry_computation {
+      param_0.2 = bf16[1,2,5,1,2] parameter(3)
+      param_1.2 = bf16[1,1,1,2,2] parameter(0)
+      param_2.2 = s32[] parameter(1)
+      param_3.2 = s32[] parameter(2)
+      fusion = bf16[1,5,1,2,2] fusion(param_0.2, param_1.2, param_2.2, param_3.2), kind=kLoop, calls=fused_computation.1
+      ROOT bitcast.1 = bf16[1,2,5,1,2] bitcast(fusion)
+    }
+  )"));
+
+  auto* root = module->entry_computation()->root_instruction();
+
+  auto analysis_fused =
+      HloFusionAnalysis::Create(*root->operand(0), *root, device_info_);
+
+  auto emitter =
+      GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis_fused});
+
+  auto fusion = dynamic_cast<InPlaceDynamicUpdateSliceFusion*>(emitter.get());
+
+  ASSERT_NE(fusion, nullptr);
+  EXPECT_EQ(fusion->launch_dimensions().launch_bound(), 4 /* update size */);
+}
+
+}  // namespace
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/input_slices.cc b/third_party/xla/xla/service/gpu/fusions/legacy/input_slices.cc
new file mode 100644
index 0000000..d336f92
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/legacy/input_slices.cc
@@ -0,0 +1,220 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "xla/service/gpu/fusions/legacy/input_slices.h"
+
+#include <cstddef>
+#include <cstdint>
+#include <optional>
+#include <vector>
+
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "absl/strings/str_cat.h"
+#include "absl/types/span.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Value.h"
+#include "mlir/IR/MLIRContext.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/elemental_ir_emitter.h"
+#include "xla/service/gpu/elemental_ir_emitter.h"
+#include "xla/service/gpu/ir_emission_utils.h"
+#include "xla/service/gpu/ir_emitter_context.h"
+#include "xla/service/gpu/launch_dimensions.h"
+#include "xla/service/gpu/model/indexing_map.h"
+#include "xla/service/gpu/parallel_loop_emitter.h"
+#include "xla/service/llvm_ir/fused_ir_emitter.h"
+#include "xla/service/llvm_ir/ir_array.h"
+#include "xla/service/llvm_ir/kernel_support_library.h"
+#include "xla/service/llvm_ir/llvm_loop.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/util.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+// Emits code for slices based on the below structure. An if statement with
+// a guarding condition is generated for each ROOT slice.
+//
+// Pseudo code:
+//
+// Compute values of slice input operands
+//
+// Compute guarding_cond0
+// if (guarding_cond0) {
+//   Write to output of slice0
+// }
+//
+// Compute guarding_cond1
+// if (guarding_cond1) {
+//   Write to output of slice1
+// }
+//
+absl::Status EmitElementForInputFusibleSlices(
+    ElementalIrEmitter& elemental_emitter,
+    const HloComputation* fused_computation,
+    const std::vector<llvm_ir::IrArray>& inputs,
+    const std::vector<llvm_ir::IrArray>& outputs,
+    const llvm_ir::IrArray::Index& index, llvm::IRBuilder<>* builder) {
+  VLOG(10) << "Emitting slice input fusion for "
+           << fused_computation->ToString();
+
+  HloInstruction* slice_or_tuple = fused_computation->root_instruction();
+  auto slice_instructions = [&]() -> absl::Span<HloInstruction* const> {
+    if (slice_or_tuple->opcode() == HloOpcode::kSlice) {
+      return absl::Span<HloInstruction* const>(&slice_or_tuple, 1);
+    }
+    CHECK_EQ(slice_or_tuple->opcode(), HloOpcode::kTuple);
+    return slice_or_tuple->operands();
+  }();
+
+  // Emit input operand values of slices.
+  std::vector<llvm::Value*> input_ir_values;
+  FusedIrEmitter fused_emitter(elemental_emitter);
+  for (int i = 0; i < fused_computation->num_parameters(); i++) {
+    fused_emitter.BindGenerator(
+        *fused_computation->parameter_instruction(i),
+        [&inputs, i, builder](llvm_ir::IrArray::Index index) {
+          return inputs[i].EmitReadArrayElement(index, builder);
+        });
+  }
+  for (const HloInstruction* slice : slice_instructions) {
+    auto input_generator = *fused_emitter.GetGenerator(*slice->operand(0));
+    input_ir_values.push_back(input_generator(index).value());
+  }
+
+  // Emit for slice_instructions.
+  KernelSupportLibrary ksl(builder, llvm_ir::UnrollMode::kDefaultUnroll);
+  for (int64_t i = 0; i < slice_instructions.size(); ++i) {
+    HloInstruction* slice = slice_instructions[i];
+
+    // guarding_cond := index >= start && index < limit, for each dim.
+    std::vector<llvm::Value*> index_within_ranges;
+    for (size_t dim = 0; dim < slice->slice_starts().size(); ++dim) {
+      CHECK_EQ(slice->slice_strides(dim), 1);
+      auto larger_or_equal_than_start = builder->CreateICmpSGE(
+          index.multidim()[dim],
+          index.GetConstantWithIndexType(slice->slice_starts(dim)));
+      llvm::Value* smaller_than_limit = builder->CreateICmpSLT(
+          index.multidim()[dim],
+          index.GetConstantWithIndexType(slice->slice_limits(dim)));
+      llvm::Value* within_range =
+          builder->CreateAnd(larger_or_equal_than_start, smaller_than_limit);
+      index_within_ranges.push_back(within_range);
+    }
+    llvm::Value* guarding_cond = builder->CreateAnd(index_within_ranges);
+
+    auto emit_slice_elem_func = [&] {
+      const std::vector<llvm::Value*>& src_multidim = index.multidim();
+      std::vector<llvm::Value*> dst_multidim(src_multidim.size());
+      for (size_t dim = 0; dim < src_multidim.size(); ++dim) {
+        dst_multidim[dim] = builder->CreateSub(
+            src_multidim[dim],
+            index.GetConstantWithIndexType(slice->slice_starts(dim)));
+      }
+      const llvm_ir::IrArray& src_ir_array = outputs[i];
+      llvm_ir::IrArray::Index slice_dst_index(dst_multidim, slice->shape(),
+                                              index.GetType());
+      src_ir_array.EmitWriteArrayElement(slice_dst_index, input_ir_values[i],
+                                         builder);
+    };
+
+    ksl.If(absl::StrCat("slice", i), guarding_cond, emit_slice_elem_func);
+  }
+  return absl::OkStatus();
+}
+
+// Gets the input shape of the ROOT slices, which will be used as the kernel
+// launch dims. The slice input fusion requires the input shapes of the ROOT
+// slices to be the same although the (slice) output shapes can be different.
+//
+// Returns the input shape of the ROOT slices if all the input shapes of ROOT
+// slices are the same and the slices are non-strided. Otherwise, returns
+// FailedPrecondition.
+absl::StatusOr<Shape> GetConsistentInputShapeForRootSlices(
+    const HloComputation* fused_computation) {
+  const HloInstruction& root = *fused_computation->root_instruction();
+  if (root.opcode() == HloOpcode::kSlice) {
+    return root.operands()[0]->shape();
+  }
+
+  CHECK_EQ(root.opcode(), HloOpcode::kTuple);
+  const Shape& first_slice_operand_shape =
+      root.operands()[0]->operands()[0]->shape();
+  for (size_t i = 1; i < root.operands().size(); ++i) {
+    const HloInstruction* slice = root.operands()[i];
+    const Shape& operand_shape = slice->operands()[0]->shape();
+    if (!ShapeUtil::EqualIgnoringElementType(first_slice_operand_shape,
+                                             operand_shape)) {
+      return FailedPrecondition(
+          "Fused slices do not have the same input shape, fused computation = "
+          "%s.",
+          root.parent()->name());
+    }
+  }
+
+  return first_slice_operand_shape;
+}
+
+}  // namespace
+
+LaunchDimensions InputSlicesFusion::launch_dimensions() const {
+  const auto& root = analysis_.fusion_root(0).instruction();
+  const auto& shape = root.operand(0)->shape();
+  return CalculateLaunchDimensions(shape, analysis_.device_info(),
+                                   {unroll_factor_});
+}
+
+std::optional<IndexingMap> InputSlicesFusion::ComputeThreadIdToOutputIndexing(
+    int64_t output_id, mlir::MLIRContext* ctx) const {
+  // The mapping here is trivial and the same for all outputs - slice offsets
+  // are applied in the indexing from slice outputs to slice inputs.
+  auto launch_dims = launch_dimensions();
+  // The implementation requires the shapes and layouts to be the same, but we
+  // still use the requested output's shape for clarity.
+  const auto& shape = analysis_.fusion_root(output_id).shape();
+  return GetDefaultThreadIdIndexingMap(launch_dims, unroll_factor_, shape, ctx);
+}
+
+absl::Status InputSlicesFusion::EmitKernel(
+    IrEmitterContext& ir_emitter_context, const HloFusionInstruction& fusion,
+    const LaunchDimensions& launch_dims, std::vector<llvm_ir::IrArray> inputs,
+    std::vector<llvm_ir::IrArray> outputs, llvm::IRBuilder<>* builder) const {
+  TF_ASSIGN_OR_RETURN(Shape element_shape,
+                      GetConsistentInputShapeForRootSlices(
+                          fusion.fused_instructions_computation()));
+  LaunchDimensionsConfig launch_config;
+  launch_config.unroll_factor = unroll_factor_;
+  GpuElementalIrEmitter elemental_emitter(ir_emitter_context, builder);
+  return ParallelLoopEmitter(
+             [&](const llvm_ir::IrArray::Index index) -> absl::Status {
+               return EmitElementForInputFusibleSlices(
+                   elemental_emitter, fusion.fused_instructions_computation(),
+                   inputs, outputs, index, builder);
+             },
+             element_shape, launch_dims, builder, launch_config)
+      .EmitLoop(
+          fusion.name(),
+          GetIndexTypeForKernel(&fusion, launch_dims.launch_bound(), builder));
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/input_slices.h b/third_party/xla/xla/service/gpu/fusions/legacy/input_slices.h
new file mode 100644
index 0000000..e653224
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/legacy/input_slices.h
@@ -0,0 +1,79 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_FUSIONS_LEGACY_INPUT_SLICES_H_
+#define XLA_SERVICE_GPU_FUSIONS_LEGACY_INPUT_SLICES_H_
+
+#include <cstdint>
+#include <optional>
+#include <vector>
+
+#include "absl/status/status.h"
+#include "llvm/IR/IRBuilder.h"
+#include "mlir/IR/MLIRContext.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/service/gpu/fusions/fusion_emitter.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
+#include "xla/service/gpu/ir_emitter_context.h"
+#include "xla/service/gpu/launch_dimensions.h"
+#include "xla/service/gpu/model/indexing_analysis.h"
+#include "xla/service/gpu/model/indexing_map.h"
+#include "xla/service/llvm_ir/ir_array.h"
+#include "xla/util.h"
+
+namespace xla {
+namespace gpu {
+
+// Generates code for input-fusible slices.
+//
+// Prerequisite: ROOT is either a slice or a tuple of slices. The input shapes
+// of all ROOT slices need to be the same while their output shapes can be
+// different. On the other hand, the input ranges of slices can be
+// overlapping. Further generalization/specialization when the needs are seen
+// in the future.
+class InputSlicesFusion : public KernelFusionEmitterBase {
+ public:
+  explicit InputSlicesFusion(const HloFusionAnalysis& analysis)
+      : analysis_(analysis),
+        unroll_factor_(CeilOfRatio(
+            8, analysis.input_output_info().smallest_output_dtype_bits)) {}
+  LaunchDimensions launch_dimensions() const override;
+
+  std::optional<IndexingMap> ComputeThreadIdToOutputIndexing(
+      int64_t output_id, mlir::MLIRContext* ctx) const override;
+
+  std::optional<IndexingMap> ComputeThreadIdToInputIndexing(
+      int64_t root_index, int64_t hero_operand_index,
+      mlir::MLIRContext* ctx) const override {
+    // TODO(b/319081342): Implement this.
+    return std::nullopt;
+  }
+
+ protected:
+  absl::Status EmitKernel(IrEmitterContext& ir_emitter_context,
+                          const HloFusionInstruction& fusion,
+                          const LaunchDimensions& launch_dims,
+                          std::vector<llvm_ir::IrArray> inputs,
+                          std::vector<llvm_ir::IrArray> outputs,
+                          llvm::IRBuilder<>* builder) const override;
+
+ private:
+  const HloFusionAnalysis& analysis_;
+  const int unroll_factor_;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_FUSIONS_LEGACY_INPUT_SLICES_H_
diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/input_slices_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/input_slices_test.cc
new file mode 100644
index 0000000..bb9f510
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/legacy/input_slices_test.cc
@@ -0,0 +1,104 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "xla/service/gpu/fusions/legacy/input_slices.h"
+
+#include <optional>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "mlir/IR/MLIRContext.h"
+#include "xla/service/gpu/fusions/fusions.h"
+#include "xla/service/gpu/gpu_device_info_for_tests.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
+#include "xla/service/gpu/model/affine_map_printer.h"
+#include "xla/service/gpu/model/indexing_test_utils.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/tests/hlo_test_base.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+class InputSlicesTest : public HloTestBase {
+ public:
+  void SetUp() override {
+    HloTestBase::SetUp();
+    printer_ =
+        AffineMapPrinter({"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"},
+                         {"chunk_id", "unroll_id"});
+  }
+
+ protected:
+  DebugOptions GetDebugOptionsForTest() override {
+    auto opts = HloTestBase::GetDebugOptionsForTest();
+    opts.set_xla_gpu_mlir_emitter_level(0);
+    return opts;
+  }
+  AffineMapPrinter printer_;
+  mlir::MLIRContext mlir_context_;
+};
+
+TEST_F(InputSlicesTest, ThreadIndexing) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule module
+
+    fused_computation {
+      %input = f32[2,3,5,7]{2,1,0,3} parameter(0)
+      slice0 = f32[1,2,3,5]{2,1,0,3} slice(input), slice={[0:1],[1:3],[0:3],[2:7]}
+      slice1 = f32[1,2,3,5]{2,1,0,3} slice(input), slice={[0:1],[0:2],[0:3],[2:7]}
+      ROOT tuple = (f32[1,2,3,5]{2,1,0,3}, f32[1,2,3,5]{2,1,0,3}) tuple(slice0, slice1)
+    }
+
+    ENTRY entry {
+      %input = f32[2,3,5,7]{2,1,0,3} parameter(0)
+      ROOT %fusion = (f32[1,2,3,5]{2,1,0,3}, f32[1,2,3,5]{2,1,0,3}) fusion(%input), kind=kLoop, calls=fused_computation
+    })")
+                    .value();
+
+  stream_executor::DeviceDescription device_info =
+      TestGpuDeviceInfo::RTXA6000DeviceInfo();
+
+  auto* root = module->entry_computation()->root_instruction();
+  auto analysis_fused = HloFusionAnalysis::Create(*root, device_info);
+
+  auto emitter =
+      GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis_fused});
+  auto fusion = dynamic_cast<InputSlicesFusion*>(emitter.get());
+  ASSERT_NE(fusion, nullptr);
+
+  auto thread_id_to_output_indexing =
+      fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context_);
+  EXPECT_THAT(thread_id_to_output_indexing->ToString(printer_),
+              MatchIndexingString(R"(
+    (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (0,
+      ((bl_x * 128 + th_x) floordiv 3) mod 2,
+       (bl_x * 128 + th_x) mod 3,
+       (bl_x * 128 + th_x) floordiv 6)
+    domain:
+    th_x in [0, 127]
+    th_y in [0, 0]
+    th_z in [0, 0]
+    bl_x in [0, 1]
+    bl_y in [0, 0]
+    bl_z in [0, 0]
+    chunk_id in [0, 0]
+    unroll_id in [0, 0]
+    bl_x * 128 + th_x in [0, 29]
+  )"));
+}
+
+}  // namespace
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/loop.cc b/third_party/xla/xla/service/gpu/fusions/legacy/loop.cc
new file mode 100644
index 0000000..e6ce5f1
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/legacy/loop.cc
@@ -0,0 +1,132 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "xla/service/gpu/fusions/legacy/loop.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <optional>
+#include <tuple>
+#include <utility>
+#include <vector>
+
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/numeric/bits.h"
+#include "absl/status/status.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Type.h"
+#include "mlir/IR/MLIRContext.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/layout_util.h"
+#include "xla/service/gpu/elemental_ir_emitter.h"
+#include "xla/service/gpu/gpu_fusible.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
+#include "xla/service/gpu/hlo_traversal.h"
+#include "xla/service/gpu/ir_emission_utils.h"
+#include "xla/service/gpu/ir_emitter_context.h"
+#include "xla/service/gpu/launch_dimensions.h"
+#include "xla/service/gpu/model/indexing_analysis.h"
+#include "xla/service/gpu/model/indexing_map.h"
+#include "xla/service/gpu/parallel_loop_emitter.h"
+#include "xla/service/llvm_ir/fused_ir_emitter.h"
+#include "xla/service/llvm_ir/ir_array.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/util.h"
+#include "tsl/platform/macros.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+const Shape& GetElementShape(const HloFusionAnalysis& analysis) {
+  const Shape* shape = &analysis.fusion_root(0).shape();
+  while (shape->IsTuple()) {
+    shape = &shape->tuple_shapes(0);
+  }
+  return *shape;
+}
+
+}  // namespace
+
+LoopFusion::LoopFusion(const HloFusionAnalysis& analysis)
+    : analysis_(analysis), config_(ComputeLoopFusionConfig(analysis)) {}
+
+std::optional<IndexingMap> LoopFusion::ComputeThreadIdToOutputIndexing(
+    int64_t root_index, mlir::MLIRContext* ctx) const {
+  auto launch_dims = launch_dimensions();
+  return GetDefaultThreadIdIndexingMap(launch_dims, config_.unroll_factor,
+                                       GetElementShape(analysis_), ctx);
+}
+
+std::optional<IndexingMap> LoopFusion::ComputeThreadIdToInputIndexing(
+    int64_t root_index, int64_t hero_operand_index,
+    mlir::MLIRContext* ctx) const {
+  std::optional<IndexingMap> thread_id_to_output_indexing =
+      ComputeThreadIdToOutputIndexing(root_index, ctx);
+  if (!thread_id_to_output_indexing.has_value()) {
+    return std::nullopt;
+  }
+  const HloInstruction* fusion_root =
+      &analysis_.fusion_root(root_index).instruction();
+  auto output_to_input_indexing =
+      ComputeOutputToInputIndexing(fusion_root, /*output_id=*/0, ctx);
+  IndexingMapSet output_to_input_indexing_set =
+      output_to_input_indexing.indexing_maps[hero_operand_index];
+  // Since we are computing the indexing for a non-fusion op, there is only one
+  // indexing map per operand.
+  CHECK_EQ(output_to_input_indexing_set.size(), 1);
+  IndexingMap thread_id_to_input_indexing_map = ComposeIndexingMaps(
+      *thread_id_to_output_indexing, *output_to_input_indexing_set.begin());
+  thread_id_to_input_indexing_map.Simplify();
+  return thread_id_to_input_indexing_map;
+}
+
+absl::Status LoopFusion::EmitKernel(IrEmitterContext& ir_emitter_context,
+                                    const HloFusionInstruction& fusion,
+                                    const LaunchDimensions& launch_dims,
+                                    std::vector<llvm_ir::IrArray> inputs,
+                                    std::vector<llvm_ir::IrArray> outputs,
+                                    llvm::IRBuilder<>* builder) const {
+  GpuElementalIrEmitter elemental_emitter(ir_emitter_context, builder);
+  FusedIrEmitter fused_emitter(elemental_emitter);
+  for (int i = 0; i < fusion.fused_parameters().size(); i++) {
+    fused_emitter.BindGenerator(
+        *fusion.fused_parameter(i), [&, i](llvm_ir::IrArray::Index index) {
+          return inputs[i].EmitReadArrayElement(index, builder);
+        });
+  }
+  TF_ASSIGN_OR_RETURN(
+      auto element_generator,
+      fused_emitter.GetGenerator(*fusion.fused_expression_root()));
+
+  llvm::Type* index_type =
+      GetIndexTypeForKernel(&fusion, launch_dims.launch_bound(), builder);
+
+  return ParallelLoopEmitter(element_generator, outputs, launch_dims, builder,
+                             config_)
+      .EmitLoop(fusion.name(), index_type);
+}
+
+LaunchDimensions LoopFusion::launch_dimensions() const {
+  return CalculateLaunchDimensions(GetElementShape(analysis_),
+                                   analysis_.device_info(), config_);
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/loop.h b/third_party/xla/xla/service/gpu/fusions/legacy/loop.h
new file mode 100644
index 0000000..30e5007
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/legacy/loop.h
@@ -0,0 +1,65 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_FUSIONS_LEGACY_LOOP_H_
+#define XLA_SERVICE_GPU_FUSIONS_LEGACY_LOOP_H_
+
+#include <cstdint>
+#include <optional>
+#include <vector>
+
+#include "absl/status/status.h"
+#include "llvm/IR/IRBuilder.h"
+#include "mlir/IR/MLIRContext.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/service/gpu/fusions/fusion_emitter.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
+#include "xla/service/gpu/ir_emitter_context.h"
+#include "xla/service/gpu/launch_dimensions.h"
+#include "xla/service/gpu/model/indexing_map.h"
+#include "xla/service/llvm_ir/ir_array.h"
+
+namespace xla {
+namespace gpu {
+
+// Generic loop fusion.
+class LoopFusion : public KernelFusionEmitterBase {
+ public:
+  explicit LoopFusion(const HloFusionAnalysis& analysis);
+  LaunchDimensions launch_dimensions() const override;
+
+  std::optional<IndexingMap> ComputeThreadIdToOutputIndexing(
+      int64_t root_index, mlir::MLIRContext* ctx) const override;
+
+  std::optional<IndexingMap> ComputeThreadIdToInputIndexing(
+      int64_t root_index, int64_t hero_operand_index,
+      mlir::MLIRContext* ctx) const override;
+
+ protected:
+  absl::Status EmitKernel(IrEmitterContext& ir_emitter_context,
+                          const HloFusionInstruction& fusion,
+                          const LaunchDimensions& launch_dims,
+                          std::vector<llvm_ir::IrArray> inputs,
+                          std::vector<llvm_ir::IrArray> outputs,
+                          llvm::IRBuilder<>* builder) const override;
+
+ private:
+  const HloFusionAnalysis& analysis_;
+  LaunchDimensionsConfig config_;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_FUSIONS_LEGACY_LOOP_H_
diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/loop_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/loop_test.cc
new file mode 100644
index 0000000..a05508c
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/legacy/loop_test.cc
@@ -0,0 +1,222 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <optional>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/status/statusor.h"
+#include "mlir/IR/MLIRContext.h"
+#include "xla/service/gpu/fusions/fusion_emitter.h"
+#include "xla/service/gpu/fusions/fusions.h"
+#include "xla/service/gpu/gpu_device_info_for_tests.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
+#include "xla/service/gpu/model/affine_map_printer.h"
+#include "xla/service/gpu/model/indexing_test_utils.h"
+#include "xla/status_macros.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+class LoopTest : public HloTestBase {
+ public:
+  void SetUp() override {
+    HloTestBase::SetUp();
+
+    printer_ =
+        AffineMapPrinter({"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"},
+                         {"chunk_id", "unroll_id"});
+  }
+
+ protected:
+  stream_executor::DeviceDescription device_info_ =
+      TestGpuDeviceInfo::RTXA6000DeviceInfo();
+  AffineMapPrinter printer_;
+  mlir::MLIRContext mlir_context_;
+};
+
+absl::StatusOr<std::unique_ptr<KernelFusionInterface>> GetFusion(
+    const HloFusionAnalysis& analysis) {
+  auto emitter = GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis});
+  auto fusion = dynamic_cast<KernelFusionInterface*>(emitter.get());
+  TF_RET_CHECK(fusion != nullptr);
+
+  emitter.release();
+  return std::unique_ptr<KernelFusionInterface>{fusion};
+}
+
+TEST_F(LoopTest, ThreadIndexingUnrolled) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule module
+
+    neg {
+      %input = f32[100,200,300] parameter(0)
+      ROOT neg = f32[100,200,300] negate(%input)
+    }
+
+    ENTRY entry {
+      %input = f32[100,200,300] parameter(0)
+      ROOT %fusion = f32[100,200,300] fusion(%input), kind=kLoop, calls=neg
+    })")
+                    .value();
+
+  auto* root = module->entry_computation()->root_instruction();
+  auto analysis = HloFusionAnalysis::Create(*root, device_info_);
+
+  TF_ASSERT_OK_AND_ASSIGN(auto loop_fusion, GetFusion(analysis));
+  auto thread_id_to_output_indexing =
+      loop_fusion->ComputeThreadIdToOutputIndexing(/*root_index=*/0,
+                                                   &mlir_context_);
+
+  EXPECT_THAT(thread_id_to_output_indexing->ToString(printer_),
+              MatchIndexingString(R"(
+  (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (
+    (bl_x * 128 + chunk_id * 129024 + th_x) floordiv 15000,
+    ((bl_x * 128 + chunk_id * 129024 + th_x) floordiv 75) mod 200,
+    ((bl_x * 128 + chunk_id * 129024 + th_x) mod 75) * 4 + unroll_id
+  )
+  domain:
+  th_x in [0, 127]
+  th_y in [0, 0]
+  th_z in [0, 0]
+  bl_x in [0, 1007]
+  bl_y in [0, 0]
+  bl_z in [0, 0]
+  chunk_id in [0, 11]
+  unroll_id in [0, 3]
+  bl_x * 128 + chunk_id * 129024 + th_x in [0, 1499999]
+)"));
+}
+
+TEST_F(LoopTest, ThreadIndexingNotUnrolled) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule module
+
+    neg {
+      %input = f32[20] parameter(0)
+      ROOT neg = f32[20] negate(%input)
+    }
+
+    ENTRY entry {
+      %input = f32[20] parameter(0)
+      ROOT %fusion = f32[20] fusion(%input), kind=kLoop, calls=neg
+    })")
+                    .value();
+
+  auto* root = module->entry_computation()->root_instruction();
+  auto analysis = HloFusionAnalysis::Create(*root, device_info_);
+
+  TF_ASSERT_OK_AND_ASSIGN(auto loop_fusion, GetFusion(analysis));
+  auto thread_id_to_output_indexing =
+      loop_fusion->ComputeThreadIdToOutputIndexing(/*root_index=*/0,
+                                                   &mlir_context_);
+  EXPECT_THAT(thread_id_to_output_indexing->ToString(printer_),
+              MatchIndexingString(R"(
+              (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (th_x)
+              domain:
+              th_x in [0, 19]
+              th_y in [0, 0]
+              th_z in [0, 0]
+              bl_x in [0, 0]
+              bl_y in [0, 0]
+              bl_z in [0, 0]
+              chunk_id in [0, 0]
+              unroll_id in [0, 0]
+            )"));
+  auto thread_id_to_input_indexing =
+      loop_fusion->ComputeThreadIdToInputIndexing(
+          /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_);
+  EXPECT_THAT(thread_id_to_input_indexing->ToString(printer_),
+              MatchIndexingString(R"(
+              (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (th_x)
+              domain:
+              th_x in [0, 19]
+              th_y in [0, 0]
+              th_z in [0, 0]
+              bl_x in [0, 0]
+              bl_y in [0, 0]
+              bl_z in [0, 0]
+              chunk_id in [0, 0]
+              unroll_id in [0, 0]
+            )"));
+}
+
+TEST_F(LoopTest, Broadcast) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule module
+
+    bcast {
+      %input = f32[20] parameter(0)
+      ROOT bcast = f32[10, 20, 30] broadcast(%input), dimensions={1}
+    }
+
+    ENTRY entry {
+      %input = f32[20] parameter(0)
+      ROOT %fusion = f32[10, 20, 30] fusion(%input), kind=kLoop, calls=bcast
+    })")
+                    .value();
+
+  auto* root = module->entry_computation()->root_instruction();
+  auto analysis = HloFusionAnalysis::Create(*root, device_info_);
+
+  TF_ASSERT_OK_AND_ASSIGN(auto loop_fusion, GetFusion(analysis));
+  auto thread_id_to_output_indexing =
+      loop_fusion->ComputeThreadIdToOutputIndexing(/*root_index=*/0,
+                                                   &mlir_context_);
+  EXPECT_THAT(thread_id_to_output_indexing->ToString(printer_),
+              MatchIndexingString(R"(
+              (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (
+                (bl_x * 128 + th_x) floordiv 600,
+                ((bl_x * 128 + th_x) floordiv 30) mod 20,
+                (bl_x * 128 + th_x) mod 30)
+                domain:
+                th_x in [0, 127]
+                th_y in [0, 0]
+                th_z in [0, 0]
+                bl_x in [0, 46]
+                bl_y in [0, 0]
+                bl_z in [0, 0]
+                chunk_id in [0, 0]
+                unroll_id in [0, 0]
+                bl_x * 128 + th_x in [0, 5999]
+            )"));
+  auto thread_id_to_input_indexing =
+      loop_fusion->ComputeThreadIdToInputIndexing(
+          /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_);
+  EXPECT_THAT(thread_id_to_input_indexing->ToString(printer_),
+              MatchIndexingString(R"(
+              (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] ->
+                  (((bl_x * 128 + th_x) floordiv 30) mod 20)
+                domain:
+                th_x in [0, 127]
+                th_y in [0, 0]
+                th_z in [0, 0]
+                bl_x in [0, 46]
+                bl_y in [0, 0]
+                bl_z in [0, 0]
+                chunk_id in [0, 0]
+                unroll_id in [0, 0]
+                bl_x * 128 + th_x in [0, 5999]
+            )"));
+}
+
+}  // namespace
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/reduction.cc b/third_party/xla/xla/service/gpu/fusions/legacy/reduction.cc
new file mode 100644
index 0000000..e009ea1
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/legacy/reduction.cc
@@ -0,0 +1,1330 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "xla/service/gpu/fusions/legacy/reduction.h"
+
+#include <cstdint>
+#include <functional>
+#include <memory>
+#include <optional>
+#include <string>
+#include <tuple>
+#include <utility>
+#include <vector>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/container/inlined_vector.h"
+#include "absl/container/node_hash_map.h"
+#include "absl/log/check.h"
+#include "absl/status/status.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/GlobalVariable.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/Type.h"
+#include "llvm/IR/Value.h"
+#include "llvm/Support/AtomicOrdering.h"
+#include "llvm/Support/Casting.h"
+#include "mlir/Support/LLVM.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/layout_util.h"
+#include "xla/service/buffer_assignment.h"
+#include "xla/service/gpu/elemental_ir_emitter.h"
+#include "xla/service/gpu/fusions/fusion_emitter.h"
+#include "xla/service/gpu/fusions/legacy/tiling_util.h"
+#include "xla/service/gpu/fusions/reduction_base.h"
+#include "xla/service/gpu/fusions/thunk_util.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
+#include "xla/service/gpu/hlo_traversal.h"
+#include "xla/service/gpu/ir_emission_utils.h"
+#include "xla/service/gpu/ir_emitter_context.h"
+#include "xla/service/gpu/ir_emitter_nested.h"
+#include "xla/service/gpu/kernel_arguments.h"
+#include "xla/service/gpu/kernel_reuse_cache.h"
+#include "xla/service/gpu/launch_dimensions.h"
+#include "xla/service/gpu/parallel_loop_emitter.h"
+#include "xla/service/gpu/reduction_utils.h"
+#include "xla/service/gpu/runtime/kernel_thunk.h"
+#include "xla/service/gpu/runtime/thunk.h"
+#include "xla/service/gpu/target_util.h"
+#include "xla/service/llvm_ir/fused_ir_emitter.h"
+#include "xla/service/llvm_ir/ir_array.h"
+#include "xla/service/llvm_ir/kernel_support_library.h"
+#include "xla/service/llvm_ir/llvm_loop.h"
+#include "xla/service/llvm_ir/llvm_util.h"
+#include "xla/service/llvm_ir/loop_emitter.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/status_macros.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/logging.h"
+#include "tsl/platform/status.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+using TypedPointer = std::pair<llvm::Value* const, llvm::Type* const>;
+
+// Fusion root -> array of indexes, one per reduction output.
+using ReductionOutputMap =
+    ConstHloInstructionMap<absl::Span<llvm_ir::IrArray const>>;
+
+using ExtraOutputGensMap = ConstHloInstructionMap<llvm_ir::ElementGenerator>;
+
+int GetNumOutputs(const Shape& shape) {
+  if (shape.IsTuple()) {
+    return shape.tuple_shapes_size();
+  }
+  return 1;
+}
+
+const Shape& OutputShape(const Shape& output_shape, int output_index) {
+  CHECK(output_index == 0 || output_shape.IsTuple());
+  return output_shape.IsTuple() ? output_shape.tuple_shapes(output_index)
+                                : output_shape;
+}
+
+llvm::Type* GetIndexType(const HloFusionInstruction& fusion,
+                         const Tiling& tiling, llvm::IRBuilder<>* builder) {
+  return GetIndexTypeForKernel(
+      &fusion, tiling.GetNumThreadsPerBlock() * tiling.GetNumBlocks(), builder);
+}
+
+llvm::Value* CastSharedToGlobal(llvm::IRBuilder<>* builder, llvm::Value* input,
+                                llvm::Type* element_type, llvm::Twine name) {
+  return builder->CreateAddrSpaceCast(
+      input,
+      llvm::PointerType::get(element_type,
+                             /*AddressSpace=*/0),
+      name);
+}
+
+class ReductionEmitter {
+ public:
+  ReductionEmitter(const HloFusionAnalysis& analysis,
+                   const ReductionInfo& reduction_codegen_info,
+                   IrEmitterContext& ir_emitter_context,
+                   const HloFusionInstruction& fusion,
+                   llvm::IRBuilder<>* builder)
+      : builder_(builder),
+        elemental_emitter_(ir_emitter_context, builder_),
+        analysis_(analysis),
+        reduction_codegen_info_(reduction_codegen_info),
+        ir_emitter_context_(ir_emitter_context),
+        fusion_(fusion),
+        index_ty_(GetIndexType(fusion, reduction_codegen_info.GetTiling(),
+                               elemental_emitter_.builder())) {
+    for (auto hero : analysis.fusion_heroes()) {
+      if (hero.opcode() == HloOpcode::kReduce) {
+        for (int i = 0; i < hero.instruction().operand_count() / 2; ++i) {
+          CHECK(LayoutUtil::IsMonotonicWithDim0Major(
+              hero.instruction().operand(i)->shape().layout()))
+              << "reduction-layout-normalizer must run before code generation";
+        }
+      }
+    }
+  }
+
+  absl::StatusOr<FusionEmissionResult> EmitInitializers();
+  absl::Status EmitKernel(const LaunchDimensions& launch_dims,
+                          std::vector<llvm_ir::IrArray> inputs,
+                          std::vector<llvm_ir::IrArray> outputs);
+
+ private:
+  friend class ReductionGroupEmitter;
+
+  absl::StatusOr<std::unique_ptr<Thunk>> BuildKernelThunkForFusion(
+      const LaunchDimensions& launch_dimensions,
+      absl::string_view discriminator,
+      std::function<absl::Status(std::vector<llvm_ir::IrArray>,
+                                 std::vector<llvm_ir::IrArray>)>
+          kernel_builder_fn);
+
+  absl::StatusOr<std::unique_ptr<Thunk>> BuildFusedInitializerThunk(
+      const HloInstruction* fusion_root, BufferAllocation::Slice dest_slice,
+      int output_index);
+
+  absl::Status EmitIRForReduction(
+      absl::Span<const HloInstruction* const> instr_index_group,
+      FusedIrEmitter& fused_emitter, const ReductionOutputMap& result_ir_arrays,
+      const Shape& input_shape);
+
+  void MaybeEmitFenceForAMDGPU();
+  void EmitSyncThreads();
+
+  int ReducedDimensionSize() const {
+    return reduction_codegen_info_.GetTiling().GetShape()[2];
+  }
+
+  llvm::IRBuilder<>* builder_;
+  GpuElementalIrEmitter elemental_emitter_;
+  const HloFusionAnalysis& analysis_;
+  const ReductionInfo& reduction_codegen_info_;
+  IrEmitterContext& ir_emitter_context_;
+  const HloFusionInstruction& fusion_;
+  llvm::Type* index_ty_;
+};
+
+class ReductionEmitter;
+
+class ReductionGroupEmitter {
+ public:
+  struct ReductionCalculationState {
+    std::optional<llvm_ir::SharedMemoryTile> shared_cache;
+    llvm::Value* initial_value;
+    llvm::AllocaInst* partial_result_address;
+    llvm::AllocaInst* input_address;
+    llvm_ir::ElementGenerator input_gen;
+  };
+
+  ReductionGroupEmitter(
+      ReductionEmitter& reduction_emitter,
+      absl::Span<const HloReduceInstruction* const> reduce_instr_index_group,
+      const ReductionOutputMap& result_ir_arrays,
+      FusedIrEmitter& fused_emitter);
+
+  const ReductionCalculationState& GetCalculationStateFor(
+      const HloInstruction* instruction, int operand_idx) const {
+    const ReductionOpState& op_state = state_.at(instruction);
+    CHECK_LT(operand_idx, op_state.size());
+    return op_state[operand_idx];
+  }
+
+  void SetCalculationStateFor(
+      const ReductionCalculationState& calculation_state,
+      const HloInstruction* instruction, int operand_idx) {
+    ReductionOpState& op_state = state_[instruction];
+    CHECK_EQ(operand_idx, op_state.size());
+    op_state.push_back(calculation_state);
+  }
+
+  void EmitReductionOutputForRowReduction(
+      const TilingKernelInfo& tiling_kernel_info,
+      const HloReduceInstruction* reduction,
+      const std::vector<const HloInstruction*>& roots) const;
+
+  void EmitReductionOutputForColumnReduction(
+      const TilingKernelInfo& tiling_kernel_info,
+      const HloReduceInstruction* reduction,
+      const std::vector<const HloInstruction*>& roots) const;
+
+  void EmitFullWarpShuffleDownLoopForReduce(
+      const HloComputation* reducer,
+      absl::Span<TypedPointer const> partial_result_addresses,
+      int threads_per_block, int num_results_per_warp) const;
+
+  void WriteReductionOutput(const TilingKernelInfo& tiling_kernel_info,
+                            const HloReduceInstruction* reduction,
+                            const std::vector<const HloInstruction*>& roots,
+                            absl::Span<TypedPointer const> values) const;
+
+  llvm_ir::IrArray::Index GetOutputIndexForReduction(
+      const TilingKernelInfo& tiling_kernel_info,
+      const HloReduceInstruction* reduction, const HloInstruction* root,
+      int output_idx) const;
+
+  void GenerateElementForReducer(const HloReduceInstruction* reduction,
+                                 const llvm_ir::IrArray::Index& index) const;
+
+  absl::Status EmitExtraOutputsForReduce(
+      const Shape& reduction_operand_shape,
+      const llvm_ir::IrArray::Index& index,
+      const ExtraOutputGensMap& extra_output_gens);
+
+ private:
+  ReductionEmitter& reduction_emitter_;
+  const ReductionOutputMap& result_ir_arrays_;
+
+  // One state per reduction operand.
+  using ReductionOpState = absl::InlinedVector<ReductionCalculationState, 2>;
+
+  // HloInstruction -> operand_idx -> cache
+  absl::flat_hash_map<const HloInstruction*, ReductionOpState> state_;
+};
+
+// Creates accumulator alloca's, populates them with initial values, generates
+// __shared__ caches and returns the populated object.
+ReductionGroupEmitter::ReductionGroupEmitter(
+    ReductionEmitter& reduction_emitter,
+    absl::Span<const HloReduceInstruction* const> reduce_instr_index_group,
+    const ReductionOutputMap& result_ir_arrays, FusedIrEmitter& fused_emitter)
+    : reduction_emitter_(reduction_emitter),
+      result_ir_arrays_(result_ir_arrays) {
+  const ReductionInfo& reduction_info =
+      reduction_emitter_.reduction_codegen_info_;
+  VLOG(10) << "Emit prologue for reduction: "
+           << reduction_emitter_.fusion_.ToString();
+
+  auto* builder = reduction_emitter_.builder_;
+  for (const HloReduceInstruction* reduce_hlo : reduce_instr_index_group) {
+    for (int op_result_idx = 0;
+         op_result_idx < GetNumOutputs(reduce_hlo->shape()); op_result_idx++) {
+      Shape result_shape = OutputShape(reduce_hlo->shape(), op_result_idx);
+
+      llvm::Type* element_type = llvm_ir::PrimitiveTypeToIrType(
+          result_shape.element_type(), builder->GetInsertBlock()->getModule());
+      llvm::AllocaInst* reduction_input_address =
+          llvm_ir::EmitAllocaAtFunctionEntry(
+              element_type, "reduction_input_address", builder);
+
+      llvm::AllocaInst* result_address = llvm_ir::EmitAllocaAtFunctionEntry(
+          element_type, "partial_reduction_result", builder);
+
+      const HloInstruction* init_value =
+          reduce_hlo->init_values()[op_result_idx];
+
+      // Initialize the partial result with the initial value of the reduction.
+      llvm::Value* init_ir_value = (*fused_emitter.GetGenerator(
+          *init_value))(llvm_ir::IrArray::Index(builder->getInt32Ty()))
+                                       .value();
+
+      builder->CreateStore(init_ir_value, result_address);
+      const Tiling& tiling = reduction_info.GetTiling();
+      auto shared_cache = [&]() -> std::optional<llvm_ir::SharedMemoryTile> {
+        auto* module = reduction_emitter.ir_emitter_context_.llvm_module();
+        if (reduction_info.IsRowReduction()) {
+          // Multi-row reductions do not use shared memory.
+          if (RowReductionGetRowsPerWarp(
+                  reduction_emitter_.ReducedDimensionSize()) > 1) {
+            return std::nullopt;
+          }
+          // Allocate one shared memory element per warp.
+          auto block_size = tiling.GetThreadsPerBlock();
+          CHECK_EQ(block_size[ReductionDimensions::kRowMinorReducedDimension] %
+                       WarpSize(),
+                   0);
+          return llvm_ir::AllocateSharedMemoryTile(
+              module, element_type,
+              {block_size[ReductionDimensions::kRowKeptDimension],
+               block_size[ReductionDimensions::kRowMinorReducedDimension] /
+                   WarpSize()},
+              "shared_cache");
+        }
+        const auto& num_threads = tiling.GetThreadsPerBlock();
+        int n = num_threads[ReductionDimensions::kColReducedDimension];
+        CHECK_EQ(n, num_threads[ReductionDimensions::kColMinorKeptDimension]);
+        // The "+1" is used to avoid bank conflicts.
+        return llvm_ir::AllocateSharedMemoryTile(module, element_type,
+                                                 {n, n + 1}, "shared_cache");
+      }();
+
+      llvm_ir::ElementGenerator input_gen =
+          *fused_emitter.GetGenerator(*reduce_hlo->inputs()[op_result_idx]);
+      SetCalculationStateFor({shared_cache, init_ir_value, result_address,
+                              reduction_input_address, input_gen},
+                             reduce_hlo, op_result_idx);
+    }
+  }
+}
+
+void ReductionEmitter::MaybeEmitFenceForAMDGPU() {
+  auto* module = builder_->GetInsertBlock()->getModule();
+  if (IsAMDGPU(module) &&
+      ir_emitter_context_.rocm_compute_capability().fence_before_barrier()) {
+    builder_->CreateFence(
+        llvm::AtomicOrdering::SequentiallyConsistent,
+        builder_->getContext().getOrInsertSyncScopeID("workgroup"));
+  }
+}
+
+void ReductionEmitter::EmitSyncThreads() {
+  MaybeEmitFenceForAMDGPU();
+  EmitCallToTargetIntrinsic(TargetIntrinsicID::kBarrierId, {}, {}, builder_);
+}
+
+// Builds a thunk that calls a new or reused kernel for a fusion operation.
+//
+// The caller must specify the same launch dimensions for fusions which have
+// the same computation.
+//
+// If a given fusion is implemented using multiple kernels, then for each
+// kernel we should provide a discriminator, such as "init" and "impl".
+//
+// The builder_fn is only invoked if the kernel couldn't be reused.
+//
+// This is the typical usage pattern of this method:
+//
+// ```
+// auto builder_fn = [](std::vector<llvm_ir::IrArray> inputs,
+//                      std::vector<llvm_ir::IrArray> outputs) { ... };
+// TF_ASSIGN_OR_RETURN(
+//   auto thunk,
+//   BuildKernelThunkForFusion(..., launch_dimensions, builder_fn));
+// AddThunkToThunkSequence(std::move(thunk))
+// ```
+absl::StatusOr<std::unique_ptr<Thunk>>
+ReductionEmitter::BuildKernelThunkForFusion(
+    const LaunchDimensions& launch_dimensions, absl::string_view discriminator,
+    std::function<absl::Status(std::vector<llvm_ir::IrArray>,
+                               std::vector<llvm_ir::IrArray>)>
+        kernel_builder_fn) {
+  const HloComputation* fused_computation =
+      fusion_.fused_instructions_computation();
+  std::string suggested_kernel_name = std::string(fusion_.name());
+
+  TF_ASSIGN_OR_RETURN(auto kernel_arguments,
+                      KernelArguments::Create(
+                          ir_emitter_context_.buffer_assignment(), &fusion_));
+
+  auto [status_or_entry, cached] =
+      ir_emitter_context_.kernel_cache().GetWithStatus(
+          fused_computation, kernel_arguments.args(), discriminator,
+          [&]() -> absl::StatusOr<KernelReuseCache::Entry> {
+            llvm::Function* kernel;
+            std::vector<llvm_ir::IrArray> input_arrays;
+            std::vector<llvm_ir::IrArray> output_arrays;
+            TF_ASSIGN_OR_RETURN(
+                std::tie(kernel, input_arrays, output_arrays),
+                BuildKernelPrototype(ir_emitter_context_, suggested_kernel_name,
+                                     kernel_arguments.args(),
+                                     fusion_.operand_count(), launch_dimensions,
+                                     builder_));
+            TF_RETURN_IF_ERROR(kernel_builder_fn(input_arrays, output_arrays));
+            // Shared memory is allocated statically.
+            return {{kernel->getName().str(), launch_dimensions,
+                     /*cluster_dim=*/std::nullopt,
+                     /*shmem_bytes=*/0}};
+          });
+  TF_ASSIGN_OR_RETURN(const KernelReuseCache::Entry* entry, status_or_entry);
+  if (cached) {
+    VLOG(3) << "Reuse: " << suggested_kernel_name << " -> "
+            << entry->kernel_name;
+  }
+
+  return std::make_unique<KernelThunk>(
+      &fusion_, entry->kernel_name, kernel_arguments.args(), launch_dimensions,
+      entry->cluster_dim, entry->shmem_bytes);
+}
+
+absl::Status ReductionGroupEmitter::EmitExtraOutputsForReduce(
+    const Shape& reduction_operand_shape, const llvm_ir::IrArray::Index& index,
+    const ExtraOutputGensMap& extra_output_gens) {
+  if (extra_output_gens.empty()) {
+    return absl::OkStatus();
+  }
+
+  auto* builder = reduction_emitter_.builder_;
+  // Compute all extra output values before writing them. This avoids
+  // overwriting aliased input/output buffers before all reads occurred.
+  std::vector<std::pair<const HloInstruction*, llvm::Value*>>
+      extra_output_ir_values;
+  extra_output_ir_values.reserve(extra_output_gens.size());
+
+  auto get_index = [&](const HloInstruction* instr) {
+    const Shape& s = instr->shape();
+    return ShapeUtil::EqualIgnoringElementType(reduction_operand_shape, s)
+               ? index
+               : index.SourceIndexOfBitcast(reduction_operand_shape, s,
+                                            builder);
+  };
+
+  for (const auto& [instr, generator] : extra_output_gens) {
+    TF_ASSIGN_OR_RETURN(llvm::Value* const extra_output_ir_value,
+                        generator(get_index(instr)));
+    extra_output_ir_values.emplace_back(instr, extra_output_ir_value);
+  }
+
+  for (const auto& [instr, generator] : extra_output_ir_values) {
+    absl::Span<llvm_ir::IrArray const> result_ir = result_ir_arrays_.at(instr);
+    CHECK_EQ(result_ir.size(), 1);
+    result_ir[0].EmitWriteArrayElement(get_index(instr), generator, builder);
+  }
+  return absl::OkStatus();
+}
+
+absl::StatusOr<std::unique_ptr<Thunk>>
+ReductionEmitter::BuildFusedInitializerThunk(const HloInstruction* fusion_root,
+                                             BufferAllocation::Slice dest_slice,
+                                             int output_index) {
+  const HloReduceInstruction* reduce =
+      DynCast<HloReduceInstruction>(fusion_root);
+  TF_RET_CHECK(reduce);
+
+  const HloInstruction* init_value = reduce->init_values()[0];
+  TF_ASSIGN_OR_RETURN(
+      std::optional<std::unique_ptr<Thunk>> constant_init_thunk,
+      BuildConstantInitializerThunk(ir_emitter_context_, fusion_root,
+                                    init_value, dest_slice));
+  if (constant_init_thunk) {
+    return *std::move(constant_init_thunk);
+  }
+
+  const Shape& dest_shape = fusion_root->shape();
+
+  LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
+      dest_shape, ir_emitter_context_.gpu_device_info());
+  const HloComputation* fused_computation =
+      fusion_.fused_instructions_computation();
+
+  auto builder_fn = [&](std::vector<llvm_ir::IrArray> inputs,
+                        std::vector<llvm_ir::IrArray> outputs) -> absl::Status {
+    FusedIrEmitter fused_emitter(elemental_emitter_);
+    for (int i = 0; i < fused_computation->num_parameters(); i++) {
+      fused_emitter.BindGenerator(
+          *fused_computation->parameter_instruction(i),
+          [builder = builder_,
+           input = inputs[i]](llvm_ir::IrArray::Index index) {
+            return input.EmitReadArrayElement(index, builder);
+          });
+    }
+    HloInstruction* instr = fused_computation->root_instruction();
+    if (instr->opcode() == HloOpcode::kTuple) {
+      instr = instr->mutable_operand(output_index);
+    } else {
+      CHECK_EQ(0, output_index);
+    }
+    TF_RET_CHECK(instr->shape().IsArray());
+    TF_ASSIGN_OR_RETURN(auto generator,
+                        fused_emitter.GetGenerator(*instr->operand(1)));
+    TF_RETURN_IF_ERROR(ParallelLoopEmitter(generator, {outputs[output_index]},
+                                           launch_dimensions, builder_)
+                           .EmitLoop(fusion_.name()));
+    return absl::OkStatus();
+  };
+
+  return BuildKernelThunkForFusion(launch_dimensions,
+                                   /*discriminator=*/
+                                   absl::StrCat("init_", output_index),
+                                   builder_fn);
+}
+
+// Emits shuffle-down reduction for the `partial_result_address` using the
+// reduction computation `reducer`, writes output into
+// `partial_result_address`.
+//
+// Multiple partial_result_address inputs happen when doing variadic
+// reduction: each one should get the output value.
+void ReductionGroupEmitter::EmitFullWarpShuffleDownLoopForReduce(
+    const HloComputation* reducer,
+    absl::Span<TypedPointer const> partial_result_addresses,
+    int threads_per_block, int num_results_per_warp) const {
+  // This only works when the block size is a multiple of 32 threads.
+  // We check this here as a mistake in the number of threads per
+  // block is very hard to detect.
+  CHECK_EQ(threads_per_block % 32, 0);
+  CHECK_EQ(WarpSize() % num_results_per_warp, 0);
+
+  auto* builder = reduction_emitter_.builder_;
+  for (int distance = 16 / num_results_per_warp; distance >= 1; distance /= 2) {
+    absl::InlinedVector<llvm::Value*, 2> reduction_params;
+
+    for (auto acc : partial_result_addresses) {
+      reduction_params.push_back(acc.first);
+    }
+
+    for (auto [partial_result_address, element_type] :
+         partial_result_addresses) {
+      int bit_width = llvm_ir::GetSizeInBits(element_type);
+      llvm::Value* result_from_other_lane = llvm_ir::EmitAllocaAtFunctionEntry(
+          element_type, "result_from_other_lane", builder);
+
+      reduction_params.push_back(result_from_other_lane);
+
+      // Bitcast cannot be applied to aggregate types (even packed ones), so
+      // we bitcast addresses of load/store to intN* of the same bit-width.
+      llvm::Type* shuffled_value_type = element_type->isStructTy()
+                                            ? builder->getIntNTy(bit_width)
+                                            : element_type;
+
+      llvm::Value* partial_result =
+          builder->CreateLoad(shuffled_value_type, partial_result_address,
+                              "partial_reduction_result");
+      builder->CreateStore(
+          EmitFullWarpShuffleDown(
+              partial_result, builder->getInt32(distance), builder,
+              reduction_emitter_.ir_emitter_context_.gpu_device_info()),
+          result_from_other_lane);
+    }
+
+    absl::StatusOr<std::vector<llvm::Value*>> returned_scalars =
+        CallNestedComputationWithScalarAddrs(
+            builder, reduction_emitter_.ir_emitter_context_, *reducer,
+            reduction_params);
+    TF_CHECK_OK(returned_scalars.status());
+
+    for (int i = 0; i < returned_scalars->size(); i++) {
+      builder->CreateStore(/*Val=*/returned_scalars->at(i),
+                           /*Ptr=*/partial_result_addresses[i].first);
+    }
+  }
+}
+
+llvm_ir::IrArray::Index ReductionGroupEmitter::GetOutputIndexForReduction(
+    const TilingKernelInfo& tiling_kernel_info,
+    const HloReduceInstruction* reduction, const HloInstruction* root,
+    int output_idx) const {
+  auto* builder = reduction_emitter_.builder_;
+  auto* index_ty = reduction_emitter_.index_ty_;
+
+  // 1d or 2d output index (for row/column reduction).
+  auto projected_index = [&]() -> llvm_ir::IrArray::Index {
+    const auto& reduction_info = reduction_emitter_.reduction_codegen_info_;
+    const auto& offset = tiling_kernel_info.tile_origin;
+    const auto& shape = reduction_info.GetTiling().GetXlaShape();
+    const auto& thread_ids = tiling_kernel_info.thread_id_info.thread_ids;
+    if (reduction_info.IsRowReduction()) {
+      constexpr int kDim = ReductionDimensions::kRowKeptDimension;
+      return {{builder->CreateAdd(offset[kDim], thread_ids[kDim])},
+              {shape.dimensions(kDim)},
+              index_ty};
+    }
+    auto* major_idx = offset[ReductionDimensions::kColMajorKeptDimension];
+    auto* minor_idx = builder->CreateAdd(
+        offset[ReductionDimensions::kColMinorKeptDimension],
+        thread_ids[ReductionDimensions::kColReducedDimension]);
+    return {{major_idx, minor_idx},
+            ShapeUtil::DeleteDimension(
+                ReductionDimensions::kColReducedDimension, shape),
+            index_ty};
+  }();
+
+  auto physical_shape = ShapeUtil::DeleteDimensions(
+      reduction->dimensions(), reduction->operand(output_idx)->shape());
+  auto physical_index =
+      projected_index.SourceIndexOfBitcast(physical_shape, builder);
+  return llvm_ir::IrArray::Index(physical_index.multidim(),
+                                 OutputShape(reduction->shape(), output_idx),
+                                 index_ty)
+      .SourceIndexOfBitcast(OutputShape(root->shape(), output_idx), builder);
+}
+
+void ReductionGroupEmitter::WriteReductionOutput(
+    const TilingKernelInfo& tiling_kernel_info,
+    const HloReduceInstruction* reduction,
+    const std::vector<const HloInstruction*>& roots,
+    const absl::Span<TypedPointer const> values) const {
+  auto* builder = reduction_emitter_.builder_;
+  const auto& reduction_info = reduction_emitter_.reduction_codegen_info_;
+  const HloComputation* reducer = reduction->to_apply();
+  for (const auto& [oidx, typed_ptr] : llvm::enumerate(values)) {
+    auto [output_ptr, type] = typed_ptr;
+    for (auto root : roots) {
+      llvm_ir::IrArray::Index output_index =
+          GetOutputIndexForReduction(tiling_kernel_info, reduction, root, oidx);
+
+      llvm::Value* output_address =
+          result_ir_arrays_.at(root)[oidx].EmitArrayElementAddress(
+              output_index, builder, "output_element_address");
+      if (reduction_info.IsRaceFree()) {
+        FusedIrEmitter fused_emitter(reduction_emitter_.elemental_emitter_);
+        llvm::Value* loaded = builder->CreateLoad(type, output_ptr, "output");
+        fused_emitter.BindGenerator(
+            *reduction,
+            [&](const llvm_ir::IrArray::Index& index) { return loaded; });
+        llvm_ir::ElementGenerator gen = *fused_emitter.GetGenerator(*root);
+        llvm::Value* generated = *gen(output_index);
+        builder->CreateStore(generated, output_address);
+      } else {
+        CHECK_EQ(values.size(), 1);
+        CHECK_EQ(roots.size(), 1);
+        CHECK_EQ(reduction, root)
+            << "output fusion is not allowed for racing reductions";
+        TF_CHECK_OK(EmitAtomicOperationForNestedComputation(
+            builder, reduction_emitter_.ir_emitter_context_, *reducer,
+            output_address, output_ptr, type));
+      }
+    }
+  }
+}
+
+void ReductionGroupEmitter::EmitReductionOutputForRowReduction(
+    const TilingKernelInfo& tiling_kernel_info,
+    const HloReduceInstruction* reduction,
+    const std::vector<const HloInstruction*>& roots) const {
+  const HloComputation* reducer = reduction->to_apply();
+  const auto& thread_id_info = tiling_kernel_info.thread_id_info;
+  const auto& thread_ids = thread_id_info.thread_ids;
+  auto* thread_id_x =
+      thread_ids[ReductionDimensions::kRowMinorReducedDimension];
+  auto constant = [&](uint64_t c) -> llvm::Constant* {
+    return llvm::ConstantInt::get(reduction_emitter_.index_ty_, c);
+  };
+
+  auto* builder = reduction_emitter_.builder_;
+  auto is_zero = [&](llvm::Value* value) {
+    return builder->CreateICmpEQ(value, constant(0));
+  };
+
+  int num_outputs = reducer->num_parameters() / 2;
+  absl::InlinedVector<TypedPointer, 2> current_outputs;
+  for (int output_idx = 0; output_idx < num_outputs; output_idx++) {
+    const auto& state = GetCalculationStateFor(reduction, output_idx);
+    current_outputs.push_back(
+        {state.partial_result_address,
+         state.partial_result_address->getAllocatedType()});
+  }
+
+  const auto& reduction_info = reduction_emitter_.reduction_codegen_info_;
+  const Tiling& tiling = reduction_info.GetTiling();
+  int num_rows_per_warp =
+      RowReductionGetRowsPerWarp(reduction_emitter_.ReducedDimensionSize());
+  EmitFullWarpShuffleDownLoopForReduce(reducer, absl::MakeSpan(current_outputs),
+                                       tiling.GetNumThreadsPerBlock(),
+                                       num_rows_per_warp);
+
+  KernelSupportLibrary ksl(builder);
+  llvm::Value* warp_id = builder->CreateUDiv(thread_id_x, constant(WarpSize()));
+
+  auto emit_write_output = [&](llvm::Value* write_condition,
+                               const absl::Span<TypedPointer const> values) {
+    ksl.If("reduction_write_output", write_condition, [&] {
+      WriteReductionOutput(tiling_kernel_info, reduction, roots, values);
+    });
+  };
+
+  // The major kept dimension and vector dimension are not tiled, so they're
+  // always in bounds.
+  llvm::Value* is_in_bounds_y = builder->CreateICmpULT(
+      thread_ids[ReductionDimensions::kRowKeptDimension],
+      tiling_kernel_info
+          .output_tile_bounds[ReductionDimensions::kRowKeptDimension]);
+
+  ksl.If("thread_in_bounds", is_in_bounds_y, [&] {
+    if (num_rows_per_warp > 1) {
+      llvm::Value* is_writing_thread = is_zero(builder->CreateAnd(
+          thread_id_x,
+          constant(reduction_emitter_.ReducedDimensionSize() - 1)));
+      emit_write_output(is_writing_thread, current_outputs);
+      return;
+    }
+
+    ksl.If("intra_warp_reduce_write", is_zero(thread_id_info.lane_id), [&] {
+      for (int oidx = 0; oidx < num_outputs; oidx++) {
+        auto& state = GetCalculationStateFor(reduction, oidx);
+        state.shared_cache->Store(
+            builder->CreateLoad(current_outputs[oidx].second,
+                                current_outputs[oidx].first),
+            {thread_id_info.thread_ids[ReductionDimensions::kRowKeptDimension],
+             warp_id},
+            builder);
+      }
+    });
+
+    // TODO(cheshire): Don't we want to sync it once for everything in the
+    // output? Not once per each?
+    reduction_emitter_.EmitSyncThreads();
+    ksl.If("inter_warp_reduce", is_zero(warp_id), [&] {
+      absl::InlinedVector<TypedPointer, 2> selected_values;
+      for (int oidx = 0; oidx < num_outputs; oidx++) {
+        auto& state = GetCalculationStateFor(reduction, oidx);
+        llvm::Value* block_accum_addr = state.shared_cache->Address(
+            {thread_id_info.thread_ids[ReductionDimensions::kRowKeptDimension],
+             thread_id_info.lane_id},
+            builder);
+
+        llvm::Type* element_type =
+            state.partial_result_address->getAllocatedType();
+
+        // Ensure initial value address is in generic, not scratch.
+        llvm::Value* initial_value_addr =
+            CastSharedToGlobal(builder,
+                               llvm_ir::EmitAllocaAtFunctionEntry(
+                                   element_type, "initial_value_addr", builder),
+                               element_type, /*name=*/"");
+        builder->CreateStore(state.initial_value, initial_value_addr);
+
+        llvm::Value* warp_exists = builder->CreateICmpULT(
+            thread_id_x,
+            constant(tiling.GetThreadsPerBlock()
+                         [ReductionDimensions::kRowMinorReducedDimension] /
+                     WarpSize()));
+
+        llvm::Value* selected_value = builder->CreateSelect(
+            warp_exists, block_accum_addr, initial_value_addr);
+
+        selected_values.push_back({selected_value, element_type});
+      }
+
+      // If only one warp produces the output element, we don't need to emit
+      // an inter warp reduce. In our tiling, DimX is the minor reduced
+      // dimension. The major reduced dimension is always emitted as a loop.
+      // TODO(b/241414088) If only warp is present, then inter-warp
+      // communication using shared memory and synchronization using barrier is
+      // also unnecessary and should be removed.
+      if (tiling.GetThreadsPerBlock()
+              [ReductionDimensions::kRowMinorReducedDimension] > WarpSize()) {
+        EmitFullWarpShuffleDownLoopForReduce(
+            reducer, absl::MakeSpan(selected_values),
+            tiling.GetNumThreadsPerBlock(), /*num_results_per_warp=*/1);
+      }
+
+      emit_write_output(is_zero(thread_id_x), selected_values);
+    });
+  });
+}
+
+// Same arguments as EmitReductionOutputForRowReduction.
+void ReductionGroupEmitter::EmitReductionOutputForColumnReduction(
+    const TilingKernelInfo& tiling_kernel_info,
+    const HloReduceInstruction* reduction,
+    const std::vector<const HloInstruction*>& roots) const {
+  auto* builder = reduction_emitter_.builder_;
+  KernelSupportLibrary ksl(builder);
+  const HloComputation* reducer = reduction->to_apply();
+  const auto& thread_id_info = tiling_kernel_info.thread_id_info;
+  const auto& thread_ids = thread_id_info.thread_ids;
+
+  auto constant = [&](uint64_t c) -> llvm::Constant* {
+    return llvm::ConstantInt::get(reduction_emitter_.index_ty_, c);
+  };
+  auto is_zero = [&](llvm::Value* value) {
+    return builder->CreateICmpEQ(value, constant(0));
+  };
+  const auto& reduction_info = reduction_emitter_.reduction_codegen_info_;
+  const Tiling& tiling = reduction_info.GetTiling();
+  int num_outputs = reducer->num_parameters() / 2;
+
+  auto* kept_index = thread_ids[ReductionDimensions::kColMinorKeptDimension];
+  auto* reduced_index = thread_ids[ReductionDimensions::kColReducedDimension];
+
+  // Store the transpose in shared memory.
+  for (int output_idx = 0; output_idx < num_outputs; output_idx++) {
+    const auto& state = GetCalculationStateFor(reduction, output_idx);
+    auto* current_output_value =
+        builder->CreateLoad(state.partial_result_address->getAllocatedType(),
+                            state.partial_result_address);
+    state.shared_cache->Store(current_output_value, {kept_index, reduced_index},
+                              builder);
+  }
+
+  reduction_emitter_.EmitSyncThreads();
+
+  // Get transposed element from shared memory.
+  absl::InlinedVector<TypedPointer, 2> shmem_transposed_addrs;
+  for (int output_idx = 0; output_idx < num_outputs; output_idx++) {
+    const auto& state = GetCalculationStateFor(reduction, output_idx);
+    auto* shmem_transposed_addr =
+        state.shared_cache->Address({reduced_index, kept_index}, builder);
+    shmem_transposed_addrs.push_back(
+        {shmem_transposed_addr, state.shared_cache->GetElementType()});
+  }
+
+  EmitFullWarpShuffleDownLoopForReduce(reducer,
+                                       absl::MakeSpan(shmem_transposed_addrs),
+                                       tiling.GetNumThreadsPerBlock(),
+                                       /*num_results_per_warp=*/1);
+
+  // Some warps in the block are completely outside of the bound of the
+  // tensor, so they should not write any output at all.
+  llvm::Value* has_output = builder->CreateAnd(
+      builder->CreateICmpULT(
+          reduced_index,
+          tiling_kernel_info
+              .output_tile_bounds[ReductionDimensions::kColMinorKeptDimension]),
+      builder->CreateICmpULT(
+          kept_index,
+          tiling_kernel_info
+              .output_tile_bounds[ReductionDimensions::kColReducedDimension]));
+
+  ksl.If("reduction_write_output",
+         builder->CreateAnd(has_output, is_zero(thread_id_info.lane_id)), [&] {
+           WriteReductionOutput(tiling_kernel_info, reduction, roots,
+                                shmem_transposed_addrs);
+         });
+}
+
+// Generate a single element of the tile (update the accumulator state) for a
+// given reducer.
+void ReductionGroupEmitter::GenerateElementForReducer(
+    const HloReduceInstruction* reduction,
+    const llvm_ir::IrArray::Index& index) const {
+  HloComputation* reducer = reduction->to_apply();
+  auto* builder = reduction_emitter_.builder_;
+  CHECK_EQ(reducer->num_parameters() % 2, 0);
+
+  absl::InlinedVector<llvm::Value*, 2> reduction_accumulators;
+  absl::InlinedVector<llvm::Value*, 2> reduction_input_value;
+  for (int red_idx = 0; red_idx < reducer->num_parameters() / 2; red_idx++) {
+    const auto& state = GetCalculationStateFor(reduction, red_idx);
+
+    llvm::AllocaInst* input_address = state.input_address;
+    auto input_index =
+        index.SourceIndexOfBitcast(reduction->operand(0)->shape(), builder);
+    llvm::Value* const input_ir_value = *state.input_gen(input_index);
+    builder->CreateStore(input_ir_value, input_address);
+    reduction_accumulators.push_back(state.partial_result_address);
+    reduction_input_value.push_back(input_address);
+  }
+
+  absl::InlinedVector<llvm::Value*, 4> reduction_params;
+  for (llvm::Value* acc : reduction_accumulators) {
+    reduction_params.push_back(acc);
+  }
+  for (llvm::Value* value : reduction_input_value) {
+    reduction_params.push_back(value);
+  }
+
+  // Emit a call to the variadic reducer. Since it may be returning a
+  // tuple, we can't return it directly as a value. Instead, before
+  // the call, we create N (N = # arguments in the tuple) allocas, one
+  // for each returned argument, then when we make the call we pass N
+  // pointers as last parameters, the called computation writes into
+  // those pointers, and we have returned values on the stack (as well
+  // as pointers to them).
+  absl::StatusOr<std::vector<llvm::Value*>> returned_scalars =
+      CallNestedComputationWithScalarAddrs(
+          builder, reduction_emitter_.ir_emitter_context_, *reducer,
+          reduction_params);
+  TF_CHECK_OK(returned_scalars.status());
+
+  for (int i = 0; i < returned_scalars->size(); i++) {
+    builder->CreateStore(returned_scalars->at(i), reduction_accumulators[i]);
+  }
+}
+
+// Emits code for reductions in the output_instructions.
+absl::Status ReductionEmitter::EmitIRForReduction(
+    absl::Span<const HloInstruction* const> instr_index_group,
+    FusedIrEmitter& fused_emitter, const ReductionOutputMap& result_ir_arrays,
+    const Shape& input_shape) {
+  ExtraOutputGensMap extra_output_gens;
+  absl::flat_hash_map<const HloReduceInstruction*,
+                      std::vector<const HloInstruction*>>
+      heroes_to_roots;
+  // Keep a list of deduplicated heroes separate from heroes_to_roots to make
+  // the CodeGen deterministic.
+  std::vector<const HloReduceInstruction*> heroes;
+
+  for (const HloInstruction* hlo : instr_index_group) {
+    auto& hero = FindNonTrivialHero(*hlo);
+    if (IsRealReductionHero(*hlo, hero)) {
+      auto reduction = Cast<HloReduceInstruction>(&hero);
+      if (heroes_to_roots.find(reduction) == heroes_to_roots.end()) {
+        heroes.push_back(reduction);
+      }
+      heroes_to_roots[reduction].push_back(hlo);
+    } else {
+      extra_output_gens[hlo] = *fused_emitter.GetGenerator(*hlo);
+    }
+  }
+
+  CHECK(!heroes.empty()) << " expect at least one reduce instructions.";
+  const Tiling& tiling = reduction_codegen_info_.GetTiling();
+  CHECK_EQ(tiling.GetNumThreadsPerBlock() % WarpSize(), 0);
+  ReductionGroupEmitter group_emitter(*this, heroes, result_ir_arrays,
+                                      fused_emitter);
+
+  TF_ASSIGN_OR_RETURN(
+      TilingKernelInfo tiling_kernel_info,
+      EmitTilingKernel(
+          builder_, tiling, index_ty_,
+          [&](const TilingThreadIdInfo& thread_id_info,
+              const llvm_ir::IrArray::Index& tile_index,
+              absl::Span<llvm::Value* const> tile_dimensions) {
+            auto emit_element =
+                [&](absl::Span<llvm::Value* const> index_in_tile) {
+                  auto index = tile_index.AddOffset(index_in_tile, builder_);
+
+                  // Emit code to generate the input and perform the reduction
+                  // computation for each reduction instruction.
+                  for (const HloReduceInstruction* reduce : heroes) {
+                    group_emitter.GenerateElementForReducer(reduce, index);
+                  }
+
+                  // Emit code to generate the output for the non-reduction
+                  // instructions in the fusion, if any.
+                  TF_CHECK_OK(group_emitter.EmitExtraOutputsForReduce(
+                      ShapeUtil::MakeShape(
+                          F32, reduction_codegen_info_.GetTiling().GetShape()),
+                      index, extra_output_gens));
+                };
+            EmitTile(builder_, reduction_codegen_info_.GetTiling(),
+                     thread_id_info, tile_dimensions, emit_element);
+          }));
+
+  KernelSupportLibrary ksl(builder_);
+  for (auto reduce : heroes) {
+    if (reduction_codegen_info_.IsRowReduction()) {
+      group_emitter.EmitReductionOutputForRowReduction(
+          tiling_kernel_info, reduce, heroes_to_roots[reduce]);
+    } else {
+      group_emitter.EmitReductionOutputForColumnReduction(
+          tiling_kernel_info, reduce, heroes_to_roots[reduce]);
+    }
+  }
+
+  return absl::OkStatus();
+}
+
+absl::StatusOr<FusionEmissionResult> ReductionEmitter::EmitInitializers() {
+  FusionEmissionResult result;
+  if (reduction_codegen_info_.IsRaceFree()) {
+    return result;
+  }
+  // We need to get the dest slice by traversing the slice assigned to
+  // fusion, because instructions inside fusion don't have buffer assignment.
+  //
+  // The order of fusion roots is determined by its position in the result
+  // tuple. For example, in the following fused computation
+  //
+  // %fused_computation {
+  //   %a = ...
+  //   &b = ...
+  //   ROOT %root = tuple(%a, %b)
+  // }
+  //
+  // The fusion root with index = 0 is %a, and the fusion root %b has index 1.
+  // Therefore we can get the ordered slices by calling ForEachSubshape on the
+  // result shape.
+  std::vector<BufferAllocation::Slice> slices;
+  TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
+      fusion_.shape(), [&](const Shape& subshape, ShapeIndex index) {
+        if (!ShapeUtil::IsLeafIndex(fusion_.shape(), index)) {
+          return absl::OkStatus();
+        }
+
+        TF_ASSIGN_OR_RETURN(
+            BufferAllocation::Slice slice,
+            ir_emitter_context_.buffer_assignment().GetUniqueSlice(&fusion_,
+                                                                   index));
+        slices.push_back(slice);
+        return absl::OkStatus();
+      }));
+
+  absl::Span<HloInstructionAdaptor const> fusion_roots =
+      analysis_.fusion_roots();
+  for (int i = 0; i < fusion_roots.size(); ++i) {
+    const HloInstruction* fusion_root = &fusion_roots[i].instruction();
+
+    if (IsReductionFromOrToContiguousDimensions(*fusion_root)) {
+      TF_ASSIGN_OR_RETURN(
+          result.thunks.emplace_back(),
+          BuildFusedInitializerThunk(fusion_root, slices[i], i));
+    }
+  }
+  return result;
+}
+
+absl::Status ReductionEmitter::EmitKernel(
+    const LaunchDimensions& launch_dims, std::vector<llvm_ir::IrArray> inputs,
+    std::vector<llvm_ir::IrArray> outputs) {
+  const HloComputation* fused_computation =
+      fusion_.fused_instructions_computation();
+  FusedIrEmitter fused_emitter(elemental_emitter_);
+  for (int i = 0; i < fused_computation->num_parameters(); i++) {
+    HloInstruction* fused_operand = fused_computation->parameter_instruction(i);
+    fused_emitter.BindGenerator(
+        *fused_operand, [builder = builder_, input = inputs[i],
+                         fused_operand](const llvm_ir::IrArray::Index& index) {
+          return input.EmitReadArrayElement(index, builder,
+                                            fused_operand->name());
+        });
+  }
+
+  // Get outputs.
+  ReductionOutputMap result_ir_arrays;
+
+  int ir_arrays_idx = 0;
+  for (const HloInstructionAdaptor& root : analysis_.fusion_roots()) {
+    int get_num_results = GetNumOutputs(root.shape());
+    result_ir_arrays[&root.instruction()] =
+        absl::MakeSpan(outputs).subspan(ir_arrays_idx, get_num_results);
+    ir_arrays_idx += get_num_results;
+  }
+
+  KernelSupportLibrary ksl(builder_, llvm_ir::UnrollMode::kDefaultUnroll);
+
+  // Use raw block_id_y to select the i-th parallel reduction to run. Using
+  // block_id_y instead of block_id_x simplifies the index calculation
+  // for reduction code generation as the block_id_y is orthogonal to
+  // the indices used within the reductions.
+  const auto& instr_index_groups =
+      reduction_codegen_info_.GetGroups().grouped_roots;
+  Shape reduce_operand_shape = reduction_codegen_info_.GetReduceOperandShape();
+
+  llvm::Value* block_id_y = gpu::EmitCallToTargetIntrinsic(
+      gpu::TargetIntrinsicID::kBlockIdy, {}, {}, builder_);
+  llvm_ir::AddRangeMetadata(0, instr_index_groups.size(),
+                            llvm::cast<llvm::Instruction>(block_id_y),
+                            builder_->GetInsertBlock()->getModule());
+  block_id_y = builder_->CreateZExtOrTrunc(block_id_y, builder_->getInt32Ty());
+  block_id_y->setName("block.id.y");
+  for (int i = 0; i < instr_index_groups.size(); ++i) {
+    TF_RETURN_IF_ERROR(ksl.IfWithStatus(
+        absl::StrCat("reduce-group-", i),
+        builder_->CreateICmpEQ(block_id_y, builder_->getInt32(i)), [&] {
+          return EmitIRForReduction(instr_index_groups[i], fused_emitter,
+                                    result_ir_arrays, reduce_operand_shape);
+        }));
+  }
+
+  return absl::OkStatus();
+}
+
+}  // namespace
+
+absl::StatusOr<FusionEmissionResult> ReductionFusion::EmitInitializers(
+    IrEmitterContext& ir_emitter_context,
+    const HloFusionInstruction& fusion) const {
+  llvm::IRBuilder<> builder(ir_emitter_context.llvm_module()->getContext());
+  return ReductionEmitter(analysis_, reduction_info_, ir_emitter_context,
+                          fusion, &builder)
+      .EmitInitializers();
+}
+
+absl::Status ReductionFusion::EmitKernel(IrEmitterContext& ir_emitter_context,
+                                         const HloFusionInstruction& fusion,
+                                         const LaunchDimensions& launch_dims,
+                                         std::vector<llvm_ir::IrArray> inputs,
+                                         std::vector<llvm_ir::IrArray> outputs,
+                                         llvm::IRBuilder<>* builder) const {
+  return ReductionEmitter(analysis_, reduction_info_, ir_emitter_context,
+                          fusion, builder)
+      .EmitKernel(launch_dims, inputs, outputs);
+}
+
+int ReductionInfo::GetRowsPerWarp() const {
+  if (!is_row_reduction_) return 1;
+  return RowReductionGetRowsPerWarp(
+      tiling_.GetShape()[ReductionDimensions::kRowMinorReducedDimension]);
+}
+
+LaunchDimensions ReductionInfo::launch_dimensions() const {
+  size_t blocks_y = groups_.grouped_roots.size();
+  return {se::BlockDim(/*x=*/tiling_.GetNumBlocks(),
+                       /*y=*/static_cast<int64_t>(blocks_y), /*z=*/1),
+          se::ThreadDim(/*x=*/tiling_.GetNumThreadsPerBlock(),
+                        /*y=*/1, /*z=*/1)};
+}
+
+ReductionInfo ReductionInfo::Create(const HloFusionAnalysis& analysis) {
+  auto* hero_reduction = analysis.FindHeroReduction();
+  CHECK_NE(hero_reduction, nullptr);
+  Shape input_shape = hero_reduction->operand(0)->shape();
+  ReductionDimensions reduction_dimensions =
+      GetReductionKindAndContiguousComponents(*hero_reduction);
+  auto shape = reduction_dimensions.dimensions;
+  VLOG(10) << "is_row_reduction " << reduction_dimensions.is_row_reduction
+           << " " << shape[0] << " " << shape[1] << " " << shape[2];
+  Vector3 reduction_tiling = GetReductionTiling(reduction_dimensions);
+
+  int64_t num_threads_y =
+      reduction_dimensions.is_row_reduction ? 1 : WarpSize();
+  int64_t rows_per_warp =
+      reduction_dimensions.is_row_reduction
+          ? RowReductionGetRowsPerWarp(
+                shape[ReductionDimensions::kRowMinorReducedDimension])
+          : 1;
+  int64_t num_threads_x = [&] {
+    if (reduction_dimensions.is_row_reduction) {
+      if (rows_per_warp > 1) {
+        return shape[ReductionDimensions::kRowMinorReducedDimension];
+      }
+      int64_t max_block_size =
+          MinThreadsXRowReduction(hero_reduction->GetModule()->config());
+      return std::min(
+          max_block_size,
+          RoundUpTo(
+              CeilOfRatio(shape[ReductionDimensions::kRowMinorReducedDimension],
+                          reduction_tiling
+                              [ReductionDimensions::kRowMinorReducedDimension]),
+              WarpSize()));
+    }
+    return WarpSize();
+  }();
+
+  // If we're limited by the size of the x dimension, add additional parallelism
+  // in the y dimension. The code generator doesn't currently support
+  // parallelizing the z dimension (major reduced dimensions). The general
+  // recommendation is to use between 128 and 512 threads, so we just go for
+  // 256. See https://forums.developer.nvidia.com/t/55529
+  constexpr int64_t kThreadsPerBlockTarget = 256;
+  if (reduction_dimensions.is_row_reduction &&
+      num_threads_x * 2 <= kThreadsPerBlockTarget) {
+    int64_t kept_size =
+        reduction_dimensions.dimensions[ReductionDimensions::kRowKeptDimension];
+    // Increase the size of the y dimension as long as there's remaining
+    // parallelism.
+    if (kept_size * num_threads_x <= kThreadsPerBlockTarget) {
+      num_threads_y = kept_size;
+      // num_threads_x is a power of two, but it may be less than 32. If dim_y
+      // is also small, we may have to increase the bound so the total number of
+      // threads is a multiple of 32.
+      while ((num_threads_x * num_threads_y) % 32) ++num_threads_y;
+    } else {
+      num_threads_y = kThreadsPerBlockTarget / num_threads_x;
+    }
+  }
+
+  int vector_size = GetVectorSize(analysis, reduction_dimensions, num_threads_x,
+                                  reduction_tiling);
+
+  absl::InlinedVector<int64_t, 4> num_threads{1, num_threads_y, num_threads_x};
+  absl::InlinedVector<int64_t, 4> tiled_shape{shape[0], shape[1],
+                                              shape[2] / vector_size};
+  absl::InlinedVector<int64_t, 4> tile_per_thread{
+      reduction_tiling[0], reduction_tiling[1],
+      std::max<int64_t>(reduction_tiling[2] / vector_size, 1)};
+  if (rows_per_warp > 1) {
+    // If we produce more than one element per thread, that means the reduced
+    // dimension is small and it can't be tiled - we already have more threads
+    // in a warp than the size of the reduced dimension. The code generator
+    // doesn't currently support tiling the kept dimension, because it just
+    // uses the thread ID as the coordinate.
+    tile_per_thread[2] = 1;
+  }
+  if (vector_size != 1) {
+    num_threads.push_back(1);  // The vector dimension is a loop.
+    tiled_shape.push_back(vector_size);
+    tile_per_thread.push_back(vector_size);
+  }
+
+  Tiling tiling(tiled_shape, tile_per_thread, num_threads,
+                /*loops_to_unroll=*/{false, false, true, false});
+  bool reduction_is_race_free = ReductionIsRaceFree(
+      hero_reduction->GetModule()->config(), reduction_dimensions);
+  return ReductionInfo(analysis, tiling, reduction_dimensions.is_row_reduction,
+                       reduction_is_race_free,
+                       GroupDisjointReductions(analysis, /*for_mlir=*/false),
+                       hero_reduction);
+}
+
+std::optional<IndexingMap> ReductionInfo::ComputeThreadIdToOutputIndexing(
+    int64_t root_index, mlir::MLIRContext* ctx) const {
+  if (!groups_.is_reduction_root[root_index]) {
+    auto map = ComposeIndexingMaps(
+        GetIndexingMapForTiling(tiling_, ctx),
+        GetBitcastMap(tiling_.GetXlaShape(),
+                      analysis_.fusion_root(root_index).shape(), ctx));
+    AddGroupIdConstraint(map, root_index, groups_);
+    return map;
+  }
+  const auto& hero = analysis_.fusion_hero(root_index).instruction();
+
+  auto block_offsets = GetBlockOffsetsForTiling(tiling_, ctx);
+  auto thread_ids = DelinearizeInBoundsIndex(mlir::getAffineDimExpr(0, ctx),
+                                             tiling_.GetThreadsPerBlock());
+
+  auto physical_shape =
+      ShapeUtil::DeleteDimensions(hero.dimensions(), hero.operand(0)->shape());
+  std::vector<DimVar> dimension_ranges{
+      {{0, tiling_.GetNumThreadsPerBlock() - 1}},
+      {},
+      {},
+      {{0, tiling_.GetNumBlocks() - 1}},
+      {{0, static_cast<int64_t>(groups_.grouped_roots.size() - 1)}},
+      {},
+  };
+
+  constexpr int kRowKept = ReductionDimensions::kRowKeptDimension;
+  constexpr int kRowMinorReduced =
+      ReductionDimensions::kRowMinorReducedDimension;
+
+  constexpr int kColMajorKept = ReductionDimensions::kColMajorKeptDimension;
+  constexpr int kColMinorKept = ReductionDimensions::kColMinorKeptDimension;
+  constexpr int kColReduced = ReductionDimensions::kColReducedDimension;
+
+  auto map = [&]() {
+    if (is_row_reduction_) {
+      IndexingMap linear_index(
+          mlir::AffineMap::get(
+              6, 0, block_offsets.getResult(kRowKept) + thread_ids[kRowKept],
+              ctx),
+          dimension_ranges, /*range_vars=*/{}, /*rt_vars=*/{});
+      int rows_per_warp = GetRowsPerWarp();
+      if (rows_per_warp > 1) {
+        linear_index.AddConstraint(
+            thread_ids[kRowMinorReduced] % (WarpSize() / rows_per_warp),
+            {0, 0});
+      } else {
+        linear_index.AddConstraint(thread_ids[kRowMinorReduced], {0, 0});
+      }
+      return ComposeIndexingMaps(
+          linear_index, GetBitcastMap(ShapeUtil::MakeShape(
+                                          PRED, {tiling_.GetShape()[kRowKept]}),
+                                      physical_shape, ctx));
+    }
+
+    mlir::SmallVector<mlir::AffineExpr> projected_dims{
+        block_offsets.getResult(kColMajorKept),
+        block_offsets.getResult(kColMinorKept) + thread_ids[kColReduced]};
+    std::vector<RangeVar> range_vars;
+    if (thread_ids.size() == 4) {
+      int vector_size = tiling_.GetThreadTileSize().back();
+      range_vars.push_back({0, vector_size - 1});
+      projected_dims.push_back(mlir::getAffineSymbolExpr(0, ctx));
+    }
+    IndexingMap projected_index(
+        mlir::AffineMap::get(6, range_vars.size(), projected_dims, ctx),
+        dimension_ranges, range_vars, /*rt_vars=*/{});
+
+    projected_index.AddConstraint(
+        mlir::getAffineDimExpr(
+            KernelFusionInterface::kIndexingMapThreadIdxDims[0], ctx) %
+            WarpSize(),
+        {0, 0});
+    if (!is_row_reduction_) {
+      projected_index.AddConstraint(
+          projected_index.GetAffineMap().getResult(1),
+          {0, tiling_.GetShape()[ReductionDimensions::kColMinorKeptDimension] -
+                  1});
+    }
+
+    return ComposeIndexingMaps(
+        projected_index,
+        GetBitcastMap(ShapeUtil::DeleteDimension(
+                          ReductionDimensions::kColReducedDimension,
+                          tiling_.GetXlaShape()),
+                      physical_shape, ctx));
+  }();
+
+  AddGroupIdConstraint(map, root_index, groups_);
+  map.Simplify();
+  return map;
+}
+
+std::optional<IndexingMap> ReductionInfo::ComputeThreadIdToInputIndexing(
+    int64_t root_index, int64_t hero_operand_index,
+    mlir::MLIRContext* ctx) const {
+  const auto& hero = analysis_.fusion_hero(root_index).instruction();
+  if (groups_.is_reduction_root[root_index] &&
+      hero_operand_index >= hero.operand_count() / 2) {
+    // We don't have indexing for the init values.
+    return std::nullopt;
+  }
+  if (!groups_.is_reduction_root[root_index]) {
+    return ComposeIndexingMaps(
+        *ComputeThreadIdToOutputIndexing(root_index, ctx),
+        *ComputeOutputToInputIndexing(
+             &analysis_.fusion_root(root_index).instruction(), 0, ctx)
+             .indexing_maps[hero_operand_index]
+             .begin());
+  }
+
+  auto map = ComposeIndexingMaps(
+      GetIndexingMapForTiling(tiling_, ctx),
+      GetBitcastMap(tiling_.GetXlaShape(),
+                    hero.operand(hero_operand_index)->shape(), ctx));
+  AddGroupIdConstraint(map, root_index, groups_);
+  map.Simplify();
+  return map;
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/reduction.h b/third_party/xla/xla/service/gpu/fusions/legacy/reduction.h
new file mode 100644
index 0000000..131b4ec
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/legacy/reduction.h
@@ -0,0 +1,190 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_FUSIONS_LEGACY_REDUCTION_H_
+#define XLA_SERVICE_GPU_FUSIONS_LEGACY_REDUCTION_H_
+
+#include <optional>
+#include <utility>
+#include <vector>
+
+#include "absl/log/check.h"
+#include "absl/status/status.h"
+#include "llvm/IR/IRBuilder.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/service/gpu/fusions/fusion_emitter.h"
+#include "xla/service/gpu/fusions/legacy/tiling_util.h"
+#include "xla/service/gpu/fusions/reduction_base.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
+#include "xla/service/gpu/ir_emitter_context.h"
+#include "xla/service/gpu/launch_dimensions.h"
+#include "xla/service/llvm_ir/ir_array.h"
+#include "xla/shape.h"
+
+namespace xla {
+namespace gpu {
+
+class ReductionInfo {
+ public:
+  static ReductionInfo Create(const HloFusionAnalysis& analysis);
+
+  const Tiling& GetTiling() const { return tiling_; }
+  const ReductionGroups& GetGroups() const { return groups_; }
+  Shape GetReduceOperandShape() const {
+    return first_reduce_->operand(0)->shape();
+  }
+
+  bool IsRowReduction() const { return is_row_reduction_; }
+  bool IsRaceFree() const { return is_race_free_; }
+  int GetRowsPerWarp() const;
+
+  std::optional<IndexingMap> ComputeThreadIdToOutputIndexing(
+      int64_t root_index, mlir::MLIRContext* ctx) const;
+
+  std::optional<IndexingMap> ComputeThreadIdToInputIndexing(
+      int64_t root_index, int64_t hero_operand_index,
+      mlir::MLIRContext* ctx) const;
+
+  LaunchDimensions launch_dimensions() const;
+
+ private:
+  ReductionInfo(const HloFusionAnalysis& analysis, Tiling tiling,
+                bool is_row_reduction, bool is_race_free,
+                ReductionGroups groups, const HloInstruction* first_reduce)
+      : analysis_(analysis),
+        tiling_(tiling),
+        is_row_reduction_(is_row_reduction),
+        is_race_free_(is_race_free),
+        groups_(std::move(groups)),
+        first_reduce_(first_reduce) {}
+
+  const HloFusionAnalysis& analysis_;
+  Tiling tiling_;
+  bool is_row_reduction_;
+  bool is_race_free_;
+  ReductionGroups groups_;
+  const HloInstruction* first_reduce_;
+};
+
+// Generates code for reduction to contiguous dimensions.
+//
+// Row reduction uses the following algorithm described in CUDA-like
+// pseudocode:
+//
+// ```
+//  __global__ void reduce(int num_rows, float *in, float out) {
+//    __shared__ float[32] cache;
+//    int offset = blockDim.x * blockIdx.x + threadIdx.x;
+//    if (offset >= num_rows) return;
+//    int tile_bound = std::min(offset + kTileSizeX, num_rows);
+//    float accum = 0;
+//    for (int i=offset; i<num_rows; i+= blockDim.x) {
+//      accum += in[i];
+//    }
+//    accum = warp_reduce(accum);
+//    if (threadIdx.x % WarpSize == 0) {
+//      cache[threadIdx.x / WarpSize] = accum;
+//    }
+//    __syncthreads();
+//    if (threadIdx.x / WarpSize == 0) {
+//      bool warp_exists = threadIdx.x < (blockDim.x / WarpSize);
+//      float block_accum = warp_exists ? cache[threadIdx.x % WarpSize] : 0;
+//      block_accum = warp_reduce(accum);
+//      if (threadIdx.x == 0) {
+//        out += block_accum;
+//      }
+//    }
+//  }
+// ```
+//
+// Column reduction uses the following algorithm:
+//
+// ```
+// void reduce(float** in, float* out) {
+//   __shared__ float[32][33] cache;
+//   int thread_id = GetThreadId();
+//   int block_id = GetBlockId();
+//   int tile_size = 128;
+//
+//   float accum = 0;
+//   for (int i=0; i<tile_size; i++) {
+//     accum += in[thread_id.y * tile_size + i][block_id * 32 + thread_id.x];
+//   }
+//   cache[thread_id.x][thread_id.y] = accum;
+//
+//   __syncthreads();
+//   accum = cache[thread_id.y][thread_id.x];
+//   accum = warp_reduce(accum); // Sum all the values of `accum` in the same
+//                               // warp.
+//
+//   if (thread_id.y % 32 == 0) {
+//     out[block_id * 32 + thread_id.x] = accum;
+//   }
+// }
+// ```
+//
+// Moreover, a heuristic is implemented to divide the reduce instructions
+// into groups for parallelization (see `DivideOutputInstructionsIntoGroups`
+// for details about the heuristic.) Reduce instructions in the same group
+// will run sequentially while different groups will run in parallel.
+//
+// we use raw block_id_y to select the reduce groups for execution without
+// complicating the index calculation in the code generation of the reduce
+// instructions. In other words, a block_id_y is assigned to a group and so
+// different groups can be run in parallel.
+class ReductionFusion : public KernelFusionEmitterBase {
+ public:
+  explicit ReductionFusion(const HloFusionAnalysis& analysis)
+      : analysis_(analysis), reduction_info_(ReductionInfo::Create(analysis)) {}
+
+  std::optional<IndexingMap> ComputeThreadIdToOutputIndexing(
+      int64_t root_index, mlir::MLIRContext* ctx) const override {
+    return reduction_info_.ComputeThreadIdToOutputIndexing(root_index, ctx);
+  }
+
+  std::optional<IndexingMap> ComputeThreadIdToInputIndexing(
+      int64_t root_index, int64_t hero_operand_index,
+      mlir::MLIRContext* ctx) const override {
+    return reduction_info_.ComputeThreadIdToInputIndexing(
+        root_index, hero_operand_index, ctx);
+  }
+
+  LaunchDimensions launch_dimensions() const override {
+    return reduction_info_.launch_dimensions();
+  }
+
+  const ReductionInfo& reduction_info() const { return reduction_info_; }
+
+ protected:
+  absl::StatusOr<FusionEmissionResult> EmitInitializers(
+      IrEmitterContext& ir_emitter_context,
+      const HloFusionInstruction& fusion) const override;
+
+  absl::Status EmitKernel(IrEmitterContext& ir_emitter_context,
+                          const HloFusionInstruction& fusion,
+                          const LaunchDimensions& launch_dims,
+                          std::vector<llvm_ir::IrArray> inputs,
+                          std::vector<llvm_ir::IrArray> outputs,
+                          llvm::IRBuilder<>* builder) const override;
+
+ private:
+  const HloFusionAnalysis& analysis_;
+  ReductionInfo reduction_info_;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_FUSIONS_LEGACY_REDUCTION_H_
diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/reduction_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/reduction_test.cc
new file mode 100644
index 0000000..cc0faf4
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/legacy/reduction_test.cc
@@ -0,0 +1,176 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/fusions/legacy/reduction.h"
+
+#include <memory>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "mlir/IR/MLIRContext.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/service/gpu/fusions/fusion_emitter.h"
+#include "xla/service/gpu/gpu_device_info_for_tests.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
+#include "xla/service/gpu/ir_emitter_context.h"
+#include "xla/service/gpu/model/indexing_analysis.h"
+#include "xla/service/gpu/model/indexing_test_utils.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/tests/hlo_test_base.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::SizeIs;
+
+class ReductionTest : public HloTestBase {
+ protected:
+  stream_executor::DeviceDescription device_info_ =
+      TestGpuDeviceInfo::RTXA6000DeviceInfo();
+  mlir::MLIRContext mlir_context_;
+};
+
+TEST_F(ReductionTest, ThreadIndexingRowReduction) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule module
+
+    add {
+      p0 = f32[] parameter(0)
+      p1 = f32[] parameter(1)
+      ROOT add = f32[] add(p0, p1)
+    }
+
+    fusion {
+      %input = f32[100,64,512] parameter(0)
+      %c0 = f32[] constant(0)
+      ROOT reduce = f32[100,64] reduce(%input, %c0), dimensions={2}, to_apply=add
+    }
+
+    ENTRY entry {
+      %input = f32[100,64,512] parameter(0)
+      ROOT %fusion = f32[100,64] fusion(%input), kind=kInput, calls=fusion
+    })")
+                    .value();
+
+  auto* root = module->entry_computation()->root_instruction();
+  auto analysis = HloFusionAnalysis::Create(*root, device_info_);
+  ReductionFusion fusion(analysis);
+
+  EXPECT_THAT(
+      fusion.ComputeThreadIdToInputIndexing(0, 0, &mlir_context_)->ToString(),
+      MatchIndexingString(R"(
+        (d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3] -> (
+          d3 floordiv 8,
+          (d3 mod 8) * 8 + d0 floordiv 32,
+          (d0 mod 32) * 2 + s2 * 64 + s3
+        )
+        domain:
+        d0 in [0, 255]
+        d1 in [0, 0]
+        d2 in [0, 0]
+        d3 in [0, 799]
+        d4 in [0, 0]
+        d5 in [0, 0]
+        s0 in [0, 0]
+        s1 in [0, 0]
+        s2 in [0, 7]
+        s3 in [0, 1]
+      )"));
+  EXPECT_THAT(
+      fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(),
+      MatchIndexingString(R"(
+        (d0, d1, d2, d3, d4, d5) -> (
+          d3 floordiv 8,
+          (d3 mod 8) * 8 + d0 floordiv 32
+        )
+        domain:
+        d0 in [0, 224]
+        d1 in [0, 0]
+        d2 in [0, 0]
+        d3 in [0, 799]
+        d4 in [0, 0]
+        d5 in [0, 0]
+        d0 mod 32 in [0, 0]
+      )"));
+}
+
+TEST_F(ReductionTest, TwoGroups) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    add {
+      p0 = f32[] parameter(0)
+      p1 = f32[] parameter(1)
+      ROOT add = f32[] add(p0, p1)
+    }
+    fusion {
+      %p0 = f32[2] parameter(0)
+      %p1 = f32[2] parameter(1)
+      %c0 = f32[] constant(-inf)
+      %r0 = f32[] reduce(%p0, %c0), dimensions={0}, to_apply=add
+      %c1 = f32[] constant(inf)
+      %r1 = f32[] reduce(%p1, %c1), dimensions={0}, to_apply=add
+      ROOT %tuple = (f32[], f32[]) tuple(%r0, %r1)
+    }
+    ENTRY entry {
+      %p0 = f32[2] parameter(0)
+      %p1 = f32[2] parameter(1)
+      ROOT %fusion = (f32[], f32[]) fusion(%p0, %p1), kind=kInput, calls=fusion
+    })")
+                    .value();
+
+  auto* root = module->entry_computation()->root_instruction();
+  auto analysis = HloFusionAnalysis::Create(*root, device_info_);
+  ReductionFusion fusion(analysis);
+
+  EXPECT_THAT(fusion.reduction_info().GetGroups().grouped_roots,
+              ElementsAre(ElementsAre(&analysis.fusion_root(0).instruction()),
+                          ElementsAre(&analysis.fusion_root(1).instruction())));
+}
+
+TEST_F(ReductionTest, OneGroup) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    %add {
+      %p0 = c128[] parameter(0)
+      %p1 = c128[] parameter(1)
+      ROOT %add.35 = c128[] add(c128[] %p0, c128[] %p1)
+    }
+    %fusion {
+      %p0 = c128[1,2] parameter(0)
+      %c0 = c128[] constant((0, 0))
+      %reduce = c128[] reduce(%p0, %c0), dimensions={0,1}, to_apply=%add
+      %real = f64[] real(c128[] %reduce)
+      %imag = f64[] imag(c128[] %reduce)
+      %negate = f64[] negate(f64[] %imag)
+      ROOT %tuple.29 = (f64[], f64[]) tuple(f64[] %real, f64[] %negate)
+    }
+    ENTRY entry {
+      %p0 = c128[1,2] parameter(0)
+      ROOT %fusion = (f64[], f64[]) fusion(%p0), kind=kInput, calls=fusion
+    })")
+                    .value();
+
+  auto* root = module->entry_computation()->root_instruction();
+  auto analysis = HloFusionAnalysis::Create(*root, device_info_);
+  ReductionFusion fusion(analysis);
+
+  EXPECT_THAT(fusion.reduction_info().GetGroups().grouped_roots, SizeIs(2));
+}
+
+}  // namespace
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/scatter.cc b/third_party/xla/xla/service/gpu/fusions/legacy/scatter.cc
new file mode 100644
index 0000000..0798788
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/legacy/scatter.cc
@@ -0,0 +1,294 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "xla/service/gpu/fusions/legacy/scatter.h"
+
+#include <cstddef>
+#include <cstdint>
+#include <optional>
+#include <string>
+#include <vector>
+
+#include "absl/log/check.h"
+#include "absl/status/status.h"
+#include "absl/types/span.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Value.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/MLIRContext.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/gpu/elemental_ir_emitter.h"
+#include "xla/service/gpu/fusions/legacy/loop.h"
+#include "xla/service/gpu/gpu_fusible.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
+#include "xla/service/gpu/ir_emission_utils.h"
+#include "xla/service/gpu/ir_emitter_context.h"
+#include "xla/service/gpu/ir_emitter_nested.h"
+#include "xla/service/gpu/launch_dimensions.h"
+#include "xla/service/gpu/model/indexing_analysis.h"
+#include "xla/service/gpu/model/indexing_map.h"
+#include "xla/service/gpu/parallel_loop_emitter.h"
+#include "xla/service/llvm_ir/fused_ir_emitter.h"
+#include "xla/service/llvm_ir/ir_array.h"
+#include "xla/service/llvm_ir/llvm_util.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+ScatterFusion::ScatterFusion(const HloFusionAnalysis& analysis)
+    : analysis_(analysis), config_(ComputeLoopFusionConfig(analysis)) {
+  CHECK_EQ(analysis.fusion_root_count(), 1);
+  CHECK_EQ(analysis.fusion_root(0).opcode(), HloOpcode::kScatter);
+}
+
+LaunchDimensions ScatterFusion::launch_dimensions() const {
+  const auto& updates_shape =
+      analysis_.fusion_root(0).instruction().operands().back()->shape();
+  return CalculateLaunchDimensions(updates_shape, analysis_.device_info());
+}
+
+absl::Status ScatterFusion::EmitKernel(IrEmitterContext& ir_emitter_context,
+                                       const HloFusionInstruction& fusion,
+                                       const LaunchDimensions& launch_dims,
+                                       std::vector<llvm_ir::IrArray> inputs,
+                                       std::vector<llvm_ir::IrArray> outputs,
+                                       llvm::IRBuilder<>* builder) const {
+  GpuElementalIrEmitter elemental_emitter(ir_emitter_context, builder);
+  // Spin up a new fused emitter for the scatter kernel and emit it.
+  FusedIrEmitter scatter_fused_emitter(elemental_emitter);
+  auto* fused_computation = fusion.fused_instructions_computation();
+  for (int i = 0; i < fused_computation->num_parameters(); i++) {
+    auto fused_operand = fused_computation->parameter_instruction(i);
+    scatter_fused_emitter.BindGenerator(
+        *fused_operand, [builder, &input = inputs[i],
+                         fused_operand](llvm_ir::IrArray::Index index) {
+          return input.EmitReadArrayElement(index, builder,
+                                            fused_operand->name());
+        });
+  }
+
+  auto* root = fused_computation->root_instruction();
+  const xla::ScatterDimensionNumbers& scatter_dims =
+      Cast<HloScatterInstruction>(root)->scatter_dimension_numbers();
+
+  std::string name = llvm_ir::IrName(root);
+  const Shape& operand_shape = root->operand(0)->shape();
+  const Shape& scatter_indices_shape = root->operand(1)->shape();
+  const Shape& updates_shape = root->operand(2)->shape();
+  const HloComputation& update_computation = *root->called_computations()[0];
+
+  TF_ASSIGN_OR_RETURN(auto scatter_indices_gen,
+                      scatter_fused_emitter.GetGenerator(*root->operand(1)));
+  TF_ASSIGN_OR_RETURN(auto updates_gen,
+                      scatter_fused_emitter.GetGenerator(*root->operand(2)));
+
+  auto loop_body_emitter =
+      [&](const llvm_ir::IrArray::Index& index) -> absl::Status {
+    std::vector<llvm::Value*> raw_window_multidim;
+    std::vector<llvm::Value*> input_scatter_multidim;
+    std::vector<int64_t> raw_window_bounds;
+
+    auto get_i64_array = [](absl::Span<const int64_t> container) {
+      return llvm::ArrayRef<int64_t>{container.data(),
+                                     static_cast<size_t>(container.size())};
+    };
+
+    llvm::ArrayRef<int64_t> update_window_dims =
+        get_i64_array(scatter_dims.update_window_dims());
+    // Partition the index into window indices and scatter indices.
+    for (int64_t i = 0, e = index.size(); i != e; ++i) {
+      // For window indices also remember the window size, this comes in handy
+      // later.
+      if (llvm::is_contained(update_window_dims, i)) {
+        raw_window_multidim.push_back(index[i]);
+        raw_window_bounds.push_back(updates_shape.dimensions(i));
+      } else {
+        input_scatter_multidim.push_back(index[i]);
+      }
+    }
+    DCHECK_EQ(raw_window_multidim.size(),
+              scatter_dims.update_window_dims_size());
+
+    // Apply inserted_window_dims to the window dimensions.
+    int64_t raw_window_multidim_idx = 0;
+    llvm::SmallVector<llvm::Value*> input_window_multidim;
+    llvm::SmallVector<int64_t> input_window_bounds;
+    const int64_t rank = operand_shape.rank();
+    input_window_bounds.reserve(rank);
+    input_window_multidim.reserve(rank);
+
+    llvm::ArrayRef<int64_t> inserted_window_dims =
+        get_i64_array(scatter_dims.inserted_window_dims());
+    for (int64_t i = 0; i != rank; ++i) {
+      if (llvm::is_contained(inserted_window_dims, i)) {
+        input_window_bounds.push_back(1);  // Trivial dimension.
+        input_window_multidim.push_back(index.GetConstantWithIndexType(0));
+      } else {
+        input_window_bounds.push_back(
+            raw_window_bounds[raw_window_multidim_idx]);
+        input_window_multidim.push_back(
+            raw_window_multidim[raw_window_multidim_idx]);
+        ++raw_window_multidim_idx;
+      }
+    }
+    DCHECK_EQ(input_window_multidim.size(), operand_shape.rank());
+
+    // Insert a 1 dimension at the end if index_vector_dim requests one.
+    Shape scatter_indices_shape_fixed = scatter_indices_shape;
+    if (scatter_dims.index_vector_dim() == scatter_indices_shape.rank()) {
+      scatter_indices_shape_fixed.add_dimensions(1);
+      scatter_indices_shape_fixed.mutable_layout()->add_minor_to_major(
+          scatter_dims.index_vector_dim());
+    }
+
+    // Now load the indices corresponding to the current window from
+    // scatter_indices.
+    std::vector<llvm::Value*> raw_scatter_index_multidim =
+        input_scatter_multidim;
+    raw_scatter_index_multidim.insert(
+        raw_scatter_index_multidim.begin() + scatter_dims.index_vector_dim(),
+        nullptr);
+
+    llvm::ArrayRef<int64_t> scatter_dims_to_operand_dims =
+        get_i64_array(scatter_dims.scatter_dims_to_operand_dims());
+    llvm::Value* is_in_bounds = builder->getTrue();
+    for (int64_t i = 0, e = scatter_dims_to_operand_dims.size(); i != e; ++i) {
+      // Our index is stored along index_vector_dim, insert that into the lookup
+      // index into scatter_indices.
+      raw_scatter_index_multidim[scatter_dims.index_vector_dim()] =
+          index.GetConstantWithIndexType(i);
+      llvm_ir::IrArray::Index raw_scatter_index_index(
+          raw_scatter_index_multidim, scatter_indices_shape_fixed,
+          index.GetType());
+
+      int64_t operand_dim = scatter_dims_to_operand_dims[i];
+      if (operand_dim > rank) {
+        return absl::OutOfRangeError(
+            "The provided scatter_dims_to_operand_dims was out of range.");
+      }
+      TF_ASSIGN_OR_RETURN(
+          llvm::Value* const loaded_scatter_index,
+          scatter_indices_gen(raw_scatter_index_index.SourceIndexOfReshape(
+              scatter_indices_shape_fixed, scatter_indices_shape, builder)));
+      // And add the index to our window index. This yields the output index.
+      llvm::Value* casted_scatter_index = builder->CreateIntCast(
+          loaded_scatter_index, index.GetType(),
+          /*isSigned=*/ShapeUtil::ElementIsSigned(scatter_indices_shape));
+      llvm::Value* dim_offset = builder->CreateAdd(
+          input_window_multidim[operand_dim], casted_scatter_index);
+      input_window_multidim[operand_dim] = dim_offset;
+
+      // Also do the bounds check now.
+      int64_t max_index = operand_shape.dimensions(operand_dim) -
+                          input_window_bounds[operand_dim] + 1;
+      // is_in_bounds = index >= 0 && index < dim_size-window_size+1
+      //   --> index u< dim_size-window_size+1
+      is_in_bounds = builder->CreateAnd(
+          is_in_bounds,
+          builder->CreateICmpULT(casted_scatter_index,
+                                 index.GetConstantWithIndexType(max_index)));
+    }
+
+    llvm_ir::LlvmIfData if_window_in_bounds_data = llvm_ir::EmitIfThenElse(
+        is_in_bounds, "scatter.in_bounds", builder, /*emit_else=*/false);
+    llvm_ir::SetToFirstInsertPoint(if_window_in_bounds_data.true_block,
+                                   builder);
+    // All done, now just read from the calculated input from the window, and do
+    // an atomic store to the calculated location in the output.
+    llvm_ir::IrArray::Index input_window_index(
+        input_window_multidim, outputs.back().GetShape(), index.GetType());
+    llvm::Value* output_address =
+        outputs.back().EmitArrayElementAddress(input_window_index, builder);
+    llvm::Value* input_address = llvm_ir::EmitAllocaAtFunctionEntry(
+        llvm_ir::PrimitiveTypeToIrType(updates_shape.element_type(),
+                                       ir_emitter_context.llvm_module()),
+        "input_address", builder);
+    TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, updates_gen(index));
+    builder->CreateStore(input_ir_value, input_address);
+
+    if (root->unique_indices()) {
+      return CallNestedComputation(
+          builder, ir_emitter_context, update_computation,
+          {output_address, input_address}, output_address);
+    }
+    return EmitAtomicOperationForNestedComputation(
+        builder, ir_emitter_context, update_computation, output_address,
+        input_address, outputs.back().GetElementLlvmType());
+  };
+
+  // Launch a kernel that reads every element in the updates tensor. We could
+  // also do one kernel per window instead if bounds checks turn out to be a
+  // bottleneck.
+  auto index_type =
+      GetIndexTypeForKernel(root, launch_dims.launch_bound(), builder);
+  return ParallelLoopEmitter(loop_body_emitter, updates_shape, launch_dims,
+                             builder)
+      .EmitLoop(name, index_type);
+}
+
+std::optional<IndexingMap> ScatterFusion::ComputeThreadIdToInputIndexing(
+    int64_t root_index, int64_t hero_operand_index,
+    mlir::MLIRContext* ctx) const {
+  const auto* scatter =
+      DynCast<HloScatterInstruction>(&analysis_.fusion_hero(0).instruction());
+  int64_t scatter_operand_count = scatter->scatter_operand_count();
+  // Scatter operands a packed in the following way:
+  // Operand IDs [0, scatter_operand_count - 1] for `scatter operands`.
+  // Operand ID  scatter_operand_count for `scatter indices`.
+  // Operand IDs [scatter_operand_count + 1, 2 * scatter_operand_count] for
+  // `scatter updates`.
+
+  // For scatter operands we do not know the thread ID indexing.
+  if (hero_operand_index < scatter_operand_count) {
+    return std::nullopt;
+  }
+  // Compute thread id mapping based on the first update operand.
+  Shape scatter_update_shape = scatter->scatter_updates().front()->shape();
+  IndexingMap scatter_update_map = GetDefaultThreadIdIndexingMap(
+      launch_dimensions(), config_.unroll_factor, scatter_update_shape, ctx);
+
+  // For scatter indices we project indexing for scatter updates and take the
+  // first result of the affine map only, because they coincide.
+  if (hero_operand_index == scatter_operand_count) {
+    Shape scatter_indices_shape = scatter->scatter_indices()->shape();
+    CHECK_EQ(scatter_indices_shape.rank(), 2) << scatter->ToString();
+    // Create a map from scatter update to scatter indices.
+    IndexingMap updates_to_indices_map{
+        mlir::AffineMap::get(
+            /*dimCount=*/scatter_update_shape.rank(), /*symbolCount=*/1,
+            {mlir::getAffineDimExpr(0, ctx), mlir::getAffineSymbolExpr(0, ctx)},
+            ctx),
+        DimVarsFromTensorSizes(scatter_update_shape.dimensions()),
+        RangeVarsFromTensorSizes({scatter_indices_shape.dimensions(1)}),
+        /*rt_vars=*/{}};
+    auto scatter_indices_map = scatter_update_map * updates_to_indices_map;
+    scatter_indices_map.Simplify();
+    return scatter_indices_map;
+  }
+  return scatter_update_map;
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/scatter.h b/third_party/xla/xla/service/gpu/fusions/legacy/scatter.h
new file mode 100644
index 0000000..862d0b3
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/legacy/scatter.h
@@ -0,0 +1,71 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_FUSIONS_LEGACY_SCATTER_H_
+#define XLA_SERVICE_GPU_FUSIONS_LEGACY_SCATTER_H_
+
+#include <optional>
+#include <vector>
+
+#include "absl/log/check.h"
+#include "absl/status/status.h"
+#include "llvm/IR/IRBuilder.h"
+#include "mlir/IR/MLIRContext.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/gpu/fusions/fusion_emitter.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
+#include "xla/service/gpu/ir_emitter_context.h"
+#include "xla/service/gpu/launch_dimensions.h"
+#include "xla/service/gpu/model/indexing_map.h"
+#include "xla/service/llvm_ir/ir_array.h"
+
+namespace xla {
+namespace gpu {
+
+// A scatter, implemented as a loop over the updates. All scatters are in-place.
+class ScatterFusion : public KernelFusionEmitterBase {
+ public:
+  explicit ScatterFusion(const HloFusionAnalysis& analysis);
+
+  LaunchDimensions launch_dimensions() const override;
+
+  std::optional<IndexingMap> ComputeThreadIdToOutputIndexing(
+      int64_t root_index, mlir::MLIRContext* ctx) const override {
+    // The kernel iterates over updates, whose correspondence to output
+    // elements cannot be computed statically.
+    return std::nullopt;
+  }
+
+  std::optional<IndexingMap> ComputeThreadIdToInputIndexing(
+      int64_t root_index, int64_t hero_operand_index,
+      mlir::MLIRContext* ctx) const override;
+
+ protected:
+  absl::Status EmitKernel(IrEmitterContext& ir_emitter_context,
+                          const HloFusionInstruction& fusion,
+                          const LaunchDimensions& launch_dims,
+                          std::vector<llvm_ir::IrArray> inputs,
+                          std::vector<llvm_ir::IrArray> outputs,
+                          llvm::IRBuilder<>* builder) const override;
+
+ private:
+  const HloFusionAnalysis& analysis_;
+  LaunchDimensionsConfig config_;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_FUSIONS_LEGACY_SCATTER_H_
diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/scatter_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/scatter_test.cc
new file mode 100644
index 0000000..71eea76
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/legacy/scatter_test.cc
@@ -0,0 +1,224 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "xla/service/gpu/fusions/legacy/scatter.h"
+
+#include <optional>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "mlir/IR/MLIRContext.h"
+#include "xla/service/gpu/fusions/fusions.h"
+#include "xla/service/gpu/gpu_device_info_for_tests.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
+#include "xla/service/gpu/model/affine_map_printer.h"
+#include "xla/service/gpu/model/indexing_test_utils.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+class ScatterFusionTest : public HloTestBase {
+ public:
+  void SetUp() override {
+    HloTestBase::SetUp();
+    printer_ =
+        AffineMapPrinter({"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"},
+                         {"chunk_id", "unroll_id", "index_id"});
+  }
+  DebugOptions GetDebugOptionsForTest() override {
+    auto opts = HloTestBase::GetDebugOptionsForTest();
+    opts.set_xla_gpu_mlir_emitter_level(0);
+    return opts;
+  }
+
+ protected:
+  AffineMapPrinter printer_;
+  mlir::MLIRContext mlir_context_;
+};
+
+TEST_F(ScatterFusionTest, ScatterFusion) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule module
+
+    add (lhs: f32[], rhs: f32[]) -> f32[] {
+      lhs = f32[] parameter(0)
+      rhs = f32[] parameter(1)
+      ROOT sum = f32[] add(lhs, rhs)
+    }
+
+    fused_computation {
+      %input = f32[2,9] parameter(0)
+      %indices = s32[3] parameter(1)
+      %updates = f32[3,9] parameter(2)
+      ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates),
+          to_apply=add,
+          update_window_dims={1},
+          inserted_window_dims={0},
+          scatter_dims_to_operand_dims={0},
+          index_vector_dim=1
+    }
+
+    ENTRY entry {
+      %input = f32[2,9] parameter(0)
+      %indices = s32[3] parameter(1)
+      %updates = f32[3,9] parameter(2)
+      ROOT %fusion = f32[2,9] fusion(%input, %indices, %updates), kind=kLoop, calls=fused_computation
+    })")
+                    .value();
+
+  stream_executor::DeviceDescription device_info =
+      TestGpuDeviceInfo::RTXA6000DeviceInfo();
+
+  auto* root = module->entry_computation()->root_instruction();
+  auto analysis_fused = HloFusionAnalysis::Create(*root, device_info);
+
+  auto emitter =
+      GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis_fused});
+  auto scatter_fusion = dynamic_cast<ScatterFusion*>(emitter.get());
+  ASSERT_NE(scatter_fusion, nullptr);
+  EXPECT_EQ(scatter_fusion->launch_dimensions().launch_bound(),
+            3 * 9 /* updates size */);
+}
+
+TEST_F(ScatterFusionTest, ThreadIdIndexing) {
+  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"(
+    HloModule module
+
+    computation {
+      %p0 = f32[] parameter(0)
+      %p1 = f32[] parameter(1)
+      %p2 = f32[] parameter(2)
+      %p3 = f32[] parameter(3)
+      ROOT %tuple = (f32[], f32[]) tuple(f32[] %p2, f32[] %p3)
+    }
+    scatter {
+      %operand0 = f32[300,200] parameter(0)
+      %operand1 = f32[300,200] parameter(1)
+      %indices = s32[42,1] parameter(2)
+      %update.1 = f32[42,10,20] parameter(3)
+      %update.2 = f32[42,10,20]parameter(4)
+
+      ROOT %scatter = (f32[300,200], f32[300,200]) scatter(
+          f32[300,200] %operand0,
+          f32[300,200] %operand1,
+          s32[42,1] %indices,
+          f32[42,10,20] %update.1,
+          f32[42,10,20] %update.2
+        ),
+        update_window_dims={1,2},
+        inserted_window_dims={},
+        scatter_dims_to_operand_dims={0},
+        index_vector_dim=1,
+        to_apply=computation
+    }
+    ENTRY entry {
+      %operand0 = f32[300,200] parameter(0)
+      %operand1 = f32[300,200] parameter(1)
+      %indices = s32[42,1] parameter(2)
+      %update.1 = f32[42,10,20] parameter(3)
+      %update.2 = f32[42,10,20]parameter(4)
+      ROOT %fusion = (f32[300,200], f32[300,200]) fusion(
+        %operand0, %operand1, %indices, %update.1, %update.2),
+        kind=kLoop, calls=scatter
+    }
+  )"));
+  stream_executor::DeviceDescription device_info =
+      TestGpuDeviceInfo::RTXA6000DeviceInfo();
+
+  auto* root = module->entry_computation()->root_instruction();
+  auto analysis_fused = HloFusionAnalysis::Create(*root, device_info);
+
+  auto emitter =
+      GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis_fused});
+  auto fusion = dynamic_cast<ScatterFusion*>(emitter.get());
+  ASSERT_NE(fusion, nullptr);
+
+  constexpr auto kUpdatesIndexing = R"(
+    (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (
+      (bl_x * 128 + th_x) floordiv 200,
+      ((bl_x * 128 + th_x) floordiv 20) mod 10,
+      (bl_x * 128 + th_x) mod 20
+    )
+    domain:
+    th_x in [0, 127]
+    th_y in [0, 0]
+    th_z in [0, 0]
+    bl_x in [0, 65]
+    bl_y in [0, 0]
+    bl_z in [0, 0]
+    chunk_id in [0, 0]
+    unroll_id in [0, 0]
+    bl_x * 128 + th_x in [0, 8399]
+  )";
+  EXPECT_THAT(
+      fusion
+          ->ComputeThreadIdToInputIndexing(
+              /*root_index=*/0, /*hero_operand_index=*/3, &mlir_context_)
+          ->ToString(printer_),
+      MatchIndexingString(kUpdatesIndexing));
+  EXPECT_THAT(
+      fusion
+          ->ComputeThreadIdToInputIndexing(
+              /*root_index=*/0, /*hero_operand_index=*/4, &mlir_context_)
+          ->ToString(printer_),
+      MatchIndexingString(kUpdatesIndexing));
+  EXPECT_THAT(
+      fusion
+          ->ComputeThreadIdToInputIndexing(
+              /*root_index=*/1, /*hero_operand_index=*/3, &mlir_context_)
+          ->ToString(printer_),
+      MatchIndexingString(kUpdatesIndexing));
+  EXPECT_THAT(
+      fusion
+          ->ComputeThreadIdToInputIndexing(
+              /*root_index=*/1, /*hero_operand_index=*/4, &mlir_context_)
+          ->ToString(printer_),
+      MatchIndexingString(kUpdatesIndexing));
+
+  constexpr auto kIndicesIndexing = R"(
+    (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id, index_id] ->
+      ((bl_x * 128 + th_x) floordiv 200, 0)
+    domain:
+    th_x in [0, 127]
+    th_y in [0, 0]
+    th_z in [0, 0]
+    bl_x in [0, 65]
+    bl_y in [0, 0]
+    bl_z in [0, 0]
+    chunk_id in [0, 0]
+    unroll_id in [0, 0]
+    index_id in [0, 0]
+    bl_x * 128 + th_x in [0, 8399]
+  )";
+  EXPECT_THAT(
+      fusion
+          ->ComputeThreadIdToInputIndexing(
+              /*root_index=*/0, /*hero_operand_index=*/2, &mlir_context_)
+          ->ToString(printer_),
+      MatchIndexingString(kIndicesIndexing));
+  EXPECT_THAT(
+      fusion
+          ->ComputeThreadIdToInputIndexing(
+              /*root_index=*/1, /*hero_operand_index=*/2, &mlir_context_)
+          ->ToString(printer_),
+      MatchIndexingString(kIndicesIndexing));
+}
+
+}  // namespace
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/tiling_util.cc b/third_party/xla/xla/service/gpu/fusions/legacy/tiling_util.cc
new file mode 100644
index 0000000..a1a7acb
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/legacy/tiling_util.cc
@@ -0,0 +1,351 @@
+/*Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/fusions/legacy/tiling_util.h"
+
+#include <cstdint>
+#include <limits>
+#include <string>
+#include <vector>
+
+#include "absl/container/inlined_vector.h"
+#include "absl/log/check.h"
+#include "absl/strings/str_cat.h"
+#include "absl/types/span.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/Value.h"
+#include "llvm/Support/Casting.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/MLIRContext.h"
+#include "xla/service/gpu/ir_emission_utils.h"
+#include "xla/service/gpu/model/indexing_analysis.h"
+#include "xla/service/gpu/target_util.h"
+#include "xla/service/llvm_ir/ir_array.h"
+#include "xla/service/llvm_ir/kernel_support_library.h"
+#include "xla/service/llvm_ir/llvm_loop.h"
+#include "xla/service/llvm_ir/llvm_util.h"
+#include "xla/shape_util.h"
+#include "xla/util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+using mlir::AffineExpr;
+using mlir::AffineMap;
+using mlir::MLIRContext;
+
+void EmitTileRec(const TilingThreadIdInfo& thread_id_info, const Tiling& tiling,
+                 int dim, absl::InlinedVector<llvm::Value*, 4> tile_idx,
+                 absl::Span<llvm::Value* const> tile_dimensions,
+                 llvm::IRBuilder<>* b, const TileElementGenerator& emit_elem) {
+  llvm::Type* index_ty = thread_id_info.thread_id->getType();
+  auto constant = [&](int64_t val) {
+    return llvm::ConstantInt::get(index_ty, val);
+  };
+
+  auto recurse = [&] {
+    if (dim == tile_idx.size() - 1) {
+      emit_elem(tile_idx);
+    } else {
+      EmitTileRec(thread_id_info, tiling, dim + 1, tile_idx, tile_dimensions, b,
+                  emit_elem);
+    }
+  };
+
+  bool unroll = tiling.GetLoopsToUnroll()[dim];
+  KernelSupportLibrary ksl(b, unroll ? llvm_ir::UnrollMode::kFullyUnroll
+                                     : llvm_ir::UnrollMode::kDefaultUnroll);
+
+  if (tiling.GetBlockTileSize()[dim] == 1) {
+    tile_idx[dim] = constant(0);
+    recurse();
+  } else if (unroll) {
+    // TODO(jreiffers): Check if this unrolling does anything useful.
+    int64_t stride = tiling.GetThreadsPerBlock()[dim];
+    int64_t dim_size = tiling.GetThreadTileSize()[dim];
+
+    auto make_loop = [&](bool emit_bounds_checks) {
+      auto body = [&, emit_bounds_checks](llvm::Value* i) {
+        tile_idx[dim] = b->CreateAdd(i, thread_id_info.thread_ids[dim]);
+        if (emit_bounds_checks) {
+          auto* in_bounds =
+              b->CreateICmpULT(tile_idx[dim], tile_dimensions[dim]);
+          ksl.If("x_in_tile", in_bounds, recurse);
+        } else {
+          recurse();
+        }
+      };
+      return [&, body] {
+        ksl.For(absl::StrCat("loop", dim), constant(0),
+                constant(dim_size * stride), constant(stride), body);
+      };
+    };
+    if (stride > 1 && dim_size > 1) {
+      // Most tiles will be full, so we emit a single bounds check for those.
+      auto* is_full_tile = b->CreateICmpEQ(
+          constant(tiling.GetBlockTileSize()[dim]), tile_dimensions[dim]);
+      ksl.If("is_full_tile", is_full_tile, make_loop(false), make_loop(true));
+    } else {
+      make_loop(true)();
+    }
+  } else {
+    // All dimensions are strided (thread 0 processes elements 0, num_threads,
+    // num_threads+2, ...; thread 1 processes elements 1, num_threads + 1 and so
+    // on).
+    ksl.For(absl::StrCat("loop", dim), /*start=*/thread_id_info.thread_ids[dim],
+            /*end=*/tile_dimensions[dim],
+            /*step=*/tiling.GetThreadsPerBlock()[dim], [&](llvm::Value* i) {
+              tile_idx[dim] = i;
+              recurse();
+            });
+  }
+}
+
+}  // namespace
+
+void EmitTile(llvm::IRBuilder<>* builder, const Tiling& tiling,
+              const TilingThreadIdInfo& thread_id_info,
+              absl::Span<llvm::Value* const> tile_dimensions,
+              const TileElementGenerator& emit_elem_function) {
+  absl::InlinedVector<llvm::Value*, 4> tile_idx(tiling.GetShape().size());
+  EmitTileRec(thread_id_info, tiling, 0, tile_idx, tile_dimensions, builder,
+              emit_elem_function);
+}
+
+namespace {
+
+// Emits current block id.
+llvm::Value* EmitBlockId(llvm::IRBuilder<>* builder, int32_t num_blocks,
+                         llvm::Type* index_ty) {
+  llvm::Value* block_id =
+      EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockIdx, {}, {}, builder);
+  if (num_blocks != 0) {
+    llvm_ir::AddRangeMetadata(0, num_blocks,
+                              llvm::cast<llvm::Instruction>(block_id),
+                              builder->GetInsertBlock()->getModule());
+  }
+  auto ret = builder->CreateIntCast(block_id, index_ty, /*isSigned=*/true);
+  ret->setName("block.id.x");
+  return ret;
+}
+
+// Emits current thread id with the given type.
+//
+// Sets the return value range to [0, threads_per_block).
+llvm::Value* EmitThreadId(llvm::IRBuilder<>* builder, int64_t threads_per_block,
+                          llvm::Type* index_ty) {
+  // Calculate (y, x) coordinates respectively in the 2D view of thread block,
+  // defined by (num_thread_y, num_thread_x) from thread_id.
+  llvm::CallInst* thread_id =
+      EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, builder);
+  llvm_ir::AddRangeMetadata(0, threads_per_block, thread_id,
+                            builder->GetInsertBlock()->getModule());
+  auto ret = builder->CreateIntCast(thread_id, index_ty, /*isSigned=*/true);
+  ret->setName("thread.id.x");
+  return ret;
+}
+
+// Emits the LLVM values for thread_id, block_id, coordinates of the current
+// tile and strides of the loops to iterate over the current tile.
+absl::StatusOr<TilingThreadIdInfo> EmitThreadIdInfo(llvm::IRBuilder<>* builder,
+                                                    const Tiling& tiling,
+                                                    llvm::Type* index_ty) {
+  auto constant = [&](uint64_t c) -> llvm::Constant* {
+    return llvm::ConstantInt::get(index_ty, c);
+  };
+  int64_t num_blocks = tiling.GetNumBlocks();
+  if (num_blocks > (int64_t)std::numeric_limits<uint32_t>::max()) {
+    return FailedPrecondition(
+        "Number of physical blocks (%d) does not fit in an i32 in tiling "
+        "scheme: %s",
+        num_blocks, tiling.ToString());
+  }
+
+  TilingThreadIdInfo info;
+  info.thread_id =
+      EmitThreadId(builder, tiling.GetNumThreadsPerBlock(), index_ty);
+  info.block_id = EmitBlockId(builder, num_blocks, index_ty);
+
+  for (auto [dim, stride] : llvm::enumerate(tiling.GetThreadStrides())) {
+    int64_t size = tiling.GetThreadsPerBlock()[dim];
+    if (size == 1) {
+      info.thread_ids.emplace_back(constant(0));
+    } else {
+      auto& dim_id = info.thread_ids.emplace_back(info.thread_id);
+      if (stride > 1) {
+        dim_id = builder->CreateUDiv(dim_id, constant(stride));
+      }
+      if (dim) {
+        dim_id = builder->CreateURem(dim_id, constant(size));
+      }
+      dim_id->setName(absl::StrCat("thread.id.", dim));
+    }
+  }
+
+  info.lane_id =
+      builder->CreateURem(info.thread_id, constant(WarpSize()), "lane_id");
+  return info;
+}
+
+AffineMap GetTilingAffineMap(llvm::ArrayRef<AffineExpr> exprs,
+                             int64_t num_symbols) {
+  return AffineMap::get(
+      /*dimCount=*/6, /*symbolCount=*/num_symbols, exprs,
+      exprs[0].getContext());
+}
+
+}  // namespace
+
+absl::StatusOr<TilingKernelInfo> EmitTilingKernel(
+    llvm::IRBuilder<>* builder, const Tiling& tiling, llvm::Type* index_ty,
+    const TileGenerator& tile_element_generator) {
+  absl::Span<const int64_t> dims_in_elems = tiling.GetShape();
+  const auto& block_counts = tiling.GetBlockCounts();
+  auto constant = [&](uint64_t c) -> llvm::Constant* {
+    return llvm::ConstantInt::get(index_ty, c);
+  };
+
+  TF_ASSIGN_OR_RETURN(TilingThreadIdInfo thread_id_info,
+                      EmitThreadIdInfo(builder, tiling, index_ty));
+
+  KernelSupportLibrary ksl(builder, llvm_ir::UnrollMode::kDefaultUnroll);
+
+  const llvm_ir::IrArray::Index block_coords(
+      thread_id_info.block_id,
+      ShapeUtil::MakeShape(PRED /*arbitrary*/, block_counts), builder);
+
+  absl::InlinedVector<llvm::Value*, 4> tile_dimensions;
+  for (int i = 0; i < block_counts.size(); ++i) {
+    int64_t block_tile_size = tiling.GetBlockTileSize()[i];
+    if (dims_in_elems[i] % block_tile_size == 0) {
+      // The block tile size evenly divides the tiled shape -> no need to emit
+      // the bounds check.
+      tile_dimensions.push_back(constant(block_tile_size));
+    } else {
+      // Only the last tile in each dimension may not have full size.
+      llvm::Value* is_last =
+          builder->CreateICmpEQ(block_coords[i], constant(block_counts[i] - 1));
+      int64_t partial_row =
+          dims_in_elems[i] - (block_counts[i] - 1) * block_tile_size;
+      tile_dimensions.push_back(builder->CreateSelect(
+          is_last, constant(partial_row), constant(block_tile_size),
+          absl::StrCat("tile_bound.", i)));
+    }
+  }
+
+  llvm_ir::IrArray::Index tile_offset = [&] {
+    std::vector<llvm::Value*> elem_multi_index = block_coords.multidim();
+    llvm::Type* index_ty = block_coords.GetType();
+    for (int i = 0; i < block_counts.size(); ++i) {
+      elem_multi_index[i] = builder->CreateMul(
+          block_coords[i],
+          llvm::ConstantInt::get(index_ty, tiling.GetBlockTileSize()[i]),
+          absl::StrCat("tile_origin.", i));
+    }
+    return llvm_ir::IrArray::Index(elem_multi_index, tiling.GetShape(),
+                                   index_ty);
+  }();
+
+  tile_element_generator(thread_id_info, tile_offset, tile_dimensions);
+  return {{tile_dimensions, tile_offset, thread_id_info}};
+}
+
+AffineMap GetBlockOffsetsForTiling(
+    absl::Span<const int64_t> num_blocks,
+    absl::Span<const int64_t> tile_sizes_per_block, int64_t rank,
+    MLIRContext* mlir_context) {
+  auto offsets =
+      DelinearizeInBoundsIndex(getAffineDimExpr(3, mlir_context), num_blocks);
+  for (auto&& [offset, tile_size] : llvm::zip(offsets, tile_sizes_per_block)) {
+    offset = offset * tile_size;
+  }
+  return GetTilingAffineMap(offsets, rank);
+}
+
+AffineMap GetBlockOffsetsForTiling(const Tiling& tiling,
+                                   MLIRContext* mlir_context) {
+  return GetBlockOffsetsForTiling(tiling.GetBlockCounts(),
+                                  tiling.GetBlockTileSize(),
+                                  tiling.GetShape().size(), mlir_context);
+}
+
+AffineMap GetThreadOffsetsForTiling(
+    absl::Span<const int64_t> num_threads,
+    absl::Span<const int64_t> tile_sizes_per_thread, int64_t rank,
+    MLIRContext* mlir_context) {
+  auto offsets =
+      DelinearizeInBoundsIndex(getAffineDimExpr(0, mlir_context), num_threads);
+  for (int dim = 0; dim < rank; ++dim) {
+    if (tile_sizes_per_thread[dim] > 1) {
+      offsets[dim] = offsets[dim] +
+                     getAffineSymbolExpr(dim, mlir_context) * num_threads[dim];
+    }
+  }
+  return GetTilingAffineMap(offsets, rank);
+}
+
+AffineMap GetThreadOffsetsForTiling(const Tiling& tiling,
+                                    MLIRContext* mlir_context) {
+  return GetThreadOffsetsForTiling(tiling.GetThreadsPerBlock(),
+                                   tiling.GetThreadTileSize(),
+                                   tiling.GetShape().size(), mlir_context);
+}
+
+IndexingMap GetIndexingMapForTiling(const Tiling& tiling,
+                                    MLIRContext* mlir_context) {
+  return GetIndexingMapForTiling(
+      GetBlockOffsetsForTiling(tiling, mlir_context),
+      GetThreadOffsetsForTiling(tiling, mlir_context),
+      tiling.GetNumThreadsPerBlock(), tiling.GetNumBlocks(),
+      tiling.GetThreadTileSize(), tiling.GetShape());
+}
+
+IndexingMap GetIndexingMapForTiling(AffineMap block_offsets,
+                                    AffineMap thread_offsets,
+                                    int64_t threads_per_block,
+                                    int64_t num_blocks,
+                                    absl::Span<const int64_t> thread_tile_sizes,
+                                    absl::Span<const int64_t> tiled_shape) {
+  auto* mlir_context = block_offsets.getContext();
+  llvm::SmallVector<AffineExpr, 4> offsets;
+  offsets.reserve(block_offsets.getNumResults());
+  for (auto [block, thread] :
+       llvm::zip(block_offsets.getResults(), thread_offsets.getResults())) {
+    offsets.push_back(block + thread);
+  }
+  std::vector<DimVar> dimension_ranges{
+      {{0, threads_per_block - 1}}, {}, {}, {{0, num_blocks - 1}}, {}, {},
+  };
+  auto affine_map = mlir::AffineMap::get(block_offsets.getNumDims(),
+                                         block_offsets.getNumSymbols(), offsets,
+                                         mlir_context);
+  IndexingMap map{affine_map, dimension_ranges,
+                  RangeVarsFromTensorSizes(thread_tile_sizes), /*rt_vars=*/{}};
+  for (int i = 0; i < tiled_shape.size(); ++i) {
+    map.AddConstraint(affine_map.getResult(i), {0, tiled_shape[i] - 1});
+  }
+  return map;
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/tiling_util.h b/third_party/xla/xla/service/gpu/fusions/legacy/tiling_util.h
new file mode 100644
index 0000000..de367e3
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/legacy/tiling_util.h
@@ -0,0 +1,215 @@
+/*Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_FUSIONS_LEGACY_TILING_UTIL_H_
+#define XLA_SERVICE_GPU_FUSIONS_LEGACY_TILING_UTIL_H_
+
+#include <cstdint>
+#include <functional>
+#include <string>
+
+#include "absl/container/inlined_vector.h"
+#include "absl/log/check.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_join.h"
+#include "absl/types/span.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Type.h"
+#include "llvm/IR/Value.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/MLIRContext.h"
+#include "xla/service/gpu/model/indexing_map.h"
+#include "xla/service/llvm_ir/ir_array.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/util.h"
+#include "xla/xla_data.pb.h"
+
+namespace xla {
+namespace gpu {
+
+// Describes tiling used by the kernel.
+//
+// Used by reduction and transpose emitters.
+class Tiling {
+ public:
+  Tiling(absl::Span<const int64_t> shape, absl::Span<const int64_t> tile_sizes,
+         absl::Span<const int64_t> num_threads,
+         // By default, don't unroll anything.
+         absl::InlinedVector<bool, 4> loops_to_unroll = {})
+      : shape_{shape.begin(), shape.end()},
+        tile_sizes_per_thread_{tile_sizes.begin(), tile_sizes.end()},
+        tile_sizes_per_block_(shape.size()),
+        num_threads_{num_threads.begin(), num_threads.end()},
+        num_blocks_(shape.size()),
+        loops_to_unroll_(loops_to_unroll) {
+    for (int64_t i = 0; i < shape.size(); ++i) {
+      tile_sizes_per_block_[i] = tile_sizes[i] * num_threads[i];
+      CHECK_NE(tile_sizes_per_block_[i], 0);
+      num_blocks_[i] = CeilOfRatio(shape[i], tile_sizes_per_block_[i]);
+      CHECK_NE(num_blocks_[i], 0);
+    }
+    if (loops_to_unroll_.empty()) loops_to_unroll_.resize(shape.size());
+  }
+  Tiling() = default;
+
+  std::string ToString() const {
+    return absl::StrJoin(
+        {absl::StrFormat("shape = {%s}", absl::StrJoin(shape_, ", ")),
+         absl::StrFormat("tile_sizes = {%s}",
+                         absl::StrJoin(tile_sizes_per_thread_, ", ")),
+         absl::StrFormat("num_threads = {%s}",
+                         absl::StrJoin(num_threads_, ", "))},
+        ", ");
+  }
+
+  // Number of elements in each dimension.
+  const absl::InlinedVector<int64_t, 4>& GetShape() const { return shape_; }
+  xla::Shape GetXlaShape(PrimitiveType element_type = F32) const {
+    return ShapeUtil::MakeShape(element_type, shape_);
+  }
+
+  const absl::InlinedVector<int64_t, 4>& GetBlockCounts() const {
+    return num_blocks_;
+  }
+
+  // Tile size for each thread.
+  //
+  // Equals to the number of iterations in the loop each tile will make.
+  const absl::InlinedVector<int64_t, 4>& GetThreadTileSize() const {
+    return tile_sizes_per_thread_;
+  }
+
+  // Tile size for an entire thread block.
+  const absl::InlinedVector<int64_t, 4>& GetBlockTileSize() const {
+    return tile_sizes_per_block_;
+  }
+
+  const absl::InlinedVector<int64_t, 4>& GetThreadsPerBlock() const {
+    return num_threads_;
+  }
+
+  // Returns the strides of the thread index dimensions wrt. the linear thread
+  // id.
+  absl::InlinedVector<int64_t, 4> GetThreadStrides() const {
+    return *ShapeUtil::ByteStrides(ShapeUtil::MakeShape(U8, num_threads_));
+  }
+
+  int64_t GetNumThreadsPerBlock() const { return Product(num_threads_); }
+
+  int64_t GetNumBlocks() const { return Product(num_blocks_); }
+
+  const absl::InlinedVector<bool, 4>& GetLoopsToUnroll() const {
+    return loops_to_unroll_;
+  }
+
+ private:
+  // The number of elements in each dimension.
+  absl::InlinedVector<int64_t, 4> shape_;
+
+  // The number of elements for each dimension of a tile.
+  absl::InlinedVector<int64_t, 4> tile_sizes_per_thread_;
+  absl::InlinedVector<int64_t, 4> tile_sizes_per_block_;
+
+  absl::InlinedVector<int64_t, 4> num_threads_;
+  absl::InlinedVector<int64_t, 4> num_blocks_;
+
+  absl::InlinedVector<bool, 4> loops_to_unroll_;
+};
+
+struct TilingThreadIdInfo {
+  llvm::Value* thread_id;
+
+  absl::InlinedVector<llvm::Value*, 4> thread_ids;
+
+  // Lane id: `thread_id % WarpSize`
+  llvm::Value* lane_id;
+
+  // Block id.
+  llvm::Value* block_id;
+};
+
+struct TilingKernelInfo {
+  // Tiling bounds.
+  absl::InlinedVector<llvm::Value*, 4> output_tile_bounds;
+
+  // Starting tile, as calculated from block id only.
+  llvm_ir::IrArray::Index tile_origin;
+
+  // Thread meta-info.
+  TilingThreadIdInfo thread_id_info;
+};
+
+// A function to generate the code to emit the entire tile.
+//
+// index: Absolute coordinate of the start of the tile in input.
+// tile_dimensions: Size of the tile
+using TileGenerator =
+    std::function<void(const TilingThreadIdInfo& thread_id_info,
+                       const llvm_ir::IrArray::Index& tile_start_index,
+                       absl::Span<llvm::Value* const> tile_dimensions)>;
+
+// A function object to generate code to process one element in a tile.
+//
+// index_in_tile: the current coordinates within the tile. To get the global
+// coordinates, use `tile_start_index.AddOffset(index_in_tile, ...)`.
+using TileElementGenerator =
+    std::function<void(absl::Span<llvm::Value* const> index_in_tile)>;
+
+// Emits code to iterate through a tile with given tile dimensions and generate
+// elements using the callback.
+void EmitTile(llvm::IRBuilder<>* builder, const Tiling& tiling,
+              const TilingThreadIdInfo& thread_id_info,
+              absl::Span<llvm::Value* const> tile_dimensions,
+              const TileElementGenerator& emit_elem_function);
+
+// Emits a kernel for the hlo instruction using the given kernel mapping
+// scheme.
+absl::StatusOr<TilingKernelInfo> EmitTilingKernel(
+    llvm::IRBuilder<>* builder, const Tiling& tiling, llvm::Type* index_ty,
+    const TileGenerator& tile_element_generator);
+
+// Creates an indexing map from thread and block IDs to elements of the tiled
+// shape. Uses the same convention as KernelFusionInterface: dimensions 0 to 2
+// are thread indices (currently only 0 is used), dimensions 3 to 5 are block
+// indices (currently only 3 is used).
+mlir::AffineMap GetBlockOffsetsForTiling(
+    absl::Span<const int64_t> num_blocks,
+    absl::Span<const int64_t> tile_sizes_per_block, int64_t rank,
+    mlir::MLIRContext* mlir_context);
+mlir::AffineMap GetBlockOffsetsForTiling(const Tiling& tiling,
+                                         mlir::MLIRContext* mlir_context);
+mlir::AffineMap GetThreadOffsetsForTiling(
+    absl::Span<const int64_t> num_threads,
+    absl::Span<const int64_t> tile_sizes_per_thread, int64_t rank,
+    mlir::MLIRContext* mlir_context);
+mlir::AffineMap GetThreadOffsetsForTiling(const Tiling& tiling,
+                                          mlir::MLIRContext* mlir_context);
+
+// Convenience functions for the two functions above
+// (`GetBlockOffsestsForTiling` + `GetThreadOffsetsForTiling`). Also sets up
+// the ranges of dimensions and symbols.
+IndexingMap GetIndexingMapForTiling(const Tiling& tiling,
+                                    mlir::MLIRContext* mlir_context);
+IndexingMap GetIndexingMapForTiling(mlir::AffineMap block_offsets,
+                                    mlir::AffineMap thread_offsets,
+                                    int64_t threads_per_block,
+                                    int64_t num_blocks,
+                                    absl::Span<const int64_t> thread_tile_sizes,
+                                    absl::Span<const int64_t> tiled_shape);
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_FUSIONS_LEGACY_TILING_UTIL_H_
diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/transpose.cc b/third_party/xla/xla/service/gpu/fusions/legacy/transpose.cc
new file mode 100644
index 0000000..d6cbdec
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/legacy/transpose.cc
@@ -0,0 +1,366 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "xla/service/gpu/fusions/legacy/transpose.h"
+
+#include <array>
+#include <cstdint>
+#include <optional>
+#include <tuple>
+#include <utility>
+#include <variant>
+#include <vector>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/inlined_vector.h"
+#include "absl/log/check.h"
+#include "absl/status/status.h"
+#include "absl/strings/str_cat.h"
+#include "absl/types/span.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Type.h"
+#include "llvm/IR/Value.h"
+#include "llvm/Support/AtomicOrdering.h"
+#include "mlir/IR/AffineMap.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/permutation_util.h"
+#include "xla/service/gpu/elemental_ir_emitter.h"
+#include "xla/service/gpu/fusions/legacy/tiling_util.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
+#include "xla/service/gpu/ir_emission_utils.h"
+#include "xla/service/gpu/ir_emitter_context.h"
+#include "xla/service/gpu/launch_dimensions.h"
+#include "xla/service/gpu/model/indexing_analysis.h"
+#include "xla/service/gpu/model/indexing_map.h"
+#include "xla/service/gpu/target_util.h"
+#include "xla/service/llvm_ir/fused_ir_emitter.h"
+#include "xla/service/llvm_ir/ir_array.h"
+#include "xla/service/llvm_ir/llvm_util.h"
+#include "xla/service/llvm_ir/loop_emitter.h"
+#include "xla/shape_util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+Tiling ComputeTransposeTiling(const se::DeviceDescription& gpu_device_info,
+                              const TransposeDescription& tiled_transpose) {
+  constexpr int kNumRows = 4;
+  static_assert(WarpSize() % kNumRows == 0);
+
+  // 3D view over the output shape.
+  absl::InlinedVector<int64_t, 3> transposed_dims = tiled_transpose.dimensions;
+  absl::InlinedVector<int64_t, 3> permutation = tiled_transpose.permutation;
+
+  // Note: the supported permutations are their own inverses. Therefore we
+  // always use the permutation, even when we want the inverse.
+  CHECK((permutation == absl::InlinedVector<int64_t, 3>{0, 2, 1}) ||
+        (permutation == absl::InlinedVector<int64_t, 3>{2, 1, 0}));
+
+  absl::InlinedVector<int64_t, 4> input_dims{transposed_dims[permutation[0]],
+                                             transposed_dims[permutation[1]],
+                                             transposed_dims[permutation[2]]};
+
+  // We tile along the minor dimensions pre- and post-transpose.
+  absl::InlinedVector<int64_t, 4> tile_sizes{1, 1, 1};
+  tile_sizes[permutation[2]] = WarpSize() / kNumRows;
+  absl::InlinedVector<int64_t, 4> num_threads{1, 1, WarpSize()};
+  num_threads[permutation[2]] = kNumRows;
+
+  auto capability = gpu_device_info.gpu_compute_capability();
+  std::visit(
+      [&](const auto& capability) {
+        if constexpr (std::is_same_v<std::decay_t<decltype(capability)>,
+                                     stream_executor::RocmComputeCapability>) {
+          // kNumRows = 8 works well on MI300 with wavefront size 64.
+          if (capability.gfx9_mi300()) {
+            tile_sizes[permutation[2]] = gpu_device_info.threads_per_warp() / 8;
+            num_threads[permutation[2]] = 8;
+          }
+        }
+      },
+      capability);
+
+  return Tiling(input_dims, tile_sizes, num_threads);
+}
+
+void MaybeEmitFenceForAMDGPU(llvm::IRBuilder<>* builder,
+                             IrEmitterContext& ir_emitter_context) {
+  auto* module = builder->GetInsertBlock()->getModule();
+  if (IsAMDGPU(module) &&
+      ir_emitter_context.rocm_compute_capability().fence_before_barrier()) {
+    builder->CreateFence(
+        llvm::AtomicOrdering::SequentiallyConsistent,
+        builder->getContext().getOrInsertSyncScopeID("workgroup"));
+  }
+}
+
+void EmitSyncThreads(llvm::IRBuilder<>* builder,
+                     IrEmitterContext& ir_emitter_context) {
+  MaybeEmitFenceForAMDGPU(builder, ir_emitter_context);
+  EmitCallToTargetIntrinsic(TargetIntrinsicID::kBarrierId, {}, {}, builder);
+}
+
+llvm_ir::IrArray::Index PermuteIndex(const llvm_ir::IrArray::Index& index,
+                                     absl::Span<const int64_t> permutation) {
+  return llvm_ir::IrArray::Index{Permute(index.multidim(), permutation),
+                                 Permute(index.dims(), permutation),
+                                 index.GetType()};
+}
+
+}  // namespace
+
+TransposeFusion::TransposeFusion(const se::DeviceDescription& gpu_device_info,
+                                 const HloFusionAnalysis& analysis)
+    : analysis_(analysis),
+      tiling_(
+          ComputeTransposeTiling(gpu_device_info, analysis.tiled_transpose())) {
+  for (auto [root, hero] :
+       llvm::zip(analysis_.fusion_roots(), analysis_.fusion_heroes())) {
+    if (auto transpose = GetDescriptionForTiledTransposeEmitter(
+            root.instruction(), hero.instruction())) {
+      permutation_ = transpose->permutation;
+      break;
+    }
+  }
+}
+
+absl::Status TransposeFusion::EmitKernel(IrEmitterContext& ir_emitter_context,
+                                         const HloFusionInstruction& fusion,
+                                         const LaunchDimensions& launch_dims,
+                                         std::vector<llvm_ir::IrArray> inputs,
+                                         std::vector<llvm_ir::IrArray> outputs,
+                                         llvm::IRBuilder<>* builder) const {
+  const auto& hlo_roots = analysis_.fusion_roots();
+  GpuElementalIrEmitter elemental_emitter(ir_emitter_context, builder);
+  FusedIrEmitter fused_emitter(elemental_emitter);
+  for (auto [i, input] : llvm::enumerate(inputs)) {
+    HloInstruction* fused_operand = fusion.fused_parameter(i);
+    fused_emitter.BindGenerator(
+        *fused_operand, [input = input, builder,
+                         fused_operand](const llvm_ir::IrArray::Index& index) {
+          return input.EmitReadArrayElement(index, builder,
+                                            fused_operand->name());
+        });
+  }
+
+  absl::flat_hash_map<const HloInstruction*,
+                      std::vector<std::pair<int64_t, const HloInstruction*>>>
+      transposes_to_roots;
+  // Keep a list of deduplicated transpose heroes separate from
+  // transposes_to_roots to make the CodeGen deterministic.
+  std::vector<TransposeDescription> transposes;
+  transposes.reserve(hlo_roots.size());
+  std::vector<std::pair<int64_t, const HloInstruction*>> extra_outputs;
+
+  for (const auto& [output_idx, root] : llvm::enumerate(hlo_roots)) {
+    const auto& hero = analysis_.fusion_hero(output_idx).instruction();
+    auto transpose_descr =
+        GetDescriptionForTiledTransposeEmitter(root.instruction(), hero);
+    if (transpose_descr.has_value()) {
+      auto iterator_inserted = transposes_to_roots.insert(std::make_pair(
+          &hero, std::vector<std::pair<int64_t, const HloInstruction*>>{
+                     {output_idx, &root.instruction()}}));
+      if (iterator_inserted.second) {
+        transposes.push_back(*transpose_descr);
+      } else {
+        iterator_inserted.first->second.push_back(
+            {output_idx, &root.instruction()});
+      }
+    } else {
+      extra_outputs.push_back({output_idx, &root.instruction()});
+    }
+  }
+
+  absl::flat_hash_map<const HloInstruction*, llvm_ir::SharedMemoryTile> tiles;
+  absl::InlinedVector<int64_t, 3> permutation;
+  for (const auto& [tile_idx, tr] : llvm::enumerate(transposes)) {
+    permutation = tr.permutation;
+    auto tile_size = tiling_.GetBlockTileSize();
+    ++tile_size.back();  // Prevent bank conflicts.
+    auto* module = ir_emitter_context.llvm_module();
+    tiles[tr.instr] = llvm_ir::AllocateSharedMemoryTile(
+        module,
+        llvm_ir::PrimitiveTypeToIrType(tr.instr->shape().element_type(),
+                                       module),
+        tile_size, absl::StrCat("tr_tile_", tile_idx));
+  }
+
+  auto tile_generator = [&](const TilingThreadIdInfo& thread_id_info,
+                            const llvm_ir::IrArray::Index& tile_start_index,
+                            absl::Span<llvm::Value* const> tile_dimensions) {
+    // Copy input parameter values to shared memory buffers:
+    // tile[thread_id_y, thread_id_x] = input[index]
+    EmitTile(builder, tiling_, thread_id_info, tile_dimensions,
+             [&](absl::Span<llvm::Value* const> index_in_tile) {
+               auto index = tile_start_index.AddOffset(index_in_tile, builder);
+               for (const auto& tr : transposes) {
+                 auto input_gen =
+                     *fused_emitter.GetGenerator(*tr.instr->operand(0));
+                 auto input_index = index.SourceIndexOfBitcast(
+                     tr.instr->operand(0)->shape(), builder);
+                 llvm::Value* value = *input_gen(input_index);
+                 tiles[tr.instr].Store(value, index_in_tile, builder);
+               }
+
+               // Compute all extra output values before writing them. This
+               // avoids overwriting aliased input/output values before all
+               // reads occurred.
+               std::vector<std::tuple<llvm_ir::IrArray, llvm_ir::IrArray::Index,
+                                      llvm::Value*>>
+                   scheduled_writes;
+               for (const auto& [output_idx, root] : extra_outputs) {
+                 auto extra_output_index =
+                     index.SourceIndexOfBitcast(root->shape(), builder);
+                 auto output_gen = *fused_emitter.GetGenerator(*root);
+                 llvm::Value* output_value = *output_gen(extra_output_index);
+                 scheduled_writes.emplace_back(
+                     outputs[output_idx], extra_output_index, output_value);
+               }
+
+               for (const auto& [output, idx, value] : scheduled_writes) {
+                 output.EmitWriteArrayElement(idx, value, builder);
+               }
+             });
+
+    EmitSyncThreads(builder, ir_emitter_context);
+
+    auto output_tile_index = PermuteIndex(tile_start_index, permutation);
+    auto transposed_tile_dimensions = Permute(tile_dimensions, permutation);
+
+    EmitTile(
+        builder, tiling_, thread_id_info, transposed_tile_dimensions,
+        /*emit_elem_function=*/
+        [&](absl::Span<llvm::Value* const> index_in_tile) {
+          auto index = output_tile_index.AddOffset(index_in_tile, builder);
+          for (const auto& tr : transposes) {
+            llvm::Value* loaded = tiles[tr.instr].Load(
+                Permute(index_in_tile, permutation), builder);
+
+            FusedIrEmitter fused_emitter(elemental_emitter);
+            fused_emitter.BindGenerator(
+                *tr.instr,
+                [&](const llvm_ir::IrArray::Index&) { return loaded; });
+            for (int64_t i = 0;
+                 i < fusion.fused_instructions_computation()->num_parameters();
+                 ++i) {
+              llvm_ir::IrArray ir_array = inputs[i];
+              HloInstruction* fused_operand = fusion.fused_parameter(i);
+              fused_emitter.BindGenerator(
+                  *fused_operand, [=](const llvm_ir::IrArray::Index& index) {
+                    return ir_array.EmitReadArrayElement(index, builder,
+                                                         fused_operand->name());
+                  });
+            }
+
+            // Apply code generation for the code after the real hero.
+            // Compute all output values before writing them. This avoids
+            // overwriting aliased input/output values before all reads
+            // occurred.
+            std::vector<std::tuple<llvm_ir::IrArray, llvm_ir::IrArray::Index,
+                                   llvm::Value*>>
+                scheduled_writes;
+            for (const auto& [output_idx, root] :
+                 transposes_to_roots[tr.instr]) {
+              TF_ASSIGN_OR_RETURN(llvm_ir::ElementGenerator gen,
+                                  fused_emitter.GetGenerator(*root));
+
+              // Both for emission and writing it should be
+              // index-as-transformed by the computation.
+              auto untiled_index =
+                  index.SourceIndexOfBitcast(root->shape(), builder);
+              TF_ASSIGN_OR_RETURN(llvm::Value * generated, gen(untiled_index));
+              scheduled_writes.emplace_back(outputs[output_idx], untiled_index,
+                                            generated);
+            }
+            for (const auto& [output, idx, value] : scheduled_writes) {
+              output.EmitWriteArrayElement(idx, value, builder);
+            }
+          }
+          return absl::OkStatus();
+        });
+  };
+
+  llvm::Type* index_type =
+      GetIndexTypeForKernel(&fusion, launch_dims.launch_bound(), builder);
+  return EmitTilingKernel(builder, tiling_, index_type, tile_generator)
+      .status();
+}
+
+LaunchDimensions TransposeFusion::launch_dimensions() const {
+  return LaunchDimensions(tiling_.GetNumBlocks(),
+                          tiling_.GetNumThreadsPerBlock());
+}
+
+std::optional<IndexingMap> TransposeFusion::ComputeThreadIdToOutputIndexing(
+    int64_t root_index, mlir::MLIRContext* ctx) const {
+  const auto& hero = analysis_.fusion_hero(root_index);
+  if (hero.opcode() != HloOpcode::kTranspose) {
+    // The shape of non-transpose roots are bitcast compatible with the input
+    // shape of transpose heroes.
+    auto map = ComposeIndexingMaps(
+        GetIndexingMapForTiling(tiling_, ctx),
+        GetBitcastMap(tiling_.GetXlaShape(),
+                      analysis_.fusion_root(root_index).shape(), ctx));
+    map.Simplify();
+    return map;
+  }
+
+  // The block offsets are permuted, but the thread offsets remain the same.
+  auto block_offset = GetBlockOffsetsForTiling(tiling_, ctx)
+                          .getSubMap(std::vector<unsigned>{permutation_.begin(),
+                                                           permutation_.end()});
+  auto thread_offset = GetThreadOffsetsForTiling(tiling_, ctx);
+  auto permuted_tiled_shape =
+      ShapeUtil::MakeShape(U8, Permute(tiling_.GetShape(), permutation_));
+
+  auto map = ComposeIndexingMaps(
+      GetIndexingMapForTiling(
+          block_offset, thread_offset, tiling_.GetNumThreadsPerBlock(),
+          tiling_.GetNumBlocks(), tiling_.GetThreadTileSize(),
+          permuted_tiled_shape.dimensions()),
+      GetBitcastMap(permuted_tiled_shape, hero.shape(), ctx));
+  map.Simplify();
+  return map;
+}
+
+std::optional<IndexingMap> TransposeFusion::ComputeThreadIdToInputIndexing(
+    int64_t root_index, int64_t hero_operand_index,
+    mlir::MLIRContext* ctx) const {
+  const auto& hero = analysis_.fusion_hero(root_index).instruction();
+  if (hero.opcode() != HloOpcode::kTranspose) {
+    auto map = ComposeIndexingMaps(
+        *ComputeThreadIdToOutputIndexing(root_index, ctx),
+        *ComputeOutputToInputIndexing(
+             &analysis_.fusion_root(root_index).instruction(), 0, ctx)
+             .indexing_maps[hero_operand_index]
+             .begin());
+    map.Simplify();
+    return map;
+  }
+
+  auto map = ComposeIndexingMaps(
+      GetIndexingMapForTiling(tiling_, ctx),
+      GetBitcastMap(tiling_.GetXlaShape(), hero.operand(0)->shape(), ctx));
+  map.Simplify();
+  return map;
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/transpose.h b/third_party/xla/xla/service/gpu/fusions/legacy/transpose.h
new file mode 100644
index 0000000..3366130
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/legacy/transpose.h
@@ -0,0 +1,91 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_FUSIONS_LEGACY_TRANSPOSE_H_
+#define XLA_SERVICE_GPU_FUSIONS_LEGACY_TRANSPOSE_H_
+
+#include <cstdint>
+#include <optional>
+#include <vector>
+
+#include "absl/container/inlined_vector.h"
+#include "absl/status/status.h"
+#include "llvm/IR/IRBuilder.h"
+#include "mlir/IR/MLIRContext.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/service/gpu/fusions/fusion_emitter.h"
+#include "xla/service/gpu/fusions/legacy/tiling_util.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
+#include "xla/service/gpu/ir_emitter_context.h"
+#include "xla/service/gpu/launch_dimensions.h"
+#include "xla/service/gpu/model/indexing_map.h"
+#include "xla/service/llvm_ir/ir_array.h"
+
+namespace xla {
+namespace gpu {
+
+// Emits a kernel for the given hlo instruction using a tiled 0-2-1 transpose
+// algorithm to improve the memory access patterns for the input parameters
+// with a shape that is a 0-2-1 transpose of the output tensor shape. The
+// caller is responsible for making sure that it is safe to apply the shared
+// memory transpose on the input parameters.
+//
+// For the purpose of tiling, the output tensors have a logical shape of three
+// components 0-2-1 while the relevant input parameters have a logical shape
+// of three components 0-1-2 in the order major to minor. The x- and y-
+// dimensions of the tensors are tiled in square tiles with an edge length
+// `kTileSize`. Each thread block of `kTileSize` x `kNumRows` threads
+// transposes one tile: each thread copies kTileSize/kNumRows elements from
+// the input to a shared memory tile, then the otherwise "regular HLO kernel"
+// reads from the shared memory instead of the original input.
+//
+// This is similar to the following CUDA algorithm in TensorFlow:
+// https://goo.gl/MStRV6.
+//
+// `kTileSize` should usually be same as warp size. We currently choose 32 for
+// `kTileSize` and 4 for `kNumRows`. The CUDA algorithm uses 8 for `kNumRows`.
+//
+// TODO(b/33320379): Here each block transposes 1 tile. It may be more
+// efficient to launch fewer blocks so each transposes many tiles.
+class TransposeFusion : public KernelFusionEmitterBase {
+ public:
+  explicit TransposeFusion(const se::DeviceDescription& gpu_device_info,
+                           const HloFusionAnalysis& analysis);
+  LaunchDimensions launch_dimensions() const override;
+
+  std::optional<IndexingMap> ComputeThreadIdToOutputIndexing(
+      int64_t root_index, mlir::MLIRContext* ctx) const override;
+
+  std::optional<IndexingMap> ComputeThreadIdToInputIndexing(
+      int64_t root_index, int64_t hero_operand_index,
+      mlir::MLIRContext* ctx) const override;
+
+ protected:
+  absl::Status EmitKernel(IrEmitterContext& ir_emitter_context,
+                          const HloFusionInstruction& fusion,
+                          const LaunchDimensions& launch_dims,
+                          std::vector<llvm_ir::IrArray> inputs,
+                          std::vector<llvm_ir::IrArray> outputs,
+                          llvm::IRBuilder<>* builder) const override;
+
+ private:
+  const HloFusionAnalysis& analysis_;
+  Tiling tiling_;
+  absl::InlinedVector<int64_t, 3> permutation_;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_FUSIONS_LEGACY_TRANSPOSE_H_
diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/transpose_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/transpose_test.cc
new file mode 100644
index 0000000..f33ae04
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/legacy/transpose_test.cc
@@ -0,0 +1,346 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "xla/service/gpu/fusions/legacy/transpose.h"
+
+#include <memory>
+#include <optional>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/status/statusor.h"
+#include "mlir/IR/MLIRContext.h"
+#include "xla/service/gpu/fusions/fusions.h"
+#include "xla/service/gpu/gpu_device_info_for_tests.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
+#include "xla/service/gpu/model/indexing_test_utils.h"
+#include "xla/status_macros.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+class TransposeTest : public HloTestBase {
+ protected:
+  DebugOptions GetDebugOptionsForTest() override {
+    auto opts = HloTestBase::GetDebugOptionsForTest();
+    opts.set_xla_gpu_mlir_emitter_level(0);
+    return opts;
+  }
+  stream_executor::DeviceDescription device_info_ =
+      TestGpuDeviceInfo::RTXA6000DeviceInfo();
+};
+
+absl::StatusOr<std::unique_ptr<TransposeFusion>> GetTransposeFusion(
+    const HloFusionAnalysis& analysis) {
+  auto emitter = GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis});
+  auto fusion = dynamic_cast<TransposeFusion*>(emitter.get());
+  TF_RET_CHECK(fusion != nullptr);
+
+  emitter.release();
+  return std::unique_ptr<TransposeFusion>{fusion};
+}
+
+TEST_F(TransposeTest, ThreadIndexing021) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule module
+
+    fusion {
+      %input = f32[100,32,64] parameter(0)
+      ROOT transpose = f32[100,64,32] transpose(%input), dimensions={0,2,1}
+    }
+
+    ENTRY entry {
+      %input = f32[100,32,64] parameter(0)
+      ROOT %fusion = f32[100,64,32] fusion(%input), kind=kInput, calls=fusion
+    })")
+                    .value();
+
+  auto* root = module->entry_computation()->root_instruction();
+  auto analysis = HloFusionAnalysis::Create(*root, device_info_);
+
+  TF_ASSERT_OK_AND_ASSIGN(auto fusion, GetTransposeFusion(analysis));
+  mlir::MLIRContext mlir_context;
+
+  EXPECT_THAT(
+      fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context)->ToString(),
+      MatchIndexingString(R"(
+        (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> (
+          d3 floordiv 2,
+          d0 floordiv 32 + s1 * 4,
+          (d3 mod 2) * 32 + d0 mod 32
+        )
+        domain:
+        d0 in [0, 127]
+        d1 in [0, 0]
+        d2 in [0, 0]
+        d3 in [0, 199]
+        d4 in [0, 0]
+        d5 in [0, 0]
+
+        s0 in [0, 0]
+        s1 in [0, 7]
+        s2 in [0, 0]
+      )"));
+  EXPECT_THAT(
+      fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)->ToString(),
+      MatchIndexingString(R"(
+        (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> (
+          d3 floordiv 2,
+          (d3 mod 2) * 32 + s1 * 4 + d0 floordiv 32,
+          d0 mod 32
+        )
+        domain:
+        d0 in [0, 127]
+        d1 in [0, 0]
+        d2 in [0, 0]
+        d3 in [0, 199]
+        d4 in [0, 0]
+        d5 in [0, 0]
+
+        s0 in [0, 0]
+        s1 in [0, 7]
+        s2 in [0, 0]
+      )"));
+}
+
+TEST_F(TransposeTest, ThreadIndexing201) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule module
+
+    fusion {
+      %input = f32[100,64,32] parameter(0)
+      ROOT transpose = f32[32,100,64] transpose(%input), dimensions={2,0,1}
+    }
+
+    ENTRY entry {
+      %input = f32[100,64,32] parameter(0)
+      ROOT %fusion = f32[32,100,64] fusion(%input), kind=kInput, calls=fusion
+    })")
+                    .value();
+
+  auto* root = module->entry_computation()->root_instruction();
+  auto analysis = HloFusionAnalysis::Create(*root, device_info_);
+
+  TF_ASSERT_OK_AND_ASSIGN(auto fusion, GetTransposeFusion(analysis));
+  mlir::MLIRContext mlir_context;
+  EXPECT_THAT(
+      fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context)->ToString(),
+      MatchIndexingString(R"(
+        (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> (
+          d3 floordiv 2,
+          (d3 mod 2) * 32 + s1 * 4 + d0 floordiv 32,
+          d0 mod 32
+        )
+        domain:
+        d0 in [0, 127]
+        d1 in [0, 0]
+        d2 in [0, 0]
+        d3 in [0, 199]
+        d4 in [0, 0]
+        d5 in [0, 0]
+
+        s0 in [0, 0]
+        s1 in [0, 7]
+        s2 in [0, 0]
+      )"));
+  EXPECT_THAT(
+      fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)->ToString(),
+      MatchIndexingString(R"(
+        (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> (
+          d0 floordiv 32 + s1 * 4,
+          d3 floordiv 2,
+          (d3 mod 2) * 32 + d0 mod 32
+        )
+        domain:
+        d0 in [0, 127]
+        d1 in [0, 0]
+        d2 in [0, 0]
+        d3 in [0, 199]
+        d4 in [0, 0]
+        d5 in [0, 0]
+
+        s0 in [0, 0]
+        s1 in [0, 7]
+        s2 in [0, 0]
+      )"));
+}
+
+TEST_F(TransposeTest, ThreadIndexingPartialBlock) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule m
+
+    fused_computation {
+      %p0 = f64[24,2,6,4] parameter(0)
+      ROOT %t = f64[6,4,2,24] transpose(%p0), dimensions={2,3,1,0}
+    }
+
+    ENTRY main {
+      %p0 = f64[24,2,6,4] parameter(0)
+      ROOT %fusion = f64[6,4,2,24] fusion(%p0), kind=kInput,
+        calls=%fused_computation
+    }
+  )")
+                    .value();
+
+  auto* root = module->entry_computation()->root_instruction();
+  auto analysis = HloFusionAnalysis::Create(*root, device_info_);
+
+  TF_ASSERT_OK_AND_ASSIGN(auto fusion, GetTransposeFusion(analysis));
+  mlir::MLIRContext mlir_context;
+  EXPECT_THAT(
+      fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context)->ToString(),
+      MatchIndexingString(R"(
+        (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> (
+          d0 floordiv 32 + s0 * 4,
+          d3,
+          (d0 floordiv 4) mod 8,
+          d0 mod 4
+        )
+        domain:
+        d0 in [0, 127]
+        d1 in [0, 0]
+        d2 in [0, 0]
+        d3 in [0, 1]
+        d4 in [0, 0]
+        d5 in [0, 0]
+        s0 in [0, 5]
+        s1 in [0, 0]
+        s2 in [0, 0]
+        d0 mod 32 in [0, 23]
+      )"));
+  EXPECT_THAT(
+      fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)->ToString(),
+      MatchIndexingString(R"(
+        (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> (
+          s0,
+          d0 floordiv 32,
+          d3,
+          d0 mod 32
+        )
+        domain:
+        d0 in [0, 127]
+        d1 in [0, 0]
+        d2 in [0, 0]
+        d3 in [0, 1]
+        d4 in [0, 0]
+        d5 in [0, 0]
+        s0 in [0, 5]
+        s1 in [0, 0]
+        s2 in [0, 0]
+        d0 mod 32 in [0, 23]
+      )"));
+}
+
+TEST_F(TransposeTest, SameInputIndexingForRealHeroAndSideOutput) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule module
+
+    fusion {
+      %input = f32[100,32,64] parameter(0)
+      %transpose = f32[100,64,32] transpose(%input), dimensions={0,2,1}
+      %bitcast = f32[100,2048] bitcast(%input)
+      ROOT %tuple = (f32[100,64,32], f32[100,2048]) tuple(%transpose, %bitcast)
+    }
+
+    ENTRY entry {
+      %input = f32[100,32,64] parameter(0)
+      ROOT %fusion = (f32[100,64,32], f32[100,2048]) fusion(%input), kind=kInput, calls=fusion
+    })")
+                    .value();
+
+  auto* root = module->entry_computation()->root_instruction();
+  auto analysis = HloFusionAnalysis::Create(*root, device_info_);
+
+  TF_ASSERT_OK_AND_ASSIGN(auto fusion, GetTransposeFusion(analysis));
+  mlir::MLIRContext mlir_context;
+
+  EXPECT_THAT(
+      fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context)->ToString(),
+      fusion->ComputeThreadIdToInputIndexing(1, 0, &mlir_context)->ToString());
+}
+
+TEST_F(TransposeTest, ThreadIndexingSideOutput) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule module
+
+    fusion {
+      %input0 = f32[100,32,64] parameter(0)
+      %input1 = f32[100,32] parameter(1)
+      %transpose = f32[100,64,32] transpose(%input0), dimensions={0,2,1}
+      %broadcast = f32[100,32,64] broadcast(%input1), dimensions={0,1}
+      ROOT %tuple = (f32[100,64,32], f32[100,32,64]) tuple(%transpose, %broadcast)
+    }
+
+    ENTRY entry {
+      %input0 = f32[100,32,64] parameter(0)
+      %input1 = f32[100,32] parameter(1)
+      ROOT %fusion = (f32[100,64,32], f32[100,32,64]) fusion(%input0, %input1), kind=kInput, calls=fusion
+    })")
+                    .value();
+
+  auto* root = module->entry_computation()->root_instruction();
+  auto analysis = HloFusionAnalysis::Create(*root, device_info_);
+
+  TF_ASSERT_OK_AND_ASSIGN(auto fusion, GetTransposeFusion(analysis));
+  mlir::MLIRContext mlir_context;
+  // Check if side output `%broadcast` get the correct input indexing, which
+  // should corresponds to `%input1` with shape [100,32].
+  EXPECT_THAT(
+      fusion->ComputeThreadIdToInputIndexing(1, 0, &mlir_context)->ToString(),
+      MatchIndexingString(R"(
+        (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> (
+          d3 floordiv 2,
+          d0 floordiv 32 + s1 * 4
+        )
+        domain:
+        d0 in [0, 127]
+        d1 in [0, 0]
+        d2 in [0, 0]
+        d3 in [0, 199]
+        d4 in [0, 0]
+        d5 in [0, 0]
+
+        s0 in [0, 0]
+        s1 in [0, 7]
+        s2 in [0, 0]
+      )"));
+  EXPECT_THAT(
+      fusion->ComputeThreadIdToOutputIndexing(1, &mlir_context)->ToString(),
+      MatchIndexingString(R"(
+        (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> (
+          d3 floordiv 2,
+          d0 floordiv 32 + s1 * 4,
+          (d3 mod 2) * 32 + d0 mod 32
+        )
+        domain:
+        d0 in [0, 127]
+        d1 in [0, 0]
+        d2 in [0, 0]
+        d3 in [0, 199]
+        d4 in [0, 0]
+        d5 in [0, 0]
+
+        s0 in [0, 0]
+        s1 in [0, 7]
+        s2 in [0, 0]
+      )"));
+}
+
+}  // namespace
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/loop.cc b/third_party/xla/xla/service/gpu/fusions/loop.cc
deleted file mode 100644
index 522dc1d..0000000
--- a/third_party/xla/xla/service/gpu/fusions/loop.cc
+++ /dev/null
@@ -1,295 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "xla/service/gpu/fusions/loop.h"
-
-#include <algorithm>
-#include <cstdint>
-#include <optional>
-#include <tuple>
-#include <utility>
-#include <vector>
-
-#include "absl/log/check.h"
-#include "absl/log/log.h"
-#include "absl/numeric/bits.h"
-#include "absl/status/status.h"
-#include "llvm/IR/IRBuilder.h"
-#include "llvm/IR/Type.h"
-#include "mlir/IR/MLIRContext.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/layout_util.h"
-#include "xla/service/gpu/elemental_ir_emitter.h"
-#include "xla/service/gpu/gpu_fusible.h"
-#include "xla/service/gpu/hlo_fusion_analysis.h"
-#include "xla/service/gpu/hlo_traversal.h"
-#include "xla/service/gpu/ir_emission_utils.h"
-#include "xla/service/gpu/ir_emitter_context.h"
-#include "xla/service/gpu/launch_dimensions.h"
-#include "xla/service/gpu/model/indexing_analysis.h"
-#include "xla/service/gpu/model/indexing_map.h"
-#include "xla/service/gpu/parallel_loop_emitter.h"
-#include "xla/service/llvm_ir/fused_ir_emitter.h"
-#include "xla/service/llvm_ir/ir_array.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/util.h"
-#include "tsl/platform/macros.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-const Shape& GetElementShape(const HloFusionAnalysis& analysis) {
-  const Shape* shape = &analysis.fusion_root(0).shape();
-  while (shape->IsTuple()) {
-    shape = &shape->tuple_shapes(0);
-  }
-  return *shape;
-}
-
-// Computes the maximum valid unroll factor for a given instruction.
-int ComputeMaxUnrollFactor(int64_t num_elements) {
-  constexpr int kMaxUnrollFactor = 4;
-  for (int i = kMaxUnrollFactor; i > 1; i /= 2) {
-    if (num_elements % i == 0) {
-      return i;
-    }
-  }
-  return 1;
-}
-
-// Determines if we enable the row optimized codegen. When we have a fusion with
-// only pointwise operations, scalar broadcasting and row broadcasting, we can
-// trigger a kernel that vectorizes the row loads. This speeds up the kernel, in
-// particular on A100. The int is the number of inputs with rank `out_rank`. Its
-// value is only defined if row vectorization is enabled.
-std::pair<bool /*enabled*/, int> RowVectorizationEnabled(
-    const HloFusionAdaptor& fusion, int64_t out_rank) {
-  auto roots = fusion.GetRoots();
-  const auto is_row_major = [](const HloInstruction* instr) {
-    // Only tested when the inputs are row-major. So only enable that case.
-    // Maybe it would work if only the inner dimensions is contiguous.
-    return LayoutUtil::IsMonotonicWithDim0Major(instr->shape().layout());
-  };
-  bool row_vectorized = roots.size() == 1 && !roots[0].shape().IsTuple() &&
-                        is_row_major(&roots[0].instruction());
-  if (!row_vectorized) {
-    return {false, 0};
-  }
-
-  // Check that the operations in the fusion are supported.  Each
-  // supported operation (or category) must be manually vetted as XLA
-  // only unrolls and relies on LLVM to vectorize. But this is brittle.
-  // Currently tested and supported operations:
-  // Elementwise, scalar and row broadcasting.
-  //
-  // We also detect at the same time if there is a row broadcasting
-  // operation.
-  int num_big_inputs = 0;
-  bool some_row_broadcasting = false;
-  HloBfsConsumersFirstTraversal(
-      roots, fusion,
-      [&](auto node) -> TraversalResult {
-        if (!row_vectorized) {
-          return TraversalResult::kInterrupt;
-        }
-
-        if (node.instruction().IsElementwise()) {
-          return TraversalResult::kAdvance;
-        }
-
-        switch (node.opcode()) {
-          case HloOpcode::kConstant:
-            return TraversalResult::kSkip;
-          case HloOpcode::kParameter:
-            return TraversalResult::kAdvance;
-          case HloOpcode::kBroadcast: {
-            auto dims = node.instruction().dimensions();
-            if (dims.empty()) {
-              return TraversalResult::kAdvance;
-            }
-
-            if (dims.size() == 1 && dims.front() == node.shape().rank() - 1) {
-              some_row_broadcasting = true;
-              return TraversalResult::kAdvance;
-            }
-            TF_FALLTHROUGH_INTENDED;
-          }
-          default:
-            VLOG(2) << "Row vectorization not enabled due to: "
-                    << node.ToString();
-            row_vectorized = false;
-            return TraversalResult::kInterrupt;
-        }
-      });
-  if (row_vectorized) {
-    for (const HloInstruction* argument : fusion.GetParameters()) {
-      if (argument->shape().rank() == out_rank) {
-        ++num_big_inputs;
-      }
-      if (!is_row_major(argument)) {
-        row_vectorized = false;
-      }
-    };
-  }
-  // Trigger only when there is a row broadcasting.
-  return std::make_pair(row_vectorized && some_row_broadcasting,
-                        num_big_inputs);
-}
-
-}  // namespace
-
-LaunchDimensionsConfig ComputeLoopFusionConfig(
-    const HloFusionAnalysis& analysis) {
-  return ComputeLoopFusionConfig(analysis, GetElementShape(analysis));
-}
-
-LaunchDimensionsConfig ComputeLoopFusionConfig(
-    const HloFusionAnalysis& analysis, const Shape& element_shape) {
-  int unroll_factor = 1;
-  // Unrolling is good to read large inputs with small elements
-  // due to vector loads, but increases the register pressure when one
-  // thread has to produce multiple output elements.
-  // Therefore for fusions with small outputs prefer to use one thread
-  // per output element = no unroll.
-  // Call 'small' fusions that use less threads than the GPU has.
-  int64_t num_elements = ShapeUtil::ElementsIn(element_shape);
-  int64_t n_threads_max = analysis.device_info().threads_per_core_limit() *
-                          analysis.device_info().core_count();
-  if (num_elements >= n_threads_max &&
-      !MayPreventVectorization(analysis.fusion())) {
-    unroll_factor = ComputeMaxUnrollFactor(num_elements);
-  }
-  // CHECK that unroll_factor is a power-of-2, as needed by the logic below.
-  CHECK(absl::has_single_bit(static_cast<uint64_t>(unroll_factor)));
-  // Ensure a single thread writes to a byte containing multiple values by
-  // setting unroll_factor to an appropriate number. Setting unroll_factor is
-  // safe even if the new unroll_factor doesn't divide the number of elements,
-  // as the parallel loop emitter will insert a bounds check in this case to
-  // ensure the out-of-bounds element is not computed and written. Setting
-  // unroll_factor is safe even if MayPreventVectorization returns false, as
-  // the MayPreventVectorization check is an optimization, not a correctness
-  // requirement.
-  unroll_factor = std::max(
-      unroll_factor,
-      CeilOfRatio(8, analysis.input_output_info().smallest_output_dtype_bits));
-  CHECK(absl::has_single_bit(static_cast<uint64_t>(unroll_factor)));
-  VLOG(2) << "Unroll factor: " << unroll_factor;
-
-  bool row_vectorized;
-  int num_big_inputs;
-  std::tie(row_vectorized, num_big_inputs) =
-      RowVectorizationEnabled(analysis.fusion(), element_shape.rank());
-  bool few_waves = !HloAnyOf(analysis.fusion(), [&](auto instr) {
-    if (instr.opcode() == HloOpcode::kParameter ||
-        instr.opcode() == HloOpcode::kConstant ||
-        HloInstruction::IsOpElementwise(instr.opcode())) {
-      return false;
-    }
-    if (auto broadcast =
-            DynCast<HloBroadcastInstruction>(&instr.instruction())) {
-      if (broadcast->dimensions().empty() ||
-          // More than 3 big inputs cause a speed regression.
-          (row_vectorized && num_big_inputs <= 3)) {
-        return false;
-      }
-    }
-    VLOG(2) << "few_waves not enabled due to: "
-            << instr.instruction().ToString();
-    return true;
-  });
-
-  LaunchDimensionsConfig launch_config{unroll_factor, few_waves,
-                                       row_vectorized};
-  // Check that the shapes is supported.
-  if (launch_config.row_vectorized &&
-      ThreadsPerBlockRowVectorized(element_shape, analysis.device_info(),
-                                   launch_config) <= 0) {
-    VLOG(2) << "Cancelling row_vectorization as the shape isn't supported.";
-    launch_config.row_vectorized = false;
-    launch_config.few_waves = false;
-  }
-  return launch_config;
-}
-
-LoopFusion::LoopFusion(const HloFusionAnalysis& analysis)
-    : analysis_(analysis), config_(ComputeLoopFusionConfig(analysis)) {}
-
-std::optional<IndexingMap> LoopFusion::ComputeThreadIdToOutputIndexing(
-    int64_t root_index, mlir::MLIRContext* ctx) const {
-  auto launch_dims = launch_dimensions();
-  return GetDefaultThreadIdIndexingMap(launch_dims, config_.unroll_factor,
-                                       GetElementShape(analysis_), ctx);
-}
-
-std::optional<IndexingMap> LoopFusion::ComputeThreadIdToInputIndexing(
-    int64_t root_index, int64_t hero_operand_index,
-    mlir::MLIRContext* ctx) const {
-  std::optional<IndexingMap> thread_id_to_output_indexing =
-      ComputeThreadIdToOutputIndexing(root_index, ctx);
-  if (!thread_id_to_output_indexing.has_value()) {
-    return std::nullopt;
-  }
-  const HloInstruction* fusion_root =
-      &analysis_.fusion_root(root_index).instruction();
-  auto output_to_input_indexing =
-      ComputeOutputToInputIndexing(fusion_root, /*output_id=*/0, ctx);
-  IndexingMapSet output_to_input_indexing_set =
-      output_to_input_indexing.indexing_maps[hero_operand_index];
-  // Since we are computing the indexing for a non-fusion op, there is only one
-  // indexing map per operand.
-  CHECK_EQ(output_to_input_indexing_set.size(), 1);
-  IndexingMap thread_id_to_input_indexing_map = ComposeIndexingMaps(
-      *thread_id_to_output_indexing, *output_to_input_indexing_set.begin());
-  thread_id_to_input_indexing_map.Simplify();
-  return thread_id_to_input_indexing_map;
-}
-
-absl::Status LoopFusion::EmitKernel(IrEmitterContext& ir_emitter_context,
-                                    const HloFusionInstruction& fusion,
-                                    const LaunchDimensions& launch_dims,
-                                    std::vector<llvm_ir::IrArray> inputs,
-                                    std::vector<llvm_ir::IrArray> outputs,
-                                    llvm::IRBuilder<>* builder) const {
-  GpuElementalIrEmitter elemental_emitter(ir_emitter_context, builder);
-  FusedIrEmitter fused_emitter(elemental_emitter);
-  for (int i = 0; i < fusion.fused_parameters().size(); i++) {
-    fused_emitter.BindGenerator(
-        *fusion.fused_parameter(i), [&, i](llvm_ir::IrArray::Index index) {
-          return inputs[i].EmitReadArrayElement(index, builder);
-        });
-  }
-  TF_ASSIGN_OR_RETURN(
-      auto element_generator,
-      fused_emitter.GetGenerator(*fusion.fused_expression_root()));
-
-  llvm::Type* index_type =
-      GetIndexTypeForKernel(&fusion, launch_dims.launch_bound(), builder);
-
-  return ParallelLoopEmitter(element_generator, outputs, launch_dims, builder,
-                             config_)
-      .EmitLoop(fusion.name(), index_type);
-}
-
-LaunchDimensions LoopFusion::launch_dimensions() const {
-  return CalculateLaunchDimensions(GetElementShape(analysis_),
-                                   analysis_.device_info(), config_);
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/loop.h b/third_party/xla/xla/service/gpu/fusions/loop.h
deleted file mode 100644
index 2d23c302..0000000
--- a/third_party/xla/xla/service/gpu/fusions/loop.h
+++ /dev/null
@@ -1,71 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_FUSIONS_LOOP_H_
-#define XLA_SERVICE_GPU_FUSIONS_LOOP_H_
-
-#include <cstdint>
-#include <optional>
-#include <vector>
-
-#include "absl/status/status.h"
-#include "llvm/IR/IRBuilder.h"
-#include "mlir/IR/MLIRContext.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/service/gpu/fusions/fusion_emitter.h"
-#include "xla/service/gpu/hlo_fusion_analysis.h"
-#include "xla/service/gpu/ir_emitter_context.h"
-#include "xla/service/gpu/launch_dimensions.h"
-#include "xla/service/gpu/model/indexing_map.h"
-#include "xla/service/llvm_ir/ir_array.h"
-
-namespace xla {
-namespace gpu {
-
-// Generic loop fusion.
-class LoopFusion : public KernelFusionEmitterBase {
- public:
-  explicit LoopFusion(const HloFusionAnalysis& analysis);
-  LaunchDimensions launch_dimensions() const override;
-
-  std::optional<IndexingMap> ComputeThreadIdToOutputIndexing(
-      int64_t root_index, mlir::MLIRContext* ctx) const override;
-
-  std::optional<IndexingMap> ComputeThreadIdToInputIndexing(
-      int64_t root_index, int64_t hero_operand_index,
-      mlir::MLIRContext* ctx) const override;
-
- protected:
-  absl::Status EmitKernel(IrEmitterContext& ir_emitter_context,
-                          const HloFusionInstruction& fusion,
-                          const LaunchDimensions& launch_dims,
-                          std::vector<llvm_ir::IrArray> inputs,
-                          std::vector<llvm_ir::IrArray> outputs,
-                          llvm::IRBuilder<>* builder) const override;
-
- private:
-  const HloFusionAnalysis& analysis_;
-  LaunchDimensionsConfig config_;
-};
-
-LaunchDimensionsConfig ComputeLoopFusionConfig(
-    const HloFusionAnalysis& analysis);
-
-LaunchDimensionsConfig ComputeLoopFusionConfig(
-    const HloFusionAnalysis& analysis, const Shape& shape);
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_FUSIONS_LOOP_H_
diff --git a/third_party/xla/xla/service/gpu/fusions/loop_mlir.cc b/third_party/xla/xla/service/gpu/fusions/loop_mlir.cc
index 9db9173..4c6bdac 100644
--- a/third_party/xla/xla/service/gpu/fusions/loop_mlir.cc
+++ b/third_party/xla/xla/service/gpu/fusions/loop_mlir.cc
@@ -35,9 +35,9 @@
 #include "xla/hlo/ir/hlo_computation.h"
 #include "xla/hlo/ir/hlo_instruction.h"
 #include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h"
 #include "xla/service/gpu/fusions/mlir/computation_partitioner.h"
 #include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h"
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h"
 #include "xla/service/gpu/hlo_fusion_analysis.h"
 #include "xla/service/gpu/hlo_traversal.h"
 #include "xla/service/gpu/launch_dimensions.h"
diff --git a/third_party/xla/xla/service/gpu/fusions/loop_mlir.h b/third_party/xla/xla/service/gpu/fusions/loop_mlir.h
index 029c67b..ecb1591 100644
--- a/third_party/xla/xla/service/gpu/fusions/loop_mlir.h
+++ b/third_party/xla/xla/service/gpu/fusions/loop_mlir.h
@@ -22,9 +22,9 @@
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/IR/MLIRContext.h"
 #include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/service/gpu/fusions/loop.h"
 #include "xla/service/gpu/fusions/mlir/computation_partitioner.h"
 #include "xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h"
+#include "xla/service/gpu/gpu_fusible.h"
 #include "xla/service/gpu/hlo_fusion_analysis.h"
 #include "xla/service/gpu/launch_dimensions.h"
 #include "xla/service/gpu/model/indexing_map.h"
diff --git a/third_party/xla/xla/service/gpu/fusions/loop_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/loop_mlir_test.cc
index 357ef65..eda89f8 100644
--- a/third_party/xla/xla/service/gpu/fusions/loop_mlir_test.cc
+++ b/third_party/xla/xla/service/gpu/fusions/loop_mlir_test.cc
@@ -46,7 +46,7 @@
   thread_id_printer_.SetSymbolName(1, "unroll_id");
 
   auto* root = module->entry_computation()->root_instruction();
-  auto analysis = AnalyzeFusion(*root, device_info_);
+  auto analysis = HloFusionAnalysis::Create(*root, device_info_);
   MlirLoopFusion fusion(analysis);
   auto thread_id_to_output_indexing =
       fusion.ComputeThreadIdToOutputIndexing(/*root_index=*/0, &mlir_context_);
@@ -88,7 +88,7 @@
   thread_id_printer_.SetSymbolName(1, "unroll_id");
 
   auto* root = module->entry_computation()->root_instruction();
-  auto analysis = AnalyzeFusion(*root, device_info_);
+  auto analysis = HloFusionAnalysis::Create(*root, device_info_);
 
   MlirLoopFusion fusion(analysis);
   auto thread_id_to_output_indexing =
@@ -140,7 +140,7 @@
   thread_id_printer_.SetSymbolName(1, "unroll_id");
 
   auto* root = module->entry_computation()->root_instruction();
-  auto analysis = AnalyzeFusion(*root, device_info_);
+  auto analysis = HloFusionAnalysis::Create(*root, device_info_);
 
   MlirLoopFusion fusion(analysis);
   auto thread_id_to_output_indexing =
@@ -196,10 +196,10 @@
     }
   )";
   TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"(
-    // CHECK: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d1 * 1024 + d0)>
-    // CHECK: #[[MAP1:.*]] = affine_map<(d0, d1) -> ((d1 * 1024 + d0) floordiv 768)>
-    // CHECK: #[[MAP2:.*]] = affine_map<(d0, d1) -> (((d1 * 1024 + d0) floordiv 48) mod 16)>
-    // CHECK: #[[MAP3:.*]] = affine_map<(d0, d1) -> ((d1 * 1024 + d0) mod 48)>
+    // CHECK-DAG: #[[MAP0:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d1 * 1024 + d0)
+    // CHECK-DAG: #[[MAP1:.*]] = #xla_gpu.indexing_map<(d0, d1) -> ((d1 * 1024 + d0) floordiv 768)
+    // CHECK-DAG: #[[MAP2:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (((d1 * 1024 + d0) floordiv 48) mod 16)
+    // CHECK-DAG: #[[MAP3:.*]] = #xla_gpu.indexing_map<(d0, d1) -> ((d1 * 1024 + d0) mod 48)
     // CHECK: func.func @fused_computation(%[[ARG0:.*]]: tensor<2x16x48xbf16>
     // CHECK: %[[UPPER_BOUND:.*]] = arith.constant 1535 : index
     // CHECK: %[[THREAD_ID:.*]] = gpu.thread_id
diff --git a/third_party/xla/xla/service/gpu/fusions/loop_test.cc b/third_party/xla/xla/service/gpu/fusions/loop_test.cc
deleted file mode 100644
index 69c41ec..0000000
--- a/third_party/xla/xla/service/gpu/fusions/loop_test.cc
+++ /dev/null
@@ -1,222 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include <memory>
-#include <optional>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "absl/status/statusor.h"
-#include "mlir/IR/MLIRContext.h"
-#include "xla/service/gpu/fusions/fusion_emitter.h"
-#include "xla/service/gpu/fusions/fusions.h"
-#include "xla/service/gpu/gpu_device_info_for_tests.h"
-#include "xla/service/gpu/hlo_fusion_analysis.h"
-#include "xla/service/gpu/model/affine_map_printer.h"
-#include "xla/service/gpu/model/indexing_test_utils.h"
-#include "xla/status_macros.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/tests/hlo_test_base.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-class LoopTest : public HloTestBase {
- public:
-  void SetUp() override {
-    HloTestBase::SetUp();
-
-    printer_ =
-        AffineMapPrinter({"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"},
-                         {"chunk_id", "unroll_id"});
-  }
-
- protected:
-  stream_executor::DeviceDescription device_info_ =
-      TestGpuDeviceInfo::RTXA6000DeviceInfo();
-  AffineMapPrinter printer_;
-  mlir::MLIRContext mlir_context_;
-};
-
-absl::StatusOr<std::unique_ptr<KernelFusionInterface>> GetFusion(
-    const HloFusionAnalysis& analysis) {
-  auto emitter = GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis});
-  auto fusion = dynamic_cast<KernelFusionInterface*>(emitter.get());
-  TF_RET_CHECK(fusion != nullptr);
-
-  emitter.release();
-  return std::unique_ptr<KernelFusionInterface>{fusion};
-}
-
-TEST_F(LoopTest, ThreadIndexingUnrolled) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule module
-
-    neg {
-      %input = f32[100,200,300] parameter(0)
-      ROOT neg = f32[100,200,300] negate(%input)
-    }
-
-    ENTRY entry {
-      %input = f32[100,200,300] parameter(0)
-      ROOT %fusion = f32[100,200,300] fusion(%input), kind=kLoop, calls=neg
-    })")
-                    .value();
-
-  auto* root = module->entry_computation()->root_instruction();
-  auto analysis = AnalyzeFusion(*root, device_info_);
-
-  TF_ASSERT_OK_AND_ASSIGN(auto loop_fusion, GetFusion(analysis));
-  auto thread_id_to_output_indexing =
-      loop_fusion->ComputeThreadIdToOutputIndexing(/*root_index=*/0,
-                                                   &mlir_context_);
-
-  EXPECT_THAT(thread_id_to_output_indexing->ToString(printer_),
-              MatchIndexingString(R"(
-  (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (
-    (bl_x * 128 + chunk_id * 129024 + th_x) floordiv 15000,
-    ((bl_x * 128 + chunk_id * 129024 + th_x) floordiv 75) mod 200,
-    ((bl_x * 128 + chunk_id * 129024 + th_x) mod 75) * 4 + unroll_id
-  )
-  domain:
-  th_x in [0, 127]
-  th_y in [0, 0]
-  th_z in [0, 0]
-  bl_x in [0, 1007]
-  bl_y in [0, 0]
-  bl_z in [0, 0]
-  chunk_id in [0, 11]
-  unroll_id in [0, 3]
-  bl_x * 128 + chunk_id * 129024 + th_x in [0, 1499999]
-)"));
-}
-
-TEST_F(LoopTest, ThreadIndexingNotUnrolled) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule module
-
-    neg {
-      %input = f32[20] parameter(0)
-      ROOT neg = f32[20] negate(%input)
-    }
-
-    ENTRY entry {
-      %input = f32[20] parameter(0)
-      ROOT %fusion = f32[20] fusion(%input), kind=kLoop, calls=neg
-    })")
-                    .value();
-
-  auto* root = module->entry_computation()->root_instruction();
-  auto analysis = AnalyzeFusion(*root, device_info_);
-
-  TF_ASSERT_OK_AND_ASSIGN(auto loop_fusion, GetFusion(analysis));
-  auto thread_id_to_output_indexing =
-      loop_fusion->ComputeThreadIdToOutputIndexing(/*root_index=*/0,
-                                                   &mlir_context_);
-  EXPECT_THAT(thread_id_to_output_indexing->ToString(printer_),
-              MatchIndexingString(R"(
-              (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (th_x)
-              domain:
-              th_x in [0, 19]
-              th_y in [0, 0]
-              th_z in [0, 0]
-              bl_x in [0, 0]
-              bl_y in [0, 0]
-              bl_z in [0, 0]
-              chunk_id in [0, 0]
-              unroll_id in [0, 0]
-            )"));
-  auto thread_id_to_input_indexing =
-      loop_fusion->ComputeThreadIdToInputIndexing(
-          /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_);
-  EXPECT_THAT(thread_id_to_input_indexing->ToString(printer_),
-              MatchIndexingString(R"(
-              (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (th_x)
-              domain:
-              th_x in [0, 19]
-              th_y in [0, 0]
-              th_z in [0, 0]
-              bl_x in [0, 0]
-              bl_y in [0, 0]
-              bl_z in [0, 0]
-              chunk_id in [0, 0]
-              unroll_id in [0, 0]
-            )"));
-}
-
-TEST_F(LoopTest, Broadcast) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule module
-
-    bcast {
-      %input = f32[20] parameter(0)
-      ROOT bcast = f32[10, 20, 30] broadcast(%input), dimensions={1}
-    }
-
-    ENTRY entry {
-      %input = f32[20] parameter(0)
-      ROOT %fusion = f32[10, 20, 30] fusion(%input), kind=kLoop, calls=bcast
-    })")
-                    .value();
-
-  auto* root = module->entry_computation()->root_instruction();
-  auto analysis = AnalyzeFusion(*root, device_info_);
-
-  TF_ASSERT_OK_AND_ASSIGN(auto loop_fusion, GetFusion(analysis));
-  auto thread_id_to_output_indexing =
-      loop_fusion->ComputeThreadIdToOutputIndexing(/*root_index=*/0,
-                                                   &mlir_context_);
-  EXPECT_THAT(thread_id_to_output_indexing->ToString(printer_),
-              MatchIndexingString(R"(
-              (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (
-                (bl_x * 128 + th_x) floordiv 600,
-                ((bl_x * 128 + th_x) floordiv 30) mod 20,
-                (bl_x * 128 + th_x) mod 30)
-                domain:
-                th_x in [0, 127]
-                th_y in [0, 0]
-                th_z in [0, 0]
-                bl_x in [0, 46]
-                bl_y in [0, 0]
-                bl_z in [0, 0]
-                chunk_id in [0, 0]
-                unroll_id in [0, 0]
-                bl_x * 128 + th_x in [0, 5999]
-            )"));
-  auto thread_id_to_input_indexing =
-      loop_fusion->ComputeThreadIdToInputIndexing(
-          /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_);
-  EXPECT_THAT(thread_id_to_input_indexing->ToString(printer_),
-              MatchIndexingString(R"(
-              (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] ->
-                  (((bl_x * 128 + th_x) floordiv 30) mod 20)
-                domain:
-                th_x in [0, 127]
-                th_y in [0, 0]
-                th_z in [0, 0]
-                bl_x in [0, 46]
-                bl_y in [0, 0]
-                bl_z in [0, 0]
-                chunk_id in [0, 0]
-                unroll_id in [0, 0]
-                bl_x * 128 + th_x in [0, 5999]
-            )"));
-}
-
-}  // namespace
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/BUILD b/third_party/xla/xla/service/gpu/fusions/mlir/BUILD
index 5231483..08a159f 100644
--- a/third_party/xla/xla/service/gpu/fusions/mlir/BUILD
+++ b/third_party/xla/xla/service/gpu/fusions/mlir/BUILD
@@ -1,4 +1,3 @@
-load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")
 load("//xla:xla.bzl", "xla_cc_test")
 
 package(
@@ -76,7 +75,7 @@
         "//xla/mlir_hlo:type_conversion",
         "//xla/service:algorithm_util",
         "//xla/service/gpu:hlo_traversal",
-        "//xla/service/gpu/fusions/mlir/ir:xla_gpu",
+        "//xla/service/gpu/fusions/ir:xla_gpu",
         "//xla/service/gpu/model:indexing_analysis",
         "//xla/service/llvm_ir:llvm_util",
         "//xla/stream_executor:device_description",
@@ -119,7 +118,7 @@
         "//xla/mlir_hlo",
         "//xla/service:hlo_parser",
         "//xla/service/gpu:launch_dimensions",
-        "//xla/service/gpu/fusions/mlir/ir:xla_gpu",
+        "//xla/service/gpu/fusions/ir:xla_gpu",
         "//xla/service/gpu/model:indexing_analysis",
         "//xla/service/llvm_ir:llvm_util",
         "//xla/stream_executor:launch_dim",
@@ -153,7 +152,6 @@
     deps = [
         ":computation_partitioner",
         ":elemental_hlo_to_mlir",
-        ":passes",
         ":type_util",
         "//xla:shape_util",
         "//xla:status_macros",
@@ -172,7 +170,8 @@
         "//xla/service/gpu:launch_dimensions",
         "//xla/service/gpu:target_util",
         "//xla/service/gpu/fusions:fusion_emitter",
-        "//xla/service/gpu/fusions/mlir/ir:xla_gpu",
+        "//xla/service/gpu/fusions/ir:xla_gpu",
+        "//xla/service/gpu/fusions/transforms:passes",
         "//xla/service/gpu/model:indexing_analysis",
         "//xla/service/gpu/runtime:kernel_thunk",
         "//xla/service/llvm_ir:llvm_util",
@@ -261,95 +260,6 @@
     ],
 )
 
-gentbl_cc_library(
-    name = "passes_inc_gen",
-    tbl_outs = [
-        (
-            [
-                "-gen-pass-decls",
-                "-name=GpuFusionTransforms",
-            ],
-            "passes.h.inc",
-        ),
-    ],
-    tblgen = "@llvm-project//mlir:mlir-tblgen",
-    td_file = "passes.td",
-    visibility = ["//visibility:private"],
-    deps = ["@llvm-project//mlir:PassBaseTdFiles"],
-)
-
-cc_library(
-    name = "passes",
-    srcs = [
-        "convert_xla_gpu_pure_call_ops.cc",
-        "erase_dead_functions.cc",
-        "expand_float_ops.cc",
-        "flatten_tensors.cc",
-        "lower_tensors.cc",
-        "lower_to_llvm.cc",
-        "lower_xla_gpu_to_scf.cc",
-        "merge_pointers_to_same_slice.cc",
-        "optimize_loops.cc",
-        "propagate_slice_indices.cc",
-        "simplify_affine.cc",
-        "simplify_arith.cc",
-        "unswitch_loops.cc",
-        "vectorize_loads_stores.cc",
-    ],
-    hdrs = ["passes.h"],
-    deps = [
-        ":passes_inc_gen",
-        "//xla:shape_util",
-        "//xla:util",
-        "//xla:xla_data_proto_cc",
-        "//xla/mlir_hlo",
-        "//xla/mlir_hlo:map_mhlo_to_scalar_op",
-        "//xla/service/gpu:ir_emission_utils",
-        "//xla/service/gpu/fusions/mlir/ir:xla_gpu",
-        "//xla/service/gpu/model:indexing_analysis",
-        "//xla/stream_executor:device_description",
-        "@com_google_absl//absl/algorithm:container",
-        "@com_google_absl//absl/base:core_headers",
-        "@com_google_absl//absl/container:flat_hash_map",
-        "@com_google_absl//absl/log:check",
-        "@com_google_absl//absl/strings",
-        "@llvm-project//llvm:Support",
-        "@llvm-project//mlir:AffineDialect",
-        "@llvm-project//mlir:AffineToStandard",
-        "@llvm-project//mlir:AffineUtils",
-        "@llvm-project//mlir:ArithDialect",
-        "@llvm-project//mlir:ArithToLLVM",
-        "@llvm-project//mlir:ArithTransforms",
-        "@llvm-project//mlir:CallOpInterfaces",
-        "@llvm-project//mlir:ComplexDialect",
-        "@llvm-project//mlir:ComplexToLLVM",
-        "@llvm-project//mlir:ControlFlowToLLVM",
-        "@llvm-project//mlir:DataLayoutInterfaces",
-        "@llvm-project//mlir:DialectUtils",
-        "@llvm-project//mlir:FuncDialect",
-        "@llvm-project//mlir:FuncToLLVM",
-        "@llvm-project//mlir:GPUDialect",
-        "@llvm-project//mlir:GPUToNVVMTransforms",
-        "@llvm-project//mlir:IR",
-        "@llvm-project//mlir:LLVMCommonConversion",
-        "@llvm-project//mlir:LLVMDialect",
-        "@llvm-project//mlir:MathDialect",
-        "@llvm-project//mlir:MathToLLVM",
-        "@llvm-project//mlir:MathTransforms",
-        "@llvm-project//mlir:NVVMDialect",
-        "@llvm-project//mlir:Pass",
-        "@llvm-project//mlir:SCFDialect",
-        "@llvm-project//mlir:SCFToControlFlow",
-        "@llvm-project//mlir:SCFUtils",
-        "@llvm-project//mlir:Support",
-        "@llvm-project//mlir:TensorDialect",
-        "@llvm-project//mlir:TransformUtils",
-        "@llvm-project//mlir:VectorDialect",
-        "@llvm-project//mlir:VectorToLLVM",
-        "@llvm-project//mlir:VectorTransforms",
-    ],
-)
-
 cc_library(
     name = "type_util",
     srcs = ["type_util.cc"],
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/convert_xla_gpu_pure_call_ops.cc b/third_party/xla/xla/service/gpu/fusions/mlir/convert_xla_gpu_pure_call_ops.cc
deleted file mode 100644
index bb1270e..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/convert_xla_gpu_pure_call_ops.cc
+++ /dev/null
@@ -1,61 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-    http://www.apache.org/licenses/LICENSE-2.0
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include <memory>
-#include <utility>
-
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-#define GEN_PASS_DEF_CONVERTPURECALLOPSPASS
-#include "xla/service/gpu/fusions/mlir/passes.h.inc"
-
-struct RewriteCall : mlir::OpRewritePattern<PureCallOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  mlir::LogicalResult matchAndRewrite(
-      PureCallOp op, mlir::PatternRewriter& rewriter) const override {
-    rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
-        op, op.getResultTypes(), op.getOperands(), op->getAttrs());
-    return mlir::success();
-  }
-};
-
-class ConvertPureCallOpsPass
-    : public impl::ConvertPureCallOpsPassBase<ConvertPureCallOpsPass> {
- public:
-  void runOnOperation() override {
-    auto* ctx = &getContext();
-    mlir::RewritePatternSet patterns(ctx);
-    patterns.add<RewriteCall>(ctx);
-    if (mlir::failed(mlir::applyPatternsAndFoldGreedily(getOperation(),
-                                                        std::move(patterns)))) {
-      signalPassFailure();
-    }
-  }
-};
-
-}  // namespace
-
-std::unique_ptr<::mlir::Pass> CreateConvertPureCallOpsPass() {
-  return std::make_unique<ConvertPureCallOpsPass>();
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc
index 59471a3..6817336 100644
--- a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc
+++ b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc
@@ -73,8 +73,8 @@
 #include "xla/mlir_hlo/mhlo/utils/type_conversion.h"
 #include "xla/primitive_util.h"
 #include "xla/service/algorithm_util.h"
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h"
 #include "xla/service/gpu/fusions/mlir/computation_partitioner.h"
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h"
 #include "xla/service/gpu/fusions/mlir/type_util.h"
 #include "xla/service/gpu/hlo_traversal.h"
 #include "xla/service/gpu/model/indexing_analysis.h"
@@ -284,43 +284,45 @@
       PrimitiveTypeToMlirType(instr->shape().element_type(), b);
   int concat_dim =
       Cast<HloConcatenateInstruction>(instr)->concatenate_dimension();
-  int64_t offset = 0;
-  IfOp outermost_if = nullptr;
   SmallVector<Value, 3> operand_indices = indices;
-  for (auto [index, operand] : llvm::enumerate(instr->operands())) {
-    int64_t limit = offset + operand->shape().dimensions(concat_dim);
-    auto ins = b.create<CmpIOp>(CmpIPredicate::ult, indices[concat_dim],
-                                b.create<ConstantIndexOp>(limit));
-
-    auto generate_operand = [&, index = index]() {
-      operand_indices[concat_dim] = b.create<arith::SubIOp>(
-          indices[concat_dim], b.create<ConstantIndexOp>(offset));
-      TF_ASSIGN_OR_RETURN(auto operand,
-                          operand_provider(instr, index, operand_indices));
-      b.create<YieldOp>(operand);
-      return absl::OkStatus();
-    };
-
-    if (index < instr->operand_count() - 1) {
-      auto if_op =
-          b.create<IfOp>(mlir::TypeRange{result_element_type}, ins, true, true);
-      if (outermost_if == nullptr) {
-        outermost_if = if_op;
-      } else {
-        b.create<YieldOp>(if_op.getResults());
-      }
-
-      b.setInsertionPointToStart(if_op.getBody(0));
-      TF_RETURN_IF_ERROR(generate_operand());
-      b.setInsertionPointToStart(if_op.getBody(1));
-    } else {
-      TF_RETURN_IF_ERROR(generate_operand());
-    }
-    offset = limit;
+  SmallVector<int64_t, 3> offsets{0};
+  for (auto* operand : instr->operands()) {
+    offsets.push_back(offsets.back() + operand->shape().dimensions(concat_dim));
   }
 
-  b.setInsertionPointAfter(outermost_if);
-  return outermost_if.getResults();
+  std::function<absl::StatusOr<SmallVector<Value, 1>>(int64_t, int64_t)>
+      generate_concat;
+  generate_concat = [&](int64_t begin,
+                        int64_t end) -> absl::StatusOr<SmallVector<Value, 1>> {
+    // If there's just one operand in the range, emit it.
+    if (begin == end - 1) {
+      operand_indices[concat_dim] = b.create<arith::SubIOp>(
+          indices[concat_dim], b.create<ConstantIndexOp>(offsets[begin]));
+      TF_ASSIGN_OR_RETURN(auto operand,
+                          operand_provider(instr, begin, operand_indices));
+      return operand;
+    }
+
+    int64_t mid = (begin + end) / 2;  // No risk of overflow.
+    auto if_op = b.create<IfOp>(
+        mlir::TypeRange{result_element_type},
+        b.create<CmpIOp>(CmpIPredicate::ult, indices[concat_dim],
+                         b.create<ConstantIndexOp>(offsets[mid])),
+        true, true);
+
+    b.setInsertionPointToStart(if_op.getBody(0));
+    TF_ASSIGN_OR_RETURN(auto left_val, generate_concat(begin, mid));
+    b.create<YieldOp>(left_val);
+
+    b.setInsertionPointToStart(if_op.getBody(1));
+    TF_ASSIGN_OR_RETURN(auto right_val, generate_concat(mid, end));
+    b.create<YieldOp>(right_val);
+    b.setInsertionPointAfter(if_op);
+
+    return if_op.getResults();
+  };
+
+  return generate_concat(0, instr->operand_count());
 }
 
 absl::StatusOr<SmallVector<Value, 1>> EmitDynamicSlice(
@@ -665,9 +667,10 @@
   return b.createOrFold<mlir::affine::AffineApplyOp>(expr, args);
 }
 
-SmallVector<Value, 3> ApplyIndexing(const IndexingMap& map, ValueRange dims,
+SmallVector<Value, 3> ApplyIndexing(IndexingMap map, ValueRange dims,
                                     ValueRange symbols,
                                     ImplicitLocOpBuilder& b) {
+  map.ClearConstraints();
   SmallVector<Value, 3> results;
   for (unsigned int i = 0; i < map.GetAffineMap().getNumResults(); ++i) {
     SmallVector<Value, 1> result;
@@ -1465,6 +1468,8 @@
       .Convert();
 }
 
+}  // namespace
+
 void GetLoopBoundsFromIndexingMap(ImplicitLocOpBuilder& b,
                                   const IndexingMap& indexing_map,
                                   SmallVectorImpl<Value>* lbs,
@@ -1479,8 +1484,6 @@
   }
 }
 
-}  // namespace
-
 absl::Status SubgraphToMlirFunction(
     const PartitionedComputation& computation,
     const PartitionedComputation::Subgraph& subgraph, mlir::func::FuncOp& func,
@@ -1513,20 +1516,6 @@
 
 namespace {
 
-bool IsSymbolConstrained(const IndexingMap& map, int symbol_id) {
-  for (const auto& [expr, _] : map.GetConstraints()) {
-    bool result = false;
-    expr.walk([&](mlir::AffineExpr leaf) {
-      auto sym = mlir::dyn_cast<mlir::AffineSymbolExpr>(leaf);
-      if (sym && sym.getPosition() == symbol_id) {
-        result = true;
-      }
-    });
-    if (result) return true;
-  }
-  return false;
-}
-
 ValueRange EmitLoopNestImpl(
     ImplicitLocOpBuilder& b, ValueRange dim_values, ValueRange iter_args_inits,
     const IndexingMap& indexing_map,
@@ -1623,7 +1612,7 @@
        sym_index >= 0 && cumulative_loop_size < 64; --sym_index) {
     auto& bound = indexing_map.GetSymbolBound(sym_index);
     cumulative_loop_size *= bound.GetLoopTripCount();
-    if (!IsSymbolConstrained(indexing_map, sym_index)) continue;
+    if (!indexing_map.IsSymbolConstrained(sym_index)) continue;
 
     IndexingMap peeled_map = indexing_map;
     if (bound.upper == bound.lower) continue;
@@ -1632,7 +1621,7 @@
     peeled_map.Simplify();
 
     // If the symbol is still constrained, peeling does not help.
-    if (IsSymbolConstrained(peeled_map, sym_index)) continue;
+    if (peeled_map.IsSymbolConstrained(sym_index)) continue;
 
     auto first_results = EmitLoopNestImpl(b, dim_values, iter_args_inits,
                                           peeled_map, create_body, vectorize);
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h
index 1f52109..1a97b57 100644
--- a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h
+++ b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h
@@ -94,7 +94,7 @@
                             mlir::ImplicitLocOpBuilder& b);
 
 // Creates an `apply_indexing` op for the given map.
-llvm::SmallVector<mlir::Value, 3> ApplyIndexing(const IndexingMap& map,
+llvm::SmallVector<mlir::Value, 3> ApplyIndexing(IndexingMap map,
                                                 mlir::ValueRange dims,
                                                 mlir::ValueRange symbols,
                                                 mlir::ImplicitLocOpBuilder& b);
@@ -148,6 +148,13 @@
                                               mlir::Block& src_block,
                                               mlir::ValueRange mapped_args);
 
+// Populates `lbs`, `ubs` and `steps` with the loop bounds from `indexing_map`.
+void GetLoopBoundsFromIndexingMap(mlir::ImplicitLocOpBuilder& b,
+                                  const IndexingMap& indexing_map,
+                                  llvm::SmallVectorImpl<mlir::Value>* lbs,
+                                  llvm::SmallVectorImpl<mlir::Value>* ubs,
+                                  llvm::SmallVectorImpl<mlir::Value>* steps);
+
 }  // namespace mlir_converter
 }  // namespace gpu
 }  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc
index d7bbbb0..eab1568 100644
--- a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc
+++ b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc
@@ -36,8 +36,8 @@
 #include "mlir/Transforms/Passes.h"
 #include "xla/hlo/ir/hlo_computation.h"
 #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h"
 #include "xla/service/gpu/fusions/mlir/computation_partitioner.h"
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h"
 #include "xla/service/gpu/launch_dimensions.h"
 #include "xla/service/gpu/model/indexing_analysis.h"
 #include "xla/service/gpu/model/indexing_map.h"
@@ -235,10 +235,10 @@
     // CHECK:      %[[INIT:.*]] = tensor.extract %[[ARG1]][]
     // CHECK:      %[[RET:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[C7]]
     // CHECK-SAME:   step %[[C1]] iter_args(%[[ACC:.*]] = %[[INIT]])
-    // CHECK:      %[[J0:.*]] = xla_gpu.apply_indexing affine_map<(d0) -> (d0 * 4)>
-    // CHECK-SAME:     (%[[Y]] in [0, 2])
-    // CHECK:      %[[J1:.*]] = xla_gpu.apply_indexing affine_map<(d0)[s0] -> (d0 + s0 - 3)>
-    // CHECK-SAME:     (%[[Z]] in [0, 7])[%[[I]] in [0, 6]]
+    // CHECK:      %[[J0:.*]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0) -> (d0 * 4), domain: d0 in [0, 2]>(%[[Y]])
+    // CHECK:      %[[J1:.*]] = xla_gpu.apply_indexing
+    // CHECK-SAME:     #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0 - 3),
+    // CHECK-SAME:              d0 in [0, 7], s0 in [0, 6]>(%[[Z]])[%[[I]]]
     // CHECK:          %[[VAL:.*]] = tensor.extract %[[ARG0]]
     // CHECK-SAME:        [%[[X]], %[[J0]], %[[J1]]]
     // CHECK:          %[[UPD:.*]] = func.call @add_sum(%[[ACC]],
@@ -285,8 +285,8 @@
     // If symbol rescaling wasn't working we would have a
     // `s0 floordiv <base_dilation>` in the map:
     // CHECK:      %[[K:.*]] = xla_gpu.apply_indexing
-    // CHECK-SAME:   affine_map<(d0)[s0] -> (d0 * 2 + s0)>
-    // CHECK-SAME:   (%[[X]] in [0, 18])[%[[I]] in [0, 3]]
+    // CHECK-SAME:   #xla_gpu.indexing_map<(d0)[s0] -> (d0 * 2 + s0),
+    // CHECK-SAME:   d0 in [0, 18], s0 in [0, 3]>(%[[X]])[%[[I]]]
 
     // CHECK:      tensor.extract %[[ARG0]][%[[K]], %[[Y]], %[[Z]]]
   )"));
@@ -333,6 +333,79 @@
   )"));
 }
 
+TEST_F(ElementalHloToMlirTest, ConcatenateMany) {
+  TF_EXPECT_OK(Run(R"(
+    ENTRY main {
+      p0 = f32[10,1,30] parameter(0)
+      p1 = f32[10,2,30] parameter(1)
+      p2 = f32[10,3,30] parameter(2)
+      p3 = f32[10,4,30] parameter(3)
+      p4 = f32[10,5,30] parameter(4)
+      p5 = f32[10,6,30] parameter(5)
+      p6 = f32[10,7,30] parameter(6)
+      ROOT r = f32[10,28,30] concatenate(p0, p1, p2, p3, p4, p5, p6),
+          dimensions={1}
+    })",
+                   R"(
+      // CHECK-DAG:  %[[C1:.*]] = arith.constant 1 : index
+      // CHECK-DAG:  %[[C3:.*]] = arith.constant 3 : index
+      // CHECK-DAG:  %[[C6:.*]] = arith.constant 6 : index
+      // CHECK-DAG:  %[[C10:.*]] = arith.constant 10 : index
+      // CHECK-DAG:  %[[C15:.*]] = arith.constant 15 : index
+      // CHECK-DAG:  %[[C21:.*]] = arith.constant 21 : index
+      // CHECK:      %[[P0TO2:.*]] = arith.cmpi ult, %[[I:.*]], %[[C6]]
+      // CHECK:      %[[CONCAT:.*]] = scf.if %[[P0TO2]] -> (f32)
+      // CHECK-NEXT:   %[[P0:.*]] = arith.cmpi ult, %[[I]], %[[C1]]
+      // CHECK-NEXT:   scf.if %[[P0]]
+      // CHECK-NEXT:     tensor.extract {{.*}}[{{.*}}, %[[I]], {{.*}}] : tensor<10x1x30xf32>
+      // CHECK-NEXT:     yield
+      // CHECK-NEXT:   } else {
+      // CHECK-NEXT:     %[[P1:.*]] = arith.cmpi ult, %[[I]], %[[C3]]
+      // CHECK-NEXT:     scf.if %[[P1]]
+      // CHECK-NEXT:       %[[O:.*]] = arith.subi %[[I]], %[[C1]]
+      // CHECK-NEXT:       tensor.extract {{.*}}[{{.*}}, %[[O]], {{.*}}] : tensor<10x2x30xf32>
+      // CHECK-NEXT:       yield
+      // CHECK-NEXT:     } else {
+      // CHECK-NEXT:       %[[O:.*]] = arith.subi %[[I]], %[[C3]]
+      // CHECK-NEXT:       tensor.extract {{.*}}[{{.*}}, %[[O]], {{.*}}] : tensor<10x3x30xf32>
+      // CHECK-NEXT:       yield
+      // CHECK-NEXT:     }
+      // CHECK-NEXT:     yield
+      // CHECK-NEXT:   }
+      // CHECK-NEXT:   yield
+      // CHECK-NEXT: } else {
+      // CHECK-NEXT:   %[[P3TO4:.*]] = arith.cmpi ult, %[[I]], %[[C15]]
+      // CHECK-NEXT:   scf.if %[[P3TO4]]
+      // CHECK-NEXT:     %[[P3:.*]] = arith.cmpi ult, %[[I]], %[[C10]]
+      // CHECK-NEXT:     scf.if %[[P3]]
+      // CHECK-NEXT:       %[[O:.*]] = arith.subi %[[I]], %[[C6]]
+      // CHECK-NEXT:       tensor.extract {{.*}}[{{.*}}, %[[O]], {{.*}}] : tensor<10x4x30xf32>
+      // CHECK-NEXT:       yield
+      // CHECK-NEXT:     } else {
+      // CHECK-NEXT:       %[[O:.*]] = arith.subi %[[I]], %[[C10]]
+      // CHECK-NEXT:       tensor.extract {{.*}}[{{.*}}, %[[O]], {{.*}}] : tensor<10x5x30xf32>
+      // CHECK-NEXT:       yield
+      // CHECK-NEXT:     }
+      // CHECK-NEXT:     yield
+      // CHECK-NEXT:   } else {
+      // CHECK-NEXT:     %[[P5:.*]] = arith.cmpi ult, %[[I]], %[[C21]]
+      // CHECK-NEXT:     scf.if %[[P5]]
+      // CHECK-NEXT:       %[[O:.*]] = arith.subi %[[I]], %[[C15]]
+      // CHECK-NEXT:       tensor.extract {{.*}}[{{.*}}, %[[O]], {{.*}}] : tensor<10x6x30xf32>
+      // CHECK-NEXT:       yield
+      // CHECK-NEXT:     } else {
+      // CHECK-NEXT:       %[[O:.*]] = arith.subi %[[I]], %[[C21]]
+      // CHECK-NEXT:       tensor.extract {{.*}}[{{.*}}, %[[O]], {{.*}}] : tensor<10x7x30xf32>
+      // CHECK-NEXT:       yield
+      // CHECK-NEXT:     }
+      // CHECK-NEXT:     yield
+      // CHECK-NEXT:   }
+      // CHECK-NEXT:   yield
+      // CHECK-NEXT: }
+      // CHECK-NEXT: return %[[CONCAT]]
+  )"));
+}
+
 TEST_F(ElementalHloToMlirTest, ConcatenateUnsigned) {
   TF_EXPECT_OK(Run(R"(
     ENTRY main {
@@ -433,7 +506,7 @@
     // CHECK-DAG:    %[[C4:.*]] = arith.constant 4
     // CHECK-DAG:    %[[C7:.*]] = arith.constant 7
     // CHECK:        %[[CONSTRAINT_VAL:.*]] = xla_gpu.apply_indexing
-    // CHECK-SAME:     <(d0) -> ((d0 - 1) mod 2)>(%[[X]] in [1, 7])
+    // CHECK-SAME:     <(d0) -> ((d0 - 1) mod 2), domain: d0 in [1, 7]>(%[[X]])
     // CHECK:        %[[CONSTRAINT:.*]] = arith.cmpi eq, %[[CONSTRAINT_VAL]], %[[C0]]
     // CHECK-DAG:        %[[X_L:.*]] = arith.cmpi sge, %[[X]], %[[C1]]
     // CHECK-DAG:        %[[X_H:.*]] = arith.cmpi sle, %[[X]], %[[C7]]
@@ -445,11 +518,9 @@
     // CHECK:        %[[FROM_INPUT:.*]] = arith.andi %[[X_AND_CONSTRAINT]], %[[Y_BOUNDS]]
     // CHECK:        %[[RET:.*]] = scf.if %[[FROM_INPUT]]
     // CHECK:          %[[IN0:.*]] = xla_gpu.apply_indexing
-    // CHECK-SAME:         <(d0) -> ((d0 - 1) floordiv 2)>
-    // CHECK-SAME:         (%[[X]] in [1, 7])
+    // CHECK-SAME:         <(d0) -> ((d0 - 1) floordiv 2), domain: d0 in [1, 7]>(%[[X]])
     // CHECK:          %[[IN1:.*]] = xla_gpu.apply_indexing
-    // CHECK-SAME:         <(d0) -> (d0 - 4)>
-    // CHECK-SAME:         (%[[Y]] in [4, 7])
+    // CHECK-SAME:         <(d0) -> (d0 - 4), domain: d0 in [4, 7]>(%[[Y]])
     // CHECK:          %[[VAL:.*]] = tensor.extract %[[ARG0]][%[[IN0]], %[[IN1]]]
     // CHECK:          scf.yield %[[VAL]]
     // CHECK:        } else {
@@ -477,7 +548,7 @@
     // CHECK-DAG:    %[[C4:.*]] = arith.constant 4
     // CHECK-DAG:    %[[C7:.*]] = arith.constant 7
     // CHECK:        %[[CONSTRAINT_VAL:.*]] = xla_gpu.apply_indexing
-    // CHECK-SAME:     <(d0) -> ((d0 - 1) mod 2)>(%[[X]] in [1, 7])
+    // CHECK-SAME:     <(d0) -> ((d0 - 1) mod 2), domain: d0 in [1, 7]>(%[[X]])
     // CHECK:        %[[CONSTRAINT:.*]] = arith.cmpi eq, %[[CONSTRAINT_VAL]], %[[C0]]
     // CHECK-DAG:        %[[X_L:.*]] = arith.cmpi sge, %[[X]], %[[C1]]
     // CHECK-DAG:        %[[X_H:.*]] = arith.cmpi sle, %[[X]], %[[C7]]
@@ -489,11 +560,9 @@
     // CHECK:        %[[FROM_INPUT:.*]] = arith.andi %[[X_AND_CONSTRAINT]], %[[Y_BOUNDS]]
     // CHECK:        %[[RET:.*]] = scf.if %[[FROM_INPUT]]
     // CHECK:          %[[IN0:.*]] = xla_gpu.apply_indexing
-    // CHECK-SAME:         <(d0) -> ((d0 - 1) floordiv 2)>
-    // CHECK-SAME:         (%[[X]] in [1, 7])
+    // CHECK-SAME:         <(d0) -> ((d0 - 1) floordiv 2), domain: d0 in [1, 7]>(%[[X]])
     // CHECK:          %[[IN1:.*]] = xla_gpu.apply_indexing
-    // CHECK-SAME:         <(d0) -> (d0 - 4)>
-    // CHECK-SAME:         (%[[Y]] in [4, 7])
+    // CHECK-SAME:         <(d0) -> (d0 - 4), domain: d0 in [4, 7]>(%[[Y]])
     // CHECK:          %[[VAL:.*]] = tensor.extract %[[ARG0]][%[[IN0]], %[[IN1]]]
     // CHECK:          scf.yield %[[VAL]]
     // CHECK:        } else {
@@ -810,11 +879,11 @@
     // CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) {
     // CHECK:      %[[R3:.+]] = scf.if {{.+}} -> (f32) {
     // CHECK:      %[[XX0:.+]] = xla_gpu.apply_indexing
-    // CHECK-SAME:   affine_map<(d0)[s0] -> (d0 + s0)>
-    // CHECK-SAME:   (%[[W]] in [0, 5])[%[[X]] in [0, 2]]
+    // CHECK-SAME:   #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0),
+    // CHECK-SAME:   d0 in [0, 5], s0 in [0, 2]>(%[[W]])[%[[X]]]
     // CHECK:      %[[XX1:.+]] = xla_gpu.apply_indexing
-    // CHECK-SAME:   affine_map<(d0)[s0] -> (d0 + s0)>
-    // CHECK-SAME:   (%[[H]] in [0, 7])[%[[Y]] in [0, 4]]
+    // CHECK-SAME:   #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0),
+    // CHECK-SAME:   d0 in [0, 7], s0 in [0, 4]>(%[[H]])[%[[Y]]]
     // CHECK-DAG:    %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32>
     // CHECK-DAG:    %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32>
     // CHECK:        %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32
@@ -856,11 +925,11 @@
     // CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) {
     // CHECK:      %[[R3:.+]] = scf.if {{.+}} -> (f32) {
     // CHECK:        %[[XX0:.+]] = xla_gpu.apply_indexing
-    // CHECK-SAME:     affine_map<(d0)[s0] -> (d0 * 2 + s0)>
-    // CHECK-SAME:     (%[[W]] in [0, 2])[%[[X]] in [0, 2]]
+    // CHECK-SAME:     #xla_gpu.indexing_map<(d0)[s0] -> (d0 * 2 + s0),
+    // CHECK-SAME:     d0 in [0, 2], s0 in [0, 2]>(%[[W]])[%[[X]]]
     // CHECK:        %[[XX1:.+]] = xla_gpu.apply_indexing
-    // CHECK-SAME:     affine_map<(d0)[s0] -> (d0 * 2 + s0)>
-    // CHECK-SAME:     (%[[H]] in [0, 3])[%[[Y]] in [0, 4]]
+    // CHECK-SAME:     #xla_gpu.indexing_map<(d0)[s0] -> (d0 * 2 + s0),
+    // CHECK-SAME:     d0 in [0, 3], s0 in [0, 4]>(%[[H]])[%[[Y]]]
     // CHECK-DAG:    %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32>
     // CHECK-DAG:    %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32>
     // CHECK:        %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32
@@ -903,21 +972,21 @@
     // CHECK:      %[[R0:.+]] = scf.for %[[X:.+]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[A0:.+]] = %[[INIT]]) -> (f32) {
     // CHECK-NEXT: %[[R1:.+]] = scf.for %[[Y:.+]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[A1:.+]] = %[[A0]]) -> (f32) {
     // CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) {
-    // CHECK-DAG:  %[[TESTX:.+]] = xla_gpu.apply_indexing affine_map<(d0)[s0] -> (d0 + s0)>(%[[W]] in [0, 7])[%[[X]] in [0, 2]]
+    // CHECK-DAG:  %[[TESTX:.+]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0), domain: d0 in [0, 7], s0 in [0, 2]>(%[[W]])[%[[X]]]
     // CHECK-DAG:  %[[TXGE:.+]] = arith.cmpi sge, %[[TESTX]], %[[C1]] : index
     // CHECK-DAG:  %[[TXLE:.+]] = arith.cmpi sle, %[[TESTX]], %[[C8]] : index
     // CHECK-DAG:  %[[TX:.+]] = arith.andi %[[TXGE]], %[[TXLE]] : i1
-    // CHECK-DAG:  %[[TESTY:.+]] = xla_gpu.apply_indexing affine_map<(d0)[s0] -> (d0 + s0)>(%[[H]] in [0, 11])[%[[Y]] in [0, 4]]
+    // CHECK-DAG:  %[[TESTY:.+]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0), domain: d0 in [0, 11], s0 in [0, 4]>(%[[H]])[%[[Y]]]
     // CHECK-DAG:  %[[TYGE:.+]] = arith.cmpi sge, %[[TESTY]], %[[C2]] : index
     // CHECK-DAG:  %[[TYLE:.+]] = arith.cmpi sle, %[[TESTY]], %[[C13]] : index
     // CHECK-DAG:  %[[TY:.+]] = arith.andi %[[TYGE]], %[[TYLE]] : i1
     // CHECK:      %[[R3:.+]] = scf.if {{.+}} -> (f32) {
     // CHECK:        %[[XX0:.+]] = xla_gpu.apply_indexing
-    // CHECK-SAME:     affine_map<(d0)[s0] -> (d0 + s0 - 1)>
-    // CHECK-SAME:     (%[[W]] in [0, 7])[%[[X]] in [0, 2]]
+    // CHECK-SAME:     #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0 - 1),
+    // CHECK-SAME:     d0 in [0, 7], s0 in [0, 2]>(%[[W]])[%[[X]]]
     // CHECK:        %[[XX1:.+]] = xla_gpu.apply_indexing
-    // CHECK-SAME:     affine_map<(d0)[s0] -> (d0 + s0 - 2)>
-    // CHECK-SAME:     (%[[H]] in [0, 11])[%[[Y]] in [0, 4]]
+    // CHECK-SAME:     #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0 - 2),
+    // CHECK-SAME:     d0 in [0, 11], s0 in [0, 4]>(%[[H]])[%[[Y]]]
     // CHECK-DAG:    %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32>
     // CHECK-DAG:    %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32>
     // CHECK:        %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32
@@ -957,17 +1026,17 @@
     // CHECK:      %[[R0:.+]] = scf.for %[[X:.+]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[A0:.+]] = %[[INIT]]) -> (f32) {
     // CHECK-NEXT: %[[R1:.+]] = scf.for %[[Y:.+]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[A1:.+]] = %[[A0]]) -> (f32) {
     // CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) {
-    // CHECK-DAG:  %[[TESTX:.+]] = xla_gpu.apply_indexing affine_map<(d0)[s0] -> ((d0 + s0) mod 2)>(%[[W]] in [0, 12])[%[[X]] in [0, 2]]
+    // CHECK-DAG:  %[[TESTX:.+]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0)[s0] -> ((d0 + s0) mod 2), domain: d0 in [0, 12], s0 in [0, 2]>(%[[W]])[%[[X]]]
     // CHECK-DAG:  %[[TX:.+]] = arith.cmpi eq, %[[TESTX]], %[[C0]] : index
-    // CHECK-DAG:  %[[TESTY:.+]] = xla_gpu.apply_indexing affine_map<(d0)[s0] -> ((d0 + s0) mod 2)>(%[[H]] in [0, 18])[%[[Y]] in [0, 4]]
+    // CHECK-DAG:  %[[TESTY:.+]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0)[s0] -> ((d0 + s0) mod 2), domain: d0 in [0, 18], s0 in [0, 4]>(%[[H]])[%[[Y]]]
     // CHECK-DAG:  %[[TY:.+]] = arith.cmpi eq, %[[TESTY]], %[[C0]] : index
     // CHECK:      %[[R3:.+]] = scf.if {{.+}} -> (f32) {
     // CHECK:        %[[XX0:.+]] = xla_gpu.apply_indexing
-    // CHECK-SAME:     affine_map<(d0)[s0] -> ((d0 + s0) floordiv 2)>
-    // CHECK-SAME:     (%[[W]] in [0, 12])[%[[X]] in [0, 2]]
+    // CHECK-SAME:     #xla_gpu.indexing_map<(d0)[s0] -> ((d0 + s0) floordiv 2),
+    // CHECK-SAME:     d0 in [0, 12], s0 in [0, 2]>(%[[W]])[%[[X]]]
     // CHECK:        %[[XX1:.+]] = xla_gpu.apply_indexing
-    // CHECK-SAME:     affine_map<(d0)[s0] -> ((d0 + s0) floordiv 2)>
-    // CHECK-SAME:     (%[[H]] in [0, 18])[%[[Y]] in [0, 4]]
+    // CHECK-SAME:     #xla_gpu.indexing_map<(d0)[s0] -> ((d0 + s0) floordiv 2),
+    // CHECK-SAME:     d0 in [0, 18], s0 in [0, 4]>(%[[H]])[%[[Y]]]
     // CHECK-DAG:    %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32>
     // CHECK-DAG:    %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32>
     // CHECK:        %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32
@@ -1009,11 +1078,11 @@
     // CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) {
     // CHECK:      %[[R3:.+]] = scf.if {{.+}} -> (f32) {
     // CHECK:        %[[XX0:.+]] = xla_gpu.apply_indexing
-    // CHECK-SAME:     affine_map<(d0)[s0] -> (d0 + s0 * 2)>
-    // CHECK-SAME:     (%[[W]] in [0, 3])[%[[X]] in [0, 2]]
+    // CHECK-SAME:     #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0 * 2),
+    // CHECK-SAME:     d0 in [0, 3], s0 in [0, 2]>(%[[W]])[%[[X]]]
     // CHECK:        %[[XX1:.+]] = xla_gpu.apply_indexing
-    // CHECK-SAME:     affine_map<(d0)[s0] -> (d0 + s0 * 2)>
-    // CHECK-SAME:     (%[[H]] in [0, 3])[%[[Y]] in [0, 4]]
+    // CHECK-SAME:     #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0 * 2),
+    // CHECK-SAME:     d0 in [0, 3], s0 in [0, 4]>(%[[H]])[%[[Y]]]
     // CHECK-DAG:    %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32>
     // CHECK-DAG:    %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32>
     // CHECK:        %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32
@@ -1055,17 +1124,14 @@
     // CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) {
     // CHECK:      %[[R3:.+]] = scf.if {{.+}} -> (f32) {
     // CHECK:        %[[XX0:.+]] = xla_gpu.apply_indexing
-    // CHECK-SAME:     affine_map<(d0)[s0] -> (d0 + s0)>
-    // CHECK-SAME:     (%[[W]] in [0, 5])
-    // CHECK-SAME:     [%[[X]] in [0, 2]]
+    // CHECK-SAME:     #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0),
+    // CHECK-SAME:     d0 in [0, 5], s0 in [0, 2]>(%[[W]])[%[[X]]]
     // CHECK:        %[[XX1:.+]] = xla_gpu.apply_indexing
-    // CHECK-SAME:     affine_map<(d0)[s0] -> (d0 + s0)>
-    // CHECK-SAME:     (%[[H]] in [0, 7])
-    // CHECK-SAME:     [%[[Y]] in [0, 4]]
+    // CHECK-SAME:     #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0),
+    // CHECK-SAME:     d0 in [0, 7], s0 in [0, 4]>(%[[H]])[%[[Y]]]
     // CHECK:        %[[XX2:.+]] = xla_gpu.apply_indexing
-    // CHECK-SAME:     affine_map<(d0)[s0] -> ((d0 floordiv 8) * 2 + s0)>
-    // CHECK-SAME:     (%[[O]] in [0, 15])
-    // CHECK-SAME:     [%[[I]] in [0, 1]]
+    // CHECK-SAME:     #xla_gpu.indexing_map<(d0)[s0] -> ((d0 floordiv 8) * 2 + s0),
+    // CHECK-SAME:     d0 in [0, 15], s0 in [0, 1]>(%[[O]])[%[[I]]]
     // CHECK-DAG:    %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[XX2]]] : tensor<2x8x12x4xf32>
     // CHECK-DAG:    %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<2x3x5x16xf32>
     // CHECK:        %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32
@@ -1109,13 +1175,11 @@
     // CHECK-NEXT: %[[R3:.+]] = scf.for %[[G:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A2]]) -> (f32) {
     // CHECK:      %[[R4:.+]] = scf.if {{.+}} -> (f32) {
     // CHECK:        %[[XX0:.+]] = xla_gpu.apply_indexing
-    // CHECK-SAME:     affine_map<(d0)[s0] -> (d0 + s0)>
-    // CHECK-SAME:     (%[[W]] in [0, 5])
-    // CHECK-SAME:     [%[[X]] in [0, 2]]
+    // CHECK-SAME:     #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0),
+    // CHECK-SAME:     d0 in [0, 5], s0 in [0, 2]>(%[[W]])[%[[X]]]
     // CHECK:        %[[XX1:.+]] = xla_gpu.apply_indexing
-    // CHECK-SAME:     affine_map<(d0)[s0] -> (d0 + s0)>
-    // CHECK-SAME:     (%[[H]] in [0, 7])
-    // CHECK-SAME:     [%[[Y]] in [0, 4]]
+    // CHECK-SAME:     #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0),
+    // CHECK-SAME:     d0 in [0, 7], s0 in [0, 4]>(%[[H]])[%[[Y]]]
     // CHECK-DAG:    %[[VL:.+]] = tensor.extract %[[LHS]][%[[G]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32>
     // CHECK-DAG:    %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32>
     // CHECK:        %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32
@@ -1581,8 +1645,8 @@
     // CHECK-SAME:     %[[X:.*]]: index {{{.*}}}, %[[Y:.*]]: index {{{.*}}}
     // CHECK:        %[[A:.*]] = tensor.extract %[[P0]][%[[X]], %[[Y]]]
     // CHECK:        %[[IDX:.*]] = xla_gpu.apply_indexing
-    // CHECK-SAME:       affine_map<(d0, d1) -> (d0 * 10 + d1)>
-    // CHECK-SAME:       (%[[X]] in [0, 9], %[[Y]] in [0, 9])
+    // CHECK-SAME:       #xla_gpu.indexing_map<(d0, d1) -> (d0 * 10 + d1),
+    // CHECK-SAME:       d0 in [0, 9], d1 in [0, 9]>(%[[X]], %[[Y]])
     // CHECK:        %[[B:.*]] = tensor.extract %[[P1]][%[[IDX]]]
     // CHECK:        return %[[A]], %[[B]]
   )"));
@@ -1605,8 +1669,8 @@
     // CHECK-SAME:     %[[X:.*]]: index {{{.*}}}, %[[Y:.*]]: index {{{.*}}}
     // CHECK:          %[[P0_V:.*]] = xla_gpu.pure_call @main_p0
     // CHECK:          %[[IDX:.*]] =
-    // CHECK-SAME:       affine_map<(d0, d1) -> (d0 * 10 + d1)>
-    // CHECK-SAME:       (%[[X]] in [0, 9], %[[Y]] in [0, 9])
+    // CHECK-SAME:       #xla_gpu.indexing_map<(d0, d1) -> (d0 * 10 + d1),
+    // CHECK-SAME:       d0 in [0, 9], d1 in [0, 9]>(%[[X]], %[[Y]])
     // CHECK:          %[[P1_V:.*]] = xla_gpu.pure_call @main_p1
     // CHECK-SAME:       (%[[P0]], %[[P1]], %[[IDX]])
     // CHECK:          return %[[P0_V]], %[[P1_V]], %[[P1_V]], %[[P1_V]], %[[P0_V]]
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/erase_dead_functions.cc b/third_party/xla/xla/service/gpu/fusions/mlir/erase_dead_functions.cc
deleted file mode 100644
index 012201a..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/erase_dead_functions.cc
+++ /dev/null
@@ -1,86 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include <memory>
-#include <queue>
-
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/SymbolTable.h"
-#include "mlir/Interfaces/CallInterfaces.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Support/LLVM.h"
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h"
-
-namespace xla {
-namespace gpu {
-
-#define GEN_PASS_DEF_ERASEDEADFUNCTIONSPASS
-#include "xla/service/gpu/fusions/mlir/passes.h.inc"
-
-namespace {
-
-struct CallInfo {
-  PureCallOp call;
-  int count;
-};
-
-llvm::DenseSet<mlir::func::FuncOp> FindLiveFunctions(mlir::ModuleOp module) {
-  std::queue<mlir::func::FuncOp> worklist;
-  llvm::DenseSet<mlir::func::FuncOp> live_funcs;
-  module.walk([&](mlir::func::FuncOp func) {
-    if (!func.isPrivate()) {
-      worklist.push(func);
-      live_funcs.insert(func);
-    }
-  });
-
-  mlir::SymbolTableCollection symbol_table;
-  while (!worklist.empty()) {
-    auto func = worklist.front();
-    worklist.pop();
-    func.walk([&](mlir::CallOpInterface call) {
-      auto callee =
-          mlir::cast<mlir::func::FuncOp>(call.resolveCallable(&symbol_table));
-      if (live_funcs.insert(callee).second) {
-        worklist.push(callee);
-      }
-    });
-  }
-  return live_funcs;
-}
-
-class EraseDeadFunctionsPass
-    : public impl::EraseDeadFunctionsPassBase<EraseDeadFunctionsPass> {
- public:
-  void runOnOperation() override {
-    // Find live functions and erase dead ones.
-    auto live = FindLiveFunctions(getOperation());
-    getOperation().walk([&](mlir::func::FuncOp func) {
-      if (!live.contains(func)) {
-        func.erase();
-      }
-    });
-  }
-};
-
-}  // namespace
-
-std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
-CreateEraseDeadFunctionsPass() {
-  return std::make_unique<EraseDeadFunctionsPass>();
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/expand_float_ops.cc b/third_party/xla/xla/service/gpu/fusions/mlir/expand_float_ops.cc
deleted file mode 100644
index 001df2c..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/expand_float_ops.cc
+++ /dev/null
@@ -1,675 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include <algorithm>
-#include <array>
-#include <cassert>
-#include <cstdint>
-#include <memory>
-#include <utility>
-
-#include "absl/log/check.h"
-#include "llvm/ADT/APFloat.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
-#include "mlir/Dialect/Math/IR/Math.h"
-#include "mlir/Dialect/Math/Transforms/Passes.h"
-#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/ImplicitLocOpBuilder.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/Matchers.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/Types.h"
-#include "mlir/IR/Value.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Support/LLVM.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
-#include "xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h"
-#include "xla/service/gpu/fusions/mlir/passes.h"
-#include "xla/xla_data.pb.h"
-
-namespace xla {
-namespace gpu {
-
-namespace ma = ::mlir::arith;
-
-using ma::SelectOp;
-using mlir::Value;
-
-#define GEN_PASS_DEF_EXPANDFLOATOPSPASS
-#include "xla/service/gpu/fusions/mlir/passes.h.inc"
-
-namespace {
-
-// Wraps a Value to provide operator overloading for more readable expressions.
-struct Val {
-  Value value;
-  mlir::ImplicitLocOpBuilder* b;
-
-  operator Value() const { return value; }  // NOLINT
-
-  Val operator+(int64_t rhs) const { return Binop<ma::AddIOp>(rhs); }
-  Val operator+(Value rhs) const { return Binop<ma::AddIOp>(rhs); }
-  Val operator-(int64_t rhs) const { return Binop<ma::SubIOp>(rhs); }
-  Val operator-(Value rhs) const { return Binop<ma::SubIOp>(rhs); }
-  Val operator*(int64_t rhs) const { return Binop<ma::MulIOp>(rhs); }
-  Val operator*(Value rhs) const { return Binop<ma::MulIOp>(rhs); }
-  Val operator&(Value rhs) const { return Binop<ma::AndIOp>(rhs); }
-  Val operator&(int64_t rhs) const { return Binop<ma::AndIOp>(rhs); }
-  Val operator|(Value rhs) const { return Binop<ma::OrIOp>(rhs); }
-  Val operator|(int64_t rhs) const { return Binop<ma::OrIOp>(rhs); }
-  Val operator^(Value rhs) const { return Binop<ma::XOrIOp>(rhs); }
-  Val shl(Value rhs) const { return Binop<ma::ShLIOp>(rhs); }
-  Val shl(int64_t rhs) const { return Binop<ma::ShLIOp>(rhs); }
-  Val shrui(Value rhs) const { return Binop<ma::ShRUIOp>(rhs); }
-  Val shrui(int64_t rhs) const { return Binop<ma::ShRUIOp>(rhs); }
-
-  Val cmp(ma::CmpIPredicate pred, Value rhs) const {
-    return {b->create<ma::CmpIOp>(pred, value, rhs), b};
-  }
-  Val cmp(ma::CmpIPredicate pred, int64_t rhs) const {
-    return cmp(pred, MakeConstant(rhs));
-  }
-  Val operator==(Value rhs) const { return cmp(ma::CmpIPredicate::eq, rhs); }
-  Val operator==(int64_t rhs) const { return cmp(ma::CmpIPredicate::eq, rhs); }
-  Val operator!=(int64_t rhs) const { return cmp(ma::CmpIPredicate::ne, rhs); }
-
-  Val MakeConstant(int64_t c) const {
-    return {b->create<ma::ConstantIntOp>(c, value.getType()), b};
-  }
-
- private:
-  template <typename Op>
-  Val Binop(Value rhs) const {
-    return {b->create<Op>(value, rhs), b};
-  }
-
-  template <typename Op>
-  Val Binop(int64_t rhs) const {
-    return Binop<Op>(MakeConstant(rhs));
-  }
-};
-
-template <typename OpTy, ma::CmpFPredicate pred>
-struct RewriteToCmpSelect : public mlir::OpRewritePattern<OpTy> {
-  using mlir::OpRewritePattern<OpTy>::OpRewritePattern;
-
-  RewriteToCmpSelect(mlir::MLIRContext* context, bool include_f32)
-      : mlir::OpRewritePattern<OpTy>(context), include_f32(include_f32) {}
-
-  mlir::LogicalResult matchAndRewrite(
-      OpTy op, mlir::PatternRewriter& rewriter) const override {
-    if (op.getType().isF32() && !include_f32) {
-      return rewriter.notifyMatchFailure(op, "not rewriting f32 min/max");
-    }
-
-    auto lhs_is_nan = rewriter.create<ma::CmpFOp>(
-        op.getLoc(), ma::CmpFPredicate::UNE, op.getLhs(), op.getLhs());
-    auto rhs_is_not_nan = rewriter.create<ma::CmpFOp>(
-        op.getLoc(), ma::CmpFPredicate::OEQ, op.getRhs(), op.getRhs());
-
-    auto return_lhs =
-        rewriter.create<ma::CmpFOp>(op.getLoc(), pred, op.getLhs(), op.getRhs())
-            .getResult();
-
-    // logic: isNaN(lhs) || (!isNan(rhs) && return_lhs) ? lhs : rhs
-    return_lhs = rewriter.create<ma::OrIOp>(
-        op.getLoc(), lhs_is_nan,
-        rewriter.create<ma::AndIOp>(op.getLoc(), rhs_is_not_nan, return_lhs));
-
-    rewriter.replaceOpWithNewOp<SelectOp>(op, op.getResult().getType(),
-                                          return_lhs, op.getLhs(), op.getRhs());
-    return mlir::success();
-  }
-
-  bool include_f32;
-};
-
-struct RewriteErf32Pattern : public mlir::OpRewritePattern<mlir::math::ErfOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  mlir::LogicalResult matchAndRewrite(
-      mlir::math::ErfOp op, mlir::PatternRewriter& rewriter) const override {
-    if (!op.getType().isF32()) {
-      return rewriter.notifyMatchFailure(op, "not an f32 erf");
-    }
-
-    static const std::array<float, 5> kAlpha{
-        0.00022905065861350646f, 0.0034082910107109506f, 0.050955695062380861f,
-        0.18520832239976145f, 1.128379143519084f};
-
-    static const std::array<float, 7> kBeta{-1.1791602954361697e-7,
-                                            0.000023547966471313185f,
-                                            0.0010179625278914885f,
-                                            0.014070470171167667f,
-                                            0.11098505178285362f,
-                                            0.49746925110067538f,
-                                            1.0f};
-
-    // We clamp x to be within [-c;c] where c = erfinv(1-2^-23), outside of
-    // which x should be +/-1.
-    constexpr float kErfInvOneMinusHalfULP = 3.7439211627767994f;
-
-    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
-    auto c = [&](float v) -> Value {
-      return b.create<ma::ConstantFloatOp>(llvm::APFloat(v),
-                                           rewriter.getF32Type());
-    };
-
-    auto poly = [&](auto x, auto coefficients) -> Value {
-      auto r = c(coefficients[0]);
-      for (int i = 1; i < coefficients.size(); ++i) {
-        r = b.create<mlir::math::FmaOp>(r, x, c(coefficients[i]));
-      }
-      return r;
-    };
-
-    Value x = op.getOperand();
-    x = b.create<ma::MaximumFOp>(x, c(-kErfInvOneMinusHalfULP));
-    x = b.create<ma::MinimumFOp>(x, c(kErfInvOneMinusHalfULP));
-    Value x2 = b.create<ma::MulFOp>(x, x);
-
-    rewriter.replaceOpWithNewOp<ma::DivFOp>(
-        op, b.create<ma::MulFOp>(x, poly(x2, kAlpha)), poly(x2, kBeta));
-
-    return mlir::success();
-  }
-};
-
-int GetSignificandBits(mlir::FloatType ty) {
-  return llvm::APFloat::semanticsPrecision(ty.getFloatSemantics()) - 1;
-}
-
-int GetExponentBias(mlir::FloatType ty) {
-  return 1 - llvm::APFloat::semanticsMinExponent(ty.getFloatSemantics());
-}
-
-Value IsInf(Value value, mlir::ImplicitLocOpBuilder& b) {
-  auto ty = mlir::cast<mlir::FloatType>(value.getType());
-  if (mlir::LLVM::isCompatibleOuterType(ty)) {
-    value = b.create<mlir::math::AbsFOp>(value);
-    Value inf = b.create<ma::ConstantFloatOp>(
-        llvm::APFloat::getInf(ty.getFloatSemantics()), ty);
-    return b.create<ma::CmpFOp>(ma::CmpFPredicate::OEQ, value, inf);
-  }
-
-  assert(ty.getIntOrFloatBitWidth() == 8);
-  if (!ty.isFloat8E5M2()) {
-    // F8E5M2 is the only 8 bit float with infinities.
-    return b.create<ma::ConstantIntOp>(false, b.getI1Type());
-  }
-  Val bits{b.create<ma::BitcastOp>(b.getI8Type(), value), &b};
-  return (bits & 0x7F) == 0x7C;
-}
-
-Value IsNaN(Value value, mlir::ImplicitLocOpBuilder& b) {
-  auto ty = value.getType();
-  if (mlir::LLVM::isCompatibleOuterType(ty)) {
-    return b.create<ma::CmpFOp>(ma::CmpFPredicate::UNO, value, value);
-  }
-
-  assert(ty.getIntOrFloatBitWidth() == 8);
-  Val bits{b.create<ma::BitcastOp>(b.getI8Type(), value), &b};
-  if (ty.isFloat8E5M2() || ty.isFloat8E4M3FN()) {
-    return (bits & 0x7F) == 0x7F;
-  }
-  return bits == 0x80;
-}
-
-Value EmitReducePrecision(Value value, int exponent_bits, int mantissa_bits,
-                          mlir::ImplicitLocOpBuilder& b) {
-  mlir::mhlo::ReducePrecisionOp::Properties properties;
-  properties.exponent_bits = b.getI32IntegerAttr(exponent_bits);
-  properties.mantissa_bits = b.getI32IntegerAttr(mantissa_bits);
-  return mlir::mhlo::MhloOpToStdScalarOp::mapOpOfType<
-      mlir::mhlo::ReducePrecisionOp>(
-      b.getLoc(), value.getType(), {value.getType()},
-      mlir::mhlo::ReducePrecisionOp::Adaptor(value, nullptr, properties), &b);
-}
-
-Value EmitF16ToF8e5m2(Value in, mlir::ImplicitLocOpBuilder& b) {
-  Val in_bits{b.create<ma::BitcastOp>(b.getI16Type(), in), &b};
-  // Use this method of checking for NaN because it's the same as what's used
-  // in the reduce precision lowering.
-  Value is_nan = (in_bits & 32767).cmp(ma::CmpIPredicate::ugt, 31744);
-
-  Value value = EmitReducePrecision(in, 5, 2, b);
-  value = b.create<ma::BitcastOp>(b.getI16Type(), value);
-  value = b.create<ma::ShRUIOp>(value,
-                                b.create<ma::ConstantIntOp>(8, b.getI16Type()));
-  value = b.create<ma::TruncIOp>(b.getI8Type(), value);
-  // When the input is NaN, just truncating can turn a NaN into an inf if the
-  // mantissa becomes 0.
-  value = b.create<ma::SelectOp>(
-      is_nan, b.create<ma::ConstantIntOp>(0x7F, value.getType()), value);
-  return b.create<ma::BitcastOp>(b.getFloat8E5M2Type(), value);
-}
-
-Value EmitFloatConversion(Value value, mlir::FloatType to_ty,
-                          mlir::ImplicitLocOpBuilder& b) {
-  using ma::CmpIPredicate;
-
-  // This is a port of ConvertImpl in
-  // https://github.com/jax-ml/ml_dtypes/blob/main/ml_dtypes/include/float8.h
-  auto from_ty = mlir::cast<mlir::FloatType>(value.getType());
-  if (to_ty == b.getFloat8E5M2Type() && from_ty == b.getF16Type()) {
-    return EmitF16ToF8e5m2(value, b);
-  }
-
-  int from_mantissa = GetSignificandBits(from_ty);
-  int from_bias = GetExponentBias(from_ty);
-  int from_min_exp =
-      llvm::APFloat::semanticsMinExponent(from_ty.getFloatSemantics());
-  int from_max_exp =
-      llvm::APFloat::semanticsMaxExponent(from_ty.getFloatSemantics());
-  auto from_int_ty = b.getIntegerType(from_ty.getIntOrFloatBitWidth());
-
-  int to_mantissa = GetSignificandBits(to_ty);
-  int to_bias = GetExponentBias(to_ty);
-  int to_min_exp =
-      llvm::APFloat::semanticsMinExponent(to_ty.getFloatSemantics());
-  int to_max_exp =
-      llvm::APFloat::semanticsMaxExponent(to_ty.getFloatSemantics());
-  auto to_int_ty = b.getIntegerType(to_ty.getIntOrFloatBitWidth());
-
-  mlir::IntegerType wide_int_ty;
-  if (from_ty.getWidth() == 8 && to_ty.getWidth() == 8) {
-    wide_int_ty = b.getI16Type();
-  } else {
-    wide_int_ty = b.getIntegerType(
-        std::max(from_int_ty.getWidth(), to_int_ty.getWidth()));
-  }
-  auto convert_int = [&](mlir::Type ty, Value v) -> Val {
-    if (v.getType() == ty) {
-      return {v, &b};
-    }
-    if (ty.getIntOrFloatBitWidth() < v.getType().getIntOrFloatBitWidth()) {
-      return {b.create<ma::TruncIOp>(ty, v), &b};
-    }
-    return {b.create<ma::ExtUIOp>(ty, v), &b};
-  };
-
-  int64_t exp_offset = to_bias - from_bias;
-  int digit_shift = to_mantissa - from_mantissa;
-
-  Val from_bits{
-      b.create<ma::BitcastOp>(
-          b.getIntegerType(value.getType().getIntOrFloatBitWidth()), value),
-      &b};
-
-  auto cst = [&](mlir::Type ty, int64_t n) -> Val {
-    return {b.create<ma::ConstantIntOp>(n, ty), &b};
-  };
-
-  // Shift bits to destination type, without sign bit.
-  Val from_sign_bit =
-      from_bits.shrui(value.getType().getIntOrFloatBitWidth() - 1) != 0;
-
-  from_bits =
-      from_bits & ((1ULL << (value.getType().getIntOrFloatBitWidth() - 1)) - 1);
-
-  Value result_is_inf = IsInf(value, b);
-  Value input_is_nan = IsNaN(value, b);
-
-  auto cst_bits = [&](llvm::APFloat f) {
-    return cst(b.getIntegerType(llvm::APFloat::getSizeInBits(f.getSemantics())),
-               f.bitcastToAPInt().getZExtValue());
-  };
-  Value to_inf = cst_bits(llvm::APFloat::getInf(to_ty.getFloatSemantics()));
-  Value to_nan = cst_bits(llvm::APFloat::getNaN(to_ty.getFloatSemantics()));
-  Val to_zero = cst_bits(llvm::APFloat::getZero(to_ty.getFloatSemantics()));
-
-  auto round_bits_to_nearest_even = [&](Val bits, Val roundoff) {
-    assert(bits.value.getType() == roundoff.value.getType());
-    // Round to nearest even by adding a bias term.
-    // Consider a bit pattern
-    //   FFF...FLRTT...T,
-    // where bits RTT...T need to be rounded-off.  We add a bias term to the
-    // bit pattern s.t. a carry is introduced to round up only if
-    // - L is 1, R is 1, OR
-    // - L is 0, R is 1, any T is one.
-    // We do this by adding L to a bit pattern consisting of all T = 1.
-    Val rounded = (bits.shrui(roundoff) & 1) +
-                  (bits.MakeConstant(1).shl(roundoff - 1) - 1);
-    Val bias{b.create<SelectOp>(roundoff == 0, roundoff, rounded), &b};
-    return bits + bias;
-  };
-
-  // Happy path: no subnormals, infinities or NaNs.
-  Value result;
-  {
-    // Round the mantissa if it is shrinking.
-    Val rounded_from_bits = convert_int(wide_int_ty, from_bits);
-    if (digit_shift < 0) {
-      rounded_from_bits = round_bits_to_nearest_even(
-                              from_bits, from_bits.MakeConstant(-digit_shift)) &
-                          ~((1ll << (-digit_shift)) - 1);
-    }
-
-    // Re-bias the exponent.
-    rounded_from_bits = rounded_from_bits + (exp_offset << from_mantissa);
-
-    // Check for overflows by aligning the significands. We always align the
-    // narrower significand to the wider significand.
-    int64_t to_highest = llvm::APFloat::getLargest(to_ty.getFloatSemantics())
-                             .bitcastToAPInt()
-                             .getZExtValue();
-    int64_t aligned_highest = to_highest;
-    if (digit_shift < 0) {
-      aligned_highest <<= -digit_shift;
-      // Shift down, all dropped bits should already be zero.
-      result = rounded_from_bits.shrui(-digit_shift);
-    } else {
-      // Shift up, inserting zeros in the newly created digits.
-      rounded_from_bits = rounded_from_bits.shl(digit_shift);
-      result = rounded_from_bits;
-    }
-    result = convert_int(to_int_ty, result);
-
-    // `From` supports larger values than `To`, we may overflow.
-    if (std::make_pair(to_max_exp, to_mantissa) <
-        std::make_pair(from_max_exp, from_mantissa)) {
-      result = b.create<SelectOp>(
-          rounded_from_bits.cmp(CmpIPredicate::ugt, aligned_highest), to_inf,
-          result);
-    }
-  }
-
-  auto i32_ty = b.getI32Type();
-  Val biased_from_exp = convert_int(i32_ty, from_bits.shrui(from_mantissa));
-
-  if (to_min_exp < from_min_exp) {
-    // `To` supports more exponents near zero which means that some subnormal
-    // values in `From` may become normal.
-
-    // Subnormals.
-    Val bits = convert_int(wide_int_ty, from_bits);
-
-    // Determine exponent in target type.
-    Value normalization_factor =
-        convert_int(i32_ty,
-                    b.create<mlir::math::CountLeadingZerosOp>(from_bits)) -
-        (from_int_ty.getWidth() - from_mantissa - 1);
-
-    Val biased_exponent = cst(i32_ty, exp_offset + 1) - normalization_factor;
-    // If the result is subnormal, adjust the subnormal bits to account for
-    // the difference in exponent bias.
-    Value subnormal_bits = bits;
-    if (exp_offset < wide_int_ty.getWidth()) {
-      subnormal_bits = bits.shl(exp_offset);
-    }
-
-    // Result is normal. Shift the mantissa to account for the number of
-    // leading zero digits, and clear the hidden bit.
-    // Insert the exponent bits.
-    Value normal_bits =
-        (bits.shl(convert_int(wide_int_ty, normalization_factor)) &
-         ~(1 << from_mantissa)) |
-        convert_int(wide_int_ty, biased_exponent).shl(from_mantissa);
-
-    Value biased_exp_sle_zero = biased_exponent.cmp(CmpIPredicate::sle, 0);
-    bits.value =
-        b.create<SelectOp>(biased_exp_sle_zero, subnormal_bits, normal_bits);
-    if (digit_shift > 0) {
-      bits = bits.shl(digit_shift);
-    } else {
-      bits = round_bits_to_nearest_even(bits, bits.MakeConstant(-digit_shift));
-      bits = bits.shrui(-digit_shift);
-    }
-    bits = convert_int(to_int_ty, bits);
-
-    result = b.create<SelectOp>(biased_from_exp == 0, bits, result);
-  } else if (to_min_exp > from_min_exp) {
-    // `To` supports fewer exponents near zero which means that some values in
-    // `From` may become subnormal.
-    Val unbiased_exp = biased_from_exp - from_bias;
-    Val biased_to_exp = unbiased_exp + to_bias;
-    // Subnormals and zero.
-    // Round and shift mantissa down.
-    Val from_has_leading_one = biased_from_exp != 0;
-    Val from_has_leading_one_i32 = convert_int(i32_ty, from_has_leading_one);
-    from_has_leading_one = convert_int(from_int_ty, from_has_leading_one);
-    Val exponent_shift_i32 =
-        (from_has_leading_one_i32 - biased_to_exp) - digit_shift;
-    // Insert the implicit leading 1 bit on the mantissa for normalized
-    // inputs.
-    Val rounded_from_bits = (from_bits & ((1ll << from_mantissa) - 1)) |
-                            from_has_leading_one.shl(from_mantissa);
-
-    // NOTE: we need to round again from the original from_bits,
-    // otherwise the lower precision bits may already be lost.  There is
-    // an edge-case where rounding to a normalized value would normally
-    // round down, but for a subnormal, we need to round up.
-    Val exponent_shift_from_ty = convert_int(from_int_ty, exponent_shift_i32);
-    Val exponent_shift_to_ty = convert_int(to_int_ty, exponent_shift_i32);
-    Val positive_bits = convert_int(
-        to_int_ty,
-        round_bits_to_nearest_even(rounded_from_bits, exponent_shift_from_ty)
-            .shrui(exponent_shift_from_ty));
-    // To avoid UB, limit rounding and shifting to the full mantissa plus
-    // leading 1.
-    positive_bits.value = b.create<SelectOp>(
-        exponent_shift_i32.cmp(CmpIPredicate::sle, from_mantissa + 1),
-        positive_bits, to_zero);
-
-    Val negative_bits = convert_int(to_int_ty, rounded_from_bits)
-                            .shl(to_zero - exponent_shift_to_ty);
-    Value bits =
-        b.create<SelectOp>(exponent_shift_i32.cmp(CmpIPredicate::sgt, 0),
-                           positive_bits, negative_bits);
-    result = b.create<SelectOp>(biased_to_exp.cmp(CmpIPredicate::sle, 0), bits,
-                                result);
-  }
-
-  // Handle types with no unsigned zero.
-  auto is_nuz = [](mlir::FloatType ty) {
-    return ty.isFloat8E4M3B11FNUZ() || ty.isFloat8E4M3FNUZ() ||
-           ty.isFloat8E5M2FNUZ();
-  };
-
-  if (is_nuz(to_ty)) {
-    // Clear the sign bit if the result is zero (the output has no negative
-    // zero).
-    Val result_is_non_zero = Val{result, &b} != 0;
-    from_sign_bit = from_sign_bit & result_is_non_zero;
-  } else if (is_nuz(from_ty)) {
-    // Clear the sign bit if the input is NaN (it's positive but encoded as
-    // negative 0).
-    from_sign_bit = from_sign_bit ^ input_is_nan;
-  }
-
-  result = b.create<SelectOp>(result_is_inf, to_inf, result);
-  result = b.create<SelectOp>(from_bits == 0, to_zero, result);
-  result = b.create<SelectOp>(input_is_nan, to_nan, result);
-
-  Value neg_result = Val{result, &b} | (1ll << (to_int_ty.getWidth() - 1));
-
-  // Insert sign bit.
-  result = b.create<SelectOp>(from_sign_bit, neg_result, result);
-  result = b.create<ma::BitcastOp>(to_ty, result);
-  return result;
-}
-
-struct RewriteTruncFPattern : public mlir::OpRewritePattern<ma::TruncFOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  mlir::LogicalResult matchAndRewrite(
-      ma::TruncFOp op, mlir::PatternRewriter& rewriter) const override {
-    using FloatValue = mlir::TypedValue<mlir::FloatType>;
-    auto src = mlir::cast<FloatValue>(op.getOperand());
-    auto dst_ty = mlir::cast<mlir::FloatType>(op.getType());
-    if (dst_ty.getWidth() != 8) {
-      return rewriter.notifyMatchFailure(op, "not an 8 bit truncf");
-    }
-
-    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
-    rewriter.replaceOp(op, EmitFloatConversion(src, dst_ty, b));
-    return mlir::success();
-  }
-};
-
-struct RewriteExtFPattern : public mlir::OpRewritePattern<ma::ExtFOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  mlir::LogicalResult matchAndRewrite(
-      ma::ExtFOp op, mlir::PatternRewriter& rewriter) const override {
-    using FloatValue = mlir::TypedValue<mlir::FloatType>;
-    auto src = mlir::cast<FloatValue>(op.getOperand());
-    auto dst_ty = mlir::cast<mlir::FloatType>(op.getType());
-    if (src.getType().getWidth() != 8) {
-      return rewriter.notifyMatchFailure(op, "not an 8 bit extf");
-    }
-
-    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
-    rewriter.replaceOp(op, EmitFloatConversion(src, dst_ty, b));
-    return mlir::success();
-  }
-};
-
-// Lowering for cmpf : f8 for float to pred conversions.
-struct RewriteF8Cst : public mlir::OpRewritePattern<ma::CmpFOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  mlir::LogicalResult matchAndRewrite(
-      ma::CmpFOp op, mlir::PatternRewriter& rewriter) const override {
-    using FloatValue = mlir::TypedValue<mlir::FloatType>;
-    auto lhs = mlir::cast<FloatValue>(op.getLhs());
-    auto rhs = mlir::cast<FloatValue>(op.getRhs());
-
-    if (lhs.getType().getWidth() != 8) {
-      return rewriter.notifyMatchFailure(op, "not an 8 bit cmpf");
-    }
-
-    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
-    // Skip the f32 conversion if we're comparing UNE.cst.
-    llvm::APFloat rhs_cst(rhs.getType().getFloatSemantics());
-    if (op.getPredicate() == ma::CmpFPredicate::UNE &&
-        mlir::matchPattern(rhs, mlir::m_ConstantFloat(&rhs_cst))) {
-      Val int_value{b.create<ma::BitcastOp>(rewriter.getI8Type(), lhs), &b};
-      int64_t constant = rhs_cst.bitcastToAPInt().getZExtValue();
-      // If we're comparing to +-0, compare the absolute values.
-      if (rhs_cst.isZero() &&
-          (lhs.getType().isFloat8E4M3FN() || lhs.getType().isFloat8E5M2())) {
-        int_value = int_value & 0x7f;
-        constant &= 0x7f;
-      }
-      auto cst = b.create<ma::ConstantIntOp>(constant, rewriter.getI8Type());
-      rewriter.replaceOpWithNewOp<ma::CmpIOp>(op, ma::CmpIPredicate::ne,
-                                              int_value, cst);
-      return mlir::success();
-    }
-
-    auto lhs_ext = b.create<ma::ExtFOp>(b.getF32Type(), lhs);
-    auto rhs_ext = b.create<ma::ExtFOp>(b.getF32Type(), rhs);
-    rewriter.replaceOpWithNewOp<ma::CmpFOp>(op, op->getResultTypes(),
-                                            mlir::ValueRange{lhs_ext, rhs_ext},
-                                            op->getAttrs());
-    return mlir::success();
-  }
-};
-
-struct RewriteAbsFPattern : public mlir::OpRewritePattern<mlir::math::AbsFOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  mlir::LogicalResult matchAndRewrite(
-      mlir::math::AbsFOp op, mlir::PatternRewriter& rewriter) const override {
-    using FloatValue = mlir::TypedValue<mlir::FloatType>;
-    auto src = mlir::cast<FloatValue>(op.getOperand());
-    // LowerGpuOpsToNVVMOps has a lowering for abs that doesn't work with bf16.
-    // Once that's removed, remove the code for BF16 here.
-    if (src.getType().getWidth() != 8 && !src.getType().isBF16()) {
-      return rewriter.notifyMatchFailure(op, "not an f8 or bf16 absf");
-    }
-    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
-    mlir::Type i_ty = rewriter.getIntegerType(src.getType().getWidth());
-    Val value{b.create<ma::BitcastOp>(i_ty, src), &b};
-    if (src.getType().getWidth() == 8) {
-      value = value & 0x7f;
-    } else {
-      CHECK(src.getType().isBF16());
-      value = value & 0x7fff;
-    }
-    rewriter.replaceOpWithNewOp<ma::BitcastOp>(op, src.getType(), value);
-    return mlir::success();
-  }
-};
-
-template <typename Op>
-struct RewriteIToFpPattern : public mlir::OpRewritePattern<Op> {
-  using mlir::OpRewritePattern<Op>::OpRewritePattern;
-
-  mlir::LogicalResult matchAndRewrite(
-      Op op, mlir::PatternRewriter& rewriter) const override {
-    if (op.getType().getIntOrFloatBitWidth() != 8) {
-      return rewriter.notifyMatchFailure(op, "not an f8 itofp");
-    }
-    Value to_float =
-        rewriter.create<Op>(op.getLoc(), rewriter.getF32Type(), op.getIn());
-    rewriter.replaceOpWithNewOp<ma::TruncFOp>(op, op.getType(), to_float);
-    return mlir::success();
-  }
-};
-
-template <typename Op>
-struct RewriteFpToIPattern : public mlir::OpRewritePattern<Op> {
-  using mlir::OpRewritePattern<Op>::OpRewritePattern;
-
-  mlir::LogicalResult matchAndRewrite(
-      Op op, mlir::PatternRewriter& rewriter) const override {
-    if (op.getIn().getType().getIntOrFloatBitWidth() != 8) {
-      return rewriter.notifyMatchFailure(op, "not an f8 fptoi");
-    }
-    Value to_f32 = rewriter.create<ma::ExtFOp>(
-        op.getLoc(), rewriter.getF32Type(), op.getIn());
-    rewriter.replaceOpWithNewOp<Op>(op, op.getType(), to_f32);
-    return mlir::success();
-  }
-};
-
-class ExpandFloatOpsPass
-    : public impl::ExpandFloatOpsPassBase<ExpandFloatOpsPass> {
- public:
-  using ExpandFloatOpsPassBase::ExpandFloatOpsPassBase;
-  void runOnOperation() override {
-    mlir::RewritePatternSet patterns(&getContext());
-    patterns.add<RewriteToCmpSelect<ma::MinimumFOp, ma::CmpFPredicate::OLE>>(
-        &getContext(), /*include_f32=*/pre_ampere_);
-    patterns.add<RewriteToCmpSelect<ma::MaximumFOp, ma::CmpFPredicate::OGE>>(
-        &getContext(), /*include_f32=*/pre_ampere_);
-    patterns.add<RewriteTruncFPattern, RewriteExtFPattern, RewriteAbsFPattern,
-                 RewriteF8Cst, RewriteIToFpPattern<ma::SIToFPOp>,
-                 RewriteIToFpPattern<ma::UIToFPOp>,
-                 RewriteFpToIPattern<ma::FPToSIOp>,
-                 RewriteFpToIPattern<ma::FPToUIOp>>(&getContext());
-    mlir::populatePolynomialApproximateTanhPattern(patterns);
-    patterns.add<RewriteErf32Pattern>(&getContext());
-    if (mlir::failed(mlir::applyPatternsAndFoldGreedily(getOperation(),
-                                                        std::move(patterns)))) {
-      signalPassFailure();
-    }
-  }
-};
-
-}  // namespace
-
-std::unique_ptr<mlir::Pass> CreateExpandFloatOpsPass(bool pre_ampere) {
-  return createExpandFloatOpsPass(ExpandFloatOpsPassOptions{pre_ampere});
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/flatten_tensors.cc b/third_party/xla/xla/service/gpu/fusions/mlir/flatten_tensors.cc
deleted file mode 100644
index 99a7ecb..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/flatten_tensors.cc
+++ /dev/null
@@ -1,452 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include <cstdint>
-#include <memory>
-#include <optional>
-#include <utility>
-
-#include "llvm/ADT/APInt.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SmallBitVector.h"
-#include "llvm/ADT/SmallVector.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/SCF/Utils/Utils.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/IR/AffineExpr.h"
-#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/ImplicitLocOpBuilder.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/TypeRange.h"
-#include "mlir/IR/Types.h"
-#include "mlir/IR/Value.h"
-#include "mlir/IR/ValueRange.h"
-#include "mlir/IR/Visitors.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Support/LLVM.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "xla/layout_util.h"
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h"
-#include "xla/service/gpu/model/indexing_analysis.h"
-#include "xla/shape_util.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-#define GEN_PASS_DEF_FLATTENTENSORSPASS
-#include "xla/service/gpu/fusions/mlir/passes.h.inc"
-
-using mlir::Location;
-using mlir::LogicalResult;
-using mlir::MLIRContext;
-using mlir::OpRewritePattern;
-using mlir::PatternRewriter;
-using mlir::RankedTensorType;
-using mlir::SmallVector;
-using mlir::Type;
-using mlir::TypedValue;
-using mlir::TypeRange;
-using mlir::UnrealizedConversionCastOp;
-using mlir::Value;
-using mlir::ValueRange;
-using mlir::func::FuncOp;
-using mlir::func::ReturnOp;
-using mlir::scf::ForOp;
-using mlir::scf::IfOp;
-using mlir::tensor::ExtractOp;
-using mlir::tensor::InsertOp;
-
-RankedTensorType GetFlattenedType(RankedTensorType tensor_type) {
-  return RankedTensorType::get({tensor_type.getNumElements()},
-                               tensor_type.getElementType());
-}
-
-bool HasOnlyFlatTensorsOrScalars(TypeRange types) {
-  return llvm::all_of(types, [](Type ty) {
-    auto tensor_type = mlir::dyn_cast<RankedTensorType>(ty);
-    if (!tensor_type) return true;
-    return tensor_type.getRank() < 2;
-  });
-}
-
-struct RewriteFunctionSignatures : OpRewritePattern<FuncOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(FuncOp op,
-                                PatternRewriter& rewriter) const override {
-    auto input_types = op.getFunctionType().getInputs();
-    auto result_types = op.getFunctionType().getResults();
-    if (HasOnlyFlatTensorsOrScalars(input_types) &&
-        HasOnlyFlatTensorsOrScalars(result_types)) {
-      return rewriter.notifyMatchFailure(op, "nothing to flatten");
-    }
-
-    auto loc = op.getLoc();
-    mlir::Block* entry_block = &op.getBody().front();
-    SmallVector<Type> new_result_types;
-    SmallVector<Value> new_results;
-
-    // If some results are tensors, we need to flatten them.
-    auto terminator = entry_block->getTerminator();
-    rewriter.setInsertionPoint(terminator);
-
-    for (Value result : terminator->getOperands()) {
-      auto tensor_type = mlir::dyn_cast<RankedTensorType>(result.getType());
-      if (!tensor_type) {
-        new_result_types.push_back(result.getType());
-        new_results.push_back(result);
-        continue;
-      }
-      auto new_result_type = GetFlattenedType(tensor_type);
-      new_result_types.push_back(new_result_type);
-
-      Value result_1d =
-          rewriter
-              .create<UnrealizedConversionCastOp>(loc, new_result_type, result)
-              .getResult(0);
-      new_results.push_back(result_1d);
-    }
-    rewriter.replaceOpWithNewOp<ReturnOp>(terminator, new_results);
-
-    // Cast all function arguments to the original type.
-    SmallVector<Type> new_operand_types(input_types);
-    rewriter.setInsertionPointToStart(entry_block);
-    for (auto&& [index, operand_type] : llvm::enumerate(new_operand_types)) {
-      if (auto tensor_type = mlir::dyn_cast<RankedTensorType>(operand_type)) {
-        if (tensor_type.getRank() > 1) {
-          mlir::BlockArgument func_argument = op.getArgument(index);
-          auto cast_to_orig_type = rewriter.create<UnrealizedConversionCastOp>(
-              loc, operand_type, func_argument);
-          func_argument.replaceAllUsesExcept(cast_to_orig_type.getResult(0),
-                                             cast_to_orig_type);
-          operand_type = GetFlattenedType(tensor_type);
-        }
-      }
-    }
-    // Replace the function arguments with the new types.
-    for (auto [arg, arg_type] :
-         llvm::zip(entry_block->getArguments(), new_operand_types)) {
-      arg.setType(arg_type);
-    }
-    // Update function signature.
-    op.setType(rewriter.getFunctionType(new_operand_types, new_result_types));
-    return mlir::success();
-  }
-};
-
-// Returns the linearized index, if the rank is greater than 1. Otherwise,
-// returns nullptr.
-Value LinearizeIndex(TypedValue<mlir::RankedTensorType> tensor,
-                     ValueRange indices, PatternRewriter& rewriter) {
-  if (tensor.getType().getRank() < 2) {
-    return nullptr;
-  }
-  auto byte_shape = ShapeUtil::MakeShape(U8, tensor.getType().getShape());
-  if (auto encoding = tensor.getType().getEncoding()) {
-    *byte_shape.mutable_layout() = LayoutUtil::MakeLayout(llvm::to_vector(
-        mlir::cast<mlir::DenseElementsAttr>(encoding).getValues<int64_t>()));
-  }
-  auto linear_shape =
-      ShapeUtil::MakeShape(U8, {ShapeUtil::ElementsIn(byte_shape)});
-  auto linearized_map =
-      GetBitcastMap(byte_shape, linear_shape, tensor.getContext());
-  mlir::SmallVector<Value> result;
-  rewriter.createOrFold<ApplyIndexingOp>(result, tensor.getLoc(), indices,
-                                         ValueRange{}, linearized_map);
-  return result.front();
-}
-
-struct RewriteTensorExtract : OpRewritePattern<ExtractOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ExtractOp op,
-                                PatternRewriter& rewriter) const override {
-    auto tensor = op.getTensor();
-    auto tensor_type = tensor.getType();
-    auto linear_index = LinearizeIndex(tensor, op.getIndices(), rewriter);
-    if (linear_index == nullptr) {
-      return rewriter.notifyMatchFailure(op, "the tensor is already flat");
-    }
-    auto tensor_1D = rewriter
-                         .create<UnrealizedConversionCastOp>(
-                             op.getLoc(), GetFlattenedType(tensor_type), tensor)
-                         .getResult(0);
-    rewriter.replaceOpWithNewOp<ExtractOp>(op, tensor_1D, linear_index);
-    return mlir::success();
-  }
-};
-
-struct RewriteTensorInsert : OpRewritePattern<InsertOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(InsertOp op,
-                                PatternRewriter& rewriter) const override {
-    auto tensor = op.getDest();
-    auto tensor_type = tensor.getType();
-    auto linear_index = LinearizeIndex(tensor, op.getIndices(), rewriter);
-    if (linear_index == nullptr) {
-      return rewriter.notifyMatchFailure(op, "the tensor is already flat");
-    }
-    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
-    auto tensor_1D = b.create<UnrealizedConversionCastOp>(
-                          GetFlattenedType(tensor_type), tensor)
-                         .getResult(0);
-    auto new_insert =
-        b.create<InsertOp>(op.getScalar(), tensor_1D, linear_index);
-    auto cast_to_orig_type = b.create<UnrealizedConversionCastOp>(
-        tensor_type, new_insert.getResult());
-    rewriter.replaceOp(op, cast_to_orig_type.getResult(0));
-    return mlir::success();
-  }
-};
-
-struct RewriteAtomicRMW : OpRewritePattern<AtomicRMWOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(AtomicRMWOp op,
-                                PatternRewriter& rewriter) const override {
-    auto tensor = op.getInput();
-    auto tensor_type = tensor.getType();
-    auto linear_index = LinearizeIndex(tensor, op.getIndices(), rewriter);
-    if (linear_index == nullptr) {
-      return rewriter.notifyMatchFailure(op, "the tensor is already flat");
-    }
-    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
-    auto tensor_1D = b.create<UnrealizedConversionCastOp>(
-                          GetFlattenedType(tensor_type), tensor)
-                         .getResult(0);
-    auto new_atomic_rmw = b.create<AtomicRMWOp>(tensor_1D, linear_index);
-    rewriter.inlineRegionBefore(op.getRegion(),
-                                &new_atomic_rmw.getRegion().front());
-    auto cast_to_orig_type = b.create<UnrealizedConversionCastOp>(
-        tensor_type, new_atomic_rmw.getResult());
-    rewriter.replaceOp(op, cast_to_orig_type.getResult(0));
-    return mlir::success();
-  }
-};
-
-// Checks that the value is produced by an unrealized conversion cast from 1D
-// tensor to ND. Returns the 1D tensor if so.
-std::optional<Value> GetDelinearizedTensor(Value value) {
-  auto tensor_type = mlir::dyn_cast<RankedTensorType>(value.getType());
-  if (!tensor_type || tensor_type.getRank() < 2) {
-    return std::nullopt;
-  }
-  auto cast = value.getDefiningOp<UnrealizedConversionCastOp>();
-  if (!cast || cast->getNumResults() != 1 || cast->getNumOperands() != 1) {
-    return std::nullopt;
-  }
-  auto type_before_linearization =
-      mlir::dyn_cast<RankedTensorType>(cast->getOperand(0).getType());
-  if (!type_before_linearization || type_before_linearization.getRank() != 1) {
-    return std::nullopt;
-  }
-  return cast->getOperand(0);
-}
-
-struct RewriteForOp : public OpRewritePattern<ForOp> {
-  using OpRewritePattern<ForOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ForOp op,
-                                PatternRewriter& rewriter) const override {
-    llvm::SmallBitVector args_to_update(op.getNumResults(), false);
-    mlir::SmallVector<Value> new_init_args;
-    new_init_args.reserve(op.getNumResults());
-    for (auto [index, arg] : llvm::enumerate(op.getInitArgs())) {
-      auto type_before_linearization = GetDelinearizedTensor(arg);
-      if (!type_before_linearization.has_value()) {
-        new_init_args.push_back(arg);
-        continue;
-      }
-      new_init_args.push_back(*type_before_linearization);
-      args_to_update.set(index);
-    }
-    if (args_to_update.none()) {
-      return rewriter.notifyMatchFailure(op, "no args to update");
-    }
-    // Create new ForOp with updated init args.
-    Location loc = op.getLoc();
-    auto new_for_op =
-        rewriter.create<ForOp>(loc, op.getLowerBound(), op.getUpperBound(),
-                               op.getStep(), new_init_args);
-    new_for_op->setAttrs(op->getAttrs());
-
-    // Insert casts for the block arguments.
-    mlir::Block* new_body = new_for_op.getBody();
-    mlir::Block* old_body = op.getBody();
-    rewriter.setInsertionPoint(new_body, new_body->begin());
-    SmallVector<Value, 4> updated_block_args{new_body->getArguments().begin(),
-                                             new_body->getArguments().end()};
-    for (auto [index, arg] :
-         llvm::enumerate(new_body->getArguments().drop_front())) {
-      if (!args_to_update.test(index)) continue;
-      updated_block_args[index + 1] =
-          rewriter
-              .create<UnrealizedConversionCastOp>(
-                  loc, old_body->getArgument(index + 1).getType(), arg)
-              .getResult(0);
-    }
-
-    // Move the body of the old ForOp to the new one.
-    rewriter.mergeBlocks(old_body, new_body, updated_block_args);
-
-    // Update the terminator.
-    auto new_terminator =
-        mlir::cast<mlir::scf::YieldOp>(new_body->getTerminator());
-    rewriter.setInsertionPoint(new_terminator);
-    for (auto&& [index, yielded_value] :
-         llvm::enumerate(new_terminator.getResultsMutable())) {
-      if (!args_to_update.test(index)) continue;
-      yielded_value.assign(
-          rewriter
-              .create<UnrealizedConversionCastOp>(
-                  loc, new_init_args[index].getType(), yielded_value.get())
-              .getResult(0));
-    }
-
-    // Cast back the results.
-    rewriter.setInsertionPointAfter(new_for_op);
-    SmallVector<Value> new_results(new_for_op.getResults());
-    for (auto&& [index, result] : llvm::enumerate(new_results)) {
-      if (!args_to_update.test(index)) continue;
-      result = rewriter
-                   .create<UnrealizedConversionCastOp>(
-                       loc, op->getResult(index).getType(), result)
-                   .getResult(0);
-    }
-    rewriter.replaceOp(op, new_results);
-    return mlir::failure();
-  }
-};
-
-struct RewriteIfOp : public OpRewritePattern<IfOp> {
-  using OpRewritePattern<IfOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(IfOp op,
-                                PatternRewriter& rewriter) const override {
-    auto result_types = op.getResultTypes();
-    if (HasOnlyFlatTensorsOrScalars(result_types)) {
-      return rewriter.notifyMatchFailure(op, "nothing to flatten");
-    }
-    mlir::scf::YieldOp then_yield = op.thenYield();
-    SmallVector<Type> new_result_types;
-    new_result_types.reserve(then_yield.getNumOperands());
-    bool found_cast = false;
-    for (auto& result : then_yield->getOpOperands()) {
-      auto delinearized_tensor = GetDelinearizedTensor(result.get());
-      if (!delinearized_tensor.has_value()) {
-        new_result_types.push_back(result.get().getType());
-        continue;
-      }
-      new_result_types.push_back(delinearized_tensor->getType());
-      result.set(*delinearized_tensor);
-      found_cast = true;
-    }
-    if (!found_cast) {
-      return rewriter.notifyMatchFailure(op, "no cast found");
-    }
-    Location loc = op.getLoc();
-    // Update the else branch if present.
-    bool has_else_region = !op.getElseRegion().empty();
-    if (has_else_region) {
-      mlir::scf::YieldOp else_yield = op.elseYield();
-      mlir::OpBuilder::InsertionGuard g(rewriter);
-      rewriter.setInsertionPoint(else_yield);
-      for (auto&& [result, type] :
-           llvm::zip(else_yield->getOpOperands(), new_result_types)) {
-        if (result.get().getType() == type) continue;
-        result.set(
-            rewriter.create<UnrealizedConversionCastOp>(loc, type, result.get())
-                .getResult(0));
-      }
-    }
-    // Create new IfOp and move the old op's regions to the new one.
-    auto new_if_op = rewriter.create<IfOp>(loc, new_result_types,
-                                           op.getCondition(), has_else_region);
-    rewriter.inlineRegionBefore(op.getThenRegion(),
-                                &new_if_op.getThenRegion().back());
-    rewriter.eraseBlock(&new_if_op.getThenRegion().back());
-    if (has_else_region) {
-      rewriter.inlineRegionBefore(op.getElseRegion(),
-                                  &new_if_op.getElseRegion().back());
-      rewriter.eraseBlock(&new_if_op.getElseRegion().back());
-    }
-
-    // Update the results.
-    rewriter.setInsertionPointAfter(new_if_op);
-    SmallVector<Value> new_results(new_if_op.getResults());
-    for (auto&& [index, result] : llvm::enumerate(new_results)) {
-      Type old_type = op->getResult(index).getType();
-      if (result.getType() == old_type) continue;
-      result =
-          rewriter.create<UnrealizedConversionCastOp>(loc, old_type, result)
-              .getResult(0);
-    }
-    rewriter.replaceOp(op, new_results);
-    return mlir::success();
-  }
-};
-
-class FlattenTensorsPass
-    : public impl::FlattenTensorsPassBase<FlattenTensorsPass> {
- public:
-  void runOnOperation() override {
-    mlir::ModuleOp module = getOperation();
-    MLIRContext* mlir_context = &getContext();
-    mlir::RewritePatternSet patterns(mlir_context);
-    // clang-format off
-    patterns.add<
-        RewriteAtomicRMW,
-        RewriteForOp,
-        RewriteFunctionSignatures,
-        RewriteIfOp,
-        RewriteTensorExtract,
-        RewriteTensorInsert
-    >(mlir_context);
-    // clang-format on
-    ApplyIndexingOp::getCanonicalizationPatterns(patterns, mlir_context);
-    if (mlir::failed(
-            mlir::applyPatternsAndFoldGreedily(module, std::move(patterns)))) {
-      signalPassFailure();
-      return;
-    }
-    // Check if there are no unrealized_conversion_casts.
-    bool module_has_casts = module
-                                .walk([](UnrealizedConversionCastOp op) {
-                                  return mlir::WalkResult::interrupt();
-                                })
-                                .wasInterrupted();
-    if (module_has_casts) {
-      llvm::outs() << "FlattenTensorsPass failed to converge";
-      signalPassFailure();
-      return;
-    }
-  }
-};
-
-}  // namespace
-
-std::unique_ptr<mlir::Pass> CreateFlattenTensorsPass() {
-  return std::make_unique<FlattenTensorsPass>();
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/ir/BUILD b/third_party/xla/xla/service/gpu/fusions/mlir/ir/BUILD
deleted file mode 100644
index d618413..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/ir/BUILD
+++ /dev/null
@@ -1,110 +0,0 @@
-load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")
-
-package(
-    # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
-    default_visibility = [":friends"],
-    licenses = ["notice"],
-)
-
-package_group(
-    name = "friends",
-    includes = [
-        "//xla:friends",
-    ],
-)
-
-td_library(
-    name = "xla_gpu_td_files",
-    srcs = glob(["*.td"]),
-    includes = ["."],
-    deps = [
-        "@llvm-project//mlir:CallInterfacesTdFiles",
-        "@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
-        "@llvm-project//mlir:OpBaseTdFiles",
-        "@llvm-project//mlir:SideEffectInterfacesTdFiles",
-    ],
-)
-
-gentbl_cc_library(
-    name = "xla_gpu_dialect_inc_gen",
-    strip_include_prefix = ".",
-    tbl_outs = [
-        (
-            ["-gen-dialect-decls"],
-            "xla_gpu_dialect.h.inc",
-        ),
-        (
-            ["-gen-dialect-defs"],
-            "xla_gpu_dialect.cc.inc",
-        ),
-    ],
-    tblgen = "@llvm-project//mlir:mlir-tblgen",
-    td_file = "xla_gpu_dialect.td",
-    deps = [":xla_gpu_td_files"],
-)
-
-gentbl_cc_library(
-    name = "xla_gpu_ops_inc_gen",
-    strip_include_prefix = ".",
-    tbl_outs = [
-        (
-            ["-gen-op-decls"],
-            "xla_gpu_ops.h.inc",
-        ),
-        (
-            ["-gen-op-defs"],
-            "xla_gpu_ops.cc.inc",
-        ),
-    ],
-    tblgen = "@llvm-project//mlir:mlir-tblgen",
-    td_file = "xla_gpu_ops.td",
-    deps = [":xla_gpu_td_files"],
-)
-
-gentbl_cc_library(
-    name = "xla_gpu_attrs_inc_gen",
-    strip_include_prefix = ".",
-    tbl_outs = [
-        (
-            ["-gen-attrdef-decls"],
-            "xla_gpu_attrs.h.inc",
-        ),
-        (
-            ["-gen-attrdef-defs"],
-            "xla_gpu_attrs.cc.inc",
-        ),
-    ],
-    tblgen = "@llvm-project//mlir:mlir-tblgen",
-    td_file = "xla_gpu_attrs.td",
-    deps = [":xla_gpu_td_files"],
-)
-
-cc_library(
-    name = "xla_gpu",
-    srcs = [
-        "xla_gpu_attrs.cc",
-        "xla_gpu_dialect.cc",
-        "xla_gpu_ops.cc",
-    ],
-    hdrs = [
-        "xla_gpu_attrs.h",
-        "xla_gpu_ops.h",
-    ],
-    deps = [
-        ":xla_gpu_attrs_inc_gen",
-        ":xla_gpu_dialect_inc_gen",
-        ":xla_gpu_ops_inc_gen",
-        "//xla/service/gpu/model:indexing_analysis",
-        "@com_google_absl//absl/strings:str_format",
-        "@llvm-project//llvm:Support",
-        "@llvm-project//mlir:ArithDialect",
-        "@llvm-project//mlir:BytecodeOpInterface",
-        "@llvm-project//mlir:CallOpInterfaces",
-        "@llvm-project//mlir:FuncDialect",
-        "@llvm-project//mlir:IR",
-        "@llvm-project//mlir:InferTypeOpInterface",
-        "@llvm-project//mlir:InliningUtils",
-        "@llvm-project//mlir:SideEffectInterfaces",
-        "@llvm-project//mlir:Support",
-    ],
-)
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.cc b/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.cc
deleted file mode 100644
index d38ed34..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.cc
+++ /dev/null
@@ -1,228 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.h"
-
-#include <string>
-#include <utility>
-
-#include "absl/strings/str_format.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/StringRef.h"
-#include "llvm/Support/LogicalResult.h"
-#include "mlir/IR/AffineExpr.h"
-#include "mlir/IR/AffineMap.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Diagnostics.h"
-#include "mlir/IR/OpImplementation.h"
-#include "mlir/IR/Types.h"
-#include "mlir/Support/LLVM.h"
-#include "xla/service/gpu/model/indexing_map.h"
-
-#define GET_ATTRDEF_LIST
-#define GET_ATTRDEF_CLASSES
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.h.inc"
-
-namespace xla {
-namespace gpu {
-
-using llvm::ParseResult;
-using llvm::SmallVector;
-using mlir::AffineExpr;
-using mlir::ArrayRef;
-using mlir::AsmParser;
-using mlir::AsmPrinter;
-using mlir::failure;
-using mlir::success;
-
-ParseResult ParseInterval(AsmParser& parser, Interval& interval) {
-  // ParseResult converts to `true` if parsing failed.
-  return failure(parser.parseLSquare() || parser.parseInteger(interval.lower) ||
-                 parser.parseComma() || parser.parseInteger(interval.upper) ||
-                 parser.parseRSquare());
-}
-
-void PrintDimVars(AsmPrinter& p, ArrayRef<DimVar> dim_vars) {
-  int index = 0;
-  llvm::interleaveComma(dim_vars, p, [&](const DimVar& dim_var) {
-    p << "d" << index++ << " in " << dim_var.bounds;
-  });
-}
-
-ParseResult ParseDimVars(AsmParser& parser, ArrayRef<std::string> dim_names,
-                         SmallVector<DimVar>& dim_vars) {
-  dim_vars.reserve(dim_names.size());
-  for (const auto& [index, dim_name] : llvm::enumerate(dim_names)) {
-    if (parser.parseKeyword(dim_name) || parser.parseKeyword("in") ||
-        ParseInterval(parser, dim_vars.emplace_back().bounds)) {
-      return failure();
-    }
-    if (index < dim_names.size() - 1 && parser.parseComma()) {
-      return failure();
-    }
-  }
-  return success();
-}
-
-void PrintRangeVars(AsmPrinter& p, ArrayRef<RangeVar> range_vars) {
-  int index = 0;
-  llvm::interleaveComma(range_vars, p, [&](const RangeVar& range_var) {
-    p << "s" << index++ << " in " << range_var.range;
-  });
-}
-
-ParseResult ParseRangeVars(AsmParser& parser,
-                           ArrayRef<std::string> range_symbol_names,
-                           SmallVector<RangeVar>& range_vars) {
-  range_vars.reserve(range_symbol_names.size());
-  for (const auto& [index, range_symbol_name] :
-       llvm::enumerate(range_symbol_names)) {
-    if (parser.parseKeyword(range_symbol_name) || parser.parseKeyword("in") ||
-        ParseInterval(parser, range_vars.emplace_back().range)) {
-      return failure();
-    }
-    if (index < range_symbol_names.size() - 1 && parser.parseComma()) {
-      return failure();
-    }
-  }
-  return success();
-}
-
-void PrintConstraints(AsmPrinter& p,
-                      ArrayRef<std::pair<AffineExpr, Interval>> constraints) {
-  llvm::interleaveComma(constraints, p, [&](const auto& constraint) {
-    p << constraint.first << " in " << constraint.second;
-  });
-}
-
-ParseResult ParseConstraints(
-    AsmParser& parser,
-    ArrayRef<std::pair<llvm::StringRef, AffineExpr>> symbolSet,
-    SmallVector<std::pair<AffineExpr, Interval>>& constraints) {
-  // In order for there to be any constraints, there must be at least 1 symbol
-  // or dimension meaning there will be commas for as long as there are
-  // constraints left.
-  while (succeeded(parser.parseOptionalComma())) {
-    auto& constraint = constraints.emplace_back();
-    if (parser.parseAffineExpr(symbolSet, constraint.first) ||
-        parser.parseKeyword("in") || ParseInterval(parser, constraint.second)) {
-      return failure();
-    }
-  }
-  return success();
-}
-
-mlir::Attribute IndexingMapAttr::parse(mlir::AsmParser& parser, mlir::Type) {
-  mlir::AffineMap map;
-  if (parser.parseLess() || parser.parseAffineMap(map)) {
-    return {};
-  }
-
-  // Store real strings to back up StringRef throughout ParseConstraints.
-  SmallVector<std::string> dim_strings(map.getNumDims());
-  SmallVector<std::string> symbol_strings(map.getNumSymbols());
-  SmallVector<std::pair<llvm::StringRef, AffineExpr>> symbolSet;
-  symbolSet.reserve(map.getNumDims() + map.getNumSymbols());
-  for (int i = 0; i < map.getNumDims(); ++i) {
-    dim_strings[i] = absl::StrFormat("d%d", i);
-    symbolSet.push_back(
-        {dim_strings[i], mlir::getAffineDimExpr(i, parser.getContext())});
-  }
-  for (int i = 0; i < map.getNumSymbols(); ++i) {
-    symbol_strings[i] = absl::StrFormat("s%d", i);
-    symbolSet.push_back(
-        {symbol_strings[i], mlir::getAffineSymbolExpr(i, parser.getContext())});
-  }
-  if (map.getNumDims() + map.getNumSymbols() > 0) {
-    if (parser.parseComma() || parser.parseKeyword("domain") ||
-        parser.parseColon()) {
-      return {};
-    }
-  }
-
-  SmallVector<DimVar> dim_vars;
-  if (map.getNumDims() > 0) {
-    if (ParseDimVars(parser, dim_strings, dim_vars)) {
-      return {};
-    }
-  }
-
-  SmallVector<RangeVar> range_vars;
-  if (map.getNumSymbols() > 0) {
-    if (!dim_vars.empty() && parser.parseComma()) {
-      return {};
-    }
-    if (ParseRangeVars(parser, symbol_strings, range_vars)) {
-      return {};
-    }
-  }
-
-  SmallVector<std::pair<AffineExpr, Interval>> constraints;
-  if (ParseConstraints(parser, symbolSet, constraints) ||
-      parser.parseGreater()) {
-    return {};
-  }
-  return IndexingMapAttr::get(parser.getContext(), map, dim_vars, range_vars,
-                              constraints);
-}
-
-void IndexingMapAttr::print(mlir::AsmPrinter& printer) const {
-  printer << "<";
-  printer.printStrippedAttrOrType(getMap());
-  if (getDimVars().size() + getRangeVars().size() + getConstraints().size() >
-      0) {
-    printer << ", domain: ";
-  }
-  PrintDimVars(printer, getDimVars());
-  if (!getDimVars().empty() &&
-      getRangeVars().size() + getConstraints().size() > 0) {
-    printer << ", ";
-  }
-  PrintRangeVars(printer, getRangeVars());
-  if (!getRangeVars().empty() && !getConstraints().empty()) {
-    printer << ", ";
-  }
-  PrintConstraints(printer, getConstraints());
-  printer << ">";
-}
-
-IndexingMapAttr IndexingMapAttr::get(mlir::MLIRContext* context,
-                                     const IndexingMap& indexing_map) {
-  llvm::SmallVector<std::pair<AffineExpr, Interval>> constraints;
-  for (auto& constraint : indexing_map.GetConstraints()) {
-    constraints.push_back({constraint.first, constraint.second});
-  }
-  return get(context, indexing_map.GetAffineMap(), indexing_map.GetDimVars(),
-             indexing_map.GetRangeVars(), constraints);
-}
-
-mlir::LogicalResult IndexingMapAttr::verify(
-    mlir::function_ref<mlir::InFlightDiagnostic()> emitError,
-    mlir::AffineMap map, ArrayRef<DimVar> dim_vars,
-    ArrayRef<RangeVar> range_vars,
-    ArrayRef<std::pair<AffineExpr, Interval>> constraints) {
-  if (map.getNumDims() != dim_vars.size()) {
-    return emitError()
-           << "dim size must match the number of dimensions in the affine map";
-  }
-  if (map.getNumSymbols() != range_vars.size()) {
-    return emitError()
-           << "range size must match the number of symbols in the affine map";
-  }
-  return mlir::success();
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.h b/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.h
deleted file mode 100644
index fca9216..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.h
+++ /dev/null
@@ -1,33 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_FUSIONS_MLIR_IR_XLA_GPU_ATTRS_H_
-#define XLA_SERVICE_GPU_FUSIONS_MLIR_IR_XLA_GPU_ATTRS_H_
-
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/OpImplementation.h"
-#include "mlir/Support/LLVM.h"
-#include "xla/service/gpu/model/indexing_map.h"  // IWYU pragma: keep
-
-namespace xla {
-namespace gpu {
-
-// Custom parser to parse IndexingMapAttr.
-mlir::FailureOr<mlir::Attribute> ParseIndexingMapAttr(mlir::AsmParser& parser);
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_FUSIONS_MLIR_IR_XLA_GPU_ATTRS_H_
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.td b/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.td
deleted file mode 100644
index 8c8f98c..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.td
+++ /dev/null
@@ -1,60 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_FUSIONS_MLIR_ATTRS
-#define XLA_SERVICE_GPU_FUSIONS_MLIR_ATTRS
-
-include "mlir/IR/AttrTypeBase.td"
-include "xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.td"
-
-class XLAGPU_Attr<string name, list<Trait> traits = []> :
-      AttrDef<XlaGpuDialect, name, traits> {
-}
-
-def XLAGPU_AffineMapParameter :
-    AttrOrTypeParameter<"::mlir::AffineMap", ""> {
-}
-
-def XLAGPU_DimVarsParameter : ArrayRefParameter<"::xla::gpu::DimVar",
-                                                "DimVarArray"> {
-}
-
-def XLAGPU_RangeVarsParameter : ArrayRefParameter<"::xla::gpu::RangeVar",
-                                                  "RangeVarArray"> {
-}
-
-def XLAGPU_ConstraintsParameter :
-    ArrayRefParameter<"::std::pair<::mlir::AffineExpr, ::xla::gpu::Interval>",
-                      "ContraintsArray"> {
-}
-
-def XLAGPU_IndexingMapAttr : XLAGPU_Attr<"IndexingMap"> {
-  let summary = "An Attribute representing an indexing map.";
-  let mnemonic = "indexing_map";
-  let description = [{This attribute stores an indexing map. See
-    https://openxla.org/xla/indexing for more details.
-  }];
-  let parameters = (ins XLAGPU_AffineMapParameter:$map,
-                        XLAGPU_DimVarsParameter:$dim_vars,
-                        XLAGPU_RangeVarsParameter:$range_vars,
-                        XLAGPU_ConstraintsParameter:$constraints);
-  let hasCustomAssemblyFormat = 1;
-  let builders =  [
-    AttrBuilder<(ins "const ::xla::gpu::IndexingMap&":$indexing_map)>,
-  ];
-  let genVerifyDecl = 1;
-}
-
-#endif // MLIR_HLO_DIALECT_MHLO_IR_HLO_ATTRS
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.cc b/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.cc
deleted file mode 100644
index 3dc60c9..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.cc
+++ /dev/null
@@ -1,129 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "llvm/ADT/TypeSwitch.h"
-#include "mlir/IR/DialectImplementation.h"  // IWYU pragma: keep
-#include "mlir/IR/OpImplementation.h"  // IWYU pragma: keep
-#include "mlir/Transforms/InliningUtils.h"
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.h"
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h"
-#define GET_ATTRDEF_CLASSES
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.cc.inc"
-#undef GET_ATTRDEF_CLASSES
-
-namespace xla {
-namespace gpu {
-namespace {
-
-struct XlaGpuInlinerInterface : public mlir::DialectInlinerInterface {
-  using DialectInlinerInterface::DialectInlinerInterface;
-  // Returns true if the given operation 'callable', that implements the
-  // 'CallableOpInterface', can be inlined into the position given call
-  // operation 'call', that is registered to the current dialect and implements
-  // the `CallOpInterface`. 'wouldBeCloned' is set to true if the region of the
-  // given 'callable' is set to be cloned during the inlining process, or false
-  // if the region is set to be moved in-place (i.e. no duplicates would be
-  // created).
-  bool isLegalToInline(mlir::Operation* call, mlir::Operation* callable,
-                       bool wouldBeCloned) const final {
-    if (!wouldBeCloned) {
-      // If no duplicate would be created, 'call' is likely the only caller of
-      // 'callable'.
-      return true;
-    }
-    // Otherwise, inline only if the called function is small. We could
-    // theoretically also inline if there is no other caller in the function
-    // that contains the callee that has a call path to the callable, but that
-    // is more expensive to check.
-    auto func_op = mlir::dyn_cast<mlir::func::FuncOp>(callable);
-    if (!func_op) {
-      return false;
-    }
-    auto region = func_op.getCallableRegion();
-    if (!region) {
-      return false;
-    }
-
-    // If callee and caller call the same third function, inline. We have no
-    // guarantee that the indices are the same, but there is a good chance they
-    // are (or if the callee gets inlined as well, there will be CSE
-    // opportunities).
-    // This is duct tape to work around the limitations of our partitioner.
-    // Ideally, the partitioner would be aware of the actual indexing and create
-    // the partitions based on it (i.e., the case where the indices are the same
-    // would never happen).
-    llvm::SmallDenseSet<llvm::StringRef> callee_calls;
-    for (auto call : region->getOps<PureCallOp>()) {
-      callee_calls.insert(call.getCallee());
-    }
-    for (auto call : call->getParentRegion()->getOps<PureCallOp>()) {
-      if (callee_calls.contains(call.getCallee())) {
-        return true;
-      }
-    }
-
-    constexpr int kMaxOperationsToInline = 8;
-    int num_ops = 0;
-    region->front().walk([&](mlir::Operation* op) { ++num_ops; });
-
-    // Don't inline functions that are called more than once and contain more
-    // than one call themselves.
-    return num_ops <= kMaxOperationsToInline;
-  }
-  // Returns true if the given operation 'op', that is registered to this
-  // dialect, can be inlined into the given region, false otherwise.
-  // 'wouldBeCloned' is set to true if the given 'op' is set to be cloned
-  // during the inlining process, or false if the operation is set to be moved
-  // in-place(i.e. no duplicates would be created). 'valueMapping' contains any
-  // remapped values from within the 'src' region. This can be used to examine
-  // what values may potentially replace the operands to 'op'.
-  bool isLegalToInline(mlir::Operation* op, mlir::Region* dest,
-                       bool wouldBeCloned,
-                       mlir::IRMapping& valueMapping) const final {
-    // We allow any op from the xla_gpu dialect to be inlined.
-    return true;
-  }
-};
-
-struct XlaGpuOpAsmDialectInterface : public mlir::OpAsmDialectInterface {
-  using OpAsmDialectInterface::OpAsmDialectInterface;
-  AliasResult getAlias(mlir::Attribute attr,
-                       mlir::raw_ostream& os) const final {
-    if (llvm::isa<IndexingMapAttr>(attr)) {
-      os << "indexing_map";
-      return AliasResult::FinalAlias;
-    }
-    return AliasResult::NoAlias;
-  }
-};
-
-}  // namespace
-
-void XlaGpuDialect::initialize() {
-  addOperations<
-#define GET_OP_LIST
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc.inc"
-#undef GET_OP_LIST
-      >();
-  addAttributes<
-#define GET_ATTRDEF_LIST
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.cc.inc"
-      >();
-#undef GET_ATTRDEF_LIST
-  addInterfaces<XlaGpuInlinerInterface, XlaGpuOpAsmDialectInterface>();
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.td b/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.td
deleted file mode 100644
index 4400747..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.td
+++ /dev/null
@@ -1,32 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_FUSIONS_MLIR_DIALECT
-#define XLA_SERVICE_GPU_FUSIONS_MLIR_DIALECT
-
-include "mlir/IR/DialectBase.td"
-
-def XlaGpuDialect : Dialect {
-  let name = "xla_gpu";
-
-  let description = [{
-    This dialect contains ops required for lowering HLO to LLVM.
-  }];
-
-  let cppNamespace = "::xla::gpu";
-  let useDefaultAttributePrinterParser = 1;
-}
-
-#endif // MLIR_HLO_DIALECT_MHLO_IR_HLO_DIALECT
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc b/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc
deleted file mode 100644
index dfa4d05..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc
+++ /dev/null
@@ -1,655 +0,0 @@
-/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h"
-
-#include <cstdint>
-#include <optional>
-#include <utility>
-#include <vector>
-
-#include "llvm/ADT/DenseSet.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/STLFunctionalExtras.h"
-#include "llvm/ADT/SmallBitVector.h"
-#include "llvm/ADT/TypeSwitch.h"  // IWYU pragma: keep
-#include "llvm/Support/Casting.h"
-#include "llvm/Support/LogicalResult.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/IR/AffineExpr.h"
-#include "mlir/IR/Builders.h"  // IWYU pragma: keep
-#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/IR/DialectImplementation.h"  // IWYU pragma: keep
-#include "mlir/IR/MLIRContext.h"  // IWYU pragma: keep
-#include "mlir/IR/OpDefinition.h"
-#include "mlir/IR/OpImplementation.h"
-#include "mlir/IR/OperationSupport.h"
-#include "mlir/IR/PatternMatch.h"  // IWYU pragma: keep
-#include "mlir/IR/SymbolTable.h"
-#include "mlir/IR/TypeUtilities.h"  // IWYU pragma: keep
-#include "mlir/IR/Value.h"
-#include "mlir/IR/ValueRange.h"
-#include "mlir/Support/LLVM.h"
-#include "mlir/Support/LogicalResult.h"
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.cc.inc"
-#include "xla/service/gpu/model/indexing_map.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-using llvm::ArrayRef;
-using mlir::AffineExpr;
-using mlir::AffineMap;
-using mlir::failure;
-using mlir::getAffineConstantExpr;
-using mlir::getAffineDimExpr;
-using mlir::getAffineSymbolExpr;
-using mlir::LogicalResult;
-using mlir::MLIRContext;
-using mlir::OpBuilder;
-using mlir::OperationState;
-using mlir::PatternRewriter;
-using mlir::RankedTensorType;
-using mlir::Region;
-using mlir::SmallVector;
-using mlir::success;
-using mlir::Type;
-using mlir::Value;
-using mlir::ValueRange;
-
-namespace arith = mlir::arith;
-
-}  // namespace
-
-LogicalResult PureCallOp::verifySymbolUses(
-    mlir::SymbolTableCollection& symbolTable) {
-  auto callee = getCalleeAttr();
-  auto function =
-      symbolTable.lookupNearestSymbolFrom<mlir::func::FuncOp>(*this, callee);
-  if (!function) {
-    return emitError("'f' attribute refers to an undefined function: ")
-           << callee;
-  }
-
-  int func_arg_count = function.getFunctionType().getNumInputs();
-  int arg_count = getOperands().size();
-
-  if (arg_count != func_arg_count) {
-    return emitError() << "argument count mismatch: 'operands' has "
-                       << arg_count << " arguments, but '" << callee
-                       << "' expects " << func_arg_count;
-  }
-
-  return success();
-}
-
-//===----------------------------------------------------------------------===//
-// AllocateSharedOp
-//===----------------------------------------------------------------------===//
-
-void AllocateSharedOp::getAsmResultNames(
-    llvm::function_ref<void(mlir::Value, mlir::StringRef)> setNameFn) {
-  setNameFn(getResult(), "shmem");
-}
-
-//===----------------------------------------------------------------------===//
-// ApplyIndexingOp
-//===----------------------------------------------------------------------===//
-
-void ApplyIndexingOp::build(OpBuilder& builder, OperationState& result,
-                            ValueRange dims, ValueRange symbols,
-                            const IndexingMap& indexing_map) {
-  SmallVector<Value, 4> operands;
-  operands.reserve(dims.size() + symbols.size());
-  operands.append(dims.begin(), dims.end());
-  operands.append(symbols.begin(), symbols.end());
-  build(builder, result, operands, indexing_map);
-}
-
-void ApplyIndexingOp::build(OpBuilder& builder, OperationState& result,
-                            ValueRange operands,
-                            const IndexingMap& indexing_map) {
-  build(builder, result, operands, indexing_map.GetAffineMap(),
-        indexing_map.GetDimVars(), indexing_map.GetRangeVars());
-}
-
-void ApplyIndexingOp::build(OpBuilder& builder, OperationState& result,
-                            ValueRange operands, AffineMap affine_map,
-                            ArrayRef<DimVar> dim_vars,
-                            ArrayRef<RangeVar> range_vars) {
-  SmallVector<int64_t, 4> lower_bounds, upper_bounds;
-  for (const DimVar& dim_var : dim_vars) {
-    lower_bounds.push_back(dim_var.bounds.lower);
-    upper_bounds.push_back(dim_var.bounds.upper);
-  }
-  for (const RangeVar& range_var : range_vars) {
-    lower_bounds.push_back(range_var.range.lower);
-    upper_bounds.push_back(range_var.range.upper);
-  }
-  build(builder, result, operands, affine_map, lower_bounds, upper_bounds);
-}
-
-void ApplyIndexingOp::build(OpBuilder& builder, OperationState& result,
-                            ValueRange operands, AffineMap affine_map,
-                            ArrayRef<int64_t> lower_bounds,
-                            ArrayRef<int64_t> upper_bounds) {
-  SmallVector<Type, 2> result_types(affine_map.getNumResults(),
-                                    builder.getIndexType());
-  build(builder, result, result_types, operands, affine_map, lower_bounds,
-        upper_bounds);
-}
-
-// Parser a comma-separated list of type %operand in [lower_bound, upper_bound].
-// Adds the parsed elements to the provided containers.
-mlir::ParseResult parseOperandsWithBoundsList(
-    mlir::OpAsmParser& parser,
-    SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4>* operands,
-    SmallVector<int64_t, 4>* lower_bounds,
-    SmallVector<int64_t, 4>* upper_bounds) {
-  int64_t lower_bound, upper_bound;
-  mlir::OpAsmParser::UnresolvedOperand operand;
-  if (parser.parseCommaSeparatedList([&]() {
-        if (parser.parseOperand(operand) || parser.parseKeyword("in") ||
-            parser.parseLSquare() || parser.parseInteger(lower_bound) ||
-            parser.parseComma() || parser.parseInteger(upper_bound) ||
-            parser.parseRSquare()) {
-          return failure();
-        }
-        operands->push_back(operand);
-        lower_bounds->push_back(lower_bound);
-        upper_bounds->push_back(upper_bound);
-        return success();
-      })) {
-    return failure();
-  }
-  return success();
-}
-
-mlir::ParseResult ApplyIndexingOp::parse(mlir::OpAsmParser& parser,
-                                         OperationState& result) {
-  mlir::Builder& builder = parser.getBuilder();
-  auto index_type = builder.getIndexType();
-
-  mlir::AffineMapAttr affine_map_attr;
-  if (parser.parseAttribute(affine_map_attr, "map", result.attributes)) {
-    return failure();
-  }
-
-  SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> operands;
-  SmallVector<int64_t, 4> lower_bounds, upper_bounds;
-  if (succeeded(parser.parseOptionalLParen())) {
-    if (parseOperandsWithBoundsList(parser, &operands, &lower_bounds,
-                                    &upper_bounds) ||
-        parser.parseRParen()) {
-      return failure();
-    }
-  }
-  if (succeeded(parser.parseOptionalLSquare())) {
-    if (parseOperandsWithBoundsList(parser, &operands, &lower_bounds,
-                                    &upper_bounds) ||
-        parser.parseRSquare()) {
-      return failure();
-    }
-  }
-  if (parser.resolveOperands(operands, index_type, result.operands) ||
-      parser.parseOptionalAttrDict(result.attributes)) {
-    return failure();
-  }
-  result.addAttribute("lower_bounds",
-                      builder.getDenseI64ArrayAttr(lower_bounds));
-  result.addAttribute("upper_bounds",
-                      builder.getDenseI64ArrayAttr(upper_bounds));
-
-  auto map = affine_map_attr.getAffineMap();
-  result.addTypes(SmallVector<Type, 2>(map.getNumResults(), index_type));
-  return success();
-}
-
-void ApplyIndexingOp::print(mlir::OpAsmPrinter& p) {
-  mlir::AffineMapAttr affine_map_attr = getMapAttr();
-  AffineMap affine_map = affine_map_attr.getAffineMap();
-  p << " " << affine_map_attr;
-
-  auto lower_bounds = getLowerBounds();
-  auto upper_bounds = getUpperBounds();
-  auto operands = getOperands();
-  unsigned num_dimensions = affine_map.getNumDims();
-  if (num_dimensions > 0) {
-    p << '(';
-    for (int dim_id = 0; dim_id < num_dimensions; ++dim_id) {
-      p << operands[dim_id] << " in " << '[' << lower_bounds[dim_id] << ", "
-        << upper_bounds[dim_id] << ']';
-      if (dim_id != num_dimensions - 1) {
-        p << ", ";
-      }
-    }
-    p << ')';
-  }
-  unsigned num_symbols = affine_map.getNumSymbols();
-  if (num_symbols > 0) {
-    p << '[';
-    for (int symbol_id = 0; symbol_id < num_symbols; ++symbol_id) {
-      unsigned operand_id = num_dimensions + symbol_id;
-      p << operands[operand_id] << " in " << '[' << lower_bounds[operand_id]
-        << ", " << upper_bounds[operand_id] << ']';
-      if (symbol_id != num_symbols - 1) {
-        p << ", ";
-      }
-    }
-    p << ']';
-  }
-  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
-                              "map", "lower_bounds", "upper_bounds"});
-}
-
-LogicalResult ApplyIndexingOp::verify() {
-  auto affine_map = getMapAttr().getAffineMap();
-  unsigned num_variables = affine_map.getNumDims() + affine_map.getNumSymbols();
-  if (getOperands().size() != num_variables ||
-      getLowerBounds().size() != num_variables ||
-      getUpperBounds().size() != num_variables) {
-    return emitOpError(
-        "operand, lower_bounds, upper_bounds count and affine map dimension "
-        "and symbol count must match");
-  }
-  return success();
-}
-
-IndexingMap ApplyIndexingOp::getIndexingMap() {
-  auto lower_bounds = getLowerBounds();
-  auto upper_bounds = getUpperBounds();
-
-  AffineMap affine_map = getAffineMap();
-  unsigned num_dimensions = affine_map.getNumDims();
-  std::vector<DimVar> dim_vars;
-  dim_vars.reserve(num_dimensions);
-  for (unsigned id = 0; id < num_dimensions; ++id) {
-    dim_vars.push_back(DimVar{Interval{lower_bounds[id], upper_bounds[id]}});
-  }
-  unsigned num_symbols = affine_map.getNumSymbols();
-  std::vector<RangeVar> range_vars;
-  range_vars.reserve(num_symbols);
-  for (unsigned id = num_dimensions; id < num_symbols + num_dimensions; ++id) {
-    range_vars.push_back(
-        RangeVar{Interval{lower_bounds[id], upper_bounds[id]}});
-  }
-  return IndexingMap(affine_map, std::move(dim_vars), std::move(range_vars),
-                     /*rt_vars=*/{});
-}
-
-namespace {
-
-// Simplifies the indexing map, removes unused variables.
-struct SimplifyIndexingMap : public mlir::OpRewritePattern<ApplyIndexingOp> {
-  using OpRewritePattern<ApplyIndexingOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ApplyIndexingOp indexing_op,
-                                PatternRewriter& rewriter) const override {
-    IndexingMap indexing_map = indexing_op.getIndexingMap();
-    bool is_simplified = indexing_map.Simplify();
-
-    // Remove unused symbols.
-    auto unused_symbols_bit_vector = indexing_map.RemoveUnusedVars();
-    bool symbols_removed = unused_symbols_bit_vector.count() != 0;
-
-    if (!is_simplified && !symbols_removed) {
-      return rewriter.notifyMatchFailure(indexing_op,
-                                         "IndexingMap stayed unchanged");
-    }
-    if (!unused_symbols_bit_vector.empty()) {
-      SmallVector<Value, 4> operands;
-      operands.reserve(unused_symbols_bit_vector.count());
-      for (int i = 0; i < unused_symbols_bit_vector.size(); ++i) {
-        if (!unused_symbols_bit_vector[i]) {
-          operands.push_back(indexing_op.getOperand(i));
-        }
-      }
-      rewriter.replaceOpWithNewOp<ApplyIndexingOp>(indexing_op, operands,
-                                                   indexing_map);
-    } else {
-      rewriter.replaceOpWithNewOp<ApplyIndexingOp>(
-          indexing_op, indexing_op.getOperands(), indexing_map);
-    }
-    return success();
-  }
-};
-
-struct FoldApplyIndexingSequence
-    : public mlir::OpRewritePattern<ApplyIndexingOp> {
-  using OpRewritePattern<ApplyIndexingOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ApplyIndexingOp indexing_op,
-                                PatternRewriter& rewriter) const override {
-    MLIRContext* ctx = indexing_op.getContext();
-    int num_dims = indexing_op.getAffineMap().getNumDims();
-    int num_syms = indexing_op.getAffineMap().getNumSymbols();
-    mlir::DenseMap<Value, AffineExpr> operand_exprs;
-    for (auto& operand : indexing_op->getOpOperands()) {
-      int operand_number = operand.getOperandNumber();
-      operand_exprs[operand.get()] =
-          operand_number < num_dims
-              ? getAffineDimExpr(operand_number, ctx)
-              : getAffineSymbolExpr(operand_number - num_dims, ctx);
-    }
-
-    auto this_map = indexing_op.getIndexingMap();
-
-    SmallVector<Value> added_dim_args;
-    SmallVector<Value> added_sym_args;
-    auto new_dim_vars = this_map.GetDimVars();
-    auto new_sym_vars = this_map.GetRangeVars();
-
-    mlir::DenseMap<AffineExpr, AffineExpr> replacements;
-    for (auto& operand : indexing_op->getOpOperands()) {
-      if (auto producer = operand.get().getDefiningOp<ApplyIndexingOp>()) {
-        auto producer_map = producer.getIndexingMap();
-        int producer_result_id =
-            mlir::cast<mlir::OpResult>(operand.get()).getResultNumber();
-        int num_producer_dims = producer.getAffineMap().getNumDims();
-        SmallVector<AffineExpr> producer_dim_replacements;
-        SmallVector<AffineExpr> producer_sym_replacements;
-        for (auto& producer_operand : producer->getOpOperands()) {
-          int producer_operand_number = producer_operand.getOperandNumber();
-          bool is_dim = producer_operand_number < num_producer_dims;
-          auto& replacement_expr = operand_exprs[producer_operand.get()];
-          if (!replacement_expr) {
-            if (is_dim) {
-              int dim_num = producer_operand_number;
-              replacement_expr =
-                  getAffineDimExpr(num_dims + added_dim_args.size(), ctx);
-              added_dim_args.push_back(producer_operand.get());
-              new_dim_vars.push_back(producer_map.GetDimVars(dim_num));
-            } else {
-              int sym_num = producer_operand_number -
-                            producer.getAffineMap().getNumDims();
-              replacement_expr =
-                  getAffineSymbolExpr(num_syms + added_sym_args.size(), ctx);
-              added_sym_args.push_back(producer_operand.get());
-              new_sym_vars.push_back(producer_map.GetRangeVar(sym_num));
-            }
-          }
-
-          if (is_dim) {
-            producer_dim_replacements.push_back(replacement_expr);
-          } else {
-            producer_sym_replacements.push_back(replacement_expr);
-          }
-        }
-
-        replacements[operand_exprs[operand.get()]] =
-            producer.getAffineMap()
-                .getResult(producer_result_id)
-                .replaceDimsAndSymbols(producer_dim_replacements,
-                                       producer_sym_replacements);
-      }
-    }
-
-    if (replacements.empty()) {
-      return rewriter.notifyMatchFailure(indexing_op,
-                                         "No apply_indexing sequences found");
-    }
-
-    int new_num_operands = indexing_op->getNumOperands() +
-                           added_dim_args.size() + added_sym_args.size();
-    auto new_affine_map = indexing_op.getAffineMap().replace(
-        replacements, num_dims + added_dim_args.size(),
-        num_syms + added_sym_args.size());
-    IndexingMap new_indexing_map(new_affine_map, new_dim_vars, new_sym_vars,
-                                 /*rt_vars=*/{});
-    if (!new_indexing_map.Simplify()) {
-      return rewriter.notifyMatchFailure(
-          indexing_op, "Folded indexing map was not simplified");
-    }
-    SmallVector<Value> new_operands;
-    new_operands.reserve(new_num_operands);
-
-    auto begin = indexing_op.getOperands().begin();
-    new_operands.append(begin, begin + num_dims);
-    new_operands.append(added_dim_args);
-    new_operands.append(begin + num_dims, begin + num_dims + num_syms);
-    new_operands.append(added_sym_args);
-
-    rewriter.replaceOpWithNewOp<ApplyIndexingOp>(indexing_op, new_operands,
-                                                 new_indexing_map);
-    return success();
-  }
-};
-
-// Folds constants into the indexing map.
-struct FoldApplyIndexingOperands
-    : public mlir::OpRewritePattern<ApplyIndexingOp> {
-  using OpRewritePattern<ApplyIndexingOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ApplyIndexingOp indexing_op,
-                                PatternRewriter& rewriter) const override {
-    AffineMap affine_map = indexing_op.getAffineMap();
-
-    MLIRContext* ctx = affine_map.getContext();
-    unsigned num_operands = indexing_op->getNumOperands();
-    unsigned num_dims = affine_map.getNumDims();
-    unsigned num_symbols = affine_map.getNumSymbols();
-
-    SmallVector<std::optional<int64_t>> constant_values(num_operands,
-                                                        std::nullopt);
-    int num_constants = 0;
-    SmallVector<int64_t> dim_id_map(num_dims, -1);
-    SmallVector<int64_t> symbol_id_map(num_symbols, -1);
-    for (auto& operand : indexing_op->getOpOperands()) {
-      if (auto constant =
-              operand.get().getDefiningOp<arith::ConstantIndexOp>()) {
-        constant_values[operand.getOperandNumber()] = constant.value();
-        ++num_constants;
-      }
-    }
-    if (num_constants == 0) {
-      return rewriter.notifyMatchFailure(indexing_op,
-                                         "No constant operands found");
-    }
-    SmallVector<AffineExpr, 2> dim_replacements, symbol_replacements;
-    dim_replacements.reserve(num_dims);
-    symbol_replacements.reserve(num_symbols);
-
-    unsigned new_num_operands = indexing_op->getNumOperands() - num_constants;
-    SmallVector<Value, 4> new_operands;
-    new_operands.reserve(new_num_operands);
-    SmallVector<int64_t, 4> new_lbs, new_ubs;
-    new_lbs.reserve(new_num_operands);
-    new_ubs.reserve(new_num_operands);
-
-    unsigned new_num_dims = 0;
-    unsigned new_num_symbols = 0;
-    for (auto [operand, constant_value, lb, ub] : llvm::zip(
-             indexing_op->getOpOperands(), constant_values,
-             indexing_op.getLowerBounds(), indexing_op.getUpperBounds())) {
-      unsigned operand_id = operand.getOperandNumber();
-      if (constant_value.has_value()) {
-        if (operand_id < num_dims) {
-          dim_replacements.push_back(
-              getAffineConstantExpr(*constant_value, ctx));
-        } else {
-          symbol_replacements.push_back(
-              getAffineConstantExpr(*constant_value, ctx));
-        }
-      } else {
-        if (operand_id < num_dims) {
-          dim_replacements.push_back(getAffineDimExpr(new_num_dims++, ctx));
-        } else {
-          symbol_replacements.push_back(
-              getAffineSymbolExpr(new_num_symbols++, ctx));
-        }
-        new_operands.push_back(operand.get());
-        new_lbs.push_back(lb);
-        new_ubs.push_back(ub);
-      }
-    }
-    rewriter.replaceOpWithNewOp<ApplyIndexingOp>(
-        indexing_op, new_operands,
-        affine_map.replaceDimsAndSymbols(dim_replacements, symbol_replacements,
-                                         new_num_dims, new_num_symbols),
-        new_lbs, new_ubs);
-    return success();
-  }
-};
-
-// Folds constant and dim/symbol expression results.
-struct FoldApplyIndexingResults
-    : public mlir::OpRewritePattern<ApplyIndexingOp> {
-  using OpRewritePattern<ApplyIndexingOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ApplyIndexingOp indexing_op,
-                                PatternRewriter& rewriter) const override {
-    mlir::Location loc = indexing_op.getLoc();
-    IndexingMap indexing_map = indexing_op.getIndexingMap();
-    if (indexing_map.IsKnownEmpty()) {
-      return rewriter.notifyMatchFailure(indexing_op,
-                                         "Domain of the indexing map is empty");
-    }
-    AffineMap* affine_map = &indexing_map.GetMutableAffineMap();
-    unsigned num_results = affine_map->getNumResults();
-    SmallVector<AffineExpr, 4> new_exprs;
-    new_exprs.reserve(num_results);
-    SmallVector<Value, 4> new_values;
-    new_values.reserve(num_results);
-    for (mlir::OpResult opresult : indexing_op->getOpResults()) {
-      if (opresult.use_empty()) {
-        new_values.push_back(rewriter.create<arith::ConstantIndexOp>(loc, 0));
-        continue;
-      }
-
-      unsigned id = opresult.getResultNumber();
-      AffineExpr result_expr = affine_map->getResult(id);
-      if (auto const_expr =
-              mlir::dyn_cast<mlir::AffineConstantExpr>(result_expr)) {
-        new_values.push_back(rewriter.create<arith::ConstantIndexOp>(
-            loc, const_expr.getValue()));
-        continue;
-      }
-      if (auto dim_expr = mlir::dyn_cast<mlir::AffineDimExpr>(result_expr)) {
-        new_values.push_back(indexing_op.getOperand(dim_expr.getPosition()));
-        continue;
-      }
-      if (auto symbol_expr =
-              mlir::dyn_cast<mlir::AffineSymbolExpr>(result_expr)) {
-        new_values.push_back(indexing_op.getOperand(
-            indexing_map.GetDimVarsCount() + symbol_expr.getPosition()));
-        continue;
-      }
-      new_exprs.push_back(result_expr);
-      new_values.push_back(Value{});
-    }
-    if (new_exprs.size() == num_results) {
-      return rewriter.notifyMatchFailure(
-          indexing_op, "No constant or dim/symbol expression found");
-    }
-    *affine_map =
-        AffineMap::get(affine_map->getNumDims(), affine_map->getNumSymbols(),
-                       new_exprs, affine_map->getContext());
-    auto new_indexing_op = rewriter.create<ApplyIndexingOp>(
-        loc, indexing_op.getOperands(), indexing_map);
-    for (int new_result_id = 0, new_indexing_op_result_id = 0;
-         new_result_id < new_values.size(); ++new_result_id) {
-      auto& new_value = new_values[new_result_id];
-      if (new_value) continue;
-      new_value = new_indexing_op.getResult(new_indexing_op_result_id++);
-    }
-    rewriter.replaceOp(indexing_op, new_values);
-    return success();
-  }
-};
-
-}  // namespace
-
-void ApplyIndexingOp::getCanonicalizationPatterns(
-    mlir::RewritePatternSet& results, MLIRContext* context) {
-  results.add<FoldApplyIndexingOperands, FoldApplyIndexingResults,
-              SimplifyIndexingMap, FoldApplyIndexingSequence>(context);
-}
-
-mlir::LogicalResult ApplyIndexingOp::fold(
-    FoldAdaptor adaptor, llvm::SmallVectorImpl<mlir::OpFoldResult>& results) {
-  auto map = getAffineMap();
-  for (auto expr : map.getResults()) {
-    if (auto dim = mlir::dyn_cast<mlir::AffineDimExpr>(expr)) {
-      results.push_back(getOperand(dim.getPosition()));
-    } else if (auto sym = mlir::dyn_cast<mlir::AffineSymbolExpr>(expr)) {
-      results.push_back(getOperand(map.getNumDims() + sym.getPosition()));
-    } else {
-      results.clear();
-      return failure();
-    }
-  }
-  return success();
-}
-
-//===----------------------------------------------------------------------===//
-// AtomicRMWOp
-//===----------------------------------------------------------------------===//
-
-void AtomicRMWOp::getAsmResultNames(
-    llvm::function_ref<void(mlir::Value, mlir::StringRef)> setNameFn) {
-  setNameFn(getResult(), "atomic_rmw");
-}
-
-void AtomicRMWOp::build(OpBuilder& builder, OperationState& result,
-                        Value tensor, ValueRange ivs) {
-  OpBuilder::InsertionGuard g(builder);
-  result.addOperands(tensor);
-  result.addOperands(ivs);
-  result.addTypes(tensor.getType());
-
-  auto tensor_type = llvm::cast<RankedTensorType>(tensor.getType());
-  Region* body = result.addRegion();
-  builder.createBlock(body);
-  body->addArgument(tensor_type.getElementType(), tensor.getLoc());
-}
-
-mlir::OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) {
-  auto* body = getBody();
-  if (&body->front() == body->getTerminator() &&
-      body->front().getOperand(0) == body->getArgument(0)) {
-    return getOperand(0);
-  }
-  return {};
-}
-
-//===----------------------------------------------------------------------===//
-// PureCallOp
-//===----------------------------------------------------------------------===//
-
-void PureCallOp::getAsmResultNames(
-    llvm::function_ref<void(mlir::Value, mlir::StringRef)> setNameFn) {
-  for (auto result : getResults()) {
-    setNameFn(result, "pure_call");
-  }
-}
-
-//===----------------------------------------------------------------------===//
-// SyncThreadsOp
-//===----------------------------------------------------------------------===//
-
-void SyncThreadsOp::getAsmResultNames(
-    llvm::function_ref<void(mlir::Value, mlir::StringRef)> setNameFn) {
-  for (auto result : getResults()) {
-    setNameFn(result, "synced_tensor");
-  }
-}
-
-}  // namespace gpu
-}  // namespace xla
-
-#define GET_OP_CLASSES
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc.inc"
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h b/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h
deleted file mode 100644
index f43786f..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h
+++ /dev/null
@@ -1,39 +0,0 @@
-/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_FUSIONS_MLIR_IR_XLA_GPU_OPS_H_
-#define XLA_SERVICE_GPU_FUSIONS_MLIR_IR_XLA_GPU_OPS_H_
-
-#include "mlir/Bytecode/BytecodeOpInterface.h"  // IWYU pragma: keep
-#include "mlir/Dialect/Func/IR/FuncOps.h"  // IWYU pragma: keep
-#include "mlir/IR/Attributes.h"  // IWYU pragma: keep
-#include "mlir/IR/BuiltinTypes.h"  // IWYU pragma: keep
-#include "mlir/IR/Dialect.h"  // IWYU pragma: keep
-#include "mlir/IR/MLIRContext.h"  // IWYU pragma: keep
-#include "mlir/IR/OpDefinition.h"  // IWYU pragma: keep
-#include "mlir/IR/OpImplementation.h"  // IWYU pragma: keep
-#include "mlir/Interfaces/CallInterfaces.h"  // IWYU pragma: keep
-#include "mlir/Interfaces/InferTypeOpInterface.h"  // IWYU pragma: keep
-#include "mlir/Interfaces/SideEffectInterfaces.h"  // IWYU pragma: keep
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.h"  // IWYU pragma: keep
-
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.h.inc"
-#define GET_OP_CLASSES
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h.inc"
-#undef GET_OP_CLASSES
-#define GET_ATTRDEF_CLASSES
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.h.inc"
-#undef GET_ATTRDEF_CLASSES
-
-#endif  // XLA_SERVICE_GPU_FUSIONS_MLIR_IR_XLA_GPU_OPS_H_
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.td b/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.td
deleted file mode 100644
index c05f843..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.td
+++ /dev/null
@@ -1,275 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_FUSIONS_MLIR_OPS
-#define XLA_SERVICE_GPU_FUSIONS_MLIR_OPS
-
-include "mlir/IR/AttrTypeBase.td"
-include "mlir/IR/OpBase.td"
-include "mlir/IR/SymbolInterfaces.td"
-include "mlir/IR/OpAsmInterface.td"
-include "mlir/Interfaces/CallInterfaces.td"
-include "mlir/Interfaces/InferTypeOpInterface.td"
-include "mlir/Interfaces/SideEffectInterfaces.td"
-include "xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.td"
-
-class XLAGPU_Op<string mnemonic, list<Trait> traits = []> :
-      Op<XlaGpuDialect, mnemonic, traits> {
-}
-
-def XLAGPU_AllocateSharedOp : XLAGPU_Op<"allocate_shared", [
-      DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
-  ]> {
-  let summary = "Allocates a shared memory tile.";
-
-  let description = [{
-    Allocates a shared memory tensor. The tensor is shared among all threads in
-    a block.
-
-    ```mlir
-    %shared = xla_gpu.allocate_shared : tensor<32x32xf32>
-    ```
-  }];
-
-  let results = (outs AnyStaticShapeTensor:$result);
-
-  let assemblyFormat = "attr-dict `:` type($result)";
-}
-
-def XLAGPU_SyncThreadsOp : XLAGPU_Op<"sync_threads", [
-   TypesMatchWith<"result type matches type of dest",
-                   "operands", "results", "$_self">,
-   DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
-  ]> {
-  let summary = "Synchronizes threads.";
-
-  let description = [{
-    Synchronizes threads, taking any number of distributed tensors and returning
-    the synchronized state.
-  }];
-
-  let arguments = (ins Variadic<AnyRankedTensor>:$operands);
-  let results = (outs Variadic<AnyRankedTensor>:$results);
-
-  let assemblyFormat = "operands attr-dict `:` type($operands)";
-}
-
-def XLAGPU_AtomicRMWOp : XLAGPU_Op<"atomic_rmw",
-    [Pure,
-     TypesMatchWith<"result type matches type of dest",
-                    "input", "result", "$_self">,
-     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
-  ]> {
-  let summary = "Atomically updates an element of a tensor.";
-
-  let description = [{
-    Reads an element from a tensor, computes the updated value for it, and
-    writes back the result.
-  }];
-
-  let arguments = (ins AnyRankedTensor:$input, Variadic<Index>:$indices);
-  let results = (outs AnyRankedTensor:$result);
-  // The region takes the current value in the tensor as an argument and yields
-  // the updated value.
-  let regions = (region SizedRegion<1>:$computation);
-
-  let skipDefaultBuilders = 1;
-  let builders = [OpBuilder<(ins "mlir::Value":$memref, "mlir::ValueRange":$ivs)>];
-
-  let extraClassDeclaration = [{
-    mlir::Block* getBody() { return &getComputation().front(); }
-    mlir::OpBuilder getBodyBuilder() {
-      return mlir::OpBuilder(getBody(), std::prev(getBody()->end()));
-    }
-    // The value stored in tensor[ivs].
-    mlir::Value getCurrentValue() {
-      return getRegion().getArgument(0);
-    }
-  }];
-  let hasFolder = 1;
-
-  let assemblyFormat = [{
-    $input `[` $indices `]` `:` type($input) $computation attr-dict
-  }];
-}
-
-def XLAGPU_YieldOp : XLAGPU_Op<"yield",
-    [HasParent<"::xla::gpu::AtomicRMWOp">, Terminator]> {
-  let summary = "Terminator for atomic_rmw ops.";
-  let arguments = (ins AnyType:$result);
-
-  let assemblyFormat = "$result attr-dict `:` type($result)";
-}
-
-def XLAGPU_PureCallOp : XLAGPU_Op<"pure_call",
-    [Pure, CallOpInterface, DeclareOpInterfaceMethods<SymbolUserOpInterface>,
-     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
-  ]> {
-  let summary = "Function call without side effects.";
-  let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<AnyType>:$operands);
-  let results = (outs Variadic<AnyType>);
-  let builders = [
-    OpBuilder<(ins "mlir::func::FuncOp":$callee, CArg<"mlir::ValueRange", "{}">:$operands), [{
-      $_state.addOperands(operands);
-      $_state.addAttribute("callee", mlir::SymbolRefAttr::get(callee));
-      $_state.addTypes(callee.getFunctionType().getResults());
-    }]>];
-  let assemblyFormat = [{
-    $callee `(` $operands `)` attr-dict `:` functional-type($operands, results)
-  }];
-
-  let extraClassDeclaration = [{
-    operand_range getArgOperands() {
-      return getOperands();
-    }
-
-    mlir::MutableOperandRange getArgOperandsMutable() {
-      return getOperandsMutable();
-    }
-
-    mlir::CallInterfaceCallable getCallableForCallee() {
-      return (*this)->getAttrOfType<mlir::SymbolRefAttr>("callee");
-    }
-
-    void setCalleeFromCallable(mlir::CallInterfaceCallable callee) {
-      (*this)->setAttr("callee", callee.get<mlir::SymbolRefAttr>());
-    }
-  }];
-}
-
-def XLAGPU_ShuffleReduceOp : XLAGPU_Op<"shuffle_reduce",
-    [Pure,
-     TypesMatchWith<"result type matches type of operands",
-                     "operands", "results", "$_self">]> {
-  let summary = "Performs a full warp shuffle and reduces the values";
-  let description = [{
-    This op performs a full warp shuffle and reduces the results using the given
-    function. The function is invoked with the operands from the low lanes,
-    followed by the operands from the high lanes. For example:
-
-    ```
-    shuffle_reduce @argmax(%value, %idx) : (f32, index)
-    ```
-
-    Will perform shuffles with distance 16, 8, 4, 2 and 1, and will invoke
-    @argmax five times. The first invocations will be
-
-    ```
-    @argmax(%value[i], %idx[i], %value[16+i], %idx[16+i])
-    ```
-  }];
-  let builders = [
-    OpBuilder<(ins "mlir::func::FuncOp":$reducer, "mlir::ValueRange":$operands, "int64_t":$max_distance), [{
-      $_state.addOperands(operands);
-      $_state.addAttribute("reducer", mlir::SymbolRefAttr::get(reducer));
-      $_state.addAttribute("max_distance",
-          mlir::IntegerAttr::get(
-              mlir::IntegerType::get(reducer.getContext(), 64),
-              max_distance));
-      $_state.addTypes(reducer.getFunctionType().getResults());
-    }]>];
-  let arguments = (ins FlatSymbolRefAttr:$reducer,
-                   Variadic<AnyType>:$operands,
-                   I64Attr:$max_distance);
-  let results = (outs Variadic<AnyType>:$results);
-
-  let assemblyFormat = [{
-    $reducer `(` $operands `)` `to` $max_distance attr-dict `:` type($operands)
-  }];
-}
-
-def XLAGPU_PredicatedInsertOp : XLAGPU_Op<"predicated_insert",
-    [Pure,
-     TypesMatchWith<"result type matches type of operands",
-                    "dest", "result", "$_self">,
-     TypesMatchWith<"value type matches element type of dest",
-                    "dest", "value",
-                    "::llvm::cast<mlir::TensorType>($_self).getElementType()">]> {
-  let summary = "Inserts a value into a tensor if a condition holds";
-  let arguments = (ins I1:$condition, AnyType:$value,
-    AnyStaticShapeTensor:$dest, Variadic<Index>:$indices);
-  let results = (outs AnyStaticShapeTensor:$result);
-
-  let assemblyFormat = [{
-     $value `into` $dest `[` $indices `]` `if` $condition attr-dict `:` type($dest)
-  }];
-}
-
-def XLAGPU_PredicatedExtractOp : XLAGPU_Op<"predicated_extract",
-    [Pure,
-     TypesMatchWith<"fallback type matches element type of src",
-                    "src", "fallback",
-                    "::llvm::cast<mlir::TensorType>($_self).getElementType()">,
-     TypesMatchWith<"result type matches element type of src",
-                    "src", "result",
-                    "::llvm::cast<mlir::TensorType>($_self).getElementType()">]> {
-  let summary = "Inserts a value into a tensor if a condition holds";
-  let arguments = (ins I1:$condition, AnyType:$fallback,
-    AnyStaticShapeTensor:$src, Variadic<Index>:$indices);
-  let results = (outs AnyType:$result);
-
-  let assemblyFormat = [{
-    $src `[` $indices `]` `if` $condition `else` $fallback attr-dict `:` type($src)
-  }];
-}
-
-def ApplyIndexingOp : XLAGPU_Op<"apply_indexing", [Pure]> {
-  let summary = "Applies indexing map to a list of SSA values";
-  let description = [{
-    The `apply_indexing` operation applies an affine map to a list
-    of SSA values, yielding a single SSA value. The number of dimension and
-    symbol arguments must be equal to the respective number of dimensional and
-    symbolic inputs in the affine map. The affine mapping can be
-    multi-dimensional, and so the `apply_indexing` operation always returns one
-    value. The  operands and results must all have ‘index’ type.
-
-    Example:
-
-    ```mlir
-    #map = affine_map<(d0, d1)[s0] -> (d0 floordiv 8 + d1 floordiv 128, s0)>
-    %results:2 = xla_gpu_ops.apply_indexing #map (%0 in [0, 10], %1 in [0, 11])[%2 in [11, 32]]
-    ```
-  }];
-  let arguments = (ins Variadic<Index>:$operands,
-                       AffineMapAttr:$map,
-                       DenseI64ArrayAttr:$lower_bounds,
-                       DenseI64ArrayAttr:$upper_bounds);
-  let results = (outs Variadic<Index>);
-
-  let builders = [
-    OpBuilder<(ins "mlir::ValueRange":$dims, "mlir::ValueRange":$symbols,
-                   "const IndexingMap&":$indexing_map)>,
-    OpBuilder<(ins "mlir::ValueRange":$operands,
-                   "const IndexingMap&":$indexing_map)>,
-    OpBuilder<(ins "mlir::ValueRange":$operands, "mlir::AffineMap":$affine_map,
-               "llvm::ArrayRef<DimVar>":$dim_vars,
-               "llvm::ArrayRef<RangeVar>":$range_vars)>,
-    OpBuilder<(ins "mlir::ValueRange":$operands,
-                   "mlir::AffineMap":$affine_map,
-                   "llvm::ArrayRef<int64_t>":$lower_bounds,
-                   "llvm::ArrayRef<int64_t>":$upper_bounds)>,
-  ];
-  let extraClassDeclaration = [{
-    // Returns the indexing map constructed from affine_map and the bounds.
-    xla::gpu::IndexingMap getIndexingMap();
-    // Extracts the affine map from the attribute.
-    mlir::AffineMap getAffineMap() { return getMapAttr().getAffineMap(); }
-  }];
-  let hasCustomAssemblyFormat = 1;
-  let hasVerifier = 1;
-  let hasCanonicalizer = 1;
-  let hasFolder = 1;
-}
-
-#endif // MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/lower_tensors.cc b/third_party/xla/xla/service/gpu/fusions/mlir/lower_tensors.cc
deleted file mode 100644
index 929ee0e..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/lower_tensors.cc
+++ /dev/null
@@ -1,1095 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include <cassert>
-#include <cstdint>
-#include <memory>
-#include <optional>
-#include <string>
-#include <tuple>
-#include <utility>
-
-#include "absl/algorithm/container.h"
-#include "absl/log/check.h"
-#include "absl/strings/str_cat.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/ADT/StringRef.h"
-#include "llvm/Support/LogicalResult.h"
-#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/GPU/IR/GPUDialect.h"
-#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
-#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
-#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
-#include "mlir/IR/AffineExpr.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/BuiltinTypeInterfaces.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/ImplicitLocOpBuilder.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/TypeRange.h"
-#include "mlir/IR/Types.h"
-#include "mlir/IR/Value.h"
-#include "mlir/IR/ValueRange.h"
-#include "mlir/Interfaces/DataLayoutInterfaces.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Support/LLVM.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "xla/layout_util.h"
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h"
-#include "xla/service/gpu/model/indexing_analysis.h"
-#include "xla/shape_util.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/xla_data.pb.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-#define GEN_PASS_DECL_LOWERTENSORSPASS
-#define GEN_PASS_DEF_LOWERTENSORSPASS
-#include "xla/service/gpu/fusions/mlir/passes.h.inc"
-
-using mlir::failure;
-using mlir::Location;
-using mlir::LogicalResult;
-using mlir::MLIRContext;
-using mlir::OpBuilder;
-using mlir::Operation;
-using mlir::success;
-using mlir::Type;
-using mlir::TypedValue;
-using mlir::TypeRange;
-using mlir::Value;
-using mlir::ValueRange;
-
-namespace arith = ::mlir::arith;
-namespace scf = ::mlir::scf;
-namespace ml = ::mlir::LLVM;
-
-Value GetDestinationBuffer(Value dest) {
-  while (dest.getDefiningOp()) {
-    int result_number = mlir::cast<mlir::OpResult>(dest).getResultNumber();
-    if (auto insert = dest.getDefiningOp<mlir::tensor::InsertOp>()) {
-      dest = insert.getDest();
-    } else if (auto scf_if = dest.getDefiningOp<scf::IfOp>()) {
-      // Pick one of the branches, they're required to yield the same buffers.
-      dest = scf_if.getThenRegion().front().getTerminator()->getOperand(
-          result_number);
-    } else if (auto scf_for = dest.getDefiningOp<scf::ForOp>()) {
-      dest = scf_for.getInitArgs()[result_number];
-    } else if (dest.getDefiningOp<mlir::UnrealizedConversionCastOp>() ||
-               dest.getDefiningOp<AllocateSharedOp>()) {
-      break;
-    } else if (auto transfer_write =
-                   dest.getDefiningOp<mlir::vector::TransferWriteOp>()) {
-      dest = transfer_write.getSource();
-    } else {
-      dest.getDefiningOp()->emitOpError("unsupported dest type");
-      return nullptr;
-    }
-  }
-  return dest;
-}
-
-template <typename Op>
-bool IsSupportedTransfer(Op op) {
-  return !absl::c_linear_search(op.getInBoundsValues(), false) &&
-         op.getVectorType().getRank() == 1 && !op.getMask() &&
-         op.getPermutationMap().isMinorIdentity();
-}
-
-struct RewriteFunctionSignatures : mlir::OpRewritePattern<mlir::func::FuncOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(
-      mlir::func::FuncOp op, mlir::PatternRewriter& rewriter) const override {
-    auto is_tensor = [](Type ty) {
-      return mlir::isa<mlir::RankedTensorType>(ty);
-    };
-    if (!llvm::any_of(op.getFunctionType().getInputs(), is_tensor)) {
-      return rewriter.notifyMatchFailure(op,
-                                         "the function has no input tensors");
-    }
-
-    bool some_tensor_result =
-        llvm::any_of(op.getFunctionType().getResults(), is_tensor);
-    bool all_tensor_results =
-        llvm::all_of(op.getFunctionType().getResults(), is_tensor);
-    if (some_tensor_result && !all_tensor_results) {
-      op->emitOpError("function has a mix of tensor and non-tensor results");
-      return failure();
-    }
-
-    TypeRange new_results = op.getFunctionType().getResults();
-    if (some_tensor_result) {
-      new_results = {};
-      auto terminator = op.getFunctionBody().front().getTerminator();
-      rewriter.setInsertionPoint(terminator);
-      rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(terminator);
-    }
-
-    llvm::SmallVector<Type> new_operands(op.getFunctionType().getInputs());
-    for (auto&& [index, operand] : llvm::enumerate(new_operands)) {
-      if (is_tensor(operand)) {
-        rewriter.setInsertionPointToStart(&op.getBody().front());
-        auto cast = rewriter.create<mlir::UnrealizedConversionCastOp>(
-            op.getLoc(), operand, op.getArgument(index));
-        op.getArgument(index).replaceAllUsesExcept(cast.getResult(0), cast);
-        operand = mlir::LLVM::LLVMPointerType::get(op.getContext());
-      }
-    }
-
-    op.setFunctionType(rewriter.getFunctionType(new_operands, new_results));
-    auto& entry = op->getRegion(0).front();
-    for (auto [arg, arg_type] : llvm::zip(entry.getArguments(), new_operands)) {
-      arg.setType(arg_type);
-    }
-
-    return success();
-  }
-};
-
-Value GetLinearIndex(TypedValue<mlir::RankedTensorType> tensor,
-                     ValueRange indices, mlir::PatternRewriter& rewriter) {
-  auto byte_shape = ShapeUtil::MakeShape(U8, tensor.getType().getShape());
-  if (auto encoding = tensor.getType().getEncoding()) {
-    *byte_shape.mutable_layout() = LayoutUtil::MakeLayout(llvm::to_vector(
-        mlir::cast<mlir::DenseElementsAttr>(encoding).getValues<int64_t>()));
-  }
-  auto linear_shape =
-      ShapeUtil::MakeShape(U8, {ShapeUtil::ElementsIn(byte_shape)});
-  auto linearize_map =
-      GetBitcastMap(byte_shape, linear_shape, tensor.getContext());
-  mlir::SmallVector<Value> result;
-  rewriter.createOrFold<ApplyIndexingOp>(result, tensor.getLoc(), indices,
-                                         ValueRange{}, linearize_map);
-  CHECK_EQ(result.size(), 1);
-  auto index = result.front();
-  auto index_ty = rewriter.getIntegerType(
-      mlir::DataLayout::closest(rewriter.getInsertionBlock()->getParentOp())
-          .getTypeSizeInBits(index.getType()));
-  return rewriter.create<mlir::arith::IndexCastUIOp>(tensor.getLoc(), index_ty,
-                                                     index);
-}
-
-std::tuple<Value, Value> GetI4IndexAndNibble(Value linear_index,
-                                             mlir::ImplicitLocOpBuilder& b) {
-  Value one = b.create<mlir::arith::ConstantIntOp>(1, linear_index.getType());
-  Value is_low_nibble = b.create<mlir::arith::CmpIOp>(
-      mlir::arith::CmpIPredicate::eq, one,
-      b.create<mlir::arith::AndIOp>(linear_index, one));
-  Value i8_index = b.create<mlir::arith::ShRUIOp>(linear_index, one);
-  return {i8_index, is_low_nibble};
-}
-
-mlir::LLVM::GEPOp CreateGep(TypedValue<mlir::RankedTensorType> tensor,
-                            Value linear_index, mlir::PatternRewriter& rewriter,
-                            Type element_type = nullptr) {
-  if (!element_type) {
-    element_type = tensor.getType().getElementType();
-  }
-  auto ptr = mlir::LLVM::LLVMPointerType::get(rewriter.getContext());
-  auto tensor_ptr = rewriter
-                        .create<mlir::UnrealizedConversionCastOp>(
-                            tensor.getLoc(), ptr, tensor)
-                        .getResult(0);
-  mlir::LLVMTypeConverter converter(rewriter.getContext());
-  auto llvm_element_type = converter.convertType(element_type);
-  auto gep = rewriter.create<mlir::LLVM::GEPOp>(
-      tensor.getLoc(), ptr, llvm_element_type, tensor_ptr, linear_index);
-  gep.setInbounds(true);
-  return gep;
-}
-
-mlir::LLVM::GEPOp CreateGep(TypedValue<mlir::RankedTensorType> tensor,
-                            ValueRange indices,
-                            mlir::PatternRewriter& rewriter) {
-  return CreateGep(tensor, GetLinearIndex(tensor, indices, rewriter), rewriter);
-}
-
-struct RewriteTensorExtract : mlir::OpRewritePattern<mlir::tensor::ExtractOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(
-      mlir::tensor::ExtractOp op,
-      mlir::PatternRewriter& rewriter) const override {
-    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
-    auto linear_index =
-        GetLinearIndex(op.getTensor(), op.getIndices(), rewriter);
-    Type element_type = op.getTensor().getType().getElementType();
-    Value is_low_nibble = nullptr;
-    if (element_type == rewriter.getI4Type()) {
-      element_type = rewriter.getI8Type();
-      std::tie(linear_index, is_low_nibble) =
-          GetI4IndexAndNibble(linear_index, b);
-    }
-
-    auto gep = CreateGep(op.getTensor(), linear_index, rewriter, element_type);
-    auto load =
-        rewriter
-            .create<mlir::LLVM::LoadOp>(gep.getLoc(), gep.getElemType(), gep)
-            .getResult();
-
-    if (is_low_nibble) {
-      auto high_value = b.create<mlir::arith::ShRUIOp>(
-          load, b.create<mlir::arith::ConstantIntOp>(4, load.getType()));
-      load = b.create<mlir::arith::TruncIOp>(
-          op.getType(),
-          b.create<mlir::arith::SelectOp>(is_low_nibble, load, high_value));
-    }
-
-    rewriter.replaceOpWithNewOp<mlir::UnrealizedConversionCastOp>(
-        op, op.getType(), load);
-    return success();
-  }
-};
-
-// Swaps pairs of values in the vector: [0, 1, 2, 3] -> [1, 0, 3, 2].
-Value PermutePairsInVector(Value vector, mlir::ImplicitLocOpBuilder& b) {
-  // There is a `vector.extract_strided_slice` op that would be useful here, but
-  // it actually requires the strides to be 1.
-  auto ty = mlir::cast<mlir::VectorType>(vector.getType());
-  int size = ty.getNumElements();
-  Value result = vector;
-  for (int i = 0; i < size; i += 2) {
-    auto v0 = b.create<mlir::vector::ExtractOp>(vector, i);
-    auto v1 = b.create<mlir::vector::ExtractOp>(vector, i + 1);
-    result = b.create<mlir::vector::InsertOp>(v1, result, i);
-    result = b.create<mlir::vector::InsertOp>(v0, result, i + 1);
-  }
-  return result;
-}
-
-struct RewriteTransferRead
-    : mlir::OpRewritePattern<mlir::vector::TransferReadOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(
-      mlir::vector::TransferReadOp op,
-      mlir::PatternRewriter& rewriter) const override {
-    assert(IsSupportedTransfer(op));
-
-    auto source = mlir::dyn_cast<mlir::TypedValue<mlir::RankedTensorType>>(
-        op.getSource());
-
-    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
-    auto linear_index = GetLinearIndex(source, op.getIndices(), rewriter);
-
-    mlir::VectorType vector_type = op.getVectorType();
-    if (vector_type.getElementType().isInteger(1)) {
-      vector_type = vector_type.cloneWith(std::nullopt, b.getI8Type());
-    }
-    mlir::Type gep_element_type = vector_type.getElementType();
-    if (op.getVectorType().getElementType().isInteger(4)) {
-      linear_index = b.create<arith::ShRUIOp>(
-          linear_index,
-          b.create<arith::ConstantIntOp>(1, linear_index.getType()));
-      gep_element_type = b.getI8Type();
-    }
-    auto gep = CreateGep(source, linear_index, rewriter, gep_element_type);
-
-    mlir::LLVMTypeConverter converter(b.getContext());
-    auto llvm_vector_type = converter.convertType(vector_type);
-    auto loaded =
-        b.create<mlir::LLVM::LoadOp>(llvm_vector_type, gep).getResult();
-
-    if (source.getType().getElementType().isInteger(1)) {
-      Value zero = b.create<mlir::arith::ConstantOp>(
-          mlir::DenseElementsAttr::get(vector_type, b.getI8IntegerAttr(0)));
-      loaded = b.create<arith::CmpIOp>(arith::CmpIPredicate::ne, loaded, zero);
-    } else if (source.getType().getElementType().isInteger(4)) {
-      // LLVM and XLA pack i4s in opposite order, so we have to reshuffle the
-      // elements.
-      loaded = PermutePairsInVector(loaded, b);
-    }
-
-    rewriter.replaceOpWithNewOp<mlir::UnrealizedConversionCastOp>(
-        op, op.getType(), loaded);
-    return success();
-  }
-};
-
-struct RewriteTensorInsert : mlir::OpRewritePattern<mlir::tensor::InsertOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(
-      mlir::tensor::InsertOp op,
-      mlir::PatternRewriter& rewriter) const override {
-    Value dest = GetDestinationBuffer(op.getDest());
-    if (!dest) {
-      return failure();
-    }
-
-    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
-    auto tensor_dest = mlir::cast<TypedValue<mlir::RankedTensorType>>(dest);
-    auto linear_index = GetLinearIndex(tensor_dest, op.getIndices(), rewriter);
-    auto element_type = tensor_dest.getType().getElementType();
-    Value is_low_nibble = nullptr;
-
-    if (element_type == rewriter.getI4Type()) {
-      element_type = rewriter.getI8Type();
-      std::tie(linear_index, is_low_nibble) =
-          GetI4IndexAndNibble(linear_index, b);
-    }
-
-    auto gep = CreateGep(tensor_dest, linear_index, rewriter, element_type);
-    auto scalar_value = op.getScalar();
-
-    if (is_low_nibble) {
-      Value current_value =
-          b.create<mlir::LLVM::LoadOp>(gep.getElemType(), gep);
-      auto ty = current_value.getType();
-      scalar_value = b.create<mlir::arith::ExtUIOp>(ty, scalar_value);
-      Value low_updated = b.create<mlir::arith::OrIOp>(
-          b.create<mlir::arith::AndIOp>(
-              current_value, b.create<mlir::arith::ConstantIntOp>(0xf0, ty)),
-          scalar_value);
-      Value high_updated = b.create<mlir::arith::OrIOp>(
-          b.create<mlir::arith::AndIOp>(
-              current_value, b.create<mlir::arith::ConstantIntOp>(0x0f, ty)),
-          b.create<mlir::arith::ShLIOp>(
-              scalar_value, b.create<mlir::arith::ConstantIntOp>(4, ty)));
-      scalar_value = b.create<mlir::arith::SelectOp>(is_low_nibble, low_updated,
-                                                     high_updated);
-    }
-
-    mlir::LLVMTypeConverter converter(getContext());
-    auto llvm_type = converter.convertType(scalar_value.getType());
-    scalar_value = rewriter
-                       .create<mlir::UnrealizedConversionCastOp>(
-                           gep.getLoc(), llvm_type, scalar_value)
-                       .getResult(0);
-    rewriter.create<mlir::LLVM::StoreOp>(gep.getLoc(), scalar_value, gep);
-
-    op.replaceAllUsesWith(op.getDest());
-    op.erase();
-    return success();
-  }
-};
-
-struct RewriteTransferWrite
-    : mlir::OpRewritePattern<mlir::vector::TransferWriteOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(
-      mlir::vector::TransferWriteOp op,
-      mlir::PatternRewriter& rewriter) const override {
-    assert(IsSupportedTransfer(op));
-    Value dest = GetDestinationBuffer(op.getSource());
-
-    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
-    auto tensor_dest = mlir::cast<TypedValue<mlir::RankedTensorType>>(dest);
-    auto linear_index = GetLinearIndex(tensor_dest, op.getIndices(), rewriter);
-    auto element_type = tensor_dest.getType().getElementType();
-
-    mlir::Value vector_value = op.getVector();
-    if (op.getVectorType().getElementType().isInteger(1)) {
-      vector_value = b.create<arith::ExtUIOp>(
-          op.getVectorType().cloneWith(std::nullopt, b.getI8Type()),
-          vector_value);
-    }
-    if (op.getVectorType().getElementType().isInteger(4)) {
-      linear_index = b.create<arith::ShRUIOp>(
-          linear_index,
-          b.create<arith::ConstantIntOp>(1, linear_index.getType()));
-      element_type = rewriter.getI8Type();
-      // LLVM and XLA pack i4s in opposite order, so we have to reshuffle the
-      // elements.
-      vector_value = PermutePairsInVector(vector_value, b);
-    }
-    auto gep = CreateGep(tensor_dest, linear_index, rewriter, element_type);
-
-    mlir::LLVMTypeConverter converter(getContext());
-    auto llvm_type = converter.convertType(vector_value.getType());
-    vector_value =
-        b.create<mlir::UnrealizedConversionCastOp>(llvm_type, vector_value)
-            .getResult(0);
-    b.create<mlir::LLVM::StoreOp>(vector_value, gep);
-
-    rewriter.replaceOp(op, mlir::ValueRange{op.getSource()});
-    return success();
-  }
-};
-
-struct RewriteCall : mlir::OpRewritePattern<mlir::func::CallOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(
-      mlir::func::CallOp op, mlir::PatternRewriter& rewriter) const override {
-    if (!llvm::any_of(op->getOperandTypes(), [](Type ty) {
-          return mlir::isa<mlir::RankedTensorType>(ty);
-        })) {
-      return rewriter.notifyMatchFailure(op, "the call has no input tensors");
-    }
-
-    for (const auto&& [index, arg] : llvm::enumerate(op.getOperands())) {
-      if (mlir::isa<mlir::RankedTensorType>(arg.getType())) {
-        op.setOperand(
-            index,
-            rewriter
-                .create<mlir::UnrealizedConversionCastOp>(
-                    op.getLoc(),
-                    mlir::LLVM::LLVMPointerType::get(op.getContext()), arg)
-                .getResult(0));
-      }
-    }
-    return success();
-  }
-};
-
-mlir::LLVM::GlobalOp CreateGlobalOp(mlir::Attribute value,
-                                    const std::string& name_prefix,
-                                    mlir::ShapedType shaped_ty,
-                                    mlir::ModuleOp module, bool is_constant,
-                                    int addr_space,
-                                    mlir::ImplicitLocOpBuilder& b) {
-  if (auto elements = mlir::dyn_cast_or_null<mlir::DenseElementsAttr>(value)) {
-    // The lowering to LLVM only works for 1d tensors or those with trailing
-    // unit dimensions.
-    value = elements.reshape(mlir::RankedTensorType::get(
-        {elements.getNumElements()}, elements.getElementType()));
-  }
-
-  Type element_type = shaped_ty.getElementType();
-  // Needed to support complex element type.
-  mlir::LLVMTypeConverter converter(b.getContext());
-  auto llvm_element_type = converter.convertType(element_type);
-  auto array_ty = mlir::LLVM::LLVMArrayType::get(llvm_element_type,
-                                                 shaped_ty.getNumElements());
-  std::string name;
-  int index = 0;
-  do {
-    name = absl::StrCat(name_prefix, index);
-    ++index;
-  } while (module.lookupSymbol(name));
-  b.setInsertionPointToStart(module.getBody());
-  return b.create<mlir::LLVM::GlobalOp>(
-      array_ty, is_constant,
-      /*linkage=*/mlir::LLVM::Linkage::Private, name, value, /*alignment=*/0,
-      addr_space);
-}
-
-struct RewriteAllocateShared : mlir::OpRewritePattern<AllocateSharedOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(
-      AllocateSharedOp op, mlir::PatternRewriter& rewriter) const override {
-    auto module = op->getParentOfType<mlir::ModuleOp>();
-    auto shaped_ty = mlir::cast<mlir::ShapedType>(op.getResult().getType());
-    constexpr int kGPUSharedMemoryAddrSpace = 3;
-    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
-
-    auto global =
-        CreateGlobalOp(mlir::Attribute{}, "shared_", shaped_ty, module,
-                       /*is_constant=*/false, kGPUSharedMemoryAddrSpace, b);
-
-    rewriter.setInsertionPoint(op);
-    auto addr = rewriter.create<mlir::LLVM::AddressOfOp>(op.getLoc(), global);
-    rewriter.replaceOpWithNewOp<mlir::UnrealizedConversionCastOp>(
-        op, op.getResult().getType(),
-        rewriter
-            .create<mlir::LLVM::AddrSpaceCastOp>(
-                op.getLoc(), mlir::LLVM::LLVMPointerType::get(op.getContext()),
-                addr)
-            .getResult());
-    return success();
-  }
-};
-
-struct RewriteNonScalarConstants
-    : mlir::OpRewritePattern<mlir::arith::ConstantOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  mlir::LogicalResult matchAndRewrite(
-      mlir::arith::ConstantOp op,
-      mlir::PatternRewriter& rewriter) const override {
-    if (mlir::isa<mlir::VectorType>(op.getType())) {
-      return rewriter.notifyMatchFailure(op, "the op is a vector constant");
-    }
-    auto shaped_ty = mlir::dyn_cast<mlir::ShapedType>(op.getValue().getType());
-    // We only need to rewrite non-scalar constants.
-    if (!shaped_ty || shaped_ty.getNumElements() < 2) {
-      return rewriter.notifyMatchFailure(
-          op, "the op is an effective scalar constant");
-    }
-
-    constexpr int kGPUGlobalMemoryAddrSpace = 0;
-    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
-    auto module = op->getParentOfType<mlir::ModuleOp>();
-    auto global =
-        CreateGlobalOp(op.getValue(), "global_cst_", shaped_ty, module,
-                       /*is_constant=*/true, kGPUGlobalMemoryAddrSpace, b);
-
-    rewriter.setInsertionPoint(op);
-    auto addr = rewriter.create<mlir::LLVM::AddressOfOp>(op.getLoc(), global);
-    rewriter.replaceOpWithNewOp<mlir::UnrealizedConversionCastOp>(
-        op, op.getResult().getType(),
-        rewriter
-            .create<mlir::LLVM::AddrSpaceCastOp>(
-                op.getLoc(), mlir::LLVM::LLVMPointerType::get(op.getContext()),
-                addr)
-            .getResult());
-    return success();
-  }
-};
-
-struct RewriteSyncThreads : mlir::OpRewritePattern<SyncThreadsOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(
-      SyncThreadsOp op, mlir::PatternRewriter& rewriter) const override {
-    rewriter.create<mlir::gpu::BarrierOp>(op.getLoc());
-    rewriter.replaceOp(op, op.getOperands());
-    return success();
-  }
-};
-
-// TODO(jreiffers): Generalize this to support index switches with some used
-// results and upstream it as a canonicalization pattern.
-struct RemoveUnusedIndexSwitchResults
-    : mlir::OpRewritePattern<scf::IndexSwitchOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(
-      scf::IndexSwitchOp op, mlir::PatternRewriter& rewriter) const override {
-    if (op->getNumResults() == 0 || !op->use_empty()) {
-      return rewriter.notifyMatchFailure(op, "the op has users");
-    }
-
-    auto new_op = rewriter.create<scf::IndexSwitchOp>(
-        op.getLoc(), mlir::TypeRange{}, op.getArg(), op.getCases(),
-        op.getNumCases());
-    for (int i = 0; i < op->getNumRegions(); ++i) {
-      auto& old_region = op->getRegion(i);
-      auto& new_region = new_op->getRegion(i);
-      rewriter.mergeBlocks(&old_region.getBlocks().front(),
-                           &new_region.emplaceBlock());
-      auto yield_op = new_region.getBlocks().front().getTerminator();
-      rewriter.modifyOpInPlace(yield_op, [&]() { yield_op->setOperands({}); });
-    }
-    rewriter.eraseOp(op);
-    return success();
-  }
-};
-
-bool IsAtomicIntegral(Type element_type) {
-  if (!element_type.isInteger()) {
-    return false;
-  }
-  unsigned element_bitwidth = element_type.getIntOrFloatBitWidth();
-  return element_bitwidth == 32 || element_bitwidth == 64;
-}
-
-Value CreateBitcast(mlir::ImplicitLocOpBuilder& b, Value value, Type ty) {
-  if (value.getType().isIntOrFloat() && ty.isIntOrFloat()) {
-    return b.create<ml::BitcastOp>(ty, value);
-  }
-
-  mlir::LLVMTypeConverter converter(b.getContext());
-  // If either type is a complex, we need to go through an alloca, since no
-  // direct bitcast from a struct to an int is possible.
-  Type llvm_input_ty = converter.convertType(value.getType());
-  Type llvm_result_ty = converter.convertType(ty);
-  Type ptr_ty = mlir::LLVM::LLVMPointerType::get(b.getContext());
-
-  Value llvm_value =
-      b.create<mlir::UnrealizedConversionCastOp>(llvm_input_ty, value)
-          .getResult(0);
-  Value alloca = b.create<ml::AllocaOp>(
-      ptr_ty, llvm_input_ty, b.create<ml::ConstantOp>(b.getI32Type(), 1));
-  b.create<ml::StoreOp>(llvm_value, alloca);
-  auto result = b.create<ml::LoadOp>(llvm_result_ty, alloca).getResult();
-  return b.create<mlir::UnrealizedConversionCastOp>(ty, result).getResult(0);
-};
-
-class RewriteAtomicRMW : public mlir::OpRewritePattern<AtomicRMWOp> {
- public:
-  RewriteAtomicRMW(mlir::MLIRContext* context, bool is_amd,
-                   const std::string& gpu_arch)
-      : mlir::OpRewritePattern<AtomicRMWOp>(context),
-        is_amd_(is_amd),
-        gpu_arch_(gpu_arch) {}
-
-  LogicalResult matchAndRewrite(
-      AtomicRMWOp op, mlir::PatternRewriter& rewriter) const override {
-    if (failed(rewriteAsDirectAtomicRMW(op, rewriter))) {
-      rewriteAsAtomicCAS(op, rewriter);
-    }
-    rewriter.replaceOp(op, op.getInput());
-    return success();
-  }
-
- private:
-  // Returns atomic op modifier and the atomic bin op kind.
-  std::optional<std::pair<Value, ml::AtomicBinOp>> GetAtomicModifierParameters(
-      AtomicRMWOp op) const {
-    Type element_type = op.getInput().getType().getElementType();
-    auto& operations = op.getBody()->getOperations();
-    auto terminator = op.getBody()->getTerminator();
-    if (operations.size() > 2) {
-      return std::nullopt;
-    }
-    // If the body contains only the terminator, then it is an atomic store.
-    if (operations.size() == 1) {
-      // TODO(b/336367145): Support complex<f32> atomic store.
-      if (element_type.isF32() || IsAtomicIntegral(element_type)) {
-        return std::make_pair(terminator->getOperand(0), ml::AtomicBinOp::xchg);
-      }
-      return std::nullopt;
-    }
-    // Match the kind of the atomic op.
-    mlir::Operation* modifier_op = &operations.front();
-    std::optional<ml::AtomicBinOp> kind =
-        llvm::TypeSwitch<Operation*, std::optional<ml::AtomicBinOp>>(
-            modifier_op)
-            // Floating-point operations.
-            .Case([](arith::AddFOp op) { return ml::AtomicBinOp::fadd; })
-            .Case([](arith::MaximumFOp op) { return ml::AtomicBinOp::fmax; })
-            .Case([](arith::MinimumFOp op) { return ml::AtomicBinOp::fmin; })
-            // Integer operations.
-            .Case([&](arith::AddIOp op) {
-              return IsAtomicIntegral(element_type)
-                         ? std::make_optional(ml::AtomicBinOp::add)
-                         : std::nullopt;
-            })
-            .Case([&](arith::MaxUIOp op) {
-              return IsAtomicIntegral(element_type)
-                         ? std::make_optional(ml::AtomicBinOp::umax)
-                         : std::nullopt;
-            })
-            .Case([&](arith::MinUIOp op) {
-              return IsAtomicIntegral(element_type)
-                         ? std::make_optional(ml::AtomicBinOp::umin)
-                         : std::nullopt;
-            })
-            .Case([&](arith::MaxSIOp op) {
-              return IsAtomicIntegral(element_type)
-                         ? std::make_optional(ml::AtomicBinOp::max)
-                         : std::nullopt;
-            })
-            .Case([&](arith::MinSIOp op) {
-              return IsAtomicIntegral(element_type)
-                         ? std::make_optional(ml::AtomicBinOp::min)
-                         : std::nullopt;
-            })
-            .Default([](Operation* op) { return std::nullopt; });
-    if (!kind.has_value()) {
-      return std::nullopt;
-    }
-    // Find the modifier arg that does not match the argument of `atomic_rmw`
-    // body.
-    Value block_arg = op.getBody()->getArgument(0);
-    Value modifier_arg = modifier_op->getOperand(0) == block_arg
-                             ? modifier_op->getOperand(1)
-                             : modifier_op->getOperand(0);
-    return std::make_pair(modifier_arg, *kind);
-  }
-
-  // Certain computations, such as floating-point addition and integer
-  // maximization, can be simply implemented using an LLVM atomic instruction.
-  // If "computation" is one of this kind, emits code to do that and returns
-  // true; otherwise, returns false.
-  LogicalResult rewriteAsDirectAtomicRMW(
-      AtomicRMWOp op, mlir::PatternRewriter& rewriter) const {
-    auto modifier_parameters = GetAtomicModifierParameters(op);
-    if (!modifier_parameters.has_value()) {
-      return failure();
-    }
-    Value modifier_arg = modifier_parameters->first;
-    Type element_type = modifier_arg.getType();
-    ml::AtomicBinOp atomic_bin_op = modifier_parameters->second;
-
-    Location loc = op.getLoc();
-    llvm::StringRef sync_scope = is_amd_ ? "agent" : "";
-    Value addr = CreateGep(op.getInput(), op.getIndices(), rewriter);
-
-    switch (atomic_bin_op) {
-      case ml::AtomicBinOp::xchg: {
-        rewriter.create<ml::StoreOp>(
-            loc, modifier_arg, addr,
-            /*alignment=*/element_type.getIntOrFloatBitWidth() / 8,
-            /*volatile*/ false, /*isNonTemporal=*/false,
-            ml::AtomicOrdering::unordered);
-        return success();
-      }
-      case ml::AtomicBinOp::add:
-      case ml::AtomicBinOp::max:
-      case ml::AtomicBinOp::min:
-      case ml::AtomicBinOp::umax:
-      case ml::AtomicBinOp::umin: {
-        rewriter.create<ml::AtomicRMWOp>(loc, atomic_bin_op, addr, modifier_arg,
-                                         ml::AtomicOrdering::seq_cst,
-                                         sync_scope);
-        return success();
-      }
-      case ml::AtomicBinOp::fadd: {
-        // TODO(b/336367154): Introduce an atomic_rmw op with the binOp attr.
-        return is_amd_ ? emitAMDAtomicFAdd(loc, modifier_arg, addr, sync_scope,
-                                           gpu_arch_, rewriter)
-                       : emitNVidiaAtomicFAdd(loc, modifier_arg, addr,
-                                              sync_scope, gpu_arch_, rewriter);
-      }
-      case ml::AtomicBinOp::fmax: {
-        return rewriteAtomicFMaxAsIntAtomics(loc, modifier_arg, addr,
-                                             sync_scope, rewriter);
-      }
-      default:
-        return failure();
-    }
-    return success();
-  }
-
-  LogicalResult emitNVidiaAtomicFAdd(Location loc, Value modifier_arg,
-                                     Value addr, llvm::StringRef sync_scope,
-                                     llvm::StringRef cuda_arch,
-                                     OpBuilder& b) const {
-    se::CudaComputeCapability cuda_compute_capability(cuda_arch.str());
-    Type element_type = modifier_arg.getType();
-    // "atom.add.f64 requires sm_60 or higher."
-    // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-atom
-    bool is_supported_f16_atomic =
-        element_type.isF16() &&
-        cuda_compute_capability.IsAtLeast(se::CudaComputeCapability::VOLTA);
-    bool is_supported_bf16_atomic =
-        element_type.isBF16() &&
-        cuda_compute_capability.IsAtLeast(se::CudaComputeCapability::HOPPER);
-    bool is_supported_f64_atomic =
-        element_type.isF64() &&
-        cuda_compute_capability.IsAtLeast(se::CudaComputeCapability::PASCAL_);
-    if (!element_type.isF32() && !is_supported_f16_atomic &&
-        !is_supported_bf16_atomic && !is_supported_f64_atomic) {
-      return failure();
-    }
-    b.create<ml::AtomicRMWOp>(loc, ml::AtomicBinOp::fadd, addr, modifier_arg,
-                              ml::AtomicOrdering::seq_cst, sync_scope);
-    return success();
-  }
-
-  LogicalResult emitAMDAtomicFAdd(Location loc, Value modifier_arg, Value addr,
-                                  llvm::StringRef sync_scope,
-                                  llvm::StringRef gcn_arch,
-                                  OpBuilder& b) const {
-    se::RocmComputeCapability rocm_compute_capability(gcn_arch.str());
-    Type element_type = modifier_arg.getType();
-    bool is_supported_f16_atomic =
-        element_type.isF16() &&
-        rocm_compute_capability.has_fp16_atomics_support();
-    if (!element_type.isF32() && !is_supported_f16_atomic) {
-      return failure();
-    }
-    constexpr int kGlobalMemory = 1;
-    constexpr int kSharedMemory = 3;
-    auto addr_type = mlir::cast<ml::LLVMPointerType>(addr.getType());
-    // adds to shared memory are always atomic.
-    if (addr_type.getAddressSpace() != kSharedMemory) {
-      // The compiler will only generate a global_atomic_fadd if the pointer is
-      // in global addrspace (1)
-      addr = b.create<ml::AddrSpaceCastOp>(
-          loc, ml::LLVMPointerType::get(b.getContext(), kGlobalMemory), addr);
-    }
-    b.create<ml::AtomicRMWOp>(loc, ml::AtomicBinOp::fadd, addr, modifier_arg,
-                              ml::AtomicOrdering::seq_cst, sync_scope);
-    return success();
-  }
-
-  LogicalResult rewriteAtomicFMaxAsIntAtomics(Location loc, Value modifier_arg,
-                                              Value addr,
-                                              llvm::StringRef sync_scope,
-                                              OpBuilder& b) const {
-    Type element_type = modifier_arg.getType();
-    if (!element_type.isF32()) {
-      return failure();
-    }
-    // Evaluating floating max using integer atomics has the limitation of not
-    // propagating -NaNs. To handle this, we check if the update value is -NaN
-    // and convert it to a positive one by dropping the sign-bit.
-    Value current = b.create<ml::LoadOp>(loc, element_type, addr);
-    Value current_is_nan =
-        b.create<ml::FCmpOp>(loc, ml::FCmpPredicate::uno, current, current);
-    auto is_current_nan =
-        b.create<scf::IfOp>(loc, /*resultTypes=*/TypeRange{}, current_is_nan,
-                            /*addThenBlock=*/true, /*addElseBlock=*/true);
-    auto if_current_nan_then_builder =
-        OpBuilder::atBlockEnd(is_current_nan.thenBlock(), b.getListener());
-    if_current_nan_then_builder.create<scf::YieldOp>(loc);
-
-    auto if_current_nan_else_builder =
-        OpBuilder::atBlockEnd(is_current_nan.elseBlock(), b.getListener());
-    Value is_modifier_nan = if_current_nan_else_builder.create<ml::FCmpOp>(
-        loc, ml::FCmpPredicate::uno, modifier_arg, modifier_arg);
-    auto f32_nan = mlir::APFloat::getNaN(mlir::APFloat::IEEEsingle());
-    Value nan = if_current_nan_else_builder.create<ml::ConstantOp>(
-        loc, b.getF32Type(), f32_nan);
-    Value no_negative_nan_source =
-        if_current_nan_else_builder.create<ml::SelectOp>(loc, is_modifier_nan,
-                                                         nan, modifier_arg);
-    Value current_less_than_modifier =
-        if_current_nan_else_builder.create<ml::FCmpOp>(
-            loc, ml::FCmpPredicate::ult, current, no_negative_nan_source);
-
-    // This check allows us to skip the atomic update all-together at the
-    // expense of reading the value in memory for every update. Evaluated
-    // against Waymo's benchmarks, adding the check achieves better overall
-    // performance.
-    auto if_need_update = if_current_nan_else_builder.create<scf::IfOp>(
-        loc, /*resultTypes=*/TypeRange{}, current_less_than_modifier,
-        /*withElseRegion=*/true,
-        /*addElseBlock=*/false);
-    if_current_nan_else_builder.create<scf::YieldOp>(loc);
-
-    auto then_builder =
-        OpBuilder::atBlockEnd(if_need_update.thenBlock(), b.getListener());
-    Value source_float_as_int = then_builder.create<ml::BitcastOp>(
-        loc, then_builder.getI32Type(), no_negative_nan_source);
-    Value c0 = then_builder.create<ml::ConstantOp>(loc, b.getI32Type(), 0);
-    Value is_not_negative = then_builder.create<ml::ICmpOp>(
-        loc, ml::ICmpPredicate::sge, source_float_as_int, c0);
-    then_builder.create<scf::IfOp>(
-        loc, is_not_negative,
-        [&](OpBuilder& nested_b, Location nested_loc) {
-          // atomicMax((int *)address, __float_as_int(val))
-          nested_b.create<ml::AtomicRMWOp>(
-              loc, ml::AtomicBinOp::max, addr, source_float_as_int,
-              ml::AtomicOrdering::seq_cst, sync_scope);
-          nested_b.create<scf::YieldOp>(nested_loc);
-        },
-        [&](OpBuilder& nested_b, Location nested_loc) {
-          // atomicMax((int *)address, __float_as_int(val))
-          nested_b.create<ml::AtomicRMWOp>(
-              loc, ml::AtomicBinOp::umin, addr, source_float_as_int,
-              ml::AtomicOrdering::seq_cst, sync_scope);
-          nested_b.create<scf::YieldOp>(nested_loc);
-        });
-    then_builder.create<scf::YieldOp>(loc);
-    return success();
-  }
-
-  // Implements atomic binary operations using atomic compare-and-swap
-  // (atomicCAS) as follows:
-  //   1. Reads the value from the memory pointed to by output_address and
-  //     records it as old_output.
-  //   2. Uses old_output as one of the source operand to perform the binary
-  //     operation and stores the result in new_output.
-  //   3. Calls atomicCAS which implements compare-and-swap as an atomic
-  //     operation. In particular, atomicCAS reads the value from the memory
-  //     pointed to by output_address, and compares the value with old_output.
-  //     If the two values equal, new_output is written to the same memory
-  //     location and true is returned to indicate that the atomic operation
-  //     succeeds. Otherwise, the new value read from the memory is returned. In
-  //     this case, the new value is copied to old_output, and steps 2. and 3.
-  //     are repeated until atomicCAS succeeds.
-  //
-  // On Nvidia GPUs, atomicCAS can only operate on 32 bit and 64 bit integers.
-  // If the element type of the binary operation is 32 bits or 64 bits, the
-  // integer type of the same size is used for the atomicCAS operation. On the
-  // other hand, if the element type is smaller than 32 bits, int32_t is used
-  // for the atomicCAS operation. In this case, atomicCAS reads and writes 32
-  // bit values from the memory, which is larger than the memory size required
-  // by the original atomic binary operation. We mask off the last two bits of
-  // the output_address and use the result as an address to read the 32 bit
-  // values from the memory. This can avoid out of bound memory accesses if
-  // tensor buffers are 4 byte aligned and have a size of 4N, an assumption that
-  // the runtime can guarantee.
-  void rewriteAsAtomicCAS(AtomicRMWOp op,
-                          mlir::PatternRewriter& rewriter) const {
-    Location loc = op.getLoc();
-    auto input = op.getInput();
-
-    // Use 32-bit atomic type for small input types.
-    Type result_ty = op.getResult().getType().getElementType();
-    int result_size;
-    if (auto complex_ty = mlir::dyn_cast<mlir::ComplexType>(result_ty)) {
-      result_size = complex_ty.getElementType().getIntOrFloatBitWidth() * 2;
-    } else {
-      result_size = result_ty.getIntOrFloatBitWidth();
-    }
-
-    bool small_type = result_size < 32;
-    Type atomic_ty =
-        mlir::IntegerType::get(op.getContext(), small_type ? 32 : result_size);
-
-    // Calculate load address for the input.
-    Value addr = CreateGep(input, op.getIndices(), rewriter);
-    Value shift, mask;
-    if (small_type) {
-      // Update input pointer by discarding the last two bits - i.e. align to
-      // 32-bit boundary for small input types (will not result in OOB, as the
-      // input alignment is at least 32 bits).
-      Type addr_int_ty = rewriter.getI64Type();
-      Value addr_int = rewriter.create<ml::PtrToIntOp>(loc, addr_int_ty, addr);
-      Value addr_offset = rewriter.create<ml::AndOp>(
-          loc, addr_int, rewriter.create<ml::ConstantOp>(loc, addr_int_ty, 3));
-      Value index = rewriter.create<ml::MulOp>(
-          loc, addr_offset,
-          rewriter.create<ml::ConstantOp>(loc, addr_int_ty, -1));
-      addr =
-          rewriter.create<ml::GEPOp>(loc, addr.getType(), rewriter.getI8Type(),
-                                     addr, index, /*inbounds=*/true);
-
-      // Calculate the bit shift (assume little-endianness).
-      Value offset = rewriter.create<ml::TruncOp>(loc, atomic_ty, addr_offset);
-      shift = rewriter.create<ml::MulOp>(
-          loc, offset,
-          rewriter.create<ml::ConstantOp>(loc, offset.getType(), 8));
-
-      // Compose the update mask.
-      Value bits_long = rewriter.create<ml::ConstantOp>(loc, atomic_ty, -1);
-      Value bits_short = rewriter.create<ml::ZExtOp>(
-          loc, atomic_ty,
-          rewriter.create<ml::ConstantOp>(
-              loc, rewriter.getIntegerType(result_size), -1));
-      mask = rewriter.create<ml::XOrOp>(
-          loc, bits_long, rewriter.create<ml::ShlOp>(loc, bits_short, shift));
-    }
-
-    // Load initial atomic value and create the loop.
-    Value initial = rewriter.create<ml::LoadOp>(loc, atomic_ty, addr);
-    rewriter.create<scf::WhileOp>(
-        loc, TypeRange{atomic_ty}, ValueRange{initial},
-        [&](mlir::OpBuilder& builder, Location loc, ValueRange values) {
-          mlir::ImplicitLocOpBuilder b(loc, builder);
-          Value old_value = values[0];
-
-          // Convert atomic value to input value.
-          Value input_value;
-          if (small_type) {
-            Value short_value =
-                b.create<ml::TruncOp>(b.getIntegerType(result_size),
-                                      b.create<ml::LShrOp>(old_value, shift));
-            input_value = b.create<ml::BitcastOp>(result_ty, short_value);
-          } else {
-            input_value = CreateBitcast(b, old_value, result_ty);
-          }
-
-          // Perform computation on the loaded input value.
-          rewriter.mergeBlocks(&op.getComputation().front(), b.getBlock(),
-                               {input_value});
-          auto yield_op = b.getBlock()->getTerminator();
-          Value result = yield_op->getOperand(0);
-          rewriter.eraseOp(yield_op);
-
-          // Convert resulting value to atomic value.
-          Value new_value;
-          if (small_type) {
-            Value cast_value = b.create<ml::ZExtOp>(
-                atomic_ty, b.create<ml::BitcastOp>(
-                               rewriter.getIntegerType(result_size), result));
-            new_value =
-                b.create<ml::OrOp>(b.create<ml::AndOp>(old_value, mask),
-                                   b.create<ml::ShlOp>(cast_value, shift));
-          } else {
-            new_value = CreateBitcast(b, result, atomic_ty);
-          }
-
-          // Try saving the result atomically, retry if failed.
-          Value cmpxchg = b.create<ml::AtomicCmpXchgOp>(
-              loc, addr, old_value, new_value,
-              /*success_ordering=*/ml::AtomicOrdering::seq_cst,
-              /*failure_ordering=*/ml::AtomicOrdering::seq_cst);
-          Value next = b.create<ml::ExtractValueOp>(cmpxchg, 0);
-          Value ok = b.create<ml::ExtractValueOp>(cmpxchg, 1);
-          Value low_bit = b.create<ml::ConstantOp>(b.getOneAttr(b.getI1Type()));
-          Value not_ok = b.create<ml::XOrOp>(ok, low_bit);
-          b.create<scf::ConditionOp>(not_ok, ValueRange{next});
-        },
-        [&](mlir::OpBuilder& b, Location loc, ValueRange values) {
-          b.create<scf::YieldOp>(loc, values);
-        });
-  }
-
-  bool is_amd_;
-  std::string gpu_arch_;
-};
-
-class LowerTensorsPass : public impl::LowerTensorsPassBase<LowerTensorsPass> {
- public:
-  explicit LowerTensorsPass(const LowerTensorsPassOptions& options)
-      : LowerTensorsPassBase(options) {}
-
-  void runOnOperation() override {
-    MLIRContext* mlir_context = &getContext();
-    mlir::RewritePatternSet tensor_patterns(mlir_context);
-    tensor_patterns.add<RewriteAtomicRMW>(mlir_context, is_amd_gpu_, gpu_arch_);
-    tensor_patterns
-        .add<RewriteAllocateShared, RewriteNonScalarConstants,
-             RewriteSyncThreads, RewriteTensorExtract, RewriteTransferRead,
-             RewriteTensorInsert, RewriteTransferWrite>(mlir_context);
-    if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
-            getOperation(), std::move(tensor_patterns)))) {
-      signalPassFailure();
-      return;
-    }
-
-    mlir::RewritePatternSet function_patterns(mlir_context);
-    function_patterns.add<RewriteFunctionSignatures, RewriteCall,
-                          RemoveUnusedIndexSwitchResults>(mlir_context);
-    scf::ForOp::getCanonicalizationPatterns(function_patterns, mlir_context);
-    scf::IfOp::getCanonicalizationPatterns(function_patterns, mlir_context);
-    if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
-            getOperation(), std::move(function_patterns)))) {
-      signalPassFailure();
-      return;
-    }
-
-    getOperation()->walk([this](mlir::LLVM::LoadOp load) {
-      Value addr = load.getAddr();
-      while (auto gep = addr.getDefiningOp<mlir::LLVM::GEPOp>()) {
-        addr = gep.getBase();
-      }
-      if (addr.getDefiningOp<mlir::LLVM::AddrSpaceCastOp>() ||
-          addr.getDefiningOp<mlir::LLVM::AddressOfOp>() ||
-          addr.getDefiningOp<mlir::LLVM::AllocaOp>()) {
-        // Shared memory, global constant or temporary - no need to annotate
-        // anything.
-        return;
-      }
-      if (auto base = mlir::dyn_cast<mlir::BlockArgument>(addr)) {
-        if (auto func = mlir::dyn_cast<mlir::func::FuncOp>(
-                base.getOwner()->getParentOp())) {
-          if (func.getArgAttr(base.getArgNumber(), "xla.invariant")) {
-            load.setInvariant(true);
-          }
-          return;
-        }
-      }
-      load.emitOpError("load op address is not (a GEP of) a function argument");
-      signalPassFailure();
-    });
-  }
-};
-
-}  // namespace
-
-std::unique_ptr<::mlir::Pass> CreateLowerTensorsPass(
-    bool is_amd_gpu, const std::string& gpu_arch) {
-  LowerTensorsPassOptions options;
-  options.is_amd_gpu_ = is_amd_gpu;
-  options.gpu_arch_ = gpu_arch;
-  return std::make_unique<LowerTensorsPass>(options);
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/lower_to_llvm.cc b/third_party/xla/xla/service/gpu/fusions/mlir/lower_to_llvm.cc
deleted file mode 100644
index 6e05ead..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/lower_to_llvm.cc
+++ /dev/null
@@ -1,99 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include <memory>
-#include <utility>
-
-#include "llvm/Support/LogicalResult.h"
-#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
-#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
-#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
-#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
-#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
-#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
-#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
-#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
-#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
-#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
-#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Arith/Transforms/Passes.h"
-#include "mlir/Dialect/Complex/IR/Complex.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/LLVMIR/LLVMDialect.h"  // IWYU pragma: keep
-#include "mlir/Dialect/LLVMIR/NVVMDialect.h"  // IWYU pragma: keep
-#include "mlir/Dialect/Math/IR/Math.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Interfaces/DataLayoutInterfaces.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/DialectConversion.h"
-
-namespace xla {
-namespace gpu {
-
-#define GEN_PASS_DEF_LOWERTOLLVMPASS
-#include "xla/service/gpu/fusions/mlir/passes.h.inc"
-
-namespace {
-
-class LowerToLLVMPass : public impl::LowerToLLVMPassBase<LowerToLLVMPass> {
- public:
-  using LowerToLLVMPassBase::LowerToLLVMPassBase;
-
-  void runOnOperation() override {
-    // Populate type conversions.
-    mlir::LowerToLLVMOptions llvm_opts(&getContext(),
-                                       mlir::DataLayout(getOperation()));
-    mlir::LLVMTypeConverter type_converter(getOperation().getContext(),
-                                           llvm_opts);
-    mlir::LLVMConversionTarget target(*getOperation().getContext());
-
-    // Populate patterns.
-    mlir::RewritePatternSet patterns(&getContext());
-    mlir::populateAffineToStdConversionPatterns(patterns);
-    mlir::populateSCFToControlFlowConversionPatterns(patterns);
-    mlir::arith::populateArithExpandOpsPatterns(patterns);
-    mlir::arith::populateArithToLLVMConversionPatterns(type_converter,
-                                                       patterns);
-    mlir::populateGpuToNVVMConversionPatterns(type_converter, patterns);
-    mlir::populateFuncToLLVMConversionPatterns(type_converter, patterns);
-    mlir::populateVectorToLLVMConversionPatterns(type_converter, patterns);
-    mlir::cf::populateControlFlowToLLVMConversionPatterns(type_converter,
-                                                          patterns);
-    mlir::populateComplexToLLVMConversionPatterns(type_converter, patterns);
-    mlir::populateMathToLLVMConversionPatterns(type_converter, patterns);
-
-    //  Setup target.
-    mlir::configureGpuToNVVMConversionLegality(target);
-    target.addIllegalDialect<mlir::arith::ArithDialect, mlir::func::FuncDialect,
-                             mlir::complex::ComplexDialect,
-                             mlir::math::MathDialect>();
-    target.addLegalOp<mlir::ModuleOp>();
-
-    if (failed(
-            applyFullConversion(getOperation(), target, std::move(patterns)))) {
-      signalPassFailure();
-    }
-  }
-};
-
-}  // namespace
-
-std::unique_ptr<mlir::Pass> CreateLowerToLLVMPass() {
-  return std::make_unique<LowerToLLVMPass>();
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/lower_xla_gpu_to_scf.cc b/third_party/xla/xla/service/gpu/fusions/mlir/lower_xla_gpu_to_scf.cc
deleted file mode 100644
index 9028480..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/lower_xla_gpu_to_scf.cc
+++ /dev/null
@@ -1,205 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include <memory>
-#include <utility>
-
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/Support/LogicalResult.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Complex/IR/Complex.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/GPU/IR/GPUDialect.h"
-#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
-#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/ImplicitLocOpBuilder.h"
-#include "mlir/IR/Location.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/Value.h"
-#include "mlir/IR/ValueRange.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Support/LLVM.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h"
-#include "xla/service/gpu/ir_emission_utils.h"
-#include "xla/util.h"
-
-namespace xla {
-namespace gpu {
-
-#define GEN_PASS_DEF_LOWERXLAGPUTOSCFPASS
-#include "xla/service/gpu/fusions/mlir/passes.h.inc"
-
-namespace {
-
-using mlir::success;
-
-struct RewritePredicatedInsert : mlir::OpRewritePattern<PredicatedInsertOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  mlir::LogicalResult matchAndRewrite(
-      PredicatedInsertOp op, mlir::PatternRewriter& rewriter) const override {
-    rewriter.replaceOpWithNewOp<mlir::scf::IfOp>(
-        op, op.getCondition(),
-        [&](mlir::OpBuilder& b, mlir::Location loc) {
-          b.create<mlir::scf::YieldOp>(
-              loc, b.create<mlir::tensor::InsertOp>(
-                        loc, op.getValue(), op.getDest(), op.getIndices())
-                       .getResult());
-        },
-        [&](mlir::OpBuilder& b, mlir::Location loc) {
-          b.create<mlir::scf::YieldOp>(loc, op.getDest());
-        });
-    return success();
-  }
-};
-
-struct RewritePredicatedExtract : mlir::OpRewritePattern<PredicatedExtractOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  mlir::LogicalResult matchAndRewrite(
-      PredicatedExtractOp op, mlir::PatternRewriter& rewriter) const override {
-    rewriter.replaceOpWithNewOp<mlir::scf::IfOp>(
-        op, op.getCondition(),
-        [&](mlir::OpBuilder& b, mlir::Location loc) {
-          b.create<mlir::scf::YieldOp>(
-              loc, b.create<mlir::tensor::ExtractOp>(loc, op.getSrc(),
-                                                     op.getIndices())
-                       .getResult());
-        },
-        [&](mlir::OpBuilder& b, mlir::Location loc) {
-          b.create<mlir::scf::YieldOp>(loc, op.getFallback());
-        });
-    return success();
-  }
-};
-
-struct RewriteShuffleReduce : mlir::OpRewritePattern<ShuffleReduceOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  mlir::LogicalResult matchAndRewrite(
-      ShuffleReduceOp op, mlir::PatternRewriter& rewriter) const override {
-    int max_distance =
-        mlir::cast<mlir::IntegerAttr>(op->getAttr("max_distance")).getInt();
-    // TODO(jreiffers): Do this in a verifier.
-    if (max_distance & (max_distance - 1) || max_distance >= WarpSize()) {
-      return op->emitOpError("max_distance must be a power of 2 < WarpSize()");
-    }
-
-    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
-    mlir::ValueRange values = op.getOperands();
-    for (int distance = max_distance; distance > 0; distance /= 2) {
-      namespace ml = mlir::LLVM;
-      auto shuffle_32 = [&](mlir::Value v) {
-        return b
-            .create<mlir::gpu::ShuffleOp>(v, distance, WarpSize(),
-                                          mlir::gpu::ShuffleMode::DOWN)
-            .getShuffleResult();
-      };
-
-      auto shuffle_int_or_float = [&](mlir::Value value) {
-        auto ty = value.getType();
-        int bit_width = ty.getIntOrFloatBitWidth();
-        if (bit_width == 32) {
-          return shuffle_32(value);
-        }
-        int n_shuffles = CeilOfRatio(bit_width, 32);
-        auto int_ty = b.getIntegerType(bit_width);
-        auto padded_int_ty = b.getIntegerType(n_shuffles * 32);
-        value = b.create<mlir::arith::BitcastOp>(int_ty, value);
-        value = b.create<mlir::arith::ExtUIOp>(padded_int_ty, value);
-        if (n_shuffles > 1) {
-          // Don't generate vectors if the size is 1.
-          auto vector_type = ml::getVectorType(b.getI32Type(), n_shuffles);
-          value = b.create<ml::BitcastOp>(vector_type, value);
-          mlir::Value result_vec = b.create<ml::UndefOp>(vector_type);
-          for (int i = 0; i < n_shuffles; ++i) {
-            auto idx = b.create<mlir::arith::ConstantIntOp>(i, 32);
-            result_vec = b.create<ml::InsertElementOp>(
-                result_vec,
-                shuffle_32(b.create<ml::ExtractElementOp>(value, idx)), idx);
-          }
-          value = b.create<ml::BitcastOp>(padded_int_ty, result_vec);
-        } else {
-          value = shuffle_32(value);
-        }
-        value = b.create<mlir::arith::TruncIOp>(int_ty, value);
-        value = b.create<ml::BitcastOp>(ty, value);
-        return value;
-      };
-
-      auto shuffle = [&](mlir::Value value) -> mlir::Value {
-        if (mlir::isa<mlir::ComplexType>(value.getType())) {
-          return b.create<mlir::complex::CreateOp>(
-              value.getType(),
-              shuffle_int_or_float(b.create<mlir::complex::ReOp>(value)),
-              shuffle_int_or_float(b.create<mlir::complex::ImOp>(value)));
-        }
-        if (value.getType().isUnsignedInteger()) {
-          auto ty = value.getType();
-          auto signless_ty = b.getIntegerType(ty.getIntOrFloatBitWidth());
-          value = b.create<mlir::UnrealizedConversionCastOp>(
-                       mlir::TypeRange{signless_ty}, value)
-                      .getResult(0);
-          value = shuffle_int_or_float(value);
-          value = b.create<mlir::UnrealizedConversionCastOp>(
-                       mlir::TypeRange{ty}, value)
-                      .getResult(0);
-          return value;
-        }
-        return shuffle_int_or_float(value);
-      };
-
-      llvm::SmallVector<mlir::Value> args = values;
-      for (auto value : values) {
-        args.push_back(shuffle(value));
-      }
-      values = b.create<PureCallOp>(op.getResultTypes(),
-                                    op.getReducerAttr().getAttr(), args)
-                   .getResults();
-    }
-    rewriter.replaceOp(op, values);
-    return success();
-  }
-};
-
-class LowerXlaGpuToScfPass
-    : public impl::LowerXlaGpuToScfPassBase<LowerXlaGpuToScfPass> {
- public:
-  void runOnOperation() override {
-    mlir::RewritePatternSet patterns(&getContext());
-    patterns.add<RewritePredicatedInsert, RewritePredicatedExtract,
-                 RewriteShuffleReduce>(&getContext());
-    if (mlir::failed(mlir::applyPatternsAndFoldGreedily(getOperation(),
-                                                        std::move(patterns)))) {
-      signalPassFailure();
-    }
-  }
-};
-
-}  // namespace
-
-std::unique_ptr<::mlir::Pass> CreateLowerXlaGpuToScfPass() {
-  return std::make_unique<LowerXlaGpuToScfPass>();
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/merge_pointers_to_same_slice.cc b/third_party/xla/xla/service/gpu/fusions/mlir/merge_pointers_to_same_slice.cc
deleted file mode 100644
index c1899d2..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/merge_pointers_to_same_slice.cc
+++ /dev/null
@@ -1,117 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include <memory>
-#include <optional>
-#include <string>
-
-#include "absl/container/flat_hash_map.h"
-#include "llvm/ADT/BitVector.h"
-#include "llvm/ADT/STLExtras.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
-#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Support/LLVM.h"
-
-namespace xla {
-namespace gpu {
-
-#define GEN_PASS_DEF_MERGEPOINTERSTOSAMESLICEPASS
-#include "xla/service/gpu/fusions/mlir/passes.h.inc"
-
-namespace {
-
-class MergePointersToSameSlicePass
-    : public impl::MergePointersToSameSlicePassBase<
-          MergePointersToSameSlicePass> {
- public:
-  void runOnOperation() override;
-};
-
-struct PackedArgs {
-  llvm::BitVector args_to_erase;
-  // replacement_args[i] == i iff !args_to_erase[i].
-  llvm::SmallVector<int> replacement_args;
-
-  PackedArgs() = default;
-  explicit PackedArgs(mlir::func::FuncOp func) {
-    absl::flat_hash_map<int, std::optional<int>> slice_to_operand;
-    args_to_erase.resize(func.getNumArguments());
-    replacement_args.reserve(func.getNumArguments());
-    for (int i = 0; i < func.getNumArguments(); ++i) {
-      replacement_args.push_back(i);
-    }
-
-    for (auto [idx, operand] : llvm::enumerate(func.getArguments())) {
-      auto slice_index = func.getArgAttr(idx, "xla.slice_index");
-      if (!slice_index) {
-        continue;
-      }
-
-      auto& target_index = slice_to_operand[static_cast<int>(
-          mlir::cast<mlir::IntegerAttr>(slice_index).getInt())];
-      if (target_index) {
-        replacement_args[idx] = *target_index;
-        args_to_erase[idx] = true;
-      } else {
-        target_index = idx;
-      }
-    }
-  }
-
-  void Pack(mlir::func::FuncOp op) {
-    for (auto [idx, arg] : llvm::enumerate(op.getArguments())) {
-      if (replacement_args[idx] != idx) {
-        arg.replaceAllUsesWith(op.getArgument(replacement_args[idx]));
-      }
-    }
-    op.eraseArguments(args_to_erase);
-    for (int i = 0; i < op.getNumArguments(); ++i) {
-      if (op.getArgAttr(i, "xla.slice_index")) {
-        op.removeArgAttr(i, "xla.slice_index");
-        op.setArgAttr(i, mlir::LLVM::LLVMDialect::getNoAliasAttrName(),
-                      mlir::UnitAttr::get(op->getContext()));
-      }
-    }
-  }
-
-  void Pack(mlir::func::CallOp op) { op->eraseOperands(args_to_erase); }
-};
-
-void MergePointersToSameSlicePass::runOnOperation() {
-  mlir::func::FuncOp entry;
-
-  absl::flat_hash_map<std::string, PackedArgs> args_to_pack;
-  getOperation()->walk([&](mlir::func::FuncOp func) {
-    args_to_pack[func.getName()] = PackedArgs(func);
-  });
-  getOperation()->walk([&](mlir::func::CallOp call) {
-    args_to_pack[call.getCallee()].Pack(call);
-  });
-  getOperation()->walk([&](mlir::func::FuncOp func) {
-    args_to_pack[func.getName()].Pack(func);
-  });
-}
-
-}  // namespace
-
-std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
-CreateMergePointersToSameSlicePass() {
-  return std::make_unique<MergePointersToSameSlicePass>();
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc b/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc
index 251d3ff..524d365 100644
--- a/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc
+++ b/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc
@@ -83,11 +83,11 @@
 #include "xla/service/buffer_assignment.h"
 #include "xla/service/dump.h"
 #include "xla/service/gpu/fusions/fusion_emitter.h"
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h"
 #include "xla/service/gpu/fusions/mlir/computation_partitioner.h"
 #include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h"
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h"
-#include "xla/service/gpu/fusions/mlir/passes.h"
 #include "xla/service/gpu/fusions/mlir/type_util.h"
+#include "xla/service/gpu/fusions/transforms/passes.h"
 #include "xla/service/gpu/ir_emitter_context.h"
 #include "xla/service/gpu/kernel_arguments.h"
 #include "xla/service/gpu/kernel_reuse_cache.h"
@@ -222,7 +222,7 @@
 absl::StatusOr<FusionEmissionResult> MlirFusionEmitterBase::Emit(
     IrEmitterContext& ir_emitter_context,
     const HloFusionInstruction& fusion) const {
-  VLOG(5) << "Fusion: " << fusion.fused_instructions_computation()->ToString();
+  VLOG(4) << "Fusion: " << fusion.fused_instructions_computation()->ToString();
   TF_ASSIGN_OR_RETURN(
       auto args,
       KernelArguments::Create(ir_emitter_context.buffer_assignment(), &fusion));
@@ -305,13 +305,14 @@
   mlir::PassManager pm(&mlir_context);
   pm.addPass(CreateEraseDeadFunctionsPass());
   pm.addPass(mlir::createCSEPass());
-  pm.addPass(CreateLowerXlaGpuToScfPass());
+  pm.addNestedPass<mlir::func::FuncOp>(CreateLowerXlaGpuToScfPass());
   pm.addPass(mlir::createInlinerPass({}, [&](mlir::OpPassManager& pm) {
     // CSE after inlining because inlining can introduce duplicates.
     pm.addPass(mlir::createCSEPass());
   }));
   pm.addPass(mlir::createCanonicalizerPass());
   pm.addPass(mlir::createCSEPass());
+  pm.addNestedPass<mlir::func::FuncOp>(CreateLowerXlaGpuLoopsToScfPass());
   pm.addPass(mlir::mhlo::createConvertToSignlessPass());
   pm.addPass(CreatePropagateSliceIndicesPass());
   // We need LICM before unswitching loops, because our loop unswitcher only
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/optimize_loops.cc b/third_party/xla/xla/service/gpu/fusions/mlir/optimize_loops.cc
deleted file mode 100644
index 6d5456f..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/optimize_loops.cc
+++ /dev/null
@@ -1,315 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include <algorithm>
-#include <memory>
-#include <optional>
-#include <utility>
-
-#include "absl/algorithm/container.h"
-#include "absl/log/check.h"
-#include "llvm/ADT/APInt.h"
-#include "llvm/ADT/SmallVector.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/Math/IR/Math.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/SCF/Utils/Utils.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/IR/AffineExpr.h"
-#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/ImplicitLocOpBuilder.h"
-#include "mlir/IR/Matchers.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/Value.h"
-#include "mlir/IR/ValueRange.h"
-#include "mlir/IR/Visitors.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Support/LLVM.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h"
-
-namespace xla {
-namespace gpu {
-
-#define GEN_PASS_DEF_OPTIMIZELOOPSPASS
-#include "xla/service/gpu/fusions/mlir/passes.h.inc"
-
-namespace {
-
-mlir::Value GetSource(mlir::vector::TransferReadOp op) {
-  return op.getSource();
-}
-
-bool DoIndicesDependOnInductionVar(mlir::ValueRange indices,
-                                   mlir::scf::ForOp loop) {
-  // We assume LICM ran, so we can just check if any index is defined in the
-  // loop.
-  return absl::c_any_of(indices, [&](mlir::Value v) {
-    return v.getParentRegion() == &loop.getBodyRegion();
-  });
-}
-
-bool CanReplaceInductionVar(mlir::ValueRange indices) {
-  return absl::c_all_of(indices, [&](mlir::Value v) {
-    if (mlir::isa<mlir::BlockArgument>(v)) {
-      return true;
-    }
-    auto* op = v.getDefiningOp();
-    return op &&
-           mlir::isa<mlir::arith::ConstantOp, ApplyIndexingOp,
-                     mlir::arith::MaxSIOp, mlir::arith::MinSIOp,
-                     mlir::arith::IndexCastOp, mlir::arith::IndexCastUIOp>(
-               op) &&
-           CanReplaceInductionVar(op->getOperands());
-  });
-}
-
-llvm::SmallVector<mlir::Value> ReplaceInductionVar(
-    mlir::Value induction_var, mlir::Value replacement,
-    llvm::SmallVector<mlir::Value> indices,
-    mlir::ImplicitLocOpBuilder& builder) {
-  for (mlir::Value& index : indices) {
-    if (mlir::isa<mlir::BlockArgument>(index)) {
-      if (index == induction_var) {
-        index = replacement;
-      }
-      continue;
-    }
-
-    auto* op = index.getDefiningOp();
-    CHECK(op) << "Did CanReplaceInductionVar() fail?";
-    if (mlir::isa<mlir::arith::ConstantOp>(op)) {
-      continue;
-    }
-
-    CHECK(
-        (mlir::isa<ApplyIndexingOp, mlir::arith::MaxSIOp, mlir::arith::MinSIOp,
-                   mlir::arith::IndexCastOp, mlir::arith::IndexCastUIOp>(op)))
-        << "Did CanReplaceInductionVar() fail?";
-    auto replaced_args = ReplaceInductionVar(induction_var, replacement,
-                                             op->getOperands(), builder);
-    index = builder
-                .create(builder.getLoc(), op->getName().getIdentifier(),
-                        replaced_args, op->getResultTypes(), op->getAttrs())
-                ->getResult(0);
-  }
-  return indices;
-}
-
-mlir::Value GetSource(mlir::tensor::ExtractOp op) { return op.getTensor(); }
-
-// TODO(jreiffers): Use a shared memory queue for pipelining instead of
-// registers.
-template <typename Op>
-struct PipelineLoad : mlir::OpRewritePattern<Op> {
-  using mlir::OpRewritePattern<Op>::OpRewritePattern;
-
-  mlir::LogicalResult matchAndRewrite(
-      Op op, mlir::PatternRewriter& rewriter) const override {
-    auto loop = mlir::dyn_cast_or_null<mlir::scf::ForOp>(op->getParentOp());
-    if (!loop) {
-      return rewriter.notifyMatchFailure(op, "no loop found");
-    }
-
-    if (auto step = loop.getConstantStep();
-        !step || step->getSExtValue() != 1) {
-      return rewriter.notifyMatchFailure(op, "loop step is not 1");
-    }
-
-    llvm::APInt lb, ub;
-    if (!mlir::matchPattern(loop.getLowerBound(), mlir::m_ConstantInt(&lb)) ||
-        !mlir::matchPattern(loop.getUpperBound(), mlir::m_ConstantInt(&ub))) {
-      return rewriter.notifyMatchFailure(op, "bounds are not constants");
-    }
-    if (lb.getSExtValue() != 0) {
-      return rewriter.notifyMatchFailure(op, "lower bound is not 0");
-    }
-
-    auto source = GetSource(op);
-    if (!source.getParentRegion()->isProperAncestor(&loop.getBodyRegion())) {
-      return rewriter.notifyMatchFailure(
-          op, "source is not defined outside the loop");
-    }
-
-    if (!DoIndicesDependOnInductionVar(op.getIndices(), loop)) {
-      // We don't run LICM between iterations, so this could happen.
-      // Just hoist the load out of the loop.
-      rewriter.moveOpBefore(op, loop);
-      return mlir::success();
-    }
-
-    if (!CanReplaceInductionVar(op.getIndices())) {
-      return rewriter.notifyMatchFailure(op, "unable to replace indices");
-    }
-
-    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
-    mlir::Value zero = b.create<mlir::arith::ConstantIndexOp>(0);
-
-    b.setInsertionPoint(loop);
-    auto first_args =
-        ReplaceInductionVar(loop.getInductionVar(), zero, op.getOperands(), b);
-    auto loaded_first =
-        b.create<Op>(op->getResultTypes(), first_args, op->getAttrs());
-    auto ub_minus_one =
-        b.create<mlir::arith::ConstantIndexOp>(ub.getSExtValue() - 1);
-
-    b.setInsertionPointToStart(loop.getBody());
-
-    auto needs_load = b.create<mlir::arith::CmpIOp>(
-        mlir::arith::CmpIPredicate::ult, loop.getInductionVar(), ub_minus_one);
-    auto next_value =
-        b.create<mlir::scf::IfOp>(op->getResultTypes(), needs_load, true, true);
-    auto new_for =
-        mlir::cast<mlir::scf::ForOp>(*loop.replaceWithAdditionalYields(
-            rewriter, loaded_first->getResult(0),
-            /*replaceInitOperandUsesInLoop=*/false,
-            [&](mlir::OpBuilder&, mlir::Location,
-                llvm::ArrayRef<mlir::BlockArgument>) {
-              return llvm::SmallVector<mlir::Value>{next_value->getResult(0)};
-            }));
-    rewriter.replaceAllUsesWith(op, new_for.getRegionIterArgs().back());
-
-    b.setInsertionPointToStart(next_value.thenBlock());
-    auto yield = b.create<mlir::scf::YieldOp>(op->getResult(0));
-
-    // We use this convoluted way to add 1 so folding works properly.
-    auto plus_one_map = mlir::AffineMap::get(
-        1, 0, mlir::getAffineDimExpr(0, this->getContext()) + 1);
-    b.setInsertionPoint(next_value);
-    auto induction_plus_one =
-        b.create<ApplyIndexingOp>(new_for.getInductionVar(), plus_one_map, 0,
-                                  ub.getSExtValue() - 1)
-            ->getResult(0);
-
-    // Create the new apply_indexing ops outside the if, to improve CSE.
-    rewriter.modifyOpInPlace(op, [&]() {
-      op->setOperands(ReplaceInductionVar(
-          new_for.getInductionVar(), induction_plus_one, op->getOperands(), b));
-    });
-    rewriter.moveOpBefore(op, yield);
-
-    b.setInsertionPointToStart(next_value.elseBlock());
-    b.create<mlir::scf::YieldOp>(new_for.getRegionIterArgs().back());
-    return mlir::success();
-  }
-};
-
-int GetUnrollingFactor(mlir::scf::ForOp op) {
-  // We only unroll loops with a step of 1 and a lower bound of 0. That's the
-  // only type we generate.
-  if (auto step = op.getConstantStep(); !step || step->getSExtValue() != 1) {
-    return 1;
-  }
-  llvm::APInt lb, ub;
-  if (!mlir::matchPattern(op.getLowerBound(), mlir::m_ConstantInt(&lb)) ||
-      !mlir::matchPattern(op.getUpperBound(), mlir::m_ConstantInt(&ub))) {
-    return 1;
-  }
-  if (lb.getSExtValue() != 0) {
-    return 1;
-  }
-
-  int64_t trip_count = ub.getSExtValue();
-  constexpr int kMaxSize = 400;  // Chosen empirically.
-
-  // Get a rough estimate of the size of the loop body.
-  int64_t size = 0;
-  op.getBodyRegion().walk([&](mlir::Operation* op) {
-    if (mlir::isa<mlir::func::CallOp, mlir::scf::ForOp>(op)) {
-      size += kMaxSize;
-      return;
-    }
-
-    int64_t this_size = 1;
-    if (mlir::isa<mlir::math::MathDialect>(op->getDialect())) {
-      // Integer instructions in math are ok, but many float ops lower to lots
-      // of instructions.
-      if (!op->getResultTypes().front().isIntOrIndex()) {
-        namespace mm = mlir::math;
-        // We err on the side of not unrolling, so we maintain a list of ops
-        // known to be cheap.
-        if (!mlir::isa<mm::AbsFOp, mm::CeilOp, mm::CopySignOp, mm::FloorOp,
-                       mm::FmaOp, mm::RoundEvenOp, mm::RoundOp, mm::RsqrtOp,
-                       mm::SqrtOp, mm::TruncOp>(op)) {
-          this_size = 20;  // Rough estimate.
-        }
-      }
-    }
-
-    if (!op->getResultTypes().empty()) {
-      if (auto vector_ty =
-              mlir::dyn_cast<mlir::VectorType>(op->getResultTypes().front())) {
-        this_size *= vector_ty.getNumElements();
-      }
-    }
-
-    size += this_size;
-  });
-
-  int factor = std::min(trip_count, kMaxSize / size);
-  while (factor > 1 && trip_count % factor) {
-    --factor;
-  }
-  return factor;
-}
-
-struct UnrollLoops : mlir::OpRewritePattern<mlir::scf::ForOp> {
-  using mlir::OpRewritePattern<mlir::scf::ForOp>::OpRewritePattern;
-
-  mlir::LogicalResult matchAndRewrite(
-      mlir::scf::ForOp op, mlir::PatternRewriter& rewriter) const override {
-    if (int factor = GetUnrollingFactor(op); factor > 1) {
-      return mlir::loopUnrollByFactor(op, factor);
-    }
-    return rewriter.notifyMatchFailure(op, "loop can't be unrolled");
-  }
-};
-
-class OptimizeLoopsPass
-    : public impl::OptimizeLoopsPassBase<OptimizeLoopsPass> {
- public:
-  void runOnOperation() override {
-    // First unroll loops. If unrolling is possible, we prefer it.
-    mlir::RewritePatternSet unroll_patterns(&getContext());
-    unroll_patterns.add<UnrollLoops>(&getContext());
-    if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
-            getOperation(), std::move(unroll_patterns)))) {
-      signalPassFailure();
-      return;
-    }
-
-    // Then pipeline the remaining loops.
-    mlir::RewritePatternSet patterns(&getContext());
-    patterns.add<PipelineLoad<mlir::vector::TransferReadOp>,
-                 PipelineLoad<mlir::tensor::ExtractOp>>(&getContext());
-    if (mlir::failed(mlir::applyPatternsAndFoldGreedily(getOperation(),
-                                                        std::move(patterns)))) {
-      signalPassFailure();
-    }
-  }
-};
-
-}  // namespace
-
-std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
-CreateOptimizeLoopsPass() {
-  return std::make_unique<OptimizeLoopsPass>();
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/passes.h b/third_party/xla/xla/service/gpu/fusions/mlir/passes.h
deleted file mode 100644
index bb0f1d4..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/passes.h
+++ /dev/null
@@ -1,61 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_FUSIONS_MLIR_PASSES_H_
-#define XLA_SERVICE_GPU_FUSIONS_MLIR_PASSES_H_
-
-#include <memory>
-#include <optional>
-#include <string>
-
-#include "mlir/IR/Value.h"
-#include "mlir/Pass/Pass.h"
-#include "xla/service/gpu/model/indexing_map.h"
-
-namespace xla {
-namespace gpu {
-
-#define GEN_PASS_DECL
-#include "xla/service/gpu/fusions/mlir/passes.h.inc"
-
-// Returns the range of a given value, if it can be statically determined.
-std::optional<Interval> GetRange(mlir::Value value);
-
-// Returns the range for the induction variable, if it can be statically
-// determined.
-std::optional<Interval> GetIVRange(mlir::Value iv);
-
-std::unique_ptr<mlir::Pass> CreateEraseDeadFunctionsPass();
-std::unique_ptr<mlir::Pass> CreateExpandFloatOpsPass(bool pre_ampere);
-std::unique_ptr<mlir::Pass> CreateConvertPureCallOpsPass();
-std::unique_ptr<mlir::Pass> CreateFlattenTensorsPass();
-std::unique_ptr<mlir::Pass> CreateLowerTensorsPass(
-    bool is_amd_gpu = false, const std::string& gpu_arch = "6.0");
-std::unique_ptr<mlir::Pass> CreateLowerToLLVMPass();
-std::unique_ptr<mlir::Pass> CreateLowerXlaGpuToScfPass();
-std::unique_ptr<mlir::Pass> CreateMergePointersToSameSlicePass();
-std::unique_ptr<mlir::Pass> CreateOptimizeLoopsPass();
-std::unique_ptr<mlir::Pass> CreatePropagateSliceIndicesPass();
-std::unique_ptr<mlir::Pass> CreateSimplifyAffinePass();
-std::unique_ptr<mlir::Pass> CreateSimplifyArithPass();
-std::unique_ptr<mlir::Pass> CreateUnswitchLoopsPass();
-std::unique_ptr<mlir::Pass> CreateVectorizeLoadsAndStoresPass();
-
-#define GEN_PASS_REGISTRATION
-#include "xla/service/gpu/fusions/mlir/passes.h.inc"
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_FUSIONS_MLIR_PASSES_H_
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/passes.td b/third_party/xla/xla/service/gpu/fusions/mlir/passes.td
deleted file mode 100644
index 6785670..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/passes.td
+++ /dev/null
@@ -1,290 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_FUSIONS_MLIR_PASSES_TD_
-#define XLA_SERVICE_GPU_FUSIONS_MLIR_PASSES_TD_
-
-include "mlir/Pass/PassBase.td"
-
-def PropagateSliceIndicesPass :
-   Pass<"xla-gpu-propagate-slice-indices", "mlir::ModuleOp"> {
-  let summary = "Propagates slice indices from the entry function to all callees.";
-
-  let description = [{
-      Propagates xla.slice_index attributes from the function with the xla.entry
-      attribute to all other functions.
-  }];
-
-  let dependentDialects = [
-    "mlir::func::FuncDialect"
-  ];
-
-  let constructor = "CreatePropagateSliceIndicesPass()";
-}
-
-def ConvertPureCallOpsPass
-    : Pass<"xla-gpu-convert-pure-call-ops", "mlir::func::FuncOp"> {
-  let summary = "Converts xla_gpu.pure_call to func.call";
-  let description = [{
-      We use xla_gpu.pure_call ops for calls to enable CSE and other
-      transformations (e.g. LICM). This pass rewrites our custom ops to standard
-      ops.
-  }];
-  let dependentDialects = [
-    "mlir::func::FuncDialect",
-    "xla::gpu::XlaGpuDialect"
-  ];
-  let constructor = "CreateConvertPureCallOpsPass()";
-}
-
-def FlattenTensorsPass : Pass<"xla-gpu-flatten-tensors", "mlir::ModuleOp"> {
-  let summary = "Flatten tensors.";
-
-  let description = [{
-    Linearizes all tensors loads and stores.
-  }];
-
-  let dependentDialects = [
-    "mlir::func::FuncDialect",
-    "mlir::tensor::TensorDialect",
-    "xla::gpu::XlaGpuDialect",
-  ];
-  let constructor = "CreateFlattenTensorsPass()";
-}
-
-def LowerTensorsPass : Pass<"xla-gpu-lower-tensors", "mlir::ModuleOp"> {
-  let summary = "Lowers tensors to llvm pointers and loads/stores.";
-
-  let description = [{
-      Lowers tensors to LLVM. We cannot use the memref lowerings because they
-      are not compatible with XLA's ABI.
-  }];
-
-  let dependentDialects = [
-    "mlir::LLVM::LLVMDialect",
-    "mlir::func::FuncDialect",
-    "mlir::gpu::GPUDialect",
-    "mlir::scf::SCFDialect",
-    "mlir::tensor::TensorDialect",
-    "xla::gpu::XlaGpuDialect",
-  ];
-  let options = [
-    Option<"is_amd_gpu_", "is_amd_gpu", "bool", /*default=*/"false",
-           "True if AMD GPU.">,
-    Option<"gpu_arch_", "gpu_arch", "std::string", /*default=*/"",
-           "CUDA or ROCm compute capability.">,
-  ];
-  let constructor = "CreateLowerTensorsPass()";
-}
-
-def MergePointersToSameSlicePass :
-   Pass<"xla-gpu-merge-pointers", "mlir::ModuleOp"> {
-  let summary = "Merges pointers that share slices.";
-
-  let description = [{
-      When a function has multiple pointer arguments with the same slice index,
-      merges them.
-  }];
-
-  let dependentDialects = [
-    "mlir::func::FuncDialect"
-  ];
-
-  let constructor = "CreateMergePointersToSameSlicePass()";
-}
-
-def SimplifyArithPass : Pass<"xla-gpu-simplify-arith", "mlir::func::FuncOp"> {
-  let summary = "Simplifies arith using XLA's range-aware simplifier.";
-
-  let description = [{
-      We often emit bounds checks that are statically known to be satisfied.
-      This pass removes them.
-  }];
-
-  let dependentDialects = [
-    "mlir::arith::ArithDialect",
-    "mlir::func::FuncDialect",
-  ];
-
-  let constructor = "CreateSimplifyArithPass()";
-}
-
-def SimplifyAffinePass : Pass<"xla-gpu-simplify-affine", "mlir::ModuleOp"> {
-  let summary = "Simplifies affine.apply using XLA's range-aware simplifier.";
-
-  let description = [{
-      The standard affine canonicalizer cannot simplify all expressions, since
-      it is unaware of range information. This pass uses `xla.range` attributes
-      on arguments and ops for simplification. It also lowers floordiv and mod
-      to simpler expressions than lower-affine. This pass only works for
-      expressions for which we can prove the LHS of mod and div is nonnegative.
-  }];
-
-  let dependentDialects = [
-    "mlir::affine::AffineDialect", "mlir::func::FuncDialect",
-    "mlir::scf::SCFDialect",
-  ];
-
-  let constructor = "CreateSimplifyAffinePass()";
-}
-
-def ExpandFloatOpsPass : Pass<"xla-gpu-expand-float-ops", "mlir::ModuleOp"> {
-  let summary = "Expands float ops that are not natively supported.";
-
-  let description = [{
-     Not all float ops are natively supported, either because they don't exist
-     in hardware or they are too inaccurate.
-
-     This pass replaces these ops with alternative implementations.
-  }];
-
-  let dependentDialects = [
-    "mlir::arith::ArithDialect", "mlir::math::MathDialect",
-    "mlir::mhlo::MhloDialect"
-  ];
-
-  let options = [
-    Option<"pre_ampere_", "pre-ampere", "bool", /*default=*/"false",
-           "Rewrite ops that are not supported on architectures before Ampere">,
-  ];
-}
-
-def LowerXlaGpuToScfPass :
-   Pass<"xla-gpu-lower-xla-gpu-to-scf", "mlir::ModuleOp"> {
-  let summary = "Lowers xla_gpu to SCF.";
-
-  let dependentDialects = [
-    "mlir::gpu::GPUDialect", "mlir::LLVM::LLVMDialect", "mlir::scf::SCFDialect",
-    "mlir::tensor::TensorDialect", "xla::gpu::XlaGpuDialect",
-  ];
-
-  let constructor = "CreateLowerXlaGpuToScfPass()";
-}
-
-def EraseDeadFunctionsPass : Pass<"xla-erase-dead-functions", "mlir::ModuleOp"> {
-  let summary = "Deletes unused functions";
-
-  let description = [{
-      Deletes functions that are not called.
-  }];
-
-  let dependentDialects = [
-    "mlir::func::FuncDialect",
-    "xla::gpu::XlaGpuDialect"
-  ];
-
-  let constructor = "CreateEraseDeadFunctionsPass()";
-}
-
-def LowerToLLVMPass :
-   Pass<"xla-gpu-lower-to-llvm", "mlir::ModuleOp"> {
-  let summary = "Lowers to LLVM.";
-
-  let description = [{
-    Lowers the rest to LLVM
-  }];
-
-  let dependentDialects = [
-    "mlir::func::FuncDialect",
-    "mlir::LLVM::LLVMDialect",
-    "mlir::NVVM::NVVMDialect",
-  ];
-
-  let constructor = "CreateLowerToLLVMPass()";
-}
-
-def VectorizeLoadsAndStoresPass :
-   Pass<"xla-gpu-vectorize-loads-stores", "mlir::func::FuncOp"> {
-  let summary = "Vectorizes loads and stores.";
-
-  let description = [{
-    Rewrites tensor.extract and tensor.insert ops inside loops to their vector
-    equivalents (vector.transfer_read and vector.transfer_write + vector.extract
-    and vector.insert).
-  }];
-
-  let dependentDialects = [
-    "mlir::vector::VectorDialect",
-  ];
-
-  let constructor = "CreateVectorizeLoadsAndStoresPass()";
-}
-
-def OptimizeLoopsPass :
-   Pass<"xla-gpu-optimize-loops", "mlir::func::FuncOp"> {
-  let summary = "Unrolls and pipelines loops.";
-
-  let description = [{
-    Unrolls loops with a small trip count. Pipelines loops with a large trip
-    count.
-  }];
-
-  let dependentDialects = [
-    "mlir::vector::VectorDialect",
-    "xla::gpu::XlaGpuDialect",
-  ];
-
-  let constructor = "CreateOptimizeLoopsPass()";
-}
-
-def UnswitchLoopsPass :
-   Pass<"xla-gpu-unswitch-loops", "mlir::func::FuncOp"> {
-  let summary = "Swaps scf.if and scf.for.";
-
-  let description = [{
-      Extracts `scf.if` ops with conditions that are independent of the loop
-      variable from `scf.for` by doing the following rewrite:
-
-      Before:
-
-      %cond = some_cond() : i1
-      %results = scf.for {
-        %some_val = scf.if %cond  {
-        } else {
-        }
-        scf.yield %some_val
-      }
-
-      After:
-
-      %cond = some_cond() : i1
-      %results = scf.if %cond {
-         %results = scf.for {
-            %some_val = scf.if %true  {
-            } else {
-            }
-         }
-         yield %results
-      } else {
-         %results = scf.for {
-            %some_val = scf.if %false  {
-            } else {
-            }
-         }
-         yield %results
-      }
-
-      This only triggers if there is a single `scf.if` op in the loop body (and
-      nothing else).
-  }];
-
-  let dependentDialects = [
-    "mlir::func::FuncDialect", "mlir::scf::SCFDialect"
-  ];
-
-  let constructor = "CreateUnswitchLoopsPass()";
-}
-
-#endif  // XLA_SERVICE_GPU_FUSIONS_MLIR_PASSES_TD_
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/propagate_slice_indices.cc b/third_party/xla/xla/service/gpu/fusions/mlir/propagate_slice_indices.cc
deleted file mode 100644
index 218b432f..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/propagate_slice_indices.cc
+++ /dev/null
@@ -1,80 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include <memory>
-
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Support/LLVM.h"
-#include "xla/service/gpu/fusions/mlir/passes.h"
-
-namespace xla {
-namespace gpu {
-
-#define GEN_PASS_DEF_PROPAGATESLICEINDICESPASS
-#include "xla/service/gpu/fusions/mlir/passes.h.inc"
-
-namespace {
-
-class PropagateSliceIndicesPass
-    : public impl::PropagateSliceIndicesPassBase<PropagateSliceIndicesPass> {
- public:
-  void runOnOperation() override;
-};
-
-void PropagateSliceIndicesPass::runOnOperation() {
-  mlir::func::FuncOp entry;
-  for (auto func : getOperation().getOps<mlir::func::FuncOp>()) {
-    if (func->getAttr("xla.entry")) {
-      entry = func;
-      break;
-    }
-  }
-
-  if (!entry) {
-    getOperation()->emitOpError("No entry function found.");
-    signalPassFailure();
-    return;
-  }
-
-  for (auto func : getOperation().getOps<mlir::func::FuncOp>()) {
-    if (func.getNumArguments() == 0 || func == entry) {
-      continue;
-    }
-
-    for (int i = 0; i < func.getNumArguments(); ++i) {
-      if (mlir::isa<mlir::RankedTensorType>(func.getArgument(i).getType())) {
-        if (auto index = entry.getArgAttr(i, "xla.slice_index")) {
-          func.setArgAttr(i, "xla.slice_index", index);
-        }
-        if (auto invariant = entry.getArgAttr(i, "xla.invariant")) {
-          func.setArgAttr(i, "xla.invariant", invariant);
-        }
-      } else {
-        break;
-      }
-    }
-  }
-}
-
-}  // namespace
-
-std::unique_ptr<mlir::Pass> CreatePropagateSliceIndicesPass() {
-  return std::make_unique<PropagateSliceIndicesPass>();
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/simplify_affine.cc b/third_party/xla/xla/service/gpu/fusions/mlir/simplify_affine.cc
deleted file mode 100644
index 7b23499..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/simplify_affine.cc
+++ /dev/null
@@ -1,368 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include <cstdint>
-#include <limits>
-#include <memory>
-#include <optional>
-#include <utility>
-#include <vector>
-
-#include "absl/base/optimization.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/Support/Casting.h"
-#include "llvm/Support/LogicalResult.h"
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
-#include "mlir/Dialect/Affine/LoopUtils.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/IR/AffineExpr.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/IR/ImplicitLocOpBuilder.h"
-#include "mlir/IR/Matchers.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/Value.h"
-#include "mlir/IR/ValueRange.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Support/LLVM.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h"
-#include "xla/service/gpu/fusions/mlir/passes.h"
-#include "xla/service/gpu/model/indexing_map.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-using mlir::AffineBinaryOpExpr;
-using mlir::AffineConstantExpr;
-using mlir::AffineDimExpr;
-using mlir::AffineExpr;
-using mlir::AffineExprKind;
-using mlir::AffineMap;
-using mlir::AffineSymbolExpr;
-using mlir::ImplicitLocOpBuilder;
-using mlir::LogicalResult;
-using mlir::MLIRContext;
-using mlir::OpRewritePattern;
-using mlir::PatternRewriter;
-using mlir::SmallVector;
-using mlir::Value;
-using mlir::ValueRange;
-using mlir::affine::AffineApplyOp;
-
-namespace arith = mlir::arith;
-
-#define GEN_PASS_DEF_SIMPLIFYAFFINEPASS
-#include "xla/service/gpu/fusions/mlir/passes.h.inc"
-
-int Distance(ImplicitLocOpBuilder& builder, Value a) {
-  auto* block = builder.getInsertionBlock();
-  auto* parent = a.getParentBlock();
-  int distance = 0;
-  while (block && block != parent) {
-    ++distance;
-    block = block->getParentOp()->getBlock();
-  }
-  return distance;
-}
-
-void CollectArgs(AffineExpr expr, AffineExprKind kind,
-                 llvm::SmallVector<AffineExpr>& ret) {
-  if (auto bin_op = mlir::dyn_cast<AffineBinaryOpExpr>(expr)) {
-    if (bin_op.getKind() == kind) {
-      CollectArgs(bin_op.getLHS(), kind, ret);
-      CollectArgs(bin_op.getRHS(), kind, ret);
-      return;
-    }
-  }
-  ret.push_back(expr);
-}
-
-struct ExpressionEvaluator {
-  ExpressionEvaluator(ImplicitLocOpBuilder& builder, unsigned dim_count,
-                      ValueRange operands)
-      : builder(builder), operands(operands) {
-    for (int i = 0; i < dim_count; ++i) {
-      dim_distances.push_back(Distance(builder, operands[i]));
-    }
-    for (int i = dim_count; i < operands.size(); ++i) {
-      sym_distances.push_back(Distance(builder, operands[i]));
-    }
-  }
-
-  // Returns the distance (in basic blocks) from the insertion point to the
-  // values used in the given expression.
-  int ExprDistance(AffineExpr e, int depth = 0) {
-    if (auto dim = mlir::dyn_cast<AffineDimExpr>(e)) {
-      return dim_distances[dim.getPosition()];
-    }
-    if (auto sym = mlir::dyn_cast<AffineSymbolExpr>(e)) {
-      return sym_distances[sym.getPosition()];
-    }
-    if (auto binop = mlir::dyn_cast<AffineBinaryOpExpr>(e)) {
-      return std::min(ExprDistance(binop.getLHS(), depth + 1),
-                      ExprDistance(binop.getRHS(), depth + 1));
-    }
-    if (depth == 0) {
-      // Top-level constant. Always add these last.
-      return std::numeric_limits<int>::min();
-    }
-    // Nested constant. Ignore these for distances.
-    return std::numeric_limits<int>::max();
-  }
-
-  Value EvaluateExpression(AffineExpr expr);
-
-  template <typename Op>
-  Value EvaluateAddMul(AffineExpr expr);
-
-  ImplicitLocOpBuilder& builder;
-  ValueRange operands;
-  SmallVector<int> dim_distances;
-  SmallVector<int> sym_distances;
-};
-
-template <typename Op>
-Value ExpressionEvaluator::EvaluateAddMul(AffineExpr expr) {
-  llvm::SmallVector<AffineExpr> args;
-  CollectArgs(expr, expr.getKind(), args);
-  // Sort the args so that the ones that are closest to the insertion point
-  // are evaluated last - this improves LICM.
-  llvm::stable_sort(args, [&](AffineExpr a, AffineExpr b) {
-    int dist_a = ExprDistance(a);
-    int dist_b = ExprDistance(b);
-    return dist_a > dist_b;
-  });
-
-  Value result = nullptr;
-  for (auto arg : args) {
-    Value arg_evaluated = EvaluateExpression(arg);
-    if (result) {
-      result = builder.create<Op>(result, arg_evaluated);
-    } else {
-      result = arg_evaluated;
-    }
-  }
-
-  return result;
-}
-
-Value ExpressionEvaluator::EvaluateExpression(AffineExpr expr) {
-  if (auto bin_op = mlir::dyn_cast<AffineBinaryOpExpr>(expr)) {
-    switch (expr.getKind()) {
-      case AffineExprKind::Add:
-        return EvaluateAddMul<arith::AddIOp>(expr);
-      case AffineExprKind::Mul:
-        return EvaluateAddMul<arith::MulIOp>(expr);
-      case AffineExprKind::Mod:
-        return builder.create<arith::RemUIOp>(
-            EvaluateExpression(bin_op.getLHS()),
-            EvaluateExpression(bin_op.getRHS()));
-      case AffineExprKind::FloorDiv:
-        return builder.create<arith::DivUIOp>(
-            EvaluateExpression(bin_op.getLHS()),
-            EvaluateExpression(bin_op.getRHS()));
-      default:
-        ABSL_UNREACHABLE();
-    }
-  }
-  switch (expr.getKind()) {
-    case AffineExprKind::Constant:
-      return builder.create<arith::ConstantIndexOp>(
-          mlir::cast<AffineConstantExpr>(expr).getValue());
-    case AffineExprKind::DimId:
-      return operands[mlir::cast<AffineDimExpr>(expr).getPosition()];
-    case AffineExprKind::SymbolId:
-      return operands[dim_distances.size() +
-                      mlir::cast<AffineSymbolExpr>(expr).getPosition()];
-    default:
-      ABSL_UNREACHABLE();
-  }
-}
-
-bool IsLoweringSupported(AffineExpr expr, RangeEvaluator& range_evaluator) {
-  auto bin_op = llvm::dyn_cast<AffineBinaryOpExpr>(expr);
-  if (!bin_op) {
-    return true;
-  }
-  // Mod and div can be lowered if their LHS is >= 0 and their RHS is a
-  // constant.
-  if (expr.getKind() == AffineExprKind::Mod ||
-      expr.getKind() == AffineExprKind::FloorDiv) {
-    if (!range_evaluator.IsAlwaysPositiveOrZero(bin_op.getLHS()) ||
-        !range_evaluator.ComputeExpressionRange(bin_op.getRHS()).IsPoint()) {
-      return false;
-    }
-  }
-  if (expr.getKind() == AffineExprKind::CeilDiv) {
-    return false;
-  }
-  return IsLoweringSupported(bin_op.getLHS(), range_evaluator) &&
-         IsLoweringSupported(bin_op.getRHS(), range_evaluator);
-}
-
-struct RewriteAffineApply : OpRewritePattern<mlir::affine::AffineApplyOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(mlir::affine::AffineApplyOp op,
-                                PatternRewriter& rewriter) const override {
-    AffineMap affine_map = op.getAffineMap();
-    std::vector<DimVar> dim_ranges(affine_map.getNumDims());
-    std::vector<RangeVar> symbol_ranges(affine_map.getNumSymbols());
-
-    for (int i = 0; i < affine_map.getNumInputs(); ++i) {
-      if (auto range = GetRange(op->getOperand(i))) {
-        if (i >= dim_ranges.size()) {
-          symbol_ranges[i - dim_ranges.size()] = RangeVar{*range};
-        } else {
-          dim_ranges[i] = DimVar{*range};
-        }
-      } else {
-        return rewriter.notifyMatchFailure(op, "failed to deduce range");
-      }
-    }
-
-    IndexingMap indexing_map(affine_map, std::move(dim_ranges),
-                             std::move(symbol_ranges),
-                             /*rt_vars=*/{});
-    indexing_map.Simplify();
-    auto result_expr = indexing_map.GetAffineMap().getResult(0);
-
-    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
-    RangeEvaluator range_evaluator = indexing_map.GetRangeEvaluator();
-    if (!IsLoweringSupported(result_expr, range_evaluator)) {
-      return rewriter.notifyMatchFailure(op,
-                                         "unable to lower the affine apply");
-    }
-    b.setInsertionPoint(op);
-    auto result = ExpressionEvaluator(b, indexing_map.GetDimensionCount(),
-                                      op->getOperands())
-                      .EvaluateExpression(result_expr);
-    rewriter.replaceOp(op, result);
-    return mlir::success();
-  }
-};
-
-struct RewriteApplyIndexingOp : OpRewritePattern<ApplyIndexingOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ApplyIndexingOp op,
-                                PatternRewriter& rewriter) const override {
-    auto indexing_map = op.getIndexingMap();
-    indexing_map.Simplify();
-    auto affine_map = indexing_map.GetAffineMap();
-    int64_t dim_count = indexing_map.GetDimensionCount();
-    auto operands = op->getOperands();
-
-    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
-    RangeEvaluator range_evaluator = indexing_map.GetRangeEvaluator();
-
-    b.setInsertionPoint(op);
-    SmallVector<Value, 4> results;
-    results.reserve(affine_map.getNumResults());
-    for (unsigned i = 0; i < affine_map.getNumResults(); ++i) {
-      AffineExpr result_expr = affine_map.getResult(i);
-      // If the expression cannot be lowered, we convert it to affine.apply,
-      // since it supports more expression types.
-      if (IsLoweringSupported(result_expr, range_evaluator)) {
-        results.push_back(ExpressionEvaluator(b, dim_count, operands)
-                              .EvaluateExpression(result_expr));
-      } else {
-        results.push_back(
-            b.create<AffineApplyOp>(affine_map.getSubMap({i}), operands));
-      }
-    }
-    rewriter.replaceOp(op, results);
-    return mlir::success();
-  }
-};
-
-struct SimplifyAffinePass
-    : public impl::SimplifyAffinePassBase<SimplifyAffinePass> {
- public:
-  void runOnOperation() override {
-    MLIRContext* ctx = &getContext();
-    mlir::RewritePatternSet patterns(ctx);
-    patterns.add<RewriteAffineApply, RewriteApplyIndexingOp>(ctx);
-    mlir::GreedyRewriteConfig config;
-    // There's no point simplifying more than once.
-    config.strictMode = mlir::GreedyRewriteStrictness::ExistingOps;
-    if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
-            getOperation(), std::move(patterns), config))) {
-      signalPassFailure();
-    }
-  }
-};
-
-}  // namespace
-
-std::optional<Interval> GetRange(mlir::Value value) {
-  auto attr_to_range = [](mlir::Attribute attr) -> std::optional<Interval> {
-    if (!attr) {
-      return std::nullopt;
-    }
-    auto values = llvm::to_vector(
-        mlir::cast<mlir::ArrayAttr>(attr).getAsValueRange<mlir::IntegerAttr>());
-    return {{values[0].getSExtValue(), values[1].getSExtValue()}};
-  };
-
-  if (auto apply = value.getDefiningOp<ApplyIndexingOp>()) {
-    return apply.getIndexingMap().GetRangeEvaluator().ComputeExpressionRange(
-        apply.getIndexingMap().GetAffineMap().getResult(
-            mlir::cast<mlir::OpResult>(value).getResultNumber()));
-  } else if (auto cst = value.getDefiningOp<mlir::arith::ConstantIndexOp>()) {
-    return {{cst.value(), cst.value()}};
-  } else if (value.getDefiningOp()) {
-    return attr_to_range(value.getDefiningOp()->getAttr("xla.range"));
-  }
-
-  auto bbarg = mlir::dyn_cast<mlir::BlockArgument>(value);
-  if (!bbarg) {
-    return std::nullopt;
-  }
-
-  auto parent = bbarg.getParentBlock()->getParentOp();
-  if (auto func_op = mlir::dyn_cast<mlir::func::FuncOp>(parent)) {
-    return attr_to_range(func_op.getArgAttr(bbarg.getArgNumber(), "xla.range"));
-  }
-  return GetIVRange(value);
-}
-
-std::optional<Interval> GetIVRange(mlir::Value iv) {
-  auto bbarg = mlir::dyn_cast<mlir::BlockArgument>(iv);
-  if (!bbarg) {
-    return std::nullopt;
-  }
-  auto parent = bbarg.getParentBlock()->getParentOp();
-  if (auto for_op = mlir::dyn_cast<mlir::scf::ForOp>(parent)) {
-    llvm::APInt lb, ub;
-    if (mlir::matchPattern(for_op.getLowerBound(), mlir::m_ConstantInt(&lb)) &&
-        mlir::matchPattern(for_op.getUpperBound(), mlir::m_ConstantInt(&ub))) {
-      return {{lb.getSExtValue(), ub.getSExtValue() - 1}};
-    }
-  }
-  return std::nullopt;
-}
-
-std::unique_ptr<mlir::Pass> CreateSimplifyAffinePass() {
-  return std::make_unique<SimplifyAffinePass>();
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/simplify_arith.cc b/third_party/xla/xla/service/gpu/fusions/mlir/simplify_arith.cc
deleted file mode 100644
index 77b1d7c..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/simplify_arith.cc
+++ /dev/null
@@ -1,344 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include <cstdint>
-#include <limits>
-#include <memory>
-#include <optional>
-#include <utility>
-
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/GPU/IR/GPUDialect.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/Value.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Support/LLVM.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h"
-#include "xla/service/gpu/fusions/mlir/passes.h"
-#include "xla/service/gpu/model/indexing_map.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-#define GEN_PASS_DEF_SIMPLIFYARITHPASS
-#include "xla/service/gpu/fusions/mlir/passes.h.inc"
-
-using mlir::LogicalResult;
-using mlir::OpRewritePattern;
-using mlir::PatternRewriter;
-using mlir::arith::CmpIOp;
-using mlir::arith::CmpIPredicate;
-
-Interval::ComparisonResult EvaluateCmpI(CmpIPredicate pred, Interval lhs,
-                                        Interval rhs) {
-  switch (pred) {
-    case CmpIPredicate::eq:
-      return lhs.Eq(rhs);
-    case CmpIPredicate::ne:
-      return lhs.Ne(rhs);
-    case CmpIPredicate::slt:
-    case CmpIPredicate::ult:
-      return lhs.Lt(rhs);
-    case CmpIPredicate::sle:
-    case CmpIPredicate::ule:
-      return lhs.Le(rhs);
-    case CmpIPredicate::sgt:
-    case CmpIPredicate::ugt:
-      return lhs.Gt(rhs);
-    case CmpIPredicate::sge:
-    case CmpIPredicate::uge:
-      return lhs.Ge(rhs);
-  }
-}
-
-struct RewriteCmpI : OpRewritePattern<CmpIOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(CmpIOp op,
-                                PatternRewriter& rewriter) const override {
-    auto rhs = GetRange(op.getRhs());
-    auto lhs = GetRange(op.getLhs());
-    if (!lhs || !rhs) {
-      return rewriter.notifyMatchFailure(op, "failed to deduce input ranges");
-    }
-    Interval::ComparisonResult result =
-        EvaluateCmpI(op.getPredicate(), *lhs, *rhs);
-    if (result != std::nullopt) {
-      rewriter.replaceOpWithNewOp<mlir::arith::ConstantIntOp>(
-          op, *result, rewriter.getI1Type());
-      return mlir::success();
-    }
-    return rewriter.notifyMatchFailure(op, "not a constant result");
-  }
-};
-
-struct RewriteMaxSi : OpRewritePattern<mlir::arith::MaxSIOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(mlir::arith::MaxSIOp op,
-                                PatternRewriter& rewriter) const override {
-    auto lhs = GetRange(op.getLhs());
-    auto rhs = GetRange(op.getRhs());
-    if (!lhs || !rhs) {
-      return rewriter.notifyMatchFailure(op, "failed to deduce input ranges");
-    }
-    if (auto lhs_ge_rhs = lhs->Ge(*rhs); lhs_ge_rhs == true) {
-      rewriter.replaceOp(op, op.getLhs());
-    } else if (auto rhs_ge_lhs = rhs->Ge(*lhs); rhs_ge_lhs == true) {
-      rewriter.replaceOp(op, op.getRhs());
-    } else {
-      return rewriter.notifyMatchFailure(op, "not equal to lhs or rhs");
-    }
-    return mlir::success();
-  }
-};
-
-struct RewriteMinSi : OpRewritePattern<mlir::arith::MinSIOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(mlir::arith::MinSIOp op,
-                                PatternRewriter& rewriter) const override {
-    auto lhs = GetRange(op.getLhs());
-    auto rhs = GetRange(op.getRhs());
-    if (!lhs || !rhs) {
-      return rewriter.notifyMatchFailure(op, "failed to deduce input ranges");
-    }
-    if (auto lhs_le_rhs = lhs->Le(*rhs); lhs_le_rhs == true) {
-      rewriter.replaceOp(op, op.getLhs());
-    } else if (auto rhs_le_lhs = rhs->Le(*lhs); rhs_le_lhs == true) {
-      rewriter.replaceOp(op, op.getRhs());
-    } else {
-      return rewriter.notifyMatchFailure(op, "not equal to lhs or rhs");
-    }
-    return mlir::success();
-  }
-};
-
-// Finds the narrowest value in a use-def chain of truncis/extuis.
-mlir::Value FindNarrowestValueInChain(mlir::Value value) {
-  if (auto ext = value.getDefiningOp<mlir::arith::ExtUIOp>()) {
-    return FindNarrowestValueInChain(ext.getOperand());
-  }
-  auto defining_op = value.getDefiningOp<mlir::arith::TruncIOp>();
-  if (defining_op) {
-    auto first_trunc = FindNarrowestValueInChain(defining_op.getOperand());
-    if (first_trunc && first_trunc.getType().getIntOrFloatBitWidth() <=
-                           defining_op.getType().getIntOrFloatBitWidth()) {
-      return first_trunc;
-    }
-    return defining_op;
-  }
-  return value;
-}
-
-// Rewrites trunc-bitwise to bitwise-trunc.
-//
-// For pred reductions, we generate code like this:
-//
-//   %1 = arith.trunci %0 : i32 to i1
-//   %2 = arith.ori %1, %x
-//   %3 = arith.extui %2 : i1 to i32
-//   %4 = gpu.shuffle %3
-//
-// By swapping the trunc with the or, we get a trunc-ext-shuffle sequence, which
-// can be rewritten to shuffle-trunc-ext. If there is another copy of the
-// pattern afterwards, we can push the truncs/exts further down.
-template <typename Op>
-struct RewriteTruncBitExt : OpRewritePattern<Op> {
-  using OpRewritePattern<Op>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(Op op,
-                                PatternRewriter& rewriter) const override {
-    mlir::Value lhs = FindNarrowestValueInChain(op.getLhs());
-    mlir::Value rhs = FindNarrowestValueInChain(op.getRhs());
-
-    if (lhs.getType() != rhs.getType()) {
-      return rewriter.notifyMatchFailure(op, "mismatched narrowest types");
-    }
-
-    auto trunci_lhs = lhs.getDefiningOp<mlir::arith::TruncIOp>();
-    auto trunci_rhs = rhs.getDefiningOp<mlir::arith::TruncIOp>();
-    if (!trunci_lhs && !trunci_rhs) {
-      return rewriter.notifyMatchFailure(
-          op, "neither narrowest value is the result of a truncation");
-    }
-
-    auto wide_type =
-        (trunci_lhs ? trunci_lhs : trunci_rhs).getOperand().getType();
-    if (trunci_rhs && trunci_rhs.getOperand().getType() != wide_type) {
-      return rewriter.notifyMatchFailure(op, "mismatched truncation types");
-    }
-
-    mlir::Value new_lhs = trunci_lhs ? trunci_lhs.getOperand()
-                                     : rewriter.create<mlir::arith::ExtUIOp>(
-                                           op.getLoc(), wide_type, lhs);
-    mlir::Value new_rhs = trunci_rhs ? trunci_rhs.getOperand()
-                                     : rewriter.create<mlir::arith::ExtUIOp>(
-                                           op.getLoc(), wide_type, rhs);
-    mlir::Value new_op = rewriter.create<Op>(op.getLoc(), new_lhs, new_rhs);
-    rewriter.replaceOpWithNewOp<mlir::arith::TruncIOp>(op, op.getType(),
-                                                       new_op);
-
-    return mlir::success();
-  }
-};
-
-// Rewrites trunc-ext-shuffle to shuffle-trunc-ext. This pattern is designed to
-// work together with RewriteTruncBitExt to optimize pred reductions.
-struct RewriteTruncExtShuffle : public OpRewritePattern<mlir::gpu::ShuffleOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(mlir::gpu::ShuffleOp op,
-                                PatternRewriter& rewriter) const override {
-    auto ext = op.getOperand(0).getDefiningOp<mlir::arith::ExtUIOp>();
-    if (!ext) {
-      return rewriter.notifyMatchFailure(op, "no ext");
-    }
-    auto trunc = ext.getOperand().getDefiningOp<mlir::arith::TruncIOp>();
-    if (!trunc || trunc.getOperand().getType() != ext.getType()) {
-      return rewriter.notifyMatchFailure(op, "no trunc or type mismatch");
-    }
-    rewriter.setInsertionPointAfter(op);
-    auto new_trunc = rewriter.create<mlir::arith::TruncIOp>(
-        op.getLoc(), trunc.getType(), op.getResult(0));
-    auto new_ext = rewriter.create<mlir::arith::ExtUIOp>(
-        op.getLoc(), ext.getType(), new_trunc.getResult());
-    rewriter.modifyOpInPlace(op,
-                             [&]() { op->setOperand(0, trunc.getOperand()); });
-    rewriter.replaceAllUsesExcept(op.getResult(0), new_ext, new_trunc);
-    return mlir::success();
-  }
-};
-
-void AnnotateRanges(mlir::func::FuncOp func) {
-  func->walk([](mlir::Operation* op) {
-    if (op->getNumResults() != 1) {
-      return;
-    }
-
-    auto result = op->getResult(0);
-    if (GetRange(result).has_value()) {
-      return;
-    }
-
-    auto get_range = [](mlir::Value value) -> Interval {
-      auto range = GetRange(value);
-      if (range) {
-        return *range;
-      }
-      return {std::numeric_limits<int64_t>::min(),
-              std::numeric_limits<int64_t>::max()};
-    };
-
-    std::optional<Interval> out_range = std::nullopt;
-    if (mlir::isa<mlir::arith::MaxSIOp, mlir::arith::MinSIOp,
-                  mlir::arith::AddIOp, mlir::arith::MulIOp>(op)) {
-      auto lhs_range = get_range(op->getOperand(0));
-      auto rhs_range = get_range(op->getOperand(1));
-      if (mlir::isa<mlir::arith::MaxSIOp>(op)) {
-        out_range = lhs_range.max(rhs_range);
-      } else if (mlir::isa<mlir::arith::MinSIOp>(op)) {
-        out_range = lhs_range.min(rhs_range);
-      } else if (mlir::isa<mlir::arith::AddIOp>(op)) {
-        out_range = lhs_range + rhs_range;
-      } else {
-        out_range = lhs_range * rhs_range;
-      }
-    }
-
-    if (out_range) {
-      mlir::OpBuilder b(op);
-      op->setAttr("xla.range",
-                  b.getIndexArrayAttr({out_range->lower, out_range->upper}));
-    }
-  });
-}
-
-// Pattern to refine the bounds of an indexing map if some of its operands are
-// bound, e.g. loop induction variables.
-struct RefineConstraints : public OpRewritePattern<ApplyIndexingOp> {
-  using OpRewritePattern<ApplyIndexingOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ApplyIndexingOp indexing_op,
-                                PatternRewriter& rewriter) const override {
-    // Right now, we only handle loop induction variables, but other rules might
-    // be added.
-    IndexingMap indexing_map = indexing_op.getIndexingMap();
-    int64_t dim_count = indexing_map.GetDimensionCount();
-    bool updated_bounds = false;
-    for (mlir::OpOperand& operand : indexing_op->getOpOperands()) {
-      auto range = GetIVRange(operand.get());
-      if (!range) {
-        continue;
-      }
-      auto operand_id = operand.getOperandNumber();
-      Interval& current_interval =
-          operand_id < dim_count
-              ? indexing_map.GetMutableDimensionBound(operand_id)
-              : indexing_map.GetMutableSymbolBound(operand_id - dim_count);
-      if (!range->Contains(current_interval)) {
-        current_interval = current_interval.Intersect(*range);
-        updated_bounds = true;
-      }
-    }
-    if (!updated_bounds) {
-      return rewriter.notifyMatchFailure(indexing_op, "No bounds to refine");
-    }
-    indexing_map.Simplify();
-    rewriter.replaceOpWithNewOp<ApplyIndexingOp>(
-        indexing_op, indexing_op.getOperands(), indexing_map);
-    return mlir::success();
-  }
-};
-
-class SimplifyArithPass
-    : public impl::SimplifyArithPassBase<SimplifyArithPass> {
- public:
-  void runOnOperation() override {
-    auto ctx = &getContext();
-    auto func = getOperation();
-    mlir::RewritePatternSet patterns(ctx);
-    AnnotateRanges(func);
-    // clang-format off
-    patterns.add<
-      RefineConstraints,
-      RewriteCmpI,
-      RewriteMaxSi,
-      RewriteMinSi,
-      RewriteTruncBitExt<mlir::arith::AndIOp>,
-      RewriteTruncBitExt<mlir::arith::OrIOp>,
-      RewriteTruncExtShuffle
-    >(ctx);
-    // clang-format on
-    if (mlir::failed(
-            mlir::applyPatternsAndFoldGreedily(func, std::move(patterns)))) {
-      signalPassFailure();
-    }
-  }
-};
-
-}  // namespace
-
-std::unique_ptr<mlir::Pass> CreateSimplifyArithPass() {
-  return std::make_unique<SimplifyArithPass>();
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/BUILD b/third_party/xla/xla/service/gpu/fusions/mlir/tests/BUILD
deleted file mode 100644
index 69b4bd0..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/BUILD
+++ /dev/null
@@ -1,43 +0,0 @@
-load("//xla:lit.bzl", "lit_test_suite")
-load("//xla:xla.bzl", "xla_cc_binary")
-
-package(
-    # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
-    licenses = ["notice"],
-)
-
-xla_cc_binary(
-    name = "mlir_fusions_opt",
-    srcs = ["mlir_fusions_opt.cc"],
-    deps = [
-        "//xla/mlir_hlo",
-        "//xla/service/gpu/fusions/mlir:passes",
-        "//xla/service/gpu/fusions/mlir/ir:xla_gpu",
-        "@llvm-project//mlir:AffineDialect",
-        "@llvm-project//mlir:ArithDialect",
-        "@llvm-project//mlir:ComplexDialect",
-        "@llvm-project//mlir:DLTIDialect",
-        "@llvm-project//mlir:FuncDialect",
-        "@llvm-project//mlir:FuncExtensions",
-        "@llvm-project//mlir:GPUDialect",
-        "@llvm-project//mlir:LLVMDialect",
-        "@llvm-project//mlir:MathDialect",
-        "@llvm-project//mlir:MlirOptLib",
-        "@llvm-project//mlir:NVVMDialect",
-        "@llvm-project//mlir:SCFDialect",
-        "@llvm-project//mlir:Support",
-        "@llvm-project//mlir:TensorDialect",
-        "@llvm-project//mlir:Transforms",
-        "@llvm-project//mlir:VectorDialect",
-    ],
-)
-
-lit_test_suite(
-    name = "tests",
-    srcs = glob(["*.mlir"]),
-    cfg = "//xla:lit.cfg.py",
-    tools = [
-        ":mlir_fusions_opt",
-        "@llvm-project//llvm:FileCheck",
-    ],
-)
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/canonicalize.mlir b/third_party/xla/xla/service/gpu/fusions/mlir/tests/canonicalize.mlir
deleted file mode 100644
index 17b0f8d..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/canonicalize.mlir
+++ /dev/null
@@ -1,179 +0,0 @@
-// RUN: mlir_fusions_opt %s --split-input-file -canonicalize | FileCheck %s
-
-#map0 = affine_map<()[s0, s1] -> (1 + s0 + s1 mod 3 - s1, s0 mod 2)>
-func.func @simplify_apply_indexing(%s0: index, %s1: index) -> (index, index) {
-  %0:2 = xla_gpu.apply_indexing #map0 [%s0 in [-10, 10], %s1 in [0, 2]]
-  func.return %0#0, %0#1 : index, index
-}
-// CHECK: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 + 1, s0 mod 2)>
-
-// CHECK-LABEL: func.func @simplify_apply_indexing
-// CHECK-SAME:      %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index)
-// CHECK:         xla_gpu.apply_indexing #[[$MAP]][%[[ARG_0]] in [-10, 10]]
-
-// -----
-
-#map0 = affine_map<(d0, d1, d2)[s0, s1] -> (1 + s0 + s1 mod 4 - s1, s0 mod 2, d0 + d2)>
-func.func @simplify_apply_indexing_remove_dims(%d0: index, %d1: index,
-    %d2: index, %s0: index, %s1: index) -> (index, index, index) {
-  %0:3 = xla_gpu.apply_indexing #map0
-    (%d0 in [0, 1], %d1 in [0, 2], %d2 in [0, 3])
-    [%s0 in [-11, 11], %s1 in [0, 3]]
-  func.return %0#0, %0#1, %0#2 : index, index, index
-}
-// CHECK: #[[$MAP:.*]] = affine_map<(d0, d1)[s0] -> (s0 + 1, s0 mod 2, d0 + d1)>
-
-// CHECK-LABEL: func.func @simplify_apply_indexing_remove_dims
-// CHECK-SAME:      %[[ARG_0:[a-zA-Z0-9_]+]]: index,
-// CHECK-SAME:      %[[ARG_1:[a-zA-Z0-9_]+]]: index,
-// CHECK-SAME:      %[[ARG_2:[a-zA-Z0-9_]+]]: index,
-// CHECK-SAME:      %[[ARG_3:[a-zA-Z0-9_]+]]: index,
-// CHECK-SAME:      %[[ARG_4:[a-zA-Z0-9_]+]]: index)
-// CHECK:       xla_gpu.apply_indexing #[[$MAP]]
-// CHECK-SAME:    (%[[ARG_0]] in [0, 1], %[[ARG_2]] in [0, 3])
-// CHECK-SAME:    [%[[ARG_3]] in [-11, 11]]
-
-// -----
-
-#map0 = affine_map<(d0, d1)[s0] -> (d0 + s0, 4, d1, 1, s0)>
-func.func @fold_indexing_map_results(%d0: index, %d1: index, %s0: index)
-    -> (index, index, index, index, index) {
-  %0:5 = xla_gpu.apply_indexing #map0 (%d0 in [-10, 10], %d1 in [0, 2])[%s0 in [-1, 1]]
-  func.return %0#0, %0#1, %0#2, %0#3, %0#4  : index, index, index, index, index
-}
-// CHECK: #[[$MAP:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
-
-// CHECK-LABEL: func.func @fold_indexing_map_results
-// CHECK-SAME:  %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index, %[[ARG_2:.*]]: index)
-
-// CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG:   %[[C4:.*]] = arith.constant 4 : index
-
-// CHECK:       %[[NEW_RESULT:.*]] = xla_gpu.apply_indexing #[[$MAP]]
-// CHECK:       return %[[NEW_RESULT]], %[[C4]], %[[ARG_1]], %[[C1]], %[[ARG_2]]
-
-// -----
-
-#map0 = affine_map<(d0, d1)[s0] -> (d0 + s0, s0 + 4, d1 mod 2, 1 + d1, s0)>
-func.func @remove_unused_results(%d0: index, %d1: index, %s0: index) -> (index) {
-  %0:5 = xla_gpu.apply_indexing #map0 (%d0 in [-10, 10], %d1 in [0, 2])[%s0 in [-1, 1]]
-  func.return %0#2 : index
-}
-// CHECK: #[[$MAP:.*]] = affine_map<(d0) -> (d0 mod 2)>
-
-// CHECK-LABEL: func.func @remove_unused_results
-// CHECK-SAME:  %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index, %[[ARG_2:.*]]: index)
-
-// CHECK:       %[[NEW_RESULT:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[ARG_1]] in [0, 2])
-// CHECK:       return %[[NEW_RESULT]]
-
-// -----
-
-#map0 = affine_map<(d0, d1)[s0, s1] -> (d0 + d1 + s0 + s1 mod 3)>
-func.func @fold_operands(%d0: index) -> index {
-  %d1 = arith.constant 1 : index
-  %s0 = arith.constant 2 : index
-  %s1 = arith.constant 3 : index
-  %0 = xla_gpu.apply_indexing #map0 (%d0 in [0, 10], %d1 in [0, 5])
-                                    [%s0 in [-10, 10], %s1 in [0, 4]]
-  func.return %0 : index
-}
-// CHECK: #[[$MAP:.*]] = affine_map<(d0) -> (d0 + 3)>
-
-// CHECK-LABEL: func.func @fold_operands
-// CHECK-SAME:      %[[ARG_0:.*]]: index)
-// CHECK:         xla_gpu.apply_indexing #[[$MAP]](%[[ARG_0]] in [0, 10])
-
-// -----
-
-func.func @fold_operands_and_results(%arg0: index, %arg1: index)
-  -> (index, index) {
-  %0:2 = xla_gpu.apply_indexing affine_map<(d0, d1) -> (0, d1)>
-    (%arg0 in [0, 4], %arg1 in [0, 5])
-  return %0#0, %0#1 : index, index
-}
-
-// CHECK-LABEL: func.func @fold_operands_and_results
-// CHECK-SAME:      %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index)
-// CHECK-NEXT: %[[C0:.*]] = arith.constant 0
-// CHECK-NEXT: return %[[C0]], %[[ARG_1]] : index, index
-
-// -----
-
-func.func @fold_sequence(%arg0: index, %arg1: index) -> index {
-  %0 = xla_gpu.apply_indexing affine_map<(d0, d1) -> (d0 + d1)>
-    (%arg0 in [0, 5], %arg1 in [0, 4])
-  %1 = xla_gpu.apply_indexing affine_map<(d0) -> (d0 mod 100 + 42)>
-    (%0 in [0, 10000])
-  func.return %1 : index
-}
-
-// CHECK: #[[$MAP:.*]] = affine_map<(d0, d1) -> (d0 + d1 + 42)>
-// CHECK-LABEL: func.func @fold_sequence
-// CHECK-SAME:      %[[ARG0:.*]]: index, %[[ARG1:.*]]: index)
-// CHECK-NEXT:  xla_gpu.apply_indexing #[[$MAP]]
-// CHECK-SAME:      (%[[ARG0]] in [0, 5], %[[ARG1]] in [0, 4])
-
-// -----
-
-func.func @fold_sequence_sym(%arg0: index, %arg1: index) -> index {
-  %0 = xla_gpu.apply_indexing affine_map<(d0, d1) -> (d0 + d1)>
-    (%arg0 in [0, 5], %arg1 in [0, 4])
-  %1 = xla_gpu.apply_indexing affine_map<()[s0] -> (s0 mod 100 + 42)>
-    [%0 in [0, 10000]]
-  func.return %1 : index
-}
-
-// CHECK: #[[$MAP:.*]] = affine_map<(d0, d1) -> (d0 + d1 + 42)>
-// CHECK-LABEL: func.func @fold_sequence_sym
-// CHECK-SAME:      %[[ARG0:.*]]: index, %[[ARG1:.*]]: index)
-// CHECK-NEXT:  xla_gpu.apply_indexing #[[$MAP]]
-// CHECK-SAME:      (%[[ARG0]] in [0, 5], %[[ARG1]] in [0, 4])
-
-// -----
-
-func.func @fold_sequence_shared_operands(%arg0: index, %arg1: index) -> index {
-  %0 = xla_gpu.apply_indexing affine_map<(d0, d1) -> (d0 + d1)>
-    (%arg0 in [0, 5], %arg1 in [0, 4])
-  %1 = xla_gpu.apply_indexing affine_map<(d0, d1) -> (d0 + d1)>
-    (%arg1 in [0, 4], %0 in [0, 10000])
-  func.return %1 : index
-}
-
-// CHECK: #[[$MAP:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1)>
-// CHECK-LABEL: func.func @fold_sequence_shared_operands
-// CHECK-SAME:      %[[ARG0:.*]]: index, %[[ARG1:.*]]: index)
-// CHECK-NEXT:  xla_gpu.apply_indexing #[[$MAP]]
-// CHECK-SAME:      (%[[ARG1]] in [0, 4], %[[ARG0]] in [0, 5])
-
-// -----
-
-func.func @atomic_rmw_empty(%in: tensor<2x3xf32>, %i: index, %j: index)
-    -> (tensor<2x3xf32>) {
-  %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x3xf32> {
-    ^bb0(%current : f32):
-      xla_gpu.yield %current : f32
-  }
-  return %ret : tensor<2x3xf32>
-}
-// CHECK-LABEL: func.func @atomic_rmw_empty
-// CHECK-SAME:      %[[ARG0:.*]]: tensor<2x3xf32>
-// CHECK: return %[[ARG0]]
-
-
-// -----
-
-func.func @atomic_rmw_cst(%in: tensor<2x3xf32>, %i: index, %j: index)
-    -> (tensor<2x3xf32>) {
-  %cst = arith.constant 0.0 : f32
-  %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x3xf32> {
-    ^bb0(%current : f32):
-      xla_gpu.yield %cst : f32
-  }
-  return %ret : tensor<2x3xf32>
-}
-// CHECK-LABEL: func.func @atomic_rmw_cst
-// CHECK-SAME:      %[[ARG0:.*]]: tensor<2x3xf32>
-// CHECK-NEXT: %[[CST:.*]] = arith.constant
-// CHECK-NEXT: atomic_rmw
-// CHECK:      xla_gpu.yield %[[CST]]
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/flatten_tensors.mlir b/third_party/xla/xla/service/gpu/fusions/mlir/tests/flatten_tensors.mlir
deleted file mode 100644
index ee2c2ae..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/flatten_tensors.mlir
+++ /dev/null
@@ -1,140 +0,0 @@
-// RUN: mlir_fusions_opt %s -split-input-file -xla-gpu-flatten-tensors \
-// RUN: --verify-diagnostics | FileCheck %s
-
-func.func @tensor_extract(
-    %arg0: tensor<2x3xf32, dense<[0, 1]> : tensor<2xi64>>,
-    %arg1: index, %arg2: index) -> f32 {
-  %v = tensor.extract %arg0[%arg1, %arg2]
-      : tensor<2x3xf32, dense<[0, 1]> : tensor<2xi64>>
-  func.return %v : f32
-}
-// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1) -> (d1 * 2 + d0)>
-
-// CHECK-LABEL: func.func @tensor_extract(
-// CHECK-SAME:      %[[SRC:.*]]: tensor<6xf32>,
-// CHECK-SAME:      %[[I:.*]]: index, %[[J:.*]]: index) -> f32 {
-// CHECK:        %[[INDEX:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[I]]
-// CHECK-SAME:     in [0, 1], %[[J]] in [0, 2])
-// CHECK:        tensor.extract %[[SRC]][%[[INDEX]]] : tensor<6xf32>
-
-// -----
-
-func.func @tensor_insert(
-    %arg0: tensor<10x24xcomplex<f32>>) -> tensor<10x24xcomplex<f32>> {
-  %c1 = arith.constant 1 : index
-  %real = arith.constant 3.0 : f32
-  %imag = arith.constant 2.0 : f32
-  %complex = complex.create %real, %imag : complex<f32>
-  %out = tensor.insert %complex into %arg0[%c1, %c1] : tensor<10x24xcomplex<f32>>
-  func.return %out : tensor<10x24xcomplex<f32>>
-}
-// CHECK-LABEL: func.func @tensor_insert(
-// CHECK-SAME:      %[[TENSOR:.*]]: tensor<240xcomplex<f32>>) -> tensor<240xcomplex<f32>> {
-// CHECK:         %[[INDEX:.*]] = arith.constant 25
-// CHECK:         %[[COMPLEX:.*]] = complex.create
-// CHECK:         tensor.insert %[[COMPLEX]] into %[[TENSOR]][%[[INDEX]]]
-// CHECK-SAME:      : tensor<240xcomplex<f32>>
-
-// -----
-
-func.func @atomic_rmw(%in: tensor<2x4xf32>, %i: index, %j: index)
-    -> (tensor<2x4xf32>) {
-  %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xf32> {
-    ^bb0(%current : f32):
-      %c42 = arith.constant 1.0 : f32
-      %add = arith.minimumf %current, %c42 : f32
-      xla_gpu.yield %add : f32
-  }
-  return %ret : tensor<2x4xf32>
-}
-// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1) -> (d0 * 4 + d1)>
-
-// CHECK-LABEL: func.func @atomic_rmw(
-// CHECK-SAME:      %[[TENSOR:.*]]: tensor<8xf32>, %[[I:.*]]: index,
-// CHECK-SAME:      %[[J:.*]]: index) -> tensor<8xf32> {
-// CHECK:         %[[INDEX:.*]] = xla_gpu.apply_indexing #[[$MAP]]
-// CHECK-SAME:      (%[[I]] in [0, 1], %[[J]] in [0, 3])
-// CHECK:         xla_gpu.atomic_rmw %[[TENSOR]][%[[INDEX]]] : tensor<8xf32>
-
-// -----
-
-func.func @for_loop(%t0: tensor<32x1024xf32>, %t1: tensor<64x8x4xf32>)
-    -> (tensor<32x1024xf32>, tensor<64x8x4xf32>, f32) {
-  %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-  %c32 = arith.constant 32 : index
-  %c64 = arith.constant 64 : index
-  %c0_f32 = arith.constant 0.0 : f32
-  %for:2 = scf.for %i = %c0 to %c64 step %c32 iter_args(%t0_ = %t0, %t1_ = %t1)
-    -> (tensor<32x1024xf32>, tensor<64x8x4xf32>) {
-    %update0 = tensor.insert %c0_f32 into %t0_[%c1, %i] : tensor<32x1024xf32>
-    %update1 = tensor.insert %c0_f32 into %t1_[%i, %c1, %c1] : tensor<64x8x4xf32>
-    scf.yield %update0, %update1 : tensor<32x1024xf32>, tensor<64x8x4xf32>
-  } {some_attr}
-    return %for#0, %for#1, %c0_f32 : tensor<32x1024xf32>, tensor<64x8x4xf32>, f32
-}
-
-// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 + 1024)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (d0 * 32 + 5)>
-// CHECK-LABEL: func.func @for_loop(
-// CHECK-SAME:      %[[T0:.*]]: tensor<32768xf32>,
-// CHECK-SAME:      %[[T1:.*]]: tensor<2048xf32>) -> (tensor<32768xf32>, tensor<2048xf32>, f32) {
-
-// CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG:  %[[C32:.*]] = arith.constant 32 : index
-// CHECK-DAG:  %[[C64:.*]] = arith.constant 64 : index
-// CHECK-DAG:  %[[F32:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:      %[[FOR:.*]]:2 = scf.for %[[I:.*]] = %[[C0]] to %[[C64]]
-// CHECK-SAME:     step %[[C32]]
-// CHECK-SAME:     iter_args(%[[T0_:.*]] = %[[T0]], %[[T1_:.*]] = %[[T1]])
-// CHECK:        %[[IND0:.*]] = xla_gpu.apply_indexing #[[$MAP0]](%[[I]] in [0, 1023])
-// CHECK:        %[[UPD0:.*]] = tensor.insert %[[F32]] into %[[T0_]][%[[IND0]]]
-// CHECK:        %[[IND1:.*]] = xla_gpu.apply_indexing #[[$MAP1]](%[[I]] in [0, 63])
-// CHECK:        %[[UPD1:.*]] = tensor.insert %[[F32]] into %[[T1_]][%[[IND1]]]
-// CHECK:        scf.yield %[[UPD0]], %[[UPD1]] : tensor<32768xf32>, tensor<2048xf32>
-
-// -----
-
-#map = affine_map<(d0, d1) -> ((d1 * 128 + d0) floordiv 36)>
-#map1 = affine_map<(d0, d1) -> (((d1 * 128 + d0) floordiv 9) mod 4)>
-#map2 = affine_map<(d0, d1) -> ((d1 * 128 + d0) mod 9)>
-func.func @if_op(%arg0: tensor<4000x4x9xf32>, %arg1: tensor<1400x1xi32>,
-    %arg2: tensor<1400x1x4x9xf32>, %arg3: tensor<4000x4x9xf32>)
-     -> tensor<4000x4x9xf32> {
-	%c0 = arith.constant 0 : index
-	%c3999 = arith.constant 3999 : index
-	%th_x = gpu.thread_id  x {xla.range = [0 : index, 127 : index]}
-	%bl_x = gpu.block_id  x {xla.range = [0 : index, 393749 : index]}
-	%0 = xla_gpu.apply_indexing #map(%th_x in [0, 127], %bl_x in [0, 393749])
-	%extracted = tensor.extract %arg1[%0, %c0] : tensor<1400x1xi32>
-	%1 = arith.index_cast %extracted : i32 to index
-	%2 = arith.cmpi ule, %1, %c3999 : index
-	%3 = scf.if %2 -> (tensor<4000x4x9xf32>) {
-		%4 = xla_gpu.apply_indexing #map1(%th_x in [0, 127], %bl_x in [0, 393749])
-		%5 = xla_gpu.apply_indexing #map2(%th_x in [0, 127], %bl_x in [0, 393749])
-		%elem = tensor.extract %arg2[%0, %c0, %4, %5] : tensor<1400x1x4x9xf32>
-		%atomic_rmw = xla_gpu.atomic_rmw %arg3[%1, %4, %5] : tensor<4000x4x9xf32> {
-		^bb0(%arg4: f32):
-			%6 = arith.addf %arg4, %elem : f32
-			xla_gpu.yield %6 : f32
-		}
-		scf.yield %atomic_rmw : tensor<4000x4x9xf32>
-	} else {
-		scf.yield %arg3 : tensor<4000x4x9xf32>
-	}
-	return %3 : tensor<4000x4x9xf32>
-}
-// CHECK-LABEL: func.func @if_op
-// CHECK-NOT:     builtin.unrealized_conversion_cast
-// CHECK:         scf.if %{{.*}} -> (tensor<144000xf32>) {
-// CHECK-COUNT-2:   scf.yield %{{.*}} : tensor<144000xf32>
-// CHECK:         return %{{.*}} : tensor<144000xf32>
-
-// -----
-
-func.func @dangling_cast(%arg0: tensor<6xf32>, %arg1: index) -> i32 {
-  %v = tensor.extract %arg0[%arg1] : tensor<6xf32>
-	%cast = builtin.unrealized_conversion_cast %v : f32 to i32
-  func.return %cast : i32
-}
-// CHECK: FlattenTensorsPass failed to converge
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/indexing_map_attr.mlir b/third_party/xla/xla/service/gpu/fusions/mlir/tests/indexing_map_attr.mlir
deleted file mode 100644
index c5cdeeb..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/indexing_map_attr.mlir
+++ /dev/null
@@ -1,136 +0,0 @@
-// RUN: mlir_fusions_opt %s -split-input-file | mlir_fusions_opt -split-input-file | FileCheck %s
-
-// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map<
-// CHECK-SAME: (d0, d1, d2)[s0] -> (d0),
-// CHECK-SAME: domain:
-// CHECK-SAME: d0 in [1, 2],
-// CHECK-SAME: d1 in [5, 8],
-// CHECK-SAME: d2 in [10, 12],
-// CHECK-SAME: s0 in [0, 32],
-// CHECK-SAME: d0 mod 2 in [0, 1],
-// CHECK-SAME: d0 + s0 in [1, 10]
-// CHECK-SAME: >
-#map = #xla_gpu.indexing_map<(d0, d1, d2)[s0] -> (d0),
-                            domain:
-                            d0 in [1, 2],
-                            d1 in [5, 8],
-                            d2 in [10, 12],
-                            s0 in [0, 32],
-                            d0 mod 2 in [0, 1],
-                            d0 + s0 in [1, 10]
-                            >
-
-func.func private @indexing_map_attr(tensor<32xf64, #map>)
-// CHECK-LABEL: @indexing_map_attr
-// CHECK: tensor<32xf64, #[[$INDEX_MAP]]>
-
-// -----
-
-// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map<
-// CHECK-SAME: (d0, d1)[s0, s1, s2] -> (d0 + s0, d1 + s1, d1 + s2)
-// CHECK-SAME: domain:
-// CHECK-SAME: d0 in [1, 2]
-// CHECK-SAME: d1 in [5, 8]
-// CHECK-SAME: s0 in [0, 10]
-// CHECK-SAME: s1 in [0, 5]
-// CHECK-SAME: s2 in [0, 32]
-// CHECK-SAME: d0 mod 2 in [0, 1]
-// CHECK-SAME: d0 + s0 in [1, 10]
-// CHECK-SAME: d1 + s1 + s2 in [1, 32]
-// CHECK-SAME: >
-#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1, s2] -> (d0 + s0, d1 + s1, d1 + s2),
-                            domain:
-                            d0 in [1, 2],
-                            d1 in [5, 8],
-                            s0 in [0, 10],
-                            s1 in [0, 5],
-                            s2 in [0, 32],
-                            d0 mod 2 in [0, 1],
-                            d0 + s0 in [1, 10],
-                            d1 + s1 + s2 in [1, 32]
-                            >
-func.func private @more_range_vars(tensor<32xf64, #map>)
-// CHECK-LABEL: @more_range_vars
-// CHECK: tensor<32xf64, #[[$INDEX_MAP]]>
-
-// -----
-
-// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map<
-// CHECK-SAME: (d0)[s0] -> (d0)
-// CHECK-SAME: domain:
-// CHECK-SAME: d0 in [0, 100]
-// CHECK-SAME: s0 in [-3, -1]
-// CHECK-SAME: >
-#map = #xla_gpu.indexing_map<(d0)[s0] -> (d0),
-                            domain:
-                            d0 in [0, 100],
-                            s0 in [-3, -1]
-                            >
-func.func private @indexing_map_small(tensor<100xf64, #map>)
-// CHECK-LABEL: @indexing_map_small
-// CHECK: tensor<100xf64, #[[$INDEX_MAP]]>
-
-// -----
-
-// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map<
-// CHECK-SAME: (d0, d1, d2)[s0] -> (d0)
-// CHECK-SAME: domain:
-// CHECK-SAME: d0 in [1, 2]
-// CHECK-SAME: d1 in [5, 8]
-// CHECK-SAME: d2 in [10, 12]
-// CHECK-SAME: s0 in [0, 32]
-// CHECK-SAME: >
-#map = #xla_gpu.indexing_map<(d0, d1, d2)[s0] -> (d0),
-                            domain:
-                            d0 in [1, 2],
-                            d1 in [5, 8],
-                            d2 in [10, 12],
-                            s0 in [0, 32]
-                            >
-func.func private @no_constraints(tensor<32xf64, #map>)
-// CHECK-LABEL: @no_constraints
-// CHECK: tensor<32xf64, #[[$INDEX_MAP]]>
-
-// -----
-
-// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map<
-// CHECK-SAME: ()[s0] -> (s0)
-// CHECK-SAME: domain:
-// CHECK-SAME: s0 in [3, 5]
-// CHECK-SAME: s0 mod 2 in [0, 1]
-// CHECK-SAME: >
-#map = #xla_gpu.indexing_map<()[s0] -> (s0),
-                            domain:
-                            s0 in [3, 5],
-                            s0 mod 2 in [0, 1]
-                            >
-func.func private @no_dimensions(tensor<100xf64, #map>)
-// CHECK-LABEL: @no_dimensions
-// CHECK: tensor<100xf64, #[[$INDEX_MAP]]>
-
-// -----
-
-// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map<
-// CHECK-SAME: (d0) -> (d0)
-// CHECK-SAME: domain:
-// CHECK-SAME: d0 in [3, 5]
-// CHECK-SAME: d0 mod 2 in [0, 1]
-// CHECK-SAME: >
-#map = #xla_gpu.indexing_map<(d0) -> (d0),
-                            domain:
-                            d0 in [3, 5],
-                            d0 mod 2 in [0, 1]
-                            >
-func.func private @no_symbols(tensor<100xf64, #map>)
-// CHECK-LABEL: @no_symbols
-// CHECK: tensor<100xf64, #[[$INDEX_MAP]]>
-
-// -----
-
-// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map<
-// CHECK-SAME: () -> ()
-// CHECK-SAME: >
-#map = #xla_gpu.indexing_map<() -> ()>
-func.func private @empty(tensor<100xf64, #map>)
-// CHECK-LABEL: @empty
-// CHECK: tensor<100xf64, #[[$INDEX_MAP]]>
\ No newline at end of file
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/invalid.mlir b/third_party/xla/xla/service/gpu/fusions/mlir/tests/invalid.mlir
deleted file mode 100644
index fbef7c0..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/invalid.mlir
+++ /dev/null
@@ -1,8 +0,0 @@
-// RUN: mlir_fusions_opt  %s -split-input-file -verify-diagnostics
-
-#map0 = affine_map<(d0, d1)[s0] -> (d0, d1 + s0)>
-func.func @apply_indexing(%d0: index, %d1: index, %s0: index) -> (index, index) {
-  // expected-error @+1 {{operand, lower_bounds, upper_bounds count and affine map dimension and symbol count must match}}
-  %0:2 = xla_gpu.apply_indexing #map0 (%d0 in [0, 2])
-  func.return %0#0, %0#1 : index, index
-}
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/lower_tensors.mlir b/third_party/xla/xla/service/gpu/fusions/mlir/tests/lower_tensors.mlir
deleted file mode 100644
index 2125e6f..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/lower_tensors.mlir
+++ /dev/null
@@ -1,854 +0,0 @@
-// RUN: mlir_fusions_opt %s --allow-unregistered-dialect -split-input-file \
-// RUN: -xla-gpu-lower-tensors="is_amd_gpu=false gpu_arch=6.0" \
-// RUN: | FileCheck %s
-
-// RUN: mlir_fusions_opt %s --allow-unregistered-dialect -split-input-file \
-// RUN: -xla-gpu-lower-tensors="is_amd_gpu=false gpu_arch=7.0" \
-// RUN: | FileCheck %s --check-prefix=CHECK-VOLTA
-
-// RUN: mlir_fusions_opt %s --allow-unregistered-dialect -split-input-file \
-// RUN: -xla-gpu-lower-tensors="is_amd_gpu=false gpu_arch=8.0" \
-// RUN: | FileCheck %s --check-prefix=CHECK-AMPERE
-
-// RUN: mlir_fusions_opt %s --allow-unregistered-dialect -split-input-file \
-// RUN: -xla-gpu-lower-tensors="is_amd_gpu=false gpu_arch=9.0" \
-// RUN: | FileCheck %s --check-prefix=CHECK-HOPPER
-
-// RUN: mlir_fusions_opt %s --allow-unregistered-dialect -split-input-file \
-// RUN: -xla-gpu-lower-tensors="is_amd_gpu=true gpu_arch=gfx908:sramecc+:xnack" \
-// RUN: | FileCheck %s --check-prefix=CHECK-GFX908-MI100
-
-// RUN: mlir_fusions_opt %s --allow-unregistered-dialect -split-input-file \
-// RUN: -xla-gpu-lower-tensors="is_amd_gpu=true gpu_arch=gfx90a:sramecc+:xnack" \
-// RUN: | FileCheck %s --check-prefix=CHECK-GFX90A-MI200
-
-module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<index, 32 : i32>>} {
-  func.func private @add(%arg0: f32, %arg1: f32) -> f32 {
-    %sum = arith.addf %arg0, %arg1 : f32
-    func.return %sum : f32
-  }
-
-  func.func private @tensorarg(%arg0: tensor<43xf32> {xla.invariant, xla.slice_index = 0}, %arg1: index) -> f32 {
-    %v1 = arith.constant 2.0 : f32
-    %v2 = tensor.extract %arg0[%arg1] : tensor<43xf32>
-    %sum = func.call @add(%v1, %v2) : (f32, f32) -> f32
-    func.return %sum : f32
-  }
-
-  func.func @tensorcall(%arg0: tensor<43xf32> {xla.slice_index = 0}, %arg1: index) -> f32 {
-    %call = func.call @tensorarg(%arg0, %arg1) : (tensor<43xf32>, index) -> f32
-    func.return %call : f32
-  }
-
-  func.func @stores(%arg0: tensor<17xf32> {xla.slice_index = 0}, %arg1: tensor<43xf32> {xla.slice_index = 1}) -> tensor<43xf32> {
-    %c17 = arith.constant 17 : index
-    %c23 = arith.constant 23 : index
-    %cst = arith.constant 3.0 : f32
-    %out = tensor.insert %cst into %arg1[%c17] : tensor<43xf32>
-    %out2 = tensor.insert %cst into %out[%c23] : tensor<43xf32>
-    func.return %out2 : tensor<43xf32>
-  }
-}
-
-// CHECK:        func.func private @add(%{{.*}}: f32, %{{.*}}: f32) -> f32 {
-// CHECK-NEXT:     arith.addf
-// CHECK-NEXT:     return
-
-// CHECK:        func.func private @tensorarg(%[[ARG0:.*]]: !llvm.ptr
-// CHECK-SAME:        {xla.invariant, xla.slice_index = 0 : i64}, %[[ARG1:.*]]: index) -> f32 {
-// CHECK-DAG:       %[[C2:.*]] = arith.constant 2.000000e+00
-// CHECK-DAG:       %[[IDX:.*]] = arith.index_castui %[[ARG1]] : index to i32
-// CHECK-DAG:       %[[PTR:.*]] = llvm.getelementptr inbounds %[[ARG0]][%[[IDX]]]
-// CHECK-DAG:       %[[V2:.*]] = llvm.load %[[PTR]] invariant
-// CHECK:           %[[RET:.*]] = call @add(%[[C2]], %[[V2]])
-// CHECK:           return %[[RET]]
-
-// CHECK:        func.func @tensorcall(%[[ARG0:.*]]: !llvm.ptr
-// CHECK-SAME:        {xla.slice_index = 0 : i64}, %[[ARG1:.*]]: index)
-// CHECK:           %[[RET:.*]] = call @tensorarg(%[[ARG0]], %[[ARG1]])
-// CHECK:           return %[[RET]]
-
-// CHECK:        func.func @stores(
-// CHECK-SAME:        %[[ARG0:.*]]: !llvm.ptr {xla.slice_index = 0 : i64},
-// CHECK-SAME:        %[[ARG1:.*]]: !llvm.ptr {xla.slice_index = 1 : i64})
-// CHECK-NEXT:      %[[CST:.*]] = arith.constant 3.000000e+00 : f32
-// CHECK-NEXT:      %[[PTR1:.*]] = llvm.getelementptr inbounds %[[ARG1]][17]
-// CHECK-NEXT:      llvm.store %[[CST]], %[[PTR1]]
-// CHECK-NEXT:      %[[PTR2:.*]] = llvm.getelementptr inbounds %[[ARG1]][23]
-// CHECK-NEXT:      llvm.store %[[CST]], %[[PTR2]]
-// CHECK-NEXT:      return
-
-// -----
-
-module {
-  func.func @layout(
-      %arg0: tensor<2x3xf32, dense<[0, 1]> : tensor<2xi64>>,
-      %arg1: index, %arg2: index) -> f32 {
-    %v = tensor.extract %arg0[%arg1, %arg2]
-        : tensor<2x3xf32, dense<[0, 1]> : tensor<2xi64>>
-    func.return %v : f32
-  }
-}
-
-// CHECK:        #[[$MAP:.*]] = affine_map<(d0, d1) -> (d1 * 2 + d0)>
-// CHECK-LABEL:  @layout(
-// CHECK-SAME:      %[[ARG0:.*]]: !llvm.ptr,
-// CHECK-SAME:      %[[X:.*]]: index, %[[Y:.*]]: index
-// CHECK:        %[[IDX:.*]] = xla_gpu.apply_indexing #[[$MAP]]
-// CHECK-SAME:      (%[[X]] in [0, 1], %[[Y]] in [0, 2])
-// CHECK:        %[[IDX_CAST:.*]] = arith.index_castui %[[IDX]] : index to i64
-// CHECK:        %[[PTR:.*]] = llvm.getelementptr inbounds %[[ARG0]][%[[IDX_CAST]]]
-// CHECK:        llvm.load %[[PTR]]
-
-// -----
-
-module {
-  func.func @store_control_flow(
-     %arg0: tensor<2xf32>,
-     %arg1: index
-  ) -> tensor<2xf32> {
-    %c0 = arith.constant 0 : index
-    %c1 = arith.constant 1 : index
-    %c2 = arith.constant 2 : index
-    %cst = arith.constant 0.0 : f32
-    %cst2 = arith.constant 1.0 : f32
-
-    %for = scf.for %i = %c0 to %c2 step %c1 iter_args(%arg2 = %arg0) -> tensor<2xf32> {
-      %new_out = tensor.insert %cst into %arg2[%i] : tensor<2xf32>
-      scf.yield %new_out : tensor<2xf32>
-    }
-
-    %inbounds = arith.cmpi sle, %arg1, %c1 : index
-    %result = scf.if %inbounds -> tensor<2xf32> {
-      %if = tensor.insert %cst2 into %for[%arg1] : tensor<2xf32>
-      scf.yield %if : tensor<2xf32>
-    } else {
-      scf.yield %for : tensor<2xf32>
-    }
-    func.return %result : tensor<2xf32>
-  }
-}
-
-// CHECK:     @store_control_flow(%[[ARG0:.*]]: !llvm.ptr, %[[X:.*]]: index) {
-// CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
-// CHECK:       scf.for %[[I:.*]] = %[[C0]] to %[[C2]] step %[[C1]] {
-// CHECK:         %[[CAST:.*]] = arith.index_castui %[[I]] : index to i64
-// CHECK:         %[[PTR:.*]] = llvm.getelementptr inbounds %[[ARG0]][%[[CAST]]]
-// CHECK:         llvm.store {{.*}}, %[[PTR]]
-// CHECK:       %[[INBOUNDS:.*]] = arith.cmpi
-// CHECK:       scf.if %[[INBOUNDS]] {
-// CHECK:         llvm.store
-// CHECK-NEXT:  }
-// CHECK-NEXT:  return
-
-// -----
-
-module {
-  func.func @large_tensor(
-      %arg0: tensor<1024x1024x1024x6xf32>,
-      %arg1: index) -> f32 {
-    %v = tensor.extract %arg0[%arg1, %arg1, %arg1, %arg1] : tensor<1024x1024x1024x6xf32>
-    func.return %v : f32
-  }
-}
-
-// CHECK: @large_tensor
-// CHECK: arith.index_castui {{.*}} : index to i64
-
-// -----
-
-module {
-  func.func @extract_from_constant(%arg0: tensor<2x1xf32>,
-      %arg1: index, %arg2: index) -> f32 {
-    %cst = arith.constant dense<[[1.000000e+00], [2.000000e+00]]> : tensor<2x1xf32>
-    %extracted = tensor.extract %arg0[%arg1, %arg2] : tensor<2x1xf32>
-    %extracted_0 = tensor.extract %cst[%arg1, %arg2] : tensor<2x1xf32>
-    %0 = arith.addf %extracted, %extracted_0 : f32
-    return %0 : f32
-  }
-}
-// CHECK: llvm.mlir.global private constant @global_cst_0(dense<
-// CHECK-SAME: [1.000000e+00, 2.000000e+00]> : tensor<2xf32>) {addr_space = 0 : i32} : !llvm.array<2 x f32>
-// CHECK: @extract_from_constant
-// CHECK: %[[ADDR_OF:.*]] = llvm.mlir.addressof @global_cst_0 : !llvm.ptr
-// CHECK: %[[GEP:.*]] = llvm.getelementptr inbounds %[[ADDR_OF]][%{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr, f32
-// CHECK: %[[LOAD:.*]] = llvm.load %[[GEP]] : !llvm.ptr -> f32
-// CHECK: %[[ADD:.*]] = arith.addf %{{.*}}, %[[LOAD]] : f32
-// CHECK: return %[[ADD]] : f32
-
-// -----
-
-module {
-  func.func @vector_constant() -> vector<2xindex> {
-    %c1 = arith.constant dense<[1, 2]> : vector<2xindex>
-    func.return %c1 : vector<2xindex>
-  }
-}
-
-// vector constants should not be rewritten.
-// CHECK: @vector_constant
-// CHECK-NEXT: arith.constant
-
-// -----
-
-module {
-  func.func @complex_tensor_insert(
-      %arg0: tensor<10xcomplex<f32>>) -> tensor<10xcomplex<f32>> {
-    %c1 = arith.constant 1 : index
-    %real = arith.constant 3.0 : f32
-    %imag = arith.constant 2.0 : f32
-    %complex = complex.create %real, %imag : complex<f32>
-    %out = tensor.insert %complex into %arg0[%c1] : tensor<10xcomplex<f32>>
-    func.return %out : tensor<10xcomplex<f32>>
-  }
-}
-
-// CHECK: @complex_tensor_insert(%[[ARG0:.*]]: !llvm.ptr
-// CHECK: %[[C:.*]] = complex.create
-// CHECK: %[[GEP:.*]] = llvm.getelementptr inbounds %[[ARG0]][1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(f32, f32)>
-// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[C]] : complex<f32> to !llvm.struct<(f32, f32)>
-// CHECK: llvm.store %[[CAST]], %[[GEP]] : !llvm.struct<(f32, f32)>, !llvm.ptr
-
-// -----
-
-module {
-  func.func @complex_tensor_extract(
-      %arg0: tensor<10xcomplex<f32>>) -> complex<f32> {
-    %c1 = arith.constant 1 : index
-    %v2 = tensor.extract %arg0[%c1] : tensor<10xcomplex<f32>>
-    func.return %v2 : complex<f32>
-  }
-}
-
-// CHECK: @complex_tensor_extract(%[[ARG0:.*]]: !llvm.ptr
-// CHECK: %[[GEP:.*]] = llvm.getelementptr inbounds %[[ARG0]][1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(f32, f32)>
-// CHECK: %[[LOAD:.*]] = llvm.load %[[GEP]] : !llvm.ptr -> !llvm.struct<(f32, f32)>
-// CHECK: builtin.unrealized_conversion_cast %[[LOAD]] : !llvm.struct<(f32, f32)> to complex<f32>
-
-// -----
-
-module {
-  // This example is a bit silly, in real life there wouldn't be a loop (the
-  // loop body would be executed by different threads). We're just doing it this
-  // way so control flow with shared memory is tested as well.
-  func.func @transpose_shared(%in: tensor<32x32xf32>,
-                              %out: tensor<32x32xf32>) -> tensor<32x32xf32> {
-    %c0 = arith.constant 0 : index
-    %c1 = arith.constant 1 : index
-    %c32 = arith.constant 32 : index
-
-    %shared = xla_gpu.allocate_shared : tensor<32x32xf32>
-    %loaded_tile = scf.for %i = %c0 to %c32 step %c1
-        iter_args(%tile = %shared) -> tensor<32x32xf32> {
-      %inner_loaded_tile = scf.for %j = %c0 to %c32 step %c1
-          iter_args(%inner_tile = %tile) -> tensor<32x32xf32> {
-        %v = tensor.extract %in[%i, %j] : tensor<32x32xf32>
-        %inserted = tensor.insert %v into %inner_tile[%i, %j]
-            : tensor<32x32xf32>
-        scf.yield %inserted : tensor<32x32xf32>
-      }
-      scf.yield %inner_loaded_tile : tensor<32x32xf32>
-    }
-
-    %synced = xla_gpu.sync_threads %shared : tensor<32x32xf32>
-    %written_tile = scf.for %i = %c0 to %c32 step %c1
-        iter_args(%written = %out) -> tensor<32x32xf32> {
-      %inner_written_tile = scf.for %j = %c0 to %c32 step %c1
-          iter_args(%inner_written = %written) -> tensor<32x32xf32> {
-        %v = tensor.extract %shared[%j, %i] : tensor<32x32xf32>
-        %inserted = tensor.insert %v into %inner_written[%i, %j]
-            : tensor<32x32xf32>
-        scf.yield %inserted : tensor<32x32xf32>
-      }
-      scf.yield %inner_written_tile : tensor<32x32xf32>
-    }
-
-    return %written_tile : tensor<32x32xf32>
-  }
-}
-
-// CHECK:      llvm.mlir.global private @[[SHARED:shared_.*]]()
-// CHECK-SAME:     {addr_space = 3 : i32} : !llvm.array<1024 x f32>
-// CHECK:      @transpose_shared
-// CHECK:        %[[ADDR:.*]] = llvm.mlir.addressof @[[SHARED]] : !llvm.ptr<3>
-// CHECK:        %[[CAST:.*]] = llvm.addrspacecast %[[ADDR]]
-// CHECK-SAME:       : !llvm.ptr<3> to !llvm.ptr
-// CHECK:        scf.for
-// CHECK:          scf.for
-// CHECK:            %[[ELEM_ADDR:.*]] = llvm.getelementptr inbounds %[[CAST]]
-// CHECK:            llvm.store {{.*}} %[[ELEM_ADDR]]
-// CHECK:        gpu.barrier
-// CHECK:        scf.for
-// CHECK:          scf.for
-// CHECK:            %[[ELEM_ADDR:.*]] = llvm.getelementptr inbounds %[[CAST]]
-// CHECK:            llvm.load %[[ELEM_ADDR]]
-
-// -----
-
-module {
-  func.func @atomic_rmw_f32(%in: tensor<2x4xf32>, %i: index, %j: index)
-      -> (tensor<2x4xf32>) {
-    %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xf32> {
-      ^bb0(%current : f32):
-        %c42 = arith.constant 1.0 : f32
-        %add = arith.minimumf %current, %c42 : f32
-        xla_gpu.yield %add : f32
-    }
-    return %ret : tensor<2x4xf32>
-  }
-}
-
-// CHECK: @atomic_rmw_f32
-// CHECK: %[[ADDR:.*]] = llvm.getelementptr
-// CHECK-NEXT: %[[INIT:.*]] = llvm.load %[[ADDR]]
-// CHECK-NEXT: scf.while (%[[VAR:.*]] = %[[INIT]])
-// CHECK: %[[RES:.*]] = llvm.bitcast %{{.*}} : f32 to i32
-// CHECK-NEXT: llvm.cmpxchg %[[ADDR]], %[[VAR]], %[[RES]]
-
-// -----
-
-module {
-  func.func @atomic_rmw_f16(%in: tensor<2x4xf16>, %i: index, %j: index)
-      -> (tensor<2x4xf16>) {
-    %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xf16> {
-      ^bb0(%current : f16):
-        %c1 = arith.constant 1.0 : f16
-        %add = arith.addf %current, %c1 : f16
-        xla_gpu.yield %add : f16
-    }
-    return %ret : tensor<2x4xf16>
-  }
-}
-
-// CHECK: @atomic_rmw_f16
-// CHECK: %[[ADDR:.*]] = llvm.getelementptr
-// CHECK-NEXT: %[[ADDR_INT:.*]] = llvm.ptrtoint %[[ADDR]]
-// CHECK-NEXT: %[[OFFSET:.*]] = llvm.and %[[ADDR_INT]], %{{.*}}
-// CHECK-NEXT: %[[INDEX:.*]] = llvm.mul %[[OFFSET]], %{{.*}}
-// CHECK-NEXT: %[[BASE:.*]] = llvm.getelementptr inbounds %[[ADDR]][%[[INDEX]]]
-// CHECK: %[[INIT:.*]] = llvm.load %[[BASE]]
-// CHECK-NEXT: scf.while (%[[VAR:.*]] = %[[INIT]])
-// CHECK-NEXT: %[[VAR_SHIFT:.*]] = llvm.lshr %[[VAR]], %{{.*}}
-// CHECK-NEXT: %[[VAR_TRUNC:.*]] = llvm.trunc %[[VAR_SHIFT]]
-// CHECK-NEXT: llvm.bitcast %[[VAR_TRUNC]] : i16 to f16
-// CHECK: %[[RES:.*]] = llvm.bitcast %{{.*}} : f16 to i16
-// CHECK-NEXT: %[[RES_WIDE:.*]] = llvm.zext %[[RES]]
-// CHECK-NEXT: %[[NEW_MASKED:.*]] = llvm.and %[[VAR]], %{{.*}}
-// CHECK-NEXT: %[[RES_SHIFT:.*]] = llvm.shl %[[RES_WIDE]], %{{.*}}
-// CHECK-NEXT: %[[NEW:.*]] = llvm.or %[[NEW_MASKED]], %[[RES_SHIFT]]
-// CHECK-NEXT: llvm.cmpxchg %[[BASE]], %[[VAR]], %[[NEW]]
-
-// -----
-
-module {
-  func.func @atomic_rmw_overwrite(%in: tensor<2x4xf16>, %i: index, %j: index)
-      -> (tensor<2x4xf16>) {
-    %c1 = arith.constant 1.0 : f16
-    %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xf16> {
-      ^bb0(%current : f16):
-        xla_gpu.yield %c1 : f16
-    }
-    return %ret : tensor<2x4xf16>
-  }
-}
-// CHECK: @atomic_rmw_overwrite
-// CHECK: %[[ADDR:.*]] = llvm.getelementptr
-// CHECK-NEXT: %[[ADDR_INT:.*]] = llvm.ptrtoint %[[ADDR]]
-// CHECK-NEXT: %[[OFFSET:.*]] = llvm.and %[[ADDR_INT]], %{{.*}}
-// CHECK-NEXT: %[[INDEX:.*]] = llvm.mul %[[OFFSET]], %{{.*}}
-// CHECK-NEXT: %[[BASE:.*]] = llvm.getelementptr inbounds %[[ADDR]][%[[INDEX]]]
-// CHECK: %[[INIT:.*]] = llvm.load %[[BASE]]
-// CHECK-NEXT: scf.while (%[[VAR:.*]] = %[[INIT]])
-// CHECK: %[[RES:.*]] = llvm.bitcast %{{.*}} : f16 to i16
-// CHECK-NEXT: %[[RES_WIDE:.*]] = llvm.zext %[[RES]]
-// CHECK-NEXT: %[[NEW_MASKED:.*]] = llvm.and %[[VAR]], %{{.*}}
-// CHECK-NEXT: %[[RES_SHIFT:.*]] = llvm.shl %[[RES_WIDE]], %{{.*}}
-// CHECK-NEXT: %[[NEW:.*]] = llvm.or %[[NEW_MASKED]], %[[RES_SHIFT]]
-// CHECK-NEXT: llvm.cmpxchg %[[BASE]], %[[VAR]], %[[NEW]]
-
-// -----
-
-module {
-  func.func @shared_complex() -> tensor<10xcomplex<f32>> {
-    %shared = xla_gpu.allocate_shared : tensor<10xcomplex<f32>>
-    return %shared : tensor<10xcomplex<f32>>
-  }
-}
-
-// CHECK: llvm.mlir.global private @{{.*}}() {addr_space = 3 : i32} : !llvm.array<10 x struct<(f32, f32)>>
-// CHECK: @shared_complex
-
-// -----
-
-module {
-  func.func @i4_load_store(%arg: tensor<10xi4>, %i: index, %j: index) -> tensor<10xi4> {
-    %v = tensor.extract %arg[%i] : tensor<10xi4>
-    %r = tensor.insert %v into %arg[%j] : tensor<10xi4>
-    return %r : tensor<10xi4>
-  }
-}
-
-// CHECK: @i4_load_store
-// CHECK: llvm.getelementptr
-// CHECK-SAME: -> !llvm.ptr, i8
-// CHECK: llvm.load
-// CHECK: llvm.getelementptr
-// CHECK-SAME: -> !llvm.ptr, i8
-// CHECK: llvm.load
-// CHECK: llvm.store
-
-// -----
-
-module {
-  func.func @direct_atomic_rmw_overwrite(%in: tensor<2x4xi32>,
-    %i: index, %j: index) -> (tensor<2x4xi32>) {
-    %c2 = arith.constant 2 : i32
-    %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xi32> {
-      ^bb0(%current : i32):
-        xla_gpu.yield %c2 : i32
-    }
-    return %ret : tensor<2x4xi32>
-  }
-}
-// CHECK: @direct_atomic_rmw_overwrite
-// CHECK: %[[C2:.*]] = arith.constant 2
-// CHECK: %[[ADDR:.*]] = llvm.getelementptr
-// CHECK: llvm.store %[[C2]], %[[ADDR]] atomic unordered {alignment = 4 : i64}
-
-// -----
-
-module {
-  func.func @direct_atomic_rmw_addi(%in: tensor<2x4xi32>,
-    %i: index, %j: index) -> (tensor<2x4xi32>) {
-    %c2 = arith.constant 2 : i32
-    %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xi32> {
-      ^bb0(%current : i32):
-        %min = arith.addi %current, %c2 : i32
-        xla_gpu.yield %c2 : i32
-    }
-    return %ret : tensor<2x4xi32>
-  }
-}
-// CHECK: @direct_atomic_rmw_addi
-// CHECK: %[[C2:.*]] = arith.constant 2
-// CHECK: %[[ADDR:.*]] = llvm.getelementptr
-// CHECK: llvm.atomicrmw add %[[ADDR]], %[[C2]] seq_cst
-
-// -----
-
-module {
-  func.func @direct_atomic_rmw_maxsi(%in: tensor<2x4xi32>,
-    %i: index, %j: index) -> (tensor<2x4xi32>) {
-    %c2 = arith.constant 2 : i32
-    %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xi32> {
-      ^bb0(%current : i32):
-        %min = arith.maxsi %current, %c2 : i32
-        xla_gpu.yield %c2 : i32
-    }
-    return %ret : tensor<2x4xi32>
-  }
-}
-// CHECK: @direct_atomic_rmw_maxsi
-// CHECK: %[[C2:.*]] = arith.constant 2
-// CHECK: %[[ADDR:.*]] = llvm.getelementptr
-// CHECK: llvm.atomicrmw max %[[ADDR]], %[[C2]] seq_cst
-
-// -----
-
-module {
-  func.func @direct_atomic_rmw_maxui(%in: tensor<2x4xi32>,
-    %i: index, %j: index) -> (tensor<2x4xi32>) {
-    %c2 = arith.constant 2 : i32
-    %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xi32> {
-      ^bb0(%current : i32):
-        %min = arith.maxui %current, %c2 : i32
-        xla_gpu.yield %c2 : i32
-    }
-    return %ret : tensor<2x4xi32>
-  }
-}
-// CHECK: @direct_atomic_rmw_maxui
-// CHECK: %[[C2:.*]] = arith.constant 2
-// CHECK: %[[ADDR:.*]] = llvm.getelementptr
-// CHECK: llvm.atomicrmw umax %[[ADDR]], %[[C2]] seq_cst
-
-// -----
-
-module {
-  func.func @direct_atomic_rmw_minsi(%in: tensor<2x4xi32>,
-    %i: index, %j: index) -> (tensor<2x4xi32>) {
-    %c2 = arith.constant 2 : i32
-    %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xi32> {
-      ^bb0(%current : i32):
-        %min = arith.minsi %current, %c2 : i32
-        xla_gpu.yield %c2 : i32
-    }
-    return %ret : tensor<2x4xi32>
-  }
-}
-// CHECK: @direct_atomic_rmw_minsi
-// CHECK: %[[C2:.*]] = arith.constant 2
-// CHECK: %[[ADDR:.*]] = llvm.getelementptr
-// CHECK: llvm.atomicrmw min %[[ADDR]], %[[C2]] seq_cst
-
-// -----
-
-module {
-  func.func @direct_atomic_rmw_minui(%in: tensor<2x4xi32>,
-    %i: index, %j: index) -> (tensor<2x4xi32>) {
-    %c2 = arith.constant 2 : i32
-    %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xi32> {
-      ^bb0(%current : i32):
-        %min = arith.minui %current, %c2 : i32
-        xla_gpu.yield %c2 : i32
-    }
-    return %ret : tensor<2x4xi32>
-  }
-}
-// CHECK: @direct_atomic_rmw_minui
-// CHECK: %[[C2:.*]] = arith.constant 2
-// CHECK: %[[ADDR:.*]] = llvm.getelementptr
-// CHECK: llvm.atomicrmw umin %[[ADDR]], %[[C2]] seq_cst
-
-// -----
-
-module {
-  func.func @direct_atomic_rmw_fadd_f32(%in: tensor<2x4xf32>,
-    %i: index, %j: index) -> (tensor<2x4xf32>) {
-    %c2 = arith.constant 2.0 : f32
-    %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xf32> {
-      ^bb0(%current : f32):
-        %min = arith.addf %current, %c2 : f32
-        xla_gpu.yield %c2 : f32
-    }
-    return %ret : tensor<2x4xf32>
-  }
-}
-// CHECK-LABEL: @direct_atomic_rmw_fadd_f32
-// CHECK: %[[C2:.*]] = arith.constant 2
-// CHECK: %[[ADDR:.*]] = llvm.getelementptr
-// CHECK: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] seq_cst
-
-// CHECK-VOLTA-LABEL: @direct_atomic_rmw_fadd_f32
-// CHECK-VOLTA: %[[C2:.*]] = arith.constant 2
-// CHECK-VOLTA: %[[ADDR:.*]] = llvm.getelementptr
-// CHECK-VOLTA: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] seq_cst
-
-// CHECK-AMPERE-LABEL: @direct_atomic_rmw_fadd_f32
-// CHECK-AMPERE: %[[C2:.*]] = arith.constant 2
-// CHECK-AMPERE: %[[ADDR:.*]] = llvm.getelementptr
-// CHECK-AMPERE: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] seq_cst
-
-// CHECK-GFX908-MI100-LABEL: @direct_atomic_rmw_fadd_f32
-// CHECK-GFX908-MI100: %[[C2:.*]] = arith.constant 2
-// CHECK-GFX908-MI100: %[[ADDR:.*]] = llvm.getelementptr
-// CHECK-GFX908-MI100: %[[ADDR_CAST:.*]] = llvm.addrspacecast %[[ADDR]] : !llvm.ptr to !llvm.ptr<1>
-// CHECK-GFX908-MI100: llvm.atomicrmw fadd %[[ADDR_CAST]], %[[C2]] syncscope("agent") seq_cst
-
-// CHECK-GFX90A-MI200-LABEL: @direct_atomic_rmw_fadd_f32
-// CHECK-GFX90A-MI200: %[[C2:.*]] = arith.constant 2
-// CHECK-GFX90A-MI200: %[[ADDR:.*]] = llvm.getelementptr
-// CHECK-GFX90A-MI200: %[[ADDR_CAST:.*]] = llvm.addrspacecast %[[ADDR]] : !llvm.ptr to !llvm.ptr<1>
-// CHECK-GFX90A-MI200: llvm.atomicrmw fadd %[[ADDR_CAST]], %[[C2]] syncscope("agent") seq_cst
-
-// -----
-
-module {
-  func.func @direct_atomic_rmw_fadd_f16(%in: tensor<2x4xf16>,
-    %i: index, %j: index) -> (tensor<2x4xf16>) {
-    %c2 = arith.constant 2.0 : f16
-    %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xf16> {
-      ^bb0(%current : f16):
-        %min = arith.addf %current, %c2 : f16
-        xla_gpu.yield %c2 : f16
-    }
-    return %ret : tensor<2x4xf16>
-  }
-}
-// CHECK-LABEL: @direct_atomic_rmw_fadd_f16
-// CHECK-NOT: llvm.atomicrmw fadd
-
-// CHECK-VOLTA-LABEL: @direct_atomic_rmw_fadd_f16
-// CHECK-VOLTA: %[[C2:.*]] = arith.constant 2
-// CHECK-VOLTA: %[[ADDR:.*]] = llvm.getelementptr
-// CHECK-VOLTA: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] seq_cst
-
-// CHECK-AMPERE-LABEL: @direct_atomic_rmw_fadd_f16
-// CHECK-AMPERE: %[[C2:.*]] = arith.constant 2
-// CHECK-AMPERE: %[[ADDR:.*]] = llvm.getelementptr
-// CHECK-AMPERE: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] seq_cst
-
-// CHECK-GFX908-MI100-LABEL: @direct_atomic_rmw_fadd_f16
-// CHECK-GFX908-MI100-NOT: llvm.atomicrmw fadd
-
-// CHECK-GFX90A-MI200-LABEL: @direct_atomic_rmw_fadd_f16
-// CHECK-GFX90A-MI200: %[[C2:.*]] = arith.constant 2
-// CHECK-GFX90A-MI200: %[[ADDR:.*]] = llvm.getelementptr
-// CHECK-GFX90A-MI200: %[[ADDR_CAST:.*]] = llvm.addrspacecast %[[ADDR]] : !llvm.ptr to !llvm.ptr<1>
-// CHECK-GFX90A-MI200: llvm.atomicrmw fadd %[[ADDR_CAST]], %[[C2]] syncscope("agent") seq_cst
-
-// -----
-
-module {
-  func.func @direct_atomic_rmw_fadd_bf16(%in: tensor<2x4xbf16>,
-      %i: index, %j: index) -> (tensor<2x4xbf16>) {
-    %c2 = arith.constant 2.0 : bf16
-    %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xbf16> {
-      ^bb0(%current : bf16):
-        %min = arith.addf %current, %c2 : bf16
-        xla_gpu.yield %c2 : bf16
-    }
-    return %ret : tensor<2x4xbf16>
-  }
-}
-// CHECK-LABEL: @direct_atomic_rmw_fadd_bf16
-// CHECK-NOT: llvm.atomicrmw fadd
-
-// CHECK-HOPPER-LABEL: @direct_atomic_rmw_fadd_bf16
-// CHECK-HOPPER: %[[C2:.*]] = arith.constant 2
-// CHECK-HOPPER: %[[ADDR:.*]] = llvm.getelementptr
-// CHECK-HOPPER: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] seq_cst
-
-// -----
-
-module {
-  func.func @direct_atomic_rmw_fadd_f64(%in: tensor<2x4xf64>,
-    %i: index, %j: index) -> (tensor<2x4xf64>) {
-    %c2 = arith.constant 2.0 : f64
-    %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xf64> {
-      ^bb0(%current : f64):
-        %min = arith.addf %current, %c2 : f64
-        xla_gpu.yield %c2 : f64
-    }
-    return %ret : tensor<2x4xf64>
-  }
-}
-// CHECK-LABEL: @direct_atomic_rmw_fadd_f64
-// CHECK: %[[C2:.*]] = arith.constant 2
-// CHECK: %[[ADDR:.*]] = llvm.getelementptr
-// CHECK: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] seq_cst
-
-// CHECK-VOLTA-LABEL: @direct_atomic_rmw_fadd_f64
-// CHECK-VOLTA: %[[C2:.*]] = arith.constant 2
-// CHECK-VOLTA: %[[ADDR:.*]] = llvm.getelementptr
-// CHECK-VOLTA: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] seq_cst
-
-// CHECK-AMPERE-LABEL: @direct_atomic_rmw_fadd_f64
-// CHECK-AMPERE: %[[C2:.*]] = arith.constant 2
-// CHECK-AMPERE: %[[ADDR:.*]] = llvm.getelementptr
-// CHECK-AMPERE: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] seq_cst
-
-// CHECK-GFX908-MI100-LABEL: @direct_atomic_rmw_fadd_f64
-// CHECK-GFX908-MI100-NOT: llvm.atomicrmw fadd
-
-// CHECK-GFX90A-MI200-LABEL: @direct_atomic_rmw_fadd_f64
-// CHECK-GFX90A-MI200-NOT: llvm.atomicrmw fadd
-
-// -----
-
-module {
-  func.func @direct_atomic_rmw_maximumf(%in: tensor<2x4xf32>,
-    %i: index, %j: index) -> (tensor<2x4xf32>) {
-    %c2 = arith.constant 2.0 : f32
-    %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xf32> {
-      ^bb0(%current : f32):
-        %min = arith.maximumf %current, %c2 : f32
-        xla_gpu.yield %c2 : f32
-    }
-    return %ret : tensor<2x4xf32>
-  }
-}
-// CHECK-LABEL: @direct_atomic_rmw_maximumf
-
-// CHECK: %[[MODIFIER:.*]] = arith.constant 2.000000e+00 : f32
-// CHECK: %[[NAN:.*]] = llvm.mlir.constant(0x7FC00000 : f32) : f32
-// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
-// CHECK: %[[ADDR:.*]] = llvm.getelementptr
-// CHECK: %[[CURRENT:.*]] = llvm.load %[[ADDR]] : !llvm.ptr -> f32
-// CHECK: %[[CURRENT_IS_NAN:.*]] = llvm.fcmp "uno" %[[CURRENT]], %[[CURRENT]] : f32
-// CHECK: scf.if %[[CURRENT_IS_NAN]] {
-// CHECK: } else {
-// CHECK:   %[[MODIFIER_IS_NAN:.*]] = llvm.fcmp "uno" %[[MODIFIER]], %[[MODIFIER]] : f32
-// CHECK:   %[[MODIFIER_OR_NAN:.*]] = llvm.select %[[MODIFIER_IS_NAN]], %[[NAN]], %[[MODIFIER]] : i1, f32
-// CHECK:   %[[VAL_13:.*]] = llvm.fcmp "ult" %[[CURRENT]], %[[MODIFIER_OR_NAN]] : f32
-// CHECK:   scf.if %[[VAL_13]] {
-// CHECK:     %[[INT_MODIFIER_OR_NAN:.*]] = llvm.bitcast %[[MODIFIER_OR_NAN]] : f32 to i32
-// CHECK:     %[[IS_POSITIVE:.*]] = llvm.icmp "sge" %[[INT_MODIFIER_OR_NAN]], %[[C0]] : i32
-// CHECK:     scf.if %[[IS_POSITIVE]] {
-// CHECK:       llvm.atomicrmw max %[[ADDR]], %[[INT_MODIFIER_OR_NAN]] seq_cst
-// CHECK:     } else {
-// CHECK:       llvm.atomicrmw umin %[[ADDR]], %[[INT_MODIFIER_OR_NAN]] seq_cst
-// CHECK:     }
-// CHECK:   }
-// CHECK: }
-// CHECK: return
-
-// -----
-
-module {
-  func.func @atomic_rmw_c32(%in: tensor<2x4xcomplex<f32>>, %i: index, %j: index)
-      -> (tensor<2x4xcomplex<f32>>) {
-    %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xcomplex<f32>> {
-      ^bb0(%current : complex<f32>):
-        %a = complex.add %current, %current : complex<f32>
-        xla_gpu.yield %a : complex<f32>
-    }
-    return %ret : tensor<2x4xcomplex<f32>>
-  }
-}
-
-// CHECK-LABEL: @atomic_rmw_c32
-
-// CHECK: scf.while (%[[ITER_ARG:.*]] = %{{.*}}) : (i64) -> i64
-// CHECK: %[[TMP:.*]] = llvm.alloca
-// CHECK: llvm.store %[[ITER_ARG]], %[[TMP]]
-// CHECK: %[[LD:.*]] = llvm.load %[[TMP]] : {{.*}} -> !llvm.struct<(f32, f32)>
-// CHECK: builtin.unrealized_conversion_cast %[[LD]] : {{.*}} to complex<f32>
-
-// -----
-
-module {
-  func.func @unused_index_switch_results(%i: index) -> index {
-    %ret, %ret2 = scf.index_switch %i -> tensor<2x4xi32>, tensor<3xf32>
-    case 0 {
-      %x, %y = "dummy.op1"() : () -> (tensor<2x4xi32>, tensor<3xf32>)
-      scf.yield %x, %y : tensor<2x4xi32>, tensor<3xf32>
-    }
-    default {
-      %x, %y = "dummy.op2"() : () -> (tensor<2x4xi32>, tensor<3xf32>)
-      scf.yield %x, %y : tensor<2x4xi32>, tensor<3xf32>
-    }
-    return %i : index
-  }
-}
-
-// CHECK-LABEL: func.func @unused_index_switch_results
-// CHECK-SAME:      (%[[I:.*]]: index)
-// CHECK-NEXT:    scf.index_switch %[[I]]
-// CHECK-NEXT:    case 0 {
-// CHECK-NEXT:      "dummy.op1"
-// CHECK-NEXT:      scf.yield
-// CHECK-NEXT:    }
-// CHECK-NEXT:    default {
-// CHECK-NEXT:      "dummy.op2"
-// CHECK-NEXT:    }
-// CHECK-NEXT:    return %[[I]] : index
-
-// -----
-
-module {
-  func.func @transfer_write(%arg0: tensor<43xf32> {xla.slice_index = 1}) -> tensor<43xf32> {
-    %c16 = arith.constant 16 : index
-    %c22 = arith.constant 22 : index
-    %cst = arith.constant dense<[1.0, 2.0]> : vector<2xf32>
-    %out = vector.transfer_write %cst, %arg0[%c16] : vector<2xf32>, tensor<43xf32>
-    %out2 = vector.transfer_write %cst, %out[%c22] : vector<2xf32>, tensor<43xf32>
-    func.return %out2 : tensor<43xf32>
-  }
-}
-
-// CHECK-LABEL: @transfer_write
-// CHECK:           %[[PTR1:.*]] = llvm.getelementptr inbounds %[[BUF:.*]][16]
-// CHECK-NEXT:      llvm.store %[[CST:.*]], %[[PTR1]]
-// CHECK-NEXT:      %[[PTR2:.*]] = llvm.getelementptr inbounds %[[BUF]][22]
-// CHECK-NEXT:      llvm.store %[[CST]], %[[PTR2]]
-
-// -----
-
-module {
-  func.func @transfer_read(%arg0: tensor<43xf32> {xla.slice_index = 1}) -> vector<2xf32> {
-    %c16 = arith.constant 16 : index
-    %c0 = arith.constant 0.0 : f32
-    %out = vector.transfer_read %arg0[%c16], %c0 : tensor<43xf32>, vector<2xf32>
-    func.return %out : vector<2xf32>
-  }
-}
-
-// CHECK-LABEL: @transfer_read
-// CHECK:           %[[PTR:.*]] = llvm.getelementptr inbounds %{{.*}}[16]
-// CHECK-NEXT:      llvm.load %[[PTR]] : !llvm.ptr -> vector<2xf32>
-
-// -----
-
-module {
-  func.func @transfer_write_i1(%arg0: tensor<43xi1> {xla.slice_index = 1},
-                               %v1: vector<2xi1>, %v2: vector<2xi1>) -> tensor<43xi1> {
-    %c16 = arith.constant 16 : index
-    %c22 = arith.constant 22 : index
-    %out = vector.transfer_write %v1, %arg0[%c16] : vector<2xi1>, tensor<43xi1>
-    %out2 = vector.transfer_write %v2, %out[%c22] : vector<2xi1>, tensor<43xi1>
-    func.return %out2 : tensor<43xi1>
-  }
-}
-
-// CHECK-LABEL: @transfer_write_i1
-// CHECK-SAME:      (%[[ARG0:.*]]: !llvm.ptr
-// CHECK-SAME:       %[[V1:.*]]: vector<2xi1>, %[[V2:.*]]: vector<2xi1>)
-// CHECK-DAG:       %[[PTR1:.*]] = llvm.getelementptr inbounds %[[BUF:.*]][16]
-// CHECK-DAG:       %[[V1_EXT:.*]] = arith.extui %[[V1]]
-// CHECK:           llvm.store %[[V1_EXT]], %[[PTR1]]
-// CHECK-DAG:       %[[PTR2:.*]] = llvm.getelementptr inbounds %[[BUF]][22]
-// CHECK-DAG:       %[[V2_EXT:.*]] = arith.extui %[[V2]]
-// CHECK:           llvm.store %[[V2_EXT]], %[[PTR2]]
-
-// -----
-
-module {
-  func.func @transfer_read_i1(%arg0: tensor<43xi1> {xla.slice_index = 1}) -> vector<2xi1> {
-    %c16 = arith.constant 16 : index
-    %false = arith.constant false
-    %out = vector.transfer_read %arg0[%c16], %false : tensor<43xi1>, vector<2xi1>
-    func.return %out : vector<2xi1>
-  }
-}
-
-// CHECK-LABEL: @transfer_read_i1
-// CHECK-DAG:       %[[C0:.*]] = arith.constant dense<0> : vector<2xi8>
-// CHECK-DAG:       %[[PTR:.*]] = llvm.getelementptr inbounds %{{.*}}[16]
-// CHECK:           %[[LOADED:.*]] = llvm.load %[[PTR]] : !llvm.ptr
-// CHECK:           %[[CAST:.*]] = arith.cmpi ne, %[[LOADED]], %[[C0]]
-// CHECK:           return %[[CAST]] : vector<2xi1>
-
-// -----
-
-module {
-  func.func @transfer_write_i4(%arg0: tensor<43xi4> {xla.slice_index = 1},
-                               %v1: vector<4xi4>) -> tensor<43xi4> {
-    %c16 = arith.constant 16 : index
-    %out = vector.transfer_write %v1, %arg0[%c16] : vector<4xi4>, tensor<43xi4>
-    func.return %out : tensor<43xi4>
-  }
-}
-
-// CHECK-LABEL: @transfer_write_i4
-// CHECK-SAME:       , %[[V1:.*]]: vector<4xi4>
-// CHECK-DAG:       %[[A0:.*]] = vector.extract %[[V1]][0]
-// CHECK-DAG:       %[[A1:.*]] = vector.extract %[[V1]][1]
-// CHECK-DAG:       %[[A2:.*]] = vector.extract %[[V1]][2]
-// CHECK-DAG:       %[[A3:.*]] = vector.extract %[[V1]][3]
-// CHECK-DAG:       vector.insert %[[A0]], {{.*}}[1]
-// CHECK-DAG:       vector.insert %[[A1]], {{.*}}[0]
-// CHECK-DAG:       vector.insert %[[A2]], {{.*}}[3]
-// CHECK-DAG:       vector.insert %[[A3]], {{.*}}[2]
-
-module {
-  func.func @transfer_read_i4(%arg0: tensor<43xi4> {xla.slice_index = 1}) -> vector<4xi4> {
-    %c16 = arith.constant 16 : index
-    %c0 = arith.constant 0 : i4
-    %out = vector.transfer_read %arg0[%c16], %c0 : tensor<43xi4>, vector<4xi4>
-    func.return %out : vector<4xi4>
-  }
-}
-
-// CHECK-LABEL: @transfer_read_i4
-// CHECK:           %[[LOADED:.*]] = llvm.load
-// CHECK-DAG:       %[[A0:.*]] = vector.extract %[[LOADED]][0]
-// CHECK-DAG:       %[[A1:.*]] = vector.extract %[[LOADED]][1]
-// CHECK-DAG:       %[[A2:.*]] = vector.extract %[[LOADED]][2]
-// CHECK-DAG:       %[[A3:.*]] = vector.extract %[[LOADED]][3]
-// CHECK-DAG:       vector.insert %[[A0]], {{.*}}[1]
-// CHECK-DAG:       vector.insert %[[A1]], {{.*}}[0]
-// CHECK-DAG:       vector.insert %[[A2]], {{.*}}[3]
-// CHECK-DAG:       vector.insert %[[A3]], {{.*}}[2]
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/lower_xla_gpu_to_scf.mlir b/third_party/xla/xla/service/gpu/fusions/mlir/tests/lower_xla_gpu_to_scf.mlir
deleted file mode 100644
index 645430a..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/lower_xla_gpu_to_scf.mlir
+++ /dev/null
@@ -1,141 +0,0 @@
-// RUN: mlir_fusions_opt %s -xla-gpu-lower-xla-gpu-to-scf | FileCheck %s
-
-module {
-  func.func @reducer(%a: f32, %b: i32, %c: f32, %d: i32) -> (f32, i32) {
-    return %a, %b : f32, i32
-  }
-
-  func.func @shuffler(%a: f32, %b: i32) -> (f32, i32) {
-    %ret:2 = xla_gpu.shuffle_reduce @reducer(%a, %b) to 4 : f32, i32
-    return %ret#0, %ret#1 : f32, i32
-  }
-}
-
-// CHECK: @shuffler(%[[A:.*]]: f32, %[[B:.*]]: i32)
-// CHECK-DAG: %[[C1:.*]] = arith.constant 1
-// CHECK-DAG: %[[C2:.*]] = arith.constant 2
-// CHECK-DAG: %[[C4:.*]] = arith.constant 4
-// CHECK-DAG: %[[C32:.*]] = arith.constant 32
-// CHECK: %[[A4H:.*]], {{.*}} = gpu.shuffle down %[[A]], %[[C4]], %[[C32]]
-// CHECK: %[[B4H:.*]], {{.*}} = gpu.shuffle down %[[B]], %[[C4]], %[[C32]]
-// CHECK: %[[AB4_0:.*]], %[[AB4_1:.*]] = xla_gpu.pure_call @reducer(%[[A]], %[[B]], %[[A4H]], %[[B4H]])
-// CHECK: %[[A2H:.*]], {{.*}} = gpu.shuffle down %[[AB4_0]], %[[C2]], %[[C32]]
-// CHECK: %[[B2H:.*]], {{.*}} = gpu.shuffle down %[[AB4_1]], %[[C2]], %[[C32]]
-// CHECK: %[[AB2_0:.*]], %[[AB2_1:.*]] = xla_gpu.pure_call @reducer(%[[AB4_0]], %[[AB4_1]], %[[A2H]], %[[B2H]])
-// CHECK: %[[A1H:.*]], {{.*}} = gpu.shuffle down %[[AB2_0]], %[[C1]], %[[C32]]
-// CHECK: %[[B1H:.*]], {{.*}} = gpu.shuffle down %[[AB2_1]], %[[C1]], %[[C32]]
-// CHECK: %[[AB1_0:.*]], %[[AB1_1:.*]] = xla_gpu.pure_call @reducer(%[[AB2_0]], %[[AB2_1]], %[[A1H]], %[[B1H]])
-// CHECK: return %[[AB1_0]], %[[AB1_1]]
-
-// -----
-
-module {
-  func.func @reducer(%a: f64, %b: f64) -> f64 {
-    return %a : f64
-  }
-
-  func.func @shuffler(%a: f64) -> f64 {
-    %ret = xla_gpu.shuffle_reduce @reducer(%a) to 1 : f64
-    return %ret : f64
-  }
-}
-
-// CHECK: @shuffler(%[[A:.*]]: f64
-// CHECK: gpu.shuffle down {{.*}}, %[[C1]]
-// CHECK: gpu.shuffle down {{.*}}, %[[C1]]
-
-// -----
-
-module {
-  func.func @reducer(%a: complex<f64>, %b: complex<f64>) -> complex<f64> {
-    return %a : complex<f64>
-  }
-
-  func.func @shuffler(%a: complex<f64>) -> complex<f64> {
-    %ret = xla_gpu.shuffle_reduce @reducer(%a) to 1 : complex<f64>
-    return %ret : complex<f64>
-  }
-}
-
-// CHECK: @shuffler
-// CHECK-COUNT-4: gpu.shuffle down {{.*}}, %[[C1]]
-
-// -----
-
-module {
-  func.func @reducer(%a: ui64, %b: ui64) -> ui64 {
-    return %a : ui64
-  }
-
-  func.func @shuffler(%a: ui64) -> ui64 {
-    %ret = xla_gpu.shuffle_reduce @reducer(%a) to 1 : ui64
-    return %ret : ui64
-  }
-}
-
-// CHECK: @shuffler
-// CHECK: unrealized_conversion_cast
-// CHECK-COUNT-2: gpu.shuffle down {{.*}}, %[[C1]]
-
-// -----
-
-module {
-  func.func @reducer(%a: i8, %b: i8) -> i8 {
-    return %a : i8
-  }
-
-  func.func @shuffler_i8(%a: i8) -> i8 {
-    %ret = xla_gpu.shuffle_reduce @reducer(%a) to 1 : i8
-    return %ret : i8
-  }
-}
-
-// CHECK: @shuffler_i8(
-// CHECK-NOT: vector
-// CHECK-COUNT-1: gpu.shuffle down {{.*}}, %[[C1]]
-
-// -----
-
-module {
-  func.func @predicated_insert(
-      %v: i32, %tensor: tensor<2xi32>, %index: index,
-      %cond: i1) -> tensor<2xi32> {
-    %ret = xla_gpu.predicated_insert %v into %tensor[%index] if %cond
-      : tensor<2xi32>
-    return %ret : tensor<2xi32>
-  }
-}
-
-// CHECK: @predicated_insert
-// CHECK-SAME: %[[V:.*]]: i32, %[[TENSOR:.*]]: tensor<2xi32>,
-// CHECK-SAME: %[[INDEX:.*]]: index, %[[COND:.*]]: i1
-// CHECK-NEXT: %[[RET:.*]] = scf.if %[[COND]]
-// CHECK-NEXT:   %[[UPD:.*]] = tensor.insert %[[V]] into %[[TENSOR]][%[[INDEX]]]
-// CHECK-NEXT:   yield %[[UPD]]
-// CHECK-NEXT: else
-// CHECK-NEXT:   yield %[[TENSOR]]
-// CHECK-NEXT: }
-// CHECK-NEXT: return %[[RET]]
-
-// -----
-
-module {
-  func.func @predicated_extract(
-      %v: i32, %tensor: tensor<2xi32>, %index: index,
-      %cond: i1) -> i32 {
-    %ret = xla_gpu.predicated_extract %tensor[%index] if %cond else %v
-      : tensor<2xi32>
-    return %ret : i32
-  }
-}
-
-// CHECK: @predicated_extract
-// CHECK-SAME: %[[V:.*]]: i32, %[[TENSOR:.*]]: tensor<2xi32>,
-// CHECK-SAME: %[[INDEX:.*]]: index, %[[COND:.*]]: i1
-// CHECK-NEXT: %[[RET:.*]] = scf.if %[[COND]]
-// CHECK-NEXT:   %[[VAL:.*]] = tensor.extract  %[[TENSOR]][%[[INDEX]]]
-// CHECK-NEXT:   yield %[[VAL]]
-// CHECK-NEXT: else
-// CHECK-NEXT:   yield %[[V]]
-// CHECK-NEXT: }
-// CHECK-NEXT: return %[[RET]]
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/mlir_fusions_opt.cc b/third_party/xla/xla/service/gpu/fusions/mlir/tests/mlir_fusions_opt.cc
deleted file mode 100644
index 8e9fb47..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/mlir_fusions_opt.cc
+++ /dev/null
@@ -1,54 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Complex/IR/Complex.h"
-#include "mlir/Dialect/DLTI/DLTI.h"
-#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/GPU/IR/GPUDialect.h"
-#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
-#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
-#include "mlir/Dialect/Math/IR/Math.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Tools/mlir-opt/MlirOptMain.h"
-#include "mlir/Transforms/Passes.h"
-#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h"
-#include "xla/service/gpu/fusions/mlir/passes.h"
-
-int main(int argc, char **argv) {
-  mlir::DialectRegistry registry;
-  registry.insert<mlir::DLTIDialect, mlir::tensor::TensorDialect,
-                  mlir::func::FuncDialect, mlir::affine::AffineDialect,
-                  mlir::arith::ArithDialect, mlir::complex::ComplexDialect,
-                  mlir::math::MathDialect, mlir::scf::SCFDialect,
-                  mlir::mhlo::MhloDialect, mlir::LLVM::LLVMDialect,
-                  mlir::gpu::GPUDialect, mlir::mhlo::MhloDialect,
-                  mlir::vector::VectorDialect, xla::gpu::XlaGpuDialect,
-                  mlir::NVVM::NVVMDialect>();
-  mlir::func::registerAllExtensions(registry);
-  mlir::registerCanonicalizerPass();
-  mlir::registerCSEPass();
-  mlir::registerInliner();
-  xla::gpu::registerGpuFusionTransformsPasses();
-
-  return mlir::failed(
-      MlirOptMain(argc, argv, "XLA MLIR Fusion Pass Driver\n", registry));
-}
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/ops.mlir b/third_party/xla/xla/service/gpu/fusions/mlir/tests/ops.mlir
deleted file mode 100644
index c7f1507..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/ops.mlir
+++ /dev/null
@@ -1,96 +0,0 @@
-// R-UN: mlir_fusions_opt %s --split-input-file | FileCheck %s
-// Verify the printed output can be parsed.
-// RU-N: mlir_fusions_opt %s --split-input-file | mlir_fusions_opt --split-input-file | FileCheck %s
-// Verify the generic form can be parsed.
-// RUN: mlir_fusions_opt %s --split-input-file --mlir-print-op-generic | mlir_fusions_opt --split-input-file | FileCheck %s
-
-func.func @shared_and_sync() -> (tensor<2xf32>, tensor<2xf32>) {
-  %shared1 = xla_gpu.allocate_shared : tensor<2xf32>
-  %shared2 = xla_gpu.allocate_shared : tensor<2xf32>
-  %sync:2 = xla_gpu.sync_threads %shared1, %shared2
-    : tensor<2xf32>, tensor<2xf32>
-  return %sync#0, %sync#1 : tensor<2xf32>, tensor<2xf32>
-}
-// CHECK-LABEL: @shared_and_sync
-// CHECK-NEXT: allocate_shared
-// CHECK-NEXT: allocate_shared
-// CHECK-NEXT: sync_threads
-// CHECK-NEXT: return
-
-// -----
-
-func.func @atomic_rmw(%in: tensor<2x3xf32>, %i: index, %j: index)
-    -> (tensor<2x3xf32>) {
-  %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x3xf32> {
-    ^bb0(%current : f32):
-      %c42 = arith.constant 42.0 : f32
-      %add = arith.addf %current, %c42 : f32
-      xla_gpu.yield %add : f32
-  }
-  return %ret : tensor<2x3xf32>
-}
-// CHECK-LABEL: @atomic_rmw
-// CHECK: xla_gpu.atomic_rmw
-
-// -----
-
-func.func private @add(%a: f32, %b: f32) -> f32 {
-  %ret = arith.addf %a, %b : f32
-  return %ret : f32
-}
-
-func.func @caller(%a: f32, %b: f32) -> f32 {
-  %c = xla_gpu.pure_call @add(%a, %b) : (f32, f32) -> (f32)
-  %d = xla_gpu.pure_call @add(%a, %b) : (f32, f32) -> (f32)
-  %ret = arith.addf %c, %d : f32
-  return %ret : f32
-}
-// CHECK-LABEL: @caller
-// CHECK: %[[C:.*]] = xla_gpu.pure_call @add
-// CHECK: %[[D:.*]] = xla_gpu.pure_call @add
-// CHECK: arith.addf %[[C]], %[[D]]
-
-// CHECK-CSE: @caller
-// CHECK-CSE: %[[C:.*]] = xla_gpu.pure_call @add
-// CHECK-CSE: arith.addf %[[C]], %[[C]]
-
-// -----
-
-#map0 = affine_map<(d0, d1)[s0] -> (d0, d1 + s0)>
-func.func @apply_indexing(%d0: index, %d1: index, %s0: index) -> (index, index) {
-  %0:2 = xla_gpu.apply_indexing #map0 (%d0 in [0, 2], %d1 in [1, 3])[%s0 in [2, 4]]
-  func.return %0#0, %0#1 : index, index
-}
-// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1)[s0] -> (d0, d1 + s0)>
-
-// CHECK-LABEL: @apply_indexing
-// CHECK: (%[[d0:.*]]: index, %[[d1:.*]]: index, %[[s0:.*]]: index)
-// CHECK: xla_gpu.apply_indexing #[[$MAP0]]
-// CHECK-SAME:  (%[[d0]] in [0, 2], %[[d1]] in [1, 3])[%[[s0]] in [2, 4]]
-
-// -----
-
-#map0 = affine_map<(d0, d1) -> (d0, d1)>
-func.func @apply_indexing_no_symbols(%d0: index, %d1: index) -> (index, index) {
-  %0:2 = xla_gpu.apply_indexing #map0 (%d0 in [0, 2], %d1 in [1, 3])
-  func.return %0#0, %0#1 : index, index
-}
-// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
-
-// CHECK-LABEL: @apply_indexing_no_symbols
-// CHECK: (%[[d0:.*]]: index, %[[d1:.*]]: index)
-// CHECK: xla_gpu.apply_indexing #[[$MAP0]]
-// CHECK-SAME:  (%[[d0]] in [0, 2], %[[d1]] in [1, 3])
-
-// -----
-
-#map0 = affine_map<()[s0] -> (s0, s0)>
-func.func @apply_indexing_no_dims(%s0: index) -> (index, index) {
-  %0:2 = xla_gpu.apply_indexing #map0 [%s0 in [2, 4]]
-  func.return %0#0, %0#1 : index, index
-}
-// CHECK: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0, s0)>
-
-// CHECK-LABEL: @apply_indexing_no_dims
-// CHECK: (%[[s0:.*]]: index)
-// CHECK: xla_gpu.apply_indexing #[[$MAP0]][%[[s0]] in [2, 4]]
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/optimize_loops.mlir b/third_party/xla/xla/service/gpu/fusions/mlir/tests/optimize_loops.mlir
deleted file mode 100644
index 6f903f3..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/optimize_loops.mlir
+++ /dev/null
@@ -1,182 +0,0 @@
-// RUN: mlir_fusions_opt %s -split-input-file -xla-gpu-optimize-loops | FileCheck %s
-
-#map = affine_map<(d0) -> (d0 floordiv 8)>
-#map1 = affine_map<(d0) -> (d0 mod 8)>
-#map2 = affine_map<(d0, d1)[s0] -> (d1 * 2 + d0 + s0 * 512)>
-module {
-  func.func @fully_unroll(%arg0: tensor<4x8x4096xf32>, %arg1: tensor<4096xbf16>,
-      %arg2: tensor<4x8xf32>, %arg3: tensor<4096xbf16>,
-      %arg4: tensor<4x8x4096xbf16>, %arg5: tensor<4x8xf32>,
-      %arg6: tensor<4x8x4096xf32>) -> (tensor<4x8x4096xf32>, f32)  {
-    %cst = arith.constant 1.000000e+00 : f32
-    %cst_1 = arith.constant 1.000000e+00 : bf16
-    %c2 = arith.constant 2 : index
-    %c8 = arith.constant 8 : index
-    %c32 = arith.constant 32 : index
-    %c1 = arith.constant 1 : index
-    %c0 = arith.constant 0 : index
-    %thread_id_x = gpu.thread_id  x {xla.range = [0 : index, 255 : index]}
-    %block_id_x = gpu.block_id  x {xla.range = [0 : index, 31 : index]}
-    %0 = gpu.lane_id
-    %1 = arith.cmpi eq, %0, %c0 : index
-    %2 = arith.divui %thread_id_x, %c32 : index
-    %3 = arith.cmpi ult, %thread_id_x, %c8 : index
-    %4 = xla_gpu.apply_indexing #map(%block_id_x in [0, 31])
-    %5 = xla_gpu.apply_indexing #map1(%block_id_x in [0, 31])
-    %extracted = tensor.extract %arg2[%4, %5] : tensor<4x8xf32>
-    %6 = arith.mulf %extracted, %cst : f32
-    %7 = arith.addf %6, %cst : f32
-    %8 = math.rsqrt %7 : f32
-    %9:2 = scf.for %arg7 = %c0 to %c8 step %c1 iter_args(%arg8 = %arg6, %arg9 = %cst) -> (tensor<4x8x4096xf32>, f32) {
-      %18 = xla_gpu.apply_indexing #map2(%c0 in [0, 1], %thread_id_x in [0, 255])[%arg7 in [0, 7]]
-      %19 = vector.transfer_read %arg1[%18], %cst_1 {in_bounds = [true]} : tensor<4096xbf16>, vector<2xbf16>
-      %20 = xla_gpu.apply_indexing #map2(%c0 in [0, 1], %thread_id_x in [0, 255])[%arg7 in [0, 7]]
-      %21 = vector.transfer_read %arg3[%20], %cst_1 {in_bounds = [true]} : tensor<4096xbf16>, vector<2xbf16>
-      %22 = xla_gpu.apply_indexing #map2(%c0 in [0, 1], %thread_id_x in [0, 255])[%arg7 in [0, 7]]
-      %23 = vector.transfer_read %arg4[%4, %5, %22], %cst_1 {in_bounds = [true]} : tensor<4x8x4096xbf16>, vector<2xbf16>
-      %24 = xla_gpu.apply_indexing #map2(%c0 in [0, 1], %thread_id_x in [0, 255])[%arg7 in [0, 7]]
-      %25 = vector.transfer_read %arg0[%4, %5, %24], %cst {in_bounds = [true]} : tensor<4x8x4096xf32>, vector<2xf32>
-      %26:2 = scf.for %arg10 = %c0 to %c2 step %c1 iter_args(%arg11 = %arg8, %arg12 = %arg9) -> (tensor<4x8x4096xf32>, f32) {
-        %27 = xla_gpu.apply_indexing #map2(%arg10 in [0, 1], %thread_id_x in [0, 255])[%arg7 in [0, 7]]
-        %28 = vector.extract %25[%arg10] : f32 from vector<2xf32>
-        %29 = vector.extract %23[%arg10] : bf16 from vector<2xbf16>
-        %30 = arith.extf %29 : bf16 to f32
-        %31 = vector.extract %21[%arg10] : bf16 from vector<2xbf16>
-        %32 = arith.extf %31 : bf16 to f32
-        %33 = arith.mulf %30, %32 : f32
-        %34 = arith.mulf %33, %8 : f32
-        %35 = vector.extract %19[%arg10] : bf16 from vector<2xbf16>
-        %36 = arith.extf %35 : bf16 to f32
-        %37 = arith.addf %36, %cst : f32
-        %38 = arith.mulf %34, %37 : f32
-        %39 = arith.addf %28, %38 : f32
-        %40 = arith.mulf %39, %39 : f32
-        %41 = arith.addf %arg12, %40 : f32
-        %inserted = tensor.insert %39 into %arg11[%4, %5, %27] : tensor<4x8x4096xf32>
-        scf.yield %inserted, %41 : tensor<4x8x4096xf32>, f32
-      }
-      scf.yield %26#0, %26#1 : tensor<4x8x4096xf32>, f32
-    }
-    return %9#0, %9#1 : tensor<4x8x4096xf32>, f32
-  }
-}
-
-// CHECK-LABEL: @fully_unroll
-// CHECK-NOT: scf.for
-
-// -----
-
-module {
-  func.func @unroll_by_factor(%arg0: f32) -> f32 {
-    %c0 = arith.constant 0 : index
-    %c1 = arith.constant 1 : index
-    %c256 = arith.constant 256 : index
-    %ret = scf.for %i = %c0 to %c256 step %c1 iter_args (%v = %arg0) -> (f32) {
-      %exp = math.exp %v : f32
-      %add = arith.addf %v, %exp : f32
-      %log = math.log %add : f32
-      scf.yield %log : f32
-    }
-    return %ret : f32
-  }
-}
-
-// CHECK-LABEL: @unroll_by_factor
-// CHECK: %[[C8:.*]] = arith.constant 8 : index
-// CHECK: scf.for {{.*}} step %[[C8]]
-
-// -----
-
-module {
-  func.func @do_not_unroll(%arg0: f32) -> f32 {
-    %c0 = arith.constant 0 : index
-    %c1 = arith.constant 1 : index
-    %c31 = arith.constant 31 : index
-    %ret = scf.for %i = %c0 to %c31 step %c1 iter_args (%v = %arg0) -> (f32) {
-      %exp = math.exp %v : f32
-      %add = arith.addf %v, %exp : f32
-      %log = math.log %add : f32
-      scf.yield %log : f32
-    }
-    return %ret : f32
-  }
-}
-
-// CHECK-LABEL: @do_not_unroll
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: scf.for {{.*}} step %[[C1]]
-
-// -----
-
-module {
-  func.func @pipeline_extract(%arg: tensor<31xf32>) -> f32 {
-    %c0 = arith.constant 0 : index
-    %c1 = arith.constant 1 : index
-    %c31 = arith.constant 31 : index
-    %cst = arith.constant 0.0 : f32
-    %ret = scf.for %i = %c0 to %c31 step %c1 iter_args (%iter = %cst) -> (f32) {
-      %val = tensor.extract %arg[%i] : tensor<31xf32>
-      %log = math.log %val : f32
-      %add = arith.addf %log, %iter : f32
-      scf.yield %add : f32
-    }
-    return %ret : f32
-  }
-}
-
-// CHECK: #[[$MAP:.*]] = affine_map<(d0) -> (d0 + 1)>
-// CHECK-LABEL: @pipeline_extract
-// CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG:  %[[C30:.*]] = arith.constant 30 : index
-// CHECK:      %[[VAL0:.*]] = tensor.extract %[[ARG0:.*]][%[[C0]]]
-// CHECK:      scf.for %[[I:.*]] = %[[C0]] {{.*}} iter_args(%[[ITER:.*]] = {{.*}}, %[[VAL:.*]] = %[[VAL0]])
-// CHECK-DAG:  %[[NEXT_I_EXISTS:.*]] = arith.cmpi ult, %[[I]], %[[C30]]
-// CHECK-DAG:  %[[NEXT_I:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[I]]
-// CHECK:      %[[NEXT_VAL:.*]] = scf.if %[[NEXT_I_EXISTS]]
-// CHECK-NEXT:   tensor.extract %[[ARG0]][%[[NEXT_I]]]
-// CHECK-NEXT:   yield
-// CHECK-NEXT: else
-// CHECK-NEXT:   yield %[[VAL]]
-// CHECK:     math.log %[[VAL]]
-// CHECK:     %[[ADD:.*]] = arith.addf
-// CHECK:     yield %[[ADD]], %[[NEXT_VAL]]
-
-// -----
-
-module {
-  func.func @pipeline_transfer(%arg: tensor<34xf32>) -> vector<2xf32> {
-    %c0 = arith.constant 0 : index
-    %c1 = arith.constant 1 : index
-    %c17 = arith.constant 17 : index
-    %cst = arith.constant dense<[0.0, 0.0]> : vector<2xf32>
-    %cst0  = arith.constant 0.0 : f32
-    %ret = scf.for %i = %c0 to %c17 step %c1 iter_args (%iter = %cst) -> (vector<2xf32>) {
-      %base = xla_gpu.apply_indexing affine_map<(d0) -> (d0 * 2)>(%i in [0, 15])
-      %val = vector.transfer_read %arg[%base], %cst0 : tensor<34xf32>, vector<2xf32>
-      %log = math.log %val : vector<2xf32>
-      %add = arith.addf %log, %iter : vector<2xf32>
-      scf.yield %add : vector<2xf32>
-    }
-    return %ret : vector<2xf32>
-  }
-}
-
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0) -> (d0 * 2)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (d0 + 1)>
-// CHECK-LABEL: @pipeline_transfer
-// CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG:  %[[C16:.*]] = arith.constant 16 : index
-// CHECK:      %[[BASE0:.*]] = xla_gpu.apply_indexing #[[$MAP0]](%[[C0]]
-// CHECK:      %[[VAL0:.*]] = vector.transfer_read %[[ARG0:.*]][%[[BASE0]]]
-// CHECK:      scf.for %[[I:.*]] = %[[C0]] {{.*}} iter_args(%[[ITER:.*]] = {{.*}}, %[[VAL:.*]] = %[[VAL0]])
-// CHECK-DAG:  %[[NEXT_I_EXISTS:.*]] = arith.cmpi ult, %[[I]], %[[C16]]
-// CHECK-DAG:  %[[NEXT_I:.*]] = xla_gpu.apply_indexing #[[$MAP1]](%[[I]]
-// CHECK-DAG:  %[[NEXT_BASE:.*]] = xla_gpu.apply_indexing #[[$MAP0]](%[[NEXT_I]]
-// CHECK:      %[[NEXT_VAL:.*]] = scf.if %[[NEXT_I_EXISTS]]
-// CHECK-NEXT:   vector.transfer_read %[[ARG0]][%[[NEXT_BASE]]]
-// CHECK-NEXT:   yield
-// CHECK-NEXT: else
-// CHECK-NEXT:   yield %[[VAL]]
-// CHECK:     math.log %[[VAL]]
-// CHECK:     %[[ADD:.*]] = arith.addf
-// CHECK:     yield %[[ADD]], %[[NEXT_VAL]]
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/simplify_affine.mlir b/third_party/xla/xla/service/gpu/fusions/mlir/tests/simplify_affine.mlir
deleted file mode 100644
index ec1a726..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/simplify_affine.mlir
+++ /dev/null
@@ -1,145 +0,0 @@
-// RUN: mlir_fusions_opt --allow-unregistered-dialect %s -split-input-file -xla-gpu-simplify-affine | FileCheck %s
-
-func.func @op_and_for_ranges(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) {
-  %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-  %c4 = arith.constant 4 : index
-  %0 = gpu.thread_id  x {xla.range = [0 : index, 127 : index]}
-  %1 = gpu.block_id  x {xla.range = [0 : index, 3071 : index]}
-  scf.for %arg3 = %c0 to %c4 step %c1 {
-    %2 = affine.apply affine_map<()[s0, s1, s2] -> (s0 * 512 + s1 * 4 + s2 + (s1 floordiv 128) + (s2 floordiv 4))>()[%1, %0, %arg3]
-    %3 = arith.index_castui %2 : index to i64
-    %4 = llvm.getelementptr %arg0[%3] : (!llvm.ptr, i64) -> !llvm.ptr, f32
-    %5 = llvm.load %4 invariant : !llvm.ptr -> f32
-    %8 = llvm.getelementptr %arg1[%3] : (!llvm.ptr, i64) -> !llvm.ptr, f32
-    %9 = llvm.load %8 invariant : !llvm.ptr -> f32
-    %10 = arith.cmpf oge, %5, %9 : f32
-    %11 = llvm.getelementptr %arg2[%3] : (!llvm.ptr, i64) -> !llvm.ptr, i1
-    llvm.store %10, %11 : i1, !llvm.ptr
-  }
-  return
-}
-
-// CHECK-LABEL: @op_and_for_ranges
-// CHECK-DAG: %[[C512:.*]] = arith.constant 512
-// CHECK-DAG: %[[C4:.*]] = arith.constant 4
-// CHECK-DAG: %[[TID_X:.*]] = gpu.thread_id x
-// CHECK-DAG: %[[BID_X:.*]] = gpu.block_id x
-// CHECK:     scf.for %[[I:.*]] =
-// CHECK:       %[[BLOCK_OFFSET:.*]] = arith.muli %[[BID_X]], %[[C512]]
-// CHECK:       %[[THREAD_OFFSET:.*]] = arith.muli %[[TID_X]], %[[C4]]
-// CHECK:       %[[OFFSET:.*]] = arith.addi %[[BLOCK_OFFSET]], %[[THREAD_OFFSET]]
-// CHECK:       arith.addi %[[OFFSET]], %[[I]]
-
-// -----
-
-func.func @arg_ranges(%arg0: index {xla.range = [0 : index, 42 : index]}, %arg1: index {xla.range = [0 : index, 1000 : index]}) -> index {
-  %0 = affine.apply affine_map<()[s0, s1] -> (s0 floordiv 100 + s1 floordiv 100)>()[%arg0, %arg1]
-  return %0 : index
-}
-
-// CHECK-LABEL: @arg_ranges
-// CHECK-NEXT:  %[[C100:.*]] = arith.constant 100
-// CHECK-NEXT:  %[[RET:.*]] = arith.divui %{{.*}}, %[[C100]]
-// CHECK-NEXT:  return %[[RET]]
-
-// -----
-
-func.func @cant_lower(%arg0: index {xla.range = [-10 : index, 42 : index]}, %arg1: index {xla.range = [0 : index, 1000 : index]}) -> index {
-  %0 = affine.apply affine_map<()[s0, s1] -> (s0 floordiv 100 + s1 floordiv 100)>()[%arg0, %arg1]
-  return %0 : index
-}
-
-// CHECK-LABEL:       @cant_lower
-// CHECK:       affine.apply
-
-// -----
-
-func.func @op_and_for_ranges(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) {
-  %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-  %c4 = arith.constant 4 : index
-  %0 = gpu.thread_id  x
-  %1 = gpu.block_id  x
-  scf.for %i = %c0 to %c4 step %c1 {
-    %2 = xla_gpu.apply_indexing affine_map<()[s0, s1, s2] -> (s0 * 512 + s1 * 4 + s2 + (s1 floordiv 128) + (s2 floordiv 4))>
-      [%1 in [0, 3071], %0 in [0, 127], %i in [0, 3]]
-    %3 = arith.index_castui %2 : index to i64
-    %4 = llvm.getelementptr %arg0[%3] : (!llvm.ptr, i64) -> !llvm.ptr, f32
-    %5 = llvm.load %4 invariant : !llvm.ptr -> f32
-    %8 = llvm.getelementptr %arg1[%3] : (!llvm.ptr, i64) -> !llvm.ptr, f32
-    %9 = llvm.load %8 invariant : !llvm.ptr -> f32
-    %10 = arith.cmpf oge, %5, %9 : f32
-    %11 = llvm.getelementptr %arg2[%3] : (!llvm.ptr, i64) -> !llvm.ptr, i1
-    llvm.store %10, %11 : i1, !llvm.ptr
-  }
-  return
-}
-
-// CHECK-LABEL: @op_and_for_ranges
-// CHECK-DAG: %[[C512:.*]] = arith.constant 512
-// CHECK-DAG: %[[C4:.*]] = arith.constant 4
-// CHECK-DAG: %[[TID_X:.*]] = gpu.thread_id x
-// CHECK-DAG: %[[BID_X:.*]] = gpu.block_id x
-// CHECK:     scf.for %[[I:.*]] =
-// CHECK:       %[[BLOCK_OFFSET:.*]] = arith.muli %[[BID_X]], %[[C512]]
-// CHECK:       %[[THREAD_OFFSET:.*]] = arith.muli %[[TID_X]], %[[C4]]
-// CHECK:       %[[OFFSET:.*]] = arith.addi %[[BLOCK_OFFSET]], %[[THREAD_OFFSET]]
-// CHECK:       arith.addi %[[OFFSET]], %[[I]]
-
-// -----
-
-func.func @arg_ranges(%arg0: index, %arg1: index) -> index {
-  %0 = xla_gpu.apply_indexing
-    affine_map<()[s0, s1] -> (s0 floordiv 100 + s1 floordiv 100)>
-    [%arg0 in [0, 42], %arg1 in [0, 1000]]
-  return %0 : index
-}
-
-// CHECK-LABEL: @arg_ranges
-// CHECK-NEXT:  %[[C100:.*]] = arith.constant 100
-// CHECK-NEXT:  %[[RET:.*]] = arith.divui %{{.*}}, %[[C100]]
-// CHECK-NEXT:  return %[[RET]]
-
-// -----
-
-func.func @cant_lower(%arg0: index, %arg1: index) -> (index, index) {
-  %0:2 = xla_gpu.apply_indexing
-    affine_map<()[s0, s1] -> (s0 floordiv 100 + s1 floordiv 100, s0 + s1)>
-    [%arg0 in [-10, 42], %arg1 in [0, 1000]]
-  return %0#0, %0#1 : index, index
-}
-
-// CHECK-LABEL: @cant_lower
-// CHECK:         affine.apply
-// CHECK-NEXT:    arith.addi
-
-// -----
-
-func.func @order_summands(%arg1: index) {
-  %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-  %c4 = arith.constant 4 : index
-  scf.for %arg2 = %c0 to %c4 step %c1 {
-    scf.for %arg3 = %c0 to %c4 step %c1 {
-      %0 = xla_gpu.apply_indexing
-        affine_map<()[s0, s1, s2] -> ((s0 + s1) floordiv 3 + s0 * 512 + s1 * 4 + s2 * 10)>
-        [%arg2 in [0, 3], %arg1 in [0, 3], %arg3 in [0, 3]]
-      "dummy.op"(%0) : (index) -> ()
-    }
-  }
-  return
-}
-
-// CHECK-LABEL: @order_summands
-// CHECK-SAME:    (%[[ARG1:.*]]: index)
-// CHECK: scf.for %[[ARG2:.*]] =
-// CHECK: scf.for %[[ARG3:.*]] =
-// CHECK: arith.muli %[[ARG1]]
-// CHECK: arith.muli %[[ARG2]]
-// CHECK: arith.addi
-// CHECK: arith.addi %[[ARG1]], %[[ARG2]]
-// CHECK: arith.divui
-// CHECK: arith.addi
-// CHECK: arith.muli %[[ARG3]]
-// CHECK: arith.addi %5, %6 : index
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/simplify_arith.mlir b/third_party/xla/xla/service/gpu/fusions/mlir/tests/simplify_arith.mlir
deleted file mode 100644
index ee2e0dd..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/simplify_arith.mlir
+++ /dev/null
@@ -1,292 +0,0 @@
-// RUN: mlir_fusions_opt %s -split-input-file -xla-gpu-simplify-arith -cse -canonicalize | FileCheck %s
-
-module {
-  func.func @unknown(%arg0: index {xla.range = [0 : index, 42 : index]}) -> i1 {
-    %c12 = arith.constant 12 : index
-    %eq = arith.cmpi eq, %arg0, %c12 : index
-    return %eq : i1
-  }
-}
-
-// CHECK: @unknown
-// CHECK: cmpi
-
-// -----
-
-module {
-  func.func @true(%arg0: index {xla.range = [12 : index, 42 : index]}) -> i1 {
-    %c5 = arith.constant 5 : index
-    %eq = arith.cmpi sge, %arg0, %c5 : index
-    return %eq : i1
-  }
-}
-
-// CHECK: @true
-// CHECK-NEXT: constant true
-// CHECK-NEXT: return
-
-// -----
-
-module {
-  func.func @false(%arg0: index {xla.range = [12 : index, 42 : index]}) -> i1 {
-    %c5 = arith.constant 5 : index
-    %eq = arith.cmpi slt, %arg0, %c5 : index
-    return %eq : i1
-  }
-}
-
-// CHECK: @false
-// CHECK-NEXT: constant false
-// CHECK-NEXT: return
-
-// -----
-
-module {
-  func.func @rhs_range(%arg0: index {xla.range = [12 : index, 42 : index]}) -> i1 {
-    %c42 = arith.constant 64 : index
-    %eq = arith.cmpi slt, %c42, %arg0 : index
-    return %eq : i1
-  }
-}
-
-// CHECK: @rhs_range
-// CHECK-NEXT: constant false
-// CHECK-NEXT: return
-
-// -----
-
-module {
-  func.func @both_range(%arg0: index {xla.range = [12 : index, 42 : index]},
-                        %arg1: index {xla.range = [63 : index, 100 : index]}) -> i1 {
-    %eq = arith.cmpi slt, %arg0, %arg1 : index
-    return %eq : i1
-  }
-}
-
-// CHECK-LABEL: @both_range
-// CHECK-NEXT: constant true
-// CHECK-NEXT: return
-
-// -----
-
-module {
-  func.func @minsi_lhs(%arg0: index {xla.range = [12 : index, 42 : index]},
-                       %arg1: index {xla.range = [63 : index, 100 : index]}) -> index {
-    %min = arith.minsi %arg0, %arg1 : index
-    return %min : index
-  }
-}
-
-// CHECK-LABEL: @minsi_lhs
-// CHECK-SAME: (%[[ARG0:.*]]: index {{.*}}, %[[ARG1:.*]]: index {{.*}})
-// CHECK-NEXT: return %[[ARG0]]
-
-// -----
-
-module {
-  func.func @minsi_rhs(%arg0: index {xla.range = [12 : index, 42 : index]},
-                       %arg1: index {xla.range = [63 : index, 100 : index]}) -> index {
-    %min = arith.minsi %arg1, %arg0 : index
-    return %min : index
-  }
-}
-
-// CHECK-LABEL: @minsi_rhs
-// CHECK-SAME: (%[[ARG0:.*]]: index {{.*}}, %[[ARG1:.*]]: index {{.*}})
-// CHECK-NEXT: return %[[ARG0]]
-
-// -----
-
-module {
-  func.func @maxsi_lhs(%arg0: index {xla.range = [12 : index, 42 : index]},
-                       %arg1: index {xla.range = [63 : index, 100 : index]}) -> index {
-    %min = arith.maxsi %arg1, %arg0 : index
-    return %min : index
-  }
-}
-
-// CHECK-LABEL: @maxsi_lhs
-// CHECK-SAME: (%[[ARG0:.*]]: index {{.*}}, %[[ARG1:.*]]: index {{.*}})
-// CHECK-NEXT: return %[[ARG1]]
-
-// -----
-
-module {
-  func.func @maxsi_rhs(%arg0: index {xla.range = [12 : index, 42 : index]},
-                       %arg1: index {xla.range = [63 : index, 100 : index]}) -> index {
-    %min = arith.maxsi %arg0, %arg1 : index
-    return %min : index
-  }
-}
-
-// CHECK-LABEL: @maxsi_rhs
-// CHECK-SAME: (%[[ARG0:.*]]: index {{.*}}, %[[ARG1:.*]]: index {{.*}})
-// CHECK-NEXT: return %[[ARG1]]
-
-// -----
-
-module {
-  func.func @maxsi_add(%arg0: index {xla.range = [102 : index, 142 : index]},
-                       %arg1: index {xla.range = [63 : index, 100 : index]}) -> index {
-    %add = arith.addi %arg0, %arg1 : index
-    %min = arith.maxsi %add, %arg1 : index
-    return %min : index
-  }
-}
-
-// CHECK-LABEL: @maxsi_add
-// CHECK-SAME: (%[[ARG0:.*]]: index {{.*}}, %[[ARG1:.*]]: index {{.*}})
-// CHECK-NEXT: %[[ADD:.*]] = arith.addi %[[ARG0]], %[[ARG1]]
-// CHECK-NEXT: return %[[ADD]]
-
-// -----
-
-module {
-  func.func @minsi_add(%arg0: index {xla.range = [102 : index, 142 : index]},
-                       %arg1: index {xla.range = [63 : index, 100 : index]}) -> index {
-    %add = arith.addi %arg0, %arg1 : index
-    %min = arith.minsi %add, %arg1 : index
-    return %min : index
-  }
-}
-
-// CHECK-LABEL: @minsi_add
-// CHECK-SAME: (%[[ARG0:.*]]: index {{.*}}, %[[ARG1:.*]]: index {{.*}})
-// CHECK-NEXT: return %[[ARG1]]
-
-// -----
-
-module {
-  func.func @pred_reduce(%in: i1) -> i1 {
-    %c1_i32 = arith.constant 1 : i32
-    %c2_i32 = arith.constant 2 : i32
-    %c4_i32 = arith.constant 4 : i32
-    %c8_i32 = arith.constant 8 : i32
-    %c16_i32 = arith.constant 16 : i32
-    %c32_i32 = arith.constant 32 : i32
-    %0 = arith.extui %in : i1 to i32
-    %shuffleResult, %valid = gpu.shuffle  down %0, %c16_i32, %c32_i32 : i32
-    %1 = arith.trunci %shuffleResult : i32 to i1
-    %2 = arith.ori %in, %1 : i1
-    %3 = arith.extui %2 : i1 to i32
-    %shuffleResult_0, %valid_1 = gpu.shuffle  down %3, %c8_i32, %c32_i32 : i32
-    %4 = arith.trunci %shuffleResult_0 : i32 to i1
-    %5 = arith.ori %2, %4 : i1
-    %6 = arith.extui %5 : i1 to i32
-    %shuffleResult_2, %valid_3 = gpu.shuffle  down %6, %c4_i32, %c32_i32 : i32
-    %7 = arith.trunci %shuffleResult_2 : i32 to i1
-    %8 = arith.ori %5, %7 : i1
-    %9 = arith.extui %8 : i1 to i32
-    %shuffleResult_4, %valid_5 = gpu.shuffle  down %9, %c2_i32, %c32_i32 : i32
-    %10 = arith.trunci %shuffleResult_4 : i32 to i1
-    %11 = arith.ori %8, %10 : i1
-    %12 = arith.extui %11 : i1 to i32
-    %shuffleResult_6, %valid_7 = gpu.shuffle  down %12, %c1_i32, %c32_i32 : i32
-    %13 = arith.trunci %shuffleResult_6 : i32 to i1
-    %14 = arith.ori %11, %13 : i1
-    return %14 : i1
-  }
-}
-
-// CHECK-LABEL: @pred_reduce
-// CHECK-SAME:     (%[[IN:.*]]: i1)
-// CHECK:       %[[IN_EXT:.*]] = arith.extui %[[IN]]
-// CHECK-NEXT:  %[[SHUFFLE0:.*]], {{.*}} = gpu.shuffle down %[[IN_EXT]]
-// CHECK-NEXT:  %[[OR0:.*]] = arith.ori %[[IN_EXT]], %[[SHUFFLE0]]
-// CHECK-NEXT:  %[[SHUFFLE1:.*]], {{.*}} = gpu.shuffle down %[[OR0]]
-// CHECK-NEXT:  %[[OR1:.*]] = arith.ori %[[OR0]], %[[SHUFFLE1]]
-// CHECK-NEXT:  %[[SHUFFLE2:.*]], {{.*}} = gpu.shuffle down %[[OR1]]
-// CHECK-NEXT:  %[[OR2:.*]] = arith.ori %[[OR1]], %[[SHUFFLE2]]
-// CHECK-NEXT:  %[[SHUFFLE3:.*]], {{.*}} = gpu.shuffle down %[[OR2]]
-// CHECK-NEXT:  %[[OR3:.*]] = arith.ori %[[OR2]], %[[SHUFFLE3]]
-// CHECK-NEXT:  %[[SHUFFLE4:.*]], {{.*}} = gpu.shuffle down %[[OR3]]
-// CHECK-NEXT:  %[[OR4:.*]] = arith.ori %[[OR3]], %[[SHUFFLE4]]
-// CHECK-NEXT:  %[[RET:.*]] = arith.trunci %[[OR4]]
-// CHECK-NEXT:  return %[[RET]]
-
-// -----
-
-module {
-  func.func @andi_no_trunc_arg(%a: i4, %b: i8) -> i4 {
-    %lhs = arith.extui %a : i4 to i8
-    %add = arith.andi %lhs, %b : i8
-    %ret = arith.trunci %add : i8 to i4
-    return %ret : i4
-  }
-}
-
-// CHECK-LABEL: @andi_no_trunc_arg
-// CHECK-NEXT: extui
-// CHECK-NEXT: andi
-// CHECK-NEXT: trunci
-// CHECK-NEXT: return
-
-// -----
-
-module {
-  func.func @ori_mismatched_narrowest(%a: i8, %b: i8) -> i8 {
-    %0 = arith.trunci %a : i8 to i4
-    %1 = arith.extui %0 : i4 to i8
-    %ret = arith.ori %b, %1 : i8
-    return %ret : i8
-  }
-}
-
-// CHECK-LABEL: @ori_mismatched_narrowest
-// CHECK-NEXT: trunci
-// CHECK-NEXT: extui
-// CHECK-NEXT: ori
-// CHECK-NEXT: return
-
-// -----
-
-func.func @refine_constraints(%tensor: tensor<100xf32>) -> tensor<100xf32> {
-  %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-  %c3 = arith.constant 3 : index
-  %c42_f32 = arith.constant 42.0 : f32
-  %loop = scf.for %i = %c0 to %c3 step %c1
-      iter_args(%in_ = %tensor) -> (tensor<100xf32>) {
-    %0 = xla_gpu.apply_indexing affine_map<(d0) -> (d0 mod 4)> (%i in [0, 9])
-    %updated = tensor.insert %c42_f32 into %in_[%0] : tensor<100xf32>
-    scf.yield %updated :tensor<100xf32>
-  }
-  func.return %loop : tensor<100xf32>
-}
-// CHECK-LABEL: func.func @refine_constraints
-// CHECK: %[[CST:.*]] = arith.constant 4.2
-// CHECK: scf.for
-// CHECK: tensor.insert %[[CST]]
-
-
-// -----
-
-#map = affine_map<(d0, d1)[s0, s1] -> (((d0 * 4 + d1 * 512 + s1) floordiv 9 + s0 * 32768) mod 2400000)>
-#map1 = affine_map<(d0, d1)[s0] -> ((d0 * 4 + d1 * 512 + s0) mod 9)>
-func.func @refine_constraints_for_symbol(%arg0: tensor<2400000x9xf32>,
-    %arg1: tensor<2400000x9xf32>) -> tensor<2400000x9xf32> {
-  %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-  %c4 = arith.constant 4 : index
-  %c73 = arith.constant 73 : index
-  %c42_f32 = arith.constant 42.0 : f32
-  %th_x = gpu.thread_id  x {xla.range = [0 : index, 127 : index]}
-  %bl_x = gpu.block_id  x {xla.range = [0 : index, 575 : index]}
-  %0 = scf.for %i = %c0 to %c73 step %c1 iter_args(%arg3 = %arg1)
-      -> (tensor<2400000x9xf32>) {
-    %2 = scf.for %j = %c0 to %c4 step %c1 iter_args(%arg5 = %arg3)
-        -> (tensor<2400000x9xf32>) {
-      %3 = xla_gpu.apply_indexing #map(%th_x in [0, 127], %bl_x in [0, 575])
-        [%i in [0, 73], %j in [0, 3]]
-      %4 = xla_gpu.apply_indexing #map1(%th_x in [0, 127], %bl_x in [0, 575])
-        [%j in [0, 3]]
-      %inserted = tensor.insert %c42_f32 into %arg5[%3, %4]
-        : tensor<2400000x9xf32>
-      scf.yield %inserted : tensor<2400000x9xf32>
-    }
-    scf.yield %2 : tensor<2400000x9xf32>
-  }
-  return %0 : tensor<2400000x9xf32>
-}
-// CHECK: #[[$MAP:.*]] = affine_map<(d0, d1)[s0, s1] -> ((d0 * 4 + d1 * 512 + s1) floordiv 9 + s0 * 32768)>
-// CHECK-LABEL: func.func @refine_constraints_for_symbol
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/vectorize_loads_stores.mlir b/third_party/xla/xla/service/gpu/fusions/mlir/tests/vectorize_loads_stores.mlir
deleted file mode 100644
index 1141d15..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/vectorize_loads_stores.mlir
+++ /dev/null
@@ -1,405 +0,0 @@
-// RUN: mlir_fusions_opt -allow-unregistered-dialect %s -split-input-file -xla-gpu-vectorize-loads-stores -canonicalize -cse | FileCheck %s
-
-#map = affine_map<(d0)[s0] -> (d0 * 2 + s0)>
-module {
-  func.func @simple_read(%arg0: tensor<128xf32>) -> (f32) {
-    %c0 = arith.constant 0 : index
-    %c1 = arith.constant 1 : index
-    %c2 = arith.constant 2 : index
-    %c64 = arith.constant 64 : index
-    %cst = arith.constant 0.0 : f32
-    %outer = scf.for %i = %c0 to %c64 step %c1 iter_args(%iter = %cst) -> f32 {
-      %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 {
-        %idx = xla_gpu.apply_indexing #map(%i in [0, 63])[%j in [0, 1]]
-        %extracted = tensor.extract %arg0[%idx] : tensor<128xf32>
-        %added = arith.addf %iter1, %extracted : f32
-        scf.yield %added : f32
-      }
-      scf.yield %inner : f32
-    }
-    return %outer : f32
-  }
-}
-
-// CHECK: #[[$MAP:.*]] = affine_map<(d0) -> (d0 * 2)>
-// CHECK-LABEL: @simple_read
-// CHECK-SAME:     (%[[ARG0:.*]]: tensor
-// CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
-// CHECK-DAG:   %[[C64:.*]] = arith.constant 64 : index
-// CHECK:       scf.for %[[I:.*]] = %[[C0]] to %[[C64]] step %[[C1]] iter_args(%[[ITER:.*]] =
-// CHECK:         %[[BASE:.*]] = xla_gpu.apply_indexing #map(%[[I]] in [0, 63])
-// CHECK-NEXT:    %[[V:.*]] = vector.transfer_read %[[ARG0]][%[[BASE]]]
-// CHECK-NEXT:    scf.for %[[J:.*]] = %[[C0]]
-// CHECK-NEXT:      vector.extract %[[V]][%[[J]]]
-// CHECK-NEXT:      addf
-
-// -----
-
-module {
-  func.func @simple_read_2d(%arg0: tensor<64x2xf32>) -> (f32) {
-    %c0 = arith.constant 0 : index
-    %c1 = arith.constant 1 : index
-    %c2 = arith.constant 2 : index
-    %c64 = arith.constant 64 : index
-    %cst = arith.constant 0.0 : f32
-    %outer = scf.for %i = %c0 to %c64 step %c1 iter_args(%iter = %cst) -> f32 {
-      %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 {
-        %extracted = tensor.extract %arg0[%i, %j] : tensor<64x2xf32>
-        %added = arith.addf %iter1, %extracted : f32
-        scf.yield %added : f32
-      }
-      scf.yield %inner : f32
-    }
-    return %outer : f32
-  }
-}
-
-// CHECK-LABEL: @simple_read_2d
-// CHECK-SAME:     (%[[ARG0:.*]]: tensor
-// CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
-// CHECK:       scf.for %[[I:.*]] = %[[C0]]
-// CHECK-NEXT:    %[[V:.*]] = vector.transfer_read %[[ARG0]][%[[I]], %[[C0]]]
-// CHECK-NEXT:    scf.for %[[J:.*]] = %[[C0]]
-// CHECK-NEXT:      vector.extract %[[V]][%[[J]]]
-
-// -----
-
-#map = affine_map<(d0)[s0] -> (d0 * 2 + s0 + 1)>
-module {
-  func.func @misaligned_indexing_map(%arg0: tensor<128xf32>) -> (f32) {
-    %c0 = arith.constant 0 : index
-    %c1 = arith.constant 1 : index
-    %c2 = arith.constant 2 : index
-    %c63 = arith.constant 63 : index
-    %cst = arith.constant 0.0 : f32
-    %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 {
-      %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 {
-        %idx = xla_gpu.apply_indexing #map(%i in [0, 63])[%j in [0, 1]]
-        %extracted = tensor.extract %arg0[%idx] : tensor<128xf32>
-        %added = arith.addf %iter1, %extracted : f32
-        scf.yield %added : f32
-      }
-      scf.yield %inner : f32
-    }
-    return %outer : f32
-  }
-}
-
-// CHECK-LABEL: @misaligned_indexing_map
-// CHECK-NOT: vector.transfer_read
-
-// -----
-
-#map = affine_map<(d0)[s0] -> (d0 * 3 + s0)>
-module {
-  func.func @misaligned_indexing_map_2(%arg0: tensor<128xf32>) -> (f32) {
-    %c0 = arith.constant 0 : index
-    %c1 = arith.constant 1 : index
-    %c2 = arith.constant 2 : index
-    %c63 = arith.constant 63 : index
-    %cst = arith.constant 0.0 : f32
-    %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 {
-      %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 {
-        %idx = xla_gpu.apply_indexing #map(%i in [0, 63])[%j in [0, 1]]
-        %extracted = tensor.extract %arg0[%idx] : tensor<128xf32>
-        %added = arith.addf %iter1, %extracted : f32
-        scf.yield %added : f32
-      }
-      scf.yield %inner : f32
-    }
-    return %outer : f32
-  }
-}
-
-// CHECK-LABEL: @misaligned_indexing_map_2
-// CHECK-NOT: vector.transfer_read
-
-// -----
-
-module {
-  func.func @misaligned_shape(%arg0: tensor<64x3xf32>) -> (f32) {
-    %c0 = arith.constant 0 : index
-    %c1 = arith.constant 1 : index
-    %c2 = arith.constant 2 : index
-    %c64 = arith.constant 64 : index
-    %cst = arith.constant 0.0 : f32
-    %outer = scf.for %i = %c0 to %c64 step %c1 iter_args(%iter = %cst) -> f32 {
-      %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 {
-        %extracted = tensor.extract %arg0[%i, %j] : tensor<64x3xf32>
-        %added = arith.addf %iter1, %extracted : f32
-        scf.yield %added : f32
-      }
-      scf.yield %inner : f32
-    }
-    return %outer : f32
-  }
-}
-
-// CHECK-LABEL: @misaligned_shape
-// CHECK-NOT: vector.transfer_read
-
-// -----
-
-#map = affine_map<(d0)[s0] -> (d0 + s0 * 2)>
-module {
-  func.func @wrong_stride(%arg0: tensor<128xf32>) -> (f32) {
-    %c0 = arith.constant 0 : index
-    %c1 = arith.constant 1 : index
-    %c2 = arith.constant 2 : index
-    %c63 = arith.constant 63 : index
-    %cst = arith.constant 0.0 : f32
-    %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 {
-      %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 {
-        %idx = xla_gpu.apply_indexing #map(%i in [0, 63])[%j in [0, 1]]
-        %extracted = tensor.extract %arg0[%idx] : tensor<128xf32>
-        %added = arith.addf %iter1, %extracted : f32
-        scf.yield %added : f32
-      }
-      scf.yield %inner : f32
-    }
-    return %outer : f32
-  }
-}
-
-// CHECK-LABEL: @wrong_stride
-// CHECK-NOT: vector.transfer_read
-
-// -----
-
-// We could vectorize this as a float vector load of double the size, but we
-// don't currently.
-module {
-  func.func @simple_read_complex(%arg0: tensor<64x2xcomplex<f32>>, %i: index) -> (complex<f32>) {
-    %c0 = arith.constant 0 : index
-    %c1 = arith.constant 1 : index
-    %c2 = arith.constant 2 : index
-    %cst = complex.constant [0.0 : f32, 0.0 : f32] : complex<f32>
-    %loop = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter = %cst) -> complex<f32> {
-      %extracted = tensor.extract %arg0[%i, %j] : tensor<64x2xcomplex<f32>>
-      %added = complex.add %iter, %extracted : complex<f32>
-      scf.yield %added : complex<f32>
-    }
-    return %loop : complex<f32>
-  }
-}
-
-// CHECK-LABEL: @simple_read_complex
-// CHECK-NOT: vector.transfer_read
-
-// -----
-
-// This is vectorizable, but not currently supported.
-module {
-  func.func @layout(%arg0: tensor<2x64xf32, dense<[0, 1]> : tensor<2xi64>>) -> (f32) {
-    %c0 = arith.constant 0 : index
-    %c1 = arith.constant 1 : index
-    %c2 = arith.constant 2 : index
-    %c64 = arith.constant 64 : index
-    %cst = arith.constant 0.0 : f32
-    %outer = scf.for %i = %c0 to %c64 step %c1 iter_args(%iter = %cst) -> f32 {
-      %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 {
-        %extracted = tensor.extract %arg0[%j, %i]
-          : tensor<2x64xf32, dense<[0, 1]> : tensor<2xi64>>
-        %added = arith.addf %iter1, %extracted : f32
-        scf.yield %added : f32
-      }
-      scf.yield %inner : f32
-    }
-    return %outer : f32
-  }
-}
-
-// CHECK-LABEL: @layout
-// CHECK-NOT: vector.transfer_read
-
-// -----
-
-module {
-  func.func @simple_write(%arg0: tensor<16x4xf32>, %i: index) -> tensor<16x4xf32> {
-    %c0 = arith.constant 0 : index
-    %c1 = arith.constant 1 : index
-    %c4 = arith.constant 2 : index
-    %cst = arith.constant 0.0 : f32
-    %loop = scf.for %j = %c0 to %c4 step %c1 iter_args(%iter = %arg0) -> tensor<16x4xf32> {
-      %inserted = tensor.insert %cst into %iter[%i, %j] : tensor<16x4xf32>
-      scf.yield %inserted : tensor<16x4xf32>
-    }
-    return %loop : tensor<16x4xf32>
-  }
-}
-
-// CHECK-LABEL: @simple_write
-// CHECK-SAME:     (%[[ARG0:.*]]: tensor{{.*}}, %[[I:.*]]: index
-// CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
-// CHECK:       %[[V:.*]] = scf.for
-// CHECK-NEXT:    vector.insert
-// CHECK-NEXT:    scf.yield
-// CHECK:       %[[WRITTEN:.*]] = vector.transfer_write %[[V]], %[[ARG0]][%[[I]], %[[C0]]]
-// CHECK-NEXT:  return %[[WRITTEN]]
-
-// -----
-
-module {
-  func.func @write_with_use(%arg0: tensor<16x4xf32>, %i: index) -> tensor<16x4xf32> {
-    %c0 = arith.constant 0 : index
-    %c1 = arith.constant 1 : index
-    %c4 = arith.constant 2 : index
-    %cst = arith.constant 0.0 : f32
-    %loop = scf.for %j = %c0 to %c4 step %c1 iter_args(%iter = %arg0) -> tensor<16x4xf32> {
-      %inserted = tensor.insert %cst into %iter[%i, %j] : tensor<16x4xf32>
-      "dummy.op1"(%inserted) : (tensor<16x4xf32>) -> ()
-      scf.yield %inserted : tensor<16x4xf32>
-    }
-    return %loop : tensor<16x4xf32>
-  }
-}
-
-// CHECK-LABEL: @write_with_use
-// CHECK-NOT:   transfer_write
-
-// -----
-
-module {
-  func.func @write_not_to_iter_arg(%arg0: tensor<16x4xf32>, %i: index) -> tensor<16x4xf32> {
-    %c0 = arith.constant 0 : index
-    %c1 = arith.constant 1 : index
-    %c4 = arith.constant 2 : index
-    %cst = arith.constant 0.0 : f32
-    %loop = scf.for %j = %c0 to %c4 step %c1 iter_args(%iter = %arg0) -> tensor<16x4xf32> {
-      %inserted = tensor.insert %cst into %arg0[%i, %j] : tensor<16x4xf32>
-      scf.yield %inserted : tensor<16x4xf32>
-    }
-    return %loop : tensor<16x4xf32>
-  }
-}
-
-// CHECK-LABEL: @write_not_to_iter_arg
-// CHECK-NOT:   transfer_write
-
-// -----
-
-module {
-  func.func @write_not_yielded(%arg0: tensor<16x4xf32>, %i: index) -> tensor<16x4xf32> {
-    %c0 = arith.constant 0 : index
-    %c1 = arith.constant 1 : index
-    %c4 = arith.constant 2 : index
-    %cst = arith.constant 0.0 : f32
-    %loop = scf.for %j = %c0 to %c4 step %c1 iter_args(%iter = %arg0) -> tensor<16x4xf32> {
-      %inserted = tensor.insert %cst into %arg0[%i, %j] : tensor<16x4xf32>
-      scf.yield %arg0 : tensor<16x4xf32>
-    }
-    return %loop : tensor<16x4xf32>
-  }
-}
-
-// CHECK-LABEL: @write_not_yielded
-// CHECK-NOT:   transfer_write
-
-// -----
-
-#map = affine_map<(d0, d1)[s0] -> (d1 * 2 + d0 + s0 * 512)>
-module {
-  func.func @multiple(%arg0: tensor<32x4096xf32>, %arg1: tensor<4096xbf16>,
-        %arg2: tensor<32xf32>, %arg3: tensor<32x4096xf32>,
-        %arg4: index) -> (tensor<32x4096xf32>, f32) {
-    %cst = arith.constant 1.000000e+00 : f32
-    %c0 = arith.constant 0 : index
-    %c1 = arith.constant 1 : index
-    %c2 = arith.constant 2 : index
-    %c8 = arith.constant 8 : index
-    %extracted1 = tensor.extract %arg2[%arg4] : tensor<32xf32>
-    %0:2 = scf.for %i = %c0 to %c8 step %c1 iter_args(%iter0 = %arg3, %iter1 = %cst) -> (tensor<32x4096xf32>, f32) {
-      %1:2 = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter2 = %iter0, %iter3 = %iter1) -> (tensor<32x4096xf32>, f32) {
-        %2 = xla_gpu.apply_indexing #map(%j in [0, 1], %arg4 in [0, 255])[%i in [0, 7]]
-        %extracted2 = tensor.extract %arg0[%i, %2] : tensor<32x4096xf32>
-        %extracted3 = tensor.extract %arg1[%2] : tensor<4096xbf16>
-        %3 = arith.extf %extracted3 : bf16 to f32
-        %4 = arith.addf %extracted2, %3 : f32
-        %5 = arith.addf %extracted1, %4 : f32
-        %6 = arith.addf %iter3, %5 : f32
-        %inserted = tensor.insert %5 into %iter2[%i, %2] : tensor<32x4096xf32>
-        scf.yield %inserted, %6 : tensor<32x4096xf32>, f32
-      }
-      scf.yield %1#0, %1#1 : tensor<32x4096xf32>, f32
-    }
-    return %0#0, %0#1 : tensor<32x4096xf32>, f32
-  }
-}
-
-// CHECK: #[[$MAP:.*]] = affine_map<(d0)[s0] -> (d0 * 2 + s0 * 512)>
-// CHECK-LABEL: @multiple
-// CHECK-SAME: (%[[ARG0:.*]]: tensor{{.*}}, %[[ARG1:.*]]: tensor{{.*}}, %[[ARG2:.*]]: tensor{{.*}}, %[[ARG3:.*]]: tensor{{.*}}, %[[ARG4:.*]]: index)
-// CHECK:      %[[C0:.*]] = arith.constant 0 : index
-// CHECK:      scf.for %[[I:.*]] = %[[C0]]
-// CHECK:      %[[BASE:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[ARG4]] in [0, 255])[%[[I]] in [0, 7]]
-// CHECK:      %[[READ1:.*]] = vector.transfer_read %[[ARG1]][%[[BASE]]]
-// CHECK:      %[[READ2:.*]] = vector.transfer_read %[[ARG0]][%[[I]], %[[BASE]]]
-// CHECK:      %[[INNER:.*]]:2 = scf.for %[[J:.*]] = %[[C0]] {{.*}} iter_args(%[[F:.*]] = {{.*}}, %[[V:.*]] = {{.*}}) -> (f32, vector<2xf32>)
-// CHECK-DAG:  vector.extract %[[READ1]][%[[J]]]
-// CHECK-DAG:  vector.extract %[[READ2]][%[[J]]]
-// CHECK:      extf
-// CHECK-NEXT: addf
-// CHECK-NEXT: %[[TO_INSERT:.*]] = arith.addf
-// CHECK-NEXT: %[[TO_YIELD:.*]] = arith.addf
-// CHECK-NEXT: %[[V_NEXT:.*]] = vector.insert %[[TO_INSERT]], %[[V]] [%[[J]]]
-// CHECK-NEXT: scf.yield %[[TO_YIELD]], %[[V_NEXT]]
-// CHECK:      %[[WRITTEN:.*]] = vector.transfer_write %[[INNER]]#1, %{{.*}}[%[[I]], %[[BASE]]]
-// CHECK:      scf.yield %[[WRITTEN]], %[[INNER]]#0
-
-// -----
-
-#map = affine_map<(d0)[s0] -> ((d0 * 4) mod 64 + s0)>
-module {
-  func.func @remainder_with_modulo(%arg0: tensor<128xf32>) -> (f32) {
-    %c0 = arith.constant 0 : index
-    %c1 = arith.constant 1 : index
-    %c2 = arith.constant 2 : index
-    %c63 = arith.constant 63 : index
-    %cst = arith.constant 0.0 : f32
-    %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 {
-      %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 {
-        %idx = xla_gpu.apply_indexing #map(%i in [0, 63])[%j in [0, 1]]
-        %extracted = tensor.extract %arg0[%idx] : tensor<128xf32>
-        %added = arith.addf %iter1, %extracted : f32
-        scf.yield %added : f32
-      }
-      scf.yield %inner : f32
-    }
-    return %outer : f32
-  }
-}
-
-// CHECK: #[[$MAP:.*]] = affine_map<(d0) -> ((d0 mod 16) * 4)>
-// CHECK-LABEL: @remainder_with_modulo
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: scf.for %[[I:.*]] = %[[C0]]
-// CHECK: %[[BASE:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[I]]
-// CHECK: vector.transfer_read {{.*}}[%[[BASE]]]
-
-// -----
-
-#map = affine_map<(d0)[s0] -> ((d0 * 4) mod 65 + s0)>
-module {
-  func.func @remainder_with_modulo_misaligned(%arg0: tensor<128xf32>) -> (f32) {
-    %c0 = arith.constant 0 : index
-    %c1 = arith.constant 1 : index
-    %c2 = arith.constant 2 : index
-    %c63 = arith.constant 63 : index
-    %cst = arith.constant 0.0 : f32
-    %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 {
-      %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 {
-        %idx = xla_gpu.apply_indexing #map(%i in [0, 63])[%j in [0, 1]]
-        %extracted = tensor.extract %arg0[%idx] : tensor<128xf32>
-        %added = arith.addf %iter1, %extracted : f32
-        scf.yield %added : f32
-      }
-      scf.yield %inner : f32
-    }
-    return %outer : f32
-  }
-}
-
-// CHECK-LABEL: @remainder_with_modulo_misaligned
-// CHECK-NOT: vector.transfer_read
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/unswitch_loops.cc b/third_party/xla/xla/service/gpu/fusions/mlir/unswitch_loops.cc
deleted file mode 100644
index 7d963f3..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/unswitch_loops.cc
+++ /dev/null
@@ -1,106 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include <memory>
-#include <utility>
-
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/IR/IRMapping.h"
-#include "mlir/IR/Matchers.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Support/LLVM.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-
-namespace xla {
-namespace gpu {
-
-#define GEN_PASS_DEF_UNSWITCHLOOPSPASS
-#include "xla/service/gpu/fusions/mlir/passes.h.inc"
-
-namespace {
-
-class UnswitchLoopsPass
-    : public impl::UnswitchLoopsPassBase<UnswitchLoopsPass> {
- public:
-  void runOnOperation() override;
-};
-
-struct UnswitchLoop : mlir::OpRewritePattern<mlir::scf::ForOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  mlir::LogicalResult matchAndRewrite(
-      mlir::scf::ForOp op, mlir::PatternRewriter& rewriter) const override {
-    if (op.getBody()->getOperations().size() != 2) {
-      return rewriter.notifyMatchFailure(
-          op, "loop body is not a single instruction");
-    }
-    auto if_op = mlir::dyn_cast<mlir::scf::IfOp>(op.getBody()->front());
-    if (!if_op) {
-      return rewriter.notifyMatchFailure(op, "no if found inside the loop");
-    }
-    if (mlir::matchPattern(if_op.getCondition(), mlir::m_Constant())) {
-      return rewriter.notifyMatchFailure(op, "condition is a constant");
-    }
-
-    auto true_cst = rewriter.create<mlir::arith::ConstantOp>(
-        op.getLoc(), rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
-    auto false_cst = rewriter.create<mlir::arith::ConstantOp>(
-        op.getLoc(), rewriter.getIntegerAttr(rewriter.getI1Type(), 0));
-    rewriter.setInsertionPoint(op);
-    mlir::IRMapping mapping;
-    mapping.map(if_op.getCondition(), false_cst);
-    auto false_branch_loop = op->clone(mapping);
-    auto new_if = rewriter.create<mlir::scf::IfOp>(
-        op.getLoc(), op.getResultTypes(), if_op.getCondition(), true, true);
-    rewriter.replaceAllUsesWith(op.getResults(), new_if.getResults());
-
-    auto then_builder = new_if.getThenBodyBuilder(rewriter.getListener());
-    auto then_yield =
-        then_builder.create<mlir::scf::YieldOp>(op.getLoc(), op.getResults());
-    rewriter.moveOpBefore(op, then_yield);
-    rewriter.modifyOpInPlace(if_op, [&]() { if_op->setOperand(0, true_cst); });
-
-    auto else_builder = new_if.getElseBodyBuilder(rewriter.getListener());
-    else_builder.insert(false_branch_loop);
-    else_builder.create<mlir::scf::YieldOp>(op.getLoc(),
-                                            false_branch_loop->getResults());
-
-    return mlir::success();
-  }
-};
-
-void UnswitchLoopsPass::runOnOperation() {
-  mlir::RewritePatternSet patterns(&getContext());
-  patterns.add<UnswitchLoop>(&getContext());
-  mlir::scf::ForOp::getCanonicalizationPatterns(patterns, &getContext());
-  mlir::scf::IfOp::getCanonicalizationPatterns(patterns, &getContext());
-  if (mlir::failed(mlir::applyPatternsAndFoldGreedily(getOperation(),
-                                                      std::move(patterns)))) {
-    signalPassFailure();
-  }
-}
-
-}  // namespace
-
-std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
-CreateUnswitchLoopsPass() {
-  return std::make_unique<UnswitchLoopsPass>();
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/vectorize_loads_stores.cc b/third_party/xla/xla/service/gpu/fusions/mlir/vectorize_loads_stores.cc
deleted file mode 100644
index 0007984..0000000
--- a/third_party/xla/xla/service/gpu/fusions/mlir/vectorize_loads_stores.cc
+++ /dev/null
@@ -1,359 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include <cstdint>
-#include <memory>
-#include <numeric>
-#include <optional>
-#include <utility>
-
-#include "llvm/ADT/APInt.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/Support/ErrorHandling.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Dialect/Utils/StaticValueUtils.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/IR/AffineExpr.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/ImplicitLocOpBuilder.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/Value.h"
-#include "mlir/IR/ValueRange.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Support/LLVM.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h"
-
-namespace xla {
-namespace gpu {
-
-#define GEN_PASS_DEF_VECTORIZELOADSANDSTORESPASS
-#include "xla/service/gpu/fusions/mlir/passes.h.inc"
-
-namespace {
-
-// Tries to find the stride of a symbol or dimension in an affine expression.
-// Returns std::nullopt if the stride could not be determined.
-//
-// Note: this function only attempts to handle the cases where the stride is
-// known to be 0 or 1.
-//
-// Example: the stride of `d0` in `(d0 + d1)` is 1.
-// Example: the stride of `d0` in `d0 * 2` is unknown (nullopt).
-std::optional<int> GetStride(mlir::AffineExpr expr,
-                             mlir::AffineExpr dim_or_sym) {
-  if (auto binop = mlir::dyn_cast_or_null<mlir::AffineBinaryOpExpr>(expr)) {
-    auto lhs_stride = GetStride(binop.getLHS(), dim_or_sym);
-    auto rhs_stride = GetStride(binop.getRHS(), dim_or_sym);
-
-    if (binop.getKind() == mlir::AffineExprKind::Add) {
-      if (lhs_stride && rhs_stride) {
-        return *lhs_stride + *rhs_stride;
-      }
-      return std::nullopt;
-    }
-    // Just return 0 if the expression doesn't occur on either side.
-    if (lhs_stride == 0 && rhs_stride == 0) {
-      return 0;
-    }
-    // Otherwise, we don't know the stride.
-    return std::nullopt;
-  }
-  return expr == dim_or_sym ? 1 : 0;
-}
-
-int64_t GetAlignmentOfRemainder(mlir::AffineExpr expr,
-                                mlir::AffineExpr dim_or_sym) {
-  if (auto binop = mlir::dyn_cast_or_null<mlir::AffineBinaryOpExpr>(expr)) {
-    auto lhs_align = GetAlignmentOfRemainder(binop.getLHS(), dim_or_sym);
-    auto rhs_align = GetAlignmentOfRemainder(binop.getRHS(), dim_or_sym);
-
-    std::optional<int64_t> rhs_cst = std::nullopt;
-    if (binop.getRHS().getKind() == mlir::AffineExprKind::Constant) {
-      rhs_cst = binop.getRHS().cast<mlir::AffineConstantExpr>().getValue();
-    }
-
-    switch (binop.getKind()) {
-      case mlir::AffineExprKind::Add:
-        if (binop.getLHS() == dim_or_sym) return rhs_align;
-        if (binop.getRHS() == dim_or_sym) return lhs_align;
-        return std::gcd(lhs_align, rhs_align);
-      case mlir::AffineExprKind::Mul:
-        return lhs_align * rhs_align;
-      case mlir::AffineExprKind::FloorDiv:
-      case mlir::AffineExprKind::CeilDiv:
-        return 1;
-      case mlir::AffineExprKind::Mod:
-        // (a * c) % (b * c) = (a % b) * c.
-        return std::gcd(lhs_align, rhs_align);
-      default:
-        llvm_unreachable("expr is none of the binary expressions");
-    }
-  }
-  if (auto cst = mlir::dyn_cast<mlir::AffineConstantExpr>(expr)) {
-    return cst.getValue();
-  }
-  return 1;
-}
-
-// Attempts to extract the vector type for the given loop. This means:
-// - checks that the lower bound is 0
-// - checks that the step is 1
-// - checks that the upper bound is 2 or 4.
-// Returns a vector type with the given upper bound and the tensor's element
-// type.
-mlir::VectorType GetVectorType(mlir::RankedTensorType tensor_type,
-                               mlir::scf::ForOp loop) {
-  // TODO(jreiffers): Support layouts.
-  if (tensor_type.getEncoding()) {
-    return nullptr;
-  }
-  if (!mlir::VectorType::isValidElementType(tensor_type.getElementType())) {
-    return nullptr;
-  }
-  if (mlir::getConstantIntValue(loop.getStep()) != 1 ||
-      mlir::getConstantIntValue(loop.getLowerBound()) != 0) {
-    return nullptr;
-  }
-  std::optional<int> vector_size =
-      mlir::getConstantIntValue(loop.getUpperBound());
-  if (vector_size != 2 && vector_size != 4) {
-    return nullptr;  // Unsupported vector size.
-  }
-  if (tensor_type.getRank() > 1 &&
-      tensor_type.getShape().back() % *vector_size) {
-    return nullptr;  // Misaligned start indices.
-  }
-  return mlir::VectorType::get({*vector_size}, tensor_type.getElementType());
-}
-
-std::optional<llvm::SmallVector<mlir::Value>> GetVectorBaseIndices(
-    mlir::ValueRange indices, mlir::scf::ForOp loop,
-    mlir::VectorType vector_type, mlir::ImplicitLocOpBuilder& b) {
-  if (indices.empty()) {
-    return std::nullopt;
-  }
-
-  // The major dimensions' indices must all be defined outside the loop.
-  for (int i = 0; i < indices.size() - 1; ++i) {
-    if (!indices[i].getParentRegion()->isProperAncestor(
-            &loop.getBodyRegion())) {
-      return std::nullopt;
-    }
-  }
-
-  mlir::Value induction_var = loop.getInductionVar();
-  if (indices.back() == induction_var) {
-    llvm::SmallVector<mlir::Value> ret = indices;
-    ret.back() = b.create<mlir::arith::ConstantIndexOp>(0);
-    return ret;
-  }
-
-  auto apply_indexing =
-      mlir::dyn_cast_or_null<ApplyIndexingOp>(indices.back().getDefiningOp());
-  if (!apply_indexing) {
-    return std::nullopt;
-  }
-
-  // We don't generate these, but they are allowed in theory.
-  if (apply_indexing->getNumResults() != 1) {
-    return std::nullopt;
-  }
-  mlir::AffineMap map = apply_indexing.getAffineMap();
-
-  int induction_var_operand_index;
-  mlir::AffineExpr induction_var_expr = nullptr;
-  for (auto [index, operand] : llvm::enumerate(apply_indexing.getOperands())) {
-    if (operand == induction_var) {
-      if (induction_var_expr) {
-        // The induction variable should be used only once.
-        return std::nullopt;
-      }
-      induction_var_operand_index = index;
-      induction_var_expr = index < map.getNumDims()
-                               ? mlir::getAffineDimExpr(index, b.getContext())
-                               : mlir::getAffineSymbolExpr(
-                                     index - map.getNumDims(), b.getContext());
-    }
-  }
-  if (!induction_var_expr) {
-    return std::nullopt;
-  }
-
-  if (GetStride(map.getResult(0), induction_var_expr) != 1) {
-    // The indexing map is not contiguous in the vectorized dimension.
-    return std::nullopt;
-  }
-
-  if (GetAlignmentOfRemainder(map.getResult(0), induction_var_expr) %
-      vector_type.getNumElements()) {
-    return std::nullopt;
-  }
-
-  auto operands = llvm::to_vector(apply_indexing.getOperands());
-  operands[induction_var_operand_index] =
-      b.create<mlir::arith::ConstantIndexOp>(0);
-
-  llvm::SmallVector<mlir::Value> ret = indices;
-  ret.back() =
-      b.create<ApplyIndexingOp>(operands, map, apply_indexing.getLowerBounds(),
-                                apply_indexing.getUpperBounds())
-          ->getResult(0);
-  return ret;
-}
-
-bool IsConflictFree(mlir::tensor::ExtractOp op) {
-  return op.getTensor().getParentRegion()->isProperAncestor(
-      op->getParentRegion());
-}
-
-struct VectorizeLoad : mlir::OpRewritePattern<mlir::tensor::ExtractOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  mlir::LogicalResult matchAndRewrite(
-      mlir::tensor::ExtractOp op,
-      mlir::PatternRewriter& rewriter) const override {
-    auto loop = mlir::dyn_cast_or_null<mlir::scf::ForOp>(op->getParentOp());
-    if (!loop) {
-      return rewriter.notifyMatchFailure(op, "no loop found");
-    }
-    if (!IsConflictFree(op)) {
-      return rewriter.notifyMatchFailure(op,
-                                         "source may be written in the loop");
-    }
-
-    auto vector_type = GetVectorType(op.getTensor().getType(), loop);
-    if (!vector_type) {
-      return rewriter.notifyMatchFailure(op, "not a vectorizable loop");
-    }
-
-    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
-    b.setInsertionPoint(loop);
-    auto vector_indices =
-        GetVectorBaseIndices(op.getIndices(), loop, vector_type, b);
-    if (!vector_indices) {
-      return rewriter.notifyMatchFailure(
-          op, "the instruction does not access contiguous elements");
-    }
-
-    auto loaded_vector = b.create<mlir::vector::TransferReadOp>(
-        vector_type, op.getTensor(), *vector_indices,
-        llvm::ArrayRef<bool>{true});
-    rewriter.replaceOpWithNewOp<mlir::vector::ExtractOp>(
-        op, loaded_vector, loop.getInductionVar());
-    return mlir::success();
-  }
-};
-
-// Verifies that the insertions happening in the loop can all safely be batched
-// in the end.
-bool IsConflictFree(mlir::tensor::InsertOp op) {
-  // The insertion's only use must be the yield.
-  if (!op->hasOneUse() || !mlir::isa<mlir::scf::YieldOp>(*op->user_begin())) {
-    return false;
-  }
-  // The destination must be one of the loop's block arguments, and the
-  // destination must be the argument's only use.
-  auto bbarg = mlir::dyn_cast<mlir::BlockArgument>(op.getDest());
-  return bbarg && bbarg.hasOneUse() &&
-         bbarg.getOwner()->getParentOp() == op->getParentOp();
-}
-
-struct VectorizeStore : mlir::OpRewritePattern<mlir::tensor::InsertOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  mlir::LogicalResult matchAndRewrite(
-      mlir::tensor::InsertOp op,
-      mlir::PatternRewriter& rewriter) const override {
-    auto loop = mlir::dyn_cast_or_null<mlir::scf::ForOp>(op->getParentOp());
-    if (!loop) {
-      return rewriter.notifyMatchFailure(op, "no loop found");
-    }
-    if (!IsConflictFree(op)) {
-      return rewriter.notifyMatchFailure(op, "write may be read back by loop");
-    }
-    auto vector_type = GetVectorType(op.getDest().getType(), loop);
-    if (!vector_type) {
-      return rewriter.notifyMatchFailure(op, "loop is not vectorizable");
-    }
-
-    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
-    b.setInsertionPoint(loop);
-    auto vector_indices =
-        GetVectorBaseIndices(op.getIndices(), loop, vector_type, b);
-    if (!vector_indices) {
-      return rewriter.notifyMatchFailure(
-          op, "the instruction does not access contiguous elements");
-    }
-
-    auto init = b.create<mlir::arith::ConstantOp>(b.getZeroAttr(vector_type))
-                    .getResult();
-
-    auto yield_fn = [&](mlir::OpBuilder& yield_b, mlir::Location yield_loc,
-                        llvm::ArrayRef<mlir::BlockArgument> bbarg) {
-      auto induction_var =
-          mlir::cast<mlir::scf::ForOp>(bbarg.front().getOwner()->getParentOp())
-              .getInductionVar();
-      auto insert_op = yield_b.create<mlir::vector::InsertOp>(
-          yield_loc, op.getScalar(), bbarg.front(), induction_var);
-      return llvm::SmallVector<mlir::Value>{insert_op.getResult()};
-    };
-    int result_index = op->use_begin()->getOperandNumber();
-    auto new_for = *loop.replaceWithAdditionalYields(
-        rewriter, init,
-        /*replaceInitOperandUsesInLoop=*/false, yield_fn);
-
-    b.setInsertionPointAfter(new_for);
-    rewriter.replaceOp(op, op.getDest());
-
-    auto filled_vector = new_for->getResults().back();
-    auto written = b.create<mlir::vector::TransferWriteOp>(
-        filled_vector, new_for.getInits()[result_index], *vector_indices,
-        llvm::ArrayRef<bool>{true});
-    new_for->getResult(result_index).replaceAllUsesWith(written.getResult());
-
-    return mlir::success();
-  }
-};
-
-class VectorizeLoadsAndStoresPass
-    : public impl::VectorizeLoadsAndStoresPassBase<
-          VectorizeLoadsAndStoresPass> {
- public:
-  void runOnOperation() override {
-    mlir::RewritePatternSet patterns(&getContext());
-    patterns.add<VectorizeLoad, VectorizeStore>(&getContext());
-    if (mlir::failed(mlir::applyPatternsAndFoldGreedily(getOperation(),
-                                                        std::move(patterns)))) {
-      signalPassFailure();
-    }
-  }
-};
-
-}  // namespace
-
-std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
-CreateVectorizeLoadsAndStoresPass() {
-  return std::make_unique<VectorizeLoadsAndStoresPass>();
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir_emitter_test_base.cc b/third_party/xla/xla/service/gpu/fusions/mlir_emitter_test_base.cc
index 774fa8e..db4a93a 100644
--- a/third_party/xla/xla/service/gpu/fusions/mlir_emitter_test_base.cc
+++ b/third_party/xla/xla/service/gpu/fusions/mlir_emitter_test_base.cc
@@ -38,7 +38,7 @@
 #include "xla/hlo/ir/hlo_casting_utils.h"
 #include "xla/hlo/ir/hlo_instructions.h"
 #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h"
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h"
 #include "xla/service/gpu/hlo_fusion_analysis.h"
 #include "xla/service/gpu/model/affine_map_printer.h"
 #include "xla/tests/filecheck.h"
@@ -82,7 +82,7 @@
   TF_ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule(hlo_string));
 
   auto* root = module->entry_computation()->root_instruction();
-  auto analysis = AnalyzeFusion(*root, device_info_);
+  auto analysis = HloFusionAnalysis::Create(*root, device_info_);
 
   auto fusion_emitter = GetEmitter(analysis);
 
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir_emitter_test_base.h b/third_party/xla/xla/service/gpu/fusions/mlir_emitter_test_base.h
index 0006dc5..3b0c78c 100644
--- a/third_party/xla/xla/service/gpu/fusions/mlir_emitter_test_base.h
+++ b/third_party/xla/xla/service/gpu/fusions/mlir_emitter_test_base.h
@@ -72,7 +72,7 @@
     auto& module =
         modules_.emplace_back(ParseAndReturnVerifiedModule(hlo_string).value());
     auto* root = module->entry_computation()->root_instruction();
-    analyses_.push_back(AnalyzeFusion(*root, device_info_));
+    analyses_.push_back(HloFusionAnalysis::Create(*root, device_info_));
     return GetEmitter(analyses_.back());
   }
 
diff --git a/third_party/xla/xla/service/gpu/fusions/reduction.cc b/third_party/xla/xla/service/gpu/fusions/reduction.cc
deleted file mode 100644
index 77c1e91..0000000
--- a/third_party/xla/xla/service/gpu/fusions/reduction.cc
+++ /dev/null
@@ -1,1330 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "xla/service/gpu/fusions/reduction.h"
-
-#include <cstdint>
-#include <functional>
-#include <memory>
-#include <optional>
-#include <string>
-#include <tuple>
-#include <utility>
-#include <vector>
-
-#include "absl/container/flat_hash_map.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/container/inlined_vector.h"
-#include "absl/container/node_hash_map.h"
-#include "absl/log/check.h"
-#include "absl/status/status.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/Twine.h"
-#include "llvm/IR/Constants.h"
-#include "llvm/IR/DerivedTypes.h"
-#include "llvm/IR/GlobalVariable.h"
-#include "llvm/IR/IRBuilder.h"
-#include "llvm/IR/Instructions.h"
-#include "llvm/IR/Type.h"
-#include "llvm/IR/Value.h"
-#include "llvm/Support/AtomicOrdering.h"
-#include "llvm/Support/Casting.h"
-#include "mlir/Support/LLVM.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/layout_util.h"
-#include "xla/service/buffer_assignment.h"
-#include "xla/service/gpu/elemental_ir_emitter.h"
-#include "xla/service/gpu/fusions/fusion_emitter.h"
-#include "xla/service/gpu/fusions/reduction_base.h"
-#include "xla/service/gpu/fusions/thunk_util.h"
-#include "xla/service/gpu/fusions/tiling_util.h"
-#include "xla/service/gpu/hlo_fusion_analysis.h"
-#include "xla/service/gpu/hlo_traversal.h"
-#include "xla/service/gpu/ir_emission_utils.h"
-#include "xla/service/gpu/ir_emitter_context.h"
-#include "xla/service/gpu/ir_emitter_nested.h"
-#include "xla/service/gpu/kernel_arguments.h"
-#include "xla/service/gpu/kernel_reuse_cache.h"
-#include "xla/service/gpu/launch_dimensions.h"
-#include "xla/service/gpu/parallel_loop_emitter.h"
-#include "xla/service/gpu/reduction_utils.h"
-#include "xla/service/gpu/runtime/kernel_thunk.h"
-#include "xla/service/gpu/runtime/thunk.h"
-#include "xla/service/gpu/target_util.h"
-#include "xla/service/llvm_ir/fused_ir_emitter.h"
-#include "xla/service/llvm_ir/ir_array.h"
-#include "xla/service/llvm_ir/kernel_support_library.h"
-#include "xla/service/llvm_ir/llvm_loop.h"
-#include "xla/service/llvm_ir/llvm_util.h"
-#include "xla/service/llvm_ir/loop_emitter.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/status_macros.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/util.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h"
-#include "tsl/platform/status.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-using TypedPointer = std::pair<llvm::Value* const, llvm::Type* const>;
-
-// Fusion root -> array of indexes, one per reduction output.
-using ReductionOutputMap =
-    ConstHloInstructionMap<absl::Span<llvm_ir::IrArray const>>;
-
-using ExtraOutputGensMap = ConstHloInstructionMap<llvm_ir::ElementGenerator>;
-
-int GetNumOutputs(const Shape& shape) {
-  if (shape.IsTuple()) {
-    return shape.tuple_shapes_size();
-  }
-  return 1;
-}
-
-const Shape& OutputShape(const Shape& output_shape, int output_index) {
-  CHECK(output_index == 0 || output_shape.IsTuple());
-  return output_shape.IsTuple() ? output_shape.tuple_shapes(output_index)
-                                : output_shape;
-}
-
-llvm::Type* GetIndexType(const HloFusionInstruction& fusion,
-                         const Tiling& tiling, llvm::IRBuilder<>* builder) {
-  return GetIndexTypeForKernel(
-      &fusion, tiling.GetNumThreadsPerBlock() * tiling.GetNumBlocks(), builder);
-}
-
-llvm::Value* CastSharedToGlobal(llvm::IRBuilder<>* builder, llvm::Value* input,
-                                llvm::Type* element_type, llvm::Twine name) {
-  return builder->CreateAddrSpaceCast(
-      input,
-      llvm::PointerType::get(element_type,
-                             /*AddressSpace=*/0),
-      name);
-}
-
-class ReductionEmitter {
- public:
-  ReductionEmitter(const HloFusionAnalysis& analysis,
-                   const ReductionInfo& reduction_codegen_info,
-                   IrEmitterContext& ir_emitter_context,
-                   const HloFusionInstruction& fusion,
-                   llvm::IRBuilder<>* builder)
-      : builder_(builder),
-        elemental_emitter_(ir_emitter_context, builder_),
-        analysis_(analysis),
-        reduction_codegen_info_(reduction_codegen_info),
-        ir_emitter_context_(ir_emitter_context),
-        fusion_(fusion),
-        index_ty_(GetIndexType(fusion, reduction_codegen_info.GetTiling(),
-                               elemental_emitter_.builder())) {
-    for (auto hero : analysis.fusion_heroes()) {
-      if (hero.opcode() == HloOpcode::kReduce) {
-        for (int i = 0; i < hero.instruction().operand_count() / 2; ++i) {
-          CHECK(LayoutUtil::IsMonotonicWithDim0Major(
-              hero.instruction().operand(i)->shape().layout()))
-              << "reduction-layout-normalizer must run before code generation";
-        }
-      }
-    }
-  }
-
-  absl::StatusOr<FusionEmissionResult> EmitInitializers();
-  absl::Status EmitKernel(const LaunchDimensions& launch_dims,
-                          std::vector<llvm_ir::IrArray> inputs,
-                          std::vector<llvm_ir::IrArray> outputs);
-
- private:
-  friend class ReductionGroupEmitter;
-
-  absl::StatusOr<std::unique_ptr<Thunk>> BuildKernelThunkForFusion(
-      const LaunchDimensions& launch_dimensions,
-      absl::string_view discriminator,
-      std::function<absl::Status(std::vector<llvm_ir::IrArray>,
-                                 std::vector<llvm_ir::IrArray>)>
-          kernel_builder_fn);
-
-  absl::StatusOr<std::unique_ptr<Thunk>> BuildFusedInitializerThunk(
-      const HloInstruction* fusion_root, BufferAllocation::Slice dest_slice,
-      int output_index);
-
-  absl::Status EmitIRForReduction(
-      absl::Span<const HloInstruction* const> instr_index_group,
-      FusedIrEmitter& fused_emitter, const ReductionOutputMap& result_ir_arrays,
-      const Shape& input_shape);
-
-  void MaybeEmitFenceForAMDGPU();
-  void EmitSyncThreads();
-
-  int ReducedDimensionSize() const {
-    return reduction_codegen_info_.GetTiling().GetShape()[2];
-  }
-
-  llvm::IRBuilder<>* builder_;
-  GpuElementalIrEmitter elemental_emitter_;
-  const HloFusionAnalysis& analysis_;
-  const ReductionInfo& reduction_codegen_info_;
-  IrEmitterContext& ir_emitter_context_;
-  const HloFusionInstruction& fusion_;
-  llvm::Type* index_ty_;
-};
-
-class ReductionEmitter;
-
-class ReductionGroupEmitter {
- public:
-  struct ReductionCalculationState {
-    std::optional<llvm_ir::SharedMemoryTile> shared_cache;
-    llvm::Value* initial_value;
-    llvm::AllocaInst* partial_result_address;
-    llvm::AllocaInst* input_address;
-    llvm_ir::ElementGenerator input_gen;
-  };
-
-  ReductionGroupEmitter(
-      ReductionEmitter& reduction_emitter,
-      absl::Span<const HloReduceInstruction* const> reduce_instr_index_group,
-      const ReductionOutputMap& result_ir_arrays,
-      FusedIrEmitter& fused_emitter);
-
-  const ReductionCalculationState& GetCalculationStateFor(
-      const HloInstruction* instruction, int operand_idx) const {
-    const ReductionOpState& op_state = state_.at(instruction);
-    CHECK_LT(operand_idx, op_state.size());
-    return op_state[operand_idx];
-  }
-
-  void SetCalculationStateFor(
-      const ReductionCalculationState& calculation_state,
-      const HloInstruction* instruction, int operand_idx) {
-    ReductionOpState& op_state = state_[instruction];
-    CHECK_EQ(operand_idx, op_state.size());
-    op_state.push_back(calculation_state);
-  }
-
-  void EmitReductionOutputForRowReduction(
-      const TilingKernelInfo& tiling_kernel_info,
-      const HloReduceInstruction* reduction,
-      const std::vector<const HloInstruction*>& roots) const;
-
-  void EmitReductionOutputForColumnReduction(
-      const TilingKernelInfo& tiling_kernel_info,
-      const HloReduceInstruction* reduction,
-      const std::vector<const HloInstruction*>& roots) const;
-
-  void EmitFullWarpShuffleDownLoopForReduce(
-      const HloComputation* reducer,
-      absl::Span<TypedPointer const> partial_result_addresses,
-      int threads_per_block, int num_results_per_warp) const;
-
-  void WriteReductionOutput(const TilingKernelInfo& tiling_kernel_info,
-                            const HloReduceInstruction* reduction,
-                            const std::vector<const HloInstruction*>& roots,
-                            absl::Span<TypedPointer const> values) const;
-
-  llvm_ir::IrArray::Index GetOutputIndexForReduction(
-      const TilingKernelInfo& tiling_kernel_info,
-      const HloReduceInstruction* reduction, const HloInstruction* root,
-      int output_idx) const;
-
-  void GenerateElementForReducer(const HloReduceInstruction* reduction,
-                                 const llvm_ir::IrArray::Index& index) const;
-
-  absl::Status EmitExtraOutputsForReduce(
-      const Shape& reduction_operand_shape,
-      const llvm_ir::IrArray::Index& index,
-      const ExtraOutputGensMap& extra_output_gens);
-
- private:
-  ReductionEmitter& reduction_emitter_;
-  const ReductionOutputMap& result_ir_arrays_;
-
-  // One state per reduction operand.
-  using ReductionOpState = absl::InlinedVector<ReductionCalculationState, 2>;
-
-  // HloInstruction -> operand_idx -> cache
-  absl::flat_hash_map<const HloInstruction*, ReductionOpState> state_;
-};
-
-// Creates accumulator alloca's, populates them with initial values, generates
-// __shared__ caches and returns the populated object.
-ReductionGroupEmitter::ReductionGroupEmitter(
-    ReductionEmitter& reduction_emitter,
-    absl::Span<const HloReduceInstruction* const> reduce_instr_index_group,
-    const ReductionOutputMap& result_ir_arrays, FusedIrEmitter& fused_emitter)
-    : reduction_emitter_(reduction_emitter),
-      result_ir_arrays_(result_ir_arrays) {
-  const ReductionInfo& reduction_info =
-      reduction_emitter_.reduction_codegen_info_;
-  VLOG(10) << "Emit prologue for reduction: "
-           << reduction_emitter_.fusion_.ToString();
-
-  auto* builder = reduction_emitter_.builder_;
-  for (const HloReduceInstruction* reduce_hlo : reduce_instr_index_group) {
-    for (int op_result_idx = 0;
-         op_result_idx < GetNumOutputs(reduce_hlo->shape()); op_result_idx++) {
-      Shape result_shape = OutputShape(reduce_hlo->shape(), op_result_idx);
-
-      llvm::Type* element_type = llvm_ir::PrimitiveTypeToIrType(
-          result_shape.element_type(), builder->GetInsertBlock()->getModule());
-      llvm::AllocaInst* reduction_input_address =
-          llvm_ir::EmitAllocaAtFunctionEntry(
-              element_type, "reduction_input_address", builder);
-
-      llvm::AllocaInst* result_address = llvm_ir::EmitAllocaAtFunctionEntry(
-          element_type, "partial_reduction_result", builder);
-
-      const HloInstruction* init_value =
-          reduce_hlo->init_values()[op_result_idx];
-
-      // Initialize the partial result with the initial value of the reduction.
-      llvm::Value* init_ir_value = (*fused_emitter.GetGenerator(
-          *init_value))(llvm_ir::IrArray::Index(builder->getInt32Ty()))
-                                       .value();
-
-      builder->CreateStore(init_ir_value, result_address);
-      const Tiling& tiling = reduction_info.GetTiling();
-      auto shared_cache = [&]() -> std::optional<llvm_ir::SharedMemoryTile> {
-        auto* module = reduction_emitter.ir_emitter_context_.llvm_module();
-        if (reduction_info.IsRowReduction()) {
-          // Multi-row reductions do not use shared memory.
-          if (RowReductionGetRowsPerWarp(
-                  reduction_emitter_.ReducedDimensionSize()) > 1) {
-            return std::nullopt;
-          }
-          // Allocate one shared memory element per warp.
-          auto block_size = tiling.GetThreadsPerBlock();
-          CHECK_EQ(block_size[ReductionDimensions::kRowMinorReducedDimension] %
-                       WarpSize(),
-                   0);
-          return llvm_ir::AllocateSharedMemoryTile(
-              module, element_type,
-              {block_size[ReductionDimensions::kRowKeptDimension],
-               block_size[ReductionDimensions::kRowMinorReducedDimension] /
-                   WarpSize()},
-              "shared_cache");
-        }
-        const auto& num_threads = tiling.GetThreadsPerBlock();
-        int n = num_threads[ReductionDimensions::kColReducedDimension];
-        CHECK_EQ(n, num_threads[ReductionDimensions::kColMinorKeptDimension]);
-        // The "+1" is used to avoid bank conflicts.
-        return llvm_ir::AllocateSharedMemoryTile(module, element_type,
-                                                 {n, n + 1}, "shared_cache");
-      }();
-
-      llvm_ir::ElementGenerator input_gen =
-          *fused_emitter.GetGenerator(*reduce_hlo->inputs()[op_result_idx]);
-      SetCalculationStateFor({shared_cache, init_ir_value, result_address,
-                              reduction_input_address, input_gen},
-                             reduce_hlo, op_result_idx);
-    }
-  }
-}
-
-void ReductionEmitter::MaybeEmitFenceForAMDGPU() {
-  auto* module = builder_->GetInsertBlock()->getModule();
-  if (IsAMDGPU(module) &&
-      ir_emitter_context_.rocm_compute_capability().fence_before_barrier()) {
-    builder_->CreateFence(
-        llvm::AtomicOrdering::SequentiallyConsistent,
-        builder_->getContext().getOrInsertSyncScopeID("workgroup"));
-  }
-}
-
-void ReductionEmitter::EmitSyncThreads() {
-  MaybeEmitFenceForAMDGPU();
-  EmitCallToTargetIntrinsic(TargetIntrinsicID::kBarrierId, {}, {}, builder_);
-}
-
-// Builds a thunk that calls a new or reused kernel for a fusion operation.
-//
-// The caller must specify the same launch dimensions for fusions which have
-// the same computation.
-//
-// If a given fusion is implemented using multiple kernels, then for each
-// kernel we should provide a discriminator, such as "init" and "impl".
-//
-// The builder_fn is only invoked if the kernel couldn't be reused.
-//
-// This is the typical usage pattern of this method:
-//
-// ```
-// auto builder_fn = [](std::vector<llvm_ir::IrArray> inputs,
-//                      std::vector<llvm_ir::IrArray> outputs) { ... };
-// TF_ASSIGN_OR_RETURN(
-//   auto thunk,
-//   BuildKernelThunkForFusion(..., launch_dimensions, builder_fn));
-// AddThunkToThunkSequence(std::move(thunk))
-// ```
-absl::StatusOr<std::unique_ptr<Thunk>>
-ReductionEmitter::BuildKernelThunkForFusion(
-    const LaunchDimensions& launch_dimensions, absl::string_view discriminator,
-    std::function<absl::Status(std::vector<llvm_ir::IrArray>,
-                               std::vector<llvm_ir::IrArray>)>
-        kernel_builder_fn) {
-  const HloComputation* fused_computation =
-      fusion_.fused_instructions_computation();
-  std::string suggested_kernel_name = std::string(fusion_.name());
-
-  TF_ASSIGN_OR_RETURN(auto kernel_arguments,
-                      KernelArguments::Create(
-                          ir_emitter_context_.buffer_assignment(), &fusion_));
-
-  auto [status_or_entry, cached] =
-      ir_emitter_context_.kernel_cache().GetWithStatus(
-          fused_computation, kernel_arguments.args(), discriminator,
-          [&]() -> absl::StatusOr<KernelReuseCache::Entry> {
-            llvm::Function* kernel;
-            std::vector<llvm_ir::IrArray> input_arrays;
-            std::vector<llvm_ir::IrArray> output_arrays;
-            TF_ASSIGN_OR_RETURN(
-                std::tie(kernel, input_arrays, output_arrays),
-                BuildKernelPrototype(ir_emitter_context_, suggested_kernel_name,
-                                     kernel_arguments.args(),
-                                     fusion_.operand_count(), launch_dimensions,
-                                     builder_));
-            TF_RETURN_IF_ERROR(kernel_builder_fn(input_arrays, output_arrays));
-            // Shared memory is allocated statically.
-            return {{kernel->getName().str(), launch_dimensions,
-                     /*cluster_dim=*/std::nullopt,
-                     /*shmem_bytes=*/0}};
-          });
-  TF_ASSIGN_OR_RETURN(const KernelReuseCache::Entry* entry, status_or_entry);
-  if (cached) {
-    VLOG(3) << "Reuse: " << suggested_kernel_name << " -> "
-            << entry->kernel_name;
-  }
-
-  return std::make_unique<KernelThunk>(
-      &fusion_, entry->kernel_name, kernel_arguments.args(), launch_dimensions,
-      entry->cluster_dim, entry->shmem_bytes);
-}
-
-absl::Status ReductionGroupEmitter::EmitExtraOutputsForReduce(
-    const Shape& reduction_operand_shape, const llvm_ir::IrArray::Index& index,
-    const ExtraOutputGensMap& extra_output_gens) {
-  if (extra_output_gens.empty()) {
-    return absl::OkStatus();
-  }
-
-  auto* builder = reduction_emitter_.builder_;
-  // Compute all extra output values before writing them. This avoids
-  // overwriting aliased input/output buffers before all reads occurred.
-  std::vector<std::pair<const HloInstruction*, llvm::Value*>>
-      extra_output_ir_values;
-  extra_output_ir_values.reserve(extra_output_gens.size());
-
-  auto get_index = [&](const HloInstruction* instr) {
-    const Shape& s = instr->shape();
-    return ShapeUtil::EqualIgnoringElementType(reduction_operand_shape, s)
-               ? index
-               : index.SourceIndexOfBitcast(reduction_operand_shape, s,
-                                            builder);
-  };
-
-  for (const auto& [instr, generator] : extra_output_gens) {
-    TF_ASSIGN_OR_RETURN(llvm::Value* const extra_output_ir_value,
-                        generator(get_index(instr)));
-    extra_output_ir_values.emplace_back(instr, extra_output_ir_value);
-  }
-
-  for (const auto& [instr, generator] : extra_output_ir_values) {
-    absl::Span<llvm_ir::IrArray const> result_ir = result_ir_arrays_.at(instr);
-    CHECK_EQ(result_ir.size(), 1);
-    result_ir[0].EmitWriteArrayElement(get_index(instr), generator, builder);
-  }
-  return absl::OkStatus();
-}
-
-absl::StatusOr<std::unique_ptr<Thunk>>
-ReductionEmitter::BuildFusedInitializerThunk(const HloInstruction* fusion_root,
-                                             BufferAllocation::Slice dest_slice,
-                                             int output_index) {
-  const HloReduceInstruction* reduce =
-      DynCast<HloReduceInstruction>(fusion_root);
-  TF_RET_CHECK(reduce);
-
-  const HloInstruction* init_value = reduce->init_values()[0];
-  TF_ASSIGN_OR_RETURN(
-      std::optional<std::unique_ptr<Thunk>> constant_init_thunk,
-      BuildConstantInitializerThunk(ir_emitter_context_, fusion_root,
-                                    init_value, dest_slice));
-  if (constant_init_thunk) {
-    return *std::move(constant_init_thunk);
-  }
-
-  const Shape& dest_shape = fusion_root->shape();
-
-  LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
-      dest_shape, ir_emitter_context_.gpu_device_info());
-  const HloComputation* fused_computation =
-      fusion_.fused_instructions_computation();
-
-  auto builder_fn = [&](std::vector<llvm_ir::IrArray> inputs,
-                        std::vector<llvm_ir::IrArray> outputs) -> absl::Status {
-    FusedIrEmitter fused_emitter(elemental_emitter_);
-    for (int i = 0; i < fused_computation->num_parameters(); i++) {
-      fused_emitter.BindGenerator(
-          *fused_computation->parameter_instruction(i),
-          [builder = builder_,
-           input = inputs[i]](llvm_ir::IrArray::Index index) {
-            return input.EmitReadArrayElement(index, builder);
-          });
-    }
-    HloInstruction* instr = fused_computation->root_instruction();
-    if (instr->opcode() == HloOpcode::kTuple) {
-      instr = instr->mutable_operand(output_index);
-    } else {
-      CHECK_EQ(0, output_index);
-    }
-    TF_RET_CHECK(instr->shape().IsArray());
-    TF_ASSIGN_OR_RETURN(auto generator,
-                        fused_emitter.GetGenerator(*instr->operand(1)));
-    TF_RETURN_IF_ERROR(ParallelLoopEmitter(generator, {outputs[output_index]},
-                                           launch_dimensions, builder_)
-                           .EmitLoop(fusion_.name()));
-    return absl::OkStatus();
-  };
-
-  return BuildKernelThunkForFusion(launch_dimensions,
-                                   /*discriminator=*/
-                                   absl::StrCat("init_", output_index),
-                                   builder_fn);
-}
-
-// Emits shuffle-down reduction for the `partial_result_address` using the
-// reduction computation `reducer`, writes output into
-// `partial_result_address`.
-//
-// Multiple partial_result_address inputs happen when doing variadic
-// reduction: each one should get the output value.
-void ReductionGroupEmitter::EmitFullWarpShuffleDownLoopForReduce(
-    const HloComputation* reducer,
-    absl::Span<TypedPointer const> partial_result_addresses,
-    int threads_per_block, int num_results_per_warp) const {
-  // This only works when the block size is a multiple of 32 threads.
-  // We check this here as a mistake in the number of threads per
-  // block is very hard to detect.
-  CHECK_EQ(threads_per_block % 32, 0);
-  CHECK_EQ(WarpSize() % num_results_per_warp, 0);
-
-  auto* builder = reduction_emitter_.builder_;
-  for (int distance = 16 / num_results_per_warp; distance >= 1; distance /= 2) {
-    absl::InlinedVector<llvm::Value*, 2> reduction_params;
-
-    for (auto acc : partial_result_addresses) {
-      reduction_params.push_back(acc.first);
-    }
-
-    for (auto [partial_result_address, element_type] :
-         partial_result_addresses) {
-      int bit_width = llvm_ir::GetSizeInBits(element_type);
-      llvm::Value* result_from_other_lane = llvm_ir::EmitAllocaAtFunctionEntry(
-          element_type, "result_from_other_lane", builder);
-
-      reduction_params.push_back(result_from_other_lane);
-
-      // Bitcast cannot be applied to aggregate types (even packed ones), so
-      // we bitcast addresses of load/store to intN* of the same bit-width.
-      llvm::Type* shuffled_value_type = element_type->isStructTy()
-                                            ? builder->getIntNTy(bit_width)
-                                            : element_type;
-
-      llvm::Value* partial_result =
-          builder->CreateLoad(shuffled_value_type, partial_result_address,
-                              "partial_reduction_result");
-      builder->CreateStore(
-          EmitFullWarpShuffleDown(
-              partial_result, builder->getInt32(distance), builder,
-              reduction_emitter_.ir_emitter_context_.gpu_device_info()),
-          result_from_other_lane);
-    }
-
-    absl::StatusOr<std::vector<llvm::Value*>> returned_scalars =
-        CallNestedComputationWithScalarAddrs(
-            builder, reduction_emitter_.ir_emitter_context_, *reducer,
-            reduction_params);
-    TF_CHECK_OK(returned_scalars.status());
-
-    for (int i = 0; i < returned_scalars->size(); i++) {
-      builder->CreateStore(/*Val=*/returned_scalars->at(i),
-                           /*Ptr=*/partial_result_addresses[i].first);
-    }
-  }
-}
-
-llvm_ir::IrArray::Index ReductionGroupEmitter::GetOutputIndexForReduction(
-    const TilingKernelInfo& tiling_kernel_info,
-    const HloReduceInstruction* reduction, const HloInstruction* root,
-    int output_idx) const {
-  auto* builder = reduction_emitter_.builder_;
-  auto* index_ty = reduction_emitter_.index_ty_;
-
-  // 1d or 2d output index (for row/column reduction).
-  auto projected_index = [&]() -> llvm_ir::IrArray::Index {
-    const auto& reduction_info = reduction_emitter_.reduction_codegen_info_;
-    const auto& offset = tiling_kernel_info.tile_origin;
-    const auto& shape = reduction_info.GetTiling().GetXlaShape();
-    const auto& thread_ids = tiling_kernel_info.thread_id_info.thread_ids;
-    if (reduction_info.IsRowReduction()) {
-      constexpr int kDim = ReductionDimensions::kRowKeptDimension;
-      return {{builder->CreateAdd(offset[kDim], thread_ids[kDim])},
-              {shape.dimensions(kDim)},
-              index_ty};
-    }
-    auto* major_idx = offset[ReductionDimensions::kColMajorKeptDimension];
-    auto* minor_idx = builder->CreateAdd(
-        offset[ReductionDimensions::kColMinorKeptDimension],
-        thread_ids[ReductionDimensions::kColReducedDimension]);
-    return {{major_idx, minor_idx},
-            ShapeUtil::DeleteDimension(
-                ReductionDimensions::kColReducedDimension, shape),
-            index_ty};
-  }();
-
-  auto physical_shape = ShapeUtil::DeleteDimensions(
-      reduction->dimensions(), reduction->operand(output_idx)->shape());
-  auto physical_index =
-      projected_index.SourceIndexOfBitcast(physical_shape, builder);
-  return llvm_ir::IrArray::Index(physical_index.multidim(),
-                                 OutputShape(reduction->shape(), output_idx),
-                                 index_ty)
-      .SourceIndexOfBitcast(OutputShape(root->shape(), output_idx), builder);
-}
-
-void ReductionGroupEmitter::WriteReductionOutput(
-    const TilingKernelInfo& tiling_kernel_info,
-    const HloReduceInstruction* reduction,
-    const std::vector<const HloInstruction*>& roots,
-    const absl::Span<TypedPointer const> values) const {
-  auto* builder = reduction_emitter_.builder_;
-  const auto& reduction_info = reduction_emitter_.reduction_codegen_info_;
-  const HloComputation* reducer = reduction->to_apply();
-  for (const auto& [oidx, typed_ptr] : llvm::enumerate(values)) {
-    auto [output_ptr, type] = typed_ptr;
-    for (auto root : roots) {
-      llvm_ir::IrArray::Index output_index =
-          GetOutputIndexForReduction(tiling_kernel_info, reduction, root, oidx);
-
-      llvm::Value* output_address =
-          result_ir_arrays_.at(root)[oidx].EmitArrayElementAddress(
-              output_index, builder, "output_element_address");
-      if (reduction_info.IsRaceFree()) {
-        FusedIrEmitter fused_emitter(reduction_emitter_.elemental_emitter_);
-        llvm::Value* loaded = builder->CreateLoad(type, output_ptr, "output");
-        fused_emitter.BindGenerator(
-            *reduction,
-            [&](const llvm_ir::IrArray::Index& index) { return loaded; });
-        llvm_ir::ElementGenerator gen = *fused_emitter.GetGenerator(*root);
-        llvm::Value* generated = *gen(output_index);
-        builder->CreateStore(generated, output_address);
-      } else {
-        CHECK_EQ(values.size(), 1);
-        CHECK_EQ(roots.size(), 1);
-        CHECK_EQ(reduction, root)
-            << "output fusion is not allowed for racing reductions";
-        TF_CHECK_OK(EmitAtomicOperationForNestedComputation(
-            builder, reduction_emitter_.ir_emitter_context_, *reducer,
-            output_address, output_ptr, type));
-      }
-    }
-  }
-}
-
-void ReductionGroupEmitter::EmitReductionOutputForRowReduction(
-    const TilingKernelInfo& tiling_kernel_info,
-    const HloReduceInstruction* reduction,
-    const std::vector<const HloInstruction*>& roots) const {
-  const HloComputation* reducer = reduction->to_apply();
-  const auto& thread_id_info = tiling_kernel_info.thread_id_info;
-  const auto& thread_ids = thread_id_info.thread_ids;
-  auto* thread_id_x =
-      thread_ids[ReductionDimensions::kRowMinorReducedDimension];
-  auto constant = [&](uint64_t c) -> llvm::Constant* {
-    return llvm::ConstantInt::get(reduction_emitter_.index_ty_, c);
-  };
-
-  auto* builder = reduction_emitter_.builder_;
-  auto is_zero = [&](llvm::Value* value) {
-    return builder->CreateICmpEQ(value, constant(0));
-  };
-
-  int num_outputs = reducer->num_parameters() / 2;
-  absl::InlinedVector<TypedPointer, 2> current_outputs;
-  for (int output_idx = 0; output_idx < num_outputs; output_idx++) {
-    const auto& state = GetCalculationStateFor(reduction, output_idx);
-    current_outputs.push_back(
-        {state.partial_result_address,
-         state.partial_result_address->getAllocatedType()});
-  }
-
-  const auto& reduction_info = reduction_emitter_.reduction_codegen_info_;
-  const Tiling& tiling = reduction_info.GetTiling();
-  int num_rows_per_warp =
-      RowReductionGetRowsPerWarp(reduction_emitter_.ReducedDimensionSize());
-  EmitFullWarpShuffleDownLoopForReduce(reducer, absl::MakeSpan(current_outputs),
-                                       tiling.GetNumThreadsPerBlock(),
-                                       num_rows_per_warp);
-
-  KernelSupportLibrary ksl(builder);
-  llvm::Value* warp_id = builder->CreateUDiv(thread_id_x, constant(WarpSize()));
-
-  auto emit_write_output = [&](llvm::Value* write_condition,
-                               const absl::Span<TypedPointer const> values) {
-    ksl.If("reduction_write_output", write_condition, [&] {
-      WriteReductionOutput(tiling_kernel_info, reduction, roots, values);
-    });
-  };
-
-  // The major kept dimension and vector dimension are not tiled, so they're
-  // always in bounds.
-  llvm::Value* is_in_bounds_y = builder->CreateICmpULT(
-      thread_ids[ReductionDimensions::kRowKeptDimension],
-      tiling_kernel_info
-          .output_tile_bounds[ReductionDimensions::kRowKeptDimension]);
-
-  ksl.If("thread_in_bounds", is_in_bounds_y, [&] {
-    if (num_rows_per_warp > 1) {
-      llvm::Value* is_writing_thread = is_zero(builder->CreateAnd(
-          thread_id_x,
-          constant(reduction_emitter_.ReducedDimensionSize() - 1)));
-      emit_write_output(is_writing_thread, current_outputs);
-      return;
-    }
-
-    ksl.If("intra_warp_reduce_write", is_zero(thread_id_info.lane_id), [&] {
-      for (int oidx = 0; oidx < num_outputs; oidx++) {
-        auto& state = GetCalculationStateFor(reduction, oidx);
-        state.shared_cache->Store(
-            builder->CreateLoad(current_outputs[oidx].second,
-                                current_outputs[oidx].first),
-            {thread_id_info.thread_ids[ReductionDimensions::kRowKeptDimension],
-             warp_id},
-            builder);
-      }
-    });
-
-    // TODO(cheshire): Don't we want to sync it once for everything in the
-    // output? Not once per each?
-    reduction_emitter_.EmitSyncThreads();
-    ksl.If("inter_warp_reduce", is_zero(warp_id), [&] {
-      absl::InlinedVector<TypedPointer, 2> selected_values;
-      for (int oidx = 0; oidx < num_outputs; oidx++) {
-        auto& state = GetCalculationStateFor(reduction, oidx);
-        llvm::Value* block_accum_addr = state.shared_cache->Address(
-            {thread_id_info.thread_ids[ReductionDimensions::kRowKeptDimension],
-             thread_id_info.lane_id},
-            builder);
-
-        llvm::Type* element_type =
-            state.partial_result_address->getAllocatedType();
-
-        // Ensure initial value address is in generic, not scratch.
-        llvm::Value* initial_value_addr =
-            CastSharedToGlobal(builder,
-                               llvm_ir::EmitAllocaAtFunctionEntry(
-                                   element_type, "initial_value_addr", builder),
-                               element_type, /*name=*/"");
-        builder->CreateStore(state.initial_value, initial_value_addr);
-
-        llvm::Value* warp_exists = builder->CreateICmpULT(
-            thread_id_x,
-            constant(tiling.GetThreadsPerBlock()
-                         [ReductionDimensions::kRowMinorReducedDimension] /
-                     WarpSize()));
-
-        llvm::Value* selected_value = builder->CreateSelect(
-            warp_exists, block_accum_addr, initial_value_addr);
-
-        selected_values.push_back({selected_value, element_type});
-      }
-
-      // If only one warp produces the output element, we don't need to emit
-      // an inter warp reduce. In our tiling, DimX is the minor reduced
-      // dimension. The major reduced dimension is always emitted as a loop.
-      // TODO(b/241414088) If only warp is present, then inter-warp
-      // communication using shared memory and synchronization using barrier is
-      // also unnecessary and should be removed.
-      if (tiling.GetThreadsPerBlock()
-              [ReductionDimensions::kRowMinorReducedDimension] > WarpSize()) {
-        EmitFullWarpShuffleDownLoopForReduce(
-            reducer, absl::MakeSpan(selected_values),
-            tiling.GetNumThreadsPerBlock(), /*num_results_per_warp=*/1);
-      }
-
-      emit_write_output(is_zero(thread_id_x), selected_values);
-    });
-  });
-}
-
-// Same arguments as EmitReductionOutputForRowReduction.
-void ReductionGroupEmitter::EmitReductionOutputForColumnReduction(
-    const TilingKernelInfo& tiling_kernel_info,
-    const HloReduceInstruction* reduction,
-    const std::vector<const HloInstruction*>& roots) const {
-  auto* builder = reduction_emitter_.builder_;
-  KernelSupportLibrary ksl(builder);
-  const HloComputation* reducer = reduction->to_apply();
-  const auto& thread_id_info = tiling_kernel_info.thread_id_info;
-  const auto& thread_ids = thread_id_info.thread_ids;
-
-  auto constant = [&](uint64_t c) -> llvm::Constant* {
-    return llvm::ConstantInt::get(reduction_emitter_.index_ty_, c);
-  };
-  auto is_zero = [&](llvm::Value* value) {
-    return builder->CreateICmpEQ(value, constant(0));
-  };
-  const auto& reduction_info = reduction_emitter_.reduction_codegen_info_;
-  const Tiling& tiling = reduction_info.GetTiling();
-  int num_outputs = reducer->num_parameters() / 2;
-
-  auto* kept_index = thread_ids[ReductionDimensions::kColMinorKeptDimension];
-  auto* reduced_index = thread_ids[ReductionDimensions::kColReducedDimension];
-
-  // Store the transpose in shared memory.
-  for (int output_idx = 0; output_idx < num_outputs; output_idx++) {
-    const auto& state = GetCalculationStateFor(reduction, output_idx);
-    auto* current_output_value =
-        builder->CreateLoad(state.partial_result_address->getAllocatedType(),
-                            state.partial_result_address);
-    state.shared_cache->Store(current_output_value, {kept_index, reduced_index},
-                              builder);
-  }
-
-  reduction_emitter_.EmitSyncThreads();
-
-  // Get transposed element from shared memory.
-  absl::InlinedVector<TypedPointer, 2> shmem_transposed_addrs;
-  for (int output_idx = 0; output_idx < num_outputs; output_idx++) {
-    const auto& state = GetCalculationStateFor(reduction, output_idx);
-    auto* shmem_transposed_addr =
-        state.shared_cache->Address({reduced_index, kept_index}, builder);
-    shmem_transposed_addrs.push_back(
-        {shmem_transposed_addr, state.shared_cache->GetElementType()});
-  }
-
-  EmitFullWarpShuffleDownLoopForReduce(reducer,
-                                       absl::MakeSpan(shmem_transposed_addrs),
-                                       tiling.GetNumThreadsPerBlock(),
-                                       /*num_results_per_warp=*/1);
-
-  // Some warps in the block are completely outside of the bound of the
-  // tensor, so they should not write any output at all.
-  llvm::Value* has_output = builder->CreateAnd(
-      builder->CreateICmpULT(
-          reduced_index,
-          tiling_kernel_info
-              .output_tile_bounds[ReductionDimensions::kColMinorKeptDimension]),
-      builder->CreateICmpULT(
-          kept_index,
-          tiling_kernel_info
-              .output_tile_bounds[ReductionDimensions::kColReducedDimension]));
-
-  ksl.If("reduction_write_output",
-         builder->CreateAnd(has_output, is_zero(thread_id_info.lane_id)), [&] {
-           WriteReductionOutput(tiling_kernel_info, reduction, roots,
-                                shmem_transposed_addrs);
-         });
-}
-
-// Generate a single element of the tile (update the accumulator state) for a
-// given reducer.
-void ReductionGroupEmitter::GenerateElementForReducer(
-    const HloReduceInstruction* reduction,
-    const llvm_ir::IrArray::Index& index) const {
-  HloComputation* reducer = reduction->to_apply();
-  auto* builder = reduction_emitter_.builder_;
-  CHECK_EQ(reducer->num_parameters() % 2, 0);
-
-  absl::InlinedVector<llvm::Value*, 2> reduction_accumulators;
-  absl::InlinedVector<llvm::Value*, 2> reduction_input_value;
-  for (int red_idx = 0; red_idx < reducer->num_parameters() / 2; red_idx++) {
-    const auto& state = GetCalculationStateFor(reduction, red_idx);
-
-    llvm::AllocaInst* input_address = state.input_address;
-    auto input_index =
-        index.SourceIndexOfBitcast(reduction->operand(0)->shape(), builder);
-    llvm::Value* const input_ir_value = *state.input_gen(input_index);
-    builder->CreateStore(input_ir_value, input_address);
-    reduction_accumulators.push_back(state.partial_result_address);
-    reduction_input_value.push_back(input_address);
-  }
-
-  absl::InlinedVector<llvm::Value*, 4> reduction_params;
-  for (llvm::Value* acc : reduction_accumulators) {
-    reduction_params.push_back(acc);
-  }
-  for (llvm::Value* value : reduction_input_value) {
-    reduction_params.push_back(value);
-  }
-
-  // Emit a call to the variadic reducer. Since it may be returning a
-  // tuple, we can't return it directly as a value. Instead, before
-  // the call, we create N (N = # arguments in the tuple) allocas, one
-  // for each returned argument, then when we make the call we pass N
-  // pointers as last parameters, the called computation writes into
-  // those pointers, and we have returned values on the stack (as well
-  // as pointers to them).
-  absl::StatusOr<std::vector<llvm::Value*>> returned_scalars =
-      CallNestedComputationWithScalarAddrs(
-          builder, reduction_emitter_.ir_emitter_context_, *reducer,
-          reduction_params);
-  TF_CHECK_OK(returned_scalars.status());
-
-  for (int i = 0; i < returned_scalars->size(); i++) {
-    builder->CreateStore(returned_scalars->at(i), reduction_accumulators[i]);
-  }
-}
-
-// Emits code for reductions in the output_instructions.
-absl::Status ReductionEmitter::EmitIRForReduction(
-    absl::Span<const HloInstruction* const> instr_index_group,
-    FusedIrEmitter& fused_emitter, const ReductionOutputMap& result_ir_arrays,
-    const Shape& input_shape) {
-  ExtraOutputGensMap extra_output_gens;
-  absl::flat_hash_map<const HloReduceInstruction*,
-                      std::vector<const HloInstruction*>>
-      heroes_to_roots;
-  // Keep a list of deduplicated heroes separate from heroes_to_roots to make
-  // the CodeGen deterministic.
-  std::vector<const HloReduceInstruction*> heroes;
-
-  for (const HloInstruction* hlo : instr_index_group) {
-    auto& hero = FindNonTrivialHero(*hlo);
-    if (IsRealReductionHero(*hlo, hero)) {
-      auto reduction = Cast<HloReduceInstruction>(&hero);
-      if (heroes_to_roots.find(reduction) == heroes_to_roots.end()) {
-        heroes.push_back(reduction);
-      }
-      heroes_to_roots[reduction].push_back(hlo);
-    } else {
-      extra_output_gens[hlo] = *fused_emitter.GetGenerator(*hlo);
-    }
-  }
-
-  CHECK(!heroes.empty()) << " expect at least one reduce instructions.";
-  const Tiling& tiling = reduction_codegen_info_.GetTiling();
-  CHECK_EQ(tiling.GetNumThreadsPerBlock() % WarpSize(), 0);
-  ReductionGroupEmitter group_emitter(*this, heroes, result_ir_arrays,
-                                      fused_emitter);
-
-  TF_ASSIGN_OR_RETURN(
-      TilingKernelInfo tiling_kernel_info,
-      EmitTilingKernel(
-          builder_, tiling, index_ty_,
-          [&](const TilingThreadIdInfo& thread_id_info,
-              const llvm_ir::IrArray::Index& tile_index,
-              absl::Span<llvm::Value* const> tile_dimensions) {
-            auto emit_element =
-                [&](absl::Span<llvm::Value* const> index_in_tile) {
-                  auto index = tile_index.AddOffset(index_in_tile, builder_);
-
-                  // Emit code to generate the input and perform the reduction
-                  // computation for each reduction instruction.
-                  for (const HloReduceInstruction* reduce : heroes) {
-                    group_emitter.GenerateElementForReducer(reduce, index);
-                  }
-
-                  // Emit code to generate the output for the non-reduction
-                  // instructions in the fusion, if any.
-                  TF_CHECK_OK(group_emitter.EmitExtraOutputsForReduce(
-                      ShapeUtil::MakeShape(
-                          F32, reduction_codegen_info_.GetTiling().GetShape()),
-                      index, extra_output_gens));
-                };
-            EmitTile(builder_, reduction_codegen_info_.GetTiling(),
-                     thread_id_info, tile_dimensions, emit_element);
-          }));
-
-  KernelSupportLibrary ksl(builder_);
-  for (auto reduce : heroes) {
-    if (reduction_codegen_info_.IsRowReduction()) {
-      group_emitter.EmitReductionOutputForRowReduction(
-          tiling_kernel_info, reduce, heroes_to_roots[reduce]);
-    } else {
-      group_emitter.EmitReductionOutputForColumnReduction(
-          tiling_kernel_info, reduce, heroes_to_roots[reduce]);
-    }
-  }
-
-  return absl::OkStatus();
-}
-
-absl::StatusOr<FusionEmissionResult> ReductionEmitter::EmitInitializers() {
-  FusionEmissionResult result;
-  if (reduction_codegen_info_.IsRaceFree()) {
-    return result;
-  }
-  // We need to get the dest slice by traversing the slice assigned to
-  // fusion, because instructions inside fusion don't have buffer assignment.
-  //
-  // The order of fusion roots is determined by its position in the result
-  // tuple. For example, in the following fused computation
-  //
-  // %fused_computation {
-  //   %a = ...
-  //   &b = ...
-  //   ROOT %root = tuple(%a, %b)
-  // }
-  //
-  // The fusion root with index = 0 is %a, and the fusion root %b has index 1.
-  // Therefore we can get the ordered slices by calling ForEachSubshape on the
-  // result shape.
-  std::vector<BufferAllocation::Slice> slices;
-  TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
-      fusion_.shape(), [&](const Shape& subshape, ShapeIndex index) {
-        if (!ShapeUtil::IsLeafIndex(fusion_.shape(), index)) {
-          return absl::OkStatus();
-        }
-
-        TF_ASSIGN_OR_RETURN(
-            BufferAllocation::Slice slice,
-            ir_emitter_context_.buffer_assignment().GetUniqueSlice(&fusion_,
-                                                                   index));
-        slices.push_back(slice);
-        return absl::OkStatus();
-      }));
-
-  absl::Span<HloInstructionAdaptor const> fusion_roots =
-      analysis_.fusion_roots();
-  for (int i = 0; i < fusion_roots.size(); ++i) {
-    const HloInstruction* fusion_root = &fusion_roots[i].instruction();
-
-    if (IsReductionFromOrToContiguousDimensions(*fusion_root)) {
-      TF_ASSIGN_OR_RETURN(
-          result.thunks.emplace_back(),
-          BuildFusedInitializerThunk(fusion_root, slices[i], i));
-    }
-  }
-  return result;
-}
-
-absl::Status ReductionEmitter::EmitKernel(
-    const LaunchDimensions& launch_dims, std::vector<llvm_ir::IrArray> inputs,
-    std::vector<llvm_ir::IrArray> outputs) {
-  const HloComputation* fused_computation =
-      fusion_.fused_instructions_computation();
-  FusedIrEmitter fused_emitter(elemental_emitter_);
-  for (int i = 0; i < fused_computation->num_parameters(); i++) {
-    HloInstruction* fused_operand = fused_computation->parameter_instruction(i);
-    fused_emitter.BindGenerator(
-        *fused_operand, [builder = builder_, input = inputs[i],
-                         fused_operand](const llvm_ir::IrArray::Index& index) {
-          return input.EmitReadArrayElement(index, builder,
-                                            fused_operand->name());
-        });
-  }
-
-  // Get outputs.
-  ReductionOutputMap result_ir_arrays;
-
-  int ir_arrays_idx = 0;
-  for (const HloInstructionAdaptor& root : analysis_.fusion_roots()) {
-    int get_num_results = GetNumOutputs(root.shape());
-    result_ir_arrays[&root.instruction()] =
-        absl::MakeSpan(outputs).subspan(ir_arrays_idx, get_num_results);
-    ir_arrays_idx += get_num_results;
-  }
-
-  KernelSupportLibrary ksl(builder_, llvm_ir::UnrollMode::kDefaultUnroll);
-
-  // Use raw block_id_y to select the i-th parallel reduction to run. Using
-  // block_id_y instead of block_id_x simplifies the index calculation
-  // for reduction code generation as the block_id_y is orthogonal to
-  // the indices used within the reductions.
-  const auto& instr_index_groups =
-      reduction_codegen_info_.GetGroups().grouped_roots;
-  Shape reduce_operand_shape = reduction_codegen_info_.GetReduceOperandShape();
-
-  llvm::Value* block_id_y = gpu::EmitCallToTargetIntrinsic(
-      gpu::TargetIntrinsicID::kBlockIdy, {}, {}, builder_);
-  llvm_ir::AddRangeMetadata(0, instr_index_groups.size(),
-                            llvm::cast<llvm::Instruction>(block_id_y),
-                            builder_->GetInsertBlock()->getModule());
-  block_id_y = builder_->CreateZExtOrTrunc(block_id_y, builder_->getInt32Ty());
-  block_id_y->setName("block.id.y");
-  for (int i = 0; i < instr_index_groups.size(); ++i) {
-    TF_RETURN_IF_ERROR(ksl.IfWithStatus(
-        absl::StrCat("reduce-group-", i),
-        builder_->CreateICmpEQ(block_id_y, builder_->getInt32(i)), [&] {
-          return EmitIRForReduction(instr_index_groups[i], fused_emitter,
-                                    result_ir_arrays, reduce_operand_shape);
-        }));
-  }
-
-  return absl::OkStatus();
-}
-
-}  // namespace
-
-absl::StatusOr<FusionEmissionResult> ReductionFusion::EmitInitializers(
-    IrEmitterContext& ir_emitter_context,
-    const HloFusionInstruction& fusion) const {
-  llvm::IRBuilder<> builder(ir_emitter_context.llvm_module()->getContext());
-  return ReductionEmitter(analysis_, reduction_info_, ir_emitter_context,
-                          fusion, &builder)
-      .EmitInitializers();
-}
-
-absl::Status ReductionFusion::EmitKernel(IrEmitterContext& ir_emitter_context,
-                                         const HloFusionInstruction& fusion,
-                                         const LaunchDimensions& launch_dims,
-                                         std::vector<llvm_ir::IrArray> inputs,
-                                         std::vector<llvm_ir::IrArray> outputs,
-                                         llvm::IRBuilder<>* builder) const {
-  return ReductionEmitter(analysis_, reduction_info_, ir_emitter_context,
-                          fusion, builder)
-      .EmitKernel(launch_dims, inputs, outputs);
-}
-
-int ReductionInfo::GetRowsPerWarp() const {
-  if (!is_row_reduction_) return 1;
-  return RowReductionGetRowsPerWarp(
-      tiling_.GetShape()[ReductionDimensions::kRowMinorReducedDimension]);
-}
-
-LaunchDimensions ReductionInfo::launch_dimensions() const {
-  size_t blocks_y = groups_.grouped_roots.size();
-  return {se::BlockDim(/*x=*/tiling_.GetNumBlocks(),
-                       /*y=*/static_cast<int64_t>(blocks_y), /*z=*/1),
-          se::ThreadDim(/*x=*/tiling_.GetNumThreadsPerBlock(),
-                        /*y=*/1, /*z=*/1)};
-}
-
-ReductionInfo ReductionInfo::Create(const HloFusionAnalysis& analysis) {
-  auto* hero_reduction = analysis.FindHeroReduction();
-  CHECK_NE(hero_reduction, nullptr);
-  Shape input_shape = hero_reduction->operand(0)->shape();
-  ReductionDimensions reduction_dimensions =
-      GetReductionKindAndContiguousComponents(*hero_reduction);
-  auto shape = reduction_dimensions.dimensions;
-  VLOG(10) << "is_row_reduction " << reduction_dimensions.is_row_reduction
-           << " " << shape[0] << " " << shape[1] << " " << shape[2];
-  Vector3 reduction_tiling = GetReductionTiling(reduction_dimensions);
-
-  int64_t num_threads_y =
-      reduction_dimensions.is_row_reduction ? 1 : WarpSize();
-  int64_t rows_per_warp =
-      reduction_dimensions.is_row_reduction
-          ? RowReductionGetRowsPerWarp(
-                shape[ReductionDimensions::kRowMinorReducedDimension])
-          : 1;
-  int64_t num_threads_x = [&] {
-    if (reduction_dimensions.is_row_reduction) {
-      if (rows_per_warp > 1) {
-        return shape[ReductionDimensions::kRowMinorReducedDimension];
-      }
-      int64_t max_block_size =
-          MinThreadsXRowReduction(hero_reduction->GetModule()->config());
-      return std::min(
-          max_block_size,
-          RoundUpTo(
-              CeilOfRatio(shape[ReductionDimensions::kRowMinorReducedDimension],
-                          reduction_tiling
-                              [ReductionDimensions::kRowMinorReducedDimension]),
-              WarpSize()));
-    }
-    return WarpSize();
-  }();
-
-  // If we're limited by the size of the x dimension, add additional parallelism
-  // in the y dimension. The code generator doesn't currently support
-  // parallelizing the z dimension (major reduced dimensions). The general
-  // recommendation is to use between 128 and 512 threads, so we just go for
-  // 256. See https://forums.developer.nvidia.com/t/55529
-  constexpr int64_t kThreadsPerBlockTarget = 256;
-  if (reduction_dimensions.is_row_reduction &&
-      num_threads_x * 2 <= kThreadsPerBlockTarget) {
-    int64_t kept_size =
-        reduction_dimensions.dimensions[ReductionDimensions::kRowKeptDimension];
-    // Increase the size of the y dimension as long as there's remaining
-    // parallelism.
-    if (kept_size * num_threads_x <= kThreadsPerBlockTarget) {
-      num_threads_y = kept_size;
-      // num_threads_x is a power of two, but it may be less than 32. If dim_y
-      // is also small, we may have to increase the bound so the total number of
-      // threads is a multiple of 32.
-      while ((num_threads_x * num_threads_y) % 32) ++num_threads_y;
-    } else {
-      num_threads_y = kThreadsPerBlockTarget / num_threads_x;
-    }
-  }
-
-  int vector_size = GetVectorSize(analysis, reduction_dimensions, num_threads_x,
-                                  reduction_tiling);
-
-  absl::InlinedVector<int64_t, 4> num_threads{1, num_threads_y, num_threads_x};
-  absl::InlinedVector<int64_t, 4> tiled_shape{shape[0], shape[1],
-                                              shape[2] / vector_size};
-  absl::InlinedVector<int64_t, 4> tile_per_thread{
-      reduction_tiling[0], reduction_tiling[1],
-      std::max<int64_t>(reduction_tiling[2] / vector_size, 1)};
-  if (rows_per_warp > 1) {
-    // If we produce more than one element per thread, that means the reduced
-    // dimension is small and it can't be tiled - we already have more threads
-    // in a warp than the size of the reduced dimension. The code generator
-    // doesn't currently support tiling the kept dimension, because it just
-    // uses the thread ID as the coordinate.
-    tile_per_thread[2] = 1;
-  }
-  if (vector_size != 1) {
-    num_threads.push_back(1);  // The vector dimension is a loop.
-    tiled_shape.push_back(vector_size);
-    tile_per_thread.push_back(vector_size);
-  }
-
-  Tiling tiling(tiled_shape, tile_per_thread, num_threads,
-                /*loops_to_unroll=*/{false, false, true, false});
-  bool reduction_is_race_free = ReductionIsRaceFree(
-      hero_reduction->GetModule()->config(), reduction_dimensions);
-  return ReductionInfo(analysis, tiling, reduction_dimensions.is_row_reduction,
-                       reduction_is_race_free,
-                       GroupDisjointReductions(analysis, /*for_mlir=*/false),
-                       hero_reduction);
-}
-
-std::optional<IndexingMap> ReductionInfo::ComputeThreadIdToOutputIndexing(
-    int64_t root_index, mlir::MLIRContext* ctx) const {
-  if (!groups_.is_reduction_root[root_index]) {
-    auto map = ComposeIndexingMaps(
-        GetIndexingMapForTiling(tiling_, ctx),
-        GetBitcastMap(tiling_.GetXlaShape(),
-                      analysis_.fusion_root(root_index).shape(), ctx));
-    AddGroupIdConstraint(map, root_index, groups_);
-    return map;
-  }
-  const auto& hero = analysis_.fusion_hero(root_index).instruction();
-
-  auto block_offsets = GetBlockOffsetsForTiling(tiling_, ctx);
-  auto thread_ids = DelinearizeInBoundsIndex(mlir::getAffineDimExpr(0, ctx),
-                                             tiling_.GetThreadsPerBlock());
-
-  auto physical_shape =
-      ShapeUtil::DeleteDimensions(hero.dimensions(), hero.operand(0)->shape());
-  std::vector<DimVar> dimension_ranges{
-      {{0, tiling_.GetNumThreadsPerBlock() - 1}},
-      {},
-      {},
-      {{0, tiling_.GetNumBlocks() - 1}},
-      {{0, static_cast<int64_t>(groups_.grouped_roots.size() - 1)}},
-      {},
-  };
-
-  constexpr int kRowKept = ReductionDimensions::kRowKeptDimension;
-  constexpr int kRowMinorReduced =
-      ReductionDimensions::kRowMinorReducedDimension;
-
-  constexpr int kColMajorKept = ReductionDimensions::kColMajorKeptDimension;
-  constexpr int kColMinorKept = ReductionDimensions::kColMinorKeptDimension;
-  constexpr int kColReduced = ReductionDimensions::kColReducedDimension;
-
-  auto map = [&]() {
-    if (is_row_reduction_) {
-      IndexingMap linear_index(
-          mlir::AffineMap::get(
-              6, 0, block_offsets.getResult(kRowKept) + thread_ids[kRowKept],
-              ctx),
-          dimension_ranges, /*range_vars=*/{}, /*rt_vars=*/{});
-      int rows_per_warp = GetRowsPerWarp();
-      if (rows_per_warp > 1) {
-        linear_index.AddConstraint(
-            thread_ids[kRowMinorReduced] % (WarpSize() / rows_per_warp),
-            {0, 0});
-      } else {
-        linear_index.AddConstraint(thread_ids[kRowMinorReduced], {0, 0});
-      }
-      return ComposeIndexingMaps(
-          linear_index, GetBitcastMap(ShapeUtil::MakeShape(
-                                          PRED, {tiling_.GetShape()[kRowKept]}),
-                                      physical_shape, ctx));
-    }
-
-    mlir::SmallVector<mlir::AffineExpr> projected_dims{
-        block_offsets.getResult(kColMajorKept),
-        block_offsets.getResult(kColMinorKept) + thread_ids[kColReduced]};
-    std::vector<RangeVar> range_vars;
-    if (thread_ids.size() == 4) {
-      int vector_size = tiling_.GetThreadTileSize().back();
-      range_vars.push_back({0, vector_size - 1});
-      projected_dims.push_back(mlir::getAffineSymbolExpr(0, ctx));
-    }
-    IndexingMap projected_index(
-        mlir::AffineMap::get(6, range_vars.size(), projected_dims, ctx),
-        dimension_ranges, range_vars, /*rt_vars=*/{});
-
-    projected_index.AddConstraint(
-        mlir::getAffineDimExpr(
-            KernelFusionInterface::kIndexingMapThreadIdxDims[0], ctx) %
-            WarpSize(),
-        {0, 0});
-    if (!is_row_reduction_) {
-      projected_index.AddConstraint(
-          projected_index.GetAffineMap().getResult(1),
-          {0, tiling_.GetShape()[ReductionDimensions::kColMinorKeptDimension] -
-                  1});
-    }
-
-    return ComposeIndexingMaps(
-        projected_index,
-        GetBitcastMap(ShapeUtil::DeleteDimension(
-                          ReductionDimensions::kColReducedDimension,
-                          tiling_.GetXlaShape()),
-                      physical_shape, ctx));
-  }();
-
-  AddGroupIdConstraint(map, root_index, groups_);
-  map.Simplify();
-  return map;
-}
-
-std::optional<IndexingMap> ReductionInfo::ComputeThreadIdToInputIndexing(
-    int64_t root_index, int64_t hero_operand_index,
-    mlir::MLIRContext* ctx) const {
-  const auto& hero = analysis_.fusion_hero(root_index).instruction();
-  if (groups_.is_reduction_root[root_index] &&
-      hero_operand_index >= hero.operand_count() / 2) {
-    // We don't have indexing for the init values.
-    return std::nullopt;
-  }
-  if (!groups_.is_reduction_root[root_index]) {
-    return ComposeIndexingMaps(
-        *ComputeThreadIdToOutputIndexing(root_index, ctx),
-        *ComputeOutputToInputIndexing(
-             &analysis_.fusion_root(root_index).instruction(), 0, ctx)
-             .indexing_maps[hero_operand_index]
-             .begin());
-  }
-
-  auto map = ComposeIndexingMaps(
-      GetIndexingMapForTiling(tiling_, ctx),
-      GetBitcastMap(tiling_.GetXlaShape(),
-                    hero.operand(hero_operand_index)->shape(), ctx));
-  AddGroupIdConstraint(map, root_index, groups_);
-  map.Simplify();
-  return map;
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/reduction.h b/third_party/xla/xla/service/gpu/fusions/reduction.h
deleted file mode 100644
index a15462f..0000000
--- a/third_party/xla/xla/service/gpu/fusions/reduction.h
+++ /dev/null
@@ -1,190 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_FUSIONS_REDUCTION_H_
-#define XLA_SERVICE_GPU_FUSIONS_REDUCTION_H_
-
-#include <optional>
-#include <utility>
-#include <vector>
-
-#include "absl/log/check.h"
-#include "absl/status/status.h"
-#include "llvm/IR/IRBuilder.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/service/gpu/fusions/fusion_emitter.h"
-#include "xla/service/gpu/fusions/reduction_base.h"
-#include "xla/service/gpu/fusions/tiling_util.h"
-#include "xla/service/gpu/hlo_fusion_analysis.h"
-#include "xla/service/gpu/ir_emitter_context.h"
-#include "xla/service/gpu/launch_dimensions.h"
-#include "xla/service/llvm_ir/ir_array.h"
-#include "xla/shape.h"
-
-namespace xla {
-namespace gpu {
-
-class ReductionInfo {
- public:
-  static ReductionInfo Create(const HloFusionAnalysis& analysis);
-
-  const Tiling& GetTiling() const { return tiling_; }
-  const ReductionGroups& GetGroups() const { return groups_; }
-  Shape GetReduceOperandShape() const {
-    return first_reduce_->operand(0)->shape();
-  }
-
-  bool IsRowReduction() const { return is_row_reduction_; }
-  bool IsRaceFree() const { return is_race_free_; }
-  int GetRowsPerWarp() const;
-
-  std::optional<IndexingMap> ComputeThreadIdToOutputIndexing(
-      int64_t root_index, mlir::MLIRContext* ctx) const;
-
-  std::optional<IndexingMap> ComputeThreadIdToInputIndexing(
-      int64_t root_index, int64_t hero_operand_index,
-      mlir::MLIRContext* ctx) const;
-
-  LaunchDimensions launch_dimensions() const;
-
- private:
-  ReductionInfo(const HloFusionAnalysis& analysis, Tiling tiling,
-                bool is_row_reduction, bool is_race_free,
-                ReductionGroups groups, const HloInstruction* first_reduce)
-      : analysis_(analysis),
-        tiling_(tiling),
-        is_row_reduction_(is_row_reduction),
-        is_race_free_(is_race_free),
-        groups_(std::move(groups)),
-        first_reduce_(first_reduce) {}
-
-  const HloFusionAnalysis& analysis_;
-  Tiling tiling_;
-  bool is_row_reduction_;
-  bool is_race_free_;
-  ReductionGroups groups_;
-  const HloInstruction* first_reduce_;
-};
-
-// Generates code for reduction to contiguous dimensions.
-//
-// Row reduction uses the following algorithm described in CUDA-like
-// pseudocode:
-//
-// ```
-//  __global__ void reduce(int num_rows, float *in, float out) {
-//    __shared__ float[32] cache;
-//    int offset = blockDim.x * blockIdx.x + threadIdx.x;
-//    if (offset >= num_rows) return;
-//    int tile_bound = std::min(offset + kTileSizeX, num_rows);
-//    float accum = 0;
-//    for (int i=offset; i<num_rows; i+= blockDim.x) {
-//      accum += in[i];
-//    }
-//    accum = warp_reduce(accum);
-//    if (threadIdx.x % WarpSize == 0) {
-//      cache[threadIdx.x / WarpSize] = accum;
-//    }
-//    __syncthreads();
-//    if (threadIdx.x / WarpSize == 0) {
-//      bool warp_exists = threadIdx.x < (blockDim.x / WarpSize);
-//      float block_accum = warp_exists ? cache[threadIdx.x % WarpSize] : 0;
-//      block_accum = warp_reduce(accum);
-//      if (threadIdx.x == 0) {
-//        out += block_accum;
-//      }
-//    }
-//  }
-// ```
-//
-// Column reduction uses the following algorithm:
-//
-// ```
-// void reduce(float** in, float* out) {
-//   __shared__ float[32][33] cache;
-//   int thread_id = GetThreadId();
-//   int block_id = GetBlockId();
-//   int tile_size = 128;
-//
-//   float accum = 0;
-//   for (int i=0; i<tile_size; i++) {
-//     accum += in[thread_id.y * tile_size + i][block_id * 32 + thread_id.x];
-//   }
-//   cache[thread_id.x][thread_id.y] = accum;
-//
-//   __syncthreads();
-//   accum = cache[thread_id.y][thread_id.x];
-//   accum = warp_reduce(accum); // Sum all the values of `accum` in the same
-//                               // warp.
-//
-//   if (thread_id.y % 32 == 0) {
-//     out[block_id * 32 + thread_id.x] = accum;
-//   }
-// }
-// ```
-//
-// Moreover, a heuristic is implemented to divide the reduce instructions
-// into groups for parallelization (see `DivideOutputInstructionsIntoGroups`
-// for details about the heuristic.) Reduce instructions in the same group
-// will run sequentially while different groups will run in parallel.
-//
-// we use raw block_id_y to select the reduce groups for execution without
-// complicating the index calculation in the code generation of the reduce
-// instructions. In other words, a block_id_y is assigned to a group and so
-// different groups can be run in parallel.
-class ReductionFusion : public KernelFusionEmitterBase {
- public:
-  explicit ReductionFusion(const HloFusionAnalysis& analysis)
-      : analysis_(analysis), reduction_info_(ReductionInfo::Create(analysis)) {}
-
-  std::optional<IndexingMap> ComputeThreadIdToOutputIndexing(
-      int64_t root_index, mlir::MLIRContext* ctx) const override {
-    return reduction_info_.ComputeThreadIdToOutputIndexing(root_index, ctx);
-  }
-
-  std::optional<IndexingMap> ComputeThreadIdToInputIndexing(
-      int64_t root_index, int64_t hero_operand_index,
-      mlir::MLIRContext* ctx) const override {
-    return reduction_info_.ComputeThreadIdToInputIndexing(
-        root_index, hero_operand_index, ctx);
-  }
-
-  LaunchDimensions launch_dimensions() const override {
-    return reduction_info_.launch_dimensions();
-  }
-
-  const ReductionInfo& reduction_info() const { return reduction_info_; }
-
- protected:
-  absl::StatusOr<FusionEmissionResult> EmitInitializers(
-      IrEmitterContext& ir_emitter_context,
-      const HloFusionInstruction& fusion) const override;
-
-  absl::Status EmitKernel(IrEmitterContext& ir_emitter_context,
-                          const HloFusionInstruction& fusion,
-                          const LaunchDimensions& launch_dims,
-                          std::vector<llvm_ir::IrArray> inputs,
-                          std::vector<llvm_ir::IrArray> outputs,
-                          llvm::IRBuilder<>* builder) const override;
-
- private:
-  const HloFusionAnalysis& analysis_;
-  ReductionInfo reduction_info_;
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_FUSIONS_REDUCTION_H_
diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_base.cc b/third_party/xla/xla/service/gpu/fusions/reduction_base.cc
index 7895108..e0af3b4 100644
--- a/third_party/xla/xla/service/gpu/fusions/reduction_base.cc
+++ b/third_party/xla/xla/service/gpu/fusions/reduction_base.cc
@@ -36,7 +36,6 @@
 #include "xla/hlo/utils/hlo_query.h"
 #include "xla/primitive_util.h"
 #include "xla/service/gpu/fusions/fusion_emitter.h"
-#include "xla/service/gpu/fusions/tiling_util.h"
 #include "xla/service/gpu/gpu_fusible.h"
 #include "xla/service/gpu/hlo_fusion_analysis.h"
 #include "xla/service/gpu/hlo_traversal.h"
diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc b/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc
index d58f809..075678d 100644
--- a/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc
+++ b/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc
@@ -47,9 +47,9 @@
 #include "xla/hlo/ir/hlo_instruction.h"
 #include "xla/hlo/ir/hlo_instructions.h"
 #include "xla/service/gpu/fusions/fusion_emitter.h"
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h"
 #include "xla/service/gpu/fusions/mlir/computation_partitioner.h"
 #include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h"
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h"
 #include "xla/service/gpu/fusions/mlir/type_util.h"
 #include "xla/service/gpu/fusions/reduction_base.h"
 #include "xla/service/gpu/hlo_fusion_analysis.h"
@@ -79,79 +79,6 @@
 constexpr int kRowKept = ReductionDimensions::kRowKeptDimension;
 constexpr int kRowMinorReduced = ReductionDimensions::kRowMinorReducedDimension;
 
-LaunchDimensions MlirReductionFusion::launch_dimensions() const {
-  size_t blocks_y = groups_.grouped_roots.size();
-  return {se::BlockDim(/*x=*/Product(num_blocks_),
-                       /*y=*/static_cast<int64_t>(blocks_y), /*z=*/1),
-          se::ThreadDim(/*x=*/Product(num_threads_),
-                        /*y=*/1, /*z=*/1)};
-}
-
-MlirReductionFusion::MlirReductionFusion(const HloFusionAnalysis& analysis)
-    : analysis_(analysis) {
-  auto* hero_reduction = analysis.FindHeroReduction();
-  CHECK_NE(hero_reduction, nullptr);
-  Shape input_shape = hero_reduction->operand(0)->shape();
-  reduction_dimensions_ =
-      GetReductionKindAndContiguousComponents(*hero_reduction);
-  VLOG(10) << reduction_dimensions_;
-
-  CHECK(ReductionIsRaceFree(hero_reduction->GetModule()->config(),
-                            reduction_dimensions_))
-      << "Non-race-free reductions should have been decomposed. Did "
-         "tree_reduction_rewriter run?";
-
-  groups_ = GroupDisjointReductions(analysis, /*for_mlir=*/true);
-  first_reduce_ = hero_reduction;
-
-  const auto& groups = GetGroups();
-  int num_groups = groups.grouped_roots.size();
-  side_output_roots_.resize(num_groups);
-  reduction_heroes_.resize(num_groups);
-  reduction_roots_.resize(num_groups);
-
-  absl::flat_hash_set<const HloInstruction*> seen_heroes;
-  for (auto [root_adaptor, hero_adaptor, is_reduction, group_id] :
-       llvm::zip(analysis.fusion_roots(), analysis.fusion_heroes(),
-                 groups.is_reduction_root, groups.group_id_per_root)) {
-    const HloInstruction* root = &root_adaptor.instruction();
-    const HloInstruction* hero = &hero_adaptor.instruction();
-    if (is_reduction) {
-      if (seen_heroes.insert(hero).second) {
-        reduction_heroes_[group_id].push_back(hero);
-      }
-      reduction_roots_[group_id].push_back(root);
-    } else {
-      side_output_roots_[group_id].push_back(root);
-    }
-  }
-}
-
-IndexingMap MlirReductionFusion::GetIndexingMap(
-    llvm::ArrayRef<mlir::AffineExpr> results,
-    absl::Span<int64_t const> symbol_sizes) const {
-  auto* ctx = results.front().getContext();
-  auto num_groups = static_cast<int64_t>(reduction_heroes_.size());
-  return IndexingMap{
-      AffineMap::get(6, symbol_sizes.size(), results, ctx),
-      DimVarsFromTensorSizes(
-          {Product(num_threads_), 1, 1, Product(num_blocks_), num_groups, 1}),
-      RangeVarsFromTensorSizes(symbol_sizes),
-      /*rt_vars=*/{}};
-}
-
-IndexingMap MlirReductionFusion::GetThreadIndexingMap(
-    llvm::ArrayRef<mlir::AffineExpr> results,
-    absl::Span<std::pair<mlir::AffineExpr, Interval> const> constraints,
-    absl::Span<int64_t const> symbol_sizes) const {
-  auto affine_map = AffineMap::get(1, symbol_sizes.size(), results,
-                                   results.front().getContext());
-  return IndexingMap{affine_map,
-                     DimVarsFromTensorSizes({Product(num_threads_)}),
-                     RangeVarsFromTensorSizes(symbol_sizes),
-                     /*rt_vars=*/{}, constraints};
-}
-
 struct PerThreadOutputs {
   // The partially reduced scalars for each thread.
   HloValueMap reduction_scalars;
@@ -232,140 +159,6 @@
   SmallVector<Value> thread_and_block_ids;
 };
 
-std::vector<mlir_converter::EpilogueSpecification>
-MlirReductionFusion::GetEpilogues(const HloFusionInstruction& fusion,
-                                  MLIRContext* mlir_context) const {
-  std::vector<mlir_converter::EpilogueSpecification> epilogues;
-  epilogues.reserve(reduction_heroes_.size());
-  for (const auto& [heroes, roots] :
-       llvm::zip(reduction_heroes_, reduction_roots_)) {
-    epilogues.push_back(
-        mlir_converter::EpilogueSpecification::FromOutputIndexing(
-            analysis_, heroes, roots, *this, mlir_context));
-  }
-  // Add empty epilogues for the side outputs. This ensures their roots don't
-  // get "fused" into the tuple function.
-  for (const auto& roots : side_output_roots_) {
-    for (const auto* root : roots) {
-      epilogues.push_back(
-          mlir_converter::EpilogueSpecification::FromIdentityIndexing(
-              root, root, mlir_context));
-    }
-  }
-  return epilogues;
-}
-
-absl::Status MlirReductionFusion::EmitEntryFunction(
-    const PartitionedComputations& computations,
-    const mlir_converter::CallTargetProvider& call_targets,
-    mlir::func::FuncOp entry_function,
-    const HloFusionInstruction& fusion) const {
-  EmitterState state{*this, entry_function, fusion, computations, call_targets};
-  auto& b = state.builder;
-  b.setInsertionPointToStart(entry_function.addEntryBlock());
-  state.thread_and_block_ids = EmitThreadAndBlockIds(b);
-  if (reduction_heroes_.size() == 1) {
-    b.create<mlir::func::ReturnOp>(EmitReduction(0, state));
-    return absl::OkStatus();
-  }
-  SmallVector<int64_t> cases(reduction_heroes_.size() - 1);
-  absl::c_iota(cases, 1);  // `default` is region 0.
-  auto switch_op = b.create<mlir::scf::IndexSwitchOp>(
-      entry_function.getResultTypes(), EmitBlockId(b, 1), cases, cases.size());
-  b.create<mlir::func::ReturnOp>(switch_op.getResults());
-  for (auto [id, region] : llvm::enumerate(switch_op->getRegions())) {
-    b.setInsertionPointToStart(&region.emplaceBlock());
-    b.create<mlir::scf::YieldOp>(EmitReduction(id, state));
-  }
-  return absl::OkStatus();
-}
-
-IndexingMap MlirRowReductionFusion::ComputeReductionInputIndexing(
-    mlir::MLIRContext* ctx) const {
-  auto thread_id =
-      DelinearizeInBoundsIndex(mlir::getAffineDimExpr(0, ctx), num_threads_);
-  auto block_id =
-      DelinearizeInBoundsIndex(mlir::getAffineDimExpr(3, ctx), num_blocks_);
-  auto major_reduced = getAffineSymbolExpr(0, ctx);
-  auto minor_reduced = getAffineSymbolExpr(1, ctx);
-  auto vector_index = getAffineSymbolExpr(2, ctx);
-
-  SmallVector<AffineExpr> indices{
-      major_reduced,
-      block_id[0] * tile_sizes_per_block_[0] + thread_id[0],
-      block_id[1] * tile_sizes_per_block_[1] +
-          (minor_reduced * num_threads_[1]) + thread_id[1],
-      vector_index,
-  };
-
-  auto map = GetIndexingMap(indices, tile_sizes_per_thread_);
-  for (auto [result, input_dim] : llvm::zip(indices, input_shape_)) {
-    map.AddConstraint(result, {0, input_dim - 1});
-  }
-  return map;
-}
-
-IndexingMap MlirMultiRowReductionFusion::ComputeReductionOutputIndexing(
-    MLIRContext* ctx) const {
-  auto thread_id =
-      DelinearizeInBoundsIndex(mlir::getAffineDimExpr(0, ctx), num_threads_);
-  auto block_id = num_blocks_.front() == 1 ? mlir::getAffineConstantExpr(0, ctx)
-                                           : mlir::getAffineDimExpr(3, ctx);
-  IndexingMap projected_index =
-      GetIndexingMap(block_id * num_threads_[0] + thread_id[0]);
-  projected_index.AddConstraint(thread_id[1] % (WarpSize() / GetRowsPerWarp()),
-                                {0, 0});
-  // We don't need a constraint on the loop dimensions, because they are removed
-  // by GetIndexingMap (since they don't show up in the output index
-  // computation).
-  return projected_index;
-}
-
-IndexingMap MlirMultiRowReductionFusion::ComputeReductionInputIndexing(
-    mlir::MLIRContext* ctx) const {
-  auto thread_id =
-      DelinearizeInBoundsIndex(mlir::getAffineDimExpr(0, ctx), num_threads_);
-  auto block_id = num_blocks_.front() == 1 ? mlir::getAffineConstantExpr(0, ctx)
-                                           : mlir::getAffineDimExpr(3, ctx);
-  auto major_reduced = getAffineSymbolExpr(0, ctx);
-  auto vector_index = getAffineSymbolExpr(1, ctx);
-
-  SmallVector<AffineExpr> indices{
-      major_reduced, block_id * num_threads_[0] + thread_id[0],
-      thread_id[1] * tile_sizes_per_thread_[1] + vector_index};
-
-  auto map = GetIndexingMap(indices, tile_sizes_per_thread_);
-  for (auto [result, input_dim] : llvm::zip(indices, input_shape_)) {
-    map.AddConstraint(result, {0, input_dim - 1});
-  }
-  return map;
-}
-
-IndexingMap MlirRowReductionFusion::ComputeReductionOutputIndexing(
-    MLIRContext* ctx) const {
-  auto thread_id =
-      DelinearizeInBoundsIndex(mlir::getAffineDimExpr(0, ctx), num_threads_);
-  auto block_id =
-      DelinearizeInBoundsIndex(mlir::getAffineDimExpr(3, ctx), num_blocks_);
-  IndexingMap projected_index =
-      GetIndexingMap(block_id[0] * tile_sizes_per_block_[0] + thread_id[0]);
-  projected_index.AddConstraint(thread_id[1], {0, 0});
-  return projected_index;
-}
-
-HloValueMap MlirReductionFusion::GetInits(int group_id,
-                                          EmitterState& state) const {
-  HloValueMap result;
-  const auto& reductions = reduction_heroes_[group_id];
-  for (auto* hero : reductions) {
-    int arity = hero->operand_count() / 2;
-    result[hero] = ProvideParameterRange(state.computation, hero, arity, arity,
-                                         {}, state.call_target,
-                                         state.entry_function, state.builder);
-  }
-  return result;
-}
-
 PerThreadOutputs MlirReductionFusion::EmitterState::EmitPerThreadElements(
     int group_id, const HloValueMap& inits, const SmallVector<Value>& outputs) {
   auto tile_indexing =
@@ -558,6 +351,140 @@
       });
 }
 
+MlirReductionFusion::MlirReductionFusion(const HloFusionAnalysis& analysis)
+    : analysis_(analysis) {
+  auto* hero_reduction = analysis.FindHeroReduction();
+  CHECK_NE(hero_reduction, nullptr);
+  Shape input_shape = hero_reduction->operand(0)->shape();
+  reduction_dimensions_ =
+      GetReductionKindAndContiguousComponents(*hero_reduction);
+  VLOG(10) << reduction_dimensions_;
+
+  CHECK(ReductionIsRaceFree(hero_reduction->GetModule()->config(),
+                            reduction_dimensions_))
+      << "Non-race-free reductions should have been decomposed. Did "
+         "tree_reduction_rewriter run?";
+
+  groups_ = GroupDisjointReductions(analysis, /*for_mlir=*/true);
+  first_reduce_ = hero_reduction;
+
+  const auto& groups = GetGroups();
+  int num_groups = groups.grouped_roots.size();
+  side_output_roots_.resize(num_groups);
+  reduction_heroes_.resize(num_groups);
+  reduction_roots_.resize(num_groups);
+
+  absl::flat_hash_set<const HloInstruction*> seen_heroes;
+  for (auto [root_adaptor, hero_adaptor, is_reduction, group_id] :
+       llvm::zip(analysis.fusion_roots(), analysis.fusion_heroes(),
+                 groups.is_reduction_root, groups.group_id_per_root)) {
+    const HloInstruction* root = &root_adaptor.instruction();
+    const HloInstruction* hero = &hero_adaptor.instruction();
+    if (is_reduction) {
+      if (seen_heroes.insert(hero).second) {
+        reduction_heroes_[group_id].push_back(hero);
+      }
+      reduction_roots_[group_id].push_back(root);
+    } else {
+      side_output_roots_[group_id].push_back(root);
+    }
+  }
+}
+
+IndexingMap MlirReductionFusion::GetIndexingMap(
+    llvm::ArrayRef<mlir::AffineExpr> results,
+    absl::Span<int64_t const> symbol_sizes) const {
+  auto* ctx = results.front().getContext();
+  auto num_groups = static_cast<int64_t>(reduction_heroes_.size());
+  return IndexingMap{
+      AffineMap::get(6, symbol_sizes.size(), results, ctx),
+      DimVarsFromTensorSizes(
+          {Product(num_threads_), 1, 1, Product(num_blocks_), num_groups, 1}),
+      RangeVarsFromTensorSizes(symbol_sizes),
+      /*rt_vars=*/{}};
+}
+
+IndexingMap MlirReductionFusion::GetThreadIndexingMap(
+    llvm::ArrayRef<mlir::AffineExpr> results,
+    absl::Span<std::pair<mlir::AffineExpr, Interval> const> constraints,
+    absl::Span<int64_t const> symbol_sizes) const {
+  auto affine_map = AffineMap::get(1, symbol_sizes.size(), results,
+                                   results.front().getContext());
+  return IndexingMap{affine_map,
+                     DimVarsFromTensorSizes({Product(num_threads_)}),
+                     RangeVarsFromTensorSizes(symbol_sizes),
+                     /*rt_vars=*/{}, constraints};
+}
+
+LaunchDimensions MlirReductionFusion::launch_dimensions() const {
+  size_t blocks_y = groups_.grouped_roots.size();
+  return {se::BlockDim(/*x=*/Product(num_blocks_),
+                       /*y=*/static_cast<int64_t>(blocks_y), /*z=*/1),
+          se::ThreadDim(/*x=*/Product(num_threads_),
+                        /*y=*/1, /*z=*/1)};
+}
+
+std::vector<mlir_converter::EpilogueSpecification>
+MlirReductionFusion::GetEpilogues(const HloFusionInstruction& fusion,
+                                  MLIRContext* mlir_context) const {
+  std::vector<mlir_converter::EpilogueSpecification> epilogues;
+  epilogues.reserve(reduction_heroes_.size());
+  for (const auto& [heroes, roots] :
+       llvm::zip(reduction_heroes_, reduction_roots_)) {
+    epilogues.push_back(
+        mlir_converter::EpilogueSpecification::FromOutputIndexing(
+            analysis_, heroes, roots, *this, mlir_context));
+  }
+  // Add empty epilogues for the side outputs. This ensures their roots don't
+  // get "fused" into the tuple function.
+  for (const auto& roots : side_output_roots_) {
+    for (const auto* root : roots) {
+      epilogues.push_back(
+          mlir_converter::EpilogueSpecification::FromIdentityIndexing(
+              root, root, mlir_context));
+    }
+  }
+  return epilogues;
+}
+
+absl::Status MlirReductionFusion::EmitEntryFunction(
+    const PartitionedComputations& computations,
+    const mlir_converter::CallTargetProvider& call_targets,
+    mlir::func::FuncOp entry_function,
+    const HloFusionInstruction& fusion) const {
+  EmitterState state{*this, entry_function, fusion, computations, call_targets};
+  auto& b = state.builder;
+  b.setInsertionPointToStart(entry_function.addEntryBlock());
+  state.thread_and_block_ids = EmitThreadAndBlockIds(b);
+  if (reduction_heroes_.size() == 1) {
+    b.create<mlir::func::ReturnOp>(EmitReduction(0, state));
+    return absl::OkStatus();
+  }
+  SmallVector<int64_t> cases(reduction_heroes_.size() - 1);
+  absl::c_iota(cases, 1);  // `default` is region 0.
+  auto switch_op = b.create<mlir::scf::IndexSwitchOp>(
+      entry_function.getResultTypes(), EmitBlockId(b, 1), cases, cases.size());
+  b.create<mlir::func::ReturnOp>(switch_op.getResults());
+  for (auto [id, region] : llvm::enumerate(switch_op->getRegions())) {
+    b.setInsertionPointToStart(&region.emplaceBlock());
+    b.create<mlir::scf::YieldOp>(EmitReduction(id, state));
+  }
+  return absl::OkStatus();
+}
+
+HloValueMap MlirReductionFusion::GetInits(int group_id,
+                                          EmitterState& state) const {
+  HloValueMap result;
+  const auto& reductions = reduction_heroes_[group_id];
+  for (auto* hero : reductions) {
+    int arity = hero->operand_count() / 2;
+    result[hero] = ProvideParameterRange(state.computation, hero, arity, arity,
+                                         {}, state.call_target,
+                                         state.entry_function, state.builder);
+  }
+  return result;
+}
+
 std::optional<IndexingMap> MlirReductionFusion::ComputeThreadIdToInputIndexing(
     int64_t root_index, int64_t hero_operand_index, MLIRContext* ctx) const {
   const auto& hero = analysis_.fusion_hero(root_index).instruction();
@@ -645,182 +572,6 @@
   return outputs;
 }
 
-MlirRowReductionFusion::MlirRowReductionFusion(
-    const HloFusionAnalysis& analysis)
-    : MlirReductionFusion(analysis) {
-  CHECK(reduction_dimensions_.is_row_reduction);
-  Vector3 shape = reduction_dimensions_.dimensions;
-  CHECK_EQ(RowReductionGetRowsPerWarp(shape[kRowMinorReduced]), 1);
-  constexpr int64_t kMinorReducedElementsPerThread = 16;
-
-  int64_t num_threads_kept = 1;
-  int64_t num_threads_reduced = [&] {
-    int64_t max_block_size =
-        MinThreadsXRowReduction(first_reduce_->GetModule()->config());
-    return std::min(max_block_size,
-                    RoundUpTo(CeilOfRatio(shape[kRowMinorReduced],
-                                          kMinorReducedElementsPerThread),
-                              WarpSize()));
-  }();
-
-  // If we're limited by the size of the x dimension, add additional parallelism
-  // in the y dimension. The code generator doesn't currently support
-  // parallelizing the z dimension (major reduced dimensions). The general
-  // recommendation is to use between 128 and 512 threads, so we just go for
-  // 256. See https://forums.developer.nvidia.com/t/55529
-  constexpr int64_t kThreadsPerBlockTarget = 256;
-  if (num_threads_reduced * 2 <= kThreadsPerBlockTarget) {
-    int64_t kept_size = reduction_dimensions_.dimensions[kRowKept];
-    // Increase the size of the y dimension as long as there's remaining
-    // parallelism.
-    if (kept_size * num_threads_reduced <= kThreadsPerBlockTarget) {
-      num_threads_kept = kept_size;
-    } else {
-      num_threads_kept = kThreadsPerBlockTarget / num_threads_reduced;
-    }
-  }
-
-  int vector_size = GetVectorSizeForMlir(analysis, reduction_dimensions_,
-                                         num_threads_reduced);
-  num_threads_ = {num_threads_kept, num_threads_reduced};
-  // TODO(jreiffers): Get rid of `vector_size` in here.
-  input_shape_ = {shape[0], shape[1], shape[2] / vector_size, vector_size};
-  // TODO(jreiffers): Tighten ranges based on constraints when simplifying
-  // instead of using min here. For example, based on
-  //
-  //   s1 in [0, 127]
-  //   d0 floordiv 32 + s1 * 32 in [0, 63]
-  //
-  // Tighten the bound of s1 to [0, 1].
-  int minor_reduced_tile_size =
-      std::min(kMinorReducedElementsPerThread / vector_size,
-               CeilOfRatio(input_shape_[2], num_threads_[1]));
-
-  tile_sizes_per_thread_ = {shape[0], minor_reduced_tile_size, vector_size};
-  tile_sizes_per_block_ = {num_threads_kept,
-                           minor_reduced_tile_size * num_threads_reduced};
-  num_blocks_ = {CeilOfRatio(input_shape_[1], tile_sizes_per_block_[0]),
-                 CeilOfRatio(input_shape_[2], tile_sizes_per_block_[1])};
-}
-
-MlirMultiRowReductionFusion::MlirMultiRowReductionFusion(
-    const HloFusionAnalysis& analysis)
-    : MlirReductionFusion(analysis) {
-  CHECK(reduction_dimensions_.is_row_reduction);
-  Vector3 shape = reduction_dimensions_.dimensions;
-  int64_t rows_per_warp = RowReductionGetRowsPerWarp(shape[kRowMinorReduced]);
-  input_shape_ = {shape[0], shape[1], shape[2]};
-  CHECK_GT(rows_per_warp, 1);
-
-  auto compute_block_size = [&](int vector_size) {
-    int64_t num_threads_reduced = shape[kRowMinorReduced] / vector_size;
-
-    constexpr int64_t kThreadsPerBlockTarget = 256;
-    int64_t kept_size = reduction_dimensions_.dimensions[kRowKept];
-    int64_t num_threads_kept = 1;
-    if (kept_size * num_threads_reduced <= kThreadsPerBlockTarget) {
-      num_threads_kept = kept_size;
-    } else {
-      num_threads_kept = kThreadsPerBlockTarget / num_threads_reduced;
-    }
-    num_threads_ = {num_threads_kept, num_threads_reduced};
-    tile_sizes_per_thread_ = {shape[0], vector_size};
-    num_blocks_ = {CeilOfRatio(input_shape_[kRowKept], num_threads_kept)};
-  };
-
-  // Compute the launch grid without vectorization. We use the results to
-  // compute the vectorized launch grid.
-  compute_block_size(1);
-
-  // Normally, we only consider input types for vectorization. However, in
-  // multi-row reductions, the input:output ratio is much higher, so we consider
-  // both inputs and outputs.
-  int smallest_input_or_output_bits =
-      std::min(analysis.input_output_info().smallest_input_dtype_bits,
-               analysis.input_output_info().smallest_output_dtype_bits);
-
-  // This vector size is always valid: we know that the reduced dimension is a
-  // power of 2, since otherwise RowReductionGetRowsPerWarp would have
-  // returned 1.
-  // Our codegen can't currently deal with vectorization across rows, so we
-  // limit the vector size to the size of the row. Note that this emitter
-  // essentially reverts to the loop emitter in this case, except for side
-  // outputs.
-  int vector_size = std::min(static_cast<int>(input_shape_[kRowMinorReduced]),
-                             32 / smallest_input_or_output_bits);
-
-  // We target 8 warps per block, which means there could be up to 8 blocks per
-  // SM, but we have no good way of knowing. In practice, enabling vectorization
-  // for decently sized reductions at least does not hurt.
-  if (num_blocks_.front() > analysis.device_info().core_count() &&
-      vector_size > 1) {
-    compute_block_size(vector_size);
-  }
-}
-
-int MlirMultiRowReductionFusion::GetRowsPerWarp() const {
-  return RowReductionGetRowsPerWarp(
-             input_shape_[ReductionDimensions::kRowMinorReducedDimension]) *
-         tile_sizes_per_thread_[1];
-}
-
-int MlirRowReductionFusion::GetWarpsPerRow() const {
-  return CeilOfRatio(num_threads_[1], WarpSize());
-}
-
-IndexingMap MlirRowReductionFusion::GetSharedMemoryReductionReadMap(
-    mlir::MLIRContext* ctx) const {
-  auto thread_id =
-      DelinearizeInBoundsIndex(getAffineDimExpr(0, ctx), num_threads_);
-  auto lane_id = thread_id[1] % WarpSize();
-  return GetThreadIndexingMap({thread_id[0], lane_id},
-                              {{thread_id[1], {0, GetWarpsPerRow() - 1}}});
-}
-
-IndexingMap MlirRowReductionFusion::GetSharedMemoryWriteMap(
-    mlir::MLIRContext* ctx) const {
-  auto thread_id =
-      DelinearizeInBoundsIndex(getAffineDimExpr(0, ctx), num_threads_);
-  // The reduced dimension is tiled; each warp writes one element to shared
-  // memory (from lane 0).
-  auto lane_id = thread_id[1] % WarpSize();
-  auto warp_id = thread_id[1].floorDiv(WarpSize());
-  return GetThreadIndexingMap({thread_id[0], warp_id}, {{lane_id, {0, 0}}});
-}
-
-llvm::SmallVector<mlir::Value> MlirRowReductionFusion::EmitReduction(
-    int group_id, EmitterState& state) const {
-  const auto& reductions = reduction_heroes_[group_id];
-
-  HloValueMap inits = GetInits(group_id, state);
-  auto per_thread =
-      state.EmitPerThreadElements(group_id, inits, state.FusionOutputs());
-  per_thread.reduction_scalars =
-      state.ShuffleReduce(reductions, per_thread.reduction_scalars);
-
-  if (GetWarpsPerRow() == 1) {
-    // If only a single warp works on an element, we don't need to go through
-    // shared memory.
-    return EvaluateEpilogue(per_thread.reduction_scalars,
-                            std::move(per_thread.outputs), state, group_id,
-                            /*symbol_values=*/{});
-  }
-
-  return state.ReduceViaSharedMemory(group_id, per_thread, inits);
-}
-
-llvm::SmallVector<mlir::Value> MlirMultiRowReductionFusion::EmitReduction(
-    int group_id, EmitterState& state) const {
-  HloValueMap inits = GetInits(group_id, state);
-  const auto& reductions = reduction_heroes_[group_id];
-  auto per_thread =
-      state.EmitPerThreadElements(group_id, inits, state.FusionOutputs());
-  auto reduced = state.ShuffleReduce(reductions, per_thread.reduction_scalars,
-                                     WarpSize() / 2 / GetRowsPerWarp());
-  return EvaluateEpilogue(reduced, std::move(per_thread.outputs), state,
-                          group_id, /*symbol_values=*/{});
-}
-
 MlirColumnReductionFusion::MlirColumnReductionFusion(
     const HloFusionAnalysis& analysis)
     : MlirReductionFusion(analysis) {
@@ -930,5 +681,254 @@
   return std::make_unique<MlirColumnReductionFusion>(analysis);
 }
 
+MlirRowReductionFusion::MlirRowReductionFusion(
+    const HloFusionAnalysis& analysis)
+    : MlirReductionFusion(analysis) {
+  CHECK(reduction_dimensions_.is_row_reduction);
+  Vector3 shape = reduction_dimensions_.dimensions;
+  CHECK_EQ(RowReductionGetRowsPerWarp(shape[kRowMinorReduced]), 1);
+  constexpr int64_t kMinorReducedElementsPerThread = 16;
+
+  int64_t num_threads_kept = 1;
+  int64_t num_threads_reduced = [&] {
+    int64_t max_block_size =
+        MinThreadsXRowReduction(first_reduce_->GetModule()->config());
+    return std::min(max_block_size,
+                    RoundUpTo(CeilOfRatio(shape[kRowMinorReduced],
+                                          kMinorReducedElementsPerThread),
+                              WarpSize()));
+  }();
+
+  // If we're limited by the size of the x dimension, add additional parallelism
+  // in the y dimension. The code generator doesn't currently support
+  // parallelizing the z dimension (major reduced dimensions). The general
+  // recommendation is to use between 128 and 512 threads, so we just go for
+  // 256. See https://forums.developer.nvidia.com/t/55529
+  constexpr int64_t kThreadsPerBlockTarget = 256;
+  if (num_threads_reduced * 2 <= kThreadsPerBlockTarget) {
+    int64_t kept_size = reduction_dimensions_.dimensions[kRowKept];
+    // Increase the size of the y dimension as long as there's remaining
+    // parallelism.
+    if (kept_size * num_threads_reduced <= kThreadsPerBlockTarget) {
+      num_threads_kept = kept_size;
+    } else {
+      num_threads_kept = kThreadsPerBlockTarget / num_threads_reduced;
+    }
+  }
+
+  int vector_size = GetVectorSizeForMlir(analysis, reduction_dimensions_,
+                                         num_threads_reduced);
+  num_threads_ = {num_threads_kept, num_threads_reduced};
+  // TODO(jreiffers): Get rid of `vector_size` in here.
+  input_shape_ = {shape[0], shape[1], shape[2] / vector_size, vector_size};
+  // TODO(jreiffers): Tighten ranges based on constraints when simplifying
+  // instead of using min here. For example, based on
+  //
+  //   s1 in [0, 127]
+  //   d0 floordiv 32 + s1 * 32 in [0, 63]
+  //
+  // Tighten the bound of s1 to [0, 1].
+  int minor_reduced_tile_size =
+      std::min(kMinorReducedElementsPerThread / vector_size,
+               CeilOfRatio(input_shape_[2], num_threads_[1]));
+
+  tile_sizes_per_thread_ = {shape[0], minor_reduced_tile_size, vector_size};
+  tile_sizes_per_block_ = {num_threads_kept,
+                           minor_reduced_tile_size * num_threads_reduced};
+  num_blocks_ = {CeilOfRatio(input_shape_[1], tile_sizes_per_block_[0]),
+                 CeilOfRatio(input_shape_[2], tile_sizes_per_block_[1])};
+}
+
+IndexingMap MlirRowReductionFusion::ComputeReductionInputIndexing(
+    mlir::MLIRContext* ctx) const {
+  auto thread_id =
+      DelinearizeInBoundsIndex(mlir::getAffineDimExpr(0, ctx), num_threads_);
+  auto block_id =
+      DelinearizeInBoundsIndex(mlir::getAffineDimExpr(3, ctx), num_blocks_);
+  auto major_reduced = getAffineSymbolExpr(0, ctx);
+  auto minor_reduced = getAffineSymbolExpr(1, ctx);
+  auto vector_index = getAffineSymbolExpr(2, ctx);
+
+  SmallVector<AffineExpr> indices{
+      major_reduced,
+      block_id[0] * tile_sizes_per_block_[0] + thread_id[0],
+      block_id[1] * tile_sizes_per_block_[1] +
+          (minor_reduced * num_threads_[1]) + thread_id[1],
+      vector_index,
+  };
+
+  auto map = GetIndexingMap(indices, tile_sizes_per_thread_);
+  for (auto [result, input_dim] : llvm::zip(indices, input_shape_)) {
+    map.AddConstraint(result, {0, input_dim - 1});
+  }
+  return map;
+}
+
+IndexingMap MlirRowReductionFusion::ComputeReductionOutputIndexing(
+    MLIRContext* ctx) const {
+  auto thread_id =
+      DelinearizeInBoundsIndex(mlir::getAffineDimExpr(0, ctx), num_threads_);
+  auto block_id =
+      DelinearizeInBoundsIndex(mlir::getAffineDimExpr(3, ctx), num_blocks_);
+  IndexingMap projected_index =
+      GetIndexingMap(block_id[0] * tile_sizes_per_block_[0] + thread_id[0]);
+  projected_index.AddConstraint(thread_id[1], {0, 0});
+  return projected_index;
+}
+
+int MlirRowReductionFusion::GetWarpsPerRow() const {
+  return CeilOfRatio(num_threads_[1], WarpSize());
+}
+
+IndexingMap MlirRowReductionFusion::GetSharedMemoryReductionReadMap(
+    mlir::MLIRContext* ctx) const {
+  auto thread_id =
+      DelinearizeInBoundsIndex(getAffineDimExpr(0, ctx), num_threads_);
+  auto lane_id = thread_id[1] % WarpSize();
+  return GetThreadIndexingMap({thread_id[0], lane_id},
+                              {{thread_id[1], {0, GetWarpsPerRow() - 1}}});
+}
+
+IndexingMap MlirRowReductionFusion::GetSharedMemoryWriteMap(
+    mlir::MLIRContext* ctx) const {
+  auto thread_id =
+      DelinearizeInBoundsIndex(getAffineDimExpr(0, ctx), num_threads_);
+  // The reduced dimension is tiled; each warp writes one element to shared
+  // memory (from lane 0).
+  auto lane_id = thread_id[1] % WarpSize();
+  auto warp_id = thread_id[1].floorDiv(WarpSize());
+  return GetThreadIndexingMap({thread_id[0], warp_id}, {{lane_id, {0, 0}}});
+}
+
+llvm::SmallVector<mlir::Value> MlirRowReductionFusion::EmitReduction(
+    int group_id, EmitterState& state) const {
+  const auto& reductions = reduction_heroes_[group_id];
+
+  HloValueMap inits = GetInits(group_id, state);
+  auto per_thread =
+      state.EmitPerThreadElements(group_id, inits, state.FusionOutputs());
+  per_thread.reduction_scalars =
+      state.ShuffleReduce(reductions, per_thread.reduction_scalars);
+
+  if (GetWarpsPerRow() == 1) {
+    // If only a single warp works on an element, we don't need to go through
+    // shared memory.
+    return EvaluateEpilogue(per_thread.reduction_scalars,
+                            std::move(per_thread.outputs), state, group_id,
+                            /*symbol_values=*/{});
+  }
+
+  return state.ReduceViaSharedMemory(group_id, per_thread, inits);
+}
+
+MlirMultiRowReductionFusion::MlirMultiRowReductionFusion(
+    const HloFusionAnalysis& analysis)
+    : MlirReductionFusion(analysis) {
+  CHECK(reduction_dimensions_.is_row_reduction);
+  Vector3 shape = reduction_dimensions_.dimensions;
+  int64_t rows_per_warp = RowReductionGetRowsPerWarp(shape[kRowMinorReduced]);
+  input_shape_ = {shape[0], shape[1], shape[2]};
+  CHECK_GT(rows_per_warp, 1);
+
+  auto compute_block_size = [&](int vector_size) {
+    int64_t num_threads_reduced = shape[kRowMinorReduced] / vector_size;
+
+    constexpr int64_t kThreadsPerBlockTarget = 256;
+    int64_t kept_size = reduction_dimensions_.dimensions[kRowKept];
+    int64_t num_threads_kept = 1;
+    if (kept_size * num_threads_reduced <= kThreadsPerBlockTarget) {
+      num_threads_kept = kept_size;
+    } else {
+      num_threads_kept = kThreadsPerBlockTarget / num_threads_reduced;
+    }
+    num_threads_ = {num_threads_kept, num_threads_reduced};
+    tile_sizes_per_thread_ = {shape[0], vector_size};
+    num_blocks_ = {CeilOfRatio(input_shape_[kRowKept], num_threads_kept)};
+  };
+
+  // Compute the launch grid without vectorization. We use the results to
+  // compute the vectorized launch grid.
+  compute_block_size(1);
+
+  // Normally, we only consider input types for vectorization. However, in
+  // multi-row reductions, the input:output ratio is much higher, so we consider
+  // both inputs and outputs.
+  int smallest_input_or_output_bits =
+      std::min(analysis.input_output_info().smallest_input_dtype_bits,
+               analysis.input_output_info().smallest_output_dtype_bits);
+
+  // This vector size is always valid: we know that the reduced dimension is a
+  // power of 2, since otherwise RowReductionGetRowsPerWarp would have
+  // returned 1.
+  // Our codegen can't currently deal with vectorization across rows, so we
+  // limit the vector size to the size of the row. Note that this emitter
+  // essentially reverts to the loop emitter in this case, except for side
+  // outputs.
+  int vector_size = std::min(static_cast<int>(input_shape_[kRowMinorReduced]),
+                             32 / smallest_input_or_output_bits);
+
+  // We target 8 warps per block, which means there could be up to 8 blocks per
+  // SM, but we have no good way of knowing. In practice, enabling vectorization
+  // for decently sized reductions at least does not hurt.
+  if (num_blocks_.front() > analysis.device_info().core_count() &&
+      vector_size > 1) {
+    compute_block_size(vector_size);
+  }
+}
+
+IndexingMap MlirMultiRowReductionFusion::ComputeReductionInputIndexing(
+    mlir::MLIRContext* ctx) const {
+  auto thread_id =
+      DelinearizeInBoundsIndex(mlir::getAffineDimExpr(0, ctx), num_threads_);
+  auto block_id = num_blocks_.front() == 1 ? mlir::getAffineConstantExpr(0, ctx)
+                                           : mlir::getAffineDimExpr(3, ctx);
+  auto major_reduced = getAffineSymbolExpr(0, ctx);
+  auto vector_index = getAffineSymbolExpr(1, ctx);
+
+  SmallVector<AffineExpr> indices{
+      major_reduced, block_id * num_threads_[0] + thread_id[0],
+      thread_id[1] * tile_sizes_per_thread_[1] + vector_index};
+
+  auto map = GetIndexingMap(indices, tile_sizes_per_thread_);
+  for (auto [result, input_dim] : llvm::zip(indices, input_shape_)) {
+    map.AddConstraint(result, {0, input_dim - 1});
+  }
+  return map;
+}
+
+IndexingMap MlirMultiRowReductionFusion::ComputeReductionOutputIndexing(
+    MLIRContext* ctx) const {
+  auto thread_id =
+      DelinearizeInBoundsIndex(mlir::getAffineDimExpr(0, ctx), num_threads_);
+  auto block_id = num_blocks_.front() == 1 ? mlir::getAffineConstantExpr(0, ctx)
+                                           : mlir::getAffineDimExpr(3, ctx);
+  IndexingMap projected_index =
+      GetIndexingMap(block_id * num_threads_[0] + thread_id[0]);
+  projected_index.AddConstraint(thread_id[1] % (WarpSize() / GetRowsPerWarp()),
+                                {0, 0});
+  // We don't need a constraint on the loop dimensions, because they are removed
+  // by GetIndexingMap (since they don't show up in the output index
+  // computation).
+  return projected_index;
+}
+
+int MlirMultiRowReductionFusion::GetRowsPerWarp() const {
+  return RowReductionGetRowsPerWarp(
+             input_shape_[ReductionDimensions::kRowMinorReducedDimension]) *
+         tile_sizes_per_thread_[1];
+}
+
+llvm::SmallVector<mlir::Value> MlirMultiRowReductionFusion::EmitReduction(
+    int group_id, EmitterState& state) const {
+  HloValueMap inits = GetInits(group_id, state);
+  const auto& reductions = reduction_heroes_[group_id];
+  auto per_thread =
+      state.EmitPerThreadElements(group_id, inits, state.FusionOutputs());
+  auto reduced = state.ShuffleReduce(reductions, per_thread.reduction_scalars,
+                                     WarpSize() / 2 / GetRowsPerWarp());
+  return EvaluateEpilogue(reduced, std::move(per_thread.outputs), state,
+                          group_id, /*symbol_values=*/{});
+}
+
 }  // namespace gpu
 }  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/reduction_mlir_test.cc
index 761ecb4..4798528 100644
--- a/third_party/xla/xla/service/gpu/fusions/reduction_mlir_test.cc
+++ b/third_party/xla/xla/service/gpu/fusions/reduction_mlir_test.cc
@@ -23,10 +23,10 @@
 #include <gtest/gtest.h>
 #include "absl/status/status.h"
 #include "absl/strings/str_cat.h"
-#include "absl/strings/substitute.h"
 #include "absl/types/span.h"
 #include "xla/error_spec.h"
 #include "xla/service/gpu/fusions/mlir_emitter_test_base.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
 #include "xla/service/gpu/model/indexing_map.h"
 #include "xla/service/gpu/model/indexing_test_utils.h"
 #include "xla/tsl/lib/core/status_test_util.h"
@@ -54,69 +54,8 @@
   }
 };
 
-using MlirRowReductionTest = ReductionTest<MlirRowReductionFusion>;
-using MlirColumnReductionTest = ReductionTest<MlirColumnReductionFusion>;
 using MlirMultiRowReductionTest = ReductionTest<MlirMultiRowReductionFusion>;
 
-constexpr std::string_view kVariadicRowReduction = R"(
-    Add {
-      scalar_lhs.0 = f32[] parameter(0)
-      scalar_rhs.0 = f32[] parameter(1)
-      scalar_lhs.1 = f32[] parameter(2)
-      scalar_rhs.1 = f32[] parameter(3)
-      add.0 = f32[] add(scalar_lhs.0, scalar_lhs.1)
-      add.1 = f32[] add(scalar_rhs.0, scalar_rhs.1)
-      ROOT t = (f32[], f32[]) tuple(add.0, add.1)
-    }
-    fused_computation {
-      param_0 = f32[2, 3, 2048] parameter(0)
-      param_1 = f32[2, 3, 2048] parameter(1)
-      param_2 = f32[] parameter(2)
-      ROOT d.1 = (f32[2, 3], f32[2, 3])
-        reduce(param_0, param_1, param_2, param_2), dimensions={2}, to_apply=Add
-    }
-    ENTRY main {
-      a = f32[2, 3, 2048] parameter(0)
-      b = f32[2, 3, 2048] parameter(1)
-      c = f32[] constant(0)
-      ROOT fusion = (f32[2, 3], f32[2, 3]) fusion(a, b, c),
-        kind=kInput, calls=fused_computation
-    })";
-
-constexpr std::string_view kF64RowReduction = R"(
-    Add {
-      lhs = f64[] parameter(0)
-      rhs = f64[] parameter(1)
-      ROOT add = f64[] add(lhs, rhs)
-    }
-    fused_computation {
-      param_0 = f64[100,128] parameter(0)
-      param_1 = f64[] parameter(1)
-      ROOT reduce = f64[100] reduce(param_0, param_1), dimensions={1}, to_apply=Add
-    }
-    ENTRY main {
-      a = f64[100,128] parameter(0)
-      c = f64[] constant(0)
-      ROOT fusion = f64[100] fusion(a, c), kind=kInput, calls=fused_computation
-    })";
-
-constexpr auto kRowReductionMinorAndMajor = R"(
-    Add {
-      lhs = f32[] parameter(0)
-      rhs = f32[] parameter(1)
-      ROOT add = f32[] add(lhs, rhs)
-    }
-    fused_computation {
-      param_0 = f32[7,100,128] parameter(0)
-      param_1 = f32[] parameter(1)
-      ROOT reduce = f32[100] reduce(param_0, param_1), dimensions={0,2}, to_apply=Add
-    }
-    ENTRY main {
-      a = f32[7,100,128] parameter(0)
-      c = f32[] constant(0)
-      ROOT fusion = f32[100] fusion(a, c), kind=kInput, calls=fused_computation
-    })";
-
 constexpr auto kMultiRowReductionX8 = R"(
     Add {
       lhs = f32[] parameter(0)
@@ -179,181 +118,6 @@
       ROOT fusion = pred[76800] fusion(p0), kind=kInput, calls=fusion
     })";
 
-constexpr std::string_view kRowReductionSideOutput = R"(
-    Add {
-      lhs = f32[] parameter(0)
-      rhs = f32[] parameter(1)
-      ROOT add = f32[] add(lhs, rhs)
-    }
-    fused_computation {
-      param_0 = f32[8,2048] parameter(0)
-      param_1 = f32[] parameter(1)
-      exp = f32[8,2048] exponential(param_0)
-      reduce = f32[8] reduce(param_0, param_1), dimensions={1}, to_apply=Add
-      ROOT t = (f32[8], f32[8,2048]) tuple(reduce, exp)
-    }
-    ENTRY main {
-      a = f32[8,2048] parameter(0)
-      c = f32[] constant(0)
-      ROOT fusion = (f32[8], f32[8,2048]) fusion(a, c), kind=kInput,
-          calls=fused_computation
-    })";
-
-TEST_F(MlirRowReductionTest, VariadicRowReductionIndexing) {
-  auto fusion = GetEmitter(kVariadicRowReduction);
-  TF_EXPECT_OK(TestBijection(
-      *fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context_),
-      {2, 3, 2048}));
-  TF_EXPECT_OK(TestBijection(
-      *fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context_), {2, 3}));
-}
-
-TEST_F(MlirRowReductionTest, VariadicRowReductionCorrectness) {
-  EXPECT_TRUE(RunAndCompareNoHloPasses(kVariadicRowReduction, ErrorSpec{1e-3}));
-}
-
-TEST_F(MlirRowReductionTest, RowReduceEpilogue) {
-  constexpr auto kHloString = R"(
-    HloModule Test, is_scheduled=true
-
-    Add {
-      lhs = f32[] parameter(0)
-      rhs = f32[] parameter(1)
-      ROOT add = f32[] add(lhs, rhs)
-    }
-    fused_computation {
-      param_0 = f32[8,2048] parameter(0)
-      param_1 = f32[] parameter(1)
-      reduce = f32[8] reduce(param_0, param_1), dimensions={1}, to_apply=Add
-      ROOT log = f32[8] log(reduce)
-    }
-    ENTRY main {
-      a = f32[8,2048] parameter(0)
-      c = f32[] constant(0)
-      ROOT fusion = f32[8] fusion(a, c), kind=kInput, calls=fused_computation
-    })";
-  TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"(
-    // CHECK: pure_call @Add_add
-    // CHECK: shuffle_reduce
-    // CHECK: allocate_shared
-    // CHECK: sync_threads
-    // CHECK: shuffle_reduce
-  )"));
-  EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3}));
-}
-
-TEST_F(MlirRowReductionTest, RowReduceMOFEpilogue) {
-  constexpr auto kHloString = R"(
-    HloModule Test, is_scheduled=true
-
-    Add {
-      lhs = f32[] parameter(0)
-      rhs = f32[] parameter(1)
-      ROOT add = f32[] add(lhs, rhs)
-    }
-    Mul {
-      lhs = f32[] parameter(0)
-      rhs = f32[] parameter(1)
-      ROOT mul = f32[] multiply(lhs, rhs)
-    }
-    fused_computation {
-      param_0 = f32[8,1024] parameter(0)
-      param_1 = f32[] parameter(1)
-      reduce1 = f32[8] reduce(param_0, param_1), dimensions={1}, to_apply=Add
-      reduce2 = f32[8] reduce(param_0, param_1), dimensions={1}, to_apply=Mul
-      log = f32[8] log(reduce1)
-      abs = f32[8] abs(reduce1)
-      neg = f32[8] negate(reduce2)
-      ROOT tuple = (f32[8], f32[8], f32[8]) tuple(log, neg, abs)
-    }
-    ENTRY main {
-      a = f32[8,1024] parameter(0)
-      c = f32[] constant(0)
-      ROOT fusion = (f32[8], f32[8], f32[8]) fusion(a, c), kind=kInput,
-        calls=fused_computation
-    })";
-  TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"(
-    // CHECK-DAG: pure_call @Add_add
-    // CHECK-DAG: shuffle_reduce @Add_add
-    // CHECK-DAG: pure_call @Mul_mul
-    // CHECK-DAG: shuffle_reduce @Mul_mul
-    // CHECK: allocate_shared
-    // CHECK: allocate_shared
-    // CHECK: sync_threads
-    // CHECK-DAG: shuffle_reduce @Add_add
-    // CHECK-DAG: shuffle_reduce @Mul_mul
-  )"));
-  EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3}));
-}
-
-TEST_F(MlirRowReductionTest, RowReduceMOFGroups) {
-  constexpr auto kHloString = R"(
-    %add_f32 {
-      %x = f32[] parameter(0)
-      %y = f32[] parameter(1)
-      ROOT %add = f32[] add(%x, %y)
-    }
-
-    %fused_computation {
-      %param0 = f32[1024] parameter(0)
-      %param1 = f32[1024] parameter(1)
-      %constant0 = f32[] constant(0)
-      %reduce1 = f32[] reduce(%param0, %constant0), dimensions={0}, to_apply=%add_f32
-      %reduce2 = f32[] reduce(%param1, %constant0), dimensions={0}, to_apply=%add_f32
-      ROOT %tuple = (f32[], f32[]) tuple(%reduce1, %reduce2)
-    }
-
-    ENTRY %cluster {
-      %param0 = f32[1024] parameter(0)
-      %param1 = f32[1024] parameter(1)
-      ROOT %fusion = (f32[], f32[])
-          fusion(%param0, %param1), kind=kInput, calls=%fused_computation
-    })";
-  TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"(
-    // CHECK: scf.index_switch %block_id_y
-    // CHECK: case 1 {
-    // CHECK: default {
-  )"));
-  EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3}));
-}
-
-TEST_F(MlirRowReductionTest, F64RowReductionIndexing) {
-  auto fusion = GetEmitter(kF64RowReduction);
-  TF_EXPECT_OK(TestBijection(
-      *fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context_),
-      /*shape=*/{100, 128}));
-  TF_EXPECT_OK(
-      TestBijection(*fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context_),
-                    /*shape=*/{100}));
-}
-
-TEST_F(MlirRowReductionTest, F64RowReductionIr) {
-  // This reduction is small enough not to require shared memory.
-  TF_ASSERT_OK(EmitAndCheckIR(kF64RowReduction, R"(
-    // CHECK-NOT: allocate_shared
-  )"));
-}
-
-TEST_F(MlirRowReductionTest, F64RowReductionCorrectness) {
-  EXPECT_TRUE(RunAndCompareNoHloPasses(kF64RowReduction, ErrorSpec{1e-3}));
-}
-
-TEST_F(MlirRowReductionTest, RowReductionMinorAndMajorIndexing) {
-  auto fusion = GetEmitter(kRowReductionMinorAndMajor);
-
-  TF_EXPECT_OK(TestBijection(
-      *fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context_),
-      /*shape=*/{7, 100, 128}));
-  TF_EXPECT_OK(
-      TestBijection(*fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context_),
-                    /*shape=*/{100}));
-}
-
-TEST_F(MlirRowReductionTest, RowReductionMinorAndMajorCorrectness) {
-  EXPECT_TRUE(
-      RunAndCompareNoHloPasses(kRowReductionMinorAndMajor, ErrorSpec{1e-3}));
-}
-
 TEST_F(MlirMultiRowReductionTest, MultiRowReductionIndexing) {
   auto fusion = GetEmitter(kMultiRowReductionX8);
 
@@ -379,207 +143,6 @@
   EXPECT_TRUE(RunAndCompareNoHloPasses(kMultiRowReductionX8, ErrorSpec{1e-3}));
 }
 
-TEST_F(MlirRowReductionTest, NonPowerOfTwoRowReduction) {
-  constexpr auto kHloString = R"(
-    HloModule Test, is_scheduled=true
-
-    Add {
-      lhs = f32[] parameter(0)
-      rhs = f32[] parameter(1)
-      ROOT add = f32[] add(lhs, rhs)
-    }
-    fused_computation {
-      param_0 = f32[100,568] parameter(0)
-      param_1 = f32[] parameter(1)
-      ROOT reduce = f32[100] reduce(param_0, param_1), dimensions={1}, to_apply=Add
-    }
-    ENTRY main {
-      a = f32[100,568] parameter(0)
-      c = f32[] constant(0)
-      ROOT fusion = f32[100] fusion(a, c), kind=kInput, calls=fused_computation
-    })";
-  TF_EXPECT_OK(EmitAndCheckIR(kHloString, R"(
-    // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1)[s0] -> ((d1 mod 64) * 2 + s0 * 128 + d0)>
-    // CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> ((d1 mod 64) * 2 + d0 + 512)>
-    // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-    // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-    // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-    // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
-    // CHECK: %[[FULL_TILES:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[C4]] step %[[C1]]
-    // CHECK-NEXT: scf.for %[[J:.*]] = %[[C0]] to %[[C2]] step %[[C1]]
-    // CHECK-NOT: scf.if
-    // CHECK: xla_gpu.apply_indexing #[[MAP1]](%[[J]] in [0, 1], %thread_id_x in [0, 255])[%[[I]] in [0, 3]]
-    // CHECK: scf.for %[[J:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%{{.*}} = %[[FULL_TILES]])
-    // CHECK: scf.if
-    // CHECK: xla_gpu.apply_indexing #[[MAP2]](%[[J]] in [0, 1], %thread_id_x in [0, 255])
-  )"));
-  EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3}));
-}
-
-TEST_F(MlirMultiRowReductionTest, NonTrivialEpilogueCorrectness) {
-  constexpr auto kHloString = R"(
-    HloModule module
-    add {
-      p0 = f64[] parameter(0)
-      p1 = f64[] parameter(1)
-      ROOT add = f64[] add(p0, p1)
-    }
-    fusion {
-      %p0 = f64[4] parameter(0)
-      %p1 = f64[4] parameter(1)
-      %c0 = f64[] constant(-inf)
-      %reduce0 = f64[] reduce(p1, c0), dimensions={0}, to_apply=add
-      %bc0 = f64[4] broadcast(reduce0), dimensions={}
-      %compare0 = pred[4] compare(p1, bc0), direction=EQ
-      %c1 = f64[] constant(0)
-      %bc1 = f64[4] broadcast(c1), dimensions={}
-      %select.3.1 = f64[4] select(compare0, p0, bc1)
-      %reduce1 = f64[] reduce(select.3.1, c1), dimensions={0}, to_apply=add
-      %convert0 = f64[4] convert(compare0)
-      %reduce2 = f64[] reduce(convert0, c1), dimensions={0}, to_apply=add
-      ROOT %tuple.1 = (f64[], f64[], f64[]) tuple(%reduce1, reduce0, reduce2)
-    }
-    ENTRY main {
-      %p0 = f64[4] parameter(0)
-      %p1 = f64[4] parameter(1)
-      ROOT %fusion = (f64[], f64[], f64[]) fusion(%p0, %p1), kind=kInput,
-        calls=fusion
-    })";
-  EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3}));
-}
-
-TEST_F(MlirRowReductionTest, SideOutputIndexing) {
-  auto fusion = GetEmitter(kRowReductionSideOutput);
-  TF_EXPECT_OK(TestBijection(
-      *fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context_),
-      {8, 2048}));
-  TF_EXPECT_OK(TestBijection(
-      *fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context_), {8}));
-  TF_EXPECT_OK(
-      TestBijection(*fusion->ComputeThreadIdToOutputIndexing(1, &mlir_context_),
-                    {8, 2048}));  // Side output.
-}
-
-TEST_F(MlirRowReductionTest, SideOutputIr) {
-  TF_ASSERT_OK(EmitAndCheckIR(kRowReductionSideOutput, R"(
-    // CHECK: @fused_computation
-    // CHECK: scf.for
-    // CHECK: scf.for
-    // CHECK: %[[SIDE_OUTPUT:.*]] = xla_gpu.pure_call @fused_computation_exp
-    // CHECK-NEXT: tensor.insert %[[SIDE_OUTPUT]]
-  )"));
-}
-
-TEST_F(MlirRowReductionTest, SideOutputCorrectness) {
-  EXPECT_TRUE(
-      RunAndCompareNoHloPasses(kRowReductionSideOutput, ErrorSpec{1e-3}));
-}
-
-TEST_F(MlirRowReductionTest, UnsignedSideOutputCorrectness) {
-  constexpr auto kHloString = R"(
-    HloModule Test, is_scheduled=true
-
-    Add {
-      lhs = u32[] parameter(0)
-      rhs = u32[] parameter(1)
-      ROOT add = u32[] add(lhs, rhs)
-    }
-    fused_computation {
-      param_0 = u32[8,2048] parameter(0)
-      param_1 = u32[] parameter(1)
-      add = u32[8,2048] add(param_0, param_0)
-      reduce = u32[8] reduce(param_0, param_1), dimensions={1}, to_apply=Add
-      ROOT t = (u32[8], u32[8,2048]) tuple(reduce, add)
-    }
-    ENTRY main {
-      a = u32[8,2048] parameter(0)
-      c = u32[] constant(0)
-      ROOT fusion = (u32[8], u32[8,2048]) fusion(a, c), kind=kInput,
-          calls=fused_computation
-    })";
-  EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3}));
-}
-
-TEST_F(MlirRowReductionTest, BroadcastSideOutputCorrectness) {
-  constexpr auto kHloString = R"(
-    %add {
-      p0 = f32[] parameter(0)
-      p1 = f32[] parameter(1)
-      ROOT add = f32[] add(p0, p1)
-    }
-    %fusion {
-      %p0 = f32[6,6] parameter(0)
-      %c0 = f32[] constant(0)
-      %reduce = f32[] reduce(%p0, %c0), dimensions={0,1}, to_apply=%add
-      %broadcast = f32[6,6] broadcast(%reduce), dimensions={}
-      ROOT %tuple = (f32[6,6], f32[]) tuple(%broadcast, %reduce)
-    }
-    ENTRY main {
-      %p0 = f32[6,6] parameter(0)
-      ROOT %fusion = (f32[6,6], f32[]) fusion(%p0), kind=kInput, calls=%fusion
-    })";
-  EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3}));
-}
-
-TEST_F(MlirRowReductionTest, VariadicMOFCorrectness) {
-  constexpr auto kHloString = R"(
-    %reducer1 {
-      p0 = f32[] parameter(0)
-      p1 = f32[] parameter(1)
-      ROOT add = f32[] add(p0, p1)
-    }
-    %reducer2 {
-      p0 = f32[] parameter(0)
-      p1 = f32[] parameter(1)
-      p2 = f32[] parameter(2)
-      p3 = f32[] parameter(3)
-      add0 = f32[] add(p0, p2)
-      add1 = f32[] add(p1, p3)
-      ROOT tuple = (f32[], f32[]) tuple(add0, add1)
-    }
-    %fusion {
-      %p0 = f32[6,6] parameter(0)
-      %c0 = f32[] constant(0)
-      %neg = f32[6,6] negate(%p0)
-      %reduce1 = f32[] reduce(%neg, %c0), dimensions={0,1}, to_apply=%reducer1
-      %reduce2 = (f32[], f32[]) reduce(%p0, %p0, %c0, %c0), dimensions={0,1}, to_apply=%reducer2
-      ROOT %tuple = (f32[], (f32[], f32[]), f32[6,6]) tuple(%reduce1, %reduce2, %neg)
-    }
-    ENTRY main {
-      %p0 = f32[6,6] parameter(0)
-      ROOT %fusion = (f32[], (f32[], f32[]), f32[6,6]) fusion(%p0), kind=kInput, calls=%fusion
-    })";
-  EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3}));
-}
-
-TEST_F(MlirRowReductionTest, OutputLayoutCorrectness) {
-  constexpr std::string_view kHloString = R"(
-    add {
-      p0 = f32[] parameter(0)
-      p1 = f32[] parameter(1)
-      ROOT add = f32[] add(p0, p1)
-    }
-
-    fusion {
-      %input = f32[17,19,127] parameter(0)
-      %c0 = f32[] constant(0)
-      ROOT reduce = f32[17,19]{0,1} reduce(%input, %c0), dimensions={2}, to_apply=add
-    }
-
-    ENTRY entry {
-      %input = f32[17,19,127] parameter(0)
-      ROOT %fusion = f32[17,19]{0,1} fusion(%input), kind=kInput, calls=fusion
-    })";
-
-  auto fusion = GetEmitter(kHloString);
-  TF_EXPECT_OK(TestBijection(
-      *fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context_),
-      {17, 19, 127}));
-  TF_EXPECT_OK(TestBijection(
-      *fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context_), {17, 19}));
-  EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3}));
-}
-
 TEST_F(MlirMultiRowReductionTest, TwoGroups) {
   auto module = ParseAndReturnVerifiedModule(R"(
     add {
@@ -604,7 +167,7 @@
                     .value();
 
   auto* root = module->entry_computation()->root_instruction();
-  auto analysis = AnalyzeFusion(*root, device_info_);
+  auto analysis = HloFusionAnalysis::Create(*root, device_info_);
   MlirMultiRowReductionFusion fusion(analysis);
 
   EXPECT_THAT(fusion.GetGroups().grouped_roots,
@@ -635,231 +198,12 @@
                     .value();
 
   auto* root = module->entry_computation()->root_instruction();
-  auto analysis = AnalyzeFusion(*root, device_info_);
+  auto analysis = HloFusionAnalysis::Create(*root, device_info_);
 
   MlirMultiRowReductionFusion mlir_fusion(analysis);
   EXPECT_THAT(mlir_fusion.GetGroups().grouped_roots, SizeIs(1));
 }
 
-constexpr absl::string_view kColumnVectorizationTemplate = R"(
-    add {
-      b = $0[] parameter(1)
-      a = $0[] parameter(0)
-      ROOT out = $0[] add(a, b)
-    }
-    fusion {
-      %p0 = $0[192,64,1536] parameter(0)
-      %p1 = $0[] parameter(1)
-      ROOT reduce = $0[192,1536] reduce(p0, p1), dimensions={1}, to_apply=add
-    }
-    ENTRY entry {
-      %p0 = $0[192,64,1536] parameter(0)
-      %p1 = $0[] parameter(1)
-      ROOT %fusion = $0[192,1536] fusion(p0, p1), kind=kInput, calls=fusion
-    })";
-
-TEST_F(MlirColumnReductionTest, ColumnReduction) {
-  constexpr auto kHloString = R"(
-    HloModule Test, is_scheduled=true
-
-    Add {
-      lhs = f32[] parameter(0)
-      rhs = f32[] parameter(1)
-      ROOT add = f32[] add(lhs, rhs)
-    }
-    fused_computation {
-      param_0 = f32[13,1051,321] parameter(0)
-      param_1 = f32[] parameter(1)
-      ROOT reduce = f32[13,321] reduce(param_0, param_1), dimensions={1}, to_apply=Add
-    }
-    ENTRY main {
-      a = f32[13,1051,321] parameter(0)
-      c = f32[] constant(0)
-      ROOT fusion = f32[13,321] fusion(a, c), kind=kInput, calls=fused_computation
-    })";
-
-  auto module = ParseAndReturnVerifiedModule(kHloString).value();
-  auto* root = module->entry_computation()->root_instruction();
-  auto analysis = AnalyzeFusion(*root, device_info_);
-  MlirColumnReductionFusion fusion(analysis);
-  EXPECT_THAT(
-      fusion.ComputeThreadIdToInputIndexing(0, 0, &mlir_context_)->ToString(),
-      MatchIndexingString(R"(
-        (d0, d1, d2, d3, d4, d5)[s0, s1] -> (
-          d3 floordiv 11,
-          d0 floordiv 32 + s0 * 32,
-          (d3 mod 11) * 32 + d0 mod 32
-        )
-        domain:
-        d0 in [0, 1023]
-        d1 in [0, 0]
-        d2 in [0, 0]
-        d3 in [0, 142]
-        d4 in [0, 0]
-        d5 in [0, 0]
-        s0 in [0, 32]
-        s1 in [0, 0]
-        (d3 mod 11) * 32 + d0 mod 32 in [0, 320]
-        d0 floordiv 32 + s0 * 32 in [0, 1050]
-      )"));
-  EXPECT_THAT(
-      fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(),
-      MatchIndexingString(R"(
-        (d0, d1, d2, d3, d4, d5)[s0] -> (
-          d3 floordiv 11, (d3 mod 11) * 32 + d0 floordiv 32
-        )
-        domain:
-        d0 in [0, 992]
-        d1 in [0, 0]
-        d2 in [0, 0]
-        d3 in [0, 142]
-        d4 in [0, 0]
-        d5 in [0, 0]
-        s0 in [0, 0]
-        (d3 mod 11) * 32 + d0 floordiv 32 in [0, 320]
-        d0 mod 32 in [0, 0]
-      )"));
-  TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"(
-    // CHECK: xla_gpu.pure_call @Add_add
-    // CHECK: allocate_shared
-    // CHECK: tensor.insert
-    // CHECK: sync_threads
-    // CHECK: predicated_extract
-    // CHECK: shuffle_reduce
-    // CHECK: predicated_insert
-  )"));
-  EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3}));
-}
-
-TEST_F(MlirColumnReductionTest, SmallColumnReduction) {
-  constexpr auto kHloString = R"(
-    HloModule Test, is_scheduled=true
-
-    Add {
-      lhs = f32[] parameter(0)
-      rhs = f32[] parameter(1)
-      ROOT add = f32[] add(lhs, rhs)
-    }
-    fused_computation {
-      param_0 = f32[3,128,4] parameter(0)
-      param_1 = f32[] parameter(1)
-      ROOT reduce = f32[3,4] reduce(param_0, param_1), dimensions={1}, to_apply=Add
-    }
-    ENTRY main {
-      a = f32[3,128,4] parameter(0)
-      c = f32[] constant(0)
-      ROOT fusion = f32[3,4] fusion(a, c), kind=kInput, calls=fused_computation
-    })";
-  EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3}));
-}
-
-TEST_F(MlirColumnReductionTest, MixedIndexing) {
-  constexpr auto kHloString = R"(
-    HloModule module
-    add {
-      p0 = f32[] parameter(0)
-      p1 = f32[] parameter(1)
-      ROOT add = f32[] add(p0, p1)
-    }
-    fusion {
-      %param_0 = f32[64,128] parameter(0)
-      %constant_0 = f32[] constant(0)
-      %reduce.1 = f32[128] reduce(f32[64,128] %param_0, f32[] %constant_0), dimensions={0}, to_apply=%add
-      %neg = f32[64,128] negate(f32[64,128] %param_0)
-      %bitcast = f32[8,8,128]{2,1,0} bitcast(f32[64,128] %neg)
-      %reduce.2 = f32[128] reduce(f32[8,8,128]{2,1,0} %bitcast, f32[] %constant_0), dimensions={0,1}, to_apply=%add
-      ROOT %tuple.12 = (f32[128], f32[128]) tuple(f32[128] %reduce.1, f32[128] %reduce.2)
-    }
-    ENTRY entry {
-      %param_0 = f32[64,128] parameter(0)
-      ROOT %fusion = (f32[128], f32[128]) fusion(%param_0), kind=kInput, calls=fusion
-    })";
-  EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3}));
-}
-
-TEST_F(MlirColumnReductionTest, ColumnReductionVectorizationCorrectness) {
-  constexpr auto kHloString = R"(
-    HloModule Test, is_scheduled=true
-    Add {
-      lhs = f32[] parameter(0)
-      rhs = f32[] parameter(1)
-      ROOT add = f32[] add(lhs, rhs)
-    }
-    fused_computation {
-      param_0 = f32[2048,16384] parameter(0)
-      param_1 = f32[] parameter(1)
-      ROOT reduce = f32[16384] reduce(param_0, param_1), dimensions={0}, to_apply=Add
-    }
-    ENTRY main {
-      a = f32[2048,16384] parameter(0)
-      c = f32[] constant(0)
-      ROOT fusion = f32[16384] fusion(a, c), kind=kInput, calls=fused_computation
-    })";
-  TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"(
-    // CHECK: vector<2xf32>
-  )"));
-  EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3}));
-}
-
-TEST_F(MlirColumnReductionTest, ColumnReductionVectorization_v4) {
-  constexpr auto kHloString = R"(
-    HloModule Test, is_scheduled=true
-    Add {
-      lhs = s16[] parameter(0)
-      rhs = s16[] parameter(1)
-      ROOT add = s16[] add(lhs, rhs)
-    }
-    fused_computation {
-      param_0 = s16[2048,16384] parameter(0)
-      param_1 = s16[] parameter(1)
-      ROOT reduce = s16[16384] reduce(param_0, param_1), dimensions={0}, to_apply=Add
-    }
-    ENTRY main {
-      a = s16[2048,16384] parameter(0)
-      c = s16[] constant(0)
-      ROOT fusion = s16[16384] fusion(a, c), kind=kInput, calls=fused_computation
-    })";
-  TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"(
-    // CHECK: vector<4xi16>
-  )"));
-  // We don't use RunAndCompareNoHloPasses because the interpreter is too slow
-  // for this input.
-}
-
-TEST_F(MlirColumnReductionTest, ThreadIndexingColumn_v2) {
-  const auto hlo_string = absl::Substitute(kColumnVectorizationTemplate, "f32");
-  auto fusion = GetEmitter(hlo_string);
-  EXPECT_THAT(GetLoopTripCounts(*fusion->ComputeThreadIdToInputIndexing(
-                  0, 0, &mlir_context_)),
-              ElementsAre(2 /* major reduced */, 2 /* vector size */));
-}
-
-TEST_F(MlirColumnReductionTest, ThreadIndexingColumn_v4) {
-  const auto hlo_string = absl::Substitute(kColumnVectorizationTemplate, "f16");
-  auto fusion = GetEmitter(hlo_string);
-  EXPECT_THAT(GetLoopTripCounts(*fusion->ComputeThreadIdToInputIndexing(
-                  0, 0, &mlir_context_)),
-              ElementsAre(2 /* major reduced */, 4 /* vector size */));
-}
-
-TEST_F(MlirColumnReductionTest, ThreadIndexingColumn_f64) {
-  // Verifies that we do not use the vectorized indexing for f64.
-  const auto hlo_string = absl::Substitute(kColumnVectorizationTemplate, "f64");
-  auto fusion = GetEmitter(hlo_string);
-  EXPECT_THAT(GetLoopTripCounts(*fusion->ComputeThreadIdToInputIndexing(
-                  0, 0, &mlir_context_)),
-              ElementsAre(2 /* major reduced */, 1 /* vector size */));
-}
-
-TEST_F(MlirColumnReductionTest, ThreadIndexingColumn_Complex) {
-  // Verifies that we do not use the vectorized indexing for complex types.
-  const auto hlo_string = absl::Substitute(kColumnVectorizationTemplate, "c64");
-  auto fusion = GetEmitter(hlo_string);
-  EXPECT_THAT(GetLoopTripCounts(*fusion->ComputeThreadIdToInputIndexing(
-                  0, 0, &mlir_context_)),
-              ElementsAre(2 /* major reduced */, 1 /* vector size */));
-}
-
 TEST_F(MlirMultiRowReductionTest, VectorizedX4Indexing) {
   auto fusion = GetEmitter(kMultiRowReductionX2VectorX4);
 
@@ -883,61 +227,6 @@
       RunAndCompareNoHloPasses(kMultiRowReductionX2VectorX4, ErrorSpec{1e-3}));
 }
 
-TEST_F(MlirRowReductionTest, LargeToUnit) {
-  // Regression test for a bug where not all threads in the warp produced a
-  // valid value for the final warp shuffle.
-  constexpr auto kHloString = R"(
-    and {
-      p0 = pred[] parameter(0)
-      p1 = pred[] parameter(1)
-      ROOT and = pred[] and(p0, p1)
-    }
-
-    %fused_reduce {
-      c1 = pred[] constant(true)
-      p0 = pred[10000] broadcast(c1), dimensions={}
-      ROOT reduce = pred[] reduce(p0, c1), dimensions={0}, to_apply=and
-    }
-  )";
-  EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3}));
-}
-
-TEST_F(MlirRowReductionTest, MOFTwoVariadic) {
-  // Regression test for a compilation crash with a MOF with two variadic
-  // reductions.
-  constexpr auto kHloString = R"(
-    add {
-      p0 = f32[] parameter(0)
-      p1 = f32[] parameter(1)
-      p2 = f32[] parameter(2)
-      p3 = f32[] parameter(3)
-      a = f32[] add(p0, p2)
-      b = f32[] add(p1, p3)
-      ROOT out = (f32[], f32[]) tuple(a, b)
-    }
-
-    fused_reduce {
-      p0 = f32[3,2] parameter(0)
-      p1 = f32[3,2] parameter(1)
-      c0 = f32[] constant(0)
-      iota0 = f32[3,2] iota(), iota_dimension=1
-      iota1 = f32[3,2] iota(), iota_dimension=1
-      reduce0 = (f32[3], f32[3]) reduce(p0, iota0, c0, c0), dimensions={1},
-          to_apply=add
-      reduce1 = (f32[3], f32[3]) reduce(p1, iota1, c0, c0), dimensions={1},
-          to_apply=add
-      ROOT tuple = ((f32[3], f32[3]), (f32[3], f32[3])) tuple(reduce0, %reduce1)
-    }
-
-    ENTRY main {
-      p0 = f32[3,2] parameter(0)
-      p1 = f32[3,2] parameter(1)
-      ROOT fusion = ((f32[3], f32[3]), (f32[3], f32[3])) fusion(p0, p1),
-        kind=kInput, calls=fused_reduce
-    })";
-  EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3}));
-}
-
 }  // namespace
 }  // namespace gpu
 }  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_test.cc b/third_party/xla/xla/service/gpu/fusions/reduction_test.cc
deleted file mode 100644
index 81649a7..0000000
--- a/third_party/xla/xla/service/gpu/fusions/reduction_test.cc
+++ /dev/null
@@ -1,176 +0,0 @@
-/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/fusions/reduction.h"
-
-#include <memory>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "mlir/IR/MLIRContext.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/service/gpu/fusions/fusion_emitter.h"
-#include "xla/service/gpu/gpu_device_info_for_tests.h"
-#include "xla/service/gpu/hlo_fusion_analysis.h"
-#include "xla/service/gpu/ir_emitter_context.h"
-#include "xla/service/gpu/model/indexing_analysis.h"
-#include "xla/service/gpu/model/indexing_test_utils.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/tests/hlo_test_base.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-using ::testing::ElementsAre;
-using ::testing::SizeIs;
-
-class ReductionTest : public HloTestBase {
- protected:
-  stream_executor::DeviceDescription device_info_ =
-      TestGpuDeviceInfo::RTXA6000DeviceInfo();
-  mlir::MLIRContext mlir_context_;
-};
-
-TEST_F(ReductionTest, ThreadIndexingRowReduction) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule module
-
-    add {
-      p0 = f32[] parameter(0)
-      p1 = f32[] parameter(1)
-      ROOT add = f32[] add(p0, p1)
-    }
-
-    fusion {
-      %input = f32[100,64,512] parameter(0)
-      %c0 = f32[] constant(0)
-      ROOT reduce = f32[100,64] reduce(%input, %c0), dimensions={2}, to_apply=add
-    }
-
-    ENTRY entry {
-      %input = f32[100,64,512] parameter(0)
-      ROOT %fusion = f32[100,64] fusion(%input), kind=kInput, calls=fusion
-    })")
-                    .value();
-
-  auto* root = module->entry_computation()->root_instruction();
-  auto analysis = AnalyzeFusion(*root, device_info_);
-  ReductionFusion fusion(analysis);
-
-  EXPECT_THAT(
-      fusion.ComputeThreadIdToInputIndexing(0, 0, &mlir_context_)->ToString(),
-      MatchIndexingString(R"(
-        (d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3] -> (
-          d3 floordiv 8,
-          (d3 mod 8) * 8 + d0 floordiv 32,
-          (d0 mod 32) * 2 + s2 * 64 + s3
-        )
-        domain:
-        d0 in [0, 255]
-        d1 in [0, 0]
-        d2 in [0, 0]
-        d3 in [0, 799]
-        d4 in [0, 0]
-        d5 in [0, 0]
-        s0 in [0, 0]
-        s1 in [0, 0]
-        s2 in [0, 7]
-        s3 in [0, 1]
-      )"));
-  EXPECT_THAT(
-      fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(),
-      MatchIndexingString(R"(
-        (d0, d1, d2, d3, d4, d5) -> (
-          d3 floordiv 8,
-          (d3 mod 8) * 8 + d0 floordiv 32
-        )
-        domain:
-        d0 in [0, 224]
-        d1 in [0, 0]
-        d2 in [0, 0]
-        d3 in [0, 799]
-        d4 in [0, 0]
-        d5 in [0, 0]
-        d0 mod 32 in [0, 0]
-      )"));
-}
-
-TEST_F(ReductionTest, TwoGroups) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    add {
-      p0 = f32[] parameter(0)
-      p1 = f32[] parameter(1)
-      ROOT add = f32[] add(p0, p1)
-    }
-    fusion {
-      %p0 = f32[2] parameter(0)
-      %p1 = f32[2] parameter(1)
-      %c0 = f32[] constant(-inf)
-      %r0 = f32[] reduce(%p0, %c0), dimensions={0}, to_apply=add
-      %c1 = f32[] constant(inf)
-      %r1 = f32[] reduce(%p1, %c1), dimensions={0}, to_apply=add
-      ROOT %tuple = (f32[], f32[]) tuple(%r0, %r1)
-    }
-    ENTRY entry {
-      %p0 = f32[2] parameter(0)
-      %p1 = f32[2] parameter(1)
-      ROOT %fusion = (f32[], f32[]) fusion(%p0, %p1), kind=kInput, calls=fusion
-    })")
-                    .value();
-
-  auto* root = module->entry_computation()->root_instruction();
-  auto analysis = AnalyzeFusion(*root, device_info_);
-  ReductionFusion fusion(analysis);
-
-  EXPECT_THAT(fusion.reduction_info().GetGroups().grouped_roots,
-              ElementsAre(ElementsAre(&analysis.fusion_root(0).instruction()),
-                          ElementsAre(&analysis.fusion_root(1).instruction())));
-}
-
-TEST_F(ReductionTest, OneGroup) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    %add {
-      %p0 = c128[] parameter(0)
-      %p1 = c128[] parameter(1)
-      ROOT %add.35 = c128[] add(c128[] %p0, c128[] %p1)
-    }
-    %fusion {
-      %p0 = c128[1,2] parameter(0)
-      %c0 = c128[] constant((0, 0))
-      %reduce = c128[] reduce(%p0, %c0), dimensions={0,1}, to_apply=%add
-      %real = f64[] real(c128[] %reduce)
-      %imag = f64[] imag(c128[] %reduce)
-      %negate = f64[] negate(f64[] %imag)
-      ROOT %tuple.29 = (f64[], f64[]) tuple(f64[] %real, f64[] %negate)
-    }
-    ENTRY entry {
-      %p0 = c128[1,2] parameter(0)
-      ROOT %fusion = (f64[], f64[]) fusion(%p0), kind=kInput, calls=fusion
-    })")
-                    .value();
-
-  auto* root = module->entry_computation()->root_instruction();
-  auto analysis = AnalyzeFusion(*root, device_info_);
-  ReductionFusion fusion(analysis);
-
-  EXPECT_THAT(fusion.reduction_info().GetGroups().grouped_roots, SizeIs(2));
-}
-
-}  // namespace
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/scatter.cc b/third_party/xla/xla/service/gpu/fusions/scatter.cc
deleted file mode 100644
index 8f7f773..0000000
--- a/third_party/xla/xla/service/gpu/fusions/scatter.cc
+++ /dev/null
@@ -1,293 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "xla/service/gpu/fusions/scatter.h"
-
-#include <cstddef>
-#include <cstdint>
-#include <optional>
-#include <string>
-#include <vector>
-
-#include "absl/log/check.h"
-#include "absl/status/status.h"
-#include "absl/types/span.h"
-#include "llvm/ADT/ArrayRef.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/IR/IRBuilder.h"
-#include "llvm/IR/Value.h"
-#include "mlir/IR/AffineExpr.h"
-#include "mlir/IR/AffineMap.h"
-#include "mlir/IR/MLIRContext.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/gpu/elemental_ir_emitter.h"
-#include "xla/service/gpu/fusions/loop.h"
-#include "xla/service/gpu/hlo_fusion_analysis.h"
-#include "xla/service/gpu/ir_emission_utils.h"
-#include "xla/service/gpu/ir_emitter_context.h"
-#include "xla/service/gpu/ir_emitter_nested.h"
-#include "xla/service/gpu/launch_dimensions.h"
-#include "xla/service/gpu/model/indexing_analysis.h"
-#include "xla/service/gpu/model/indexing_map.h"
-#include "xla/service/gpu/parallel_loop_emitter.h"
-#include "xla/service/llvm_ir/fused_ir_emitter.h"
-#include "xla/service/llvm_ir/ir_array.h"
-#include "xla/service/llvm_ir/llvm_util.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-ScatterFusion::ScatterFusion(const HloFusionAnalysis& analysis)
-    : analysis_(analysis), config_(ComputeLoopFusionConfig(analysis)) {
-  CHECK_EQ(analysis.fusion_root_count(), 1);
-  CHECK_EQ(analysis.fusion_root(0).opcode(), HloOpcode::kScatter);
-}
-
-LaunchDimensions ScatterFusion::launch_dimensions() const {
-  const auto& updates_shape =
-      analysis_.fusion_root(0).instruction().operands().back()->shape();
-  return CalculateLaunchDimensions(updates_shape, analysis_.device_info());
-}
-
-absl::Status ScatterFusion::EmitKernel(IrEmitterContext& ir_emitter_context,
-                                       const HloFusionInstruction& fusion,
-                                       const LaunchDimensions& launch_dims,
-                                       std::vector<llvm_ir::IrArray> inputs,
-                                       std::vector<llvm_ir::IrArray> outputs,
-                                       llvm::IRBuilder<>* builder) const {
-  GpuElementalIrEmitter elemental_emitter(ir_emitter_context, builder);
-  // Spin up a new fused emitter for the scatter kernel and emit it.
-  FusedIrEmitter scatter_fused_emitter(elemental_emitter);
-  auto* fused_computation = fusion.fused_instructions_computation();
-  for (int i = 0; i < fused_computation->num_parameters(); i++) {
-    auto fused_operand = fused_computation->parameter_instruction(i);
-    scatter_fused_emitter.BindGenerator(
-        *fused_operand, [builder, &input = inputs[i],
-                         fused_operand](llvm_ir::IrArray::Index index) {
-          return input.EmitReadArrayElement(index, builder,
-                                            fused_operand->name());
-        });
-  }
-
-  auto* root = fused_computation->root_instruction();
-  const xla::ScatterDimensionNumbers& scatter_dims =
-      Cast<HloScatterInstruction>(root)->scatter_dimension_numbers();
-
-  std::string name = llvm_ir::IrName(root);
-  const Shape& operand_shape = root->operand(0)->shape();
-  const Shape& scatter_indices_shape = root->operand(1)->shape();
-  const Shape& updates_shape = root->operand(2)->shape();
-  const HloComputation& update_computation = *root->called_computations()[0];
-
-  TF_ASSIGN_OR_RETURN(auto scatter_indices_gen,
-                      scatter_fused_emitter.GetGenerator(*root->operand(1)));
-  TF_ASSIGN_OR_RETURN(auto updates_gen,
-                      scatter_fused_emitter.GetGenerator(*root->operand(2)));
-
-  auto loop_body_emitter =
-      [&](const llvm_ir::IrArray::Index& index) -> absl::Status {
-    std::vector<llvm::Value*> raw_window_multidim;
-    std::vector<llvm::Value*> input_scatter_multidim;
-    std::vector<int64_t> raw_window_bounds;
-
-    auto get_i64_array = [](absl::Span<const int64_t> container) {
-      return llvm::ArrayRef<int64_t>{container.data(),
-                                     static_cast<size_t>(container.size())};
-    };
-
-    llvm::ArrayRef<int64_t> update_window_dims =
-        get_i64_array(scatter_dims.update_window_dims());
-    // Partition the index into window indices and scatter indices.
-    for (int64_t i = 0, e = index.size(); i != e; ++i) {
-      // For window indices also remember the window size, this comes in handy
-      // later.
-      if (llvm::is_contained(update_window_dims, i)) {
-        raw_window_multidim.push_back(index[i]);
-        raw_window_bounds.push_back(updates_shape.dimensions(i));
-      } else {
-        input_scatter_multidim.push_back(index[i]);
-      }
-    }
-    DCHECK_EQ(raw_window_multidim.size(),
-              scatter_dims.update_window_dims_size());
-
-    // Apply inserted_window_dims to the window dimensions.
-    int64_t raw_window_multidim_idx = 0;
-    llvm::SmallVector<llvm::Value*> input_window_multidim;
-    llvm::SmallVector<int64_t> input_window_bounds;
-    const int64_t rank = operand_shape.rank();
-    input_window_bounds.reserve(rank);
-    input_window_multidim.reserve(rank);
-
-    llvm::ArrayRef<int64_t> inserted_window_dims =
-        get_i64_array(scatter_dims.inserted_window_dims());
-    for (int64_t i = 0; i != rank; ++i) {
-      if (llvm::is_contained(inserted_window_dims, i)) {
-        input_window_bounds.push_back(1);  // Trivial dimension.
-        input_window_multidim.push_back(index.GetConstantWithIndexType(0));
-      } else {
-        input_window_bounds.push_back(
-            raw_window_bounds[raw_window_multidim_idx]);
-        input_window_multidim.push_back(
-            raw_window_multidim[raw_window_multidim_idx]);
-        ++raw_window_multidim_idx;
-      }
-    }
-    DCHECK_EQ(input_window_multidim.size(), operand_shape.rank());
-
-    // Insert a 1 dimension at the end if index_vector_dim requests one.
-    Shape scatter_indices_shape_fixed = scatter_indices_shape;
-    if (scatter_dims.index_vector_dim() == scatter_indices_shape.rank()) {
-      scatter_indices_shape_fixed.add_dimensions(1);
-      scatter_indices_shape_fixed.mutable_layout()->add_minor_to_major(
-          scatter_dims.index_vector_dim());
-    }
-
-    // Now load the indices corresponding to the current window from
-    // scatter_indices.
-    std::vector<llvm::Value*> raw_scatter_index_multidim =
-        input_scatter_multidim;
-    raw_scatter_index_multidim.insert(
-        raw_scatter_index_multidim.begin() + scatter_dims.index_vector_dim(),
-        nullptr);
-
-    llvm::ArrayRef<int64_t> scatter_dims_to_operand_dims =
-        get_i64_array(scatter_dims.scatter_dims_to_operand_dims());
-    llvm::Value* is_in_bounds = builder->getTrue();
-    for (int64_t i = 0, e = scatter_dims_to_operand_dims.size(); i != e; ++i) {
-      // Our index is stored along index_vector_dim, insert that into the lookup
-      // index into scatter_indices.
-      raw_scatter_index_multidim[scatter_dims.index_vector_dim()] =
-          index.GetConstantWithIndexType(i);
-      llvm_ir::IrArray::Index raw_scatter_index_index(
-          raw_scatter_index_multidim, scatter_indices_shape_fixed,
-          index.GetType());
-
-      int64_t operand_dim = scatter_dims_to_operand_dims[i];
-      if (operand_dim > rank) {
-        return absl::OutOfRangeError(
-            "The provided scatter_dims_to_operand_dims was out of range.");
-      }
-      TF_ASSIGN_OR_RETURN(
-          llvm::Value* const loaded_scatter_index,
-          scatter_indices_gen(raw_scatter_index_index.SourceIndexOfReshape(
-              scatter_indices_shape_fixed, scatter_indices_shape, builder)));
-      // And add the index to our window index. This yields the output index.
-      llvm::Value* casted_scatter_index = builder->CreateIntCast(
-          loaded_scatter_index, index.GetType(),
-          /*isSigned=*/ShapeUtil::ElementIsSigned(scatter_indices_shape));
-      llvm::Value* dim_offset = builder->CreateAdd(
-          input_window_multidim[operand_dim], casted_scatter_index);
-      input_window_multidim[operand_dim] = dim_offset;
-
-      // Also do the bounds check now.
-      int64_t max_index = operand_shape.dimensions(operand_dim) -
-                          input_window_bounds[operand_dim] + 1;
-      // is_in_bounds = index >= 0 && index < dim_size-window_size+1
-      //   --> index u< dim_size-window_size+1
-      is_in_bounds = builder->CreateAnd(
-          is_in_bounds,
-          builder->CreateICmpULT(casted_scatter_index,
-                                 index.GetConstantWithIndexType(max_index)));
-    }
-
-    llvm_ir::LlvmIfData if_window_in_bounds_data = llvm_ir::EmitIfThenElse(
-        is_in_bounds, "scatter.in_bounds", builder, /*emit_else=*/false);
-    llvm_ir::SetToFirstInsertPoint(if_window_in_bounds_data.true_block,
-                                   builder);
-    // All done, now just read from the calculated input from the window, and do
-    // an atomic store to the calculated location in the output.
-    llvm_ir::IrArray::Index input_window_index(
-        input_window_multidim, outputs.back().GetShape(), index.GetType());
-    llvm::Value* output_address =
-        outputs.back().EmitArrayElementAddress(input_window_index, builder);
-    llvm::Value* input_address = llvm_ir::EmitAllocaAtFunctionEntry(
-        llvm_ir::PrimitiveTypeToIrType(updates_shape.element_type(),
-                                       ir_emitter_context.llvm_module()),
-        "input_address", builder);
-    TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, updates_gen(index));
-    builder->CreateStore(input_ir_value, input_address);
-
-    if (root->unique_indices()) {
-      return CallNestedComputation(
-          builder, ir_emitter_context, update_computation,
-          {output_address, input_address}, output_address);
-    }
-    return EmitAtomicOperationForNestedComputation(
-        builder, ir_emitter_context, update_computation, output_address,
-        input_address, outputs.back().GetElementLlvmType());
-  };
-
-  // Launch a kernel that reads every element in the updates tensor. We could
-  // also do one kernel per window instead if bounds checks turn out to be a
-  // bottleneck.
-  auto index_type =
-      GetIndexTypeForKernel(root, launch_dims.launch_bound(), builder);
-  return ParallelLoopEmitter(loop_body_emitter, updates_shape, launch_dims,
-                             builder)
-      .EmitLoop(name, index_type);
-}
-
-std::optional<IndexingMap> ScatterFusion::ComputeThreadIdToInputIndexing(
-    int64_t root_index, int64_t hero_operand_index,
-    mlir::MLIRContext* ctx) const {
-  const auto* scatter =
-      DynCast<HloScatterInstruction>(&analysis_.fusion_hero(0).instruction());
-  int64_t scatter_operand_count = scatter->scatter_operand_count();
-  // Scatter operands a packed in the following way:
-  // Operand IDs [0, scatter_operand_count - 1] for `scatter operands`.
-  // Operand ID  scatter_operand_count for `scatter indices`.
-  // Operand IDs [scatter_operand_count + 1, 2 * scatter_operand_count] for
-  // `scatter updates`.
-
-  // For scatter operands we do not know the thread ID indexing.
-  if (hero_operand_index < scatter_operand_count) {
-    return std::nullopt;
-  }
-  // Compute thread id mapping based on the first update operand.
-  Shape scatter_update_shape = scatter->scatter_updates().front()->shape();
-  IndexingMap scatter_update_map = GetDefaultThreadIdIndexingMap(
-      launch_dimensions(), config_.unroll_factor, scatter_update_shape, ctx);
-
-  // For scatter indices we project indexing for scatter updates and take the
-  // first result of the affine map only, because they coincide.
-  if (hero_operand_index == scatter_operand_count) {
-    Shape scatter_indices_shape = scatter->scatter_indices()->shape();
-    CHECK_EQ(scatter_indices_shape.rank(), 2) << scatter->ToString();
-    // Create a map from scatter update to scatter indices.
-    IndexingMap updates_to_indices_map{
-        mlir::AffineMap::get(
-            /*dimCount=*/scatter_update_shape.rank(), /*symbolCount=*/1,
-            {mlir::getAffineDimExpr(0, ctx), mlir::getAffineSymbolExpr(0, ctx)},
-            ctx),
-        DimVarsFromTensorSizes(scatter_update_shape.dimensions()),
-        RangeVarsFromTensorSizes({scatter_indices_shape.dimensions(1)}),
-        /*rt_vars=*/{}};
-    auto scatter_indices_map = scatter_update_map * updates_to_indices_map;
-    scatter_indices_map.Simplify();
-    return scatter_indices_map;
-  }
-  return scatter_update_map;
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/scatter.h b/third_party/xla/xla/service/gpu/fusions/scatter.h
deleted file mode 100644
index dda11c0..0000000
--- a/third_party/xla/xla/service/gpu/fusions/scatter.h
+++ /dev/null
@@ -1,71 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_FUSIONS_SCATTER_H_
-#define XLA_SERVICE_GPU_FUSIONS_SCATTER_H_
-
-#include <optional>
-#include <vector>
-
-#include "absl/log/check.h"
-#include "absl/status/status.h"
-#include "llvm/IR/IRBuilder.h"
-#include "mlir/IR/MLIRContext.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/gpu/fusions/fusion_emitter.h"
-#include "xla/service/gpu/hlo_fusion_analysis.h"
-#include "xla/service/gpu/ir_emitter_context.h"
-#include "xla/service/gpu/launch_dimensions.h"
-#include "xla/service/gpu/model/indexing_map.h"
-#include "xla/service/llvm_ir/ir_array.h"
-
-namespace xla {
-namespace gpu {
-
-// A scatter, implemented as a loop over the updates. All scatters are in-place.
-class ScatterFusion : public KernelFusionEmitterBase {
- public:
-  explicit ScatterFusion(const HloFusionAnalysis& analysis);
-
-  LaunchDimensions launch_dimensions() const override;
-
-  std::optional<IndexingMap> ComputeThreadIdToOutputIndexing(
-      int64_t root_index, mlir::MLIRContext* ctx) const override {
-    // The kernel iterates over updates, whose correspondence to output
-    // elements cannot be computed statically.
-    return std::nullopt;
-  }
-
-  std::optional<IndexingMap> ComputeThreadIdToInputIndexing(
-      int64_t root_index, int64_t hero_operand_index,
-      mlir::MLIRContext* ctx) const override;
-
- protected:
-  absl::Status EmitKernel(IrEmitterContext& ir_emitter_context,
-                          const HloFusionInstruction& fusion,
-                          const LaunchDimensions& launch_dims,
-                          std::vector<llvm_ir::IrArray> inputs,
-                          std::vector<llvm_ir::IrArray> outputs,
-                          llvm::IRBuilder<>* builder) const override;
-
- private:
-  const HloFusionAnalysis& analysis_;
-  LaunchDimensionsConfig config_;
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_FUSIONS_SCATTER_H_
diff --git a/third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc b/third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc
index a281c0e..85e1e50 100644
--- a/third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc
+++ b/third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc
@@ -39,9 +39,10 @@
 #include "xla/hlo/ir/hlo_instruction.h"
 #include "xla/hlo/ir/hlo_instructions.h"
 #include "xla/primitive_util.h"
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h"
 #include "xla/service/gpu/fusions/mlir/computation_partitioner.h"
 #include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h"
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h"
+#include "xla/service/gpu/gpu_fusible.h"
 #include "xla/service/gpu/hlo_fusion_analysis.h"
 #include "xla/service/gpu/launch_dimensions.h"
 #include "xla/service/gpu/model/indexing_map.h"
@@ -62,7 +63,6 @@
 using mlir::Value;
 using mlir::ValueRange;
 using mlir::func::ReturnOp;
-using mlir::tensor::InsertOp;
 using mlir_converter::CallTargetProvider;
 using mlir_converter::PartitionedComputations;
 using mlir_converter::ProvideParameter;
@@ -174,7 +174,8 @@
     auto reduced_val = mlir_converter::InlineBlock(
         b, reducer.getBody().front(), {operand_elem, update_elem})[0];
 
-    return b.create<InsertOp>(reduced_val, output_tensor, indices);
+    return b.create<mlir::tensor::InsertOp>(reduced_val, output_tensor,
+                                            indices);
   }
   auto atomic_rmw = b.create<AtomicRMWOp>(output_tensor, indices);
   mlir::OpBuilder body_builder = atomic_rmw.getBodyBuilder();
diff --git a/third_party/xla/xla/service/gpu/fusions/scatter_mlir.h b/third_party/xla/xla/service/gpu/fusions/scatter_mlir.h
index de9743a..3efaa0a 100644
--- a/third_party/xla/xla/service/gpu/fusions/scatter_mlir.h
+++ b/third_party/xla/xla/service/gpu/fusions/scatter_mlir.h
@@ -25,7 +25,6 @@
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/Interfaces/DataLayoutInterfaces.h"
 #include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/service/gpu/fusions/loop.h"
 #include "xla/service/gpu/fusions/mlir/computation_partitioner.h"
 #include "xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h"
 #include "xla/service/gpu/hlo_fusion_analysis.h"
diff --git a/third_party/xla/xla/service/gpu/fusions/scatter_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/scatter_mlir_test.cc
index 6b8d013..2e9a11a 100644
--- a/third_party/xla/xla/service/gpu/fusions/scatter_mlir_test.cc
+++ b/third_party/xla/xla/service/gpu/fusions/scatter_mlir_test.cc
@@ -77,7 +77,7 @@
   thread_id_printer_.SetSymbolName(2, "index_id");
 
   auto* root = module->entry_computation()->root_instruction();
-  auto analysis = AnalyzeFusion(*root, device_info_);
+  auto analysis = HloFusionAnalysis::Create(*root, device_info_);
   MlirScatterFusion fusion(analysis);
 
   constexpr auto kUpdatesIndexing = R"(
@@ -187,8 +187,8 @@
     }
   )";
   TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"(
-    // CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0 floordiv 2)>
-    // CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> (d0 mod 2)>
+    // CHECK: #[[$MAP0:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 floordiv 2)
+    // CHECK: #[[$MAP1:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 mod 2)
 
     // CHECK-LABEL: func.func @fused_computation(
     // CHECK-SAME:    %[[OPERAND:[a-zA-Z0-9]*]]: tensor<10x5xf32>
diff --git a/third_party/xla/xla/service/gpu/fusions/scatter_test.cc b/third_party/xla/xla/service/gpu/fusions/scatter_test.cc
deleted file mode 100644
index 284d308..0000000
--- a/third_party/xla/xla/service/gpu/fusions/scatter_test.cc
+++ /dev/null
@@ -1,224 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "xla/service/gpu/fusions/scatter.h"
-
-#include <optional>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "mlir/IR/MLIRContext.h"
-#include "xla/service/gpu/fusions/fusions.h"
-#include "xla/service/gpu/gpu_device_info_for_tests.h"
-#include "xla/service/gpu/hlo_fusion_analysis.h"
-#include "xla/service/gpu/model/affine_map_printer.h"
-#include "xla/service/gpu/model/indexing_test_utils.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/tests/hlo_test_base.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-class ScatterFusionTest : public HloTestBase {
- public:
-  void SetUp() override {
-    HloTestBase::SetUp();
-    printer_ =
-        AffineMapPrinter({"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"},
-                         {"chunk_id", "unroll_id", "index_id"});
-  }
-  DebugOptions GetDebugOptionsForTest() override {
-    auto opts = HloTestBase::GetDebugOptionsForTest();
-    opts.set_xla_gpu_mlir_emitter_level(0);
-    return opts;
-  }
-
- protected:
-  AffineMapPrinter printer_;
-  mlir::MLIRContext mlir_context_;
-};
-
-TEST_F(ScatterFusionTest, ScatterFusion) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule module
-
-    add (lhs: f32[], rhs: f32[]) -> f32[] {
-      lhs = f32[] parameter(0)
-      rhs = f32[] parameter(1)
-      ROOT sum = f32[] add(lhs, rhs)
-    }
-
-    fused_computation {
-      %input = f32[2,9] parameter(0)
-      %indices = s32[3] parameter(1)
-      %updates = f32[3,9] parameter(2)
-      ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates),
-          to_apply=add,
-          update_window_dims={1},
-          inserted_window_dims={0},
-          scatter_dims_to_operand_dims={0},
-          index_vector_dim=1
-    }
-
-    ENTRY entry {
-      %input = f32[2,9] parameter(0)
-      %indices = s32[3] parameter(1)
-      %updates = f32[3,9] parameter(2)
-      ROOT %fusion = f32[2,9] fusion(%input, %indices, %updates), kind=kLoop, calls=fused_computation
-    })")
-                    .value();
-
-  stream_executor::DeviceDescription device_info =
-      TestGpuDeviceInfo::RTXA6000DeviceInfo();
-
-  auto* root = module->entry_computation()->root_instruction();
-  auto analysis_fused = AnalyzeFusion(*root, device_info);
-
-  auto emitter =
-      GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis_fused});
-  auto scatter_fusion = dynamic_cast<ScatterFusion*>(emitter.get());
-  ASSERT_NE(scatter_fusion, nullptr);
-  EXPECT_EQ(scatter_fusion->launch_dimensions().launch_bound(),
-            3 * 9 /* updates size */);
-}
-
-TEST_F(ScatterFusionTest, ThreadIdIndexing) {
-  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"(
-    HloModule module
-
-    computation {
-      %p0 = f32[] parameter(0)
-      %p1 = f32[] parameter(1)
-      %p2 = f32[] parameter(2)
-      %p3 = f32[] parameter(3)
-      ROOT %tuple = (f32[], f32[]) tuple(f32[] %p2, f32[] %p3)
-    }
-    scatter {
-      %operand0 = f32[300,200] parameter(0)
-      %operand1 = f32[300,200] parameter(1)
-      %indices = s32[42,1] parameter(2)
-      %update.1 = f32[42,10,20] parameter(3)
-      %update.2 = f32[42,10,20]parameter(4)
-
-      ROOT %scatter = (f32[300,200], f32[300,200]) scatter(
-          f32[300,200] %operand0,
-          f32[300,200] %operand1,
-          s32[42,1] %indices,
-          f32[42,10,20] %update.1,
-          f32[42,10,20] %update.2
-        ),
-        update_window_dims={1,2},
-        inserted_window_dims={},
-        scatter_dims_to_operand_dims={0},
-        index_vector_dim=1,
-        to_apply=computation
-    }
-    ENTRY entry {
-      %operand0 = f32[300,200] parameter(0)
-      %operand1 = f32[300,200] parameter(1)
-      %indices = s32[42,1] parameter(2)
-      %update.1 = f32[42,10,20] parameter(3)
-      %update.2 = f32[42,10,20]parameter(4)
-      ROOT %fusion = (f32[300,200], f32[300,200]) fusion(
-        %operand0, %operand1, %indices, %update.1, %update.2),
-        kind=kLoop, calls=scatter
-    }
-  )"));
-  stream_executor::DeviceDescription device_info =
-      TestGpuDeviceInfo::RTXA6000DeviceInfo();
-
-  auto* root = module->entry_computation()->root_instruction();
-  auto analysis_fused = AnalyzeFusion(*root, device_info);
-
-  auto emitter =
-      GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis_fused});
-  auto fusion = dynamic_cast<ScatterFusion*>(emitter.get());
-  ASSERT_NE(fusion, nullptr);
-
-  constexpr auto kUpdatesIndexing = R"(
-    (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (
-      (bl_x * 128 + th_x) floordiv 200,
-      ((bl_x * 128 + th_x) floordiv 20) mod 10,
-      (bl_x * 128 + th_x) mod 20
-    )
-    domain:
-    th_x in [0, 127]
-    th_y in [0, 0]
-    th_z in [0, 0]
-    bl_x in [0, 65]
-    bl_y in [0, 0]
-    bl_z in [0, 0]
-    chunk_id in [0, 0]
-    unroll_id in [0, 0]
-    bl_x * 128 + th_x in [0, 8399]
-  )";
-  EXPECT_THAT(
-      fusion
-          ->ComputeThreadIdToInputIndexing(
-              /*root_index=*/0, /*hero_operand_index=*/3, &mlir_context_)
-          ->ToString(printer_),
-      MatchIndexingString(kUpdatesIndexing));
-  EXPECT_THAT(
-      fusion
-          ->ComputeThreadIdToInputIndexing(
-              /*root_index=*/0, /*hero_operand_index=*/4, &mlir_context_)
-          ->ToString(printer_),
-      MatchIndexingString(kUpdatesIndexing));
-  EXPECT_THAT(
-      fusion
-          ->ComputeThreadIdToInputIndexing(
-              /*root_index=*/1, /*hero_operand_index=*/3, &mlir_context_)
-          ->ToString(printer_),
-      MatchIndexingString(kUpdatesIndexing));
-  EXPECT_THAT(
-      fusion
-          ->ComputeThreadIdToInputIndexing(
-              /*root_index=*/1, /*hero_operand_index=*/4, &mlir_context_)
-          ->ToString(printer_),
-      MatchIndexingString(kUpdatesIndexing));
-
-  constexpr auto kIndicesIndexing = R"(
-    (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id, index_id] ->
-      ((bl_x * 128 + th_x) floordiv 200, 0)
-    domain:
-    th_x in [0, 127]
-    th_y in [0, 0]
-    th_z in [0, 0]
-    bl_x in [0, 65]
-    bl_y in [0, 0]
-    bl_z in [0, 0]
-    chunk_id in [0, 0]
-    unroll_id in [0, 0]
-    index_id in [0, 0]
-    bl_x * 128 + th_x in [0, 8399]
-  )";
-  EXPECT_THAT(
-      fusion
-          ->ComputeThreadIdToInputIndexing(
-              /*root_index=*/0, /*hero_operand_index=*/2, &mlir_context_)
-          ->ToString(printer_),
-      MatchIndexingString(kIndicesIndexing));
-  EXPECT_THAT(
-      fusion
-          ->ComputeThreadIdToInputIndexing(
-              /*root_index=*/1, /*hero_operand_index=*/2, &mlir_context_)
-          ->ToString(printer_),
-      MatchIndexingString(kIndicesIndexing));
-}
-
-}  // namespace
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/tests/BUILD b/third_party/xla/xla/service/gpu/fusions/tests/BUILD
new file mode 100644
index 0000000..d3e3b66
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tests/BUILD
@@ -0,0 +1,19 @@
+load("//xla:lit.bzl", "lit_test_suite")
+
+package(
+    # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
+    licenses = ["notice"],
+)
+
+lit_test_suite(
+    name = "tests",
+    srcs = glob(["**/*.hlo"]),
+    cfg = "//xla:lit.cfg.py",
+    default_tags = ["requires-gpu-sm80-only"],
+    tools = [
+        "//xla/service/gpu/fusions/tools:fusion_to_mlir",
+        "//xla/service/gpu/fusions/tools:mlir_fusions_opt",
+        "//xla/service/gpu/fusions/tools:test_correctness",
+        "@llvm-project//llvm:FileCheck",
+    ],
+)
diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/mof_heteorgeneous_input_shapes.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/mof_heteorgeneous_input_shapes.hlo
new file mode 100644
index 0000000..1646ade
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/mof_heteorgeneous_input_shapes.hlo
@@ -0,0 +1,19 @@
+// RUN: test_correctness %s --bijection_inputs=reduce.1:0 --bijection_inputs=reduce.2:0 --bijection_outputs=reduce.1 --bijection_outputs=reduce.2
+
+add {
+  p0 = f32[] parameter(0)
+  p1 = f32[] parameter(1)
+  ROOT add = f32[] add(p0, p1)
+}
+
+fusion {
+  %param_0 = f32[64,128] parameter(0)
+  %constant_0 = f32[] constant(0)
+  %reduce.1 = f32[128] reduce(param_0, constant_0), dimensions={0},
+    to_apply=%add
+  %neg = f32[64,128] negate(param_0)
+  %bitcast = f32[8,8,128] bitcast(neg)
+  %reduce.2 = f32[128] reduce(bitcast, constant_0), dimensions={0,1},
+    to_apply=%add
+  ROOT %tuple = (f32[128], f32[128]) tuple(reduce.1, reduce.2)
+}
diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/odd_sizes.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/odd_sizes.hlo
new file mode 100644
index 0000000..e7ae070
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/odd_sizes.hlo
@@ -0,0 +1,22 @@
+// RUN: fusion_to_mlir %s | mlir_fusions_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s
+// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce
+
+add {
+  lhs = f32[] parameter(0)
+  rhs = f32[] parameter(1)
+  ROOT add = f32[] add(lhs, rhs)
+}
+
+fused_computation {
+  param_0 = f32[13,1051,321] parameter(0)
+  param_1 = f32[] parameter(1)
+  ROOT reduce = f32[13,321] reduce(param_0, param_1), dimensions={1}, to_apply=add
+}
+
+// CHECK: xla_gpu.pure_call @add_add
+// CHECK: allocate_shared
+// CHECK: tensor.insert
+// CHECK: sync_threads
+// CHECK: predicated_extract
+// CHECK: shuffle_reduce
+// CHECK: predicated_insert
diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/small.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/small.hlo
new file mode 100644
index 0000000..958b391
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/small.hlo
@@ -0,0 +1,13 @@
+// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce
+
+add {
+  lhs = f32[] parameter(0)
+  rhs = f32[] parameter(1)
+  ROOT add = f32[] add(lhs, rhs)
+}
+
+fused_computation {
+  param_0 = f32[3,128,4] parameter(0)
+  c0 = f32[] constant(0)
+  ROOT reduce = f32[3,4] reduce(param_0, c0), dimensions={1}, to_apply=add
+}
diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x1_c64.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x1_c64.hlo
new file mode 100644
index 0000000..a2a22363
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x1_c64.hlo
@@ -0,0 +1,17 @@
+// RUN: fusion_to_mlir %s | mlir_fusions_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s
+// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce
+
+add {
+  lhs = c64[] parameter(0)
+  rhs = c64[] parameter(1)
+  ROOT add = c64[] add(lhs, rhs)
+}
+
+fused_computation {
+  param_0 = c64[128,64] parameter(0)
+  c0 = c64[] constant((0, 0))
+  ROOT reduce = c64[64] reduce(param_0, c0), dimensions={0},
+    to_apply=add
+}
+
+// CHECK-NOT: vector<
\ No newline at end of file
diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x1_f64.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x1_f64.hlo
new file mode 100644
index 0000000..660664b
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x1_f64.hlo
@@ -0,0 +1,17 @@
+// RUN: fusion_to_mlir %s | mlir_fusions_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s
+// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce
+
+add {
+  lhs = f64[] parameter(0)
+  rhs = f64[] parameter(1)
+  ROOT add = f64[] add(lhs, rhs)
+}
+
+fused_computation {
+  param_0 = f64[128,64] parameter(0)
+  c0 = f64[] constant(0)
+  ROOT reduce = f64[64] reduce(param_0, c0), dimensions={0},
+    to_apply=add
+}
+
+// CHECK-NOT: vector<
\ No newline at end of file
diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x2_f32.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x2_f32.hlo
new file mode 100644
index 0000000..a142ad4
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x2_f32.hlo
@@ -0,0 +1,17 @@
+// RUN: fusion_to_mlir %s | mlir_fusions_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s
+// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce
+
+add {
+  lhs = f32[] parameter(0)
+  rhs = f32[] parameter(1)
+  ROOT add = f32[] add(lhs, rhs)
+}
+
+fused_computation {
+  param_0 = f32[2048,64] parameter(0)
+  c0 = f32[] constant(0)
+  ROOT reduce = f32[64] reduce(param_0, c0), dimensions={0},
+    to_apply=add
+}
+
+// CHECK: vector<2xf32>
\ No newline at end of file
diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x4_s16.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x4_s16.hlo
new file mode 100644
index 0000000..81da088
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x4_s16.hlo
@@ -0,0 +1,17 @@
+// RUN: fusion_to_mlir %s | mlir_fusions_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s
+// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce
+
+add {
+  lhs = s16[] parameter(0)
+  rhs = s16[] parameter(1)
+  ROOT add = s16[] add(lhs, rhs)
+}
+
+fused_computation {
+  param_0 = s16[256,128] parameter(0)
+  c0 = s16[] constant(0)
+  ROOT reduce = s16[128] reduce(param_0, c0), dimensions={0},
+    to_apply=add
+}
+
+// CHECK: vector<4xi16>
\ No newline at end of file
diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/epilogue.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/epilogue.hlo
new file mode 100644
index 0000000..f8a9e86
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/epilogue.hlo
@@ -0,0 +1,22 @@
+// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-simplify-arith -xla-erase-dead-functions -inline -canonicalize | FileCheck %s
+// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce
+
+add {
+  lhs = f32[] parameter(0)
+  rhs = f32[] parameter(1)
+  ROOT add = f32[] add(lhs, rhs)
+}
+
+fused_computation {
+  param_0 = f32[8,2048] parameter(0)
+  param_1 = f32[] parameter(1)
+  reduce = f32[8] reduce(param_0, param_1), dimensions={1}, to_apply=add
+  ROOT log = f32[8] log(reduce)
+}
+
+// CHECK:      shuffle_reduce
+// CHECK:      allocate_shared
+// CHECK:      sync_threads
+// CHECK:      shuffle_reduce
+// CHECK-NEXT: %[[OUT:.*]] = math.log
+// CHECK:      predicated_insert %[[OUT]]
diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/inefficient_codegen.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/inefficient_codegen.hlo
new file mode 100644
index 0000000..bc84174
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/inefficient_codegen.hlo
@@ -0,0 +1,43 @@
+// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-to-inline  | FileCheck %s
+// RUN: test_correctness %s --bijection_inputs=reduce1:0 --bijection_inputs=reduce2:0 --bijection_outputs=reduce1 --bijection_outputs=reduce2
+
+add {
+  p0 = f64[] parameter(0)
+  p1 = f64[] parameter(1)
+  ROOT add = f64[] add(p0, p1)
+}
+
+// This fusion is valid, but we can't efficiently codegen it.
+fusion {
+  %p0 = f64[4] parameter(0)
+  %p1 = f64[4] parameter(1)
+  %c0 = f64[] constant(-inf)
+  %reduce0 = f64[] reduce(p1, c0), dimensions={0}, to_apply=add
+  %bc0 = f64[4] broadcast(reduce0), dimensions={}
+  %compare0 = pred[4] compare(p1, bc0), direction=EQ
+  %c1 = f64[] constant(0)
+  %bc1 = f64[4] broadcast(c1), dimensions={}
+  %select.3.1 = f64[4] select(compare0, p0, bc1)
+  %reduce1 = f64[] reduce(select.3.1, c1), dimensions={0}, to_apply=add
+  %convert0 = f64[4] convert(compare0)
+  %reduce2 = f64[] reduce(convert0, c1), dimensions={0}, to_apply=add
+  ROOT %tuple.1 = (f64[], f64[], f64[]) tuple(%reduce1, reduce0, reduce2)
+}
+
+// We read all of %p1 once from each thread, and then read one element again.
+// CHECK:      func.func @main
+// CHECK-SAME:   , %[[P1:.*]]: tensor<4xf64> {xla.slice_index = 1 : index}
+// CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:  %[[CST0:.*]] = arith.constant 0xFFF0000000000000
+// CHECK-DAG:  %[[TID_X:.*]] = gpu.thread_id x
+
+// reduce0 in the context of reduce2 and reduce1's prologue:
+// CHECK:      scf.for %[[I:.*]] = %[[C0]]
+// CHECK-NEXT:   tensor.extract %[[P1]][%[[I]]]
+// CHECK-NEXT:   addf
+// CHECK-NEXT:   yield
+
+// reduce0 again, in the context of its status as a fusion hero:
+// CHECK:     %[[EXTRACTED:.*]] = tensor.extract %[[P1]][%[[TID_X]]]
+// CHECK:     %[[ADDED:.*]] = arith.addf %[[CST0]], %[[EXTRACTED]]
+// CHECK:     shuffle_reduce @add_add(%[[ADDED]]) to 2
diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/large_to_unit.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/large_to_unit.hlo
new file mode 100644
index 0000000..ee155c8
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/large_to_unit.hlo
@@ -0,0 +1,15 @@
+// Regression test for a bug where not all threads in the warp produced a valid
+// value for the final warp shuffle.
+// RUN: test_correctness %s
+
+and {
+  p0 = pred[] parameter(0)
+  p1 = pred[] parameter(1)
+  ROOT and = pred[] and(p0, p1)
+}
+
+fused_reduce {
+  c1 = pred[] constant(true)
+  p0 = pred[10000] broadcast(c1), dimensions={}
+  ROOT reduce = pred[] reduce(p0, c1), dimensions={0}, to_apply=and
+}
diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/layout.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/layout.hlo
new file mode 100644
index 0000000..102e32b
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/layout.hlo
@@ -0,0 +1,17 @@
+// RUN: fusion_to_mlir %s | FileCheck %s
+// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce
+
+add {
+  p0 = f32[] parameter(0)
+  p1 = f32[] parameter(1)
+  ROOT add = f32[] add(p0, p1)
+}
+
+fusion {
+  %input = f32[17,19,127] parameter(0)
+  %c0 = f32[] constant(0)
+  // The output is physically transposed.
+  ROOT reduce = f32[17,19]{0,1} reduce(%input, %c0), dimensions={2}, to_apply=add
+}
+
+// CHECK: xla_gpu.predicated_insert {{.*}} : tensor<17x19xf32, dense<[0, 1]> : tensor<2xi64>>
diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/minor_and_major.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/minor_and_major.hlo
new file mode 100644
index 0000000..c9481f3
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/minor_and_major.hlo
@@ -0,0 +1,20 @@
+// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-simplify-arith -inline -canonicalize | FileCheck %s
+// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce
+
+add {
+  lhs = f32[] parameter(0)
+  rhs = f32[] parameter(1)
+  ROOT add = f32[] add(lhs, rhs)
+}
+
+fused_computation {
+  param_0 = f32[7,100,128] parameter(0)
+  param_1 = f32[] parameter(1)
+  ROOT reduce = f32[100] reduce(param_0, param_1), dimensions={0,2}, to_apply=add
+}
+
+// Our codegen doesn't support parallelizing the major reduction dimension. In
+// principle, this could be done via shared memory.
+// CHECK-NOT: allocate_shared
+// CHECK: shuffle_reduce
+// CHECK-NOT: allocate_shared
diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/mof_epilogue.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/mof_epilogue.hlo
new file mode 100644
index 0000000..315d604
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/mof_epilogue.hlo
@@ -0,0 +1,40 @@
+// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-simplify-arith -xla-erase-dead-functions -inline -canonicalize | FileCheck %s
+// RUN: test_correctness %s --bijection_inputs=reduce1:0 --bijection_inputs=reduce2:0 --bijection_outputs=reduce1 --bijection_outputs=reduce2
+
+add {
+  lhs = f32[] parameter(0)
+  rhs = f32[] parameter(1)
+  ROOT add = f32[] add(lhs, rhs)
+}
+
+mul {
+  lhs = f32[] parameter(0)
+  rhs = f32[] parameter(1)
+  ROOT mul = f32[] multiply(lhs, rhs)
+}
+
+fused_computation {
+  param_0 = f32[8,1024] parameter(0)
+  c0 = f32[] constant(0)
+  c1 = f32[] constant(1)
+  reduce1 = f32[8] reduce(param_0, c0), dimensions={1}, to_apply=add
+  reduce2 = f32[8] reduce(param_0, c1), dimensions={1}, to_apply=mul
+  log = f32[8] log(reduce1)
+  abs = f32[8] abs(reduce1)
+  neg = f32[8] negate(reduce2)
+  ROOT tuple = (f32[8], f32[8], f32[8]) tuple(log, neg, abs)
+}
+
+// CHECK-DAG: shuffle_reduce @add_add
+// CHECK-DAG: shuffle_reduce @mul_mul
+// CHECK: allocate_shared
+// CHECK: allocate_shared
+// CHECK: sync_threads
+// CHECK-DAG: %[[ADDED:.*]] = xla_gpu.shuffle_reduce @add_add
+// CHECK-DAG: %[[MULTIPLIED:.*]] = xla_gpu.shuffle_reduce @mul_mul
+// CHECK-DAG: %[[LOG:.*]] = math.log %[[ADDED]]
+// CHECK-DAG: %[[ABS:.*]] = math.absf %[[ADDED]]
+// CHECK-DAG: %[[NEG:.*]] = arith.negf %[[MULTIPLIED]]
+// CHECK-DAG: xla_gpu.predicated_insert %[[LOG]]
+// CHECK-DAG: xla_gpu.predicated_insert %[[ABS]]
+// CHECK-DAG: xla_gpu.predicated_insert %[[NEG]]
diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/mof_scalar_variadic.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/mof_scalar_variadic.hlo
new file mode 100644
index 0000000..48a2033
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/mof_scalar_variadic.hlo
@@ -0,0 +1,26 @@
+// RUN: test_correctness %s
+
+%reducer1 {
+  p0 = f32[] parameter(0)
+  p1 = f32[] parameter(1)
+  ROOT add = f32[] add(p0, p1)
+}
+
+%reducer2 {
+  p0 = f32[] parameter(0)
+  p1 = f32[] parameter(1)
+  p2 = f32[] parameter(2)
+  p3 = f32[] parameter(3)
+  add0 = f32[] add(p0, p2)
+  add1 = f32[] add(p1, p3)
+  ROOT tuple = (f32[], f32[]) tuple(add0, add1)
+}
+
+%fusion {
+  %p0 = f32[6,6] parameter(0)
+  %c0 = f32[] constant(0)
+  %neg = f32[6,6] negate(%p0)
+  %reduce1 = f32[] reduce(%neg, %c0), dimensions={0,1}, to_apply=%reducer1
+  %reduce2 = (f32[], f32[]) reduce(%p0, %p0, %c0, %c0), dimensions={0,1}, to_apply=%reducer2
+  ROOT %tuple = (f32[], (f32[], f32[]), f32[6,6]) tuple(%reduce1, %reduce2, %neg)
+}
\ No newline at end of file
diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/mof_variadic_variadic.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/mof_variadic_variadic.hlo
new file mode 100644
index 0000000..6d47fc6
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/mof_variadic_variadic.hlo
@@ -0,0 +1,26 @@
+// Regression test for a compilation crash with a MOF with two variadic
+// reductions.
+// RUN: test_correctness %s
+
+add {
+  p0 = f32[] parameter(0)
+  p1 = f32[] parameter(1)
+  p2 = f32[] parameter(2)
+  p3 = f32[] parameter(3)
+  a = f32[] add(p0, p2)
+  b = f32[] add(p1, p3)
+  ROOT out = (f32[], f32[]) tuple(a, b)
+}
+
+fused_reduce {
+  p0 = f32[3,2] parameter(0)
+  p1 = f32[3,2] parameter(1)
+  c0 = f32[] constant(0)
+  iota0 = f32[3,2] iota(), iota_dimension=1
+  iota1 = f32[3,2] iota(), iota_dimension=1
+  reduce0 = (f32[3], f32[3]) reduce(p0, iota0, c0, c0), dimensions={1},
+      to_apply=add
+  reduce1 = (f32[3], f32[3]) reduce(p1, iota1, c0, c0), dimensions={1},
+      to_apply=add
+  ROOT tuple = ((f32[3], f32[3]), (f32[3], f32[3])) tuple(reduce0, %reduce1)
+}
diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/non_power_of_two.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/non_power_of_two.hlo
new file mode 100644
index 0000000..30202d0
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/non_power_of_two.hlo
@@ -0,0 +1,31 @@
+// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-simplify-arith -inline -canonicalize | FileCheck %s
+// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce
+
+add {
+  lhs = f32[] parameter(0)
+  rhs = f32[] parameter(1)
+  ROOT add = f32[] add(lhs, rhs)
+}
+fused_computation {
+  param_0 = f32[100,568] parameter(0)
+  param_1 = f32[] parameter(1)
+  ROOT reduce = f32[100] reduce(param_0, param_1), dimensions={1}, to_apply=add
+}
+
+// CHECK-DAG: #[[MAP1:.*]] = #xla_gpu.indexing_map<(d0, d1)[s0] -> ((d1 mod 64) * 2 + s0 * 128 + d0), domain: d0 in [0, 1], d1 in [0, 255], s0 in [0, 3]>
+// CHECK-DAG: #[[MAP2:.*]] = #xla_gpu.indexing_map<(d0, d1) -> ((d1 mod 64) * 2 + d0 + 512), domain: d0 in [0, 1], d1 in [0, 255]>
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+
+// The full loop without bounds checks:
+// CHECK: %[[FULL_TILES:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[C4]] step %[[C1]]
+// CHECK-NEXT: scf.for %[[J:.*]] = %[[C0]] to %[[C2]] step %[[C1]]
+// CHECK-NOT: scf.if
+// CHECK: xla_gpu.apply_indexing #[[MAP1]](%[[J]], %thread_id_x)[%[[I]]]
+
+// The tail loop:
+// CHECK: scf.for %[[J:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%{{.*}} = %[[FULL_TILES]])
+// CHECK: scf.if
+// CHECK: xla_gpu.apply_indexing #[[MAP2]](%[[J]], %thread_id_x)
diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/reduction_groups.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/reduction_groups.hlo
new file mode 100644
index 0000000..a7e6415
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/reduction_groups.hlo
@@ -0,0 +1,22 @@
+// RUN: fusion_to_mlir %s | FileCheck %s
+// RUN: test_correctness %s
+
+%add_f32 {
+  %x = f32[] parameter(0)
+  %y = f32[] parameter(1)
+  ROOT %add = f32[] add(%x, %y)
+}
+
+%fused_computation {
+  %param0 = f32[1024] parameter(0)
+  %param1 = f32[1024] parameter(1)
+  %constant0 = f32[] constant(0)
+  %reduce1 = f32[] reduce(%param0, %constant0), dimensions={0}, to_apply=%add_f32
+  %reduce2 = f32[] reduce(%param1, %constant0), dimensions={0}, to_apply=%add_f32
+  ROOT %tuple = (f32[], f32[]) tuple(%reduce1, %reduce2)
+}
+
+// CHECK: %[[BLOCK_ID_Y:.*]] = gpu.block_id y
+// CHECK: scf.index_switch %[[BLOCK_ID_Y]]
+// CHECK: case 1 {
+// CHECK: default {
diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/side_output.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/side_output.hlo
new file mode 100644
index 0000000..e950e3c
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/side_output.hlo
@@ -0,0 +1,24 @@
+// RUN: fusion_to_mlir %s | mlir_fusions_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s
+// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce --bijection_outputs=exp
+
+add {
+  lhs = f32[] parameter(0)
+  rhs = f32[] parameter(1)
+  ROOT add = f32[] add(lhs, rhs)
+}
+
+fused_computation {
+  param_0 = f32[8,2048] parameter(0)
+  param_1 = f32[] parameter(1)
+  exp = f32[8,2048] exponential(param_0)
+  reduce = f32[8] reduce(param_0, param_1), dimensions={1}, to_apply=add
+  ROOT t = (f32[8], f32[8,2048]) tuple(reduce, exp)
+}
+
+// CHECK:     @fused_computation
+// CHECK:     scf.for
+// CHECK:       scf.for
+// CHECK:         %[[SIDE_OUTPUT:.*]] = xla_gpu.pure_call @fused_computation_exp
+// CHECK-NEXT:    tensor.insert %[[SIDE_OUTPUT]]
+// CHECK:         scf.yield
+// CHECK:       scf.yield
diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/side_output_broadcast.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/side_output_broadcast.hlo
new file mode 100644
index 0000000..0db1901
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/side_output_broadcast.hlo
@@ -0,0 +1,15 @@
+// RUN: test_correctness %s
+
+%add {
+  p0 = f32[] parameter(0)
+  p1 = f32[] parameter(1)
+  ROOT add = f32[] add(p0, p1)
+}
+
+%fusion {
+  %p0 = f32[6,6] parameter(0)
+  %c0 = f32[] constant(0)
+  %reduce = f32[] reduce(%p0, %c0), dimensions={0,1}, to_apply=%add
+  %broadcast = f32[6,6] broadcast(%reduce), dimensions={}
+  ROOT %tuple = (f32[6,6], f32[]) tuple(%broadcast, %reduce)
+}
\ No newline at end of file
diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/side_output_unsigned.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/side_output_unsigned.hlo
new file mode 100644
index 0000000..5371b80
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/side_output_unsigned.hlo
@@ -0,0 +1,15 @@
+// RUN: test_correctness %s
+
+add {
+  lhs = u32[] parameter(0)
+  rhs = u32[] parameter(1)
+  ROOT add = u32[] add(lhs, rhs)
+}
+
+fused_computation {
+  param_0 = u32[8,2048] parameter(0)
+  param_1 = u32[] parameter(1)
+  add = u32[8,2048] add(param_0, param_0)
+  reduce = u32[8] reduce(param_0, param_1), dimensions={1}, to_apply=add
+  ROOT t = (u32[8], u32[8,2048]) tuple(reduce, add)
+}
diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/small_f64.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/small_f64.hlo
new file mode 100644
index 0000000..56e3266
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/small_f64.hlo
@@ -0,0 +1,17 @@
+// RUN: fusion_to_mlir %s | FileCheck %s
+// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce
+
+add {
+  lhs = f64[] parameter(0)
+  rhs = f64[] parameter(1)
+  ROOT add = f64[] add(lhs, rhs)
+}
+
+fused_computation {
+  param_0 = f64[100,128] parameter(0)
+  param_1 = f64[] parameter(1)
+  ROOT reduce = f64[100] reduce(param_0, param_1), dimensions={1}, to_apply=add
+}
+
+// This reduction is small enough to not require any shared memory.
+// CHECK-NOT: allocate_shared
diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/variadic_f32.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/variadic_f32.hlo
new file mode 100644
index 0000000..b28bff4
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/variadic_f32.hlo
@@ -0,0 +1,23 @@
+// RUN: fusion_to_mlir %s | mlir_fusions_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s
+// RUN: test_correctness %s --bijection_inputs=reduce:0,1 --bijection_outputs=reduce
+
+add {
+  scalar_lhs.0 = f32[] parameter(0)
+  scalar_rhs.0 = f32[] parameter(1)
+  scalar_lhs.1 = f32[] parameter(2)
+  scalar_rhs.1 = f32[] parameter(3)
+  add.0 = f32[] add(scalar_lhs.0, scalar_lhs.1)
+  add.1 = f32[] add(scalar_rhs.0, scalar_rhs.1)
+  ROOT t = (f32[], f32[]) tuple(add.0, add.1)
+}
+
+fused_computation {
+  param_0 = f32[2, 3, 2048] parameter(0)
+  param_1 = f32[2, 3, 2048] parameter(1)
+  c0 = f32[] constant(0)
+  ROOT reduce = (f32[2, 3], f32[2, 3])
+    reduce(param_0, param_1, c0, c0), dimensions={2}, to_apply=add
+}
+
+// CHECK: allocate_shared
+// CHECK: allocate_shared
diff --git a/third_party/xla/xla/service/gpu/fusions/tiling_util.cc b/third_party/xla/xla/service/gpu/fusions/tiling_util.cc
deleted file mode 100644
index 9ad085f..0000000
--- a/third_party/xla/xla/service/gpu/fusions/tiling_util.cc
+++ /dev/null
@@ -1,259 +0,0 @@
-/*Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/fusions/tiling_util.h"
-
-#include <cstdint>
-#include <limits>
-#include <string>
-#include <vector>
-
-#include "absl/container/inlined_vector.h"
-#include "absl/log/check.h"
-#include "absl/strings/str_cat.h"
-#include "absl/types/span.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/IR/Constants.h"
-#include "llvm/IR/DerivedTypes.h"
-#include "llvm/IR/IRBuilder.h"
-#include "llvm/IR/Instructions.h"
-#include "llvm/IR/Value.h"
-#include "llvm/Support/Casting.h"
-#include "xla/service/gpu/ir_emission_utils.h"
-#include "xla/service/gpu/target_util.h"
-#include "xla/service/llvm_ir/ir_array.h"
-#include "xla/service/llvm_ir/kernel_support_library.h"
-#include "xla/service/llvm_ir/llvm_loop.h"
-#include "xla/service/llvm_ir/llvm_util.h"
-#include "xla/shape_util.h"
-#include "xla/util.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-void EmitTileRec(const TilingThreadIdInfo& thread_id_info, const Tiling& tiling,
-                 int dim, absl::InlinedVector<llvm::Value*, 4> tile_idx,
-                 absl::Span<llvm::Value* const> tile_dimensions,
-                 llvm::IRBuilder<>* b, const TileElementGenerator& emit_elem) {
-  llvm::Type* index_ty = thread_id_info.thread_id->getType();
-  auto constant = [&](int64_t val) {
-    return llvm::ConstantInt::get(index_ty, val);
-  };
-
-  auto recurse = [&] {
-    if (dim == tile_idx.size() - 1) {
-      emit_elem(tile_idx);
-    } else {
-      EmitTileRec(thread_id_info, tiling, dim + 1, tile_idx, tile_dimensions, b,
-                  emit_elem);
-    }
-  };
-
-  bool unroll = tiling.GetLoopsToUnroll()[dim];
-  KernelSupportLibrary ksl(b, unroll ? llvm_ir::UnrollMode::kFullyUnroll
-                                     : llvm_ir::UnrollMode::kDefaultUnroll);
-
-  if (tiling.GetBlockTileSize()[dim] == 1) {
-    tile_idx[dim] = constant(0);
-    recurse();
-  } else if (unroll) {
-    // TODO(jreiffers): Check if this unrolling does anything useful.
-    int64_t stride = tiling.GetThreadsPerBlock()[dim];
-    int64_t dim_size = tiling.GetThreadTileSize()[dim];
-
-    auto make_loop = [&](bool emit_bounds_checks) {
-      auto body = [&, emit_bounds_checks](llvm::Value* i) {
-        tile_idx[dim] = b->CreateAdd(i, thread_id_info.thread_ids[dim]);
-        if (emit_bounds_checks) {
-          auto* in_bounds =
-              b->CreateICmpULT(tile_idx[dim], tile_dimensions[dim]);
-          ksl.If("x_in_tile", in_bounds, recurse);
-        } else {
-          recurse();
-        }
-      };
-      return [&, body] {
-        ksl.For(absl::StrCat("loop", dim), constant(0),
-                constant(dim_size * stride), constant(stride), body);
-      };
-    };
-    if (stride > 1 && dim_size > 1) {
-      // Most tiles will be full, so we emit a single bounds check for those.
-      auto* is_full_tile = b->CreateICmpEQ(
-          constant(tiling.GetBlockTileSize()[dim]), tile_dimensions[dim]);
-      ksl.If("is_full_tile", is_full_tile, make_loop(false), make_loop(true));
-    } else {
-      make_loop(true)();
-    }
-  } else {
-    // All dimensions are strided (thread 0 processes elements 0, num_threads,
-    // num_threads+2, ...; thread 1 processes elements 1, num_threads + 1 and so
-    // on).
-    ksl.For(absl::StrCat("loop", dim), /*start=*/thread_id_info.thread_ids[dim],
-            /*end=*/tile_dimensions[dim],
-            /*step=*/tiling.GetThreadsPerBlock()[dim], [&](llvm::Value* i) {
-              tile_idx[dim] = i;
-              recurse();
-            });
-  }
-}
-
-}  // namespace
-
-void EmitTile(llvm::IRBuilder<>* builder, const Tiling& tiling,
-              const TilingThreadIdInfo& thread_id_info,
-              absl::Span<llvm::Value* const> tile_dimensions,
-              const TileElementGenerator& emit_elem_function) {
-  absl::InlinedVector<llvm::Value*, 4> tile_idx(tiling.GetShape().size());
-  EmitTileRec(thread_id_info, tiling, 0, tile_idx, tile_dimensions, builder,
-              emit_elem_function);
-}
-
-namespace {
-
-// Emits current block id.
-llvm::Value* EmitBlockId(llvm::IRBuilder<>* builder, int32_t num_blocks,
-                         llvm::Type* index_ty) {
-  llvm::Value* block_id =
-      EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockIdx, {}, {}, builder);
-  if (num_blocks != 0) {
-    llvm_ir::AddRangeMetadata(0, num_blocks,
-                              llvm::cast<llvm::Instruction>(block_id),
-                              builder->GetInsertBlock()->getModule());
-  }
-  auto ret = builder->CreateIntCast(block_id, index_ty, /*isSigned=*/true);
-  ret->setName("block.id.x");
-  return ret;
-}
-
-// Emits current thread id with the given type.
-//
-// Sets the return value range to [0, threads_per_block).
-llvm::Value* EmitThreadId(llvm::IRBuilder<>* builder, int64_t threads_per_block,
-                          llvm::Type* index_ty) {
-  // Calculate (y, x) coordinates respectively in the 2D view of thread block,
-  // defined by (num_thread_y, num_thread_x) from thread_id.
-  llvm::CallInst* thread_id =
-      EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, builder);
-  llvm_ir::AddRangeMetadata(0, threads_per_block, thread_id,
-                            builder->GetInsertBlock()->getModule());
-  auto ret = builder->CreateIntCast(thread_id, index_ty, /*isSigned=*/true);
-  ret->setName("thread.id.x");
-  return ret;
-}
-
-// Emits the LLVM values for thread_id, block_id, coordinates of the current
-// tile and strides of the loops to iterate over the current tile.
-absl::StatusOr<TilingThreadIdInfo> EmitThreadIdInfo(llvm::IRBuilder<>* builder,
-                                                    const Tiling& tiling,
-                                                    llvm::Type* index_ty) {
-  auto constant = [&](uint64_t c) -> llvm::Constant* {
-    return llvm::ConstantInt::get(index_ty, c);
-  };
-  int64_t num_blocks = tiling.GetNumBlocks();
-  if (num_blocks > (int64_t)std::numeric_limits<uint32_t>::max()) {
-    return FailedPrecondition(
-        "Number of physical blocks (%d) does not fit in an i32 in tiling "
-        "scheme: %s",
-        num_blocks, tiling.ToString());
-  }
-
-  TilingThreadIdInfo info;
-  info.thread_id =
-      EmitThreadId(builder, tiling.GetNumThreadsPerBlock(), index_ty);
-  info.block_id = EmitBlockId(builder, num_blocks, index_ty);
-
-  for (auto [dim, stride] : llvm::enumerate(tiling.GetThreadStrides())) {
-    int64_t size = tiling.GetThreadsPerBlock()[dim];
-    if (size == 1) {
-      info.thread_ids.emplace_back(constant(0));
-    } else {
-      auto& dim_id = info.thread_ids.emplace_back(info.thread_id);
-      if (stride > 1) {
-        dim_id = builder->CreateUDiv(dim_id, constant(stride));
-      }
-      if (dim) {
-        dim_id = builder->CreateURem(dim_id, constant(size));
-      }
-      dim_id->setName(absl::StrCat("thread.id.", dim));
-    }
-  }
-
-  info.lane_id =
-      builder->CreateURem(info.thread_id, constant(WarpSize()), "lane_id");
-  return info;
-}
-
-}  // namespace
-
-absl::StatusOr<TilingKernelInfo> EmitTilingKernel(
-    llvm::IRBuilder<>* builder, const Tiling& tiling, llvm::Type* index_ty,
-    const TileGenerator& tile_element_generator) {
-  absl::Span<const int64_t> dims_in_elems = tiling.GetShape();
-  const auto& block_counts = tiling.GetBlockCounts();
-  auto constant = [&](uint64_t c) -> llvm::Constant* {
-    return llvm::ConstantInt::get(index_ty, c);
-  };
-
-  TF_ASSIGN_OR_RETURN(TilingThreadIdInfo thread_id_info,
-                      EmitThreadIdInfo(builder, tiling, index_ty));
-
-  KernelSupportLibrary ksl(builder, llvm_ir::UnrollMode::kDefaultUnroll);
-
-  const llvm_ir::IrArray::Index block_coords(
-      thread_id_info.block_id,
-      ShapeUtil::MakeShape(PRED /*arbitrary*/, block_counts), builder);
-
-  absl::InlinedVector<llvm::Value*, 4> tile_dimensions;
-  for (int i = 0; i < block_counts.size(); ++i) {
-    int64_t block_tile_size = tiling.GetBlockTileSize()[i];
-    if (dims_in_elems[i] % block_tile_size == 0) {
-      // The block tile size evenly divides the tiled shape -> no need to emit
-      // the bounds check.
-      tile_dimensions.push_back(constant(block_tile_size));
-    } else {
-      // Only the last tile in each dimension may not have full size.
-      llvm::Value* is_last =
-          builder->CreateICmpEQ(block_coords[i], constant(block_counts[i] - 1));
-      int64_t partial_row =
-          dims_in_elems[i] - (block_counts[i] - 1) * block_tile_size;
-      tile_dimensions.push_back(builder->CreateSelect(
-          is_last, constant(partial_row), constant(block_tile_size),
-          absl::StrCat("tile_bound.", i)));
-    }
-  }
-
-  llvm_ir::IrArray::Index tile_offset = [&] {
-    std::vector<llvm::Value*> elem_multi_index = block_coords.multidim();
-    llvm::Type* index_ty = block_coords.GetType();
-    for (int i = 0; i < block_counts.size(); ++i) {
-      elem_multi_index[i] = builder->CreateMul(
-          block_coords[i],
-          llvm::ConstantInt::get(index_ty, tiling.GetBlockTileSize()[i]),
-          absl::StrCat("tile_origin.", i));
-    }
-    return llvm_ir::IrArray::Index(elem_multi_index, tiling.GetShape(),
-                                   index_ty);
-  }();
-
-  tile_element_generator(thread_id_info, tile_offset, tile_dimensions);
-  return {{tile_dimensions, tile_offset, thread_id_info}};
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/tiling_util.h b/third_party/xla/xla/service/gpu/fusions/tiling_util.h
deleted file mode 100644
index 66014ae..0000000
--- a/third_party/xla/xla/service/gpu/fusions/tiling_util.h
+++ /dev/null
@@ -1,183 +0,0 @@
-/*Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_FUSIONS_TILING_UTIL_H_
-#define XLA_SERVICE_GPU_FUSIONS_TILING_UTIL_H_
-
-#include <cstdint>
-#include <functional>
-#include <string>
-
-#include "absl/container/inlined_vector.h"
-#include "absl/log/check.h"
-#include "absl/strings/str_format.h"
-#include "absl/strings/str_join.h"
-#include "absl/types/span.h"
-#include "llvm/IR/IRBuilder.h"
-#include "llvm/IR/Type.h"
-#include "llvm/IR/Value.h"
-#include "xla/service/llvm_ir/ir_array.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/util.h"
-#include "xla/xla_data.pb.h"
-
-namespace xla {
-namespace gpu {
-
-// Describes tiling used by the kernel.
-//
-// Used by reduction and transpose emitters.
-class Tiling {
- public:
-  Tiling(absl::Span<const int64_t> shape, absl::Span<const int64_t> tile_sizes,
-         absl::Span<const int64_t> num_threads,
-         // By default, don't unroll anything.
-         absl::InlinedVector<bool, 4> loops_to_unroll = {})
-      : shape_{shape.begin(), shape.end()},
-        tile_sizes_per_thread_{tile_sizes.begin(), tile_sizes.end()},
-        tile_sizes_per_block_(shape.size()),
-        num_threads_{num_threads.begin(), num_threads.end()},
-        num_blocks_(shape.size()),
-        loops_to_unroll_(loops_to_unroll) {
-    for (int64_t i = 0; i < shape.size(); ++i) {
-      tile_sizes_per_block_[i] = tile_sizes[i] * num_threads[i];
-      CHECK_NE(tile_sizes_per_block_[i], 0);
-      num_blocks_[i] = CeilOfRatio(shape[i], tile_sizes_per_block_[i]);
-      CHECK_NE(num_blocks_[i], 0);
-    }
-    if (loops_to_unroll_.empty()) loops_to_unroll_.resize(shape.size());
-  }
-  Tiling() = default;
-
-  std::string ToString() const {
-    return absl::StrJoin(
-        {absl::StrFormat("shape = {%s}", absl::StrJoin(shape_, ", ")),
-         absl::StrFormat("tile_sizes = {%s}",
-                         absl::StrJoin(tile_sizes_per_thread_, ", ")),
-         absl::StrFormat("num_threads = {%s}",
-                         absl::StrJoin(num_threads_, ", "))},
-        ", ");
-  }
-
-  // Number of elements in each dimension.
-  const absl::InlinedVector<int64_t, 4>& GetShape() const { return shape_; }
-  xla::Shape GetXlaShape(PrimitiveType element_type = F32) const {
-    return ShapeUtil::MakeShape(element_type, shape_);
-  }
-
-  const absl::InlinedVector<int64_t, 4>& GetBlockCounts() const {
-    return num_blocks_;
-  }
-
-  // Tile size for each thread.
-  //
-  // Equals to the number of iterations in the loop each tile will make.
-  const absl::InlinedVector<int64_t, 4>& GetThreadTileSize() const {
-    return tile_sizes_per_thread_;
-  }
-
-  // Tile size for an entire thread block.
-  const absl::InlinedVector<int64_t, 4>& GetBlockTileSize() const {
-    return tile_sizes_per_block_;
-  }
-
-  const absl::InlinedVector<int64_t, 4>& GetThreadsPerBlock() const {
-    return num_threads_;
-  }
-
-  // Returns the strides of the thread index dimensions wrt. the linear thread
-  // id.
-  absl::InlinedVector<int64_t, 4> GetThreadStrides() const {
-    return *ShapeUtil::ByteStrides(ShapeUtil::MakeShape(U8, num_threads_));
-  }
-
-  int64_t GetNumThreadsPerBlock() const { return Product(num_threads_); }
-
-  int64_t GetNumBlocks() const { return Product(num_blocks_); }
-
-  const absl::InlinedVector<bool, 4>& GetLoopsToUnroll() const {
-    return loops_to_unroll_;
-  }
-
- private:
-  // The number of elements in each dimension.
-  absl::InlinedVector<int64_t, 4> shape_;
-
-  // The number of elements for each dimension of a tile.
-  absl::InlinedVector<int64_t, 4> tile_sizes_per_thread_;
-  absl::InlinedVector<int64_t, 4> tile_sizes_per_block_;
-
-  absl::InlinedVector<int64_t, 4> num_threads_;
-  absl::InlinedVector<int64_t, 4> num_blocks_;
-
-  absl::InlinedVector<bool, 4> loops_to_unroll_;
-};
-
-struct TilingThreadIdInfo {
-  llvm::Value* thread_id;
-
-  absl::InlinedVector<llvm::Value*, 4> thread_ids;
-
-  // Lane id: `thread_id % WarpSize`
-  llvm::Value* lane_id;
-
-  // Block id.
-  llvm::Value* block_id;
-};
-
-struct TilingKernelInfo {
-  // Tiling bounds.
-  absl::InlinedVector<llvm::Value*, 4> output_tile_bounds;
-
-  // Starting tile, as calculated from block id only.
-  llvm_ir::IrArray::Index tile_origin;
-
-  // Thread meta-info.
-  TilingThreadIdInfo thread_id_info;
-};
-
-// A function to generate the code to emit the entire tile.
-//
-// index: Absolute coordinate of the start of the tile in input.
-// tile_dimensions: Size of the tile
-using TileGenerator =
-    std::function<void(const TilingThreadIdInfo& thread_id_info,
-                       const llvm_ir::IrArray::Index& tile_start_index,
-                       absl::Span<llvm::Value* const> tile_dimensions)>;
-
-// A function object to generate code to process one element in a tile.
-//
-// index_in_tile: the current coordinates within the tile. To get the global
-// coordinates, use `tile_start_index.AddOffset(index_in_tile, ...)`.
-using TileElementGenerator =
-    std::function<void(absl::Span<llvm::Value* const> index_in_tile)>;
-
-// Emits code to iterate through a tile with given tile dimensions and generate
-// elements using the callback.
-void EmitTile(llvm::IRBuilder<>* builder, const Tiling& tiling,
-              const TilingThreadIdInfo& thread_id_info,
-              absl::Span<llvm::Value* const> tile_dimensions,
-              const TileElementGenerator& emit_elem_function);
-
-// Emits a kernel for the hlo instruction using the given kernel mapping
-// scheme.
-absl::StatusOr<TilingKernelInfo> EmitTilingKernel(
-    llvm::IRBuilder<>* builder, const Tiling& tiling, llvm::Type* index_ty,
-    const TileGenerator& tile_element_generator);
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_FUSIONS_TILING_UTIL_H_
diff --git a/third_party/xla/xla/service/gpu/fusions/tools/BUILD b/third_party/xla/xla/service/gpu/fusions/tools/BUILD
new file mode 100644
index 0000000..2886ad1
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tools/BUILD
@@ -0,0 +1,113 @@
+load("//xla:xla.bzl", "xla_cc_binary")
+
+package(
+    # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
+    licenses = ["notice"],
+)
+
+xla_cc_binary(
+    name = "mlir_fusions_opt",
+    srcs = ["mlir_fusions_opt.cc"],
+    visibility = ["//xla/service/gpu/fusions:__subpackages__"],
+    deps = [
+        "//xla/mlir_hlo",
+        "//xla/service/gpu/fusions/ir:xla_gpu",
+        "//xla/service/gpu/fusions/transforms:passes",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:AffineDialect",
+        "@llvm-project//mlir:ArithDialect",
+        "@llvm-project//mlir:ComplexDialect",
+        "@llvm-project//mlir:DLTIDialect",
+        "@llvm-project//mlir:FuncDialect",
+        "@llvm-project//mlir:FuncExtensions",
+        "@llvm-project//mlir:GPUDialect",
+        "@llvm-project//mlir:LLVMDialect",
+        "@llvm-project//mlir:MathDialect",
+        "@llvm-project//mlir:MlirOptLib",
+        "@llvm-project//mlir:NVVMDialect",
+        "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:SCFDialect",
+        "@llvm-project//mlir:Support",
+        "@llvm-project//mlir:TensorDialect",
+        "@llvm-project//mlir:Transforms",
+        "@llvm-project//mlir:VectorDialect",
+    ],
+)
+
+cc_library(
+    name = "test_lib",
+    testonly = 1,
+    srcs = ["test_lib.cc"],
+    hdrs = ["test_lib.h"],
+    deps = [
+        "//xla:status_macros",
+        "//xla/hlo/ir:hlo",
+        "//xla/mlir_hlo",
+        "//xla/service/gpu:gpu_device_info_for_tests",
+        "//xla/service/gpu:hlo_fusion_analysis",
+        "//xla/service/gpu/fusions",
+        "//xla/service/gpu/fusions/ir:xla_gpu",
+        "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter",
+        "//xla/stream_executor:device_description",
+        "//xla/tools:hlo_module_loader",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings:string_view",
+        "@llvm-project//mlir:AffineDialect",
+        "@llvm-project//mlir:ArithDialect",
+        "@llvm-project//mlir:ComplexDialect",
+        "@llvm-project//mlir:DLTIDialect",
+        "@llvm-project//mlir:FuncDialect",
+        "@llvm-project//mlir:FuncExtensions",
+        "@llvm-project//mlir:GPUDialect",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:MathDialect",
+        "@llvm-project//mlir:MlirOptLib",
+        "@llvm-project//mlir:SCFDialect",
+        "@llvm-project//mlir:TensorDialect",
+        "@llvm-project//mlir:VectorDialect",
+    ],
+)
+
+xla_cc_binary(
+    name = "fusion_to_mlir",
+    testonly = 1,
+    srcs = ["fusion_to_mlir.cc"],
+    visibility = ["//xla/service/gpu/fusions:__subpackages__"],
+    deps = [
+        ":test_lib",
+        "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/status",
+        "@llvm-project//llvm:Support",
+        "@local_tsl//tsl/platform:platform_port",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+xla_cc_binary(
+    name = "test_correctness",
+    testonly = 1,
+    srcs = ["test_correctness.cc"],
+    visibility = ["//xla/service/gpu/fusions:__subpackages__"],
+    deps = [
+        ":test_lib",
+        "//xla:debug_options_flags",
+        "//xla:error_spec",
+        "//xla:shape_util",
+        "//xla/service:gpu_plugin",
+        "//xla/service/gpu:hlo_fusion_analysis",
+        "//xla/service/gpu/model:indexing_analysis",
+        "//xla/service/gpu/model:indexing_test_utils",
+        "//xla/tests:hlo_test_base",
+        "//xla/tsl/lib/core:status_test_util",
+        "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/strings:string_view",
+        "@com_google_absl//absl/types:span",
+        "@com_google_googletest//:gtest",
+        "@llvm-project//llvm:Support",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
diff --git a/third_party/xla/xla/service/gpu/fusions/tools/fusion_to_mlir.cc b/third_party/xla/xla/service/gpu/fusions/tools/fusion_to_mlir.cc
new file mode 100644
index 0000000..9fe41b6
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tools/fusion_to_mlir.cc
@@ -0,0 +1,48 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <string>
+
+#include "absl/log/check.h"
+#include "absl/status/status.h"
+#include "llvm/Support/raw_ostream.h"
+#include "xla/service/gpu/fusions/tools/test_lib.h"
+#include "tsl/platform/init_main.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+absl::Status Run(const std::string& filename) {
+  TF_ASSIGN_OR_RETURN(auto module, LoadTestModule(filename));
+  TF_ASSIGN_OR_RETURN(auto emitter_data, GetMlirFusionEmitter(*module));
+
+  auto context = GetMlirContextForTest();
+  TF_ASSIGN_OR_RETURN(auto mlir_module,
+                      emitter_data->emitter->CreateMLIRModule(
+                          context, *emitter_data->fusion, "main",
+                          /*buffer_assignment=*/nullptr));
+  llvm::outs() << *mlir_module;
+  return absl::OkStatus();
+}
+
+}  // namespace gpu
+}  // namespace xla
+
+int main(int argc, char** argv) {
+  tsl::port::InitMain(argv[0], &argc, &argv);
+  CHECK_EQ(argc, 2) << "Must specify an input file";
+  CHECK_OK(xla::gpu::Run(argv[1]));
+  return 0;
+}
diff --git a/third_party/xla/xla/service/gpu/fusions/tools/mlir_fusions_opt.cc b/third_party/xla/xla/service/gpu/fusions/tools/mlir_fusions_opt.cc
new file mode 100644
index 0000000..7206cc1
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tools/mlir_fusions_opt.cc
@@ -0,0 +1,77 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "llvm/ADT/STLFunctionalExtras.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/Twine.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/DLTI/DLTI.h"
+#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Tools/mlir-opt/MlirOptMain.h"
+#include "mlir/Transforms/Passes.h"
+#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h"
+#include "xla/service/gpu/fusions/transforms/passes.h"
+
+int main(int argc, char** argv) {
+  mlir::DialectRegistry registry;
+  registry.insert<mlir::DLTIDialect, mlir::tensor::TensorDialect,
+                  mlir::func::FuncDialect, mlir::affine::AffineDialect,
+                  mlir::arith::ArithDialect, mlir::complex::ComplexDialect,
+                  mlir::math::MathDialect, mlir::scf::SCFDialect,
+                  mlir::mhlo::MhloDialect, mlir::LLVM::LLVMDialect,
+                  mlir::gpu::GPUDialect, mlir::mhlo::MhloDialect,
+                  mlir::vector::VectorDialect, xla::gpu::XlaGpuDialect,
+                  mlir::NVVM::NVVMDialect>();
+  mlir::func::registerAllExtensions(registry);
+  mlir::registerCanonicalizerPass();
+  mlir::registerCSEPass();
+  mlir::registerInliner();
+  xla::gpu::registerGpuFusionTransformsPasses();
+  mlir::registerPassPipeline(
+      "xla-gpu-test-to-inline",
+      "Test pipeline of passes up to inlining. No vectorization, also does not "
+      "lower xla_gpu. Intended to simplify IR in tests.",
+      [=](mlir::OpPassManager& pm, llvm::StringRef options,
+          llvm::function_ref<mlir::LogicalResult(const llvm::Twine&)>
+              errorHandler) {
+        if (!options.empty()) return mlir::failure();
+
+        pm.addNestedPass<mlir::func::FuncOp>(
+            xla::gpu::CreateSimplifyArithPass());
+        pm.addPass(xla::gpu::CreateEraseDeadFunctionsPass());
+        pm.addPass(mlir::createCSEPass());
+        pm.addPass(mlir::createInlinerPass({}, [&](mlir::OpPassManager& pm) {
+          pm.addPass(mlir::createCSEPass());
+        }));
+        return mlir::success();
+      },
+      [](llvm::function_ref<void(const mlir::detail::PassOptions&)>) {});
+
+  return mlir::failed(
+      MlirOptMain(argc, argv, "XLA MLIR Fusion Pass Driver\n", registry));
+}
diff --git a/third_party/xla/xla/service/gpu/fusions/tools/test_correctness.cc b/third_party/xla/xla/service/gpu/fusions/tools/test_correctness.cc
new file mode 100644
index 0000000..72529cd
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tools/test_correctness.cc
@@ -0,0 +1,192 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <cstdint>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "absl/log/check.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/ascii.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_split.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "llvm/ADT/STLExtras.h"
+#include "xla/debug_options_flags.h"
+#include "xla/error_spec.h"
+#include "xla/service/gpu/fusions/tools/test_lib.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
+#include "xla/service/gpu/model/indexing_map.h"
+#include "xla/service/gpu/model/indexing_test_utils.h"
+#include "xla/shape.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/tsl/lib/core/status_test_util.h"
+#include "tsl/platform/statusor.h"
+
+struct Flags {
+  std::string input_file = "";
+  float abs_error_bound = 1e-4;
+  float rel_error_bound = 1e-4;
+  std::vector<std::pair<std::string, std::vector<int64_t>>> bijection_inputs;
+  std::vector<std::string> bijection_outputs;
+};
+
+Flags& flags = *new Flags;
+
+namespace xla {
+namespace gpu {
+namespace {
+
+using CorrectnessTest = HloTestBase;
+
+const Shape& GetFirstArrayShape(const Shape& shape) {
+  if (shape.IsArray()) {
+    return shape;
+  }
+  CHECK(shape.IsTuple());
+  return GetFirstArrayShape(shape.tuple_shapes(0));
+}
+
+absl::Status TestBijection(const IndexingMap& map,
+                           absl::Span<int64_t const> shape) {
+  std::vector<Interval> intervals;
+  for (int64_t size : shape) {
+    intervals.push_back({0, size - 1});
+  }
+  auto status = VerifyBijection(map, intervals);
+  if (status.ok()) return status;
+  return absl::FailedPreconditionError(
+      absl::StrCat(status.message(), " in map ", map.ToString()));
+}
+
+TEST_F(CorrectnessTest, RunAndCompare) {
+  TF_ASSERT_OK_AND_ASSIGN(auto module, LoadTestModule(flags.input_file));
+  EXPECT_TRUE(RunAndCompareNoHloPasses(
+      std::move(module),
+      ErrorSpec{flags.abs_error_bound, flags.rel_error_bound}));
+}
+
+absl::StatusOr<int64_t> GetHeroIndex(absl::string_view name,
+                                     const HloFusionAnalysis& analysis) {
+  for (auto [index, hero] : llvm::enumerate(analysis.fusion_heroes())) {
+    if (hero.name() == name) {
+      return index;
+    }
+  }
+  return absl::NotFoundError(absl::StrCat("Hero ", name, " not found"));
+}
+
+std::pair<std::string, std::vector<int64_t>> ParseHeroAndIds(
+    absl::string_view hero_and_ids) {
+  std::pair<absl::string_view, absl::string_view> hero_and_ids_pair =
+      absl::StrSplit(hero_and_ids, ':');
+  std::vector<int64_t> ids;
+  for (absl::string_view id : absl::StrSplit(hero_and_ids_pair.second, ',')) {
+    ids.push_back(std::stoi(std::string(absl::StripAsciiWhitespace(id))));
+  }
+  return {std::string(absl::StripAsciiWhitespace(hero_and_ids_pair.first)),
+          ids};
+}
+
+TEST_F(CorrectnessTest, InputIndexingIsBijection) {
+  auto context = GetMlirContextForTest();
+  TF_ASSERT_OK_AND_ASSIGN(auto module, LoadTestModule(flags.input_file));
+  TF_ASSERT_OK_AND_ASSIGN(auto emitter_data, GetMlirFusionEmitter(*module));
+  for (const auto& [hero_name, ids] : flags.bijection_inputs) {
+    TF_ASSERT_OK_AND_ASSIGN(int64_t hero_index,
+                            GetHeroIndex(hero_name, *emitter_data->analysis));
+    for (int64_t id : ids) {
+      auto indexing = emitter_data->emitter->ComputeThreadIdToInputIndexing(
+          hero_index, id, &context);
+      ASSERT_TRUE(indexing.has_value());
+      TF_ASSERT_OK(TestBijection(*indexing,
+                                 emitter_data->analysis->fusion_hero(hero_index)
+                                     .GetOperand(id)
+                                     .shape()
+                                     .dimensions()))
+          << "Expected operand " << id << " of " << hero_name << " (root index "
+          << hero_index << ") to be read exactly once.";
+    }
+  }
+}
+
+TEST_F(CorrectnessTest, OutputIndexingIsBijection) {
+  auto context = GetMlirContextForTest();
+  TF_ASSERT_OK_AND_ASSIGN(auto module, LoadTestModule(flags.input_file));
+  TF_ASSERT_OK_AND_ASSIGN(auto emitter_data, GetMlirFusionEmitter(*module));
+  for (const auto& hero_name : flags.bijection_outputs) {
+    TF_ASSERT_OK_AND_ASSIGN(int64_t hero_index,
+                            GetHeroIndex(hero_name, *emitter_data->analysis));
+    auto indexing = emitter_data->emitter->ComputeThreadIdToOutputIndexing(
+        hero_index, &context);
+    ASSERT_TRUE(indexing.has_value());
+    TF_ASSERT_OK(TestBijection(
+        *indexing, GetFirstArrayShape(
+                       emitter_data->analysis->fusion_root(hero_index).shape())
+                       .dimensions()))
+        << "Expected output of " << hero_name << " (root index " << hero_index
+        << ") to be written exactly once.";
+  }
+}
+
+}  // namespace
+}  // namespace gpu
+}  // namespace xla
+
+int main(int argc, char* argv[]) {
+  std::vector<tsl::Flag> flag_list = {
+      tsl::Flag("abs_error_bound", &flags.abs_error_bound,
+                "Absolute error bound."),
+      tsl::Flag("rel_error_bound", &flags.rel_error_bound,
+                "Relative error bound."),
+      tsl::Flag(
+          "bijection_inputs",
+          [](std::string name_and_ids) {
+            if (name_and_ids.empty()) return false;
+            flags.bijection_inputs.push_back(
+                xla::gpu::ParseHeroAndIds(name_and_ids));
+            return true;
+          },
+          "",
+          "The name of a hero followed by operand ids that should be read "
+          "exactly once, i.e. there's a bijection between a subset of threads "
+          "and the input shape. Example: 'reduction0: 0, 1'."),
+      tsl::Flag(
+          "bijection_outputs",
+          [](std::string name) {
+            if (name.empty()) return false;
+            flags.bijection_outputs.push_back(name);
+            return true;
+          },
+          "",
+          "The name of a hero whose outputs should be written exactly once, "
+          "i.e. there's a bijection between a subset of threads and the output "
+          "shape.")};
+
+  xla::AppendDebugOptionsFlags(&flag_list);
+  std::string usage = tsl::Flags::Usage(argv[0], flag_list);
+  bool parseResult = tsl::Flags::Parse(&argc, argv, flag_list);
+  if (!parseResult || argc != 2) {
+    LOG(ERROR) << "\n" << usage;
+    return 1;
+  }
+
+  flags.input_file = argv[1];
+  testing::InitGoogleTest(&argc, argv);
+  return RUN_ALL_TESTS();
+}
diff --git a/third_party/xla/xla/service/gpu/fusions/tools/test_lib.cc b/third_party/xla/xla/service/gpu/fusions/tools/test_lib.cc
new file mode 100644
index 0000000..11b82dd
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tools/test_lib.cc
@@ -0,0 +1,118 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "xla/service/gpu/fusions/tools/test_lib.h"
+
+#include <memory>
+
+#include "absl/algorithm/container.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/DLTI/DLTI.h"
+#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Tools/mlir-opt/MlirOptMain.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
+#include "xla/service/gpu/fusions/fusions.h"
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h"
+#include "xla/service/gpu/gpu_device_info_for_tests.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
+#include "xla/status_macros.h"
+#include "xla/tools/hlo_module_loader.h"
+
+namespace xla {
+namespace gpu {
+
+absl::StatusOr<std::unique_ptr<HloModule>> LoadTestModule(
+    absl::string_view filename) {
+  auto module = *xla::LoadModuleFromFile(std::string(filename));
+  module->mutable_config()
+      .mutable_debug_options()
+      .set_xla_gpu_mlir_emitter_level(4);
+
+  int num_fusions = absl::c_count_if(
+      module->entry_computation()->instructions(),
+      [](const HloInstruction* instruction) {
+        return instruction->opcode() == xla::HloOpcode::kFusion;
+      });
+  TF_RET_CHECK(num_fusions <= 1) << "HLO must contain at most one fusion";
+
+  if (num_fusions == 0) {
+    // Generate a fusion from the entry computation.
+    HloComputation::Builder builder("generated_main");
+    std::vector<HloInstruction*> params;
+    for (const auto* param :
+         module->entry_computation()->parameter_instructions()) {
+      params.push_back(*builder.AddParameter(param->Clone(/*suffix=*/"")));
+    }
+    builder.AddInstruction(HloInstruction::CreateFusion(
+        module->entry_computation()->root_instruction()->shape(),
+        HloInstruction::FusionKind::kLoop /* irrelevant */, params,
+        module->entry_computation()));
+
+    auto* new_entry = module->AddComputationAndUnifyNamesAndIds(
+        builder.Build(), /*is_entry=*/false);
+    module->ReplaceEntryComputation(new_entry);
+  }
+
+  return module;
+}
+
+absl::StatusOr<std::unique_ptr<EmitterData>> GetMlirFusionEmitter(
+    const HloModule& module) {
+  auto data = std::make_unique<EmitterData>();
+  data->fusion = DynCast<HloFusionInstruction>(
+      module.entry_computation()->root_instruction());
+  TF_RET_CHECK(data->fusion != nullptr) << "Root instruction must be a fusion";
+  data->device.emplace(TestGpuDeviceInfo::RTXA6000DeviceInfo());
+  data->analysis.emplace(
+      HloFusionAnalysis::Create(*data->fusion, data->device.value()));
+  PreBufferAssignmentFusionInfo info(data->analysis.value());
+  auto emitter = GetFusionEmitter(info);
+
+  auto mlir_emitter = dynamic_cast<MlirFusionEmitterBase*>(emitter.get());
+  TF_RET_CHECK(mlir_emitter != nullptr)
+      << "Expected emitter to be an MlirFusionEmitter";
+
+  emitter.release();
+  data->emitter.reset(mlir_emitter);
+  return data;
+}
+
+mlir::MLIRContext GetMlirContextForTest() {
+  mlir::DialectRegistry registry;
+  registry.insert<mlir::DLTIDialect, mlir::tensor::TensorDialect,
+                  mlir::func::FuncDialect, mlir::affine::AffineDialect,
+                  mlir::arith::ArithDialect, mlir::complex::ComplexDialect,
+                  mlir::math::MathDialect, mlir::scf::SCFDialect,
+                  mlir::mhlo::MhloDialect, mlir::gpu::GPUDialect,
+                  mlir::vector::VectorDialect, XlaGpuDialect>();
+  return mlir::MLIRContext(registry);
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/tools/test_lib.h b/third_party/xla/xla/service/gpu/fusions/tools/test_lib.h
new file mode 100644
index 0000000..5dfa300
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/tools/test_lib.h
@@ -0,0 +1,58 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_FUSIONS_TOOLS_TEST_LIB_H_
+#define XLA_SERVICE_GPU_FUSIONS_TOOLS_TEST_LIB_H_
+
+#include <memory>
+#include <optional>
+
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "mlir/IR/MLIRContext.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
+#include "xla/stream_executor/device_description.h"
+
+namespace xla {
+
+namespace gpu {
+
+// Loads a test module from the given filename, ensuring it has a single fusion.
+// If the file contains more than one fusion, the function fails. If the file
+// contains no fusions, the function generates a fusion from the entry
+// computation.
+absl::StatusOr<std::unique_ptr<HloModule>> LoadTestModule(
+    absl::string_view filename);
+
+// Returns the MLIR fusion emitter for the given module, which should have been
+// loaded using LoadTestModule.
+struct EmitterData {
+  HloFusionInstruction* fusion;
+  std::optional<se::DeviceDescription> device;
+  std::optional<HloFusionAnalysis> analysis;
+  std::unique_ptr<MlirFusionEmitterBase> emitter;
+};
+absl::StatusOr<std::unique_ptr<EmitterData>> GetMlirFusionEmitter(
+    const HloModule& module);
+
+// Returns an MLIR context with all the dialects needed for testing.
+mlir::MLIRContext GetMlirContextForTest();
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_FUSIONS_TOOLS_TEST_LIB_H_
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/BUILD b/third_party/xla/xla/service/gpu/fusions/transforms/BUILD
new file mode 100644
index 0000000..3cf1d26
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/BUILD
@@ -0,0 +1,105 @@
+load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")
+
+package(
+    # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
+    default_visibility = [":friends"],
+    licenses = ["notice"],
+)
+
+package_group(
+    name = "friends",
+    includes = [
+        "//xla:friends",
+    ],
+)
+
+gentbl_cc_library(
+    name = "passes_inc_gen",
+    tbl_outs = [
+        (
+            [
+                "-gen-pass-decls",
+                "-name=GpuFusionTransforms",
+            ],
+            "passes.h.inc",
+        ),
+    ],
+    tblgen = "@llvm-project//mlir:mlir-tblgen",
+    td_file = "passes.td",
+    visibility = ["//visibility:private"],
+    deps = ["@llvm-project//mlir:PassBaseTdFiles"],
+)
+
+cc_library(
+    name = "passes",
+    srcs = [
+        "convert_xla_gpu_pure_call_ops.cc",
+        "erase_dead_functions.cc",
+        "expand_float_ops.cc",
+        "flatten_tensors.cc",
+        "lower_tensors.cc",
+        "lower_to_llvm.cc",
+        "lower_xla_gpu_to_scf.cc",
+        "merge_pointers_to_same_slice.cc",
+        "optimize_loops.cc",
+        "peel_loops.cc",
+        "propagate_slice_indices.cc",
+        "simplify_affine.cc",
+        "simplify_arith.cc",
+        "unswitch_loops.cc",
+        "vectorize_loads_stores.cc",
+    ],
+    hdrs = ["passes.h"],
+    deps = [
+        ":passes_inc_gen",
+        "//xla:shape_util",
+        "//xla:util",
+        "//xla:xla_data_proto_cc",
+        "//xla/mlir_hlo",
+        "//xla/mlir_hlo:map_mhlo_to_scalar_op",
+        "//xla/service/gpu:ir_emission_utils",
+        "//xla/service/gpu/fusions/ir:xla_gpu",
+        "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir",
+        "//xla/service/gpu/model:indexing_analysis",
+        "//xla/stream_executor:device_description",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/base:core_headers",
+        "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/strings",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:AffineDialect",
+        "@llvm-project//mlir:AffineToStandard",
+        "@llvm-project//mlir:AffineUtils",
+        "@llvm-project//mlir:ArithDialect",
+        "@llvm-project//mlir:ArithToLLVM",
+        "@llvm-project//mlir:ArithTransforms",
+        "@llvm-project//mlir:CallOpInterfaces",
+        "@llvm-project//mlir:ComplexDialect",
+        "@llvm-project//mlir:ComplexToLLVM",
+        "@llvm-project//mlir:ControlFlowToLLVM",
+        "@llvm-project//mlir:DataLayoutInterfaces",
+        "@llvm-project//mlir:DialectUtils",
+        "@llvm-project//mlir:FuncDialect",
+        "@llvm-project//mlir:FuncToLLVM",
+        "@llvm-project//mlir:GPUDialect",
+        "@llvm-project//mlir:GPUToNVVMTransforms",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:LLVMCommonConversion",
+        "@llvm-project//mlir:LLVMDialect",
+        "@llvm-project//mlir:MathDialect",
+        "@llvm-project//mlir:MathToLLVM",
+        "@llvm-project//mlir:MathTransforms",
+        "@llvm-project//mlir:NVVMDialect",
+        "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:SCFDialect",
+        "@llvm-project//mlir:SCFToControlFlow",
+        "@llvm-project//mlir:SCFUtils",
+        "@llvm-project//mlir:Support",
+        "@llvm-project//mlir:TensorDialect",
+        "@llvm-project//mlir:TransformUtils",
+        "@llvm-project//mlir:VectorDialect",
+        "@llvm-project//mlir:VectorToLLVM",
+        "@llvm-project//mlir:VectorTransforms",
+    ],
+)
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/convert_xla_gpu_pure_call_ops.cc b/third_party/xla/xla/service/gpu/fusions/transforms/convert_xla_gpu_pure_call_ops.cc
new file mode 100644
index 0000000..0c9053a
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/convert_xla_gpu_pure_call_ops.cc
@@ -0,0 +1,61 @@
+/* Copyright 2024 The OpenXLA Authors.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+    http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <utility>
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+#define GEN_PASS_DEF_CONVERTPURECALLOPSPASS
+#include "xla/service/gpu/fusions/transforms/passes.h.inc"
+
+struct RewriteCall : mlir::OpRewritePattern<PureCallOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  mlir::LogicalResult matchAndRewrite(
+      PureCallOp op, mlir::PatternRewriter& rewriter) const override {
+    rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
+        op, op.getResultTypes(), op.getOperands(), op->getAttrs());
+    return mlir::success();
+  }
+};
+
+class ConvertPureCallOpsPass
+    : public impl::ConvertPureCallOpsPassBase<ConvertPureCallOpsPass> {
+ public:
+  void runOnOperation() override {
+    auto* ctx = &getContext();
+    mlir::RewritePatternSet patterns(ctx);
+    patterns.add<RewriteCall>(ctx);
+    if (mlir::failed(mlir::applyPatternsAndFoldGreedily(getOperation(),
+                                                        std::move(patterns)))) {
+      signalPassFailure();
+    }
+  }
+};
+
+}  // namespace
+
+std::unique_ptr<::mlir::Pass> CreateConvertPureCallOpsPass() {
+  return std::make_unique<ConvertPureCallOpsPass>();
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/erase_dead_functions.cc b/third_party/xla/xla/service/gpu/fusions/transforms/erase_dead_functions.cc
new file mode 100644
index 0000000..3918a19
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/erase_dead_functions.cc
@@ -0,0 +1,86 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <queue>
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/CallInterfaces.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h"
+
+namespace xla {
+namespace gpu {
+
+#define GEN_PASS_DEF_ERASEDEADFUNCTIONSPASS
+#include "xla/service/gpu/fusions/transforms/passes.h.inc"
+
+namespace {
+
+struct CallInfo {
+  PureCallOp call;
+  int count;
+};
+
+llvm::DenseSet<mlir::func::FuncOp> FindLiveFunctions(mlir::ModuleOp module) {
+  std::queue<mlir::func::FuncOp> worklist;
+  llvm::DenseSet<mlir::func::FuncOp> live_funcs;
+  module.walk([&](mlir::func::FuncOp func) {
+    if (!func.isPrivate()) {
+      worklist.push(func);
+      live_funcs.insert(func);
+    }
+  });
+
+  mlir::SymbolTableCollection symbol_table;
+  while (!worklist.empty()) {
+    auto func = worklist.front();
+    worklist.pop();
+    func.walk([&](mlir::CallOpInterface call) {
+      auto callee =
+          mlir::cast<mlir::func::FuncOp>(call.resolveCallable(&symbol_table));
+      if (live_funcs.insert(callee).second) {
+        worklist.push(callee);
+      }
+    });
+  }
+  return live_funcs;
+}
+
+class EraseDeadFunctionsPass
+    : public impl::EraseDeadFunctionsPassBase<EraseDeadFunctionsPass> {
+ public:
+  void runOnOperation() override {
+    // Find live functions and erase dead ones.
+    auto live = FindLiveFunctions(getOperation());
+    getOperation().walk([&](mlir::func::FuncOp func) {
+      if (!live.contains(func)) {
+        func.erase();
+      }
+    });
+  }
+};
+
+}  // namespace
+
+std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
+CreateEraseDeadFunctionsPass() {
+  return std::make_unique<EraseDeadFunctionsPass>();
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/expand_float_ops.cc b/third_party/xla/xla/service/gpu/fusions/transforms/expand_float_ops.cc
new file mode 100644
index 0000000..2274576
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/expand_float_ops.cc
@@ -0,0 +1,675 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <algorithm>
+#include <array>
+#include <cassert>
+#include <cstdint>
+#include <memory>
+#include <utility>
+
+#include "absl/log/check.h"
+#include "llvm/ADT/APFloat.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Math/Transforms/Passes.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Types.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
+#include "xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h"
+#include "xla/service/gpu/fusions/transforms/passes.h"
+#include "xla/xla_data.pb.h"
+
+namespace xla {
+namespace gpu {
+
+namespace ma = ::mlir::arith;
+
+using ma::SelectOp;
+using mlir::Value;
+
+#define GEN_PASS_DEF_EXPANDFLOATOPSPASS
+#include "xla/service/gpu/fusions/transforms/passes.h.inc"
+
+namespace {
+
+// Wraps a Value to provide operator overloading for more readable expressions.
+struct Val {
+  Value value;
+  mlir::ImplicitLocOpBuilder* b;
+
+  operator Value() const { return value; }  // NOLINT
+
+  Val operator+(int64_t rhs) const { return Binop<ma::AddIOp>(rhs); }
+  Val operator+(Value rhs) const { return Binop<ma::AddIOp>(rhs); }
+  Val operator-(int64_t rhs) const { return Binop<ma::SubIOp>(rhs); }
+  Val operator-(Value rhs) const { return Binop<ma::SubIOp>(rhs); }
+  Val operator*(int64_t rhs) const { return Binop<ma::MulIOp>(rhs); }
+  Val operator*(Value rhs) const { return Binop<ma::MulIOp>(rhs); }
+  Val operator&(Value rhs) const { return Binop<ma::AndIOp>(rhs); }
+  Val operator&(int64_t rhs) const { return Binop<ma::AndIOp>(rhs); }
+  Val operator|(Value rhs) const { return Binop<ma::OrIOp>(rhs); }
+  Val operator|(int64_t rhs) const { return Binop<ma::OrIOp>(rhs); }
+  Val operator^(Value rhs) const { return Binop<ma::XOrIOp>(rhs); }
+  Val shl(Value rhs) const { return Binop<ma::ShLIOp>(rhs); }
+  Val shl(int64_t rhs) const { return Binop<ma::ShLIOp>(rhs); }
+  Val shrui(Value rhs) const { return Binop<ma::ShRUIOp>(rhs); }
+  Val shrui(int64_t rhs) const { return Binop<ma::ShRUIOp>(rhs); }
+
+  Val cmp(ma::CmpIPredicate pred, Value rhs) const {
+    return {b->create<ma::CmpIOp>(pred, value, rhs), b};
+  }
+  Val cmp(ma::CmpIPredicate pred, int64_t rhs) const {
+    return cmp(pred, MakeConstant(rhs));
+  }
+  Val operator==(Value rhs) const { return cmp(ma::CmpIPredicate::eq, rhs); }
+  Val operator==(int64_t rhs) const { return cmp(ma::CmpIPredicate::eq, rhs); }
+  Val operator!=(int64_t rhs) const { return cmp(ma::CmpIPredicate::ne, rhs); }
+
+  Val MakeConstant(int64_t c) const {
+    return {b->create<ma::ConstantIntOp>(c, value.getType()), b};
+  }
+
+ private:
+  template <typename Op>
+  Val Binop(Value rhs) const {
+    return {b->create<Op>(value, rhs), b};
+  }
+
+  template <typename Op>
+  Val Binop(int64_t rhs) const {
+    return Binop<Op>(MakeConstant(rhs));
+  }
+};
+
+template <typename OpTy, ma::CmpFPredicate pred>
+struct RewriteToCmpSelect : public mlir::OpRewritePattern<OpTy> {
+  using mlir::OpRewritePattern<OpTy>::OpRewritePattern;
+
+  RewriteToCmpSelect(mlir::MLIRContext* context, bool include_f32)
+      : mlir::OpRewritePattern<OpTy>(context), include_f32(include_f32) {}
+
+  mlir::LogicalResult matchAndRewrite(
+      OpTy op, mlir::PatternRewriter& rewriter) const override {
+    if (op.getType().isF32() && !include_f32) {
+      return rewriter.notifyMatchFailure(op, "not rewriting f32 min/max");
+    }
+
+    auto lhs_is_nan = rewriter.create<ma::CmpFOp>(
+        op.getLoc(), ma::CmpFPredicate::UNE, op.getLhs(), op.getLhs());
+    auto rhs_is_not_nan = rewriter.create<ma::CmpFOp>(
+        op.getLoc(), ma::CmpFPredicate::OEQ, op.getRhs(), op.getRhs());
+
+    auto return_lhs =
+        rewriter.create<ma::CmpFOp>(op.getLoc(), pred, op.getLhs(), op.getRhs())
+            .getResult();
+
+    // logic: isNaN(lhs) || (!isNan(rhs) && return_lhs) ? lhs : rhs
+    return_lhs = rewriter.create<ma::OrIOp>(
+        op.getLoc(), lhs_is_nan,
+        rewriter.create<ma::AndIOp>(op.getLoc(), rhs_is_not_nan, return_lhs));
+
+    rewriter.replaceOpWithNewOp<SelectOp>(op, op.getResult().getType(),
+                                          return_lhs, op.getLhs(), op.getRhs());
+    return mlir::success();
+  }
+
+  bool include_f32;
+};
+
+struct RewriteErf32Pattern : public mlir::OpRewritePattern<mlir::math::ErfOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  mlir::LogicalResult matchAndRewrite(
+      mlir::math::ErfOp op, mlir::PatternRewriter& rewriter) const override {
+    if (!op.getType().isF32()) {
+      return rewriter.notifyMatchFailure(op, "not an f32 erf");
+    }
+
+    static const std::array<float, 5> kAlpha{
+        0.00022905065861350646f, 0.0034082910107109506f, 0.050955695062380861f,
+        0.18520832239976145f, 1.128379143519084f};
+
+    static const std::array<float, 7> kBeta{-1.1791602954361697e-7,
+                                            0.000023547966471313185f,
+                                            0.0010179625278914885f,
+                                            0.014070470171167667f,
+                                            0.11098505178285362f,
+                                            0.49746925110067538f,
+                                            1.0f};
+
+    // We clamp x to be within [-c;c] where c = erfinv(1-2^-23), outside of
+    // which x should be +/-1.
+    constexpr float kErfInvOneMinusHalfULP = 3.7439211627767994f;
+
+    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    auto c = [&](float v) -> Value {
+      return b.create<ma::ConstantFloatOp>(llvm::APFloat(v),
+                                           rewriter.getF32Type());
+    };
+
+    auto poly = [&](auto x, auto coefficients) -> Value {
+      auto r = c(coefficients[0]);
+      for (int i = 1; i < coefficients.size(); ++i) {
+        r = b.create<mlir::math::FmaOp>(r, x, c(coefficients[i]));
+      }
+      return r;
+    };
+
+    Value x = op.getOperand();
+    x = b.create<ma::MaximumFOp>(x, c(-kErfInvOneMinusHalfULP));
+    x = b.create<ma::MinimumFOp>(x, c(kErfInvOneMinusHalfULP));
+    Value x2 = b.create<ma::MulFOp>(x, x);
+
+    rewriter.replaceOpWithNewOp<ma::DivFOp>(
+        op, b.create<ma::MulFOp>(x, poly(x2, kAlpha)), poly(x2, kBeta));
+
+    return mlir::success();
+  }
+};
+
+int GetSignificandBits(mlir::FloatType ty) {
+  return llvm::APFloat::semanticsPrecision(ty.getFloatSemantics()) - 1;
+}
+
+int GetExponentBias(mlir::FloatType ty) {
+  return 1 - llvm::APFloat::semanticsMinExponent(ty.getFloatSemantics());
+}
+
+Value IsInf(Value value, mlir::ImplicitLocOpBuilder& b) {
+  auto ty = mlir::cast<mlir::FloatType>(value.getType());
+  if (mlir::LLVM::isCompatibleOuterType(ty)) {
+    value = b.create<mlir::math::AbsFOp>(value);
+    Value inf = b.create<ma::ConstantFloatOp>(
+        llvm::APFloat::getInf(ty.getFloatSemantics()), ty);
+    return b.create<ma::CmpFOp>(ma::CmpFPredicate::OEQ, value, inf);
+  }
+
+  assert(ty.getIntOrFloatBitWidth() == 8);
+  if (!ty.isFloat8E5M2()) {
+    // F8E5M2 is the only 8 bit float with infinities.
+    return b.create<ma::ConstantIntOp>(false, b.getI1Type());
+  }
+  Val bits{b.create<ma::BitcastOp>(b.getI8Type(), value), &b};
+  return (bits & 0x7F) == 0x7C;
+}
+
+Value IsNaN(Value value, mlir::ImplicitLocOpBuilder& b) {
+  auto ty = value.getType();
+  if (mlir::LLVM::isCompatibleOuterType(ty)) {
+    return b.create<ma::CmpFOp>(ma::CmpFPredicate::UNO, value, value);
+  }
+
+  assert(ty.getIntOrFloatBitWidth() == 8);
+  Val bits{b.create<ma::BitcastOp>(b.getI8Type(), value), &b};
+  if (ty.isFloat8E5M2() || ty.isFloat8E4M3FN()) {
+    return (bits & 0x7F) == 0x7F;
+  }
+  return bits == 0x80;
+}
+
+Value EmitReducePrecision(Value value, int exponent_bits, int mantissa_bits,
+                          mlir::ImplicitLocOpBuilder& b) {
+  mlir::mhlo::ReducePrecisionOp::Properties properties;
+  properties.exponent_bits = b.getI32IntegerAttr(exponent_bits);
+  properties.mantissa_bits = b.getI32IntegerAttr(mantissa_bits);
+  return mlir::mhlo::MhloOpToStdScalarOp::mapOpOfType<
+      mlir::mhlo::ReducePrecisionOp>(
+      b.getLoc(), value.getType(), {value.getType()},
+      mlir::mhlo::ReducePrecisionOp::Adaptor(value, nullptr, properties), &b);
+}
+
+Value EmitF16ToF8e5m2(Value in, mlir::ImplicitLocOpBuilder& b) {
+  Val in_bits{b.create<ma::BitcastOp>(b.getI16Type(), in), &b};
+  // Use this method of checking for NaN because it's the same as what's used
+  // in the reduce precision lowering.
+  Value is_nan = (in_bits & 32767).cmp(ma::CmpIPredicate::ugt, 31744);
+
+  Value value = EmitReducePrecision(in, 5, 2, b);
+  value = b.create<ma::BitcastOp>(b.getI16Type(), value);
+  value = b.create<ma::ShRUIOp>(value,
+                                b.create<ma::ConstantIntOp>(8, b.getI16Type()));
+  value = b.create<ma::TruncIOp>(b.getI8Type(), value);
+  // When the input is NaN, just truncating can turn a NaN into an inf if the
+  // mantissa becomes 0.
+  value = b.create<ma::SelectOp>(
+      is_nan, b.create<ma::ConstantIntOp>(0x7F, value.getType()), value);
+  return b.create<ma::BitcastOp>(b.getFloat8E5M2Type(), value);
+}
+
+Value EmitFloatConversion(Value value, mlir::FloatType to_ty,
+                          mlir::ImplicitLocOpBuilder& b) {
+  using ma::CmpIPredicate;
+
+  // This is a port of ConvertImpl in
+  // https://github.com/jax-ml/ml_dtypes/blob/main/ml_dtypes/include/float8.h
+  auto from_ty = mlir::cast<mlir::FloatType>(value.getType());
+  if (to_ty == b.getFloat8E5M2Type() && from_ty == b.getF16Type()) {
+    return EmitF16ToF8e5m2(value, b);
+  }
+
+  int from_mantissa = GetSignificandBits(from_ty);
+  int from_bias = GetExponentBias(from_ty);
+  int from_min_exp =
+      llvm::APFloat::semanticsMinExponent(from_ty.getFloatSemantics());
+  int from_max_exp =
+      llvm::APFloat::semanticsMaxExponent(from_ty.getFloatSemantics());
+  auto from_int_ty = b.getIntegerType(from_ty.getIntOrFloatBitWidth());
+
+  int to_mantissa = GetSignificandBits(to_ty);
+  int to_bias = GetExponentBias(to_ty);
+  int to_min_exp =
+      llvm::APFloat::semanticsMinExponent(to_ty.getFloatSemantics());
+  int to_max_exp =
+      llvm::APFloat::semanticsMaxExponent(to_ty.getFloatSemantics());
+  auto to_int_ty = b.getIntegerType(to_ty.getIntOrFloatBitWidth());
+
+  mlir::IntegerType wide_int_ty;
+  if (from_ty.getWidth() == 8 && to_ty.getWidth() == 8) {
+    wide_int_ty = b.getI16Type();
+  } else {
+    wide_int_ty = b.getIntegerType(
+        std::max(from_int_ty.getWidth(), to_int_ty.getWidth()));
+  }
+  auto convert_int = [&](mlir::Type ty, Value v) -> Val {
+    if (v.getType() == ty) {
+      return {v, &b};
+    }
+    if (ty.getIntOrFloatBitWidth() < v.getType().getIntOrFloatBitWidth()) {
+      return {b.create<ma::TruncIOp>(ty, v), &b};
+    }
+    return {b.create<ma::ExtUIOp>(ty, v), &b};
+  };
+
+  int64_t exp_offset = to_bias - from_bias;
+  int digit_shift = to_mantissa - from_mantissa;
+
+  Val from_bits{
+      b.create<ma::BitcastOp>(
+          b.getIntegerType(value.getType().getIntOrFloatBitWidth()), value),
+      &b};
+
+  auto cst = [&](mlir::Type ty, int64_t n) -> Val {
+    return {b.create<ma::ConstantIntOp>(n, ty), &b};
+  };
+
+  // Shift bits to destination type, without sign bit.
+  Val from_sign_bit =
+      from_bits.shrui(value.getType().getIntOrFloatBitWidth() - 1) != 0;
+
+  from_bits =
+      from_bits & ((1ULL << (value.getType().getIntOrFloatBitWidth() - 1)) - 1);
+
+  Value result_is_inf = IsInf(value, b);
+  Value input_is_nan = IsNaN(value, b);
+
+  auto cst_bits = [&](llvm::APFloat f) {
+    return cst(b.getIntegerType(llvm::APFloat::getSizeInBits(f.getSemantics())),
+               f.bitcastToAPInt().getZExtValue());
+  };
+  Value to_inf = cst_bits(llvm::APFloat::getInf(to_ty.getFloatSemantics()));
+  Value to_nan = cst_bits(llvm::APFloat::getNaN(to_ty.getFloatSemantics()));
+  Val to_zero = cst_bits(llvm::APFloat::getZero(to_ty.getFloatSemantics()));
+
+  auto round_bits_to_nearest_even = [&](Val bits, Val roundoff) {
+    assert(bits.value.getType() == roundoff.value.getType());
+    // Round to nearest even by adding a bias term.
+    // Consider a bit pattern
+    //   FFF...FLRTT...T,
+    // where bits RTT...T need to be rounded-off.  We add a bias term to the
+    // bit pattern s.t. a carry is introduced to round up only if
+    // - L is 1, R is 1, OR
+    // - L is 0, R is 1, any T is one.
+    // We do this by adding L to a bit pattern consisting of all T = 1.
+    Val rounded = (bits.shrui(roundoff) & 1) +
+                  (bits.MakeConstant(1).shl(roundoff - 1) - 1);
+    Val bias{b.create<SelectOp>(roundoff == 0, roundoff, rounded), &b};
+    return bits + bias;
+  };
+
+  // Happy path: no subnormals, infinities or NaNs.
+  Value result;
+  {
+    // Round the mantissa if it is shrinking.
+    Val rounded_from_bits = convert_int(wide_int_ty, from_bits);
+    if (digit_shift < 0) {
+      rounded_from_bits = round_bits_to_nearest_even(
+                              from_bits, from_bits.MakeConstant(-digit_shift)) &
+                          ~((1ll << (-digit_shift)) - 1);
+    }
+
+    // Re-bias the exponent.
+    rounded_from_bits = rounded_from_bits + (exp_offset << from_mantissa);
+
+    // Check for overflows by aligning the significands. We always align the
+    // narrower significand to the wider significand.
+    int64_t to_highest = llvm::APFloat::getLargest(to_ty.getFloatSemantics())
+                             .bitcastToAPInt()
+                             .getZExtValue();
+    int64_t aligned_highest = to_highest;
+    if (digit_shift < 0) {
+      aligned_highest <<= -digit_shift;
+      // Shift down, all dropped bits should already be zero.
+      result = rounded_from_bits.shrui(-digit_shift);
+    } else {
+      // Shift up, inserting zeros in the newly created digits.
+      rounded_from_bits = rounded_from_bits.shl(digit_shift);
+      result = rounded_from_bits;
+    }
+    result = convert_int(to_int_ty, result);
+
+    // `From` supports larger values than `To`, we may overflow.
+    if (std::make_pair(to_max_exp, to_mantissa) <
+        std::make_pair(from_max_exp, from_mantissa)) {
+      result = b.create<SelectOp>(
+          rounded_from_bits.cmp(CmpIPredicate::ugt, aligned_highest), to_inf,
+          result);
+    }
+  }
+
+  auto i32_ty = b.getI32Type();
+  Val biased_from_exp = convert_int(i32_ty, from_bits.shrui(from_mantissa));
+
+  if (to_min_exp < from_min_exp) {
+    // `To` supports more exponents near zero which means that some subnormal
+    // values in `From` may become normal.
+
+    // Subnormals.
+    Val bits = convert_int(wide_int_ty, from_bits);
+
+    // Determine exponent in target type.
+    Value normalization_factor =
+        convert_int(i32_ty,
+                    b.create<mlir::math::CountLeadingZerosOp>(from_bits)) -
+        (from_int_ty.getWidth() - from_mantissa - 1);
+
+    Val biased_exponent = cst(i32_ty, exp_offset + 1) - normalization_factor;
+    // If the result is subnormal, adjust the subnormal bits to account for
+    // the difference in exponent bias.
+    Value subnormal_bits = bits;
+    if (exp_offset < wide_int_ty.getWidth()) {
+      subnormal_bits = bits.shl(exp_offset);
+    }
+
+    // Result is normal. Shift the mantissa to account for the number of
+    // leading zero digits, and clear the hidden bit.
+    // Insert the exponent bits.
+    Value normal_bits =
+        (bits.shl(convert_int(wide_int_ty, normalization_factor)) &
+         ~(1 << from_mantissa)) |
+        convert_int(wide_int_ty, biased_exponent).shl(from_mantissa);
+
+    Value biased_exp_sle_zero = biased_exponent.cmp(CmpIPredicate::sle, 0);
+    bits.value =
+        b.create<SelectOp>(biased_exp_sle_zero, subnormal_bits, normal_bits);
+    if (digit_shift > 0) {
+      bits = bits.shl(digit_shift);
+    } else {
+      bits = round_bits_to_nearest_even(bits, bits.MakeConstant(-digit_shift));
+      bits = bits.shrui(-digit_shift);
+    }
+    bits = convert_int(to_int_ty, bits);
+
+    result = b.create<SelectOp>(biased_from_exp == 0, bits, result);
+  } else if (to_min_exp > from_min_exp) {
+    // `To` supports fewer exponents near zero which means that some values in
+    // `From` may become subnormal.
+    Val unbiased_exp = biased_from_exp - from_bias;
+    Val biased_to_exp = unbiased_exp + to_bias;
+    // Subnormals and zero.
+    // Round and shift mantissa down.
+    Val from_has_leading_one = biased_from_exp != 0;
+    Val from_has_leading_one_i32 = convert_int(i32_ty, from_has_leading_one);
+    from_has_leading_one = convert_int(from_int_ty, from_has_leading_one);
+    Val exponent_shift_i32 =
+        (from_has_leading_one_i32 - biased_to_exp) - digit_shift;
+    // Insert the implicit leading 1 bit on the mantissa for normalized
+    // inputs.
+    Val rounded_from_bits = (from_bits & ((1ll << from_mantissa) - 1)) |
+                            from_has_leading_one.shl(from_mantissa);
+
+    // NOTE: we need to round again from the original from_bits,
+    // otherwise the lower precision bits may already be lost.  There is
+    // an edge-case where rounding to a normalized value would normally
+    // round down, but for a subnormal, we need to round up.
+    Val exponent_shift_from_ty = convert_int(from_int_ty, exponent_shift_i32);
+    Val exponent_shift_to_ty = convert_int(to_int_ty, exponent_shift_i32);
+    Val positive_bits = convert_int(
+        to_int_ty,
+        round_bits_to_nearest_even(rounded_from_bits, exponent_shift_from_ty)
+            .shrui(exponent_shift_from_ty));
+    // To avoid UB, limit rounding and shifting to the full mantissa plus
+    // leading 1.
+    positive_bits.value = b.create<SelectOp>(
+        exponent_shift_i32.cmp(CmpIPredicate::sle, from_mantissa + 1),
+        positive_bits, to_zero);
+
+    Val negative_bits = convert_int(to_int_ty, rounded_from_bits)
+                            .shl(to_zero - exponent_shift_to_ty);
+    Value bits =
+        b.create<SelectOp>(exponent_shift_i32.cmp(CmpIPredicate::sgt, 0),
+                           positive_bits, negative_bits);
+    result = b.create<SelectOp>(biased_to_exp.cmp(CmpIPredicate::sle, 0), bits,
+                                result);
+  }
+
+  // Handle types with no unsigned zero.
+  auto is_nuz = [](mlir::FloatType ty) {
+    return ty.isFloat8E4M3B11FNUZ() || ty.isFloat8E4M3FNUZ() ||
+           ty.isFloat8E5M2FNUZ();
+  };
+
+  if (is_nuz(to_ty)) {
+    // Clear the sign bit if the result is zero (the output has no negative
+    // zero).
+    Val result_is_non_zero = Val{result, &b} != 0;
+    from_sign_bit = from_sign_bit & result_is_non_zero;
+  } else if (is_nuz(from_ty)) {
+    // Clear the sign bit if the input is NaN (it's positive but encoded as
+    // negative 0).
+    from_sign_bit = from_sign_bit ^ input_is_nan;
+  }
+
+  result = b.create<SelectOp>(result_is_inf, to_inf, result);
+  result = b.create<SelectOp>(from_bits == 0, to_zero, result);
+  result = b.create<SelectOp>(input_is_nan, to_nan, result);
+
+  Value neg_result = Val{result, &b} | (1ll << (to_int_ty.getWidth() - 1));
+
+  // Insert sign bit.
+  result = b.create<SelectOp>(from_sign_bit, neg_result, result);
+  result = b.create<ma::BitcastOp>(to_ty, result);
+  return result;
+}
+
+struct RewriteTruncFPattern : public mlir::OpRewritePattern<ma::TruncFOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  mlir::LogicalResult matchAndRewrite(
+      ma::TruncFOp op, mlir::PatternRewriter& rewriter) const override {
+    using FloatValue = mlir::TypedValue<mlir::FloatType>;
+    auto src = mlir::cast<FloatValue>(op.getOperand());
+    auto dst_ty = mlir::cast<mlir::FloatType>(op.getType());
+    if (dst_ty.getWidth() != 8) {
+      return rewriter.notifyMatchFailure(op, "not an 8 bit truncf");
+    }
+
+    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    rewriter.replaceOp(op, EmitFloatConversion(src, dst_ty, b));
+    return mlir::success();
+  }
+};
+
+struct RewriteExtFPattern : public mlir::OpRewritePattern<ma::ExtFOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  mlir::LogicalResult matchAndRewrite(
+      ma::ExtFOp op, mlir::PatternRewriter& rewriter) const override {
+    using FloatValue = mlir::TypedValue<mlir::FloatType>;
+    auto src = mlir::cast<FloatValue>(op.getOperand());
+    auto dst_ty = mlir::cast<mlir::FloatType>(op.getType());
+    if (src.getType().getWidth() != 8) {
+      return rewriter.notifyMatchFailure(op, "not an 8 bit extf");
+    }
+
+    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    rewriter.replaceOp(op, EmitFloatConversion(src, dst_ty, b));
+    return mlir::success();
+  }
+};
+
+// Lowering for cmpf : f8 for float to pred conversions.
+struct RewriteF8Cst : public mlir::OpRewritePattern<ma::CmpFOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  mlir::LogicalResult matchAndRewrite(
+      ma::CmpFOp op, mlir::PatternRewriter& rewriter) const override {
+    using FloatValue = mlir::TypedValue<mlir::FloatType>;
+    auto lhs = mlir::cast<FloatValue>(op.getLhs());
+    auto rhs = mlir::cast<FloatValue>(op.getRhs());
+
+    if (lhs.getType().getWidth() != 8) {
+      return rewriter.notifyMatchFailure(op, "not an 8 bit cmpf");
+    }
+
+    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    // Skip the f32 conversion if we're comparing UNE.cst.
+    llvm::APFloat rhs_cst(rhs.getType().getFloatSemantics());
+    if (op.getPredicate() == ma::CmpFPredicate::UNE &&
+        mlir::matchPattern(rhs, mlir::m_ConstantFloat(&rhs_cst))) {
+      Val int_value{b.create<ma::BitcastOp>(rewriter.getI8Type(), lhs), &b};
+      int64_t constant = rhs_cst.bitcastToAPInt().getZExtValue();
+      // If we're comparing to +-0, compare the absolute values.
+      if (rhs_cst.isZero() &&
+          (lhs.getType().isFloat8E4M3FN() || lhs.getType().isFloat8E5M2())) {
+        int_value = int_value & 0x7f;
+        constant &= 0x7f;
+      }
+      auto cst = b.create<ma::ConstantIntOp>(constant, rewriter.getI8Type());
+      rewriter.replaceOpWithNewOp<ma::CmpIOp>(op, ma::CmpIPredicate::ne,
+                                              int_value, cst);
+      return mlir::success();
+    }
+
+    auto lhs_ext = b.create<ma::ExtFOp>(b.getF32Type(), lhs);
+    auto rhs_ext = b.create<ma::ExtFOp>(b.getF32Type(), rhs);
+    rewriter.replaceOpWithNewOp<ma::CmpFOp>(op, op->getResultTypes(),
+                                            mlir::ValueRange{lhs_ext, rhs_ext},
+                                            op->getAttrs());
+    return mlir::success();
+  }
+};
+
+struct RewriteAbsFPattern : public mlir::OpRewritePattern<mlir::math::AbsFOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  mlir::LogicalResult matchAndRewrite(
+      mlir::math::AbsFOp op, mlir::PatternRewriter& rewriter) const override {
+    using FloatValue = mlir::TypedValue<mlir::FloatType>;
+    auto src = mlir::cast<FloatValue>(op.getOperand());
+    // LowerGpuOpsToNVVMOps has a lowering for abs that doesn't work with bf16.
+    // Once that's removed, remove the code for BF16 here.
+    if (src.getType().getWidth() != 8 && !src.getType().isBF16()) {
+      return rewriter.notifyMatchFailure(op, "not an f8 or bf16 absf");
+    }
+    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    mlir::Type i_ty = rewriter.getIntegerType(src.getType().getWidth());
+    Val value{b.create<ma::BitcastOp>(i_ty, src), &b};
+    if (src.getType().getWidth() == 8) {
+      value = value & 0x7f;
+    } else {
+      CHECK(src.getType().isBF16());
+      value = value & 0x7fff;
+    }
+    rewriter.replaceOpWithNewOp<ma::BitcastOp>(op, src.getType(), value);
+    return mlir::success();
+  }
+};
+
+template <typename Op>
+struct RewriteIToFpPattern : public mlir::OpRewritePattern<Op> {
+  using mlir::OpRewritePattern<Op>::OpRewritePattern;
+
+  mlir::LogicalResult matchAndRewrite(
+      Op op, mlir::PatternRewriter& rewriter) const override {
+    if (op.getType().getIntOrFloatBitWidth() != 8) {
+      return rewriter.notifyMatchFailure(op, "not an f8 itofp");
+    }
+    Value to_float =
+        rewriter.create<Op>(op.getLoc(), rewriter.getF32Type(), op.getIn());
+    rewriter.replaceOpWithNewOp<ma::TruncFOp>(op, op.getType(), to_float);
+    return mlir::success();
+  }
+};
+
+template <typename Op>
+struct RewriteFpToIPattern : public mlir::OpRewritePattern<Op> {
+  using mlir::OpRewritePattern<Op>::OpRewritePattern;
+
+  mlir::LogicalResult matchAndRewrite(
+      Op op, mlir::PatternRewriter& rewriter) const override {
+    if (op.getIn().getType().getIntOrFloatBitWidth() != 8) {
+      return rewriter.notifyMatchFailure(op, "not an f8 fptoi");
+    }
+    Value to_f32 = rewriter.create<ma::ExtFOp>(
+        op.getLoc(), rewriter.getF32Type(), op.getIn());
+    rewriter.replaceOpWithNewOp<Op>(op, op.getType(), to_f32);
+    return mlir::success();
+  }
+};
+
+class ExpandFloatOpsPass
+    : public impl::ExpandFloatOpsPassBase<ExpandFloatOpsPass> {
+ public:
+  using ExpandFloatOpsPassBase::ExpandFloatOpsPassBase;
+  void runOnOperation() override {
+    mlir::RewritePatternSet patterns(&getContext());
+    patterns.add<RewriteToCmpSelect<ma::MinimumFOp, ma::CmpFPredicate::OLE>>(
+        &getContext(), /*include_f32=*/pre_ampere_);
+    patterns.add<RewriteToCmpSelect<ma::MaximumFOp, ma::CmpFPredicate::OGE>>(
+        &getContext(), /*include_f32=*/pre_ampere_);
+    patterns.add<RewriteTruncFPattern, RewriteExtFPattern, RewriteAbsFPattern,
+                 RewriteF8Cst, RewriteIToFpPattern<ma::SIToFPOp>,
+                 RewriteIToFpPattern<ma::UIToFPOp>,
+                 RewriteFpToIPattern<ma::FPToSIOp>,
+                 RewriteFpToIPattern<ma::FPToUIOp>>(&getContext());
+    mlir::populatePolynomialApproximateTanhPattern(patterns);
+    patterns.add<RewriteErf32Pattern>(&getContext());
+    if (mlir::failed(mlir::applyPatternsAndFoldGreedily(getOperation(),
+                                                        std::move(patterns)))) {
+      signalPassFailure();
+    }
+  }
+};
+
+}  // namespace
+
+std::unique_ptr<mlir::Pass> CreateExpandFloatOpsPass(bool pre_ampere) {
+  return createExpandFloatOpsPass(ExpandFloatOpsPassOptions{pre_ampere});
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/flatten_tensors.cc b/third_party/xla/xla/service/gpu/fusions/transforms/flatten_tensors.cc
new file mode 100644
index 0000000..a60a003
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/flatten_tensors.cc
@@ -0,0 +1,452 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <cstdint>
+#include <memory>
+#include <optional>
+#include <utility>
+
+#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallBitVector.h"
+#include "llvm/ADT/SmallVector.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Utils/Utils.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeRange.h"
+#include "mlir/IR/Types.h"
+#include "mlir/IR/Value.h"
+#include "mlir/IR/ValueRange.h"
+#include "mlir/IR/Visitors.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "xla/layout_util.h"
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h"
+#include "xla/service/gpu/model/indexing_analysis.h"
+#include "xla/shape_util.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+#define GEN_PASS_DEF_FLATTENTENSORSPASS
+#include "xla/service/gpu/fusions/transforms/passes.h.inc"
+
+using mlir::Location;
+using mlir::LogicalResult;
+using mlir::MLIRContext;
+using mlir::OpRewritePattern;
+using mlir::PatternRewriter;
+using mlir::RankedTensorType;
+using mlir::SmallVector;
+using mlir::Type;
+using mlir::TypedValue;
+using mlir::TypeRange;
+using mlir::UnrealizedConversionCastOp;
+using mlir::Value;
+using mlir::ValueRange;
+using mlir::func::FuncOp;
+using mlir::func::ReturnOp;
+using mlir::scf::ForOp;
+using mlir::scf::IfOp;
+using mlir::tensor::ExtractOp;
+using mlir::tensor::InsertOp;
+
+RankedTensorType GetFlattenedType(RankedTensorType tensor_type) {
+  return RankedTensorType::get({tensor_type.getNumElements()},
+                               tensor_type.getElementType());
+}
+
+bool HasOnlyFlatTensorsOrScalars(TypeRange types) {
+  return llvm::all_of(types, [](Type ty) {
+    auto tensor_type = mlir::dyn_cast<RankedTensorType>(ty);
+    if (!tensor_type) return true;
+    return tensor_type.getRank() < 2;
+  });
+}
+
+struct RewriteFunctionSignatures : OpRewritePattern<FuncOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(FuncOp op,
+                                PatternRewriter& rewriter) const override {
+    auto input_types = op.getFunctionType().getInputs();
+    auto result_types = op.getFunctionType().getResults();
+    if (HasOnlyFlatTensorsOrScalars(input_types) &&
+        HasOnlyFlatTensorsOrScalars(result_types)) {
+      return rewriter.notifyMatchFailure(op, "nothing to flatten");
+    }
+
+    auto loc = op.getLoc();
+    mlir::Block* entry_block = &op.getBody().front();
+    SmallVector<Type> new_result_types;
+    SmallVector<Value> new_results;
+
+    // If some results are tensors, we need to flatten them.
+    auto terminator = entry_block->getTerminator();
+    rewriter.setInsertionPoint(terminator);
+
+    for (Value result : terminator->getOperands()) {
+      auto tensor_type = mlir::dyn_cast<RankedTensorType>(result.getType());
+      if (!tensor_type) {
+        new_result_types.push_back(result.getType());
+        new_results.push_back(result);
+        continue;
+      }
+      auto new_result_type = GetFlattenedType(tensor_type);
+      new_result_types.push_back(new_result_type);
+
+      Value result_1d =
+          rewriter
+              .create<UnrealizedConversionCastOp>(loc, new_result_type, result)
+              .getResult(0);
+      new_results.push_back(result_1d);
+    }
+    rewriter.replaceOpWithNewOp<ReturnOp>(terminator, new_results);
+
+    // Cast all function arguments to the original type.
+    SmallVector<Type> new_operand_types(input_types);
+    rewriter.setInsertionPointToStart(entry_block);
+    for (auto&& [index, operand_type] : llvm::enumerate(new_operand_types)) {
+      if (auto tensor_type = mlir::dyn_cast<RankedTensorType>(operand_type)) {
+        if (tensor_type.getRank() > 1) {
+          mlir::BlockArgument func_argument = op.getArgument(index);
+          auto cast_to_orig_type = rewriter.create<UnrealizedConversionCastOp>(
+              loc, operand_type, func_argument);
+          func_argument.replaceAllUsesExcept(cast_to_orig_type.getResult(0),
+                                             cast_to_orig_type);
+          operand_type = GetFlattenedType(tensor_type);
+        }
+      }
+    }
+    // Replace the function arguments with the new types.
+    for (auto [arg, arg_type] :
+         llvm::zip(entry_block->getArguments(), new_operand_types)) {
+      arg.setType(arg_type);
+    }
+    // Update function signature.
+    op.setType(rewriter.getFunctionType(new_operand_types, new_result_types));
+    return mlir::success();
+  }
+};
+
+// Returns the linearized index, if the rank is greater than 1. Otherwise,
+// returns nullptr.
+Value LinearizeIndex(TypedValue<mlir::RankedTensorType> tensor,
+                     ValueRange indices, PatternRewriter& rewriter) {
+  if (tensor.getType().getRank() < 2) {
+    return nullptr;
+  }
+  auto byte_shape = ShapeUtil::MakeShape(U8, tensor.getType().getShape());
+  if (auto encoding = tensor.getType().getEncoding()) {
+    *byte_shape.mutable_layout() = LayoutUtil::MakeLayout(llvm::to_vector(
+        mlir::cast<mlir::DenseElementsAttr>(encoding).getValues<int64_t>()));
+  }
+  auto linear_shape =
+      ShapeUtil::MakeShape(U8, {ShapeUtil::ElementsIn(byte_shape)});
+  auto linearized_map =
+      GetBitcastMap(byte_shape, linear_shape, tensor.getContext());
+  mlir::SmallVector<Value> result;
+  rewriter.createOrFold<ApplyIndexingOp>(result, tensor.getLoc(), indices,
+                                         ValueRange{}, linearized_map);
+  return result.front();
+}
+
+struct RewriteTensorExtract : OpRewritePattern<ExtractOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ExtractOp op,
+                                PatternRewriter& rewriter) const override {
+    auto tensor = op.getTensor();
+    auto tensor_type = tensor.getType();
+    auto linear_index = LinearizeIndex(tensor, op.getIndices(), rewriter);
+    if (linear_index == nullptr) {
+      return rewriter.notifyMatchFailure(op, "the tensor is already flat");
+    }
+    auto tensor_1D = rewriter
+                         .create<UnrealizedConversionCastOp>(
+                             op.getLoc(), GetFlattenedType(tensor_type), tensor)
+                         .getResult(0);
+    rewriter.replaceOpWithNewOp<ExtractOp>(op, tensor_1D, linear_index);
+    return mlir::success();
+  }
+};
+
+struct RewriteTensorInsert : OpRewritePattern<InsertOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(InsertOp op,
+                                PatternRewriter& rewriter) const override {
+    auto tensor = op.getDest();
+    auto tensor_type = tensor.getType();
+    auto linear_index = LinearizeIndex(tensor, op.getIndices(), rewriter);
+    if (linear_index == nullptr) {
+      return rewriter.notifyMatchFailure(op, "the tensor is already flat");
+    }
+    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    auto tensor_1D = b.create<UnrealizedConversionCastOp>(
+                          GetFlattenedType(tensor_type), tensor)
+                         .getResult(0);
+    auto new_insert =
+        b.create<InsertOp>(op.getScalar(), tensor_1D, linear_index);
+    auto cast_to_orig_type = b.create<UnrealizedConversionCastOp>(
+        tensor_type, new_insert.getResult());
+    rewriter.replaceOp(op, cast_to_orig_type.getResult(0));
+    return mlir::success();
+  }
+};
+
+struct RewriteAtomicRMW : OpRewritePattern<AtomicRMWOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(AtomicRMWOp op,
+                                PatternRewriter& rewriter) const override {
+    auto tensor = op.getInput();
+    auto tensor_type = tensor.getType();
+    auto linear_index = LinearizeIndex(tensor, op.getIndices(), rewriter);
+    if (linear_index == nullptr) {
+      return rewriter.notifyMatchFailure(op, "the tensor is already flat");
+    }
+    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    auto tensor_1D = b.create<UnrealizedConversionCastOp>(
+                          GetFlattenedType(tensor_type), tensor)
+                         .getResult(0);
+    auto new_atomic_rmw = b.create<AtomicRMWOp>(tensor_1D, linear_index);
+    rewriter.inlineRegionBefore(op.getRegion(),
+                                &new_atomic_rmw.getRegion().front());
+    auto cast_to_orig_type = b.create<UnrealizedConversionCastOp>(
+        tensor_type, new_atomic_rmw.getResult());
+    rewriter.replaceOp(op, cast_to_orig_type.getResult(0));
+    return mlir::success();
+  }
+};
+
+// Checks that the value is produced by an unrealized conversion cast from 1D
+// tensor to ND. Returns the 1D tensor if so.
+std::optional<Value> GetDelinearizedTensor(Value value) {
+  auto tensor_type = mlir::dyn_cast<RankedTensorType>(value.getType());
+  if (!tensor_type || tensor_type.getRank() < 2) {
+    return std::nullopt;
+  }
+  auto cast = value.getDefiningOp<UnrealizedConversionCastOp>();
+  if (!cast || cast->getNumResults() != 1 || cast->getNumOperands() != 1) {
+    return std::nullopt;
+  }
+  auto type_before_linearization =
+      mlir::dyn_cast<RankedTensorType>(cast->getOperand(0).getType());
+  if (!type_before_linearization || type_before_linearization.getRank() != 1) {
+    return std::nullopt;
+  }
+  return cast->getOperand(0);
+}
+
+struct RewriteForOp : public OpRewritePattern<ForOp> {
+  using OpRewritePattern<ForOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ForOp op,
+                                PatternRewriter& rewriter) const override {
+    llvm::SmallBitVector args_to_update(op.getNumResults(), false);
+    mlir::SmallVector<Value> new_init_args;
+    new_init_args.reserve(op.getNumResults());
+    for (auto [index, arg] : llvm::enumerate(op.getInitArgs())) {
+      auto type_before_linearization = GetDelinearizedTensor(arg);
+      if (!type_before_linearization.has_value()) {
+        new_init_args.push_back(arg);
+        continue;
+      }
+      new_init_args.push_back(*type_before_linearization);
+      args_to_update.set(index);
+    }
+    if (args_to_update.none()) {
+      return rewriter.notifyMatchFailure(op, "no args to update");
+    }
+    // Create new ForOp with updated init args.
+    Location loc = op.getLoc();
+    auto new_for_op =
+        rewriter.create<ForOp>(loc, op.getLowerBound(), op.getUpperBound(),
+                               op.getStep(), new_init_args);
+    new_for_op->setAttrs(op->getAttrs());
+
+    // Insert casts for the block arguments.
+    mlir::Block* new_body = new_for_op.getBody();
+    mlir::Block* old_body = op.getBody();
+    rewriter.setInsertionPoint(new_body, new_body->begin());
+    SmallVector<Value, 4> updated_block_args{new_body->getArguments().begin(),
+                                             new_body->getArguments().end()};
+    for (auto [index, arg] :
+         llvm::enumerate(new_body->getArguments().drop_front())) {
+      if (!args_to_update.test(index)) continue;
+      updated_block_args[index + 1] =
+          rewriter
+              .create<UnrealizedConversionCastOp>(
+                  loc, old_body->getArgument(index + 1).getType(), arg)
+              .getResult(0);
+    }
+
+    // Move the body of the old ForOp to the new one.
+    rewriter.mergeBlocks(old_body, new_body, updated_block_args);
+
+    // Update the terminator.
+    auto new_terminator =
+        mlir::cast<mlir::scf::YieldOp>(new_body->getTerminator());
+    rewriter.setInsertionPoint(new_terminator);
+    for (auto&& [index, yielded_value] :
+         llvm::enumerate(new_terminator.getResultsMutable())) {
+      if (!args_to_update.test(index)) continue;
+      yielded_value.assign(
+          rewriter
+              .create<UnrealizedConversionCastOp>(
+                  loc, new_init_args[index].getType(), yielded_value.get())
+              .getResult(0));
+    }
+
+    // Cast back the results.
+    rewriter.setInsertionPointAfter(new_for_op);
+    SmallVector<Value> new_results(new_for_op.getResults());
+    for (auto&& [index, result] : llvm::enumerate(new_results)) {
+      if (!args_to_update.test(index)) continue;
+      result = rewriter
+                   .create<UnrealizedConversionCastOp>(
+                       loc, op->getResult(index).getType(), result)
+                   .getResult(0);
+    }
+    rewriter.replaceOp(op, new_results);
+    return mlir::failure();
+  }
+};
+
+struct RewriteIfOp : public OpRewritePattern<IfOp> {
+  using OpRewritePattern<IfOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(IfOp op,
+                                PatternRewriter& rewriter) const override {
+    auto result_types = op.getResultTypes();
+    if (HasOnlyFlatTensorsOrScalars(result_types)) {
+      return rewriter.notifyMatchFailure(op, "nothing to flatten");
+    }
+    mlir::scf::YieldOp then_yield = op.thenYield();
+    SmallVector<Type> new_result_types;
+    new_result_types.reserve(then_yield.getNumOperands());
+    bool found_cast = false;
+    for (auto& result : then_yield->getOpOperands()) {
+      auto delinearized_tensor = GetDelinearizedTensor(result.get());
+      if (!delinearized_tensor.has_value()) {
+        new_result_types.push_back(result.get().getType());
+        continue;
+      }
+      new_result_types.push_back(delinearized_tensor->getType());
+      result.set(*delinearized_tensor);
+      found_cast = true;
+    }
+    if (!found_cast) {
+      return rewriter.notifyMatchFailure(op, "no cast found");
+    }
+    Location loc = op.getLoc();
+    // Update the else branch if present.
+    bool has_else_region = !op.getElseRegion().empty();
+    if (has_else_region) {
+      mlir::scf::YieldOp else_yield = op.elseYield();
+      mlir::OpBuilder::InsertionGuard g(rewriter);
+      rewriter.setInsertionPoint(else_yield);
+      for (auto&& [result, type] :
+           llvm::zip(else_yield->getOpOperands(), new_result_types)) {
+        if (result.get().getType() == type) continue;
+        result.set(
+            rewriter.create<UnrealizedConversionCastOp>(loc, type, result.get())
+                .getResult(0));
+      }
+    }
+    // Create new IfOp and move the old op's regions to the new one.
+    auto new_if_op = rewriter.create<IfOp>(loc, new_result_types,
+                                           op.getCondition(), has_else_region);
+    rewriter.inlineRegionBefore(op.getThenRegion(),
+                                &new_if_op.getThenRegion().back());
+    rewriter.eraseBlock(&new_if_op.getThenRegion().back());
+    if (has_else_region) {
+      rewriter.inlineRegionBefore(op.getElseRegion(),
+                                  &new_if_op.getElseRegion().back());
+      rewriter.eraseBlock(&new_if_op.getElseRegion().back());
+    }
+
+    // Update the results.
+    rewriter.setInsertionPointAfter(new_if_op);
+    SmallVector<Value> new_results(new_if_op.getResults());
+    for (auto&& [index, result] : llvm::enumerate(new_results)) {
+      Type old_type = op->getResult(index).getType();
+      if (result.getType() == old_type) continue;
+      result =
+          rewriter.create<UnrealizedConversionCastOp>(loc, old_type, result)
+              .getResult(0);
+    }
+    rewriter.replaceOp(op, new_results);
+    return mlir::success();
+  }
+};
+
+class FlattenTensorsPass
+    : public impl::FlattenTensorsPassBase<FlattenTensorsPass> {
+ public:
+  void runOnOperation() override {
+    mlir::ModuleOp module = getOperation();
+    MLIRContext* mlir_context = &getContext();
+    mlir::RewritePatternSet patterns(mlir_context);
+    // clang-format off
+    patterns.add<
+        RewriteAtomicRMW,
+        RewriteForOp,
+        RewriteFunctionSignatures,
+        RewriteIfOp,
+        RewriteTensorExtract,
+        RewriteTensorInsert
+    >(mlir_context);
+    // clang-format on
+    ApplyIndexingOp::getCanonicalizationPatterns(patterns, mlir_context);
+    if (mlir::failed(
+            mlir::applyPatternsAndFoldGreedily(module, std::move(patterns)))) {
+      signalPassFailure();
+      return;
+    }
+    // Check if there are no unrealized_conversion_casts.
+    bool module_has_casts = module
+                                .walk([](UnrealizedConversionCastOp op) {
+                                  return mlir::WalkResult::interrupt();
+                                })
+                                .wasInterrupted();
+    if (module_has_casts) {
+      llvm::outs() << "FlattenTensorsPass failed to converge";
+      signalPassFailure();
+      return;
+    }
+  }
+};
+
+}  // namespace
+
+std::unique_ptr<mlir::Pass> CreateFlattenTensorsPass() {
+  return std::make_unique<FlattenTensorsPass>();
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/lower_tensors.cc b/third_party/xla/xla/service/gpu/fusions/transforms/lower_tensors.cc
new file mode 100644
index 0000000..efccaaa
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/lower_tensors.cc
@@ -0,0 +1,1095 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <cassert>
+#include <cstdint>
+#include <memory>
+#include <optional>
+#include <string>
+#include <tuple>
+#include <utility>
+
+#include "absl/algorithm/container.h"
+#include "absl/log/check.h"
+#include "absl/strings/str_cat.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/LogicalResult.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeRange.h"
+#include "mlir/IR/Types.h"
+#include "mlir/IR/Value.h"
+#include "mlir/IR/ValueRange.h"
+#include "mlir/Interfaces/DataLayoutInterfaces.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "xla/layout_util.h"
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h"
+#include "xla/service/gpu/model/indexing_analysis.h"
+#include "xla/shape_util.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/xla_data.pb.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+#define GEN_PASS_DECL_LOWERTENSORSPASS
+#define GEN_PASS_DEF_LOWERTENSORSPASS
+#include "xla/service/gpu/fusions/transforms/passes.h.inc"
+
+using mlir::failure;
+using mlir::Location;
+using mlir::LogicalResult;
+using mlir::MLIRContext;
+using mlir::OpBuilder;
+using mlir::Operation;
+using mlir::success;
+using mlir::Type;
+using mlir::TypedValue;
+using mlir::TypeRange;
+using mlir::Value;
+using mlir::ValueRange;
+
+namespace arith = ::mlir::arith;
+namespace scf = ::mlir::scf;
+namespace ml = ::mlir::LLVM;
+
+Value GetDestinationBuffer(Value dest) {
+  while (dest.getDefiningOp()) {
+    int result_number = mlir::cast<mlir::OpResult>(dest).getResultNumber();
+    if (auto insert = dest.getDefiningOp<mlir::tensor::InsertOp>()) {
+      dest = insert.getDest();
+    } else if (auto scf_if = dest.getDefiningOp<scf::IfOp>()) {
+      // Pick one of the branches, they're required to yield the same buffers.
+      dest = scf_if.getThenRegion().front().getTerminator()->getOperand(
+          result_number);
+    } else if (auto scf_for = dest.getDefiningOp<scf::ForOp>()) {
+      dest = scf_for.getInitArgs()[result_number];
+    } else if (dest.getDefiningOp<mlir::UnrealizedConversionCastOp>() ||
+               dest.getDefiningOp<AllocateSharedOp>()) {
+      break;
+    } else if (auto transfer_write =
+                   dest.getDefiningOp<mlir::vector::TransferWriteOp>()) {
+      dest = transfer_write.getSource();
+    } else {
+      dest.getDefiningOp()->emitOpError("unsupported dest type");
+      return nullptr;
+    }
+  }
+  return dest;
+}
+
+template <typename Op>
+bool IsSupportedTransfer(Op op) {
+  return !absl::c_linear_search(op.getInBoundsValues(), false) &&
+         op.getVectorType().getRank() == 1 && !op.getMask() &&
+         op.getPermutationMap().isMinorIdentity();
+}
+
+struct RewriteFunctionSignatures : mlir::OpRewritePattern<mlir::func::FuncOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(
+      mlir::func::FuncOp op, mlir::PatternRewriter& rewriter) const override {
+    auto is_tensor = [](Type ty) {
+      return mlir::isa<mlir::RankedTensorType>(ty);
+    };
+    if (!llvm::any_of(op.getFunctionType().getInputs(), is_tensor)) {
+      return rewriter.notifyMatchFailure(op,
+                                         "the function has no input tensors");
+    }
+
+    bool some_tensor_result =
+        llvm::any_of(op.getFunctionType().getResults(), is_tensor);
+    bool all_tensor_results =
+        llvm::all_of(op.getFunctionType().getResults(), is_tensor);
+    if (some_tensor_result && !all_tensor_results) {
+      op->emitOpError("function has a mix of tensor and non-tensor results");
+      return failure();
+    }
+
+    TypeRange new_results = op.getFunctionType().getResults();
+    if (some_tensor_result) {
+      new_results = {};
+      auto terminator = op.getFunctionBody().front().getTerminator();
+      rewriter.setInsertionPoint(terminator);
+      rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(terminator);
+    }
+
+    llvm::SmallVector<Type> new_operands(op.getFunctionType().getInputs());
+    for (auto&& [index, operand] : llvm::enumerate(new_operands)) {
+      if (is_tensor(operand)) {
+        rewriter.setInsertionPointToStart(&op.getBody().front());
+        auto cast = rewriter.create<mlir::UnrealizedConversionCastOp>(
+            op.getLoc(), operand, op.getArgument(index));
+        op.getArgument(index).replaceAllUsesExcept(cast.getResult(0), cast);
+        operand = mlir::LLVM::LLVMPointerType::get(op.getContext());
+      }
+    }
+
+    op.setFunctionType(rewriter.getFunctionType(new_operands, new_results));
+    auto& entry = op->getRegion(0).front();
+    for (auto [arg, arg_type] : llvm::zip(entry.getArguments(), new_operands)) {
+      arg.setType(arg_type);
+    }
+
+    return success();
+  }
+};
+
+Value GetLinearIndex(TypedValue<mlir::RankedTensorType> tensor,
+                     ValueRange indices, mlir::PatternRewriter& rewriter) {
+  auto byte_shape = ShapeUtil::MakeShape(U8, tensor.getType().getShape());
+  if (auto encoding = tensor.getType().getEncoding()) {
+    *byte_shape.mutable_layout() = LayoutUtil::MakeLayout(llvm::to_vector(
+        mlir::cast<mlir::DenseElementsAttr>(encoding).getValues<int64_t>()));
+  }
+  auto linear_shape =
+      ShapeUtil::MakeShape(U8, {ShapeUtil::ElementsIn(byte_shape)});
+  auto linearize_map =
+      GetBitcastMap(byte_shape, linear_shape, tensor.getContext());
+  mlir::SmallVector<Value> result;
+  rewriter.createOrFold<ApplyIndexingOp>(result, tensor.getLoc(), indices,
+                                         ValueRange{}, linearize_map);
+  CHECK_EQ(result.size(), 1);
+  auto index = result.front();
+  auto index_ty = rewriter.getIntegerType(
+      mlir::DataLayout::closest(rewriter.getInsertionBlock()->getParentOp())
+          .getTypeSizeInBits(index.getType()));
+  return rewriter.create<mlir::arith::IndexCastUIOp>(tensor.getLoc(), index_ty,
+                                                     index);
+}
+
+std::tuple<Value, Value> GetI4IndexAndNibble(Value linear_index,
+                                             mlir::ImplicitLocOpBuilder& b) {
+  Value one = b.create<mlir::arith::ConstantIntOp>(1, linear_index.getType());
+  Value is_low_nibble = b.create<mlir::arith::CmpIOp>(
+      mlir::arith::CmpIPredicate::eq, one,
+      b.create<mlir::arith::AndIOp>(linear_index, one));
+  Value i8_index = b.create<mlir::arith::ShRUIOp>(linear_index, one);
+  return {i8_index, is_low_nibble};
+}
+
+mlir::LLVM::GEPOp CreateGep(TypedValue<mlir::RankedTensorType> tensor,
+                            Value linear_index, mlir::PatternRewriter& rewriter,
+                            Type element_type = nullptr) {
+  if (!element_type) {
+    element_type = tensor.getType().getElementType();
+  }
+  auto ptr = mlir::LLVM::LLVMPointerType::get(rewriter.getContext());
+  auto tensor_ptr = rewriter
+                        .create<mlir::UnrealizedConversionCastOp>(
+                            tensor.getLoc(), ptr, tensor)
+                        .getResult(0);
+  mlir::LLVMTypeConverter converter(rewriter.getContext());
+  auto llvm_element_type = converter.convertType(element_type);
+  auto gep = rewriter.create<mlir::LLVM::GEPOp>(
+      tensor.getLoc(), ptr, llvm_element_type, tensor_ptr, linear_index);
+  gep.setInbounds(true);
+  return gep;
+}
+
+mlir::LLVM::GEPOp CreateGep(TypedValue<mlir::RankedTensorType> tensor,
+                            ValueRange indices,
+                            mlir::PatternRewriter& rewriter) {
+  return CreateGep(tensor, GetLinearIndex(tensor, indices, rewriter), rewriter);
+}
+
+struct RewriteTensorExtract : mlir::OpRewritePattern<mlir::tensor::ExtractOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(
+      mlir::tensor::ExtractOp op,
+      mlir::PatternRewriter& rewriter) const override {
+    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    auto linear_index =
+        GetLinearIndex(op.getTensor(), op.getIndices(), rewriter);
+    Type element_type = op.getTensor().getType().getElementType();
+    Value is_low_nibble = nullptr;
+    if (element_type == rewriter.getI4Type()) {
+      element_type = rewriter.getI8Type();
+      std::tie(linear_index, is_low_nibble) =
+          GetI4IndexAndNibble(linear_index, b);
+    }
+
+    auto gep = CreateGep(op.getTensor(), linear_index, rewriter, element_type);
+    auto load =
+        rewriter
+            .create<mlir::LLVM::LoadOp>(gep.getLoc(), gep.getElemType(), gep)
+            .getResult();
+
+    if (is_low_nibble) {
+      auto high_value = b.create<mlir::arith::ShRUIOp>(
+          load, b.create<mlir::arith::ConstantIntOp>(4, load.getType()));
+      load = b.create<mlir::arith::TruncIOp>(
+          op.getType(),
+          b.create<mlir::arith::SelectOp>(is_low_nibble, load, high_value));
+    }
+
+    rewriter.replaceOpWithNewOp<mlir::UnrealizedConversionCastOp>(
+        op, op.getType(), load);
+    return success();
+  }
+};
+
+// Swaps pairs of values in the vector: [0, 1, 2, 3] -> [1, 0, 3, 2].
+Value PermutePairsInVector(Value vector, mlir::ImplicitLocOpBuilder& b) {
+  // There is a `vector.extract_strided_slice` op that would be useful here, but
+  // it actually requires the strides to be 1.
+  auto ty = mlir::cast<mlir::VectorType>(vector.getType());
+  int size = ty.getNumElements();
+  Value result = vector;
+  for (int i = 0; i < size; i += 2) {
+    auto v0 = b.create<mlir::vector::ExtractOp>(vector, i);
+    auto v1 = b.create<mlir::vector::ExtractOp>(vector, i + 1);
+    result = b.create<mlir::vector::InsertOp>(v1, result, i);
+    result = b.create<mlir::vector::InsertOp>(v0, result, i + 1);
+  }
+  return result;
+}
+
+struct RewriteTransferRead
+    : mlir::OpRewritePattern<mlir::vector::TransferReadOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(
+      mlir::vector::TransferReadOp op,
+      mlir::PatternRewriter& rewriter) const override {
+    assert(IsSupportedTransfer(op));
+
+    auto source = mlir::dyn_cast<mlir::TypedValue<mlir::RankedTensorType>>(
+        op.getSource());
+
+    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    auto linear_index = GetLinearIndex(source, op.getIndices(), rewriter);
+
+    mlir::VectorType vector_type = op.getVectorType();
+    if (vector_type.getElementType().isInteger(1)) {
+      vector_type = vector_type.cloneWith(std::nullopt, b.getI8Type());
+    }
+    mlir::Type gep_element_type = vector_type.getElementType();
+    if (op.getVectorType().getElementType().isInteger(4)) {
+      linear_index = b.create<arith::ShRUIOp>(
+          linear_index,
+          b.create<arith::ConstantIntOp>(1, linear_index.getType()));
+      gep_element_type = b.getI8Type();
+    }
+    auto gep = CreateGep(source, linear_index, rewriter, gep_element_type);
+
+    mlir::LLVMTypeConverter converter(b.getContext());
+    auto llvm_vector_type = converter.convertType(vector_type);
+    auto loaded =
+        b.create<mlir::LLVM::LoadOp>(llvm_vector_type, gep).getResult();
+
+    if (source.getType().getElementType().isInteger(1)) {
+      Value zero = b.create<mlir::arith::ConstantOp>(
+          mlir::DenseElementsAttr::get(vector_type, b.getI8IntegerAttr(0)));
+      loaded = b.create<arith::CmpIOp>(arith::CmpIPredicate::ne, loaded, zero);
+    } else if (source.getType().getElementType().isInteger(4)) {
+      // LLVM and XLA pack i4s in opposite order, so we have to reshuffle the
+      // elements.
+      loaded = PermutePairsInVector(loaded, b);
+    }
+
+    rewriter.replaceOpWithNewOp<mlir::UnrealizedConversionCastOp>(
+        op, op.getType(), loaded);
+    return success();
+  }
+};
+
+struct RewriteTensorInsert : mlir::OpRewritePattern<mlir::tensor::InsertOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(
+      mlir::tensor::InsertOp op,
+      mlir::PatternRewriter& rewriter) const override {
+    Value dest = GetDestinationBuffer(op.getDest());
+    if (!dest) {
+      return failure();
+    }
+
+    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    auto tensor_dest = mlir::cast<TypedValue<mlir::RankedTensorType>>(dest);
+    auto linear_index = GetLinearIndex(tensor_dest, op.getIndices(), rewriter);
+    auto element_type = tensor_dest.getType().getElementType();
+    Value is_low_nibble = nullptr;
+
+    if (element_type == rewriter.getI4Type()) {
+      element_type = rewriter.getI8Type();
+      std::tie(linear_index, is_low_nibble) =
+          GetI4IndexAndNibble(linear_index, b);
+    }
+
+    auto gep = CreateGep(tensor_dest, linear_index, rewriter, element_type);
+    auto scalar_value = op.getScalar();
+
+    if (is_low_nibble) {
+      Value current_value =
+          b.create<mlir::LLVM::LoadOp>(gep.getElemType(), gep);
+      auto ty = current_value.getType();
+      scalar_value = b.create<mlir::arith::ExtUIOp>(ty, scalar_value);
+      Value low_updated = b.create<mlir::arith::OrIOp>(
+          b.create<mlir::arith::AndIOp>(
+              current_value, b.create<mlir::arith::ConstantIntOp>(0xf0, ty)),
+          scalar_value);
+      Value high_updated = b.create<mlir::arith::OrIOp>(
+          b.create<mlir::arith::AndIOp>(
+              current_value, b.create<mlir::arith::ConstantIntOp>(0x0f, ty)),
+          b.create<mlir::arith::ShLIOp>(
+              scalar_value, b.create<mlir::arith::ConstantIntOp>(4, ty)));
+      scalar_value = b.create<mlir::arith::SelectOp>(is_low_nibble, low_updated,
+                                                     high_updated);
+    }
+
+    mlir::LLVMTypeConverter converter(getContext());
+    auto llvm_type = converter.convertType(scalar_value.getType());
+    scalar_value = rewriter
+                       .create<mlir::UnrealizedConversionCastOp>(
+                           gep.getLoc(), llvm_type, scalar_value)
+                       .getResult(0);
+    rewriter.create<mlir::LLVM::StoreOp>(gep.getLoc(), scalar_value, gep);
+
+    op.replaceAllUsesWith(op.getDest());
+    op.erase();
+    return success();
+  }
+};
+
+struct RewriteTransferWrite
+    : mlir::OpRewritePattern<mlir::vector::TransferWriteOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(
+      mlir::vector::TransferWriteOp op,
+      mlir::PatternRewriter& rewriter) const override {
+    assert(IsSupportedTransfer(op));
+    Value dest = GetDestinationBuffer(op.getSource());
+
+    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    auto tensor_dest = mlir::cast<TypedValue<mlir::RankedTensorType>>(dest);
+    auto linear_index = GetLinearIndex(tensor_dest, op.getIndices(), rewriter);
+    auto element_type = tensor_dest.getType().getElementType();
+
+    mlir::Value vector_value = op.getVector();
+    if (op.getVectorType().getElementType().isInteger(1)) {
+      vector_value = b.create<arith::ExtUIOp>(
+          op.getVectorType().cloneWith(std::nullopt, b.getI8Type()),
+          vector_value);
+    }
+    if (op.getVectorType().getElementType().isInteger(4)) {
+      linear_index = b.create<arith::ShRUIOp>(
+          linear_index,
+          b.create<arith::ConstantIntOp>(1, linear_index.getType()));
+      element_type = rewriter.getI8Type();
+      // LLVM and XLA pack i4s in opposite order, so we have to reshuffle the
+      // elements.
+      vector_value = PermutePairsInVector(vector_value, b);
+    }
+    auto gep = CreateGep(tensor_dest, linear_index, rewriter, element_type);
+
+    mlir::LLVMTypeConverter converter(getContext());
+    auto llvm_type = converter.convertType(vector_value.getType());
+    vector_value =
+        b.create<mlir::UnrealizedConversionCastOp>(llvm_type, vector_value)
+            .getResult(0);
+    b.create<mlir::LLVM::StoreOp>(vector_value, gep);
+
+    rewriter.replaceOp(op, mlir::ValueRange{op.getSource()});
+    return success();
+  }
+};
+
+struct RewriteCall : mlir::OpRewritePattern<mlir::func::CallOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(
+      mlir::func::CallOp op, mlir::PatternRewriter& rewriter) const override {
+    if (!llvm::any_of(op->getOperandTypes(), [](Type ty) {
+          return mlir::isa<mlir::RankedTensorType>(ty);
+        })) {
+      return rewriter.notifyMatchFailure(op, "the call has no input tensors");
+    }
+
+    for (const auto&& [index, arg] : llvm::enumerate(op.getOperands())) {
+      if (mlir::isa<mlir::RankedTensorType>(arg.getType())) {
+        op.setOperand(
+            index,
+            rewriter
+                .create<mlir::UnrealizedConversionCastOp>(
+                    op.getLoc(),
+                    mlir::LLVM::LLVMPointerType::get(op.getContext()), arg)
+                .getResult(0));
+      }
+    }
+    return success();
+  }
+};
+
+mlir::LLVM::GlobalOp CreateGlobalOp(mlir::Attribute value,
+                                    const std::string& name_prefix,
+                                    mlir::ShapedType shaped_ty,
+                                    mlir::ModuleOp module, bool is_constant,
+                                    int addr_space,
+                                    mlir::ImplicitLocOpBuilder& b) {
+  if (auto elements = mlir::dyn_cast_or_null<mlir::DenseElementsAttr>(value)) {
+    // The lowering to LLVM only works for 1d tensors or those with trailing
+    // unit dimensions.
+    value = elements.reshape(mlir::RankedTensorType::get(
+        {elements.getNumElements()}, elements.getElementType()));
+  }
+
+  Type element_type = shaped_ty.getElementType();
+  // Needed to support complex element type.
+  mlir::LLVMTypeConverter converter(b.getContext());
+  auto llvm_element_type = converter.convertType(element_type);
+  auto array_ty = mlir::LLVM::LLVMArrayType::get(llvm_element_type,
+                                                 shaped_ty.getNumElements());
+  std::string name;
+  int index = 0;
+  do {
+    name = absl::StrCat(name_prefix, index);
+    ++index;
+  } while (module.lookupSymbol(name));
+  b.setInsertionPointToStart(module.getBody());
+  return b.create<mlir::LLVM::GlobalOp>(
+      array_ty, is_constant,
+      /*linkage=*/mlir::LLVM::Linkage::Private, name, value, /*alignment=*/0,
+      addr_space);
+}
+
+struct RewriteAllocateShared : mlir::OpRewritePattern<AllocateSharedOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(
+      AllocateSharedOp op, mlir::PatternRewriter& rewriter) const override {
+    auto module = op->getParentOfType<mlir::ModuleOp>();
+    auto shaped_ty = mlir::cast<mlir::ShapedType>(op.getResult().getType());
+    constexpr int kGPUSharedMemoryAddrSpace = 3;
+    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+
+    auto global =
+        CreateGlobalOp(mlir::Attribute{}, "shared_", shaped_ty, module,
+                       /*is_constant=*/false, kGPUSharedMemoryAddrSpace, b);
+
+    rewriter.setInsertionPoint(op);
+    auto addr = rewriter.create<mlir::LLVM::AddressOfOp>(op.getLoc(), global);
+    rewriter.replaceOpWithNewOp<mlir::UnrealizedConversionCastOp>(
+        op, op.getResult().getType(),
+        rewriter
+            .create<mlir::LLVM::AddrSpaceCastOp>(
+                op.getLoc(), mlir::LLVM::LLVMPointerType::get(op.getContext()),
+                addr)
+            .getResult());
+    return success();
+  }
+};
+
+struct RewriteNonScalarConstants
+    : mlir::OpRewritePattern<mlir::arith::ConstantOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  mlir::LogicalResult matchAndRewrite(
+      mlir::arith::ConstantOp op,
+      mlir::PatternRewriter& rewriter) const override {
+    if (mlir::isa<mlir::VectorType>(op.getType())) {
+      return rewriter.notifyMatchFailure(op, "the op is a vector constant");
+    }
+    auto shaped_ty = mlir::dyn_cast<mlir::ShapedType>(op.getValue().getType());
+    // We only need to rewrite non-scalar constants.
+    if (!shaped_ty || shaped_ty.getNumElements() < 2) {
+      return rewriter.notifyMatchFailure(
+          op, "the op is an effective scalar constant");
+    }
+
+    constexpr int kGPUGlobalMemoryAddrSpace = 0;
+    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    auto module = op->getParentOfType<mlir::ModuleOp>();
+    auto global =
+        CreateGlobalOp(op.getValue(), "global_cst_", shaped_ty, module,
+                       /*is_constant=*/true, kGPUGlobalMemoryAddrSpace, b);
+
+    rewriter.setInsertionPoint(op);
+    auto addr = rewriter.create<mlir::LLVM::AddressOfOp>(op.getLoc(), global);
+    rewriter.replaceOpWithNewOp<mlir::UnrealizedConversionCastOp>(
+        op, op.getResult().getType(),
+        rewriter
+            .create<mlir::LLVM::AddrSpaceCastOp>(
+                op.getLoc(), mlir::LLVM::LLVMPointerType::get(op.getContext()),
+                addr)
+            .getResult());
+    return success();
+  }
+};
+
+struct RewriteSyncThreads : mlir::OpRewritePattern<SyncThreadsOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(
+      SyncThreadsOp op, mlir::PatternRewriter& rewriter) const override {
+    rewriter.create<mlir::gpu::BarrierOp>(op.getLoc());
+    rewriter.replaceOp(op, op.getOperands());
+    return success();
+  }
+};
+
+// TODO(jreiffers): Generalize this to support index switches with some used
+// results and upstream it as a canonicalization pattern.
+struct RemoveUnusedIndexSwitchResults
+    : mlir::OpRewritePattern<scf::IndexSwitchOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(
+      scf::IndexSwitchOp op, mlir::PatternRewriter& rewriter) const override {
+    if (op->getNumResults() == 0 || !op->use_empty()) {
+      return rewriter.notifyMatchFailure(op, "the op has users");
+    }
+
+    auto new_op = rewriter.create<scf::IndexSwitchOp>(
+        op.getLoc(), mlir::TypeRange{}, op.getArg(), op.getCases(),
+        op.getNumCases());
+    for (int i = 0; i < op->getNumRegions(); ++i) {
+      auto& old_region = op->getRegion(i);
+      auto& new_region = new_op->getRegion(i);
+      rewriter.mergeBlocks(&old_region.getBlocks().front(),
+                           &new_region.emplaceBlock());
+      auto yield_op = new_region.getBlocks().front().getTerminator();
+      rewriter.modifyOpInPlace(yield_op, [&]() { yield_op->setOperands({}); });
+    }
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
+bool IsAtomicIntegral(Type element_type) {
+  if (!element_type.isInteger()) {
+    return false;
+  }
+  unsigned element_bitwidth = element_type.getIntOrFloatBitWidth();
+  return element_bitwidth == 32 || element_bitwidth == 64;
+}
+
+Value CreateBitcast(mlir::ImplicitLocOpBuilder& b, Value value, Type ty) {
+  if (value.getType().isIntOrFloat() && ty.isIntOrFloat()) {
+    return b.create<ml::BitcastOp>(ty, value);
+  }
+
+  mlir::LLVMTypeConverter converter(b.getContext());
+  // If either type is a complex, we need to go through an alloca, since no
+  // direct bitcast from a struct to an int is possible.
+  Type llvm_input_ty = converter.convertType(value.getType());
+  Type llvm_result_ty = converter.convertType(ty);
+  Type ptr_ty = mlir::LLVM::LLVMPointerType::get(b.getContext());
+
+  Value llvm_value =
+      b.create<mlir::UnrealizedConversionCastOp>(llvm_input_ty, value)
+          .getResult(0);
+  Value alloca = b.create<ml::AllocaOp>(
+      ptr_ty, llvm_input_ty, b.create<ml::ConstantOp>(b.getI32Type(), 1));
+  b.create<ml::StoreOp>(llvm_value, alloca);
+  auto result = b.create<ml::LoadOp>(llvm_result_ty, alloca).getResult();
+  return b.create<mlir::UnrealizedConversionCastOp>(ty, result).getResult(0);
+};
+
+class RewriteAtomicRMW : public mlir::OpRewritePattern<AtomicRMWOp> {
+ public:
+  RewriteAtomicRMW(mlir::MLIRContext* context, bool is_amd,
+                   const std::string& gpu_arch)
+      : mlir::OpRewritePattern<AtomicRMWOp>(context),
+        is_amd_(is_amd),
+        gpu_arch_(gpu_arch) {}
+
+  LogicalResult matchAndRewrite(
+      AtomicRMWOp op, mlir::PatternRewriter& rewriter) const override {
+    if (failed(rewriteAsDirectAtomicRMW(op, rewriter))) {
+      rewriteAsAtomicCAS(op, rewriter);
+    }
+    rewriter.replaceOp(op, op.getInput());
+    return success();
+  }
+
+ private:
+  // Returns atomic op modifier and the atomic bin op kind.
+  std::optional<std::pair<Value, ml::AtomicBinOp>> GetAtomicModifierParameters(
+      AtomicRMWOp op) const {
+    Type element_type = op.getInput().getType().getElementType();
+    auto& operations = op.getBody()->getOperations();
+    auto terminator = op.getBody()->getTerminator();
+    if (operations.size() > 2) {
+      return std::nullopt;
+    }
+    // If the body contains only the terminator, then it is an atomic store.
+    if (operations.size() == 1) {
+      // TODO(b/336367145): Support complex<f32> atomic store.
+      if (element_type.isF32() || IsAtomicIntegral(element_type)) {
+        return std::make_pair(terminator->getOperand(0), ml::AtomicBinOp::xchg);
+      }
+      return std::nullopt;
+    }
+    // Match the kind of the atomic op.
+    mlir::Operation* modifier_op = &operations.front();
+    std::optional<ml::AtomicBinOp> kind =
+        llvm::TypeSwitch<Operation*, std::optional<ml::AtomicBinOp>>(
+            modifier_op)
+            // Floating-point operations.
+            .Case([](arith::AddFOp op) { return ml::AtomicBinOp::fadd; })
+            .Case([](arith::MaximumFOp op) { return ml::AtomicBinOp::fmax; })
+            .Case([](arith::MinimumFOp op) { return ml::AtomicBinOp::fmin; })
+            // Integer operations.
+            .Case([&](arith::AddIOp op) {
+              return IsAtomicIntegral(element_type)
+                         ? std::make_optional(ml::AtomicBinOp::add)
+                         : std::nullopt;
+            })
+            .Case([&](arith::MaxUIOp op) {
+              return IsAtomicIntegral(element_type)
+                         ? std::make_optional(ml::AtomicBinOp::umax)
+                         : std::nullopt;
+            })
+            .Case([&](arith::MinUIOp op) {
+              return IsAtomicIntegral(element_type)
+                         ? std::make_optional(ml::AtomicBinOp::umin)
+                         : std::nullopt;
+            })
+            .Case([&](arith::MaxSIOp op) {
+              return IsAtomicIntegral(element_type)
+                         ? std::make_optional(ml::AtomicBinOp::max)
+                         : std::nullopt;
+            })
+            .Case([&](arith::MinSIOp op) {
+              return IsAtomicIntegral(element_type)
+                         ? std::make_optional(ml::AtomicBinOp::min)
+                         : std::nullopt;
+            })
+            .Default([](Operation* op) { return std::nullopt; });
+    if (!kind.has_value()) {
+      return std::nullopt;
+    }
+    // Find the modifier arg that does not match the argument of `atomic_rmw`
+    // body.
+    Value block_arg = op.getBody()->getArgument(0);
+    Value modifier_arg = modifier_op->getOperand(0) == block_arg
+                             ? modifier_op->getOperand(1)
+                             : modifier_op->getOperand(0);
+    return std::make_pair(modifier_arg, *kind);
+  }
+
+  // Certain computations, such as floating-point addition and integer
+  // maximization, can be simply implemented using an LLVM atomic instruction.
+  // If "computation" is one of this kind, emits code to do that and returns
+  // true; otherwise, returns false.
+  LogicalResult rewriteAsDirectAtomicRMW(
+      AtomicRMWOp op, mlir::PatternRewriter& rewriter) const {
+    auto modifier_parameters = GetAtomicModifierParameters(op);
+    if (!modifier_parameters.has_value()) {
+      return failure();
+    }
+    Value modifier_arg = modifier_parameters->first;
+    Type element_type = modifier_arg.getType();
+    ml::AtomicBinOp atomic_bin_op = modifier_parameters->second;
+
+    Location loc = op.getLoc();
+    llvm::StringRef sync_scope = is_amd_ ? "agent" : "";
+    Value addr = CreateGep(op.getInput(), op.getIndices(), rewriter);
+
+    switch (atomic_bin_op) {
+      case ml::AtomicBinOp::xchg: {
+        rewriter.create<ml::StoreOp>(
+            loc, modifier_arg, addr,
+            /*alignment=*/element_type.getIntOrFloatBitWidth() / 8,
+            /*volatile*/ false, /*isNonTemporal=*/false,
+            ml::AtomicOrdering::unordered);
+        return success();
+      }
+      case ml::AtomicBinOp::add:
+      case ml::AtomicBinOp::max:
+      case ml::AtomicBinOp::min:
+      case ml::AtomicBinOp::umax:
+      case ml::AtomicBinOp::umin: {
+        rewriter.create<ml::AtomicRMWOp>(loc, atomic_bin_op, addr, modifier_arg,
+                                         ml::AtomicOrdering::seq_cst,
+                                         sync_scope);
+        return success();
+      }
+      case ml::AtomicBinOp::fadd: {
+        // TODO(b/336367154): Introduce an atomic_rmw op with the binOp attr.
+        return is_amd_ ? emitAMDAtomicFAdd(loc, modifier_arg, addr, sync_scope,
+                                           gpu_arch_, rewriter)
+                       : emitNVidiaAtomicFAdd(loc, modifier_arg, addr,
+                                              sync_scope, gpu_arch_, rewriter);
+      }
+      case ml::AtomicBinOp::fmax: {
+        return rewriteAtomicFMaxAsIntAtomics(loc, modifier_arg, addr,
+                                             sync_scope, rewriter);
+      }
+      default:
+        return failure();
+    }
+    return success();
+  }
+
+  LogicalResult emitNVidiaAtomicFAdd(Location loc, Value modifier_arg,
+                                     Value addr, llvm::StringRef sync_scope,
+                                     llvm::StringRef cuda_arch,
+                                     OpBuilder& b) const {
+    se::CudaComputeCapability cuda_compute_capability(cuda_arch.str());
+    Type element_type = modifier_arg.getType();
+    // "atom.add.f64 requires sm_60 or higher."
+    // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-atom
+    bool is_supported_f16_atomic =
+        element_type.isF16() &&
+        cuda_compute_capability.IsAtLeast(se::CudaComputeCapability::VOLTA);
+    bool is_supported_bf16_atomic =
+        element_type.isBF16() &&
+        cuda_compute_capability.IsAtLeast(se::CudaComputeCapability::HOPPER);
+    bool is_supported_f64_atomic =
+        element_type.isF64() &&
+        cuda_compute_capability.IsAtLeast(se::CudaComputeCapability::PASCAL_);
+    if (!element_type.isF32() && !is_supported_f16_atomic &&
+        !is_supported_bf16_atomic && !is_supported_f64_atomic) {
+      return failure();
+    }
+    b.create<ml::AtomicRMWOp>(loc, ml::AtomicBinOp::fadd, addr, modifier_arg,
+                              ml::AtomicOrdering::seq_cst, sync_scope);
+    return success();
+  }
+
+  LogicalResult emitAMDAtomicFAdd(Location loc, Value modifier_arg, Value addr,
+                                  llvm::StringRef sync_scope,
+                                  llvm::StringRef gcn_arch,
+                                  OpBuilder& b) const {
+    se::RocmComputeCapability rocm_compute_capability(gcn_arch.str());
+    Type element_type = modifier_arg.getType();
+    bool is_supported_f16_atomic =
+        element_type.isF16() &&
+        rocm_compute_capability.has_fp16_atomics_support();
+    if (!element_type.isF32() && !is_supported_f16_atomic) {
+      return failure();
+    }
+    constexpr int kGlobalMemory = 1;
+    constexpr int kSharedMemory = 3;
+    auto addr_type = mlir::cast<ml::LLVMPointerType>(addr.getType());
+    // adds to shared memory are always atomic.
+    if (addr_type.getAddressSpace() != kSharedMemory) {
+      // The compiler will only generate a global_atomic_fadd if the pointer is
+      // in global addrspace (1)
+      addr = b.create<ml::AddrSpaceCastOp>(
+          loc, ml::LLVMPointerType::get(b.getContext(), kGlobalMemory), addr);
+    }
+    b.create<ml::AtomicRMWOp>(loc, ml::AtomicBinOp::fadd, addr, modifier_arg,
+                              ml::AtomicOrdering::seq_cst, sync_scope);
+    return success();
+  }
+
+  LogicalResult rewriteAtomicFMaxAsIntAtomics(Location loc, Value modifier_arg,
+                                              Value addr,
+                                              llvm::StringRef sync_scope,
+                                              OpBuilder& b) const {
+    Type element_type = modifier_arg.getType();
+    if (!element_type.isF32()) {
+      return failure();
+    }
+    // Evaluating floating max using integer atomics has the limitation of not
+    // propagating -NaNs. To handle this, we check if the update value is -NaN
+    // and convert it to a positive one by dropping the sign-bit.
+    Value current = b.create<ml::LoadOp>(loc, element_type, addr);
+    Value current_is_nan =
+        b.create<ml::FCmpOp>(loc, ml::FCmpPredicate::uno, current, current);
+    auto is_current_nan =
+        b.create<scf::IfOp>(loc, /*resultTypes=*/TypeRange{}, current_is_nan,
+                            /*addThenBlock=*/true, /*addElseBlock=*/true);
+    auto if_current_nan_then_builder =
+        OpBuilder::atBlockEnd(is_current_nan.thenBlock(), b.getListener());
+    if_current_nan_then_builder.create<scf::YieldOp>(loc);
+
+    auto if_current_nan_else_builder =
+        OpBuilder::atBlockEnd(is_current_nan.elseBlock(), b.getListener());
+    Value is_modifier_nan = if_current_nan_else_builder.create<ml::FCmpOp>(
+        loc, ml::FCmpPredicate::uno, modifier_arg, modifier_arg);
+    auto f32_nan = mlir::APFloat::getNaN(mlir::APFloat::IEEEsingle());
+    Value nan = if_current_nan_else_builder.create<ml::ConstantOp>(
+        loc, b.getF32Type(), f32_nan);
+    Value no_negative_nan_source =
+        if_current_nan_else_builder.create<ml::SelectOp>(loc, is_modifier_nan,
+                                                         nan, modifier_arg);
+    Value current_less_than_modifier =
+        if_current_nan_else_builder.create<ml::FCmpOp>(
+            loc, ml::FCmpPredicate::ult, current, no_negative_nan_source);
+
+    // This check allows us to skip the atomic update all-together at the
+    // expense of reading the value in memory for every update. Evaluated
+    // against Waymo's benchmarks, adding the check achieves better overall
+    // performance.
+    auto if_need_update = if_current_nan_else_builder.create<scf::IfOp>(
+        loc, /*resultTypes=*/TypeRange{}, current_less_than_modifier,
+        /*withElseRegion=*/true,
+        /*addElseBlock=*/false);
+    if_current_nan_else_builder.create<scf::YieldOp>(loc);
+
+    auto then_builder =
+        OpBuilder::atBlockEnd(if_need_update.thenBlock(), b.getListener());
+    Value source_float_as_int = then_builder.create<ml::BitcastOp>(
+        loc, then_builder.getI32Type(), no_negative_nan_source);
+    Value c0 = then_builder.create<ml::ConstantOp>(loc, b.getI32Type(), 0);
+    Value is_not_negative = then_builder.create<ml::ICmpOp>(
+        loc, ml::ICmpPredicate::sge, source_float_as_int, c0);
+    then_builder.create<scf::IfOp>(
+        loc, is_not_negative,
+        [&](OpBuilder& nested_b, Location nested_loc) {
+          // atomicMax((int *)address, __float_as_int(val))
+          nested_b.create<ml::AtomicRMWOp>(
+              loc, ml::AtomicBinOp::max, addr, source_float_as_int,
+              ml::AtomicOrdering::seq_cst, sync_scope);
+          nested_b.create<scf::YieldOp>(nested_loc);
+        },
+        [&](OpBuilder& nested_b, Location nested_loc) {
+          // atomicMax((int *)address, __float_as_int(val))
+          nested_b.create<ml::AtomicRMWOp>(
+              loc, ml::AtomicBinOp::umin, addr, source_float_as_int,
+              ml::AtomicOrdering::seq_cst, sync_scope);
+          nested_b.create<scf::YieldOp>(nested_loc);
+        });
+    then_builder.create<scf::YieldOp>(loc);
+    return success();
+  }
+
+  // Implements atomic binary operations using atomic compare-and-swap
+  // (atomicCAS) as follows:
+  //   1. Reads the value from the memory pointed to by output_address and
+  //     records it as old_output.
+  //   2. Uses old_output as one of the source operand to perform the binary
+  //     operation and stores the result in new_output.
+  //   3. Calls atomicCAS which implements compare-and-swap as an atomic
+  //     operation. In particular, atomicCAS reads the value from the memory
+  //     pointed to by output_address, and compares the value with old_output.
+  //     If the two values equal, new_output is written to the same memory
+  //     location and true is returned to indicate that the atomic operation
+  //     succeeds. Otherwise, the new value read from the memory is returned. In
+  //     this case, the new value is copied to old_output, and steps 2. and 3.
+  //     are repeated until atomicCAS succeeds.
+  //
+  // On Nvidia GPUs, atomicCAS can only operate on 32 bit and 64 bit integers.
+  // If the element type of the binary operation is 32 bits or 64 bits, the
+  // integer type of the same size is used for the atomicCAS operation. On the
+  // other hand, if the element type is smaller than 32 bits, int32_t is used
+  // for the atomicCAS operation. In this case, atomicCAS reads and writes 32
+  // bit values from the memory, which is larger than the memory size required
+  // by the original atomic binary operation. We mask off the last two bits of
+  // the output_address and use the result as an address to read the 32 bit
+  // values from the memory. This can avoid out of bound memory accesses if
+  // tensor buffers are 4 byte aligned and have a size of 4N, an assumption that
+  // the runtime can guarantee.
+  void rewriteAsAtomicCAS(AtomicRMWOp op,
+                          mlir::PatternRewriter& rewriter) const {
+    Location loc = op.getLoc();
+    auto input = op.getInput();
+
+    // Use 32-bit atomic type for small input types.
+    Type result_ty = op.getResult().getType().getElementType();
+    int result_size;
+    if (auto complex_ty = mlir::dyn_cast<mlir::ComplexType>(result_ty)) {
+      result_size = complex_ty.getElementType().getIntOrFloatBitWidth() * 2;
+    } else {
+      result_size = result_ty.getIntOrFloatBitWidth();
+    }
+
+    bool small_type = result_size < 32;
+    Type atomic_ty =
+        mlir::IntegerType::get(op.getContext(), small_type ? 32 : result_size);
+
+    // Calculate load address for the input.
+    Value addr = CreateGep(input, op.getIndices(), rewriter);
+    Value shift, mask;
+    if (small_type) {
+      // Update input pointer by discarding the last two bits - i.e. align to
+      // 32-bit boundary for small input types (will not result in OOB, as the
+      // input alignment is at least 32 bits).
+      Type addr_int_ty = rewriter.getI64Type();
+      Value addr_int = rewriter.create<ml::PtrToIntOp>(loc, addr_int_ty, addr);
+      Value addr_offset = rewriter.create<ml::AndOp>(
+          loc, addr_int, rewriter.create<ml::ConstantOp>(loc, addr_int_ty, 3));
+      Value index = rewriter.create<ml::MulOp>(
+          loc, addr_offset,
+          rewriter.create<ml::ConstantOp>(loc, addr_int_ty, -1));
+      addr =
+          rewriter.create<ml::GEPOp>(loc, addr.getType(), rewriter.getI8Type(),
+                                     addr, index, /*inbounds=*/true);
+
+      // Calculate the bit shift (assume little-endianness).
+      Value offset = rewriter.create<ml::TruncOp>(loc, atomic_ty, addr_offset);
+      shift = rewriter.create<ml::MulOp>(
+          loc, offset,
+          rewriter.create<ml::ConstantOp>(loc, offset.getType(), 8));
+
+      // Compose the update mask.
+      Value bits_long = rewriter.create<ml::ConstantOp>(loc, atomic_ty, -1);
+      Value bits_short = rewriter.create<ml::ZExtOp>(
+          loc, atomic_ty,
+          rewriter.create<ml::ConstantOp>(
+              loc, rewriter.getIntegerType(result_size), -1));
+      mask = rewriter.create<ml::XOrOp>(
+          loc, bits_long, rewriter.create<ml::ShlOp>(loc, bits_short, shift));
+    }
+
+    // Load initial atomic value and create the loop.
+    Value initial = rewriter.create<ml::LoadOp>(loc, atomic_ty, addr);
+    rewriter.create<scf::WhileOp>(
+        loc, TypeRange{atomic_ty}, ValueRange{initial},
+        [&](mlir::OpBuilder& builder, Location loc, ValueRange values) {
+          mlir::ImplicitLocOpBuilder b(loc, builder);
+          Value old_value = values[0];
+
+          // Convert atomic value to input value.
+          Value input_value;
+          if (small_type) {
+            Value short_value =
+                b.create<ml::TruncOp>(b.getIntegerType(result_size),
+                                      b.create<ml::LShrOp>(old_value, shift));
+            input_value = b.create<ml::BitcastOp>(result_ty, short_value);
+          } else {
+            input_value = CreateBitcast(b, old_value, result_ty);
+          }
+
+          // Perform computation on the loaded input value.
+          rewriter.mergeBlocks(&op.getComputation().front(), b.getBlock(),
+                               {input_value});
+          auto yield_op = b.getBlock()->getTerminator();
+          Value result = yield_op->getOperand(0);
+          rewriter.eraseOp(yield_op);
+
+          // Convert resulting value to atomic value.
+          Value new_value;
+          if (small_type) {
+            Value cast_value = b.create<ml::ZExtOp>(
+                atomic_ty, b.create<ml::BitcastOp>(
+                               rewriter.getIntegerType(result_size), result));
+            new_value =
+                b.create<ml::OrOp>(b.create<ml::AndOp>(old_value, mask),
+                                   b.create<ml::ShlOp>(cast_value, shift));
+          } else {
+            new_value = CreateBitcast(b, result, atomic_ty);
+          }
+
+          // Try saving the result atomically, retry if failed.
+          Value cmpxchg = b.create<ml::AtomicCmpXchgOp>(
+              loc, addr, old_value, new_value,
+              /*success_ordering=*/ml::AtomicOrdering::seq_cst,
+              /*failure_ordering=*/ml::AtomicOrdering::seq_cst);
+          Value next = b.create<ml::ExtractValueOp>(cmpxchg, 0);
+          Value ok = b.create<ml::ExtractValueOp>(cmpxchg, 1);
+          Value low_bit = b.create<ml::ConstantOp>(b.getOneAttr(b.getI1Type()));
+          Value not_ok = b.create<ml::XOrOp>(ok, low_bit);
+          b.create<scf::ConditionOp>(not_ok, ValueRange{next});
+        },
+        [&](mlir::OpBuilder& b, Location loc, ValueRange values) {
+          b.create<scf::YieldOp>(loc, values);
+        });
+  }
+
+  bool is_amd_;
+  std::string gpu_arch_;
+};
+
+class LowerTensorsPass : public impl::LowerTensorsPassBase<LowerTensorsPass> {
+ public:
+  explicit LowerTensorsPass(const LowerTensorsPassOptions& options)
+      : LowerTensorsPassBase(options) {}
+
+  void runOnOperation() override {
+    MLIRContext* mlir_context = &getContext();
+    mlir::RewritePatternSet tensor_patterns(mlir_context);
+    tensor_patterns.add<RewriteAtomicRMW>(mlir_context, is_amd_gpu_, gpu_arch_);
+    tensor_patterns
+        .add<RewriteAllocateShared, RewriteNonScalarConstants,
+             RewriteSyncThreads, RewriteTensorExtract, RewriteTransferRead,
+             RewriteTensorInsert, RewriteTransferWrite>(mlir_context);
+    if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
+            getOperation(), std::move(tensor_patterns)))) {
+      signalPassFailure();
+      return;
+    }
+
+    mlir::RewritePatternSet function_patterns(mlir_context);
+    function_patterns.add<RewriteFunctionSignatures, RewriteCall,
+                          RemoveUnusedIndexSwitchResults>(mlir_context);
+    scf::ForOp::getCanonicalizationPatterns(function_patterns, mlir_context);
+    scf::IfOp::getCanonicalizationPatterns(function_patterns, mlir_context);
+    if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
+            getOperation(), std::move(function_patterns)))) {
+      signalPassFailure();
+      return;
+    }
+
+    getOperation()->walk([this](mlir::LLVM::LoadOp load) {
+      Value addr = load.getAddr();
+      while (auto gep = addr.getDefiningOp<mlir::LLVM::GEPOp>()) {
+        addr = gep.getBase();
+      }
+      if (addr.getDefiningOp<mlir::LLVM::AddrSpaceCastOp>() ||
+          addr.getDefiningOp<mlir::LLVM::AddressOfOp>() ||
+          addr.getDefiningOp<mlir::LLVM::AllocaOp>()) {
+        // Shared memory, global constant or temporary - no need to annotate
+        // anything.
+        return;
+      }
+      if (auto base = mlir::dyn_cast<mlir::BlockArgument>(addr)) {
+        if (auto func = mlir::dyn_cast<mlir::func::FuncOp>(
+                base.getOwner()->getParentOp())) {
+          if (func.getArgAttr(base.getArgNumber(), "xla.invariant")) {
+            load.setInvariant(true);
+          }
+          return;
+        }
+      }
+      load.emitOpError("load op address is not (a GEP of) a function argument");
+      signalPassFailure();
+    });
+  }
+};
+
+}  // namespace
+
+std::unique_ptr<::mlir::Pass> CreateLowerTensorsPass(
+    bool is_amd_gpu, const std::string& gpu_arch) {
+  LowerTensorsPassOptions options;
+  options.is_amd_gpu_ = is_amd_gpu;
+  options.gpu_arch_ = gpu_arch;
+  return std::make_unique<LowerTensorsPass>(options);
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/lower_to_llvm.cc b/third_party/xla/xla/service/gpu/fusions/transforms/lower_to_llvm.cc
new file mode 100644
index 0000000..28762d0
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/lower_to_llvm.cc
@@ -0,0 +1,99 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <utility>
+
+#include "llvm/Support/LogicalResult.h"
+#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
+#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
+#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
+#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
+#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
+#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
+#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
+#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
+#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/Passes.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"  // IWYU pragma: keep
+#include "mlir/Dialect/LLVMIR/NVVMDialect.h"  // IWYU pragma: keep
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/DataLayoutInterfaces.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace xla {
+namespace gpu {
+
+#define GEN_PASS_DEF_LOWERTOLLVMPASS
+#include "xla/service/gpu/fusions/transforms/passes.h.inc"
+
+namespace {
+
+class LowerToLLVMPass : public impl::LowerToLLVMPassBase<LowerToLLVMPass> {
+ public:
+  using LowerToLLVMPassBase::LowerToLLVMPassBase;
+
+  void runOnOperation() override {
+    // Populate type conversions.
+    mlir::LowerToLLVMOptions llvm_opts(&getContext(),
+                                       mlir::DataLayout(getOperation()));
+    mlir::LLVMTypeConverter type_converter(getOperation().getContext(),
+                                           llvm_opts);
+    mlir::LLVMConversionTarget target(*getOperation().getContext());
+
+    // Populate patterns.
+    mlir::RewritePatternSet patterns(&getContext());
+    mlir::populateAffineToStdConversionPatterns(patterns);
+    mlir::populateSCFToControlFlowConversionPatterns(patterns);
+    mlir::arith::populateArithExpandOpsPatterns(patterns);
+    mlir::arith::populateArithToLLVMConversionPatterns(type_converter,
+                                                       patterns);
+    mlir::populateGpuToNVVMConversionPatterns(type_converter, patterns);
+    mlir::populateFuncToLLVMConversionPatterns(type_converter, patterns);
+    mlir::populateVectorToLLVMConversionPatterns(type_converter, patterns);
+    mlir::cf::populateControlFlowToLLVMConversionPatterns(type_converter,
+                                                          patterns);
+    mlir::populateComplexToLLVMConversionPatterns(type_converter, patterns);
+    mlir::populateMathToLLVMConversionPatterns(type_converter, patterns);
+
+    //  Setup target.
+    mlir::configureGpuToNVVMConversionLegality(target);
+    target.addIllegalDialect<mlir::arith::ArithDialect, mlir::func::FuncDialect,
+                             mlir::complex::ComplexDialect,
+                             mlir::math::MathDialect>();
+    target.addLegalOp<mlir::ModuleOp>();
+
+    if (failed(
+            applyFullConversion(getOperation(), target, std::move(patterns)))) {
+      signalPassFailure();
+    }
+  }
+};
+
+}  // namespace
+
+std::unique_ptr<mlir::Pass> CreateLowerToLLVMPass() {
+  return std::make_unique<LowerToLLVMPass>();
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/lower_xla_gpu_to_scf.cc b/third_party/xla/xla/service/gpu/fusions/transforms/lower_xla_gpu_to_scf.cc
new file mode 100644
index 0000000..cbd64b8
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/lower_xla_gpu_to_scf.cc
@@ -0,0 +1,278 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <utility>
+
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/LogicalResult.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
+#include "mlir/IR/ValueRange.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h"
+#include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h"
+#include "xla/service/gpu/ir_emission_utils.h"
+#include "xla/service/gpu/model/indexing_map.h"
+#include "xla/util.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+#define GEN_PASS_DEF_LOWERXLAGPUTOSCFPASS
+#define GEN_PASS_DEF_LOWERXLAGPULOOPSTOSCFPASS
+#include "xla/service/gpu/fusions/transforms/passes.h.inc"
+
+using mlir::ImplicitLocOpBuilder;
+using mlir::Location;
+using mlir::OpBuilder;
+using mlir::SmallVector;
+using mlir::success;
+using mlir::Value;
+using mlir::ValueRange;
+using mlir::scf::IfOp;
+
+struct RewritePredicatedInsert : mlir::OpRewritePattern<PredicatedInsertOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  mlir::LogicalResult matchAndRewrite(
+      PredicatedInsertOp op, mlir::PatternRewriter& rewriter) const override {
+    rewriter.replaceOpWithNewOp<mlir::scf::IfOp>(
+        op, op.getCondition(),
+        [&](mlir::OpBuilder& b, mlir::Location loc) {
+          b.create<mlir::scf::YieldOp>(
+              loc, b.create<mlir::tensor::InsertOp>(
+                        loc, op.getValue(), op.getDest(), op.getIndices())
+                       .getResult());
+        },
+        [&](mlir::OpBuilder& b, mlir::Location loc) {
+          b.create<mlir::scf::YieldOp>(loc, op.getDest());
+        });
+    return success();
+  }
+};
+
+struct RewritePredicatedExtract : mlir::OpRewritePattern<PredicatedExtractOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  mlir::LogicalResult matchAndRewrite(
+      PredicatedExtractOp op, mlir::PatternRewriter& rewriter) const override {
+    rewriter.replaceOpWithNewOp<mlir::scf::IfOp>(
+        op, op.getCondition(),
+        [&](mlir::OpBuilder& b, mlir::Location loc) {
+          b.create<mlir::scf::YieldOp>(
+              loc, b.create<mlir::tensor::ExtractOp>(loc, op.getSrc(),
+                                                     op.getIndices())
+                       .getResult());
+        },
+        [&](mlir::OpBuilder& b, mlir::Location loc) {
+          b.create<mlir::scf::YieldOp>(loc, op.getFallback());
+        });
+    return success();
+  }
+};
+
+struct RewriteShuffleReduce : mlir::OpRewritePattern<ShuffleReduceOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  mlir::LogicalResult matchAndRewrite(
+      ShuffleReduceOp op, mlir::PatternRewriter& rewriter) const override {
+    int max_distance =
+        mlir::cast<mlir::IntegerAttr>(op->getAttr("max_distance")).getInt();
+    // TODO(jreiffers): Do this in a verifier.
+    if (max_distance & (max_distance - 1) || max_distance >= WarpSize()) {
+      return op->emitOpError("max_distance must be a power of 2 < WarpSize()");
+    }
+
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    ValueRange values = op.getOperands();
+    for (int distance = max_distance; distance > 0; distance /= 2) {
+      namespace ml = mlir::LLVM;
+      auto shuffle_32 = [&](Value v) {
+        return b
+            .create<mlir::gpu::ShuffleOp>(v, distance, WarpSize(),
+                                          mlir::gpu::ShuffleMode::DOWN)
+            .getShuffleResult();
+      };
+
+      auto shuffle_int_or_float = [&](Value value) {
+        auto ty = value.getType();
+        int bit_width = ty.getIntOrFloatBitWidth();
+        if (bit_width == 32) {
+          return shuffle_32(value);
+        }
+        int n_shuffles = CeilOfRatio(bit_width, 32);
+        auto int_ty = b.getIntegerType(bit_width);
+        auto padded_int_ty = b.getIntegerType(n_shuffles * 32);
+        value = b.create<mlir::arith::BitcastOp>(int_ty, value);
+        value = b.create<mlir::arith::ExtUIOp>(padded_int_ty, value);
+        if (n_shuffles > 1) {
+          // Don't generate vectors if the size is 1.
+          auto vector_type = ml::getVectorType(b.getI32Type(), n_shuffles);
+          value = b.create<ml::BitcastOp>(vector_type, value);
+          Value result_vec = b.create<ml::UndefOp>(vector_type);
+          for (int i = 0; i < n_shuffles; ++i) {
+            auto idx = b.create<mlir::arith::ConstantIntOp>(i, 32);
+            result_vec = b.create<ml::InsertElementOp>(
+                result_vec,
+                shuffle_32(b.create<ml::ExtractElementOp>(value, idx)), idx);
+          }
+          value = b.create<ml::BitcastOp>(padded_int_ty, result_vec);
+        } else {
+          value = shuffle_32(value);
+        }
+        value = b.create<mlir::arith::TruncIOp>(int_ty, value);
+        value = b.create<ml::BitcastOp>(ty, value);
+        return value;
+      };
+
+      auto shuffle = [&](Value value) -> Value {
+        if (mlir::isa<mlir::ComplexType>(value.getType())) {
+          return b.create<mlir::complex::CreateOp>(
+              value.getType(),
+              shuffle_int_or_float(b.create<mlir::complex::ReOp>(value)),
+              shuffle_int_or_float(b.create<mlir::complex::ImOp>(value)));
+        }
+        if (value.getType().isUnsignedInteger()) {
+          auto ty = value.getType();
+          auto signless_ty = b.getIntegerType(ty.getIntOrFloatBitWidth());
+          value = b.create<mlir::UnrealizedConversionCastOp>(
+                       mlir::TypeRange{signless_ty}, value)
+                      .getResult(0);
+          value = shuffle_int_or_float(value);
+          value = b.create<mlir::UnrealizedConversionCastOp>(
+                       mlir::TypeRange{ty}, value)
+                      .getResult(0);
+          return value;
+        }
+        return shuffle_int_or_float(value);
+      };
+
+      SmallVector<Value> args = values;
+      for (auto value : values) {
+        args.push_back(shuffle(value));
+      }
+      values = b.create<PureCallOp>(op.getResultTypes(),
+                                    op.getReducerAttr().getAttr(), args)
+                   .getResults();
+    }
+    rewriter.replaceOp(op, values);
+    return success();
+  }
+};
+
+struct RewriteXlaGpuLoop : mlir::OpRewritePattern<LoopOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  mlir::LogicalResult matchAndRewrite(
+      LoopOp op, mlir::PatternRewriter& rewriter) const override {
+    Location loc = op.getLoc();
+    ImplicitLocOpBuilder b(loc, rewriter);
+
+    IndexingMap indexing_map = op.getIndexingMap();
+    SmallVector<Value, 4> lbs, ubs, steps;
+    mlir_converter::GetLoopBoundsFromIndexingMap(b, indexing_map, &lbs, &ubs,
+                                                 &steps);
+    mlir::scf::LoopNest loop_nest = mlir::scf::buildLoopNest(
+        b, loc, lbs, ubs, steps, op.getInits(),
+        [&](OpBuilder& nested_builder, Location loc, ValueRange symbol_values,
+            ValueRange iter_args) -> mlir::scf::ValueVector {
+          mlir::ImplicitLocOpBuilder nested_b(loc, nested_builder);
+          auto is_in_bounds = mlir_converter::CheckConstraints(
+              indexing_map, op.getDims(), symbol_values, nested_b);
+          auto if_op = nested_b.create<mlir::scf::IfOp>(
+              is_in_bounds,
+              [&](OpBuilder& then_builder, Location then_loc) -> void {
+                SmallVector<Value, 4> bb_args(symbol_values);
+                bb_args.append(iter_args.begin(), iter_args.end());
+
+                mlir::Block* then_block = then_builder.getInsertionBlock();
+                OpBuilder::InsertionGuard guard(rewriter);
+                rewriter.setInsertionPointToStart(then_block);
+                rewriter.mergeBlocks(op.getBody(), then_block, bb_args);
+
+                auto old_terminator = then_block->getTerminator();
+                then_builder.create<mlir::scf::YieldOp>(
+                    then_loc, old_terminator->getOperands());
+                old_terminator->erase();
+              },
+              [&](OpBuilder& else_b, Location else_loc) {
+                else_b.create<mlir::scf::YieldOp>(loc, iter_args);
+              });
+          return if_op.getResults();
+        });
+    rewriter.replaceOp(op, loop_nest.results);
+    return mlir::success();
+  }
+};
+
+class LowerXlaGpuToScfPass
+    : public impl::LowerXlaGpuToScfPassBase<LowerXlaGpuToScfPass> {
+ public:
+  void runOnOperation() override {
+    auto* ctx = &getContext();
+    mlir::RewritePatternSet patterns(ctx);
+    patterns.add<RewritePredicatedInsert, RewritePredicatedExtract,
+                 RewriteShuffleReduce>(ctx);
+    if (mlir::failed(mlir::applyPatternsAndFoldGreedily(getOperation(),
+                                                        std::move(patterns)))) {
+      signalPassFailure();
+    }
+  }
+};
+
+class LowerXlaGpuLoopsToScfPass
+    : public impl::LowerXlaGpuLoopsToScfPassBase<LowerXlaGpuLoopsToScfPass> {
+ public:
+  void runOnOperation() override {
+    auto* ctx = &getContext();
+    mlir::RewritePatternSet patterns(ctx);
+    patterns.add<RewriteXlaGpuLoop>(ctx);
+    if (mlir::failed(mlir::applyPatternsAndFoldGreedily(getOperation(),
+                                                        std::move(patterns)))) {
+      signalPassFailure();
+    }
+  }
+};
+
+}  // namespace
+
+std::unique_ptr<::mlir::Pass> CreateLowerXlaGpuToScfPass() {
+  return std::make_unique<LowerXlaGpuToScfPass>();
+}
+
+std::unique_ptr<::mlir::Pass> CreateLowerXlaGpuLoopsToScfPass() {
+  return std::make_unique<LowerXlaGpuLoopsToScfPass>();
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/merge_pointers_to_same_slice.cc b/third_party/xla/xla/service/gpu/fusions/transforms/merge_pointers_to_same_slice.cc
new file mode 100644
index 0000000..50193e3
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/merge_pointers_to_same_slice.cc
@@ -0,0 +1,117 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <optional>
+#include <string>
+
+#include "absl/container/flat_hash_map.h"
+#include "llvm/ADT/BitVector.h"
+#include "llvm/ADT/STLExtras.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+
+namespace xla {
+namespace gpu {
+
+#define GEN_PASS_DEF_MERGEPOINTERSTOSAMESLICEPASS
+#include "xla/service/gpu/fusions/transforms/passes.h.inc"
+
+namespace {
+
+class MergePointersToSameSlicePass
+    : public impl::MergePointersToSameSlicePassBase<
+          MergePointersToSameSlicePass> {
+ public:
+  void runOnOperation() override;
+};
+
+struct PackedArgs {
+  llvm::BitVector args_to_erase;
+  // replacement_args[i] == i iff !args_to_erase[i].
+  llvm::SmallVector<int> replacement_args;
+
+  PackedArgs() = default;
+  explicit PackedArgs(mlir::func::FuncOp func) {
+    absl::flat_hash_map<int, std::optional<int>> slice_to_operand;
+    args_to_erase.resize(func.getNumArguments());
+    replacement_args.reserve(func.getNumArguments());
+    for (int i = 0; i < func.getNumArguments(); ++i) {
+      replacement_args.push_back(i);
+    }
+
+    for (auto [idx, operand] : llvm::enumerate(func.getArguments())) {
+      auto slice_index = func.getArgAttr(idx, "xla.slice_index");
+      if (!slice_index) {
+        continue;
+      }
+
+      auto& target_index = slice_to_operand[static_cast<int>(
+          mlir::cast<mlir::IntegerAttr>(slice_index).getInt())];
+      if (target_index) {
+        replacement_args[idx] = *target_index;
+        args_to_erase[idx] = true;
+      } else {
+        target_index = idx;
+      }
+    }
+  }
+
+  void Pack(mlir::func::FuncOp op) {
+    for (auto [idx, arg] : llvm::enumerate(op.getArguments())) {
+      if (replacement_args[idx] != idx) {
+        arg.replaceAllUsesWith(op.getArgument(replacement_args[idx]));
+      }
+    }
+    op.eraseArguments(args_to_erase);
+    for (int i = 0; i < op.getNumArguments(); ++i) {
+      if (op.getArgAttr(i, "xla.slice_index")) {
+        op.removeArgAttr(i, "xla.slice_index");
+        op.setArgAttr(i, mlir::LLVM::LLVMDialect::getNoAliasAttrName(),
+                      mlir::UnitAttr::get(op->getContext()));
+      }
+    }
+  }
+
+  void Pack(mlir::func::CallOp op) { op->eraseOperands(args_to_erase); }
+};
+
+void MergePointersToSameSlicePass::runOnOperation() {
+  mlir::func::FuncOp entry;
+
+  absl::flat_hash_map<std::string, PackedArgs> args_to_pack;
+  getOperation()->walk([&](mlir::func::FuncOp func) {
+    args_to_pack[func.getName()] = PackedArgs(func);
+  });
+  getOperation()->walk([&](mlir::func::CallOp call) {
+    args_to_pack[call.getCallee()].Pack(call);
+  });
+  getOperation()->walk([&](mlir::func::FuncOp func) {
+    args_to_pack[func.getName()].Pack(func);
+  });
+}
+
+}  // namespace
+
+std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
+CreateMergePointersToSameSlicePass() {
+  return std::make_unique<MergePointersToSameSlicePass>();
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/optimize_loops.cc b/third_party/xla/xla/service/gpu/fusions/transforms/optimize_loops.cc
new file mode 100644
index 0000000..e483bfe
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/optimize_loops.cc
@@ -0,0 +1,326 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <algorithm>
+#include <memory>
+#include <optional>
+#include <utility>
+
+#include "absl/algorithm/container.h"
+#include "absl/log/check.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/SmallVector.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Utils/Utils.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
+#include "mlir/IR/ValueRange.h"
+#include "mlir/IR/Visitors.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h"
+#include "xla/service/gpu/model/indexing_map.h"
+
+namespace xla {
+namespace gpu {
+
+#define GEN_PASS_DEF_OPTIMIZELOOPSPASS
+#include "xla/service/gpu/fusions/transforms/passes.h.inc"
+
+namespace {
+
+mlir::Value GetSource(mlir::vector::TransferReadOp op) {
+  return op.getSource();
+}
+
+bool DoIndicesDependOnInductionVar(mlir::ValueRange indices,
+                                   mlir::scf::ForOp loop) {
+  // We assume LICM ran, so we can just check if any index is defined in the
+  // loop.
+  return absl::c_any_of(indices, [&](mlir::Value v) {
+    return v.getParentRegion() == &loop.getBodyRegion();
+  });
+}
+
+bool CanReplaceInductionVar(mlir::ValueRange indices) {
+  return absl::c_all_of(indices, [&](mlir::Value v) {
+    if (auto bbarg = mlir::dyn_cast<mlir::BlockArgument>(v)) {
+      auto for_op = mlir::dyn_cast_or_null<mlir::scf::ForOp>(
+          v.getParentRegion()->getParentOp());
+      // This is a bbarg that is defined outside of the loop, so it doesn't
+      // affect pipelining.
+      if (!for_op) {
+        return true;
+      }
+      // We can only replace the induction variable, not other loop-carried
+      // values.
+      return v == for_op.getInductionVar();
+    }
+    auto* op = v.getDefiningOp();
+    return op &&
+           mlir::isa<mlir::arith::ConstantOp, ApplyIndexingOp,
+                     mlir::arith::MaxSIOp, mlir::arith::MinSIOp,
+                     mlir::arith::IndexCastOp, mlir::arith::IndexCastUIOp>(
+               op) &&
+           CanReplaceInductionVar(op->getOperands());
+  });
+}
+
+llvm::SmallVector<mlir::Value> ReplaceInductionVar(
+    mlir::Value induction_var, mlir::Value replacement,
+    llvm::SmallVector<mlir::Value> indices,
+    mlir::ImplicitLocOpBuilder& builder) {
+  for (mlir::Value& index : indices) {
+    if (mlir::isa<mlir::BlockArgument>(index)) {
+      if (index == induction_var) {
+        index = replacement;
+      }
+      continue;
+    }
+
+    auto* op = index.getDefiningOp();
+    CHECK(op) << "Did CanReplaceInductionVar() fail?";
+    if (mlir::isa<mlir::arith::ConstantOp>(op)) {
+      continue;
+    }
+
+    CHECK(
+        (mlir::isa<ApplyIndexingOp, mlir::arith::MaxSIOp, mlir::arith::MinSIOp,
+                   mlir::arith::IndexCastOp, mlir::arith::IndexCastUIOp>(op)))
+        << "Did CanReplaceInductionVar() fail?";
+    auto replaced_args = ReplaceInductionVar(induction_var, replacement,
+                                             op->getOperands(), builder);
+    index = builder
+                .create(builder.getLoc(), op->getName().getIdentifier(),
+                        replaced_args, op->getResultTypes(), op->getAttrs())
+                ->getResult(0);
+  }
+  return indices;
+}
+
+mlir::Value GetSource(mlir::tensor::ExtractOp op) { return op.getTensor(); }
+
+// TODO(jreiffers): Use a shared memory queue for pipelining instead of
+// registers.
+template <typename Op>
+struct PipelineLoad : mlir::OpRewritePattern<Op> {
+  using mlir::OpRewritePattern<Op>::OpRewritePattern;
+
+  mlir::LogicalResult matchAndRewrite(
+      Op op, mlir::PatternRewriter& rewriter) const override {
+    auto loop = mlir::dyn_cast_or_null<mlir::scf::ForOp>(op->getParentOp());
+    if (!loop) {
+      return rewriter.notifyMatchFailure(op, "no loop found");
+    }
+
+    if (auto step = loop.getConstantStep();
+        !step || step->getSExtValue() != 1) {
+      return rewriter.notifyMatchFailure(op, "loop step is not 1");
+    }
+
+    llvm::APInt lb, ub;
+    if (!mlir::matchPattern(loop.getLowerBound(), mlir::m_ConstantInt(&lb)) ||
+        !mlir::matchPattern(loop.getUpperBound(), mlir::m_ConstantInt(&ub))) {
+      return rewriter.notifyMatchFailure(op, "bounds are not constants");
+    }
+    if (lb.getSExtValue() != 0) {
+      return rewriter.notifyMatchFailure(op, "lower bound is not 0");
+    }
+
+    auto source = GetSource(op);
+    if (!source.getParentRegion()->isProperAncestor(&loop.getBodyRegion())) {
+      return rewriter.notifyMatchFailure(
+          op, "source is not defined outside the loop");
+    }
+
+    if (!DoIndicesDependOnInductionVar(op.getIndices(), loop)) {
+      // We don't run LICM between iterations, so this could happen.
+      // Just hoist the load out of the loop.
+      rewriter.moveOpBefore(op, loop);
+      return mlir::success();
+    }
+
+    if (!CanReplaceInductionVar(op.getIndices())) {
+      return rewriter.notifyMatchFailure(op, "unable to replace indices");
+    }
+
+    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    mlir::Value zero = b.create<mlir::arith::ConstantIndexOp>(0);
+
+    b.setInsertionPoint(loop);
+    auto first_args =
+        ReplaceInductionVar(loop.getInductionVar(), zero, op.getOperands(), b);
+    auto loaded_first =
+        b.create<Op>(op->getResultTypes(), first_args, op->getAttrs());
+    auto ub_minus_one =
+        b.create<mlir::arith::ConstantIndexOp>(ub.getSExtValue() - 1);
+
+    b.setInsertionPointToStart(loop.getBody());
+
+    auto needs_load = b.create<mlir::arith::CmpIOp>(
+        mlir::arith::CmpIPredicate::ult, loop.getInductionVar(), ub_minus_one);
+    auto next_value =
+        b.create<mlir::scf::IfOp>(op->getResultTypes(), needs_load, true, true);
+    auto new_for =
+        mlir::cast<mlir::scf::ForOp>(*loop.replaceWithAdditionalYields(
+            rewriter, loaded_first->getResult(0),
+            /*replaceInitOperandUsesInLoop=*/false,
+            [&](mlir::OpBuilder&, mlir::Location,
+                llvm::ArrayRef<mlir::BlockArgument>) {
+              return llvm::SmallVector<mlir::Value>{next_value->getResult(0)};
+            }));
+    rewriter.replaceAllUsesWith(op, new_for.getRegionIterArgs().back());
+
+    b.setInsertionPointToStart(next_value.thenBlock());
+    auto yield = b.create<mlir::scf::YieldOp>(op->getResult(0));
+
+    // We use this convoluted way to add 1 so folding works properly.
+    auto plus_one_map = mlir::AffineMap::get(
+        1, 0, mlir::getAffineDimExpr(0, this->getContext()) + 1);
+    b.setInsertionPoint(next_value);
+    IndexingMap indexing_map(plus_one_map, {DimVar{0, ub.getSExtValue() - 1}},
+                             /*range_vars=*/{}, /*rt_vars=*/{});
+    auto induction_plus_one =
+        b.create<ApplyIndexingOp>(new_for.getInductionVar(), indexing_map)
+            ->getResult(0);
+
+    // Create the new apply_indexing ops outside the if, to improve CSE.
+    rewriter.modifyOpInPlace(op, [&]() {
+      op->setOperands(ReplaceInductionVar(
+          new_for.getInductionVar(), induction_plus_one, op->getOperands(), b));
+    });
+    rewriter.moveOpBefore(op, yield);
+
+    b.setInsertionPointToStart(next_value.elseBlock());
+    b.create<mlir::scf::YieldOp>(new_for.getRegionIterArgs().back());
+    return mlir::success();
+  }
+};
+
+int GetUnrollingFactor(mlir::scf::ForOp op) {
+  // We only unroll loops with a step of 1 and a lower bound of 0. That's the
+  // only type we generate.
+  if (auto step = op.getConstantStep(); !step || step->getSExtValue() != 1) {
+    return 1;
+  }
+  llvm::APInt lb, ub;
+  if (!mlir::matchPattern(op.getLowerBound(), mlir::m_ConstantInt(&lb)) ||
+      !mlir::matchPattern(op.getUpperBound(), mlir::m_ConstantInt(&ub))) {
+    return 1;
+  }
+  if (lb.getSExtValue() != 0) {
+    return 1;
+  }
+
+  int64_t trip_count = ub.getSExtValue();
+  constexpr int kMaxSize = 400;  // Chosen empirically.
+
+  // Get a rough estimate of the size of the loop body.
+  int64_t size = 0;
+  op.getBodyRegion().walk([&](mlir::Operation* op) {
+    if (mlir::isa<mlir::func::CallOp, mlir::scf::ForOp>(op)) {
+      size += kMaxSize;
+      return;
+    }
+
+    int64_t this_size = 1;
+    if (mlir::isa<mlir::math::MathDialect>(op->getDialect())) {
+      // Integer instructions in math are ok, but many float ops lower to lots
+      // of instructions.
+      if (!op->getResultTypes().front().isIntOrIndex()) {
+        namespace mm = mlir::math;
+        // We err on the side of not unrolling, so we maintain a list of ops
+        // known to be cheap.
+        if (!mlir::isa<mm::AbsFOp, mm::CeilOp, mm::CopySignOp, mm::FloorOp,
+                       mm::FmaOp, mm::RoundEvenOp, mm::RoundOp, mm::RsqrtOp,
+                       mm::SqrtOp, mm::TruncOp>(op)) {
+          this_size = 20;  // Rough estimate.
+        }
+      }
+    }
+
+    if (!op->getResultTypes().empty()) {
+      if (auto vector_ty =
+              mlir::dyn_cast<mlir::VectorType>(op->getResultTypes().front())) {
+        this_size *= vector_ty.getNumElements();
+      }
+    }
+
+    size += this_size;
+  });
+
+  int factor = std::min(trip_count, kMaxSize / size);
+  while (factor > 1 && trip_count % factor) {
+    --factor;
+  }
+  return factor;
+}
+
+struct UnrollLoops : mlir::OpRewritePattern<mlir::scf::ForOp> {
+  using mlir::OpRewritePattern<mlir::scf::ForOp>::OpRewritePattern;
+
+  mlir::LogicalResult matchAndRewrite(
+      mlir::scf::ForOp op, mlir::PatternRewriter& rewriter) const override {
+    if (int factor = GetUnrollingFactor(op); factor > 1) {
+      return mlir::loopUnrollByFactor(op, factor);
+    }
+    return rewriter.notifyMatchFailure(op, "loop can't be unrolled");
+  }
+};
+
+class OptimizeLoopsPass
+    : public impl::OptimizeLoopsPassBase<OptimizeLoopsPass> {
+ public:
+  void runOnOperation() override {
+    // First unroll loops. If unrolling is possible, we prefer it.
+    mlir::RewritePatternSet unroll_patterns(&getContext());
+    unroll_patterns.add<UnrollLoops>(&getContext());
+    if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
+            getOperation(), std::move(unroll_patterns)))) {
+      signalPassFailure();
+      return;
+    }
+
+    // Then pipeline the remaining loops.
+    mlir::RewritePatternSet patterns(&getContext());
+    patterns.add<PipelineLoad<mlir::vector::TransferReadOp>,
+                 PipelineLoad<mlir::tensor::ExtractOp>>(&getContext());
+    if (mlir::failed(mlir::applyPatternsAndFoldGreedily(getOperation(),
+                                                        std::move(patterns)))) {
+      signalPassFailure();
+    }
+  }
+};
+
+}  // namespace
+
+std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
+CreateOptimizeLoopsPass() {
+  return std::make_unique<OptimizeLoopsPass>();
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/passes.h b/third_party/xla/xla/service/gpu/fusions/transforms/passes.h
new file mode 100644
index 0000000..e70af75
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/passes.h
@@ -0,0 +1,63 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_FUSIONS_TRANSFORMS_PASSES_H_
+#define XLA_SERVICE_GPU_FUSIONS_TRANSFORMS_PASSES_H_
+
+#include <memory>
+#include <optional>
+#include <string>
+
+#include "mlir/IR/Value.h"
+#include "mlir/Pass/Pass.h"
+#include "xla/service/gpu/model/indexing_map.h"
+
+namespace xla {
+namespace gpu {
+
+#define GEN_PASS_DECL
+#include "xla/service/gpu/fusions/transforms/passes.h.inc"
+
+// Returns the range of a given value, if it can be statically determined.
+std::optional<Interval> GetRange(mlir::Value value);
+
+// Returns the range for the induction variable, if it can be statically
+// determined.
+std::optional<Interval> GetIVRange(mlir::Value iv);
+
+std::unique_ptr<mlir::Pass> CreateEraseDeadFunctionsPass();
+std::unique_ptr<mlir::Pass> CreateExpandFloatOpsPass(bool pre_ampere);
+std::unique_ptr<mlir::Pass> CreateConvertPureCallOpsPass();
+std::unique_ptr<mlir::Pass> CreateFlattenTensorsPass();
+std::unique_ptr<mlir::Pass> CreateLowerTensorsPass(
+    bool is_amd_gpu = false, const std::string& gpu_arch = "6.0");
+std::unique_ptr<mlir::Pass> CreateLowerToLLVMPass();
+std::unique_ptr<mlir::Pass> CreateLowerXlaGpuToScfPass();
+std::unique_ptr<mlir::Pass> CreateLowerXlaGpuLoopsToScfPass();
+std::unique_ptr<mlir::Pass> CreateMergePointersToSameSlicePass();
+std::unique_ptr<mlir::Pass> CreateOptimizeLoopsPass();
+std::unique_ptr<mlir::Pass> CreatePeelLoopsPass();
+std::unique_ptr<mlir::Pass> CreatePropagateSliceIndicesPass();
+std::unique_ptr<mlir::Pass> CreateSimplifyAffinePass();
+std::unique_ptr<mlir::Pass> CreateSimplifyArithPass();
+std::unique_ptr<mlir::Pass> CreateUnswitchLoopsPass();
+std::unique_ptr<mlir::Pass> CreateVectorizeLoadsAndStoresPass();
+
+#define GEN_PASS_REGISTRATION
+#include "xla/service/gpu/fusions/transforms/passes.h.inc"
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_FUSIONS_TRANSFORMS_PASSES_H_
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/passes.td b/third_party/xla/xla/service/gpu/fusions/transforms/passes.td
new file mode 100644
index 0000000..af27b36
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/passes.td
@@ -0,0 +1,319 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_FUSIONS_TRANSFORMS_PASSES_TD_
+#define XLA_SERVICE_GPU_FUSIONS_TRANSFORMS_PASSES_TD_
+
+include "mlir/Pass/PassBase.td"
+
+def PropagateSliceIndicesPass :
+   Pass<"xla-gpu-propagate-slice-indices", "mlir::ModuleOp"> {
+  let summary = "Propagates slice indices from the entry function to all callees.";
+
+  let description = [{
+      Propagates xla.slice_index attributes from the function with the xla.entry
+      attribute to all other functions.
+  }];
+
+  let dependentDialects = [
+    "mlir::func::FuncDialect"
+  ];
+
+  let constructor = "CreatePropagateSliceIndicesPass()";
+}
+
+def ConvertPureCallOpsPass
+    : Pass<"xla-gpu-convert-pure-call-ops", "mlir::func::FuncOp"> {
+  let summary = "Converts xla_gpu.pure_call to func.call";
+  let description = [{
+      We use xla_gpu.pure_call ops for calls to enable CSE and other
+      transformations (e.g. LICM). This pass rewrites our custom ops to standard
+      ops.
+  }];
+  let dependentDialects = [
+    "mlir::func::FuncDialect",
+    "xla::gpu::XlaGpuDialect"
+  ];
+  let constructor = "CreateConvertPureCallOpsPass()";
+}
+
+def FlattenTensorsPass : Pass<"xla-gpu-flatten-tensors", "mlir::ModuleOp"> {
+  let summary = "Flatten tensors.";
+
+  let description = [{
+    Linearizes all tensors loads and stores.
+  }];
+
+  let dependentDialects = [
+    "mlir::func::FuncDialect",
+    "mlir::tensor::TensorDialect",
+    "xla::gpu::XlaGpuDialect",
+  ];
+  let constructor = "CreateFlattenTensorsPass()";
+}
+
+def LowerTensorsPass : Pass<"xla-gpu-lower-tensors", "mlir::ModuleOp"> {
+  let summary = "Lowers tensors to llvm pointers and loads/stores.";
+
+  let description = [{
+      Lowers tensors to LLVM. We cannot use the memref lowerings because they
+      are not compatible with XLA's ABI.
+  }];
+
+  let dependentDialects = [
+    "mlir::LLVM::LLVMDialect",
+    "mlir::func::FuncDialect",
+    "mlir::gpu::GPUDialect",
+    "mlir::scf::SCFDialect",
+    "mlir::tensor::TensorDialect",
+    "xla::gpu::XlaGpuDialect",
+  ];
+  let options = [
+    Option<"is_amd_gpu_", "is_amd_gpu", "bool", /*default=*/"false",
+           "True if AMD GPU.">,
+    Option<"gpu_arch_", "gpu_arch", "std::string", /*default=*/"",
+           "CUDA or ROCm compute capability.">,
+  ];
+  let constructor = "CreateLowerTensorsPass()";
+}
+
+def MergePointersToSameSlicePass :
+   Pass<"xla-gpu-merge-pointers", "mlir::ModuleOp"> {
+  let summary = "Merges pointers that share slices.";
+
+  let description = [{
+      When a function has multiple pointer arguments with the same slice index,
+      merges them.
+  }];
+
+  let dependentDialects = [
+    "mlir::func::FuncDialect"
+  ];
+
+  let constructor = "CreateMergePointersToSameSlicePass()";
+}
+
+def SimplifyArithPass : Pass<"xla-gpu-simplify-arith", "mlir::func::FuncOp"> {
+  let summary = "Simplifies arith using XLA's range-aware simplifier.";
+
+  let description = [{
+      We often emit bounds checks that are statically known to be satisfied.
+      This pass removes them.
+  }];
+
+  let dependentDialects = [
+    "mlir::arith::ArithDialect",
+    "mlir::func::FuncDialect",
+  ];
+
+  let constructor = "CreateSimplifyArithPass()";
+}
+
+def SimplifyAffinePass : Pass<"xla-gpu-simplify-affine", "mlir::ModuleOp"> {
+  let summary = "Simplifies affine.apply using XLA's range-aware simplifier.";
+
+  let description = [{
+      The standard affine canonicalizer cannot simplify all expressions, since
+      it is unaware of range information. This pass uses `xla.range` attributes
+      on arguments and ops for simplification. It also lowers floordiv and mod
+      to simpler expressions than lower-affine. This pass only works for
+      expressions for which we can prove the LHS of mod and div is nonnegative.
+  }];
+
+  let dependentDialects = [
+    "mlir::affine::AffineDialect", "mlir::func::FuncDialect",
+    "mlir::scf::SCFDialect",
+  ];
+
+  let constructor = "CreateSimplifyAffinePass()";
+}
+
+def ExpandFloatOpsPass : Pass<"xla-gpu-expand-float-ops", "mlir::ModuleOp"> {
+  let summary = "Expands float ops that are not natively supported.";
+
+  let description = [{
+     Not all float ops are natively supported, either because they don't exist
+     in hardware or they are too inaccurate.
+
+     This pass replaces these ops with alternative implementations.
+  }];
+
+  let dependentDialects = [
+    "mlir::arith::ArithDialect", "mlir::math::MathDialect",
+    "mlir::mhlo::MhloDialect"
+  ];
+
+  let options = [
+    Option<"pre_ampere_", "pre-ampere", "bool", /*default=*/"false",
+           "Rewrite ops that are not supported on architectures before Ampere">,
+  ];
+}
+
+def LowerXlaGpuToScfPass :
+   Pass<"xla-gpu-lower-xla-gpu-to-scf", "mlir::func::FuncOp"> {
+  let summary = "Lowers xla_gpu to SCF.";
+
+  let dependentDialects = [
+    "mlir::gpu::GPUDialect", "mlir::LLVM::LLVMDialect", "mlir::scf::SCFDialect",
+    "mlir::tensor::TensorDialect", "xla::gpu::XlaGpuDialect",
+  ];
+
+  let constructor = "CreateLowerXlaGpuToScfPass()";
+}
+
+def LowerXlaGpuLoopsToScfPass : Pass<
+    "xla-gpu-lower-xla-gpu-loops-to-scf", "mlir::func::FuncOp"> {
+  let summary = "Lowers xla_gpu.loop to SCF.";
+
+  let description = [{
+    This pass is separate from lower-xla-gpu-to-scf because
+    lower-xla-gpu-to-scf, inliner, peeling and lower-xla-gpu-loops-to-scf
+    have to run in that order.
+  }];
+
+  let dependentDialects = [
+    "mlir::scf::SCFDialect",
+    "mlir::tensor::TensorDialect", "xla::gpu::XlaGpuDialect",
+  ];
+
+  let constructor = "CreateLowerXlaGpuLoopsToScfPass()";
+}
+
+def EraseDeadFunctionsPass : Pass<"xla-erase-dead-functions", "mlir::ModuleOp"> {
+  let summary = "Deletes unused functions";
+
+  let description = [{
+      Deletes functions that are not called.
+  }];
+
+  let dependentDialects = [
+    "mlir::func::FuncDialect",
+    "xla::gpu::XlaGpuDialect"
+  ];
+
+  let constructor = "CreateEraseDeadFunctionsPass()";
+}
+
+def LowerToLLVMPass :
+   Pass<"xla-gpu-lower-to-llvm", "mlir::ModuleOp"> {
+  let summary = "Lowers to LLVM.";
+
+  let description = [{
+    Lowers the rest to LLVM
+  }];
+
+  let dependentDialects = [
+    "mlir::func::FuncDialect",
+    "mlir::LLVM::LLVMDialect",
+    "mlir::NVVM::NVVMDialect",
+  ];
+
+  let constructor = "CreateLowerToLLVMPass()";
+}
+
+def VectorizeLoadsAndStoresPass :
+   Pass<"xla-gpu-vectorize-loads-stores", "mlir::func::FuncOp"> {
+  let summary = "Vectorizes loads and stores.";
+
+  let description = [{
+    Rewrites tensor.extract and tensor.insert ops inside loops to their vector
+    equivalents (vector.transfer_read and vector.transfer_write + vector.extract
+    and vector.insert).
+  }];
+
+  let dependentDialects = [
+    "mlir::vector::VectorDialect",
+  ];
+
+  let constructor = "CreateVectorizeLoadsAndStoresPass()";
+}
+
+def PeelLoopsPass : Pass<"xla-gpu-peel-loops", "mlir::func::FuncOp"> {
+  let summary = "Peels xla_gpu.loop.";
+  let description = [{
+      Attempts to split each loop dimension [0, NUM_ITERATIONS)
+      as [0, NUM_ITERATIONS - 1) and [NUM_ITERATIONS - 1, NUM_ITERATIONS)
+      if it removes a constraint.
+  }];
+  let dependentDialects = ["xla::gpu::XlaGpuDialect"];
+  let constructor = "CreatePeelLoopsPass()";
+}
+
+def OptimizeLoopsPass :
+   Pass<"xla-gpu-optimize-loops", "mlir::func::FuncOp"> {
+  let summary = "Unrolls and pipelines loops.";
+
+  let description = [{
+    Unrolls loops with a small trip count. Pipelines loops with a large trip
+    count.
+  }];
+
+  let dependentDialects = [
+    "mlir::vector::VectorDialect",
+    "xla::gpu::XlaGpuDialect",
+  ];
+
+  let constructor = "CreateOptimizeLoopsPass()";
+}
+
+def UnswitchLoopsPass :
+   Pass<"xla-gpu-unswitch-loops", "mlir::func::FuncOp"> {
+  let summary = "Swaps scf.if and scf.for.";
+
+  let description = [{
+      Extracts `scf.if` ops with conditions that are independent of the loop
+      variable from `scf.for` by doing the following rewrite:
+
+      Before:
+
+      %cond = some_cond() : i1
+      %results = scf.for {
+        %some_val = scf.if %cond  {
+        } else {
+        }
+        scf.yield %some_val
+      }
+
+      After:
+
+      %cond = some_cond() : i1
+      %results = scf.if %cond {
+         %results = scf.for {
+            %some_val = scf.if %true  {
+            } else {
+            }
+         }
+         yield %results
+      } else {
+         %results = scf.for {
+            %some_val = scf.if %false  {
+            } else {
+            }
+         }
+         yield %results
+      }
+
+      This only triggers if there is a single `scf.if` op in the loop body (and
+      nothing else).
+  }];
+
+  let dependentDialects = [
+    "mlir::func::FuncDialect", "mlir::scf::SCFDialect"
+  ];
+
+  let constructor = "CreateUnswitchLoopsPass()";
+}
+
+#endif  // XLA_SERVICE_GPU_FUSIONS_TRANSFORMS_PASSES_TD_
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/peel_loops.cc b/third_party/xla/xla/service/gpu/fusions/transforms/peel_loops.cc
new file mode 100644
index 0000000..a8157a4
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/peel_loops.cc
@@ -0,0 +1,149 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <cstdint>
+#include <memory>
+#include <utility>
+
+#include "llvm/ADT/STLExtras.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/SCF/Utils/Utils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h"
+#include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h"
+#include "xla/service/gpu/model/indexing_map.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+#define GEN_PASS_DEF_PEELLOOPSPASS
+#include "xla/service/gpu/fusions/transforms/passes.h.inc"
+
+using mlir::Location;
+using mlir::OpBuilder;
+using mlir::OpRewritePattern;
+using mlir::PatternRewriter;
+using mlir::SmallVector;
+using mlir::Value;
+using mlir::ValueRange;
+
+struct PeelLoop : public OpRewritePattern<LoopOp> {
+  using OpRewritePattern<LoopOp>::OpRewritePattern;
+
+  mlir::LogicalResult matchAndRewrite(
+      LoopOp loop_op, PatternRewriter& rewriter) const override {
+    int64_t cumulative_loop_size = 1;
+
+    // Compute the list of indexing maps. The last element is the "peeled" or
+    // "main" loop. Everything else is a "tail" loop.
+    auto indexing_map = loop_op.getIndexingMap();
+    // TODO(b/358274367): Remove the simplify call once we have `is_simplified`
+    // field and a canonicalization pattern to simplify indexing map in
+    // xla_gpu.loop.
+    indexing_map.Simplify();
+    SmallVector<IndexingMap> indexing_maps{indexing_map};
+    for (int sym_index = indexing_map.GetSymbolCount() - 1;
+         sym_index >= 0 && cumulative_loop_size < 64; --sym_index) {
+      IndexingMap indexing_map = indexing_maps.back();
+      auto& bound = indexing_map.GetSymbolBound(sym_index);
+      cumulative_loop_size *= bound.GetLoopTripCount();
+      if (!indexing_map.IsSymbolConstrained(sym_index) ||
+          bound.upper == bound.lower) {
+        continue;
+      }
+      // Create peeled indexing map.
+      IndexingMap peeled_map = indexing_map;
+      --peeled_map.GetMutableSymbolBound(sym_index).upper;
+      peeled_map.Simplify();
+
+      // If the symbol is still constrained, peeling does not help.
+      if (peeled_map.IsSymbolConstrained(sym_index)) continue;
+
+      // Create remainder indexing map.
+      IndexingMap tail_map = indexing_map;
+      tail_map.GetMutableSymbolBound(sym_index).lower = bound.upper;
+      tail_map.Simplify();
+
+      VLOG(5) << "Peeled indexing map\n"
+              << indexing_map.ToString() << "into\n"
+              << peeled_map.ToString() << "and\n"
+              << tail_map.ToString() << "\n";
+      indexing_maps.pop_back();
+      indexing_maps.push_back(tail_map);
+      indexing_maps.push_back(peeled_map);
+    }
+
+    if (indexing_maps.size() == 1) {
+      return rewriter.notifyMatchFailure(loop_op,
+                                         "No range variables to peel.");
+    }
+
+    // Create chained loops from the list of indexing maps.
+    Location loc = loop_op.getLoc();
+    SmallVector<Value, 4> inits = loop_op.getInits();
+    for (const auto& indexing_map : llvm::reverse(indexing_maps)) {
+      if (indexing_map.IsKnownEmpty()) continue;
+      auto tail_loop = rewriter.create<LoopOp>(
+          loc, indexing_map, loop_op.getDims(), inits,
+          [&](OpBuilder& nested_b, Location nested_loc, ValueRange ivs,
+              ValueRange iter_args) {
+            OpBuilder::InsertionGuard guard(nested_b);
+            mlir::IRMapping mapping;
+            mapping.map(loop_op.getInductionVars(), ivs);
+            mapping.map(loop_op.getRegionIterArgs(), iter_args);
+            for (auto& op : loop_op.getBody()->getOperations()) {
+              nested_b.clone(op, mapping);
+            }
+          });
+      inits = tail_loop.getResults();
+    }
+    rewriter.replaceOp(loop_op, inits);
+    return mlir::success();
+  }
+};
+
+struct PeelLoopsPass : public impl::PeelLoopsPassBase<PeelLoopsPass> {
+  void runOnOperation() override {
+    auto func = getOperation();
+    mlir::MLIRContext* mlir_context = &getContext();
+    mlir::RewritePatternSet patterns(mlir_context);
+    patterns.add<PeelLoop>(mlir_context);
+    if (mlir::failed(
+            mlir::applyPatternsAndFoldGreedily(func, std::move(patterns)))) {
+      signalPassFailure();
+      return;
+    }
+  }
+};
+
+}  // namespace
+
+std::unique_ptr<mlir::Pass> CreatePeelLoopsPass() {
+  return std::make_unique<PeelLoopsPass>();
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/propagate_slice_indices.cc b/third_party/xla/xla/service/gpu/fusions/transforms/propagate_slice_indices.cc
new file mode 100644
index 0000000..31a6379
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/propagate_slice_indices.cc
@@ -0,0 +1,80 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "xla/service/gpu/fusions/transforms/passes.h"
+
+namespace xla {
+namespace gpu {
+
+#define GEN_PASS_DEF_PROPAGATESLICEINDICESPASS
+#include "xla/service/gpu/fusions/transforms/passes.h.inc"
+
+namespace {
+
+class PropagateSliceIndicesPass
+    : public impl::PropagateSliceIndicesPassBase<PropagateSliceIndicesPass> {
+ public:
+  void runOnOperation() override;
+};
+
+void PropagateSliceIndicesPass::runOnOperation() {
+  mlir::func::FuncOp entry;
+  for (auto func : getOperation().getOps<mlir::func::FuncOp>()) {
+    if (func->getAttr("xla.entry")) {
+      entry = func;
+      break;
+    }
+  }
+
+  if (!entry) {
+    getOperation()->emitOpError("No entry function found.");
+    signalPassFailure();
+    return;
+  }
+
+  for (auto func : getOperation().getOps<mlir::func::FuncOp>()) {
+    if (func.getNumArguments() == 0 || func == entry) {
+      continue;
+    }
+
+    for (int i = 0; i < func.getNumArguments(); ++i) {
+      if (mlir::isa<mlir::RankedTensorType>(func.getArgument(i).getType())) {
+        if (auto index = entry.getArgAttr(i, "xla.slice_index")) {
+          func.setArgAttr(i, "xla.slice_index", index);
+        }
+        if (auto invariant = entry.getArgAttr(i, "xla.invariant")) {
+          func.setArgAttr(i, "xla.invariant", invariant);
+        }
+      } else {
+        break;
+      }
+    }
+  }
+}
+
+}  // namespace
+
+std::unique_ptr<mlir::Pass> CreatePropagateSliceIndicesPass() {
+  return std::make_unique<PropagateSliceIndicesPass>();
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/simplify_affine.cc b/third_party/xla/xla/service/gpu/fusions/transforms/simplify_affine.cc
new file mode 100644
index 0000000..acbd9d3
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/simplify_affine.cc
@@ -0,0 +1,368 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <cstdint>
+#include <limits>
+#include <memory>
+#include <optional>
+#include <utility>
+#include <vector>
+
+#include "absl/base/optimization.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/LogicalResult.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/LoopUtils.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
+#include "mlir/IR/ValueRange.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h"
+#include "xla/service/gpu/fusions/transforms/passes.h"
+#include "xla/service/gpu/model/indexing_map.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+using mlir::AffineBinaryOpExpr;
+using mlir::AffineConstantExpr;
+using mlir::AffineDimExpr;
+using mlir::AffineExpr;
+using mlir::AffineExprKind;
+using mlir::AffineMap;
+using mlir::AffineSymbolExpr;
+using mlir::ImplicitLocOpBuilder;
+using mlir::LogicalResult;
+using mlir::MLIRContext;
+using mlir::OpRewritePattern;
+using mlir::PatternRewriter;
+using mlir::SmallVector;
+using mlir::Value;
+using mlir::ValueRange;
+using mlir::affine::AffineApplyOp;
+
+namespace arith = mlir::arith;
+
+#define GEN_PASS_DEF_SIMPLIFYAFFINEPASS
+#include "xla/service/gpu/fusions/transforms/passes.h.inc"
+
+int Distance(ImplicitLocOpBuilder& builder, Value a) {
+  auto* block = builder.getInsertionBlock();
+  auto* parent = a.getParentBlock();
+  int distance = 0;
+  while (block && block != parent) {
+    ++distance;
+    block = block->getParentOp()->getBlock();
+  }
+  return distance;
+}
+
+void CollectArgs(AffineExpr expr, AffineExprKind kind,
+                 llvm::SmallVector<AffineExpr>& ret) {
+  if (auto bin_op = mlir::dyn_cast<AffineBinaryOpExpr>(expr)) {
+    if (bin_op.getKind() == kind) {
+      CollectArgs(bin_op.getLHS(), kind, ret);
+      CollectArgs(bin_op.getRHS(), kind, ret);
+      return;
+    }
+  }
+  ret.push_back(expr);
+}
+
+struct ExpressionEvaluator {
+  ExpressionEvaluator(ImplicitLocOpBuilder& builder, unsigned dim_count,
+                      ValueRange operands)
+      : builder(builder), operands(operands) {
+    for (int i = 0; i < dim_count; ++i) {
+      dim_distances.push_back(Distance(builder, operands[i]));
+    }
+    for (int i = dim_count; i < operands.size(); ++i) {
+      sym_distances.push_back(Distance(builder, operands[i]));
+    }
+  }
+
+  // Returns the distance (in basic blocks) from the insertion point to the
+  // values used in the given expression.
+  int ExprDistance(AffineExpr e, int depth = 0) {
+    if (auto dim = mlir::dyn_cast<AffineDimExpr>(e)) {
+      return dim_distances[dim.getPosition()];
+    }
+    if (auto sym = mlir::dyn_cast<AffineSymbolExpr>(e)) {
+      return sym_distances[sym.getPosition()];
+    }
+    if (auto binop = mlir::dyn_cast<AffineBinaryOpExpr>(e)) {
+      return std::min(ExprDistance(binop.getLHS(), depth + 1),
+                      ExprDistance(binop.getRHS(), depth + 1));
+    }
+    if (depth == 0) {
+      // Top-level constant. Always add these last.
+      return std::numeric_limits<int>::min();
+    }
+    // Nested constant. Ignore these for distances.
+    return std::numeric_limits<int>::max();
+  }
+
+  Value EvaluateExpression(AffineExpr expr);
+
+  template <typename Op>
+  Value EvaluateAddMul(AffineExpr expr);
+
+  ImplicitLocOpBuilder& builder;
+  ValueRange operands;
+  SmallVector<int> dim_distances;
+  SmallVector<int> sym_distances;
+};
+
+template <typename Op>
+Value ExpressionEvaluator::EvaluateAddMul(AffineExpr expr) {
+  llvm::SmallVector<AffineExpr> args;
+  CollectArgs(expr, expr.getKind(), args);
+  // Sort the args so that the ones that are closest to the insertion point
+  // are evaluated last - this improves LICM.
+  llvm::stable_sort(args, [&](AffineExpr a, AffineExpr b) {
+    int dist_a = ExprDistance(a);
+    int dist_b = ExprDistance(b);
+    return dist_a > dist_b;
+  });
+
+  Value result = nullptr;
+  for (auto arg : args) {
+    Value arg_evaluated = EvaluateExpression(arg);
+    if (result) {
+      result = builder.create<Op>(result, arg_evaluated);
+    } else {
+      result = arg_evaluated;
+    }
+  }
+
+  return result;
+}
+
+Value ExpressionEvaluator::EvaluateExpression(AffineExpr expr) {
+  if (auto bin_op = mlir::dyn_cast<AffineBinaryOpExpr>(expr)) {
+    switch (expr.getKind()) {
+      case AffineExprKind::Add:
+        return EvaluateAddMul<arith::AddIOp>(expr);
+      case AffineExprKind::Mul:
+        return EvaluateAddMul<arith::MulIOp>(expr);
+      case AffineExprKind::Mod:
+        return builder.create<arith::RemUIOp>(
+            EvaluateExpression(bin_op.getLHS()),
+            EvaluateExpression(bin_op.getRHS()));
+      case AffineExprKind::FloorDiv:
+        return builder.create<arith::DivUIOp>(
+            EvaluateExpression(bin_op.getLHS()),
+            EvaluateExpression(bin_op.getRHS()));
+      default:
+        ABSL_UNREACHABLE();
+    }
+  }
+  switch (expr.getKind()) {
+    case AffineExprKind::Constant:
+      return builder.create<arith::ConstantIndexOp>(
+          mlir::cast<AffineConstantExpr>(expr).getValue());
+    case AffineExprKind::DimId:
+      return operands[mlir::cast<AffineDimExpr>(expr).getPosition()];
+    case AffineExprKind::SymbolId:
+      return operands[dim_distances.size() +
+                      mlir::cast<AffineSymbolExpr>(expr).getPosition()];
+    default:
+      ABSL_UNREACHABLE();
+  }
+}
+
+bool IsLoweringSupported(AffineExpr expr, RangeEvaluator& range_evaluator) {
+  auto bin_op = llvm::dyn_cast<AffineBinaryOpExpr>(expr);
+  if (!bin_op) {
+    return true;
+  }
+  // Mod and div can be lowered if their LHS is >= 0 and their RHS is a
+  // constant.
+  if (expr.getKind() == AffineExprKind::Mod ||
+      expr.getKind() == AffineExprKind::FloorDiv) {
+    if (!range_evaluator.IsAlwaysPositiveOrZero(bin_op.getLHS()) ||
+        !range_evaluator.ComputeExpressionRange(bin_op.getRHS()).IsPoint()) {
+      return false;
+    }
+  }
+  if (expr.getKind() == AffineExprKind::CeilDiv) {
+    return false;
+  }
+  return IsLoweringSupported(bin_op.getLHS(), range_evaluator) &&
+         IsLoweringSupported(bin_op.getRHS(), range_evaluator);
+}
+
+struct RewriteAffineApply : OpRewritePattern<mlir::affine::AffineApplyOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(mlir::affine::AffineApplyOp op,
+                                PatternRewriter& rewriter) const override {
+    AffineMap affine_map = op.getAffineMap();
+    std::vector<DimVar> dim_ranges(affine_map.getNumDims());
+    std::vector<RangeVar> symbol_ranges(affine_map.getNumSymbols());
+
+    for (int i = 0; i < affine_map.getNumInputs(); ++i) {
+      if (auto range = GetRange(op->getOperand(i))) {
+        if (i >= dim_ranges.size()) {
+          symbol_ranges[i - dim_ranges.size()] = RangeVar{*range};
+        } else {
+          dim_ranges[i] = DimVar{*range};
+        }
+      } else {
+        return rewriter.notifyMatchFailure(op, "failed to deduce range");
+      }
+    }
+
+    IndexingMap indexing_map(affine_map, std::move(dim_ranges),
+                             std::move(symbol_ranges),
+                             /*rt_vars=*/{});
+    indexing_map.Simplify();
+    auto result_expr = indexing_map.GetAffineMap().getResult(0);
+
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    RangeEvaluator range_evaluator = indexing_map.GetRangeEvaluator();
+    if (!IsLoweringSupported(result_expr, range_evaluator)) {
+      return rewriter.notifyMatchFailure(op,
+                                         "unable to lower the affine apply");
+    }
+    b.setInsertionPoint(op);
+    auto result = ExpressionEvaluator(b, indexing_map.GetDimensionCount(),
+                                      op->getOperands())
+                      .EvaluateExpression(result_expr);
+    rewriter.replaceOp(op, result);
+    return mlir::success();
+  }
+};
+
+struct RewriteApplyIndexingOp : OpRewritePattern<ApplyIndexingOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ApplyIndexingOp op,
+                                PatternRewriter& rewriter) const override {
+    auto indexing_map = op.getIndexingMap();
+    indexing_map.Simplify();
+    auto affine_map = indexing_map.GetAffineMap();
+    int64_t dim_count = indexing_map.GetDimensionCount();
+    auto operands = op->getOperands();
+
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    RangeEvaluator range_evaluator = indexing_map.GetRangeEvaluator();
+
+    b.setInsertionPoint(op);
+    SmallVector<Value, 4> results;
+    results.reserve(affine_map.getNumResults());
+    for (unsigned i = 0; i < affine_map.getNumResults(); ++i) {
+      AffineExpr result_expr = affine_map.getResult(i);
+      // If the expression cannot be lowered, we convert it to affine.apply,
+      // since it supports more expression types.
+      if (IsLoweringSupported(result_expr, range_evaluator)) {
+        results.push_back(ExpressionEvaluator(b, dim_count, operands)
+                              .EvaluateExpression(result_expr));
+      } else {
+        results.push_back(
+            b.create<AffineApplyOp>(affine_map.getSubMap({i}), operands));
+      }
+    }
+    rewriter.replaceOp(op, results);
+    return mlir::success();
+  }
+};
+
+struct SimplifyAffinePass
+    : public impl::SimplifyAffinePassBase<SimplifyAffinePass> {
+ public:
+  void runOnOperation() override {
+    MLIRContext* ctx = &getContext();
+    mlir::RewritePatternSet patterns(ctx);
+    patterns.add<RewriteAffineApply, RewriteApplyIndexingOp>(ctx);
+    mlir::GreedyRewriteConfig config;
+    // There's no point simplifying more than once.
+    config.strictMode = mlir::GreedyRewriteStrictness::ExistingOps;
+    if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
+            getOperation(), std::move(patterns), config))) {
+      signalPassFailure();
+    }
+  }
+};
+
+}  // namespace
+
+std::optional<Interval> GetRange(mlir::Value value) {
+  auto attr_to_range = [](mlir::Attribute attr) -> std::optional<Interval> {
+    if (!attr) {
+      return std::nullopt;
+    }
+    auto values = llvm::to_vector(
+        mlir::cast<mlir::ArrayAttr>(attr).getAsValueRange<mlir::IntegerAttr>());
+    return {{values[0].getSExtValue(), values[1].getSExtValue()}};
+  };
+
+  if (auto apply = value.getDefiningOp<ApplyIndexingOp>()) {
+    return apply.getIndexingMap().GetRangeEvaluator().ComputeExpressionRange(
+        apply.getIndexingMap().GetAffineMap().getResult(
+            mlir::cast<mlir::OpResult>(value).getResultNumber()));
+  } else if (auto cst = value.getDefiningOp<mlir::arith::ConstantIndexOp>()) {
+    return {{cst.value(), cst.value()}};
+  } else if (value.getDefiningOp()) {
+    return attr_to_range(value.getDefiningOp()->getAttr("xla.range"));
+  }
+
+  auto bbarg = mlir::dyn_cast<mlir::BlockArgument>(value);
+  if (!bbarg) {
+    return std::nullopt;
+  }
+
+  auto parent = bbarg.getParentBlock()->getParentOp();
+  if (auto func_op = mlir::dyn_cast<mlir::func::FuncOp>(parent)) {
+    return attr_to_range(func_op.getArgAttr(bbarg.getArgNumber(), "xla.range"));
+  }
+  return GetIVRange(value);
+}
+
+std::optional<Interval> GetIVRange(mlir::Value iv) {
+  auto bbarg = mlir::dyn_cast<mlir::BlockArgument>(iv);
+  if (!bbarg) {
+    return std::nullopt;
+  }
+  auto parent = bbarg.getParentBlock()->getParentOp();
+  if (auto for_op = mlir::dyn_cast<mlir::scf::ForOp>(parent)) {
+    llvm::APInt lb, ub;
+    if (mlir::matchPattern(for_op.getLowerBound(), mlir::m_ConstantInt(&lb)) &&
+        mlir::matchPattern(for_op.getUpperBound(), mlir::m_ConstantInt(&ub))) {
+      return {{lb.getSExtValue(), ub.getSExtValue() - 1}};
+    }
+  }
+  return std::nullopt;
+}
+
+std::unique_ptr<mlir::Pass> CreateSimplifyAffinePass() {
+  return std::make_unique<SimplifyAffinePass>();
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/simplify_arith.cc b/third_party/xla/xla/service/gpu/fusions/transforms/simplify_arith.cc
new file mode 100644
index 0000000..bbf4831
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/simplify_arith.cc
@@ -0,0 +1,344 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <cstdint>
+#include <limits>
+#include <memory>
+#include <optional>
+#include <utility>
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h"
+#include "xla/service/gpu/fusions/transforms/passes.h"
+#include "xla/service/gpu/model/indexing_map.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+#define GEN_PASS_DEF_SIMPLIFYARITHPASS
+#include "xla/service/gpu/fusions/transforms/passes.h.inc"
+
+using mlir::LogicalResult;
+using mlir::OpRewritePattern;
+using mlir::PatternRewriter;
+using mlir::arith::CmpIOp;
+using mlir::arith::CmpIPredicate;
+
+Interval::ComparisonResult EvaluateCmpI(CmpIPredicate pred, Interval lhs,
+                                        Interval rhs) {
+  switch (pred) {
+    case CmpIPredicate::eq:
+      return lhs.Eq(rhs);
+    case CmpIPredicate::ne:
+      return lhs.Ne(rhs);
+    case CmpIPredicate::slt:
+    case CmpIPredicate::ult:
+      return lhs.Lt(rhs);
+    case CmpIPredicate::sle:
+    case CmpIPredicate::ule:
+      return lhs.Le(rhs);
+    case CmpIPredicate::sgt:
+    case CmpIPredicate::ugt:
+      return lhs.Gt(rhs);
+    case CmpIPredicate::sge:
+    case CmpIPredicate::uge:
+      return lhs.Ge(rhs);
+  }
+}
+
+struct RewriteCmpI : OpRewritePattern<CmpIOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(CmpIOp op,
+                                PatternRewriter& rewriter) const override {
+    auto rhs = GetRange(op.getRhs());
+    auto lhs = GetRange(op.getLhs());
+    if (!lhs || !rhs) {
+      return rewriter.notifyMatchFailure(op, "failed to deduce input ranges");
+    }
+    Interval::ComparisonResult result =
+        EvaluateCmpI(op.getPredicate(), *lhs, *rhs);
+    if (result != std::nullopt) {
+      rewriter.replaceOpWithNewOp<mlir::arith::ConstantIntOp>(
+          op, *result, rewriter.getI1Type());
+      return mlir::success();
+    }
+    return rewriter.notifyMatchFailure(op, "not a constant result");
+  }
+};
+
+struct RewriteMaxSi : OpRewritePattern<mlir::arith::MaxSIOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(mlir::arith::MaxSIOp op,
+                                PatternRewriter& rewriter) const override {
+    auto lhs = GetRange(op.getLhs());
+    auto rhs = GetRange(op.getRhs());
+    if (!lhs || !rhs) {
+      return rewriter.notifyMatchFailure(op, "failed to deduce input ranges");
+    }
+    if (auto lhs_ge_rhs = lhs->Ge(*rhs); lhs_ge_rhs == true) {
+      rewriter.replaceOp(op, op.getLhs());
+    } else if (auto rhs_ge_lhs = rhs->Ge(*lhs); rhs_ge_lhs == true) {
+      rewriter.replaceOp(op, op.getRhs());
+    } else {
+      return rewriter.notifyMatchFailure(op, "not equal to lhs or rhs");
+    }
+    return mlir::success();
+  }
+};
+
+struct RewriteMinSi : OpRewritePattern<mlir::arith::MinSIOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(mlir::arith::MinSIOp op,
+                                PatternRewriter& rewriter) const override {
+    auto lhs = GetRange(op.getLhs());
+    auto rhs = GetRange(op.getRhs());
+    if (!lhs || !rhs) {
+      return rewriter.notifyMatchFailure(op, "failed to deduce input ranges");
+    }
+    if (auto lhs_le_rhs = lhs->Le(*rhs); lhs_le_rhs == true) {
+      rewriter.replaceOp(op, op.getLhs());
+    } else if (auto rhs_le_lhs = rhs->Le(*lhs); rhs_le_lhs == true) {
+      rewriter.replaceOp(op, op.getRhs());
+    } else {
+      return rewriter.notifyMatchFailure(op, "not equal to lhs or rhs");
+    }
+    return mlir::success();
+  }
+};
+
+// Finds the narrowest value in a use-def chain of truncis/extuis.
+mlir::Value FindNarrowestValueInChain(mlir::Value value) {
+  if (auto ext = value.getDefiningOp<mlir::arith::ExtUIOp>()) {
+    return FindNarrowestValueInChain(ext.getOperand());
+  }
+  auto defining_op = value.getDefiningOp<mlir::arith::TruncIOp>();
+  if (defining_op) {
+    auto first_trunc = FindNarrowestValueInChain(defining_op.getOperand());
+    if (first_trunc && first_trunc.getType().getIntOrFloatBitWidth() <=
+                           defining_op.getType().getIntOrFloatBitWidth()) {
+      return first_trunc;
+    }
+    return defining_op;
+  }
+  return value;
+}
+
+// Rewrites trunc-bitwise to bitwise-trunc.
+//
+// For pred reductions, we generate code like this:
+//
+//   %1 = arith.trunci %0 : i32 to i1
+//   %2 = arith.ori %1, %x
+//   %3 = arith.extui %2 : i1 to i32
+//   %4 = gpu.shuffle %3
+//
+// By swapping the trunc with the or, we get a trunc-ext-shuffle sequence, which
+// can be rewritten to shuffle-trunc-ext. If there is another copy of the
+// pattern afterwards, we can push the truncs/exts further down.
+template <typename Op>
+struct RewriteTruncBitExt : OpRewritePattern<Op> {
+  using OpRewritePattern<Op>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(Op op,
+                                PatternRewriter& rewriter) const override {
+    mlir::Value lhs = FindNarrowestValueInChain(op.getLhs());
+    mlir::Value rhs = FindNarrowestValueInChain(op.getRhs());
+
+    if (lhs.getType() != rhs.getType()) {
+      return rewriter.notifyMatchFailure(op, "mismatched narrowest types");
+    }
+
+    auto trunci_lhs = lhs.getDefiningOp<mlir::arith::TruncIOp>();
+    auto trunci_rhs = rhs.getDefiningOp<mlir::arith::TruncIOp>();
+    if (!trunci_lhs && !trunci_rhs) {
+      return rewriter.notifyMatchFailure(
+          op, "neither narrowest value is the result of a truncation");
+    }
+
+    auto wide_type =
+        (trunci_lhs ? trunci_lhs : trunci_rhs).getOperand().getType();
+    if (trunci_rhs && trunci_rhs.getOperand().getType() != wide_type) {
+      return rewriter.notifyMatchFailure(op, "mismatched truncation types");
+    }
+
+    mlir::Value new_lhs = trunci_lhs ? trunci_lhs.getOperand()
+                                     : rewriter.create<mlir::arith::ExtUIOp>(
+                                           op.getLoc(), wide_type, lhs);
+    mlir::Value new_rhs = trunci_rhs ? trunci_rhs.getOperand()
+                                     : rewriter.create<mlir::arith::ExtUIOp>(
+                                           op.getLoc(), wide_type, rhs);
+    mlir::Value new_op = rewriter.create<Op>(op.getLoc(), new_lhs, new_rhs);
+    rewriter.replaceOpWithNewOp<mlir::arith::TruncIOp>(op, op.getType(),
+                                                       new_op);
+
+    return mlir::success();
+  }
+};
+
+// Rewrites trunc-ext-shuffle to shuffle-trunc-ext. This pattern is designed to
+// work together with RewriteTruncBitExt to optimize pred reductions.
+struct RewriteTruncExtShuffle : public OpRewritePattern<mlir::gpu::ShuffleOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(mlir::gpu::ShuffleOp op,
+                                PatternRewriter& rewriter) const override {
+    auto ext = op.getOperand(0).getDefiningOp<mlir::arith::ExtUIOp>();
+    if (!ext) {
+      return rewriter.notifyMatchFailure(op, "no ext");
+    }
+    auto trunc = ext.getOperand().getDefiningOp<mlir::arith::TruncIOp>();
+    if (!trunc || trunc.getOperand().getType() != ext.getType()) {
+      return rewriter.notifyMatchFailure(op, "no trunc or type mismatch");
+    }
+    rewriter.setInsertionPointAfter(op);
+    auto new_trunc = rewriter.create<mlir::arith::TruncIOp>(
+        op.getLoc(), trunc.getType(), op.getResult(0));
+    auto new_ext = rewriter.create<mlir::arith::ExtUIOp>(
+        op.getLoc(), ext.getType(), new_trunc.getResult());
+    rewriter.modifyOpInPlace(op,
+                             [&]() { op->setOperand(0, trunc.getOperand()); });
+    rewriter.replaceAllUsesExcept(op.getResult(0), new_ext, new_trunc);
+    return mlir::success();
+  }
+};
+
+void AnnotateRanges(mlir::func::FuncOp func) {
+  func->walk([](mlir::Operation* op) {
+    if (op->getNumResults() != 1) {
+      return;
+    }
+
+    auto result = op->getResult(0);
+    if (GetRange(result).has_value()) {
+      return;
+    }
+
+    auto get_range = [](mlir::Value value) -> Interval {
+      auto range = GetRange(value);
+      if (range) {
+        return *range;
+      }
+      return {std::numeric_limits<int64_t>::min(),
+              std::numeric_limits<int64_t>::max()};
+    };
+
+    std::optional<Interval> out_range = std::nullopt;
+    if (mlir::isa<mlir::arith::MaxSIOp, mlir::arith::MinSIOp,
+                  mlir::arith::AddIOp, mlir::arith::MulIOp>(op)) {
+      auto lhs_range = get_range(op->getOperand(0));
+      auto rhs_range = get_range(op->getOperand(1));
+      if (mlir::isa<mlir::arith::MaxSIOp>(op)) {
+        out_range = lhs_range.max(rhs_range);
+      } else if (mlir::isa<mlir::arith::MinSIOp>(op)) {
+        out_range = lhs_range.min(rhs_range);
+      } else if (mlir::isa<mlir::arith::AddIOp>(op)) {
+        out_range = lhs_range + rhs_range;
+      } else {
+        out_range = lhs_range * rhs_range;
+      }
+    }
+
+    if (out_range) {
+      mlir::OpBuilder b(op);
+      op->setAttr("xla.range",
+                  b.getIndexArrayAttr({out_range->lower, out_range->upper}));
+    }
+  });
+}
+
+// Pattern to refine the bounds of an indexing map if some of its operands are
+// bound, e.g. loop induction variables.
+struct RefineConstraints : public OpRewritePattern<ApplyIndexingOp> {
+  using OpRewritePattern<ApplyIndexingOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ApplyIndexingOp indexing_op,
+                                PatternRewriter& rewriter) const override {
+    // Right now, we only handle loop induction variables, but other rules might
+    // be added.
+    IndexingMap indexing_map = indexing_op.getIndexingMap();
+    int64_t dim_count = indexing_map.GetDimensionCount();
+    bool updated_bounds = false;
+    for (mlir::OpOperand& operand : indexing_op->getOpOperands()) {
+      auto range = GetIVRange(operand.get());
+      if (!range) {
+        continue;
+      }
+      auto operand_id = operand.getOperandNumber();
+      Interval& current_interval =
+          operand_id < dim_count
+              ? indexing_map.GetMutableDimensionBound(operand_id)
+              : indexing_map.GetMutableSymbolBound(operand_id - dim_count);
+      if (!range->Contains(current_interval)) {
+        current_interval = current_interval.Intersect(*range);
+        updated_bounds = true;
+      }
+    }
+    if (!updated_bounds) {
+      return rewriter.notifyMatchFailure(indexing_op, "No bounds to refine");
+    }
+    indexing_map.Simplify();
+    rewriter.replaceOpWithNewOp<ApplyIndexingOp>(
+        indexing_op, indexing_op.getOperands(), indexing_map);
+    return mlir::success();
+  }
+};
+
+class SimplifyArithPass
+    : public impl::SimplifyArithPassBase<SimplifyArithPass> {
+ public:
+  void runOnOperation() override {
+    auto ctx = &getContext();
+    auto func = getOperation();
+    mlir::RewritePatternSet patterns(ctx);
+    AnnotateRanges(func);
+    // clang-format off
+    patterns.add<
+      RefineConstraints,
+      RewriteCmpI,
+      RewriteMaxSi,
+      RewriteMinSi,
+      RewriteTruncBitExt<mlir::arith::AndIOp>,
+      RewriteTruncBitExt<mlir::arith::OrIOp>,
+      RewriteTruncExtShuffle
+    >(ctx);
+    // clang-format on
+    if (mlir::failed(
+            mlir::applyPatternsAndFoldGreedily(func, std::move(patterns)))) {
+      signalPassFailure();
+    }
+  }
+};
+
+}  // namespace
+
+std::unique_ptr<mlir::Pass> CreateSimplifyArithPass() {
+  return std::make_unique<SimplifyArithPass>();
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/BUILD b/third_party/xla/xla/service/gpu/fusions/transforms/tests/BUILD
new file mode 100644
index 0000000..381d5a3
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/BUILD
@@ -0,0 +1,16 @@
+load("//xla:lit.bzl", "lit_test_suite")
+
+package(
+    # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
+    licenses = ["notice"],
+)
+
+lit_test_suite(
+    name = "tests",
+    srcs = glob(["*.mlir"]),
+    cfg = "//xla:lit.cfg.py",
+    tools = [
+        "//xla/service/gpu/fusions/tools:mlir_fusions_opt",
+        "@llvm-project//llvm:FileCheck",
+    ],
+)
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/convert_xla_gpu_pure_calls.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/convert_xla_gpu_pure_calls.mlir
similarity index 100%
rename from third_party/xla/xla/service/gpu/fusions/mlir/tests/convert_xla_gpu_pure_calls.mlir
rename to third_party/xla/xla/service/gpu/fusions/transforms/tests/convert_xla_gpu_pure_calls.mlir
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/expand_float_ops.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/expand_float_ops.mlir
similarity index 100%
rename from third_party/xla/xla/service/gpu/fusions/mlir/tests/expand_float_ops.mlir
rename to third_party/xla/xla/service/gpu/fusions/transforms/tests/expand_float_ops.mlir
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/flatten_tensors.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/flatten_tensors.mlir
new file mode 100644
index 0000000..21a8dc2
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/flatten_tensors.mlir
@@ -0,0 +1,146 @@
+// RUN: mlir_fusions_opt %s -split-input-file -xla-gpu-flatten-tensors \
+// RUN: --verify-diagnostics | FileCheck %s
+
+func.func @tensor_extract(
+    %arg0: tensor<2x3xf32, dense<[0, 1]> : tensor<2xi64>>,
+    %arg1: index, %arg2: index) -> f32 {
+  %v = tensor.extract %arg0[%arg1, %arg2]
+      : tensor<2x3xf32, dense<[0, 1]> : tensor<2xi64>>
+  func.return %v : f32
+}
+// CHECK: #[[$MAP:.+]] = #xla_gpu.indexing_map<(d0, d1) -> (d1 * 2 + d0),
+// CHECK-SAME:                    domain: d0 in [0, 1], d1 in [0, 2]>
+
+// CHECK-LABEL: func.func @tensor_extract(
+// CHECK-SAME:      %[[SRC:.*]]: tensor<6xf32>,
+// CHECK-SAME:      %[[I:.*]]: index, %[[J:.*]]: index) -> f32 {
+// CHECK:        %[[INDEX:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[I]], %[[J]])
+// CHECK:        tensor.extract %[[SRC]][%[[INDEX]]] : tensor<6xf32>
+
+// -----
+
+func.func @tensor_insert(
+    %arg0: tensor<10x24xcomplex<f32>>) -> tensor<10x24xcomplex<f32>> {
+  %c1 = arith.constant 1 : index
+  %real = arith.constant 3.0 : f32
+  %imag = arith.constant 2.0 : f32
+  %complex = complex.create %real, %imag : complex<f32>
+  %out = tensor.insert %complex into %arg0[%c1, %c1] : tensor<10x24xcomplex<f32>>
+  func.return %out : tensor<10x24xcomplex<f32>>
+}
+// CHECK-LABEL: func.func @tensor_insert(
+// CHECK-SAME:      %[[TENSOR:.*]]: tensor<240xcomplex<f32>>) -> tensor<240xcomplex<f32>> {
+// CHECK:         %[[INDEX:.*]] = arith.constant 25
+// CHECK:         %[[COMPLEX:.*]] = complex.create
+// CHECK:         tensor.insert %[[COMPLEX]] into %[[TENSOR]][%[[INDEX]]]
+// CHECK-SAME:      : tensor<240xcomplex<f32>>
+
+// -----
+
+func.func @atomic_rmw(%in: tensor<2x4xf32>, %i: index, %j: index)
+    -> (tensor<2x4xf32>) {
+  %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xf32> {
+    ^bb0(%current : f32):
+      %c42 = arith.constant 1.0 : f32
+      %add = arith.minimumf %current, %c42 : f32
+      xla_gpu.yield %add : f32
+  }
+  return %ret : tensor<2x4xf32>
+}
+// CHECK: #[[$MAP:.+]] = #xla_gpu.indexing_map<(d0, d1) -> (d0 * 4 + d1),
+// CHECK-SAME:                    domain: d0 in [0, 1], d1 in [0, 3]>
+
+// CHECK-LABEL: func.func @atomic_rmw(
+// CHECK-SAME:      %[[TENSOR:.*]]: tensor<8xf32>, %[[I:.*]]: index,
+// CHECK-SAME:      %[[J:.*]]: index) -> tensor<8xf32> {
+// CHECK:         %[[INDEX:.*]] = xla_gpu.apply_indexing #[[$MAP]]
+// CHECK-SAME:      (%[[I]], %[[J]])
+// CHECK:         xla_gpu.atomic_rmw %[[TENSOR]][%[[INDEX]]] : tensor<8xf32>
+
+// -----
+
+func.func @for_loop(%t0: tensor<32x1024xf32>, %t1: tensor<64x8x4xf32>)
+    -> (tensor<32x1024xf32>, tensor<64x8x4xf32>, f32) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c32 = arith.constant 32 : index
+  %c64 = arith.constant 64 : index
+  %c0_f32 = arith.constant 0.0 : f32
+  %for:2 = scf.for %i = %c0 to %c64 step %c32 iter_args(%t0_ = %t0, %t1_ = %t1)
+    -> (tensor<32x1024xf32>, tensor<64x8x4xf32>) {
+    %update0 = tensor.insert %c0_f32 into %t0_[%c1, %i] : tensor<32x1024xf32>
+    %update1 = tensor.insert %c0_f32 into %t1_[%i, %c1, %c1] : tensor<64x8x4xf32>
+    scf.yield %update0, %update1 : tensor<32x1024xf32>, tensor<64x8x4xf32>
+  } {some_attr}
+    return %for#0, %for#1, %c0_f32 : tensor<32x1024xf32>, tensor<64x8x4xf32>, f32
+}
+
+// CHECK: #[[$MAP0:.+]] = #xla_gpu.indexing_map<(d0) -> (d0 + 1024),
+// CHECK-SAME:                     domain: d0 in [0, 1023]>
+// CHECK: #[[$MAP1:.+]] = #xla_gpu.indexing_map<(d0) -> (d0 * 32 + 5),
+// CHECK-SAME:                     domain: d0 in [0, 63]>
+// CHECK-LABEL: func.func @for_loop(
+// CHECK-SAME:      %[[T0:.*]]: tensor<32768xf32>,
+// CHECK-SAME:      %[[T1:.*]]: tensor<2048xf32>) -> (tensor<32768xf32>, tensor<2048xf32>, f32) {
+
+// CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:  %[[C32:.*]] = arith.constant 32 : index
+// CHECK-DAG:  %[[C64:.*]] = arith.constant 64 : index
+// CHECK-DAG:  %[[F32:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:      %[[FOR:.*]]:2 = scf.for %[[I:.*]] = %[[C0]] to %[[C64]]
+// CHECK-SAME:     step %[[C32]]
+// CHECK-SAME:     iter_args(%[[T0_:.*]] = %[[T0]], %[[T1_:.*]] = %[[T1]])
+// CHECK:        %[[IND0:.*]] = xla_gpu.apply_indexing #[[$MAP0]](%[[I]])
+// CHECK:        %[[UPD0:.*]] = tensor.insert %[[F32]] into %[[T0_]][%[[IND0]]]
+// CHECK:        %[[IND1:.*]] = xla_gpu.apply_indexing #[[$MAP1]](%[[I]])
+// CHECK:        %[[UPD1:.*]] = tensor.insert %[[F32]] into %[[T1_]][%[[IND1]]]
+// CHECK:        scf.yield %[[UPD0]], %[[UPD1]] : tensor<32768xf32>, tensor<2048xf32>
+
+// -----
+
+#map = #xla_gpu.indexing_map<(d0, d1) -> ((d1 * 128 + d0) floordiv 36),
+                              domain: d0 in [0, 127], d1 in [0, 393749]>
+#map1 = #xla_gpu.indexing_map<(d0, d1) -> (((d1 * 128 + d0) floordiv 9) mod 4),
+                              domain: d0 in [0, 127], d1 in [0, 393749]>
+#map2 = #xla_gpu.indexing_map<(d0, d1) -> ((d1 * 128 + d0) mod 9),
+                              domain: d0 in [0, 127], d1 in [0, 393749]>
+func.func @if_op(%arg0: tensor<4000x4x9xf32>, %arg1: tensor<1400x1xi32>,
+    %arg2: tensor<1400x1x4x9xf32>, %arg3: tensor<4000x4x9xf32>)
+     -> tensor<4000x4x9xf32> {
+	%c0 = arith.constant 0 : index
+	%c3999 = arith.constant 3999 : index
+	%th_x = gpu.thread_id  x {xla.range = [0 : index, 127 : index]}
+	%bl_x = gpu.block_id  x {xla.range = [0 : index, 393749 : index]}
+	%0 = xla_gpu.apply_indexing #map(%th_x, %bl_x)
+	%extracted = tensor.extract %arg1[%0, %c0] : tensor<1400x1xi32>
+	%1 = arith.index_cast %extracted : i32 to index
+	%2 = arith.cmpi ule, %1, %c3999 : index
+	%3 = scf.if %2 -> (tensor<4000x4x9xf32>) {
+		%4 = xla_gpu.apply_indexing #map1(%th_x, %bl_x)
+		%5 = xla_gpu.apply_indexing #map2(%th_x, %bl_x)
+		%elem = tensor.extract %arg2[%0, %c0, %4, %5] : tensor<1400x1x4x9xf32>
+		%atomic_rmw = xla_gpu.atomic_rmw %arg3[%1, %4, %5] : tensor<4000x4x9xf32> {
+		^bb0(%arg4: f32):
+			%6 = arith.addf %arg4, %elem : f32
+			xla_gpu.yield %6 : f32
+		}
+		scf.yield %atomic_rmw : tensor<4000x4x9xf32>
+	} else {
+		scf.yield %arg3 : tensor<4000x4x9xf32>
+	}
+	return %3 : tensor<4000x4x9xf32>
+}
+// CHECK-LABEL: func.func @if_op
+// CHECK-NOT:     builtin.unrealized_conversion_cast
+// CHECK:         scf.if %{{.*}} -> (tensor<144000xf32>) {
+// CHECK-COUNT-2:   scf.yield %{{.*}} : tensor<144000xf32>
+// CHECK:         return %{{.*}} : tensor<144000xf32>
+
+// -----
+
+func.func @dangling_cast(%arg0: tensor<6xf32>, %arg1: index) -> i32 {
+  %v = tensor.extract %arg0[%arg1] : tensor<6xf32>
+	%cast = builtin.unrealized_conversion_cast %v : f32 to i32
+  func.return %cast : i32
+}
+// CHECK: FlattenTensorsPass failed to converge
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/inlining.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/inlining.mlir
similarity index 100%
copy from third_party/xla/xla/service/gpu/fusions/mlir/tests/inlining.mlir
copy to third_party/xla/xla/service/gpu/fusions/transforms/tests/inlining.mlir
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_tensors.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_tensors.mlir
new file mode 100644
index 0000000..be8eb1e
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_tensors.mlir
@@ -0,0 +1,854 @@
+// RUN: mlir_fusions_opt %s --allow-unregistered-dialect -split-input-file \
+// RUN: -xla-gpu-lower-tensors="is_amd_gpu=false gpu_arch=6.0" \
+// RUN: | FileCheck %s
+
+// RUN: mlir_fusions_opt %s --allow-unregistered-dialect -split-input-file \
+// RUN: -xla-gpu-lower-tensors="is_amd_gpu=false gpu_arch=7.0" \
+// RUN: | FileCheck %s --check-prefix=CHECK-VOLTA
+
+// RUN: mlir_fusions_opt %s --allow-unregistered-dialect -split-input-file \
+// RUN: -xla-gpu-lower-tensors="is_amd_gpu=false gpu_arch=8.0" \
+// RUN: | FileCheck %s --check-prefix=CHECK-AMPERE
+
+// RUN: mlir_fusions_opt %s --allow-unregistered-dialect -split-input-file \
+// RUN: -xla-gpu-lower-tensors="is_amd_gpu=false gpu_arch=9.0" \
+// RUN: | FileCheck %s --check-prefix=CHECK-HOPPER
+
+// RUN: mlir_fusions_opt %s --allow-unregistered-dialect -split-input-file \
+// RUN: -xla-gpu-lower-tensors="is_amd_gpu=true gpu_arch=gfx908:sramecc+:xnack" \
+// RUN: | FileCheck %s --check-prefix=CHECK-GFX908-MI100
+
+// RUN: mlir_fusions_opt %s --allow-unregistered-dialect -split-input-file \
+// RUN: -xla-gpu-lower-tensors="is_amd_gpu=true gpu_arch=gfx90a:sramecc+:xnack" \
+// RUN: | FileCheck %s --check-prefix=CHECK-GFX90A-MI200
+
+module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<index, 32 : i32>>} {
+  func.func private @add(%arg0: f32, %arg1: f32) -> f32 {
+    %sum = arith.addf %arg0, %arg1 : f32
+    func.return %sum : f32
+  }
+
+  func.func private @tensorarg(%arg0: tensor<43xf32> {xla.invariant, xla.slice_index = 0}, %arg1: index) -> f32 {
+    %v1 = arith.constant 2.0 : f32
+    %v2 = tensor.extract %arg0[%arg1] : tensor<43xf32>
+    %sum = func.call @add(%v1, %v2) : (f32, f32) -> f32
+    func.return %sum : f32
+  }
+
+  func.func @tensorcall(%arg0: tensor<43xf32> {xla.slice_index = 0}, %arg1: index) -> f32 {
+    %call = func.call @tensorarg(%arg0, %arg1) : (tensor<43xf32>, index) -> f32
+    func.return %call : f32
+  }
+
+  func.func @stores(%arg0: tensor<17xf32> {xla.slice_index = 0}, %arg1: tensor<43xf32> {xla.slice_index = 1}) -> tensor<43xf32> {
+    %c17 = arith.constant 17 : index
+    %c23 = arith.constant 23 : index
+    %cst = arith.constant 3.0 : f32
+    %out = tensor.insert %cst into %arg1[%c17] : tensor<43xf32>
+    %out2 = tensor.insert %cst into %out[%c23] : tensor<43xf32>
+    func.return %out2 : tensor<43xf32>
+  }
+}
+
+// CHECK:        func.func private @add(%{{.*}}: f32, %{{.*}}: f32) -> f32 {
+// CHECK-NEXT:     arith.addf
+// CHECK-NEXT:     return
+
+// CHECK:        func.func private @tensorarg(%[[ARG0:.*]]: !llvm.ptr
+// CHECK-SAME:        {xla.invariant, xla.slice_index = 0 : i64}, %[[ARG1:.*]]: index) -> f32 {
+// CHECK-DAG:       %[[C2:.*]] = arith.constant 2.000000e+00
+// CHECK-DAG:       %[[IDX:.*]] = arith.index_castui %[[ARG1]] : index to i32
+// CHECK-DAG:       %[[PTR:.*]] = llvm.getelementptr inbounds %[[ARG0]][%[[IDX]]]
+// CHECK-DAG:       %[[V2:.*]] = llvm.load %[[PTR]] invariant
+// CHECK:           %[[RET:.*]] = call @add(%[[C2]], %[[V2]])
+// CHECK:           return %[[RET]]
+
+// CHECK:        func.func @tensorcall(%[[ARG0:.*]]: !llvm.ptr
+// CHECK-SAME:        {xla.slice_index = 0 : i64}, %[[ARG1:.*]]: index)
+// CHECK:           %[[RET:.*]] = call @tensorarg(%[[ARG0]], %[[ARG1]])
+// CHECK:           return %[[RET]]
+
+// CHECK:        func.func @stores(
+// CHECK-SAME:        %[[ARG0:.*]]: !llvm.ptr {xla.slice_index = 0 : i64},
+// CHECK-SAME:        %[[ARG1:.*]]: !llvm.ptr {xla.slice_index = 1 : i64})
+// CHECK-NEXT:      %[[CST:.*]] = arith.constant 3.000000e+00 : f32
+// CHECK-NEXT:      %[[PTR1:.*]] = llvm.getelementptr inbounds %[[ARG1]][17]
+// CHECK-NEXT:      llvm.store %[[CST]], %[[PTR1]]
+// CHECK-NEXT:      %[[PTR2:.*]] = llvm.getelementptr inbounds %[[ARG1]][23]
+// CHECK-NEXT:      llvm.store %[[CST]], %[[PTR2]]
+// CHECK-NEXT:      return
+
+// -----
+
+module {
+  func.func @layout(
+      %arg0: tensor<2x3xf32, dense<[0, 1]> : tensor<2xi64>>,
+      %arg1: index, %arg2: index) -> f32 {
+    %v = tensor.extract %arg0[%arg1, %arg2]
+        : tensor<2x3xf32, dense<[0, 1]> : tensor<2xi64>>
+    func.return %v : f32
+  }
+}
+
+// CHECK:        #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d1 * 2 + d0),
+// CHECK-SAME:                           domain: d0 in [0, 1], d1 in [0, 2]>
+// CHECK-LABEL:  @layout(
+// CHECK-SAME:      %[[ARG0:.*]]: !llvm.ptr,
+// CHECK-SAME:      %[[X:.*]]: index, %[[Y:.*]]: index
+// CHECK:        %[[IDX:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[X]], %[[Y]])
+// CHECK:        %[[IDX_CAST:.*]] = arith.index_castui %[[IDX]] : index to i64
+// CHECK:        %[[PTR:.*]] = llvm.getelementptr inbounds %[[ARG0]][%[[IDX_CAST]]]
+// CHECK:        llvm.load %[[PTR]]
+
+// -----
+
+module {
+  func.func @store_control_flow(
+     %arg0: tensor<2xf32>,
+     %arg1: index
+  ) -> tensor<2xf32> {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c2 = arith.constant 2 : index
+    %cst = arith.constant 0.0 : f32
+    %cst2 = arith.constant 1.0 : f32
+
+    %for = scf.for %i = %c0 to %c2 step %c1 iter_args(%arg2 = %arg0) -> tensor<2xf32> {
+      %new_out = tensor.insert %cst into %arg2[%i] : tensor<2xf32>
+      scf.yield %new_out : tensor<2xf32>
+    }
+
+    %inbounds = arith.cmpi sle, %arg1, %c1 : index
+    %result = scf.if %inbounds -> tensor<2xf32> {
+      %if = tensor.insert %cst2 into %for[%arg1] : tensor<2xf32>
+      scf.yield %if : tensor<2xf32>
+    } else {
+      scf.yield %for : tensor<2xf32>
+    }
+    func.return %result : tensor<2xf32>
+  }
+}
+
+// CHECK:     @store_control_flow(%[[ARG0:.*]]: !llvm.ptr, %[[X:.*]]: index) {
+// CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
+// CHECK:       scf.for %[[I:.*]] = %[[C0]] to %[[C2]] step %[[C1]] {
+// CHECK:         %[[CAST:.*]] = arith.index_castui %[[I]] : index to i64
+// CHECK:         %[[PTR:.*]] = llvm.getelementptr inbounds %[[ARG0]][%[[CAST]]]
+// CHECK:         llvm.store {{.*}}, %[[PTR]]
+// CHECK:       %[[INBOUNDS:.*]] = arith.cmpi
+// CHECK:       scf.if %[[INBOUNDS]] {
+// CHECK:         llvm.store
+// CHECK-NEXT:  }
+// CHECK-NEXT:  return
+
+// -----
+
+module {
+  func.func @large_tensor(
+      %arg0: tensor<1024x1024x1024x6xf32>,
+      %arg1: index) -> f32 {
+    %v = tensor.extract %arg0[%arg1, %arg1, %arg1, %arg1] : tensor<1024x1024x1024x6xf32>
+    func.return %v : f32
+  }
+}
+
+// CHECK: @large_tensor
+// CHECK: arith.index_castui {{.*}} : index to i64
+
+// -----
+
+module {
+  func.func @extract_from_constant(%arg0: tensor<2x1xf32>,
+      %arg1: index, %arg2: index) -> f32 {
+    %cst = arith.constant dense<[[1.000000e+00], [2.000000e+00]]> : tensor<2x1xf32>
+    %extracted = tensor.extract %arg0[%arg1, %arg2] : tensor<2x1xf32>
+    %extracted_0 = tensor.extract %cst[%arg1, %arg2] : tensor<2x1xf32>
+    %0 = arith.addf %extracted, %extracted_0 : f32
+    return %0 : f32
+  }
+}
+// CHECK: llvm.mlir.global private constant @global_cst_0(dense<
+// CHECK-SAME: [1.000000e+00, 2.000000e+00]> : tensor<2xf32>) {addr_space = 0 : i32} : !llvm.array<2 x f32>
+// CHECK: @extract_from_constant
+// CHECK: %[[ADDR_OF:.*]] = llvm.mlir.addressof @global_cst_0 : !llvm.ptr
+// CHECK: %[[GEP:.*]] = llvm.getelementptr inbounds %[[ADDR_OF]][%{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+// CHECK: %[[LOAD:.*]] = llvm.load %[[GEP]] : !llvm.ptr -> f32
+// CHECK: %[[ADD:.*]] = arith.addf %{{.*}}, %[[LOAD]] : f32
+// CHECK: return %[[ADD]] : f32
+
+// -----
+
+module {
+  func.func @vector_constant() -> vector<2xindex> {
+    %c1 = arith.constant dense<[1, 2]> : vector<2xindex>
+    func.return %c1 : vector<2xindex>
+  }
+}
+
+// vector constants should not be rewritten.
+// CHECK: @vector_constant
+// CHECK-NEXT: arith.constant
+
+// -----
+
+module {
+  func.func @complex_tensor_insert(
+      %arg0: tensor<10xcomplex<f32>>) -> tensor<10xcomplex<f32>> {
+    %c1 = arith.constant 1 : index
+    %real = arith.constant 3.0 : f32
+    %imag = arith.constant 2.0 : f32
+    %complex = complex.create %real, %imag : complex<f32>
+    %out = tensor.insert %complex into %arg0[%c1] : tensor<10xcomplex<f32>>
+    func.return %out : tensor<10xcomplex<f32>>
+  }
+}
+
+// CHECK: @complex_tensor_insert(%[[ARG0:.*]]: !llvm.ptr
+// CHECK: %[[C:.*]] = complex.create
+// CHECK: %[[GEP:.*]] = llvm.getelementptr inbounds %[[ARG0]][1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(f32, f32)>
+// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[C]] : complex<f32> to !llvm.struct<(f32, f32)>
+// CHECK: llvm.store %[[CAST]], %[[GEP]] : !llvm.struct<(f32, f32)>, !llvm.ptr
+
+// -----
+
+module {
+  func.func @complex_tensor_extract(
+      %arg0: tensor<10xcomplex<f32>>) -> complex<f32> {
+    %c1 = arith.constant 1 : index
+    %v2 = tensor.extract %arg0[%c1] : tensor<10xcomplex<f32>>
+    func.return %v2 : complex<f32>
+  }
+}
+
+// CHECK: @complex_tensor_extract(%[[ARG0:.*]]: !llvm.ptr
+// CHECK: %[[GEP:.*]] = llvm.getelementptr inbounds %[[ARG0]][1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(f32, f32)>
+// CHECK: %[[LOAD:.*]] = llvm.load %[[GEP]] : !llvm.ptr -> !llvm.struct<(f32, f32)>
+// CHECK: builtin.unrealized_conversion_cast %[[LOAD]] : !llvm.struct<(f32, f32)> to complex<f32>
+
+// -----
+
+module {
+  // This example is a bit silly, in real life there wouldn't be a loop (the
+  // loop body would be executed by different threads). We're just doing it this
+  // way so control flow with shared memory is tested as well.
+  func.func @transpose_shared(%in: tensor<32x32xf32>,
+                              %out: tensor<32x32xf32>) -> tensor<32x32xf32> {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c32 = arith.constant 32 : index
+
+    %shared = xla_gpu.allocate_shared : tensor<32x32xf32>
+    %loaded_tile = scf.for %i = %c0 to %c32 step %c1
+        iter_args(%tile = %shared) -> tensor<32x32xf32> {
+      %inner_loaded_tile = scf.for %j = %c0 to %c32 step %c1
+          iter_args(%inner_tile = %tile) -> tensor<32x32xf32> {
+        %v = tensor.extract %in[%i, %j] : tensor<32x32xf32>
+        %inserted = tensor.insert %v into %inner_tile[%i, %j]
+            : tensor<32x32xf32>
+        scf.yield %inserted : tensor<32x32xf32>
+      }
+      scf.yield %inner_loaded_tile : tensor<32x32xf32>
+    }
+
+    %synced = xla_gpu.sync_threads %shared : tensor<32x32xf32>
+    %written_tile = scf.for %i = %c0 to %c32 step %c1
+        iter_args(%written = %out) -> tensor<32x32xf32> {
+      %inner_written_tile = scf.for %j = %c0 to %c32 step %c1
+          iter_args(%inner_written = %written) -> tensor<32x32xf32> {
+        %v = tensor.extract %shared[%j, %i] : tensor<32x32xf32>
+        %inserted = tensor.insert %v into %inner_written[%i, %j]
+            : tensor<32x32xf32>
+        scf.yield %inserted : tensor<32x32xf32>
+      }
+      scf.yield %inner_written_tile : tensor<32x32xf32>
+    }
+
+    return %written_tile : tensor<32x32xf32>
+  }
+}
+
+// CHECK:      llvm.mlir.global private @[[SHARED:shared_.*]]()
+// CHECK-SAME:     {addr_space = 3 : i32} : !llvm.array<1024 x f32>
+// CHECK:      @transpose_shared
+// CHECK:        %[[ADDR:.*]] = llvm.mlir.addressof @[[SHARED]] : !llvm.ptr<3>
+// CHECK:        %[[CAST:.*]] = llvm.addrspacecast %[[ADDR]]
+// CHECK-SAME:       : !llvm.ptr<3> to !llvm.ptr
+// CHECK:        scf.for
+// CHECK:          scf.for
+// CHECK:            %[[ELEM_ADDR:.*]] = llvm.getelementptr inbounds %[[CAST]]
+// CHECK:            llvm.store {{.*}} %[[ELEM_ADDR]]
+// CHECK:        gpu.barrier
+// CHECK:        scf.for
+// CHECK:          scf.for
+// CHECK:            %[[ELEM_ADDR:.*]] = llvm.getelementptr inbounds %[[CAST]]
+// CHECK:            llvm.load %[[ELEM_ADDR]]
+
+// -----
+
+module {
+  func.func @atomic_rmw_f32(%in: tensor<2x4xf32>, %i: index, %j: index)
+      -> (tensor<2x4xf32>) {
+    %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xf32> {
+      ^bb0(%current : f32):
+        %c42 = arith.constant 1.0 : f32
+        %add = arith.minimumf %current, %c42 : f32
+        xla_gpu.yield %add : f32
+    }
+    return %ret : tensor<2x4xf32>
+  }
+}
+
+// CHECK: @atomic_rmw_f32
+// CHECK: %[[ADDR:.*]] = llvm.getelementptr
+// CHECK-NEXT: %[[INIT:.*]] = llvm.load %[[ADDR]]
+// CHECK-NEXT: scf.while (%[[VAR:.*]] = %[[INIT]])
+// CHECK: %[[RES:.*]] = llvm.bitcast %{{.*}} : f32 to i32
+// CHECK-NEXT: llvm.cmpxchg %[[ADDR]], %[[VAR]], %[[RES]]
+
+// -----
+
+module {
+  func.func @atomic_rmw_f16(%in: tensor<2x4xf16>, %i: index, %j: index)
+      -> (tensor<2x4xf16>) {
+    %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xf16> {
+      ^bb0(%current : f16):
+        %c1 = arith.constant 1.0 : f16
+        %add = arith.addf %current, %c1 : f16
+        xla_gpu.yield %add : f16
+    }
+    return %ret : tensor<2x4xf16>
+  }
+}
+
+// CHECK: @atomic_rmw_f16
+// CHECK: %[[ADDR:.*]] = llvm.getelementptr
+// CHECK-NEXT: %[[ADDR_INT:.*]] = llvm.ptrtoint %[[ADDR]]
+// CHECK-NEXT: %[[OFFSET:.*]] = llvm.and %[[ADDR_INT]], %{{.*}}
+// CHECK-NEXT: %[[INDEX:.*]] = llvm.mul %[[OFFSET]], %{{.*}}
+// CHECK-NEXT: %[[BASE:.*]] = llvm.getelementptr inbounds %[[ADDR]][%[[INDEX]]]
+// CHECK: %[[INIT:.*]] = llvm.load %[[BASE]]
+// CHECK-NEXT: scf.while (%[[VAR:.*]] = %[[INIT]])
+// CHECK-NEXT: %[[VAR_SHIFT:.*]] = llvm.lshr %[[VAR]], %{{.*}}
+// CHECK-NEXT: %[[VAR_TRUNC:.*]] = llvm.trunc %[[VAR_SHIFT]]
+// CHECK-NEXT: llvm.bitcast %[[VAR_TRUNC]] : i16 to f16
+// CHECK: %[[RES:.*]] = llvm.bitcast %{{.*}} : f16 to i16
+// CHECK-NEXT: %[[RES_WIDE:.*]] = llvm.zext %[[RES]]
+// CHECK-NEXT: %[[NEW_MASKED:.*]] = llvm.and %[[VAR]], %{{.*}}
+// CHECK-NEXT: %[[RES_SHIFT:.*]] = llvm.shl %[[RES_WIDE]], %{{.*}}
+// CHECK-NEXT: %[[NEW:.*]] = llvm.or %[[NEW_MASKED]], %[[RES_SHIFT]]
+// CHECK-NEXT: llvm.cmpxchg %[[BASE]], %[[VAR]], %[[NEW]]
+
+// -----
+
+module {
+  func.func @atomic_rmw_overwrite(%in: tensor<2x4xf16>, %i: index, %j: index)
+      -> (tensor<2x4xf16>) {
+    %c1 = arith.constant 1.0 : f16
+    %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xf16> {
+      ^bb0(%current : f16):
+        xla_gpu.yield %c1 : f16
+    }
+    return %ret : tensor<2x4xf16>
+  }
+}
+// CHECK: @atomic_rmw_overwrite
+// CHECK: %[[ADDR:.*]] = llvm.getelementptr
+// CHECK-NEXT: %[[ADDR_INT:.*]] = llvm.ptrtoint %[[ADDR]]
+// CHECK-NEXT: %[[OFFSET:.*]] = llvm.and %[[ADDR_INT]], %{{.*}}
+// CHECK-NEXT: %[[INDEX:.*]] = llvm.mul %[[OFFSET]], %{{.*}}
+// CHECK-NEXT: %[[BASE:.*]] = llvm.getelementptr inbounds %[[ADDR]][%[[INDEX]]]
+// CHECK: %[[INIT:.*]] = llvm.load %[[BASE]]
+// CHECK-NEXT: scf.while (%[[VAR:.*]] = %[[INIT]])
+// CHECK: %[[RES:.*]] = llvm.bitcast %{{.*}} : f16 to i16
+// CHECK-NEXT: %[[RES_WIDE:.*]] = llvm.zext %[[RES]]
+// CHECK-NEXT: %[[NEW_MASKED:.*]] = llvm.and %[[VAR]], %{{.*}}
+// CHECK-NEXT: %[[RES_SHIFT:.*]] = llvm.shl %[[RES_WIDE]], %{{.*}}
+// CHECK-NEXT: %[[NEW:.*]] = llvm.or %[[NEW_MASKED]], %[[RES_SHIFT]]
+// CHECK-NEXT: llvm.cmpxchg %[[BASE]], %[[VAR]], %[[NEW]]
+
+// -----
+
+module {
+  func.func @shared_complex() -> tensor<10xcomplex<f32>> {
+    %shared = xla_gpu.allocate_shared : tensor<10xcomplex<f32>>
+    return %shared : tensor<10xcomplex<f32>>
+  }
+}
+
+// CHECK: llvm.mlir.global private @{{.*}}() {addr_space = 3 : i32} : !llvm.array<10 x struct<(f32, f32)>>
+// CHECK: @shared_complex
+
+// -----
+
+module {
+  func.func @i4_load_store(%arg: tensor<10xi4>, %i: index, %j: index) -> tensor<10xi4> {
+    %v = tensor.extract %arg[%i] : tensor<10xi4>
+    %r = tensor.insert %v into %arg[%j] : tensor<10xi4>
+    return %r : tensor<10xi4>
+  }
+}
+
+// CHECK: @i4_load_store
+// CHECK: llvm.getelementptr
+// CHECK-SAME: -> !llvm.ptr, i8
+// CHECK: llvm.load
+// CHECK: llvm.getelementptr
+// CHECK-SAME: -> !llvm.ptr, i8
+// CHECK: llvm.load
+// CHECK: llvm.store
+
+// -----
+
+module {
+  func.func @direct_atomic_rmw_overwrite(%in: tensor<2x4xi32>,
+    %i: index, %j: index) -> (tensor<2x4xi32>) {
+    %c2 = arith.constant 2 : i32
+    %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xi32> {
+      ^bb0(%current : i32):
+        xla_gpu.yield %c2 : i32
+    }
+    return %ret : tensor<2x4xi32>
+  }
+}
+// CHECK: @direct_atomic_rmw_overwrite
+// CHECK: %[[C2:.*]] = arith.constant 2
+// CHECK: %[[ADDR:.*]] = llvm.getelementptr
+// CHECK: llvm.store %[[C2]], %[[ADDR]] atomic unordered {alignment = 4 : i64}
+
+// -----
+
+module {
+  func.func @direct_atomic_rmw_addi(%in: tensor<2x4xi32>,
+    %i: index, %j: index) -> (tensor<2x4xi32>) {
+    %c2 = arith.constant 2 : i32
+    %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xi32> {
+      ^bb0(%current : i32):
+        %min = arith.addi %current, %c2 : i32
+        xla_gpu.yield %c2 : i32
+    }
+    return %ret : tensor<2x4xi32>
+  }
+}
+// CHECK: @direct_atomic_rmw_addi
+// CHECK: %[[C2:.*]] = arith.constant 2
+// CHECK: %[[ADDR:.*]] = llvm.getelementptr
+// CHECK: llvm.atomicrmw add %[[ADDR]], %[[C2]] seq_cst
+
+// -----
+
+module {
+  func.func @direct_atomic_rmw_maxsi(%in: tensor<2x4xi32>,
+    %i: index, %j: index) -> (tensor<2x4xi32>) {
+    %c2 = arith.constant 2 : i32
+    %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xi32> {
+      ^bb0(%current : i32):
+        %min = arith.maxsi %current, %c2 : i32
+        xla_gpu.yield %c2 : i32
+    }
+    return %ret : tensor<2x4xi32>
+  }
+}
+// CHECK: @direct_atomic_rmw_maxsi
+// CHECK: %[[C2:.*]] = arith.constant 2
+// CHECK: %[[ADDR:.*]] = llvm.getelementptr
+// CHECK: llvm.atomicrmw max %[[ADDR]], %[[C2]] seq_cst
+
+// -----
+
+module {
+  func.func @direct_atomic_rmw_maxui(%in: tensor<2x4xi32>,
+    %i: index, %j: index) -> (tensor<2x4xi32>) {
+    %c2 = arith.constant 2 : i32
+    %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xi32> {
+      ^bb0(%current : i32):
+        %min = arith.maxui %current, %c2 : i32
+        xla_gpu.yield %c2 : i32
+    }
+    return %ret : tensor<2x4xi32>
+  }
+}
+// CHECK: @direct_atomic_rmw_maxui
+// CHECK: %[[C2:.*]] = arith.constant 2
+// CHECK: %[[ADDR:.*]] = llvm.getelementptr
+// CHECK: llvm.atomicrmw umax %[[ADDR]], %[[C2]] seq_cst
+
+// -----
+
+module {
+  func.func @direct_atomic_rmw_minsi(%in: tensor<2x4xi32>,
+    %i: index, %j: index) -> (tensor<2x4xi32>) {
+    %c2 = arith.constant 2 : i32
+    %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xi32> {
+      ^bb0(%current : i32):
+        %min = arith.minsi %current, %c2 : i32
+        xla_gpu.yield %c2 : i32
+    }
+    return %ret : tensor<2x4xi32>
+  }
+}
+// CHECK: @direct_atomic_rmw_minsi
+// CHECK: %[[C2:.*]] = arith.constant 2
+// CHECK: %[[ADDR:.*]] = llvm.getelementptr
+// CHECK: llvm.atomicrmw min %[[ADDR]], %[[C2]] seq_cst
+
+// -----
+
+module {
+  func.func @direct_atomic_rmw_minui(%in: tensor<2x4xi32>,
+    %i: index, %j: index) -> (tensor<2x4xi32>) {
+    %c2 = arith.constant 2 : i32
+    %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xi32> {
+      ^bb0(%current : i32):
+        %min = arith.minui %current, %c2 : i32
+        xla_gpu.yield %c2 : i32
+    }
+    return %ret : tensor<2x4xi32>
+  }
+}
+// CHECK: @direct_atomic_rmw_minui
+// CHECK: %[[C2:.*]] = arith.constant 2
+// CHECK: %[[ADDR:.*]] = llvm.getelementptr
+// CHECK: llvm.atomicrmw umin %[[ADDR]], %[[C2]] seq_cst
+
+// -----
+
+module {
+  func.func @direct_atomic_rmw_fadd_f32(%in: tensor<2x4xf32>,
+    %i: index, %j: index) -> (tensor<2x4xf32>) {
+    %c2 = arith.constant 2.0 : f32
+    %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xf32> {
+      ^bb0(%current : f32):
+        %min = arith.addf %current, %c2 : f32
+        xla_gpu.yield %c2 : f32
+    }
+    return %ret : tensor<2x4xf32>
+  }
+}
+// CHECK-LABEL: @direct_atomic_rmw_fadd_f32
+// CHECK: %[[C2:.*]] = arith.constant 2
+// CHECK: %[[ADDR:.*]] = llvm.getelementptr
+// CHECK: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] seq_cst
+
+// CHECK-VOLTA-LABEL: @direct_atomic_rmw_fadd_f32
+// CHECK-VOLTA: %[[C2:.*]] = arith.constant 2
+// CHECK-VOLTA: %[[ADDR:.*]] = llvm.getelementptr
+// CHECK-VOLTA: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] seq_cst
+
+// CHECK-AMPERE-LABEL: @direct_atomic_rmw_fadd_f32
+// CHECK-AMPERE: %[[C2:.*]] = arith.constant 2
+// CHECK-AMPERE: %[[ADDR:.*]] = llvm.getelementptr
+// CHECK-AMPERE: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] seq_cst
+
+// CHECK-GFX908-MI100-LABEL: @direct_atomic_rmw_fadd_f32
+// CHECK-GFX908-MI100: %[[C2:.*]] = arith.constant 2
+// CHECK-GFX908-MI100: %[[ADDR:.*]] = llvm.getelementptr
+// CHECK-GFX908-MI100: %[[ADDR_CAST:.*]] = llvm.addrspacecast %[[ADDR]] : !llvm.ptr to !llvm.ptr<1>
+// CHECK-GFX908-MI100: llvm.atomicrmw fadd %[[ADDR_CAST]], %[[C2]] syncscope("agent") seq_cst
+
+// CHECK-GFX90A-MI200-LABEL: @direct_atomic_rmw_fadd_f32
+// CHECK-GFX90A-MI200: %[[C2:.*]] = arith.constant 2
+// CHECK-GFX90A-MI200: %[[ADDR:.*]] = llvm.getelementptr
+// CHECK-GFX90A-MI200: %[[ADDR_CAST:.*]] = llvm.addrspacecast %[[ADDR]] : !llvm.ptr to !llvm.ptr<1>
+// CHECK-GFX90A-MI200: llvm.atomicrmw fadd %[[ADDR_CAST]], %[[C2]] syncscope("agent") seq_cst
+
+// -----
+
+module {
+  func.func @direct_atomic_rmw_fadd_f16(%in: tensor<2x4xf16>,
+    %i: index, %j: index) -> (tensor<2x4xf16>) {
+    %c2 = arith.constant 2.0 : f16
+    %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xf16> {
+      ^bb0(%current : f16):
+        %min = arith.addf %current, %c2 : f16
+        xla_gpu.yield %c2 : f16
+    }
+    return %ret : tensor<2x4xf16>
+  }
+}
+// CHECK-LABEL: @direct_atomic_rmw_fadd_f16
+// CHECK-NOT: llvm.atomicrmw fadd
+
+// CHECK-VOLTA-LABEL: @direct_atomic_rmw_fadd_f16
+// CHECK-VOLTA: %[[C2:.*]] = arith.constant 2
+// CHECK-VOLTA: %[[ADDR:.*]] = llvm.getelementptr
+// CHECK-VOLTA: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] seq_cst
+
+// CHECK-AMPERE-LABEL: @direct_atomic_rmw_fadd_f16
+// CHECK-AMPERE: %[[C2:.*]] = arith.constant 2
+// CHECK-AMPERE: %[[ADDR:.*]] = llvm.getelementptr
+// CHECK-AMPERE: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] seq_cst
+
+// CHECK-GFX908-MI100-LABEL: @direct_atomic_rmw_fadd_f16
+// CHECK-GFX908-MI100-NOT: llvm.atomicrmw fadd
+
+// CHECK-GFX90A-MI200-LABEL: @direct_atomic_rmw_fadd_f16
+// CHECK-GFX90A-MI200: %[[C2:.*]] = arith.constant 2
+// CHECK-GFX90A-MI200: %[[ADDR:.*]] = llvm.getelementptr
+// CHECK-GFX90A-MI200: %[[ADDR_CAST:.*]] = llvm.addrspacecast %[[ADDR]] : !llvm.ptr to !llvm.ptr<1>
+// CHECK-GFX90A-MI200: llvm.atomicrmw fadd %[[ADDR_CAST]], %[[C2]] syncscope("agent") seq_cst
+
+// -----
+
+module {
+  func.func @direct_atomic_rmw_fadd_bf16(%in: tensor<2x4xbf16>,
+      %i: index, %j: index) -> (tensor<2x4xbf16>) {
+    %c2 = arith.constant 2.0 : bf16
+    %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xbf16> {
+      ^bb0(%current : bf16):
+        %min = arith.addf %current, %c2 : bf16
+        xla_gpu.yield %c2 : bf16
+    }
+    return %ret : tensor<2x4xbf16>
+  }
+}
+// CHECK-LABEL: @direct_atomic_rmw_fadd_bf16
+// CHECK-NOT: llvm.atomicrmw fadd
+
+// CHECK-HOPPER-LABEL: @direct_atomic_rmw_fadd_bf16
+// CHECK-HOPPER: %[[C2:.*]] = arith.constant 2
+// CHECK-HOPPER: %[[ADDR:.*]] = llvm.getelementptr
+// CHECK-HOPPER: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] seq_cst
+
+// -----
+
+module {
+  func.func @direct_atomic_rmw_fadd_f64(%in: tensor<2x4xf64>,
+    %i: index, %j: index) -> (tensor<2x4xf64>) {
+    %c2 = arith.constant 2.0 : f64
+    %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xf64> {
+      ^bb0(%current : f64):
+        %min = arith.addf %current, %c2 : f64
+        xla_gpu.yield %c2 : f64
+    }
+    return %ret : tensor<2x4xf64>
+  }
+}
+// CHECK-LABEL: @direct_atomic_rmw_fadd_f64
+// CHECK: %[[C2:.*]] = arith.constant 2
+// CHECK: %[[ADDR:.*]] = llvm.getelementptr
+// CHECK: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] seq_cst
+
+// CHECK-VOLTA-LABEL: @direct_atomic_rmw_fadd_f64
+// CHECK-VOLTA: %[[C2:.*]] = arith.constant 2
+// CHECK-VOLTA: %[[ADDR:.*]] = llvm.getelementptr
+// CHECK-VOLTA: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] seq_cst
+
+// CHECK-AMPERE-LABEL: @direct_atomic_rmw_fadd_f64
+// CHECK-AMPERE: %[[C2:.*]] = arith.constant 2
+// CHECK-AMPERE: %[[ADDR:.*]] = llvm.getelementptr
+// CHECK-AMPERE: llvm.atomicrmw fadd %[[ADDR]], %[[C2]] seq_cst
+
+// CHECK-GFX908-MI100-LABEL: @direct_atomic_rmw_fadd_f64
+// CHECK-GFX908-MI100-NOT: llvm.atomicrmw fadd
+
+// CHECK-GFX90A-MI200-LABEL: @direct_atomic_rmw_fadd_f64
+// CHECK-GFX90A-MI200-NOT: llvm.atomicrmw fadd
+
+// -----
+
+module {
+  func.func @direct_atomic_rmw_maximumf(%in: tensor<2x4xf32>,
+    %i: index, %j: index) -> (tensor<2x4xf32>) {
+    %c2 = arith.constant 2.0 : f32
+    %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xf32> {
+      ^bb0(%current : f32):
+        %min = arith.maximumf %current, %c2 : f32
+        xla_gpu.yield %c2 : f32
+    }
+    return %ret : tensor<2x4xf32>
+  }
+}
+// CHECK-LABEL: @direct_atomic_rmw_maximumf
+
+// CHECK: %[[MODIFIER:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK: %[[NAN:.*]] = llvm.mlir.constant(0x7FC00000 : f32) : f32
+// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[ADDR:.*]] = llvm.getelementptr
+// CHECK: %[[CURRENT:.*]] = llvm.load %[[ADDR]] : !llvm.ptr -> f32
+// CHECK: %[[CURRENT_IS_NAN:.*]] = llvm.fcmp "uno" %[[CURRENT]], %[[CURRENT]] : f32
+// CHECK: scf.if %[[CURRENT_IS_NAN]] {
+// CHECK: } else {
+// CHECK:   %[[MODIFIER_IS_NAN:.*]] = llvm.fcmp "uno" %[[MODIFIER]], %[[MODIFIER]] : f32
+// CHECK:   %[[MODIFIER_OR_NAN:.*]] = llvm.select %[[MODIFIER_IS_NAN]], %[[NAN]], %[[MODIFIER]] : i1, f32
+// CHECK:   %[[VAL_13:.*]] = llvm.fcmp "ult" %[[CURRENT]], %[[MODIFIER_OR_NAN]] : f32
+// CHECK:   scf.if %[[VAL_13]] {
+// CHECK:     %[[INT_MODIFIER_OR_NAN:.*]] = llvm.bitcast %[[MODIFIER_OR_NAN]] : f32 to i32
+// CHECK:     %[[IS_POSITIVE:.*]] = llvm.icmp "sge" %[[INT_MODIFIER_OR_NAN]], %[[C0]] : i32
+// CHECK:     scf.if %[[IS_POSITIVE]] {
+// CHECK:       llvm.atomicrmw max %[[ADDR]], %[[INT_MODIFIER_OR_NAN]] seq_cst
+// CHECK:     } else {
+// CHECK:       llvm.atomicrmw umin %[[ADDR]], %[[INT_MODIFIER_OR_NAN]] seq_cst
+// CHECK:     }
+// CHECK:   }
+// CHECK: }
+// CHECK: return
+
+// -----
+
+module {
+  func.func @atomic_rmw_c32(%in: tensor<2x4xcomplex<f32>>, %i: index, %j: index)
+      -> (tensor<2x4xcomplex<f32>>) {
+    %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xcomplex<f32>> {
+      ^bb0(%current : complex<f32>):
+        %a = complex.add %current, %current : complex<f32>
+        xla_gpu.yield %a : complex<f32>
+    }
+    return %ret : tensor<2x4xcomplex<f32>>
+  }
+}
+
+// CHECK-LABEL: @atomic_rmw_c32
+
+// CHECK: scf.while (%[[ITER_ARG:.*]] = %{{.*}}) : (i64) -> i64
+// CHECK: %[[TMP:.*]] = llvm.alloca
+// CHECK: llvm.store %[[ITER_ARG]], %[[TMP]]
+// CHECK: %[[LD:.*]] = llvm.load %[[TMP]] : {{.*}} -> !llvm.struct<(f32, f32)>
+// CHECK: builtin.unrealized_conversion_cast %[[LD]] : {{.*}} to complex<f32>
+
+// -----
+
+module {
+  func.func @unused_index_switch_results(%i: index) -> index {
+    %ret, %ret2 = scf.index_switch %i -> tensor<2x4xi32>, tensor<3xf32>
+    case 0 {
+      %x, %y = "dummy.op1"() : () -> (tensor<2x4xi32>, tensor<3xf32>)
+      scf.yield %x, %y : tensor<2x4xi32>, tensor<3xf32>
+    }
+    default {
+      %x, %y = "dummy.op2"() : () -> (tensor<2x4xi32>, tensor<3xf32>)
+      scf.yield %x, %y : tensor<2x4xi32>, tensor<3xf32>
+    }
+    return %i : index
+  }
+}
+
+// CHECK-LABEL: func.func @unused_index_switch_results
+// CHECK-SAME:      (%[[I:.*]]: index)
+// CHECK-NEXT:    scf.index_switch %[[I]]
+// CHECK-NEXT:    case 0 {
+// CHECK-NEXT:      "dummy.op1"
+// CHECK-NEXT:      scf.yield
+// CHECK-NEXT:    }
+// CHECK-NEXT:    default {
+// CHECK-NEXT:      "dummy.op2"
+// CHECK-NEXT:    }
+// CHECK-NEXT:    return %[[I]] : index
+
+// -----
+
+module {
+  func.func @transfer_write(%arg0: tensor<43xf32> {xla.slice_index = 1}) -> tensor<43xf32> {
+    %c16 = arith.constant 16 : index
+    %c22 = arith.constant 22 : index
+    %cst = arith.constant dense<[1.0, 2.0]> : vector<2xf32>
+    %out = vector.transfer_write %cst, %arg0[%c16] : vector<2xf32>, tensor<43xf32>
+    %out2 = vector.transfer_write %cst, %out[%c22] : vector<2xf32>, tensor<43xf32>
+    func.return %out2 : tensor<43xf32>
+  }
+}
+
+// CHECK-LABEL: @transfer_write
+// CHECK:           %[[PTR1:.*]] = llvm.getelementptr inbounds %[[BUF:.*]][16]
+// CHECK-NEXT:      llvm.store %[[CST:.*]], %[[PTR1]]
+// CHECK-NEXT:      %[[PTR2:.*]] = llvm.getelementptr inbounds %[[BUF]][22]
+// CHECK-NEXT:      llvm.store %[[CST]], %[[PTR2]]
+
+// -----
+
+module {
+  func.func @transfer_read(%arg0: tensor<43xf32> {xla.slice_index = 1}) -> vector<2xf32> {
+    %c16 = arith.constant 16 : index
+    %c0 = arith.constant 0.0 : f32
+    %out = vector.transfer_read %arg0[%c16], %c0 : tensor<43xf32>, vector<2xf32>
+    func.return %out : vector<2xf32>
+  }
+}
+
+// CHECK-LABEL: @transfer_read
+// CHECK:           %[[PTR:.*]] = llvm.getelementptr inbounds %{{.*}}[16]
+// CHECK-NEXT:      llvm.load %[[PTR]] : !llvm.ptr -> vector<2xf32>
+
+// -----
+
+module {
+  func.func @transfer_write_i1(%arg0: tensor<43xi1> {xla.slice_index = 1},
+                               %v1: vector<2xi1>, %v2: vector<2xi1>) -> tensor<43xi1> {
+    %c16 = arith.constant 16 : index
+    %c22 = arith.constant 22 : index
+    %out = vector.transfer_write %v1, %arg0[%c16] : vector<2xi1>, tensor<43xi1>
+    %out2 = vector.transfer_write %v2, %out[%c22] : vector<2xi1>, tensor<43xi1>
+    func.return %out2 : tensor<43xi1>
+  }
+}
+
+// CHECK-LABEL: @transfer_write_i1
+// CHECK-SAME:      (%[[ARG0:.*]]: !llvm.ptr
+// CHECK-SAME:       %[[V1:.*]]: vector<2xi1>, %[[V2:.*]]: vector<2xi1>)
+// CHECK-DAG:       %[[PTR1:.*]] = llvm.getelementptr inbounds %[[BUF:.*]][16]
+// CHECK-DAG:       %[[V1_EXT:.*]] = arith.extui %[[V1]]
+// CHECK:           llvm.store %[[V1_EXT]], %[[PTR1]]
+// CHECK-DAG:       %[[PTR2:.*]] = llvm.getelementptr inbounds %[[BUF]][22]
+// CHECK-DAG:       %[[V2_EXT:.*]] = arith.extui %[[V2]]
+// CHECK:           llvm.store %[[V2_EXT]], %[[PTR2]]
+
+// -----
+
+module {
+  func.func @transfer_read_i1(%arg0: tensor<43xi1> {xla.slice_index = 1}) -> vector<2xi1> {
+    %c16 = arith.constant 16 : index
+    %false = arith.constant false
+    %out = vector.transfer_read %arg0[%c16], %false : tensor<43xi1>, vector<2xi1>
+    func.return %out : vector<2xi1>
+  }
+}
+
+// CHECK-LABEL: @transfer_read_i1
+// CHECK-DAG:       %[[C0:.*]] = arith.constant dense<0> : vector<2xi8>
+// CHECK-DAG:       %[[PTR:.*]] = llvm.getelementptr inbounds %{{.*}}[16]
+// CHECK:           %[[LOADED:.*]] = llvm.load %[[PTR]] : !llvm.ptr
+// CHECK:           %[[CAST:.*]] = arith.cmpi ne, %[[LOADED]], %[[C0]]
+// CHECK:           return %[[CAST]] : vector<2xi1>
+
+// -----
+
+module {
+  func.func @transfer_write_i4(%arg0: tensor<43xi4> {xla.slice_index = 1},
+                               %v1: vector<4xi4>) -> tensor<43xi4> {
+    %c16 = arith.constant 16 : index
+    %out = vector.transfer_write %v1, %arg0[%c16] : vector<4xi4>, tensor<43xi4>
+    func.return %out : tensor<43xi4>
+  }
+}
+
+// CHECK-LABEL: @transfer_write_i4
+// CHECK-SAME:       , %[[V1:.*]]: vector<4xi4>
+// CHECK-DAG:       %[[A0:.*]] = vector.extract %[[V1]][0]
+// CHECK-DAG:       %[[A1:.*]] = vector.extract %[[V1]][1]
+// CHECK-DAG:       %[[A2:.*]] = vector.extract %[[V1]][2]
+// CHECK-DAG:       %[[A3:.*]] = vector.extract %[[V1]][3]
+// CHECK-DAG:       vector.insert %[[A0]], {{.*}}[1]
+// CHECK-DAG:       vector.insert %[[A1]], {{.*}}[0]
+// CHECK-DAG:       vector.insert %[[A2]], {{.*}}[3]
+// CHECK-DAG:       vector.insert %[[A3]], {{.*}}[2]
+
+module {
+  func.func @transfer_read_i4(%arg0: tensor<43xi4> {xla.slice_index = 1}) -> vector<4xi4> {
+    %c16 = arith.constant 16 : index
+    %c0 = arith.constant 0 : i4
+    %out = vector.transfer_read %arg0[%c16], %c0 : tensor<43xi4>, vector<4xi4>
+    func.return %out : vector<4xi4>
+  }
+}
+
+// CHECK-LABEL: @transfer_read_i4
+// CHECK:           %[[LOADED:.*]] = llvm.load
+// CHECK-DAG:       %[[A0:.*]] = vector.extract %[[LOADED]][0]
+// CHECK-DAG:       %[[A1:.*]] = vector.extract %[[LOADED]][1]
+// CHECK-DAG:       %[[A2:.*]] = vector.extract %[[LOADED]][2]
+// CHECK-DAG:       %[[A3:.*]] = vector.extract %[[LOADED]][3]
+// CHECK-DAG:       vector.insert %[[A0]], {{.*}}[1]
+// CHECK-DAG:       vector.insert %[[A1]], {{.*}}[0]
+// CHECK-DAG:       vector.insert %[[A2]], {{.*}}[3]
+// CHECK-DAG:       vector.insert %[[A3]], {{.*}}[2]
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_loops_to_scf.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_loops_to_scf.mlir
new file mode 100644
index 0000000..f0de25a
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_loops_to_scf.mlir
@@ -0,0 +1,52 @@
+// RUN: mlir_fusions_opt %s -xla-gpu-lower-xla-gpu-loops-to-scf | FileCheck %s
+
+#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (s0, s1),
+  domain: d0 in [0, 3], s0 in [0, 1024], s1 in [0, 32], s0 + s1 in [0, 90]>
+func.func @loop_op(%input: tensor<1024x32xf32>, %init: f32, %dim: index) -> (f32) {
+  %sum = xla_gpu.loop (%dim)[%i, %j] in #map iter_args(%sum_ = %init) -> (f32) {
+    %t = tensor.extract %input[%i, %j] : tensor<1024x32xf32>
+    %add = arith.addf %sum_, %t : f32
+    xla_gpu.yield %add : f32
+  } {xla.range = [0 : index, 42 : index]}
+  func.return %sum : f32
+}
+
+// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0)[s0, s1] -> (s0 + s1),
+// CHECK-SAME:  domain: d0 in [0, 3], s0 in [0, 1024], s1 in [0, 32]>
+
+// CHECK-LABEL: func.func @loop_op(
+// CHECK-SAME:    %[[IN:.*]]: tensor<1024x32xf32>,
+// CHECK-SAME:    %[[INIT:.*]]: f32, %[[DIM:.*]]: index) -> f32 {
+
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG:     %[[C33:.*]] = arith.constant 33 : index
+// CHECK-DAG:     %[[C90:.*]] = arith.constant 90 : index
+// CHECK-DAG:     %[[C1025:.*]] = arith.constant 1025 : index
+
+// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C1025]] step %[[C1]]
+// CHECK-SAME:   iter_args(%[[INIT_:.*]] = %[[INIT]]) -> (f32) {
+
+// CHECK:      %[[INNER_FOR:.*]] = scf.for %[[J:.*]] = %[[C0]] to %[[C33]]
+// CHECK-SAME:   step %[[C1]] iter_args(%[[INIT__:.*]] = %[[INIT_]]) -> (f32) {
+
+// CHECK:        %[[INDEX:.*]] = xla_gpu.apply_indexing
+// CHECK-SAME:     #[[$MAP]](%[[DIM]])[%[[I]], %[[J]]]
+// CHECK:        %[[VAL1:.*]] = arith.cmpi sge, %[[INDEX]], %[[C0]] : index
+// CHECK:        %[[VAL2:.*]] = arith.cmpi sle, %[[INDEX]], %[[C90]] : index
+// CHECK:        %[[VAL3:.*]] = arith.andi %[[VAL1]], %[[VAL2]] : i1
+// CHECK:        %[[VAL4:.*]] = arith.cmpi sge, %[[DIM]], %[[C0]] : index
+// CHECK:        %[[VAL5:.*]] = arith.cmpi sle, %[[DIM]], %[[C3]] : index
+// CHECK:        %[[VAL6:.*]] = arith.andi %[[VAL4]], %[[VAL5]] : i1
+// CHECK:        %[[INBOUNDS:.*]] = arith.andi %[[VAL3]], %[[VAL6]] : i1
+// CHECK:        %[[IF_RESULT:.*]] = scf.if %[[INBOUNDS]] -> (f32) {
+// CHECK:          %[[ELEM:.*]] = tensor.extract %[[IN]][%[[I]], %[[J]]]
+// CHECK:          %[[SUM:.*]] = arith.addf %[[INIT__]], %[[ELEM]] : f32
+// CHECK:          scf.yield %[[SUM]] : f32
+// CHECK:        } else {
+// CHECK:          scf.yield %[[INIT__]] : f32
+// CHECK:        }
+// CHECK:        scf.yield %[[IF_RESULT]] : f32
+// CHECK:      }
+// CHECK:      scf.yield %[[INNER_FOR]] : f32
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_to_scf.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_to_scf.mlir
new file mode 100644
index 0000000..2f9494a
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_to_scf.mlir
@@ -0,0 +1,121 @@
+// RUN: mlir_fusions_opt %s -xla-gpu-lower-xla-gpu-to-scf --split-input-file  \
+// RUN: | FileCheck %s
+
+func.func @reducer(%a: f32, %b: i32, %c: f32, %d: i32) -> (f32, i32) {
+  return %a, %b : f32, i32
+}
+
+func.func @shuffler(%a: f32, %b: i32) -> (f32, i32) {
+  %ret:2 = xla_gpu.shuffle_reduce @reducer(%a, %b) to 4 : f32, i32
+  return %ret#0, %ret#1 : f32, i32
+}
+// CHECK: @shuffler(%[[A:.*]]: f32, %[[B:.*]]: i32)
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4
+// CHECK-DAG: %[[C32:.*]] = arith.constant 32
+// CHECK: %[[A4H:.*]], {{.*}} = gpu.shuffle down %[[A]], %[[C4]], %[[C32]]
+// CHECK: %[[B4H:.*]], {{.*}} = gpu.shuffle down %[[B]], %[[C4]], %[[C32]]
+// CHECK: %[[AB4_0:.*]], %[[AB4_1:.*]] = xla_gpu.pure_call @reducer(%[[A]], %[[B]], %[[A4H]], %[[B4H]])
+// CHECK: %[[A2H:.*]], {{.*}} = gpu.shuffle down %[[AB4_0]], %[[C2]], %[[C32]]
+// CHECK: %[[B2H:.*]], {{.*}} = gpu.shuffle down %[[AB4_1]], %[[C2]], %[[C32]]
+// CHECK: %[[AB2_0:.*]], %[[AB2_1:.*]] = xla_gpu.pure_call @reducer(%[[AB4_0]], %[[AB4_1]], %[[A2H]], %[[B2H]])
+// CHECK: %[[A1H:.*]], {{.*}} = gpu.shuffle down %[[AB2_0]], %[[C1]], %[[C32]]
+// CHECK: %[[B1H:.*]], {{.*}} = gpu.shuffle down %[[AB2_1]], %[[C1]], %[[C32]]
+// CHECK: %[[AB1_0:.*]], %[[AB1_1:.*]] = xla_gpu.pure_call @reducer(%[[AB2_0]], %[[AB2_1]], %[[A1H]], %[[B1H]])
+// CHECK: return %[[AB1_0]], %[[AB1_1]]
+
+// -----
+
+func.func @reducer(%a: f64, %b: f64) -> f64 {
+  return %a : f64
+}
+
+func.func @shuffler(%a: f64) -> f64 {
+  %ret = xla_gpu.shuffle_reduce @reducer(%a) to 1 : f64
+  return %ret : f64
+}
+// CHECK: @shuffler(%[[A:.*]]: f64
+// CHECK: gpu.shuffle down {{.*}}, %[[C1]]
+// CHECK: gpu.shuffle down {{.*}}, %[[C1]]
+
+// -----
+
+func.func @reducer(%a: complex<f64>, %b: complex<f64>) -> complex<f64> {
+  return %a : complex<f64>
+}
+
+func.func @shuffler(%a: complex<f64>) -> complex<f64> {
+  %ret = xla_gpu.shuffle_reduce @reducer(%a) to 1 : complex<f64>
+  return %ret : complex<f64>
+}
+// CHECK: @shuffler
+// CHECK-COUNT-4: gpu.shuffle down {{.*}}, %[[C1]]
+
+// -----
+
+func.func @reducer(%a: ui64, %b: ui64) -> ui64 {
+  return %a : ui64
+}
+
+func.func @shuffler(%a: ui64) -> ui64 {
+  %ret = xla_gpu.shuffle_reduce @reducer(%a) to 1 : ui64
+  return %ret : ui64
+}
+// CHECK: @shuffler
+// CHECK: unrealized_conversion_cast
+// CHECK-COUNT-2: gpu.shuffle down {{.*}}, %[[C1]]
+
+// -----
+
+func.func @reducer(%a: i8, %b: i8) -> i8 {
+  return %a : i8
+}
+
+func.func @shuffler_i8(%a: i8) -> i8 {
+  %ret = xla_gpu.shuffle_reduce @reducer(%a) to 1 : i8
+  return %ret : i8
+}
+// CHECK: @shuffler_i8(
+// CHECK-NOT: vector
+// CHECK-COUNT-1: gpu.shuffle down {{.*}}, %[[C1]]
+
+// -----
+
+func.func @predicated_insert(
+    %v: i32, %tensor: tensor<2xi32>, %index: index,
+    %cond: i1) -> tensor<2xi32> {
+  %ret = xla_gpu.predicated_insert %v into %tensor[%index] if %cond
+    : tensor<2xi32>
+  return %ret : tensor<2xi32>
+}
+// CHECK: @predicated_insert
+// CHECK-SAME: %[[V:.*]]: i32, %[[TENSOR:.*]]: tensor<2xi32>,
+// CHECK-SAME: %[[INDEX:.*]]: index, %[[COND:.*]]: i1
+// CHECK-NEXT: %[[RET:.*]] = scf.if %[[COND]]
+// CHECK-NEXT:   %[[UPD:.*]] = tensor.insert %[[V]] into %[[TENSOR]][%[[INDEX]]]
+// CHECK-NEXT:   yield %[[UPD]]
+// CHECK-NEXT: else
+// CHECK-NEXT:   yield %[[TENSOR]]
+// CHECK-NEXT: }
+// CHECK-NEXT: return %[[RET]]
+
+// -----
+
+func.func @predicated_extract(
+    %v: i32, %tensor: tensor<2xi32>, %index: index,
+    %cond: i1) -> i32 {
+  %ret = xla_gpu.predicated_extract %tensor[%index] if %cond else %v
+    : tensor<2xi32>
+  return %ret : i32
+}
+// CHECK: @predicated_extract
+// CHECK-SAME: %[[V:.*]]: i32, %[[TENSOR:.*]]: tensor<2xi32>,
+// CHECK-SAME: %[[INDEX:.*]]: index, %[[COND:.*]]: i1
+// CHECK-NEXT: %[[RET:.*]] = scf.if %[[COND]]
+// CHECK-NEXT:   %[[VAL:.*]] = tensor.extract  %[[TENSOR]][%[[INDEX]]]
+// CHECK-NEXT:   yield %[[VAL]]
+// CHECK-NEXT: else
+// CHECK-NEXT:   yield %[[V]]
+// CHECK-NEXT: }
+// CHECK-NEXT: return %[[RET]]
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/merge_pointers_to_same_slice.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/merge_pointers_to_same_slice.mlir
similarity index 100%
rename from third_party/xla/xla/service/gpu/fusions/mlir/tests/merge_pointers_to_same_slice.mlir
rename to third_party/xla/xla/service/gpu/fusions/transforms/tests/merge_pointers_to_same_slice.mlir
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/optimize_loops.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/optimize_loops.mlir
new file mode 100644
index 0000000..cb6c048
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/optimize_loops.mlir
@@ -0,0 +1,211 @@
+// RUN: mlir_fusions_opt %s -split-input-file -xla-gpu-optimize-loops | FileCheck %s
+
+#map = #xla_gpu.indexing_map<(d0) -> (d0 floordiv 8),
+                              domain: d0 in [0, 31]>
+#map1 = #xla_gpu.indexing_map<(d0) -> (d0 mod 8),
+                              domain: d0 in [0, 31]>
+#map2 = #xla_gpu.indexing_map<(d0, d1)[s0] -> (d1 * 2 + d0 + s0 * 512),
+                              domain: d0 in [0, 1], d1 in [0, 255], s0 in [0, 7]>
+module {
+  func.func @fully_unroll(%arg0: tensor<4x8x4096xf32>, %arg1: tensor<4096xbf16>,
+      %arg2: tensor<4x8xf32>, %arg3: tensor<4096xbf16>,
+      %arg4: tensor<4x8x4096xbf16>, %arg5: tensor<4x8xf32>,
+      %arg6: tensor<4x8x4096xf32>) -> (tensor<4x8x4096xf32>, f32)  {
+    %cst = arith.constant 1.000000e+00 : f32
+    %cst_1 = arith.constant 1.000000e+00 : bf16
+    %c2 = arith.constant 2 : index
+    %c8 = arith.constant 8 : index
+    %c32 = arith.constant 32 : index
+    %c1 = arith.constant 1 : index
+    %c0 = arith.constant 0 : index
+    %thread_id_x = gpu.thread_id  x {xla.range = [0 : index, 255 : index]}
+    %block_id_x = gpu.block_id  x {xla.range = [0 : index, 31 : index]}
+    %0 = gpu.lane_id
+    %1 = arith.cmpi eq, %0, %c0 : index
+    %2 = arith.divui %thread_id_x, %c32 : index
+    %3 = arith.cmpi ult, %thread_id_x, %c8 : index
+    %4 = xla_gpu.apply_indexing #map(%block_id_x)
+    %5 = xla_gpu.apply_indexing #map1(%block_id_x)
+    %extracted = tensor.extract %arg2[%4, %5] : tensor<4x8xf32>
+    %6 = arith.mulf %extracted, %cst : f32
+    %7 = arith.addf %6, %cst : f32
+    %8 = math.rsqrt %7 : f32
+    %9:2 = scf.for %arg7 = %c0 to %c8 step %c1 iter_args(%arg8 = %arg6, %arg9 = %cst) -> (tensor<4x8x4096xf32>, f32) {
+      %18 = xla_gpu.apply_indexing #map2(%c0, %thread_id_x)[%arg7]
+      %19 = vector.transfer_read %arg1[%18], %cst_1 {in_bounds = [true]} : tensor<4096xbf16>, vector<2xbf16>
+      %20 = xla_gpu.apply_indexing #map2(%c0, %thread_id_x)[%arg7]
+      %21 = vector.transfer_read %arg3[%20], %cst_1 {in_bounds = [true]} : tensor<4096xbf16>, vector<2xbf16>
+      %22 = xla_gpu.apply_indexing #map2(%c0, %thread_id_x)[%arg7]
+      %23 = vector.transfer_read %arg4[%4, %5, %22], %cst_1 {in_bounds = [true]} : tensor<4x8x4096xbf16>, vector<2xbf16>
+      %24 = xla_gpu.apply_indexing #map2(%c0, %thread_id_x)[%arg7]
+      %25 = vector.transfer_read %arg0[%4, %5, %24], %cst {in_bounds = [true]} : tensor<4x8x4096xf32>, vector<2xf32>
+      %26:2 = scf.for %arg10 = %c0 to %c2 step %c1 iter_args(%arg11 = %arg8, %arg12 = %arg9) -> (tensor<4x8x4096xf32>, f32) {
+        %27 = xla_gpu.apply_indexing #map2(%arg10, %thread_id_x)[%arg7]
+        %28 = vector.extract %25[%arg10] : f32 from vector<2xf32>
+        %29 = vector.extract %23[%arg10] : bf16 from vector<2xbf16>
+        %30 = arith.extf %29 : bf16 to f32
+        %31 = vector.extract %21[%arg10] : bf16 from vector<2xbf16>
+        %32 = arith.extf %31 : bf16 to f32
+        %33 = arith.mulf %30, %32 : f32
+        %34 = arith.mulf %33, %8 : f32
+        %35 = vector.extract %19[%arg10] : bf16 from vector<2xbf16>
+        %36 = arith.extf %35 : bf16 to f32
+        %37 = arith.addf %36, %cst : f32
+        %38 = arith.mulf %34, %37 : f32
+        %39 = arith.addf %28, %38 : f32
+        %40 = arith.mulf %39, %39 : f32
+        %41 = arith.addf %arg12, %40 : f32
+        %inserted = tensor.insert %39 into %arg11[%4, %5, %27] : tensor<4x8x4096xf32>
+        scf.yield %inserted, %41 : tensor<4x8x4096xf32>, f32
+      }
+      scf.yield %26#0, %26#1 : tensor<4x8x4096xf32>, f32
+    }
+    return %9#0, %9#1 : tensor<4x8x4096xf32>, f32
+  }
+}
+
+// CHECK-LABEL: @fully_unroll
+// CHECK-NOT: scf.for
+
+// -----
+
+module {
+  func.func @unroll_by_factor(%arg0: f32) -> f32 {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c256 = arith.constant 256 : index
+    %ret = scf.for %i = %c0 to %c256 step %c1 iter_args (%v = %arg0) -> (f32) {
+      %exp = math.exp %v : f32
+      %add = arith.addf %v, %exp : f32
+      %log = math.log %add : f32
+      scf.yield %log : f32
+    }
+    return %ret : f32
+  }
+}
+
+// CHECK-LABEL: @unroll_by_factor
+// CHECK: %[[C8:.*]] = arith.constant 8 : index
+// CHECK: scf.for {{.*}} step %[[C8]]
+
+// -----
+
+module {
+  func.func @do_not_unroll(%arg0: f32) -> f32 {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c31 = arith.constant 31 : index
+    %ret = scf.for %i = %c0 to %c31 step %c1 iter_args (%v = %arg0) -> (f32) {
+      %exp = math.exp %v : f32
+      %add = arith.addf %v, %exp : f32
+      %log = math.log %add : f32
+      scf.yield %log : f32
+    }
+    return %ret : f32
+  }
+}
+
+// CHECK-LABEL: @do_not_unroll
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: scf.for {{.*}} step %[[C1]]
+
+// -----
+
+module {
+  func.func @pipeline_extract(%arg: tensor<31xf32>) -> f32 {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c31 = arith.constant 31 : index
+    %cst = arith.constant 0.0 : f32
+    %ret = scf.for %i = %c0 to %c31 step %c1 iter_args (%iter = %cst) -> (f32) {
+      %val = tensor.extract %arg[%i] : tensor<31xf32>
+      %log = math.log %val : f32
+      %add = arith.addf %log, %iter : f32
+      scf.yield %add : f32
+    }
+    return %ret : f32
+  }
+}
+
+// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 + 1),
+// CHECK-LABEL: @pipeline_extract
+// CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:  %[[C30:.*]] = arith.constant 30 : index
+// CHECK:      %[[VAL0:.*]] = tensor.extract %[[ARG0:.*]][%[[C0]]]
+// CHECK:      scf.for %[[I:.*]] = %[[C0]] {{.*}} iter_args(%[[ITER:.*]] = {{.*}}, %[[VAL:.*]] = %[[VAL0]])
+// CHECK-DAG:  %[[NEXT_I_EXISTS:.*]] = arith.cmpi ult, %[[I]], %[[C30]]
+// CHECK-DAG:  %[[NEXT_I:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[I]]
+// CHECK:      %[[NEXT_VAL:.*]] = scf.if %[[NEXT_I_EXISTS]]
+// CHECK-NEXT:   tensor.extract %[[ARG0]][%[[NEXT_I]]]
+// CHECK-NEXT:   yield
+// CHECK-NEXT: else
+// CHECK-NEXT:   yield %[[VAL]]
+// CHECK:     math.log %[[VAL]]
+// CHECK:     %[[ADD:.*]] = arith.addf
+// CHECK:     yield %[[ADD]], %[[NEXT_VAL]]
+
+// -----
+
+module {
+  func.func @pipeline_transfer(%arg: tensor<34xf32>) -> vector<2xf32> {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c17 = arith.constant 17 : index
+    %cst = arith.constant dense<[0.0, 0.0]> : vector<2xf32>
+    %cst0  = arith.constant 0.0 : f32
+    %ret = scf.for %i = %c0 to %c17 step %c1 iter_args (%iter = %cst) -> (vector<2xf32>) {
+      %base = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0) -> (d0 * 2), domain: d0 in [0, 15]>(%i)
+      %val = vector.transfer_read %arg[%base], %cst0 : tensor<34xf32>, vector<2xf32>
+      %log = math.log %val : vector<2xf32>
+      %add = arith.addf %log, %iter : vector<2xf32>
+      scf.yield %add : vector<2xf32>
+    }
+    return %ret : vector<2xf32>
+  }
+}
+
+// CHECK-DAG: #[[$MAP0:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 * 2),
+// CHECK-DAG: #[[$MAP1:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 + 1),
+// CHECK-LABEL: @pipeline_transfer
+// CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:  %[[C16:.*]] = arith.constant 16 : index
+// CHECK:      %[[BASE0:.*]] = xla_gpu.apply_indexing #[[$MAP0]](%[[C0]]
+// CHECK:      %[[VAL0:.*]] = vector.transfer_read %[[ARG0:.*]][%[[BASE0]]]
+// CHECK:      scf.for %[[I:.*]] = %[[C0]] {{.*}} iter_args(%[[ITER:.*]] = {{.*}}, %[[VAL:.*]] = %[[VAL0]])
+// CHECK-DAG:  %[[NEXT_I_EXISTS:.*]] = arith.cmpi ult, %[[I]], %[[C16]]
+// CHECK-DAG:  %[[NEXT_I:.*]] = xla_gpu.apply_indexing #[[$MAP1]](%[[I]]
+// CHECK-DAG:  %[[NEXT_BASE:.*]] = xla_gpu.apply_indexing #[[$MAP0]](%[[NEXT_I]]
+// CHECK:      %[[NEXT_VAL:.*]] = scf.if %[[NEXT_I_EXISTS]]
+// CHECK-NEXT:   vector.transfer_read %[[ARG0]][%[[NEXT_BASE]]]
+// CHECK-NEXT:   yield
+// CHECK-NEXT: else
+// CHECK-NEXT:   yield %[[VAL]]
+// CHECK:     math.log %[[VAL]]
+// CHECK:     %[[ADD:.*]] = arith.addf
+// CHECK:     yield %[[ADD]], %[[NEXT_VAL]]
+
+// -----
+
+module {
+  func.func @sequential_extract(%arg0: tensor<6xindex>, %arg1: tensor<22xindex>) -> (index) {
+    %c1 = arith.constant 1 : index
+    %c733 = arith.constant 733 : index
+    %c0 = arith.constant 0 : index
+    %2 = scf.for %i = %c0 to %c733 step %c1 iter_args(%x = %c1) -> (index) {
+      %extracted = tensor.extract %arg0[%i] : tensor<6xindex>
+      %extracted_1 = tensor.extract %arg1[%extracted] : tensor<22xindex>
+      scf.yield %extracted_1 : index
+    }
+    return %2 : index
+  }
+}
+
+// Once `extracted` is pipelined, it becomes an iter arg, so `extracted_1` is
+// extract %arg1[%arg]. While it is possible to pipeline this in principle, we
+// do not currently do this.
+
+// CHECK-LABEL: @sequential_extract
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<6xindex>, %[[ARG1:.*]]: tensor<22xindex>)
+// CHECK: tensor.extract %[[ARG0]]
+// CHECK-NOT: tensor.extract
+// CHECK: scf.for
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/peel_loops.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/peel_loops.mlir
new file mode 100644
index 0000000..2044254
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/peel_loops.mlir
@@ -0,0 +1,88 @@
+// RUN: mlir_fusions_opt -split-input-file %s -xla-gpu-peel-loops \
+// RUN: | FileCheck %s
+
+#map = #xla_gpu.indexing_map<
+  (d0)[s0, s1] -> (s0, s1),
+  domain:
+  d0 in [0, 3],
+  s0 in [0, 7],
+  s1 in [0, 10],
+  d0 + s0 in [0, 9],
+  d0 + s1 in [0, 12]
+>
+func.func @peel_both_loops(%input: tensor<16x32xf32>,
+    %init: f32, %dim: index) -> (f32) {
+  %sum = xla_gpu.loop (%dim)[%i, %j] in #map iter_args(%sum_ = %init) -> (f32) {
+    %t = tensor.extract %input[%i, %j] : tensor<16x32xf32>
+    %add = arith.addf %sum_, %t : f32
+    xla_gpu.yield %add : f32
+  }
+  func.return %sum : f32
+}
+// CHECK: #[[$PEELED_MAP:.*]] = #xla_gpu.indexing_map<(d0)[s0, s1] -> (s0, s1), domain: d0 in [0, 3], s0 in [0, 6], s1 in [0, 9]>
+// CHECK: #[[$TAIL_MAP0:.*]] = #xla_gpu.indexing_map<(d0)[s0, s1] -> (7, s1), domain: d0 in [0, 2], s0 in [7, 7], s1 in [0, 9]>
+// CHECK: #[[$TAIL_MAP1:.*]] = #xla_gpu.indexing_map<(d0)[s0, s1] -> (s0, 10), domain: d0 in [0, 2], s0 in [0, 7], s1 in [10, 10]>
+
+// CHECK-LABEL: func.func @peel_both_loops(
+// CHECK-SAME:      %[[INPUT:.*]]: tensor<16x32xf32>,
+// CHECK-SAME:      %[[INIT:.*]]: f32, %[[DIM:.*]]: index)
+
+// CHECK:      %[[PEELED:.*]] = xla_gpu.loop (%[[DIM]])[%[[I:.*]], %[[J:.*]]]
+// CHECK-SAME:     in #[[$PEELED_MAP]] iter_args(%[[INIT_:.*]] = %[[INIT]])
+// CHECK:        tensor.extract %[[INPUT]][%[[I]], %[[J]]] : tensor<16x32xf32>
+// CHECK:        arith.addf %[[INIT_]]
+
+// CHECK:      %[[TAIL0:.*]] = xla_gpu.loop (%[[DIM]])[%[[I:.*]], %[[J:.*]]]
+// CHECK-SAME:     in #[[$TAIL_MAP0]] iter_args(%[[INIT_:.*]] = %[[PEELED]])
+// CHECK:        tensor.extract %[[INPUT]][%[[I]], %[[J]]]
+// CHECK:        arith.addf %[[INIT_]]
+
+// CHECK:      %[[TAIL1:.*]] = xla_gpu.loop (%[[DIM]])[%[[I:.*]], %[[J:.*]]]
+// CHECK-SAME:     in #[[$TAIL_MAP1]] iter_args(%[[INIT_:.*]] = %[[TAIL0]])
+// CHECK:        tensor.extract %[[INPUT]][%[[I]], %[[J]]]
+// CHECK:        arith.addf %[[INIT_]]
+
+// CHECK: return %[[TAIL1]] : f32
+
+// -----
+
+#map = #xla_gpu.indexing_map<
+  (d0)[s0] -> (s0),
+  domain:
+  d0 in [0, 3],
+  s0 in [0, 7]
+>
+func.func @not_constrained_symbol(%input: tensor<16xf32>, %init: f32,
+    %dim: index) -> (f32) {
+  %sum = xla_gpu.loop (%dim)[%i] in #map iter_args(%sum_ = %init) -> (f32) {
+    %t = tensor.extract %input[%i] : tensor<16xf32>
+    %add = arith.addf %sum_, %t : f32
+    xla_gpu.yield %add : f32
+  }
+  func.return %sum : f32
+}
+// CHECK-LABEL: func.func @not_constrained_symbol
+// CHECK:     xla_gpu.loop
+// CHECK-NOT: xla_gpu.loop
+
+// -----
+
+#map = #xla_gpu.indexing_map<
+  (d0)[s0] -> (s0),
+  domain:
+  d0 in [0, 3],
+  s0 in [0, 7],
+  s0 mod 5 in [0, 1]
+>
+func.func @constraint_exists_after_peeling(%input: tensor<16xf32>, %init: f32,
+    %dim: index) -> (f32) {
+  %sum = xla_gpu.loop (%dim)[%i] in #map iter_args(%sum_ = %init) -> (f32) {
+    %t = tensor.extract %input[%i] : tensor<16xf32>
+    %add = arith.addf %sum_, %t : f32
+    xla_gpu.yield %add : f32
+  }
+  func.return %sum : f32
+}
+// CHECK-LABEL: func.func @constraint_exists_after_peeling
+// CHECK:     xla_gpu.loop
+// CHECK-NOT: xla_gpu.loop
\ No newline at end of file
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/propagate_slice_indices.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/propagate_slice_indices.mlir
similarity index 100%
rename from third_party/xla/xla/service/gpu/fusions/mlir/tests/propagate_slice_indices.mlir
rename to third_party/xla/xla/service/gpu/fusions/transforms/tests/propagate_slice_indices.mlir
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_affine.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_affine.mlir
new file mode 100644
index 0000000..d51566a
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_affine.mlir
@@ -0,0 +1,146 @@
+// RUN: mlir_fusions_opt --allow-unregistered-dialect %s -split-input-file -xla-gpu-simplify-affine | FileCheck %s
+
+func.func @op_and_for_ranges(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %0 = gpu.thread_id  x {xla.range = [0 : index, 127 : index]}
+  %1 = gpu.block_id  x {xla.range = [0 : index, 3071 : index]}
+  scf.for %arg3 = %c0 to %c4 step %c1 {
+    %2 = affine.apply affine_map<()[s0, s1, s2] -> (s0 * 512 + s1 * 4 + s2 + (s1 floordiv 128) + (s2 floordiv 4))>()[%1, %0, %arg3]
+    %3 = arith.index_castui %2 : index to i64
+    %4 = llvm.getelementptr %arg0[%3] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+    %5 = llvm.load %4 invariant : !llvm.ptr -> f32
+    %8 = llvm.getelementptr %arg1[%3] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+    %9 = llvm.load %8 invariant : !llvm.ptr -> f32
+    %10 = arith.cmpf oge, %5, %9 : f32
+    %11 = llvm.getelementptr %arg2[%3] : (!llvm.ptr, i64) -> !llvm.ptr, i1
+    llvm.store %10, %11 : i1, !llvm.ptr
+  }
+  return
+}
+
+// CHECK-LABEL: @op_and_for_ranges
+// CHECK-DAG: %[[C512:.*]] = arith.constant 512
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4
+// CHECK-DAG: %[[TID_X:.*]] = gpu.thread_id x
+// CHECK-DAG: %[[BID_X:.*]] = gpu.block_id x
+// CHECK:     scf.for %[[I:.*]] =
+// CHECK:       %[[BLOCK_OFFSET:.*]] = arith.muli %[[BID_X]], %[[C512]]
+// CHECK:       %[[THREAD_OFFSET:.*]] = arith.muli %[[TID_X]], %[[C4]]
+// CHECK:       %[[OFFSET:.*]] = arith.addi %[[BLOCK_OFFSET]], %[[THREAD_OFFSET]]
+// CHECK:       arith.addi %[[OFFSET]], %[[I]]
+
+// -----
+
+func.func @arg_ranges(%arg0: index {xla.range = [0 : index, 42 : index]}, %arg1: index {xla.range = [0 : index, 1000 : index]}) -> index {
+  %0 = affine.apply affine_map<()[s0, s1] -> (s0 floordiv 100 + s1 floordiv 100)>()[%arg0, %arg1]
+  return %0 : index
+}
+
+// CHECK-LABEL: @arg_ranges
+// CHECK-NEXT:  %[[C100:.*]] = arith.constant 100
+// CHECK-NEXT:  %[[RET:.*]] = arith.divui %{{.*}}, %[[C100]]
+// CHECK-NEXT:  return %[[RET]]
+
+// -----
+
+func.func @cant_lower(%arg0: index {xla.range = [-10 : index, 42 : index]}, %arg1: index {xla.range = [0 : index, 1000 : index]}) -> index {
+  %0 = affine.apply affine_map<()[s0, s1] -> (s0 floordiv 100 + s1 floordiv 100)>()[%arg0, %arg1]
+  return %0 : index
+}
+
+// CHECK-LABEL:       @cant_lower
+// CHECK:       affine.apply
+
+// -----
+
+func.func @op_and_for_ranges(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %0 = gpu.thread_id  x
+  %1 = gpu.block_id  x
+  scf.for %i = %c0 to %c4 step %c1 {
+    %2 = xla_gpu.apply_indexing
+      #xla_gpu.indexing_map<()[s0, s1, s2] -> (s0 * 512 + s1 * 4 + s2 + (s1 floordiv 128) + (s2 floordiv 4)),
+      domain: s0 in [0, 3071], s1 in [0, 127], s2 in [0, 3]>[%1, %0, %i]
+    %3 = arith.index_castui %2 : index to i64
+    %4 = llvm.getelementptr %arg0[%3] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+    %5 = llvm.load %4 invariant : !llvm.ptr -> f32
+    %8 = llvm.getelementptr %arg1[%3] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+    %9 = llvm.load %8 invariant : !llvm.ptr -> f32
+    %10 = arith.cmpf oge, %5, %9 : f32
+    %11 = llvm.getelementptr %arg2[%3] : (!llvm.ptr, i64) -> !llvm.ptr, i1
+    llvm.store %10, %11 : i1, !llvm.ptr
+  }
+  return
+}
+
+// CHECK-LABEL: @op_and_for_ranges
+// CHECK-DAG: %[[C512:.*]] = arith.constant 512
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4
+// CHECK-DAG: %[[TID_X:.*]] = gpu.thread_id x
+// CHECK-DAG: %[[BID_X:.*]] = gpu.block_id x
+// CHECK:     scf.for %[[I:.*]] =
+// CHECK:       %[[BLOCK_OFFSET:.*]] = arith.muli %[[BID_X]], %[[C512]]
+// CHECK:       %[[THREAD_OFFSET:.*]] = arith.muli %[[TID_X]], %[[C4]]
+// CHECK:       %[[OFFSET:.*]] = arith.addi %[[BLOCK_OFFSET]], %[[THREAD_OFFSET]]
+// CHECK:       arith.addi %[[OFFSET]], %[[I]]
+
+// -----
+
+func.func @arg_ranges(%arg0: index, %arg1: index) -> index {
+  %0 = xla_gpu.apply_indexing
+    #xla_gpu.indexing_map<()[s0, s1] -> (s0 floordiv 100 + s1 floordiv 100),
+    domain: s0 in [0, 42], s1 in [0, 1000]>[%arg0, %arg1]
+  return %0 : index
+}
+
+// CHECK-LABEL: @arg_ranges
+// CHECK-NEXT:  %[[C100:.*]] = arith.constant 100
+// CHECK-NEXT:  %[[RET:.*]] = arith.divui %{{.*}}, %[[C100]]
+// CHECK-NEXT:  return %[[RET]]
+
+// -----
+
+func.func @cant_lower(%arg0: index, %arg1: index) -> (index, index) {
+  %0:2 = xla_gpu.apply_indexing
+    #xla_gpu.indexing_map<()[s0, s1] -> (s0 floordiv 100 + s1 floordiv 100, s0 + s1),
+    domain: s0 in [-10, 42], s1 in [0, 1000]>[%arg0, %arg1]
+  return %0#0, %0#1 : index, index
+}
+
+// CHECK-LABEL: @cant_lower
+// CHECK:         affine.apply
+// CHECK-NEXT:    arith.addi
+
+// -----
+
+func.func @order_summands(%arg1: index) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  scf.for %arg2 = %c0 to %c4 step %c1 {
+    scf.for %arg3 = %c0 to %c4 step %c1 {
+      %0 = xla_gpu.apply_indexing
+        #xla_gpu.indexing_map<()[s0, s1, s2] -> ((s0 + s1) floordiv 3 + s0 * 512 + s1 * 4 + s2 * 10),
+        domain: s0 in [0, 3], s1 in [0, 3], s2 in [0, 3]>[%arg2, %arg1, %arg3]
+      "dummy.op"(%0) : (index) -> ()
+    }
+  }
+  return
+}
+
+// CHECK-LABEL: @order_summands
+// CHECK-SAME:    (%[[ARG1:.*]]: index)
+// CHECK: scf.for %[[ARG2:.*]] =
+// CHECK: scf.for %[[ARG3:.*]] =
+// CHECK: arith.muli %[[ARG1]]
+// CHECK: arith.muli %[[ARG2]]
+// CHECK: arith.addi
+// CHECK: arith.addi %[[ARG1]], %[[ARG2]]
+// CHECK: arith.divui
+// CHECK: arith.addi
+// CHECK: arith.muli %[[ARG3]]
+// CHECK: arith.addi %5, %6 : index
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_arith.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_arith.mlir
new file mode 100644
index 0000000..09c8901
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_arith.mlir
@@ -0,0 +1,292 @@
+// RUN: mlir_fusions_opt %s -split-input-file -xla-gpu-simplify-arith -cse -canonicalize | FileCheck %s
+
+module {
+  func.func @unknown(%arg0: index {xla.range = [0 : index, 42 : index]}) -> i1 {
+    %c12 = arith.constant 12 : index
+    %eq = arith.cmpi eq, %arg0, %c12 : index
+    return %eq : i1
+  }
+}
+
+// CHECK: @unknown
+// CHECK: cmpi
+
+// -----
+
+module {
+  func.func @true(%arg0: index {xla.range = [12 : index, 42 : index]}) -> i1 {
+    %c5 = arith.constant 5 : index
+    %eq = arith.cmpi sge, %arg0, %c5 : index
+    return %eq : i1
+  }
+}
+
+// CHECK: @true
+// CHECK-NEXT: constant true
+// CHECK-NEXT: return
+
+// -----
+
+module {
+  func.func @false(%arg0: index {xla.range = [12 : index, 42 : index]}) -> i1 {
+    %c5 = arith.constant 5 : index
+    %eq = arith.cmpi slt, %arg0, %c5 : index
+    return %eq : i1
+  }
+}
+
+// CHECK: @false
+// CHECK-NEXT: constant false
+// CHECK-NEXT: return
+
+// -----
+
+module {
+  func.func @rhs_range(%arg0: index {xla.range = [12 : index, 42 : index]}) -> i1 {
+    %c42 = arith.constant 64 : index
+    %eq = arith.cmpi slt, %c42, %arg0 : index
+    return %eq : i1
+  }
+}
+
+// CHECK: @rhs_range
+// CHECK-NEXT: constant false
+// CHECK-NEXT: return
+
+// -----
+
+module {
+  func.func @both_range(%arg0: index {xla.range = [12 : index, 42 : index]},
+                        %arg1: index {xla.range = [63 : index, 100 : index]}) -> i1 {
+    %eq = arith.cmpi slt, %arg0, %arg1 : index
+    return %eq : i1
+  }
+}
+
+// CHECK-LABEL: @both_range
+// CHECK-NEXT: constant true
+// CHECK-NEXT: return
+
+// -----
+
+module {
+  func.func @minsi_lhs(%arg0: index {xla.range = [12 : index, 42 : index]},
+                       %arg1: index {xla.range = [63 : index, 100 : index]}) -> index {
+    %min = arith.minsi %arg0, %arg1 : index
+    return %min : index
+  }
+}
+
+// CHECK-LABEL: @minsi_lhs
+// CHECK-SAME: (%[[ARG0:.*]]: index {{.*}}, %[[ARG1:.*]]: index {{.*}})
+// CHECK-NEXT: return %[[ARG0]]
+
+// -----
+
+module {
+  func.func @minsi_rhs(%arg0: index {xla.range = [12 : index, 42 : index]},
+                       %arg1: index {xla.range = [63 : index, 100 : index]}) -> index {
+    %min = arith.minsi %arg1, %arg0 : index
+    return %min : index
+  }
+}
+
+// CHECK-LABEL: @minsi_rhs
+// CHECK-SAME: (%[[ARG0:.*]]: index {{.*}}, %[[ARG1:.*]]: index {{.*}})
+// CHECK-NEXT: return %[[ARG0]]
+
+// -----
+
+module {
+  func.func @maxsi_lhs(%arg0: index {xla.range = [12 : index, 42 : index]},
+                       %arg1: index {xla.range = [63 : index, 100 : index]}) -> index {
+    %min = arith.maxsi %arg1, %arg0 : index
+    return %min : index
+  }
+}
+
+// CHECK-LABEL: @maxsi_lhs
+// CHECK-SAME: (%[[ARG0:.*]]: index {{.*}}, %[[ARG1:.*]]: index {{.*}})
+// CHECK-NEXT: return %[[ARG1]]
+
+// -----
+
+module {
+  func.func @maxsi_rhs(%arg0: index {xla.range = [12 : index, 42 : index]},
+                       %arg1: index {xla.range = [63 : index, 100 : index]}) -> index {
+    %min = arith.maxsi %arg0, %arg1 : index
+    return %min : index
+  }
+}
+
+// CHECK-LABEL: @maxsi_rhs
+// CHECK-SAME: (%[[ARG0:.*]]: index {{.*}}, %[[ARG1:.*]]: index {{.*}})
+// CHECK-NEXT: return %[[ARG1]]
+
+// -----
+
+module {
+  func.func @maxsi_add(%arg0: index {xla.range = [102 : index, 142 : index]},
+                       %arg1: index {xla.range = [63 : index, 100 : index]}) -> index {
+    %add = arith.addi %arg0, %arg1 : index
+    %min = arith.maxsi %add, %arg1 : index
+    return %min : index
+  }
+}
+
+// CHECK-LABEL: @maxsi_add
+// CHECK-SAME: (%[[ARG0:.*]]: index {{.*}}, %[[ARG1:.*]]: index {{.*}})
+// CHECK-NEXT: %[[ADD:.*]] = arith.addi %[[ARG0]], %[[ARG1]]
+// CHECK-NEXT: return %[[ADD]]
+
+// -----
+
+module {
+  func.func @minsi_add(%arg0: index {xla.range = [102 : index, 142 : index]},
+                       %arg1: index {xla.range = [63 : index, 100 : index]}) -> index {
+    %add = arith.addi %arg0, %arg1 : index
+    %min = arith.minsi %add, %arg1 : index
+    return %min : index
+  }
+}
+
+// CHECK-LABEL: @minsi_add
+// CHECK-SAME: (%[[ARG0:.*]]: index {{.*}}, %[[ARG1:.*]]: index {{.*}})
+// CHECK-NEXT: return %[[ARG1]]
+
+// -----
+
+module {
+  func.func @pred_reduce(%in: i1) -> i1 {
+    %c1_i32 = arith.constant 1 : i32
+    %c2_i32 = arith.constant 2 : i32
+    %c4_i32 = arith.constant 4 : i32
+    %c8_i32 = arith.constant 8 : i32
+    %c16_i32 = arith.constant 16 : i32
+    %c32_i32 = arith.constant 32 : i32
+    %0 = arith.extui %in : i1 to i32
+    %shuffleResult, %valid = gpu.shuffle  down %0, %c16_i32, %c32_i32 : i32
+    %1 = arith.trunci %shuffleResult : i32 to i1
+    %2 = arith.ori %in, %1 : i1
+    %3 = arith.extui %2 : i1 to i32
+    %shuffleResult_0, %valid_1 = gpu.shuffle  down %3, %c8_i32, %c32_i32 : i32
+    %4 = arith.trunci %shuffleResult_0 : i32 to i1
+    %5 = arith.ori %2, %4 : i1
+    %6 = arith.extui %5 : i1 to i32
+    %shuffleResult_2, %valid_3 = gpu.shuffle  down %6, %c4_i32, %c32_i32 : i32
+    %7 = arith.trunci %shuffleResult_2 : i32 to i1
+    %8 = arith.ori %5, %7 : i1
+    %9 = arith.extui %8 : i1 to i32
+    %shuffleResult_4, %valid_5 = gpu.shuffle  down %9, %c2_i32, %c32_i32 : i32
+    %10 = arith.trunci %shuffleResult_4 : i32 to i1
+    %11 = arith.ori %8, %10 : i1
+    %12 = arith.extui %11 : i1 to i32
+    %shuffleResult_6, %valid_7 = gpu.shuffle  down %12, %c1_i32, %c32_i32 : i32
+    %13 = arith.trunci %shuffleResult_6 : i32 to i1
+    %14 = arith.ori %11, %13 : i1
+    return %14 : i1
+  }
+}
+
+// CHECK-LABEL: @pred_reduce
+// CHECK-SAME:     (%[[IN:.*]]: i1)
+// CHECK:       %[[IN_EXT:.*]] = arith.extui %[[IN]]
+// CHECK-NEXT:  %[[SHUFFLE0:.*]], {{.*}} = gpu.shuffle down %[[IN_EXT]]
+// CHECK-NEXT:  %[[OR0:.*]] = arith.ori %[[IN_EXT]], %[[SHUFFLE0]]
+// CHECK-NEXT:  %[[SHUFFLE1:.*]], {{.*}} = gpu.shuffle down %[[OR0]]
+// CHECK-NEXT:  %[[OR1:.*]] = arith.ori %[[OR0]], %[[SHUFFLE1]]
+// CHECK-NEXT:  %[[SHUFFLE2:.*]], {{.*}} = gpu.shuffle down %[[OR1]]
+// CHECK-NEXT:  %[[OR2:.*]] = arith.ori %[[OR1]], %[[SHUFFLE2]]
+// CHECK-NEXT:  %[[SHUFFLE3:.*]], {{.*}} = gpu.shuffle down %[[OR2]]
+// CHECK-NEXT:  %[[OR3:.*]] = arith.ori %[[OR2]], %[[SHUFFLE3]]
+// CHECK-NEXT:  %[[SHUFFLE4:.*]], {{.*}} = gpu.shuffle down %[[OR3]]
+// CHECK-NEXT:  %[[OR4:.*]] = arith.ori %[[OR3]], %[[SHUFFLE4]]
+// CHECK-NEXT:  %[[RET:.*]] = arith.trunci %[[OR4]]
+// CHECK-NEXT:  return %[[RET]]
+
+// -----
+
+module {
+  func.func @andi_no_trunc_arg(%a: i4, %b: i8) -> i4 {
+    %lhs = arith.extui %a : i4 to i8
+    %add = arith.andi %lhs, %b : i8
+    %ret = arith.trunci %add : i8 to i4
+    return %ret : i4
+  }
+}
+
+// CHECK-LABEL: @andi_no_trunc_arg
+// CHECK-NEXT: extui
+// CHECK-NEXT: andi
+// CHECK-NEXT: trunci
+// CHECK-NEXT: return
+
+// -----
+
+module {
+  func.func @ori_mismatched_narrowest(%a: i8, %b: i8) -> i8 {
+    %0 = arith.trunci %a : i8 to i4
+    %1 = arith.extui %0 : i4 to i8
+    %ret = arith.ori %b, %1 : i8
+    return %ret : i8
+  }
+}
+
+// CHECK-LABEL: @ori_mismatched_narrowest
+// CHECK-NEXT: trunci
+// CHECK-NEXT: extui
+// CHECK-NEXT: ori
+// CHECK-NEXT: return
+
+// -----
+
+func.func @refine_constraints(%tensor: tensor<100xf32>) -> tensor<100xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c3 = arith.constant 3 : index
+  %c42_f32 = arith.constant 42.0 : f32
+  %loop = scf.for %i = %c0 to %c3 step %c1
+      iter_args(%in_ = %tensor) -> (tensor<100xf32>) {
+    %0 = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0) -> (d0 mod 4), domain: d0 in [0, 9]>(%i)
+    %updated = tensor.insert %c42_f32 into %in_[%0] : tensor<100xf32>
+    scf.yield %updated :tensor<100xf32>
+  }
+  func.return %loop : tensor<100xf32>
+}
+// CHECK-LABEL: func.func @refine_constraints
+// CHECK: %[[CST:.*]] = arith.constant 4.2
+// CHECK: scf.for
+// CHECK: tensor.insert %[[CST]]
+
+
+// -----
+
+#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (((d0 * 4 + d1 * 512 + s1) floordiv 9 + s0 * 32768) mod 2400000),
+                             domain: d0 in [0, 127], d1 in [0, 575], s0 in [0, 73], s1 in [0, 3]>
+#map1 = #xla_gpu.indexing_map<(d0, d1)[s0] -> ((d0 * 4 + d1 * 512 + s0) mod 9),
+                             domain: d0 in [0, 127], d1 in [0, 575], s0 in [0, 3]>
+func.func @refine_constraints_for_symbol(%arg0: tensor<2400000x9xf32>,
+    %arg1: tensor<2400000x9xf32>) -> tensor<2400000x9xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %c73 = arith.constant 73 : index
+  %c42_f32 = arith.constant 42.0 : f32
+  %th_x = gpu.thread_id  x {xla.range = [0 : index, 127 : index]}
+  %bl_x = gpu.block_id  x {xla.range = [0 : index, 575 : index]}
+  %0 = scf.for %i = %c0 to %c73 step %c1 iter_args(%arg3 = %arg1)
+      -> (tensor<2400000x9xf32>) {
+    %2 = scf.for %j = %c0 to %c4 step %c1 iter_args(%arg5 = %arg3)
+        -> (tensor<2400000x9xf32>) {
+      %3 = xla_gpu.apply_indexing #map(%th_x, %bl_x)[%i, %j]
+      %4 = xla_gpu.apply_indexing #map1(%th_x, %bl_x)[%j]
+      %inserted = tensor.insert %c42_f32 into %arg5[%3, %4]
+        : tensor<2400000x9xf32>
+      scf.yield %inserted : tensor<2400000x9xf32>
+    }
+    scf.yield %2 : tensor<2400000x9xf32>
+  }
+  return %0 : tensor<2400000x9xf32>
+}
+// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> ((d0 * 4 + d1 * 512 + s1) floordiv 9 + s0 * 32768),
+// CHECK-LABEL: func.func @refine_constraints_for_symbol
diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/unswitch_loops.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/unswitch_loops.mlir
similarity index 100%
rename from third_party/xla/xla/service/gpu/fusions/mlir/tests/unswitch_loops.mlir
rename to third_party/xla/xla/service/gpu/fusions/transforms/tests/unswitch_loops.mlir
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/vectorize_loads_stores.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/vectorize_loads_stores.mlir
new file mode 100644
index 0000000..16e4498
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/vectorize_loads_stores.mlir
@@ -0,0 +1,413 @@
+// RUN: mlir_fusions_opt -allow-unregistered-dialect %s -split-input-file -xla-gpu-vectorize-loads-stores -canonicalize -cse | FileCheck %s
+
+#map = #xla_gpu.indexing_map<(d0)[s0] -> (d0 * 2 + s0),
+                             domain: d0 in [0, 63], s0 in [0, 1]>
+module {
+  func.func @simple_read(%arg0: tensor<128xf32>) -> (f32) {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c2 = arith.constant 2 : index
+    %c64 = arith.constant 64 : index
+    %cst = arith.constant 0.0 : f32
+    %outer = scf.for %i = %c0 to %c64 step %c1 iter_args(%iter = %cst) -> f32 {
+      %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 {
+        %idx = xla_gpu.apply_indexing #map(%i)[%j]
+        %extracted = tensor.extract %arg0[%idx] : tensor<128xf32>
+        %added = arith.addf %iter1, %extracted : f32
+        scf.yield %added : f32
+      }
+      scf.yield %inner : f32
+    }
+    return %outer : f32
+  }
+}
+
+// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 * 2), domain: d0 in [0, 63]>
+// CHECK-LABEL: @simple_read
+// CHECK-SAME:     (%[[ARG0:.*]]: tensor
+// CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG:   %[[C64:.*]] = arith.constant 64 : index
+// CHECK:       scf.for %[[I:.*]] = %[[C0]] to %[[C64]] step %[[C1]] iter_args(%[[ITER:.*]] =
+// CHECK:         %[[BASE:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[I]])
+// CHECK-NEXT:    %[[V:.*]] = vector.transfer_read %[[ARG0]][%[[BASE]]]
+// CHECK-NEXT:    scf.for %[[J:.*]] = %[[C0]]
+// CHECK-NEXT:      vector.extract %[[V]][%[[J]]]
+// CHECK-NEXT:      addf
+
+// -----
+
+module {
+  func.func @simple_read_2d(%arg0: tensor<64x2xf32>) -> (f32) {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c2 = arith.constant 2 : index
+    %c64 = arith.constant 64 : index
+    %cst = arith.constant 0.0 : f32
+    %outer = scf.for %i = %c0 to %c64 step %c1 iter_args(%iter = %cst) -> f32 {
+      %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 {
+        %extracted = tensor.extract %arg0[%i, %j] : tensor<64x2xf32>
+        %added = arith.addf %iter1, %extracted : f32
+        scf.yield %added : f32
+      }
+      scf.yield %inner : f32
+    }
+    return %outer : f32
+  }
+}
+
+// CHECK-LABEL: @simple_read_2d
+// CHECK-SAME:     (%[[ARG0:.*]]: tensor
+// CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+// CHECK:       scf.for %[[I:.*]] = %[[C0]]
+// CHECK-NEXT:    %[[V:.*]] = vector.transfer_read %[[ARG0]][%[[I]], %[[C0]]]
+// CHECK-NEXT:    scf.for %[[J:.*]] = %[[C0]]
+// CHECK-NEXT:      vector.extract %[[V]][%[[J]]]
+
+// -----
+
+#map = #xla_gpu.indexing_map<(d0)[s0] -> (d0 * 2 + s0 + 1),
+                             domain: d0 in [0, 63], s0 in [0, 1]>
+module {
+  func.func @misaligned_indexing_map(%arg0: tensor<128xf32>) -> (f32) {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c2 = arith.constant 2 : index
+    %c63 = arith.constant 63 : index
+    %cst = arith.constant 0.0 : f32
+    %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 {
+      %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 {
+        %idx = xla_gpu.apply_indexing #map(%i)[%j]
+        %extracted = tensor.extract %arg0[%idx] : tensor<128xf32>
+        %added = arith.addf %iter1, %extracted : f32
+        scf.yield %added : f32
+      }
+      scf.yield %inner : f32
+    }
+    return %outer : f32
+  }
+}
+
+// CHECK-LABEL: @misaligned_indexing_map
+// CHECK-NOT: vector.transfer_read
+
+// -----
+
+#map = #xla_gpu.indexing_map<(d0)[s0] -> (d0 * 3 + s0),
+                             domain: d0 in [0, 63], s0 in [0, 1]>
+module {
+  func.func @misaligned_indexing_map_2(%arg0: tensor<128xf32>) -> (f32) {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c2 = arith.constant 2 : index
+    %c63 = arith.constant 63 : index
+    %cst = arith.constant 0.0 : f32
+    %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 {
+      %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 {
+        %idx = xla_gpu.apply_indexing #map(%i)[%j]
+        %extracted = tensor.extract %arg0[%idx] : tensor<128xf32>
+        %added = arith.addf %iter1, %extracted : f32
+        scf.yield %added : f32
+      }
+      scf.yield %inner : f32
+    }
+    return %outer : f32
+  }
+}
+
+// CHECK-LABEL: @misaligned_indexing_map_2
+// CHECK-NOT: vector.transfer_read
+
+// -----
+
+module {
+  func.func @misaligned_shape(%arg0: tensor<64x3xf32>) -> (f32) {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c2 = arith.constant 2 : index
+    %c64 = arith.constant 64 : index
+    %cst = arith.constant 0.0 : f32
+    %outer = scf.for %i = %c0 to %c64 step %c1 iter_args(%iter = %cst) -> f32 {
+      %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 {
+        %extracted = tensor.extract %arg0[%i, %j] : tensor<64x3xf32>
+        %added = arith.addf %iter1, %extracted : f32
+        scf.yield %added : f32
+      }
+      scf.yield %inner : f32
+    }
+    return %outer : f32
+  }
+}
+
+// CHECK-LABEL: @misaligned_shape
+// CHECK-NOT: vector.transfer_read
+
+// -----
+
+#map = #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0 * 2),
+                             domain: d0 in [0, 63], s0 in [0, 1]>
+module {
+  func.func @wrong_stride(%arg0: tensor<128xf32>) -> (f32) {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c2 = arith.constant 2 : index
+    %c63 = arith.constant 63 : index
+    %cst = arith.constant 0.0 : f32
+    %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 {
+      %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 {
+        %idx = xla_gpu.apply_indexing #map(%i)[%j]
+        %extracted = tensor.extract %arg0[%idx] : tensor<128xf32>
+        %added = arith.addf %iter1, %extracted : f32
+        scf.yield %added : f32
+      }
+      scf.yield %inner : f32
+    }
+    return %outer : f32
+  }
+}
+
+// CHECK-LABEL: @wrong_stride
+// CHECK-NOT: vector.transfer_read
+
+// -----
+
+// We could vectorize this as a float vector load of double the size, but we
+// don't currently.
+module {
+  func.func @simple_read_complex(%arg0: tensor<64x2xcomplex<f32>>, %i: index) -> (complex<f32>) {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c2 = arith.constant 2 : index
+    %cst = complex.constant [0.0 : f32, 0.0 : f32] : complex<f32>
+    %loop = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter = %cst) -> complex<f32> {
+      %extracted = tensor.extract %arg0[%i, %j] : tensor<64x2xcomplex<f32>>
+      %added = complex.add %iter, %extracted : complex<f32>
+      scf.yield %added : complex<f32>
+    }
+    return %loop : complex<f32>
+  }
+}
+
+// CHECK-LABEL: @simple_read_complex
+// CHECK-NOT: vector.transfer_read
+
+// -----
+
+// This is vectorizable, but not currently supported.
+module {
+  func.func @layout(%arg0: tensor<2x64xf32, dense<[0, 1]> : tensor<2xi64>>) -> (f32) {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c2 = arith.constant 2 : index
+    %c64 = arith.constant 64 : index
+    %cst = arith.constant 0.0 : f32
+    %outer = scf.for %i = %c0 to %c64 step %c1 iter_args(%iter = %cst) -> f32 {
+      %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 {
+        %extracted = tensor.extract %arg0[%j, %i]
+          : tensor<2x64xf32, dense<[0, 1]> : tensor<2xi64>>
+        %added = arith.addf %iter1, %extracted : f32
+        scf.yield %added : f32
+      }
+      scf.yield %inner : f32
+    }
+    return %outer : f32
+  }
+}
+
+// CHECK-LABEL: @layout
+// CHECK-NOT: vector.transfer_read
+
+// -----
+
+module {
+  func.func @simple_write(%arg0: tensor<16x4xf32>, %i: index) -> tensor<16x4xf32> {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c4 = arith.constant 2 : index
+    %cst = arith.constant 0.0 : f32
+    %loop = scf.for %j = %c0 to %c4 step %c1 iter_args(%iter = %arg0) -> tensor<16x4xf32> {
+      %inserted = tensor.insert %cst into %iter[%i, %j] : tensor<16x4xf32>
+      scf.yield %inserted : tensor<16x4xf32>
+    }
+    return %loop : tensor<16x4xf32>
+  }
+}
+
+// CHECK-LABEL: @simple_write
+// CHECK-SAME:     (%[[ARG0:.*]]: tensor{{.*}}, %[[I:.*]]: index
+// CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+// CHECK:       %[[V:.*]] = scf.for
+// CHECK-NEXT:    vector.insert
+// CHECK-NEXT:    scf.yield
+// CHECK:       %[[WRITTEN:.*]] = vector.transfer_write %[[V]], %[[ARG0]][%[[I]], %[[C0]]]
+// CHECK-NEXT:  return %[[WRITTEN]]
+
+// -----
+
+module {
+  func.func @write_with_use(%arg0: tensor<16x4xf32>, %i: index) -> tensor<16x4xf32> {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c4 = arith.constant 2 : index
+    %cst = arith.constant 0.0 : f32
+    %loop = scf.for %j = %c0 to %c4 step %c1 iter_args(%iter = %arg0) -> tensor<16x4xf32> {
+      %inserted = tensor.insert %cst into %iter[%i, %j] : tensor<16x4xf32>
+      "dummy.op1"(%inserted) : (tensor<16x4xf32>) -> ()
+      scf.yield %inserted : tensor<16x4xf32>
+    }
+    return %loop : tensor<16x4xf32>
+  }
+}
+
+// CHECK-LABEL: @write_with_use
+// CHECK-NOT:   transfer_write
+
+// -----
+
+module {
+  func.func @write_not_to_iter_arg(%arg0: tensor<16x4xf32>, %i: index) -> tensor<16x4xf32> {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c4 = arith.constant 2 : index
+    %cst = arith.constant 0.0 : f32
+    %loop = scf.for %j = %c0 to %c4 step %c1 iter_args(%iter = %arg0) -> tensor<16x4xf32> {
+      %inserted = tensor.insert %cst into %arg0[%i, %j] : tensor<16x4xf32>
+      scf.yield %inserted : tensor<16x4xf32>
+    }
+    return %loop : tensor<16x4xf32>
+  }
+}
+
+// CHECK-LABEL: @write_not_to_iter_arg
+// CHECK-NOT:   transfer_write
+
+// -----
+
+module {
+  func.func @write_not_yielded(%arg0: tensor<16x4xf32>, %i: index) -> tensor<16x4xf32> {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c4 = arith.constant 2 : index
+    %cst = arith.constant 0.0 : f32
+    %loop = scf.for %j = %c0 to %c4 step %c1 iter_args(%iter = %arg0) -> tensor<16x4xf32> {
+      %inserted = tensor.insert %cst into %arg0[%i, %j] : tensor<16x4xf32>
+      scf.yield %arg0 : tensor<16x4xf32>
+    }
+    return %loop : tensor<16x4xf32>
+  }
+}
+
+// CHECK-LABEL: @write_not_yielded
+// CHECK-NOT:   transfer_write
+
+// -----
+
+#map = #xla_gpu.indexing_map<(d0, d1)[s0] -> (d1 * 2 + d0 + s0 * 512),
+                             domain: d0 in [0, 1], d1 in [0, 255], s0 in [0, 7]>
+module {
+  func.func @multiple(%arg0: tensor<32x4096xf32>, %arg1: tensor<4096xbf16>,
+        %arg2: tensor<32xf32>, %arg3: tensor<32x4096xf32>,
+        %arg4: index) -> (tensor<32x4096xf32>, f32) {
+    %cst = arith.constant 1.000000e+00 : f32
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c2 = arith.constant 2 : index
+    %c8 = arith.constant 8 : index
+    %extracted1 = tensor.extract %arg2[%arg4] : tensor<32xf32>
+    %0:2 = scf.for %i = %c0 to %c8 step %c1 iter_args(%iter0 = %arg3, %iter1 = %cst) -> (tensor<32x4096xf32>, f32) {
+      %1:2 = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter2 = %iter0, %iter3 = %iter1) -> (tensor<32x4096xf32>, f32) {
+        %2 = xla_gpu.apply_indexing #map(%j, %arg4)[%i]
+        %extracted2 = tensor.extract %arg0[%i, %2] : tensor<32x4096xf32>
+        %extracted3 = tensor.extract %arg1[%2] : tensor<4096xbf16>
+        %3 = arith.extf %extracted3 : bf16 to f32
+        %4 = arith.addf %extracted2, %3 : f32
+        %5 = arith.addf %extracted1, %4 : f32
+        %6 = arith.addf %iter3, %5 : f32
+        %inserted = tensor.insert %5 into %iter2[%i, %2] : tensor<32x4096xf32>
+        scf.yield %inserted, %6 : tensor<32x4096xf32>, f32
+      }
+      scf.yield %1#0, %1#1 : tensor<32x4096xf32>, f32
+    }
+    return %0#0, %0#1 : tensor<32x4096xf32>, f32
+  }
+}
+
+// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0)[s0] -> (d0 * 2 + s0 * 512),
+// CHECK-SAME:                                 domain: d0 in [0, 255], s0 in [0, 7]>
+// CHECK-LABEL: @multiple
+// CHECK-SAME: (%[[ARG0:.*]]: tensor{{.*}}, %[[ARG1:.*]]: tensor{{.*}}, %[[ARG2:.*]]: tensor{{.*}}, %[[ARG3:.*]]: tensor{{.*}}, %[[ARG4:.*]]: index)
+// CHECK:      %[[C0:.*]] = arith.constant 0 : index
+// CHECK:      scf.for %[[I:.*]] = %[[C0]]
+// CHECK:      %[[BASE:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[ARG4]])[%[[I]]]
+// CHECK:      %[[READ1:.*]] = vector.transfer_read %[[ARG1]][%[[BASE]]]
+// CHECK:      %[[READ2:.*]] = vector.transfer_read %[[ARG0]][%[[I]], %[[BASE]]]
+// CHECK:      %[[INNER:.*]]:2 = scf.for %[[J:.*]] = %[[C0]] {{.*}} iter_args(%[[F:.*]] = {{.*}}, %[[V:.*]] = {{.*}}) -> (f32, vector<2xf32>)
+// CHECK-DAG:  vector.extract %[[READ1]][%[[J]]]
+// CHECK-DAG:  vector.extract %[[READ2]][%[[J]]]
+// CHECK:      extf
+// CHECK-NEXT: addf
+// CHECK-NEXT: %[[TO_INSERT:.*]] = arith.addf
+// CHECK-NEXT: %[[TO_YIELD:.*]] = arith.addf
+// CHECK-NEXT: %[[V_NEXT:.*]] = vector.insert %[[TO_INSERT]], %[[V]] [%[[J]]]
+// CHECK-NEXT: scf.yield %[[TO_YIELD]], %[[V_NEXT]]
+// CHECK:      %[[WRITTEN:.*]] = vector.transfer_write %[[INNER]]#1, %{{.*}}[%[[I]], %[[BASE]]]
+// CHECK:      scf.yield %[[WRITTEN]], %[[INNER]]#0
+
+// -----
+
+#map = #xla_gpu.indexing_map<(d0)[s0] -> ((d0 * 4) mod 64 + s0),
+                             domain: d0 in [0, 63], s0 in [0, 1]>
+module {
+  func.func @remainder_with_modulo(%arg0: tensor<128xf32>) -> (f32) {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c2 = arith.constant 2 : index
+    %c63 = arith.constant 63 : index
+    %cst = arith.constant 0.0 : f32
+    %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 {
+      %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 {
+        %idx = xla_gpu.apply_indexing #map(%i)[%j]
+        %extracted = tensor.extract %arg0[%idx] : tensor<128xf32>
+        %added = arith.addf %iter1, %extracted : f32
+        scf.yield %added : f32
+      }
+      scf.yield %inner : f32
+    }
+    return %outer : f32
+  }
+}
+
+// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0) -> ((d0 mod 16) * 4),
+// CHECK-LABEL: @remainder_with_modulo
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: scf.for %[[I:.*]] = %[[C0]]
+// CHECK: %[[BASE:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[I]]
+// CHECK: vector.transfer_read {{.*}}[%[[BASE]]]
+
+// -----
+
+#map = #xla_gpu.indexing_map<(d0)[s0] -> ((d0 * 4) mod 65 + s0),
+                             domain: d0 in [0, 63], s0 in [0, 1]>
+module {
+  func.func @remainder_with_modulo_misaligned(%arg0: tensor<128xf32>) -> (f32) {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c2 = arith.constant 2 : index
+    %c63 = arith.constant 63 : index
+    %cst = arith.constant 0.0 : f32
+    %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 {
+      %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 {
+        %idx = xla_gpu.apply_indexing #map(%i)[%j]
+        %extracted = tensor.extract %arg0[%idx] : tensor<128xf32>
+        %added = arith.addf %iter1, %extracted : f32
+        scf.yield %added : f32
+      }
+      scf.yield %inner : f32
+    }
+    return %outer : f32
+  }
+}
+
+// CHECK-LABEL: @remainder_with_modulo_misaligned
+// CHECK-NOT: vector.transfer_read
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/unswitch_loops.cc b/third_party/xla/xla/service/gpu/fusions/transforms/unswitch_loops.cc
new file mode 100644
index 0000000..d514a67
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/unswitch_loops.cc
@@ -0,0 +1,106 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <utility>
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace xla {
+namespace gpu {
+
+#define GEN_PASS_DEF_UNSWITCHLOOPSPASS
+#include "xla/service/gpu/fusions/transforms/passes.h.inc"
+
+namespace {
+
+class UnswitchLoopsPass
+    : public impl::UnswitchLoopsPassBase<UnswitchLoopsPass> {
+ public:
+  void runOnOperation() override;
+};
+
+struct UnswitchLoop : mlir::OpRewritePattern<mlir::scf::ForOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  mlir::LogicalResult matchAndRewrite(
+      mlir::scf::ForOp op, mlir::PatternRewriter& rewriter) const override {
+    if (op.getBody()->getOperations().size() != 2) {
+      return rewriter.notifyMatchFailure(
+          op, "loop body is not a single instruction");
+    }
+    auto if_op = mlir::dyn_cast<mlir::scf::IfOp>(op.getBody()->front());
+    if (!if_op) {
+      return rewriter.notifyMatchFailure(op, "no if found inside the loop");
+    }
+    if (mlir::matchPattern(if_op.getCondition(), mlir::m_Constant())) {
+      return rewriter.notifyMatchFailure(op, "condition is a constant");
+    }
+
+    auto true_cst = rewriter.create<mlir::arith::ConstantOp>(
+        op.getLoc(), rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
+    auto false_cst = rewriter.create<mlir::arith::ConstantOp>(
+        op.getLoc(), rewriter.getIntegerAttr(rewriter.getI1Type(), 0));
+    rewriter.setInsertionPoint(op);
+    mlir::IRMapping mapping;
+    mapping.map(if_op.getCondition(), false_cst);
+    auto false_branch_loop = op->clone(mapping);
+    auto new_if = rewriter.create<mlir::scf::IfOp>(
+        op.getLoc(), op.getResultTypes(), if_op.getCondition(), true, true);
+    rewriter.replaceAllUsesWith(op.getResults(), new_if.getResults());
+
+    auto then_builder = new_if.getThenBodyBuilder(rewriter.getListener());
+    auto then_yield =
+        then_builder.create<mlir::scf::YieldOp>(op.getLoc(), op.getResults());
+    rewriter.moveOpBefore(op, then_yield);
+    rewriter.modifyOpInPlace(if_op, [&]() { if_op->setOperand(0, true_cst); });
+
+    auto else_builder = new_if.getElseBodyBuilder(rewriter.getListener());
+    else_builder.insert(false_branch_loop);
+    else_builder.create<mlir::scf::YieldOp>(op.getLoc(),
+                                            false_branch_loop->getResults());
+
+    return mlir::success();
+  }
+};
+
+void UnswitchLoopsPass::runOnOperation() {
+  mlir::RewritePatternSet patterns(&getContext());
+  patterns.add<UnswitchLoop>(&getContext());
+  mlir::scf::ForOp::getCanonicalizationPatterns(patterns, &getContext());
+  mlir::scf::IfOp::getCanonicalizationPatterns(patterns, &getContext());
+  if (mlir::failed(mlir::applyPatternsAndFoldGreedily(getOperation(),
+                                                      std::move(patterns)))) {
+    signalPassFailure();
+  }
+}
+
+}  // namespace
+
+std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
+CreateUnswitchLoopsPass() {
+  return std::make_unique<UnswitchLoopsPass>();
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/vectorize_loads_stores.cc b/third_party/xla/xla/service/gpu/fusions/transforms/vectorize_loads_stores.cc
new file mode 100644
index 0000000..dd5d443
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/transforms/vectorize_loads_stores.cc
@@ -0,0 +1,358 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <cstdint>
+#include <memory>
+#include <numeric>
+#include <optional>
+#include <utility>
+
+#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
+#include "mlir/IR/ValueRange.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h"
+
+namespace xla {
+namespace gpu {
+
+#define GEN_PASS_DEF_VECTORIZELOADSANDSTORESPASS
+#include "xla/service/gpu/fusions/transforms/passes.h.inc"
+
+namespace {
+
+// Tries to find the stride of a symbol or dimension in an affine expression.
+// Returns std::nullopt if the stride could not be determined.
+//
+// Note: this function only attempts to handle the cases where the stride is
+// known to be 0 or 1.
+//
+// Example: the stride of `d0` in `(d0 + d1)` is 1.
+// Example: the stride of `d0` in `d0 * 2` is unknown (nullopt).
+std::optional<int> GetStride(mlir::AffineExpr expr,
+                             mlir::AffineExpr dim_or_sym) {
+  if (auto binop = mlir::dyn_cast_or_null<mlir::AffineBinaryOpExpr>(expr)) {
+    auto lhs_stride = GetStride(binop.getLHS(), dim_or_sym);
+    auto rhs_stride = GetStride(binop.getRHS(), dim_or_sym);
+
+    if (binop.getKind() == mlir::AffineExprKind::Add) {
+      if (lhs_stride && rhs_stride) {
+        return *lhs_stride + *rhs_stride;
+      }
+      return std::nullopt;
+    }
+    // Just return 0 if the expression doesn't occur on either side.
+    if (lhs_stride == 0 && rhs_stride == 0) {
+      return 0;
+    }
+    // Otherwise, we don't know the stride.
+    return std::nullopt;
+  }
+  return expr == dim_or_sym ? 1 : 0;
+}
+
+int64_t GetAlignmentOfRemainder(mlir::AffineExpr expr,
+                                mlir::AffineExpr dim_or_sym) {
+  if (auto binop = mlir::dyn_cast_or_null<mlir::AffineBinaryOpExpr>(expr)) {
+    auto lhs_align = GetAlignmentOfRemainder(binop.getLHS(), dim_or_sym);
+    auto rhs_align = GetAlignmentOfRemainder(binop.getRHS(), dim_or_sym);
+
+    std::optional<int64_t> rhs_cst = std::nullopt;
+    if (binop.getRHS().getKind() == mlir::AffineExprKind::Constant) {
+      rhs_cst = binop.getRHS().cast<mlir::AffineConstantExpr>().getValue();
+    }
+
+    switch (binop.getKind()) {
+      case mlir::AffineExprKind::Add:
+        if (binop.getLHS() == dim_or_sym) return rhs_align;
+        if (binop.getRHS() == dim_or_sym) return lhs_align;
+        return std::gcd(lhs_align, rhs_align);
+      case mlir::AffineExprKind::Mul:
+        return lhs_align * rhs_align;
+      case mlir::AffineExprKind::FloorDiv:
+      case mlir::AffineExprKind::CeilDiv:
+        return 1;
+      case mlir::AffineExprKind::Mod:
+        // (a * c) % (b * c) = (a % b) * c.
+        return std::gcd(lhs_align, rhs_align);
+      default:
+        llvm_unreachable("expr is none of the binary expressions");
+    }
+  }
+  if (auto cst = mlir::dyn_cast<mlir::AffineConstantExpr>(expr)) {
+    return cst.getValue();
+  }
+  return 1;
+}
+
+// Attempts to extract the vector type for the given loop. This means:
+// - checks that the lower bound is 0
+// - checks that the step is 1
+// - checks that the upper bound is 2 or 4.
+// Returns a vector type with the given upper bound and the tensor's element
+// type.
+mlir::VectorType GetVectorType(mlir::RankedTensorType tensor_type,
+                               mlir::scf::ForOp loop) {
+  // TODO(jreiffers): Support layouts.
+  if (tensor_type.getEncoding()) {
+    return nullptr;
+  }
+  if (!mlir::VectorType::isValidElementType(tensor_type.getElementType())) {
+    return nullptr;
+  }
+  if (mlir::getConstantIntValue(loop.getStep()) != 1 ||
+      mlir::getConstantIntValue(loop.getLowerBound()) != 0) {
+    return nullptr;
+  }
+  std::optional<int> vector_size =
+      mlir::getConstantIntValue(loop.getUpperBound());
+  if (vector_size != 2 && vector_size != 4) {
+    return nullptr;  // Unsupported vector size.
+  }
+  if (tensor_type.getRank() > 1 &&
+      tensor_type.getShape().back() % *vector_size) {
+    return nullptr;  // Misaligned start indices.
+  }
+  return mlir::VectorType::get({*vector_size}, tensor_type.getElementType());
+}
+
+std::optional<llvm::SmallVector<mlir::Value>> GetVectorBaseIndices(
+    mlir::ValueRange indices, mlir::scf::ForOp loop,
+    mlir::VectorType vector_type, mlir::ImplicitLocOpBuilder& b) {
+  if (indices.empty()) {
+    return std::nullopt;
+  }
+
+  // The major dimensions' indices must all be defined outside the loop.
+  for (int i = 0; i < indices.size() - 1; ++i) {
+    if (!indices[i].getParentRegion()->isProperAncestor(
+            &loop.getBodyRegion())) {
+      return std::nullopt;
+    }
+  }
+
+  mlir::Value induction_var = loop.getInductionVar();
+  if (indices.back() == induction_var) {
+    llvm::SmallVector<mlir::Value> ret = indices;
+    ret.back() = b.create<mlir::arith::ConstantIndexOp>(0);
+    return ret;
+  }
+
+  auto apply_indexing =
+      mlir::dyn_cast_or_null<ApplyIndexingOp>(indices.back().getDefiningOp());
+  if (!apply_indexing) {
+    return std::nullopt;
+  }
+
+  // We don't generate these, but they are allowed in theory.
+  if (apply_indexing->getNumResults() != 1) {
+    return std::nullopt;
+  }
+  mlir::AffineMap map = apply_indexing.getAffineMap();
+
+  int induction_var_operand_index;
+  mlir::AffineExpr induction_var_expr = nullptr;
+  for (auto [index, operand] : llvm::enumerate(apply_indexing.getOperands())) {
+    if (operand == induction_var) {
+      if (induction_var_expr) {
+        // The induction variable should be used only once.
+        return std::nullopt;
+      }
+      induction_var_operand_index = index;
+      induction_var_expr = index < map.getNumDims()
+                               ? mlir::getAffineDimExpr(index, b.getContext())
+                               : mlir::getAffineSymbolExpr(
+                                     index - map.getNumDims(), b.getContext());
+    }
+  }
+  if (!induction_var_expr) {
+    return std::nullopt;
+  }
+
+  if (GetStride(map.getResult(0), induction_var_expr) != 1) {
+    // The indexing map is not contiguous in the vectorized dimension.
+    return std::nullopt;
+  }
+
+  if (GetAlignmentOfRemainder(map.getResult(0), induction_var_expr) %
+      vector_type.getNumElements()) {
+    return std::nullopt;
+  }
+
+  auto operands = llvm::to_vector(apply_indexing.getOperands());
+  operands[induction_var_operand_index] =
+      b.create<mlir::arith::ConstantIndexOp>(0);
+
+  llvm::SmallVector<mlir::Value> ret = indices;
+  ret.back() =
+      b.create<ApplyIndexingOp>(operands, apply_indexing.getIndexingMap())
+          ->getResult(0);
+  return ret;
+}
+
+bool IsConflictFree(mlir::tensor::ExtractOp op) {
+  return op.getTensor().getParentRegion()->isProperAncestor(
+      op->getParentRegion());
+}
+
+struct VectorizeLoad : mlir::OpRewritePattern<mlir::tensor::ExtractOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  mlir::LogicalResult matchAndRewrite(
+      mlir::tensor::ExtractOp op,
+      mlir::PatternRewriter& rewriter) const override {
+    auto loop = mlir::dyn_cast_or_null<mlir::scf::ForOp>(op->getParentOp());
+    if (!loop) {
+      return rewriter.notifyMatchFailure(op, "no loop found");
+    }
+    if (!IsConflictFree(op)) {
+      return rewriter.notifyMatchFailure(op,
+                                         "source may be written in the loop");
+    }
+
+    auto vector_type = GetVectorType(op.getTensor().getType(), loop);
+    if (!vector_type) {
+      return rewriter.notifyMatchFailure(op, "not a vectorizable loop");
+    }
+
+    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    b.setInsertionPoint(loop);
+    auto vector_indices =
+        GetVectorBaseIndices(op.getIndices(), loop, vector_type, b);
+    if (!vector_indices) {
+      return rewriter.notifyMatchFailure(
+          op, "the instruction does not access contiguous elements");
+    }
+
+    auto loaded_vector = b.create<mlir::vector::TransferReadOp>(
+        vector_type, op.getTensor(), *vector_indices,
+        llvm::ArrayRef<bool>{true});
+    rewriter.replaceOpWithNewOp<mlir::vector::ExtractOp>(
+        op, loaded_vector, loop.getInductionVar());
+    return mlir::success();
+  }
+};
+
+// Verifies that the insertions happening in the loop can all safely be batched
+// in the end.
+bool IsConflictFree(mlir::tensor::InsertOp op) {
+  // The insertion's only use must be the yield.
+  if (!op->hasOneUse() || !mlir::isa<mlir::scf::YieldOp>(*op->user_begin())) {
+    return false;
+  }
+  // The destination must be one of the loop's block arguments, and the
+  // destination must be the argument's only use.
+  auto bbarg = mlir::dyn_cast<mlir::BlockArgument>(op.getDest());
+  return bbarg && bbarg.hasOneUse() &&
+         bbarg.getOwner()->getParentOp() == op->getParentOp();
+}
+
+struct VectorizeStore : mlir::OpRewritePattern<mlir::tensor::InsertOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  mlir::LogicalResult matchAndRewrite(
+      mlir::tensor::InsertOp op,
+      mlir::PatternRewriter& rewriter) const override {
+    auto loop = mlir::dyn_cast_or_null<mlir::scf::ForOp>(op->getParentOp());
+    if (!loop) {
+      return rewriter.notifyMatchFailure(op, "no loop found");
+    }
+    if (!IsConflictFree(op)) {
+      return rewriter.notifyMatchFailure(op, "write may be read back by loop");
+    }
+    auto vector_type = GetVectorType(op.getDest().getType(), loop);
+    if (!vector_type) {
+      return rewriter.notifyMatchFailure(op, "loop is not vectorizable");
+    }
+
+    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    b.setInsertionPoint(loop);
+    auto vector_indices =
+        GetVectorBaseIndices(op.getIndices(), loop, vector_type, b);
+    if (!vector_indices) {
+      return rewriter.notifyMatchFailure(
+          op, "the instruction does not access contiguous elements");
+    }
+
+    auto init = b.create<mlir::arith::ConstantOp>(b.getZeroAttr(vector_type))
+                    .getResult();
+
+    auto yield_fn = [&](mlir::OpBuilder& yield_b, mlir::Location yield_loc,
+                        llvm::ArrayRef<mlir::BlockArgument> bbarg) {
+      auto induction_var =
+          mlir::cast<mlir::scf::ForOp>(bbarg.front().getOwner()->getParentOp())
+              .getInductionVar();
+      auto insert_op = yield_b.create<mlir::vector::InsertOp>(
+          yield_loc, op.getScalar(), bbarg.front(), induction_var);
+      return llvm::SmallVector<mlir::Value>{insert_op.getResult()};
+    };
+    int result_index = op->use_begin()->getOperandNumber();
+    auto new_for = *loop.replaceWithAdditionalYields(
+        rewriter, init,
+        /*replaceInitOperandUsesInLoop=*/false, yield_fn);
+
+    b.setInsertionPointAfter(new_for);
+    rewriter.replaceOp(op, op.getDest());
+
+    auto filled_vector = new_for->getResults().back();
+    auto written = b.create<mlir::vector::TransferWriteOp>(
+        filled_vector, new_for.getInits()[result_index], *vector_indices,
+        llvm::ArrayRef<bool>{true});
+    new_for->getResult(result_index).replaceAllUsesWith(written.getResult());
+
+    return mlir::success();
+  }
+};
+
+class VectorizeLoadsAndStoresPass
+    : public impl::VectorizeLoadsAndStoresPassBase<
+          VectorizeLoadsAndStoresPass> {
+ public:
+  void runOnOperation() override {
+    mlir::RewritePatternSet patterns(&getContext());
+    patterns.add<VectorizeLoad, VectorizeStore>(&getContext());
+    if (mlir::failed(mlir::applyPatternsAndFoldGreedily(getOperation(),
+                                                        std::move(patterns)))) {
+      signalPassFailure();
+    }
+  }
+};
+
+}  // namespace
+
+std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
+CreateVectorizeLoadsAndStoresPass() {
+  return std::make_unique<VectorizeLoadsAndStoresPass>();
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/transpose.cc b/third_party/xla/xla/service/gpu/fusions/transpose.cc
deleted file mode 100644
index 611099d..0000000
--- a/third_party/xla/xla/service/gpu/fusions/transpose.cc
+++ /dev/null
@@ -1,366 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "xla/service/gpu/fusions/transpose.h"
-
-#include <array>
-#include <cstdint>
-#include <optional>
-#include <tuple>
-#include <utility>
-#include <variant>
-#include <vector>
-
-#include "absl/container/flat_hash_map.h"
-#include "absl/container/inlined_vector.h"
-#include "absl/log/check.h"
-#include "absl/status/status.h"
-#include "absl/strings/str_cat.h"
-#include "absl/types/span.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/IR/DerivedTypes.h"
-#include "llvm/IR/IRBuilder.h"
-#include "llvm/IR/Type.h"
-#include "llvm/IR/Value.h"
-#include "llvm/Support/AtomicOrdering.h"
-#include "mlir/IR/AffineMap.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/permutation_util.h"
-#include "xla/service/gpu/elemental_ir_emitter.h"
-#include "xla/service/gpu/fusions/tiling_util.h"
-#include "xla/service/gpu/hlo_fusion_analysis.h"
-#include "xla/service/gpu/ir_emission_utils.h"
-#include "xla/service/gpu/ir_emitter_context.h"
-#include "xla/service/gpu/launch_dimensions.h"
-#include "xla/service/gpu/model/indexing_analysis.h"
-#include "xla/service/gpu/model/indexing_map.h"
-#include "xla/service/gpu/target_util.h"
-#include "xla/service/llvm_ir/fused_ir_emitter.h"
-#include "xla/service/llvm_ir/ir_array.h"
-#include "xla/service/llvm_ir/llvm_util.h"
-#include "xla/service/llvm_ir/loop_emitter.h"
-#include "xla/shape_util.h"
-#include "xla/util.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-Tiling ComputeTransposeTiling(const se::DeviceDescription& gpu_device_info,
-                              const TransposeDescription& tiled_transpose) {
-  constexpr int kNumRows = 4;
-  static_assert(WarpSize() % kNumRows == 0);
-
-  // 3D view over the output shape.
-  Vector3 transposed_dims = tiled_transpose.dimensions;
-  Vector3 permutation = tiled_transpose.permutation;
-
-  // Note: the supported permutations are their own inverses. Therefore we
-  // always use the permutation, even when we want the inverse.
-  CHECK((permutation == Vector3{0, 2, 1}) || (permutation == Vector3{2, 1, 0}));
-
-  absl::InlinedVector<int64_t, 4> input_dims{transposed_dims[permutation[0]],
-                                             transposed_dims[permutation[1]],
-                                             transposed_dims[permutation[2]]};
-
-  // We tile along the minor dimensions pre- and post-transpose.
-  absl::InlinedVector<int64_t, 4> tile_sizes{1, 1, 1};
-  tile_sizes[permutation[2]] = WarpSize() / kNumRows;
-  absl::InlinedVector<int64_t, 4> num_threads{1, 1, WarpSize()};
-  num_threads[permutation[2]] = kNumRows;
-
-  auto capability = gpu_device_info.gpu_compute_capability();
-  std::visit(
-      [&](const auto& capability) {
-        if constexpr (std::is_same_v<std::decay_t<decltype(capability)>,
-                                     stream_executor::RocmComputeCapability>) {
-          // kNumRows = 8 works well on MI300 with wavefront size 64.
-          if (capability.gfx9_mi300()) {
-            tile_sizes[permutation[2]] = gpu_device_info.threads_per_warp() / 8;
-            num_threads[permutation[2]] = 8;
-          }
-        }
-      },
-      capability);
-
-  return Tiling(input_dims, tile_sizes, num_threads);
-}
-
-void MaybeEmitFenceForAMDGPU(llvm::IRBuilder<>* builder,
-                             IrEmitterContext& ir_emitter_context) {
-  auto* module = builder->GetInsertBlock()->getModule();
-  if (IsAMDGPU(module) &&
-      ir_emitter_context.rocm_compute_capability().fence_before_barrier()) {
-    builder->CreateFence(
-        llvm::AtomicOrdering::SequentiallyConsistent,
-        builder->getContext().getOrInsertSyncScopeID("workgroup"));
-  }
-}
-
-void EmitSyncThreads(llvm::IRBuilder<>* builder,
-                     IrEmitterContext& ir_emitter_context) {
-  MaybeEmitFenceForAMDGPU(builder, ir_emitter_context);
-  EmitCallToTargetIntrinsic(TargetIntrinsicID::kBarrierId, {}, {}, builder);
-}
-
-llvm_ir::IrArray::Index PermuteIndex(const llvm_ir::IrArray::Index& index,
-                                     absl::Span<const int64_t> permutation) {
-  return llvm_ir::IrArray::Index{Permute(index.multidim(), permutation),
-                                 Permute(index.dims(), permutation),
-                                 index.GetType()};
-}
-
-}  // namespace
-
-TransposeFusion::TransposeFusion(const se::DeviceDescription& gpu_device_info,
-                                 const HloFusionAnalysis& analysis)
-    : analysis_(analysis),
-      tiling_(
-          ComputeTransposeTiling(gpu_device_info, analysis.tiled_transpose())) {
-  for (auto [root, hero] :
-       llvm::zip(analysis_.fusion_roots(), analysis_.fusion_heroes())) {
-    if (auto transpose = GetDescriptionForTiledTransposeEmitter(
-            root.instruction(), hero.instruction())) {
-      permutation_ = transpose->permutation;
-      break;
-    }
-  }
-}
-
-absl::Status TransposeFusion::EmitKernel(IrEmitterContext& ir_emitter_context,
-                                         const HloFusionInstruction& fusion,
-                                         const LaunchDimensions& launch_dims,
-                                         std::vector<llvm_ir::IrArray> inputs,
-                                         std::vector<llvm_ir::IrArray> outputs,
-                                         llvm::IRBuilder<>* builder) const {
-  const auto& hlo_roots = analysis_.fusion_roots();
-  GpuElementalIrEmitter elemental_emitter(ir_emitter_context, builder);
-  FusedIrEmitter fused_emitter(elemental_emitter);
-  for (auto [i, input] : llvm::enumerate(inputs)) {
-    HloInstruction* fused_operand = fusion.fused_parameter(i);
-    fused_emitter.BindGenerator(
-        *fused_operand, [input = input, builder,
-                         fused_operand](const llvm_ir::IrArray::Index& index) {
-          return input.EmitReadArrayElement(index, builder,
-                                            fused_operand->name());
-        });
-  }
-
-  absl::flat_hash_map<const HloInstruction*,
-                      std::vector<std::pair<int64_t, const HloInstruction*>>>
-      transposes_to_roots;
-  // Keep a list of deduplicated transpose heroes separate from
-  // transposes_to_roots to make the CodeGen deterministic.
-  std::vector<TransposeDescription> transposes;
-  transposes.reserve(hlo_roots.size());
-  std::vector<std::pair<int64_t, const HloInstruction*>> extra_outputs;
-
-  for (const auto& [output_idx, root] : llvm::enumerate(hlo_roots)) {
-    const auto& hero = analysis_.fusion_hero(output_idx).instruction();
-    auto transpose_descr =
-        GetDescriptionForTiledTransposeEmitter(root.instruction(), hero);
-    if (transpose_descr.has_value()) {
-      auto iterator_inserted = transposes_to_roots.insert(std::make_pair(
-          &hero, std::vector<std::pair<int64_t, const HloInstruction*>>{
-                     {output_idx, &root.instruction()}}));
-      if (iterator_inserted.second) {
-        transposes.push_back(*transpose_descr);
-      } else {
-        iterator_inserted.first->second.push_back(
-            {output_idx, &root.instruction()});
-      }
-    } else {
-      extra_outputs.push_back({output_idx, &root.instruction()});
-    }
-  }
-
-  absl::flat_hash_map<const HloInstruction*, llvm_ir::SharedMemoryTile> tiles;
-  Vector3 permutation;
-  for (const auto& [tile_idx, tr] : llvm::enumerate(transposes)) {
-    permutation = tr.permutation;
-    auto tile_size = tiling_.GetBlockTileSize();
-    ++tile_size.back();  // Prevent bank conflicts.
-    auto* module = ir_emitter_context.llvm_module();
-    tiles[tr.instr] = llvm_ir::AllocateSharedMemoryTile(
-        module,
-        llvm_ir::PrimitiveTypeToIrType(tr.instr->shape().element_type(),
-                                       module),
-        tile_size, absl::StrCat("tr_tile_", tile_idx));
-  }
-
-  auto tile_generator = [&](const TilingThreadIdInfo& thread_id_info,
-                            const llvm_ir::IrArray::Index& tile_start_index,
-                            absl::Span<llvm::Value* const> tile_dimensions) {
-    // Copy input parameter values to shared memory buffers:
-    // tile[thread_id_y, thread_id_x] = input[index]
-    EmitTile(builder, tiling_, thread_id_info, tile_dimensions,
-             [&](absl::Span<llvm::Value* const> index_in_tile) {
-               auto index = tile_start_index.AddOffset(index_in_tile, builder);
-               for (const auto& tr : transposes) {
-                 auto input_gen =
-                     *fused_emitter.GetGenerator(*tr.instr->operand(0));
-                 auto input_index = index.SourceIndexOfBitcast(
-                     tr.instr->operand(0)->shape(), builder);
-                 llvm::Value* value = *input_gen(input_index);
-                 tiles[tr.instr].Store(value, index_in_tile, builder);
-               }
-
-               // Compute all extra output values before writing them. This
-               // avoids overwriting aliased input/output values before all
-               // reads occurred.
-               std::vector<std::tuple<llvm_ir::IrArray, llvm_ir::IrArray::Index,
-                                      llvm::Value*>>
-                   scheduled_writes;
-               for (const auto& [output_idx, root] : extra_outputs) {
-                 auto extra_output_index =
-                     index.SourceIndexOfBitcast(root->shape(), builder);
-                 auto output_gen = *fused_emitter.GetGenerator(*root);
-                 llvm::Value* output_value = *output_gen(extra_output_index);
-                 scheduled_writes.emplace_back(
-                     outputs[output_idx], extra_output_index, output_value);
-               }
-
-               for (const auto& [output, idx, value] : scheduled_writes) {
-                 output.EmitWriteArrayElement(idx, value, builder);
-               }
-             });
-
-    EmitSyncThreads(builder, ir_emitter_context);
-
-    auto output_tile_index = PermuteIndex(tile_start_index, permutation);
-    auto transposed_tile_dimensions = Permute(tile_dimensions, permutation);
-
-    EmitTile(
-        builder, tiling_, thread_id_info, transposed_tile_dimensions,
-        /*emit_elem_function=*/
-        [&](absl::Span<llvm::Value* const> index_in_tile) {
-          auto index = output_tile_index.AddOffset(index_in_tile, builder);
-          for (const auto& tr : transposes) {
-            llvm::Value* loaded = tiles[tr.instr].Load(
-                Permute(index_in_tile, permutation), builder);
-
-            FusedIrEmitter fused_emitter(elemental_emitter);
-            fused_emitter.BindGenerator(
-                *tr.instr,
-                [&](const llvm_ir::IrArray::Index&) { return loaded; });
-            for (int64_t i = 0;
-                 i < fusion.fused_instructions_computation()->num_parameters();
-                 ++i) {
-              llvm_ir::IrArray ir_array = inputs[i];
-              HloInstruction* fused_operand = fusion.fused_parameter(i);
-              fused_emitter.BindGenerator(
-                  *fused_operand, [=](const llvm_ir::IrArray::Index& index) {
-                    return ir_array.EmitReadArrayElement(index, builder,
-                                                         fused_operand->name());
-                  });
-            }
-
-            // Apply code generation for the code after the real hero.
-            // Compute all output values before writing them. This avoids
-            // overwriting aliased input/output values before all reads
-            // occurred.
-            std::vector<std::tuple<llvm_ir::IrArray, llvm_ir::IrArray::Index,
-                                   llvm::Value*>>
-                scheduled_writes;
-            for (const auto& [output_idx, root] :
-                 transposes_to_roots[tr.instr]) {
-              TF_ASSIGN_OR_RETURN(llvm_ir::ElementGenerator gen,
-                                  fused_emitter.GetGenerator(*root));
-
-              // Both for emission and writing it should be
-              // index-as-transformed by the computation.
-              auto untiled_index =
-                  index.SourceIndexOfBitcast(root->shape(), builder);
-              TF_ASSIGN_OR_RETURN(llvm::Value * generated, gen(untiled_index));
-              scheduled_writes.emplace_back(outputs[output_idx], untiled_index,
-                                            generated);
-            }
-            for (const auto& [output, idx, value] : scheduled_writes) {
-              output.EmitWriteArrayElement(idx, value, builder);
-            }
-          }
-          return absl::OkStatus();
-        });
-  };
-
-  llvm::Type* index_type =
-      GetIndexTypeForKernel(&fusion, launch_dims.launch_bound(), builder);
-  return EmitTilingKernel(builder, tiling_, index_type, tile_generator)
-      .status();
-}
-
-LaunchDimensions TransposeFusion::launch_dimensions() const {
-  return LaunchDimensions(tiling_.GetNumBlocks(),
-                          tiling_.GetNumThreadsPerBlock());
-}
-
-std::optional<IndexingMap> TransposeFusion::ComputeThreadIdToOutputIndexing(
-    int64_t root_index, mlir::MLIRContext* ctx) const {
-  const auto& hero = analysis_.fusion_hero(root_index);
-  if (hero.opcode() != HloOpcode::kTranspose) {
-    // The shape of non-transpose roots are bitcast compatible with the input
-    // shape of transpose heroes.
-    auto map = ComposeIndexingMaps(
-        GetIndexingMapForTiling(tiling_, ctx),
-        GetBitcastMap(tiling_.GetXlaShape(),
-                      analysis_.fusion_root(root_index).shape(), ctx));
-    map.Simplify();
-    return map;
-  }
-
-  // The block offsets are permuted, but the thread offsets remain the same.
-  auto block_offset = GetBlockOffsetsForTiling(tiling_, ctx)
-                          .getSubMap(std::vector<unsigned>{permutation_.begin(),
-                                                           permutation_.end()});
-  auto thread_offset = GetThreadOffsetsForTiling(tiling_, ctx);
-  auto permuted_tiled_shape =
-      ShapeUtil::MakeShape(U8, Permute(tiling_.GetShape(), permutation_));
-
-  auto map = ComposeIndexingMaps(
-      GetIndexingMapForTiling(
-          block_offset, thread_offset, tiling_.GetNumThreadsPerBlock(),
-          tiling_.GetNumBlocks(), tiling_.GetThreadTileSize(),
-          permuted_tiled_shape.dimensions()),
-      GetBitcastMap(permuted_tiled_shape, hero.shape(), ctx));
-  map.Simplify();
-  return map;
-}
-
-std::optional<IndexingMap> TransposeFusion::ComputeThreadIdToInputIndexing(
-    int64_t root_index, int64_t hero_operand_index,
-    mlir::MLIRContext* ctx) const {
-  const auto& hero = analysis_.fusion_hero(root_index).instruction();
-  if (hero.opcode() != HloOpcode::kTranspose) {
-    auto map = ComposeIndexingMaps(
-        *ComputeThreadIdToOutputIndexing(root_index, ctx),
-        *ComputeOutputToInputIndexing(
-             &analysis_.fusion_root(root_index).instruction(), 0, ctx)
-             .indexing_maps[hero_operand_index]
-             .begin());
-    map.Simplify();
-    return map;
-  }
-
-  auto map = ComposeIndexingMaps(
-      GetIndexingMapForTiling(tiling_, ctx),
-      GetBitcastMap(tiling_.GetXlaShape(), hero.operand(0)->shape(), ctx));
-  map.Simplify();
-  return map;
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/transpose.h b/third_party/xla/xla/service/gpu/fusions/transpose.h
deleted file mode 100644
index 3f369a4..0000000
--- a/third_party/xla/xla/service/gpu/fusions/transpose.h
+++ /dev/null
@@ -1,91 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_FUSIONS_TRANSPOSE_H_
-#define XLA_SERVICE_GPU_FUSIONS_TRANSPOSE_H_
-
-#include <cstdint>
-#include <optional>
-#include <vector>
-
-#include "absl/status/status.h"
-#include "llvm/IR/IRBuilder.h"
-#include "mlir/IR/MLIRContext.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/service/gpu/fusions/fusion_emitter.h"
-#include "xla/service/gpu/fusions/tiling_util.h"
-#include "xla/service/gpu/hlo_fusion_analysis.h"
-#include "xla/service/gpu/ir_emitter_context.h"
-#include "xla/service/gpu/launch_dimensions.h"
-#include "xla/service/gpu/model/indexing_map.h"
-#include "xla/service/llvm_ir/ir_array.h"
-#include "xla/util.h"
-
-namespace xla {
-namespace gpu {
-
-// Emits a kernel for the given hlo instruction using a tiled 0-2-1 transpose
-// algorithm to improve the memory access patterns for the input parameters
-// with a shape that is a 0-2-1 transpose of the output tensor shape. The
-// caller is responsible for making sure that it is safe to apply the shared
-// memory transpose on the input parameters.
-//
-// For the purpose of tiling, the output tensors have a logical shape of three
-// components 0-2-1 while the relevant input parameters have a logical shape
-// of three components 0-1-2 in the order major to minor. The x- and y-
-// dimensions of the tensors are tiled in square tiles with an edge length
-// `kTileSize`. Each thread block of `kTileSize` x `kNumRows` threads
-// transposes one tile: each thread copies kTileSize/kNumRows elements from
-// the input to a shared memory tile, then the otherwise "regular HLO kernel"
-// reads from the shared memory instead of the original input.
-//
-// This is similar to the following CUDA algorithm in TensorFlow:
-// https://goo.gl/MStRV6.
-//
-// `kTileSize` should usually be same as warp size. We currently choose 32 for
-// `kTileSize` and 4 for `kNumRows`. The CUDA algorithm uses 8 for `kNumRows`.
-//
-// TODO(b/33320379): Here each block transposes 1 tile. It may be more
-// efficient to launch fewer blocks so each transposes many tiles.
-class TransposeFusion : public KernelFusionEmitterBase {
- public:
-  explicit TransposeFusion(const se::DeviceDescription& gpu_device_info,
-                           const HloFusionAnalysis& analysis);
-  LaunchDimensions launch_dimensions() const override;
-
-  std::optional<IndexingMap> ComputeThreadIdToOutputIndexing(
-      int64_t root_index, mlir::MLIRContext* ctx) const override;
-
-  std::optional<IndexingMap> ComputeThreadIdToInputIndexing(
-      int64_t root_index, int64_t hero_operand_index,
-      mlir::MLIRContext* ctx) const override;
-
- protected:
-  absl::Status EmitKernel(IrEmitterContext& ir_emitter_context,
-                          const HloFusionInstruction& fusion,
-                          const LaunchDimensions& launch_dims,
-                          std::vector<llvm_ir::IrArray> inputs,
-                          std::vector<llvm_ir::IrArray> outputs,
-                          llvm::IRBuilder<>* builder) const override;
-
- private:
-  const HloFusionAnalysis& analysis_;
-  Tiling tiling_;
-  Vector3 permutation_;
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_FUSIONS_TRANSPOSE_H_
diff --git a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc
index 8053a89..c6ce433 100644
--- a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc
+++ b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc
@@ -24,6 +24,7 @@
 #include "absl/container/flat_hash_map.h"
 #include "absl/container/inlined_vector.h"
 #include "absl/log/check.h"
+#include "absl/status/status.h"
 #include "absl/types/span.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
@@ -41,14 +42,12 @@
 #include "xla/hlo/ir/hlo_computation.h"
 #include "xla/hlo/ir/hlo_instruction.h"
 #include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/mlir/utils/type_util.h"
 #include "xla/permutation_util.h"
 #include "xla/primitive_util.h"
 #include "xla/service/gpu/fusions/fusion_emitter.h"
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h"
 #include "xla/service/gpu/fusions/mlir/computation_partitioner.h"
 #include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h"
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h"
 #include "xla/service/gpu/fusions/mlir/type_util.h"
 #include "xla/service/gpu/hlo_fusion_analysis.h"
 #include "xla/service/gpu/ir_emission_utils.h"
@@ -72,8 +71,6 @@
 using mlir::ValueRange;
 using mlir::func::FuncOp;
 using mlir::func::ReturnOp;
-using mlir::tensor::ExtractOp;
-using mlir::tensor::InsertOp;
 using mlir_converter::ApplyIndexing;
 
 constexpr int kNumRows = 4;
@@ -99,6 +96,11 @@
       transposes_to_tile.insert(&hero.instruction());
       shmem_transpose_roots_.push_back(&root.instruction());
       int size = primitive_util::ByteWidth(hero.shape().element_type());
+      // If the last dimension stays the same, we need to make it part of the
+      // shared memory tile.
+      if (MostMinorDimensionUnchanged()) {
+        size *= input_shape_.back();
+      }
       max_element_bytes = std::max(max_element_bytes, size);
       shmem_usage += kBaseBlockSize * (kBaseBlockSize + 1) * size;
       shmem_transpose_root_indices_.push_back(index);
@@ -113,8 +115,12 @@
   auto compute_block_sizes = [this](int vector_size) {
     vector_size_ = vector_size;
     block_size_ = kBaseBlockSize * vector_size_;
-    block_sizes_ = {1, 1, block_size_};
-    block_sizes_[permutation_[2]] = block_size_;
+    if (MostMinorDimensionUnchanged()) {
+      block_sizes_ = {block_size_, block_size_, input_shape_.back()};
+    } else {
+      block_sizes_ = {1, 1, block_size_};
+      block_sizes_[permutation_[2]] = block_size_;
+    }
     block_counts_ = {CeilOfRatio(input_shape_[0], block_sizes_[0]),
                      CeilOfRatio(input_shape_[1], block_sizes_[1]),
                      CeilOfRatio(input_shape_[2], block_sizes_[2])};
@@ -137,6 +143,10 @@
         shmem_usage * elems_per_thread <= device.shared_memory_per_block();
     bool aligned_dims = (input_shape_[2] % vec_size == 0) &&
                         (input_shape_[permutation_[2]] % vec_size == 0);
+    if (MostMinorDimensionUnchanged()) {
+      aligned_dims =
+          input_shape_[0] % vec_size == 0 && input_shape_[1] % vec_size == 0;
+    }
     if (enough_work && enough_shmem && aligned_dims) {
       compute_block_sizes(vec_size);
       break;
@@ -187,7 +197,15 @@
 IndexingMap MlirTransposeFusion::GetSharedMemoryIndexing(
     bool read, mlir::MLIRContext* ctx) const {
   auto thread_offsets =
-      Permute(GetThreadOffsets(ctx), read ? Vector3{0, 1, 2} : permutation_);
+      Permute(GetThreadOffsets(ctx),
+              read ? absl::InlinedVector<int64_t, 3>{0, 1, 2} : permutation_);
+  if (MostMinorDimensionUnchanged()) {
+    return {mlir::AffineMap::get(6, 3, thread_offsets, ctx),
+            DimVarsFromTensorSizes({kNumThreadsPerBlock, 1, 1, 1, 1, 1}),
+            RangeVarsFromTensorSizes(
+                {block_size_ / kNumRows, vector_size_, input_shape_[2]}),
+            {}};
+  }
   return {mlir::AffineMap::get(6, 2, thread_offsets, ctx),
           DimVarsFromTensorSizes({kNumThreadsPerBlock, 1, 1, 1, 1, 1}),
           RangeVarsFromTensorSizes({block_size_ / kNumRows, vector_size_}),
@@ -203,7 +221,11 @@
   MLIRContext* ctx = builder.getContext();
   auto shmem_tensor_size = block_sizes_;
   // Avoid bank conflicts.
-  ++shmem_tensor_size.back();
+  if (MostMinorDimensionUnchanged()) {
+    ++shmem_tensor_size[1];
+  } else {
+    ++shmem_tensor_size.back();
+  }
 
   // Allocate shared memory.
   SmallVector<Value> inits;
@@ -237,8 +259,8 @@
           root_computation, transpose,
           /*operand_index=*/0, input_indices(transpose->operand(0)),
           call_target_provider, entry_function, builder)[0];
-      result_tensors.push_back(
-          builder.create<InsertOp>(result_scalar, output, shmem_indices));
+      result_tensors.push_back(builder.create<mlir::tensor::InsertOp>(
+          result_scalar, output, shmem_indices));
     }
 
     // Produce all side outputs and then write them.
@@ -258,7 +280,7 @@
          llvm::zip(side_outputs, side_output_indices,
                    output_tensors.take_back(side_output_roots_.size()))) {
       result_tensors.push_back(
-          builder.create<InsertOp>(value, output, indices));
+          builder.create<mlir::tensor::InsertOp>(value, output, indices));
     }
 
     return result_tensors;
@@ -306,7 +328,7 @@
         for (auto [transpose, shmem] :
              llvm::zip(shmem_transposes_, written.shmem_tensors)) {
           transpose_values[transpose].push_back(
-              builder.create<ExtractOp>(shmem, shmem_indices));
+              builder.create<mlir::tensor::ExtractOp>(shmem, shmem_indices));
         }
         llvm::SmallVector<Value> epilogue_indices = dim_values;
         absl::c_copy(symbol_values, std::back_inserter(epilogue_indices));
@@ -320,7 +342,7 @@
                        shmem_transpose_root_indices_)) {
           llvm::SmallVector<Value> indices =
               ApplyIndexing(indexing, dim_values, symbol_values, builder);
-          results[root_index] = builder.create<InsertOp>(
+          results[root_index] = builder.create<mlir::tensor::InsertOp>(
               result_scalars.at(root).front(), results[root_index], indices);
         }
         return results;
@@ -372,6 +394,10 @@
   auto vector = mlir::getAffineSymbolExpr(1, ctx);
   int loop_stride = block_size_ * kNumRows;
   auto linear_index = loop * loop_stride + thread * vector_size_ + vector;
+  if (MostMinorDimensionUnchanged()) {
+    auto minor_dim = mlir::getAffineSymbolExpr(2, ctx);
+    linear_index = linear_index * input_shape_[2] + minor_dim;
+  }
   return DelinearizeInBoundsIndex(linear_index, block_sizes_);
 }
 
@@ -380,19 +406,25 @@
                                              mlir::MLIRContext* ctx) const {
   auto raw_id = mlir::getAffineDimExpr(
       KernelFusionInterface::kIndexingMapBlockIdxDims[0], ctx);
-  auto block_ids = Permute(DelinearizeInBoundsIndex(raw_id, block_counts_),
-                           input ? Vector3{0, 1, 2} : permutation_);
+  auto block_ids =
+      Permute(DelinearizeInBoundsIndex(raw_id, block_counts_),
+              input ? absl::InlinedVector<int64_t, 3>{0, 1, 2} : permutation_);
   auto thread_offsets = GetThreadOffsets(ctx);
   llvm::SmallVector<AffineExpr, 3> offsets;
   for (auto [block_id, block_size, thread] :
        llvm::zip(block_ids, block_sizes_, thread_offsets)) {
     offsets.push_back(block_id * block_size + thread);
   }
+  auto range_var_sizes =
+      std::vector<int64_t>{block_size_ / kNumRows, vector_size_};
+  if (MostMinorDimensionUnchanged()) {
+    range_var_sizes.push_back(input_shape_[2]);
+  }
   IndexingMap result{
-      mlir::AffineMap::get(6, 2, offsets, ctx),
+      mlir::AffineMap::get(6, range_var_sizes.size(), offsets, ctx),
       DimVarsFromTensorSizes(
           {kNumThreadsPerBlock, 1, 1, Product(block_counts_), 1, 1}),
-      RangeVarsFromTensorSizes({block_size_ / kNumRows, vector_size_}),
+      RangeVarsFromTensorSizes(range_var_sizes),
       {}};
   auto normalized_shape =
       input ? ShapeUtil::MakeShape(shape.element_type(), input_shape_)
@@ -407,5 +439,9 @@
   return result;
 }
 
+bool MlirTransposeFusion::MostMinorDimensionUnchanged() const {
+  return permutation_.back() == permutation_.size() - 1;
+}
+
 }  // namespace gpu
 }  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.h b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.h
index 07d1e99..538ad2f 100644
--- a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.h
+++ b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.h
@@ -19,6 +19,7 @@
 #include <optional>
 #include <vector>
 
+#include "absl/container/inlined_vector.h"
 #include "absl/status/status.h"
 #include "llvm/ADT/SmallVector.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -35,7 +36,6 @@
 #include "xla/service/gpu/ir_emission_utils.h"
 #include "xla/service/gpu/launch_dimensions.h"
 #include "xla/service/gpu/model/indexing_map.h"
-#include "xla/util.h"
 
 namespace xla {
 namespace gpu {
@@ -98,9 +98,10 @@
   IndexingMap GetSharedMemoryIndexing(bool read, mlir::MLIRContext* ctx) const;
   llvm::SmallVector<mlir::AffineExpr, 4> GetThreadOffsets(
       mlir::MLIRContext* ctx) const;
+  bool MostMinorDimensionUnchanged() const;
 
   TransposeDescription transpose_;
-  Vector3 permutation_;
+  absl::InlinedVector<int64_t, 3> permutation_;
   std::vector<int64_t> input_shape_;
   std::vector<int64_t> block_sizes_;  // In input elements.
   std::vector<int64_t> block_counts_;
diff --git a/third_party/xla/xla/service/gpu/fusions/transpose_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/transpose_mlir_test.cc
index eb71bb7..eec531d 100644
--- a/third_party/xla/xla/service/gpu/fusions/transpose_mlir_test.cc
+++ b/third_party/xla/xla/service/gpu/fusions/transpose_mlir_test.cc
@@ -44,7 +44,7 @@
   )"));
 
   auto* root = module->entry_computation()->root_instruction();
-  auto analysis = AnalyzeFusion(*root, device_info_);
+  auto analysis = HloFusionAnalysis::Create(*root, device_info_);
 
   MlirTransposeFusion fusion(analysis);
   EXPECT_THAT(
@@ -101,7 +101,7 @@
     })"));
 
   auto* root = module->entry_computation()->root_instruction();
-  auto analysis = AnalyzeFusion(*root, device_info_);
+  auto analysis = HloFusionAnalysis::Create(*root, device_info_);
   MlirTransposeFusion fusion(analysis);
 
   EXPECT_THAT(
@@ -158,7 +158,7 @@
   )"));
 
   auto* root = module->entry_computation()->root_instruction();
-  auto analysis = AnalyzeFusion(*root, device_info_);
+  auto analysis = HloFusionAnalysis::Create(*root, device_info_);
 
   MlirTransposeFusion fusion(analysis);
   EXPECT_THAT(
@@ -212,7 +212,7 @@
     })"));
 
   auto* root = module->entry_computation()->root_instruction();
-  auto analysis = AnalyzeFusion(*root, device_info_);
+  auto analysis = HloFusionAnalysis::Create(*root, device_info_);
   MlirTransposeFusion fusion(analysis);
 
   EXPECT_THAT(
@@ -295,6 +295,55 @@
   EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3}));
 }
 
+TEST_F(MlirTransposeFusionTest, FusedTranspose102) {
+  auto kHloString = R"(
+    HloModule Transpose
+
+    %fused_computation {
+      %p0 = s8[160,170,3] parameter(0)
+      ROOT %transpose = s8[170,160,3] transpose(%p0), dimensions={1,0,2}
+    }
+    ENTRY main {
+      %param = s8[160,170,3] parameter(0)
+      ROOT %fusion = s8[170,160,3] fusion(%param), kind=kInput,
+        calls=%fused_computation
+    }
+  )";
+  TF_EXPECT_OK(EmitAndCheckIR(kHloString, R"(
+    // CHECK-LABEL: func.func @fused_computation(
+    // CHECK-SAME:   }, %[[OUT:.*]]: tensor<170x160x3xi8>
+    //
+    // CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
+    // CHECK-DAG:  %[[C1:.*]] = arith.constant 1 : index
+    // CHECK-DAG:  %[[C3:.*]] = arith.constant 3 : index
+    // CHECK-DAG:  %[[C8:.*]] = arith.constant 8 : index
+
+    // CHECK:      %[[SHMEM:.*]] = xla_gpu.allocate_shared : tensor<32x33x3xi8>
+    // CHECK:      %[[SHMEM_WITH_VALS:.*]] = scf.for
+    // CHECK-SAME:     %[[C0]] to %[[C8]] step %[[C1]]
+    // CHECK-SAME:     iter_args(%[[SHMEM_:.*]] = %[[SHMEM]])
+    // CHECK:      %[[SHMEM_WITH_VALS2:.*]] = scf.for
+    // CHECK-SAME:     %[[C0]] to %[[C3]] step %[[C1]]
+    // CHECK-SAME:     iter_args(%[[SHMEM2_:.*]] = %[[SHMEM_]])
+    // CHECK:        %[[P0:.*]] = xla_gpu.pure_call @fused_computation_p0
+    // CHECK:        tensor.insert %[[P0]] into %[[SHMEM2_]]
+
+    // CHECK:      %[[SYNC:.*]] = xla_gpu.sync_threads %[[SHMEM_WITH_VALS]]
+
+    // CHECK:      scf.for
+    // CHECK-SAME:    %[[C0]] to %[[C8]] step %[[C1]]
+    // CHECK-SAME:    iter_args(%[[OUT_:.*]] = %[[OUT]])
+    // CHECK:      scf.for
+    // CHECK-SAME:    %[[C0]] to %[[C3]] step %[[C1]]
+    // CHECK-SAME:    iter_args(%[[OUT2_:.*]] = %[[OUT_]])
+    // CHECK:       %[[EXTRACTED:.*]] = tensor.extract %[[SYNC]]
+    // CHECK:       %[[RES:.*]] = xla_gpu.pure_call @fused_computation__epilogue__transpose
+    // CHECK:       tensor.insert %[[RES]] into %[[OUT2_]]
+  )"));
+
+  EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3}));
+}
+
 TEST_F(MlirTransposeFusionTest, FusedTranspose210) {
   auto kHloString = R"(
     HloModule Transpose
@@ -578,7 +627,7 @@
                     .value();
 
   auto* root = module->entry_computation()->root_instruction();
-  auto analysis = AnalyzeFusion(*root, device_info_);
+  auto analysis = HloFusionAnalysis::Create(*root, device_info_);
 
   MlirTransposeFusion fusion(analysis);
   mlir::MLIRContext mlir_context;
@@ -608,7 +657,7 @@
                     .value();
 
   auto* root = module->entry_computation()->root_instruction();
-  auto analysis = AnalyzeFusion(*root, device_info_);
+  auto analysis = HloFusionAnalysis::Create(*root, device_info_);
 
   MlirTransposeFusion fusion(analysis);
   mlir::MLIRContext mlir_context;
diff --git a/third_party/xla/xla/service/gpu/fusions/transpose_test.cc b/third_party/xla/xla/service/gpu/fusions/transpose_test.cc
deleted file mode 100644
index f942469..0000000
--- a/third_party/xla/xla/service/gpu/fusions/transpose_test.cc
+++ /dev/null
@@ -1,346 +0,0 @@
-/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "xla/service/gpu/fusions/transpose.h"
-
-#include <memory>
-#include <optional>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "absl/status/statusor.h"
-#include "mlir/IR/MLIRContext.h"
-#include "xla/service/gpu/fusions/fusions.h"
-#include "xla/service/gpu/gpu_device_info_for_tests.h"
-#include "xla/service/gpu/hlo_fusion_analysis.h"
-#include "xla/service/gpu/model/indexing_test_utils.h"
-#include "xla/status_macros.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/tests/hlo_test_base.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-class TransposeTest : public HloTestBase {
- protected:
-  DebugOptions GetDebugOptionsForTest() override {
-    auto opts = HloTestBase::GetDebugOptionsForTest();
-    opts.set_xla_gpu_mlir_emitter_level(0);
-    return opts;
-  }
-  stream_executor::DeviceDescription device_info_ =
-      TestGpuDeviceInfo::RTXA6000DeviceInfo();
-};
-
-absl::StatusOr<std::unique_ptr<TransposeFusion>> GetTransposeFusion(
-    const HloFusionAnalysis& analysis) {
-  auto emitter = GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis});
-  auto fusion = dynamic_cast<TransposeFusion*>(emitter.get());
-  TF_RET_CHECK(fusion != nullptr);
-
-  emitter.release();
-  return std::unique_ptr<TransposeFusion>{fusion};
-}
-
-TEST_F(TransposeTest, ThreadIndexing021) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule module
-
-    fusion {
-      %input = f32[100,32,64] parameter(0)
-      ROOT transpose = f32[100,64,32] transpose(%input), dimensions={0,2,1}
-    }
-
-    ENTRY entry {
-      %input = f32[100,32,64] parameter(0)
-      ROOT %fusion = f32[100,64,32] fusion(%input), kind=kInput, calls=fusion
-    })")
-                    .value();
-
-  auto* root = module->entry_computation()->root_instruction();
-  auto analysis = AnalyzeFusion(*root, device_info_);
-
-  TF_ASSERT_OK_AND_ASSIGN(auto fusion, GetTransposeFusion(analysis));
-  mlir::MLIRContext mlir_context;
-
-  EXPECT_THAT(
-      fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context)->ToString(),
-      MatchIndexingString(R"(
-        (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> (
-          d3 floordiv 2,
-          d0 floordiv 32 + s1 * 4,
-          (d3 mod 2) * 32 + d0 mod 32
-        )
-        domain:
-        d0 in [0, 127]
-        d1 in [0, 0]
-        d2 in [0, 0]
-        d3 in [0, 199]
-        d4 in [0, 0]
-        d5 in [0, 0]
-
-        s0 in [0, 0]
-        s1 in [0, 7]
-        s2 in [0, 0]
-      )"));
-  EXPECT_THAT(
-      fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)->ToString(),
-      MatchIndexingString(R"(
-        (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> (
-          d3 floordiv 2,
-          (d3 mod 2) * 32 + s1 * 4 + d0 floordiv 32,
-          d0 mod 32
-        )
-        domain:
-        d0 in [0, 127]
-        d1 in [0, 0]
-        d2 in [0, 0]
-        d3 in [0, 199]
-        d4 in [0, 0]
-        d5 in [0, 0]
-
-        s0 in [0, 0]
-        s1 in [0, 7]
-        s2 in [0, 0]
-      )"));
-}
-
-TEST_F(TransposeTest, ThreadIndexing201) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule module
-
-    fusion {
-      %input = f32[100,64,32] parameter(0)
-      ROOT transpose = f32[32,100,64] transpose(%input), dimensions={2,0,1}
-    }
-
-    ENTRY entry {
-      %input = f32[100,64,32] parameter(0)
-      ROOT %fusion = f32[32,100,64] fusion(%input), kind=kInput, calls=fusion
-    })")
-                    .value();
-
-  auto* root = module->entry_computation()->root_instruction();
-  auto analysis = AnalyzeFusion(*root, device_info_);
-
-  TF_ASSERT_OK_AND_ASSIGN(auto fusion, GetTransposeFusion(analysis));
-  mlir::MLIRContext mlir_context;
-  EXPECT_THAT(
-      fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context)->ToString(),
-      MatchIndexingString(R"(
-        (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> (
-          d3 floordiv 2,
-          (d3 mod 2) * 32 + s1 * 4 + d0 floordiv 32,
-          d0 mod 32
-        )
-        domain:
-        d0 in [0, 127]
-        d1 in [0, 0]
-        d2 in [0, 0]
-        d3 in [0, 199]
-        d4 in [0, 0]
-        d5 in [0, 0]
-
-        s0 in [0, 0]
-        s1 in [0, 7]
-        s2 in [0, 0]
-      )"));
-  EXPECT_THAT(
-      fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)->ToString(),
-      MatchIndexingString(R"(
-        (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> (
-          d0 floordiv 32 + s1 * 4,
-          d3 floordiv 2,
-          (d3 mod 2) * 32 + d0 mod 32
-        )
-        domain:
-        d0 in [0, 127]
-        d1 in [0, 0]
-        d2 in [0, 0]
-        d3 in [0, 199]
-        d4 in [0, 0]
-        d5 in [0, 0]
-
-        s0 in [0, 0]
-        s1 in [0, 7]
-        s2 in [0, 0]
-      )"));
-}
-
-TEST_F(TransposeTest, ThreadIndexingPartialBlock) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule m
-
-    fused_computation {
-      %p0 = f64[24,2,6,4] parameter(0)
-      ROOT %t = f64[6,4,2,24] transpose(%p0), dimensions={2,3,1,0}
-    }
-
-    ENTRY main {
-      %p0 = f64[24,2,6,4] parameter(0)
-      ROOT %fusion = f64[6,4,2,24] fusion(%p0), kind=kInput,
-        calls=%fused_computation
-    }
-  )")
-                    .value();
-
-  auto* root = module->entry_computation()->root_instruction();
-  auto analysis = AnalyzeFusion(*root, device_info_);
-
-  TF_ASSERT_OK_AND_ASSIGN(auto fusion, GetTransposeFusion(analysis));
-  mlir::MLIRContext mlir_context;
-  EXPECT_THAT(
-      fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context)->ToString(),
-      MatchIndexingString(R"(
-        (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> (
-          d0 floordiv 32 + s0 * 4,
-          d3,
-          (d0 floordiv 4) mod 8,
-          d0 mod 4
-        )
-        domain:
-        d0 in [0, 127]
-        d1 in [0, 0]
-        d2 in [0, 0]
-        d3 in [0, 1]
-        d4 in [0, 0]
-        d5 in [0, 0]
-        s0 in [0, 5]
-        s1 in [0, 0]
-        s2 in [0, 0]
-        d0 mod 32 in [0, 23]
-      )"));
-  EXPECT_THAT(
-      fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)->ToString(),
-      MatchIndexingString(R"(
-        (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> (
-          s0,
-          d0 floordiv 32,
-          d3,
-          d0 mod 32
-        )
-        domain:
-        d0 in [0, 127]
-        d1 in [0, 0]
-        d2 in [0, 0]
-        d3 in [0, 1]
-        d4 in [0, 0]
-        d5 in [0, 0]
-        s0 in [0, 5]
-        s1 in [0, 0]
-        s2 in [0, 0]
-        d0 mod 32 in [0, 23]
-      )"));
-}
-
-TEST_F(TransposeTest, SameInputIndexingForRealHeroAndSideOutput) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule module
-
-    fusion {
-      %input = f32[100,32,64] parameter(0)
-      %transpose = f32[100,64,32] transpose(%input), dimensions={0,2,1}
-      %bitcast = f32[100,2048] bitcast(%input)
-      ROOT %tuple = (f32[100,64,32], f32[100,2048]) tuple(%transpose, %bitcast)
-    }
-
-    ENTRY entry {
-      %input = f32[100,32,64] parameter(0)
-      ROOT %fusion = (f32[100,64,32], f32[100,2048]) fusion(%input), kind=kInput, calls=fusion
-    })")
-                    .value();
-
-  auto* root = module->entry_computation()->root_instruction();
-  auto analysis = AnalyzeFusion(*root, device_info_);
-
-  TF_ASSERT_OK_AND_ASSIGN(auto fusion, GetTransposeFusion(analysis));
-  mlir::MLIRContext mlir_context;
-
-  EXPECT_THAT(
-      fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context)->ToString(),
-      fusion->ComputeThreadIdToInputIndexing(1, 0, &mlir_context)->ToString());
-}
-
-TEST_F(TransposeTest, ThreadIndexingSideOutput) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule module
-
-    fusion {
-      %input0 = f32[100,32,64] parameter(0)
-      %input1 = f32[100,32] parameter(1)
-      %transpose = f32[100,64,32] transpose(%input0), dimensions={0,2,1}
-      %broadcast = f32[100,32,64] broadcast(%input1), dimensions={0,1}
-      ROOT %tuple = (f32[100,64,32], f32[100,32,64]) tuple(%transpose, %broadcast)
-    }
-
-    ENTRY entry {
-      %input0 = f32[100,32,64] parameter(0)
-      %input1 = f32[100,32] parameter(1)
-      ROOT %fusion = (f32[100,64,32], f32[100,32,64]) fusion(%input0, %input1), kind=kInput, calls=fusion
-    })")
-                    .value();
-
-  auto* root = module->entry_computation()->root_instruction();
-  auto analysis = AnalyzeFusion(*root, device_info_);
-
-  TF_ASSERT_OK_AND_ASSIGN(auto fusion, GetTransposeFusion(analysis));
-  mlir::MLIRContext mlir_context;
-  // Check if side output `%broadcast` get the correct input indexing, which
-  // should corresponds to `%input1` with shape [100,32].
-  EXPECT_THAT(
-      fusion->ComputeThreadIdToInputIndexing(1, 0, &mlir_context)->ToString(),
-      MatchIndexingString(R"(
-        (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> (
-          d3 floordiv 2,
-          d0 floordiv 32 + s1 * 4
-        )
-        domain:
-        d0 in [0, 127]
-        d1 in [0, 0]
-        d2 in [0, 0]
-        d3 in [0, 199]
-        d4 in [0, 0]
-        d5 in [0, 0]
-
-        s0 in [0, 0]
-        s1 in [0, 7]
-        s2 in [0, 0]
-      )"));
-  EXPECT_THAT(
-      fusion->ComputeThreadIdToOutputIndexing(1, &mlir_context)->ToString(),
-      MatchIndexingString(R"(
-        (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> (
-          d3 floordiv 2,
-          d0 floordiv 32 + s1 * 4,
-          (d3 mod 2) * 32 + d0 mod 32
-        )
-        domain:
-        d0 in [0, 127]
-        d1 in [0, 0]
-        d2 in [0, 0]
-        d3 in [0, 199]
-        d4 in [0, 0]
-        d5 in [0, 0]
-
-        s0 in [0, 0]
-        s1 in [0, 7]
-        s2 in [0, 0]
-      )"));
-}
-
-}  // namespace
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/fusions/triton/BUILD b/third_party/xla/xla/service/gpu/fusions/triton/BUILD
index 8c9eadf..3bedbcf 100644
--- a/third_party/xla/xla/service/gpu/fusions/triton/BUILD
+++ b/third_party/xla/xla/service/gpu/fusions/triton/BUILD
@@ -1,3 +1,4 @@
+load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")
 load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured")
 load("@local_tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured")
 load("//xla:xla.bzl", "xla_cc_test")
@@ -32,8 +33,7 @@
     ]),
     hdrs = ["triton_fusion_emitter.h"],
     deps = [
-        ":prevent_mmav3_loop_unrolling",
-        ":sparse_extensions",
+        ":passes",
         "//xla:autotuning_proto_cc",
         "//xla:comparison_util",
         "//xla:debug_options_flags",
@@ -58,9 +58,9 @@
         "//xla/service/gpu:target_util",
         "//xla/service/gpu:triton_fusion_analysis",
         "//xla/service/gpu:triton_tiling_propagation",
+        "//xla/service/gpu/fusions/ir:xla_gpu",
         "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir",
-        "//xla/service/gpu/fusions/mlir:passes",
-        "//xla/service/gpu/fusions/mlir/ir:xla_gpu",
+        "//xla/service/gpu/fusions/transforms:passes",
         "//xla/service/gpu/llvm_gpu_backend",
         "//xla/service/gpu/model:affine_map_printer",
         "//xla/service/gpu/model:indexing_analysis",
@@ -135,11 +135,33 @@
     ]),
 )
 
+gentbl_cc_library(
+    name = "passes_inc_gen",
+    tbl_outs = [
+        (
+            [
+                "-gen-pass-decls",
+                "-name=TritonFusionTransforms",
+            ],
+            "passes.h.inc",
+        ),
+    ],
+    tblgen = "@llvm-project//mlir:mlir-tblgen",
+    td_file = "passes.td",
+    visibility = ["//visibility:private"],
+    deps = ["@llvm-project//mlir:PassBaseTdFiles"],
+)
+
 cc_library(
-    name = "sparse_extensions",
-    srcs = ["sparse_extensions.cc"],
-    hdrs = ["sparse_extensions.h"],
+    name = "passes",
+    srcs = [
+        "passes.cc",
+        "prevent_mmav3_loop_unrolling.cc",
+        "sparse_extensions.cc",
+    ],
+    hdrs = ["passes.h"],
     deps = [
+        ":passes_inc_gen",
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:ArithDialect",
         "@llvm-project//mlir:GPUCommonTransforms",
@@ -151,6 +173,7 @@
         "@llvm-project//mlir:NVVMDialect",
         "@llvm-project//mlir:Pass",
         "@llvm-project//mlir:Rewrite",
+        "@llvm-project//mlir:SCFDialect",
         "@llvm-project//mlir:Support",
         "@llvm-project//mlir:TransformUtils",
         "@llvm-project//mlir:Transforms",
@@ -165,20 +188,6 @@
     ],
 )
 
-cc_library(
-    name = "prevent_mmav3_loop_unrolling",
-    srcs = ["prevent_mmav3_loop_unrolling.cc"],
-    hdrs = ["prevent_mmav3_loop_unrolling.h"],
-    deps = [
-        "@llvm-project//llvm:Support",
-        "@llvm-project//mlir:IR",
-        "@llvm-project//mlir:Pass",
-        "@llvm-project//mlir:SCFDialect",
-        "@llvm-project//mlir:Support",
-        "@triton//:TritonDialects",
-    ],
-)
-
 xla_test(
     name = "triton_fusion_emitter_device_legacy_test",
     srcs = if_gpu_is_configured(["triton_fusion_emitter_device_legacy_test.cc"]),
diff --git a/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_cuda.cc b/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_cuda.cc
index 471a91c..46a569d 100644
--- a/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_cuda.cc
+++ b/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_cuda.cc
@@ -24,8 +24,7 @@
 #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
 #include "mlir/Pass/PassManager.h"
 #include "mlir/Transforms/Passes.h"
-#include "xla/service/gpu/fusions/triton/prevent_mmav3_loop_unrolling.h"
-#include "xla/service/gpu/fusions/triton/sparse_extensions.h"
+#include "xla/service/gpu/fusions/triton/passes.h"
 #include "xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
 #include "xla/service/gpu/model/tiled_hlo_computation.h"
 #include "xla/service/hlo_module_config.h"
@@ -65,9 +64,9 @@
   pm.addPass(mt::createConvertTritonToTritonGPUPass(
       absl::StrFormat("cuda:%u", ccAsInt), block_level_parameters.num_warps,
       threadsPerWarp, block_level_parameters.num_ctas));
-  pm.addPass(CreateAddSparseDotEncodingPass(block_level_parameters.num_warps,
-                                            threadsPerWarp,
-                                            block_level_parameters.num_ctas));
+  pm.addPass(CreateSparseAddEncodingPass(block_level_parameters.num_warps,
+                                         threadsPerWarp,
+                                         block_level_parameters.num_ctas));
   pm.addPass(mt::gpu::createTritonGPUCoalesce());
   if (ccCuda.IsAtLeastAmpere()) {
     pm.addPass(mt::gpu::createTritonGPUF32DotTC());
diff --git a/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc b/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc
index e31e29b..2a95ea8 100644
--- a/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc
+++ b/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc
@@ -22,7 +22,6 @@
 #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
 #include "mlir/Pass/PassManager.h"
 #include "mlir/Transforms/Passes.h"
-#include "xla/service/gpu/fusions/triton/sparse_extensions.h"
 #include "xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
 #include "xla/service/gpu/matmul_utils.h"
 #include "xla/service/gpu/model/tiled_hlo_computation.h"
@@ -57,8 +56,6 @@
     mlir::OpPassManager& pm, const se::GpuComputeCapability& cc,
     const BlockLevelParameters& block_level_parameters,
     mt::nvidia_gpu::ClusterInfo& out_cluster_info) {
-  // TODO(ROCm): Check whether value different than 0 can be used.
-  const int ccAsInt = 0;
   // TODO(ROCm): Check why some test fail when threadsPerWarp is set to 64.
   const int threadsPerWarp = 32;
   auto ccRocm = std::get<se::RocmComputeCapability>(cc);
diff --git a/third_party/xla/xla/service/gpu/fusions/triton/passes.cc b/third_party/xla/xla/service/gpu/fusions/triton/passes.cc
new file mode 100644
index 0000000..fefb0f7
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/triton/passes.cc
@@ -0,0 +1,38 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/fusions/triton/passes.h"
+
+#include "llvm/ADT/STLFunctionalExtras.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Visitors.h"
+
+namespace xla::gpu {
+namespace {
+
+using ::mlir::WalkResult;
+
+}  // namespace
+
+bool ContainsOp(mlir::Operation* op,
+                llvm::function_ref<bool(mlir::Operation*)> fn) {
+  return op
+      ->walk([&](mlir::Operation* nested_op) {
+        return fn(nested_op) ? WalkResult::interrupt() : WalkResult::advance();
+      })
+      .wasInterrupted();
+}
+
+}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/fusions/triton/passes.h b/third_party/xla/xla/service/gpu/fusions/triton/passes.h
new file mode 100644
index 0000000..39066ba
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/triton/passes.h
@@ -0,0 +1,50 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_FUSIONS_TRITON_PASSES_H_
+#define XLA_SERVICE_GPU_FUSIONS_TRITON_PASSES_H_
+
+#include <cstdint>
+#include <memory>
+
+#include "llvm/ADT/STLFunctionalExtras.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Pass/Pass.h"
+
+namespace xla::gpu {
+
+#define GEN_PASS_DECL
+#include "xla/service/gpu/fusions/triton/passes.h.inc"
+
+std::unique_ptr<mlir::Pass> CreateSparseAddEncodingPass(
+    int32_t num_warps = 4, int32_t threads_per_warp = 32, int32_t num_ctas = 1);
+std::unique_ptr<mlir::Pass> CreateSparseBlockedToMMAPass();
+std::unique_ptr<mlir::Pass> CreateSparseRemoveLayoutConversionPass();
+std::unique_ptr<mlir::Pass> CreateSparseLocalLoadToLLVMPass();
+std::unique_ptr<mlir::Pass> CreateSparseDotOpToLLVMPass();
+std::unique_ptr<mlir::Pass> CreateSparseWGMMAOpToLLVMPass();
+std::unique_ptr<mlir::Pass> CreatePreventMmaV3LoopUnrollingPass();
+
+// Returns true if the `op` contains an operation in it's regions that satisfies
+// the `fn`.
+bool ContainsOp(mlir::Operation* op,
+                llvm::function_ref<bool(mlir::Operation*)> fn);
+
+#define GEN_PASS_REGISTRATION
+#include "xla/service/gpu/fusions/triton/passes.h.inc"
+
+}  // namespace xla::gpu
+
+#endif  // XLA_SERVICE_GPU_FUSIONS_TRITON_PASSES_H_
diff --git a/third_party/xla/xla/service/gpu/fusions/triton/passes.td b/third_party/xla/xla/service/gpu/fusions/triton/passes.td
new file mode 100644
index 0000000..b1366d0
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/fusions/triton/passes.td
@@ -0,0 +1,94 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_FUSIONS_TRITON_PASSES_TD_
+#define XLA_SERVICE_GPU_FUSIONS_TRITON_PASSES_TD_
+
+include "mlir/Pass/PassBase.td"
+
+def SparseAddEncodingPass : Pass<"sparse-add-encoding", "mlir::ModuleOp"> {
+  let summary = "Adds sparse dot encoding.";
+  let options = [
+    Option<"num_warps_", "num-warps", "int32_t", /*default=*/"4",
+           "Number of warps">,
+    Option<"threads_per_warp_", "threads-per-warp", "int32_t", /*default=*/"32",
+           "Number of threads per warp">,
+    Option<"num_ctas_", "num-ctas", "int32_t", /*default=*/"1",
+           "Number of CTAs in a CGA">,
+  ];
+  let dependentDialects = [
+    "triton::gpu::TritonGPUDialect",
+  ];
+  let constructor = "CreateSparseAddEncodingPass()";
+}
+
+def SparseBlockedToMMAPass : Pass<"sparse-blocked-to-mma", "mlir::ModuleOp"> {
+  let summary = "Add convert layouts to/from MMA before and after SparseDotOp.";
+  let dependentDialects = [
+    "triton::gpu::TritonGPUDialect",
+  ];
+  let constructor = "CreateSparseBlockedToMMAPass()";
+}
+
+def SparseRemoveLayoutConversionPass
+    : Pass<"sparse-remove-layout-conversion", "mlir::ModuleOp"> {
+  let summary = "Replaces ConvertLayoutOp with sparse dot encoding";
+  let dependentDialects = [
+    "triton::gpu::TritonGPUDialect",
+  ];
+  let constructor = "CreateSparseRemoveLayoutConversionPass()";
+}
+
+def SparseLocalLoadToLLVMPass
+    : Pass<"sparse-local-load-to-llvm", "mlir::ModuleOp"> {
+  let summary = "Lowers sparse local load to LLVM";
+  let dependentDialects = [
+    "triton::gpu::TritonGPUDialect",
+    "mlir::LLVM::LLVMDialect"
+  ];
+  let constructor = "CreateSparseLocalLoadToLLVMPass()";
+}
+
+def SparseDotOpToLLVMPass : Pass<"sparse-dot-to-llvm", "mlir::ModuleOp"> {
+  let summary = "Lowers sparse dot to LLVM";
+  let constructor = "CreateSparseDotOpToLLVMPass()";
+  let dependentDialects = [
+    "triton::gpu::TritonGPUDialect",
+    "mlir::triton::nvgpu::NVGPUDialect",
+  ];
+}
+
+def SparseWGMMAOpToLLVMPass : Pass<"sparse-wgmma-to-llvm", "mlir::ModuleOp"> {
+  let summary = "Lowers sparse WGMMA to LLVM";
+  let dependentDialects = [
+    "triton::gpu::TritonGPUDialect",
+    "mlir::triton::nvgpu::NVGPUDialect",
+  ];
+  let constructor = "CreateSparseWGMMAOpToLLVMPass()";
+}
+
+def PreventMmaV3LoopUnrollingPass
+    : Pass<"prevent-mmav3-loop-unrolling", "mlir::ModuleOp"> {
+  let summary = "Prevent MMAv3 loop unrolling.";
+  let description = [{
+    This pass is a result of b/344841434:
+    PTX sometimes unrolls wgmma loops that can cause a 1000x slow down in
+    compilation time. Most unrolling has already been done before PTX,
+    this pragma prevents ptxas from doing more.
+  }];
+  let constructor = "CreatePreventMmaV3LoopUnrollingPass()";
+}
+
+#endif  // XLA_SERVICE_GPU_FUSIONS_TRITON_PASSES_TD_
diff --git a/third_party/xla/xla/service/gpu/fusions/triton/prevent_mmav3_loop_unrolling.cc b/third_party/xla/xla/service/gpu/fusions/triton/prevent_mmav3_loop_unrolling.cc
index 7cb0a55..e5b3d4e 100644
--- a/third_party/xla/xla/service/gpu/fusions/triton/prevent_mmav3_loop_unrolling.cc
+++ b/third_party/xla/xla/service/gpu/fusions/triton/prevent_mmav3_loop_unrolling.cc
@@ -13,29 +13,27 @@
 limitations under the License.
 ==============================================================================*/
 
-#include "xla/service/gpu/fusions/triton/prevent_mmav3_loop_unrolling.h"
-
 #include <memory>
 
-#include "llvm/Support/ErrorHandling.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/Value.h"
 #include "mlir/Pass/Pass.h"
-#include "mlir/Pass/PassRegistry.h"
 #include "mlir/Support/LLVM.h"
+#include "xla/service/gpu/fusions/triton/passes.h"
 #include "triton/Dialect/Triton/IR/Dialect.h"
 #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
 
-class PreventMmaV3LoopUnrollingPass
-    : public mlir::PassWrapper<PreventMmaV3LoopUnrollingPass,
-                               mlir::OperationPass<mlir::ModuleOp>> {
- public:
-  llvm::StringRef getArgument() const override {
-    return "prevent-mmav3-loop-unrolling";
-  }
+namespace xla::gpu {
+namespace {
 
+#define GEN_PASS_DEF_PREVENTMMAV3LOOPUNROLLINGPASS
+#include "xla/service/gpu/fusions/triton/passes.h.inc"
+
+struct PreventMmaV3LoopUnrollingPass
+    : public impl::PreventMmaV3LoopUnrollingPassBase<
+          PreventMmaV3LoopUnrollingPass> {
   // TODO(b/344841434): Remove this if NVIDIA fixes compile-time issue.
   // PTX sometimes unrolls wgmma loops that can cause a 1000x slow down in
   // compilation time. Most unrolling has already been done before PTX;
@@ -60,10 +58,10 @@
   }
 };
 
-std::unique_ptr<mlir::Pass> xla::gpu::CreatePreventMmaV3LoopUnrollingPass() {
+}  // namespace
+
+std::unique_ptr<mlir::Pass> CreatePreventMmaV3LoopUnrollingPass() {
   return std::make_unique<PreventMmaV3LoopUnrollingPass>();
 }
 
-void xla::gpu::RegisterPreventMmaV3LoopUnrollingPass() {
-  registerPass(CreatePreventMmaV3LoopUnrollingPass);
-}
+}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/fusions/triton/prevent_mmav3_loop_unrolling.h b/third_party/xla/xla/service/gpu/fusions/triton/prevent_mmav3_loop_unrolling.h
deleted file mode 100644
index f8e1af0..0000000
--- a/third_party/xla/xla/service/gpu/fusions/triton/prevent_mmav3_loop_unrolling.h
+++ /dev/null
@@ -1,35 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_FUSIONS_TRITON_PREVENT_MMAV3_LOOP_UNROLLING_H_
-#define XLA_SERVICE_GPU_FUSIONS_TRITON_PREVENT_MMAV3_LOOP_UNROLLING_H_
-
-#include <memory>
-
-#include "mlir/Pass/Pass.h"
-
-namespace xla::gpu {
-
-// This pass is a result of b/344841434:
-// PTX sometimes unrolls wgmma loops that can cause a 1000x slow down in
-// compilation time. Most unrolling has already been done before PTX,
-// this pragma prevents ptxas from doing more.
-std::unique_ptr<mlir::Pass> CreatePreventMmaV3LoopUnrollingPass();
-
-void RegisterPreventMmaV3LoopUnrollingPass();
-
-}  // namespace xla::gpu
-
-#endif  // XLA_SERVICE_GPU_FUSIONS_TRITON_PREVENT_MMAV3_LOOP_UNROLLING_H_
diff --git a/third_party/xla/xla/service/gpu/fusions/triton/sparse_extensions.cc b/third_party/xla/xla/service/gpu/fusions/triton/sparse_extensions.cc
index bfc9d7a..5b2d331 100644
--- a/third_party/xla/xla/service/gpu/fusions/triton/sparse_extensions.cc
+++ b/third_party/xla/xla/service/gpu/fusions/triton/sparse_extensions.cc
@@ -13,8 +13,6 @@
 limitations under the License.
 ==============================================================================*/
 
-#include "xla/service/gpu/fusions/triton/sparse_extensions.h"
-
 #include <algorithm>
 #include <cassert>
 #include <cstdint>
@@ -48,9 +46,9 @@
 #include "mlir/Pass/PassRegistry.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/LogicalResult.h"
-#include "mlir/Support/TypeID.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "xla/service/gpu/fusions/triton/passes.h"
 #include "triton/Analysis/Allocation.h"
 #include "triton/Analysis/Membar.h"
 #include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h"
@@ -65,6 +63,13 @@
 
 using namespace mlir;  // NOLINT(build/namespaces)
 
+namespace ttn = triton::nvgpu;
+using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
+using ::mlir::triton::gpu::getShapePerCTA;
+using ::mlir::triton::gpu::getShapePerCTATile;
+using ::mlir::triton::gpu::SharedEncodingAttr;
+using ttn::OperandsAndConstraints;
+
 // The functions below are defined in AccelerateMatmul.cpp.
 namespace mlir::triton::gpu {
 SmallVector<unsigned, 3> getWarpsPerTile(
@@ -80,13 +85,22 @@
                        int64_t swizzling, uint32_t stride);
 int64_t getSwizzlingFromLayout(const triton::gpu::SharedEncodingAttr &layout,
                                uint32_t widthInByte);
-triton::nvgpu::WGMMAEltType getMmaRetType(Value);
-triton::nvgpu::WGMMAEltType getMmaOperandType(Value, bool);
+ttn::WGMMAEltType getMmaRetType(Value);
+ttn::WGMMAEltType getMmaOperandType(Value, bool);
 
+namespace xla::gpu {
 namespace {
 
+#define GEN_PASS_DEF_SPARSEADDENCODINGPASS
+#define GEN_PASS_DEF_SPARSEBLOCKEDTOMMAPASS
+#define GEN_PASS_DEF_SPARSEDOTOPTOLLVMPASS
+#define GEN_PASS_DEF_SPARSELOCALLOADTOLLVMPASS
+#define GEN_PASS_DEF_SPARSEREMOVELAYOUTCONVERSIONPASS
+#define GEN_PASS_DEF_SPARSEWGMMAOPTOLLVMPASS
+#include "xla/service/gpu/fusions/triton/passes.h.inc"
+
 // Add sparse encoding for all the arguments of a SparseDotOp.
-struct AddSparseEncoding
+struct SparseAddEncoding
     : public OpConversionPattern<triton::gpu::SparseDotOp> {
   using OpConversionPattern<triton::gpu::SparseDotOp>::OpConversionPattern;
 
@@ -179,29 +193,16 @@
   }
 };
 
-class AddSparseEncodingPass
-    : public PassWrapper<AddSparseEncodingPass, OperationPass<ModuleOp>> {
- public:
-  AddSparseEncodingPass() = default;
-  AddSparseEncodingPass(int32_t num_warps, int32_t threads_per_warp,
-                        int32_t num_ctas) {
-    num_warps_ = num_warps;
-    threads_per_warp_ = threads_per_warp;
-    num_ctas_ = num_ctas;
-  }
-  AddSparseEncodingPass(const AddSparseEncodingPass &other) {
-    num_warps_ = other.num_warps_;
-    threads_per_warp_ = other.threads_per_warp_;
-    num_ctas_ = other.num_ctas_;
-  };
-
-  StringRef getArgument() const override { return "add-sparse-encoding"; }
+struct SparseAddEncodingPass
+    : public impl::SparseAddEncodingPassBase<SparseAddEncodingPass> {
+  using impl::SparseAddEncodingPassBase<
+      SparseAddEncodingPass>::SparseAddEncodingPassBase;
 
   void runOnOperation() override {
     MLIRContext *context = &getContext();
     TritonGPUTypeConverter type_converter(context, num_warps_,
                                           threads_per_warp_, num_ctas_);
-    auto pattern = std::make_unique<AddSparseEncoding>(type_converter, context);
+    auto pattern = std::make_unique<SparseAddEncoding>(type_converter, context);
     RewritePatternSet patterns(context, std::move(pattern));
     TritonGPUConversionTarget target(*context, type_converter);
     target.addDynamicallyLegalOp<triton::gpu::SparseDotOp>(
@@ -212,18 +213,6 @@
                                       std::move(patterns))))
       return signalPassFailure();
   }
-
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AddSparseEncodingPass)
-
- private:
-  Option<int32_t> num_warps_{
-      *this, "num-warps", llvm::cl::desc("number of warps"), llvm::cl::init(4)};
-  Option<int32_t> threads_per_warp_{
-      *this, "threads-per-warp", llvm::cl::desc("number of threads per warp"),
-      llvm::cl::init(32)};
-  Option<int32_t> num_ctas_{*this, "num-ctas",
-                            llvm::cl::desc("number of ctas in a cga"),
-                            llvm::cl::init(1)};
 };
 
 // Add convert layouts to and from MMA before and after SparseDotOp. In MMAV3,
@@ -332,13 +321,8 @@
   int compute_capability_;
 };
 
-class SparseBlockedToMMAPass
-    : public PassWrapper<SparseBlockedToMMAPass, OperationPass<ModuleOp>> {
- public:
-  SparseBlockedToMMAPass() = default;
-
-  StringRef getArgument() const override { return "sparse-blocked-to-mma"; }
-
+struct SparseBlockedToMMAPass
+    : public impl::SparseBlockedToMMAPassBase<SparseBlockedToMMAPass> {
   void runOnOperation() override {
     MLIRContext *context = &getContext();
     ModuleOp module = getOperation();
@@ -350,8 +334,6 @@
       return signalPassFailure();
     }
   }
-
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SparseBlockedToMMAPass)
 };
 
 class SparseLocalLoadToLLVM
@@ -469,16 +451,9 @@
   }
 };
 
-class SparseRemoveLayoutConversionPass
-    : public PassWrapper<SparseRemoveLayoutConversionPass,
-                         OperationPass<ModuleOp>> {
- public:
-  SparseRemoveLayoutConversionPass() = default;
-
-  StringRef getArgument() const override {
-    return "sparse-remove-layout-conversion";
-  }
-
+struct SparseRemoveLayoutConversionPass
+    : public impl::SparseRemoveLayoutConversionPassBase<
+          SparseRemoveLayoutConversionPass> {
   void runOnOperation() override {
     getOperation().walk([&](triton::gpu::ConvertLayoutOp op) {
       ImplicitLocOpBuilder builder(op.getLoc(), op);
@@ -507,35 +482,22 @@
       op.erase();
     });
   }
-
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SparseRemoveLayoutConversionPass)
 };
 
-class SparseLocalLoadToLLVMPass
-    : public PassWrapper<SparseLocalLoadToLLVMPass, OperationPass<ModuleOp>> {
- public:
-  SparseLocalLoadToLLVMPass() = default;
+bool IsLocalLoadWithSparseEncoding(Operation *op) {
+  auto local_load = mlir::dyn_cast<triton::gpu::LocalLoadOp>(op);
+  if (!local_load) return false;
+  return isa<triton::gpu::SparseDotMetaEncodingAttr>(
+      local_load.getType().getEncoding());
+}
 
-  StringRef getArgument() const override { return "sparse-local-load-to-llvm"; }
-
-  void getDependentDialects(mlir::DialectRegistry &registry) const override {
-    registry.insert<LLVM::LLVMDialect, mlir::gpu::GPUDialect,
-                    arith::ArithDialect>();
-  }
-
+struct SparseLocalLoadToLLVMPass
+    : public impl::SparseLocalLoadToLLVMPassBase<SparseLocalLoadToLLVMPass> {
   void runOnOperation() override {
     // Exit early if there are no sparse ops.
-    mlir::ModuleOp mod = getOperation();
-    if (!mod.walk([](triton::gpu::LocalLoadOp op) {
-              if (isa<triton::gpu::SparseDotMetaEncodingAttr>(
-                      op.getType().getEncoding())) {
-                return WalkResult::interrupt();
-              }
-              return WalkResult::advance();
-            })
-             .wasInterrupted()) {
-      return;
-    }
+    ModuleOp mod = getOperation();
+    if (!ContainsOp(mod, IsLocalLoadWithSparseEncoding)) return;
+
     // Allocate shared memory and set barrier
     // This is also done in the TritonGPUToLLVMPass but we need to do it before
     // we write the local load op to LLVM to have barriers in the right place.
@@ -553,7 +515,7 @@
           return !isa<triton::gpu::SparseDotMetaEncodingAttr>(
               op.getType().getEncoding());
         });
-    mlir::LowerToLLVMOptions option(context);
+    LowerToLLVMOptions option(context);
     TritonGPUToLLVMTypeConverter typeConverter(context, option);
     auto pattern = std::make_unique<SparseLocalLoadToLLVM>(typeConverter);
     RewritePatternSet patterns(context, std::move(pattern));
@@ -562,15 +524,8 @@
       return signalPassFailure();
     }
   }
-
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SparseLocalLoadToLLVMPass)
 };
 
-using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
-using ::mlir::triton::gpu::getShapePerCTA;
-using ::mlir::triton::gpu::getShapePerCTATile;
-using ::mlir::triton::gpu::SharedEncodingAttr;
-
 using ValueTableV2 = std::map<std::pair<unsigned, unsigned>, Value>;
 
 constexpr int kContractingFactor = 2;  // implied by N:M (2:4)
@@ -818,17 +773,17 @@
   assert(hMetaPacked.size() == repM * repK);
 
   // Generate prologue.
-  triton::nvgpu::WGMMAEltType eltTypeA = getMmaOperandType(op.getA(), false);
-  triton::nvgpu::WGMMAEltType eltTypeB = getMmaOperandType(op.getB(), false);
-  triton::nvgpu::WGMMAEltType eltTypeC = getMmaRetType(op.getD());
+  ttn::WGMMAEltType eltTypeA = getMmaOperandType(op.getA(), false);
+  ttn::WGMMAEltType eltTypeB = getMmaOperandType(op.getB(), false);
+  ttn::WGMMAEltType eltTypeC = getMmaRetType(op.getD());
 
-  triton::nvgpu::WGMMALayout layoutA = transA ? triton::nvgpu::WGMMALayout::col
-                                              : triton::nvgpu::WGMMALayout::row;
-  triton::nvgpu::WGMMALayout layoutB = transB ? triton::nvgpu::WGMMALayout::row
-                                              : triton::nvgpu::WGMMALayout::col;
+  ttn::WGMMALayout layoutA =
+      transA ? ttn::WGMMALayout::col : ttn::WGMMALayout::row;
+  ttn::WGMMALayout layoutB =
+      transB ? ttn::WGMMALayout::row : ttn::WGMMALayout::col;
 
-  rewriter.create<triton::nvgpu::FenceAsyncSharedOp>(loc, 0);
-  rewriter.create<triton::nvgpu::WGMMAFenceOp>(loc);
+  rewriter.create<ttn::FenceAsyncSharedOp>(loc, 0);
+  rewriter.create<ttn::WGMMAFenceOp>(loc);
 
   // Generate main loop.
   for (int m = 0; m < repM; ++m) {
@@ -841,7 +796,7 @@
         Value a = loadA(m, k);
         Value b = loadB(n, k);
         Value meta = hMetaPacked[k * repM + m];
-        d = rewriter.create<triton::nvgpu::SparseWGMMAOp>(
+        d = rewriter.create<ttn::SparseWGMMAOp>(
             loc, accTy, a, meta, b, d, kWarpsInGroup * instrShape[0],
             instrShape[1], kContractingFactor * instrShape[2], eltTypeC,
             eltTypeA, eltTypeB, layoutA, layoutB);
@@ -858,8 +813,8 @@
       op.getContext(), SmallVector<Type>(fc.size(), f32_ty));
   Value res = packLLElements(loc, typeConverter, fc, rewriter, structTy);
 
-  rewriter.create<triton::nvgpu::WGMMACommitGroupOp>(loc);
-  res = rewriter.create<triton::nvgpu::WGMMAWaitGroupOp>(loc, res, 0);
+  rewriter.create<ttn::WGMMACommitGroupOp>(loc);
+  res = rewriter.create<ttn::WGMMAWaitGroupOp>(loc, res, 0);
   rewriter.replaceOp(op, res);
 
   return success();
@@ -899,26 +854,16 @@
   }
 };
 
-class SparseDotOpToLLVMPass
-    : public PassWrapper<SparseDotOpToLLVMPass, OperationPass<ModuleOp>> {
- public:
-  SparseDotOpToLLVMPass() = default;
-
-  StringRef getArgument() const override { return "sparse-dot-to-llvm"; }
-
-  void getDependentDialects(mlir::DialectRegistry &registry) const override {
-    registry.insert<LLVM::LLVMDialect, mlir::gpu::GPUDialect,
-                    arith::ArithDialect>();
-  }
-
+struct SparseDotOpToLLVMPass
+    : public impl::SparseDotOpToLLVMPassBase<SparseDotOpToLLVMPass> {
   void runOnOperation() override {
     MLIRContext *context = &getContext();
     ConversionTarget target(*context);
     target.addLegalDialect<LLVM::LLVMDialect, NVVM::NVVMDialect,
-                           arith::ArithDialect, triton::nvgpu::NVGPUDialect>();
+                           arith::ArithDialect, ttn::NVGPUDialect>();
     target.addIllegalOp<triton::gpu::SparseDotOp>();
     target.addIllegalDialect<mlir::gpu::GPUDialect>();
-    mlir::LowerToLLVMOptions option(context);
+    LowerToLLVMOptions option(context);
     TritonGPUToLLVMTypeConverter typeConverter(context, option);
     RewritePatternSet patterns(context);
     patterns.add<SparseDotOpConversion>(typeConverter);
@@ -928,13 +873,8 @@
       return signalPassFailure();
     }
   }
-
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SparseLocalLoadToLLVMPass)
 };
 
-namespace ttn = mlir::triton::nvgpu;
-using ttn::OperandsAndConstraints;
-
 class SparseWGMMAOpPattern : public OpRewritePattern<ttn::SparseWGMMAOp> {
  public:
   using OpRewritePattern<ttn::SparseWGMMAOp>::OpRewritePattern;
@@ -1021,13 +961,8 @@
   }
 };
 
-class SparseWGMMAOpToLLVMPass
-    : public PassWrapper<SparseWGMMAOpToLLVMPass, OperationPass<ModuleOp>> {
- public:
-  SparseWGMMAOpToLLVMPass() = default;
-
-  StringRef getArgument() const override { return "sparse-wgmma-to-llvm"; }
-
+struct SparseWGMMAOpToLLVMPass
+    : public impl::SparseWGMMAOpToLLVMPassBase<SparseWGMMAOpToLLVMPass> {
   void runOnOperation() override {
     MLIRContext *context = &getContext();
     auto pattern = std::make_unique<SparseWGMMAOpPattern>(context);
@@ -1037,43 +972,38 @@
       return signalPassFailure();
     }
   }
-
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SparseLocalLoadToLLVMPass)
 };
 
 }  // namespace
 
-std::unique_ptr<Pass> xla::gpu::CreateAddSparseDotEncodingPass(
-    int32_t num_warps, int32_t threads_per_warp, int32_t num_ctas) {
-  return std::make_unique<AddSparseEncodingPass>(num_warps, threads_per_warp,
-                                                 num_ctas);
+std::unique_ptr<Pass> CreateSparseAddEncodingPass(int32_t num_warps,
+                                                  int32_t threads_per_warp,
+                                                  int32_t num_ctas) {
+  SparseAddEncodingPassOptions options;
+  options.num_warps_ = num_warps;
+  options.threads_per_warp_ = threads_per_warp;
+  options.num_ctas_ = num_ctas;
+  return std::make_unique<SparseAddEncodingPass>(options);
 }
 
-std::unique_ptr<mlir::Pass> xla::gpu::CreateSparseBlockedToMMAPass() {
+std::unique_ptr<Pass> CreateSparseBlockedToMMAPass() {
   return std::make_unique<SparseBlockedToMMAPass>();
 }
 
-std::unique_ptr<mlir::Pass> xla::gpu::CreateSparseRemoveLayoutConversionPass() {
+std::unique_ptr<Pass> CreateSparseRemoveLayoutConversionPass() {
   return std::make_unique<SparseRemoveLayoutConversionPass>();
 }
 
-std::unique_ptr<mlir::Pass> xla::gpu::CreateSparseLocalLoadToLLVMPass() {
+std::unique_ptr<Pass> CreateSparseLocalLoadToLLVMPass() {
   return std::make_unique<SparseLocalLoadToLLVMPass>();
 }
 
-std::unique_ptr<mlir::Pass> xla::gpu::CreateSparseDotOpToLLVMPass() {
+std::unique_ptr<Pass> CreateSparseDotOpToLLVMPass() {
   return std::make_unique<SparseDotOpToLLVMPass>();
 }
 
-std::unique_ptr<mlir::Pass> xla::gpu::CreateSparseWGMMAOpToLLVMPass() {
+std::unique_ptr<Pass> CreateSparseWGMMAOpToLLVMPass() {
   return std::make_unique<SparseWGMMAOpToLLVMPass>();
 }
 
-void xla::gpu::RegisterSparsePasses() {
-  registerPass([] { return std::make_unique<AddSparseEncodingPass>(); });
-  registerPass(CreateSparseBlockedToMMAPass);
-  registerPass(CreateSparseRemoveLayoutConversionPass);
-  registerPass(CreateSparseLocalLoadToLLVMPass);
-  registerPass(CreateSparseDotOpToLLVMPass);
-  registerPass(CreateSparseWGMMAOpToLLVMPass);
-}
+}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/fusions/triton/sparse_extensions.h b/third_party/xla/xla/service/gpu/fusions/triton/sparse_extensions.h
deleted file mode 100644
index 988a63d..0000000
--- a/third_party/xla/xla/service/gpu/fusions/triton/sparse_extensions.h
+++ /dev/null
@@ -1,38 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_FUSIONS_TRITON_SPARSE_EXTENSIONS_H_
-#define XLA_SERVICE_GPU_FUSIONS_TRITON_SPARSE_EXTENSIONS_H_
-
-#include <cstdint>
-#include <memory>
-
-#include "mlir/Pass/Pass.h"
-
-namespace xla::gpu {
-
-std::unique_ptr<mlir::Pass> CreateAddSparseDotEncodingPass(
-    int32_t num_warps, int32_t threads_per_warp, int32_t num_ctas);
-std::unique_ptr<mlir::Pass> CreateSparseBlockedToMMAPass();
-std::unique_ptr<mlir::Pass> CreateSparseRemoveLayoutConversionPass();
-std::unique_ptr<mlir::Pass> CreateSparseLocalLoadToLLVMPass();
-std::unique_ptr<mlir::Pass> CreateSparseDotOpToLLVMPass();
-std::unique_ptr<mlir::Pass> CreateSparseWGMMAOpToLLVMPass();
-
-void RegisterSparsePasses();
-
-}  // namespace xla::gpu
-
-#endif  // XLA_SERVICE_GPU_FUSIONS_TRITON_SPARSE_EXTENSIONS_H_
diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc
index 20fa19c..e6c0374 100644
--- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc
+++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc
@@ -110,9 +110,9 @@
 #include "xla/service/algorithm_util.h"
 #include "xla/service/dump.h"
 #include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h"
 #include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h"
-#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h"
-#include "xla/service/gpu/fusions/mlir/passes.h"
+#include "xla/service/gpu/fusions/transforms/passes.h"
 #include "xla/service/gpu/hlo_traversal.h"
 #include "xla/service/gpu/ir_emission_utils.h"
 #include "xla/service/gpu/launch_dimensions.h"
@@ -1507,7 +1507,7 @@
             }
           }
         }
-        CHECK(to_order.insert(current).second);
+        to_order.insert(current);
         to_add.pop();
       }
     }
diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.h b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.h
index fe133d8..3f7c3bc 100644
--- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.h
+++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.h
@@ -17,7 +17,6 @@
 #define XLA_SERVICE_GPU_FUSIONS_TRITON_TRITON_FUSION_EMITTER_H_
 
 #include <cstdint>
-#include <functional>
 #include <optional>
 #include <string>
 
@@ -87,19 +86,6 @@
                         mlir::triton::FuncOp fn,
                         const BlockLevelParameters& block_level_parameters);
 
-// Generate Softmax in Triton IR inside 'fn'.
-// Use execution parameters from 'block_level_parameters'.
-absl::Status EmitSoftMax(mlir::OpBuilder b, absl::string_view libdevice_path,
-                         const se::DeviceDescription& device_info,
-                         const HloFusionInstruction* fusion,
-                         mlir::triton::FuncOp fn,
-                         const BlockLevelParameters& block_level_parameters);
-
-using TritonIrEmitter = std::function<absl::Status(
-    mlir::OpBuilder, absl::string_view, const se::DeviceDescription&,
-    const HloFusionInstruction*, mlir::triton::FuncOp,
-    const BlockLevelParameters&)>;
-
 // Load the MLIR dialects required for Triton IR generation.
 void LoadMlirDialectsForTriton(mlir::MLIRContext& mlir_context);
 
diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc
index c162c34..720a8b4 100644
--- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc
+++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc
@@ -1549,6 +1549,35 @@
       kHloText, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-6}));
 }
 
+TEST_F(TritonGemmTest, MultiplePathsToSameOperandWorks) {
+  const std::string kHloText = R"(
+triton_computation {
+  p0 = bf16[8192,512]{1,0} parameter(0)
+  p1 = bf16[512,512]{1,0} parameter(1)
+  dot = bf16[8192,512]{1,0} dot(bf16[8192,512]{1,0} p0, bf16[512,512]{1,0} p1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  p2 = bf16[8192,512]{1,0} parameter(2)
+  multiply.1 = bf16[8192,512]{1,0} multiply(bf16[8192,512]{1,0} dot, bf16[8192,512]{1,0} p2)
+  ROOT multiply.2 = bf16[8192,512]{1,0} multiply(bf16[8192,512]{1,0} multiply.1, bf16[8192,512]{1,0} p2)
+}
+
+ENTRY e {
+  p0 = bf16[8192,512]{1,0} parameter(0)
+  p1 = bf16[512,512]{1,0} parameter(1)
+  p2 = bf16[8192,512]{1,0} parameter(2)
+  ROOT fusion = bf16[8192,512]{1,0} fusion(p0,p1,p2), kind=kCustom, calls=triton_computation, 
+  backend_config={"fusion_backend_config":
+      {"kind":"__triton_gemm", "triton_gemm_config":{"block_m":"64","block_n":"256","block_k":"32","split_k":"1","num_stages":"4","num_warps":"4","num_ctas":"1"}}}
+})";
+
+  TF_ASSERT_OK(
+      CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_computation", R"(
+        CHECK:      tt.dot
+        CHECK-SAME: tensor<64x32xbf16> * tensor<32x256xbf16> -> tensor<64x256xf32>
+        CHECK:      arith.mulf
+        CHECK:      arith.mulf
+    )"));
+}
+
 class TritonGemmDynamicSliceClampingTest
     : public TritonTest,
       public ::testing::WithParamInterface<int> {};
diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc
index 8f07bba..d59b6cc 100644
--- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc
+++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc
@@ -214,7 +214,7 @@
   TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText,
                                           FromOutputTileSizes({1, 127}),
                                           "triton_softmax_computation", R"(
-CHECK:        #[[MAP:.*]] = affine_map<(d0) -> (d0 * 127)>
+CHECK:        #[[MAP:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 * 127)
 CHECK:        tt.func @triton_fn(%[[P0:[^:]*]]: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %[[P1:[^:]*]]: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
 CHECK:            %[[PID:.*]] = tt.get_program_id x : i32
 CHECK:            arith.index_castui %[[PID]] : i32 to index
@@ -272,7 +272,7 @@
   TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText,
                                           FromOutputTileSizes({1, 127}),
                                           "triton_softmax_computation", R"(
-CHECK:        #[[MAP:.*]] = affine_map<(d0) -> (d0 * 127)>
+CHECK:        #[[MAP:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 * 127)
 CHECK-LABEL:  tt.func @triton_fn(
 CHECK-SAME:                      %[[P0:[A-Za-z0-9_]*]]: !tt.ptr<f32>
 CHECK-SAME:                      %[[P1:[A-Za-z0-9_]*]]: !tt.ptr<f32>
@@ -339,7 +339,7 @@
   TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText,
                                           FromOutputTileSizes({1, 1, 127}),
                                           "triton_softmax_computation", R"(
-CHECK:        #[[MAP:.*]] = affine_map<(d0) -> (d0 * 127)>
+CHECK:        #[[MAP:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 * 127)
 CHECK:        tt.func @triton_fn(%[[P0:[^:]*]]: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %[[P1:[^:]*]]: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %[[P2:[^:]*]]: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %[[P3:[^:]*]]: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
 CHECK-DAG:        %[[PID:.*]] = tt.get_program_id x : i32
 CHECK-DAG:        %[[PID_INDEX:.*]] = arith.index_castui %[[PID]] : i32 to index
@@ -517,7 +517,7 @@
   TF_ASSERT_OK(CreateTritonIrAndFileCheck(this, kHloText,
                                           FromOutputTileSizes({1, 1, 16}),
                                           "triton_softmax_computation", R"(
-// CHECK:         #[[MAP:.*]] = affine_map<(d0) -> (d0 * 16)>
+// CHECK:         #[[MAP:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 * 16)
 // CHECK-LABEL:   tt.func @triton_fn(
 // CHECK-SAME:                       %[[P0:[A-Za-z0-9_]*]]: !tt.ptr<f32>
 // CHECK-SAME:                       %[[P1:[A-Za-z0-9_]*]]: !tt.ptr<f32>
@@ -674,7 +674,7 @@
   TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText,
                                           FromOutputTileSizes({1, 127}),
                                           "triton_softmax_computation", R"(
-CHECK:        #[[MAP:.*]] = affine_map<(d0) -> (d0 * 127)>
+CHECK:        #[[MAP:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 * 127)
 CHECK:        tt.func @triton_fn(%[[P0:[^:]*]]: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %[[P1:[^:]*]]: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
 CHECK:            %[[PID:.*]] = tt.get_program_id x : i32
 CHECK:            arith.index_castui %[[PID]] : i32 to index
diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_parametrized_test.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_parametrized_test.cc
index f946cc4..7ae8f5e 100644
--- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_parametrized_test.cc
+++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_parametrized_test.cc
@@ -52,13 +52,6 @@
 class MixedTypeTest : public GpuCodegenTest,
                       public ::testing::WithParamInterface<MixTypeParams> {
  public:
-  se::CudaComputeCapability GetCudaComputeCapability() {
-    return backend()
-        .default_stream_executor()
-        ->GetDeviceDescription()
-        .cuda_compute_capability();
-  }
-
   DebugOptions GetDebugOptionsForTest() override {
     DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest();
     // We are testing Triton, remove cuBLAS fallback for these tests.
@@ -803,13 +796,6 @@
     debug_options.clear_xla_disable_hlo_passes();
     return debug_options;
   }
-
-  se::CudaComputeCapability GetCudaComputeCapability() {
-    return backend()
-        .default_stream_executor()
-        ->GetDeviceDescription()
-        .cuda_compute_capability();
-  }
 };
 
 TEST_P(TritonSoftmaxTest, CanFuseAndEmitExactSoftmax) {
diff --git a/third_party/xla/xla/service/gpu/fusions/triton_test.cc b/third_party/xla/xla/service/gpu/fusions/triton_test.cc
index c2cfabf..1738d6f 100644
--- a/third_party/xla/xla/service/gpu/fusions/triton_test.cc
+++ b/third_party/xla/xla/service/gpu/fusions/triton_test.cc
@@ -64,7 +64,7 @@
       TestGpuDeviceInfo::RTXA6000DeviceInfo();
 
   auto* root = module->entry_computation()->root_instruction();
-  HloFusionAnalysis analysis = AnalyzeFusion(*root, device_info);
+  HloFusionAnalysis analysis = HloFusionAnalysis::Create(*root, device_info);
 
   std::unique_ptr<FusionInterface> emitter =
       GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis});
@@ -100,7 +100,7 @@
       TestGpuDeviceInfo::RTXA6000DeviceInfo();
 
   auto* root = module->entry_computation()->root_instruction();
-  HloFusionAnalysis analysis = AnalyzeFusion(*root, device_info);
+  HloFusionAnalysis analysis = HloFusionAnalysis::Create(*root, device_info);
 
   std::unique_ptr<FusionInterface> emitter =
       GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis});
diff --git a/third_party/xla/xla/service/gpu/gemm_algorithm_picker.cc b/third_party/xla/xla/service/gpu/gemm_algorithm_picker.cc
deleted file mode 100644
index a2de14c..0000000
--- a/third_party/xla/xla/service/gpu/gemm_algorithm_picker.cc
+++ /dev/null
@@ -1,500 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gemm_algorithm_picker.h"
-
-#include <cstddef>
-#include <cstdint>
-#include <memory>
-#include <optional>
-#include <string>
-#include <utility>
-#include <variant>
-#include <vector>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/string_view.h"
-#include "absl/synchronization/mutex.h"
-#include "absl/types/span.h"
-#include "xla/autotuning.pb.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/service/gpu/autotuner_compile_util.h"
-#include "xla/service/gpu/autotuner_util.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/buffer_comparator.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/service/gpu/matmul_utils.h"
-#include "xla/service/gpu/stream_executor_util.h"
-#include "xla/service/gpu/variant_visitor.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/stream_executor/blas.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/stream_executor/device_memory.h"
-#include "xla/stream_executor/device_memory_allocator.h"
-#include "xla/stream_executor/gpu/redzone_allocator.h"
-#include "xla/tsl/util/proto/proto_utils.h"
-#include "xla/util.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h"
-#include "tsl/platform/statusor.h"
-#include "tsl/profiler/lib/scoped_annotation.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-using se::gpu::BlasLt;
-
-absl::StatusOr<BlasLt::Epilogue> AsBlasLtEpilogue(
-    GemmBackendConfig_Epilogue epilogue) {
-  switch (epilogue) {
-    case GemmBackendConfig::DEFAULT:
-      return BlasLt::Epilogue::kDefault;
-    case GemmBackendConfig::RELU:
-      return BlasLt::Epilogue::kReLU;
-    case GemmBackendConfig::GELU:
-      return BlasLt::Epilogue::kGELU;
-    case GemmBackendConfig::GELU_AUX:
-      return BlasLt::Epilogue::kGELUWithAux;
-    case GemmBackendConfig::BIAS:
-      return BlasLt::Epilogue::kBias;
-    case GemmBackendConfig::BIAS_RELU:
-      return BlasLt::Epilogue::kBiasThenReLU;
-    case GemmBackendConfig::BIAS_GELU:
-      return BlasLt::Epilogue::kBiasThenGELU;
-    case GemmBackendConfig::BIAS_GELU_AUX:
-      return BlasLt::Epilogue::kBiasThenGELUWithAux;
-    default:
-      return Internal("Unsupported Epilogue.");
-  }
-}
-
-class GemmAutotuner {
-  const AutotuneConfig& autotune_config_;
-  RedzoneBuffers rz_buffers_;
-  se::Stream* stream_ = nullptr;
-  bool deterministic_ops_ = false;
-  size_t solutions_limit_ = 0;
-  size_t num_algorithms_left_ = 0;
-
- public:
-  explicit GemmAutotuner(const AutotuneConfig& autotune_config)
-      : autotune_config_(autotune_config) {}
-
-  size_t num_algorithms_left() const { return num_algorithms_left_; }
-
-  absl::StatusOr<AutotuneResult> operator()(const HloInstruction* gemm,
-                                            const AutotuneCacheKey& key) {
-    num_algorithms_left_ = 0;
-    if (autotune_config_.IsDeviceless()) {
-      // Return empty result, will tune at runtime.
-      return AutotuneResult{};
-    }
-    VLOG(3) << "Starting autotune of GemmThunk " << gemm->ToString();
-
-    TF_ASSIGN_OR_RETURN(stream_, autotune_config_.GetStream());
-    const DebugOptions& debug_options =
-        gemm->GetModule()->config().debug_options();
-    deterministic_ops_ = RequireDeterminism(gemm->GetModule()->config());
-    solutions_limit_ = debug_options.xla_gpu_autotune_max_solutions();
-
-    TF_ASSIGN_OR_RETURN(auto gemm_config, GemmConfig::For(gemm));
-
-    // Don't run autotuning concurrently on the same GPU.
-    absl::MutexLock gpu_lock(&GetGpuMutex(stream_->parent()));
-
-    TF_ASSIGN_OR_RETURN(rz_buffers_, RedzoneBuffers::FromInstruction(
-                                         *gemm, autotune_config_, debug_options,
-                                         RedzoneBuffers::kAllInputsAllOutputs));
-
-    return IsCublasLtMatmul(*gemm) || IsCublasLtMatmulF8(*gemm)
-               ? TuneGpuBlasLt(gemm, gemm_config)
-               : TuneGpuBlas(gemm, gemm_config);
-  }
-
- private:
-  se::DeviceMemoryBase LhsBuffer() { return rz_buffers_.input_buffers().at(0); }
-  se::DeviceMemoryBase RhsBuffer() { return rz_buffers_.input_buffers().at(1); }
-  se::DeviceMemoryBase OutputBuffer() {
-    return rz_buffers_.output_buffers().at(0);
-  }
-
-  const Shape& GetOutputShape(const HloInstruction* gemm) {
-    return gemm->shape().IsTuple() ? gemm->shape().tuple_shapes(0)
-                                   : gemm->shape();
-  }
-
-  absl::StatusOr<AutotuneResult> TuneGpuBlasLt(const HloInstruction* gemm,
-                                               const GemmConfig& gemm_config) {
-    auto workspace_buffer =
-        rz_buffers_.output_buffers().at(gemm->shape().tuple_shapes_size() - 1);
-
-    GpuBackendConfig gpu_config =
-        gemm->backend_config<GpuBackendConfig>().value();
-    const GemmBackendConfig& backend_config = gpu_config.gemm_backend_config();
-
-    bool has_matrix_bias = gemm_config.beta != 0.;
-
-    TF_ASSIGN_OR_RETURN(
-        bool has_vector_bias,
-        gpublas_lt::EpilogueAddsVectorBias(backend_config.epilogue()));
-
-    TF_ASSIGN_OR_RETURN(
-        bool has_aux_output,
-        gpublas_lt::EpilogueHasAuxiliaryOutput(backend_config.epilogue()));
-
-    TF_ASSIGN_OR_RETURN(auto epilogue,
-                        AsBlasLtEpilogue(backend_config.epilogue()));
-
-    se::DeviceMemoryBase a_scale_buffer, b_scale_buffer, c_scale_buffer,
-        d_scale_buffer, d_amax_buffer, bias_buffer, aux_buffer;
-
-    if (has_vector_bias) {
-      bias_buffer = rz_buffers_.input_buffers().at(has_matrix_bias ? 3 : 2);
-    }
-    if (has_aux_output) {
-      aux_buffer = rz_buffers_.output_buffers().at(1);
-    }
-
-    TF_ASSIGN_OR_RETURN(auto plan,
-                        BlasLt::GetMatmulPlan(stream_, gemm_config, epilogue));
-
-    TF_ASSIGN_OR_RETURN(
-        auto algorithms,
-        plan->GetAlgorithms(/*max_algorithm_count*/ 128,
-                            /*max_workspace_size*/ workspace_buffer.size()));
-
-    auto tuned_func = [&](const BlasLt::MatmulAlgorithm& algorithm)
-        -> absl::StatusOr<se::blas::ProfileResult> {
-      // Run a warmup iteration without the profiler active.
-      TF_RETURN_IF_ERROR(plan->ExecuteOnStream(
-          stream_, LhsBuffer(), RhsBuffer(), OutputBuffer(), OutputBuffer(),
-          bias_buffer, aux_buffer, a_scale_buffer, b_scale_buffer,
-          c_scale_buffer, d_scale_buffer, d_amax_buffer, algorithm,
-          workspace_buffer));
-      se::blas::ProfileResult profile_result;
-      profile_result.set_warmup_run_executed(true);
-      TF_RETURN_IF_ERROR(plan->ExecuteOnStream(
-          stream_, LhsBuffer(), RhsBuffer(), OutputBuffer(), OutputBuffer(),
-          bias_buffer, aux_buffer, a_scale_buffer, b_scale_buffer,
-          c_scale_buffer, d_scale_buffer, d_amax_buffer, algorithm,
-          workspace_buffer, &profile_result));
-      return std::move(profile_result);
-    };
-
-    return GetBestAlgorithm<BlasLt::MatmulAlgorithm>(
-        gemm, algorithms, gemm_config.beta, /*return_algo_index*/ true,
-        tuned_func);
-  }
-
-  absl::StatusOr<AutotuneResult> TuneGpuBlas(const HloInstruction* gemm,
-                                             const GemmConfig& gemm_config) {
-    auto workspace_buffer = rz_buffers_.output_buffers().at(1);
-
-    std::vector<se::blas::AlgorithmType> algorithms;
-    TF_ASSIGN_OR_RETURN(GemmConfig::DescriptorsTuple desc,
-                        gemm_config.GetMatrixDescriptors(
-                            LhsBuffer(), RhsBuffer(), OutputBuffer()));
-
-    auto blas = stream_->parent()->AsBlas();
-    if (blas == nullptr) {
-      return absl::InternalError("No BLAS support for stream");
-    }
-    blas->GetBlasGemmAlgorithms(stream_, desc.lhs, desc.rhs, &desc.output,
-                                &gemm_config.alpha, &gemm_config.beta,
-                                &algorithms);
-
-    auto tuned_func = [&](const se::blas::AlgorithmType& algorithm)
-        -> absl::StatusOr<se::blas::ProfileResult> {
-      // Do a warm-up run first, without a profile result. RunGemm swallows
-      // error codes when profile_result is passed, as it is in the measurement
-      // below, but not otherwise. It is, therefore, consistent to ignore the
-      // error code here.
-      static_cast<void>(RunGemm(gemm_config, LhsBuffer(), RhsBuffer(),
-                                OutputBuffer(), workspace_buffer,
-                                deterministic_ops_, stream_, algorithm));
-      se::blas::ProfileResult profile_result;
-      // Allow GpuTimer to use its delay kernel implementation to improve
-      // accuracy.
-      profile_result.set_warmup_run_executed(true);
-      // We expect GemmWithAlgorithm to fail sometimes -- in fact, it will fail
-      // for all algorithms if we're targeting < sm_50. But because we pass a
-      // non-null ProfileResult, DoGemmWithAlgorithm should always return true,
-      // and the actual success-ness is returned in ProfileResult::is_valid.
-      TF_RETURN_IF_ERROR(RunGemm(gemm_config, LhsBuffer(), RhsBuffer(),
-                                 OutputBuffer(), workspace_buffer,
-                                 deterministic_ops_, stream_, algorithm,
-                                 &profile_result));
-      return std::move(profile_result);
-    };
-
-    return GetBestAlgorithm<se::blas::AlgorithmType>(
-        gemm, algorithms, gemm_config.beta, /*return_algo_index*/ false,
-        tuned_func);
-  }
-
-  // Returns the index (into `algorithms`) of the fastest algorithm.
-  template <typename AlgoT, typename TunedFunc>
-  absl::StatusOr<AutotuneResult> GetBestAlgorithm(
-      const HloInstruction* gemm, absl::Span<const AlgoT> algorithms,
-      double beta, bool return_algo_index, TunedFunc&& run_benchmark) {
-    static_assert(std::is_invocable_r_v<absl::StatusOr<se::blas::ProfileResult>,
-                                        TunedFunc, const AlgoT&>,
-                  "Tuned function has incorrect prototype!");
-
-    if (!stream_->parent()->SynchronizeAllActivity()) {
-      return Internal("Failed to synchronize GPU for autotuning.");
-    }
-    tsl::profiler::ScopedAnnotation annotation([&] {
-      return absl::StrFormat("XlaAutotunerMeasurement:#hlo_op=%s#",
-                             gemm->name());
-    });
-
-    auto& hlo_module_config = gemm->GetModule()->mutable_config();
-    const auto& output_shape = GetOutputShape(gemm);
-
-    se::DeviceMemoryBase reference_buffer;
-    if (autotune_config_.should_check_correctness()) {
-      TF_ASSIGN_OR_RETURN(reference_buffer,
-                          rz_buffers_.RedzoneAllocator().AllocateBytes(
-                              ShapeUtil::ByteSizeOf(output_shape)));
-    }
-
-    // Do not print error messages if should_skip_wrong_results() is ON.
-    BufferComparator comparator(
-        output_shape,
-        hlo_module_config.debug_options().xla_gpu_autotune_gemm_rtol(),
-        /* verbose */ !autotune_config_.should_skip_wrong_results());
-    std::vector<AutotuneResult> results;
-    results.reserve(algorithms.size());
-    std::optional<int64_t> reference_algorithm;
-
-    auto num = algorithms.size();
-    if (solutions_limit_ > 0) num = std::min(num, solutions_limit_);
-    for (size_t i = 0; i < num; i++) {
-      const AlgoT& algorithm = algorithms[i];
-      // Make sure the output buffer always has the same value if we use
-      // the bias parameter.
-      if (autotune_config_.should_reinit_output_buffer() && beta != 0) {
-        int64_t rng_state = 0;
-        InitializeBuffer(stream_, output_shape.element_type(), &rng_state,
-                         OutputBuffer());
-      }
-      TF_ASSIGN_OR_RETURN(auto profile_result, run_benchmark(algorithm));
-
-      AutotuneResult& result = results.emplace_back();
-      result.mutable_gemm()->set_algorithm(profile_result.algorithm());
-
-      if (!profile_result.is_valid()) {  // Unsupported algorithm.
-        result.mutable_failure()->set_kind(AutotuneResult::DISQUALIFIED);
-        continue;
-      }
-
-      VLOG(2) << "gemm algorithm " << profile_result.algorithm() << " took "
-              << profile_result.elapsed_time_in_ms() << "ms";
-
-      *result.mutable_run_time() = tsl::proto_utils::ToDurationProto(
-          absl::Milliseconds(profile_result.elapsed_time_in_ms()));
-
-      if (!autotune_config_.should_check_correctness()) {
-        num_algorithms_left_++;
-        continue;
-      }
-      TF_ASSIGN_OR_RETURN(
-          se::RedzoneAllocator::RedzoneCheckStatus rz_check_status,
-          rz_buffers_.RedzoneAllocator().CheckRedzones());
-
-      if (!rz_check_status.ok()) {
-        result.mutable_failure()->set_kind(AutotuneResult::REDZONE_MODIFIED);
-        *result.mutable_failure()->mutable_msg() =
-            rz_check_status.RedzoneFailureMsg();
-        LOG(ERROR) << "Detected out-of-bounds write in gemm buffer";
-        CHECK(!autotune_config_.should_crash_on_check_failure());
-        continue;
-      }
-
-      num_algorithms_left_++;
-      if (!reference_algorithm) {
-        TF_RETURN_IF_ERROR(stream_->Memcpy(&reference_buffer, OutputBuffer(),
-                                           OutputBuffer().size()));
-        reference_algorithm = profile_result.algorithm();
-        continue;
-      }
-      // Perform the comparison versus the reference algorithm.
-      TF_ASSIGN_OR_RETURN(
-          bool outputs_match,
-          comparator.CompareEqual(stream_, /*current=*/OutputBuffer(),
-                                  /*expected=*/reference_buffer));
-      if (!outputs_match) {
-        LOG(ERROR) << "Results mismatch between different GEMM algorithms. "
-                   << "This is likely a bug/unexpected loss of precision.";
-        CHECK(!autotune_config_.should_crash_on_check_failure());
-
-        // By default, autotuner does NOT really skip wrong results, but
-        // merely prints out the above error message: this may lead to a
-        // great confusion. When should_skip_wrong_results() is set to true,
-        // solutions with accuracy problems will be disqualified.
-        auto kind = AutotuneResult::WRONG_RESULT;
-        if (autotune_config_.should_skip_wrong_results()) {
-          kind = AutotuneResult::DISQUALIFIED;
-          num_algorithms_left_--;  // Decrement again since we disqualified it.
-        }
-        result.mutable_failure()->set_kind(kind);
-        result.mutable_failure()->mutable_reference_gemm()->set_algorithm(
-            *reference_algorithm);
-      }
-    }  // for algorithms
-
-    absl::StatusOr<AutotuneResult> best =
-        PickBestResult(results, gemm->ToString(), hlo_module_config);
-    if (best.ok()) {
-      // Note that, cublas-lt returns an opaque object as an algorithm ID,
-      // therefore we need to convert it to the index from the algorithms list
-      // (otherwise, we cannot store this ID inside a gemm_backend_config).
-      // In contrast, legacy cublas returns a 32-bit integer algorithm ID which
-      // can be readily stored inside an HLO (hence return_algo_index is false
-      // for cublas case).
-      if (!return_algo_index) return best;
-      // Otherwise, map a real algorithm ID to its index among the results.
-      for (size_t i = 0; i < results.size(); ++i) {
-        if (best->gemm().algorithm() == results[i].gemm().algorithm()) {
-          best->mutable_gemm()->set_algorithm(i);
-          return best;
-        }
-      }
-      return Internal("unknown best algorithm");
-    }
-    LOG(WARNING) << "Failed to find best cuBLAS algorithm, GEMM performance "
-                    "might be suboptimal: "
-                 << best.status();
-    return AutotuneResult{};
-  }  // GetBestAlgorithm
-};  // GemmAutotuner
-
-// Do Gemm Autotune without stream executor. Use results from autotune cache
-// only.
-absl::StatusOr<bool> RunOnInstruction(HloInstruction* gemm,
-                                      const AutotuneConfig& config,
-                                      size_t* num_algorithms_left) {
-  VLOG(3) << "Loading the autotune result of GemmThunk " << gemm->ToString();
-
-  GpuBackendConfig gpu_config =
-      gemm->backend_config<GpuBackendConfig>().value();
-  GemmBackendConfig& backend_config = *gpu_config.mutable_gemm_backend_config();
-
-  *num_algorithms_left = 0;
-  // Degenerate gemms replaced with memzero operation, no need to auto tune it.
-  if (backend_config.alpha_real() == 0.0 &&
-      backend_config.alpha_imag() == 0.0 && backend_config.beta() == 0.0) {
-    VLOG(3) << "Skip degenerate gemm instruction auto tuning";
-    return false;
-  }
-
-  AutotuneCacheKey key(config.GetModelStr(), *gemm);
-  GemmAutotuner autotuner(config);
-  TF_ASSIGN_OR_RETURN(AutotuneResult algorithm,
-                      AutotunerUtil::Autotune(
-                          gemm, config, [&] { return autotuner(gemm, key); }));
-
-  *num_algorithms_left = autotuner.num_algorithms_left();
-  auto old_algorithm = backend_config.selected_algorithm();
-  bool update_algorithm =
-      IsCublasLtMatmulF8(*gemm) ||
-      std::visit(VariantVisitor{[](const se::CudaComputeCapability& cc) {
-                                  // We only set the 'algorithm' field on
-                                  // non-Ampere architectures, as for Ampere
-                                  // it's ignored in any case.
-                                  return !cc.IsAtLeast(
-                                      se::CudaComputeCapability::AMPERE);
-                                },
-                                [](const se::RocmComputeCapability&) {
-                                  return true;  // TODO: not decided yet
-                                }},
-                 config.GetGpuComputeCapability());
-
-  if (update_algorithm) {
-    int64_t new_algorithm{};
-    if (algorithm.has_gemm()) {
-      new_algorithm = algorithm.gemm().algorithm();
-    } else {
-      // NOTE: runtime autotuning is no longer available => set to default
-      new_algorithm = se::blas::kDefaultAlgorithm;
-    }
-
-    if (new_algorithm == old_algorithm &&
-        backend_config.has_selected_algorithm()) {
-      // We don't need to update the backend config if
-      // the algorithm hasn't changed unless previously
-      // the algorithm wasn't set explicitly.
-      return false;
-    }
-
-    backend_config.set_selected_algorithm(new_algorithm);
-    TF_RETURN_IF_ERROR(gemm->set_backend_config(gpu_config));
-    return true;  // We changed `gemm`
-  }
-
-  return false;  // No change to `gemm`
-}
-
-absl::StatusOr<bool> RunOnComputation(HloComputation* computation,
-                                      AutotuneConfig config,
-                                      size_t* num_algorithms_left) {
-  bool changed = false;
-
-  for (HloInstruction* instr : computation->instructions()) {
-    if (IsCublasGemm(*instr)) {
-      size_t num_left;
-      TF_ASSIGN_OR_RETURN(bool result,
-                          RunOnInstruction(instr, config, &num_left));
-      // Gathering statistics on the algorithms left after tuning (for testing)
-      *num_algorithms_left = std::max(*num_algorithms_left, num_left);
-      changed |= result;
-    }
-  }
-  return changed;
-}
-
-}  // namespace
-
-absl::StatusOr<bool> GemmAlgorithmPicker::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  XLA_SCOPED_LOGGING_TIMER(
-      absl::StrCat("GemmAlgorithmPicker for ", module->name()));
-
-  num_algorithms_left_ = 0;
-  if (module->config().debug_options().xla_gpu_autotune_level() == 0) {
-    VLOG(2) << "GEMM auto-tuning disabled, GemmAlgorithmPicker returning early";
-    return false;
-  }
-
-  bool changed = false;
-  for (HloComputation* computation :
-       module->MakeNonfusionComputations(execution_threads)) {
-    TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation, config_,
-                                                      &num_algorithms_left_));
-    changed |= result;
-  }
-  return changed;
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gemm_algorithm_picker.h b/third_party/xla/xla/service/gpu/gemm_algorithm_picker.h
deleted file mode 100644
index be2686d..0000000
--- a/third_party/xla/xla/service/gpu/gemm_algorithm_picker.h
+++ /dev/null
@@ -1,70 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_GEMM_ALGORITHM_PICKER_H_
-#define XLA_SERVICE_GPU_GEMM_ALGORITHM_PICKER_H_
-
-#include <functional>
-#include <optional>
-#include <string_view>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "xla/autotune_results.pb.h"
-#include "xla/autotuning.pb.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/gpu/autotuner_util.h"
-#include "xla/service/hlo_module_config.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/shape.h"
-#include "xla/stream_executor/blas.h"
-#include "xla/stream_executor/device_memory.h"
-#include "xla/stream_executor/device_memory_allocator.h"
-#include "xla/stream_executor/gpu/redzone_allocator.h"
-#include "xla/stream_executor/stream_executor.h"
-
-namespace xla {
-namespace gpu {
-
-// GemmAlgorithmPicker supports two modes: device and deviceless.
-// In device mode, we run autotuning on the device and store autotune results.
-// In deviceless mode, we pass in some information related to the device and
-// use stored autotune results to rewrite Gemm instructions. If the required
-// autotune result is not stored, then algorithm is set to kRuntimeAutotuning.
-class GemmAlgorithmPicker : public HloModulePass {
- public:
-  explicit GemmAlgorithmPicker(AutotuneConfig config) : config_(config) {}
-
-  absl::string_view name() const override { return "gemm-algorithm-picker"; }
-
-  size_t num_algorithms_left() const { return num_algorithms_left_; }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
-  AutotuneConfig config_;
-  // The number of valid algorithms used for autotuning (from the last call),
-  // to be used for testing purposes.
-  size_t num_algorithms_left_ = 0;
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_GEMM_ALGORITHM_PICKER_H_
diff --git a/third_party/xla/xla/service/gpu/gemm_algorithm_picker_test.cc b/third_party/xla/xla/service/gpu/gemm_algorithm_picker_test.cc
deleted file mode 100644
index 8049aa1..0000000
--- a/third_party/xla/xla/service/gpu/gemm_algorithm_picker_test.cc
+++ /dev/null
@@ -1,301 +0,0 @@
-/* Copyright 2022 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gemm_algorithm_picker.h"
-
-#include <cstdint>
-#include <variant>
-#include <vector>
-
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/service/gpu/autotuner_util.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/gemm_rewriter.h"
-#include "xla/service/gpu/variant_visitor.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/service/platform_util.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/stream_executor/platform.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/tsl/lib/core/status_test_util.h"
-#include "tsl/platform/statusor.h"
-#include "tsl/platform/test.h"
-#include "tsl/protobuf/dnn.pb.h"
-
-namespace xla::gpu {
-namespace {
-
-namespace m = ::xla::match;
-
-class GemmAlgorithmPickerTest : public HloTestBase,
-                                public ::testing::WithParamInterface<bool> {
- public:
-  GemmAlgorithmPickerTest() { AutotunerUtil::ClearAutotuneResults(); }
-
-  DebugOptions GetDebugOptionsForTest() override {
-    DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest();
-    debug_options.set_xla_gpu_enable_cublaslt(GetParam());
-    debug_options.set_xla_gpu_enable_triton_gemm(false);
-    return debug_options;
-  }
-
-  const se::DeviceDescription& device_desc() {
-    return backend().default_stream_executor()->GetDeviceDescription();
-  }
-
-  se::StreamExecutor* stream_exec() {
-    return backend().default_stream_executor();
-  }
-  const se::DeviceDescription& gpu_device_desc() {
-    return stream_exec()->GetDeviceDescription();
-  }
-  const se::GpuComputeCapability& gpu_comp() {
-    return gpu_device_desc().gpu_compute_capability();
-  }
-
-  void SetUp() override {
-    std::string_view name =
-        ::testing::UnitTest::GetInstance()->current_test_info()->name();
-    // We need special handling for BlasGetVersion test.
-    bool blas_get_version = name.rfind("BlasGetVersion") == 0;
-
-    std::visit(
-        VariantVisitor{
-            [&](const se::CudaComputeCapability& cc) {
-              if (!blas_get_version && cc.IsAtLeastAmpere()) {
-                GTEST_SKIP()
-                    << "Skipping this test for Ampere+ as it is supported "
-                       "and recommended with the Nvidia Volta+ GPUs.";
-              }
-            },
-            [&](const se::RocmComputeCapability& cc) {
-              if (blas_get_version) {
-                auto version = std::stol(device_desc().runtime_version());
-                if (version < 60200) {
-                  GTEST_SKIP()
-                      << "This API is not available on ROCM 6.1 and below.";
-                }
-              } else if (GetDebugOptionsForTest().xla_gpu_enable_cublaslt() &&
-                         !cc.has_hipblaslt()) {
-                GTEST_SKIP() << "No gpublas-lt support on this architecture!";
-              }
-            }},
-        gpu_comp());
-  }
-};
-
-TEST_P(GemmAlgorithmPickerTest, BlasGetVersion) {
-  auto* blas = backend().default_stream_executor()->AsBlas();
-  ASSERT_TRUE(blas != nullptr);
-  std::string version;
-  ASSERT_TRUE(blas->GetVersion(&version).ok());
-  VLOG(0) << "Blas version: " << version;
-  ASSERT_TRUE(!version.empty());
-}
-
-TEST_P(GemmAlgorithmPickerTest, SkipAlgorithmsWithAccuracyCheck) {
-  constexpr absl::string_view kHlo = R"(
-HloModule module
-
-ENTRY main {
-  %arg0 = f32[100,100]{1,0} parameter(0)
-  %arg1 = f32[100,100]{1,0} parameter(1)
-  ROOT %dot = f32[100,100]{1,0} dot(arg0, arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-})";
-
-  auto module_cfg = GetModuleConfigForTest();
-  auto debug_opts = module_cfg.debug_options();
-  size_t num_left1 = 0, num_left2 = 0;
-
-  TF_ASSERT_OK_AND_ASSIGN(auto module,
-                          ParseAndReturnVerifiedModule(kHlo, module_cfg));
-
-  {
-    // Run first with default settings (autotune level = 4), keep the number of
-    // algorithms left after autotuning
-    TF_ASSERT_OK_AND_ASSIGN(
-        bool changed,
-        RunHloPass(GemmRewriter(gpu_comp(), /*toolkit_version=*/12040),
-                   module.get()));
-
-    AutotuneConfig cfg{DeviceConfig{stream_exec(), nullptr}, debug_opts};
-    GemmAlgorithmPicker gpicker(cfg);
-    // Note that, we do not care if the algorithm index has been changed:
-    // the thing matters is the # of algorithms left after sorting out.
-    TF_ASSERT_OK_AND_ASSIGN(changed, RunHloPass(gpicker, module.get()));
-    num_left1 = gpicker.num_algorithms_left();
-    if (num_left1 < 2) {
-      GTEST_SKIP() << "Too few algorithms left after the first step";
-    }
-  }
-
-  // Clear cache before the second run!
-  AutotunerUtil::ClearAutotuneResults();
-  {
-    // Run once again but now with autotune level 5 and embarassingly tight
-    // rtol which shall disqualify most of the algorithms.
-
-    // Note that, we have "two sources of truth" for GemmAlgorithmPicker: i.e.,
-    // debug_options are used to initialize both 'HloModuleConfig' and also
-    // 'AutotuneConfig'.
-    debug_opts.set_xla_gpu_autotune_gemm_rtol(1e-12);
-    debug_opts.set_xla_gpu_autotune_level(5);
-    module->mutable_config().set_debug_options(debug_opts);
-    TF_ASSERT_OK_AND_ASSIGN(
-        bool changed,
-        RunHloPass(GemmRewriter(gpu_comp(), /*toolkit_version=*/12040),
-                   module.get()));
-
-    AutotuneConfig cfg{DeviceConfig{stream_exec(), nullptr}, debug_opts};
-    GemmAlgorithmPicker gpicker(cfg);
-    TF_ASSERT_OK_AND_ASSIGN(changed, RunHloPass(gpicker, module.get()));
-    num_left2 = gpicker.num_algorithms_left();
-  }
-  // Assert that we have fewer algorithms left after the second run.
-  ASSERT_TRUE(num_left1 > num_left2);
-}
-
-TEST_P(GemmAlgorithmPickerTest, SetAlgorithm) {
-  constexpr absl::string_view kHlo = R"(
-HloModule module
-
-ENTRY main {
-  %arg0 = f32[100,100]{1,0} parameter(0)
-  %arg1 = f32[100,100]{1,0} parameter(1)
-  ROOT %dot = f32[100,100]{1,0} dot(arg0, arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-})";
-
-  auto module_cfg = GetModuleConfigForTest();
-  TF_ASSERT_OK_AND_ASSIGN(auto m,
-                          ParseAndReturnVerifiedModule(kHlo, module_cfg));
-
-  bool changed = false;
-  TF_ASSERT_OK_AND_ASSIGN(
-      changed,
-      RunHloPass(GemmRewriter(gpu_comp(), /*toolkit_version=*/12040), m.get()));
-  changed = false;
-  DebugOptions opts;
-  AutotuneConfig cfg{DeviceConfig{stream_exec(), nullptr}, opts};
-  TF_ASSERT_OK_AND_ASSIGN(changed,
-                          RunHloPass(GemmAlgorithmPicker(cfg), m.get()));
-  ASSERT_TRUE(changed);
-
-  AutotuneResults results;
-  TF_ASSERT_OK(AutotunerUtil::SerializeAutotuneResults(&results));
-  ASSERT_EQ(results.results_size(), 1);
-  auto& result = *results.mutable_results(0)->mutable_result();
-  int64_t old_algo_id = result.algorithm().algo_id();
-  int64_t new_algo_id = old_algo_id + 1;
-  result.mutable_gemm()->set_algorithm(new_algo_id);
-
-  AutotunerUtil::ClearAutotuneResults();
-  TF_ASSERT_OK(AutotunerUtil::LoadAutotuneResults(results));
-
-  // Now send the same module through GemmAlgorithmPicker again.  The dot should
-  // have the new algorithm.
-  TF_ASSERT_OK_AND_ASSIGN(m, ParseAndReturnVerifiedModule(kHlo, module_cfg));
-  changed = false;
-  TF_ASSERT_OK_AND_ASSIGN(
-      changed,
-      RunHloPass(GemmRewriter(gpu_comp(), /*toolkit_version=*/12040), m.get()));
-  changed = false;
-  TF_ASSERT_OK_AND_ASSIGN(changed,
-                          RunHloPass(GemmAlgorithmPicker(cfg), m.get()));
-  ASSERT_TRUE(changed);
-
-  SCOPED_TRACE(m->ToString());
-  HloInstruction* dot;
-  ASSERT_THAT(m->entry_computation()->root_instruction(),
-              GmockMatch(m::GetTupleElement(m::CustomCall(&dot), 0)));
-
-  TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config,
-                          dot->backend_config<GpuBackendConfig>());
-  const GemmBackendConfig& config = gpu_config.gemm_backend_config();
-  EXPECT_EQ(config.selected_algorithm(), new_algo_id);
-}
-
-TEST_P(GemmAlgorithmPickerTest, GetAlgorithmWithoutDevice) {
-  constexpr absl::string_view kHlo = R"(
-HloModule module
-
-ENTRY main {
-  %arg0 = f32[100,100]{1,0} parameter(0)
-  %arg1 = f32[100,100]{1,0} parameter(1)
-  ROOT %dot = f32[100,100]{1,0} dot(arg0, arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-})";
-  TF_ASSERT_OK_AND_ASSIGN(
-      auto m, ParseAndReturnVerifiedModule(kHlo, GetModuleConfigForTest()));
-
-  bool changed = false;
-  TF_ASSERT_OK_AND_ASSIGN(
-      changed,
-      RunHloPass(GemmRewriter(gpu_comp(), /*toolkit_version=*/12040), m.get()));
-  changed = false;
-
-  DebugOptions opts;
-  AutotuneConfig cfg{DeviceConfig{stream_exec(), nullptr}, opts};
-
-  TF_ASSERT_OK_AND_ASSIGN(changed,
-                          RunHloPass(GemmAlgorithmPicker(cfg), m.get()));
-  ASSERT_TRUE(changed);
-
-  AutotuneResults results;
-  TF_ASSERT_OK(AutotunerUtil::SerializeAutotuneResults(&results));
-  ASSERT_EQ(results.results_size(), 1);
-  auto& result = *results.mutable_results(0)->mutable_result();
-  int64_t old_algo_id = result.algorithm().algo_id();
-  int64_t new_algo_id = old_algo_id + 1;
-  result.mutable_gemm()->set_algorithm(new_algo_id);
-
-  AutotunerUtil::ClearAutotuneResults();
-  TF_ASSERT_OK(AutotunerUtil::LoadAutotuneResults(results));
-
-  auto module_cfg = GetModuleConfigForTest();
-  // Now send the same module through GemmAlgorithmPicker again.  The dot should
-  // have the new algorithm.
-  TF_ASSERT_OK_AND_ASSIGN(m, ParseAndReturnVerifiedModule(kHlo, module_cfg));
-  changed = false;
-
-  DevicelessConfig deviceless_config{gpu_device_desc()};
-  AutotuneConfig deviceless_cfg{deviceless_config, opts};
-  TF_ASSERT_OK_AND_ASSIGN(changed,
-                          RunHloPass(GemmRewriter(gpu_comp(),
-                                                  /*toolkit_version=*/12040),
-                                     m.get()));
-  changed = false;
-  TF_ASSERT_OK_AND_ASSIGN(
-      changed, RunHloPass(GemmAlgorithmPicker(deviceless_cfg), m.get()))
-  ASSERT_TRUE(changed);
-
-  SCOPED_TRACE(m->ToString());
-  HloInstruction* dot;
-
-  ASSERT_THAT(m->entry_computation()->root_instruction(),
-              GmockMatch(m::GetTupleElement(m::CustomCall(&dot), 0)));
-
-  TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config,
-                          dot->backend_config<GpuBackendConfig>());
-  const GemmBackendConfig& config = gpu_config.gemm_backend_config();
-
-  EXPECT_EQ(config.selected_algorithm(), new_algo_id);
-}
-
-INSTANTIATE_TEST_SUITE_P(GemmAlgorithmPickerTestSuite, GemmAlgorithmPickerTest,
-                         ::testing::Bool());
-
-}  // namespace
-}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/gemm_broadcast_folding_rewriter.cc b/third_party/xla/xla/service/gpu/gemm_broadcast_folding_rewriter.cc
deleted file mode 100644
index a6cbbf11..0000000
--- a/third_party/xla/xla/service/gpu/gemm_broadcast_folding_rewriter.cc
+++ /dev/null
@@ -1,124 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gemm_broadcast_folding_rewriter.h"
-
-#include <cstdint>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/log/check.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/service/pattern_matcher.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-namespace m = match;
-
-class GemmBroadcastFoldingVisitor : public DfsHloRewriteVisitor {
- public:
-  absl::Status HandleCustomCall(HloInstruction *instr) override {
-    HloInstruction *existing_gemm;
-    HloInstruction *bcast;
-    if (Match(instr, m::CustomCall(&existing_gemm,
-                                   {kGemmCallTarget, kCublasLtMatmulCallTarget})
-                         .WithOperand(0, m::Broadcast(&bcast, m::Op()))) ||
-        (Match(instr, m::CustomCall(&existing_gemm, {kGemmCallTarget,
-                                                     kCublasLtMatmulCallTarget})
-                          .WithOperand(1, m::Broadcast(&bcast, m::Op()))))) {
-      TF_ASSIGN_OR_RETURN(auto gpu_config,
-                          existing_gemm->backend_config<GpuBackendConfig>());
-      GemmBackendConfig &config = *gpu_config.mutable_gemm_backend_config();
-      DotDimensionNumbers *dim_nums = config.mutable_dot_dimension_numbers();
-      int bcast_operand_index = instr->operand_index(bcast);
-      int num_bcast_dims = (bcast->shape().dimensions_size() -
-                            bcast->operand(0)->shape().dimensions_size());
-      int num_batch_dims = dim_nums->lhs_batch_dimensions_size();
-
-      const tsl::protobuf::RepeatedField<int64_t> &batch_dimensions =
-          (bcast_operand_index == 1) ? dim_nums->rhs_batch_dimensions()
-                                     : dim_nums->lhs_batch_dimensions();
-      // This optimization is only valid if the set of broadcasted dimensions
-      // is exactly the set of batch dimensions. First, check that all newly
-      // broadcast dimensions have been inserted on the left i.e. all new
-      // dimensions must be in [0, num_bcast_dims) or equivalently all original
-      // dimensions are >= num_bcast_dims.
-      for (int64_t bcast_dim : bcast->dimensions()) {
-        if (bcast_dim < num_bcast_dims) {
-          return absl::OkStatus();
-        }
-        // bcast_dim should not be in batch_dimensions.
-        if (absl::c_linear_search(batch_dimensions, bcast_dim)) {
-          return absl::OkStatus();
-        }
-      }
-
-      // Then check that all batch dimensions are being broadcast, and that
-      // there is at least one batch dimension.
-      CHECK_GT(num_bcast_dims, 0);
-      if (num_bcast_dims != num_batch_dims) {
-        return absl::OkStatus();
-      }
-
-      if (bcast_operand_index == 1) {
-        CHECK_EQ(dim_nums->rhs_contracting_dimensions_size(), 1);
-        dim_nums->set_rhs_contracting_dimensions(
-            0, dim_nums->rhs_contracting_dimensions(0) - num_batch_dims);
-        dim_nums->clear_rhs_batch_dimensions();
-      } else {
-        CHECK_EQ(dim_nums->lhs_contracting_dimensions_size(), 1);
-        dim_nums->set_lhs_contracting_dimensions(
-            0, dim_nums->lhs_contracting_dimensions(0) - num_batch_dims);
-        dim_nums->clear_lhs_batch_dimensions();
-      }
-      TF_RETURN_IF_ERROR(existing_gemm->ReplaceOperandWithDifferentShape(
-          bcast_operand_index, bcast->mutable_operand(0)));
-      TF_RETURN_IF_ERROR(existing_gemm->set_backend_config(gpu_config));
-      MarkAsChanged();
-    }
-    return absl::OkStatus();
-  }
-};
-
-static absl::StatusOr<bool> RunOnComputation(HloComputation *computation) {
-  GemmBroadcastFoldingVisitor visitor;
-  TF_RETURN_IF_ERROR(computation->Accept(&visitor));
-  return visitor.changed();
-}
-
-absl::StatusOr<bool> GemmBroadcastFoldingRewriter::Run(
-    HloModule *module,
-    const absl::flat_hash_set<absl::string_view> &execution_threads) {
-  bool changed = false;
-  for (HloComputation *computation :
-       module->MakeNonfusionComputations(execution_threads)) {
-    TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation));
-    changed |= result;
-  }
-  return changed;
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gemm_broadcast_folding_rewriter.h b/third_party/xla/xla/service/gpu/gemm_broadcast_folding_rewriter.h
deleted file mode 100644
index bac14bc..0000000
--- a/third_party/xla/xla/service/gpu/gemm_broadcast_folding_rewriter.h
+++ /dev/null
@@ -1,51 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_GEMM_BROADCAST_FOLDING_REWRITER_H_
-#define XLA_SERVICE_GPU_GEMM_BROADCAST_FOLDING_REWRITER_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-// cuBLAS GEMM has support for strided batched calls, where the stride is used
-// to determine the offset between the batches.
-//
-// This allows (kCustomCall:gemm A kBroadcast(B)) or
-//             (kCustomCall:gemm kBroadcast(A) B)
-// to be rewritten as (kCustomCall:gemm A B) with a zero stride for the
-// broadcasted operand if the broadcast operates on all the batch dimensions.
-//
-// This pattern matches the above case and removes the unnecessary broadcast.
-class GemmBroadcastFoldingRewriter : public HloModulePass {
- public:
-  absl::string_view name() const override {
-    return "cublas-gemm-broadcast-folding-rewriter";
-  }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_GEMM_BROADCAST_FOLDING_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/gemm_fusion.cc b/third_party/xla/xla/service/gpu/gemm_fusion.cc
deleted file mode 100644
index 4e37a4b..0000000
--- a/third_party/xla/xla/service/gpu/gemm_fusion.cc
+++ /dev/null
@@ -1,815 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gemm_fusion.h"
-
-#include <array>
-#include <cstddef>
-#include <cstdint>
-#include <optional>
-#include <queue>
-#include <string>
-#include <tuple>
-#include <utility>
-#include <variant>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_map.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/log/check.h"
-#include "absl/log/log.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/cublas_padding_requirements.h"
-#include "xla/service/gpu/fusions/triton/triton_support.h"
-#include "xla/service/gpu/ir_emission_utils.h"
-#include "xla/service/gpu/matmul_utils.h"
-#include "xla/service/gpu/triton_fusion_analysis.h"
-#include "xla/service/gpu/triton_tiling_propagation.h"
-#include "xla/service/instruction_fusion.h"
-#include "xla/shape_util.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/util.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-namespace {
-
-using triton_fusion::CombineDotRequirements;
-using triton_fusion::DimensionOrder;
-using triton_fusion::DimOrderMap;
-using triton_fusion::DimOrdersAndReqs;
-using triton_fusion::DimOrdersAndReqsOrError;
-using triton_fusion::DotProperties;
-using triton_fusion::DotRequirements;
-using triton_fusion::DotRequirementsOrError;
-using triton_fusion::FusionContext;
-using triton_fusion::GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible;
-using triton_fusion::TransformDirection;
-
-// This represents a directed graph.
-class AdjacencyList {
- public:
-  using NodeId = int64_t;
-
-  NodeId AddNode() {
-    adj_.emplace_back();
-    return adj_.size() - 1;
-  }
-
-  const std::vector<NodeId>& GetOutNeighbors(NodeId node_id) const {
-    return adj_.at(node_id);
-  }
-
-  void ReserveSpaceForOutNeighbors(NodeId node_id, size_t count) {
-    adj_.at(node_id).reserve(count);
-  }
-
-  void AddArc(NodeId from, NodeId to) { adj_.at(from).push_back(to); }
-
-  // Currently the Root node is the node which was added first.
-  NodeId GetRoot() const {
-    CHECK(!adj_.empty());
-    return 0;
-  }
-
- private:
-  // Adjacency list: A vector of out-neighbors for each node.
-  std::vector<std::vector<NodeId>> adj_;
-};
-
-struct HloAndDimOrder {
-  const HloInstruction* original_hlo = nullptr;
-  DimensionOrder dim_order;
-};
-
-struct HloAndIterSpec {
-  const HloInstruction* original_hlo;
-  TensorIterationSpec iter_spec;
-
-  auto ToTuple() const { return std::make_tuple(original_hlo, iter_spec); }
-  bool operator==(const HloAndIterSpec& other) const {
-    return ToTuple() == other.ToTuple();
-  }
-  template <typename H>
-  friend H AbslHashValue(H h, const HloAndIterSpec& key) {
-    return H::combine(std::move(h), key.ToTuple());
-  }
-};
-
-struct NodeFusionPlan {
-  const HloInstruction* original_hlo = nullptr;
-  bool should_fuse = false;
-};
-
-struct FusionPlan {
-  // The graph describing the structure of the fusion that we build - nodes
-  // corresponding to the instructions and arcs pointing from users to operands.
-  AdjacencyList graph;
-  // The fusion plan for each node.
-  absl::flat_hash_map<AdjacencyList::NodeId, NodeFusionPlan> map;
-};
-
-struct FusionPlanAndRequirements {
-  FusionPlan fusion_plan;
-  DotRequirements requirements;
-};
-
-struct HlosAndRequirements {
-  // The original HLO (which is outside the fusion computation).
-  const HloInstruction* original_hlo = nullptr;
-  // The fused HLO inside the new fusion computation, built by the builder.
-  //
-  // This can have the same opcode as `original_hlo` or it can be a parameter if
-  // the original HLO can't be fused.
-  const HloInstruction* fused_hlo = nullptr;
-  // The requirements imposed by the fused operations.
-  //
-  // If we fuse further operations they may have to conform to these
-  // requirements.
-  DotRequirements requirements;
-};
-
-// Clones the hero kDot operation into the fusion.
-HloInstruction& FuseDot(const HloDotInstruction& dot,
-                        const HloInstruction& fused_lhs,
-                        const HloInstruction& fused_rhs,
-                        std::optional<const HloInstruction*> fused_meta,
-                        HloComputation::Builder& builder  // append
-) {
-  VLOG(3) << "Fusing " << dot.ToString();
-
-  std::vector<HloInstruction*> hlo_new_operands = {
-      const_cast<HloInstruction*>(&fused_lhs),
-      const_cast<HloInstruction*>(&fused_rhs)};
-  if (fused_meta.has_value()) {
-    hlo_new_operands.push_back(const_cast<HloInstruction*>(fused_meta.value()));
-  }
-  return *builder.AddInstruction(
-      dot.CloneWithNewOperands(dot.shape(), hlo_new_operands));
-}
-
-// Tells how many new parameters does a fusion gain by fusing the operation as
-// an input.
-int64_t NumAddedParameters(const HloInstruction& hlo) {
-  // Non-scalar constant is equivalent to a parameter: one input, one output.
-  if (hlo.opcode() == HloOpcode::kParameter ||
-      (hlo.opcode() == HloOpcode::kConstant &&
-       !ShapeUtil::IsScalar(hlo.shape()))) {
-    return 0;
-  }
-  // All other instructions add all own inputs and remove own single output.
-  return hlo.operand_count() - 1;
-}
-
-// Just a helper to reduce "unwrapping" code where we use this.
-std::optional<DimOrdersAndReqs> GetOperandDimOrdersAndCombinedReqs(
-    const HloInstruction& hlo, const DimensionOrder& dim_order,
-    const DotProperties& properties,
-    const se::GpuComputeCapability& gpu_version,
-    const DotRequirements& requirements) {
-  DimOrdersAndReqsOrError dim_orders_and_new_reqs =
-      GetPropagatedDimOrdersAndRequirements(
-          hlo, dim_order, TransformDirection::kOutputToInput, properties);
-  if (!std::holds_alternative<DimOrdersAndReqs>(dim_orders_and_new_reqs)) {
-    return std::nullopt;
-  }
-  DotRequirementsOrError combined_reqs = CombineDotRequirements(
-      requirements,
-      std::get<DimOrdersAndReqs>(dim_orders_and_new_reqs).requirements);
-  if (!std::holds_alternative<DotRequirements>(combined_reqs)) {
-    return std::nullopt;
-  }
-  return DimOrdersAndReqs{
-      std::get<DimOrdersAndReqs>(dim_orders_and_new_reqs).dim_orders,
-      std::get<DotRequirements>(combined_reqs)};
-}
-
-// Just a helper to reduce "unwrapping" code where we use this.
-std::optional<DimOrdersAndReqs> GetOperandDimOrdersAndCombinedReqsIfProfitable(
-    const HloInstruction& hlo, const DimensionOrder& dim_order,
-    const DotProperties& properties,
-    const se::GpuComputeCapability& gpu_version,
-    const DotRequirements& requirements) {
-  DimOrdersAndReqsOrError dim_orders_and_new_reqs =
-      GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible(
-          hlo, TransformDirection::kOutputToInput,
-          /*src_operand_index=*/std::nullopt, dim_order, gpu_version,
-          properties);
-  if (!std::holds_alternative<DimOrdersAndReqs>(dim_orders_and_new_reqs)) {
-    return std::nullopt;
-  }
-  DotRequirementsOrError combined_reqs = CombineDotRequirements(
-      requirements,
-      std::get<DimOrdersAndReqs>(dim_orders_and_new_reqs).requirements);
-  if (!std::holds_alternative<DotRequirements>(combined_reqs)) {
-    return std::nullopt;
-  }
-  return DimOrdersAndReqs{
-      std::get<DimOrdersAndReqs>(dim_orders_and_new_reqs).dim_orders,
-      std::get<DotRequirements>(combined_reqs)};
-}
-
-// Just a helper to reduce "unwrapping" code where we use this.
-std::optional<DimOrdersAndReqs> GetUserDimOrdersAndCombinedReqsIfProfitable(
-    const HloInstruction& hlo, const DimensionOrder& hlo_dim_order,
-    const HloInstruction& user, const DotProperties& properties,
-    const se::GpuComputeCapability& gpu_version,
-    const DotRequirements& requirements) {
-  DimOrdersAndReqsOrError dim_orders_and_new_reqs =
-      GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible(
-          user, TransformDirection::kInputToOutput, user.operand_index(&hlo),
-          hlo_dim_order, gpu_version, properties);
-  if (!std::holds_alternative<DimOrdersAndReqs>(dim_orders_and_new_reqs)) {
-    return std::nullopt;
-  }
-  DotRequirementsOrError combined_reqs = CombineDotRequirements(
-      requirements,
-      std::get<DimOrdersAndReqs>(dim_orders_and_new_reqs).requirements);
-  if (!std::holds_alternative<DotRequirements>(combined_reqs)) {
-    return std::nullopt;
-  }
-  return DimOrdersAndReqs{
-      std::get<DimOrdersAndReqs>(dim_orders_and_new_reqs).dim_orders,
-      std::get<DotRequirements>(combined_reqs)};
-}
-
-// Builds the fusion map and the requirements which can later be used to
-// actually fuse that subgraph.
-FusionPlanAndRequirements BuildFusionPlanTowardOperands(
-    const HloInstruction& root_hlo, const DimensionOrder& root_dim_order,
-    const std::optional<int>& max_params,
-    const se::GpuComputeCapability& gpu_version,
-    const DotProperties& properties,
-    const DotRequirements& requirements_so_far) {
-  CHECK(!max_params.has_value() || max_params.value() >= 1);
-
-  // The graph describing the structure of the fusion that we build - nodes
-  // corresponding to the instructions and arcs pointing from users to operands.
-  // We can build and modify this graph easily without the need to create
-  // HloInstructions at this point.
-  AdjacencyList graph;
-  // Stores the original HLO and the dimension order for each node. This is a
-  // temporary map which is used when processing the nodes in this function.
-  absl::flat_hash_map<AdjacencyList::NodeId, HloAndDimOrder>
-      hlo_and_dim_order_map;
-  // Stores the information needed to build the fused HLO for each node (what
-  // was the original HLO and whether we should fuse it or create a parameter).
-  // This is one of the outputs of this function.
-  absl::flat_hash_map<AdjacencyList::NodeId, NodeFusionPlan> fusion_plan_map;
-  // Allows reusing nodes when multiple instructions iterate over the same HLO
-  // using the same iteration spec. In that case we don't duplicate the
-  // instruction in the fusion.
-  absl::flat_hash_map<HloAndIterSpec, AdjacencyList::NodeId> node_reuse_map;
-  // The requirements imposed by the fusion choices made in this function,
-  // combined with the existing requirements. This is one of the outputs of this
-  // function.
-  DotRequirements combined_reqs = requirements_so_far;
-
-  auto get_or_create_fusion_node =
-      [&](const HloInstruction& hlo, const DimensionOrder& dim_order,
-          bool* is_new_node = nullptr) -> AdjacencyList::NodeId {
-    HloAndIterSpec reuse_key = {&hlo, dim_order.ToTensorIterationSpec()};
-    if (auto it = node_reuse_map.find(reuse_key); it != node_reuse_map.end()) {
-      if (is_new_node != nullptr) {
-        *is_new_node = false;
-      }
-      return it->second;
-    }
-    AdjacencyList::NodeId node_id = graph.AddNode();
-    CHECK(hlo_and_dim_order_map.insert({node_id, {&hlo, dim_order}}).second);
-    CHECK(node_reuse_map.insert({reuse_key, node_id}).second);
-    if (is_new_node != nullptr) {
-      *is_new_node = true;
-    }
-    return node_id;
-  };
-  AdjacencyList::NodeId root =
-      get_or_create_fusion_node(root_hlo, root_dim_order);
-
-  // Nodes at the fusion edge that can either get fused too or become parameters
-  // of the fusion. Used to track the number of parameters.
-  absl::flat_hash_set<AdjacencyList::NodeId> inputs({root});
-  std::queue<AdjacencyList::NodeId> queue({root});
-  int64_t num_requeued = 0;
-  // BFS
-  while (queue.size() > num_requeued) {
-    AdjacencyList::NodeId node_id = queue.front();
-    queue.pop();
-    const HloAndDimOrder& hlo_and_dim_order = hlo_and_dim_order_map.at(node_id);
-    const HloInstruction& original_hlo = *hlo_and_dim_order.original_hlo;
-    const DimensionOrder& dim_order = hlo_and_dim_order.dim_order;
-
-    // Watch the total number of fusion parameters.
-    if (max_params.has_value() &&
-        inputs.size() + NumAddedParameters(original_hlo) > max_params.value()) {
-      // Re-queue: the number of parameters may go down when other instructions
-      // are processed.
-      queue.push(node_id);
-      // Prevent infinite loops.
-      ++num_requeued;
-      continue;
-    }
-    num_requeued = 0;
-    if (original_hlo.opcode() == HloOpcode::kParameter) {
-      CHECK(fusion_plan_map
-                .insert({node_id, {&original_hlo, /*should_fuse=*/false}})
-                .second);
-      continue;
-    }
-    auto opt_result = GetOperandDimOrdersAndCombinedReqsIfProfitable(
-        original_hlo, dim_order, properties, gpu_version, combined_reqs);
-    if (!opt_result.has_value()) {
-      CHECK(fusion_plan_map
-                .insert({node_id, {&original_hlo, /*should_fuse=*/false}})
-                .second);
-      continue;
-    }
-    const DimOrderMap operand_dim_orders = std::move(opt_result->dim_orders);
-    combined_reqs = std::move(opt_result->requirements);
-    inputs.erase(node_id);
-    graph.ReserveSpaceForOutNeighbors(node_id, original_hlo.operand_count());
-    for (int64_t i = 0; i < original_hlo.operand_count(); ++i) {
-      const HloInstruction& operand = *original_hlo.operand(i);
-      const DimensionOrder& operand_dim_order = operand_dim_orders.at(&operand);
-      bool is_new_node = false;
-      AdjacencyList::NodeId operand_node_id =
-          get_or_create_fusion_node(operand, operand_dim_order, &is_new_node);
-      graph.AddArc(node_id, operand_node_id);
-      if (is_new_node) {
-        VLOG(6) << "Enqueueing " << operand.ToString() << ":"
-                << operand_dim_order.ToString();
-        inputs.insert(operand_node_id);
-        queue.push(operand_node_id);
-      }
-    }
-    CHECK(
-        fusion_plan_map.insert({node_id, {&original_hlo, /*should_fuse=*/true}})
-            .second);
-  }
-  // Handle the remaining requeued items.
-  while (!queue.empty()) {
-    AdjacencyList::NodeId node_id = queue.front();
-    queue.pop();
-
-    const HloAndDimOrder& hlo_and_dim_order = hlo_and_dim_order_map.at(node_id);
-    CHECK(fusion_plan_map
-              .insert({node_id,
-                       {hlo_and_dim_order.original_hlo, /*should_fuse=*/false}})
-              .second);
-  }
-  return {{std::move(graph), std::move(fusion_plan_map)},
-          std::move(combined_reqs)};
-}
-
-// Builds the HLO instructions for the fusion represented by `fusion_plan`,
-// starting from `node_id`.
-HloInstruction& BuildFusionTowardOperandsImpl(
-    AdjacencyList::NodeId node_id, const FusionPlan& fusion_plan,
-    absl::flat_hash_map<AdjacencyList::NodeId, HloInstruction*>&
-        fused_hlo_map,                           // read/append
-    HloComputation::Builder& builder,            // append
-    std::vector<HloInstruction*>& fusion_params  // append
-) {
-  if (auto it = fused_hlo_map.find(node_id); it != fused_hlo_map.end()) {
-    return *it->second;
-  }
-
-  const NodeFusionPlan& node_fusion_plan = fusion_plan.map.at(node_id);
-  const bool should_fuse = node_fusion_plan.should_fuse;
-  const HloInstruction& original_hlo = *node_fusion_plan.original_hlo;
-
-  HloInstruction* fused_hlo = nullptr;
-  if (should_fuse) {
-    HloInstruction::InstructionVector new_operands;
-    for (AdjacencyList::NodeId operand_id :
-         fusion_plan.graph.GetOutNeighbors(node_id)) {
-      new_operands.push_back(&BuildFusionTowardOperandsImpl(
-          operand_id, fusion_plan, fused_hlo_map, builder, fusion_params));
-    }
-    fused_hlo = builder.AddInstruction(
-        original_hlo.CloneWithNewOperands(original_hlo.shape(), new_operands));
-  } else {
-    fusion_params.push_back(const_cast<HloInstruction*>(&original_hlo));
-    fused_hlo = builder.AddInstruction(HloInstruction::CreateParameter(
-        fusion_params.size() - 1, original_hlo.shape(),
-        absl::StrCat("parameter_", fusion_params.size() - 1)));
-  }
-
-  CHECK(fused_hlo_map.insert({node_id, fused_hlo}).second);
-  return *fused_hlo;
-}
-
-// Builds the HLO instructions for the fusion represented by `fusion_plan`.
-HloInstruction& BuildFusionTowardOperands(
-    const FusionPlan& fusion_plan,
-    HloComputation::Builder& builder,            // append
-    std::vector<HloInstruction*>& fusion_params  // append
-) {
-  absl::flat_hash_map<AdjacencyList::NodeId, HloInstruction*> fused_hlo_map;
-  return BuildFusionTowardOperandsImpl(fusion_plan.graph.GetRoot(), fusion_plan,
-                                       fused_hlo_map, builder, fusion_params);
-}
-
-// Grows the fusion toward the operands.
-//
-// This always succeeds.
-//
-// If it's not possible to fuse something, it fuses a parameter instead.
-//
-// The fusion can grow until it has `max_params` params and it can only grow
-// with operations for which the DimOrder propagation works and they don't
-// impose requirements contradicting the existing requirements.
-//
-// The return value contains the HLOs corresponding to `root_hlo` and the
-// requirements corresponding to the whole fusion so far.
-HlosAndRequirements FuseTowardOperands(
-    const HloInstruction& root_hlo, const DimensionOrder& root_dim_order,
-    const std::optional<int>& max_params,
-    const se::GpuComputeCapability& gpu_version,
-    const DotProperties& properties, const DotRequirements& requirements_so_far,
-    HloComputation::Builder& builder,            // append
-    std::vector<HloInstruction*>& fusion_params  // append
-) {
-  FusionPlanAndRequirements fusion_plan_and_reqs =
-      BuildFusionPlanTowardOperands(root_hlo, root_dim_order, max_params,
-                                    gpu_version, properties,
-                                    requirements_so_far);
-  HloInstruction& fused_hlo_or_param = BuildFusionTowardOperands(
-      fusion_plan_and_reqs.fusion_plan, builder, fusion_params);
-  return HlosAndRequirements{&root_hlo, &fused_hlo_or_param,
-                             fusion_plan_and_reqs.requirements};
-}
-
-// Grows the fusion toward the given dot operand.
-//
-// This always succeeds.
-//
-// If it's not possible to fuse something, it fuses a parameter instead.
-//
-// The fusion can grow until it has `max_params` params and it can only grow
-// with operations for which the DimOrder propagation works and they don't
-// impose requirements contradicting the existing requirements.
-//
-// The return value contains the HLOs corresponding to the given dot operand and
-// the requirements corresponding to the whole fusion so far.
-absl::StatusOr<HlosAndRequirements> FuseDotOperand(
-    const HloInstruction& dot, int operand_index,
-    const se::GpuComputeCapability& gpu_version,
-    HloComputation::Builder& builder,            // append
-    std::vector<HloInstruction*>& fusion_params  // append
-) {
-  // Direct dot inputs have well defined dimension orders.
-  TF_ASSIGN_OR_RETURN(const FusionContext context,
-                      FusionContext::FromDotOperand(dot, operand_index));
-  const HloInstruction& operand = *dot.operand(operand_index);
-  return FuseTowardOperands(operand, context.dim_orders().at(&operand),
-                            TritonFusionAnalysis::kMaxParameterPerDotOperand,
-                            gpu_version, context.dot_properties(),
-                            context.requirements(), builder, fusion_params);
-}
-
-// Grows the fusion toward the users.
-//
-// This always succeeds.
-//
-// The fusion can grow as long as the DimOrder propagation works and the users
-// don't impose requirements contradicting the existing requirements.
-//
-// The return value contains the HLOs corresponding to the "lowest" fused user
-// or `hlo` if no users can be fused.
-//
-// It also grows the fusion upward, toward the "other" operands of the users,
-// but currently only in special cases, such as binary elementwise operation
-// with broadcast of scalar constant.
-HlosAndRequirements FuseTowardUsers(
-    const HloInstruction& hlo, const HloInstruction& fused_hlo,
-    const DimensionOrder& hlo_dim_order,
-    const se::GpuComputeCapability& gpu_version,
-    const DotProperties& properties, const DotRequirements& requirements,
-    HloComputation::Builder& builder,            // append
-    std::vector<HloInstruction*>& fusion_params  // append
-) {
-  const HlosAndRequirements existing_hlos_and_requirements = {&hlo, &fused_hlo,
-                                                              requirements};
-  if (hlo.user_count() != 1) {
-    return existing_hlos_and_requirements;
-  }
-  const HloInstruction& user = *hlo.users()[0];
-  if (!legacy_triton::IsDistributiveOverAddition(user)) {
-    return existing_hlos_and_requirements;
-  }
-
-  // Get the dim orders for the user.
-  auto opt_user_result = GetUserDimOrdersAndCombinedReqsIfProfitable(
-      hlo, hlo_dim_order, user, properties, gpu_version, requirements);
-  if (!opt_user_result.has_value()) {
-    return existing_hlos_and_requirements;
-  }
-  DimensionOrder user_dim_order = opt_user_result->dim_orders.at(&user);
-  DotRequirements combined_requirements = opt_user_result->requirements;
-
-  HloInstruction::InstructionVector new_operands;
-  if (user.operand_count() == 1) {
-    new_operands.push_back(const_cast<HloInstruction*>(&fused_hlo));
-  } else {
-    // Get the dim orders for the operands of the user.
-    // We shouldn't do a profitability check here, we made that decision in
-    // GetUserDimOrdersAndCombinedReqsIfProfitable.
-    auto opt_operand_result = GetOperandDimOrdersAndCombinedReqs(
-        user, user_dim_order, properties, gpu_version, combined_requirements);
-    // This shouldn't fail, because currently we only encounter this when we
-    // have just propagated down the DimOrders on a binary elementwise
-    // operation (user). In that case propagating up the DimOrders should always
-    // work.
-    if (!opt_operand_result.has_value()) {
-      return existing_hlos_and_requirements;
-    }
-    DimOrderMap operand_dim_orders = opt_operand_result->dim_orders;
-    combined_requirements = opt_operand_result->requirements;
-
-    // Fuse the other operands of the user.
-    for (int i = 0; i < user.operand_count(); ++i) {
-      const HloInstruction& operand = *user.operand(i);
-      if (&operand == &hlo) {
-        new_operands.push_back(const_cast<HloInstruction*>(&fused_hlo));
-      } else {
-        HlosAndRequirements hlos_and_requirements = FuseTowardOperands(
-            operand, operand_dim_orders.at(&operand),
-            /*max_params=*/std::nullopt, gpu_version, properties,
-            combined_requirements, builder, fusion_params);
-        new_operands.push_back(
-            const_cast<HloInstruction*>(hlos_and_requirements.fused_hlo));
-        combined_requirements = hlos_and_requirements.requirements;
-      }
-    }
-  }
-
-  const HloInstruction& fused_user = *builder.AddInstruction(
-      user.CloneWithNewOperands(user.shape(), new_operands));
-  return FuseTowardUsers(user, fused_user, user_dim_order, gpu_version,
-                         properties, combined_requirements, builder,
-                         fusion_params);
-}
-
-// Grows the fusion toward the users of the dot.
-//
-// This always succeeds.
-//
-// The fusion can grow as long as the DimOrder propagation works and the users
-// don't impose requirements contradicting the existing requirements.
-//
-// The return value contains the HLOs corresponding to the "lowest" fused user
-// or `dot` if no users can be fused.
-//
-// It also grows the fusion towards the "other" operands of the users, but
-// currently only in special cases, such as binary elementwise operation with
-// broadcast of scalar constant.
-HlosAndRequirements FuseDotOutput(
-    const HloInstruction& dot, const HloInstruction& fused_dot,
-    const se::GpuComputeCapability& gpu_version,
-    const DotRequirements& requirements,
-    HloComputation::Builder& builder,            // append
-    std::vector<HloInstruction*>& fusion_params  // append
-) {
-  const auto context =
-      FusionContext::FromDotOutput(dot, /*split_k=*/1, requirements);
-  return FuseTowardUsers(dot, fused_dot, context.dim_orders().at(&dot),
-                         gpu_version, context.dot_properties(),
-                         context.requirements(), builder, fusion_params);
-}
-
-// Fuses dot and the compatible and profitable to fuse operations around it
-// into a new fusion computation constructed using the builder. fusion_inputs
-// get populated with the non-fused instructions that become operands of the
-// call to this fusion. fusion_output_ptr (if not nullptr) gets assigned the
-// original instruction that has to be replaced by the call to the fusion.
-absl::StatusOr<FusionDecision> CreateDotFusion(
-    const HloDotInstruction& dot, const se::GpuComputeCapability gpu_version,
-    HloComputation::Builder& builder,
-    std::vector<HloInstruction*>& fusion_inputs,
-    HloInstruction** fusion_output_ptr) {
-  VLOG(5) << dot.ToString();
-  if (CodegenDecision is_supported =
-          legacy_triton::IsTritonSupportedInstruction(dot, gpu_version);
-      !is_supported) {
-    VLOG(3) << is_supported.Explain();
-    return is_supported;
-  }
-
-  // Verify sparse dot constraints.
-  if (dot.sparse_operands()) {
-    const SparsityDescriptor& descriptor = dot.sparsity().front();
-    if (dot.sparse_operands() != 1 || descriptor.index() != 0) {
-      return InvalidArgument("Sparsity is only supported on left operand");
-    }
-    if (descriptor.type() != SparsityType::SPARSITY_STRUCTURED_N_M ||
-        descriptor.n() != 2 || descriptor.m() != 4) {
-      return InvalidArgument("Only 2:4 structured sparsity is supported");
-    }
-    // DotDimensionSorter pass makes sure the sparse dimension is minor.
-    CHECK_EQ(descriptor.dimension(), dot.operand(0)->shape().rank() - 1);
-  }
-
-  TF_ASSIGN_OR_RETURN(HlosAndRequirements lhs_hlos_and_reqs,
-                      FuseDotOperand(dot, /*operand_index=*/0, gpu_version,
-                                     builder, fusion_inputs));
-  TF_ASSIGN_OR_RETURN(HlosAndRequirements rhs_hlos_and_reqs,
-                      FuseDotOperand(dot, /*operand_index=*/1, gpu_version,
-                                     builder, fusion_inputs));
-  std::optional<const HloInstruction*> meta_hlo;
-  if (dot.sparse_operands()) {
-    TF_ASSIGN_OR_RETURN(HlosAndRequirements meta_hlos_and_reqs,
-                        FuseDotOperand(dot, /*operand_index=*/2, gpu_version,
-                                       builder, fusion_inputs));
-    meta_hlo.emplace(meta_hlos_and_reqs.fused_hlo);
-  }
-  HloInstruction& fused_dot =
-      FuseDot(dot, *lhs_hlos_and_reqs.fused_hlo, *rhs_hlos_and_reqs.fused_hlo,
-              meta_hlo, builder);
-  // For now the RHS doesn't support splits, so it also doesn't impose any
-  // requirements.
-  HlosAndRequirements fused_output_and_reqs =
-      FuseDotOutput(dot, fused_dot, gpu_version, lhs_hlos_and_reqs.requirements,
-                    builder, fusion_inputs);
-
-  if (fusion_output_ptr != nullptr) {
-    *fusion_output_ptr =
-        const_cast<HloInstruction*>(fused_output_and_reqs.original_hlo);
-  }
-
-  const PrecisionConfig::Algorithm algorithm =
-      dot.precision_config().algorithm();
-  if (algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6 ||
-      algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3 ||
-      dot.GetModule()->config().debug_options().xla_gpu_triton_gemm_any() ||
-      dot.sparse_operands()) {
-    return FusionDecision{};
-  }
-
-  bool is_pure_matmul = true;
-  (void)builder.ForEachInstruction([&](const HloInstruction* fused_hlo) {
-    static constexpr std::array<HloOpcode, 4> kPureOpcodes = {
-        HloOpcode::kBitcast, HloOpcode::kDot, HloOpcode::kParameter,
-        HloOpcode::kReshape};
-    if (absl::c_find(kPureOpcodes, fused_hlo->opcode()) == kPureOpcodes.end()) {
-      is_pure_matmul = false;
-      // Stop iterating.
-      return absl::CancelledError();
-    }
-    return absl::OkStatus();
-  });
-  if (!is_pure_matmul) {
-    return FusionDecision{};
-  }
-
-  return "No profitable operations to fuse.";
-}
-
-// Extracts into fused computations parts of HLO graph including dot()
-// operations that can target the triton GEMM emitter.
-class GemmFusionVisitor : public DfsHloRewriteVisitor {
- public:
-  explicit GemmFusionVisitor(const se::GpuComputeCapability& gpu_version)
-      : gpu_version_(gpu_version) {}
-  // Checks that a dot() should be targeting the triton GEMM emitter;
-  // if so - fuses all its compatible inputs and outputs as a new computation
-  // and replaces the original dot() with a call to the computation.
-  absl::Status HandleDot(HloInstruction* dot) override {
-    CHECK_EQ(dot->opcode(), HloOpcode::kDot);
-
-    int64_t gemm_rewrite_size_threshold =
-        dot->GetModule()
-            ->config()
-            .debug_options()
-            .xla_gpu_gemm_rewrite_size_threshold();
-    TF_ASSIGN_OR_RETURN(bool is_matmul_tiny,
-                        IsMatrixMultiplicationTooSmallForRewriting(
-                            *dot, gemm_rewrite_size_threshold));
-    if (is_matmul_tiny && IsDotSupportedByClassicalEmitters(*dot)) {
-      return absl::OkStatus();
-    }
-
-    std::string fusion_name = absl::StrCat("gemm_fusion_", dot->name());
-    HloComputation::Builder builder(absl::StrCat(fusion_name, "_computation"));
-    std::vector<HloInstruction*> fusion_inputs;
-    HloInstruction* fusion_output = nullptr;
-    TF_ASSIGN_OR_RETURN(
-        const FusionDecision should_fuse,
-        CreateDotFusion(*Cast<HloDotInstruction>(dot), gpu_version_, builder,
-                        fusion_inputs, &fusion_output));
-    if (builder.last_added_instruction() == nullptr) {
-      return absl::OkStatus();
-    }
-    // If a GEMM requiring padding for cuBLAS is encountered here this
-    // happened because earlier ShouldTritonHandleGEMM() accepted it and padding
-    // was skipped. Accept it ignoring profitability checks.
-    // TODO(rocm): check ROCM padding requirements.
-    if (std::holds_alternative<se::CudaComputeCapability>(gpu_version_)) {
-      if (!CublasRequiresPadding(
-              *Cast<HloDotInstruction>(dot),
-              std::get<se::CudaComputeCapability>(gpu_version_)) &&
-          !should_fuse) {
-        return absl::OkStatus();
-      }
-    }
-
-    HloComputation* computation =
-        dot->GetModule()->AddComputationAndUnifyNamesAndIds(builder.Build(),
-                                                            /*is_entry=*/false);
-    HloInstruction* dot_fusion =
-        dot->parent()->AddInstruction(HloInstruction::CreateFusion(
-            computation->root_instruction()->shape(),
-            HloInstruction::FusionKind::kCustom, fusion_inputs, computation));
-    // Copy the metadata of the `dot` to the newly created `fusion` op. This
-    // is convenient for handling metadata in split-k rewriting subsequently.
-    dot_fusion->set_metadata(dot->metadata());
-    dot_fusion->GetModule()->SetAndUniquifyInstrName(dot_fusion, fusion_name);
-
-    TF_ASSIGN_OR_RETURN(auto gpu_config,
-                        dot_fusion->backend_config<GpuBackendConfig>());
-    FusionBackendConfig& backend_config =
-        *gpu_config.mutable_fusion_backend_config();
-    backend_config.set_kind(std::string(kTritonGemmFusionKind));
-    TF_RETURN_IF_ERROR(dot_fusion->set_backend_config(gpu_config));
-
-    if (fusion_output->IsRoot()) {
-      fusion_output->parent()->set_root_instruction(dot_fusion);
-      TF_RETURN_IF_ERROR(
-          fusion_output->parent()->RemoveInstructionAndUnusedOperands(
-              fusion_output));
-      MarkAsChanged();
-    } else {
-      TF_RETURN_IF_ERROR(ReplaceInstruction(fusion_output, dot_fusion));
-    }
-    XLA_VLOG_LINES(5, computation->ToString(HloPrintOptions::ShortParsable()));
-    return absl::OkStatus();
-  }
-
- private:
-  se::GpuComputeCapability gpu_version_;
-};
-
-absl::StatusOr<bool> RunOnComputation(
-    HloComputation* computation, const se::GpuComputeCapability& gpu_version) {
-  GemmFusionVisitor visitor(gpu_version);
-  TF_RETURN_IF_ERROR(computation->Accept(&visitor));
-  return visitor.changed();
-}
-
-
-}  // namespace
-
-bool ShouldTritonHandleGEMM(HloDotInstruction& dot,
-                            const se::GpuComputeCapability& gpu_version) {
-  std::vector<HloInstruction*> fusion_inputs;
-  HloComputation::Builder builder("disposable");
-  return CreateDotFusion(dot, gpu_version, builder, fusion_inputs,
-                         /*fusion_output_ptr=*/nullptr)
-      ->CanFuse();
-}
-
-absl::StatusOr<bool> GemmFusion::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  TF_RETURN_IF_ERROR(
-      EnsureTritonSupportsComputeCapability(compute_capability_));
-
-  bool changed = false;
-  for (HloComputation* computation :
-       module->MakeNonfusionComputations(execution_threads)) {
-    TF_ASSIGN_OR_RETURN(bool result,
-                        RunOnComputation(computation, compute_capability_));
-    changed |= result;
-  }
-  return changed;
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gemm_fusion.h b/third_party/xla/xla/service/gpu/gemm_fusion.h
deleted file mode 100644
index c858b43..0000000
--- a/third_party/xla/xla/service/gpu/gemm_fusion.h
+++ /dev/null
@@ -1,57 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_GEMM_FUSION_H_
-#define XLA_SERVICE_GPU_GEMM_FUSION_H_
-
-// This file contains the code for fusing dots and other operations into Triton
-// GEMM fusions.
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/service/instruction_fusion.h"
-#include "xla/stream_executor/device_description.h"
-
-namespace xla {
-namespace gpu {
-
-// Filters GEMMs which are better to handle using Triton.
-bool ShouldTritonHandleGEMM(HloDotInstruction&,
-                            const se::GpuComputeCapability&);
-
-// Rewrite compatible dot() calls into custom calls with fused computations
-// that target Triton-based matmul emitter.
-class GemmFusion : public HloModulePass {
- public:
-  explicit GemmFusion(const se::GpuComputeCapability& compute_capability)
-      : compute_capability_(compute_capability) {}
-  absl::string_view name() const override { return "triton-gemm-rewriter"; }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
-  se::GpuComputeCapability compute_capability_;
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_GEMM_FUSION_H_
diff --git a/third_party/xla/xla/service/gpu/gemm_fusion_autotuner.cc b/third_party/xla/xla/service/gpu/gemm_fusion_autotuner.cc
deleted file mode 100644
index 5a5a0f3..0000000
--- a/third_party/xla/xla/service/gpu/gemm_fusion_autotuner.cc
+++ /dev/null
@@ -1,1280 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gemm_fusion_autotuner.h"
-
-#include <algorithm>
-#include <array>
-#include <atomic>
-#include <cstdint>
-#include <iterator>
-#include <memory>
-#include <optional>
-#include <string>
-#include <utility>
-#include <variant>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_map.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/log/check.h"
-#include "absl/log/log.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/str_format.h"
-#include "absl/strings/string_view.h"
-#include "absl/synchronization/mutex.h"
-#include "absl/time/time.h"
-#include "absl/types/span.h"
-#include "third_party/gpus/cuda/include/cublas_v2.h"
-#include "xla/autotuning.pb.h"
-#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_clone_context.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/hlo/utils/hlo_query.h"
-#include "xla/pjrt/distributed/key_value_store_interface.h"
-#include "xla/primitive_util.h"
-#include "xla/service/algorithm_util.h"
-#include "xla/service/dump.h"
-#include "xla/service/executable.h"
-#include "xla/service/float_normalization.h"
-#include "xla/service/gpu/autotuner_compile_util.h"
-#include "xla/service/gpu/autotuner_util.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/buffer_comparator.h"
-#include "xla/service/gpu/cudnn_fusion_compiler.h"
-#include "xla/service/gpu/fusion_wrapper.h"
-#include "xla/service/gpu/gemm_rewriter.h"
-#include "xla/service/gpu/gpu_float_support.h"
-#include "xla/service/gpu/hlo_traversal.h"
-#include "xla/service/gpu/instruction_fusion.h"
-#include "xla/service/gpu/ir_emission_utils.h"
-#include "xla/service/gpu/matmul_utils.h"
-#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h"
-#include "xla/service/gpu/priority_fusion.h"
-#include "xla/service/gpu/split_k_gemm_rewriter.h"
-#include "xla/service/gpu/stream_executor_util.h"
-#include "xla/service/hlo_module_config.h"
-#include "xla/service/shaped_buffer.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/status_macros.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/stream_executor/device_memory.h"
-#include "xla/stream_executor/device_memory_allocator.h"
-#include "xla/stream_executor/gpu/redzone_allocator.h"
-#include "xla/stream_executor/stream.h"
-#include "xla/stream_executor/stream_executor_memory_allocator.h"
-#include "xla/tools/hlo_decomposer.h"
-#include "xla/tsl/util/proto/proto_utils.h"
-#include "xla/util.h"
-#include "xla/xla.pb.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/lib/core/bits.h"
-#include "tsl/platform/blocking_counter.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/path.h"
-#include "tsl/platform/protobuf.h"
-#include "tsl/platform/status.h"
-#include "tsl/platform/statusor.h"
-#include "tsl/platform/threadpool.h"
-#include "tsl/profiler/lib/scoped_annotation.h"
-
-// Log levels used in this file:
-// VLOG(1): Overview
-// VLOG(2): Autotuning progress
-// VLOG(3): Autotuning progress - more frequent
-// VLOG(4): Print all fusions
-// VLOG(5): Profiling information for every tiling
-// VLOG(10): Print fusion computations and each configuration
-
-// TODO(b/317016172): Update usages of TritonGemmConfig to use newly exposed
-// parameters.
-
-namespace xla {
-namespace gpu {
-
-using Config = GemmFusionAutotunerImpl::Config;
-using TilingConfigs = GemmFusionAutotunerImpl::TilingConfigs;
-using ProfilingOutput = AutotunerCompileUtil::ProfilingOutput;
-
-namespace {
-
-// Minimum tile size.
-constexpr int kMinTileSize = 16;
-
-// Default tiling when autotuning is disabled.
-constexpr TritonGemmConfig kDefaultGemmTiling = {32, 32, 32, 1, 1, 4};
-
-// Split-K is enabled when the estimate number of waves is lower than the limit.
-constexpr int kMaxWavesForSplitK = 5;
-
-// Search space for exhaustive matmul autotuning.
-constexpr std::array<int, 6> kBlockSizes = {16, 32, 64, 128, 256, 512};
-constexpr std::array<int, 4> kNumStages = {1, 2, 3, 4};
-constexpr std::array<int, 4> kNumWarps = {2, 4, 8, 16};
-constexpr std::array<int, 5> kSplitK = {1, 2, 4, 8, 16};
-constexpr std::array<int, 5> kNumCtas = {1, 2, 4, 8, 16};
-
-using AutoTuneCacheKeyCount = absl::flat_hash_map<AutotuneCacheKey, uint64_t>;
-
-class GemmFusionAutotunerVisitor : public DfsHloRewriteVisitor {
- public:
-  explicit GemmFusionAutotunerVisitor(const AutotuneConfig& config)
-      : config_(config) {}
-
-  absl::Status HandleFusion(HloInstruction* hlo) override {
-    TF_ASSIGN_OR_RETURN(auto gpu_config,
-                        hlo->backend_config<GpuBackendConfig>());
-    FusionBackendConfig& backend_config =
-        *gpu_config.mutable_fusion_backend_config();
-    if (backend_config.kind() != kTritonGemmFusionKind &&
-        backend_config.kind() != kCuDnnFusionKind) {
-      return absl::OkStatus();
-    }
-
-    VLOG(4) << "Processing " << hlo->ToString();
-    if (!backend_config.has_triton_gemm_config() &&
-        !backend_config.has_cudnn_fusion_config()) {
-      TF_ASSIGN_OR_RETURN(
-          AutotuneResult autotune_result,
-          AutotunerUtil::Autotune(
-              hlo, config_, [&]() -> absl::StatusOr<AutotuneResult> {
-                if (config_.IsDeviceless()) {
-                  return absl::InternalError(absl::StrCat(
-                      "Expect autotune result cache hit for deviceless "
-                      "compilation (HLO: ",
-                      hlo->ToString(), ")"));
-                }
-                return absl::InternalError("Expect autotune result cache hit.");
-              }));
-      VLOG(4) << "Result: " << autotune_result.ShortDebugString();
-
-      if (autotune_result.has_triton()) {
-        *backend_config.mutable_triton_gemm_config() = autotune_result.triton();
-        TF_RETURN_IF_ERROR(hlo->set_backend_config(gpu_config));
-      } else if (autotune_result.has_gemm()) {
-        // Falling back to cuBLAS: Converting the fusion to a Call, so that it
-        // can be inlined back again.
-        HloComputation* const computation = hlo->parent();
-        HloInstruction* const call = computation->AddInstruction(
-            HloInstruction::CreateCall(hlo->shape(), hlo->operands(),
-                                       hlo->fused_instructions_computation()));
-        TF_RETURN_IF_ERROR(computation->ReplaceInstruction(hlo, call));
-        hlo = call;
-      } else {
-        CHECK(autotune_result.has_algorithm());
-        backend_config.set_kind(std::string(kCuDnnFusionKind));
-        backend_config.mutable_cudnn_fusion_config()->set_plan_id(
-            autotune_result.algorithm().algo_id());
-        TF_RETURN_IF_ERROR(hlo->set_backend_config(gpu_config));
-      }
-    }
-
-    if (backend_config.has_triton_gemm_config()) {
-      TF_ASSIGN_OR_RETURN(
-          const TritonGemmConfig config,
-          TritonGemmConfig::FromProto(backend_config.triton_gemm_config()));
-      if (config.split_k > 1) {
-        TF_RETURN_IF_ERROR(MakeDotSplitKBatch(hlo, config));
-      }
-    }
-
-    MarkAsChanged();
-    return absl::OkStatus();
-  }
-
- private:
-  AutotuneConfig config_;
-};
-
-class GemmConfigSetCollector : public ConstDfsHloVisitorWithDefault {
- public:
-  explicit GemmConfigSetCollector(GemmFusionAutotunerImpl* impl)
-      : impl_(impl) {}
-
-  // Find configurations to tune.
-  absl::StatusOr<TilingConfigs> CollectGemmConfigSets(
-      const HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads = {}) {
-    error_out_on_cache_miss_ =
-        module->config()
-            .debug_options()
-            .xla_gpu_require_complete_aot_autotune_results();
-    gemm_config_sets_.clear();
-    for (HloComputation* computation :
-         module->MakeNonfusionComputations(execution_threads)) {
-      TF_RETURN_IF_ERROR(computation->Accept(this));
-    }
-    return std::move(gemm_config_sets_);
-  }
-
-  AutoTuneCacheKeyCount GetFusionsCount() {
-    return std::move(fusion_count_map_);
-  }
-
-  absl::Status HandleFusion(const HloInstruction* hlo) override {
-    const HloFusionInstruction* fusion = Cast<HloFusionInstruction>(hlo);
-
-    TF_ASSIGN_OR_RETURN(auto gpu_config,
-                        hlo->backend_config<GpuBackendConfig>());
-    const FusionBackendConfig& backend_config =
-        gpu_config.fusion_backend_config();
-
-    AutotuneCacheKey key = AutotunerUtil::GetKey(hlo, impl_->GetConfig());
-
-    auto [iterator, inserted] = fusion_count_map_.insert({key, 1});
-    if (!inserted) {
-      ++(iterator->second);
-    }
-
-    TF_ASSIGN_OR_RETURN(bool is_in_cache,
-                        AutotunerUtil::IsInCache(key, impl_->GetConfig()));
-    if (is_in_cache || handled_fusions_.contains(key)) {
-      return absl::OkStatus();
-    }
-
-    bool missing_config = (backend_config.kind() == kTritonGemmFusionKind &&
-                           !backend_config.has_triton_gemm_config()) ||
-                          (backend_config.kind() == kCuDnnFusionKind &&
-                           !backend_config.has_cudnn_fusion_config());
-    if (missing_config) {
-      if (error_out_on_cache_miss_) {
-        return absl::NotFoundError(absl::StrCat(
-            "Complete autotuning results are required, but no cache result "
-            "found for key: ",
-            key.ToString()));
-      }
-
-      TF_ASSIGN_OR_RETURN(std::vector<Config> configs,
-                          impl_->GenerateConfigs(*fusion));
-      gemm_config_sets_.push_back({fusion, std::move(configs)});
-    }
-
-    handled_fusions_.insert(key);
-    return absl::OkStatus();
-  }
-
-  absl::Status DefaultAction(const HloInstruction* hlo) override {
-    return absl::OkStatus();
-  }
-
- private:
-  bool error_out_on_cache_miss_;
-  GemmFusionAutotunerImpl* impl_;
-  TilingConfigs gemm_config_sets_;
-  AutoTuneCacheKeyCount fusion_count_map_;
-  absl::flat_hash_set<AutotuneCacheKey> handled_fusions_;
-};
-
-struct TileSizeLimit {
-  int block_m = 0;
-  int block_n = 0;
-  int block_k = 0;
-};
-
-absl::StatusOr<TileSizeLimit> GetLimits(const HloDotInstruction& dot) {
-  TF_ASSIGN_OR_RETURN(int64_t non_contracting_index_lhs,
-                      NonContractingDimensionIndex(dot, /*operand_number=*/0));
-  TF_ASSIGN_OR_RETURN(int64_t non_contracting_index_rhs,
-                      NonContractingDimensionIndex(dot, /*operand_number=*/1));
-  TF_ASSIGN_OR_RETURN(int64_t contracting_index,
-                      ContractingDimensionIndex(dot, /*operand_number=*/1));
-  // This is not a sharp upper limit, the actual m value can be much smaller
-  // based on how much of the m dimension is physically contiguous.
-  // TODO(tdanyluk): Get the exact m value by running a TritonFusionAnalysis.
-  const int max_m = tsl::NextPowerOfTwoS64(
-      dot.operand(0)->shape().dimensions(non_contracting_index_lhs));
-  // Theoretically the same is true as for m, but that is not possible in
-  // practice with the current implementation.
-  const int max_n = tsl::NextPowerOfTwoS64(
-      dot.operand(1)->shape().dimensions(non_contracting_index_rhs));
-  // This is before doing the split-k transform.
-  const int max_k = tsl::NextPowerOfTwoS64(
-      dot.operand(1)->shape().dimensions(contracting_index));
-
-  return TileSizeLimit{
-      /*block_m=*/std::max(max_m, kMinTileSize),
-      /*block_n=*/std::max(max_n, kMinTileSize),
-      /*block_k=*/std::max(max_k, kMinTileSize),
-  };
-}
-
-int GetLogEveryN() { return VLOG_IS_ON(3) ? 100 : 1000; }
-
-absl::StatusOr<std::unique_ptr<HloModule>> TritonGemmAutotuneExtractor(
-    const TritonGemmConfig& config,
-    const se::DeviceDescription& gpu_device_info,
-    const HloFusionInstruction* fusion, DebugOptions debug_opts,
-    bool allow_filtering_kernels_spilling_registers) {
-  std::unique_ptr<HloModule> new_module =
-      ExtractInstructionIntoNewModule(*fusion);
-  if (!allow_filtering_kernels_spilling_registers) {
-    debug_opts.set_xla_gpu_filter_kernels_spilling_registers_on_autotuning(
-        false);
-  }
-  new_module->mutable_config().set_debug_options(debug_opts);
-
-  HloComputation* entry_computation = new_module->entry_computation();
-  HloInstruction* cloned_dot_fusion = entry_computation->root_instruction();
-
-  TF_ASSIGN_OR_RETURN(auto gpu_config,
-                      cloned_dot_fusion->backend_config<GpuBackendConfig>());
-  FusionBackendConfig& backend_config =
-      *gpu_config.mutable_fusion_backend_config();
-
-  *backend_config.mutable_triton_gemm_config() = config.ToProto();
-  TF_RETURN_IF_ERROR(cloned_dot_fusion->set_backend_config(gpu_config));
-
-  if (config.split_k > 1) {
-    TF_RETURN_IF_ERROR(MakeDotSplitKBatch(cloned_dot_fusion, config));
-    GpuFloatSupport bf16_support(gpu_device_info.cuda_compute_capability(),
-                                 BF16);
-    FloatNormalization float_normalization(&bf16_support);
-    TF_RETURN_IF_ERROR(float_normalization.Run(new_module.get()).status());
-
-    auto shape_size_function = [&](const Shape& shape) {
-      // The real pointer size is set in GpuCompiler. In HloCostAnalysis, the
-      // pointer size is used only to determine the size of tuple types. We
-      // shouldn't have any tuples in the autotuned module, so it's safe to use
-      // a constant here, instead of piping the real value.
-      constexpr int64_t kPointerSize = 8;
-      return ShapeUtil::ByteSizeOf(shape, kPointerSize);
-    };
-    GpuPriorityFusion priority_fusion(
-        /*thread_pool=*/nullptr, gpu_device_info,
-        GpuHloCostAnalysis::Options{/*shape_size=*/shape_size_function,
-                                    /*per_second_rates=*/{},
-                                    /*count_multiple_input_accesses=*/true});
-    TF_RETURN_IF_ERROR(priority_fusion.Run(new_module.get()).status());
-
-    // If the priority fusion pass above skipped some instructions, turn them
-    // into fusions.
-    FusionWrapper fusion_wrapper;
-    TF_RETURN_IF_ERROR(fusion_wrapper.Run(new_module.get()).status());
-  }
-  return new_module;
-}
-
-absl::StatusOr<std::unique_ptr<HloModule>> CublasGemmAutotuneExtractor(
-    const AutotuneConfig& config, const int32_t toolkit_version,
-    const HloFusionInstruction* fusion, const DebugOptions& debug_opts) {
-  const HloComputation* fusion_computation =
-      fusion->called_computations().at(0);
-  std::unique_ptr<HloModule> new_module =
-      ExtractComputationIntoNewModule(*fusion_computation);
-  new_module->mutable_config().set_debug_options(debug_opts);
-
-  auto* dot = hlo_query::GetFirstInstructionWithOpcode(
-      *new_module->entry_computation(), HloOpcode::kDot);
-  // Substitute algorithms, which are not supported by cuBLAS for the check, but
-  // don't use cuBlas in the end. This assumes that the substituting algorithm
-  // has result which are close enough for the check in this file.
-  if (dot->precision_config().algorithm() ==
-          PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3 ||
-      dot->precision_config().algorithm() ==
-          PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6) {
-    dot->mutable_precision_config()->set_algorithm(
-        PrecisionConfig::ALG_DOT_F32_F32_F32);
-  }
-
-  for (bool fp8 : {true, false}) {
-    GemmRewriter rewriter(config.GetGpuComputeCapability(), toolkit_version,
-                          fp8);
-    GpuInstructionFusion fusion_pass(
-        /*may_duplicate=*/false, config.GetExecutor()->GetDeviceDescription());
-    TF_RETURN_IF_ERROR(rewriter.Run(new_module.get()).status());
-    TF_RETURN_IF_ERROR(fusion_pass.Run(new_module.get()).status());
-  }
-  // TODO(tdanyluk): Consider running GemmAlgorithmPicker here for better cuBLAS
-  // performance. It is probably not needed on Ampere and later because cuBLAS
-  // ignores the algorithm parameter for those targets. If we run
-  // GemmAlgorithmPicker, we probably should not run this in parallel with other
-  // compilations.
-  return new_module;
-}
-
-absl::StatusOr<std::unique_ptr<HloModule>> FusionExtractor(
-    const HloFusionInstruction& fusion, const DebugOptions& debug_opts) {
-  std::unique_ptr<HloModule> module = ExtractInstructionIntoNewModule(fusion);
-  module->mutable_config().set_debug_options(debug_opts);
-  return module;
-}
-
-absl::StatusOr<std::unique_ptr<HloModule>> CuDnnFusionExtractor(
-    const HloFusionInstruction& fusion, const DebugOptions& debug_opts,
-    const int plan_id) {
-  TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
-                      FusionExtractor(fusion, debug_opts));
-
-  GpuBackendConfig gpu_config;
-  FusionBackendConfig& backend_config =
-      *gpu_config.mutable_fusion_backend_config();
-  backend_config.set_kind(std::string(kCuDnnFusionKind));
-  // Provided a plan ID the autotuner just compiles one plan.
-  backend_config.mutable_cudnn_fusion_config()->set_plan_id(plan_id);
-  TF_RETURN_IF_ERROR(
-      module->entry_computation()->root_instruction()->set_backend_config(
-          gpu_config));
-  return module;
-}
-
-bool IsFusionKind(const HloInstruction& hlo, absl::string_view kind) {
-  auto gpu_config = hlo.backend_config<GpuBackendConfig>();
-  if (!gpu_config.ok()) {
-    return false;
-  }
-  return gpu_config->fusion_backend_config().kind() == kind;
-}
-
-int GetCuDnnPlanCount(const HloInstruction& hlo,
-                      const AutotuneConfig& autotune_config) {
-  if (auto gpu_config = hlo.backend_config<GpuBackendConfig>();
-      !gpu_config.ok() ||
-      gpu_config->fusion_backend_config().has_cudnn_fusion_config()) {
-    return {};
-  }
-  return CuDnnFusionCompiler::GetAvailablePlanCount(
-      *autotune_config.GetExecutor(), *DynCast<HloFusionInstruction>(&hlo));
-}
-
-AutotuneResult FromConfig(const Config& config) {
-  AutotuneResult res;
-  if (std::holds_alternative<GemmFusionAutotunerImpl::CuBlasConfig>(config)) {
-    res.mutable_gemm()->set_algorithm(CUBLAS_GEMM_DEFAULT);
-  } else if (std::holds_alternative<GemmFusionAutotunerImpl::CuDnnConfig>(
-                 config)) {
-    res.mutable_algorithm()->set_algo_id(
-        std::get<GemmFusionAutotunerImpl::CuDnnConfig>(config).plan_id);
-  } else if (std::holds_alternative<TritonGemmConfig>(config)) {
-    *res.mutable_triton() = std::get<TritonGemmConfig>(config).ToProto();
-  } else {
-    LOG(FATAL) << "Unsupported config type: " << config.index();
-  }
-  return res;
-}
-
-absl::Status DumpOriginalFusion(AutotunerCompileUtil& util,
-                                const HloFusionInstruction& fusion,
-                                int fusion_id) {
-  TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
-                      util.ExtractModule([&](const DebugOptions& debug_opts) {
-                        return FusionExtractor(fusion, debug_opts);
-                      }));
-  module->set_name(std::string(fusion.name()));
-  // Using the original module for its debug info and name in the first
-  // parameter. It's better to include the name of both the original module
-  // and the extracted module, to avoid name clashes.
-  DumpToFileInDirOrStdout(
-      /*module=*/*fusion.GetModule(),
-      /*file_prefix=*/"",
-      /*file_suffix=*/
-      absl::StrCat("gemm_fusion_", fusion_id, ".", module->name(), ".txt"),
-      /*contents=*/module->ToString());
-  return absl::OkStatus();
-}
-
-absl::Status DumpAutotunedFusion(const AutotuneConfig& autotune_config,
-                                 const int32_t toolkit_version,
-                                 AutotunerCompileUtil& util,
-                                 const AutotuneResult result,
-                                 const HloFusionInstruction* fusion,
-                                 int fusion_id) {
-  TritonGemmConfig triton_gemm_config;
-  if (result.has_triton()) {
-    TF_ASSIGN_OR_RETURN(triton_gemm_config,
-                        TritonGemmConfig::FromProto(result.triton()));
-  }
-  const se::DeviceDescription& device_desc =
-      autotune_config.GetExecutor()->GetDeviceDescription();
-  TF_ASSIGN_OR_RETURN(
-      std::unique_ptr<HloModule> module,
-      util.ExtractModule([&](const DebugOptions& debug_opts) {
-        if (result.has_algorithm()) {
-          return CuDnnFusionExtractor(*fusion, debug_opts,
-                                      result.algorithm().algo_id());
-        } else if (result.has_triton()) {
-          return TritonGemmAutotuneExtractor(
-              triton_gemm_config, device_desc, fusion, debug_opts,
-              /*allow_filtering_kernels_spilling_registers=*/true);
-        } else if (result.has_gemm()) {
-          return CublasGemmAutotuneExtractor(autotune_config, toolkit_version,
-                                             fusion, debug_opts);
-        } else {
-          LOG(FATAL) << "Unknown result type: " << result.DebugString();
-        }
-      }));
-  module->set_name(std::string(fusion->name()));
-  // Using the original module for its debug info and name in the first
-  // parameter. It's better to include the name of both the original module
-  // and the extracted module, to avoid name clashes.
-  DumpToFileInDirOrStdout(
-      /*module=*/*fusion->GetModule(),
-      /*file_prefix=*/"",
-      /*file_suffix=*/
-      absl::StrCat("gemm_fusion_", fusion_id, ".", module->name(),
-                   ".optimized.txt"),
-      /*contents=*/module->ToString());
-  return absl::OkStatus();
-}
-
-std::string Serialize(const Config& config) {
-  if (auto triton_config = std::get_if<TritonGemmConfig>(&config)) {
-    tsl::protobuf::TextFormat::Printer printer;
-    printer.SetSingleLineMode(true);
-    std::string result;
-    printer.PrintToString(triton_config->ToProto(), &result);
-    return result;
-  }
-  return GemmFusionAutotunerImpl::ToString(config);
-}
-
-}  // anonymous namespace
-
-// Methods required for sorting the configs.
-bool GemmFusionAutotunerImpl::CuBlasConfig::operator<(
-    const CuBlasConfig& other) const {
-  return false;
-}
-bool GemmFusionAutotunerImpl::CuDnnConfig::operator<(
-    const CuDnnConfig& other) const {
-  return plan_id < other.plan_id;
-}
-
-bool GemmFusionAutotunerImpl::IsAutotuningEnabled() const {
-  return debug_options_.xla_gpu_autotune_level() > 0 &&
-         !debug_options_.xla_gpu_deterministic_ops();
-}
-
-/*static*/ std::string GemmFusionAutotunerImpl::ToString(const Config& config) {
-  if (std::holds_alternative<TritonGemmConfig>(config)) {
-    return std::get<TritonGemmConfig>(config).ToString();
-  } else if (std::holds_alternative<CuDnnConfig>(config)) {
-    return absl::StrFormat("cuDNN plan %d",
-                           std::get<CuDnnConfig>(config).plan_id);
-  } else if (std::holds_alternative<CuBlasConfig>(config)) {
-    return "reference (cublas)";
-  } else {
-    LOG(FATAL) << "Unsupported config type: " << config.index();
-  }
-}
-
-absl::StatusOr<std::vector<Config>> GemmFusionAutotunerImpl::GenerateConfigs(
-    const HloFusionInstruction& fusion) {
-  const HloDotInstruction* dot =
-      Cast<HloDotInstruction>(hlo_query::GetFirstInstructionWithOpcode(
-          *fusion.called_computations().at(0), HloOpcode::kDot));
-
-  // Add cuBLAS reference config, if available.
-  std::vector<Config> configs;
-  if (algorithm_util::IsSupportedByCublasOrCublasLt(
-          dot->precision_config().algorithm()) &&
-      !dot->sparse_operands() && IsAutotuningEnabled()) {
-    configs.push_back(CuBlasConfig{});
-  }
-
-  // Add cuDNN plans, if available.
-  bool is_hopper =
-      !config_.IsDeviceless() && GetComputeCapability().IsAtLeastHopper();
-  bool is_cudnn_enabled =
-      debug_options_.xla_gpu_cudnn_gemm_fusion_level() > 0 && is_hopper &&
-      GetDnnVersionInfoOrDefault(config_.GetExecutor()).major_version() >= 9;
-  if ((IsFusionKind(fusion, kCuDnnFusionKind) && IsAutotuningEnabled()) ||
-      (IsFusionKind(fusion, kTritonGemmFusionKind) && is_cudnn_enabled &&
-       algorithm_util::IsSupportedByCudnn(
-           dot->precision_config().algorithm()) &&
-       !dot->sparse_operands() && IsAutotuningEnabled())) {
-    const int plan_count = GetCuDnnPlanCount(fusion, config_);
-    for (int plan_id = 0; plan_id < plan_count; ++plan_id) {
-      configs.push_back(CuDnnConfig{plan_id});
-    }
-  }
-  if (IsFusionKind(fusion, kCuDnnFusionKind)) {
-    if (!IsAutotuningEnabled()) {
-      configs.push_back(CuDnnConfig{-1});
-    }
-    return configs;
-  }
-
-  // Add triton configs.
-  TF_ASSIGN_OR_RETURN(std::vector<TritonGemmConfig> triton_configs,
-                      GenerateTritonConfigs(*dot));
-  for (TritonGemmConfig& config : triton_configs) {
-    configs.push_back(std::move(config));
-  }
-  return configs;
-}
-
-absl::StatusOr<std::vector<TritonGemmConfig>>
-GemmFusionAutotunerImpl::GenerateTritonConfigs(const HloDotInstruction& dot) {
-  // Retrieve the minimum bit-width participating in the dot. This is needed
-  // to avoid autotuning configurations that are not supported by Triton. This
-  // is used to restrict the values for tile_k.
-  std::vector<const HloInstruction*> converts =
-      HloBfsFindAll({&dot}, [&](const HloInstruction* node) {
-        return node->opcode() == HloOpcode::kConvert;
-      });
-  int minBitWidth = primitive_util::BitWidth(dot.shape().element_type());
-  for (auto convert : converts) {
-    auto in_type = convert->operand(0)->shape().element_type();
-    auto out_type = convert->shape().element_type();
-    minBitWidth = std::min({minBitWidth, primitive_util::BitWidth(in_type),
-                            primitive_util::BitWidth(out_type)});
-  }
-
-  std::vector<TritonGemmConfig> result_configs;
-  TF_ASSIGN_OR_RETURN(TileSizeLimit limits, GetLimits(dot));
-
-  // Generate the list of configurations (once).
-  if (triton_configs_.empty()) {
-    triton_configs_ = !IsAutotuningEnabled()
-                          ? std::vector(1, kDefaultGemmTiling)
-                      : debug_options_.xla_gpu_exhaustive_tiling_search()
-                          ? GetExhaustiveTritonConfigs()
-                          : GetDefaultTritonConfigs();
-  }
-
-  // Avoid autotuning tiny fusions.
-  constexpr int kMinGemmElements = 32 * 32;
-  bool small_dot =
-      ShapeUtil::ElementsIn(dot.operand(0)->shape()) <= kMinGemmElements &&
-      ShapeUtil::ElementsIn(dot.operand(1)->shape()) <= kMinGemmElements;
-  std::vector<TritonGemmConfig> triton_configs =
-      small_dot ? std::vector(1, kDefaultGemmTiling) : triton_configs_;
-
-  // Split-K optimization enables more even utilization of a GPU in cases
-  // where tiling just the non-contracting dimensions of a GEMM does not create
-  // a sufficient number of thread block programs to occupy all available cores.
-  // Around 5 full waves completely avoid the need for split-K.
-  // n_tiles = split_k * (M * N) / (block_m * block_n)
-  const int kCoreCount =
-      !config_.IsDeviceless()
-          ? config_.GetExecutor()->GetDeviceDescription().core_count()
-          : 100;  // some sensible default
-  const int64_t kSufficientNumberOfTiles = kMaxWavesForSplitK * kCoreCount;
-  const int64_t result_size = ShapeUtil::ElementsIn(dot.shape());
-
-  // Triton configurations are adjusted and deduplicated.
-  absl::flat_hash_set<TritonGemmConfig> added;
-  bool is_hopper =
-      !config_.IsDeviceless() && GetComputeCapability().IsAtLeastHopper();
-  for (TritonGemmConfig& config : triton_configs) {
-    config.block_m = std::min(config.block_m, limits.block_m);
-    config.block_n = std::min(config.block_n, limits.block_n);
-    config.block_k = std::min(config.block_k, limits.block_k);
-    int max_split_k = 1;
-    if (debug_options_.xla_gpu_enable_split_k_autotuning()) {
-      int64_t ratio = kSufficientNumberOfTiles * config.block_m *
-                      config.block_n / result_size;
-      max_split_k = 1 << std::max<int>(tsl::Log2Floor64(ratio), 0);
-    }
-    config.split_k = std::min(config.split_k, max_split_k);
-
-    // TODO(b/337839570): Triton currently has a limitation where it crashes
-    // on small block_k values depending on the bit-width of the inputs to the
-    // dot. The logic below accounts for this limitation.
-    constexpr int kLdmatrixGranularity = 256;
-    config.block_k =
-        std::max(config.block_k, kLdmatrixGranularity / minBitWidth);
-
-    // Sparse meta should have at least one element per thread.
-    // Note: only 2:4 structured sparsity is currently supported.
-    if (dot.sparse_operands()) {
-      if (is_hopper) {
-        config.block_m = std::max(config.block_m, 64);
-        config.num_warps = std::max(config.num_warps, 4);
-      }
-      config.block_k = std::max(
-          config.block_k,
-          2 * std::max(kMinTileSize, kLdmatrixGranularity / minBitWidth));
-      int meta_elements = config.block_m * config.block_k / 16;
-      config.num_warps =
-          std::min<int>(config.num_warps, meta_elements / WarpSize());
-    }
-
-    if (added.insert(config).second) {
-      result_configs.push_back(config);
-    }
-  }
-  return result_configs;
-}
-
-absl::StatusOr<absl::flat_hash_map<
-    const HloFusionInstruction*,
-    std::vector<GemmFusionAutotunerImpl::ExecutableCandidate>>>
-GemmFusionAutotunerImpl::CompileAll(AutotunerCompileUtil& compile_util,
-                                    const TilingConfigs& task) {
-  tsl::profiler::ScopedAnnotation annotation("XlaAutotunerCompilation");
-  absl::Mutex results_mu;
-  absl::flat_hash_map<const HloFusionInstruction*,
-                      std::vector<ExecutableCandidate>>
-      results;
-  if (task.empty()) {
-    return results;
-  }
-
-  const int log_every_n = GetLogEveryN();
-  int64_t config_count = 0;
-  for (const auto& [unused, configs] : task) {
-    config_count += configs.size();
-  }
-
-  std::atomic<int> done_count = 0;
-  std::atomic<int> good_count = 0;
-  auto log = [&](bool success) {
-    const int done_so_far = done_count.fetch_add(1) + 1;
-    const int good_so_far =
-        success ? good_count.fetch_add(1) + 1 : good_count.load();
-    if (done_so_far % log_every_n == 0) {
-      VLOG(2) << "Compiled " << done_so_far << " of " << config_count
-              << " configs (successful: " << good_so_far << ")";
-    }
-  };
-
-  auto compile = [&](const HloFusionInstruction* fusion, const Config& config,
-                     bool allow_filtering_kernels_spilling_registers)
-      -> absl::StatusOr<bool> {
-    std::unique_ptr<Executable> executable;
-    if (std::holds_alternative<TritonGemmConfig>(config)) {
-      TF_ASSIGN_OR_RETURN(
-          executable, compile_util.Compile([&](const DebugOptions& opts) {
-            return TritonGemmAutotuneExtractor(
-                std::get<TritonGemmConfig>(config),
-                config_.GetExecutor()->GetDeviceDescription(), fusion, opts,
-                allow_filtering_kernels_spilling_registers);
-          }));
-    } else if (std::holds_alternative<CuDnnConfig>(config)) {
-      executable =
-          compile_util
-              .Compile([&](const DebugOptions& opts) {
-                return CuDnnFusionExtractor(
-                    *fusion, opts, std::get<CuDnnConfig>(config).plan_id);
-              })
-              .value_or(nullptr);
-    } else if (std::holds_alternative<CuBlasConfig>(config)) {
-      TF_ASSIGN_OR_RETURN(executable,
-                          compile_util.Compile([&](const DebugOptions& opts) {
-                            return CublasGemmAutotuneExtractor(
-                                config_, toolkit_version_, fusion, opts);
-                          }));
-    } else {
-      LOG(FATAL) << "Unsupported config type: " << config.index();
-    }
-    if (executable != nullptr) {
-      absl::MutexLock lock(&results_mu);
-      results[fusion].push_back({config, std::move(executable)});
-      return true;
-    }
-    return false;
-  };
-
-  // If the thread pool has only one thread, then it is actually slower to
-  // offload the tasks there.
-  if (thread_pool_ && thread_pool_->NumThreads() > 1 &&
-      debug_options_.xla_gpu_force_compilation_parallelism() != 1) {
-    if (task.size() == 1) {
-      absl::string_view fusion_name = task.begin()->first->name();
-      VLOG(1) << "Compiling " << config_count << " configs for " << fusion_name
-              << " on " << thread_pool_->NumThreads() << " threads.";
-    } else {
-      VLOG(1) << "Compiling " << config_count << " configs for " << task.size()
-              << " fusions on " << thread_pool_->NumThreads() << " threads.";
-    }
-
-    tsl::BlockingCounter counter(config_count);
-    for (const auto& key_value : task) {
-      const HloFusionInstruction* fusion = key_value.first;
-      const std::vector<Config>& gemm_config_set = key_value.second;
-
-      VLOG(10) << "Compiling fusion: " << fusion->name();
-      VLOG(10) << "Dumping fusion computation: "
-               << fusion->called_computation()->ToString();
-      for (const Config& config : gemm_config_set) {
-        thread_pool_->Schedule([&, fusion] {
-          VLOG(10) << "Trying configuration forceable through: "
-                      "--xla_gpu_override_gemm_autotuner='"
-                   << Serialize(config) << "'";
-          VLOG(10) << "WARNING: you are running in multithreaded-mode, the "
-                      "last configuration printed out might not be the one "
-                      "causing issues! Use "
-                      "--xla_gpu_force_compilation_parallelism=1 to fix.";
-          absl::StatusOr<bool> has_executable =
-              compile(fusion, config, gemm_config_set.size() > 1);
-          TF_CHECK_OK(has_executable.status())
-              << "Failure occured when compiling fusion " << fusion->name()
-              << " with config '" << ToString(config)
-              << "'\nFused HLO computation:\n"
-              << fusion->fused_instructions_computation()->ToString();
-          log(has_executable.value());
-          counter.DecrementCount();
-        });
-      }
-    }
-    counter.Wait();
-  } else {
-    if (task.size() == 1) {
-      absl::string_view fusion_name = task.begin()->first->name();
-      LOG(WARNING) << "Compiling " << config_count << " configs for "
-                   << fusion_name << " on a single thread.";
-    } else {
-      LOG(WARNING) << "Compiling " << config_count << " configs for "
-                   << task.size() << " fusions on a single thread.";
-    }
-
-    for (const auto& [fusion, gemm_config_set] : task) {
-      VLOG(10) << "Compiling fusion: " << fusion->name();
-      VLOG(10) << "Dumping fusion computation: "
-               << fusion->called_computation()->ToString();
-      for (const Config& config : gemm_config_set) {
-        VLOG(10) << "Trying configuration forceable through: "
-                    "--xla_gpu_override_gemm_autotuner='"
-                 << Serialize(config) << "'";
-        TF_ASSIGN_OR_RETURN(
-            bool has_executable,
-            compile(fusion, config, gemm_config_set.size() > 1));
-        log(has_executable);
-      }
-    }
-  }
-
-  VLOG(1) << "Done compiling (successful: " << good_count.load() << ").";
-  return results;
-}
-
-absl::StatusOr<std::vector<AutotuneResult>> GemmFusionAutotunerImpl::Profile(
-    AutotunerCompileUtil& compile_util, const HloFusionInstruction& fusion,
-    absl::Span<const ExecutableCandidate> candidates) {
-  const HloComputation* fusion_computation = fusion.called_computations().at(0);
-
-  se::StreamExecutor* stream_exec = config_.GetExecutor();
-  if (!stream_exec->SynchronizeAllActivity()) {
-    return Internal("Failed to synchronize GPU for autotuning.");
-  }
-  tsl::profiler::ScopedAnnotation annotation([&] {
-    return absl::StrFormat("XlaAutotunerMeasurement:#hlo_op=%s#",
-                           fusion.name());
-  });
-  se::DeviceMemoryAllocator* allocator = config_.GetAllocator();
-  std::unique_ptr<se::DeviceMemoryAllocator> owned_allocator;
-  if (allocator == nullptr) {
-    owned_allocator =
-        std::make_unique<se::StreamExecutorMemoryAllocator>(stream_exec);
-    allocator = owned_allocator.get();
-  }
-  TF_ASSIGN_OR_RETURN(se::Stream* const stream, config_.GetStream());
-
-  const HloInstruction& root = *fusion_computation->root_instruction();
-  BufferComparator comparator(root.shape(),
-                              debug_options_.xla_gpu_autotune_gemm_rtol());
-
-  TF_ASSIGN_OR_RETURN(auto rz_buffers,
-                      RedzoneBuffers::FromInstruction(
-                          *fusion_computation->FusionInstruction(), config_,
-                          debug_options_, RedzoneBuffers::kAllInputs));
-
-  const int log_every_n = GetLogEveryN();
-  std::vector<AutotuneResult> results;
-  std::optional<ScopedShapedBuffer> reference_buffer;
-  for (const ExecutableCandidate& candidate : candidates) {
-    VLOG(5) << "Trying : " << ToString(candidate.config);
-    AutotuneResult res = FromConfig(candidate.config);
-
-    std::optional<ProfilingOutput> profiling_output;
-    if (IsAutotuningEnabled()) {
-      TF_ASSIGN_OR_RETURN(
-          profiling_output,
-          compile_util.ProfileExecutable(candidate.executable.get(), stream,
-                                         rz_buffers.input_buffers(),
-                                         rz_buffers.input_shapes()));
-      if (std::holds_alternative<CuBlasConfig>(candidate.config) &&
-          config_.should_check_correctness()) {
-        reference_buffer = std::move(profiling_output->output);
-      }
-
-      int ran_so_far = results.size() + 1;
-      if (ran_so_far % log_every_n == 0) {
-        VLOG(2) << "Ran " << ran_so_far << " configs of " << candidates.size()
-                << ".";
-      }
-      if (!profiling_output) {
-        VLOG(5) << "Skipping this tiling.";
-        continue;
-      }
-
-      VLOG(5) << "Running the kernel took: " << profiling_output->duration;
-      if (profiling_output->duration >= absl::Seconds(1)) {
-        LOG(WARNING) << "Slow kernel for "
-                     << fusion.called_computations()[0]->ToString()
-                     << " took: " << profiling_output->duration << ". "
-                     << ToString(candidate.config);
-      }
-      *res.mutable_run_time() =
-          tsl::proto_utils::ToDurationProto(profiling_output->duration);
-    }
-
-    // Reference buffer is available when `config.should_check_correctness()`
-    // is set and reference executable was compiled.
-    if (reference_buffer.has_value() &&
-        !std::holds_alternative<CuBlasConfig>(candidate.config)) {
-      TF_ASSIGN_OR_RETURN(
-          se::RedzoneAllocator::RedzoneCheckStatus rz_check_status,
-          rz_buffers.RedzoneAllocator().CheckRedzones());
-      if (!rz_check_status.ok()) {
-        LOG(ERROR) << "Red zone modified";
-        res.mutable_failure()->set_kind(AutotuneResult::REDZONE_MODIFIED);
-        res.mutable_failure()->set_msg(rz_check_status.RedzoneFailureMsg());
-        CHECK(!config_.should_crash_on_check_failure());
-        continue;
-      }
-
-      TF_ASSIGN_OR_RETURN(
-          bool outputs_match,
-          comparator.CompareEqual(
-              stream, /*current=*/profiling_output->output.root_buffer(),
-              /*expected=*/reference_buffer->root_buffer()));
-      if (!outputs_match) {
-        const char kMessage[] =
-            "Results do not match the reference. This is likely a "
-            "bug/unexpected loss of precision.";
-        LOG(ERROR) << kMessage;
-        CHECK(!config_.should_crash_on_check_failure());
-        // WRONG_RESULT is not taken seriously by PickBestResult(), so
-        // use DISQUALIFIED.
-        res.mutable_failure()->set_kind(AutotuneResult::DISQUALIFIED);
-        res.mutable_failure()->set_msg(kMessage);
-      }
-    }
-    results.push_back(std::move(res));
-  }
-  VLOG(2) << "Done running.";
-  return results;
-}
-
-std::vector<TritonGemmConfig>
-GemmFusionAutotunerImpl::GetExhaustiveTritonConfigs() const {
-  std::vector<TritonGemmConfig> configs;
-  se::CudaComputeCapability cc = GetComputeCapability();
-  bool tune_ctas =
-      debug_options_.xla_gpu_enable_triton_hopper() && cc.IsAtLeastHopper();
-
-  for (int num_stages : kNumStages) {
-    // Volta doesn't support num_stages > 2.
-    if (!cc.IsAtLeastAmpere() && num_stages > 2) {
-      break;
-    }
-    for (int tile_m : kBlockSizes) {
-      for (int tile_n : kBlockSizes) {
-        for (int tile_k : kBlockSizes) {
-          const int tile_lhs = tile_m * tile_k;
-          const int tile_rhs = tile_k * tile_n;
-          for (int num_warps : kNumWarps) {
-            // Each thread should read at least one input element.
-            if (num_warps * WarpSize() > std::min(tile_lhs, tile_rhs)) {
-              break;
-            }
-            for (int split_k : kSplitK) {
-              // Split-K autotuning may be disabled by a flag.
-              if (!debug_options_.xla_gpu_enable_split_k_autotuning() &&
-                  split_k > 1) {
-                break;
-              }
-              for (int num_ctas : kNumCtas) {
-                // Clusters are only supported on Hopper.
-                // Autotuning this parameter is enabled by a flag.
-                if (!tune_ctas && num_ctas > 1) {
-                  break;
-                }
-                if (num_ctas > num_warps) {
-                  break;
-                }
-                configs.push_back(TritonGemmConfig(tile_m, tile_n, tile_k,
-                                                   split_k, num_stages,
-                                                   num_warps, num_ctas));
-              }
-            }
-          }
-        }
-      }
-    }
-  }
-  return configs;
-}
-
-std::vector<TritonGemmConfig> GemmFusionAutotunerImpl::GetDefaultTritonConfigs()
-    const {
-  using Config = TritonGemmConfig;
-  std::vector<Config> configs = {
-      Config(32, 32, 256, 1, 1, 4), Config(64, 32, 32, 16, 1, 4),
-      Config(32, 64, 64, 4, 1, 4),  Config(128, 128, 64, 4, 1, 4),
-      Config(16, 16, 256, 1, 1, 4), Config(16, 128, 32, 16, 1, 4),
-      Config(16, 64, 128, 1, 1, 4), Config(16, 128, 32, 8, 1, 4),
-      Config(16, 16, 512, 1, 1, 4), Config(32, 16, 512, 1, 1, 4),
-      Config(64, 32, 64, 1, 2, 8)};
-  if (GetComputeCapability().IsAtLeastAmpere()) {
-    absl::c_copy(
-        std::vector<Config>{
-            Config(128, 256, 32, 1, 3, 8),  Config(256, 128, 32, 1, 3, 8),
-            Config(256, 64, 32, 1, 4, 4),   Config(64, 256, 32, 1, 4, 4),
-            Config(128, 64, 32, 1, 4, 4),   Config(64, 128, 32, 1, 4, 4),
-            Config(256, 128, 128, 1, 3, 8), Config(256, 64, 128, 1, 4, 4),
-            Config(64, 256, 128, 1, 4, 4),  Config(128, 128, 128, 1, 4, 4),
-            Config(128, 64, 64, 1, 4, 4),   Config(64, 128, 64, 1, 4, 4),
-            Config(128, 32, 64, 1, 4, 4),   Config(64, 32, 64, 1, 4, 4),
-            Config(32, 128, 32, 1, 4, 4),   Config(128, 128, 32, 1, 4, 4),
-            Config(16, 16, 256, 1, 3, 4),   Config(128, 128, 64, 2, 1, 8),
-            Config(64, 64, 64, 1, 2, 4),    Config(16, 64, 256, 8, 1, 4),
-            Config(256, 256, 128, 1, 3, 8)},
-        std::back_inserter(configs));
-  }
-  if (GetComputeCapability().IsAtLeastHopper()) {
-    absl::c_copy(
-        std::vector<Config>{
-            Config(16, 32, 32, 8, 1, 2),
-            Config(16, 64, 128, 8, 1, 4),
-            Config(16, 64, 128, 16, 3, 4),
-        },
-        std::back_inserter(configs));
-  }
-  return configs;
-}
-
-absl::Status DumpAutotuningLogs(const DebugOptions& debug_opts,
-                                const AutotuningLogs& autotuning_logs) {
-  if (absl::string_view file_path = debug_opts.xla_gpu_dump_autotune_logs_to();
-      !file_path.empty()) {
-    std::string resolved_path;
-    if (!tsl::io::ResolveTestPrefixes(file_path, resolved_path)) {
-      return FailedPrecondition("File path can not be resolved: %s", file_path);
-    }
-
-    std::string textproto;
-    tsl::protobuf::TextFormat::PrintToString(autotuning_logs, &textproto);
-
-    TF_RETURN_IF_ERROR(
-        tsl::WriteStringToFile(tsl::Env::Default(), resolved_path, textproto));
-    LOG(INFO) << "Autotune logs serialized to file: " << resolved_path;
-  }
-  return absl::OkStatus();
-}
-
-absl::Status GemmFusionAutotunerImpl::Autotune(
-    AutotunerCompileUtil& compile_util, const TilingConfigs& gemm_config_sets,
-    AutoTuneCacheKeyCount fusion_count_map) {
-  TF_ASSIGN_OR_RETURN(auto executable_sets,
-                      CompileAll(compile_util, gemm_config_sets));
-
-  // Sort the candidates to make their execution order well-defined for each
-  // fusion.
-  for (auto& [unused, candidates] : executable_sets) {
-    absl::c_sort(candidates, [](const auto& a, const auto& b) {
-      return a.config < b.config;
-    });
-  }
-
-  AutotuningLogs autotuning_logs;
-  int fusion_id = 0;
-  for (const auto& [fusion, candidates] : executable_sets) {
-    TF_ASSIGN_OR_RETURN(std::vector<AutotuneResult> results,
-                        Profile(compile_util, *fusion, candidates));
-
-    // The reference config (if it exists) will be the first in the results,
-    // due to how sorting the variants work.
-    if (!debug_options_.xla_gpu_cublas_fallback() &&
-        results.front().has_gemm()) {
-      results.erase(results.begin());
-    }
-
-    const HloInstruction* root =
-        fusion->called_computations().at(0)->root_instruction();
-    TF_ASSIGN_OR_RETURN(
-        AutotuneResult best,
-        PickBestResult(results, root->ToString(), root->GetModule()->config()));
-    VLOG(2) << "Best time: "
-            << tsl::proto_utils::FromDurationProto(best.run_time());
-
-    if (debug_options_.xla_gpu_dump_autotuned_gemm_fusions()) {
-      TF_RETURN_IF_ERROR(DumpOriginalFusion(compile_util, *fusion, fusion_id));
-      TF_RETURN_IF_ERROR(DumpAutotunedFusion(
-          config_, toolkit_version_, compile_util, best, fusion, fusion_id++));
-    }
-
-    const AutotuneCacheKey key = AutotunerUtil::GetKey(fusion, config_);
-    TF_ASSIGN_OR_RETURN(
-        bool added, AutotunerUtil::AddResult(key, std::move(best), config_));
-    if (!added) {
-      // In the context of model server, concurrent autotuning is expected and
-      // insertion of identical autotuning keys is accepted.
-      LOG(WARNING) << "AutotunerUtil::AddResult already existed: "
-                   << key.ToString();
-    }
-
-    if (!debug_options_.xla_gpu_dump_autotune_logs_to().empty()) {
-      auto autotuning_log = autotuning_logs.add_logs();
-      autotuning_log->set_fusion_name(std::string(fusion->name()));
-
-      for (const auto& autotune_result : results) {
-        auto log_result = autotuning_log->add_results();
-        log_result->CopyFrom(autotune_result);
-      }
-
-      if (auto fusion_key_count = fusion_count_map.find(key);
-          fusion_key_count != fusion_count_map.end()) {
-        auto fusion_key = fusion_key_count->first;
-        auto fusion_count = fusion_key_count->second;
-        autotuning_log->set_fusion_count(fusion_count);
-      }
-    }
-  }
-
-  TF_RETURN_IF_ERROR(DumpAutotuningLogs(debug_options_, autotuning_logs));
-
-  return absl::OkStatus();
-}
-
-// Trim the set of configs to what one rank has to run.
-static TilingConfigs TrimConfigs(const TilingConfigs& gemm_config_sets,
-                                 const int shard_index, const int shard_count) {
-  const uint64_t bucket_size =
-      (gemm_config_sets.size() + shard_count - 1) / shard_count;
-  const uint64_t start = bucket_size * shard_index;
-  const uint64_t end = std::min(start + bucket_size, gemm_config_sets.size());
-  if (start >= end) {
-    return {};
-  }
-  return TilingConfigs(gemm_config_sets.cbegin() + start,
-                       gemm_config_sets.cbegin() + end);
-}
-
-// Exchange the results with the other ranks.
-absl::Status ExchangeResults(KeyValueStoreInterface& key_value_store,
-                             const int module_id, const int shard_index,
-                             const int shard_count) {
-  AutotuneResults results;
-  TF_RETURN_IF_ERROR(AutotunerUtil::SerializeAutotuneResults(&results));
-  TF_ASSIGN_OR_RETURN(std::string results_str,
-                      AutotuneResultsToString(results, true));
-  constexpr absl::string_view kKeyPrefix = "gemm_fusion_autotuning_results";
-  TF_RETURN_IF_ERROR(key_value_store.Set(
-      absl::StrFormat("%s_%d_%d", kKeyPrefix, module_id, shard_index),
-      results_str));
-  VLOG(2) << "Rank " << shard_index << ": published results";
-  for (int i = 0; i < shard_count; ++i) {
-    if (i == shard_index) {
-      continue;
-    }
-    VLOG(2) << "Rank " << shard_index << ": waiting for results from rank " << i
-            << " / " << shard_count;
-    TF_ASSIGN_OR_RETURN(
-        std::string autotune_results_str,
-        key_value_store.Get(
-            absl::StrFormat("%s_%d_%d", kKeyPrefix, module_id, i),
-            absl::InfiniteDuration()));
-    TF_RETURN_IF_ERROR(
-        AutotunerUtil::LoadAutotuneResults(autotune_results_str, true));
-  }
-  return absl::OkStatus();
-}
-
-absl::StatusOr<bool> GemmFusionAutotuner::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  XLA_SCOPED_LOGGING_TIMER("GEMM fusion autotuner");
-
-  const DebugOptions& debug_options = module->config().debug_options();
-  GemmFusionAutotunerImpl autotuner(config_, toolkit_version_, debug_options,
-                                    thread_pool_);
-  GemmConfigSetCollector gemm_config_set_collector(&autotuner);
-  TF_ASSIGN_OR_RETURN(TilingConfigs gemm_config_sets,
-                      gemm_config_set_collector.CollectGemmConfigSets(
-                          module, execution_threads));
-  const int total_fusion_count = gemm_config_sets.size();
-
-  AutoTuneCacheKeyCount fusion_count_map =
-      gemm_config_set_collector.GetFusionsCount();
-
-  if (!autotuner.IsAutotuningEnabled()) {
-    // Pick the first option for each gemm instead of autotuning.
-    for (const auto& [fusion, tilings] : gemm_config_sets) {
-      const AutotuneCacheKey key = AutotunerUtil::GetKey(fusion, config_);
-      AutotuneResult res = FromConfig(tilings[0]);
-      *res.mutable_run_time() =
-          tsl::proto_utils::ToDurationProto(absl::ZeroDuration());
-      TF_RETURN_IF_ERROR(AutotunerUtil::AddResult(key, res, config_).status());
-    }
-  } else if (!debug_options.xla_gpu_override_gemm_autotuner().empty()) {
-    // TODO(gflegar): support overriding with non-Triton configs (cuBLAS, cuDNN)
-    AutotuneResult::TritonGemmKey gemm_key;
-    CHECK(tsl::protobuf::TextFormat::ParseFromString(
-        debug_options.xla_gpu_override_gemm_autotuner(), &gemm_key));
-    VLOG(1) << "Overriding GEMM autotuner with the following config: "
-            << gemm_key.DebugString();
-    for (const auto& [fusion, unused] : gemm_config_sets) {
-      const AutotuneCacheKey key = AutotunerUtil::GetKey(fusion, config_);
-      AutotuneResult res;
-      *res.mutable_triton() = gemm_key;
-      *res.mutable_run_time() =
-          tsl::proto_utils::ToDurationProto(absl::ZeroDuration());
-      TF_RETURN_IF_ERROR(AutotunerUtil::AddResult(key, res, config_).status());
-    }
-  } else if (!config_.IsDeviceless()) {
-    TF_ASSIGN_OR_RETURN(std::optional<AutotunerCompileUtil> opt_compile_util,
-                        AutotunerCompileUtil::Create(config_, debug_options));
-    TF_RET_CHECK(opt_compile_util.has_value());
-    std::string correctness_check_str = config_.should_check_correctness()
-                                            ? "(with correctness check)"
-                                            : "(without correctness check)";
-
-    const bool shard_autotuning = debug_options.xla_gpu_shard_autotuning() &&
-                                  key_value_store_.process_count > 1 &&
-                                  total_fusion_count > 0;
-    if (shard_autotuning) {
-      if (key_value_store_.key_value_store == nullptr) {
-        return absl::FailedPreconditionError(
-            "Sharded autotuning requested but key-value store is missing.");
-      }
-      gemm_config_sets =
-          TrimConfigs(gemm_config_sets, key_value_store_.process_index,
-                      key_value_store_.process_count);
-    }
-
-    VLOG(1) << absl::StrFormat(
-        "Shard %d / %d: autotuning %d / %d fusions for %s %s.",
-        key_value_store_.process_index + 1, key_value_store_.process_count,
-        gemm_config_sets.size(), total_fusion_count, module->name(),
-        correctness_check_str);
-    TF_RETURN_IF_ERROR(autotuner.Autotune(*opt_compile_util, gemm_config_sets,
-                                          std::move(fusion_count_map)));
-    VLOG(1) << "Done autotuning.";
-
-    if (shard_autotuning) {
-      TF_RETURN_IF_ERROR(ExchangeResults(
-          *key_value_store_.key_value_store, module->unique_id(),
-          key_value_store_.process_index, key_value_store_.process_count));
-    }
-  }
-
-  return GemmFusionAutotunerVisitor(config_).RunOnModule(module,
-                                                         execution_threads);
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gemm_fusion_autotuner.h b/third_party/xla/xla/service/gpu/gemm_fusion_autotuner.h
deleted file mode 100644
index 2815792..0000000
--- a/third_party/xla/xla/service/gpu/gemm_fusion_autotuner.h
+++ /dev/null
@@ -1,147 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_GEMM_FUSION_AUTOTUNER_H_
-#define XLA_SERVICE_GPU_GEMM_FUSION_AUTOTUNER_H_
-
-#include <cstdint>
-#include <memory>
-#include <string>
-#include <utility>
-#include <variant>
-#include <vector>
-
-#include "absl/container/flat_hash_map.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "xla/autotuning.pb.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/pjrt/distributed/key_value_store_interface.h"
-#include "xla/service/executable.h"
-#include "xla/service/gpu/autotuner_compile_util.h"
-#include "xla/service/gpu/autotuner_util.h"
-#include "xla/service/gpu/matmul_utils.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/xla.pb.h"
-#include "tsl/platform/threadpool.h"
-
-namespace xla {
-namespace gpu {
-
-// Find best tiling configuration for each triton fusion outlined.
-class GemmFusionAutotuner : public HloModulePass {
- public:
-  explicit GemmFusionAutotuner(const AutotuneConfig& config,
-                               const int32_t toolkit_version,
-                               tsl::thread::ThreadPool* thread_pool,
-                               const MultiProcessKeyValueStore& key_value_store)
-      : config_(config),
-        toolkit_version_(toolkit_version),
-        thread_pool_(thread_pool),
-        key_value_store_(key_value_store) {}
-
-  absl::string_view name() const override { return "triton-autotuner"; }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
-  const AutotuneConfig config_;
-  const int32_t toolkit_version_;
-  tsl::thread::ThreadPool* thread_pool_;
-  MultiProcessKeyValueStore key_value_store_;
-};
-
-// Autotuner implementation.
-class GemmFusionAutotunerImpl {
- public:
-  GemmFusionAutotunerImpl(const AutotuneConfig config,
-                          const int32_t toolkit_version,
-                          const DebugOptions debug_options,
-                          tsl::thread::ThreadPool* thread_pool)
-      : config_(std::move(config)),
-        toolkit_version_(toolkit_version),
-        debug_options_(std::move(debug_options)),
-        thread_pool_(thread_pool) {}
-
-  struct CuBlasConfig {
-    bool operator<(const CuBlasConfig& other) const;
-  };
-  struct CuDnnConfig {
-    int64_t plan_id;
-    bool operator<(const CuDnnConfig& other) const;
-  };
-  using Config = std::variant<CuBlasConfig, CuDnnConfig, TritonGemmConfig>;
-  using TilingConfigs =
-      std::vector<std::pair<const HloFusionInstruction*, std::vector<Config>>>;
-
-  struct ExecutableCandidate {
-    Config config;
-    std::unique_ptr<Executable> executable;
-  };
-
-  // Generate all possible configs for a dot operation.
-  absl::StatusOr<std::vector<Config>> GenerateConfigs(
-      const HloFusionInstruction& fusion);
-  absl::StatusOr<std::vector<TritonGemmConfig>> GenerateTritonConfigs(
-      const HloDotInstruction& dot);
-
-  // Compile all executables for all fusions.
-  absl::StatusOr<absl::flat_hash_map<const HloFusionInstruction*,
-                                     std::vector<ExecutableCandidate>>>
-  CompileAll(AutotunerCompileUtil& compile_util, const TilingConfigs& task);
-
-  // Profile all executables for a fusion.
-  absl::StatusOr<std::vector<AutotuneResult>> Profile(
-      AutotunerCompileUtil& compile_util, const HloFusionInstruction& fusion,
-      absl::Span<const ExecutableCandidate> candidates);
-
-  // Autotune and save the results to the autotuning cache.
-  absl::Status Autotune(
-      AutotunerCompileUtil& compile_util, const TilingConfigs& gemm_config_sets,
-      absl::flat_hash_map<AutotuneCacheKey, uint64_t> fusion_count_map);
-
-  // Helper methods.
-  const AutotuneConfig& GetConfig() const { return config_; }
-  bool IsAutotuningEnabled() const;
-  static std::string ToString(const Config& config);
-
- private:
-  se::CudaComputeCapability GetComputeCapability() const {
-    return std::get<se::CudaComputeCapability>(
-        config_.GetGpuComputeCapability());
-  }
-
-  std::vector<TritonGemmConfig> GetDefaultTritonConfigs() const;
-  std::vector<TritonGemmConfig> GetExhaustiveTritonConfigs() const;
-
-  const AutotuneConfig config_;
-  const int32_t toolkit_version_;
-  const DebugOptions debug_options_;
-  tsl::thread::ThreadPool* thread_pool_;
-  std::vector<TritonGemmConfig> triton_configs_;
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_GEMM_FUSION_AUTOTUNER_H_
diff --git a/third_party/xla/xla/service/gpu/gemm_fusion_autotuner_test.cc b/third_party/xla/xla/service/gpu/gemm_fusion_autotuner_test.cc
deleted file mode 100644
index b0d8ba6..0000000
--- a/third_party/xla/xla/service/gpu/gemm_fusion_autotuner_test.cc
+++ /dev/null
@@ -1,946 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "xla/service/gpu/gemm_fusion_autotuner.h"
-
-#include <algorithm>
-#include <cstdint>
-#include <memory>
-#include <string>
-#include <utility>
-#include <variant>
-#include <vector>
-
-#include <gtest/gtest.h>
-#include "absl/log/check.h"
-#include "absl/log/log.h"
-#include "absl/strings/str_format.h"
-#include "absl/strings/string_view.h"
-#include "third_party/gpus/cuda/include/cuda.h"
-#include "xla/autotuning.pb.h"
-#include "xla/error_spec.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/pjrt/distributed/key_value_store_interface.h"
-#include "xla/service/call_inliner.h"
-#include "xla/service/dump.h"
-#include "xla/service/executable.h"
-#include "xla/service/gpu/autotuner_util.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/gemm_fusion.h"
-#include "xla/service/gpu/gemm_rewriter.h"
-#include "xla/service/gpu/ir_emission_utils.h"
-#include "xla/service/gpu/matmul_utils.h"
-#include "xla/service/hlo_module_config.h"
-#include "xla/service/hlo_pass_pipeline.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/stream_executor/device_description.pb.h"
-#include "xla/tests/filecheck.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/tests/test_utils.h"
-#include "xla/tests/verified_hlo_module.h"
-#include "xla/tools/hlo_decomposer.h"
-#include "xla/tsl/lib/core/status_test_util.h"
-#include "xla/xla.pb.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/cpu_info.h"
-#include "tsl/platform/env.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/path.h"
-#include "tsl/platform/status_matchers.h"
-#include "tsl/platform/statusor.h"
-#include "tsl/platform/threadpool.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-namespace m = ::xla::match;
-
-using HloExtractionTest = HloTestBase;
-
-TEST_F(HloExtractionTest, InstructionExtractionIsCorrect) {
-  std::unique_ptr<VerifiedHloModule> module = ParseAndReturnVerifiedModule(R"(
-HloModule module
-
-triton_gemm_dot {
-  p0 = s8[10,10] parameter(0)
-  p1 = f32[10,10] parameter(1)
-  c0 = f32[10,10] convert(p0)
-  ROOT dot.0 = f32[10,10] dot(c0, p1),
-    lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
-
-ENTRY entry {
-  p0 = s8[10,10] parameter(0)
-  p1 = f32[10,10] parameter(1)
-  s = f32[10,10] sqrt(p1)
-  d = f32[10,10] fusion(p0, p1),
-    kind=kCustom, calls=triton_gemm_dot
-  ROOT r = f32[10,10] add(d, s)
-})")
-                                                  .value();
-
-  std::unique_ptr<HloModule> extracted_module = ExtractInstructionIntoNewModule(
-      *module->entry_computation()->root_instruction()->operand(0));
-
-  // Destroy the original module to be sure that the extracted one has no
-  // dependency on it.
-  module.release();
-
-  EXPECT_THAT(extracted_module->entry_computation()->root_instruction(),
-              GmockMatch(m::Fusion(m::Parameter(), m::Parameter())));
-  EXPECT_EQ(extracted_module->entry_computation()->instruction_count(), 3);
-  TF_EXPECT_OK(VerifyHloModule(extracted_module.get(),
-                               /*layout_sensitive=*/true,
-                               /*allow_mixed_precision=*/false));
-}
-
-TEST_F(HloExtractionTest, ComputationExtractionIsCorrect) {
-  std::unique_ptr<VerifiedHloModule> module = ParseAndReturnVerifiedModule(R"(
-HloModule module
-
-triton_gemm_dot {
-  p0 = s8[10,10] parameter(0)
-  p1 = f32[10,10] parameter(1)
-  c0 = f32[10,10] convert(p0)
-  ROOT dot.0 = f32[10,10] dot(c0, p1),
-    lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
-
-ENTRY entry {
-  p0 = s8[10,10] parameter(0)
-  p1 = f32[10,10] parameter(1)
-  s = f32[10,10] sqrt(p1)
-  d = f32[10,10] fusion(p0, p1),
-    kind=kCustom, calls=triton_gemm_dot
-  ROOT r = f32[10,10] add(d, s)
-})")
-                                                  .value();
-
-  std::unique_ptr<HloModule> extracted_module =
-      ExtractComputationIntoNewModule(*module->entry_computation()
-                                           ->root_instruction()
-                                           ->operand(0)
-                                           ->fused_instructions_computation());
-
-  // Destroy the original module to be sure that the extracted one has no
-  // dependency on it.
-  module.release();
-
-  EXPECT_THAT(extracted_module->entry_computation()->root_instruction(),
-              GmockMatch(m::Dot(m::Convert(m::Parameter()), m::Parameter())));
-  EXPECT_EQ(extracted_module->entry_computation()->instruction_count(), 4);
-  TF_EXPECT_OK(VerifyHloModule(extracted_module.get(),
-                               /*layout_sensitive=*/true,
-                               /*allow_mixed_precision=*/false));
-}
-
-class StatelessAutotunerTest : public HloTestBase {
- public:
-  StatelessAutotunerTest()
-      : HloTestBase(/*verifier_layout_sensitive=*/true,
-                    /*allow_mixed_precision_in_hlo_verifier=*/false) {}
-
-  int32_t GetToolkitVersion() const { return CUDA_VERSION; }
-
-  void SetUp() override {
-    AutotunerUtil::ClearAutotuneResults();
-    HloTestBase::SetUp();
-  }
-
-  void TearDown() override {
-    AutotunerUtil::ClearAutotuneResults();
-    HloTestBase::TearDown();
-  }
-};
-
-class GemmFusionAutotunerTest : public StatelessAutotunerTest {
- public:
-  DebugOptions GetDebugOptionsForTest() override {
-    DebugOptions debug_options =
-        StatelessAutotunerTest::GetDebugOptionsForTest();
-    debug_options.set_xla_gpu_enable_triton_gemm(true);
-    debug_options.set_xla_gpu_cublas_fallback(false);
-    debug_options.set_xla_gpu_cudnn_gemm_fusion_level(0);
-    return debug_options;
-  }
-
-  se::CudaComputeCapability GetCudaComputeCapability() {
-    return backend()
-        .default_stream_executor()
-        ->GetDeviceDescription()
-        .cuda_compute_capability();
-  }
-
-  void CheckTritonAutotuning(absl::string_view hlo,
-                             absl::string_view expected) {
-    HloPassPipeline pipeline("gemm_rewrite");
-    pipeline.AddPass<GemmFusion>(backend()
-                                     .default_stream_executor()
-                                     ->GetDeviceDescription()
-                                     .cuda_compute_capability());
-    tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "",
-                                        tsl::port::MaxParallelism());
-    DebugOptions opts;
-    MultiProcessKeyValueStore key_value_store;
-    pipeline.AddPass<GemmFusionAutotuner>(
-        AutotuneConfig{DeviceConfig{backend().default_stream_executor(),
-                                    backend().memory_allocator()},
-                       opts},
-        GetToolkitVersion(), &thread_pool, key_value_store);
-
-    RunAndFilecheckHloRewrite(
-        hlo, std::move(pipeline), expected, [](const HloModule* m) {
-          VLOG(5) << m->ToString();
-          const HloInstruction* dot_fusion =
-              m->entry_computation()->root_instruction();
-          if (dot_fusion->opcode() == HloOpcode::kReduce) {
-            dot_fusion = dot_fusion->operand(0);
-          }
-          CHECK_EQ(dot_fusion->opcode(), HloOpcode::kFusion);
-          if (!dot_fusion->backend_config<GpuBackendConfig>()
-                   ->fusion_backend_config()
-                   .has_cudnn_fusion_config()) {
-            CHECK_GT(dot_fusion->backend_config<GpuBackendConfig>()
-                         .value()
-                         .fusion_backend_config()
-                         .triton_gemm_config()
-                         .block_m(),
-                     0);
-          }
-        });
-  }
-};
-
-class GemmFusionAutotunerTestWithMorePreciseReduction
-    : public GemmFusionAutotunerTest {
- public:
-  DebugOptions GetDebugOptionsForTest() override {
-    DebugOptions debug_options =
-        GemmFusionAutotunerTest::GetDebugOptionsForTest();
-    debug_options.set_xla_gpu_triton_gemm_disable_reduced_precision_reduction(
-        true);
-    return debug_options;
-  }
-};
-
-absl::StatusOr<std::vector<TritonGemmConfig>> GetPossibleMatmulAutotuneConfigs(
-    const HloDotInstruction& dot,
-    const se::CudaComputeCapability& compute_capability,
-    const int32_t toolkit_version, const DebugOptions& debug_options) {
-  se::GpuDeviceInfoProto deviceless_proto;
-  auto ccc = deviceless_proto.mutable_cuda_compute_capability();
-  ccc->set_major(compute_capability.major);
-  ccc->set_minor(compute_capability.minor);
-  DevicelessConfig test_config{se::DeviceDescription{deviceless_proto}};
-  AutotuneConfig autotune_config{test_config, debug_options};
-  GemmFusionAutotunerImpl autotuner(autotune_config, toolkit_version,
-                                    debug_options, nullptr);
-  return autotuner.GenerateTritonConfigs(dot);
-}
-
-TEST_F(GemmFusionAutotunerTest, AmpereUsesMoreThanTwoStages) {
-  std::unique_ptr<VerifiedHloModule> module = ParseAndReturnVerifiedModule(R"(
-ENTRY e {
-  p0 = f32[1024,1024] parameter(0)
-  p1 = f32[1024,1024] parameter(1)
-  ROOT r = f32[1024,1024] dot(p0, p1),
-    lhs_contracting_dims={1}, rhs_contracting_dims={0}
-})")
-                                                  .value();
-  const se::CudaComputeCapability compute_capability{
-      se::CudaComputeCapability::AMPERE, /*minor=*/0};
-  TF_ASSERT_OK_AND_ASSIGN(
-      const std::vector<TritonGemmConfig> configs,
-      GetPossibleMatmulAutotuneConfigs(
-          *Cast<HloDotInstruction>(
-              module->entry_computation()->root_instruction()),
-          compute_capability, GetToolkitVersion(), GetDebugOptionsForTest()));
-  EXPECT_TRUE(std::any_of(
-      configs.begin(), configs.end(),
-      [](const TritonGemmConfig& config) { return config.num_stages > 2; }));
-}
-
-TEST_F(GemmFusionAutotunerTest, SmallOutputCanUseLargeSplitK) {
-  std::unique_ptr<VerifiedHloModule> module = ParseAndReturnVerifiedModule(R"(
-ENTRY e {
-  p0 = f32[1024,1024] parameter(0)
-  p1 = f32[1024,1024] parameter(1)
-  ROOT r = f32[1024,1024] dot(p0, p1),
-    lhs_contracting_dims={1}, rhs_contracting_dims={0}
-})")
-                                                  .value();
-  const se::CudaComputeCapability compute_capability{
-      se::CudaComputeCapability::AMPERE, /*minor=*/0};
-  TF_ASSERT_OK_AND_ASSIGN(
-      const std::vector<TritonGemmConfig> configs,
-      GetPossibleMatmulAutotuneConfigs(
-          *Cast<HloDotInstruction>(
-              module->entry_computation()->root_instruction()),
-          compute_capability, GetToolkitVersion(), GetDebugOptionsForTest()));
-  EXPECT_TRUE(std::any_of(
-      configs.begin(), configs.end(),
-      [](const TritonGemmConfig& config) { return config.split_k >= 4; }));
-}
-
-TEST_F(GemmFusionAutotunerTest, LargeOutputDoesNotUseLargeSplitK) {
-  std::unique_ptr<VerifiedHloModule> module = ParseAndReturnVerifiedModule(R"(
-ENTRY e {
-  p0 = f32[20480,20480] parameter(0)
-  p1 = f32[20480,20480] parameter(1)
-  ROOT r = f32[20480,20480] dot(p0, p1),
-    lhs_contracting_dims={1}, rhs_contracting_dims={0}
-})")
-                                                  .value();
-  const se::CudaComputeCapability compute_capability{
-      se::CudaComputeCapability::AMPERE, /*minor=*/0};
-  TF_ASSERT_OK_AND_ASSIGN(
-      const std::vector<TritonGemmConfig> configs,
-      GetPossibleMatmulAutotuneConfigs(
-          *Cast<HloDotInstruction>(
-              module->entry_computation()->root_instruction()),
-          compute_capability, GetToolkitVersion(), GetDebugOptionsForTest()));
-  EXPECT_FALSE(std::any_of(
-      configs.begin(), configs.end(),
-      [](const TritonGemmConfig& config) { return config.split_k > 1; }));
-}
-
-TEST_F(GemmFusionAutotunerTest, Int8FusedGemm) {
-  const std::string hlo = R"(
-HloModule module
-
-ENTRY e {
-  x = s8[128,64] parameter(0)
-  c = f16[128,64] convert(x)
-
-  y = f16[64,6144] parameter(1)
-
-  ROOT out = f16[128,6144] dot(c, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
-)";
-  CheckTritonAutotuning(hlo, R"(
-// CHECK: ENTRY
-// CHECK: ROOT
-// CHECK-SAME: kCustom
-// CHECK-SAME: block_m
-)");
-
-  EXPECT_TRUE(RunAndCompare(hlo, ErrorSpec{/*aabs=*/5e-3, /*arel=*/5e-3}));
-}
-
-TEST_F(GemmFusionAutotunerTest, Int8FusedGemm256) {
-  const std::string hlo = R"(
-HloModule module
-
-ENTRY e {
-  x = s8[128,256] parameter(0)
-  c = f16[128,256] convert(x)
-
-  y = f16[256,6144] parameter(1)
-
-  ROOT out = f16[128,6144] dot(c, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
-)";
-
-  CheckTritonAutotuning(hlo, R"(
-// CHECK: ENTRY
-// CHECK: ROOT
-// CHECK-SAME: kCustom
-// CHECK-SAME: block_m
-)");
-
-  EXPECT_TRUE(RunAndCompare(hlo, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2}));
-}
-
-TEST_F(GemmFusionAutotunerTest, SelectsSplitK) {
-  // Shapes with K >> M, N have to force split-K configurations.
-  const std::string kHloText = R"(
-HloModule t
-
-ENTRY e {
-  p0 = s8[7,8192] parameter(0)
-  p0c = f16[7,8192] convert(p0)
-  p1 = f16[8192,18] parameter(1)
-  ROOT dot.0 = f16[7,18] dot(p0c, p1),
-    lhs_contracting_dims={1}, rhs_contracting_dims={0}
-})";
-
-  MatchOptimizedHlo(kHloText, R"(
-; CHECK: reduce
-; CHECK: ENTRY
-; CHECK-NEXT: parameter
-; CHECK-NEXT: parameter
-; CHECK-NEXT: kCustom
-; CHECK-NEXT: kLoop
-)");
-
-  EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1, /*arel=*/0.5}));
-}
-
-TEST_F(GemmFusionAutotunerTestWithMorePreciseReduction, SelectsSplitK) {
-  // Shapes with K >> M, N have to force split-K configurations.
-  constexpr absl::string_view kHloText = R"(
-HloModule t
-
-ENTRY e {
-  p0 = s8[7,8192] parameter(0)
-  p0c = f16[7,8192] convert(p0)
-  p1 = f16[8192,18] parameter(1)
-  ROOT dot.0 = f16[7,18] dot(p0c, p1),
-    lhs_contracting_dims={1}, rhs_contracting_dims={0}
-})";
-
-  MatchOptimizedHlo(kHloText, R"(
-; CHECK: reduce
-; CHECK: ENTRY
-; CHECK-NEXT: parameter
-; CHECK-NEXT: parameter
-; CHECK-NEXT: kCustom
-; CHECK-NEXT: kLoop
-)");
-
-  EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-3}));
-}
-
-TEST_F(GemmFusionAutotunerTest, ApplySplitKWithoutAlteringTiling) {
-  const std::string kHloText = R"(
-triton_dot {
-  p0 = f16[55,120] parameter(0)
-  p1 = f16[120,20] parameter(1)
-  ROOT dot = f16[55,20] dot(p0, p1),
-    lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
-
-ENTRY e {
-  p0 = f16[55,120]{1,0} parameter(0)
-  p1 = f16[120,20]{1,0} parameter(1)
-  ROOT _ = f16[55,20] fusion(p0, p1), kind=kCustom, calls=triton_dot,
-    backend_config={"fusion_backend_config":{kind: "__triton_gemm", triton_gemm_config: {"block_m":16,"block_n":64,"block_k":32,"split_k":3,"num_stages":1,"num_warps":2,"num_ctas":1}}}
-})";
-
-  MatchOptimizedHlo(kHloText, R"(
-; CHECK: f16[3,55,20]
-; CHECK: {"block_m":16,"block_n":64,"block_k":32,"split_k":3,"num_stages":1,"num_warps":2,"num_ctas":1}
-; CHECK: f16[55,20]{1,0} {{(reduce|fusion)}}
-)");
-
-  EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3}));
-}
-
-// Modify block_k back to 16 once b/337839570 is fixed.
-// TODO(b/344770374): Make this test not fragile.
-TEST_F(GemmFusionAutotunerTest, DoNotRunAutotuningKernelSpillingRegisters) {
-  const std::string kHloText = R"(
-HloModule m
-
-%triton_gemm_dot {
-  %p1 = s8[4,12288]{1,0} parameter(1)
-  %p0 = s8[12288,1536]{1,0} parameter(0)
-  %convert.p0 = f16[12288,1536]{1,0} convert(s8[12288,1536]{1,0} %p0)
-  %convert.p1 = f16[4,12288]{1,0} convert(s8[4,12288]{1,0} %p1)
-  %dot = f16[4,1536]{1,0} dot(f16[4,12288]{1,0} %convert.p1, f16[12288,1536]{1,0} %convert.p0), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  ROOT %convert = s8[4,1536]{1,0} convert(f16[4,1536]{1,0} %dot)
-}
-
-ENTRY %e {
-  %get-tuple-element.7020 = s8[12288,1536]{1,0} parameter(0)
-  %convert = s8[4,12288]{1,0} parameter(1)
-  ROOT %triton = s8[4,1536]{1,0} fusion(s8[12288,1536]{1,0} %get-tuple-element.7020, s8[4,12288]{1,0} %convert), kind=kCustom, calls=%triton_gemm_dot,
-    backend_config={"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"256","block_n":"256","block_k":"32","split_k":"1","num_stages":"1","num_warps":"16","num_ctas":"1"}}}
-})";
-
-  auto module = ParseAndReturnVerifiedModule(kHloText).value();
-  EXPECT_THAT(
-      backend().compiler()->RunBackend(std::move(module),
-                                       backend().default_stream_executor(),
-                                       {/*device_allocator=*/nullptr,
-                                        /*thread_pool=*/nullptr,
-                                        /*layout_canonicalization_callback=*/{},
-                                        /*is_autotuning_compilation=*/true}),
-      ::testing::AnyOf(
-          tsl::testing::StatusIs(
-              tsl::error::CANCELLED,
-              absl::StrFormat(
-                  "Compilation result discarded due to register spilling")),
-          // Hopper can't spill registers since wgmma instructions are
-          // asynchronous, instead it just runs out of them.
-          tsl::testing::StatusIs(
-              tsl::error::RESOURCE_EXHAUSTED,
-              absl::StrFormat("Register allocation failed"))));
-}
-
-// Modify block_k back to 16 once b/337839570 is fixed.
-// TODO(b/344770374): Make this test not fragile.
-TEST_F(GemmFusionAutotunerTest,
-       DoNotFilterOutAutotuningKernelSpillingRegisters) {
-  if (GetCudaComputeCapability().IsAtLeastHopper()) {
-    GTEST_SKIP() << "Hopper and newer runs out of registers for such HLOs";
-  }
-  const std::string kHloText = R"(
-HloModule m
-
-%triton_gemm_dot {
-  %p1 = s8[4,12288]{1,0} parameter(1)
-  %p0 = s8[12288,1536]{1,0} parameter(0)
-  %convert.p0 = f16[12288,1536]{1,0} convert(s8[12288,1536]{1,0} %p0)
-  %convert.p1 = f16[4,12288]{1,0} convert(s8[4,12288]{1,0} %p1)
-  %dot = f16[4,1536]{1,0} dot(f16[4,12288]{1,0} %convert.p1, f16[12288,1536]{1,0} %convert.p0), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  ROOT %convert = s8[4,1536]{1,0} convert(f16[4,1536]{1,0} %dot)
-}
-
-ENTRY %e {
-  %get-tuple-element.7020 = s8[12288,1536]{1,0} parameter(0)
-  %convert = s8[4,12288]{1,0} parameter(1)
-  ROOT %triton = s8[4,1536]{1,0} fusion(s8[12288,1536]{1,0} %get-tuple-element.7020, s8[4,12288]{1,0} %convert), kind=kCustom, calls=%triton_gemm_dot,
-    backend_config={"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"256","block_n":"256","block_k":"32","split_k":"1","num_stages":"1","num_warps":"16","num_ctas":"1"}}}
-})";
-
-  auto module = ParseAndReturnVerifiedModule(kHloText).value();
-  HloModuleConfig config = module->config();
-  DebugOptions debug_options = config.debug_options();
-  debug_options.set_xla_gpu_filter_kernels_spilling_registers_on_autotuning(
-      false);
-  config.set_debug_options(debug_options);
-  module->set_config(config);
-
-  std::unique_ptr<Executable> executable =
-      backend()
-          .compiler()
-          ->RunBackend(std::move(module), backend().default_stream_executor(),
-                       {/*device_allocator=*/nullptr,
-                        /*thread_pool=*/nullptr,
-                        /*layout_canonicalization_callback=*/{},
-                        /*is_autotuning_compilation=*/true})
-          .value();
-  EXPECT_NE(executable, nullptr);
-}
-
-// Modify block_k back to 16 once b/337839570 is fixed.
-TEST_F(GemmFusionAutotunerTest, RunAutotuningKernelNotSpillingRegisters) {
-  const std::string kHloText = R"(
-HloModule m
-
-%triton_gemm_dot {
-  %p1 = f16[4,12288]{1,0} parameter(1)
-  %p0 = s8[12288,1536]{1,0} parameter(0)
-  %convert.10406 = f16[12288,1536]{1,0} convert(s8[12288,1536]{1,0} %p0)
-  ROOT %dot = f16[4,1536]{1,0} dot(f16[4,12288]{1,0} %p1, f16[12288,1536]{1,0} %convert.10406), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
-
-ENTRY %e {
-  %p0 = s8[12288,1536]{1,0} parameter(0)
-  %p1 = f16[4,12288]{1,0} parameter(1)
-  ROOT %triton_dot = f16[4,1536]{1,0} fusion(s8[12288,1536]{1,0} %p0, f16[4,12288]{1,0} %p1), kind=kCustom, calls=%triton_gemm_dot,
-    backend_config={"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"16","block_n":"32","block_k":"32","split_k":"1","num_stages":"1","num_warps":"2","num_ctas":"1"}}}
-})";
-
-  auto module = ParseAndReturnVerifiedModule(kHloText).value();
-  std::unique_ptr<Executable> executable =
-      backend()
-          .compiler()
-          ->RunBackend(std::move(module), backend().default_stream_executor(),
-                       {/*device_allocator=*/nullptr,
-                        /*thread_pool=*/nullptr,
-                        /*layout_canonicalization_callback=*/{},
-                        /*is_autotuning_compilation=*/true})
-          .value();
-  EXPECT_NE(executable, nullptr);
-}
-
-using GemmFusionAutotunerDumpTest = GemmFusionAutotunerTest;
-
-TEST_F(GemmFusionAutotunerDumpTest, Fp8CublasltFallbackSupport) {
-  const std::string kHloText = R"(
-HloModule o
-
-gemm_fusion {
-  p0 = f8e4m3fn[64,6144]{1,0} parameter(0)
-  p1 = f8e4m3fn[64,6144]{1,0} parameter(1)
-  ROOT %dot.0 = f32[64,64]{1,0} dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={1}
-}
-
-ENTRY main {
-  p0 = f8e4m3fn[64,6144]{1,0} parameter(0)
-  p1 = f8e4m3fn[64,6144]{1,0} parameter(1)
-  ROOT %dot.0 = f32[64,64]{1,0} fusion(p0, p1), kind=kCustom, calls=gemm_fusion, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_gemm"},"force_earliest_schedule":false}
-})";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
-                          ParseAndReturnVerifiedModule(kHloText));
-
-  DebugOptions opts;
-  AutotuneConfig autotune_config{
-      DeviceConfig{backend().default_stream_executor(),
-                   backend().memory_allocator()},
-      opts};
-  AutotuneCacheKey cache_key(autotune_config.GetModelStr(),
-                             *module->entry_computation()->root_instruction());
-
-  TF_ASSERT_OK_AND_ASSIGN(AutotuneResults autotune_results_override,
-                          ParseTextProto<AutotuneResults>(R"pb(
-                            version: 3
-                            results {
-                              device: "..."
-                              hlo: "..."
-                              result {
-                                gemm { algorithm: -1 }
-                                run_time { nanos: 14 }
-                              }
-                            })pb"));
-  autotune_results_override.mutable_results(0)->set_device(
-      std::string(cache_key.GetModelStr()));
-  autotune_results_override.mutable_results(0)->set_hlo(
-      std::string(cache_key.GetHlo()));
-  CHECK_OK(AutotunerUtil::LoadAutotuneResults(autotune_results_override));
-
-  HloPassPipeline pipeline("gemm_autotune");
-  tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "",
-                                      tsl::port::MaxParallelism());
-  MultiProcessKeyValueStore key_value_store;
-  pipeline.AddPass<GemmFusionAutotuner>(autotune_config, GetToolkitVersion(),
-                                        &thread_pool, key_value_store);
-  pipeline.AddPass<CallInliner>();
-  for (bool fp8_rewrite : {true, false}) {
-    pipeline.AddPass<GemmRewriter>(autotune_config.GetGpuComputeCapability(),
-                                   GetToolkitVersion(), fp8_rewrite);
-  }
-
-  TF_EXPECT_OK(HloTestBase::RunHloPass(&pipeline, module.get()));
-  const bool is_at_least_hopper =
-      std::holds_alternative<se::CudaComputeCapability>(
-          autotune_config.GetGpuComputeCapability()) &&
-      std::get<se::CudaComputeCapability>(
-          autotune_config.GetGpuComputeCapability())
-          .IsAtLeastHopper();
-  TF_ASSERT_OK_AND_ASSIGN(
-      bool filecheck_matches,
-      RunFileCheck(module->ToString(), is_at_least_hopper
-                                           ? "// CHECK: __cublas$lt"
-                                           : "// CHECK: __cublas$gemm"));
-  EXPECT_TRUE(filecheck_matches);
-}
-
-TEST_F(GemmFusionAutotunerDumpTest, DumpingWorks) {
-  HloModuleConfig config;
-  DebugOptions options = GetDebugOptionsForTest();
-  options.set_xla_gpu_cublas_fallback(true);
-  options.set_xla_gpu_dump_autotuned_gemm_fusions(true);
-  std::string output_directory;
-  if (!tsl::io::GetTestUndeclaredOutputsDir(&output_directory)) {
-    output_directory = tsl::testing::TmpDir();
-  }
-  options.set_xla_dump_to(output_directory);
-  config.set_debug_options(options);
-  // Computation is chosen such that relatively heavy math operations before the
-  // GEMM are not worth fusing because they would get duplicated many times and
-  // slow down execution. Therefore autotuning picks cuBLAS here.
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
-                          ParseAndReturnVerifiedModule(R"(
-fusion1 {
-  p0 = f32[333,333] parameter(0)
-  s = f32[333,333] sine(p0)
-  p1 = f32[333,333] parameter(1)
-  c = f32[333,333] cosine(p1)
-  ROOT dot = f32[333,333] dot(s, c),
-    lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
-
-ENTRY e {
-  p0 = f32[333,333] parameter(0)
-  p1 = f32[333,333] parameter(1)
-  ROOT rr = f32[333,333] fusion(p0, p1), kind=kCustom, calls=fusion1,
-    backend_config={"fusion_backend_config": {kind: "__triton_gemm"}}
-})",
-                                                       config));
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
-                          GetOptimizedModule(std::move(module)));
-
-  std::string dump;
-  TF_EXPECT_OK(tsl::ReadFileToString(
-      tsl::Env::Default(),
-      tsl::io::JoinPath(output_directory,
-                        FilenameFor(*optimized_module, /*prefix=*/"",
-                                    /*suffix=*/"gemm_fusion_0.rr.txt")),
-      &dump));
-  EXPECT_TRUE(*RunFileCheck(dump, R"(
-CHECK: HloModule rr
-CHECK-NOT: cublas
-CHECK: __triton_gemm
-CHECK-NOT: block_m
-)"));
-
-  dump.clear();
-
-  TF_EXPECT_OK(tsl::ReadFileToString(
-      tsl::Env::Default(),
-      tsl::io::JoinPath(
-          output_directory,
-          FilenameFor(*optimized_module, /*prefix=*/"",
-                      /*suffix=*/"gemm_fusion_0.rr.optimized.txt")),
-      &dump));
-  EXPECT_TRUE(*RunFileCheck(dump, R"(
-CHECK: HloModule rr
-CHECK-NOT: triton
-CHECK: cublas
-)"));
-}
-
-TEST_F(GemmFusionAutotunerTest, AutotuneCuDnnFusion) {
-  const std::string kHlo = R"(
-fusion1 {
-  p0 = f32[3,28,32] parameter(0)
-  p1 = f32[3,28,32] parameter(1)
-  ROOT d = f32[3,32,32] dot(p0, p1),
-    lhs_batch_dims={0}, rhs_batch_dims={0},
-    lhs_contracting_dims={1}, rhs_contracting_dims={1}
-}
-
-ENTRY e {
-  p0 = f32[3,28,32] parameter(0)
-  p1 = f32[3,28,32] parameter(1)
-  ROOT _ = f32[3,32,32] fusion(p0, p1), kind=kCustom, calls=fusion1,
-    backend_config={"fusion_backend_config": {kind: "__cudnn$fusion"}}
-})";
-
-  CheckTritonAutotuning(kHlo, R"(
-// CHECK: "plan_id":
-)");
-}
-
-// TODO(b/281489442): Write a testcase called
-// `SkipConfigsProducingDeviantResults` or similar.
-
-class GemmFusionAutotunerLevelTest : public StatelessAutotunerTest,
-                                     public ::testing::WithParamInterface<int> {
- public:
-  DebugOptions GetDebugOptionsForTest() override {
-    DebugOptions debug_options =
-        StatelessAutotunerTest::GetDebugOptionsForTest();
-    debug_options.set_xla_gpu_autotune_level(GetParam());
-    debug_options.set_xla_gpu_cublas_fallback(false);
-    return debug_options;
-  }
-};
-
-TEST_P(GemmFusionAutotunerLevelTest, AllAutotuningLevelsWorkCorrectly) {
-  const std::string kHloText = R"(
-HloModule m
-
-ENTRY e {
-  p0 = pred[64,10] parameter(0)
-  p0c = f32[64,10] convert(p0)
-  p1 = f32[10,128] parameter(1)
-  ROOT r = f32[64,128] dot(p0c, p1),
-    lhs_contracting_dims={1}, rhs_contracting_dims={0}
-})";
-
-  MatchOptimizedHlo(kHloText, R"(
-; CHECK: kind=kCustom
-; CHECK-SAME: block_m
-      )");
-
-  EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3}));
-}
-
-TEST_P(GemmFusionAutotunerLevelTest, Deviceless) {
-  const std::string hlo = R"(
-HloModule module
-
-ENTRY e {
-  x = s8[16,16] parameter(0)
-  c = f16[16,16] convert(x)
-  y = f16[16,16] parameter(1)
-  ROOT out = f16[16,16] dot(c, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
-)";
-
-  HloPassPipeline pipeline("gemm_rewrite_deviceless");
-  pipeline.AddPass<GemmFusion>(backend()
-                                   .default_stream_executor()
-                                   ->GetDeviceDescription()
-                                   .cuda_compute_capability());
-  tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "",
-                                      tsl::port::MaxParallelism());
-  DebugOptions opts;
-  MultiProcessKeyValueStore key_value_store;
-  pipeline.AddPass<GemmFusionAutotuner>(
-      AutotuneConfig{
-          DevicelessConfig{
-              backend().default_stream_executor()->GetDeviceDescription()},
-          opts},
-      GetToolkitVersion(), &thread_pool, key_value_store);
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
-                          ParseAndReturnVerifiedModule(hlo));
-  if (GetDebugOptionsForTest().xla_gpu_autotune_level() == 0) {
-    TF_ASSERT_OK_AND_ASSIGN(bool changed,
-                            HloTestBase::RunHloPass(&pipeline, module.get()));
-    EXPECT_TRUE(changed);
-
-    // Check default configuration.
-    TF_ASSERT_OK_AND_ASSIGN(
-        bool filecheck_matches,
-        RunFileCheck(
-            module->ToString(HloPrintOptions{}.set_print_operand_shape(false)),
-            R"(
-// CHECK: backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"16","block_n":"16","block_k":"32","split_k":"1","num_stages":"1","num_warps":"4","num_ctas":"1"}},"force_earliest_schedule":false}
-            )"));
-    EXPECT_TRUE(filecheck_matches);
-  } else {
-    EXPECT_THAT(HloTestBase::RunHloPass(&pipeline, module.get()),
-                tsl::testing::StatusIs(
-                    tsl::error::INTERNAL,
-                    ::testing::HasSubstr(
-                        "Expect autotune result cache hit for deviceless")));
-  }
-}
-
-INSTANTIATE_TEST_SUITE_P(GemmFusionAutotunerLevelSweep,
-                         GemmFusionAutotunerLevelTest, ::testing::Range(0, 5));
-
-class GemmFusionAutotunerExhaustiveTest : public GemmFusionAutotunerTest {
- public:
-  DebugOptions GetDebugOptionsForTest() override {
-    DebugOptions debug_options =
-        GemmFusionAutotunerTest::GetDebugOptionsForTest();
-    debug_options.set_xla_gpu_exhaustive_tiling_search(true);
-    return debug_options;
-  }
-};
-
-TEST_F(GemmFusionAutotunerExhaustiveTest, DISABLED_CompileOnly) {
-  const std::string hlo = R"(
-HloModule module
-
-ENTRY e {
-  x = s8[16,16] parameter(0)
-  c = f16[16,16] convert(x)
-  y = f16[16,16] parameter(1)
-  ROOT out = f16[16,16] dot(c, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
-)";
-
-  CheckTritonAutotuning(hlo, R"(
-// CHECK:   %triton_gemm_out_computation (
-// CHECK:   ROOT %out.1 = f16[16,16]{1,0} dot(%c.1, %parameter_1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-// CHECK:   ROOT %triton_gemm_out = f16[16,16]{1,0} fusion(%x, %y), kind=kCustom, calls=%triton_gemm_out_computation
-// CHECK-SAME: "block_m":
-)");
-}
-
-// TODO(b/337839570): Triton currently has a limitation where it crashes
-// on small block_k values depending on the bit-width of the inputs to the
-// dot. For this test case, it should skip any block_k values that are <= 16
-// since the smallest type has a bit-width of 8.
-TEST_F(GemmFusionAutotunerExhaustiveTest, SkipsCrashingTileKConfig) {
-  std::unique_ptr<VerifiedHloModule> module = ParseAndReturnVerifiedModule(R"(
-HloModule module
-ENTRY e {
-  x = s8[33,33]{1,0} parameter(0)
-  c = f16[33,33]{1,0} convert(x)
-  y = f16[33,33]{1,0} parameter(1)
-  ROOT out = f16[33,33]{1,0} dot(c, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
-)")
-                                                  .value();
-  const se::CudaComputeCapability compute_capability{
-      se::CudaComputeCapability::AMPERE, /*minor=*/0};
-  TF_ASSERT_OK_AND_ASSIGN(
-      const std::vector<TritonGemmConfig> configs,
-      GetPossibleMatmulAutotuneConfigs(
-          *Cast<HloDotInstruction>(
-              module->entry_computation()->root_instruction()),
-          compute_capability, GetToolkitVersion(), GetDebugOptionsForTest()));
-  EXPECT_TRUE(std::all_of(
-      configs.begin(), configs.end(),
-      [](const TritonGemmConfig& config) { return config.block_k > 16; }));
-}
-
-class GemmFusionAutotunerDisableSplitK : public GemmFusionAutotunerTest {
- public:
-  DebugOptions GetDebugOptionsForTest() override {
-    DebugOptions debug_options =
-        GemmFusionAutotunerTest::GetDebugOptionsForTest();
-    debug_options.set_xla_gpu_enable_split_k_autotuning(false);
-    return debug_options;
-  }
-};
-
-TEST_F(GemmFusionAutotunerDisableSplitK, SplitKIsDisabled) {
-  std::unique_ptr<VerifiedHloModule> module = ParseAndReturnVerifiedModule(R"(
-ENTRY e {
-  p0 = f32[1024,1024] parameter(0)
-  p1 = f32[1024,1024] parameter(1)
-  ROOT r = f32[1024,1024] dot(p0, p1),
-    lhs_contracting_dims={1}, rhs_contracting_dims={0}
-})")
-                                                  .value();
-  const se::CudaComputeCapability compute_capability{
-      se::CudaComputeCapability::AMPERE, /*minor=*/0};
-  TF_ASSERT_OK_AND_ASSIGN(
-      const std::vector<TritonGemmConfig> configs,
-      GetPossibleMatmulAutotuneConfigs(
-          *Cast<HloDotInstruction>(
-              module->entry_computation()->root_instruction()),
-          compute_capability, GetToolkitVersion(), GetDebugOptionsForTest()));
-  EXPECT_TRUE(std::all_of(
-      configs.begin(), configs.end(),
-      [](const TritonGemmConfig& config) { return config.split_k == 1; }));
-}
-
-class GemmFusionAutotunerConfigTest
-    : public StatelessAutotunerTest,
-      public ::testing::WithParamInterface<bool> {};
-
-TEST_P(GemmFusionAutotunerConfigTest, SparseDotDiscardsUnsupportedTiles) {
-  const std::string kHloText = R"(
-HloModule test
-ENTRY wais {
-  lhs = f16[5,1600] parameter(0)
-  rhs = f16[3200,10] parameter(1)
-  meta = u16[5,200] parameter(2)
-  ROOT dot = f32[5,10] dot(lhs, rhs, meta),
-      lhs_contracting_dims={1}, rhs_contracting_dims={0}, sparsity=L.1@2:4
-})";
-  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText));
-  const se::CudaComputeCapability compute_capability{
-      se::CudaComputeCapability::AMPERE, /*minor=*/0};
-  DebugOptions debug_options = GetDebugOptionsForTest();
-  debug_options.set_xla_gpu_exhaustive_tiling_search(GetParam());
-
-  TF_ASSERT_OK_AND_ASSIGN(
-      const std::vector<TritonGemmConfig> configs,
-      GetPossibleMatmulAutotuneConfigs(
-          *Cast<HloDotInstruction>(
-              module->entry_computation()->root_instruction()),
-          compute_capability, GetToolkitVersion(), debug_options));
-  for (const auto& config : configs) {
-    int metadata_size = config.block_m * config.block_k / 16;
-    EXPECT_LE(config.num_warps * WarpSize(), metadata_size);
-    EXPECT_GT(config.block_k, 16);  // kMinTileSize
-  }
-}
-
-INSTANTIATE_TEST_SUITE_P(GemmFusionAutotunerConfigSweep,
-                         GemmFusionAutotunerConfigTest, ::testing::Bool());
-
-}  // namespace
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gemm_fusion_test.cc b/third_party/xla/xla/service/gpu/gemm_fusion_test.cc
deleted file mode 100644
index 44430f5..0000000
--- a/third_party/xla/xla/service/gpu/gemm_fusion_test.cc
+++ /dev/null
@@ -1,1334 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gemm_fusion.h"
-
-#include <memory>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "absl/status/status.h"
-#include "absl/strings/string_view.h"
-#include "xla/autotuning.pb.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/gpu/cublas_padding_requirements.h"
-#include "xla/service/gpu/triton_fusion_analysis.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/tests/filecheck.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/tests/verified_hlo_module.h"
-#include "xla/xla.pb.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/status_matchers.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-using ::testing::ElementsAre;
-using ::testing::FieldsAre;
-
-namespace m = ::xla::match;
-
-class GemmFusionTest : public HloTestBase {
- public:
-  GemmFusionTest()
-      : HloTestBase(/*verifier_layout_sensitive=*/true,
-                    /*allow_mixed_precision_in_hlo_verifier=*/false) {}
-
-  DebugOptions GetDebugOptionsForTest() override {
-    DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest();
-    debug_options.set_xla_gpu_triton_gemm_any(false);
-    debug_options.set_xla_gpu_gemm_rewrite_size_threshold(0);
-    return debug_options;
-  }
-
-  se::GpuComputeCapability gpu_version_{
-      se::CudaComputeCapability{se::CudaComputeCapability::AMPERE, 0}};
-
-  void MatchHloModule(HloModule& module, absl::string_view pattern) {
-    TF_ASSERT_OK_AND_ASSIGN(bool filecheck_result,
-                            RunFileCheck(module.ToString(), pattern));
-    EXPECT_TRUE(filecheck_result);
-  }
-};
-
-TEST_F(GemmFusionTest, TransposeSubdimensionGroup) {
-  // This HLO is artificial because unnecessary reshapes get optimized
-  // out during compilation. It tests the ability of GemmFusion
-  // to handle transposes of groups of subdimensions.
-  auto module = ParseAndReturnVerifiedModule(R"(
-HloModule m
-
-ENTRY e {
-  p0 = f32[32,3] parameter(0)
-  t1 = f32[3,32] transpose(p0), dimensions={1,0}
-  r1 = f32[3,8,4] reshape(t1)
-  r0 = f32[3,32] reshape(r1)
-  p1 = f16[32,7] parameter(1)
-  c1 = f32[32,7] convert(p1)
-  ROOT d = f32[3,7] dot(r0, c1),
-    lhs_contracting_dims={1}, rhs_contracting_dims={0}
-})")
-                    .value();
-  EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch(m::Fusion(m::Parameter(), m::Parameter())));
-}
-
-TEST_F(GemmFusionTest, UnsupportedTransposeIsNotFused) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-ENTRY e {
-  p0 = f16[1,512,8,1024]{3,1,0,2} parameter(0)
-  c = f16[1,512,8,1024]{3,2,1,0} copy(p0)
-  b = f16[4096,1024]{1,0} bitcast(c)
-  p1 = f16[128,1024]{1,0} parameter(1)
-  ROOT d = f16[4096,128]{1,0} dot(b, p1),
-    lhs_contracting_dims={1}, rhs_contracting_dims={1}
-})")
-                    .value();
-  EXPECT_FALSE(GemmFusion(gpu_version_).Run(module.get()).value());
-}
-
-TEST_F(GemmFusionTest, BitcastChain) {
-  // This HLO is artificial because unnecessary reshapes get optimized
-  // out during compilation. It tests the ability of GemmFusion
-  // to handle various kinds of bitcasts.
-  auto module = ParseAndReturnVerifiedModule(R"(
-HloModule m
-
-ENTRY e {
-  p0 = s8[60,5] parameter(0)
-  r0 = s8[3,20,5] reshape(p0)
-  c0 = f16[3,20,5] convert(r0)
-  p1 = f16[3,200] parameter(1)
-  r12 = f16[600] reshape(p1)
-  r11 = f16[30,20] reshape(r12)
-  r1 = f16[3,10,20] reshape(r11)
-  ROOT d = f16[3,5,10] dot(c0, r1),
-    lhs_contracting_dims={1}, rhs_contracting_dims={2},
-    lhs_batch_dims={0}, rhs_batch_dims={0}
-})")
-                    .value();
-  EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch(m::Fusion(m::Parameter(), m::Parameter())));
-}
-
-TEST_F(GemmFusionTest, SplitDimensionTwice) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-ENTRY e {
-  p0 = s8[4,2,32,4,2] parameter(0)
-  r1 = s8[8,32,8] reshape(p0)
-  t1 = s8[32,8,8] transpose(r1), dimensions={1,0,2}
-  r0 = s8[32,64] reshape(t1)
-  p1 = s8[32,32] parameter(1)
-  c0 = f16[32,32] convert(p1)
-  ROOT d = f16[64,32] dot(r0, c0),
-    lhs_contracting_dims={0}, rhs_contracting_dims={1}
-})")
-                    .value();
-  EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch(m::Fusion(m::Parameter(), m::Parameter())));
-}
-
-TEST_F(GemmFusionTest, DoNotTriggerOnUnsupportedOutputConversions) {
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
-                          ParseAndReturnVerifiedModule(R"(
-ENTRY e {
-  p0 = f16[128,256] parameter(0)
-  p1 = f16[256,512] parameter(1)
-  r = f16[128,512] dot(p0, p1),
-    lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  ROOT c = u8[128,512] convert(r)
-})"));
-  EXPECT_FALSE(GemmFusion(gpu_version_).Run(module.get()).value());
-}
-
-TEST_F(GemmFusionTest, FuseDotWithTrivialNoncontractingDim) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-HloModule m
-
-ENTRY e {
-  p0 = s8[60,5] parameter(0)
-  r0 = s8[3,20,5] reshape(p0)
-  c0 = f16[3,20,5] convert(r0)
-  p1 = f16[3,1,20] parameter(1)
-  ROOT d = f16[3,5,1] dot(c0, p1),
-    lhs_contracting_dims={1}, rhs_contracting_dims={2},
-    lhs_batch_dims={0}, rhs_batch_dims={0}
-})")
-                    .value();
-  EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch(m::Fusion(m::Parameter(), m::Parameter())));
-}
-
-TEST_F(GemmFusionTest, HandleDotIfCublasRequiresPadding) {
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
-                          ParseAndReturnVerifiedModule(R"(
-HloModule m
-
-ENTRY e {
-  p0 = f16[5,3] parameter(0)
-  p1 = f16[5,7] parameter(1)
-  ROOT d = f16[3,7] dot(p0, p1),
-    lhs_contracting_dims={0}, rhs_contracting_dims={0}
-})"));
-
-  const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0};
-  EXPECT_TRUE(CublasRequiresPadding(
-      *xla::Cast<HloDotInstruction>(
-          module->entry_computation()->root_instruction()),
-      cc));
-  EXPECT_TRUE(GemmFusion(cc).Run(module.get()).value());
-}
-
-TEST_F(GemmFusionTest, FuseSliceOfParameterWithOtherUsers) {
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
-                          ParseAndReturnVerifiedModule(R"(
-ENTRY e {
-  p0 = f32[97,121] parameter(0)
-  s0 = f32[7,101] slice(p0), slice={[3:10], [10:111]}
-  p1 = f32[101,16] parameter(1)
-  d = f32[16,7] dot(p1, s0),
-    lhs_contracting_dims={0}, rhs_contracting_dims={1}
-  s1 = f32[3,33] slice(p0), slice={[10:13], [20:53]}
-  ROOT t = tuple(d, s1)
-})"));
-
-  const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0};
-  EXPECT_TRUE(GemmFusion(cc).Run(module.get()).value());
-}
-
-TEST_F(GemmFusionTest, DoNotFuseSliceOfMixedDimensions) {
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
-                          ParseAndReturnVerifiedModule(R"(
-ENTRY e {
-  p0 = bf16[768,64] parameter(0)
-  s0 = bf16[768,32] slice(p0), slice={[0:768], [0:32]}
-  b0 = bf16[256,3,32] reshape(s0)
-  b1 = bf16[256,96] reshape(b0)
-  p1 = bf16[256,96] parameter(1)
-  ROOT d = bf16[96,96] dot(b1, p1),
-    lhs_contracting_dims={0}, rhs_contracting_dims={0}
-})"));
-
-  const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0};
-  EXPECT_FALSE(GemmFusion(cc).Run(module.get()).value());
-}
-
-TEST_F(GemmFusionTest, DoNotFuseSlicesOfNonMajorFragments) {
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
-                          ParseAndReturnVerifiedModule(R"(
-ENTRY e {
-  p0 = f32[2,2,256,256] parameter(0)
-  s0 = f32[1,1,256,256] slice(p0),
-    slice={[0:1], [0:1], [0:256], [0:256]}
-  r0 = f32[256,256] reshape(s0)
-  p1 = f16[2,2,256,256] parameter(1)
-  s1 = f16[1,1,256,256] slice(p1),
-    slice={[0:1], [0:1], [0:256], [0:256]}
-  r1 = f16[256,256] reshape(s1)
-  ROOT d = f32[256,256] dot(r0, r1),
-    lhs_contracting_dims={1}, rhs_contracting_dims={0}
-})"));
-
-  const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0};
-  EXPECT_FALSE(GemmFusion(cc).Run(module.get()).value());
-}
-
-TEST_F(GemmFusionTest, DynamicSliceIsFused) {
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
-                          ParseAndReturnVerifiedModule(R"(
-ENTRY e {
-  dot_lhs = f32[2,18] parameter(0)
-  dynamic_slice_input = f32[2,64,2] parameter(1)
-  start_index0 = s32[] parameter(2)
-  start_index1_2 = s32[] constant(0)
-  dynamic_slice = f32[1,64,2] dynamic-slice(dynamic_slice_input, start_index0, start_index1_2, start_index1_2),
-                  dynamic_slice_sizes={1,64,2}
-  reshape = f32[64,2] reshape(dynamic_slice)
-  ROOT dot = f16[18,64] dot(dot_lhs, reshape),
-             lhs_contracting_dims={0}, rhs_contracting_dims={1}
-})"));
-
-  EXPECT_TRUE(GemmFusion(se::CudaComputeCapability{
-                             se::CudaComputeCapability::AMPERE, 0})
-                  .Run(module.get())
-                  .value());
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch((m::Fusion(m::Parameter(), m::Parameter(),
-                                    m::Parameter(), m::Constant()))));
-}
-
-TEST_F(GemmFusionTest, DynamicSlicesAreFusedEvenIfTheyShareIndices) {
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
-                          ParseAndReturnVerifiedModule(R"(
-ENTRY e {
-  p0 = f32[2,64,2] parameter(0)
-  p1 = s32[] parameter(1)
-  p2 = s32[] parameter(2)
-  p3 = s32[] parameter(3)
-  ds0 = f32[1,64,2] dynamic-slice(p0, p1, p2, p3), dynamic_slice_sizes={1,64,2}
-  a = f32[64,2] reshape(ds0)
-  ds1 = f32[1,64,2] dynamic-slice(p0, p3, p2, p1), dynamic_slice_sizes={1,64,2}
-  b = f32[64,2] reshape(ds1)
-  ROOT d = f16[64,64] dot(a, b),
-           lhs_contracting_dims={1}, rhs_contracting_dims={1}
-})"));
-
-  EXPECT_TRUE(GemmFusion(se::CudaComputeCapability{
-                             se::CudaComputeCapability::AMPERE, 0})
-                  .Run(module.get())
-                  .value());
-  // TODO(b/339810582): Don't duplicate scalar parameters to dot fusions,
-  // because they are never tiled differently.
-  // TODO(b/339814210): Don't count scalar parameters towards dot fusion
-  // parameter limit.
-  EXPECT_THAT(
-      module->entry_computation()->root_instruction(),
-      GmockMatch((m::Fusion(m::Parameter(), m::Parameter(), m::Parameter(),
-                            m::Parameter(), m::Parameter(), m::Parameter(),
-                            m::Parameter(), m::Parameter()))));
-}
-
-TEST_F(GemmFusionTest, DoNotFuseDynamicSliceOfNonMajorFragments) {
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
-                          ParseAndReturnVerifiedModule(R"(
-ENTRY e {
-  dot_lhs = f32[2,4]{1,0} parameter(0)
-  dynamic_slice_input = f32[4,5,2]{2,1,0} parameter(1)
-  c0 = s32[] constant(0)
-  c2 = s32[] constant(2)
-  dynamic_slice = f32[4,1,2]{2,1,0} dynamic-slice(dynamic_slice_input, c0, c2, c0),
-                  dynamic_slice_sizes={4,1,2}
-  reshape = f32[4,2]{1,0} reshape(dynamic_slice)
-  ROOT dot = f32[4,4]{1,0} dot(dot_lhs, reshape),
-             lhs_contracting_dims={0}, rhs_contracting_dims={1}
-})"));
-  const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0};
-  // FusionDecision "Unsupported dynamic slice on non-major-most dimension."
-  EXPECT_FALSE(GemmFusion(cc).Run(module.get()).value());
-}
-
-TEST_F(GemmFusionTest, CanFuseDynamicSliceOfContractingDimIfItIsMajor) {
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
-                          ParseAndReturnVerifiedModule(R"(
-ENTRY e {
-  dot_lhs = f32[2,4]{1,0} parameter(0)
-  dynamic_slice_input = f32[5,5]{1,0} parameter(1)
-  start_index0 = s32[] constant(2)
-  start_index1 = s32[] constant(0)
-  dynamic_slice = f32[2,5]{1,0} dynamic-slice(dynamic_slice_input, start_index0, start_index1),
-                  dynamic_slice_sizes={2,5}
-  ROOT d = f32[4,5]{1,0} dot(dot_lhs, dynamic_slice),
-           lhs_contracting_dims={0}, rhs_contracting_dims={0}
-})"));
-  EXPECT_TRUE(GemmFusion(se::CudaComputeCapability{
-                             se::CudaComputeCapability::AMPERE, 0})
-                  .Run(module.get())
-                  .value());
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch((m::Fusion(m::Parameter(), m::Parameter(),
-                                    m::Constant(), m::Constant()))));
-}
-
-TEST_F(GemmFusionTest, SliceToDegenerateIsSkipped) {
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
-                          ParseAndReturnVerifiedModule(R"(
-ENTRY e {
-  p = f32[3] parameter(0)
-  s = f32[1] slice(p), slice={[2:3]}
-  r = f32[] reshape(s)
-  b = f32[3,3] broadcast(r), dimensions={}
-  ROOT d = f32[3,3] dot(b, b),
-    lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
-)"));
-  const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0};
-
-  ASSERT_TRUE(GemmFusion(cc).Run(module.get()).value());
-
-  // Slice is not fused.
-  MatchHloModule(*module, R"(
-; CHECK-NOT: slice
-; CHECK: ENTRY
-; CHECK: slice
-)");
-}
-
-TEST_F(GemmFusionTest, MultipleUsesAreHandled) {
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
-                          ParseAndReturnVerifiedModule(R"(
-ENTRY e {
-  c = f32[] constant(1)
-  b = f32[6,8] broadcast(c), dimensions={}
-  p0 = f32[6,8] parameter(0)
-  a1 = f32[6,8] add(p0, b)
-  e = f32[6,8] exponential(a1)
-  a2 = f32[6,8] add(e, b)
-  d = f32[6,8] divide(b, a2)
-  p2 = f16[8,6] parameter(1)
-  cv = f32[8,6] convert(p2)
-  ROOT r = f32[6,6] dot(d, cv),
-    lhs_contracting_dims={1}, rhs_contracting_dims={0}
-})"));
-  const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0};
-  EXPECT_TRUE(GemmFusion(cc).Run(module.get()).value());
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch(m::Fusion(m::Parameter(), m::Parameter())));
-}
-
-TEST_F(GemmFusionTest, BinaryElementwiseOfBroadcastIsFused) {
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
-                          ParseAndReturnVerifiedModule(R"(
-ENTRY e {
-  p2 = f32[3072] parameter(2)
-  b = f32[8192,3072] broadcast(p2), dimensions={1}
-  p0 = f16[8192,3072] parameter(0)
-  p0c = f32[8192,3072] convert(p0)
-  a = f32[8192,3072] add(p0c, b)
-  p1 = f32[3072,768] parameter(1)
-  ROOT r = f32[8192,768] dot(a, p1),
-    lhs_contracting_dims={1}, rhs_contracting_dims={0}
-})"));
-  const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0};
-  EXPECT_TRUE(GemmFusion(cc).Run(module.get()).value());
-  EXPECT_THAT(
-      module->entry_computation()->root_instruction(),
-      GmockMatch(m::Fusion(m::Parameter(), m::Parameter(), m::Parameter())));
-}
-
-TEST_F(GemmFusionTest, BinaryElementwiseOfUnsupportedBroadcastIsNotFused) {
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
-                          ParseAndReturnVerifiedModule(R"(
-ENTRY e {
-  p2 = f32[768] parameter(2)
-  b = f32[8192,768,4] broadcast(p2), dimensions={1}
-  s = f32[8192,3072] bitcast(b)
-  p0 = f16[8192,3072] parameter(0)
-  p0c = f32[8192,3072] convert(p0)
-  a = f32[8192,3072] add(p0c, s)
-  p1 = f32[3072,768] parameter(1)
-  ROOT r = f32[8192,768] dot(a, p1),
-    lhs_contracting_dims={1}, rhs_contracting_dims={0}
-})"));
-  const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0};
-  EXPECT_FALSE(GemmFusion(cc).Run(module.get()).value());
-}
-
-class GemmFusionLevel2Test : public GemmFusionTest {
- public:
-  DebugOptions GetDebugOptionsForTest() override {
-    DebugOptions debug_options = GemmFusionTest::GetDebugOptionsForTest();
-    debug_options.set_xla_gpu_triton_fusion_level(2);
-    return debug_options;
-  }
-};
-
-TEST_F(GemmFusionTest, ConcatenationDivisibleBy64IsFused) {
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
-                          ParseAndReturnVerifiedModule(R"(
-ENTRY e {
-  p0 = bf16[8192,1]{1,0} parameter(0)
-  p1 = bf16[2752,8192]{1,0} parameter(1)
-  p2 = bf16[2752,8192]{1,0} parameter(2)
-  concat = bf16[5504,8192]{1,0} concatenate(p1, p2), dimensions={0}
-  bitcast = bf16[8192,5504]{0,1} bitcast(concat)
-  ROOT r = f32[1,5504]{1,0} dot(p0, bitcast),
-    lhs_contracting_dims={0}, rhs_contracting_dims={0}
-})"));
-  const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0};
-  EXPECT_TRUE(GemmFusion(cc).Run(module.get()).value());
-  EXPECT_THAT(
-      module->entry_computation()->root_instruction(),
-      GmockMatch(m::Fusion(m::Parameter(), m::Parameter(), m::Parameter())));
-}
-
-TEST_F(GemmFusionLevel2Test, ReshapeToScalarIsHandled) {
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
-                          ParseAndReturnVerifiedModule(R"(
-ENTRY e {
-  p0 = s8[5,3] parameter(0)
-  c = f16[5,3] convert(p0)
-  p1 = f16[1] parameter(1)
-  r = f16[] reshape(p1)
-  b = f16[5,7] broadcast(r)
-  ROOT d = f16[3,7] dot(c, b),
-    lhs_contracting_dims={0}, rhs_contracting_dims={0}
-})"));
-
-  EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch(m::Fusion(m::Parameter(), m::Parameter())));
-}
-
-TEST_F(GemmFusionLevel2Test, DoNotFuseIncompatibleDimensionSplits) {
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
-                          ParseAndReturnVerifiedModule(R"(
-ENTRY e {
-  p1 = s8[5,7,2,3]{3,2,1,0} parameter(1)
-  t1 = s8[7,5,2,3]{3,2,1,0} transpose(p1), dimensions={1,0,2,3}
-  r1 = s8[7,30]{1,0} reshape(t1)
-  cvt = f16[7,30]{1,0} convert(r1)
-  p2 = f16[2,7,5,3]{3,2,1,0} parameter(2)
-  t2 = f16[7,2,5,3]{3,2,1,0} transpose(p2), dimensions={1,0,2,3}
-  r2 = f16[7,30]{1,0} reshape(t2)
-  a = f16[7,30]{1,0} add(cvt, r2)
-  p0 = f16[7,79]{1,0} parameter(0)
-  ROOT dot = f16[30,79]{1,0} dot(a, p0),
-    lhs_contracting_dims={0}, rhs_contracting_dims={0}
-})"));
-
-  EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
-  EXPECT_THAT(
-      module->entry_computation()->root_instruction(),
-      GmockMatch(m::Fusion(m::Transpose(), m::Parameter(), m::Parameter())));
-}
-
-TEST_F(GemmFusionLevel2Test, DoNotFuseTooManyParameters) {
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
-                          ParseAndReturnVerifiedModule(R"(
-ENTRY e {
-  tmp_0 = f32[] constant(1)
-  tmp_1 = f32[3,49]{1,0} broadcast(tmp_0), dimensions={}
-  tmp_2 = f32[3,49]{1,0} parameter(6)
-  tmp_3 = f32[] constant(0)
-  tmp_4 = f32[3,49]{1,0} broadcast(tmp_3), dimensions={}
-  tmp_5 = pred[3,49]{1,0} compare(tmp_2, tmp_4), direction=GT
-  tmp_6 = f32[3,49]{1,0} convert(tmp_5)
-  tmp_7 = f32[3,49]{1,0} subtract(tmp_1, tmp_6)
-  tmp_8 = s32[] parameter(13)
-  tmp_9 = f32[] convert(tmp_8)
-  tmp_10 = f32[] maximum(tmp_9, tmp_0)
-  tmp_11 = f32[] divide(tmp_3, tmp_10)
-  tmp_12 = f32[3,49]{1,0} broadcast(tmp_11), dimensions={}
-  tmp_13 = pred[3,49]{1,0} parameter(7)
-  tmp_14 = pred[3,49]{1,0} parameter(10)
-  tmp_15 = pred[3,49]{1,0} and(tmp_13, tmp_14)
-  tmp_16 = f32[3,49]{1,0} convert(tmp_15)
-  tmp_17 = f32[3,49]{1,0} multiply(tmp_12, tmp_16)
-  tmp_18 = f32[3,49]{1,0} negate(tmp_17)
-  tmp_19 = f32[3,49]{1,0} multiply(tmp_7, tmp_18)
-  tmp_20 = f32[3,49]{1,0} parameter(19)
-  tmp_21 = f32[3,49]{1,0} subtract(tmp_1, tmp_20)
-  tmp_22 = f32[3,49]{1,0} divide(tmp_19, tmp_21)
-  tmp_23 = f32[3,49]{1,0} negate(tmp_22)
-  tmp_24 = f32[3,49]{1,0} negate(tmp_6)
-  tmp_25 = f32[3,49]{1,0} multiply(tmp_24, tmp_17)
-  tmp_26 = f32[3,49]{1,0} divide(tmp_25, tmp_20)
-  tmp_27 = f32[3,49]{1,0} add(tmp_23, tmp_26)
-  tmp_28 = f32[3,49]{1,0} parameter(18)
-  tmp_29 = f32[3,49]{1,0} multiply(tmp_27, tmp_28)
-  tmp_30 = f32[3,49]{1,0} parameter(17)
-  tmp_31 = f32[3,49]{1,0} multiply(tmp_29, tmp_30)
-  tmp_32 = f32[3,49]{1,0} parameter(16)
-  tmp_33 = f32[3,49]{1,0} multiply(tmp_31, tmp_32)
-  tmp_34 = f32[3,49]{1,0} parameter(15)
-  tmp_35 = f32[3,49]{1,0} add(tmp_33, tmp_34)
-  tmp_36 = f32[3,49]{1,0} parameter(14)
-  tmp_37 = f32[3,49]{1,0} add(tmp_35, tmp_36)
-  tmp_38 = f32[1,1]{1,0} constant({ {0} })
-  tmp_39 = f32[1,1]{1,0} broadcast(tmp_38), dimensions={0,1}
-  tmp_40 = f32[] reshape(tmp_39)
-  tmp_41 = f32[3,32]{1,0} broadcast(tmp_40), dimensions={}
-  tmp_42 = u32[48]{0} parameter(11)
-  tmp_43 = u32[48]{0} parameter(5)
-  tmp_44 = u32[96]{0} concatenate(tmp_42, tmp_43), dimensions={0}
-  tmp_45 = u32[3,32]{1,0} reshape(tmp_44)
-  tmp_46 = u32[96]{0} reshape(tmp_45)
-  tmp_47 = u32[] constant(1)
-  tmp_48 = u32[3,32]{1,0} broadcast(tmp_47), dimensions={}
-  tmp_49 = u32[96]{0} reshape(tmp_48)
-  tmp_50 = u32[96]{0} shift-right-logical(tmp_46, tmp_49)
-  tmp_51 = u32[3,32]{1,0} reshape(tmp_50)
-  tmp_52 = u32[3,32]{1,0} or(tmp_51, tmp_48)
-  tmp_53 = f32[3,32]{1,0} bitcast-convert(tmp_52)
-  tmp_54 = f32[3,32]{1,0} broadcast(tmp_0), dimensions={}
-  tmp_55 = f32[3,32]{1,0} subtract(tmp_53, tmp_54)
-  tmp_56 = f32[1,1]{1,0} constant({ {1} })
-  tmp_57 = f32[1,1]{1,0} broadcast(tmp_56), dimensions={0,1}
-  tmp_58 = f32[] reshape(tmp_57)
-  tmp_59 = f32[3,32]{1,0} broadcast(tmp_58), dimensions={}
-  tmp_60 = f32[3,32]{1,0} multiply(tmp_55, tmp_59)
-  tmp_61 = f32[3,32]{1,0} add(tmp_60, tmp_41)
-  tmp_62 = f32[3,32]{1,0} maximum(tmp_41, tmp_61)
-  tmp_63 = f32[3,32]{1,0} broadcast(tmp_3), dimensions={}
-  tmp_64 = pred[3,32]{1,0} compare(tmp_62, tmp_63), direction=LT
-  tmp_65 = f32[3,32]{1,0} convert(tmp_64)
-  tmp_66 = f32[3,49]{1,0} parameter(9)
-  tmp_67 = f32[49]{0} parameter(4)
-  tmp_68 = f32[3,49]{1,0} broadcast(tmp_67), dimensions={1}
-  tmp_69 = f32[3,49]{1,0} add(tmp_66, tmp_68)
-  tmp_70 = f32[1,49]{1,0} parameter(12)
-  tmp_71 = f32[1,49]{1,0} broadcast(tmp_0), dimensions={}
-  tmp_72 = f32[1,49]{1,0} divide(tmp_70, tmp_71)
-  tmp_73 = f32[1,49]{1,0} broadcast(tmp_72), dimensions={0,1}
-  tmp_74 = f32[49]{0} reshape(tmp_73)
-  tmp_75 = f32[3,49]{1,0} broadcast(tmp_74), dimensions={1}
-  tmp_76 = f32[3,49]{1,0} subtract(tmp_69, tmp_75)
-  tmp_77 = f32[1,49]{1,0} parameter(3)
-  tmp_78 = f32[1,49]{1,0} parameter(8)
-  tmp_79 = f32[1,49]{1,0} divide(tmp_78, tmp_71)
-  tmp_80 = f32[1,49]{1,0} multiply(tmp_72, tmp_72)
-  tmp_81 = f32[1,49]{1,0} subtract(tmp_79, tmp_80)
-  tmp_82 = f32[1,49]{1,0} add(tmp_81, tmp_71)
-  tmp_83 = f32[1,49]{1,0} rsqrt(tmp_82)
-  tmp_84 = f32[1,49]{1,0} multiply(tmp_77, tmp_83)
-  tmp_85 = f32[1,49]{1,0} broadcast(tmp_84), dimensions={0,1}
-  tmp_86 = f32[49]{0} reshape(tmp_85)
-  tmp_87 = f32[3,49]{1,0} broadcast(tmp_86), dimensions={1}
-  tmp_88 = f32[3,49]{1,0} multiply(tmp_76, tmp_87)
-  tmp_89 = f32[1,49]{1,0} parameter(2)
-  tmp_90 = f32[1,49]{1,0} broadcast(tmp_89), dimensions={0,1}
-  tmp_91 = f32[49]{0} reshape(tmp_90)
-  tmp_92 = f32[3,49]{1,0} broadcast(tmp_91), dimensions={1}
-  tmp_93 = f32[3,49]{1,0} add(tmp_88, tmp_92)
-  tmp_94 = f32[49,32]{1,0} parameter(1)
-  tmp_95 = f32[3,32]{1,0} dot(tmp_93, tmp_94), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  tmp_96 = f32[32]{0} parameter(0)
-  tmp_97 = f32[3,32]{1,0} broadcast(tmp_96), dimensions={1}
-  tmp_98 = f32[3,32]{1,0} add(tmp_95, tmp_97)
-  tmp_99 = f32[3,32]{1,0} multiply(tmp_65, tmp_98)
-  tmp_100 = f32[3,32]{1,0} divide(tmp_99, tmp_63)
-  tmp_101 = f32[3,32]{1,0} maximum(tmp_100, tmp_63)
-  ROOT tmp_102 = f32[49,32]{1,0} dot(tmp_37, tmp_101), lhs_contracting_dims={0}, rhs_contracting_dims={0}
-})"));
-
-  EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
-  EXPECT_EQ(module->entry_computation()->root_instruction()->opcode(),
-            HloOpcode::kFusion);
-  EXPECT_EQ(module->entry_computation()->root_instruction()->fusion_kind(),
-            HloInstruction::FusionKind::kCustom);
-  EXPECT_LE(module->entry_computation()->root_instruction()->operand_count(),
-            TritonFusionAnalysis::kMaxParameterPerDotOperand * 2);
-}
-
-TEST_F(GemmFusionLevel2Test,
-       DoNotFuseTooManyParametersWhenAnInstructionWouldAddMultipleParameters) {
-  static_assert(TritonFusionAnalysis::kMaxParameterPerDotOperand == 4,
-                "We have to update this test.");
-  // If we fuse the select, it adds 2 additional parameters at once (not 3,
-  // because the select instruction itself is removed from the parameters).
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
-                          ParseAndReturnVerifiedModule(R"(
-ENTRY e {
-  a = f32[3,49]{1,0} parameter(0)
-  b = f32[3,49]{1,0} parameter(1)
-  c = pred[3,49]{1,0} parameter(2)
-  d = f32[3,49]{1,0} parameter(3)
-  e = f32[3,49]{1,0} parameter(4)
-  add0 = f32[3,49]{1,0} add(a, b)
-  select = f32[3,49]{1,0} select(c, d, e)
-  add1 = f32[3,49]{1,0} add(add0, select)
-  f = f32[3,32]{1,0} parameter(5)
-  ROOT tmp_102 = f32[49,32]{1,0} dot(add1, f), lhs_contracting_dims={0}, rhs_contracting_dims={0}
-})"));
-
-  EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
-  EXPECT_EQ(module->entry_computation()->root_instruction()->opcode(),
-            HloOpcode::kFusion);
-  EXPECT_EQ(module->entry_computation()->root_instruction()->fusion_kind(),
-            HloInstruction::FusionKind::kCustom);
-  EXPECT_LE(module->entry_computation()->root_instruction()->operand_count(),
-            TritonFusionAnalysis::kMaxParameterPerDotOperand + 1);
-}
-
-TEST_F(GemmFusionLevel2Test, DoNotFuseTooManyParametersForConcat) {
-  static_assert(TritonFusionAnalysis::kMaxParameterPerDotOperand == 4,
-                "We have to update this test.");
-  // The concat shouldn't overgo the allowed parameter limit.
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
-                          ParseAndReturnVerifiedModule(R"(
-ENTRY e {
-  a = f32[3,3]{1,0} parameter(0)
-  b = f32[3,3]{1,0} parameter(1)
-  c = f32[3,3]{1,0} parameter(2)
-  d = f32[3,3]{1,0} parameter(3)
-  e = f32[3,3]{1,0} parameter(4)
-  f = f16[3,3]{1,0} parameter(5)
-  concat = f32[15,3]{1,0} concatenate(a, b, c, d, e), dimensions={0}
-  convert = f32[3,3]{1,0} convert(f)
-  ROOT dot = f32[15,3]{1,0} dot(concat, convert), lhs_contracting_dims={1}, rhs_contracting_dims={1}
-})"));
-
-  EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
-  EXPECT_EQ(module->entry_computation()->root_instruction()->opcode(),
-            HloOpcode::kFusion);
-  EXPECT_EQ(module->entry_computation()->root_instruction()->fusion_kind(),
-            HloInstruction::FusionKind::kCustom);
-  EXPECT_LE(module->entry_computation()->root_instruction()->operand_count(),
-            TritonFusionAnalysis::kMaxParameterPerDotOperand + 1);
-}
-
-TEST_F(GemmFusionLevel2Test,
-       InstructionsReachableFromMultipleOperandsAreHandledCorrectly) {
-  static_assert(TritonFusionAnalysis::kMaxParameterPerDotOperand == 4,
-                "We have to update this test.");
-  // There was a bug that some dead code was generated into some fusions in a
-  // specific edge case. When some instructions were reachable both through the
-  // LHS and the RHS operands, the BFS (Breadth-first search) through the LHS1
-  // operand "marked" one operation as non-fusible because it would exceed the
-  // limit on fusion parameters per operand. But the BFS through the RHS operand
-  // went through that node and fused some more operands. So the resulting
-  // fusion was not connected and caused errors. This test case checks that such
-  // configurations generate a correct HLO now.
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
-                          ParseAndReturnVerifiedModule(R"(
-ENTRY e {
-  a = f32[2,4]{1,0} parameter(0)
-  b = f32[2,4]{1,0} parameter(1)
-  c = f32[2,4]{1,0} parameter(2)
-  d = f32[2,4]{1,0} parameter(3)
-  e = f32[2,4]{1,0} parameter(4)
-  add0 = f32[2,4]{1,0} add(a, b)
-  add1 = f32[2,4]{1,0} add(add0, c)
-  add2 = f32[2,4]{1,0} add(add1, d)
-  add3 = f32[2,4]{1,0} add(add2, e)
-  ROOT r = f32[2,2]{1,0} dot(add3, add0),
-           lhs_contracting_dims={1}, rhs_contracting_dims={1}
-})"));
-
-  EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
-  // ~VerifiedHloModule() will verify the module.
-}
-
-TEST_F(GemmFusionLevel2Test, EachScopeIsFusedToASeparateSubgraph) {
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
-                          ParseAndReturnVerifiedModule(R"(
-ENTRY e {
-  a = f32[2,4]{1,0} parameter(0)
-  b = f32[2,4]{1,0} parameter(1)
-  add = f32[2,4]{1,0} add(a, b)
-  ROOT r = f32[2,2]{1,0} dot(add, add),
-           lhs_contracting_dims={1}, rhs_contracting_dims={1}
-})"));
-
-  EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
-
-  MatchHloModule(*module, R"(
-CHECK-DAG: %[[P0:.*]] = f32[2,4]{1,0} parameter(0)
-CHECK-DAG: %[[P1:.*]] = f32[2,4]{1,0} parameter(1)
-CHECK-DAG: %[[ADD0:.*]] = f32[2,4]{1,0} add(f32[2,4]{1,0} %[[P0]], f32[2,4]{1,0} %[[P1]])
-CHECK-DAG: %[[P2:.*]] = f32[2,4]{1,0} parameter(2)
-CHECK-DAG: %[[P3:.*]] = f32[2,4]{1,0} parameter(3)
-CHECK-DAG: %[[ADD1:.*]] = f32[2,4]{1,0} add(f32[2,4]{1,0} %[[P2]], f32[2,4]{1,0} %[[P3]])
-CHECK-DAG: ROOT {{.*}} = f32[2,2]{1,0} dot(f32[2,4]{1,0} %[[ADD0]], f32[2,4]{1,0} %[[ADD1]])
-CHECK: ENTRY
-CHECK-DAG: %[[P0:.*]] = f32[2,4]{1,0} parameter(0)
-CHECK-DAG: %[[P1:.*]] = f32[2,4]{1,0} parameter(1)
-CHECK-DAG: ROOT {{.*}} = f32[2,2]{1,0}
-CHECK-SAME: fusion(f32[2,4]{1,0} %[[P0]], f32[2,4]{1,0} %[[P1]], f32[2,4]{1,0} %[[P0]], f32[2,4]{1,0} %[[P1]]),
-CHECK-SAME: kind=kCustom
-CHECK-SAME: __triton_gemm
-})");
-}
-
-// The 2 inputs of the add operation are the same and they are iterated the same
-// way, so the same parameter node is reused for them.
-// The reuse happens per "operand fusion", so the add of the LHS and RHS still
-// use different nodes.
-TEST_F(GemmFusionLevel2Test, ParamNodesAreReusedIfTheyHaveTheSameIterSpec) {
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
-                          ParseAndReturnVerifiedModule(R"(
-ENTRY e {
-  a = f32[2,4]{1,0} parameter(0)
-  add = f32[2,4]{1,0} add(a, a)
-  ROOT r = f32[2,2]{1,0} dot(add, add),
-           lhs_contracting_dims={1}, rhs_contracting_dims={1}
-})"));
-
-  EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
-
-  MatchHloModule(*module, R"(
-CHECK-DAG: %[[P0:.*]] = f32[2,4]{1,0} parameter(0)
-CHECK-DAG: %[[ADD0:.*]] = f32[2,4]{1,0} add(f32[2,4]{1,0} %[[P0]], f32[2,4]{1,0} %[[P0]])
-CHECK-DAG: %[[P1:.*]] = f32[2,4]{1,0} parameter(1)
-CHECK-DAG: %[[ADD1:.*]] = f32[2,4]{1,0} add(f32[2,4]{1,0} %[[P1]], f32[2,4]{1,0} %[[P1]])
-CHECK-DAG: ROOT {{.*}} = f32[2,2]{1,0} dot(f32[2,4]{1,0} %[[ADD0]], f32[2,4]{1,0} %[[ADD1]])
-CHECK: ENTRY
-CHECK-DAG: %[[P0:.*]] = f32[2,4]{1,0} parameter(0)
-CHECK-DAG: ROOT {{.*}} = f32[2,2]{1,0}
-CHECK-SAME: fusion(f32[2,4]{1,0} %[[P0]], f32[2,4]{1,0} %[[P0]])
-CHECK-SAME: kind=kCustom
-CHECK-SAME: __triton_gemm
-})");
-}
-
-// NEGATE has the same iteration spec at both usages, so the node is reused
-// (implying that P0 is also reused).
-TEST_F(GemmFusionLevel2Test, NonParamNodesAreReusedIfTheyHaveTheSameIterSpec) {
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
-                          ParseAndReturnVerifiedModule(R"(
-ENTRY e {
-  a = f32[4,4]{1,0} parameter(0)
-  b = f32[4,4]{1,0} parameter(1)
-  negate = f32[4,4]{1,0} negate(a)
-  sine = f32[4,4]{1,0} sine(negate)
-  add = f32[4,4]{1,0} add(negate, sine)
-  ROOT r = f32[4,4]{1,0} dot(add, b),
-           lhs_contracting_dims={1}, rhs_contracting_dims={1}
-})"));
-
-  EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
-
-  MatchHloModule(*module, R"(
-CHECK-DAG: %[[P0:.*]] = f32[4,4]{1,0} parameter(0)
-CHECK-DAG: %[[P1:.*]] = f32[4,4]{1,0} parameter(1)
-CHECK-DAG: %[[NEGATE:.*]] = f32[4,4]{1,0} negate(f32[4,4]{1,0} %[[P0]])
-CHECK-DAG: %[[SINE:.*]] = f32[4,4]{1,0} sine(f32[4,4]{1,0} %[[NEGATE]])
-CHECK-DAG: %[[ADD:.*]] = f32[4,4]{1,0} add(f32[4,4]{1,0} %[[NEGATE]], f32[4,4]{1,0} %[[SINE]])
-CHECK-DAG: ROOT {{.*}} = f32[4,4]{1,0} dot(f32[4,4]{1,0} %[[ADD]], f32[4,4]{1,0} %[[P1]])
-CHECK: ENTRY
-CHECK-DAG: %[[P0:.*]] = f32[4,4]{1,0} parameter(0)
-CHECK-DAG: %[[P1:.*]] = f32[4,4]{1,0} parameter(1)
-CHECK-DAG: ROOT {{.*}} = f32[4,4]{1,0}
-CHECK-SAME: fusion(f32[4,4]{1,0} %[[P0]], f32[4,4]{1,0} %[[P1]])
-CHECK-SAME: kind=kCustom
-CHECK-SAME: __triton_gemm
-})");
-}
-
-// The direct read of the input and the transposed read of the input have
-// different iteration specs, so we don't reuse the node.
-TEST_F(GemmFusionLevel2Test, NodesAreNotReusedIfTheyHaveDifferentIterSpecs) {
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
-                          ParseAndReturnVerifiedModule(R"(
-ENTRY e {
-  a = f32[4,4]{1,0} parameter(0)
-  b = f32[4,4]{1,0} parameter(1)
-  tr_a = f32[4,4]{1,0} transpose(a), dimensions={1,0}
-  add = f32[4,4]{1,0} add(a, tr_a)
-  ROOT r = f32[4,4]{1,0} dot(add, b),
-           lhs_contracting_dims={1}, rhs_contracting_dims={1}
-})"));
-
-  EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
-
-  MatchHloModule(*module, R"(
-CHECK-DAG: %[[P0:.*]] = f32[4,4]{1,0} parameter(0)
-CHECK-DAG: %[[P1:.*]] = f32[4,4]{1,0} parameter(1)
-CHECK-DAG: %[[P2:.*]] = f32[4,4]{1,0} parameter(2)
-CHECK-DAG: %[[TRANSPOSE:.*]] = f32[4,4]{1,0} transpose(f32[4,4]{1,0} %[[P1]])
-CHECK-DAG: %[[ADD:.*]] = f32[4,4]{1,0} add(f32[4,4]{1,0} %[[P0]], f32[4,4]{1,0} %[[TRANSPOSE]])
-CHECK-DAG: ROOT {{.*}} = f32[4,4]{1,0} dot(f32[4,4]{1,0} %[[ADD]], f32[4,4]{1,0} %[[P2]])
-CHECK: ENTRY
-CHECK-DAG: %[[P0:.*]] = f32[4,4]{1,0} parameter(0)
-CHECK-DAG: %[[P1:.*]] = f32[4,4]{1,0} parameter(1)
-CHECK-DAG: ROOT {{.*}} = f32[4,4]{1,0}
-CHECK-SAME: fusion(f32[4,4]{1,0} %[[P0]], f32[4,4]{1,0} %[[P0]], f32[4,4]{1,0} %[[P1]])
-CHECK-SAME: kind=kCustom
-CHECK-SAME: __triton_gemm
-})");
-}
-
-TEST_F(GemmFusionLevel2Test, OperationsAddingMoreParametersGetMultipleTries) {
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
-                          ParseAndReturnVerifiedModule(R"(
-e {
-  p0 = f32[2,2] parameter(0)
-  c0 = f32[] constant(12345)
-  b0 = f32[2,2] broadcast(c0), dimensions={}
-  m0 = f32[2,2] multiply(p0, b0)
-  c1 = f32[] constant(34567)
-  b1 = f32[2,2] broadcast(c1), dimensions={}
-  a0 = f32[2,2] add(m0, b1)
-  b3 = f32[2,2,2] broadcast(a0), dimensions={0,1}
-  p2 = f32[2,2,2] parameter(2)
-  m2 = f32[2,2,2] multiply(p2, b3)
-  p1 = f32[2]{0} parameter(1)
-  c2 = f32[] constant(5678)
-  b2 = f32[2] broadcast(c2), dimensions={}
-  a1 = f32[2]{0} add(p1, b2)
-  b4 = f32[2,2,2] broadcast(a1), dimensions={2}
-  m1 = f32[2,2,2] multiply(m2, b4)
-  b = f32[4,2] bitcast(m1)
-  p3 = f16[2,2] parameter(3)
-  p3c = f32[2,2] convert(p3)
-  ROOT r = f32[4,2] dot(b, p3c),
-    lhs_contracting_dims={1}, rhs_contracting_dims={0}
-})"));
-
-  EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch((m::Fusion(m::Parameter(), m::Parameter(),
-                                    m::Parameter(), m::Parameter()))));
-}
-
-TEST_F(GemmFusionLevel2Test, GemmFusionBailsOutPreAmpere) {
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
-                          ParseAndReturnVerifiedModule(R"(
-ENTRY e {
-  p0 = f32[2,53] parameter(0)
-  p0e = f32[2,53] exponential(p0)
-  p1 = s16[53,2] parameter(1)
-  p1c = f32[53,2] convert(p1)
-  ROOT dot = f32[2,2] dot(p0e, p1c),
-    lhs_contracting_dims={1}, rhs_contracting_dims={0}
-})"));
-  EXPECT_THAT(
-      GemmFusion(se::CudaComputeCapability{se::CudaComputeCapability::VOLTA, 0})
-          .Run(module.get()),
-      tsl::testing::StatusIs(
-          absl::StatusCode::kFailedPrecondition,
-          ::testing::HasSubstr("Triton support is only enabled for Ampere GPUs "
-                               "(compute capability 8.0) and up, but got")));
-}
-
-TEST_F(GemmFusionLevel2Test, GemmFusionSucceedsOnNonCudaGpu) {
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
-                          ParseAndReturnVerifiedModule(R"(
-ENTRY e {
-  p0 = f32[2,53] parameter(0)
-  p0e = f32[2,53] exponential(p0)
-  p1 = s16[53,2] parameter(1)
-  p1c = f32[53,2] convert(p1)
-  ROOT dot = f32[2,2] dot(p0e, p1c),
-    lhs_contracting_dims={1}, rhs_contracting_dims={0}
-})"));
-  EXPECT_TRUE(GemmFusion(se::RocmComputeCapability{}).Run(module.get()).ok());
-}
-
-TEST_F(GemmFusionLevel2Test, ParameterUsedElementwiseTwiceIsFused) {
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
-                          ParseAndReturnVerifiedModule(R"(
-HloModule t
-
-ENTRY e {
-  p0 = f32[2,35] parameter(0)
-  p0n = f32[2,35] negate(p0)
-  p0e = f32[2,35] exponential(p0)
-  a = f32[2,35] add(p0e, p0n)
-  p1 = f16[35,2] parameter(1)
-  p1c = f32[35,2] convert(p1)
-  ROOT dot = f32[2,2] dot(a, p1c),
-    lhs_contracting_dims={1}, rhs_contracting_dims={0}
-})"));
-  EXPECT_TRUE(GemmFusion(se::CudaComputeCapability{
-                             se::CudaComputeCapability::AMPERE, 0})
-                  .Run(module.get())
-                  .value());
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch((m::Fusion(m::Parameter(), m::Parameter()))));
-  TF_ASSERT_OK_AND_ASSIGN(
-      const auto analysis,
-      TritonFusionAnalysis::Execute(*module->entry_computation()
-                                         ->root_instruction()
-                                         ->called_computations()[0]));
-  EXPECT_EQ(analysis.ScopeParameters(TritonFusionAnalysis::Scope::LHS).size(),
-            1);
-  EXPECT_EQ(analysis.ScopeParameters(TritonFusionAnalysis::Scope::RHS).size(),
-            1);
-}
-
-TEST_F(GemmFusionLevel2Test,
-       ParameterUsedNonElementwiseTwiceIsFusedOnBothPaths) {
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
-                          ParseAndReturnVerifiedModule(R"(
-HloModule t
-
-ENTRY e {
-  p0 = f32[4,4] parameter(0)
-  p0t = f32[4,4] transpose(p0), dimensions={1,0}
-  a = f32[4,4] add(p0, p0t)
-  p1 = f16[4,5] parameter(1)
-  p1c = f32[4,5] convert(p1)
-  ROOT dot = f32[4,5] dot(a, p1c),
-    lhs_contracting_dims={1}, rhs_contracting_dims={0}
-})"));
-  EXPECT_TRUE(GemmFusion(se::CudaComputeCapability{
-                             se::CudaComputeCapability::AMPERE, 0})
-                  .Run(module.get())
-                  .value());
-  EXPECT_THAT(
-      module->entry_computation()->root_instruction(),
-      GmockMatch((m::Fusion(m::Parameter(), m::Parameter(), m::Parameter()))));
-}
-
-TEST_F(GemmFusionLevel2Test,
-       ComputationParameterWithMultipleUsersIsNotTrivialToFuse) {
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
-                          ParseAndReturnVerifiedModule(R"(
-ENTRY e {
-  p0 = f32[400,400] parameter(0)
-
-  c0 = f16[400,400] convert(p0)
-  p1 = f16[400,400] parameter(1)
-  dot0 = f16[400,400] dot(c0, p1),
-    lhs_contracting_dims={1}, rhs_contracting_dims={0}
-
-  c1 = f16[400,400] convert(p0)
-  p2 = f16[400,400] parameter(2)
-  dot1 = f16[400,400] dot(c1, p2),
-    lhs_contracting_dims={1}, rhs_contracting_dims={0}
-
-  ROOT a = f16[400,400] add(dot0, dot1)
-})"));
-  EXPECT_FALSE(GemmFusion(se::CudaComputeCapability{
-                              se::CudaComputeCapability::AMPERE, 0})
-                   .Run(module.get())
-                   .value());
-}
-
-TEST_F(GemmFusionLevel2Test, NarrowingConversionIsAlwaysBetterToFuse) {
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
-                          ParseAndReturnVerifiedModule(R"(
-ENTRY e {
-  p0 = s8[512,512] parameter(0)
-  c0 = f16[512,512] convert(p0)
-  p1 = f16[512,512] parameter(1)
-  dot0 = f16[512,512] dot(c0, p1),
-    lhs_contracting_dims={1}, rhs_contracting_dims={0}
-
-  n = f16[512,512] negate(c0)
-  ROOT a = f16[512,512] add(dot0, n)
-})"));
-  EXPECT_TRUE(GemmFusion(se::CudaComputeCapability{
-                             se::CudaComputeCapability::AMPERE, 0})
-                  .Run(module.get())
-                  .value());
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch((m::Add(m::Fusion(m::Parameter(), m::Parameter()),
-                                 m::Negate()))));
-}
-
-TEST_F(GemmFusionLevel2Test, NestedSlicingIsAnalyzedCorrectly) {
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
-                          ParseAndReturnVerifiedModule(R"(
-triton_gemm_d_computation {
-  p0 = f32[6,24]{1,0} parameter(0)
-  slice1 = f32[5,20]{1,0} slice(p0), slice={[1:6], [3:23]}
-  n1 = f32[5,20]{1,0} negate(slice1)
-  slice2 = f32[3,7]{1,0} slice(n1), slice={[1:4], [13:20]}
-  p1 = f32[7,37]{1,0} parameter(1)
-  ROOT d = f32[3,37]{1,0} dot(slice2, p1),
-    lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
-
-ENTRY e {
-  p0 = f32[7,37]{1,0} parameter(0)
-  p1 = f32[6,24]{1,0} parameter(1)
-  ROOT triton_gemm_d = f32[3,37]{1,0} fusion(p1, p0), kind=kCustom,
-    calls=triton_gemm_d_computation
-})"));
-  const HloComputation* computation =
-      module->entry_computation()->root_instruction()->called_computations()[0];
-  TF_ASSERT_OK_AND_ASSIGN(const auto analysis,
-                          TritonFusionAnalysis::Execute(*computation));
-  EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::LHS,
-                                 computation->parameter_instruction(0), 0),
-              ElementsAre(FieldsAre(/*stride=*/24, /*count=*/6,
-                                    /*slice_start=*/2, /*sliced_count=*/3,
-                                    /*subfragments=*/ElementsAre(3))));
-  EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::LHS,
-                                 computation->parameter_instruction(0), 1),
-              ElementsAre(FieldsAre(/*stride=*/1, /*count=*/24,
-                                    /*slice_start=*/16, /*sliced_count=*/7,
-                                    /*subfragments=*/ElementsAre(7))));
-}
-
-TEST_F(GemmFusionLevel2Test, FusedConcatenationIsAnalyzedCorrectly) {
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
-                          ParseAndReturnVerifiedModule(R"(
-e {
-  p0 = s8[153,1536] parameter(0)
-  p1 = s8[153,128] parameter(1)
-  p2 = s8[153,256] parameter(2)
-  cat = s8[153,1920] concatenate(p0, p1, p2), dimensions={1}
-  cvt = bf16[153,1920] convert(cat)
-  p3 = bf16[16,153] parameter(3)
-  ROOT d = bf16[16,1920] dot(p3, cvt),
-    lhs_contracting_dims={1}, rhs_contracting_dims={0}
-})"));
-  EXPECT_TRUE(GemmFusion(se::CudaComputeCapability{
-                             se::CudaComputeCapability::AMPERE, 0})
-                  .Run(module.get())
-                  .value());
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch((m::Fusion(m::Parameter(), m::Parameter(),
-                                    m::Parameter(), m::Parameter()))));
-  const HloComputation* computation =
-      module->entry_computation()->root_instruction()->called_computations()[0];
-  TF_ASSERT_OK_AND_ASSIGN(const auto analysis,
-                          TritonFusionAnalysis::Execute(*computation));
-
-  EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS,
-                                 computation->parameter_instruction(1), 0),
-              ElementsAre(FieldsAre(/*stride=*/1536, /*count=*/153,
-                                    /*slice_start=*/0, /*sliced_count=*/153,
-                                    /*subfragments=*/ElementsAre(153))));
-  EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS,
-                                 computation->parameter_instruction(1), 1),
-              ElementsAre(FieldsAre(/*stride=*/1, /*count=*/1536,
-                                    /*slice_start=*/0, /*sliced_count=*/1536,
-                                    /*subfragments=*/ElementsAre(1536))));
-
-  EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS,
-                                 computation->parameter_instruction(2), 0),
-              ElementsAre(FieldsAre(/*stride=*/128, /*count=*/153,
-                                    /*slice_start=*/0, /*sliced_count=*/153,
-                                    /*subfragments=*/ElementsAre(153))));
-  EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS,
-                                 computation->parameter_instruction(2), 1),
-              ElementsAre(FieldsAre(/*stride=*/1, /*count=*/128,
-                                    /*slice_start=*/-1536, /*sliced_count=*/128,
-                                    /*subfragments=*/ElementsAre(128))));
-
-  EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS,
-                                 computation->parameter_instruction(3), 0),
-              ElementsAre(FieldsAre(/*stride=*/256, /*count=*/153,
-                                    /*slice_start=*/0, /*sliced_count=*/153,
-                                    /*subfragments=*/ElementsAre(153))));
-  EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS,
-                                 computation->parameter_instruction(3), 1),
-              ElementsAre(FieldsAre(/*stride=*/1, /*count=*/256,
-                                    /*slice_start=*/-1536 - 128,
-                                    /*sliced_count=*/256,
-                                    /*subfragments=*/ElementsAre(256))));
-}
-
-TEST_F(GemmFusionLevel2Test, IndivisibleConcatenationIsNotFused) {
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
-                          ParseAndReturnVerifiedModule(R"(
-e {
-  p0 = s8[124,1024] parameter(0)
-  p1 = s8[124,1001] parameter(1)
-  cat = s8[124,2025] concatenate(p0, p1), dimensions={1}
-  cvt = f16[124,2025] convert(cat)
-  p2 = f16[123,124] parameter(2)
-  ROOT d = f16[2025,123] dot(cvt, p2),
-    lhs_contracting_dims={0}, rhs_contracting_dims={1}
-})"));
-  EXPECT_TRUE(GemmFusion(se::CudaComputeCapability{
-                             se::CudaComputeCapability::AMPERE, 0})
-                  .Run(module.get())
-                  .value());
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch((m::Fusion(m::Concatenate(), m::Parameter()))));
-}
-
-TEST_F(GemmFusionLevel2Test, ConcatenationOfContractingIsNotFused) {
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
-                          ParseAndReturnVerifiedModule(R"(
-e {
-  p0 = s8[124,1024] parameter(0)
-  p1 = s8[124,1024] parameter(1)
-  cat = s8[124,2048] concatenate(p0, p1), dimensions={1}
-  cvt = f16[124,2048] convert(cat)
-  p2 = f16[123,2048] parameter(2)
-  ROOT d = f16[124,123] dot(cvt, p2),
-    lhs_contracting_dims={1}, rhs_contracting_dims={1}
-})"));
-  EXPECT_TRUE(GemmFusion(se::CudaComputeCapability{
-                             se::CudaComputeCapability::AMPERE, 0})
-                  .Run(module.get())
-                  .value());
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch((m::Fusion(m::Concatenate(), m::Parameter()))));
-}
-
-TEST_F(GemmFusionLevel2Test, ConcatenationOfBatchIsNotFused) {
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
-                          ParseAndReturnVerifiedModule(R"(
-e {
-  p0 = s8[124,1024,50] parameter(0)
-  p1 = s8[124,1024,50] parameter(1)
-  cat = s8[124,2048,50] concatenate(p0, p1), dimensions={1}
-  cvt = f16[124,2048,50] convert(cat)
-  p2 = f16[123,2048,50] parameter(2)
-  ROOT d = f16[2048,124,123] dot(cvt, p2),
-    lhs_batch_dims={1}, rhs_batch_dims={1},
-    lhs_contracting_dims={2}, rhs_contracting_dims={2}
-})"));
-  EXPECT_TRUE(GemmFusion(se::CudaComputeCapability{
-                             se::CudaComputeCapability::AMPERE, 0})
-                  .Run(module.get())
-                  .value());
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch((m::Fusion(m::Concatenate(), m::Parameter()))));
-}
-
-TEST_F(GemmFusionLevel2Test,
-       DifferentConcatenationOfSameParametersIsFusedViaNodeDuplication) {
-  // It means that the same input is passed to the fusion multiple times and
-  // it's read differently for each.
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
-                          ParseAndReturnVerifiedModule(R"(
-e {
-  p0 = s8[128,2] parameter(0)
-  p1 = s8[128,2] parameter(1)
-  cat0 = s8[256,2] concatenate(p0, p1), dimensions={0}
-  cvt0 = f16[256,2] convert(cat0)
-  cat1 = s8[256,2] concatenate(p1, p0), dimensions={0}
-  n1 = s8[256,2] negate(cat1)
-  cvt1 = f16[256,2] convert(n1)
-  a = f16[256,2] add(cvt1, cvt0)
-  p2 = f16[2,18] parameter(2)
-  ROOT d = f16[18,256] dot(p2, a),
-    lhs_contracting_dims={0}, rhs_contracting_dims={1}
-})"));
-
-  EXPECT_TRUE(GemmFusion(se::CudaComputeCapability{
-                             se::CudaComputeCapability::AMPERE, 0})
-                  .Run(module.get())
-                  .value());
-  EXPECT_THAT(
-      module->entry_computation()->root_instruction(),
-      GmockMatch((m::Fusion(m::Parameter(), m::Parameter(), m::Parameter(),
-                            m::Parameter(), m::Parameter()))));
-}
-
-TEST_F(GemmFusionTest, CopiesDotMetadataToFusionOp) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-HloModule m
-
-ENTRY e {
-  p0 = f16[2,18] parameter(0)
-  p1 = f16[256,2] parameter(1)
-  ROOT d = f16[18,256] dot(p0, p1),
-    lhs_contracting_dims={0}, rhs_contracting_dims={1}, metadata={op_name="foo"}
-})")
-                    .value();
-  EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
-  EXPECT_EQ(
-      module->entry_computation()->root_instruction()->metadata().op_name(),
-      "foo");
-}
-
-// A test fixture class for testing the threshold for small matrices.
-class SmallDotGemmFusionTest : public GemmFusionTest {
- public:
-  DebugOptions GetDebugOptionsForTest() override {
-    DebugOptions debug_options = GemmFusionTest::GetDebugOptionsForTest();
-    debug_options.set_xla_gpu_gemm_rewrite_size_threshold(100);
-    return debug_options;
-  }
-};
-
-TEST_F(SmallDotGemmFusionTest, SkipSmallMatrixMultiplicationRewrite) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-HloModule m
-
-ENTRY e {
-  p0 = f16[2,10] parameter(0)
-  p1 = f16[10,2] parameter(1)
-  ROOT d = f16[10,10] dot(p0, p1),
-    lhs_contracting_dims={0}, rhs_contracting_dims={1}
-})")
-                    .value();
-
-  EXPECT_FALSE(GemmFusion(gpu_version_).Run(module.get()).value());
-
-  MatchHloModule(*module, R"(
-; CHECK-LABEL: ENTRY %e ({{.*}}: f16[2,10], {{.*}}: f16[10,2]) -> f16[10,10] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f16[2,10]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = f16[10,2]{1,0} parameter(1)
-; CHECK:      ROOT {{.*}} = f16[10,10]{1,0} dot(f16[2,10]{1,0} [[P0]], f16[10,2]{1,0} [[P1]])
-})");
-}
-
-TEST_F(SmallDotGemmFusionTest, LargeMatrixMultiplicationIsRewritten) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-HloModule m
-
-ENTRY e {
-  p0 = f16[2,18] parameter(0)
-  p1 = f16[50,2] parameter(1)
-  ROOT d = f16[18,50] dot(p0, p1),
-    lhs_contracting_dims={0}, rhs_contracting_dims={1}
-})")
-                    .value();
-
-  EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
-
-  MatchHloModule(*module, R"(
-; CHECK-LABEL: ENTRY %e ({{.*}}: f16[2,18], {{.*}}: f16[50,2]) -> f16[18,50] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f16[2,18]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = f16[50,2]{1,0} parameter(1)
-; CHECK:      ROOT {{.*}} = f16[18,50]{1,0}
-; CHECK:        fusion(f16[2,18]{1,0} [[P0]], f16[50,2]{1,0} [[P1]]),
-; CHECK:        kind=kCustom
-; CHECK:        __triton_gemm
-})");
-}
-
-class SparseDotTest : public GemmFusionTest {};
-
-TEST_F(SparseDotTest, DotWithSparseLhsOperandIsRewritten) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-HloModule test
-ENTRY main {
-  lhs = f16[2,16] parameter(0)
-  rhs = f16[32,2] parameter(1)
-  meta = u16[2,2] parameter(2)
-  ROOT dot = f32[2,2] dot(lhs, rhs, meta),
-      lhs_contracting_dims={1}, rhs_contracting_dims={0}, sparsity=L.1@2:4
-})")
-                    .value();
-  EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
-
-  MatchHloModule(*module, R"(
-; CHECK-LABEL: ENTRY %main ({{.*}}: f16[2,16], {{.*}}: f16[32,2], {{.*}}: u16[2,2]) -> f32[2,2] {
-; CHECK-NEXT: [[P0:%[^ ]+]] = f16[2,16]{1,0} parameter(0)
-; CHECK-NEXT: [[P1:%[^ ]+]] = f16[32,2]{1,0} parameter(1)
-; CHECK-NEXT: [[META:%[^ ]+]] = u16[2,2]{1,0} parameter(2)
-; CHECK:      ROOT {{.*}} = f32[2,2]{1,0}
-; CHECK-SAME:   fusion(f16[2,16]{1,0} [[P0]], f16[32,2]{1,0} [[P1]], u16[2,2]{1,0} [[META]]),
-; CHECK-SAME:   kind=kCustom
-; CHECK-SAME:   __triton_gemm
-})");
-}
-
-TEST_F(SparseDotTest, DotWithSparseRhsOperandIsNotSupported) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-HloModule test
-ENTRY main {
-  lhs = f16[2,32] parameter(0)
-  rhs = f16[16,2] parameter(1)
-  meta = u16[2,2] parameter(2)
-  ROOT dot = f32[2,2] dot(lhs, rhs, meta),
-      lhs_contracting_dims={1}, rhs_contracting_dims={0}, sparsity=R.0@2:4
-})")
-                    .value();
-  auto result = GemmFusion(gpu_version_).Run(module.get());
-  EXPECT_FALSE(result.ok());
-}
-
-TEST_F(SparseDotTest, UnsupportedSparsityType) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-HloModule test
-ENTRY main {
-  lhs = f16[2,8] parameter(0)
-  rhs = f16[32,2] parameter(1)
-  meta = u16[2,1] parameter(2)
-  ROOT dot = f32[2,2] dot(lhs, rhs, meta),
-      lhs_contracting_dims={1}, rhs_contracting_dims={0}, sparsity=L.1@1:4
-})")
-                    .value();
-  auto result = GemmFusion(gpu_version_).Run(module.get());
-  EXPECT_FALSE(result.ok());
-}
-
-}  // namespace
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gemm_rewriter.cc b/third_party/xla/xla/service/gpu/gemm_rewriter.cc
deleted file mode 100644
index e3dd0cf..0000000
--- a/third_party/xla/xla/service/gpu/gemm_rewriter.cc
+++ /dev/null
@@ -1,2429 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-=
-=============================================================================*/
-
-#include "xla/service/gpu/gemm_rewriter.h"
-
-#include <algorithm>
-#include <array>
-#include <cmath>
-#include <cstddef>
-#include <cstdint>
-#include <initializer_list>
-#include <limits>
-#include <memory>
-#include <optional>
-#include <string>
-#include <tuple>
-#include <utility>
-#include <variant>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/log/log.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "xla/hlo/evaluator/hlo_evaluator.h"
-#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/literal.h"
-#include "xla/literal_util.h"
-#include "xla/primitive_util.h"
-#include "xla/service/algorithm_util.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/service/gpu/ir_emission_utils.h"
-#include "xla/service/gpu/matmul_utils.h"
-#include "xla/service/hlo_creation_utils.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/status_macros.h"
-#include "xla/stream_executor/blas.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/stream_executor/gpu/gpu_blas_lt.h"
-#include "xla/types.h"
-#include "xla/util.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/ml_dtypes.h"
-#include "tsl/platform/statusor.h"
-#include "tsl/protobuf/dnn.pb.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-namespace m = match;
-
-// Give this instruction a more useful name than "custom-call.42".
-absl::Status SetName(HloModule *module, HloInstruction *gemm) {
-  if (IsCublasLtMatmul(*gemm)) {
-    module->SetAndUniquifyInstrName(gemm, "cublas-lt-matmul");
-    return absl::OkStatus();
-  }
-
-  TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
-                      gemm->backend_config<GpuBackendConfig>());
-  const GemmBackendConfig &config = gpu_config.gemm_backend_config();
-  const DotDimensionNumbers &dot_dims = config.dot_dimension_numbers();
-  bool is_batch_dot = !dot_dims.lhs_batch_dimensions().empty() ||
-                      !dot_dims.rhs_batch_dimensions().empty();
-
-  module->SetAndUniquifyInstrName(
-      gemm, is_batch_dot ? "cublas-batch-gemm" : "cublas-gemm");
-  return absl::OkStatus();
-}
-
-// Returns whether a given PrimitiveType is supported by cuBLASLt Epilogue
-// Fusion. A table of supported data types can be found in the cuBLASLt
-// documentation: https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmul.
-// Note that `Ctype` also describes the output type of the GEMM. Rows with
-// `Non-default epilogue not supported` entries in the last column indicate data
-// types not compatible with Epilogue Fusion.
-bool SupportsEpilogueFusion(PrimitiveType type) {
-  switch (type) {
-    case F8E4M3FN:
-    case F8E5M2:
-    case F16:
-    case BF16:
-    case F32:
-    case F64:
-      return true;
-    default:
-      return false;
-  }
-}
-
-bool IsF8Type(const HloInstruction *instr) {
-  return primitive_util::IsF8Type(instr->shape().element_type());
-}
-
-// Returns a new shape with non-batch dimensions padded to multiples of 16, as
-// required by cuBLASLt FP8 gemms.
-Shape PadShapeToMultipleOf16(const Shape old_shape,
-                             const absl::Span<const int64_t> batch_dims) {
-  Shape padded_shape = old_shape;
-  for (int i = 0; i < old_shape.rank(); ++i) {
-    if (!absl::c_linear_search(batch_dims, i)) {
-      int64_t padded_dimension =
-          RoundUpTo<int64_t>(old_shape.dimensions(i), 16);
-      padded_shape.set_dimensions(i, padded_dimension);
-    }
-  }
-  return padded_shape;
-}
-
-// Pad the dimensions of the operands to the target shape.
-HloInstruction *PadOperandToTargetShape(const Shape &target,
-                                        HloInstruction *x) {
-  if (ShapeUtil::Equal(target, x->shape()) ||
-      !ShapeUtil::SameElementType(x->shape(), target)) {
-    return x;
-  }
-
-  PaddingConfig padding_config;
-  for (int i = 0; i < x->shape().rank(); ++i) {
-    auto dimension = padding_config.add_dimensions();
-    dimension->set_edge_padding_low(0);
-    dimension->set_edge_padding_high(target.dimensions(i) -
-                                     x->shape().dimensions(i));
-    dimension->set_interior_padding(0);
-  }
-
-  HloInstruction *zero = x->AddInstruction(HloInstruction::CreateConstant(
-      LiteralUtil::Zero(x->shape().element_type())));
-  return x->AddInstruction(
-      HloInstruction::CreatePad(target, x, zero, padding_config));
-}
-
-// Pad the non-batch dimensions of the operands to multiples of 16 as required
-// by cuBLASLt FP8 gemms.
-HloInstruction *PadOperandToMultipleOf16(absl::Span<const int64_t> batch_dims,
-                                         HloInstruction *x) {
-  Shape padded_shape = PadShapeToMultipleOf16(x->shape(), batch_dims);
-  return PadOperandToTargetShape(padded_shape, x);
-}
-
-// Calculates the reciprocal of scalar when invert is true and converts to FP32.
-absl::StatusOr<HloInstruction *> InvertAndConvertScalar(HloInstruction *scalar,
-                                                        bool invert) {
-  DCHECK(ShapeUtil::IsScalar(scalar->shape()));
-
-  if (invert) {
-    Literal one_literal = LiteralUtil::One(scalar->shape().element_type());
-    HloInstruction *one = scalar->parent()->AddInstruction(
-        HloInstruction::CreateConstant(one_literal.Clone()));
-    TF_ASSIGN_OR_RETURN(scalar, MakeBinaryHlo(HloOpcode::kDivide, one, scalar,
-                                              &scalar->metadata()));
-  }
-  if (scalar->shape().element_type() != F32) {
-    scalar = MakeConvertToHlo(scalar, F32, &scalar->metadata());
-  }
-
-  return scalar;
-}
-
-// A path of instructions by traversing downwards through users, as (op,
-// operand_index) pairs. operand_index is the index to get to the previous
-// element in the path. I.e.,
-// path[i].first->operand(path[i].second) == path[i-1].first
-using InstrPath = std::vector<std::pair<HloInstruction *, int>>;
-
-// From 'instr', recursively traverses operands until an FP8 instruction is
-// encountered. Only unary ops and a few types of non-unary ops are traversed.
-// If an FP8 instruction is found, returns the path from the FP8 instruction to
-// 'instr'. Returns nullopt when no FP8 instruction is reached.
-//
-// The intent is, given 'instr' is the operand of a dot, to find a sequence of
-// instruction that can potentially be fused into a cuBLAS LT FP8 gemm.
-std::optional<InstrPath> FindF8SubgraphRecursive(
-    HloInstruction *instr, absl::flat_hash_set<int> &visited_instrs) {
-  // Avoid visiting the same instruction more than once.
-  if (!visited_instrs.emplace(instr->unique_id()).second) {
-    return std::nullopt;
-  }
-  if (IsF8Type(instr)) {
-    // The initial operand index is meaningless. Arbitrarily use -1.
-    return InstrPath{{instr, -1}};
-  }
-  if (instr->operand_count() == 1 || instr->opcode() == HloOpcode::kDivide ||
-      instr->opcode() == HloOpcode::kDynamicSlice ||
-      instr->opcode() == HloOpcode::kPad) {
-    std::optional<InstrPath> subgraph =
-        FindF8SubgraphRecursive(instr->mutable_operand(0), visited_instrs);
-    if (subgraph) {
-      subgraph->emplace_back(std::make_pair(instr, 0));
-    }
-    return subgraph;
-  } else if (instr->opcode() == HloOpcode::kMultiply ||
-             instr->opcode() == HloOpcode::kSelect) {
-    for (int k = 0; k < 2; ++k) {
-      // Iterate over operands 0 and 1 for multiply and operands 1 and 2 for
-      // select.
-      int operand_idx = k + (instr->opcode() == HloOpcode::kSelect);
-      std::optional<InstrPath> subgraph = FindF8SubgraphRecursive(
-          instr->mutable_operand(operand_idx), visited_instrs);
-      if (subgraph) {
-        subgraph->emplace_back(std::make_pair(instr, operand_idx));
-        return subgraph;
-      }
-    }
-  }
-  return std::nullopt;
-}
-
-// Contains information on a parameter (either the LHS or RHS) for a
-// gemm that can be potentially pattern-matched into an FP8 cublasLT gemm.
-struct MatchedFp8Param {
-  // The FP8 input to the gemm.
-  HloInstruction *fp8_input = nullptr;
-  // If nonnull, the scale for the 'x'
-  HloInstruction *scale = nullptr;
-  // Whether the scale, if present, multiplies or divides 'x'
-  bool mult_scale = false;
-  // A list of instructions from x to the dot instruction commutative with
-  // dequantization. Such instructions can be moved before the FP8 gemm.
-  InstrPath commutative_ops;
-};
-
-// Given an operand of a dot, `instr`, returns a MatchedFp8Param if this operand
-// allows rewriting the dot in an FP8 cublasLT custom call, optionally with
-// scaling. In particular, returns an MatchedFp8Param if either 'instr' is FP8
-// or there is a there is a path from an FP8 instruction 'fp8_input' to 'instr'
-// consisting of the following.
-// 1. A convert to a wider type.
-// 2. Optionally, a multiplication/division by a scalar, representing the scale.
-//    If present, the scalar scale is returned as 'scale' and 'mult_scale'
-//    is set to true or false depending on whether there is a multiplication or
-//    a division.
-// 3. A possibly-empty set of ops communative with steps (1) and (2), meaning
-//    they can be safely moved before step (1). Such ops are returned in
-//    'commutative_ops'.
-// Steps (1) and (2) together are a dequantization, and can be fused into a
-// cublas LT matmul. Step (3) can be moved before the cublas LT matmul.
-std::optional<MatchedFp8Param> MatchFp8Param(HloInstruction *instr) {
-  absl::flat_hash_set<int> visited_instrs;
-  std::optional<InstrPath> maybe_subgraph =
-      FindF8SubgraphRecursive(instr, visited_instrs);
-  if (!maybe_subgraph) {
-    return std::nullopt;
-  }
-  InstrPath &subgraph = maybe_subgraph.value();
-
-  MatchedFp8Param param;
-
-  // Directly operating on an FP8 operand.
-  if (subgraph.size() == 1) {
-    CHECK(IsF8Type(subgraph[0].first));
-    param.fp8_input = subgraph[0].first;
-    return param;
-  }
-
-  int num_dequant_ops;
-  // When not operating directly on an FP8 operand, the second and
-  // third instructions in the subgraph can describe a dequantization, i.e. a
-  // convert instruction followed by a multiply/divide instruction.
-  if (subgraph.size() > 2 &&
-      Match(subgraph[2].first,
-            m::MultiplyAnyOrder(m::Convert(m::Op(&param.fp8_input)),
-                                m::Broadcast(m::Op(&param.scale))))) {
-    param.mult_scale = true;
-    num_dequant_ops = 2;
-  } else if (subgraph.size() > 2 &&
-             Match(subgraph[2].first,
-                   m::Divide(m::Convert(m::Op(&param.fp8_input)),
-                             m::Broadcast(m::Op(&param.scale))))) {
-    param.mult_scale = false;
-    num_dequant_ops = 2;
-  } else if (subgraph.size() > 1 &&
-             Match(subgraph[1].first, m::Convert(m::Op(&param.fp8_input)))) {
-    // We have a convert from FP8 without a scale in this case.
-    param.scale = nullptr;
-    num_dequant_ops = 1;
-  } else {
-    VLOG(1) << "Possible intended FP8 GEMM operating on "
-            << instr->ToShortString() << " not rewritten into FP8 Custom Call.";
-    return std::nullopt;
-  }
-
-  auto preserves_element_type = [](const HloInstruction *instr) -> bool {
-    return ShapeUtil::SameElementType(instr->shape(),
-                                      instr->operand(0)->shape());
-  };
-  auto use_spmd_partitioning = [](const HloInstruction *instr) -> bool {
-    return instr->GetModule()->config().use_spmd_partitioning();
-  };
-
-  // Skip the initial FP8 instruction and the dequantization instructions.
-  int start = 1 + num_dequant_ops;
-  for (int i = start; i < subgraph.size(); ++i) {
-    // The remaining instructions must be commutative with dequantization.
-    // Bitcast, broadcast, copy, dynamic-slice, pad, reshape, select, slice,
-    // transpose, all-gather, all-to-all and collective-permute instructions are
-    // supported. Specifically, the all-gather, all-to-all and
-    // collective-permute operations are permitted only in SPMD cases since the
-    // optimization cannot be guaranteed to be applied to all replicas in the
-    // MPMD scenario.
-    if (!Match(
-            subgraph[i].first,
-            m::AnyOf<HloInstruction>(
-                m::Bitcast().WithPredicate(preserves_element_type),
-                m::Broadcast(), m::Copy(), m::DynamicSlice(), m::Pad(),
-                m::Reshape(), m::Select(), m::Slice(), m::Transpose(),
-                m::AllGather().WithPredicate(use_spmd_partitioning),
-                m::AllToAll().WithPredicate(use_spmd_partitioning),
-                m::CollectivePermute().WithPredicate(use_spmd_partitioning)))) {
-      VLOG(1) << "Possible intended FP8 GEMM operating on "
-              << instr->ToShortString()
-              << " not rewritten into FP8 Custom Call.";
-      return std::nullopt;
-    }
-    // One of the operands of select must be zero for the op to be commutative
-    // with dequantization.
-    if (Match(subgraph[i].first, m::Select()) &&
-        !Match(subgraph[i].first->operand(subgraph[i].second == 2 ? 1 : 2),
-               m::Broadcast(m::ConstantScalar(0)))) {
-      VLOG(1) << "Possible intended FP8 GEMM operating on "
-              << instr->ToShortString()
-              << " not rewritten into FP8 Custom Call. Select requires a zero "
-                 "operand to be exchanged with dequantization.";
-      return std::nullopt;
-    }
-  }
-
-  param.commutative_ops = {subgraph.begin() + start, subgraph.end()};
-  return param;
-}
-
-// Transposes a matrix by swapping the contracting and non-contracting
-// dimension. There must be only one contracting and only one non-contracting
-// dimension. Keeps the layout the same.
-HloInstruction *TransposeMatrix(HloInstruction *instr, int64_t contracting_dim,
-                                absl::Span<const int64_t> batch_dims) {
-  // Identify the dimensional order which describes a transpose of the
-  // contracting and non-contracting dimensions of the GEMM.
-  std::vector<int64_t> permutation(instr->shape().dimensions_size(), -1);
-  // Discard the batch dimensions.
-  for (int64_t batch_dim : batch_dims) {
-    permutation[batch_dim] = batch_dim;
-  }
-  // Identify the non-contracting dimension.
-  int non_contracting_dim;
-  for (int i = 0; i < instr->shape().dimensions_size(); ++i) {
-    if (permutation[i] == -1 && contracting_dim != i) {
-      non_contracting_dim = i;
-    }
-  }
-  permutation[non_contracting_dim] = contracting_dim;
-  permutation[contracting_dim] = non_contracting_dim;
-
-  Shape new_shape = ShapeUtil::PermuteDimensions(permutation, instr->shape());
-  *new_shape.mutable_layout() = instr->shape().layout();
-  return instr->AddInstruction(
-      HloInstruction::CreateTranspose(new_shape, instr, permutation));
-}
-
-// If the bias is a sequence of ops that depend only on broadcasts of
-// constants, materialize the bias if it's small.
-//
-// Normally the constant-folding pass would materialize the bias if it is
-// calculated entirely from constants. But if the bias is a broadcast of a
-// constant, constant-folding won't expand the broadcast, on the theory that
-// folding broadcasts of constants causes us to consume more memory and can
-// actually make things slower (because any op which reads the constant has
-// to read more memory).
-//
-// OTOH in our case, we don't want to run an op that just broadcasts a
-// constant so we can fuse it into this gemm. That would defeat the whole
-// purpose of this fusion, which is to launch fewer kernels.  So if we can,
-// we expand out this constant ourselves.
-//
-// TODO(b/192499646): Even better would be to use cublasLT to fuse the
-// broadcasted bias, if it supports that fusion efficiently.
-HloInstruction *MaybeConstantFoldBias(HloInstruction *bias) {
-  // This limit was not chosen carefully.
-  constexpr int kMaxMaterializeBiasBytes = 8 * 1024 * 1024;
-
-  // Don't fold broadcasts of scalars -- algsimp will just collapse it again.
-  auto is_nonscalar = [](const HloInstruction *instr) {
-    return !ShapeUtil::IsEffectiveScalar(instr->shape());
-  };
-
-  // For now, only fold broadcast(constant) or
-  // reshape/transpose/bitcast(broadcast(constant)). This lets us avoid the
-  // complexity in the constant-folding pass about what is and isn't legal to
-  // fold.
-  auto broadcast_of_nonscalar =
-      m::Broadcast(m::Constant().WithPredicate(is_nonscalar));
-
-  if (ShapeUtil::ByteSizeOf(bias->shape()) <= kMaxMaterializeBiasBytes &&
-      (Match(bias, broadcast_of_nonscalar) ||
-       Match(bias, m::Reshape(broadcast_of_nonscalar)) ||
-       Match(bias, m::Transpose(broadcast_of_nonscalar)) ||
-       Match(bias, m::Bitcast(broadcast_of_nonscalar)))) {
-    HloEvaluator evaluator(/*max_loop_iterations=*/0);
-    Literal result;
-    if (evaluator.TryEvaluate(
-            bias, &result,
-            /*recursively_evaluate_nonconstant_operands=*/true)) {
-      return bias->parent()->AddInstruction(
-          HloInstruction::CreateConstant(std::move(result)));
-    }
-  }
-
-  return bias;
-}
-
-auto Gemm(HloInstruction **instr) {
-  return m::CustomCall(instr, {kGemmCallTarget});
-}
-
-auto CublasLtMatmul(HloInstruction **instr) {
-  return m::CustomCall(instr, {kCublasLtMatmulCallTarget});
-}
-
-auto CublasLtMatmulF8(HloInstruction **instr) {
-  return m::CustomCall(instr, {kCublasLtMatmulF8CallTarget});
-}
-
-auto CublasLtMatmulMaybeF8(HloInstruction **instr) {
-  return m::CustomCall(
-      instr, {kCublasLtMatmulCallTarget, kCublasLtMatmulF8CallTarget});
-}
-
-auto GemmOrCublasLtMatmul(HloInstruction **instr) {
-  return m::CustomCall(instr, {kGemmCallTarget, kCublasLtMatmulCallTarget});
-}
-
-auto GemmOrCublasLtMatmulMaybeF8(HloInstruction **instr) {
-  return m::CustomCall(instr, {kGemmCallTarget, kCublasLtMatmulCallTarget,
-                               kCublasLtMatmulF8CallTarget});
-}
-
-auto BcastConstScalar(HloInstruction **instr, double value) {
-  return m::Broadcast(instr, m::ConstantScalar(value));
-}
-
-auto BcastConstScalar(double value) { return BcastConstScalar(nullptr, value); }
-
-auto BcastConstScalarNear(double value) {
-  return m::Broadcast(m::ConstantScalar().WithPredicate(
-      [expected = value](const HloInstruction *instr) {
-        // Not a very robust floating-point comparison, but good enough for our
-        // purposes.
-        std::optional<double> actual =
-            xla::Cast<const HloConstantInstruction>(instr)
-                ->literal()
-                .GetAsDouble({});
-        if (!actual.has_value()) return false;
-        double epsilon;
-        switch (instr->shape().element_type()) {
-          case F16:
-            epsilon = 128 * std::numeric_limits<Eigen::half>::epsilon();
-            break;
-          case BF16:
-            epsilon = 128 * std::numeric_limits<bfloat16>::epsilon();
-            break;
-          case F32:
-            epsilon = 128 * std::numeric_limits<float>::epsilon();
-            break;
-          case F64:
-            epsilon = 128 * std::numeric_limits<double>::epsilon();
-            break;
-          default:
-            return false;
-        }
-        return abs(*actual - expected) < (abs(*actual + expected) * epsilon);
-      }));
-}
-
-template <typename Pattern>
-auto OptionalSlice(HloInstruction **optional_slice, Pattern pattern) {
-  return m::AnyOf<HloInstruction>(m::Slice(optional_slice, pattern),
-                                  std::move(pattern));
-}
-
-template <typename Pattern>
-auto OptionalConvert(HloInstruction **optional_convert, Pattern pattern) {
-  return m::AnyOf<HloInstruction>(m::Convert(optional_convert, pattern),
-                                  std::move(pattern));
-}
-
-template <typename Pattern>
-auto OptionalBitcast(HloInstruction **optional_bitcast, Pattern pattern) {
-  return m::AnyOf<HloInstruction>(m::Bitcast(optional_bitcast, pattern),
-                                  std::move(pattern));
-}
-
-// The rewriting proceeds in a bottom-up way:
-//
-// (kDot A B) is rewritten into a (kCustomCall:gemm A B)
-//
-// (kMultiply (kCustomCall:gemm A B) C) is folding C (provided it's a constant)
-// into an alpha parameter of the custom call.
-//
-// (kAdd (kCustomCall:gemm A B) C) is rewritten into (kCustomCall:gemm A B C),
-// where the "beta" parameter is set to 1 (provided it was zero before,
-// and provided C has no other users).
-// We then guide the buffer assignment to alias the buffer of the custom call
-// and C.
-//
-// For scaled FP8 GEMMs on Hopper systems, the following steps mentioned in
-// RFC #22 (https://github.com/openxla/xla/discussions/22) are elided and
-// rewritten into a Custom Call:
-//
-// 1. Cast each input from FP8 to a wider type such as FP16 or FP32.
-// 2. Unscale each input by multiplying each input by the corresponding input
-// scale.
-// 3. Evaluate the matrix multiplication on the scaled inputs.
-// 4. Compute the maximum of the absolute values in the result of the GEMM
-// (DAmax).
-// 5. Scale the output by dividing the output by the output scale.
-// 6. Cast the output back to FP8. Since saturation should be done on
-// overflow, this is represented by a Clamp instruction followed by a Convert
-// instruction.
-
-// Steps 1 through 3 can be elided independently of the remainder. Steps 5 and
-// 6 are elided only if steps 1 through 3 were successfully transformed. Step
-// 4 requires steps 5 and 6, i.e. the computation of DAmax can be elided only
-// when the output of the GEMM is requested in FP8 format.
-class GemmRewriterVisitor : public DfsHloRewriteVisitor {
- public:
-  explicit GemmRewriterVisitor(const se::GpuComputeCapability &gpu_version,
-                               const int32_t toolkit_version,
-                               const bool f8_rewrite)
-      : gpu_version_(gpu_version),
-        toolkit_version_(toolkit_version),
-        f8_rewrite_(f8_rewrite) {}
-
-  absl::Status HandleDot(HloInstruction *instr) override {
-    if (!IsMatrixMultiplication(*instr) &&
-        !IsMatrixVectorMultiplication(*instr)) {
-      return absl::OkStatus();
-    }
-    // Sparse dot is not supported.
-    if (Cast<HloDotInstruction>(instr)->sparse_operands()) {
-      return absl::OkStatus();
-    }
-
-    int64_t gemm_rewrite_size_threshold =
-        instr->GetModule()
-            ->config()
-            .debug_options()
-            .xla_gpu_gemm_rewrite_size_threshold();
-    TF_ASSIGN_OR_RETURN(bool is_matmul_tiny,
-                        IsMatrixMultiplicationTooSmallForRewriting(
-                            *instr, gemm_rewrite_size_threshold));
-    if (is_matmul_tiny && IsDotSupportedByClassicalEmitters(*instr)) {
-      return absl::OkStatus();
-    }
-
-    CHECK(!instr->IsRank2Transpose());
-    if (instr->operand(0)->IsRank2Transpose() ||
-        instr->operand(1)->IsRank2Transpose()) {
-      return absl::OkStatus();
-    }
-    // Create a GemmBackendConfig based on the instruction.
-    TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_backend_config,
-                        instr->backend_config<GpuBackendConfig>());
-    GemmBackendConfig &gemm_backend_config =
-        *gpu_backend_config.mutable_gemm_backend_config();
-    gemm_backend_config.set_alpha_real(1.0);
-    gemm_backend_config.set_alpha_imag(0.0);
-    gemm_backend_config.set_beta(0.0);
-    *gemm_backend_config.mutable_dot_dimension_numbers() =
-        instr->dot_dimension_numbers();
-    *gemm_backend_config.mutable_precision_config() = instr->precision_config();
-
-    HloInstruction *lhs = instr->mutable_operand(0);
-    HloInstruction *rhs = instr->mutable_operand(1);
-    auto attributes = instr->frontend_attributes().map();
-    gemm_backend_config.set_grad_x(attributes["grad_x"] == "true");
-    gemm_backend_config.set_grad_y(attributes["grad_y"] == "true");
-
-    int64_t lhs_batch_dims_size =
-        instr->dot_dimension_numbers().lhs_batch_dimensions_size();
-    bool is_lhs_vector =
-        lhs->shape().dimensions_size() == lhs_batch_dims_size + 1;
-    bool is_rhs_vector =
-        rhs->shape().dimensions_size() == lhs_batch_dims_size + 1;
-    int64_t lhs_stride =
-        is_lhs_vector ? lhs->shape().dimensions(lhs_batch_dims_size)
-                      : lhs->shape().dimensions(lhs_batch_dims_size) *
-                            lhs->shape().dimensions(lhs_batch_dims_size + 1);
-    int64_t rhs_stride =
-        is_rhs_vector ? rhs->shape().dimensions(lhs_batch_dims_size)
-                      : rhs->shape().dimensions(lhs_batch_dims_size) *
-                            rhs->shape().dimensions(lhs_batch_dims_size + 1);
-
-    gemm_backend_config.set_lhs_stride(lhs_stride);
-    gemm_backend_config.set_rhs_stride(rhs_stride);
-
-    if (f8_rewrite_) {
-      // Rewrite FP8 GEMMs into a type-specific cublasLT Custom Call.
-      TF_ASSIGN_OR_RETURN(
-          bool supported_by_cublaslt,
-          GemmIsSupportedByCublasLt(*instr, gemm_backend_config));
-      std::optional<MatchedFp8Param> a, b;
-      if (supported_by_cublaslt && instr->opcode() == HloOpcode::kDot &&
-          (a = MatchFp8Param(
-               const_cast<HloInstruction *>(instr->operand(0)))) &&
-          (b = MatchFp8Param(
-               const_cast<HloInstruction *>(instr->operand(1))))) {
-        if (IsRocm(gpu_version_) && toolkit_version_ < 60200 &&
-            instr->shape().element_type() != F16 &&
-            instr->shape().element_type() != F32) {
-          TF_ASSIGN_OR_RETURN(instr,
-                              TurnF8DotWithUnsupportedOutputTypeIntoF32(instr));
-        }
-        TF_ASSIGN_OR_RETURN(bool created_call,
-                            CreateF8CustomCall(instr, gpu_backend_config,
-                                               a.value(), b.value()));
-        if (created_call) {
-          return absl::OkStatus();
-        }
-      }
-      if (IsF8Type(instr->operand(0))) {
-        // FP8 rewriter couldn't rewrite dot with FP8 inputs into cublasLt
-        // custom call, so turn into an FP16 dot which may be rewritten as an
-        // FP16 Triton, cublas or cublasLt call.
-        TF_ASSIGN_OR_RETURN(instr, TurnF8DotIntoF16Dot(instr));
-      }
-    } else {
-      // Rewrite non-FP8 GEMMs into a cublas or cublasLT Custom Call.
-      TF_ASSIGN_OR_RETURN(
-          absl::string_view gemm_custom_call_target,
-          GetNonFp8GemmCustomCallTarget(*instr, gemm_backend_config));
-      const Shape &output_shape = instr->shape();
-      HloInstruction *gemm_call =
-          instr->AddInstruction(HloInstruction::CreateCustomCall(
-              output_shape,
-              {instr->mutable_operand(0), instr->mutable_operand(1)},
-              gemm_custom_call_target));
-      TF_RETURN_IF_ERROR(gemm_call->set_backend_config(gpu_backend_config));
-      TF_RETURN_IF_ERROR(ReplaceInstruction(instr, gemm_call));
-    }
-    return absl::OkStatus();
-  }
-
-  absl::Status HandleMultiply(HloInstruction *instr) override {
-    HloInstruction *alpha, *existing_gemm;
-    if (Match(instr,
-              m::MultiplyAnyOrder(
-                  GemmOrCublasLtMatmulMaybeF8(&existing_gemm).WithOneUser(),
-                  m::Broadcast(m::ConstantScalar(&alpha)).WithOneUser()))) {
-      TF_ASSIGN_OR_RETURN(auto gpu_config,
-                          existing_gemm->backend_config<GpuBackendConfig>());
-      GemmBackendConfig &config = *gpu_config.mutable_gemm_backend_config();
-      // Do not fuse alpha into S32 GEMM, as they only support fixed values for
-      // alpha/beta.
-      if (existing_gemm->shape().element_type() == S32) {
-        return absl::OkStatus();
-      }
-
-      if (config.beta() == 0.0 && existing_gemm->user_count() == 1) {
-        complex128 prev_alpha = {config.alpha_real(), config.alpha_imag()};
-        complex128 new_alpha =
-            *alpha->literal().GetAsComplex128({}) * prev_alpha;
-        config.set_alpha_real(new_alpha.real());
-        config.set_alpha_imag(new_alpha.imag());
-        TF_RETURN_IF_ERROR(existing_gemm->set_backend_config(gpu_config));
-        return ReplaceInstruction(instr, existing_gemm);
-      }
-    }
-
-    HloInstruction *d_scale;
-    if (Match(instr, m::MultiplyAnyOrder(
-                         CublasLtMatmulF8(&existing_gemm).WithOneUser(),
-                         m::Broadcast(m::Op(&d_scale)).WithOneUser()))) {
-      return F8ScaleD(instr, existing_gemm, d_scale);
-    }
-
-    // Attempt to match approximate GELU activation
-    // (https://arxiv.org/abs/1606.08415), where:
-    // approx_gelu(x) = x * cdf(x)
-    // cdf(x) = 0.5 * (1 + tanh(sqrt(2 / pi) * (x + 0.044715 * x**3))
-    HloInstruction *cdf, *slice_or_bitcast = nullptr;
-    if (Match(instr, m::MultiplyAnyOrder(
-                         m::AnyOf<HloInstruction>(
-                             m::Slice(&slice_or_bitcast,
-                                      CublasLtMatmulMaybeF8(&existing_gemm)),
-                             m::Bitcast(&slice_or_bitcast,
-                                        CublasLtMatmulMaybeF8(&existing_gemm)),
-                             CublasLtMatmulMaybeF8(&existing_gemm)),
-                         m::Op(&cdf).WithOneUser())) &&
-        Match(cdf,
-              m::MultiplyAnyOrder(
-                  BcastConstScalar(0.5),
-                  m::AddAnyOrder(
-                      BcastConstScalar(1.0),
-                      m::Tanh(
-                          m::MultiplyAnyOrder(
-                              BcastConstScalarNear(sqrt(M_2_PI)),
-                              m::AddAnyOrder(
-                                  m::Op().Is(slice_or_bitcast ? slice_or_bitcast
-                                                              : existing_gemm),
-                                  m::MultiplyAnyOrder(
-                                      BcastConstScalarNear(0.044715),
-                                      m::MultiplyAnyOrder(
-                                          m::Op().Is(slice_or_bitcast
-                                                         ? slice_or_bitcast
-                                                         : existing_gemm),
-                                          m::MultiplyAnyOrder(
-                                              m::Op().Is(slice_or_bitcast
-                                                             ? slice_or_bitcast
-                                                             : existing_gemm),
-                                              m::Op().Is(slice_or_bitcast
-                                                             ? slice_or_bitcast
-                                                             : existing_gemm))
-                                              .WithOneUser())
-                                          .WithOneUser())
-                                      .WithOneUser())
-                                  .WithOneUser())
-                              .WithOneUser())
-                          .WithOneUser())))) {
-      return FuseGeluActivation(instr, existing_gemm, slice_or_bitcast);
-    }
-    return absl::OkStatus();
-  }
-
-  // Fuse the scaling of an FP8 GEMM into the Custom Call.
-  absl::Status HandleDivide(HloInstruction *instr) override {
-    HloInstruction *existing_gemm, *d_scale;
-    if (Match(instr, m::Divide(CublasLtMatmulF8(&existing_gemm).WithOneUser(),
-                               m::Broadcast(m::Op(&d_scale)).WithOneUser()))) {
-      return F8ScaleD(instr, existing_gemm, d_scale);
-    }
-    return absl::OkStatus();
-  }
-
-  absl::Status HandleAdd(HloInstruction *instr) override {
-    HloInstruction *bias, *existing_gemm = nullptr;
-    HloInstruction *optional_slice = nullptr;
-    HloInstruction *optional_convert = nullptr;
-    HloInstruction *optional_bitcast = nullptr;
-    // Attempt to elide broadcast and fuse addition of a vector bias into
-    // GEMM, including when slicing is applied to the result.
-    if (Match(instr,
-              m::AddAnyOrder(
-                  OptionalBitcast(
-                      &optional_bitcast,
-                      OptionalSlice(
-                          &optional_slice,
-                          CublasLtMatmulMaybeF8(&existing_gemm).WithOneUser())
-                          .WithOneUser())
-                      .WithOneUser(),
-                  m::Broadcast(&bias,
-                               OptionalConvert(&optional_convert, m::Op()))))) {
-      TF_ASSIGN_OR_RETURN(
-          bool was_fused,
-          FuseVectorBiasAdd(instr, bias, existing_gemm, optional_slice,
-                            optional_convert, optional_bitcast));
-
-      if (was_fused) {
-        return absl::OkStatus();
-      }
-    }
-    // Attempt to elide broadcast and fuse addition of a vector bias into
-    // *batched* GEMM as a matrix bias addition using FuseMatrixBiasAdd.
-    // add(bitcast(gemm(a, b)), broadcast(bias)) ->
-    //   bitcast(add(gemm(a, b), bitcast(broadcast(bias)))) ->
-    //   bitcast(gemm(a, b, bitcast(broadcast(bias)))) (FuseMatrixBiasAdd)
-    //
-    if (Match(
-            instr,
-            m::AddAnyOrder(
-                m::Bitcast(CublasLtMatmulMaybeF8(&existing_gemm).WithOneUser())
-                    .WithOneUser(),
-                m::Broadcast(&bias, m::Op()).WithOneUser()))) {
-      TF_ASSIGN_OR_RETURN(
-          HloInstruction * new_add,
-          MakeBinaryHlo(HloOpcode::kAdd, existing_gemm,
-                        MakeBitcastHlo(bias, existing_gemm->shape())));
-      TF_RETURN_IF_ERROR(
-          ReplaceInstruction(instr, MakeBitcastHlo(new_add, instr->shape())));
-
-      // Continue below.
-      instr = new_add;
-    }
-
-    // Do not fuse broadcast unless we can fuse its input, as it will cause
-    // broadcast materialization.
-    auto is_not_broadcast = [](const HloInstruction *instr) {
-      return instr->opcode() != HloOpcode::kBroadcast;
-    };
-
-    // add(bitcast(gemm(a, b)), bias) ->
-    //   bitcast(add(gemm(a, b), bitcast(bias))) ->
-    //   bitcast(gemm(a, b, bitcast(bias))) (later down in this function).
-    //
-    // We see this idiom in models that contain batch-dots, where we cast
-    // between a rank-2 shape for non-batch dots and a higher-rank shape for
-    // batch-dots.
-    //
-    // The last stage of the transform may fail (because of any of the checks in
-    // FuseMatrixBiasAdd), but if so that's okay -- we'll have done a useless
-    // transformation, but it doesn't hurt anything.
-    if (Match(instr,
-              m::AddAnyOrder(
-                  m::Bitcast(
-                      GemmOrCublasLtMatmulMaybeF8(&existing_gemm).WithOneUser())
-                      .WithOneUser(),
-                  m::Op(&bias).WithPredicate(is_not_broadcast)))) {
-      HloInstruction *new_bitcast =
-          MakeBitcastHlo(bias, existing_gemm->shape(), &bias->metadata());
-      TF_ASSIGN_OR_RETURN(HloInstruction * new_add,
-                          MakeBinaryHlo(HloOpcode::kAdd, existing_gemm,
-                                        new_bitcast, &bias->metadata()));
-      TF_RETURN_IF_ERROR(
-          ReplaceInstruction(instr, MakeBitcastHlo(new_add, instr->shape())));
-
-      // Continue below transforming new_add.
-      instr = new_add;
-    }
-
-    // Attempt to fuse matrix bias into gemm with optional convert
-    // add(convert(gemm(a, b)), c) -> gemm(a, b, c)
-    // add(gemm(a, b), c) -> gemm(a, b, c)
-    if (Match(instr,
-              m::AddAnyOrder(
-                  m::AnyOf<HloInstruction>(
-                      GemmOrCublasLtMatmul(&existing_gemm).WithOneUser(),
-                      m::Convert(
-                          GemmOrCublasLtMatmul(&existing_gemm).WithOneUser())
-                          .WithOneUser()),
-                  m::Op(&bias).WithPredicate(is_not_broadcast)))) {
-      TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_backend_config,
-                          existing_gemm->backend_config<GpuBackendConfig>());
-      const GemmBackendConfig &gemm_backend_config =
-          gpu_backend_config.gemm_backend_config();
-      // check if type combination is supported here
-      TF_ASSIGN_OR_RETURN(
-          bool types_are_supported,
-          IsLegacyCublasMatmul(*existing_gemm)
-              ? TypesAreSupportedByLegacyCublas(*existing_gemm,
-                                                gemm_backend_config, instr)
-              : TypesAreSupportedByCublasLt(*existing_gemm, gemm_backend_config,
-                                            instr));
-
-      // for mix type gemm, only fuse add if there is no consumers
-      // ROOT add
-      // ROOT tuple(add)
-      bool has_no_consumer =
-          instr->shape().element_type() ==
-              existing_gemm->shape().element_type() ||
-          instr->user_count() == 0 ||
-          (instr->user_count() == 1 &&
-           instr->users()[0]->opcode() == HloOpcode::kTuple &&
-           instr->users()[0]->user_count() == 0);
-
-      if (types_are_supported && has_no_consumer) {
-        return FuseMatrixBiasAdd(instr, bias, existing_gemm);
-      }
-    }
-
-    HloInstruction *optional_bitcast_matrix = nullptr;
-    HloInstruction *optional_slice_matrix = nullptr;
-    if (Match(instr,
-              m::AddAnyOrder(
-                  OptionalBitcast(
-                      &optional_bitcast_matrix,
-                      OptionalSlice(&optional_slice_matrix,
-                                    GemmOrCublasLtMatmulMaybeF8(&existing_gemm)
-                                        .WithOneUser()))
-                      .WithOneUser(),
-                  m::Op(&bias).WithPredicate(is_not_broadcast)))) {
-      // The matrix bias must not be FP8, see
-      // https://docs.nvidia.com/cuda/cublas/index.html.
-      if (!IsF8Type(bias)) {
-        return FuseMatrixBiasAdd(instr, bias, existing_gemm,
-                                 optional_bitcast_matrix,
-                                 optional_slice_matrix);
-      }
-    }
-
-    return absl::OkStatus();
-  }
-
-  absl::Status HandleMaximum(HloInstruction *instr) override {
-    HloInstruction *existing_gemm, *zeros;
-    HloInstruction *optional_slice_or_bitcast = nullptr;
-    // Attempt to elide maximum and fuse ReLU activation into GEMM, including
-    // when slicing or bitcasting is applied to the result.
-    if (Match(instr,
-              m::MaximumAnyOrder(
-                  m::AnyOf<HloInstruction>(
-                      m::Slice(
-                          &optional_slice_or_bitcast,
-                          CublasLtMatmulMaybeF8(&existing_gemm).WithOneUser()),
-                      m::Bitcast(
-                          &optional_slice_or_bitcast,
-                          CublasLtMatmulMaybeF8(&existing_gemm).WithOneUser()),
-                      CublasLtMatmulMaybeF8(&existing_gemm))
-                      .WithOneUser(),
-                  m::Broadcast(&zeros, m::ConstantScalar(0))))) {
-      TF_RETURN_IF_ERROR(FuseReluActivation(instr, zeros, existing_gemm,
-                                            optional_slice_or_bitcast));
-    }
-    return absl::OkStatus();
-  }
-
-  absl::Status HandleConvert(HloInstruction *instr) override {
-    HloInstruction *clamp_lower, *clamp_upper, *existing_gemm,
-        *d_scale = nullptr, *binary = nullptr;
-    // Attempt to elide the scaling and conversion of the result of an FP8
-    // GEMM, including the optional calculation of the maximum of the absolute
-    // values before scaling, and adapt the Custom Call.
-    if (Match(instr,
-              m::Convert(
-                  m::Clamp(
-                      m::Broadcast(m::ConstantScalar(&clamp_lower)),
-                      m::AnyOf<HloInstruction>(
-                          CublasLtMatmulF8(&existing_gemm),
-                          m::Divide(&binary, CublasLtMatmulF8(&existing_gemm),
-                                    m::Broadcast(m::Op(&d_scale))),
-                          m::MultiplyAnyOrder(&binary,
-                                              CublasLtMatmulF8(&existing_gemm),
-                                              m::Broadcast(m::Op(&d_scale)))),
-                      m::Broadcast(m::ConstantScalar(&clamp_upper)))
-                      .WithOneUser()))) {
-      return F8ConvertD(
-          instr, existing_gemm, d_scale, clamp_lower, clamp_upper,
-          /*mult_scale=*/(binary && binary->opcode() == HloOpcode::kMultiply));
-    }
-    return absl::OkStatus();
-  }
-
-  static bool IsCuda(const se::GpuComputeCapability &gpu_version) {
-    return std::holds_alternative<se::CudaComputeCapability>(gpu_version);
-  }
-
-  static absl::StatusOr<se::CudaComputeCapability> GetCudaComputeCapability(
-      const se::GpuComputeCapability &gpu_version) {
-    auto *cuda_cc = std::get_if<se::CudaComputeCapability>(&gpu_version);
-    if (cuda_cc == nullptr) {
-      return absl::InvalidArgumentError("Compute Capability is not CUDA.");
-    }
-    return *cuda_cc;
-  }
-
-  static bool IsRocm(const se::GpuComputeCapability &gpu_version) {
-    return std::holds_alternative<se::RocmComputeCapability>(gpu_version);
-  }
-
-  static absl::StatusOr<se::RocmComputeCapability> GetRocmComputeCapability(
-      const se::GpuComputeCapability &gpu_version) {
-    auto rocm_cc = std::get_if<se::RocmComputeCapability>(&gpu_version);
-    if (rocm_cc == nullptr) {
-      return absl::InvalidArgumentError("Compute Capability is not ROCm.");
-    }
-    return *rocm_cc;
-  }
-
-  absl::StatusOr<bool> CreateF8CustomCall(HloInstruction *instr,
-                                          GpuBackendConfig &gpu_backend_config,
-                                          MatchedFp8Param a,
-                                          MatchedFp8Param b) {
-    GemmBackendConfig &gemm_backend_config =
-        *gpu_backend_config.mutable_gemm_backend_config();
-    if (IsCuda(gpu_version_)) {
-      TF_ASSIGN_OR_RETURN(auto cuda_compute_capability,
-                          GetCudaComputeCapability(gpu_version_));
-      // FP8 GEMM kernels are only available on Ada, Hopper, and later
-      // architectures.
-      if (!cuda_compute_capability.IsAtLeast(8, 9)) {
-        VLOG(1) << "FP8 Custom Calls require Ada, Hopper, or later "
-                   "architectures. Got: "
-                << cuda_compute_capability.ToString()
-                << " and toolkit version: " << toolkit_version_;
-        return false;
-      }
-      // FP8 GEMM kernels are only available with CUDA 12.0 and above
-      if (toolkit_version_ < 12000) {
-        VLOG(1) << "FP8 Custom Calls require CUDA 12.0 or newer.";
-        return false;
-      }
-    }
-
-    if (IsRocm(gpu_version_)) {
-      TF_ASSIGN_OR_RETURN(auto rocm_compute_capability,
-                          GetRocmComputeCapability(gpu_version_));
-      if (!rocm_compute_capability.has_fp8_support()) {
-        VLOG(1) << "FP8 Custom Calls require MI300, or later architectures.";
-        return false;
-      }
-      if (toolkit_version_ < 60000) {
-        // FP8 GEMM kernels are only available with ROCm 6.0 and above
-        VLOG(1) << "FP8 Custom Calls require ROCm 6.0 or newer.";
-        return false;
-      }
-    }
-
-    PrimitiveType a_type = a.fp8_input->shape().element_type();
-    PrimitiveType b_type = b.fp8_input->shape().element_type();
-
-    // cuBLASLt FP8 GEMM kernels require one of the two operands to be in
-    // F8E4M3FN format.
-    if (IsCuda(gpu_version_)) {
-      if (a_type == F8E5M2 && b_type == F8E5M2) {
-        VLOG(1)
-            << "Failed to rewrite " << instr->ToShortString()
-            << " into FP8 Custom Call. The element type of one of the operands "
-               "must be F8E4M3FN.";
-        return false;
-      }
-      if ((a_type != F8E5M2 && a_type != F8E4M3FN) ||
-          (b_type != F8E5M2 && b_type != F8E4M3FN)) {
-        VLOG(1) << "Failed to rewrite " << instr->ToShortString()
-                << " into FP8 Custom Call. The input types must be F8E5M2 or "
-                   "F8E4M3FN, but got "
-                << PrimitiveType_Name(a_type) << " and "
-                << PrimitiveType_Name(b_type);
-        return false;
-      }
-    }
-
-    if (IsRocm(gpu_version_)) {
-      if (a_type == F8E5M2FNUZ && b_type == F8E5M2FNUZ) {
-        VLOG(1)
-            << "Failed to rewrite " << instr->ToShortString()
-            << " into FP8 Custom Call. The element type of one of the operands "
-               "must be F8E4M3FNUZ.";
-        return false;
-      }
-      if ((a_type != F8E5M2FNUZ && a_type != F8E4M3FNUZ) ||
-          (b_type != F8E5M2FNUZ && b_type != F8E4M3FNUZ)) {
-        VLOG(1)
-            << "Failed to rewrite " << instr->ToShortString()
-            << " into FP8 Custom Call. The input types must be F8E5M2FNUZ or "
-               "F8E4M3FNUZ, but got "
-            << PrimitiveType_Name(a_type) << " and "
-            << PrimitiveType_Name(b_type);
-        return false;
-      }
-    }
-
-    absl::Span<const int64_t> batch_dims =
-        gemm_backend_config.dot_dimension_numbers().rhs_batch_dimensions();
-
-    // cuBLASLt FP8 GEMM kernels require the scaling factors to be in F32
-    // format. Set the factors to one when no scaling factors were captured.
-    Literal one_literal = LiteralUtil::One(F32);
-    HloInstruction *one = instr->AddInstruction(
-        HloInstruction::CreateConstant(one_literal.Clone()));
-    std::array<bool, 2> mult_scale{a.mult_scale, b.mult_scale};
-    std::array<HloInstruction *, 2> scales{a.scale, b.scale}, inv_scales,
-        scales_f32;
-    for (int i = 0; i < scales.size(); ++i) {
-      if (scales[i]) {
-        if (!ShapeUtil::IsScalar(scales[i]->shape())) {
-          VLOG(1) << "Failed to rewrite " << instr->ToShortString()
-                  << " into FP8 Custom Call. The scaling factors must be "
-                     "scalars.";
-          return false;
-        }
-        if (!mult_scale[i]) {
-          inv_scales[i] = instr->AddInstruction(HloInstruction::CreateBinary(
-              scales[i]->shape(), HloOpcode::kDivide, one, scales[i]));
-        }
-        scales_f32[i] = mult_scale[i] ? scales[i] : inv_scales[i];
-        if (scales_f32[i]->shape().element_type() != F32) {
-          scales_f32[i] = instr->AddInstruction(HloInstruction::CreateConvert(
-              ShapeUtil::MakeScalarShape(F32), scales_f32[i]));
-        }
-      } else {
-        scales_f32[i] = one;
-      }
-    }
-
-    PrimitiveType d_type = instr->shape().element_type();
-    bool supported_d_type = (d_type == BF16 || d_type == F16 || d_type == F32);
-    if (IsCuda(gpu_version_) && (d_type == F8E4M3FN || d_type == F8E5M2)) {
-      supported_d_type = true;
-    }
-    if (IsRocm(gpu_version_) && toolkit_version_ >= 60200 &&
-        (d_type == F8E4M3FNUZ || d_type == F8E5M2FNUZ)) {
-      supported_d_type = true;
-    }
-    if (!supported_d_type) {
-      VLOG(1) << "Failed to rewrite " << instr->ToShortString()
-              << " into FP8 Custom Call. Output element type must be "
-              << (IsCuda(gpu_version_) ? "F8E4M3FN, F8E5M2, BF16, F16 or F32. "
-                  : toolkit_version_ >= 60200
-                      ? "F8E4M3FNUZ, F8E5M2FNUZ, BF16, F16 or F32. "
-                      : "BF16, F16 or F32. ")
-              << "Actual element type is " << PrimitiveType_Name(d_type);
-      return false;
-    }
-
-    // Each operand must have exactly one contracting and one non-contracting
-    // dimension.
-    absl::Span<const int64_t> a_contracting_dims =
-        gemm_backend_config.dot_dimension_numbers()
-            .lhs_contracting_dimensions();
-    absl::Span<const int64_t> b_contracting_dims =
-        gemm_backend_config.dot_dimension_numbers()
-            .rhs_contracting_dimensions();
-    if (a_contracting_dims.size() != 1 || b_contracting_dims.size() != 1) {
-      VLOG(1) << "Failed to rewrite " << instr->ToShortString()
-              << " into FP8 Custom Call. A and B must have one contracting "
-                 "dimension.";
-      return false;
-    }
-    if ((a.commutative_ops.empty() ? a.fp8_input
-                                   : a.commutative_ops.back().first)
-                    ->shape()
-                    .dimensions_size() -
-                batch_dims.size() !=
-            2 ||
-        (b.commutative_ops.empty() ? b.fp8_input
-                                   : b.commutative_ops.back().first)
-                    ->shape()
-                    .dimensions_size() -
-                batch_dims.size() !=
-            2) {
-      VLOG(1) << "Failed to rewrite " << instr->ToShortString()
-              << "into FP8 Custom Call. A and B must have one non-contracting "
-                 "dimension.";
-      return false;
-    }
-
-    // Sequentially apply the collected unary, dynamic-slice, pad and select ops
-    // to the unconverted and unscaled operands.
-    auto shift_ops = [&instr](HloInstruction *&x, InstrPath &x_ops) -> void {
-      for (std::pair<HloInstruction *, int> op : x_ops) {
-        std::vector<HloInstruction *> operands = {x};
-        // Insert the additional operands of dynamic-slice ops.
-        if (op.first->opcode() == HloOpcode::kDynamicSlice) {
-          for (int i = 1; i < op.first->operand_count(); ++i) {
-            operands.emplace_back(op.first->mutable_operand(i));
-          }
-        }
-        // Convert the second operand of pad ops.
-        if (op.first->opcode() == HloOpcode::kPad) {
-          HloInstruction *convert =
-              instr->AddInstruction(HloInstruction::CreateConvert(
-                  ShapeUtil::ChangeElementType(op.first->operand(1)->shape(),
-                                               x->shape().element_type()),
-                  op.first->mutable_operand(1)));
-          operands.emplace_back(convert);
-        }
-        // Convert and insert the additional operands of select ops.
-        if (op.first->opcode() == HloOpcode::kSelect) {
-          // The first operand is the predicate.
-          operands.emplace(operands.begin(), op.first->mutable_operand(0));
-          // Convert the remaining operand.
-          int operand_idx = op.second == 2 ? 1 : 2;
-          HloInstruction *convert =
-              instr->AddInstruction(HloInstruction::CreateConvert(
-                  ShapeUtil::ChangeElementType(
-                      op.first->operand(operand_idx)->shape(),
-                      x->shape().element_type()),
-                  op.first->mutable_operand(operand_idx)));
-          operands.emplace(operands.begin() + operand_idx, convert);
-        }
-        x = instr->AddInstruction(op.first->CloneWithNewOperands(
-            ShapeUtil::MakeShapeWithDenseLayout(
-                x->shape().element_type(), op.first->shape().dimensions(),
-                op.first->shape().layout().minor_to_major()),
-            operands));
-      }
-      return;
-    };
-    shift_ops(a.fp8_input, a.commutative_ops);
-    shift_ops(b.fp8_input, b.commutative_ops);
-
-    TF_ASSIGN_OR_RETURN(bool a_is_col_major,
-                        MatrixIsColumnMajor(*instr, gemm_backend_config, "a"));
-    TF_ASSIGN_OR_RETURN(bool b_is_col_major,
-                        MatrixIsColumnMajor(*instr, gemm_backend_config, "b"));
-
-    DotDimensionNumbers *dim_nums =
-        gemm_backend_config.mutable_dot_dimension_numbers();
-    int batch_dim_offset = batch_dims.size();
-
-    // cuBLASLt FP8 GEMM kernels currently require the first operand, i.e. A, to
-    // be row-major. If A is column-major, swap the contracting and
-    // non-contracting dimension and transpose the matrix to effectively make it
-    // column-major.
-    // TODO(philipphack): Remove once cuBLASLt supports A being column-major
-    if (a_is_col_major) {
-      CHECK(a_contracting_dims[0] == batch_dim_offset ||
-            a_contracting_dims[0] == batch_dim_offset + 1);
-      if (a_contracting_dims[0] == batch_dim_offset) {
-        dim_nums->set_lhs_contracting_dimensions(0, batch_dim_offset + 1);
-      } else {
-        dim_nums->set_lhs_contracting_dimensions(0, batch_dim_offset);
-      }
-      a.fp8_input =
-          TransposeMatrix(a.fp8_input, a_contracting_dims[0], batch_dims);
-    }
-
-    // Similarly, cuBLASLt requires the second operand to be column-major, so
-    // make it column-major if it is currently row-major.
-    if (!b_is_col_major) {
-      CHECK(b_contracting_dims[0] == batch_dim_offset ||
-            b_contracting_dims[0] == batch_dim_offset + 1);
-      if (b_contracting_dims[0] == batch_dim_offset) {
-        dim_nums->set_rhs_contracting_dimensions(0, batch_dim_offset + 1);
-      } else {
-        dim_nums->set_rhs_contracting_dimensions(0, batch_dim_offset);
-      }
-      b.fp8_input =
-          TransposeMatrix(b.fp8_input, b_contracting_dims[0], batch_dims);
-    }
-
-    a.fp8_input = PadOperandToMultipleOf16(batch_dims, a.fp8_input);
-    b.fp8_input = PadOperandToMultipleOf16(batch_dims, b.fp8_input);
-    Shape new_output_shape = PadShapeToMultipleOf16(instr->shape(), batch_dims);
-
-    std::vector<HloInstruction *> operands_list = {
-        a.fp8_input, b.fp8_input, scales_f32[0], scales_f32[1], one, one};
-
-    HloInstruction *new_custom_call =
-        instr->AddInstruction(HloInstruction::CreateCustomCall(
-            ShapeUtil::MakeShapeWithDenseLayout(
-                instr->shape().element_type(), new_output_shape.dimensions(),
-                instr->shape().layout().minor_to_major()),
-            operands_list, kCublasLtMatmulF8CallTarget));
-    TF_RETURN_IF_ERROR(new_custom_call->set_backend_config(gpu_backend_config));
-    TF_RETURN_IF_ERROR(SetName(instr->GetModule(), new_custom_call));
-
-    // Slice the result of the GEMM if the operands were padded.
-    HloInstruction *slice = nullptr;
-    if (new_output_shape.dimensions() != instr->shape().dimensions()) {
-      std::vector<int64_t> start_indices(instr->shape().rank(), 0);
-      std::vector<int64_t> strides(instr->shape().rank(), 1);
-      slice = instr->AddInstruction(HloInstruction::CreateSlice(
-          instr->shape(), new_custom_call, start_indices,
-          instr->shape().dimensions(), strides));
-    }
-
-    TF_RETURN_IF_ERROR(
-        ReplaceInstruction(instr, slice ? slice : new_custom_call));
-    VLOG(1) << instr->ToString() << " rewritten into FP8 Custom Call.";
-    return true;
-  }
-
-  absl::Status F8ScaleD(HloInstruction *instr, HloInstruction *existing_gemm,
-                        HloInstruction *d_scale) {
-    if (!ShapeUtil::IsScalar(d_scale->shape())) {
-      return absl::OkStatus();
-    }
-
-    // When the output of an FP8 GEMM is scaled but not type converted to FP8,
-    // cublasLT requires the scaling factor to be forwarded to the Custom Call
-    // as a_scale (chosen here) or b_scale. The scaling factor is fused here
-    // when no input scaling factors were fused during the creation of the
-    // Custom Call. When the maximum of the absolute value of the output of an
-    // FP8 GEMM is calculated and the output is scaled and type converted to
-    // FP8, the scaling of the output is fused in F8ConvertD.
-    if (!existing_gemm->operand(2)->IsConstant() ||
-        existing_gemm->operand(2)->literal().GetAsDouble({}) != 1.) {
-      return absl::OkStatus();
-    }
-
-    // The application of the scaling of the output to the input (see previous
-    // comment) is not valid for epilogues other than ReLU or when a matrix bias
-    // has been fused.
-    TF_ASSIGN_OR_RETURN(auto gpu_backend_config,
-                        existing_gemm->backend_config<GpuBackendConfig>());
-    const GemmBackendConfig &config = gpu_backend_config.gemm_backend_config();
-    if ((config.epilogue() != GemmBackendConfig::DEFAULT &&
-         config.epilogue() != GemmBackendConfig::RELU) ||
-        config.beta() != 0.) {
-      return absl::OkStatus();
-    }
-
-    // If necessary, invert the scaling factor of D and convert to F32.
-    TF_ASSIGN_OR_RETURN(
-        d_scale,
-        InvertAndConvertScalar(d_scale, instr->opcode() == HloOpcode::kDivide));
-
-    TF_RETURN_IF_ERROR(existing_gemm->ReplaceOperandWith(2, d_scale));
-    TF_RETURN_IF_ERROR(ReplaceInstruction(instr, existing_gemm));
-
-    VLOG(1) << "Scaling of FP8 GEMM fused into Custom Call.";
-    return absl::OkStatus();
-  }
-
-  absl::Status F8ConvertD(HloInstruction *instr, HloInstruction *existing_gemm,
-                          HloInstruction *d_scale, HloInstruction *clamp_lower,
-                          HloInstruction *clamp_upper,
-                          bool mult_scale = false) {
-    // Verify the data types and the operands of clamp.
-    if (instr->shape().element_type() == F8E4M3FN) {
-      if (!clamp_lower->literal().IsAllFloat(static_cast<float>(
-              std::numeric_limits<tsl::float8_e4m3fn>::lowest())) ||
-          !clamp_upper->literal().IsAllFloat(static_cast<float>(
-              std::numeric_limits<tsl::float8_e4m3fn>::max()))) {
-        return absl::OkStatus();
-      }
-    } else if (instr->shape().element_type() == F8E5M2) {
-      if (!clamp_lower->literal().IsAllFloat(static_cast<float>(
-              std::numeric_limits<tsl::float8_e5m2>::lowest())) ||
-          !clamp_upper->literal().IsAllFloat(static_cast<float>(
-              std::numeric_limits<tsl::float8_e5m2>::max()))) {
-        return absl::OkStatus();
-      }
-    } else {
-      return absl::OkStatus();
-    }
-
-    if (d_scale && !ShapeUtil::IsScalar(d_scale->shape())) {
-      return absl::OkStatus();
-    }
-
-    // The possible second user of the GEMM must be the calculation of the
-    // maximum of the absolute value of the result of the GEMM. Since it is
-    // unknown in what form this operation will be used, it is identified in a
-    // top-down approach by inspecting the users of the GEMM.
-    const std::vector<HloInstruction *> gemm_users = existing_gemm->users();
-    HloInstruction *reduce_damax = nullptr;
-    if (gemm_users.size() == 2) {
-      // In the presence of a ReLU activation, the abs instruction is elided
-      // since abs(ReLU(x)) = ReLU(x).
-      TF_ASSIGN_OR_RETURN(auto gpu_config,
-                          existing_gemm->backend_config<GpuBackendConfig>());
-      const GemmBackendConfig &config = gpu_config.gemm_backend_config();
-      for (int i = 0; i < gemm_users.size(); ++i) {
-        HloInstruction *maybe_reduce = nullptr;
-        if (gemm_users[i]->opcode() == HloOpcode::kAbs) {
-          if (gemm_users[i]->users().size() != 1) continue;
-          maybe_reduce = gemm_users[i]->users()[0];
-        } else {
-          // If there is no Abs instruction, relu is required as epilogue to
-          // ensure all values are nonnegative.
-          if (config.epilogue() != GemmBackendConfig::BIAS_RELU &&
-              config.epilogue() != GemmBackendConfig::RELU)
-            continue;
-          maybe_reduce = gemm_users[i];
-        }
-
-        if (maybe_reduce->opcode() == HloOpcode::kReduce &&
-            maybe_reduce->operands().size() == 2 &&
-            maybe_reduce->operand(1)->opcode() == HloOpcode::kConstant &&
-            ShapeUtil::IsScalar(maybe_reduce->operand(1)->shape())) {
-          HloInstruction *reduce = maybe_reduce;
-          HloComputation *reduce_comp = reduce->to_apply();
-          HloInstruction *reduce_comp_root = reduce_comp->root_instruction();
-          if (reduce->operand(1)->literal().GetAsDouble({}) <= 0. &&
-              reduce_comp_root->opcode() == HloOpcode::kMaximum &&
-              reduce_comp_root->operand(0)->opcode() == HloOpcode::kParameter &&
-              reduce_comp_root->operand(1)->opcode() == HloOpcode::kParameter) {
-            reduce_damax = reduce;
-          }
-        }
-      }
-      if (!reduce_damax) {
-        return absl::OkStatus();
-      }
-    } else if (gemm_users.size() > 2) {
-      return absl::OkStatus();
-    }
-
-    TF_ASSIGN_OR_RETURN(auto gpu_backend_config,
-                        existing_gemm->backend_config<GpuBackendConfig>());
-    const GemmBackendConfig &gemm_backend_config =
-        gpu_backend_config.gemm_backend_config();
-
-    if (gemm_backend_config.beta() != 0.0) {
-      if (existing_gemm->operand(2)->shape().element_type() != BF16 &&
-          existing_gemm->operand(2)->shape().element_type() != F16) {
-        VLOG(1) << "The scaling and conversion of the result of "
-                << existing_gemm->ToShortString()
-                << " is not fused into the FP8 Custom Call because it "
-                   "conflicts with the existing fusion of the addition of a "
-                   "matrix bias with element type other than BF16 or F16.";
-        return absl::OkStatus();
-      } else {
-        // Turn off the output to operand aliasing, since the fp8 output and
-        // bf16/fp16 bias have different sizes.
-        xla::Cast<HloCustomCallInstruction>(existing_gemm)
-            ->set_output_to_operand_aliasing({});
-      }
-    }
-
-    // If necessary, invert the scaling factor of D and convert to F32.
-    if (d_scale) {
-      TF_ASSIGN_OR_RETURN(d_scale,
-                          InvertAndConvertScalar(d_scale, !mult_scale));
-      TF_RETURN_IF_ERROR(existing_gemm->ReplaceOperandWith(
-          gemm_backend_config.beta() == 0.0 ? 5 : 6, d_scale));
-    }
-
-    // If present, elide the calculation of the maximum of the absolute values
-    // of the result of the GEMM.
-    if (reduce_damax) {
-      return F8AddDAmax(instr, existing_gemm, reduce_damax);
-    }
-
-    std::unique_ptr<HloInstruction> new_gemm =
-        existing_gemm->CloneWithNewShape(instr->shape());
-
-    TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(instr, std::move(new_gemm)));
-
-    VLOG(1) << "Conversion" << (reduce_damax ? " and amax calculation" : "")
-            << " fused into FP8 GEMM.";
-    return absl::OkStatus();
-  }
-
-  // Adds a scalar DAmax return value to an FP8 GEMM.
-  absl::Status F8AddDAmax(HloInstruction *instr, HloInstruction *existing_gemm,
-                          HloInstruction *reduce_damax) {
-    // Change the output shape of the Custom Call to tuple(D, DAmax).
-    Shape damax_shape = ShapeUtil::MakeScalarShape(F32);
-    Shape tuple_shape =
-        ShapeUtil::MakeTupleShape({instr->shape(), damax_shape});
-    HloInstruction *gemm_and_damax =
-        instr->AddInstruction(existing_gemm->CloneWithNewShape(tuple_shape));
-
-    TF_ASSIGN_OR_RETURN(auto gpu_config,
-                        gemm_and_damax->backend_config<GpuBackendConfig>());
-    GemmBackendConfig &config = *gpu_config.mutable_gemm_backend_config();
-    config.set_damax_output(true);
-    TF_RETURN_IF_ERROR(gemm_and_damax->set_backend_config(gpu_config));
-
-    // Obtain D and DAmax separately from the output tuple.
-    HloInstruction *d =
-        instr->AddInstruction(HloInstruction::CreateGetTupleElement(
-            instr->shape(), gemm_and_damax, 0));
-    HloInstruction *damax = instr->AddInstruction(
-        HloInstruction::CreateGetTupleElement(damax_shape, gemm_and_damax, 1));
-
-    // Convert DAmax from FP32 to the requested type and elide reduce.
-    HloInstruction *damax_converted = instr->AddInstruction(
-        HloInstruction::CreateConvert(reduce_damax->shape(), damax));
-    TF_RETURN_IF_ERROR(ReplaceInstruction(reduce_damax, damax_converted));
-    TF_RETURN_IF_ERROR(ReplaceInstruction(instr, d));
-
-    return absl::OkStatus();
-  }
-
-  // Fuses a matrix bias into a cuBLAS call. 'instr' should be an Add
-  // instruction in the following form:
-  //   Add(OptionalBitcast(OptionalSlice(gemm)), bias)
-  // where 'gemm' is expected to be a cuBLAS custom_call. Slice is introduced
-  // when the inputs of the gemm are possibly padded. Bitcast is introduced to
-  // handle high rank input.
-  absl::Status FuseMatrixBiasAdd(HloInstruction *instr, HloInstruction *bias,
-                                 const HloInstruction *gemm,
-                                 HloInstruction *bitcast = nullptr,
-                                 HloInstruction *slice = nullptr) {
-    TF_RET_CHECK(Shape::Equal().IgnoreElementType()(bias->shape(),
-                                                    bitcast ? bitcast->shape()
-                                                    : slice ? slice->shape()
-                                                            : gemm->shape()));
-
-    // Do not fuse bias into S32 GEMM, as for this datatype cuBLAS only
-    // supports fixed values for alpha/beta.
-    if (gemm->shape().element_type() == S32) {
-      return absl::OkStatus();
-    }
-
-    // To ensure correctness, only slices that chop off the ends of dimensions
-    // are supported.
-    if (slice) {
-      int slice_op_dim = slice->operand(0)->shape().rank();
-      if (slice->slice_starts() != std::vector<int64_t>(slice_op_dim, 0) ||
-          slice->slice_strides() != std::vector<int64_t>(slice_op_dim, 1)) {
-        return absl::OkStatus();
-      }
-    }
-    // Cublas gemm overwrites the bias matrix, so fusion is only possible if the
-    // gemm is the only user. CublasLt gemm can operate out-of-place.
-    bool can_overwrite_bias = [bias]() {
-      if (bias->user_count() > 1) {
-        // There is another user of the data, do not overwrite it.
-        return false;
-      }
-
-      if (bias->opcode() != HloOpcode::kParameter) {
-        // Not a parameter; can overwrite.
-        return true;
-      }
-
-      // The bias is a parameter of the computation; check if it is aliased.
-      if (!bias->parent()->IsEntryComputation()) {
-        // Only the HloModule has input/output aliasing, since this is not the
-        // entry computation, there are no guarantees about aliasing; do not
-        // overwrite.
-        return false;
-      }
-      const auto &in_out_alias_config =
-          bias->GetModule()->input_output_alias_config();
-      // If the parameter is aliased, we can overwrite it.
-      // TODO(victorstone): The assumption when calling ParameterHasAlias is
-      // that bias is not a tuple. This is why we pass {} as the argument for
-      // param_index.
-      return in_out_alias_config.ParameterHasAlias(bias->parameter_number(),
-                                                   /*param_index=*/{});
-    }();
-    bool want_to_fuse_bias = IsCublasLtMatmulF8(*gemm) ||
-                             IsCublasLtMatmul(*gemm) || can_overwrite_bias;
-
-    auto gpu_config = gemm->backend_config<GpuBackendConfig>().value();
-    GemmBackendConfig &config = *gpu_config.mutable_gemm_backend_config();
-    // It is possible to fuse into a cublasLt matmul that already has a vector
-    // bias, but no other epilogue will commute with the matrix bias add.
-    bool supported_epilogue =
-        ((config.epilogue() == GemmBackendConfig::DEFAULT) ||
-         (config.epilogue() == GemmBackendConfig::BIAS));
-
-    if ((config.beta() != 0) || !want_to_fuse_bias ||
-        (gemm->user_count() != 1) || !supported_epilogue) {
-      return absl::OkStatus();
-    }
-
-    config.set_beta(1.0);
-
-    std::vector<HloInstruction *> operands(gemm->operands().begin(),
-                                           gemm->operands().end());
-    HloInstruction *maybe_constant_folded_bias = MaybeConstantFoldBias(bias);
-    if (bitcast) {
-      maybe_constant_folded_bias =
-          instr->AddInstruction(HloInstruction::CreateBitcast(
-              slice->shape(), maybe_constant_folded_bias));
-    }
-
-    maybe_constant_folded_bias =
-        PadOperandToTargetShape(gemm->shape(), maybe_constant_folded_bias);
-
-    operands.insert(operands.begin() + 2, maybe_constant_folded_bias);
-
-    std::unique_ptr<HloInstruction> fused_op =
-        gemm->CloneWithNewOperands(gemm->shape(), operands);
-    // set output shape to bias shape if mix type
-    fused_op->mutable_shape()->set_element_type(bias->shape().element_type());
-    TF_RETURN_IF_ERROR(fused_op->set_backend_config(gpu_config));
-
-    // Choose whether the bias must alias the output. Legacy cublas GEMMs must
-    // operate in place and alias the bias with the output, whereas with
-    // cublasLt we can choose.
-    //
-    // Operating in place is always safe; copy-insertion will insert copies if
-    // necessary.  But (we assume) copying is slower than operating
-    // out-of-place, so for cublasLt (where we have the choice), we try to
-    // operate in place if we think it a copy won't be necessary.
-    //
-    // We assume that parameters are always read-only and therefore we'd need to
-    // copy if we were going to operate in place. (This is not quite true; the
-    // param could have input/output aliasing.)  We also assume that if there
-    // are other uses of the bias, we might need to copy.  (Again, not quite
-    // true if those uses all come before this operation.  But copy-insertion
-    // runs before scheduling, so it can't know and has to conservatively insert
-    // copies.)
-    if (IsLegacyCublasMatmul(*fused_op) || can_overwrite_bias) {
-      xla::Cast<HloCustomCallInstruction>(fused_op.get())
-          ->set_output_to_operand_aliasing({{{}, {2, {}}}});
-    }
-    TF_RETURN_IF_ERROR(SetName(instr->GetModule(), fused_op.get()));
-    if (slice) {
-      fused_op = slice->CloneWithNewOperands(
-          slice->shape(),
-          {slice->parent()->AddInstruction(std::move(fused_op))});
-    }
-
-    if (bitcast) {
-      fused_op = bitcast->CloneWithNewOperands(
-          bitcast->shape(),
-          {bitcast->parent()->AddInstruction(std::move(fused_op))});
-    }
-
-    return ReplaceWithNewInstruction(instr, std::move(fused_op));
-  }
-
-  // Fuses a vector bias into a cuBLAS call. 'instr' should be an Add
-  // instruction in the following form:
-  //   Add(OptionalBitcast(OptionalSlice(gemm)), Broadcast(OptionalConvert()))
-  // where 'gemm' is expected to be a cuBLAS custom_call. The optional
-  // convert is only used for F8 matmuls as cublasLt has specific constraints
-  // on the vector bias type for such matmuls. The optional bitcast is
-  // necessary to handle high rank input cases.
-  absl::StatusOr<bool> FuseVectorBiasAdd(HloInstruction *instr,
-                                         HloInstruction *broadcast,
-                                         HloInstruction *gemm,
-                                         HloInstruction *slice = nullptr,
-                                         HloInstruction *convert = nullptr,
-                                         HloInstruction *bitcast = nullptr) {
-    if (!bitcast) {
-      TF_RET_CHECK(ShapeUtil::Compatible(
-          broadcast->shape(), (slice ? slice->shape() : gemm->shape())));
-    }
-    // Verify that the data type is supported by Epilogue Fusion.
-    if (!SupportsEpilogueFusion(gemm->shape().element_type())) {
-      return false;
-    }
-
-    HloInstruction *bias = broadcast->mutable_operand(0);
-
-    TF_ASSIGN_OR_RETURN(auto gpu_config,
-                        gemm->backend_config<GpuBackendConfig>());
-    GemmBackendConfig &config = *gpu_config.mutable_gemm_backend_config();
-    // # output column dims == # non-contracting rhs operand dims.
-    const DotDimensionNumbers &dot_dims = config.dot_dimension_numbers();
-    size_t num_col_dims = gemm->operand(1)->shape().rank() -
-                          dot_dims.rhs_batch_dimensions_size() -
-                          dot_dims.rhs_contracting_dimensions_size();
-
-    if ((gemm->user_count() != 1) ||
-        (config.epilogue() != GemmBackendConfig::DEFAULT) ||
-        (bias->shape().rank() != num_col_dims)) {
-      return false;
-    }
-    // We require the bias vector to have been broadcast in the most major
-    // dimensions; i.e. its most minor physical dimensions align with most minor
-    // physical dimensions of the gemm output.
-    absl::Span<const int64_t> broadcast_dims = broadcast->dimensions();
-    for (size_t i = 0; i < num_col_dims; ++i) {
-      int64_t dim =
-          (bitcast ? bitcast : gemm)->shape().layout().minor_to_major(i);
-
-      // Find the corresponding dimension from the bias vector.
-      auto it = absl::c_find(broadcast_dims, dim);
-
-      if (it == broadcast_dims.end()) {
-        return false;
-      }
-
-      int64_t vector_dim = it - broadcast_dims.begin();
-      if (bias->shape().layout().minor_to_major(i) != vector_dim) {
-        return false;
-      }
-    }
-
-    std::vector<HloInstruction *> operands(gemm->operands().begin(),
-                                           gemm->operands().end());
-    // When (non-trivial) matrix and vector bias co-exist for FP8 matmul, just
-    // fuse matrix bias.
-    if (gemm->custom_call_target() == kCublasLtMatmulF8CallTarget &&
-        config.beta() != 0.0) {
-      return true;
-    }
-
-    if (gemm->custom_call_target() == kCublasLtMatmulF8CallTarget &&
-        bias->shape().element_type() == F32) {
-      if (convert == nullptr) {
-        return false;
-      }
-
-      HloInstruction *bias_f16_or_bf16 = convert->mutable_operand(0);
-      auto compatible_bias_type = [](const PrimitiveType bias_type,
-                                     const PrimitiveType output_type) {
-        if (bias_type == BF16) {
-          return output_type == F8E4M3FN || output_type == F8E5M2 ||
-                 output_type == F32 || output_type == BF16;
-        } else if (bias_type == F16) {
-          return output_type == F16 || output_type == F8E4M3FN ||
-                 output_type == F8E5M2;
-        }
-        return false;
-      };
-
-      // cuBLAS LT does not support FP32 biases on matmuls with FP8 inputs,
-      // even if the matmul output is FP32. We do not unconditionally convert
-      // the bias to a supported precision (F16 or BF16) because this lowers
-      // precision. Instead, we only fuse the bias if the bias itself is a
-      // convert from F16 or BF16, fusing the input of the convert instruction
-      // to the matmul.
-      if (compatible_bias_type(bias_f16_or_bf16->shape().element_type(),
-                               gemm->shape().element_type())) {
-        bias = bias_f16_or_bf16;
-      } else {
-        VLOG(1) << "Epilogue fusion of FP32 vector bias into FP8 GEMM is "
-                   "currently not supported. See the cublasLT support matrix.";
-        return false;
-      }
-    }
-
-    // In the case of high rank input for FP8, it is necessary to consider
-    // potential padding for the bias.
-    if (gemm->custom_call_target() == kCublasLtMatmulF8CallTarget && bitcast) {
-      bias = PadOperandToMultipleOf16(
-          config.dot_dimension_numbers().rhs_batch_dimensions(), bias);
-    }
-    // Replace add(gemm, broadcast) with fused new_gemm.
-    operands.push_back(bias);
-    config.set_epilogue(GemmBackendConfig::BIAS);
-    std::unique_ptr<HloInstruction> result =
-        gemm->CloneWithNewOperands(gemm->shape(), operands);
-    TF_RETURN_IF_ERROR(result->set_backend_config(gpu_config));
-    TF_RETURN_IF_ERROR(SetName(result->GetModule(), result.get()));
-    if (slice) {
-      result = slice->CloneWithNewOperands(
-          slice->shape(), {slice->parent()->AddInstruction(std::move(result))});
-    }
-
-    if (bitcast) {
-      result = bitcast->CloneWithNewOperands(
-          bitcast->shape(),
-          {bitcast->parent()->AddInstruction(std::move(result))});
-    }
-    TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(instr, std::move(result)));
-    return true;
-  }
-
-  absl::Status FuseReluActivation(HloInstruction *instr,
-                                  HloInstruction *broadcast,
-                                  HloInstruction *gemm,
-                                  HloInstruction *slice_or_bitcast = nullptr) {
-    TF_RET_CHECK(ShapeUtil::Compatible(
-        broadcast->shape(),
-        (slice_or_bitcast ? slice_or_bitcast->shape() : gemm->shape())));
-
-    if (!SupportsEpilogueFusion(gemm->shape().element_type())) {
-      return absl::OkStatus();
-    }
-
-    if (gemm->user_count() != 1) {
-      return absl::OkStatus();
-    }
-
-    TF_ASSIGN_OR_RETURN(auto gpu_config,
-                        gemm->backend_config<GpuBackendConfig>());
-    GemmBackendConfig &config = *gpu_config.mutable_gemm_backend_config();
-    if (config.epilogue() == GemmBackendConfig::DEFAULT) {
-      config.set_epilogue(GemmBackendConfig::RELU);
-    } else if (config.epilogue() == GemmBackendConfig::BIAS) {
-      config.set_epilogue(GemmBackendConfig::BIAS_RELU);
-    } else {
-      return absl::OkStatus();
-    }
-
-    std::unique_ptr<HloInstruction> result = gemm->Clone();
-    TF_RETURN_IF_ERROR(result->set_backend_config(gpu_config));
-    TF_RETURN_IF_ERROR(SetName(result->GetModule(), result.get()));
-
-    if (slice_or_bitcast) {
-      result = slice_or_bitcast->CloneWithNewOperands(
-          slice_or_bitcast->shape(),
-          {slice_or_bitcast->parent()->AddInstruction(std::move(result))});
-    }
-
-    return ReplaceWithNewInstruction(instr, std::move(result));
-  }
-
-  absl::Status FuseGeluActivation(HloInstruction *multiply,
-                                  HloInstruction *gemm,
-                                  HloInstruction *slice_or_bitcast = nullptr) {
-    if (!SupportsEpilogueFusion(gemm->shape().element_type())) {
-      return absl::OkStatus();
-    }
-    // For CUDA versions less than 12.3.2, cuBLAS LT returns
-    // CUBLAS_STATUS_NOT_SUPPORTED in some cases when fusing gelu into an FP8
-    // matmul. We cannot check the patch version, so disable this fusion with
-    // CUDA versions less than 12.4.
-    if (IsCuda(gpu_version_) && toolkit_version_ < 12040 &&
-        IsCublasLtMatmulF8(*gemm)) {
-      return absl::OkStatus();
-    }
-
-    // There are four users of the gemm output within the GELU calculation.
-    bool has_aux = gemm->user_count() > 4;
-
-    TF_ASSIGN_OR_RETURN(auto gpu_config,
-                        gemm->backend_config<GpuBackendConfig>());
-    GemmBackendConfig &config = *gpu_config.mutable_gemm_backend_config();
-
-    if (config.epilogue() == GemmBackendConfig::DEFAULT) {
-      config.set_epilogue(has_aux ? GemmBackendConfig::GELU_AUX
-                                  : GemmBackendConfig::GELU);
-    } else if (config.epilogue() == GemmBackendConfig::BIAS) {
-      config.set_epilogue(has_aux ? GemmBackendConfig::BIAS_GELU_AUX
-                                  : GemmBackendConfig::BIAS_GELU);
-    } else {
-      return absl::OkStatus();
-    }
-
-    std::unique_ptr<HloInstruction> output = gemm->CloneWithNewShape(
-        has_aux ? ShapeUtil::MakeTupleShape({gemm->shape(), gemm->shape()})
-                : gemm->shape());
-    TF_RETURN_IF_ERROR(output->set_backend_config(gpu_config));
-    TF_RETURN_IF_ERROR(SetName(multiply->GetModule(), output.get()));
-
-    if (slice_or_bitcast) {
-      output = slice_or_bitcast->CloneWithNewOperands(
-          slice_or_bitcast->shape(),
-          {gemm->parent()->AddInstruction(std::move(output))});
-    }
-
-    if (has_aux) {
-      HloInstruction *tuple_output =
-          gemm->parent()->AddInstruction(std::move(output));
-      TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
-          gemm, HloInstruction::CreateGetTupleElement(tuple_output, 1)));
-      output = HloInstruction::CreateGetTupleElement(tuple_output, 0);
-    }
-
-    return ReplaceWithNewInstruction(multiply, std::move(output));
-  }
-
- private:
-  se::GpuComputeCapability gpu_version_;
-  int32_t toolkit_version_;
-  bool f8_rewrite_;
-
-  // Choose cublas or cublasLt for the target of the custom call that instr will
-  // be rewritten into.
-  absl::StatusOr<absl::string_view> GetNonFp8GemmCustomCallTarget(
-      const HloInstruction &instr,
-      const GemmBackendConfig &gemm_backend_config) const {
-    if (!instr.GetModule()
-             ->config()
-             .debug_options()
-             .xla_gpu_enable_cublaslt()) {
-      // cublasLt is not enabled.
-      return absl::string_view(kGemmCallTarget);
-    }
-
-    // cublasLt is enabled, check if other internal conditions are met.
-    const HloInstruction *lhs = instr.operand(0);
-    const HloInstruction *rhs = instr.operand(1);
-    if (lhs->shape().element_type() == S8 ||
-        rhs->shape().element_type() == S8) {
-      // TODO(b/241446501) The XLA usage of cublasLt does not yet handle
-      // int8 matmuls. Fallback to legacy cublas.
-      return absl::string_view(kGemmCallTarget);
-    }
-
-    // All internal conditions are met, check if we meet the requirements of
-    // cublasLt.
-    TF_ASSIGN_OR_RETURN(bool gemm_is_supported_by_cublas_lt,
-                        GemmIsSupportedByCublasLt(instr, gemm_backend_config));
-    if (gemm_is_supported_by_cublas_lt) {
-      return absl::string_view(kCublasLtMatmulCallTarget);
-    }
-
-    // This case is not supported by cublasLt, fallback to legacy cublas.
-    return absl::string_view(kGemmCallTarget);
-  }
-
-  absl::StatusOr<bool> TypesAreSupportedByLegacyCublas(
-      const HloInstruction &instr, const GemmBackendConfig &gemm_backend_config,
-      const HloInstruction *bias = nullptr) const {
-    // Figure out the Atype/Btype.
-    const PrimitiveType a_dtype = instr.operand(0)->shape().element_type();
-    const PrimitiveType b_dtype = instr.operand(1)->shape().element_type();
-    const PrimitiveType output_type =
-        bias ? bias->shape().element_type() : instr.shape().element_type();
-    const std::array<PrimitiveType, 12> supported_type = {
-        PrimitiveType::S8,  PrimitiveType::F16, PrimitiveType::BF16,
-        PrimitiveType::F32, PrimitiveType::S32, PrimitiveType::F64,
-        PrimitiveType::C64, PrimitiveType::C128};
-    // legacy cublas has a defined set of combinations of types that it
-    // supports. Figure out the computeType and scaleType.
-    if (!absl::c_linear_search(supported_type, output_type)) return false;
-    TF_ASSIGN_OR_RETURN(const se::blas::DataType output_dtype,
-                        se::gpu::AsBlasDataType(output_type));
-    // TODO(tdanyluk): Investigate why don't we use the actual precision (and
-    // algorithm) here? Why do we use the default?
-    TF_ASSIGN_OR_RETURN(const se::blas::ComputationType compute_type,
-                        se::gpu::GetBlasComputationType(
-                            PrecisionConfig::ALG_UNSET, a_dtype, output_type,
-                            stream_executor::blas::kDefaultComputePrecision));
-    se::blas::DataType scale_type =
-        se::gpu::GetScaleType(output_dtype, compute_type);
-
-    using se::blas::ComputationType;
-    using se::blas::DataType;
-    // This matrix of supported types is taken directly from cublas
-    // documentation.
-    // https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmex
-    const std::array<
-        std::tuple<ComputationType, DataType /*scale_type*/,
-                   PrimitiveType /*a_dtype*/, PrimitiveType /*b_dtype*/,
-                   DataType /*output_dtype*/>,
-        32>
-        supported_type_combinations = {{
-            {ComputationType::kF16, DataType::kHalf, PrimitiveType::F16,
-             PrimitiveType::F16, DataType::kHalf},
-
-            {ComputationType::kI32, DataType::kInt32, PrimitiveType::S8,
-             PrimitiveType::S8, DataType::kInt32},
-
-            {ComputationType::kF32, DataType::kFloat, PrimitiveType::BF16,
-             PrimitiveType::BF16, DataType::kBF16},
-            {ComputationType::kF32, DataType::kFloat, PrimitiveType::F16,
-             PrimitiveType::F16, DataType::kHalf},
-            {ComputationType::kF32, DataType::kFloat, PrimitiveType::S8,
-             PrimitiveType::S8, DataType::kFloat},
-            {ComputationType::kF32, DataType::kFloat, PrimitiveType::BF16,
-             PrimitiveType::BF16, DataType::kFloat},
-            {ComputationType::kF32, DataType::kFloat, PrimitiveType::F16,
-             PrimitiveType::F16, DataType::kFloat},
-            {ComputationType::kF32, DataType::kFloat, PrimitiveType::F32,
-             PrimitiveType::F32, DataType::kFloat},
-
-            // There would be an entry here for A/BType complex int8, but we do
-            // not support that type.
-            {ComputationType::kF32, DataType::kComplexFloat, PrimitiveType::C64,
-             PrimitiveType::C64, DataType::kComplexFloat},
-
-            {ComputationType::kF16AsF32, DataType::kFloat, PrimitiveType::F32,
-             PrimitiveType::F32, DataType::kFloat},
-            {ComputationType::kF16AsF32, DataType::kComplexFloat,
-             PrimitiveType::C64, PrimitiveType::C64, DataType::kComplexFloat},
-
-            {ComputationType::kBF16AsF32, DataType::kFloat, PrimitiveType::F32,
-             PrimitiveType::F32, DataType::kFloat},
-            {ComputationType::kBF16AsF32, DataType::kComplexFloat,
-             PrimitiveType::C64, PrimitiveType::C64, DataType::kComplexFloat},
-
-            {ComputationType::kTF32AsF32, DataType::kFloat, PrimitiveType::F32,
-             PrimitiveType::F32, DataType::kFloat},
-            {ComputationType::kTF32AsF32, DataType::kComplexFloat,
-             PrimitiveType::C64, PrimitiveType::C64, DataType::kComplexFloat},
-
-            {ComputationType::kF64, DataType::kDouble, PrimitiveType::F64,
-             PrimitiveType::F64, DataType::kDouble},
-            {ComputationType::kF64, DataType::kComplexDouble,
-             PrimitiveType::C128, PrimitiveType::C128,
-             DataType::kComplexDouble},
-        }};
-
-    return absl::c_linear_search(
-        supported_type_combinations,
-        std::make_tuple(compute_type, scale_type, a_dtype, b_dtype,
-                        output_dtype));
-  }
-
-  absl::StatusOr<bool> TypesAreSupportedByCublasLt(
-      const HloInstruction &instr, const GemmBackendConfig &backend_config,
-      const HloInstruction *bias = nullptr) const {
-    // Figure out the Atype/Btype.
-    const PrimitiveType a_dtype = instr.operand(0)->shape().element_type();
-    const PrimitiveType b_dtype = instr.operand(1)->shape().element_type();
-    const PrimitiveType output_type =
-        bias ? bias->shape().element_type() : instr.shape().element_type();
-    const std::array<PrimitiveType, 12> supported_type = {
-        PrimitiveType::F8E5M2FNUZ, PrimitiveType::F8E4M3FNUZ,
-        PrimitiveType::F8E5M2,     PrimitiveType::F8E4M3FN,
-        PrimitiveType::S8,         PrimitiveType::F16,
-        PrimitiveType::BF16,       PrimitiveType::F32,
-        PrimitiveType::S32,        PrimitiveType::F64,
-        PrimitiveType::C64,        PrimitiveType::C128};
-    if (!absl::c_linear_search(supported_type, output_type)) return false;
-    // cublasLt has a defined set of combinations of types that it supports.
-    // Figure out the computeType and scaleType.
-    TF_ASSIGN_OR_RETURN(const se::blas::DataType output_dtype,
-                        se::gpu::AsBlasDataType(output_type));
-    const int max_precision = *absl::c_max_element(
-        backend_config.precision_config().operand_precision());
-    const PrecisionConfig::Algorithm algorithm =
-        backend_config.precision_config().algorithm();
-    if (!algorithm_util::IsSupportedByCublasOrCublasLt(algorithm)) return false;
-
-    TF_ASSIGN_OR_RETURN(
-        const se::blas::ComputationType compute_type,
-        se::gpu::GetBlasComputationType(
-            algorithm, a_dtype, instr.shape().element_type(), max_precision));
-    se::blas::DataType scale_type =
-        se::gpu::GetScaleType(output_dtype, compute_type);
-
-    using se::blas::ComputationType;
-    using se::blas::DataType;
-    using TypeCombinations = std::initializer_list<std::tuple<
-        ComputationType, DataType /*scale_type*/, PrimitiveType /*a_dtype*/,
-        PrimitiveType /*b_dtype*/, DataType /*output_dtype*/>>;
-    // This matrix of supported types is taken directly from cublasLt
-    // documentation.
-    // https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmul
-    const TypeCombinations supported_cublas_type_combinations = {
-        // FP8 types:
-        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN,
-         PrimitiveType::F8E4M3FN, DataType::kBF16},
-        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN,
-         PrimitiveType::F8E4M3FN, DataType::kF8E4M3FN},
-        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN,
-         PrimitiveType::F8E4M3FN, DataType::kHalf},
-        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN,
-         PrimitiveType::F8E4M3FN, DataType::kFloat},
-
-        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN,
-         PrimitiveType::F8E5M2, DataType::kBF16},
-        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN,
-         PrimitiveType::F8E5M2, DataType::kF8E4M3FN},
-        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN,
-         PrimitiveType::F8E5M2, DataType::kF8E5M2},
-        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN,
-         PrimitiveType::F8E5M2, DataType::kHalf},
-        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN,
-         PrimitiveType::F8E5M2, DataType::kFloat},
-
-        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2,
-         PrimitiveType::F8E4M3FN, DataType::kBF16},
-        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2,
-         PrimitiveType::F8E4M3FN, DataType::kF8E4M3FN},
-        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2,
-         PrimitiveType::F8E4M3FN, DataType::kF8E5M2},
-        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2,
-         PrimitiveType::F8E4M3FN, DataType::kHalf},
-        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2,
-         PrimitiveType::F8E4M3FN, DataType::kFloat},
-        // There would be an entry here for A/BType complex int8, but we do
-        // not support that type.
-        {ComputationType::kF32, DataType::kComplexFloat, PrimitiveType::C64,
-         PrimitiveType::C64, DataType::kComplexFloat},
-
-        {ComputationType::kF16AsF32, DataType::kFloat, PrimitiveType::F32,
-         PrimitiveType::F32, DataType::kFloat},
-        {ComputationType::kF16AsF32, DataType::kComplexFloat,
-         PrimitiveType::C64, PrimitiveType::C64, DataType::kComplexFloat},
-        // The next 4 may be supported by hipblaslt, but they are not
-        // covered by any unit tests
-        {ComputationType::kBF16AsF32, DataType::kFloat, PrimitiveType::F32,
-         PrimitiveType::F32, DataType::kFloat},
-        {ComputationType::kBF16AsF32, DataType::kComplexFloat,
-         PrimitiveType::C64, PrimitiveType::C64, DataType::kComplexFloat},
-
-        {ComputationType::kTF32AsF32, DataType::kFloat, PrimitiveType::F32,
-         PrimitiveType::F32, DataType::kFloat},
-        {ComputationType::kTF32AsF32, DataType::kComplexFloat,
-         PrimitiveType::C64, PrimitiveType::C64, DataType::kComplexFloat},
-
-        {ComputationType::kF64, DataType::kDouble, PrimitiveType::F64,
-         PrimitiveType::F64, DataType::kDouble},
-        {ComputationType::kF64, DataType::kComplexDouble, PrimitiveType::C128,
-         PrimitiveType::C128, DataType::kComplexDouble},
-    };
-    if (IsCuda(gpu_version_) &&
-        absl::c_linear_search(supported_cublas_type_combinations,
-                              std::tuple{compute_type, scale_type, a_dtype,
-                                         b_dtype, output_dtype})) {
-      return true;
-    }
-    const TypeCombinations supported_hipblas_type_combinations = {
-        // FP8 types:
-        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
-         PrimitiveType::F8E4M3FNUZ, DataType::kBF16},
-        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
-         PrimitiveType::F8E4M3FNUZ, DataType::kF8E4M3FNUZ},
-        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
-         PrimitiveType::F8E4M3FNUZ, DataType::kHalf},
-        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
-         PrimitiveType::F8E4M3FNUZ, DataType::kFloat},
-
-        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
-         PrimitiveType::F8E5M2FNUZ, DataType::kBF16},
-        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
-         PrimitiveType::F8E5M2FNUZ, DataType::kF8E4M3FNUZ},
-        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
-         PrimitiveType::F8E5M2FNUZ, DataType::kF8E5M2FNUZ},
-        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
-         PrimitiveType::F8E5M2FNUZ, DataType::kHalf},
-        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
-         PrimitiveType::F8E5M2FNUZ, DataType::kFloat},
-
-        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2FNUZ,
-         PrimitiveType::F8E4M3FNUZ, DataType::kBF16},
-        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2FNUZ,
-         PrimitiveType::F8E4M3FNUZ, DataType::kF8E4M3FNUZ},
-        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2FNUZ,
-         PrimitiveType::F8E4M3FNUZ, DataType::kF8E5M2FNUZ},
-        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2FNUZ,
-         PrimitiveType::F8E4M3FNUZ, DataType::kHalf},
-        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2FNUZ,
-         PrimitiveType::F8E4M3FNUZ, DataType::kFloat},
-    };
-    if (IsRocm(gpu_version_) &&
-        absl::c_linear_search(supported_hipblas_type_combinations,
-                              std::tuple{compute_type, scale_type, a_dtype,
-                                         b_dtype, output_dtype})) {
-      return true;
-    }
-    const TypeCombinations supported_type_combinations = {
-        // Other data types:
-        {ComputationType::kF16, DataType::kHalf, PrimitiveType::F16,
-         PrimitiveType::F16, DataType::kHalf},
-
-        {ComputationType::kI32, DataType::kInt32, PrimitiveType::S8,
-         PrimitiveType::S8, DataType::kInt32},
-        {ComputationType::kI32, DataType::kFloat, PrimitiveType::S8,
-         PrimitiveType::S8, DataType::kInt8},
-
-        {ComputationType::kF32, DataType::kFloat, PrimitiveType::BF16,
-         PrimitiveType::BF16, DataType::kBF16},
-        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F16,
-         PrimitiveType::F16, DataType::kHalf},
-        {ComputationType::kF32, DataType::kFloat, PrimitiveType::S8,
-         PrimitiveType::S8, DataType::kFloat},
-        {ComputationType::kF32, DataType::kFloat, PrimitiveType::BF16,
-         PrimitiveType::BF16, DataType::kFloat},
-        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F16,
-         PrimitiveType::F16, DataType::kFloat},
-        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F32,
-         PrimitiveType::F32, DataType::kFloat},
-    };
-
-    return absl::c_linear_search(
-        supported_type_combinations,
-        std::make_tuple(compute_type, scale_type, a_dtype, b_dtype,
-                        output_dtype));
-  }
-
-  absl::StatusOr<bool> MatrixIsColumnMajor(
-      const HloInstruction &instr, const GemmBackendConfig &gemm_backend_config,
-      const std::string matrix_name = "output") const {
-    const HloInstruction *lhs = instr.operand(0);
-    const HloInstruction *rhs = instr.operand(1);
-
-    const DotDimensionNumbers &dot_dims =
-        gemm_backend_config.dot_dimension_numbers();
-    // We use ALG_UNSET and kDefaultComputePrecision because we don't care about
-    // the precision, just the layout, since we're just checking if the matrix
-    // is column-major.
-    TF_ASSIGN_OR_RETURN(
-        GemmConfig gemm_config,
-        GemmConfig::For(
-            lhs->shape(), dot_dims.lhs_batch_dimensions(),
-            dot_dims.lhs_contracting_dimensions(), rhs->shape(),
-            dot_dims.rhs_batch_dimensions(),
-            dot_dims.rhs_contracting_dimensions(),
-            /*output_shape=*/instr.shape(), gemm_backend_config.alpha_real(),
-            gemm_backend_config.alpha_imag(), gemm_backend_config.beta(),
-            /*precision_algorithm=*/PrecisionConfig::ALG_UNSET,
-            /*algorithm*/ std::nullopt, se::blas::kDefaultComputePrecision,
-            gemm_backend_config.grad_x(), gemm_backend_config.grad_y()));
-
-    if (matrix_name == "lhs" || matrix_name == "a") {
-      return gemm_config.lhs_layout.order == MatrixLayout::Order::kColumnMajor;
-    } else if (matrix_name == "rhs" || matrix_name == "b") {
-      return gemm_config.rhs_layout.order == MatrixLayout::Order::kColumnMajor;
-    } else if (matrix_name == "output" || matrix_name == "d") {
-      return gemm_config.output_layout.order ==
-             MatrixLayout::Order::kColumnMajor;
-    } else {
-      return Internal("Invalid matrix name.");
-    }
-  }
-
-  absl::StatusOr<bool> GemmIsSupportedByCublasLt(
-      const HloInstruction &instr,
-      const GemmBackendConfig &gemm_backend_config) const {
-    const HloInstruction *lhs = instr.operand(0);
-    const HloInstruction *rhs = instr.operand(1);
-    const Shape &output_shape = instr.shape();
-
-    TF_ASSIGN_OR_RETURN(
-        bool types_are_supported_by_cublas_lt,
-        TypesAreSupportedByCublasLt(instr, gemm_backend_config));
-    if (!types_are_supported_by_cublas_lt) {
-      return false;
-    }
-
-    // The cublasLt API has two currently known limitations:
-    // 1. Batch count must be <2^16.
-    constexpr int64_t kMaxBatchCount = 65535;
-    // We get the batch dimension size from lhs here, but we could just as well
-    // use rhs; they are guaranteed to be the same (TODO:Verify).
-    const auto &batch_dimensions =
-        gemm_backend_config.dot_dimension_numbers().lhs_batch_dimensions();
-    int batch_count = (batch_dimensions.empty() ? 0 : 1);
-    // All batch dimensions get flattened into a single batch dimension.
-    for (auto batch_dimension : batch_dimensions) {
-      batch_count *= lhs->shape().dimensions(batch_dimension);
-    }
-    if (batch_count > kMaxBatchCount) {
-      // This is not supported by cublasLt.
-      return false;
-    }
-
-    TF_ASSIGN_OR_RETURN(bool output_is_column_major,
-                        MatrixIsColumnMajor(instr, gemm_backend_config));
-
-    if (auto isrocm = std::get_if<se::RocmComputeCapability>(&gpu_version_);
-        isrocm) {
-      if (!isrocm->has_hipblaslt()) {
-        return false;
-      }
-    }
-
-    // 2. cublasLt does not support rhs col dimension size > 4194240 for
-    // C64.
-    constexpr int kMaxDimensionSize{4194240};
-    if (output_shape.element_type() != C64) {
-      // Does not match type in unsupported case.
-      return true;
-    }
-
-    if (std::holds_alternative<se::CudaComputeCapability>(gpu_version_)) {
-      auto cuda_compute_capability_ =
-          std::get<se::CudaComputeCapability>(gpu_version_);
-      if (cuda_compute_capability_.IsAtLeast(
-              se::CudaComputeCapability::AMPERE)) {
-        // cuBlasLt has an implementation for complex data with compute type
-        // 32F_FAST_32TF that uses tensor cores and that is free from the
-        // restriction. This implementation only works on Ampere
-        // architecture though (where TF32 was introduced).
-        return true;
-      }
-    }
-    // Get the rhs non-contracting dimensions as they will eventually be at the
-    // cublasLt level.
-    std::vector<int64_t> rhs_non_contracting_dims;
-    const DotDimensionNumbers &dot_dims =
-        gemm_backend_config.dot_dimension_numbers();
-
-    if (!output_is_column_major) {
-      // cublasLt's matmul output is column major by default. This gemm requires
-      // the output to be in row major. Later we will swap lhs & rhs (and
-      // transpose each operand) of this gemm. Since we care about the rhs at
-      // the cublasLt level, this swap means that we care about the lhs right
-      // here.
-      TF_ASSIGN_OR_RETURN(
-          rhs_non_contracting_dims,
-          GetNonContractingDims(lhs->shape(), dot_dims.lhs_batch_dimensions(),
-                                dot_dims.lhs_contracting_dimensions()));
-    } else {
-      TF_ASSIGN_OR_RETURN(
-          rhs_non_contracting_dims,
-          GetNonContractingDims(rhs->shape(), dot_dims.rhs_batch_dimensions(),
-                                dot_dims.rhs_contracting_dimensions()));
-    }
-
-    const auto lhs_non_contracting_dimension_size = absl::c_accumulate(
-        rhs_non_contracting_dims, 1, [&](int64_t size, int64_t dim) {
-          return size * lhs->shape().dimensions(dim);
-        });
-
-    // Check that the size of the non-contracting dimension is not too large.
-    return lhs_non_contracting_dimension_size <= kMaxDimensionSize;
-  }
-
-  // Turns an F8 dot with unsupported output type into an F8 dot with F32
-  // output, and converting the F32 output to unsupported output types.
-  absl::StatusOr<HloInstruction *> TurnF8DotWithUnsupportedOutputTypeIntoF32(
-      HloInstruction *instr) {
-    Shape output_f32_shape = instr->shape();
-    output_f32_shape.set_element_type(F32);
-    HloInstruction *f32_dot =
-        instr->AddInstruction(instr->CloneWithNewShape(output_f32_shape));
-    HloInstruction *convert = instr->AddInstruction(
-        HloInstruction::CreateConvert(instr->shape(), f32_dot));
-    TF_RETURN_IF_ERROR(ReplaceInstruction(instr, convert));
-    return f32_dot;
-  }
-
-  // Turns an F8 dot into an F16 dot, converting operands to F16 and
-  // converting the output back to F8.
-  absl::StatusOr<HloInstruction *> TurnF8DotIntoF16Dot(HloInstruction *instr) {
-    DCHECK(IsF8Type(instr->operand(0)));
-    DCHECK(IsF8Type(instr->operand(1)));
-
-    // Convert operands to F16
-    for (int i = 0; i < 2; ++i) {
-      Shape operand_f16_shape = instr->operand(i)->shape();
-      operand_f16_shape.set_element_type(F16);
-      HloInstruction *convert =
-          instr->AddInstruction(HloInstruction::CreateConvert(
-              operand_f16_shape, instr->mutable_operand(i)));
-      TF_RETURN_IF_ERROR(instr->ReplaceOperandWith(i, convert));
-    }
-
-    // If output is F8, change output to F16 and then convert it back to F8
-    if (IsF8Type(instr)) {
-      Shape output_f16_shape = instr->shape();
-      output_f16_shape.set_element_type(F16);
-      HloInstruction *f16_dot =
-          instr->AddInstruction(instr->CloneWithNewShape(output_f16_shape));
-      HloInstruction *convert_to_f8 = instr->AddInstruction(
-          HloInstruction::CreateConvert(instr->shape(), f16_dot));
-      TF_RETURN_IF_ERROR(ReplaceInstruction(instr, convert_to_f8));
-      return f16_dot;
-    } else {
-      return instr;
-    }
-  }
-};
-
-// Rewriter that adds a workspace to legacy cuBLAS custom calls. We run it
-// separately after gemm rewriter, so that we can do pattern matching without
-// having to match output tuples.
-class GemmWorkspaceRewriteVisitor : public DfsHloRewriteVisitor {
- public:
-  explicit GemmWorkspaceRewriteVisitor(
-      const se::GpuComputeCapability &gpu_version)
-      : gpu_version_(gpu_version) {}
-
-  absl::Status HandleCustomCall(HloInstruction *instr) override {
-    bool has_aux_output = false;
-    if (instr->custom_call_target() == kCublasLtMatmulCallTarget ||
-        instr->custom_call_target() == kCublasLtMatmulF8CallTarget) {
-      TF_ASSIGN_OR_RETURN(const auto gpu_config,
-                          instr->backend_config<xla::gpu::GpuBackendConfig>());
-      const xla::gpu::GemmBackendConfig &config =
-          gpu_config.gemm_backend_config();
-      xla::gpu::GemmBackendConfig_Epilogue epilogue = config.epilogue();
-      TF_ASSIGN_OR_RETURN(
-          has_aux_output,
-          xla::gpu::gpublas_lt::EpilogueHasAuxiliaryOutput(epilogue));
-
-      if (!((instr->shape().IsTuple() &&
-             instr->shape().tuple_shapes_size() ==
-                 has_aux_output + config.damax_output() + 1) ||
-            instr->shape().IsArray())) {
-        return absl::OkStatus();
-      }
-    } else if (instr->custom_call_target() != kGemmCallTarget ||
-               !instr->shape().IsArray()) {
-      return absl::OkStatus();
-    }
-
-    auto *cuda_cc = std::get_if<se::CudaComputeCapability>(&gpu_version_);
-
-    // Pass a user-managed workspace to legacy cuBLAS operations, as
-    // otherwise cuBLAS will use its own internal pool which will be competing
-    // with XLA allocator for device memory.
-    int64_t workspace = cuda_cc == nullptr ? GemmConfig::kDefaultWorkspace
-                        : cuda_cc->IsAtLeastHopper()
-                            ? GemmConfig::kHopperWorkspace
-                            : GemmConfig::kDefaultWorkspace;
-
-    // We do not know the workspace size required by cuBLAS, but we can guess
-    // that in a worst case cuBLAS will transpose all operands into tiled
-    // layout optimal for the tensor cores. It doesn't make sense to allocate a
-    // larger workspace.
-    //
-    // TODO(ezhulenev): This is not based on any measurement, just a common
-    // sense, we should tweak it to find the minimal workspace size.
-    if (instr->custom_call_target() == kGemmCallTarget) {
-      int64_t operands_byte_size = 0;
-      for (auto &operand : instr->operands()) {
-        operands_byte_size += ShapeUtil::ByteSizeOf(operand->shape());
-      }
-      workspace = std::min(workspace, operands_byte_size);
-    }
-
-    // Append workspace buffer to instruction outputs.
-    std::vector<Shape> output_shapes = instr->shape().IsArray()
-                                           ? std::vector<Shape>{instr->shape()}
-                                           : instr->shape().tuple_shapes();
-    output_shapes.emplace_back(ShapeUtil::MakeShape(S8, {workspace}));
-    Shape output_shape = ShapeUtil::MakeTupleShape(output_shapes);
-
-    // Clone custom call with a new shape.
-    HloInstruction *new_call = instr->AddInstruction(
-        instr->CloneWithNewOperands(output_shape, instr->operands()));
-
-    // Update operand aliasing if it was a fused gemm with aliased output.
-    auto *custom_call = xla::Cast<HloCustomCallInstruction>(new_call);
-    if (!custom_call->output_to_operand_aliasing().empty()) {
-      custom_call->set_output_to_operand_aliasing({{{0}, {2, {}}}});
-    }
-
-    if (instr->shape().IsTuple()) {
-      for (auto user : instr->users()) {
-        auto user_get_tuple =
-            dynamic_cast<HloGetTupleElementInstruction *>(user);
-        TF_RET_CHECK(user_get_tuple);
-        HloInstruction *get_output =
-            instr->AddInstruction(HloInstruction::CreateGetTupleElement(
-                new_call, user_get_tuple->tuple_index()));
-        TF_RETURN_IF_ERROR(ReplaceInstruction(user_get_tuple, get_output));
-      }
-      return absl::OkStatus();
-    } else {
-      HloInstruction *get_output = instr->AddInstruction(
-          HloInstruction::CreateGetTupleElement(new_call, 0));
-      return ReplaceInstruction(instr, get_output);
-    }
-  }
-
- private:
-  se::GpuComputeCapability gpu_version_;
-};
-
-absl::StatusOr<bool> RunOnComputation(HloComputation *computation,
-                                      se::GpuComputeCapability gpu_version,
-                                      int32_t toolkit_version,
-                                      bool f8_rewrite) {
-  GemmRewriterVisitor visitor(gpu_version, toolkit_version, f8_rewrite);
-  TF_RETURN_IF_ERROR(computation->Accept(&visitor));
-  GemmWorkspaceRewriteVisitor workspace_visitor(gpu_version);
-  TF_RETURN_IF_ERROR(computation->Accept(&workspace_visitor));
-  return visitor.changed();
-}
-
-}  // anonymous namespace
-
-GemmRewriter::GemmRewriter(se::GpuComputeCapability gpu_version,
-                           int32_t toolkit_version, bool f8_rewrite)
-    : gpu_version_(gpu_version),
-      toolkit_version_(toolkit_version),
-      f8_rewrite_(f8_rewrite) {}
-
-absl::StatusOr<bool> GemmRewriter::Run(
-    HloModule *module,
-    const absl::flat_hash_set<absl::string_view> &execution_threads) {
-  bool changed = false;
-  for (HloComputation *computation :
-       module->MakeNonfusionComputations(execution_threads)) {
-    TF_ASSIGN_OR_RETURN(bool result,
-                        RunOnComputation(computation, gpu_version_,
-                                         toolkit_version_, f8_rewrite_));
-    changed |= result;
-  }
-  return changed;
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gemm_rewriter.h b/third_party/xla/xla/service/gpu/gemm_rewriter.h
deleted file mode 100644
index 161a29a..0000000
--- a/third_party/xla/xla/service/gpu/gemm_rewriter.h
+++ /dev/null
@@ -1,70 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_GEMM_REWRITER_H_
-#define XLA_SERVICE_GPU_GEMM_REWRITER_H_
-
-#include <cstdint>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/stream_executor/device_description.h"
-
-namespace xla {
-namespace gpu {
-
-// cuBLAS GEMM in the most general form can run the following operation:
-//
-// (kAdd
-//    (kMultiply (kDot A B) alpha)
-//    (kMultiply C beta))
-//
-// where A, B, C are matrices or vectors and `alpha` and `beta` are host
-// constants. In matrix-vector multiplication, one operand must be a matrix and
-// the other must be a vector. The additional requirement is that C has no other
-// users (otherwise, it does not make sense to fuse it inside the custom call).
-//
-// Both multiplication and addition can be avoided (equivalent to setting
-// `alpha` to one and `beta` to zero).
-//
-// This pass pattern-matches the most general form of this instruction
-// (we assume transposes are already folded), and rewrites it into a custom call
-// where (A, B, C) are three operands respectively, and `alpha` and `beta` are
-// stored in the backend config.
-class GemmRewriter : public HloModulePass {
- public:
-  // When f8_rewrite is true, only FP8 GEMMs are rewritten. Otherwise, non-FP8
-  // GEMMs are rewritten.
-  GemmRewriter(se::GpuComputeCapability gpu_version, int32_t toolkit_version,
-               bool f8_rewrite = false);
-  absl::string_view name() const override { return "cublas-gemm-rewriter"; }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
-  se::GpuComputeCapability gpu_version_;
-  int32_t toolkit_version_;
-  bool f8_rewrite_;
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_GEMM_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/gemv_rewriter.cc b/third_party/xla/xla/service/gpu/gemv_rewriter.cc
deleted file mode 100644
index 21e5f47..0000000
--- a/third_party/xla/xla/service/gpu/gemv_rewriter.cc
+++ /dev/null
@@ -1,183 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gemv_rewriter.h"
-
-#include <cstdint>
-#include <vector>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/container/inlined_vector.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/layout.h"
-#include "xla/layout_util.h"
-#include "xla/shape.h"
-#include "xla/util.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-namespace {
-
-// Construct a new layout by adding a new minor-most dimension to the input
-// layout. For example, {3, 2, 1, 0} is extended to {4, 3, 2, 1, 0}.
-// We expect that the input layout is normalized by LayoutNormalizer, so that
-// the input layout has a descending ordering.
-absl::StatusOr<Layout> GetLayoutWithNewMinorMostDimension(
-    const Layout& layout) {
-  // Check that the layout is normalized.
-  if (!LayoutUtil::IsMonotonicWithDim0Major(layout)) {
-    return absl::InvalidArgumentError("Layout is not normalized.");
-  }
-  return LayoutUtil::MakeDescendingLayout(layout.minor_to_major_size() + 1);
-}
-
-class GemvRewriterVisitor : public DfsHloRewriteVisitor {
- public:
-  absl::Status HandleDot(HloInstruction* instr) override {
-    HloDotInstruction* dot = Cast<HloDotInstruction>(instr);
-    const DotDimensionNumbers& dim_numbers = dot->dot_dimension_numbers();
-    HloInstruction* lhs = dot->mutable_operand(0);
-    HloInstruction* rhs = dot->mutable_operand(1);
-
-    // This pass relies on dot decomposer which ensures that all non-batch
-    // dimensions are merged into one.
-    bool lhs_has_non_contracting_dim =
-        lhs->shape().rank() ==
-        dim_numbers.lhs_batch_dimensions_size() +
-            dim_numbers.lhs_contracting_dimensions_size() + 1;
-    bool rhs_has_non_contracting_dim =
-        rhs->shape().rank() ==
-        dim_numbers.rhs_batch_dimensions_size() +
-            dim_numbers.rhs_contracting_dimensions_size() + 1;
-
-    // Skip matrix-matrix multiplication.
-    if (lhs_has_non_contracting_dim && rhs_has_non_contracting_dim) {
-      return absl::OkStatus();
-    }
-
-    // Skip vector-vector multiplication.
-    if (!lhs_has_non_contracting_dim && !rhs_has_non_contracting_dim) {
-      return absl::OkStatus();
-    }
-
-    if (dot->shape().is_dynamic()) {
-      return absl::OkStatus();
-    }
-
-    changed_ = true;
-
-    HloComputation* computation = dot->parent();
-    HloInstruction* new_lhs = lhs;
-    if (!lhs_has_non_contracting_dim) {
-      const Shape& lhs_shape = lhs->shape();
-      absl::Span<const int64_t> lhs_dimensions = lhs_shape.dimensions();
-      std::vector<int64_t> new_lhs_dimensions(lhs_dimensions.begin(),
-                                              lhs_dimensions.end());
-      new_lhs_dimensions.push_back(1);
-      Shape new_lhs_shape(
-          lhs_shape.element_type(), new_lhs_dimensions,
-          absl::InlinedVector<bool, 4>(new_lhs_dimensions.size(), false),
-          /*tuple_shapes=*/{});
-      TF_ASSIGN_OR_RETURN(
-          *new_lhs_shape.mutable_layout(),
-          GetLayoutWithNewMinorMostDimension(lhs_shape.layout()));
-      new_lhs = computation->AddInstruction(
-          HloInstruction::CreateBitcast(new_lhs_shape, lhs));
-    }
-
-    HloInstruction* new_rhs = rhs;
-    if (!rhs_has_non_contracting_dim) {
-      const Shape& rhs_shape = rhs->shape();
-      absl::Span<const int64_t> rhs_dimensions = rhs_shape.dimensions();
-      std::vector<int64_t> new_rhs_dimensions(rhs_dimensions.begin(),
-                                              rhs_dimensions.end());
-      new_rhs_dimensions.push_back(1);
-      Shape new_rhs_shape(
-          rhs_shape.element_type(), new_rhs_dimensions,
-          absl::InlinedVector<bool, 4>(new_rhs_dimensions.size(), false),
-          /*tuple_shapes=*/{});
-      TF_ASSIGN_OR_RETURN(
-          *new_rhs_shape.mutable_layout(),
-          GetLayoutWithNewMinorMostDimension(rhs_shape.layout()));
-      new_rhs = computation->AddInstruction(
-          HloInstruction::CreateBitcast(new_rhs_shape, rhs));
-    }
-
-    std::vector<int64_t> new_out_dimensions;
-    new_out_dimensions.reserve(dot->shape().dimensions().size() + 1);
-    for (int64_t dim_size : dot->shape().dimensions()) {
-      new_out_dimensions.push_back(dim_size);
-    }
-    if (!lhs_has_non_contracting_dim) {
-      // Insert the trivial dimension before the non-contracting dimension from
-      // rhs.
-      int non_contracting_dim_size = new_out_dimensions.back();
-      new_out_dimensions[new_out_dimensions.size() - 1] = 1;
-      new_out_dimensions.push_back(non_contracting_dim_size);
-    } else {
-      new_out_dimensions.push_back(1);
-    }
-
-    Shape new_out_shape(
-        dot->shape().element_type(), new_out_dimensions,
-        absl::InlinedVector<bool, 4>(new_out_dimensions.size(), false),
-        /*tuple_shapes=*/{});
-    TF_ASSIGN_OR_RETURN(
-        *new_out_shape.mutable_layout(),
-        GetLayoutWithNewMinorMostDimension(dot->shape().layout()));
-
-    HloInstruction* new_dot =
-        computation->AddInstruction(HloInstruction::CreateDot(
-            new_out_shape, new_lhs, new_rhs, dot->dot_dimension_numbers(),
-            dot->precision_config()));
-    HloInstruction* bitcast = computation->AddInstruction(
-        HloInstruction::CreateBitcast(dot->shape(), new_dot));
-    return computation->ReplaceInstruction(dot, bitcast);
-  }
-
-  bool changed() const { return changed_; }
-
- private:
-  bool changed_ = false;
-};
-
-}  // namespace
-
-absl::StatusOr<bool> GemvRewriter::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  GemvRewriterVisitor gemv_rewriter;
-  for (HloComputation* computation :
-       module->MakeNonfusionComputations(execution_threads)) {
-    TF_RETURN_IF_ERROR(computation->Accept(&gemv_rewriter));
-  }
-  return gemv_rewriter.changed();
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gemv_rewriter.h b/third_party/xla/xla/service/gpu/gemv_rewriter.h
deleted file mode 100644
index a041138..0000000
--- a/third_party/xla/xla/service/gpu/gemv_rewriter.h
+++ /dev/null
@@ -1,44 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_GEMV_REWRITER_H_
-#define XLA_SERVICE_GPU_GEMV_REWRITER_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-// Rewrite a matrix-vector or a vector-matrix multiplication into a
-// matrix-matrix multiplication with a trivial dimension. For example,
-// [m x n] @ [n] is rewritten to [m x n] @ [n x 1], and [n] @ [m x n] is
-// rewritten to [n x 1] @ [m x n].
-class GemvRewriter : public HloModulePass {
- public:
-  absl::string_view name() const override { return "gemv-rewriter"; }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_GEMV_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/gemv_rewriter_test.cc b/third_party/xla/xla/service/gpu/gemv_rewriter_test.cc
deleted file mode 100644
index 2a8b810..0000000
--- a/third_party/xla/xla/service/gpu/gemv_rewriter_test.cc
+++ /dev/null
@@ -1,149 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gemv_rewriter.h"
-
-#include <memory>
-#include <optional>
-
-#include <gtest/gtest.h>
-#include "absl/status/statusor.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/tests/hlo_test_base.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla::gpu {
-namespace {
-
-class GemvRewriterTest : public HloTestBase {};
-
-TEST_F(GemvRewriterTest, RewriteMatrixVectorMultiplicationToGemm) {
-  const char* hlo = R"(
-  HloModule m
-
-  ENTRY e {
-    p0 = f32[32,7] parameter(0)
-    p1 = f32[7] parameter(1)
-    ROOT d = f32[32] dot(p0, p1),
-      lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  })";
-
-  const char* expected = R"()
-// CHECK:  %[[P0:.*]] = f32[32,7]{1,0} parameter(0)
-// CHECK:  %[[P1:.*]] = f32[7]{0} parameter(1)
-// CHECK:  %[[BITCAST:.*]] = f32[7,1]{1,0} bitcast(%[[P1]])
-// CHECK:  %[[DOT:.*]] = f32[32,1]{1,0} dot(%[[P0]], %[[BITCAST]]), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-// CHECK:  ROOT %[[ROOT:.*]] = f32[32]{0} bitcast(%[[DOT]])
-})";
-
-  RunAndFilecheckHloRewrite(hlo, GemvRewriter(), expected);
-}
-
-TEST_F(GemvRewriterTest, RewriteVectorMatrixMultiplicationToGemm) {
-  const char* hlo = R"(
-  HloModule m
-
-  ENTRY e {
-    p0 = f32[7] parameter(0)
-    p1 = f32[7,32] parameter(1)
-    ROOT d = f32[32] dot(p0, p1),
-      lhs_contracting_dims={0}, rhs_contracting_dims={0}
-  })";
-
-  const char* expected = R"()
-// CHECK:  %[[P0:.*]] = f32[7]{0} parameter(0)
-// CHECK:  %[[BITCAST:.*]] = f32[7,1]{1,0} bitcast(%[[P0]])
-// CHECK:  %[[P1:.*]] = f32[7,32]{1,0} parameter(1)
-// CHECK:  %[[DOT:.*]] = f32[1,32]{1,0} dot(%[[BITCAST]], %[[P1]]), lhs_contracting_dims={0}, rhs_contracting_dims={0}
-// CHECK:  ROOT %[[ROOT:.*]].1 = f32[32]{0} bitcast(%[[DOT]])
-})";
-
-  RunAndFilecheckHloRewrite(hlo, GemvRewriter(), expected);
-}
-
-TEST_F(GemvRewriterTest, RewriteMatrixVectorMultiplicationWithBatch) {
-  const char* hlo = R"(
-  HloModule m
-
-  ENTRY e {
-    p0 = f32[2,5,32,7] parameter(0)
-    p1 = f32[2,5,7] parameter(1)
-    ROOT d = f32[2,5,32] dot(p0, p1),
-      lhs_batch_dims={0,1}, rhs_batch_dims={0,1},
-      lhs_contracting_dims={3}, rhs_contracting_dims={2}
-  })";
-
-  const char* expected = R"()
-// CHECK:  %[[P0:.*]] = f32[2,5,32,7]{3,2,1,0} parameter(0)
-// CHECK:  %[[P1:.*]] = f32[2,5,7]{2,1,0} parameter(1)
-// CHECK:  %[[BITCAST:.*]] = f32[2,5,7,1]{3,2,1,0} bitcast(%[[P1]])
-// CHECK:  %[[DOT:.*]] = f32[2,5,32,1]{3,2,1,0} dot(%[[P0]], %[[BITCAST]]),
-// CHECK-SAME: lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
-// CHECK:  ROOT %[[ROOT:.*]] = f32[2,5,32]{2,1,0} bitcast(%[[DOT]])
-})";
-
-  RunAndFilecheckHloRewrite(hlo, GemvRewriter(), expected);
-}
-
-TEST_F(GemvRewriterTest, DotNotRewriteVectorVectorMultiplication) {
-  const char* hlo = R"(
-  HloModule m
-
-  ENTRY e {
-    p0 = f32[7] parameter(0)
-    p1 = f32[7] parameter(1)
-    ROOT d = f32[] dot(p0, p1),
-      lhs_contracting_dims={0}, rhs_contracting_dims={0}
-  })";
-
-  RunAndFilecheckHloRewrite(hlo, GemvRewriter(), /*expected=*/std::nullopt);
-}
-
-TEST_F(GemvRewriterTest, DotNotRewriteMatrixMatrixMultiplication) {
-  const char* hlo = R"(
-  HloModule m
-
-  ENTRY e {
-    p0 = f32[5,7] parameter(0)
-    p1 = f32[7,32] parameter(1)
-    ROOT d = f32[5,32] dot(p0, p1),
-      lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  })";
-
-  RunAndFilecheckHloRewrite(hlo, GemvRewriter(), /*expected=*/std::nullopt);
-}
-
-TEST_F(GemvRewriterTest, DoNotRewriteDotsWithNonNormalizedLayout) {
-  const char* hlo = R"(
-  HloModule m
-
-  ENTRY e {
-    p0 = f32[5,32,7]{2,1,0} parameter(0)
-    p1 = f32[5,7]{0,1} parameter(1)
-    ROOT d = f32[5,32]{0,1} dot(p0, p1),
-      lhs_batch_dims={0}, rhs_batch_dims={0},
-      lhs_contracting_dims={2}, rhs_contracting_dims={1}
-  })";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(hlo));
-  GemvRewriter rewriter;
-  absl::StatusOr<bool> result = this->RunHloPass(&rewriter, module.get());
-  EXPECT_FALSE(result.ok());
-  EXPECT_EQ(result.status().message(), "Layout is not normalized.");
-}
-
-}  // namespace
-}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/gpu_algebraic_simplifier.cc b/third_party/xla/xla/service/gpu/gpu_algebraic_simplifier.cc
deleted file mode 100644
index 21e8d6c..0000000
--- a/third_party/xla/xla/service/gpu/gpu_algebraic_simplifier.cc
+++ /dev/null
@@ -1,66 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gpu_algebraic_simplifier.h"
-
-#include "absl/log/check.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/service/gpu/fusions/triton/triton_support.h"
-#include "xla/service/gpu/matmul_utils.h"
-#include "xla/xla_data.pb.h"
-
-namespace xla::gpu {
-
-bool GpuAlgebraicSimplifierVisitor::ShouldStrengthReduceDotToReduce(
-    const HloInstruction* hlo) {
-  if (!options_.enable_dot_strength_reduction()) {
-    return false;
-  }
-
-  const HloDotInstruction* dot = DynCast<HloDotInstruction>(hlo);
-  if (dot == nullptr) {
-    return false;
-  }
-
-  const HloInstruction* lhs = dot->operand(0);
-  const HloInstruction* rhs = dot->operand(1);
-  DotDimensionNumbers dnums = dot->dot_dimension_numbers();
-  bool lhs_is_vector = (dnums.lhs_batch_dimensions_size() +
-                            dnums.lhs_contracting_dimensions_size() ==
-                        lhs->shape().rank());
-  bool rhs_is_vector = (dnums.rhs_batch_dimensions_size() +
-                            dnums.rhs_contracting_dimensions_size() ==
-                        rhs->shape().rank());
-  // Strength-reduce vector-vector dots since they are not supported by
-  // GemmFusion.
-  if (lhs_is_vector && rhs_is_vector) {
-    return true;
-  }
-
-  absl::StatusOr<bool> is_too_small =
-      IsMatrixMultiplicationTooSmallForRewriting(*hlo, /*threshold=*/10000000);
-  CHECK_OK(is_too_small.status());
-  if (is_too_small.value()) {
-    return true;
-  }
-
-  // If GemmFusion cannot handle this dot, we should strength-reduce it so that
-  // it can be handled by the fusion pipeline.
-  return !legacy_triton::CanTritonHandleGEMM(*dot, compute_capability_);
-}
-
-}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/gpu_algebraic_simplifier.h b/third_party/xla/xla/service/gpu/gpu_algebraic_simplifier.h
deleted file mode 100644
index 8553596..0000000
--- a/third_party/xla/xla/service/gpu/gpu_algebraic_simplifier.h
+++ /dev/null
@@ -1,78 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_GPU_ALGEBRAIC_SIMPLIFIER_H_
-#define XLA_SERVICE_GPU_GPU_ALGEBRAIC_SIMPLIFIER_H_
-
-#include <utility>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/service/algebraic_simplifier.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/util.h"
-
-namespace xla::gpu {
-
-class GpuAlgebraicSimplifierVisitor : public AlgebraicSimplifierVisitor {
- public:
-  explicit GpuAlgebraicSimplifierVisitor(
-      const AlgebraicSimplifierOptions& options,
-      se::GpuComputeCapability compute_capability,
-      AlgebraicSimplifier* simplifier)
-      : AlgebraicSimplifierVisitor(options, simplifier),
-        compute_capability_(std::move(compute_capability)) {}
-
-  bool ShouldStrengthReduceDotToReduce(const HloInstruction* hlo) override;
-
- private:
-  se::GpuComputeCapability compute_capability_;
-};
-
-class GpuAlgebraicSimplifier : public AlgebraicSimplifier {
- public:
-  explicit GpuAlgebraicSimplifier(const AlgebraicSimplifierOptions& options,
-                                  se::GpuComputeCapability compute_capability)
-      : AlgebraicSimplifier(options),
-        compute_capability_(std::move(compute_capability)) {}
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(HloModule* module,
-                           const absl::flat_hash_set<absl::string_view>&
-                               execution_threads) override {
-    XLA_VLOG_LINES(
-        2, "GpuAlgebraicSimplifier::Run(), before:\n" + module->ToString());
-    bool changed = false;
-    GpuAlgebraicSimplifierVisitor visitor(options_, compute_capability_, this);
-    for (auto* comp : module->MakeNonfusionComputations(execution_threads)) {
-      if (visitor.Run(comp, options_, this)) {
-        changed = true;
-      }
-    }
-    XLA_VLOG_LINES(
-        2, "GpuAlgebraicSimplifier::Run(), after:\n" + module->ToString());
-    return changed;
-  }
-
- private:
-  se::GpuComputeCapability compute_capability_;
-};
-
-}  // namespace xla::gpu
-
-#endif  // XLA_SERVICE_GPU_GPU_ALGEBRAIC_SIMPLIFIER_H_
diff --git a/third_party/xla/xla/service/gpu/gpu_algebraic_simplifier_test.cc b/third_party/xla/xla/service/gpu/gpu_algebraic_simplifier_test.cc
deleted file mode 100644
index 135ddb12..0000000
--- a/third_party/xla/xla/service/gpu/gpu_algebraic_simplifier_test.cc
+++ /dev/null
@@ -1,141 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gpu_algebraic_simplifier.h"
-
-#include <string>
-
-#include <gtest/gtest.h>
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/service/algebraic_simplifier.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/tests/hlo_test_base.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla::gpu {
-namespace {
-
-class GpuAlgebraicSimplifierTest : public HloTestBase {};
-
-TEST_F(GpuAlgebraicSimplifierTest, VectorVectorDotShouldBeStrengthReduced) {
-  const std::string& hlo_string = R"(
-HloModule m
-
-ENTRY entry {
-  p0 = f32[32, 500] parameter(0)
-  p1 = f32[32, 500] parameter(1)
-  ROOT dot = f32[32] dot(p0, p1), lhs_batch_dims={0},
-    lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1}
-})";
-  TF_ASSERT_OK_AND_ASSIGN(auto module,
-                          ParseAndReturnVerifiedModule(hlo_string));
-  const HloInstruction* dot = module->entry_computation()->root_instruction();
-  AlgebraicSimplifierOptions options;
-  options.set_enable_dot_strength_reduction(true);
-  se::CudaComputeCapability ampere(8, 0);
-  GpuAlgebraicSimplifier simplifier(options, ampere);
-  GpuAlgebraicSimplifierVisitor visitor(options, ampere, &simplifier);
-  EXPECT_TRUE(visitor.ShouldStrengthReduceDotToReduce(dot));
-}
-
-TEST_F(GpuAlgebraicSimplifierTest, MatrixVectorDotShouldNotBeStrengthReduced) {
-  const std::string& hlo_string = R"(
-HloModule m
-
-ENTRY entry {
-  p0 = f32[32, 5000, 7000] parameter(0)
-  p1 = f32[32, 5000] parameter(1)
-  ROOT dot = f32[32,7000] dot(p0, p1), lhs_batch_dims={0},
-    lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1},
-    algorithm=dot_bf16_bf16_f32_x6
-})";
-  TF_ASSERT_OK_AND_ASSIGN(auto module,
-                          ParseAndReturnVerifiedModule(hlo_string));
-  const HloInstruction* dot = module->entry_computation()->root_instruction();
-  AlgebraicSimplifierOptions options;
-  options.set_enable_dot_strength_reduction(true);
-  se::CudaComputeCapability ampere(8, 0);
-  GpuAlgebraicSimplifier simplifier(options, ampere);
-  GpuAlgebraicSimplifierVisitor visitor(options, ampere, &simplifier);
-  EXPECT_FALSE(visitor.ShouldStrengthReduceDotToReduce(dot));
-}
-
-TEST_F(GpuAlgebraicSimplifierTest,
-       DotWithTypeUnsupportedByGemmFusionShouldBeStrengthReduced) {
-  const std::string& hlo_string = R"(
-HloModule m
-
-ENTRY entry {
-  p0 = c64[32, 5000, 7000] parameter(0)
-  p1 = c64[32, 5000] parameter(1)
-  ROOT dot = c64[32,7000] dot(p0, p1), lhs_batch_dims={0},
-    lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1}
-})";
-  TF_ASSERT_OK_AND_ASSIGN(auto module,
-                          ParseAndReturnVerifiedModule(hlo_string));
-  const HloInstruction* dot = module->entry_computation()->root_instruction();
-  AlgebraicSimplifierOptions options;
-  options.set_enable_dot_strength_reduction(true);
-  se::CudaComputeCapability ampere(8, 0);
-  GpuAlgebraicSimplifier simplifier(options, ampere);
-  GpuAlgebraicSimplifierVisitor visitor(options, ampere, &simplifier);
-  EXPECT_TRUE(visitor.ShouldStrengthReduceDotToReduce(dot));
-}
-
-TEST_F(GpuAlgebraicSimplifierTest, SmallDotShouldBeStrengthReduced) {
-  const std::string& hlo_string = R"(
-HloModule m
-
-ENTRY entry {
-  p0 = f32[32, 50, 70] parameter(0)
-  p1 = f32[32, 50] parameter(1)
-  ROOT dot = f32[32,70] dot(p0, p1), lhs_batch_dims={0},
-    lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1},
-    algorithm=dot_bf16_bf16_f32_x6
-})";
-  TF_ASSERT_OK_AND_ASSIGN(auto module,
-                          ParseAndReturnVerifiedModule(hlo_string));
-  const HloInstruction* dot = module->entry_computation()->root_instruction();
-  AlgebraicSimplifierOptions options;
-  options.set_enable_dot_strength_reduction(true);
-  se::CudaComputeCapability ampere(8, 0);
-  GpuAlgebraicSimplifier simplifier(options, ampere);
-  GpuAlgebraicSimplifierVisitor visitor(options, ampere, &simplifier);
-  EXPECT_TRUE(visitor.ShouldStrengthReduceDotToReduce(dot));
-}
-
-TEST_F(GpuAlgebraicSimplifierTest, SmallDotShouldBeStrengthReduced2) {
-  const std::string& hlo_string = R"(
-HloModule m
-
-ENTRY entry {
-  p0 = f32[2000, 3000] parameter(0)
-  p1 = f32[2000] parameter(1)
-  ROOT dot = f32[3000] dot(p0, p1), lhs_contracting_dims={0},
-    rhs_contracting_dims={0}, algorithm=dot_bf16_bf16_f32_x6
-})";
-  TF_ASSERT_OK_AND_ASSIGN(auto module,
-                          ParseAndReturnVerifiedModule(hlo_string));
-  const HloInstruction* dot = module->entry_computation()->root_instruction();
-  AlgebraicSimplifierOptions options;
-  options.set_enable_dot_strength_reduction(true);
-  se::CudaComputeCapability ampere(8, 0);
-  GpuAlgebraicSimplifier simplifier(options, ampere);
-  GpuAlgebraicSimplifierVisitor visitor(options, ampere, &simplifier);
-  EXPECT_TRUE(visitor.ShouldStrengthReduceDotToReduce(dot));
-}
-
-}  // namespace
-}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/gpu_all_gather_optimizer.cc b/third_party/xla/xla/service/gpu/gpu_all_gather_optimizer.cc
deleted file mode 100644
index fe2d2d1..0000000
--- a/third_party/xla/xla/service/gpu/gpu_all_gather_optimizer.cc
+++ /dev/null
@@ -1,109 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gpu_all_gather_optimizer.h"
-
-#include <cstdint>
-#include <utility>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/collective_ops_utils.h"
-#include "xla/shape_util.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h"
-
-namespace xla {
-namespace gpu {
-
-absl::StatusOr<bool> AllGatherOptimizer::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  bool changed = false;
-  for (HloComputation* computation :
-       module->MakeNonfusionComputations(execution_threads)) {
-    for (HloInstruction* instruction :
-         computation->MakeInstructionPostOrder()) {
-      if (!HloOpcodeIsBinaryCommutative(instruction->opcode())) {
-        continue;
-      }
-
-      HloInstruction* left_op = instruction->mutable_operand(0);
-      HloInstruction* right_op = instruction->mutable_operand(1);
-
-      if (right_op->opcode() != HloOpcode::kAllGather ||
-          left_op->opcode() != HloOpcode::kAllGather) {
-        VLOG(2) << "Binary op's operands are not all-gather deduced types.";
-        continue;
-      }
-
-      auto* left_all_gather = Cast<HloAllGatherInstruction>(left_op);
-      auto* right_all_gather = Cast<HloAllGatherInstruction>(right_op);
-
-      if (right_all_gather->constrain_layout() !=
-              left_all_gather->constrain_layout() ||
-          right_all_gather->use_global_device_ids() !=
-              left_all_gather->use_global_device_ids() ||
-          !ReplicaGroupsEqual(right_all_gather->replica_groups(),
-                              left_all_gather->replica_groups())) {
-        VLOG(2) << "The right and left all-gather ops are not compatible "
-                   "to merge. ";
-        continue;
-      }
-
-      if (!ShapeUtil::Equal(left_all_gather->operand(0)->shape(),
-                            right_all_gather->operand(0)->shape())) {
-        VLOG(2) << "all-gather operands have different shapes";
-        continue;
-      }
-
-      if (right_all_gather->user_count() != 1 ||
-          left_all_gather->user_count() != 1) {
-        VLOG(2) << "all-gather user_count > 1 ";
-        continue;
-      }
-      auto index_in_full_shape =
-          computation->AddInstruction(HloInstruction::CreateBinary(
-              right_all_gather->operand(0)->shape(), instruction->opcode(),
-              left_all_gather->mutable_operand(0),
-              right_all_gather->mutable_operand(0)));
-
-      int64_t all_gather_dimension =
-          Cast<HloAllGatherInstruction>(right_all_gather)
-              ->all_gather_dimension();
-
-      auto combined = HloInstruction::CreateAllGather(
-          left_all_gather->shape(), {index_in_full_shape}, all_gather_dimension,
-          left_all_gather->device_list(),
-          /*constrain_layout=*/false, left_all_gather->channel_id(),
-          Cast<HloAllGatherInstruction>(left_all_gather)
-              ->use_global_device_ids());
-
-      TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction(
-          instruction, std::move(combined)));
-      changed = true;
-    }
-  }
-
-  return changed;
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gpu_all_gather_optimizer.h b/third_party/xla/xla/service/gpu/gpu_all_gather_optimizer.h
deleted file mode 100644
index e28e422..0000000
--- a/third_party/xla/xla/service/gpu/gpu_all_gather_optimizer.h
+++ /dev/null
@@ -1,46 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_GPU_ALL_GATHER_OPTIMIZER_H_
-#define XLA_SERVICE_GPU_GPU_ALL_GATHER_OPTIMIZER_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-// Transforms binary_op(all-gather(reduce_scatter(a)),
-// all-gather(reduce_scatter(b))) to allgather(binary_op(reduce_scatter(a),
-// reduce_scatter(b)))
-
-class AllGatherOptimizer : public HloModulePass {
- public:
-  AllGatherOptimizer() = default;
-  absl::string_view name() const override { return "all-gather-optimizer"; }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_GPU_ALL_GATHER_OPTIMIZER_H_
diff --git a/third_party/xla/xla/service/gpu/gpu_async_collective_annotator.cc b/third_party/xla/xla/service/gpu/gpu_async_collective_annotator.cc
deleted file mode 100644
index c2f6c04..0000000
--- a/third_party/xla/xla/service/gpu/gpu_async_collective_annotator.cc
+++ /dev/null
@@ -1,55 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gpu_async_collective_annotator.h"
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/utils/hlo_query.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-absl::StatusOr<bool> GpuAsyncCollectiveAnnotator::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  bool changed = false;
-  for (HloComputation* computation :
-       module->MakeNonfusionComputations(execution_threads)) {
-    for (HloInstruction* instruction : computation->instructions()) {
-      if (!hlo_query::IsAsyncCollectiveStartOp(instruction)) {
-        continue;
-      }
-      CollectiveBackendConfig config;
-      config.set_is_sync(!is_collective_async_(instruction));
-      TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
-                          instruction->backend_config<GpuBackendConfig>());
-      *gpu_config.mutable_collective_backend_config() = config;
-      TF_RETURN_IF_ERROR(instruction->set_backend_config(gpu_config));
-      changed = true;
-    }
-  }
-  return changed;
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gpu_async_collective_annotator.h b/third_party/xla/xla/service/gpu/gpu_async_collective_annotator.h
deleted file mode 100644
index 4000fbc..0000000
--- a/third_party/xla/xla/service/gpu/gpu_async_collective_annotator.h
+++ /dev/null
@@ -1,52 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_GPU_ASYNC_COLLECTIVE_ANNOTATOR_H_
-#define XLA_SERVICE_GPU_GPU_ASYNC_COLLECTIVE_ANNOTATOR_H_
-
-#include <utility>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/util.h"
-
-namespace xla {
-namespace gpu {
-
-// Annotate async collectives with CollectiveBackendConfig.
-class GpuAsyncCollectiveAnnotator : public HloModulePass {
- public:
-  explicit GpuAsyncCollectiveAnnotator(HloPredicate is_collective_async)
-      : is_collective_async_(std::move(is_collective_async)) {}
-  absl::string_view name() const override {
-    return "gpu-async-collective-annotator";
-  }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
-  HloPredicate is_collective_async_;
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_GPU_ASYNC_COLLECTIVE_ANNOTATOR_H_
diff --git a/third_party/xla/xla/service/gpu/gpu_async_collective_annotator_test.cc b/third_party/xla/xla/service/gpu/gpu_async_collective_annotator_test.cc
deleted file mode 100644
index f874a7e..0000000
--- a/third_party/xla/xla/service/gpu/gpu_async_collective_annotator_test.cc
+++ /dev/null
@@ -1,183 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gpu_async_collective_annotator.h"
-
-#include <memory>
-#include <string>
-#include <vector>
-
-#include <gtest/gtest.h>
-#include "absl/container/flat_hash_set.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/hlo/utils/hlo_query.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/tests/test_macros.h"
-#include "xla/util.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-constexpr absl::string_view kHloString = R"(
-  HloModule ModuleWithAsync
-
-  addf32 {
-    p0 = f32[] parameter(0)
-    p1 = f32[] parameter(1)
-    ROOT add = f32[] add(p0, p1)
-  }
-
-  addf16 {
-    p0 = f16[] parameter(0)
-    p1 = f16[] parameter(1)
-    ROOT add = f16[] add(p0, p1)
-  }
-
-  reduce_scatterf32 {
-    p0 = f32[2] parameter(0)
-    ROOT result = f32[1] reduce-scatter(p0), replica_groups={},
-                      dimensions={0}, to_apply=addf32
-  }
-
-  ENTRY entry {
-    pf32 = f32[1] parameter(0)
-    pf16 = f16[1] parameter(1)
-
-    arf32-start = f32[1] all-reduce-start(pf32), to_apply=addf32
-    arf32-done = f32[1] all-reduce-done(arf32-start)
-
-    arf16-start = f16[1] all-reduce-start(pf16), to_apply=addf16
-    arf16-done = f16[1] all-reduce-done(arf16-start)
-
-    agf32-start = (f32[1], f32[2]) all-gather-start(pf32), dimensions={0}
-    agf32-done = f32[2] all-gather-done(agf32-start)
-
-    agf16-start = (f16[1], f16[2]) all-gather-start(pf16), dimensions={0}
-    agf16-done = f16[2] all-gather-done(agf16-start)
-
-    cpf32-start = (f32[1], f32[1], u32[], u32[]) collective-permute-start(pf32),
-                    source_target_pairs={{0,1}, {1,0}}
-    cpf32-done = f32[1] collective-permute-done(cpf32-start)
-
-    cpf16-start = (f16[1], f16[1], u32[], u32[]) collective-permute-start(pf16),
-                    source_target_pairs={{0,1}, {1,0}}
-    cpf16-done = f16[1] collective-permute-done(cpf16-start)
-
-    rsf32-start = ((f32[2]), f32[1]) async-start(agf32-done), calls=reduce_scatterf32
-    rsf32-done = f32[1] async-done(rsf32-start), calls=reduce_scatterf32
-
-    ROOT tuple = (f32[1], f16[1], f32[2], f16[2], f32[1], f16[1], f32[1])
-                tuple(arf32-done, arf16-done, agf32-done, agf16-done, cpf32-done,
-                      cpf16-done, rsf32-done)
-  }
-)";
-
-struct TestCase {
-  std::string test_name;
-  HloPredicate is_async_predicate;
-  absl::flat_hash_set<absl::string_view> expected_async;
-  absl::flat_hash_set<absl::string_view> expected_sync;
-};
-
-class GpuAsyncCollectiveAnnotatorTest
-    : public HloTestBase,
-      public ::testing::WithParamInterface<TestCase> {};
-
-XLA_TEST_P(GpuAsyncCollectiveAnnotatorTest, Test) {
-  const TestCase& test_case = GetParam();
-  TF_ASSERT_OK_AND_ASSIGN(
-      std::unique_ptr<HloModule> module,
-      ParseAndReturnVerifiedModule(kHloString, /*replica_count=*/2));
-  TF_ASSERT_OK_AND_ASSIGN(
-      bool changed, GpuAsyncCollectiveAnnotator(test_case.is_async_predicate)
-                        .Run(module.get()));
-  EXPECT_TRUE(changed);
-
-  // Assert that all async collectives are annotated with the backend config.
-  for (const HloInstruction* hlo :
-       module->entry_computation()->instructions()) {
-    if (!hlo_query::IsAsyncCollectiveStartOp(hlo)) {
-      continue;
-    }
-    auto gpu_config = hlo->backend_config<GpuBackendConfig>();
-    ASSERT_TRUE(gpu_config.ok());
-
-    const CollectiveBackendConfig& backend_config =
-        gpu_config.value().collective_backend_config();
-    if (test_case.expected_async.contains(hlo->name())) {
-      EXPECT_FALSE(backend_config.is_sync());
-    }
-
-    if (test_case.expected_sync.contains(hlo->name())) {
-      EXPECT_TRUE(backend_config.is_sync());
-    }
-  }
-}
-
-std::vector<TestCase> TestCases() {
-  HloPredicate is_f16 = [](const HloInstruction* hlo) {
-    return hlo->operand(0)->shape().element_type() == PrimitiveType::F16;
-  };
-
-  return {
-      {"all_async",
-       HloPredicateTrue, /*expected_async=*/
-       {"arf32-start", "arf16-start", "agf32-start", "agf16-start",
-        "cpf32-start", "cpf16-start", "rsf32-start"},
-       /*expected_sync=*/{}},
-      {"all_sync",
-       HloPredicateFalse,
-       /*expected_async=*/{},
-       /*expected_sync=*/
-       {"arf32-start", "arf16-start", "agf32-start", "agf16-start",
-        "cpf32-start", "cpf16-start", "rsf32-start"}},
-      {"ar_async",
-       HloPredicateIsOp<HloOpcode::kAllReduceStart>,
-       /*expected_async=*/
-       {"arf32-start", "arf16-start"},
-       /*expected_sync=*/
-       {"agf32-start", "agf16-start", "cpf32-start", "cpf16-start",
-        "rsf32-start"}},
-      {"cp_async",
-       HloPredicateIsOp<HloOpcode::kCollectivePermuteStart>,
-       /*expected_async=*/
-       {"cpf32-start", "cpf16-start"},
-       /*expected_sync=*/
-       {"arf32-start", "arf16-start", "agf32-start", "agf16-start",
-        "rsf32-start"}},
-      {"f16_async",
-       is_f16,
-       /*expected_async=*/{"arf16-start", "agf16-start", "cpf16-start"},
-       /*expected_sync=*/
-       {"arf32-start", "agf32-start", "cpf32-start", "rsf32-start"}},
-  };
-}
-
-std::string TestCaseName(const ::testing::TestParamInfo<TestCase>& test_case) {
-  return test_case.param.test_name;
-}
-
-INSTANTIATE_TEST_SUITE_P(GpuAsyncCollectiveAnnotatorTest,
-                         GpuAsyncCollectiveAnnotatorTest,
-                         ::testing::ValuesIn(TestCases()), TestCaseName);
-}  // namespace
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc
index 96b537e..b477a88 100644
--- a/third_party/xla/xla/service/gpu/gpu_compiler.cc
+++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc
@@ -74,7 +74,6 @@
 #include "xla/service/all_reduce_folder.h"
 #include "xla/service/all_reduce_promotion.h"
 #include "xla/service/all_reduce_reassociate.h"
-#include "xla/service/all_reduce_splitter.h"
 #include "xla/service/async_collective_creator.h"
 #include "xla/service/batchnorm_expander.h"
 #include "xla/service/bitcast_dtypes_expander.h"
@@ -110,43 +109,18 @@
 #include "xla/service/float_support.h"
 #include "xla/service/gather_expander.h"
 #include "xla/service/gather_simplifier.h"
-#include "xla/service/gpu/algorithm_checker.h"
-#include "xla/service/gpu/all_reduce_blueconnect.h"
-#include "xla/service/gpu/autotuner_util.h"
-#include "xla/service/gpu/collective_permute_cycle_decomposer.h"
-#include "xla/service/gpu/collective_permute_valid_iteration_annotator.h"
-#include "xla/service/gpu/command_buffer_scheduling.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
+#include "xla/service/gpu/autotuning/custom_kernel_fusion_autotuner.h"
 #include "xla/service/gpu/compile_module_to_llvm_ir.h"
 #include "xla/service/gpu/conv_layout_normalization.h"
-#include "xla/service/gpu/custom_kernel_fusion_autotuner.h"
-#include "xla/service/gpu/custom_kernel_fusion_rewriter.h"
-#include "xla/service/gpu/dot_dimension_sorter.h"
-#include "xla/service/gpu/dot_operand_converter.h"
-#include "xla/service/gpu/double_buffer_loop_unrolling.h"
-#include "xla/service/gpu/dynamic_slice_fusion_rewriter.h"
 #include "xla/service/gpu/execution_stream_assignment.h"
 #include "xla/service/gpu/fusion_pipeline.h"
-#include "xla/service/gpu/fusion_wrapper.h"
-#include "xla/service/gpu/gemm_broadcast_folding_rewriter.h"
-#include "xla/service/gpu/gemm_fusion.h"
-#include "xla/service/gpu/gemm_rewriter.h"
-#include "xla/service/gpu/gemv_rewriter.h"
-#include "xla/service/gpu/gpu_algebraic_simplifier.h"
-#include "xla/service/gpu/gpu_all_gather_optimizer.h"
-#include "xla/service/gpu/gpu_async_collective_annotator.h"
-#include "xla/service/gpu/gpu_conv_rewriter.h"
-#include "xla/service/gpu/gpu_convert_async_collectives_to_sync.h"
 #include "xla/service/gpu/gpu_executable.h"
 #include "xla/service/gpu/gpu_float_support.h"
 #include "xla/service/gpu/gpu_hlo_schedule.h"
 #include "xla/service/gpu/gpu_latency_hiding_scheduler.h"
-#include "xla/service/gpu/gpu_layout_assignment.h"
 #include "xla/service/gpu/gpu_p2p_pipeliner.h"
-#include "xla/service/gpu/gpu_reduce_scatter_creator.h"
-#include "xla/service/gpu/gpu_sanitize_constant_names.h"
-#include "xla/service/gpu/gpu_scatter_expander.h"
 #include "xla/service/gpu/gpu_spmd_pipeline.h"
-#include "xla/service/gpu/gpu_windowed_einsum_handler.h"
 #include "xla/service/gpu/hlo_fusion_stats.h"
 #include "xla/service/gpu/ir_emission_utils.h"
 #include "xla/service/gpu/ir_emitter_context.h"
@@ -156,26 +130,53 @@
 #include "xla/service/gpu/metrics.h"
 #include "xla/service/gpu/model/gpu_cost_model_stats_collection.h"
 #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h"
-#include "xla/service/gpu/move_copy_to_users.h"
-#include "xla/service/gpu/pipelined_p2p_rewriter.h"
 #include "xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.h"
-#include "xla/service/gpu/reduction_degenerate_dim_remover.h"
-#include "xla/service/gpu/reduction_dimension_grouper.h"
-#include "xla/service/gpu/reduction_layout_normalizer.h"
-#include "xla/service/gpu/reduction_splitter.h"
 #include "xla/service/gpu/reduction_utils.h"
-#include "xla/service/gpu/rename_fusions.h"
 #include "xla/service/gpu/runtime/thunk.h"
 #include "xla/service/gpu/runtime_intrinsics.h"
-#include "xla/service/gpu/scatter_slice_simplifier.h"
-#include "xla/service/gpu/softmax_rewriter_triton.h"
-#include "xla/service/gpu/stream_attribute_annotator.h"
-#include "xla/service/gpu/stream_attribute_async_wrapper.h"
 #include "xla/service/gpu/stream_executor_util.h"
-#include "xla/service/gpu/topk_specializer.h"
-#include "xla/service/gpu/topk_splitter.h"
-#include "xla/service/gpu/tree_reduction_rewriter.h"
-#include "xla/service/gpu/triton_fusion_numerics_verifier.h"
+#include "xla/service/gpu/transforms/algebraic_simplifier.h"
+#include "xla/service/gpu/transforms/algorithm_checker.h"
+#include "xla/service/gpu/transforms/all_gather_optimizer.h"
+#include "xla/service/gpu/transforms/all_reduce_blueconnect.h"
+#include "xla/service/gpu/transforms/all_reduce_splitter.h"
+#include "xla/service/gpu/transforms/async_collective_annotator.h"
+#include "xla/service/gpu/transforms/collective_permute_cycle_decomposer.h"
+#include "xla/service/gpu/transforms/collective_permute_valid_iteration_annotator.h"
+#include "xla/service/gpu/transforms/command_buffer_scheduling.h"
+#include "xla/service/gpu/transforms/conv_rewriter.h"
+#include "xla/service/gpu/transforms/convert_async_collectives_to_sync.h"
+#include "xla/service/gpu/transforms/cudnn_custom_call_converter.h"
+#include "xla/service/gpu/transforms/custom_kernel_fusion_rewriter.h"
+#include "xla/service/gpu/transforms/dot_dimension_sorter.h"
+#include "xla/service/gpu/transforms/dot_operand_converter.h"
+#include "xla/service/gpu/transforms/double_buffer_loop_unrolling.h"
+#include "xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.h"
+#include "xla/service/gpu/transforms/fusion_wrapper.h"
+#include "xla/service/gpu/transforms/gemm_broadcast_folding_rewriter.h"
+#include "xla/service/gpu/transforms/gemm_fusion.h"
+#include "xla/service/gpu/transforms/gemm_rewriter.h"
+#include "xla/service/gpu/transforms/gemv_rewriter.h"
+#include "xla/service/gpu/transforms/layout_assignment.h"
+#include "xla/service/gpu/transforms/move_copy_to_users.h"
+#include "xla/service/gpu/transforms/pipelined_p2p_rewriter.h"
+#include "xla/service/gpu/transforms/reduce_scatter_creator.h"
+#include "xla/service/gpu/transforms/reduction_degenerate_dim_remover.h"
+#include "xla/service/gpu/transforms/reduction_dimension_grouper.h"
+#include "xla/service/gpu/transforms/reduction_layout_normalizer.h"
+#include "xla/service/gpu/transforms/reduction_splitter.h"
+#include "xla/service/gpu/transforms/rename_fusions.h"
+#include "xla/service/gpu/transforms/sanitize_constant_names.h"
+#include "xla/service/gpu/transforms/scatter_expander.h"
+#include "xla/service/gpu/transforms/scatter_slice_simplifier.h"
+#include "xla/service/gpu/transforms/softmax_rewriter_triton.h"
+#include "xla/service/gpu/transforms/stream_attribute_annotator.h"
+#include "xla/service/gpu/transforms/stream_attribute_async_wrapper.h"
+#include "xla/service/gpu/transforms/topk_specializer.h"
+#include "xla/service/gpu/transforms/topk_splitter.h"
+#include "xla/service/gpu/transforms/tree_reduction_rewriter.h"
+#include "xla/service/gpu/transforms/triton_fusion_numerics_verifier.h"
+#include "xla/service/gpu/transforms/windowed_einsum_handler.h"
 #include "xla/service/hlo.pb.h"
 #include "xla/service/hlo_computation_deduplicator.h"
 #include "xla/service/hlo_constant_folding.h"
@@ -494,7 +495,7 @@
   AlgebraicSimplifierOptions layout_insensitive_algsimp_opts =
       opts_from_compiler;
   layout_insensitive_algsimp_opts.set_conv_is_lowerable_callback(
-      GpuConvRewriter::ConvIsLowerable);
+      ConvRewriter::ConvIsLowerable);
   layout_insensitive_algsimp_opts.set_enable_dot_strength_reduction(
       hlo_module_config.debug_options()
           .xla_gpu_enable_dot_strength_reduction());
@@ -526,6 +527,7 @@
   HloPassPipeline pre_spmd_pipeline("pre-spmd-partitioner");
   // Run some IR cleanup passes before running the SPMD partitioning
   // passes.
+  pre_spmd_pipeline.AddPass<CuDnnCustomCallConverter>();
   pre_spmd_pipeline.AddPass<ConvertMemoryPlacementToInternalAnnotations>();
   pre_spmd_pipeline.AddPass<CallInliner>();
   pre_spmd_pipeline.AddPass<ZeroSizedHloElimination>();
@@ -627,7 +629,7 @@
   HloPassPipeline pipeline("optimization");
   AddHloVerifier(&pipeline);
   if (debug_options.xla_gpu_multi_streamed_windowed_einsum()) {
-    pipeline.AddPass<GpuWindowedEinsumHandler>();
+    pipeline.AddPass<WindowedEinsumHandler>();
   }
   pipeline.AddPass<TopKSplitter>();
   pipeline.AddPass<TopkSpecializer>();
@@ -1121,7 +1123,7 @@
         return false;
     }
   };
-  pipeline.AddPass<GpuAsyncCollectiveAnnotator>(convert_to_async);
+  pipeline.AddPass<AsyncCollectiveAnnotator>(convert_to_async);
 
   return pipeline.Run(hlo_module).status();
 }
@@ -1376,6 +1378,7 @@
     // heuristic, so we can mix and match various Gemm implementations based
     // on projected (measured) performance.
     if (debug_options.xla_gpu_enable_custom_fusions()) {
+      pipeline.AddPass<SimplifyFPConversions>();
       pipeline.AddPass<CustomKernelFusionRewriter>(
           &gpu_target_config.device_description);
       pipeline.AddPass<CustomKernelFusionAutotuner>(autotune_config);
@@ -1396,10 +1399,12 @@
       pipeline.AddPass<GemmFusion>(gpu_version);
     }
 
-    pipeline.AddPass<GemmRewriter>(gpu_version, GetToolkitVersion(),
-                                   /*f8_rewrite=*/true);
-    pipeline.AddPass<GemmRewriter>(gpu_version, GetToolkitVersion(),
-                                   /*f8_rewrite=*/false);
+    pipeline.AddPass<GemmRewriter>(
+        gpu_version, GetToolkitVersion(),
+        GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only});
+    pipeline.AddPass<GemmRewriter>(
+        gpu_version, GetToolkitVersion(),
+        GemmRewriterOptions{GemmRewriterOptions::DType::kNonFp8Only});
 
     // Rewrite GEMMs with broadcasted inputs as strided GEMMs.
     pipeline.AddPass<GemmBroadcastFoldingRewriter>();
@@ -1442,7 +1447,7 @@
     bool ignore_small_reduce_dims =
         !debug_options.xla_gpu_enable_priority_fusion();
     pipeline.AddPass<HloPassFix<ReductionSplitter>>(ignore_small_reduce_dims);
-    pipeline.AddPass<HloPassFix<GpuTreeReductionRewriter>>(gpu_version);
+    pipeline.AddPass<HloPassFix<TreeReductionRewriter>>(gpu_version);
     TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
   }
 
@@ -1467,10 +1472,12 @@
   pipeline.AddPass<CallInliner>();
   // TODO(tdanyluk): Apply CublasPadForGemms to the cuBLAS GEMMs generated
   // here for possibly better cuBLAS performance.
-  pipeline.AddPass<GemmRewriter>(gpu_version, GetToolkitVersion(),
-                                 /*f8_rewrite=*/true);
-  pipeline.AddPass<GemmRewriter>(gpu_version, GetToolkitVersion(),
-                                 /*f8_rewrite=*/false);
+  pipeline.AddPass<GemmRewriter>(
+      gpu_version, GetToolkitVersion(),
+      GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only});
+  pipeline.AddPass<GemmRewriter>(
+      gpu_version, GetToolkitVersion(),
+      GemmRewriterOptions{GemmRewriterOptions::DType::kNonFp8Only});
   // Rewrite GEMMs with broadcasted inputs as strided GEMMs.
   pipeline.AddPass<GemmBroadcastFoldingRewriter>();
 
@@ -2147,8 +2154,8 @@
   }};
   BinaryMap dnn_compiled_graphs;
   if (stream_exec) {
-    TF_RETURN_IF_ERROR(RunCudnnFusionCompilerPass(module.get(), stream_exec,
-                                                  &dnn_compiled_graphs));
+    TF_RETURN_IF_ERROR(RunCudnnCompilerPasses(module.get(), stream_exec,
+                                              &dnn_compiled_graphs));
   }
 
   const DebugOptions& debug_opts = module->config().debug_options();
@@ -2481,7 +2488,7 @@
     pipeline.AddPass<CommandBufferScheduling>(
         gpu_device_info, toolkit_version,
         driver_version.value_or(toolkit_version));
-    pipeline.AddPass<GpuSanitizeConstantNames>();
+    pipeline.AddPass<SanitizeConstantNames>();
   }
 
   AddHloVerifier(&main_pipeline,
diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.h b/third_party/xla/xla/service/gpu/gpu_compiler.h
index 27a434f..456e675 100644
--- a/third_party/xla/xla/service/gpu/gpu_compiler.h
+++ b/third_party/xla/xla/service/gpu/gpu_compiler.h
@@ -31,7 +31,7 @@
 #include "xla/service/buffer_assignment.h"
 #include "xla/service/compiler.h"
 #include "xla/service/executable.h"
-#include "xla/service/gpu/autotuner_util.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
 #include "xla/service/gpu/buffer_sharing.h"
 #include "xla/service/gpu/compile_module_to_llvm_ir.h"
 #include "xla/service/gpu/executable.pb.h"
@@ -171,10 +171,10 @@
     return absl::OkStatus();
   }
 
-  // Runs cuDNN fusion compiler pass.
-  virtual absl::Status RunCudnnFusionCompilerPass(
-      HloModule* module, se::StreamExecutor* stream_exec,
-      BinaryMap* dnn_compiled_graphs) {
+  // Runs cuDNN fusion and custom call compiler passes.
+  virtual absl::Status RunCudnnCompilerPasses(HloModule* module,
+                                              se::StreamExecutor* stream_exec,
+                                              BinaryMap* dnn_compiled_graphs) {
     return absl::OkStatus();
   }
 
@@ -235,7 +235,8 @@
   absl::Status PrepareHloModuleForIrEmitting(HloModule* hlo_module);
 
   virtual absl::StatusOr<std::vector<uint8_t>> LinkModules(
-      se::GpuComputeCapability cc, se::StreamExecutor* stream_exec,
+      se::GpuComputeCapability gpu_compute_capability,
+      se::StreamExecutor* stream_exec,
       std::vector<std::vector<uint8_t>> modules,
       const DebugOptions& debug_options) {
     return Unimplemented("LinkModules is not implemented.");
diff --git a/third_party/xla/xla/service/gpu/gpu_compiler_test.cc b/third_party/xla/xla/service/gpu/gpu_compiler_test.cc
index b74d77c..93057c5 100644
--- a/third_party/xla/xla/service/gpu/gpu_compiler_test.cc
+++ b/third_party/xla/xla/service/gpu/gpu_compiler_test.cc
@@ -36,7 +36,7 @@
 #include "xla/hlo/ir/hlo_opcode.h"
 #include "xla/primitive_util.h"
 #include "xla/service/executable.h"
-#include "xla/service/gpu/autotuner_util.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
 #include "xla/service/gpu/gpu_hlo_schedule.h"
 #include "xla/service/gpu/metrics.h"
 #include "xla/service/hlo_module_config.h"
@@ -391,7 +391,6 @@
 
 TEST_F(GpuCompilerTest,
        GemmFusionIsNoOpWhenGemmFusionAutotunerFallsBackToCublas) {
-  GTEST_SKIP() << "TODO(b/354864068): Test fails in OSS stack on A100-80.";
   auto cc = backend()
                 .default_stream_executor()
                 ->GetDeviceDescription()
@@ -645,12 +644,12 @@
 CHECK:       %[[RESULT_RECV:.*]] = recv(%[[AFTER_ALL]])
 CHECK-SAME:    channel_id=[[CHANNEL_ID]]
 CHECK-SAME:    frontend_attributes={_xla_send_recv_pipeline="0",
-CHECK-SAME{LITERAL}:                _xla_send_recv_source_target_pairs="{{0,1},{1,2},{2,3},{3,4}}"},
+CHECK-SAME{LITERAL}:                _xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3},{3,4}}},
 CHECK-SAME:                         control-predecessors={%[[CUSTOM_CALL]]}
 CHECK:       %[[RESULT_SEND:.*]] = send(%[[SOME_SEND_ARG:.*]], %[[AFTER_ALL]])
 CHECK-SAME:    channel_id=1
 CHECK-SAME:    frontend_attributes={_xla_send_recv_pipeline="0",
-CHECK-SAME{LITERAL}:                _xla_send_recv_source_target_pairs="{{0,1},{1,2},{2,3},{3,4}}"},
+CHECK-SAME{LITERAL}:                _xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3},{3,4}}},
 CHECK-SAME:                         control-predecessors={%[[RESULT_RECV]]}
 CHECK:       ROOT
 // We actually expect both RESULT_RECV and RESULT_SEND to match on this line.
@@ -664,11 +663,11 @@
 CHECK:       %[[ENTRY_RECV:.*]] = recv(%[[ENTRY_AFTER_ALL]])
 CHECK-SAME:    channel_id=[[CHANNEL_ID]]
 CHECK-SAME:    frontend_attributes={_xla_send_recv_pipeline="0",
-CHECK-SAME{LITERAL}:                _xla_send_recv_source_target_pairs="{{0,1},{1,2},{2,3},{3,4}}"}
+CHECK-SAME{LITERAL}:                _xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3},{3,4}}}
 CHECK:       %[[ENTRY_SEND:.*]] = send(%[[SOME_SEND_ARG:.*]], %[[ENTRY_AFTER_ALL]])
 CHECK-SAME:    channel_id=1
 CHECK-SAME:    frontend_attributes={_xla_send_recv_pipeline="0",
-CHECK-SAME{LITERAL}:                _xla_send_recv_source_target_pairs="{{0,1},{1,2},{2,3},{3,4}}"},
+CHECK-SAME{LITERAL}:                _xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3},{3,4}}},
 CHECK-SAME:                         control-predecessors={%[[ENTRY_RECV]]}
 CHECK:       %[[WHILE_INIT:.*]] = tuple
 // Check here that the send argument is likewise passed to the while loop, as
@@ -861,10 +860,10 @@
     })";
 
   const char* kExpected = R"(
-  // CHECK: {{.+}} = send({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_source_target_pairs="{{[{]}}{3,0}}",_xla_send_recv_validation="{{[{]}}{3,9}}"}
-  // CHECK: {{.+}} = send({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_source_target_pairs="{{[{]}}{0,1},{1,2},{2,3}}",_xla_send_recv_validation="{{[{]}}{0,6},{1,7},{2,8}}"}
-  // CHECK: {{.+}} = recv({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_source_target_pairs="{{[{]}}{3,0}}",_xla_send_recv_validation="{{[{]}}{3,9}}"}
-  // CHECK: {{.+}} = recv({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_source_target_pairs="{{[{]}}{0,1},{1,2},{2,3}}",_xla_send_recv_validation="{{[{]}}{0,6},{1,7},{2,8}}"}
+  // CHECK: {{.+}} = send({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_source_target_pairs={{[{]}}{3,0}},_xla_send_recv_validation={{[{]}}{3,9}}}
+  // CHECK: {{.+}} = send({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_source_target_pairs={{[{]}}{0,1},{1,2},{2,3}},_xla_send_recv_validation={{[{]}}{0,6},{1,7},{2,8}}}
+  // CHECK: {{.+}} = recv({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_source_target_pairs={{[{]}}{3,0}},_xla_send_recv_validation={{[{]}}{3,9}}}
+  // CHECK: {{.+}} = recv({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_source_target_pairs={{[{]}}{0,1},{1,2},{2,3}},_xla_send_recv_validation={{[{]}}{0,6},{1,7},{2,8}}}
   )";
 
   DebugOptions debug_options;
diff --git a/third_party/xla/xla/service/gpu/gpu_compiler_test_autotune_db.textproto b/third_party/xla/xla/service/gpu/gpu_compiler_test_autotune_db.textproto
index 3549c95..51caadb 100644
--- a/third_party/xla/xla/service/gpu/gpu_compiler_test_autotune_db.textproto
+++ b/third_party/xla/xla/service/gpu/gpu_compiler_test_autotune_db.textproto
@@ -13,7 +13,7 @@
 }
 results {
   device: "CUDA: 8.0, Cores: 108, GPU clock: 1.41 GHz, Memory bandwidth: 1555 GB/s, L2 cache: 40 MB"
-  hlo: "(bf16[128,1024,1024]{2,1,0}, s8[4194304]{0}) custom-call(bf16[128,1024,1024]{2,1,0}, bf16[128,1024,1024]{2,1,0}), custom_call_target=\"__cublas$gemm\", backend_config={\"operation_queue_id\":\"0\",\"wait_on_operation_queues\":[],\"gemm_backend_config\":{\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"2\"],\"rhs_contracting_dimensions\":[\"1\"],\"lhs_batch_dimensions\":[\"0\"],\"rhs_batch_dimensions\":[\"0\"]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"],\"algorithm\":\"ALG_UNSET\"},\"epilogue\":\"DEFAULT\",\"lhs_stride\":\"1048576\",\"rhs_stride\":\"1048576\",\"grad_x\":false,\"grad_y\":false,\"damax_output\":false},\"force_earliest_schedule\":false}"
+  hlo: "(bf16[128,1024,1024]{2,1,0}, s8[4194304]{0}) custom-call(bf16[128,1024,1024]{2,1,0}, bf16[128,1024,1024]{2,1,0}), custom_call_target=\"__cublas$gemm\", backend_config={\"force_earliest_schedule\":false,\"gemm_backend_config\":{\"alpha_imag\":0,\"alpha_real\":1,\"beta\":0,\"damax_output\":false,\"dot_dimension_numbers\":{\"lhs_batch_dimensions\":[\"0\"],\"lhs_contracting_dimensions\":[\"2\"],\"rhs_batch_dimensions\":[\"0\"],\"rhs_contracting_dimensions\":[\"1\"]},\"epilogue\":\"DEFAULT\",\"grad_x\":false,\"grad_y\":false,\"lhs_stride\":\"1048576\",\"precision_config\":{\"algorithm\":\"ALG_UNSET\",\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"rhs_stride\":\"1048576\"},\"operation_queue_id\":\"0\",\"wait_on_operation_queues\":[]}"
   result {
     run_time {
       nanos: 1
@@ -37,7 +37,7 @@
 }
 results {
   device: "CUDA: 8.0, Cores: 108, GPU clock: 1.41 GHz, Memory bandwidth: 2039 GB/s, L2 cache: 40 MB"
-  hlo: "(bf16[128,1024,1024]{2,1,0}, s8[4194304]{0}) custom-call(bf16[128,1024,1024]{2,1,0}, bf16[128,1024,1024]{2,1,0}), custom_call_target=\"__cublas$gemm\", backend_config={\"operation_queue_id\":\"0\",\"wait_on_operation_queues\":[],\"gemm_backend_config\":{\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"2\"],\"rhs_contracting_dimensions\":[\"1\"],\"lhs_batch_dimensions\":[\"0\"],\"rhs_batch_dimensions\":[\"0\"]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"],\"algorithm\":\"ALG_UNSET\"},\"epilogue\":\"DEFAULT\",\"lhs_stride\":\"1048576\",\"rhs_stride\":\"1048576\",\"grad_x\":false,\"grad_y\":false,\"damax_output\":false},\"force_earliest_schedule\":false}"
+  hlo: "(bf16[128,1024,1024]{2,1,0}, s8[4194304]{0}) custom-call(bf16[128,1024,1024]{2,1,0}, bf16[128,1024,1024]{2,1,0}), custom_call_target=\"__cublas$gemm\", backend_config={\"force_earliest_schedule\":false,\"gemm_backend_config\":{\"alpha_imag\":0,\"alpha_real\":1,\"beta\":0,\"damax_output\":false,\"dot_dimension_numbers\":{\"lhs_batch_dimensions\":[\"0\"],\"lhs_contracting_dimensions\":[\"2\"],\"rhs_batch_dimensions\":[\"0\"],\"rhs_contracting_dimensions\":[\"1\"]},\"epilogue\":\"DEFAULT\",\"grad_x\":false,\"grad_y\":false,\"lhs_stride\":\"1048576\",\"precision_config\":{\"algorithm\":\"ALG_UNSET\",\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"rhs_stride\":\"1048576\"},\"operation_queue_id\":\"0\",\"wait_on_operation_queues\":[]}"
   result {
     run_time {
       nanos: 1
@@ -61,7 +61,7 @@
 }
 results {
   device: "CUDA: 9.0, Cores: 132, GPU clock: 1.98 GHz, Memory bandwidth: 3352 GB/s, L2 cache: 50 MB"
-  hlo: "(bf16[128,1024,1024]{2,1,0}, s8[33554432]{0}) custom-call(bf16[128,1024,1024]{2,1,0}, bf16[128,1024,1024]{2,1,0}), custom_call_target=\"__cublas$gemm\", backend_config={\"operation_queue_id\":\"0\",\"wait_on_operation_queues\":[],\"gemm_backend_config\":{\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"2\"],\"rhs_contracting_dimensions\":[\"1\"],\"lhs_batch_dimensions\":[\"0\"],\"rhs_batch_dimensions\":[\"0\"]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"],\"algorithm\":\"ALG_UNSET\"},\"epilogue\":\"DEFAULT\",\"lhs_stride\":\"1048576\",\"rhs_stride\":\"1048576\",\"grad_x\":false,\"grad_y\":false,\"damax_output\":false},\"force_earliest_schedule\":false}"
+  hlo: "(bf16[128,1024,1024]{2,1,0}, s8[33554432]{0}) custom-call(bf16[128,1024,1024]{2,1,0}, bf16[128,1024,1024]{2,1,0}), custom_call_target=\"__cublas$gemm\", backend_config={\"force_earliest_schedule\":false,\"gemm_backend_config\":{\"alpha_imag\":0,\"alpha_real\":1,\"beta\":0,\"damax_output\":false,\"dot_dimension_numbers\":{\"lhs_batch_dimensions\":[\"0\"],\"lhs_contracting_dimensions\":[\"2\"],\"rhs_batch_dimensions\":[\"0\"],\"rhs_contracting_dimensions\":[\"1\"]},\"epilogue\":\"DEFAULT\",\"grad_x\":false,\"grad_y\":false,\"lhs_stride\":\"1048576\",\"precision_config\":{\"algorithm\":\"ALG_UNSET\",\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"rhs_stride\":\"1048576\"},\"operation_queue_id\":\"0\",\"wait_on_operation_queues\":[]}"
   result {
     gemm {
       algorithm: -1
diff --git a/third_party/xla/xla/service/gpu/gpu_conv_padding_legalization.cc b/third_party/xla/xla/service/gpu/gpu_conv_padding_legalization.cc
deleted file mode 100644
index 0b55f7d..0000000
--- a/third_party/xla/xla/service/gpu/gpu_conv_padding_legalization.cc
+++ /dev/null
@@ -1,461 +0,0 @@
-/* Copyright 2017 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gpu_conv_padding_legalization.h"
-
-#include <algorithm>
-#include <cstddef>
-#include <cstdint>
-#include <cstdlib>
-#include <vector>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/log/check.h"
-#include "absl/log/log.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/literal_util.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/service/hlo_creation_utils.h"
-#include "xla/service/shape_inference.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/util.h"
-#include "xla/window_util.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/status.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-namespace {
-bool IsForwardConvolutionCanonical(const HloInstruction& conv) {
-  CHECK(conv.custom_call_target() == kCudnnConvForwardCallTarget ||
-        conv.custom_call_target() ==
-            kCudnnConvBiasActivationForwardCallTarget ||
-        conv.custom_call_target() == kCudnnConvForwardGraphCallTarget);
-  return window_util::HasSymmetricPadding(conv.window()) &&
-         !window_util::HasNegativePadding(conv.window()) &&
-         !window_util::HasDilation(conv.window());
-}
-
-// If the (positive and negative) padding on the input operand of a convolution
-// can't be folded into a cuDNN convolution libcall (e.g. uneven padding and
-// dilation), returns kPad and/or kSlice instructions that explicitly apply the
-// padding; otherwise returns the original input operand. When there is both
-// positive padding (including dilation) and negative padding, we insert both
-// kPad and kSlice. Modifies 'conv_window' accordingly if any padding was moved
-// into a kPad or kSlice op.
-HloInstruction* MaybePaddedAndSlicedInput(
-    Window* conv_window, const ConvolutionDimensionNumbers& conv_dnums,
-    HloInstruction* input) {
-  HloComputation* computation = input->parent();
-  if (!window_util::HasSymmetricPadding(*conv_window) ||
-      window_util::HasBaseDilation(*conv_window)) {
-    // If padding is uneven or has dilation, we insert a kPad instruction that
-    // applies positive padding and dilation.
-    //
-    // TODO(phawkins): If conv_window has asymmetric padding, perhaps instead of
-    // moving all the padding into an explicit pad op, we should keep as much
-    // padding inside of cudnn as possible, on the assumption that padding
-    // within cudnn is basically free, whereas a kPad's cost increases as the
-    // amount of padding increases.
-    PaddingConfig padding_config =
-        MakeNoPaddingConfig(input->shape().dimensions_size());
-    for (size_t i = 0; i < conv_dnums.input_spatial_dimensions().size(); ++i) {
-      int64_t dim = conv_dnums.input_spatial_dimensions(i);
-      if (conv_window->dimensions(i).padding_low() > 0) {
-        padding_config.mutable_dimensions(dim)->set_edge_padding_low(
-            conv_window->dimensions(i).padding_low());
-        conv_window->mutable_dimensions(i)->set_padding_low(0);
-      }
-      if (conv_window->dimensions(i).padding_high() > 0) {
-        padding_config.mutable_dimensions(dim)->set_edge_padding_high(
-            conv_window->dimensions(i).padding_high());
-        conv_window->mutable_dimensions(i)->set_padding_high(0);
-      }
-      if (conv_window->dimensions(i).base_dilation() != 1) {
-        padding_config.mutable_dimensions(dim)->set_interior_padding(
-            conv_window->dimensions(i).base_dilation() - 1);
-        conv_window->mutable_dimensions(i)->set_base_dilation(1);
-      }
-    }
-    PrimitiveType element_type = input->shape().element_type();
-    HloInstruction* padding = computation->AddInstruction(
-        HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
-    input =
-        MakePadHlo(input, padding, padding_config, &input->metadata()).value();
-  }
-
-  if (window_util::HasNegativePadding(*conv_window)) {
-    // If the window has negative padding, insert a kSlice that explicitly
-    // applies negative padding.
-    //
-    // For each dimension, initialize the start index to 0 and the limit index
-    // to the size of that dimension.
-    std::vector<int64_t> start_indices(input->shape().dimensions_size(), 0);
-    std::vector<int64_t> limit_indices(input->shape().dimensions().begin(),
-                                       input->shape().dimensions().end());
-    std::vector<int64_t> strides(input->shape().dimensions_size(), 1);
-    for (size_t i = 0; i < conv_dnums.input_spatial_dimensions().size(); ++i) {
-      int64_t dim = conv_dnums.input_spatial_dimensions(i);
-      // If dimension "dim" has negative padding, increase the start index or
-      // decrement the limit index by the amount of negative padding.
-      if (conv_window->dimensions(i).padding_low() < 0) {
-        start_indices[dim] += -conv_window->dimensions(i).padding_low();
-        conv_window->mutable_dimensions(i)->set_padding_low(0);
-      }
-      if (conv_window->dimensions(i).padding_high() < 0) {
-        limit_indices[dim] -= -conv_window->dimensions(i).padding_high();
-        conv_window->mutable_dimensions(i)->set_padding_high(0);
-      }
-    }
-
-    input = MakeSliceHlo(input, start_indices, limit_indices, strides).value();
-  }
-
-  return input;
-}
-
-// If the padding on the kernel operand of a convolution can't be folded into a
-// cuDNN convolution libcall (e.g. dilation), returns a kPad instruction that
-// explicitly applies the padding; otherwise returns the original kernel
-// operand.
-HloInstruction* MaybePaddedKernel(const Window& conv_window,
-                                  const ConvolutionDimensionNumbers& conv_dnums,
-                                  HloInstruction* kernel) {
-  if (!window_util::HasWindowDilation(conv_window)) {
-    return kernel;
-  }
-
-  // Compute the shape and padding config of the pad to be inserted.
-  PaddingConfig padding_config;
-  padding_config.mutable_dimensions()->Reserve(
-      kernel->shape().dimensions_size());
-  for (size_t i = 0; i < kernel->shape().dimensions_size(); ++i) {
-    padding_config.add_dimensions();
-  }
-  for (size_t i = 0; i < conv_dnums.kernel_spatial_dimensions().size(); ++i) {
-    int64_t dim = conv_dnums.kernel_spatial_dimensions(i);
-    padding_config.mutable_dimensions(dim)->set_interior_padding(
-        conv_window.dimensions(i).window_dilation() - 1);
-  }
-
-  HloComputation* computation = kernel->parent();
-  PrimitiveType element_type = kernel->shape().element_type();
-  HloInstruction* padding = computation->AddInstruction(
-      HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
-  return MakePadHlo(kernel, padding, padding_config, &kernel->metadata())
-      .value();
-}
-}  // namespace
-
-bool GpuConvPaddingLegalization::CanonicalizeForwardConvolution(
-    HloInstruction* conv) {
-  if (IsForwardConvolutionCanonical(*conv)) {
-    return false;
-  }
-
-  // Insert slices and/or pads between the convolution and its input and/or
-  // kernel operand.
-  Window new_conv_window = conv->window();
-  HloInstruction* new_input = MaybePaddedAndSlicedInput(
-      &new_conv_window, conv->convolution_dimension_numbers(),
-      conv->mutable_operand(0));
-  HloInstruction* new_kernel =
-      MaybePaddedKernel(new_conv_window, conv->convolution_dimension_numbers(),
-                        conv->mutable_operand(1));
-
-  // Remove the window dilation from convolution's window field. These paddings
-  // are made explicit with the pads inserted by MaybePaddedKernel().
-  for (size_t i = 0; i < new_conv_window.dimensions_size(); ++i) {
-    WindowDimension* dim = new_conv_window.mutable_dimensions(i);
-
-    // The size of the kernel may have changed so update the Window to match.
-    dim->set_size(new_kernel->shape().dimensions(
-        conv->convolution_dimension_numbers().kernel_spatial_dimensions(i)));
-    dim->set_window_dilation(1);
-  }
-
-  // The conv CustomCall returns a tuple (conv_result, scratch_buffer).  Extract
-  // out the shape of conv_result.
-  VLOG(1) << "Canonicalizing forward conv";
-  std::vector<HloInstruction*> operands(conv->operands().begin(),
-                                        conv->operands().end());
-  operands[0] = new_input;
-  operands[1] = new_kernel;
-  auto new_conv = conv->parent()->AddInstruction(
-      conv->CloneWithNewOperands(conv->shape(), operands));
-  new_conv->set_window(new_conv_window);
-  VLOG(1) << "Replacing:\n  " << conv->ToString() << "\nwith:\n  "
-          << new_conv->ToString();
-  TF_CHECK_OK(conv->parent()->ReplaceInstruction(conv, new_conv));
-  return true;
-}
-
-namespace {
-void IncreasePaddingLowBy(int64_t delta, WindowDimension* window_dim) {
-  window_dim->set_padding_low(window_dim->padding_low() + delta);
-}
-
-void IncreasePaddingHighBy(int64_t delta, WindowDimension* window_dim) {
-  window_dim->set_padding_high(window_dim->padding_high() + delta);
-}
-}  // namespace
-
-bool GpuConvPaddingLegalization::CanonicalizeBackwardFilterConvolution(
-    HloInstruction* backward_conv) {
-  CHECK_EQ(backward_conv->custom_call_target(),
-           kCudnnConvBackwardFilterCallTarget);
-  if (window_util::HasSymmetricPadding(backward_conv->window())) {
-    return false;
-  }
-
-  // A backward filter convolution with uneven padding can be canonicalized to
-  // one with even padding by padding the activations (input) beforehand. For
-  // example,
-  //   BackwardFilterConv(ABCD, xyz, padding_low=1, padding_high=2)
-  // is equivalent to
-  //   ABCD0 = Pad(ABCD, padding_high=1)
-  //   BackwardFilterConv(ABCD0, xyz, padding_low=padding_high=1)
-  // We choose the lesser of padding_low and padding_high as the new padding.
-  HloInstruction* input = backward_conv->mutable_operand(0);
-  Window new_backward_conv_window = backward_conv->window();
-  // input_padding_config is the config of the kPad to be inserted.
-  PaddingConfig input_padding_config =
-      MakeNoPaddingConfig(input->shape().rank());
-  ConvolutionDimensionNumbers backward_conv_dnums =
-      backward_conv->convolution_dimension_numbers();
-  for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) {
-    int64_t padding_low = backward_conv->window().dimensions(i).padding_low();
-    int64_t padding_high = backward_conv->window().dimensions(i).padding_high();
-    if (padding_low < 0 || padding_high < 0) {
-      // TODO(b/32744257): The following canonicalization wouldn't remove
-      // negative padding in a backward convolution, and would therefore cause
-      // cuDNN convolution (which doesn't support negative padding) to fail.
-      return false;
-    }
-    // Compute the new, even padding for the backward conv operation.
-    int64_t new_conv_padding = std::min(padding_low, padding_high);
-    int64_t dim = backward_conv_dnums.input_spatial_dimensions(i);
-    input_padding_config.mutable_dimensions(dim)->set_edge_padding_low(
-        padding_low - new_conv_padding);
-    input_padding_config.mutable_dimensions(dim)->set_edge_padding_high(
-        padding_high - new_conv_padding);
-
-    // Since we move some padding from the backward convolution to the kPad, we
-    // need to accordingly reduce the padding amount of the backward convolution
-    // and its inner forward convolution.
-    auto* new_dim = new_backward_conv_window.mutable_dimensions(i);
-    new_dim->set_padding_low(new_conv_padding);
-    new_dim->set_padding_high(new_conv_padding);
-  }
-
-  // Create a new backward convolution replacing the old one.
-  HloComputation* computation = backward_conv->parent();
-  HloInstruction* output = backward_conv->mutable_operand(1);
-  HloInstruction* padding =
-      computation->AddInstruction(HloInstruction::CreateConstant(
-          LiteralUtil::Zero(input->shape().element_type())));
-  HloInstruction* padded_input =
-      MakePadHlo(input, padding, input_padding_config).value();
-
-  // The shape of the backward_conv CustomCall is a tuple (conv_result,
-  // scratch_buffer).  Extract out the shape of conv_result.
-  HloInstruction* new_backward_conv =
-      computation->AddInstruction(backward_conv->CloneWithNewOperands(
-          backward_conv->shape(), {padded_input, output}));
-  new_backward_conv->set_window(new_backward_conv_window);
-
-  VLOG(1) << "Canonicalizing backward filter conv";
-  VLOG(1) << "Replacing:\n  " << backward_conv->ToString() << "\nwith:\n  "
-          << new_backward_conv->ToString();
-
-  TF_CHECK_OK(
-      computation->ReplaceInstruction(backward_conv, new_backward_conv));
-  return true;
-}
-
-bool GpuConvPaddingLegalization::CanonicalizeBackwardInputConvolution(
-    HloInstruction* backward_conv) {
-  if (window_util::HasSymmetricPadding(backward_conv->window())) {
-    return false;
-  }
-
-  Window new_backward_conv_window = backward_conv->window();
-  ConvolutionDimensionNumbers backward_conv_dnums =
-      backward_conv->convolution_dimension_numbers();
-
-  // The backward_conv CustomCall returns a tuple (conv_result, scratch_memory).
-  // Get the shape of conv_result.
-  Shape backward_conv_shape = backward_conv->shape().tuple_shapes(0);
-
-  Shape new_backward_conv_shape = backward_conv_shape;
-  for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) {
-    int64_t padding_low = backward_conv->window().dimensions(i).padding_low();
-    int64_t padding_high = backward_conv->window().dimensions(i).padding_high();
-    if (padding_low < 0 || padding_high < 0) {
-      // TODO(b/32744257): The following canonicalization wouldn't remove
-      // negative padding in a backward convolution, and would therefore cause
-      // cuDNN convolution (which doesn't support negative padding) to fail.
-      return false;
-    }
-    // If the backward convolution has uneven padding on the activations, we
-    // move some padding on the larger end to "internal" padding, so that the
-    // backward convolution produces larger activations which get sliced later.
-    //
-    // For example, suppose we have a non-canonical HLO
-    //   [A] = BackwardInputConvolve([a b], [x y z], padding=(low=2,high=1))
-    // where the amount of padding low is larger, we can canonicalize it to
-    //   [B A] = BackwardInputConvolve([a b], [x y z], padding=(low=1,high=1))
-    //   [A] = Slice([B A])
-    if (padding_low > padding_high) {
-      IncreasePaddingLowBy(padding_high - padding_low,
-                           new_backward_conv_window.mutable_dimensions(i));
-    } else if (padding_low < padding_high) {
-      IncreasePaddingHighBy(padding_low - padding_high,
-                            new_backward_conv_window.mutable_dimensions(i));
-    }
-    // Decreasing the padding by X *increases* the size of our output by X.
-    // Note that we have swapped input spatial dimensions with output spatial
-    // dimensions to be compatible with the cuDNN API, so
-    // input_spatial_dimensions(i) gives the i-th spatial dimension of the
-    // output.
-    int64_t dim = backward_conv_dnums.input_spatial_dimensions(i);
-    new_backward_conv_shape.set_dimensions(
-        dim, new_backward_conv_shape.dimensions(dim) +
-                 std::abs(padding_low - padding_high));
-  }
-
-  // Create a new backward convolution replacing the old one.
-  HloComputation* computation = backward_conv->parent();
-  HloInstruction* output = backward_conv->mutable_operand(0);
-  HloInstruction* filter = backward_conv->mutable_operand(1);
-
-  HloInstruction* new_backward_conv_call =
-      computation->AddInstruction(backward_conv->CloneWithNewOperands(
-          ShapeUtil::MakeTupleShape(
-              {new_backward_conv_shape, ShapeUtil::MakeShape(U8, {0})}),
-          {output, filter}));
-  new_backward_conv_call->set_window(new_backward_conv_window);
-
-  // The CustomCall created above returns a tuple (conv_result, scratch_memory).
-  // Extract out the two elements.
-  HloInstruction* new_backward_conv =
-      computation->AddInstruction(HloInstruction::CreateGetTupleElement(
-          new_backward_conv_shape, new_backward_conv_call, 0));
-  HloInstruction* new_backward_conv_scratch =
-      computation->AddInstruction(HloInstruction::CreateGetTupleElement(
-          new_backward_conv_call->shape().tuple_shapes(1),
-          new_backward_conv_call, 1));
-
-  // Slice the new backward convolution.
-  //
-  // Initialize start_indices and limit_indices as no slicing.
-  std::vector<int64_t> start_indices(
-      new_backward_conv->shape().dimensions_size(), 0LL);
-  std::vector<int64_t> limit_indices(
-      new_backward_conv->shape().dimensions().begin(),
-      new_backward_conv->shape().dimensions().end());
-  std::vector<int64_t> strides(new_backward_conv->shape().dimensions_size(),
-                               1LL);
-  for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) {
-    int64_t padding_low = backward_conv->window().dimensions(i).padding_low();
-    int64_t padding_high = backward_conv->window().dimensions(i).padding_high();
-    // Note that we have swapped input spatial dimensions with output spatial
-    // dimensions to be compatible with the cuDNN API, so
-    // input_spatial_dimensions(i) gives the i-th spatial dimension of the
-    // output.
-    int64_t dim = backward_conv_dnums.input_spatial_dimensions(i);
-    if (padding_low > padding_high) {
-      // If the amount of low padding (of the old backward convolution) is
-      // larger, we internally pad the low end of the activations and slice
-      // internal padding out here.
-      start_indices[dim] += padding_low - padding_high;
-    } else if (padding_low < padding_high) {
-      // If the amount of high padding is larger, we slice out the internal
-      // padding on the high end.
-      limit_indices[dim] -= padding_high - padding_low;
-    }
-  }
-
-  // Replace the old backward convolution with the slice.
-  Shape slice_shape =
-      ShapeInference::InferSliceShape(new_backward_conv->shape(), start_indices,
-                                      limit_indices, strides)
-          .value();
-  CHECK(ShapeUtil::Compatible(slice_shape, backward_conv_shape))
-      << ShapeUtil::HumanString(slice_shape) << " vs "
-      << ShapeUtil::HumanString(backward_conv_shape);
-
-  HloInstruction* slice = computation->AddInstruction(
-      HloInstruction::CreateSlice(backward_conv_shape, new_backward_conv,
-                                  start_indices, limit_indices, strides));
-  HloInstruction* new_tuple = computation->AddInstruction(
-      HloInstruction::CreateTuple({slice, new_backward_conv_scratch}));
-
-  VLOG(1) << "Canonicalizing backward input conv";
-  VLOG(1) << "Replacing:\n  " << backward_conv->ToString() << "\nwith:\n  "
-          << new_tuple->ToString();
-
-  TF_CHECK_OK(computation->ReplaceInstruction(backward_conv, new_tuple));
-  return true;
-}
-
-absl::StatusOr<bool> GpuConvPaddingLegalization::RunOnComputation(
-    HloComputation* computation) {
-  bool changed = false;
-  std::vector<HloCustomCallInstruction*> convs;
-  for (auto* instr : computation->instructions()) {
-    if (IsCustomCallToDnnConvolution(*instr)) {
-      convs.push_back(Cast<HloCustomCallInstruction>(instr));
-    }
-  }
-  for (HloCustomCallInstruction* instruction : convs) {
-    TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(instruction));
-    changed |= [&] {
-      switch (kind) {
-        case CudnnConvKind::kForward:
-        case CudnnConvKind::kForwardActivation:
-        case CudnnConvKind::kForwardGraph:
-          return CanonicalizeForwardConvolution(instruction);
-        case CudnnConvKind::kBackwardInput:
-          return CanonicalizeBackwardInputConvolution(instruction);
-        case CudnnConvKind::kBackwardFilter:
-          return CanonicalizeBackwardFilterConvolution(instruction);
-      }
-    }();
-  }
-  return changed;
-}
-
-absl::StatusOr<bool> GpuConvPaddingLegalization::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  bool changed = false;
-  for (HloComputation* computation :
-       module->MakeNonfusionComputations(execution_threads)) {
-    TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation));
-    changed |= result;
-  }
-  return changed;
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gpu_conv_padding_legalization.h b/third_party/xla/xla/service/gpu/gpu_conv_padding_legalization.h
deleted file mode 100644
index 32e0238..0000000
--- a/third_party/xla/xla/service/gpu/gpu_conv_padding_legalization.h
+++ /dev/null
@@ -1,55 +0,0 @@
-/* Copyright 2017 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_GPU_CONV_PADDING_LEGALIZATION_H_
-#define XLA_SERVICE_GPU_GPU_CONV_PADDING_LEGALIZATION_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-// An HLO pass that canonicalizes convolution instructions for GPU codegen. It
-// inserts Pad instructions before Convolution instructions with uncanonicalized
-// padding, so that they can be lowered to Cudnn/Miopen convolution.
-class GpuConvPaddingLegalization : public HloModulePass {
- public:
-  absl::string_view name() const override {
-    return "gpu-conv-padding-legalization";
-  }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
-  absl::StatusOr<bool> RunOnComputation(HloComputation* computation);
-  // Returns if any changes are made to the parent computation.
-  bool CanonicalizeForwardConvolution(HloInstruction* conv);
-  bool CanonicalizeBackwardFilterConvolution(HloInstruction* backward_conv);
-  bool CanonicalizeBackwardInputConvolution(HloInstruction* backward_conv);
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_GPU_CONV_PADDING_LEGALIZATION_H_
diff --git a/third_party/xla/xla/service/gpu/gpu_conv_padding_legalization_test.cc b/third_party/xla/xla/service/gpu/gpu_conv_padding_legalization_test.cc
deleted file mode 100644
index edaf9d0..0000000
--- a/third_party/xla/xla/service/gpu/gpu_conv_padding_legalization_test.cc
+++ /dev/null
@@ -1,96 +0,0 @@
-/* Copyright 2020 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gpu_conv_padding_legalization.h"
-
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/test.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/test.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-namespace m = ::xla::match;
-
-using GpuConvPaddingLegalizationTest = HloTestBase;
-
-TEST_F(GpuConvPaddingLegalizationTest, BackwardInputConvolve) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule convolution_module
-ENTRY %convolution (operand f64[2,2,2,3]{3,2,1,0}) -> (f64[2,2,4,4]{3,2,1,0}, u8[0]) {
-  %operand = f64[2,2,2,3]{3,2,1,0} parameter(0)
-  %kernel = f64[2,3,2,3]{3,2,1,0} constant(
-  {
-    { /*i0=0*/
-    { /*i1=0*/
-      { 0.29629629629629628, 0.30246913580246915, 0.30864197530864196 },
-      { 0.31481481481481483, 0.32098765432098764, 0.3271604938271605 }
-    },
-    { /*i1=1*/
-      { 0.25925925925925924, 0.26543209876543211, 0.27160493827160492 },
-      { 0.27777777777777779, 0.2839506172839506, 0.29012345679012347 }
-    },
-    { /*i1=2*/
-      { 0.22222222222222221, 0.22839506172839505, 0.23456790123456789 },
-      { 0.24074074074074073, 0.24691358024691357, 0.25308641975308643 }
-    }
-    },
-    { /*i0=1*/
-    { /*i1=0*/
-      { 0.18518518518518517, 0.19135802469135801, 0.19753086419753085 },
-      { 0.20370370370370369, 0.20987654320987653, 0.21604938271604937 }
-    },
-    { /*i1=1*/
-      { 0.14814814814814814, 0.15432098765432098, 0.16049382716049382 },
-      { 0.16666666666666666, 0.1728395061728395, 0.17901234567901234 }
-    },
-    { /*i2=2*/
-      { 0.1111111111111111, 0.11728395061728394, 0.12345679012345678 },
-      { 0.12962962962962962, 0.13580246913580246, 0.1419753086419753 }
-    }
-    }
-  })
-  %reverse = f64[2,3,2,3]{3,2,1,0} reverse(%kernel), dimensions={0,1}
-  ROOT %custom-call = (f64[2,2,4,4]{3,2,1,0}, u8[0]{0}) custom-call(f64[2,2,2,3]{3,2,1,0} %operand, f64[2,3,2,3]{3,2,1,0} %reverse), window={size=2x3 stride=2x2 pad=0_0x0_1}, dim_labels=bf01_01io->b01f, custom_call_target="__cudnn$convBackwardInput", backend_config="{\"algorithm\":\"0\",\"tensor_ops_enabled\":false,\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}"
-}
-                                               )")
-                    .value();
-  ASSERT_TRUE(GpuConvPaddingLegalization().Run(module.get()).value());
-  auto root = module->entry_computation()->root_instruction();
-  EXPECT_THAT(root, GmockMatch(m::Tuple(
-                        m::Slice(m::GetTupleElement(
-                            m::CustomCall({kCudnnConvBackwardInputCallTarget},
-                                          m::Op(), m::Reverse(m::Constant())),
-                            0)),
-                        m::GetTupleElement())));
-  auto slice = root->operand(0);
-  Shape expected_slice_shape = ShapeUtil::MakeShape(F64, {2, 2, 4, 4});
-  EXPECT_TRUE(ShapeUtil::Equal(slice->shape(), expected_slice_shape));
-  auto conv = slice->operand(0);
-  Shape expected_conv_shape = ShapeUtil::MakeShape(F64, {2, 2, 4, 5});
-  EXPECT_TRUE(ShapeUtil::Equal(conv->shape(), expected_conv_shape));
-}
-
-}  // anonymous namespace
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gpu_conv_rewriter.cc b/third_party/xla/xla/service/gpu/gpu_conv_rewriter.cc
deleted file mode 100644
index cb5b186..0000000
--- a/third_party/xla/xla/service/gpu/gpu_conv_rewriter.cc
+++ /dev/null
@@ -1,869 +0,0 @@
-/* Copyright 2018 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gpu_conv_rewriter.h"
-
-#include <cstdint>
-#include <cstdlib>
-#include <memory>
-#include <numeric>
-#include <optional>
-#include <string>
-#include <string_view>
-#include <tuple>
-#include <utility>
-#include <variant>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/status.h"
-#include "absl/strings/str_replace.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/permutation_util.h"
-#include "xla/primitive_util.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/util.h"
-#include "xla/window_util.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h"
-#include "tsl/platform/status.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-namespace {
-
-absl::Status CheckTypes(HloInstruction* conv,
-                        const se::GpuComputeCapability cc) {
-  auto valid_shape = [conv, &cc](const Shape& shape) -> absl::Status {
-    PrimitiveType type = shape.element_type();
-    if (!primitive_util::IsFloatingPointType(type) &&
-        !primitive_util::IsIntegralType(type)) {
-      // Among integral types, only S8 is supported. But CudnnFusedConvRewriter
-      // may rewrite convolutions of wider types into S8 convolutions, so allow
-      // all integral convolutions here.
-      return Unimplemented(
-          "Convolutions must have floating-point or integral operands/outputs, "
-          "but got convolution with type %s: %s",
-          primitive_util::LowercasePrimitiveTypeName(type), conv->ToString());
-    }
-    if (primitive_util::IsF8Type(type)) {
-      if (type != F8E4M3FN && type != F8E5M2) {
-        return Unimplemented(
-            "The only FP8 types supported in convolutions are f8e5m2 and "
-            "f8e4m3, "
-            "but got convolution with FP8 type %s: %s",
-            primitive_util::LowercasePrimitiveTypeName(type), conv->ToString());
-      }
-      if (!std::holds_alternative<se::CudaComputeCapability>(cc)) {
-        return Unimplemented(
-            "FP8 convolutions are only supported on CUDA GPUs, but got "
-            "FP8 convolution on ROCm GPU: %s",
-            conv->ToString());
-      } else if (!std::get<se::CudaComputeCapability>(cc).IsAtLeastHopper()) {
-        return Unimplemented(
-            "FP8 convolutions are only supported on CUDA GPUs with compute "
-            "capability at least 9.0, but got "
-            "FP8 convolution on GPU with compute capability %s: %s",
-            std::get<se::CudaComputeCapability>(cc).ToString(),
-            conv->ToString());
-      }
-    }
-    return absl::OkStatus();
-  };
-
-  TF_RETURN_IF_ERROR(valid_shape(conv->shape()));
-  TF_RETURN_IF_ERROR(valid_shape(conv->operand(0)->shape()));
-  TF_RETURN_IF_ERROR(valid_shape(conv->operand(1)->shape()));
-  return absl::OkStatus();
-}
-
-using ConvolutionMatch = std::optional<
-    std::tuple<Window, ConvolutionDimensionNumbers, HloInstruction*>>;
-
-// Determine whether conv2d is equal to conv1d.
-bool MaybeConv1dToConv2d(HloInstruction* conv) {
-  if (conv->window().dimensions().size() != 2) {
-    return false;
-  }
-  if (conv->operand(1)->opcode() != HloOpcode::kReshape) {
-    return false;
-  }
-  auto filter = conv->operand(1);
-  std::optional<ShapeUtil::ShapeEqualityDescriptor> reshape_degenerate =
-      filter->ReshapeMerelyInsertsOrDeletes1SizedDimensions();
-  if (reshape_degenerate.has_value() &&
-      reshape_degenerate->deleted_dimensions.empty() &&
-      reshape_degenerate->inserted_dimensions.size() == 1) {
-    const auto& dnums = conv->convolution_dimension_numbers();
-    for (auto dim : dnums.kernel_spatial_dimensions()) {
-      if (dim == reshape_degenerate->inserted_dimensions[0]) {
-        return true;
-      }
-    }
-  }
-  return false;
-}
-
-bool CanImplementAsGpuForwardConv(HloInstruction* conv) {
-  const ConvolutionDimensionNumbers& dnums =
-      conv->convolution_dimension_numbers();
-  if (dnums.input_spatial_dimensions_size() > 3) {
-    return false;
-  }
-
-  // CuDNN does not accept zero-element arguments
-  if (ShapeUtil::IsZeroElementArray(conv->operand(0)->shape()) ||
-      ShapeUtil::IsZeroElementArray(conv->operand(1)->shape())) {
-    return false;
-  }
-
-  // CuDNN can perform either cross correlation (no reversal),
-  // or convolution (all dimensions reversed).
-  if (dnums.input_spatial_dimensions_size() == 2
-          ? !window_util::AllOrNoneReversed(conv->window())
-          : window_util::HasWindowReversal(conv->window())) {
-    return false;
-  }
-  return true;
-}
-
-// Try to match a backward filter pattern that contains "conv".
-// Precondition: "conv" is a kConvolution.
-ConvolutionMatch MatchBackwardFilter(HloInstruction* conv) {
-  VLOG(2) << "Trying to match convolution backward filter.";
-
-  if (conv->feature_group_count() > 1) {
-    VLOG(1) << conv->ToString()
-            << " is a forward convolution. All grouped backward filters are "
-               "mapped to batch grouped convolutions in tf2xla bridge. Hence "
-               "backward filter "
-               "convolutions cannot have feature groups greater than 1 at this "
-               "point. No need to fold to backward filter.";
-    return std::nullopt;
-  }
-
-  // Step 1: match the instruction pattern without considering the paddings and
-  // dimension numbers just yet. We may need some generic pattern matcher
-  // similar to third_party/llvm/llvm/include/llvm/IR/PatternMatch.h
-  //
-  // Backward filter convolution is implemented in XLA as the forward
-  // convolution of padded activations and dilated gradients. Padding on
-  // activations and dilation on gradients are specified in the "window" field
-  // of the forward convolution.
-  //
-  //        activations  gradients
-  //              \         /
-  //               v       v
-  //              Convolution
-  //                 conv
-  CHECK_EQ(HloOpcode::kConvolution, conv->opcode());
-
-  // Step 2: match paddings and dimension numbers of the forward convolution.
-  const ConvolutionDimensionNumbers& conv_dnums =
-      conv->convolution_dimension_numbers();
-  auto input_batch_dim = conv_dnums.input_batch_dimension();
-  auto input_feature_dim = conv_dnums.input_feature_dimension();
-  auto input_spatial_dims = conv_dnums.input_spatial_dimensions();
-  auto kernel_input_feature_dim = conv_dnums.kernel_input_feature_dimension();
-  auto kernel_output_feature_dim = conv_dnums.kernel_output_feature_dimension();
-  auto kernel_spatial_dims = conv_dnums.kernel_spatial_dimensions();
-  auto output_batch_dim = conv_dnums.output_batch_dimension();
-  auto output_feature_dim = conv_dnums.output_feature_dimension();
-  auto output_spatial_dims = conv_dnums.output_spatial_dimensions();
-  for (const WindowDimension& window_dim : conv->window().dimensions()) {
-    if (window_dim.stride() != 1) {
-      VLOG(1) << "Forward convolution's window "
-              << conv->window().ShortDebugString()
-              << " should have stride of 1.";
-      return std::nullopt;
-    }
-    if (window_dim.base_dilation() != 1) {
-      VLOG(1) << "Forward convolution's window "
-              << conv->window().ShortDebugString()
-              << " should have no base (LHS) dilation.";
-      return std::nullopt;
-    }
-    if (window_dim.padding_low() < 0) {
-      VLOG(1) << "Padding low should be non-negative.";
-      return std::nullopt;
-    }
-    if (window_dim.window_reversal()) {
-      VLOG(1) << "Window reversal field not supported";
-      return std::nullopt;
-    }
-    // Padding high will be checked in Step 3.
-  }
-  // Mathematically, there is no difference between convolution forward vs
-  // backward filter. A backward filter:
-  //   [N, O, H+h-1, W+w-1] x [N, C, H, W] -> [O, C, h, w]
-  // Can be treated as a forward convolution with `N` treated as the new
-  // contracting (feature) dimension, `O` treated as the new batch dimension,
-  // and `C` treated as the new output feature dimension. The only difference is
-  // layouts and performance.
-  //
-  // Since there is no way to precisely tell whether we want a foward conv or
-  // backward filter conv, we have to rely on heuristics. Empirically forward
-  // convolutions have very small kernel dimensions, while in the backward pass
-  // "kernel dimensions" are large. If kernel dimensions are smaller than the
-  // output dimensions, return foward conv; otherwise proceed with backward
-  // filter conv. But for conv1d, it is not same. Due to conv1d always reshape
-  // 1D-filter to 2D-filter, even backward or forward will exist one small
-  // kernel dimension. We should handle this special case.
-  int small_kernel_dimension_num = 0;
-  for (int i = 0; i < kernel_spatial_dims.size(); ++i) {
-    if (conv->operand(1)->shape().dimensions(kernel_spatial_dims[i]) <=
-        conv->shape().dimensions(output_spatial_dims[i])) {
-      small_kernel_dimension_num += 1;
-    }
-  }
-  if ((kernel_spatial_dims.empty() || small_kernel_dimension_num > 1 ||
-       (!MaybeConv1dToConv2d(conv) && small_kernel_dimension_num == 1)) &&
-      !window_util::HasWindowDilation(conv->window())) {
-    VLOG(1) << conv->ToString()
-            << " is a regular forward convolution. No need "
-               "to fold it to a backward filter convolution....";
-    return std::nullopt;
-  }
-
-  // Step 3: fuse the matched HLOs into a backward convolution instruction.
-  //
-  // Compute the window of the backward convolution.
-  Window backward_conv_window;
-  for (int i = 0; i < input_spatial_dims.size(); ++i) {
-    WindowDimension* dim = backward_conv_window.add_dimensions();
-    // The window size of the backward convolution equals the output size of the
-    // forward convolution.
-    int64_t filter_size = conv->shape().dimensions(output_spatial_dims[i]);
-    dim->set_size(filter_size);
-    // The window stride equals the window dilation of the forward convolution.
-    dim->set_stride(conv->window().dimensions(i).window_dilation());
-    // The window's low padding is the same as the low padding of the
-    // activations.
-    dim->set_padding_low(conv->window().dimensions(i).padding_low());
-    dim->set_base_dilation(1);
-    dim->set_window_dilation(1);
-
-    int64_t input_size =
-        conv->operand(0)->shape().dimensions(input_spatial_dims[i]);
-    int64_t output_size = conv->window().dimensions(i).size();
-    // Compute the range of the amount of valid high padding. We first compute
-    // min_padding_high, the amount of padding on the right/bottom to ensure the
-    // last patch ends at the border, i.e.,
-    //
-    //   input_size + dim->padding_low() + min_padding_high
-    //     = (output_size - 1) * stride + filter_size
-    //
-    // Because convolution ignores trailing incomplete windows, any amount of
-    // padding high from min_padding_high to min_padding_high+stride-1
-    // (max_padding_high) has the same effect.
-    int64_t padded_input_size = filter_size + (output_size - 1) * dim->stride();
-    int64_t min_padding_high =
-        padded_input_size - input_size - dim->padding_low();
-    int64_t max_padding_high = min_padding_high + dim->stride() - 1;
-    CHECK_GE(dim->padding_low(), 0);
-    // In practice, since cuDNN convolution only supports even padding, we make
-    // the amount of high padding the same as the amount of low padding as long
-    // as it is between min_padding_high and max_padding_high. If it is not in
-    // that range, we pick the one that's closest to dim->padding_low() and let
-    // GpuConvPaddingLegalization canonicalize the resultant backward
-    // convolution later. Picking the closest one minimizes the cost of the kPad
-    // instruction to be inserted by GpuConvPaddingLegalization.
-    if (dim->padding_low() >= min_padding_high &&
-        dim->padding_low() <= max_padding_high) {
-      dim->set_padding_high(dim->padding_low());
-    } else {
-      if (dim->padding_low() < min_padding_high) {
-        dim->set_padding_high(min_padding_high);
-      } else {
-        dim->set_padding_high(max_padding_high);
-      }
-    }
-    if (dim->padding_high() < 0) {
-      LOG(WARNING)
-          << "Fusing this pattern to backward filter convolution would cause "
-             "negative padding ("
-          << dim->padding_high()
-          << ") on right/bottom of the weight gradients, which is not "
-             "supported by GpuConvPaddingLegalization (b/32744257). "
-             "Falling back to "
-             "unfused convolution for instruction: "
-          << conv->ToString();
-      return std::nullopt;
-    }
-  }
-
-  // Restore the dimension numbers of the backward convolution from the forward
-  // convolution. The two activation dimensions are reversed (batch and
-  // feature).
-  ConvolutionDimensionNumbers backward_conv_dnums;
-  backward_conv_dnums.set_input_batch_dimension(input_feature_dim);
-  backward_conv_dnums.set_input_feature_dimension(input_batch_dim);
-  for (int i = 0; i < input_spatial_dims.size(); ++i) {
-    backward_conv_dnums.add_input_spatial_dimensions(input_spatial_dims[i]);
-  }
-  backward_conv_dnums.set_output_batch_dimension(kernel_input_feature_dim);
-  backward_conv_dnums.set_output_feature_dimension(kernel_output_feature_dim);
-  for (int i = 0; i < kernel_spatial_dims.size(); ++i) {
-    backward_conv_dnums.add_output_spatial_dimensions(kernel_spatial_dims[i]);
-  }
-  // The dimension numbering of the output of the forward convolution (before
-  // transposition) is the same as that of the activations (according to the
-  // semantics of kConvolution). The batch dimension of the activations should
-  // be treated as the input feature dimension, and the feature dimension should
-  // be treated as the output feature.
-  backward_conv_dnums.set_kernel_input_feature_dimension(output_batch_dim);
-  backward_conv_dnums.set_kernel_output_feature_dimension(output_feature_dim);
-  for (int i = 0; i < output_spatial_dims.size(); ++i) {
-    backward_conv_dnums.add_kernel_spatial_dimensions(output_spatial_dims[i]);
-  }
-
-  HloInstruction* lhs = conv->mutable_operand(0);
-  return std::make_tuple(backward_conv_window, backward_conv_dnums, lhs);
-}
-
-// Try to match a backward input pattern that contains "conv".
-// Precondition: "conv" is a kConvolution.
-ConvolutionMatch MatchBackwardInput(HloInstruction* conv) {
-  VLOG(2) << "Trying to match convolution backward input.";
-
-  // TODO(timshen) Theoretically cuDNN supports grouped convolutions also
-  // for the backward input convolution, but based on the cudnn's current state
-  // there is not much performance improvement when using the
-  // cudnn backward input API for grouped conv.
-  // This needs to be re-evaluated for future cuDNN versions.
-  // Note that we already have the necessary code down below, the only thing to
-  // enable it is to remove the following early return.
-  if (conv->feature_group_count() > 1) {
-    return std::nullopt;
-  }
-
-  // Match instruction pattern.
-  CHECK_EQ(HloOpcode::kConvolution, conv->opcode());
-  HloInstruction* reverse_filter = conv->mutable_operand(1);
-  ConvolutionDimensionNumbers dnums = conv->convolution_dimension_numbers();
-
-  // Match BackwardInput for a depthwise convolution and thunk it to forward
-  // convolution Output feature dimension and input feature dimension has been
-  // swapped in the bridge. Hence to get the actual input features we need to
-  // query the output feature dimension
-  auto kernel_out_feature_dim = dnums.kernel_output_feature_dimension();
-  auto kernel_out_features =
-      reverse_filter->shape().dimensions(kernel_out_feature_dim);
-
-  // For a depthwise convolution, the input features must be equal to the
-  // feature_group_count. We can leverage this property to match a depthwise
-  // convolution and thunk it to forward conv
-  if (conv->feature_group_count() > 1 &&
-      kernel_out_features == conv->feature_group_count()) {
-    return std::nullopt;
-  }
-
-  // We pattern-match to a backwards input conv if:
-  //
-  //  - all spatial dims of the filter are reversed
-  //
-  // OR
-  //
-  //  - filter is 1x1 or a constant AND
-  //  - conv has base dilation (otherwise this is just a regular forward conv).
-  //
-  // The final criterion above is just for canonicalization; cudnn seems to run
-  // just as fast if we canonicalize 1x1/constant filters without base dilation
-  // to forward or backward convs.  We canonicalize to forward conv because (a)
-  // it's more natural (constant filters usually show up when doing inference,
-  // and having backwards convolutions in inference graphs would be weird), and
-  // (b) cudnn has special fusions for forward conv plus bias and activation,
-  // and we want to pattern-match to that after running this pass.
-  bool is_reversed_filter =
-      reverse_filter->opcode() == HloOpcode::kReverse &&
-      absl::c_is_permutation(dnums.kernel_spatial_dimensions(),
-                             reverse_filter->dimensions());
-  // For conv1d which reshape to conv2d, filter reverse pattern is
-  // reshape(reverse(filter)). It seems we can reuse conv2d backward input
-  // pattern matcher, but after algsimp pass, this pattern will change to
-  // reverse(reshape(filter)) and fail to match. So matching conv1d backward
-  // input need different processing logic.
-  bool is_reversed_conv1d_filter =
-      MaybeConv1dToConv2d(conv) &&
-      reverse_filter->operand(0)->opcode() == HloOpcode::kReverse;
-  bool is_1x1_filter =
-      absl::c_all_of(conv->window().dimensions(),
-                     [](const WindowDimension& d) { return d.size() == 1; });
-  if (!is_reversed_filter && !is_reversed_conv1d_filter &&
-      !(window_util::HasBaseDilation(conv->window()) &&
-        (reverse_filter->IsConstant() || is_1x1_filter))) {
-    VLOG(1) << "Can't match to backwards convolution. Either filter is not "
-               "kReverse, or it's not a base-dilated conv with a 1x1 or "
-               "constant filter.";
-    return std::nullopt;
-  }
-
-  // Match padding and dilation of the forward convolution.
-  for (const WindowDimension& window_dim : conv->window().dimensions()) {
-    if (window_dim.stride() != 1) {
-      VLOG(1) << "Forward convolution's window "
-              << conv->window().ShortDebugString()
-              << " should have stride of 1.";
-      return std::nullopt;
-    }
-    if (window_dim.window_dilation() != 1) {
-      VLOG(1) << "Forward convolution's window "
-              << conv->window().ShortDebugString()
-              << " should have no window dilation.";
-      return std::nullopt;
-    }
-    if (window_dim.window_reversal()) {
-      VLOG(1) << "Window reversal field not supported";
-      return std::nullopt;
-    }
-  }
-
-  const auto& input_spatial_dims = dnums.input_spatial_dimensions();
-  const auto& output_spatial_dims = dnums.output_spatial_dimensions();
-  CHECK_EQ(conv->window().dimensions().size(), input_spatial_dims.size());
-  CHECK_EQ(output_spatial_dims.size(), input_spatial_dims.size());
-
-  const Window& old_window = conv->window();
-  Window new_window = old_window;
-  for (size_t i = 0; i < input_spatial_dims.size(); ++i) {
-    // Restore backward convolution's padding config from the matched pattern.
-    // See the comment in tensorflow/core/kernels/conv_grad_ops.h for how we
-    // convert backward input convolution to a variant of forward convolution.
-    //
-    // The stride of the backward convolution
-    // = the base dilation factor of the forward convolution
-    auto dim = new_window.mutable_dimensions(i);
-    dim->set_stride(old_window.dimensions(i).base_dilation());
-    dim->set_base_dilation(1);
-
-    // The low padding = kernel_size - 1 - low padding on the gradients
-    // Make sure the low padding is not negative.
-    auto kernel_size = old_window.dimensions(i).size();
-    auto backward_padding_low =
-        kernel_size - 1 - old_window.dimensions(i).padding_low();
-    if (backward_padding_low < 0) {
-      LOG(WARNING)
-          << "The low padding of the backward convolution would be negative ("
-          << backward_padding_low
-          << "), which isn't supported by GpuConvPaddingLegalization "
-             "for now (b/32744257).";
-      return std::nullopt;
-    }
-    dim->set_padding_low(backward_padding_low);
-
-    // Compute the range of the amount of padding on the right/bottom of the
-    // activations. XLA's convolution requires all patches to be within the
-    // padded base. This gives us flexiblity to choose the amount of high
-    // padding from a set of values without changing the result of the backward
-    // convolution. The minimum amount (min_padding_high) makes the last patch
-    // end at the border. The maximum amount (max_padding_high) equals
-    // min_padding_high+stride-1 -- max_padding_high+1 would cause the output
-    // size to change.
-    auto unpadded_input_size = conv->shape().dimensions(output_spatial_dims[i]);
-    auto output_size =
-        conv->operand(0)->shape().dimensions(input_spatial_dims[i]);
-    auto padded_input_size = kernel_size + dim->stride() * (output_size - 1);
-    auto total_pad_size = padded_input_size - unpadded_input_size;
-    auto min_padding_high = total_pad_size - backward_padding_low;
-    auto max_padding_high = min_padding_high + dim->stride() - 1;
-
-    if (backward_padding_low >= min_padding_high &&
-        backward_padding_low <= max_padding_high) {
-      // In the best case (most likely), if backward_padding_low is in the range
-      // of the amounts of valid high padding, we choose backward_padding_low
-      // because cuDNN supports even padding only.
-      dim->set_padding_high(backward_padding_low);
-    } else {
-      // Otherwise, we choose the amount that's closest to backward_padding_low,
-      // and GpuConvPaddingLegalization will later insert kSlice
-      // instructions to enforce even padding.
-      //
-      // For example, consider the backward convolution pattern
-      //
-      //   ab     xy
-      //   | pad  | reverse
-      //  .a.b    yx
-      //     \   /
-      //      ABC
-      //
-      // The amount of low padding on activations (in backward convolution) is
-      //   backward_padding_low = kernel_size - 1 - forward_padding_low
-      //                        = 2 - 1 - 1 = 0
-      //
-      // The amount of padding high must be between 1 and 2, in order to make
-      // Conv(ABC, xy, stride=2) produce exactly 2 elements (ab). 0 is not in
-      // the range of [1,2], so we pick the closest valid amount of padding
-      // high, which is 1 in this case. Therefore, we fuse the above pattern to
-      //
-      //   ABC = BackwardInputConv(ab, xy, stride=2, padding_high=1)
-      if (backward_padding_low < min_padding_high) {
-        dim->set_padding_high(min_padding_high);
-      } else {
-        dim->set_padding_high(max_padding_high);
-      }
-    }
-    // GpuConvPaddingLegalization doesn't handle backward input
-    // convolution with negative padding for now. So fall back to unfused
-    // convolution in case of negative padding. For example,
-    //   ABCD = Conv(abc, reverse(xy), padding_high=2)
-    // could be fused to
-    //   ABCD = BackwardInputConv(abc, xy, padding_low=1, padding_high=-1)
-    // with positive padding low but negative padding high.
-    if (dim->padding_high() < 0) {
-      LOG(WARNING) << "Fusing this pattern to backward convolution would cause "
-                      "negative padding ("
-                   << dim->padding_high()
-                   << ") on right/bottom of the activations, which is not "
-                      "supported by GpuConvPaddingLegalization (b/32744257). "
-                      "Falling back to unfused convolution for instruction: "
-                   << conv->ToString();
-      return std::nullopt;
-    }
-  }
-
-  // OK, it's a match! Switch the input feature dimension with the output
-  // feature dimension. Also switch the output with the input. This is the way
-  // cuDNN expects it to be.
-  auto conv_dnums = conv->convolution_dimension_numbers();
-  dnums.set_kernel_input_feature_dimension(
-      conv_dnums.kernel_output_feature_dimension());
-  dnums.set_kernel_output_feature_dimension(
-      conv_dnums.kernel_input_feature_dimension());
-  for (int i = 0; i < input_spatial_dims.size(); ++i) {
-    dnums.set_input_spatial_dimensions(i,
-                                       conv_dnums.output_spatial_dimensions(i));
-    dnums.set_output_spatial_dimensions(i,
-                                        conv_dnums.input_spatial_dimensions(i));
-  }
-  dnums.set_input_feature_dimension(conv_dnums.output_feature_dimension());
-  dnums.set_input_batch_dimension(conv_dnums.output_batch_dimension());
-  dnums.set_output_feature_dimension(conv_dnums.input_feature_dimension());
-  dnums.set_output_batch_dimension(conv_dnums.input_batch_dimension());
-
-  // If we matched against a constant, we need to add a reverse op that can be
-  // subsumed by the cuDNN call. algebraic-simplifier will later remove any
-  // unnecessary reverses.
-  if (reverse_filter->opcode() != HloOpcode::kReverse &&
-      reverse_filter->IsConstant()) {
-    // Create a double-reverse, which is a nop.
-    HloComputation* c = conv->parent();
-    reverse_filter = c->AddInstruction(
-        HloInstruction::CreateReverse(reverse_filter->shape(), reverse_filter,
-                                      dnums.kernel_spatial_dimensions()));
-    reverse_filter = c->AddInstruction(
-        HloInstruction::CreateReverse(reverse_filter->shape(), reverse_filter,
-                                      dnums.kernel_spatial_dimensions()));
-    TF_CHECK_OK(conv->ReplaceOperandWith(/*operand_num=*/1, reverse_filter));
-  }
-
-  // Calculate the 'rhs' that goes into the backward input convolution.
-  HloInstruction* rhs = reverse_filter;
-  // One reverse is subsumed by the cuDNN call.
-  if (rhs->opcode() == HloOpcode::kReverse) {
-    rhs = rhs->mutable_operand(0);
-  } else if (is_reversed_conv1d_filter) {
-    auto src = rhs->mutable_operand(0)->mutable_operand(0);
-    rhs = conv->parent()->AddInstruction(
-        HloInstruction::CreateReshape(rhs->shape(), src));
-  }
-  if (conv->feature_group_count() == 1) {
-    return std::make_tuple(new_window, dnums, rhs);
-  }
-
-  // Handle grouped convolutions. Because we swapped the input feature dimension
-  // with the output feature dimension, we need to also reshape the kernel so
-  // that the 'feature_group_count' parameter still makes sense. The
-  // 'feature_group_count' parameter essentially specifies how often the
-  // 'kernel_input_feature_dimension' is repeated. So when we swap these
-  // dimensions, we need to divide the new 'kernel_input_feature_dimension' by
-  // 'feature_group_count' and multiply the new
-  // 'kernel_output_feature_dimension' by 'feature_group_count'.
-  int64_t input_feature_dimension = dnums.kernel_input_feature_dimension();
-  int64_t output_feature_dimension = dnums.kernel_output_feature_dimension();
-  // The following code assumes that input_feature_dimension and
-  // output_feature_dimension are adjacent.
-  if (std::abs(input_feature_dimension - output_feature_dimension) != 1) {
-    return std::nullopt;
-  }
-
-  int64_t input_features = rhs->shape().dimensions(input_feature_dimension);
-  int64_t output_features = rhs->shape().dimensions(output_feature_dimension);
-
-  // Reshape [H, W, ..., in_depth, out_depth / G] -> [H, W, ..., G, in_depth/G,
-  // out_depth / G]
-  std::vector<int64_t> reshape_dims = SpanToVector(rhs->shape().dimensions());
-  auto num_groups = conv->feature_group_count();
-  CHECK_EQ(input_features % num_groups, 0)
-      << "Input feature count should be an exact multiple of feature group "
-         "count";
-  reshape_dims[input_feature_dimension] =
-      reshape_dims[input_feature_dimension] / num_groups;
-  reshape_dims.insert(reshape_dims.begin() + input_feature_dimension,
-                      num_groups);
-
-  HloComputation* c = conv->parent();
-  rhs = c->AddInstruction(HloInstruction::CreateReshape(
-      ShapeUtil::MakeShape(rhs->shape().element_type(), reshape_dims), rhs));
-
-  // Transpose [H, W, ..., G, in_depth/G, out_depth / G] -> [H, W, ...,
-  // in_depth/G, G, out_depth / G]
-  std::vector<int64_t> transpose_dims(rhs->shape().dimensions_size());
-  std::iota(transpose_dims.begin(), transpose_dims.end(), 0);
-  transpose_dims.erase(transpose_dims.begin() + input_feature_dimension);
-  transpose_dims.insert(transpose_dims.begin() + output_feature_dimension,
-                        input_feature_dimension);
-  std::vector<int64_t> transpose_reshape_dims =
-      SpanToVector(rhs->shape().dimensions());
-  transpose_reshape_dims.erase(transpose_reshape_dims.begin() +
-                               input_feature_dimension);
-  transpose_reshape_dims.insert(
-      transpose_reshape_dims.begin() + output_feature_dimension, num_groups);
-  rhs = c->AddInstruction(HloInstruction::CreateTranspose(
-      ShapeUtil::MakeShape(rhs->shape().element_type(), transpose_reshape_dims),
-      rhs, transpose_dims));
-
-  // Reshape [H, W, ..., in_depth/G, G, out_depth / G] -> [H, W, ...,
-  // in_depth/G, out_depth]
-  Shape new_shape = rhs->shape();
-  new_shape.DeleteDimension(output_feature_dimension);
-  new_shape.set_dimensions(output_feature_dimension,
-                           output_features * num_groups);
-  rhs = c->AddInstruction(HloInstruction::CreateReshape(new_shape, rhs));
-  return std::make_tuple(new_window, dnums, rhs);
-}
-
-HloInstruction* CreateGpuConv(absl::string_view call_target, const Shape& shape,
-                              HloInstruction* lhs, HloInstruction* rhs,
-                              const Window& window,
-                              const ConvolutionDimensionNumbers& dnums,
-                              int64_t feature_group_count,
-                              const PrecisionConfig& precision_config,
-                              const OpMetadata& metadata) {
-  HloComputation* computation = lhs->parent();
-
-  // This call returns a tuple of (conv_result, scratch_memory), where
-  // conv_result is the actual result of the convolution, and scratch_memory is
-  // temporary memory used by cudnn.
-  //
-  // At the moment, we don't know how much scratch memory this conv is going to
-  // use, so we put u8[0] in this place.  Later on another pass will choose
-  // which conv algorithm to use, and at that point we'll modify the shape of
-  // this second tuple element.
-  Shape call_shape =
-      ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U8, {0})});
-
-  HloInstruction* custom_call = computation->AddInstruction(
-      HloInstruction::CreateCustomCall(call_shape, {lhs, rhs}, call_target));
-  custom_call->set_window(window);
-  custom_call->set_convolution_dimension_numbers(dnums);
-  custom_call->set_feature_group_count(feature_group_count);
-  *custom_call->mutable_precision_config() = precision_config;
-  custom_call->set_metadata(metadata);
-
-  // Give the customcall a user-friendly name.
-  std::optional<std::string> name;
-  if (call_target == kCudnnConvForwardCallTarget) {
-    name = "cudnn-conv";
-  } else if (call_target == kCudnnConvBackwardInputCallTarget) {
-    name = "cudnn-conv-bw-input";
-  } else if (call_target == kCudnnConvBackwardFilterCallTarget) {
-    name = "cudnn-conv-bw-filter";
-  } else if (call_target == kCudnnConvBiasActivationForwardCallTarget) {
-    name = "cudnn-conv-bias-activation";
-  }
-  if (name.has_value()) {
-    computation->parent()->SetAndUniquifyInstrName(custom_call, *name);
-  }
-
-  return custom_call;
-}
-
-HloInstruction* ConvertBatchGroupedToFeatureGroupedConvolution(
-    HloInstruction* conv) {
-  CHECK_EQ(conv->feature_group_count(), 1);
-  int64_t num_groups = conv->batch_group_count();
-  auto dim_numbers = conv->convolution_dimension_numbers();
-  auto lhs = conv->mutable_operand(0);
-  auto rhs = conv->mutable_operand(1);
-
-  int64_t input_batch_dimension = dim_numbers.input_batch_dimension();
-
-  Shape output_shape = conv->shape();
-  int64_t input_feature_dimension = dim_numbers.input_feature_dimension();
-  int64_t input_feature = lhs->shape().dimensions(input_feature_dimension);
-
-  HloComputation* computation = lhs->parent();
-  auto add = [&](std::unique_ptr<HloInstruction> inst) {
-    return computation->AddInstruction(std::move(inst));
-  };
-  // Reshape batch_dim N -> [G, N/G]
-  std::vector<int64_t> reshape_dims = SpanToVector(lhs->shape().dimensions());
-  reshape_dims[input_batch_dimension] =
-      reshape_dims[input_batch_dimension] / num_groups;
-  reshape_dims.insert(reshape_dims.begin() + input_batch_dimension, num_groups);
-  lhs = add(HloInstruction::CreateReshape(
-      ShapeUtil::MakeShape(lhs->shape().element_type(), reshape_dims), lhs));
-
-  // Transpose G to the axis before C, For eg: [G, N/G, H, W, C ] -> [N/G, H,
-  // W, G, C]
-  std::vector<int64_t> transpose_dims(lhs->shape().dimensions_size());
-  std::iota(transpose_dims.begin(), transpose_dims.end(), 0);
-  transpose_dims.erase(transpose_dims.begin() + input_batch_dimension);
-  transpose_dims.insert(transpose_dims.begin() + input_feature_dimension,
-                        input_batch_dimension);
-  std::vector<int64_t> transpose_reshape_dims =
-      ComposePermutations(lhs->shape().dimensions(), transpose_dims);
-  lhs = add(HloInstruction::CreateTranspose(
-      ShapeUtil::MakeShape(lhs->shape().element_type(), transpose_reshape_dims),
-      lhs, transpose_dims));
-
-  // Merge [G,C] -> [C*G]
-  Shape new_shape = lhs->shape();
-  new_shape.DeleteDimension(input_feature_dimension);
-  new_shape.set_dimensions(input_feature_dimension, input_feature * num_groups);
-  lhs = add(HloInstruction::CreateReshape(new_shape, lhs));
-
-  std::vector<HloInstruction*> new_operands = {lhs, rhs};
-  auto new_conv = conv->CloneWithNewOperands(output_shape, new_operands);
-  new_conv->set_feature_group_count(num_groups);
-  new_conv->set_batch_group_count(1);
-  new_conv->set_convolution_dimension_numbers(dim_numbers);
-  return computation->AddInstruction(std::move(new_conv));
-}
-
-CudnnConvBackendConfig GetDefaultBackendConfig() {
-  CudnnConvBackendConfig config;
-  config.set_conv_result_scale(1);
-  return config;
-}
-
-// Helper function to create a custom_call instruction to replace the given
-// conv instruction
-static absl::StatusOr<HloInstruction*> CreateCustomCallHelper(
-    HloInstruction* conv, const se::GpuComputeCapability& cc) {
-  TF_RETURN_IF_ERROR(CheckTypes(conv, cc));
-  if (ConvolutionMatch m = MatchBackwardInput(conv)) {
-    auto& [window, dnums, rhs] = *m;
-    return CreateGpuConv(kCudnnConvBackwardInputCallTarget, conv->shape(),
-                         conv->mutable_operand(0), rhs, window, dnums,
-                         conv->feature_group_count(), conv->precision_config(),
-                         conv->metadata());
-  }
-
-  if (ConvolutionMatch m = MatchBackwardFilter(conv)) {
-    auto& [window, dnums, lhs] = *m;
-    return CreateGpuConv(kCudnnConvBackwardFilterCallTarget, conv->shape(), lhs,
-                         conv->mutable_operand(1), window, dnums,
-                         conv->batch_group_count(), conv->precision_config(),
-                         conv->metadata());
-  }
-
-  // If all else fails, try a forward convolution.
-  if (CanImplementAsGpuForwardConv(conv)) {
-    if (conv->batch_group_count() > 1) {
-      conv = ConvertBatchGroupedToFeatureGroupedConvolution(conv);
-    }
-
-    return CreateGpuConv(kCudnnConvForwardCallTarget, conv->shape(),
-                         conv->mutable_operand(0), conv->mutable_operand(1),
-                         conv->window(), conv->convolution_dimension_numbers(),
-                         conv->feature_group_count(), conv->precision_config(),
-                         conv->metadata());
-  }
-
-  return nullptr;
-}
-
-// Tries to rewrite a single convolution into a call to cudnn/miopen.
-absl::StatusOr<bool> RunOnInstruction(HloInstruction* conv,
-                                      const se::GpuComputeCapability& cc) {
-  CHECK_EQ(conv->opcode(), HloOpcode::kConvolution);
-
-  TF_ASSIGN_OR_RETURN(HloInstruction * custom_call,
-                      CreateCustomCallHelper(conv, cc));
-  if (custom_call == nullptr) {
-    return false;
-  }
-
-  GpuBackendConfig gpu_backend_config;
-  *gpu_backend_config.mutable_cudnn_conv_backend_config() =
-      GetDefaultBackendConfig();
-  TF_RETURN_IF_ERROR(custom_call->set_backend_config(gpu_backend_config));
-
-  VLOG(1) << "Replacing convolution " << conv->ToString() << " with "
-          << custom_call->ToString();
-
-  // The CustomCall returns a tuple (conv_result, scratch_memory).  Extract
-  // out the conv result and replace `conv` with it.
-  TF_RETURN_IF_ERROR(conv->parent()->ReplaceWithNewInstruction(
-      conv,
-      HloInstruction::CreateGetTupleElement(conv->shape(), custom_call, 0)));
-  return true;
-}
-
-// Rewrites the convolutions in the given computation into calls to
-// cudnn/miopen.
-// Returns true if it made any changes.
-absl::StatusOr<bool> RunOnComputation(HloComputation* computation,
-                                      const se::GpuComputeCapability& cc) {
-  std::vector<HloInstruction*> convs;
-  for (auto* hlo : computation->instructions()) {
-    if (hlo->opcode() == HloOpcode::kConvolution) {
-      convs.push_back(hlo);
-    }
-  }
-
-  bool changed = false;
-  for (HloInstruction* conv : convs) {
-    TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(conv, cc));
-    changed |= result;
-  }
-  return changed;
-}
-}  // namespace
-
-absl::StatusOr<bool> GpuConvRewriter::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  XLA_VLOG_LINES(2, "GpuConvRewriter::Run(), before:\n" + module->ToString());
-  bool changed = false;
-  for (HloComputation* computation :
-       module->MakeNonfusionComputations(execution_threads)) {
-    TF_ASSIGN_OR_RETURN(bool result,
-                        RunOnComputation(computation, compute_capability_));
-    changed |= result;
-  }
-  XLA_VLOG_LINES(2, "GpuConvRewriter::Run(), after:\n" + module->ToString());
-  return changed;
-}
-
-/*static*/ bool GpuConvRewriter::ConvIsLowerable(HloInstruction* conv) {
-  return CanImplementAsGpuForwardConv(conv) || MatchBackwardFilter(conv) ||
-         MatchBackwardInput(conv);
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gpu_conv_rewriter.h b/third_party/xla/xla/service/gpu/gpu_conv_rewriter.h
deleted file mode 100644
index 74b860f..0000000
--- a/third_party/xla/xla/service/gpu/gpu_conv_rewriter.h
+++ /dev/null
@@ -1,58 +0,0 @@
-/* Copyright 2018 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_GPU_CONV_REWRITER_H_
-#define XLA_SERVICE_GPU_GPU_CONV_REWRITER_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-// Rewrites plain convolutions, backwards-filter convolutions, and
-// backwards-input convolutions into CustomCall HLOs that call into
-// Cudnn/Miopen.
-//
-// This pass does not fuse other ops into the convolution. Instead, specific
-// patterns of ops will be matched and fused into the custom call in
-// CudnnFusedConvRewriter.
-
-class GpuConvRewriter : public HloModulePass {
- public:
-  explicit GpuConvRewriter(const se::GpuComputeCapability& compute_capability)
-      : compute_capability_(compute_capability) {};
-
-  absl::string_view name() const override { return "gpu-conv-rewriter"; }
-
-  static bool ConvIsLowerable(HloInstruction* conv);
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
-  se::GpuComputeCapability compute_capability_;
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_GPU_CONV_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/gpu_conv_rewriter_test.cc b/third_party/xla/xla/service/gpu/gpu_conv_rewriter_test.cc
deleted file mode 100644
index f83bae8..0000000
--- a/third_party/xla/xla/service/gpu/gpu_conv_rewriter_test.cc
+++ /dev/null
@@ -1,812 +0,0 @@
-/* Copyright 2018 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gpu_conv_rewriter.h"
-
-#include <optional>
-#include <string>
-
-#include "absl/log/check.h"
-#include "absl/strings/str_format.h"
-#include "xla/array4d.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/literal_util.h"
-#include "xla/protobuf_util.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/service/shape_inference.h"
-#include "xla/shape_util.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/test.h"
-#include "xla/test_helpers.h"
-#include "xla/tests/hlo_test_base.h"
-#include "tsl/platform/status_matchers.h"
-#include "tsl/platform/statusor.h"
-#include "tsl/platform/test.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-namespace m = ::xla::match;
-
-class GpuConvRewriterTest : public HloTestBase {
- public:
-  GpuConvRewriterTest()
-      : HloTestBase(/*verifier_layout_sensitive=*/true,
-                    /*allow_mixed_precision_in_hlo_verifier=*/false) {
-    for (int i = 0; i < 2; ++i) {
-      WindowDimension* window_dim = default_conv_window_.add_dimensions();
-      window_dim->set_size(1);
-      window_dim->set_stride(1);
-      window_dim->set_padding_low(0);
-      window_dim->set_padding_high(0);
-      window_dim->set_window_dilation(1);
-      window_dim->set_base_dilation(1);
-    }
-    // TF data shapes are by default in the NHWC order, and filter shape is by
-    // default in HWIO order. For backward filter convolution, we need to swap
-    // the batch and feature dimension in the activations, and treat the batch
-    // dimension in gradients as the input feature dimension in the filter.
-    //
-    // TODO(jingyue): Add more tests on NCHW input order, which TF also
-    // supports.
-    tf_default_dnums_for_backward_filter_.set_input_batch_dimension(3);
-    tf_default_dnums_for_backward_filter_.set_input_feature_dimension(0);
-    tf_default_dnums_for_backward_filter_.add_input_spatial_dimensions(1);
-    tf_default_dnums_for_backward_filter_.add_input_spatial_dimensions(2);
-    tf_default_dnums_for_backward_filter_.set_kernel_input_feature_dimension(0);
-    tf_default_dnums_for_backward_filter_.set_kernel_output_feature_dimension(
-        3);
-    tf_default_dnums_for_backward_filter_.add_kernel_spatial_dimensions(1);
-    tf_default_dnums_for_backward_filter_.add_kernel_spatial_dimensions(2);
-    tf_default_dnums_for_backward_filter_.add_output_spatial_dimensions(0);
-    tf_default_dnums_for_backward_filter_.add_output_spatial_dimensions(1);
-    tf_default_dnums_for_backward_filter_.set_output_batch_dimension(2);
-    tf_default_dnums_for_backward_filter_.set_output_feature_dimension(3);
-
-    tf_default_dnums_for_backward_input_.set_input_batch_dimension(0);
-    tf_default_dnums_for_backward_input_.set_output_batch_dimension(0);
-    tf_default_dnums_for_backward_input_.set_input_feature_dimension(3);
-    tf_default_dnums_for_backward_input_.set_output_feature_dimension(3);
-    tf_default_dnums_for_backward_input_.add_input_spatial_dimensions(1);
-    tf_default_dnums_for_backward_input_.add_output_spatial_dimensions(1);
-    tf_default_dnums_for_backward_input_.add_input_spatial_dimensions(2);
-    tf_default_dnums_for_backward_input_.add_output_spatial_dimensions(2);
-    tf_default_dnums_for_backward_input_.set_kernel_input_feature_dimension(3);
-    tf_default_dnums_for_backward_input_.set_kernel_output_feature_dimension(2);
-    tf_default_dnums_for_backward_input_.add_kernel_spatial_dimensions(0);
-    tf_default_dnums_for_backward_input_.add_kernel_spatial_dimensions(1);
-  }
-
- protected:
-  const se::GpuComputeCapability& GetComputeCapability() {
-    return backend()
-        .default_stream_executor()
-        ->GetDeviceDescription()
-        .gpu_compute_capability();
-  }
-
-  bool RunPass(HloModule* module) {
-    return GpuConvRewriter(GetComputeCapability()).Run(module).value();
-  }
-
-  // A convolution window with stride 1 and zero padding. The size fields are
-  // not set.
-  Window default_conv_window_;
-  ConvolutionDimensionNumbers tf_default_dnums_for_backward_filter_;
-  ConvolutionDimensionNumbers tf_default_dnums_for_backward_input_;
-};
-
-TEST_F(GpuConvRewriterTest, BackwardFilterConvolve) {
-  HloComputation::Builder builder(TestName());
-  HloInstruction* activations =
-      builder.AddInstruction(HloInstruction::CreateParameter(
-          0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "activations"));
-  HloInstruction* gradients =
-      builder.AddInstruction(HloInstruction::CreateParameter(
-          1, ShapeUtil::MakeShape(F32, {1, 1, 2, 1}), "gradients"));
-  Window conv_window = default_conv_window_;
-  conv_window.mutable_dimensions(1)->set_size(2);
-  conv_window.mutable_dimensions(1)->set_window_dilation(2);
-  auto* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
-      ShapeInference::InferConvolveShape(
-          activations->shape(), gradients->shape(), /*feature_group_count=*/1,
-          /*batch_group_count=*/1, conv_window,
-          tf_default_dnums_for_backward_filter_,
-          /*preferred_element_type=*/std::nullopt)
-          .value(),
-      activations, gradients, /*feature_group_count=*/1,
-      /*batch_group_count=*/1, conv_window,
-      tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
-
-  OpMetadata metadata;
-  metadata.set_op_name("foo");
-  conv->set_metadata(metadata);
-
-  auto module = CreateNewVerifiedModule();
-  HloComputation* entry_computation =
-      module->AddEntryComputation(builder.Build());
-  EXPECT_TRUE(RunPass(module.get()));
-  ASSERT_THAT(entry_computation->root_instruction(),
-              GmockMatch(m::GetTupleElement(
-                  m::CustomCall({kCudnnConvBackwardFilterCallTarget}), 0)));
-
-  // Check that metadata was preserved.
-  const auto& md_after_opt =
-      entry_computation->root_instruction()->operand(0)->metadata();
-  EXPECT_TRUE(protobuf_util::ProtobufEquals(md_after_opt, metadata))
-      << md_after_opt.DebugString() << " vs " << metadata.DebugString();
-}
-
-TEST_F(GpuConvRewriterTest,
-       BackwardFilterConvolveEquivalentToForwardConvolution) {
-  HloComputation::Builder builder(TestName());
-  HloInstruction* activations =
-      builder.AddInstruction(HloInstruction::CreateParameter(
-          0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "activations"));
-  HloInstruction* gradients =
-      builder.AddInstruction(HloInstruction::CreateParameter(
-          1, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "gradients"));
-  Window conv_window = default_conv_window_;
-  conv_window.mutable_dimensions(1)->set_size(3);
-  builder.AddInstruction(HloInstruction::CreateConvolve(
-      ShapeInference::InferConvolveShape(
-          activations->shape(), gradients->shape(), /*feature_group_count=*/1,
-          /*batch_group_count=*/1, conv_window,
-          tf_default_dnums_for_backward_filter_,
-          /*preferred_element_type=*/std::nullopt)
-          .value(),
-      activations, gradients, /*feature_group_count=*/1,
-      /*batch_group_count=*/1, conv_window,
-      tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
-
-  auto module = CreateNewVerifiedModule();
-  HloComputation* entry_computation =
-      module->AddEntryComputation(builder.Build());
-  EXPECT_TRUE(RunPass(module.get()));
-  EXPECT_THAT(entry_computation->root_instruction(),
-              GmockMatch(m::GetTupleElement(
-                  m::CustomCall({kCudnnConvForwardCallTarget}), 0)));
-}
-
-// Extracted from block35 training.
-TEST_F(GpuConvRewriterTest, BackwardFilterConvolveWithPaddedActivations) {
-  auto builder = HloComputation::Builder(TestName());
-  HloInstruction* activations =
-      builder.AddInstruction(HloInstruction::CreateParameter(
-          0, ShapeUtil::MakeShape(F32, {20, 35, 35, 32}), "activations"));
-  HloInstruction* gradients =
-      builder.AddInstruction(HloInstruction::CreateParameter(
-          1, ShapeUtil::MakeShape(F32, {20, 35, 35, 32}), "gradients"));
-
-  Window conv_window = default_conv_window_;
-  for (int i = 0; i < 2; ++i) {
-    conv_window.mutable_dimensions(i)->set_size(35);
-    conv_window.mutable_dimensions(i)->set_padding_low(1);
-    conv_window.mutable_dimensions(i)->set_padding_high(1);
-  }
-  builder.AddInstruction(HloInstruction::CreateConvolve(
-      ShapeUtil::MakeShape(F32, {32, 3, 3, 32}), activations, gradients,
-      /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window,
-      tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
-
-  auto module = CreateNewVerifiedModule();
-  HloComputation* entry_computation =
-      module->AddEntryComputation(builder.Build());
-  EXPECT_TRUE(RunPass(module.get()));
-  EXPECT_THAT(entry_computation->root_instruction(),
-              GmockMatch(m::GetTupleElement(
-                  m::CustomCall({kCudnnConvBackwardFilterCallTarget}), 0)));
-}
-
-// Extracted from inception v3 training.
-TEST_F(GpuConvRewriterTest, BackwardFilterConvolveWithPaddedGradients) {
-  auto builder = HloComputation::Builder(TestName());
-  HloInstruction* activations =
-      builder.AddInstruction(HloInstruction::CreateParameter(
-          0, ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), "activations"));
-  HloInstruction* gradients =
-      builder.AddInstruction(HloInstruction::CreateParameter(
-          1, ShapeUtil::MakeShape(F32, {20, 4, 4, 320}), "gradients"));
-
-  Window conv_window = default_conv_window_;
-  for (int i = 0; i < 2; ++i) {
-    conv_window.mutable_dimensions(i)->set_size(4);
-    conv_window.mutable_dimensions(i)->set_padding_high(-1);
-    conv_window.mutable_dimensions(i)->set_window_dilation(2);
-  }
-  builder.AddInstruction(HloInstruction::CreateConvolve(
-      ShapeUtil::MakeShape(F32, {320, 3, 3, 192}), activations, gradients,
-      /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window,
-      tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
-
-  auto module = CreateNewVerifiedModule();
-  HloComputation* entry_computation =
-      module->AddEntryComputation(builder.Build());
-  EXPECT_TRUE(RunPass(module.get()));
-  EXPECT_THAT(entry_computation->root_instruction(),
-              GmockMatch(m::GetTupleElement(
-                  m::CustomCall({kCudnnConvBackwardFilterCallTarget}), 0)));
-}
-
-TEST_F(GpuConvRewriterTest, BackwardFilterConvolveWithUnevenPadding) {
-  auto builder = HloComputation::Builder(TestName());
-  HloInstruction* activations =
-      builder.AddInstruction(HloInstruction::CreateParameter(
-          0, ShapeUtil::MakeShape(F32, {20, 35, 35, 32}), "activations"));
-  HloInstruction* gradients =
-      builder.AddInstruction(HloInstruction::CreateParameter(
-          1, ShapeUtil::MakeShape(F32, {20, 35, 35, 32}), "gradients"));
-
-  Window conv_window = default_conv_window_;
-  for (int i = 0; i < 2; ++i) {
-    conv_window.mutable_dimensions(i)->set_size(35);
-    // Uneven padding: padding_low=0, padding_high=1
-    conv_window.mutable_dimensions(i)->set_padding_high(1);
-  }
-  builder.AddInstruction(HloInstruction::CreateConvolve(
-      ShapeUtil::MakeShape(F32, {32, 2, 2, 32}), activations, gradients,
-      /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window,
-      tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
-
-  auto module = CreateNewVerifiedModule();
-  HloComputation* entry_computation =
-      module->AddEntryComputation(builder.Build());
-  EXPECT_TRUE(RunPass(module.get()));
-  EXPECT_THAT(entry_computation->root_instruction(),
-              GmockMatch(m::GetTupleElement(
-                  m::CustomCall({kCudnnConvBackwardFilterCallTarget}), 0)));
-}
-
-TEST_F(GpuConvRewriterTest, BackwardInputConvolveEvenPadding) {
-  auto builder = HloComputation::Builder(TestName());
-  HloInstruction* output =
-      builder.AddInstruction(HloInstruction::CreateParameter(
-          0, ShapeUtil::MakeShape(F32, {4, 5, 16, 16}), "output"));
-  HloInstruction* kernel =
-      builder.AddInstruction(HloInstruction::CreateParameter(
-          1, ShapeUtil::MakeShape(F32, {5, 3, 7, 7}), "kernel"));
-  HloInstruction* reverse_kernel = builder.AddInstruction(
-      HloInstruction::CreateReverse(kernel->shape(), kernel, {2, 3}));
-
-  Window conv_window = default_conv_window_;
-  for (int i = 0; i < 2; ++i) {
-    conv_window.mutable_dimensions(i)->set_size(7);
-    conv_window.mutable_dimensions(i)->set_padding_low(3);
-    conv_window.mutable_dimensions(i)->set_padding_high(3);
-  }
-  ConvolutionDimensionNumbers conv_dnums;
-  conv_dnums.set_input_batch_dimension(0);
-  conv_dnums.set_output_batch_dimension(0);
-  conv_dnums.set_input_feature_dimension(1);
-  conv_dnums.set_output_feature_dimension(1);
-  conv_dnums.add_input_spatial_dimensions(2);
-  conv_dnums.add_output_spatial_dimensions(2);
-  conv_dnums.add_input_spatial_dimensions(3);
-  conv_dnums.add_output_spatial_dimensions(3);
-  conv_dnums.set_kernel_input_feature_dimension(0);
-  conv_dnums.set_kernel_output_feature_dimension(1);
-  conv_dnums.add_kernel_spatial_dimensions(2);
-  conv_dnums.add_kernel_spatial_dimensions(3);
-
-  HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
-      ShapeUtil::MakeShape(F32, {4, 3, 16, 16}), /*lhs=*/output,
-      /*rhs=*/reverse_kernel, /*feature_group_count=*/1,
-      /*batch_group_count=*/1, conv_window, conv_dnums,
-      DefaultPrecisionConfig(2)));
-  // Verify the convolution's shape is consistent with ShapeInference.
-  CHECK(ShapeUtil::Compatible(
-      conv->shape(),
-      ShapeInference::InferConvolveShape(
-          output->shape(), reverse_kernel->shape(),
-          /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window,
-          conv_dnums, /*preferred_element_type=*/std::nullopt)
-          .value()));
-
-  auto module = CreateNewVerifiedModule();
-  HloComputation* entry_computation =
-      module->AddEntryComputation(builder.Build());
-  EXPECT_TRUE(RunPass(module.get()));
-
-  ASSERT_THAT(entry_computation->root_instruction(),
-              GmockMatch(m::GetTupleElement(
-                  m::CustomCall({kCudnnConvBackwardInputCallTarget}), 0)));
-  const HloInstruction* custom_call =
-      entry_computation->root_instruction()->operand(0);
-  for (int i = 0; i < 2; ++i) {
-    const WindowDimension& window_dim = custom_call->window().dimensions(i);
-    // Low padding of the backward input convolution
-    //   = kernel_size - 1 - low padding on gradients.
-    EXPECT_EQ(3, window_dim.padding_low());
-    EXPECT_EQ(3, window_dim.padding_high());
-    EXPECT_EQ(1, window_dim.stride());
-    EXPECT_EQ(1, window_dim.base_dilation());
-  }
-}
-
-// Convolve([abc], [x], base_dilation=2)
-//   = Convolve([abc], Reverse([x]), base_dilation=2)
-//   = BackwardInputConvolve([abc], [x], stride=2)
-TEST_F(GpuConvRewriterTest, BackwardInputConvolve1x1Filter) {
-  auto builder = HloComputation::Builder(TestName());
-  // NHWC dimension order.
-  HloInstruction* output =
-      builder.AddInstruction(HloInstruction::CreateParameter(
-          0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "output"));
-  // HWOI dimension order.
-  HloInstruction* kernel =
-      builder.AddInstruction(HloInstruction::CreateParameter(
-          1, ShapeUtil::MakeShape(F32, {1, 1, 1, 1}), "kernel"));
-
-  Window conv_window = default_conv_window_;
-  conv_window.mutable_dimensions(1)->set_base_dilation(2);
-
-  builder.AddInstruction(HloInstruction::CreateConvolve(
-      ShapeInference::InferConvolveShape(
-          output->shape(), kernel->shape(),
-          /*feature_group_count=*/1,
-          /*batch_group_count=*/1, conv_window,
-          tf_default_dnums_for_backward_input_,
-          /*preferred_element_type=*/std::nullopt)
-          .value(),
-      /*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1,
-      /*batch_group_count=*/1, conv_window,
-      tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
-
-  auto module = CreateNewVerifiedModule();
-  HloComputation* entry_computation =
-      module->AddEntryComputation(builder.Build());
-  EXPECT_TRUE(RunPass(module.get()));
-  EXPECT_THAT(entry_computation->root_instruction(),
-              GmockMatch(m::GetTupleElement(
-                  m::CustomCall({kCudnnConvBackwardInputCallTarget}), 0)));
-}
-
-// BackwardInputConvolve([abc], [x], stride=1) is equivalent to
-// ForwardConvolve([abc], [x], stride=1). No need to fold it into backward input
-// convolution.
-TEST_F(GpuConvRewriterTest,
-       BackwardInputConvolve1x1FilterEquivalentToForwardConvolve) {
-  auto builder = HloComputation::Builder(TestName());
-  // NHWC dimension order.
-  HloInstruction* output =
-      builder.AddInstruction(HloInstruction::CreateParameter(
-          0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "output"));
-  // HWOI dimension order.
-  HloInstruction* kernel =
-      builder.AddInstruction(HloInstruction::CreateParameter(
-          1, ShapeUtil::MakeShape(F32, {1, 1, 1, 1}), "kernel"));
-
-  builder.AddInstruction(HloInstruction::CreateConvolve(
-      ShapeInference::InferConvolveShape(
-          output->shape(), kernel->shape(), /*feature_group_count=*/1,
-          /*batch_group_count=*/1, default_conv_window_,
-          tf_default_dnums_for_backward_input_,
-          /*preferred_element_type=*/std::nullopt)
-          .value(),
-      /*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1,
-      /*batch_group_count=*/1, default_conv_window_,
-      tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
-
-  auto module = CreateNewVerifiedModule();
-  HloComputation* entry_computation =
-      module->AddEntryComputation(builder.Build());
-  EXPECT_TRUE(RunPass(module.get()));
-  EXPECT_THAT(entry_computation->root_instruction(),
-              GmockMatch(m::GetTupleElement(
-                  m::CustomCall({kCudnnConvForwardCallTarget}), 0)));
-}
-
-// Extracted from Inception V3 training.
-//
-//                                  filter(HWIO)
-//                                  3x3x192x320
-//                                      |
-//                                      v
-//      gradients(NHWC)              reverse
-//        20x4x4x320               3x3x192x320
-//                    \            /
-//                     \          /
-//  conv (NHWC) with padding (low=2,high=3,interior=1)
-//                     20x10x10x192
-//
-// Gradients are padded unevenly.
-TEST_F(GpuConvRewriterTest, BackwardInputConvolveUnevenPaddingOnGradients) {
-  auto builder = HloComputation::Builder(TestName());
-  HloInstruction* output =
-      builder.AddInstruction(HloInstruction::CreateParameter(
-          0, ShapeUtil::MakeShape(F32, {20, 4, 4, 320}), "output"));
-  HloInstruction* kernel =
-      builder.AddInstruction(HloInstruction::CreateParameter(
-          1, ShapeUtil::MakeShape(F32, {3, 3, 192, 320}), "kernel"));
-  HloInstruction* reverse_kernel = builder.AddInstruction(
-      HloInstruction::CreateReverse(kernel->shape(), kernel, {0, 1}));
-
-  Window conv_window = default_conv_window_;
-  for (int i = 0; i < 2; ++i) {
-    conv_window.mutable_dimensions(i)->set_size(3);
-    conv_window.mutable_dimensions(i)->set_padding_low(2);
-    conv_window.mutable_dimensions(i)->set_padding_high(3);
-    // Interior padding = 1.
-    conv_window.mutable_dimensions(i)->set_base_dilation(2);
-  }
-  HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
-      ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), output, reverse_kernel,
-      /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window,
-      tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
-  // Verify the convolution's shape is consistent with ShapeInference.
-  CHECK(ShapeUtil::Compatible(
-      conv->shape(), ShapeInference::InferConvolveShape(
-                         output->shape(), reverse_kernel->shape(),
-                         /*feature_group_count=*/1, /*batch_group_count=*/1,
-                         conv_window, tf_default_dnums_for_backward_input_,
-                         /*preferred_element_type=*/std::nullopt)
-                         .value()));
-
-  auto module = CreateNewVerifiedModule();
-  HloComputation* entry_computation =
-      module->AddEntryComputation(builder.Build());
-  EXPECT_TRUE(RunPass(module.get()));
-  ASSERT_THAT(entry_computation->root_instruction(),
-              GmockMatch(m::GetTupleElement(
-                  m::CustomCall({kCudnnConvBackwardInputCallTarget}), 0)));
-  const HloInstruction* custom_call =
-      entry_computation->root_instruction()->operand(0);
-  for (int i = 0; i < 2; ++i) {
-    const WindowDimension& window_dim = custom_call->window().dimensions(i);
-    EXPECT_EQ(0, window_dim.padding_low());
-    EXPECT_EQ(0, window_dim.padding_high());
-    EXPECT_EQ(2, window_dim.stride());
-    EXPECT_EQ(1, window_dim.base_dilation());
-  }
-}
-
-// Similar to BackwardInputConvolveUnevenPadding, but the low padding of the
-// gradients exceeds kernel_size - 1. Therefore, this pattern cannot be fused.
-TEST_F(GpuConvRewriterTest, BackwardInputConvolveLowPaddingTooLarge) {
-  auto builder = HloComputation::Builder(TestName());
-  HloInstruction* output =
-      builder.AddInstruction(HloInstruction::CreateParameter(
-          0, ShapeUtil::MakeShape(F32, {20, 4, 4, 320}), "output"));
-  HloInstruction* kernel =
-      builder.AddInstruction(HloInstruction::CreateParameter(
-          1, ShapeUtil::MakeShape(F32, {3, 3, 192, 320}), "kernel"));
-  HloInstruction* reverse_kernel = builder.AddInstruction(
-      HloInstruction::CreateReverse(kernel->shape(), kernel, {0, 1}));
-
-  Window conv_window = default_conv_window_;
-  for (int i = 0; i < 2; ++i) {
-    conv_window.mutable_dimensions(i)->set_size(3);
-    conv_window.mutable_dimensions(i)->set_padding_low(3);
-    conv_window.mutable_dimensions(i)->set_padding_high(2);
-    conv_window.mutable_dimensions(i)->set_base_dilation(2);
-  }
-  HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
-      ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), output, reverse_kernel,
-      /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window,
-      tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
-  // Verify the convolution's shape is consistent with ShapeInference.
-  CHECK(ShapeUtil::Compatible(
-      conv->shape(), ShapeInference::InferConvolveShape(
-                         output->shape(), reverse_kernel->shape(),
-                         /*feature_group_count=*/1, /*batch_group_count=*/1,
-                         conv_window, tf_default_dnums_for_backward_input_,
-                         /*preferred_element_type=*/std::nullopt)
-                         .value()));
-
-  auto module = CreateNewVerifiedModule();
-  HloComputation* entry_computation =
-      module->AddEntryComputation(builder.Build());
-  EXPECT_TRUE(RunPass(module.get()));
-  EXPECT_THAT(entry_computation->root_instruction(),
-              GmockMatch(m::GetTupleElement(
-                  m::CustomCall({kCudnnConvForwardCallTarget}), 0)));
-}
-
-// Extracted from Resnet-50.
-//
-// For simplicity, we focus on the column dimension and ignore other dimensions.
-// We use [?] to represent the shape instead of the content.
-//
-// Suppose operator FC does
-//   [4] = conv([14], [3], stride=2, padding_high=1)  // Padding::kSame
-//
-// BC = BackwardInput(FC) does:
-//   [14] = conv([7], reverse([3]),
-//               padding_low=2, padding_high=1, base_dilation=2)
-//
-// We should fuse BC even though padding on activations is uneven, because
-// GpuConvPaddingLegalization will canonicalize the fusion HLO.
-TEST_F(GpuConvRewriterTest, BackwardInputConvolveUnevenPaddingOnActivations) {
-  auto builder = HloComputation::Builder(TestName());
-  // The gradients are in NCHW layout.
-  HloInstruction* output =
-      builder.AddInstruction(HloInstruction::CreateParameter(
-          0, ShapeUtil::MakeShape(F32, {1, 1, 7, 1}), "output"));
-  // The kernel is in HWIO layout.
-  HloInstruction* kernel =
-      builder.AddInstruction(HloInstruction::CreateParameter(
-          1, ShapeUtil::MakeShape(F32, {1, 3, 1, 1}), "kernel"));
-  HloInstruction* reverse_kernel = builder.AddInstruction(
-      HloInstruction::CreateReverse(kernel->shape(), kernel, {0, 1}));
-
-  Window conv_window = default_conv_window_;
-  WindowDimension* forward_conv_col_dim = conv_window.mutable_dimensions(1);
-  forward_conv_col_dim->set_size(3);
-  forward_conv_col_dim->set_padding_low(2);
-  forward_conv_col_dim->set_padding_high(1);
-  forward_conv_col_dim->set_base_dilation(2);
-  HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
-      ShapeUtil::MakeShape(F32, {1, 1, 14, 1}), output, reverse_kernel,
-      /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window,
-      tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
-  // Verify the convolution's shape is consistent with ShapeInference.
-  CHECK(ShapeUtil::Compatible(
-      conv->shape(), ShapeInference::InferConvolveShape(
-                         output->shape(), reverse_kernel->shape(),
-                         /*feature_group_count=*/1, /*batch_group_count=*/1,
-                         conv_window, tf_default_dnums_for_backward_input_,
-                         /*preferred_element_type=*/std::nullopt)
-                         .value()));
-
-  auto module = CreateNewVerifiedModule();
-  const HloComputation* entry_computation =
-      module->AddEntryComputation(builder.Build());
-  EXPECT_TRUE(RunPass(module.get()));
-  ASSERT_THAT(entry_computation->root_instruction(),
-              GmockMatch(m::GetTupleElement(
-                  m::CustomCall({kCudnnConvBackwardInputCallTarget}), 0)));
-  const WindowDimension& backward_conv_col_dim =
-      entry_computation->root_instruction()->operand(0)->window().dimensions(1);
-  EXPECT_EQ(0, backward_conv_col_dim.padding_low());
-  EXPECT_EQ(1, backward_conv_col_dim.padding_high());
-}
-
-// For simplicity, we focus on the column dimension and ignore other dimensions.
-// We use [?] to represent the shape instead of the content.
-//
-// Suppose operator FC does
-//   [3] = conv([4], [2], padding_low=1, padding_high=-1)
-//
-// BC = BackwardInput(FC) does:
-//   [4] = conv([3], reverse([2]), padding_high=2)
-//
-// We currently don't fuse BC because GpuConvPaddingLegalization
-// doesn't support negative padding on the gradients of backward convolution
-// (b/32744257).
-TEST_F(GpuConvRewriterTest,
-       BackwardInputConvolveNegativePaddingHighOnActivations) {
-  auto builder = HloComputation::Builder(TestName());
-  // The gradients are in NCHW layout.
-  HloInstruction* output =
-      builder.AddInstruction(HloInstruction::CreateParameter(
-          0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "output"));
-  // The kernel is in HWIO layout.
-  HloInstruction* kernel =
-      builder.AddInstruction(HloInstruction::CreateParameter(
-          1, ShapeUtil::MakeShape(F32, {1, 2, 1, 1}), "kernel"));
-  HloInstruction* reverse_kernel = builder.AddInstruction(
-      HloInstruction::CreateReverse(kernel->shape(), kernel, {0, 1}));
-
-  Window conv_window = default_conv_window_;
-  WindowDimension* forward_conv_col_dim = conv_window.mutable_dimensions(1);
-  forward_conv_col_dim->set_size(2);
-  forward_conv_col_dim->set_padding_high(2);
-  HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
-      ShapeUtil::MakeShape(F32, {1, 1, 4, 1}), output, reverse_kernel,
-      /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window,
-      tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
-  // Verify the convolution's shape is consistent with ShapeInference.
-  CHECK(ShapeUtil::Compatible(
-      conv->shape(), ShapeInference::InferConvolveShape(
-                         output->shape(), reverse_kernel->shape(),
-                         /*feature_group_count=*/1, /*batch_group_count=*/1,
-                         conv_window, tf_default_dnums_for_backward_input_,
-                         /*preferred_element_type=*/std::nullopt)
-                         .value()));
-
-  auto module = CreateNewVerifiedModule();
-  HloComputation* entry_computation =
-      module->AddEntryComputation(builder.Build());
-  EXPECT_TRUE(RunPass(module.get()));
-  EXPECT_THAT(entry_computation->root_instruction(),
-              GmockMatch(m::GetTupleElement(
-                  m::CustomCall({kCudnnConvForwardCallTarget}), 0)));
-}
-
-// Check that we will materialize a reversed version of a constant in order to
-// pattern-match a backwards input convolution.
-TEST_F(GpuConvRewriterTest, BackwardInputConvolveConstantFilter) {
-  Array4D<float> constant_arr(4, 4, 2, 2);
-  constant_arr.FillIota(0);
-  std::string constant_str =
-      LiteralUtil::CreateR4FromArray4D(constant_arr).ToStringWithoutShape();
-
-  const std::string module_str = absl::StrFormat(R"(
-    HloModule test
-
-    ENTRY entry_computation {
-      param0 = f32[128,2,16,16]{3,2,1,0} parameter(0)
-      constant = f32[4,4,2,2]{3,2,1,0} constant(%s)
-      ROOT convolution = f32[128,2,32,32]{3,2,1,0} convolution(param0, constant),
-          window={size=4x4 pad=2_2x2_2 lhs_dilate=2x2},
-          dim_labels=bf01_01oi->bf01, feature_group_count=1
-    })",
-                                                 constant_str);
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  EXPECT_TRUE(RunPass(m.get()));
-  EXPECT_THAT(m->entry_computation()->root_instruction(),
-              GmockMatch(m::GetTupleElement(
-                  m::CustomCall({kCudnnConvBackwardInputCallTarget},
-                                m::Parameter(), m::Reverse(m::Constant())),
-                  0)));
-}
-
-TEST_F(GpuConvRewriterTest, TestBackwardFilterPatternMatch) {
-  // All filter dimensions are larger than the corresponding output dimensions.
-  // This must be a backward filter convolution.
-  const std::string module_str = absl::StrFormat(R"(
-    HloModule Test
-
-    ENTRY Test {
-      input = f32[8,120,256,256] parameter(0)
-      filter = f32[8,120,256,256] parameter(1)
-
-      ROOT conv = f32[120,120,3,3] convolution(input, filter), window={size=256x256 pad=1_1x1_1}, dim_labels=fb01_io01->fb01
-    })");
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  EXPECT_TRUE(RunPass(m.get()));
-  EXPECT_THAT(m->entry_computation()->root_instruction(),
-              GmockMatch(m::GetTupleElement(
-                  m::CustomCall({kCudnnConvBackwardFilterCallTarget},
-                                m::Parameter(0), m::Parameter(1)),
-                  0)));
-}
-
-TEST_F(GpuConvRewriterTest, TestBackwardFilterPatternNoMatch) {
-  // At least one filter dimension is smaller than the corresponding output
-  // dimension. This must be a forward convolution.
-  const std::string module_str = absl::StrFormat(R"(
-    HloModule Test
-
-    ENTRY Test {
-      input = f32[8,128,2,32] parameter(0)
-      filter = f32[3,3,128,128] parameter(1)
-
-      ROOT conv = f32[8,128,2,32] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01
-    })");
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  EXPECT_TRUE(RunPass(m.get()));
-  EXPECT_THAT(m->entry_computation()->root_instruction(),
-              GmockMatch(m::GetTupleElement(
-                  m::CustomCall({kCudnnConvForwardCallTarget}, m::Parameter(0),
-                                m::Parameter(1)),
-                  0)));
-}
-
-TEST_F(GpuConvRewriterTest, TestConv1dBackwardFilterPatternMatch) {
-  // There exist one kernel dimension equal to output dimension, regard
-  // it as backward filter if conv is 1d.
-  const std::string module_str = absl::StrFormat(R"(
-    HloModule Test
-
-    ENTRY Test {
-      input = f32[8,256,128] parameter(0)
-      filter = f32[8,254,128] parameter(1)
-      reshape.1 = f32[8,1,256,128] reshape(input)
-      reshape.2 = f32[8,1,254,128] reshape(filter)
-      ROOT conv = f32[1,3,128,128] convolution(reshape.1, reshape.2), window={size=1x254}, dim_labels=f01b_i01o->01bf
-    })");
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  EXPECT_TRUE(RunPass(m.get()));
-  EXPECT_THAT(m->entry_computation()->root_instruction(),
-              GmockMatch(m::GetTupleElement(
-                  m::CustomCall({kCudnnConvBackwardFilterCallTarget},
-                                m::Reshape(), m::Reshape()),
-                  0)));
-}
-
-TEST_F(GpuConvRewriterTest, TestConv1dBackwardInputPatternMatch) {
-  // For conv1d backward input, filter may reverse first and then reshape.
-  const std::string module_str = absl::StrFormat(R"(
-    HloModule Test
-
-    ENTRY Test {
-      input = f32[8,254,128] parameter(0)
-      filter = f32[3,128,128] parameter(1)
-      reverse = f32[3,128,128] reverse(filter), dimensions={0}
-      reshape.1 = f32[8,1,254,128] reshape(input)
-      reshape.2 = f32[1,3,128,128] reshape(reverse)
-      ROOT conv = f32[8,1,256,128] convolution(reshape.1, reshape.2), window={size=1x3 pad=0_0x2_2}, dim_labels=b01f_01oi->b01f
-    })");
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  EXPECT_TRUE(RunPass(m.get()));
-  EXPECT_THAT(m->entry_computation()->root_instruction(),
-              GmockMatch(m::GetTupleElement(
-                  m::CustomCall({kCudnnConvBackwardInputCallTarget},
-                                m::Reshape(), m::Reshape()),
-                  0)));
-}
-
-TEST_F(GpuConvRewriterTest, TestInvalidTypes) {
-  const std::string module_str = absl::StrFormat(R"(
-    HloModule Test
-
-    ENTRY Test {
-      input = TYPE[1,17,9,9] parameter(0)
-      filter = TYPE[3,3,17,32] parameter(1)
-      ROOT conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
-    })");
-
-  // Test complex types
-  for (std::string_view type : {"c64", "c128"}) {
-    const std::string module_with_type =
-        absl::StrReplaceAll(module_str, {{"TYPE", type}});
-    TF_ASSERT_OK_AND_ASSIGN(auto m,
-                            ParseAndReturnVerifiedModule(module_with_type));
-
-    absl::Status s =
-        GpuConvRewriter(GetComputeCapability()).Run(m.get()).status();
-    EXPECT_THAT(
-        s, tsl::testing::StatusIs(
-               absl::StatusCode::kUnimplemented,
-               ::testing::HasSubstr("Convolutions must have floating-point or "
-                                    "integral operands/outputs")));
-  }
-
-  // Test FP8 type on unsupported GPUs
-  std::string module_with_type =
-      absl::StrReplaceAll(module_str, {{"TYPE", "f8e4m3fn"}});
-  TF_ASSERT_OK_AND_ASSIGN(auto m,
-                          ParseAndReturnVerifiedModule(module_with_type));
-  absl::Status s = GpuConvRewriter(se::CudaComputeCapability::Ampere())
-                       .Run(m.get())
-                       .status();
-  EXPECT_THAT(s, tsl::testing::StatusIs(
-                     absl::StatusCode::kUnimplemented,
-                     ::testing::HasSubstr(
-                         "FP8 convolutions are only supported on CUDA "
-                         "GPUs with compute capability at least 9.0")));
-  s = GpuConvRewriter(se::RocmComputeCapability{"gfx942"})
-          .Run(m.get())
-          .status();
-  EXPECT_THAT(s, tsl::testing::StatusIs(
-                     absl::StatusCode::kUnimplemented,
-                     ::testing::HasSubstr(
-                         "FP8 convolutions are only supported on CUDA GPUs")));
-
-  // Test unsupported FP8 type
-  module_with_type = absl::StrReplaceAll(module_str, {{"TYPE", "f8e4m3fnuz"}});
-  TF_ASSERT_OK_AND_ASSIGN(m, ParseAndReturnVerifiedModule(module_with_type));
-  s = GpuConvRewriter(GetComputeCapability()).Run(m.get()).status();
-  EXPECT_THAT(s,
-              tsl::testing::StatusIs(
-                  absl::StatusCode::kUnimplemented,
-                  ::testing::HasSubstr("The only FP8 types supported in "
-                                       "convolutions are f8e5m2 and f8e4m3")));
-}
-
-}  // anonymous namespace
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gpu_convert_async_collectives_to_sync.cc b/third_party/xla/xla/service/gpu/gpu_convert_async_collectives_to_sync.cc
deleted file mode 100644
index b8c87e2..0000000
--- a/third_party/xla/xla/service/gpu/gpu_convert_async_collectives_to_sync.cc
+++ /dev/null
@@ -1,80 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gpu_convert_async_collectives_to_sync.h"
-
-#include <utility>
-#include <vector>
-
-#include "absl/container/flat_hash_map.h"
-#include "absl/status/status.h"
-#include "absl/types/span.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_schedule.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-absl::Status GpuConvertAsyncCollectivesToSync::ConvertAsyncInstructionsToSync(
-    HloComputation* computation,
-    absl::Span<const std::pair<HloInstruction*, HloInstruction*>> async_pairs)
-    const {
-  absl::flat_hash_map<HloInstruction*, HloInstruction*> replaced_ops;
-  CollectiveBackendConfig sync_config;
-  sync_config.set_is_sync(true);
-  for (auto& [async_start, async_done] : async_pairs) {
-    // Tag the async start with is_sync = true.
-    TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
-                        async_start->backend_config<GpuBackendConfig>());
-    *gpu_config.mutable_collective_backend_config() = sync_config;
-    TF_RETURN_IF_ERROR(async_start->set_backend_config(gpu_config));
-    replaced_ops[async_start] = nullptr;
-    replaced_ops[async_done] = async_start;
-  }
-
-  // Update schedule.
-  HloModule* module = computation->parent();
-  const HloInstructionSequence& sequence =
-      module->schedule().sequence(computation);
-  std::vector<HloInstruction*> new_sequence;
-  new_sequence.reserve(sequence.size());
-  for (HloInstruction* instr : sequence.instructions()) {
-    auto it = replaced_ops.find(instr);
-    // If its not a start or done, add it to new schedule.
-    if (it == replaced_ops.end()) {
-      new_sequence.push_back(instr);
-      continue;
-    }
-
-    // If its a start op, do not add it to the schedule yet.
-    if (it->second == nullptr) {
-      continue;
-    }
-
-    // Its a done op. First add the start and then the done.
-    new_sequence.push_back(it->second);
-    new_sequence.push_back(instr);
-  }
-  module->schedule().set_sequence(computation, new_sequence);
-  return absl::OkStatus();
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gpu_convert_async_collectives_to_sync.h b/third_party/xla/xla/service/gpu/gpu_convert_async_collectives_to_sync.h
deleted file mode 100644
index ea56f7a..0000000
--- a/third_party/xla/xla/service/gpu/gpu_convert_async_collectives_to_sync.h
+++ /dev/null
@@ -1,47 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_GPU_CONVERT_ASYNC_COLLECTIVES_TO_SYNC_H_
-#define XLA_SERVICE_GPU_GPU_CONVERT_ASYNC_COLLECTIVES_TO_SYNC_H_
-
-#include <utility>
-
-#include "absl/status/status.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/service/convert_async_collectives_to_sync.h"
-
-namespace xla {
-namespace gpu {
-
-class GpuConvertAsyncCollectivesToSync : public ConvertAsyncCollectivesToSync {
- public:
-  using ConvertAsyncCollectivesToSync::ConvertAsyncCollectivesToSync;
-  absl::string_view name() const override {
-    return "gpu-convert-async-collectives-to-sync";
-  }
-
-  absl::Status ConvertAsyncInstructionsToSync(
-      HloComputation* computation,
-      absl::Span<const std::pair<HloInstruction*, HloInstruction*>> async_pairs)
-      const override;
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_GPU_CONVERT_ASYNC_COLLECTIVES_TO_SYNC_H_
diff --git a/third_party/xla/xla/service/gpu/gpu_convert_async_collectives_to_sync_test.cc b/third_party/xla/xla/service/gpu/gpu_convert_async_collectives_to_sync_test.cc
deleted file mode 100644
index 03f18bd..0000000
--- a/third_party/xla/xla/service/gpu/gpu_convert_async_collectives_to_sync_test.cc
+++ /dev/null
@@ -1,347 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gpu_convert_async_collectives_to_sync.h"
-
-#include <string_view>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "absl/status/status.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/tsl/lib/core/status_test_util.h"
-#include "xla/util.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-using ::testing::IsFalse;
-using ::testing::IsTrue;
-
-// Note: The pass only processes modules that are already scheduled. If the test
-// does not work as expected, make sure to check if "is_scheduled=true" is added
-// to the HLO module string.
-class GpuConvertAsyncCollectivesToSyncTest : public HloTestBase {
- public:
-  absl::Status RunPass(HloModule *module, bool expect_change,
-                       HloPredicate is_nop = {}) {
-    TF_ASSIGN_OR_RETURN(bool changed,
-                        GpuConvertAsyncCollectivesToSync{is_nop}.Run(module));
-    EXPECT_EQ(changed, expect_change);
-    return absl::OkStatus();
-  }
-
-  // Returns true if the instruction with the given name is synchronous.
-  bool IsSync(HloModule *module, std::string_view name) {
-    const HloInstruction *inst = FindInstruction(module, name);
-    if (inst == nullptr) {
-      return false;
-    }
-    auto backend_config = inst->backend_config<GpuBackendConfig>()
-                              .value()
-                              .collective_backend_config();
-    return backend_config.is_sync();
-  }
-
-  HloPredicate is_nop_simple_ =
-      HloPredicateIsOp<HloOpcode::kBitcast, HloOpcode::kGetTupleElement,
-                       HloOpcode::kParameter>;
-};
-
-TEST_F(GpuConvertAsyncCollectivesToSyncTest, SimpleAllReduce) {
-  const absl::string_view hlo_string = R"(
-      HloModule test, is_scheduled=true
-
-      apply_op {
-        x = u32[] parameter(0)
-        y = u32[] parameter(1)
-        ROOT apply_op = u32[] add(x, y)
-      }
-
-      ENTRY test_computation {
-        id = u32[] replica-id()
-        start = u32[] all-reduce-start(id), to_apply=apply_op, channel_id=3
-        ROOT done = u32[] all-reduce-done(start)
-      }
-    )";
-  TF_ASSERT_OK_AND_ASSIGN(auto module,
-                          ParseAndReturnVerifiedModule(hlo_string));
-  TF_ASSERT_OK(RunPass(module.get(), /*expect_change=*/true));
-  EXPECT_THAT(IsSync(module.get(), "start"), IsTrue());
-}
-
-TEST_F(GpuConvertAsyncCollectivesToSyncTest, SimpleAllReduceWithNop) {
-  const absl::string_view hlo_string = R"(
-      HloModule test, is_scheduled=true
-
-      apply_op {
-        x = u32[] parameter(0)
-        y = u32[] parameter(1)
-        ROOT apply_op = u32[] add(x, y)
-      }
-
-      ENTRY test_computation {
-        id = u32[] replica-id()
-        start = u32[] all-reduce-start(id), to_apply=apply_op, channel_id=3, replica_groups={{0,1}, {2,3}}
-        id2 = f32[] bitcast(id)
-        ROOT done = u32[] all-reduce-done(start)
-      }
-    )";
-  TF_ASSERT_OK_AND_ASSIGN(auto module,
-                          ParseAndReturnVerifiedModule(hlo_string));
-  TF_ASSERT_OK(RunPass(module.get(), /*expect_change=*/true, is_nop_simple_));
-  EXPECT_THAT(IsSync(module.get(), "start"), IsTrue());
-}
-TEST_F(GpuConvertAsyncCollectivesToSyncTest, SimpleCollectiveBroadcast) {
-  const absl::string_view hlo_string = R"(
-  HloModule test, is_scheduled=true
-
-  collective_broadcast {
-    p0 = u32[8] parameter(0)
-    ROOT result = u32[8] collective-broadcast(p0), replica_groups={{0,1}, {2,3}}
-  }
-
-  ENTRY main {
-    data = u32[8] parameter(0)
-    cb-start = ((u32[8]{0}), u32[8]{0}) async-start(u32[8]{0} %data), calls=collective_broadcast
-    ROOT %ars = u32[8]{0} async-done(((u32[8]{0}), u32[8]{0}) %cb-start), calls=collective_broadcast
-  }
-  )";
-  TF_ASSERT_OK_AND_ASSIGN(auto module,
-                          ParseAndReturnVerifiedModule(hlo_string));
-  TF_ASSERT_OK(RunPass(module.get(), /*expect_change=*/true));
-  EXPECT_THAT(IsSync(module.get(), "cb-start"), IsTrue());
-}
-
-TEST_F(GpuConvertAsyncCollectivesToSyncTest, SimpleAllReduceWithNonNop) {
-  const absl::string_view hlo_string = R"(
-      HloModule test, is_scheduled=true
-
-      apply_op {
-        x = u32[] parameter(0)
-        y = u32[] parameter(1)
-        ROOT apply_op = u32[] add(x, y)
-      }
-
-      ENTRY test_computation {
-        id = u32[] replica-id()
-        start = u32[] all-reduce-start(id), to_apply=apply_op, channel_id=3
-        id2 = u32[] add(id, id)
-        ROOT done = u32[] all-reduce-done(start)
-      }
-    )";
-  TF_ASSERT_OK_AND_ASSIGN(auto module,
-                          ParseAndReturnVerifiedModule(hlo_string));
-  TF_ASSERT_OK(RunPass(module.get(), /*expect_change=*/false));
-}
-
-TEST_F(GpuConvertAsyncCollectivesToSyncTest, SimpleAllGather) {
-  const absl::string_view hlo_string = R"(
-  HloModule test, is_scheduled=true
-  ENTRY test_computation {
-    a1 = u32[1, 2] parameter(0)
-    ags = (u32[1, 2], u32[2, 2]) all-gather-start(a1), dimensions={0}, channel_id=3
-    ROOT allgather = u32[2,2] all-gather-done(ags)
-  })";
-  TF_ASSERT_OK_AND_ASSIGN(auto module,
-                          ParseAndReturnVerifiedModule(hlo_string));
-  TF_ASSERT_OK(RunPass(module.get(), /*expect_change=*/true));
-  EXPECT_THAT(IsSync(module.get(), "ags"), IsTrue());
-}
-
-TEST_F(GpuConvertAsyncCollectivesToSyncTest, SimpleCollectivePermute) {
-  const absl::string_view hlo_string = R"(
-  HloModule test, is_scheduled=true
-
-  ENTRY test_computation {
-    p = u32[2] parameter(0)
-    start = (u32[2], u32[2], u32[], u32[]) collective-permute-start(p), source_target_pairs={{0,1}, {1,0}}
-    ROOT done = u32[2] collective-permute-done(start)
-  })";
-  TF_ASSERT_OK_AND_ASSIGN(auto module,
-                          ParseAndReturnVerifiedModule(hlo_string));
-  TF_ASSERT_OK(RunPass(module.get(), /*expect_change=*/true));
-  EXPECT_THAT(IsSync(module.get(), "start"), IsTrue());
-}
-
-TEST_F(GpuConvertAsyncCollectivesToSyncTest, SimpleReduceScatter) {
-  const absl::string_view hlo_string = R"(
-  HloModule test, is_scheduled=true
-
-  add {
-    lhs = u32[] parameter(0)
-    rhs = u32[] parameter(1)
-    ROOT add = u32[] add(lhs, rhs)
-  }
-
-  reduce_scatter {
-    p0 = u32[8] parameter(0)
-    ROOT result = u32[4] reduce-scatter(p0), replica_groups={{0,3}, {1,2}},
-                      dimensions={0}, to_apply=add
-  }
-
-  ENTRY main {
-    data = u32[8] parameter(0)
-    rs-start = ((u32[8]{0}), u32[4]{0}) async-start(u32[8]{0} %data), calls=reduce_scatter
-    ROOT %ars = u32[4]{0} async-done(((u32[8]{0}), u32[4]{0}) %rs-start), calls=reduce_scatter
-  }
-  )";
-  TF_ASSERT_OK_AND_ASSIGN(auto module,
-                          ParseAndReturnVerifiedModule(hlo_string));
-  TF_ASSERT_OK(RunPass(module.get(), /*expect_change=*/true));
-  EXPECT_THAT(IsSync(module.get(), "rs-start"), IsTrue());
-}
-
-TEST_F(GpuConvertAsyncCollectivesToSyncTest, SimpleAllToAll) {
-  const absl::string_view hlo_string = R"(
-  HloModule test, is_scheduled=true
-
-  all_to_all {
-    p0 = u32[2] parameter(0)
-    ROOT result = u32[2] all-to-all(p0), dimensions={0}, replica_groups={{0,1},{2,3}}
-  }
-
-  ENTRY test_computation {
-    a1 = u32[2] parameter(0)
-    a2a-start = ((u32[2]), u32[2]) async-start(u32[2] a1), calls=all_to_all
-    ROOT a2s = u32[2] async-done(a2a-start), calls=all_to_all
-  }
-  )";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto module,
-                          ParseAndReturnVerifiedModule(hlo_string));
-  TF_ASSERT_OK(RunPass(module.get(), /*expect_change=*/true));
-  EXPECT_THAT(IsSync(module.get(), "a2a-start"), IsTrue());
-}
-
-TEST_F(GpuConvertAsyncCollectivesToSyncTest, ControlDeps) {
-  const absl::string_view hlo_string = R"(
-      HloModule test, is_scheduled=true
-
-      apply_op {
-        x = u32[] parameter(0)
-        y = u32[] parameter(1)
-        ROOT apply_op = u32[] add(x, y)
-      }
-
-      ENTRY test_computation {
-        id = u32[] replica-id()
-        start1 = u32[] all-reduce-start(id), to_apply=apply_op, channel_id=3
-        done1 = u32[] all-reduce-done(start1)
-        start2 = u32[] all-reduce-start(id), to_apply=apply_op, channel_id=4, control-predecessors={done1}
-        done2 = u32[] all-reduce-done(start2)
-        ROOT x = u32[] add(done1, done2)
-      }
-    )";
-  TF_ASSERT_OK_AND_ASSIGN(auto module,
-                          ParseAndReturnVerifiedModule(hlo_string));
-  TF_ASSERT_OK(RunPass(module.get(), /*expect_change=*/true));
-  EXPECT_THAT(IsSync(module.get(), "start1"), IsTrue());
-  EXPECT_THAT(IsSync(module.get(), "start2"), IsTrue());
-}
-
-// Test multiple in-flight collectives that are ordered in a streaming fashion:
-// i.e., ends are in start order (FIFO).
-TEST_F(GpuConvertAsyncCollectivesToSyncTest, MultipleInFlightStreaming) {
-  const absl::string_view hlo_string = R"(
-      HloModule test, is_scheduled=true
-
-      apply_op {
-        x = u32[] parameter(0)
-        y = u32[] parameter(1)
-        ROOT apply_op = u32[] add(x, y)
-      }
-
-      ENTRY test_computation {
-        id = u32[] replica-id()
-        start1 = u32[] all-reduce-start(id), to_apply=apply_op, channel_id=3
-        start2 = u32[] all-reduce-start(id), to_apply=apply_op, channel_id=4
-        done1 = u32[] all-reduce-done(start1)
-        done2 = u32[] all-reduce-done(start2)
-        ROOT x = u32[] add(done1, done2)
-      }
-    )";
-  TF_ASSERT_OK_AND_ASSIGN(auto module,
-                          ParseAndReturnVerifiedModule(hlo_string));
-  TF_ASSERT_OK(RunPass(module.get(), /*expect_change=*/true));
-  EXPECT_THAT(IsSync(module.get(), "start1"), IsTrue());
-  EXPECT_THAT(IsSync(module.get(), "start2"), IsTrue());
-}
-
-// Test multiple in-flight collectives that are nested: {s0,{s1,e1},e0}
-TEST_F(GpuConvertAsyncCollectivesToSyncTest, MultipleInFlightNested) {
-  const absl::string_view hlo_string = R"(
-      HloModule test, is_scheduled=true
-
-      apply_op {
-        x = u32[] parameter(0)
-        y = u32[] parameter(1)
-        ROOT apply_op = u32[] add(x, y)
-      }
-
-      ENTRY test_computation {
-        id = u32[] replica-id()
-        start1 = u32[] all-reduce-start(id), to_apply=apply_op, channel_id=3
-        start2 = u32[] all-reduce-start(id), to_apply=apply_op, channel_id=4
-        done2 = u32[] all-reduce-done(start2)
-        done1 = u32[] all-reduce-done(start1)
-        ROOT x = u32[] add(done1, done2)
-      }
-    )";
-  TF_ASSERT_OK_AND_ASSIGN(auto module,
-                          ParseAndReturnVerifiedModule(hlo_string));
-  TF_ASSERT_OK(RunPass(module.get(), /*expect_change=*/true));
-  EXPECT_THAT(IsSync(module.get(), "start1"), IsTrue());
-  EXPECT_THAT(IsSync(module.get(), "start2"), IsTrue());
-}
-
-// Test multiple in-flight collectives that are nested: {s0,{s1,e1},e0} where
-// inner pair can be converted but not outer.
-TEST_F(GpuConvertAsyncCollectivesToSyncTest, MultipleInFlightNestedPartial) {
-  const absl::string_view hlo_string = R"(
-      HloModule test, is_scheduled=true
-
-      apply_op {
-        x = u32[] parameter(0)
-        y = u32[] parameter(1)
-        ROOT apply_op = u32[] add(x, y)
-      }
-
-      ENTRY test_computation {
-        id = u32[] replica-id()
-        start1 = u32[] all-reduce-start(id), to_apply=apply_op, channel_id=3
-        start2 = u32[] all-reduce-start(id), to_apply=apply_op, channel_id=4
-        done2 = u32[] all-reduce-done(start2)
-        id2 = u32[] add(done2, done2)
-        done1 = u32[] all-reduce-done(start1)
-        ROOT x = u32[] add(done1, done2)
-      }
-    )";
-  TF_ASSERT_OK_AND_ASSIGN(auto module,
-                          ParseAndReturnVerifiedModule(hlo_string));
-  TF_ASSERT_OK(RunPass(module.get(), /*expect_change=*/true));
-  EXPECT_THAT(IsSync(module.get(), "start1"), IsFalse());
-  EXPECT_THAT(IsSync(module.get(), "start2"), IsTrue());
-}
-
-}  // namespace
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gpu_executable.cc b/third_party/xla/xla/service/gpu/gpu_executable.cc
index bf97747..25fe651 100644
--- a/third_party/xla/xla/service/gpu/gpu_executable.cc
+++ b/third_party/xla/xla/service/gpu/gpu_executable.cc
@@ -841,9 +841,14 @@
     TF_ASSIGN_OR_RETURN(globals, ResolveConstantGlobals(run_options->stream()));
   }
 
-  auto device_ordinal = executor->device_ordinal();
+  // Use the `device_ordinal` from the `run_options` if it is provided. This is
+  // the ordinal of the logical devices (e.g., virtual GPUs). If it is not
+  // provided, the ordinals of the logical and physical devices are the same.
+  const int device_ordinal = run_options->device_ordinal() != -1
+                                 ? run_options->device_ordinal()
+                                 : executor->device_ordinal();
   ExecutionOutput result(/*on_device_shape=*/output_shape_, memory_allocator,
-                         device_ordinal);
+                         device_ordinal, executor->device_ordinal());
 
   TF_ASSIGN_OR_RETURN(
       BufferAllocations buffer_allocations,
diff --git a/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.cc b/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.cc
deleted file mode 100644
index 566c006..0000000
--- a/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.cc
+++ /dev/null
@@ -1,719 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gpu_fused_mha_runner.h"
-
-#include <cstdint>
-#include <optional>
-#include <string>
-
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/str_format.h"
-#include "Eigen/Core"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/service/gpu/stream_executor_util.h"
-#include "xla/shape.h"
-#include "xla/stream_executor/device_memory.h"
-#include "xla/stream_executor/dnn.h"
-#include "xla/stream_executor/lazy_op_runner.h"
-#include "xla/stream_executor/stream.h"
-#include "xla/util.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-namespace {
-using se::DeviceMemory;
-using se::DeviceMemoryBase;
-using se::dnn::DataType;
-using se::dnn::MatmulTensorDescriptor;
-using se::dnn::TensorDescriptor;
-
-template <typename ElementType, typename BiasType, typename OutputType>
-absl::Status RunFusedMHA(GpufMHAParams params, se::Stream *stream,
-                         RunFusedMHAOptions options,
-                         DeviceMemory<ElementType> lhs_bmm1_buffer,
-                         DeviceMemory<ElementType> rhs_bmm1_buffer,
-                         DeviceMemory<ElementType> rhs_bmm2_buffer,
-                         DeviceMemory<OutputType> output_buffer,
-                         DeviceMemoryBase bias_buffer,
-                         DeviceMemoryBase scratch_memory,
-                         DeviceMemoryBase activation_output,
-                         DeviceMemoryBase seqlen_q, DeviceMemoryBase seqlen_k) {
-  se::dnn::LazyOpRunner<se::dnn::FusedMHAOp> *lazy_runner =
-      options.runner_cache->AsFusedMHARunner();
-  std::optional<se::dnn::LazyOpRunner<se::dnn::FusedMHAOp>> local_runner;
-  if (!lazy_runner) {
-    local_runner.emplace(params.config->algorithm);
-    lazy_runner = &*local_runner;
-  }
-  std::optional<double> dropout_rate;
-  if (params.config->dropout_rate) {
-    dropout_rate = *params.config->dropout_rate;
-  }
-
-  std::optional<int64_t> seed;
-  if (params.config->seed) {
-    seed = *params.config->seed;
-  }
-
-  TF_ASSIGN_OR_RETURN(se::dnn::FusedMHAOp::Config config,
-                      params.config->AsDnnFusedMHAOpConfig());
-  TF_ASSIGN_OR_RETURN(auto *runner,
-                      lazy_runner->GetOrCreateRunner(config, stream));
-  return (*runner)(stream, options.profile_result, scratch_memory,
-                   lhs_bmm1_buffer, rhs_bmm1_buffer, rhs_bmm2_buffer,
-                   output_buffer, bias_buffer, activation_output, seqlen_q,
-                   seqlen_k);
-}
-
-template <typename ElementType, typename BiasType, typename OutputType>
-absl::Status RunGpuFMHAImpl(const GpufMHAParams &params, se::Stream *stream,
-                            se::DeviceMemoryBase scratch_memory,
-                            RunFusedMHAOptions options) {
-  auto lhs_bmm1_buffer = se::DeviceMemory<ElementType>(params.lhs_bmm1_buffer);
-  auto rhs_bmm1_buffer = se::DeviceMemory<ElementType>(params.rhs_bmm1_buffer);
-  auto rhs_bmm2_buffer = se::DeviceMemory<ElementType>(params.rhs_bmm2_buffer);
-  auto output_buffer = se::DeviceMemory<OutputType>(params.output_buffer);
-  auto activation_buffer =
-      params.activation_buffer.has_value()
-          ? se::DeviceMemory<OutputType>(*params.activation_buffer)
-          : se::DeviceMemoryBase();
-  auto bias_buffer = params.bias_buffer.has_value()
-                         ? se::DeviceMemory<BiasType>(*params.bias_buffer)
-                         : se::DeviceMemoryBase();
-  auto seqlen_q_buffer =
-      params.seqlen_q_buffer.has_value()
-          ? se::DeviceMemory<BiasType>(*params.seqlen_q_buffer)
-          : se::DeviceMemoryBase();
-  auto seqlen_k_buffer =
-      params.seqlen_k_buffer.has_value()
-          ? se::DeviceMemory<BiasType>(*params.seqlen_k_buffer)
-          : se::DeviceMemoryBase();
-  se::dnn::AlgorithmDesc algorithm = params.config->algorithm;
-  if (options.runner_cache) {
-    algorithm = options.runner_cache->ToAlgorithmDesc();
-  }
-
-  absl::Status run_status = absl::OkStatus();
-  switch (params.config->kind) {
-    case CudnnfMHAKind::kSoftmaxDropout:
-    case CudnnfMHAKind::kSoftmax:
-    case CudnnfMHAKind::kScaleBiasSoftmax:
-    case CudnnfMHAKind::kScaleBiasSoftmaxDropout:
-      run_status = RunFusedMHA<ElementType, BiasType, OutputType>(
-          params, stream, options, lhs_bmm1_buffer, rhs_bmm1_buffer,
-          rhs_bmm2_buffer, output_buffer, bias_buffer, scratch_memory,
-          activation_buffer, seqlen_q_buffer, seqlen_k_buffer);
-      break;
-    default:
-      return Internal("Invalid cuDNN fMHA kind");
-  }
-
-  if (!run_status.ok()) {
-    return run_status;
-  }
-
-  if (!stream->ok()) {
-    return Internal("Unable to launch FMHA with type %s and algorithm %s",
-                    CudnnfMHAKindToString(params.config->kind),
-                    algorithm.ToString());
-  }
-
-  return absl::OkStatus();
-}
-
-template <typename ElementType, typename OutputType>
-absl::Status RunFusedMHABackward(
-    GpufMHABackwardParams params, se::Stream *stream,
-    RunFusedMHABackwardOptions options,
-    DeviceMemory<ElementType> bmm1_grad_gemm1_rhs_buffer,
-    DeviceMemory<ElementType> bmm1_grad_gemm2_rhs_buffer,
-    DeviceMemory<ElementType> bmm2_grad_gemm1_lhs_buffer,
-    DeviceMemory<ElementType> bmm2_grad_gemm2_rhs_buffer,
-    DeviceMemory<ElementType> d_output_buffer,
-    DeviceMemory<OutputType> d_bmm1_lhs_buffer,
-    DeviceMemory<OutputType> d_bmm1_rhs_buffer,
-    DeviceMemory<OutputType> d_bmm2_rhs_buffer, DeviceMemoryBase d_s_buffer,
-    DeviceMemoryBase d_bias_buffer, DeviceMemoryBase fwd_output_buffer,
-    DeviceMemoryBase bias_buffer, DeviceMemoryBase scratch_memory,
-    DeviceMemoryBase seqlen_q, DeviceMemoryBase seqlen_k) {
-  se::dnn::LazyOpRunner<se::dnn::FusedMHABackwardOp> *lazy_runner =
-      options.runner_cache->AsFusedMHABackwardRunner();
-  std::optional<se::dnn::LazyOpRunner<se::dnn::FusedMHABackwardOp>>
-      local_runner;
-  if (!lazy_runner) {
-    local_runner.emplace(params.config->algorithm);
-    lazy_runner = &*local_runner;
-  }
-  std::optional<double> dropout_rate;
-  if (params.config->dropout_rate) {
-    dropout_rate = *params.config->dropout_rate;
-  }
-
-  std::optional<int64_t> seed;
-  if (params.config->seed) {
-    seed = *params.config->seed;
-  }
-
-  TF_ASSIGN_OR_RETURN(se::dnn::FusedMHABackwardOp::Config config,
-                      params.config->AsDnnFusedMHABackwardOpConfig());
-  TF_ASSIGN_OR_RETURN(auto *runner,
-                      lazy_runner->GetOrCreateRunner(config, stream));
-  // TODO: pass in real softmax_sum, dQ_accum, fwd_output
-  return (*runner)(stream, options.profile_result, scratch_memory,
-                   bmm1_grad_gemm1_rhs_buffer, bmm1_grad_gemm2_rhs_buffer,
-                   bmm2_grad_gemm1_lhs_buffer, bmm2_grad_gemm2_rhs_buffer,
-                   d_output_buffer, d_bmm1_lhs_buffer, d_bmm1_rhs_buffer,
-                   d_bmm2_rhs_buffer, d_s_buffer, d_bias_buffer,
-                   fwd_output_buffer, bias_buffer, seqlen_q, seqlen_k);
-  return absl::OkStatus();
-}
-
-template <typename ElementType, typename BiasType, typename OutputType>
-absl::Status RunGpuFMHABackwardImpl(const GpufMHABackwardParams &params,
-                                    se::Stream *stream,
-                                    se::DeviceMemoryBase scratch_memory,
-                                    RunFusedMHABackwardOptions options) {
-  auto bmm1_grad_gemm1_rhs_buffer =
-      se::DeviceMemory<ElementType>(params.bmm1_grad_gemm1_rhs_buffer);
-  auto bmm1_grad_gemm2_rhs_buffer =
-      se::DeviceMemory<ElementType>(params.bmm1_grad_gemm2_rhs_buffer);
-  auto bmm2_grad_gemm1_lhs_buffer =
-      se::DeviceMemory<ElementType>(params.bmm2_grad_gemm1_lhs_buffer);
-  auto bmm2_grad_gemm2_rhs_buffer =
-      se::DeviceMemory<ElementType>(params.bmm2_grad_gemm2_rhs_buffer);
-  auto d_output_buffer = se::DeviceMemory<ElementType>(params.d_output_buffer);
-  auto d_bmm1_lhs_buffer =
-      se::DeviceMemory<OutputType>(params.d_bmm1_lhs_buffer);
-  auto d_bmm1_rhs_buffer =
-      se::DeviceMemory<OutputType>(params.d_bmm1_rhs_buffer);
-  auto d_bmm2_rhs_buffer =
-      se::DeviceMemory<OutputType>(params.d_bmm2_rhs_buffer);
-
-  // optional buffers
-  auto d_s_buffer = params.d_s_buffer.has_value()
-                        ? se::DeviceMemory<OutputType>(*params.d_s_buffer)
-                        : se::DeviceMemoryBase();
-
-  auto d_bias_buffer = params.d_bias_buffer.has_value()
-                           ? se::DeviceMemory<OutputType>(*params.d_bias_buffer)
-                           : se::DeviceMemoryBase();
-
-  auto fwd_output_buffer =
-      params.fwd_output_buffer.has_value()
-          ? se::DeviceMemory<ElementType>(*params.fwd_output_buffer)
-          : se::DeviceMemoryBase();
-
-  auto bias_buffer = params.bias_buffer.has_value()
-                         ? se::DeviceMemory<BiasType>(*params.bias_buffer)
-                         : se::DeviceMemoryBase();
-
-  auto seqlen_q_buffer =
-      params.seqlen_q_buffer.has_value()
-          ? se::DeviceMemory<BiasType>(*params.seqlen_q_buffer)
-          : se::DeviceMemoryBase();
-
-  auto seqlen_k_buffer =
-      params.seqlen_k_buffer.has_value()
-          ? se::DeviceMemory<BiasType>(*params.seqlen_k_buffer)
-          : se::DeviceMemoryBase();
-
-  se::dnn::AlgorithmDesc algorithm = params.config->algorithm;
-  if (options.runner_cache) {
-    algorithm = options.runner_cache->ToAlgorithmDesc();
-  }
-
-  absl::Status run_status = absl::OkStatus();
-  switch (params.config->kind) {
-    case CudnnfMHAKind::kBackwardSoftmaxDropout:
-    case CudnnfMHAKind::kBackwardSoftmax:
-    case CudnnfMHAKind::kBackwardScaleBiasSoftmax:
-    case CudnnfMHAKind::kBackwardScaleBiasSoftmaxDropout:
-      run_status = RunFusedMHABackward<ElementType, OutputType>(
-          params, stream, options, bmm1_grad_gemm1_rhs_buffer,
-          bmm1_grad_gemm2_rhs_buffer, bmm2_grad_gemm1_lhs_buffer,
-          bmm2_grad_gemm2_rhs_buffer, d_output_buffer, d_bmm1_lhs_buffer,
-          d_bmm1_rhs_buffer, d_bmm2_rhs_buffer, d_s_buffer, d_bias_buffer,
-          fwd_output_buffer, bias_buffer, scratch_memory, seqlen_q_buffer,
-          seqlen_k_buffer);
-      break;
-    default:
-      return Internal("Invalid cuDNN fMHA kind");
-  }
-
-  if (!run_status.ok()) {
-    return run_status;
-  }
-
-  if (!stream->ok()) {
-    return Internal("Unable to launch FMHA with type %s and algorithm %s",
-                    CudnnfMHAKindToString(params.config->kind),
-                    algorithm.ToString());
-  }
-
-  return run_status;
-}
-}  // namespace
-
-/*static*/ absl::StatusOr<GpufMHAConfig> GpufMHAConfig::For(
-    const GpufMHADescriptor &desc) {
-  // Get shapes from desc.
-  const Shape &lhs_bmm1_shape = desc.lhs_bmm1_shape;
-  const Shape &rhs_bmm1_shape = desc.rhs_bmm1_shape;
-  const Shape &rhs_bmm2_shape = desc.rhs_bmm2_shape;
-  const Shape &intermediate_lhs_bmm2_shape = desc.intermediate_lhs_bmm2_shape;
-  const Shape &output_shape = desc.output_shapes[0];
-
-  // Get DNN dtype from primtive types
-  TF_ASSIGN_OR_RETURN(
-      DataType lhs_bmm1_type,
-      GetDNNDataTypeFromPrimitiveType(lhs_bmm1_shape.element_type()));
-  TF_ASSIGN_OR_RETURN(
-      DataType rhs_bmm1_type,
-      GetDNNDataTypeFromPrimitiveType(rhs_bmm1_shape.element_type()));
-
-  TF_ASSIGN_OR_RETURN(
-      DataType rhs_bmm2_type,
-      GetDNNDataTypeFromPrimitiveType(rhs_bmm2_shape.element_type()));
-  TF_ASSIGN_OR_RETURN(DataType lhs_bmm2_type,
-                      GetDNNDataTypeFromPrimitiveType(
-                          intermediate_lhs_bmm2_shape.element_type()));
-  TF_ASSIGN_OR_RETURN(DataType output_type, GetDNNDataTypeFromPrimitiveType(
-                                                output_shape.element_type()));
-  GpufMHAConfig config;
-  config.input_type = lhs_bmm1_shape.element_type();
-  config.output_type = output_shape.element_type();
-
-  // Get MatmulTensorDescriptors for BMM1
-  config.lhs_bmm1 =
-      MatmulTensorDescriptor::For(lhs_bmm1_type, lhs_bmm1_shape.dimensions(),
-                                  desc.lhs_bmm1_shape.layout().minor_to_major(),
-                                  desc.bmm1_dnums.lhs_batch_dimensions(),
-                                  desc.bmm1_dnums.lhs_contracting_dimensions());
-  config.rhs_bmm1 =
-      MatmulTensorDescriptor::For(rhs_bmm1_type, rhs_bmm1_shape.dimensions(),
-                                  desc.rhs_bmm1_shape.layout().minor_to_major(),
-                                  desc.bmm1_dnums.rhs_batch_dimensions(),
-                                  desc.bmm1_dnums.rhs_contracting_dimensions());
-
-  // Get MatmulTensorDescriptors for BMM2
-  config.rhs_bmm2 =
-      MatmulTensorDescriptor::For(rhs_bmm2_type, rhs_bmm2_shape.dimensions(),
-                                  desc.rhs_bmm2_shape.layout().minor_to_major(),
-                                  desc.bmm2_dnums.rhs_batch_dimensions(),
-                                  desc.bmm2_dnums.rhs_contracting_dimensions());
-
-  config.intermediate_lhs_bmm2 = MatmulTensorDescriptor::For(
-      lhs_bmm2_type, intermediate_lhs_bmm2_shape.dimensions(),
-      desc.intermediate_lhs_bmm2_shape.layout().minor_to_major(),
-      desc.bmm2_dnums.lhs_batch_dimensions(),
-      desc.bmm2_dnums.lhs_contracting_dimensions());
-
-  config.output = TensorDescriptor::For(output_type, output_shape.dimensions(),
-                                        output_shape.layout().minor_to_major());
-
-  if (desc.output_shapes.size() > 1) {
-    const Shape &activation_shape = desc.output_shapes.back();
-    // Generally, activation should have same type as output, but set it
-    // explicityly just to be safe.
-    TF_ASSIGN_OR_RETURN(
-        DataType activation_type,
-        GetDNNDataTypeFromPrimitiveType(activation_shape.element_type()));
-    config.activation =
-        TensorDescriptor::For(activation_type, activation_shape.dimensions(),
-                              activation_shape.layout().minor_to_major());
-  }
-
-  if (desc.mask_shape) {
-    const Shape &mask_shape = *desc.mask_shape;
-    TF_ASSIGN_OR_RETURN(DataType mask_type, GetDNNDataTypeFromPrimitiveType(
-                                                mask_shape.element_type()));
-    config.mask = TensorDescriptor::For(mask_type, mask_shape.dimensions(),
-                                        mask_shape.layout().minor_to_major());
-  }
-
-  if (desc.bias_shape) {
-    const Shape &bias_shape = *desc.bias_shape;
-    TF_ASSIGN_OR_RETURN(DataType bias_type, GetDNNDataTypeFromPrimitiveType(
-                                                bias_shape.element_type()));
-    config.bias = TensorDescriptor::For(bias_type, bias_shape.dimensions(),
-                                        bias_shape.layout().minor_to_major());
-  }
-  config.kind = desc.kind;
-  config.mask_type = desc.mask_type;
-  const CudnnfMHABackendConfig &backend_config = desc.backend_config;
-  config.algorithm = se::dnn::AlgorithmDesc(backend_config.algorithm());
-  config.fmha_scale.emplace(backend_config.fmha_scale());
-  config.dropout_rate.emplace(backend_config.dropout_rate());
-  config.seed.emplace(backend_config.seed());
-  return config;
-}
-
-absl::StatusOr<se::dnn::FusedMHAOp::Config>
-GpufMHAConfig::AsDnnFusedMHAOpConfig() const {
-  double scale = 1.0;
-  if (fmha_scale.has_value()) {
-    scale = *fmha_scale;
-  }
-  TF_ASSIGN_OR_RETURN(se::dnn::FMHAMaskKind mask_type,
-                      GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(mask_type));
-
-  return se::dnn::FusedMHAOp::Config{
-      scale,    lhs_bmm1, rhs_bmm1,   rhs_bmm2,     intermediate_lhs_bmm2,
-      output,   bias,     activation, dropout_rate, seed,
-      mask_type};
-}
-
-/*static*/ absl::StatusOr<GpufMHABackwardConfig> GpufMHABackwardConfig::For(
-    const GpufMHABackwardDescriptor &desc) {
-  // Get shapes from desc.
-
-  const Shape &bmm1_grad_gemm1_rhs_shape = desc.bmm1_grad_gemm1_rhs_shape;
-  const Shape &bmm1_grad_gemm2_rhs_shape = desc.bmm1_grad_gemm2_rhs_shape;
-  const Shape &bmm2_grad_gemm1_lhs_shape = desc.bmm2_grad_gemm1_lhs_shape;
-  const Shape &bmm2_grad_gemm2_rhs_shape = desc.bmm2_grad_gemm2_rhs_shape;
-  const Shape &d_output_shape = desc.d_output_shape;
-  const Shape &d_bmm1_lhs_shape = desc.d_bmm1_lhs_shape;
-  const Shape &d_bmm1_rhs_shape = desc.d_bmm1_rhs_shape;
-  const Shape &d_bmm2_rhs_shape = desc.d_bmm2_rhs_shape;
-  // Get DNN dtype from primtive types
-  TF_ASSIGN_OR_RETURN(DataType bmm1_grad_gemm1_rhs_type,
-                      GetDNNDataTypeFromPrimitiveType(
-                          bmm1_grad_gemm1_rhs_shape.element_type()));
-
-  TF_ASSIGN_OR_RETURN(DataType bmm1_grad_gemm2_rhs_type,
-                      GetDNNDataTypeFromPrimitiveType(
-                          bmm1_grad_gemm2_rhs_shape.element_type()));
-
-  TF_ASSIGN_OR_RETURN(DataType bmm2_grad_gemm1_lhs_type,
-                      GetDNNDataTypeFromPrimitiveType(
-                          bmm2_grad_gemm1_lhs_shape.element_type()));
-
-  TF_ASSIGN_OR_RETURN(DataType bmm2_grad_gemm2_rhs_type,
-                      GetDNNDataTypeFromPrimitiveType(
-                          bmm2_grad_gemm2_rhs_shape.element_type()));
-
-  TF_ASSIGN_OR_RETURN(
-      DataType d_output_type,
-      GetDNNDataTypeFromPrimitiveType(d_output_shape.element_type()));
-
-  TF_ASSIGN_OR_RETURN(
-      DataType d_bmm1_lhs_type,
-      GetDNNDataTypeFromPrimitiveType(d_bmm1_lhs_shape.element_type()));
-
-  TF_ASSIGN_OR_RETURN(
-      DataType d_bmm1_rhs_type,
-      GetDNNDataTypeFromPrimitiveType(d_bmm1_rhs_shape.element_type()));
-
-  TF_ASSIGN_OR_RETURN(
-      DataType d_bmm2_rhs_type,
-      GetDNNDataTypeFromPrimitiveType(d_bmm2_rhs_shape.element_type()));
-
-  GpufMHABackwardConfig config;
-  config.input_type = bmm1_grad_gemm1_rhs_shape.element_type();
-  config.output_type = d_bmm1_lhs_shape.element_type();
-
-  // Get MatmulTensorDescriptors for lhs of BMM1 grad GEMM 1
-  config.bmm1_grad_gemm1_rhs = MatmulTensorDescriptor::For(
-      bmm1_grad_gemm1_rhs_type, bmm1_grad_gemm1_rhs_shape.dimensions(),
-      desc.bmm1_grad_gemm1_rhs_shape.layout().minor_to_major(),
-      desc.bmm1_grad_gemm1_dnums.rhs_batch_dimensions(),
-      desc.bmm1_grad_gemm1_dnums.rhs_contracting_dimensions());
-
-  // Get MatmulTensorDescriptors for rhs of BMM1 grad GEMM 2
-  config.bmm1_grad_gemm2_rhs = MatmulTensorDescriptor::For(
-      bmm1_grad_gemm2_rhs_type, bmm1_grad_gemm2_rhs_shape.dimensions(),
-      desc.bmm1_grad_gemm2_rhs_shape.layout().minor_to_major(),
-      desc.bmm1_grad_gemm2_dnums.rhs_batch_dimensions(),
-      desc.bmm1_grad_gemm2_dnums.rhs_contracting_dimensions());
-
-  // Get MatmulTensorDescriptors for BMM2 grad GEMM 1
-  config.bmm2_grad_gemm1_lhs = MatmulTensorDescriptor::For(
-      bmm2_grad_gemm1_lhs_type, bmm2_grad_gemm1_lhs_shape.dimensions(),
-      desc.bmm2_grad_gemm1_lhs_shape.layout().minor_to_major(),
-      desc.bmm2_grad_gemm1_dnums.lhs_batch_dimensions(),
-      desc.bmm2_grad_gemm1_dnums.lhs_contracting_dimensions());
-
-  config.d_output = MatmulTensorDescriptor::For(
-      d_output_type, d_output_shape.dimensions(),
-      desc.d_output_shape.layout().minor_to_major(),
-      desc.bmm2_grad_gemm1_dnums.rhs_batch_dimensions(),
-      desc.bmm2_grad_gemm1_dnums.rhs_contracting_dimensions());
-
-  // Get MatmulTensorDescriptors for BMM2 grad GEMM 2
-  config.bmm2_grad_gemm2_rhs = MatmulTensorDescriptor::For(
-      bmm2_grad_gemm2_rhs_type, bmm2_grad_gemm2_rhs_shape.dimensions(),
-      desc.bmm2_grad_gemm2_rhs_shape.layout().minor_to_major(),
-      desc.bmm2_grad_gemm2_dnums.rhs_batch_dimensions(),
-      desc.bmm2_grad_gemm2_dnums
-          .rhs_contracting_dimensions());  // FMHA TODO: transpose here?
-
-  config.d_bmm1_lhs =
-      TensorDescriptor::For(d_bmm1_lhs_type, d_bmm1_lhs_shape.dimensions(),
-                            d_bmm1_lhs_shape.layout().minor_to_major());
-  config.d_bmm1_rhs =
-      TensorDescriptor::For(d_bmm1_rhs_type, d_bmm1_rhs_shape.dimensions(),
-                            d_bmm1_rhs_shape.layout().minor_to_major());
-  config.d_bmm2_rhs =
-      TensorDescriptor::For(d_bmm2_rhs_type, d_bmm2_rhs_shape.dimensions(),
-                            d_bmm2_rhs_shape.layout().minor_to_major());
-  config.d_s = TensorDescriptor::For(
-      bmm2_grad_gemm1_lhs_type, bmm2_grad_gemm1_lhs_shape.dimensions(),
-      bmm2_grad_gemm1_lhs_shape.layout().minor_to_major());
-
-  if (desc.d_bias_shape) {
-    const Shape &d_bias_shape = *desc.d_bias_shape;
-    // Get DNN dtype from primtive types
-    TF_ASSIGN_OR_RETURN(DataType d_bias_type, GetDNNDataTypeFromPrimitiveType(
-                                                  d_bias_shape.element_type()));
-    config.d_bias =
-        TensorDescriptor::For(d_bias_type, d_bias_shape.dimensions(),
-                              d_bias_shape.layout().minor_to_major());
-  }
-
-  if (desc.mask_shape) {
-    const Shape &mask_shape = *desc.mask_shape;
-    TF_ASSIGN_OR_RETURN(DataType mask_type, GetDNNDataTypeFromPrimitiveType(
-                                                mask_shape.element_type()));
-    config.mask = TensorDescriptor::For(mask_type, mask_shape.dimensions(),
-                                        mask_shape.layout().minor_to_major());
-  }
-  if (desc.fwd_output_shape) {
-    const Shape &fwd_output_shape = *desc.fwd_output_shape;
-    TF_ASSIGN_OR_RETURN(
-        DataType fwd_output_type,
-        GetDNNDataTypeFromPrimitiveType(fwd_output_shape.element_type()));
-    config.fwd_output =
-        TensorDescriptor::For(fwd_output_type, fwd_output_shape.dimensions(),
-                              fwd_output_shape.layout().minor_to_major());
-  }
-
-  if (desc.bias_shape) {
-    const Shape &bias_shape = *desc.bias_shape;
-    TF_ASSIGN_OR_RETURN(DataType bias_type, GetDNNDataTypeFromPrimitiveType(
-                                                bias_shape.element_type()));
-    config.bias = TensorDescriptor::For(bias_type, bias_shape.dimensions(),
-                                        bias_shape.layout().minor_to_major());
-  }
-
-  config.kind = desc.kind;
-  config.mask_type = desc.mask_type;
-  config.force_deterministic = desc.force_deterministic;
-  const CudnnfMHABackendConfig &backend_config = desc.backend_config;
-  config.algorithm = se::dnn::AlgorithmDesc(backend_config.algorithm());
-  config.fmha_scale.emplace(backend_config.fmha_scale());
-  config.dropout_rate.emplace(backend_config.dropout_rate());
-  config.seed.emplace(backend_config.seed());
-  return config;
-}
-
-absl::StatusOr<se::dnn::FusedMHABackwardOp::Config>
-GpufMHABackwardConfig::AsDnnFusedMHABackwardOpConfig() const {
-  double scale = 1.0;
-  if (fmha_scale.has_value()) {
-    scale = *fmha_scale;
-  }
-  TF_ASSIGN_OR_RETURN(se::dnn::FMHAMaskKind mask_type,
-                      GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(mask_type));
-
-  return se::dnn::FusedMHABackwardOp::Config{scale,
-                                             bmm1_grad_gemm1_rhs,
-                                             bmm1_grad_gemm2_rhs,
-                                             bmm2_grad_gemm1_lhs,
-                                             bmm2_grad_gemm2_rhs,
-                                             d_output,
-                                             d_bmm1_lhs,
-                                             d_bmm1_rhs,
-                                             d_bmm2_rhs,
-                                             d_s,
-                                             d_bias,
-                                             fwd_output,
-                                             bias,
-                                             dropout_rate,
-                                             seed,
-                                             mask_type,
-                                             force_deterministic};
-}
-
-/*static*/ absl::StatusOr<GpufMHAParams> GpufMHAParams::For(
-    const GpufMHAConfig &config, se::DeviceMemoryBase lhs_bmm1_buffer,
-    se::DeviceMemoryBase rhs_bmm1_buffer, se::DeviceMemoryBase rhs_bmm2_buffer,
-    se::DeviceMemoryBase output_buffer,
-    std::optional<se::DeviceMemoryBase> bias_buffer,
-    std::optional<se::DeviceMemoryBase> activation_buffer,
-    std::optional<se::DeviceMemoryBase> seqlen_q_buffer,
-    std::optional<se::DeviceMemoryBase> seqlen_k_buffer) {
-  GpufMHAParams params;
-  params.config = &config;
-  params.lhs_bmm1_buffer = lhs_bmm1_buffer;
-  params.rhs_bmm1_buffer = rhs_bmm1_buffer;
-  params.rhs_bmm2_buffer = rhs_bmm2_buffer;
-  params.output_buffer = output_buffer;
-  params.activation_buffer = activation_buffer;
-  params.bias_buffer = bias_buffer;
-  params.seqlen_q_buffer = seqlen_q_buffer;
-  params.seqlen_k_buffer = seqlen_k_buffer;
-  return params;
-}
-
-/*static*/ absl::StatusOr<GpufMHABackwardParams> GpufMHABackwardParams::For(
-    const GpufMHABackwardConfig &config,
-    se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer,
-    se::DeviceMemoryBase bmm1_grad_gemm2_rhs_buffer,
-    se::DeviceMemoryBase bmm2_grad_gemm1_lhs_buffer,
-    se::DeviceMemoryBase bmm2_grad_gemm2_rhs_buffer,
-    se::DeviceMemoryBase d_output_buffer,
-    se::DeviceMemoryBase d_bmm1_lhs_buffer,
-    se::DeviceMemoryBase d_bmm1_rhs_buffer,
-    se::DeviceMemoryBase d_bmm2_rhs_buffer,
-    std::optional<se::DeviceMemoryBase> d_s_buffer,
-    std::optional<se::DeviceMemoryBase> d_bias_buffer,
-    std::optional<se::DeviceMemoryBase> fwd_output_buffer,
-    std::optional<se::DeviceMemoryBase> bias_buffer,
-    std::optional<se::DeviceMemoryBase> seqlen_q_buffer,
-    std::optional<se::DeviceMemoryBase> seqlen_k_buffer) {
-  GpufMHABackwardParams params;
-  params.config = &config;
-  params.bmm1_grad_gemm1_rhs_buffer = bmm1_grad_gemm1_rhs_buffer;
-  params.bmm1_grad_gemm2_rhs_buffer = bmm1_grad_gemm2_rhs_buffer;
-  params.bmm2_grad_gemm1_lhs_buffer = bmm2_grad_gemm1_lhs_buffer;
-  params.bmm2_grad_gemm2_rhs_buffer = bmm2_grad_gemm2_rhs_buffer;
-  params.d_output_buffer = d_output_buffer;
-  params.d_bmm1_lhs_buffer = d_bmm1_lhs_buffer;
-  params.d_bmm1_rhs_buffer = d_bmm1_rhs_buffer;
-  params.d_bmm2_rhs_buffer = d_bmm2_rhs_buffer;
-  params.d_s_buffer = d_s_buffer;
-  params.d_bias_buffer = d_bias_buffer;
-  params.fwd_output_buffer = fwd_output_buffer;
-  params.bias_buffer = bias_buffer;
-  params.seqlen_q_buffer = seqlen_q_buffer;
-  params.seqlen_k_buffer = seqlen_k_buffer;
-  return params;
-}
-
-absl::Status RunGpuFMHA(const GpufMHAConfig &fmha_config,
-                        se::DeviceMemoryBase lhs_bmm1_buffer,
-                        se::DeviceMemoryBase rhs_bmm1_buffer,
-                        se::DeviceMemoryBase rhs_bmm2_buffer,
-                        se::DeviceMemoryBase output_buffer,
-                        se::DeviceMemoryBase scratch_buffer,
-                        std::optional<se::DeviceMemoryBase> bias_buffer,
-                        std::optional<se::DeviceMemoryBase> activation_buffer,
-                        std::optional<se::DeviceMemoryBase> seqlen_q_buffer,
-                        std::optional<se::DeviceMemoryBase> seqlen_k_buffer,
-                        se::Stream *stream, RunFusedMHAOptions options) {
-  TF_ASSIGN_OR_RETURN(
-      GpufMHAParams params,
-      GpufMHAParams::For(fmha_config, lhs_bmm1_buffer, rhs_bmm1_buffer,
-                         rhs_bmm2_buffer, output_buffer, bias_buffer,
-                         activation_buffer, seqlen_q_buffer, seqlen_k_buffer));
-  PrimitiveType input_primitive_type = fmha_config.input_type;
-  switch (input_primitive_type) {
-    case F16:
-      return RunGpuFMHAImpl<Eigen::half, Eigen::half, Eigen::half>(
-          params, stream, scratch_buffer, options);
-    case BF16:
-      return RunGpuFMHAImpl<Eigen::bfloat16, Eigen::bfloat16, Eigen::bfloat16>(
-          params, stream, scratch_buffer, options);
-    default:
-      return absl::UnimplementedError(absl::StrFormat(
-          "Unimplemented fused MHA with %s", ToString(fmha_config)));
-  }
-  return absl::OkStatus();
-}
-
-absl::Status RunGpuFMHABackward(
-    const GpufMHABackwardConfig &fmha_config,
-    se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer,
-    se::DeviceMemoryBase bmm1_grad_gemm2_rhs_buffer,
-    se::DeviceMemoryBase bmm2_grad_gemm1_lhs_buffer,
-    se::DeviceMemoryBase bmm2_grad_gemm2_rhs_buffer,
-    se::DeviceMemoryBase d_output_buffer, se::DeviceMemoryBase scratch_buffer,
-    se::DeviceMemoryBase d_bmm1_lhs_buffer,
-    se::DeviceMemoryBase d_bmm1_rhs_buffer,
-    se::DeviceMemoryBase d_bmm2_rhs_buffer,
-    std::optional<se::DeviceMemoryBase> d_s_buffer,
-    std::optional<se::DeviceMemoryBase> d_bias_buffer,
-    std::optional<se::DeviceMemoryBase> fwd_output_buffer,
-    std::optional<se::DeviceMemoryBase> bias_buffer,
-    std::optional<se::DeviceMemoryBase> seqlen_q_buffer,
-    std::optional<se::DeviceMemoryBase> seqlen_k_buffer, se::Stream *stream,
-    RunFusedMHABackwardOptions options) {
-  TF_ASSIGN_OR_RETURN(
-      GpufMHABackwardParams params,
-      GpufMHABackwardParams::For(
-          fmha_config, bmm1_grad_gemm1_rhs_buffer, bmm1_grad_gemm2_rhs_buffer,
-          bmm2_grad_gemm1_lhs_buffer, bmm2_grad_gemm2_rhs_buffer,
-          d_output_buffer, d_bmm1_lhs_buffer, d_bmm1_rhs_buffer,
-          d_bmm2_rhs_buffer, d_s_buffer, d_bias_buffer, fwd_output_buffer,
-          bias_buffer, seqlen_q_buffer, seqlen_k_buffer));
-  PrimitiveType input_primitive_type = fmha_config.input_type;
-  switch (input_primitive_type) {
-    case F16:
-      return RunGpuFMHABackwardImpl<Eigen::half, Eigen::half, Eigen::half>(
-          params, stream, scratch_buffer, options);
-    case BF16:
-      return RunGpuFMHABackwardImpl<Eigen::bfloat16, Eigen::bfloat16,
-                                    Eigen::bfloat16>(params, stream,
-                                                     scratch_buffer, options);
-    default:
-      return Unimplemented("Unimplemented fused MHA backward");
-  }
-  return absl::OkStatus();
-}
-
-std::string ToString(const GpufMHAConfig &config) {
-  std::string result = "GpufMHAConfig:\n";
-  absl::StrAppend(&result,
-                  "input_type: ", PrimitiveType_Name(config.input_type), ", ");
-  absl::StrAppend(
-      &result, "output_type: ", PrimitiveType_Name(config.output_type), ", ");
-  absl::StrAppend(&result, "Kind: ", CudnnfMHAKindToString(config.kind), ", ");
-  if (config.fmha_scale) {
-    absl::StrAppend(&result, "fmha_scale: ", *config.fmha_scale, ", ");
-  }
-  if (config.dropout_rate) {
-    absl::StrAppend(&result, "dropout_rate: ", *config.dropout_rate, ", ");
-  }
-  if (config.seed) {
-    absl::StrAppend(&result, "seed: ", *config.seed, ", ");
-  }
-  absl::StrAppend(&result, "Algorithm Desc: ", config.algorithm.ToString(),
-                  "\n");
-  absl::StrAppend(&result, "lhs_bmm1: ", config.lhs_bmm1.ToString(), "\n");
-  absl::StrAppend(&result, "rhs_bmm1: ", config.rhs_bmm1.ToString(), "\n");
-  absl::StrAppend(&result, "rhs_bmm2: ", config.rhs_bmm2.ToString(), "\n");
-  absl::StrAppend(&result, "intermediate_lhs_bmm2: ",
-                  config.intermediate_lhs_bmm2.ToString(), "\n");
-  absl::StrAppend(&result, "output: ", config.output.ToString(), "\n");
-
-  if (config.mask) {
-    absl::StrAppend(&result, "mask: ", (*config.mask).ToString(), "\n");
-  }
-
-  if (config.bias) {
-    absl::StrAppend(&result, "bias: ", (*config.bias).ToString(), "\n");
-  }
-
-  return result;
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.h b/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.h
deleted file mode 100644
index d0621cb..0000000
--- a/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.h
+++ /dev/null
@@ -1,431 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_GPU_FUSED_MHA_RUNNER_H_
-#define XLA_SERVICE_GPU_GPU_FUSED_MHA_RUNNER_H_
-
-#include <cstdint>
-#include <memory>
-#include <optional>
-#include <string>
-#include <utility>
-#include <variant>
-
-#include "absl/container/inlined_vector.h"
-#include "absl/log/log.h"
-#include "absl/status/status.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/shape.h"
-#include "xla/stream_executor/device_memory.h"
-#include "xla/stream_executor/dnn.h"
-#include "xla/stream_executor/lazy_op_runner.h"
-#include "xla/stream_executor/stream_executor.h"
-#include "xla/xla_data.pb.h"
-
-namespace xla {
-namespace gpu {
-
-inline absl::StatusOr<xla::gpu::CudnnfMHAMaskKind> AsCudnnFmhaMaskKind(
-    xla::gpu::CudnnfMHABackendConfig_MaskType mask_type) {
-  switch (mask_type) {
-    case xla::gpu::CudnnfMHABackendConfig::NO_MASK:
-      return xla::gpu::CudnnfMHAMaskKind::kNoMask;
-    case xla::gpu::CudnnfMHABackendConfig::PADDING:
-      return xla::gpu::CudnnfMHAMaskKind::kPadding;
-    case xla::gpu::CudnnfMHABackendConfig::CAUSAL:
-      return xla::gpu::CudnnfMHAMaskKind::kCausal;
-    case xla::gpu::CudnnfMHABackendConfig::PADDING_CAUSAL:
-      return xla::gpu::CudnnfMHAMaskKind::kPaddingCausal;
-    case xla::gpu::CudnnfMHABackendConfig::ALIBI:
-      return xla::gpu::CudnnfMHAMaskKind::kAlibi;
-    default:
-      return xla::Internal("Unknown fmha mask kind.");
-  }
-}
-
-// This is an interim structure to hold the parameters to construct a
-// GpufMHAConfig.
-// Struct to describe properties of a FMHA without being tied to specific
-// IR. Will be used to help build FMHA thunks from either XLA HLO or
-// LHLO GPU dialect in MLIR.
-struct GpufMHADescriptor {
-  CudnnfMHAKind kind;
-  CudnnfMHABackendConfig backend_config;
-  CudnnfMHAMaskKind mask_type;
-  Shape lhs_bmm1_shape;
-  Shape rhs_bmm1_shape;
-  Shape rhs_bmm2_shape;
-  Shape intermediate_lhs_bmm2_shape;
-  // This will contain both output shape and activation shape
-  absl::InlinedVector<Shape, 2> output_shapes;
-  DotDimensionNumbers bmm1_dnums;
-  DotDimensionNumbers bmm2_dnums;
-
-  std::optional<Shape> mask_shape;
-  std::optional<Shape> bias_shape;
-};
-
-struct GpufMHABackwardDescriptor {
-  CudnnfMHAKind kind;
-  CudnnfMHABackendConfig backend_config;
-  CudnnfMHAMaskKind mask_type;
-  Shape bmm1_grad_gemm1_rhs_shape;
-  Shape bmm1_grad_gemm2_rhs_shape;
-  Shape bmm2_grad_gemm1_lhs_shape;
-  Shape bmm2_grad_gemm2_rhs_shape;
-  Shape d_output_shape;
-  Shape d_bmm1_lhs_shape;
-  Shape d_bmm1_rhs_shape;
-  Shape d_bmm2_rhs_shape;
-  DotDimensionNumbers bmm1_grad_gemm1_dnums;
-  DotDimensionNumbers bmm1_grad_gemm2_dnums;
-  DotDimensionNumbers bmm2_grad_gemm1_dnums;
-  DotDimensionNumbers bmm2_grad_gemm2_dnums;
-
-  std::optional<Shape> d_s_shape;
-  std::optional<Shape> fwd_output_shape;
-  std::optional<Shape> mask_shape;
-  std::optional<Shape> d_bias_shape;
-  std::optional<Shape> bias_shape;
-  bool force_deterministic;
-};
-
-// Structure to describe static properties of a GPU fused Multi-Headed
-// Attention.
-struct GpufMHAConfig {
-  static absl::StatusOr<GpufMHAConfig> For(const GpufMHADescriptor& fmha_desc);
-
-  absl::StatusOr<se::dnn::FusedMHAOp::Config> AsDnnFusedMHAOpConfig() const;
-
-  PrimitiveType
-      input_type;  // Capture the primitive type of one of the inputs of BMM1
-  PrimitiveType output_type;
-  CudnnfMHAKind kind;
-  std::optional<double> fmha_scale;
-  std::optional<double> dropout_rate;
-  std::optional<int64_t> seed;
-
-  se::dnn::AlgorithmDesc algorithm;
-  CudnnfMHAMaskKind mask_type;
-  // bias -> [1, num_attn_heads, q_seq_len, kv_seq_len]
-  // mask -> [batch_size, 1, q_seq_len, kv_seq_len]
-  se::dnn::MatmulTensorDescriptor lhs_bmm1;
-  se::dnn::MatmulTensorDescriptor rhs_bmm1;
-  se::dnn::MatmulTensorDescriptor rhs_bmm2;
-  se::dnn::MatmulTensorDescriptor intermediate_lhs_bmm2;
-  se::dnn::TensorDescriptor output;
-
-  std::optional<se::dnn::TensorDescriptor> activation;
-  std::optional<se::dnn::TensorDescriptor> mask;
-  std::optional<se::dnn::TensorDescriptor> bias;
-};
-
-// Structure to describe static properties of a GPU fused Multi-Headed
-// Attention backward.
-struct GpufMHABackwardConfig {
-  static absl::StatusOr<GpufMHABackwardConfig> For(
-      const GpufMHABackwardDescriptor& fmha_desc);
-
-  absl::StatusOr<se::dnn::FusedMHABackwardOp::Config>
-  AsDnnFusedMHABackwardOpConfig() const;
-
-  PrimitiveType
-      input_type;  // Capture the primitive type of one of the inputs of BMM1
-  PrimitiveType output_type;
-  CudnnfMHAKind kind;
-  std::optional<double> fmha_scale;
-  std::optional<double> dropout_rate;
-  std::optional<int64_t> seed;
-
-  se::dnn::AlgorithmDesc algorithm;
-  CudnnfMHAMaskKind mask_type;
-  // mask -> [batch_size, 1, q_seq_len, kv_seq_len]
-  // d_bias -> [1, num_heads, q_seq_len, kv_seq_len]
-  se::dnn::MatmulTensorDescriptor bmm1_grad_gemm1_rhs;
-  se::dnn::MatmulTensorDescriptor bmm1_grad_gemm2_rhs;
-  se::dnn::MatmulTensorDescriptor bmm2_grad_gemm1_lhs;
-  se::dnn::MatmulTensorDescriptor bmm2_grad_gemm2_rhs;
-  se::dnn::MatmulTensorDescriptor d_output;
-  se::dnn::TensorDescriptor d_bmm1_lhs;
-  se::dnn::TensorDescriptor d_bmm1_rhs;
-  se::dnn::TensorDescriptor d_bmm2_rhs;
-  std::optional<se::dnn::TensorDescriptor> d_s;
-  std::optional<se::dnn::TensorDescriptor> mask;
-  std::optional<se::dnn::TensorDescriptor> d_bias;
-  std::optional<se::dnn::TensorDescriptor> fwd_output;
-  std::optional<se::dnn::TensorDescriptor> bias;
-  bool force_deterministic;
-};
-
-// Implementation struct exposed for debugging and log analysis.
-struct GpufMHAParams {
-  static absl::StatusOr<GpufMHAParams> For(
-      const GpufMHAConfig& config, se::DeviceMemoryBase lhs_bmm1_buffer,
-      se::DeviceMemoryBase rhs_bmm1_buffer,
-      se::DeviceMemoryBase rhs_bmm2_buffer, se::DeviceMemoryBase output_buffer,
-      std::optional<se::DeviceMemoryBase> bias_buffer,
-      std::optional<se::DeviceMemoryBase> activation_buffer,
-      std::optional<se::DeviceMemoryBase> seqlen_q_buffer,
-      std::optional<se::DeviceMemoryBase> seqlen_k_buffer);
-
-  const GpufMHAConfig* config;  // Not owned
-  se::DeviceMemoryBase lhs_bmm1_buffer;
-  se::DeviceMemoryBase rhs_bmm1_buffer;
-  se::DeviceMemoryBase rhs_bmm2_buffer;
-  se::DeviceMemoryBase output_buffer;
-  std::optional<se::DeviceMemoryBase> activation_buffer;
-  std::optional<se::DeviceMemoryBase> bias_buffer;
-  std::optional<se::DeviceMemoryBase> seqlen_q_buffer;
-  std::optional<se::DeviceMemoryBase> seqlen_k_buffer;
-};
-
-struct GpufMHABackwardParams {
-  static absl::StatusOr<GpufMHABackwardParams> For(
-      const GpufMHABackwardConfig& config,
-      se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer,
-      se::DeviceMemoryBase bmm1_grad_gemm2_rhs_buffer,
-      se::DeviceMemoryBase bmm2_grad_gemm1_lhs_buffer,
-      se::DeviceMemoryBase bmm2_grad_gemm2_rhs_buffer,
-      se::DeviceMemoryBase d_output_buffer,
-      se::DeviceMemoryBase d_bmm1_lhs_buffer,
-      se::DeviceMemoryBase d_bmm1_rhs_buffer,
-      se::DeviceMemoryBase d_bmm2_rhs_buffer,
-      std::optional<se::DeviceMemoryBase> d_s_buffer,
-      std::optional<se::DeviceMemoryBase> d_bias_buffer,
-      std::optional<se::DeviceMemoryBase> fwd_output_buffer,
-      std::optional<se::DeviceMemoryBase> bias_buffer,
-      std::optional<se::DeviceMemoryBase> seqlen_q_buffer,
-      std::optional<se::DeviceMemoryBase> seqlen_k_buffer);
-
-  const GpufMHABackwardConfig* config;  // Not owned
-  se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer;
-  se::DeviceMemoryBase bmm1_grad_gemm2_rhs_buffer;
-  se::DeviceMemoryBase bmm2_grad_gemm1_lhs_buffer;
-  se::DeviceMemoryBase bmm2_grad_gemm2_rhs_buffer;
-  se::DeviceMemoryBase d_output_buffer;
-  se::DeviceMemoryBase d_bmm1_lhs_buffer;
-  se::DeviceMemoryBase d_bmm1_rhs_buffer;
-  se::DeviceMemoryBase d_bmm2_rhs_buffer;
-  std::optional<se::DeviceMemoryBase> d_s_buffer;
-  std::optional<se::DeviceMemoryBase> d_bias_buffer;
-  std::optional<se::DeviceMemoryBase> fwd_output_buffer;
-  std::optional<se::DeviceMemoryBase> bias_buffer;
-  std::optional<se::DeviceMemoryBase> seqlen_q_buffer;
-  std::optional<se::DeviceMemoryBase> seqlen_k_buffer;
-};
-
-class FusedMultiHeadedAttentionRunner {
- public:
-  using Repr =
-      std::variant<std::monostate,  // To allow XXX default ctor
-                   std::unique_ptr<se::dnn::LazyOpRunner<se::dnn::FusedMHAOp>>>;
-
-  FusedMultiHeadedAttentionRunner() = default;
-
-  explicit FusedMultiHeadedAttentionRunner(
-      std::unique_ptr<se::dnn::LazyOpRunner<se::dnn::FusedMHAOp>> runner)
-      : repr_(std::move(runner)) {}
-
-  explicit FusedMultiHeadedAttentionRunner(Repr runner)
-      : repr_(std::move(runner)) {}
-
-  explicit FusedMultiHeadedAttentionRunner(const GpufMHAConfig& config)
-      : FusedMultiHeadedAttentionRunner(CreateRunner(config)) {
-    if (std::holds_alternative<std::monostate>(repr_)) {
-      CHECK(false) << "Cannot construct FusedMultiHeadedAttentionRunner with "
-                      "std::monostate";
-    }
-  }
-
-  se::dnn::AlgorithmDesc ToAlgorithmDesc() const {
-    return std::visit(ToAlgorithmDescVisitor{}, repr_);
-  }
-
-  se::dnn::LazyOpRunner<se::dnn::FusedMHAOp>* AsFusedMHARunner() {
-    CHECK(std::holds_alternative<
-          std::unique_ptr<se::dnn::LazyOpRunner<se::dnn::FusedMHAOp>>>(repr_));
-    return std::get<
-               std::unique_ptr<se::dnn::LazyOpRunner<se::dnn::FusedMHAOp>>>(
-               repr_)
-        .get();
-  }
-
- private:
-  //  The CreateRunner function is defined as static because it
-  //  doesn't need access to any non-static member variables of the
-  //  FusedMultiHeadedAttentionRunner class. Defining it static makes it easy to
-  //  use and makes it clear that it is a utility function that doesn't rely on
-  //  the state of any specific instance of the class.
-  static Repr CreateRunner(const GpufMHAConfig& config) {
-    switch (config.kind) {
-      case CudnnfMHAKind::kSoftmaxDropout:
-      case CudnnfMHAKind::kSoftmax:
-      case CudnnfMHAKind::kScaleBiasSoftmax:
-      case CudnnfMHAKind::kScaleBiasSoftmaxDropout:
-        return std::make_unique<se::dnn::LazyOpRunner<se::dnn::FusedMHAOp>>(
-            config.algorithm);
-      default:
-        LOG(FATAL) << "Internal error: unsupported CUDNN MHA kind in "
-                      "FusedMultiHeadedAttentionRunner";
-    }
-  }
-
-  struct ToAlgorithmDescVisitor {
-    template <typename RunnerPtr>
-    se::dnn::AlgorithmDesc operator()(const RunnerPtr& runner) {
-      return runner->ToAlgorithmDesc();
-    }
-
-    se::dnn::AlgorithmDesc operator()(const std::monostate&) {
-      CHECK(false) << "Internal error: uninitialized runner in ToAlgorithmDesc";
-    }
-  };
-
-  Repr repr_;
-};
-
-class FusedMultiHeadedAttentionBackwardRunner {
- public:
-  using Repr = std::variant<
-      std::monostate,  // To allow XXX default ctor
-      std::unique_ptr<se::dnn::LazyOpRunner<se::dnn::FusedMHABackwardOp>>>;
-
-  FusedMultiHeadedAttentionBackwardRunner() = default;
-
-  explicit FusedMultiHeadedAttentionBackwardRunner(
-      std::unique_ptr<se::dnn::LazyOpRunner<se::dnn::FusedMHABackwardOp>>
-          runner)
-      : repr_(std::move(runner)) {}
-
-  explicit FusedMultiHeadedAttentionBackwardRunner(Repr runner)
-      : repr_(std::move(runner)) {}
-
-  explicit FusedMultiHeadedAttentionBackwardRunner(
-      const GpufMHABackwardConfig& config)
-      : FusedMultiHeadedAttentionBackwardRunner(CreateRunner(config)) {
-    if (std::holds_alternative<std::monostate>(repr_)) {
-      CHECK(false)
-          << "Cannot construct FusedMultiHeadedAttentionBackwardRunner with "
-             "std::monostate";
-    }
-  }
-
-  se::dnn::AlgorithmDesc ToAlgorithmDesc() const {
-    return std::visit(ToAlgorithmDescVisitor{}, repr_);
-  }
-
-  se::dnn::LazyOpRunner<se::dnn::FusedMHABackwardOp>*
-  AsFusedMHABackwardRunner() {
-    CHECK(std::holds_alternative<
-          std::unique_ptr<se::dnn::LazyOpRunner<se::dnn::FusedMHABackwardOp>>>(
-        repr_));
-    return std::get<std::unique_ptr<
-        se::dnn::LazyOpRunner<se::dnn::FusedMHABackwardOp>>>(repr_)
-        .get();
-  }
-
- private:
-  //  The CreateRunner function is defined as static because it
-  //  doesn't need access to any non-static member variables of the
-  //  FusedMultiHeadedAttentionBackwardRunner class. Defining it static makes it
-  //  easy to use and makes it clear that it is a utility function that doesn't
-  //  rely on the state of any specific instance of the class.
-  static Repr CreateRunner(const GpufMHABackwardConfig& config) {
-    switch (config.kind) {
-      case CudnnfMHAKind::kBackwardSoftmaxDropout:
-      case CudnnfMHAKind::kBackwardSoftmax:
-      case CudnnfMHAKind::kBackwardScaleBiasSoftmax:
-      case CudnnfMHAKind::kBackwardScaleBiasSoftmaxDropout:
-        return std::make_unique<
-            se::dnn::LazyOpRunner<se::dnn::FusedMHABackwardOp>>(
-            config.algorithm);
-      default:
-        LOG(FATAL) << "Internal error: unsupported CUDNN MHA kind in "
-                      "FusedMultiHeadedAttentionBackwardRunner";
-    }
-  }
-
-  struct ToAlgorithmDescVisitor {
-    template <typename RunnerPtr>
-    se::dnn::AlgorithmDesc operator()(const RunnerPtr& runner) {
-      return runner->ToAlgorithmDesc();
-    }
-
-    se::dnn::AlgorithmDesc operator()(const std::monostate&) {
-      CHECK(false) << "Internal error: uninitialized runner in ToAlgorithmDesc";
-    }
-  };
-
-  Repr repr_;
-};
-
-struct RunFusedMHAOptions {
-  // Nullable output-parameter pointer for profiling results.
-  // Profile results remain unused for now since cuDNN FMHA has only one
-  // algorithm for now.
-  se::dnn::ProfileResult* profile_result = nullptr;
-
-  // Use this runner cache (and its configured algorithm), instead of the one
-  // from the instruction.
-  FusedMultiHeadedAttentionRunner* runner_cache;
-};
-
-struct RunFusedMHABackwardOptions {
-  // Nullable output-parameter pointer for profiling results.
-  // Profile results remain unused for now since cuDNN FMHA has only one
-  // algorithm for now.
-  se::dnn::ProfileResult* profile_result = nullptr;
-
-  // Use this runner cache (and its configured algorithm), instead of the one
-  // from the instruction.
-  FusedMultiHeadedAttentionBackwardRunner* runner_cache;
-};
-
-absl::Status RunGpuFMHA(const GpufMHAConfig& fmha_config,
-                        se::DeviceMemoryBase lhs_bmm1_buffer,
-                        se::DeviceMemoryBase rhs_bmm1_buffer,
-                        se::DeviceMemoryBase rhs_bmm2_buffer,
-                        se::DeviceMemoryBase output_buffer,
-                        se::DeviceMemoryBase scratch_buffer,
-                        std::optional<se::DeviceMemoryBase> bias_buffer,
-                        std::optional<se::DeviceMemoryBase> activation_buffer,
-                        std::optional<se::DeviceMemoryBase> seqlen_q_buffer,
-                        std::optional<se::DeviceMemoryBase> seqlen_k_buffer,
-                        se::Stream* stream, RunFusedMHAOptions = {});
-
-absl::Status RunGpuFMHABackward(
-    const GpufMHABackwardConfig& fmha_config,
-    se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer,
-    se::DeviceMemoryBase bmm1_grad_gemm2_rhs_buffer,
-    se::DeviceMemoryBase bmm2_grad_gemm1_lhs_buffer,
-    se::DeviceMemoryBase bmm2_grad_gemm2_rhs_buffer,
-    se::DeviceMemoryBase d_output_buffer, se::DeviceMemoryBase scratch_buffer,
-    se::DeviceMemoryBase d_bmm1_lhs_buffer,
-    se::DeviceMemoryBase d_bmm1_rhs_buffer,
-    se::DeviceMemoryBase d_bmm2_rhs_buffer,
-    std::optional<se::DeviceMemoryBase> d_s_buffer,
-    std::optional<se::DeviceMemoryBase> d_bias_buffer,
-    std::optional<se::DeviceMemoryBase> fwd_output_buffer,
-    std::optional<se::DeviceMemoryBase> bias_buffer,
-    std::optional<se::DeviceMemoryBase> seqlen_q_buffer,
-    std::optional<se::DeviceMemoryBase> seqlen_k_buffer, se::Stream* stream,
-    RunFusedMHABackwardOptions = {});
-
-std::string ToString(const GpufMHAConfig& config);
-
-}  // namespace gpu
-}  // namespace xla
-#endif  // XLA_SERVICE_GPU_GPU_FUSED_MHA_RUNNER_H_
diff --git a/third_party/xla/xla/service/gpu/gpu_fusible.cc b/third_party/xla/xla/service/gpu/gpu_fusible.cc
index c40be16..e8710fe 100644
--- a/third_party/xla/xla/service/gpu/gpu_fusible.cc
+++ b/third_party/xla/xla/service/gpu/gpu_fusible.cc
@@ -29,11 +29,13 @@
 #include "absl/log/check.h"
 #include "absl/log/log.h"
 #include "absl/synchronization/mutex.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
 #include "xla/hlo/ir/hlo_computation.h"
 #include "xla/hlo/ir/hlo_instruction.h"
 #include "xla/hlo/ir/hlo_opcode.h"
 #include "xla/permutation_util.h"
 #include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
 #include "xla/service/gpu/hlo_traversal.h"
 #include "xla/service/gpu/ir_emission_utils.h"
 #include "xla/service/gpu/reduction_utils.h"
@@ -57,6 +59,103 @@
                         });
 }
 
+const Shape& GetElementShape(const HloFusionAnalysis& analysis) {
+  const Shape* shape = &analysis.fusion_root(0).shape();
+  while (shape->IsTuple()) {
+    shape = &shape->tuple_shapes(0);
+  }
+  return *shape;
+}
+
+// Computes the maximum valid unroll factor for a given instruction.
+int ComputeMaxUnrollFactor(int64_t num_elements) {
+  constexpr int kMaxUnrollFactor = 4;
+  for (int i = kMaxUnrollFactor; i > 1; i /= 2) {
+    if (num_elements % i == 0) {
+      return i;
+    }
+  }
+  return 1;
+}
+
+// Determines if we enable the row optimized codegen. When we have a fusion with
+// only pointwise operations, scalar broadcasting and row broadcasting, we can
+// trigger a kernel that vectorizes the row loads. This speeds up the kernel, in
+// particular on A100. The int is the number of inputs with rank `out_rank`. Its
+// value is only defined if row vectorization is enabled.
+std::pair<bool /*enabled*/, int> RowVectorizationEnabled(
+    const HloFusionAdaptor& fusion, int64_t out_rank) {
+  auto roots = fusion.GetRoots();
+  const auto is_row_major = [](const HloInstruction* instr) {
+    // Only tested when the inputs are row-major. So only enable that case.
+    // Maybe it would work if only the inner dimensions is contiguous.
+    return LayoutUtil::IsMonotonicWithDim0Major(instr->shape().layout());
+  };
+  bool row_vectorized = roots.size() == 1 && !roots[0].shape().IsTuple() &&
+                        is_row_major(&roots[0].instruction());
+  if (!row_vectorized) {
+    return {false, 0};
+  }
+
+  // Check that the operations in the fusion are supported.  Each
+  // supported operation (or category) must be manually vetted as XLA
+  // only unrolls and relies on LLVM to vectorize. But this is brittle.
+  // Currently tested and supported operations:
+  // Elementwise, scalar and row broadcasting.
+  //
+  // We also detect at the same time if there is a row broadcasting
+  // operation.
+  int num_big_inputs = 0;
+  bool some_row_broadcasting = false;
+  HloBfsConsumersFirstTraversal(
+      roots, fusion, [&](auto node) -> TraversalResult {
+        if (!row_vectorized) {
+          return TraversalResult::kInterrupt;
+        }
+
+        if (node.instruction().IsElementwise()) {
+          return TraversalResult::kAdvance;
+        }
+
+        switch (node.opcode()) {
+          case HloOpcode::kConstant:
+            return TraversalResult::kSkip;
+          case HloOpcode::kParameter:
+            return TraversalResult::kAdvance;
+          case HloOpcode::kBroadcast: {
+            auto dims = node.instruction().dimensions();
+            if (dims.empty()) {
+              return TraversalResult::kAdvance;
+            }
+
+            if (dims.size() == 1 && dims.front() == node.shape().rank() - 1) {
+              some_row_broadcasting = true;
+              return TraversalResult::kAdvance;
+            }
+            TF_FALLTHROUGH_INTENDED;
+          }
+          default:
+            VLOG(2) << "Row vectorization not enabled due to: "
+                    << node.ToString();
+            row_vectorized = false;
+            return TraversalResult::kInterrupt;
+        }
+      });
+  if (row_vectorized) {
+    for (const HloInstruction* argument : fusion.GetParameters()) {
+      if (argument->shape().rank() == out_rank) {
+        ++num_big_inputs;
+      }
+      if (!is_row_major(argument)) {
+        row_vectorized = false;
+      }
+    };
+  }
+  // Trigger only when there is a row broadcasting.
+  return std::make_pair(row_vectorized && some_row_broadcasting,
+                        num_big_inputs);
+}
+
 }  // namespace
 
 bool IfFusedReadsElementsMultipleTimes(const HloInstruction& instr) {
@@ -612,11 +711,16 @@
       // from potential x-tiling).
       return 4 * 32 * 33 * primitive_size * num_variadic;
     }
-  } else if (GetDescriptionForTiledTransposeEmitter(instr, instr).has_value()) {
+  } else if (auto tr = GetDescriptionForTiledTransposeEmitter(instr, instr)) {
     // Tile size for transposition.
     int64_t primitive_size =
         ShapeUtil::ByteSizeOfPrimitiveType(instr.shape().element_type());
-    return 32 * 33 * primitive_size;
+    int64_t bytes_required = 32 * 33 * primitive_size;
+    // If the last dimension is not changed, it becomes part of the tile.
+    if (tr->permutation.back() == tr->permutation.size() - 1) {
+      bytes_required *= tr->dimensions.back();
+    }
+    return bytes_required;
   }
   // Other fused expressions for now don't need the shared memory budget.
   return 0;
@@ -1022,5 +1126,78 @@
   return result;
 }
 
+LaunchDimensionsConfig ComputeLoopFusionConfig(
+    const HloFusionAnalysis& analysis) {
+  return ComputeLoopFusionConfig(analysis, GetElementShape(analysis));
+}
+
+LaunchDimensionsConfig ComputeLoopFusionConfig(
+    const HloFusionAnalysis& analysis, const Shape& element_shape) {
+  int unroll_factor = 1;
+  // Unrolling is good to read large inputs with small elements
+  // due to vector loads, but increases the register pressure when one
+  // thread has to produce multiple output elements.
+  // Therefore for fusions with small outputs prefer to use one thread
+  // per output element = no unroll.
+  // Call 'small' fusions that use less threads than the GPU has.
+  int64_t num_elements = ShapeUtil::ElementsIn(element_shape);
+  int64_t n_threads_max = analysis.device_info().threads_per_core_limit() *
+                          analysis.device_info().core_count();
+  if (num_elements >= n_threads_max &&
+      !MayPreventVectorization(analysis.fusion())) {
+    unroll_factor = ComputeMaxUnrollFactor(num_elements);
+  }
+  // CHECK that unroll_factor is a power-of-2, as needed by the logic below.
+  CHECK(absl::has_single_bit(static_cast<uint64_t>(unroll_factor)));
+  // Ensure a single thread writes to a byte containing multiple values by
+  // setting unroll_factor to an appropriate number. Setting unroll_factor is
+  // safe even if the new unroll_factor doesn't divide the number of elements,
+  // as the parallel loop emitter will insert a bounds check in this case to
+  // ensure the out-of-bounds element is not computed and written. Setting
+  // unroll_factor is safe even if MayPreventVectorization returns false, as
+  // the MayPreventVectorization check is an optimization, not a correctness
+  // requirement.
+  unroll_factor = std::max(
+      unroll_factor,
+      CeilOfRatio(8, analysis.input_output_info().smallest_output_dtype_bits));
+  CHECK(absl::has_single_bit(static_cast<uint64_t>(unroll_factor)));
+  VLOG(2) << "Unroll factor: " << unroll_factor;
+
+  bool row_vectorized;
+  int num_big_inputs;
+  std::tie(row_vectorized, num_big_inputs) =
+      RowVectorizationEnabled(analysis.fusion(), element_shape.rank());
+  bool few_waves = !HloAnyOf(analysis.fusion(), [&](auto instr) {
+    if (instr.opcode() == HloOpcode::kParameter ||
+        instr.opcode() == HloOpcode::kConstant ||
+        HloInstruction::IsOpElementwise(instr.opcode())) {
+      return false;
+    }
+    if (auto broadcast =
+            DynCast<HloBroadcastInstruction>(&instr.instruction())) {
+      if (broadcast->dimensions().empty() ||
+          // More than 3 big inputs cause a speed regression.
+          (row_vectorized && num_big_inputs <= 3)) {
+        return false;
+      }
+    }
+    VLOG(2) << "few_waves not enabled due to: "
+            << instr.instruction().ToString();
+    return true;
+  });
+
+  LaunchDimensionsConfig launch_config{unroll_factor, few_waves,
+                                       row_vectorized};
+  // Check that the shapes is supported.
+  if (launch_config.row_vectorized &&
+      ThreadsPerBlockRowVectorized(element_shape, analysis.device_info(),
+                                   launch_config) <= 0) {
+    VLOG(2) << "Cancelling row_vectorization as the shape isn't supported.";
+    launch_config.row_vectorized = false;
+    launch_config.few_waves = false;
+  }
+  return launch_config;
+}
+
 }  // namespace gpu
 }  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gpu_fusible.h b/third_party/xla/xla/service/gpu/gpu_fusible.h
index 185c440..0dadbfa 100644
--- a/third_party/xla/xla/service/gpu/gpu_fusible.h
+++ b/third_party/xla/xla/service/gpu/gpu_fusible.h
@@ -27,12 +27,14 @@
 #include "absl/synchronization/mutex.h"
 #include "xla/hlo/ir/hlo_computation.h"
 #include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
 #include "xla/service/gpu/hlo_traversal.h"
+#include "xla/service/gpu/launch_dimensions.h"
 #include "xla/service/instruction_fusion.h"
 #include "xla/stream_executor/device_description.h"
 
 // TODO(b/112957171): Extract logic to determine fusibility of HLO ops from
-// GpuInstructionFusion, FusionMerger, and GpuMultiOutputFusion.
+// GpuInstructionFusion, FusionMerger, and MultiOutputFusion.
 
 namespace xla {
 namespace gpu {
@@ -226,6 +228,12 @@
 // instructions it contains.
 bool MayPreventVectorization(const HloFusionAdaptor& fusion);
 
+LaunchDimensionsConfig ComputeLoopFusionConfig(
+    const HloFusionAnalysis& analysis);
+
+LaunchDimensionsConfig ComputeLoopFusionConfig(
+    const HloFusionAnalysis& analysis, const Shape& shape);
+
 }  // namespace gpu
 }  // namespace xla
 
diff --git a/third_party/xla/xla/service/gpu/gpu_fusible_test.cc b/third_party/xla/xla/service/gpu/gpu_fusible_test.cc
index 874b9da..aa3713d 100644
--- a/third_party/xla/xla/service/gpu/gpu_fusible_test.cc
+++ b/third_party/xla/xla/service/gpu/gpu_fusible_test.cc
@@ -1731,5 +1731,23 @@
                                    module->entry_computation()));
 }
 
+TEST_F(GpuFusibleTest, GetSharedMemoryUsage) {
+  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+    wrapped_transpose {
+      p0 = f32[128,1024,2]{2,1,0} parameter(0)
+      ROOT transpose = f32[1024,128,2]{2,1,0} transpose(p0), dimensions={1,0,2}
+    }
+    ENTRY main {
+      p = f32[128,1024,2] parameter(0)
+      ROOT res = f32[1024,128,2]{2,1,0} fusion(p), kind=kInput, calls=wrapped_transpose
+    })"))
+                    .value();
+  auto& debug_options = module->mutable_config().mutable_debug_options();
+  debug_options.set_xla_gpu_mlir_emitter_level(3);
+  FusionInfoCache cache;
+  auto fusion = module->entry_computation()->root_instruction();
+  EXPECT_EQ(cache.GetSharedMemoryUsage(*fusion), 32 * 33 * 2 * 4);
+}
+
 }  // namespace gpu
 }  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc b/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc
index 2504b43..a6a68b4 100644
--- a/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc
+++ b/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc
@@ -46,9 +46,9 @@
 #include "xla/service/collective_ops_utils.h"
 #include "xla/service/gpu/backend_configs.pb.h"
 #include "xla/service/gpu/gpu_latency_hiding_scheduler.h"
-#include "xla/service/gpu/gpu_schedule_postprocessing.h"
 #include "xla/service/gpu/model/analytical_latency_estimator.h"
-#include "xla/service/gpu/scheduling_instruction_annotator.h"
+#include "xla/service/gpu/transforms/schedule_postprocessing.h"
+#include "xla/service/gpu/transforms/scheduling_instruction_annotator.h"
 #include "xla/service/hlo_memory_scheduler.h"
 #include "xla/service/hlo_pass_pipeline.h"
 #include "xla/service/latency_hiding_scheduler.h"
@@ -74,6 +74,9 @@
     case HloOpcode::kAllReduceStart:
     case HloOpcode::kCollectivePermuteStart:
       return !IsSyncCollective(&instr);
+    case HloOpcode::kAsyncStart:
+      // Start async ops as early as possible to allow more concurrency.
+      return true;
     case HloOpcode::kCustomCall:
       return static_cast<const HloCustomCallInstruction&>(instr)
                  .custom_call_schedule() ==
@@ -95,6 +98,10 @@
     case HloOpcode::kAllReduceDone:
     case HloOpcode::kCollectivePermuteDone:
       return ShouldScheduleAsEarlyAsPossible(*instr.operand(0));
+    case HloOpcode::kAsyncDone:
+      // Schedule as many other ops as possible before blocking on the
+      // completion of async ops.
+      return true;
     case HloOpcode::kCustomCall:
       return static_cast<const HloCustomCallInstruction&>(instr)
                  .custom_call_schedule() == CustomCallSchedule::SCHEDULE_LATEST;
@@ -513,8 +520,8 @@
 
   TF_RETURN_IF_ERROR(pipeline.Run(module).status());
 
-  HloPassPipeline postprocessing_pipeline("gpu-schedule-postprocessing");
-  postprocessing_pipeline.AddPass<GpuSchedulePostprocessing>();
+  HloPassPipeline postprocessing_pipeline("schedule-postprocessing");
+  postprocessing_pipeline.AddPass<SchedulePostprocessing>();
   TF_RETURN_IF_ERROR(postprocessing_pipeline.Run(module).status());
 
   return ScheduleMetadata{memory_limit};
diff --git a/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc b/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc
index 0304f35..3582d40 100644
--- a/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc
+++ b/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc
@@ -46,6 +46,7 @@
 #include "xla/tests/filecheck.h"
 #include "xla/tests/hlo_test_base.h"
 #include "xla/tests/test_utils.h"
+#include "xla/tests/verified_hlo_module.h"
 #include "tsl/platform/status.h"
 #include "tsl/platform/status_matchers.h"
 #include "tsl/platform/statusor.h"
@@ -54,6 +55,7 @@
 namespace xla {
 namespace gpu {
 
+using ::testing::ElementsAre;
 using ::testing::HasSubstr;
 using ::tsl::testing::StatusIs;
 
@@ -1480,5 +1482,48 @@
   }
 }
 
+TEST_F(GpuHloScheduleTest, AsyncOps) {
+  const char* hlo_text = R"(
+  HloModule m
+
+  op1 {
+    p0 = f32[2,2] parameter(0)
+    ROOT add = f32[2,2] add(p0, p0)
+  }
+
+  op2 {
+    p0 = f32[2,2] parameter(0)
+    ROOT add = f32[2,2] add(p0, p0)
+  }
+
+  ENTRY main {
+    p0 = f32[2,2] parameter(0)
+    // The `async-start` blocks should be moved up, and the `async-done` blocks
+    // should be moved down.
+    acc1_start = ((f32[2,2]), f32[2,2], s32[]) fusion-start(p0),
+        kind=kLoop, calls=op1
+    acc1_done = f32[2,2] fusion-done(acc1_start)
+    acc2_start = ((f32[2,2]), f32[2,2], s32[]) fusion-start(p0),
+        kind=kLoop, calls=op2
+    acc2_done = f32[2,2] fusion-done(acc2_start)
+    ROOT done = f32[2,2] add(acc1_done, acc2_done)
+  })";
+
+  TF_ASSERT_OK_AND_ASSIGN(
+      std::unique_ptr<xla::VerifiedHloModule> module,
+      ParseAndReturnVerifiedModule(hlo_text, HloModuleConfig{}));
+  SequentialHloOrdering order = BuildHloOrdering(module.get());
+
+  std::vector<HloOpcode> opcodes;
+  for (HloInstruction* instruction :
+       order.SequentialOrder(*module->entry_computation())->instructions()) {
+    opcodes.push_back(instruction->opcode());
+  }
+  EXPECT_THAT(opcodes,
+              ElementsAre(HloOpcode::kParameter, HloOpcode::kAsyncStart,
+                          HloOpcode::kAsyncStart, HloOpcode::kAsyncDone,
+                          HloOpcode::kAsyncDone, HloOpcode::kAdd));
+}
+
 }  // namespace gpu
 }  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gpu_layout_assignment.cc b/third_party/xla/xla/service/gpu/gpu_layout_assignment.cc
deleted file mode 100644
index 008dbae..0000000
--- a/third_party/xla/xla/service/gpu/gpu_layout_assignment.cc
+++ /dev/null
@@ -1,596 +0,0 @@
-/* Copyright 2017 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gpu_layout_assignment.h"
-
-#include <cstddef>
-#include <cstdint>
-#include <initializer_list>
-#include <memory>
-#include <tuple>
-#include <utility>
-#include <variant>
-#include <vector>
-
-#include "absl/log/check.h"
-#include "absl/log/log.h"
-#include "absl/status/status.h"
-#include "absl/types/span.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/layout.h"
-#include "xla/layout_util.h"
-#include "xla/primitive_util.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/service/gpu/matmul_utils.h"
-#include "xla/service/gpu/reduction_utils.h"
-#include "xla/service/gpu/stream_executor_util.h"
-#include "xla/service/host_memory_offload_annotations.h"
-#include "xla/service/logical_buffer.h"
-#include "xla/shape.h"
-#include "xla/shape_layout.h"
-#include "xla/shape_util.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/stream_executor/dnn.h"
-#include "xla/tsl/util/env_var.h"
-#include "xla/util.h"
-#include "xla/window_util.h"
-#include "xla/xla.pb.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/status.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-using se::dnn::DataLayout;
-using se::dnn::FilterLayout;
-
-// Returns (input, filter, output) layouts.
-static std::tuple<DataLayout, FilterLayout, DataLayout>
-HeuristicLayoutAssignment(const HloInstruction* instr,
-                          const se::GpuComputeCapability& gpu_version,
-                          const se::dnn::VersionInfo& dnn_version) {
-  // DataLayout and FilterLayout uses weird enum names. Translations:
-  //   N <=> Batch or Output
-  //   C <=> Depth or Input
-  //   H <=> Y
-  //   W <=> X
-  //
-  // Therefore kOutputInputYX and kBatchDepthYX mean NCHW.
-  //
-  // If you have trouble keeping these straight, consider that all that matters
-  // is the location of the channel dim: Is it major (NCHW), or minor (NHWC)?
-
-  constexpr auto kAllNCHW =
-      std::make_tuple(DataLayout::kBatchDepthYX, FilterLayout::kOutputInputYX,
-                      DataLayout::kBatchDepthYX);
-  // kBatchDepthYX4 has the same layout as kBatchDepthYX32; they're both VECT_C
-  // layouts as far as cudnn is concerned.
-  constexpr auto kAllNCHW_VECT_C =
-      std::make_tuple(DataLayout::kBatchDepthYX4, FilterLayout::kOutputInputYX4,
-                      DataLayout::kBatchDepthYX4);
-  constexpr auto kAllNHWC =
-      std::make_tuple(DataLayout::kBatchYXDepth, FilterLayout::kOutputYXInput,
-                      DataLayout::kBatchYXDepth);
-
-  // Integer convolution must use NHWC or NCHW_VECT_C.
-  //
-  // TODO(jlebar): Do non-VECT_C int8_t convs still require NHWC with new
-  // versions of cudnn?
-  const ConvolutionDimensionNumbers& dnums =
-      instr->convolution_dimension_numbers();
-  Shape input_shape = instr->operand(0)->shape();
-  PrimitiveType input_ty = instr->operand(0)->shape().element_type();
-  if (primitive_util::IsIntegralType(input_ty)) {
-    if (input_ty == S8 && dnums.input_spatial_dimensions_size() == 2 &&
-        input_shape.dimensions_size() == 5) {
-      VLOG(2) << "Using NCHW_VECT_C for int8_t conv " << instr->ToString();
-      return kAllNCHW_VECT_C;
-    }
-    VLOG(2) << "Using NHWC for int8_t conv " << instr->ToString();
-    return kAllNHWC;
-  }
-
-  if (primitive_util::IsF8Type(input_ty)) {
-    VLOG(2) << "Using NHWC for FP8 conv " << instr->ToString();
-    return kAllNHWC;
-  }
-
-  const DebugOptions& debug_options =
-      instr->GetModule()->config().debug_options();
-
-  if (debug_options.xla_gpu_force_conv_nchw()) {
-    VLOG(2) << "Overriding layout to NCHW for " << instr->ToString();
-    return kAllNCHW;
-  }
-
-  if (debug_options.xla_gpu_force_conv_nhwc()) {
-    VLOG(2) << "Overriding layout to NHWC for " << instr->ToString();
-    return kAllNHWC;
-  }
-
-  const auto* rocm_compute_capability =
-      std::get_if<se::RocmComputeCapability>(&gpu_version);
-  if (rocm_compute_capability && input_ty == F16) return kAllNHWC;
-
-  // If we're not Volta or not fp16/bfloat16, or not conv2D, the decision is
-  // easy: Use NCHW.
-  const bool isFloat16 = (input_ty == F16) || (input_ty == BF16);
-  if (std::holds_alternative<se::CudaComputeCapability>(gpu_version)) {
-    // If we're not Volta or not fp16/bfloat16, or not conv2D, the decision is
-    // easy: Use NCHW.
-    const auto* cuda_compute_capability =
-        std::get_if<se::CudaComputeCapability>(&gpu_version);
-    bool is_volta =
-        cuda_compute_capability &&
-        cuda_compute_capability->IsAtLeast(se::CudaComputeCapability::VOLTA);
-    if (!isFloat16 || !is_volta ||
-        instr->shape().tuple_shapes(0).dimensions_size() != 4) {
-      return kAllNCHW;
-    }
-
-    // Empirically we've found with Volta and cudnn <= 7.3 that backward-input
-    // convs with stride are significantly faster with NCHW layouts.
-    //
-    // We could have used a mixed layout combination, e.g. (NHWC, NCHW, NCHW),
-    // which on paper gives good performance. However, there are two
-    // observations:
-    // * a mixed layout combination is more cuDNN-bug prone, based on empirical
-    //   evidence.
-    // * we've also observed that for mixed layouts, cuDNN transposes data back
-    //   and forth from a different layout combination. If we end up with
-    //   transposes anyway, we prefer to have them in XLA, as they can be fused.
-    if (std::make_tuple(dnn_version.major_version(),
-                        dnn_version.minor_version()) <= std::make_tuple(7, 3) &&
-        instr->custom_call_target() == kCudnnConvBackwardInputCallTarget &&
-        window_util::HasStride(instr->window())) {
-      return kAllNCHW;
-    }
-  } else if (std::holds_alternative<se::RocmComputeCapability>(gpu_version)) {
-    bool is_enabled = false;
-    TF_CHECK_OK(tsl::ReadBoolFromEnvVar("TF_USE_ROCM_NHWC",
-                                        /*default_val=*/false, &is_enabled));
-    auto rocm_compute_capability =
-        std::get<se::RocmComputeCapability>(gpu_version);
-    if (!isFloat16 || (!rocm_compute_capability.has_nhwc_layout_support()) ||
-        instr->shape().tuple_shapes(0).dimensions_size() != 4 || !is_enabled) {
-      return kAllNCHW;
-    }
-  }
-
-  VLOG(2) << "Using heuristic to figure out layouts for " << instr->ToString();
-
-  // For other Volta f16 convolutions, use NHWC.
-  return kAllNHWC;
-}
-
-// Adds layout constraints on the cudnn custom-call instruction. The layout
-// constraints are represented in terms of minor_to_major fields of both
-// operands and the output shape. Depending on the underlying algorithm, one of
-// { NCHW, NHWC } ^ 3 = 8 different layout combinations may be chosen.
-absl::Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall(
-    HloCustomCallInstruction* instr, LayoutConstraints* constraints) {
-  Shape lhs_shape = instr->operand(0)->shape();
-  Shape rhs_shape = instr->operand(1)->shape();
-  Shape result_shape = instr->shape().tuple_shapes(0);
-
-  Shape* input_shape;
-  Shape* filter_shape;
-  Shape* output_shape;
-
-  TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(instr));
-  switch (kind) {
-    case CudnnConvKind::kForward:
-    case CudnnConvKind::kForwardActivation:
-    case CudnnConvKind::kForwardGraph:
-      input_shape = &lhs_shape;
-      filter_shape = &rhs_shape;
-      output_shape = &result_shape;
-      break;
-    case CudnnConvKind::kBackwardInput:
-      input_shape = &result_shape;
-      filter_shape = &rhs_shape;
-      output_shape = &lhs_shape;
-      break;
-    case CudnnConvKind::kBackwardFilter:
-      input_shape = &lhs_shape;
-      filter_shape = &result_shape;
-      output_shape = &rhs_shape;
-      break;
-  }
-
-  {
-    DataLayout input;
-    FilterLayout filter;
-    DataLayout output;
-    std::tie(input, filter, output) =
-        HeuristicLayoutAssignment(instr, gpu_version_, dnn_version_);
-
-    TF_ASSIGN_OR_RETURN(
-        std::tie(*input_shape->mutable_layout(),
-                 *filter_shape->mutable_layout(),
-                 *output_shape->mutable_layout()),
-        StreamExecutorConvLayoutsToXlaLayouts(
-            instr->convolution_dimension_numbers(), input, filter, output));
-  }
-
-  // The custom call returns a tuple of (actual_result, scratch_buffer);
-  // call_result_buf is the logical buffer for actual_result, the thing that
-  // contains the result of the conv call.
-  TF_ASSIGN_OR_RETURN(
-      const LogicalBuffer* call_result_buf,
-      points_to_analysis_->GetBufferDefinedAt(instr, /*index=*/{0}));
-
-  // Set layouts of the instructions' shapes.
-  TF_RETURN_IF_ERROR(SetOperandLayout(lhs_shape, instr, 0));
-  TF_RETURN_IF_ERROR(SetOperandLayout(rhs_shape, instr, 1));
-  TF_RETURN_IF_ERROR(SetBufferLayout(result_shape.layout(), *call_result_buf));
-  // For fused convolutions, instr->operand(2), if exists, is the bias buffer.
-  // There is no need to assign layout to it, as it has only one dimension.
-  // instr->operand(3), if exists, is the side input buffer.
-  if (kind == CudnnConvKind::kForwardActivation &&
-      instr->operand_count() == 4) {
-    // The side input layout must match the output layout.
-    TF_RETURN_IF_ERROR(SetOperandLayout(*output_shape, instr, 3));
-  }
-
-  // For graph convolutions, align the layouts of the non-scalar inputs to any
-  // pointwise ops with the output layout.
-  if (kind == CudnnConvKind::kForwardGraph) {
-    for (int k = 2; k < instr->operand_count(); ++k) {
-      if (!ShapeUtil::IsScalar(instr->operand(k)->shape())) {
-        TF_RETURN_IF_ERROR(SetOperandLayout(*output_shape, instr, k));
-      }
-    }
-  }
-
-  if (instr->operand_count() > 2 && kind != CudnnConvKind::kForwardActivation &&
-      kind != CudnnConvKind::kForwardGraph) {
-    return Internal(
-        "Invalid convolution. Conv has a side input, but kind is not fused "
-        "conv forward or graph conv foward: %s",
-        instr->ToString());
-  }
-
-  return absl::OkStatus();
-}
-
-namespace {
-
-// Imposes the default layout with first two dimensions swapped on input
-// `shape`.
-void SetFortranLayout(Shape* shape) {
-  LayoutUtil::SetToDefaultLayout(shape);
-  int n = shape->mutable_layout()->minor_to_major_size();
-  CHECK_GE(n, 2);
-  std::swap(shape->mutable_layout()->mutable_minor_to_major()->at(0),
-            shape->mutable_layout()->mutable_minor_to_major()->at(1));
-}
-
-bool DotCanSupportShapeWithLayout(const HloInstruction* dot,
-                                  const Shape& shape) {
-  const DotDimensionNumbers& dot_dims = dot->dot_dimension_numbers();
-  // If we are able to construct a `MatrixLayout` then the dot can support
-  // this layout.
-  return MatrixLayout::For(shape, dot_dims.lhs_batch_dimensions().size(),
-                           dot->operand(0)->shape().rank() -
-                               dot_dims.lhs_contracting_dimensions().size() -
-                               dot_dims.lhs_batch_dimensions().size(),
-                           dot_dims.rhs_batch_dimensions().size(),
-                           dot->operand(1)->shape().rank() -
-                               dot_dims.rhs_contracting_dimensions().size() -
-                               dot_dims.rhs_batch_dimensions().size())
-      .ok();
-}
-
-}  // namespace
-
-absl::Status GpuLayoutAssignment::AddBackendConstraints(
-    LayoutConstraints* constraints) {
-  // Add convolution constraints in reverse postorder that the earliest
-  // convolution layout propagates first. This reduces the likelihood of fusion
-  // nodes with copies.
-  auto post_order = constraints->computation()->MakeInstructionPostOrder();
-  for (auto iterator = post_order.rbegin(); iterator != post_order.rend();
-       ++iterator) {
-    HloInstruction* instruction = *iterator;
-    if (IsCustomCallToDnnConvolution(*instruction)) {
-      TF_RETURN_IF_ERROR(AddBackendConstraintsToDnnConvCustomCall(
-          Cast<HloCustomCallInstruction>(instruction), constraints));
-    }
-
-    CHECK(!IsCublasGemm(*instruction))
-        << "Gemm rewriting should run after layout assignment";
-
-    if (instruction->opcode() == HloOpcode::kDot) {
-      const Shape& output_shape = instruction->shape();
-      const Shape& lhs_shape = instruction->operand(0)->shape();
-      const Shape& rhs_shape = instruction->operand(1)->shape();
-      const DotDimensionNumbers& dot_dims =
-          instruction->dot_dimension_numbers();
-
-      // Matmuls require the batch dimensions to be in consecutive physical
-      // dimensions and likewise for the contracting and non-contracting
-      // dimensions. Additionally, no batch dimension can be in the most
-      // minor physical dimension for inputs or the output.
-      absl::Span<const int64_t> lhs_batch_dims =
-          dot_dims.lhs_batch_dimensions();
-      absl::Span<const int64_t> lhs_contracting_dims =
-          dot_dims.lhs_contracting_dimensions();
-      TF_ASSIGN_OR_RETURN(std::vector<int64_t> lhs_non_contracting_dims,
-                          GetNonContractingDims(lhs_shape, lhs_batch_dims,
-                                                lhs_contracting_dims));
-
-      absl::Span<const int64_t> rhs_batch_dims =
-          dot_dims.rhs_batch_dimensions();
-      absl::Span<const int64_t> rhs_contracting_dims =
-          dot_dims.rhs_contracting_dimensions();
-      TF_ASSIGN_OR_RETURN(std::vector<int64_t> rhs_non_contracting_dims,
-                          GetNonContractingDims(rhs_shape, rhs_batch_dims,
-                                                rhs_contracting_dims));
-
-      const DebugOptions& debug_options =
-          instruction->GetModule()->config().debug_options();
-
-      bool is_bf16_to_bf16 =
-          (output_shape.element_type() == PrimitiveType::BF16 &&
-           lhs_shape.element_type() == PrimitiveType::BF16 &&
-           rhs_shape.element_type() == PrimitiveType::BF16);
-      bool is_s8_to_s32 = (output_shape.element_type() == PrimitiveType::S32 &&
-                           lhs_shape.element_type() == PrimitiveType::S8 &&
-                           rhs_shape.element_type() == PrimitiveType::S8 &&
-                           output_shape.dimensions_size() == 2 &&
-                           lhs_shape.dimensions_size() == 2 &&
-                           rhs_shape.dimensions_size() == 2);
-
-      if (is_s8_to_s32 ||
-          (is_bf16_to_bf16 &&
-           debug_options.xla_gpu_ensure_minor_dot_contraction_dims())) {
-        TF_RETURN_IF_ERROR(SetOperandMajorToMinorLayout(
-            instruction, /*operand=*/0,
-            /*dim_groups=*/
-            {lhs_batch_dims, lhs_non_contracting_dims, lhs_contracting_dims}));
-        TF_RETURN_IF_ERROR(SetOperandMajorToMinorLayout(
-            instruction, /*operand=*/1,
-            /*dim_groups=*/
-            {rhs_batch_dims, rhs_non_contracting_dims, rhs_contracting_dims}));
-        TF_RETURN_IF_ERROR(SetDotLayout(instruction, constraints));
-      } else {
-        if (!lhs_batch_dims.empty() || lhs_contracting_dims.size() > 1 ||
-            lhs_non_contracting_dims.size() > 1) {
-          TF_RETURN_IF_ERROR(SetDotOperandLayout(instruction, 0, lhs_batch_dims,
-                                                 lhs_contracting_dims,
-                                                 lhs_non_contracting_dims));
-        }
-        if (!rhs_batch_dims.empty() || rhs_non_contracting_dims.size() > 1 ||
-            rhs_contracting_dims.size() > 1) {
-          TF_RETURN_IF_ERROR(SetDotOperandLayout(instruction, 1, rhs_batch_dims,
-                                                 rhs_contracting_dims,
-                                                 rhs_non_contracting_dims));
-        }
-        // If we have at least one batch dimension or there is more than one
-        // non-contracting dimension on lhs or rhs, we need to set a layout for
-        // the dot output.
-        if (!lhs_batch_dims.empty() || lhs_non_contracting_dims.size() > 1 ||
-            rhs_non_contracting_dims.size() > 1) {
-          TF_RETURN_IF_ERROR(SetDotLayout(instruction, constraints));
-        }
-      }
-    } else if (instruction->opcode() == HloOpcode::kTranspose) {
-      const HloInstruction* operand = instruction->operand(0);
-      if ((operand->opcode() != HloOpcode::kDot) ||
-          (operand->user_count() > 1)) {
-        continue;
-      }
-
-      // If possible, set layout of the dot operation such that the output of
-      // the transpose (as a bitcast) has the default layout.
-      Shape shape = operand->shape();
-      *shape.mutable_layout() =
-          LayoutUtil::MakeLayoutFromMajorToMinor(instruction->dimensions());
-
-      if (DotCanSupportShapeWithLayout(operand, shape)) {
-        TF_RETURN_IF_ERROR(
-            SetOperandLayout(shape, instruction, /*operand_no=*/0));
-      }
-    } else if (instruction->opcode() == HloOpcode::kFft) {
-      // cuFFT requires a dim0 major layout.
-      Shape op0_shape = instruction->operand(0)->shape();
-      LayoutUtil::SetToDefaultLayout(&op0_shape);
-      Shape output_shape = instruction->shape();
-      LayoutUtil::SetToDefaultLayout(&output_shape);
-      TF_RETURN_IF_ERROR(SetOperandLayout(op0_shape, instruction, 0));
-      TF_RETURN_IF_ERROR(SetInstructionLayout(output_shape, instruction));
-    } else if (instruction->opcode() == HloOpcode::kSort &&
-               instruction->operand(0)->shape().rank() > 1) {
-      // Make sure that all the operands and the output(s) have the same layout.
-      Shape keys_shape = instruction->operand(0)->shape();
-      Layout keys_layout =
-          LayoutUtil::GetDefaultLayoutForRank(keys_shape.rank());
-      for (int64_t i = 0; i < instruction->operand_count(); ++i) {
-        Shape shape = instruction->operand(i)->shape();
-        *shape.mutable_layout() = keys_layout;
-        TF_RETURN_IF_ERROR(SetOperandLayout(shape, instruction, i));
-        const LogicalBuffer* output_buffer;
-        if (instruction->shape().IsArray()) {
-          TF_ASSIGN_OR_RETURN(
-              output_buffer,
-              points_to_analysis_->GetBufferDefinedAt(instruction, {}));
-        } else {
-          TF_ASSIGN_OR_RETURN(
-              output_buffer,
-              points_to_analysis_->GetBufferDefinedAt(instruction, {i}));
-        }
-        TF_RETURN_IF_ERROR(SetBufferLayout(keys_layout, *output_buffer));
-      }
-    } else if (instruction->opcode() == HloOpcode::kTriangularSolve) {
-      // TODO(phawkins): Ideally we would relax this constraint. What we
-      // actually want is that:
-      // a) the batch dimensions are major, in no particular order.
-      // b) the two minor dimensions are in fortran (column-major) order,
-      // although for the 'a' argument we could potentially accept row-major
-      // order and fold the transpose into the operator.
-      Shape op0_shape = instruction->operand(0)->shape();
-      Shape op1_shape = instruction->operand(1)->shape();
-      Shape output_shape = instruction->shape();
-      SetFortranLayout(&op0_shape);
-      SetFortranLayout(&op1_shape);
-      SetFortranLayout(&output_shape);
-      TF_RETURN_IF_ERROR(SetOperandLayout(op0_shape, instruction, 0));
-      TF_RETURN_IF_ERROR(SetOperandLayout(op1_shape, instruction, 1));
-      TF_RETURN_IF_ERROR(SetInstructionLayout(output_shape, instruction));
-    } else if (instruction->opcode() == HloOpcode::kReduceScatter) {
-      // XLA:GPU can only support reduce-scatter where the scatter dimension
-      // is the most major dimension in the layout.
-      auto ars = Cast<HloReduceScatterInstruction>(instruction);
-      TF_RETURN_IF_ERROR(SetInstructionLayout(
-          ShapeUtil::MoveDimToMajor(ars->shape(), ars->scatter_dimension()),
-          ars));
-    } else if (instruction->opcode() == HloOpcode::kAllGather) {
-      // XLA:GPU can only support all-gathers where the gather dimension is the
-      // most major dimension in the layout.
-      auto ag = Cast<HloAllGatherInstruction>(instruction);
-      TF_RETURN_IF_ERROR(SetInstructionLayout(
-          ShapeUtil::MoveDimToMajor(ag->shape(), ag->all_gather_dimension()),
-          ag));
-    } else if (instruction->opcode() == HloOpcode::kAllToAll &&
-               instruction->shape().IsArray()) {
-      // XLA:GPU can only support all-to-all with split dimensions where the
-      // split dimension is the most major dimension in the layout.
-      auto* all_to_all = Cast<HloAllToAllInstruction>(instruction);
-      TF_RETURN_IF_ERROR(SetInstructionLayout(
-          ShapeUtil::MoveDimToMajor(all_to_all->shape(),
-                                    *all_to_all->split_dimension()),
-          all_to_all));
-    } else if (instruction->opcode() == HloOpcode::kSend) {
-      Shape s = instruction->operand(0)->shape();
-      LayoutUtil::SetToDefaultLayout(&s);
-      TF_RETURN_IF_ERROR(SetInstructionLayout(s, instruction->operand(0)));
-      TF_RETURN_IF_ERROR(
-          SetArrayOperandLayout(s.layout(), instruction->operand(0), 0));
-    } else if (instruction->opcode() == HloOpcode::kRecv) {
-      Shape s = instruction->shape();
-      ShapeUtil::ForEachMutableSubshape(
-          &s, [&](Shape* subshape, const ShapeIndex& index) {
-            LayoutUtil::SetToDefaultLayout(subshape);
-          });
-      TF_RETURN_IF_ERROR(SetInstructionLayout(s, instruction));
-    }
-  }
-  return absl::OkStatus();
-}
-
-absl::Status GpuLayoutAssignment::SetDotOperandLayout(
-    const HloInstruction* instruction, int64_t operand,
-    absl::Span<const int64_t> batch_dims, absl::Span<const int64_t> row_dims,
-    absl::Span<const int64_t> col_dims) {
-  Shape shape = instruction->operand(operand)->shape();
-
-  // First, try to use the existing layout, if present.
-  if (shape.has_layout() &&
-      MatrixLayout::For(shape, batch_dims, row_dims, col_dims).ok())
-    // Re-set the operand layout, so it becomes mandatory.
-    return SetOperandLayout(shape, instruction, operand);
-
-  // Next, try the default layout (for the sake of everybody's sanity).
-  LayoutUtil::SetToDefaultLayout(&shape);
-  if (MatrixLayout::For(shape, batch_dims, row_dims, col_dims).ok())
-    return SetOperandLayout(shape, instruction, operand);
-
-  // Otherwise, fallback to forcing (batch, rows, cols) layout.
-  return SetOperandMajorToMinorLayout(
-      instruction, operand,
-      /*dim_groups=*/{batch_dims, row_dims, col_dims});
-}
-
-absl::Status GpuLayoutAssignment::SetOperandMajorToMinorLayout(
-    const HloInstruction* instruction, int64_t operand,
-    std::initializer_list<absl::Span<const int64_t>> dim_groups) {
-  size_t size = 0;
-  for (auto group : dim_groups) size += group.size();
-  std::vector<int64_t> major_to_minor;
-  major_to_minor.reserve(size);
-  for (const auto& group : dim_groups) {
-    major_to_minor.insert(major_to_minor.end(), group.begin(), group.end());
-  }
-
-  Shape shape = instruction->operand(operand)->shape();
-  *shape.mutable_layout() =
-      LayoutUtil::MakeLayoutFromMajorToMinor(major_to_minor);
-  return SetOperandLayout(shape, instruction, operand);
-}
-
-absl::Status GpuLayoutAssignment::SetDotLayout(
-    const HloInstruction* instruction, LayoutConstraints* constraints) {
-  // If a user has requested a layout that we can support, use that.
-  for (const HloInstruction* user : instruction->users()) {
-    for (int64_t i = 0; i < user->operand_count(); ++i) {
-      if (user->operand(i) != instruction) {
-        continue;
-      }
-
-      const ShapeLayout* constraint = constraints->OperandLayout(user, i);
-      if ((constraint != nullptr) &&
-          DotCanSupportShapeWithLayout(instruction, constraint->shape())) {
-        return SetInstructionLayout(constraint->shape(), instruction);
-      }
-    }
-  }
-
-  // Otherwise, use the default layout.
-  return SetInstructionLayout(
-      LayoutUtil::GetWithDefaultLayout(instruction->shape()), instruction);
-}
-
-bool GpuLayoutAssignment::PropagateReductionLayoutToOperand(
-    const HloInstruction* user) {
-  // We try to propagate a layout to make the reduction a row reduction. But
-  // propagating the layout is only beneficial if the reduction emitter would be
-  // used for the row reduction.
-  int64_t reduction_size = 1;
-  for (int64_t reduction_dim : user->dimensions()) {
-    reduction_size *= user->operand(0)->shape().dimensions(reduction_dim);
-  }
-  int64_t kept_dimension_size = ShapeUtil::ElementsIn(user->shape());
-  return IsUnnestedReductionFasterThanElemental(
-      {/*is_row_reduction=*/true, {1, kept_dimension_size, reduction_size}});
-}
-
-bool GpuLayoutAssignment::InstructionCanChangeLayoutInstance(
-    const HloInstruction* instruction) {
-  // The host offloading custom calls will be eventually removed
-  // by the offloader, so we need to make sure that the calls do not change
-  // the layout and thus cause layout mismatches after the removal.
-  const HloCustomCallInstruction* custom_call =
-      DynCast<HloCustomCallInstruction>(instruction);
-  if (custom_call != nullptr &&
-      (custom_call->custom_call_target() ==
-           host_memory_offload_annotations::kMoveToHostCustomCallTarget ||
-       custom_call->custom_call_target() ==
-           host_memory_offload_annotations::kMoveToDeviceCustomCallTarget)) {
-    return false;
-  }
-
-  return LayoutAssignment::InstructionCanChangeLayoutInstance(instruction);
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gpu_layout_assignment.h b/third_party/xla/xla/service/gpu/gpu_layout_assignment.h
deleted file mode 100644
index 70741fe..0000000
--- a/third_party/xla/xla/service/gpu/gpu_layout_assignment.h
+++ /dev/null
@@ -1,81 +0,0 @@
-/* Copyright 2017 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_GPU_LAYOUT_ASSIGNMENT_H_
-#define XLA_SERVICE_GPU_GPU_LAYOUT_ASSIGNMENT_H_
-
-#include <cstdint>
-#include <initializer_list>
-
-#include "absl/status/status.h"
-#include "absl/types/span.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/service/computation_layout.h"
-#include "xla/service/layout_assignment.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/stream_executor/dnn.h"
-
-namespace xla {
-namespace gpu {
-
-// GPU-specific layout assignment pass which preassigns layouts to satisfy
-// layout constraints for operands and results of library calls.
-class GpuLayoutAssignment : public LayoutAssignment {
- public:
-  explicit GpuLayoutAssignment(
-      ComputationLayout* entry_computation_layout,
-      const se::GpuComputeCapability& gpu_version,
-      const se::dnn::VersionInfo& dnn_version,
-      ChannelLayoutConstraints* channel_constraints = nullptr)
-      : LayoutAssignment(entry_computation_layout, channel_constraints),
-        gpu_version_(gpu_version),
-        dnn_version_(dnn_version) {}
-  ~GpuLayoutAssignment() override = default;
-
- protected:
-  absl::Status AddBackendConstraints(LayoutConstraints* constraints) override;
-
- private:
-  absl::Status AddBackendConstraintsToDnnConvCustomCall(
-      HloCustomCallInstruction* instr, LayoutConstraints* constraints);
-
-  // dim_groups are ordered from major to minor dimensions.
-  absl::Status SetOperandMajorToMinorLayout(
-      const HloInstruction* instruction, int64_t operand,
-      std::initializer_list<absl::Span<const int64_t>> dim_groups);
-
-  absl::Status SetDotOperandLayout(const HloInstruction* instruction,
-                                   int64_t operand,
-                                   absl::Span<const int64_t> batch_dims,
-                                   absl::Span<const int64_t> row_dims,
-                                   absl::Span<const int64_t> col_dims);
-
-  absl::Status SetDotLayout(const HloInstruction* instruction,
-                            LayoutConstraints* constraints);
-
-  bool PropagateReductionLayoutToOperand(const HloInstruction* user) override;
-
-  bool InstructionCanChangeLayoutInstance(
-      const HloInstruction* instruction) override;
-
-  const se::GpuComputeCapability gpu_version_;
-  const se::dnn::VersionInfo dnn_version_;
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_GPU_LAYOUT_ASSIGNMENT_H_
diff --git a/third_party/xla/xla/service/gpu/gpu_layout_assignment_test.cc b/third_party/xla/xla/service/gpu/gpu_layout_assignment_test.cc
deleted file mode 100644
index 81f9e00..0000000
--- a/third_party/xla/xla/service/gpu/gpu_layout_assignment_test.cc
+++ /dev/null
@@ -1,677 +0,0 @@
-/* Copyright 2017 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gpu_layout_assignment.h"
-
-#include <cstdint>
-#include <memory>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "absl/types/span.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/layout.h"
-#include "xla/layout_util.h"
-#include "xla/service/computation_layout.h"
-#include "xla/service/gpu/stream_executor_util.h"
-#include "xla/service/hlo_parser.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/shape.h"
-#include "xla/shape_layout.h"
-#include "xla/shape_util.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/stream_executor/dnn.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/status_matchers.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-namespace m = ::xla::match;
-using ::tsl::testing::IsOkAndHolds;
-
-class LayoutAssignmentTest : public HloTestBase {
- public:
-  se::CudaComputeCapability GetCudaComputeCapability() {
-    return backend()
-        .default_stream_executor()
-        ->GetDeviceDescription()
-        .cuda_compute_capability();
-  }
-
-  se::GpuComputeCapability GetGpuComputeCapability() {
-    return backend()
-        .default_stream_executor()
-        ->GetDeviceDescription()
-        .gpu_compute_capability();
-  }
-
-  se::dnn::VersionInfo GetDnnVersion() {
-    // GpuLayoutAssignment has a special case heuristic for cudnn <= 7.3, but
-    // none of the tests trigger this heuristic.
-    return GetDnnVersionInfoOrDefault(backend().default_stream_executor(),
-                                      se::dnn::VersionInfo{8, 3, 0});
-  }
-};
-
-TEST_F(LayoutAssignmentTest, Elementwise) {
-  Shape ashape = ShapeUtil::MakeShape(F32, {42, 12});
-  Shape ashape_in_row_major(ashape);
-  Shape ashape_in_col_major(ashape);
-  *ashape_in_row_major.mutable_layout() = LayoutUtil::MakeLayout({1, 0});
-  *ashape_in_col_major.mutable_layout() = LayoutUtil::MakeLayout({0, 1});
-
-  // Enumerate all possible combinations of layouts.
-  for (const Shape& lhs_shape_with_layout :
-       {ashape_in_row_major, ashape_in_col_major}) {
-    for (const Shape& rhs_shape_with_layout :
-         {ashape_in_row_major, ashape_in_col_major}) {
-      for (const Shape& result_shape_with_layout :
-           {ashape_in_row_major, ashape_in_col_major}) {
-        // GpuLayoutAssignment should assign the same layout to "add" and its
-        // two operands.
-        auto builder = HloComputation::Builder(TestName());
-        auto x = builder.AddInstruction(
-            HloInstruction::CreateParameter(0, ashape, "x"));
-        auto y = builder.AddInstruction(
-            HloInstruction::CreateParameter(1, ashape, "y"));
-        auto add = builder.AddInstruction(
-            HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, x, y));
-        auto module = CreateNewVerifiedModule();
-        HloComputation* computation =
-            module->AddEntryComputation(builder.Build(add));
-
-        ComputationLayout computation_layout(
-            computation->ComputeProgramShape());
-        *computation_layout.mutable_parameter_layout(0) =
-            ShapeLayout(lhs_shape_with_layout);
-        *computation_layout.mutable_parameter_layout(1) =
-            ShapeLayout(rhs_shape_with_layout);
-        *computation_layout.mutable_result_layout() =
-            ShapeLayout(result_shape_with_layout);
-
-        GpuLayoutAssignment layout_assignment(
-            &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
-        EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
-
-        for (const HloInstruction* operand : add->operands()) {
-          EXPECT_TRUE(LayoutUtil::Equal(add->shape().layout(),
-                                        operand->shape().layout()));
-        }
-      }
-    }
-  }
-}
-
-TEST_F(LayoutAssignmentTest, DotLayoutUnchangedIfValid) {
-  const char* hlo_text = R"(
-  HloModule DotLayout
-  ENTRY dot {
-    p0 = f32[5,2,3]{1,2,0} parameter(0)
-    p1 = f32[5,3,4]{1,2,0} parameter(1)
-    ROOT dot.1330.10585 = f32[5,2,4]{2,1,0} dot(p0, p1),
-      lhs_batch_dims={0}, lhs_contracting_dims={2},
-      rhs_batch_dims={0}, rhs_contracting_dims={1}
-  })";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(hlo_text));
-
-  ComputationLayout computation_layout(
-      module->entry_computation()->ComputeProgramShape(),
-      /*ignore_layouts=*/false);
-  GpuLayoutAssignment layout_assignment(
-      &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
-  EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch(m::Dot(m::Op().WithShape(F32, {5, 2, 3}, {1, 2, 0}),
-                                m::Op().WithShape(F32, {5, 3, 4}, {1, 2, 0}))
-                             .WithShape(F32, {5, 2, 4}, {2, 1, 0})));
-}
-
-TEST_F(LayoutAssignmentTest, DotLayoutSetToDefaultIfDefaultValid) {
-  const char* hlo_text = R"(
-  HloModule DotLayout
-  ENTRY dot {
-    p0 = f32[5,3,2] parameter(0)
-    p1 = f32[5,4,3]{0,1,2} parameter(1)
-    ROOT dot.1330.10585 = f32[5,2,4] dot(p0, p1),
-      lhs_batch_dims={0}, lhs_contracting_dims={1},
-      rhs_batch_dims={0}, rhs_contracting_dims={2}
-  })";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(hlo_text));
-
-  ComputationLayout computation_layout(
-      module->entry_computation()->ComputeProgramShape(),
-      /*ignore_layouts=*/false);
-  GpuLayoutAssignment layout_assignment(
-      &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
-
-  EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch(m::Dot(m::Op().WithShape(F32, {5, 3, 2}, {2, 1, 0}),
-                                m::Op().WithShape(F32, {5, 4, 3}, {2, 1, 0}))
-                             .WithShape(F32, {5, 2, 4}, {2, 1, 0})));
-}
-
-TEST_F(LayoutAssignmentTest, DotOperandLayoutSetToBatchRowsColsOtherwise) {
-  const char* hlo_text = R"(
-  HloModule DotLayout
-  ENTRY dot {
-    p0 = f32[2,3,5]{2,1,0} parameter(0)
-    p1 = f32[3,4,5] parameter(1)
-    ROOT dot.1330.10585 = f32[5,2,4] dot(p0, p1),
-      lhs_batch_dims={2}, lhs_contracting_dims={1},
-      rhs_batch_dims={2}, rhs_contracting_dims={0}
-  })";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(hlo_text));
-
-  ComputationLayout computation_layout(
-      module->entry_computation()->ComputeProgramShape(),
-      /*ignore_layouts=*/false);
-  GpuLayoutAssignment layout_assignment(
-      &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
-
-  EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch(m::Dot(m::Op().WithShape(F32, {2, 3, 5}, {0, 1, 2}),
-                                m::Op().WithShape(F32, {3, 4, 5}, {1, 0, 2}))));
-}
-
-TEST_F(LayoutAssignmentTest, DotOperandInconsistentDimLayouts) {
-  const char* hlo_text = R"(
-  HloModule DotLayout
-  ENTRY dot {
-    p0 = f32[5,6,2,3] parameter(0)
-    p1 = f32[6,5,3,4] parameter(1)
-    ROOT dot.1330.10585 = f32[5,6,2,4] dot(p0, p1),
-      lhs_batch_dims={0,1}, lhs_contracting_dims={3},
-      rhs_batch_dims={1,0}, rhs_contracting_dims={2}
-  })";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(hlo_text));
-
-  ComputationLayout computation_layout(
-      module->entry_computation()->ComputeProgramShape(),
-      /*ignore_layouts=*/false);
-  GpuLayoutAssignment layout_assignment(
-      &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
-
-  EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
-  EXPECT_THAT(
-      module->entry_computation()->root_instruction(),
-      GmockMatch(m::Dot(m::Op().WithShape(F32, {5, 6, 2, 3}, {3, 2, 1, 0}),
-                        m::Op().WithShape(F32, {6, 5, 3, 4}, {3, 2, 0, 1}))));
-}
-
-TEST_F(LayoutAssignmentTest, TransposedDotLayout) {
-  const char* hlo_text = R"(
-  HloModule DotLayout
-  ENTRY dot {
-    p0 = f32[5,2,3] parameter(0)
-    p1 = f32[5,3,4,6] parameter(1)
-    dot = f32[5,2,4,6] dot(p0, p1),
-      lhs_batch_dims={0}, lhs_contracting_dims={2},
-      rhs_batch_dims={0}, rhs_contracting_dims={1}
-    ROOT out = f32[2,5,4,6] transpose(dot), dimensions={1,0,2,3}
-  })";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(hlo_text));
-
-  ComputationLayout computation_layout(
-      module->entry_computation()->ComputeProgramShape(),
-      /*ignore_layouts=*/false);
-  GpuLayoutAssignment layout_assignment(
-      &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
-
-  EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
-  EXPECT_THAT(
-      module->entry_computation()->root_instruction(),
-      GmockMatch(m::Transpose(
-                     m::Dot(m::Op().WithShape(F32, {5, 2, 3}, {2, 1, 0}),
-                            m::Op().WithShape(F32, {5, 3, 4, 6}, {3, 2, 1, 0}))
-                         .WithShape(F32, {5, 2, 4, 6}, {3, 2, 0, 1}))
-                     .WithShape(F32, {2, 5, 4, 6}, {3, 2, 1, 0})));
-}
-
-TEST_F(LayoutAssignmentTest, TransposedDotOfDotLayout) {
-  const char* hlo_text = R"(
-  HloModule DotLayout
-  ENTRY dot {
-    p0 = f32[8,50] parameter(0)
-    p1 = f32[2,8,4,4] parameter(1)
-    p2 = f32[4,38] parameter(2)
-    dot.1 = f32[50,2,4,4]{3,2,1,0} dot(p0, p1),
-      lhs_contracting_dims={0}, rhs_contracting_dims={1}
-    dot.2 = f32[50,2,4,38]{3,2,1,0} dot(dot.1, p2),
-      lhs_contracting_dims={2}, rhs_contracting_dims={0}
-    ROOT out = f32[2,50,38,4]{2,3,0,1} transpose(dot.2), dimensions={1,0,3,2}
-  })";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(hlo_text));
-
-  ComputationLayout computation_layout(
-      module->entry_computation()->ComputeProgramShape(),
-      /*ignore_layouts=*/false);
-  GpuLayoutAssignment layout_assignment(
-      &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
-
-  EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
-  // The transpose layout is not supported by dot.2. Also, we need a copy
-  // between dot.1 and dot.2, because the needed operand layout for the lhs of
-  // dot.1 cannot be used as layout for dot.1
-  EXPECT_THAT(
-      module->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::Transpose(
-              m::Dot(m::Copy(m::Dot(m::Op().WithShape(F32, {8, 50}, {1, 0}),
-                                    m::Op().WithShape(F32, {2, 8, 4, 4},
-                                                      {3, 2, 0, 1}))
-                                 .WithShape(F32, {50, 2, 4, 4}, {3, 2, 1, 0}))
-                         .WithShape(F32, {50, 2, 4, 4}, {3, 1, 0, 2}),
-                     m::Op().WithShape(F32, {4, 38}, {1, 0}))
-                  .WithShape(F32, {50, 2, 4, 38}, {3, 2, 1, 0}))
-              .WithShape(F32, {2, 50, 38, 4}, {2, 3, 0, 1})));
-}
-
-TEST_F(LayoutAssignmentTest, DotLayoutS8) {
-  const char* hlo_text = R"(
-  HloModule DotLayout
-  ENTRY int8_t {
-    p0 = s8[32,64] parameter(0)
-    p1 = s8[64,96] parameter(1)
-    ROOT out = s32[32,96] dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  })";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(hlo_text));
-
-  ComputationLayout computation_layout(
-      module->entry_computation()->ComputeProgramShape(),
-      /*ignore_layouts=*/false);
-  GpuLayoutAssignment layout_assignment(
-      &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
-
-  EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch(m::Dot(m::Op().WithShape(S8, {32, 64}, {1, 0}),
-                                m::Op().WithShape(S8, {64, 96}, {0, 1}))));
-}
-
-TEST_F(LayoutAssignmentTest, SortLayout) {
-  const char* hlo_text = R"(
-  HloModule SortLayout
-
-  compare {
-    p.0.lhs = f32[] parameter(0)
-    p.0.rhs = f32[] parameter(1)
-    p.1.lhs = f32[] parameter(2)
-    p.1.rhs = f32[] parameter(3)
-    ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
-  }
-
-  ENTRY sort {
-    keys = f32[3,2]{0,1} constant({{0,1},{0,1},{0,1}})
-    values = f32[2,3]{1,0} parameter(0)
-    transpose = f32[3,2]{1,0} transpose(values), dimensions={1,0}
-    ROOT sort = (f32[3,2]{1,0}, f32[3,2]{1,0}) sort(keys, transpose),
-      dimensions={1}, to_apply=compare
-  })";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(hlo_text));
-
-  ComputationLayout computation_layout(
-      module->entry_computation()->ComputeProgramShape(),
-      /*ignore_layouts=*/false);
-  GpuLayoutAssignment layout_assignment(
-      &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
-
-  EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
-
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch(m::Sort(m::Op().WithShape(F32, {3, 2}, {1, 0}),
-                                 m::Op().WithShape(F32, {3, 2}, {1, 0}))));
-}
-
-TEST_F(LayoutAssignmentTest, FftLayout) {
-  const char* hlo_text = R"(
-  HloModule Fft_module
-
-  ENTRY Fft {
-    input = c64[8,32]{0,1} parameter(0)
-    fft = c64[8,32] fft(input), fft_type=FFT, fft_length={32}
-    ROOT transpose = c64[32,8] transpose(fft), dimensions={1,0}
-  })";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(hlo_text));
-
-  ComputationLayout computation_layout(
-      module->entry_computation()->ComputeProgramShape(),
-      /*ignore_layouts=*/false);
-  GpuLayoutAssignment layout_assignment(
-      &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
-
-  EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch(m::Copy(
-                  m::Transpose(m::Fft(m::Op().WithShape(C64, {8, 32}, {1, 0}))
-                                   .WithShape(C64, {8, 32}, {1, 0})))));
-}
-
-TEST_F(LayoutAssignmentTest, CustomCallConstrainedAlias) {
-  const char* module_str = R"(
-HloModule TestModule
-
-ENTRY entry {
-  Arg_0 = f32[2,5,5]{2,1,0} parameter(0)
-  Arg_1 = f32[2,5,5]{2,1,0} parameter(1)
-  Arg_2 = f32[2,5,5]{2,1,0} parameter(2)
-  dot.0 = f32[2,5,5]{2,1,0} dot(Arg_1, Arg_2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={2}, operand_precision={highest,highest}
-  custom-call.0 = (f32[2,5,5]{1,2,0}, s8[16]{0}, s8[16]{0}) custom-call(Arg_0, dot.0), custom_call_target="dummy_call", operand_layout_constraints={f32[2,5,5]{1,2,0}, f32[2,5,5]{1,2,0}}, output_to_operand_aliasing={{0}: (1, {})}
-  ROOT get-tuple-element.0 = f32[2,5,5]{1,2,0} get-tuple-element(custom-call.0), index=0
-}
-)";
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
-                          ParseAndReturnVerifiedModule(module_str));
-  ComputationLayout computation_layout(
-      m->entry_computation()->ComputeProgramShape());
-
-  GpuLayoutAssignment layout_assignment(
-      &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
-
-  EXPECT_THAT(layout_assignment.Run(m.get()), IsOkAndHolds(true));
-
-  const HloInstruction* call_0 = FindInstruction(m.get(), "custom-call.0");
-  auto expect_layout = [](const Shape& shape,
-                          absl::Span<const int64_t> minor_to_major) {
-    const Layout expected = LayoutUtil::MakeLayout(minor_to_major);
-    EXPECT_TRUE(LayoutUtil::Equal(shape.layout(), expected))
-        << "Expected layout " << expected << ", actual " << shape.layout();
-  };
-  expect_layout(ShapeUtil::GetSubshape(call_0->shape(), {0}), {1, 2, 0});
-  expect_layout(call_0->operand(0)->shape(), {1, 2, 0});
-  expect_layout(call_0->operand(1)->shape(), {1, 2, 0});
-}
-
-TEST_F(LayoutAssignmentTest, MoveToHostCustomCallConstrained) {
-  const char* module_str = R"(
-HloModule TestModule
-
-ENTRY entry {
-  Arg_0 = f32[2,5,5]{2,1,0} parameter(0)
-  custom-call.0 = f32[2,5,5] custom-call(Arg_0), custom_call_target="MoveToHost"
-  ROOT custom-call.1 = f32[2,5,5]{2, 1, 0} custom-call(custom-call.0), custom_call_target="fixed_call", operand_layout_constraints={f32[2,5,5]{1,2,0}}
-}
-)";
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
-                          ParseAndReturnVerifiedModule(module_str));
-  ComputationLayout computation_layout(
-      m->entry_computation()->ComputeProgramShape());
-
-  GpuLayoutAssignment layout_assignment(
-      &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
-
-  EXPECT_THAT(layout_assignment.Run(m.get()), IsOkAndHolds(true));
-
-  const HloInstruction* call_0 = FindInstruction(m.get(), "custom-call.0");
-  const Layout input_layout = call_0->operand(0)->shape().layout();
-  const Layout output_layout = call_0->shape().layout();
-  EXPECT_TRUE(LayoutUtil::Equal(input_layout, output_layout))
-      << "Expected the same input/output layouts.  Input: " << input_layout
-      << ". Output: " << output_layout;
-}
-
-TEST_F(LayoutAssignmentTest, MoveToDeviceCustomCallConstrained) {
-  const char* module_str = R"(
-HloModule TestModule
-
-ENTRY entry {
-  Arg_0 = f32[2,5,5]{2,1,0} parameter(0)
-  custom-call.0 = f32[2,5,5] custom-call(Arg_0), custom_call_target="MoveToDevice"
-  ROOT custom-call.1 = f32[2,5,5]{2, 1, 0} custom-call(custom-call.0), custom_call_target="fixed_call", operand_layout_constraints={f32[2,5,5]{1,2,0}}
-}
-)";
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
-                          ParseAndReturnVerifiedModule(module_str));
-  ComputationLayout computation_layout(
-      m->entry_computation()->ComputeProgramShape());
-
-  GpuLayoutAssignment layout_assignment(
-      &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
-
-  EXPECT_THAT(layout_assignment.Run(m.get()), IsOkAndHolds(true));
-
-  const HloInstruction* call_0 = FindInstruction(m.get(), "custom-call.0");
-  const Layout input_layout = call_0->operand(0)->shape().layout();
-  const Layout output_layout = call_0->shape().layout();
-  EXPECT_TRUE(LayoutUtil::Equal(input_layout, output_layout))
-      << "Expected the same input/output layouts.  Input: " << input_layout
-      << ". Output: " << output_layout;
-}
-
-TEST_F(LayoutAssignmentTest, ConvCuDNNF8) {
-  if (!GetCudaComputeCapability().IsAtLeast(
-          se::CudaComputeCapability::HOPPER)) {
-    GTEST_SKIP() << "FP8 convolutions require HOPPER or newer archiecture.";
-  }
-
-  const char* hlo = R"(
-
-  HloModule jit_conv_general_dilated
-
-  ENTRY main.4 {
-    Arg_0 = f8e4m3fn[1,64,64,16]{3,2,1,0} parameter(0)
-    Arg_1 = f8e4m3fn[3,3,16,32]{3,2,1,0} parameter(1)
-    ROOT conv = f8e4m3fn[1,64,64,32]{3,2,1,0} convolution(Arg_0, Arg_1), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f
-  }
-)";
-
-  MatchOptimizedHlo(hlo, R"(
-  // CHECK: [[P0:%[^ ]+]] = f8e4m3fn[1,64,64,16]{3,2,1,0} parameter(0)
-  // CHECK: [[P1:%[^ ]+]] = f8e4m3fn[3,3,16,32]{3,2,1,0} parameter(1)
-  // CHECK-NEXT: [[P2:%[^ ]+]] = f8e4m3fn[32,3,3,16]{3,2,1,0} transpose([[P1]]), dimensions={3,0,1,2}
-  // CHECK-NEXT: [[CONV:%[^ ]+]] = (f8e4m3fn[1,64,64,32]{3,2,1,0}, u8[0]{0}) custom-call([[P0]], [[P2]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph"
-  )");
-}
-
-TEST_F(LayoutAssignmentTest, ConvCuDNNBF16) {
-  if (!GetCudaComputeCapability().IsAtLeast(
-          se::CudaComputeCapability::AMPERE)) {
-    GTEST_SKIP() << "Conv with Bfloat16 uses NHWC layout for "
-                    "architectures with Tensor Cores.";
-  }
-
-  const char* hlo = R"(
-
-  HloModule jit_conv_general_dilated
-
-  ENTRY main.4 {
-    Arg_0.1 = bf16[1,64,64,16]{3,2,1,0} parameter(0), sharding={replicated}
-    Arg_1.2 = bf16[3,3,16,32]{3,2,1,0} parameter(1), sharding={replicated}
-    ROOT convolution.3 = bf16[1,64,64,32]{3,2,1,0} convolution(Arg_0.1, Arg_1.2), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, metadata={op_name="jit(conv_general_dilated)/jit(main)/conv_general_dilated[window_strides=(1, 1) padding=((1, 1), (1, 1)) lhs_dilation=(1, 1) rhs_dilation=(1, 1) dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2)) feature_group_count=1 batch_group_count=1 lhs_shape=(1, 64, 64, 16) rhs_shape=(3, 3, 16, 32) precision=None preferred_element_type=None]" source_file="/usr/local/lib/python3.8/dist-packages/flax/linen/linear.py" source_line=438}
-  }
-)";
-
-  MatchOptimizedHlo(hlo, R"(
-  // CHECK: [[P0:%[^ ]+]] = bf16[1,64,64,16]{3,2,1,0} parameter(0), sharding={replicated}
-  // CHECK: [[P1:%[^ ]+]] = bf16[3,3,16,32]{3,2,1,0} parameter(1), sharding={replicated}
-  // CHECK-NEXT: [[P2:%[^ ]+]] = bf16[32,3,3,16]{3,2,1,0} transpose([[P1]]), dimensions={3,0,1,2}
-  // CHECK-NEXT: %cudnn-conv.1 = (bf16[1,64,64,32]{3,2,1,0}, u8[0]{0}) custom-call([[P0]], [[P2]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForward"
-  )");
-}
-
-TEST_F(LayoutAssignmentTest, ConvCuDNNFP16) {
-  if (!GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::VOLTA)) {
-    GTEST_SKIP() << "Conv with FP16 uses NHWC layout for "
-                    "architectures with Tensor Cores.";
-  }
-
-  const char* hlo = R"(
-
-  HloModule jit_conv_general_dilated
-
-  ENTRY main.4 {
-    Arg_0.1 = f16[1,64,64,16]{3,2,1,0} parameter(0), sharding={replicated}
-    Arg_1.2 = f16[3,3,16,32]{3,2,1,0} parameter(1), sharding={replicated}
-    ROOT convolution.3 = f16[1,64,64,32]{3,2,1,0} convolution(Arg_0.1, Arg_1.2), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f
-  }
-)";
-
-  MatchOptimizedHlo(hlo, R"(
-  // CHECK: [[P0:%[^ ]+]] = f16[1,64,64,16]{3,2,1,0} parameter(0), sharding={replicated}
-  // CHECK: [[P1:%[^ ]+]] = f16[3,3,16,32]{3,2,1,0} parameter(1), sharding={replicated}
-  // CHECK-NEXT: [[P2:%[^ ]+]] = f16[32,3,3,16]{3,2,1,0} transpose([[P1]]), dimensions={3,0,1,2}
-  // CHECK-NEXT: %cudnn-conv.1 = (f16[1,64,64,32]{3,2,1,0}, u8[0]{0}) custom-call([[P0]], [[P2]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForward"
-  )");
-}
-
-TEST_F(LayoutAssignmentTest, ReduceOperandLayout) {
-  const char* module_str = R"(
-scalar_add_computation {
-  scalar_lhs = c64[] parameter(0)
-  scalar_rhs = c64[] parameter(1)
-  ROOT add.1 = c64[] add(scalar_lhs, scalar_rhs)
-}
-
-ENTRY main {
-  param_0 = c64[512,64,1024,32,128]{4,3,2,1,0} parameter(0)
-  negate = c64[512,64,1024,32,128]{4,3,2,1,0} negate(param_0)
-  constant_7 = c64[] constant((0, 0))
-  ROOT reduce.2 = c64[512,1024,128]{2,1,0} reduce(negate, constant_7), dimensions={1,3}, to_apply=scalar_add_computation
-}
-)";
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
-                          ParseAndReturnVerifiedModule(module_str));
-  ComputationLayout computation_layout(
-      m->entry_computation()->ComputeProgramShape());
-  GpuLayoutAssignment layout_assignment(
-      &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
-
-  EXPECT_THAT(layout_assignment.Run(m.get()), IsOkAndHolds(true));
-  auto reduce = m->entry_computation()->root_instruction();
-  EXPECT_EQ(reduce->operand(0)->shape().layout().minor_to_major(),
-            LayoutUtil::MakeLayout({3, 1, 4, 2, 0}).minor_to_major());
-}
-
-TEST_F(LayoutAssignmentTest, ReduceOperandLayoutDivisorOfWarpSize) {
-  // Same as ReduceOperandLayout, but with a small reduction dimension that
-  // is a divisor of the warp size.
-  const char* module_str = R"(
-scalar_add_computation {
-  scalar_lhs = c64[] parameter(0)
-  scalar_rhs = c64[] parameter(1)
-  ROOT add.1 = c64[] add(scalar_lhs, scalar_rhs)
-}
-
-ENTRY main {
-  param_0 = c64[512,16,1024,128]{3,2,1,0} parameter(0)
-  negate = c64[512,16,1024,128]{3,2,1,0} negate(param_0)
-  constant_7 = c64[] constant((0, 0))
-  ROOT reduce.2 = c64[512,1024,128]{2,1,0} reduce(negate, constant_7), dimensions={1}, to_apply=scalar_add_computation
-}
-)";
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
-                          ParseAndReturnVerifiedModule(module_str));
-  ComputationLayout computation_layout(
-      m->entry_computation()->ComputeProgramShape());
-  GpuLayoutAssignment layout_assignment(
-      &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
-
-  EXPECT_THAT(layout_assignment.Run(m.get()), IsOkAndHolds(true));
-  auto reduce = m->entry_computation()->root_instruction();
-  EXPECT_EQ(reduce->operand(0)->shape().layout().minor_to_major(),
-            LayoutUtil::MakeLayout({1, 3, 2, 0}).minor_to_major());
-}
-
-TEST_F(LayoutAssignmentTest, SendRcvLayout) {
-  const char* hlo = R"(
-HloModule Module
-
-condition  {
-    p = (f32[100,100], (f32[100,100], u32[], token[])) parameter(0)
-    ROOT lt = pred[] constant(1)
-}
-
-body {
-    p = (f32[100,100], (f32[100,100], u32[], token[])) parameter(0)
-
-    t1 = f32[100,100] get-tuple-element(p), index=0
-    t = (f32[100,100], u32[], token[]) get-tuple-element(p), index=1
-    sdone = token[] send-done(t), channel_id=3, frontend_attributes={
-      _xla_send_recv_pipeline="0"
-    }
-    tk = token[] after-all()
-
-
-    rcvd = (f32[100,100]{0,1}, u32[], token[]) recv(tk), channel_id=2
-    zz = (f32[100,100]{0,1}, token[]) recv-done(rcvd), channel_id=2
-
-    rcvd_d = get-tuple-element(zz), index=0
-
-    snd = (f32[100,100]{0,1}, u32[], token[]) send(t1, tk), channel_id=3, frontend_attributes={
-      _xla_send_recv_pipeline="0"
-    }
-    a = add(t1, t1)
-
-    b = add(rcvd_d, a)
-
-    ROOT tup =  tuple(b, snd)
-}
-
-ENTRY %main {
-    p0 = f32[100,100] parameter(0)
-    tk = token[] after-all()
-    snd = (f32[100,100]{0,1}, u32[], token[]) send(p0, tk), channel_id=1, frontend_attributes={
-      _xla_send_recv_pipeline="0"
-    }
-    t = tuple(p0, snd)
-    ROOT loop = while(t), condition=condition, body=body
-}
-)";
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
-                          ParseAndReturnVerifiedModule(hlo));
-  ComputationLayout computation_layout(
-      m->entry_computation()->ComputeProgramShape());
-
-  RunAndFilecheckHloRewrite(
-      hlo,
-      GpuLayoutAssignment{&computation_layout, GetGpuComputeCapability(),
-                          GetDnnVersion()},
-      R"(
-// CHECK: (f32[100,100]{1,0}, u32[], token[]) recv
-// CHECK:  (f32[100,100]{1,0}, token[]) recv-done
-// CHECK:  (f32[100,100]{1,0}, u32[], token[]) send
-                                )");
-}
-
-}  // namespace
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gpu_offloading_test.cc b/third_party/xla/xla/service/gpu/gpu_offloading_test.cc
index 928011c..215609c 100644
--- a/third_party/xla/xla/service/gpu/gpu_offloading_test.cc
+++ b/third_party/xla/xla/service/gpu/gpu_offloading_test.cc
@@ -31,7 +31,7 @@
 #include "xla/layout.h"
 #include "xla/service/buffer_value.h"
 #include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/stream_attribute_annotator.h"
+#include "xla/service/gpu/transforms/stream_attribute_annotator.h"
 #include "xla/service/hlo_cost_analysis.h"
 #include "xla/service/hlo_memory_scheduler.h"
 #include "xla/service/hlo_rematerialization.h"
diff --git a/third_party/xla/xla/service/gpu/gpu_p2p_pipeliner_test.cc b/third_party/xla/xla/service/gpu/gpu_p2p_pipeliner_test.cc
index 1dcab0c..de3adb3 100644
--- a/third_party/xla/xla/service/gpu/gpu_p2p_pipeliner_test.cc
+++ b/third_party/xla/xla/service/gpu/gpu_p2p_pipeliner_test.cc
@@ -146,7 +146,7 @@
   EXPECT_EQ(send1->channel_id(), send2->channel_id());
 
   const char* kPeeledAttr = "_xla_send_recv_validation=\"invalid\"";
-  const char* kRotatedAttr = "_xla_send_recv_validation=\"{{0,6}}\"";
+  const char* kRotatedAttr = "_xla_send_recv_validation={{0,6}}";
   EXPECT_THAT(send1->ToString(), ::testing::HasSubstr(kPeeledAttr));
   EXPECT_THAT(recv1->ToString(), ::testing::HasSubstr(kPeeledAttr));
   EXPECT_THAT(send2->ToString(), ::testing::HasSubstr(kRotatedAttr));
diff --git a/third_party/xla/xla/service/gpu/gpu_reduce_scatter_creator.cc b/third_party/xla/xla/service/gpu/gpu_reduce_scatter_creator.cc
deleted file mode 100644
index 7f1f800..0000000
--- a/third_party/xla/xla/service/gpu/gpu_reduce_scatter_creator.cc
+++ /dev/null
@@ -1,130 +0,0 @@
-/* Copyright 2021 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gpu_reduce_scatter_creator.h"
-
-#include <cstdint>
-#include <optional>
-#include <vector>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/log/log.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/hlo/utils/hlo_query.h"
-#include "xla/service/collective_opt_utils.h"
-#include "xla/service/hlo_module_config.h"
-#include "xla/shape.h"
-#include "xla/status_macros.h"
-#include "tsl/platform/errors.h"
-
-namespace xla {
-namespace gpu {
-
-absl::StatusOr<bool> ReduceScatterCreator::Run(
-    HloModule *module,
-    const absl::flat_hash_set<absl::string_view> &execution_threads) {
-  const HloModuleConfig &config = module->config();
-  int64_t next_channel_id = hlo_query::NextChannelId(*module);
-
-  bool changed = false;
-  for (HloComputation *computation :
-       module->MakeNonfusionComputations(execution_threads)) {
-    for (HloInstruction *instruction :
-         computation->MakeInstructionPostOrder()) {
-      if (instruction->opcode() != HloOpcode::kAllReduce) {
-        continue;
-      }
-      auto *ar = Cast<HloAllReduceInstruction>(instruction);
-      auto ar_spec = MatchReduceScatter(ar, config.num_partitions(),
-                                        config.replica_count(),
-                                        /*allow_multiple_split_dims=*/false,
-                                        /*allow_intervening_reshape=*/true);
-      if (!ar_spec) {
-        VLOG(2) << "Cannot match reduce-scatter " << ar->ToString();
-        continue;
-      }
-
-      HloInstruction *ds = ar_spec->dynamic_slice;
-
-      // Convert to all-reduce scatter. The output shape of the all-reduce
-      // scatter will the same as the input shape, except the split dim size is
-      // that of the result of the dynamic slice.
-      const int64_t split_dim = ar_spec->split_dim;
-      Shape scatter_shape = ar->shape();
-      const int64_t split_dim_size = scatter_shape.dimensions(split_dim);
-      HloInstruction *rs_input = ar->mutable_operand(0);
-      const int64_t scatter_dim_size = split_dim_size / ar_spec->group_size;
-      TF_RET_CHECK(scatter_dim_size * ar_spec->group_size <= split_dim_size);
-      if (split_dim_size % ar_spec->group_size != 0) {
-        // The dynamic-slice does not evenly split the scatter dim. In that
-        // case, create a reduce-scatter with the relevant slice of the
-        // all-reduce input.
-        scatter_shape.set_dimensions(split_dim,
-                                     scatter_dim_size * ar_spec->group_size);
-        rs_input = computation->AddInstruction(HloInstruction::CreateSlice(
-            scatter_shape, rs_input,
-            std::vector<int64_t>(scatter_shape.rank(), 0),
-            scatter_shape.dimensions(),
-            std::vector<int64_t>(scatter_shape.rank(), 1)));
-      }
-      scatter_shape.set_dimensions(split_dim, scatter_dim_size);
-
-      std::optional<int64_t> channel_id;
-      if (ar->channel_id()) {
-        // We cannot reuse the channel_id on all-reduce for reduce-scatter.
-        channel_id = next_channel_id++;
-      }
-
-      HloInstruction *ars =
-          computation->AddInstruction(HloInstruction::CreateReduceScatter(
-              scatter_shape, {rs_input}, ar->to_apply(), ar->device_list(),
-              ar->constrain_layout(), channel_id, ar->use_global_device_ids(),
-              ar_spec->split_dim));
-
-      // If there was an intervening reshape, reshape the non-split dimensions
-      // to match that existing reshape. Basically we can just reshape the ars
-      // result to the dynamic slice shape.
-      HloInstruction *result = ars;
-      HloInstruction *reshape = nullptr;
-      if (ds->operand(0) != ar) {
-        reshape = ds->mutable_operand(0);
-        result = computation->AddInstruction(
-            HloInstruction::CreateReshape(ds->shape(), result));
-      }
-
-      // Note that RemoveInstructionAndUnusedOperands may not always remove the
-      // all-reduce operand of the dynamic-slice, so remove all the dead
-      // instructions manually.
-      TF_RETURN_IF_ERROR(ds->ReplaceAllUsesWith(result));
-      TF_RETURN_IF_ERROR(computation->RemoveInstruction(ds));
-      if (reshape) {
-        TF_RETURN_IF_ERROR(computation->RemoveInstruction(reshape));
-      }
-      TF_RETURN_IF_ERROR(computation->RemoveInstructionAndUnusedOperands(ar));
-      changed = true;
-    }
-  }
-
-  return changed;
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gpu_reduce_scatter_creator.h b/third_party/xla/xla/service/gpu/gpu_reduce_scatter_creator.h
deleted file mode 100644
index fcecb46..0000000
--- a/third_party/xla/xla/service/gpu/gpu_reduce_scatter_creator.h
+++ /dev/null
@@ -1,43 +0,0 @@
-/* Copyright 2021 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_GPU_REDUCE_SCATTER_CREATOR_H_
-#define XLA_SERVICE_GPU_GPU_REDUCE_SCATTER_CREATOR_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-// Transforms dynamic-slice(all-reduce) to a reduce-scatter.
-class ReduceScatterCreator : public HloModulePass {
- public:
-  ReduceScatterCreator() = default;
-  absl::string_view name() const override { return "reduce-scatter-creator"; }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_GPU_REDUCE_SCATTER_CREATOR_H_
diff --git a/third_party/xla/xla/service/gpu/gpu_sanitize_constant_names.cc b/third_party/xla/xla/service/gpu/gpu_sanitize_constant_names.cc
deleted file mode 100644
index 771e8cb..0000000
--- a/third_party/xla/xla/service/gpu/gpu_sanitize_constant_names.cc
+++ /dev/null
@@ -1,75 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gpu_sanitize_constant_names.h"
-
-#include <string>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/llvm_ir/buffer_assignment_util.h"
-#include "xla/service/name_uniquer.h"
-#include "tsl/platform/logging.h"
-
-namespace xla {
-
-namespace gpu {
-
-absl::StatusOr<bool> GpuSanitizeConstantNames::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  bool changed = false;
-
-  NameUniquer instr_name_uniquer(/*separator=*/"_");
-  // Collect the names used for the non-constant HLO instructions.+
-  for (HloComputation* computation : module->computations(execution_threads)) {
-    for (HloInstruction* instr : computation->instructions()) {
-      if (instr->opcode() == HloOpcode::kConstant) {
-        continue;
-      }
-
-      // Record the non-constant HLO instruction name in uniquer, and keep
-      // original instruction name unchanged.
-      instr_name_uniquer.GetUniqueName(instr->name());
-    }
-  }
-
-  // Sanitize the names for the constant HLO instructions and make them unique.
-  // This is not merged into the above loop because we don't want this pass to
-  // change the names of non-constant instructions, that is, if a constant HLO
-  // conflicts with a non-constant HLO, we change the name of the constant HLO
-  // even though the non-constant HLO comes after in the HLO module.
-  for (HloComputation* computation : module->computations(execution_threads)) {
-    for (HloInstruction* instr : computation->instructions()) {
-      if (instr->opcode() != HloOpcode::kConstant) {
-        continue;
-      }
-      std::string sanitized_name = llvm_ir::SanitizeConstantName(*instr);
-      instr->SetAndSanitizeName(sanitized_name);
-      instr->UniquifyName(&instr_name_uniquer);
-      // Register this new name with the module's instruction_name_uniquer to
-      // avoid name collision that might happen in future.
-      module->instruction_name_uniquer().GetUniqueName(instr->name());
-      changed = true;
-    }
-  }
-
-  return changed;
-}  // namespace gpu
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gpu_sanitize_constant_names.h b/third_party/xla/xla/service/gpu/gpu_sanitize_constant_names.h
deleted file mode 100644
index 08701a4..0000000
--- a/third_party/xla/xla/service/gpu/gpu_sanitize_constant_names.h
+++ /dev/null
@@ -1,44 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_GPU_SANITIZE_CONSTANT_NAMES_H_
-#define XLA_SERVICE_GPU_GPU_SANITIZE_CONSTANT_NAMES_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-// Sanitizes HLO instruction names for the GPU backend. Currently, it only
-// replaces . and - in the HLO constant instruction names with _ to please the
-// LLVM PTX backend.
-class GpuSanitizeConstantNames : public HloModulePass {
- public:
-  absl::string_view name() const override { return "sanitize-constant-names"; }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_GPU_SANITIZE_CONSTANT_NAMES_H_
diff --git a/third_party/xla/xla/service/gpu/gpu_sanitize_constant_names_test.cc b/third_party/xla/xla/service/gpu/gpu_sanitize_constant_names_test.cc
deleted file mode 100644
index 17f45dc..0000000
--- a/third_party/xla/xla/service/gpu/gpu_sanitize_constant_names_test.cc
+++ /dev/null
@@ -1,111 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gpu_sanitize_constant_names.h"
-
-#include <cstdint>
-#include <memory>
-#include <utility>
-
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/literal_util.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/tests/hlo_test_base.h"
-#include "tsl/platform/statusor.h"
-#include "tsl/platform/test.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-namespace m = ::xla::match;
-using SanitizeConstantNamesTest = HloTestBase;
-
-TEST_F(SanitizeConstantNamesTest, InstructionNameWithHyphenSanitized) {
-  const char *const kHloString = R"(
-    HloModule HyphenInInstructionName
-      ENTRY kernelEntry {
-        ROOT equal-to = s32[2]{0} constant({42, 73})
-    })";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(kHloString));
-
-  EXPECT_TRUE(GpuSanitizeConstantNames().Run(module.get()).value());
-  HloInstruction *root = module->entry_computation()->root_instruction();
-  EXPECT_EQ(root->name(), "equal_to");
-}
-
-TEST_F(SanitizeConstantNamesTest, InstructionNameWithDotSanitized) {
-  const char *const kHloString = R"(
-    HloModule HyphenInInstructionName
-      ENTRY kernelEntry {
-        ROOT equal.to = s32[2]{0} constant({42, 73})
-    })";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(kHloString));
-
-  EXPECT_TRUE(GpuSanitizeConstantNames().Run(module.get()).value());
-  HloInstruction *root = module->entry_computation()->root_instruction();
-  EXPECT_EQ(root->name(), "equal_to");
-}
-
-TEST_F(SanitizeConstantNamesTest, NewInstructionNameRegisteredWithModule) {
-  const char *const kHloString = R"(
-    HloModule HyphenInInstructionName
-      ENTRY kernelEntry {
-        ROOT equal.to = s32[2]{0} constant({42, 73})
-    })";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(kHloString));
-
-  EXPECT_TRUE(GpuSanitizeConstantNames().Run(module.get()).value());
-  HloInstruction *root = module->entry_computation()->root_instruction();
-  EXPECT_EQ(root->name(), "equal_to");
-
-  auto constant_instr =
-      HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(1));
-  constant_instr->SetAndSanitizeName("equal_to");
-  module->entry_computation()->AddInstruction(std::move(constant_instr));
-
-  EXPECT_THAT(FindInstruction(module.get(), "equal_to.1"),
-              GmockMatch(m::Constant()));
-}
-
-TEST_F(SanitizeConstantNamesTest, BufferSanitizedNameCollisionResolved) {
-  const char *const kHloString = R"(
-    HloModule BufferSanitizedName
-      ENTRY kernelEntry {
-      equal.to = s32[2]{0} constant({42, 73})
-      equal-to = s32[2]{0} constant({67, 3})
-      ROOT equal_to = s32[2]{0} add(equal.to, equal-to)
-    })";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(kHloString));
-
-  EXPECT_TRUE(GpuSanitizeConstantNames().Run(module.get()).value());
-  EXPECT_THAT(FindInstruction(module.get(), "equal_to_1"),
-              GmockMatch(m::Constant()));
-  EXPECT_THAT(FindInstruction(module.get(), "equal_to_2"),
-              GmockMatch(m::Constant()));
-}
-
-}  // namespace
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gpu_scatter_expander.cc b/third_party/xla/xla/service/gpu/gpu_scatter_expander.cc
deleted file mode 100644
index b03b340..0000000
--- a/third_party/xla/xla/service/gpu/gpu_scatter_expander.cc
+++ /dev/null
@@ -1,33 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gpu_scatter_expander.h"
-
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/primitive_util.h"
-
-namespace xla {
-
-bool GpuScatterExpander::InstructionMatchesPattern(HloInstruction* inst) {
-  // TODO(b/129698548): Scattering elements larger than 64 bits is not
-  // supported by XLA:GPU.
-  // TODO(b/227486631): Variadic scatter is not yet supported by GPU.
-  return inst->opcode() == HloOpcode::kScatter &&
-         (inst->shape().IsTuple() ||
-          primitive_util::BitWidth(inst->shape().element_type()) > 64);
-}
-
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gpu_scatter_expander.h b/third_party/xla/xla/service/gpu/gpu_scatter_expander.h
deleted file mode 100644
index 100350c..0000000
--- a/third_party/xla/xla/service/gpu/gpu_scatter_expander.h
+++ /dev/null
@@ -1,40 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_GPU_SCATTER_EXPANDER_H_
-#define XLA_SERVICE_GPU_GPU_SCATTER_EXPANDER_H_
-
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/service/scatter_expander.h"
-
-namespace xla {
-
-// Legalizes scatters on the GPU.
-class GpuScatterExpander : public ScatterExpander {
- public:
-  // Although we pass kEliminateAllScatters, we override this behavior in
-  // InstruuctionMatchesPattern and select only some scatters to expand.
-  GpuScatterExpander() : ScatterExpander(kEliminateAllScatters) {}
-
-  absl::string_view name() const override { return "gpu_scatter_expander"; }
-
- protected:
-  bool InstructionMatchesPattern(HloInstruction* inst) override;
-};
-
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_GPU_SCATTER_EXPANDER_H_
diff --git a/third_party/xla/xla/service/gpu/gpu_schedule_postprocessing.cc b/third_party/xla/xla/service/gpu/gpu_schedule_postprocessing.cc
deleted file mode 100644
index a0af798..0000000
--- a/third_party/xla/xla/service/gpu/gpu_schedule_postprocessing.cc
+++ /dev/null
@@ -1,165 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gpu_schedule_postprocessing.h"
-
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_map.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/hlo/ir/hlo_schedule.h"
-#include "xla/hlo/utils/hlo_query.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-// Maps a computation to a boolean that indicates whether the computation may
-// invoke custom-calls directly or indirectly, which can eventually trigger gpu
-// synchronization.
-using CustomCallInComputation =
-    absl::flat_hash_map<const HloComputation*, bool>;
-
-// Returns whether the hlo may invoke custom-calls which may trigger gpu
-// synchronization. Currently, we only check for custom-calls, because they are
-// the only operations that can be parallel with asynchronous collectives
-// operations in an hlo-schedule and may trigger gpu synchronization.
-bool MayInvokeCustomCall(
-    const HloInstruction* hlo,
-    const CustomCallInComputation& custom_call_in_computation) {
-  if (hlo->opcode() == HloOpcode::kCustomCall) {
-    return true;
-  }
-
-  return absl::c_any_of(
-      hlo->called_computations(), [&](const HloComputation* callee) {
-        return custom_call_in_computation.find(callee)->second;
-      });
-}
-
-// Returns true if this is an asynchronous collective start operation, excluding
-// P2P operations.
-absl::StatusOr<bool> IsRelevantAsynchronousStart(const HloInstruction* hlo) {
-  if (!hlo_query::IsAsyncCollectiveStartOp(hlo,
-                                           /*include_send_recv=*/false)) {
-    return false;
-  }
-  TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
-                      hlo->backend_config<GpuBackendConfig>());
-  const CollectiveBackendConfig& collective_backend_config =
-      gpu_config.collective_backend_config();
-  return !collective_backend_config.is_sync();
-}
-
-// Returns true if this is a collective done operation, excluding P2P
-// operations.
-absl::StatusOr<bool> IsRelevantAsynchronousDone(const HloInstruction* hlo) {
-  return hlo_query::IsAsyncCollectiveDoneOp(hlo,
-                                            /*include_send_recv=*/false);
-}
-
-// For a given computation, finds all the asynchronous collective operations
-// that aren't parallel with custom-calls and sets its no_parallel_custom_call
-// attribute to true. Also records whether the given computation may invoke
-// custom-calls.
-absl::StatusOr<bool> ProcessComputation(
-    const HloSchedule& schedule, HloComputation* computation,
-    CustomCallInComputation& custom_call_in_computation) {
-  bool changed = false;
-  bool has_custom_call = false;
-  absl::flat_hash_set<HloInstruction*> async_starts;
-  const HloInstructionSequence& sequence = schedule.sequence(computation);
-
-  // Visit instructions in the sequence. Collect relevant asynchronous
-  // collective start ops. When we see a relevant asynchronous collective done
-  // op, remove the corresponding start op from the collection and set its
-  // attribute no_parallel_custom_call to true. When we see a custom-call, clear
-  // the start ops from the collection and keep their attribute
-  // no_parallel_custom_call as false.
-  const std::vector<HloInstruction*>& all_instructions =
-      sequence.instructions();
-  for (HloInstruction* hlo : all_instructions) {
-    if (MayInvokeCustomCall(hlo, custom_call_in_computation)) {
-      async_starts.clear();
-      has_custom_call = true;
-      continue;
-    }
-    TF_ASSIGN_OR_RETURN(bool is_async_start, IsRelevantAsynchronousStart(hlo));
-    if (is_async_start) {
-      async_starts.insert(hlo);
-      continue;
-    }
-
-    TF_ASSIGN_OR_RETURN(bool is_async_done, IsRelevantAsynchronousDone(hlo));
-    if (is_async_done) {
-      HloInstruction* async_start = hlo->mutable_operand(0);
-      if (async_starts.contains(async_start)) {
-        changed = true;
-        TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
-                            async_start->backend_config<GpuBackendConfig>());
-        CollectiveBackendConfig& collective_backend_config =
-            *gpu_config.mutable_collective_backend_config();
-        collective_backend_config.set_no_parallel_custom_call(true);
-        TF_RETURN_IF_ERROR(async_start->set_backend_config(gpu_config));
-        async_starts.erase(async_start);
-      }
-    }
-  }
-
-  custom_call_in_computation[computation] = has_custom_call;
-  return changed;
-}
-
-}  // anonymous namespace
-
-absl::StatusOr<bool> GpuSchedulePostprocessing::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  if (!module->has_schedule()) return false;
-  HloSchedule& schedule = module->schedule();
-  bool changed = false;
-  CustomCallInComputation custom_call_in_computation;
-
-  // We visit computations in the order of callees to callers, as information is
-  // propagated from calles to callers.
-  std::vector<HloComputation*> all_computations =
-      module->MakeComputationPostOrder(execution_threads);
-  for (auto iter = all_computations.begin(); iter != all_computations.end();
-       ++iter) {
-    HloComputation* computation = *iter;
-    if (computation->IsFusionComputation()) {
-      custom_call_in_computation[computation] = false;
-      continue;
-    }
-
-    TF_ASSIGN_OR_RETURN(
-        bool result,
-        ProcessComputation(schedule, computation, custom_call_in_computation));
-    changed |= result;
-  }
-
-  return changed;
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gpu_schedule_postprocessing.h b/third_party/xla/xla/service/gpu/gpu_schedule_postprocessing.h
deleted file mode 100644
index d8eda81..0000000
--- a/third_party/xla/xla/service/gpu/gpu_schedule_postprocessing.h
+++ /dev/null
@@ -1,52 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_GPU_SCHEDULE_POSTPROCESSING_H_
-#define XLA_SERVICE_GPU_GPU_SCHEDULE_POSTPROCESSING_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-// Amends a schedule result with the needed information to support a runtime
-// implementation. Currently, this pass refines attribute
-// no_parallel_custom_call for asynchronous collective operations to support
-// runtime optimization, such as skipping rendezvous of all participating
-// threads for NCCL collective operations. In particular, it sets the attribute
-// value for Collective-start operations with is_sync=false; it also keeps the
-// attribute value untouch for the operations with is_sync=true and for P2P
-// operations, assumming the runtime won't use those values.
-//
-class GpuSchedulePostprocessing : public HloModulePass {
- public:
-  absl::string_view name() const override {
-    return "gpu-schedule-postprocessing";
-  }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_GPU_SCHEDULE_POSTPROCESSING_H_
diff --git a/third_party/xla/xla/service/gpu/gpu_schedule_postprocessing_test.cc b/third_party/xla/xla/service/gpu/gpu_schedule_postprocessing_test.cc
deleted file mode 100644
index 9d4956b..0000000
--- a/third_party/xla/xla/service/gpu/gpu_schedule_postprocessing_test.cc
+++ /dev/null
@@ -1,163 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gpu_schedule_postprocessing.h"
-
-#include <memory>
-
-#include <gtest/gtest.h>
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/hlo_parser.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/util.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-using GpuSchedulePostprocessingTest = HloTestBase;
-
-TEST_F(GpuSchedulePostprocessingTest, SynchronousOpsNotChanged) {
-  constexpr absl::string_view kHloString = R"(
-  HloModule module, is_scheduled=true
-
-  ENTRY entry {
-    pf32 = f32[1] parameter(0)
-
-    all-gather-start = (f32[1], f32[2]) all-gather-start(pf32), dimensions={0}, backend_config={"collective_backend_config":{"is_sync":true,"no_parallel_custom_call":false}}
-    ROOT all-gather-done = f32[2] all-gather-done(all-gather-start)
-  }
-)";
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnUnverifiedModule((kHloString)));
-  GpuSchedulePostprocessing pass;
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get()));
-  EXPECT_FALSE(changed);
-}
-
-TEST_F(GpuSchedulePostprocessingTest, P2POpsNotChanged) {
-  constexpr absl::string_view kHloString = R"(
-  HloModule module, is_scheduled=true
-
-  ENTRY main {
-    f0 = f32[] constant(0.0)
-    init = f32[1, 1024, 1024] broadcast(f0), dimensions={}
-
-    after-all = token[] after-all()
-    recv = (f32[1, 1024, 1024], u32[], token[]) recv(after-all), channel_id=2,
-      frontend_attributes={
-      _xla_send_recv_source_target_pairs="{{0,1}, {1,2}}"
-    }
-    recv-done = (f32[1, 1024, 1024], token[]) recv-done(recv), channel_id=2
-    ROOT recv-data = f32[1, 1024, 1024] get-tuple-element(recv-done), index=0
-  }
-)";
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnUnverifiedModule((kHloString)));
-  GpuSchedulePostprocessing pass;
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get()));
-  EXPECT_FALSE(changed);
-}
-
-TEST_F(GpuSchedulePostprocessingTest, AsynchronousOpsChanged) {
-  constexpr absl::string_view kHloString = R"(
-  HloModule module, is_scheduled=true
-
-  ENTRY entry {
-    pf32 = f32[1] parameter(0)
-    pf32.2 = f32[1] custom-call(pf32), custom_call_target="my_custom_call"
-    all-gather-start = (f32[1], f32[2]) all-gather-start(pf32.2), dimensions={0}, backend_config={"collective_backend_config":{"is_sync":false}}
-    ROOT all-gather-done = f32[2] all-gather-done(all-gather-start)
-  }
-)";
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnUnverifiedModule((kHloString)));
-  GpuSchedulePostprocessing pass;
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get()));
-  EXPECT_TRUE(changed);
-
-  HloInstruction* start = FindInstruction(module.get(), "all-gather-start");
-  TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config,
-                          start->backend_config<GpuBackendConfig>());
-  const CollectiveBackendConfig& collective_backend_config =
-      gpu_config.collective_backend_config();
-  EXPECT_TRUE(collective_backend_config.no_parallel_custom_call());
-}
-
-TEST_F(GpuSchedulePostprocessingTest, AsynchronousOpsWithParallelCustomcall) {
-  constexpr absl::string_view kHloString = R"(
-  HloModule module, is_scheduled=true
-
-  ENTRY entry {
-    pf32 = f32[1] parameter(0)
-    all-gather-start = (f32[1], f32[2]) all-gather-start(pf32), dimensions={0}, backend_config={"collective_backend_config":{"is_sync":false}}
-    pf32.2 = f32[1] custom-call(pf32), custom_call_target="my_custom_call"
-    all-gather-done = f32[2] all-gather-done(all-gather-start)
-    ROOT out = (f32[1], f32[2]) tuple(f32[1] pf32.2, f32[2] all-gather-done)
-  }
-)";
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnUnverifiedModule((kHloString)));
-  GpuSchedulePostprocessing pass;
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get()));
-  EXPECT_FALSE(changed);
-
-  HloInstruction* start = FindInstruction(module.get(), "all-gather-start");
-  TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config,
-                          start->backend_config<GpuBackendConfig>());
-  const CollectiveBackendConfig& collective_backend_config =
-      gpu_config.collective_backend_config();
-  EXPECT_FALSE(collective_backend_config.no_parallel_custom_call());
-}
-
-TEST_F(GpuSchedulePostprocessingTest,
-       AsynchronousOpsWithParallelNestedCustomcall) {
-  constexpr absl::string_view kHloString = R"(
-  HloModule module, is_scheduled=true
-  foo {
-    v = f32[1] parameter(0)
-    ROOT ret = f32[1] custom-call(v), custom_call_target="my_custom_call"
-  }
-
-  ENTRY entry {
-    pf32 = f32[1] parameter(0)
-    all-gather-start = (f32[1], f32[2]) all-gather-start(pf32), dimensions={0}, backend_config={"collective_backend_config":{"is_sync":false}}
-    pf32.2 = f32[1] call(f32[1] pf32), to_apply=foo
-    all-gather-done = f32[2] all-gather-done(all-gather-start)
-    ROOT out = (f32[1], f32[2]) tuple(f32[1] pf32.2, f32[2] all-gather-done)
-  }
-)";
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnUnverifiedModule((kHloString)));
-  GpuSchedulePostprocessing pass;
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get()));
-  EXPECT_FALSE(changed);
-
-  HloInstruction* start = FindInstruction(module.get(), "all-gather-start");
-  TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config,
-                          start->backend_config<GpuBackendConfig>());
-  const CollectiveBackendConfig& collective_backend_config =
-      gpu_config.collective_backend_config();
-  EXPECT_FALSE(collective_backend_config.no_parallel_custom_call());
-}
-
-}  // namespace
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gpu_sort_rewriter.cc b/third_party/xla/xla/service/gpu/gpu_sort_rewriter.cc
deleted file mode 100644
index 217387c..0000000
--- a/third_party/xla/xla/service/gpu/gpu_sort_rewriter.cc
+++ /dev/null
@@ -1,343 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gpu_sort_rewriter.h"
-
-#include <algorithm>
-#include <cstdint>
-#include <memory>
-#include <optional>
-#include <utility>
-#include <vector>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "xla/comparison_util.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/service/gpu/runtime/cub_sort_thunk.h"
-#include "xla/service/stable_sort_expander.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/util.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-// Analyze sort comparer function.
-struct SortComputationAnalysis {
-  int key_operand;  // 0 or 1
-  bool descending;
-};
-
-std::pair<int64_t, int64_t> ParametersFromCmpOperands(
-    const HloCompareInstruction* cmp_op) {
-  if (cmp_op == nullptr) {
-    return std::pair<int64_t, int64_t>(-1, -1);
-  }
-  const HloParameterInstruction* param0 =
-      DynCast<HloParameterInstruction>(cmp_op->operand(0));
-  const HloParameterInstruction* param1 =
-      DynCast<HloParameterInstruction>(cmp_op->operand(1));
-  return (param0 && param1) ? std::make_pair(param0->parameter_number(),
-                                             param1->parameter_number())
-                            : std::pair<int64_t, int64_t>(-1, -1);
-}
-
-// Returns sort info on compatible compare instructions. The instruction may
-// belong to a computation that has 2 or 4 operands. If this is the root
-// instruction of a computation with 4 parameters only succeeds in cases where
-// 2 of the parameters are ignored.
-std::optional<SortComputationAnalysis> AnalyzeCompareOp(
-    const HloInstruction* maybe_compare_op) {
-  // Root instruction must be a comparison with a valid direction.
-  const HloCompareInstruction* compare =
-      DynCast<HloCompareInstruction>(maybe_compare_op);
-  if (compare == nullptr || compare->direction() == ComparisonDirection::kEq ||
-      compare->direction() == ComparisonDirection::kNe) {
-    return std::nullopt;
-  }
-
-  // Compare should operate on the function parameters for a single tensor.
-  auto [index0, index1] = ParametersFromCmpOperands(compare);
-  if (index0 == -1 || index1 == -1) {
-    return std::nullopt;
-  }
-
-  // When sorting a pair of tensors, the parameters should be adjacent.
-  int first_index = std::min(index0, index1);
-  if (first_index % 2 != 0 || std::max(index0, index1) != first_index + 1) {
-    return std::nullopt;
-  }
-
-  // Return the tensor index and the sort direction.
-  bool descending = compare->direction() == ComparisonDirection::kGt ||
-                    compare->direction() == ComparisonDirection::kGe;
-  bool reverse = first_index != index0;
-  return SortComputationAnalysis{first_index / 2, descending != reverse};
-}
-
-// Detects a sort with these properties:
-// - Has two operands -- one is an iota op
-// - Has a comparison computation that takes 4 inputs and compares them
-// hierarchically, so that the iota inputs are the final tie-breaker.
-//
-// The above is equivalent to a stable sort where the iota operand is completely
-// ignored. That simpler comparator is the one detected in AnalyzeCompareOp, but
-// that's insufficient, because the StableSortExpander pass expands it into the
-// more complex version detected below.
-std::optional<SortComputationAnalysis> AnalyzeComplexSortComputation(
-    const HloSortInstruction& sort_op) {
-  auto computation = sort_op.called_computations().front();
-  if (computation->num_parameters() != 4) {
-    return std::nullopt;
-  }
-
-  int64_t iota_operand_index =
-      StableSortExpander::IotaOperandIndexForStableSort(sort_op);
-  if (iota_operand_index < 0) {
-    return std::nullopt;
-  }
-
-  auto root = computation->root_instruction();
-  if (root->opcode() != HloOpcode::kSelect) {
-    return std::nullopt;
-  }
-
-  // Check that the middle operand of the select compares the iota input.
-  auto iota_cmp = DynCast<HloCompareInstruction>(root->operand(1));
-  auto [iotap0, iotap1] = ParametersFromCmpOperands(iota_cmp);
-  if (iota_cmp == nullptr ||
-      iota_cmp->direction() != ComparisonDirection::kLt ||
-      iotap0 != iota_operand_index * 2 ||
-      iotap1 != iota_operand_index * 2 + 1) {
-    return std::nullopt;
-  }
-
-  // Check that the first operand of the select is an EQ comparison of the
-  // values (non-iota) input.
-  auto eq_cmp = DynCast<HloCompareInstruction>(root->operand(0));
-  if (eq_cmp == nullptr || eq_cmp->direction() != ComparisonDirection::kEq) {
-    return std::nullopt;
-  }
-
-  // EQ comparison case 1: direct comparison of parameters
-  auto [p0, p1] = ParametersFromCmpOperands(eq_cmp);
-  if (p0 < 0 || p1 < 0) {
-    // EQ comparison case 2: comparison of comparisons. This is what
-    // the StableSortExpander pass currently generates.
-    auto cmp = DynCast<HloCompareInstruction>(eq_cmp->operand(0));
-    auto cmp_reverse = DynCast<HloCompareInstruction>(eq_cmp->operand(1));
-    auto [a, b] = ParametersFromCmpOperands(cmp);
-    auto [p, q] = ParametersFromCmpOperands(cmp_reverse);
-    if (cmp == nullptr || cmp_reverse == nullptr || a < 0 || b < 0 || a != q ||
-        b != p || cmp->direction() != cmp_reverse->direction() ||
-        cmp->direction() == Comparison::Direction::kEq ||
-        cmp->direction() == Comparison::Direction::kNe) {
-      return std::nullopt;
-    }
-  }
-
-  // At this point only the last operand of the select needs to be verified.
-  return AnalyzeCompareOp(root->operand(2));
-}
-
-std::optional<SortComputationAnalysis> AnalyzeSortOp(
-    const HloSortInstruction& sort_op) {
-  auto computation = sort_op.called_computations().front();
-
-  // First, check if the computation is a simple compare op on the operands.
-  auto result = AnalyzeCompareOp(computation->root_instruction());
-  if (!result.has_value()) {
-    // If the above fails, check if the sort instruction and comparer are more
-    // complex, like what is produced by the StableSortExpander pass.
-    result = AnalyzeComplexSortComputation(sort_op);
-  }
-  return result;
-}
-
-// Create runner for CUB sort operation.
-absl::StatusOr<std::unique_ptr<CubSortRunnerInterface>> CreateRunner(
-    HloSortInstruction* sort_op, const SortComputationAnalysis& sort_config) {
-  int value_index = 1 - sort_config.key_operand;
-  return CubSortRunnerInterface::Create(
-      sort_op->operand(sort_config.key_operand)->shape().element_type(),
-      sort_op->operand_count() == 2
-          ? std::optional(sort_op->operand(value_index)->shape().element_type())
-          : std::nullopt);
-}
-
-// Verify that the sort tensor shape is supported by CUB.
-bool IsCubCompatibleSort(HloSortInstruction* sort_op) {
-  VLOG(1) << "Sort instruction: " << sort_op->name();
-  if (sort_op->operand_count() != 1 && sort_op->operand_count() != 2) {
-    VLOG(2) << "Unsupported operand count: " << sort_op->operand_count();
-    return false;
-  }
-
-  const Shape& operand_shape = sort_op->operand(0)->shape();
-  if (sort_op->sort_dimension() != operand_shape.rank() - 1) {
-    VLOG(2) << "Sort dimension should be the minor one";
-    return false;
-  }
-  if (Product(operand_shape.dimensions()) <
-      GpuSortRewriter::SortSizeThreshold()) {
-    VLOG(2) << "Tensor shape size is too small to see an improvement";
-    return false;
-  }
-
-  auto sort_config = AnalyzeSortOp(*sort_op);
-  if (!sort_config.has_value()) {
-    VLOG(2) << "Only simple compare computations are supported";
-    return false;
-  }
-  if (!CreateRunner(sort_op, *sort_config).ok()) {
-    VLOG(2) << "Unsupported operand types (no compiled CUB kernels)";
-    return false;
-  }
-  VLOG(2) << "Sort operation is compatible";
-  return true;
-}
-
-// Restore the result shape after sorting a pair of tensors.
-// The trailing argument is the scratch buffer which should be discarded.
-HloInstruction* UnpackResultPair(HloSortInstruction* sort_op,
-                                 HloInstruction* custom_call, bool swap) {
-  HloComputation* parent = sort_op->parent();
-  HloInstruction* gte0 =
-      parent->AddInstruction(HloInstruction::CreateGetTupleElement(
-          sort_op->operand(0)->shape(), custom_call, swap ? 1 : 0));
-  HloInstruction* gte1 =
-      parent->AddInstruction(HloInstruction::CreateGetTupleElement(
-          sort_op->operand(1)->shape(), custom_call, swap ? 0 : 1));
-  return parent->AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
-}
-
-}  // namespace
-
-// Rewrites a single sort instruction with a custom call.
-absl::StatusOr<bool> GpuSortRewriter::RunOnInstruction(
-    HloSortInstruction* sort_op) {
-  // Get the sort tensor index and direction.
-  SortComputationAnalysis sort_config = AnalyzeSortOp(*sort_op).value();
-
-  // Get scratch size requirements from CUB.
-  const Shape& operand_shape = sort_op->operand(0)->shape();
-  int64_t batch_size = Product(operand_shape.dimensions()) /
-                       operand_shape.dimensions(sort_op->sort_dimension());
-
-  TF_ASSIGN_OR_RETURN(auto runner, CreateRunner(sort_op, sort_config));
-  TF_ASSIGN_OR_RETURN(
-      int64_t scratch_size,
-      runner->GetScratchSize(Product(operand_shape.dimensions()), batch_size));
-
-  // Align and increase scratch size to fit the offsets.
-  if (batch_size > 1) {
-    scratch_size += sizeof(int) - scratch_size % sizeof(int);
-    scratch_size += (batch_size + 1) * sizeof(int);
-  }
-
-  // Values are only present if sorting a pair of tensors.
-  HloInstruction* keys = sort_op->mutable_operand(0);
-  HloInstruction* values = nullptr;
-  if (sort_op->operand_count() == 2) {
-    values = sort_op->mutable_operand(1);
-    if (sort_config.key_operand == 1) {
-      std::swap(keys, values);
-    }
-  }
-
-  // Build the resulting shape for the custom call.
-  std::vector<Shape> shapes{keys->shape()};
-  std::vector<HloInstruction*> operands{keys};
-  if (values != nullptr) {
-    shapes.push_back(values->shape());
-    operands.push_back(values);
-  }
-  shapes.push_back(ShapeUtil::MakeShape(U8, {scratch_size}));
-  Shape call_shape = ShapeUtil::MakeTupleShape(absl::MakeSpan(shapes));
-
-  // Build the custom call instruction.
-  HloInstruction* custom_call =
-      sort_op->parent()->AddInstruction(HloInstruction::CreateCustomCall(
-          call_shape, absl::MakeSpan(operands), kCubDeviceRadixSortTarget));
-
-  xla::SortOptions backend_config;
-  backend_config.set_descending(sort_config.descending);
-  TF_RETURN_IF_ERROR(custom_call->set_backend_config(backend_config));
-
-  // Build the replacement instruction.
-  HloInstruction* replacement;
-  if (sort_op->operand_count() == 1) {
-    replacement =
-        sort_op->parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
-            sort_op->shape(), custom_call, 0));
-  } else {
-    replacement = UnpackResultPair(sort_op, custom_call,
-                                   /*swap=*/sort_config.key_operand == 1);
-  }
-
-  // Replace sort operation with custom call followed by GTE.
-  TF_RETURN_IF_ERROR(
-      sort_op->parent()->ReplaceInstruction(sort_op, replacement));
-  return true;
-}
-
-// Rewrites the sorts in the given computation into calls to CUB.
-absl::StatusOr<bool> GpuSortRewriter::RunOnComputation(
-    HloComputation* computation) {
-  std::vector<HloSortInstruction*> sort_ops;
-  for (auto* inst : computation->instructions()) {
-    HloSortInstruction* sort = DynCast<HloSortInstruction>(inst);
-    if (sort != nullptr && IsCubCompatibleSort(sort)) {
-      sort_ops.push_back(sort);
-    }
-  }
-  bool changed = false;
-  for (auto* sort : sort_ops) {
-    TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(sort));
-    changed |= result;
-  }
-  return changed;
-}
-
-// Replace compatible sort operations with custom calls.
-absl::StatusOr<bool> GpuSortRewriter::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  XLA_VLOG_LINES(2, "GpuSortRewriter::Run(), before:\n" + module->ToString());
-  bool changed = false;
-  for (HloComputation* computation :
-       module->MakeNonfusionComputations(execution_threads)) {
-    TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation));
-    changed |= result;
-  }
-  XLA_VLOG_LINES(2, "GpuSortRewriter::Run(), after:\n" + module->ToString());
-  return changed;
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gpu_sort_rewriter.h b/third_party/xla/xla/service/gpu/gpu_sort_rewriter.h
deleted file mode 100644
index 51dba3c..0000000
--- a/third_party/xla/xla/service/gpu/gpu_sort_rewriter.h
+++ /dev/null
@@ -1,63 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_GPU_SORT_REWRITER_H_
-#define XLA_SERVICE_GPU_GPU_SORT_REWRITER_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-// Rewrites sort operations into CustomCall HLOs that call into CUB.
-// Only a subset of shapes is supported - either a single tensor with a simple
-// compare function or a pair of tensors where keys are unsigned integers.
-
-class GpuSortRewriter : public HloModulePass {
- public:
-  absl::string_view name() const override { return "gpu-sort-rewriter"; }
-
-  // CUB radix sort is slower than XLA sort on small shapes, so do not rewrite
-  // tensors with sizes below this limit.
-  static int SortSizeThreshold() { return sort_size_threshold_; }
-  static void SetSortSizeThresholdForTestingOnly(int threshold) {
-    // We need to be able to reduce the threshold for testing, so that the tests
-    // can run and compare against the reference interpreter, which is quite
-    // slow.
-    sort_size_threshold_ = threshold;
-  }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
-  absl::StatusOr<bool> RunOnInstruction(HloSortInstruction* sort_op);
-  absl::StatusOr<bool> RunOnComputation(HloComputation* computation);
-
-  static inline int sort_size_threshold_ = 16385;
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_GPU_SORT_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/gpu_sort_rewriter_stub.cc b/third_party/xla/xla/service/gpu/gpu_sort_rewriter_stub.cc
deleted file mode 100644
index abacbc1..0000000
--- a/third_party/xla/xla/service/gpu/gpu_sort_rewriter_stub.cc
+++ /dev/null
@@ -1,46 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gpu_sort_rewriter.h"
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-absl::StatusOr<bool> GpuSortRewriter::RunOnInstruction(
-    HloSortInstruction* sort_op) {
-  return false;
-}
-
-absl::StatusOr<bool> GpuSortRewriter::RunOnComputation(
-    HloComputation* computation) {
-  return false;
-}
-
-absl::StatusOr<bool> GpuSortRewriter::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  return false;
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gpu_sort_rewriter_test.cc b/third_party/xla/xla/service/gpu/gpu_sort_rewriter_test.cc
deleted file mode 100644
index 69cdb92..0000000
--- a/third_party/xla/xla/service/gpu/gpu_sort_rewriter_test.cc
+++ /dev/null
@@ -1,453 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gpu_sort_rewriter.h"
-
-#include <utility>
-
-#include <gtest/gtest.h>
-#include "xla/error_spec.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/statusor.h"
-#include "tsl/platform/test.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-namespace m = ::xla::match;
-
-class GpuSortRewriterTest : public HloTestBase {
- public:
-  void SetUp() override {
-    HloTestBase::SetUp();
-    GpuSortRewriter::SetSortSizeThresholdForTestingOnly(1000);
-  }
-
-  bool RunModuleAndPass(HloModule* module) {
-    auto cloned = module->Clone();
-    bool changed = GpuSortRewriter().Run(module).value();
-    if (changed) {
-      // Here we run an end to end test to make sure that GpuSortRewriter does
-      // not introduce an incorrect rewrite. To do this, we need to clone the
-      // original module because the interpreter cannot process the already
-      // optimized module.
-      EXPECT_TRUE(RunAndCompare(std::move(cloned), ErrorSpec{0, 0}));
-    }
-    return changed;
-  }
-
-  void ExpectDirection(const HloInstruction* instruction, bool descending) {
-    auto config = instruction->backend_config<xla::SortOptions>();
-    EXPECT_EQ(config->descending(), descending);
-  }
-};
-
-// Basic sort: ascending.
-TEST_F(GpuSortRewriterTest, SortKeysLessThan) {
-  constexpr char kHlo[] = R"(
-HloModule TestModule
-
-%compare {
-  %lhs = f32[] parameter(0)
-  %rhs = f32[] parameter(1)
-  ROOT %lt = pred[] compare(%lhs, %rhs), direction=LT
-}
-
-ENTRY %main {
-  %input = f32[1000] parameter(0)
-  ROOT %sort = f32[1000] sort(%input), dimensions={0}, to_apply=%compare
-})";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
-  EXPECT_TRUE(RunModuleAndPass(module.get()));
-  EXPECT_THAT(
-      module->entry_computation()->root_instruction(),
-      GmockMatch(m::GetTupleElement(
-          m::CustomCall({kCubDeviceRadixSortTarget}, m::Parameter()), 0)));
-  ExpectDirection(module->entry_computation()->root_instruction()->operand(0),
-                  /*descending=*/false);
-}
-
-// Basic sort: descending.
-TEST_F(GpuSortRewriterTest, SortKeysGreaterThan) {
-  constexpr char kHlo[] = R"(
-HloModule TestModule
-
-%compare {
-  %lhs = f32[] parameter(0)
-  %rhs = f32[] parameter(1)
-  ROOT %gt = pred[] compare(%lhs, %rhs), direction=GT
-}
-
-ENTRY %main {
-  %input = f32[1000] parameter(0)
-  ROOT %sort = f32[1000] sort(%input), dimensions={0}, to_apply=%compare
-})";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
-  EXPECT_TRUE(RunModuleAndPass(module.get()));
-  EXPECT_THAT(
-      module->entry_computation()->root_instruction(),
-      GmockMatch(m::GetTupleElement(
-          m::CustomCall({kCubDeviceRadixSortTarget}, m::Parameter()), 0)));
-  ExpectDirection(module->entry_computation()->root_instruction()->operand(0),
-                  /*descending=*/true);
-}
-
-// Comparer swaps the parameter order -> direction is reversed.
-TEST_F(GpuSortRewriterTest, SortKeysGreaterThanSwapped) {
-  constexpr char kHlo[] = R"(
-HloModule TestModule
-
-%compare {
-  %lhs = f32[] parameter(1)
-  %rhs = f32[] parameter(0)
-  ROOT %gt = pred[] compare(%lhs, %rhs), direction=GT
-}
-
-ENTRY %main {
-  %input = f32[1000] parameter(0)
-  ROOT %sort = f32[1000] sort(%input), dimensions={0}, to_apply=%compare
-})";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
-  EXPECT_TRUE(RunModuleAndPass(module.get()));
-  EXPECT_THAT(
-      module->entry_computation()->root_instruction(),
-      GmockMatch(m::GetTupleElement(
-          m::CustomCall({kCubDeviceRadixSortTarget}, m::Parameter()), 0)));
-  ExpectDirection(module->entry_computation()->root_instruction()->operand(0),
-                  /*descending=*/false);
-}
-
-// Sort a pair of tensors, keys go first.
-TEST_F(GpuSortRewriterTest, SortPairs) {
-  constexpr char kHlo[] = R"(
-HloModule TestModule
-
-%compare {
-  %lhs_key = u32[] parameter(0)
-  %rhs_key = u32[] parameter(1)
-  %lhs_value = f32[] parameter(2)
-  %rhs_value = f32[] parameter(3)
-  ROOT %lt = pred[] compare(%lhs_key, %rhs_key), direction=LT
-}
-
-ENTRY %main {
-  %input_keys = u32[1000] parameter(0)
-  %input_values = f32[1000] parameter(1)
-  ROOT %sort = (u32[1000], f32[1000]) sort(%input_keys, %input_values),
-      dimensions={0}, to_apply=%compare
-})";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
-  EXPECT_TRUE(RunModuleAndPass(module.get()));
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch(m::Tuple(m::GetTupleElement(m::CustomCall(), 0),
-                                  m::GetTupleElement(m::CustomCall(), 1))));
-}
-
-// Sort a pair of tensors, keys go last.
-TEST_F(GpuSortRewriterTest, SortPairsSwapped) {
-  constexpr char kHlo[] = R"(
-HloModule TestModule
-
-%compare {
-  %lhs_value = f32[] parameter(0)
-  %rhs_value = f32[] parameter(1)
-  %lhs_key = u32[] parameter(2)
-  %rhs_key = u32[] parameter(3)
-  ROOT %lt = pred[] compare(%lhs_key, %rhs_key), direction=LT
-}
-
-ENTRY %main {
-  %input_values = f32[1000] parameter(0)
-  %input_keys = u32[1000] parameter(1)
-  ROOT %sort = (f32[1000], u32[1000]) sort(%input_values, %input_keys),
-      dimensions={0}, to_apply=%compare
-})";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
-  EXPECT_TRUE(RunModuleAndPass(module.get()));
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch(m::Tuple(m::GetTupleElement(m::CustomCall(), 1),
-                                  m::GetTupleElement(m::CustomCall(), 0))));
-}
-
-// CUB sort doesn't support more than two tensors.
-TEST_F(GpuSortRewriterTest, NoRewriteManyTensors) {
-  constexpr char kHlo[] = R"(
-HloModule TestModule
-
-%compare {
-  %lhs = f32[] parameter(0)
-  %rhs = f32[] parameter(1)
-  %unused1 = f64[] parameter(2)
-  %unused2 = f64[] parameter(3)
-  %unused3 = u64[] parameter(4)
-  %unused4 = u64[] parameter(5)
-  ROOT %lt = pred[] compare(%lhs, %rhs), direction=LT
-}
-
-ENTRY %main {
-  %input1 = f32[1000] parameter(0)
-  %input2 = f64[1000] parameter(1)
-  %input3 = u64[1000] parameter(2)
-  ROOT %sort = (f32[1000], f64[1000], u64[1000]) sort(%input1, %input2, %input3),
-      dimensions={0}, to_apply=%compare
-})";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
-  EXPECT_FALSE(RunModuleAndPass(module.get()));
-}
-
-// Only 1D shapes are supported.
-TEST_F(GpuSortRewriterTest, NoRewriteNonMinorSortDimension) {
-  constexpr char kHlo[] = R"(
-HloModule TestModule
-
-%compare {
-  %lhs = f32[] parameter(0)
-  %rhs = f32[] parameter(1)
-  ROOT %lt = pred[] compare(%lhs, %rhs), direction=LT
-}
-
-ENTRY %main {
-  %input = f32[1000,4] parameter(0)
-  ROOT %sort = f32[1000,4] sort(%input), dimensions={0}, to_apply=%compare
-})";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
-  EXPECT_FALSE(RunModuleAndPass(module.get()));
-}
-
-// Kernels are compiled for a subset of types.
-TEST_F(GpuSortRewriterTest, NoRewriteUnsupportedType) {
-  constexpr char kHlo[] = R"(
-HloModule TestModule
-
-%compare {
-  %lhs = pred[] parameter(0)
-  %rhs = pred[] parameter(1)
-  ROOT %lt = pred[] compare(%lhs, %rhs), direction=LT
-}
-
-ENTRY %main {
-  %input = pred[1000] parameter(0)
-  ROOT %sort = pred[1000] sort(%input), dimensions={0}, to_apply=%compare
-})";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
-  EXPECT_FALSE(RunModuleAndPass(module.get()));
-}
-
-// Comparer must be a simple function.
-TEST_F(GpuSortRewriterTest, NoRewriteComplexComparer) {
-  constexpr char kHlo[] = R"(
-HloModule TestModule
-
-%compare {
-  %lhs = f32[] parameter(0)
-  %lhs_scaled = f32[] multiply(%lhs, f32[] constant(2))
-  %rhs = f32[] parameter(1)
-  ROOT %lt = pred[] compare(%lhs_scaled, %rhs), direction=LT
-}
-
-ENTRY %main {
-  %input = f32[1000] parameter(0)
-  ROOT %sort = f32[1000] sort(%input), dimensions={0}, to_apply=%compare
-})";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
-  EXPECT_FALSE(RunModuleAndPass(module.get()));
-}
-
-// Comparer must use adjacent input values.
-TEST_F(GpuSortRewriterTest, NoRewriteMixedKeysValues) {
-  constexpr char kHlo[] = R"(
-HloModule TestModule
-
-%compare {
-  %lhs_key = u32[] parameter(0)
-  %rhs_key = u32[] parameter(1)
-  %lhs_value = u32[] parameter(2)
-  %rhs_value = u32[] parameter(3)
-  ROOT %mixed = pred[] compare(%rhs_key, %lhs_value), direction=LT
-}
-
-ENTRY %main {
-  %input_keys = u32[1000] parameter(0)
-  %input_values = u32[1000] parameter(1)
-  ROOT %sort = (u32[1000], u32[1000]) sort(%input_keys, %input_values),
-      dimensions={0}, to_apply=%compare
-})";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
-  EXPECT_FALSE(RunModuleAndPass(module.get()));
-}
-
-// Small shapes do not see improvement from CUB sort.
-TEST_F(GpuSortRewriterTest, NoRewriteSmallSize) {
-  constexpr char kHlo[] = R"(
-HloModule TestModule
-
-%compare {
-  %lhs = f32[] parameter(0)
-  %rhs = f32[] parameter(1)
-  ROOT %lt = pred[] compare(%lhs, %rhs), direction=LT
-}
-
-ENTRY %main {
-  %input = f32[100] parameter(0)
-  ROOT %sort = f32[100] sort(%input), dimensions={0}, to_apply=%compare
-})";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
-  EXPECT_FALSE(RunModuleAndPass(module.get()));
-}
-
-// Basic sort: with batch dimension.
-TEST_F(GpuSortRewriterTest, SortWithBatchDim) {
-  constexpr char kHlo[] = R"(
-HloModule TestModule
-
-%compare {
-  %lhs = f32[] parameter(0)
-  %rhs = f32[] parameter(1)
-  ROOT %lt = pred[] compare(%lhs, %rhs), direction=LT
-}
-
-ENTRY %main {
-  %input = f32[10,100] parameter(0)
-  ROOT %sort = f32[10,100] sort(%input), dimensions={1}, to_apply=%compare
-})";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
-  EXPECT_TRUE(RunModuleAndPass(module.get()));
-  EXPECT_THAT(
-      module->entry_computation()->root_instruction(),
-      GmockMatch(m::GetTupleElement(
-          m::CustomCall({kCubDeviceRadixSortTarget}, m::Parameter()), 0)));
-  ExpectDirection(module->entry_computation()->root_instruction()->operand(0),
-                  /*descending=*/false);
-}
-
-// Basic sort: with multiple batch dimensions.
-TEST_F(GpuSortRewriterTest, SortWithMultipleBatchDims) {
-  constexpr char kHlo[] = R"(
-HloModule TestModule
-
-%compare {
-  %lhs = f32[] parameter(0)
-  %rhs = f32[] parameter(1)
-  ROOT %lt = pred[] compare(%lhs, %rhs), direction=LT
-}
-
-ENTRY %main {
-  %input = f32[10,10,10] parameter(0)
-  ROOT %sort = f32[10,10,10] sort(%input), dimensions={2}, to_apply=%compare
-})";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
-  EXPECT_TRUE(RunModuleAndPass(module.get()));
-  EXPECT_THAT(
-      module->entry_computation()->root_instruction(),
-      GmockMatch(m::GetTupleElement(
-          m::CustomCall({kCubDeviceRadixSortTarget}, m::Parameter()), 0)));
-  ExpectDirection(module->entry_computation()->root_instruction()->operand(0),
-                  /*descending=*/false);
-}
-
-// Sort a pair of tensors (values, indices generated by iota) with a complex
-// compare.
-TEST_F(GpuSortRewriterTest, SortPairsIotaComparerSimple) {
-  constexpr char kHlo[] = R"(
-HloModule TestModule
-
-%compare {
-  %lhs = u16[] parameter(0)
-  %rhs = u16[] parameter(1)
-  %lhs_index = s32[] parameter(2)
-  %rhs_index = s32[] parameter(3)
-
-  cmp_indices = pred[] compare(%lhs_index, %rhs_index), direction=LT
-  cmp_lr = pred[] compare(%lhs, %rhs), direction=GT
-  cmp_eq = pred[] compare(%lhs, %rhs), direction=EQ
-
-  ROOT %lt = pred[] select(cmp_eq, cmp_indices, cmp_lr)
-}
-
-ENTRY %main {
-  %inputs = u16[1000] parameter(0)
-  %iota = s32[1000] iota(), iota_dimension=0
-  ROOT %sort = (u16[1000], s32[1000]) sort(%inputs, %iota),
-      dimensions={0}, to_apply=%compare
-})";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
-  EXPECT_TRUE(RunModuleAndPass(module.get()));
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch(m::Tuple(m::GetTupleElement(m::CustomCall(), 0),
-                                  m::GetTupleElement(m::CustomCall(), 1))));
-}
-
-// Sort a pair of tensors (values, indices generated by iota) with a complex
-// compare computation that matches the output of the StableSortExpander pass.
-TEST_F(GpuSortRewriterTest, SortPairsIotaComparerLikeStableSortExpander) {
-  constexpr char kHlo[] = R"(
-HloModule TestModule
-
-%compare {
-  %lhs = u16[] parameter(0)
-  %rhs = u16[] parameter(1)
-  %lhs_index = s32[] parameter(2)
-  %rhs_index = s32[] parameter(3)
-
-  cmp_indices = pred[] compare(%lhs_index, %rhs_index), direction=LT
-  cmp_lr = pred[] compare(%lhs, %rhs), direction=GT
-  cmp_rl = pred[] compare(%rhs, %lhs), direction=GT
-  cmp_eq = pred[] compare(cmp_lr, cmp_rl), direction=EQ
-
-  ROOT %lt = pred[] select(cmp_eq, cmp_indices, cmp_lr)
-}
-
-ENTRY %main {
-  %inputs = u16[1000] parameter(0)
-  %iota = s32[1000] iota(), iota_dimension=0
-  ROOT %sort = (u16[1000], s32[1000]) sort(%inputs, %iota),
-      dimensions={0}, to_apply=%compare
-})";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
-  EXPECT_TRUE(RunModuleAndPass(module.get()));
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch(m::Tuple(m::GetTupleElement(m::CustomCall(), 0),
-                                  m::GetTupleElement(m::CustomCall(), 1))));
-}
-
-TEST_F(GpuSortRewriterTest, SortSizeThresholdIsSet) {
-  EXPECT_EQ(GpuSortRewriter::SortSizeThreshold(), 1000);
-}
-
-}  // namespace
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gpu_spmd_pipeline.cc b/third_party/xla/xla/service/gpu/gpu_spmd_pipeline.cc
index d84797d..06e1e6f 100644
--- a/third_party/xla/xla/service/gpu/gpu_spmd_pipeline.cc
+++ b/third_party/xla/xla/service/gpu/gpu_spmd_pipeline.cc
@@ -27,7 +27,7 @@
 #include "xla/service/algebraic_simplifier.h"
 #include "xla/service/conditional_simplifier.h"
 #include "xla/service/gather_expander.h"
-#include "xla/service/gpu/gpu_algebraic_simplifier.h"
+#include "xla/service/gpu/transforms/algebraic_simplifier.h"
 #include "xla/service/hlo_constant_folding.h"
 #include "xla/service/hlo_dce.h"
 #include "xla/service/hlo_module_config.h"
diff --git a/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler.cc b/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler.cc
deleted file mode 100644
index 8f5e261..0000000
--- a/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler.cc
+++ /dev/null
@@ -1,1150 +0,0 @@
-/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gpu_windowed_einsum_handler.h"
-
-#include <cstdint>
-#include <vector>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/status.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/hlo/utils/hlo_query.h"
-#include "xla/literal_util.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/hlo_creation_utils.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/shape_inference.h"
-#include "xla/util.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla::gpu {
-namespace {
-
-namespace m = match;
-
-// Enables the creation of FP8 GEMM Custom Calls for all-gather and
-// reduce-scatter windowed einsums in gemm_rewriter.cc by moving the scalings
-// and type conversions of FP8 operands into the bodies of their while loops,
-// i.e. rewrites
-//
-//   inputs --> dequant --> while loop {dynamic-slice/collective-permute/dot}
-//
-// into
-//
-//   inputs --> while loop {dequant --> dynamic-slice/collective-permute/dot}.
-absl::Status ShiftDequantizationF8(const HloComputation* comp,
-                                   const std::array<HloInstruction*, 2>& gte) {
-  HloInstruction* while_instr = comp->WhileCallInstruction();
-  if (!while_instr) {
-    return absl::OkStatus();
-  }
-
-  // Identify the scalings and type conversions applied to the inputs of the
-  // while loop.
-  HloInstruction* param_tuple = while_instr->mutable_operand(0);
-  std::array<HloInstruction*, 2> binaries, operands, scales;
-  for (int k = 0; k < 2; ++k) {
-    if (!Match(param_tuple->mutable_operand(k),
-               m::AnyOf<HloInstruction>(
-                   m::Divide(&binaries[k], m::Convert(m::Op(&operands[k])),
-                             m::Broadcast(m::Op(&scales[k]))),
-                   m::MultiplyAnyOrder(&binaries[k],
-                                       m::Convert(m::Op(&operands[k])),
-                                       m::Broadcast(m::Op(&scales[k])))))) {
-      VLOG(5) << "Unable to identify FP8 dequantization pattern.";
-      return absl::OkStatus();
-    }
-  }
-
-  // For the dot to be rewritten by gemm_rewriter.cc into an FP8 GEMM, at most
-  // one of the inputs can be F8E5M2.
-  std::array<PrimitiveType, 2> operand_types{
-      operands[0]->shape().element_type(), operands[1]->shape().element_type()};
-  if (!((operand_types[0] == F8E4M3FN && operand_types[1] == F8E4M3FN) ||
-        (operand_types[0] == F8E4M3FN && operand_types[1] == F8E5M2) ||
-        (operand_types[0] == F8E5M2 && operand_types[1] == F8E4M3FN))) {
-    VLOG(5) << "Unsupported types.";
-    return absl::OkStatus();
-  }
-
-  // The dequantized types must be BF16, FP16 or FP32.
-  for (int k = 0; k < 2; ++k) {
-    if (binaries[k]->shape().element_type() != BF16 &&
-        binaries[k]->shape().element_type() != F16 &&
-        binaries[k]->shape().element_type() != F32) {
-      VLOG(5) << "Unsupported types.";
-      return absl::OkStatus();
-    }
-  }
-
-  // The FP8 scaling operands must be scalars.
-  if (!ShapeUtil::IsScalar(scales[0]->shape()) ||
-      !ShapeUtil::IsScalar(scales[1]->shape())) {
-    VLOG(5) << "Scaling factors must be scalars.";
-    return absl::OkStatus();
-  }
-
-  // Identify the dot and collective-permute or dynamic-slice instructions in
-  // the all-gather or reduce-scatter patterns in while's body.
-  HloComputation* while_body = while_instr->while_body();
-  HloComputation* while_condition = while_instr->while_condition();
-  HloInstruction* while_root = while_body->root_instruction();
-  std::array<HloInstruction*, 2> dots, dyn_slices{nullptr, nullptr},
-      coll_perms{nullptr, nullptr};
-  if (Match(
-          while_root,
-          m::Tuple(m::CollectivePermute(
-                       &coll_perms[1], m::CollectivePermute(
-                                           &coll_perms[0], m::Op().Is(gte[0]))),
-                   m::Op().Is(gte[1]),
-                   m::DynamicUpdateSlice(
-                       m::DynamicUpdateSlice().WithOperand(
-                           1, m::Dot(&dots[0], m::Op().Is(gte[0]),
-                                     m::Op().Is(gte[1]))),
-                       m::Dot(&dots[1], m::Op(), m::Op().Is(gte[1])), m::Op(),
-                       m::Op(), m::Op()),
-                   m::Op(), m::Op()))) {
-    VLOG(5) << "Identified all-gather windowed einsum pattern.";
-  } else if (Match(
-                 while_root,
-                 m::Tuple(m::Op().Is(gte[0]), m::Op().Is(gte[1]),
-                          m::AddAnyOrder(
-                              m::Dot(&dots[0], m::DynamicSlice(&dyn_slices[0]),
-                                     m::Op().Is(gte[1])),
-                              m::Op()),
-                          m::CollectivePermute(m::AddAnyOrder(
-                              m::Dot(&dots[1], m::DynamicSlice(&dyn_slices[1]),
-                                     m::Op().Is(gte[1])),
-                              m::Op())),
-                          m::Op()))) {
-    VLOG(5) << "Identified reduce-scatter windowed einsum pattern.";
-  } else {
-    VLOG(5) << "Unable to identify valid windowed einsum pattern.";
-    return absl::OkStatus();
-  }
-
-  // Replace the dequantized dot operands in the parameter tuple used by while
-  // with FP8 operands.
-  for (int k = 0; k < 2; ++k) {
-    TF_RETURN_IF_ERROR(
-        param_tuple->ReplaceOperandWithDifferentShape(k, operands[k]));
-    ShapeUtil::UpdateTupleShape(operands[k]->shape(), k,
-                                param_tuple->mutable_shape());
-    param_tuple->AppendOperand(scales[k]);
-    ShapeUtil::AppendShapeToTuple(scales[k]->shape(),
-                                  param_tuple->mutable_shape());
-  }
-
-  // Update the parameter tuples of while's body and condition computations.
-  for (HloComputation* while_comp : {while_body, while_condition}) {
-    while_comp->ReplaceParameter(
-        0, HloInstruction::CreateParameter(
-               0, param_tuple->shape(),
-               while_comp->parameter_instruction(0)->name()));
-  }
-
-  // In the while body, replace the existing get-tuple-element instructions
-  // retrieving BF16/FP16/FP32 dot operands with dequantized get-tuple-element
-  // instructions retrieving FP8 dot operands from the input tuple.
-  HloInstruction* body_param = while_body->parameter_instruction(0);
-  for (int k = 0; k < 2; ++k) {
-    TF_ASSIGN_OR_RETURN(HloInstruction * operand_f8,
-                        MakeGetTupleElementHlo(body_param, k));
-
-    if (while_root->operand(k) == gte[k]) {
-      TF_RETURN_IF_ERROR(
-          while_root->ReplaceOperandWithDifferentShape(k, operand_f8));
-      ShapeUtil::UpdateTupleShape(operand_f8->shape(), k,
-                                  while_root->mutable_shape());
-    }
-
-    TF_ASSIGN_OR_RETURN(
-        HloInstruction * operand_scale,
-        MakeGetTupleElementHlo(
-            body_param, body_param->shape().tuple_shapes_size() - 2 + k));
-
-    // Also add the scaling factor to the output tuple of the while body.
-    while_root->AppendOperand(operand_scale);
-    ShapeUtil::AppendShapeToTuple(operand_scale->shape(),
-                                  while_root->mutable_shape());
-
-    // Dequantize the operands of the dots and dynamic-slices.
-    HloInstruction* operand_f32 =
-        MakeConvertToHlo(operand_f8, gte[k]->shape().element_type());
-    HloInstruction* broadcast_scale =
-        MakeBroadcastHlo(operand_scale, {}, operand_f32->shape());
-    TF_ASSIGN_OR_RETURN(
-        HloInstruction * operand_scaled,
-        MakeBinaryHlo(binaries[k]->opcode(), operand_f32, broadcast_scale));
-
-    // Replace the original get-tuple-element instructions accessing the
-    // operands of the dots and dynamic-slices with the dequantized FP8
-    // operands. The order of dequantization and dynamic-slices will be
-    // exchanged in gemm_rewriter.cc.
-    for (int l = 0; l < 2; ++l) {
-      if (dots[l]->operand(k) == gte[k]) {
-        TF_RETURN_IF_ERROR(dots[l]->ReplaceOperandWith(k, operand_scaled));
-      }
-      if (dyn_slices[l] && dyn_slices[l]->operand(0) == gte[k]) {
-        TF_RETURN_IF_ERROR(
-            dyn_slices[l]->ReplaceOperandWith(0, operand_scaled));
-      }
-    }
-
-    // In the all-gather case, coll_perms[0] has two users, coll_perms[1] and
-    // dots[1], which prevents it from being exchanged with dequantization in
-    // gemm_rewriter.cc. Instead, directly insert the dequantization before
-    // dots[1] here.
-    if (coll_perms[0] && coll_perms[0]->operand(0) == gte[k]) {
-      std::array<HloInstruction*, 2> coll_perms_f8{nullptr, nullptr};
-      // Change the type of both collective-permutes to FP8.
-      coll_perms_f8[0] =
-          while_body->AddInstruction(coll_perms[0]->CloneWithNewOperands(
-              operand_f8->shape(), {operand_f8}));
-      coll_perms_f8[1] =
-          while_body->AddInstruction(coll_perms[1]->CloneWithNewOperands(
-              coll_perms_f8[0]->shape(), {coll_perms_f8[0]}));
-
-      // Insert the dequantization between coll_perms[0] and dots[1].
-      HloInstruction* coll_perm0_f32 =
-          MakeConvertToHlo(coll_perms_f8[0], gte[k]->shape().element_type());
-      TF_ASSIGN_OR_RETURN(HloInstruction * x_scaled,
-                          MakeBinaryHlo(binaries[k]->opcode(), coll_perm0_f32,
-                                        broadcast_scale));
-      TF_RETURN_IF_ERROR(dots[1]->ReplaceOperandWith(0, x_scaled));
-
-      // Update the output tuple.
-      TF_RETURN_IF_ERROR(
-          while_root->ReplaceOperandWithDifferentShape(0, coll_perms_f8[1]));
-      ShapeUtil::UpdateTupleShape(coll_perms_f8[1]->shape(), 0,
-                                  while_root->mutable_shape());
-    }
-  }
-
-  // Update the shape of the while call in the parent computation.
-  TF_RETURN_IF_ERROR(
-      while_instr->ReplaceAllUsesWithDifferentShape(while_instr->AddInstruction(
-          while_instr->CloneWithNewShape(while_root->shape()))));
-  TF_RETURN_IF_ERROR(while_instr->parent()->RemoveInstruction(while_instr));
-
-  if (coll_perms[0]) {
-    TF_RETURN_IF_ERROR(while_body->RemoveInstruction(coll_perms[1]));
-    TF_RETURN_IF_ERROR(while_body->RemoveInstruction(coll_perms[0]));
-  }
-  TF_RETURN_IF_ERROR(while_body->RemoveInstruction(gte[0]));
-  TF_RETURN_IF_ERROR(while_body->RemoveInstruction(gte[1]));
-
-  VLOG(5) << "FP8 dequantization moved into while loop.";
-  return absl::OkStatus();
-}
-
-int64_t NumberOfInstructionsInComp(const HloComputation* comp, HloOpcode op) {
-  int64_t total_count = 0;
-  for (const HloInstruction* inst : comp->instructions()) {
-    if (inst->opcode() == op) {
-      ++total_count;
-    }
-  }
-  return total_count;
-}
-
-absl::Status UpdateDotAndConsumerConfig(HloInstruction* dot,
-                                        int64_t stream_id) {
-  auto dot_gpu_config = dot->backend_config<gpu::GpuBackendConfig>();
-  HloInstruction* updater = dot->users()[0];
-  auto updater_gpu_config = updater->backend_config<gpu::GpuBackendConfig>();
-  dot_gpu_config->set_operation_queue_id(stream_id);
-  updater_gpu_config->mutable_wait_on_operation_queues()->Add(stream_id);
-
-  TF_RETURN_IF_ERROR(dot->set_backend_config(dot_gpu_config.value()));
-  TF_RETURN_IF_ERROR(updater->set_backend_config(updater_gpu_config.value()));
-  return absl::OkStatus();
-}
-
-absl::Status SetForceDelayForInstruction(HloInstruction* instr,
-                                         bool force_delay) {
-  auto gpu_config = instr->backend_config<gpu::GpuBackendConfig>();
-
-  gpu_config->set_force_earliest_schedule(force_delay);
-
-  TF_RETURN_IF_ERROR(instr->set_backend_config(gpu_config.value()));
-  return absl::OkStatus();
-}
-
-absl::StatusOr<bool> HandleRsWindowedEinsumLoop(HloComputation* comp,
-                                                int64_t stream_id) {
-  bool changed = false;
-  // If we have a einsum loop with only 1 dot, this means either
-  // the loop is not unrolled or only 1 partition is available.
-  // It's a no-op in either case.
-  if (NumberOfInstructionsInComp(comp, HloOpcode::kDot) <= 1) {
-    return changed;
-  }
-  for (auto inst : comp->MakeInstructionPostOrder()) {
-    HloInstruction* matched_dot;
-    std::array<HloInstruction*, 2> gte;
-    // The dot we'd like to parallelize is consuming the second loop input
-    // as RHS.
-    if (Match(inst,
-              m::Dot(&matched_dot,
-                     m::DynamicSlice().WithOperand(
-                         0, m::GetTupleElement(&gte[0], m::Parameter(), 0)),
-                     m::GetTupleElement(&gte[1], m::Parameter(), 1)))) {
-      // If present, move the dequantization of FP8 operands of the dot into the
-      // while loop to allow gemm_rewriter.cc to rewrite into an FP8 Custom
-      // Call.
-      TF_RETURN_IF_ERROR(ShiftDequantizationF8(comp, gte));
-
-      // Dispatch the dot to additional compute stream.
-      TF_RETURN_IF_ERROR(UpdateDotAndConsumerConfig(matched_dot, stream_id));
-      ++stream_id;
-      changed = true;
-    }
-
-    // We need to enforce the first collective-permute to be always scheduled
-    // at the beginning of the loop.
-    HloInstruction* matched_cp;
-    if (Match(inst, m::CollectivePermute(
-                        &matched_cp, m::GetTupleElement(m::Parameter(), 2)))) {
-      TF_RETURN_IF_ERROR(
-          SetForceDelayForInstruction(matched_cp, /*force_delay=*/true));
-      changed = true;
-    }
-  }
-  return changed;
-}
-
-absl::StatusOr<bool> HandleAgWindowedEinsumLoop(HloComputation* comp,
-                                                int64_t stream_id) {
-  bool changed = false;
-  // If we have a einsum loop with only 1 dot, this means either
-  // the loop is not unrolled or only 1 partition is available.
-  // It's a no-op in either case.
-  if (NumberOfInstructionsInComp(comp, HloOpcode::kDot) <= 1) {
-    return changed;
-  }
-  for (auto inst : comp->MakeInstructionPostOrder()) {
-    HloInstruction* matched_dot;
-    std::array<HloInstruction*, 2> gte;
-    // The dot we'd like to parallelize is consuming the second loop input
-    // as RHS and first loop input as LHS.
-    if (Match(inst, m::Dot(&matched_dot,
-                           m::GetTupleElement(&gte[0], m::Parameter(), 0),
-                           m::GetTupleElement(&gte[1], m::Parameter(), 1)))) {
-      // If present, move the dequantization of FP8 operands of the dot into the
-      // while loop to allow gemm_rewriter.cc to rewrite into an FP8 Custom
-      // Call.
-      TF_RETURN_IF_ERROR(ShiftDequantizationF8(comp, gte));
-
-      // Dispatch the dot to additional compute stream.
-      TF_RETURN_IF_ERROR(UpdateDotAndConsumerConfig(matched_dot, stream_id));
-      ++stream_id;
-      TF_RETURN_IF_ERROR(
-          SetForceDelayForInstruction(matched_dot, /*force_delay=*/true));
-      changed = true;
-    }
-
-    // We need to enforce the first collective-permute to be always scheduled
-    // at the beginning of the loop.
-    HloInstruction* matched_cp;
-    if (Match(inst, m::CollectivePermute(
-                        &matched_cp, m::GetTupleElement(m::Parameter(), 0)))) {
-      TF_RETURN_IF_ERROR(
-          SetForceDelayForInstruction(matched_cp, /*force_delay=*/true));
-      changed = true;
-    }
-  }
-  return changed;
-}
-
-static int64_t GetAgActivationCacheIndex(const HloInstruction* while_loop) {
-  const HloInstruction* loop_tuple = while_loop->operand(0);
-  const Shape& tuple_shape = loop_tuple->shape();
-  CHECK(tuple_shape.IsTuple());
-  return tuple_shape.tuple_shapes_size();
-}
-
-absl::Status ProcessWindowedEinsumLoopForActivationCaching(
-    GpuWindowedEinsumHandler::WindowedEinsumAgLoops& ag_loop,
-    HloInstruction* ag_with_shared_operand) {
-  HloInstruction* loop = ag_loop.loop;
-  // Transform the while body to cache the allgathered result in the
-  // output buffer to be consumed by the dot
-  HloComputation* while_body = loop->while_body();
-  HloInstruction* input_gte;
-  for (HloInstruction* gte : while_body->parameter_instruction(0)->users()) {
-    if (gte->tuple_index() == 0) {
-      input_gte = gte;
-    }
-  }
-  // Get the output operand of the full buffer.
-  HloInstruction* root = while_body->root_instruction();
-  // Change loop body to include the new input and output element.
-  HloInstruction* input_tuple = while_body->parameter_instruction(0);
-  const Shape& input_shape = input_tuple->shape();
-  // The full buffer that we will use to cache the accumulated activation
-  // is the last operand in the output tuple.
-  int64_t full_cache_buffer_index = GetAgActivationCacheIndex(loop);
-  std::vector<Shape> new_input_shapes(input_shape.tuple_shapes().begin(),
-                                      input_shape.tuple_shapes().end());
-  new_input_shapes.push_back(ag_with_shared_operand->shape());
-  // Update body input shape
-  Shape new_input_shape = ShapeUtil::MakeTupleShape(new_input_shapes);
-  *input_tuple->mutable_shape() = new_input_shape;
-  HloInstruction* full_buffer_output_gte =
-      while_body->AddInstruction(HloInstruction::CreateGetTupleElement(
-          ag_with_shared_operand->shape(), input_tuple,
-          full_cache_buffer_index));
-
-  // Update condition input shape
-  HloComputation* cond_comp = loop->while_condition();
-  HloInstruction* cond_input_tuple = cond_comp->parameter_instruction(0);
-  *cond_input_tuple->mutable_shape() = new_input_shape;
-
-  // Update input to the while instruction in parent computation
-  HloInstruction* original_while_input = loop->mutable_operand(0);
-  HloComputation* parent_comp = loop->parent();
-  std::vector<HloInstruction*> new_operands(
-      original_while_input->operands().begin(),
-      original_while_input->operands().end());
-  new_operands.push_back(
-      parent_comp->AddInstruction(HloInstruction::CreateBroadcast(
-          ag_with_shared_operand->shape(),
-          parent_comp->AddInstruction(HloInstruction::CreateConstant(
-              LiteralUtil::Zero(new_input_shapes[0].element_type()))),
-          {})));
-  HloInstruction* new_while_input =
-      parent_comp->AddInstruction(HloInstruction::CreateTuple(new_operands));
-  TF_RETURN_IF_ERROR(
-      loop->ReplaceOperandWithDifferentShape(0, new_while_input));
-  TF_RETURN_IF_ERROR(parent_comp->ReplaceInstructionWithDifferentShape(
-      original_while_input, new_while_input));
-  *loop->mutable_shape() = new_input_shape;
-
-  HloInstruction* new_full_buffer_output = nullptr;
-  // Find the DUS in the loop body and re-use the slice indices
-  // This should just be a constant(0)
-  HloInstruction* dus_boundary_constant;
-  // The slice we need this time is the output of the first
-  // collective-permute
-  HloInstruction* first_cp_output;
-  for (HloInstruction* gte_user : input_gte->users()) {
-    if (gte_user->opcode() == HloOpcode::kCollectivePermute) {
-      first_cp_output = gte_user;
-      break;
-    }
-  }
-  for (HloInstruction* inst : while_body->MakeInstructionPostOrder()) {
-    HloInstruction* slice_indices;
-    // If we have a DUS(PARAM,DS) pattern, we need to update the output
-    // buffer with the first slice.
-    if (Match(inst,
-              m::DynamicUpdateSlice(
-                  m::GetTupleElement(m::Parameter()), m::Op(),
-                  m::Constant(&dus_boundary_constant),
-                  m::Reshape(m::DynamicSlice(&slice_indices, m::Op(), m::Op())),
-                  m::Op()))) {
-      slice_indices = while_body->AddInstruction(HloInstruction::CreateReshape(
-          dus_boundary_constant->shape(), slice_indices));
-      VLOG(5) << "Created slice op for first slice: "
-              << slice_indices->ToString();
-      full_buffer_output_gte =
-          while_body->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
-              full_buffer_output_gte->shape(), full_buffer_output_gte,
-              input_gte,
-              {dus_boundary_constant, slice_indices, dus_boundary_constant}));
-    }
-    // If we have a DUS(DUS,DS) pattern, then the einsum loop is
-    // unrolled, we need to update the output buffer again with the
-    // second slice. Since the second slice will have different indices,
-    // we need to re-capture slice_indices.
-    if (Match(inst,
-              m::DynamicUpdateSlice(
-                  m::DynamicUpdateSlice(), m::Op(), m::Constant(),
-                  m::Reshape(m::DynamicSlice(&slice_indices, m::Op(), m::Op())),
-                  m::Op()))) {
-      slice_indices = while_body->AddInstruction(HloInstruction::CreateReshape(
-          dus_boundary_constant->shape(), slice_indices));
-      VLOG(5) << "Created slice op for second slice: "
-              << slice_indices->ToString();
-      new_full_buffer_output =
-          while_body->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
-              full_buffer_output_gte->shape(), full_buffer_output_gte,
-              first_cp_output,
-              {dus_boundary_constant, slice_indices, dus_boundary_constant}));
-    }
-
-    // If we have a Dot(DS(parameter_index1)), then operands are sharded along
-    // the contracting dim. Slice indices will be the contracting dim's slices.
-    HloInstruction* slice_index;
-    HloInstruction* ds_index_constant;
-    HloInstruction* remainder;
-    HloInstruction* ds_param;
-    // There will be 2 dynamic-slices for unrolled loops, match for each one to
-    // get the slice index which will be used to write the corresponding
-    // received shard into cached activation buffer. For unrolled loops, we need
-    // to write to the final buffer twice per iteration, so we need to match for
-    // the correct slice index based on each DS.
-    if (Match(inst, m::Dot(m::Op(), m::DynamicSlice(&ds_param))) &&
-        Match(ds_param->operand(0), m::GetTupleElement(m::Parameter(), 1))) {
-      for (int64_t ds_op_i = 1; ds_op_i < ds_param->operands().size();
-           ds_op_i++) {
-        if (!Match(
-                ds_param->mutable_operand(ds_op_i),
-                m::Reshape(&slice_index, m::DynamicSlice(m::Constant(),
-                                                         m::Op(&remainder)))) &&
-            !Match(ds_param->mutable_operand(ds_op_i),
-                   m::Constant(&ds_index_constant))) {
-          return absl::OkStatus();
-        }
-      }
-      // First DS has slice index calculated based on loop iterator
-      // Remainder(add(gte, partition_id))
-      if (Match(remainder,
-                m::Remainder(m::Add(m::GetTupleElement(), m::Op()), m::Op()))) {
-        full_buffer_output_gte =
-            while_body->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
-                full_buffer_output_gte->shape(), full_buffer_output_gte,
-                input_gte,
-                {ds_index_constant, ds_index_constant, slice_index}));
-      }
-      // Second DS has slice index calculated based on loop iterator+1 hence
-      // Remainder(add(add(gte, 1), partition_id))
-      if (Match(remainder,
-                m::Remainder(
-                    m::Add(m::Add(m::GetTupleElement(), m::Op()), m::Op()),
-                    m::Op()))) {
-        new_full_buffer_output =
-            while_body->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
-                full_buffer_output_gte->shape(), full_buffer_output_gte,
-                first_cp_output,
-                {ds_index_constant, ds_index_constant, slice_index}));
-      }
-    }
-  }
-  std::vector<HloInstruction*> original_operands(root->operands().begin(),
-                                                 root->operands().end());
-  original_operands.push_back(new_full_buffer_output);
-  HloInstruction* new_output_tuple = while_body->AddInstruction(
-      HloInstruction::CreateTuple(original_operands));
-  TF_RETURN_IF_ERROR(
-      while_body->ReplaceInstructionWithDifferentShape(root, new_output_tuple));
-  return absl::OkStatus();
-}
-
-bool HasReplicaGroups(const HloInstruction* inst) {
-  return inst->replica_groups().size() > 0;
-}
-
-bool ShouldAddToChain(const HloInstruction* inst) {
-  switch (inst->opcode()) {
-    case HloOpcode::kTranspose:
-    case HloOpcode::kReshape:
-    case HloOpcode::kCopy:
-      return inst->user_count() == 1;
-    default:
-      return false;
-  }
-}
-
-struct MatchedGemmA2aResult {
-  HloInstruction* producer_gemm;
-  HloInstruction* lhs;
-  HloInstruction* rhs;
-  HloInstruction* a2a_replacement = nullptr;
-  bool matched = false;
-};
-
-class WindowedEinsumVisitor : public DfsHloRewriteVisitor {
- public:
-  explicit WindowedEinsumVisitor(
-      std::vector<GpuWindowedEinsumHandler::WindowedEinsumAgLoops>&
-          all_ag_loops)
-      : all_ag_loops_(all_ag_loops) {}
-  absl::StatusOr<bool> MatchA2aGemmWithIntermediateReshapes(
-      HloInstruction* dot, HloInstruction** lhs, HloInstruction** rhs) {
-    if (Match(dot, m::Dot(m::AllToAll(lhs).WithOneUse().WithPredicate(
-                              HasReplicaGroups),
-                          m::Op(rhs))) &&
-        !DynCast<HloAllToAllInstruction>((*lhs))->constrain_layout() &&
-        !(*lhs)->shape().IsTuple()) {
-      return true;
-    }
-    std::vector<HloInstruction*> allowed_intermediate_ops(
-        {dot->mutable_operand(0)});
-
-    HloAllToAllInstruction* matched_a2a = nullptr;
-    // We keep pushing until an unmet condition or we have found the a2a.
-    while (true) {
-      HloInstruction* curr = allowed_intermediate_ops.back();
-      if (ShouldAddToChain(curr)) {
-        allowed_intermediate_ops.insert(allowed_intermediate_ops.end(),
-                                        std::begin(curr->operands()),
-                                        std::end(curr->operands()));
-      } else if (curr->opcode() == HloOpcode::kAllToAll &&
-                 curr->user_count() == 1) {
-        matched_a2a = DynCast<HloAllToAllInstruction>(curr);
-        allowed_intermediate_ops.pop_back();
-        break;
-      } else {
-        return false;
-      }
-    }
-    CHECK(matched_a2a != nullptr);
-    if (matched_a2a->constrain_layout() || matched_a2a->shape().IsTuple() ||
-        !HasReplicaGroups(matched_a2a) || !matched_a2a->split_dimension()) {
-      return false;
-    }
-    // We need to create a new a2a that's a direct producer of the dot and
-    // replace it with the original a2a. A new reshape will be added to the
-    // orginal a2a's input. We first need to determine the new split dimension
-    // after all the reshape ops.
-    int64_t split_dimension = *matched_a2a->split_dimension();
-    for (int64_t i = allowed_intermediate_ops.size() - 1; i >= 0; i--) {
-      HloInstruction* current_op = allowed_intermediate_ops[i];
-      if (current_op->opcode() == HloOpcode::kReshape) {
-        std::vector<std::pair<int64_t, int64_t>> unmodified_dims =
-            ShapeUtil::DimensionsUnmodifiedByReshape(
-                current_op->operand(0)->shape(), current_op->shape());
-        auto it = absl::c_find_if(
-            unmodified_dims,
-            [&split_dimension](std::pair<int64_t, int64_t>& dim_pair) {
-              return dim_pair.first == split_dimension;
-            });
-        // Split dimension of a2a has been modified, we cannot deduce the new
-        // split dim easily, so skip decomposition.
-        if (it == unmodified_dims.end()) {
-          VLOG(5) << "Split dimension of: " << matched_a2a->ToShortString()
-                  << " has been modified by reshapes. Skip process it for "
-                     "decomposition.";
-          return false;
-        }
-        // Assign the new split dim.
-        split_dimension = it->second;
-      } else if (current_op->opcode() == HloOpcode::kTranspose) {
-        const auto& transpose_dims = current_op->dimensions();
-        for (int64_t j = 0; j < transpose_dims.size(); j++) {
-          if ((int64_t)transpose_dims[j] == split_dimension) {
-            split_dimension = j;
-            break;
-          }
-        }
-      }
-    }
-    TF_RETURN_IF_ERROR(allowed_intermediate_ops.back()->ReplaceOperandWith(
-        0, matched_a2a->mutable_operand(0)));
-    HloInstruction* new_a2a =
-        matched_a2a->parent()->AddInstruction(HloInstruction::CreateAllToAll(
-            allowed_intermediate_ops.front()->shape(),
-            {allowed_intermediate_ops.front()}, matched_a2a->replica_groups(),
-            false, hlo_query::NextChannelId(*matched_a2a->GetModule()),
-            split_dimension));
-
-    TF_RETURN_IF_ERROR(dot->ReplaceOperandWith(0, new_a2a));
-    TF_RETURN_IF_ERROR(
-        matched_a2a->parent()->RemoveInstructionAndUnusedOperands(matched_a2a));
-    MarkAsChanged();
-    *lhs = new_a2a;
-    *rhs = dot->mutable_operand(1);
-    return true;
-  }
-
-  absl::Status HandleDot(HloInstruction* dot) override {
-    CHECK_EQ(dot->opcode(), HloOpcode::kDot);
-    HloComputation* comp = dot->parent();
-    // Rewrites a allgather-dot pattern that shares the same operand
-    // with a windowed einsum loop to consume the output of the loop
-    // and remove the all-gather.
-    // Now that we have processed all loops, we can check if there are any
-    // allgather-dot pattern that we can optimize. We'd want to transform:
-    //                       input
-    //                       /    |
-    //                      /     |
-    //                     AG    windowed loop
-    //                     /
-    //                    /
-    //                   dot
-    // to:
-    //                       input
-    //                       |
-    //                       |
-    //                     windowed loop
-    //                       |
-    //                       |
-    //                      dot
-    // The windowed einsum loop will also be rewritten to output the full input
-    // to be consumed by the dot. This is advantageous since the chained dot can
-    // fully utilize all the resources on the GPU while comm is hidden by the
-    // first collective matmul loop.
-    for (GpuWindowedEinsumHandler::WindowedEinsumAgLoops ag_loop :
-         all_ag_loops_) {
-      HloInstruction* loop = ag_loop.loop;
-      HloInstruction* ag_operand = nullptr;
-
-      if (Match(dot, m::Dot(m::AllGather(&ag_operand), m::Op())) ||
-          Match(dot, m::Dot(m::Op(), m::AllGather(&ag_operand)))) {
-        HloInstruction* windowed_lhs =
-            loop->mutable_operand(0)->mutable_operand(0);
-        HloInstruction* ag_with_shared_operand = nullptr;
-        if (ag_operand && ag_operand->mutable_operand(0) == windowed_lhs) {
-          ag_with_shared_operand = ag_operand;
-        }
-
-        if (!ag_with_shared_operand) {
-          continue;
-        }
-
-        VLOG(5) << "Found all-gather that shares the same operand with a "
-                   "windowed einsum loop : "
-                << loop->ToString();
-
-        if (!ag_loop.consumed) {
-          TF_RETURN_IF_ERROR(ProcessWindowedEinsumLoopForActivationCaching(
-              ag_loop, ag_with_shared_operand));
-          ag_loop.consumed = true;
-        }
-        int64_t cache_output_index = dot->operand_index(ag_with_shared_operand);
-        HloComputation* comp = dot->parent();
-        HloInstruction* new_gte =
-            comp->AddInstruction(HloInstruction::CreateGetTupleElement(
-                loop, GetAgActivationCacheIndex(loop) - 1));
-        TF_RETURN_IF_ERROR(
-            dot->ReplaceOperandWith(cache_output_index, new_gte));
-        TF_RETURN_IF_ERROR(comp->RemoveInstruction(ag_with_shared_operand));
-      }
-    }
-    // Rewrites an all-to-all+gemm into multiple independent partial a2a+gemms
-    // to minimize communication overhead. To do this, the original input will
-    // be sliced into replica_group size and perform all-to-all+gemm.
-    HloInstruction* lhs;
-    HloInstruction* rhs;
-    std::vector<xla::ReplicaGroup> replica_groups;
-    TF_ASSIGN_OR_RETURN(bool matched,
-                        MatchA2aGemmWithIntermediateReshapes(dot, &lhs, &rhs));
-    if (matched) {
-      replica_groups = lhs->replica_groups();
-      // We split the a2a+gemm along the contracting dimension into multiple
-      // a2a+gemms and perform partial dots, partial results are added to the
-      // final output buffer.
-      int64_t group_size = replica_groups[0].replica_ids_size();
-      if (absl::c_find_if(replica_groups, [&](ReplicaGroup& group) {
-            return group.replica_ids_size() != group_size;
-          }) != replica_groups.end()) {
-        VLOG(5) << "All-to-all split groups don't have the same number of "
-                   "replicas.";
-        return absl::OkStatus();
-      }
-
-      // Get the dimension to slice for lhs and rhs, we slice on the contracting
-      // dimensions to calculate partial results
-      const DotDimensionNumbers& original_dot_dnums =
-          dot->dot_dimension_numbers();
-      const PrecisionConfig& original_precision = dot->precision_config();
-      const auto& lhs_contracting_dims =
-          dot->dot_dimension_numbers().lhs_contracting_dimensions();
-      const auto& rhs_contracting_dims =
-          dot->dot_dimension_numbers().rhs_contracting_dimensions();
-
-      if (lhs_contracting_dims.size() != 1 ||
-          rhs_contracting_dims.size() != 1) {
-        VLOG(5) << "Contracting dimensions have multiple elements, all-to-all "
-                   "sharding will be skipped.";
-        return absl::OkStatus();
-      }
-      int64_t lhs_contracting_dim = lhs_contracting_dims[0];
-      int64_t rhs_contracting_dim = rhs_contracting_dims[0];
-      HloAllToAllInstruction* a2a = DynCast<HloAllToAllInstruction>(lhs);
-      int64_t contracting_dim_value =
-          rhs->shape().dimensions()[rhs_contracting_dim];
-
-      // Each split is sliced out of the input buffer, we need to determine the
-      // slice sizes and increments.
-      std::vector<int64_t> lhs_slice_sizes(a2a->shape().rank(), 0);
-      std::vector<int64_t> lhs_slice_increments(a2a->shape().rank(), 1);
-      std::vector<int64_t> lhs_slice_max_range(
-          a2a->shape().dimensions().begin(), a2a->shape().dimensions().end());
-
-      std::vector<int64_t> rhs_slice_sizes(rhs->shape().rank(), 0);
-      std::vector<int64_t> rhs_slice_increments(rhs->shape().rank(), 1);
-      std::vector<int64_t> rhs_slice_max_range(
-          rhs->shape().dimensions().begin(), rhs->shape().dimensions().end());
-
-      // Create a zero-valued buffer to hold output.
-      HloInstruction* output_buffer =
-          comp->AddInstruction(HloInstruction::CreateBroadcast(
-              dot->shape(),
-              comp->AddInstruction(HloInstruction::CreateConstant(
-                  LiteralUtil::Zero(dot->shape().element_type()))),
-              {}));
-      HloInstruction* a2a_operand = a2a->mutable_operand(0);
-      if (contracting_dim_value % group_size) {
-        VLOG(5) << absl::StrFormat(
-            "Contracting dimension %d needs to be divisible by group_size %d",
-            contracting_dim_value, group_size);
-        return absl::OkStatus();
-      }
-      int64_t size_per_split = contracting_dim_value / group_size;
-
-      // Each split is sliced out of the input buffer, we need to determine the
-      // slice sizes and increments.
-      lhs_slice_max_range[lhs_contracting_dim] = size_per_split;
-      rhs_slice_max_range[rhs_contracting_dim] = size_per_split;
-
-      Shape lhs_slice_shape = a2a->shape();
-      Shape rhs_slice_shape = rhs->shape();
-
-      lhs_slice_shape.set_dimensions(lhs_contracting_dim, size_per_split);
-      rhs_slice_shape.set_dimensions(rhs_contracting_dim, size_per_split);
-
-      HloInstruction* lhs_slice;
-      HloInstruction* rhs_slice;
-
-      HloInstruction* partial_result = output_buffer;
-
-      Shape partial_all_to_all_shape = lhs_slice_shape;
-
-      TF_ASSIGN_OR_RETURN(
-          Shape partial_dot_shape,
-          ShapeInference::InferDotOpShape(
-              partial_all_to_all_shape, rhs_slice_shape, original_dot_dnums,
-              /*preferred_element_type=*/std::nullopt));
-      int64_t stream_id = hlo_query::NextChannelId(*a2a->GetModule());
-      for (int64_t i = 0; i < group_size; ++i) {
-        lhs_slice = comp->AddInstruction(HloInstruction::CreateSlice(
-            lhs_slice_shape, a2a_operand, lhs_slice_sizes, lhs_slice_max_range,
-            lhs_slice_increments));
-        a2a->SetupDerivedInstruction(lhs_slice);
-        lhs_slice_sizes[lhs_contracting_dim] =
-            lhs_slice_max_range[lhs_contracting_dim];
-        lhs_slice_max_range[lhs_contracting_dim] += size_per_split;
-
-        rhs_slice = comp->AddInstruction(HloInstruction::CreateSlice(
-            rhs_slice_shape, rhs, rhs_slice_sizes, rhs_slice_max_range,
-            rhs_slice_increments));
-        a2a->SetupDerivedInstruction(rhs_slice);
-        rhs_slice_sizes[rhs_contracting_dim] =
-            rhs_slice_max_range[rhs_contracting_dim];
-        rhs_slice_max_range[rhs_contracting_dim] += size_per_split;
-
-        HloInstruction* partial_all_to_all =
-            comp->AddInstruction(HloInstruction::CreateAllToAll(
-                partial_all_to_all_shape, {lhs_slice}, a2a->device_list(),
-                false, hlo_query::NextChannelId(*a2a->GetModule()),
-                a2a->split_dimension()));
-        a2a->SetupDerivedInstruction(partial_all_to_all);
-
-        HloInstruction* partial_dot =
-            comp->AddInstruction(HloInstruction::CreateDot(
-                partial_dot_shape, partial_all_to_all, rhs_slice,
-                original_dot_dnums, original_precision));
-        partial_result = comp->AddInstruction(
-            HloInstruction::CreateBinary(partial_dot->shape(), HloOpcode::kAdd,
-                                         partial_dot, partial_result));
-        a2a->SetupDerivedInstruction(partial_result);
-        TF_RETURN_IF_ERROR(
-            UpdateDotAndConsumerConfig(partial_dot, stream_id++));
-      }
-      TF_RETURN_IF_ERROR(ReplaceInstruction(dot, partial_result));
-    }
-    return absl::OkStatus();
-  }
-
-  absl::StatusOr<MatchedGemmA2aResult> MatchGemmA2aWithIntermediateReshapes(
-      HloInstruction* inst) {
-    MatchedGemmA2aResult result;
-    HloAllToAllInstruction* a2a = DynCast<HloAllToAllInstruction>(inst);
-    if (!HasReplicaGroups(a2a) || a2a->constrain_layout() ||
-        a2a->shape().IsTuple()) {
-      return result;
-    }
-    if (Match(a2a, m::AllToAll(m::Dot(&result.producer_gemm, m::Op(&result.lhs),
-                                      m::Op(&result.rhs))
-                                   .WithOneUse()))) {
-      result.matched = true;
-      return result;
-    }
-    std::vector<HloInstruction*> allowed_intermediate_ops(
-        {a2a->mutable_operand(0)});
-
-    HloInstruction* matched_dot = nullptr;
-    // We keep pushing until an unmet condition or we have found the producer
-    // dot.
-    while (true) {
-      HloInstruction* curr = allowed_intermediate_ops.back();
-      if (ShouldAddToChain(curr)) {
-        allowed_intermediate_ops.insert(allowed_intermediate_ops.end(),
-                                        std::begin(curr->operands()),
-                                        std::end(curr->operands()));
-      } else if (curr->opcode() == HloOpcode::kDot && curr->user_count() == 1) {
-        matched_dot = curr;
-        allowed_intermediate_ops.pop_back();
-        break;
-      } else {
-        return result;
-      }
-    }
-    CHECK(matched_dot != nullptr);
-    // We need to create a new a2a that's a direct consumer of the dot and
-    // replace it with the original a2a. A new reshape will be added to the
-    // orginal a2a's output. We first need to determine the new split dimension
-    // after all the reshape ops.
-    int64_t split_dimension = *a2a->split_dimension();
-    for (int64_t i = 0; i < allowed_intermediate_ops.size(); i++) {
-      HloInstruction* current_op = allowed_intermediate_ops[i];
-      if (current_op->opcode() == HloOpcode::kReshape) {
-        std::vector<std::pair<int64_t, int64_t>> unmodified_dims =
-            ShapeUtil::DimensionsUnmodifiedByReshape(
-                current_op->operand(0)->shape(), current_op->shape());
-        auto it = absl::c_find_if(
-            unmodified_dims,
-            [&split_dimension](std::pair<int64_t, int64_t>& dim_pair) {
-              return dim_pair.second == split_dimension;
-            });
-        // Split dimension of a2a has been modified, we cannot deduce the new
-        // split dim easily, so skip decomposition.
-        if (it == unmodified_dims.end()) {
-          VLOG(5) << "Split dimension of: " << a2a->ToShortString()
-                  << " has been modified by reshapes. Skip process it for "
-                     "decomposition.";
-          return result;
-        }
-        // Assign the new split dim.
-        split_dimension = it->first;
-      } else if (current_op->opcode() == HloOpcode::kTranspose) {
-        const auto& transpose_dims = current_op->dimensions();
-        split_dimension = transpose_dims[split_dimension];
-      }
-    }
-    result.a2a_replacement =
-        matched_dot->parent()->AddInstruction(HloInstruction::CreateAllToAll(
-            matched_dot->shape(), {matched_dot}, a2a->replica_groups(), false,
-            hlo_query::NextChannelId(*matched_dot->GetModule()),
-            split_dimension));
-    TF_RETURN_IF_ERROR(allowed_intermediate_ops.back()->ReplaceOperandWith(
-        0, result.a2a_replacement));
-    inst->SetupDerivedInstruction(result.a2a_replacement);
-
-    TF_RETURN_IF_ERROR(
-        ReplaceInstruction(inst, allowed_intermediate_ops.front()));
-    result.lhs = matched_dot->mutable_operand(0);
-    result.rhs = matched_dot->mutable_operand(1);
-    result.producer_gemm = matched_dot;
-    result.matched = true;
-    return result;
-  }
-
-  // Rewrites an gemm+all-to-all into multiple independent partial gemm+a2a's
-  // to minimize communication overhead. To do this, the original input will be
-  // sliced into replica_group size and perform gemm+all-to-all.
-  absl::Status HandleAllToAll(HloInstruction* inst) override {
-    CHECK_EQ(inst->opcode(), HloOpcode::kAllToAll);
-    HloComputation* comp = inst->parent();
-    // Rewrites a gemm+alltoall into multiple independent partial gemm+a2as
-    // to minimize communication overhead.
-    std::vector<xla::ReplicaGroup> replica_groups;
-    TF_ASSIGN_OR_RETURN(MatchedGemmA2aResult matched_result,
-                        MatchGemmA2aWithIntermediateReshapes(inst));
-    if (matched_result.matched) {
-      HloInstruction* a2a = inst;
-      if (matched_result.a2a_replacement) {
-        a2a = matched_result.a2a_replacement;
-      }
-      replica_groups = a2a->replica_groups();
-      // Similar to a2a+gemm, we split along contracting dimensions
-      // and aggregate result at each step.
-      int64_t group_size = replica_groups[0].replica_ids_size();
-
-      if (absl::c_find_if(replica_groups, [&](ReplicaGroup& group) {
-            return group.replica_ids_size() != group_size;
-          }) != replica_groups.end()) {
-        VLOG(5) << "All-to-all split groups don't have the same number of "
-                   "replicas.";
-        return absl::OkStatus();
-      }
-
-      // Get the dimension to slice for lhs and rhs, we slice on the contracting
-      // dimensions to calculate partial results
-      const DotDimensionNumbers& original_dot_dnums =
-          matched_result.producer_gemm->dot_dimension_numbers();
-      const PrecisionConfig& original_precision =
-          matched_result.producer_gemm->precision_config();
-      const auto& lhs_contracting_dims =
-          matched_result.producer_gemm->dot_dimension_numbers()
-              .lhs_contracting_dimensions();
-      const auto& rhs_contracting_dims =
-          matched_result.producer_gemm->dot_dimension_numbers()
-              .rhs_contracting_dimensions();
-
-      if (lhs_contracting_dims.size() != 1 ||
-          rhs_contracting_dims.size() != 1) {
-        VLOG(5) << "Contracting dimensions have multiple elements, all-to-all "
-                   "sharding will be skipped.";
-        return absl::OkStatus();
-      }
-      int64_t lhs_contracting_dim = lhs_contracting_dims[0];
-      int64_t rhs_contracting_dim = rhs_contracting_dims[0];
-      HloAllToAllInstruction* all_to_all = DynCast<HloAllToAllInstruction>(a2a);
-      int64_t contracting_dim_value =
-          matched_result.rhs->shape().dimensions()[rhs_contracting_dim];
-      // Each split is sliced out of the input buffer, we need to determine the
-      // slice sizes and increments.
-      std::vector<int64_t> lhs_slice_sizes(matched_result.lhs->shape().rank(),
-                                           0);
-      std::vector<int64_t> lhs_slice_increments(
-          matched_result.lhs->shape().rank(), 1);
-      std::vector<int64_t> lhs_slice_max_range(
-          matched_result.lhs->shape().dimensions().begin(),
-          matched_result.lhs->shape().dimensions().end());
-
-      std::vector<int64_t> rhs_slice_sizes(matched_result.rhs->shape().rank(),
-                                           0);
-      std::vector<int64_t> rhs_slice_increments(
-          matched_result.rhs->shape().rank(), 1);
-      std::vector<int64_t> rhs_slice_max_range(
-          matched_result.rhs->shape().dimensions().begin(),
-          matched_result.rhs->shape().dimensions().end());
-
-      // Create a zero-valued buffer to hold output.
-      HloInstruction* output_buffer =
-          comp->AddInstruction(HloInstruction::CreateBroadcast(
-              all_to_all->shape(),
-              comp->AddInstruction(HloInstruction::CreateConstant(
-                  LiteralUtil::Zero(all_to_all->shape().element_type()))),
-              {}));
-      if (contracting_dim_value % group_size) {
-        VLOG(5) << absl::StrFormat(
-            "Contracting dimension %d needs to be divisible by group_size %d",
-            contracting_dim_value, group_size);
-        return absl::OkStatus();
-      }
-
-      int64_t size_per_split = contracting_dim_value / group_size;
-      // Each split is sliced out of the input buffer, we need to determine the
-      // slice sizes and increments.
-      lhs_slice_max_range[lhs_contracting_dim] = size_per_split;
-      rhs_slice_max_range[rhs_contracting_dim] = size_per_split;
-
-      Shape lhs_slice_shape = matched_result.lhs->shape();
-      Shape rhs_slice_shape = matched_result.rhs->shape();
-
-      lhs_slice_shape.set_dimensions(lhs_contracting_dim, size_per_split);
-      rhs_slice_shape.set_dimensions(rhs_contracting_dim, size_per_split);
-
-      HloInstruction* lhs_slice;
-      HloInstruction* rhs_slice;
-
-      HloInstruction* partial_result = output_buffer;
-      Shape partial_all_to_all_shape = all_to_all->shape();
-
-      TF_ASSIGN_OR_RETURN(
-          Shape partial_dot_shape,
-          ShapeInference::InferDotOpShape(
-              lhs_slice_shape, rhs_slice_shape, original_dot_dnums,
-              /*preferred_element_type=*/std::nullopt));
-      int64_t stream_id = hlo_query::NextChannelId(*all_to_all->GetModule());
-      for (int64_t i = 0; i < group_size; ++i) {
-        lhs_slice = comp->AddInstruction(HloInstruction::CreateSlice(
-            lhs_slice_shape, matched_result.lhs, lhs_slice_sizes,
-            lhs_slice_max_range, lhs_slice_increments));
-        all_to_all->SetupDerivedInstruction(lhs_slice);
-        lhs_slice_sizes[lhs_contracting_dim] =
-            lhs_slice_max_range[lhs_contracting_dim];
-        lhs_slice_max_range[lhs_contracting_dim] += size_per_split;
-
-        rhs_slice = comp->AddInstruction(HloInstruction::CreateSlice(
-            rhs_slice_shape, matched_result.rhs, rhs_slice_sizes,
-            rhs_slice_max_range, rhs_slice_increments));
-
-        all_to_all->SetupDerivedInstruction(rhs_slice);
-        rhs_slice_sizes[rhs_contracting_dim] =
-            rhs_slice_max_range[rhs_contracting_dim];
-        rhs_slice_max_range[rhs_contracting_dim] += size_per_split;
-
-        HloInstruction* partial_dot = comp->AddInstruction(
-            HloInstruction::CreateDot(partial_dot_shape, lhs_slice, rhs_slice,
-                                      original_dot_dnums, original_precision));
-
-        HloInstruction* partial_all_to_all =
-            comp->AddInstruction(HloInstruction::CreateAllToAll(
-                partial_all_to_all_shape, {partial_dot},
-                all_to_all->device_list(), false,
-                hlo_query::NextChannelId(*all_to_all->GetModule()),
-                all_to_all->split_dimension()));
-        all_to_all->SetupDerivedInstruction(partial_all_to_all);
-        partial_result = comp->AddInstruction(HloInstruction::CreateBinary(
-            partial_all_to_all_shape, HloOpcode::kAdd, partial_all_to_all,
-            partial_result));
-        all_to_all->SetupDerivedInstruction(partial_result);
-        TF_RETURN_IF_ERROR(
-            UpdateDotAndConsumerConfig(partial_dot, stream_id++));
-      }
-      TF_RETURN_IF_ERROR(ReplaceInstruction(all_to_all, partial_result));
-    }
-
-    return absl::OkStatus();
-  }
-
- private:
-  std::vector<GpuWindowedEinsumHandler::WindowedEinsumAgLoops>& all_ag_loops_;
-};
-
-}  // namespace
-
-absl::StatusOr<bool> GpuWindowedEinsumHandler::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  XLA_VLOG_LINES(
-      5, "GpuWindowedEinsumHandler::Run(), before:\n" + module->ToString());
-  bool changed = false;
-  int64_t stream_id = hlo_query::NextChannelId(*module);
-
-  for (HloComputation* comp :
-       module->MakeNonfusionComputations(execution_threads)) {
-    if (comp->name().find(kWindowedEinsumRsLoopName) == 0) {
-      VLOG(5) << "Processing computation: " << comp->name();
-      TF_ASSIGN_OR_RETURN(bool comp_result,
-                          HandleRsWindowedEinsumLoop(comp, stream_id));
-      changed = comp_result;
-    } else if (comp->name().find(kWindowedEinsumAgLoopName) == 0) {
-      VLOG(5) << "Processing computation: " << comp->name();
-      TF_ASSIGN_OR_RETURN(bool comp_result,
-                          HandleAgWindowedEinsumLoop(comp, stream_id));
-      all_ag_loops_.push_back(
-          WindowedEinsumAgLoops(comp->WhileCallInstruction()));
-      changed = comp_result;
-    }
-  }
-  for (HloComputation* comp :
-       module->MakeNonfusionComputations(execution_threads)) {
-    WindowedEinsumVisitor visitor(all_ag_loops_);
-    TF_RETURN_IF_ERROR(comp->Accept(&visitor));
-    changed |= visitor.changed();
-  }
-
-  XLA_VLOG_LINES(
-      5, "GpuWindowedEinsumHandler::Run(), after:\n" + module->ToString());
-  return changed;
-}
-
-}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler.h b/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler.h
deleted file mode 100644
index b511920..0000000
--- a/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler.h
+++ /dev/null
@@ -1,66 +0,0 @@
-/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_GPU_WINDOWED_EINSUM_HANDLER_H_
-#define XLA_SERVICE_GPU_GPU_WINDOWED_EINSUM_HANDLER_H_
-
-#include <vector>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla::gpu {
-
-// This pass is targeting the windowed einsum optimization
-// in the SPMD pipeline. It rewrites all-gather+gemm or
-// gemm+reduce-scatter into sharded loops to achieve overlap
-// between sharded gemms and communication. This pass will
-// optimize it on GPU by annotating independent gemms with
-// stream ids in the backend config. By running them in different
-// streams, we can practically achieve overlap between gemms too.
-class GpuWindowedEinsumHandler : public HloModulePass {
- public:
-  absl::string_view name() const override {
-    return "gpu-windowed-einsum-handler";
-  }
-
-  struct WindowedEinsumAgLoops {
-    explicit WindowedEinsumAgLoops(HloInstruction* loop) : loop(loop) {}
-    HloInstruction* loop;
-    bool consumed = false;
-  };
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
-  constexpr static const char* kWindowedEinsumRsLoopName =
-      "windowed_dot_general_body_rs";
-  constexpr static const char* kWindowedEinsumAgLoopName =
-      "windowed_dot_general_body_ag";
-
- private:
-  std::vector<WindowedEinsumAgLoops> all_ag_loops_;
-};
-
-}  // namespace xla::gpu
-
-#endif  // XLA_SERVICE_GPU_GPU_WINDOWED_EINSUM_HANDLER_H_
diff --git a/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler_test.cc b/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler_test.cc
deleted file mode 100644
index 6f23319..0000000
--- a/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler_test.cc
+++ /dev/null
@@ -1,918 +0,0 @@
-/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gpu_windowed_einsum_handler.h"
-
-#include <memory>
-#include <string>
-
-#include <gtest/gtest.h>
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/tests/filecheck.h"
-#include "xla/tests/hlo_test_base.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla::gpu {
-namespace {
-
-namespace m = ::xla::match;
-
-using GpuWindowedEinsumHanlderTest = HloTestBase;
-
-HloInstruction* FindInstructionByName(HloComputation* comp, std::string name) {
-  for (auto inst : comp->instructions()) {
-    if (inst->name() == name) {
-      return inst;
-    }
-  }
-  return nullptr;
-}
-
-TEST_F(GpuWindowedEinsumHanlderTest, AgLoopsHaveStreamIds) {
-  constexpr absl::string_view kHloString = R"(
-HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[1,512,24576]{2,1,0}, bf16[24576,24576]{1,0})->bf16[2048,24576]{1,0}}, num_partitions=4
-
-windowed_dot_general_body_ag.1 {
-  param = (bf16[512,24576]{1,0}, bf16[24576,24576]{1,0}, bf16[2048,24576]{1,0}, bf16[2048,24576]{1,0}, u32[]) parameter(0)
-  get-tuple-element = bf16[512,24576]{1,0} get-tuple-element(param), index=0
-  collective-permute = bf16[512,24576]{1,0} collective-permute(get-tuple-element), channel_id=2, source_target_pairs={{0,3},{1,0},{2,1},{3,2}}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[]}
-  get-tuple-element.1 = bf16[24576,24576]{1,0} get-tuple-element(param), index=1
-  get-tuple-element.2 = bf16[2048,24576]{1,0} get-tuple-element(param), index=2
-  dot.2 = bf16[512,24576]{1,0} dot(get-tuple-element, get-tuple-element.1), lhs_contracting_dims={1}, rhs_contracting_dims={0}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[]}
-  constant.1 = s32[4]{0} constant({0, 512, 1024, 1536})
-  get-tuple-element.4 = u32[] get-tuple-element(param), index=4
-  partition-id = u32[] partition-id()
-  add = u32[] add(get-tuple-element.4, partition-id)
-  constant = u32[] constant(4)
-  remainder = u32[] remainder(add, constant)
-  dynamic-slice = s32[1]{0} dynamic-slice(constant.1, remainder), dynamic_slice_sizes={1}
-  reshape.4 = s32[] reshape(dynamic-slice)
-  constant.2 = s32[] constant(0)
-  dynamic-update-slice = bf16[2048,24576]{1,0} dynamic-update-slice(get-tuple-element.2, dot.2, reshape.4, constant.2), backend_config={"operation_queue_id":"0","wait_on_operation_queues":[]}
-  dot.3 = bf16[512,24576]{1,0} dot(collective-permute, get-tuple-element.1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  constant.3 = u32[] constant(1)
-  add.1 = u32[] add(get-tuple-element.4, constant.3)
-  add.2 = u32[] add(add.1, partition-id)
-  remainder.1 = u32[] remainder(add.2, constant)
-  dynamic-slice.1 = s32[1]{0} dynamic-slice(constant.1, remainder.1), dynamic_slice_sizes={1}
-  reshape.5 = s32[] reshape(dynamic-slice.1)
-  dynamic-update-slice.1 = bf16[2048,24576]{1,0} dynamic-update-slice(dynamic-update-slice, dot.3, reshape.5, constant.2)
-  get-tuple-element.3 = bf16[2048,24576]{1,0} get-tuple-element(param), index=3
-  add.3 = u32[] add(add.1, constant.3)
-  ROOT tuple = (bf16[512,24576]{1,0}, bf16[24576,24576]{1,0}, bf16[2048,24576]{1,0}, bf16[2048,24576]{1,0}, u32[]) tuple(collective-permute, get-tuple-element.1, dynamic-update-slice.1, get-tuple-element.3, add.3)
-} // windowed_dot_general_body_ag.1
-
-windowed_dot_general_cond_ag {
-  param.1 = (bf16[512,24576]{1,0}, bf16[24576,24576]{1,0}, bf16[2048,24576]{1,0}, bf16[2048,24576]{1,0}, u32[]) parameter(0)
-  get-tuple-element.5 = u32[] get-tuple-element(param.1), index=4
-  constant.8 = u32[] constant(4)
-  ROOT compare = pred[] compare(get-tuple-element.5, constant.8), direction=LT
-}
-
-ENTRY test_main {
-  param.4 = bf16[1,512,24576]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]}
-  reshape.8 = bf16[512,24576]{1,0} reshape(param.4)
-  param.5 = bf16[24576,24576]{1,0} parameter(1), sharding={devices=[1,4]<=[4]}
-  constant.18 = bf16[] constant(0)
-  broadcast = bf16[2048,24576]{1,0} broadcast(constant.18), dimensions={}
-  constant.20 = u32[] constant(0)
-  tuple.2 = (bf16[512,24576]{1,0}, bf16[24576,24576]{1,0}, bf16[2048,24576]{1,0}, bf16[2048,24576]{1,0}, u32[]) tuple(reshape.8, param.5, broadcast, broadcast, constant.20)
-  while = (bf16[512,24576]{1,0}, bf16[24576,24576]{1,0}, bf16[2048,24576]{1,0}, bf16[2048,24576]{1,0}, u32[]) while(tuple.2), condition=windowed_dot_general_cond_ag, body=windowed_dot_general_body_ag.1
-  ROOT get-tuple-element.13 = bf16[2048,24576]{1,0} get-tuple-element(while), index=2
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(kHloString));
-
-  GpuWindowedEinsumHandler gpu_handler;
-  bool changed;
-  TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get()));
-  EXPECT_TRUE(changed);
-
-  HloInstruction* ag_loop =
-      module->entry_computation()->root_instruction()->mutable_operand(0);
-  HloComputation* ag_loop_body = ag_loop->while_body();
-  HloInstruction* inst = FindInstructionByName(ag_loop_body, "dot.2");
-  EXPECT_GT(inst->backend_config<GpuBackendConfig>()->operation_queue_id(), 0);
-  EXPECT_TRUE(
-      inst->backend_config<GpuBackendConfig>()->force_earliest_schedule());
-
-  HloInstruction* cp1 =
-      FindInstructionByName(ag_loop_body, "collective-permute");
-  EXPECT_TRUE(
-      cp1->backend_config<GpuBackendConfig>()->force_earliest_schedule());
-}
-
-TEST_F(GpuWindowedEinsumHanlderTest, RsLoopsHaveStreamIds) {
-  constexpr absl::string_view kHloString = R"(
-HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[24576,24576]{1,0}, bf16[512,24576]{1,0}, bf16[2048,24576]{1,0})->bf16[512,24576]{1,0}}, num_partitions=4
-
-windowed_dot_general_body_rs_clone.1 {
-  param.2 = (bf16[2048,24576]{1,0}, bf16[24576,24576]{1,0}, bf16[512,24576]{1,0}, bf16[512,24576]{1,0}, u32[]) parameter(0)
-  get-tuple-element.6 = bf16[2048,24576]{1,0} get-tuple-element(param.2), index=0
-  get-tuple-element.7 = bf16[24576,24576]{1,0} get-tuple-element(param.2), index=1
-  get-tuple-element.9 = bf16[512,24576]{1,0} get-tuple-element(param.2), index=2
-  collective-permute.1 = bf16[512,24576]{1,0} collective-permute(get-tuple-element.9), channel_id=4, source_target_pairs={{0,2},{1,3},{2,0},{3,1}}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[]}
-  constant.10 = s32[4]{0} constant({0, 512, 1024, 1536})
-  get-tuple-element.11 = u32[] get-tuple-element(param.2), index=4
-  constant.12 = u32[] constant(2)
-  add.8 = u32[] add(get-tuple-element.11, constant.12)
-  constant.13 = u32[] constant(1)
-  add.9 = u32[] add(add.8, constant.13)
-  partition-id.3 = u32[] partition-id()
-  add.10 = u32[] add(add.9, partition-id.3)
-  constant.9 = u32[] constant(4)
-  remainder.3 = u32[] remainder(add.10, constant.9)
-  dynamic-slice.4 = s32[1]{0} dynamic-slice(constant.10, remainder.3), dynamic_slice_sizes={1}
-  reshape.7 = s32[] reshape(dynamic-slice.4)
-  constant.11 = s32[] constant(0)
-  dynamic-slice.5 = bf16[512,24576]{1,0} dynamic-slice(get-tuple-element.6, reshape.7, constant.11), dynamic_slice_sizes={512,24576}
-  dot.7 = bf16[512,24576]{1,0} dot(dynamic-slice.5, get-tuple-element.7), lhs_contracting_dims={1}, rhs_contracting_dims={0}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[]}
-  add.11 = bf16[512,24576]{1,0} add(collective-permute.1, dot.7), backend_config={"operation_queue_id":"0","wait_on_operation_queues":[]}
-  get-tuple-element.10 = bf16[512,24576]{1,0} get-tuple-element(param.2), index=3
-  add.6 = u32[] add(get-tuple-element.11, partition-id.3)
-  remainder.2 = u32[] remainder(add.6, constant.9)
-  dynamic-slice.2 = s32[1]{0} dynamic-slice(constant.10, remainder.2), dynamic_slice_sizes={1}
-  reshape.6 = s32[] reshape(dynamic-slice.2)
-  dynamic-slice.3 = bf16[512,24576]{1,0} dynamic-slice(get-tuple-element.6, reshape.6, constant.11), dynamic_slice_sizes={512,24576}
-  dot.5 = bf16[512,24576]{1,0} dot(dynamic-slice.3, get-tuple-element.7), lhs_contracting_dims={1}, rhs_contracting_dims={0}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[]}
-  add.7 = bf16[512,24576]{1,0} add(get-tuple-element.10, dot.5), backend_config={"operation_queue_id":"0","wait_on_operation_queues":[]}
-  collective-permute.2 = bf16[512,24576]{1,0} collective-permute(add.7), channel_id=5, source_target_pairs={{0,2},{1,3},{2,0},{3,1}}
-  ROOT tuple.1 = (bf16[2048,24576]{1,0}, bf16[24576,24576]{1,0}, bf16[512,24576]{1,0}, bf16[512,24576]{1,0}, u32[]) tuple(get-tuple-element.6, get-tuple-element.7, add.11, collective-permute.2, add.8)
-}
-
-windowed_dot_general_cond_rs {
-  param.3 = (bf16[2048,24576]{1,0}, bf16[24576,24576]{1,0}, bf16[512,24576]{1,0}, bf16[512,24576]{1,0}, u32[]) parameter(0)
-  get-tuple-element.12 = u32[] get-tuple-element(param.3), index=4
-  constant.17 = u32[] constant(4)
-  ROOT compare.1 = pred[] compare(get-tuple-element.12, constant.17), direction=LT
-}
-
-ENTRY main.9_spmd {
-  param.6 = bf16[24576,24576]{1,0} parameter(0), sharding={devices=[4,1]<=[4]}
-  param.7 = bf16[512,24576]{1,0} parameter(1)
-  param.8 = bf16[2048,24576]{1,0} parameter(2)
-  constant.20 = u32[] constant(0)
-  tuple.3 = (bf16[2048,24576]{1,0}, bf16[24576,24576]{1,0}, bf16[512,24576]{1,0}, bf16[512,24576]{1,0}, u32[]) tuple(param.8, param.6, param.7, param.7, constant.20)
-  while.1 = (bf16[2048,24576]{1,0}, bf16[24576,24576]{1,0}, bf16[512,24576]{1,0}, bf16[512,24576]{1,0}, u32[]) while(tuple.3), condition=windowed_dot_general_cond_rs, body=windowed_dot_general_body_rs_clone.1
-  ROOT get-tuple-element.14 = bf16[512,24576]{1,0} get-tuple-element(while.1), index=2
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(kHloString));
-
-  GpuWindowedEinsumHandler gpu_handler;
-  bool changed;
-  TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get()));
-  EXPECT_TRUE(changed);
-
-  HloInstruction* rs_loop =
-      module->entry_computation()->root_instruction()->mutable_operand(0);
-  HloComputation* rs_loop_body = rs_loop->while_body();
-  HloInstruction* inst = FindInstructionByName(rs_loop_body, "dot.7");
-  EXPECT_TRUE(inst->backend_config<GpuBackendConfig>()->operation_queue_id() >
-              0);
-
-  HloInstruction* cp1 =
-      FindInstructionByName(rs_loop_body, "collective-permute.1");
-  EXPECT_TRUE(
-      cp1->backend_config<GpuBackendConfig>()->force_earliest_schedule());
-}
-
-TEST_F(GpuWindowedEinsumHanlderTest, AgLoopsMultipleConsumersAreChained) {
-  constexpr absl::string_view kHloString = R"(
-HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[2,512,24576]{2,1,0}, bf16[24576,24576]{1,0}, bf16[24576,24576]{1,0})->bf16[2,2048,24576]{2,1,0}}, num_partitions=4
-
-windowed_dot_general_body_ag {
-  param.1 = (bf16[2,512,24576]{2,1,0}, bf16[24576,24576]{1,0}, bf16[2,2048,24576]{2,1,0}, bf16[2,2048,24576]{2,1,0}, u32[]) parameter(0)
-  get-tuple-element.1 = bf16[2,512,24576]{2,1,0} get-tuple-element(param.1), index=0
-  collective-permute = bf16[2,512,24576]{2,1,0} collective-permute(get-tuple-element.1), channel_id=2, source_target_pairs={{0,3},{1,0},{2,1},{3,2}}
-  collective-permute.1 = bf16[2,512,24576]{2,1,0} collective-permute(collective-permute), channel_id=3, source_target_pairs={{0,3},{1,0},{2,1},{3,2}}
-  get-tuple-element.2 = bf16[24576,24576]{1,0} get-tuple-element(param.1), index=1
-  get-tuple-element.3 = bf16[2,2048,24576]{2,1,0} get-tuple-element(param.1), index=2
-  dot = bf16[2,512,24576]{2,1,0} dot(get-tuple-element.1, get-tuple-element.2), lhs_contracting_dims={2}, rhs_contracting_dims={0}
-  constant.2 = s32[] constant(0)
-  constant.3 = s32[4]{0} constant({0, 512, 1024, 1536})
-  get-tuple-element.5 = u32[] get-tuple-element(param.1), index=4
-  partition-id = u32[] partition-id()
-  add = u32[] add(get-tuple-element.5, partition-id)
-  constant.1 = u32[] constant(4)
-  remainder = u32[] remainder(add, constant.1)
-  dynamic-slice = s32[1]{0} dynamic-slice(constant.3, remainder), dynamic_slice_sizes={1}
-  reshape = s32[] reshape(dynamic-slice)
-  dynamic-update-slice = bf16[2,2048,24576]{2,1,0} dynamic-update-slice(get-tuple-element.3, dot, constant.2, reshape, constant.2)
-  dot.1 = bf16[2,512,24576]{2,1,0} dot(collective-permute, get-tuple-element.2), lhs_contracting_dims={2}, rhs_contracting_dims={0}
-  constant.5 = u32[] constant(1)
-  add.1 = u32[] add(get-tuple-element.5, constant.5)
-  add.2 = u32[] add(add.1, partition-id)
-  remainder.1 = u32[] remainder(add.2, constant.1)
-  dynamic-slice.1 = s32[1]{0} dynamic-slice(constant.3, remainder.1), dynamic_slice_sizes={1}
-  reshape.1 = s32[] reshape(dynamic-slice.1)
-  dynamic-update-slice.1 = bf16[2,2048,24576]{2,1,0} dynamic-update-slice(dynamic-update-slice, dot.1, constant.2, reshape.1, constant.2)
-  get-tuple-element.4 = bf16[2,2048,24576]{2,1,0} get-tuple-element(param.1), index=3
-  add.3 = u32[] add(add.1, constant.5)
-  ROOT tuple = (bf16[2,512,24576]{2,1,0}, bf16[24576,24576]{1,0}, bf16[2,2048,24576]{2,1,0}, bf16[2,2048,24576]{2,1,0}, u32[]) tuple(collective-permute.1, get-tuple-element.2, dynamic-update-slice.1, get-tuple-element.4, add.3)
-} // windowed_dot_general_body_ag
-
-windowed_dot_general_cond_ag {
-  param = (bf16[2,512,24576]{2,1,0}, bf16[24576,24576]{1,0}, bf16[2,2048,24576]{2,1,0}, bf16[2,2048,24576]{2,1,0}, u32[]) parameter(0)
-  get-tuple-element = u32[] get-tuple-element(param), index=4
-  constant = u32[] constant(4)
-  ROOT compare = pred[] compare(get-tuple-element, constant), direction=LT
-}
-
-ENTRY main.12_spmd {
-  param.4 = bf16[2,512,24576]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]}
-  param.5 = bf16[24576,24576]{1,0} parameter(1), sharding={devices=[1,4]<=[4]}
-  constant.22 = bf16[] constant(0)
-  broadcast = bf16[2,2048,24576]{2,1,0} broadcast(constant.22), dimensions={}
-  constant.24 = u32[] constant(0)
-  tuple.2 = (bf16[2,512,24576]{2,1,0}, bf16[24576,24576]{1,0}, bf16[2,2048,24576]{2,1,0}, bf16[2,2048,24576]{2,1,0}, u32[]) tuple(param.4, param.5, broadcast, broadcast, constant.24)
-  while = (bf16[2,512,24576]{2,1,0}, bf16[24576,24576]{1,0}, bf16[2,2048,24576]{2,1,0}, bf16[2,2048,24576]{2,1,0}, u32[]) while(tuple.2), condition=windowed_dot_general_cond_ag, body=windowed_dot_general_body_ag
-  get-tuple-element.13 = bf16[2,2048,24576]{2,1,0} get-tuple-element(while), index=2
-  copy.1 = bf16[2,2048,24576]{2,1,0} copy(get-tuple-element.13)
-  all-gather = bf16[2,2048,24576]{2,1,0} all-gather(param.4), channel_id=1, replica_groups={{0,1,2,3}}, dimensions={1}, use_global_device_ids=true
-  param.6 = bf16[24576,24576]{1,0} parameter(2), sharding={devices=[1,4]<=[4]}
-  ROOT dot.7 = bf16[2,2048,24576]{2,1,0} dot(all-gather, param.6), lhs_contracting_dims={2}, rhs_contracting_dims={0}
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(kHloString));
-
-  GpuWindowedEinsumHandler gpu_handler;
-  bool changed;
-  TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get()));
-  EXPECT_TRUE(changed);
-  HloInstruction* ag_loop =
-      FindInstructionByName(module->entry_computation(), "while");
-  HloInstruction* inst =
-      FindInstructionByName(module->entry_computation(), "dot.7");
-  // dot.7 should now consume output of the windowed einsum while loop.
-  EXPECT_EQ(inst->operand(0)->opcode(), HloOpcode::kGetTupleElement);
-  EXPECT_EQ(inst->operand(0)->tuple_index(), 5);
-  EXPECT_EQ(inst->operand(0)->operand(0), ag_loop);
-
-  // while loop's root should now have a chain of DUS.
-  HloInstruction* ag_while_root = ag_loop->while_body()->root_instruction();
-  EXPECT_THAT(ag_while_root,
-              GmockMatch(m::Tuple(
-                  m::Op(), m::Op(), m::Op(), m::Op(), m::Op(),
-                  m::DynamicUpdateSlice(
-                      m::DynamicUpdateSlice(
-                          m::GetTupleElement(m::Parameter())
-                              .WithPredicate([](const HloInstruction* instr) {
-                                return instr->tuple_index() == 5;
-                              }),
-                          m::Op(), m::Op(), m::Op(), m::Op()),
-                      m::Op(), m::Op(), m::Op(), m::Op()))));
-}
-TEST_F(GpuWindowedEinsumHanlderTest, A2aGemmHaveStreamIds) {
-  constexpr absl::string_view kHloString = R"(
-HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[1,8192,32768]{2,1,0}, bf16[1,4,2048,8192]{3,2,1,0})->bf16[1,4,2048,32768]{3,2,1,0}}, num_partitions=8
-
-ENTRY main.9_spmd {
-  param0 = bf16[1,8192,32768]{2,1,0} parameter(0)
-  param1 = bf16[1,4,2048,8192]{3,2,1,0} parameter(1)
-  all-to-all = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(param1), channel_id=4, replica_groups={{0,1,2,3},{4,5,6,7}}, dimensions={1}
-  ROOT dot.12 = bf16[1,4,2048,32768]{3,2,1,0} dot(all-to-all, param0), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}
-}
-)";
-
-  const char* kExpected = R"(
-CHECK: ENTRY
-CHECK-DAG: %[[P1:.*]] = bf16[1,4,2048,8192]{3,2,1,0} parameter(1)
-
-CHECK-DAG: %[[SLICE0:.*]] = bf16[1,4,2048,2048]{3,2,1,0} slice(bf16[1,4,2048,8192]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [6144:8192]}
-CHECK: %[[A2A0:.*]] = bf16[1,4,2048,2048]{3,2,1,0} all-to-all(bf16[1,4,2048,2048]{3,2,1,0} %[[SLICE0]]),
-CHECK: replica_groups={
-CHECK:     {0,1,2,3},{4,5,6,7}
-CHECK: }
-CHECK: dimensions={1}
-CHECK-DAG: %[[P0:.*]] = bf16[1,8192,32768]{2,1,0} parameter(0)
-CHECK-DAG: %[[SLICE4:.*]] = bf16[1,2048,32768]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [6144:8192], [0:32768]}
-CHECK-DAG: %[[DOT0:.*]] = bf16[1,4,2048,32768]{3,2,1,0} dot(bf16[1,4,2048,2048]{3,2,1,0} %[[A2A0:.*]], bf16[1,2048,32768]{2,1,0} %[[SLICE4:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"8","wait_on_operation_queues":[],"force_earliest_schedule":false}
-
-CHECK-DAG: %[[SLICE1:.*]] = bf16[1,4,2048,2048]{3,2,1,0} slice(bf16[1,4,2048,8192]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [4096:6144]}
-CHECK: %[[A2A1:.*]] = bf16[1,4,2048,2048]{3,2,1,0} all-to-all(bf16[1,4,2048,2048]{3,2,1,0} %[[SLICE1]]),
-CHECK: replica_groups={
-CHECK:     {0,1,2,3},{4,5,6,7}
-CHECK: }
-CHECK: dimensions={1}
-CHECK-DAG: %[[SLICE5:.*]] = bf16[1,2048,32768]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [4096:6144], [0:32768]}
-CHECK-DAG: %[[DOT1:.*]] = bf16[1,4,2048,32768]{3,2,1,0} dot(bf16[1,4,2048,2048]{3,2,1,0} %[[A2A1:.*]], bf16[1,2048,32768]{2,1,0} %[[SLICE5:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"7","wait_on_operation_queues":[],"force_earliest_schedule":false}
-
-CHECK-DAG: %[[SLICE2:.*]] = bf16[1,4,2048,2048]{3,2,1,0} slice(bf16[1,4,2048,8192]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [2048:4096]}
-CHECK: %[[A2A2:.*]] = bf16[1,4,2048,2048]{3,2,1,0} all-to-all(bf16[1,4,2048,2048]{3,2,1,0} %[[SLICE2]]),
-CHECK: replica_groups={
-CHECK:     {0,1,2,3},{4,5,6,7}
-CHECK: }
-CHECK: dimensions={1}
-CHECK-DAG: %[[SLICE6:.*]] = bf16[1,2048,32768]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [2048:4096], [0:32768]}
-CHECK-DAG: %[[DOT2:.*]] = bf16[1,4,2048,32768]{3,2,1,0} dot(bf16[1,4,2048,2048]{3,2,1,0} %[[A2A2:.*]], bf16[1,2048,32768]{2,1,0} %[[SLICE6:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"6","wait_on_operation_queues":[],"force_earliest_schedule":false}
-
-CHECK-DAG: %[[SLICE3:.*]] = bf16[1,4,2048,2048]{3,2,1,0} slice(bf16[1,4,2048,8192]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [0:2048]}
-CHECK: %[[A2A2:.*]] = bf16[1,4,2048,2048]{3,2,1,0} all-to-all(bf16[1,4,2048,2048]{3,2,1,0} %[[SLICE3]]),
-CHECK: replica_groups={
-CHECK:     {0,1,2,3},{4,5,6,7}
-CHECK: }
-CHECK: dimensions={1}
-CHECK-DAG: %[[SLICE7:.*]] = bf16[1,2048,32768]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [0:2048], [0:32768]}
-CHECK-DAG: %[[DOT3:.*]] = bf16[1,4,2048,32768]{3,2,1,0} dot(bf16[1,4,2048,2048]{3,2,1,0} %[[A2A3:.*]], bf16[1,2048,32768]{2,1,0} %[[SLICE7:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"5","wait_on_operation_queues":[],"force_earliest_schedule":false}
-CHECK-DAG: %[[CONSTANT:.*]] = bf16[] constant(0)
-CHECK-DAG: %[[BROADCAST:.*]] = bf16[1,4,2048,32768]{3,2,1,0} broadcast(bf16[] %[[CONSTANT:.*]]), dimensions={}
-CHECK-DAG: %[[ADD0:.*]] = bf16[1,4,2048,32768]{3,2,1,0} add(bf16[1,4,2048,32768]{3,2,1,0} %[[DOT0:.*]], bf16[1,4,2048,32768]{3,2,1,0} %[[BROADCAST:.*]]), backend_config={"operation_queue_id":"0","wait_on_operation_queues":["5"],"force_earliest_schedule":false}
-CHECK-DAG: %[[ADD1:.*]] = bf16[1,4,2048,32768]{3,2,1,0} add(bf16[1,4,2048,32768]{3,2,1,0} %[[DOT1:.*]], bf16[1,4,2048,32768]{3,2,1,0} %[[ADD0:.*]]), backend_config={"operation_queue_id":"0","wait_on_operation_queues":["6"],"force_earliest_schedule":false}
-CHECK-DAG: %[[ADD2:.*]] = bf16[1,4,2048,32768]{3,2,1,0} add(bf16[1,4,2048,32768]{3,2,1,0} %[[DOT2:.*]], bf16[1,4,2048,32768]{3,2,1,0} %[[ADD1:.*]]), backend_config={"operation_queue_id":"0","wait_on_operation_queues":["7"],"force_earliest_schedule":false}
-
-CHECK: ROOT {{.*}} = bf16[1,4,2048,32768]{3,2,1,0} add(bf16[1,4,2048,32768]{3,2,1,0} %[[DOT3:.*]], bf16[1,4,2048,32768]{3,2,1,0} %[[ADD2:.*]])
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(kHloString));
-
-  GpuWindowedEinsumHandler gpu_handler;
-  bool changed;
-  TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get()));
-  TF_ASSERT_OK_AND_ASSIGN(bool filecheck_matched,
-                          RunFileCheck(module->ToString(), kExpected));
-  EXPECT_TRUE(filecheck_matched);
-}
-
-TEST_F(GpuWindowedEinsumHanlderTest, GemmA2aHaveStreamIds) {
-  constexpr absl::string_view kHloString = R"(
-HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[1,8192,32768]{2,1,0}, bf16[1,4,2048,32768]{3,2,1,0})->bf16[1,4,2048,8192]{3,2,1,0}}, num_partitions=4
-
-ENTRY main.9_spmd {
-  param.9 = bf16[1,8192,32768]{2,1,0} parameter(0)
-  param.10 = bf16[1,4,2048,32768]{3,2,1,0} parameter(1)
-  dot.12 = bf16[1,4,2048,8192]{3,2,1,0} dot(param.10, param.9), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={2}
-  ROOT all-to-all = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(dot.12), channel_id=4, replica_groups={{0,1,2,3}}, dimensions={1}
-}
-)";
-
-  const char* kExpected = R"(
-CHECK: ENTRY
-CHECK-DAG: %[[P1:.*]] = bf16[1,4,2048,32768]{3,2,1,0} parameter(1)
-
-CHECK-DAG: %[[SLICE0:.*]] = bf16[1,4,2048,8192]{3,2,1,0} slice(bf16[1,4,2048,32768]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [24576:32768]}
-CHECK-DAG: %[[P0:.*]] = bf16[1,8192,32768]{2,1,0} parameter(0)
-CHECK-DAG: %[[SLICE4:.*]] = bf16[1,8192,8192]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [0:8192], [24576:32768]}
-CHECK-DAG: %[[DOT0:.*]] = bf16[1,4,2048,8192]{3,2,1,0} dot(bf16[1,4,2048,8192]{3,2,1,0} %[[SLICE0:.*]], bf16[1,8192,8192]{2,1,0} %[[SLICE4:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={2}, backend_config={"operation_queue_id":"8","wait_on_operation_queues":[],"force_earliest_schedule":false}
-CHECK: %[[A2A0:.*]] = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(bf16[1,4,2048,8192]{3,2,1,0} %[[DOT0:.*]]),
-CHECK: replica_groups={
-CHECK:     {0,1,2,3}
-CHECK: }
-CHECK: dimensions={1}
-
-CHECK-DAG: %[[SLICE1:.*]] = bf16[1,4,2048,8192]{3,2,1,0} slice(bf16[1,4,2048,32768]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [16384:24576]}
-CHECK-DAG: %[[SLICE5:.*]] = bf16[1,8192,8192]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [0:8192], [16384:24576]}
-CHECK-DAG: %[[DOT1:.*]] = bf16[1,4,2048,8192]{3,2,1,0} dot(bf16[1,4,2048,8192]{3,2,1,0} %[[SLICE1:.*]], bf16[1,8192,8192]{2,1,0} %[[SLICE5:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={2}, backend_config={"operation_queue_id":"7","wait_on_operation_queues":[],"force_earliest_schedule":false}
-CHECK: %[[A2A1:.*]] = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(bf16[1,4,2048,8192]{3,2,1,0} %[[DOT1:.*]]),
-CHECK: replica_groups={
-CHECK:     {0,1,2,3}
-CHECK: }
-CHECK: dimensions={1}
-
-CHECK-DAG: %[[SLICE2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} slice(bf16[1,4,2048,32768]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [8192:16384]}
-CHECK-DAG: %[[SLICE6:.*]] = bf16[1,8192,8192]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [0:8192], [8192:16384]}
-CHECK-DAG: %[[DOT2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} dot(bf16[1,4,2048,8192]{3,2,1,0} %[[SLICE2:.*]], bf16[1,8192,8192]{2,1,0} %[[SLICE6:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={2}, backend_config={"operation_queue_id":"6","wait_on_operation_queues":[],"force_earliest_schedule":false}
-CHECK: %[[A2A2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(bf16[1,4,2048,8192]{3,2,1,0} %[[DOT2:.*]]),
-CHECK: replica_groups={
-CHECK:     {0,1,2,3}
-CHECK: }
-CHECK: dimensions={1}
-
-CHECK-DAG: %[[SLICE3:.*]] = bf16[1,4,2048,8192]{3,2,1,0} slice(bf16[1,4,2048,32768]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [0:8192]}
-CHECK-DAG: %[[SLICE7:.*]] = bf16[1,8192,8192]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [0:8192], [0:8192]}
-CHECK-DAG: %[[DOT3:.*]] = bf16[1,4,2048,8192]{3,2,1,0} dot(bf16[1,4,2048,8192]{3,2,1,0} %[[SLICE3:.*]], bf16[1,8192,8192]{2,1,0} %[[SLICE7:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={2}, backend_config={"operation_queue_id":"5","wait_on_operation_queues":[],"force_earliest_schedule":false}
-CHECK: %[[A2A2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(bf16[1,4,2048,8192]{3,2,1,0} %[[DOT3:.*]]),
-CHECK: replica_groups={
-CHECK:     {0,1,2,3}
-CHECK: }
-CHECK: dimensions={1}
-CHECK-DAG: %[[CONSTANT:.*]] = bf16[] constant(0)
-CHECK-DAG: %[[BROADCAST:.*]] = bf16[1,4,2048,8192]{3,2,1,0} broadcast(bf16[] %[[CONSTANT:.*]]), dimensions={}
-CHECK-DAG: %[[ADD0:.*]] = bf16[1,4,2048,8192]{3,2,1,0} add(bf16[1,4,2048,8192]{3,2,1,0} %[[A2A0:.*]], bf16[1,4,2048,8192]{3,2,1,0} %[[BROADCAST:.*]])
-CHECK-DAG: %[[ADD1:.*]] = bf16[1,4,2048,8192]{3,2,1,0} add(bf16[1,4,2048,8192]{3,2,1,0} %[[A2A1:.*]], bf16[1,4,2048,8192]{3,2,1,0} %[[ADD0:.*]])
-CHECK-DAG: %[[ADD2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} add(bf16[1,4,2048,8192]{3,2,1,0} %[[A2A2:.*]], bf16[1,4,2048,8192]{3,2,1,0} %[[ADD1:.*]])
-
-CHECK: ROOT {{.*}} = bf16[1,4,2048,8192]{3,2,1,0} add(bf16[1,4,2048,8192]{3,2,1,0} %[[A2A3:.*]], bf16[1,4,2048,8192]{3,2,1,0} %[[ADD2:.*]])
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(kHloString));
-
-  GpuWindowedEinsumHandler gpu_handler;
-  bool changed;
-  TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get()));
-  TF_ASSERT_OK_AND_ASSIGN(bool filecheck_matched,
-                          RunFileCheck(module->ToString(), kExpected));
-  EXPECT_TRUE(filecheck_matched);
-}
-
-TEST_F(GpuWindowedEinsumHanlderTest, A2aTransposeLoopsHaveStreamIds) {
-  constexpr absl::string_view kHloString = R"(
-HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[1,8192,32768]{2,1,0}, bf16[1,1,8192,4,1,2048]{5,4,3,2,1,0})->bf16[1,4,2048,32768]{3,2,1,0}}, num_partitions=4
-
-ENTRY main.9_spmd {
-  param.9 = bf16[1,8192,32768]{2,1,0} parameter(0)
-  param.10 = bf16[1,1,8192,4,1,2048]{5,4,3,2,1,0} parameter(1)
-  all-to-all = bf16[1,1,8192,4,1,2048]{5,4,3,2,1,0} all-to-all(param.10), channel_id=4, replica_groups={{0,1,2,3}}, dimensions={3}
-  transpose.15 = bf16[1,4,1,8192,1,2048]{5,4,1,3,2,0} transpose(all-to-all), dimensions={0,3,1,2,4,5}
-  reshape.2170 = bf16[1,4,8192,1,2048]{4,3,2,1,0} reshape(transpose.15)
-  reshape.2173 = bf16[4,8192,1,2048]{3,2,1,0} reshape(reshape.2170)
-  transpose.16 = bf16[1,4,2048,8192]{2,0,3,1} transpose(reshape.2173), dimensions={2,0,3,1}
-  copy.53 = bf16[1,4,2048,8192]{3,2,1,0} copy(transpose.16)
-  ROOT dot.12 = bf16[1,4,2048,32768]{3,2,1,0} dot(copy.53, param.9), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}
-}
-)";
-
-  const char* kExpected = R"(
-CHECK: ENTRY
-CHECK-DAG: %[[P1:.*]] = bf16[1,1,8192,4,1,2048]{5,4,3,2,1,0} parameter(1)
-CHECK-DAG: %[[TRANSPOSE0:.*]] = bf16[1,4,1,8192,1,2048]{5,4,1,3,2,0} transpose(bf16[1,1,8192,4,1,2048]{5,4,3,2,1,0} %[[P1:.*]]), dimensions={0,3,1,2,4,5}
-CHECK-DAG: %[[RESHAPE0:.*]] = bf16[1,4,8192,1,2048]{4,3,2,1,0} reshape(bf16[1,4,1,8192,1,2048]{5,4,1,3,2,0} %[[TRANSPOSE0:.*]])
-CHECK-DAG: %[[RESHAPE1:.*]] = bf16[4,8192,1,2048]{3,2,1,0} reshape(bf16[1,4,8192,1,2048]{4,3,2,1,0} %[[RESHAPE0:.*]])
-CHECK-DAG: %[[TRANSPOSE1:.*]] = bf16[1,4,2048,8192]{2,0,3,1} transpose(bf16[4,8192,1,2048]{3,2,1,0} %[[RESHAPE1:.*]]), dimensions={2,0,3,1}
-CHECK-DAG: %[[COPY:.*]] = bf16[1,4,2048,8192]{3,2,1,0} copy(bf16[1,4,2048,8192]{2,0,3,1} %[[TRANSPOSE1:.*]])
-
-CHECK-DAG: %[[SLICE0:.*]] = bf16[1,4,2048,2048]{3,2,1,0} slice(bf16[1,4,2048,8192]{3,2,1,0} %[[COPY:.*]]), slice={[0:1], [0:4], [0:2048], [6144:8192]}
-CHECK: %[[A2A0:.*]] = bf16[1,4,2048,2048]{3,2,1,0} all-to-all(bf16[1,4,2048,2048]{3,2,1,0} %[[SLICE0]]),
-CHECK: replica_groups={
-CHECK:     {0,1,2,3}
-CHECK: }
-CHECK: dimensions={1}
-CHECK-DAG: %[[P0:.*]] = bf16[1,8192,32768]{2,1,0} parameter(0)
-CHECK-DAG: %[[SLICE4:.*]] = bf16[1,2048,32768]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [6144:8192], [0:32768]}
-CHECK-DAG: %[[DOT0:.*]] = bf16[1,4,2048,32768]{3,2,1,0} dot(bf16[1,4,2048,2048]{3,2,1,0} %[[A2A0:.*]], bf16[1,2048,32768]{2,1,0} %[[SLICE4:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"9","wait_on_operation_queues":[],"force_earliest_schedule":false}
-
-CHECK-DAG: %[[SLICE1:.*]] = bf16[1,4,2048,2048]{3,2,1,0} slice(bf16[1,4,2048,8192]{3,2,1,0} %[[COPY:.*]]), slice={[0:1], [0:4], [0:2048], [4096:6144]}
-CHECK: %[[A2A1:.*]] = bf16[1,4,2048,2048]{3,2,1,0} all-to-all(bf16[1,4,2048,2048]{3,2,1,0} %[[SLICE1]]),
-CHECK: replica_groups={
-CHECK:     {0,1,2,3}
-CHECK: }
-CHECK: dimensions={1}
-CHECK-DAG: %[[SLICE5:.*]] = bf16[1,2048,32768]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [4096:6144], [0:32768]}
-CHECK-DAG: %[[DOT1:.*]] = bf16[1,4,2048,32768]{3,2,1,0} dot(bf16[1,4,2048,2048]{3,2,1,0} %[[A2A1:.*]], bf16[1,2048,32768]{2,1,0} %[[SLICE5:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"8","wait_on_operation_queues":[],"force_earliest_schedule":false}
-
-CHECK-DAG: %[[SLICE2:.*]] = bf16[1,4,2048,2048]{3,2,1,0} slice(bf16[1,4,2048,8192]{3,2,1,0} %[[COPY:.*]]), slice={[0:1], [0:4], [0:2048], [2048:4096]}
-CHECK: %[[A2A2:.*]] = bf16[1,4,2048,2048]{3,2,1,0} all-to-all(bf16[1,4,2048,2048]{3,2,1,0} %[[SLICE2]]),
-CHECK: replica_groups={
-CHECK:     {0,1,2,3}
-CHECK: }
-CHECK: dimensions={1}
-CHECK-DAG: %[[SLICE6:.*]] = bf16[1,2048,32768]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [2048:4096], [0:32768]}
-CHECK-DAG: %[[DOT2:.*]] = bf16[1,4,2048,32768]{3,2,1,0} dot(bf16[1,4,2048,2048]{3,2,1,0} %[[A2A2:.*]], bf16[1,2048,32768]{2,1,0} %[[SLICE6:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"7","wait_on_operation_queues":[],"force_earliest_schedule":false}
-
-CHECK-DAG: %[[SLICE3:.*]] = bf16[1,4,2048,2048]{3,2,1,0} slice(bf16[1,4,2048,8192]{3,2,1,0} %[[COPY:.*]]), slice={[0:1], [0:4], [0:2048], [0:2048]}
-CHECK: %[[A2A2:.*]] = bf16[1,4,2048,2048]{3,2,1,0} all-to-all(bf16[1,4,2048,2048]{3,2,1,0} %[[SLICE3]]),
-CHECK: replica_groups={
-CHECK:     {0,1,2,3}
-CHECK: }
-CHECK: dimensions={1}
-CHECK-DAG: %[[SLICE7:.*]] = bf16[1,2048,32768]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [0:2048], [0:32768]}
-CHECK-DAG: %[[DOT3:.*]] = bf16[1,4,2048,32768]{3,2,1,0} dot(bf16[1,4,2048,2048]{3,2,1,0} %[[A2A3:.*]], bf16[1,2048,32768]{2,1,0} %[[SLICE7:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"6","wait_on_operation_queues":[],"force_earliest_schedule":false}
-CHECK-DAG: %[[CONSTANT:.*]] = bf16[] constant(0)
-CHECK-DAG: %[[BROADCAST:.*]] = bf16[1,4,2048,32768]{3,2,1,0} broadcast(bf16[] %[[CONSTANT:.*]]), dimensions={}
-CHECK-DAG: %[[ADD0:.*]] = bf16[1,4,2048,32768]{3,2,1,0} add(bf16[1,4,2048,32768]{3,2,1,0} %[[DOT0:.*]], bf16[1,4,2048,32768]{3,2,1,0} %[[BROADCAST:.*]]), backend_config={"operation_queue_id":"0","wait_on_operation_queues":["6"],"force_earliest_schedule":false}
-CHECK-DAG: %[[ADD1:.*]] = bf16[1,4,2048,32768]{3,2,1,0} add(bf16[1,4,2048,32768]{3,2,1,0} %[[DOT1:.*]], bf16[1,4,2048,32768]{3,2,1,0} %[[ADD0:.*]]), backend_config={"operation_queue_id":"0","wait_on_operation_queues":["7"],"force_earliest_schedule":false}
-CHECK-DAG: %[[ADD2:.*]] = bf16[1,4,2048,32768]{3,2,1,0} add(bf16[1,4,2048,32768]{3,2,1,0} %[[DOT2:.*]], bf16[1,4,2048,32768]{3,2,1,0} %[[ADD1:.*]]), backend_config={"operation_queue_id":"0","wait_on_operation_queues":["8"],"force_earliest_schedule":false}
-
-CHECK: ROOT {{.*}} = bf16[1,4,2048,32768]{3,2,1,0} add(bf16[1,4,2048,32768]{3,2,1,0} %[[DOT3:.*]], bf16[1,4,2048,32768]{3,2,1,0} %[[ADD2:.*]])
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(kHloString));
-
-  GpuWindowedEinsumHandler gpu_handler;
-  bool changed;
-  TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get()));
-  EXPECT_TRUE(changed);
-  TF_ASSERT_OK_AND_ASSIGN(bool filecheck_matched,
-                          RunFileCheck(module->ToString(), kExpected));
-  EXPECT_TRUE(filecheck_matched);
-}
-
-TEST_F(GpuWindowedEinsumHanlderTest, GemmA2aTransposeLoopsHaveStreamIds) {
-  constexpr absl::string_view kHloString = R"(
-HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[1,4,2048,32768]{3,2,1,0}, bf16[1,32768,8192]{2,1,0})->bf16[1,4,1,1,2048,8192]{5,4,3,2,1,0}}, num_partitions=4
-
-ENTRY main.9_spmd {
-  param.9 = bf16[1,4,2048,32768]{3,2,1,0} parameter(0)
-  param.10 = bf16[1,32768,8192]{2,1,0} parameter(1)
-  dot.13 = bf16[1,4,2048,8192]{3,2,1,0} dot(param.9, param.10), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}
-  copy.55 = bf16[1,4,2048,8192]{3,2,1,0} copy(dot.13)
-  transpose.17 = bf16[4,1,2048,8192]{3,2,0,1} transpose(copy.55), dimensions={1,0,2,3}
-  copy.56 = bf16[4,1,2048,8192]{3,2,1,0} copy(transpose.17)
-  reshape.2216 = bf16[1,4,1,2048,8192]{4,3,2,1,0} reshape(copy.56)
-  reshape.2219 = bf16[1,4,1,1,2048,8192]{5,4,3,2,1,0} reshape(reshape.2216)
-  ROOT all-to-all.1 = bf16[1,4,1,1,2048,8192]{5,4,3,2,1,0} all-to-all(reshape.2219), channel_id=7, replica_groups={{0,1,2,3}}, dimensions={1}
-}
-)";
-
-  const char* kExpected = R"(
-CHECK: ENTRY
-CHECK-DAG: %[[P1:.*]] = bf16[1,4,2048,32768]{3,2,1,0} parameter(0)
-
-CHECK-DAG: %[[SLICE0:.*]] = bf16[1,4,2048,8192]{3,2,1,0} slice(bf16[1,4,2048,32768]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [24576:32768]}
-CHECK-DAG: %[[P0:.*]] = bf16[1,32768,8192]{2,1,0} parameter(1)
-CHECK-DAG: %[[SLICE4:.*]] = bf16[1,8192,8192]{2,1,0} slice(bf16[1,32768,8192]{2,1,0} %[[P0:.*]]), slice={[0:1], [24576:32768], [0:8192]}
-CHECK-DAG: %[[DOT0:.*]] = bf16[1,4,2048,8192]{3,2,1,0} dot(bf16[1,4,2048,8192]{3,2,1,0} %[[SLICE0:.*]], bf16[1,8192,8192]{2,1,0} %[[SLICE4:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"12","wait_on_operation_queues":[],"force_earliest_schedule":false}
-CHECK: %[[A2A0:.*]] = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(bf16[1,4,2048,8192]{3,2,1,0} %[[DOT0:.*]]),
-CHECK: replica_groups={
-CHECK:     {0,1,2,3}
-CHECK: }
-CHECK: dimensions={1}
-
-CHECK-DAG: %[[SLICE1:.*]] = bf16[1,4,2048,8192]{3,2,1,0} slice(bf16[1,4,2048,32768]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [16384:24576]}
-CHECK-DAG: %[[SLICE5:.*]] = bf16[1,8192,8192]{2,1,0} slice(bf16[1,32768,8192]{2,1,0} %[[P0:.*]]), slice={[0:1], [16384:24576], [0:8192]}
-CHECK-DAG: %[[DOT1:.*]] = bf16[1,4,2048,8192]{3,2,1,0} dot(bf16[1,4,2048,8192]{3,2,1,0} %[[SLICE1:.*]], bf16[1,8192,8192]{2,1,0} %[[SLICE5:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"11","wait_on_operation_queues":[],"force_earliest_schedule":false}
-CHECK: %[[A2A1:.*]] = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(bf16[1,4,2048,8192]{3,2,1,0} %[[DOT1:.*]]),
-CHECK: replica_groups={
-CHECK:     {0,1,2,3}
-CHECK: }
-CHECK: dimensions={1}
-
-CHECK-DAG: %[[SLICE2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} slice(bf16[1,4,2048,32768]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [8192:16384]}
-CHECK-DAG: %[[SLICE6:.*]] = bf16[1,8192,8192]{2,1,0} slice(bf16[1,32768,8192]{2,1,0} %[[P0:.*]]), slice={[0:1], [8192:16384], [0:8192]}
-CHECK-DAG: %[[DOT2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} dot(bf16[1,4,2048,8192]{3,2,1,0} %[[SLICE2:.*]], bf16[1,8192,8192]{2,1,0} %[[SLICE6:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"10","wait_on_operation_queues":[],"force_earliest_schedule":false}
-CHECK: %[[A2A2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(bf16[1,4,2048,8192]{3,2,1,0} %[[DOT2:.*]]),
-CHECK: replica_groups={
-CHECK:     {0,1,2,3}
-CHECK: }
-CHECK: dimensions={1}
-
-CHECK-DAG: %[[SLICE3:.*]] = bf16[1,4,2048,8192]{3,2,1,0} slice(bf16[1,4,2048,32768]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [0:8192]}
-CHECK-DAG: %[[SLICE7:.*]] = bf16[1,8192,8192]{2,1,0} slice(bf16[1,32768,8192]{2,1,0} %[[P0:.*]]), slice={[0:1], [0:8192], [0:8192]}
-CHECK-DAG: %[[DOT3:.*]] = bf16[1,4,2048,8192]{3,2,1,0} dot(bf16[1,4,2048,8192]{3,2,1,0} %[[SLICE3:.*]], bf16[1,8192,8192]{2,1,0} %[[SLICE7:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"9","wait_on_operation_queues":[],"force_earliest_schedule":false}
-CHECK: %[[A2A2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(bf16[1,4,2048,8192]{3,2,1,0} %[[DOT3:.*]]),
-CHECK: replica_groups={
-CHECK:     {0,1,2,3}
-CHECK: }
-CHECK: dimensions={1}
-CHECK-DAG: %[[CONSTANT:.*]] = bf16[] constant(0)
-CHECK-DAG: %[[BROADCAST:.*]] = bf16[1,4,2048,8192]{3,2,1,0} broadcast(bf16[] %[[CONSTANT:.*]]), dimensions={}
-CHECK-DAG: %[[ADD0:.*]] = bf16[1,4,2048,8192]{3,2,1,0} add(bf16[1,4,2048,8192]{3,2,1,0} %[[A2A0:.*]], bf16[1,4,2048,8192]{3,2,1,0} %[[BROADCAST:.*]])
-CHECK-DAG: %[[ADD1:.*]] = bf16[1,4,2048,8192]{3,2,1,0} add(bf16[1,4,2048,8192]{3,2,1,0} %[[A2A1:.*]], bf16[1,4,2048,8192]{3,2,1,0} %[[ADD0:.*]])
-CHECK-DAG: %[[ADD2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} add(bf16[1,4,2048,8192]{3,2,1,0} %[[A2A2:.*]], bf16[1,4,2048,8192]{3,2,1,0} %[[ADD1:.*]])
-CHECK-DAG: %[[ADD3:.*]] = bf16[1,4,2048,8192]{3,2,1,0} add(bf16[1,4,2048,8192]{3,2,1,0} %[[A2A3:.*]], bf16[1,4,2048,8192]{3,2,1,0} %[[ADD2:.*]])
-
-CHECK-DAG: %[[COPY:.*]] = bf16[1,4,2048,8192]{3,2,1,0} copy(bf16[1,4,2048,8192]{3,2,1,0} %[[ADD3:.*]])
-CHECK-DAG: %[[TRANSPOSE0:.*]] = bf16[4,1,2048,8192]{3,2,0,1} transpose(bf16[1,4,2048,8192]{3,2,1,0} %[[COPY:.*]]), dimensions={1,0,2,3}
-CHECK-DAG: %[[COPY1:.*]] = bf16[4,1,2048,8192]{3,2,1,0} copy(bf16[4,1,2048,8192]{3,2,0,1} %[[TRANSPOSE0:.*]])
-CHECK-DAG: %[[RESHAPE0:.*]] = bf16[1,4,1,2048,8192]{4,3,2,1,0} reshape(bf16[4,1,2048,8192]{3,2,1,0} %[[COPY1:.*]])
-
-CHECK: ROOT {{.*}} = bf16[1,4,1,1,2048,8192]{5,4,3,2,1,0} reshape(bf16[1,4,1,2048,8192]{4,3,2,1,0} %[[RESHAPE0:.*]])
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(kHloString));
-
-  GpuWindowedEinsumHandler gpu_handler;
-  bool changed;
-  TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get()));
-  EXPECT_TRUE(changed);
-  TF_ASSERT_OK_AND_ASSIGN(bool filecheck_matched,
-                          RunFileCheck(module->ToString(), kExpected));
-  EXPECT_TRUE(filecheck_matched);
-}
-
-TEST_F(GpuWindowedEinsumHanlderTest, AllGatherF8) {
-  constexpr absl::string_view kHloString = R"(
-HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(f8e4m3fn[2,512,24576]{2,1,0}, f8e4m3fn[24576,24576]{1,0}, f32[], f32[])->f32[2,2048,24576]{2,1,0}}, num_partitions=4
-
-windowed_dot_general_body_ag {
-  param.1 = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) parameter(0)
-  get-tuple-element.1 = f32[2,512,24576]{2,1,0} get-tuple-element(param.1), index=0
-  collective-permute = f32[2,512,24576]{2,1,0} collective-permute(get-tuple-element.1), channel_id=4, source_target_pairs={{0,3},{1,0},{2,1},{3,2}}
-  collective-permute.1 = f32[2,512,24576]{2,1,0} collective-permute(collective-permute), channel_id=5, source_target_pairs={{0,3},{1,0},{2,1},{3,2}}
-  get-tuple-element.2 = f32[24576,24576]{1,0} get-tuple-element(param.1), index=1
-  get-tuple-element.3 = f32[2,2048,24576]{2,1,0} get-tuple-element(param.1), index=2
-  dot = f32[2,512,24576]{2,1,0} dot(get-tuple-element.1, get-tuple-element.2), lhs_contracting_dims={2}, rhs_contracting_dims={0}
-  constant.12 = s32[] constant(0)
-  constant.13 = s32[4]{0} constant({0, 512, 1024, 1536})
-  get-tuple-element.5 = u32[] get-tuple-element(param.1), index=4
-  partition-id = u32[] partition-id()
-  add = u32[] add(get-tuple-element.5, partition-id)
-  constant.11 = u32[] constant(4)
-  remainder = u32[] remainder(add, constant.11)
-  dynamic-slice = s32[1]{0} dynamic-slice(constant.13, remainder), dynamic_slice_sizes={1}
-  reshape = s32[] reshape(dynamic-slice)
-  dynamic-update-slice = f32[2,2048,24576]{2,1,0} dynamic-update-slice(get-tuple-element.3, dot, constant.12, reshape, constant.12)
-  dot.1 = f32[2,512,24576]{2,1,0} dot(collective-permute, get-tuple-element.2), lhs_contracting_dims={2}, rhs_contracting_dims={0}
-  constant.15 = u32[] constant(1)
-  add.1 = u32[] add(get-tuple-element.5, constant.15)
-  add.2 = u32[] add(add.1, partition-id)
-  remainder.1 = u32[] remainder(add.2, constant.11)
-  dynamic-slice.1 = s32[1]{0} dynamic-slice(constant.13, remainder.1), dynamic_slice_sizes={1}
-  reshape.1 = s32[] reshape(dynamic-slice.1)
-  dynamic-update-slice.1 = f32[2,2048,24576]{2,1,0} dynamic-update-slice(dynamic-update-slice, dot.1, constant.12, reshape.1, constant.12)
-  get-tuple-element.4 = f32[2,2048,24576]{2,1,0} get-tuple-element(param.1), index=3
-  add.3 = u32[] add(add.1, constant.15)
-  ROOT tuple = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) tuple(collective-permute.1, get-tuple-element.2, dynamic-update-slice.1, get-tuple-element.4, add.3)
-} // windowed_dot_general_body_ag
-
-windowed_dot_general_cond_ag {
-  param = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) parameter(0)
-  get-tuple-element = u32[] get-tuple-element(param), index=4
-  constant.10 = u32[] constant(4)
-  ROOT compare = pred[] compare(get-tuple-element, constant.10), direction=LT
-}
-
-ENTRY test_main {
-  param.4 = f8e4m3fn[2,512,24576]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]}
-  reshape.8 = f8e4m3fn[2,512,24576]{2,1,0} reshape(param.4)
-  param.5 = f8e4m3fn[24576,24576]{1,0} parameter(1), sharding={devices=[1,4]<=[4]}
-  constant.18 = f32[] constant(0)
-  broadcast = f32[2,2048,24576]{2,1,0} broadcast(constant.18), dimensions={}
-  constant.20 = u32[] constant(0)
-  scale_lhs = f32[] parameter(2)
-  scale_lhs_bcast = f32[2,512,24576]{2,1,0} broadcast(scale_lhs), dimensions={}
-  lhs_bf32 = f32[2,512,24576]{2,1,0} convert(reshape.8)  
-  lhs_scaled = f32[2,512,24576]{2,1,0} multiply(lhs_bf32, scale_lhs_bcast)
-  scale_rhs = f32[] parameter(3)
-  scale_rhs_bcast = f32[24576,24576]{1,0} broadcast(scale_rhs), dimensions={}
-  rhs_bf32 = f32[24576,24576]{1,0} convert(param.5)  
-  rhs_scaled = f32[24576,24576]{1,0} multiply(rhs_bf32, scale_rhs_bcast)
-  tuple.2 = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) tuple(lhs_scaled, rhs_scaled, broadcast, broadcast, constant.20)
-  while = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) while(tuple.2), condition=windowed_dot_general_cond_ag, body=windowed_dot_general_body_ag
-  ROOT get-tuple-element.13 = f32[2,2048,24576]{2,1,0} get-tuple-element(while), index=2
-}
-)";
-
-  RunAndFilecheckHloRewrite(kHloString, GpuWindowedEinsumHandler(),
-                            R"(
-; CHECK-LABEL: windowed_dot_general_body_ag
-; CHECK-NEXT:    [[P0:%[^ ]+]] = (f8e4m3fn[2,512,24576]{2,1,0}, f8e4m3fn[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[], /*index=5*/f32[], f32[]) parameter(0)
-; CHECK-NEXT:    [[GTE0:%[^ ]+]] = f8e4m3fn[2,512,24576]{2,1,0} get-tuple-element([[P0]]), index=0
-; CHECK-NEXT:    [[CP0:%[^ ]+]] = f8e4m3fn[2,512,24576]{2,1,0} collective-permute([[GTE0]]), channel_id=4
-; CHECK-NEXT:    [[CP1:%[^ ]+]] = f8e4m3fn[2,512,24576]{2,1,0} collective-permute([[CP0]]), channel_id=5
-; CHECK-NEXT:    [[GTE1:%[^ ]+]] = f8e4m3fn[24576,24576]{1,0} get-tuple-element([[P0]]), index=1
-; CHECK-NEXT:    [[GTE2:%[^ ]+]] = f32[2,2048,24576]{2,1,0} get-tuple-element([[P0]]), index=2
-; CHECK-NEXT:    [[CONVERT0:%[^ ]+]] = f32[2,512,24576]{2,1,0} convert([[GTE0]])
-; CHECK-NEXT:    [[GTE3:%[^ ]+]] = f32[] get-tuple-element([[P0]]), index=5
-; CHECK-NEXT:    [[BCAST0:%[^ ]+]] = f32[2,512,24576]{2,1,0} broadcast([[GTE3]]), dimensions={}
-; CHECK-NEXT:    [[MUL0:%[^ ]+]] = f32[2,512,24576]{2,1,0} multiply([[CONVERT0]], [[BCAST0]])
-; CHECK-NEXT:    [[CONVERT1:%[^ ]+]] = f32[24576,24576]{1,0} convert([[GTE1]])
-; CHECK-NEXT:    [[GTE4:%[^ ]+]] = f32[] get-tuple-element([[P0]]), index=6
-; CHECK-NEXT:    [[BCAST1:%[^ ]+]] = f32[24576,24576]{1,0} broadcast([[GTE4]]), dimensions={}
-; CHECK-NEXT:    [[MUL1:%[^ ]+]] = f32[24576,24576]{1,0} multiply([[CONVERT1]], [[BCAST1]])
-; CHECK-NEXT:    [[DOT0:%[^ ]+]] = f32[2,512,24576]{2,1,0} dot([[MUL0]], [[MUL1]]),
-; CHECK-DAG:       lhs_contracting_dims={2},
-; CHECK-DAG:       rhs_contracting_dims={0},
-; CHECK-DAG:       backend_config={
-; CHECK-DAG:         "operation_queue_id":"[[OPQUEUEID:[0-9]+]]",
-; CHECK-DAG:         "wait_on_operation_queues":[],
-; CHECK-DAG:         "force_earliest_schedule":true}
-; CHECK-NEXT:    [[C0:%[^ ]+]] = s32[] constant(0)
-; CHECK-NEXT:    [[C1:%[^ ]+]] = s32[4]{0} constant({0, 512, 1024, 1536})
-; CHECK-NEXT:    [[GTE5:%[^ ]+]] = u32[] get-tuple-element([[P0]]), index=4
-; CHECK-NEXT:    [[PID:%[^ ]+]] = u32[] partition-id()
-; CHECK-NEXT:    [[ADD0:%[^ ]+]] = u32[] add([[GTE5]], [[PID]])
-; CHECK-NEXT:    [[C2:%[^ ]+]] = u32[] constant(4)
-; CHECK-NEXT:    [[REM0:%[^ ]+]] = u32[] remainder([[ADD0]], [[C2]])
-; CHECK-NEXT:    [[DSLICE0:%[^ ]+]] = s32[1]{0} dynamic-slice([[C1]], [[REM0]]), dynamic_slice_sizes={1}
-; CHECK-NEXT:    [[RESHAPE0:%[^ ]+]] = s32[] reshape([[DSLICE0]])
-; CHECK-NEXT:    [[DUPDATESLICE0:%[^ ]+]] = f32[2,2048,24576]{2,1,0} dynamic-update-slice([[GTE2]], [[DOT0]], [[C0]], [[RESHAPE0]], [[C0]]),
-; CHECK-DAG:       backend_config={
-; CHECK-DAG:         "operation_queue_id":"0",
-; CHECK-DAG:         "wait_on_operation_queues":["[[OPQUEUEID]]"],
-; CHECK-DAG:         "force_earliest_schedule":false}
-; CHECK-NEXT:    [[CONVERT2:%[^ ]+]] = f32[2,512,24576]{2,1,0} convert([[CP0]])
-; CHECK-NEXT:    [[MUL2:%[^ ]+]] = f32[2,512,24576]{2,1,0} multiply([[CONVERT2]], [[BCAST0]])
-; CHECK-NEXT:    [[DOT1:%[^ ]+]] = f32[2,512,24576]{2,1,0} dot([[MUL2]], [[MUL1]]),
-; CHECK-DAG:       lhs_contracting_dims={2},
-; CHECK-DAG:       rhs_contracting_dims={0}
-; CHECK-NEXT:    [[C3:%[^ ]+]] = u32[] constant(1)
-; CHECK-NEXT:    [[ADD1:%[^ ]+]] = u32[] add([[GTE5]], [[C3]])
-; CHECK-NEXT:    [[ADD2:%[^ ]+]] = u32[] add([[ADD1]], [[PID]])
-; CHECK-NEXT:    [[REM1:%[^ ]+]] = u32[] remainder([[ADD2]], [[C2]])
-; CHECK-NEXT:    [[DSLICE1:%[^ ]+]] = s32[1]{0} dynamic-slice([[C1]], [[REM1]]), dynamic_slice_sizes={1}
-; CHECK-NEXT:    [[RESHAPE1:%[^ ]+]] = s32[] reshape([[DSLICE1]])
-; CHECK-NEXT:    [[DUPDATESLICE1:%[^ ]+]] = f32[2,2048,24576]{2,1,0} dynamic-update-slice([[DUPDATESLICE0]], [[DOT1]], [[C0]], [[RESHAPE1]], [[C0]])
-; CHECK-NEXT:    [[GTE6:%[^ ]+]] = f32[2,2048,24576]{2,1,0} get-tuple-element([[P0]]), index=3
-; CHECK-NEXT:    [[ADD3:%[^ ]+]] = u32[] add([[ADD1]], [[C3]])
-; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = (f8e4m3fn[2,512,24576]{2,1,0}, f8e4m3fn[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[], /*index=5*/f32[], f32[]) tuple([[CP1]], [[GTE1]], [[DUPDATESLICE1]], [[GTE6]], [[ADD3]], /*index=5*/[[GTE3]], [[GTE4]])
-)");
-}
-
-TEST_F(GpuWindowedEinsumHanlderTest, ReduceScatterF8) {
-  constexpr absl::string_view kHloString = R"(
-HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(f8e4m3fn[24576,24576]{1,0}, f32[2,512,24576]{2,1,0}, f8e4m3fn[2,2048,24576]{2,1,0}, f32[], f32[])->f32[2,512,24576]{2,1,0}}, num_partitions=4
-
-windowed_dot_general_body_rs {
-  param.3 = (f32[2,2048,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,512,24576]{2,1,0}, f32[2,512,24576]{2,1,0}, u32[]) parameter(0)
-  get-tuple-element.7 = f32[2,2048,24576]{2,1,0} get-tuple-element(param.3), index=0
-  get-tuple-element.8 = f32[24576,24576]{1,0} get-tuple-element(param.3), index=1
-  get-tuple-element.9 = f32[2,512,24576]{2,1,0} get-tuple-element(param.3), index=2
-  collective-permute.2 = f32[2,512,24576]{2,1,0} collective-permute(get-tuple-element.9), channel_id=9, source_target_pairs={{0,2},{1,3},{2,0},{3,1}}
-  constant.23 = s32[] constant(0)
-  constant.24 = s32[4]{0} constant({0, 512, 1024, 1536})
-  get-tuple-element.11 = u32[] get-tuple-element(param.3), index=4
-  constant.26 = u32[] constant(2)
-  add.8 = u32[] add(get-tuple-element.11, constant.26)
-  constant.27 = u32[] constant(1)
-  add.9 = u32[] add(add.8, constant.27)
-  partition-id.3 = u32[] partition-id()
-  add.10 = u32[] add(add.9, partition-id.3)
-  constant.22 = u32[] constant(4)
-  remainder.3 = u32[] remainder(add.10, constant.22)
-  dynamic-slice.4 = s32[1]{0} dynamic-slice(constant.24, remainder.3), dynamic_slice_sizes={1}
-  reshape.3 = s32[] reshape(dynamic-slice.4)
-  dynamic-slice.5 = f32[2,512,24576]{2,1,0} dynamic-slice(get-tuple-element.7, constant.23, reshape.3, constant.23), dynamic_slice_sizes={2,512,24576}
-  dot.3 = f32[2,512,24576]{2,1,0} dot(dynamic-slice.5, get-tuple-element.8), lhs_contracting_dims={2}, rhs_contracting_dims={0}
-  add.11 = f32[2,512,24576]{2,1,0} add(collective-permute.2, dot.3)
-  get-tuple-element.10 = f32[2,512,24576]{2,1,0} get-tuple-element(param.3), index=3
-  add.6 = u32[] add(get-tuple-element.11, partition-id.3)
-  remainder.2 = u32[] remainder(add.6, constant.22)
-  dynamic-slice.2 = s32[1]{0} dynamic-slice(constant.24, remainder.2), dynamic_slice_sizes={1}
-  reshape.2 = s32[] reshape(dynamic-slice.2)
-  dynamic-slice.3 = f32[2,512,24576]{2,1,0} dynamic-slice(get-tuple-element.7, constant.23, reshape.2, constant.23), dynamic_slice_sizes={2,512,24576}
-  dot.2 = f32[2,512,24576]{2,1,0} dot(dynamic-slice.3, get-tuple-element.8), lhs_contracting_dims={2}, rhs_contracting_dims={0}
-  add.7 = f32[2,512,24576]{2,1,0} add(get-tuple-element.10, dot.2)
-  collective-permute.3 = f32[2,512,24576]{2,1,0} collective-permute(add.7), channel_id=10, source_target_pairs={{0,2},{1,3},{2,0},{3,1}}
-  ROOT tuple.1 = (f32[2,2048,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,512,24576]{2,1,0}, f32[2,512,24576]{2,1,0}, u32[]) tuple(get-tuple-element.7, get-tuple-element.8, add.11, collective-permute.3, add.8)
-} // windowed_dot_general_body_rs
-
-windowed_dot_general_cond_rs {
-  param.2 = (f32[2,2048,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,512,24576]{2,1,0}, f32[2,512,24576]{2,1,0}, u32[]) parameter(0)
-  get-tuple-element.6 = u32[] get-tuple-element(param.2), index=4
-  constant.21 = u32[] constant(4)
-  ROOT compare.1 = pred[] compare(get-tuple-element.6, constant.21), direction=LT
-}
-
-ENTRY main.9_spmd {
-  param.6 = f8e4m3fn[24576,24576]{1,0} parameter(0), sharding={devices=[4,1]<=[4]}
-  param.7 = f32[2,512,24576]{2,1,0} parameter(1)
-  param.8 = f8e4m3fn[2,2048,24576]{2,1,0} parameter(2)
-  constant.20 = u32[] constant(0)
-  scale_lhs = f32[] parameter(3)
-  scale_lhs_bcast = f32[2,2048,24576]{2,1,0} broadcast(scale_lhs), dimensions={}
-  lhs_bf16 = f32[2,2048,24576]{2,1,0} convert(param.8)  
-  lhs_scaled = f32[2,2048,24576]{2,1,0} multiply(lhs_bf16, scale_lhs_bcast)
-  scale_rhs = f32[] parameter(4)
-  scale_rhs_bcast = f32[24576,24576]{1,0} broadcast(scale_rhs), dimensions={}
-  rhs_bf16 = f32[24576,24576]{1,0} convert(param.6)  
-  rhs_scaled = f32[24576,24576]{1,0} multiply(rhs_bf16, scale_rhs_bcast)
-  tuple.3 = (f32[2,2048,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,512,24576]{2,1,0}, f32[2,512,24576]{2,1,0}, u32[]) tuple(lhs_scaled, rhs_scaled, param.7, param.7, constant.20)
-  while.1 = (f32[2,2048,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,512,24576]{2,1,0}, f32[2,512,24576]{2,1,0}, u32[]) while(tuple.3), condition=windowed_dot_general_cond_rs, body=windowed_dot_general_body_rs
-  ROOT get-tuple-element.14 = f32[2,512,24576]{2,1,0} get-tuple-element(while.1), index=2
-}
-)";
-
-  RunAndFilecheckHloRewrite(kHloString, GpuWindowedEinsumHandler(),
-                            R"(
-; CHECK-LABEL: windowed_dot_general_body_rs
-; CHECK-NEXT:    [[P0:%[^ ]+]] = (f8e4m3fn[2,2048,24576]{2,1,0}, f8e4m3fn[24576,24576]{1,0}, f32[2,512,24576]{2,1,0}, f32[2,512,24576]{2,1,0}, u32[], /*index=5*/f32[], f32[]) parameter(0)
-; CHECK-NEXT:    [[GTE0:%[^ ]+]] = f8e4m3fn[2,2048,24576]{2,1,0} get-tuple-element([[P0]]), index=0
-; CHECK-NEXT:    [[GTE1:%[^ ]+]] = f8e4m3fn[24576,24576]{1,0} get-tuple-element([[P0]]), index=1
-; CHECK-NEXT:    [[GTE2:%[^ ]+]] = f32[2,512,24576]{2,1,0} get-tuple-element([[P0]]), index=2
-; CHECK-NEXT:    [[CP0:%[^ ]+]] = f32[2,512,24576]{2,1,0} collective-permute([[GTE2]]), channel_id=9
-; CHECK-NEXT:    [[CONVERT0:%[^ ]+]] = f32[2,2048,24576]{2,1,0} convert([[GTE0]])
-; CHECK-NEXT:    [[GTE3:%[^ ]+]] = f32[] get-tuple-element([[P0]]), index=5
-; CHECK-NEXT:    [[BCAST0:%[^ ]+]] = f32[2,2048,24576]{2,1,0} broadcast([[GTE3]]), dimensions={}
-; CHECK-NEXT:    [[MUL0:%[^ ]+]] = f32[2,2048,24576]{2,1,0} multiply([[CONVERT0]], [[BCAST0]])
-; CHECK-NEXT:    [[C0:%[^ ]+]] = s32[] constant(0)
-; CHECK-NEXT:    [[C1:%[^ ]+]] = s32[4]{0} constant({0, 512, 1024, 1536})
-; CHECK-NEXT:    [[GTE4:%[^ ]+]] = u32[] get-tuple-element([[P0]]), index=4
-; CHECK-NEXT:    [[C2:%[^ ]+]] = u32[] constant(2)
-; CHECK-NEXT:    [[ADD0:%[^ ]+]] = u32[] add([[GTE4]], [[C2]])
-; CHECK-NEXT:    [[C3:%[^ ]+]] = u32[] constant(1)
-; CHECK-NEXT:    [[ADD1:%[^ ]+]] = u32[] add([[ADD0]], [[C3]])
-; CHECK-NEXT:    [[PID:%[^ ]+]] = u32[] partition-id()
-; CHECK-NEXT:    [[ADD2:%[^ ]+]] = u32[] add([[ADD1]], [[PID]])
-; CHECK-NEXT:    [[C4:%[^ ]+]] = u32[] constant(4)
-; CHECK-NEXT:    [[REM0:%[^ ]+]] = u32[] remainder([[ADD2]], [[C4]])
-; CHECK-NEXT:    [[DSLICE0:%[^ ]+]] = s32[1]{0} dynamic-slice([[C1]], [[REM0]]), dynamic_slice_sizes={1}
-; CHECK-NEXT:    [[RESHAPE0:%[^ ]+]] = s32[] reshape([[DSLICE0]])
-; CHECK-NEXT:    [[DSLICE1:%[^ ]+]] = f32[2,512,24576]{2,1,0} dynamic-slice([[MUL0]], [[C0]], [[RESHAPE0]], [[C0]]), dynamic_slice_sizes={2,512,24576}
-; CHECK-NEXT:    [[CONVERT1:%[^ ]+]] = f32[24576,24576]{1,0} convert([[GTE1]])
-; CHECK-NEXT:    [[GTE5:%[^ ]+]] = f32[] get-tuple-element([[P0]]), index=6
-; CHECK-NEXT:    [[BCAST1:%[^ ]+]] = f32[24576,24576]{1,0} broadcast([[GTE5]]), dimensions={}
-; CHECK-NEXT:    [[MUL1:%[^ ]+]] = f32[24576,24576]{1,0} multiply([[CONVERT1]], [[BCAST1]])
-; CHECK-NEXT:    [[DOT0:%[^ ]+]] = f32[2,512,24576]{2,1,0} dot([[DSLICE1]], [[MUL1]]),
-; CHECK-DAG:       lhs_contracting_dims={2},
-; CHECK-DAG:       rhs_contracting_dims={0},
-; CHECK-DAG:       backend_config={
-; CHECK-DAG:         "operation_queue_id":"[[OPQUEUEID:[0-9]+]]",
-; CHECK-DAG:         "wait_on_operation_queues":[],
-; CHECK-DAG:         "force_earliest_schedule":false}
-; CHECK-NEXT:    [[ADD3:%[^ ]+]] = f32[2,512,24576]{2,1,0} add([[CP0]], [[DOT0]]),
-; CHECK-DAG:       backend_config={"
-; CHECK-DAG:         operation_queue_id":"0",
-; CHECK-DAG:         "wait_on_operation_queues":["[[OPQUEUEID]]"],
-; CHECK-DAG:         "force_earliest_schedule":false}
-; CHECK-NEXT:    [[GTE6:[^ ]+]] = f32[2,512,24576]{2,1,0} get-tuple-element([[P0]]), index=3
-; CHECK-NEXT:    [[ADD4:%[^ ]+]] = u32[] add([[GTE4]], [[PID]])
-; CHECK-NEXT:    [[REM1:%[^ ]+]] = u32[] remainder([[ADD4]], [[C4]])
-; CHECK-NEXT:    [[DSLICE2:%[^ ]+]] = s32[1]{0} dynamic-slice([[C1]], [[REM1]]), dynamic_slice_sizes={1}
-; CHECK-NEXT:    [[RESHAPE1:%[^ ]+]] = s32[] reshape([[DSLICE2]])
-; CHECK-NEXT:    [[DSLICE3:%[^ ]+]] = f32[2,512,24576]{2,1,0} dynamic-slice([[MUL0]], [[C0]], [[RESHAPE1]], [[C0]]), dynamic_slice_sizes={2,512,24576}
-; CHECK-NEXT:    [[DOT1:%[^ ]+]] = f32[2,512,24576]{2,1,0} dot([[DSLICE3]], [[MUL1]]),
-; CHECK-DAG:       lhs_contracting_dims={2},
-; CHECK-DAG:       rhs_contracting_dims={0}
-; CHECK-NEXT:    [[ADD5:%[^ ]+]] = f32[2,512,24576]{2,1,0} add([[GTE6]], [[DOT1]])
-; CHECK-NEXT:    [[CP1:[^ ]+]] = f32[2,512,24576]{2,1,0} collective-permute([[ADD5]]), channel_id=10
-; CHECK-NEXT:  ROOT [[OUT:[^ ]+]] = (f8e4m3fn[2,2048,24576]{2,1,0}, f8e4m3fn[24576,24576]{1,0}, f32[2,512,24576]{2,1,0}, f32[2,512,24576]{2,1,0}, u32[], /*index=5*/f32[], f32[]) tuple([[GTE0]], [[GTE1]], [[ADD3]], [[CP1]], [[ADD0]], /*index=5*/[[GTE3]], [[GTE5]])
-)");
-}
-
-TEST_F(GpuWindowedEinsumHanlderTest,
-       AgLoopsMultipleConsumersAreChainedWithShardedContratingDim) {
-  constexpr absl::string_view kHloString = R"(
-HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0})->bf16[4096,6288]{1,0}}, num_partitions=8
-
-windowed_dot_general_body_ag {
-  param.195 = (bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0}, bf16[16,2048,6288]{2,1,0}, u32[]) parameter(0)
-  get-tuple-element.588 = bf16[16,2048,512]{2,1,0} get-tuple-element(param.195), index=0
-  collective-permute.194 = bf16[16,2048,512]{2,1,0} collective-permute(get-tuple-element.588), channel_id=446, source_target_pairs={{0,7},{1,0},{2,1},{3,2},{4,3},{5,4},{6,5},{7,6}}
-  collective-permute.195 = bf16[16,2048,512]{2,1,0} collective-permute(collective-permute.194), channel_id=447, source_target_pairs={{0,7},{1,0},{2,1},{3,2},{4,3},{5,4},{6,5},{7,6}}
-  get-tuple-element.589 = bf16[4096,6288]{1,0} get-tuple-element(param.195), index=1
-  get-tuple-element.590 = bf16[16,2048,6288]{2,1,0} get-tuple-element(param.195), index=2
-  constant.11432 = s32[8]{0} constant({0, 512, 1024, 1536, 2048, 2560, 3072, 3584})
-  get-tuple-element.592 = u32[] get-tuple-element(param.195), index=4
-  partition-id.194 = u32[] partition-id()
-  add.4309 = u32[] add(get-tuple-element.592, partition-id.194)
-  constant.11431 = u32[] constant(8)
-  remainder.194 = u32[] remainder(add.4309, constant.11431)
-  dynamic-slice.388 = s32[1]{0} dynamic-slice(constant.11432, remainder.194), dynamic_slice_sizes={1}
-  reshape.12959 = s32[] reshape(dynamic-slice.388)
-  constant.11433 = s32[] constant(0)
-  dynamic-slice.389 = bf16[512,6288]{1,0} dynamic-slice(get-tuple-element.589, reshape.12959, constant.11433), dynamic_slice_sizes={512,6288}
-  dot.244 = bf16[16,2048,6288]{2,1,0} dot(get-tuple-element.588, dynamic-slice.389), lhs_contracting_dims={2}, rhs_contracting_dims={0}
-  add.4310 = bf16[16,2048,6288]{2,1,0} add(get-tuple-element.590, dot.244)
-  constant.11434 = u32[] constant(1)
-  add.4312 = u32[] add(get-tuple-element.592, constant.11434)
-  add.4313 = u32[] add(add.4312, partition-id.194)
-  remainder.195 = u32[] remainder(add.4313, constant.11431)
-  dynamic-slice.390 = s32[1]{0} dynamic-slice(constant.11432, remainder.195), dynamic_slice_sizes={1}
-  reshape.12960 = s32[] reshape(dynamic-slice.390)
-  dynamic-slice.391 = bf16[512,6288]{1,0} dynamic-slice(get-tuple-element.589, reshape.12960, constant.11433), dynamic_slice_sizes={512,6288}
-  dot.245 = bf16[16,2048,6288]{2,1,0} dot(collective-permute.194, dynamic-slice.391), lhs_contracting_dims={2}, rhs_contracting_dims={0}
-  add.4314 = bf16[16,2048,6288]{2,1,0} add(add.4310, dot.245)
-  get-tuple-element.591 = bf16[16,2048,6288]{2,1,0} get-tuple-element(param.195), index=3
-  add.4315 = u32[] add(add.4312, constant.11434)
-  ROOT tuple.98 = (bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0}, bf16[16,2048,6288]{2,1,0}, u32[]) tuple(collective-permute.195, get-tuple-element.589, add.4314, get-tuple-element.591, add.4315)
-} // windowed_dot_general_body_ag
-
-windowed_dot_general_cond_ag {
-  param = (bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0}, bf16[16,2048,6288]{2,1,0}, u32[]) parameter(0)
-  get-tuple-element = u32[] get-tuple-element(param), index=4
-  constant = u32[] constant(4)
-  ROOT compare = pred[] compare(get-tuple-element, constant), direction=LT
-}
-
-ENTRY main.12_spmd {
-  param.4 = bf16[16,2048,512]{2,1,0} parameter(0)
-  param.5 = bf16[4096,6288]{1,0} parameter(1)
-  constant.22 = bf16[] constant(0)
-  broadcast = bf16[16,2048,6288]{2,1,0} broadcast(constant.22), dimensions={}
-  constant.24 = u32[] constant(0)
-  tuple.2 = (bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0}, bf16[16,2048,6288]{2,1,0}, u32[]) tuple(param.4, param.5, broadcast, broadcast, constant.24)
-  while = (bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0}, bf16[16,2048,6288]{2,1,0}, u32[]) while(tuple.2), condition=windowed_dot_general_cond_ag, body=windowed_dot_general_body_ag
-  get-tuple-element.13 = bf16[16,2048,6288]{2,1,0} get-tuple-element(while), index=2
-  all-gather = bf16[16,2048,4096]{2,1,0} all-gather(param.4), channel_id=1, replica_groups={{0,1,2,3,4,5,6,7}}, dimensions={2}, use_global_device_ids=true
-  param.6 = bf16[16,2048,6288]{2,1,0} parameter(2)
-  ROOT dot.7 = bf16[4096,6288]{1,0} dot(all-gather, param.6), lhs_contracting_dims={0,1}, rhs_contracting_dims={0,1}
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(kHloString));
-
-  GpuWindowedEinsumHandler gpu_handler;
-  bool changed;
-  TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get()));
-  EXPECT_TRUE(changed);
-
-  HloInstruction* ag_loop =
-      FindInstructionByName(module->entry_computation(), "while");
-  HloInstruction* inst =
-      FindInstructionByName(module->entry_computation(), "dot.7");
-  // dot.7 should now consume output of the windowed einsum while loop.
-  EXPECT_EQ(inst->operand(0)->opcode(), HloOpcode::kGetTupleElement);
-  EXPECT_EQ(inst->operand(0)->tuple_index(), 5);
-  EXPECT_EQ(inst->operand(0)->operand(0), ag_loop);
-}
-}  // namespace
-}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/hlo_algorithm_denylist.cc b/third_party/xla/xla/service/gpu/hlo_algorithm_denylist.cc
index 8b1efb2..dcca8f3 100644
--- a/third_party/xla/xla/service/gpu/hlo_algorithm_denylist.cc
+++ b/third_party/xla/xla/service/gpu/hlo_algorithm_denylist.cc
@@ -25,8 +25,8 @@
 #include "absl/strings/str_cat.h"
 #include "xla/debug_options_flags.h"
 #include "xla/hlo/ir/backend_config.h"
+#include "xla/service/gpu/autotuning/gpu_autotuning.pb.h"
 #include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/gpu_autotuning.pb.h"
 #include "xla/stream_executor/dnn.h"
 #include "tsl/platform/env.h"
 #include "tsl/platform/protobuf.h"
diff --git a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc
index 345fd8c..e527a09 100644
--- a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc
+++ b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc
@@ -171,15 +171,40 @@
 
 // static
 HloFusionAnalysis HloFusionAnalysis::Create(
-    const HloFusionInstruction* fusion,
-    const se::DeviceDescription* device_info) {
-  CHECK(device_info != nullptr);
-  FusionBackendConfig backend_config =
-      fusion->has_backend_config()
-          ? fusion->backend_config<GpuBackendConfig>()->fusion_backend_config()
-          : FusionBackendConfig::default_instance();
-  return Create(std::move(backend_config),
-                HloFusionAdaptor::ForInstruction(fusion), device_info);
+    const HloInstruction& instruction,
+    const se::DeviceDescription& device_info) {
+  absl::StatusOr<GpuBackendConfig> gpu_backend_config =
+      instruction.backend_config<GpuBackendConfig>();
+
+  FusionBackendConfig fusion_backend_config =
+      gpu_backend_config.ok() ? gpu_backend_config->fusion_backend_config()
+                              : FusionBackendConfig::default_instance();
+  return Create(std::move(fusion_backend_config),
+                HloFusionAdaptor::ForInstruction(&instruction), &device_info);
+}
+
+// static
+HloFusionAnalysis HloFusionAnalysis::Create(
+    const HloInstruction& producer, const HloInstruction& consumer,
+    const se::DeviceDescription& device_info) {
+  absl::StatusOr<GpuBackendConfig> gpu_backend_config;
+
+  if (consumer.has_backend_config()) {
+    gpu_backend_config = consumer.backend_config<GpuBackendConfig>();
+  }
+
+  if (!gpu_backend_config.ok() && producer.has_backend_config()) {
+    gpu_backend_config = producer.backend_config<GpuBackendConfig>();
+  }
+
+  FusionBackendConfig fusion_backend_config =
+      gpu_backend_config.ok() ? gpu_backend_config->fusion_backend_config()
+                              : FusionBackendConfig::default_instance();
+
+  return HloFusionAnalysis::Create(
+      std::move(fusion_backend_config),
+      HloFusionAdaptor::ForProducerConsumer(&producer, &consumer),
+      &device_info);
 }
 
 // Returns true if the fusion has consistent transpose heros.
@@ -264,7 +289,7 @@
   }
 
   // We expect that the last dimension is swapped with a different dimension.
-  if (HasConsistentTransposeHeros() && tiled_transpose_->permutation[2] != 2) {
+  if (HasConsistentTransposeHeros()) {
     return EmitterFusionKind::kTranspose;
   }
 
@@ -305,24 +330,5 @@
   LOG(FATAL) << "Did not find a hero reduction";
 }
 
-HloFusionAnalysis AnalyzeProducerConsumerFusion(
-    const HloInstruction& producer, const HloInstruction& consumer,
-    const se::DeviceDescription& device_info) {
-  return HloFusionAnalysis::Create(
-      consumer.has_backend_config()
-          ? consumer.backend_config<GpuBackendConfig>()->fusion_backend_config()
-          : producer.backend_config<GpuBackendConfig>()
-                ->fusion_backend_config(),
-      HloFusionAdaptor::ForProducerConsumer(&producer, &consumer),
-      &device_info);
-}
-
-HloFusionAnalysis AnalyzeFusion(const HloInstruction& consumer,
-                                const se::DeviceDescription& device_info) {
-  return HloFusionAnalysis::Create(
-      consumer.backend_config<GpuBackendConfig>()->fusion_backend_config(),
-      HloFusionAdaptor::ForInstruction(&consumer), &device_info);
-}
-
 }  // namespace gpu
 }  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.h b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.h
index 146224b..c1b7e5b 100644
--- a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.h
+++ b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.h
@@ -58,8 +58,17 @@
   static HloFusionAnalysis Create(FusionBackendConfig backend_config,
                                   std::unique_ptr<HloFusionAdaptor> fusion,
                                   const se::DeviceDescription* device_info);
-  static HloFusionAnalysis Create(const HloFusionInstruction* fusion,
-                                  const se::DeviceDescription* device_info);
+
+  // Creates a HloFusionAnalysis that analyzes just instruction as a standalone
+  // fusion.
+  static HloFusionAnalysis Create(const HloInstruction& instruction,
+                                  const se::DeviceDescription& device_info);
+
+  // Creates a HloFusionAnalysis that analyzes a hypothetical fusion of producer
+  // into consumer.
+  static HloFusionAnalysis Create(const HloInstruction& producer,
+                                  const HloInstruction& consumer,
+                                  const se::DeviceDescription& device_info);
 
   const HloFusionAdaptor& fusion() const { return *fusion_; }
 
@@ -131,17 +140,6 @@
   InputOutputInfo input_output_info_;
 };
 
-// Creates a HloFusionAnalysis that analyzes a hypothetical fusion of producer
-// into consumer.
-HloFusionAnalysis AnalyzeProducerConsumerFusion(
-    const HloInstruction& producer, const HloInstruction& consumer,
-    const se::DeviceDescription& device_info);
-
-// Creates a HloFusionAnalysis that analyzes just consumer as a standalone
-// fusion.
-HloFusionAnalysis AnalyzeFusion(const HloInstruction& consumer,
-                                const se::DeviceDescription& device_info);
-
 }  // namespace gpu
 }  // namespace xla
 
diff --git a/third_party/xla/xla/service/gpu/hlo_fusion_analysis_test.cc b/third_party/xla/xla/service/gpu/hlo_fusion_analysis_test.cc
index 04c5819..7328bc6 100644
--- a/third_party/xla/xla/service/gpu/hlo_fusion_analysis_test.cc
+++ b/third_party/xla/xla/service/gpu/hlo_fusion_analysis_test.cc
@@ -15,9 +15,11 @@
 #include "xla/service/gpu/hlo_fusion_analysis.h"
 
 #include <gtest/gtest.h>
+#include "xla/protobuf_util.h"
 #include "xla/service/gpu/backend_configs.pb.h"
 #include "xla/service/gpu/gpu_device_info_for_tests.h"
 #include "xla/service/gpu/hlo_traversal.h"
+#include "xla/service/gpu/ir_emission_utils.h"
 #include "xla/stream_executor/device_description.h"
 #include "xla/stream_executor/device_description.pb.h"
 #include "xla/tests/hlo_test_base.h"
@@ -48,12 +50,12 @@
   auto device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo();
 
   auto* root = module->entry_computation()->root_instruction();
-  auto analysis = AnalyzeFusion(*root, device_info);
+  auto analysis = HloFusionAnalysis::Create(*root, device_info);
   EXPECT_EQ(analysis.GetEmitterFusionKind(),
             HloFusionAnalysis::EmitterFusionKind::kLoop);
 
   auto analysis_fused =
-      AnalyzeProducerConsumerFusion(*root->operand(0), *root, device_info);
+      HloFusionAnalysis::Create(*root->operand(0), *root, device_info);
   EXPECT_EQ(analysis_fused.GetEmitterFusionKind(),
             HloFusionAnalysis::EmitterFusionKind::kReduction);
 }
@@ -155,7 +157,7 @@
   auto* root = module->entry_computation()->root_instruction();
 
   auto analysis =
-      AnalyzeProducerConsumerFusion(*root->operand(0), *root, device_info);
+      HloFusionAnalysis::Create(*root->operand(0), *root, device_info);
   EXPECT_EQ(analysis.GetEmitterFusionKind(),
             HloFusionAnalysis::EmitterFusionKind::kReduction);
 }
@@ -186,7 +188,7 @@
 
   auto* root = module->entry_computation()->root_instruction();
   auto analysis =
-      AnalyzeProducerConsumerFusion(*root->operand(0), *root, device_info);
+      HloFusionAnalysis::Create(*root->operand(0), *root, device_info);
   EXPECT_EQ(analysis.GetEmitterFusionKind(),
             HloFusionAnalysis::EmitterFusionKind::kReduction);
 }
@@ -223,7 +225,7 @@
 
   auto* root = module->entry_computation()->root_instruction();
   auto analysis =
-      AnalyzeProducerConsumerFusion(*root->operand(0), *root, device_info);
+      HloFusionAnalysis::Create(*root->operand(0), *root, device_info);
   EXPECT_EQ(analysis.GetEmitterFusionKind(),
             HloFusionAnalysis::EmitterFusionKind::kReduction);
 }
@@ -255,7 +257,7 @@
   auto device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo();
 
   auto* root = module->entry_computation()->root_instruction();
-  auto analysis = AnalyzeFusion(*root, device_info);
+  auto analysis = HloFusionAnalysis::Create(*root, device_info);
   EXPECT_EQ(analysis.GetEmitterFusionKind(),
             HloFusionAnalysis::EmitterFusionKind::kReduction);
 }
@@ -287,7 +289,7 @@
   auto device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo();
 
   auto* root = module->entry_computation()->root_instruction();
-  auto analysis = AnalyzeFusion(*root, device_info);
+  auto analysis = HloFusionAnalysis::Create(*root, device_info);
   // We expect to fallback to the loop emitter, because the two reductions are
   // not compatible as they reduce over different dimensions.
   EXPECT_EQ(analysis.GetEmitterFusionKind(),
@@ -319,7 +321,7 @@
 
   auto* root = module->entry_computation()->root_instruction();
   auto analysis_fused =
-      AnalyzeProducerConsumerFusion(*root->operand(0), *root, device_info);
+      HloFusionAnalysis::Create(*root->operand(0), *root, device_info);
   EXPECT_EQ(analysis_fused.GetEmitterFusionKind(),
             HloFusionAnalysis::EmitterFusionKind::kReduction);
 }
@@ -352,5 +354,90 @@
             HloFusionAnalysis::EmitterFusionKind::kConcatenate);
 }
 
+TEST_F(HloFusionAnalysisTest, ExtractValidGpuBackendConfig) {
+  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"(
+    HloModule module
+
+    fused_computation.1 {
+      %x = s32[64] parameter(0)
+      %y = s32[64] parameter(1)
+      ROOT %root = s32[64] add(%x, %y)
+    }
+
+    fused_computation.2 {
+      %x = s32[64] parameter(0)
+      %y = s32[64] parameter(1)
+      ROOT %root = s32[64] add(%x, %y)
+    }
+
+    ENTRY entry {
+      %x = s32[64] parameter(0)
+      %y = s32[64] parameter(1)
+      %fusion.1 = s32[64] fusion(%x, %y), kind=kLoop, calls=fused_computation.1, backend_config={"fusion_backend_config": {kind: "__triton"}}
+      ROOT %fusion.2 = s32[64] fusion(%fusion.1, %y), kind=kLoop, calls=fused_computation.2
+    })"));
+
+  auto device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+  auto* consumer = module->entry_computation()->root_instruction();
+  auto* producer = consumer->operand(0);
+
+  auto producer_analysis = HloFusionAnalysis::Create(*producer, device_info);
+  EXPECT_EQ(producer_analysis.fusion_backend_config().kind(),
+            kTritonFusionKind);
+
+  auto producer_consumer_analysis =
+      HloFusionAnalysis::Create(*producer, *consumer, device_info);
+  EXPECT_EQ(producer_consumer_analysis.fusion_backend_config().kind(),
+            kTritonFusionKind);
+}
+
+TEST_F(HloFusionAnalysisTest,
+       InvalidGpuBackendConfig_SingleInstruction_Ignored) {
+  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"(
+    HloModule module
+
+    ENTRY entry {
+      %x = s32[64,64,64] parameter(0)
+      %y = s32[64,64,64] parameter(1)
+      ROOT %root = s32[64,128,64] concatenate(x, y), dimensions={1}, backend_config={"outer_dimension_partitions": ["1"]}
+    })"));
+
+  auto device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+  auto* root = module->entry_computation()->root_instruction();
+  auto analysis = HloFusionAnalysis::Create(*root, device_info);
+
+  EXPECT_TRUE(
+      protobuf_util::ProtobufEquals(analysis.fusion_backend_config(),
+                                    FusionBackendConfig::default_instance()));
+}
+
+TEST_F(HloFusionAnalysisTest,
+       InvalidGpuBackendConfig_ProducerConsumer_Ignored) {
+  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"(
+    HloModule module
+
+    fused_computation {
+      %x = s32[64] parameter(0)
+      %y = s32[64] parameter(1)
+      ROOT %root = s32[64] add(%x, %y)
+    }
+
+    ENTRY entry {
+      %x = s32[64] parameter(0)
+      %y = s32[64] parameter(1)
+      %fusion = s32[64] fusion(%x, %y), kind=kLoop, calls=fused_computation, backend_config={"invalid_field": "some_value"}
+      ROOT %root = s32[128] concatenate(fusion, y), dimensions={0}, backend_config={"invalid_field": "some_value"}
+    })"));
+
+  auto device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+  auto* consumer = module->entry_computation()->root_instruction();
+  auto* producer = consumer->operand(0);
+  auto analysis = HloFusionAnalysis::Create(*producer, *consumer, device_info);
+
+  EXPECT_TRUE(
+      protobuf_util::ProtobufEquals(analysis.fusion_backend_config(),
+                                    FusionBackendConfig::default_instance()));
+}
+
 }  // namespace
 }  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/horizontal_input_fusion.cc b/third_party/xla/xla/service/gpu/horizontal_input_fusion.cc
deleted file mode 100644
index c693856..0000000
--- a/third_party/xla/xla/service/gpu/horizontal_input_fusion.cc
+++ /dev/null
@@ -1,192 +0,0 @@
-/* Copyright 2020 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/horizontal_input_fusion.h"
-
-#include <algorithm>
-#include <cstddef>
-#include <vector>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/log/log.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/gpu/gpu_fusible.h"
-#include "xla/service/hlo_creation_utils.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/util.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-namespace {
-
-// Gets the representative input shape of the multi-output fusion.
-Shape GetInputShapeForMultiOutputFusion(const HloInstruction& instr) {
-  // Get the HLO that determines the emitter used for lowering.
-  const HloInstruction* real_hero = GetRealHeroForMultiOutputFusion(instr);
-  if (real_hero->operands().empty()) {
-    // Simply return an empty shape if the representative node has no input
-    // operands.
-    return Shape();
-  } else {
-    return real_hero->operand(0)->shape();
-  }
-}
-
-class HorizontalInputFusionImpl {
- public:
-  explicit HorizontalInputFusionImpl(HloComputation* computation,
-                                     const se::DeviceDescription& d)
-      : computation_(computation), device_info_(d) {}
-
-  ~HorizontalInputFusionImpl() = default;
-
-  absl::StatusOr<bool> Run();
-
- private:
-  HloComputation* computation_;
-  const se::DeviceDescription& device_info_;
-};  // HorizontalInputFusionImpl
-
-// Compares one-by-one the dimensions of `shape_a` and `shape_b` from left to
-// right.
-bool CompareShapeDimsFromLeftToRight(const Shape& shape_a,
-                                     const Shape& shape_b) {
-  if (shape_a.rank() != shape_b.rank()) {
-    return shape_a.rank() < shape_b.rank();
-  }
-  auto dims_a = shape_a.dimensions();
-  auto dims_b = shape_b.dimensions();
-  for (size_t i = 0; i < dims_a.size(); ++i) {
-    if (dims_a[i] != dims_b[i]) {
-      return dims_a[i] < dims_b[i];
-    }
-  }
-  return true;
-}
-
-std::vector<HloInstruction*> FindAndSortFusionCandidates(
-    HloInstruction* consumer) {
-  absl::flat_hash_set<HloInstruction*> fusion_instr_set;
-  std::vector<HloInstruction*> fusion_instrs;
-  for (HloInstruction* opnd : consumer->operands()) {
-    HloInstruction* predecessor = opnd->LatestNonGteAncestor();
-    // Find out the input fusion instructions whose only consumer is `consumer`.
-    // This guarantees that fusing these candidates will never create cycles, as
-    // there is no back edge.
-    if (IsInputFusibleReduction(*predecessor) &&
-        IsConsumerTheOnlyNonRootUser(*predecessor, *consumer)) {
-      if (fusion_instr_set.insert(predecessor).second) {
-        fusion_instrs.push_back(predecessor);
-      }
-    }
-  }
-
-  std::sort(fusion_instrs.begin(), fusion_instrs.end(),
-            [&](const HloInstruction* a, const HloInstruction* b) {
-              Shape shape_a = GetInputShapeForMultiOutputFusion(*a);
-              Shape shape_b = GetInputShapeForMultiOutputFusion(*b);
-              if (!ShapeUtil::EqualIgnoringElementType(shape_a, shape_b)) {
-                // Sort shapes according to dimensions, so that the same input
-                // shapes will be placed adjacent each other.
-                return CompareShapeDimsFromLeftToRight(shape_a, shape_b);
-              }
-              // Sort `fusion_instrs` according to instruction counts, because
-              // we'd like to fuse together computations of similar sizes.
-              return GetInstrCountOfFusible(*a) < GetInstrCountOfFusible(*b);
-            });
-
-  return fusion_instrs;
-}
-
-absl::StatusOr<bool> HorizontalInputFusionImpl::Run() {
-  bool changed = false;
-  XLA_VLOG_LINES(3, computation_->ToString());
-
-  // Using def-to-use order is sound since we do not modify users.
-  std::vector<HloInstruction*> def_to_use_order =
-      computation_->MakeInstructionPostOrder();
-  for (HloInstruction* consumer : def_to_use_order) {
-    auto candidates = FindAndSortFusionCandidates(consumer);
-    if (candidates.size() <= 1) {
-      continue;
-    }
-
-    // Convert candidates into fusions if needed.
-    for (size_t j = 0; j < candidates.size(); ++j) {
-      if (candidates[j]->opcode() != HloOpcode::kFusion) {
-        TF_ASSIGN_OR_RETURN(
-            HloInstruction * fusion_instr,
-            MakeFusionInstruction(candidates[j],
-                                  HloInstruction::FusionKind::kInput));
-        candidates[j] = fusion_instr;
-        changed = true;
-      }
-    }
-
-    size_t fusion_anchor_id = 0;
-    for (size_t j = 1; j < candidates.size(); ++j) {
-      HloInstruction* fusion_anchor = candidates[fusion_anchor_id];
-      HloInstruction* fused = candidates[j];
-      if (ShapesCompatibleForMultiOutputFusion(*fusion_anchor, *fused) &&
-          FusionFitsInBudget(*fusion_anchor, *fused, device_info_)) {
-        VLOG(3) << "Fuse " << fused->ToString() << " into "
-                << fusion_anchor->ToString();
-        fusion_anchor->MergeFusionInstructionIntoMultiOutput(fused);
-        changed = true;
-      } else {
-        // Update the `fusion_anchor_id` since `fused` is either not
-        // compatible or not beneficial to be fused with current fusion anchor.
-        VLOG(3) << j - fusion_anchor_id - 1 << " instructions are fused.";
-        fusion_anchor_id = j;
-      }
-    }
-  }
-
-  return changed;
-}
-
-}  // namespace
-
-absl::StatusOr<bool> GpuHorizontalInputFusion::RunOnComputation(
-    HloComputation* computation) {
-  HorizontalInputFusionImpl horizontal_fusion_impl(computation, device_info_);
-  return horizontal_fusion_impl.Run();
-}
-
-absl::StatusOr<bool> GpuHorizontalInputFusion::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  bool changed = false;
-  VLOG(2) << "Run horizontal input fusion.";
-  for (HloComputation* comp :
-       module->MakeNonfusionComputations(execution_threads)) {
-    TF_ASSIGN_OR_RETURN(changed, RunOnComputation(comp));
-  }
-
-  return changed;
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/horizontal_input_fusion.h b/third_party/xla/xla/service/gpu/horizontal_input_fusion.h
deleted file mode 100644
index 370ce7b..0000000
--- a/third_party/xla/xla/service/gpu/horizontal_input_fusion.h
+++ /dev/null
@@ -1,65 +0,0 @@
-/* Copyright 2020 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_HORIZONTAL_INPUT_FUSION_H_
-#define XLA_SERVICE_GPU_HORIZONTAL_INPUT_FUSION_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/stream_executor/device_description.h"
-
-namespace xla {
-namespace gpu {
-
-// This optimization pass horizontally fuses kInput fusions to both reduce the
-// kernel launch overhead and increase parallelism degree. See
-// GpuHorizontalFusion for general description and motivation about horizontal
-// fusion. GpuHorizontalFusion deals with kLoop fusions while this pass deals
-// with kInput fusions.
-//
-// Following GpuHorizontalFusion, a simple yet effective heuristic is used
-// to search the fusion candidates while avoiding creating cycles. That is,
-// we simply search for fusion candidates by looking for instructions whose
-// outputs are all consumed by the same instruction. This catches the typical
-// target cases; often, the candidate instructions are just consumed by the
-// ROOT tuple of the entry computation.
-class GpuHorizontalInputFusion : public HloModulePass {
- public:
-  explicit GpuHorizontalInputFusion(const se::DeviceDescription& d)
-      : device_info_(d) {}
-
-  absl::string_view name() const override {
-    return "gpu_horizontal_input_fusion";
-  }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
-  absl::StatusOr<bool> RunOnComputation(HloComputation*);
-
-  const se::DeviceDescription& device_info_;
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_HORIZONTAL_INPUT_FUSION_H_
diff --git a/third_party/xla/xla/service/gpu/horizontal_input_fusion_test.cc b/third_party/xla/xla/service/gpu/horizontal_input_fusion_test.cc
deleted file mode 100644
index 2d458f9..0000000
--- a/third_party/xla/xla/service/gpu/horizontal_input_fusion_test.cc
+++ /dev/null
@@ -1,270 +0,0 @@
-/* Copyright 2020 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/horizontal_input_fusion.h"
-
-#include <cstdint>
-#include <utility>
-#include <vector>
-
-#include "xla/error_spec.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/literal_util.h"
-#include "xla/service/gpu/gpu_device_info_for_tests.h"
-#include "xla/service/gpu/tests/gpu_codegen_test.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/shape_util.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/test.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-namespace m = ::xla::match;
-
-class HorizontalInputFusionTest : public GpuCodegenTest {
- public:
-  se::DeviceDescription device_description_{
-      TestGpuDeviceInfo::RTXA6000DeviceInfo()};
-  GpuHorizontalInputFusion horizontal_input_fusion_{device_description_};
-};
-
-TEST_F(HorizontalInputFusionTest, BasicTest) {
-  auto module = ParseAndReturnVerifiedModule(R"(
- HloModule BasicTest
-
-  %add_f16 {
-    %x = f16[] parameter(0)
-    %y = f16[] parameter(1)
-    ROOT %add = f16[] add(%x, %y)
-  }
-
- fused_computation.1 {
-   arg.1 = f16[1024]{0} parameter(0)
-   constant0 = f16[] constant(0)
-   ROOT reduce1 = f16[] reduce(arg.1, constant0), dimensions={0}, to_apply=%add_f16
- }
-
- fused_computation.2 {
-   arg.1 = f16[1024]{0} parameter(0)
-   constant0 = f16[] constant(0)
-   ROOT reduce1 = f16[] reduce(arg.1, constant0), dimensions={0}, to_apply=%add_f16
- }
-
- ENTRY entry_computation {
-   arg.1 = f16[1024]{0} parameter(0)
-   arg.2 = f16[1024]{0} parameter(1)
-   fusion.1 = f16[] fusion(arg.1), kind=kInput, calls=fused_computation.1
-   fusion.2 = f16[] fusion(arg.2), kind=kInput, calls=fused_computation.2
-   ROOT tuple.1 = (f16[], f16[]) tuple(fusion.1, fusion.2)
- }
-)")
-                    .value();
-
-  EXPECT_TRUE(horizontal_input_fusion_.Run(module.get()).value());
-
-  const HloInstruction* entry_root =
-      module->entry_computation()->root_instruction();
-  const HloInstruction* fusion = nullptr;
-  ASSERT_THAT(entry_root,
-              GmockMatch(m::Tuple((m::GetTupleElement(m::Fusion(&fusion))),
-                                  (m::GetTupleElement(m::Fusion())))));
-  ASSERT_TRUE(fusion->IsMultiOutputFusion());
-  EXPECT_THAT(fusion->fused_expression_root(),
-              GmockMatch(m::Tuple(m::Reduce(), m::Reduce())));
-}
-
-TEST_F(HorizontalInputFusionTest, ManyInputFusions) {
-  auto module = CreateNewVerifiedModule();
-
-  HloComputation* reduce_computation;
-  {
-    auto embedded_builder = HloComputation::Builder("add");
-    auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
-        0, ShapeUtil::MakeShape(F32, {}), "lhs"));
-    auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
-        1, ShapeUtil::MakeShape(F32, {}), "rhs"));
-    embedded_builder.AddInstruction(
-        HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs));
-    reduce_computation =
-        module->AddEmbeddedComputation(embedded_builder.Build());
-  }
-
-  HloComputation::Builder builder(TestName());
-  std::vector<HloInstruction*> var_outs;
-  auto input_shape = ShapeUtil::MakeShape(F32, {1024, 1024});
-  auto output_shape = ShapeUtil::MakeShape(F32, {1024});
-  for (int64_t i = 0; i < 130; ++i) {
-    // %fused_computation.3 (param_0: f32[1024,1024], param_1: f32[]) ->
-    // f32[1024] {
-    //  %param_0 = f32[1024,1024]{1,0} parameter(0)
-    //  %param_1 = f32[] parameter(1)
-    //  %broadcast = f32[1024,1024]{1,0} broadcast(f32[] %param_1),
-    //  dimensions={}
-    //  %multiply = f32[1024,1024]{1,0}
-    //      multiply(f32[1024,1024]{1,0} %param_0, f32[1024,1024]{1,0}
-    //      %broadcast)
-    //  %constant0 = f32[] constant(0)
-    //  ROOT %reduce = f32[1024]{0}
-    //      reduce(f32[1024,1024]{1,0} %multiply, f32[] %constant0),
-    //          dimensions={1}, to_apply=%add
-    // }
-    HloInstruction* param_var_in = builder.AddInstruction(
-        HloInstruction::CreateParameter(i * 2 + 0, input_shape, "var.in"));
-    HloInstruction* param_alpha =
-        builder.AddInstruction(HloInstruction::CreateParameter(
-            i * 2 + 1, ShapeUtil::MakeShape(F32, {}), "alpha"));
-    auto alpha_broadcasted = builder.AddInstruction(
-        HloInstruction::CreateBroadcast(input_shape, param_alpha, {}));
-    auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
-        input_shape, HloOpcode::kMultiply, param_var_in, alpha_broadcasted));
-    HloInstruction* const0 = builder.AddInstruction(
-        HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
-    auto reduce = builder.AddInstruction(HloInstruction::CreateReduce(
-        output_shape, mul, const0, {1}, reduce_computation));
-    var_outs.push_back(reduce);
-  }
-  builder.AddInstruction(HloInstruction::CreateTuple(var_outs));
-  module->AddEntryComputation(builder.Build());
-
-  // Verify that horizontal fusion is kicked in. Check that there are multiple
-  // `reduce` instructions fused into the same fusion.
-  if (GetDebugOptionsForTest().xla_gpu_mlir_emitter_level() < 4) {
-    // 6 is just a randomly picked number as we don't exactly know how large the
-    // fusion will be created due to the `FusionFitsInBudget` constraint.
-    CompileAndVerifyIr(module->Clone(), R"(CHECK: reduce-group-6)",
-                       /*match_optimized_ir=*/false);
-  } else {
-    // Verify that we produced a multi-output reduction with independent groups.
-    CompileAndVerifyIr(module->Clone(), R"(CHECK: switch {{.*}} label {{.*}} [
-                                           CHECK-NEXT: label)",
-                       /*match_optimized_ir=*/false);
-  }
-
-  // Testing with the entire gpu optimization pipeline.
-  EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{1e-5, 1e-5}));
-}
-
-TEST_F(HorizontalInputFusionTest, MultiOutputFusionTest) {
-  // This tests the below pattern. One known issue is that gtes (to fusions) can
-  // be removed after their producer fusions are merged. In the below case, gte2
-  // and gte6 will be gone if Fusion2 is fused into Fusion1.
-  //
-  // Fusion1   Fusion2
-  //  |   |    |     |
-  //  |  gte1 gte2   |
-  //  |   |    |     |
-  //  |   Fusion3    |
-  //  |    |   |     |
-  // gte3 gte4 gte5 gte6
-  //  \  |     |    /
-  //  =====ROOT=====
-  //
-  auto module = ParseAndReturnVerifiedModule(R"(
- HloModule MultiOutputFusionTest
-
-  %add_f16 {
-    %x = f16[] parameter(0)
-    %y = f16[] parameter(1)
-    ROOT %add = f16[] add(%x, %y)
-  }
-
- fused_computation.1 {
-   arg.1 = f16[1024]{0} parameter(0)
-   constant0 = f16[] constant(0)
-   reduce.1 = f16[] reduce(arg.1, constant0), dimensions={0}, to_apply=%add_f16
-   add.0 = f16[1024] add(arg.1, arg.1)
-   ROOT tuple.1 = (f16[], f16[1024]) tuple(reduce.1, add.0)
- }
-
- fused_computation.2 {
-   arg.1 = f16[1024]{0} parameter(0)
-   constant0 = f16[] constant(0)
-   reduce.1 = f16[] reduce(arg.1, constant0), dimensions={0}, to_apply=%add_f16
-   add.0 = f16[1024] add(arg.1, arg.1)
-   ROOT tuple.1 = (f16[], f16[1024]) tuple(reduce.1, add.0)
- }
-
- fused_computation.3 {
-   arg.0 = f16[1024]{0} parameter(0)
-   arg.1 = f16[1024]{0} parameter(1)
-   add.0 = f16[1024] add(arg.0, arg.1)
-   mul.0 = f16[1024] multiply(arg.0, arg.1)
-   ROOT tuple.1 = (f16[1024], f16[1024]) tuple(add.0, mul.0)
- }
-
- ENTRY entry_computation {
-   arg.1 = f16[1024]{0} parameter(0)
-   arg.2 = f16[1024]{0} parameter(1)
-   fusion.1 = (f16[],f16[1024]) fusion(arg.1), kind=kInput, calls=fused_computation.1
-   fusion.2 = (f16[],f16[1024]) fusion(arg.2), kind=kInput, calls=fused_computation.2
-   gte.3 = f16[] get-tuple-element(fusion.1), index=0
-   gte.1 = f16[1024]{0} get-tuple-element(fusion.1), index=1
-   gte.2 = f16[1024]{0} get-tuple-element(fusion.2), index=1
-   gte.6 = f16[] get-tuple-element(fusion.2), index=0
-   fusion.3 = (f16[1024],f16[1024]) fusion(gte.1, gte.2),
-       kind=kLoop, calls=fused_computation.3
-   gte.4 = f16[1024] get-tuple-element(fusion.3), index=0
-   gte.5 = f16[1024]{0} get-tuple-element(fusion.3), index=1
-   ROOT tuple.1 = (f16[], f16[1024], f16[1024]{0}, f16[])
-       tuple(gte.3, gte.4, gte.5, gte.6)
- }
-)")
-                    .value();
-
-  EXPECT_TRUE(horizontal_input_fusion_.Run(module.get()).value());
-}
-
-TEST_F(HorizontalInputFusionTest, NonfusionInstrs) {
-  auto module = ParseAndReturnVerifiedModule(R"(
- HloModule NonfusionInstrs
-
- %add_f16 {
-   %x = f16[] parameter(0)
-   %y = f16[] parameter(1)
-   ROOT %add = f16[] add(%x, %y)
- }
-
- ENTRY entry_computation {
-   arg.0 = f16[1024]{0} parameter(0)
-   arg.1 = f16[1024]{0} parameter(1)
-   constant0 = f16[] constant(0)
-   reduce.0 = f16[] reduce(arg.0, constant0), dimensions={0}, to_apply=%add_f16
-   reduce.1 = f16[] reduce(arg.1, constant0), dimensions={0}, to_apply=%add_f16
-   ROOT tuple.0 = (f16[], f16[]) tuple(reduce.0, reduce.1)
- }
-)")
-                    .value();
-
-  EXPECT_TRUE(horizontal_input_fusion_.Run(module.get()).value());
-
-  const HloInstruction* entry_root =
-      module->entry_computation()->root_instruction();
-  const HloInstruction* fusion = nullptr;
-  ASSERT_THAT(entry_root,
-              GmockMatch(m::Tuple((m::GetTupleElement(m::Fusion(&fusion))),
-                                  (m::GetTupleElement(m::Fusion())))));
-  ASSERT_TRUE(fusion->IsMultiOutputFusion());
-  EXPECT_THAT(fusion->fused_expression_root(),
-              GmockMatch(m::Tuple(m::Reduce(), m::Reduce())));
-}
-
-}  // namespace
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/horizontal_loop_fusion.cc b/third_party/xla/xla/service/gpu/horizontal_loop_fusion.cc
deleted file mode 100644
index 80c46cb..0000000
--- a/third_party/xla/xla/service/gpu/horizontal_loop_fusion.cc
+++ /dev/null
@@ -1,744 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/horizontal_loop_fusion.h"
-
-#include <algorithm>
-#include <cstddef>
-#include <cstdint>
-#include <memory>
-#include <string>
-#include <utility>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_map.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/log/check.h"
-#include "absl/log/log.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/layout_util.h"
-#include "xla/service/gpu/gpu_fusible.h"
-#include "xla/service/hlo_creation_utils.h"
-#include "xla/service/sub_byte_normalization.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/util.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-namespace {
-
-PrimitiveType GetUniqueOutputTypeOfFusible(const HloInstruction& fusible) {
-  auto outputs = GetOutputsOfFusible(fusible);
-  CHECK(!outputs.empty());
-  PrimitiveType first_output_type = outputs[0]->shape().element_type();
-  for (size_t i = 1; i < outputs.size(); ++i) {
-    PrimitiveType cur_output_type = outputs[i]->shape().element_type();
-    CHECK(first_output_type == cur_output_type)
-        << "Output types are expected to be unique, but see "
-        << PrimitiveType_Name(first_output_type) << " and "
-        << PrimitiveType_Name(cur_output_type);
-  }
-
-  return first_output_type;
-}
-
-class HorizontalLoopFusionImpl {
- public:
-  explicit HorizontalLoopFusionImpl(HloComputation* computation,
-                                    absl::string_view prefix)
-      : computation_(computation), prefix_(prefix) {}
-
-  ~HorizontalLoopFusionImpl() = default;
-
-  absl::StatusOr<bool> Run();
-
- private:
-  absl::Status Fuse(absl::Span<HloInstruction*> fused_fusion_instrs,
-                    bool sliced_input_fusion,
-                    std::vector<HloInstruction*>& to_fuse_candidates);
-
-  // If `sliced_input_fusion` is true, Horizontally fuses `fused_fusion_instrs`
-  // into kInput computation, else fuses `fused_fusion_instrs` into kLoop
-  // computation.
-  //
-  // It is required that each of `fused_fusion_instrs` is a kLoop fusion. Also,
-  // we require their numbers of outputs to be the same, so that each output
-  // will be fused/concatenated with the same number of outputs from other fused
-  // fusion instrs. Then, all the fused outputs still have the same shapes for
-  // kernel generation.
-  //
-  // Returns the fused computation in `uniq_computation` and the operands that
-  // are used by `uniq_computation`.
-  absl::Status CreateFusedComputation(
-      absl::Span<HloInstruction*> fused_fusion_instrs,
-      std::unique_ptr<HloComputation>* uniq_computation,
-      std::vector<HloInstruction*>* bound_operands, bool sliced_input_fusion);
-
-  // Horizontally fuses the operands of consumer instruction,
-  // `sliced_input_fusion` controls whether kInput or kLoop type fused
-  // instruction want to be created. `to_fuse_candidates` is the instruction
-  // stack that we want to try horizontally fuse its operands, when we create a
-  // new fusion instruction, we push it to the stack in hope to further fuse its
-  // operands.
-  absl::StatusOr<bool> FuseConsumerOperands(
-      HloInstruction* consumer, bool sliced_input_fusion,
-      std::vector<HloInstruction*>& to_fuse_candidates);
-
-  // FusionCandidates collects profitable candidates for a given consumer
-  // instruction. GetNextSpanOfFusions() can then be iteratively invoked to
-  // acquire the next set of fusion candidates based on some heuristics.
-  class FusionCandidates {
-   public:
-    explicit FusionCandidates(HloInstruction* consumer,
-                              bool sliced_input_fusion)
-        : fusible_instrs_(),
-          pos_(0),
-          sliced_input_fusion_(sliced_input_fusion) {
-      Initialize(consumer);
-    }
-
-    // Gets a span of fusions to be fused.
-    absl::Span<HloInstruction*> GetNextSpanOfFusions();
-
-   private:
-    void Initialize(HloInstruction*);
-
-    std::vector<HloInstruction*> fusible_instrs_;
-    // `pos_` points to the start position of the next span.
-    size_t pos_;
-    // `sliced_input_fusion_` flag controls whether we want to fuse
-    // into kLoop (false) or kInput (True) type kernel
-    bool sliced_input_fusion_;
-  };
-
-  HloComputation* computation_;
-  std::string prefix_;
-};  // HorizontalLoopFusionImpl
-
-bool IsFusibleCandidate(const HloInstruction& instr) {
-  // For now, we do not support fusing instruction with control flow.
-  if (!instr.control_successors().empty() ||
-      !instr.control_predecessors().empty()) {
-    return false;
-  }
-
-  if (IsNestableVariadicReduction(instr)) {
-    return false;
-  }
-
-  // Require no further check for element-wise instructions.
-  if (instr.IsElementwise() && instr.operand_count() > 0) {
-    return true;
-  }
-
-  // Exclude fusions other than kLoop.
-  if (!instr.IsLoopFusion()) {
-    return false;
-  }
-
-  // Cannot support fusion who has multiple output types, because the
-  // concatenate (inserted for horizontal fusion) requires the same type
-  // for all of its operands.
-  auto outputs = GetOutputsOfFusible(instr);
-  CHECK(!outputs.empty());
-  const HloInstruction* first_output = outputs[0];
-  for (size_t i = 1; i < outputs.size(); ++i) {
-    if (first_output->shape().element_type() !=
-        outputs[i]->shape().element_type()) {
-      return false;
-    }
-  }
-
-  return true;
-}
-
-// Returns whether `instr` is a profitable candidate to be horizontally fused.
-// Since the primary benefit of horizontal fusion comes from reducing the
-// kernel launch overhead, we want to exclude the instructions with
-// insignificant kernel launch overhead. In other words, we exclude instructions
-// if their computation latencies are longer than launch latencies. We estimate
-// the computation latency of a given instruction by its shapes and the
-// instruction count in its fused computation. We roughly observe that if a
-// fusion instruction has shapes smaller than `kShapeThreshold` and has fewer
-// instructions than `kInstrCountThreshold`, it is launch-latency-bound and
-// profitable by horizontal fusion.
-bool IsProfitableFusionCandidate(const HloInstruction& instr,
-                                 bool sliced_input_fusion) {
-  // For kLoop fused kernel, each GPU thread will process 1 or more elements
-  // from each horizontal fused operands, while for kInput fused kernel, each
-  // GPU thread can only process 1 element. From experience, we enable larger
-  // tensor size threshold for kLoop fusion.
-  const int64_t kShapeThreshold =
-      sliced_input_fusion ? 128 * 2048 : 8192 * 8192;
-  const int64_t kInstrCountThreshold = sliced_input_fusion ? 30 : 128;
-  const HloInstruction* root = (instr.opcode() == HloOpcode::kFusion)
-                                   ? instr.fused_expression_root()
-                                   : &instr;
-
-  // Too large shapes are not easily profitable.
-  if (root->opcode() == HloOpcode::kTuple) {
-    // Since all output shapes are the same, use the first shape as the
-    // representative.
-    Shape shape = root->operand(0)->shape();
-    if (ShapeUtil::ElementsIn(shape) > kShapeThreshold) {
-      VLOG(2) << "Profitable check failed due to element count with "
-                 "sliced_input_fusion="
-              << sliced_input_fusion;
-      return false;
-    }
-  } else {
-    Shape shape = root->shape();
-    if (ShapeUtil::ElementsIn(shape) > kShapeThreshold) {
-      VLOG(2) << "Profiltable check failed due to element size with "
-                 "sliced_input_fusion="
-              << sliced_input_fusion;
-      return false;
-    }
-  }
-
-  // Having too many instructions is not easily profitable.
-  if (instr.opcode() == HloOpcode::kFusion &&
-      instr.fused_instruction_count() > kInstrCountThreshold) {
-    return false;
-  }
-
-  return true;
-}
-
-// Returns whether `fusion_instr` has only row-major layouts.
-// The horizontal fusion excludes computations with non-row-major layouts,
-// because fusing computations with different layouts can result in uncoalesced
-// memory accesses and cause great performance overhead.
-bool HasOnlyRowMajorLayout(const HloInstruction& instr) {
-  if (instr.opcode() != HloOpcode::kFusion) {
-    return LayoutUtil::IsMonotonicWithDim0Major(instr.shape().layout());
-  }
-
-  auto fused_instrs = instr.fused_instructions_computation()->instructions();
-  for (HloInstruction* i : fused_instrs) {
-    if (!LayoutUtil::IsDenseArray(i->shape())) {
-      continue;
-    }
-    if (!LayoutUtil::IsMonotonicWithDim0Major(i->shape().layout())) {
-      return false;
-    }
-  }
-  return true;
-}
-
-// Returns whether any operand of `instr` is a parameter instruction that
-// is shared with `fusion_instrs`.
-bool AnyOpndIsParamSharedAmongFusions(
-    const HloInstruction* instr,
-    const absl::flat_hash_set<HloInstruction*>& fusion_instrs) {
-  return absl::c_any_of(instr->operands(), [&](const HloInstruction* opnd) {
-    return opnd->opcode() == HloOpcode::kParameter &&
-           absl::c_any_of(opnd->users(), [&](const HloInstruction* user) {
-             return user != instr && fusion_instrs.contains(user);
-           });
-  });
-}
-
-void HorizontalLoopFusionImpl::FusionCandidates::Initialize(
-    HloInstruction* consumer) {
-  // First, find out all potential target candidates. We will filter out
-  // unsupported/non-profitable cases below.
-  absl::flat_hash_set<HloInstruction*> fusible_candidates;
-  std::vector<HloInstruction*> ordered_fusible_candidates;
-  for (HloInstruction* opnd : consumer->operands()) {
-    HloInstruction* predecessor = opnd->LatestNonGteAncestor();
-    // We support kLoop fusion and element-wise HLOs now. We may extend the
-    // support list if needs arise.
-    if (IsFusibleCandidate(*predecessor)) {
-      if (fusible_candidates.insert(predecessor).second) {
-        // Add unseen fusion to ordered list.
-        ordered_fusible_candidates.push_back(predecessor);
-      }
-    }
-  }
-
-  for (HloInstruction* instr : ordered_fusible_candidates) {
-    if (!IsConsumerTheOnlyNonRootUser(*instr, *consumer)) {
-      VLOG(2) << "sliced_input_fusion=" << sliced_input_fusion_
-              << " rejects maybe illegal instr " << instr->ToString()
-              << "; including it may create cycles in HLO.";
-      continue;
-    } else if (!IsProfitableFusionCandidate(*instr, sliced_input_fusion_)) {
-      VLOG(2) << "sliced_input_fusion=" << sliced_input_fusion_
-              << " rejects may-not-be profitable fusion instr"
-              << instr->ToString();
-      continue;
-    } else if (!HasOnlyRowMajorLayout(*instr)) {
-      VLOG(2) << "sliced_input_fusion=" << sliced_input_fusion_
-              << " rejects non-row-major fusion instr " << instr->ToString();
-      continue;
-    } else if (AnyOpndIsParamSharedAmongFusions(instr, fusible_candidates)) {
-      // Don't fuse fusions whose operands are parameter instructions that are
-      // shared among fusions because we cannot i/o alias the produced
-      // horizontal fusion due to the concat insertion.
-      VLOG(2) << "sliced_input_fusion=" << sliced_input_fusion_
-              << " rejects the fusion instr because it shares parameter with"
-              << " other fusion candidates, instr: " << instr->ToString();
-      continue;
-    } else {
-      VLOG(2) << "Find a fusion candidate " << instr->ToString();
-      // Encapsulate it into a fusion computation for unified representation
-      // for later processing.
-      fusible_instrs_.push_back(instr);
-    }
-  }
-
-  // Sort `fusible_instrs_` according to output types, the number of outputs,
-  // instruction counts, output tensor element count. For sliced input fusion,
-  // we only fuse instructions with the same number/type of outputs and whose
-  // computations have the same instruction count. For kLoop fusion, we requires
-  // the fused instructions to have the same number/type of outputs and also the
-  // same output shape. We did a sort here so the fusion candidates is
-  // populating a continuous span.
-  std::stable_sort(
-      fusible_instrs_.begin(), fusible_instrs_.end(),
-      [&](const HloInstruction* a, const HloInstruction* b) {
-        if (GetUniqueOutputTypeOfFusible(*a) !=
-            GetUniqueOutputTypeOfFusible(*b)) {
-          return GetUniqueOutputTypeOfFusible(*a) <
-                 GetUniqueOutputTypeOfFusible(*b);
-        } else if (GetOutputSizeOfFusible(*a) != GetOutputSizeOfFusible(*b)) {
-          return GetOutputSizeOfFusible(*a) < GetOutputSizeOfFusible(*b);
-        } else if (GetInstrCountOfFusible(*a) != GetInstrCountOfFusible(*b)) {
-          return GetInstrCountOfFusible(*a) < GetInstrCountOfFusible(*b);
-        } else {
-          return ShapeUtil::ElementsIn(GetOutputsOfFusible(*a)[0]->shape()) <
-                 ShapeUtil::ElementsIn(GetOutputsOfFusible(*b)[0]->shape());
-        }
-      });
-}
-
-// Gets a next span of fusion instructions to be fused.
-absl::Span<HloInstruction*>
-HorizontalLoopFusionImpl::FusionCandidates::GetNextSpanOfFusions() {
-  if (pos_ >= fusible_instrs_.size()) {
-    return absl::Span<HloInstruction*>();
-  }
-
-  // Fusing too many computations at a time may not be easily profitable and
-  // may increase compile time due to large kernels. Set a limit to it.
-  // From profiling results, we found an issue that large fused horizontal
-  // kernel could have lower E2E perf, though the pure GPU kernel time is
-  // shorter. TODO task for understanding why E2E perf regression for large
-  // horiizontal fused kernel. Use the experience max fusion batch size based on
-  // the fused instruction count of the operand
-  const auto kMaxFusionBatchSize = [&]() -> int64_t {
-    if (sliced_input_fusion_) {
-      return 32;
-    } else {
-      if (fusible_instrs_[pos_]->opcode() == HloOpcode::kFusion) {
-        return 32;
-      } else {
-        return 64;
-      }
-    }
-  }();
-
-  size_t left = pos_;
-  size_t right = pos_ + 1;
-  size_t first_output_size = GetOutputSizeOfFusible(*fusible_instrs_[left]);
-  PrimitiveType first_output_type =
-      GetUniqueOutputTypeOfFusible(*fusible_instrs_[left]);
-  // CUDA has a parameter size limit of ~4k bytes.
-  constexpr int64_t kMaxCudaParamSize = 4000;
-  size_t accum_io_size = 0;
-  size_t accum_num_outputs = 0;
-  for (; right < fusible_instrs_.size(); ++right) {
-    PrimitiveType cur_output_type =
-        GetUniqueOutputTypeOfFusible(*fusible_instrs_[right]);
-    if (first_output_type != cur_output_type) {
-      // Cannot fuse computations who have multiple output types.
-      break;
-    }
-    if (first_output_size != GetOutputSizeOfFusible(*fusible_instrs_[right])) {
-      // Cannot fuse computations who have different numbers of outputs.
-      break;
-    }
-    if (GetInstrCountOfFusible(*fusible_instrs_[left]) !=
-        GetInstrCountOfFusible(*fusible_instrs_[right])) {
-      // Do not fuse computations of different instruction counts as it may
-      // introduce control divergence. This is a very simple heuristic to avoid
-      // fusing computations with too much discrepancy and we may improve it
-      // when the needs arise.
-      break;
-    }
-    if (!sliced_input_fusion_ &&
-        !ShapeUtil::EqualIgnoringElementType(
-            GetOutputsOfFusible(*fusible_instrs_[left])[0]->shape(),
-            GetOutputsOfFusible(*fusible_instrs_[right])[0]->shape())) {
-      // This is for fusing into kLoop type kernel, so we requires that each
-      // fusion operand have the same shape
-      break;
-    }
-    size_t num_outputs = GetOutputSizeOfFusible(*fusible_instrs_[right]);
-    accum_num_outputs += num_outputs;
-    if (accum_num_outputs >= kMaxFusionBatchSize) {
-      // Hit max fusion batch size.
-      break;
-    }
-    accum_io_size += fusible_instrs_.at(right)->operand_count() + num_outputs;
-    if (accum_io_size * 8 >= kMaxCudaParamSize) {
-      break;
-    }
-  }
-  VLOG(2) << "horizontal fuse get instruction span with " << (right - left)
-          << " instructions for sliced_input_fusion=" << sliced_input_fusion_
-          << " fusion";
-  pos_ = right;
-  return absl::MakeSpan(fusible_instrs_).subspan(left, right - left);
-}
-
-absl::StatusOr<bool> HorizontalLoopFusionImpl::FuseConsumerOperands(
-    HloInstruction* consumer, bool sliced_input_fusion,
-    std::vector<HloInstruction*>& to_fuse_candidates) {
-  bool changed = false;
-  FusionCandidates loop_fusion_candidates(consumer, sliced_input_fusion);
-  while (true) {
-    auto fusibles = loop_fusion_candidates.GetNextSpanOfFusions();
-    if (fusibles.empty()) {
-      break;
-    } else if (fusibles.size() == 1) {
-      // Skip; there is just one fused_instr.
-      continue;
-    }
-
-    changed = true;
-    // Convert fusible into fusion_instrs to simplify the implementation of
-    // `Fuse()`.
-    std::vector<HloInstruction*> fusion_instrs;
-    for (HloInstruction* instr : fusibles) {
-      if (instr->opcode() == HloOpcode::kFusion) {
-        fusion_instrs.push_back(instr);
-      } else {
-        TF_ASSIGN_OR_RETURN(
-            HloInstruction * fusion_instr,
-            MakeFusionInstruction(instr, HloInstruction::FusionKind::kLoop));
-        fusion_instrs.push_back(fusion_instr);
-      }
-    }
-
-    TF_RETURN_IF_ERROR(Fuse(absl::MakeSpan(fusion_instrs), sliced_input_fusion,
-                            to_fuse_candidates));
-  }
-  return changed;
-}
-
-absl::Status HorizontalLoopFusionImpl::CreateFusedComputation(
-    absl::Span<HloInstruction*> fused_fusion_instrs,
-    std::unique_ptr<HloComputation>* uniq_computation,
-    std::vector<HloInstruction*>* bound_operands, bool sliced_input_fusion) {
-  // First, build a computation with only params.
-  HloComputation::Builder b(prefix_ + "horizontally_fused_computation");
-  size_t fused_comp_param_id = 0;
-  for (size_t i = 0; i < fused_fusion_instrs.size(); ++i) {
-    auto old_params = fused_fusion_instrs[i]->fused_parameters();
-    for (size_t j = 0; j < old_params.size(); ++j) {
-      HloInstruction* bound_opnd = fused_fusion_instrs[i]->mutable_operand(j);
-      // in a form of param_i_j
-      b.AddInstruction(HloInstruction::CreateParameter(
-          fused_comp_param_id++, bound_opnd->shape(),
-          absl::StrCat("param_", i, "_", j)));
-      bound_operands->push_back(bound_opnd);
-    }
-  }
-  // Always create a dummy tuple instruction to serve as the root of the
-  // computation, as the existence of a root instruction is required by the
-  // HloComputation. The real root instruction will replace it below.
-  HloInstruction* dummy_root = b.AddInstruction(
-      HloInstruction::CreateTuple(std::vector<HloInstruction*>{}));
-  *uniq_computation = b.Build(dummy_root);
-  HloComputation* comp = uniq_computation->get();
-
-  // Preparing clone_map, which maps old operand to new operand.
-  absl::flat_hash_map<const HloInstruction*, HloInstruction*> clone_map;
-  size_t new_param_id = 0;
-  for (size_t i = 0; i < fused_fusion_instrs.size(); ++i) {
-    auto old_params = fused_fusion_instrs[i]->fused_parameters();
-    for (size_t j = 0; j < old_params.size(); ++j) {
-      HloInstruction* old_param = old_params[j];
-      HloInstruction* new_param = comp->parameter_instruction(new_param_id++);
-      clone_map.insert({old_param, new_param});
-    }
-  }
-
-  // Clone every fused computation.
-  const OpMetadata* metadata = nullptr;
-  for (size_t i = 0; i < fused_fusion_instrs.size(); ++i) {
-    auto def_to_use_order = fused_fusion_instrs[i]
-                                ->fused_instructions_computation()
-                                ->MakeInstructionPostOrder();
-    for (HloInstruction* old_instr : def_to_use_order) {
-      if (old_instr->opcode() == HloOpcode::kParameter ||
-          (sliced_input_fusion && old_instr->opcode() == HloOpcode::kTuple &&
-           old_instr == fused_fusion_instrs[i]->fused_expression_root())) {
-        // Parameters have been created, and we don't need tuples from
-        // multi-output fusions, as we will directly reference the tuple
-        // operands instead by using GetOutputsOfFusible().
-        continue;
-      }
-      std::vector<HloInstruction*> new_opnds;
-      const auto& old_opnds = old_instr->operands();
-      new_opnds.reserve(old_opnds.size());
-      for (HloInstruction* old_opnd : old_opnds) {
-        CHECK(clone_map.find(old_opnd) != clone_map.end());
-        new_opnds.push_back(clone_map[old_opnd]);
-      }
-      HloInstruction* new_instr = comp->AddInstruction(
-          old_instr->CloneWithNewOperands(old_instr->shape(), new_opnds));
-      clone_map.insert({old_instr, new_instr});
-      // Get the metadata from the last fused instruction.
-      metadata = &old_instr->metadata();
-    }
-  }
-
-  // Since we require each fusion to have the same number of outputs, we can
-  // simply use the first fusion as the representative for output size.
-  size_t fused_instr_output_size =
-      GetOutputSizeOfFusible(*fused_fusion_instrs[0]);
-
-  if (sliced_input_fusion) {
-    // Fusing into kInput fusion
-    std::vector<HloInstruction*> concated_outputs;
-    for (size_t i = 0; i < fused_instr_output_size; ++i) {
-      std::vector<HloInstruction*> instr_outputs(fused_fusion_instrs.size());
-      for (size_t j = 0; j < fused_fusion_instrs.size(); ++j) {
-        const HloInstruction* old_output =
-            GetOutputsOfFusible(*fused_fusion_instrs[j])[i];
-        HloInstruction* new_output = clone_map[old_output];
-        if (new_output->shape().dimensions_size() == 1) {
-          instr_outputs[j] = new_output;
-        } else {
-          Shape new_shape = ShapeUtil::MakeShapeWithDenseLayout(
-              new_output->shape().element_type(),
-              {ShapeUtil::ElementsIn(new_output->shape())},
-              /*minor_to_major=*/std::vector<int64_t>(1, 0));
-          TF_ASSIGN_OR_RETURN(instr_outputs[j],
-                              MakeReshapeHlo(new_shape, new_output));
-        }
-      }
-      TF_ASSIGN_OR_RETURN(HloInstruction * concated_output,
-                          MakeConcatHlo(instr_outputs, 0));
-      concated_outputs.push_back(concated_output);
-    }
-
-    // Make slices of outputs.
-    std::vector<HloInstruction*> output_slices(concated_outputs.size() *
-                                               fused_fusion_instrs.size());
-    for (size_t i = 0; i < concated_outputs.size(); ++i) {
-      HloInstruction* concated_output = concated_outputs[i];
-      int64_t slice_start = 0;
-      // Create a slice per fused computation.
-      for (size_t j = 0; j < fused_fusion_instrs.size(); ++j) {
-        const HloInstruction* old_output =
-            GetOutputsOfFusible(*fused_fusion_instrs[j])[i];
-        Shape shape = old_output->shape();
-        int64_t slice_limit = slice_start + ShapeUtil::ElementsIn(shape);
-        TF_ASSIGN_OR_RETURN(
-            output_slices[concated_outputs.size() * j + i],
-            MakeSliceHlo(concated_output, {slice_start}, {slice_limit},
-                         /*strides=*/{1}));
-        slice_start = slice_limit;
-      }
-    }
-
-    // Make a tuple of output_slices.
-    HloInstruction* tuple = comp->AddInstruction(
-        HloInstruction::CreateTuple(output_slices), metadata);
-    comp->set_root_instruction(tuple, /*accept_different_shape=*/true);
-    TF_RETURN_IF_ERROR(comp->RemoveInstruction(dummy_root));
-
-  } else {
-    // Fusing into kLoop fusion
-    std::vector<HloInstruction*> tuple_operands(fused_instr_output_size *
-                                                fused_fusion_instrs.size());
-    // If fusing into kLoop fusion, the new fusion root is tuple of fused
-    // fusion computaton's root.
-    for (size_t i = 0; i < fused_instr_output_size; ++i) {
-      for (size_t j = 0; j < fused_fusion_instrs.size(); ++j) {
-        const HloInstruction* old_output =
-            GetOutputsOfFusible(*fused_fusion_instrs[j])[i];
-        HloInstruction* new_output = clone_map[old_output];
-        tuple_operands[fused_instr_output_size * j + i] = new_output;
-      }
-    }
-    // Make a tuple instruction of fused instruction outputs as
-    // the root of fused computation.
-    HloInstruction* tuple =
-        comp->AddInstruction(HloInstruction::CreateTuple(tuple_operands));
-    comp->set_root_instruction(tuple, /*accept_different_shape=*/true);
-    TF_RETURN_IF_ERROR(comp->RemoveInstruction(dummy_root));
-  }
-
-  return absl::OkStatus();
-}
-
-absl::Status HorizontalLoopFusionImpl::Fuse(
-    absl::Span<HloInstruction*> fused_fusion_instrs, bool sliced_input_fusion,
-    std::vector<HloInstruction*>& to_fuse_candidates) {
-  // Fuse fused_fusion_instrs and replace them with the new fused computation.
-  std::unique_ptr<HloComputation> uniq_computation;
-  std::vector<HloInstruction*> bound_operands;
-
-  TF_RETURN_IF_ERROR(CreateFusedComputation(fused_fusion_instrs,
-                                            &uniq_computation, &bound_operands,
-                                            sliced_input_fusion));
-
-  HloComputation* fused_comp = computation_->parent()->AddEmbeddedComputation(
-      std::move(uniq_computation));
-  HloInstruction* hori_fusion_instr = computation_->AddInstruction(
-      HloInstruction::CreateFusion(fused_comp->root_instruction()->shape(),
-                                   sliced_input_fusion
-                                       ? HloInstruction::FusionKind::kInput
-                                       : HloInstruction::FusionKind::kLoop,
-                                   bound_operands, fused_comp, prefix_),
-      &fused_comp->root_instruction()->metadata());
-  fused_comp->SetFusionInstruction(hori_fusion_instr);
-
-  // we push the newly fused instruction into fusion candidate stack, because
-  // the operands of the newly fused instruction could now be possible to be
-  // horizontally fused.
-  to_fuse_candidates.push_back(hori_fusion_instr);
-
-  // Insert bitcasts and replace corresponding users. Note that we do not insert
-  // the bitcasts in the fused computation as it does not fit into the slice
-  // input fusion pattern. However, inserting bitcasts outside the fused
-  // computation creates no performance cost.
-  size_t total_output_id = 0;
-  for (size_t i = 0; i < fused_fusion_instrs.size(); ++i) {
-    std::vector<HloInstruction*> bitcasts_or_gte;
-    HloInstruction* fused_instr = fused_fusion_instrs[i];
-    size_t num_outputs = GetOutputSizeOfFusible(*fused_instr);
-    for (size_t j = 0; j < num_outputs; ++j) {
-      const HloInstruction* output = GetOutputsOfFusible(*fused_instr)[j];
-      TF_ASSIGN_OR_RETURN(
-          HloInstruction * gep,
-          MakeGetTupleElementHlo(hori_fusion_instr, total_output_id++));
-      // This pass runs late, so useless bitcast won't be cleaned up.
-      if (output->shape().dimensions_size() == 1) {
-        bitcasts_or_gte.push_back(gep);
-      } else {
-        bitcasts_or_gte.push_back(computation_->AddInstruction(
-            HloInstruction::CreateBitcast(output->shape(), gep)));
-      }
-    }
-    HloInstruction* bitcast_or_tuple =
-        (bitcasts_or_gte.size() == 1)
-            ? bitcasts_or_gte.at(0)
-            : computation_->AddInstruction(
-                  HloInstruction::CreateTuple(bitcasts_or_gte));
-    HloComputation* old_computation =
-        fused_instr->fused_instructions_computation();
-    HloModule* module = old_computation->parent();
-    TF_RETURN_IF_ERROR(
-        computation_->ReplaceInstruction(fused_instr, bitcast_or_tuple));
-    TF_RETURN_IF_ERROR(module->RemoveEmbeddedComputation(old_computation));
-  }
-
-  TF_RETURN_IF_ERROR(Cast<HloFusionInstruction>(hori_fusion_instr)
-                         ->DeduplicateFusionOperands());
-
-  VLOG(1) << "Fused " << fused_fusion_instrs.size()
-          << " instructions into: " << hori_fusion_instr->ToString();
-  return absl::OkStatus();
-}
-
-absl::StatusOr<bool> HorizontalLoopFusionImpl::Run() {
-  bool changed = false;
-  XLA_VLOG_LINES(3, computation_->ToString());
-
-  // Traverse from use to def. Bitcasts are placed after h-fusions to resolve
-  // shape mismatch but bitcasts could prevent future h-fusion from happening.
-  // So, a bottom-up, use-to-def order should be more favorable. It also helps
-  // to save compiler iterations to reach the fixed point.
-  std::vector<HloInstruction*> to_fuse_candidates =
-      computation_->MakeInstructionPostOrder();
-
-  while (!to_fuse_candidates.empty()) {
-    HloInstruction* consumer = to_fuse_candidates.back();
-    to_fuse_candidates.pop_back();
-
-    // the consumer may be the operands of previously fused instruction, so
-    // it will no longer valid, skip this instruction.
-    if (consumer->IsDead()) {
-      continue;
-    }
-
-    // we first try to fuse into kLoop fusion instruction for those operands
-    // that have the same shape.
-    TF_ASSIGN_OR_RETURN(
-        bool loop_fusion_changed,
-        FuseConsumerOperands(consumer, false, to_fuse_candidates));
-
-    // for the remaining operands with diffent shape, we further try fuse them
-    // into kInput fusion instruction.
-    TF_ASSIGN_OR_RETURN(
-        bool sliced_input_fusion_changed,
-        FuseConsumerOperands(consumer, true, to_fuse_candidates));
-
-    changed = changed || loop_fusion_changed || sliced_input_fusion_changed;
-  }
-  return changed;
-}
-
-}  // namespace
-
-absl::StatusOr<bool> GpuHorizontalLoopFusion::RunOnComputation(
-    HloComputation* computation) {
-  HorizontalLoopFusionImpl horizontal_fusion_impl(computation, prefix_);
-  return horizontal_fusion_impl.Run();
-}
-
-absl::StatusOr<bool> GpuHorizontalLoopFusion::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  VLOG(2) << "Run horizontal fusion.";
-
-  // Run on the entry computation is actually enough.
-  TF_ASSIGN_OR_RETURN(bool changed,
-                      RunOnComputation(module->entry_computation()));
-
-  if (changed) {
-    // Correctly set element_size_in_bits for any sub-byte added slice and
-    // concatenate instructions
-    TF_ASSIGN_OR_RETURN(
-        [[maybe_unused]] bool unused,
-        SubByteNormalization{SubByteNormalization::SET_ELEMENT_SIZE}.Run(
-            module));
-  }
-
-  return changed;
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/horizontal_loop_fusion.h b/third_party/xla/xla/service/gpu/horizontal_loop_fusion.h
deleted file mode 100644
index 5daed03..0000000
--- a/third_party/xla/xla/service/gpu/horizontal_loop_fusion.h
+++ /dev/null
@@ -1,148 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_HORIZONTAL_LOOP_FUSION_H_
-#define XLA_SERVICE_GPU_HORIZONTAL_LOOP_FUSION_H_
-
-#include <string>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-// This optimization pass horizontally fuses computations for reducing kernel
-// launch overhead while increasing kernel launch dims on GPU. The initial
-// motivation of this horizontal fusion is due to the observation that the
-// training optimizer phase (e.g., AdamOptimizer and L2Loss, etc.) typically
-// has many small kernels as a result of applying the same formula on many
-// training parameters (or variables in Tensorflow). Fusing these small
-// kernels, hence, provides performance gain.
-//
-// Theoretically speaking, we may implement a cycle detection algorithm to make
-// sure no cycles are created after fusion. However, cycle detection check is
-// somewhat cumbersome; also, we observe that naive horizontal fusion of
-// arbitrary kernels may not be profitable due to control divergence and
-// possible increase of memory bandwidth pressure due to uncoalesced memory
-// accesses (note that horizontal fusion does not change the amount of memory
-// read+written at all). In practice, a simple yet effective heuristic is used
-// to avoid these issues while addressing the known beneficial cases. That is,
-// we simply search for fusion candidates by looking for instructions whose
-// outputs are all consumed by the same instruction. This catches the cases in
-// the training optimizer phase, as the candidate instructions are typically
-// consumed only by the ROOT tuple of the entry computation.
-//
-// The following illustrates the mechanism of the horizontal fusion. Before
-// fusion, there are two trivial kernels in the illustrating example. One has
-// only a Mul op, while the other consists of only an Add op. Since they are
-// only consumed by the same (ROOT) tuple instruction, horizontal fusion is
-// triggered.
-//
-// i0 i1   i2 i3
-//  | |     | |
-//  v v     v v
-//  Mul     Add
-//   |       |
-//   v       v
-//  (ROOT) tuple
-//
-// We fuse into one of two possible patterns, depending on whether all the
-// fused operations have the same shape or not.
-//
-// case 1: if Mul and Add's output shape and type are the same, then we fuse
-// them into the below pattern:
-// i0 i1   i2 i3
-//  | |     | |
-//  v v     v v
-//  Mul     Add
-//   |       |
-//   v       v
-//  (ROOT) tuple
-// the fused kernel will be kLoop type, and GPU code is emitted through
-// the LoopFusion class.
-//
-// case 2: if Mul and Add's output shape are diffent, then we fuse them into
-// the below pattern that adds extra indexing:
-// i0 i1   i2 i3       +++ (Slice) Input Fusion
-//  | |     | |          +
-//  v v     v v          +
-//  Mul     Add          +
-//   |       |           +
-//   v       v           +
-// Reshape0  Reshape1    +
-//   |       |           +
-//   v       v           +
-//  Concatenate          +
-//   |       |           +
-//   v       v           +
-//  Slice0  Slice1     +++
-//   |       |
-//   v       v
-// Reshape2  Reshape3
-//   |       |
-//   v       v
-//  (ROOT) tuple
-//
-// the fused kernel will be kInput type, and, the GPU code is emitted through
-// the InputSlicesFusion class.
-//
-// In theory, the pattern in case 1 could also be fused into the case2 target
-// graph, but we prefer to fuse into kLoop type, because the codegen for it does
-// not have the slicing range check cost introduced by case 2 pattern.
-//
-// Note that the fusion style by case 2 provides an important advantage that
-// kernels of different shapes can be horizontally fused. The first pair of
-// reshapes (i.e., Reshape0 and Reshape1) reshape the dims to 1 dimension, so
-// that the outputs of the fused kernels can (always) be concatenated. The
-// second pair of reshapes (Reshape2 and Reshape3) restore the original shapes
-// to the output tensors.
-//
-// No extra copies are introduced by the horizontal fusion. Besides Reshape2
-// and Reshape3, the other instructions are fused into an input fusion; the
-// output dims of the concatenate will be used as the kernel launch dims.
-// Instruction bitcasts can be used for Reshape2 and Reshape3 as long as the
-// outputs of Mul and Add are row-major.
-//
-// Note, reshapes are added only if the tensors isn't already a vector.
-class GpuHorizontalLoopFusion : public HloModulePass {
- public:
-  GpuHorizontalLoopFusion() = default;
-  explicit GpuHorizontalLoopFusion(absl::string_view prefix)
-      : prefix_(prefix) {}
-
-  absl::string_view name() const override {
-    return "gpu_horizontal_loop_fusion";
-  }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
-  absl::StatusOr<bool> RunOnComputation(HloComputation*);
-  std::string prefix_;
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_HORIZONTAL_LOOP_FUSION_H_
diff --git a/third_party/xla/xla/service/gpu/horizontal_loop_fusion_test.cc b/third_party/xla/xla/service/gpu/horizontal_loop_fusion_test.cc
deleted file mode 100644
index 4045183..0000000
--- a/third_party/xla/xla/service/gpu/horizontal_loop_fusion_test.cc
+++ /dev/null
@@ -1,851 +0,0 @@
-/* Copyright 2016 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/horizontal_loop_fusion.h"
-
-#include <cstdint>
-#include <optional>
-#include <utility>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/log/log.h"
-#include "xla/error_spec.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/gpu/gpu_device_info_for_tests.h"
-#include "xla/service/gpu/instruction_fusion.h"
-#include "xla/service/hlo_dce.h"
-#include "xla/service/hlo_parser.h"
-#include "xla/service/hlo_pass_fix.h"
-#include "xla/service/hlo_pass_pipeline.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/test.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/tsl/lib/core/status_test_util.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-namespace m = ::xla::match;
-
-class HorizontalLoopFusionTest : public HloTestBase {
- public:
-  static bool IsFusion(const HloInstruction* instr) {
-    return instr->opcode() == HloOpcode::kFusion;
-  }
-};
-
-TEST_F(HorizontalLoopFusionTest, BasicTest) {
-  auto module = ParseAndReturnVerifiedModule(R"(
- HloModule BasicTest
-
- fused_computation.1 {
-   arg.1 = f16[1024]{0} parameter(0)
-   arg.2 = f16[1024]{0} parameter(1)
-   ROOT mul.1 = f16[1024]{0} multiply(arg.1, arg.2)
- }
-
- fused_computation.2 {
-   arg.1 = f16[123]{0} parameter(0)
-   arg.2 = f16[123]{0} parameter(1)
-   ROOT add.1 = f16[123]{0} add(arg.1, arg.2)
- }
-
- ENTRY entry_computation {
-   arg.1 = f16[1024]{0} parameter(0)
-   arg.2 = f16[1024]{0} parameter(1)
-   arg.3 = f16[123]{0} parameter(2)
-   arg.4 = f16[123]{0} parameter(3)
-   fusion.1 = f16[1024]{0}
-       fusion(arg.1, arg.2), kind=kLoop, calls=fused_computation.1
-   fusion.2 = f16[123]{0}
-       fusion(arg.3, arg.4), kind=kLoop, calls=fused_computation.2
-   ROOT tuple.1 = (f16[1024]{0}, f16[123]{0})
-       tuple(fusion.1, fusion.2)
- }
-)")
-                    .value();
-
-  EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).value());
-  TF_ASSERT_OK(verifier().Run(module.get()).status());
-  EXPECT_FALSE(HloDCE().Run(module.get()).value());
-
-  const HloInstruction* entry_root =
-      module->entry_computation()->root_instruction();
-  const HloInstruction* fusion = nullptr;
-  ASSERT_THAT(entry_root,
-              GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
-                                  m::GetTupleElement(m::Fusion()))));
-  ASSERT_TRUE(fusion->IsMultiOutputFusion());
-  EXPECT_THAT(fusion->fused_expression_root(),
-              GmockMatch(m::Tuple(m::Slice(m::Concatenate()),
-                                  m::Slice(m::Concatenate()))));
-}
-
-// Horizontal fusion should not be triggered as fusion will create cycles.
-TEST_F(HorizontalLoopFusionTest, NegativeTestForCycle) {
-  auto module = ParseAndReturnVerifiedModule(R"(
- HloModule NegativeTestForCycle
-
- fused_computation.1 {
-   arg.1 = f16[123]{0} parameter(0)
-   arg.2 = f16[123]{0} parameter(1)
-   ROOT mul.1 = f16[123]{0} multiply(arg.1, arg.2)
- }
-
- fused_computation.2 {
-   arg.1 = f16[123]{0} parameter(0)
-   arg.2 = f16[123]{0} parameter(1)
-   ROOT add.1 = f16[123]{0} add(arg.1, arg.2)
- }
-
- ENTRY entry_computation {
-   arg.1 = f16[123]{0} parameter(0)
-   arg.2 = f16[123]{0} parameter(1)
-   arg.3 = f16[123]{0} parameter(2)
-   arg.4 = f16[123]{0} parameter(3)
-   // fusion.1 and fusion.2 will not be horizontally fused as it will create
-   // a cycle through fusion.1 -> add.2 -> fusion.2
-   fusion.1 = f16[123]{0}
-       fusion(arg.1, arg.2), kind=kLoop, calls=fused_computation.1
-   add.2 = f16[123]{0} add(fusion.1, arg.4)
-   fusion.2 = f16[123]{0}
-       fusion(add.2, arg.3), kind=kLoop, calls=fused_computation.2
-   ROOT tuple.1 = (f16[123]{0}, f16[123]{0}, f16[123]{0})
-       tuple(fusion.1, fusion.2, add.2)
- }
-)")
-                    .value();
-
-  EXPECT_FALSE(GpuHorizontalLoopFusion().Run(module.get()).value());
-}
-
-TEST_F(HorizontalLoopFusionTest, NegativeTestForIncompatibleTypes) {
-  auto module = ParseAndReturnVerifiedModule(R"(
- HloModule NegativeTestForIncompatibleTypes
-
- fused_computation.1 {
-   arg.1 = f16[1024]{0} parameter(0)
-   arg.2 = f16[1024]{0} parameter(1)
-   ROOT mul.1 = f16[1024]{0} multiply(arg.1, arg.2)
- }
-
- fused_computation.2 {
-   arg.1 = s32[123]{0} parameter(0)
-   arg.2 = s32[123]{0} parameter(1)
-   ROOT add.1 = s32[123]{0} add(arg.1, arg.2)
- }
-
- ENTRY entry_computation {
-   arg.1 = f16[1024]{0} parameter(0)
-   arg.2 = f16[1024]{0} parameter(1)
-   arg.3 = s32[123]{0} parameter(2)
-   arg.4 = s32[123]{0} parameter(3)
-   // fusion.1 and fusion.2 will not be horizontally fused because their output
-   // types are different.
-   fusion.1 = f16[1024]{0}
-       fusion(arg.1, arg.2), kind=kLoop, calls=fused_computation.1
-   fusion.2 = s32[123]{0}
-       fusion(arg.3, arg.4), kind=kLoop, calls=fused_computation.2
-   ROOT tuple.1 = (f16[1024]{0}, s32[123]{0})
-       tuple(fusion.1, fusion.2)
- }
-)")
-                    .value();
-
-  EXPECT_FALSE(GpuHorizontalLoopFusion().Run(module.get()).value());
-}
-
-TEST_F(HorizontalLoopFusionTest, FusingIntoKLoopAndKInputTogether) {
-  auto module = ParseAndReturnVerifiedModule(R"(
- HloModule FusingIntoKLoopAndKInputTogether
-
- fused_computation.1 {
-   arg.1 = f16[129, 2048]{1, 0} parameter(0)
-   arg.2 = f16[129, 2048]{1, 0} parameter(1)
-   ROOT mul.1 = f16[129,2048]{1, 0} multiply(arg.1, arg.2)
- }
-
- fused_computation.2 {
-   arg.1 = f16[129, 2048]{1, 0} parameter(0)
-   arg.2 = f16[129, 2048]{1, 0} parameter(1)
-   ROOT mul.1 = f16[129,2048]{1, 0} multiply(arg.1, arg.2)
- }
-
- fused_computation.3 {
-   arg.1 = f16[130, 2048]{1, 0} parameter(0)
-   arg.2 = f16[130, 2048]{1, 0} parameter(1)
-   ROOT mul.1 = f16[130,2048]{1, 0} multiply(arg.1, arg.2)
- }
-
- fused_computation.4 {
-   arg.1 = f16[130, 2048]{1, 0} parameter(0)
-   arg.2 = f16[130, 2048]{1, 0} parameter(1)
-   ROOT mul.1 = f16[130,2048]{1, 0} multiply(arg.1, arg.2)
- }
-
- fused_computation.5 {
-   arg.1 = f16[123]{0} parameter(0)
-   arg.2 = f16[123]{0} parameter(1)
-   ROOT add.1 = f16[123]{0} add(arg.1, arg.2)
- }
-
- fused_computation.6 {
-   arg.1 = f16[128]{0} parameter(0)
-   arg.2 = f16[128]{0} parameter(1)
-   ROOT add.1 = f16[128]{0} add(arg.1, arg.2)
- }
-
- ENTRY entry_computation {
-   arg.1 = f16[129, 2048]{1, 0} parameter(0)
-   arg.2 = f16[129, 2048]{1, 0} parameter(1)
-   arg.3 = f16[129, 2048]{1, 0} parameter(2)
-   arg.4 = f16[129, 2048]{1, 0} parameter(3)
-   arg.5 = f16[130, 2048]{1, 0} parameter(4)
-   arg.6 = f16[130, 2048]{1, 0} parameter(5)
-   arg.7 = f16[130, 2048]{1, 0} parameter(6)
-   arg.8 = f16[130, 2048]{1, 0} parameter(7)
-   arg.9 = f16[123]{0} parameter(8)
-   arg.10 = f16[123]{0} parameter(9)
-   arg.11 = f16[128]{0} parameter(10)
-   arg.12 = f16[128]{0} parameter(11)
-
-   // fusion.1 and fusion.2 will be fused into kLoop fusion
-   // fusion.3 and fusion.4 will be fused into another kLoop fusion
-   // fusion.5 and fusion.6 will be fused into kInput fusion
-
-   fusion.1 = f16[129,2048]{1, 0}
-      fusion(arg.1, arg.2), kind=kLoop, calls=fused_computation.1
-
-   fusion.2 = f16[129,2048]{1, 0}
-      fusion(arg.3, arg.4), kind=kLoop, calls=fused_computation.2
-
-   fusion.3 = f16[130,2048]{1, 0}
-      fusion(arg.5, arg.6), kind=kLoop, calls=fused_computation.3
-
-   fusion.4 = f16[130,2048]{1, 0}
-      fusion(arg.7, arg.8), kind=kLoop, calls=fused_computation.4
-
-   fusion.5 = f16[123]{0}
-      fusion(arg.9, arg.10), kind=kLoop, calls=fused_computation.5
-
-   fusion.6 = f16[128]{0}
-      fusion(arg.11, arg.12), kind=kLoop, calls=fused_computation.6
-
-   ROOT tuple.1 = (f16[129,2048]{1, 0}, f16[129,2048]{1, 0},
-                   f16[130,2048]{1, 0}, f16[130,2048]{1, 0},
-                   f16[123]{0}, f16[128]{0})
-      tuple(fusion.1, fusion.2, fusion.3, fusion.4, fusion.5, fusion.6)
- }
-)")
-                    .value();
-
-  EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).value());
-
-  int input_fusion_count = 0;
-  int loop_fusion_count = 0;
-  for (auto inst : module->entry_computation()->MakeInstructionPostOrder()) {
-    if (inst->opcode() == HloOpcode::kFusion) {
-      input_fusion_count +=
-          (inst->fusion_kind() == HloInstruction::FusionKind::kInput) ? 1 : 0;
-      loop_fusion_count +=
-          (inst->fusion_kind() == HloInstruction::FusionKind::kLoop) ? 1 : 0;
-    }
-  }
-  EXPECT_EQ(input_fusion_count, 1);
-  EXPECT_EQ(loop_fusion_count, 2);
-}
-
-TEST_F(HorizontalLoopFusionTest, HorizontalLoopFusionAfterVerticalFusion) {
-  auto module = ParseAndReturnVerifiedModule(R"(
- HloModule MergeSharedFusionInstruction
-
- ENTRY MergeSharedFusionInstruction.Computation0 {
-  param.1.1   = f32[4,1024]{1,0} parameter(0)
-  param.1.2   = f32[4,1024]{1,0} parameter(1)
-  param.1.3   = f32[4,1024]{1,0} parameter(2)
-  param.2.1   = f32[321,5]{1,0} parameter(3)
-  param.2.2   = f32[321,5]{1,0} parameter(4)
-  param.2.3   = f32[321,5]{1,0} parameter(5)
-  const.1     = f32[] constant(3)
-  const.2     = f32[] constant(3)
-  broadcast.1 = f32[4,1024]{1,0} broadcast(const.1), dimensions={}
-  broadcast.2 = f32[321,5]{1,0} broadcast(const.2), dimensions={}
-  mul.1.1     = f32[4,1024]{1,0} multiply(param.1.1, param.1.2)
-  mul.1.2     = f32[4,1024]{1,0} multiply(param.1.3, broadcast.1)
-  add.1       = f32[4,1024]{1,0} add(mul.1.1, mul.1.2)
-  mul.2.1     = f32[321,5]{1,0} multiply(param.2.1, param.2.2)
-  mul.2.2     = f32[321,5]{1,0} multiply(param.2.3, broadcast.2)
-  add.2       = f32[321,5]{1,0} add(mul.2.1, mul.2.2)
-  ROOT tuple = (f32[4,1024]{1,0}, f32[321,5]{1,0}) tuple(add.1, add.2)
-})")
-                    .value();
-
-  HloPassPipeline fusion("fusion");
-  const se::DeviceDescription device_info =
-      TestGpuDeviceInfo::RTXA6000DeviceInfo();
-  fusion.AddPass<xla::gpu::GpuInstructionFusion>(/*may_duplicate=*/false,
-                                                 device_info);
-  fusion.AddPass<xla::gpu::GpuInstructionFusion>(/*may_duplicate=*/true,
-                                                 device_info);
-  EXPECT_TRUE(fusion.Run(module.get()).value());
-  EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).value());
-  TF_ASSERT_OK(verifier().Run(module.get()).status());
-
-  VLOG(2) << "Dump after horizontal fusion:";
-  VLOG(2) << module->ToString();
-
-  const HloInstruction* entry_root =
-      module->entry_computation()->root_instruction();
-  const HloInstruction* fusion_instr = nullptr;
-  // Check that we add bitcast when needed.
-  ASSERT_THAT(entry_root,
-              GmockMatch(m::Tuple(
-                  m::Bitcast(m::GetTupleElement(m::Fusion(&fusion_instr))),
-                  m::Bitcast(m::GetTupleElement(m::Fusion())))));
-  ASSERT_TRUE(fusion_instr->IsMultiOutputFusion());
-  EXPECT_THAT(fusion_instr->fused_expression_root(),
-              GmockMatch(m::Tuple(
-                  m::Slice(m::Concatenate(m::Reshape(), m::Reshape())),
-                  m::Slice(m::Concatenate(m::Reshape(), m::Reshape())))));
-
-  EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec{0, 0}));
-}
-
-TEST_F(HorizontalLoopFusionTest, GradientDescentOptimizerLike) {
-  HloComputation::Builder builder(TestName());
-
-  std::vector<HloInstruction*> var_outs;
-  for (int64_t i = 0; i < 128; ++i) {
-    // For shapes {1, 1024}, {2, 1024}, ..., {128, 1024}
-    Shape shape = ShapeUtil::MakeShape(F32, {i + 1, 1024});
-    HloInstruction* param_var_in = builder.AddInstruction(
-        HloInstruction::CreateParameter(i * 3 + 0, shape, "var.in"));
-    HloInstruction* param_alpha =
-        builder.AddInstruction(HloInstruction::CreateParameter(
-            i * 3 + 1, ShapeUtil::MakeShape(F32, {}), "alpha"));
-    HloInstruction* param_delta = builder.AddInstruction(
-        HloInstruction::CreateParameter(i * 3 + 2, shape, "delta"));
-    HloInstruction* alpha_broadcasted = builder.AddInstruction(
-        HloInstruction::CreateBroadcast(shape, param_alpha, {}));
-    HloInstruction* alpha_delta =
-        builder.AddInstruction(HloInstruction::CreateBinary(
-            shape, HloOpcode::kMultiply, alpha_broadcasted, param_delta));
-    HloInstruction* var_out =
-        builder.AddInstruction(HloInstruction::CreateBinary(
-            shape, HloOpcode::kSubtract, param_var_in, alpha_delta));
-    var_outs.push_back(var_out);
-  }
-  builder.AddInstruction(HloInstruction::CreateTuple(var_outs));
-
-  auto module = CreateNewVerifiedModule();
-  module->AddEntryComputation(builder.Build());
-
-  // Testing with the entire gpu optimization pipeline.
-  EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{0, 0}));
-}
-
-TEST_F(HorizontalLoopFusionTest, FusingDifferentOutputs) {
-  auto module = ParseAndReturnVerifiedModule(R"(
- HloModule HeterogeneousMultiOutputFusions
-
- fused_computation.1 {
-   arg.1 = f16[1024]{0} parameter(0)
-   arg.2 = f16[1024]{0} parameter(1)
-   arg.3 = f16[1024]{0} parameter(2)
-   arg.4 = f16[1024]{0} parameter(3)
-   mul.1 = f16[1024]{0} multiply(arg.1, arg.2)
-   mul.2 = f16[1024]{0} multiply(arg.3, arg.4)
-   add.1 = f16[1024]{0} add(mul.1, mul.2)
-   ROOT tuple.1 = (f16[1024]{0}, f16[1024]{0}) tuple(add.1, mul.1)
- }
-
- fused_computation.2 {
-   arg.1 = f16[123]{0} parameter(0)
-   arg.2 = f16[123]{0} parameter(1)
-   arg.3 = f16[123]{0} parameter(2)
-   arg.4 = f16[123]{0} parameter(3)
-   add.1 = f16[123]{0} add(arg.1, arg.2)
-   add.2 = f16[123]{0} add(arg.3, arg.4)
-   mul.1 = f16[123]{0} multiply(add.1, add.2)
-   ROOT tuple.1 = (f16[123]{0}, f16[123]{0}) tuple(mul.1, add.1)
- }
-
- ENTRY entry_computation {
-   arg.1 = f16[1024]{0} parameter(0)
-   arg.2 = f16[1024]{0} parameter(1)
-   arg.3 = f16[1024]{0} parameter(2)
-   arg.4 = f16[1024]{0} parameter(3)
-   arg.5 = f16[123]{0} parameter(4)
-   arg.6 = f16[123]{0} parameter(5)
-   arg.7 = f16[123]{0} parameter(6)
-   arg.8 = f16[123]{0} parameter(7)
-   fusion.1 = (f16[1024]{0}, f16[1024]{0})
-       fusion(arg.1, arg.2, arg.3, arg.4),
-           kind=kLoop, calls=fused_computation.1
-   fusion.2 = (f16[123]{0}, f16[123]{0})
-       fusion(arg.5, arg.6, arg.7, arg.8),
-           kind=kLoop, calls=fused_computation.2
-   gte.1 = f16[1024]{0} get-tuple-element(fusion.1), index=0
-   gte.2 = f16[1024]{0} get-tuple-element(fusion.1), index=1
-   gte.3 = f16[123]{0} get-tuple-element(fusion.2), index=0
-   gte.4 = f16[123]{0} get-tuple-element(fusion.2), index=1
-   ROOT tuple.1 = (f16[1024]{0}, f16[1024]{0}, f16[123]{0}, f16[123]{0})
-       tuple(gte.1, gte.2, gte.3, gte.4)
- }
-)")
-                    .value();
-
-  EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).value());
-  TF_ASSERT_OK(verifier().Run(module.get()).status());
-  EXPECT_FALSE(HloDCE().Run(module.get()).value());
-
-  VLOG(2) << "Dump after horizontal fusion:";
-  VLOG(2) << module->ToString();
-
-  EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec{0, 0}));
-}
-
-TEST_F(HorizontalLoopFusionTest, RMSPropLike) {
-  HloComputation::Builder builder(TestName());
-
-  std::vector<HloInstruction*> all_outputs;
-  for (int64_t i = 0; i < 48; ++i) {
-    Shape shape = ShapeUtil::MakeShape(F32, {2, 1024 + i});
-    // ms <- grad**2 (1 - rho) + ms * rho
-    HloInstruction* grad = builder.AddInstruction(
-        HloInstruction::CreateParameter(i * 9 + 0, shape, "grad"));
-    HloInstruction* ms = builder.AddInstruction(
-        HloInstruction::CreateParameter(i * 9 + 1, shape, "ms"));
-    HloInstruction* rho =
-        builder.AddInstruction(HloInstruction::CreateParameter(
-            i * 9 + 2, ShapeUtil::MakeShape(F32, {}), "rho"));
-    HloInstruction* one_minus_rho =
-        builder.AddInstruction(HloInstruction::CreateParameter(
-            i * 9 + 3, ShapeUtil::MakeShape(F32, {}), "one_minus_rho"));
-    HloInstruction* rho_broadcasted =
-        builder.AddInstruction(HloInstruction::CreateBroadcast(shape, rho, {}));
-    HloInstruction* one_mins_rho_broadcasted = builder.AddInstruction(
-        HloInstruction::CreateBroadcast(shape, one_minus_rho, {}));
-    HloInstruction* grad_squared = builder.AddInstruction(
-        HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, grad, grad));
-    HloInstruction* ms_1st_term = builder.AddInstruction(
-        HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, grad_squared,
-                                     one_mins_rho_broadcasted));
-    HloInstruction* ms_2nd_term =
-        builder.AddInstruction(HloInstruction::CreateBinary(
-            shape, HloOpcode::kMultiply, ms, rho_broadcasted));
-    HloInstruction* ms_out =
-        builder.AddInstruction(HloInstruction::CreateBinary(
-            shape, HloOpcode::kAdd, ms_1st_term, ms_2nd_term));
-
-    // mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)
-    HloInstruction* momentum = builder.AddInstruction(
-        HloInstruction::CreateParameter(i * 9 + 4, shape, "momemtum"));
-    HloInstruction* mom = builder.AddInstruction(
-        HloInstruction::CreateParameter(i * 9 + 5, shape, "mom"));
-    HloInstruction* lr = builder.AddInstruction(HloInstruction::CreateParameter(
-        i * 9 + 6, ShapeUtil::MakeShape(F32, {}), "lr"));
-    HloInstruction* epsilon =
-        builder.AddInstruction(HloInstruction::CreateParameter(
-            i * 9 + 7, ShapeUtil::MakeShape(F32, {}), "epsilon"));
-    HloInstruction* lr_broadcasted =
-        builder.AddInstruction(HloInstruction::CreateBroadcast(shape, lr, {}));
-    HloInstruction* epsilon_broadcasted = builder.AddInstruction(
-        HloInstruction::CreateBroadcast(shape, epsilon, {}));
-    HloInstruction* mom_1st_term =
-        builder.AddInstruction(HloInstruction::CreateBinary(
-            shape, HloOpcode::kMultiply, momentum, mom));
-    HloInstruction* ms_eps =
-        builder.AddInstruction(HloInstruction::CreateBinary(
-            shape, HloOpcode::kAdd, ms_out, epsilon_broadcasted));
-    HloInstruction* ms_eps_rsq = builder.AddInstruction(
-        HloInstruction::CreateUnary(shape, HloOpcode::kRsqrt, ms_eps));
-    HloInstruction* grad_ms_eps_rsq =
-        builder.AddInstruction(HloInstruction::CreateBinary(
-            shape, HloOpcode::kMultiply, grad, ms_eps_rsq));
-    HloInstruction* mom_2nd_term =
-        builder.AddInstruction(HloInstruction::CreateBinary(
-            shape, HloOpcode::kMultiply, lr_broadcasted, grad_ms_eps_rsq));
-    HloInstruction* mom_out =
-        builder.AddInstruction(HloInstruction::CreateBinary(
-            shape, HloOpcode::kAdd, mom_1st_term, mom_2nd_term));
-
-    // var <- var - mom
-    HloInstruction* var = builder.AddInstruction(
-        HloInstruction::CreateParameter(i * 9 + 8, shape, "var"));
-    HloInstruction* var_out =
-        builder.AddInstruction(HloInstruction::CreateBinary(
-            shape, HloOpcode::kSubtract, var, mom_out));
-
-    all_outputs.push_back(ms_out);
-    all_outputs.push_back(mom_out);
-    all_outputs.push_back(var_out);
-  }
-  builder.AddInstruction(HloInstruction::CreateTuple(all_outputs));
-
-  auto module = CreateNewVerifiedModule();
-  module->AddEntryComputation(builder.Build());
-
-  EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{1.0e-5, 1.0e-5}));
-}
-
-TEST_F(HorizontalLoopFusionTest, DynamicUpdateSlice) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule NegativeTestForDynamicUpdateSlice
-
-  fusion.1 {
-    p.0 = f16[5,9,10]{2,1,0} parameter(0)
-    p.1 = s32[] parameter(1)
-    p.2 = f16[1,9,10]{2,1,0} parameter(2)
-    c.0 = s32[] constant(0)
-    ROOT %dynamic-update-slice =
-        f16[5,9,10]{2,1,0} dynamic-update-slice(p.0, p.2, p.1, c.0, c.0)
-  }
-
-  fusion.2 {
-    p.0 = f16[5,9,10]{2,1,0} parameter(0)
-    p.1 = s32[] parameter(1)
-    p.2 = f16[1,9,10]{2,1,0} parameter(2)
-    c.0 = s32[] constant(0)
-    ROOT %dynamic-update-slice =
-        f16[5,9,10]{2,1,0} dynamic-update-slice(p.0, p.2, p.1, c.0, c.0)
-  }
-
-  ENTRY entry {
-    p.00 = f16[5,9,10]{2,1,0} parameter(0)
-    p.01 = f16[5,9,10]{2,1,0} parameter(1)
-    p.10 = s32[] parameter(2)
-    p.11 = s32[] parameter(3)
-    p.20 = f16[1,9,10]{2,1,0} parameter(4)
-    p.21 = f16[1,9,10]{2,1,0} parameter(5)
-
-    f1 = f16[5,9,10] fusion(p.00, p.10, p.20), kind=kLoop, calls=fusion.1
-    f2 = f16[5,9,10] fusion(p.01, p.11, p.21), kind=kLoop, calls=fusion.2
-    ROOT tuple = (f16[5,9,10],f16[5,9,10]) tuple(f1, f2)
-  })")
-                    .value();
-
-  EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).value());
-  TF_ASSERT_OK(verifier().Run(module.get()).status());
-  EXPECT_FALSE(HloDCE().Run(module.get()).value());
-
-  VLOG(2) << "Dump after horizontal fusion:";
-  VLOG(2) << module->ToString();
-
-  EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec{0, 0}));
-}
-
-TEST_F(HorizontalLoopFusionTest, NegativeTestForSharedParam) {
-  auto module = ParseAndReturnVerifiedModule(R"(
- HloModule BasicTest
-
- fused_computation.1 {
-   arg.1 = f16[123]{0} parameter(0)
-   arg.2 = f16[123]{0} parameter(1)
-   ROOT mul.1 = f16[123]{0} multiply(arg.1, arg.2)
- }
-
- fused_computation.2 {
-   arg.1 = f16[123]{0} parameter(0)
-   arg.2 = f16[123]{0} parameter(1)
-   ROOT add.1 = f16[123]{0} add(arg.1, arg.2)
- }
-
- ENTRY entry_computation {
-   arg.1 = f16[123]{0} parameter(0)
-   // arg.2 is shared by fusion.1 and fusion.2
-   arg.2 = f16[123]{0} parameter(1)
-   arg.3 = f16[123]{0} parameter(2)
-   fusion.1 = f16[123]{0}
-       fusion(arg.1, arg.2), kind=kLoop, calls=fused_computation.1
-   fusion.2 = f16[123]{0}
-       fusion(arg.3, arg.2), kind=kLoop, calls=fused_computation.2
-   ROOT tuple.1 = (f16[123]{0}, f16[123]{0})
-       tuple(fusion.1, fusion.2)
- }
-)")
-                    .value();
-
-  EXPECT_FALSE(GpuHorizontalLoopFusion().Run(module.get()).value());
-}
-
-TEST_F(HorizontalLoopFusionTest, IterativeHorizontalFusion) {
-  auto module = ParseAndReturnVerifiedModule(R"(
- HloModule NonfusionInstrs
-
- fused_computation.0 {
-   arg.0 = f16[] parameter(0)
-   arg.1 = f16[123]{0} parameter(1)
-   broadcast.0 = f16[123]{0} broadcast(arg.0), dimensions={}
-   ROOT mul.1 = f16[123]{0} multiply(broadcast.0, arg.1)
- }
-
- fused_computation.1 {
-   arg.0 = f16[] parameter(0)
-   arg.1 = f16[456]{0} parameter(1)
-   broadcast.0 = f16[456]{0} broadcast(arg.0), dimensions={}
-   ROOT add.1 = f16[456]{0} add(broadcast.0, arg.1)
- }
-
- ENTRY entry_computation {
-   arg.0 = f16[] parameter(0)
-   arg.1 = f16[] parameter(1)
-   arg.2 = f16[123]{0} parameter(2)
-   arg.3 = f16[456]{0} parameter(3)
-   // Test fusion of non-fusion instructions. sqrt.0 and sqrt.1 are to be
-   // fused.
-   sqrt.0 = f16[] sqrt(arg.0)
-   sqrt.1 = f16[] sqrt(arg.1)
-   // fusion.0 and fusion.1 are to be fused.
-   fusion.0 = f16[123]{0}
-       fusion(sqrt.0, arg.2), kind=kLoop, calls=fused_computation.0
-   fusion.1 = f16[456]{0}
-       fusion(sqrt.1, arg.3), kind=kLoop, calls=fused_computation.1
-   ROOT tuple.1 = (f16[123]{0}, f16[456]{0}) tuple(fusion.0, fusion.1)
- }
-)")
-                    .value();
-
-  HloPassFix<HloPassPipeline> iterative_h_fusion("iterative_h_fusion");
-  iterative_h_fusion.AddPass<GpuHorizontalLoopFusion>();
-  iterative_h_fusion.AddPass<HloDCE>();
-  EXPECT_TRUE(iterative_h_fusion.Run(module.get()).value());
-
-  // Verify that fusion.0 and fusion.1 are fused.
-  const HloInstruction* entry_root =
-      module->entry_computation()->root_instruction();
-  const HloInstruction* fusion = nullptr;
-  ASSERT_THAT(entry_root,
-              GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
-                                  m::GetTupleElement(m::Fusion()))));
-  EXPECT_TRUE(fusion->IsMultiOutputFusion());
-
-  // Verify that the total number of fusion instructions is 2 so that we
-  // know sqrt.0 and sqrt.1 are fused.
-  EXPECT_EQ(
-      absl::c_count_if(module->entry_computation()->instructions(), IsFusion),
-      2);
-}
-
-TEST_F(HorizontalLoopFusionTest, TraversalOrder) {
-  auto module = ParseAndReturnVerifiedModule(R"(
- HloModule cluster
-
- %fused_computation (param_0: f32[256,256], param_1: f32[], param_2: f32[])
-     -> f32[256,256] {
-   %param_0 = f32[256,256]{1,0} parameter(0)
-   %param_1 = f32[] parameter(1)
-   %param_2 = f32[] parameter(2)
-   %multiply.0 = f32[] multiply(f32[] %param_1, f32[] %param_2)
-   %broadcast.0 = f32[256,256]{1,0} broadcast(f32[] %multiply.0), dimensions={}
-   ROOT %multiply.1 = f32[256,256]{1,0}
-       multiply(f32[256,256]{1,0} %param_0, f32[256,256]{1,0} %broadcast.0)
- }
-
- %fused_computation.1 (param_0: f32[256,256], param_1: f32[], param_2: f32[])
-     -> f32[256,256] {
-   %param_0 = f32[256,256]{1,0} parameter(0)
-   %param_1 = f32[] parameter(1)
-   %param_2 = f32[] parameter(2)
-   %multiply.0 = f32[] multiply(f32[] %param_1, f32[] %param_2)
-   %broadcast.0 = f32[256,256]{1,0} broadcast(f32[] %multiply.0), dimensions={}
-   ROOT %multiply.1 = f32[256,256]{1,0}
-       multiply(f32[256,256]{1,0} %param_0, f32[256,256]{1,0} %broadcast.0)
- }
-
- ENTRY %entry_computation (arg0: f32[256,256], arg1: f32[256,256], arg2: f32[],
-                           arg3: f32[], arg4: f32[], arg5: f32[])
-                               -> (f32[256,256], f32[256,256]) {
-   %arg0 = f32[256,256]{1,0} parameter(0), parameter_replication={false}
-   %arg1 = f32[256,256]{1,0} parameter(1), parameter_replication={false}
-   %arg2 = f32[] parameter(2), parameter_replication={false}
-   %arg3 = f32[] parameter(3), parameter_replication={false}
-   %arg4 = f32[] parameter(4), parameter_replication={false}
-   %arg5 = f32[] parameter(5), parameter_replication={false}
-   %sqrt = f32[] sqrt(f32[] %arg2)
-   %sqrt.1 = f32[] sqrt(f32[] %arg3)
-   %fusion = f32[256,256]{1,0}
-       fusion(f32[256,256]{1,0} %arg0, f32[] %sqrt, f32[] %sqrt.1),
-       kind=kLoop, calls=%fused_computation
-   %sqrt.2 = f32[] sqrt(f32[] %arg4)
-   %sqrt.3 = f32[] sqrt(f32[] %arg5)
-   %fusion.1 = f32[256,256]{1,0}
-       fusion(f32[256,256]{1,0} %arg1, f32[] %sqrt.2, f32[] %sqrt.3),
-       kind=kLoop, calls=%fused_computation.1
-   ROOT %tuple.163 = (f32[256,256]{1,0}, f32[256,256]{1,0})
-       tuple(f32[256,256]{1,0} %fusion.1, f32[256,256]{1,0} %fusion)
- }
-)")
-                    .value();
-
-  HloPassFix<HloPassPipeline> iterative_h_fusion("iterative_h_fusion");
-  iterative_h_fusion.AddPass<GpuHorizontalLoopFusion>();
-  EXPECT_TRUE(iterative_h_fusion.Run(module.get()).value());
-
-  // Verify that the total number of fusion instructions is 2 so that we
-  // know all the sqrt instructions are fused into a kernel. Note that if we
-  // traverse from def-to-use (i.e., top-to-down) instead of use-to-def, we
-  // will end up having 3 fusions instead of 2.
-  EXPECT_EQ(
-      absl::c_count_if(module->entry_computation()->instructions(), IsFusion),
-      2);
-}
-
-// Simplified reproducer for Google bug b/242287055.
-// Things that happened:
-//  - horizontal loop fusion joined addition a0 and multiplication m0
-//  - the resulting fusion had 4 inputs: (gte1, gte0, gte1, gte0)
-//  - buffer assignment aliased outputs of this fusion with its inputs
-//  - some threads simultaneously did the addition, some - multiplication
-//  - as a result some inputs were overwritten before being read
-// Conditional operation is meaningless (branches are equivalent) and
-// is there only to properly confuse the buffer assignment.
-TEST_F(HorizontalLoopFusionTest, NoBufferAliasingOfDuplicateParameter) {
-  const char* hlo_text = R"(
-HloModule m
-
-branch_a {
-  p0 = s32[] parameter(0)
-  c0 = s32[] constant(1)
-  c1 = s32[] constant(2)
-  b0 = s32[4096] broadcast(c0), dimensions={}
-  b1 = s32[4096] broadcast(c1), dimensions={}
-  ROOT r = (s32[4096], s32[4096]) tuple(b0, b1)
-}
-
-branch_b {
-  p0 = s32[] parameter(0)
-  c0 = s32[] constant(1)
-  c1 = s32[] constant(2)
-  b0 = s32[4096] broadcast(c0), dimensions={}
-  b1 = s32[4096] broadcast(c1), dimensions={}
-  ROOT r = (s32[4096], s32[4096]) tuple(b0, b1)
-}
-
-ENTRY e {
-  p0 = s32[] parameter(0)
-  c0 = s32[] constant(0)
-  cond = (s32[4096], s32[4096]) conditional(p0, c0, c0), branch_computations={branch_a, branch_b}
-  p1 = s32[4096] parameter(1)
-  gte0 = s32[4096] get-tuple-element(cond), index=0
-  gte1 = s32[4096] get-tuple-element(cond), index=1
-  a0 = s32[4096] add(gte1, gte0)
-  m0 = s32[4096] multiply(gte1, gte0)
-  ROOT r = (s32[4096], s32[4096]) tuple(m0, a0)
-}
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, std::nullopt));
-}
-
-TEST_F(HorizontalLoopFusionTest, CopyInsertionFusionControlFlow) {
-  const char* hlo_text = R"(
-HloModule cluster
-
-ENTRY main {
-  cst = f32[1]{0} constant({0})
-  cp1 = f32[1]{0} copy(cst)
-  cp2 = f32[1]{0} copy(cst)
-  cp3 = f32[1]{0} copy(cst)
-  cp4 = f32[1]{0} copy(cst), control-predecessors={cp1}
-  ROOT tuple_out = (f32[1]{0}, f32[1]{0}, f32[1]{0}, f32[1]{0}) tuple(cp1, cp2, cp3, cp4)
-}
-)";
-
-  auto module = ParseAndReturnUnverifiedModule(hlo_text).value();
-  EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).value());
-
-  VLOG(2) << module->ToString();
-
-  // Verify that the total number of fusion instructions is 1.
-  EXPECT_EQ(
-      absl::c_count_if(module->entry_computation()->instructions(), IsFusion),
-      1);
-
-  const HloInstruction* entry_root =
-      module->entry_computation()->root_instruction();
-  // Check that we fuse when supported.
-  EXPECT_THAT(entry_root,
-              GmockMatch(m::Tuple(m::Copy(), m::GetTupleElement(m::Fusion()),
-                                  m::GetTupleElement(m::Fusion()), m::Copy())));
-}
-
-TEST_F(HorizontalLoopFusionTest, DoNotMergeVariadicReductions) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule m
-
-  fused_computation.94 {
-    tmp_0 = f32[] parameter(0)
-    tmp_1 = f32[] parameter(1)
-    tmp_2 = pred[] compare(tmp_0, tmp_1), direction=GE
-    tmp_3 = f32[] select(tmp_2, tmp_0, tmp_1)
-    tmp_4 = pred[] compare(tmp_0, tmp_1), direction=EQ
-    tmp_5 = s32[] parameter(2)
-    tmp_6 = s32[] parameter(3)
-    tmp_7 = s32[] minimum(tmp_5, tmp_6)
-    tmp_8 = s32[] select(tmp_2, tmp_5, tmp_6)
-    tmp_9 = s32[] select(tmp_4, tmp_7, tmp_8)
-    ROOT tmp_10 = (f32[], s32[]) tuple(tmp_3, tmp_9)
-  }
-
-  minmax_func.1536 {
-    tmp_0 = f32[] parameter(0)
-    tmp_1 = f32[] parameter(2)
-    tmp_2 = s32[] parameter(1)
-    tmp_3 = s32[] parameter(3)
-    ROOT tmp_4 = (f32[], s32[]) fusion(tmp_0, tmp_1, tmp_2, tmp_3), kind=kLoop, calls=fused_computation.94
-  }
-
-  fused_computation {
-    tmp_0 = f32[554112,10]{1,0} parameter(0)
-    tmp_1 = s32[554112,10]{1,0} iota(), iota_dimension=1
-    tmp_2 = f32[] constant(-inf)
-    tmp_3 = s32[] constant(0)
-    ROOT tmp_4 = (f32[554112]{0}, s32[554112]{0}) reduce(tmp_0, tmp_1, tmp_2, tmp_3), dimensions={1}, to_apply=minmax_func.1536
-  }
-
-  fused_computation2 {
-    tmp_0 = f32[554112,10]{1,0} parameter(0)
-    tmp_1 = s32[554112,10]{1,0} iota(), iota_dimension=1
-    tmp_2 = f32[] constant(inf)
-    tmp_3 = s32[] constant(1)
-    ROOT tmp_4 = (f32[554112]{0}, s32[554112]{0}) reduce(tmp_0, tmp_1, tmp_2, tmp_3), dimensions={1}, to_apply=minmax_func.1536
-  }
-
-  ENTRY e {
-    tmp_0 = f32[554112,10]{1,0} parameter(0)
-    tmp_1 = (f32[554112]{0}, s32[554112]{0}) fusion(tmp_0), kind=kLoop, calls=fused_computation
-    tmp_2 = s32[554112]{0} get-tuple-element(tmp_1), index=1
-    tmp_3 = f32[554112,10]{1,0} parameter(1)
-    tmp_4 = (f32[554112]{0}, s32[554112]{0}) fusion(tmp_3), kind=kLoop, calls=fused_computation2
-    tmp_5 = s32[554112]{0} get-tuple-element(tmp_4), index=1
-    ROOT tmp_6 = s32[554112]{0} add(tmp_2, tmp_5)
-  })")
-                    .value();
-
-  EXPECT_FALSE(GpuHorizontalLoopFusion().Run(module.get()).value());
-}
-
-}  // namespace
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/instruction_fusion.cc b/third_party/xla/xla/service/gpu/instruction_fusion.cc
deleted file mode 100644
index 8751d44..0000000
--- a/third_party/xla/xla/service/gpu/instruction_fusion.cc
+++ /dev/null
@@ -1,187 +0,0 @@
-/* Copyright 2017 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/instruction_fusion.h"
-
-#include <cstdint>
-#include <memory>
-#include <utility>
-#include <vector>
-
-#include "absl/container/flat_hash_map.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/meta/type_traits.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/fusion_node_indexing_evaluation.h"
-#include "xla/service/fusion_queue.h"
-#include "xla/service/gpu/gpu_fusible.h"
-#include "xla/service/instruction_fusion.h"
-#include "xla/shape.h"
-#include "xla/xla_data.pb.h"
-
-namespace xla {
-namespace gpu {
-
-namespace {
-bool ElementIsF32OrF16(const Shape& shape) {
-  PrimitiveType type = shape.element_type();
-  return type == F32 || type == F16;
-}
-
-class EmptyFusionQueue : public FusionQueue {
- public:
-  std::pair<HloInstruction*, std::vector<int64_t>>
-  DequeueNextInstructionAndOperandsToFuseInOrder() override {
-    return {nullptr, {}};
-  }
-  void RemoveInstruction(HloInstruction* instruction) override {};
-  const std::vector<bool>* FusionConfiguration() override { return nullptr; };
-};
-
-}  // namespace
-
-absl::StatusOr<bool> GpuInstructionFusion::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  fusion_node_evaluations_.clear();
-  auto fusible_computations =
-      GetFusibleComputations(*module, execution_threads);
-  fusible_computations_ = {fusible_computations.begin(),
-                           fusible_computations.end()};
-  return InstructionFusion::Run(module, execution_threads);
-}
-
-/*static*/ bool GpuInstructionFusion::IsExpensive(
-    const HloInstruction& instruction) {
-  // Some floating-point math ops are cheap on the GPU.
-  switch (instruction.opcode()) {
-    case HloOpcode::kDivide:
-    case HloOpcode::kSqrt:
-    case HloOpcode::kRsqrt:
-    case HloOpcode::kExp:
-      if (ElementIsF32OrF16(instruction.shape())) {
-        return false;
-      }
-      break;
-    default:
-      break;
-  }
-  return InstructionFusion::IsExpensive(instruction);
-}
-
-FusionDecision GpuInstructionFusion::ShouldFuseInexpensiveChecks(
-    HloInstruction* consumer, int64_t operand_index) {
-  HloInstruction* producer = consumer->mutable_operand(operand_index);
-
-  // Output fusions are not currently supported on GPUs.
-  if (producer->opcode() == HloOpcode::kFusion) {
-    return "the producer is a fusion";
-  }
-
-  if (consumer->IsCustomFusion()) {
-    return "the consumer is a custom fusion";
-  }
-
-  // Cost condition: not fuse (simple, expensive producers) and (consumers who
-  // reuse operand elements).
-  if (is_expensive(*producer) &&
-      ReusesOperandElements(consumer, operand_index)) {
-    return "the producer is expensive, and the consumer reuses inputs";
-  }
-
-  // Do not fuse into fusions if the resulting kernel would suffer from
-  // uncoalesced reads due to a transposed memory access pattern.
-  if (IsInputFusibleReduction(*consumer) &&
-      IsPhysicallyTransposing(*producer)) {
-    return "fusing the producer would break read coalescing";
-  }
-
-  RETURN_IF_NOT_FUSIBLE(IsProducerConsumerFusible(*producer, *consumer));
-
-  if (CreatesHeavyComputation(*producer, *consumer)) {
-    return "the fusion would create a heavy computation";
-  }
-
-  return InstructionFusion::ShouldFuse(consumer, operand_index);
-}
-
-FusionDecision GpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
-                                                int64_t operand_index) {
-  RETURN_IF_NOT_FUSIBLE(ShouldFuseInexpensiveChecks(consumer, operand_index));
-
-  auto producer = consumer->operand(operand_index);
-
-  // The following checks are potentially expensive.
-  RETURN_IF_NOT_FUSIBLE(
-      FusionFitsInBudget(*consumer, *producer, device_info_,
-                         /*is_consumer_producer_fusion=*/true));
-
-  if (consumer->opcode() != HloOpcode::kFusion) {
-    return {};
-  }
-
-  // Also check that our emitter can handle the fusion node. We currently can
-  // have exponential time/memory requirements for emitting certain fusion
-  // kernels, in which case we don't want to fuse.
-  // TODO(b/119692968): Remove this once we have fixed our fusion emitter.
-  if (fusion_node_evaluations_.find(consumer) ==
-      fusion_node_evaluations_.end()) {
-    // We have no cached results for this fusion node yet. This can happen when
-    // we run the InstructionFusion pass more than once. We can only cache the
-    // results within one run.
-    fusion_node_evaluations_.emplace(consumer,
-                                     FusionNodeIndexingEvaluation(consumer));
-  }
-  if (fusion_node_evaluations_.at(consumer).CodeDuplicationTooHigh(producer)) {
-    return "the fusion would result in an overly large code duplication";
-  }
-  return {};
-}
-
-HloInstruction::FusionKind GpuInstructionFusion::ChooseKind(
-    const HloInstruction* producer, const HloInstruction* consumer) {
-  return ChooseFusionKind(*producer, *consumer);
-}
-
-HloInstruction* GpuInstructionFusion::FuseInstruction(
-    HloInstruction* fusion_instruction, HloInstruction* producer) {
-  auto evaluation = fusion_node_evaluations_.find(fusion_instruction);
-  if (evaluation == fusion_node_evaluations_.end()) {
-    evaluation = fusion_node_evaluations_
-                     .emplace(fusion_instruction,
-                              FusionNodeIndexingEvaluation(fusion_instruction))
-                     .first;
-  }
-  auto indexing_users = evaluation->second.RemoveFusionOperand(producer);
-  HloInstruction* new_producer =
-      InstructionFusion::FuseInstruction(fusion_instruction, producer);
-  evaluation->second.UpdateEvaluationCache(new_producer, indexing_users);
-  return new_producer;
-}
-
-std::unique_ptr<FusionQueue> GpuInstructionFusion::GetFusionQueue(
-    HloComputation* computation) {
-  if (fusible_computations_.contains(computation)) {
-    return InstructionFusion::GetFusionQueue(computation);
-  }
-  return std::make_unique<EmptyFusionQueue>();
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/instruction_fusion.h b/third_party/xla/xla/service/gpu/instruction_fusion.h
deleted file mode 100644
index 29eb032..0000000
--- a/third_party/xla/xla/service/gpu/instruction_fusion.h
+++ /dev/null
@@ -1,82 +0,0 @@
-/* Copyright 2017 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_INSTRUCTION_FUSION_H_
-#define XLA_SERVICE_GPU_INSTRUCTION_FUSION_H_
-
-#include <stdint.h>
-
-#include <memory>
-
-#include "absl/container/flat_hash_map.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/fusion_node_indexing_evaluation.h"
-#include "xla/service/fusion_queue.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/service/instruction_fusion.h"
-#include "xla/stream_executor/device_description.h"
-
-namespace xla {
-namespace gpu {
-
-class GpuInstructionFusion : public InstructionFusion {
- public:
-  GpuInstructionFusion(bool may_duplicate, const se::DeviceDescription& d)
-      : InstructionFusion(GpuInstructionFusion::IsExpensive, may_duplicate),
-        device_info_(d) {}
-
-  static bool IsExpensive(const HloInstruction& instruction);
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- protected:
-  std::unique_ptr<FusionQueue> GetFusionQueue(
-      HloComputation* computation) override;
-  FusionDecision ShouldFuse(HloInstruction* consumer,
-                            int64_t operand_index) override;
-
-  HloInstruction::FusionKind ChooseKind(
-      const HloInstruction* producer, const HloInstruction* consumer) override;
-
- private:
-  // This method is called by ShouldFuse() to do all the computationally
-  // inexpensive checks whether we should fuse the operand into 'consumer'.
-  FusionDecision ShouldFuseInexpensiveChecks(HloInstruction* consumer,
-                                             int64_t operand_index);
-
-  HloInstruction* FuseInstruction(HloInstruction* fusion_instruction,
-                                  HloInstruction* producer) override;
-
-  // Keep track of the number of times each instruction inside a fusion node is
-  // indexed with different index vectors.
-  absl::flat_hash_set<const HloComputation*> fusible_computations_;
-  absl::flat_hash_map<const HloInstruction*, FusionNodeIndexingEvaluation>
-      fusion_node_evaluations_;
-
-  se::DeviceDescription device_info_;
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_INSTRUCTION_FUSION_H_
diff --git a/third_party/xla/xla/service/gpu/instruction_fusion_test.cc b/third_party/xla/xla/service/gpu/instruction_fusion_test.cc
deleted file mode 100644
index fa96edf..0000000
--- a/third_party/xla/xla/service/gpu/instruction_fusion_test.cc
+++ /dev/null
@@ -1,1006 +0,0 @@
-/* Copyright 2017 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/instruction_fusion.h"
-
-#include <cstdint>
-#include <memory>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/literal_util.h"
-#include "xla/service/gpu/gpu_device_info_for_tests.h"
-#include "xla/service/gpu/gpu_fusible.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/tests/test_utils.h"
-#include "xla/tests/verified_hlo_module.h"
-#include "xla/util.h"
-#include "tsl/platform/statusor.h"
-
-namespace m = ::xla::match;
-
-namespace xla {
-namespace gpu {
-
-class InstructionFusionTest : public HloTestBase {
- public:
-  GpuInstructionFusion duplicating_instruction_fusion_{
-      /*may_duplicate=*/true, TestGpuDeviceInfo::RTXA6000DeviceInfo()};
-};
-
-TEST_F(InstructionFusionTest, NoFusionIntoCustomFusionConsumer) {
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
-                          ParseAndReturnVerifiedModule(R"(
-HloModule m
-
-c {
-  p0 = bf16[3000,53]{1,0} parameter(0)
-  p1 = bf16[22,53]{1,0} parameter(1)
-  d = bf16[3000,22]{1,0} dot(p0, p1),
-    lhs_contracting_dims={1}, rhs_contracting_dims={1}
-  r = bf16[1,1,3000,22]{3,2,1,0} reshape(d)
-  ROOT c = bf16[1,1,3000,22]{2,1,3,0} copy(r)
-}
-
-ENTRY e {
-  p1 = bf16[3000,53]{1,0} parameter(1)
-  p0 = bf16[22,53]{1,0} parameter(0)
-  cp0 = bf16[22,53]{1,0} convert(p0)
-  ROOT f = bf16[1,1,3000,22]{2,1,3,0} fusion(p1, cp0), kind=kCustom, calls=c
-})"));
-
-  EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
-}
-
-TEST_F(InstructionFusionTest,
-       CostlyProducerAndOperandElementReusingConsumerNotFused) {
-  HloComputation::Builder builder(TestName());
-  HloInstruction* const0 = builder.AddInstruction(
-      HloInstruction::CreateConstant(LiteralUtil::CreateR0(5.0f)));
-  HloInstruction* log1 = builder.AddInstruction(HloInstruction::CreateUnary(
-      ShapeUtil::MakeShape(F32, {}), HloOpcode::kLog, const0));
-  HloInstruction* broadcast2 =
-      builder.AddInstruction(HloInstruction::CreateBroadcast(
-          ShapeUtil::MakeShape(F32, {1}), log1, {}));
-
-  auto module = CreateNewVerifiedModule();
-  auto computation = module->AddEntryComputation(builder.Build());
-  EXPECT_EQ(broadcast2, computation->root_instruction());
-  EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
-  EXPECT_EQ(broadcast2, computation->root_instruction());
-}
-
-TEST_F(InstructionFusionTest,
-       NonCostlyProducerAndOperandElementReusingConsumerFused) {
-  HloComputation::Builder builder(TestName());
-  HloInstruction* const0 = builder.AddInstruction(
-      HloInstruction::CreateConstant(LiteralUtil::CreateR0(5)));
-  HloInstruction* negate1 = builder.AddInstruction(HloInstruction::CreateUnary(
-      ShapeUtil::MakeShape(S32, {}), HloOpcode::kNegate, const0));
-  HloInstruction* broadcast2 =
-      builder.AddInstruction(HloInstruction::CreateBroadcast(
-          ShapeUtil::MakeShape(S32, {1}), negate1, {}));
-
-  auto module = CreateNewVerifiedModule();
-  auto computation = module->AddEntryComputation(builder.Build());
-  EXPECT_EQ(broadcast2, computation->root_instruction());
-  EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
-  EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Fusion()));
-}
-
-TEST_F(InstructionFusionTest,
-       CostlyProducerAndNonOperandElementReusingConsumerFused_Reshape) {
-  HloComputation::Builder builder(TestName());
-  HloInstruction* const0 = builder.AddInstruction(
-      HloInstruction::CreateConstant(LiteralUtil::CreateR0(5.0f)));
-  HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
-      ShapeUtil::MakeShape(F32, {}), HloOpcode::kExp, const0));
-  HloInstruction* reshape2 = builder.AddInstruction(
-      HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {}), exp1));
-
-  auto module = CreateNewVerifiedModule();
-  auto computation = module->AddEntryComputation(builder.Build());
-  EXPECT_EQ(reshape2, computation->root_instruction());
-  EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
-  EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Fusion()));
-}
-
-TEST_F(InstructionFusionTest,
-       CostlyProducerAndNonOperandElementReusingConsumerFused_Transpose) {
-  HloComputation::Builder builder(TestName());
-  HloInstruction* const0 = builder.AddInstruction(
-      HloInstruction::CreateConstant(LiteralUtil::CreateR0(5.0f)));
-  HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
-      ShapeUtil::MakeShape(F32, {}), HloOpcode::kExp, const0));
-  HloInstruction* transpose2 = builder.AddInstruction(
-      HloInstruction::CreateTranspose(ShapeUtil::MakeShape(F32, {}), exp1, {}));
-
-  auto module = CreateNewVerifiedModule();
-  auto computation = module->AddEntryComputation(builder.Build());
-  EXPECT_EQ(transpose2, computation->root_instruction());
-  EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
-  EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Fusion()));
-}
-
-TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfDotFused) {
-  HloComputation::Builder builder(TestName());
-  auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
-      0, ShapeUtil::MakeShape(F32, {1, 1}), "0"));
-  auto dot1 = builder.AddInstruction(
-      CreateCanonicalDot(ShapeUtil::MakeShape(F32, {1, 1}), param0, param0));
-  auto reshape2 = builder.AddInstruction(HloInstruction::CreateReshape(
-      ShapeUtil::MakeShape(F32, {1, 1, 1}), dot1));
-  auto log = builder.AddInstruction(HloInstruction::CreateUnary(
-      reshape2->shape(), xla::HloOpcode::kLog, reshape2));
-
-  auto module = CreateNewVerifiedModule();
-  auto computation = module->AddEntryComputation(builder.Build());
-  EXPECT_EQ(log, computation->root_instruction());
-  EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
-}
-
-TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfDotUnfused) {
-  HloComputation::Builder builder(TestName());
-  auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
-      0, ShapeUtil::MakeShape(S32, {1, 1}), "0"));
-  auto dot1 = builder.AddInstruction(
-      CreateCanonicalDot(ShapeUtil::MakeShape(S32, {1, 1}), param0, param0));
-  auto transpose2 = builder.AddInstruction(HloInstruction::CreateTranspose(
-      ShapeUtil::MakeShape(S32, {1, 1}), dot1, {0, 1}));
-
-  auto module = CreateNewVerifiedModule();
-  auto computation = module->AddEntryComputation(builder.Build());
-  EXPECT_EQ(transpose2, computation->root_instruction());
-  EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
-}
-
-// Tests that broadcasts fused into a fusion with a reduce root.
-TEST_F(InstructionFusionTest, BroadcastIntoReduce) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule test_module
-
-    add {
-      lhs = f32[] parameter(0)
-      rhs = f32[] parameter(1)
-      ROOT add = f32[] add(lhs, rhs)
-    }
-
-    ENTRY BroadcastIntoReduce {
-      constant = f32[] constant(1)
-      broadcast = f32[16,16,16,16]{3,2,1,0} broadcast(constant), dimensions={}
-      constant.1 = f32[] constant(0)
-      ROOT reduce = f32[] reduce(broadcast, constant.1), dimensions={0,1,2,3},
-                                                         to_apply=add
-    })")
-                    .value();
-
-  EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
-
-  HloInstruction* root = module->entry_computation()->root_instruction();
-  ASSERT_THAT(root, GmockMatch(m::Fusion()));
-  EXPECT_THAT(
-      root->fused_expression_root(),
-      GmockMatch(m::Reduce(m::Broadcast(m::Constant()), m::Constant())));
-}
-
-TEST_F(InstructionFusionTest, DoNotFuseLayoutChangingOpWithReduce) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule test_module
-
-    add {
-      lhs = f32[] parameter(0)
-      rhs = f32[] parameter(1)
-      ROOT add = f32[] add(lhs, rhs)
-    }
-
-    ENTRY entry {
-      p0 = f32[16,16,16,16]{3,2,1,0} parameter(0)
-      copy = f32[16,16,16,16]{0,1,2,3} copy(p0)
-      constant.1 = f32[] constant(0)
-      ROOT reduce = f32[16] reduce(copy, constant.1), dimensions={0,1,2}, to_apply=add
-    })")
-                    .value();
-
-  EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
-}
-
-TEST_F(InstructionFusionTest, DoNotFuseLayoutChangingOpWithReduceFusion) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule test_module
-
-    add {
-      lhs = f32[] parameter(0)
-      rhs = f32[] parameter(1)
-      ROOT add = f32[] add(lhs, rhs)
-    }
-
-    fused_reduce {
-      p0.1 = f32[16,16,16,16]{0,1,2,3} parameter(0)
-      mul = f32[16,16,16,16]{0,1,2,3} multiply(p0.1, p0.1)
-      c0.1 = f32[] constant(0)
-      ROOT root = f32[] reduce(mul, c0.1), dimensions={0,1,2,3}, to_apply=add
-    }
-
-    ENTRY entry {
-      p0 = f32[16,16,16,16]{3,2,1,0} parameter(0)
-      copy = f32[16,16,16,16]{0,1,2,3} copy(p0)
-      fusion = f32[] fusion(copy), kind=kInput, calls=fused_reduce
-      ROOT root = (f32[]) tuple(fusion)
-    })")
-                    .value();
-
-  EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
-}
-
-TEST_F(InstructionFusionTest, DoNotRepeatLargeReduceWindow) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule test_module
-
-    add {
-      lhs = f32[] parameter(0)
-      rhs = f32[] parameter(1)
-      ROOT add = f32[] add(lhs, rhs)
-    }
-
-    ENTRY entry {
-      p0 = s32[512,512,2] parameter(0)
-      p1 = f32[1,1,512,512] parameter(1)
-      constant_1 = f32[] constant(1)
-      reduce-window.1 = reduce-window(p1, constant_1),
-        window={size=1x1x9x9}, to_apply=add
-      ROOT ret = gather(reduce-window.1, p0), offset_dims={0,1,2,3},
-        collapsed_slice_dims={}, start_index_map={1,2},
-        index_vector_dim=2, slice_sizes={1,1,1,1}
-    })")
-                    .value();
-
-  EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
-}
-
-TEST_F(InstructionFusionTest, FuseLayoutChangingOpWithElementwise) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule test_module
-    ENTRY entry {
-      p0 = f32[16,16,16,16]{3,2,1,0} parameter(0)
-      copy = f32[16,16,16,16]{0,1,2,3} copy(p0)
-      ROOT add = f32[16,16,16,16]{0,1,2,3} add(copy, copy)
-    })")
-                    .value();
-
-  EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
-
-  HloInstruction* root = module->entry_computation()->root_instruction();
-  ASSERT_THAT(root, GmockMatch(m::Fusion()));
-  EXPECT_THAT(root->fused_expression_root(),
-              GmockMatch(m::Add(m::Copy(), m::Copy())));
-}
-
-TEST_F(InstructionFusionTest, BitcastIntoAdd) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule test_module
-
-    ENTRY BroadcastIntoAdd {
-      p0 = f32[4,1,1]{2,1,0} parameter(0)
-      p1 = f32[4,1]{1,0} parameter(1)
-      bitcast = f32[4,1]{1,0} bitcast(p0)
-      ROOT add = f32[4,1] add(bitcast, p1)
-    })")
-                    .value();
-
-  EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
-
-  HloInstruction* root = module->entry_computation()->root_instruction();
-  ASSERT_THAT(root, GmockMatch(m::Fusion()));
-  EXPECT_THAT(root->fused_expression_root(),
-              GmockMatch(m::Add(m::Bitcast(m::Parameter()), m::Parameter())));
-}
-
-TEST_F(InstructionFusionTest, AddIntoBitcast) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule test_module
-
-    ENTRY BroadcastIntoAdd {
-      p0 = f32[4,1]{1,0} parameter(0)
-      p1 = f32[4,1]{1,0} parameter(1)
-      add = f32[4,1] add(p0, p1)
-      ROOT bitcast = f32[4,1,1] bitcast(add)
-    })")
-                    .value();
-
-  EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
-}
-
-TEST_F(InstructionFusionTest, ConvertIntoBitcastBothConsumedByTuple) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule test
-
-  ENTRY main {
-    param_0 = f32[2048,16000]{1,0} parameter(0)
-    convert = bf16[2048,16000]{1,0} convert(param_0)
-    bitcast = bf16[16000,1,2048]{2,1,0} bitcast(convert)
-    ROOT tuple.143 = (bf16[16000,1,2048]{2,1,0}, bf16[2048,16000]{1,0}) tuple(bitcast, convert)
-  })")
-                    .value();
-  EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
-}
-
-TEST_F(InstructionFusionTest, DontFuseGTE) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule test_module
-  ENTRY DontFuseGTE {
-    p0 = (f32[10], f32[10]) parameter(0)
-    gte0 = f32[10] get-tuple-element(p0), index=0
-    gte1 = f32[10] get-tuple-element(p0), index=1
-    ROOT add = f32[10] add(gte0, gte1)
-  })")
-                    .value();
-
-  EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
-}
-
-// Compute sum(1/p0), where p0 has type f32, twice.  Check that the division is
-// duplicated and fused into both reduces.
-TEST_F(InstructionFusionTest, FloatingPointDivIsCheap) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule test_module
-  Add {
-    lhs = f32[] parameter(0)
-    rhs = f32[] parameter(1)
-    ROOT add = f32[] add(lhs, rhs)
-  }
-  ENTRY TestComputation {
-    zero = f32[] constant(0)
-    p0 = f32[100] parameter(0)
-    p1 = f32[100] parameter(1)
-    recip = f32[100] divide(p1, p0)
-    sum1 = f32[] reduce(recip, zero), dimensions={0}, to_apply=Add
-    sum2 = f32[] reduce(recip, zero), dimensions={0}, to_apply=Add
-    ROOT root = (f32[], f32[]) tuple(sum1, sum2)
-  })")
-                    .value();
-
-  EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
-
-  HloInstruction* root = module->entry_computation()->root_instruction();
-  EXPECT_THAT(root, GmockMatch(m::Tuple(m::Fusion(), m::Fusion())))
-      << module->ToString();
-}
-
-// Compute sum(100/p0), where p0 has type s32, twice.  Check that the division
-// is *not* duplicated and fused into both reduces, because we say that integer
-// division is not cheap.
-TEST_F(InstructionFusionTest, IntegerDivIsNotCheap) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule test_module
-  Add {
-    lhs = s32[] parameter(0)
-    rhs = s32[] parameter(1)
-    ROOT add = s32[] add(lhs, rhs)
-  }
-  ENTRY TestComputation {
-    zero = s32[] constant(0)
-    p0 = s32[100] parameter(0)
-    p1 = s32[100] parameter(1)
-    recip = s32[100] divide(p1, p0)
-    sum1 = s32[] reduce(recip, zero), dimensions={0}, to_apply=Add
-    sum2 = s32[] reduce(recip, zero), dimensions={0}, to_apply=Add
-    ROOT mul = (s32[], s32[]) tuple(sum1, sum2)
-  })")
-                    .value();
-
-  EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value())
-      << module->ToString();
-}
-
-TEST_F(InstructionFusionTest, DotOutputFusionImpossible) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule test_module
-  ENTRY NoOutputFusion {
-    alpha = f32[] constant(3)
-    broadcast = f32[4,4]{1,0} broadcast(alpha), dimensions={}
-    p0 = f32[4,3]{1,0} parameter(0)
-    p1 = f32[3,4]{1,0} parameter(1)
-    dot = f32[4,4]{1,0} dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-    d = f32[4,4]{1,0} multiply(dot, dot)
-    ROOT mul = f32[4,4] multiply(d, broadcast)
-  })")
-                    .value();
-
-  EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
-
-  HloInstruction* root = module->entry_computation()->root_instruction();
-  ASSERT_THAT(root, GmockMatch(m::Fusion()));
-  EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kLoop);
-  EXPECT_THAT(
-      root->fused_expression_root(),
-      GmockMatch(m::Multiply(m::Multiply(m::Parameter(), m::Parameter()),
-                             m::Broadcast(m::Constant()))));
-}
-
-// Counts the HLO ops with a given op code in the specified module.
-static int Count(const HloModule& module, HloOpcode op) {
-  int count = 0;
-  for (const auto* computation : module.computations()) {
-    for (const auto* instruction : computation->instructions()) {
-      if (instruction->opcode() == op) {
-        ++count;
-      }
-    }
-  }
-  return count;
-}
-
-TEST_F(InstructionFusionTest, MultiOutputFusion) {
-  // sub --> add --> tuple
-  //  \---------------/
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule test_module
-    ENTRY OutputFusion {
-     p0 = f32[4,3]{1,0} parameter(0)
-     p1 = f32[4,3]{1,0} parameter(1)
-     p2 = f32[4,3]{1,0} parameter(2)
-     sub = f32[4,3]{1,0} subtract(p0, p2)
-     add = f32[4,3]{1,0} add(sub, p1)
-     ROOT tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}) tuple(sub, add)
-    })")
-                    .value();
-
-  // Multi-output fusion is disabled here and performed in the
-  // GpuMultiOutputFusion pass instead.
-  ASSERT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
-}
-
-TEST_F(InstructionFusionTest, FuseScalarConstant) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule test_module
-
-  ENTRY FuseScalarConstant {
-    p0 = f32[] parameter(0)
-    c0 = f32[] constant(1)
-    add1 = f32[] add(p0, c0)
-    b0 = f32[2]{0} broadcast(add1), dimensions={}
-    c1 = f32[2]{0} constant({1, 2})
-    ROOT add2 = f32[2]{0} add(b0, c1)
-  })")
-                    .value();
-
-  EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
-
-  HloInstruction* root = module->entry_computation()->root_instruction();
-  ASSERT_THAT(root, GmockMatch(m::Fusion()));
-  EXPECT_THAT(
-      root->fused_expression_root(),
-      GmockMatch(m::Add(m::Broadcast(m::Add(m::Parameter(), m::Constant())),
-                        m::Parameter())));
-}
-
-// Check that we limit the number of operands to fusions we create.
-TEST_F(InstructionFusionTest, AvoidsLargeFusion) {
-  constexpr int64_t kNumParams = 200;
-  ASSERT_GT(kNumParams, MaxOperandsAndOutputsPerFusion());
-
-  // Compute p0 + p1 + ... + pN.
-  HloComputation::Builder b(TestName());
-  Shape shape = ShapeUtil::MakeShape(F32, {10, 100});
-  auto param0 =
-      b.AddInstruction(HloInstruction::CreateParameter(0, shape, "p"));
-  auto sum = param0;
-  for (int64_t i = 1; i < kNumParams; ++i) {
-    auto param =
-        b.AddInstruction(HloInstruction::CreateParameter(i, shape, "p"));
-    sum = b.AddInstruction(
-        HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sum, param));
-  }
-  auto module = CreateNewVerifiedModule();
-  auto computation = module->AddEntryComputation(b.Build());
-  EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
-  SCOPED_TRACE(module->ToString());
-  for (const HloInstruction* instr : computation->instructions()) {
-    EXPECT_LE(instr->operand_count(), MaxOperandsAndOutputsPerFusion())
-        << instr->ToString();
-  }
-}
-
-TEST_F(InstructionFusionTest, FuseIntoScatter) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule test_module
-
-    add {
-      lhs = s32[] parameter(0)
-      rhs = s32[] parameter(1)
-      ROOT add = s32[] add(lhs, rhs)
-    }
-
-    ENTRY FuseIntoScatter {
-      p0 = s32[3,3] parameter(0)
-      p1 = s32[2] parameter(1)
-      indices = s32[2] add(p1, p1)
-      p2 = s32[2,3] parameter(2)
-      updates = s32[2,3] add(p2, p2)
-      scatter = s32[3,3] scatter(p0, indices, updates),
-          to_apply=add,
-          update_window_dims={1},
-          inserted_window_dims={0},
-          scatter_dims_to_operand_dims={0},
-          index_vector_dim=1
-      ROOT add = s32[3,3] add(scatter, scatter)
-    })")
-                    .value();
-
-  EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
-
-  HloInstruction* root = module->entry_computation()->root_instruction();
-  const HloInstruction* fusion = nullptr;
-  ASSERT_THAT(root, GmockMatch(m::Add(m::Fusion(&fusion), m::Fusion())));
-  EXPECT_EQ(fusion->fusion_kind(), HloInstruction::FusionKind::kInput);
-  EXPECT_THAT(fusion->fused_expression_root(),
-              GmockMatch(m::Scatter(m::Parameter(), m::Add(), m::Add())));
-}
-
-TEST_F(InstructionFusionTest, DontFuseIntoFirstOperandOfScatter) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule test_module
-
-    add {
-      lhs = s32[] parameter(0)
-      rhs = s32[] parameter(1)
-      ROOT add = s32[] add(lhs, rhs)
-    }
-
-    ENTRY FuseIntoScatter {
-      p0 = s32[3,3] parameter(0)
-      operand = s32[3,3] add(p0, p0)
-      p1 = s32[2] parameter(1)
-      indices = s32[2] add(p1, p1)
-      p2 = s32[2,3] parameter(2)
-      updates = s32[2,3] add(p2, p2)
-      scatter = s32[3,3] scatter(operand, indices, updates),
-          to_apply=add,
-          update_window_dims={1},
-          inserted_window_dims={0},
-          scatter_dims_to_operand_dims={0},
-          index_vector_dim=1
-      ROOT add = s32[3,3] add(scatter, scatter)
-    })")
-                    .value();
-
-  EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
-
-  HloInstruction* root = module->entry_computation()->root_instruction();
-  const HloInstruction* fusion = nullptr;
-  ASSERT_THAT(root, GmockMatch(m::Add(m::Fusion(&fusion), m::Fusion())));
-  EXPECT_EQ(fusion->fusion_kind(), HloInstruction::FusionKind::kInput);
-  EXPECT_THAT(fusion->fused_expression_root(),
-              GmockMatch(m::Scatter(m::Parameter(), m::Add(), m::Add())));
-}
-
-TEST_F(InstructionFusionTest, ScatterOpShouldNotFuseWithSharedOperand) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule test_module
-
-  add {
-    lhs = f32[] parameter(0)
-    rhs = f32[] parameter(1)
-    ROOT add = f32[] add(lhs, rhs)
-  }
-
-  ENTRY Test {
-    parameter.0 = f32[8,8] parameter(0)
-    parameter.1 = s32[7] parameter(1)
-    indices = s32[7] add(parameter.1, parameter.1)
-    slice = f32[7,8] slice(parameter.0), slice={[0:7],[0:8]}
-    ROOT scatter = f32[8,8] scatter(parameter.0, indices, slice),
-        to_apply=add,
-        update_window_dims={1},
-        inserted_window_dims={0},
-        scatter_dims_to_operand_dims={0},
-        index_vector_dim=1
-  })")
-                    .value();
-  EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
-  // Verify that we don't fuse scatter and slice together since
-  // scatter modifies the input buffer in-place, which is also used
-  // as slice's input, and we don't know where the scatter indices point to.
-  HloInstruction* root = module->entry_computation()->root_instruction();
-  EXPECT_THAT(
-      root, GmockMatch(m::Fusion(m::Parameter(), m::Slice(), m::Parameter())));
-}
-
-TEST_F(InstructionFusionTest, NonscalarConstantsNotFused) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule test_module
-
-    add {
-      lhs = f32[] parameter(0)
-      rhs = f32[] parameter(1)
-      ROOT add = f32[] add(lhs, rhs)
-    }
-
-    ENTRY BroadcastIntoReduce {
-      constant = f32[16] constant({0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15})
-      broadcast = f32[16,16,16,16]{3,2,1,0} broadcast(constant), dimensions={0}
-      constant.1 = f32[] constant(0)
-      ROOT reduce = f32[] reduce(broadcast, constant.1), dimensions={0,1,2,3},
-                                                         to_apply=add
-    })")
-                    .value();
-
-  EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
-  // The f32[16] constant should not be fused into the reduce, but the f32[]
-  // constant should be.
-  auto* root = module->entry_computation()->root_instruction();
-  ASSERT_THAT(root, GmockMatch(m::Fusion()));
-  EXPECT_THAT(
-      root->fused_instructions_computation()->root_instruction(),
-      GmockMatch(m::Reduce(m::Broadcast(m::Parameter()), m::Constant())));
-}
-
-TEST_F(InstructionFusionTest, FuseReverse) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule test_module
-
-    ENTRY Reverse {
-      p0 = f32[50,96,1024]{2,1,0} parameter(0)
-      add = f32[50,96,1024]{2,1,0} add(p0, p0)
-      ROOT reverse = f32[50,96,1024] reverse(add), dimensions={0}
-    })")
-                    .value();
-
-  EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
-
-  HloInstruction* root = module->entry_computation()->root_instruction();
-  ASSERT_THAT(root, GmockMatch(m::Fusion()));
-  EXPECT_THAT(root->fused_expression_root(),
-              GmockMatch(m::Reverse(m::Add(m::Parameter(), m::Parameter()))));
-}
-
-TEST_F(InstructionFusionTest, GpuIsExpensiveF32) {
-  auto m = CreateNewVerifiedModule();
-  Shape r0f32 = ShapeUtil::MakeShape(F32, {});
-  HloComputation::Builder builder(TestName());
-  HloInstruction* param0 = builder.AddInstruction(
-      HloInstruction::CreateParameter(0, r0f32, "param0"));
-
-  HloInstruction* one = builder.AddInstruction(
-      HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
-  HloInstruction* div = builder.AddInstruction(
-      HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, one));
-  HloInstruction* rem = builder.AddInstruction(
-      HloInstruction::CreateBinary(r0f32, HloOpcode::kRemainder, param0, one));
-  HloInstruction* sqrt = builder.AddInstruction(
-      HloInstruction::CreateUnary(r0f32, HloOpcode::kSqrt, param0));
-  HloInstruction* rsqrt = builder.AddInstruction(
-      HloInstruction::CreateUnary(r0f32, HloOpcode::kRsqrt, param0));
-  HloInstruction* exp = builder.AddInstruction(
-      HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0));
-
-  EXPECT_FALSE(GpuInstructionFusion::IsExpensive(*div));
-  EXPECT_TRUE(GpuInstructionFusion::IsExpensive(*rem));
-  EXPECT_FALSE(GpuInstructionFusion::IsExpensive(*sqrt));
-  EXPECT_FALSE(GpuInstructionFusion::IsExpensive(*rsqrt));
-  EXPECT_FALSE(GpuInstructionFusion::IsExpensive(*exp));
-}
-
-TEST_F(InstructionFusionTest, GpuIsExpensiveF64) {
-  auto m = CreateNewVerifiedModule();
-  Shape r0f64 = ShapeUtil::MakeShape(F64, {});
-  HloComputation::Builder builder(TestName());
-  HloInstruction* param0 = builder.AddInstruction(
-      HloInstruction::CreateParameter(0, r0f64, "param0"));
-
-  HloInstruction* one = builder.AddInstruction(
-      HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
-  HloInstruction* div = builder.AddInstruction(
-      HloInstruction::CreateBinary(r0f64, HloOpcode::kDivide, param0, one));
-  HloInstruction* rem = builder.AddInstruction(
-      HloInstruction::CreateBinary(r0f64, HloOpcode::kRemainder, param0, one));
-  HloInstruction* sqrt = builder.AddInstruction(
-      HloInstruction::CreateUnary(r0f64, HloOpcode::kSqrt, param0));
-  HloInstruction* rsqrt = builder.AddInstruction(
-      HloInstruction::CreateUnary(r0f64, HloOpcode::kRsqrt, param0));
-  HloInstruction* exp = builder.AddInstruction(
-      HloInstruction::CreateUnary(r0f64, HloOpcode::kExp, param0));
-
-  EXPECT_TRUE(GpuInstructionFusion::IsExpensive(*div));
-  EXPECT_TRUE(GpuInstructionFusion::IsExpensive(*rem));
-  EXPECT_TRUE(GpuInstructionFusion::IsExpensive(*sqrt));
-  EXPECT_TRUE(GpuInstructionFusion::IsExpensive(*rsqrt));
-  EXPECT_TRUE(GpuInstructionFusion::IsExpensive(*exp));
-}
-
-TEST_F(InstructionFusionTest, GpuIsExpensiveS32) {
-  auto m = CreateNewVerifiedModule();
-  Shape r0s32 = ShapeUtil::MakeShape(S32, {});
-  HloComputation::Builder builder(TestName());
-  HloInstruction* param0 = builder.AddInstruction(
-      HloInstruction::CreateParameter(0, r0s32, "param0"));
-
-  HloInstruction* one = builder.AddInstruction(
-      HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
-  HloInstruction* div = builder.AddInstruction(
-      HloInstruction::CreateBinary(r0s32, HloOpcode::kDivide, param0, one));
-  HloInstruction* rem = builder.AddInstruction(
-      HloInstruction::CreateBinary(r0s32, HloOpcode::kRemainder, param0, one));
-
-  EXPECT_FALSE(GpuInstructionFusion::IsExpensive(*div));
-  EXPECT_FALSE(GpuInstructionFusion::IsExpensive(*rem));
-}
-
-TEST_F(InstructionFusionTest, GpuIsExpensiveBroadcastS32) {
-  auto m = CreateNewVerifiedModule();
-  Shape r1s32 = ShapeUtil::MakeShape(S32, {10});
-  HloComputation::Builder builder(TestName());
-  HloInstruction* param0 = builder.AddInstruction(
-      HloInstruction::CreateParameter(0, r1s32, "param0"));
-
-  HloInstruction* one = builder.AddInstruction(
-      HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
-  HloInstruction* one_broad =
-      builder.AddInstruction(HloInstruction::CreateBroadcast(r1s32, one, {}));
-
-  HloInstruction* div = builder.AddInstruction(HloInstruction::CreateBinary(
-      r1s32, HloOpcode::kDivide, param0, one_broad));
-  HloInstruction* rem = builder.AddInstruction(HloInstruction::CreateBinary(
-      r1s32, HloOpcode::kRemainder, param0, one_broad));
-
-  EXPECT_FALSE(GpuInstructionFusion::IsExpensive(*div));
-  EXPECT_FALSE(GpuInstructionFusion::IsExpensive(*rem));
-}
-
-TEST_F(InstructionFusionTest, FloatingPointExpIsCheap) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule test_module
-  Add {
-    lhs = f32[] parameter(0)
-    rhs = f32[] parameter(1)
-    ROOT add = f32[] add(lhs, rhs)
-  }
-  ENTRY TestComputation {
-    zero = f32[] constant(0)
-    p0 = f32[100] parameter(0)
-    recip = f32[100] exponential(p0)
-    sum1 = f32[] reduce(recip, zero), dimensions={0}, to_apply=Add
-    sum2 = f32[] reduce(recip, zero), dimensions={0}, to_apply=Add
-    ROOT root = (f32[], f32[]) tuple(sum1, sum2)
-  })")
-                    .value();
-
-  EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
-
-  HloInstruction* root = module->entry_computation()->root_instruction();
-  EXPECT_THAT(root, GmockMatch(m::Tuple(m::Fusion(), m::Fusion())))
-      << module->ToString();
-}
-
-TEST_F(InstructionFusionTest, SmallReducedDimensionIsNotLoweredToLoop) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule test_module
-
-    add {
-      lhs = s32[] parameter(0)
-      rhs = s32[] parameter(1)
-      ROOT add = s32[] add(lhs, rhs)
-    }
-
-    ENTRY FuseSmallReduction {
-      p0 = s32[1048576,4] parameter(0)
-      p1 = s32[1048576,4] parameter(1)
-      sum = s32[1048576,4] add(p0, p1)
-      init = s32[] constant(0)
-      ROOT reduce = s32[1048576] reduce(sum, init), dimensions={1}, to_apply=add
-    })")
-                    .value();
-
-  EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
-
-  HloInstruction* root = module->entry_computation()->root_instruction();
-  ASSERT_THAT(root, GmockMatch(m::Fusion()));
-  EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kInput);
-}
-
-TEST_F(InstructionFusionTest, IotaIntoVariadicReduction) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule m
-
-  f {
-    tmp_0 = f32[] parameter(0)
-    tmp_1 = f32[] parameter(1)
-    tmp_2 = pred[] compare(tmp_0, tmp_1), direction=GE
-    tmp_3 = f32[] select(tmp_2, tmp_0, tmp_1)
-    tmp_4 = pred[] compare(tmp_0, tmp_1), direction=EQ
-    tmp_5 = s32[] parameter(2)
-    tmp_6 = s32[] parameter(3)
-    tmp_7 = s32[] minimum(tmp_5, tmp_6)
-    tmp_8 = s32[] select(tmp_2, tmp_5, tmp_6)
-    tmp_9 = s32[] select(tmp_4, tmp_7, tmp_8)
-    ROOT tmp_10 = (f32[], s32[]) tuple(tmp_3, tmp_9)
-  }
-
-  minmax {
-    tmp_0 = f32[] parameter(0)
-    tmp_1 = f32[] parameter(2)
-    tmp_2 = s32[] parameter(1)
-    tmp_3 = s32[] parameter(3)
-    ROOT tmp_4 = (f32[], s32[]) fusion(tmp_0, tmp_1, tmp_2, tmp_3), kind=kLoop, calls=f
-  }
-
-  ENTRY e {
-    tmp_0 = f32[554112,10]{1,0} parameter(0)
-    tmp_1 = s32[554112,10]{1,0} iota(), iota_dimension=1
-    tmp_2 = f32[] constant(-inf)
-    tmp_3 = s32[] constant(0)
-    ROOT tmp_4 = (f32[554112]{0}, s32[554112]{0}) reduce(tmp_0, tmp_1, tmp_2, tmp_3), dimensions={1}, to_apply=minmax
-  })")
-                    .value();
-
-  EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/false,
-                                   TestGpuDeviceInfo::RTXA6000DeviceInfo())
-                  .Run(module.get())
-                  .value());
-  ASSERT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch(m::Fusion(m::Parameter())));
-  EXPECT_THAT(
-      module->entry_computation()->root_instruction()->fused_expression_root(),
-      GmockMatch(
-          m::Reduce(m::Parameter(), m::Iota(), m::Constant(), m::Constant())));
-}
-
-TEST_F(InstructionFusionTest, InputReductionFusion) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule test_module
-    add.clone.13 {
-      x.27 = f32[] parameter(0)
-      y.27 = f32[] parameter(1)
-      ROOT add.1036 = f32[] add(x.27, y.27)
-    }
-    add.clone.14 {
-      x.28 = f32[] parameter(0)
-      y.28 = f32[] parameter(1)
-      ROOT add.1037 = f32[] add(x.28, y.28)
-    }
-    add {
-      x = bf16[] parameter(0)
-      convert.448 = f32[] convert(x)
-      y = bf16[] parameter(1)
-      convert.449 = f32[] convert(y)
-      add.597 = f32[] add(convert.448, convert.449)
-      ROOT convert.450 = bf16[] convert(add.597)
-    }
-    ENTRY FuseSmallReduction {
-      param_2.7 = bf16[8,16,64,2048]{3,2,1,0} parameter(2)
-      convert.1395 = f32[8,16,64,2048]{3,2,1,0} convert(param_2.7)
-      param_0.85 = bf16[8,16,64,2048]{3,2,1,0} parameter(0)
-      convert.1393 = f32[8,16,64,2048]{3,2,1,0} convert(param_0.85)
-      multiply.1652 = f32[8,16,64,2048]{3,2,1,0} multiply(convert.1395, convert.1393)
-      convert.1392 = bf16[8,16,64,2048]{3,2,1,0} convert(multiply.1652)
-      bitcast.15934 = bf16[128,64,2048]{2,1,0} bitcast(convert.1392)
-      convert.1391 = f32[128,64,2048]{2,1,0} convert(bitcast.15934)
-      param_1.15 = bf16[] parameter(1)
-      convert.1394 = f32[] convert(param_1.15)
-      reduce.462 = f32[128,64]{1,0} reduce(convert.1391, convert.1394), dimensions={2}, to_apply=add.clone.13
-      reduce.121 = f32[64]{0} reduce(reduce.462, convert.1394), dimensions={0}, to_apply=add.clone.14
-      ROOT convert.890 = bf16[64]{0} convert(reduce.121)
-    })")
-                    .value();
-
-  EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
-
-  HloInstruction* fused_convert_fusion =
-      module->entry_computation()->root_instruction();
-
-  ASSERT_THAT(fused_convert_fusion, GmockMatch(m::Fusion()));
-  SCOPED_TRACE(module->ToString());
-  EXPECT_EQ(fused_convert_fusion->fusion_kind(),
-            HloInstruction::FusionKind::kInput);
-}
-
-TEST_F(InstructionFusionTest, DotStrengthReductionFusion) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule test_module
-
-scalar_add_computation {
-  scalar_rhs = f32[] parameter(1)
-  scalar_lhs = f32[] parameter(0)
-  ROOT add.1 = f32[] add(scalar_lhs, scalar_rhs)
-}
-
-ENTRY main {
-  param_1.3 = f16[16,64,96,6,2,16]{5,4,3,2,1,0} parameter(1)
-  param_0.6 = f16[16,64,96,1,2,16]{5,4,3,2,1,0} parameter(0)
-  bitcast.26 = f16[16,64,96,2,16]{4,3,2,1,0} bitcast(param_0.6)
-  broadcast.4 = f16[16,64,96,6,2,16]{5,4,3,2,1,0} broadcast(bitcast.26), dimensions={0,1,2,4,5}
-  multiply.4 = f16[16,64,96,6,2,16]{5,4,3,2,1,0} multiply(broadcast.4, param_1.3)
-  convert.8 = f32[16,64,96,6,2,16]{5,4,3,2,1,0} convert(multiply.4)
-  constant_2 = f32[] constant(0)
-  reduce.3 = f32[16,64,96,6,2]{3,4,2,1,0} reduce(convert.8, constant_2), dimensions={5}, to_apply=scalar_add_computation
-  bitcast.25 = f32[16,64,96,2,6]{4,3,2,1,0} bitcast(reduce.3)
-  convert.7 = f16[16,64,96,2,6]{4,3,2,1,0} convert(bitcast.25)
-  ROOT bitcast.24 = f16[16,64,96,2,1,6]{5,4,3,2,1,0} bitcast(convert.7)
-})")
-                    .value();
-
-  EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
-
-  const HloInstruction* fused_convert_fusion =
-      module->entry_computation()->root_instruction()->operand(0);
-
-  ASSERT_THAT(fused_convert_fusion, GmockMatch(m::Fusion()));
-  SCOPED_TRACE(module->ToString());
-  EXPECT_EQ(fused_convert_fusion->fusion_kind(),
-            HloInstruction::FusionKind::kInput);
-  EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1);
-}
-
-TEST_F(InstructionFusionTest, ReductionFusionOtherUnaryElementwiseOpsAreFused) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule test_module
-
-scalar_add_computation {
-  scalar_rhs = f32[] parameter(1)
-  scalar_lhs = f32[] parameter(0)
-  ROOT add.1 = f32[] add(scalar_lhs, scalar_rhs)
-}
-
-ENTRY main {
-  param_0 = f16[64,96,6,16]{3,2,1,0} parameter(0)
-  constant_2 = f32[] constant(0)
-  reduce.3 = f32[64,6,16]{2,1,0} reduce(param_0, constant_2), dimensions={1}, to_apply=scalar_add_computation
-  negate = f32[64,6,16]{2,1,0} negate(reduce.3)
-  ROOT sine = f16[64,6,16]{2,1,0} sine(negate)
-})")
-                    .value();
-
-  EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
-
-  HloInstruction* fused_convert_fusion =
-      module->entry_computation()->root_instruction();
-
-  ASSERT_THAT(fused_convert_fusion, GmockMatch(m::Fusion()));
-  SCOPED_TRACE(module->ToString());
-  EXPECT_EQ(fused_convert_fusion->fusion_kind(),
-            HloInstruction::FusionKind::kInput);
-  EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1);
-}
-
-TEST_F(InstructionFusionTest, DoNotFuseInsideReducer) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule test_module
-
-scalar_add_computation {
-  scalar_rhs = f32[] parameter(1)
-  scalar_lhs = f32[] parameter(0)
-  add.1 = f32[] add(scalar_lhs, scalar_rhs)
-  ROOT add.2 = f32[] add(add.1, scalar_rhs)
-}
-
-ENTRY main {
-  param_0 = f16[64,96] parameter(0)
-  constant_2 = f32[] constant(0)
-  ROOT reduce = f32[64] reduce(param_0, constant_2), dimensions={1}, to_apply=scalar_add_computation
-})")
-                    .value();
-
-  EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
-  SCOPED_TRACE(module->ToString());
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/ir_emission_utils.cc b/third_party/xla/xla/service/gpu/ir_emission_utils.cc
index e2882a7..ee1140d 100644
--- a/third_party/xla/xla/service/gpu/ir_emission_utils.cc
+++ b/third_party/xla/xla/service/gpu/ir_emission_utils.cc
@@ -15,6 +15,7 @@
 
 #include "xla/service/gpu/ir_emission_utils.h"
 
+#include <algorithm>
 #include <cstdint>
 #include <functional>
 #include <optional>
@@ -60,9 +61,9 @@
 #include "xla/shape_util.h"
 #include "xla/status_macros.h"
 #include "xla/stream_executor/device_description.h"
+#include "xla/tsl/lib/strings/proto_serialization.h"
 #include "xla/util.h"
 #include "xla/xla_data.pb.h"
-#include "tsl/lib/strings/proto_serialization.h"
 #include "tsl/platform/protobuf.h"
 #include "tsl/platform/statusor.h"
 
@@ -81,6 +82,13 @@
   return shape.rank() == batch_dimensions_size + 1;
 }
 
+bool IsMlirTransposeEmitterEnabled(const HloInstruction& hlo) {
+  return hlo.GetModule()
+             ->config()
+             .debug_options()
+             .xla_gpu_mlir_emitter_level() >= 3;
+}
+
 }  // namespace
 
 bool IsMatrixMultiplication(const HloInstruction& dot) {
@@ -538,32 +546,53 @@
     return std::nullopt;
   }
 
-  if (std::optional<Vector3> tr = ShapeUtil::GetNormalizedTransposeShape(
-          instr.operand(0)->shape(), instr.shape(), Vector3{0, 2, 1})) {
+  if (std::optional<absl::InlinedVector<int64_t, 3>> tr =
+          ShapeUtil::GetNormalizedTransposeShape(
+              instr.operand(0)->shape(), instr.shape(),
+              absl::InlinedVector<int64_t, 3>{0, 2, 1})) {
     if ((tr->at(1) >= kMinDimensionToTransposeTiled &&
          tr->at(2) >= kMinDimensionToTransposeTiled) ||
         (tr->at(1) >= kMinDimensionToTransposeTiled2 &&
          tr->at(2) >= kMinDimensionToTransposeTiled2 &&
          tr->at(1) * tr->at(2) >= kMinTotalDimensionsToTransposeTiled)) {
-      return TransposeDescription{&instr, *tr,
-                                  /*permutation=*/Vector3{0, 2, 1}};
+      return TransposeDescription{
+          &instr, *tr,
+          /*permutation=*/absl::InlinedVector<int64_t, 3>{0, 2, 1}};
     }
   }
-  if (std::optional<Vector3> tr = ShapeUtil::GetNormalizedTransposeShape(
-          instr.operand(0)->shape(), instr.shape(), Vector3{2, 1, 0})) {
+  if (std::optional<absl::InlinedVector<int64_t, 3>> tr =
+          ShapeUtil::GetNormalizedTransposeShape(
+              instr.operand(0)->shape(), instr.shape(),
+              absl::InlinedVector<int64_t, 3>{2, 1, 0})) {
     if ((tr->at(0) >= kMinDimensionToTransposeTiled &&
          tr->at(2) >= kMinDimensionToTransposeTiled) ||
         (tr->at(0) >= kMinDimensionToTransposeTiled2 &&
          tr->at(2) >= kMinDimensionToTransposeTiled2 &&
          tr->at(0) * tr->at(2) >= kMinTotalDimensionsToTransposeTiled)) {
-      return TransposeDescription{&instr, *tr,
-                                  /*permutation=*/Vector3{2, 1, 0}};
+      return TransposeDescription{
+          &instr, *tr,
+          /*permutation=*/absl::InlinedVector<int64_t, 3>{2, 1, 0}};
+    }
+  }
+  if (IsMlirTransposeEmitterEnabled(instr)) {
+    if (std::optional<absl::InlinedVector<int64_t, 3>> tr =
+            ShapeUtil::GetNormalizedTransposeShape(
+                instr.operand(0)->shape(), instr.shape(),
+                absl::InlinedVector<int64_t, 3>{1, 0, 2})) {
+      auto byte_width = primitive_util::ByteWidth(instr.shape().element_type());
+      if (byte_width * tr->at(2) <= kMaxBytesInMostMinorDimension &&
+          byte_width * tr->at(2) * std::min(tr->at(0), tr->at(1)) >=
+              kMinDimensionToTransposeTiled) {
+        return TransposeDescription{
+            &instr, *tr,
+            /*permutation=*/absl::InlinedVector<int64_t, 3>{1, 0, 2}};
+      }
     }
   }
   return std::nullopt;
 }
 
-// Find 021 or 210 transpose in logical + physical transposition.
+// Find 021, 210 or 102 transpose in logical + physical transposition.
 static std::optional<TransposeDescription> FindTiledLogicalTranspose(
     const HloInstruction& instr) {
   if (instr.opcode() != HloOpcode::kTranspose) {
@@ -571,28 +600,47 @@
   }
 
   // TODO(cheshire): avoid code duplication.
-  if (std::optional<Vector3> tr = ShapeUtil::GetNormalizedLogicalTransposeShape(
-          instr.operand(0)->shape(), instr.shape(), instr.dimensions(),
-          Vector3{0, 2, 1})) {
+  if (std::optional<absl::InlinedVector<int64_t, 3>> tr =
+          ShapeUtil::GetNormalizedLogicalTransposeShape(
+              instr.operand(0)->shape(), instr.shape(), instr.dimensions(),
+              absl::InlinedVector<int64_t, 3>{0, 2, 1})) {
     if ((tr->at(1) >= kMinDimensionToTransposeTiled &&
          tr->at(2) >= kMinDimensionToTransposeTiled) ||
         (tr->at(1) >= kMinDimensionToTransposeTiled2 &&
          tr->at(2) >= kMinDimensionToTransposeTiled2 &&
          tr->at(1) * tr->at(2) >= kMinTotalDimensionsToTransposeTiled)) {
-      return TransposeDescription{&instr, *tr,
-                                  /*permutation=*/Vector3{0, 2, 1}};
+      return TransposeDescription{
+          &instr, *tr,
+          /*permutation=*/absl::InlinedVector<int64_t, 3>{0, 2, 1}};
     }
   }
-  if (std::optional<Vector3> tr = ShapeUtil::GetNormalizedLogicalTransposeShape(
-          instr.operand(0)->shape(), instr.shape(), instr.dimensions(),
-          Vector3{2, 1, 0})) {
+  if (std::optional<absl::InlinedVector<int64_t, 3>> tr =
+          ShapeUtil::GetNormalizedLogicalTransposeShape(
+              instr.operand(0)->shape(), instr.shape(), instr.dimensions(),
+              absl::InlinedVector<int64_t, 3>{2, 1, 0})) {
     if ((tr->at(0) >= kMinDimensionToTransposeTiled &&
          tr->at(2) >= kMinDimensionToTransposeTiled) ||
         (tr->at(0) >= kMinDimensionToTransposeTiled2 &&
          tr->at(2) >= kMinDimensionToTransposeTiled2 &&
          tr->at(0) * tr->at(2) >= kMinTotalDimensionsToTransposeTiled)) {
-      return TransposeDescription{&instr, *tr,
-                                  /*permutation=*/Vector3{2, 1, 0}};
+      return TransposeDescription{
+          &instr, *tr,
+          /*permutation=*/absl::InlinedVector<int64_t, 3>{2, 1, 0}};
+    }
+  }
+  if (IsMlirTransposeEmitterEnabled(instr)) {
+    if (std::optional<absl::InlinedVector<int64_t, 3>> tr =
+            ShapeUtil::GetNormalizedLogicalTransposeShape(
+                instr.operand(0)->shape(), instr.shape(), instr.dimensions(),
+                absl::InlinedVector<int64_t, 3>{1, 0, 2})) {
+      auto byte_width = primitive_util::ByteWidth(instr.shape().element_type());
+      if (byte_width * tr->at(2) <= kMaxBytesInMostMinorDimension &&
+          byte_width * tr->at(2) * std::min(tr->at(0), tr->at(1)) >=
+              kMinDimensionToTransposeTiled) {
+        return TransposeDescription{
+            &instr, *tr,
+            /*permutation=*/absl::InlinedVector<int64_t, 3>{1, 0, 2}};
+      }
     }
   }
   return std::nullopt;
diff --git a/third_party/xla/xla/service/gpu/ir_emission_utils.h b/third_party/xla/xla/service/gpu/ir_emission_utils.h
index 9316ba9..3dcf0bc 100644
--- a/third_party/xla/xla/service/gpu/ir_emission_utils.h
+++ b/third_party/xla/xla/service/gpu/ir_emission_utils.h
@@ -23,6 +23,7 @@
 #include <variant>
 #include <vector>
 
+#include "absl/container/inlined_vector.h"
 #include "absl/status/statusor.h"
 #include "absl/strings/str_cat.h"
 #include "absl/strings/string_view.h"
@@ -54,6 +55,10 @@
 // efficient.
 inline constexpr int64_t kMinDimensionToTransposeTiled2 = 8;
 inline constexpr int64_t kMinTotalDimensionsToTransposeTiled = 64 * 128;
+// As the amount of shared memory is limited, we need to make sure that we don't
+// detect 102 transposes that would require too much bytes for the most minor
+// dimension.
+inline constexpr int64_t kMaxBytesInMostMinorDimension = 8;
 
 // Matrix multiplication before the rewrite.
 bool IsMatrixMultiplication(const HloInstruction& dot);
@@ -160,16 +165,18 @@
   const HloInstruction* instr;
 
   // Normalized transpose dimensions.
-  Vector3 dimensions;
+  absl::InlinedVector<int64_t, 3> dimensions;
 
   // Permutations of normalized transpose dimensions.
-  Vector3 permutation;
+  absl::InlinedVector<int64_t, 3> permutation;
 
-  TransposeDescription(Vector3 dimensions, Vector3 permutation)
+  TransposeDescription(absl::InlinedVector<int64_t, 3> dimensions,
+                       absl::InlinedVector<int64_t, 3> permutation)
       : TransposeDescription(/*instr=*/nullptr, dimensions, permutation) {}
 
-  TransposeDescription(const HloInstruction* instr, Vector3 dimensions,
-                       Vector3 permutation)
+  TransposeDescription(const HloInstruction* instr,
+                       absl::InlinedVector<int64_t, 3> dimensions,
+                       absl::InlinedVector<int64_t, 3> permutation)
       : instr(instr), dimensions(dimensions), permutation(permutation) {}
 
   // Transpose instruction input shape.
diff --git a/third_party/xla/xla/service/gpu/ir_emission_utils_test.cc b/third_party/xla/xla/service/gpu/ir_emission_utils_test.cc
index 0703f86..2d94235 100644
--- a/third_party/xla/xla/service/gpu/ir_emission_utils_test.cc
+++ b/third_party/xla/xla/service/gpu/ir_emission_utils_test.cc
@@ -20,6 +20,7 @@
 #include <string>
 #include <vector>
 
+#include "absl/container/inlined_vector.h"
 #include "absl/strings/str_cat.h"
 #include "xla/hlo/ir/backend_config.h"
 #include "xla/literal.h"
@@ -30,7 +31,6 @@
 #include "xla/shape_util.h"
 #include "xla/tests/hlo_test_base.h"
 #include "xla/types.h"
-#include "xla/util.h"
 #include "tsl/platform/status_matchers.h"
 #include "tsl/platform/statusor.h"
 #include "tsl/platform/test.h"
@@ -41,6 +41,7 @@
 using ::tsl::testing::IsOkAndHolds;
 
 using IrEmissionUtilsTest = HloTestBase;
+using InlinedVector = absl::InlinedVector<int64_t, 3>;
 
 TEST_F(IrEmissionUtilsTest, FindTiledLogicalTranspose) {
   const char* hlo = R"(
@@ -59,8 +60,94 @@
   auto result = GetDescriptionForTiledTransposeEmitter(*tr, *tr);
   EXPECT_TRUE(result.has_value());
   EXPECT_EQ(result->instr, tr);
-  EXPECT_EQ(result->dimensions, Vector3({1, 64, 1536}));
-  EXPECT_EQ(result->permutation, Vector3({0, 2, 1}));
+  EXPECT_EQ(result->dimensions, InlinedVector({1, 64, 1536}));
+  EXPECT_EQ(result->permutation, InlinedVector({0, 2, 1}));
+}
+
+TEST_F(IrEmissionUtilsTest, FindTiledLogical102Transpose) {
+  const char* hlo = R"(
+HloModule module
+
+ENTRY entry {
+  p = f32[32,48,2]{2,1,0} parameter(0)
+  ROOT t = f32[48,32,2]{2,1,0} transpose(p), dimensions={1,0,2}
+}
+)";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo));
+  auto& debug_options = module->mutable_config().mutable_debug_options();
+  debug_options.set_xla_gpu_mlir_emitter_level(3);
+
+  HloInstruction* tr = module->entry_computation()->root_instruction();
+
+  auto result = GetDescriptionForTiledTransposeEmitter(*tr, *tr);
+  EXPECT_TRUE(result.has_value());
+  EXPECT_EQ(result->instr, tr);
+  EXPECT_EQ(result->dimensions, InlinedVector({48, 32, 2}));
+  EXPECT_EQ(result->permutation, InlinedVector({1, 0, 2}));
+}
+
+TEST_F(IrEmissionUtilsTest, FindTiledLogical102TransposeTooMuchMemoryRequired) {
+  const char* hlo = R"(
+HloModule module
+
+ENTRY entry {
+  p = s8[32,48,9]{2,1,0} parameter(0)
+  ROOT t = s8[48,32,9]{2,1,0} transpose(p), dimensions={1,0,2}
+}
+)";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo));
+  auto& debug_options = module->mutable_config().mutable_debug_options();
+  debug_options.set_xla_gpu_mlir_emitter_level(3);
+
+  HloInstruction* tr = module->entry_computation()->root_instruction();
+
+  auto result = GetDescriptionForTiledTransposeEmitter(*tr, *tr);
+  EXPECT_FALSE(result.has_value());
+}
+
+TEST_F(IrEmissionUtilsTest, FindTiled102Transpose) {
+  const char* hlo = R"(
+HloModule module
+
+ENTRY entry {
+  p = s16[32,48,4]{2,1,0} parameter(0)
+  ROOT t = s16[32,48,4]{2,0,1} copy(p)
+}
+)";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo));
+  auto& debug_options = module->mutable_config().mutable_debug_options();
+  debug_options.set_xla_gpu_mlir_emitter_level(3);
+
+  HloInstruction* tr = module->entry_computation()->root_instruction();
+
+  auto result = GetDescriptionForTiledTransposeEmitter(*tr, *tr);
+  EXPECT_TRUE(result.has_value());
+  EXPECT_EQ(result->instr, tr);
+  EXPECT_EQ(result->dimensions, InlinedVector({48, 32, 4}));
+  EXPECT_EQ(result->permutation, InlinedVector({1, 0, 2}));
+}
+
+TEST_F(IrEmissionUtilsTest, FindTiled102TransposeTooMuchMemoryRequired) {
+  const char* hlo = R"(
+HloModule module
+
+ENTRY entry {
+  p = s8[32,48,9]{2,1,0} parameter(0)
+  ROOT t = s8[32,48,9]{2,0,1} copy(p)
+}
+)";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo));
+  auto& debug_options = module->mutable_config().mutable_debug_options();
+  debug_options.set_xla_gpu_mlir_emitter_level(3);
+
+  HloInstruction* tr = module->entry_computation()->root_instruction();
+
+  auto result = GetDescriptionForTiledTransposeEmitter(*tr, *tr);
+  EXPECT_FALSE(result.has_value());
 }
 
 TEST_F(IrEmissionUtilsTest, FindAnyTiledTranspose) {
@@ -79,8 +166,8 @@
   auto result = GetDescriptionForTiledTransposeEmitter(*r, *r);
   EXPECT_TRUE(result.has_value());
   EXPECT_EQ(result->instr, r);
-  EXPECT_EQ(result->dimensions, Vector3({64, 48, 32}));
-  EXPECT_EQ(result->permutation, Vector3({2, 1, 0}));
+  EXPECT_EQ(result->dimensions, InlinedVector({64, 48, 32}));
+  EXPECT_EQ(result->permutation, InlinedVector({2, 1, 0}));
 }
 
 TEST_F(IrEmissionUtilsTest, FindAnyTiledTransposeWithIntermediateUnaryOp) {
@@ -100,8 +187,8 @@
   auto result = GetDescriptionForTiledTransposeEmitter(*r, *r->operand(0));
   EXPECT_TRUE(result.has_value());
   EXPECT_EQ(result->instr, r->operand(0));
-  EXPECT_EQ(result->dimensions, Vector3({64, 48, 32}));
-  EXPECT_EQ(result->permutation, Vector3({2, 1, 0}));
+  EXPECT_EQ(result->dimensions, InlinedVector({64, 48, 32}));
+  EXPECT_EQ(result->permutation, InlinedVector({2, 1, 0}));
 }
 
 TEST_F(IrEmissionUtilsTest, FindAnyTiledTransposeWithIntermediateUnaryOpS8) {
@@ -258,8 +345,8 @@
   auto result = GetDescriptionForTiledTransposeEmitter(*r, *r->operand(0));
   EXPECT_TRUE(result.has_value());
   EXPECT_EQ(result->instr, r->operand(0));
-  EXPECT_EQ(result->dimensions, Vector3({64, 48, 32}));
-  EXPECT_EQ(result->permutation, Vector3({2, 1, 0}));
+  EXPECT_EQ(result->dimensions, InlinedVector({64, 48, 32}));
+  EXPECT_EQ(result->permutation, InlinedVector({2, 1, 0}));
 }
 
 TEST_F(IrEmissionUtilsTest, FindAnyTiledTransposeWithTwoIntermediateBinaryOps) {
@@ -289,8 +376,8 @@
       GetDescriptionForTiledTransposeEmitter(*r, FindNonTrivialHero(*r));
   EXPECT_TRUE(result.has_value());
   EXPECT_EQ(result->instr, r->operand(0)->operand(0));
-  EXPECT_EQ(result->dimensions, Vector3({64, 48, 32}));
-  EXPECT_EQ(result->permutation, Vector3({2, 1, 0}));
+  EXPECT_EQ(result->dimensions, InlinedVector({64, 48, 32}));
+  EXPECT_EQ(result->permutation, InlinedVector({2, 1, 0}));
 }
 
 TEST_F(IrEmissionUtilsTest,
@@ -475,8 +562,8 @@
       GetDescriptionForTiledTransposeEmitter(*copy, FindNonTrivialHero(*copy));
   EXPECT_TRUE(result.has_value());
   EXPECT_EQ(result->instr, copy);
-  EXPECT_EQ(result->dimensions, Vector3({8, 12, 1100}));
-  EXPECT_EQ(result->permutation, Vector3({2, 1, 0}));
+  EXPECT_EQ(result->dimensions, InlinedVector({8, 12, 1100}));
+  EXPECT_EQ(result->permutation, InlinedVector({2, 1, 0}));
 }
 
 TEST_F(IrEmissionUtilsTest, FindTiledLogicalTransposeOneSwapDimIsSmall) {
@@ -502,8 +589,8 @@
       GetDescriptionForTiledTransposeEmitter(*tr, FindNonTrivialHero(*tr));
   EXPECT_TRUE(result.has_value());
   EXPECT_EQ(result->instr, tr);
-  EXPECT_EQ(result->dimensions, Vector3({8, 12, 1100}));
-  EXPECT_EQ(result->permutation, Vector3({2, 1, 0}));
+  EXPECT_EQ(result->dimensions, InlinedVector({8, 12, 1100}));
+  EXPECT_EQ(result->permutation, InlinedVector({2, 1, 0}));
 }
 
 TEST_F(IrEmissionUtilsTest, FindTiledTransposeOtherSwapDimIsSmall) {
@@ -529,8 +616,8 @@
       GetDescriptionForTiledTransposeEmitter(*copy, FindNonTrivialHero(*copy));
   EXPECT_TRUE(result.has_value());
   EXPECT_EQ(result->instr, copy);
-  EXPECT_EQ(result->dimensions, Vector3({1100, 12, 8}));
-  EXPECT_EQ(result->permutation, Vector3({2, 1, 0}));
+  EXPECT_EQ(result->dimensions, InlinedVector({1100, 12, 8}));
+  EXPECT_EQ(result->permutation, InlinedVector({2, 1, 0}));
 }
 
 TEST_F(IrEmissionUtilsTest, FindTiledLogicalTransposeOtherSwapDimIsSmall) {
@@ -556,8 +643,8 @@
       GetDescriptionForTiledTransposeEmitter(*tr, FindNonTrivialHero(*tr));
   EXPECT_TRUE(result.has_value());
   EXPECT_EQ(result->instr, tr);
-  EXPECT_EQ(result->dimensions, Vector3({1100, 12, 8}));
-  EXPECT_EQ(result->permutation, Vector3({2, 1, 0}));
+  EXPECT_EQ(result->dimensions, InlinedVector({1100, 12, 8}));
+  EXPECT_EQ(result->permutation, InlinedVector({2, 1, 0}));
 }
 
 TEST_F(IrEmissionUtilsTest, IsContiguousSlice) {
diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc
index a964f6b..b73225a 100644
--- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc
+++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc
@@ -26,7 +26,6 @@
 #include <string>
 #include <tuple>
 #include <utility>
-#include <variant>
 #include <vector>
 
 #include "absl/container/flat_hash_map.h"
@@ -99,7 +98,6 @@
 #include "xla/service/gpu/fusions/triton/triton_fusion_emitter.h"
 #include "xla/service/gpu/gpu_asm_opts_util.h"
 #include "xla/service/gpu/gpu_conv_runner.h"
-#include "xla/service/gpu/gpu_fused_mha_runner.h"
 #include "xla/service/gpu/gpu_norm_runner.h"
 #include "xla/service/gpu/hlo_fusion_analysis.h"
 #include "xla/service/gpu/ir_emission_utils.h"
@@ -122,7 +120,6 @@
 #include "xla/service/gpu/runtime/copy_thunk.h"
 #include "xla/service/gpu/runtime/custom_call_thunk.h"
 #include "xla/service/gpu/runtime/fft_thunk.h"
-#include "xla/service/gpu/runtime/fused_mha_thunk.h"
 #include "xla/service/gpu/runtime/gemm_thunk.h"
 #include "xla/service/gpu/runtime/gpublas_lt_matmul_thunk.h"
 #include "xla/service/gpu/runtime/infeed_thunk.h"
@@ -173,6 +170,7 @@
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 #include "xla/service/gpu/runtime/cholesky_thunk.h"
 #include "xla/service/gpu/runtime/cub_sort_thunk.h"
+#include "xla/service/gpu/runtime/cudnn_thunk.h"
 #include "xla/service/gpu/runtime/triangular_solve_thunk.h"
 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
@@ -955,221 +953,17 @@
   return absl::OkStatus();
 }
 
-absl::Status IrEmitterUnnested::EmitFusedMHAThunk(
+absl::Status IrEmitterUnnested::EmitCuDnnThunk(
     const HloCustomCallInstruction* instr) {
-  const HloInstruction* lhs_bmm1 = instr->operand(0);
-  const HloInstruction* rhs_bmm1 = instr->operand(1);
-  const HloInstruction* rhs_bmm2 = instr->operand(2);
-
-  TF_ASSIGN_OR_RETURN(BufferAllocation::Slice lhs_bmm1_slice,
-                      GetAllocationSliceForHlo(lhs_bmm1));
-  TF_ASSIGN_OR_RETURN(BufferAllocation::Slice rhs_bmm1_slice,
-                      GetAllocationSliceForHlo(rhs_bmm1));
-  TF_ASSIGN_OR_RETURN(BufferAllocation::Slice rhs_bmm2_slice,
-                      GetAllocationSliceForHlo(rhs_bmm2));
-  TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice,
-                      GetAllocationSliceForHlo(instr, {0}));
-  TF_ASSIGN_OR_RETURN(BufferAllocation::Slice scratch_slice,
-                      GetAllocationSliceForHlo(
-                          instr, {instr->shape().tuple_shapes_size() - 1}));
-  BufferAllocation::Slice activation_slice;
-  bool has_activation = xla::ShapeUtil::TupleElementCount(instr->shape()) == 3;
-  if (has_activation) {
-    TF_ASSIGN_OR_RETURN(activation_slice, GetAllocationSliceForHlo(instr, {1}));
-  }
-
-  TF_ASSIGN_OR_RETURN(const xla::gpu::CudnnfMHAKind kind,
-                      xla::gpu::GetCudnnfMHAKind(instr));
-  BufferAllocation::Slice mask_slice, bias_slice;
-  BufferAllocation::Slice seqlen_q_slice, seqlen_k_slice;
-  std::optional<Shape> mask_shape, bias_shape;
-  {
-    bool has_bias = kind == CudnnfMHAKind::kScaleBiasSoftmax ||
-                    kind == CudnnfMHAKind::kScaleBiasSoftmaxDropout;
-
-    if (has_bias) {
-      const HloInstruction* bias = instr->operand(3);
-      TF_ASSIGN_OR_RETURN(bias_slice, GetAllocationSliceForHlo(bias));
-      bias_shape = bias->shape();
-    }
-    int64_t seqlen_qk_operand_index = 3 + has_bias;
-    bool has_seqlen_qk = seqlen_qk_operand_index == instr->operand_count() - 2;
-    if (has_seqlen_qk) {
-      const HloInstruction* seqlen_q = instr->operand(seqlen_qk_operand_index);
-      TF_ASSIGN_OR_RETURN(seqlen_q_slice, GetAllocationSliceForHlo(seqlen_q));
-      const HloInstruction* seqlen_k =
-          instr->operand(seqlen_qk_operand_index + 1);
-      TF_ASSIGN_OR_RETURN(seqlen_k_slice, GetAllocationSliceForHlo(seqlen_k));
-    }
-  }
-
-  TF_ASSIGN_OR_RETURN(const auto gpu_config,
-                      instr->backend_config<xla::gpu::GpuBackendConfig>());
-  const xla::gpu::CudnnfMHABackendConfig& config =
-      gpu_config.cudnn_fmha_backend_config();
-  Shape intermediate_tensor_shape(config.intermediate_tensor_shape());
-  absl::InlinedVector<Shape, 2> output_shapes = {
-      ShapeUtil::GetSubshape(instr->shape(), {0})};
-  if (has_activation) {
-    output_shapes.push_back(ShapeUtil::GetSubshape(instr->shape(), {1}));
-  }
-  TF_ASSIGN_OR_RETURN(const auto mask_type,
-                      AsCudnnFmhaMaskKind(config.mask_type()));
-  GpufMHADescriptor descriptor = {kind,
-                                  config,
-                                  mask_type,
-                                  lhs_bmm1->shape(),
-                                  rhs_bmm1->shape(),
-                                  rhs_bmm2->shape(),
-                                  intermediate_tensor_shape,
-                                  output_shapes,
-                                  config.bmm1_dot_dimension_numbers(),
-                                  config.bmm2_dot_dimension_numbers(),
-                                  mask_shape,
-                                  bias_shape};
-
-  TF_ASSIGN_OR_RETURN(GpufMHAConfig fmha_config,
-                      GpufMHAConfig::For(descriptor));
-  AddThunkToThunkSequence(std::make_unique<FusedMHAThunk>(
-      Thunk::ThunkInfo::WithProfileAnnotation(instr), std::move(fmha_config),
-      lhs_bmm1_slice, rhs_bmm1_slice, rhs_bmm2_slice, output_slice,
-      scratch_slice, mask_slice, bias_slice, activation_slice, seqlen_q_slice,
-      seqlen_k_slice));
-  return absl::OkStatus();
-}
-
-absl::Status IrEmitterUnnested::EmitFusedMHABackwardThunk(
-    const HloCustomCallInstruction* instr) {
-  TF_ASSIGN_OR_RETURN(const auto gpu_config,
-                      instr->backend_config<xla::gpu::GpuBackendConfig>());
-  const xla::gpu::CudnnfMHABackendConfig& config =
-      gpu_config.cudnn_fmha_backend_config();
-
-  int input_index = 0;
-  TF_ASSIGN_OR_RETURN(BufferAllocation::Slice bmm1_grad_gemm1_rhs_slice,
-                      GetAllocationSliceForHlo(instr->operand(input_index)));
-  Shape bmm1_grad_gemm1_rhs_shape = instr->operand(input_index++)->shape();
-
-  TF_ASSIGN_OR_RETURN(BufferAllocation::Slice bmm1_grad_gemm2_rhs_slice,
-                      GetAllocationSliceForHlo(instr->operand(input_index)));
-  Shape bmm1_grad_gemm2_rhs_shape = instr->operand(input_index++)->shape();
-
-  TF_ASSIGN_OR_RETURN(BufferAllocation::Slice bmm2_grad_gemm2_rhs_slice,
-                      GetAllocationSliceForHlo(instr->operand(input_index)));
-  Shape bmm2_grad_gemm2_rhs_shape = instr->operand(input_index++)->shape();
-
-  TF_ASSIGN_OR_RETURN(BufferAllocation::Slice bmm2_grad_gemm1_lhs_slice,
-                      GetAllocationSliceForHlo(instr->operand(input_index)));
-  Shape bmm2_grad_gemm1_lhs_shape;
-
-  Shape intermediate_tensor_shape(config.intermediate_tensor_shape());
-  bmm2_grad_gemm1_lhs_shape = intermediate_tensor_shape;
-  input_index++;
-
-  TF_ASSIGN_OR_RETURN(BufferAllocation::Slice d_output_slice,
-                      GetAllocationSliceForHlo(instr->operand(input_index)));
-  Shape d_output_shape = instr->operand(input_index++)->shape();
-
-  TF_ASSIGN_OR_RETURN(const CudnnfMHAKind kind, GetCudnnfMHAKind(instr));
-  BufferAllocation::Slice mask_slice;
-  std::optional<Shape> mask_shape;
-
-  bool has_bias = (kind == CudnnfMHAKind::kBackwardScaleBiasSoftmax ||
-                   kind == CudnnfMHAKind::kBackwardScaleBiasSoftmaxDropout);
-  BufferAllocation::Slice bias_slice;
-  std::optional<Shape> bias_shape;
-  if (has_bias) {
-    TF_ASSIGN_OR_RETURN(bias_slice,
-                        GetAllocationSliceForHlo(instr->operand(input_index)));
-    bias_shape = instr->operand(input_index++)->shape();
-  }
-
-  BufferAllocation::Slice fwd_output_slice;
-  std::optional<Shape> fwd_output_shape;
-
-  TF_ASSIGN_OR_RETURN(fwd_output_slice,
-                      GetAllocationSliceForHlo(instr->operand(input_index)));
-  fwd_output_shape = instr->operand(input_index++)->shape();
-
-  BufferAllocation::Slice seqlen_q_slice, seqlen_k_slice;
-  bool has_seqlen_qk = input_index == instr->operand_count() - 2;
-  if (has_seqlen_qk) {
-    const HloInstruction* seqlen_q = instr->operand(input_index);
-    TF_ASSIGN_OR_RETURN(seqlen_q_slice, GetAllocationSliceForHlo(seqlen_q));
-    const HloInstruction* seqlen_k = instr->operand(input_index + 1);
-    TF_ASSIGN_OR_RETURN(seqlen_k_slice, GetAllocationSliceForHlo(seqlen_k));
-    input_index += 2;
-  }
-  TF_RET_CHECK(input_index == instr->operand_count());
-
-  int output_index = 0;
-  TF_ASSIGN_OR_RETURN(BufferAllocation::Slice d_bmm1_lhs_slice,
-                      GetAllocationSliceForHlo(instr, {output_index}));
-  Shape d_bmm1_lhs_shape =
-      ShapeUtil::GetSubshape(instr->shape(), {output_index++});
-
-  TF_ASSIGN_OR_RETURN(BufferAllocation::Slice d_bmm1_rhs_slice,
-                      GetAllocationSliceForHlo(instr, {output_index}));
-  Shape d_bmm1_rhs_shape =
-      ShapeUtil::GetSubshape(instr->shape(), {output_index++});
-
-  TF_ASSIGN_OR_RETURN(BufferAllocation::Slice d_bmm2_rhs_slice,
-                      GetAllocationSliceForHlo(instr, {output_index}));
-  Shape d_bmm2_rhs_shape =
-      ShapeUtil::GetSubshape(instr->shape(), {output_index++});
-
-  BufferAllocation::Slice d_s_slice;
-  std::optional<Shape> d_s_shape;
-
-  bool has_dbias = instr->shape().tuple_shapes().size() == 5;
-  BufferAllocation::Slice d_bias_slice;
-  std::optional<Shape> d_bias_shape;
-  if (has_dbias) {
-    TF_ASSIGN_OR_RETURN(d_bias_slice,
-                        GetAllocationSliceForHlo(instr, {output_index}));
-    d_bias_shape = ShapeUtil::GetSubshape(instr->shape(), {output_index++});
-  }
-  TF_ASSIGN_OR_RETURN(BufferAllocation::Slice scratch_slice,
-                      GetAllocationSliceForHlo(instr, {output_index++}));
-  TF_RET_CHECK(output_index == instr->shape().tuple_shapes().size());
-  TF_ASSIGN_OR_RETURN(const auto mask_type,
-                      AsCudnnFmhaMaskKind(config.mask_type()));
-  bool force_deterministic = config.force_deterministic();
-  GpufMHABackwardDescriptor descriptor = {
-      kind,
-      config,
-      mask_type,
-      bmm1_grad_gemm1_rhs_shape,
-      bmm1_grad_gemm2_rhs_shape,
-      bmm2_grad_gemm1_lhs_shape,
-      bmm2_grad_gemm2_rhs_shape,
-      d_output_shape,
-      d_bmm1_lhs_shape,
-      d_bmm1_rhs_shape,
-      d_bmm2_rhs_shape,
-      config.bmm1_grad_gemm1_dot_dimension_numbers(),
-      config.bmm1_grad_gemm2_dot_dimension_numbers(),
-      config.bmm2_grad_gemm1_dot_dimension_numbers(),
-      config.bmm2_grad_gemm2_dot_dimension_numbers(),
-      d_s_shape,
-      fwd_output_shape,
-      mask_shape,
-      d_bias_shape,
-      bias_shape,
-      force_deterministic};
-
-  TF_ASSIGN_OR_RETURN(GpufMHABackwardConfig fmha_backward_config,
-                      GpufMHABackwardConfig::For(descriptor));
-
-  AddThunkToThunkSequence(std::make_unique<FusedMHABackwardThunk>(
-      Thunk::ThunkInfo::WithProfileAnnotation(instr),
-      std::move(fmha_backward_config), bmm1_grad_gemm1_rhs_slice,
-      bmm1_grad_gemm2_rhs_slice, bmm2_grad_gemm1_lhs_slice,
-      bmm2_grad_gemm2_rhs_slice, d_output_slice, scratch_slice,
-      d_bmm1_lhs_slice, d_bmm1_rhs_slice, d_bmm2_rhs_slice, d_s_slice,
-      mask_slice, d_bias_slice, fwd_output_slice, bias_slice, seqlen_q_slice,
-      seqlen_k_slice));
-
+  TF_ASSIGN_OR_RETURN(
+      auto kernel_arguments,
+      KernelArguments::Create(ir_emitter_context_->buffer_assignment(), instr,
+                              instr->operands()));
+  TF_ASSIGN_OR_RETURN(const std::string fingerprint,
+                      FingerprintWithBackendConfig<GpuBackendConfig>(*instr));
+  AddThunkToThunkSequence(std::make_unique<CuDnnThunk>(
+      fingerprint, Thunk::ThunkInfo::WithProfileAnnotation(instr),
+      kernel_arguments.args()));
   return absl::OkStatus();
 }
 
@@ -1698,7 +1492,7 @@
   const se::DeviceDescription& device_info =
       ir_emitter_context_->gpu_device_info();
   const HloFusionAnalysis fusion_analysis =
-      HloFusionAnalysis::Create(instr, &device_info);
+      HloFusionAnalysis::Create(*instr, device_info);
 
   std::unique_ptr<FusionInterface> emitter = GetFusionEmitter(HloFusionInfo(
       fusion_analysis, instr, &ir_emitter_context_->buffer_assignment()));
@@ -2921,11 +2715,8 @@
       if (IsCustomCallToDnnNorm(*instr)) {
         return EmitNormThunk(custom_call);
       }
-      if (IsFwdCustomCallTofMHA(*instr)) {
-        return EmitFusedMHAThunk(custom_call);
-      }
-      if (IsBwdCustomCallTofMHA(*instr)) {
-        return EmitFusedMHABackwardThunk(custom_call);
+      if (IsCustomCallTofMHA(*instr)) {
+        return EmitCuDnnThunk(custom_call);
       }
 #endif  // GOOGLE_CUDA
       if (IsCustomCallToTopK(*instr)) {
diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.h b/third_party/xla/xla/service/gpu/ir_emitter_unnested.h
index f97f106..d19dd5d 100644
--- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.h
+++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.h
@@ -147,8 +147,7 @@
   absl::Status EmitConvolutionReorderThunk(
       const HloCustomCallInstruction* instr);
   absl::Status EmitNormThunk(const HloCustomCallInstruction* instr);
-  absl::Status EmitFusedMHAThunk(const HloCustomCallInstruction* instr);
-  absl::Status EmitFusedMHABackwardThunk(const HloCustomCallInstruction* instr);
+  absl::Status EmitCuDnnThunk(const HloCustomCallInstruction* instr);
 #endif  // GOOGLE_CUDA
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
   absl::Status EmitCubDeviceRadixSort(const HloCustomCallInstruction* instr);
diff --git a/third_party/xla/xla/service/gpu/kernels/BUILD b/third_party/xla/xla/service/gpu/kernels/BUILD
index 8c9b851..a1ce0cf 100644
--- a/third_party/xla/xla/service/gpu/kernels/BUILD
+++ b/third_party/xla/xla/service/gpu/kernels/BUILD
@@ -95,6 +95,7 @@
         "//xla/stream_executor:device_description",
         "@com_google_absl//absl/container:flat_hash_set",
         "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
         "@com_google_absl//absl/types:span",
         "@local_tsl//tsl/platform:errors",
         "@local_tsl//tsl/platform:statusor",
@@ -110,6 +111,7 @@
     disabled_backends = ["gpu_h100"],
     tags = ["no_rocm"],
     deps = [
+        ":custom_kernel",
         ":custom_kernel_fusion_pattern",
         ":cutlass_gemm_custom_kernel",
         ":cutlass_gemm_fusion",
@@ -119,9 +121,11 @@
         "//xla:error_spec",
         "//xla:literal_util",
         "//xla:types",
-        "//xla/service/gpu:custom_kernel_fusion_rewriter",
+        "//xla:xla_data_proto_cc",
         "//xla/service/gpu:gpu_device_info_for_tests",
+        "//xla/service/gpu/transforms:custom_kernel_fusion_rewriter",
         "//xla/tests:hlo_test_base",
+        "@com_google_googletest//:gtest_main",
         "@local_tsl//tsl/platform:test",
         "@local_tsl//tsl/platform:test_main",
     ],
@@ -355,6 +359,10 @@
     deps = [
         ":cutlass_gemm_kernel_bf16xbf16_to_bf16",
         ":cutlass_gemm_kernel_bf16xbf16_to_bf16_sm80",
+        ":cutlass_gemm_kernel_bf16xbf16_to_f32",
+        ":cutlass_gemm_kernel_bf16xf32_to_f32",
+        ":cutlass_gemm_kernel_bf16xs8_to_f32",
+        ":cutlass_gemm_kernel_f32xbf16_to_f32",
         ":cutlass_gemm_kernel_f32xf32_to_f32",
     ] + if_cuda_newer_than(
         "12_0",
@@ -437,6 +445,68 @@
     ]),
 )
 
+cuda_library(
+    name = "cutlass_gemm_kernel_bf16xbf16_to_f32",
+    srcs = if_cuda_is_configured(["cutlass_gemm_kernel_bf16xbf16_to_f32.cu.cc"]),
+    copts = [
+        "-mllvm",
+        "-unroll-threshold=100000",
+    ] + if_windows(
+        [],
+        ["-Wno-unknown-attributes"],
+    ),
+    deps = if_cuda_is_configured([
+        ":cutlass_gemm_adaptor",
+        "@cutlass_archive//:cutlass",
+        "@local_config_cuda//cuda:cuda_headers",
+    ]),
+)
+
+cuda_library(
+    name = "cutlass_gemm_kernel_bf16xf32_to_f32",
+    srcs = if_cuda_is_configured(["cutlass_gemm_kernel_bf16xf32_to_f32.cu.cc"]),
+    copts = [
+        "-mllvm",
+        "-unroll-threshold=100000",
+    ] + if_windows(
+        [],
+        ["-Wno-unknown-attributes"],
+    ),
+    deps = if_cuda_is_configured([
+        ":cutlass_gemm_adaptor",
+        "@cutlass_archive//:cutlass",
+        "@local_config_cuda//cuda:cuda_headers",
+    ]),
+)
+
+cuda_library(
+    name = "cutlass_gemm_kernel_f32xbf16_to_f32",
+    srcs = if_cuda_is_configured(["cutlass_gemm_kernel_f32xbf16_to_f32.cu.cc"]),
+    copts = [
+        "-mllvm",
+        "-unroll-threshold=100000",
+    ] + if_windows(
+        [],
+        ["-Wno-unknown-attributes"],
+    ),
+    deps = if_cuda_is_configured([
+        ":cutlass_gemm_adaptor",
+        "@cutlass_archive//:cutlass",
+        "@local_config_cuda//cuda:cuda_headers",
+    ]),
+)
+
+cuda_library(
+    name = "cutlass_gemm_kernel_bf16xs8_to_f32",
+    srcs = if_cuda_is_configured(["cutlass_gemm_kernel_bf16xs8_to_f32.cu.cc"]),
+    copts = ["-Wno-unknown-attributes -mllvm -unroll-threshold=100000"],
+    deps = if_cuda_is_configured([
+        ":cutlass_gemm_adaptor",
+        "@cutlass_archive//:cutlass",
+        "@local_config_cuda//cuda:cuda_headers",
+    ]),
+)
+
 #===--------------------------------------------------------------------------------------------===#
 # CUTLASS Gemm kernel libraries
 #===--------------------------------------------------------------------------------------------===#
diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm.h b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm.h
index 37fb0ad..963b80c 100644
--- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm.h
+++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm.h
@@ -46,12 +46,27 @@
 
 enum class Arch { kDefault, kSm80, kSm90 };
 
+// Keep in sync with cutlass::gemm::GemmUniversalMode.
+enum class GemmMode { kGemm, kGemmSplitKParallel, kBatched, kArray, kInvalid };
+
 template <Arch arch>
 struct Bf16xBf16ToBf16 {};
 
 template <Arch arch>
 struct F32xF32ToF32 {};
 
+template <Arch arch>
+struct Bf16xBf16ToF32 {};
+
+template <Arch arch>
+struct Bf16xF32ToF32 {};
+
+template <Arch arch>
+struct F32xBf16ToF32 {};
+
+template <Arch arch>
+struct Bf16xS8ToF32 {};
+
 // A tag to specialize CUTLASS kernel adaptors for loading kernels from shared
 // libraries using dlopen.
 struct DlOpenedKernel {};
@@ -132,6 +147,12 @@
 // Type-erased CUTLASS gemm arguments structure that has all of the details
 // required for packing CUTLASS kernel parameters.
 struct Arguments {
+  GemmMode mode;
+
+  // Number of batches when mode is `kBatched`.
+  // Number of k-slices when mode is `kGemmSplitKParallel`.
+  int32_t batch_count;
+
   int32_t m;
   int32_t n;
   int32_t k;
diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_adaptor.cu.h b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_adaptor.cu.h
index 53a6ac6..1478dc8 100644
--- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_adaptor.cu.h
+++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_adaptor.cu.h
@@ -19,13 +19,17 @@
 #include <cstddef>
 #include <cstdint>
 #include <memory>
+#include <optional>
 
+#include "third_party/gpus/cuda/include/vector_types.h"
 #include "cute/layout.hpp"
 #include "cutlass/cutlass.h"
+#include "cutlass/device_kernel.h"
 #include "cutlass/gemm/device/gemm_universal_adapter.h"
 #include "cutlass/gemm/gemm.h"
 #include "cutlass/gemm/gemm_enumerated_types.h"
 #include "cutlass/gemm_coord.h"
+#include "cutlass/kernel_hardware_info.h"
 #include "cutlass/layout/matrix.h"
 #include "cutlass/util/packed_stride.hpp"
 #include "xla/service/gpu/kernels/cutlass_gemm.h"
@@ -137,6 +141,21 @@
          cutlass::Status::kSuccess;
 }
 
+inline cutlass::gemm::GemmUniversalMode ToGemmUniversalMode(GemmMode mode) {
+  switch (mode) {
+    case GemmMode::kGemm:
+      return cutlass::gemm::GemmUniversalMode::kGemm;
+    case GemmMode::kGemmSplitKParallel:
+      return cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel;
+    case GemmMode::kBatched:
+      return cutlass::gemm::GemmUniversalMode::kBatched;
+    case GemmMode::kArray:
+      return cutlass::gemm::GemmUniversalMode::kArray;
+    case GemmMode::kInvalid:
+      return cutlass::gemm::GemmUniversalMode::kInvalid;
+  }
+}
+
 // Converts type-erased gemm arguments to the underlying CUTLASS operation
 // arguments.
 template <typename Tag>
@@ -148,7 +167,7 @@
   auto ldb = LdB<typename Traits<Tag>::Operation>(problem_size);
   auto ldc = LdC<typename Traits<Tag>::Operation>(problem_size);
 
-  auto mode = cutlass::gemm::GemmUniversalMode::kGemm;
+  cutlass::gemm::GemmUniversalMode mode = ToGemmUniversalMode(args.mode);
 
   // TODO(ezhulenev): We hardcode parameters for `LinearCombination`
   // epilogue, however `Gemm` template can be compiled with arbitrary
@@ -160,7 +179,7 @@
 
   return typename Traits<Tag>::Arguments(      // CUTLASS Operation arguments
       mode, problem_size,                      //
-      1,                                       // batch
+      args.batch_count,                        // batch or k-split slices
       {alpha, beta},                           // epilogue
       args.lhs, args.rhs, args.out, args.out,  // pointers
       0, 0, 0, 0,                              // batch strides
@@ -237,7 +256,9 @@
   // TODO(ezhulenev): Pass device id and sm_count in arguments.
   cutlass::KernelHardwareInfo hw_info{/*device_id=*/0, /*sm_count=*/128};
 
-  auto mode = cutlass::gemm::GemmUniversalMode::kGemm;
+  cutlass::gemm::GemmUniversalMode mode =
+      static_cast<cutlass::gemm::GemmUniversalMode>(
+          static_cast<int>(args.mode));
   typename Kernel::ProblemShape problem_shape = {args.m, args.n, args.k,
                                                  /*batch=*/1};
 
diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.cc
index 81b2dbd..a97fe04 100644
--- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.cc
+++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.cc
@@ -90,8 +90,8 @@
 }
 
 template <typename Tag>
-KernelArgsPacking ArgsPacking(int32_t m, int32_t n, int32_t k,
-                              const ArgsIndices& indices,
+KernelArgsPacking ArgsPacking(GemmMode mode, int32_t batch_count, int32_t m,
+                              int32_t n, int32_t k, const ArgsIndices& indices,
                               const DynamicSliceIndices& slices,
                               int32_t device_sms, Adaptor<Tag> adaptor) {
   using Packed = absl::StatusOr<std::unique_ptr<se::KernelArgsPackedArrayBase>>;
@@ -111,7 +111,7 @@
   return [=](const se::Kernel& kernel, const se::KernelArgs& args) -> Packed {
     auto* mem_args = se::Cast<se::KernelArgsDeviceMemoryArray>(&args);
 
-    Arguments arguments = {m, n, k};
+    Arguments arguments = {mode, batch_count, m, n, k};
     arguments.lhs = const_cast<void*>(mem_args->device_memory_ptr(indices.lhs));
     arguments.rhs = const_cast<void*>(mem_args->device_memory_ptr(indices.rhs));
     arguments.out = const_cast<void*>(mem_args->device_memory_ptr(indices.out));
@@ -176,7 +176,8 @@
 //===----------------------------------------------------------------------===//
 
 template <typename Tag>
-static CustomKernel Load(std::string name, int32_t m, int32_t n, int32_t k,
+static CustomKernel Load(std::string name, GemmMode mode, int32_t batch_count,
+                         int32_t m, int32_t n, int32_t k,
                          const ArgsIndices& indices,
                          const DynamicSliceIndices& slices,
                          const se::DeviceDescription& device,
@@ -188,8 +189,8 @@
   auto thread_dim = As<se::ThreadDim>(adaptor.ThreadDim());
   auto shared_memory_bytes = adaptor.SharedMemoryBytes();
 
-  auto packing =
-      ArgsPacking<Tag>(m, n, k, indices, slices, device.core_count(), adaptor);
+  auto packing = ArgsPacking<Tag>(mode, batch_count, m, n, k, indices, slices,
+                                  device.core_count(), adaptor);
 
   se::MultiKernelLoaderSpec spec(/*arity=*/2, std::move(packing));
   spec.AddInProcessSymbol(kernel.symbol(), name);
@@ -204,33 +205,83 @@
 }
 
 absl::StatusOr<std::vector<CustomKernel>> GetCutlassGemmKernels(
-    std::string name, PrimitiveType dtype, int32_t m, int32_t n, int32_t k,
+    std::string name, PrimitiveType dot_type, PrimitiveType lhs_type,
+    PrimitiveType rhs_type, int32_t m, int32_t n, int32_t k,
     const ArgsIndices& indices, const DynamicSliceIndices& slices,
     const se::DeviceDescription& device) {
   auto& cuda_cc =
       std::get<se::CudaComputeCapability>(device.gpu_compute_capability());
 
-  switch (dtype) {
-    case PrimitiveType::F32:
-      return {{Load<F32xF32ToF32<Default>>(std::move(name), m, n, k, indices,
-                                           slices, device)}};
-    case PrimitiveType::BF16:
+  if (dot_type == PrimitiveType::F32 && lhs_type == PrimitiveType::F32 &&
+      rhs_type == PrimitiveType::F32) {
+    return {{Load<F32xF32ToF32<Default>>(std::move(name), GemmMode::kGemm,
+                                         /*batch_count=*/1, m, n, k, indices,
+                                         slices, device)}};
+  }
+
+  if (dot_type == PrimitiveType::BF16 && lhs_type == PrimitiveType::BF16 &&
+      rhs_type == PrimitiveType::BF16) {
 #if CUDA_VERSION >= 12000
       if (cuda_cc.IsAtLeastHopper()) {
-        return {{Load<Bf16xBf16ToBf16<Sm90>>(std::move(name), m, n, k, indices,
-                                             slices, device)}};
+        return {{Load<Bf16xBf16ToBf16<Sm90>>(std::move(name), GemmMode::kGemm,
+                                             /*batch_count=*/1, m, n, k,
+                                             indices, slices, device)}};
       }
 #endif
       if (cuda_cc.IsAtLeastAmpere()) {
-        return {{Load<Bf16xBf16ToBf16<Sm80>>(std::move(name), m, n, k, indices,
-                                             slices, device)}};
+        return {{Load<Bf16xBf16ToBf16<Default>>(
+            std::move(name), GemmMode::kGemm, /*batch_count=*/1, m, n, k,
+            indices, slices, device)}};
       }
-      return {{Load<Bf16xBf16ToBf16<Default>>(std::move(name), m, n, k, indices,
-                                              slices, device)}};
-
-    default:
-      return absl::InvalidArgumentError("Unsupported CUTLASS gemm data type");
+      return {{Load<Bf16xBf16ToBf16<Default>>(std::move(name), GemmMode::kGemm,
+                                              /*batch_count=*/1, m, n, k,
+                                              indices, slices, device)}};
   }
+
+  if (dot_type == PrimitiveType::F32 && lhs_type == PrimitiveType::BF16 &&
+      rhs_type == PrimitiveType::BF16) {
+    return {{Load<Bf16xBf16ToF32<Default>>(std::move(name), GemmMode::kGemm,
+                                           /*batch_count=*/1, m, n, k, indices,
+                                           slices, device)}};
+  }
+
+  if (dot_type == PrimitiveType::F32 && lhs_type == PrimitiveType::BF16 &&
+      rhs_type == PrimitiveType::F32) {
+    return {{Load<Bf16xF32ToF32<Default>>(name, GemmMode::kGemm,
+                                          /*batch_count=*/1, m, n, k, indices,
+                                          slices, device),
+             Load<Bf16xF32ToF32<Default>>(name, GemmMode::kGemmSplitKParallel,
+                                          /*batch_count=*/16, m, n, k, indices,
+                                          slices, device)}};
+  }
+
+  if (dot_type == PrimitiveType::F32 && lhs_type == PrimitiveType::F32 &&
+      rhs_type == PrimitiveType::BF16) {
+    return {{Load<F32xBf16ToF32<Default>>(name, GemmMode::kGemm,
+                                          /*batch_count=*/1, m, n, k, indices,
+                                          slices, device),
+             Load<F32xBf16ToF32<Default>>(name, GemmMode::kGemmSplitKParallel,
+                                          /*batch_count=*/16, m, n, k, indices,
+                                          slices, device)}};
+  }
+
+  if (dot_type == PrimitiveType::F32 && lhs_type == PrimitiveType::BF16 &&
+      rhs_type == PrimitiveType::S8) {
+    return {{
+        Load<Bf16xS8ToF32<Default>>(name, GemmMode::kGemm,
+                                    /*batch_count=*/1, m, n, k, indices, slices,
+                                    device),
+        Load<Bf16xS8ToF32<Default>>(name, GemmMode::kGemmSplitKParallel,
+                                    /*batch_count=*/16, m, n, k, indices,
+                                    slices, device),
+    }};
+  }
+
+  std::string kernel_name = PrimitiveType_Name(lhs_type) + "x" +
+                            PrimitiveType_Name(rhs_type) + "To" +
+                            PrimitiveType_Name(dot_type);
+  return absl::InvalidArgumentError(absl::StrCat(
+      "Unsupported CUTLASS gemm data type for kernel: ", kernel_name));
 }
 
 absl::StatusOr<CustomKernel> LoadCutlassGemmKernel(
@@ -250,8 +301,9 @@
         "Failed to load CUTLASS kernel from a shared library: ", library_path));
   }
 
-  return Load<DlOpenedKernel>(std::move(name), m, n, k, indices, slices, device,
-                              *adaptor, *kernel);
+  return Load<DlOpenedKernel>(std::move(name), GemmMode::kGemm,
+                              /*batch_count=*/1, m, n, k, indices, slices,
+                              device, *adaptor, *kernel);
 }
 
 }  // namespace xla::gpu::kernel::gemm_universal
diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.h b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.h
index 37531ef..04b0925 100644
--- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.h
+++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.h
@@ -30,7 +30,8 @@
 
 // Returns pre-compiled custom kernels for a given data type and problem size.
 absl::StatusOr<std::vector<CustomKernel>> GetCutlassGemmKernels(
-    std::string name, PrimitiveType dtype, int32_t m, int32_t n, int32_t k,
+    std::string name, PrimitiveType dot_type, PrimitiveType lhs_type,
+    PrimitiveType rhs_type, int32_t m, int32_t n, int32_t k,
     const ArgsIndices& indices, const DynamicSliceIndices& slices,
     const se::DeviceDescription& device);
 
diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_benchmarks.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_benchmarks.cc
index 22edaea..124569e 100644
--- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_benchmarks.cc
+++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_benchmarks.cc
@@ -54,7 +54,8 @@
 
   TF_ASSERT_OK_AND_ASSIGN(
       auto custom_kernels,
-      GetCutlassGemmKernels("cutlass_gemm", PrimitiveType::BF16, m, n, k,
+      GetCutlassGemmKernels("cutlass_gemm", PrimitiveType::BF16,
+                            PrimitiveType::BF16, PrimitiveType::BF16, m, n, k,
                             /*indices=*/{0, 1, 2}, /*slices=*/{}, device));
   const auto& custom_kernel = custom_kernels[0];
 
diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_stub.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_stub.cc
index 8e231ee..d95241b 100644
--- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_stub.cc
+++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_stub.cc
@@ -26,7 +26,8 @@
 namespace xla::gpu::kernel::gemm_universal {
 
 absl::StatusOr<std::vector<CustomKernel>> GetCutlassGemmKernels(
-    std::string name, PrimitiveType dtype, int32_t m, int32_t n, int32_t k,
+    std::string name, PrimitiveType dot_type, PrimitiveType lhs_type,
+    PrimitiveType rhs_type, int32_t m, int32_t n, int32_t k,
     const ArgsIndices& indices, const DynamicSliceIndices& slices,
     const se::DeviceDescription& device) {
   return absl::InternalError("XLA compiled without CUDA support");
diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_test.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_test.cc
index 4c0a586..7cdc950 100644
--- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_test.cc
+++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_test.cc
@@ -44,7 +44,8 @@
   // Load [4, 4] x [4, 4] gemm kernel written in CUDA C++ with CUTLASS.
   TF_ASSERT_OK_AND_ASSIGN(
       auto custom_kernels,
-      GetCutlassGemmKernels("cutlass_gemm", PrimitiveType::F32, 4, 4, 4,
+      GetCutlassGemmKernels("cutlass_gemm", PrimitiveType::F32,
+                            PrimitiveType::F32, PrimitiveType::F32, 4, 4, 4,
                             /*indices=*/{0, 1, 2}, /*slices=*/{},
                             executor->GetDeviceDescription()));
   auto custom_kernel = custom_kernels[0];
diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc
index a392801..946e1f8 100644
--- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc
+++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc
@@ -15,6 +15,7 @@
 
 #include "xla/service/gpu/kernels/cutlass_gemm_fusion.h"
 
+#include <algorithm>
 #include <cstddef>
 #include <cstdint>
 #include <optional>
@@ -23,6 +24,7 @@
 
 #include "absl/container/flat_hash_set.h"
 #include "absl/status/status.h"
+#include "absl/status/statusor.h"
 #include "absl/types/span.h"
 #include "xla/hlo/ir/hlo_casting_utils.h"
 #include "xla/hlo/ir/hlo_computation.h"
@@ -136,14 +138,22 @@
   return absl::InternalError("unsupported operands type");
 }
 
-// Returns matched GEMM with one of the operands upcasted to the accumulator
-// data type with an HLO convert instruction.
+// Returns matched GEMM with one or both the operands upcasted to the
+// accumulator data type with an HLO convert instruction.
 static absl::StatusOr<GemmWithUpcast> MatchGemmWithUpcast(
     HloDotInstruction* dot) {
   TF_RETURN_IF_ERROR(MatchRowMajorGemm(dot));
 
   GemmWithUpcast match(dot);
 
+  // C <- convert(A) * convert(B)
+  if (Match(const_cast<HloInstruction*>(dot->operand(0)),
+            m::Convert(&match.lhs_upcast, m::Op())) &&
+      Match(const_cast<HloInstruction*>(dot->operand(1)),
+            m::Convert(&match.rhs_upcast, m::Op()))) {
+    return match;
+  }
+
   // C <- convert(A) * B
   if (Match(const_cast<HloInstruction*>(dot->operand(0)),
             m::Convert(&match.lhs_upcast, m::Op()))) {
@@ -254,16 +264,19 @@
   if (!dot) return std::nullopt;
 
   auto matched = MatchGemmWithUpcast(dot);
-  if (!matched.ok()) return std::nullopt;
 
-  // Only one operand can be upcasted.
-  DCHECK(matched->lhs_upcast == nullptr || matched->rhs_upcast == nullptr);
+  if (!matched.ok()) return std::nullopt;
 
   CustomFusionConfig config;
   config.set_name("cutlass_gemm_with_upcast");
 
-  return matched->lhs_upcast ? Match{config, {matched->lhs_upcast, instr}}
-                             : Match{config, {matched->rhs_upcast, instr}};
+  if (matched->lhs_upcast != nullptr && matched->rhs_upcast == nullptr) {
+    return Match{config, {matched->lhs_upcast, instr}};
+  } else if (matched->rhs_upcast != nullptr && matched->lhs_upcast == nullptr) {
+    return Match{config, {matched->rhs_upcast, instr}};
+  } else {
+    return Match{config, {matched->lhs_upcast, matched->rhs_upcast, instr}};
+  }
 }
 
 //===----------------------------------------------------------------------===//
@@ -283,7 +296,7 @@
 
     TF_RETURN_IF_ERROR(MatchSimpleGemm(dot, {PrimitiveType::F32}));
 
-    auto dtype = dot->shape().element_type();
+    PrimitiveType dot_type = dot->shape().element_type();
 
     auto* lhs = Cast<HloParameterInstruction>(dot->operand(0));
     auto* rhs = Cast<HloParameterInstruction>(dot->operand(1));
@@ -293,15 +306,19 @@
         lhs->parameter_number(), rhs->parameter_number(),
         computation->num_parameters()};
 
-    auto& lhs_shape = lhs->shape();
-    auto& rhs_shape = rhs->shape();
+    const Shape& lhs_shape = lhs->shape();
+    const Shape& rhs_shape = rhs->shape();
 
     size_t m = lhs_shape.dimensions(0);
     size_t k = lhs_shape.dimensions(1);
     size_t n = rhs_shape.dimensions(1);
 
-    return kernel::gemm_universal::GetCutlassGemmKernels(
-        "cutlass_gemm", dtype, m, n, k, indices, /*slices=*/{}, device);
+    PrimitiveType lhs_type = lhs->shape().element_type();
+    PrimitiveType rhs_type = rhs->shape().element_type();
+
+    return GetCutlassGemmKernels("cutlass_gemm", dot_type, lhs_type, rhs_type,
+                                 m, n, k, indices,
+                                 /*slices=*/{}, device);
   }
 };
 
@@ -313,23 +330,44 @@
     auto* dot = DynCast<HloDotInstruction>(computation->root_instruction());
     if (dot == nullptr) {
       return absl::InternalError(
-          "cutlass_gemm requires ROOT operation to be a dot");
+          "cutlass_gemm_with_upcast requires ROOT operation to be a dot");
     }
 
-    TF_ASSIGN_OR_RETURN(auto matched, MatchGemmWithUpcast(dot));
+    TF_ASSIGN_OR_RETURN(GemmWithUpcast matched, MatchGemmWithUpcast(dot));
 
-    // We only support upcasting of rhs operand.
-    if (matched.lhs_upcast != nullptr)
-      return absl::InternalError("only rhs upcasting is implemented");
+    const HloParameterInstruction* lhs;
+    const HloParameterInstruction* rhs;
 
-    auto dot_dtype = dot->shape().element_type();
-    auto upcast_dtype = matched.rhs_upcast->shape().element_type();
+    if (matched.lhs_upcast == nullptr && matched.rhs_upcast != nullptr) {
+      lhs = Cast<HloParameterInstruction>(matched.dot->operand(0));
+      rhs = Cast<HloParameterInstruction>(matched.rhs_upcast->operand(0));
+    } else if (matched.lhs_upcast != nullptr && matched.rhs_upcast == nullptr) {
+      lhs = Cast<HloParameterInstruction>(matched.lhs_upcast->operand(0));
+      rhs = Cast<HloParameterInstruction>(matched.dot->operand(1));
+    } else {
+      lhs = Cast<HloParameterInstruction>(matched.lhs_upcast->operand(0));
+      rhs = Cast<HloParameterInstruction>(matched.rhs_upcast->operand(0));
+    }
 
-    // We only support BF16 <- BF16 x S8 upcasted gemm.
-    if (dot_dtype != PrimitiveType::BF16 || upcast_dtype != PrimitiveType::S8)
-      return absl::InternalError("unsupported upcasting pattern");
+    const Shape& lhs_shape = lhs->shape();
+    const Shape& rhs_shape = rhs->shape();
 
-    return absl::UnimplementedError("requires CUTLASS 3.3.0");
+    size_t m = lhs_shape.dimensions(0);
+    size_t k = lhs_shape.dimensions(1);
+    size_t n = rhs_shape.dimensions(1);
+
+    PrimitiveType dot_type = dot->shape().element_type();
+    PrimitiveType lhs_type = lhs_shape.element_type();
+    PrimitiveType rhs_type = rhs_shape.element_type();
+
+    // Mapping from fusion arguments to gemm kernel arguments.
+    kernel::gemm_universal::ArgsIndices args_indices = {
+        lhs->parameter_number(), rhs->parameter_number(),
+        computation->num_parameters()};
+
+    return GetCutlassGemmKernels("cutlass_gemm_with_upcast", dot_type, lhs_type,
+                                 rhs_type, m, n, k, args_indices, /*slices=*/{},
+                                 device);
   }
 };
 
@@ -353,7 +391,7 @@
         MatchSimpleGemm(Cast<HloDotInstruction>(matched.dot),
                         {PrimitiveType::F32, PrimitiveType::BF16}));
 
-    auto dtype = matched.dot->shape().element_type();
+    auto dot_type = matched.dot->shape().element_type();
 
     auto* lhs = Cast<HloParameterInstruction>(matched.dot->operand(0));
     auto* rhs = Cast<HloParameterInstruction>(matched.dot->operand(1));
@@ -370,21 +408,25 @@
     kernel::gemm_universal::DynamicSliceIndices slices;
     slices.out = offset->parameter_number();
 
-    auto& lhs_shape = lhs->shape();
-    auto& rhs_shape = rhs->shape();
+    const Shape& lhs_shape = lhs->shape();
+    const Shape& rhs_shape = rhs->shape();
 
     size_t m = lhs_shape.dimensions(0);
     size_t k = lhs_shape.dimensions(1);
     size_t n = rhs_shape.dimensions(1);
 
-    return kernel::gemm_universal::GetCutlassGemmKernels(
-        "cutlass_gemm_with_dynamic_update_slice", dtype, m, n, k, args_indices,
-        slices, device);
+    PrimitiveType lhs_type = lhs->shape().element_type();
+    PrimitiveType rhs_type = rhs->shape().element_type();
+
+    return GetCutlassGemmKernels("cutlass_gemm_with_dynamic_update_slice",
+                                 dot_type, lhs_type, rhs_type, m, n, k,
+                                 args_indices, slices, device);
   }
 };
 
 }  // namespace xla::gpu
 
+XLA_REGISTER_CUSTOM_FUSION_PATTERN(::xla::gpu::CutlassGemmWithUpcastPattern);
 XLA_REGISTER_CUSTOM_FUSION_PATTERN(
     ::xla::gpu::CutlassGemmWithDynamicUpdateSlicePattern);
 
diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc
index a6488e9..768feaf 100644
--- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc
+++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc
@@ -17,18 +17,21 @@
 
 #include <cstdint>
 #include <utility>
+#include <vector>
 
+#include <gtest/gtest.h>
 #include "xla/array.h"
 #include "xla/array2d.h"
 #include "xla/array3d.h"
 #include "xla/error_spec.h"
 #include "xla/literal_util.h"
-#include "xla/service/gpu/custom_kernel_fusion_rewriter.h"
 #include "xla/service/gpu/gpu_device_info_for_tests.h"
 #include "xla/service/gpu/kernels/custom_kernel_fusion_pattern.h"
 #include "xla/service/gpu/kernels/cutlass_gemm_custom_kernel.h"
+#include "xla/service/gpu/transforms/custom_kernel_fusion_rewriter.h"
 #include "xla/tests/hlo_test_base.h"
 #include "xla/types.h"
+#include "xla/xla_data.pb.h"
 #include "tsl/platform/test.h"
 
 namespace xla::gpu {
@@ -41,10 +44,12 @@
         ->GetDeviceDescription()
         .shared_memory_per_block_optin();
   }
-  int CutlassGemmKernelSharedMemorySize(PrimitiveType dtype, int m, int n,
+  int CutlassGemmKernelSharedMemorySize(PrimitiveType dot_type,
+                                        PrimitiveType lhs_type,
+                                        PrimitiveType rhs_type, int m, int n,
                                         int k) {
     return kernel::gemm_universal::GetCutlassGemmKernels(
-               "cutlass_gemm", dtype, m, n, k,
+               "cutlass_gemm", dot_type, lhs_type, rhs_type, m, n, k,
                /*indices=*/{0, 1, 2}, /*slices=*/{},
                backend().default_stream_executor()->GetDeviceDescription())
         ->at(0)
@@ -134,6 +139,48 @@
   RunAndFilecheckHloRewrite(hlo, std::move(pass), expected);
 }
 
+TEST_F(CutlassFusionTest, RowMajorGemmWithUpcastOfBothOperands) {
+  const char* hlo = R"(
+    HloModule test
+
+    ENTRY %main (p0: bf16[15,19], p1: bf16[19,17]) -> f32[15,17] {
+      %p0 = bf16[15,19]{1,0} parameter(0)
+      %c1 = f32[15,19]{1,0} convert(%p0)
+      %p1 = bf16[19,17]{1,0} parameter(1)
+      %c2 = f32[19,17]{1,0} convert(%p1)
+      ROOT %r = f32[15,17]{1,0} dot(%c1, %c2),
+        lhs_contracting_dims={1}, rhs_contracting_dims={0}
+    }
+  )";
+
+  const char* expected = R"(
+    ; CHECK: %cutlass_gemm_with_upcast {{.*}} {
+    ; CHECK-DAG: [[P0:%[^ ]+]] = bf16[15,19]{1,0} parameter
+    ; CHECK:     [[C1:%[^ ]+]] = f32[15,19]{1,0} convert([[P0]])
+    ; CHECK-DAG: [[P1:%[^ ]+]] = bf16[19,17]{1,0} parameter
+    ; CHECK:     [[C2:%[^ ]+]] = f32[19,17]{1,0} convert([[P1]])
+    ; CHECK:     ROOT [[DOT:%[^ ]+]] = f32[15,17]{1,0} dot([[C1]], [[C2]]),
+    ; CHECK:       lhs_contracting_dims={1}, rhs_contracting_dims={0}
+    ; CHECK: }
+
+    ; CHECK: ENTRY %main {{.*}} {
+    ; CHECK:   ROOT [[FUSION:%[^ ]+]] = f32[15,17]{1,0} fusion
+    ; CHECK:     kind=kCustom, calls=%cutlass_gemm_with_upcast,
+    ; CHECK:     backend_config={
+    ; CHECK:       "kind":"__custom_fusion",
+    ; CHECK:       "custom_fusion_config":{"name":"cutlass_gemm_with_upcast","kernel_index":0}
+    ; CHECK:     }
+    ; CHECK: }
+  )";
+
+  CustomKernelFusionPatternRegistry patterns;
+  patterns.Emplace<CutlassGemmWithUpcastPattern>();
+
+  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+  CustomKernelFusionRewriter pass(&device, &patterns);
+  RunAndFilecheckHloRewrite(hlo, std::move(pass), expected);
+}
+
 TEST_F(CutlassFusionTest, RowMajorGemmWithDynamicUpdateSlice) {
   const char* hlo = R"(
     HloModule test
@@ -329,9 +376,7 @@
                                       error_spec, /*run_hlo_passes=*/false));
 }
 
-TEST_F(CutlassFusionTest, RowMajorGemmWithUpcastKernel) {
-  GTEST_SKIP() << "Requires CUTLASS 3.3.0+";
-
+TEST_F(CutlassFusionTest, GemmWithLeftHandSideUpcastKernel) {
   ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3};
 
   const char* hlo_text_cublas = R"(
@@ -339,12 +384,12 @@
 
   ENTRY e {
     p0 = bf16[16,32]{1,0} parameter(0)
-    p1 = s8[32,8]{1,0} parameter(1)
-    c1 = bf16[32,8]{1,0} convert(p1)
-    gemm = (bf16[16,8]{1,0}, s8[0]{0}) custom-call(p0, c1),
+    c0 = f32[16,32]{1,0} convert(p0)
+    p1 = f32[32,8]{1,0} parameter(1)
+    gemm = (f32[16,8]{1,0}, s8[0]{0}) custom-call(c0, p1),
       custom_call_target="__cublas$gemm",
       backend_config={"gemm_backend_config":{"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":[1],"rhs_contracting_dimensions":[0],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}}
-    ROOT get-tuple-element = bf16[16,8]{1,0} get-tuple-element(gemm), index=0
+    ROOT get-tuple-element = f32[16,8]{1,0} get-tuple-element(gemm), index=0
   })";
 
   const char* hlo_text_custom_fusion = R"(
@@ -352,16 +397,97 @@
 
   cutlass_gemm_with_upcast {
     p0 = bf16[16,32]{1,0} parameter(0)
+    c0 = f32[16,32]{1,0} convert(p0)
+    p1 = f32[32,8]{1,0} parameter(1)
+    ROOT dot = f32[16,8]{1,0} dot(c0, p1),
+      lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  }
+
+  ENTRY e {
+    p0 = bf16[16,32]{1,0} parameter(0)
+    p1 = f32[32,8]{1,0} parameter(1)
+    ROOT _ = f32[16,8]{1,0} fusion(p0, p1), kind=kCustom, calls=cutlass_gemm_with_upcast,
+      backend_config={"fusion_backend_config":{kind: "__custom_fusion", custom_fusion_config: {"name":"cutlass_gemm_with_upcast", "kernel_index":0}}}
+  })";
+
+  EXPECT_TRUE(RunAndCompareTwoModules(hlo_text_cublas, hlo_text_custom_fusion,
+                                      error_spec, /*run_hlo_passes=*/false));
+}
+
+TEST_F(CutlassFusionTest, GemmWithRightHandSideUpcastKernel) {
+  ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3};
+
+  const char* hlo_text_cublas = R"(
+  HloModule cublas
+
+  ENTRY e {
+    p0 = f32[16,32]{1,0} parameter(0)
+    p1 = bf16[32,8]{1,0} parameter(1)
+    c1 = f32[32,8]{1,0} convert(p1)
+    gemm = (f32[16,8]{1,0}, s8[0]{0}) custom-call(p0, c1),
+      custom_call_target="__cublas$gemm",
+      backend_config={"gemm_backend_config":{"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":[1],"rhs_contracting_dimensions":[0],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}}
+    ROOT get-tuple-element = f32[16,8]{1,0} get-tuple-element(gemm), index=0
+  })";
+
+  const char* hlo_text_custom_fusion = R"(
+  HloModule cutlass
+
+  cutlass_gemm_with_upcast {
+    p0 = f32[16,32]{1,0} parameter(0)
+    p1 = bf16[32,8]{1,0} parameter(1)
+    c1 = f32[32,8]{1,0} convert(p1)
+    ROOT dot = f32[16,8]{1,0} dot(p0, c1),
+      lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  }
+
+  ENTRY e {
+    p0 = f32[16,32]{1,0} parameter(0)
+    p1 = bf16[32,8]{1,0} parameter(1)
+    ROOT _ = f32[16,8]{1,0} fusion(p0, p1), kind=kCustom,
+    calls=cutlass_gemm_with_upcast,
+      backend_config={"fusion_backend_config":{kind: "__custom_fusion",
+      custom_fusion_config: {"name":"cutlass_gemm_with_upcast",
+      "kernel_index":0}}}
+  })";
+
+  EXPECT_TRUE(RunAndCompareTwoModules(hlo_text_cublas, hlo_text_custom_fusion,
+                                      error_spec, /*run_hlo_passes=*/false));
+}
+
+TEST_F(CutlassFusionTest, GemmWithLeftHandAndRightHandSideUpcastKernel) {
+  ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3};
+
+  const char* hlo_text_cublas = R"(
+  HloModule cublas
+
+  ENTRY e {
+    p0 = bf16[16,32]{1,0} parameter(0)
+    c0 = f32[16,32]{1,0} convert(p0)
     p1 = s8[32,8]{1,0} parameter(1)
-    c1 = bf16[32,8]{1,0} convert(p1)
-    ROOT dot = bf16[16,8]{1,0} dot(p0, c1),
+    c1 = f32[32,8]{1,0} convert(p1)
+    gemm = (f32[16,8]{1,0}, s8[0]{0}) custom-call(c0, c1),
+      custom_call_target="__cublas$gemm",
+      backend_config={"gemm_backend_config":{"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":[1],"rhs_contracting_dimensions":[0],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}}
+    ROOT get-tuple-element = f32[16,8]{1,0} get-tuple-element(gemm), index=0
+  })";
+
+  const char* hlo_text_custom_fusion = R"(
+  HloModule cutlass
+
+  cutlass_gemm_with_upcast {
+    p0 = bf16[16,32]{1,0} parameter(0)
+    c0 = f32[16,32]{1,0} convert(p0)
+    p1 = s8[32,8]{1,0} parameter(1)
+    c1 = f32[32,8]{1,0} convert(p1)
+    ROOT dot = f32[16,8]{1,0} dot(c0, c1),
       lhs_contracting_dims={1}, rhs_contracting_dims={0}
   }
 
   ENTRY e {
     p0 = bf16[16,32]{1,0} parameter(0)
     p1 = s8[32,8]{1,0} parameter(1)
-    ROOT _ = bf16[16,8]{1,0} fusion(p0, p1), kind=kCustom, calls=cutlass_gemm_with_upcast,
+    ROOT _ = f32[16,8]{1,0} fusion(p0, p1), kind=kCustom, calls=cutlass_gemm_with_upcast,
       backend_config={"fusion_backend_config":{kind: "__custom_fusion", custom_fusion_config: {"name":"cutlass_gemm_with_upcast", "kernel_index":0}}}
   })";
 
@@ -371,7 +497,7 @@
 
 TEST_F(CutlassFusionTest, RowMajorGemmWithDynamicUpdateSliceKernel) {
   if (GpuSharedMemorySize() <
-      CutlassGemmKernelSharedMemorySize(BF16, 8, 8, 8)) {
+      CutlassGemmKernelSharedMemorySize(BF16, BF16, BF16, 8, 8, 8)) {
     GTEST_SKIP_("The GPU does not have sufficient shared memory");
   }
 
@@ -445,7 +571,7 @@
 TEST_F(CutlassFusionTest,
        RowMajorGemmWithDynamicUpdateSliceKernelWithoutBitcast) {
   if (GpuSharedMemorySize() <
-      CutlassGemmKernelSharedMemorySize(BF16, 8, 8, 8)) {
+      CutlassGemmKernelSharedMemorySize(BF16, BF16, BF16, 8, 8, 8)) {
     GTEST_SKIP_("The GPU does not have sufficient shared memory");
   }
 
diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel_bf16xbf16_to_f32.cu.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel_bf16xbf16_to_f32.cu.cc
new file mode 100644
index 0000000..ec08008
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel_bf16xbf16_to_f32.cu.cc
@@ -0,0 +1,51 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "cutlass/arch/arch.h"
+#include "cutlass/arch/mma.h"
+#include "cutlass/gemm/device/gemm_universal.h"
+#include "cutlass/gemm/kernel/default_gemm_universal.h"
+#include "xla/service/gpu/kernels/cutlass_gemm_adaptor.cu.h"
+
+namespace xla::gpu::kernel::gemm_universal {
+
+namespace {
+
+using ElementA = cutlass::bfloat16_t;
+using ElementB = cutlass::bfloat16_t;
+using ElementOutput = float;
+using ElementAccumulator = float;
+
+}  // namespace
+
+using GemmOperation = cutlass::gemm::device::GemmUniversal<
+    ElementA, cutlass::layout::RowMajor, ElementB, cutlass::layout::RowMajor,
+    ElementOutput, cutlass::layout::RowMajor, ElementAccumulator,
+    cutlass::arch::OpClassSimt, cutlass::arch::Sm70,
+    cutlass::gemm::GemmShape<128, 32, 8>, cutlass::gemm::GemmShape<64, 32, 8>,
+    cutlass::gemm::GemmShape<1, 1, 1>,
+    cutlass::epilogue::thread::LinearCombination<float, 1, float, float>,
+    cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
+    2,  // stages
+    1,  // A alignment
+    1,  // B alignment
+    cutlass::arch::OpMultiplyAdd>;
+
+XLA_GPU_DEFINE_CUTLASS_GEMM_TRAITS(Bf16xBf16ToF32<Arch::kDefault>,
+                                   GemmOperation);
+template struct Adaptor<Bf16xBf16ToF32<Arch::kDefault>>;
+template struct DeviceKernel<Bf16xBf16ToF32<Arch::kDefault>>;
+
+}  // namespace xla::gpu::kernel::gemm_universal
diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel_bf16xf32_to_f32.cu.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel_bf16xf32_to_f32.cu.cc
new file mode 100644
index 0000000..e117b1a
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel_bf16xf32_to_f32.cu.cc
@@ -0,0 +1,51 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "cutlass/arch/arch.h"
+#include "cutlass/arch/mma.h"
+#include "cutlass/gemm/device/gemm_universal.h"
+#include "cutlass/gemm/kernel/default_gemm_universal.h"
+#include "xla/service/gpu/kernels/cutlass_gemm_adaptor.cu.h"
+
+namespace xla::gpu::kernel::gemm_universal {
+
+namespace {
+
+using ElementA = cutlass::bfloat16_t;
+using ElementB = float;
+using ElementOutput = float;
+using ElementAccumulator = float;
+
+}  // namespace
+
+using GemmOperation = cutlass::gemm::device::GemmUniversal<
+    ElementA, cutlass::layout::RowMajor, ElementB, cutlass::layout::RowMajor,
+    ElementOutput, cutlass::layout::RowMajor, ElementAccumulator,
+    cutlass::arch::OpClassSimt, cutlass::arch::Sm70,
+    cutlass::gemm::GemmShape<128, 32, 8>, cutlass::gemm::GemmShape<64, 32, 8>,
+    cutlass::gemm::GemmShape<1, 1, 1>,
+    cutlass::epilogue::thread::LinearCombination<float, 1, float, float>,
+    cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
+    2,  // stages
+    1,  // A alignment
+    1,  // B alignment
+    cutlass::arch::OpMultiplyAdd>;
+
+XLA_GPU_DEFINE_CUTLASS_GEMM_TRAITS(Bf16xF32ToF32<Arch::kDefault>,
+                                   GemmOperation);
+template struct Adaptor<Bf16xF32ToF32<Arch::kDefault>>;
+template struct DeviceKernel<Bf16xF32ToF32<Arch::kDefault>>;
+
+}  // namespace xla::gpu::kernel::gemm_universal
diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel_bf16xs8_to_f32.cu.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel_bf16xs8_to_f32.cu.cc
new file mode 100644
index 0000000..527d369
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel_bf16xs8_to_f32.cu.cc
@@ -0,0 +1,50 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "cutlass/arch/arch.h"
+#include "cutlass/arch/mma.h"
+#include "cutlass/gemm/device/gemm_universal.h"
+#include "cutlass/gemm/kernel/default_gemm_universal.h"
+#include "xla/service/gpu/kernels/cutlass_gemm_adaptor.cu.h"
+
+namespace xla::gpu::kernel::gemm_universal {
+
+namespace {
+
+using ElementA = cutlass::bfloat16_t;
+using ElementB = cutlass::int8_t;
+using ElementOutput = float;
+using ElementAccumulator = float;
+
+}  // namespace
+
+using GemmOperation = cutlass::gemm::device::GemmUniversal<
+    ElementA, cutlass::layout::RowMajor, ElementB, cutlass::layout::RowMajor,
+    ElementOutput, cutlass::layout::RowMajor, ElementAccumulator,
+    cutlass::arch::OpClassSimt, cutlass::arch::Sm70,
+    cutlass::gemm::GemmShape<64, 128, 8>, cutlass::gemm::GemmShape<32, 64, 8>,
+    cutlass::gemm::GemmShape<1, 1, 1>,
+    cutlass::epilogue::thread::LinearCombination<float, 1, float, float>,
+    cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
+    2,  // stages
+    1,  // A alignment
+    1,  // B alignment
+    cutlass::arch::OpMultiplyAdd>;
+
+XLA_GPU_DEFINE_CUTLASS_GEMM_TRAITS(Bf16xS8ToF32<Arch::kDefault>, GemmOperation);
+template struct Adaptor<Bf16xS8ToF32<Arch::kDefault>>;
+template struct DeviceKernel<Bf16xS8ToF32<Arch::kDefault>>;
+
+}  // namespace xla::gpu::kernel::gemm_universal
diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel_f32xbf16_to_f32.cu.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel_f32xbf16_to_f32.cu.cc
new file mode 100644
index 0000000..6ec6963
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel_f32xbf16_to_f32.cu.cc
@@ -0,0 +1,51 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "cutlass/arch/arch.h"
+#include "cutlass/arch/mma.h"
+#include "cutlass/gemm/device/gemm_universal.h"
+#include "cutlass/gemm/kernel/default_gemm_universal.h"
+#include "xla/service/gpu/kernels/cutlass_gemm_adaptor.cu.h"
+
+namespace xla::gpu::kernel::gemm_universal {
+
+namespace {
+
+using ElementA = float;
+using ElementB = cutlass::bfloat16_t;
+using ElementOutput = float;
+using ElementAccumulator = float;
+
+}  // namespace
+
+using GemmOperation = cutlass::gemm::device::GemmUniversal<
+    ElementA, cutlass::layout::RowMajor, ElementB, cutlass::layout::RowMajor,
+    ElementOutput, cutlass::layout::RowMajor, ElementAccumulator,
+    cutlass::arch::OpClassSimt, cutlass::arch::Sm70,
+    cutlass::gemm::GemmShape<128, 32, 8>, cutlass::gemm::GemmShape<64, 32, 8>,
+    cutlass::gemm::GemmShape<1, 1, 1>,
+    cutlass::epilogue::thread::LinearCombination<float, 1, float, float>,
+    cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
+    2,  // stages
+    1,  // A alignment
+    1,  // B alignment
+    cutlass::arch::OpMultiplyAdd>;
+
+XLA_GPU_DEFINE_CUTLASS_GEMM_TRAITS(F32xBf16ToF32<Arch::kDefault>,
+                                   GemmOperation);
+template struct Adaptor<F32xBf16ToF32<Arch::kDefault>>;
+template struct DeviceKernel<F32xBf16ToF32<Arch::kDefault>>;
+
+}  // namespace xla::gpu::kernel::gemm_universal
diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel_f32xf32_to_f32.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel_f32xf32_to_f32.cc
index 5aff534..119d724 100644
--- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel_f32xf32_to_f32.cc
+++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel_f32xf32_to_f32.cc
@@ -51,14 +51,14 @@
 extern "C" bool xla_cutlass_kernel_can_implement(int32_t m, int32_t n,
                                                  int32_t k) {
   Adaptor<CutlassGemm> adaptor;
-  Arguments arguments = {m, n, k};
+  Arguments arguments = {GemmMode::kGemm, /*batch_count=*/1, m, n, k};
   return adaptor.CanImplement(arguments);
 }
 
 extern "C" int64_t xla_cutlass_kernel_workspace_size(int32_t m, int32_t n,
                                                      int32_t k) {
   Adaptor<CutlassGemm> adaptor;
-  Arguments arguments = {m, n, k};
+  Arguments arguments = {GemmMode::kGemm, /*batch_count=*/1, m, n, k};
   return adaptor.WorkspaceSize(arguments);
 }
 
@@ -67,7 +67,9 @@
     void* out, void* workspace, int32_t* out_offset, int32_t device_sms,
     int32_t sm_occupancy) {
   Adaptor<CutlassGemm> adaptor;
-  Arguments arguments = {m, n, k, lhs, rhs, out, workspace, {out_offset}};
+  Arguments arguments = {
+      GemmMode::kGemm, /*batch_count=*/1, m, n, k, lhs, rhs, out,
+      workspace,       {out_offset}};
   adaptor.Initialize(params, arguments, device_sms, sm_occupancy);
 }
 
diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/BUILD b/third_party/xla/xla/service/gpu/llvm_gpu_backend/BUILD
index 8951c27..8fc3db5 100644
--- a/third_party/xla/xla/service/gpu/llvm_gpu_backend/BUILD
+++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/BUILD
@@ -3,6 +3,10 @@
     "if_rocm_is_configured",
 )
 load(
+    "@local_config_sycl//sycl:build_defs.bzl",
+    "if_sycl_is_configured",
+)
+load(
     "@local_tsl//tsl/platform/default:cuda_build_defs.bzl",
     "if_cuda_is_configured",
 )
@@ -88,6 +92,8 @@
         "@local_config_rocm//rocm:rocm_headers",
         "@llvm-project//llvm:AMDGPUCodeGen",
         "@llvm-project//llvm:AMDGPUAsmParser",
+    ]) + if_sycl_is_configured([
+        "@spirv_llvm_translator//:spirv_llvm_translator",
     ]),
 )
 
@@ -106,3 +112,16 @@
         "@local_tsl//tsl/platform:test",
     ],
 )
+
+xla_cc_test(
+    name = "gpu_backend_lib_test",
+    size = "small",
+    srcs = ["gpu_backend_lib_test.cc"],
+    deps = [
+        ":llvm_gpu_backend",
+        "//xla/stream_executor:device_description",
+        "//xla/tests:xla_internal_test_main",
+        "@llvm-project//llvm:Core",
+        "@local_tsl//tsl/platform:test",
+    ],
+)
diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc
index 13e17bb..696e360 100644
--- a/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc
+++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc
@@ -17,6 +17,7 @@
 
 #include <algorithm>
 #include <cstdint>
+#include <cstdlib>
 #include <fstream>
 #include <functional>
 #include <ios>
@@ -105,6 +106,11 @@
 #include "xla/stream_executor/cuda/cuda_asm_compiler.h"
 #endif
 
+#if TENSORFLOW_USE_SYCL
+#include "LLVMSPIRVLib.h"
+#include "LLVMSPIRVOpts.h"
+#endif  // TENSORFLOW_USE_SYCL
+
 namespace xla {
 namespace gpu {
 namespace {
@@ -117,41 +123,6 @@
 // Default inline threshold value to use in llvm.
 const int kDefaultInlineThreshold = 1100;
 
-// Gets the GPU name as it's known to LLVM for a given compute
-// capability.  If we see an unrecognized compute capability, we
-// return the highest one that is known and below the selected device.
-static std::string GetSmName(se::CudaComputeCapability compute_capability) {
-  int compute_capability_version =
-      compute_capability.major * 10 + compute_capability.minor;
-  int sm_version = 30;
-  // If the current compute capability isn't known, fallback to the
-  // most recent version before it.
-  int supported_versions[] = {90, 89, 87, 86, 80, 75, 72, 70, 62,
-                              61, 60, 53, 52, 50, 37, 35, 32, 30};
-  for (int v : supported_versions) {
-    if (v <= compute_capability_version) {
-      sm_version = v;
-      break;
-    }
-  }
-
-  // If the current CC isn't supported by LLVM and it is newer then
-  // the max supported LLVM version, do not warn about it. The end
-  // user can't do anything about this. E.g., PTX compiled for SM75 will
-  // run on SM80 too.
-  if (sm_version != compute_capability_version &&
-      compute_capability_version < supported_versions[0]) {
-    LOG(WARNING) << "Unknown compute capability "
-                 << compute_capability.ToString()
-                 << ". Defaulting to telling LLVM that we're compiling for sm_"
-                 << sm_version;
-  }
-  // If the target is sm_90, hard code it to sm_90a so that all instructions
-  // can be used. We don't need the portability that sm_90 gives.
-  std::string_view extension = sm_version == 90 ? "a" : "";
-  return absl::StrCat("sm_", sm_version, extension);
-}
-
 // NOLINTBEGIN: clang-diagnostic-unused-function
 // Convenience function for producing a name of a temporary compilation product
 // from the input filename.
@@ -378,7 +349,7 @@
 #else
   std::string feature_str;
 #endif  // GOOGLE_CUDA
-  return GetTargetMachine(target_triple, GetSmName(compute_capability),
+  return GetTargetMachine(target_triple, nvptx::GetSmName(compute_capability),
                           debug_options, feature_str);
 }
 
@@ -452,7 +423,9 @@
   llvm::CGSCCAnalysisManager cgam;
   llvm::ModuleAnalysisManager mam;
 
-  fam.registerPass([&] { return target_machine->getTargetIRAnalysis(); });
+  if (target_machine) {
+    fam.registerPass([&] { return target_machine->getTargetIRAnalysis(); });
+  }
 
   llvm::PipelineTuningOptions pto;
   pto.SLPVectorization = true;
@@ -569,6 +542,40 @@
 
 namespace nvptx {
 
+std::string GetSmName(se::CudaComputeCapability compute_capability) {
+  int compute_capability_version =
+      compute_capability.major * 10 + compute_capability.minor;
+  int sm_version = 30;
+  // If the current compute capability isn't known, fallback to the
+  // most recent version before it.
+  int supported_versions[] = {90, 89, 87, 86, 80, 75, 72, 70, 62,
+                              61, 60, 53, 52, 50, 37, 35, 32, 30};
+  for (int v : supported_versions) {
+    if (v <= compute_capability_version) {
+      sm_version = v;
+      break;
+    }
+  }
+
+  // If the current CC isn't supported by LLVM and it is newer then
+  // the max supported LLVM version, do not warn about it. The end
+  // user can't do anything about this. E.g., PTX compiled for SM75 will
+  // run on SM80 too.
+  if (sm_version != compute_capability_version &&
+      compute_capability_version < supported_versions[0]) {
+    LOG(WARNING) << "Unknown compute capability "
+                 << compute_capability.ToString()
+                 << ". Defaulting to telling LLVM that we're compiling for sm_"
+                 << sm_version;
+  }
+  // On Hopper, default to sm_90a so that all instructions can be used. But
+  // only sm_90 is forward compatible, so don't use sm_90a with newer hardware:
+  // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#ptx-compatibility
+  std::string_view extension =
+      (compute_capability.major == 9 && sm_version == 90) ? "a" : "";
+  return absl::StrCat("sm_", sm_version, extension);
+}
+
 std::string CantFindCudaMessage(absl::string_view msg,
                                 absl::string_view xla_gpu_cuda_data_dir) {
   return absl::StrCat(
@@ -855,7 +862,12 @@
     ir_fs->flush();
   }
   // Locate lld.
-  std::string lld_path = tsl::io::JoinPath(tsl::RocmRoot(), "llvm/bin");
+  std::string lld_path;
+  if (std::getenv("LLVM_PATH")) {
+    lld_path = tsl::io::JoinPath(std::getenv("LLVM_PATH"), "bin");
+  } else {
+    lld_path = tsl::io::JoinPath(tsl::RocmRoot(), "llvm/bin");
+  }
   auto lld_program = llvm::sys::findProgramByName("ld.lld", {lld_path});
   if (!lld_program) {
     return xla::Internal("unable to find ld.lld in PATH: %s",
@@ -1132,5 +1144,95 @@
 
 }  // namespace amdgpu
 
+namespace {
+
+std::unique_ptr<llvm::TargetMachine> SPIRGetTargetMachine(
+    llvm::Triple target_triple, se::GpuComputeCapability gpu_version,
+    const DebugOptions& debug_options) {
+  return nullptr;
+}
+
+absl::Status SPIRTargetModuleLinker(
+    llvm::Module* module, se::GpuComputeCapability gpu_version,
+    const DebugOptions& debug_options,
+    const std::string& device_bitcode_dir_path) {
+  return absl::OkStatus();
+}
+
+absl::StatusOr<std::string> EmitModuleToSpir(
+    llvm::Module* module, se::GpuComputeCapability gpu_version,
+    const DebugOptions& debug_options) {
+#if TENSORFLOW_USE_SYCL
+  SPIRV::TranslatorOpts::ExtensionsStatusMap ExtensionsStatus;
+  SPIRV::TranslatorOpts opts(SPIRV::VersionNumber::MaximumVersion,
+                             ExtensionsStatus);
+  opts.enableAllExtensions();  // enable all SPIR-V extension first
+
+  std::ostringstream oss;
+  std::string err;
+  bool success = llvm::writeSpirv(module, opts, oss, err);
+  if (!success) {
+    return xla::Internal("Fails to convert LLVM as SPIR-V: %s", err);
+  }
+  return oss.str();
+#else
+  return absl::UnimplementedError("Not implemented for SYCL");
+#endif
+}
+
+void SPIRBackendInit(const DebugOptions& debug_options) {
+  FeedLLVMWithFlags({
+      "-slp-vectorize-hor=false",
+      "-slp-min-reg-size=64",
+      "-slp-max-reg-size=64",
+  });
+
+  llvm_ir::InitializeLLVMCommandLineOptions(
+      debug_options.xla_backend_extra_options());
+
+  llvm::PassRegistry* registry = llvm::PassRegistry::getPassRegistry();
+  InitializePasses(registry);
+}
+
+}  // namespace
+
+namespace spir {
+
+absl::StatusOr<std::vector<uint8_t>> CompileToSpir(
+    llvm::Module* module, se::GpuComputeCapability gpu_version,
+    const DebugOptions& debug_options) {
+  std::string libdevice_dir_path;
+  static absl::once_flag backend_init_flag;
+  absl::call_once(backend_init_flag, SPIRBackendInit, debug_options);
+
+  std::string spir;
+  {
+    XLA_SCOPED_LOGGING_TIMER("Compile module " + module->getName().str());
+
+    // If the module has no functions or globals, there's nothing to compile.
+    if (module->empty() && module->global_empty()) {
+      VLOG(2) << "Module '" << module->getName().str()
+              << "' is empty. Skipping compilation.";
+      return std::vector<uint8_t>();
+    }
+
+    llvm::Triple default_target_triple("spir64-unknown-unknown");
+    std::unique_ptr<llvm::TargetMachine> target_machine =
+        SPIRGetTargetMachine(default_target_triple, gpu_version, debug_options);
+
+    TF_RETURN_IF_ERROR(LinkAndOptimizeModule(
+        module, gpu_version, debug_options, libdevice_dir_path,
+        SPIRTargetModuleLinker, default_target_triple, target_machine.get(),
+        kDefaultInlineThreshold));
+
+    // Lower optimized LLVM module to SPIR.
+    TF_ASSIGN_OR_RETURN(spir,
+                        EmitModuleToSpir(module, gpu_version, debug_options));
+  }
+  return std::vector<uint8_t>(spir.begin(), spir.end());
+}
+
+}  // namespace spir
+
 }  // namespace gpu
 }  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h b/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h
index 3ab5d6d..1814291 100644
--- a/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h
+++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h
@@ -37,6 +37,11 @@
 
 namespace nvptx {
 
+// Gets the GPU name as it's known to LLVM for a given compute
+// capability.  If we see an unrecognized compute capability, we
+// return the highest one that is known and below the selected device.
+std::string GetSmName(se::CudaComputeCapability compute_capability);
+
 std::string CantFindCudaMessage(absl::string_view msg,
                                 absl::string_view xla_gpu_cuda_data_dir);
 
@@ -73,6 +78,13 @@
     const std::string& module_config_cache_key);
 }  // namespace amdgpu
 
+namespace spir {
+// Compiles the argument module and returns it.
+absl::StatusOr<std::vector<uint8_t>> CompileToSpir(
+    llvm::Module* module, se::GpuComputeCapability gpu_version,
+    const DebugOptions& debug_options);
+}  // namespace spir
+
 }  // namespace gpu
 }  // namespace xla
 
diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib_test.cc b/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib_test.cc
new file mode 100644
index 0000000..9e65f34
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib_test.cc
@@ -0,0 +1,38 @@
+/* Copyright 2017 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
+
+#include "xla/stream_executor/device_description.h"
+#include "tsl/platform/test.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+namespace se = ::stream_executor;
+
+TEST(UtilsTest, TestGetSmName) {
+  se::CudaComputeCapability cc_hopper(9, 0);
+  ASSERT_EQ(nvptx::GetSmName(cc_hopper), "sm_90a");
+  // Do not default to sm90_a after Hopper, because it is not forward
+  // compatible.
+  // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#ptx-compatibility
+  se::CudaComputeCapability cc_next(10, 0);
+  ASSERT_EQ(nvptx::GetSmName(cc_next), "sm_90");
+}
+
+}  // namespace
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/matmul_utils.cc b/third_party/xla/xla/service/gpu/matmul_utils.cc
index fe4982e..49270de 100644
--- a/third_party/xla/xla/service/gpu/matmul_utils.cc
+++ b/third_party/xla/xla/service/gpu/matmul_utils.cc
@@ -456,7 +456,11 @@
     const HloInstruction* gemm) {
   TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
                       gemm->backend_config<GpuBackendConfig>());
-  const GemmBackendConfig& config = gpu_config.gemm_backend_config();
+  return For(gemm, gpu_config.gemm_backend_config());
+}
+
+/*static*/ absl::StatusOr<GemmConfig> GemmConfig::For(
+    const HloInstruction* gemm, const GemmBackendConfig& config) {
   std::optional<int64_t> algorithm;
   if (config.algorithm_case() != GemmBackendConfig::ALGORITHM_NOT_SET) {
     algorithm = config.selected_algorithm();
diff --git a/third_party/xla/xla/service/gpu/matmul_utils.h b/third_party/xla/xla/service/gpu/matmul_utils.h
index 22d7f17..5f128e4 100644
--- a/third_party/xla/xla/service/gpu/matmul_utils.h
+++ b/third_party/xla/xla/service/gpu/matmul_utils.h
@@ -108,6 +108,11 @@
 
   static absl::StatusOr<GemmConfig> For(const HloInstruction* gemm);
 
+  // Gets the GemmConfig of the `gemm` instruction with overridden
+  // GemmBackendConfig.
+  static absl::StatusOr<GemmConfig> For(const HloInstruction* gemm,
+                                        const GemmBackendConfig& config);
+
   static absl::StatusOr<GemmConfig> For(
       const Shape& lhs_shape, absl::Span<const int64_t> lhs_batch_dims,
       absl::Span<const int64_t> lhs_contracting_dims, const Shape& rhs_shape,
diff --git a/third_party/xla/xla/service/gpu/model/BUILD b/third_party/xla/xla/service/gpu/model/BUILD
index e4b5048..6173fa5 100644
--- a/third_party/xla/xla/service/gpu/model/BUILD
+++ b/third_party/xla/xla/service/gpu/model/BUILD
@@ -475,7 +475,6 @@
         "//xla/service:gather_simplifier",
         "//xla/service/gpu:hlo_traversal",
         "//xla/service/gpu:matmul_utils",
-        "//xla/service/gpu/fusions:tiling_util",
         "@com_google_absl//absl/algorithm:container",
         "@com_google_absl//absl/base:core_headers",
         "@com_google_absl//absl/container:flat_hash_map",
@@ -549,7 +548,6 @@
         ":indexing_test_utils",
         "//xla/hlo/ir:hlo",
         "//xla/service/gpu:hlo_traversal",
-        "//xla/service/gpu/fusions:tiling_util",
         "//xla/tests:hlo_test_base",
         "//xla/tests:xla_internal_test_main",
         "@com_google_absl//absl/strings:string_view",
diff --git a/third_party/xla/xla/service/gpu/model/coalescing_analysis_test.cc b/third_party/xla/xla/service/gpu/model/coalescing_analysis_test.cc
index 11ebb82..6a18288 100644
--- a/third_party/xla/xla/service/gpu/model/coalescing_analysis_test.cc
+++ b/third_party/xla/xla/service/gpu/model/coalescing_analysis_test.cc
@@ -52,7 +52,7 @@
 
   std::vector<bool> IsReadCoalescedPerOperand(const HloInstruction* root) {
     auto fusion_adaptor = HloFusionAdaptor::ForInstruction(root);
-    auto analysis = AnalyzeFusion(*root, device_info_);
+    auto analysis = HloFusionAnalysis::Create(*root, device_info_);
     auto emitter = GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis});
     auto fusion = dynamic_cast<KernelFusionInterface*>(emitter.get());
     EXPECT_NE(fusion, nullptr);
@@ -71,7 +71,7 @@
   bool IsReadCoalescedHeuristic(absl::string_view hlo_string) {
     auto module = ParseAndReturnVerifiedModule(hlo_string).value();
     HloInstruction* root = module->entry_computation()->root_instruction();
-    auto analysis = AnalyzeFusion(*root, device_info_);
+    auto analysis = HloFusionAnalysis::Create(*root, device_info_);
     return xla::gpu::IsReadCoalescedHeuristic(analysis.GetEmitterFusionKind(),
                                               root->operand(0), root);
   }
diff --git a/third_party/xla/xla/service/gpu/model/fusion_analysis_cache.cc b/third_party/xla/xla/service/gpu/model/fusion_analysis_cache.cc
index ba033fb..f4369df 100644
--- a/third_party/xla/xla/service/gpu/model/fusion_analysis_cache.cc
+++ b/third_party/xla/xla/service/gpu/model/fusion_analysis_cache.cc
@@ -33,7 +33,8 @@
     }
   }
 
-  HloFusionAnalysis analysis = AnalyzeFusion(instruction, device_info_);
+  HloFusionAnalysis analysis =
+      HloFusionAnalysis::Create(instruction, device_info_);
   absl::MutexLock lock(&mutex_);
 
   // If some other thread created an entry for this key concurrently, return
@@ -59,7 +60,7 @@
   }
 
   HloFusionAnalysis analysis =
-      AnalyzeProducerConsumerFusion(producer, consumer, device_info_);
+      HloFusionAnalysis::Create(producer, consumer, device_info_);
   absl::MutexLock lock(&mutex_);
 
   // If some other thread created an entry for this key concurrently, return
diff --git a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc
index 49b914e..f3417c9 100644
--- a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc
+++ b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc
@@ -250,7 +250,7 @@
                                /*exec_time=*/absl::ZeroDuration()};
   }
 
-  auto fusion_analysis = AnalyzeFusion(*producer, *device_info_);
+  auto fusion_analysis = HloFusionAnalysis::Create(*producer, *device_info_);
 
   bool is_coalesced = IsReadCoalescedHeuristic(
       fusion_analysis.GetEmitterFusionKind(), producer);
@@ -261,7 +261,7 @@
 GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForProducerConsumer(
     const HloInstruction* producer, const HloInstruction* consumer) {
   auto fusion_analysis =
-      AnalyzeProducerConsumerFusion(*producer, *consumer, *device_info_);
+      HloFusionAnalysis::Create(*producer, *consumer, *device_info_);
 
   bool is_coalesced = IsReadCoalescedHeuristic(
       fusion_analysis.GetEmitterFusionKind(), producer, consumer);
diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc b/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc
index c2057e0..6bb4071 100644
--- a/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc
+++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc
@@ -56,7 +56,7 @@
   // TODO(jreiffers): Remove this once all callers use a cache.
   std::optional<HloFusionAnalysis> local_analysis;
   if (!config.fusion_analysis_cache) {
-    local_analysis = AnalyzeFusion(*instr, device_info);
+    local_analysis = HloFusionAnalysis::Create(*instr, device_info);
   }
   const auto& fusion_analysis = config.fusion_analysis_cache
                                     ? config.fusion_analysis_cache->Get(*instr)
@@ -144,7 +144,7 @@
     // TODO(jreiffers): Remove this once all callers use a cache.
     std::optional<HloFusionAnalysis> local_analysis;
     if (!config.fusion_analysis_cache) {
-      local_analysis = AnalyzeFusion(*fused_consumer, device_info);
+      local_analysis = HloFusionAnalysis::Create(*fused_consumer, device_info);
     }
     const auto& analysis_unfused =
         config.fusion_analysis_cache
@@ -193,7 +193,7 @@
   std::optional<HloFusionAnalysis> local_analysis_fused;
   if (!config.fusion_analysis_cache) {
     local_analysis_fused =
-        AnalyzeProducerConsumerFusion(*producer, *consumer, device_info);
+        HloFusionAnalysis::Create(*producer, *consumer, device_info);
   }
   const auto& fusion_analysis =
       config.fusion_analysis_cache
@@ -296,8 +296,8 @@
 
     std::optional<HloFusionAnalysis> local_analysis_fused;
     if (!config.fusion_analysis_cache) {
-      local_analysis_fused = AnalyzeProducerConsumerFusion(
-          *producer, *fused_consumer, device_info);
+      local_analysis_fused =
+          HloFusionAnalysis::Create(*producer, *fused_consumer, device_info);
     }
     const auto& analysis_fused =
         config.fusion_analysis_cache
@@ -345,8 +345,9 @@
     const GpuPerformanceModelOptions& config,
     absl::Span<const HloInstruction* const> fused_consumers,
     bool multi_output) {
-  EstimateRunTimeData producer_runtime = EstimateRunTimeForInstructionCached(
-      producer, device_info, cost_analysis, config);
+  auto cache_result = config.gpu_performance_model_cache->Get(*producer);
+  CHECK(cache_result.has_value());
+  EstimateRunTimeData producer_runtime = *cache_result;
 
   absl::Duration time_unfused =
       kKernelLaunchOverhead * (fused_consumers.size() + 1) +
@@ -357,8 +358,10 @@
   for (auto fused_consumer : fused_consumers) {
     VLOG(8) << "Fused consumer: " << fused_consumer->name();
 
-    EstimateRunTimeData consumer_runtime = EstimateRunTimeForInstructionCached(
-        fused_consumer, device_info, cost_analysis, config);
+    auto cache_result =
+        config.gpu_performance_model_cache->Get(*fused_consumer);
+    CHECK(cache_result.has_value());
+    EstimateRunTimeData consumer_runtime = *cache_result;
 
     time_unfused += consumer_runtime.exec_time;
 
diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.cc b/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.cc
index 56e34c3..08ae2e5 100644
--- a/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.cc
+++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.cc
@@ -88,8 +88,6 @@
 
 std::optional<EstimateRunTimeData> GpuPerformanceModelCache::Get(
     const HloInstruction& instruction) {
-  absl::MutexLock lock(&mutex_);
-
   auto it = instruction_runtime_data_.find(&instruction);
   if (it != instruction_runtime_data_.end()) {
     return it->second;
@@ -113,8 +111,6 @@
 
 void GpuPerformanceModelCache::Set(const HloInstruction& instruction,
                                    const EstimateRunTimeData& runtime_data) {
-  absl::MutexLock lock(&mutex_);
-
   instruction_runtime_data_[&instruction] = runtime_data;
 }
 
@@ -126,8 +122,6 @@
 }
 
 void GpuPerformanceModelCache::Invalidate(const HloInstruction& instruction) {
-  absl::MutexLock lock(&mutex_);
-
   // Remove runtime data for the instruction.
   instruction_runtime_data_.erase(&instruction);
 
diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model_base_test.cc b/third_party/xla/xla/service/gpu/model/gpu_performance_model_base_test.cc
index a3eb170..b645800 100644
--- a/third_party/xla/xla/service/gpu/model/gpu_performance_model_base_test.cc
+++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model_base_test.cc
@@ -211,7 +211,7 @@
   TF_ASSERT_OK_AND_ASSIGN(auto module,
                           ParseAndReturnVerifiedModule(hlo_string));
 
-  auto fusion_analysis = AnalyzeFusion(
+  auto fusion_analysis = HloFusionAnalysis::Create(
       *module->entry_computation()->root_instruction(), device_info_);
   auto launch_dimensions =
       GpuPerformanceModelBase::EstimateFusionLaunchDimensions(fusion_analysis);
@@ -247,7 +247,7 @@
   TF_ASSERT_OK_AND_ASSIGN(auto module,
                           ParseAndReturnVerifiedModule(hlo_string));
 
-  auto fusion_analysis = AnalyzeFusion(
+  auto fusion_analysis = HloFusionAnalysis::Create(
       *module->entry_computation()->root_instruction(), device_info_);
   auto launch_dimensions =
       GpuPerformanceModelBase::EstimateFusionLaunchDimensions(fusion_analysis);
@@ -276,7 +276,7 @@
   TF_ASSERT_OK_AND_ASSIGN(auto module,
                           ParseAndReturnVerifiedModule(hlo_string));
 
-  auto fusion_analysis = AnalyzeFusion(
+  auto fusion_analysis = HloFusionAnalysis::Create(
       *module->entry_computation()->root_instruction(), device_info_);
   auto launch_dimensions =
       GpuPerformanceModelBase::EstimateFusionLaunchDimensions(fusion_analysis);
diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc b/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc
index 4c0c35e..7c88789 100644
--- a/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc
+++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc
@@ -68,9 +68,19 @@
   GpuPerformanceModel::RunTimes EstimateRunTimesForPriorityFusion(
       const HloInstruction* producer,
       std::vector<HloInstruction*> fused_consumers = {}) {
+    auto config = GpuPerformanceModelOptions::PriorityFusion(
+        &fusion_analysis_cache_, &gpu_performance_model_cache_);
+
+    auto runtime_data = GpuPerformanceModel::EstimateRunTimeForInstruction(
+        producer, device_info_, &analysis_, config);
+    gpu_performance_model_cache_.Set(*producer, runtime_data);
+    for (auto consumer : fused_consumers) {
+      auto runtime_data = GpuPerformanceModel::EstimateRunTimeForInstruction(
+          consumer, device_info_, &analysis_, config);
+      gpu_performance_model_cache_.Set(*consumer, runtime_data);
+    }
     return GpuPerformanceModel::EstimateRunTimesForPriorityFusion(
-        producer, device_info_, &analysis_,
-        GpuPerformanceModelOptions::PriorityFusion(), fused_consumers);
+        producer, device_info_, &analysis_, config, fused_consumers);
   }
 
   mlir::MLIRContext mlir_context_;
@@ -82,6 +92,7 @@
   se::DeviceDescription device_info_{TestGpuDeviceInfo::RTXA6000DeviceInfo()};
   HloFusionAnalysisCache fusion_analysis_cache_{device_info_};
   GpuHloCostAnalysis analysis_{options_, device_info_};
+  GpuPerformanceModelCache gpu_performance_model_cache_;
 
   GpuPerformanceModelWithIndexingAnalysis indexing_cost_model_{
       &device_info_, &fusion_analysis_cache_, ShapeSizeBytesFunction(),
@@ -674,16 +685,16 @@
 }
 
 fused_computation.0 {
-  p0 = f32[4,28672,32] parameter(0)
-  tanh = f32[4,28672,32] tanh(p0)
+  p0 = f32[4,256,32] parameter(0)
+  tanh = f32[4,256,32] tanh(p0)
   c1 = f32[] constant(72)
-  broadcast = f32[4,28672,32] broadcast(c1), dimensions={}
-  ROOT mul = f32[4,28672,32] multiply(tanh, broadcast)
+  broadcast = f32[4,256, 32] broadcast(c1), dimensions={}
+  ROOT mul = f32[4,256,32] multiply(tanh, broadcast)
 }
 
 ENTRY fusion {
-  p0 = f32[4,28672,32] parameter(0)
-  fusion = f32[4,28672,32] fusion(p0), kind=kLoop, calls=fused_computation.0
+  p0 = f32[4,256,32] parameter(0)
+  fusion = f32[4,256,32] fusion(p0), kind=kLoop, calls=fused_computation.0
   c0 = f32[] constant(0)
   ROOT reduce = f32[4,32] reduce(fusion, c0), to_apply=add, dimensions={1}
 })";
diff --git a/third_party/xla/xla/service/gpu/model/indexing_analysis.cc b/third_party/xla/xla/service/gpu/model/indexing_analysis.cc
index 8f81cb4..e8842b3 100644
--- a/third_party/xla/xla/service/gpu/model/indexing_analysis.cc
+++ b/third_party/xla/xla/service/gpu/model/indexing_analysis.cc
@@ -48,7 +48,6 @@
 #include "xla/layout.h"
 #include "xla/permutation_util.h"
 #include "xla/service/gather_simplifier.h"
-#include "xla/service/gpu/fusions/tiling_util.h"
 #include "xla/service/gpu/hlo_traversal.h"
 #include "xla/service/gpu/matmul_utils.h"
 #include "xla/service/gpu/model/affine_map_printer.h"
@@ -1142,13 +1141,6 @@
   return out;
 }
 
-AffineMap GetTilingAffineMap(llvm::ArrayRef<AffineExpr> exprs,
-                             int64_t num_symbols) {
-  return AffineMap::get(
-      /*dimCount=*/6, /*symbolCount=*/num_symbols, exprs,
-      exprs[0].getContext());
-}
-
 }  // namespace
 
 IndexingMap CreateIdentityMap(const Shape& shape, MLIRContext* mlir_context) {
@@ -1219,83 +1211,6 @@
       shape.dimensions(), {});
 }
 
-AffineMap GetBlockOffsetsForTiling(
-    absl::Span<const int64_t> num_blocks,
-    absl::Span<const int64_t> tile_sizes_per_block, int64_t rank,
-    MLIRContext* mlir_context) {
-  auto offsets =
-      DelinearizeInBoundsIndex(getAffineDimExpr(3, mlir_context), num_blocks);
-  for (auto&& [offset, tile_size] : llvm::zip(offsets, tile_sizes_per_block)) {
-    offset = offset * tile_size;
-  }
-  return GetTilingAffineMap(offsets, rank);
-}
-
-AffineMap GetBlockOffsetsForTiling(const Tiling& tiling,
-                                   MLIRContext* mlir_context) {
-  return GetBlockOffsetsForTiling(tiling.GetBlockCounts(),
-                                  tiling.GetBlockTileSize(),
-                                  tiling.GetShape().size(), mlir_context);
-}
-
-AffineMap GetThreadOffsetsForTiling(
-    absl::Span<const int64_t> num_threads,
-    absl::Span<const int64_t> tile_sizes_per_thread, int64_t rank,
-    MLIRContext* mlir_context) {
-  auto offsets =
-      DelinearizeInBoundsIndex(getAffineDimExpr(0, mlir_context), num_threads);
-  for (int dim = 0; dim < rank; ++dim) {
-    if (tile_sizes_per_thread[dim] > 1) {
-      offsets[dim] = offsets[dim] +
-                     getAffineSymbolExpr(dim, mlir_context) * num_threads[dim];
-    }
-  }
-  return GetTilingAffineMap(offsets, rank);
-}
-
-AffineMap GetThreadOffsetsForTiling(const Tiling& tiling,
-                                    MLIRContext* mlir_context) {
-  return GetThreadOffsetsForTiling(tiling.GetThreadsPerBlock(),
-                                   tiling.GetThreadTileSize(),
-                                   tiling.GetShape().size(), mlir_context);
-}
-
-IndexingMap GetIndexingMapForTiling(const Tiling& tiling,
-                                    MLIRContext* mlir_context) {
-  return GetIndexingMapForTiling(
-      GetBlockOffsetsForTiling(tiling, mlir_context),
-      GetThreadOffsetsForTiling(tiling, mlir_context),
-      tiling.GetNumThreadsPerBlock(), tiling.GetNumBlocks(),
-      tiling.GetThreadTileSize(), tiling.GetShape());
-}
-
-IndexingMap GetIndexingMapForTiling(AffineMap block_offsets,
-                                    AffineMap thread_offsets,
-                                    int64_t threads_per_block,
-                                    int64_t num_blocks,
-                                    absl::Span<const int64_t> thread_tile_sizes,
-                                    absl::Span<const int64_t> tiled_shape) {
-  auto* mlir_context = block_offsets.getContext();
-  llvm::SmallVector<AffineExpr, 4> offsets;
-  offsets.reserve(block_offsets.getNumResults());
-  for (auto [block, thread] :
-       llvm::zip(block_offsets.getResults(), thread_offsets.getResults())) {
-    offsets.push_back(block + thread);
-  }
-  std::vector<DimVar> dimension_ranges{
-      {{0, threads_per_block - 1}}, {}, {}, {{0, num_blocks - 1}}, {}, {},
-  };
-  auto affine_map = mlir::AffineMap::get(block_offsets.getNumDims(),
-                                         block_offsets.getNumSymbols(), offsets,
-                                         mlir_context);
-  IndexingMap map{affine_map, dimension_ranges,
-                  RangeVarsFromTensorSizes(thread_tile_sizes), /*rt_vars=*/{}};
-  for (int i = 0; i < tiled_shape.size(); ++i) {
-    map.AddConstraint(affine_map.getResult(i), {0, tiled_shape[i] - 1});
-  }
-  return map;
-}
-
 bool HloInstructionIndexing::Simplify() {
   bool any_simplified = false;
   for (auto& operand_indexing : indexing_maps) {
diff --git a/third_party/xla/xla/service/gpu/model/indexing_analysis.h b/third_party/xla/xla/service/gpu/model/indexing_analysis.h
index 201b8f6..e475b5a 100644
--- a/third_party/xla/xla/service/gpu/model/indexing_analysis.h
+++ b/third_party/xla/xla/service/gpu/model/indexing_analysis.h
@@ -18,7 +18,6 @@
 #define XLA_SERVICE_GPU_MODEL_INDEXING_ANALYSIS_H_
 
 #include <cstdint>
-#include <functional>
 #include <ostream>
 #include <string>
 #include <vector>
@@ -31,7 +30,6 @@
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/MLIRContext.h"
 #include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/service/gpu/fusions/tiling_util.h"
 #include "xla/service/gpu/hlo_traversal.h"
 #include "xla/service/gpu/model/affine_map_printer.h"
 #include "xla/service/gpu/model/indexing_map.h"
@@ -145,35 +143,6 @@
 IndexingMap GetIndexingMapFromLogicalToPhysicalLayout(
     const Shape& shape, mlir::MLIRContext* mlir_context);
 
-// Creates an indexing map from thread and block IDs to elements of the tiled
-// shape. Uses the same convention as KernelFusionInterface: dimensions 0 to 2
-// are thread indices (currently only 0 is used), dimensions 3 to 5 are block
-// indices (currently only 3 is used).
-mlir::AffineMap GetBlockOffsetsForTiling(
-    absl::Span<const int64_t> num_blocks,
-    absl::Span<const int64_t> tile_sizes_per_block, int64_t rank,
-    mlir::MLIRContext* mlir_context);
-mlir::AffineMap GetBlockOffsetsForTiling(const Tiling& tiling,
-                                         mlir::MLIRContext* mlir_context);
-mlir::AffineMap GetThreadOffsetsForTiling(
-    absl::Span<const int64_t> num_threads,
-    absl::Span<const int64_t> tile_sizes_per_thread, int64_t rank,
-    mlir::MLIRContext* mlir_context);
-mlir::AffineMap GetThreadOffsetsForTiling(const Tiling& tiling,
-                                          mlir::MLIRContext* mlir_context);
-
-// Convenience functions for the two functions above
-// (`GetBlockOffsestsForTiling` + `GetThreadOffsetsForTiling`). Also sets up
-// the ranges of dimensions and symbols.
-IndexingMap GetIndexingMapForTiling(const Tiling& tiling,
-                                    mlir::MLIRContext* mlir_context);
-IndexingMap GetIndexingMapForTiling(mlir::AffineMap block_offsets,
-                                    mlir::AffineMap thread_offsets,
-                                    int64_t threads_per_block,
-                                    int64_t num_blocks,
-                                    absl::Span<const int64_t> thread_tile_sizes,
-                                    absl::Span<const int64_t> tiled_shape);
-
 // Returns the shape of the output of the instruction.
 const Shape& GetOutputShape(const HloInstruction* instr, int64_t output_id);
 
diff --git a/third_party/xla/xla/service/gpu/model/indexing_analysis_test.cc b/third_party/xla/xla/service/gpu/model/indexing_analysis_test.cc
index 30fd805..d30c963 100644
--- a/third_party/xla/xla/service/gpu/model/indexing_analysis_test.cc
+++ b/third_party/xla/xla/service/gpu/model/indexing_analysis_test.cc
@@ -20,7 +20,6 @@
 #include "absl/strings/string_view.h"
 #include "mlir/IR/MLIRContext.h"
 #include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/service/gpu/fusions/tiling_util.h"
 #include "xla/service/gpu/hlo_traversal.h"
 #include "xla/service/gpu/model/indexing_test_utils.h"
 #include "xla/tests/hlo_test_base.h"
@@ -2564,32 +2563,6 @@
                   ElementsAre(UndefinedMap()), ElementsAre(UndefinedMap())));
 }
 
-TEST_F(IndexingAnalysisTest, TilingIndexing) {
-  Tiling tiling{/*shape=*/{1022, 256, 16},
-                /*tile_sizes=*/{8, 1, 4},
-                /*num_threads=*/{1, 4, 4}};
-  auto indexing_map = GetIndexingMapForTiling(tiling, &mlir_context_);
-  indexing_map.Simplify();
-  EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"(
-        (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> (
-          (d3 floordiv 64) * 8 + s0,
-          (d3 mod 64) * 4 + d0 floordiv 4,
-          d0 mod 4 + s2 * 4
-        )
-        domain:
-        d0 in [0, 15]
-        d1 in [0, 0]
-        d2 in [0, 0]
-        d3 in [0, 8191]
-        d4 in [0, 0]
-        d5 in [0, 0]
-        s0 in [0, 7]
-        s1 in [0, 0]
-        s2 in [0, 3]
-        (d3 floordiv 64) * 8 + s0 in [0, 1021]
-      )"));
-}
-
 TEST_F(IndexingAnalysisTest, EpilogueIndexing) {
   auto module = ParseAndReturnVerifiedModule(R"(
     HloModule m
diff --git a/third_party/xla/xla/service/gpu/model/indexing_map.cc b/third_party/xla/xla/service/gpu/model/indexing_map.cc
index f0d34e8..da21c34 100644
--- a/third_party/xla/xla/service/gpu/model/indexing_map.cc
+++ b/third_party/xla/xla/service/gpu/model/indexing_map.cc
@@ -1113,6 +1113,20 @@
   return eval.getConstantResults();
 }
 
+bool IndexingMap::IsSymbolConstrained(int64_t symbol_id) const {
+  for (const auto& [expr, _] : constraints_) {
+    bool result = false;
+    expr.walk([&](mlir::AffineExpr leaf) {
+      auto sym = mlir::dyn_cast<mlir::AffineSymbolExpr>(leaf);
+      if (sym && sym.getPosition() == symbol_id) {
+        result = true;
+      }
+    });
+    if (result) return true;
+  }
+  return false;
+}
+
 RangeEvaluator::RangeEvaluator(const IndexingMap& indexing_map,
                                MLIRContext* mlir_context, bool use_constraints)
     : mlir_context_(mlir_context),
diff --git a/third_party/xla/xla/service/gpu/model/indexing_map.h b/third_party/xla/xla/service/gpu/model/indexing_map.h
index 478e0ec..2e6cc13 100644
--- a/third_party/xla/xla/service/gpu/model/indexing_map.h
+++ b/third_party/xla/xla/service/gpu/model/indexing_map.h
@@ -369,6 +369,9 @@
       llvm::ArrayRef<mlir::AffineExpr> dim_const_exprs,
       llvm::ArrayRef<mlir::AffineExpr> symbol_const_exprs) const;
 
+  // Returns true if there is a constraint on the given symbol.
+  bool IsSymbolConstrained(int64_t symbol_id) const;
+
   // Returns true if the domain is empty. If it returns false, that does not
   // mean that the domain is not effectively empty.
   // For example, if there are two constraints 0 <= d0 mod 7 <= 0 and
diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile.cc
index dc5c2f2..1ea107c 100644
--- a/third_party/xla/xla/service/gpu/model/symbolic_tile.cc
+++ b/third_party/xla/xla/service/gpu/model/symbolic_tile.cc
@@ -602,23 +602,8 @@
     AffineExpr strided_indexing, absl::Span<Interval const> dimension_intervals,
     absl::Span<Interval const> symbol_intervals) {
   MLIRContext* ctx = strided_indexing.getContext();
-  // Deal with the symbol case (capturing a whole untiled dimension).
-  // TODO(b/330906085): concatenating across a reduction dimension needs to be
-  // handled by this code.
-  if (auto symbol = llvm::dyn_cast<AffineSymbolExpr>(strided_indexing)) {
-    const Interval& symbol_interval = symbol_intervals[symbol.getPosition()];
-    if (symbol_interval.lower != 0) {
-      return std::nullopt;
-    }
-
-    return SizeAndStrideExpression(
-        /*size=*/getAffineConstantExpr(symbol_interval.upper + 1, ctx),
-        /*stride=*/getAffineConstantExpr(1, ctx));
-  }
-
   AffineMapPrinter printer;
 
-  // TODO(b/328427138): support multivariate size expressions.
   switch (strided_indexing.getKind()) {
     case AffineExprKind::DimId:
       return SizeAndStrideExpression(/*size=*/strided_indexing,
@@ -626,23 +611,15 @@
     case mlir::AffineExprKind::Mul: {
       const auto mul = llvm::cast<mlir::AffineBinaryOpExpr>(strided_indexing);
       AffineExpr lhs = mul.getLHS();
-      // The stride may not be fully collapsed if it is negative; in that case,
-      // we need to extract the negative multiplier first.
-      if (const auto rhs = llvm::dyn_cast<AffineConstantExpr>(mul.getRHS());
-          rhs && rhs.getValue() == -1) {
-        std::optional<SizeAndStrideExpression> maybe_size_and_stride =
-            ExtractSizeAndStride(lhs, dimension_intervals, symbol_intervals);
-        if (!maybe_size_and_stride.has_value()) {
-          return std::nullopt;
-        }
-
-        return SizeAndStrideExpression(
-            /*size=*/maybe_size_and_stride->size,
-            /*stride=*/maybe_size_and_stride->stride * rhs);
+      std::optional<SizeAndStrideExpression> maybe_size_and_stride =
+          ExtractSizeAndStride(lhs, dimension_intervals, symbol_intervals);
+      if (!maybe_size_and_stride.has_value()) {
+        return std::nullopt;
       }
-      CHECK(lhs.getKind() == AffineExprKind::DimId);
-      return SizeAndStrideExpression(/*size=*/lhs,
-                                     /*stride=*/mul.getRHS());
+
+      return SizeAndStrideExpression(
+          /*size=*/maybe_size_and_stride->size,
+          /*stride=*/maybe_size_and_stride->stride * mul.getRHS());
     }
     case mlir::AffineExprKind::Mod: {
       auto mod = llvm::cast<mlir::AffineBinaryOpExpr>(strided_indexing);
@@ -656,15 +633,18 @@
     case mlir::AffineExprKind::Constant:
       return SizeAndStrideExpression(/*size=*/getAffineConstantExpr(1, ctx),
                                      /*stride=*/getAffineConstantExpr(0, ctx));
-    case mlir::AffineExprKind::SymbolId:
-      VLOG(1) << "Encountered complex size expression involving symbol "
-              << printer.ToString(strided_indexing);
-      // It's currently not checked separately, but RTVars shouldn't appear in
-      // the strided indexing expressions.
-      return std::nullopt;
+    case mlir::AffineExprKind::SymbolId: {
+      auto symbol = llvm::cast<AffineSymbolExpr>(strided_indexing);
+      const Interval& symbol_interval = symbol_intervals[symbol.getPosition()];
+      if (symbol_interval.lower != 0) {
+        return std::nullopt;
+      }
+
+      return SizeAndStrideExpression(
+          /*size=*/getAffineConstantExpr(symbol_interval.upper + 1, ctx),
+          /*stride=*/getAffineConstantExpr(1, ctx));
+    }
     case mlir::AffineExprKind::Add: {
-      // TODO(b/328427138): this should only be necessary in the multivariate
-      // case, and will be implemented later.
       std::optional<std::vector<SizeAndStrideExpression>>
           maybe_sizes_and_strides =
               ExtractSizesAndStridesFromMultivariateSummation(
diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc
index 1db5537..92c851d 100644
--- a/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc
+++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc
@@ -549,6 +549,61 @@
       )")));
 }
 
+TEST_F(SymbolicTileTest, CanPropagateTileThroughSplitReductionOfSplittedAxis) {
+  // A split reshape of a reverse creates a sum of strided symbols.
+  auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"(
+    HloModule m
+    add {
+      p0 = f32[] parameter(0)
+      p1 = f32[] parameter(1)
+      ROOT add = f32[] add(p0, p1)
+    }
+
+    computation {
+      p0 = f32[18] parameter(0)
+      bitcast = f32[9,2] bitcast(p0)
+      c0 = f32[] constant(0)
+      reduce_0 = f32[9] reduce(bitcast, c0), dimensions={1}, to_apply=add
+      ROOT reduce_1 = f32[] reduce(reduce_0, c0), dimensions={0}, to_apply=add
+    }
+
+    ENTRY e {
+      p0 = f32[18] parameter(0)
+      ROOT fusion = f32[] fusion(p0), kind=kLoop, calls=computation
+    }
+  )"));
+
+  EXPECT_THAT(
+      SymbolicTile::FromIndexingMap(*input_indexing.indexing_maps[0].begin()),
+      Optional(MatchSymbolicTileString(R"(
+      Symbolic tile with
+        offset_map: () -> (0)
+        size_map: () -> (18)
+        stride_map: () -> (1)
+      )")));
+}
+
+TEST_F(SymbolicTileTest, CanPropagateTileThroughSummationOfSymbols) {
+  // Such an indexing map is representative of a sequence of HLOs containing a
+  // bitcast followed by two sequential reductions of the split axis, i.e.
+  // something like
+  //   p0 = f32[18] parameter(0)
+  //   bitcast = f32[9,2] bitcast(p0)
+  //   reduce_0 = f32[9] reduce(bitcast), dimensions={1}
+  //   reduce_1 = f32[] reduce(reduce_0), dimensions={0}
+  IndexingMap indexing_map = IndexingMap::FromTensorSizes(
+      ParseAffineMap("()[s0, s1] -> (s1 * 2 + s0)", &mlir_context_), {},
+      {2, 9});
+
+  EXPECT_THAT(SymbolicTile::FromIndexingMap(indexing_map),
+              Optional(MatchSymbolicTileString(R"(
+      Symbolic tile with
+        offset_map: () -> (0)
+        size_map: () -> (18)
+        stride_map: () -> (1)
+      )")));
+}
+
 TEST_F(SymbolicTileTest,
        FailsGracefullyAtPropagatingTileThroughSliceOfSplitReshape) {
   // TODO(b/349487906): constraints should allow us to unblock this use case.
diff --git a/third_party/xla/xla/service/gpu/move_copy_to_users.cc b/third_party/xla/xla/service/gpu/move_copy_to_users.cc
deleted file mode 100644
index acc10db..0000000
--- a/third_party/xla/xla/service/gpu/move_copy_to_users.cc
+++ /dev/null
@@ -1,240 +0,0 @@
-/* Copyright 2022 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/move_copy_to_users.h"
-
-#include <vector>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/status.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/layout.h"
-#include "xla/service/hlo_creation_utils.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace {
-
-class MoveCopyToUsersVisitor : public DfsHloRewriteVisitor {
-  // Turn copy->pad into pad->copy
-  absl::Status HandlePad(HloInstruction* hlo) override {
-    HloInstruction* operand = hlo->mutable_operand(0);
-    HloInstruction* c = hlo->mutable_operand(1);
-    if (operand->opcode() == HloOpcode::kCopy) {
-      HloInstruction* copied = operand->mutable_operand(0);
-      TF_ASSIGN_OR_RETURN(
-          HloInstruction * earlier_pad,
-          MakePadHlo(copied, c, hlo->padding_config(), &hlo->metadata()));
-      // MakePadHlo fails to propagate layout.
-      *earlier_pad->mutable_shape()->mutable_layout() =
-          copied->shape().layout();
-      HloInstruction* later_copy = MakeCopyHlo(earlier_pad, hlo->shape());
-      TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy));
-    }
-    return absl::OkStatus();
-  }
-
-  // Turn copy->slice into slice->copy, as slice is layout-preserving.
-  absl::Status HandleSlice(HloInstruction* hlo) override {
-    HloInstruction* operand = hlo->mutable_operand(0);
-    if (operand->opcode() == HloOpcode::kCopy) {
-      HloInstruction* copied = operand->mutable_operand(0);
-      TF_ASSIGN_OR_RETURN(
-          HloInstruction * earlier_slice,
-          MakeSliceHlo(copied, hlo->slice_starts(), hlo->slice_limits(),
-                       hlo->slice_strides(), &hlo->metadata()));
-      *earlier_slice->mutable_shape()->mutable_layout() =
-          copied->shape().layout();
-      HloInstruction* later_copy = MakeCopyHlo(earlier_slice, hlo->shape());
-      TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy));
-    }
-    return absl::OkStatus();
-  }
-
-  // Turn copy->dynamic-slice into dynamic-slice->copy, as dynamic-slice is
-  // layout-preserving.
-  absl::Status HandleDynamicSlice(HloInstruction* hlo) override {
-    HloInstruction* operand = hlo->mutable_operand(0);
-    if (operand->opcode() == HloOpcode::kCopy) {
-      HloInstruction* copied = operand->mutable_operand(0);
-      TF_ASSIGN_OR_RETURN(
-          HloInstruction * earlier_slice,
-          MakeDynamicSliceHlo(
-              copied,
-              absl::Span<HloInstruction* const>(hlo->operands()).subspan(1),
-              hlo->dynamic_slice_sizes(), &hlo->metadata()));
-      *earlier_slice->mutable_shape()->mutable_layout() =
-          copied->shape().layout();
-      HloInstruction* later_copy = MakeCopyHlo(earlier_slice, hlo->shape());
-      TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy));
-    }
-    return absl::OkStatus();
-  }
-
-  // Turn copy->reduce_window into reduce_window->copy, as reduce_window is
-  // layout-preserving.
-  absl::Status HandleReduceWindow(HloInstruction* hlo) override {
-    HloInstruction* operand = hlo->mutable_operand(0);
-    if (operand->opcode() == HloOpcode::kCopy) {
-      HloInstruction* copied = operand->mutable_operand(0);
-      TF_ASSIGN_OR_RETURN(
-          HloInstruction * earlier_reduce_window,
-          MakeReduceWindowHlo(copied, hlo->mutable_operand(1), hlo->window(),
-                              hlo->called_computations()[0], &hlo->metadata()));
-      *earlier_reduce_window->mutable_shape()->mutable_layout() =
-          copied->shape().layout();
-      HloInstruction* later_copy =
-          MakeCopyHlo(earlier_reduce_window, hlo->shape());
-      TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy));
-    }
-    return absl::OkStatus();
-  }
-
-  absl::Status HandleReduce(HloInstruction* hlo) override {
-    HloInstruction* operand = hlo->mutable_operand(0);
-    // Reductions can handle transposes, e.g. via column reduction.
-    if (operand->opcode() == HloOpcode::kCopy && !hlo->shape().IsTuple()) {
-      HloInstruction* new_reduce = hlo->AddInstruction(
-          hlo->CloneWithNewOperands(hlo->shape(), {operand->mutable_operand(0),
-                                                   hlo->mutable_operand(1)}));
-      TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, new_reduce));
-    }
-    return absl::OkStatus();
-  }
-
-  absl::Status HandleBitcastConvert(HloInstruction* hlo) override {
-    return absl::OkStatus();
-  }
-
-  // Sink kCopy across elementwise unary.
-  absl::Status HandleElementwiseUnary(HloInstruction* hlo) override {
-    HloInstruction* operand = hlo->mutable_operand(0);
-    if (hlo->opcode() == HloOpcode::kReducePrecision) {
-      return absl::OkStatus();
-    }
-    if (operand->opcode() == HloOpcode::kCopy) {
-      HloInstruction* copied = operand->mutable_operand(0);
-      TF_ASSIGN_OR_RETURN(
-          HloInstruction * earlier_elementwise,
-          MakeUnaryHlo(hlo->opcode(), copied, &hlo->metadata()));
-      HloInstruction* later_copy =
-          MakeCopyHlo(earlier_elementwise, hlo->shape());
-      TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy));
-    }
-    return absl::OkStatus();
-  }
-
-  // Sink kCopy across reverse
-  absl::Status HandleReverse(HloInstruction* hlo) override {
-    HloInstruction* operand = hlo->mutable_operand(0);
-    if (operand->opcode() == HloOpcode::kCopy) {
-      HloInstruction* copied = operand->mutable_operand(0);
-      TF_ASSIGN_OR_RETURN(
-          HloInstruction * earlier_reverse,
-          MakeReverseHlo(copied, hlo->dimensions(), &hlo->metadata()));
-      HloInstruction* later_copy = MakeCopyHlo(earlier_reverse, hlo->shape());
-      TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy));
-    }
-    return absl::OkStatus();
-  }
-
-  // Sink kCopy across convert.
-  absl::Status HandleConvert(HloInstruction* hlo) override {
-    HloInstruction* operand = hlo->mutable_operand(0);
-    if (operand->opcode() == HloOpcode::kCopy) {
-      HloInstruction* copied = operand->mutable_operand(0);
-      HloInstruction* earlier_convert = MakeConvertToHlo(
-          copied, hlo->shape().element_type(), &hlo->metadata());
-      HloInstruction* later_copy = MakeCopyHlo(earlier_convert, hlo->shape());
-      TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy));
-    }
-    return absl::OkStatus();
-  }
-
-  // Sink kCopy across elementwise binary.
-  absl::Status HandleElementwiseBinary(HloInstruction* hlo) override {
-    HloInstruction* a = hlo->mutable_operand(0);
-    HloInstruction* b = hlo->mutable_operand(1);
-    if (a->opcode() == HloOpcode::kCopy && b->opcode() == HloOpcode::kCopy) {
-      HloInstruction* copied_a = a->mutable_operand(0);
-      HloInstruction* copied_b = b->mutable_operand(0);
-      if (copied_a->shape() == copied_b->shape()) {
-        HloInstruction* earlier_elementwise;
-        if (hlo->opcode() == HloOpcode::kCompare) {
-          TF_ASSIGN_OR_RETURN(
-              earlier_elementwise,
-              MakeCompareHlo(hlo->comparison_direction(), copied_a, copied_b,
-                             &hlo->metadata()));
-        } else {
-          TF_ASSIGN_OR_RETURN(earlier_elementwise,
-                              MakeBinaryHlo(hlo->opcode(), copied_a, copied_b,
-                                            &hlo->metadata()));
-        }
-        HloInstruction* later_copy =
-            MakeCopyHlo(earlier_elementwise, hlo->shape());
-        TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy));
-      }
-    }
-    return absl::OkStatus();
-  }
-
-  // Move copy across kConcat if it occurs on all operands.
-  absl::Status HandleConcatenate(HloInstruction* hlo) override {
-    const HloInstruction* first = hlo->operand(0);
-    if (first->opcode() != HloOpcode::kCopy) {
-      return absl::OkStatus();
-    }
-    const HloInstruction* inner_op = first->operand(0);
-    const Layout& inner_op_layout = inner_op->shape().layout();
-
-    std::vector<HloInstruction*> new_operands;
-    new_operands.reserve(hlo->operand_count());
-    for (HloInstruction* op : hlo->mutable_operands()) {
-      if (op->opcode() != HloOpcode::kCopy ||
-          op->operand(0)->shape().layout() != inner_op_layout) {
-        VLOG(3) << "Mismatch between " << op->ToString()
-                << " and expected op layout " << inner_op_layout.ToString();
-        return absl::OkStatus();
-      }
-      new_operands.push_back(op->mutable_operand(0));
-    }
-
-    TF_ASSIGN_OR_RETURN(
-        HloInstruction * new_concat,
-        MakeConcatHlo(new_operands, hlo->concatenate_dimension()));
-    *new_concat->mutable_shape()->mutable_layout() = inner_op_layout;
-
-    HloInstruction* new_copy = MakeCopyHlo(new_concat, hlo->shape());
-    TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, new_copy));
-    return absl::OkStatus();
-  }
-};
-
-}  // end namespace
-
-absl::StatusOr<bool> MoveCopyToUsers::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  return MoveCopyToUsersVisitor{}.RunOnModule(module, execution_threads);
-}
-
-}  // end namespace xla
diff --git a/third_party/xla/xla/service/gpu/move_copy_to_users.h b/third_party/xla/xla/service/gpu/move_copy_to_users.h
deleted file mode 100644
index 4a7dfb4..0000000
--- a/third_party/xla/xla/service/gpu/move_copy_to_users.h
+++ /dev/null
@@ -1,39 +0,0 @@
-/* Copyright 2022 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_MOVE_COPY_TO_USERS_H_
-#define XLA_SERVICE_GPU_MOVE_COPY_TO_USERS_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-
-// Sink kCopy operations as far down the graph as possible.
-class MoveCopyToUsers : public HloModulePass {
- public:
-  absl::string_view name() const override { return "move_copy_to_users"; }
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-};
-
-}  // end namespace xla
-
-#endif  // XLA_SERVICE_GPU_MOVE_COPY_TO_USERS_H_
diff --git a/third_party/xla/xla/service/gpu/move_copy_to_users_test.cc b/third_party/xla/xla/service/gpu/move_copy_to_users_test.cc
deleted file mode 100644
index 10179c1..0000000
--- a/third_party/xla/xla/service/gpu/move_copy_to_users_test.cc
+++ /dev/null
@@ -1,274 +0,0 @@
-/* Copyright 2022 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/move_copy_to_users.h"
-
-#include <optional>
-
-#include "absl/strings/string_view.h"
-#include "xla/service/layout_assignment.h"
-#include "xla/tests/hlo_test_base.h"
-#include "tsl/platform/test.h"
-
-namespace xla {
-namespace {
-
-class MoveCopyToUsersTest : public HloTestBase {
- public:
-  MoveCopyToUsersTest()
-      : HloTestBase(/*verifier_layout_sensitive=*/true,
-                    /*allow_mixed_precision_in_hlo_verifier=*/true,
-                    LayoutAssignment::InstructionCanChangeLayout) {}
-  void CheckMoveCopyToUsers(absl::string_view hlo,
-                            std::optional<absl::string_view> expected) {
-    RunAndFilecheckHloRewrite(hlo, MoveCopyToUsers{}, expected);
-  }
-};
-
-TEST_F(MoveCopyToUsersTest, Pad) {
-  const char* hlo = R"(
-HloModule module
-
-ENTRY main {
-  input = s8[1,17,9,9]{3,1,2,0} parameter(0)
-  copy = s8[1,17,9,9]{1,3,2,0} copy(input)
-  constant = s8[] constant(0)
-  ROOT pad = s8[1,32,9,9]{1,3,2,0} pad(copy, constant), padding=0_0x0_15x0_0x0_0
-}
-)";
-
-  CheckMoveCopyToUsers(hlo, R"(
-// CHECK: [[constant_0:%[^ ]+]] = s8[] constant(0)
-// CHECK: [[pad_1_1:%[^ ]+]] = s8[1,32,9,9]{3,1,2,0} pad([[input_2:%[^ ]+]], [[constant_0]]), padding=0_0x0_15x0_0x0_0
-// CHECK: ROOT [[copy_1_3:%[^ ]+]] = s8[1,32,9,9]{1,3,2,0} copy([[pad_1_1]])
-)");
-}
-
-TEST_F(MoveCopyToUsersTest, Unary) {
-  const char* hlo = R"(
-HloModule module
-
-ENTRY main {
-  input = f32[1,17,9,9]{3,2,1,0} parameter(0)
-  copy = f32[1,17,9,9]{1,3,2,0} copy(input)
-  ROOT pad = f32[1,17,9,9]{1,3,2,0} sqrt(copy)
-}
-)";
-
-  CheckMoveCopyToUsers(hlo, R"(
-// CHECK: [[input_0:%[^ ]+]] = f32[1,17,9,9]{3,2,1,0} parameter(0)
-// CHECK: [[sqrt_1:%[^ ]+]] = f32[1,17,9,9]{3,2,1,0} sqrt([[input_0]])
-// CHECK: ROOT [[copy_1_2:%[^ ]+]] = f32[1,17,9,9]{1,3,2,0} copy([[sqrt_1]])
-)");
-}
-
-TEST_F(MoveCopyToUsersTest, Reverse) {
-  const char* hlo = R"(
-HloModule module
-
-ENTRY main {
-  input = f32[1,17,9,9]{3,2,1,0} parameter(0)
-  copy = f32[1,17,9,9]{1,3,2,0} copy(input)
-  ROOT pad = f32[1,17,9,9]{1,3,2,0} reverse(copy), dimensions={1,2}
-}
-)";
-
-  CheckMoveCopyToUsers(hlo, R"(
-// CHECK: [[input_0:%[^ ]+]] = f32[1,17,9,9]{3,2,1,0} parameter(0)
-// CHECK: [[sqrt_1:%[^ ]+]] = f32[1,17,9,9]{3,2,1,0} reverse([[input_0]]), dimensions={1,2}
-// CHECK: ROOT [[copy_1_2:%[^ ]+]] = f32[1,17,9,9]{1,3,2,0} copy([[sqrt_1]])
-)");
-}
-
-TEST_F(MoveCopyToUsersTest, Convert) {
-  const char* hlo = R"(
-HloModule module
-
-ENTRY main {
-  input = f32[1,17,9,9]{3,2,1,0} parameter(0)
-  copy = f32[1,17,9,9]{1,3,2,0} copy(input)
-  ROOT converted = f16[1,17,9,9]{1,3,2,0} convert(copy)
-}
-)";
-
-  CheckMoveCopyToUsers(hlo, R"(
-// CHECK: [[input_0:%[^ ]+]] = f32[1,17,9,9]{3,2,1,0} parameter(0)
-// CHECK: [[sqrt_1:%[^ ]+]] = f16[1,17,9,9]{3,2,1,0} convert([[input_0]])
-// CHECK: ROOT [[copy_1_2:%[^ ]+]] = f16[1,17,9,9]{1,3,2,0} copy([[sqrt_1]])
-)");
-}
-
-TEST_F(MoveCopyToUsersTest, Slice) {
-  const char* hlo = R"(
-HloModule module
-
-ENTRY main {
-  input = f32[1,17,9,9]{3,2,1,0} parameter(0)
-  copy = f32[1,17,9,9]{1,3,2,0} copy(input)
-  ROOT slice = f32[1,4,6,6]{1,3,2,0} slice(copy), slice={[0:1],[0:4],[0:6],[0:6]}
-}
-)";
-
-  CheckMoveCopyToUsers(hlo, R"(
-// CHECK: [[slice_0:%[^ ]+]] = f32[1,4,6,6]{3,2,1,0} slice([[input_1:%[^ ]+]]), slice={[0:1], [0:4], [0:6], [0:6]}
-// CHECK-NEXT: ROOT [[copy_1_2:%[^ ]+]] = f32[1,4,6,6]{1,3,2,0} copy([[slice_0]])
-)");
-}
-
-TEST_F(MoveCopyToUsersTest, DynamicSlice) {
-  const char* hlo = R"(
-HloModule module
-
-ENTRY main {
-  input = f32[1,17,9,9]{3,2,1,0} parameter(0)
-  copy = f32[1,17,9,9]{1,3,2,0} copy(input)
-  p0 = s32[] parameter(1)
-  p1 = s32[] parameter(2)
-  p2 = s32[] parameter(3)
-  p3 = s32[] parameter(4)
-  ROOT ds = f32[1,4,6,6]{1,3,2,0} dynamic-slice(copy, p0, p1, p2, p3), dynamic_slice_sizes={1,4,6,6}
-}
-)";
-
-  CheckMoveCopyToUsers(hlo, R"(
-// CHECK: [[ds:%[^ ]+]] = f32[1,4,6,6]{3,2,1,0} dynamic-slice({{.*}}), dynamic_slice_sizes={1,4,6,6}
-// CHECK-NEXT: ROOT {{.*}} = f32[1,4,6,6]{1,3,2,0} copy([[ds]])
-)");
-}
-
-TEST_F(MoveCopyToUsersTest, ReduceWindow) {
-  const char* hlo = R"(
-HloModule R2Window
-
-mul {
-  lhs = f32[] parameter(0)
-  rhs = f32[] parameter(1)
-  ROOT mul = f32[] multiply(lhs, rhs)
-}
-
-ENTRY R2Window {
-  operand = f32[256,384]{1,0} parameter(0)
-  c = f32[256,384]{0,1} copy(operand)
-  constant = f32[] constant(1)
-  ROOT reduce-window = f32[256,384]{0,1} reduce-window(c, constant), window={size=2x3 pad=0_1x1_1}, to_apply=mul
-}
-)";
-
-  CheckMoveCopyToUsers(hlo, R"(
-// CHECK: [[reduce_window_1_0:%[^ ]+]] = f32[256,384]{1,0} reduce-window([[operand_1:%[^ ]+]], [[constant_2:%[^ ]+]]), window={size=2x3 pad=0_1x1_1}, to_apply=[[mul_3:%[^ ]+]]
-// CHECK-NEXT: ROOT [[copy_4:%[^ ]+]] = f32[256,384]{0,1} copy([[reduce_window_1_0]])
-)");
-}
-
-TEST_F(MoveCopyToUsersTest, Reduce) {
-  const char* hlo = R"(
-HloModule R2
-
-mul {
-  lhs = f32[] parameter(0)
-  rhs = f32[] parameter(1)
-  ROOT mul = f32[] multiply(lhs, rhs)
-}
-
-ENTRY R2 {
-  operand = f32[256,384,10]{2,1,0} parameter(0)
-  c = f32[256,384,10]{0,1,2} copy(operand)
-  constant = f32[] constant(1)
-  ROOT reduce = f32[384,10]{0,1} reduce(c, constant), dimensions={0}, to_apply=mul
-}
-)";
-
-  CheckMoveCopyToUsers(hlo, R"(
-// CHECK: [[operand:%[^ ]+]] = f32[256,384,10]{2,1,0} parameter(0)
-// CHECK: ROOT [[reduce:%[^ ]+]] = f32[384,10]{0,1} reduce([[operand]], [[constant_2:%[^ ]+]]), dimensions={0}, to_apply=[[mul_3:%[^ ]+]]
-)");
-}
-
-TEST_F(MoveCopyToUsersTest, Binary) {
-  const char* hlo = R"(
-HloModule module
-
-ENTRY main {
-  input = f32[1,17,9,9]{3,2,1,0} parameter(0)
-  input2 = f32[1,17,9,9]{3,2,1,0} parameter(1)
-  copy = f32[1,17,9,9]{1,3,2,0} copy(input)
-  copy2 = f32[1,17,9,9]{1,3,2,0} copy(input2)
-  ROOT add = f32[1,17,9,9]{1,3,2,0} add(copy, copy2)
-}
-)";
-
-  CheckMoveCopyToUsers(hlo, R"(
-// CHECK: [[input_0:%[^ ]+]] = f32[1,17,9,9]{3,2,1,0} parameter(0)
-// CHECK: [[input2_1:%[^ ]+]] = f32[1,17,9,9]{3,2,1,0} parameter(1)
-// CHECK: [[add_1_2:%[^ ]+]] = f32[1,17,9,9]{3,2,1,0} add([[input_0]], [[input2_1]])
-// CHECK: ROOT [[copy_1_3:%[^ ]+]] = f32[1,17,9,9]{1,3,2,0} copy([[add_1_2]])
-)");
-}
-
-TEST_F(MoveCopyToUsersTest, BinaryDifferentLayoutNoChange) {
-  const char* hlo = R"(
-HloModule module
-
-ENTRY main {
-  input = f32[1,17,9,9]{3,2,0,1} parameter(0)
-  input2 = f32[1,17,9,9]{3,2,1,0} parameter(1)
-  copy = f32[1,17,9,9]{1,3,2,0} copy(input)
-  copy2 = f32[1,17,9,9]{1,3,2,0} copy(input2)
-  ROOT add = f32[1,17,9,9]{1,3,2,0} add(copy, copy2)
-}
-)";
-
-  CheckMoveCopyToUsers(hlo, std::nullopt);
-}
-
-TEST_F(MoveCopyToUsersTest, Concat) {
-  const char* hlo = R"(
-HloModule module
-
-ENTRY main {
-  input = f32[1,17,9,9]{3,2,1,0} parameter(0)
-  input2 = f32[5,17,9,9]{3,2,1,0} parameter(1)
-  copy = f32[1,17,9,9]{1,3,2,0} copy(input)
-  copy2 = f32[5,17,9,9]{1,3,2,0} copy(input2)
-  ROOT add = f32[6,17,9,9]{1,3,2,0} concatenate(copy, copy2), dimensions={0}
-}
-)";
-
-  CheckMoveCopyToUsers(hlo, R"(
-// CHECK: [[input_0:%[^ ]+]] = f32[1,17,9,9]{3,2,1,0} parameter(0)
-// CHECK: [[input2_1:%[^ ]+]] = f32[5,17,9,9]{3,2,1,0} parameter(1)
-// CHECK: [[concat:%[^ ]+]] = f32[6,17,9,9]{3,2,1,0} concatenate([[input_0]], [[input2_1]])
-// CHECK: ROOT [[copy_1_3:%[^ ]+]] = f32[6,17,9,9]{1,3,2,0} copy([[concat]])
-)");
-}
-
-TEST_F(MoveCopyToUsersTest, ConcatDifferentLayoutNoChange) {
-  const char* hlo = R"(
-HloModule module
-
-ENTRY main {
-  input = f32[1,17,9,9]{3,2,0,1} parameter(0)
-  input2 = f32[1,17,9,9]{3,2,1,0} parameter(1)
-  copy = f32[1,17,9,9]{1,3,2,0} copy(input)
-  copy2 = f32[1,17,9,9]{1,3,2,0} copy(input2)
-  ROOT add = f32[2,17,9,9]{1,3,2,0} concatenate(copy, copy2), dimensions={0}
-}
-)";
-
-  CheckMoveCopyToUsers(hlo, std::nullopt);
-}
-
-}  // namespace
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/multi_output_fusion.cc b/third_party/xla/xla/service/gpu/multi_output_fusion.cc
deleted file mode 100644
index 6ac1217..0000000
--- a/third_party/xla/xla/service/gpu/multi_output_fusion.cc
+++ /dev/null
@@ -1,521 +0,0 @@
-/* Copyright 2018 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/multi_output_fusion.h"
-
-#include <algorithm>
-#include <cstddef>
-#include <cstdint>
-#include <iterator>
-#include <memory>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/log/check.h"
-#include "absl/log/log.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/str_format.h"
-#include "absl/strings/string_view.h"
-#include "xla/debug_options_flags.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_dfs_reachability.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/gpu/gpu_fusible.h"
-#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h"
-#include "xla/service/gpu/model/gpu_performance_model.h"
-#include "xla/service/gpu/model/gpu_performance_model_base.h"
-#include "xla/service/hlo_graph_dumper.h"
-#include "xla/service/instruction_fusion.h"
-#include "xla/shape_util.h"
-#include "xla/stream_executor/device_description.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/status.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-namespace {
-
-bool IsProfitableOperand(HloInstruction* instr) {
-  // Effective scalars are not a profitable shared operand. Skip them.
-  return !ShapeUtil::IsEffectiveScalar(instr->shape());
-}
-
-// Finds and returns the unique `slice` op where `parent` is used in `instr`.
-// Returns `nullptr` if no such `slice` exists.
-const HloSliceInstruction* FindUniqueSlice(const HloInstruction* parent,
-                                           const HloInstruction* instr) {
-  if (const auto* slice = DynCast<HloSliceInstruction>(instr)) {
-    return slice;
-  } else if (const auto* fusion = DynCast<HloFusionInstruction>(instr)) {
-    const HloSliceInstruction* result = nullptr;
-    for (size_t i = 0; i < fusion->operand_count(); ++i) {
-      if (fusion->operand(i) == parent) {
-        // Parameter used more than once -> there's no unique slice.
-        if (result) return nullptr;
-
-        auto* called_param = fusion->fused_parameter(i);
-        if (called_param->user_count() != 1) return nullptr;
-
-        result = FindUniqueSlice(called_param, called_param->users()[0]);
-        if (!result) return nullptr;
-      }
-    }
-    return result;
-  } else {
-    return nullptr;
-  }
-}
-
-FusionDecision ParameterSlicesAreNonOverlapping(const HloInstruction& instr1,
-                                                const HloInstruction& instr2,
-                                                const HloInstruction* parent) {
-  if (parent->shape().IsTuple()) return {};
-  // Allow MOF if the parameter is small, even if there's no overlap. 1024 bytes
-  // were arbitrarily chosen as the threshold.
-  if (ShapeUtil::ByteSizeOfElements(parent->shape()) < 1024) return {};
-
-  const HloSliceInstruction* slice1 = FindUniqueSlice(parent, &instr1);
-  const HloSliceInstruction* slice2 = FindUniqueSlice(parent, &instr2);
-  if (!slice1 || !slice2) return {};
-
-  // TODO(jreiffers): Check strides as well.
-  auto& starts1 = slice1->slice_starts();
-  auto& starts2 = slice2->slice_starts();
-  auto& limits1 = slice1->slice_limits();
-  auto& limits2 = slice2->slice_limits();
-
-  for (int64_t dim = 0; dim < parent->shape().rank(); ++dim) {
-    bool overlap = starts1[dim] < limits2[dim] && starts2[dim] < limits1[dim];
-    if (!overlap) {
-      return "slices are non-overlapping";
-    }
-  }
-  return {};
-}
-
-FusionDecision LegalToFuse(const HloInstruction& instr1,
-                           const HloInstruction& instr2,
-                           const se::DeviceDescription& device_info,
-                           FusionInfoCache* fusion_info_cache) {
-  CHECK(instr1.opcode() == HloOpcode::kFusion);
-
-  // The emitter only supports in-place DUS for fusions with a single DUS at the
-  // root. Don't sibling fuse DUS for now.
-  // TODO(b/119178699): Multi-output fusing DUS can improve performance if we
-  // share the input and output buffers and add support to the emitter.
-  if (instr1.fused_expression_root()->opcode() ==
-          HloOpcode::kDynamicUpdateSlice ||
-      (instr2.opcode() == HloOpcode::kFusion &&
-       instr2.fused_expression_root()->opcode() ==
-           HloOpcode::kDynamicUpdateSlice)) {
-    return "can't fuse multiple DUSs";
-  }
-
-  // Do this check last, as it may be expensive.
-  return FusionFitsInBudget(instr1, instr2, device_info,
-                            /*is_consumer_producer_fusion=*/false,
-                            fusion_info_cache);
-}
-
-// We prefer multi-output fusions over other fusions over unfused ops, because
-// we want to preserve fusion opportunities if possible.
-int FusionPriority(const HloInstruction* instr) {
-  if (instr->IsMultiOutputFusion()) {
-    return 2;
-  }
-  if (instr->opcode() == HloOpcode::kFusion) {
-    return 1;
-  }
-  return 0;
-}
-
-HloInstruction* SelectPreferredFusionCandidate(
-    const std::vector<HloInstruction*> candidates) {
-  if (candidates.empty()) {
-    return nullptr;
-  }
-  return *std::max_element(
-      candidates.begin(), candidates.end(),
-      [](const HloInstruction* a, const HloInstruction* b) {
-        return FusionPriority(a) < FusionPriority(b);
-      });
-}
-
-// Do not fuse a producer if the other operands of the fusion are
-// reachable from the producer, this would create a cycle.
-FusionDecision OperandReachableFromProducer(
-    const HloInstruction& producer, const HloInstruction& consumer,
-    const HloDfsReachability& reachability) {
-  for (const auto* operand : consumer.operands()) {
-    // If a get-tuple-element instruction is not in the reachability
-    // map, it has been created by fusion in this pass. Simply move
-    // on to its operand, which is in the reachability map.
-    if (!reachability.IsPresent(operand) &&
-        operand->opcode() == HloOpcode::kGetTupleElement) {
-      operand = operand->operand(0);
-    }
-    CHECK(reachability.IsPresent(operand) && reachability.IsPresent(&producer))
-        << "Reachability map is incomplete. This should never "
-           "happen.";
-    if (&producer != operand && reachability.IsReachable(&producer, operand)) {
-      return {
-          absl::StrCat(producer.name(), " would introduce a cycle when fused")};
-    }
-  }
-  return {};
-}
-
-FusionDecision ProducerCandidateIsFusible(
-    const HloInstruction& producer, const HloInstruction& consumer,
-    const HloDfsReachability& reachability, FusionInfoCache* fusion_info_cache,
-    const se::DeviceDescription& device_info,
-    GpuHloCostAnalysis* cost_analysis) {
-  if (!IsFusibleAsMultiOutputFusionRoot(consumer)) {
-    return "consumer not eligible as multi-output fusion root.";
-  }
-
-  RETURN_IF_NOT_FUSIBLE(
-      ShapesCompatibleForMultiOutputFusion(consumer, producer));
-
-  RETURN_IF_NOT_FUSIBLE(
-      OperandReachableFromProducer(producer, consumer, reachability));
-
-  RETURN_IF_NOT_FUSIBLE(FusionFitsInBudget(
-      producer, consumer, device_info,
-      /*is_consumer_producer_fusion=*/false, fusion_info_cache));
-
-  if (cost_analysis->ProducerConsumerMergedTooLarge(producer, consumer)) {
-    return "will generate too large IR";
-  }
-
-  GpuPerformanceModel::RunTimes t = GpuPerformanceModel::EstimateRunTimes(
-      &producer, device_info, cost_analysis,
-      GpuPerformanceModelOptions::Default(),
-      /*fused_consumers=*/{&consumer},
-      /*multi_output=*/true);
-  if (t.time_fused > t.time_unfused) {
-    return "will execute slower if fused";
-  }
-
-  return {};
-}
-
-std::vector<HloInstruction*> GetProducerConsumerMultiOutputFusionCandidates(
-    const HloInstruction* producer, const HloDfsReachability& reachability,
-    FusionInfoCache* fusion_info_cache,
-    const se::DeviceDescription& device_info,
-    GpuHloCostAnalysis* cost_analysis) {
-  std::vector<HloInstruction*> fusion_candidates;
-  const HloComputation* computation = producer->parent();
-  const HloModule* module = computation->parent();
-  bool dump_fusion =
-      module->config().debug_options().xla_dump_fusion_visualization();
-
-  // If the producer is not a valid candidate for MOF, no need to check any of
-  // its users.
-  if (!IsProducerMultiOutputFusible(*producer)) {
-    return fusion_candidates;
-  }
-
-  // If there is only one user, and it is not a multi-output fusion node, this
-  // fusion possibility was already considered and rejected by the FusionMerger
-  // pass. No need to try again!
-  if (producer->user_count() == 1 &&
-      !producer->users()[0]->IsMultiOutputFusion()) {
-    return fusion_candidates;
-  }
-
-  for (HloInstruction* consumer : producer->users()) {
-    VLOG(3) << "Looking at producer " << producer->name()
-            << " and its consumer " << consumer->name();
-
-    if (auto decision = ProducerCandidateIsFusible(
-            *producer, *consumer, reachability, fusion_info_cache, device_info,
-            cost_analysis)) {
-      fusion_candidates.push_back(consumer);
-    } else if (dump_fusion) {
-      RegisterFusionState(
-          *computation,
-          absl::StrCat("Not considering fusion of producer |", producer->name(),
-                       "| into consumer |", consumer->name(),
-                       "| due to: ", decision.Explain()),
-          *consumer, producer);
-    }
-  }
-  return fusion_candidates;
-}
-
-bool IsSiblingFusionCandidate(const HloInstruction* instr) {
-  if (instr->users().empty() || !IsFusibleAsMultiOutputFusionRoot(*instr) ||
-      IsNestableVariadicReduction(*instr)) {
-    return false;
-  }
-  // Check if the users of multioutput fusion is not a get-tuple-element.
-  // If this is the case, we bail out because the transformation assumes
-  // the users are get-tuple-element.
-  return (!instr->IsMultiOutputFusion() ||
-          absl::c_all_of(instr->users(), [&](const HloInstruction* user) {
-            return user->opcode() == HloOpcode::kGetTupleElement;
-          }));
-}
-
-FusionDecision CanFuseSiblings(const HloInstruction& sibling_consumer_1,
-                               const HloInstruction& sibling_consumer_2,
-                               const HloInstruction& common_producer,
-                               const HloDfsReachability& reachability,
-                               FusionInfoCache* fusion_info_cache,
-                               const se::DeviceDescription& device_info) {
-  if (reachability.IsConnected(&sibling_consumer_1, &sibling_consumer_2)) {
-    return {absl::StrCat(sibling_consumer_1.name(), " and ",
-                         sibling_consumer_2.name(), " are connected")};
-  }
-
-  RETURN_IF_NOT_FUSIBLE(ShapesCompatibleForMultiOutputFusion(
-      sibling_consumer_1, sibling_consumer_2));
-
-  // Technically, this check is order-dependent (e.g. siblings A, B, C where
-  // {A, B} and {B, C} overlap, but {A, C} do not. If the priority order is
-  // [C, A, B], only {C, B} will be fused, and A will only be fused in the
-  // next iteration of the fusion pipeline, potentially requiring several
-  // iterations to converge. We assume this case to be very rare in
-  // practice.
-  RETURN_IF_NOT_FUSIBLE(ParameterSlicesAreNonOverlapping(
-      sibling_consumer_1, sibling_consumer_2, &common_producer));
-
-  // This check should be last, as it may be expensive.
-  RETURN_IF_NOT_FUSIBLE(LegalToFuse(sibling_consumer_1, sibling_consumer_2,
-                                    device_info, fusion_info_cache));
-  return {};
-}
-
-}  // namespace
-
-void GpuMultiOutputFusion::RecomputeReachability() {
-  reachability_ = HloDfsReachability::Build(computation_);
-}
-
-bool GpuMultiOutputFusion::FuseSiblings(HloInstruction* parent,
-                                        FusionInfoCache* fusion_info_cache,
-                                        GpuHloCostAnalysis* cost_analysis) {
-  const HloComputation* computation = parent->parent();
-  const HloModule* module = computation->parent();
-  bool dump_fusion =
-      module->config().debug_options().xla_dump_fusion_visualization();
-
-  if (!IsProfitableOperand(parent)) {
-    VLOG(3) << "Operand " << parent->ToShortString() << " is not profitable";
-    return false;
-  }
-  bool changed = false;
-  std::vector<HloInstruction*> siblings;
-  // Only consider siblings that are fusion candidates.
-  absl::c_copy_if(parent->users(), std::back_inserter(siblings),
-                  IsSiblingFusionCandidate);
-  // Sort the siblings such that multi-output fusion ops occur first, followed
-  // by fusion ops, followed by unfused ops.
-  absl::c_stable_sort(siblings,
-                      [](const HloInstruction* a, const HloInstruction* b) {
-                        return FusionPriority(a) > FusionPriority(b);
-                      });
-
-  for (auto i = siblings.begin(); i != siblings.end(); ++i) {
-    VLOG(3) << "Considering " << (*i)->name();
-    if ((*i)->opcode() != HloOpcode::kFusion) {
-      continue;
-    }
-    for (auto j = i + 1; j != siblings.end();) {
-      VLOG(3) << "Considering " << (*i)->name() << " and " << (*j)->name();
-
-      if (auto fusible = CanFuseSiblings(**i, **j, *parent, *reachability_,
-                                         fusion_info_cache, device_info_);
-          !fusible) {
-        // We pick `j` arbitrarily as a consumer.
-        if (dump_fusion) {
-          RegisterFusionState(
-              *computation,
-              absl::StrCat("Not fusing siblings |", (**i).name(), "| and |",
-                           (**j).name(), "| due to: ", fusible.Explain()),
-              // Randomly pick one consumer.
-              /*consumer=*/**i,
-              /*producer=*/parent);
-        }
-        ++j;
-        continue;
-      }
-      if (!ConsumeFuel(name(), [&] {
-            return absl::StrFormat("Not fusing siblings %s and %s.",
-                                   (*i)->name(), (*j)->name());
-          })) {
-        ++j;
-        continue;
-      }
-      VLOG(2) << "Fuse siblings " << (*i)->name() << " and " << (*j)->name();
-      fusion_info_cache->Invalidate(*i);
-      fusion_info_cache->Invalidate(*j);
-      HloInstruction* remaining = *i;
-      HloInstruction* fused = *j;
-      TF_CHECK_OK(cost_analysis->RemoveInstruction(remaining));
-      TF_CHECK_OK(cost_analysis->RemoveInstruction(fused));
-
-      DumpFusionState(*remaining,
-                      absl::StrCat("About to fuse sibling |", fused->name(),
-                                   "| into sibling |", remaining->name(),
-                                   "| inside multi-output fusion"),
-                      /*producer=*/fused);
-
-      if (fused->opcode() == HloOpcode::kFusion) {
-        remaining->MergeFusionInstructionIntoMultiOutput(fused);
-        if (fused->IsInputFusion()) {
-          remaining->set_fusion_kind(HloInstruction::FusionKind::kInput);
-        }
-      } else {
-        remaining->FuseInstructionIntoMultiOutput(fused);
-        CHECK_EQ(0, fused->user_count());
-        TF_CHECK_OK(computation_->RemoveInstruction(fused));
-      }
-      DumpFusionState(*remaining,
-                      absl::StrCat("Fused into |", remaining->name(),
-                                   "| inside multi-output fusion"));
-      TF_CHECK_OK(cost_analysis->RevisitInstruction(remaining));
-      changed = true;
-      siblings.erase(j);
-      RecomputeReachability();
-    }
-  }
-  return changed;
-}
-
-absl::StatusOr<bool> GpuMultiOutputFusion::DoMultiOutputFusion() {
-  bool changed = false;
-  RecomputeReachability();
-  GpuHloCostAnalysis cost_analysis({shape_size_function_,
-                                    /*per_second_rates=*/{},
-                                    /*count_multiple_input_accesses=*/true},
-                                   device_info_);
-  TF_RETURN_IF_ERROR(computation_->Accept(&cost_analysis));
-  std::vector<HloInstruction*> defs_before_uses =
-      computation_->MakeInstructionPostOrder();
-
-  FusionInfoCache fusion_info_cache;
-  // Traverse the HLO in uses-before-defs order.
-  for (auto it = defs_before_uses.rbegin(); it != defs_before_uses.rend();
-       ++it) {
-    auto* producer = *it;
-    // Never multi-output fuse constants.  To the extent that we want to fuse
-    // constants, that should be handled by the regular fusion pass.
-    if (producer->opcode() == HloOpcode::kConstant) {
-      VLOG(3) << producer->name() << " is a constant.";
-      continue;
-    }
-    if (producer->IsCustomFusion()) {
-      continue;
-    }
-    // First, fuse the consumer ops of the current op, which are siblings.
-    if (FuseSiblings(/*parent=*/producer, &fusion_info_cache, &cost_analysis)) {
-      changed = true;
-    }
-    // Second, perform producer-consumer multi-output fusion. This order will
-    // ensure that all get-tuple-element ops inserted as a by-product of
-    // multi-output fusion will occur before the current op in the order of
-    // traversal, and hence, not get into the way of subsequent fusion attempts.
-    const auto candidates = GetProducerConsumerMultiOutputFusionCandidates(
-        producer, *reachability_, &fusion_info_cache, device_info_,
-        &cost_analysis);
-    auto* consumer_for_fusion = SelectPreferredFusionCandidate(candidates);
-    if (consumer_for_fusion == nullptr) {
-      continue;
-    }
-    if (!ConsumeFuel(name(), [&] {
-          return absl::StrFormat("Not fusing %s and %s.", producer->name(),
-                                 consumer_for_fusion->name());
-        })) {
-      continue;
-    }
-    changed = true;
-    fusion_info_cache.Invalidate(producer);
-    fusion_info_cache.Invalidate(consumer_for_fusion);
-    TF_RETURN_IF_ERROR(cost_analysis.RemoveInstruction(producer));
-    TF_RETURN_IF_ERROR(cost_analysis.RemoveInstruction(consumer_for_fusion));
-
-    HloInstruction* input_fusion;
-    if (consumer_for_fusion->opcode() == HloOpcode::kFusion) {
-      input_fusion = consumer_for_fusion;
-      VLOG(2) << "Fuse producer " << producer->name() << " into its consumer "
-              << consumer_for_fusion->name();
-    } else {
-      input_fusion = computation_->AddInstruction(HloInstruction::CreateFusion(
-          consumer_for_fusion->shape(),
-          ChooseFusionKind(*producer, *consumer_for_fusion),
-          consumer_for_fusion));
-      VLOG(2) << "Fuse producer " << producer->name() << " and its consumer "
-              << consumer_for_fusion->name() << " into "
-              << input_fusion->name();
-      TF_CHECK_OK(
-          computation_->ReplaceInstruction(consumer_for_fusion, input_fusion));
-    }
-
-    DumpFusionState(*input_fusion,
-                    absl::StrCat("About to fuse producer |", producer->name(),
-                                 "| into consumer |", input_fusion->name(),
-                                 "| inside multi-output fusion"),
-                    /*producer=*/producer);
-
-    if (producer->opcode() == HloOpcode::kFusion) {
-      input_fusion->MergeFusionInstructionIntoMultiOutput(producer);
-    } else {
-      input_fusion->FuseInstructionIntoMultiOutput(producer);
-      CHECK_EQ(0, producer->user_count());
-      TF_CHECK_OK(computation_->RemoveInstruction(producer));
-    }
-    TF_RETURN_IF_ERROR(cost_analysis.RevisitInstruction(input_fusion));
-
-    DumpFusionState(*input_fusion,
-                    absl::StrCat("Fused into |", input_fusion->name(),
-                                 "| inside multi-output fusion"));
-    RecomputeReachability();
-  }
-  return changed;
-}
-
-void GpuMultiOutputFusion::DumpFusionState(const HloInstruction& consumer,
-                                           absl::string_view label,
-                                           const HloInstruction* producer) {
-  if (consumer.GetModule()
-          ->config()
-          .debug_options()
-          .xla_dump_fusion_visualization()) {
-    RegisterFusionState(*computation_, label, consumer, producer);
-  }
-}
-
-absl::StatusOr<bool> GpuMultiOutputFusion::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  bool changed = false;
-  for (auto* computation : GetFusibleComputations(*module, execution_threads)) {
-    computation_ = computation;
-    TF_ASSIGN_OR_RETURN(bool computation_changed, DoMultiOutputFusion());
-    changed |= computation_changed;
-  }
-  return changed;
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/multi_output_fusion.h b/third_party/xla/xla/service/gpu/multi_output_fusion.h
deleted file mode 100644
index 82789d3..0000000
--- a/third_party/xla/xla/service/gpu/multi_output_fusion.h
+++ /dev/null
@@ -1,134 +0,0 @@
-/* Copyright 2018 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_MULTI_OUTPUT_FUSION_H_
-#define XLA_SERVICE_GPU_MULTI_OUTPUT_FUSION_H_
-
-#include <memory>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_dfs_reachability.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/gpu/gpu_fusible.h"
-#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h"
-#include "xla/service/hlo_cost_analysis.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/stream_executor/device_description.h"
-
-namespace xla {
-namespace gpu {
-
-// Multi-output fusion of sibling and producer-consumer instructions for the
-// GPU backend to reduce memory bandwidth requirements.
-//
-//   0) Before multi-    1) Sibling multi-    2) Producer-consumer
-//      output fusion       output fusion        multi-output fusion
-//
-//          p                    p                    p
-//          |                    |                    |
-//          v                    v                    v
-//          A                    A               +-fusion--+
-//        /   \                  |               |    A    |
-//       |     |            +-fusion--+          |   / \   |
-//       v     v            |   / \   |          |  B   |  |
-//       B     C            |  B   C  |          |  |   |  |
-//        \   /             |  |   |  |          |  v   v  |
-//         v v              |  v   v  |          |  tuple  |
-//        ROOT              |  tuple  |          +---------+
-//                          +---------+            /    \
-//                            /    \            gte_b  gte_a
-//                         gte_b  gte_c           |      |
-//                           |      |             |      v
-//                            \    /              |      C
-//                             v  v                \    /
-//                             ROOT                 v  v
-//                                                  ROOT
-//
-// Multi-output fusion ops have a tuple op at their root containing multiple
-// elements as outputs. GetTupleElement ops (depicted as gte_* above) are
-// inserted to extract tuple elements for consumers.
-//
-// The two different flavors of multi-output fusion this pass performs are
-// depicted above.
-// 1) Fusion of sibling ops reduces memory bandwidth requirements, because
-//    common input parameters have to be read only once.
-// 2) Fusion of producer-consumer ops reduces memory bandwidth requirements by
-//    saving one read from memory. In the example above, B does not need to read
-//    the output of A from memory, while C still does (using gte_a).
-// Note that sibling (1) and producer-consumer (2) multi-output fusion can be
-// combined.
-//
-// The GpuMultiOutputFusion pass modifies the HLO in reverse post-order (defs
-// before uses). First, it attempts to fuse the consumer ops of the current op,
-// which are siblings (1). Hereafter, it attempts to fuse the current op with
-// one of its consumers (2). This order avoids a phase ordering issue (described
-// in go/fusionfusion). It ensures that all GetTupleElement ops inserted as a
-// by-product of multi-output fusion will occur before the current op in the
-// order of traversal, and hence, not get into the way of subsequent fusion
-// attempts.
-//
-// The GpuMultiOutputFusion pass ensures several conditions are met for fusion.
-// Some of them are relevant for correctness. In particular, no cycles must be
-// introduced into the HLO module. Moreover, the code emitters for multi-output
-// fusion must support the combination of ops and their shapes. Other
-// restrictions are rather arbitrary and lifting them could be beneficial.
-// * Sibling fusion (1) requires at least one op to be a kFusion.
-// * Sibling fusion (1) does not fuse kInput fusions with kLoop fusions, i.e.
-//   the fusion kinds must match.
-
-class GpuMultiOutputFusion : public HloModulePass {
- public:
-  explicit GpuMultiOutputFusion(
-      const se::DeviceDescription& device_info,
-      HloCostAnalysis::ShapeSizeFunction shape_size_function)
-      : device_info_(device_info), shape_size_function_(shape_size_function) {}
-
-  absl::string_view name() const override { return "multi_output_fusion"; }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
-  bool FuseSiblings(HloInstruction* parent, FusionInfoCache* fusion_info_cache,
-                    GpuHloCostAnalysis* cost_analysis);
-
-  absl::StatusOr<bool> DoMultiOutputFusion();
-
-  // Recompute reachability for the current computation.
-  void RecomputeReachability();
-
-  void DumpFusionState(const HloInstruction& consumer, absl::string_view label,
-                       const HloInstruction* producer = nullptr);
-
-  // Computation for the pass.
-  HloComputation* computation_;
-
-  // The reachability map of current computation.
-  std::unique_ptr<HloDfsReachability> reachability_;
-
-  se::DeviceDescription device_info_;
-  HloCostAnalysis::ShapeSizeFunction shape_size_function_;
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_MULTI_OUTPUT_FUSION_H_
diff --git a/third_party/xla/xla/service/gpu/multi_output_fusion_test.cc b/third_party/xla/xla/service/gpu/multi_output_fusion_test.cc
deleted file mode 100644
index b333a04..0000000
--- a/third_party/xla/xla/service/gpu/multi_output_fusion_test.cc
+++ /dev/null
@@ -1,2236 +0,0 @@
-/* Copyright 2018 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/multi_output_fusion.h"
-
-#include <cstdint>
-#include <optional>
-#include <vector>
-
-#include "absl/strings/str_cat.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/gpu/gpu_device_info_for_tests.h"
-#include "xla/service/gpu/gpu_fusible.h"
-#include "xla/service/hlo_cost_analysis.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/tests/hlo_test_base.h"
-
-namespace xla {
-namespace gpu {
-
-namespace m = ::xla::match;
-
-class MultiOutputFusionTest : public HloTestBase {
-  HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const {
-    return [&](const Shape& shape) {
-      constexpr int64_t kPointerSize = 8;
-      return ShapeUtil::ByteSizeOf(shape, kPointerSize);
-    };
-  }
-
- public:
-  GpuMultiOutputFusion mof_{
-      TestGpuDeviceInfo::RTXA6000DeviceInfo(),
-      ShapeSizeBytesFunction()};
-
-  void CheckGpuMultiOutputFusion(absl::string_view hlo,
-                                 std::optional<absl::string_view> expected) {
-    RunAndFilecheckHloRewrite(
-        hlo,
-        GpuMultiOutputFusion{
-            TestGpuDeviceInfo::RTXA6000DeviceInfo(),
-            ShapeSizeBytesFunction()},
-        expected);
-  }
-};
-
-const char kModulePrefix[] = R"(
-    HloModule test_module
-
-    scalar_add_computation {
-      scalar_lhs.0 = f32[] parameter(0)
-      scalar_rhs.0 = f32[] parameter(1)
-      ROOT add.0 = f32[] add(scalar_lhs.0, scalar_rhs.0)
-    }
-    scalar_mul_computation {
-      scalar_lhs.1 = f32[] parameter(0)
-      scalar_rhs.1 = f32[] parameter(1)
-      ROOT mul.1 = f32[] multiply(scalar_lhs.1, scalar_rhs.1)
-    })";
-
-static int64_t CountMultiOutputFusions(const HloModule* module) {
-  int multi_output_fusion_count = 0;
-  for (auto* computation : module->MakeNonfusionComputations()) {
-    for (auto* instr : computation->instructions()) {
-      if (instr->IsMultiOutputFusion()) {
-        multi_output_fusion_count++;
-      }
-    }
-  }
-  return multi_output_fusion_count;
-}
-
-TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingReduceAndReduceFusion) {
-  // Fusion with reduce instruction root and a sibling reduce instruction
-  // sharing the same input param.
-  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
-    fused_computation {
-      p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
-      mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1)
-      const.1 = f32[] parameter(0)
-      ROOT reduce.1 = f32[512]{0} reduce(mul, const.1), dimensions={0,2,3}, to_apply=scalar_add_computation
-    }
-
-    ENTRY entry {
-      p0 = f32[] parameter(0)
-      p1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
-      const.2 = f32[] constant(1)
-      fusion = f32[512] fusion(p0, p1), kind=kInput, calls=fused_computation
-      reduce.2 = f32[512]{0} reduce(p1, const.2), dimensions={0,2,3}, to_apply=scalar_add_computation
-      ROOT root = (f32[512]{0}, f32[512]{0}) tuple(fusion, reduce.2)
-    })"))
-                    .value();
-  ASSERT_TRUE(mof_.Run(module.get()).value());
-  SCOPED_TRACE(module->ToString());
-  const HloInstruction* fusion =
-      module->entry_computation()->root_instruction()->operand(0)->operand(0);
-  ASSERT_TRUE(fusion->IsMultiOutputFusion());
-  EXPECT_THAT(fusion->fused_expression_root(),
-              GmockMatch(m::Tuple(m::Reduce(), m::Reduce())));
-}
-
-TEST_F(MultiOutputFusionTest, MultiOutputFusionDifferentReduceInputShapes) {
-  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
-    fused_computation_1 {
-      p1.1 = f32[6400]{0} parameter(1)
-      mul = f32[6400]{0} multiply(p1.1, p1.1)
-      const.1 = f32[] parameter(0)
-      ROOT reduce.1 = f32[] reduce(mul, const.1), dimensions={0}, to_apply=scalar_add_computation
-    }
-
-    fused_computation_2 {
-      p1.2 = f32[6400]{0} parameter(1)
-      r1 = f32[64,100]{0,1} reshape(p1.2)
-      const.2 = f32[] parameter(0)
-      ROOT reduce.2 = f32[] reduce(r1, const.2), dimensions={1,0}, to_apply=scalar_mul_computation
-    }
-
-    ENTRY entry {
-      p0 = f32[] parameter(0)
-      p1 = f32[6400]{0} parameter(1)
-      fusion.1 = f32[] fusion(p0, p1), kind=kInput, calls=fused_computation_1
-      fusion.2 = f32[] fusion(p0, p1), kind=kInput, calls=fused_computation_2
-      ROOT root = (f32[], f32[]) tuple(fusion.1, fusion.2)
-    })"))
-                    .value();
-  ASSERT_FALSE(mof_.Run(module.get()).value());
-}
-
-TEST_F(MultiOutputFusionTest, ReduceMofDifferentTypes) {
-  // Fusion with reduce instruction root and a sibling reduce instruction
-  // sharing the same input param.
-  const char* hlo = R"(
-HloModule module
-
-scalar_add_computation {
-  scalar_lhs.1 = f32[] parameter(0)
-  scalar_rhs.1 = f32[] parameter(1)
-  ROOT add.1 = f32[] add(scalar_lhs.1, scalar_rhs.1)
-}
-
-scalar_add_computation_f16 {
-  scalar_lhs.0 = f16[] parameter(0)
-  scalar_rhs.0 = f16[] parameter(1)
-  ROOT add.0 = f16[] add(scalar_lhs.0, scalar_rhs.0)
-}
-
-fused_computation {
-  param_0.2 = f32[128,512,28,28]{3,2,1,0} parameter(0)
-  c.1 = f16[128,512,28,28]{3,2,1,0} convert(param_0.2)
-  const.0 = f16[] constant(0)
-  ROOT reduce.0 = f16[512]{0} reduce(c.1, const.0), dimensions={0,2,3}, to_apply=scalar_add_computation_f16
-}
-
-ENTRY entry {
-  p0 = f32[] parameter(0)
-  p1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
-  const.2 = f32[] constant(0)
-  reduce.1 = f32[512]{0} reduce(p1, const.2), dimensions={0,2,3}, to_apply=scalar_add_computation
-  fusion = f16[512]{0} fusion(p1), kind=kInput, calls=fused_computation
-  ROOT root = (f32[512]{0}, f16[512]{0}) tuple(reduce.1, fusion)
-})";
-
-  CheckGpuMultiOutputFusion(hlo, R"(
-// CHECK: %fused_computation
-// CHECK-NEXT:   [[param_0_2_0:%[^ ]+]] = f32[128,512,28,28]{3,2,1,0} parameter(0)
-// CHECK-NEXT:   [[c_1_1:%[^ ]+]] = f16[128,512,28,28]{3,2,1,0} convert([[param_0_2_0]])
-// CHECK-NEXT:   [[const_0_2:%[^ ]+]] = f16[] constant(0)
-// CHECK-NEXT:   [[reduce_0_3:%[^ ]+]] = f16[512]{0} reduce([[c_1_1]], [[const_0_2]]), dimensions={0,2,3}, to_apply=[[scalar_add_computation_f16_4:%[^ ]+]]
-// CHECK-NEXT:   [[param_1_5:%[^ ]+]] = f32[] parameter(1)
-// CHECK-NEXT:   [[reduce_2_6:%[^ ]+]] = f32[512]{0} reduce([[param_0_2_0]], [[param_1_5]]), dimensions={0,2,3}, to_apply=[[scalar_add_computation_7:%[^ ]+]]
-// CHECK-NEXT:   ROOT [[tuple_8:%[^ ]+]] = (f16[512]{0}, f32[512]{0}) tuple([[reduce_0_3]], [[reduce_2_6]])
-// CHECK:   [[fusion_9:%[^ ]+]] = (f16[512]{0}, f32[512]{0}) fusion([[p1_10:%[^ ]+]], [[const_2_11:%[^ ]+]]), kind=kInput, calls=[[fused_computation_12:%[^ ]+]]
-)");
-}
-
-TEST_F(MultiOutputFusionTest, MultiOutputFusionDifferentReduceOutputShapes) {
-  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
-    fused_computation_1 {
-      p1.1 = f32[10,10]{1,0} parameter(1)
-      mul = f32[10,10]{1,0} multiply(p1.1, p1.1)
-      const.1 = f32[] parameter(0)
-      ROOT reduce.1 = f32[] reduce(mul, const.1), dimensions={0,1}, to_apply=scalar_add_computation
-    }
-
-    fused_computation_2 {
-      p1.2 = f32[10,10]{1,0} parameter(1)
-      const.2 = f32[] parameter(0)
-      ROOT reduce.2 = f32[10]{0} reduce(p1.2, const.2), dimensions={0}, to_apply=scalar_mul_computation
-    }
-
-    ENTRY entry {
-      p0 = f32[] parameter(0)
-      p1.3 = f32[10,10]{1,0} parameter(1)
-      fusion.1 = f32[] fusion(p0, p1.3), kind=kInput, calls=fused_computation_1
-      p2 = f32[] parameter(2)
-      fusion.2 = f32[10]{0} fusion(p2, p1.3), kind=kInput, calls=fused_computation_2
-      ROOT root = (f32[], f32[10]{0}) tuple(fusion.1, fusion.2)
-    })"))
-                    .value();
-  ASSERT_FALSE(mof_.Run(module.get()).value());
-}
-
-TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingReduceFusions) {
-  // Two sibling fusions with reduce instruction roots sharing the same input
-  // param.
-  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
-    fused_computation_1 {
-      p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
-      mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1)
-      const.1 = f32[] parameter(0)
-      ROOT reduce.1 = f32[512]{0} reduce(mul, const.1), dimensions={0,2,3}, to_apply=scalar_add_computation
-    }
-
-    fused_computation_2 {
-      p1.2 = f32[128,512,28,28]{3,2,1,0} parameter(1)
-      const.2 = f32[] parameter(0)
-      ROOT reduce.2 = f32[512]{0} reduce(p1.2, const.2), dimensions={0,2,3}, to_apply=scalar_add_computation
-    }
-
-    ENTRY entry {
-      p0 = f32[] parameter(0)
-      p1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
-      fusion.1 = f32[512] fusion(p0, p1), kind=kInput, calls=fused_computation_1
-      fusion.2 = f32[512] fusion(p0, p1), kind=kInput, calls=fused_computation_2
-      ROOT root = (f32[512]{0}, f32[512]{0}) tuple(fusion.1, fusion.2)
-    })"))
-                    .value();
-  ASSERT_TRUE(mof_.Run(module.get()).value());
-  SCOPED_TRACE(module->ToString());
-  const HloInstruction* fusion =
-      module->entry_computation()->root_instruction()->operand(0)->operand(0);
-  ASSERT_TRUE(fusion->IsMultiOutputFusion());
-  EXPECT_THAT(fusion->fused_expression_root(),
-              GmockMatch(m::Tuple(m::Reduce(), m::Reduce())));
-}
-
-TEST_F(MultiOutputFusionTest, MultiOutputFusionNoSiblingFusionForCommonScalar) {
-  // Two sibling fusions with bitcast roots sharing the same scalar input param.
-  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
-    fused_computation_1 {
-      param_0.87 = bf16[32,4096,16384]{2,1,0} parameter(0)
-      param_1.4620 = s32[] parameter(1)
-      constant_3949 = s32[] constant(0)
-      compare.1026 = pred[] compare(param_1.4620, constant_3949), direction=LT
-      constant_5437 = s32[] constant(32)
-      add.6859 = s32[] add(param_1.4620, constant_5437)
-      select.1599 = s32[] select(compare.1026, add.6859, param_1.4620)
-      dynamic-slice.59 = bf16[1,4096,16384]{2,1,0} dynamic-slice(param_0.87, select.1599, constant_3949, constant_3949), dynamic_slice_sizes={1,4096,16384}
-      ROOT bitcast.41089 = bf16[4096,16384]{1,0} bitcast(dynamic-slice.59)
-    }
-
-    fused_computation_2 {
-      param_0 = bf16[32,4096,16384]{2,1,0} parameter(0)
-      param_1 = s32[] parameter(1)
-      constant = s32[] constant(0)
-      compare = pred[] compare(param_1, constant), direction=LT
-      constant.32 = s32[] constant(32)
-      add = s32[] add(param_1, constant.32)
-      select = s32[] select(compare, add, param_1)
-      dynamic-slice = bf16[1,4096,16384]{2,1,0} dynamic-slice(param_0, select, constant, constant), dynamic_slice_sizes={1,4096,16384}
-      ROOT bitcast.41087 = bf16[4096,16384]{1,0} bitcast(dynamic-slice)
-    }
-
-    ENTRY entry {
-      p0 = s32[] parameter(0)
-      p1 = bf16[32,4096,16384]{2,1,0} parameter(1)
-      p2 = bf16[32,4096,16384]{2,1,0} parameter(2)
-      fusion.1 = bf16[4096,16384]{1,0} fusion(p1, p0), kind=kLoop, calls=fused_computation_1
-      fusion.2 = bf16[4096,16384]{1,0} fusion(p2, p0), kind=kLoop, calls=fused_computation_2
-      ROOT root = (bf16[4096,16384]{1,0}, bf16[4096,16384]{1,0}) tuple(fusion.1, fusion.2)
-    })"))
-                    .value();
-  ASSERT_FALSE(mof_.Run(module.get()).value());
-}
-
-TEST_F(MultiOutputFusionTest,
-       MultiOutputFusionSiblingReduceAndReduceMultiOutputFusion) {
-  // Multi-output fusion with two reduce instructions root and a sibling reduce
-  // instruction sharing the same input param.
-  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
-    fused_computation (p0: f32[128,512,28,28]) -> (f32[512], f32[512]) {
-      const.1 = f32[] constant(1)
-      p0.1 = f32[128,512,28,28]{3,2,1,0} parameter(0)
-      mul = f32[128,512,28,28]{3,2,1,0} multiply(f32[128,512,28,28]{3,2,1,0} p0.1, f32[128,512,28,28]{3,2,1,0} p0.1)
-      reduce.1 = f32[512]{0} reduce(f32[128,512,28,28]{3,2,1,0} mul, f32[] const.1), dimensions={0,2,3}, to_apply=scalar_add_computation
-      reduce.2 = f32[512]{0} reduce(f32[128,512,28,28]{3,2,1,0} p0.1, f32[] const.1), dimensions={0,2,3}, to_apply=scalar_add_computation
-      ROOT tuple = (f32[512]{0}, f32[512]{0}) tuple(f32[512]{0} reduce.1, f32[512]{0} reduce.2)
-    }
-
-    ENTRY entry (p0: f32[128,512,28,28]) -> (f32[512], f32[512], f32[512]) {
-      p0 = f32[128,512,28,28]{3,2,1,0} parameter(0)
-      const = f32[] constant(1)
-      fusion = (f32[512]{0}, f32[512]{0}) fusion(f32[128,512,28,28]{3,2,1,0} p0), kind=kInput, calls=fused_computation
-      get-tuple-element = f32[512]{0} get-tuple-element((f32[512]{0}, f32[512]{0}) fusion), index=0
-      get-tuple-element.1 = f32[512]{0} get-tuple-element((f32[512]{0}, f32[512]{0}) fusion), index=1
-      reduce.3 = f32[512]{0} reduce(p0, const), dimensions={0,2,3}, to_apply=scalar_add_computation
-      ROOT root = (f32[512]{0}, f32[512]{0}, f32[512]{0}) tuple(f32[512]{0} get-tuple-element, f32[512]{0} get-tuple-element.1, f32[512]{0} reduce.3)
-    })"))
-                    .value();
-  ASSERT_TRUE(mof_.Run(module.get()).value());
-  SCOPED_TRACE(module->ToString());
-  const HloInstruction* fusion =
-      module->entry_computation()->root_instruction()->operand(0)->operand(0);
-  ASSERT_TRUE(fusion->IsMultiOutputFusion());
-  EXPECT_THAT(fusion->fused_expression_root(),
-              GmockMatch(m::Tuple(m::Reduce(), m::Reduce(), m::Reduce())));
-}
-
-TEST_F(MultiOutputFusionTest,
-       MultiOutputFusionSiblingFusionCheckAgainstReduceOperand) {
-  // Verify that if we already have a multi-output fusion that we prefer to pick
-  // a reduce op from its operands for checking shape compatibility.
-  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
-    fused_computation_1 {
-      p1.1 = f32[10,10]{1,0} parameter(1)
-      mul = f32[10,10]{1,0} multiply(p1.1, p1.1)
-      const.1 = f32[] parameter(0)
-      reduce.1 = f32[] reduce(p1.1, const.1), dimensions={0,1}, to_apply=scalar_add_computation
-      ROOT tuple = (f32[10,10], f32[]) tuple(mul, reduce.1)
-    }
-
-    fused_computation_2 {
-      p1.2 = f32[10,10]{1,0} parameter(1)
-      const.2 = f32[] parameter(0)
-      ROOT reduce.2 = f32[10] reduce(p1.2, const.2), dimensions={0}, to_apply=scalar_mul_computation
-    }
-
-    ENTRY entry {
-      p0 = f32[] parameter(0)
-      p1 = f32[10,10]{1,0} parameter(1)
-      p2 = f32[] parameter(2)
-      fusion.1 = (f32[10,10], f32[]) fusion(p0, p1), kind=kInput, calls=fused_computation_1
-      get-tuple-element.1 = f32[10,10] get-tuple-element((f32[10,10], f32[]) fusion.1), index=0
-      get-tuple-element.2 = f32[] get-tuple-element((f32[10,10], f32[]) fusion.1), index=1
-      fusion.2 = f32[10] fusion(p2, p1), kind=kInput, calls=fused_computation_2
-      ROOT root = (f32[10,10], f32[], f32[10]) tuple(get-tuple-element.1, get-tuple-element.2, fusion.2)
-    })"))
-                    .value();
-  ASSERT_FALSE(mof_.Run(module.get()).value());
-}
-
-TEST_F(MultiOutputFusionTest, LoopVariadicReductionFusions) {
-  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
-    fused_computation.94 {
-      tmp_0 = f32[] parameter(0)
-      tmp_1 = f32[] parameter(1)
-      tmp_2 = pred[] compare(tmp_0, tmp_1), direction=GE
-      tmp_3 = f32[] select(tmp_2, tmp_0, tmp_1)
-      tmp_4 = pred[] compare(tmp_0, tmp_1), direction=EQ
-      tmp_5 = s32[] parameter(2)
-      tmp_6 = s32[] parameter(3)
-      tmp_7 = s32[] minimum(tmp_5, tmp_6)
-      tmp_8 = s32[] select(tmp_2, tmp_5, tmp_6)
-      tmp_9 = s32[] select(tmp_4, tmp_7, tmp_8)
-      ROOT tmp_10 = (f32[], s32[]) tuple(tmp_3, tmp_9)
-    }
-
-    minmax_func.1536 {
-      tmp_0 = f32[] parameter(0)
-      tmp_1 = f32[] parameter(2)
-      tmp_2 = s32[] parameter(1)
-      tmp_3 = s32[] parameter(3)
-      ROOT tmp_4 = (f32[], s32[]) fusion(tmp_0, tmp_1, tmp_2, tmp_3), kind=kLoop, calls=fused_computation.94
-    }
-
-    fused_computation {
-      tmp_0 = f32[554112,10]{1,0} parameter(0)
-      tmp_1 = s32[554112,10]{1,0} iota(), iota_dimension=1
-      tmp_2 = f32[] constant(-inf)
-      tmp_3 = s32[] constant(0)
-      ROOT tmp_4 = (f32[554112]{0}, s32[554112]{0}) reduce(tmp_0, tmp_1, tmp_2, tmp_3), dimensions={1}, to_apply=minmax_func.1536
-    }
-
-    fused_computation2 {
-      tmp_0 = f32[554112,10]{1,0} parameter(0)
-      tmp_1 = s32[554112,10]{1,0} iota(), iota_dimension=1
-      tmp_2 = f32[] constant(inf)
-      tmp_3 = s32[] constant(1)
-      ROOT tmp_4 = (f32[554112]{0}, s32[554112]{0}) reduce(tmp_0, tmp_1, tmp_2, tmp_3), dimensions={1}, to_apply=minmax_func.1536
-    }
-
-    ENTRY e {
-      tmp_0 = f32[554112,10]{1,0} parameter(0)
-      tmp_1 = (f32[554112]{0}, s32[554112]{0}) fusion(tmp_0), kind=kLoop, calls=fused_computation
-      tmp_2 = s32[554112]{0} get-tuple-element(tmp_1), index=1
-      tmp_4 = (f32[554112]{0}, s32[554112]{0}) fusion(tmp_0), kind=kLoop, calls=fused_computation2
-      tmp_5 = s32[554112]{0} get-tuple-element(tmp_4), index=1
-      ROOT tmp_6 = s32[554112]{0} add(tmp_2, tmp_5)
-    })"))
-                    .value();
-  EXPECT_FALSE(mof_.Run(module.get()).value());
-}
-
-TEST_F(MultiOutputFusionTest, InputVariadicReductionFusions) {
-  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
-    fused_computation.1117 {
-      param_0.2433 = f32[] parameter(0)
-      param_1.2571 = f32[] parameter(1)
-      compare.1770 = pred[] compare(param_0.2433, param_1.2571), direction=LE
-      select.682 = f32[] select(compare.1770, param_0.2433, param_1.2571)
-      compare.1303.clone.1 = pred[] compare(param_0.2433, param_1.2571), direction=EQ
-      param_2.6460 = s32[] parameter(2)
-      param_3.6755 = s32[] parameter(3)
-      minimum.633.clone.1 = s32[] minimum(param_2.6460, param_3.6755)
-      select.398.clone.1 = s32[] select(compare.1770, param_2.6460, param_3.6755)
-      select.397.clone.1 = s32[] select(compare.1303.clone.1, minimum.633.clone.1, select.398.clone.1)
-      ROOT tuple.151 = (f32[], s32[]) tuple(select.682, select.397.clone.1)
-    }
-
-    minmax_func.223 {
-      lhs_value.224 = f32[] parameter(0)
-      rhs_value.226 = f32[] parameter(2)
-      lhs_index.225 = s32[] parameter(1)
-      rhs_index.227 = s32[] parameter(3)
-      ROOT fusion.1117 = (f32[], s32[]) fusion(lhs_value.224, rhs_value.226, lhs_index.225, rhs_index.227), kind=kLoop, calls=fused_computation.1117
-    }
-
-    fused_computation.73 {
-      bitcast.86661 = f32[3,1024,300]{2,1,0} parameter(0)
-      iota.734 = s32[3,1,1024,300]{3,2,1,0} iota(), iota_dimension=3
-      bitcast.97555 = s32[3,1024,300]{2,1,0} bitcast(iota.734)
-      constant_3917 = f32[] constant(inf)
-      constant_3918 = s32[] constant(0)
-      ROOT reduce.1069 = (f32[3,1024]{1,0}, s32[3,1024]{1,0}) reduce(bitcast.86661, bitcast.97555, constant_3917, constant_3918), dimensions={2}, to_apply=minmax_func.223
-    }
-
-    fused_computation.84 {
-      bitcast.86676 = f32[3,1024,300]{2,1,0} parameter(0)
-      iota.732 = s32[3,1,1024,300]{3,2,1,0} iota(), iota_dimension=3
-      bitcast.97553 = s32[3,1024,300]{2,1,0} bitcast(iota.732)
-      constant_3915 = f32[] constant(inf)
-      constant_3916 = s32[] constant(0)
-      ROOT reduce.1070 = (f32[3,1024]{1,0}, s32[3,1024]{1,0}) reduce(bitcast.86676, bitcast.97553, constant_3915, constant_3916), dimensions={2}, to_apply=minmax_func.223
-    }
-
-    ENTRY e {
-      p0 = f32[3,1024,300]{2,1,0} parameter(0)
-      fusion.84 = (f32[3,1024]{1,0}, s32[3,1024]{1,0}) fusion(p0), kind=kInput, calls=fused_computation.84
-      gte.391 = s32[3,1024]{1,0} get-tuple-element(fusion.84), index=1
-      fusion.73 = (f32[3,1024]{1,0}, s32[3,1024]{1,0}) fusion(p0), kind=kInput, calls=fused_computation.73
-      gte.393 = s32[3,1024]{1,0} get-tuple-element(fusion.73), index=1
-      ROOT r = s32[3,1024]{1,0} add(gte.391, gte.393)
-    })"))
-                    .value();
-  EXPECT_TRUE(mof_.Run(module.get()).value());
-  EXPECT_EQ(module->entry_computation()->parameter_instruction(0)->user_count(),
-            1);
-  const HloInstruction* fusion =
-      module->entry_computation()->parameter_instruction(0)->users()[0];
-  EXPECT_THAT(fusion, GmockMatch(m::Fusion()));
-  EXPECT_THAT(fusion->fused_expression_root(),
-              GmockMatch(m::Tuple(m::Reduce(), m::Reduce())));
-}
-
-TEST_F(MultiOutputFusionTest, MultiOutputFusionTwoLoops) {
-  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
-    fused_computation_1 {
-      p0.1 = f32[6400]{0} parameter(0)
-      ROOT mul = f32[6400]{0} multiply(p0.1, p0.1)
-    }
-
-    fused_computation_2 {
-      p0.2 = f32[6400]{0} parameter(0)
-      const.2 = f32[] constant(1)
-      broadcast = f32[6400]{0} broadcast(const.2), dimensions={}
-      ROOT div = f32[6400]{0} divide(p0.2, broadcast)
-    }
-
-    ENTRY entry {
-      p0 = f32[6400]{0} parameter(0)
-      fusion.1 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_1
-      fusion.2 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_2
-      ROOT root = (f32[6400]{0}, f32[6400]{0}) tuple(fusion.1, fusion.2)
-    })"))
-                    .value();
-  ASSERT_TRUE(mof_.Run(module.get()).value());
-  SCOPED_TRACE(module->ToString());
-  const HloInstruction* fusion =
-      module->entry_computation()->root_instruction()->operand(0)->operand(0);
-  ASSERT_TRUE(fusion->IsMultiOutputFusion());
-  EXPECT_THAT(fusion->fused_expression_root(),
-              GmockMatch(m::Tuple(m::Multiply(), m::Divide())));
-}
-
-TEST_F(MultiOutputFusionTest, MultiOutputFusionLoopElementwise) {
-  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
-    fused_computation_1 {
-      p0.1 = f32[6400]{0} parameter(0)
-      ROOT mul = f32[6400]{0} multiply(p0.1, p0.1)
-    }
-
-    ENTRY entry {
-      p0 = f32[6400]{0} parameter(0)
-      fusion.1 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_1
-      const.2 = f32[] constant(1)
-      broadcast = f32[6400]{0} broadcast(const.2), dimensions={}
-      div = f32[6400]{0} divide(p0, broadcast)
-      ROOT root = (f32[6400]{0}, f32[6400]{0}) tuple(fusion.1, div)
-    })"))
-                    .value();
-  ASSERT_TRUE(mof_.Run(module.get()).value());
-  SCOPED_TRACE(module->ToString());
-  const HloInstruction* fusion =
-      module->entry_computation()->root_instruction()->operand(0)->operand(0);
-  ASSERT_TRUE(fusion->IsMultiOutputFusion());
-  EXPECT_THAT(fusion->fused_expression_root(),
-              GmockMatch(m::Tuple(m::Multiply(), m::Divide())));
-}
-
-TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopsDifferentShapes) {
-  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
-    fused_computation_1 {
-      p0.1 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} parameter(0)
-      ROOT mul = f32[8,1,5,16,1,2]{5,4,3,2,1,0} multiply(p0.1, p0.1)
-    }
-
-    fused_computation_2 {
-      p0.2 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} parameter(0)
-      const.2 = f32[] constant(0)
-      ROOT reduce = f32[1,5,1,2]{3,2,1,0} reduce(p0.2, const.2), dimensions={0,3}, to_apply=scalar_add_computation
-    }
-
-    ENTRY entry {
-      p0 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} parameter(0)
-      fusion.1 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_1
-      fusion.2 = f32[1,5,1,2]{3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_2
-      ROOT root = (f32[8,1,5,16,1,2]{5,4,3,2,1,0}, f32[1,5,1,2]{3,2,1,0}) tuple(fusion.1, fusion.2)
-    })"))
-                    .value();
-  ASSERT_FALSE(mof_.Run(module.get()).value());
-}
-
-TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopAndMultiOutputLoop) {
-  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
-    fused_computation_1 {
-      p0.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
-      mul = f32[8,1,5,16,1,1]{5,4,3,2,1,0} multiply(p0.1, p0.1)
-      exp = f32[8,1,5,16,1,1]{5,4,3,2,1,0} exponential(p0.1)
-      ROOT tuple = (f32[8,1,5,16,1,1]{5,4,3,2,1,0},
-        f32[8,1,5,16,1,1]{5,4,3,2,1,0}) tuple(mul, exp)
-    }
-
-    fused_computation_2 {
-      p0.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
-      const.2 = f32[] constant(0)
-      broadcast = f32[8,1,5,16,1,1]{5,4,3,2,1,0} broadcast(const.2),
-        dimensions={}
-      ROOT add = f32[8,1,5,16,1,1]{5,4,3,2,1,0} add(p0.2, broadcast)
-    }
-
-    ENTRY entry {
-      p0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
-      fusion.1 = (f32[8,1,5,16,1,1]{5,4,3,2,1,0},
-        f32[8,1,5,16,1,1]{5,4,3,2,1,0}) fusion(p0), kind=kLoop,
-        calls=fused_computation_1
-      fusion.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} fusion(p0), kind=kLoop,
-        calls=fused_computation_2
-      gte0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=0
-      gte1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=1
-      ROOT root = (f32[8,1,5,16,1,1]{5,4,3,2,1,0},
-        f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0})
-        tuple(gte0, gte1, fusion.2)
-    })"))
-                    .value();
-  ASSERT_TRUE(mof_.Run(module.get()).value());
-  SCOPED_TRACE(module->ToString());
-  const HloInstruction* fusion =
-      module->entry_computation()->root_instruction()->operand(0)->operand(0);
-  ASSERT_TRUE(fusion->IsMultiOutputFusion());
-  EXPECT_THAT(fusion->fused_expression_root(),
-              GmockMatch(m::Tuple(m::Multiply(), m::Exp(), m::Add())));
-}
-
-TEST_F(MultiOutputFusionTest,
-       MultiOutputFusionSiblingMultiOutputLoopAndMultiOutputLoop) {
-  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
-    fused_computation_1 {
-      p0.1 = f32[8,16]{1,0} parameter(0)
-      mul = f32[8,16]{1,0} multiply(p0.1, p0.1)
-      exp = f32[8,16]{1,0} exponential(p0.1)
-      ROOT tuple = (f32[8,16]{1,0}, f32[8,16]{1,0}) tuple(mul, exp)
-    }
-
-    fused_computation_2 {
-      p0.2 = f32[8,16]{1,0} parameter(0)
-      const.2 = f32[] constant(0)
-      broadcast = f32[8,16]{1,0} broadcast(const.2),
-        dimensions={}
-      add = f32[8,16]{1,0} add(p0.2, broadcast)
-      ROOT tuple.1 = (f32[8,16]{1,0}, f32[8,16]{1,0}) tuple(add, broadcast)
-    }
-
-    ENTRY entry {
-      p0 = f32[8,16]{1,0} parameter(0)
-      fusion.1 = (f32[8,16]{1,0}, f32[8,16]{1,0}) fusion(p0), kind=kLoop,
-        calls=fused_computation_1
-      fusion.2 = (f32[8,16]{1,0}, f32[8,16]{1,0}) fusion(p0), kind=kLoop,
-        calls=fused_computation_2
-      gte0 = f32[8,16]{1,0} get-tuple-element(fusion.1), index=0
-      gte1 = f32[8,16]{1,0} get-tuple-element(fusion.1), index=1
-      gte2 = f32[8,16]{1,0} get-tuple-element(fusion.2), index=0
-      gte3 = f32[8,16]{1,0} get-tuple-element(fusion.2), index=1
-      ROOT root = (f32[8,16]{1,0}, f32[8,16]{1,0}, f32[8,16]{1,0},
-        f32[8,16]{1,0})
-        tuple(gte0, gte1, gte2, gte3)
-    })"))
-                    .value();
-  ASSERT_TRUE(mof_.Run(module.get()).value());
-  SCOPED_TRACE(module->ToString());
-  const HloInstruction* fusion =
-      module->entry_computation()->root_instruction()->operand(0)->operand(0);
-  ASSERT_TRUE(fusion->IsMultiOutputFusion());
-  EXPECT_THAT(
-      fusion->fused_expression_root(),
-      GmockMatch(m::Tuple(m::Multiply(), m::Exp(), m::Add(), m::Broadcast())));
-}
-
-TEST_F(MultiOutputFusionTest,
-       MultiOutputFusionSiblingLoopAndMultiOutputLoopDifferentShapes) {
-  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
-    fused_computation_1 {
-      p0.1 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} parameter(0)
-      mul = f32[8,1,5,16,1,2]{5,4,3,2,1,0} multiply(p0.1, p0.1)
-      exp = f32[8,1,5,16,1,2]{5,4,3,2,1,0} exponential(p0.1)
-      ROOT tuple = (f32[8,1,5,16,1,2]{5,4,3,2,1,0},
-        f32[8,1,5,16,1,2]{5,4,3,2,1,0}) tuple(mul, exp)
-    }
-
-    fused_computation_2 {
-      p0.2 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} parameter(0)
-      const.2 = f32[] constant(0)
-      ROOT reduce = f32[1,5,1,2]{3,2,1,0} reduce(p0.2, const.2),
-        dimensions={0,3}, to_apply=scalar_add_computation
-    }
-
-    ENTRY entry {
-      p0 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} parameter(0)
-      fusion.1 = (f32[8,1,5,16,1,2]{5,4,3,2,1,0},
-        f32[8,1,5,16,1,2]{5,4,3,2,1,0}) fusion(p0), kind=kLoop,
-        calls=fused_computation_1
-      fusion.2 = f32[1,5,1,2]{3,2,1,0} fusion(p0), kind=kLoop,
-        calls=fused_computation_2
-      gte0 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=0
-      gte1 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=1
-      ROOT root = (f32[8,1,5,16,1,2]{5,4,3,2,1,0},
-        f32[8,1,5,16,1,2]{5,4,3,2,1,0}, f32[1,5,1,2]{3,2,1,0})
-        tuple(gte0, gte1, fusion.2)
-    })"))
-                    .value();
-  ASSERT_FALSE(mof_.Run(module.get()).value());
-}
-
-TEST_F(MultiOutputFusionTest, SiblingFusionBitcastAndLoopFusionNotFused) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-HloModule test
-
-fused_computation_1 {
-  p0.1 = f32[2048,16000]{1,0} parameter(0)
-  bitcast = f32[2048,1,16000]{2,1,0} bitcast(p0.1)
-  ROOT exp = f32[2048,1,16000]{2,1,0} exponential(bitcast)
-}
-
-ENTRY main {
-  param_0 = f32[2048,16000]{1,0} parameter(0)
-  fusion = f32[2048,1,16000]{2,1,0} fusion(param_0), kind=kLoop, calls=fused_computation_1
-  bitcast = f32[16000,1,2048]{2,1,0} bitcast(param_0)
-  ROOT tuple.143 = (f32[16000,1,2048]{2,1,0}, f32[2048,1,16000]{2,1,0}) tuple(bitcast, fusion)
-})")
-                    .value();
-  EXPECT_FALSE(mof_.Run(module.get()).value());
-}
-
-TEST_F(MultiOutputFusionTest,
-       ProducerConsumerFusionBitcastAndElementwiseNotFused) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-HloModule test
-
-ENTRY main {
-  param_0 = f32[2048,16000]{1,0} parameter(0)
-  convert = bf16[2048,16000]{1,0} convert(param_0)
-  bitcast = bf16[16000,1,2048]{2,1,0} bitcast(convert)
-  ROOT tuple.143 = (bf16[16000,1,2048]{2,1,0}, bf16[2048,16000]{1,0}) tuple(bitcast, convert)
-})")
-                    .value();
-  EXPECT_FALSE(mof_.Run(module.get()).value());
-}
-
-TEST_F(MultiOutputFusionTest, ProducerConsumerFusionElementwiseAndReduce) {
-  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
-    ENTRY reduce {
-      p0 = f32[32,32,32]{2,1,0} parameter(0)
-      c0 = f32[] constant(0)
-      exp = f32[32,32,32]{2,1,0} exponential(p0)
-      reduce = f32[32,32]{1,0} reduce(exp, c0), dimensions={2},
-        to_apply=scalar_add_computation
-      ROOT root = (f32[32,32]{1,0}, f32[32,32,32]{2,1,0}) tuple(reduce, exp)
-    })"))
-                    .value();
-  ASSERT_TRUE(mof_.Run(module.get()).value());
-  SCOPED_TRACE(module->ToString());
-  const HloInstruction* root = module->entry_computation()->root_instruction();
-  const HloInstruction* fusion = nullptr;
-  ASSERT_THAT(root, GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
-                                        m::GetTupleElement())));
-  ASSERT_TRUE(fusion->IsMultiOutputFusion());
-  EXPECT_THAT(fusion->fused_expression_root(),
-              GmockMatch(m::Tuple(m::Reduce(), m::Exp())));
-}
-
-TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduce) {
-  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
-    fused_add {
-      p0.1 = f32[32,32,32]{2,1,0} parameter(0)
-      p1.1 = f32[32,32,32]{2,1,0} parameter(1)
-      ROOT add = f32[32,32,32]{2,1,0} add(p0.1, p1.1)
-    }
-
-    ENTRY reduce {
-      p0 = f32[32,32,32]{2,1,0} parameter(0)
-      p1 = f32[32,32,32]{2,1,0} parameter(1)
-      c0 = f32[] constant(0)
-      add = f32[32,32,32]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_add
-      reduce = f32[32,32]{1,0} reduce(add, c0), dimensions={2},
-        to_apply=scalar_add_computation
-      ROOT root = (f32[32,32]{1,0}, f32[32,32,32]{2,1,0}) tuple(reduce, add)
-    })"))
-                    .value();
-  ASSERT_TRUE(mof_.Run(module.get()).value());
-  SCOPED_TRACE(module->ToString());
-  const HloInstruction* root = module->entry_computation()->root_instruction();
-  const HloInstruction* fusion = nullptr;
-  ASSERT_THAT(root, GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
-                                        m::GetTupleElement())));
-  ASSERT_TRUE(fusion->IsMultiOutputFusion());
-  EXPECT_THAT(fusion->fused_expression_root(),
-              GmockMatch(m::Tuple(m::Reduce(), m::Add())));
-}
-
-TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduceFusion) {
-  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
-    fused_select {
-      p1.1 = f32[32,32,32]{2,1,0} parameter(1)
-      c0 = f32[] constant(0)
-      broadcast = f32[32,32,32]{2,1,0} broadcast(f32[] c0), dimensions={}
-      greater-than = pred[32,32,32]{2,1,0} compare(f32[32,32,32]{2,1,0} p1.1,
-        f32[32,32,32]{2,1,0} broadcast), direction=GT
-      p0.1 = f32[32,32,32]{2,1,0} parameter(0)
-      ROOT select = f32[32,32,32]{2,1,0} select(pred[32,32,32]{2,1,0}
-        greater-than, f32[32,32,32]{2,1,0} p0.1, f32[32,32,32]{2,1,0} broadcast)
-    }
-
-    fused_reduce {
-      p0.2 = f32[32,32,32]{2,1,0} parameter(0)
-      c1 = f32[] constant(0)
-      r1 = f32[32,32]{1,0} reduce(p0.2, c1), dimensions={2},
-        to_apply=scalar_add_computation
-      mul = f32[32,32,32]{2,1,0} multiply(p0.2, p0.2)
-      r2 = f32[32,32]{1,0} reduce(mul, c1), dimensions={2},
-        to_apply=scalar_add_computation
-      ROOT tuple = (f32[32,32]{1,0}, f32[32,32]{1,0}) tuple(r1, r2)
-    }
-
-    ENTRY reduce {
-      p0 = f32[32,32,32]{2,1,0} parameter(0)
-      p1 = f32[32,32,32]{2,1,0} parameter(1)
-      select = f32[32,32,32]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_select
-      fusion = (f32[32,32]{1,0}, f32[32,32]{1,0}) fusion(select), kind=kInput,
-        calls=fused_reduce
-      gte0 = f32[32,32]{1,0} get-tuple-element(fusion), index=0
-      gte1 = f32[32,32]{1,0} get-tuple-element(fusion), index=1
-      ROOT root = (f32[32,32]{1,0}, f32[32,32]{1,0}, f32[32,32,32]{2,1,0})
-        tuple(gte1, gte1, select)
-    })"))
-                    .value();
-  ASSERT_TRUE(mof_.Run(module.get()).value());
-  SCOPED_TRACE(module->ToString());
-  const HloInstruction* root = module->entry_computation()->root_instruction();
-  const HloInstruction* fusion = nullptr;
-  ASSERT_THAT(root,
-              GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
-                                  m::GetTupleElement(), m::GetTupleElement())));
-  ASSERT_TRUE(fusion->IsMultiOutputFusion());
-  EXPECT_THAT(fusion->fused_expression_root(),
-              GmockMatch(m::Tuple(m::Reduce(), m::Reduce(), m::Select())));
-}
-
-TEST_F(MultiOutputFusionTest, ProducerConsumerFusionDoNotFuseLoopReduceFusion) {
-  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
-    fused_element_wise {
-      p0.1 = f32[2,2,2]{2,1,0} parameter(0)
-      p1.1 = f32[2,2,2]{2,1,0} parameter(1)
-      ROOT root = f32[2,2,2]{2,1,0} add(p0.1, p1.1)
-    }
-
-    fused_reduce {
-      p0.2 = f32[2,2,2]{2,1,0} parameter(0)
-      mul = f32[2,2,2]{2,1,0} multiply(f32[2,2,2]{2,1,0} p0.2,
-        f32[2,2,2]{2,1,0} p0.2)
-      broadcast = f32[2,2,2,2]{3,2,1,0} broadcast(mul), dimensions={3,2,1}
-      c1 = f32[] constant(0)
-      ROOT reduce = f32[2,2]{1,0} reduce(f32[2,2,2,2]{3,2,1,0} broadcast,
-        f32[] c1), dimensions={1,3}, to_apply=scalar_add_computation
-    }
-
-    ENTRY reduce {
-      p0 = f32[2,2,2]{2,1,0} parameter(0)
-      p1 = f32[2,2,2]{2,1,0} parameter(1)
-      element_wise = f32[2,2,2]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_element_wise
-      fusion = f32[2,2]{1,0} fusion(element_wise), kind=kLoop, calls=fused_reduce
-      ROOT root = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(fusion, element_wise)
-    })"))
-                    .value();
-  ASSERT_FALSE(mof_.Run(module.get()).value());
-}
-
-TEST_F(MultiOutputFusionTest,
-       ProducerConsumerFusionFp16LoopFusionAndReduceFusion) {
-  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
-    fused_select {
-      p1.1 = f16[32,32,32]{2,1,0} parameter(1)
-      c0 = f16[] constant(0)
-      broadcast = f16[32,32,32]{2,1,0} broadcast(f16[] c0), dimensions={}
-      greater-than = pred[32,32,32]{2,1,0} compare(f16[32,32,32]{2,1,0} p1.1,
-        f16[32,32,32]{2,1,0} broadcast), direction=GT
-      p0.1 = f16[32,32,32]{2,1,0} parameter(0)
-      ROOT select = f16[32,32,32]{2,1,0} select(pred[32,32,32]{2,1,0}
-        greater-than, f16[32,32,32]{2,1,0} p0.1, f16[32,32,32]{2,1,0} broadcast)
-    }
-    fused_reduce {
-      p0.2 = f16[32,32,32]{2,1,0} parameter(0)
-      convert = f32[32,32,32]{2,1,0} convert(p0.2)
-      c1 = f32[] constant(0)
-      r1 = f32[32,32]{1,0} reduce(convert, c1), dimensions={2},
-        to_apply=scalar_add_computation
-      mul = f32[32,32,32]{2,1,0} multiply(convert, convert)
-      r2 = f32[32,32]{1,0} reduce(mul, c1), dimensions={2},
-        to_apply=scalar_add_computation
-      ROOT tuple = (f32[32,32]{1,0}, f32[32,32]{1,0}) tuple(r1, r2)
-    }
-    ENTRY reduce {
-      p0 = f16[32,32,32]{2,1,0} parameter(0)
-      p1 = f16[32,32,32]{2,1,0} parameter(1)
-      select = f16[32,32,32]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_select
-      fusion = (f32[32,32]{1,0}, f32[32,32]{1,0}) fusion(select), kind=kInput,
-        calls=fused_reduce
-      gte0 = f32[32,32]{1,0} get-tuple-element(fusion), index=0
-      gte1 = f32[32,32]{1,0} get-tuple-element(fusion), index=1
-      ROOT root = (f32[32,32]{1,0}, f32[32,32]{1,0}, f16[32,32,32]{2,1,0})
-        tuple(gte1, gte1, select)
-    })"))
-                    .value();
-  ASSERT_TRUE(mof_.Run(module.get()).value());
-  SCOPED_TRACE(module->ToString());
-  const HloInstruction* root = module->entry_computation()->root_instruction();
-  const HloInstruction* fusion = nullptr;
-  ASSERT_THAT(root,
-              GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
-                                  m::GetTupleElement(), m::GetTupleElement())));
-  ASSERT_TRUE(fusion->IsMultiOutputFusion());
-  EXPECT_THAT(fusion->fused_expression_root(),
-              GmockMatch(m::Tuple(m::Reduce(), m::Reduce(), m::Select())));
-}
-
-TEST_F(MultiOutputFusionTest,
-       ProducerConsumerFusionReduceUnfriendlyLoopFusion) {
-  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
-    mixed_input_layouts_computation {
-      p0.1 = f16[128,1024,32,32]{1,3,2,0} parameter(0)
-      p1.1 = f16[128,1024,32,32]{3,2,1,0} parameter(1)
-      copy = f16[128,1024,32,32]{1,3,2,0} copy(p1.1)
-      c0 = f16[] constant(0)
-      broadcast = f16[128,1024,32,32]{1,3,2,0} broadcast(c0), dimensions={}
-      greater-than = pred[128,1024,32,32]{1,3,2,0} compare(copy, broadcast), direction=GT
-      ROOT root = f16[128,1024,32,32]{1,3,2,0} select(greater-than, p0.1, broadcast)
-    }
-    fused_reduce {
-      p0.2 = f16[128,1024,32,32]{1,3,2,0} parameter(0)
-      convert = f32[128,1024,32,32]{1,3,2,0} convert(p0.2)
-      c0.2 = f32[] constant(0)
-      ROOT reduce = f32[1024]{0} reduce(convert, c0.2), dimensions={0,2,3}, to_apply=scalar_add_computation
-    }
-    ENTRY reduce {
-      p0 = f16[128,1024,32,32]{3,2,1,0} parameter(0)
-      p1 = f16[128,1024,32,32]{1,3,2,0} parameter(1)
-      loop_fusion = f16[128,1024,32,32]{1,3,2,0} fusion(p0, p1), kind=kLoop, calls=mixed_input_layouts_computation
-      reduce_fusion = f32[1024]{0} fusion(loop_fusion), kind=kInput, calls=fused_reduce
-      ROOT root = (f32[1024]{0}, f16[128,1024,32,32]{1,3,2,0}) tuple(reduce_fusion, loop_fusion)
-    })"))
-                    .value();
-  ASSERT_FALSE(mof_.Run(module.get()).value());
-}
-
-TEST_F(MultiOutputFusionTest, ProducerConsumerFusionAvoidsCycles) {
-  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
-    fused_add {
-      p0 = f32[32,32,32]{2,1,0} parameter(0)
-      p1 = f32[32,32,32]{2,1,0} parameter(1)
-      ROOT add = f32[32,32,32]{2,1,0} add(p0, p1)
-    }
-
-    fused_mul {
-      p2 = f32[64,64,64]{2,1,0} parameter(0)
-      p3 = f32[64,64,64]{2,1,0} parameter(1)
-      ROOT multiply = f32[64,64,64]{2,1,0} multiply(p2, p3)
-    }
-
-    fused_reduce_1 {
-      p4 = f32[32,32,32]{2,1,0} parameter(0)
-      p5 = f32[64,64,64]{2,1,0} parameter(1)
-      slice = f32[32,32,32]{2,1,0} slice(p5), slice={[0:32], [0:32], [0:32]}
-      add = f32[32,32,32]{2,1,0} add(p4, slice)
-      c0 = f32[] constant(0)
-      ROOT r1 = f32[32,32]{1,0} reduce(add, c0), dimensions={2},
-        to_apply=scalar_add_computation
-    }
-
-    fused_reduce_2 {
-      p6 = f32[32,32,32]{2,1,0} parameter(0)
-      p7 = f32[64,64,64]{2,1,0} parameter(1)
-      c0 = f32[] constant(0)
-      pad = f32[64,64,64]{2,1,0} pad(p6, c0), padding=16_16x16_16x16_16
-      mul = f32[64,64,64]{2,1,0} multiply(pad, p7)
-      ROOT r1 = f32[64,64]{1,0} reduce(mul, c0), dimensions={2},
-        to_apply=scalar_add_computation
-    }
-
-    ENTRY reduce {
-      p8 = f32[32,32,32]{2,1,0} parameter(0)
-      p9 = f32[64,64,64]{2,1,0} parameter(1)
-      // `add` and `mul` can be multi-output fused with `reduce1` and `reduce2`,
-      // respectively. However, both isn't possible, because multi-output fusion
-      // will introduce an extra dependency from `neg` to `abs` or vice versa.
-      // Hence, the second multi-output fusion would introduce a cycle.
-      add = f32[32,32,32]{2,1,0} fusion(p8, p8), kind=kLoop, calls=fused_add
-      mul = f32[64,64,64]{2,1,0} fusion(p9, p9), kind=kLoop, calls=fused_mul
-
-      reduce1 = f32[32,32]{1,0} fusion(add, mul), kind=kInput,
-          calls=fused_reduce_1
-      reduce2 = f32[64,64]{1,0} fusion(add, mul), kind=kInput,
-          calls=fused_reduce_2
-      ROOT root = (f32[32,32,32]{2,1,0}, f32[32,32]{1,0}, f32[64,64]{1,0},
-                   f32[64,64,64]{2,1,0}) tuple(add, reduce1, reduce2, mul)
-    })"))
-                    .value();
-  ASSERT_TRUE(mof_.Run(module.get()).value());
-  SCOPED_TRACE(module->ToString());
-  EXPECT_EQ(1, CountMultiOutputFusions(module.get()));
-}
-
-TEST_F(MultiOutputFusionTest, PreferFuseProducerIntoFusionConsumer) {
-  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
-    fused_add {
-      p0 = f32[32,32,32]{2,1,0} parameter(0)
-      p1 = f32[32,32,32]{2,1,0} parameter(1)
-      ROOT add = f32[32,32,32]{2,1,0} add(p0, p1)
-    }
-    fused_reduce {
-      p0 = f32[32,32,32]{2,1,0} parameter(0)
-      p1 = f32[64,64,64]{2,1,0} parameter(1)
-      slice = f32[32,32,32]{2,1,0} slice(p1), slice={[0:32], [0:32], [0:32]}
-      add = f32[32,32,32]{2,1,0} add(p0, slice)
-      c0 = f32[] constant(0)
-      ROOT r1 = f32[32,32]{1,0} reduce(add, c0), dimensions={2},
-        to_apply=scalar_add_computation
-    }
-    ENTRY reduce {
-      p0 = f32[32,32,32]{2,1,0} parameter(0)
-      p1 = f32[64,64,64]{2,1,0} parameter(1)
-      add = f32[32,32,32]{2,1,0} fusion(p0, p0), kind=kLoop, calls=fused_add
-      c0 = f32[] constant(0)
-      reduce2 = f32[32,32]{1,0} reduce(add, c0), dimensions={2},
-        to_apply=scalar_add_computation
-      reduce = f32[32,32]{1,0} fusion(add, p1), kind=kInput, calls=fused_reduce
-      ROOT root = (f32[32,32,32]{2,1,0}, f32[32,32]{1,0}, f32[32,32]{1,0})
-                  tuple(add, reduce, reduce2)
-    })"))
-                    .value();
-  ASSERT_TRUE(mof_.Run(module.get()).value());
-  SCOPED_TRACE(module->ToString());
-  int multi_output_fusion_count = 0;
-  for (auto* computation : module->MakeNonfusionComputations()) {
-    for (auto* instr : computation->instructions()) {
-      if (instr->IsMultiOutputFusion()) {
-        multi_output_fusion_count++;
-      }
-    }
-  }
-  EXPECT_EQ(1, multi_output_fusion_count);
-}
-
-// Check that we limit the number of operands to fusions we create.
-TEST_F(MultiOutputFusionTest, AvoidsLargeFusion) {
-  constexpr int64_t kNumParams = 200;
-  ASSERT_GT(kNumParams, MaxOperandsAndOutputsPerFusion());
-
-  // Compute
-  //   p0 * p1,
-  //   p0 * p1 + p1 * p2
-  //   p0 * p1 + p1 * p2 + p2 * p3
-  //   ...
-  // where each of the (pi * pj)'s is represented as a fusion node so that
-  // multi-output fusion will pay attention to it.
-  auto module = CreateNewVerifiedModule();
-  HloComputation::Builder b(TestName());
-  Shape shape = ShapeUtil::MakeShape(F32, {10, 100});
-
-  std::vector<HloInstruction*> params;
-  for (int64_t i = 0; i < kNumParams; ++i) {
-    params.push_back(
-        b.AddInstruction(HloInstruction::CreateParameter(i, shape, "p")));
-  }
-
-  // Creates a fusion node that calculates x*y.
-  auto make_fusion = [&](HloInstruction* x, HloInstruction* y) {
-    HloComputation::Builder sub_builder("subcomp");
-    auto* p0 = sub_builder.AddInstruction(
-        HloInstruction::CreateParameter(0, shape, "p"));
-    auto* p1 = sub_builder.AddInstruction(
-        HloInstruction::CreateParameter(1, shape, "p"));
-    sub_builder.AddInstruction(
-        HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, p0, p1));
-    HloComputation* subcomp =
-        module->AddEmbeddedComputation(sub_builder.Build());
-    return HloInstruction::CreateFusion(
-        shape, HloInstruction::FusionKind::kLoop, {x, y}, subcomp);
-  };
-
-  auto* sum = b.AddInstruction(make_fusion(params[0], params[1]));
-  for (int64_t i = 2; i < kNumParams; ++i) {
-    sum = b.AddInstruction(HloInstruction::CreateBinary(
-        shape, HloOpcode::kAdd, sum,
-        b.AddInstruction(make_fusion(params[i - 1], params[i]))));
-  }
-  auto computation = module->AddEntryComputation(b.Build());
-  EXPECT_TRUE(mof_.Run(module.get()).value());
-  SCOPED_TRACE(module->ToString());
-  for (const HloInstruction* instr : computation->instructions()) {
-    EXPECT_LE(instr->operand_count() + ShapeUtil::SubshapeCount(instr->shape()),
-              MaxOperandsAndOutputsPerFusion())
-        << instr->ToString();
-  }
-}
-
-TEST_F(MultiOutputFusionTest, MultiOutputFusionDUS) {
-  auto module = ParseAndReturnVerifiedModule(R"(HloModule dus_mof
-    fusion.1 {
-      p.0 = f16[50,96,1024]{2,1,0} parameter(0)
-      p.1 = f16[1,96,1024]{2,1,0} parameter(1)
-      c.0 = s32[3]{0} constant({0, 0, 0})
-      ROOT %dynamic-update-slice = f16[50,96,1024]{2,1,0} dynamic-update-slice(p.0, p.1, c.0)
-    }
-
-    fusion.2 {
-      p.0 = f16[50,96,1024]{2,1,0} parameter(0)
-      p.1 = f16[1,96,1024]{2,1,0} parameter(1)
-      c.0 = s32[3]{0} constant({0, 0, 0})
-      ROOT %dynamic-update-slice = f16[50,96,1024]{2,1,0} dynamic-update-slice(p.0, p.1, c.0)
-    }
-
-    ENTRY entry {
-      p.00 = f16[50,96,1024]{2,1,0} parameter(0)
-      p.01 = f16[50,96,1024]{2,1,0} parameter(1)
-      p.1 = f16[1,96,1024]{2,1,0} parameter(2)
-
-      f1 = f16[50,96,1024] fusion(p.00, p.1), kind=kLoop, calls=fusion.1
-      f2 = f16[50,96,1024] fusion(p.01, p.1), kind=kLoop, calls=fusion.2
-      ROOT tuple = (f16[50,96,1024],f16[50,96,1024]) tuple(f1, f2)
-    })")
-                    .value();
-  ASSERT_FALSE(mof_.Run(module.get()).value());
-}
-
-// Check that we don't fuse too many reductions together.
-TEST_F(MultiOutputFusionTest, SharedMemoryBudget) {
-  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
-    fused_computation0 {
-      p0 = f32[64,64] parameter(0)
-      p1 = f32[64,64] parameter(1)
-      p2 = f32[] parameter(2)
-      add = f32[64,64] add(p0, p1)
-      ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
-        to_apply=scalar_add_computation
-    }
-    fused_computation1 {
-      p0 = f32[64,64] parameter(0)
-      p1 = f32[64,64] parameter(1)
-      p2 = f32[] parameter(2)
-      add = f32[64,64] add(p0, p1)
-      ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
-        to_apply=scalar_add_computation
-    }
-    fused_computation2 {
-      p0 = f32[64,64] parameter(0)
-      p1 = f32[64,64] parameter(1)
-      p2 = f32[] parameter(2)
-      add = f32[64,64] add(p0, p1)
-      ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
-        to_apply=scalar_add_computation
-    }
-    fused_computation3 {
-      p0 = f32[64,64] parameter(0)
-      p1 = f32[64,64] parameter(1)
-      p2 = f32[] parameter(2)
-      add = f32[64,64] add(p0, p1)
-      ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
-        to_apply=scalar_add_computation
-    }
-    fused_computation4 {
-      p0 = f32[64,64] parameter(0)
-      p1 = f32[64,64] parameter(1)
-      p2 = f32[] parameter(2)
-      add = f32[64,64] add(p0, p1)
-      ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
-        to_apply=scalar_add_computation
-    }
-    fused_computation5 {
-      p0 = f32[64,64] parameter(0)
-      p1 = f32[64,64] parameter(1)
-      p2 = f32[] parameter(2)
-      add = f32[64,64] add(p0, p1)
-      ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
-        to_apply=scalar_add_computation
-    }
-    fused_computation6 {
-      p0 = f32[64,64] parameter(0)
-      p1 = f32[64,64] parameter(1)
-      p2 = f32[] parameter(2)
-      add = f32[64,64] add(p0, p1)
-      ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
-        to_apply=scalar_add_computation
-    }
-    fused_computation7 {
-      p0 = f32[64,64] parameter(0)
-      p1 = f32[64,64] parameter(1)
-      p2 = f32[] parameter(2)
-      add = f32[64,64] add(p0, p1)
-      ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
-        to_apply=scalar_add_computation
-    }
-    fused_computation8 {
-      p0 = f32[64,64] parameter(0)
-      p1 = f32[64,64] parameter(1)
-      p2 = f32[] parameter(2)
-      add = f32[64,64] add(p0, p1)
-      ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
-        to_apply=scalar_add_computation
-    }
-    fused_computation9 {
-      p0 = f32[64,64] parameter(0)
-      p1 = f32[64,64] parameter(1)
-      p2 = f32[] parameter(2)
-      add = f32[64,64] add(p0, p1)
-      ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
-        to_apply=scalar_add_computation
-    }
-    ENTRY computation {
-      zero = f32[] constant(0)
-      param0 = f32[64,64] parameter(0)
-      param1 = f32[64,64] parameter(1)
-      param2 = f32[64,64] parameter(2)
-      param3 = f32[64,64] parameter(3)
-      param4 = f32[64,64] parameter(4)
-      param5 = f32[64,64] parameter(5)
-      param6 = f32[64,64] parameter(6)
-      param7 = f32[64,64] parameter(7)
-      param8 = f32[64,64] parameter(8)
-      param9 = f32[64,64] parameter(9)
-      out0 = f32[64] fusion(param0, param1, zero), kind=kInput, calls=fused_computation0
-      out1 = f32[64] fusion(param1, param2, zero), kind=kInput, calls=fused_computation1
-      out2 = f32[64] fusion(param2, param3, zero), kind=kInput, calls=fused_computation2
-      out3 = f32[64] fusion(param3, param4, zero), kind=kInput, calls=fused_computation3
-      out4 = f32[64] fusion(param4, param5, zero), kind=kInput, calls=fused_computation4
-      out5 = f32[64] fusion(param5, param6, zero), kind=kInput, calls=fused_computation5
-      out6 = f32[64] fusion(param6, param7, zero), kind=kInput, calls=fused_computation6
-      out7 = f32[64] fusion(param7, param8, zero), kind=kInput, calls=fused_computation7
-      out8 = f32[64] fusion(param8, param9, zero), kind=kInput, calls=fused_computation8
-      out9 = f32[64] fusion(param9, param0, zero), kind=kInput, calls=fused_computation9
-      ROOT out = (f32[64], f32[64], f32[64], f32[64], f32[64], f32[64], f32[64], f32[64], f32[64], f32[64]) tuple(f32[64] out0, f32[64] out1, f32[64] out2, f32[64] out3, f32[64] out4, f32[64] out5, f32[64] out6, f32[64] out7, f32[64] out8, f32[64] out9)
-    }
-  )"))
-                    .value();
-  ASSERT_TRUE(mof_.Run(module.get()).value());
-
-  EXPECT_EQ(5, CountMultiOutputFusions(module.get()));
-}
-
-TEST_F(MultiOutputFusionTest, DoNotGroupTooManyReductions) {
-  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
-    fused_computation0 {
-      p0 = f32[64,64] parameter(0)
-      p1 = f32[64,64] parameter(1)
-      p2 = f32[] parameter(2)
-      add = f32[64,64] add(p0, p1)
-      ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
-        to_apply=scalar_add_computation
-    }
-    fused_computation1 {
-      p0 = f32[64,64] parameter(0)
-      p1 = f32[64,64] parameter(1)
-      p2 = f32[] parameter(2)
-      add = f32[64,64] add(p0, p1)
-      ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
-        to_apply=scalar_add_computation
-    }
-    fused_computation2 {
-      p0 = f32[64,64] parameter(0)
-      p1 = f32[64,64] parameter(1)
-      p2 = f32[] parameter(2)
-      add = f32[64,64] add(p0, p1)
-      ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
-        to_apply=scalar_add_computation
-    }
-    fused_computation3 {
-      p0 = f32[64,64] parameter(0)
-      p1 = f32[64,64] parameter(1)
-      p2 = f32[] parameter(2)
-      add = f32[64,64] add(p0, p1)
-      ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
-        to_apply=scalar_add_computation
-    }
-    fused_computation4 {
-      p0 = f32[64,64] parameter(0)
-      p1 = f32[64,64] parameter(1)
-      p2 = f32[] parameter(2)
-      add = f32[64,64] add(p0, p1)
-      ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
-        to_apply=scalar_add_computation
-    }
-    fused_computation5 {
-      p0 = f32[64,64] parameter(0)
-      p1 = f32[64,64] parameter(1)
-      p2 = f32[] parameter(2)
-      add = f32[64,64] add(p0, p1)
-      ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
-        to_apply=scalar_add_computation
-    }
-    fused_computation6 {
-      p0 = f32[64,64] parameter(0)
-      p1 = f32[64,64] parameter(1)
-      p2 = f32[] parameter(2)
-      add = f32[64,64] add(p0, p1)
-      ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
-        to_apply=scalar_add_computation
-    }
-    fused_computation7 {
-      p0 = f32[64,64] parameter(0)
-      p1 = f32[64,64] parameter(1)
-      p2 = f32[] parameter(2)
-      add = f32[64,64] add(p0, p1)
-      ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
-        to_apply=scalar_add_computation
-    }
-    fused_computation8 {
-      p0 = f32[64,64] parameter(0)
-      p1 = f32[64,64] parameter(1)
-      p2 = f32[] parameter(2)
-      add = f32[64,64] add(p0, p1)
-      ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
-        to_apply=scalar_add_computation
-    }
-    fused_computation9 {
-      p0 = f32[64,64] parameter(0)
-      p1 = f32[64,64] parameter(1)
-      p2 = f32[] parameter(2)
-      add = f32[64,64] add(p0, p1)
-      ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
-        to_apply=scalar_add_computation
-    }
-    ENTRY computation {
-      zero = f32[] constant(0)
-      param0 = f32[64,64] parameter(0)
-      param1 = f32[64,64] parameter(1)
-      param2 = f32[64,64] parameter(2)
-      param3 = f32[64,64] parameter(3)
-      param4 = f32[64,64] parameter(4)
-      param5 = f32[64,64] parameter(5)
-      param6 = f32[64,64] parameter(6)
-      param7 = f32[64,64] parameter(7)
-      param8 = f32[64,64] parameter(8)
-      param9 = f32[64,64] parameter(9)
-      out0 = f32[64] fusion(param0, param1, zero), kind=kInput, calls=fused_computation0
-      out1 = f32[64] fusion(param1, param2, zero), kind=kInput, calls=fused_computation1
-      out2 = f32[64] fusion(param2, param3, zero), kind=kInput, calls=fused_computation2
-      out3 = f32[64] fusion(param3, param4, zero), kind=kInput, calls=fused_computation3
-      out4 = f32[64] fusion(param4, param5, zero), kind=kInput, calls=fused_computation4
-      out5 = f32[64] fusion(param5, param6, zero), kind=kInput, calls=fused_computation5
-      out6 = f32[64] fusion(param6, param7, zero), kind=kInput, calls=fused_computation6
-      out7 = f32[64] fusion(param7, param8, zero), kind=kInput, calls=fused_computation7
-      out8 = f32[64] fusion(param8, param9, zero), kind=kInput, calls=fused_computation8
-      out9 = f32[64] fusion(param9, param0, zero), kind=kInput, calls=fused_computation9
-      ROOT out = (f32[64], f32[64], f32[64], f32[64], f32[64], f32[64], f32[64], f32[64], f32[64], f32[64]) tuple(f32[64] out0, f32[64] out1, f32[64] out2, f32[64] out3, f32[64] out4, f32[64] out5, f32[64] out6, f32[64] out7, f32[64] out8, f32[64] out9)
-    }
-  )"))
-                    .value();
-  ASSERT_TRUE(mof_.Run(module.get()).value());
-
-  EXPECT_EQ(2, CountMultiOutputFusions(module.get()));
-}
-
-TEST_F(MultiOutputFusionTest, NoFusionToAvoidUsingTooMuchSharedMemory) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule xla_computation_update_step.10931
-
-%scalar_add_computation.1 (scalar_lhs.1: f64[], scalar_rhs.1: f64[]) -> f64[] {
-  %scalar_lhs.1 = f64[] parameter(0)
-  %scalar_rhs.1 = f64[] parameter(1)
-  ROOT %add.1257 = f64[] add(f64[] %scalar_lhs.1, f64[] %scalar_rhs.1)
-}
-
-%fused_computation.1 (param_0.8: f64[64,64], param_1.11: f64[64,64], param_2.9: f64[64,64]) -> (f64[64], f64[64]) {
-  %param_0.8 = f64[64,64]{1,0} parameter(0)
-  %param_1.11 = f64[64,64]{1,0} parameter(1)
-  %multiply.2 = f64[64,64]{1,0} multiply(f64[64,64]{1,0} %param_0.8, f64[64,64]{1,0} %param_1.11)
-  %constant_5217.3 = f64[] constant(0)
-  %broadcast.1 = f64[64,64]{1,0} broadcast(f64[] %constant_5217.3), dimensions={}
-  %multiply.0 = f64[64,64]{1,0} multiply(f64[64,64]{1,0} %multiply.2, f64[64,64]{1,0} %broadcast.1)
-  %reduce.0 = f64[64]{0} reduce(f64[64,64]{1,0} %multiply.0, f64[] %constant_5217.3), dimensions={0}, to_apply=%scalar_add_computation.1
-  %param_2.9 = f64[64,64]{1,0} parameter(2)
-  %multiply.1514.clone.0.clone.1 = f64[64,64]{1,0} multiply(f64[64,64]{1,0} %param_2.9, f64[64,64]{1,0} %param_1.11)
-  %constant_5217.1.clone.1 = f64[] constant(0)
-  %broadcast.0.clone.1 = f64[64,64]{1,0} broadcast(f64[] %constant_5217.1.clone.1), dimensions={}
-  %multiply.1341.clone.0.clone.1 = f64[64,64]{1,0} multiply(f64[64,64]{1,0} %multiply.1514.clone.0.clone.1, f64[64,64]{1,0} %broadcast.0.clone.1)
-  %reduce.630.clone.0.clone.1 = f64[64]{0} reduce(f64[64,64]{1,0} %multiply.1341.clone.0.clone.1, f64[] %constant_5217.1.clone.1), dimensions={0}, to_apply=%scalar_add_computation.1
-  ROOT %tuple = (f64[64]{0}, f64[64]{0}) tuple(f64[64]{0} %reduce.0, f64[64]{0} %reduce.630.clone.0.clone.1)
-}
-
-%primitive_computation_add__1.6426 (parameter.6427: f64[], parameter.6428: f64[]) -> f64[] {
-  %parameter.6427 = f64[] parameter(0)
-  %parameter.6428 = f64[] parameter(1)
-  ROOT %add.6429 = f64[] add(f64[] %parameter.6427, f64[] %parameter.6428)
-}
-
-%fused_computation.2 (param_0.7: f64[64,64], param_1.9: f64[64,64]) -> f64[64] {
-  %param_0.7 = f64[64,64]{1,0} parameter(0)
-  %param_1.9 = f64[64,64]{1,0} parameter(1)
-  %multiply.1 = f64[64,64]{1,0} multiply(f64[64,64]{1,0} %param_0.7, f64[64,64]{1,0} %param_1.9)
-  %constant_5217.2 = f64[] constant(0)
-  ROOT %reduce.740.clone.0 = f64[64]{0} reduce(f64[64,64]{1,0} %multiply.1, f64[] %constant_5217.2), dimensions={0}, to_apply=%primitive_computation_add__1.6426
-}
-
-ENTRY %reproducer (param_0.1090: f64[64,64], param_1.1377: f64[64,64], param_2.1948: f64[64,64]) -> (f64[64], f64[64], f64[64]) {
-  %param_0.1090 = f64[64,64]{1,0} parameter(0)
-  %param_1.1377 = f64[64,64]{1,0} parameter(1)
-  %param_2.1948 = f64[64,64]{1,0} parameter(2)
-  %fusion.1 = (f64[64]{0}, f64[64]{0}) fusion(f64[64,64]{1,0} %param_0.1090, f64[64,64]{1,0} %param_1.1377, f64[64,64]{1,0} %param_2.1948), kind=kInput, calls=%fused_computation.1
-  %get-tuple-element = f64[64]{0} get-tuple-element((f64[64]{0}, f64[64]{0}) %fusion.1), index=0
-  %fusion.2 = f64[64]{0} fusion(f64[64,64]{1,0} %param_0.1090, f64[64,64]{1,0} %param_1.1377), kind=kInput, calls=%fused_computation.2
-  %get-tuple-element.1 = f64[64]{0} get-tuple-element((f64[64]{0}, f64[64]{0}) %fusion.1), index=1
-  ROOT %tuple.428 = (f64[64]{0}, f64[64]{0}, f64[64]{0}) tuple(f64[64]{0} %get-tuple-element, f64[64]{0} %fusion.2, f64[64]{0} %get-tuple-element.1)
-}
-  )")
-                    .value();
-  EXPECT_FALSE(mof_.Run(module.get()).value());
-}
-
-TEST_F(MultiOutputFusionTest, NoFusionToAvoidCodeDuplication) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-HloModule module
-
-and.reduce_sub_computation {
-  x = pred[] parameter(0)
-  y = pred[] parameter(1)
-  ROOT and = pred[] and(x, y)
-}
-
-fused_computation.1 {
-  param_4.658 = f32[2,20,256]{2,0,1} parameter(4)
-  slice.1385 = f32[2,1,256]{2,0,1} slice(param_4.658), slice={[0:2], [11:12], [0:256]}
-  constant.6847 = s32[] constant(0)
-  broadcast.4823 = s32[3]{0} broadcast(constant.6847), dimensions={}
-  param_9.415 = s32[3]{0} parameter(9)
-  compare.700 = pred[3]{0} compare(broadcast.4823, param_9.415), direction=LE
-  constant.6846 = pred[] constant(true)
-  reduce.221 = pred[] reduce(compare.700, constant.6846), dimensions={0}, to_apply=and.reduce_sub_computation
-  broadcast.2933 = pred[2,1,256]{2,0,1} broadcast(reduce.221), dimensions={}
-  param_5.528 = f32[2,512]{1,0} parameter(5)
-  slice.1384 = f32[2,256]{1,0} slice(param_5.528), slice={[0:2], [0:256]}
-  bitcast.341 = f32[2,1,256]{2,0,1} bitcast(slice.1384)
-  constant.5418 = f32[] constant(0)
-  broadcast.3227 = f32[2,1,256]{2,0,1} broadcast(constant.5418), dimensions={}
-  select.173 = f32[2,1,256]{2,0,1} select(broadcast.2933, bitcast.341, broadcast.3227)
-  add.573 = f32[2,1,256]{2,0,1} add(slice.1385, select.173)
-  param_0.299 = s32[] parameter(0)
-  constant.5157 = s32[] constant(11)
-  dynamic-update-slice.189 = f32[2,20,256]{2,0,1} dynamic-update-slice(param_4.658, add.573, param_0.299, constant.5157, param_0.299)
-  slice.1383 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.189), slice={[0:2], [10:11], [0:256]}
-  constant.6800 = s32[] constant(0)
-  broadcast.4803 = s32[3]{0} broadcast(constant.6800), dimensions={}
-  param_8.484 = s32[3]{0} parameter(8)
-  compare.681 = pred[3]{0} compare(broadcast.4803, param_8.484), direction=LE
-  constant.6798 = pred[] constant(true)
-  reduce.203 = pred[] reduce(compare.681, constant.6798), dimensions={0}, to_apply=and.reduce_sub_computation
-  broadcast.2932 = pred[2,1,256]{2,0,1} broadcast(reduce.203), dimensions={}
-  param_3.1169 = f32[2,512]{1,0} parameter(3)
-  slice.1382 = f32[2,256]{1,0} slice(param_3.1169), slice={[0:2], [0:256]}
-  bitcast.340 = f32[2,1,256]{2,0,1} bitcast(slice.1382)
-  select.172 = f32[2,1,256]{2,0,1} select(broadcast.2932, bitcast.340, broadcast.3227)
-  add.572 = f32[2,1,256]{2,0,1} add(slice.1383, select.172)
-  constant.5154 = s32[] constant(10)
-  dynamic-update-slice.188 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.189, add.572, param_0.299, constant.5154, param_0.299)
-  slice.1381 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.188), slice={[0:2], [9:10], [0:256]}
-  constant.6794 = s32[] constant(0)
-  broadcast.4801 = s32[3]{0} broadcast(constant.6794), dimensions={}
-  param_7.478 = s32[3]{0} parameter(7)
-  compare.679 = pred[3]{0} compare(broadcast.4801, param_7.478), direction=LE
-  constant.6793 = pred[] constant(true)
-  reduce.201 = pred[] reduce(compare.679, constant.6793), dimensions={0}, to_apply=and.reduce_sub_computation
-  broadcast.2930 = pred[2,1,256]{2,0,1} broadcast(reduce.201), dimensions={}
-  param_2.1685 = f32[2,512]{1,0} parameter(2)
-  slice.1380 = f32[2,256]{1,0} slice(param_2.1685), slice={[0:2], [0:256]}
-  bitcast.339 = f32[2,1,256]{2,0,1} bitcast(slice.1380)
-  select.171 = f32[2,1,256]{2,0,1} select(broadcast.2930, bitcast.339, broadcast.3227)
-  add.571 = f32[2,1,256]{2,0,1} add(slice.1381, select.171)
-  constant.5153 = s32[] constant(9)
-  dynamic-update-slice.187 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.188, add.571, param_0.299, constant.5153, param_0.299)
-  slice.1379 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.187), slice={[0:2], [8:9], [0:256]}
-  constant.6788 = s32[] constant(0)
-  broadcast.4799 = s32[3]{0} broadcast(constant.6788), dimensions={}
-  param_6.495 = s32[3]{0} parameter(6)
-  compare.677 = pred[3]{0} compare(broadcast.4799, param_6.495), direction=LE
-  constant.6786 = pred[] constant(true)
-  reduce.199 = pred[] reduce(compare.677, constant.6786), dimensions={0}, to_apply=and.reduce_sub_computation
-  broadcast.2929 = pred[2,1,256]{2,0,1} broadcast(reduce.199), dimensions={}
-  param_1.1408 = f32[2,512]{1,0} parameter(1)
-  slice.1378 = f32[2,256]{1,0} slice(param_1.1408), slice={[0:2], [0:256]}
-  bitcast.338 = f32[2,1,256]{2,0,1} bitcast(slice.1378)
-  select.170 = f32[2,1,256]{2,0,1} select(broadcast.2929, bitcast.338, broadcast.3227)
-  add.570 = f32[2,1,256]{2,0,1} add(slice.1379, select.170)
-  constant.5152 = s32[] constant(8)
-  ROOT dynamic-update-slice.186 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.187, add.570, param_0.299, constant.5152, param_0.299)
-}
-
-fused_computation.2 {
-  param_4.655 = f32[2,20,256]{2,0,1} parameter(4)
-  slice.1369 = f32[2,1,256]{2,0,1} slice(param_4.655), slice={[0:2], [7:8], [0:256]}
-  param_6.483 = pred[] parameter(6)
-  broadcast.2927 = pred[2,1,256]{2,0,1} broadcast(param_6.483), dimensions={}
-  param_5.525 = f32[2,512]{1,0} parameter(5)
-  slice.1368 = f32[2,256]{1,0} slice(param_5.525), slice={[0:2], [0:256]}
-  bitcast.333 = f32[2,1,256]{2,0,1} bitcast(slice.1368)
-  constant.5415 = f32[] constant(0)
-  broadcast.3225 = f32[2,1,256]{2,0,1} broadcast(constant.5415), dimensions={}
-  select.161 = f32[2,1,256]{2,0,1} select(broadcast.2927, bitcast.333, broadcast.3225)
-  add.549 = f32[2,1,256]{2,0,1} add(slice.1369, select.161)
-  param_0.265 = s32[] parameter(0)
-  constant.5151 = s32[] constant(7)
-  dynamic-update-slice.185 = f32[2,20,256]{2,0,1} dynamic-update-slice(param_4.655, add.549, param_0.265, constant.5151, param_0.265)
-  slice.1367 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.185), slice={[0:2], [6:7], [0:256]}
-  constant.6782 = s32[] constant(0)
-  broadcast.4797 = s32[3]{0} broadcast(constant.6782), dimensions={}
-  param_9.391 = s32[3]{0} parameter(9)
-  compare.675 = pred[3]{0} compare(broadcast.4797, param_9.391), direction=LE
-  constant.6781 = pred[] constant(true)
-  reduce.197 = pred[] reduce(compare.675, constant.6781), dimensions={0}, to_apply=and.reduce_sub_computation
-  broadcast.2926 = pred[2,1,256]{2,0,1} broadcast(reduce.197), dimensions={}
-  param_3.1167 = f32[2,512]{1,0} parameter(3)
-  slice.1366 = f32[2,256]{1,0} slice(param_3.1167), slice={[0:2], [0:256]}
-  bitcast.332 = f32[2,1,256]{2,0,1} bitcast(slice.1366)
-  select.160 = f32[2,1,256]{2,0,1} select(broadcast.2926, bitcast.332, broadcast.3225)
-  add.548 = f32[2,1,256]{2,0,1} add(slice.1367, select.160)
-  constant.5150 = s32[] constant(6)
-  dynamic-update-slice.184 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.185, add.548, param_0.265, constant.5150, param_0.265)
-  slice.1365 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.184), slice={[0:2], [5:6], [0:256]}
-  constant.6776 = s32[] constant(0)
-  broadcast.4794 = s32[3]{0} broadcast(constant.6776), dimensions={}
-  param_8.464 = s32[3]{0} parameter(8)
-  compare.673 = pred[3]{0} compare(broadcast.4794, param_8.464), direction=LE
-  constant.6775 = pred[] constant(true)
-  reduce.195 = pred[] reduce(compare.673, constant.6775), dimensions={0}, to_apply=and.reduce_sub_computation
-  broadcast.2925 = pred[2,1,256]{2,0,1} broadcast(reduce.195), dimensions={}
-  param_2.1684 = f32[2,512]{1,0} parameter(2)
-  slice.1364 = f32[2,256]{1,0} slice(param_2.1684), slice={[0:2], [0:256]}
-  bitcast.331 = f32[2,1,256]{2,0,1} bitcast(slice.1364)
-  select.159 = f32[2,1,256]{2,0,1} select(broadcast.2925, bitcast.331, broadcast.3225)
-  add.547 = f32[2,1,256]{2,0,1} add(slice.1365, select.159)
-  constant.5149 = s32[] constant(5)
-  dynamic-update-slice.183 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.184, add.547, param_0.265, constant.5149, param_0.265)
-  slice.1363 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.183), slice={[0:2], [4:5], [0:256]}
-  constant.6770 = s32[] constant(0)
-  broadcast.4792 = s32[3]{0} broadcast(constant.6770), dimensions={}
-  param_7.458 = s32[3]{0} parameter(7)
-  compare.671 = pred[3]{0} compare(broadcast.4792, param_7.458), direction=LE
-  constant.6769 = pred[] constant(true)
-  reduce.193 = pred[] reduce(compare.671, constant.6769), dimensions={0}, to_apply=and.reduce_sub_computation
-  broadcast.2924 = pred[2,1,256]{2,0,1} broadcast(reduce.193), dimensions={}
-  param_1.1405 = f32[2,512]{1,0} parameter(1)
-  slice.1362 = f32[2,256]{1,0} slice(param_1.1405), slice={[0:2], [0:256]}
-  bitcast.330 = f32[2,1,256]{2,0,1} bitcast(slice.1362)
-  select.158 = f32[2,1,256]{2,0,1} select(broadcast.2924, bitcast.330, broadcast.3225)
-  add.546 = f32[2,1,256]{2,0,1} add(slice.1363, select.158)
-  constant.5148 = s32[] constant(4)
-  ROOT dynamic-update-slice.182 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.183, add.546, param_0.265, constant.5148, param_0.265)
-}
-
-ENTRY main {
-  param_0.0 = s32[] parameter(0)
-  param_1.0 = f32[2,512]{1,0} parameter(1)
-  param_2.0 = f32[2,512]{1,0} parameter(2)
-  param_3.0 = f32[2,512]{1,0} parameter(3)
-  param_4.0 = f32[2,20,256]{2,1,0} parameter(4)
-  param_5.0 = f32[2,512]{1,0} parameter(5)
-  param_6.0 = s32[3]{0} parameter(6)
-  param_7.0 = s32[3]{0} parameter(7)
-  param_8.0 = s32[3]{0} parameter(8)
-  param_9.0 = s32[3]{0} parameter(9)
-  fusion.1 = f32[2,20,256]{2,0,1} fusion(param_0.0, param_1.0, param_2.0, param_3.0, param_4.0, param_5.0, param_6.0, param_7.0, param_8.0, param_9.0), kind=kLoop, calls=fused_computation.1
-  param_10 = pred[] parameter(10)
-  fusion.2 = f32[2,20,256]{2,0,1} fusion(param_0.0, param_1.0, param_2.0, param_3.0, fusion.1, param_5.0, param_10, param_7.0, param_8.0, param_9.0), kind=kLoop, calls=fused_computation.2
-  ROOT root = (f32[2,20,256]{2,0,1}, f32[2,20,256]{2,0,1}) tuple(fusion.1, fusion.2)
-}
-  )")
-                    .value();
-  EXPECT_FALSE(mof_.Run(module.get()).value());
-}
-
-TEST_F(MultiOutputFusionTest, DoNotFuseRoot) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-HloModule module
-
-no_op {
-  arg_empty_tuple = () parameter(0)
-  ROOT tuple = () tuple()
-}
-
-fused_computation {
-  param_0 = f32[] parameter(0)
-  ROOT convert = s32[] convert(param_0)
-}
-
-ENTRY main {
-  param_0 = f32[] parameter(0)
-  fusion = s32[] fusion(param_0), kind=kLoop, calls=fused_computation
-  tuple = () tuple()
-  conditional = () conditional(fusion, tuple, tuple), branch_computations={no_op, no_op}
-  constant = f32[] constant(1)
-  ROOT root = f32[] add(param_0, constant)
-}
-  )")
-                    .value();
-  EXPECT_FALSE(mof_.Run(module.get()).value());
-}
-
-TEST_F(MultiOutputFusionTest, CostBasedNoMerge) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-HloModule m
-
-region_3.63 {
-  Arg_0.64 = f32[] parameter(0)
-  Arg_1.65 = f32[] parameter(1)
-  ROOT add.66 = f32[] add(Arg_0.64, Arg_1.65)
-}
-
-fused_computation.29 {
-  param_0.161 = f32[5,32,32,1]{3,2,1,0} parameter(0)
-  multiply.208 = f32[5,32,32,1]{3,2,1,0} multiply(param_0.161, param_0.161)
-  bitcast.67 = f32[5,32,32]{2,1,0} bitcast(multiply.208)
-  constant.265 = f32[] constant(0)
-  reduce-window.81 = f32[5,30,31]{2,1,0} reduce-window(bitcast.67, constant.265), window={size=1x3x2}, to_apply=region_3.63
-  constant.264 = f32[] constant(0.166666672)
-  broadcast.204 = f32[5,30,31]{2,1,0} broadcast(constant.264), dimensions={}
-  multiply.205 = f32[5,30,31]{2,1,0} multiply(reduce-window.81, broadcast.204)
-  constant.263 = f32[] constant(0)
-  reduce-window.80 = f32[5,30,31]{2,1,0} reduce-window(multiply.205, constant.263), window={size=1x2x3 pad=0_0x0_1x1_1}, to_apply=region_3.63
-  constant.262 = f32[] constant(0.0138888899)
-  broadcast.201 = f32[5,30,31]{2,1,0} broadcast(constant.262), dimensions={}
-  multiply.204 = f32[5,30,31]{2,1,0} multiply(reduce-window.80, broadcast.201)
-  constant.261 = f32[] constant(0)
-  reduce-window.78 = f32[5,30,31]{2,1,0} reduce-window(multiply.204, constant.261), window={size=1x1x2 pad=0_0x0_0x0_1}, to_apply=region_3.63
-  constant.113 = f32[] constant(0.5)
-  broadcast.137 = f32[5,30,31]{2,1,0} broadcast(constant.113), dimensions={}
-  multiply.125 = f32[5,30,31]{2,1,0} multiply(reduce-window.78, broadcast.137)
-  constant.114 = f32[] constant(0)
-  ROOT reduce-window.17 = f32[5,30,31]{2,1,0} reduce-window(multiply.125, constant.114), window={size=1x2x1 pad=0_0x0_1x0_0}, to_apply=region_3.63
-}
-
-fused_computation.15 {
-  constant.108 = f32[] constant(0.5)
-  broadcast.105 = f32[5,5,30,31]{3,2,1,0} broadcast(constant.108), dimensions={}
-  param_3.126 = f32[5,30,31]{2,1,0} parameter(3)
-  constant.295 = f32[] constant(0.25)
-  broadcast.234 = f32[5,30,31]{2,1,0} broadcast(constant.295), dimensions={}
-  multiply.242 = f32[5,30,31]{2,1,0} multiply(param_3.126, broadcast.234)
-  broadcast.233 = f32[5,5,30,31]{3,2,1,0} broadcast(multiply.242), dimensions={0,2,3}
-  param_2.154 = f32[5,30,31]{2,1,0} parameter(2)
-  multiply.241 = f32[5,30,31]{2,1,0} multiply(param_2.154, broadcast.234)
-  broadcast.232 = f32[5,5,30,31]{3,2,1,0} broadcast(multiply.241), dimensions={1,2,3}
-  multiply.240 = f32[5,5,30,31]{3,2,1,0} multiply(broadcast.233, broadcast.232)
-  param_1.188 = f32[5,5,30,31]{3,2,1,0} parameter(1)
-  constant.294 = f32[] constant(0.159154937)
-  broadcast.231 = f32[5,5,30,31]{3,2,1,0} broadcast(constant.294), dimensions={}
-  multiply.239 = f32[5,5,30,31]{3,2,1,0} multiply(param_1.188, broadcast.231)
-  param_0.164 = f32[5,5,30,31]{3,2,1,0} parameter(0)
-  add.19 = f32[5,5,30,31]{3,2,1,0} add(multiply.239, param_0.164)
-  constant.293 = f32[] constant(0)
-  reduce-window.90 = f32[5,5,30,31]{3,2,1,0} reduce-window(add.19, constant.293), window={size=1x1x1x2 pad=0_0x0_0x0_0x0_1}, to_apply=region_3.63
-  constant.292 = f32[] constant(0.5)
-  broadcast.230 = f32[5,5,30,31]{3,2,1,0} broadcast(constant.292), dimensions={}
-  multiply.238 = f32[5,5,30,31]{3,2,1,0} multiply(reduce-window.90, broadcast.230)
-  constant.291 = f32[] constant(0)
-  reduce-window.89 = f32[5,5,30,31]{3,2,1,0} reduce-window(multiply.238, constant.291), window={size=1x1x2x1 pad=0_0x0_0x0_1x0_0}, to_apply=region_3.63
-  constant.290 = f32[] constant(0.25)
-  broadcast.229 = f32[5,5,30,31]{3,2,1,0} broadcast(constant.290), dimensions={}
-  multiply.237 = f32[5,5,30,31]{3,2,1,0} multiply(reduce-window.89, broadcast.229)
-  multiply.236 = f32[5,5,30,31]{3,2,1,0} multiply(multiply.237, multiply.237)
-  subtract.10 = f32[5,5,30,31]{3,2,1,0} subtract(multiply.240, multiply.236)
-  constant.289 = f32[] constant(0)
-  broadcast.228 = f32[5,5,30,31]{3,2,1,0} broadcast(constant.289), dimensions={}
-  maximum.6 = f32[5,5,30,31]{3,2,1,0} maximum(subtract.10, broadcast.228)
-  sqrt.6 = f32[5,5,30,31]{3,2,1,0} sqrt(maximum.6)
-  constant.110 = f32[] constant(0)
-  broadcast.107 = f32[5,5,30,31]{3,2,1,0} broadcast(constant.110), dimensions={}
-  compare.4 = pred[5,5,30,31]{3,2,1,0} compare(sqrt.6, broadcast.107), direction=EQ
-  constant.243 = f32[] constant(0.159154937)
-  broadcast.193 = f32[5,5,30,31]{3,2,1,0} broadcast(constant.243), dimensions={}
-  multiply.194 = f32[5,5,30,31]{3,2,1,0} multiply(param_1.188, broadcast.193)
-  add.15 = f32[5,5,30,31]{3,2,1,0} add(multiply.194, param_0.164)
-  constant.242 = f32[] constant(0)
-  reduce-window.66 = f32[5,5,30,31]{3,2,1,0} reduce-window(add.15, constant.242), window={size=1x1x1x2 pad=0_0x0_0x0_0x0_1}, to_apply=region_3.63
-  constant.241 = f32[] constant(0.5)
-  broadcast.192 = f32[5,5,30,31]{3,2,1,0} broadcast(constant.241), dimensions={}
-  multiply.193 = f32[5,5,30,31]{3,2,1,0} multiply(reduce-window.66, broadcast.192)
-  constant.240 = f32[] constant(0)
-  reduce-window.65 = f32[5,5,30,31]{3,2,1,0} reduce-window(multiply.193, constant.240), window={size=1x1x2x1 pad=0_0x0_0x0_1x0_0}, to_apply=region_3.63
-  constant.239 = f32[] constant(0.25)
-  broadcast.191 = f32[5,5,30,31]{3,2,1,0} broadcast(constant.239), dimensions={}
-  multiply.192 = f32[5,5,30,31]{3,2,1,0} multiply(reduce-window.65, broadcast.191)
-  compare.3 = pred[5,5,30,31]{3,2,1,0} compare(multiply.192, broadcast.107), direction=EQ
-  and.1 = pred[5,5,30,31]{3,2,1,0} and(compare.4, compare.3)
-  constant.109 = f32[] constant(1.57079637)
-  broadcast.104 = f32[5,5,30,31]{3,2,1,0} broadcast(constant.109), dimensions={}
-  atan2.1 = f32[5,5,30,31]{3,2,1,0} atan2(sqrt.6, multiply.192)
-  select.4 = f32[5,5,30,31]{3,2,1,0} select(and.1, broadcast.104, atan2.1)
-  constant.107 = f32[] constant(0.159154937)
-  broadcast.106 = f32[5,5,30,31]{3,2,1,0} broadcast(constant.107), dimensions={}
-  multiply.100 = f32[5,5,30,31]{3,2,1,0} multiply(select.4, broadcast.106)
-  ROOT subtract.3 = f32[5,5,30,31]{3,2,1,0} subtract(broadcast.105, multiply.100)
-}
-
-fused_computation.4 {
-  param_0.172 = f32[5,30,31]{2,1,0} parameter(0)
-  constant.315 = f32[] constant(0.125)
-  broadcast.242 = f32[5,30,31]{2,1,0} broadcast(constant.315), dimensions={}
-  multiply.250 = f32[5,30,31]{2,1,0} multiply(param_0.172, broadcast.242)
-  constant.314 = f32[] constant(0)
-  reduce-window.100 = f32[5,30,31]{2,1,0} reduce-window(multiply.250, constant.314), window={size=1x3x3 pad=0_0x1_1x1_1}, to_apply=region_3.63
-  constant.79 = f32[] constant(0.055555556)
-  broadcast.85 = f32[5,30,31]{2,1,0} broadcast(constant.79), dimensions={}
-  multiply.80 = f32[5,30,31]{2,1,0} multiply(reduce-window.100, broadcast.85)
-  constant.81 = f32[] constant(0)
-  reduce-window.1 = f32[5,30,31]{2,1,0} reduce-window(multiply.80, constant.81), window={size=1x3x3 pad=0_0x1_1x1_1}, to_apply=region_3.63
-  constant.80 = f32[] constant(0.111111112)
-  broadcast.86 = f32[5,30,31]{2,1,0} broadcast(constant.80), dimensions={}
-  multiply.79 = f32[5,30,31]{2,1,0} multiply(reduce-window.1, broadcast.86)
-  bitcast.26 = f32[5,930]{1,0} bitcast(multiply.79)
-  ROOT reduce.8 = f32[5]{0} reduce(bitcast.26, constant.81), dimensions={1}, to_apply=region_3.63
-}
-
-ENTRY e {
-  Arg_0.1 = f32[5,32,32,1]{3,2,1,0} parameter(0)
-  p1 = f32[5,5,30,31]{3,2,1,0} parameter(1)
-  p2 = f32[5,5,30,31]{3,2,1,0} parameter(2)
-  p3 = f32[5,30,31]{2,1,0} parameter(3)
-  fusion.29 = f32[5,30,31]{2,1,0} fusion(Arg_0.1), kind=kLoop, calls=fused_computation.29
-  fusion.15 = f32[5,5,30,31]{3,2,1,0} fusion(p2, p1, p3, fusion.29), kind=kLoop, calls=fused_computation.15
-  ROOT fusion.4 = f32[5]{0} fusion(fusion.29), kind=kInput, calls=fused_computation.4
-})")
-                    .value();
-  EXPECT_FALSE(mof_.Run(module.get()).value());
-}
-
-TEST_F(MultiOutputFusionTest, NoOverlappingRead) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule module
-
-    fused_computation_1 {
-      p0.1 = f32[100,200]{1,0} parameter(0)
-      slice.0 = f32[50,100]{1,0} slice(p0.1), slice={[0:50],[0:100]}
-      mul = f32[50,100]{1,0} multiply(slice.0, slice.0)
-      exp = f32[50,100]{1,0} exponential(slice.0)
-      ROOT tuple = (f32[50,100]{1,0}, f32[50,100]{1,0}) tuple(mul, exp)
-    }
-
-    fused_computation_2 {
-      p0.2 = f32[100,200]{1,0} parameter(0)
-      slice.1 = f32[50,100]{1,0} slice(p0.2), slice={[0:50],[100:200]}
-      const.2 = f32[] constant(0)
-      broadcast = f32[50,100]{1,0} broadcast(const.2), dimensions={}
-      ROOT add = f32[50,100]{1,0} add(slice.1, broadcast)
-    }
-
-    ENTRY entry {
-      p0 = f32[100,200]{1,0} parameter(0)
-      fusion.1 = (f32[50,100]{1,0}, f32[50,100]{1,0}) fusion(p0), kind=kLoop,
-        calls=fused_computation_1
-      gte0 = f32[50,100]{1,0} get-tuple-element(fusion.1), index=0
-      gte1 = f32[50,100]{1,0} get-tuple-element(fusion.1), index=1
-      fusion.2 = f32[50,100]{1,0} fusion(p0), kind=kLoop,
-        calls=fused_computation_2
-      ROOT root = (f32[50,100]{1,0}, f32[50,100]{1,0}, f32[50,100]{1,0})
-        tuple(gte0, gte1, fusion.2)
-    })")
-                    .value();
-
-  EXPECT_FALSE(mof_.Run(module.get()).value());
-}
-
-TEST_F(MultiOutputFusionTest, OverlappingRead) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule module
-
-    fused_computation_1 {
-      p0.1 = f32[100,200]{1,0} parameter(0)
-      slice.0 = f32[50,100]{1,0} slice(p0.1), slice={[0:50],[50:150]}
-      mul = f32[50,100]{1,0} multiply(slice.0, slice.0)
-      exp = f32[50,100]{1,0} exponential(slice.0)
-      ROOT tuple = (f32[50,100]{1,0}, f32[50,100]{1,0}) tuple(mul, exp)
-    }
-
-    fused_computation_2 {
-      p0.2 = f32[100,200]{1,0} parameter(0)
-      slice.1 = f32[50,100]{1,0} slice(p0.2), slice={[30:80],[20:120]}
-      const.2 = f32[] constant(0)
-      broadcast = f32[50,100]{1,0} broadcast(const.2), dimensions={}
-      ROOT add = f32[50,100]{1,0} add(slice.1, broadcast)
-    }
-
-    ENTRY entry {
-      p0 = f32[100,200]{1,0} parameter(0)
-      fusion.1 = (f32[50,100]{1,0}, f32[50,100]{1,0}) fusion(p0), kind=kLoop,
-        calls=fused_computation_1
-      gte0 = f32[50,100]{1,0} get-tuple-element(fusion.1), index=0
-      gte1 = f32[50,100]{1,0} get-tuple-element(fusion.1), index=1
-      fusion.2 = f32[50,100]{1,0} fusion(p0), kind=kLoop,
-        calls=fused_computation_2
-      ROOT root = (f32[50,100]{1,0}, f32[50,100]{1,0}, f32[50,100]{1,0})
-        tuple(gte0, gte1, fusion.2)
-    })")
-                    .value();
-
-  EXPECT_TRUE(mof_.Run(module.get()).value());
-}
-
-class TransposeMultiOutputFusionTest : public MultiOutputFusionTest {};
-
-TEST_F(TransposeMultiOutputFusionTest, MultipleCopies) {
-  const char* hlo = R"(
-HloModule module
-
-fused_computation {
-  param_0.1 = f32[16,32]{1,0} parameter(0)
-  s.1 = f32[16,32]{1,0} sqrt(param_0.1)
-  ROOT c.1 = f32[16,32]{0,1} copy(s.1)
-}
-
-ENTRY main {
-  p = f32[16,32]{1,0} parameter(0)
-  fusion = f32[16,32]{0,1} fusion(p), kind=kInput, calls=fused_computation
-  c1 = f32[16,32]{0,1} copy(p)
-  ROOT t = (f32[16,32]{0,1}, f32[16,32]{0,1}) tuple(fusion, c1)
-}
-  )";
-
-  CheckGpuMultiOutputFusion(hlo, R"(
-// CHECK: %fused_computation (param_0.1: f32[16,32]) -> (f32[16,32], f32[16,32]) {
-// CHECK-NEXT:   [[param_0_1_0:%[^ ]+]] = f32[16,32]{1,0} parameter(0)
-// CHECK-NEXT:   [[s_1_1:%[^ ]+]] = f32[16,32]{1,0} sqrt([[param_0_1_0]])
-// CHECK-NEXT:   [[c_1_2:%[^ ]+]] = f32[16,32]{0,1} copy([[s_1_1]])
-// CHECK-NEXT:   [[c1_1_3:%[^ ]+]] = f32[16,32]{0,1} copy([[param_0_1_0]])
-// CHECK-NEXT:   ROOT [[tuple_4:%[^ ]+]] = (f32[16,32]{0,1}, f32[16,32]{0,1}) tuple([[c_1_2]], [[c1_1_3]])
-// CHECK-NEXT: }
-
-// CHECK: [[fusion_0:%[^ ]+]] = (f32[16,32]{0,1}, f32[16,32]{0,1}) fusion([[p_1:%[^ ]+]]), kind=kInput, calls=[[fused_computation_2:%[^ ]+]]
-)");
-}
-
-TEST_F(TransposeMultiOutputFusionTest, MultipleTransposes) {
-  const char* hlo = R"(
-HloModule module
-
-fused_computation {
-  param_0.1 = f32[16,32]{1,0} parameter(0)
-  s.1 = f32[16,32]{1,0} sqrt(param_0.1)
-  ROOT c.1 = f32[32,16]{1,0} transpose(s.1), dimensions={1,0}
-}
-
-ENTRY main {
-  p = f32[16,32]{1,0} parameter(0)
-  fusion = f32[32,16]{1,0} fusion(p), kind=kInput, calls=fused_computation
-  c1 = f32[32,16]{1,0} transpose(p), dimensions={1,0}
-  ROOT t = (f32[32,16]{1,0}, f32[32,16]{1,0}) tuple(fusion, c1)
-}
-  )";
-
-  CheckGpuMultiOutputFusion(hlo, R"(
-// CHECK: %fused_computation (param_0.1: f32[16,32]) -> (f32[32,16], f32[32,16]) {
-// CHECK-NEXT:   [[param_0_1_0:%[^ ]+]] = f32[16,32]{1,0} parameter(0)
-// CHECK-NEXT:   [[s_1_1:%[^ ]+]] = f32[16,32]{1,0} sqrt([[param_0_1_0]])
-// CHECK-NEXT:   [[c_1_2:%[^ ]+]] = f32[32,16]{1,0} transpose([[s_1_1]]), dimensions={1,0}
-// CHECK-NEXT:   [[c1_1_3:%[^ ]+]] = f32[32,16]{1,0} transpose([[param_0_1_0]]), dimensions={1,0}
-// CHECK-NEXT:   ROOT [[tuple_4:%[^ ]+]] = (f32[32,16]{1,0}, f32[32,16]{1,0}) tuple([[c_1_2]], [[c1_1_3]])
-// CHECK-NEXT: }
-// CHECK: [[fusion_0:%[^ ]+]] = (f32[32,16]{1,0}, f32[32,16]{1,0}) fusion([[p_1:%[^ ]+]]), kind=kInput, calls=[[fused_computation_2:%[^ ]+]]
-)");
-}
-
-TEST_F(TransposeMultiOutputFusionTest, CopyAndTranspose) {
-  const char* hlo = R"(
-HloModule module
-
-fused_computation {
-  param_0.1 = f32[16,32]{1,0} parameter(0)
-  s.1 = f32[16,32]{1,0} sqrt(param_0.1)
-  ROOT c.1 = f32[16,32]{0,1} copy(s.1)
-}
-
-ENTRY main {
-  p = f32[16,32]{1,0} parameter(0)
-  fusion = f32[16,32]{0,1} fusion(p), kind=kInput, calls=fused_computation
-  c1 = f32[32,16]{1,0} transpose(p), dimensions={1,0}
-  ROOT t = (f32[16,32]{0,1}, f32[32,16]{1,0}) tuple(fusion, c1)
-}
-  )";
-
-  CheckGpuMultiOutputFusion(hlo, R"(
-  // CHECK: %fused_computation ({{[^ ]+}} f32[16,32]) -> (f32[16,32], f32[32,16]) {
-  // CHECK-NEXT: [[param_0:%[^ ]+]] = f32[16,32]{1,0} parameter(0)
-  // CHECK-NEXT: [[s_1:%[^ ]+]] = f32[16,32]{1,0} sqrt([[param_0]])
-  // CHECK-NEXT: [[copy:%[^ ]+]] = f32[16,32]{0,1} copy([[s_1]])
-  // CHECK-NEXT: [[transpose:[^ ]+]] = f32[32,16]{1,0} transpose([[param_0]]), dimensions={1,0}
-  // CHECK-NEXT: ROOT {{[^ ]+}} = (f32[16,32]{0,1}, f32[32,16]{1,0}) tuple([[copy]], [[transpose]])
-  // CHECK: %fusion = (f32[16,32]{0,1}, f32[32,16]{1,0}) fusion(%{{.*}}), kind=kInput, calls=%fused_computation
-)");
-}
-
-TEST_F(TransposeMultiOutputFusionTest, MultipleCopiesDifferentTypes) {
-  const char* hlo = R"(
-HloModule module
-
-fused_computation {
-  param_0.1 = f16[16,32]{1,0} parameter(0)
-  s.1 = f32[16,32]{1,0} convert(param_0.1)
-  ROOT c.1 = f32[16,32]{0,1} copy(s.1)
-}
-
-ENTRY main {
-  p = f16[16,32]{1,0} parameter(0)
-  fusion = f32[16,32]{0,1} fusion(p), kind=kInput, calls=fused_computation
-  c1 = f16[16,32]{0,1} copy(p)
-  ROOT t = (f32[16,32]{0,1}, f16[16,32]{0,1}) tuple(fusion, c1)
-}
-  )";
-
-  CheckGpuMultiOutputFusion(hlo, R"(
-// CHECK: %fused_computation (param_0.1: f16[16,32]) -> (f32[16,32], f16[16,32]) {
-// CHECK-NEXT:   [[param_0_1_0:%[^ ]+]] = f16[16,32]{1,0} parameter(0)
-// CHECK-NEXT:   [[s_1_1:%[^ ]+]] = f32[16,32]{1,0} convert([[param_0_1_0]])
-// CHECK-NEXT:   [[c_1_2:%[^ ]+]] = f32[16,32]{0,1} copy([[s_1_1]])
-// CHECK-NEXT:   [[c1_1_3:%[^ ]+]] = f16[16,32]{0,1} copy([[param_0_1_0]])
-// CHECK-NEXT:   ROOT [[tuple_4:%[^ ]+]] = (f32[16,32]{0,1}, f16[16,32]{0,1}) tuple([[c_1_2]], [[c1_1_3]])
-// CHECK:   [[fusion_5:%[^ ]+]] = (f32[16,32]{0,1}, f16[16,32]{0,1}) fusion([[p_6:%[^ ]+]]), kind=kInput, calls=[[fused_computation_7:%[^ ]+]]
-)");
-}
-
-// Do not group copy and reduction.
-TEST_F(TransposeMultiOutputFusionTest, TiledReduceCopy) {
-  const char* hlo = R"(
-HloModule module
-
-add {
-  lhs = f32[] parameter(0)
-  rhs = f32[] parameter(1)
-  ROOT add = add(lhs, rhs)
-}
-
-fused_computation {
-  param_0.1 = f32[16,32]{1,0} parameter(0)
-  s.1 = f32[16,32]{1,0} sqrt(param_0.1)
-  ROOT c.1 = f32[16,32]{0,1} copy(s.1)
-}
-
-ENTRY main {
-  p = f32[16,32]{1,0} parameter(0)
-  fusion = f32[16,32]{0,1} fusion(p), kind=kInput, calls=fused_computation
-  z = f32[] constant(0)
-  r1 = f32[32]{0} reduce(p, z), dimensions={0}, to_apply=add
-  ROOT t = (f32[16,32]{0,1}, f32[32]{0}) tuple(fusion, r1)
-}
-  )";
-
-  CheckGpuMultiOutputFusion(hlo, std::nullopt);
-}
-
-// Do not group incompatible transposes.
-TEST_F(TransposeMultiOutputFusionTest, IncompatibleTransposes) {
-  const char* hlo = R"(
-HloModule module
-
-fused_computation {
-  param_0.1 = f32[18,16,32]{2,1,0} parameter(0)
-  param_1.1 = f32[32,16,18]{2,1,0} parameter(1)
-  s.1 = f32[18,16,32]{2,1,0} sqrt(param_0.1)
-  t.1 = f32[32,16,18]{2,1,0} transpose(s.1), dimensions={2,1,0}
-  sub.1 = f32[32,16,18]{2,1,0} subtract(t.1, param_1.1)
-  exp.1 = f32[32,16,18]{2,1,0} exponential(sub.1)
-  ROOT add.1 = f32[32,16,18]{2,1,0} add(exp.1, exp.1)
-}
-
-fused_computation.2 {
-  param_0.2 = f32[18,16,32]{2,1,0} parameter(0)
-  s.2 = f32[18,16,32]{2,1,0} sqrt(param_0.2)
-  ROOT t.2 = f32[18,32,16]{2,1,0} transpose(s.2), dimensions={0,2,1}
-}
-
-ENTRY main {
-  p = f32[18,16,32]{2,1,0} parameter(0)
-  p2 = f32[32,16,18]{2,1,0} parameter(1)
-  fusion = f32[32,16,18]{2,1,0} fusion(p, p2), kind=kLoop, calls=fused_computation
-  fusion2 = f32[18,32,16]{2,1,0} fusion(p), kind=kInput, calls=fused_computation.2
-  ROOT t = (f32[32,16,18]{2,1,0}, f32[18,32,16]{2,1,0}) tuple(fusion, fusion2)
-}
-  )";
-
-  CheckGpuMultiOutputFusion(hlo, std::nullopt);
-}
-
-// A variation of the test above, where no CSE was run.
-TEST_F(TransposeMultiOutputFusionTest, TransposesNoCSE) {
-  const char* hlo = R"(
-HloModule module
-
-fused_computation {
-  param_0.1 = f32[18,16,32]{2,1,0} parameter(0)
-  param_1.1 = f32[32,16,18]{2,1,0} parameter(1)
-  s.1 = f32[18,16,32]{2,1,0} sqrt(param_0.1)
-  t.1 = f32[32,16,18]{2,1,0} transpose(s.1), dimensions={2,1,0}
-  sub.1 = f32[32,16,18]{2,1,0} subtract(t.1, param_1.1)
-  exp.1 = f32[32,16,18]{2,1,0} exponential(sub.1)
-  exp.2 = f32[32,16,18]{2,1,0} exponential(sub.1)
-  ROOT add.1 = f32[32,16,18]{2,1,0} add(exp.1, exp.2)
-}
-
-fused_computation.2 {
-  param_0.2 = f32[18,16,32]{2,1,0} parameter(0)
-  s.2 = f32[18,16,32]{2,1,0} sqrt(param_0.2)
-  ROOT t.2 = f32[18,32,16]{2,1,0} transpose(s.2), dimensions={0,2,1}
-}
-
-ENTRY main {
-  p = f32[18,16,32]{2,1,0} parameter(0)
-  p2 = f32[32,16,18]{2,1,0} parameter(1)
-  fusion = f32[32,16,18]{2,1,0} fusion(p, p2), kind=kLoop, calls=fused_computation
-  fusion2 = f32[18,32,16]{2,1,0} fusion(p), kind=kInput, calls=fused_computation.2
-  ROOT t = (f32[32,16,18]{2,1,0}, f32[18,32,16]{2,1,0}) tuple(fusion, fusion2)
-}
-  )";
-
-  CheckGpuMultiOutputFusion(hlo, std::nullopt);
-}
-
-TEST_F(TransposeMultiOutputFusionTest, CopyAndInput) {
-  const char* hlo = R"(
-HloModule module
-
-fused_computation {
-  param_0.1 = f32[16,32]{1,0} parameter(0)
-  s.1 = f32[16,32]{1,0} sqrt(param_0.1)
-  ROOT c.1 = f32[16,32]{0,1} copy(s.1)
-}
-
-ENTRY main {
-  p = f32[16,32]{1,0} parameter(0)
-  fusion = f32[16,32]{0,1} fusion(p), kind=kInput, calls=fused_computation
-  c1 = exponential(p)
-  ROOT t = tuple(fusion, c1)
-}
-  )";
-
-  CheckGpuMultiOutputFusion(hlo, R"(
-// CHECK: %fused_computation (param_0.1: f32[16,32]) -> (f32[16,32], f32[16,32]) {
-// CHECK-NEXT:   [[param_0_1_0:%[^ ]+]] = f32[16,32]{1,0} parameter(0)
-// CHECK-NEXT:   [[s_1_1:%[^ ]+]] = f32[16,32]{1,0} sqrt([[param_0_1_0]])
-// CHECK-NEXT:   [[c_1_2:%[^ ]+]] = f32[16,32]{0,1} copy([[s_1_1]])
-// CHECK-NEXT:   [[c1_1_3:%[^ ]+]] = f32[16,32]{1,0} exponential([[param_0_1_0]])
-// CHECK-NEXT:   ROOT [[tuple_4:%[^ ]+]] = (f32[16,32]{0,1}, f32[16,32]{1,0}) tuple([[c_1_2]], [[c1_1_3]])
-// CHECK-NEXT: }
-// CHECK:   [[fusion_0:%[^ ]+]] = (f32[16,32]{0,1}, f32[16,32]{1,0}) fusion([[p_1:%[^ ]+]]), kind=kInput, calls=[[fused_computation_2:%[^ ]+]]
-)");
-}
-
-TEST_F(TransposeMultiOutputFusionTest, TransposeAndInputEpilogueFusion) {
-  const char* hlo = R"(
-HloModule module
-
-fused_computation {
-  param_0.1 = f32[16,32]{1,0} parameter(0)
-  s.1 = f32[16,32]{1,0} sqrt(param_0.1)
-  t.1 = f32[32,16]{1,0} transpose(s.1), dimensions={1,0}
-  ROOT out = f32[32,16,1]{2,1,0} bitcast(t.1)
-}
-
-ENTRY main {
-  p = f32[16,32]{1,0} parameter(0)
-  fusion = f32[32,16,1]{2,1,0} fusion(p), kind=kInput, calls=fused_computation
-  c1 = exponential(p)
-  ROOT t = tuple(fusion, c1)
-}
-  )";
-
-  CheckGpuMultiOutputFusion(hlo, R"(
-// CHECK: %fused_computation
-// CHECK-NEXT:   [[param_0_1_0:%[^ ]+]] = f32[16,32]{1,0} parameter(0)
-// CHECK-NEXT:   [[s_1_1:%[^ ]+]] = f32[16,32]{1,0} sqrt([[param_0_1_0]])
-// CHECK-NEXT:   [[c_1_2:%[^ ]+]] = f32[32,16]{1,0} transpose([[s_1_1]])
-// CHECK-NEXT:   [[out_3:%[^ ]+]] = f32[32,16,1]{2,1,0} bitcast([[c_1_2]])
-// CHECK-NEXT:   [[c1_1_4:%[^ ]+]] = f32[16,32]{1,0} exponential([[param_0_1_0]])
-// CHECK-NEXT:   ROOT [[tuple_5:%[^ ]+]] = (f32[32,16,1]{2,1,0}, f32[16,32]{1,0}) tuple([[out_3]], [[c1_1_4]])
-// CHECK-NEXT: }
-// CHECK: [[fusion_0:%[^ ]+]] = (f32[32,16,1]{2,1,0}, f32[16,32]{1,0}) fusion([[p_1:%[^ ]+]]), kind=kInput, calls=[[fused_computation_2:%[^ ]+]]
-)");
-}
-
-class ReduceMultiOutputFusionTest : public MultiOutputFusionTest {};
-
-TEST_F(ReduceMultiOutputFusionTest, ReduceAndLoop) {
-  const char* hlo = R"(
-HloModule module
-
-add {
-  a = f32[] parameter(0)
-  b = f32[] parameter(1)
-  ROOT c = f32[] add(a, b)
-}
-
-fused_reduction {
-  p = f32[200] parameter(0)
-  z = f32[] constant(0)
-  e = f32[200] exponential(p)
-  ROOT r = f32[] reduce(e, z), dimensions={0}, to_apply=add
-}
-
-fused_elementwise {
-  p = f32[200] parameter(0)
-  ROOT r = f32[200] sqrt(p)
-}
-
-ENTRY computation {
-  p = f32[200] parameter(0)
-  o1 = f32[200] fusion(p), kind=kLoop, calls=fused_elementwise
-  o2 = f32[] fusion(p), kind=kInput, calls=fused_reduction
-  ROOT out = (f32[200], f32[]) tuple(o1, o2)
-}
-
-)";
-
-  CheckGpuMultiOutputFusion(hlo, R"(
-// CHECK: %fused_elementwise
-// CHECK-NEXT:  [[p_1_0:%[^ ]+]] = f32[200]{0} parameter(0)
-// CHECK-NEXT:  [[r_1_1:%[^ ]+]] = f32[200]{0} sqrt([[p_1_0]])
-// CHECK-NEXT:  [[e_2:%[^ ]+]].clone.1 = f32[200]{0} exponential([[p_1_0]])
-// CHECK-NEXT:  [[z_3:%[^ ]+]].clone.1 = f32[] constant(0)
-// CHECK-NEXT:  [[r_4:%[^ ]+]].clone.1 = f32[] reduce([[e_2]].clone.1, [[z_3]].clone.1), dimensions={0}, to_apply=[[add_5:%[^ ]+]]
-// CHECK-NEXT:  ROOT [[tuple_6:%[^ ]+]] = (f32[200]{0}, f32[]) tuple([[r_1_1]], [[r_4]].clone.1)
-// CHECK-NEXT:}
-// CHECK: [[o1_0:%[^ ]+]] = (f32[200]{0}, f32[]) fusion([[p_2_1:%[^ ]+]]), kind=kInput, calls=[[fused_elementwise_2:%[^ ]+]]
-  )");
-}
-
-TEST_F(ReduceMultiOutputFusionTest, ReduceAndLoopDifferentShape) {
-  const char* hlo = R"(
-HloModule module
-
-add {
-  a = f32[] parameter(0)
-  b = f32[] parameter(1)
-  ROOT c = f32[] add(a, b)
-}
-
-fused_reduction {
-  p = f32[10,20] parameter(0)
-  z = f32[] constant(0)
-  e = f32[10,20] exponential(p)
-  b = f32[200] bitcast(e)
-  ROOT r = f32[] reduce(b, z), dimensions={0}, to_apply=add
-}
-
-fused_elementwise {
-  p = f32[10,20] parameter(0)
-  ROOT r = f32[10,20] sqrt(p)
-}
-
-ENTRY computation {
-  p = f32[10,20] parameter(0)
-  o1 = f32[10,20] fusion(p), kind=kLoop, calls=fused_elementwise
-  o2 = f32[] fusion(p), kind=kInput, calls=fused_reduction
-  ROOT out = (f32[10,20], f32[]) tuple(o1, o2)
-}
-)";
-
-  CheckGpuMultiOutputFusion(hlo, R"(
-// CHECK: %fused_elementwise (p.1: f32[10,20]) -> (f32[10,20], f32[]) {
-// CHECK-NEXT:   [[p_1_0:%[^ ]+]] = f32[10,20]{1,0} parameter(0)
-// CHECK-NEXT:   [[r_1_1:%[^ ]+]] = f32[10,20]{1,0} sqrt([[p_1_0]])
-// CHECK-NEXT:   [[e_2:%[^ ]+]].clone.1 = f32[10,20]{1,0} exponential([[p_1_0]])
-// CHECK-NEXT:   [[b_1_3:%[^ ]+]].clone.1 = f32[200]{0} bitcast([[e_2]].clone.1)
-// CHECK-NEXT:   [[z_4:%[^ ]+]].clone.1 = f32[] constant(0)
-// CHECK-NEXT:   [[r_5:%[^ ]+]].clone.1 = f32[] reduce([[b_1_3]].clone.1, [[z_4]].clone.1), dimensions={0}, to_apply=[[add_6:%[^ ]+]]
-// CHECK-NEXT:   ROOT [[tuple_7:%[^ ]+]] = (f32[10,20]{1,0}, f32[]) tuple([[r_1_1]], [[r_5]].clone.1)
-// CHECK-NEXT: }
-  )");
-}
-
-TEST_F(ReduceMultiOutputFusionTest, ReduceAndLoopDifferentShapeDifferentType) {
-  const char* hlo = R"(
-HloModule module, entry_computation_layout={(f16[100,200]{1,0},f32[],f32[])->(f16[100,200]{1,0}, f32[])}
-
-max {
-  a = f32[] parameter(0)
-  b = f32[] parameter(1)
-  ROOT c = f32[] maximum(a, b)
-}
-
-fused_computation {
-  one_5 = f32[] constant(1)
-  one_b.5 = f32[100,200]{1,0} broadcast(one_5), dimensions={}
-  param_1.15 = f16[100,200]{1,0} parameter(1)
-  c.6 = f32[100,200]{1,0} convert(param_1.15)
-  param_0.11 = f32[] parameter(0)
-  b.6 = f32[100,200]{1,0} broadcast(param_0.11), dimensions={}
-  d.5 = f32[100,200]{1,0} divide(c.6, b.6)
-  a.6 = f32[100,200]{1,0} add(one_b.5, d.5)
-  bitcast.1 = f32[20000]{0} bitcast(a.6)
-  z_1 = f32[] constant(0)
-  ROOT r.1 = f32[] reduce(bitcast.1, z_1), dimensions={0}, to_apply=max
-}
-
-fused_computation.1 {
-  one_3 = f32[] constant(1)
-  one_b.3 = f32[100,200]{1,0} broadcast(one_3), dimensions={}
-  param_2.7 = f16[100,200]{1,0} parameter(2)
-  c.4 = f32[100,200]{1,0} convert(param_2.7)
-  param_1.10 = f32[] parameter(1)
-  b.4 = f32[100,200]{1,0} broadcast(param_1.10), dimensions={}
-  d.3 = f32[100,200]{1,0} divide(c.4, b.4)
-  a.4 = f32[100,200]{1,0} add(one_b.3, d.3)
-  param_0.8 = f32[] parameter(0)
-  output_scale_broadcast.1 = f32[100,200]{1,0} broadcast(param_0.8), dimensions={}
-  a_scaled.1 = f32[100,200]{1,0} multiply(a.4, output_scale_broadcast.1)
-  ROOT a_scaled_converted.1 = f16[100,200]{1,0} convert(a_scaled.1)
-}
-
-ENTRY computation {
-  output_scale = f32[] parameter(2)
-  input_scale = f32[] parameter(1)
-  p = f16[100,200]{1,0} parameter(0)
-  fusion.1 = f16[100,200]{1,0} fusion(output_scale, input_scale, p), kind=kLoop, calls=fused_computation.1
-  fusion = f32[] fusion(input_scale, p), kind=kInput, calls=fused_computation
-  ROOT out = (f16[100,200]{1,0}, f32[]) tuple(fusion.1, fusion)
-}
-)";
-
-  CheckGpuMultiOutputFusion(hlo, R"(
-// CHECK: %fused_computation.1 (param_0.8: f32[], param_1.10: f32[], param_2.7: f16[100,200]) -> (f16[100,200], f32[]) {
-// CHECK-NEXT:   [[one_3_0:%[^ ]+]] = f32[] constant(1)
-// CHECK-NEXT:   [[one_b_3_1:%[^ ]+]] = f32[100,200]{1,0} broadcast([[one_3_0]]), dimensions={}
-// CHECK-NEXT:   [[param_2_7_2:%[^ ]+]] = f16[100,200]{1,0} parameter(2)
-// CHECK-NEXT:   [[c_4_3:%[^ ]+]] = f32[100,200]{1,0} convert([[param_2_7_2]])
-// CHECK-NEXT:   [[param_1_10_4:%[^ ]+]] = f32[] parameter(1)
-// CHECK-NEXT:   [[b_4_5:%[^ ]+]] = f32[100,200]{1,0} broadcast([[param_1_10_4]]), dimensions={}
-// CHECK-NEXT:   [[d_3_6:%[^ ]+]] = f32[100,200]{1,0} divide([[c_4_3]], [[b_4_5]])
-// CHECK-NEXT:   [[a_4_7:%[^ ]+]] = f32[100,200]{1,0} add([[one_b_3_1]], [[d_3_6]])
-// CHECK-NEXT:   [[param_0_8_8:%[^ ]+]] = f32[] parameter(0)
-// CHECK-NEXT:   [[output_scale_broadcast_1_9:%[^ ]+]] = f32[100,200]{1,0} broadcast([[param_0_8_8]]), dimensions={}
-// CHECK-NEXT:   [[a_scaled_1_10:%[^ ]+]] = f32[100,200]{1,0} multiply([[a_4_7]], [[output_scale_broadcast_1_9]])
-// CHECK-NEXT:   [[a_scaled_converted_1_11:%[^ ]+]] = f16[100,200]{1,0} convert([[a_scaled_1_10]])
-// CHECK-NEXT:   [[one_5_12:%[^ ]+]].clone.1 = f32[] constant(1)
-// CHECK-NEXT:   [[one_b_5_13:%[^ ]+]].clone.1 = f32[100,200]{1,0} broadcast([[one_5_12]].clone.1), dimensions={}
-// CHECK-NEXT:   [[c_6_14:%[^ ]+]].clone.1 = f32[100,200]{1,0} convert([[param_2_7_2]])
-// CHECK-NEXT:   [[b_6_15:%[^ ]+]].clone.1 = f32[100,200]{1,0} broadcast([[param_1_10_4]]), dimensions={}
-// CHECK-NEXT:   [[d_5_16:%[^ ]+]].clone.1 = f32[100,200]{1,0} divide([[c_6_14]].clone.1, [[b_6_15]].clone.1)
-// CHECK-NEXT:   [[a_6_17:%[^ ]+]].clone.1 = f32[100,200]{1,0} add([[one_b_5_13]].clone.1, [[d_5_16]].clone.1)
-// CHECK-NEXT:   [[bitcast_1_18:%[^ ]+]].clone.1 = f32[20000]{0} bitcast([[a_6_17]].clone.1)
-// CHECK-NEXT:   [[z_1_19:%[^ ]+]].clone.1 = f32[] constant(0)
-// CHECK-NEXT:   [[r_1_20:%[^ ]+]].clone.1 = f32[] reduce([[bitcast_1_18]].clone.1, [[z_1_19]].clone.1), dimensions={0}, to_apply=[[max_21:%[^ ]+]]
-// CHECK-NEXT:   ROOT [[tuple_22:%[^ ]+]] = (f16[100,200]{1,0}, f32[]) tuple([[a_scaled_converted_1_11]], [[r_1_20]].clone.1)
-// CHECK-NEXT: }
-  )");
-}
-
-TEST_F(ReduceMultiOutputFusionTest, GetTupleElementMakeTupleSequence) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule test_module
-
-    fusion {
-      p0 = s32[] parameter(0)
-      p1 = s32[32] parameter(1)
-      custom-call = (bf16[], s32[], u32[]) custom-call(p1), custom_call_target="my_custom_call"
-      get-tuple-element.0 = bf16[] get-tuple-element(custom-call), index=0
-      get-tuple-element.1 = s32[] get-tuple-element(custom-call), index=1
-      bitcast = s32[1] bitcast(get-tuple-element.1)
-      dynamic-update-slice = s32[32] dynamic-update-slice(p1, bitcast, p0)
-      get-tuple-element.2 = u32[] get-tuple-element(custom-call), index=2
-      ROOT tuple.30 = (bf16[], s32[32], u32[]) tuple(get-tuple-element.0, dynamic-update-slice, get-tuple-element.2)
-    }
-
-    ENTRY entry{
-      p0 = s32[] parameter(0)
-      bitcast = s32[32] bitcast(p0)
-      ROOT address_computation.7.0 = (bf16[], s32[32], u32[]) fusion(p0, bitcast), kind=kCustom, calls=fusion
-    }
-  )")
-                    .value();
-
-  ASSERT_FALSE(mof_.Run(module.get()).value());
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/nvptx_compiler.cc b/third_party/xla/xla/service/gpu/nvptx_compiler.cc
index 40eff87..7925ebc 100644
--- a/third_party/xla/xla/service/gpu/nvptx_compiler.cc
+++ b/third_party/xla/xla/service/gpu/nvptx_compiler.cc
@@ -15,6 +15,7 @@
 
 #include "xla/service/gpu/nvptx_compiler.h"
 
+#include <algorithm>
 #include <array>
 #include <cstdint>
 #include <fstream>
@@ -52,34 +53,35 @@
 #include "xla/service/dump.h"
 #include "xla/service/float_normalization.h"
 #include "xla/service/float_support.h"
-#include "xla/service/gpu/autotuner_util.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
+#include "xla/service/gpu/autotuning/conv_algorithm_picker.h"
+#include "xla/service/gpu/autotuning/gemm_algorithm_picker.h"
+#include "xla/service/gpu/autotuning/gemm_fusion_autotuner.h"
 #include "xla/service/gpu/buffer_sharing.h"
-#include "xla/service/gpu/conv_algorithm_picker.h"
-#include "xla/service/gpu/cublas_pad_for_gemms.h"
 #include "xla/service/gpu/cublas_padding_requirements.h"
-#include "xla/service/gpu/cudnn_fused_conv_rewriter.h"
-#include "xla/service/gpu/cudnn_fused_mha_rewriter.h"
-#include "xla/service/gpu/cudnn_fused_mha_transpose_fusion.h"
-#include "xla/service/gpu/cudnn_fusion_compiler.h"
-#include "xla/service/gpu/cudnn_norm_rewriter.h"
-#include "xla/service/gpu/cudnn_pad_for_convolutions.h"
-#include "xla/service/gpu/cudnn_simplify_padding.h"
-#include "xla/service/gpu/cudnn_vectorize_convolutions.h"
-#include "xla/service/gpu/cudnn_workspace_rewriter.h"
-#include "xla/service/gpu/cusolver_rewriter.h"
-#include "xla/service/gpu/dot_sparsity_rewriter.h"
-#include "xla/service/gpu/gemm_algorithm_picker.h"
-#include "xla/service/gpu/gemm_fusion_autotuner.h"
-#include "xla/service/gpu/gpu_algebraic_simplifier.h"
 #include "xla/service/gpu/gpu_asm_opts_util.h"
 #include "xla/service/gpu/gpu_compiler.h"
-#include "xla/service/gpu/gpu_conv_padding_legalization.h"
-#include "xla/service/gpu/gpu_conv_rewriter.h"
-#include "xla/service/gpu/gpu_sort_rewriter.h"
+#include "xla/service/gpu/ir_emission_utils.h"
 #include "xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
 #include "xla/service/gpu/metrics.h"
 #include "xla/service/gpu/target_constants.h"
-#include "xla/service/gpu/triangular_solve_rewriter.h"
+#include "xla/service/gpu/transforms/algebraic_simplifier.h"
+#include "xla/service/gpu/transforms/conv_padding_legalization.h"
+#include "xla/service/gpu/transforms/conv_rewriter.h"
+#include "xla/service/gpu/transforms/cublas_pad_for_gemms.h"
+#include "xla/service/gpu/transforms/cudnn_custom_call_compiler.h"
+#include "xla/service/gpu/transforms/cudnn_fused_conv_rewriter.h"
+#include "xla/service/gpu/transforms/cudnn_fused_mha_rewriter.h"
+#include "xla/service/gpu/transforms/cudnn_fused_mha_transpose_fusion.h"
+#include "xla/service/gpu/transforms/cudnn_fusion_compiler.h"
+#include "xla/service/gpu/transforms/cudnn_norm_rewriter.h"
+#include "xla/service/gpu/transforms/cudnn_pad_for_convolutions.h"
+#include "xla/service/gpu/transforms/cudnn_simplify_padding.h"
+#include "xla/service/gpu/transforms/cudnn_vectorize_convolutions.h"
+#include "xla/service/gpu/transforms/dot_sparsity_rewriter.h"
+#include "xla/service/gpu/transforms/gpusolver_rewriter.h"
+#include "xla/service/gpu/transforms/sort_rewriter.h"
+#include "xla/service/gpu/transforms/triangular_solve_rewriter.h"
 #include "xla/service/hlo_constant_folding.h"
 #include "xla/service/hlo_cse.h"
 #include "xla/service/hlo_dataflow_analysis.h"
@@ -94,6 +96,8 @@
 #include "xla/stream_executor/cuda/cuda_asm_compiler.h"
 #include "xla/stream_executor/cuda/cuda_diagnostics.h"
 #include "xla/stream_executor/cuda/cuda_platform_id.h"
+#include "xla/stream_executor/cuda/nvjitlink.h"
+#include "xla/stream_executor/cuda/nvjitlink_support.h"
 #include "xla/stream_executor/cuda/ptx_compilation_method.h"
 #include "xla/stream_executor/cuda/ptx_compiler.h"
 #include "xla/stream_executor/cuda/ptx_compiler_support.h"
@@ -187,7 +191,7 @@
   auto cuda_compute_capability =
       std::get<se::CudaComputeCapability>(gpu_version);
   // Convert convolutions into CustomCalls to cudnn, then canonicalize them
-  // (GpuConvPaddingLegalization). Also expand cuSolver calls.
+  // (ConvPaddingLegalization). Also expand cuSolver calls.
   HloPassPipeline pipeline("conv_canonicalization");
   pipeline.AddInvariantCheckerDebug<HloVerifier>(
       /*layout_sensitive=*/false,
@@ -202,10 +206,10 @@
   pipeline.AddPass<FloatNormalization>(&matmul_bf16_support);
 
   pipeline.AddPass<GpusolverRewriter>();
-  pipeline.AddPass<GpuConvRewriter>(cuda_compute_capability);
+  pipeline.AddPass<ConvRewriter>(cuda_compute_capability);
   pipeline.AddPass<CudnnFusedConvRewriter>(cuda_compute_capability, dnn_version,
                                            GetToolkitVersion());
-  pipeline.AddPass<GpuConvPaddingLegalization>();
+  pipeline.AddPass<ConvPaddingLegalization>();
   pipeline.AddPass<CudnnPadForConvolutions>(cuda_compute_capability);
   pipeline.AddPass<CudnnVectorizeConvolutions>(cuda_compute_capability,
                                                dnn_version);
@@ -230,7 +234,7 @@
   // e.g. clean up unnecessary nop `convert`s.
   pipeline.AddPass<CudnnSimplifyPadding>();
 
-  // tf2xla bridge, DepthwiseConvolutionConverter, GpuConvRewriter, and
+  // tf2xla bridge, DepthwiseConvolutionConverter, ConvRewriter, and
   // CudnnSimplifyPadding introduce reshapes and transposes.  Run ReshapeMover
   // to a fixed point.  Include algsimp because ReshapeMover relies on it.
   [&, &pipeline = pipeline.AddPass<HloPassFix<HloPassPipeline>>(
@@ -252,7 +256,7 @@
     pipeline.AddPass<GpuAlgebraicSimplifier>(algsimp_options, gpu_version);
   }();
 
-  // GpuConvRewriter, GpuConvPaddingLegalization and
+  // ConvRewriter, ConvPaddingLegalization and
   // CudnnConvPadForTensorCores may add instructions which can be simplified
   // by constant folding.
   pipeline.AddPass<HloConstantFolding>();
@@ -338,9 +342,6 @@
   // Transform TriangularSolve ops into custom-calls, so we can add temp
   // memory.
   post_pipeline.AddPass<TriangularSolveRewriter>();
-  if (stream_exec) {
-    post_pipeline.AddPass<CuDnnWorkspaceRewriter>(*stream_exec);
-  }
   TF_RETURN_IF_ERROR(post_pipeline.Run(hlo_module).status());
 
   return absl::OkStatus();
@@ -386,20 +387,22 @@
 absl::Status NVPTXCompiler::AddCustomKernelReplacementPasses(
     HloPassPipeline* pipeline, const DebugOptions& debug_options) {
   if (debug_options.xla_gpu_enable_cub_radix_sort()) {
-    pipeline->AddPass<GpuSortRewriter>();
+    pipeline->AddPass<SortRewriter>();
   }
   return absl::OkStatus();
 }
 
-absl::Status NVPTXCompiler::RunCudnnFusionCompilerPass(
+absl::Status NVPTXCompiler::RunCudnnCompilerPasses(
     HloModule* module, se::StreamExecutor* stream_exec,
     BinaryMap* dnn_compiled_graphs) {
   tsl::profiler::ScopedAnnotation annotation([&] {
     return absl::StrFormat("XlaCompileCudnnFusion:#module=%s,program_id=%d#",
                            module->name(), module->unique_id());
   });
-  CuDnnFusionCompiler cudnn_compiler(*stream_exec, *dnn_compiled_graphs);
-  return cudnn_compiler.Run(module).status();
+  CuDnnFusionCompiler fusion_compiler(*stream_exec, *dnn_compiled_graphs);
+  TF_RETURN_IF_ERROR(fusion_compiler.Run(module).status());
+  CuDnnCustomCallCompiler call_compiler(*stream_exec, *dnn_compiled_graphs);
+  return call_compiler.Run(module).status();
 }
 
 namespace {
@@ -531,6 +534,8 @@
   return &CanShareBufferHint;
 }
 
+constexpr const uint8_t kPtxPrefix[] = {'P', 'T', 'X', ':', ' '};
+
 absl::StatusOr<GpuCompiler::BackendCompileResult>
 NVPTXCompiler::CompileTargetBinary(const HloModuleConfig& module_config,
                                    llvm::Module* llvm_module,
@@ -568,6 +573,20 @@
     RecordLlvmPassesAndLlvmToPtxDuration(end_usecs - start_usecs);
   }
 
+  TF_ASSIGN_OR_RETURN(se::PtxLinkingMethod linking_method,
+                      ChooseLinkingMethod(module_config.debug_options()));
+
+  if (linking_method == se::PtxLinkingMethod::kNvJitLink && relocatable) {
+    VLOG(2) << "Deferring the PTX to CUBIN compilation of the relocatable "
+               "module to the linking step.";
+    std::vector<uint8_t> binary;
+    binary.reserve(sizeof(kPtxPrefix) + ptx.size() + 1);
+    binary.insert(binary.end(), kPtxPrefix, kPtxPrefix + sizeof(kPtxPrefix));
+    binary.insert(binary.end(), ptx.begin(), ptx.end());
+    binary.emplace_back('\0');
+    return BackendCompileResult{std::move(ptx), std::move(binary)};
+  }
+
   absl::StatusOr<std::vector<uint8_t>> maybe_cubin =
       CompileGpuAsmOrGetCachedResult(
           ptx, std::get<se::CudaComputeCapability>(gpu_version), module_config,
@@ -588,6 +607,9 @@
   if (se::IsLibNvPtxCompilerSupported()) {
     methods.emplace_back(PtxCompilationMethod::kNvPtxCompiler);
   }
+  if (se::IsLibNvJitLinkSupported()) {
+    methods.emplace_back(PtxCompilationMethod::kNvJitLink);
+  }
   methods.emplace_back(PtxCompilationMethod::kPtxas);
   return methods;
 }
@@ -608,11 +630,26 @@
     }
   };
 
+  if (!debug_options.xla_gpu_enable_libnvjitlink()) {
+    VLOG(3) << "Discarding NvJitLink since it is disabled.";
+    remove_compilation_method(PtxCompilationMethod::kNvJitLink);
+  }
   if (!debug_options.xla_gpu_enable_libnvptxcompiler()) {
     VLOG(3) << "Discarding NvPtxCompiler since it is disabled.";
     remove_compilation_method(PtxCompilationMethod::kNvPtxCompiler);
   }
 
+  VLOG(2) << "Supported and enabled compilation methods: "
+          << absl::StrJoin(compilation_methods, ", ");
+
+  if (relocatable && absl::c_linear_search(compilation_methods,
+                                           PtxCompilationMethod::kNvJitLink)) {
+    // NvJitLink can't produce relocatable CUBINs.
+    VLOG(3) << "Discarding NvJitLink since it can't produce the requested "
+               "relocatable CUBIN.";
+    remove_compilation_method(PtxCompilationMethod::kNvJitLink);
+  }
+
   VLOG(2) << "Considered compilation methods: "
           << absl::StrJoin(compilation_methods, ", ");
 
@@ -655,6 +692,16 @@
 
   absl::StatusOr<std::vector<uint8_t>> maybe_cubin = [&] {
     switch (compilation_method) {
+      case PtxCompilationMethod::kNvJitLink:
+        return se::CompileAndLinkUsingLibNvJitLink(
+            cc.major, cc.minor,
+            {se::NvJitLinkInput{
+                se::NvJitLinkInput::Type::kPtx,
+                absl::Span<const uint8_t>{
+                    reinterpret_cast<const uint8_t*>(ptx.c_str()),
+                    ptx.size() + 1 /* We need the null terminator. */}}},
+            ptxas_config, cancel_if_reg_spill);
+
       case PtxCompilationMethod::kNvPtxCompiler:
         return se::CompileGpuAsmUsingLibNvPtxCompiler(
             cc.major, cc.minor, ptx.c_str(), ptxas_config, cancel_if_reg_spill);
@@ -815,6 +862,12 @@
   std::string& preferred_cuda_dir = ptxas_config.preferred_cuda_dir;
 
   using LinkingMethod = se::PtxLinkingMethod;
+
+  if (stream_executor::IsLibNvJitLinkSupported() &&
+      debug_options.xla_gpu_enable_libnvjitlink()) {
+    return se::PtxLinkingMethod::kNvJitLink;
+  }
+
   TF_ASSIGN_OR_RETURN(auto asm_compiler_version,
                       GetAsmCompilerVersion(debug_options, preferred_cuda_dir));
 
@@ -859,28 +912,60 @@
 }
 
 absl::StatusOr<std::vector<uint8_t>> NVPTXCompiler::LinkModules(
-    se::GpuComputeCapability cc, se::StreamExecutor* stream_exec,
-    std::vector<std::vector<uint8_t>> modules,
+    se::GpuComputeCapability compute_capability,
+    se::StreamExecutor* stream_exec, std::vector<std::vector<uint8_t>> modules,
     const DebugOptions& debug_options) {
   if (modules.empty()) return std::vector<uint8_t>{};
 
+  auto cc =
+      std::get<stream_executor::CudaComputeCapability>(compute_capability);
+
   TF_ASSIGN_OR_RETURN(se::PtxLinkingMethod linking_method,
                       ChooseLinkingMethod(debug_options));
   VLOG(1) << "Linking " << modules.size()
           << " modules with linking method: " << linking_method;
 
-  std::vector<stream_executor::CubinOrPTXImage> images;
-  images.reserve(modules.size());
-  for (std::vector<uint8_t>& module : modules) {
-    images.push_back({"", std::move(module)});
+  if (linking_method == se::PtxLinkingMethod::kNvJitLink) {
+    const auto module_contains_ptx =
+        [](const std::vector<uint8_t>& module) -> bool {
+      return module.size() >= sizeof(kPtxPrefix) &&
+             std::equal(std::begin(kPtxPrefix), std::end(kPtxPrefix),
+                        std::begin(module));
+    };
+
+    std::vector<stream_executor::NvJitLinkInput> nvjitlink_inputs;
+    nvjitlink_inputs.reserve(modules.size());
+    for (std::vector<uint8_t>& module : modules) {
+      if (module_contains_ptx(module)) {
+        nvjitlink_inputs.push_back(
+            {se::NvJitLinkInput::Type::kPtx,
+             absl::Span<const uint8_t>(module).subspan(sizeof(kPtxPrefix))});
+      } else {
+        nvjitlink_inputs.push_back({se::NvJitLinkInput::Type::kCubin, module});
+      }
+    }
+
+    se::GpuAsmOpts ptxas_config = PtxOptsFromDebugOptions(debug_options);
+    return stream_executor::CompileAndLinkUsingLibNvJitLink(
+        cc.major, cc.minor, nvjitlink_inputs, ptxas_config,
+        /*cancel_if_reg_spill=*/false);
   }
+
+  std::vector<stream_executor::CubinOrPTXImage> cubin_images;
+  cubin_images.reserve(modules.size());
+  for (std::vector<uint8_t>& module : modules) {
+    {
+      std::string profile = absl::StrCat("sm_", cc.major, cc.minor);
+      cubin_images.push_back({std::move(profile), std::move(module)});
+    }
+  }
+
   auto context = se::gpu::ExtractGpuExecutor(stream_exec)->gpu_context();
   if (linking_method == se::PtxLinkingMethod::kNvLink) {
-    return LinkUsingNvlink(std::get<se::CudaComputeCapability>(cc),
-                           debug_options.xla_gpu_cuda_data_dir(), context,
-                           images);
+    return LinkUsingNvlink(cc, debug_options.xla_gpu_cuda_data_dir(), context,
+                           cubin_images);
   }
-  return LinkGpuAsm(std::get<se::CudaComputeCapability>(cc), context, images);
+  return LinkGpuAsm(cc, context, cubin_images);
 }
 
 }  // namespace gpu
diff --git a/third_party/xla/xla/service/gpu/nvptx_compiler.h b/third_party/xla/xla/service/gpu/nvptx_compiler.h
index 25fa268..6d84deb 100644
--- a/third_party/xla/xla/service/gpu/nvptx_compiler.h
+++ b/third_party/xla/xla/service/gpu/nvptx_compiler.h
@@ -32,7 +32,7 @@
 #include "xla/autotune_results.pb.h"
 #include "xla/hlo/ir/hlo_module.h"
 #include "xla/pjrt/distributed/key_value_store_interface.h"
-#include "xla/service/gpu/autotuner_util.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
 #include "xla/service/gpu/gpu_compiler.h"
 #include "xla/service/gpu/ir_emission_utils.h"
 #include "xla/service/hlo_dataflow_analysis.h"
@@ -84,9 +84,9 @@
   absl::Status AddCustomKernelReplacementPasses(
       HloPassPipeline* pipeline, const DebugOptions& debug_options) override;
 
-  absl::Status RunCudnnFusionCompilerPass(
-      HloModule* module, se::StreamExecutor* stream_exec,
-      BinaryMap* dnn_compiled_graphs) override;
+  absl::Status RunCudnnCompilerPasses(HloModule* module,
+                                      se::StreamExecutor* stream_exec,
+                                      BinaryMap* dnn_compiled_graphs) override;
 
   HloDataflowAnalysis::CanShareBuffer GetCanShareBuffer() const override;
 
@@ -100,7 +100,8 @@
 
  private:
   absl::StatusOr<std::vector<uint8_t>> LinkModules(
-      se::GpuComputeCapability cc, se::StreamExecutor* stream_exec,
+      se::GpuComputeCapability gpu_compute_capability,
+      se::StreamExecutor* stream_exec,
       std::vector<std::vector<uint8_t>> modules,
       const DebugOptions& debug_options) override;
 
diff --git a/third_party/xla/xla/service/gpu/pipelined_p2p_rewriter.cc b/third_party/xla/xla/service/gpu/pipelined_p2p_rewriter.cc
deleted file mode 100644
index d0e841c..0000000
--- a/third_party/xla/xla/service/gpu/pipelined_p2p_rewriter.cc
+++ /dev/null
@@ -1,703 +0,0 @@
-
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/pipelined_p2p_rewriter.h"
-
-#include <cstdint>
-#include <optional>
-#include <utility>
-#include <vector>
-
-#include "absl/container/flat_hash_map.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/log/check.h"
-#include "absl/log/log.h"
-#include "absl/status/status.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "xla/hlo/ir/dfs_hlo_visitor.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/hlo/ir/hlo_schedule.h"
-#include "xla/hlo/utils/hlo_query.h"
-#include "xla/service/collective_ops_utils.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/util.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-namespace {
-// Maps a computation to a boolean that indicates whether there is any
-// collective operations directly or indirectly invoked in the computation.
-using CollectiveInComputation =
-    absl::flat_hash_map<const HloComputation*, bool>;
-
-using InstructionVector = HloInstruction::InstructionVector;
-
-// Records starting index and the ending index of a pipelined while-op. They
-// are the indices of the while-loop operand.
-struct PipelinedP2PInfo {
-  int64_t opnd_start;
-  int64_t opnd_end;
-};
-
-// Returns whether the instruction is a collective operation.
-bool IsCollectiveOp(const HloInstruction* op) {
-  HloOpcode opcode = op->opcode();
-  // TODO(NVIDIA/4364298): The information is recorded in b/309639264.
-  // we need to avoid custom-calls to overlap with Send/Recv to workaround the
-  // bug. Remove custom-calls here when the bug is fixed.
-  if (opcode == HloOpcode::kCustomCall) {
-    return true;
-  }
-
-  return hlo_query::IsCollectiveCommunicationOp(opcode) ||
-         opcode == HloOpcode::kSend || opcode == HloOpcode::kRecv;
-}
-
-// Returns whether the instruction may invoke collective operations directly
-// or indirectly.
-bool MayInvokeCollectiveOp(
-    const HloInstruction* hlo,
-    const CollectiveInComputation& collective_in_computation) {
-  if (IsCollectiveOp(hlo)) {
-    return true;
-  }
-  for (HloComputation* callee : hlo->called_computations()) {
-    auto collective_in_comp = collective_in_computation.find(callee);
-    CHECK(collective_in_comp != collective_in_computation.end());
-    if (collective_in_comp->second) {
-      return true;
-    }
-  }
-  return false;
-}
-
-// Returns the unique get-tuple-element user with the given idx or nullptr if
-// there isn't such a unique user.
-HloInstruction* FindUniqueGTEUserWithIndex(const HloInstruction* op,
-                                           int64_t idx) {
-  CHECK(op->shape().IsTuple());
-
-  HloInstruction* gte = nullptr;
-  for (auto user : op->users()) {
-    if (user->opcode() != HloOpcode::kGetTupleElement) {
-      continue;
-    }
-    if (user->tuple_index() == idx) {
-      if (gte == nullptr) {
-        gte = user;
-      } else {
-        return nullptr;
-      }
-    }
-  }
-  return gte;
-}
-
-// Returns whether there is any get-tuple-element user with the given idx.
-bool HasGTEUserWithIndex(const HloInstruction* op, int64_t idx) {
-  CHECK(op->shape().IsTuple());
-
-  for (auto user : op->users()) {
-    if (user->opcode() != HloOpcode::kGetTupleElement) {
-      continue;
-    }
-    if (user->tuple_index() == idx) {
-      return true;
-    }
-  }
-  return false;
-}
-
-// Returns the instruction hidden behind a trivial tuple or `op`. This allows
-// the discovery of recv-done for the following case, for which the indirection
-// would have been removed by tuple-simplification.
-//    gte.0 = f32[1,1024,1024] get-tuple-element(recv-done), index=0
-//    gte.1 = token get-tuple-element(recv-done.p), index=1
-//    op = (f32[1,1024,1024], token[]) tuple(gte.0, gte.1)
-//
-// TODO(bixia): investigate the possible of implementing
-// m::TrivialTuple(m::RecvDone(&instr)) as suggested by code review.
-HloInstruction* MaySkipTrivialTuple(HloInstruction* op) {
-  if (op->opcode() != HloOpcode::kTuple) {
-    return op;
-  }
-  HloInstruction* hidden_op = nullptr;
-  for (auto opnd : op->mutable_operands()) {
-    if (opnd->opcode() != HloOpcode::kGetTupleElement) {
-      return op;
-    }
-    if (hidden_op == nullptr) {
-      hidden_op = opnd->mutable_operand(0);
-    } else if (opnd->mutable_operand(0) != hidden_op) {
-      return op;
-    }
-  }
-  return hidden_op;
-}
-
-// This routine is similar to the non-const version above except that the
-// the given instruction is used for pattern checking only and can't be mutated.
-const HloInstruction* MaySkipTrivialTuple(const HloInstruction* op) {
-  // Use const_cast to avoid repeating the non-const version above to find
-  // operands of the instruction through operands() instead of
-  // mutable_operands().
-  return MaySkipTrivialTuple(const_cast<HloInstruction*>(op));
-}
-
-// Finds a consecutive block of balanced SendDone/RecvDone in the while_init
-// of a while-loop, assuming its while_init is a tuple.
-std::optional<PipelinedP2PInfo>
-FindConsecutiveAndBalanceBlockOfSendDoneRecvDone(
-    const HloInstruction* while_init) {
-  PipelinedP2PInfo pipelined_p2p_info{0, 0};
-  // Return whether the first SendDone/RecvDone has been seen.
-  auto has_started = [&]() {
-    return pipelined_p2p_info.opnd_start != pipelined_p2p_info.opnd_end;
-  };
-  // Record the difference between the number of SendDone and RecvDone in a
-  // consecutive block.
-  int difference = 0;
-  // If SendDone/RecvDone exists in a consecutive block in the while_init
-  // tuple, find such block.
-  for (int64_t i = 0; i < while_init->operand_count(); ++i) {
-    const HloInstruction* op = while_init->operand(i);
-    if ((op->opcode() == HloOpcode::kRecvDone ||
-         op->opcode() == HloOpcode::kSendDone) &&
-        op->frontend_attributes().map().count(kSendRecvPipelineAttr) > 0) {
-      if (op->opcode() == HloOpcode::kRecvDone) {
-        difference++;
-      } else {
-        difference--;
-      }
-      if (!has_started()) {
-        pipelined_p2p_info.opnd_start = i;
-      }
-      pipelined_p2p_info.opnd_end = i + 1;
-    } else {
-      if (has_started()) {
-        VLOG(10) << "End a consecutive block";
-        break;
-      }
-    }
-  }
-
-  if (difference != 0) {
-    VLOG(10) << "Mismatch number of SendDone and RecvDone: " << difference;
-    return std::nullopt;
-  }
-
-  if (has_started()) {
-    // Check for SendDone/RecvDone outside the consecutive block.
-    for (int64_t i = pipelined_p2p_info.opnd_end;
-         i < while_init->operand_count(); ++i) {
-      const HloInstruction* op = while_init->operand(i);
-      if (op->opcode() == HloOpcode::kRecvDone ||
-          op->opcode() == HloOpcode::kSendDone) {
-        VLOG(10) << "SendDone/RecvDone outside the consecutive block";
-        return std::nullopt;
-        break;
-      }
-    }
-  }
-
-  if (!has_started()) {
-    VLOG(10) << "No SendDone/RecvDone in while-init ";
-    return std::nullopt;
-  }
-
-  return pipelined_p2p_info;
-}
-
-// Checks whether the while-op, its while-body and while-condition have a
-// recognized pipelined pattern. If a pipelined pattern is found, returns the
-// first and last indices for the pipelined instruction in the while-init tuple.
-// For pipelined Send/Recv to work, the SendDone/RecvDone doesn't have to be in
-// a consecutive block, but this simplifies the implementation and is the
-// pattern that the current gpu-p2p-pipeliner generated.
-//
-// As a summary, this is what the routine looks for:
-//
-// . The while-init has a tuple with a single user.
-// . The while-init has a consecutive block of SendDone and RecvDone. The
-//   numbers of SendDone and RecvDone are the same, and there isn't any other
-//   SendDone and RecvDone outside the block.
-// . The while-body has a single tuple parameter.
-// . For the while-op result tuple and the while-body parameter tuple:
-//     The index corresponding to the index of SendDone in while-init should not
-//       correspond to any get-element-tuple user.
-//     The index corresponding to the index of RecvDone in while-init should
-//       correspond to a single get-element-tuple user.
-// . In the while-body result tuple, the operand with an index corresponding to
-//   the index in the while-init SendDone and RecvDone should also be a SendDone
-//   or RecvDone.
-//
-// TODO(bixia): support pipelined SendDone/RecvDone not in a consecutive block
-// if the gpu-p2p-pipeliner will ever generate such code in the future.
-std::optional<PipelinedP2PInfo> FindPipelinedP2P(
-    const HloInstruction* while_op) {
-  VLOG(10) << "while_op: " << while_op->ToString();
-  const HloInstruction* while_init = while_op->while_init();
-  if (while_init->opcode() != HloOpcode::kTuple ||
-      while_init->user_count() != 1) {
-    return std::nullopt;
-  }
-
-  // The while-body and while-condition should have one parameter of a tuple
-  // shape.
-  const HloComputation* while_body = while_op->while_body();
-  const HloComputation* while_condition = while_op->while_condition();
-  if (while_body->num_parameters() != 1 ||
-      while_condition->num_parameters() != 1) {
-    return std::nullopt;
-  }
-
-  std::optional<PipelinedP2PInfo> pipelined_p2p_info =
-      FindConsecutiveAndBalanceBlockOfSendDoneRecvDone(while_init);
-  if (!pipelined_p2p_info.has_value()) {
-    return std::nullopt;
-  }
-
-  VLOG(10) << "opnd_start " << pipelined_p2p_info->opnd_start << " opnd_end "
-           << pipelined_p2p_info->opnd_end;
-
-  // In the while-result or while-body parameter, the index for RecvDone should
-  // correspond to one get-tuple-element user and the index for SendDone should
-  // not correspond to any get-tuple-element user.
-  for (int64_t i = pipelined_p2p_info->opnd_start;
-       i < pipelined_p2p_info->opnd_end; ++i) {
-    const HloInstruction* op = while_init->operand(i);
-    if (op->opcode() == HloOpcode::kRecvDone) {
-      if (!FindUniqueGTEUserWithIndex(while_op, i)) {
-        VLOG(10) << "While result get-tuple-element user with index " << i
-                 << " not unique";
-        return std::nullopt;
-      }
-      if (!FindUniqueGTEUserWithIndex(while_body->parameter_instruction(0),
-                                      i)) {
-        VLOG(10) << "While-body parameter get-tuple-element user with index "
-                 << i << " not unique";
-        return std::nullopt;
-      }
-    } else {
-      CHECK(op->opcode() == HloOpcode::kSendDone);
-      if (HasGTEUserWithIndex(while_op, i) ||
-          HasGTEUserWithIndex(while_body->parameter_instruction(0), i)) {
-        VLOG(10) << "SendDone with index " << i << " has unexpected users";
-        return std::nullopt;
-      }
-    }
-  }
-
-  // The element in the while-body result tuple corresponding to the pipelined
-  // SendDone/RecvDone in the while-init have the same opcode.
-  const HloInstruction* root = while_body->root_instruction();
-  for (int64_t i = pipelined_p2p_info->opnd_start;
-       i < pipelined_p2p_info->opnd_end; ++i) {
-    const HloInstruction* op_init = while_init->operand(i);
-    const HloInstruction* op_root = root->operand(i);
-    op_root = MaySkipTrivialTuple(op_root);
-    if (op_init->opcode() != op_root->opcode()) {
-      VLOG(10) << "Mismatching opcode, op_init: " << op_init->ToString()
-               << " op_root: " << op_root->ToString();
-      return std::nullopt;
-    }
-  }
-
-  return pipelined_p2p_info.value();
-}
-
-absl::Status RemoveOpFromParent(HloInstruction* op) {
-  TF_RETURN_IF_ERROR(op->DropAllControlDeps());
-  TF_RETURN_IF_ERROR(op->parent()->RemoveInstruction(op));
-  return absl::OkStatus();
-}
-
-absl::Status ReplaceOpInSequence(HloInstruction* old_op, HloInstruction* new_op,
-                                 HloInstructionSequence& instruction_sequence) {
-  VLOG(10) << "old_op: " << old_op->ToString();
-  VLOG(10) << "new_op: " << new_op->ToString();
-  instruction_sequence.replace_instruction(old_op, new_op);
-  return RemoveOpFromParent(old_op);
-}
-
-absl::Status ReplaceUsesAndUpdateSequence(
-    HloInstruction* old_op, HloInstruction* new_op,
-    HloInstructionSequence& instruction_sequence, bool diff_shape = false) {
-  VLOG(10) << "old_op: " << old_op->ToString();
-  VLOG(10) << "new_op: " << new_op->ToString();
-  if (diff_shape) {
-    TF_RETURN_IF_ERROR(old_op->ReplaceAllUsesWithDifferentShape(new_op));
-  } else {
-    TF_RETURN_IF_ERROR(old_op->ReplaceAllUsesWith(new_op));
-  }
-  return ReplaceOpInSequence(old_op, new_op, instruction_sequence);
-}
-
-absl::Status ReplaceUsesAndUpdateSequence(
-    const InstructionVector& old_ops, const InstructionVector& new_ops,
-    HloInstructionSequence& instruction_sequence) {
-  CHECK(old_ops.size() == new_ops.size());
-  for (int64_t i = 0; i < old_ops.size(); ++i) {
-    TF_RETURN_IF_ERROR(ReplaceUsesAndUpdateSequence(old_ops[i], new_ops[i],
-                                                    instruction_sequence));
-  }
-  return absl::OkStatus();
-}
-
-absl::Status RemoveDoneOpsAndUpdateSequence(
-    const InstructionVector& ops,
-    HloInstructionSequence& instruction_sequence) {
-  auto remove_op = [&](HloInstruction* op) {
-    VLOG(10) << "op: " << op->ToString();
-    TF_RETURN_IF_ERROR(RemoveOpFromParent(op));
-    instruction_sequence.remove_instruction(op);
-    return absl::OkStatus();
-  };
-  for (auto op : ops) {
-    if (op->opcode() == HloOpcode::kTuple) {
-      InstructionVector to_remove;
-      HloInstruction* tuple_op = op;
-      op = MaySkipTrivialTuple(tuple_op);
-      to_remove.push_back(tuple_op);
-      for (auto opnd : tuple_op->mutable_operands()) {
-        to_remove.push_back(opnd);
-      }
-      for (auto opnd : to_remove) {
-        TF_RETURN_IF_ERROR(remove_op(opnd));
-      }
-    }
-    TF_RETURN_IF_ERROR(remove_op(op));
-  }
-  return absl::OkStatus();
-}
-
-bool InsertBeforeFirstCollectiveOp(
-    const InstructionVector& ops,
-    const CollectiveInComputation& collective_in_computation,
-    HloInstructionSequence& instruction_sequence, int64_t& idx,
-    int64_t& idx_tot) {
-  bool inserted = false;
-  while (idx < idx_tot) {
-    HloInstruction* hlo = instruction_sequence.instructions()[idx];
-    if (MayInvokeCollectiveOp(hlo, collective_in_computation)) {
-      for (auto op : ops) {
-        instruction_sequence.insert_instruction(op, idx);
-        idx++;
-        idx_tot++;
-      }
-      inserted = true;
-      break;
-    }
-    idx++;
-  }
-  return inserted;
-}
-
-void CopyInstructionInfo(const HloInstruction* old_op, HloInstruction* new_op) {
-  new_op->SetAndSanitizeName(absl::StrCat(old_op->name(), ".clone"));
-  new_op->set_metadata(old_op->metadata());
-  new_op->add_frontend_attributes(old_op->frontend_attributes());
-  new_op->CopyBackendConfigFrom(old_op);
-}
-
-HloInstruction* CreateRecvDoneFrom(const HloInstruction* old_recv_done,
-                                   HloInstruction* recv,
-                                   HloComputation* computation) {
-  HloInstruction* recv_done =
-      computation->AddInstruction(HloInstruction::CreateRecvDone(
-          recv, old_recv_done->channel_id().value()));
-  CopyInstructionInfo(old_recv_done, recv_done);
-  return recv_done;
-}
-
-HloInstruction* CreateSendDoneFrom(const HloInstruction* old_send_done,
-                                   HloInstruction* send,
-                                   HloComputation* computation) {
-  HloInstruction* send_done =
-      computation->AddInstruction(HloInstruction::CreateSendDone(
-          send, old_send_done->channel_id().value()));
-  CopyInstructionInfo(old_send_done, send_done);
-  return send_done;
-}
-
-absl::Status RewritePipelinedP2PWhileBody(
-    const CollectiveInComputation& collective_in_computation,
-    const std::vector<Shape>& new_parameter_shapes, HloInstruction* while_op,
-    int64_t opnd_start, int64_t opnd_end) {
-  HloComputation* computation = while_op->while_body();
-  HloInstruction* while_init = while_op->while_init();
-  HloInstruction* root = computation->root_instruction();
-  HloInstructionSequence& instruction_sequence =
-      computation->parent()->schedule().GetOrCreateSequence(computation);
-
-  HloInstruction* param = computation->parameter_instruction(0);
-  *param->mutable_shape() = ShapeUtil::MakeTupleShape(new_parameter_shapes);
-
-  InstructionVector recv_dones;
-  InstructionVector new_recv_dones;
-  InstructionVector new_send_dones;
-  for (int64_t i = opnd_start; i < opnd_end; ++i) {
-    const HloInstruction* op = root->operand(i);
-    op = MaySkipTrivialTuple(op);
-    if (op->opcode() == HloOpcode::kRecvDone) {
-      HloInstruction* gte = FindUniqueGTEUserWithIndex(param, i);
-      CHECK(gte != nullptr);
-      recv_dones.push_back(gte);
-
-      // Create the new RecvDone using the new while-body parameter.
-      HloInstruction* recv = computation->AddInstruction(
-          HloInstruction::CreateGetTupleElement(param, i));
-
-      HloInstruction* recv_done = CreateRecvDoneFrom(op, recv, computation);
-      new_recv_dones.push_back(recv_done);
-      continue;
-    }
-    CHECK(op->opcode() == HloOpcode::kSendDone);
-    //  Create the new SendDone using the new while-op result.
-    HloInstruction* send = computation->AddInstruction(
-        HloInstruction::CreateGetTupleElement(param, i));
-    HloInstruction* send_done = CreateSendDoneFrom(op, send, computation);
-    new_send_dones.push_back(send_done);
-  }
-  TF_RETURN_IF_ERROR(ReplaceUsesAndUpdateSequence(recv_dones, new_recv_dones,
-                                                  instruction_sequence));
-
-  // Create a new root tuple.
-  InstructionVector done_ops;
-  InstructionVector new_opnds;
-  for (int64_t i = 0; i < while_init->operand_count(); ++i) {
-    HloInstruction* op = root->mutable_operand(i);
-    if (i >= opnd_start && i < opnd_end) {
-      new_opnds.push_back(MaySkipTrivialTuple(op)->mutable_operand(0));
-      done_ops.push_back(op);
-    } else {
-      new_opnds.push_back(op);
-    }
-  }
-  HloInstruction* new_root =
-      computation->AddInstruction(HloInstruction::CreateTuple(new_opnds));
-  computation->set_root_instruction(new_root,
-                                    /*accept_different_shape=*/true);
-  TF_RETURN_IF_ERROR(computation->RemoveInstruction(root));
-  instruction_sequence.replace_instruction(root, new_root);
-
-  TF_RETURN_IF_ERROR(
-      RemoveDoneOpsAndUpdateSequence(done_ops, instruction_sequence));
-
-  // Find a place to put the new SendDone. It will be either the first
-  // may-invoke-collective ops that is not in the pipelined Send/Recv chain or
-  // the first op in the pipelined Send/Recv chain.
-  int64_t idx = 0;
-  int64_t idx_end = instruction_sequence.size();
-  bool inserted =
-      InsertBeforeFirstCollectiveOp(new_send_dones, collective_in_computation,
-                                    instruction_sequence, idx, idx_end);
-  CHECK(inserted);  // There are Send/Recv in the while-body, expect inserted.
-  CHECK(idx_end == instruction_sequence.size());
-
-  // The module schedule will be updated at the end of the pass.
-  return absl::OkStatus();
-}
-
-void RewritePipelinedP2PWhileCond(
-    const std::vector<Shape>& new_parameter_shapes, HloInstruction* while_op) {
-  HloComputation* computation = while_op->while_condition();
-  HloInstruction* param = computation->parameter_instruction(0);
-  *param->mutable_shape() = ShapeUtil::MakeTupleShape(new_parameter_shapes);
-  VLOG(10) << computation->ToString();
-}
-
-// Rewrites the while-op with a recognized pipelined SendDone/RecvDone pattern
-// to pipeline Send/Recv instead.
-absl::Status TransformLoop(
-    const PipelinedP2PInfo& pipelined_info,
-    const CollectiveInComputation& collective_in_computation, int64_t& idx,
-    int64_t& idx_end, HloInstructionSequence& instruction_sequence,
-    HloInstruction* while_op) {
-  HloComputation* computation = while_op->parent();
-  int64_t opnd_start = pipelined_info.opnd_start;
-  int64_t opnd_end = pipelined_info.opnd_end;
-  VLOG(10) << "Transform pipelined while-op " << while_op->ToString();
-  HloInstruction* while_init = while_op->while_init();
-  InstructionVector new_while_init_opnds;
-  std::vector<Shape> new_parameter_shapes;
-  for (int64_t i = 0; i < while_init->operand_count(); ++i) {
-    HloInstruction* op = while_init->mutable_operand(i);
-    if (i >= opnd_start && i < opnd_end) {
-      // Get Send/Recv from SendDone/RecvDone.
-      new_while_init_opnds.push_back(op->mutable_operand(0));
-    } else {
-      new_while_init_opnds.push_back(op);
-    }
-    new_parameter_shapes.push_back(new_while_init_opnds.back()->shape());
-  }
-
-  RewritePipelinedP2PWhileCond(new_parameter_shapes, while_op);
-  TF_RETURN_IF_ERROR(RewritePipelinedP2PWhileBody(
-      collective_in_computation, new_parameter_shapes, while_op, opnd_start,
-      opnd_end));
-  HloInstruction* new_while_init = computation->AddInstruction(
-      HloInstruction::CreateTuple(new_while_init_opnds), "while-init");
-  VLOG(10) << "new_while_init: " << new_while_init->ToString();
-  HloInstruction* new_while_op = computation->AddInstruction(
-      HloInstruction::CreateWhile(
-          while_op->while_body()->root_instruction()->shape(),
-          while_op->while_condition(), while_op->while_body(), new_while_init),
-      "while-result");
-  CopyInstructionInfo(while_op, new_while_op);
-  VLOG(10) << "new_while_op: " << new_while_op->ToString();
-
-  InstructionVector recv_dones;
-  InstructionVector new_recv_dones;
-  InstructionVector new_send_dones;
-  InstructionVector done_ops;
-  for (int64_t i = opnd_start; i < opnd_end; ++i) {
-    HloInstruction* op = while_init->mutable_operand(i);
-    done_ops.push_back(op);
-    if (op->opcode() == HloOpcode::kRecvDone) {
-      HloInstruction* gte = FindUniqueGTEUserWithIndex(while_op, i);
-      CHECK(gte != nullptr);
-      recv_dones.push_back(gte);
-
-      // Create the new RecvDone using the new while-op result.
-      HloInstruction* recv = computation->AddInstruction(
-          HloInstruction::CreateGetTupleElement(new_while_op, i));
-      HloInstruction* recv_done = computation->AddInstruction(
-          HloInstruction::CreateRecvDone(recv, op->channel_id().value()));
-      new_recv_dones.push_back(recv_done);
-      CopyInstructionInfo(op, recv_done);
-      continue;
-    }
-    CHECK(op->opcode() == HloOpcode::kSendDone);
-    //  Create the new SendDone using the new while-op result.
-    HloInstruction* send = computation->AddInstruction(
-        HloInstruction::CreateGetTupleElement(new_while_op, i));
-    HloInstruction* send_done = computation->AddInstruction(
-        HloInstruction::CreateSendDone(send, op->channel_id().value()));
-    new_send_dones.push_back(send_done);
-    CopyInstructionInfo(op, send_done);
-  }
-
-  TF_RETURN_IF_ERROR(ReplaceUsesAndUpdateSequence(
-      while_op, new_while_op, instruction_sequence, /*diff_shape*/ true));
-  TF_RETURN_IF_ERROR(
-      ReplaceOpInSequence(while_init, new_while_init, instruction_sequence));
-  TF_RETURN_IF_ERROR(ReplaceUsesAndUpdateSequence(recv_dones, new_recv_dones,
-                                                  instruction_sequence));
-  TF_RETURN_IF_ERROR(
-      RemoveDoneOpsAndUpdateSequence(done_ops, instruction_sequence));
-
-  int64_t opnd_tot = opnd_end - opnd_start;
-  // Verify that the numbers of ops we have removed from the sequence is
-  // opnd_tot and they are before the position of the new while-op.
-  CHECK(idx_end == instruction_sequence.size() + opnd_tot);
-  CHECK(instruction_sequence.instructions()[idx - opnd_tot] == new_while_op);
-
-  // Update idx_end to reflect the current size of the instruction sequence.
-  // Update idx to right after the new while-op.
-  idx_end -= opnd_tot;
-  idx = idx - opnd_tot + 1;
-  bool inserted =
-      InsertBeforeFirstCollectiveOp(new_send_dones, collective_in_computation,
-                                    instruction_sequence, idx, idx_end);
-  CHECK(idx_end == instruction_sequence.size());
-  // If there isn't any may-invoke-collective ops after the while-op, add
-  // the new SendDone ops before the last instruction in the sequence.
-  if (!inserted) {
-    CHECK(idx_end == idx);
-    idx--;
-    for (auto send_done : new_send_dones) {
-      instruction_sequence.insert_instruction(send_done, idx++);
-    }
-  }
-  return absl::OkStatus();
-}
-
-// Find while-loop with pipelined Send/Recv and rotates the SendDone/RecvDone
-// for such while-loop.
-absl::StatusOr<bool> ProcessComputation(
-    HloModule* module, HloComputation* computation,
-    CollectiveInComputation& collective_in_computation) {
-  VLOG(10) << "Process compuation " << computation->name();
-  bool changed = false;
-  HloInstructionSequence& instruction_sequence =
-      module->schedule().GetOrCreateSequence(computation);
-  int64_t idx = 0;
-  int64_t idx_end = instruction_sequence.size();
-  while (idx < idx_end) {
-    HloInstruction* hlo = instruction_sequence.instructions()[idx];
-
-    if (MayInvokeCollectiveOp(hlo, collective_in_computation)) {
-      collective_in_computation[computation] = true;
-    }
-
-    if (hlo->opcode() != HloOpcode::kWhile) {
-      idx++;
-      continue;
-    }
-
-    std::optional<PipelinedP2PInfo> pipelined_info = FindPipelinedP2P(hlo);
-    if (!pipelined_info.has_value()) {
-      idx++;
-      continue;
-    }
-    TF_RETURN_IF_ERROR(TransformLoop(pipelined_info.value(),
-                                     collective_in_computation, idx, idx_end,
-                                     instruction_sequence, hlo));
-    changed = true;
-  }
-  return changed;
-}
-}  // namespace
-
-absl::StatusOr<bool> PipelinedP2PRewriter::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  bool changed = false;
-  if (!module->has_schedule()) return changed;
-  CollectiveInComputation collective_in_computation;
-  // Visit the computations in the order of callees to callers, so that
-  // while-body is processed before while-op.
-  for (auto* computation :
-       module->MakeComputationPostOrder(execution_threads)) {
-    if (computation->IsFusionComputation()) {
-      collective_in_computation[computation] = false;
-      continue;
-    }
-
-    TF_ASSIGN_OR_RETURN(
-        bool cur_changed,
-        ProcessComputation(module, computation, collective_in_computation));
-    changed |= cur_changed;
-  }
-
-  if (changed) {
-    TF_RETURN_IF_ERROR(module->schedule().Update());
-  }
-
-  return changed;
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/pipelined_p2p_rewriter.h b/third_party/xla/xla/service/gpu/pipelined_p2p_rewriter.h
deleted file mode 100644
index 88b6bb6..0000000
--- a/third_party/xla/xla/service/gpu/pipelined_p2p_rewriter.h
+++ /dev/null
@@ -1,133 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_PIPELINED_P2P_REWRITER_H_
-#define XLA_SERVICE_GPU_PIPELINED_P2P_REWRITER_H_
-
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-// PipelinedP2PRewriter is a pass that rewrites pipelined Send/Recv related
-// code for point-to-point communication to rotate SendDone and RecvDone at the
-// end of a loop iteration to the beginning of the next iteration. This pass
-// operates on scheduled module and updates the instruction sequence.
-//
-// In particular, a pipelined Send/Recv chain with one channel group with this
-// code pattern:
-//
-// main:
-//    recv
-//    send
-//    recv-done
-//    send-done
-//    while-init = (recv-done, send-done, ...)
-//    while-op = while(whiel-init) ...
-//
-// while-body:
-//    ...
-//    recv
-//    send
-//    recv-done
-//    send-done
-//    ROOT tuple(recv-done, send-done, ...)
-//
-// Will be transformed to:
-//
-// main:
-//    recv
-//    send
-//    while-init = (recv, send, ...)
-//    while-op = while(whiel-init) ...
-//    recv-done
-//    send-done
-//
-// while-body:
-//    recv-done
-//    ...
-//    send-done
-//    recv
-//    send
-//    ROOT tuple(recv, send, ...)
-//
-// A pipelined Send/Recv chain with two channel groups with this code pattern:
-//
-// main:
-//    recv.0
-//    send.0
-//    recv.1
-//    send.1
-//    recv-done.0
-//    send-done.0
-//    recv-done.1
-//    send-done.1
-//    while-init = (recv-done.0, send-done.0, recv-done.1, send-done.1, ...)
-//    while-op = while(whiel-init) ...
-//
-// while-body:
-//    ...
-//    recv.0
-//    send.0
-//    recv.1
-//    send.1
-//    recv-done.0
-//    send-done.0
-//    recv-done.1
-//    send-done.1
-//    ROOT = tuple(recv-done.0, send-done.0, recv-done.1, send-done.1, ...)
-//
-// Will be transformed to:
-//
-// main:
-//
-//    recv.0
-//    send.0
-//    recv.1
-//    send.1
-//    while-init = (recv.0, send.0, recv.1, send.1, ...)
-//    while-op = while(while-init) ...
-//    recv-done.0
-//    send-done.0
-//    recv-done.1
-//    send-done.1
-//
-// while-body:
-//    recv-done.0
-//    recv-done.1
-//    ...
-//    send-done.0
-//    send-done.1
-//    recv.0
-//    send.1
-//    recv.1
-//    send.1
-//    ROOT tuple(recv.0, send.0, recv.1, send.1, ...)
-//
-class PipelinedP2PRewriter : public HloModulePass {
- public:
-  absl::string_view name() const override { return "pipelined-p2p-rewriter"; }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_PIPELINED_P2P_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/pipelined_p2p_rewriter_test.cc b/third_party/xla/xla/service/gpu/pipelined_p2p_rewriter_test.cc
deleted file mode 100644
index a0d5830..0000000
--- a/third_party/xla/xla/service/gpu/pipelined_p2p_rewriter_test.cc
+++ /dev/null
@@ -1,674 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/pipelined_p2p_rewriter.h"
-
-#include <gtest/gtest.h>
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/tests/filecheck.h"
-#include "xla/tests/hlo_test_base.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-class PipelinedP2pRewriterTest : public HloTestBase {
- protected:
-  void DoFileCheck(const HloModule* module, absl::string_view expected) {
-    HloPrintOptions options;
-    options.set_print_operand_shape(false);
-    options.set_print_result_shape(false);
-    TF_ASSERT_OK_AND_ASSIGN(bool filecheck_matched,
-                            RunFileCheck(module->ToString(options), expected));
-    EXPECT_TRUE(filecheck_matched);
-  }
-};
-
-TEST_F(PipelinedP2pRewriterTest, SendRecUnpipelinedNotTransform) {
-  const char* kModuleStr = R"(
-HloModule test
-
-cond {
-    param = (u32[], u32[2]) parameter(0)
-    count = get-tuple-element(%param), index=0
-    ub = u32[] constant(11)
-    ROOT result = pred[] compare(count, ub), direction=LT
- }
-
-body {
-    param = (u32[], u32[2]) parameter(0)
-    count = get-tuple-element(param), index=0
-    send-data = u32[2] get-tuple-element(param), index=1
-
-    after-all.0.n = token[] after-all()
-    recv.0 = (u32[2], u32[], token[]) recv(after-all.0.n), channel_id=1,
-      frontend_attributes={
-        _xla_send_recv_source_target_pairs="{{3,0}}",
-        _xla_send_recv_pipeline="0"
-      }
-    send.0 = (u32[2], u32[], token[]) send(send-data, after-all.0.n),
-      channel_id=1,
-      frontend_attributes={
-        _xla_send_recv_source_target_pairs="{{3,0}}",
-        _xla_send_recv_pipeline="0"
-      }
-    recv-done.0 = (u32[2], token[]) recv-done(recv.0), channel_id=1,
-      frontend_attributes={
-        _xla_send_recv_pipeline="0"
-      }
-    send-done.0 = token[] send-done(send.0), channel_id=1,
-      frontend_attributes={
-        _xla_send_recv_pipeline="0"
-      }
-
-    recv-data = u32[2] get-tuple-element(recv-done.0), index=0
-
-    c1 = u32[] constant(1)
-    new_count = u32[] add(count, c1)
-
-    r = u32[2] broadcast(c1), dimensions={}
-    s = u32[2] add(r, recv-data)
-
-    ROOT result = (u32[], u32[2]) tuple(new_count, s)
-  }
-
-  ENTRY test_computation {
-    c0 = u32[] constant(0)
-    c1 = u32[] constant(1)
-    r = u32[] replica-id()
-    a = u32[] add(c1, r)
-    init = u32[2] broadcast(a), dimensions={}
-    while_init = (u32[], u32[2]) tuple(c0, init)
-    while_result = (u32[], u32[2]) while(while_init), body=body, condition=cond,
-      backend_config={"known_trip_count":{"n":"11"}}
-    ROOT recv-data = u32[2] get-tuple-element(while_result), index=1
-  }
-  )";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto module,
-                          ParseAndReturnVerifiedModule(kModuleStr));
-  PipelinedP2PRewriter rewriter;
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, rewriter.Run(module.get()));
-  EXPECT_FALSE(changed);
-}
-
-// Tests the rewrite for a pipelined Send/Recv chain with only one channel
-// group.
-TEST_F(PipelinedP2pRewriterTest, SendRecvPipelined1) {
-  const char* kModuleStr = R"(
-  HloModule test, is_scheduled=true
-
-  while-cond {
-    param = (u32[], (f32[1,1024,1024], token[]), token[]) parameter(0)
-    count = get-tuple-element(param), index=0
-    ub = u32[] constant(25)
-    ROOT cond-result = pred[] compare(count, ub), direction=LT
-  }
-
-  while-body {
-    param = (u32[], (f32[1,1024,1024], token[]), token[]) parameter(0)
-    count = get-tuple-element(param), index=0
-
-    recv-done.q = (f32[1,1024,1024], token[]) get-tuple-element(param), index=1
-    recv-data = f32[1, 1024, 1024] get-tuple-element(recv-done.q), index=0
-
-    c1 = u32[] constant(1)
-    new-count = u32[] add(count, c1)
-    replica = u32[] replica-id()
-    c10 = u32[] constant(10)
-    sum = u32[] add(replica, c10)
-    sum2 = u32[] add(sum, count)
-    conv = f32[] convert(sum2)
-    p = f32[1, 1024, 1024] broadcast(conv), dimensions={}
-    b = f32[1, 1024, 1024] add(p, recv-data)
-    c = f32[1, 1024, 1024] multiply(b, b)
-    d = f32[1, 1024, 1024] tan(c)
-    s = f32[1, 1024, 1024] dot(c, d), lhs_batch_dims={0},
-      lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1}
-    send-data = f32[1, 1024, 1024] add(c, s)
-
-    after-all = token[] after-all()
-    recv = (f32[1, 1024, 1024], u32[], token[]) recv(after-all), channel_id=1,
-      frontend_attributes={
-        _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
-        _xla_send_recv_pipeline="0"
-      }
-    send = (f32[1, 1024, 1024], u32[], token[]) send(send-data, after-all),
-      channel_id=1, frontend_attributes={
-        _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
-        _xla_send_recv_pipeline="0"
-      }
-    recv-done.p = (f32[1,1024,1024], token[]) recv-done(recv), channel_id=1,
-      frontend_attributes={
-        _xla_send_recv_pipeline="0"
-      }
-    send-done.p = token[] send-done(send), channel_id=1,
-      frontend_attributes={
-        _xla_send_recv_pipeline="0"
-      }
-    gte.0 = f32[1,1024,1024] get-tuple-element(recv-done.p), index=0
-    gte.1 = token[] get-tuple-element(recv-done.p), index=1
-    recv-done-tuple = (f32[1,1024,1024], token[]) tuple(gte.0, gte.1)
-    ROOT body-result = (u32[], (f32[1,1024,1024], token[]), token[])
-      tuple(new-count, recv-done-tuple, send-done.p)
-  }
-
-  ENTRY main {
-    c0 = u32[] constant(0)
-    f0 = f32[] constant(0.0)
-    init = f32[1, 1024, 1024] broadcast(f0), dimensions={}
-
-    after-all.1 = token[] after-all()
-    recv.1 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all.1), channel_id=1,
-      frontend_attributes={
-        _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
-        _xla_send_recv_pipeline="0"
-      }
-    send.1 = (f32[1, 1024, 1024], u32[], token[]) send(init, after-all.1), channel_id=1,
-      frontend_attributes={
-        _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
-        _xla_send_recv_pipeline="0"
-      }
-    recv-done.1.p = (f32[1,1024,1024], token[]) recv-done(recv.1), channel_id=1,
-      frontend_attributes={
-        _xla_send_recv_pipeline="0"
-      }
-    send-done.1.p = token[] send-done(send.1), channel_id=1,
-      frontend_attributes={
-        _xla_send_recv_pipeline="0"
-      }
-    while-init.p =  (u32[], (f32[1,1024,1024], token[]), token[])
-      tuple(c0, recv-done.1.p, send-done.1.p)
-    while-result.p = (u32[], (f32[1,1024,1024], token[]), token[])
-      while(while-init.p),
-      body=while-body, condition=while-cond,
-      backend_config={"known_trip_count":{"n":"25"}}
-
-    recv-done.1.q = (f32[1,1024,1024], token[]) get-tuple-element(while-result.p), index=1
-
-    ROOT entry-result = f32[1, 1024, 1024] get-tuple-element(recv-done.1.q), index=0
-  }
-  )";
-
-  const char* kExpected = R"(
-  CHECK: %while-body (param.1: (u32[], (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]))) -> (u32[], (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[])) {
-  CHECK: %param.1 = parameter(0)
-  CHECK: %get-tuple-element = get-tuple-element(%param.1), index=1
-  CHECK: %get-tuple-element.1 = get-tuple-element(%param.1), index=2
-  CHECK: %count.1 = get-tuple-element(%param.1), index=0
-  CHECK: %recv-done.p.clone = recv-done(%get-tuple-element), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"}
-  CHECK: %recv-data = get-tuple-element(%recv-done.p.clone), index=0
-  CHECK: %c1 = constant(1)
-  CHECK: %new-count = add(%count.1, %c1)
-  CHECK: %replica = replica-id()
-  CHECK: %c10 = constant(10)
-  CHECK: %sum = add(%replica, %c10)
-  CHECK: %sum2 = add(%sum, %count.1)
-  CHECK: %conv = convert(%sum2)
-  CHECK: %p = broadcast(%conv), dimensions={}
-  CHECK: %b = add(%p, %recv-data)
-  CHECK: %c = multiply(%b, %b)
-  CHECK: %d = tan(%c)
-  CHECK: %s = dot(%c, %d), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1}
-  CHECK: %send-data = add(%c, %s)
-  CHECK: %after-all = after-all()
-  CHECK: %send-done.p.clone = send-done(%get-tuple-element.1), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"}
-  CHECK{LITERAL}: %recv = recv(%after-all), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}"}
-  CHECK{LITERAL}: %send = send(%send-data, %after-all), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}"}
-  CHECK: ROOT %tuple = tuple(%new-count, %recv, %send)
-  CHECK: }
-
-  CHECK: %while-cond (param: (u32[], (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]))) -> pred[] {
-  CHECK: %param = parameter(0)
-  CHECK: %count = get-tuple-element(%param), index=0
-  CHECK: %ub = constant(25)
-  CHECK: ROOT %cond-result = compare(%count, %ub), direction=LT
-  CHECK: }
-
-  CHECK: ENTRY %main () -> f32[1,1024,1024] {
-  CHECK: %c0 = constant(0)
-  CHECK: %f0 = constant(0)
-  CHECK: %init = broadcast(%f0), dimensions={}
-  CHECK: %after-all.1 = after-all()
-  CHECK{LITERAL}: %recv.1 = recv(%after-all.1), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}"}
-  CHECK{LITERAL}: %send.1 = send(%init, %after-all.1), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}"}
-  CHECK: %while-init = tuple(%c0, %recv.1, %send.1)
-  CHECK: %while-result.p.clone = while(%while-init), condition=%while-cond, body=%while-body,
-  CHECK-SAME{LITERAL}: backend_config={"known_trip_count":{"n":"25"}}
-  CHECK: %get-tuple-element.2 = get-tuple-element(%while-result.p.clone), index=1
-  CHECK: %get-tuple-element.3 = get-tuple-element(%while-result.p.clone), index=2
-  CHECK: %recv-done.1.p.clone = recv-done(%get-tuple-element.2), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"}
-  CHECK: %send-done.1.p.clone = send-done(%get-tuple-element.3), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"}
-  CHECK: ROOT %entry-result = get-tuple-element(%recv-done.1.p.clone), index=0
-  CHECK: })";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto module,
-                          ParseAndReturnVerifiedModule(kModuleStr));
-  PipelinedP2PRewriter rewriter;
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, rewriter.Run(module.get()));
-  EXPECT_TRUE(changed);
-
-  DoFileCheck(module.get(), kExpected);
-}
-
-// Repeats the Send/Recv pattern in the previous test, to test that we can
-// rewrite a routine with multiple pipelined loops without crashing.
-TEST_F(PipelinedP2pRewriterTest, SendRecvTwoPipelinedWhileLoops) {
-  const char* kModuleStr = R"(
-  HloModule test, is_scheduled=true
-
-  while-cond {
-    param = (u32[], (f32[1,1024,1024], token[]), token[]) parameter(0)
-    count = get-tuple-element(param), index=0
-    ub = u32[] constant(25)
-    ROOT cond-result = pred[] compare(count, ub), direction=LT
-  }
-
-  while-body {
-    param = (u32[], (f32[1,1024,1024], token[]), token[]) parameter(0)
-    count = get-tuple-element(param), index=0
-
-    recv-done.q = (f32[1,1024,1024], token[]) get-tuple-element(param), index=1
-    send-data = f32[1, 1024, 1024] get-tuple-element(recv-done.q), index=0
-
-    c1 = u32[] constant(1)
-    new-count = u32[] add(count, c1)
-
-    after-all = token[] after-all()
-    recv = (f32[1, 1024, 1024], u32[], token[]) recv(after-all), channel_id=1,
-      frontend_attributes={
-        _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
-        _xla_send_recv_pipeline="0"
-      }
-    send = (f32[1, 1024, 1024], u32[], token[]) send(send-data, after-all),
-      channel_id=1, frontend_attributes={
-        _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
-        _xla_send_recv_pipeline="0"
-      }
-    recv-done.p = (f32[1,1024,1024], token[]) recv-done(recv), channel_id=1,
-      frontend_attributes={
-        _xla_send_recv_pipeline="0"
-      }
-    send-done.p = token[] send-done(send), channel_id=1,
-      frontend_attributes={
-        _xla_send_recv_pipeline="0"
-      }
-    gte.0 = f32[1,1024,1024] get-tuple-element(recv-done.p), index=0
-    gte.1 = token[] get-tuple-element(recv-done.p), index=1
-    recv-done-tuple = (f32[1,1024,1024], token[]) tuple(gte.0, gte.1)
-    ROOT body-result = (u32[], (f32[1,1024,1024], token[]), token[])
-      tuple(new-count, recv-done-tuple, send-done.p)
-  }
-
-  while-cond-2 {
-    param = (u32[], (f32[1,1024,1024], token[]), token[]) parameter(0)
-    count = get-tuple-element(param), index=0
-    ub = u32[] constant(25)
-    ROOT cond-result = pred[] compare(count, ub), direction=LT
-  }
-
-  while-body-2 {
-    param = (u32[], (f32[1,1024,1024], token[]), token[]) parameter(0)
-    count = get-tuple-element(param), index=0
-
-    recv-done.q = (f32[1,1024,1024], token[]) get-tuple-element(param), index=1
-    send-data = f32[1, 1024, 1024] get-tuple-element(recv-done.q), index=0
-
-    c1 = u32[] constant(1)
-    new-count = u32[] add(count, c1)
-
-    after-all = token[] after-all()
-    recv = (f32[1, 1024, 1024], u32[], token[]) recv(after-all), channel_id=1,
-      frontend_attributes={
-        _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
-        _xla_send_recv_pipeline="0"
-      }
-    send = (f32[1, 1024, 1024], u32[], token[]) send(send-data, after-all),
-      channel_id=1, frontend_attributes={
-        _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
-        _xla_send_recv_pipeline="0"
-      }
-    recv-done.p = (f32[1,1024,1024], token[]) recv-done(recv), channel_id=1,
-      frontend_attributes={
-        _xla_send_recv_pipeline="0"
-      }
-    send-done.p = token[] send-done(send), channel_id=1,
-      frontend_attributes={
-        _xla_send_recv_pipeline="0"
-      }
-    gte.0 = f32[1,1024,1024] get-tuple-element(recv-done.p), index=0
-    gte.1 = token[] get-tuple-element(recv-done.p), index=1
-    recv-done-tuple = (f32[1,1024,1024], token[]) tuple(gte.0, gte.1)
-    ROOT body-result = (u32[], (f32[1,1024,1024], token[]), token[])
-      tuple(new-count, recv-done-tuple, send-done.p)
-  }
-
-  ENTRY main {
-    c0 = u32[] constant(0)
-    f0 = f32[] constant(0.0)
-    init = f32[1, 1024, 1024] broadcast(f0), dimensions={}
-
-    after-all.1 = token[] after-all()
-    recv.1 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all.1), channel_id=1,
-      frontend_attributes={
-        _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
-        _xla_send_recv_pipeline="0"
-      }
-    send.1 = (f32[1, 1024, 1024], u32[], token[]) send(init, after-all.1), channel_id=1,
-      frontend_attributes={
-        _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
-        _xla_send_recv_pipeline="0"
-      }
-    recv-done.1.p = (f32[1,1024,1024], token[]) recv-done(recv.1), channel_id=1,
-      frontend_attributes={
-        _xla_send_recv_pipeline="0"
-      }
-    send-done.1.p = token[] send-done(send.1), channel_id=1,
-      frontend_attributes={
-        _xla_send_recv_pipeline="0"
-      }
-    while-init.p =  (u32[], (f32[1,1024,1024], token[]), token[])
-      tuple(c0, recv-done.1.p, send-done.1.p)
-    while-result.p = (u32[], (f32[1,1024,1024], token[]), token[])
-      while(while-init.p),
-      body=while-body, condition=while-cond,
-      backend_config={"known_trip_count":{"n":"25"}}
-
-    recv-done.1.q = (f32[1,1024,1024], token[]) get-tuple-element(while-result.p), index=1
-
-    after-all-2.1 = token[] after-all()
-    recv-2.1 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all-2.1), channel_id=2,
-      frontend_attributes={
-        _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
-        _xla_send_recv_pipeline="0"
-      }
-    send-2.1 = (f32[1, 1024, 1024], u32[], token[]) send(recv-done.1.q, after-all-2.1), channel_id=2,
-      frontend_attributes={
-        _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
-        _xla_send_recv_pipeline="0"
-      }
-    recv-done-2.1.p = (f32[1,1024,1024], token[]) recv-done(recv-2.1), channel_id=2,
-      frontend_attributes={
-        _xla_send_recv_pipeline="0"
-      }
-    send-done-2.1.p = token[] send-done(send-2.1), channel_id=2,
-      frontend_attributes={
-        _xla_send_recv_pipeline="0"
-      }
-    while-init-2.p =  (u32[], (f32[1,1024,1024], token[]), token[])
-      tuple(c0, recv-done-2.1.p, send-done-2.1.p)
-    while-result-2.p = (u32[], (f32[1,1024,1024], token[]), token[])
-      while(while-init-2.p),
-      body=while-body-2, condition=while-cond-2,
-      backend_config={"known_trip_count":{"n":"25"}}
-
-    recv-done-2.1.q = (f32[1,1024,1024], token[]) get-tuple-element(while-result-2.p), index=1
-
-    ROOT entry-result = f32[1, 1024, 1024] get-tuple-element(recv-done-2.1.q), index=0
-  }
-  )";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto module,
-                          ParseAndReturnVerifiedModule(kModuleStr));
-  PipelinedP2PRewriter rewriter;
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, rewriter.Run(module.get()));
-  // Check that we transform the module without crashing.
-  EXPECT_TRUE(changed);
-}
-
-// Tests the rewrite for a pipelined Send/Recv chain with two channel groups.
-TEST_F(PipelinedP2pRewriterTest, SendRecvPipelined2) {
-  const char* kModuleStr = R"(
-  HloModule test, is_scheduled=true
-
-  while-cond {
-    param = (u32[], (f32[1,1024,1024], token[]), token[],
-      (f32[1,1024,1024], token[]), token[]) parameter(0)
-    count = get-tuple-element(param), index=0
-    ub = u32[] constant(25)
-    ROOT cond-result = pred[] compare(count, ub), direction=LT
-  }
-
-  while-body {
-    param = (u32[], (f32[1,1024,1024], token[]), token[],
-      (f32[1,1024,1024], token[]), token[]) parameter(0)
-    count = get-tuple-element(param), index=0
-
-    recv-done.0.q = (f32[1,1024,1024], token[]) get-tuple-element(param), index=1
-    recv-data.0 = f32[1, 1024, 1024] get-tuple-element(recv-done.0.q), index=0
-    recv-done.1.q = (f32[1,1024,1024], token[]) get-tuple-element(param), index=3
-    recv-data.1 = f32[1, 1024, 1024] get-tuple-element(recv-done.1.q), index=0
-
-    replica = u32[] replica-id()
-    constant0 = u32[] constant(0)
-    compare0 = pred[] compare(replica, constant0), direction=EQ
-    compare = pred[1, 1024, 1024] broadcast(compare0), dimensions={}
-    recv-data = f32[1, 1024, 1024] select(compare, recv-data.0, recv-data.1)
-
-    c1 = u32[] constant(1)
-    new-count = u32[] add(count, c1)
-    c10 = u32[] constant(10)
-    sum = u32[] add(replica, c10)
-    sum2 = u32[] add(sum, count)
-    conv = f32[] convert(sum2)
-    p = f32[1, 1024, 1024] broadcast(conv), dimensions={}
-    b = f32[1, 1024, 1024] add(p, recv-data)
-    c = f32[1, 1024, 1024] multiply(b, b)
-    d = f32[1, 1024, 1024] tan(c)
-    s = f32[1, 1024, 1024] dot(c, d), lhs_batch_dims={0},
-      lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1}
-    send-data = f32[1, 1024, 1024] add(c, s)
-
-    after-all = token[] after-all()
-    recv = (f32[1, 1024, 1024], u32[], token[]) recv(after-all), channel_id=1,
-      frontend_attributes={
-        _xla_send_recv_source_target_pairs="{{3,0}}",
-        _xla_send_recv_pipeline="0"
-      }
-    send = (f32[1, 1024, 1024], u32[], token[]) send(send-data, after-all),
-      channel_id=1, frontend_attributes={
-        _xla_send_recv_source_target_pairs="{{3,0}}",
-        _xla_send_recv_pipeline="0"
-      }
-    recv-done.p = (f32[1,1024,1024], token[]) recv-done(recv), channel_id=1,
-      frontend_attributes={
-        _xla_send_recv_pipeline="0"
-      }
-    send-done.p = token[] send-done(send), channel_id=1,
-      frontend_attributes={
-        _xla_send_recv_pipeline="0"
-      }
-
-    after-all.1 = token[] after-all()
-    recv.1 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all.1), channel_id=2,
-      frontend_attributes={
-        _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}}",
-        _xla_send_recv_pipeline="1"
-      }
-    send.1 = (f32[1, 1024, 1024], u32[], token[]) send(send-data, after-all.1),
-      channel_id=2, frontend_attributes={
-       _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}}",
-       _xla_send_recv_pipeline="1"
-      }
-    recv-done.1.p = (f32[1,1024,1024], token[]) recv-done(recv.1), channel_id=2,
-      frontend_attributes={
-        _xla_send_recv_pipeline="1"
-      }
-    send-done.1.p = token[] send-done(send.1), channel_id=2,
-      frontend_attributes={
-        _xla_send_recv_pipeline="1"
-      }
-
-    ROOT body-result = (u32[], (f32[1,1024,1024], token[]), token[],
-      (f32[1,1024,1024], token[]), token[])
-      tuple(new-count, recv-done.p, send-done.p, recv-done.1.p, send-done.1.p)
-  }
-
-  ENTRY main {
-    c0 = u32[] constant(0)
-    f0 = f32[] constant(0.0)
-    init = f32[1, 1024, 1024] broadcast(f0), dimensions={}
-
-    after-all.2 = token[] after-all()
-    recv.2 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all.2), channel_id=1,
-      frontend_attributes={
-       _xla_send_recv_source_target_pairs="{{3,0}}",
-       _xla_send_recv_pipeline="0"
-    }
-    send.2 = (f32[1, 1024, 1024], u32[], token[]) send(init, after-all.2), channel_id=1,
-      frontend_attributes={
-       _xla_send_recv_source_target_pairs="{{3,0}}",
-       _xla_send_recv_pipeline="0"
-    }
-    recv-done.2.p = (f32[1,1024,1024], token[]) recv-done(recv.2), channel_id=1,
-      frontend_attributes={
-        _xla_send_recv_pipeline="0"
-      }
-    send-done.2.p = token[] send-done(send.2), channel_id=1,
-      frontend_attributes={
-        _xla_send_recv_pipeline="0"
-      }
-
-    after-all.3 = token[] after-all()
-    recv.3 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all.3), channel_id=2,
-      frontend_attributes={
-       _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}}",
-       _xla_send_recv_pipeline="1"
-    }
-    send.3 = (f32[1, 1024, 1024], u32[], token[]) send(init, after-all.3), channel_id=2,
-      frontend_attributes={
-       _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}}",
-       _xla_send_recv_pipeline="1"
-    }
-    recv-done.3.p = (f32[1,1024,1024], token[]) recv-done(recv.3), channel_id=2,
-      frontend_attributes={
-        _xla_send_recv_pipeline="1"
-      }
-    send-done.3.p = token[] send-done(send.3), channel_id=2,
-      frontend_attributes={
-        _xla_send_recv_pipeline="1"
-      }
-
-    while-init.p =  (u32[], (f32[1,1024,1024], token[]), token[],
-      (f32[1,1024,1024], token[]), token[]) tuple(c0, recv-done.2.p, send-done.2.p, recv-done.3.p, send-done.3.p)
-    while-result.p = (u32[], (f32[1,1024,1024], token[]), token[],
-      (f32[1,1024,1024], token[]), token[]) while(while-init.p),
-      body=while-body, condition=while-cond,
-      backend_config={"known_trip_count":{"n":"25"}}
-
-    recv-done.2.q = (f32[1,1024,1024], token[]) get-tuple-element(while-result.p), index=1
-    recv-data.2 = f32[1, 1024, 1024] get-tuple-element(recv-done.2.q), index=0
-    recv-done.3.q = (f32[1,1024,1024], token[]) get-tuple-element(while-result.p), index=3
-    recv-data.3 = f32[1, 1024, 1024] get-tuple-element(recv-done.3.q), index=0
-
-    replica = u32[] replica-id()
-    constant0 = u32[] constant(0)
-    compare0 = pred[] compare(replica, constant0), direction=EQ
-    compare = pred[1, 1024, 1024] broadcast(compare0), dimensions={}
-    ROOT entry-result = f32[1, 1024, 1024] select(compare, recv-data.2, recv-data.3)
-  }
-  )";
-
-  const char* kExpected = R"(
-  CHECK: %while-body (param.1: (u32[], (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]))) -> (u32[], (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[])) {
-  CHECK: %param.1 = parameter(0)
-  CHECK: %get-tuple-element = get-tuple-element(%param.1), index=1
-  CHECK: %get-tuple-element.1 = get-tuple-element(%param.1), index=2
-  CHECK: %get-tuple-element.2 = get-tuple-element(%param.1), index=3
-  CHECK: %get-tuple-element.3 = get-tuple-element(%param.1), index=4
-  CHECK: %count.1 = get-tuple-element(%param.1), index=0
-  CHECK: %recv-done.p.clone = recv-done(%get-tuple-element), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"}
-  CHECK: %recv-data.0 = get-tuple-element(%recv-done.p.clone), index=0
-  CHECK: %recv-done.1.p.clone = recv-done(%get-tuple-element.2), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1"}
-  CHECK: %recv-data.1 = get-tuple-element(%recv-done.1.p.clone), index=0
-  CHECK: %replica = replica-id()
-  CHECK: %constant0 = constant(0)
-  CHECK: %compare0 = compare(%replica, %constant0), direction=EQ
-  CHECK: %compare = broadcast(%compare0), dimensions={}
-  CHECK: %recv-data.2 = select(%compare, %recv-data.0, %recv-data.1)
-  CHECK: %c1 = constant(1)
-  CHECK: %new-count = add(%count.1, %c1)
-  CHECK: %c10 = constant(10)
-  CHECK: %sum = add(%replica, %c10)
-  CHECK: %sum2 = add(%sum, %count.1)
-  CHECK: %conv = convert(%sum2)
-  CHECK: %p = broadcast(%conv), dimensions={}
-  CHECK: %b = add(%p, %recv-data.2)
-  CHECK: %c = multiply(%b, %b)
-  CHECK: %d = tan(%c)
-  CHECK: %s = dot(%c, %d), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1}
-  CHECK: %send-data = add(%c, %s)
-  CHECK: %after-all = after-all()
-  CHECK: %send-done.p.clone = send-done(%get-tuple-element.1), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"}
-  CHECK: %send-done.1.p.clone = send-done(%get-tuple-element.3), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1"}
-  CHECK{LITERAL}: %recv = recv(%after-all), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs="{{3,0}}"}
-  CHECK{LITERAL}: %send = send(%send-data, %after-all), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs="{{3,0}}"}
-  CHECK: %after-all.1 = after-all()
-  CHECK{LITERAL}: %recv.1 = recv(%after-all.1), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1",_xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}}"}
-  CHECK{LITERAL}: %send.1 = send(%send-data, %after-all.1), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1",_xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}}"}
-  CHECK: ROOT %tuple = tuple(%new-count, %recv, %send, %recv.1, %send.1)
-  CHECK: }
-
-  CHECK: %while-cond (param: (u32[], (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]))) -> pred[] {
-  CHECK: %param = parameter(0)
-  CHECK: %count = get-tuple-element(%param), index=0
-  CHECK: %ub = constant(25)
-  CHECK: ROOT %cond-result = compare(%count, %ub), direction=LT
-  CHECK: }
-
-  CHECK: ENTRY %main () -> f32[1,1024,1024] {
-  CHECK: %c0 = constant(0)
-  CHECK: %f0 = constant(0)
-  CHECK: %init = broadcast(%f0), dimensions={}
-  CHECK: %after-all.2 = after-all()
-  CHECK{LITERAL}: %recv.2 = recv(%after-all.2), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs="{{3,0}}"}
-  CHECK{LITERAL}: %send.2 = send(%init, %after-all.2), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs="{{3,0}}"}
-  CHECK: %after-all.3 = after-all()
-  CHECK{LITERAL}: %recv.3 = recv(%after-all.3), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1",_xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}}"}
-  CHECK{LITERAL}: %send.3 = send(%init, %after-all.3), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1",_xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}}"}
-  CHECK: %while-init = tuple(%c0, %recv.2, %send.2, %recv.3, %send.3)
-  CHECK{LITERAL}: %while-result.p.clone = while(%while-init), condition=%while-cond, body=%while-body, backend_config={"known_trip_count":{"n":"25"}}
-  CHECK: %get-tuple-element.4 = get-tuple-element(%while-result.p.clone), index=1
-  CHECK: %get-tuple-element.5 = get-tuple-element(%while-result.p.clone), index=2
-  CHECK: %get-tuple-element.6 = get-tuple-element(%while-result.p.clone), index=3
-  CHECK: %get-tuple-element.7 = get-tuple-element(%while-result.p.clone), index=4
-  CHECK: %recv-done.2.p.clone = recv-done(%get-tuple-element.4), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"}
-  CHECK: %recv-data.3 = get-tuple-element(%recv-done.2.p.clone), index=0
-  CHECK: %recv-done.3.p.clone = recv-done(%get-tuple-element.6), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1"}
-  CHECK: %recv-data.4 = get-tuple-element(%recv-done.3.p.clone), index=0
-  CHECK: %replica.1 = replica-id()
-  CHECK: %constant0.1 = constant(0)
-  CHECK: %compare0.1 = compare(%replica.1, %constant0.1), direction=EQ
-  CHECK: %compare.1 = broadcast(%compare0.1), dimensions={}
-  CHECK: %send-done.2.p.clone = send-done(%get-tuple-element.5), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"}
-  CHECK: %send-done.3.p.clone = send-done(%get-tuple-element.7), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1"}
-  CHECK: ROOT %entry-result = select(%compare.1, %recv-data.3, %recv-data.4)
-  CHECK: })";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto module,
-                          ParseAndReturnVerifiedModule(kModuleStr));
-  PipelinedP2PRewriter rewriter;
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, rewriter.Run(module.get()));
-  EXPECT_TRUE(changed);
-
-  DoFileCheck(module.get(), kExpected);
-}
-
-}  // namespace
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.cc b/third_party/xla/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.cc
index 928f4bf..7ced52c 100644
--- a/third_party/xla/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.cc
+++ b/third_party/xla/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.cc
@@ -22,10 +22,10 @@
 #include "xla/hlo/ir/hlo_module.h"
 #include "xla/service/copy_insertion.h"
 #include "xla/service/cpu_gpu_shape_verifier.h"
-#include "xla/service/gpu/alias_passthrough_params.h"
-#include "xla/service/gpu/copy_fusion.h"
-#include "xla/service/gpu/gpu_sanitize_constant_names.h"
-#include "xla/service/gpu/horizontal_loop_fusion.h"
+#include "xla/service/gpu/transforms/alias_passthrough_params.h"
+#include "xla/service/gpu/transforms/copy_fusion.h"
+#include "xla/service/gpu/transforms/horizontal_loop_fusion.h"
+#include "xla/service/gpu/transforms/sanitize_constant_names.h"
 #include "xla/service/hlo_dataflow_analysis.h"
 #include "xla/service/hlo_dce.h"
 #include "xla/service/hlo_pass_pipeline.h"
@@ -78,14 +78,14 @@
   }
 
   // We are using a sub-pipeline here, so that the verifier only runs after both
-  // GpuHorizontalLoopFusion and HloDCE.
+  // HorizontalLoopFusion and HloDCE.
   auto& sub_pipeline =
       pipeline.AddPass<HloPassPipeline>("horizontal-loop-fusion-for-copy");
   // To fuse the copy.
   sub_pipeline.AddPass<CopyFusion>();
-  sub_pipeline.AddPass<GpuHorizontalLoopFusion>("copy_");
+  sub_pipeline.AddPass<HorizontalLoopFusion>("copy_");
   sub_pipeline.AddPass<HloDCE>();
-  pipeline.AddPass<GpuSanitizeConstantNames>();
+  pipeline.AddPass<SanitizeConstantNames>();
   return pipeline;
 }
 
diff --git a/third_party/xla/xla/service/gpu/priority_fusion.cc b/third_party/xla/xla/service/gpu/priority_fusion.cc
deleted file mode 100644
index d57a83c..0000000
--- a/third_party/xla/xla/service/gpu/priority_fusion.cc
+++ /dev/null
@@ -1,871 +0,0 @@
-/* Copyright 2017 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/priority_fusion.h"
-
-#include <cstddef>
-#include <cstdint>
-#include <functional>
-#include <iterator>
-#include <limits>
-#include <map>
-#include <memory>
-#include <string>
-#include <utility>
-#include <variant>
-#include <vector>
-
-#include "absl/container/flat_hash_map.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/log/check.h"
-#include "absl/meta/type_traits.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/str_format.h"
-#include "absl/strings/string_view.h"
-#include "absl/synchronization/mutex.h"
-#include "absl/time/time.h"
-#include "llvm/ADT/STLExtras.h"
-#include "mlir/IR/MLIRContext.h"
-#include "xla/debug_options_flags.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/dump.h"
-#include "xla/service/fusion_queue.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/fusion_process_dump.pb.h"
-#include "xla/service/gpu/gpu_fusible.h"
-#include "xla/service/gpu/hlo_fusion_analysis.h"
-#include "xla/service/gpu/hlo_traversal.h"
-#include "xla/service/gpu/model/fusion_analysis_cache.h"
-#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h"
-#include "xla/service/gpu/model/gpu_performance_model.h"
-#include "xla/service/gpu/model/gpu_performance_model_base.h"
-#include "xla/service/gpu/model/symbolic_tile_analysis.h"
-#include "xla/service/hlo_graph_dumper.h"
-#include "xla/service/instruction_fusion.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/blocking_counter.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h"
-#include "tsl/platform/status.h"
-#include "tsl/platform/threadpool.h"
-
-namespace xla {
-namespace gpu {
-
-namespace {
-bool ElementIsF32OrF16(const Shape& shape) {
-  PrimitiveType type = shape.element_type();
-  return type == F32 || type == F16;
-}
-
-bool IsFusible(const HloInstruction& instr) {
-  // Side-effecting operations are not fusible.
-  if (!instr.IsFusible()) {
-    return false;
-  }
-
-  // Element-wise operations are always fusible.
-  if (instr.IsElementwise()) {
-    return true;
-  }
-
-  // Other non-elementwise ops also supported by elemental fusion.
-  switch (instr.opcode()) {
-    case HloOpcode::kFusion:
-      return instr.fusion_kind() != HloInstruction::FusionKind::kCustom;
-
-    case HloOpcode::kCopy:
-    case HloOpcode::kIota:
-    case HloOpcode::kConstant:
-    case HloOpcode::kReduce:
-    case HloOpcode::kBitcast:
-    case HloOpcode::kBroadcast:
-    case HloOpcode::kConcatenate:
-    case HloOpcode::kDynamicSlice:
-    case HloOpcode::kDynamicUpdateSlice:
-    case HloOpcode::kGather:
-    case HloOpcode::kPad:
-    case HloOpcode::kReduceWindow:
-    case HloOpcode::kReshape:
-    case HloOpcode::kReverse:
-    case HloOpcode::kScatter:
-    case HloOpcode::kSlice:
-    case HloOpcode::kTranspose:
-      return true;
-    default:
-      return false;
-  }
-}
-
-// An implementation of FusionQueue that determines whether to fuse instructions
-// according to a cost model, and chooses the next fusion candidate according to
-// dynamically updated priorities. The elements in the queue are producer nodes
-// that could be fused, and the priority of a producer is the benefit in
-// performance when fusing it to all of its fusible users. We greedily pick the
-// max-benefit producer to fuse, and update the estimated benefits of the fused
-// nodes and their operands.
-class GpuPriorityFusionQueue {
-  using Priority = int64_t;
-  using CanFuseCallback = std::function<FusionDecision(
-      HloInstruction* /*producer*/, int64_t /*consumer operand_index*/)>;
-
- public:
-  GpuPriorityFusionQueue(
-      HloComputation* computation,
-      const GpuHloCostAnalysis::Options& cost_analysis_options,
-      const se::DeviceDescription* device_info,
-      FusionProcessDumpProto* fusion_process_dump,
-      tsl::thread::ThreadPool* thread_pool, mlir::MLIRContext* mlir_context,
-      HloFusionAnalysisCache& fusion_analysis_cache,
-      bool triton_softmax_priority_fusion_enabled)
-      : computation_(computation),
-        device_info_(device_info),
-        cost_analysis_(cost_analysis_options, *device_info),
-        fusion_process_dump_(fusion_process_dump),
-        thread_pool_(thread_pool),
-        mlir_context_(mlir_context),
-        fusion_analysis_cache_(fusion_analysis_cache),
-        triton_softmax_priority_fusion_enabled_(
-            triton_softmax_priority_fusion_enabled) {
-    VLOG(2) << "Running full HLO cost analysis for " << computation_->name();
-    TF_CHECK_OK(computation_->Accept(&cost_analysis_));
-
-    dump_fusion_visualization_ = computation->parent()
-                                     ->config()
-                                     .debug_options()
-                                     .xla_dump_fusion_visualization();
-
-    // Initializes the priority queue.
-    std::vector<HloInstruction*> instructions;
-    for (auto* instruction : computation->MakeInstructionPostOrder()) {
-      if (instruction->opcode() == HloOpcode::kParameter ||
-          instruction->user_count() == 0 || !instruction->IsFusible() ||
-          instruction->opcode() == HloOpcode::kTuple ||
-          instruction->opcode() == HloOpcode::kGetTupleElement) {
-        continue;
-      }
-      instructions.push_back(instruction);
-    }
-
-    ComputeAndSetPriorities(instructions);
-  }
-
-  void ComputeAndSetPriorities(
-      const std::vector<HloInstruction*>& instructions) {
-    std::vector<Priority> priorities = ComputePriorities(instructions);
-
-    for (auto [instruction, priority] : llvm::zip(instructions, priorities)) {
-      auto key = std::make_pair(priority, instruction->unique_id());
-
-      // Remove instruction with the old priority from the queue.
-      auto reverse_it = reverse_map_.find(instruction);
-      if (reverse_it != reverse_map_.end()) {
-        const PriorityQueue::iterator& queue_it = reverse_it->second;
-        // Priority didn't change. Nothing to do.
-        if (key == queue_it->first) {
-          continue;
-        }
-        producer_priority_queue_.erase(queue_it);
-        reverse_map_.erase(reverse_it);
-      }
-
-      // If the priority is negative, it's not helpful to perform fusion on this
-      // instruction.
-      if (priority < 0) {
-        continue;
-      }
-
-      auto emplace_result = producer_priority_queue_.emplace(key, instruction);
-      reverse_map_.emplace(instruction, emplace_result.first);
-    }
-  }
-
-  std::vector<Priority> ComputePriorities(
-      const std::vector<HloInstruction*>& instructions) {
-    auto schedule_or_run = [this](std::function<void()> fn) {
-      if (thread_pool_) {
-        thread_pool_->Schedule(std::move(fn));
-      } else {
-        fn();
-      }
-    };
-    tsl::BlockingCounter counter(instructions.size());
-    std::vector<Priority> priorities(instructions.size());
-
-    for (size_t i = 0; i < instructions.size(); ++i) {
-      schedule_or_run([&, i] {
-        priorities[i] = CalculateProducerPriority(instructions[i]);
-        counter.DecrementCount();
-      });
-    }
-    counter.Wait();
-    return priorities;
-  }
-
-  // Gets the next pair of (producer, consumers) from the queue for fusion.
-  // Returns true if there is the next producer to fuse, otherwise false. Stores
-  // the producer and consumers in `current_producer_` and `current_consumers_`.
-  bool DequeueNextProducer() {
-    current_producer_ = nullptr;
-    current_consumers_.clear();
-
-    while (!producer_priority_queue_.empty() && current_consumers_.empty()) {
-      auto next_it = std::prev(producer_priority_queue_.end());
-
-      current_producer_ = next_it->second;
-      producer_priority_queue_.erase(next_it);
-      reverse_map_.erase(current_producer_);
-
-      current_consumers_ = current_producer_->users();
-
-      if (current_producer_->opcode() == HloOpcode::kBitcast) {
-        // We don't check if bitcasts can be fused with all consumers, so we
-        // have to do it here.
-        llvm::erase_if(current_consumers_, [&](HloInstruction* consumer) {
-          return !CanFuseCached(current_producer_, consumer);
-        });
-      }
-    }
-
-    return !current_consumers_.empty();
-  }
-
-  // Update priorities of all affected ops.
-  void UpdatePriorities() {
-    // Revisit costs of all updated ops. It's important to update cost analysis
-    // before recalculating priorities.
-    for (auto instruction : to_update_priority_) {
-      TF_CHECK_OK(cost_analysis_.RevisitInstruction(instruction));
-    }
-
-    ComputeAndSetPriorities(std::vector<HloInstruction*>{
-        to_update_priority_.begin(), to_update_priority_.end()});
-
-    to_update_priority_.clear();
-  }
-
-  // Prepares producer and consumer instruction to be fused. Invalidates caches
-  // and writes logs.
-  void PreFusion(HloInstruction* producer, HloInstruction* consumer) {
-    if (dump_fusion_visualization_) {
-      RegisterFusionState(
-          *computation_,
-          absl::StrCat("About to fuse |", producer->name(), "| into |",
-                       consumer->name(), "| inside PriorityFusion"),
-          *consumer, producer);
-    }
-
-    InvalidateCaches(producer);
-    InvalidateCaches(consumer);
-  }
-
-  // Invalidates all cached value related to this instruction. Called before the
-  // instruction is fused. The instruction can be either producer or consumer.
-  void InvalidateCaches(HloInstruction* instruction) {
-    can_fuse_cache_.erase(instruction);
-    for (const HloInstruction* operand : instruction->operands()) {
-      auto it = can_fuse_cache_.find(operand);
-      if (it != can_fuse_cache_.end()) {
-        it->second.erase(instruction);
-      }
-    }
-
-    gpu_performance_model_cache_.Invalidate(*instruction);
-    fusion_analysis_cache_.Invalidate(*instruction);
-    fusion_info_cache_.Invalidate(instruction);
-  }
-
-  // Updates data for the new fusion instruction and its users and operands.
-  void OnFusingInstruction(HloInstruction* fusion,
-                           HloInstruction* original_producer,
-                           HloInstruction* original_consumer) {
-    if (fusion_process_dump_) {
-      auto* fusion_step =
-          fusion_process_dump_->add_fusion_steps()->mutable_fusion();
-
-      // Explicit std::string is needed for OSS proto implementation.
-      fusion_step->set_fusion_name(std::string(fusion->name()));
-      fusion_step->set_producer_name(std::string(original_producer->name()));
-      fusion_step->set_consumer_name(std::string(original_consumer->name()));
-    }
-
-    if (dump_fusion_visualization_) {
-      RegisterFusionState(
-          *computation_,
-          absl::StrCat("Fused |", original_producer->name(), "| into |",
-                       fusion->name(), "| inside PriorityFusion"),
-          *fusion);
-    }
-
-    // The original consumer was replaced with the fusion, but it's pointer can
-    // still be referenced somewhere, for example, in to_update_priority_.
-    // Priority recomputation is called before DCE. Remove all references to
-    // the original consumer here.
-    if (fusion != original_consumer) {
-      RemoveInstruction(original_consumer);
-    }
-
-    // Detach 'original_producer' from its operands if it has no users.
-    // This avoids having it appear as a "phantom" user in subsequent priority
-    // calculations on 'fusion.operands' below, before it is finally removed
-    // in 'RemoveInstruction'.
-    if (original_producer->user_count() == 0) {
-      original_producer->DetachFromOperandsAndUsers();
-    }
-
-    // Collect the instructions whose priorities need to be updated.
-    for (HloInstruction* operand : fusion->operands()) {
-      if (operand == original_producer ||
-          operand->opcode() == HloOpcode::kConstant ||
-          operand->opcode() == HloOpcode::kGetTupleElement) {
-        continue;
-      }
-      // Need to consider only instructions that are fusible, e.g., rng with
-      // greater than one user is not fusible.
-      if (!operand->IsFusible()) {
-        continue;
-      }
-
-      to_update_priority_.insert(operand);
-    }
-    to_update_priority_.insert(fusion);
-  }
-
-  // Removes data for the instruction.
-  void RemoveInstruction(HloInstruction* instruction) {
-    to_update_priority_.erase(instruction);
-    fusion_analysis_cache_.Invalidate(*instruction);
-
-    auto reverse_it = reverse_map_.find(instruction);
-    if (reverse_it == reverse_map_.end()) {
-      return;
-    }
-    producer_priority_queue_.erase(reverse_it->second);
-    reverse_map_.erase(reverse_it);
-  }
-
-  HloInstruction* current_producer() { return current_producer_; }
-
-  const std::vector<HloInstruction*>& current_consumers() {
-    return current_consumers_;
-  }
-
- private:
-  // Returns the priority of the producer based on its current operands and
-  // users.
-  Priority CalculateProducerPriority(HloInstruction* producer) {
-    // Bitcasts should always be fused first, since they are no-ops.
-    if (producer->opcode() == HloOpcode::kBitcast) {
-      return std::numeric_limits<Priority>::max();
-    }
-    // We always fuse constants, but the cost model doesn't handle them very
-    // well: fusing constants changes costs significantly. Also, there's no
-    // point recomputing priorities. Therefore, we fuse all of them at the end.
-    if (producer->opcode() == HloOpcode::kConstant) {
-      return std::numeric_limits<Priority>::min();
-    }
-
-    // Don't fuse if we can't fuse in all users.
-    if (auto fusion_decision = CanFuseWithAllNonBitcastUsers(producer);
-        !fusion_decision) {
-      if (fusion_process_dump_) {
-        absl::MutexLock lock(&fusion_process_dump_mutex_);
-        auto* step = fusion_process_dump_->add_fusion_steps()
-                         ->mutable_producer_ineligible();
-        step->set_producer_name(std::string(producer->name()));
-        step->set_reason(fusion_decision.Explain());
-      }
-      return std::numeric_limits<Priority>::min();
-    }
-
-    GpuPerformanceModel::RunTimes run_times =
-        GpuPerformanceModel::EstimateRunTimesForPriorityFusion(
-            producer, *device_info_, &cost_analysis_,
-            GpuPerformanceModelOptions::PriorityFusion(
-                &fusion_analysis_cache_, &gpu_performance_model_cache_),
-            producer->users());
-
-    if (fusion_process_dump_) {
-      absl::MutexLock lock(&fusion_process_dump_mutex_);
-      auto* step =
-          fusion_process_dump_->add_fusion_steps()->mutable_update_priority();
-      step->set_producer_name(std::string(producer->name()));
-      for (auto* consumer : producer->users()) {
-        step->add_consumer_names(std::string(consumer->name()));
-      }
-      step->set_us_fused(absl::ToDoubleMicroseconds(run_times.time_fused));
-      step->set_us_unfused(absl::ToDoubleMicroseconds(run_times.time_unfused));
-    }
-    return absl::ToInt64Nanoseconds(run_times.time_unfused -
-                                    run_times.time_fused);
-  }
-
-  FusionDecision CanFuseTriton(HloInstruction* producer,
-                               HloInstruction* consumer) {
-    if (!triton_softmax_priority_fusion_enabled_) {
-      return "triton softmax fusion is not enabled";
-    }
-
-    if (IsGenericTritonFusion(*producer)) {
-      if (!IsFusible(*consumer)) {
-        return "the consumer is not fusible";
-      }
-    } else {
-      if (!IsFusible(*producer)) {
-        return "the producer is not fusible";
-      }
-    }
-
-    auto fusion = HloFusionAdaptor::ForProducerConsumer(producer, consumer);
-
-    SymbolicTileAnalysisOrError symbolic_tile_analysis_or =
-        SymbolicTileAnalysis::AnalyzeFusion(*fusion, mlir_context_);
-
-    if (const auto* fusion_decision =
-            std::get_if<FusionDecision>(&symbolic_tile_analysis_or)) {
-      return {
-          absl::StrCat("Fusion can not be tiled with SymbolicTileAnalysis: ",
-                       fusion_decision->Explain())};
-    }
-
-    return {};
-  }
-
-  FusionDecision CanFuse(HloInstruction* producer, HloInstruction* consumer) {
-    if (IsGenericTritonFusion(*producer) || IsGenericTritonFusion(*consumer)) {
-      return CanFuseTriton(producer, consumer);
-    }
-
-    if (!IsFusible(*producer)) {
-      return "the producer is not fusible";
-    }
-
-    if (!IsFusible(*consumer)) {
-      return "the consumer is not fusible";
-    }
-
-    if (consumer->opcode() == HloOpcode::kBitcast) {
-      return "not fusing into a single bitcast as consumer";
-    }
-
-    // Scatter is special as it has no elemental version but is still input
-    // fusible. Block attempts to create scatter fusions we can't codegen.
-    if (auto can_fuse = CanEmitInputFusedScatter(*producer, *consumer);
-        !can_fuse) {
-      return can_fuse;
-    }
-
-    // Avoid fusing reduce into reduce. Our cost model doesn't currently
-    // understand this case due to a lack of tiling analysis.
-    // TODO(b/312200883): Remove this.
-    auto contains_significant_reduce = [&](const HloInstruction* instr) {
-      auto fusion = HloFusionAdaptor::ForInstruction(instr);
-      return HloAnyOf(*fusion, [](auto node) {
-        if (!(node.opcode() == HloOpcode::kReduce && node.shape().IsArray())) {
-          return false;
-        }
-
-        int64_t reduction_size =
-            ShapeUtil::ElementsIn(node.instruction().operand(0)->shape()) /
-            ShapeUtil::ElementsIn(node.shape());
-
-        // Small reductions are emitted using the elemental emitter anyway.
-        return reduction_size >= 16;
-      });
-    };
-    if (contains_significant_reduce(producer) &&
-        contains_significant_reduce(consumer)) {
-      return "both the producer and the consumer contain a reduce";
-    }
-
-    // Avoid doing fusions into the output of an "input" fusion when it would
-    // switch it to the loop emitter. This often occurs during epilog fusion for
-    // reductions, which suffer from limited emitter support.
-    // TODO(b/312686229): Cost model should handle this.
-    const auto& analysis = fusion_analysis_cache_.Get(*producer);
-    if (analysis.GetEmitterFusionKind() ==
-        HloFusionAnalysis::EmitterFusionKind::kReduction) {
-      const auto& analysis_fused =
-          fusion_analysis_cache_.Get(*producer, *consumer);
-      if (analysis_fused.GetEmitterFusionKind() ==
-          HloFusionAnalysis::EmitterFusionKind::kLoop) {
-        return "fusion into output of a reduce fusion would create a loop "
-               "fusion";
-      }
-    }
-
-    // Avoid cases where we'd create a fusion that hit limitations in ptxas.
-    // Would be nice to model this with cost instead.
-    if (auto fits_budget = FusionFitsInBudget(
-            *consumer, *producer, *device_info_,
-            /*is_consumer_producer_fusion=*/true, &fusion_info_cache_);
-        !fits_budget) {
-      return fits_budget;
-    }
-
-    // Also check that our emitter can handle the fusion node. We currently can
-    // have exponential time/memory requirements for emitting certain fusion
-    // kernels, in which case we don't want to fuse.
-    // TODO(b/119692968): Remove this once we have fixed our fusion emitter.
-    if (cost_analysis_.ProducerConsumerMergedTooLarge(*producer, *consumer)) {
-      return "the fusion would result in an overly large code duplication";
-    }
-
-    // Don't fuse across a root instruction. There are situation when a root
-    // instruction is not the last in the computation. Instructions after the
-    // root are not necessary dead. They can be inputs to instructions with side
-    // effects, like outfeed.
-    if (producer == producer->parent()->root_instruction()) {
-      return "not fusing into the output of the root instruction";
-    }
-
-    return InstructionFusion::ShouldFuseInPlaceOp(producer, consumer);
-  }
-
-  FusionDecision CanFuseCached(HloInstruction* producer,
-                               HloInstruction* consumer) {
-    {
-      absl::MutexLock lock(&can_fuse_cache_mutex_);
-      auto& producer_cache = can_fuse_cache_[producer];
-
-      auto it = producer_cache.find(consumer);
-      if (it != producer_cache.end()) {
-        return it->second;
-      }
-    }
-
-    auto fusion_decision = CanFuse(producer, consumer);
-
-    // The lock is required, because writing to a flat_hash_map is not
-    // thread-safe even for different keys. We never call this computation
-    // concurrently for the same producer, so it's guaranteed that we don't
-    // override any value.
-    {
-      absl::MutexLock lock(&can_fuse_cache_mutex_);
-      can_fuse_cache_[producer][consumer] = fusion_decision;
-    }
-
-    return fusion_decision;
-  }
-
-  FusionDecision CanFuseWithAllNonBitcastUsers(HloInstruction* producer) {
-    if (producer->users().empty()) {
-      return "No users to fuse";
-    }
-
-    FusionDecision result;
-    bool has_non_bitcast_user = false;
-    for (const auto& user : producer->users()) {
-      if (user->opcode() == HloOpcode::kBitcast) {
-        continue;
-      }
-      has_non_bitcast_user = true;
-      if (auto fusion_decision = CanFuseCached(producer, user);
-          !fusion_decision) {
-        VLOG(10) << "Cannot fuse " << producer->name() << " with "
-                 << user->name() << ", because: " << fusion_decision.Explain();
-        return fusion_decision;
-      }
-    }
-    if (!has_non_bitcast_user) {
-      return "not fusing because there are only bitcast users";
-    }
-    return {};
-  }
-
-  // Store computation for cost analysis.
-  HloComputation* computation_;
-
-  const se::DeviceDescription* device_info_;
-
-  // Reference to cost model that defines priorities in the queue.
-  GpuHloCostAnalysis cost_analysis_;
-
-  // The priority queue of producers, implemented as an ordered map, where a
-  // key is a pair: the first element is the priority and the second element is
-  // the unique ID of the instruction to break ties.
-  using PriorityQueue = std::map<std::pair<Priority, int>, HloInstruction*>;
-  PriorityQueue producer_priority_queue_;
-
-  // A reverse map that helps find an instruction in the priority queue.
-  absl::flat_hash_map<HloInstruction*, PriorityQueue::iterator> reverse_map_;
-
-  // The current producer being visited.
-  HloInstruction* current_producer_;
-
-  // The current consumers being visited.
-  std::vector<HloInstruction*> current_consumers_;
-
-  // The set of producers whose priorities need to be updated. Their
-  // priorities are changed because their neighbors got fused, but we delay
-  // the priority updates until current_consumers_ becomes empty. This is to
-  // avoid recomputing priorities multiple times before we dequeue a new
-  // producer.
-  absl::flat_hash_set<HloInstruction*> to_update_priority_;
-
-  // Proto with structured logs of fusion decisions. Used only for debugging. If
-  // null, logging is disabled.
-  FusionProcessDumpProto* fusion_process_dump_;
-  absl::Mutex fusion_process_dump_mutex_;
-
-  tsl::thread::ThreadPool* thread_pool_;
-
-  mlir::MLIRContext* mlir_context_;
-
-  HloFusionAnalysisCache& fusion_analysis_cache_;
-
-  // Caches result of can_fuse for a (producer, consumer) pair. A cache entry is
-  // invalidated if producer or consumer is modified.
-  absl::flat_hash_map<
-      const HloInstruction*,
-      absl::flat_hash_map<const HloInstruction*, FusionDecision>>
-      can_fuse_cache_;
-  absl::Mutex can_fuse_cache_mutex_;
-
-  GpuPerformanceModelCache gpu_performance_model_cache_;
-
-  // Cache for `FusionFitsInBudget` to avoid recomputing expensive properties
-  // like shared memory usage or number of unnested reductions of fusion nodes.
-  FusionInfoCache fusion_info_cache_;
-
-  bool triton_softmax_priority_fusion_enabled_;
-
-  bool dump_fusion_visualization_;
-};
-
-}  // namespace
-
-/*static*/ bool GpuPriorityFusion::IsExpensive(
-    const HloInstruction& instruction) {
-  // Some floating-point math ops are cheap on the GPU.
-  switch (instruction.opcode()) {
-    case HloOpcode::kDivide:
-    case HloOpcode::kSqrt:
-    case HloOpcode::kRsqrt:
-    case HloOpcode::kExp:
-      if (ElementIsF32OrF16(instruction.shape())) {
-        return false;
-      }
-      break;
-    // Loop fusions are cheap.
-    case HloOpcode::kFusion:
-      return false;
-    default:
-      break;
-  }
-  return InstructionFusion::IsExpensive(instruction);
-}
-
-// Return true, if instr is a small constant.
-//
-// There is not single definition for what is a small constant in XLA.
-// IrEmitterContext::emit_constant treats as small only constants of 1 element.
-// HloPrintOptions::print_large_constants is effective for constants larger
-// than 10 elements.
-//
-// This function matches the emitter logic.
-bool IsSmallConstant(const HloInstruction* instr) {
-  return instr->opcode() == HloOpcode::kConstant && instr->shape().IsArray() &&
-         ShapeUtil::ElementsIn(instr->shape()) <= 1;
-}
-
-bool GpuPriorityFusion::ConsumeFuel(HloInstruction* producer,
-                                    HloInstruction* consumer) {
-  return xla::ConsumeFuel(name(), /*ran_out_of_fuel_msg=*/[&] {
-    return absl::StrFormat("Not fusing producer %s with consumer %s",
-                           producer->name(), consumer->name());
-  });
-};
-
-absl::StatusOr<bool> GpuPriorityFusion::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  bool dump_enabled =
-      DumpingEnabledForHloPass(name(), module->config().debug_options());
-  if (dump_enabled) {
-    fusion_process_dump_ = std::make_unique<FusionProcessDumpProto>();
-    *fusion_process_dump_->mutable_gpu_device_info() =
-        device_info_.ToGpuProto();
-  }
-
-  // Compute the computations within which more fusion is possible.
-  auto fusible_computations =
-      GetFusibleComputations(*module, execution_threads);
-
-  // Appends ".0" suffix to all instructions.
-  //
-  // Every time an instruction is duplicated, the last integer suffix is
-  // incremented.
-  // Before: broadcast.123 -> broadcast.124
-  // After: broadcast.123.0 -> broadcast.123.1
-  //
-  // With this modification it will be easier to match instructions before and
-  // after fusion passes, because they will have the same unique prefix. Names
-  // are not used in the pipeline, but it makes debugging much easier.
-  for (auto* computation : fusible_computations) {
-    for (auto* instruction : computation->instructions()) {
-      module->SetAndUniquifyInstrName(instruction,
-                                      absl::StrCat(instruction->name(), ".0"));
-    }
-  }
-
-  if (dump_enabled) {
-    fusion_process_dump_->set_hlo_module_before_fusion(
-        module->ToString(HloPrintOptions::ShortParsable()));
-  }
-
-  bool triton_softmax_priority_fusion_enabled =
-      module->config()
-          .debug_options()
-          .xla_gpu_enable_triton_softmax_priority_fusion();
-
-  int changed = false;
-  for (auto* computation : fusible_computations) {
-    CHECK(!computation->IsFusionComputation());
-
-    auto fusion_queue = std::make_unique<GpuPriorityFusionQueue>(
-        computation, cost_analysis_options_, &device_info_,
-        fusion_process_dump_.get(), thread_pool_, &mlir_context_,
-        fusion_analysis_cache_, triton_softmax_priority_fusion_enabled);
-
-    while (fusion_queue->DequeueNextProducer()) {
-      auto producer = fusion_queue->current_producer();
-
-      for (auto* consumer : fusion_queue->current_consumers()) {
-        // Don't fuse into single bitcasts. We ignore them in the check
-        // CanFuseWithAllNonBitcastUsers(), so we need to check it here.
-        if (consumer->opcode() == HloOpcode::kBitcast) {
-          continue;
-        }
-        if (!ConsumeFuel(producer, consumer)) continue;
-
-        VLOG(5) << "next: " << consumer->name() << "(" << consumer << ") + "
-                << producer->name() << "(" << producer << ")";
-
-        fusion_queue->PreFusion(producer, consumer);
-        auto fusion_instruction = Fuse(producer, consumer, computation);
-        fusion_queue->OnFusingInstruction(fusion_instruction, producer,
-                                          consumer);
-
-        changed = true;
-      }
-
-      if (producer->user_count() == 0) {
-        fusion_queue->RemoveInstruction(producer);
-        // Remove from computation.
-        TF_RETURN_IF_ERROR(computation->RemoveInstruction(producer));
-      }
-
-      fusion_queue->UpdatePriorities();
-    }
-
-    // Fuse all constants.
-    std::vector<HloInstruction*> constants;
-    for (auto* instruction : computation->instructions()) {
-      // Small constants should be fused, because they can be folded and
-      // codegened efficiently.
-      // Fusing large constants doesn't give much benefits, because they're
-      // treated like parameters and read from global memory anyway. Fusion
-      // and duplication of large constants can, however, cause problems if we
-      // want to dump hlo and parse back, because in that case duplicated
-      // constants will be filled with different data.
-      if (IsSmallConstant(instruction)) {
-        constants.push_back(instruction);
-      }
-    }
-    for (auto* constant : constants) {
-      auto users = constant->users();
-      for (auto* user : users) {
-        if (IsFusible(*user) && CanEmitInputFusedScatter(*constant, *user)) {
-          Fuse(constant, user, computation);
-          changed = true;
-        }
-      }
-    }
-  }
-
-  // FusionAnalysis cache uses unique_id as key. IDs are only unique inside one
-  // module. It's important to fully clear the cache if the same instance of the
-  // pass will be called on a different module.
-  fusion_analysis_cache_.Clear();
-
-  if (dump_enabled) {
-    DumpPerModuleProtobufToFile(*module, *fusion_process_dump_,
-                                module->config().debug_options(),
-                                "priority_fusion_dump");
-  }
-
-  return changed;
-}
-
-FusionDecision GpuPriorityFusion::ShouldFuse(HloInstruction* consumer,
-                                             int64_t operand_index) {
-  // This method is called in `InstructionFusion::Run` right before fusion, but
-  // it will always return true. Fusion decision are fully controlled by the
-  // PriorityQueue. If the queue returns a producer that shouldn't be fused,
-  // it's a bug and should be fixed in the queue logic.
-  return {};
-}
-
-HloInstruction::FusionKind GpuPriorityFusion::ChooseKind(
-    const HloInstruction* producer, const HloInstruction* consumer) {
-  // Derive kInput/kLoop fusion kinds from fusion analysis. This shouldn't
-  // matter but some passes downstream still query these instead of fusion
-  // analysis.
-  const auto& analysis = fusion_analysis_cache_.Get(*producer, *consumer);
-  switch (analysis.GetEmitterFusionKind()) {
-    case HloFusionAnalysis::EmitterFusionKind::kLoop:
-      return HloInstruction::FusionKind::kLoop;
-    case HloFusionAnalysis::EmitterFusionKind::kTriton:
-    case HloFusionAnalysis::EmitterFusionKind::kCustomFusion:
-    case HloFusionAnalysis::EmitterFusionKind::kCuDnn:
-      return HloInstruction::FusionKind::kCustom;
-    case HloFusionAnalysis::EmitterFusionKind::kConcatenate:
-    case HloFusionAnalysis::EmitterFusionKind::kReduction:
-    case HloFusionAnalysis::EmitterFusionKind::kTranspose:
-    case HloFusionAnalysis::EmitterFusionKind::kInputSlices:
-    case HloFusionAnalysis::EmitterFusionKind::kScatter:
-      return HloInstruction::FusionKind::kInput;
-  }
-}
-
-HloInstruction* GpuPriorityFusion::FuseInstruction(
-    HloInstruction* fusion_instruction, HloInstruction* producer) {
-  HloInstruction* result = fusion_instruction;
-  if (producer->opcode() == HloOpcode::kFusion) {
-    if (IsGenericTritonFusion(*producer)) {
-      TF_CHECK_OK(fusion_instruction->set_backend_config(
-          *producer->backend_config<GpuBackendConfig>()));
-    }
-
-    fusion_instruction->MergeFusionInstruction(producer);
-  } else {
-    result = InstructionFusion::FuseInstruction(fusion_instruction, producer);
-  }
-  return result;
-}
-
-std::unique_ptr<FusionQueue> GpuPriorityFusion::GetFusionQueue(
-    HloComputation* computation) {
-  return nullptr;
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/priority_fusion.h b/third_party/xla/xla/service/gpu/priority_fusion.h
deleted file mode 100644
index 999eb78..0000000
--- a/third_party/xla/xla/service/gpu/priority_fusion.h
+++ /dev/null
@@ -1,100 +0,0 @@
-/* Copyright 2017 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_PRIORITY_FUSION_H_
-#define XLA_SERVICE_GPU_PRIORITY_FUSION_H_
-
-#include <stdint.h>
-
-#include <memory>
-#include <utility>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "mlir/IR/MLIRContext.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/fusion_queue.h"
-#include "xla/service/gpu/fusion_process_dump.pb.h"
-#include "xla/service/gpu/model/fusion_analysis_cache.h"
-#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h"
-#include "xla/service/hlo_cost_analysis.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/service/instruction_fusion.h"
-#include "xla/stream_executor/device_description.h"
-#include "tsl/platform/threadpool.h"
-
-namespace xla {
-namespace gpu {
-
-class GpuPriorityFusion : public InstructionFusion {
- public:
-  GpuPriorityFusion(tsl::thread::ThreadPool* thread_pool,
-                    const se::DeviceDescription& device,
-                    GpuHloCostAnalysis::Options cost_analysis_options)
-      : InstructionFusion(GpuPriorityFusion::IsExpensive),
-        thread_pool_(thread_pool),
-        device_info_(device),
-        cost_analysis_options_(std::move(cost_analysis_options)),
-        fusion_analysis_cache_(device_info_) {}
-
-  absl::string_view name() const override { return "priority-fusion"; }
-
-  static bool IsExpensive(const HloInstruction& instruction);
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- protected:
-  std::unique_ptr<FusionQueue> GetFusionQueue(
-      HloComputation* computation) override;
-
-  FusionDecision ShouldFuse(HloInstruction* consumer,
-                            int64_t operand_index) override;
-
-  HloInstruction::FusionKind ChooseKind(
-      const HloInstruction* producer, const HloInstruction* consumer) override;
-
- private:
-  HloInstruction* FuseInstruction(HloInstruction* fusion_instruction,
-                                  HloInstruction* producer) override;
-
-  // Consumes a unit of compiler fuel and returns true if we should
-  // continue with the transformation.
-  bool ConsumeFuel(HloInstruction* producer, HloInstruction* consumer);
-
-  tsl::thread::ThreadPool* thread_pool_;
-  se::DeviceDescription device_info_;
-
-  // Cost model options that defines priorities in the queue.
-  GpuHloCostAnalysis::Options cost_analysis_options_;
-
-  // Proto with structured logs of fusion decisions. Used only for debugging. If
-  // null, logging is disabled.
-  std::unique_ptr<FusionProcessDumpProto> fusion_process_dump_;
-
-  HloFusionAnalysisCache fusion_analysis_cache_;
-
-  mlir::MLIRContext mlir_context_;
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_PRIORITY_FUSION_H_
diff --git a/third_party/xla/xla/service/gpu/priority_fusion_test.cc b/third_party/xla/xla/service/gpu/priority_fusion_test.cc
deleted file mode 100644
index 4f71a51..0000000
--- a/third_party/xla/xla/service/gpu/priority_fusion_test.cc
+++ /dev/null
@@ -1,941 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/priority_fusion.h"
-
-#include <stdint.h>
-
-#include <memory>
-#include <optional>
-#include <string>
-#include <utility>
-#include <vector>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/gpu_device_info_for_tests.h"
-#include "xla/service/gpu/gpu_fusible.h"
-#include "xla/service/gpu/hlo_fusion_analysis.h"
-#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h"
-#include "xla/service/hlo_cost_analysis.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/tests/verified_hlo_module.h"
-#include "tsl/platform/status_matchers.h"
-#include "tsl/platform/statusor.h"
-
-namespace m = ::xla::match;
-
-using ::testing::UnorderedElementsAre;
-using ::tsl::testing::IsOk;
-using ::tsl::testing::IsOkAndHolds;
-
-namespace xla {
-namespace gpu {
-
-class PriorityFusionTest : public HloTestBase {
-  HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const {
-    return [&](const Shape& shape) {
-      constexpr int64_t kPointerSize = 8;
-      return ShapeUtil::ByteSizeOf(shape, kPointerSize);
-    };
-  }
-
- public:
-  std::vector<HloFusionAnalysis::EmitterFusionKind> RunAndGetFusionKinds(
-      absl::string_view hlo) {
-    auto module = ParseAndReturnVerifiedModule(hlo).value();
-    EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(true));
-    EXPECT_THAT(module->RemoveUnusedComputations(), IsOk());
-    std::vector<HloFusionAnalysis::EmitterFusionKind> kinds;
-    for (auto computation : module->computations()) {
-      if (!computation->FusionInstruction()) continue;
-
-      auto device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo();
-      auto analysis = HloFusionAnalysis::Create(
-          Cast<HloFusionInstruction>(computation->FusionInstruction()),
-          &device_info);
-      kinds.push_back(analysis.GetEmitterFusionKind());
-    }
-    return kinds;
-  }
-
-  GpuPriorityFusion priority_fusion_{
-      /*thread_pool=*/nullptr, TestGpuDeviceInfo::RTXA6000DeviceInfo(),
-      GpuHloCostAnalysis::Options{ShapeSizeBytesFunction(),
-                                  /*per_second_rates=*/{},
-                                  /*count_multiple_input_accesses=*/true}};
-};
-
-TEST_F(PriorityFusionTest, FuseWithSharedArgument) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-    HloModule test_module
-
-    ENTRY main {
-      %p0 = f32[] parameter(0)
-      %p1 = f32[] parameter(1)
-      %subtract = f32[] subtract(%p0, %p1)
-      %compare = pred[] compare(%subtract, %subtract), direction=NE
-      %add = f32[] add(%p0, %p1)
-      %abs = f32[] abs(%subtract)
-      ROOT %select = f32[] select(%compare, %add, %abs)
-    })")
-                    .value();
-
-  EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(true));
-
-  HloInstruction* root = module->entry_computation()->root_instruction();
-  EXPECT_THAT(root, GmockMatch(m::Fusion()));
-  EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kLoop);
-}
-
-TEST_F(PriorityFusionTest, FusionFusionWithDuplication) {
-  absl::string_view kHlo = R"(
-    HloModule test_module
-
-    square {
-      p = f32[16384]{0} parameter(0)
-      ROOT m = f32[16384]{0} multiply(p, p)
-    }
-
-    exp {
-      p = f32[16384]{0} parameter(0)
-      ROOT e = f32[16384]{0} exponential(p)
-    }
-
-    log {
-      p = f32[16384]{0} parameter(0)
-      ROOT l = f32[16384]{0} log(p)
-    }
-
-    ENTRY main {
-      p = f32[16384]{0} parameter(0)
-      s = f32[16384]{0} fusion(p), kind=kLoop, calls=square
-      e = f32[16384]{0} fusion(s), kind=kLoop, calls=exp
-      l = f32[16384]{0} fusion(s), kind=kInput, calls=log
-      ROOT t = (f32[16384], f32[16384]) tuple(l, e)
-    })";
-
-  RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"(
-CHECK: ENTRY
-CHECK-NEXT: %[[PARAM:.*]] = f32[16384]{0} parameter(0)
-CHECK-NEXT: %[[FUSION_0:.*]] = f32[16384]{0} fusion(%[[PARAM]])
-CHECK-NEXT: %[[FUSION_1:.*]] = f32[16384]{0} fusion(%[[PARAM]])
-CHECK-NEXT: ROOT {{.*}} tuple(%[[FUSION_0]], %[[FUSION_1]])
-  )");
-}
-
-TEST_F(PriorityFusionTest, FuseBroadcastIntoBitcastConsumers) {
-  absl::string_view kHlo = R"(
-    HloModule test_module
-
-    ENTRY main {
-      param_0 = f32[96]{0} parameter(0)
-      broadcast = f32[8,96,128,7]{3,2,1,0} broadcast(param_0), dimensions={1}
-      bitcast.6079.2 = f32[8,24,4,128,7]{4,3,2,1,0} bitcast(broadcast)
-      ROOT transpose.1990.2 = f32[8,24,128,7,4]{4,3,2,1,0} transpose(bitcast.6079.2), dimensions={0,1,3,4,2}
-    }
-  )";
-  RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"(
-CHECK:      ENTRY
-CHECK-NEXT: %[[PARAM:.*]] = f32[96]{0} parameter(0)
-CHECK-NEXT: ROOT %{{.*}} fusion(%[[PARAM]])
-  )");
-}
-
-TEST_F(PriorityFusionTest, FuseWideningConvertIntoConsumers) {
-  absl::string_view kHlo = R"(
-    HloModule test_module
-
-    ENTRY main {
-      p = f16[512]{0} parameter(0)
-      a = f16[512]{0} add(p, p)
-      c = f32[512]{0} convert(a)
-      s = f32[512]{0} multiply(c, c)
-      bc = s32[512]{0} bitcast(c)
-      ROOT t = (f32[512], s32[512]) tuple(s, bc)
-    })";
-
-  RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"(
-CHECK:      ENTRY
-CHECK-NEXT: %[[PARAM:.*]] = f16[512]{0} parameter(0)
-CHECK-NEXT: %[[FUSION_F32:.*]] = f32[512]{0} fusion(%[[PARAM]])
-CHECK-NEXT: %[[CONVERT_FUSION:.*]] = f32[512]{0} fusion(%[[PARAM]])
-CHECK-NEXT: %[[BITCAST:.*]] = s32[512]{0} bitcast(%[[CONVERT_FUSION]])
-CHECK-NEXT: ROOT %{{.*}} = (f32[512]{0}, s32[512]{0}) tuple(%[[FUSION_F32]], %[[BITCAST]])
-  )");
-}
-
-TEST_F(PriorityFusionTest, FuseConvertIntoReduce) {
-  absl::string_view kHlo = R"(
-    HloModule test_module
-
-    add {
-      p0 = f32[] parameter(0)
-      p1 = f32[] parameter(1)
-      ROOT add.13235 = f32[] add(p0, p1)
-    }
-
-    ENTRY main {
-      param_0_0.79 = bf16[1024,8192]{1,0} parameter(0)
-      param_1_0.79 = bf16[1024,8192]{1,0} parameter(1)
-      param_2.483 = f32[8192]{0} parameter(2)
-      param_4.2892 = bf16[1024,8192]{1,0} parameter(3)
-      convert.21854 = f32[1024,8192]{1,0} convert(param_0_0.79)
-      convert.21855 = f32[1024,8192]{1,0} convert(param_1_0.79)
-      constant_7773 = f32[] constant(0)
-      broadcast.14555 = f32[1024,8192]{1,0} broadcast(param_2.483), dimensions={1}
-      multiply.6906 = f32[1024,8192]{1,0} multiply(broadcast.14555, convert.21854)
-      reduce.4813 = f32[1024]{0} reduce(multiply.6906, constant_7773), dimensions={1}, to_apply=add
-      convert.13970 = bf16[1024]{0} convert(reduce.4813)
-      convert.21534 = f32[1024,8192]{1,0} convert(param_4.2892)
-      multiply.6910.clone.1 = f32[1024,8192]{1,0} multiply(broadcast.14555, convert.21534)
-      reduce.4811.clone.1 = f32[1024]{0} reduce(multiply.6910.clone.1, constant_7773), dimensions={1}, to_apply=add
-      convert.13967.clone.1 = bf16[1024]{0} convert(reduce.4811.clone.1)
-      multiply.6908.clone.1 = f32[1024,8192]{1,0} multiply(broadcast.14555, convert.21855)
-      reduce.4812.clone.1 = f32[1024]{0} reduce(multiply.6908.clone.1, constant_7773), dimensions={1}, to_apply=add
-      convert.13969.clone.1 = bf16[1024]{0} convert(reduce.4812.clone.1)
-      ROOT fusion.241 = (bf16[1024]{0}, bf16[1024]{0}, bf16[1024]{0}) tuple(convert.13970, convert.13967.clone.1, convert.13969.clone.1)
-    })";
-
-  RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"(
-CHECK-COUNT-3: ROOT {{.*}} convert(
-CHECK: ENTRY %main
-CHECK-COUNT-3: fusion
-  )");
-}
-
-TEST_F(PriorityFusionTest, ReductionEpilogueFusionRegressionTest) {
-  // Regression test for epilogue fusion of convert into a reduction, even if
-  // the convert has a bitcast as consumer.
-  absl::string_view kHlo = R"(
-    HloModule test_module
-
-    add {
-      rhs.407 = f32[] parameter(1)
-      lhs.407 = f32[] parameter(0)
-      ROOT add.24451 = f32[] add(lhs.407, rhs.407)
-    }
-
-    ENTRY main {
-      param_1.15162 = f32[2752]{0} parameter(1)
-      convert.44829 = bf16[2752]{0} convert(param_1.15162)
-      bitcast.24686 = bf16[1,1,2752]{2,1,0} bitcast(convert.44829)
-      convert.44468 = f32[1,1,2752]{2,1,0} convert(bitcast.24686)
-      constant_13722 = bf16[] constant(1)
-      convert.17451 = f32[] convert(constant_13722)
-      broadcast.17565 = f32[1,1,2752]{2,1,0} broadcast(convert.17451), dimensions={}
-      negate.167 = f32[1,1,2752]{2,1,0} negate(convert.44468)
-      exponential.569 = f32[1,1,2752]{2,1,0} exponential(negate.167)
-      add.1850 = f32[1,1,2752]{2,1,0} add(broadcast.17565, exponential.569)
-      divide.1376 = f32[1,1,2752]{2,1,0} divide(broadcast.17565, add.1850)
-      multiply.9709 = f32[1,1,2752]{2,1,0} multiply(convert.44468, divide.1376)
-      param_0.15005 = f32[2752]{0} parameter(0)
-      convert.44826 = bf16[2752]{0} convert(param_0.15005)
-      bitcast.24683 = bf16[1,1,2752]{2,1,0} bitcast(convert.44826)
-      convert.44467 = f32[1,1,2752]{2,1,0} convert(bitcast.24683)
-      multiply.9708 = f32[1,1,2752]{2,1,0} multiply(multiply.9709, convert.44467)
-      convert.16959 = bf16[1,1,2752]{2,1,0} convert(multiply.9708)
-      fusion.3203 = bf16[2752]{0} bitcast(convert.16959)
-      convert.15093 = f32[2752]{0} convert(fusion.3203)
-      broadcast.13841 = f32[8192,2752]{1,0} broadcast(convert.15093), dimensions={1}
-      param_0.15525 = bf16[8192,2752]{1,0} parameter(2)
-      convert.13738 = f32[8192,2752]{1,0} convert(param_0.15525)
-      multiply.6422 = f32[8192,2752]{1,0} multiply(broadcast.13841, convert.13738)
-      constant_14382 = f32[] constant(0)
-      fusion.339 = f32[8192]{0} reduce(multiply.6422, constant_14382), dimensions={1}, to_apply=add
-      convert.44633 = bf16[8192]{0} convert(fusion.339)
-      ROOT bitcast.24487 = bf16[1,1,8192]{2,1,0} bitcast(convert.44633)
-    }
-                                                     )";
-
-  EXPECT_THAT(
-      RunAndGetFusionKinds(kHlo),
-      UnorderedElementsAre(HloFusionAnalysis::EmitterFusionKind::kLoop,
-                           HloFusionAnalysis::EmitterFusionKind::kReduction));
-
-  RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"(
-CHECK: ENTRY
-CHECK: ROOT {{.*}} bitcast({{.*}}fusion{{.*}})
-  )");
-}
-
-TEST_F(PriorityFusionTest, DoNotChangeReductionFusionToLoopFusion) {
-  // Regression test for epilogue fusion of slice into a reduction. The fusion
-  // kind for the reduction fusion is intentionally chosen to be set to kLoop,
-  // as we cannot rely on reductions always having fusion kind kInput.
-  auto module = *ParseAndReturnVerifiedModule(R"(
-    HloModule test_module
-
-    add {
-      rhs.407 = f32[] parameter(1)
-      lhs.407 = f32[] parameter(0)
-      ROOT add.24451 = f32[] add(lhs.407, rhs.407)
-    }
-
-    fused_computation {
-      p0 = f32[16,64]{1,0} parameter(0)
-      zero = f32[] constant(0.0)
-      ROOT reduce = f32[16]{0} reduce(p0, zero), dimensions={1}, to_apply=add
-    }
-
-    ENTRY main {
-      param0 = f32[16,64]{1,0} parameter(0)
-      fusion = f32[16]{0} fusion(param0), kind=kLoop, calls=fused_computation
-      ROOT slice = f32[8]{0} slice(fusion), slice={[0:8]}
-    })");
-  EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(false));
-}
-
-TEST_F(PriorityFusionTest, DoNotFuseTransposeIntoReduce) {
-  absl::string_view kHlo = R"(
-    HloModule test_module
-
-    add {
-      Arg_1.1046 = f32[] parameter(1)
-      Arg_0.1045 = f32[] parameter(0)
-      ROOT add.3303 = f32[] add(Arg_0.1045, Arg_1.1046)
-    }
-
-    ENTRY main {
-      param_0.17323 = pred[2048,2048]{1,0} parameter(0)
-      broadcast.22829 = pred[1,12,2048,2048]{3,2,1,0} broadcast(param_0.17323), dimensions={2,3}
-      param_1.19761 = bf16[2048,24576]{1,0} parameter(1)
-      convert.29880.clone.1 = f32[2048,24576]{1,0} convert(param_1.19761)
-      constant_10033_clone_1 = bf16[] constant(0.02002)
-      convert.30056.clone.1 = f32[] convert(constant_10033_clone_1)
-      broadcast.18898.clone.1 = f32[2048,24576]{1,0} broadcast(convert.30056.clone.1), dimensions={}
-      multiply.13451.clone.1 = f32[2048,24576]{1,0} multiply(convert.29880.clone.1, broadcast.18898.clone.1)
-      tanh.798.clone.1 = f32[2048,24576]{1,0} tanh(multiply.13451.clone.1)
-      constant_10244_clone_1 = bf16[] constant(50)
-      convert.30039.clone.1 = f32[] convert(constant_10244_clone_1)
-      broadcast.18310.clone.1 = f32[2048,24576]{1,0} broadcast(convert.30039.clone.1), dimensions={}
-      multiply.12550.clone.1 = f32[2048,24576]{1,0} multiply(tanh.798.clone.1, broadcast.18310.clone.1)
-      convert.29370.clone.1 = bf16[2048,24576]{1,0} convert(multiply.12550.clone.1)
-      bitcast.22330 = bf16[1,2048,2048,12]{3,2,1,0} bitcast(convert.29370.clone.1)
-      transpose.6582 = bf16[1,12,2048,2048]{3,2,1,0} transpose(bitcast.22330), dimensions={0,3,2,1}
-      convert.33705 = f32[1,12,2048,2048]{3,2,1,0} convert(transpose.6582)
-      constant_10212 = f32[] constant(-2.38197633e+38)
-      broadcast.22828 = f32[1,12,2048,2048]{3,2,1,0} broadcast(constant_10212), dimensions={}
-      select.589 = f32[1,12,2048,2048]{3,2,1,0} select(broadcast.22829, convert.33705, broadcast.22828)
-      bitcast.22075 = f32[12,2048,2048]{2,1,0} bitcast(select.589)
-      constant_10192 = f32[] constant(-inf)
-      reduce.1614 = f32[12,2048]{1,0} reduce(bitcast.22075, constant_10192), dimensions={2}, to_apply=add
-
-      predarg = pred[1,1,2048,2048]{3,2,1,0} parameter(2)
-      bitcast.11069 = pred[2048,2048]{1,0} bitcast(predarg)
-
-      broadcast.22825 = pred[1,12,2048,2048]{3,2,1,0} broadcast(bitcast.11069), dimensions={2,3}
-      bitcast.22331 = bf16[1,2048,2048,12]{3,2,1,0} bitcast(convert.29370.clone.1)
-      transpose.6580 = bf16[1,12,2048,2048]{3,2,1,0} transpose(bitcast.22331), dimensions={0,3,2,1}
-      convert.33703 = f32[1,12,2048,2048]{3,2,1,0} convert(transpose.6580)
-      constant_10213 = f32[] constant(-2.38197633e+38)
-      broadcast.22824 = f32[1,12,2048,2048]{3,2,1,0} broadcast(constant_10213), dimensions={}
-      select.587 = f32[1,12,2048,2048]{3,2,1,0} select(broadcast.22825, convert.33703, broadcast.22824)
-      broadcast.22819 = f32[1,12,2048,2048]{3,2,1,0} broadcast(reduce.1614), dimensions={1,2}
-      subtract.1129 = f32[1,12,2048,2048]{3,2,1,0} subtract(select.587, broadcast.22819)
-      exponential.418 = f32[1,12,2048,2048]{3,2,1,0} exponential(subtract.1129)
-      bitcast.22074 = f32[12,2048,2048]{2,1,0} bitcast(exponential.418)
-      constant_10490 = f32[] constant(0)
-      reduce.1613 = f32[12,2048]{1,0} reduce(bitcast.22074, constant_10490), dimensions={2}, to_apply=add
-
-      constant_468 = f32[] constant(-2.38197633e+38)
-      broadcast.22833 = pred[1,12,2048,2048]{3,2,1,0} broadcast(bitcast.11069), dimensions={2,3}
-      bitcast.22332 = bf16[1,2048,2048,12]{3,2,1,0} bitcast(convert.29370.clone.1)
-      transpose.6584 = bf16[1,12,2048,2048]{3,2,1,0} transpose(bitcast.22332), dimensions={0,3,2,1}
-      convert.33707 = f32[1,12,2048,2048]{3,2,1,0} convert(transpose.6584)
-      broadcast.22832 = f32[1,12,2048,2048]{3,2,1,0} broadcast(constant_468), dimensions={}
-      select.591 = f32[1,12,2048,2048]{3,2,1,0} select(broadcast.22833, convert.33707, broadcast.22832)
-      broadcast.22821 = f32[1,12,2048,2048]{3,2,1,0} broadcast(reduce.1614), dimensions={1,2}
-      subtract.1131 = f32[1,12,2048,2048]{3,2,1,0} subtract(select.591, broadcast.22821)
-      exponential.420 = f32[1,12,2048,2048]{3,2,1,0} exponential(subtract.1131)
-      broadcast.18351 = f32[1,12,2048,2048]{3,2,1,0} broadcast(reduce.1613), dimensions={1,2}
-      divide.340 = f32[1,12,2048,2048]{3,2,1,0} divide(exponential.420, broadcast.18351)
-      ROOT convert.29418 = bf16[1,12,2048,2048]{3,2,1,0} convert(divide.340)
-    })";
-
-  using Kind = HloFusionAnalysis::EmitterFusionKind;
-  EXPECT_THAT(
-      RunAndGetFusionKinds(kHlo),
-      UnorderedElementsAre(Kind::kLoop, Kind::kLoop, Kind::kLoop,
-                           Kind::kReduction, Kind::kReduction, Kind::kTranspose,
-                           Kind::kTranspose, Kind::kTranspose));
-}
-
-TEST_F(PriorityFusionTest, DoNotFuseReduceIntoReduce) {
-  absl::string_view kHlo = R"(
-    HloModule test_module
-
-    add {
-      p0 = f32[] parameter(0)
-      p1 = f32[] parameter(1)
-      ROOT add.13235 = f32[] add(p0, p1)
-    }
-
-    ENTRY main {
-      p0 = f32[8,4,128,226]{3,2,1,0} parameter(0)
-      c0 = f32[] constant(0)
-      r0 = f32[8,4,128]{2,1,0} reduce(p0, c0), dimensions={3}, to_apply=add
-      ROOT r1 = f32[8,4]{1,0} reduce(r0, c0), dimensions={2}, to_apply=add
-    })";
-
-  RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"(
-CHECK: ROOT {{.*}} reduce(
-CHECK: ROOT {{.*}} reduce(
-  )");
-}
-
-TEST_F(PriorityFusionTest, ConvertFusedIntoReduce) {
-  absl::string_view kHlo = R"(
-    HloModule test_module
-
-    add {
-      p0 = f32[] parameter(0)
-      p1 = f32[] parameter(1)
-      ROOT add.13235 = f32[] add(p0, p1)
-    }
-
-    ENTRY main {
-      param_0_0.79 = bf16[1024,8192]{1,0} parameter(0)
-      param_1_0.79 = bf16[1024,8192]{1,0} parameter(1)
-      param_2.483 = f32[8192]{0} parameter(2)
-      param_4.2892 = bf16[1024,8192]{1,0} parameter(3)
-      convert.21854 = f32[1024,8192]{1,0} convert(param_0_0.79)
-      convert.21855 = f32[1024,8192]{1,0} convert(param_1_0.79)
-      constant_7773 = f32[] constant(0)
-      broadcast.14555 = f32[1024,8192]{1,0} broadcast(param_2.483), dimensions={1}
-      multiply.6906 = f32[1024,8192]{1,0} multiply(broadcast.14555, convert.21854)
-      reduce.4813 = f32[1024]{0} reduce(multiply.6906, constant_7773), dimensions={1}, to_apply=add
-      convert.13970 = bf16[1024]{0} convert(reduce.4813)
-      convert.21534 = f32[1024,8192]{1,0} convert(param_4.2892)
-      multiply.6910.clone.1 = f32[1024,8192]{1,0} multiply(broadcast.14555, convert.21534)
-      reduce.4811.clone.1 = f32[1024]{0} reduce(multiply.6910.clone.1, constant_7773), dimensions={1}, to_apply=add
-      convert.13967.clone.1 = bf16[1024]{0} convert(reduce.4811.clone.1)
-      multiply.6908.clone.1 = f32[1024,8192]{1,0} multiply(broadcast.14555, convert.21855)
-      reduce.4812.clone.1 = f32[1024]{0} reduce(multiply.6908.clone.1, constant_7773), dimensions={1}, to_apply=add
-      convert.13969.clone.1 = bf16[1024]{0} convert(reduce.4812.clone.1)
-      ROOT fusion.241 = (bf16[1024]{0}, bf16[1024]{0}, bf16[1024]{0}) tuple(convert.13970, convert.13967.clone.1, convert.13969.clone.1)
-    })";
-
-  RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"(
-CHECK-COUNT-3: ROOT {{.*}} convert(
-CHECK: ENTRY %main
-CHECK-COUNT-3: fusion(
-CHECK-NOT: fusion(
-  )");
-}
-
-TEST_F(PriorityFusionTest, DoNotFuseDynamicUpdateSliceIntoReduce) {
-  GTEST_SKIP() << "b/294198633";
-  absl::string_view kHlo = R"(
-    HloModule test_module
-
-add {
-  Arg_1.1046 = f32[] parameter(1)
-  Arg_0.1045 = f32[] parameter(0)
-  ROOT add.3303 = f32[] add(Arg_0.1045, Arg_1.1046)
-}
-
-ENTRY main {
-  param_0.10549 = f32[4,2112]{1,0} parameter(0)
-  param_5.2561 = pred[] parameter(5)
-  broadcast.19725 = pred[4,1]{1,0} broadcast(param_5.2561), dimensions={}
-  param_1.11587 = pred[4]{0} parameter(1)
-  constant_5837 = f32[] constant(1)
-  broadcast.19723 = f32[4]{0} broadcast(constant_5837), dimensions={}
-  param_2.5952 = f32[4,8000]{1,0} parameter(2)
-  param_3.4004 = f32[4]{0} parameter(3)
-  broadcast.19718 = f32[4,8000]{1,0} broadcast(param_3.4004), dimensions={0}
-  subtract.1112 = f32[4,8000]{1,0} subtract(param_2.5952, broadcast.19718)
-  exponential.418 = f32[4,8000]{1,0} exponential(subtract.1112)
-  constant_6254 = f32[] constant(0)
-  reduce.1154 = f32[4]{0} reduce(exponential.418, constant_6254), dimensions={1}, to_apply=add
-  log.38 = f32[4]{0} log(reduce.1154)
-  broadcast.19717 = f32[4,8000]{1,0} broadcast(log.38), dimensions={0}
-  subtract.1111 = f32[4,8000]{1,0} subtract(subtract.1112, broadcast.19717)
-  iota.170 = s32[4,1]{1,0} iota(), iota_dimension=0
-  constant_6281 = s32[] constant(0)
-  broadcast.19735 = s32[4]{0} broadcast(constant_6281), dimensions={}
-  param_4.3400 = s32[4,8000]{1,0} parameter(4)
-  slice.3186 = s32[4,40]{1,0} slice(param_4.3400), slice={[0:4], [0:40]}
-  iota.168 = s32[4,1]{1,0} iota(), iota_dimension=0
-  param_7.1596 = s32[4]{0} parameter(7)
-  compare.341 = pred[4]{0} compare(param_7.1596, broadcast.19735), direction=LT
-  constant_5833 = s32[] constant(40)
-  broadcast.19731 = s32[4]{0} broadcast(constant_5833), dimensions={}
-  add.8348 = s32[4]{0} add(param_7.1596, broadcast.19731)
-  select.418 = s32[4]{0} select(compare.341, add.8348, param_7.1596)
-  bitcast.20942 = s32[4,1]{1,0} bitcast(select.418)
-  concatenate.1337 = s32[4,2]{1,0} concatenate(iota.168, bitcast.20942), dimensions={1}
-  gather.43 = s32[4,1,1]{2,1,0} gather(slice.3186, concatenate.1337), offset_dims={1,2}, collapsed_slice_dims={}, start_index_map={0,1}, index_vector_dim=1, slice_sizes={1,1}
-  bitcast.20941 = s32[4]{0} bitcast(gather.43)
-  select.398 = s32[4]{0} select(param_1.11587, broadcast.19735, bitcast.20941)
-  compare.334 = pred[4]{0} compare(select.398, broadcast.19735), direction=LT
-  constant_6260 = s32[] constant(8000)
-  broadcast.19720 = s32[4]{0} broadcast(constant_6260), dimensions={}
-  add.8336 = s32[4]{0} add(select.398, broadcast.19720)
-  select.396 = s32[4]{0} select(compare.334, add.8336, select.398)
-  bitcast.20830 = s32[4,1]{1,0} bitcast(select.396)
-  concatenate.1308 = s32[4,2]{1,0} concatenate(iota.170, bitcast.20830), dimensions={1}
-  gather.41 = f32[4,1,1]{2,1,0} gather(subtract.1111, concatenate.1308), offset_dims={1,2}, collapsed_slice_dims={}, start_index_map={0,1}, index_vector_dim=1, slice_sizes={1,1}
-  bitcast.20824 = f32[4]{0} bitcast(gather.41)
-  select.389 = f32[4]{0} select(param_1.11587, broadcast.19723, bitcast.20824)
-  bitcast.20823 = f32[4,1]{1,0} bitcast(select.389)
-  param_6.1719 = s32[] parameter(6)
-  constant_6323 = s32[] constant(2048)
-  add.8549 = s32[] add(param_6.1719, constant_6323)
-  compare.388 = pred[] compare(add.8549, constant_6281), direction=LT
-  constant_5436 = s32[] constant(4160)
-  add.8339 = s32[] add(param_6.1719, constant_5436)
-  select.409 = s32[] select(compare.388, add.8339, add.8549)
-  dynamic-slice.36 = f32[4,1]{1,0} dynamic-slice(param_0.10549, constant_6281, select.409), dynamic_slice_sizes={4,1}
-  select.388 = f32[4,1]{1,0} select(broadcast.19725, bitcast.20823, dynamic-slice.36)
-  ROOT dynamic-update-slice.307 = f32[4,2112]{1,0} dynamic-update-slice(param_0.10549, select.388, constant_6281, select.409)
-})";
-
-  RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"(
-CHECK: ROOT {{.*}} dynamic-update-slice(
-CHECK: %[[REDUCE:.*]] = {{.*}} reduce(
-CHECK: ROOT {{.*}} log(%[[REDUCE]])
-CHECK: ENTRY
-CHECK-COUNT-2: fusion(
-  )");
-}
-
-TEST_F(PriorityFusionTest, DontFuseIntoFirstOperandOfScatter) {
-  auto module = *ParseAndReturnVerifiedModule(R"(
-    HloModule test_module
-
-    add {
-      lhs = s32[] parameter(0)
-      rhs = s32[] parameter(1)
-      ROOT add = s32[] add(lhs, rhs)
-    }
-
-    ENTRY FuseIntoScatter {
-      p0 = s32[3,3] parameter(0)
-      operand = s32[3,3] add(p0, p0)
-      p1 = s32[2] parameter(1)
-      indices = s32[2] add(p1, p1)
-      p2 = s32[2,3] parameter(2)
-      updates = s32[2,3] add(p2, p2)
-      scatter = s32[3,3] scatter(operand, indices, updates),
-          to_apply=add,
-          update_window_dims={1},
-          inserted_window_dims={0},
-          scatter_dims_to_operand_dims={0},
-          index_vector_dim=1
-      ROOT add = s32[3,3] add(scatter, scatter)
-    })");
-
-  EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(true));
-
-  HloInstruction* root = module->entry_computation()->root_instruction();
-  const HloInstruction* fusion = nullptr;
-  ASSERT_THAT(root, GmockMatch(m::Add(m::Fusion(&fusion), m::Fusion())));
-  EXPECT_EQ(fusion->fusion_kind(), HloInstruction::FusionKind::kInput);
-  EXPECT_THAT(fusion->fused_expression_root(),
-              GmockMatch(m::Scatter(m::Parameter(), m::Add(), m::Add())));
-}
-
-// This test is similar to DontFuseIntoFirstOperandOfScatter, but PriorityFusion
-// has a separate run to fuse constants. Fusing anything into a scatter fusion
-// will fail in the emitter.
-TEST_F(PriorityFusionTest, DontFuseConstantIntoFirstOperandOfScatter) {
-  auto module = *ParseAndReturnVerifiedModule(R"(
-    HloModule test_module
-
-    add {
-      lhs = s32[] parameter(0)
-      rhs = s32[] parameter(1)
-      ROOT add = s32[] add(lhs, rhs)
-    }
-
-    ENTRY FuseIntoScatter {
-      operand = s32[1] constant({0})
-      indices = s32[24,1] parameter(0)
-      constant = s32[] constant(1)
-      updates = s32[24,1] broadcast(constant)
-      ROOT scatter = s32[1] scatter(operand, indices, updates),
-          to_apply=add,
-          update_window_dims={1},
-          inserted_window_dims={},
-          scatter_dims_to_operand_dims={0},
-          index_vector_dim=1
-    })");
-
-  EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(true));
-
-  HloInstruction* root = module->entry_computation()->root_instruction();
-  ASSERT_THAT(root, GmockMatch(m::Fusion(m::Constant(), m::Parameter())));
-  EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kInput);
-  EXPECT_THAT(root->fused_expression_root(),
-              GmockMatch(m::Scatter(m::Parameter(), m::Parameter(),
-                                    m::Broadcast(m::Constant()))));
-}
-
-TEST_F(PriorityFusionTest, DoNotFuseReduceIntoReduceEvenIfOccupancyIsHigh) {
-  constexpr absl::string_view kHlo = R"(
-    HloModule test_module
-
-    add {
-      lhs = f32[] parameter(0)
-      rhs = f32[] parameter(1)
-      ROOT add = f32[] add(lhs, rhs)
-    }
-
-    ENTRY main {
-      p0 = f32[4,3584,128,168]{3,2,1,0} parameter(0)
-      c = f32[] constant(0)
-      r1 = f32[4,3584,128]{2,1,0} reduce(p0, c), dimensions={3}, to_apply=add
-      ROOT r2 = f32[4,3584]{1,0} reduce(r1, c), dimensions={2}, to_apply=add
-    })";
-
-  RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"(
-CHECK: ROOT {{.*}} reduce(
-CHECK: ROOT {{.*}} reduce(
-  )");
-}
-
-TEST_F(PriorityFusionTest, FuseReductionEpilogueWithMultipleUsers) {
-  // Regression test that verifies we correctly fuse the `log` into the reduce.
-  constexpr absl::string_view kHlo = R"(
-    HloModule test_module
-
-    add {
-      x = f32[] parameter(0)
-      y = f32[] parameter(1)
-      ROOT add = f32[] add(x, y)
-    }
-
-    fused_computation {
-      p0 = f32[64,16384]{1,0} parameter(0)
-      c0 = f32[] constant(0)
-      ROOT reduce.858 = f32[64]{0} reduce(p0, c0), dimensions={1}, to_apply=add
-    }
-
-    ENTRY main {
-      p0 = f32[64,16384]{1,0} parameter(0)
-      fusion = f32[64]{0} fusion(p0), kind=kInput, calls=fused_computation
-      log = f32[64]{0} log(fusion)
-      negate = f32[64]{0} custom-call(log), custom_call_target="negate"
-      ROOT add = f32[64]{0} add(negate, log)
-    }
-  )";
-
-  RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"(
-    CHECK: ENTRY
-    CHECK: %[[PARAM:.*]] = {{.*}} parameter(0)
-    CHECK: %[[FUSION:.*]] = {{.*}} fusion(%[[PARAM]])
-    CHECK: custom-call(%[[FUSION]])
-  )");
-}
-
-TEST_F(PriorityFusionTest, EpilogueFusion) {
-  absl::string_view kHlo = R"(
-    HloModule test_module
-
-    add {
-      p0 = f32[] parameter(0)
-      p1 = f32[] parameter(1)
-      ROOT add.13235 = f32[] add(p0, p1)
-    }
-
-    fused_computation.1 {
-      p0 = f32[8,4,128,226]{3,2,1,0} parameter(0)
-      c0 = f32[] constant(0)
-      ROOT r0 = f32[8,4,128]{2,1,0} reduce(p0, c0), dimensions={3}, to_apply=add
-    }
-
-    fused_computation.2 {
-      p0 = f32[8,4,128]{2,1,0} parameter(0)
-      r1 = f32[8,4,128]{2,1,0} log(p0)
-      ROOT r2 = f32[8,4,128]{2,1,0} log(r1)
-    }
-
-    ENTRY main {
-      p0 = f32[8,4,128,226]{3,2,1,0} parameter(0)
-      f1 = f32[8,4,128]{2,1,0} fusion(p0), kind=kInput, calls=%fused_computation.1
-      ROOT fusion = f32[8,4,128]{2,1,0} fusion(f1), kind=kLoop, calls=%fused_computation.2
-    })";
-
-  RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"(
-CHECK: ROOT {{.*}} = f32[8,4,128]{2,1,0} fusion(%p{{.*}}), kind=kInput, calls=%fused_computation)");
-}
-
-TEST_F(PriorityFusionTest, EpilogueFusionFails) {
-  auto module = *ParseAndReturnVerifiedModule(R"(
-    HloModule test_module
-
-    add {
-      p0 = f32[] parameter(0)
-      p1 = f32[] parameter(1)
-      ROOT add.13235 = f32[] add(p0, p1)
-    }
-
-    fused_computation.1 {
-      p0 = f32[28672,4096]{1,0} parameter(0)
-      c0 = f32[] constant(0)
-      ROOT r = f32[28672]{0} reduce(p0, c0), dimensions={1}, to_apply=add
-    }
-
-    fused_computation.2 {
-      p0 = f32[28672]{0} parameter(0)
-      p1 = f32[28672]{0} parameter(1)
-      ROOT a = f32[28672]{0} add(p0, p1)
-    }
-
-    ENTRY main {
-      p0 = f32[28672,4096]{1,0} parameter(0)
-      p1 = f32[28672]{0} parameter(1)
-      f = f32[28672]{0} fusion(p0), kind=kInput, calls=%fused_computation.1
-      ROOT fusion = f32[28672]{0} fusion(f,p1), kind=kLoop, calls=%fused_computation.2
-    })");
-
-  EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(false));
-}
-
-TEST_F(PriorityFusionTest, DoNotFuseIntoRoot) {
-  auto module = *ParseAndReturnVerifiedModule(R"(
-    HloModule test_module
-
-    ENTRY %main (p.0: u32[2], p.1: u32[]) -> u32[2] {
-      %p.0 = u32[2]{0} parameter(0)
-      %p.1 = u32[] parameter(1)
-      ROOT %broadcast = u32[2]{0} broadcast(u32[] %p.1), dimensions={}, sharding={replicated}
-      %add = u32[2]{0} add(u32[2]{0} %p.0, u32[2]{0} %broadcast)
-      %tuple.1 = (u32[2]{0}) tuple(u32[2]{0} %add)
-      %token.0 = token[] after-all()
-      %outfeed.6 = token[] outfeed((u32[2]{0}) %tuple.1, token[] %token.0), outfeed_shape=(u32[2]{0}), sharding={maximal device=0}
-    })");
-
-  EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(false));
-}
-
-TEST_F(PriorityFusionTest, DontFuseConcat) {
-  // Regression test that verifies we don't fuse concat into a column reduction.
-  auto module = *ParseAndReturnVerifiedModule(R"(
-    HloModule module
-
-    %maximum (param_0: f32[], param_1: f32[]) -> f32[] {
-      %param_0 = f32[] parameter(0)
-      %param_1 = f32[] parameter(1)
-      ROOT %maximum = f32[] maximum(f32[] %param_0, f32[] %param_1)
-    }
-
-    %fused_concat (param_0: f32[1,4,401,8,8], param_1: f32[1,1,4,1023,8], param_2: bf16[1,4,1023,8,8]) -> f32[1,4,1424,8,8] {
-      %param_2 = bf16[1,4,1023,8,8]{4,3,2,1,0} parameter(2)
-      %convert = f32[1,4,1023,8,8]{4,3,2,1,0} convert(bf16[1,4,1023,8,8]{4,3,2,1,0} %param_2)
-      %param_1 = f32[1,1,4,1023,8]{4,3,2,1,0} parameter(1)
-      %bitcast = f32[4,1023,8]{2,1,0} bitcast(f32[1,1,4,1023,8]{4,3,2,1,0} %param_1)
-      %broadcast = f32[1,4,1023,8,8]{4,3,2,1,0} broadcast(f32[4,1023,8]{2,1,0} %bitcast), dimensions={1,2,4}
-      %add = f32[1,4,1023,8,8]{4,3,2,1,0} add(f32[1,4,1023,8,8]{4,3,2,1,0} %convert, f32[1,4,1023,8,8]{4,3,2,1,0} %broadcast)
-      %param_0 = f32[1,4,401,8,8]{4,3,2,1,0} parameter(0)
-      ROOT %concatenate = f32[1,4,1424,8,8]{4,3,2,1,0} concatenate(f32[1,4,1023,8,8]{4,3,2,1,0} %add, f32[1,4,401,8,8]{4,3,2,1,0} %param_0), dimensions={2}
-    }
-
-    %fused_reduce (param_0: f32[], param_1: f32[1,4,1424,8,8]) -> f32[4,8,8] {
-      %param_1 = f32[1,4,1424,8,8]{4,3,2,1,0} parameter(1)
-      %bitcast = f32[4,1424,8,8]{3,2,1,0} bitcast(f32[1,4,1424,8,8]{4,3,2,1,0} %param_1)
-      %param_0 = f32[] parameter(0)
-      ROOT %reduce = f32[4,8,8]{2,1,0} reduce(f32[4,1424,8,8]{3,2,1,0} %bitcast, f32[] %param_0), dimensions={1}, to_apply=%maximum
-    }
-
-    %fused_broadcast (param_0: f32[1,4,1424,8,8], param_1: f32[4,8,8]) -> f32[1,4,1424,8,8] {
-      %param_0 = f32[1,4,1424,8,8]{4,3,2,1,0} parameter(0)
-      %param_1 = f32[4,8,8]{2,1,0} parameter(1)
-      %broadcast = f32[1,4,1424,8,8]{4,3,2,1,0} broadcast(f32[4,8,8]{2,1,0} %param_1), dimensions={1,3,4}
-      ROOT %subtract = f32[1,4,1424,8,8]{4,3,2,1,0} subtract(f32[1,4,1424,8,8]{4,3,2,1,0} %param_0, f32[1,4,1424,8,8]{4,3,2,1,0} %broadcast)
-    }
-
-    ENTRY fusion {
-      %param_0 = f32[1,4,401,8,8]{4,3,2,1,0} parameter(0)
-      %param_1 = f32[1,1,4,1023,8]{4,3,2,1,0} parameter(1)
-      %param_2 = bf16[1,4,1023,8,8]{4,3,2,1,0} parameter(2)
-      %concat = f32[1,4,1424,8,8]{4,3,2,1,0} fusion(%param_0, %param_1, %param_2), kind=kLoop, calls=fused_concat
-      %param_3 = f32[] parameter(3)
-      %reduce = f32[4,8,8]{2,1,0} fusion(%param_3, %concat), kind=kLoop, calls=fused_reduce
-      %param_4 = f32[4,8,8]{2,1,0} parameter(4)
-      %broadcast = f32[1,4,1424,8,8]{4,3,2,1,0} fusion(%concat, %param_4), kind=kLoop, calls=fused_broadcast
-      ROOT tuple = (f32[4,8,8]{2,1,0}, f32[1,4,1424,8,8]{4,3,2,1,0}) tuple(%reduce, %broadcast)
-    }
-  )");
-
-  EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(false));
-}
-
-TEST_F(PriorityFusionTest, FuseOnlySmallConstant) {
-  auto module = *ParseAndReturnVerifiedModule(R"(
-    HloModule module
-
-    ENTRY main {
-      param_0 = f32[32,32]{1,0} parameter(0)
-      c_1 = f32[] constant(1)
-      c_2 = f32[32,32] constant({...})
-      broadcast = f32[32,32]{1,0} broadcast(c_1), dimensions={}
-      add = f32[32,32]{1,0} add(param_0, broadcast)
-      ROOT mul = f32[32,32]{1,0} multiply(c_2, add)
-    }
-  )");
-  EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(true));
-
-  HloInstruction* root = module->entry_computation()->root_instruction();
-  ASSERT_THAT(root, GmockMatch(m::Fusion(m::Constant(), m::Parameter())));
-  EXPECT_THAT(root->fused_expression_root(),
-              GmockMatch(m::Multiply(
-                  m::Parameter(),
-                  m::Add(m::Parameter(), m::Broadcast(m::Constant())))));
-}
-
-TEST_F(PriorityFusionTest, DoNotFuseProducerConsumerMergedTooLarge) {
-  auto module = *ParseAndReturnVerifiedModule(R"(
-    HloModule module
-
-    fused_computation.1 {
-      iota.9.7 = s32[3,1,1]{2,1,0} iota(), iota_dimension=0
-      param_3.29 = s32[] parameter(2)
-      pad.2.7 = s32[3,1,2]{2,1,0} pad(iota.9.7, param_3.29), padding=0_0x0_0x0_1
-      param_2.39 = s32[] parameter(1)
-      broadcast.76.1 = s32[3,1,2]{2,1,0} broadcast(param_2.39), dimensions={}
-      compare.9.1 = pred[3,1,2]{2,1,0} compare(pad.2.7, broadcast.76.1), direction=GE
-      param_1.73 = s32[2]{0} parameter(0)
-      broadcast.78.1 = s32[3,2]{1,0} broadcast(param_1.73), dimensions={1}
-      bitcast.1 = s32[3,2]{1,0} bitcast(pad.2.7)
-      compare.10.1 = pred[3,2]{1,0} compare(bitcast.1, broadcast.78.1), direction=LE
-      bitcast.2 = pred[3,1,2]{2,1,0} bitcast(compare.10.1)
-      ROOT and.3.1 = pred[3,1,2]{2,1,0} and(compare.9.1, bitcast.2)
-    }
-
-    and {
-      x = pred[] parameter(0)
-      y = pred[] parameter(1)
-      ROOT and = pred[] and(x, y)
-    }
-
-    fused_computation.2 {
-      param0 = pred[3,1,2]{2,1,0} parameter(0)
-      slice = pred[1,1,2]{2,1,0} slice(param0), slice={[0:1], [0:1], [0:2]}
-      bitcast = pred[2]{0} bitcast(slice)
-      init = pred[] constant(true)
-      reduce = pred[2]{0} reduce(param0, init), dimensions={0,1}, to_apply=and
-      and = pred[2]{0} and(bitcast, reduce)
-      pad = pred[3]{0} pad(and, init), padding=0_1
-      broadcast = pred[3,2]{1,0} broadcast(pad), dimensions={0}
-      bitcast2 = pred[6]{0} bitcast(broadcast)
-      broadcast2 = pred[2,3]{1,0} broadcast(pad), dimensions={1}
-      bitcast3 = pred[6]{0} bitcast(broadcast2)
-      ROOT and2 = pred[6]{0} and(bitcast2, bitcast3)
-    }
-
-    ENTRY main {
-      p0 = s32[2]{0} parameter(0)
-      p1 = s32[] parameter(1)
-      p2 = s32[] parameter(2)
-      fusion1 = pred[3,1,2]{2,1,0} fusion(p0, p1, p2), kind=kLoop, calls=fused_computation.1
-      ROOT fusion2 = pred[6]{0} fusion(fusion1), kind=kInput, calls=fused_computation.2
-    }
-  )");
-  EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(false));
-}
-
-TEST_F(PriorityFusionTest, CanMergeTritonFusionWithBothProducerAndConsumer) {
-#ifndef GOOGLE_CUDA
-  GTEST_SKIP() << "Triton fusion only enable for CUDA devices.";
-#endif
-
-  const std::string kHloText = R"(
-HloModule t
-add {
-  Arg_0 = f32[] parameter(0)
-  Arg_1 = f32[] parameter(1)
-  ROOT add = f32[] add(Arg_0, Arg_1)
-}
-
-producer_computation {
-  parameter_0 = f32[125]{0} parameter(0)
-  ROOT broadcast = f32[125,127]{1,0} broadcast(parameter_0), dimensions={0}
-}
-
-consumer_computation {
-  parameter_0 = f32[125,127]{1,0} parameter(0)
-  parameter_1 = f32[125,127]{1,0} parameter(1)
-  ROOT multiply = f32[125,127]{1,0} multiply(parameter_1, parameter_0)
-}
-
-triton_softmax_computation {
-  parameter_0 = f32[125,127]{1,0} parameter(0)
-  multiply_0 = f32[125,127]{1,0} multiply(parameter_0, parameter_0)
-  constant_0 = f32[] constant(0)
-  reduce_0 = f32[125]{0} reduce(multiply_0, constant_0), dimensions={1}, to_apply=add
-  broadcast_4 = f32[125,127]{1,0} broadcast(reduce_0), dimensions={0}
-  ROOT multiply = f32[125,127]{1,0} multiply(multiply_0, broadcast_4)
-}
-
-ENTRY main {
-  param_0 = f32[125]{0} parameter(0)
-  param_1 = f32[125,127]{1,0} parameter(1)
-  producer_fusion = f32[125,127]{1,0} fusion(param_0), kind=kLoop, calls=producer_computation
-  triton_softmax = f32[125,127]{1,0} fusion(producer_fusion), kind=kCustom, calls=triton_softmax_computation, backend_config={"fusion_backend_config": {"kind":"__triton"}}
-  ROOT consumer_fusion = f32[125,127]{1,0} fusion(param_1, triton_softmax), kind=kLoop, calls=consumer_computation
-})";
-  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText));
-  auto debug_options = module->config().debug_options();
-  debug_options.set_xla_gpu_enable_triton_softmax_priority_fusion(true);
-  module->mutable_config().set_debug_options(debug_options);
-
-  EXPECT_TRUE(priority_fusion_.Run(module.get()).value());
-  EXPECT_TRUE(verifier().Run(module.get()).status().ok());
-
-  HloInstruction* root = module->entry_computation()->root_instruction();
-  EXPECT_THAT(root, GmockMatch(m::Fusion(m::Parameter(), m::Parameter())));
-  EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kCustom);
-  EXPECT_TRUE(IsGenericTritonFusion(*root));
-}
-
-TEST_F(PriorityFusionTest, DoNotFuseInsideReducer) {
-  auto module = *ParseAndReturnVerifiedModule(R"(
-    %reducer {
-      p0 = f32[] parameter(0)
-      p1 = f32[] parameter(1)
-      add = f32[] add(p0, p1)
-      ROOT max = f32[] maximum(add, p0)
-    }
-
-    %fused_reduce {
-      p0 = f32[256] parameter(0)
-      p1 = f32[] parameter(1)
-      ROOT reduce = f32[] reduce(p0, p1), dimensions={0}, to_apply=%reducer
-    }
-
-    ENTRY fusion {
-      p0 = f32[256] parameter(0)
-      p1 = f32[] parameter(1)
-      ROOT %reduce = f32[] fusion(p0, p1), kind=kInput, calls=fused_reduce
-    }
-  )");
-  EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(false));
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/ptx_compilation_test.cc b/third_party/xla/xla/service/gpu/ptx_compilation_test.cc
index b8496b3..cd0be0b 100644
--- a/third_party/xla/xla/service/gpu/ptx_compilation_test.cc
+++ b/third_party/xla/xla/service/gpu/ptx_compilation_test.cc
@@ -41,6 +41,7 @@
 #include "xla/service/gpu/gpu_executable.h"
 #include "xla/service/gpu/nvptx_compiler.h"
 #include "xla/service/hlo_module_config.h"
+#include "xla/stream_executor/cuda/nvjitlink_support.h"
 #include "xla/stream_executor/cuda/ptx_compilation_method.h"
 #include "xla/stream_executor/cuda/ptx_compiler_support.h"
 #include "xla/stream_executor/cuda/ptx_linking_method.h"
@@ -155,6 +156,29 @@
       // Compiled without libnvptxcompiler support
       GTEST_SKIP() << "libnvptxcompiler is not supported in this build.";
     }
+
+    if (!stream_executor::IsLibNvJitLinkSupported() &&
+        (compilation_method == PtxCompilationMethod::kNvJitLink ||
+         linking_method == PtxLinkingMethod::kNvJitLink)) {
+      // Compiled without libnvjitlink support
+      GTEST_SKIP() << "libnvjitlink is not supported in this build.";
+    }
+
+    if (compilation_method == PtxCompilationMethod::kNvJitLink &&
+        linking_method != PtxLinkingMethod::kNvJitLink) {
+      // When compilation method is NvJitLink, linking method must be NvJitLink
+      // as well.
+      GTEST_SKIP() << "Compilation method NvJitLink is only supported if the "
+                      "linking method is NvJitLink as well.";
+    }
+
+    if (compilation_method == PtxCompilationMethod::kPtxas &&
+        linking_method == PtxLinkingMethod::kNvJitLink) {
+      // We could support this combination, but it would require some
+      // refactoring of the flags.
+      GTEST_SKIP() << "Compilation method Ptxas is not supported with linking "
+                      "method NvJitLink.";
+    }
   }
 
   void SetDebugOptionsFromPtxSettings(DebugOptions* debug_options,
@@ -163,6 +187,10 @@
     debug_options->set_xla_gpu_enable_libnvptxcompiler(
         compilation_method == PtxCompilationMethod::kNvPtxCompiler);
 
+    debug_options->set_xla_gpu_enable_libnvjitlink(
+        compilation_method == PtxCompilationMethod::kNvJitLink ||
+        linking_method == PtxLinkingMethod::kNvJitLink);
+
     debug_options->set_xla_gpu_enable_llvm_module_compilation_parallelism(
         linking_method != PtxLinkingMethod::kNone);
     debug_options->set_xla_gpu_force_compilation_parallelism(12);
@@ -316,9 +344,11 @@
     ::testing::Combine(
         ::testing::Values("simple", "parallel_compilation", "requires_sm90a"),
         ::testing::Values(PtxCompilationMethod::kNvPtxCompiler,
-                          PtxCompilationMethod::kPtxas),
+                          PtxCompilationMethod::kPtxas,
+                          PtxCompilationMethod::kNvJitLink),
         ::testing::Values(PtxLinkingMethod::kNone, PtxLinkingMethod::kNvLink,
-                          PtxLinkingMethod::kDriver)),
+                          PtxLinkingMethod::kDriver,
+                          PtxLinkingMethod::kNvJitLink)),
     [](const ::testing::TestParamInfo<std::tuple<
            std::string_view, PtxCompilationMethod, PtxLinkingMethod>>& info) {
       return GenerateParametrizedTestname(std::get<0>(info.param),
diff --git a/third_party/xla/xla/service/gpu/reduction_degenerate_dim_remover.cc b/third_party/xla/xla/service/gpu/reduction_degenerate_dim_remover.cc
deleted file mode 100644
index ac5419c..0000000
--- a/third_party/xla/xla/service/gpu/reduction_degenerate_dim_remover.cc
+++ /dev/null
@@ -1,131 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/reduction_degenerate_dim_remover.h"
-
-#include <cstdint>
-#include <memory>
-#include <utility>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/container/inlined_vector.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-class ReductionDegenerateDimRemoverVisitor : public DfsHloRewriteVisitor {
- public:
-  absl::Status HandleReduce(HloInstruction *hlo) override {
-    auto instr = Cast<HloReduceInstruction>(hlo);
-    absl::InlinedVector<HloInstruction *, 2> input_reshapes;
-    absl::InlinedVector<Shape, 2> canonical_reduce_shapes;
-
-    int idx = -1;
-    std::vector<int64_t> updated_reduced_dimensions;
-    for (HloInstruction *reduced_op : instr->inputs()) {
-      idx++;
-      const Shape &input_shape = reduced_op->shape();
-      const Shape &reduce_shape = instr->shape().IsTuple()
-                                      ? instr->shape().tuple_shapes(idx)
-                                      : instr->shape();
-
-      if (!ShapeUtil::HasDegenerateDimensions(reduced_op->shape())) {
-        return absl::OkStatus();
-      }
-      Shape canonical_input_shape =
-          ShapeUtil::DropDegenerateDimensions(input_shape);
-
-      Shape canonical_reduce_shape =
-          ShapeUtil::DropDegenerateDimensions(reduce_shape);
-
-      auto reduced_dimensions = instr->dimensions();
-      int64_t shift = 0;
-
-      for (int dim = 0; dim < input_shape.rank(); dim++) {
-        if (input_shape.dimensions(dim) == 1) {
-          shift++;
-        } else {
-          if (absl::c_linear_search(reduced_dimensions, dim) && idx == 0) {
-            // Only populate on first iteration.
-            updated_reduced_dimensions.push_back(dim - shift);
-          }
-        }
-      }
-
-      if (updated_reduced_dimensions.empty()) {
-        std::unique_ptr<HloInstruction> reshape =
-            HloInstruction::CreateBitcast(reduce_shape, reduced_op);
-        return ReplaceWithNewInstruction(instr, std::move(reshape));
-      }
-
-      input_reshapes.push_back(instr->parent()->AddInstruction(
-          HloInstruction::CreateBitcast(canonical_input_shape, reduced_op)));
-      canonical_reduce_shapes.push_back(canonical_reduce_shape);
-    }
-
-    Shape canonical_reduce_shape =
-        ShapeUtil::MakeMaybeTupleShape(canonical_reduce_shapes);
-    const Shape &orig_reduce_shape = instr->shape();
-    std::unique_ptr<HloInstruction> new_reduce = HloInstruction::CreateReduce(
-        canonical_reduce_shape, input_reshapes, instr->init_values(),
-        updated_reduced_dimensions, instr->to_apply());
-    instr->SetupDerivedInstruction(new_reduce.get());
-
-    if (canonical_reduce_shape != instr->shape()) {
-      HloInstruction *wrapped_reduce =
-          instr->parent()->AddInstruction(std::move(new_reduce));
-      absl::InlinedVector<HloInstruction *, 2> out;
-      if (!canonical_reduce_shape.IsTuple()) {
-        new_reduce =
-            HloInstruction::CreateBitcast(orig_reduce_shape, wrapped_reduce);
-      } else {
-        for (int oidx = 0; oidx < instr->input_count(); oidx++) {
-          HloInstruction *gte = instr->parent()->AddInstruction(
-              HloInstruction::CreateGetTupleElement(wrapped_reduce, oidx));
-          out.push_back(
-              instr->parent()->AddInstruction(HloInstruction::CreateBitcast(
-                  orig_reduce_shape.tuple_shapes(oidx), gte)));
-        }
-        new_reduce = HloInstruction::CreateTuple(out);
-      }
-    }
-
-    return ReplaceWithNewInstruction(instr, std::move(new_reduce));
-  }
-};
-
-absl::StatusOr<bool> ReductionDegenerateDimRemover::Run(
-    HloModule *module,
-    const absl::flat_hash_set<absl::string_view> &execution_threads) {
-  TF_ASSIGN_OR_RETURN(bool changed,
-                      ReductionDegenerateDimRemoverVisitor().RunOnModule(
-                          module, execution_threads));
-  return changed;
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/reduction_degenerate_dim_remover.h b/third_party/xla/xla/service/gpu/reduction_degenerate_dim_remover.h
deleted file mode 100644
index 03d6819..0000000
--- a/third_party/xla/xla/service/gpu/reduction_degenerate_dim_remover.h
+++ /dev/null
@@ -1,56 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_REDUCTION_DEGENERATE_DIM_REMOVER_H_
-#define XLA_SERVICE_GPU_REDUCTION_DEGENERATE_DIM_REMOVER_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-// Enforces the invariant that reduction input and output have no degenerate
-// (size 1) dimension. Since these dimensions are physically meaningless, they
-// are removed using bitcasts.
-//
-// For example,
-//
-//   f[1] out = reduce(f[100, 1, 1] input, dimensions={0, 1})
-//
-// becomes:
-//
-//
-//   f[100] tmp1 = f[100] bitcast(f[100, 1, 1], input)
-//   f[] tmp2 = reduce(f[100] tmp1, dimensions={0})
-//   f[1] out = f[] bitcast(tmp2)
-//
-class ReductionDegenerateDimRemover : public HloModulePass {
- public:
-  absl::string_view name() const override {
-    return "reduction-degenerate-dim-remover";
-  }
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_REDUCTION_DEGENERATE_DIM_REMOVER_H_
diff --git a/third_party/xla/xla/service/gpu/reduction_dimension_grouper.cc b/third_party/xla/xla/service/gpu/reduction_dimension_grouper.cc
deleted file mode 100644
index 8ab4fcf..0000000
--- a/third_party/xla/xla/service/gpu/reduction_dimension_grouper.cc
+++ /dev/null
@@ -1,122 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/reduction_dimension_grouper.h"
-
-#include <cstdint>
-#include <memory>
-#include <utility>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/container/inlined_vector.h"
-#include "absl/log/check.h"
-#include "absl/log/log.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/layout_util.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-class ReduceDimensionGroupVisitor : public DfsHloRewriteVisitor {
- public:
-  absl::Status HandleReduce(HloInstruction *hlo) override {
-    auto reduce = Cast<HloReduceInstruction>(hlo);
-
-    VLOG(4) << "Input: " << reduce->ToString();
-
-    absl::InlinedVector<HloInstruction *, 2> reduce_inputs_grouped;
-    std::vector<int64_t> reduced_dims_grouped;
-
-    int idx = -1;
-    for (HloInstruction *operand : reduce->inputs()) {
-      idx++;
-      std::vector<int64_t> new_grouped_dims;
-      const Shape &shape = operand->shape();
-      CHECK(shape == LayoutUtil::GetWithDefaultLayout(shape))
-          << "Default layout should be enforced on reduction operand";
-      auto is_reduced = [&](int dim) {
-        return absl::c_linear_search(reduce->dimensions(), dim);
-      };
-
-      bool changed = false;
-      int64_t next_dim_size = 1;
-
-      // Since we have enforced the standard layout, iteration over logical
-      // dimensions is equivalent to iteration over the major-to-minor order.
-      for (int logical_dim = 0; logical_dim < shape.rank(); logical_dim++) {
-        VLOG(5) << "Processing dimension " << logical_dim << " of size "
-                << shape.dimensions(logical_dim);
-        if (is_reduced(logical_dim) && logical_dim < shape.rank() - 1 &&
-            is_reduced(logical_dim + 1)) {
-          VLOG(5) << "This and consecutive dimension are reduced, merging";
-          changed = true;
-          next_dim_size *= shape.dimensions(logical_dim);
-          continue;
-        }
-
-        if (is_reduced(logical_dim)) {
-          new_grouped_dims.push_back(next_dim_size *
-                                     shape.dimensions(logical_dim));
-          if (idx == 0) {
-            // Only populate for first argument.
-            reduced_dims_grouped.push_back(new_grouped_dims.size() - 1);
-          }
-          next_dim_size = 1;
-        } else {
-          new_grouped_dims.push_back(shape.dimensions(logical_dim));
-        }
-      }
-
-      if (!changed) {  // Since all inputs have same shape dimensions.
-        return absl::OkStatus();
-      }
-
-      Shape grouped_shape =
-          ShapeUtil::MakeShape(shape.element_type(), new_grouped_dims);
-      reduce_inputs_grouped.push_back(reduce->parent()->AddInstruction(
-          HloInstruction::CreateBitcast(grouped_shape, operand),
-          &operand->metadata()));
-      VLOG(5) << "Adding bitcast: " << reduce_inputs_grouped.back()->ToString();
-    }
-
-    std::unique_ptr<HloInstruction> new_reduce = HloInstruction::CreateReduce(
-        reduce->shape(), reduce_inputs_grouped, reduce->init_values(),
-        reduced_dims_grouped, reduce->to_apply());
-    VLOG(5) << "Generated new reduction: " << new_reduce->ToString();
-    return ReplaceWithNewInstruction(reduce, std::move(new_reduce));
-  }
-};
-
-absl::StatusOr<bool> ReductionDimensionGrouper::Run(
-    HloModule *module,
-    const absl::flat_hash_set<absl::string_view> &execution_threads) {
-  TF_ASSIGN_OR_RETURN(bool changed, ReduceDimensionGroupVisitor().RunOnModule(
-                                        module, execution_threads));
-  return changed;
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/reduction_dimension_grouper.h b/third_party/xla/xla/service/gpu/reduction_dimension_grouper.h
deleted file mode 100644
index 8ee4efd..0000000
--- a/third_party/xla/xla/service/gpu/reduction_dimension_grouper.h
+++ /dev/null
@@ -1,56 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_REDUCTION_DIMENSION_GROUPER_H_
-#define XLA_SERVICE_GPU_REDUCTION_DIMENSION_GROUPER_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-// Groups adjacent (logically and physically) reduced dimensions in reduction
-// input.
-//
-// Precondition: ReductionLayoutNormalizer has been run (physical proximity and
-// logical proximity become the same).
-//
-// For example,
-//
-//   f[] out = reduce(f[10,20,30] input, dimensions={0,1,2})
-//
-// becomes:
-//
-//   f[600] tmp = f[600] bitcast(f[10,20,30] input)
-//   f[] out = reduce(f[600] tmp, dimensions={0})
-//
-class ReductionDimensionGrouper : public HloModulePass {
- public:
-  absl::string_view name() const override {
-    return "reduction-dimension-grouper";
-  }
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_REDUCTION_DIMENSION_GROUPER_H_
diff --git a/third_party/xla/xla/service/gpu/reduction_layout_normalizer.cc b/third_party/xla/xla/service/gpu/reduction_layout_normalizer.cc
deleted file mode 100644
index a91fdf7..0000000
--- a/third_party/xla/xla/service/gpu/reduction_layout_normalizer.cc
+++ /dev/null
@@ -1,203 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/reduction_layout_normalizer.h"
-
-#include <cstdint>
-#include <memory>
-#include <utility>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/container/inlined_vector.h"
-#include "absl/log/check.h"
-#include "absl/log/log.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/layout.h"
-#include "xla/layout_util.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/status_macros.h"
-#include "xla/util.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-class EnforceMinorToMajorReduceOpVisitor : public DfsHloRewriteVisitor {
-  absl::Status HandleReduce(HloInstruction *hlo) override {
-    auto reduce = Cast<HloReduceInstruction>(hlo);
-    VLOG(5) << "Input: " << reduce->ToString();
-
-    int operand_idx = -1;
-
-    absl::InlinedVector<HloInstruction *, 2> canonical_reduce_inputs;
-    absl::InlinedVector<Shape, 2> new_reduce_shapes;
-
-    DimensionVector out_reduce_dimensions;
-    const Shape &first_instruction_shape = reduce->inputs()[0]->shape();
-
-    for (HloInstruction *operand : reduce->inputs()) {
-      operand_idx++;
-
-      if (operand_idx != 0 &&
-          operand->shape().layout() != first_instruction_shape.layout()) {
-        HloInstruction *copy =
-            reduce->parent()->AddInstruction(HloInstruction::CreateUnary(
-                operand->shape(), HloOpcode::kCopy, operand));
-
-        LayoutUtil::ClearLayout(copy->mutable_shape());
-        TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
-            first_instruction_shape, copy->mutable_shape()));
-
-        copy->set_metadata(operand->metadata());
-        operand = copy;
-        VLOG(3) << "Copying to establish consistent inputs layout: "
-                << copy->ToString();
-      }
-
-      const Shape &operand_shape = operand->shape();
-      const Layout &operand_layout = operand_shape.layout();
-
-      const Shape &reduce_shape =
-          reduce->shape().IsTuple() ? reduce->shape().tuple_shapes(operand_idx)
-                                    : reduce->shape();
-
-      DimensionVector new_reduce_dimensions;
-      DimensionVector new_operand_shape_data;
-      DimensionVector new_reduce_shape_data;
-
-      // The layout order of the reduction output can be different to the
-      // ordering of kept dimensions in the input operand, thus we need to
-      // calculate the new layout.
-      DimensionVector new_reduce_shape_layout(reduce_shape.rank());
-      std::vector<int64_t> reduce_shape_logical_to_physical =
-          LayoutUtil::MakeLogicalToPhysical(reduce_shape.layout());
-
-      auto to_reduce_logical_dim = [&](int64_t op_logical_dim) {
-        return op_logical_dim -
-               absl::c_count_if(reduce->dimensions(), [&](int64_t dim) {
-                 CHECK(dim != op_logical_dim);
-                 return dim < op_logical_dim;
-               });
-      };
-
-      for (int i = 0; i < operand_shape.rank(); i++) {
-        // Process the dimensions in the major-to-minor order in order to
-        // enforce the default layout.
-        int64_t major_to_minor_dim_idx = operand_shape.rank() - i - 1;
-        int64_t logical_dim =
-            operand_layout.minor_to_major(major_to_minor_dim_idx);
-        int64_t dim_size = operand_shape.dimensions(logical_dim);
-        VLOG(5) << "Processing logical dimension " << logical_dim << " of size "
-                << dim_size;
-        new_operand_shape_data.push_back(dim_size);
-
-        if (absl::c_linear_search(reduce->dimensions(), logical_dim)) {
-          new_reduce_dimensions.push_back(i);
-        } else {
-          new_reduce_shape_data.push_back(dim_size);
-          int64_t logical_reduce_dim = to_reduce_logical_dim(logical_dim);
-          int64_t physical_reduce_dim =
-              reduce_shape_logical_to_physical[logical_reduce_dim];
-          VLOG(5) << "logical_reduce_dim = " << logical_reduce_dim << ", "
-                  << "physical_reduce_dim = " << physical_reduce_dim;
-          new_reduce_shape_layout[reduce_shape.rank() - physical_reduce_dim -
-                                  1] = new_reduce_shape_data.size() - 1;
-        }
-      }
-
-      Shape new_operand_shape = ShapeUtil::MakeShape(
-          operand_shape.element_type(), new_operand_shape_data);
-      Shape new_reduce_shape = ShapeUtil::MakeShapeWithDenseLayout(
-          reduce_shape.element_type(), new_reduce_shape_data,
-          new_reduce_shape_layout);
-
-      if (new_operand_shape == operand_shape && reduce->inputs().size() == 1) {
-        return absl::OkStatus();
-      }
-
-      HloInstruction *canonical_reduce_input =
-          new_operand_shape != operand_shape
-              ? reduce->parent()->AddInstruction(
-                    HloInstruction::CreateBitcast(new_operand_shape, operand))
-              : operand;
-      canonical_reduce_input->set_metadata(operand->metadata());
-      VLOG(5) << "Reduction input: " << canonical_reduce_input->ToString();
-
-      new_reduce_shapes.push_back(new_reduce_shape);
-      canonical_reduce_inputs.push_back(canonical_reduce_input);
-
-      if (out_reduce_dimensions.empty()) {
-        out_reduce_dimensions = new_reduce_dimensions;
-      } else {
-        TF_RET_CHECK(out_reduce_dimensions == new_reduce_dimensions);
-      }
-    }
-
-    Shape new_reduce_shape = ShapeUtil::MakeMaybeTupleShape(new_reduce_shapes);
-
-    std::unique_ptr<HloInstruction> new_reduce = HloInstruction::CreateReduce(
-        new_reduce_shape, canonical_reduce_inputs, reduce->init_values(),
-        out_reduce_dimensions, reduce->to_apply());
-    VLOG(5) << "Generated new reduction: " << new_reduce->ToString();
-    const Shape &orig_reduce_shape = reduce->shape();
-
-    if (new_reduce_shape != orig_reduce_shape) {
-      HloInstruction *wrapped_reduce =
-          reduce->parent()->AddInstruction(std::move(new_reduce));
-
-      if (!new_reduce_shape.IsTuple()) {
-        new_reduce =
-            HloInstruction::CreateBitcast(reduce->shape(), wrapped_reduce);
-      } else {
-        // Bitcast each element of the tuple.
-        absl::InlinedVector<HloInstruction *, 2> out;
-        for (int oidx = 0; oidx < reduce->input_count(); oidx++) {
-          HloInstruction *gte = reduce->parent()->AddInstruction(
-              HloInstruction::CreateGetTupleElement(wrapped_reduce, oidx));
-          out.push_back(
-              reduce->parent()->AddInstruction(HloInstruction::CreateBitcast(
-                  orig_reduce_shape.tuple_shapes(oidx), gte)));
-        }
-        new_reduce = HloInstruction::CreateTuple(out);
-      }
-    }
-
-    VLOG(5) << "Generated output: " << new_reduce->ToString();
-    return ReplaceWithNewInstruction(reduce, std::move(new_reduce));
-  }
-};
-
-absl::StatusOr<bool> ReductionLayoutNormalizer::Run(
-    HloModule *module,
-    const absl::flat_hash_set<absl::string_view> &execution_threads) {
-  TF_ASSIGN_OR_RETURN(bool changed,
-                      EnforceMinorToMajorReduceOpVisitor().RunOnModule(
-                          module, execution_threads));
-  return changed;
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/reduction_layout_normalizer.h b/third_party/xla/xla/service/gpu/reduction_layout_normalizer.h
deleted file mode 100644
index 7d2d207..0000000
--- a/third_party/xla/xla/service/gpu/reduction_layout_normalizer.h
+++ /dev/null
@@ -1,54 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_REDUCTION_LAYOUT_NORMALIZER_H_
-#define XLA_SERVICE_GPU_REDUCTION_LAYOUT_NORMALIZER_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-// Enforces default (minor-to-major) layout on all reduction inputs.
-// Note that since reduction output can request a custom layout,
-// this pass only guarantees standard layout for the input.
-//
-// For example,
-//
-//   f[20,30]{0,1} out = reduce(f[10,20,30]{2,0,1} input, dimensions={0})
-//
-// becomes:
-//
-//   f[20,10,30] tmp = f[20,10,30] bitcast(f[10,20,30]{2,0,1} input)
-//   f[20,30]{0,1} out = reduce(f[20,10,30]{2,1,0} tmp, dimensions={1})
-class ReductionLayoutNormalizer : public HloModulePass {
- public:
-  absl::string_view name() const override {
-    return "reduction-layout-normalizer";
-  }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_REDUCTION_LAYOUT_NORMALIZER_H_
diff --git a/third_party/xla/xla/service/gpu/reduction_splitter.cc b/third_party/xla/xla/service/gpu/reduction_splitter.cc
deleted file mode 100644
index cd37319..0000000
--- a/third_party/xla/xla/service/gpu/reduction_splitter.cc
+++ /dev/null
@@ -1,140 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/reduction_splitter.h"
-
-#include <algorithm>
-#include <cstdint>
-#include <cstdlib>
-#include <memory>
-#include <utility>
-#include <vector>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/log/check.h"
-#include "absl/log/log.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/layout_util.h"
-#include "xla/service/gpu/reduction_utils.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-class ReductionSplitterVisitor : public DfsHloRewriteVisitor {
- public:
-  explicit ReductionSplitterVisitor(bool ignore_small_dims)
-      : ignore_small_dims_(ignore_small_dims) {}
-  absl::Status HandleReduce(HloInstruction *reduce) override {
-    VLOG(4) << "Input: " << reduce->ToString();
-
-    // Reductions with contiguous dimensions are lowered to efficient code. No
-    // need to split such ops.
-    if (IsReductionFromOrToContiguousDimensions(*reduce)) {
-      VLOG(4) << "Reduction with contiguous dimensions. Return.";
-      return absl::OkStatus();
-    }
-    if (reduce->dimensions().size() < 2) {
-      return absl::OkStatus();
-    }
-    if (!reduce->shape().IsArray()) {
-      // TODO(cheshire): Handle variadic reduction.
-      return absl::OkStatus();
-    }
-
-    HloInstruction *operand = reduce->mutable_operand(0);
-    const Shape &shape = operand->shape();
-    CHECK(shape == LayoutUtil::GetWithDefaultLayout(shape))
-        << "Default layout should be enforced on reduction operand";
-    // Verify that contiguous dimensions have been grouped by the
-    // ReductionDimensionGrouper pass.
-    for (int64_t i = 0; i < reduce->dimensions().size(); ++i) {
-      for (int64_t j = i + 1; j < reduce->dimensions().size(); ++j) {
-        CHECK(abs(reduce->dimensions(i) - reduce->dimensions(j)) > 1)
-            << "Reduction dimensions must not be consecutive";
-      }
-    }
-
-    // The reduce op has non-contiguous dimensions. Look for the dimension with
-    // the largest shape dimension. Reducing along this dimension first will
-    // reduce the output size most effectively.
-    int64_t max_shape_dim = 0;
-    int64_t max_reduce_dim = 0;
-    const auto &input_shape = reduce->operand(0)->shape();
-    for (int64_t i = 0; i < reduce->dimensions().size(); ++i) {
-      if (input_shape.dimensions(reduce->dimensions(i)) > max_shape_dim) {
-        max_reduce_dim = reduce->dimensions(i);
-        max_shape_dim = input_shape.dimensions(max_reduce_dim);
-      }
-    }
-    if (ignore_small_dims_ && max_shape_dim <= 8) {
-      return absl::OkStatus();
-    }
-
-    // Split the reduction into a pre-reduction and a final reduction.
-    VLOG(3) << "Splitting reduction " << reduce->name() << " at dimension "
-            << max_reduce_dim;
-    std::vector<int64_t> pre_reduce_dims;
-    pre_reduce_dims.push_back(max_reduce_dim);
-    std::vector<int64_t> pre_reduce_shape_dims(input_shape.dimensions().begin(),
-                                               input_shape.dimensions().end());
-    pre_reduce_shape_dims.erase(pre_reduce_shape_dims.begin() + max_reduce_dim);
-    Shape pre_reduce_shape = ShapeUtil::MakeShape(
-        reduce->shape().element_type(), pre_reduce_shape_dims);
-    std::unique_ptr<HloInstruction> pre_reduce = HloInstruction::CreateReduce(
-        pre_reduce_shape, reduce->mutable_operand(0),
-        reduce->mutable_operand(1), pre_reduce_dims, reduce->to_apply());
-    pre_reduce->set_metadata(reduce->metadata());
-
-    std::vector<int64_t> final_reduce_dims(reduce->dimensions().begin(),
-                                           reduce->dimensions().end());
-    final_reduce_dims.erase(
-        std::remove(final_reduce_dims.begin(), final_reduce_dims.end(),
-                    max_reduce_dim),
-        final_reduce_dims.end());
-    for (int64_t i = 0; i < final_reduce_dims.size(); ++i) {
-      if (final_reduce_dims[i] > max_reduce_dim) {
-        final_reduce_dims[i]--;
-      }
-    }
-    std::unique_ptr<HloInstruction> final_reduce = HloInstruction::CreateReduce(
-        reduce->shape(),
-        reduce->parent()->AddInstruction(std::move(pre_reduce)),
-        reduce->mutable_operand(1), final_reduce_dims, reduce->to_apply());
-    return ReplaceWithNewInstruction(reduce, std::move(final_reduce));
-  }
-
- private:
-  bool ignore_small_dims_;
-};
-
-absl::StatusOr<bool> ReductionSplitter::Run(
-    HloModule *module,
-    const absl::flat_hash_set<absl::string_view> &execution_threads) {
-  TF_ASSIGN_OR_RETURN(bool changed,
-                      ReductionSplitterVisitor(ignore_small_dims_)
-                          .RunOnModule(module, execution_threads));
-  return changed;
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/reduction_splitter.h b/third_party/xla/xla/service/gpu/reduction_splitter.h
deleted file mode 100644
index 7e76525..0000000
--- a/third_party/xla/xla/service/gpu/reduction_splitter.h
+++ /dev/null
@@ -1,59 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_REDUCTION_SPLITTER_H_
-#define XLA_SERVICE_GPU_REDUCTION_SPLITTER_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-// Splits a reduce op into two consecutive reduce ops if the reduce dimensions
-// are not contiguous. Ignores small reduce dimensions if `ignore_small_dims` is
-// set.
-//
-// Reductions with non-contiguous dimensions are emitted as simple element-wise
-// loops. This is inefficient when reducing large input shape dimensions.
-// Splitting such reductions allows using more efficient reduction emitters.
-//
-// This pass splits reduce ops into two consecutive reduce ops. Run it to a
-// fixpoint to split reduce ops along multiple dimensions.
-//
-// Precondition: ReductionDimensionGrouper has been run and adjacent reduce
-// dimentsions have been grouped. Reduction layouts have been normalized.
-
-class ReductionSplitter : public HloModulePass {
- public:
-  explicit ReductionSplitter(bool ignore_small_dims)
-      : ignore_small_dims_(ignore_small_dims) {}
-  absl::string_view name() const override { return "reduction-splitter"; }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
-  bool ignore_small_dims_;
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_REDUCTION_SPLITTER_H_
diff --git a/third_party/xla/xla/service/gpu/reduction_splitter_test.cc b/third_party/xla/xla/service/gpu/reduction_splitter_test.cc
deleted file mode 100644
index 13a5210..0000000
--- a/third_party/xla/xla/service/gpu/reduction_splitter_test.cc
+++ /dev/null
@@ -1,152 +0,0 @@
-/* Copyright 2016 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/reduction_splitter.h"
-
-#include <cstdint>
-#include <vector>
-
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/service/hlo_parser.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/shape_util.h"
-#include "xla/test.h"
-#include "xla/tests/hlo_test_base.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-namespace m = ::xla::match;
-
-class ReductionSplitterTest : public HloTestBase {};
-
-TEST_F(ReductionSplitterTest, SplitReductionAtDimensionTwo) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule test
-
-  add_computation {
-    x = f32[] parameter(0)
-    y = f32[] parameter(1)
-    ROOT add = f32[] add(x, y)
-  }
-
-  ENTRY entry_computation {
-    param_0 = f16[6,16,512,64]{3,2,1,0} parameter(0)
-    transpose.1781 = f16[6,512,16,64]{3,1,2,0} transpose(param_0), dimensions={0,2,1,3}
-    convert.6986 = f32[6,512,16,64]{3,1,2,0} convert(transpose.1781)
-    bitcast.2136 = f32[6,16,512,64]{3,2,1,0} bitcast(convert.6986)
-    constant_11111 = f32[] constant(0)
-    ROOT reduce.982 = f32[16,64]{1,0} reduce(bitcast.2136, constant_11111), dimensions={0,2}, to_apply=add_computation
-  }
-  )")
-                    .value();
-  ASSERT_TRUE(
-      ReductionSplitter(/*ignore_small_dims=*/true).Run(module.get()).value());
-  SCOPED_TRACE(module->ToString());
-  const HloInstruction* root_reduction =
-      module->entry_computation()->root_instruction();
-  ASSERT_THAT(root_reduction,
-              GmockMatch(m::Reduce(m::Reduce(), m::Constant())));
-
-  auto* pre_reduction = root_reduction->operand(0);
-  EXPECT_THAT(pre_reduction->dimensions(), std::vector<int64_t>({2}));
-  EXPECT_THAT(pre_reduction->shape(), ShapeUtil::MakeShape(F32, {6, 16, 64}));
-  EXPECT_THAT(root_reduction->dimensions(), std::vector<int64_t>({0}));
-  EXPECT_THAT(root_reduction->shape(), ShapeUtil::MakeShape(F32, {16, 64}));
-}
-
-TEST_F(ReductionSplitterTest, SplitReductionAtDimensionZero) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule test
-
-  add_computation {
-    x = f32[] parameter(0)
-    y = f32[] parameter(1)
-    ROOT add = f32[] add(x, y)
-  }
-
-  ENTRY entry_computation {
-    param_0 = f32[1024,16,512,64,128]{4,3,2,1,0} parameter(0)
-    constant_11111 = f32[] constant(0)
-    ROOT reduce.982 = f32[16,64]{1,0} reduce(param_0, constant_11111), dimensions={2,0,4}, to_apply=add_computation
-  }
-  )")
-                    .value();
-  ASSERT_TRUE(
-      ReductionSplitter(/*ignore_small_dims=*/false).Run(module.get()).value());
-  SCOPED_TRACE(module->ToString());
-  const HloInstruction* root_reduction =
-      module->entry_computation()->root_instruction();
-  ASSERT_THAT(root_reduction,
-              GmockMatch(m::Reduce(m::Reduce(), m::Constant())));
-
-  auto* pre_reduction = root_reduction->operand(0);
-  EXPECT_THAT(pre_reduction->dimensions(), std::vector<int64_t>({0}));
-  EXPECT_THAT(pre_reduction->shape(),
-              ShapeUtil::MakeShape(F32, {16, 512, 64, 128}));
-  EXPECT_THAT(root_reduction->dimensions(), std::vector<int64_t>({1, 3}));
-  EXPECT_THAT(root_reduction->shape(), ShapeUtil::MakeShape(F32, {16, 64}));
-}
-
-TEST_F(ReductionSplitterTest, DontSplitReductionWithSmallDimensions) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule test
-
-  add_computation {
-    x = f32[] parameter(0)
-    y = f32[] parameter(1)
-    ROOT add = f32[] add(x, y)
-  }
-
-  ENTRY entry_computation {
-    param_0 = f32[16,8,1024,8]{3,2,1,0} parameter(0)
-    constant_11111 = f32[] constant(0)
-    ROOT reduce.982 = f32[16,1024]{1,0} reduce(param_0, constant_11111), dimensions={3,1}, to_apply=add_computation
-  }
-  )")
-                    .value();
-  EXPECT_FALSE(
-      ReductionSplitter(/*ignore_small_dims=*/true).Run(module.get()).value());
-  EXPECT_TRUE(
-      ReductionSplitter(/*ignore_small_dims=*/false).Run(module.get()).value());
-}
-
-TEST_F(ReductionSplitterTest, DontSplitReductionsWithContiguousDimensions) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule test
-
-  add_computation {
-    x = f32[] parameter(0)
-    y = f32[] parameter(1)
-    ROOT add = f32[] add(x, y)
-  }
-
-  ENTRY entry_computation {
-    param_0 = f32[128,128,64,128]{3,2,1,0} parameter(0)
-    constant_11111 = f32[] constant(0)
-    // The dimenstions to keep (1 and 2) are contiguous.
-    ROOT reduce.982 = f32[128,64]{1,0} reduce(param_0, constant_11111), dimensions={3,0}, to_apply=add_computation
-  }
-  )")
-                    .value();
-  EXPECT_FALSE(
-      ReductionSplitter(/*ignore_small_dims=*/false).Run(module.get()).value());
-}
-
-}  // namespace
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/rename_fusions.cc b/third_party/xla/xla/service/gpu/rename_fusions.cc
deleted file mode 100644
index a2a3048..0000000
--- a/third_party/xla/xla/service/gpu/rename_fusions.cc
+++ /dev/null
@@ -1,92 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/rename_fusions.h"
-
-#include <memory>
-#include <string>
-
-#include "absl/container/btree_set.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/str_join.h"
-#include "absl/strings/str_replace.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/gpu/hlo_traversal.h"
-#include "xla/service/gpu/ir_emission_utils.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-constexpr absl::string_view FusionKindToString(
-    HloInstruction::FusionKind kind) {
-  switch (kind) {
-    case HloInstruction::FusionKind::kCustom:
-      return "custom";
-    case HloInstruction::FusionKind::kLoop:
-      return "loop";
-    case HloInstruction::FusionKind::kInput:
-      return "input";
-    case HloInstruction::FusionKind::kOutput:
-      return "output";
-  }
-}
-
-std::string MakeFusionHeroNames(const HloInstruction* instruction) {
-  std::unique_ptr<HloFusionAdaptor> fusion_adaptor =
-      HloFusionAdaptor::ForInstruction(instruction);
-  absl::btree_set<absl::string_view> heroes;
-
-  for (auto root : fusion_adaptor->GetRoots()) {
-    heroes.insert(HloOpcodeString(FindNonTrivialHero(root).opcode()));
-  }
-  return absl::StrReplaceAll(absl::StrJoin(heroes, "_"), {{"-", "_"}});
-}
-
-void RenameFusion(HloModule* module, HloInstruction* instruction) {
-  std::string hero_names = MakeFusionHeroNames(instruction);
-  module->SetAndUniquifyInstrName(
-      instruction, absl::StrCat(FusionKindToString(instruction->fusion_kind()),
-                                "_", hero_names, "_fusion"));
-  module->SetAndUniquifyComputationName(
-      instruction->fused_instructions_computation(),
-      absl::StrCat("fused_", hero_names));
-}
-
-}  // namespace
-
-absl::StatusOr<bool> RenameFusions::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  for (HloComputation* computation : module->MakeNonfusionComputations()) {
-    for (HloInstruction* instruction : computation->instructions()) {
-      if (instruction->opcode() != HloOpcode::kFusion ||
-          instruction->fusion_kind() == HloInstruction::FusionKind::kCustom) {
-        continue;
-      }
-      RenameFusion(module, instruction);
-    }
-  }
-  return true;
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/rename_fusions.h b/third_party/xla/xla/service/gpu/rename_fusions.h
deleted file mode 100644
index c3065a4..0000000
--- a/third_party/xla/xla/service/gpu/rename_fusions.h
+++ /dev/null
@@ -1,47 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_RENAME_FUSIONS_H_
-#define XLA_SERVICE_GPU_RENAME_FUSIONS_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-// An HLO pass that gives fusions and fused computations descriptive names.
-//
-// The name is based on hero instructions and the fusion kind, i.e.
-// Fusions get name "<fusion kind>_<hero instrucitons>_fusion",
-// and fused computations get name "fused_<hero instructions>".
-// In the case of multiple roots, the hero instructions in the name are
-// underscore-separated and alphabetically sorted.
-
-class RenameFusions : public HloModulePass {
-  absl::string_view name() const override { return "rename_fusions"; }
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_RENAME_FUSIONS_H_
diff --git a/third_party/xla/xla/service/gpu/rename_fusions_test.cc b/third_party/xla/xla/service/gpu/rename_fusions_test.cc
deleted file mode 100644
index 60c97cf..0000000
--- a/third_party/xla/xla/service/gpu/rename_fusions_test.cc
+++ /dev/null
@@ -1,83 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/rename_fusions.h"
-
-#include <utility>
-
-#include <gtest/gtest.h>
-#include "absl/strings/string_view.h"
-#include "xla/tests/hlo_test_base.h"
-
-namespace xla {
-namespace gpu {
-
-class RenameFusionsTest : public HloTestBase {
- protected:
-  RenameFusions rename_fusions_;
-};
-
-TEST_F(RenameFusionsTest, FusionInstructionNames) {
-  absl::string_view kHlo = R"(
-      HloModule test_module
-
-      square {
-        p = f32[16384] parameter(0)
-        ROOT m = f32[16384] multiply(p, p)
-      }
-
-      exp {
-        p = f32[16384] parameter(0)
-        ROOT e = f32[16384] exponential(p)
-      }
-
-      log {
-        p = f32[16384] parameter(0)
-        ROOT l = f32[16384] log(p)
-      }
-
-      add {
-        p0 = f32[] parameter(0)
-        p1 = f32[] parameter(1)
-        ROOT add = f32[] add(p0, p1)
-      }
-
-      ENTRY main {
-        p0 = bf16[1024,8192] parameter(0)
-        p1 = f32[8192] parameter(1)
-        p2 = f32[16384] parameter(2)
-        convert = f32[1024,8192] convert(p0)
-        broadcast = f32[1024,8192] broadcast(p1), dimensions={1}
-        c0 = f32[] constant(0)
-        multiply = f32[1024,8192] multiply(broadcast, convert)
-        reduce = f32[1024] reduce(multiply, c0), dimensions={1}, to_apply=add
-        convert.1 = bf16[1024] convert(reduce)
-        s = f32[16384] fusion(p2), kind=kLoop, calls=square
-        e = f32[16384] fusion(s), kind=kLoop, calls=exp
-        l = f32[16384] fusion(s), kind=kInput, calls=log
-        ROOT result = (bf16[1024]{0}, f32[16384]{0}, f32[16384]{0}) tuple(convert.1, l, e)
-      })";
-
-  RunAndFilecheckHloRewrite(kHlo, std::move(rename_fusions_), R"(
-CHECK: ENTRY %main
-CHECK: %loop_multiply_fusion{{.*}} calls=%fused_multiply
-CHECK: %input_log_fusion{{.*}} calls=%fused_log
-CHECK: %loop_exponential_fusion{{.*}} calls=%fused_exponential
-CHECK: ROOT %result
-  )");
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/runtime/BUILD b/third_party/xla/xla/service/gpu/runtime/BUILD
index 861fcca..d6b6f1f 100644
--- a/third_party/xla/xla/service/gpu/runtime/BUILD
+++ b/third_party/xla/xla/service/gpu/runtime/BUILD
@@ -79,7 +79,6 @@
         "//xla/service:executable",
         "//xla/service:global_device_id",
         "//xla/service/gpu:buffer_allocations",
-        "//xla/service/gpu:gpu_fused_mha_runner",
         "//xla/service/gpu:launch_dimensions",
         "//xla/service/gpu:matmul_utils",
         "//xla/service/gpu:stream_executor_util",
@@ -121,7 +120,6 @@
         ":copy_thunk",
         ":cudnn_thunk",
         ":custom_call_thunk",
-        ":fused_mha_thunk",
         ":gemm_thunk",
         ":gpublas_lt_matmul_thunk",
         ":kernel_thunk",
@@ -661,28 +659,6 @@
 )
 
 cc_library(
-    name = "fused_mha_thunk",
-    srcs = ["fused_mha_thunk.cc"],
-    hdrs = ["fused_mha_thunk.h"],
-    deps = [
-        ":thunk",
-        "//xla:util",
-        "//xla:xla_data_proto_cc",
-        "//xla/service:buffer_assignment",
-        "//xla/service/gpu:buffer_allocations",
-        "//xla/service/gpu:gpu_fused_mha_runner",
-        "//xla/stream_executor",
-        "//xla/stream_executor:lazy_op_runner",
-        "@com_google_absl//absl/base:core_headers",
-        "@com_google_absl//absl/container:flat_hash_map",
-        "@com_google_absl//absl/status",
-        "@com_google_absl//absl/synchronization",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-cc_library(
     name = "gemm_thunk",
     srcs = ["gemm_thunk.cc"],
     hdrs = ["gemm_thunk.h"],
diff --git a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc
index 5a05bb4..ce99623 100644
--- a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc
+++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc
@@ -1168,314 +1168,6 @@
 }
 
 //===----------------------------------------------------------------------===//
-// FusedMHACmd
-//===----------------------------------------------------------------------===//
-
-FusedMHACmd::FusedMHACmd(
-    ExecutionStreamId execution_stream_id, GpufMHAConfig config,
-    BufferAllocation::Slice lhs_bmm1, BufferAllocation::Slice rhs_bmm1,
-    BufferAllocation::Slice rhs_bmm2, BufferAllocation::Slice output,
-    BufferAllocation::Slice scratch, BufferAllocation::Slice mask,
-    BufferAllocation::Slice bias, BufferAllocation::Slice activation,
-    BufferAllocation::Slice seqlen_q, BufferAllocation::Slice seqlen_k)
-    : TracedCommandBufferCmd(CommandBufferCmdType::kFusedMHACmd,
-                             execution_stream_id),
-      config_(std::move(config)),
-      lhs_bmm1_buffer_(lhs_bmm1),
-      rhs_bmm1_buffer_(rhs_bmm1),
-      rhs_bmm2_buffer_(rhs_bmm2),
-      output_buffer_(output),
-      scratch_buffer_(scratch),
-      bias_buffer_(bias),
-      activation_buffer_(activation),
-      seqlen_q_buffer_(seqlen_q),
-      seqlen_k_buffer_(seqlen_k) {}
-
-FusedMultiHeadedAttentionRunner& FusedMHACmd::GetOrCreateRunner(
-    const stream_executor::Stream* stream) {
-  absl::MutexLock lock(&mutex_);
-  auto it = runner_cache_.find(stream);
-  if (it == runner_cache_.end()) {
-    it = runner_cache_
-             .insert({stream, std::make_unique<FusedMultiHeadedAttentionRunner>(
-                                  config_)})
-             .first;
-  }
-  return *it->second;
-}
-
-absl::Status FusedMHACmd::Initialize(const Thunk::InitializeParams& params,
-                                     StateManager& state) {
-  se::dnn::LazyOpRunner<se::dnn::FusedMHAOp>* lazy_runner =
-      GetOrCreateRunner(params.command_buffer_trace_stream).AsFusedMHARunner();
-  TF_ASSIGN_OR_RETURN(auto config, config_.AsDnnFusedMHAOpConfig());
-  return lazy_runner
-      ->GetOrCreateRunner(config, params.command_buffer_trace_stream)
-      .status();
-}
-
-absl::Status FusedMHACmd::Record(const Thunk::ExecuteParams& execute_params,
-                                 const RecordParams& record_params,
-                                 se::CommandBuffer* command_buffer) {
-  se::dnn::LazyOpRunner<se::dnn::FusedMHAOp>* lazy_runner =
-      GetOrCreateRunner(execute_params.command_buffer_trace_stream)
-          .AsFusedMHARunner();
-  CHECK(lazy_runner) << "FusedMHA lazy runner cache should have been populated";
-
-  const auto& buffer_allocations = *execute_params.buffer_allocations;
-  se::DeviceMemoryBase lhs_bmm1_buffer =
-      buffer_allocations.GetDeviceAddress(lhs_bmm1_buffer_);
-  se::DeviceMemoryBase rhs_bmm1_buffer =
-      buffer_allocations.GetDeviceAddress(rhs_bmm1_buffer_);
-  se::DeviceMemoryBase rhs_bmm2_buffer =
-      buffer_allocations.GetDeviceAddress(rhs_bmm2_buffer_);
-  se::DeviceMemoryBase output_buffer =
-      buffer_allocations.GetDeviceAddress(output_buffer_);
-  se::DeviceMemoryBase scratch_buffer =
-      buffer_allocations.GetDeviceAddress(scratch_buffer_);
-
-  std::optional<se::DeviceMemoryBase> bias_buffer =
-      AssignBufferIfNotNull(buffer_allocations, bias_buffer_);
-  std::optional<se::DeviceMemoryBase> activation_buffer =
-      AssignBufferIfNotNull(buffer_allocations, activation_buffer_);
-  std::optional<se::DeviceMemoryBase> seqlen_q_buffer =
-      AssignBufferIfNotNull(buffer_allocations, seqlen_q_buffer_);
-  std::optional<se::DeviceMemoryBase> seqlen_k_buffer =
-      AssignBufferIfNotNull(buffer_allocations, seqlen_k_buffer_);
-
-  ExecutionScopeId execution_scope_id = GetExecutionScope(record_params);
-  VLOG(5) << "FusedMHACmd with execution_scope_id: "
-          << execution_scope_id.value();
-  VLOG(5) << " lhs_bmm1_buffer: " << lhs_bmm1_buffer_.ToString();
-  VLOG(5) << " rhs_bmm1_buffer: " << rhs_bmm1_buffer_.ToString();
-  VLOG(5) << " rhs_bmm2_buffer: " << rhs_bmm2_buffer_.ToString();
-  VLOG(5) << " output_buffer: " << output_buffer_.ToString();
-  VLOG(5) << " scratch_buffer: " << scratch_buffer_.ToString();
-  VLOG(5) << " bias_buffer: " << bias_buffer_.ToString();
-  VLOG(5) << " activation_buffer: " << activation_buffer_.ToString();
-  VLOG(5) << " seqlen_q_buffer: " << seqlen_q_buffer_.ToString();
-  VLOG(5) << " seqlen_k_buffer: " << seqlen_k_buffer_.ToString();
-
-  RunFusedMHAOptions opts;
-  opts.runner_cache =
-      &GetOrCreateRunner(execute_params.command_buffer_trace_stream);
-  return AddTracedCommandBuffer(
-      execute_params, record_params, command_buffer, [&](se::Stream* stream) {
-        return RunGpuFMHA(config_, lhs_bmm1_buffer, rhs_bmm1_buffer,
-                          rhs_bmm2_buffer, output_buffer, scratch_buffer,
-                          bias_buffer, activation_buffer, seqlen_q_buffer,
-                          seqlen_k_buffer, stream, opts);
-      });
-}
-
-FusedMHACmd::BufferUsageVector FusedMHACmd::buffers() {
-  BufferUsageVector buffer_usage;
-  buffer_usage.reserve(9);
-  buffer_usage.push_back({lhs_bmm1_buffer_, MemoryAccess::kRead});
-  buffer_usage.push_back({rhs_bmm1_buffer_, MemoryAccess::kRead});
-  buffer_usage.push_back({rhs_bmm2_buffer_, MemoryAccess::kRead});
-  buffer_usage.push_back({output_buffer_, MemoryAccess::kWrite});
-  buffer_usage.push_back({scratch_buffer_, MemoryAccess::kWrite});
-  if (bias_buffer_.allocation() != nullptr) {
-    buffer_usage.push_back({bias_buffer_, MemoryAccess::kRead});
-  }
-  if (activation_buffer_.allocation() != nullptr) {
-    buffer_usage.push_back({activation_buffer_, MemoryAccess::kRead});
-  }
-  if (seqlen_q_buffer_.allocation() != nullptr) {
-    buffer_usage.push_back({seqlen_q_buffer_, MemoryAccess::kRead});
-  }
-  if (seqlen_k_buffer_.allocation() != nullptr) {
-    buffer_usage.push_back({seqlen_k_buffer_, MemoryAccess::kRead});
-  }
-  return buffer_usage;
-}
-
-//===----------------------------------------------------------------------===//
-// FusedMHABackwardCmd
-//===----------------------------------------------------------------------===//
-
-FusedMHABackwardCmd::FusedMHABackwardCmd(
-    ExecutionStreamId execution_stream_id, GpufMHABackwardConfig config,
-    BufferAllocation::Slice bmm1_grad_gemm1_rhs,
-    BufferAllocation::Slice bmm1_grad_gemm2_rhs,
-    BufferAllocation::Slice bmm2_grad_gemm1_lhs,
-    BufferAllocation::Slice bmm2_grad_gemm2_rhs,
-    BufferAllocation::Slice d_output, BufferAllocation::Slice scratch,
-    BufferAllocation::Slice d_bmm1_lhs, BufferAllocation::Slice d_bmm1_rhs,
-    BufferAllocation::Slice d_bmm2_rhs, BufferAllocation::Slice d_s,
-    BufferAllocation::Slice d_bias, BufferAllocation::Slice fwd_output,
-    BufferAllocation::Slice bias, BufferAllocation::Slice seqlen_q,
-    BufferAllocation::Slice seqlen_k)
-    : TracedCommandBufferCmd(CommandBufferCmdType::kFusedMHABackwardCmd,
-                             execution_stream_id),
-      config_(std::move(config)),
-      bmm1_grad_gemm1_rhs_buffer_(bmm1_grad_gemm1_rhs),
-      bmm1_grad_gemm2_rhs_buffer_(bmm1_grad_gemm2_rhs),
-      bmm2_grad_gemm1_lhs_buffer_(bmm2_grad_gemm1_lhs),
-      bmm2_grad_gemm2_rhs_buffer_(bmm2_grad_gemm2_rhs),
-      d_output_buffer_(d_output),
-      scratch_buffer_(scratch),
-      d_bmm1_lhs_buffer_(d_bmm1_lhs),
-      d_bmm1_rhs_buffer_(d_bmm1_rhs),
-      d_bmm2_rhs_buffer_(d_bmm2_rhs),
-      d_s_buffer_(d_s),
-      d_bias_buffer_(d_bias),
-      fwd_output_buffer_(fwd_output),
-      bias_buffer_(bias),
-      seqlen_q_buffer_(seqlen_q),
-      seqlen_k_buffer_(seqlen_k) {}
-
-FusedMultiHeadedAttentionBackwardRunner& FusedMHABackwardCmd::GetOrCreateRunner(
-    const stream_executor::Stream* stream) {
-  absl::MutexLock lock(&mutex_);
-  auto it = runner_cache_.find(stream);
-  if (it == runner_cache_.end()) {
-    it = runner_cache_
-             .insert({stream,
-                      std::make_unique<FusedMultiHeadedAttentionBackwardRunner>(
-                          config_)})
-             .first;
-  }
-  return *it->second;
-}
-
-absl::Status FusedMHABackwardCmd::Initialize(
-    const Thunk::InitializeParams& params, StateManager& state) {
-  se::dnn::LazyOpRunner<se::dnn::FusedMHABackwardOp>* lazy_runner =
-      GetOrCreateRunner(params.command_buffer_trace_stream)
-          .AsFusedMHABackwardRunner();
-  TF_ASSIGN_OR_RETURN(auto config, config_.AsDnnFusedMHABackwardOpConfig());
-  return lazy_runner
-      ->GetOrCreateRunner(config, params.command_buffer_trace_stream)
-      .status();
-}
-
-absl::Status FusedMHABackwardCmd::Record(
-    const Thunk::ExecuteParams& execute_params,
-    const RecordParams& record_params, se::CommandBuffer* command_buffer) {
-  se::dnn::LazyOpRunner<se::dnn::FusedMHABackwardOp>* lazy_runner =
-      GetOrCreateRunner(execute_params.command_buffer_trace_stream)
-          .AsFusedMHABackwardRunner();
-  CHECK(lazy_runner)
-      << "FusedMHABackward lazy runner cache should have been populated";
-
-  const auto& buffer_allocations = *execute_params.buffer_allocations;
-  se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer =
-      buffer_allocations.GetDeviceAddress(bmm1_grad_gemm1_rhs_buffer_);
-
-  se::DeviceMemoryBase bmm1_grad_gemm2_rhs_buffer =
-      buffer_allocations.GetDeviceAddress(bmm1_grad_gemm2_rhs_buffer_);
-
-  se::DeviceMemoryBase bmm2_grad_gemm1_lhs_buffer =
-      buffer_allocations.GetDeviceAddress(bmm2_grad_gemm1_lhs_buffer_);
-
-  se::DeviceMemoryBase bmm2_grad_gemm2_rhs_buffer =
-      buffer_allocations.GetDeviceAddress(bmm2_grad_gemm2_rhs_buffer_);
-
-  se::DeviceMemoryBase d_output_buffer =
-      buffer_allocations.GetDeviceAddress(d_output_buffer_);
-
-  se::DeviceMemoryBase scratch_buffer =
-      buffer_allocations.GetDeviceAddress(scratch_buffer_);
-
-  se::DeviceMemoryBase d_bmm1_lhs_buffer =
-      buffer_allocations.GetDeviceAddress(d_bmm1_lhs_buffer_);
-
-  se::DeviceMemoryBase d_bmm1_rhs_buffer =
-      buffer_allocations.GetDeviceAddress(d_bmm1_rhs_buffer_);
-
-  se::DeviceMemoryBase d_bmm2_rhs_buffer =
-      buffer_allocations.GetDeviceAddress(d_bmm2_rhs_buffer_);
-
-  std::optional<se::DeviceMemoryBase> d_s_buffer =
-      AssignBufferIfNotNull(buffer_allocations, d_s_buffer_);
-  std::optional<se::DeviceMemoryBase> d_bias_buffer =
-      AssignBufferIfNotNull(buffer_allocations, d_bias_buffer_);
-  std::optional<se::DeviceMemoryBase> fwd_output_buffer =
-      AssignBufferIfNotNull(buffer_allocations, fwd_output_buffer_);
-  std::optional<se::DeviceMemoryBase> bias_buffer =
-      AssignBufferIfNotNull(buffer_allocations, bias_buffer_);
-  std::optional<se::DeviceMemoryBase> seqlen_q_buffer =
-      AssignBufferIfNotNull(buffer_allocations, seqlen_q_buffer_);
-  std::optional<se::DeviceMemoryBase> seqlen_k_buffer =
-      AssignBufferIfNotNull(buffer_allocations, seqlen_k_buffer_);
-
-  ExecutionScopeId execution_scope_id = GetExecutionScope(record_params);
-  VLOG(5) << "FusedMHABackwardCmd with execution_scope_id: "
-          << execution_scope_id.value();
-  VLOG(5) << "bmm1_grad_gemm1_rhs_buffer"
-          << bmm1_grad_gemm1_rhs_buffer_.ToString();
-  VLOG(5) << "bmm1_grad_gemm2_rhs_buffer"
-          << bmm1_grad_gemm2_rhs_buffer_.ToString();
-  VLOG(5) << "bmm2_grad_gemm1_lhs_buffer"
-          << bmm2_grad_gemm1_lhs_buffer_.ToString();
-  VLOG(5) << "bmm2_grad_gemm2_rhs_buffer"
-          << bmm2_grad_gemm2_rhs_buffer_.ToString();
-  VLOG(5) << "d_output_buffer" << d_output_buffer_.ToString();
-  VLOG(5) << "scratch_buffer" << scratch_buffer_.ToString();
-  VLOG(5) << "d_bmm1_lhs_buffer" << d_bmm1_lhs_buffer_.ToString();
-  VLOG(5) << "d_bmm1_rhs_buffer" << d_bmm1_rhs_buffer_.ToString();
-  VLOG(5) << "d_bmm2_rhs_buffer" << d_bmm2_rhs_buffer_.ToString();
-  VLOG(5) << "d_s_buffer" << d_s_buffer_.ToString();
-  VLOG(5) << "d_bias_buffer" << d_bias_buffer_.ToString();
-  VLOG(5) << "fwd_output_buffer" << fwd_output_buffer_.ToString();
-  VLOG(5) << "bias_buffer" << bias_buffer_.ToString();
-  VLOG(5) << "seqlen_q_buffer" << seqlen_q_buffer_.ToString();
-  VLOG(5) << "seqlen_k_buffer" << seqlen_k_buffer_.ToString();
-
-  RunFusedMHABackwardOptions opts;
-  opts.runner_cache =
-      &GetOrCreateRunner(execute_params.command_buffer_trace_stream);
-  return AddTracedCommandBuffer(
-      execute_params, record_params, command_buffer, [&](se::Stream* stream) {
-        return RunGpuFMHABackward(
-            config_, bmm1_grad_gemm1_rhs_buffer, bmm1_grad_gemm2_rhs_buffer,
-            bmm2_grad_gemm1_lhs_buffer, bmm2_grad_gemm2_rhs_buffer,
-            d_output_buffer, scratch_buffer, d_bmm1_lhs_buffer,
-            d_bmm1_rhs_buffer, d_bmm2_rhs_buffer, d_s_buffer, d_bias_buffer,
-            fwd_output_buffer, bias_buffer, seqlen_q_buffer, seqlen_k_buffer,
-            stream, opts);
-      });
-}
-
-FusedMHABackwardCmd::BufferUsageVector FusedMHABackwardCmd::buffers() {
-  BufferUsageVector buffer_usage;
-  buffer_usage.reserve(15);
-
-  buffer_usage.push_back({bmm1_grad_gemm1_rhs_buffer_, MemoryAccess::kRead});
-  buffer_usage.push_back({bmm1_grad_gemm2_rhs_buffer_, MemoryAccess::kRead});
-  buffer_usage.push_back({bmm2_grad_gemm1_lhs_buffer_, MemoryAccess::kRead});
-  buffer_usage.push_back({bmm2_grad_gemm2_rhs_buffer_, MemoryAccess::kRead});
-  buffer_usage.push_back({d_output_buffer_, MemoryAccess::kWrite});
-  buffer_usage.push_back({scratch_buffer_, MemoryAccess::kWrite});
-  buffer_usage.push_back({d_bmm1_lhs_buffer_, MemoryAccess::kRead});
-  buffer_usage.push_back({d_bmm1_rhs_buffer_, MemoryAccess::kRead});
-  buffer_usage.push_back({d_bmm2_rhs_buffer_, MemoryAccess::kRead});
-
-  if (d_s_buffer_.allocation() != nullptr) {
-    buffer_usage.push_back({d_s_buffer_, MemoryAccess::kRead});
-  };
-  if (d_bias_buffer_.allocation() != nullptr) {
-    buffer_usage.push_back({d_bias_buffer_, MemoryAccess::kRead});
-  };
-  if (fwd_output_buffer_.allocation() != nullptr) {
-    buffer_usage.push_back({fwd_output_buffer_, MemoryAccess::kRead});
-  };
-  if (bias_buffer_.allocation() != nullptr) {
-    buffer_usage.push_back({bias_buffer_, MemoryAccess::kRead});
-  };
-  if (seqlen_q_buffer_.allocation() != nullptr) {
-    buffer_usage.push_back({seqlen_q_buffer_, MemoryAccess::kRead});
-  };
-  if (seqlen_k_buffer_.allocation() != nullptr) {
-    buffer_usage.push_back({seqlen_k_buffer_, MemoryAccess::kRead});
-  };
-
-  return buffer_usage;
-}
-
-//===----------------------------------------------------------------------===//
 // CublasLtCmd
 //===----------------------------------------------------------------------===//
 
@@ -1835,8 +1527,9 @@
           execute_params.command_buffer_trace_stream, [&](se::Stream* stream) {
             ffi::CallOptions options = {
                 execute_params.buffer_allocations->device_ordinal(),
-                execute_params.stream,
-                execute_params.buffer_allocations->memory_allocator(),
+                ffi::CallOptions::GpuOptions{
+                    execute_params.stream,
+                    execute_params.buffer_allocations->memory_allocator()},
                 /*called_computation=*/nullptr,  // TODO(b/342285364)
                 execute_params.ffi_execution_context};
             return ffi::Call(handler_, call_frame, options);
diff --git a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.h b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.h
index b7a077e..27e8fea 100644
--- a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.h
+++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.h
@@ -40,7 +40,6 @@
 #include "xla/service/buffer_assignment.h"
 #include "xla/service/collective_ops_utils.h"
 #include "xla/service/gpu/buffer_allocations.h"
-#include "xla/service/gpu/gpu_fused_mha_runner.h"
 #include "xla/service/gpu/kernels/custom_kernel.h"
 #include "xla/service/gpu/launch_dimensions.h"
 #include "xla/service/gpu/matmul_utils.h"
@@ -81,8 +80,6 @@
   V(kReduceScatter, "ReduceScatterCmd")                  \
   V(kAllGatherCmd, "AllGatherCmd")                       \
   V(kCollectiveBroadcastCmd, "CollectiveBroadcastCmd")   \
-  V(kFusedMHACmd, "FusedMHACmd")                         \
-  V(kFusedMHABackwardCmd, "FusedMHABackwardCmd")         \
   V(kUnknownCmd, "UnknownCmd") \
   // clang-format on
 
@@ -783,112 +780,6 @@
 };
 
 //===----------------------------------------------------------------------===//
-// FusedMHACmd
-//===----------------------------------------------------------------------===//
-
-class FusedMHACmd : public TracedCommandBufferCmd {
- public:
-  FusedMHACmd(ExecutionStreamId execution_stream_id, GpufMHAConfig config,
-              BufferAllocation::Slice lhs_bmm1,
-              BufferAllocation::Slice rhs_bmm1,
-              BufferAllocation::Slice rhs_bmm2, BufferAllocation::Slice output,
-              BufferAllocation::Slice scratch, BufferAllocation::Slice mask,
-              BufferAllocation::Slice bias, BufferAllocation::Slice activation,
-              BufferAllocation::Slice seqlen_q,
-              BufferAllocation::Slice seqlen_k);
-
-  absl::Status Initialize(const Thunk::InitializeParams& params,
-                          StateManager& state) override;
-
-  absl::Status Record(const Thunk::ExecuteParams& execute_params,
-                      const RecordParams& record_params,
-                      se::CommandBuffer* command_buffer) override;
-
-  BufferUsageVector buffers() override;
-
-  bool IsNestedCommandBuffer() const final { return true; }
-
- private:
-  FusedMultiHeadedAttentionRunner& GetOrCreateRunner(
-      const stream_executor::Stream* stream);
-
-  const GpufMHAConfig config_;
-  BufferAllocation::Slice lhs_bmm1_buffer_;
-  BufferAllocation::Slice rhs_bmm1_buffer_;
-  BufferAllocation::Slice rhs_bmm2_buffer_;
-  BufferAllocation::Slice output_buffer_;
-  BufferAllocation::Slice scratch_buffer_;
-  BufferAllocation::Slice bias_buffer_;
-  BufferAllocation::Slice activation_buffer_;
-  BufferAllocation::Slice seqlen_q_buffer_;
-  BufferAllocation::Slice seqlen_k_buffer_;
-
-  // FusedMHA config
-  absl::Mutex mutex_;
-  absl::flat_hash_map<const stream_executor::Stream*,
-                      std::unique_ptr<FusedMultiHeadedAttentionRunner>>
-      runner_cache_ ABSL_GUARDED_BY(mutex_);
-};
-
-//===----------------------------------------------------------------------===//
-// FusedMHABackwardCmd
-//===----------------------------------------------------------------------===//
-
-class FusedMHABackwardCmd : public TracedCommandBufferCmd {
- public:
-  FusedMHABackwardCmd(
-      ExecutionStreamId execution_stream_id, GpufMHABackwardConfig config,
-      BufferAllocation::Slice bmm1_grad_gemm1_rhs,
-      BufferAllocation::Slice bmm1_grad_gemm2_rhs,
-      BufferAllocation::Slice bmm2_grad_gemm1_lhs,
-      BufferAllocation::Slice bmm2_grad_gemm2_rhs,
-      BufferAllocation::Slice d_output, BufferAllocation::Slice scratch,
-      BufferAllocation::Slice d_bmm1_lhs, BufferAllocation::Slice d_bmm1_rhs,
-      BufferAllocation::Slice d_bmm2_rhs, BufferAllocation::Slice d_s,
-      BufferAllocation::Slice d_bias, BufferAllocation::Slice fwd_output,
-      BufferAllocation::Slice bias, BufferAllocation::Slice seqlen_q,
-      BufferAllocation::Slice seqlen_k);
-
-  absl::Status Initialize(const Thunk::InitializeParams& params,
-                          StateManager& state) override;
-
-  absl::Status Record(const Thunk::ExecuteParams& execute_params,
-                      const RecordParams& record_params,
-                      se::CommandBuffer* command_buffer) override;
-
-  BufferUsageVector buffers() override;
-
-  bool IsNestedCommandBuffer() const final { return true; }
-
- private:
-  FusedMultiHeadedAttentionBackwardRunner& GetOrCreateRunner(
-      const stream_executor::Stream* stream);
-
-  const GpufMHABackwardConfig config_;
-  BufferAllocation::Slice bmm1_grad_gemm1_rhs_buffer_;
-  BufferAllocation::Slice bmm1_grad_gemm2_rhs_buffer_;
-  BufferAllocation::Slice bmm2_grad_gemm1_lhs_buffer_;
-  BufferAllocation::Slice bmm2_grad_gemm2_rhs_buffer_;
-  BufferAllocation::Slice d_output_buffer_;
-  BufferAllocation::Slice scratch_buffer_;
-  BufferAllocation::Slice d_bmm1_lhs_buffer_;
-  BufferAllocation::Slice d_bmm1_rhs_buffer_;
-  BufferAllocation::Slice d_bmm2_rhs_buffer_;
-  BufferAllocation::Slice d_s_buffer_;
-  BufferAllocation::Slice d_bias_buffer_;
-  BufferAllocation::Slice fwd_output_buffer_;
-  BufferAllocation::Slice bias_buffer_;
-  BufferAllocation::Slice seqlen_q_buffer_;
-  BufferAllocation::Slice seqlen_k_buffer_;
-
-  // FusedMHA config
-  absl::Mutex mutex_;
-  absl::flat_hash_map<const stream_executor::Stream*,
-                      std::unique_ptr<FusedMultiHeadedAttentionBackwardRunner>>
-      runner_cache_ ABSL_GUARDED_BY(mutex_);
-};
-
-//===----------------------------------------------------------------------===//
 // CublasLtCmd
 //===----------------------------------------------------------------------===//
 
diff --git a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_emitter.cc b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_emitter.cc
index 54e01fa..230d050 100644
--- a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_emitter.cc
+++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_emitter.cc
@@ -29,7 +29,6 @@
 #include "xla/service/gpu/runtime/copy_thunk.h"
 #include "xla/service/gpu/runtime/cudnn_thunk.h"
 #include "xla/service/gpu/runtime/custom_call_thunk.h"
-#include "xla/service/gpu/runtime/fused_mha_thunk.h"
 #include "xla/service/gpu/runtime/gemm_thunk.h"
 #include "xla/service/gpu/runtime/gpublas_lt_matmul_thunk.h"
 #include "xla/service/gpu/runtime/kernel_thunk.h"
@@ -143,27 +142,6 @@
       thunk.workspace().value());
 }
 
-static absl::StatusOr<Command> Convert(const FusedMHAThunk& thunk) {
-  return std::make_unique<FusedMHACmd>(
-      thunk.execution_stream_id(), thunk.config(), thunk.lhs_bmm1_buffer(),
-      thunk.rhs_bmm1_buffer(), thunk.rhs_bmm2_buffer(), thunk.output_buffer(),
-      thunk.scratch_buffer(), BufferAllocation::Slice(), thunk.bias_buffer(),
-      thunk.activation_buffer(), thunk.seqlen_q_buffer(),
-      thunk.seqlen_k_buffer());
-}
-
-static absl::StatusOr<Command> Convert(const FusedMHABackwardThunk& thunk) {
-  return std::make_unique<FusedMHABackwardCmd>(
-      thunk.execution_stream_id(), thunk.config(),
-      thunk.bmm1_grad_gemm1_rhs_buffer(), thunk.bmm1_grad_gemm2_rhs_buffer(),
-      thunk.bmm2_grad_gemm1_lhs_buffer(), thunk.bmm2_grad_gemm2_rhs_buffer(),
-      thunk.d_output_buffer(), thunk.scratch_buffer(),
-      thunk.d_bmm1_lhs_buffer(), thunk.d_bmm1_rhs_buffer(),
-      thunk.d_bmm2_rhs_buffer(), thunk.d_s_buffer(), thunk.d_bias_buffer(),
-      thunk.fwd_output_buffer(), thunk.bias_buffer(), thunk.seqlen_q_buffer(),
-      thunk.seqlen_k_buffer());
-}
-
 static absl::StatusOr<Command> Convert(
     const ConditionalThunk& thunk,
     CommandBufferCmdSequence::SynchronizationMode synchronization_mode) {
@@ -276,10 +254,6 @@
       return append(Convert<CustomCallThunk>(thunk));
     case Thunk::Kind::kCustomKernel:
       return append(Convert<CustomKernelThunk>(thunk));
-    case Thunk::Kind::kFusedMHA:
-      return append(Convert<FusedMHAThunk>(thunk));
-    case Thunk::Kind::kFusedMHABackward:
-      return append(Convert<FusedMHABackwardThunk>(thunk));
     case Thunk::Kind::kKernel:
       return append(Convert<KernelThunk>(thunk));
     case Thunk::Kind::kGemm:
diff --git a/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.cc b/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.cc
index e994fac..f77653e 100644
--- a/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.cc
+++ b/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.cc
@@ -198,9 +198,9 @@
   builder.AddAttributes(attrs.Build());
   CallFrame call_frame = builder.Build();
 
-  CallOptions options = {device_ordinal,    stream,
-                         allocator,         called_computation_,
-                         execution_context, execution_state_.get()};
+  CallOptions options = {
+      device_ordinal, CallOptions::GpuOptions{stream, allocator},
+      called_computation_, execution_context, execution_state_.get()};
   return Call(handler, call_frame, options, stage);
 }
 
diff --git a/third_party/xla/xla/service/gpu/runtime/for_all_thunks.cc b/third_party/xla/xla/service/gpu/runtime/for_all_thunks.cc
index 604bce1..eda2bc6 100644
--- a/third_party/xla/xla/service/gpu/runtime/for_all_thunks.cc
+++ b/third_party/xla/xla/service/gpu/runtime/for_all_thunks.cc
@@ -74,8 +74,6 @@
     case Thunk::kCustomKernel:
     case Thunk::kCuDnn:
     case Thunk::kFft:
-    case Thunk::kFusedMHA:
-    case Thunk::kFusedMHABackward:
     case Thunk::kGemm:
     case Thunk::kInfeed:
     case Thunk::kKernel:
diff --git a/third_party/xla/xla/service/gpu/runtime/fused_mha_thunk.cc b/third_party/xla/xla/service/gpu/runtime/fused_mha_thunk.cc
deleted file mode 100644
index ee13689..0000000
--- a/third_party/xla/xla/service/gpu/runtime/fused_mha_thunk.cc
+++ /dev/null
@@ -1,230 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/runtime/fused_mha_thunk.h"
-
-#include <memory>
-#include <optional>
-#include <utility>
-
-#include "absl/status/status.h"
-#include "absl/synchronization/mutex.h"
-#include "xla/service/buffer_assignment.h"
-#include "xla/service/gpu/buffer_allocations.h"
-#include "xla/service/gpu/gpu_fused_mha_runner.h"
-#include "xla/service/gpu/runtime/thunk.h"
-#include "xla/stream_executor/device_memory.h"
-#include "xla/stream_executor/lazy_op_runner.h"
-#include "xla/stream_executor/stream_executor.h"
-#include "xla/util.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-FusedMHAThunk::FusedMHAThunk(
-    ThunkInfo thunk_info, GpufMHAConfig config,
-    BufferAllocation::Slice lhs_bmm1, BufferAllocation::Slice rhs_bmm1,
-    BufferAllocation::Slice rhs_bmm2, BufferAllocation::Slice output,
-    BufferAllocation::Slice scratch, BufferAllocation::Slice mask,
-    BufferAllocation::Slice bias, BufferAllocation::Slice activation,
-    BufferAllocation::Slice seqlen_q, BufferAllocation::Slice seqlen_k)
-    : Thunk(Kind::kFusedMHA, thunk_info),
-      lhs_bmm1_buffer_(lhs_bmm1),
-      rhs_bmm1_buffer_(rhs_bmm1),
-      rhs_bmm2_buffer_(rhs_bmm2),
-      output_buffer_(output),
-      scratch_buffer_(scratch),
-      bias_buffer_(bias),
-      activation_buffer_(activation),
-      seqlen_q_buffer_(seqlen_q),
-      seqlen_k_buffer_(seqlen_k),
-      config_(std::move(config)) {}
-
-FusedMultiHeadedAttentionRunner& FusedMHAThunk::GetOrCreateRunner(
-    const stream_executor::Stream* stream) {
-  absl::MutexLock lock(&mu_);
-  auto it = runner_cache_.find(stream);
-  if (it == runner_cache_.end()) {
-    it = runner_cache_
-             .insert({stream, std::make_unique<FusedMultiHeadedAttentionRunner>(
-                                  config_)})
-             .first;
-  }
-  return *it->second;
-}
-
-std::optional<se::DeviceMemoryBase> AssignBufferIfNotNull(
-    const BufferAllocations& buffer_allocations,
-    BufferAllocation::Slice& slice) {
-  return slice.allocation() != nullptr
-             ? std::optional<se::DeviceMemoryBase>{buffer_allocations
-                                                       .GetDeviceAddress(slice)}
-             : std::nullopt;
-}
-
-absl::Status FusedMHAThunk::Initialize(const InitializeParams& params) {
-  se::dnn::LazyOpRunner<se::dnn::FusedMHAOp>* lazy_runner =
-      GetOrCreateRunner(params.stream).AsFusedMHARunner();
-  TF_ASSIGN_OR_RETURN(auto config, config_.AsDnnFusedMHAOpConfig());
-  return lazy_runner->GetOrCreateRunner(config, params.stream).status();
-}
-
-absl::Status FusedMHAThunk::ExecuteOnStream(const ExecuteParams& params) {
-  const auto& buffer_allocations = *params.buffer_allocations;
-  se::DeviceMemoryBase lhs_bmm1_buffer =
-      buffer_allocations.GetDeviceAddress(lhs_bmm1_buffer_);
-  se::DeviceMemoryBase rhs_bmm1_buffer =
-      buffer_allocations.GetDeviceAddress(rhs_bmm1_buffer_);
-  se::DeviceMemoryBase rhs_bmm2_buffer =
-      buffer_allocations.GetDeviceAddress(rhs_bmm2_buffer_);
-  se::DeviceMemoryBase output_buffer =
-      buffer_allocations.GetDeviceAddress(output_buffer_);
-  se::DeviceMemoryBase scratch_buffer =
-      buffer_allocations.GetDeviceAddress(scratch_buffer_);
-
-  std::optional<se::DeviceMemoryBase> bias_buffer =
-      AssignBufferIfNotNull(buffer_allocations, bias_buffer_);
-  std::optional<se::DeviceMemoryBase> activation_buffer =
-      AssignBufferIfNotNull(buffer_allocations, activation_buffer_);
-  std::optional<se::DeviceMemoryBase> seqlen_q_buffer =
-      AssignBufferIfNotNull(buffer_allocations, seqlen_q_buffer_);
-  std::optional<se::DeviceMemoryBase> seqlen_k_buffer =
-      AssignBufferIfNotNull(buffer_allocations, seqlen_k_buffer_);
-  RunFusedMHAOptions opts;
-  opts.runner_cache = &GetOrCreateRunner(params.stream);
-  TF_RETURN_IF_ERROR(RunGpuFMHA(config_, lhs_bmm1_buffer, rhs_bmm1_buffer,
-                                rhs_bmm2_buffer, output_buffer, scratch_buffer,
-                                bias_buffer, activation_buffer, seqlen_q_buffer,
-                                seqlen_k_buffer, params.stream, opts));
-
-  if (!params.stream->ok()) {
-    return Internal("FusedMHAThunk::ExecuteOnStream failed.");
-  }
-  return absl::OkStatus();
-}
-FusedMHABackwardThunk::FusedMHABackwardThunk(
-    ThunkInfo thunk_info, GpufMHABackwardConfig config,
-    BufferAllocation::Slice bmm1_grad_gemm1_rhs,
-    BufferAllocation::Slice bmm1_grad_gemm2_rhs,
-    BufferAllocation::Slice bmm2_grad_gemm1_lhs,
-    BufferAllocation::Slice bmm2_grad_gemm2_rhs,
-    BufferAllocation::Slice d_output, BufferAllocation::Slice scratch,
-    BufferAllocation::Slice d_bmm1_lhs, BufferAllocation::Slice d_bmm1_rhs,
-    BufferAllocation::Slice d_bmm2_rhs, BufferAllocation::Slice d_s,
-    BufferAllocation::Slice mask, BufferAllocation::Slice d_bias,
-    BufferAllocation::Slice fwd_output, BufferAllocation::Slice bias,
-    BufferAllocation::Slice seqlen_q, BufferAllocation::Slice seqlen_k)
-    : Thunk(Kind::kFusedMHABackward, thunk_info),
-      bmm1_grad_gemm1_rhs_buffer_(bmm1_grad_gemm1_rhs),
-      bmm1_grad_gemm2_rhs_buffer_(bmm1_grad_gemm2_rhs),
-      bmm2_grad_gemm1_lhs_buffer_(bmm2_grad_gemm1_lhs),
-      bmm2_grad_gemm2_rhs_buffer_(bmm2_grad_gemm2_rhs),
-      d_output_buffer_(d_output),
-      scratch_buffer_(scratch),
-      d_bmm1_lhs_buffer_(d_bmm1_lhs),
-      d_bmm1_rhs_buffer_(d_bmm1_rhs),
-      d_bmm2_rhs_buffer_(d_bmm2_rhs),
-      d_s_buffer_(d_s),
-      d_bias_buffer_(d_bias),
-      fwd_output_buffer_(fwd_output),
-      bias_buffer_(bias),
-      seqlen_q_buffer_(seqlen_q),
-      seqlen_k_buffer_(seqlen_k),
-      config_(std::move(config)) {}
-
-FusedMultiHeadedAttentionBackwardRunner&
-FusedMHABackwardThunk::GetOrCreateRunner(
-    const stream_executor::Stream* stream) {
-  absl::MutexLock lock(&mu_);
-  auto it = runner_cache_.find(stream);
-  if (it == runner_cache_.end()) {
-    it = runner_cache_
-             .insert({stream,
-                      std::make_unique<FusedMultiHeadedAttentionBackwardRunner>(
-                          config_)})
-             .first;
-  }
-  return *it->second;
-}
-
-absl::Status FusedMHABackwardThunk::Initialize(const InitializeParams& params) {
-  se::dnn::LazyOpRunner<se::dnn::FusedMHABackwardOp>* lazy_runner =
-      GetOrCreateRunner(params.stream).AsFusedMHABackwardRunner();
-  TF_ASSIGN_OR_RETURN(auto config, config_.AsDnnFusedMHABackwardOpConfig());
-  return lazy_runner->GetOrCreateRunner(config, params.stream).status();
-}
-
-absl::Status FusedMHABackwardThunk::ExecuteOnStream(
-    const ExecuteParams& params) {
-  const auto& buffer_allocations = *params.buffer_allocations;
-  se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer =
-      buffer_allocations.GetDeviceAddress(bmm1_grad_gemm1_rhs_buffer_);
-
-  se::DeviceMemoryBase bmm1_grad_gemm2_rhs_buffer =
-      buffer_allocations.GetDeviceAddress(bmm1_grad_gemm2_rhs_buffer_);
-
-  se::DeviceMemoryBase bmm2_grad_gemm1_lhs_buffer =
-      buffer_allocations.GetDeviceAddress(bmm2_grad_gemm1_lhs_buffer_);
-
-  se::DeviceMemoryBase bmm2_grad_gemm2_rhs_buffer =
-      buffer_allocations.GetDeviceAddress(bmm2_grad_gemm2_rhs_buffer_);
-
-  se::DeviceMemoryBase d_output_buffer =
-      buffer_allocations.GetDeviceAddress(d_output_buffer_);
-
-  se::DeviceMemoryBase scratch_buffer =
-      buffer_allocations.GetDeviceAddress(scratch_buffer_);
-
-  se::DeviceMemoryBase d_bmm1_lhs_buffer =
-      buffer_allocations.GetDeviceAddress(d_bmm1_lhs_buffer_);
-
-  se::DeviceMemoryBase d_bmm1_rhs_buffer =
-      buffer_allocations.GetDeviceAddress(d_bmm1_rhs_buffer_);
-
-  se::DeviceMemoryBase d_bmm2_rhs_buffer =
-      buffer_allocations.GetDeviceAddress(d_bmm2_rhs_buffer_);
-
-  std::optional<se::DeviceMemoryBase> d_s_buffer =
-      AssignBufferIfNotNull(buffer_allocations, d_s_buffer_);
-  std::optional<se::DeviceMemoryBase> d_bias_buffer =
-      AssignBufferIfNotNull(buffer_allocations, d_bias_buffer_);
-  std::optional<se::DeviceMemoryBase> fwd_output_buffer =
-      AssignBufferIfNotNull(buffer_allocations, fwd_output_buffer_);
-  std::optional<se::DeviceMemoryBase> bias_buffer =
-      AssignBufferIfNotNull(buffer_allocations, bias_buffer_);
-  std::optional<se::DeviceMemoryBase> seqlen_q_buffer =
-      AssignBufferIfNotNull(buffer_allocations, seqlen_q_buffer_);
-  std::optional<se::DeviceMemoryBase> seqlen_k_buffer =
-      AssignBufferIfNotNull(buffer_allocations, seqlen_k_buffer_);
-  RunFusedMHABackwardOptions opts;
-
-  opts.runner_cache = &GetOrCreateRunner(params.stream);
-
-  TF_RETURN_IF_ERROR(RunGpuFMHABackward(
-      config_, bmm1_grad_gemm1_rhs_buffer, bmm1_grad_gemm2_rhs_buffer,
-      bmm2_grad_gemm1_lhs_buffer, bmm2_grad_gemm2_rhs_buffer, d_output_buffer,
-      scratch_buffer, d_bmm1_lhs_buffer, d_bmm1_rhs_buffer, d_bmm2_rhs_buffer,
-      d_s_buffer, d_bias_buffer, fwd_output_buffer, bias_buffer,
-      seqlen_q_buffer, seqlen_k_buffer, params.stream, opts));
-  if (!params.stream->ok()) {
-    return Internal("FusedMHABackwardThunk::ExecuteOnStream failed.");
-  }
-  return absl::OkStatus();
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/runtime/fused_mha_thunk.h b/third_party/xla/xla/service/gpu/runtime/fused_mha_thunk.h
deleted file mode 100644
index 99a8327..0000000
--- a/third_party/xla/xla/service/gpu/runtime/fused_mha_thunk.h
+++ /dev/null
@@ -1,184 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_RUNTIME_FUSED_MHA_THUNK_H_
-#define XLA_SERVICE_GPU_RUNTIME_FUSED_MHA_THUNK_H_
-
-#include <memory>
-
-#include "absl/base/thread_annotations.h"
-#include "absl/container/flat_hash_map.h"
-#include "absl/status/status.h"
-#include "absl/synchronization/mutex.h"
-#include "xla/service/buffer_assignment.h"
-#include "xla/service/gpu/gpu_fused_mha_runner.h"
-#include "xla/service/gpu/runtime/thunk.h"
-#include "xla/stream_executor/stream_executor.h"
-#include "xla/xla_data.pb.h"
-
-namespace xla {
-namespace gpu {
-
-// This class stores everything that StreamExecutor needs to launch a DNN
-// fMHA. It is generated by IrEmitter.
-//
-// This is thread-compatible.
-class FusedMHAThunk : public Thunk {
- public:
-  // Constructs a thunk for launching a DNN FMHA.
-  FusedMHAThunk(ThunkInfo thunk_info, GpufMHAConfig config,
-                BufferAllocation::Slice lhs_bmm1_slice,
-                BufferAllocation::Slice rhs_bmm1_slice,
-                BufferAllocation::Slice rhs_bmm2_slice,
-                BufferAllocation::Slice output_slice,
-                BufferAllocation::Slice scratch_slice,
-                BufferAllocation::Slice mask_slice, /* may be null */
-                BufferAllocation::Slice bias_slice /* may be null */,
-                BufferAllocation::Slice activation_slice /* may be null */,
-                BufferAllocation::Slice seqlen_q_slice /* may be null */,
-                BufferAllocation::Slice seqlen_k_slice /* may be null */);
-
-  FusedMHAThunk(const FusedMHAThunk&) = delete;
-  FusedMHAThunk& operator=(const FusedMHAThunk&) = delete;
-
-  BufferAllocation::Slice lhs_bmm1_buffer() const { return lhs_bmm1_buffer_; }
-  BufferAllocation::Slice rhs_bmm1_buffer() const { return rhs_bmm1_buffer_; }
-  BufferAllocation::Slice rhs_bmm2_buffer() const { return rhs_bmm2_buffer_; }
-  BufferAllocation::Slice output_buffer() const { return output_buffer_; }
-  BufferAllocation::Slice scratch_buffer() const { return scratch_buffer_; }
-  BufferAllocation::Slice bias_buffer() const { return bias_buffer_; }
-  BufferAllocation::Slice activation_buffer() const {
-    return activation_buffer_;
-  }
-  BufferAllocation::Slice seqlen_q_buffer() const { return seqlen_q_buffer_; }
-  BufferAllocation::Slice seqlen_k_buffer() const { return seqlen_k_buffer_; }
-
-  GpufMHAConfig config() const { return config_; }
-  absl::Status Initialize(const InitializeParams& params) override;
-  absl::Status ExecuteOnStream(const ExecuteParams& params) override;
-
- private:
-  BufferAllocation::Slice lhs_bmm1_buffer_;
-  BufferAllocation::Slice rhs_bmm1_buffer_;
-  BufferAllocation::Slice rhs_bmm2_buffer_;
-  BufferAllocation::Slice output_buffer_;
-  BufferAllocation::Slice scratch_buffer_;
-  BufferAllocation::Slice bias_buffer_;
-  BufferAllocation::Slice activation_buffer_;
-  BufferAllocation::Slice seqlen_q_buffer_;
-  BufferAllocation::Slice seqlen_k_buffer_;
-
-  FusedMultiHeadedAttentionRunner& GetOrCreateRunner(
-      const stream_executor::Stream* stream);
-
-  // FusedMHA config
-  const GpufMHAConfig config_;
-  absl::Mutex mu_;
-  absl::flat_hash_map<const stream_executor::Stream*,
-                      std::unique_ptr<FusedMultiHeadedAttentionRunner>>
-      runner_cache_ ABSL_GUARDED_BY(mu_);
-};
-
-class FusedMHABackwardThunk : public Thunk {
- public:
-  // Constructs a thunk for launching a DNN FMHA backward.
-  FusedMHABackwardThunk(ThunkInfo thunk_info, GpufMHABackwardConfig config,
-                        BufferAllocation::Slice bmm1_grad_gemm1_rhs_slice,
-                        BufferAllocation::Slice bmm1_grad_gemm2_rhs_slice,
-                        BufferAllocation::Slice bmm2_grad_gemm1_lhs_slice,
-                        BufferAllocation::Slice bmm2_grad_gemm2_rhs_slice,
-                        BufferAllocation::Slice d_output_slice,
-                        BufferAllocation::Slice scratch_slice,
-                        BufferAllocation::Slice d_bmm1_lhs_slice,
-                        BufferAllocation::Slice d_bmm1_rhs_slice,
-                        BufferAllocation::Slice d_bmm2_rhs_slice,
-                        BufferAllocation::Slice d_s_slice,
-                        BufferAllocation::Slice mask_slice,
-                        BufferAllocation::Slice d_bias_slice,
-                        BufferAllocation::Slice fwd_output_slice,
-                        BufferAllocation::Slice bias_slice,
-                        BufferAllocation::Slice seqlen_q_slice,
-                        BufferAllocation::Slice seqlen_k_slice);
-
-  FusedMHABackwardThunk(const FusedMHABackwardThunk&) = delete;
-  FusedMHABackwardThunk& operator=(const FusedMHABackwardThunk&) = delete;
-
-  BufferAllocation::Slice bmm1_grad_gemm1_rhs_buffer() const {
-    return bmm1_grad_gemm1_rhs_buffer_;
-  }
-  BufferAllocation::Slice bmm1_grad_gemm2_rhs_buffer() const {
-    return bmm1_grad_gemm2_rhs_buffer_;
-  }
-  BufferAllocation::Slice bmm2_grad_gemm1_lhs_buffer() const {
-    return bmm2_grad_gemm1_lhs_buffer_;
-  }
-  BufferAllocation::Slice bmm2_grad_gemm2_rhs_buffer() const {
-    return bmm2_grad_gemm2_rhs_buffer_;
-  }
-  BufferAllocation::Slice d_output_buffer() const { return d_output_buffer_; }
-  BufferAllocation::Slice scratch_buffer() const { return scratch_buffer_; }
-  BufferAllocation::Slice d_bmm1_lhs_buffer() const {
-    return d_bmm1_lhs_buffer_;
-  }
-  BufferAllocation::Slice d_bmm1_rhs_buffer() const {
-    return d_bmm1_rhs_buffer_;
-  }
-  BufferAllocation::Slice d_bmm2_rhs_buffer() const {
-    return d_bmm2_rhs_buffer_;
-  }
-  BufferAllocation::Slice d_s_buffer() const { return d_s_buffer_; }
-  BufferAllocation::Slice d_bias_buffer() const { return d_bias_buffer_; }
-  BufferAllocation::Slice fwd_output_buffer() const {
-    return fwd_output_buffer_;
-  }
-  BufferAllocation::Slice bias_buffer() const { return bias_buffer_; }
-  BufferAllocation::Slice seqlen_q_buffer() const { return seqlen_q_buffer_; }
-  BufferAllocation::Slice seqlen_k_buffer() const { return seqlen_k_buffer_; }
-
-  GpufMHABackwardConfig config() const { return config_; }
-
-  absl::Status Initialize(const InitializeParams& params) override;
-  absl::Status ExecuteOnStream(const ExecuteParams& params) override;
-
- private:
-  BufferAllocation::Slice bmm1_grad_gemm1_rhs_buffer_;
-  BufferAllocation::Slice bmm1_grad_gemm2_rhs_buffer_;
-  BufferAllocation::Slice bmm2_grad_gemm1_lhs_buffer_;
-  BufferAllocation::Slice bmm2_grad_gemm2_rhs_buffer_;
-  BufferAllocation::Slice d_output_buffer_;
-  BufferAllocation::Slice scratch_buffer_;
-  BufferAllocation::Slice d_bmm1_lhs_buffer_;
-  BufferAllocation::Slice d_bmm1_rhs_buffer_;
-  BufferAllocation::Slice d_bmm2_rhs_buffer_;
-  BufferAllocation::Slice d_s_buffer_;
-  BufferAllocation::Slice d_bias_buffer_;
-  BufferAllocation::Slice fwd_output_buffer_;
-  BufferAllocation::Slice bias_buffer_;
-  BufferAllocation::Slice seqlen_q_buffer_;
-  BufferAllocation::Slice seqlen_k_buffer_;
-
-  FusedMultiHeadedAttentionBackwardRunner& GetOrCreateRunner(
-      const stream_executor::Stream* stream);
-
-  // FusedMHA backward config
-  const GpufMHABackwardConfig config_;
-  absl::Mutex mu_;
-  absl::flat_hash_map<const stream_executor::Stream*,
-                      std::unique_ptr<FusedMultiHeadedAttentionBackwardRunner>>
-      runner_cache_ ABSL_GUARDED_BY(mu_);
-};
-}  // namespace gpu
-}  // namespace xla
-#endif  // XLA_SERVICE_GPU_RUNTIME_FUSED_MHA_THUNK_H_
diff --git a/third_party/xla/xla/service/gpu/runtime/thunk.cc b/third_party/xla/xla/service/gpu/runtime/thunk.cc
index fc3c0cf..6f3081a 100644
--- a/third_party/xla/xla/service/gpu/runtime/thunk.cc
+++ b/third_party/xla/xla/service/gpu/runtime/thunk.cc
@@ -286,8 +286,6 @@
     CASE(kSequential);
     CASE(kTriangularSolve);
     CASE(kWhile);
-    CASE(kFusedMHA);
-    CASE(kFusedMHABackward);
     CASE(kWaitForStreams);
     CASE(kCuDnn);
   }
diff --git a/third_party/xla/xla/service/gpu/runtime/thunk.h b/third_party/xla/xla/service/gpu/runtime/thunk.h
index 3466649..cd26323 100644
--- a/third_party/xla/xla/service/gpu/runtime/thunk.h
+++ b/third_party/xla/xla/service/gpu/runtime/thunk.h
@@ -165,8 +165,6 @@
     kSendDone,
     kTriangularSolve,
     kWhile,
-    kFusedMHA,
-    kFusedMHABackward,
     kWaitForStreams,
     kCuDnn
   };
diff --git a/third_party/xla/xla/service/gpu/scatter_slice_simplifier.cc b/third_party/xla/xla/service/gpu/scatter_slice_simplifier.cc
deleted file mode 100644
index 9672bf2..0000000
--- a/third_party/xla/xla/service/gpu/scatter_slice_simplifier.cc
+++ /dev/null
@@ -1,264 +0,0 @@
-/* Copyright 2022 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/scatter_slice_simplifier.h"
-
-#include <cstdint>
-#include <iterator>
-#include <optional>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_map.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/log/log.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/hlo_creation_utils.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/util.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace {
-
-// Returns whether the instruction could be an operand for a slice instruction.
-bool IsValidIntermediaryUser(const HloInstruction* instruction) {
-  // Allow elementwise instructions, as they don't depend on the truncated
-  // elements. In case of multi-output scatters, the resulting shape is a tuple.
-  return instruction->IsElementwise() ||
-         instruction->opcode() == HloOpcode::kGetTupleElement;
-}
-
-// Matches the "Scatter -> Elementwise (zero or more) -> Slice" pattern.
-// Calculates the resulting scatter dimensions from the slice users.
-class ScatterSliceMatcher {
- public:
-  explicit ScatterSliceMatcher(const HloScatterInstruction* scatter)
-      : scatter_(scatter),
-        operand_dimensions_(
-            scatter->scatter_operands()[0]->shape().dimensions()),
-        result_dimensions_(operand_dimensions_.begin(),
-                           operand_dimensions_.end()) {}
-
-  // Determine the scatter shape from the user slice instructions.
-  // If any of the users are not truncation slices, return `nullopt`.
-  std::optional<Shape> InferShape() {
-    VLOG(10) << "Evaluating scatter " << scatter_->name();
-    if (!AreAllUsersValid(scatter_)) {
-      return std::nullopt;
-    }
-    std::vector<Shape> result_shapes;
-    absl::c_transform(scatter_->scatter_operands(),
-                      std::back_inserter(result_shapes),
-                      [&](const HloInstruction* op) {
-                        return ShapeUtil::MakeShape(op->shape().element_type(),
-                                                    result_dimensions_);
-                      });
-    return ShapeUtil::MakeMaybeTupleShape(result_shapes);
-  }
-
- private:
-  // Update the resulting scatter dimensions from the slice configuration and
-  // the original scatter dimensions. Return `false` if the update is not
-  // possible.
-  bool UpdateDimensions(const HloSliceInstruction* slice) {
-    int64_t rank = slice->shape().rank();
-    for (int64_t i = 0; i < rank; ++i) {
-      if (slice->slice_starts(i) != 0 || slice->slice_strides(i) != 1) {
-        return false;  // The slice is not a truncation.
-      }
-      if (slice->slice_limits(i) != result_dimensions_[i]) {
-        if (result_dimensions_[i] != operand_dimensions_[i]) {
-          return false;  // Another slice has incompatible dimensions.
-        }
-        auto& update_window_dims =
-            scatter_->scatter_dimension_numbers().update_window_dims();
-        if (absl::c_binary_search(update_window_dims, i)) {
-          return false;  // Update dimensions cannot be truncated.
-        }
-        result_dimensions_[i] = slice->slice_limits(i);
-        VLOG(10) << "Dimension " << i << " truncated to size "
-                 << result_dimensions_[i];
-      }
-    }
-    return true;
-  }
-
-  // Verify that the instruction is a valid scatter user, i.e. is either a slice
-  // operation or is an elementwise operation that has slice users (recursive).
-  bool IsUserValid(const HloInstruction* op) {
-    VLOG(10) << "Visiting user " << op->name();
-
-    // If the user is a slice operation, verify the configuration and update
-    // the resulting dimensions.
-    if (auto* slice = DynCast<HloSliceInstruction>(op)) {
-      return UpdateDimensions(slice);
-    }
-    // If the user is an elementwise operation, verify the users recursively
-    // (unless already visited).
-    bool is_valid = visited_set_.contains(op) ||
-                    (IsValidIntermediaryUser(op) && AreAllUsersValid(op));
-    if (is_valid) {
-      visited_set_.emplace(op);
-    }
-    return is_valid;
-  }
-
-  // Verify that all users are valid (see the definition of IsValidUser).
-  // If we reach the root instruction, fail the matching (slice is not found).
-  bool AreAllUsersValid(const HloInstruction* op) {
-    if (op->user_count() == 0) {
-      return !op->IsRoot();
-    }
-    return absl::c_all_of(op->users(), [this](const HloInstruction* user) {
-      return IsUserValid(user);
-    });
-  }
-
-  const HloScatterInstruction* scatter_;
-  absl::flat_hash_set<const HloInstruction*> visited_set_;
-  absl::Span<const int64_t> operand_dimensions_;
-  DimensionVector result_dimensions_;
-};
-
-// Create a replacement operand for the scatter instruction.
-HloInstruction* CreateSliceFrom(HloInstruction* operand, const Shape& shape) {
-  std::vector<int64_t> start_indices(shape.rank(), 0);
-  std::vector<int64_t> limit_indices(shape.rank());
-  std::vector<int64_t> strides(shape.rank(), 1);
-  for (int64_t i = 0; i < shape.rank(); ++i) {
-    limit_indices[i] = shape.dimensions(i);
-  }
-  return operand->AddInstruction(HloInstruction::CreateSlice(
-      shape, operand, start_indices, limit_indices, strides));
-}
-
-// Create a replacement for the scatter instruction.
-HloInstruction* CreateScatterFrom(HloScatterInstruction* scatter,
-                                  const Shape& shape) {
-  std::vector<HloInstruction*> operands(scatter->scatter_operand_count());
-  for (int64_t i = 0; i < operands.size(); ++i) {
-    operands[i] =
-        CreateSliceFrom(scatter->scatter_operands()[i],
-                        shape.IsTuple() ? shape.tuple_shapes(i) : shape);
-  }
-  return scatter->AddInstruction(HloInstruction::CreateScatter(
-      shape, absl::MakeSpan(operands), scatter->scatter_indices(),
-      scatter->scatter_updates(), scatter->called_computations()[0],
-      scatter->scatter_dimension_numbers(), scatter->indices_are_sorted(),
-      scatter->unique_indices()));
-}
-
-class ScatterSliceSimplifierVisitor : public DfsHloRewriteVisitor {
- public:
-  absl::Status HandleScatter(HloInstruction* instruction) override {
-    auto* scatter = Cast<HloScatterInstruction>(instruction);
-
-    // Infer scatter shape from the slice users.
-    std::optional<Shape> result_shape =
-        ScatterSliceMatcher(scatter).InferShape();
-    if (!result_shape.has_value()) {
-      return absl::OkStatus();
-    }
-    VLOG(2) << "Matched scatter " << scatter->name() << " with shape "
-            << scatter->shape().ToString() << ", inferred result shape "
-            << result_shape->ToString() << " (from the slice users)";
-
-    // Replace slice user instructions.
-    HloInstruction* new_scatter = CreateScatterFrom(scatter, *result_shape);
-    return ReplaceAllUsersRecursive(scatter, new_scatter);
-  }
-
- private:
-  // Create a replacement for every user. If the user is a slice operation,
-  // replace it in the computation graph, the old branch will be removed.
-  absl::Status ReplaceAllUsersRecursive(HloInstruction* old_instruction,
-                                        HloInstruction* new_instruction) {
-    // Maintain the replacement map, needed for non-unary elementwise users.
-    replacements_[old_instruction] = new_instruction;
-
-    // It's importand to make a copy of the users list, as it may be modified
-    // during the iteration.
-    std::vector<HloInstruction*> users = old_instruction->users();
-    for (HloInstruction* user : users) {
-      if (user->parent() == nullptr) {
-        VLOG(3) << "Skipping user " << user->name() << " (already replaced)";
-        continue;
-      }
-      TF_RETURN_IF_ERROR(ReplaceUserRecursive(user, new_instruction));
-    }
-    return absl::OkStatus();
-  }
-
-  // Replace the slice user with a new scatter (or a new chain of operations
-  // starting with a scatter). For elementwise operations, create a new user
-  // with updated operands (build the chain).
-  absl::Status ReplaceUserRecursive(HloInstruction* user,
-                                    HloInstruction* operand) {
-    VLOG(3) << "Replacing scatter user " << user->name();
-    if (user->opcode() == HloOpcode::kSlice) {
-      return ReplaceInstruction(user, operand);
-    }
-
-    // Create the replacement instruction with new shape.
-    HloInstruction* new_user = nullptr;
-    if (user->IsElementwise()) {
-      auto new_shape = [operand](HloInstruction* from) {
-        return ShapeUtil::MakeShape(from->shape().element_type(),
-                                    operand->shape().dimensions());
-      };
-      std::vector<HloInstruction*> new_operands;
-      absl::c_transform(user->operands(), std::back_inserter(new_operands),
-                        [&](HloInstruction* op) {
-                          auto it = replacements_.find(op);
-                          return it != replacements_.end()
-                                     ? it->second
-                                     : CreateSliceFrom(op, new_shape(op));
-                        });
-      new_user = user->AddInstruction(
-          user->CloneWithNewOperands(new_shape(user), new_operands));
-    } else {
-      auto* gte = Cast<HloGetTupleElementInstruction>(user);
-      TF_ASSIGN_OR_RETURN(new_user,
-                          MakeGetTupleElementHlo(operand, gte->tuple_index(),
-                                                 &user->metadata()));
-    }
-
-    // Replace slice user instructions recursively.
-    return ReplaceAllUsersRecursive(user, new_user);
-  }
-
-  absl::flat_hash_map<HloInstruction*, HloInstruction*> replacements_;
-};
-
-}  // namespace
-
-absl::StatusOr<bool> ScatterSliceSimplifier::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  return ScatterSliceSimplifierVisitor{}.RunOnModule(module, execution_threads);
-}
-
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/scatter_slice_simplifier.h b/third_party/xla/xla/service/gpu/scatter_slice_simplifier.h
deleted file mode 100644
index 3498377..0000000
--- a/third_party/xla/xla/service/gpu/scatter_slice_simplifier.h
+++ /dev/null
@@ -1,58 +0,0 @@
-/* Copyright 2022 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_SCATTER_SLICE_SIMPLIFIER_H_
-#define XLA_SERVICE_GPU_SCATTER_SLICE_SIMPLIFIER_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-
-// Replaces scatters followed by truncation slices with a new scatter using
-// a different output shape, and the slices are eliminated.
-//
-// (a) Single output     (b) Multiple outputs    (c) Elementwise users
-//
-//   T[N+1] scatter       (T1, T2) scatter         T scatter  T constant
-//          v                v       v                  v     v
-//     T[N] slice          T1 gte    T2 gte            T maximum
-//                           v       v                     v
-//                         T1 slice  T2 slice           T slice
-//
-// This pattern is used when the last element of the scatter output is intended
-// to accumulate updates from the input elements that should be ignored.
-// This is slow if there are many indices mapped to the last output index and
-// the scatter is implemented using atomics, so everything collides on that one
-// memory location.
-// As OOB scatter indices are dropped by the GPU implementation, we can remove
-// the slice step entirely and avoid the memory congestion in the scatter step.
-
-class ScatterSliceSimplifier : public HloModulePass {
- public:
-  absl::string_view name() const override { return "scatter-slice-simplifier"; }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-};
-
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_SCATTER_SLICE_SIMPLIFIER_H_
diff --git a/third_party/xla/xla/service/gpu/scatter_slice_simplifier_test.cc b/third_party/xla/xla/service/gpu/scatter_slice_simplifier_test.cc
deleted file mode 100644
index 281a4f0..0000000
--- a/third_party/xla/xla/service/gpu/scatter_slice_simplifier_test.cc
+++ /dev/null
@@ -1,336 +0,0 @@
-/* Copyright 2022 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/scatter_slice_simplifier.h"
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/tests/hlo_test_base.h"
-
-namespace xla {
-namespace {
-
-namespace m = ::xla::match;
-
-using ScatterSliceSimplifierTest = HloTestBase;
-
-TEST_F(ScatterSliceSimplifierTest, Scatter1D) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-HloModule test_module
-
-%add_F32 {
-  %lhs = f32[] parameter(0)
-  %rhs = f32[] parameter(1)
-  ROOT %add = f32[] add(%lhs, %rhs)
-}
-
-ENTRY main {
-  %indices = s32[4] parameter(0)
-  %updates = f32[4] parameter(1)
-  %operands = f32[9] constant(0)
-  %scatter = f32[9] scatter(%operands, %indices, %updates), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%add_F32
-  ROOT %slice = f32[8] slice(%scatter), slice={[0:8]}
-}
-  )")
-                    .value();
-  ScatterSliceSimplifier test_pass;
-  ASSERT_TRUE(RunHloPass(&test_pass, module.get()).value());
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch(m::Scatter(m::Slice(m::Constant()), m::Parameter(0),
-                                    m::Parameter(1))
-                             .WithShape(F32, {8})));
-}
-
-TEST_F(ScatterSliceSimplifierTest, Scatter3D) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-HloModule test_module
-
-%add_F32 {
-  %lhs = f32[] parameter(0)
-  %rhs = f32[] parameter(1)
-  ROOT %add = f32[] add(%lhs, %rhs)
-}
-
-ENTRY main {
-  %indices = s32[2] parameter(0)
-  %updates = f32[2,4,4] parameter(1)
-  %operands = f32[5,4,4] constant(0)
-  %scatter = f32[5,4,4] scatter(%operands, %indices, %updates), update_window_dims={1,2}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%add_F32
-  ROOT %slice = f32[4,4,4] slice(%scatter), slice={[0:4], [0:4], [0:4]}
-}
-  )")
-                    .value();
-  ScatterSliceSimplifier test_pass;
-  ASSERT_TRUE(RunHloPass(&test_pass, module.get()).value());
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch(m::Scatter(m::Slice(m::Constant()), m::Parameter(0),
-                                    m::Parameter(1))
-                             .WithShape(F32, {4, 4, 4})));
-}
-
-TEST_F(ScatterSliceSimplifierTest, ScatterMultiOutput) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-HloModule test_module
-
-%add_F32_add_F16 {
-  %lhs.0 = f32[] parameter(0)
-  %rhs.0 = f32[] parameter(2)
-  %add.0 = f32[] add(%lhs.0, %rhs.0)
-  %lhs.1 = f16[] parameter(1)
-  %rhs.1 = f16[] parameter(3)
-  %add.1 = f16[] add(%lhs.1, %rhs.1)
-  ROOT %tuple = (f32[], f16[]) tuple(%add.0, %add.1)
-}
-
-ENTRY main {
-  %indices = s32[4] parameter(0)
-  %updates.0 = f32[4] parameter(1)
-  %updates.1 = f16[4] parameter(2)
-  %operands.0 = f32[9] constant(0)
-  %operands.1 = f16[9] constant(0)
-  %scatter = (f32[9], f16[9]) scatter(%operands.0, %operands.1, %indices, %updates.0, %updates.1), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%add_F32_add_F16
-  %gte.0 = f32[9] get-tuple-element(%scatter), index=0
-  %slice.0 = f32[8] slice(%gte.0), slice={[0:8]}
-  %gte.1 = f16[9] get-tuple-element(%scatter), index=1
-  %slice.1 = f16[8] slice(%gte.1), slice={[0:8]}
-  ROOT %tuple = (f32[8], f16[8]) tuple(%slice.0, %slice.1)
-}
-  )")
-                    .value();
-  ScatterSliceSimplifier test_pass;
-  ASSERT_TRUE(RunHloPass(&test_pass, module.get()).value());
-  auto expected_scatter =
-      m::Scatter(m::Slice(m::Constant()), m::Slice(m::Constant()),
-                 m::Parameter(0), m::Parameter(1), m::Parameter(2));
-
-  Shape expected_shape = ShapeUtil::MakeTupleShape(
-      {ShapeUtil::MakeShape(F32, {8}), ShapeUtil::MakeShape(F16, {8})});
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch(m::Tuple(m::GetTupleElement(expected_scatter),
-                                  m::GetTupleElement(expected_scatter))
-                             .WithShapeEqualTo(&expected_shape)));
-}
-
-TEST_F(ScatterSliceSimplifierTest, NotMatching) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-HloModule test_module
-
-%add_F32 {
-  %lhs = f32[] parameter(0)
-  %rhs = f32[] parameter(1)
-  ROOT %add = f32[] add(%lhs, %rhs)
-}
-
-slice_not_truncation {
-  %indices = s32[4] parameter(0)
-  %updates = f32[4] parameter(1)
-  %operands = f32[9] constant(0)
-  %scatter = f32[9] scatter(%operands, %indices, %updates), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%add_F32
-  ROOT %slice = f32[8] slice(%scatter), slice={[1:9]}
-}
-
-slice_with_stride {
-  %indices = s32[4] parameter(0)
-  %updates = f32[4] parameter(1)
-  %operands = f32[9] constant(0)
-  %scatter = f32[9] scatter(%operands, %indices, %updates), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%add_F32
-  ROOT %slice = f32[4] slice(%scatter), slice={[0:8:2]}
-}
-
-scatter_multiple_users {
-  %indices = s32[4] parameter(0)
-  %updates = f32[4] parameter(1)
-  %operands = f32[9] constant(0)
-  %scatter = f32[9] scatter(%operands, %indices, %updates), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%add_F32
-  %slice = f32[8] slice(%scatter), slice={[0:8]}
-  ROOT %tuple = (f32[9], f32[8]) tuple(%scatter, %slice)
-}
-
-scatter_incompatible_slices {
-  %indices = s32[2] parameter(0)
-  %updates = f32[2,4] parameter(1)
-  %operands = f32[4,4] constant(0)
-  %scatter = f32[4,4] scatter(%operands, %indices, %updates), update_window_dims={1}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%add_F32
-  %slice.0 = f32[3,4] slice(%scatter), slice={[0:3], [0:4]}
-  %slice.1 = f32[4,3] slice(%scatter), slice={[0:4], [0:3]}
-  ROOT %tuple = (f32[3,4], f32[4,3]) tuple(%slice.0, %slice.1)
-}
-
-slice_not_found {
-  %indices = s32[4] parameter(0)
-  %updates = f32[4] parameter(1)
-  %operands = f32[8] constant(0)
-  %scatter = f32[8] scatter(%operands, %indices, %updates), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%add_F32
-  ROOT %exp = f32[8] exponential(%scatter)
-}
-
-slice_update_dimensions {
-  %indices = s32[10] parameter(0)
-  %updates = f32[10,1,128] parameter(1)
-  %operands = f32[100,128] constant(0)
-  %scatter = f32[100,128] scatter(%operands, %indices, %updates), update_window_dims={1,2}, inserted_window_dims={}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%add_F32
-  ROOT %slice = f32[100,64] slice(%scatter), slice={[0:100], [0:64]}
-}
-  )")
-                    .value();
-  ScatterSliceSimplifier test_pass;
-  ASSERT_FALSE(RunHloPass(&test_pass, module.get()).value());
-}
-
-TEST_F(ScatterSliceSimplifierTest, IntermediaryUsers) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-HloModule test_module
-
-%add_F32 {
-  %lhs = f32[] parameter(0)
-  %rhs = f32[] parameter(1)
-  ROOT %add = f32[] add(%lhs, %rhs)
-}
-
-ENTRY main {
-  %indices = s32[4] parameter(0)
-  %updates = f32[4] parameter(1)
-  %operands = f32[9] constant(0)
-  %scatter = f32[9] scatter(%operands, %indices, %updates), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%add_F32
-  %unary = f32[9] abs(%scatter)
-  %slice.0 = f32[8] slice(%unary), slice={[0:8]}
-  %binary = f32[9] maximum(%scatter, %operands)
-  %slice.1 = f32[8] slice(%binary), slice={[0:8]}
-  ROOT %tuple = (f32[8], f32[8]) tuple(%slice.0, %slice.1)
-}
-  )")
-                    .value();
-  ScatterSliceSimplifier test_pass;
-  ASSERT_TRUE(RunHloPass(&test_pass, module.get()).value());
-  auto expected_scatter =
-      m::Scatter(m::Slice(m::Constant()), m::Parameter(0), m::Parameter(1));
-
-  Shape expected_shape = ShapeUtil::MakeTupleShape(
-      {ShapeUtil::MakeShape(F32, {8}), ShapeUtil::MakeShape(F32, {8})});
-  EXPECT_THAT(
-      module->entry_computation()->root_instruction(),
-      GmockMatch(m::Tuple(m::Abs(expected_scatter),
-                          m::Maximum(expected_scatter, m::Slice(m::Constant())))
-                     .WithShapeEqualTo(&expected_shape)));
-}
-
-TEST_F(ScatterSliceSimplifierTest, IntermediaryChain) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-HloModule test_module
-
-%add_F32 {
-  %lhs = f32[] parameter(0)
-  %rhs = f32[] parameter(1)
-  ROOT %add = f32[] add(%lhs, %rhs)
-}
-
-ENTRY main {
-  %indices = s32[4] parameter(0)
-  %updates = f32[4] parameter(1)
-  %operands = f32[9] constant(0)
-  %scatter = f32[9] scatter(%operands, %indices, %updates), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%add_F32
-  %elementwise.0 = f32[9] abs(%scatter)
-  %elementwise.1 = f32[9] exponential(%elementwise.0)
-  %elementwise.2 = f32[9] add(%elementwise.0, %elementwise.1)
-  ROOT %result = f32[8] slice(%elementwise.2), slice={[0:8]}
-}
-  )")
-                    .value();
-  ScatterSliceSimplifier test_pass;
-  ASSERT_TRUE(RunHloPass(&test_pass, module.get()).value());
-  auto expected_scatter =
-      m::Scatter(m::Slice(m::Constant()), m::Parameter(0), m::Parameter(1));
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch(m::Add(m::Abs(expected_scatter),
-                                m::Exp(m::Abs(expected_scatter)))
-                             .WithShape(F32, {8})));
-}
-
-TEST_F(ScatterSliceSimplifierTest, DiamondShape) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-HloModule test_module
-
-%add_F32_mul_F32 {
-  %lhs.0 = f32[] parameter(0)
-  %rhs.0 = f32[] parameter(2)
-  %add.0 = f32[] add(%lhs.0, %rhs.0)
-  %lhs.1 = f32[] parameter(1)
-  %rhs.1 = f32[] parameter(3)
-  %mul.1 = f32[] multiply(%lhs.1, %rhs.1)
-  ROOT %tuple = (f32[], f32[]) tuple(%add.0, %mul.1)
-}
-
-ENTRY main {
-  %indices = s32[4] parameter(0)
-  %updates.0 = f32[4] parameter(1)
-  %updates.1 = f32[4] parameter(2)
-  %operands.0 = f32[9] constant(0)
-  %operands.1 = f32[9] constant(0)
-  %scatter = (f32[9], f32[9]) scatter(%operands.0, %operands.1, %indices, %updates.0, %updates.1), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%add_F32_mul_F32
-  %gte.0 = f32[9] get-tuple-element(%scatter), index=0
-  %gte.1 = f32[9] get-tuple-element(%scatter), index=1
-  %consumer = f32[9] add(%gte.0, %gte.1)
-  ROOT %slice = f32[8] slice(%consumer), slice={[0:8]}
-}
-  )")
-                    .value();
-  ScatterSliceSimplifier test_pass;
-  ASSERT_TRUE(RunHloPass(&test_pass, module.get()).value());
-  auto expected_scatter =
-      m::Scatter(m::Slice(m::Constant()), m::Slice(m::Constant()),
-                 m::Parameter(0), m::Parameter(1), m::Parameter(2));
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch(m::Add(m::GetTupleElement(expected_scatter),
-                                m::GetTupleElement(expected_scatter))
-                             .WithShape(F32, {8})));
-}
-
-TEST_F(ScatterSliceSimplifierTest, ElementwiseSelect) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-HloModule test_module
-
-%add_F32 {
-  %lhs = f32[] parameter(0)
-  %rhs = f32[] parameter(1)
-  ROOT %add = f32[] add(%lhs, %rhs)
-}
-
-ENTRY main {
-  %indices = s32[4] parameter(0)
-  %updates = f32[4] parameter(1)
-  %operands = f32[9] constant(0)
-  %scatter = f32[9] scatter(%operands, %indices, %updates), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%add_F32
-  %pred_ = pred[9] parameter(2)
-  %select = f32[9] select(%pred_, %scatter, %operands)
-  ROOT %slice = f32[8] slice(%select), slice={[0:8]}
-}
-  )")
-                    .value();
-  ScatterSliceSimplifier test_pass;
-  ASSERT_TRUE(RunHloPass(&test_pass, module.get()).value());
-  auto expected_scatter =
-      m::Scatter(m::Slice(m::Constant()), m::Parameter(0), m::Parameter(1));
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch(m::Select(m::Slice(m::Parameter(2)), expected_scatter,
-                                   m::Slice(m::Constant()))
-                             .WithShape(F32, {8})));
-}
-
-}  // namespace
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/scheduling_instruction_annotator.cc b/third_party/xla/xla/service/gpu/scheduling_instruction_annotator.cc
deleted file mode 100644
index fbf1b2c..0000000
--- a/third_party/xla/xla/service/gpu/scheduling_instruction_annotator.cc
+++ /dev/null
@@ -1,67 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/scheduling_instruction_annotator.h"
-
-#include <string>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/log/check.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla::gpu {
-namespace {
-
-// Populates `OpMetadata`'s `scheduling_name` field for all of the instructions
-// belonging to `computation`.
-absl::StatusOr<bool> AnnotateSchedulingInstructionNames(
-    HloComputation& computation) {
-  bool changed = false;
-  for (HloInstruction* inst : computation.instructions()) {
-    if (!inst->metadata().scheduling_name().empty()) {
-      continue;
-    }
-    inst->set_metadata_scheduling_name(std::string(inst->name()));
-    changed = true;
-  }
-  return changed;
-}
-
-}  // namespace
-
-absl::StatusOr<bool> SchedulingInstructionAnnotator::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  CHECK(module->has_schedule())
-      << "The pass is supposed to run in the beginning of post-scheduling!";
-  bool changed = false;
-
-  // We visit computations in the order of callees to callers, as information is
-  // propagated from calles to callers.
-  for (HloComputation* computation :
-       module->MakeComputationPostOrder(execution_threads)) {
-    TF_ASSIGN_OR_RETURN(bool result,
-                        AnnotateSchedulingInstructionNames(*computation));
-    changed |= result;
-  }
-
-  return changed;
-}
-
-}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/scheduling_instruction_annotator.h b/third_party/xla/xla/service/gpu/scheduling_instruction_annotator.h
deleted file mode 100644
index 3f9b769..0000000
--- a/third_party/xla/xla/service/gpu/scheduling_instruction_annotator.h
+++ /dev/null
@@ -1,44 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_SCHEDULING_INSTRUCTION_ANNOTATOR_H_
-#define XLA_SERVICE_GPU_SCHEDULING_INSTRUCTION_ANNOTATOR_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla::gpu {
-
-// The pass amends the `OpMetadata` with instruction name present at the
-// scheduling time. This is later being used to make sure instructions are not
-// renamed post scheduling. Enforcing this is necessary because otherwise
-class SchedulingInstructionAnnotator : public HloModulePass {
- public:
-  absl::string_view name() const override {
-    return "scheduling-instruction-annotator";
-  }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-};
-
-}  // namespace xla::gpu
-
-#endif  // XLA_SERVICE_GPU_SCHEDULING_INSTRUCTION_ANNOTATOR_H_
diff --git a/third_party/xla/xla/service/gpu/scheduling_instruction_annotator_test.cc b/third_party/xla/xla/service/gpu/scheduling_instruction_annotator_test.cc
deleted file mode 100644
index 146607f..0000000
--- a/third_party/xla/xla/service/gpu/scheduling_instruction_annotator_test.cc
+++ /dev/null
@@ -1,97 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/scheduling_instruction_annotator.h"
-
-#include <memory>
-
-#include <gtest/gtest.h>
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/tests/filecheck.h"
-#include "xla/tests/hlo_test_base.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla::gpu {
-namespace {
-
-using SchedulingInstructionAnnotatorTest = HloTestBase;
-
-TEST_F(SchedulingInstructionAnnotatorTest,
-       AnnotatesAllInstructionsWithTheirRespectiveNames) {
-  constexpr absl::string_view kHloString = R"(
-    HloModule module, is_scheduled=true
-
-    ENTRY entry {
-      p0 = f32[1] parameter(0)
-      p1 = f32[1] parameter(1)
-      add0 = f32[1] add(p0,p1)
-      ROOT exp0 = f32[1] exponential(add0)
-    }
-  )";
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(kHloString));
-
-  SchedulingInstructionAnnotator pass;
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get()));
-
-  ASSERT_TRUE(changed);
-  for (const auto* comp : module->computations()) {
-    for (const auto* instruction : comp->instructions()) {
-      EXPECT_EQ(instruction->name(), instruction->metadata().scheduling_name());
-    }
-  }
-  constexpr absl::string_view kExpected = R"(
-// CHECK:       %[[P0:.+]] = {{.*}} parameter(0)
-// CHECK-SAME:  scheduling_name="[[P0]]"
-// CHECK:       %[[P1:.+]] = {{.*}} parameter(1)
-// CHECK-SAME:  scheduling_name="[[P1]]"
-// CHECK:       %[[ADD0:.+]] = {{.*}} add(%[[P0]], %[[P1]])
-// CHECK-SAME:  scheduling_name="[[ADD0]]"
-// CHECK: ROOT  %[[EXP0:.+]] = {{.*}} exponential(%[[ADD0]])
-// CHECK-SAME:  scheduling_name="[[EXP0]]"
-  )";
-  TF_ASSERT_OK_AND_ASSIGN(
-      bool filecheck_matches,
-      RunFileCheck(
-          module->ToString(HloPrintOptions().set_print_operand_shape(false)),
-          kExpected));
-  EXPECT_TRUE(filecheck_matches);
-}
-
-TEST_F(SchedulingInstructionAnnotatorTest,
-       DoesNotAnnotateAllInstructionsWithTheirRespectiveNames) {
-  constexpr absl::string_view kHloString = R"(
-    HloModule module, is_scheduled=true
-
-    ENTRY entry {
-      p0 = f32[1] parameter(0), metadata={scheduling_name="p0"}
-      p1 = f32[1] parameter(1), metadata={scheduling_name="p1"}
-      add0 = f32[1] add(p0,p1), metadata={scheduling_name="add0"}
-      ROOT exp0 = f32[1] exponential(add0), metadata={scheduling_name="exp0"}
-    }
-  )";
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(kHloString));
-
-  SchedulingInstructionAnnotator pass;
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get()));
-
-  EXPECT_FALSE(changed);
-}
-
-}  // namespace
-}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/softmax_rewriter_triton.cc b/third_party/xla/xla/service/gpu/softmax_rewriter_triton.cc
deleted file mode 100644
index c6bd796..0000000
--- a/third_party/xla/xla/service/gpu/softmax_rewriter_triton.cc
+++ /dev/null
@@ -1,797 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/softmax_rewriter_triton.h"
-
-#include <cstdint>
-#include <functional>
-#include <string>
-#include <utility>
-#include <variant>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_map.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/log/check.h"
-#include "absl/log/log.h"
-#include "absl/status/status.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/string_view.h"
-#include "mlir/IR/MLIRContext.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/hlo/utils/hlo_query.h"
-#include "xla/layout_util.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/fusions/triton/triton_support.h"
-#include "xla/service/gpu/hlo_traversal.h"
-#include "xla/service/gpu/ir_emission_utils.h"
-#include "xla/service/gpu/model/fusion_analysis_cache.h"
-#include "xla/service/gpu/model/gpu_indexing_performance_model.h"
-#include "xla/service/gpu/model/symbolic_tile_analysis.h"
-#include "xla/service/gpu/model/tiled_hlo_computation.h"
-#include "xla/service/instruction_fusion.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/status_macros.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/util.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-using hlo_query::IsBroadcastOfParameter;
-using hlo_query::IsBroadcastOfScalarConstant;
-
-bool HasDefaultLayout(const Shape& shape) {
-  return shape.has_layout() &&
-         LayoutUtil::IsMonotonicWithDim0Major(shape.layout());
-}
-
-// Returns true if a trivially connected producer of 'consumer' with opcode
-// 'opcode' exists. If such an instruction is found, the value of 'producer' is
-// set to it. The definition of "trivial" operations is as given in
-// 'IsTriviallyFusible'.
-bool TrivialEdge(HloInstruction** producer, HloInstruction* consumer,
-                 HloOpcode opcode, const se::GpuComputeCapability& gpu_version);
-
-bool BitcastIsTilingNoop(HloInstruction* bitcast,
-                         const se::GpuComputeCapability& gpu_version) {
-  CHECK_EQ(bitcast->opcode(), HloOpcode::kBitcast);
-
-  if (ShapeUtil::IsEffectiveScalar(bitcast->shape())) {
-    return true;
-  }
-
-  // In the Softmax rewriter for now, tiling is derived from a hero reduction
-  // operation, which should be reducing its input on the last axis. Therefore,
-  // a bitcast is always a no-op with regards to a tile if
-  //   (1) it does not change the size of the reduction dimension of its input
-  //       (the last one); if its input is already reduced, then (1) is true
-  //       by default
-  //   (2) the layout of its output is ordered in the same way as the layout of
-  //       its input. This is a fuzzy definition, but since we assume fusible
-  //       ops to always have a default layout, we can just check if both the
-  //       bitcast and its input have a default layout
-  auto last_dimension = [](const HloInstruction* instr) {
-    return instr->shape().dimensions().back();
-  };
-
-  HloInstruction* reduce = nullptr;
-  TrivialEdge(&reduce, bitcast->mutable_operand(0), HloOpcode::kReduce,
-              gpu_version);
-
-  return (HasDefaultLayout(bitcast->shape()) &&
-          HasDefaultLayout(bitcast->operand(0)->shape()) &&
-          (reduce != nullptr ||
-           last_dimension(bitcast->operand(0)) == last_dimension(bitcast)));
-}
-
-inline bool HasOneUse(const HloInstruction* instr) {
-  return instr->user_count() == 1;
-}
-
-// Supports two types of broadcast of parameters. Either to one batch
-// dim, or one reduction dim. For example the following cases are supported:
-//
-// Case #1:
-// p = f32[a] parameter(0)
-// b = f32[a,x] broadcast(p), dimensions={0}
-//
-// Case #2:
-// p = f32[a] parameter(0)
-// b = f32[x,a] broadcast(p), dimensions={1}
-//
-// Case #3:
-// p = f32[a,b] parameter(0)
-// b = f32[x,a,b] broadcast(p), dimensions={1,2}
-//
-// Other broadcast tiling patterns are currently unsupported.
-// See b/328049138 for details.
-//
-// Unsupported case #1:
-// p = f32[a] parameter(0)
-// b = f32[x,a,y] broadcast(p), dimensions={1}
-//
-// Unsupported case #2:
-// p = f32[a,b] parameter(0)
-// b = f32[x,a,y,b] broadcast(p), dimensions={1,3}
-//
-// Unsupported case #3:
-// p = f32[a] parameter(0)
-// b = f32[x,y,a] broadcast(p), dimensions={2}
-//
-// Unsupported case #4:
-// p = f32[a,b] parameter(0)
-// b = f32[a,x,b] broadcast(p), dimensions={0,2}
-bool IsBatchOrReductionDimBroadcast(const HloInstruction& hlo) {
-  CHECK_EQ(hlo.opcode(), HloOpcode::kBroadcast)
-      << "Expected broadcast " << hlo.ToShortString();
-  CHECK_EQ(hlo.operand(0)->opcode(), HloOpcode::kParameter)
-      << "Expected parameter " << hlo.operand(0)->ToShortString();
-
-  const HloBroadcastInstruction* broadcast =
-      Cast<HloBroadcastInstruction>(&hlo);
-
-  const HloParameterInstruction* parameter =
-      Cast<HloParameterInstruction>(hlo.operand(0));
-
-  // Support only one dim broadcast.
-  if (parameter->shape().dimensions_size() + 1 !=
-      broadcast->shape().dimensions_size()) {
-    return false;
-  }
-
-  // It is enough to ensure that the broadcast does not preserve both last, and
-  // first dimensions of the parameter at the same time. Otherwise the broadcast
-  // is the unsupported case #4.
-  //
-  // Preserve the first dim:
-  //   p = f32[a,b] parameter(0)
-  //   b1 = f32[a,b,c] broadcast(p), dimensions={0,1}
-  bool preserve_first_dim = broadcast->dimensions().front() == 0;
-  // Preserve the last dim:
-  //   p = f32[a,b] parameter(0)
-  //   b1 = f32[c,a,b] broadcast(p), dimensions={1,2}
-  bool preserve_last_dim = broadcast->dimensions().back() ==
-                           broadcast->shape().dimensions_size() - 1;
-  // We do not want to preserve both first and last dim, as it means the
-  // broadcast is not expanding on outermost dims.
-  return !(preserve_first_dim && preserve_last_dim);
-}
-
-bool IsBroadcastOfAScalar(const HloInstruction& hlo) {
-  CHECK_EQ(hlo.opcode(), HloOpcode::kBroadcast)
-      << "Expected broadcast " << hlo.ToShortString();
-  return ShapeUtil::IsScalar(hlo.operand(0)->shape());
-}
-
-bool IsSingleRowParameterBroadcast(const HloInstruction& hlo) {
-  CHECK_EQ(hlo.opcode(), HloOpcode::kBroadcast)
-      << "Expected broadcast " << hlo.ToShortString();
-  CHECK_EQ(hlo.operand(0)->opcode(), HloOpcode::kParameter)
-      << "Expected parameter " << hlo.operand(0)->ToShortString();
-
-  const HloBroadcastInstruction* broadcast =
-      Cast<HloBroadcastInstruction>(&hlo);
-  const HloParameterInstruction* parameter =
-      Cast<HloParameterInstruction>(hlo.operand(0));
-
-  if (parameter->shape().dimensions_size() != 1) {
-    return false;
-  }
-  return broadcast->dimensions()[0] == broadcast->shape().dimensions_size() - 1;
-}
-
-bool IsSupportedBroadcastOfParameter(const HloInstruction& hlo) {
-  return IsBroadcastOfParameter(hlo) &&
-         (IsBatchOrReductionDimBroadcast(hlo) || IsBroadcastOfAScalar(hlo) ||
-          IsSingleRowParameterBroadcast(hlo));
-}
-
-// Chooses which operand to use for fusion processing. Taking in a unary or
-// binary instruction, returns the first non-splat operand. If none is
-// present, returns any operand.
-HloInstruction* ChooseOperandForFusionProcessing(HloInstruction* instr) {
-  CHECK_GT(instr->operand_count(), 0);
-  CHECK_LE(instr->operand_count(), 2);
-
-  // TODO(b/326217416): Extend the broadcast of splat constants/parameters to a
-  // broadcast of any op.
-  if (instr->operand_count() > 1 &&
-      (IsBroadcastOfScalarConstant(*instr->operand(0)) ||
-       IsSupportedBroadcastOfParameter(*instr->operand(0)))) {
-    return instr->mutable_operand(1);
-  }
-  return instr->mutable_operand(0);
-}
-
-bool IsTriviallyFusible(HloInstruction* instr,
-                        const se::GpuComputeCapability& gpu_version,
-                        int num_allowed_users = 1) {
-  // Checks whether an op is trivially fusible. An op is said to be trivially
-  // fusible if it does not increase the amount of memory read/written by the
-  // resulting fusion, is compatible with any chosen tiling, and can be
-  // codegen'd using Triton. The op is allowed to have up to num_allowed_users
-  // users.
-  if (instr->user_count() > num_allowed_users ||
-      !HasDefaultLayout(instr->shape())) {
-    return false;
-  }
-
-  if (instr->opcode() == HloOpcode::kBitcast &&
-      BitcastIsTilingNoop(instr, gpu_version)) {
-    return true;
-  }
-
-  if (instr->IsElementwise() && instr->operand_count() == 1) {
-    return static_cast<bool>(IsTritonSupportedInstruction(*instr, gpu_version));
-  }
-
-  // Elementwise binary ops are trivially fusible if the operands are the same,
-  // or if exactly one of the operands is a splat constant.
-  if (instr->IsElementwiseBinary()) {
-    const HloInstruction* operand_0 = instr->operand(0);
-    const HloInstruction* operand_1 = instr->operand(1);
-
-    // Elementwise binary ops should be fused if both operands are the same and
-    // if the operand is triton supported.
-    if (operand_0 == operand_1) {
-      return static_cast<bool>(
-          IsTritonSupportedInstruction(*instr, gpu_version));
-    }
-
-    // For simplicity we only fuse elementwise binary ops with splat operands
-    // if they contain one non-splat operand.
-    // TODO(b/326217416): Extend the broadcast of splat constants/parameters to
-    // a broadcast of any op.
-    if ((IsBroadcastOfScalarConstant(*operand_0) ||
-         IsSupportedBroadcastOfParameter(*operand_0)) ^
-        (IsBroadcastOfScalarConstant(*operand_1) ||
-         IsSupportedBroadcastOfParameter(*operand_1))) {
-      return static_cast<bool>(
-          IsTritonSupportedInstruction(*instr, gpu_version));
-    }
-  }
-
-  return false;
-}
-
-bool TrivialEdge(HloInstruction** producer, HloInstruction* consumer,
-                 HloOpcode opcode,
-                 const se::GpuComputeCapability& gpu_version) {
-  while (consumer->opcode() != opcode) {
-    if (IsTriviallyFusible(consumer, gpu_version)) {
-      consumer = ChooseOperandForFusionProcessing(consumer);
-    } else {
-      return false;
-    }
-  }
-
-  *producer = consumer;
-  return true;
-}
-
-bool IsTriviallyConnectedProducerOf(
-    HloInstruction* producer, HloInstruction* consumer,
-    const se::GpuComputeCapability& gpu_version) {
-  if (producer == consumer) {
-    return true;
-  }
-
-  HloInstruction* found_producer = consumer;
-  while (
-      TrivialEdge(&found_producer, consumer, producer->opcode(), gpu_version)) {
-    if (found_producer == producer) {
-      return true;
-    }
-
-    if (!IsTriviallyFusible(found_producer, gpu_version)) {
-      return false;
-    }
-
-    consumer = found_producer->mutable_operand(0);
-  }
-
-  return false;
-}
-
-// Finds the first non-fusible producer of a diamond. This instruction is either
-//   1. the direct producer of the diamond, if that producer is used more than
-//      twice and/or is not otherwise trivially fusible
-//   2. the first parent instruction of the producer of the diamond such that
-//      that instruction is used more than once, and/or is not trivially
-//      fusible.
-HloInstruction* FindFirstNonFusibleDiamondProducer(
-    HloInstruction* diamond_producer,
-    const se::GpuComputeCapability& gpu_version) {
-  if (IsTriviallyFusible(diamond_producer, gpu_version,
-                         /*num_allowed_users=*/2)) {
-    diamond_producer = ChooseOperandForFusionProcessing(diamond_producer);
-    while (IsTriviallyFusible(diamond_producer, gpu_version)) {
-      diamond_producer = ChooseOperandForFusionProcessing(diamond_producer);
-    }
-  }
-
-  return diamond_producer;
-}
-
-// Creates a fusion corresponding to the input diamond chain. The resulting
-// fusion instruction is added to the module, but is not yet inserted into the
-// graph as a replacement of the original instructions.
-//
-// TODO(b/347956491): this awkward abstraction is needed to work around
-// limitations of HloFusionAdaptor, which underpins the implementation of
-// SymbolicTileAnalysis. We need to come up with a better solution.
-absl::StatusOr<HloFusionInstruction*> MakeFusionForDiamondChain(
-    const DiamondChainDescriptor& diamond_chain) {
-  auto [root, producer] = diamond_chain;
-
-  std::string suggested_name = "triton_softmax";
-  HloComputation::Builder builder(absl::StrCat(suggested_name, "_computation"));
-  // Original instruction -> fused one.
-  absl::flat_hash_map<const HloInstruction*, HloInstruction*>
-      old_to_new_mapping;
-
-  int param = 0;
-  old_to_new_mapping[producer] =
-      builder.AddInstruction(HloInstruction::CreateParameter(
-          param, producer->shape(), absl::StrCat("parameter_", param)));
-  param++;
-
-  std::vector<HloInstruction*> parameters = {producer};
-
-  std::function<void(HloInstruction*)> create_computation =
-      [&](HloInstruction* instr) -> void {
-    if (old_to_new_mapping.contains(instr)) {
-      return;
-    }
-    std::vector<HloInstruction*> new_operands;
-    for (HloInstruction* operand : instr->mutable_operands()) {
-      create_computation(operand);
-      new_operands.push_back(old_to_new_mapping[operand]);
-    }
-    if (instr->opcode() == HloOpcode::kParameter) {
-      old_to_new_mapping[instr] =
-          builder.AddInstruction(HloInstruction::CreateParameter(
-              param, instr->shape(), absl::StrCat("parameter_", param)));
-      parameters.push_back(instr);
-      param++;
-    } else {
-      old_to_new_mapping[instr] = builder.AddInstruction(
-          instr->CloneWithNewOperands(instr->shape(), new_operands));
-    }
-  };
-  create_computation(root);
-
-  HloComputation* computation =
-      root->GetModule()->AddComputationAndUnifyNamesAndIds(builder.Build(),
-                                                           /*is_entry=*/false);
-
-  HloInstruction* softmax_fusion =
-      root->parent()->AddInstruction(HloInstruction::CreateFusion(
-          root->shape(), HloInstruction::FusionKind::kCustom, parameters,
-          computation));
-
-  softmax_fusion->GetModule()->SetAndUniquifyInstrName(softmax_fusion,
-                                                       "triton_softmax");
-  TF_ASSIGN_OR_RETURN(auto gpu_config,
-                      softmax_fusion->backend_config<GpuBackendConfig>());
-  FusionBackendConfig& backend_config =
-      *gpu_config.mutable_fusion_backend_config();
-  backend_config.set_kind(std::string(kTritonFusionKind));
-  TF_RETURN_IF_ERROR(softmax_fusion->set_backend_config(gpu_config));
-  return xla::Cast<HloFusionInstruction>(softmax_fusion);
-}
-
-absl::Status FuseDiamondChainImpl(
-    const DiamondChainDescriptor& diamond_chain,
-    GpuPerformanceModelWithIndexingAnalysis& indexing_performance_model) {
-  TF_ASSIGN_OR_RETURN(HloFusionInstruction * softmax_fusion,
-                      MakeFusionForDiamondChain(diamond_chain));
-  HloInstruction* root = diamond_chain.root;
-
-  auto fusion_adaptor = HloFusionAdaptor::ForInstruction(softmax_fusion);
-
-  TF_ASSIGN_OR_RETURN(
-      TiledRunTimeDataOrError tiled_runtime_data_or,
-      indexing_performance_model.TryFindBestTilingForFusion(*fusion_adaptor));
-
-  if (const auto* fusion_decision =
-          std::get_if<FusionDecision>(&tiled_runtime_data_or)) {
-    return absl::FailedPreconditionError(absl::StrCat(
-        "SymbolicTileAnalysis failed. ", fusion_decision->Explain()));
-  }
-
-  TiledRunTimeData tiled_runtime_data =
-      std::get<TiledRunTimeData>(std::move(tiled_runtime_data_or));
-
-  TF_ASSIGN_OR_RETURN(auto backend_config,
-                      softmax_fusion->backend_config<GpuBackendConfig>());
-  *backend_config.mutable_fusion_backend_config()
-       ->mutable_block_level_fusion_config() =
-      tiled_runtime_data.block_level_parameters.ToBlockLevelFusionConfig();
-  TF_RETURN_IF_ERROR(softmax_fusion->set_backend_config(backend_config));
-
-  if (root->IsRoot()) {
-    root->parent()->set_root_instruction(softmax_fusion);
-    TF_RETURN_IF_ERROR(
-        root->parent()->RemoveInstructionAndUnusedOperands(root));
-  } else {
-    TF_RETURN_IF_ERROR(
-        root->parent()->ReplaceInstruction(root, softmax_fusion));
-  }
-
-  VLOG(5) << softmax_fusion->ToString();
-  return absl::OkStatus();
-}
-
-// Returns `true` if the diamond chain passed as a parameter can be tiled
-// correctly using `SymbolicTileAnalysis`.
-absl::StatusOr<bool> CanSymbolicTileAnalysisTileDiamondChain(
-    const DiamondChainDescriptor& diamond_chain) {
-  TF_ASSIGN_OR_RETURN(HloFusionInstruction * softmax_fusion,
-                      MakeFusionForDiamondChain(diamond_chain));
-  mlir::MLIRContext context;
-  SymbolicTileAnalysisOrError symbolic_tile_analysis_or_error =
-      SymbolicTileAnalysis::AnalyzeComputation(
-          *softmax_fusion->called_computation(), &context);
-
-  bool can_tile = std::holds_alternative<SymbolicTileAnalysis>(
-      symbolic_tile_analysis_or_error);
-
-  TF_RETURN_IF_ERROR(diamond_chain.root->GetModule()->RemoveEmbeddedComputation(
-      softmax_fusion->called_computation()));
-  TF_RETURN_IF_ERROR(
-      diamond_chain.root->parent()->RemoveInstruction(softmax_fusion));
-
-  return can_tile;
-}
-
-FusionDecision ShouldFuseReduction(const HloInstruction& reduce,
-                                   const se::GpuComputeCapability& cc) {
-  if (CodegenDecision is_supported = IsTritonSupportedInstruction(reduce, cc);
-      !is_supported) {
-    return FusionDecision(is_supported.Explain());
-  }
-
-  // Ensure that the reduction's identity is either a constant or a supported
-  // convert of a constant.
-  const HloInstruction* identity = reduce.operand(1);
-  bool should_fuse_identity =
-      identity->opcode() == HloOpcode::kConstant ||
-      (identity->opcode() == HloOpcode::kConvert &&
-       identity->operand(0)->opcode() == HloOpcode::kConstant &&
-       IsTritonSupportedInstruction(*identity, cc));
-  if (!should_fuse_identity) {
-    return "Reduction identity is not a constant or a supported convert of a "
-           "constant.";
-  }
-
-  return {};
-}
-
-DiamondMatchingDecision MatchesTritonCompatibleClosedReductionDiamondImpl(
-    HloInstruction* instr, const se::GpuComputeCapability& cc) {
-  if (!instr->IsElementwiseBinary()) {
-    return "Root is not elementwise binary.";
-  }
-
-  if (!IsTritonSupportedInstruction(*instr, cc)) {
-    return "Root is not supported for Triton instruction.";
-  }
-
-  HloInstruction* producer;
-  HloInstruction* broadcast;
-  HloInstruction* reduce;
-
-  if (!TrivialEdge(&broadcast, instr->mutable_operand(1), HloOpcode::kBroadcast,
-                   cc)) {
-    return "Could not find a trivial connection from root to a broadcast.";
-  }
-
-  if (!TrivialEdge(&reduce, broadcast->mutable_operand(0), HloOpcode::kReduce,
-                   cc)) {
-    return "Could not find a trivial connection from matched broadcast to a "
-           "reduction.";
-  }
-
-  if (!(HasDefaultLayout(broadcast->shape()) &&
-        HasDefaultLayout(reduce->shape()))) {
-    return "Broadcast or reduce have non-default layouts.";
-  }
-
-  if (FusionDecision should_fuse_reduction = ShouldFuseReduction(*reduce, cc);
-      !should_fuse_reduction) {
-    VLOG(2) << should_fuse_reduction.Explain();
-    return should_fuse_reduction;
-  }
-
-  // Ensure that the reduction's identity is either a constant or a supported
-  // convert of a constant.
-  const HloInstruction* identity = reduce->operand(1);
-  bool should_fuse_identity =
-      identity->opcode() == HloOpcode::kConstant ||
-      (identity->opcode() == HloOpcode::kConvert &&
-       identity->operand(0)->opcode() == HloOpcode::kConstant &&
-       IsTritonSupportedInstruction(*identity, cc));
-  if (!should_fuse_identity) {
-    return "Reduction identity is not a constant or a supported convert of a "
-           "constant.";
-  }
-
-  if (!HasOneUse(broadcast) || !HasOneUse(reduce)) {
-    return "More than one use of broadcast or reduce.";
-  }
-
-  producer = reduce->mutable_operand(0);
-
-  if (absl::c_linear_search(broadcast->dimensions(),
-                            broadcast->shape().rank() - 1)) {
-    return "Broadcast is not along the reduction dimension.";
-  }
-
-  while (IsTriviallyFusible(producer, cc)) {
-    producer = ChooseOperandForFusionProcessing(producer);
-  }
-
-  if (!HasDefaultLayout(producer->shape())) {
-    return "Producer has non-default layout.";
-  }
-
-  if (!IsTriviallyConnectedProducerOf(producer, instr->mutable_operand(0),
-                                      cc)) {
-    return "Producer is not trivially connected.";
-  }
-
-  if (producer != instr->operand(0) && instr->operand(0)->user_count() != 1) {
-    return "Unsupported root-producer connection.";
-  }
-
-  VLOG(5) << "Matched Softmax diamond with: ";
-  VLOG(5) << "root: " << instr->ToString();
-  VLOG(5) << "producer: " << producer->ToString();
-  VLOG(5) << "broadcast: " << broadcast->ToString();
-  VLOG(5) << "reduce: " << reduce->ToString();
-
-  return producer;
-}
-
-// Returns a vector containing all the single diamonds in the parameter module.
-// The diamonds are returned in def-before-use order, and grouped by
-// computation.
-absl::StatusOr<std::vector<DiamondChainDescriptor>> FindAllFusibleDiamonds(
-    HloModule& module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads,
-    const se::GpuComputeCapability& cc) {
-  std::vector<DiamondChainDescriptor> matched_diamonds;
-
-  for (HloComputation* comp :
-       module.MakeNonfusionComputations(execution_threads)) {
-    if (comp->IsCustomCallComputation()) {
-      continue;
-    }
-    for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
-      auto producer =
-          MatchesTritonCompatibleClosedReductionDiamondImpl(instr, cc);
-      if (std::holds_alternative<HloInstruction*>(producer)) {
-        DiamondChainDescriptor diamond_chain{
-            /*root=*/instr, /*producer=*/std::get<HloInstruction*>(producer)};
-        // We filter out the diamond chains that cannot be tiled correctly using
-        // `SymbolicTileAnalysis`.
-        TF_ASSIGN_OR_RETURN(
-            bool can_tile_diamond_chain,
-            CanSymbolicTileAnalysisTileDiamondChain(diamond_chain));
-        if (can_tile_diamond_chain) {
-          matched_diamonds.push_back(diamond_chain);
-        } else {
-          VLOG(5) << "Cannot tile the diamond pattern described by "
-                  << "instructions " << instr->ToString() << " and "
-                  << std::get<HloInstruction*>(producer)->ToString() << ".";
-          continue;
-        }
-
-      } else {
-        VLOG(5) << "Cannot match the diamond pattern for instruction "
-                << instr->ToString()
-                << ". Reason: " << std::get<FusionDecision>(producer).Explain();
-      }
-    }
-  }
-
-  return std::move(matched_diamonds);
-}
-
-// Returns the size of the reduction dimension of the input diamond.
-int64_t GetReductionDimensionSizeForDiamond(
-    const DiamondChainDescriptor& diamond_chain) {
-  HloInstruction* diamond_root = diamond_chain.root;
-  HloInstruction* instr = diamond_root->mutable_operand(1);
-  while (instr->opcode() != HloOpcode::kReduce) {
-    instr = ChooseOperandForFusionProcessing(instr);
-  }
-
-  int operand_rank = instr->operand(0)->shape().rank();
-  CHECK_EQ(instr->dimensions().size(), 1);
-  CHECK_EQ(instr->dimensions(0), operand_rank - 1);
-  return instr->operand(0)->shape().dimensions(operand_rank - 1);
-}
-
-// Returns a pointer to the last user of `instr` that is trivially fusible.
-HloInstruction* GetLastTriviallyFusibleUser(
-    HloInstruction* instr, const se::GpuComputeCapability& cc) {
-  while (HasOneUse(instr) && !instr->IsRoot() &&
-         IsTriviallyFusible(instr->users().front(), cc)) {
-    instr = instr->users().front();
-  }
-
-  // We do not care about the number of users for the last instruction of the
-  // fusion, so attempt to fuse one more instruction with this relaxed
-  // restriction.
-  if (HasOneUse(instr) && !instr->IsRoot() &&
-      IsTriviallyFusible(
-          instr->users().front(), cc,
-          /*num_allowed_users=*/instr->users().front()->user_count())) {
-    instr = instr->users().front();
-  }
-  return instr;
-}
-
-}  // anonymous namespace
-
-DiamondMatchingDecision
-SoftmaxRewriterTriton::MatchesTritonCompatibleClosedReductionDiamond(
-    HloInstruction* instr) const {
-  return MatchesTritonCompatibleClosedReductionDiamondImpl(
-      instr, device_info_.gpu_compute_capability());
-}
-
-absl::StatusOr<std::vector<DiamondChainDescriptor>>
-SoftmaxRewriterTriton::FindAllFusibleDiamondChains(
-    HloModule& module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) const {
-  const se::GpuComputeCapability& cc = device_info_.gpu_compute_capability();
-  TF_ASSIGN_OR_RETURN(std::vector<DiamondChainDescriptor> matched_diamonds,
-                      FindAllFusibleDiamonds(module, execution_threads, cc));
-
-  if (matched_diamonds.empty()) {
-    return std::vector<DiamondChainDescriptor>();
-  }
-
-  // If we matched several diamonds, it may be possible for some of them to be
-  // fused together. This is the case if the following conditions hold:
-  //   1. The path between the root of diamond n towards the producer of
-  //      diamond n+1 is composed only of trivially fusible operations. In that
-  //      case, the first non-trivially fusible producer of diamond n+1 must be
-  //      exactly the root of diamond n.
-  //   2. The root of diamond n/first non-fusible producer of diamond n+1 must
-  //      have
-  //        a. exactly one user if it is not exactly the producer of diamond
-  //           n+1;
-  //        b/ exactly two users otherwise.
-  //   3. The axis being reduced must have the same length in all the diamonds
-  //      being fused together.
-  //
-  // Crucially, this approach relies on a diamond root never being considered a
-  // trivially fusible operation.
-  std::vector<DiamondChainDescriptor> diamond_chains;
-  diamond_chains.reserve(matched_diamonds.size());
-
-  HloInstruction* current_fusion_producer =
-      FindFirstNonFusibleDiamondProducer(matched_diamonds.front().producer, cc);
-  int current_reduce_dimension_size =
-      GetReductionDimensionSizeForDiamond(matched_diamonds.front());
-
-  for (int diamond_idx = 1; diamond_idx < matched_diamonds.size();
-       ++diamond_idx) {
-    HloInstruction* diamond_producer = matched_diamonds[diamond_idx].producer;
-    HloInstruction* previous_diamond_root =
-        matched_diamonds[diamond_idx - 1].root;
-
-    HloInstruction* first_non_fusible_diamond_producer =
-        FindFirstNonFusibleDiamondProducer(diamond_producer, cc);
-
-    int diamond_reduce_dimension_size =
-        GetReductionDimensionSizeForDiamond(matched_diamonds[diamond_idx]);
-
-    if (first_non_fusible_diamond_producer == previous_diamond_root &&  // 1
-        ((first_non_fusible_diamond_producer != diamond_producer &&
-          HasOneUse(first_non_fusible_diamond_producer)) ||  // 2.a
-         (first_non_fusible_diamond_producer == diamond_producer &&
-          first_non_fusible_diamond_producer->user_count() == 2)) &&  // 2.b
-        diamond_reduce_dimension_size == current_reduce_dimension_size) {  // 3
-      continue;
-    }
-
-    // The "last trivially fusible user" chain of diamond chain n should never
-    // intersect with the "first non fusible diamond producer" chain of diamond
-    // chain n+1: if these chains intersected, then all the intermediate ops
-    // between the diamond chains could be trivially fused, and both diamond
-    // chains could be fused into a single diamond chain. Note that this only
-    // holds insofar as we do not allow fusing in bitcasts that modify the last
-    // dimension of the input array. It is however possible for the last
-    // trivially fusible user of diamond chain n to be the first non fusible
-    // diamond producer of diamond chain n+1.
-    diamond_chains.push_back(DiamondChainDescriptor{
-        GetLastTriviallyFusibleUser(previous_diamond_root, cc),
-        current_fusion_producer,
-    });
-
-    current_fusion_producer = first_non_fusible_diamond_producer;
-    current_reduce_dimension_size = diamond_reduce_dimension_size;
-  }
-
-  // The last diamond chain is still open; close it.
-  diamond_chains.push_back(DiamondChainDescriptor{
-      GetLastTriviallyFusibleUser(matched_diamonds.back().root, cc),
-      current_fusion_producer});
-
-  // We filter out the diamond chains that cannot be tiled correctly using
-  // `SymbolicTileAnalysis`.
-  std::vector<DiamondChainDescriptor> filtered_diamond_chains;
-  for (const DiamondChainDescriptor& diamond_chain : diamond_chains) {
-    TF_ASSIGN_OR_RETURN(bool can_tile_diamond_chain,
-                        CanSymbolicTileAnalysisTileDiamondChain(diamond_chain));
-    if (can_tile_diamond_chain) {
-      filtered_diamond_chains.push_back(diamond_chain);
-    }
-  }
-  return filtered_diamond_chains;
-}
-
-absl::Status SoftmaxRewriterTriton::FuseDiamondChain(
-    const DiamondChainDescriptor& diamond_chain) {
-  HloFusionAnalysisCache fusion_analysis_cache(device_info_);
-  GpuPerformanceModelWithIndexingAnalysis indexing_performance_model(
-      &device_info_, &fusion_analysis_cache, shape_size_, &mlir_context_);
-
-  return FuseDiamondChainImpl(diamond_chain, indexing_performance_model);
-}
-
-absl::StatusOr<bool> SoftmaxRewriterTriton::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  TF_RETURN_IF_ERROR(EnsureTritonSupportsComputeCapability(
-      device_info_.gpu_compute_capability()));
-
-  TF_ASSIGN_OR_RETURN(std::vector<DiamondChainDescriptor> diamond_chains,
-                      FindAllFusibleDiamondChains(*module, execution_threads));
-
-  if (diamond_chains.empty()) {
-    return false;
-  }
-
-  // The diamond chains must be emitted in reverse order, to make sure that
-  // producer instructions are emitted correctly when the root of
-  // diamond chain n is exactly the producer of diamond chain n+1.
-  for (auto diamond_chain = diamond_chains.rbegin();
-       diamond_chain != diamond_chains.rend(); ++diamond_chain) {
-    TF_RET_CHECK(FuseDiamondChain(*diamond_chain).ok());
-  }
-  return true;
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/softmax_rewriter_triton.h b/third_party/xla/xla/service/gpu/softmax_rewriter_triton.h
deleted file mode 100644
index 9da8cc5..0000000
--- a/third_party/xla/xla/service/gpu/softmax_rewriter_triton.h
+++ /dev/null
@@ -1,101 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_SOFTMAX_REWRITER_TRITON_H_
-#define XLA_SERVICE_GPU_SOFTMAX_REWRITER_TRITON_H_
-
-#include <variant>
-#include <vector>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "mlir/IR/MLIRContext.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_cost_analysis.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/service/instruction_fusion.h"
-#include "xla/stream_executor/device_description.h"
-
-namespace xla {
-namespace gpu {
-
-struct DiamondChainDescriptor {
-  HloInstruction* root = nullptr;
-  HloInstruction* producer = nullptr;
-};
-
-using DiamondMatchingDecision = std::variant<FusionDecision, HloInstruction*>;
-
-// Rewrite compatible Softmax into a custom fusion region to be code-generated
-// with the Triton-based Softmax emitter.
-class SoftmaxRewriterTriton : public HloModulePass {
- public:
-  explicit SoftmaxRewriterTriton(const se::DeviceDescription& device_info,
-                                 HloCostAnalysis::ShapeSizeFunction shape_size)
-      : device_info_(device_info), shape_size_(shape_size) {}
-
-  absl::string_view name() const override { return "triton-softmax-rewriter"; }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
-  // Finds and returns all the fusible diamond chains in the module. The
-  // resulting vector is sorted according to a post-order matching (i.e. within
-  // the same computation, producer diamonds appear before consumer diamonds).
-  absl::StatusOr<std::vector<DiamondChainDescriptor>>
-  FindAllFusibleDiamondChains(
-      HloModule& module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) const;
-
-  // Constructs a Softmax fusion containing all the instructions between the
-  // root and the producer of a diamond chain. The producer is excluded from the
-  // fusion.
-  absl::Status FuseDiamondChain(const DiamondChainDescriptor& diamond_chain);
-
-  // Return the producer of the following pattern:
-  //
-  // producer
-  // |    \
-  // |  reduce_{max,sum,...}
-  // |     |
-  // |  broadcast
-  // |   /
-  // binop (elementwise)
-  //
-  // where each edge is allowed to contain also trivial operations that can be
-  // generated by Triton. We mean by "trivial" here those operations that do not
-  // increase the amount of memory read/written by the fusion, and that are
-  // compatible with any chosen tiling.
-  //
-  // We also assume that the reduction is done on the last axis of the producer
-  // array.
-  DiamondMatchingDecision MatchesTritonCompatibleClosedReductionDiamond(
-      HloInstruction* instr) const;
-
- private:
-  const se::DeviceDescription& device_info_;
-  const HloCostAnalysis::ShapeSizeFunction shape_size_;
-  mlir::MLIRContext mlir_context_;
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_SOFTMAX_REWRITER_TRITON_H_
diff --git a/third_party/xla/xla/service/gpu/softmax_rewriter_triton_test.cc b/third_party/xla/xla/service/gpu/softmax_rewriter_triton_test.cc
deleted file mode 100644
index 8488031..0000000
--- a/third_party/xla/xla/service/gpu/softmax_rewriter_triton_test.cc
+++ /dev/null
@@ -1,1590 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-
-You may obtain a copy of the License at
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "xla/service/gpu/softmax_rewriter_triton.h"
-
-#include <cstdint>
-#include <memory>
-#include <string>
-#include <variant>
-#include <vector>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "absl/log/check.h"
-#include "absl/log/log.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/hlo/utils/hlo_query.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/fusions/triton/triton_support.h"
-#include "xla/service/gpu/gpu_device_info_for_tests.h"
-#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h"
-#include "xla/service/instruction_fusion.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/tests/hlo_test_base.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/status_matchers.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-namespace m = ::xla::match;
-
-using ::testing::HasSubstr;
-
-GpuHloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() {
-  return [&](const Shape& shape) {
-    constexpr int64_t kPointerSize = 8;
-    return ShapeUtil::ByteSizeOf(shape, kPointerSize);
-  };
-}
-
-bool HasBlockLevelFusionConfig(const HloInstruction* fusion) {
-  return fusion->opcode() == HloOpcode::kFusion &&
-         fusion->has_backend_config() &&
-         fusion->backend_config<GpuBackendConfig>().ok() &&
-         fusion->backend_config<GpuBackendConfig>()
-             ->fusion_backend_config()
-             .has_block_level_fusion_config();
-}
-
-// Wrapper around SoftmaxRewriterTriton(gpu_version).Run(module) that finds
-// and fuses as many diamond chains as possible without invoking any kind of
-// cost analysis.
-absl::StatusOr<bool> SoftmaxRewriterTritonMatchAndRewrite(
-    const se::DeviceDescription& device_info, HloModule* module) {
-  CHECK_NE(module, nullptr);
-  SoftmaxRewriterTriton softmax_rewriter_triton(device_info,
-                                                ShapeSizeBytesFunction());
-  TF_ASSIGN_OR_RETURN(std::vector<DiamondChainDescriptor> diamond_chains,
-                      softmax_rewriter_triton.FindAllFusibleDiamondChains(
-                          *module, /*execution_threads=*/{}));
-
-  for (auto diamond_chain = diamond_chains.rbegin();
-       diamond_chain != diamond_chains.rend(); ++diamond_chain) {
-    TF_RETURN_IF_ERROR(
-        softmax_rewriter_triton.FuseDiamondChain(*diamond_chain));
-  }
-
-  return !diamond_chains.empty();
-}
-
-class SoftmaxRewriterTritonTest
-    : public HloTestBase,
-      public ::testing::WithParamInterface<PrimitiveType> {
- protected:
-  se::DeviceDescription device_info_{TestGpuDeviceInfo::RTXA6000DeviceInfo()};
-};
-
-TEST_F(SoftmaxRewriterTritonTest, CanFuseExactSoftmaxF32) {
-  const std::string hlo_string = R"(
-HloModule softmax
-max_computation {
-  arg_0 = f32[] parameter(0)
-  arg_1 = f32[] parameter(1)
-  ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-add_computation {
-  arg_0.1 = f32[] parameter(0)
-  arg_1.1 = f32[] parameter(1)
-  ROOT add = f32[] add(arg_0.1, arg_1.1)
-}
-ENTRY main {
-  param_0 = f32[127,125]{1,0} parameter(0)
-  constant_neg_inf = f32[] constant(-inf)
-  reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
-  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
-  subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
-  exponential = f32[127,125]{1,0} exponential(subtract)
-  constant_zero = f32[] constant(0)
-  second_reduce = f32[127]{0} reduce(exponential, constant_zero), dimensions={1}, to_apply=add_computation
-  second_broadcast = f32[127,125]{1,0} broadcast(second_reduce), dimensions={0}
-  ROOT divide = f32[127,125]{1,0} divide(exponential, second_broadcast)
-}
-)";
-  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-
-  EXPECT_TRUE(
-      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-  EXPECT_TRUE(verifier().Run(module.get()).status().ok());
-  VLOG(2) << module->ToString();
-
-  EXPECT_THAT(
-      module->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
-}
-
-TEST_F(SoftmaxRewriterTritonTest,
-       CanFuseSoftmaxLikeComputationWithNonF32DataType) {
-  const std::string hlo_string = R"(
-HloModule softmax
-max_computation {
-  arg_0 = f16[] parameter(0)
-  arg_1 = f16[] parameter(1)
-  ROOT maximum = f16[] maximum(arg_0, arg_1)
-}
-add_computation {
-  arg_0.1 = f16[] parameter(0)
-  arg_1.1 = f16[] parameter(1)
-  ROOT add = f16[] add(arg_0.1, arg_1.1)
-}
-ENTRY main {
-  param_0 = f16[127,125]{1,0} parameter(0)
-  constant_neg_inf = f16[] constant(-inf)
-  reduce = f16[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
-  broadcast = f16[127,125]{1,0} broadcast(reduce), dimensions={0}
-  subtract = f16[127,125]{1,0} subtract(param_0, broadcast)
-  // Replace Softmax exponential with abs, because Triton doesn't support
-  // non-f32 exponentials.
-  abs = f16[127,125]{1,0} abs(subtract)
-  constant_zero = f16[] constant(0)
-  second_reduce = f16[127]{0} reduce(abs, constant_zero), dimensions={1}, to_apply=add_computation
-  second_broadcast = f16[127,125]{1,0} broadcast(second_reduce), dimensions={0}
-  // Replace divide with multiply, because Triton doesn't support f16
-  // divisions.
-  ROOT multiply = f16[127,125]{1,0} multiply(abs, second_broadcast)
-}
-)";
-  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-
-  EXPECT_TRUE(
-      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-  EXPECT_TRUE(verifier().Run(module.get()).status().ok());
-  EXPECT_THAT(
-      module->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
-}
-
-TEST_F(SoftmaxRewriterTritonTest, CanFuseSingleNormalizationDiamond) {
-  const std::string hlo_string = R"(
-HloModule softmax
-max_computation {
-  arg_0 = f32[] parameter(0)
-  arg_1 = f32[] parameter(1)
-  ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-ENTRY main {
-  param_0 = f32[127,125]{1,0} parameter(0)
-  constant_neg_inf = f32[] constant(-inf)
-  reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
-  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
-  ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
-}
-)";
-  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-  EXPECT_TRUE(
-      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-  EXPECT_TRUE(verifier().Run(module.get()).status().ok());
-  EXPECT_THAT(
-      module->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
-}
-
-TEST_F(SoftmaxRewriterTritonTest,
-       DoesNotFuseDiamondInvolvingUnsupportedTritonInstruction) {
-  const std::string hlo_string = R"(
-HloModule softmax
-add_computation {
-  arg_0.1 = bf16[] parameter(0)
-  arg_1.1 = bf16[] parameter(1)
-  ROOT add = bf16[] add(arg_0.1, arg_1.1)
-}
-ENTRY main {
-  param_0 = bf16[127,125]{1,0} parameter(0)
-  constant_zero = bf16[] constant(0)
-  reduce = bf16[127]{0} reduce(param_0, constant_zero), dimensions={1}, to_apply=add_computation
-  broadcast = bf16[127,125]{1,0} broadcast(reduce), dimensions={0}
-  ROOT divide = bf16[127,125]{1,0} divide(param_0, broadcast)
-})";
-
-  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-  const HloInstruction* bf16_divide =
-      module->entry_computation()->root_instruction();
-  EXPECT_FALSE(IsTritonSupportedInstruction(
-      *bf16_divide, device_info_.gpu_compute_capability()));
-  EXPECT_FALSE(
-      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-}
-
-TEST_F(SoftmaxRewriterTritonTest,
-       DoesNotFuseInstructionsUnsupportedByTritonIntoDiamonds) {
-  const std::string hlo_string = R"(
-HloModule softmax
-max_computation {
-  arg_0 = bf16[] parameter(0)
-  arg_1 = bf16[] parameter(1)
-  ROOT maximum = bf16[] maximum(arg_0, arg_1)
-}
-ENTRY main {
-  param_0 = bf16[127,125]{1,0} parameter(0)
-  constant_neg_inf = bf16[] constant(-inf)
-  reduce = bf16[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
-  broadcast = bf16[127,125]{1,0} broadcast(reduce), dimensions={0}
-  subtract = bf16[127,125]{1,0} subtract(param_0, broadcast)
-  ROOT exponential = bf16[127,125]{1,0} exponential(subtract)
-})";
-
-  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-  const HloInstruction* bf16_exponential =
-      hlo_query::GetFirstInstructionWithOpcode(*module->entry_computation(),
-                                               HloOpcode::kExp);
-  EXPECT_FALSE(IsTritonSupportedInstruction(
-      *bf16_exponential, device_info_.gpu_compute_capability()));
-  EXPECT_TRUE(
-      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-  EXPECT_TRUE(verifier().Run(module.get()).status().ok());
-  EXPECT_THAT(
-      module->entry_computation()->root_instruction(),
-      GmockMatch(m::Exp(
-          m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig))));
-}
-
-TEST_F(SoftmaxRewriterTritonTest, CanNotFuseSoftmaxDiamondWithWrongLayout) {
-  const std::string hlo_string = R"(
-HloModule softmax
-max_computation {
-  arg_0 = f32[] parameter(0)
-  arg_1 = f32[] parameter(1)
-  ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-ENTRY main {
-  param_0 = f32[127,125]{0,1} parameter(0)
-  constant_neg_inf = f32[] constant(-inf)
-  reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
-  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
-  ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
-}
-)";
-  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-  EXPECT_FALSE(
-      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-}
-
-TEST_F(SoftmaxRewriterTritonTest,
-       CanNotFuseSoftmaxDiamondWithWrongReduceDimension) {
-  const std::string hlo_string = R"(
-HloModule softmax
-max_computation {
-  arg_0 = f32[] parameter(0)
-  arg_1 = f32[] parameter(1)
-  ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-ENTRY main {
-  param_0 = f32[127,125]{1,0} parameter(0)
-  constant_neg_inf = f32[] constant(-inf)
-  reduce = f32[125]{0} reduce(param_0, constant_neg_inf), dimensions={0}, to_apply=max_computation
-  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={1}
-  ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
-}
-)";
-  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-  EXPECT_FALSE(
-      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-}
-
-TEST_F(SoftmaxRewriterTritonTest,
-       CanNotFuseSoftmaxDiamondWithWrongBroadcastDimension) {
-  const std::string hlo_string = R"(
-HloModule softmax
-max_computation {
-  arg_0 = f32[] parameter(0)
-  arg_1 = f32[] parameter(1)
-  ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-ENTRY main {
-  param_0 = f32[125,125]{1,0} parameter(0)
-  constant_neg_inf = f32[] constant(-inf)
-  reduce = f32[125]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
-  broadcast = f32[125,125]{1,0} broadcast(reduce), dimensions={1}
-  ROOT subtract = f32[125,125]{1,0} subtract(param_0, broadcast)
-}
-)";
-  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-  EXPECT_FALSE(
-      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-}
-
-TEST_F(SoftmaxRewriterTritonTest,
-       CanNotFuseSoftmaxDiamondWithExtraBroadcastUsage) {
-  const std::string hlo_string = R"(
-HloModule softmax
-max_computation {
-  arg_0 = f32[] parameter(0)
-  arg_1 = f32[] parameter(1)
-  ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-ENTRY main {
-  param_0 = f32[127,125]{1,0} parameter(0)
-  constant_neg_inf = f32[] constant(-inf)
-  reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
-  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
-  subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
-  ROOT multiply = f32[127,125]{1,0} multiply(broadcast, subtract)
-}
-)";
-  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-  EXPECT_FALSE(
-      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-}
-
-TEST_F(SoftmaxRewriterTritonTest,
-       CanFuseSoftmaxWithIntermediateUnaryElementwise) {
-  const std::string hlo_string = R"(
-HloModule softmax
-max_computation {
-  arg_0 = f32[] parameter(0)
-  arg_1 = f32[] parameter(1)
-  ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-add_computation {
-  arg_0.1 = f32[] parameter(0)
-  arg_1.1 = f32[] parameter(1)
-  ROOT add = f32[] add(arg_0.1, arg_1.1)
-}
-ENTRY main {
-  param_0 = f32[127,125]{1,0} parameter(0)
-  constant_neg_inf = f32[] constant(-inf)
-  reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
-  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
-  subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
-  abs = f32[127,125]{1,0} abs(subtract)
-  exponential = f32[127,125]{1,0} exponential(abs)
-  constant_zero = f32[] constant(0)
-  second_reduce = f32[127]{0} reduce(exponential, constant_zero), dimensions={1}, to_apply=add_computation
-  second_broadcast = f32[127,125]{1,0} broadcast(second_reduce), dimensions={0}
-  ROOT divide = f32[127,125]{1,0} divide(exponential, second_broadcast)
-}
-)";
-  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-
-  EXPECT_TRUE(
-      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-  EXPECT_TRUE(verifier().Run(module.get()).status().ok());
-  EXPECT_THAT(
-      module->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
-}
-
-TEST_F(SoftmaxRewriterTritonTest,
-       CanFuseTwoDiamondsWithSecondDiamondProducerEqualToFirstDiamondRoot) {
-  const std::string hlo_string = R"(
-HloModule softmax
-max_computation {
-  arg_0 = f32[] parameter(0)
-  arg_1 = f32[] parameter(1)
-  ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-add_computation {
-  arg_0.1 = f32[] parameter(0)
-  arg_1.1 = f32[] parameter(1)
-  ROOT add = f32[] add(arg_0.1, arg_1.1)
-}
-ENTRY main {
-  param_0 = f32[127,125]{1,0} parameter(0)
-  constant_neg_inf = f32[] constant(-inf)
-  reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
-  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
-  subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
-  constant_zero = f32[] constant(0)
-  second_reduce = f32[127]{0} reduce(subtract, constant_zero), dimensions={1}, to_apply=add_computation
-  second_broadcast = f32[127,125]{1,0} broadcast(second_reduce), dimensions={0}
-  ROOT divide = f32[127,125]{1,0} divide(subtract, second_broadcast)
-}
-)";
-  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-
-  EXPECT_TRUE(
-      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-  EXPECT_TRUE(verifier().Run(module.get()).status().ok());
-  EXPECT_THAT(
-      module->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
-}
-
-TEST_F(SoftmaxRewriterTritonTest,
-       CanFuseDiamondWithTrailingUnaryElementwiseAtTheRoot) {
-  const std::string hlo_string = R"(
-HloModule softmax
-max_computation {
-  arg_0 = f32[] parameter(0)
-  arg_1 = f32[] parameter(1)
-  ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-ENTRY main {
-  param_0 = f32[127,125]{1,0} parameter(0)
-  constant_neg_inf = f32[] constant(-inf)
-  reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
-  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
-  subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
-  ROOT abs = f32[127,125]{1,0} abs(subtract)
-}
-)";
-  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-  EXPECT_TRUE(
-      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-  EXPECT_TRUE(verifier().Run(module.get()).status().ok());
-  EXPECT_THAT(
-      module->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
-}
-
-TEST_F(SoftmaxRewriterTritonTest, CanFuseDiamondWithUnaryElementwisePrefix) {
-  const std::string hlo_string = R"(
-HloModule softmax
-max_computation {
-  arg_0 = f32[] parameter(0)
-  arg_1 = f32[] parameter(1)
-  ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-ENTRY main {
-  param_0 = f32[127,125]{1,0} parameter(0)
-  abs = f32[127,125]{1,0} abs(param_0)
-  constant_neg_inf = f32[] constant(-inf)
-  reduce = f32[127]{0} reduce(abs, constant_neg_inf), dimensions={1}, to_apply=max_computation
-  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
-  ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
-}
-)";
-  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-  EXPECT_TRUE(
-      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-  EXPECT_TRUE(verifier().Run(module.get()).status().ok());
-  EXPECT_THAT(
-      module->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
-}
-
-TEST_F(SoftmaxRewriterTritonTest,
-       CanFuseDiamondWithMultipleBroadcastDimensions) {
-  const std::string hlo_string = R"(
-HloModule softmax
-max_computation {
-  arg_0 = f32[] parameter(0)
-  arg_1 = f32[] parameter(1)
-  ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-ENTRY main {
-  param_0 = f32[1,3,125,125]{3,2,1,0} parameter(0)
-  bitcast = f32[3,125,125]{2,1,0} bitcast(f32[1,3,125,125]{3,2,1,0} param_0)
-  constant_neg_inf = f32[] constant(-inf)
-  reduce = f32[3,125]{1,0} reduce(f32[3,125,125]{2,1,0} bitcast, f32[] constant_neg_inf), dimensions={2}, to_apply=max_computation
-  broadcast = f32[1,3,125,125]{3,2,1,0} broadcast(f32[3,125]{1,0} reduce), dimensions={1,2}
-  ROOT subtract = f32[1,3,125,125]{3,2,1,0} subtract(f32[1,3,125,125]{3,2,1,0} param_0, f32[1,3,125,125]{3,2,1,0} broadcast)
-})";
-  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-
-  EXPECT_TRUE(
-      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-  EXPECT_TRUE(verifier().Run(module.get()).status().ok());
-  EXPECT_THAT(
-      module->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
-}
-
-TEST_F(SoftmaxRewriterTritonTest,
-       CanNotFuseSoftmaxDiamondWithParameterReducerIdentity) {
-  const std::string hlo_string = R"(
-HloModule softmax
-max_computation {
-  arg_0 = f32[] parameter(0)
-  arg_1 = f32[] parameter(1)
-  ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-
-ENTRY main {
-  param_0 = f32[127,125]{1,0} parameter(0)
-  identity = f32[] parameter(1)
-  constant_neg_inf = f32[] constant(-inf)
-  reduce = f32[127]{0} reduce(param_0, identity), dimensions={1}, to_apply=max_computation
-  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
-  ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
-}
-)";
-  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-  EXPECT_FALSE(
-      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-}
-
-TEST_F(SoftmaxRewriterTritonTest,
-       CanNotFuseSoftmaxDiamondWithTritonIncompatibleReducer) {
-  const std::string hlo_string = R"(
-HloModule softmax
-max_computation {
-  arg_0 = f32[] parameter(0)
-  arg_1 = f32[] parameter(1)
-  if_0 = pred[] is-finite(arg_0)
-  c = f32[] convert(if_0)
-  ROOT maximum = f32[] maximum(c, arg_1)
-}
-
-ENTRY main {
-  param_0 = f32[127,125]{1,0} parameter(0)
-  constant_neg_inf = f32[] constant(-inf)
-  reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
-  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
-  ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
-}
-)";
-  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-  EXPECT_FALSE(
-      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-}
-
-TEST_F(SoftmaxRewriterTritonTest,
-       CanFuseSoftmaxDiamondWithLastDimensionBitcastAfterReduce) {
-  const std::string hlo_string = R"(
-HloModule softmax
-max_computation {
-  arg_0 = f32[] parameter(0)
-  arg_1 = f32[] parameter(1)
-  ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-
-ENTRY main {
-  param_0 = f32[3,127,125]{2,1,0} parameter(0)
-  constant_neg_inf = f32[] constant(-inf)
-  reduce = f32[3,127]{1,0} reduce(param_0, constant_neg_inf), dimensions={2}, to_apply=max_computation
-  bitcasted_reduce = f32[381]{0} bitcast(reduce)
-  broadcast = f32[381,125]{1,0} broadcast(bitcasted_reduce), dimensions={0}
-  bitcasted_broadcast = f32[3,127,125]{2,1,0} bitcast(broadcast)
-  ROOT subtract = f32[3,127,125]{2,1,0} subtract(param_0, bitcasted_broadcast)
-}
-)";
-  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-  EXPECT_TRUE(
-      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-  EXPECT_TRUE(verifier().Run(module.get()).status().ok());
-  EXPECT_THAT(
-      module->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
-}
-
-TEST_F(SoftmaxRewriterTritonTest,
-       CanNotFuseSoftmaxDiamondWithTransposeBitcast) {
-  const std::string hlo_string = R"(
-HloModule softmax
-max_computation {
-  arg_0 = f32[] parameter(0)
-  arg_1 = f32[] parameter(1)
-  ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-
-ENTRY main {
-  param_0 = f32[1,127,125]{2,1,0} parameter(0)
-  constant_neg_inf = f32[] constant(-inf)
-  bitcasted_param_0 = f32[127,1,125]{2,0,1} bitcast(param_0)
-  reduce = f32[127,1]{0,1} reduce(bitcasted_param_0, constant_neg_inf), dimensions={2}, to_apply=max_computation
-  broadcast = f32[127,1,125]{2,0,1} broadcast(reduce), dimensions={0,1}
-  bitcasted_broadcast = f32[1,127,125]{2,1,0} bitcast(broadcast)
-  ROOT subtract = f32[1,127,125]{2,1,0} subtract(param_0, bitcasted_broadcast)
-}
-)";
-  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-  EXPECT_FALSE(
-      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-}
-
-TEST_F(SoftmaxRewriterTritonTest,
-       CanNotFuseTwoDiamondsWithDifferentReductionAxisSizeTogether) {
-  const std::string hlo_string = R"(
-HloModule softmax
-max_computation {
-  arg_0 = f32[] parameter(0)
-  arg_1 = f32[] parameter(1)
-  ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-add_computation {
-  arg_0.1 = f32[] parameter(0)
-  arg_1.1 = f32[] parameter(1)
-  ROOT add = f32[] add(arg_0.1, arg_1.1)
-}
-ENTRY main {
-  param_0 = f32[127,625]{1,0} parameter(0)
-  constant_neg_inf = f32[] constant(-inf)
-  reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
-  broadcast = f32[127,625]{1,0} broadcast(reduce), dimensions={0}
-  subtract = f32[127,625]{1,0} subtract(param_0, broadcast)
-  bitcasted_subtract = f32[127,5,125] bitcast(subtract)
-  exponential = f32[127,5,125] exponential(bitcasted_subtract)
-  constant_zero = f32[] constant(0)
-  second_reduce = f32[127,5] reduce(exponential, constant_zero), dimensions={2}, to_apply=add_computation
-  second_broadcast = f32[127,5,125] broadcast(second_reduce), dimensions={0,1}
-  ROOT divide = f32[127,5,125] divide(exponential, second_broadcast)
-}
-)";
-  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-
-  EXPECT_TRUE(
-      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-  EXPECT_TRUE(verifier().Run(module.get()).status().ok());
-  EXPECT_THAT(
-      module->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::Fusion(m::Bitcast(m::Fusion(m::Parameter())
-                                   .WithPredicate(HasBlockLevelFusionConfig)))
-              .WithPredicate(HasBlockLevelFusionConfig)));
-}
-
-TEST_F(SoftmaxRewriterTritonTest,
-       CanNotFuseTwoDiamondsWithExtraUsageForFirstDiamondRoot) {
-  const std::string hlo_string = R"(
-HloModule softmax
-max_computation {
-  arg_0 = f32[] parameter(0)
-  arg_1 = f32[] parameter(1)
-  ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-add_computation {
-  arg_0.1 = f32[] parameter(0)
-  arg_1.1 = f32[] parameter(1)
-  ROOT add = f32[] add(arg_0.1, arg_1.1)
-}
-ENTRY main {
-  param_0 = f32[127,125]{1,0} parameter(0)
-  constant_neg_inf = f32[] constant(-inf)
-  reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
-  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
-  subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
-  exponential = f32[127,125]{1,0} exponential(subtract)
-  constant_zero = f32[] constant(0)
-  second_reduce = f32[127]{0} reduce(exponential, constant_zero), dimensions={1}, to_apply=add_computation
-  second_broadcast = f32[127,125]{1,0} broadcast(second_reduce), dimensions={0}
-  divide = f32[127,125]{1,0} divide(exponential, second_broadcast)
-  ROOT tuple = (f32[127,125]{1,0}, f32[127,125]{1,0}) tuple(divide, subtract)
-}
-)";
-  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-
-  EXPECT_TRUE(
-      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-  EXPECT_TRUE(verifier().Run(module.get()).status().ok());
-      EXPECT_THAT(
-          module->entry_computation()->root_instruction(),
-          GmockMatch(m::Tuple(
-              m::Fusion(m::Fusion()).WithPredicate(HasBlockLevelFusionConfig),
-              m::Fusion(m::Parameter())
-                  .WithPredicate(HasBlockLevelFusionConfig))));
-}
-
-TEST_F(SoftmaxRewriterTritonTest,
-       CanNotFuseTwoDiamondsWithExtraUsageForSecondDiamondProducer) {
-  const std::string hlo_string = R"(
-HloModule softmax
-max_computation {
-  arg_0 = f32[] parameter(0)
-  arg_1 = f32[] parameter(1)
-  ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-add_computation {
-  arg_0.1 = f32[] parameter(0)
-  arg_1.1 = f32[] parameter(1)
-  ROOT add = f32[] add(arg_0.1, arg_1.1)
-}
-ENTRY main {
-  param_0 = f32[127,125]{1,0} parameter(0)
-  constant_neg_inf = f32[] constant(-inf)
-  reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
-  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
-  subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
-  exponential = f32[127,125]{1,0} exponential(subtract)
-  constant_zero = f32[] constant(0)
-  second_reduce = f32[127]{0} reduce(exponential, constant_zero), dimensions={1}, to_apply=add_computation
-  second_broadcast = f32[127,125]{1,0} broadcast(second_reduce), dimensions={0}
-  divide = f32[127,125]{1,0} divide(exponential, second_broadcast)
-  ROOT tuple = (f32[127,125]{1,0}, f32[127,125]{1,0}) tuple(divide, exponential)
-}
-)";
-  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-
-  EXPECT_TRUE(
-      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-  EXPECT_TRUE(verifier().Run(module.get()).status().ok());
-
-  EXPECT_THAT(
-      module->entry_computation()->root_instruction(),
-      GmockMatch(m::Tuple(
-          m::Fusion(m::Fusion()).WithPredicate(HasBlockLevelFusionConfig),
-          m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig))));
-}
-
-TEST_F(SoftmaxRewriterTritonTest,
-       CanFuseSoftmaxDiamondWithTritonIncompatibleProducer) {
-  const std::string hlo_string = R"(
-HloModule softmax
-max_computation {
-  arg_0 = f32[] parameter(0)
-  arg_1 = f32[] parameter(1)
-  ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-
-ENTRY main {
-  param_0 = f16[127,125]{1,0} parameter(0)
-  exponential = f16[127,125] exponential(param_0)
-  convert = f32[127,125] convert(exponential)
-  constant_neg_inf = f32[] constant(-inf)
-  reduce = f32[127]{0} reduce(convert, constant_neg_inf), dimensions={1}, to_apply=max_computation
-  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
-  ROOT subtract = f32[127,125]{1,0} subtract(convert, broadcast)
-})";
-  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-
-  EXPECT_TRUE(
-      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-  EXPECT_TRUE(verifier().Run(module.get()).status().ok());
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch(m::Fusion(m::Exp(m::Parameter()))
-                             .WithPredicate(HasBlockLevelFusionConfig)));
-}
-
-TEST_F(SoftmaxRewriterTritonTest,
-       CanNotFuseSoftmaxDiamondWithNonFusibleBitcastBetweenReduceAndProducer) {
-  const std::string hlo_string = R"(
-HloModule softmax
-
-max_computation {
-  arg_0 = f32[] parameter(0)
-  arg_1 = f32[] parameter(1)
-  ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-
-ENTRY main {
-  param_0 = f32[1,127,5,25]{3,2,1,0} parameter(0)
-  bitcast_0 = f32[127,125] bitcast(param_0)
-  bitcast_1 = f32[127,125] bitcast(param_0)
-  constant_neg_inf = f32[] constant(-inf)
-  reduce = f32[127]{0} reduce(bitcast_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
-  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
-  ROOT subtract = f32[127,125]{1,0} subtract(bitcast_1, broadcast)
-}
-)";
-  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-  EXPECT_FALSE(
-      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-}
-
-TEST_F(SoftmaxRewriterTritonTest,
-       CanFuseSoftmaxDiamondWithBitcastProducerFollowedByBitcastsOnEachUse) {
-  const std::string hlo_string = R"(
-HloModule softmax
-
-max_computation {
-  arg_0 = f32[] parameter(0)
-  arg_1 = f32[] parameter(1)
-  ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-
-ENTRY main {
-  param_0 = f32[1,1,127,125]{3,2,1,0} parameter(0)
-  bitcast_parent = f32[127,125]{1,0} bitcast(param_0)
-  bitcast_0 = f32[127,125]{1,0} bitcast(bitcast_parent)
-  bitcast_1 = f32[127,125]{1,0} bitcast(bitcast_parent)
-  constant_neg_inf = f32[] constant(-inf)
-  reduce = f32[127]{0} reduce(bitcast_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
-  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
-  ROOT subtract = f32[127,125]{1,0} subtract(bitcast_1, broadcast)
-}
-)";
-  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-  EXPECT_TRUE(
-      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-  EXPECT_TRUE(verifier().Run(module.get()).status().ok());
-  EXPECT_THAT(
-      module->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
-}
-
-TEST_F(SoftmaxRewriterTritonTest, RewriterBailsOutOnPreAmpereCudaGpu) {
-  const std::string hlo_string = R"(
-HloModule softmax
-max_computation {
-  arg_0 = f32[] parameter(0)
-  arg_1 = f32[] parameter(1)
-  ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-ENTRY main {
-  param_0 = bf16[127,125]{1,0} parameter(0)
-  param_0_f32 = f32[127,125]{1,0} convert(param_0)
-  constant_neg_inf = f32[] constant(-inf)
-  reduce = f32[127]{0} reduce(param_0_f32, constant_neg_inf), dimensions={1}, to_apply=max_computation
-  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
-  ROOT subtract = f32[127,125]{1,0} subtract(param_0_f32, broadcast)
-})";
-
-  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-
-  EXPECT_THAT(
-      SoftmaxRewriterTriton(
-          TestGpuDeviceInfo::RTXA6000DeviceInfo(
-              se::CudaComputeCapability{se::CudaComputeCapability::VOLTA, 0}),
-          ShapeSizeBytesFunction())
-          .Run(module.get()),
-      tsl::testing::StatusIs(
-          tsl::error::FAILED_PRECONDITION,
-          ::testing::HasSubstr("Triton support is only enabled for Ampere GPUs "
-                               "(compute capability 8.0) and up, but got")));
-}
-
-TEST_F(SoftmaxRewriterTritonTest, RewriterSucceedsOnNonCudaGpu) {
-  const std::string hlo_string = R"(
-HloModule softmax
-max_computation {
-  arg_0 = f32[] parameter(0)
-  arg_1 = f32[] parameter(1)
-  ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-ENTRY main {
-  param_0 = bf16[127,125]{1,0} parameter(0)
-  param_0_f32 = f32[127,125]{1,0} convert(param_0)
-  constant_neg_inf = f32[] constant(-inf)
-  reduce = f32[127]{0} reduce(param_0_f32, constant_neg_inf), dimensions={1}, to_apply=max_computation
-  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
-  ROOT subtract = f32[127,125]{1,0} subtract(param_0_f32, broadcast)
-})";
-
-  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-
-  EXPECT_TRUE(SoftmaxRewriterTriton(TestGpuDeviceInfo::AMDMI210DeviceInfo(),
-                                    ShapeSizeBytesFunction())
-                  .Run(module.get())
-                  .ok());
-}
-
-TEST_F(SoftmaxRewriterTritonTest,
-       CanFuseBinaryElementwiseProducerIntoDiamondWhenBothOperandsAreTheSame) {
-  const std::string hlo_string = R"(
-HloModule fusible_diamond
-max_computation {
-  arg_0 = f32[] parameter(0)
-  arg_1 = f32[] parameter(1)
-  ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-ENTRY main {
-  param_0 = f32[127,125]{1,0} parameter(0)
-  multiply =  f32[127,125]{1,0} multiply(param_0, param_0)
-  constant_neg_inf = f32[] constant(-inf)
-  reduce = f32[127]{0} reduce(multiply, constant_neg_inf), dimensions={1}, to_apply=max_computation
-  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
-  ROOT subtract = f32[127,125]{1,0} subtract(multiply, broadcast)
-})";
-  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-  EXPECT_TRUE(
-      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-  EXPECT_TRUE(verifier().Run(module.get()).status().ok());
-  EXPECT_THAT(
-      module->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
-}
-
-TEST_F(
-    SoftmaxRewriterTritonTest,
-    CanFuseIntermediateBinaryElementwiseWithinDiamondWhenBothOperandsAreTheSame) {  // NOLINT(whitespace/line_length)
-  const std::string hlo_string = R"(
-HloModule fusible_diamond
-max_computation {
-  arg_0 = f32[] parameter(0)
-  arg_1 = f32[] parameter(1)
-  ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-ENTRY main {
-  param_0 = f32[127,125]{1,0} parameter(0)
-  constant_neg_inf = f32[] constant(-inf)
-  reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
-  multiply =  f32[127]{0} multiply(reduce, reduce)
-  broadcast = f32[127,125]{1,0} broadcast(multiply), dimensions={0}
-  ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
-}
-)";
-  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-  EXPECT_TRUE(
-      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-  EXPECT_TRUE(verifier().Run(module.get()).status().ok());
-  EXPECT_THAT(
-      module->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
-}
-
-TEST_F(SoftmaxRewriterTritonTest,
-       CanFuseBinaryElementwiseWhenBothOperandsAreTheSameBetweenDiamonds) {
-  const std::string hlo_string = R"(
-HloModule fusible_diamonds
-max_computation {
-  arg_0 = f32[] parameter(0)
-  arg_1 = f32[] parameter(1)
-  ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-add_computation {
-  arg_0.1 = f32[] parameter(0)
-  arg_1.1 = f32[] parameter(1)
-  ROOT add = f32[] add(arg_0.1, arg_1.1)
-}
-ENTRY main {
-  param_0 = f32[127,125]{1,0} parameter(0)
-  constant_neg_inf = f32[] constant(-inf)
-  reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
-  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
-  subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
-  multiply = f32[127,125]{1,0} multiply(subtract, subtract)
-  constant_zero = f32[] constant(0)
-  second_reduce = f32[127]{0} reduce(multiply, constant_zero), dimensions={1}, to_apply=add_computation
-  second_broadcast = f32[127,125]{1,0} broadcast(second_reduce), dimensions={0}
-  ROOT subtract_second = f32[127,125]{1,0} subtract(multiply, second_broadcast)
-}
-)";
-  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-  EXPECT_TRUE(
-      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-  EXPECT_TRUE(verifier().Run(module.get()).status().ok());
-  EXPECT_THAT(
-      module->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
-}
-
-TEST_F(SoftmaxRewriterTritonTest,
-       CanFuseBinaryElementwiseConsumerWhereBothOperandsAreTheSameIntoDiamond) {
-  const std::string hlo_string = R"(
-HloModule fusible_diamond
-max_computation {
-  arg_0 = f32[] parameter(0)
-  arg_1 = f32[] parameter(1)
-  ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-add_computation {
-  arg_0.1 = f32[] parameter(0)
-  arg_1.1 = f32[] parameter(1)
-  ROOT add = f32[] add(arg_0.1, arg_1.1)
-}
-ENTRY main {
-  param_0 = f32[127,125]{1,0} parameter(0)
-  constant_neg_inf = f32[] constant(-inf)
-  reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
-  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
-  subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
-  ROOT multiply = f32[127,125]{1,0} multiply(subtract, subtract)
-}
-)";
-  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-  EXPECT_TRUE(
-      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-  EXPECT_TRUE(verifier().Run(module.get()).status().ok());
-  EXPECT_THAT(
-      module->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
-}
-
-TEST_F(
-    SoftmaxRewriterTritonTest,
-    DoesNotFuseIntermediateBinaryElementwiseWithBothSplatOperandsIntoDiamond) {
-  const std::string hlo_string = R"(
-HloModule nonfusible_splat
-max_computation {
-  arg_0 = f32[] parameter(0)
-  arg_1 = f32[] parameter(1)
-  ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-ENTRY main {
-  constant_0 = f32[] constant(0.333333343)
-  splat_0 = f32[127,125]{1,0} broadcast(constant_0), dimensions={}
-  constant_1 = f32[] constant(0.66666)
-  splat_1 = f32[127,125]{1,0} broadcast(constant_1), dimensions={}
-  param_0 = f32[127,125]{1,0} parameter(0)
-  multiply_splats = f32[127,125]{1,0} multiply(splat_0, splat_1)
-  multiply_splat_param = f32[127,125]{1,0} multiply(multiply_splats, param_0)
-  constant_neg_inf = f32[] constant(-inf)
-  reduce = f32[127]{0} reduce(multiply_splat_param, constant_neg_inf), dimensions={1}, to_apply=max_computation
-  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
-  ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
-}
-)";
-
-  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-  EXPECT_FALSE(
-      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-}
-
-TEST_F(
-    SoftmaxRewriterTritonTest,
-    DoesNotFuseIntermediateBinaryElementwiseWithSameSplatOperandsIntoDiamond) {
-  const std::string hlo_string = R"(
-HloModule nonfusible_splat_diamond
-max_computation {
-  arg_0 = f32[] parameter(0)
-  arg_1 = f32[] parameter(1)
-  ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-ENTRY main {
-  constant_0 = f32[] constant(0.333333343)
-  splat = f32[127,125]{1,0} broadcast(constant_0), dimensions={}
-  param_0 = f32[127,125]{1,0} parameter(0)
-  multiply = f32[127,125]{1,0} multiply(splat, splat)
-  add = f32[127,125]{1,0} add(param_0, multiply)
-  constant_neg_inf = f32[] constant(-inf)
-  reduce = f32[127]{0} reduce(add, constant_neg_inf), dimensions={1}, to_apply=max_computation
-  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
-  ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
-}
-)";
-  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-  SoftmaxRewriterTriton fusion_rewriter(device_info_, ShapeSizeBytesFunction());
-  EXPECT_FALSE(
-      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-}
-
-TEST_F(SoftmaxRewriterTritonTest, CanFuseRMSNormDiamond) {
-  const std::string hlo_string = R"(
-HloModule rms_norm
-add_computation {
-  arg_0 = f32[] parameter(0)
-  arg_1 = f32[] parameter(1)
-  ROOT add.1 = f32[] add(arg_0, arg_1)
-}
-ENTRY main.30 {
-  param_0 = f32[10,10,10,128]{3,2,1,0} parameter(0)
-  multiply_param = f32[10,10,10,128]{3,2,1,0} multiply(param_0, param_0)
-  constant_0 = f32[] constant(0)
-  reduce = f32[10,10,10]{2,1,0} reduce(multiply_param, constant_0), dimensions={3}, to_apply=add_computation
-  constant_1 = f32[] constant(0.333333343)
-  splat = f32[10,10,10]{2,1,0} broadcast(constant_1), dimensions={}
-  multiply_splat = f32[10,10,10]{2,1,0} multiply(reduce, splat)
-  epsilon = f32[] constant(1e-06)
-  splat_epsilon = f32[10,10,10]{2,1,0} broadcast(epsilon), dimensions={}
-  add = f32[10,10,10]{2,1,0} add(multiply_splat, splat_epsilon)
-  rsqrt = f32[10,10,10]{2,1,0} rsqrt(add)
-  broadcast = f32[10,10,10,128]{3,2,1,0} broadcast(rsqrt), dimensions={0,1,2}
-  ROOT multiply = f32[10,10,10,128]{3,2,1,0} multiply(param_0, broadcast)
-}
-)";
-  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-      EXPECT_TRUE(
-          SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get())
-              .value());
-      EXPECT_TRUE(verifier().Run(module.get()).status().ok());
-      EXPECT_THAT(module->entry_computation()->root_instruction(),
-                  GmockMatch(m::Fusion(m::Parameter())
-                                 .WithPredicate(HasBlockLevelFusionConfig)));
-}
-
-TEST_F(
-    SoftmaxRewriterTritonTest,
-    CanFuseAndEmitBinaryElementwiseWhereTheFirstOperandIsASplatConstantBetweenDiamonds) {  // NOLINT(whitespace/line_length)
-  const std::string hlo_string = R"(
-HloModule fusible_diamonds
-add_computation {
-  arg_0.1 = f32[] parameter(0)
-  arg_1.1 = f32[] parameter(1)
-  ROOT add = f32[] add(arg_0.1, arg_1.1)
-}
-ENTRY main {
-  param_0 = f32[127,125]{1,0} parameter(0)
-  constant_neg_inf = f32[] constant(-inf)
-  reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=add_computation
-  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
-  subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
-  constant = f32[] constant(0.333333343)
-  broadcast_splat = f32[127,125]{1,0} broadcast(constant), dimensions={}
-  multiply = f32[127,125]{1,0} multiply(broadcast_splat, subtract)
-  constant_zero = f32[] constant(0)
-  second_reduce = f32[127]{0} reduce(multiply, constant_zero), dimensions={1}, to_apply=add_computation
-  second_broadcast = f32[127,125]{1,0} broadcast(second_reduce), dimensions={0}
-  ROOT second_subtract = f32[127,125]{1,0} subtract(multiply, second_broadcast)
-}
-)";
-  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-  EXPECT_TRUE(
-      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-  EXPECT_TRUE(verifier().Run(module.get()).status().ok());
-  EXPECT_THAT(
-      module->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
-}
-
-TEST_F(
-    SoftmaxRewriterTritonTest,
-    CanFuseAndEmitBinaryElementwiseWhereTheSecondOperandIsASplatConstantBetweenDiamonds) {  // NOLINT(whitespace/line_length)
-  const std::string hlo_string = R"(
-HloModule fusible_diamonds
-add_computation {
-  arg_0.1 = f32[] parameter(0)
-  arg_1.1 = f32[] parameter(1)
-  ROOT add = f32[] add(arg_0.1, arg_1.1)
-}
-ENTRY main {
-  param_0 = f32[127,125]{1,0} parameter(0)
-  constant_neg_inf = f32[] constant(-inf)
-  reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=add_computation
-  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
-  subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
-  constant = f32[] constant(0.333333343)
-  broadcast_splat = f32[127,125]{1,0} broadcast(constant), dimensions={}
-  multiply = f32[127,125]{1,0} multiply(subtract, broadcast_splat)
-  constant_zero = f32[] constant(0)
-  second_reduce = f32[127]{0} reduce(multiply, constant_zero), dimensions={1}, to_apply=add_computation
-  second_broadcast = f32[127,125]{1,0} broadcast(second_reduce), dimensions={0}
-  ROOT second_subtract = f32[127,125]{1,0} subtract(multiply, second_broadcast)
-}
-)";
-  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-  EXPECT_TRUE(
-      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-  EXPECT_TRUE(verifier().Run(module.get()).status().ok());
-  EXPECT_THAT(
-      module->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
-}
-
-TEST_F(
-    SoftmaxRewriterTritonTest,
-    CanFuseBinaryElementwiseWhereTheFirstOperandIsASplatConstantWithinDiamond) {
-  const std::string hlo_string = R"(
-HloModule fusible_diamond
-max_computation {
-  arg_0 = f32[] parameter(0)
-  arg_1 = f32[] parameter(1)
-  ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-ENTRY main {
-  param_0 = f32[127,125]{1,0} parameter(0)
-  constant_neg_inf = f32[] constant(-inf)
-  reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
-  constant = f32[] constant(0.333333343)
-  broadcast_splat = f32[127]{0} broadcast(constant), dimensions={}
-  multiply = f32[127]{0} multiply(broadcast_splat, reduce)
-  broadcast = f32[127,125]{1,0} broadcast(multiply), dimensions={0}
-  ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
-}
-)";
-  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-  EXPECT_TRUE(
-      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-  EXPECT_TRUE(verifier().Run(module.get()).status().ok());
-  EXPECT_THAT(
-      module->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
-}
-
-TEST_F(SoftmaxRewriterTritonTest,
-       CanFuseBinaryElementwiseConsumerWhereTheFirstOperandIsASplatConstant) {
-  const std::string hlo_string = R"(
-HloModule fusible_diamond
-add_computation {
-  arg_0.1 = f32[] parameter(0)
-  arg_1.1 = f32[] parameter(1)
-  ROOT add = f32[] add(arg_0.1, arg_1.1)
-}
-ENTRY main {
-  param_0 = f32[127,125]{1,0} parameter(0)
-  constant_neg_inf = f32[] constant(-inf)
-  reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=add_computation
-  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
-  subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
-  constant = f32[] constant(0.333333343)
-  broadcast_splat = f32[127,125]{1,0} broadcast(constant), dimensions={}
-  ROOT multiply = f32[127,125]{1,0} multiply(broadcast_splat, subtract)
-})";
-  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-  EXPECT_TRUE(
-      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-  EXPECT_TRUE(verifier().Run(module.get()).status().ok());
-  EXPECT_THAT(
-      module->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
-}
-
-TEST_F(SoftmaxRewriterTritonTest,
-       CanFuseBinaryElementwiseOperationWhereOneOperandIsASharedSplatProducer) {
-  const std::string hlo_string = R"(
-HloModule nonfusible_diamond
-max_computation {
-  arg_0 = f32[] parameter(0)
-  arg_1 = f32[] parameter(1)
-  ROOT max = f32[] maximum(arg_0, arg_1)
-}
-ENTRY main {
-  param_0 = f32[127,125]{1,0} parameter(0)
-  constant_2 = f32[] constant(0.333333343)
-  broadcast_splat = f32[127,125]{1,0} broadcast(constant_2), dimensions={}
-  param_1 = f32[127,125]{1,0} parameter(1)
-  multiply_splat = f32[127,125]{1,0} multiply(broadcast_splat, param_1)
-  multiply = f32[127,125]{1,0} multiply(param_0, broadcast_splat)
-  constant_neg_inf = f32[] constant(-inf)
-  reduce = f32[127]{0} reduce(multiply, constant_neg_inf), dimensions={1}, to_apply=max_computation
-  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
-  ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
-})";
-  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-  EXPECT_TRUE(
-      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-  EXPECT_TRUE(verifier().Run(module.get()).status().ok());
-  VLOG(2) << module->ToString();
-  EXPECT_THAT(
-      module->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
-}
-
-TEST_F(
-    SoftmaxRewriterTritonTest,
-    DoesNotFuseBinaryElementwiseOperationWhereFirstOperandIsASplatAndSecondOperandIsASharedSplatProducer) {  // NOLINT(whitespace/line_length)
-  const std::string hlo_string = R"(
-HloModule nonfusible_diamond
-add_computation {
-  arg_0.1 = f32[] parameter(0)
-  arg_1.1 = f32[] parameter(1)
-  ROOT add = f32[] add(arg_0.1, arg_1.1)
-}
-ENTRY main {
-  param_0 = f32[127,125]{1,0} parameter(0)
-  constant_2 = f32[] constant(0.333333343)
-  broadcast_splat_shared = f32[127,125]{1,0} broadcast(constant_2), dimensions={}
-  param_1 = f32[127,125]{1,0} parameter(1)
-  multiply_splat_shared = f32[127,125]{1,0} multiply(broadcast_splat_shared, param_1)
-  constant_3 = f32[] constant(0.5)
-  broadcast_splat = f32[127,125]{1,0} broadcast(constant_3), dimensions={}
-  multiply_splat = f32[127,125]{1,0} multiply(broadcast_splat, broadcast_splat_shared)
-  multiply = f32[127,125]{1,0} multiply(param_0, multiply_splat)
-  constant_neg_inf = f32[] constant(-inf)
-  reduce = f32[127]{0} reduce(multiply, constant_neg_inf), dimensions={1}, to_apply=add_computation
-  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
-  ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
-})";
-  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-  EXPECT_FALSE(
-      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-}
-
-TEST_F(SoftmaxRewriterTritonTest, FusionDecisionIsCapturedExplicitly) {
-  const std::string hlo_string = R"(
-HloModule softmax
-max_computation {
-  arg_0 = f32[] parameter(0)
-  arg_1 = f32[] parameter(1)
-  ROOT maximum = f32[] maximum(arg_0, arg_1)
-}
-ENTRY main {
-  param_0 = f32[127,125]{1,0} parameter(0)
-  identity_f8 = f8e5m2[] parameter(1)
-  identity = f32[] convert(identity_f8)
-  reduce = f32[127]{0} reduce(param_0, identity), dimensions={1}, to_apply=max_computation
-  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
-  ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
-}
-)";
-
-  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-  SoftmaxRewriterTriton softmax_rewriter_triton(device_info_,
-                                                ShapeSizeBytesFunction());
-  int unmatched = 0, matched = 0;
-  for (HloInstruction* instruction :
-       module->entry_computation()->MakeInstructionPostOrder()) {
-    DiamondMatchingDecision decision =
-        softmax_rewriter_triton.MatchesTritonCompatibleClosedReductionDiamond(
-            instruction);
-    if (std::holds_alternative<FusionDecision>(decision)) {
-      std::string actual_decision =
-          std::get<FusionDecision>(decision).Explain();
-      EXPECT_THAT(
-          actual_decision,
-          AnyOf(
-              HasSubstr("Root is not elementwise binary"),
-              HasSubstr("identity is not a constant or a supported convert")));
-      unmatched++;
-    } else {
-      matched++;
-    }
-  }
-  EXPECT_EQ(unmatched, 6);
-  EXPECT_EQ(matched, 0);
-}
-
-TEST_F(
-    SoftmaxRewriterTritonTest,
-    FusesBinaryElementwiseIfIntermediateDiamondOpWithBroadcastAlongReductionDimAsParameter) {  // NOLINT(whitespace/line_length)
-  const std::string hlo_string = R"(
-HloModule h1
-
-add_computation {
-  y = f32[] parameter(1)
-  x = f32[] parameter(0)
-  ROOT add = f32[] add(x, y)
-}
-
-ENTRY main {
-  p0 = f32[32]{0} parameter(0)
-  p1 = f32[32,16]{1,0} parameter(1)
-  c = f32[] constant(0)
-
-  r0 = f32[32]{0} reduce(p1, c), dimensions={1}, to_apply=add_computation
-  b0 = f32[32,16]{1,0} broadcast(r0), dimensions={0}
-  b1 = f32[32,16]{1,0} broadcast(p0), dimensions={0}
-  add0 = f32[32,16]{1,0} add(b1, p1)
-  ROOT add1 = f32[32,16]{1,0} add(add0, b0)
-})";
-  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-  EXPECT_TRUE(
-      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-}
-
-TEST_F(
-    SoftmaxRewriterTritonTest,
-    FusesBinaryElementwiseIfIntermediateDiamondOpWithBroadcastAlongBatchDimAsParameter) {  // NOLINT(whitespace/line_length)
-  const std::string hlo_string = R"(
-HloModule h1
-
-add_computation {
-  y = f32[] parameter(1)
-  x = f32[] parameter(0)
-  ROOT add = f32[] add(x, y)
-}
-
-ENTRY main {
-  p0 = f32[16]{0} parameter(0)
-  p1 = f32[32,16]{1,0} parameter(1)
-  c = f32[] constant(0)
-
-  r0 = f32[32]{0} reduce(p1, c), dimensions={1}, to_apply=add_computation
-  b0 = f32[32,16]{1,0} broadcast(r0), dimensions={0}
-  b1 = f32[32,16]{1,0} broadcast(p0), dimensions={1}
-  add0 = f32[32,16]{1,0} add(b1, p1)
-  ROOT add1 = f32[32,16]{1,0} add(add0, b0)
-})";
-  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-  EXPECT_TRUE(
-      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-}
-
-TEST_F(
-    SoftmaxRewriterTritonTest,
-    FusesBinaryElementwiseIfIntermediateDiamondOpWithMultiDimTensorBroadcastAlongBatchDimAsParameter) {  // NOLINT(whitespace/line_length)
-  const std::string hlo_string = R"(
-HloModule h1
-
-add_computation {
-  y = f32[] parameter(1)
-  x = f32[] parameter(0)
-  ROOT add = f32[] add(x, y)
-}
-
-ENTRY main {
-  p0 = f32[32,16]{1,0} parameter(0)
-  p1 = f32[64,32,16]{2,1,0} parameter(1)
-  c = f32[] constant(0)
-
-  r0 = f32[64,32]{1,0} reduce(p1, c), dimensions={2}, to_apply=add_computation
-  b0 = f32[64,32,16]{2,1,0} broadcast(r0), dimensions={0,1}
-  b1 = f32[64,32,16]{2,1,0} broadcast(p0), dimensions={1,2}
-  add0 = f32[64,32,16]{2,1,0} add(b1, p1)
-  ROOT add1 = f32[64,32,16]{2,1,0} add(add0, b0)
-})";
-  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-  EXPECT_TRUE(
-      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-}
-
-TEST_F(
-    SoftmaxRewriterTritonTest,
-    FusesBinaryElementwiseIfIntermediateDiamondOpWithZeroDimTensorBroadcastAsParameter) {  // NOLINT(whitespace/line_length)
-  const std::string hlo_string = R"(
-HloModule h1
-
-add_computation {
-  y = f32[] parameter(1)
-  x = f32[] parameter(0)
-  ROOT add = f32[] add(x, y)
-}
-
-ENTRY main {
-  parameter_0 = f32[] parameter(0)
-  parameter_1 = f32[64,32,16]{2,1,0} parameter(1)
-  c = f32[] constant(0)
-
-  reduce_0 = f32[64,32]{1,0} reduce(parameter_1, c), dimensions={2}, to_apply=add_computation
-  broadcast_0 = f32[64,32,16]{2,1,0} broadcast(reduce_0), dimensions={0,1}
-  broadcast_1 = f32[64,32,16]{2,1,0} broadcast(parameter_0), dimensions={}
-  add_0 = f32[64,32,16]{2,1,0} add(broadcast_1, parameter_1)
-  ROOT add1 = f32[64,32,16]{2,1,0} add(add_0, broadcast_0)
-})";
-  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-  EXPECT_TRUE(
-      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-}
-
-TEST_F(
-    SoftmaxRewriterTritonTest,
-    FusesBinaryElementwiseIfIntermediateDiamondOpIsBroadcastOf1DParameterAlongNonReductionDimensions) {  // NOLINT(whitespace/line_length)
-  const std::string hlo_string = R"(
-HloModule h1
-
-add_computation {
-  y = f32[] parameter(1)
-  x = f32[] parameter(0)
-  ROOT add = f32[] add(x, y)
-}
-
-ENTRY main {
-  parameter_0 = f32[16] parameter(0)
-  parameter_1 = f32[64,32,16]{2,1,0} parameter(1)
-  c = f32[] constant(0)
-
-  reduce_0 = f32[64,32]{1,0} reduce(parameter_1, c), dimensions={2}, to_apply=add_computation
-  broadcast_0 = f32[64,32,16]{2,1,0} broadcast(reduce_0), dimensions={0,1}
-  broadcast_1 = f32[64,32,16]{2,1,0} broadcast(parameter_0), dimensions={2}
-  add_0 = f32[64,32,16]{2,1,0} add(broadcast_1, parameter_1)
-  ROOT add1 = f32[64,32,16]{2,1,0} add(add_0, broadcast_0)
-})";
-  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-  EXPECT_TRUE(
-      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-}
-
-TEST_F(
-    SoftmaxRewriterTritonTest,
-    DoesNotFuseBinaryElementwiseIfIntermediateDiamondOpIsBroadcastOf1DParameterAlongBothBatchAndReductionDimensions) {  // NOLINT(whitespace/line_length)
-  const std::string hlo_string = R"(
-HloModule h1
-
-add_computation {
-  y = f32[] parameter(1)
-  x = f32[] parameter(0)
-  ROOT add = f32[] add(x, y)
-}
-
-ENTRY main {
-  parameter_0 = f32[64] parameter(0)
-  parameter_1 = f32[64,32,16]{2,1,0} parameter(1)
-  c = f32[] constant(0)
-
-  reduce_0 = f32[64,32]{1,0} reduce(parameter_1, c), dimensions={2}, to_apply=add_computation
-  broadcast_0 = f32[64,32,16]{2,1,0} broadcast(reduce_0), dimensions={0,1}
-  broadcast_1 = f32[64,32,16]{2,1,0} broadcast(parameter_0), dimensions={0}
-  add_0 = f32[64,32,16]{2,1,0} add(broadcast_1, parameter_1)
-  ROOT add1 = f32[64,32,16]{2,1,0} add(add_0, broadcast_0)
-})";
-  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-  EXPECT_FALSE(
-      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-}
-
-TEST_F(
-    SoftmaxRewriterTritonTest,
-    DoesNotFuseBinaryElementwiseIfIntermediateDiamondOpWithBroadcastAlongBatchAndReductionDimAsParameter) {  // NOLINT(whitespace/line_length)
-  const std::string hlo_string = R"(
-HloModule h1
-
-add_computation {
-  y = f32[] parameter(1)
-  x = f32[] parameter(0)
-  ROOT add = f32[] add(x, y)
-}
-
-ENTRY main {
-  p0 = f32[8]{0} parameter(0)
-  p1 = f32[32,8,16]{2,1,0} parameter(1)
-  c = f32[] constant(0)
-
-  r0 = f32[32,8]{1,0} reduce(p1, c), dimensions={2}, to_apply=add_computation
-  b0 = f32[32,8,16]{2,1,0} broadcast(r0), dimensions={0,1}
-  b1 = f32[32,8,16]{2,1,0} broadcast(p0), dimensions={1}
-  add0 = f32[32,8,16]{2,1,0} add(b1, p1)
-  ROOT add1 = f32[32,8,16]{2,1,0} add(add0, b0)
-})";
-  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-  EXPECT_FALSE(
-      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-}
-
-TEST_F(
-    SoftmaxRewriterTritonTest,
-    DoesNotFuseBinaryElementwiseIfIntermediateDiamondOpWithPartialBroadcastToBatchDim) {  // NOLINT(whitespace/line_length)
-  const std::string hlo_string = R"(
-HloModule h1
-
-add_computation {
-  y = f32[] parameter(1)
-  x = f32[] parameter(0)
-  ROOT add = f32[] add(x, y)
-}
-
-ENTRY main {
-  p0 = f32[16,64]{1,0} parameter(0)
-  p1 = f32[8,16,32,64]{3,2,1,0} parameter(1)
-  c = f32[] constant(0)
-
-  r0 = f32[8,16,32]{2,1,0} reduce(p1, c), dimensions={3}, to_apply=add_computation
-  b0 = f32[8,16,32,64]{3,2,1,0} broadcast(r0), dimensions={0,1,2}
-  b1 = f32[8,16,32,64]{3,2,1,0} broadcast(p0), dimensions={1,3}
-  add0 = f32[8,16,32,64]{3,2,1,0} add(b1, p1)
-  ROOT add1 = f32[8,16,32,64]{3,2,1,0} add(add0, b0)
-}
-)";
-  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-  EXPECT_FALSE(
-      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-}
-
-TEST_F(
-    SoftmaxRewriterTritonTest,
-    DoesNotFuseBinaryElementwiseIfIntermediateDiamondOpWithMultiDimBroadcastAlongBatchDimAsParameter) {  // NOLINT(whitespace/line_length)
-  const std::string hlo_string = R"(
-HloModule h1
-
-add_computation {
-  y = f32[] parameter(1)
-  x = f32[] parameter(0)
-  ROOT add = f32[] add(x, y)
-}
-
-ENTRY main {
-  p0 = f32[32,16]{1,0} parameter(0)
-  p1 = f32[128,64,32,16]{3,2,1,0} parameter(1)
-  c = f32[] constant(0)
-
-  r0 = f32[128,64,32]{2,1,0} reduce(p1, c), dimensions={3}, to_apply=add_computation
-  b0 = f32[128,64,32,16]{3,2,1,0} broadcast(r0), dimensions={0,1,2}
-  b1 = f32[128,64,32,16]{3,2,1,0} broadcast(p0), dimensions={2,3}
-  add0 = f32[128,64,32,16]{3,2,1,0} add(b1, p1)
-  ROOT add1 = f32[128,64,32,16]{3,2,1,0} add(add0, b0)
-})";
-  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
-  EXPECT_FALSE(
-      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
-}
-
-}  // anonymous namespace
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/stream_attribute_annotator.cc b/third_party/xla/xla/service/gpu/stream_attribute_annotator.cc
deleted file mode 100644
index 7e54ea5..0000000
--- a/third_party/xla/xla/service/gpu/stream_attribute_annotator.cc
+++ /dev/null
@@ -1,219 +0,0 @@
-/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/stream_attribute_annotator.h"
-
-#include <cstdint>
-#include <vector>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/hlo/utils/hlo_query.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/gpu_fusible.h"
-#include "xla/service/gpu/runtime/thunk.h"
-#include "xla/util.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla::gpu {
-namespace {
-
-bool IsOnlyRootNonDefaultStream(HloComputation* computation) {
-  HloInstruction* root = computation->root_instruction();
-  auto root_gpu_config = root->backend_config<GpuBackendConfig>();
-  if (!root_gpu_config.ok() || root->opcode() == HloOpcode::kTuple) {
-    return false;
-  }
-  int64_t root_stream_id = root_gpu_config->operation_queue_id();
-  VLOG(2) << "Found fusion computation's root stream id to be "
-          << root_stream_id;
-  if (root_stream_id == Thunk::kDefaultExecutionStreamId.value()) {
-    return false;
-  }
-  for (HloInstruction* instr : computation->MakeInstructionPostOrder()) {
-    if (instr == root) {
-      continue;
-    }
-    int64_t instr_stream_id =
-        instr->backend_config<GpuBackendConfig>()->operation_queue_id();
-    if (instr_stream_id != Thunk::kDefaultExecutionStreamId.value() &&
-        instr_stream_id != root_stream_id) {
-      return false;
-    }
-  }
-  return true;
-}
-
-absl::StatusOr<bool> AnnotateStreamAttributesForInstruction(
-    HloInstruction* instr, GpuBackendConfig& instr_gpu_config) {
-  if (instr->called_computations().size() != 1) {
-    return false;
-  }
-  HloComputation* called_comp = instr->called_computations()[0];
-  int64_t stream_id = instr_gpu_config.operation_queue_id();
-
-  if (!IsOnlyRootNonDefaultStream(called_comp) ||
-      stream_id != Thunk::kDefaultExecutionStreamId.value()) {
-    return false;
-  }
-
-  auto comp_root_gpu_config =
-      called_comp->root_instruction()->backend_config<GpuBackendConfig>();
-
-  instr_gpu_config.set_operation_queue_id(
-      comp_root_gpu_config->operation_queue_id());
-  *instr_gpu_config.mutable_wait_on_operation_queues() =
-      comp_root_gpu_config->wait_on_operation_queues();
-  TF_RETURN_IF_ERROR(instr->set_backend_config(instr_gpu_config));
-  return true;
-}
-
-absl::StatusOr<bool> AnnotateStreamAttributesForCopyStart(
-    HloInstruction* instr, int64_t channel_id,
-    GpuBackendConfig& instr_gpu_config) {
-  // Do nothing if copy-start has already been annotated
-  if (instr_gpu_config.operation_queue_id() !=
-      Thunk::kDefaultExecutionStreamId.value()) {
-    return false;
-  }
-  instr_gpu_config.set_operation_queue_id(channel_id);
-  TF_RETURN_IF_ERROR(instr->set_backend_config(instr_gpu_config));
-  VLOG(3) << "Add copy-start's backend config: " << channel_id;
-  return true;
-}
-
-absl::StatusOr<bool> WrapIntoFusionAndAnnotateStreamAttributes(
-    HloInstruction* instruction, int64_t channel_id,
-    GpuBackendConfig& instr_gpu_config) {
-  auto* computation = instruction->parent();
-  auto* module = computation->parent();
-  auto* fusion_instruction =
-      computation->AddInstruction(HloInstruction::CreateFusion(
-          instruction->shape(), ChooseFusionKind(*instruction, *instruction),
-          instruction));
-  const absl::string_view wrapped_opcode =
-      HloOpcodeString(instruction->opcode());
-  module->SetAndUniquifyInstrName(fusion_instruction,
-                                  absl::StrCat("wrapped_", wrapped_opcode));
-  module->SetAndUniquifyComputationName(
-      fusion_instruction->fused_instructions_computation(),
-      absl::StrCat("wrapped_", wrapped_opcode, "_computation"));
-  if (module->has_schedule()) {
-    module->schedule().replace_instruction(computation, instruction,
-                                           fusion_instruction);
-  }
-  TF_RETURN_IF_ERROR(fusion_instruction->CopyAllControlDepsFrom(instruction));
-  TF_RETURN_IF_ERROR(instruction->DropAllControlDeps());
-  TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(fusion_instruction));
-  TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction));
-
-  instr_gpu_config.set_operation_queue_id(channel_id);
-  TF_RETURN_IF_ERROR(fusion_instruction->set_backend_config(instr_gpu_config));
-  VLOG(3) << "Add async stream " << channel_id << " and wrapped instruction "
-          << instruction->ToString();
-  VLOG(3) << "  Fusion wrapper: " << fusion_instruction->ToString();
-  return true;
-}
-
-absl::StatusOr<bool> AnnotateStreamAttributesForUsers(
-    HloInstruction* instr, GpuBackendConfig& instr_gpu_config) {
-  bool changed = false;
-  int64_t stream_id = instr_gpu_config.operation_queue_id();
-  if (stream_id == Thunk::kDefaultExecutionStreamId.value()) {
-    return changed;
-  }
-  std::vector<HloInstruction*> all_consumers;
-  for (auto user : instr->users()) {
-    if (user->opcode() == HloOpcode::kGetTupleElement) {
-      user = user->users()[0];
-    }
-    all_consumers.push_back(user);
-  }
-
-  for (auto user : all_consumers) {
-    TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
-                        user->backend_config<GpuBackendConfig>());
-    auto it = absl::c_find(gpu_config.wait_on_operation_queues(), stream_id);
-    if (it == gpu_config.wait_on_operation_queues().end() &&
-        gpu_config.operation_queue_id() != stream_id) {
-      gpu_config.mutable_wait_on_operation_queues()->Add(stream_id);
-      TF_RETURN_IF_ERROR(user->set_backend_config(gpu_config));
-      changed = true;
-    }
-  }
-
-  return changed;
-}
-}  // namespace
-
-absl::StatusOr<bool> StreamAttributeAnnotator::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  XLA_VLOG_LINES(
-      5, "StreamAttributeAnnotator::Run(), before:\n" + module->ToString());
-  bool changed = false;
-  int64_t channel_id = hlo_query::NextChannelId(*module);
-  for (const HloComputation* comp :
-       module->MakeComputationPostOrder(execution_threads)) {
-    for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
-      auto instr_gpu_config = instr->backend_config<GpuBackendConfig>();
-      if (!instr_gpu_config.ok()) {
-        continue;
-      }
-      // For fusion instruction, only annotate
-      // when the root of fusion is a single instruction
-      // running on non-default stream.
-      if (instr->opcode() == HloOpcode::kFusion) {
-        TF_ASSIGN_OR_RETURN(bool comp_result,
-                            AnnotateStreamAttributesForInstruction(
-                                instr, instr_gpu_config.value()));
-        changed |= comp_result;
-      } else if (instr->opcode() == HloOpcode::kCopyStart) {
-        TF_ASSIGN_OR_RETURN(bool comp_result,
-                            AnnotateStreamAttributesForCopyStart(
-                                instr, channel_id, instr_gpu_config.value()));
-        changed |= comp_result;
-        continue;
-      } else if (comp->IsAsyncComputation() &&
-                 (instr->opcode() == HloOpcode::kDynamicSlice ||
-                  instr->opcode() == HloOpcode::kDynamicUpdateSlice)) {
-        TF_ASSIGN_OR_RETURN(bool comp_result,
-                            WrapIntoFusionAndAnnotateStreamAttributes(
-                                instr, channel_id, instr_gpu_config.value()));
-        changed |= comp_result;
-        continue;
-      }
-
-      TF_ASSIGN_OR_RETURN(
-          bool user_result,
-          AnnotateStreamAttributesForUsers(instr, instr_gpu_config.value()));
-      changed |= user_result;
-    }
-  }
-  XLA_VLOG_LINES(
-      5, "StreamAttributeAnnotator::Run(), after:\n" + module->ToString());
-  return changed;
-}
-
-}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/stream_attribute_annotator.h b/third_party/xla/xla/service/gpu/stream_attribute_annotator.h
deleted file mode 100644
index 8a0284a..0000000
--- a/third_party/xla/xla/service/gpu/stream_attribute_annotator.h
+++ /dev/null
@@ -1,60 +0,0 @@
-/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_STREAM_ATTRIBUTE_ANNOTATOR_H_
-#define XLA_SERVICE_GPU_STREAM_ATTRIBUTE_ANNOTATOR_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla::gpu {
-
-// This pass checks to see if:
-//  1.  there's any instruction, that
-//      consumes data from other computes streams,
-//      is missing "wait_on_operation_queues" attribute.
-//  2.  there's any fusion instruction with non-default
-//      stream fusion root.
-// It will annotate the corresponding instruction with
-// the correct attribute in GpuBackendConfig.
-// Instructions annotated with operation_queue_id > 0
-// will be wrapped with AsyncInstruction and split into
-// AsyncStart and AsyncDone in the
-// StreamAttributeAsyncWrapper pass.
-// We also check if there's any non-default-stream
-// instruction's user doesn't have the correct "wait_on_operation_queues"
-// attribute and set it with producer's operation_queue_id.
-// "wait_on_operation_queues" will need to used by the emitter to emit the
-// correct WaitForStreams thunk.
-
-class StreamAttributeAnnotator : public HloModulePass {
- public:
-  absl::string_view name() const override {
-    return "stream-attribute-annotator";
-  }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-};
-
-}  // namespace xla::gpu
-
-#endif  // XLA_SERVICE_GPU_STREAM_ATTRIBUTE_ANNOTATOR_H_
diff --git a/third_party/xla/xla/service/gpu/stream_attribute_annotator_test.cc b/third_party/xla/xla/service/gpu/stream_attribute_annotator_test.cc
deleted file mode 100644
index 17d9b2f..0000000
--- a/third_party/xla/xla/service/gpu/stream_attribute_annotator_test.cc
+++ /dev/null
@@ -1,287 +0,0 @@
-/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/stream_attribute_annotator.h"
-
-#include <cstdint>
-#include <memory>
-#include <string>
-#include <vector>
-
-#include <gtest/gtest.h>
-#include "absl/algorithm/container.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/tests/hlo_test_base.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla::gpu {
-namespace {
-
-using StreamAttributeAnnotatorTest = HloTestBase;
-
-TEST_F(StreamAttributeAnnotatorTest, AllUsersAreAnnotated) {
-  constexpr absl::string_view kHloString = R"(
-  HloModule ModuleWithAsync
-
-  ENTRY entry {
-    p1_32 = f32[1] parameter(0)
-    p2_32 = f32[1] parameter(1)
-    add_32 = f32[1] add(p1_32, p2_32), backend_config={"operation_queue_id":"1", "wait_on_operation_queues":[]}
-    exp_32 = f32[1] exponential(add_32)
-
-    neg32 = f32[1] negate(add_32)
-    ROOT add_out_32 = f32[1] add(neg32, exp_32)
-  }
-  )";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(kHloString));
-
-  StreamAttributeAnnotator attr_annotator;
-  bool changed;
-  TF_ASSERT_OK_AND_ASSIGN(changed, attr_annotator.Run(module.get()));
-  EXPECT_TRUE(changed);
-
-  const HloInstruction* add = FindInstruction(module.get(), "add_32");
-  for (auto user : add->users()) {
-    // Every user should have an annotation.
-    EXPECT_TRUE(user->has_backend_config());
-    TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config,
-                            user->backend_config<GpuBackendConfig>());
-    EXPECT_EQ(gpu_config.wait_on_operation_queues()[0], 1);
-  }
-}
-
-TEST_F(StreamAttributeAnnotatorTest, MultipleStreamsAreCombined) {
-  constexpr absl::string_view kHloString = R"(
-  HloModule ModuleWithAsync
-
-  ENTRY entry {
-    p1_32 = f32[1] parameter(0)
-    p2_32 = f32[1] parameter(1)
-    add_32 = f32[1] add(p1_32, p2_32), backend_config={"operation_queue_id":"1", "wait_on_operation_queues":[]}
-    exp_32 = f32[1] exponential(p2_32), backend_config={"operation_queue_id":"2", "wait_on_operation_queues":[]}
-
-    ROOT add_out_32 = f32[1] add(add_32, exp_32)
-  }
-  )";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(kHloString));
-
-  StreamAttributeAnnotator attr_annotator;
-  bool changed;
-  TF_ASSERT_OK_AND_ASSIGN(changed, attr_annotator.Run(module.get()));
-  EXPECT_TRUE(changed);
-
-  const HloInstruction* root = module->entry_computation()->root_instruction();
-  // Root should wait on 2 streams.
-  EXPECT_TRUE(root->has_backend_config());
-  TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config,
-                          root->backend_config<GpuBackendConfig>());
-  std::vector<int64_t> expected_stream_ids = {1, 2};
-  for (auto id : expected_stream_ids) {
-    auto it = absl::c_find(gpu_config.wait_on_operation_queues(), id);
-    EXPECT_NE(it, gpu_config.wait_on_operation_queues().end());
-  }
-}
-
-TEST_F(StreamAttributeAnnotatorTest, GTEUserIsAnnotated) {
-  constexpr absl::string_view kHloString = R"(
-  HloModule ModuleWithAsync
-
-  ENTRY entry {
-    p1_32 = f32[16,32] parameter(0)
-    p2_32 = f32[32,16] parameter(1)
-
-    custom-call.3 = (f32[16,16], s8[1028]{0}) custom-call(p1_32, p2_32), custom_call_target="__cublas$gemm", backend_config={"operation_queue_id":"1","wait_on_operation_queues":[],"gemm_backend_config":{"alpha_real":1,"alpha_imag":0,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":["1"],"rhs_contracting_dimensions":["0"],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT","grad_x":false,"grad_y":false}}
-    get-tuple-element.24 = f32[16,16] get-tuple-element(custom-call.3), index=0
-
-    exp_32 = f32[16,16] exponential(get-tuple-element.24)
-
-    ROOT neg32 = f32[16,16] negate(exp_32)
-  }
-  )";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(kHloString));
-
-  StreamAttributeAnnotator attr_annotator;
-  bool changed;
-  TF_ASSERT_OK_AND_ASSIGN(changed, attr_annotator.Run(module.get()));
-  EXPECT_TRUE(changed);
-
-  const HloInstruction* exp = FindInstruction(module.get(), "exp_32");
-  EXPECT_TRUE(exp->has_backend_config());
-  TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config,
-                          exp->backend_config<GpuBackendConfig>());
-  EXPECT_EQ(gpu_config.wait_on_operation_queues()[0], 1);
-}
-
-TEST_F(StreamAttributeAnnotatorTest, FusionIsAnnotated) {
-  constexpr absl::string_view kHloString = R"(
-  HloModule ModuleWithFusion
-
-  fused_computation.1 {
-    fusion_p0_32 = f32[16,16] parameter(0)
-    fusion_p2_32 = f32[16,16] parameter(1)
-    ROOT add = f32[16,16] add(fusion_p0_32, fusion_p2_32), backend_config={"operation_queue_id":"1","wait_on_operation_queues":[]}
-  }
-
-  ENTRY entry {
-    p1_32 = f32[16,16] parameter(0)
-    p2_32 = f32[16,16] parameter(1)
-    ROOT fusion.1 = f32[16,16] fusion(p1_32, p2_32), kind=kLoop, calls=fused_computation.1
-  }
-  )";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(kHloString));
-
-  StreamAttributeAnnotator attr_annotator;
-  bool changed;
-  TF_ASSERT_OK_AND_ASSIGN(changed, attr_annotator.Run(module.get()));
-  EXPECT_TRUE(changed);
-
-  const HloInstruction* fusion = FindInstruction(module.get(), "fusion.1");
-  EXPECT_TRUE(fusion->has_backend_config());
-  TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config,
-                          fusion->backend_config<GpuBackendConfig>());
-  EXPECT_EQ(gpu_config.operation_queue_id(), 1);
-}
-
-TEST_F(StreamAttributeAnnotatorTest, CopyStartIsAnnotated) {
-  constexpr absl::string_view kHloString = R"(
-  HloModule offloading
-    ENTRY %main (param_0: f32[1024], param_1: f32[1024]) -> f32[1024] {
-    %param_1 = f32[1024]{0} parameter(1)
-    %param_0 = f32[1024]{0} parameter(0)
-    %res_3 = f32[1024]{0} add(f32[1024]{0} %param_0, f32[1024]{0} %param_1)
-    %copy-start = (f32[1024]{0:S(5)}, f32[1024]{0}, u32[]) copy-start(f32[1024]{0} %res_3)
-    %res_4 = f32[1024]{0} tanh(f32[1024]{0} %res_3)
-    %copy-start.2 = (f32[1024]{0:S(5)}, f32[1024]{0}, u32[]) copy-start(f32[1024]{0} %res_4)
-    %res_5 = f32[1024]{0} tanh(f32[1024]{0} %res_4)
-    %copy-done = f32[1024]{0:S(5)} copy-done((f32[1024]{0:S(5)}, f32[1024]{0}, u32[]) %copy-start)
-    %res_6 = f32[1024]{0} tanh(f32[1024]{0} %res_5)
-    %copy-done.2 = f32[1024]{0:S(5)} copy-done((f32[1024]{0:S(5)}, f32[1024]{0}, u32[]) %copy-start.2)
-    %copy-start.3 = (f32[1024]{0}, f32[1024]{0:S(5)}, u32[]) copy-start(f32[1024]{0:S(5)} %copy-done.2)
-    %res_7 = f32[1024]{0} add(f32[1024]{0} %res_6, f32[1024]{0} %res_6)
-    %copy-start.1 = (f32[1024]{0}, f32[1024]{0:S(5)}, u32[]) copy-start(f32[1024]{0:S(5)} %copy-done)
-    %res_8 = f32[1024]{0} add(f32[1024]{0} %res_7, f32[1024]{0} %res_5)
-    %copy-done.3 = f32[1024]{0} copy-done((f32[1024]{0}, f32[1024]{0:S(5)}, u32[]) %copy-start.3)
-    %res_9 = f32[1024]{0} add(f32[1024]{0} %res_8, f32[1024]{0} %copy-done.3)
-    %copy-done.1 = f32[1024]{0} copy-done((f32[1024]{0}, f32[1024]{0:S(5)}, u32[]) %copy-start.1)
-    %res_10 = f32[1024]{0} add(f32[1024]{0} %res_9, f32[1024]{0} %copy-done.1)
-    ROOT %res_11 = f32[1024]{0} tanh(f32[1024]{0} %res_10)
-  }
-  )";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(kHloString));
-
-  StreamAttributeAnnotator attr_annotator;
-  bool changed;
-  TF_ASSERT_OK_AND_ASSIGN(changed, attr_annotator.Run(module.get()));
-  EXPECT_TRUE(changed);
-
-  for (std::string i : {"", ".1", ".2", ".3"}) {
-    const HloInstruction* cp_start =
-        FindInstruction(module.get(), "copy-start" + i);
-    EXPECT_TRUE(cp_start->has_backend_config());
-    TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config,
-                            cp_start->backend_config<GpuBackendConfig>());
-    EXPECT_EQ(gpu_config.operation_queue_id(), 1);
-  }
-}
-
-TEST_F(StreamAttributeAnnotatorTest, DynamicUpdateSliceWrappedAndAnnotated) {
-  constexpr absl::string_view kHloString = R"(
-  HloModule ModuleWithAsyncDynamicUpdateSlice
-
-  ENTRY entry (param_0: f32[256,128,128], param_1: f32[1,128,128]) -> f32[256,128,128] {
-    param_0 = f32[256,128,128]{2,1,0:S(5)} parameter(0)
-    param_1 = f32[1,128,128]{2,1,0} parameter(1)
-    izero = s32[] constant(0)
-    dynamic-update-slice-start.2 = ((f32[256,128,128]{2,1,0:S(5)}, f32[1,128,128]{2,1,0}, s32[], s32[], s32[]), f32[256,128,128]{2,1,0:S(5)}, u32[])
-        dynamic-update-slice-start(param_0, param_1, izero, izero, izero)
-    ROOT dynamic-update-slice-done.2 = f32[256,128,128]{2,1,0:S(5)}
-        dynamic-update-slice-done(dynamic-update-slice-start.2)
-  }
-  )";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(kHloString));
-
-  TF_ASSERT_OK_AND_ASSIGN(bool changed,
-                          StreamAttributeAnnotator().Run(module.get()));
-  EXPECT_TRUE(changed);
-
-  // Check that the dynamic-update-slice instruction is wrapped in a fusion
-  // and the fusion is annotated with the correct operation_queue_id.
-  const HloInstruction* dus =
-      FindInstruction(module.get(), HloOpcode::kDynamicUpdateSlice);
-  const HloComputation* computation = dus->parent();
-  EXPECT_TRUE(computation->IsFusionComputation());
-  const HloInstruction* fusion = computation->FusionInstruction();
-  EXPECT_EQ(fusion->opcode(), HloOpcode::kFusion);
-  EXPECT_TRUE(fusion->parent()->IsAsyncComputation());
-
-  EXPECT_TRUE(fusion->has_backend_config());
-  TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config,
-                          fusion->backend_config<GpuBackendConfig>());
-  EXPECT_EQ(gpu_config.operation_queue_id(), 1);
-}
-
-TEST_F(StreamAttributeAnnotatorTest, DynamicSliceWrappedAndAnnotated) {
-  constexpr absl::string_view kHloString = R"(
-  HloModule ModuleWithAsyncDynamicSlice
-
-  ENTRY entry (param_0: f32[256,128,128]) -> f32[1,128,128] {
-    param_0 = f32[256,128,128]{2,1,0:S(5)} parameter(0)
-    izero = s32[] constant(0)
-    dynamic-slice-start.2 = ((f32[256,128,128]{2,1,0:S(5)}, s32[], s32[], s32[]), f32[1,128,128]{2,1,0}, u32[])
-        dynamic-slice-start(param_0, izero, izero, izero), dynamic_slice_sizes={1,128,128}
-    ROOT dynamic-slice-done.2 = f32[1,128,128]{2,1,0}
-        dynamic-slice-done(dynamic-slice-start.2)
-  }
-  )";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(kHloString));
-
-  TF_ASSERT_OK_AND_ASSIGN(bool changed,
-                          StreamAttributeAnnotator().Run(module.get()));
-  EXPECT_TRUE(changed);
-
-  // Check that the dynamic-slice instruction is wrapped in a fusion
-  // and the fusion is annotated with the correct operation_queue_id.
-  const HloInstruction* ds =
-      FindInstruction(module.get(), HloOpcode::kDynamicSlice);
-  const HloComputation* computation = ds->parent();
-  EXPECT_TRUE(computation->IsFusionComputation());
-  const HloInstruction* fusion = computation->FusionInstruction();
-  EXPECT_EQ(fusion->opcode(), HloOpcode::kFusion);
-  EXPECT_TRUE(fusion->parent()->IsAsyncComputation());
-
-  EXPECT_TRUE(fusion->has_backend_config());
-  TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config,
-                          fusion->backend_config<GpuBackendConfig>());
-  EXPECT_EQ(gpu_config.operation_queue_id(), 1);
-}
-}  // namespace
-}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/stream_attribute_async_wrapper.cc b/third_party/xla/xla/service/gpu/stream_attribute_async_wrapper.cc
deleted file mode 100644
index 822c647..0000000
--- a/third_party/xla/xla/service/gpu/stream_attribute_async_wrapper.cc
+++ /dev/null
@@ -1,74 +0,0 @@
-/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/stream_attribute_async_wrapper.h"
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/runtime/thunk.h"
-#include "xla/util.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/logging.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla::gpu {
-
-namespace {
-static absl::StatusOr<bool> AsynchronizeInstruction(HloInstruction* instr) {
-  auto instr_gpu_config = instr->backend_config<GpuBackendConfig>();
-  if (!instr_gpu_config.ok() || instr_gpu_config->operation_queue_id() ==
-                                    Thunk::kDefaultExecutionStreamId.value()) {
-    return false;
-  }
-  HloComputation* computation = instr->parent();
-  TF_ASSIGN_OR_RETURN(
-      HloInstruction * done,
-      computation->CreateAsyncInstructions(
-          instr, {}, StreamAttributeAsyncWrapper::kParallelExecutionThread,
-          /*replace=*/true));
-  TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
-                      done->backend_config<GpuBackendConfig>());
-  // Set the false delay of done op to be false so it can be scheduled
-  // far apart from start.
-  gpu_config.set_force_earliest_schedule(false);
-  TF_RETURN_IF_ERROR(done->set_backend_config(gpu_config));
-  VLOG(5) << "Created async instruction: " << done->ToString();
-  return true;
-}
-}  // namespace
-
-absl::StatusOr<bool> StreamAttributeAsyncWrapper::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  XLA_VLOG_LINES(
-      2, "StreamAttributeAsyncWrapper::Run(), before:\n" + module->ToString());
-  bool changed = false;
-  for (const HloComputation* comp :
-       module->MakeNonfusionComputations(execution_threads)) {
-    for (HloInstruction* instr : comp->instructions()) {
-      TF_ASSIGN_OR_RETURN(bool result, AsynchronizeInstruction(instr));
-      changed |= result;
-    }
-  }
-  XLA_VLOG_LINES(
-      2, "StreamAttributeAsyncWrapper::Run(), after:\n" + module->ToString());
-  return changed;
-}
-
-}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/stream_attribute_async_wrapper.h b/third_party/xla/xla/service/gpu/stream_attribute_async_wrapper.h
deleted file mode 100644
index 95fe7bb..0000000
--- a/third_party/xla/xla/service/gpu/stream_attribute_async_wrapper.h
+++ /dev/null
@@ -1,49 +0,0 @@
-/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_STREAM_ATTRIBUTE_ASYNC_WRAPPER_H_
-#define XLA_SERVICE_GPU_STREAM_ATTRIBUTE_ASYNC_WRAPPER_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla::gpu {
-
-// This pass will find the instructions that
-// are annotated with non-default stream id in backend configs
-// by the StreamAttributeAnnotator pass
-// and wrap them using AsyncStartDone pairs to achieve
-// asynchronous executions.
-class StreamAttributeAsyncWrapper : public HloModulePass {
- public:
-  inline static constexpr char kParallelExecutionThread[] = "parallel";
-
-  absl::string_view name() const override {
-    return "async-stream-attribute-wrapper";
-  }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-};
-
-}  // namespace xla::gpu
-
-#endif  // XLA_SERVICE_GPU_STREAM_ATTRIBUTE_ASYNC_WRAPPER_H_
diff --git a/third_party/xla/xla/service/gpu/stream_attribute_async_wrapper_test.cc b/third_party/xla/xla/service/gpu/stream_attribute_async_wrapper_test.cc
deleted file mode 100644
index 8b3dcb2..0000000
--- a/third_party/xla/xla/service/gpu/stream_attribute_async_wrapper_test.cc
+++ /dev/null
@@ -1,77 +0,0 @@
-/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/stream_attribute_async_wrapper.h"
-
-#include <memory>
-
-#include <gtest/gtest.h>
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/tests/hlo_test_base.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla::gpu {
-namespace {
-
-using StreamAttributeAsyncWrapperTest = HloTestBase;
-
-TEST_F(StreamAttributeAsyncWrapperTest, NonDefaultOpIsWrapped) {
-  constexpr absl::string_view kHloString = R"(
-  HloModule ModuleWithAsync
-
-  ENTRY entry {
-    p1_32 = f32[1] parameter(0)
-    p2_32 = f32[1] parameter(1)
-    add_32 = f32[1] add(p1_32, p2_32), backend_config={"operation_queue_id":"1", "wait_on_operation_queues":[], "force_earliest_schedule":true}
-    ROOT exp_32 = f32[1] exponential(add_32), backend_config={"operation_queue_id":"0", "wait_on_operation_queues":[1]}
-  }
-  )";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(kHloString));
-
-  StreamAttributeAsyncWrapper async_wrapper;
-  bool changed;
-  TF_ASSERT_OK_AND_ASSIGN(changed, async_wrapper.Run(module.get()));
-  EXPECT_TRUE(changed);
-  const HloInstruction* producer =
-      module->entry_computation()->root_instruction()->operand(0);
-  EXPECT_EQ(producer->opcode(), HloOpcode::kAsyncDone);
-  // Verify that the force_earliest_schedule is set to false for the done op.
-  TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig done_gpu_config,
-                          producer->backend_config<GpuBackendConfig>());
-  EXPECT_EQ(done_gpu_config.force_earliest_schedule(), false);
-
-  const HloInstruction* producer_start = producer->operand(0);
-  EXPECT_EQ(producer_start->opcode(), HloOpcode::kAsyncStart);
-
-  const xla::HloAsyncInstruction* async =
-      Cast<HloAsyncInstruction>(producer_start);
-  EXPECT_EQ(async->async_wrapped_opcode(), HloOpcode::kAdd);
-  // Verify that the backend config is kept intact
-  TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config,
-                          async->backend_config<GpuBackendConfig>());
-  EXPECT_EQ(gpu_config.operation_queue_id(), 1);
-  EXPECT_EQ(gpu_config.force_earliest_schedule(), true);
-  EXPECT_EQ(async->async_execution_thread(), "parallel");
-}
-}  // namespace
-}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/tests/BUILD b/third_party/xla/xla/service/gpu/tests/BUILD
index 487b6d9..e18dd57 100644
--- a/third_party/xla/xla/service/gpu/tests/BUILD
+++ b/third_party/xla/xla/service/gpu/tests/BUILD
@@ -162,44 +162,6 @@
     ],
 )
 
-xla_cc_test(
-    name = "gpu_reduce_scatter_creator_test",
-    srcs = ["gpu_reduce_scatter_creator_test.cc"],
-    deps = [
-        "//xla:util",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_module_config",
-        "//xla/service:pattern_matcher",
-        "//xla/service:pattern_matcher_gmock",
-        "//xla/service/gpu:gpu_reduce_scatter_creator",
-        "//xla/tests:hlo_test_base",
-        "//xla/tests:xla_internal_test_main",
-        "@com_google_absl//absl/algorithm:container",
-        "@com_google_absl//absl/log",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings:string_view",
-        "@com_google_googletest//:gtest_main",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
-xla_cc_test(
-    name = "gpu_all_gather_optimizer_test",
-    srcs = ["gpu_all_gather_optimizer_test.cc"],
-    deps = [
-        "//xla:util",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_module_config",
-        "//xla/service/gpu:gpu_all_gather_optimizer",
-        "//xla/tests:hlo_test_base",
-        "//xla/tests:xla_internal_test_main",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings",
-        "@local_tsl//tsl/platform:statusor",
-        "@local_tsl//tsl/platform:test",
-    ],
-)
-
 xla_test(
     name = "gpu_spmd_e2e_compile_test",
     size = "small",
@@ -221,65 +183,6 @@
 )
 
 xla_test(
-    name = "gemm_rewrite_test",
-    srcs = ["gemm_rewrite_test.cc"],
-    backends = ["gpu"],
-    local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([
-        "TENSORFLOW_USE_ROCM=1",
-    ]),
-    deps = [
-        ":gpu_codegen_test",
-        "//xla:error_spec",
-        "//xla:test",
-        "//xla:xla_proto_cc",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:buffer_assignment",
-        "//xla/service:executable",
-        "//xla/service:hlo_module_config",
-        "//xla/service:hlo_pass",
-        "//xla/service:pattern_matcher",
-        "//xla/service:pattern_matcher_gmock",
-        "//xla/service/gpu:gemm_rewriter",
-        "//xla/service/gpu:gpu_executable",
-        "//xla/stream_executor:device_description",
-        "//xla/stream_executor:device_memory_allocator",
-        "//xla/stream_executor:stream_executor_memory_allocator",
-        "//xla/tests:filecheck",
-        "//xla/tests:verified_hlo_module",
-        "//xla/tsl/lib/core:status_test_util",
-        "@com_google_absl//absl/container:flat_hash_map",
-        "@com_google_absl//absl/status:statusor",
-        "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/types:span",
-        "@local_tsl//tsl/platform:statusor",
-        "@local_tsl//tsl/platform:test_main",
-    ] + if_cuda_is_configured([
-        "@local_config_cuda//cuda:cuda_headers",
-    ]) + if_rocm_is_configured([
-        "@local_config_rocm//rocm:rocm_headers",
-    ]),
-)
-
-xla_test(
-    name = "gemm_broadcast_folding_rewrite_test",
-    srcs = ["gemm_broadcast_folding_rewrite_test.cc"],
-    backends = ["gpu"],
-    local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([
-        "TENSORFLOW_USE_ROCM=1",
-    ]),
-    deps = [
-        ":gpu_codegen_test",
-        "//xla:error_spec",
-        "//xla/hlo/ir:hlo",
-        "//xla/service/gpu:gemm_broadcast_folding_rewriter",
-        "//xla/service/gpu:gemm_rewriter",
-        "@local_tsl//tsl/platform:statusor",
-        "@local_tsl//tsl/platform:test",
-        "@local_tsl//tsl/platform:test_main",
-    ],
-)
-
-xla_test(
     name = "gpu_too_many_blocks_test",
     srcs = [
         "gpu_too_many_blocks_test.cc",
@@ -297,51 +200,6 @@
     ],
 )
 
-xla_cc_test(
-    name = "reduction_degenerate_dim_remover_test",
-    srcs = [
-        "reduction_degenerate_dim_remover_test.cc",
-    ],
-    deps = [
-        "//xla/service/gpu:reduction_degenerate_dim_remover",
-        "//xla/tests:hlo_test_base",
-        "@com_google_absl//absl/strings:string_view",
-        "@local_tsl//tsl/platform:test",
-        "@local_tsl//tsl/platform:test_main",
-    ],
-)
-
-xla_test(
-    name = "reduction_layout_normalizer_test",
-    srcs = [
-        "reduction_layout_normalizer_test.cc",
-    ],
-    backends = ["gpu"],
-    deps = [
-        "//xla:error_spec",
-        "//xla/service/gpu:reduction_layout_normalizer",
-        "//xla/tests:hlo_test_base",
-        "@com_google_absl//absl/strings:string_view",
-        "@local_tsl//tsl/platform:test",
-        "@local_tsl//tsl/platform:test_main",
-    ],
-)
-
-xla_cc_test(
-    name = "tree_reduction_rewriter_test",
-    srcs = [
-        "tree_reduction_rewriter_test.cc",
-    ],
-    deps = [
-        "//xla/service/gpu:tree_reduction_rewriter",
-        "//xla/stream_executor:device_description",
-        "//xla/tests:hlo_test_base",
-        "@com_google_absl//absl/strings",
-        "@local_tsl//tsl/platform:test",
-        "@local_tsl//tsl/platform:test_main",
-    ],
-)
-
 xla_test(
     name = "swap_conv_operands_test",
     srcs = [
@@ -375,20 +233,6 @@
     ],
 )
 
-xla_cc_test(
-    name = "reduction_dimension_grouper_test",
-    srcs = [
-        "reduction_dimension_grouper_test.cc",
-    ],
-    deps = [
-        "//xla/service/gpu:reduction_dimension_grouper",
-        "//xla/tests:hlo_test_base",
-        "@com_google_absl//absl/strings:string_view",
-        "@local_tsl//tsl/platform:test",
-        "@local_tsl//tsl/platform:test_main",
-    ],
-)
-
 xla_test(
     name = "parallel_reduction_test",
     srcs = [
@@ -640,7 +484,7 @@
         "//xla/hlo/ir:hlo",
         "//xla/service/gpu:gpu_device_info_for_tests",
         "//xla/service/gpu:gpu_fusible",
-        "//xla/service/gpu:instruction_fusion",
+        "//xla/service/gpu/transforms:instruction_fusion",
         "@com_google_absl//absl/strings:string_view",
         "@local_tsl//tsl/platform:test_main",
     ],
@@ -655,10 +499,10 @@
         "//xla:shape_util",
         "//xla/service:hlo_cost_analysis",
         "//xla/service:hlo_pass_pipeline",
-        "//xla/service/gpu:fusion_merger",
         "//xla/service/gpu:gpu_device_info_for_tests",
-        "//xla/service/gpu:instruction_fusion",
-        "//xla/service/gpu:multi_output_fusion",
+        "//xla/service/gpu/transforms:fusion_merger",
+        "//xla/service/gpu/transforms:instruction_fusion",
+        "//xla/service/gpu/transforms:multi_output_fusion",
         "//xla/stream_executor:device_description",
         "@com_google_absl//absl/strings:string_view",
         "@local_tsl//tsl/platform:test_main",
@@ -857,8 +701,7 @@
 #     name = "xla-opt",
 #     srcs = ["xla-opt.cc"],
 #     deps = [
-#         "//xla/service/gpu/fusions/triton:prevent_mmav3_loop_unrolling",
-#         "//xla/service/gpu/fusions/triton:sparse_extensions",
+#         "//xla/service/gpu/fusions/triton:passes",
 #         "@llvm-project//mlir:AllExtensions",
 #         "@llvm-project//mlir:MlirOptLib",
 #         "@triton//:AllPassesAndDialects",
@@ -972,7 +815,7 @@
         "//xla:shape_util",
         "//xla:xla_data_proto_cc",
         "//xla/hlo/ir:hlo",
-        "//xla/service/gpu:gpu_sort_rewriter",
+        "//xla/service/gpu/transforms:sort_rewriter",
         "//xla/tests:hlo_test_base",
         "@com_google_absl//absl/strings",
         "@com_google_googletest//:gtest_main",
diff --git a/third_party/xla/xla/service/gpu/tests/gemm_broadcast_folding_rewrite_test.cc b/third_party/xla/xla/service/gpu/tests/gemm_broadcast_folding_rewrite_test.cc
deleted file mode 100644
index 7e98e97..0000000
--- a/third_party/xla/xla/service/gpu/tests/gemm_broadcast_folding_rewrite_test.cc
+++ /dev/null
@@ -1,229 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include <memory>
-
-#include "xla/error_spec.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/gpu/gemm_broadcast_folding_rewriter.h"
-#include "xla/service/gpu/gemm_rewriter.h"
-#include "xla/service/gpu/tests/gpu_codegen_test.h"
-#include "tsl/platform/statusor.h"
-#include "tsl/platform/test.h"
-
-namespace xla {
-namespace gpu {
-
-namespace {
-
-class GemmBroadcastFoldingRewriteTest : public GpuCodegenTest {
- protected:
-  const auto& GpuComputeComp() {
-    return backend()
-        .default_stream_executor()
-        ->GetDeviceDescription()
-        .gpu_compute_capability();
-  }
-
-  DebugOptions GetDebugOptionsForTest() override {
-    DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest();
-    // These tests test the cuBLAS rewriter so we have to make sure that we use
-    // cuBLAS for them.
-    debug_options.set_xla_gpu_enable_triton_gemm(false);
-    debug_options.set_xla_gpu_gemm_rewrite_size_threshold(0);
-    return debug_options;
-  }
-};
-
-TEST_F(GemmBroadcastFoldingRewriteTest, BroadcastedStridedRewriteRhs) {
-  const char* hlo_text = R"(
-HloModule BroadcastedInput
-
-ENTRY AddDotsFunc {
-  x = f32[3,2,2]{2,1,0} parameter(0)
-  y = f32[2,2]{1,0} parameter(1)
-  y_broadcast = f32[3,2,2]{2,1,0} broadcast(y), dimensions={1,2}
-  ROOT dot_a = f32[3,2,2]{2,1,0} dot(x, y_broadcast), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_contracting_dims={1}
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[3,2,2], {{.*}}: f32[2,2]) -> f32[3,2,2] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[3,2,2]{2,1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1)
-; CHECK-NEXT:    [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
-; CHECK:           custom_call_target="__cublas${{(lt\$matmul|gemm)}}",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["2"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":["0"]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-)");
-}
-
-TEST_F(GemmBroadcastFoldingRewriteTest, BroadcastedStridedRewriteLhs) {
-  const char* hlo_text = R"(
-HloModule BroadcastedInput
-
-ENTRY AddDotsFunc {
-  x = f32[2,2]{1,0} parameter(0)
-  y = f32[3,2,2]{2,1,0} parameter(1)
-  x_broadcast = f32[3,2,2]{2,1,0} broadcast(x), dimensions={1,2}
-  ROOT dot_a = f32[3,2,2]{2,1,0} dot(x_broadcast, y), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_contracting_dims={1}
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,2], {{.*}}: f32[3,2,2]) -> f32[3,2,2] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[3,2,2]{2,1,0} parameter(1)
-; CHECK-NEXT:    [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
-; CHECK    :       custom_call_target="__cublas${{(lt\$matmul|gemm)}}",
-; CHECK    :       backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":["0"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-)");
-}
-
-TEST_F(GemmBroadcastFoldingRewriteTest,
-       BroadcastedStridedRewriteRhsPassChanged) {
-  const char* hlo_text = R"(
-HloModule BroadcastedInput
-
-ENTRY AddDotsFunc {
-  x = f32[3,2,2]{2,1,0} parameter(0)
-  y = f32[2,2]{1,0} parameter(1)
-  y_broadcast = f32[3,2,2]{2,1,0} broadcast(y), dimensions={1,2}
-  ROOT dot_a = f32[3,2,2]{2,1,0} dot(x, y_broadcast), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_contracting_dims={1}
-}
-
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(hlo_text));
-  // Use GemmRewriter to generate cublasGemm call.
-  GemmRewriter gemm_rewriter(GpuComputeComp(), /*toolkit_version=*/12040);
-  TF_ASSERT_OK_AND_ASSIGN(bool changed,
-                          this->RunHloPass(&gemm_rewriter, module.get()));
-  EXPECT_TRUE(changed);
-  GemmBroadcastFoldingRewriter pass;
-  TF_ASSERT_OK_AND_ASSIGN(changed, this->RunHloPass(&pass, module.get()));
-  EXPECT_TRUE(changed);
-}
-
-TEST_F(GemmBroadcastFoldingRewriteTest,
-       BroadcastedStridedRewriteLhsPassChanged) {
-  const char* hlo_text = R"(
-HloModule BroadcastedInput
-
-ENTRY AddDotsFunc {
-  x = f32[2,2]{1,0} parameter(0)
-  y = f32[3,2,2]{2,1,0} parameter(1)
-  x_broadcast = f32[3,2,2]{2,1,0} broadcast(x), dimensions={1,2}
-  ROOT dot_a = f32[3,2,2]{2,1,0} dot(x_broadcast, y), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_contracting_dims={1}
-}
-
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(hlo_text));
-  // Use GemmRewriter to generate cublasGemm call.
-  GemmRewriter gemm_rewriter(GpuComputeComp(), /*toolkit_version=*/12040);
-  TF_ASSERT_OK_AND_ASSIGN(bool changed,
-                          this->RunHloPass(&gemm_rewriter, module.get()));
-  EXPECT_TRUE(changed);
-  GemmBroadcastFoldingRewriter pass;
-  TF_ASSERT_OK_AND_ASSIGN(changed, this->RunHloPass(&pass, module.get()));
-  EXPECT_TRUE(changed);
-}
-
-TEST_F(GemmBroadcastFoldingRewriteTest, LHSBatchDimNonZero) {
-  const char* hlo_text = R"(
-HloModule LHSBatchDimNonZero
-
-ENTRY %LHSBatchDimNonZero (Arg_1: f32[4,3], Arg_2: f32[4,7,3]) -> f32[4,7,7] {
-  %Arg_1 = f32[4,3]{1,0} parameter(0)
-  %Arg_2 = f32[4,7,3]{2,1,0} parameter(1)
-  %broadcast.22 = f32[7,4,3]{2,1,0} broadcast(f32[4,3]{1,0} %Arg_1), dimensions={1,2}
-  ROOT %dot.24 = f32[4,7,7]{2,1,0} dot(f32[7,4,3]{2,1,0} %broadcast.22, f32[4,7,3]{2,1,0} %Arg_2), lhs_batch_dims={1}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={2}
-}
-)";
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(hlo_text));
-  // Use GemmRewriter to generate cublasGemm call.
-  GemmRewriter gemm_rewriter(GpuComputeComp(), /*toolkit_version=*/12040);
-  TF_ASSERT_OK_AND_ASSIGN(bool changed,
-                          this->RunHloPass(&gemm_rewriter, module.get()));
-  EXPECT_TRUE(changed);
-  GemmBroadcastFoldingRewriter pass;
-  TF_ASSERT_OK_AND_ASSIGN(changed, this->RunHloPass(&pass, module.get()));
-  EXPECT_FALSE(changed);
-}
-
-TEST_F(GemmBroadcastFoldingRewriteTest, RHSBatchDimNonZero) {
-  const char* hlo_text = R"(
-HloModule RHSBatchDimNonZero
-
-ENTRY %RHSBatchDimNonZero (Arg_1: f32[4,3], Arg_2: f32[4,7,3]) -> f32[4,7,7] {
-  %Arg_1 = f32[4,3]{1,0} parameter(0)
-  %Arg_2 = f32[4,7,3]{2,1,0} parameter(1)
-  %broadcast.22 = f32[7,4,3]{2,1,0} broadcast(f32[4,3]{1,0} %Arg_1), dimensions={1,2}
-  ROOT %dot.24 = f32[4,7,7]{2,1,0} dot(f32[4,7,3]{2,1,0} %Arg_2, f32[7,4,3]{2,1,0} %broadcast.22), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={1}, rhs_contracting_dims={2}
-}
-)";
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(hlo_text));
-  GemmRewriter gemm_rewriter(GpuComputeComp(), /*toolkit_version=*/12040);
-  TF_ASSERT_OK_AND_ASSIGN(bool changed,
-                          this->RunHloPass(&gemm_rewriter, module.get()));
-  EXPECT_TRUE(changed);
-  GemmBroadcastFoldingRewriter pass;
-  TF_ASSERT_OK_AND_ASSIGN(changed, this->RunHloPass(&pass, module.get()));
-  EXPECT_FALSE(changed);
-}
-
-}  // namespace
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/tests/gemm_rewrite_test.cc b/third_party/xla/xla/service/gpu/tests/gemm_rewrite_test.cc
deleted file mode 100644
index f412f1f..0000000
--- a/third_party/xla/xla/service/gpu/tests/gemm_rewrite_test.cc
+++ /dev/null
@@ -1,8313 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include <array>
-#include <cstdint>
-#include <functional>
-#include <memory>
-#include <optional>
-#include <string>
-#include <tuple>
-#include <utility>
-#include <variant>
-#include <vector>
-
-#include "absl/container/flat_hash_map.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/str_replace.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "xla/error_spec.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/buffer_assignment.h"
-#include "xla/service/executable.h"
-#include "xla/service/gpu/gemm_rewriter.h"
-#include "xla/service/gpu/gpu_executable.h"
-#include "xla/service/gpu/tests/gpu_codegen_test.h"
-#include "xla/service/hlo_module_config.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/stream_executor/device_memory_allocator.h"
-#include "xla/stream_executor/stream_executor_memory_allocator.h"
-#include "xla/test.h"
-#include "xla/tests/filecheck.h"
-#include "xla/tests/verified_hlo_module.h"
-#include "xla/tsl/lib/core/status_test_util.h"
-#include "xla/xla.pb.h"
-#include "tsl/platform/statusor.h"
-
-#if GOOGLE_CUDA
-#include "third_party/gpus/cuda/include/cuda.h"
-#elif TENSORFLOW_USE_ROCM
-#include "rocm/rocm_config.h"
-#endif
-
-namespace xla {
-namespace gpu {
-
-namespace {
-
-namespace m = ::xla::match;
-
-class GemmRewriteTest : public GpuCodegenTest {
-  const auto& device_desc() {
-    return backend().default_stream_executor()->GetDeviceDescription();
-  }
-
- protected:
-  const se::GpuComputeCapability& Capability() {
-    return device_desc().gpu_compute_capability();
-  }
-
-  int32_t GetToolkitVersion() const {
-#if GOOGLE_CUDA
-    return CUDA_VERSION;
-#elif TENSORFLOW_USE_ROCM
-    return TF_ROCM_VERSION;
-#endif
-    return 0;
-  }
-
-  bool IsCuda() {
-    return std::holds_alternative<se::CudaComputeCapability>(Capability());
-  }
-
-  se::GpuComputeCapability CudaHopperOrRocmMI300() {
-    if (IsCuda()) {
-      return se::CudaComputeCapability::Hopper();
-    } else {
-      return se::RocmComputeCapability{"gfx942"};
-    }
-  }
-
-  DebugOptions GetDebugOptionsForTest() override {
-    DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest();
-    // These tests test the cuBLAS rewriter so we have to make sure that we use
-    // cuBLAS for them.
-    debug_options.set_xla_gpu_enable_triton_gemm(false);
-    debug_options.set_xla_gpu_gemm_rewrite_size_threshold(0);
-    return debug_options;
-  }
-
-  bool SkipGpuBlasLtTest() {
-    return !IsCuda() &&
-           !std::get<se::RocmComputeCapability>(Capability()).has_hipblaslt() &&
-           GetDebugOptionsForTest().xla_gpu_enable_cublaslt();
-  }
-
-  bool HasFp8Support() {
-    if (IsCuda()) {
-      return std::get<se::CudaComputeCapability>(Capability()).IsAtLeast(8, 9);
-    }
-    return std::get<se::RocmComputeCapability>(Capability()).has_fp8_support();
-  }
-
-  bool HasCudaComputeCapability(const se::CudaComputeCapability& cc) {
-    return IsCuda() &&
-           std::get<se::CudaComputeCapability>(Capability()).IsAtLeast(cc);
-  }
-};
-
-TEST_F(GemmRewriteTest, CheckCustomCallTarget) {
-  if (SkipGpuBlasLtTest()) {
-    GTEST_SKIP() << "BlasLt is not supported on this GPU architecture";
-  }
-
-  const char* hlo_text = R"(
-HloModule SimpleGemm
-
-ENTRY AddDotsFunc {
-  x = f32[2,3] parameter(0)
-  y = f32[3,4] parameter(1)
-  ROOT dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
-
-)";
-  DebugOptions debug_options = GetDebugOptionsForTest();
-  if (debug_options.xla_gpu_enable_cublaslt()) {
-    MatchOptimizedHlo(hlo_text,
-                      R"(; CHECK: custom_call_target="__cublas$lt$matmul")");
-  } else {
-    MatchOptimizedHlo(hlo_text,
-                      R"(; CHECK: custom_call_target="__cublas$gemm")");
-  }
-}
-
-#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
-TEST_F(GemmRewriteTest, TestBatchedAutotuning) {
-  if (HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
-    GTEST_SKIP()
-        << "There is no autotuning starting with the Nvidia Ampere generation";
-  }
-
-  const char* hlo_text = R"(
-HloModule ComplexDotMultipleNonContracting
-
-ENTRY %test {
-  %lhs = f32[7,17,10,13]{3,2,1,0} parameter(0)
-  %rhs = f32[7,9,10,13,6]{4,3,2,1,0} parameter(1)
-  ROOT %dot = f32[10,7,17,9,6]{4,3,2,1,0} dot(%lhs, %rhs), lhs_batch_dims={2,0}, rhs_batch_dims={2,0}, lhs_contracting_dims={3}, rhs_contracting_dims={3}
-}
-
-)";
-
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-; CHECK: selected_algorithm
-      )");
-}
-#endif
-
-TEST_F(GemmRewriteTest, SimpleRewriteDeterministic) {
-  if (SkipGpuBlasLtTest()) {
-    GTEST_SKIP() << "BlasLt is not supported on this GPU architecture";
-  }
-
-  const char* hlo_text = R"(
-HloModule SimpleGemm
-
-ENTRY AddDotsFunc {
-  x = f32[128,128] parameter(0)
-  y = f32[128,128] parameter(1)
-  ROOT dot_a = f32[128,128] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
-)";
-
-  ErrorSpec error_spec = [&] {
-    DebugOptions debug_options = GetDebugOptionsForTest();
-    if (debug_options.xla_gpu_enable_cublaslt()) {
-      return ErrorSpec{1e-3, 1e-3};
-    } else {
-      return ErrorSpec{1e-3, 1e-3};
-    }
-  }();
-
-  auto get_module = [&]() {
-    HloModuleConfig config;
-    DebugOptions debug_options = GetDebugOptionsForTest();
-    debug_options.set_xla_gpu_exclude_nondeterministic_ops(true);
-    config.set_debug_options(debug_options);
-    return ParseAndReturnVerifiedModule(hlo_text, config);
-  };
-
-  se::StreamExecutorMemoryAllocator allocator(
-      backend().default_stream_executor());
-  TF_ASSERT_OK_AND_ASSIGN(
-      std::unique_ptr<HloModule> optimized_module,
-      backend().compiler()->RunHloPasses(
-          *get_module(), backend().default_stream_executor(), &allocator));
-
-  absl::StatusOr<bool> filecheck_result =
-      RunFileCheck(optimized_module->ToString(),
-                   R"(
-; CHECK:    custom_call_target="__cublas${{(lt\$matmul|gemm)}}"
-    )");
-  TF_ASSERT_OK(filecheck_result.status());
-  EXPECT_TRUE(filecheck_result.value());
-  EXPECT_TRUE(RunAndCompare(*get_module(), error_spec));
-}
-
-TEST_F(GemmRewriteTest, BF16GemmCodeGen) {
-  const char* hlo_text = R"(
-HloModule bf16codegendgemm
-
-ENTRY bf16gemm {
-  %parameter.1 = bf16[3]{0} parameter(0)
-  %parameter.2 = bf16[3]{0} parameter(1)
-  ROOT %dot.3 = bf16[] dot(bf16[3]{0} %parameter.1, bf16[3]{0} %parameter.2), lhs_contracting_dims={0}, rhs_contracting_dims={0}, operand_precision={highest,highest}
-}
-  )";
-
-  if (HasCudaComputeCapability(se::CudaComputeCapability::Hopper())) {
-    // The Hopper optimized HLO has a BF16 multiply instruction since Hopper has
-    // native BF16 multiply support.
-    MatchOptimizedHlo(hlo_text, R"(
-  ; CHECK:  [[P0:%[^ ]+]] = bf16[3]{0} parameter(0)
-  ; CHECK:  [[P1:%[^ ]+]] = bf16[3]{0} parameter(1)
-  ; CHECK:  [[INSTR_2:%[^ ]+]] = bf16[3]{0} multiply([[P0]], [[P1]])
-  ; CHECK:  [[INSTR_3:%[^ ]+]] = f32[3]{0} convert([[INSTR_2]])
-  ; CHECK:  [[INSTR_4:%[^ ]+]] = f32[] constant(0)
-  ; CHECK:  [[INSTR_5:%[^ ]+]] = f32[] reduce([[INSTR_3]], [[INSTR_4]]), dimensions={0}, to_apply=[[INSTR_6:%[^ ]+]]
-  ; CHECK:  ROOT [[INSTR_7:%[^ ]+]] = bf16[] convert([[INSTR_5]])
-    )");
-  } else {
-    MatchOptimizedHlo(hlo_text, R"(
-  ; CHECK:  [[P1:%[^ ]+]] = bf16[3]{0} parameter(1)
-  ; CHECK:  [[INSTR_1:%[^ ]+]] = f32[3]{0} convert([[P1]])
-  ; CHECK:  [[P0:%[^ ]+]] = bf16[3]{0} parameter(0)
-  ; CHECK:  [[INSTR_3:%[^ ]+]] = f32[3]{0} convert([[P0]])
-  ; CHECK:  [[INSTR_4:%[^ ]+]] = f32[3]{0} multiply([[INSTR_1]], [[INSTR_3]])
-  ; CHECK:  [[INSTR_5:%[^ ]+]] = f32[] constant(0)
-  ; CHECK:  [[INSTR_6:%[^ ]+]] = f32[] reduce([[INSTR_4]], [[INSTR_5]]), dimensions={0}, to_apply=[[INSTR_7:%[^ ]+]]
-  ; CHECK:  ROOT [[INSTR_8:%[^ ]+]] = bf16[] convert([[INSTR_6]])
-    )");
-  }
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-4, 1e-4}));
-}
-
-TEST_F(GemmRewriteTest, BF16Transpose) {
-  const char* hlo_text = R"(
-HloModule broadcast
-
-ENTRY broadcast {
-  p = bf16[9] parameter(0)
-  ROOT out = bf16[1,9] broadcast(p), dimensions={1}
-}
-)";
-
-  MatchOptimizedHlo(hlo_text, R"(
-; CHECK: bf16[1,9]{1,0} bitcast
-; CHECK: bf16[1,9]{1,0} copy
-)");
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-}
-
-#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
-// A test fixture class for tests which should have similar results with legacy
-// cublas and cublasLt
-class ParameterizedGemmRewriteTest
-    : public GemmRewriteTest,
-      public ::testing::WithParamInterface<bool> {
- public:
-  ParameterizedGemmRewriteTest() {
-    const bool kUsingCublasLt = GetParam();
-    replacements_[kCustomCallTargetPlaceholder] =
-        kUsingCublasLt ? "__cublas$lt$matmul" : "__cublas$gemm";
-  }
-  DebugOptions GetDebugOptionsForTest() override {
-    DebugOptions debug_options = GemmRewriteTest::GetDebugOptionsForTest();
-    debug_options.set_xla_gpu_enable_cublaslt(GetParam());
-    debug_options.set_xla_gpu_enable_triton_gemm(false);
-    return debug_options;
-  }
-  void MatchOptimizedHlo(absl::string_view hlo, const absl::string_view pattern,
-                         bool print_operand_shape = false) {
-    GemmRewriteTest::MatchOptimizedHlo(
-        hlo, absl::StrReplaceAll(pattern, replacements_), print_operand_shape);
-  }
-  absl::string_view CustomCallTarget() {
-    return replacements_[kCustomCallTargetPlaceholder];
-  }
-
- protected:
-  void SetUp() override {
-    if (SkipGpuBlasLtTest()) {
-      GTEST_SKIP() << "BlasLt is not supported on this GPU architecture";
-    }
-  }
-
- protected:
-  absl::flat_hash_map<absl::string_view, absl::string_view> replacements_;
-
- private:
-  static constexpr const char* kCustomCallTargetPlaceholder{
-      "<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>"};
-};
-
-TEST_P(ParameterizedGemmRewriteTest, Simple) {
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  x = f32[2,3] parameter(0)
-  y = f32[3,4] parameter(1)
-  ROOT dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4]) -> f32[2,4] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
-; CHECK-NEXT:    [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
-; CHECK:           custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-)");
-}
-
-TEST_P(ParameterizedGemmRewriteTest, SimpleRewrite) {
-  const char* hlo_text = R"(
-HloModule SimpleGemm
-
-ENTRY AddDotsFunc {
-  x = f32[2,3] parameter(0)
-  y = f32[3,4] parameter(1)
-  ROOT dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,3], {{.*}}: f32[3,4]) -> f32[2,4] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
-; CHECK-NEXT:    [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
-; CHECK:           custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-)");
-}
-
-TEST_P(ParameterizedGemmRewriteTest, MultipleContractingDims) {
-  const char* hlo_text = R"(
-HloModule MultipleContractingCheckGemm
-
-ENTRY AddDotsFunc {
-  x = f32[3,4,2] parameter(0)
-  y = f32[3,4,5] parameter(1)
-  ROOT dot_a = f32[2,5] dot(x, y), lhs_contracting_dims={0,1}, rhs_contracting_dims={0,1}, operand_precision={highest,highest}
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-; CHECK-NOT:     copy
-;
-; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[3,4,2], {{.*}}: f32[3,4,5]) -> f32[2,5] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[3,4,2]{2,1,0} parameter(0)
-; CHECK-DAG:     [[P1:%[^ ]+]] = f32[3,4,5]{2,1,0} parameter(1)
-; CHECK-DAG:     [[BITCAST0:%[^ ]+]] = f32[2,12]{0,1} bitcast([[P0]])
-; CHECK-DAG:     [[BITCAST1:%[^ ]+]] = f32[12,5]{1,0} bitcast([[P1]])
-; CHECK-NEXT:    [[GEMM:%[^ ]+]] = {{.*}} custom-call([[BITCAST0]], [[BITCAST1]]),
-; CHECK:           custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["HIGHEST","HIGHEST"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-)");
-}
-
-TEST_P(ParameterizedGemmRewriteTest, ArgTransposeFoldCheck) {
-  const char* hlo_text = R"(
-HloModule ArgTransposeFoldGemm
-
-ENTRY AddDotsFunc {
-  x = f32[3,2] parameter(0)
-  y = f32[3,4] parameter(1)
-  x_transposed = f32[2,3] transpose(x), dimensions={1, 0}
-  ROOT dot_a = f32[2,4] dot(x_transposed, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[3,2], {{.*}}: f32[3,4]) -> f32[2,4] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[3,2]{1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
-; CHECK-NEXT:    [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
-; CHECK:           custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-)");
-}
-
-TEST_P(ParameterizedGemmRewriteTest, BatchedArgRowColTransposeFoldCheck) {
-  const char* hlo_text = R"(
-HloModule BatchedArgRowColTransposeFoldGemm
-
-ENTRY AddDotsFunc {
-  x = f32[5,3,2] parameter(0)
-  y = f32[5,3,4] parameter(1)
-  x_transposed = f32[5,2,3] transpose(x), dimensions={0, 2, 1}
-  ROOT dot_a = f32[5,2,4] dot(x_transposed, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0}
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[5,3,2], {{.*}}: f32[5,3,4]) -> f32[5,2,4] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[5,3,2]{2,1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[5,3,4]{2,1,0} parameter(1)
-; CHECK-NEXT:    [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
-; CHECK:           custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "lhs_batch_dimensions":["0"]
-; CHECK-DAG:           "rhs_batch_dimensions":["0"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-)");
-}
-
-TEST_P(ParameterizedGemmRewriteTest, BatchRowTransposeFoldCheck) {
-  const char* hlo_text = R"(
-HloModule BatchRowTransposeFoldCheck
-
-ENTRY AddDotsFunc {
-  x = f32[2,5,3] parameter(0)
-  y = f32[5,3,4] parameter(1)
-  x_transposed = f32[5,2,3] transpose(x), dimensions={1, 0, 2}
-  ROOT dot_a = f32[5,2,4] dot(x_transposed, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0}
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{2.5e-5, 1e-5}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,5,3], {{.*}}: f32[5,3,4]) -> f32[5,2,4] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,5,3]{2,1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[5,3,4]{2,1,0} parameter(1)
-; CHECK-NEXT:    [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
-; CHECK:           custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["2"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "lhs_batch_dimensions":["1"]
-; CHECK-DAG:           "rhs_batch_dimensions":["0"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-)");
-}
-
-TEST_P(ParameterizedGemmRewriteTest, BatchFromMinorDimTransposeIsNotFolded) {
-  const char* hlo_text = R"(
-HloModule BatchFromMinorDimTransposeDoesntFold
-
-ENTRY AddDotsFunc {
-  x = f32[3,2,5] parameter(0)
-  y = f32[5,3,4] parameter(1)
-  x_transposed = f32[5,2,3] transpose(x), dimensions={2, 1, 0}
-  ROOT dot_a = f32[5,2,4] dot(x_transposed, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0}
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{2.5e-5, 1e-5}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[3,2,5], {{.*}}: f32[5,3,4]) -> f32[5,2,4] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[3,2,5]{2,1,0} parameter(0)
-; CHECK-DAG:     [[P1:%[^ ]+]] = f32[5,3,4]{2,1,0} parameter(1)
-; CHECK-DAG:     [[FUSION:%[^ ]+]] = f32[5,2,3]{2,1,0} transpose([[P0]])
-; CHECK-NEXT:    [[GEMM:%[^ ]+]] = {{.*}} custom-call([[FUSION]], [[P1]]),
-; CHECK:           custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["2"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "lhs_batch_dimensions":["0"]
-; CHECK-DAG:           "rhs_batch_dimensions":["0"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-)");
-}
-
-TEST_P(ParameterizedGemmRewriteTest, LargeBatch) {
-  const char* hlo_text = R"(
-HloModule BatchedArgRowColTransposeFoldGemm
-
-ENTRY AddDotsFunc {
-  x = f32[20000,4,3,2] parameter(0)
-  y = f32[20000,4,3,4] parameter(1)
-  ROOT dot_a = f32[20000,4,2,4] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={2}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1}
-}
-
-)";
-
-  // Batch sizes larger than 2^16-1 are not supported by cublasLt. Ensure that
-  // the custom_call_target is __cublas$gemm.
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[20000,4,3,2], {{.*}}: f32[20000,4,3,4]) -> f32[20000,4,2,4] {
-; CHECK:    [[P0:%[^ ]+]] = f32[20000,4,3,2]{3,2,1,0} parameter(0)
-; CHECK:    [[BC0:%[^ ]+]] = f32[80000,3,2]{2,1,0} bitcast([[P0]])
-; CHECK:    [[P1:%[^ ]+]] = f32[20000,4,3,4]{3,2,1,0} parameter(1)
-; CHECK:    [[BC1:%[^ ]+]] = f32[80000,3,4]{2,1,0} bitcast([[P1]])
-; CHECK:    [[GEMM:%[^ ]+]] = (f32[80000,2,4]{2,1,0}, s8[{{[0-9]+}}]{0}) custom-call([[BC0]], [[BC1]]),
-; CHECK:           custom_call_target="__cublas$gemm",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "lhs_batch_dimensions":["0"]
-; CHECK-DAG:           "rhs_batch_dimensions":["0"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK:           }
-; CHECK:   [[OUT:%[^ ]+]] = f32[80000,2,4]{2,1,0} get-tuple-element([[GEMM]]), index=0
-; CHECK:   ROOT {{[^ ]+}} = f32[20000,4,2,4]{3,2,1,0} bitcast([[OUT]])
-)");
-}
-
-TEST_P(ParameterizedGemmRewriteTest, InstrTransposeFoldCheck) {
-  const char* hlo_text = R"(
-HloModule InstrTransposeFoldGemm
-
-ENTRY AddDotsFunc {
-  x = f32[2,3] parameter(0)
-  y = f32[3,4] parameter(1)
-  dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  ROOT out = f32[4,2] transpose(dot_a), dimensions={1, 0}
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,3], {{.*}}: f32[3,4]) -> f32[4,2] {
-; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
-; CHECK-NEXT:    [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P1]], [[P0]]),
-; CHECK:           custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-)");
-}
-
-TEST_P(ParameterizedGemmRewriteTest, BatchedInstrLayoutTransposed) {
-  const char* hlo_text = R"(
-HloModule BatchedInstrLayoutCheck
-
-ENTRY AddDotsFunc {
-  x = f32[5,2,3] parameter(0)
-  y = f32[5,3,4] parameter(1)
-  dot_a = f32[5,2,4] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0}
-  ROOT out = f32[2,5,4] transpose(dot_a), dimensions={1, 0, 2}
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{2.5e-5, 1e-5}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[5,2,3], {{.*}}: f32[5,3,4]) -> f32[2,5,4] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[5,2,3]{2,1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[5,3,4]{2,1,0} parameter(1)
-; CHECK-NEXT:    [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
-; CHECK:           custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["2"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "lhs_batch_dimensions":["0"]
-; CHECK-DAG:           "rhs_batch_dimensions":["0"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-; CHECK:         ROOT [[OUT:%[^ ]+]] = f32[2,5,4]{2,1,0} bitcast
-)");
-}
-
-TEST_P(ParameterizedGemmRewriteTest, BatchedInstrLayoutBatchNotInMinorDim) {
-  const char* hlo_text = R"(
-HloModule BatchedInstrLayoutBatchNotInMinorDim
-
-ENTRY AddDotsFunc {
-  x = f32[5,2,3] parameter(0)
-  y = f32[5,3,4] parameter(1)
-  dot_a = f32[5,2,4] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0}
-  ROOT out = f32[2,4,5] transpose(dot_a), dimensions={1, 2, 0}
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{2.5e-5, 1e-5}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[5,2,3], {{.*}}: f32[5,3,4]) -> f32[2,4,5] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[5,2,3]{2,1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[5,3,4]{2,1,0} parameter(1)
-; CHECK-NEXT:    [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
-; CHECK:           custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["2"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "lhs_batch_dimensions":["0"]
-; CHECK-DAG:           "rhs_batch_dimensions":["0"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-; CHECK:         ROOT [[OUT:%[^ ]+]] = f32[2,4,5]{2,1,0} [[OP:[^ ]+]]
-)");
-}
-
-TEST_P(ParameterizedGemmRewriteTest, AlphaSimpleRewrite) {
-  const char* hlo_text = R"(
-HloModule AlphaSimpleRewrite
-
-ENTRY AddDotsFunc {
-  x = f32[2,2] parameter(0)
-  y = f32[2,2] parameter(1)
-  k = f32[] constant(3.0)
-  k_broadcast = f32[2, 2] broadcast(k), dimensions={}
-  dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,highest}
-  ROOT dot_a_multiplied = f32[2, 2] multiply(dot_a, k_broadcast)
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,2], {{.*}}: f32[2,2]) -> f32[2,2] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1)
-; CHECK-NEXT:    [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
-; CHECK:           custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":3
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["HIGHEST","HIGHEST"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-)");
-}
-
-TEST_P(ParameterizedGemmRewriteTest, F64C64_CublasLtSupportTest) {
-  // This test should fail if gemm rewriter does not correctly rewrite
-  // F64/C64 dots to cublas-lt or legacy cublas calls
-  {
-    const char* hlo_text = R"(
-HloModule F64_rewrite
-
-ENTRY AddDotsFunc {
-  x = f64[2,2] parameter(0)
-  y = f64[2,2] parameter(1)
-  k = f64[] constant(3.0)
-  k_broadcast = f64[2, 2] broadcast(k), dimensions={}
-  dot_a = f64[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  ROOT dot_a_multiplied = f64[2, 2] multiply(dot_a, k_broadcast)
-}
-)";
-    EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-4, 1e-5}));
-  }
-  {
-    const char* hlo_text = R"(
-HloModule C64_rewrite
-
-ENTRY AddDotsFunc {
-  x = c64[2,2] parameter(0)
-  y = c64[2,2] parameter(1)
-  k = c64[] constant((3.0, 3.0))
-  k_broadcast = c64[2, 2] broadcast(k), dimensions={}
-  dot_a = c64[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  ROOT dot_a_multiplied = c64[2, 2] multiply(dot_a, k_broadcast)
-}
-)";
-    EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-4, 1e-5}));
-  }
-}
-
-TEST_P(ParameterizedGemmRewriteTest, ComplexAlphaSimpleRewrite) {
-  if (!IsCuda() && GetDebugOptionsForTest().xla_gpu_enable_cublaslt()) {
-    GTEST_SKIP() << "TODO: Unsupported C64 gpublas-lt datatype on ROCM";
-  }
-  const char* hlo_text = R"(
-HloModule ComplexAlphaSimpleRewrite
-
-ENTRY AddDotsFunc {
-  x = c64[2,2] parameter(0)
-  y = c64[2,2] parameter(1)
-  k = c64[] constant((3.0, 3.0))
-  k_broadcast = c64[2, 2] broadcast(k), dimensions={}
-  dot_a = c64[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  ROOT dot_a_multiplied = c64[2, 2] multiply(dot_a, k_broadcast)
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-4, 1e-5}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: c64[2,2], {{.*}}: c64[2,2]) -> c64[2,2] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = c64[2,2]{1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = c64[2,2]{1,0} parameter(1)
-; CHECK-NEXT:    [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
-; CHECK:           custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":3
-; CHECK-DAG:         "alpha_imag":3
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-)");
-}
-
-TEST_P(ParameterizedGemmRewriteTest, AlphaMultipleUsersNoRewrite) {
-  const char* hlo_text = R"(
-HloModule AlphaMultipleUsersNoRewrite
-
-ENTRY AddDotsFunc {
-  x = f32[2,2] parameter(0)
-  y = f32[2,2] parameter(1)
-  k = f32[] constant(3.0)
-  k_broadcast = f32[2, 2] broadcast(k), dimensions={}
-  dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,highest}
-  dot_a_multiplied = f32[2, 2] multiply(dot_a, k_broadcast)
-  ROOT out = f32[2,2] add(dot_a_multiplied, dot_a)
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-; CHECK:    {{[^ ]+}} = {{.*}} custom-call({{[^,]+}}, {{[^)]+}}),
-; CHECK:           custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["HIGHEST","HIGHEST"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-)");
-}
-
-TEST_P(ParameterizedGemmRewriteTest, AlphaVectorNoRewrite) {
-  const char* hlo_text = R"(
-HloModule AlphaVectorNoRewrite
-
-ENTRY AddDotsFunc {
-  x = f32[2,2] parameter(0)
-  y = f32[2,2] parameter(1)
-  alpha = f32[2] constant({1, 2})
-  alpha_broadcast = f32[2,2] broadcast(alpha), dimensions={1}
-  dot = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  ROOT dot_a_multiplied = f32[2, 2] multiply(dot, alpha_broadcast)
-}
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,2], {{.*}}: f32[2,2]) -> f32[2,2] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1)
-; CHECK-NEXT:    [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
-; CHECK:           custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-)");
-}
-
-TEST_P(ParameterizedGemmRewriteTest, BF16Gemm) {
-  const char* hlo_text = R"(
-HloModule bf16gemm
-
-ENTRY bf16gemm {
-  %parameter.1 = bf16[12,4]{1,0} parameter(0)
-  %parameter.2 = bf16[4,8]{1,0} parameter(1)
-  ROOT %dot.8 = bf16[12,8] dot(bf16[12,4] %parameter.1, bf16[4,8] %parameter.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
-  )";
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-
-  if (!IsCuda() ||
-      HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
-    MatchOptimizedHlo(hlo_text,
-                      R"(
-; CHECK: {{.*}} custom-call(bf16[16,8]{1,0} {{.*}}, bf16[8,8]{1,0} {{.*}}), custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>"
-  )",
-                      /*print_operand_shape=*/true);
-  } else {
-    GTEST_SKIP() << "Pre-Ampere casts up bf16 to fp32";
-  }
-}
-
-TEST_P(ParameterizedGemmRewriteTest, BF16GemmStrided) {
-  const char* hlo_text = R"(
-HloModule bf16gemm
-
-ENTRY bf16gemm {
-  %parameter.1 = bf16[3,3,4] parameter(0)
-  %parameter.2 = bf16[3,3,2] parameter(1)
-  ROOT %dot.3 = bf16[3,4,2]{2,1,0} dot(bf16[3,3,4]{2,1,0} %parameter.1, bf16[3,3,2]{2,1,0} %parameter.2), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1}, operand_precision={highest,highest}
-}
-
-  )";
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-
-  if (!IsCuda() ||
-      HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
-    MatchOptimizedHlo(hlo_text,
-                      R"(
-    ; CHECK: {{.*}} custom-call(bf16[3,8,8]{2,1,0} {{.*}}, bf16[3,8,8]{2,1,0} {{.*}}), custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>"
-    )",
-                      /*print_operand_shape=*/true);
-  } else {
-    GTEST_SKIP() << "Pre-Ampere casts up bf16 to fp32";
-  }
-}
-
-TEST_P(ParameterizedGemmRewriteTest, Int8Gemm) {
-  const char* hlo_text = R"(
-HloModule int8gemm
-
-ENTRY int8gemm {
-  %parameter.1 = s8[12,4]{1,0} parameter(0)
-  %parameter.2 = s8[4,8]{1,0} parameter(1)
-  ROOT %dot.8 = s32[12,8] dot(s8[12,4] %parameter.1, s8[4,8] %parameter.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
-  )";
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-
-  if (!IsCuda() ||
-      HasCudaComputeCapability(se::CudaComputeCapability::Volta())) {
-    MatchOptimizedHlo(hlo_text,
-                      R"(
-; CHECK: {{.*}} custom-call(s8[12,4]{1,0} [[A:%[^ ]+]], s8[4,8]{0,1} [[B:%[^ ]+]]), custom_call_target="__cublas$gemm"
-  )",
-                      /*print_operand_shape=*/true);
-  } else {
-    MatchOptimizedHlo(hlo_text,
-                      R"(
-; CHECK: {{.*}} dot(s32[12,4]{1,0} [[A:%[^ ]+]], s32[4,8]{1,0} [[B:%[^ ]+]]), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-
-  )",
-                      /*print_operand_shape=*/true);
-  }
-}
-
-TEST_F(GemmRewriteTest, Int8GemmRankGreaterThanTwo) {
-  if (!IsCuda()) {
-    GTEST_SKIP() << "DoBlasGemmWithAlgorithm is not yet implemented on ROCm";
-  }
-
-  const char* hlo_text = R"(
-HloModule int8gemm
-
-ENTRY main.4 {
-  Arg_0.1 = s8[1,8,2]{2,1,0} parameter(0)
-  Arg_1.2 = s8[2,4]{1,0} parameter(1)
-  ROOT dot.3 = s32[1,8,4]{2,1,0} dot(Arg_0.1, Arg_1.2),
-  lhs_contracting_dims={2}, rhs_contracting_dims={0}
-}
-  )";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-
-  if (!IsCuda() ||
-      HasCudaComputeCapability(se::CudaComputeCapability::Volta())) {
-    MatchOptimizedHlo(hlo_text,
-                      R"(
-; CHECK: [[GEMM:%[^ ]+]] = (s32[8,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call(s8[8,4]{1,0} %{{.*}}, s8[4,4]{0,1} %{{.*}}), custom_call_target="__cublas$gemm",
-  )",
-                      /*print_operand_shape=*/true);
-  }
-}
-
-TEST_P(ParameterizedGemmRewriteTest, Int8GemmNoAlphaRewrite) {
-  const char* hlo_text = R"(
-HloModule int8gemm
-
-ENTRY int8gemm {
-  %parameter.1 = s8[12,4]{1,0} parameter(0)
-  %parameter.2 = s8[4,8]{1,0} parameter(1)
-  k = s32[] constant(2)
-  k_broadcast = s32[12,8] broadcast(k), dimensions={}
-  %dot.8 = s32[12,8] dot(s8[12,4] %parameter.1, s8[4,8] %parameter.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  ROOT dot_multiplied = s32[12,8] multiply(%dot.8, k_broadcast)
-}
-  )";
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-
-  if (!IsCuda() ||
-      HasCudaComputeCapability(se::CudaComputeCapability::Volta())) {
-    MatchOptimizedHlo(hlo_text,
-                      R"(
-; CHECK: {{.*}} custom-call(s8[12,4]{1,0} [[A:%[^ ]+]], s8[4,8]{0,1} [[B:%[^ ]+]]),
-; CHECK:           custom_call_target="__cublas$gemm",
-; CHECK:           backend_config={
-; CHECK-DAG:       "alpha_real":1
-; CHECK-DAG:       "alpha_imag":0
-  )",
-                      /*print_operand_shape=*/true);
-  } else {
-    MatchOptimizedHlo(hlo_text,
-                      R"(
-; CHECK: {{.*}} dot(s32[12,4]{1,0} [[A:%[^ ]+]], s32[4,8]{1,0} [[B:%[^ ]+]]), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-
-  )",
-                      /*print_operand_shape=*/true);
-  }
-}
-
-TEST_P(ParameterizedGemmRewriteTest, Int8GemmNoBetaRewrite) {
-  const char* hlo_text = R"(
-HloModule int8gemm
-
-ENTRY int8gemm {
-  %parameter.1 = s8[12,4]{1,0} parameter(0)
-  %parameter.2 = s8[4,8]{1,0} parameter(1)
-  bias = s32[12,8] parameter(2)
-  %dot.8 = s32[12,8] dot(s8[12,4] %parameter.1, s8[4,8] %parameter.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  ROOT out = s32[12,8] add(%dot.8, bias)
-}
-  )";
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-
-  if (!IsCuda() ||
-      HasCudaComputeCapability(se::CudaComputeCapability::Volta())) {
-    MatchOptimizedHlo(hlo_text,
-                      R"(
-; CHECK: {{.*}} custom-call(s8[12,4]{1,0} [[A:%[^ ]+]], s8[4,8]{0,1} [[B:%[^ ]+]]),
-; CHECK:           custom_call_target="__cublas$gemm",
-; CHECK:           backend_config={
-; CHECK-DAG:       "alpha_real":1
-; CHECK-DAG:       "alpha_imag":0
-; CHECK-DAG:       "beta":0
-  )",
-                      /*print_operand_shape=*/true);
-  } else {
-    MatchOptimizedHlo(hlo_text,
-                      R"(
-; CHECK: {{.*}} dot(s32[12,4]{1,0} [[A:%[^ ]+]], s32[4,8]{1,0} [[B:%[^ ]+]]), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-
-  )",
-                      /*print_operand_shape=*/true);
-  }
-}
-
-TEST_P(ParameterizedGemmRewriteTest, Int8GemmNotMultipleOfFour) {
-  if (!IsCuda()) {
-    GTEST_SKIP() << "DoBlasGemmWithAlgorithm is not yet implemented on ROCm";
-  }
-
-  const char* hlo_text = R"(
-HloModule int8gemm
-
-ENTRY int8gemm {
-  %parameter.1 = s8[13,4]{1,0} parameter(0)
-  %parameter.2 = s8[4,9]{1,0} parameter(1)
-  ROOT %dot.9 = s32[13,9] dot(s8[13,4] %parameter.1, s8[4,9] %parameter.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
-  )";
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-
-  if (!IsCuda() ||
-      HasCudaComputeCapability(se::CudaComputeCapability::Volta())) {
-    MatchOptimizedHlo(hlo_text,
-                      R"(
-; CHECK: {{.*}} custom-call(s8[16,4]{1,0} [[A:%[^ ]+]], s8[4,12]{0,1} [[B:%[^ ]+]]), custom_call_target="__cublas$gemm"
-  )",
-                      /*print_operand_shape=*/true);
-  } else {
-    MatchOptimizedHlo(hlo_text,
-                      R"(
-; CHECK: {{.*}} dot(s32[13,4]{1,0} [[A:%[^ ]+]], s32[4,9]{1,0} [[B:%[^ ]+]]), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-
-  )",
-                      /*print_operand_shape=*/true);
-  }
-}
-
-TEST_P(ParameterizedGemmRewriteTest, GemmTypeCombinationCheck) {
-  if (!IsCuda()) {
-    GTEST_SKIP() << "DoBlasGemmWithAlgorithm is not yet implemented on ROCm";
-  }
-
-  std::vector<std::tuple<absl::string_view, absl::string_view, bool>>
-      type_combinations = {{"s8", "s8", true},
-                           {"s32", "s32", true},
-                           {"bf16", "bf16", true},
-                           {"f16", "f16", true},
-                           {"f32", "f32", true},
-                           {"f64", "f64", true},
-                           {"c64", "c64", true},
-                           {"c128", "c128", true},
-                           // add mix type gemm
-                           {"s8", "s32", true},
-                           {"s8", "f32", true},
-                           {"f16", "f32", true},
-                           {"bf16", "f32", true}};
-
-  if (!IsCuda() ||
-      HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
-    // For compute capabilities before Ampere, we may do upcasting, so it
-    // would be impossible for this test to fail. That is why we only add these
-    // cases when the compute capability is at least Volta.
-    std::vector<std::tuple<absl::string_view, absl::string_view, bool>>
-        more_type_combinations = {
-            {"s8", "bf16", false},  {"s8", "f16", false},
-            {"s8", "f64", false},   {"s8", "c64", false},
-            {"s8", "c128", false},
-
-            {"s32", "f32", false},  {"s32", "f64", false},
-            {"s32", "c64", false},  {"s32", "c128", false},
-
-            {"f16", "bf16", false}, {"f16", "f64", false},
-            {"f16", "c64", false},  {"f16", "c128", false},
-
-            {"bf16", "f16", false}, {"bf16", "f64", false},
-            {"bf16", "c64", false}, {"bf16", "c128", false},
-
-            {"f32", "f64", false},  {"f32", "c64", false},
-            {"f32", "c128", false},
-
-            {"f64", "c64", false},  {"f64", "c128", false},
-        };
-    type_combinations.insert(type_combinations.end(),
-                             more_type_combinations.begin(),
-                             more_type_combinations.end());
-  }
-
-  for (const auto& type_combination : type_combinations) {
-    absl::flat_hash_map<absl::string_view, absl::string_view> replacements;
-    replacements["<<ABType>>"] = std::get<0>(type_combination);
-    replacements["<<DType>>"] = std::get<1>(type_combination);
-    const char* hlo_template = R"(
-  HloModule type_combo
-
-  ENTRY type_combo {
-    %parameter.1 = <<ABType>>[4,4]{1,0} parameter(0)
-    %parameter.2 = <<ABType>>[4,4]{1,0} parameter(1)
-    ROOT %dot = <<DType>>[4,4] dot(%parameter.1, %parameter.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  }
-    )";
-    const auto hlo_text = absl::StrReplaceAll(hlo_template, replacements);
-    if (std::get<2>(type_combination)) {
-      EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
-    } else {
-      EXPECT_FALSE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
-    }
-  }
-}
-
-TEST_P(ParameterizedGemmRewriteTest, UpcastingBf16ToF64) {
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  Arg_0.1 = bf16[4,3]{1,0} parameter(0)
-  Arg_1.2 = bf16[3,6]{1,0} parameter(1)
-  ROOT dot.3 = f64[4,6]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(hlo_text));
-  GemmRewriter pass(Capability(), GetToolkitVersion());
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
-  EXPECT_TRUE(changed);
-
-  // This is a type combination which is not supported by cublasLt, expect
-  // GemmRewriter to choose legacy cublas.
-  EXPECT_THAT(
-      module->entry_computation()->root_instruction(),
-      GmockMatch(m::GetTupleElement(m::CustomCall({"__cublas$gemm"}), 0)));
-}
-
-TEST_P(ParameterizedGemmRewriteTest, UpcastingC64ToC128) {
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  Arg_0.1 = c64[4,3]{1,0} parameter(0)
-  Arg_1.2 = c64[3,6]{1,0} parameter(1)
-  ROOT dot.3 = c128[4,6]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(hlo_text));
-  GemmRewriter pass(Capability(), GetToolkitVersion());
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
-  EXPECT_TRUE(changed);
-
-  // This is a type combination which is not supported by cublasLt, expect
-  // GemmRewriter to choose legacy cublas.
-  EXPECT_THAT(
-      module->entry_computation()->root_instruction(),
-      GmockMatch(m::GetTupleElement(m::CustomCall({"__cublas$gemm"}), 0)));
-}
-
-TEST_P(ParameterizedGemmRewriteTest, UpcastingF16ToF32) {
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  Arg_0.1 = f16[4,3]{1,0} parameter(0)
-  Arg_1.2 = f16[3,6]{1,0} parameter(1)
-  ROOT dot.3 = f32[4,6]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest, highest}
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(hlo_text));
-  GemmRewriter pass(Capability(), GetToolkitVersion());
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
-  EXPECT_TRUE(changed);
-  EXPECT_THAT(
-      module->entry_computation()->root_instruction(),
-      GmockMatch(m::GetTupleElement(m::CustomCall({CustomCallTarget()}), 0)));
-}
-
-TEST_P(ParameterizedGemmRewriteTest, UpcastingF16ToF64) {
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  Arg_0.1 = f16[4,3]{1,0} parameter(0)
-  Arg_1.2 = f16[3,6]{1,0} parameter(1)
-  ROOT dot.3 = f64[4,6]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(hlo_text));
-  GemmRewriter pass(Capability(), GetToolkitVersion());
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
-  EXPECT_TRUE(changed);
-
-  // This is a type combination which is not supported by cublasLt, expect
-  // GemmRewriter to choose legacy cublas.
-  EXPECT_THAT(
-      module->entry_computation()->root_instruction(),
-      GmockMatch(m::GetTupleElement(m::CustomCall({"__cublas$gemm"}), 0)));
-}
-
-TEST_P(ParameterizedGemmRewriteTest, UpcastingF32ToF64) {
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  Arg_0.1 = f32[4,3]{1,0} parameter(0)
-  Arg_1.2 = f32[3,6]{1,0} parameter(1)
-  ROOT dot.3 = f64[4,6]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(hlo_text));
-  GemmRewriter pass(Capability(), GetToolkitVersion());
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
-  EXPECT_TRUE(changed);
-
-  // This is a type combination which is not supported by cublasLt, expect
-  // GemmRewriter to choose legacy cublas.
-  EXPECT_THAT(
-      module->entry_computation()->root_instruction(),
-      GmockMatch(m::GetTupleElement(m::CustomCall({"__cublas$gemm"}), 0)));
-}
-
-TEST_P(ParameterizedGemmRewriteTest, DoNotUpconvertOutput) {
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY main {
-  param_0 = f16[240,88]{1,0} parameter(0)
-  param_1 = f16[88,4]{1,0} parameter(1)
-  dot = f16[240,4]{1,0} dot(param_0, param_1), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,highest}
-  constant_255 = f16[] constant(255)
-  broadcast = f16[240,4]{1,0} broadcast(constant_255), dimensions={}
-  multiply = f16[240,4]{1,0} multiply(dot, broadcast)
-  ROOT result = f32[240,4]{1,0} convert(multiply)
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(hlo_text));
-  GemmRewriter pass(Capability(), GetToolkitVersion());
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
-  EXPECT_TRUE(changed);
-
-  // input fp16 and output fp32 combination is supported by legacy cublas and
-  // cublasLt, expect GemmRewriter to fuse the convert into gemm.
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch(m::Convert(
-                  m::GetTupleElement(m::CustomCall({CustomCallTarget()}), 0))));
-}
-
-TEST_P(ParameterizedGemmRewriteTest, UnsupportedMixTypeGemm) {
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY main {
-  param_0 = f32[240,88]{1,0} parameter(0)
-  param_1 = f32[88,4]{1,0} parameter(1)
-  dot = f32[240,4]{1,0} dot(param_0, param_1), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,highest}
-  constant_255 = f32[] constant(255)
-  broadcast = f32[240,4]{1,0} broadcast(constant_255), dimensions={}
-  multiply = f32[240,4]{1,0} multiply(dot, broadcast)
-  ROOT result = u8[240,4]{1,0} convert(multiply)
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(hlo_text));
-  GemmRewriter pass(Capability(), GetToolkitVersion());
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
-  EXPECT_TRUE(changed);
-
-  // u8 is not supported by legacy cublas and cublasLt, expect
-  // GemmRewriter to not fuse the convert into gemm.
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch(m::Convert(
-                  m::GetTupleElement(m::CustomCall({CustomCallTarget()}), 0))));
-}
-
-TEST_P(ParameterizedGemmRewriteTest, CheckIsGemmAliasedBeforeFusion) {
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY main {
-  Arg_0.1 = f16[8,16]{1,0} parameter(0)
-  Arg_1.2 = f16[16,32]{1,0} parameter(1)
-  dot.8 = f16[8,32]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  Arg_2.3 = f16[8,32]{1,0} parameter(2)
-  constant.5 = f16[] constant(1)
-  broadcast.6 = f16[8,32]{1,0} broadcast(constant.5), dimensions={}
-  add.7 = f16[8,32]{1,0} add(Arg_2.3, broadcast.6)
-  add.9 = f16[8,32]{1,0} add(dot.8, add.7)
-  convert.10 = f32[8,32]{1,0} convert(add.9)
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(hlo_text));
-  GemmRewriter pass(Capability(), GetToolkitVersion());
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
-  EXPECT_TRUE(changed);
-
-  // input fp16 and output fp32 combination is supported by legacy cublas and
-  // cublasLt, but gemm output is already aliased with one of the input expect
-  // GemmRewriter to not fuse the convert into gemm.
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch(m::Convert(
-                  m::GetTupleElement(m::CustomCall({CustomCallTarget()}), 0))));
-}
-
-INSTANTIATE_TEST_SUITE_P(CublasTestsBothLegacyAndLt,
-                         ParameterizedGemmRewriteTest, ::testing::Bool());
-#endif
-
-// A test fixture class for tests which are specific to legacy cublas
-class LegacyCublasGemmRewriteTest : public GemmRewriteTest {
- public:
-  DebugOptions GetDebugOptionsForTest() override {
-    DebugOptions debug_options = GemmRewriteTest::GetDebugOptionsForTest();
-    debug_options.set_xla_gpu_enable_triton_gemm(false);
-    debug_options.set_xla_gpu_enable_cublaslt(false);
-    return debug_options;
-  }
-};
-
-TEST_F(LegacyCublasGemmRewriteTest, MatrixVectorMultiplication) {
-  const char* hlo_text = R"(
-HloModule m
-
-ENTRY e {
-  p0 = f32[2048] parameter(0)
-  p1 = f32[2048, 16384] parameter(1)
-  ROOT d = f32[16384] dot(p0, p1),
-    lhs_contracting_dims={0}, rhs_contracting_dims={0}
-})";
-
-  RunAndFilecheckHloRewrite(
-      hlo_text,
-      GemmRewriter(
-          se::CudaComputeCapability{se::CudaComputeCapability::AMPERE, 0},
-          /*toolkit_version=*/12040),
-      R"(
-; CHECK:  %[[P0:.+]] = f32[2048]{0} parameter(0)
-; CHECK:  %[[P1:.+]] = f32[2048,16384]{1,0} parameter(1)
-; CHECK:  %[[CUSTOM_CALL:.+]] = (f32[16384]{0}, s8[4194304]{0}) custom-call(%[[P0]], %[[P1]]), custom_call_target="__cublas$gemm"
-)");
-}
-
-TEST_F(LegacyCublasGemmRewriteTest, MatrixVectorMultiplicationWithBatch) {
-  const char* hlo_text = R"(
-HloModule m
-
-ENTRY e {
-  p0 = f32[10, 10, 2048] parameter(0)
-  p1 = f32[10, 10, 2048, 16384] parameter(1)
-  ROOT d = f32[10, 10, 16384] dot(p0, p1),
-   lhs_batch_dims={0, 1}, rhs_batch_dims={0, 1},
-   lhs_contracting_dims={2}, rhs_contracting_dims={2}
-})";
-
-  RunAndFilecheckHloRewrite(
-      hlo_text,
-      GemmRewriter(
-          se::CudaComputeCapability{se::CudaComputeCapability::AMPERE, 0},
-          /*toolkit_version=*/12040),
-      R"(
-; CHECK:  %[[P0:.+]] = f32[10,10,2048]{2,1,0} parameter(0)
-; CHECK:  %[[P1:.+]] = f32[10,10,2048,16384]{3,2,1,0} parameter(1)
-; CHECK:  %[[CUSTOM_CALL:.+]] = (f32[10,10,16384]{2,1,0}, s8[4194304]{0}) custom-call(%[[P0]], %[[P1]]), custom_call_target="__cublas$gemm"
-)");
-}
-
-TEST_F(LegacyCublasGemmRewriteTest, SparseDotNotSupported) {
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY main {
-  lhs = f16[5,16] parameter(0)
-  rhs = f16[32,10] parameter(1)
-  meta = u16[5,2] parameter(2)
-  ROOT dot = f32[5,10] dot(lhs, rhs, meta),
-      lhs_contracting_dims={1}, rhs_contracting_dims={0}, sparsity=L.1@2:4
-})";
-  auto hlo_pass = GemmRewriter(
-      se::CudaComputeCapability{se::CudaComputeCapability::AMPERE, 0},
-      /*toolkit_version=*/12040);
-  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_text));
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&hlo_pass, module.get()));
-  EXPECT_FALSE(changed);
-}
-
-// Test that the alpha and beta fields of the GemmBackendConfig are updated.
-// A bias must be present for the beta value to be set.
-// In order to have a bias add fused, the bias term must be overwritable.
-// We assume that we may not overwrite parameters of a computation. Hence, we
-// use the third parameter to create a new value which can be overwritten and
-// will be used as the bias. This negate(param_2) has no semantic use, it simply
-// exists so that bias may be overwritten.
-TEST_F(LegacyCublasGemmRewriteTest, AlphaBetaRewrite) {
-  const char* hlo_text = R"(
-HloModule NonZeroAlphaBeta
-
-ENTRY AddDotsFunc {
-  x = f32[2,2] parameter(0)
-  y = f32[2,2] parameter(1)
-  param_2 = f32[2,2] parameter(2)
-  bias = f32[2,2] negate(param_2)
-  k = f32[] constant(3.0)
-  k_broadcast = f32[2, 2] broadcast(k), dimensions={}
-  dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,highest}
-  dot_a_multiplied = f32[2, 2] multiply(dot_a, k_broadcast)
-  ROOT out = f32[2,2] add(dot_a_multiplied, bias)
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,2], {{.*}}: f32[2,2], {{.*}}: f32[2,2]) -> f32[2,2] {
-; CHECK-DAG:     [[X:%[^ ]+]] = f32[2,2]{1,0} parameter(0)
-; CHECK-DAG:     [[Y:%[^ ]+]] = f32[2,2]{1,0} parameter(1)
-; CHECK:         [[O:%[^ ]+]] = (f32[2,2]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[X]], [[Y]], {{[^,)]+}}),
-; CHECK:           custom_call_target="__cublas$gemm",
-; CHECK:           output_to_operand_aliasing={
-; CHECK-SAME:        {0}: (2, {})
-; CHECK-SAME:      }
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":3
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":1
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["HIGHEST","HIGHEST"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-; CHECK:         ROOT [[OUT:%[^ ]+]] = f32[2,2]{1,0} get-tuple-element([[O]]), index=0
-)");
-}
-
-TEST_F(LegacyCublasGemmRewriteTest, BiasMultipleUsersNoOverwrite) {
-  const char* hlo_text = R"(
-HloModule BiasMultipleUsersNoOverwrite
-
-ENTRY AddDotsFunc {
-  x = f32[2,2] parameter(0)
-  y = f32[2,2] parameter(1)
-  bias = f32[2,2] parameter(2)
-  k = f32[] constant(3.0)
-  k_broadcast = f32[2, 2] broadcast(k), dimensions={}
-  dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,highest}
-  dot_a_multiplied = f32[2, 2] multiply(dot_a, k_broadcast)
-  biased_out = f32[2,2] add(dot_a_multiplied, bias)
-  ROOT out = f32[2,2] add(biased_out, bias)
-}
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,2], {{.*}}: f32[2,2], {{.*}}: f32[2,2]) -> f32[2,2] {
-; CHECK-DAG:     [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0)
-; CHECK-DAG:     [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1)
-; CHECK-NEXT:    [[CUSTOM_CALL:%[^ ]+]] = (f32[2,2]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]]),
-; CHECK:           custom_call_target="__cublas$gemm",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":3
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["HIGHEST","HIGHEST"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-
-)");
-}
-
-TEST_F(LegacyCublasGemmRewriteTest, BiasParameterNoOverwrite) {
-  const char* hlo_text = R"(
-HloModule BiasParameterNoOverwrite
-
-ENTRY AddDotsFunc {
-  x = f32[2,2] parameter(0)
-  y = f32[2,2] parameter(1)
-  bias = f32[2,2] parameter(2)
-  dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  ROOT out = f32[2,2] add(dot_a, bias)
-}
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,2], {{.*}}: f32[2,2], {{.*}}: f32[2,2]) -> f32[2,2] {
-; CHECK-DAG:     [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0)
-; CHECK-DAG:     [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1)
-; CHECK-NEXT:    [[GEMM:%[^ ]+]] = (f32[2,2]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]]),
-; CHECK:           custom_call_target="__cublas$gemm",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-)");
-}
-
-TEST_F(LegacyCublasGemmRewriteTest, BiasTupleParameterOverwrite) {
-  const char* hlo_text = R"(
-HloModule BiasTupleParameterOverwrite
-
-ENTRY AddDotsFunc {
-  x = f32[2,2] parameter(0)
-  y = f32[2,2] parameter(1)
-  param_2 = (f32[2,2], f32[3,3]) parameter(2)
-  bias = f32[2,2] get-tuple-element(param_2), index=0
-  dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  ROOT out = f32[2,2] add(dot_a, bias)
-}
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,2], {{.*}}: f32[2,2], {{.*}}: (f32[2,2], f32[3,3])) -> f32[2,2] {
-; CHECK-DAG:     [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0)
-; CHECK-DAG:     [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1)
-; CHECK-DAG:     [[P2:%[^ ]+]] = (f32[2,2]{1,0}, f32[3,3]{1,0}) parameter(2)
-; CHECK-DAG:     [[BIAS:%[^ ]+]] = f32[2,2]{1,0} get-tuple-element([[P2]]), index=0
-; CHECK-DAG:     [[BIAS_COPY:%[^ ]+]] = f32[2,2]{1,0} copy([[BIAS]])
-; CHECK-NEXT:    [[GEMM:%[^ ]+]] = (f32[2,2]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[BIAS_COPY]]),
-; CHECK:           custom_call_target="__cublas$gemm",
-; CHECK:           output_to_operand_aliasing={
-; CHECK-SAME:        {0}: (2, {})
-; CHECK-SAME:      }
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":1
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-)");
-}
-
-TEST_F(LegacyCublasGemmRewriteTest, AliasedBiasOverwrite) {
-  const char* hlo_text = R"(
-HloModule AliasedBiasOverwrite, input_output_alias={ {}: (2, {}, must-alias) }
-
-ENTRY AddDotsFunc {
-  x = f32[2,2] parameter(0)
-  y = f32[2,2] parameter(1)
-  bias = f32[2,2] parameter(2)
-  k = f32[] constant(3.0)
-  k_broadcast = f32[2, 2] broadcast(k), dimensions={}
-  dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,highest}
-  dot_a_multiplied = f32[2, 2] multiply(dot_a, k_broadcast)
-  ROOT out = f32[2,2] add(dot_a_multiplied, bias)
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,2], {{.*}}: f32[2,2], {{.*}}: f32[2,2]) -> f32[2,2] {
-; CHECK-DAG:     [[X:%[^ ]+]] = f32[2,2]{1,0} parameter(0)
-; CHECK-DAG:     [[Y:%[^ ]+]] = f32[2,2]{1,0} parameter(1)
-; CHECK-DAG:     [[BIAS:%[^ ]+]] = f32[2,2]{1,0} parameter(2)
-; CHECK:         [[GEMM:%[^ ]+]] = (f32[2,2]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[X]], [[Y]], [[BIAS]]),
-; CHECK:           custom_call_target="__cublas$gemm",
-; CHECK:           output_to_operand_aliasing={
-; CHECK-SAME:        {0}: (2, {})
-; CHECK-SAME:      }
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":3
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":1
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["HIGHEST","HIGHEST"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-)");
-}
-
-TEST_F(LegacyCublasGemmRewriteTest, LargerBiasMultipleUsersNoRewrite) {
-  const char* hlo_text = R"(
-HloModule LargerBiasMultipleUsersNoRewrite
-
-ENTRY AddDotsFunc {
-  x = f32[1024,1024] parameter(0)
-  y = f32[1024,1024] parameter(1)
-  bias = f32[1024,1024] parameter(2)
-  dot_a = f32[1024,1024] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  biased_out = f32[1024,1024] add(dot_a, bias)
-  ROOT out = f32[1024,1024] add(biased_out, bias)
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[1024,1024], {{.*}}: f32[1024,1024], {{.*}}: f32[1024,1024]) -> f32[1024,1024] {
-; CHECK-DAG:     [[P0:%[^ ]+]] = f32[1024,1024]{1,0} parameter(0)
-; CHECK-DAG:     [[P1:%[^ ]+]] = f32[1024,1024]{1,0} parameter(1)
-; CHECK-NEXT:    [[GEMM:%[^ ]+]] = (f32[1024,1024]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]]),
-; CHECK:           custom_call_target="__cublas$gemm",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-)");
-}
-
-// In order to have a bias add fused, the bias term must be overwritable.
-// We assume that we may not overwrite parameters of a computation. Hence, we
-// use the third parameter to create a new value which can be overwritten and
-// will be used as the bias. This negate(param_2) has no semantic use, it simply
-// exists so that bias may be overwritten.
-TEST_F(LegacyCublasGemmRewriteTest, BF16GemmWithBias) {
-  const char* hlo_text = R"(
-HloModule BF16GemmWithBias
-
-ENTRY BF16GemmWithBias {
-  x = bf16[8,8]{1,0} parameter(0)
-  y = bf16[8,8]{1,0} parameter(1)
-  dot.5 = bf16[8,8]{1,0} dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  param_2 = bf16[8,8]{1,0} parameter(2)
-  bias = bf16[8,8]{1,0} negate(param_2)
-  ROOT add.6 = bf16[8,8]{1,0} add(dot.5, bias)
-}
-  )";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{2e-3, 2e-3}));
-
-  if (IsCuda() &&
-      !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
-    GTEST_SKIP() << "Pre-Ampere casts up bf16 to fp32";
-  }
-
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-; CHECK-LABEL: ENTRY %BF16GemmWithBias ({{.*}}: bf16[8,8], {{.*}}: bf16[8,8], {{.*}}: bf16[8,8]) -> bf16[8,8] {
-; CHECK-DAG:    [[X:%[^ ]+]] = bf16[8,8]{1,0} parameter(0)
-; CHECK-DAG:    [[Y:%[^ ]+]] = bf16[8,8]{1,0} parameter(1)
-; CHECK:        [[GEMM:%[^ ]+]] = (bf16[8,8]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[X]], [[Y]], {{[^,)]+}}),
-; CHECK:           custom_call_target="__cublas$gemm",
-; CHECK:           output_to_operand_aliasing={
-; CHECK-SAME:        {0}: (2, {})
-; CHECK-SAME:      }
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":1
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-)");
-}
-
-// In order to have a bias add fused, the bias term must be overwritable.
-// We assume that we may not overwrite parameters of a computation. Hence, we
-// use the third parameter to create a new value which can be overwritten and
-// will be used as the bias. This negate(param_2) has no semantic use, it simply
-// exists so that bias may be overwritten.
-TEST_F(LegacyCublasGemmRewriteTest, MatrixBias) {
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  x = f32[2,3] parameter(0)
-  y = f32[3,4] parameter(1)
-  param_2 = f32[2,4] parameter(2)
-  bias = f32[2,4] negate(param_2)
-  dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  ROOT out = f32[2,4] add(dot_a, bias)
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[2,4]) -> f32[2,4] {
-; CHECK-DAG:     [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
-; CHECK-DAG:     [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
-; CHECK:         [[GEMM:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], {{[^,)]+}}),
-; CHECK:           custom_call_target="__cublas$gemm",
-; CHECK:           output_to_operand_aliasing={
-; CHECK-SAME:        {0}: (2, {})
-; CHECK-SAME:      }
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":1
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-)");
-}
-
-TEST_F(LegacyCublasGemmRewriteTest, MatrixBiasWhereBiasIsNotAParameter) {
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  w = f32[2,3] parameter(0)
-  x = f32[3,4] parameter(1)
-  first_dot = f32[2,4] dot(w, x), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  y = f32[2,3] parameter(2)
-  z = f32[3,4] parameter(3)
-  second_dot = f32[2,4] dot(y, z), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  ROOT out = f32[2,4] add(second_dot, first_dot)
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[2,3], {{.*}}: f32[3,4]) -> f32[2,4] {
-; CHECK-DAG:     [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
-; CHECK-DAG:     [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
-; CHECK-DAG:     [[P2:%[^ ]+]] = f32[2,3]{1,0} parameter(2)
-; CHECK-DAG:     [[P3:%[^ ]+]] = f32[3,4]{1,0} parameter(3)
-; CHECK-NEXT:    [[FIRST_GEMM:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]]),
-; CHECK:           custom_call_target="__cublas$gemm",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-; CHECK:         [[FIRST_GEMM_OUT:%[^ ]+]] = f32[2,4]{1,0} get-tuple-element([[FIRST_GEMM]]), index=0
-; CHECK-NEXT:    [[SECOND_GEMM:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P2]], [[P3]], [[FIRST_GEMM_OUT]]),
-; CHECK:           custom_call_target="__cublas$gemm",
-; CHECK:           output_to_operand_aliasing={
-; CHECK-SAME:        {0}: (2, {})
-; CHECK-SAME:      }
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":1
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-)");
-}
-
-#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
-// Test gemm matrix bias add fusion with mix type
-TEST_F(LegacyCublasGemmRewriteTest, MatrixBiasMixType) {
-  std::vector<std::tuple<absl::string_view, absl::string_view>>
-      type_combinations = {
-          {"f16", "f32"},
-          {"bf16", "f32"},
-      };
-
-  const char* hlo_text_template = R"(
-HloModule test
-
-ENTRY test {
-  x = <<ABType>>[16,32] parameter(0)
-  y = <<ABType>>[32,16] parameter(1)
-  z = <<DType>>[16,16] parameter(2)
-  dot_a = <<ABType>>[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  bias = <<DType>>[16,16] negate(z)
-  convert = <<DType>>[16,16] convert(dot_a)
-  ROOT out = <<DType>>[16,16] add(convert, bias)
-}
-
-)";
-  for (const auto& type_combination : type_combinations) {
-    absl::flat_hash_map<absl::string_view, absl::string_view> replacements;
-    replacements["<<ABType>>"] = std::get<0>(type_combination);
-    replacements["<<DType>>"] = std::get<1>(type_combination);
-    const auto hlo_text = absl::StrReplaceAll(hlo_text_template, replacements);
-    EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
-
-    if (std::get<0>(type_combination) == "bf16" && IsCuda() &&
-        !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
-      continue;
-    }
-
-    TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
-                            GetOptimizedModule(hlo_text));
-    EXPECT_THAT(optimized_module->entry_computation()->root_instruction(),
-                GmockMatch(m::GetTupleElement(
-                    m::CustomCall(m::Parameter(0), m::Parameter(1),
-                                  m::Negate(m::Parameter(2))),
-                    0)));
-  }
-}
-
-// Test batch gemm matrix bias add fusion with mix type
-TEST_F(LegacyCublasGemmRewriteTest, MatrixBiasMixTypeBatched) {
-  std::vector<std::tuple<absl::string_view, absl::string_view>>
-      type_combinations = {
-          {"f16", "f32"},
-          {"bf16", "f32"},
-      };
-
-  const char* hlo_text_template = R"(
-HloModule test
-
-ENTRY test {
-  x = <<ABType>>[4,16,32] parameter(0)
-  y = <<ABType>>[4,32,16] parameter(1)
-  z = <<DType>>[4,16,16] parameter(2)
-  dot_a = <<ABType>>[4,16,16] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0}
-  bias = <<DType>>[4,16,16] negate(z)
-  convert = <<DType>>[4,16,16] convert(dot_a)
-  ROOT out = <<DType>>[4,16,16] add(convert, bias)
-})";
-  for (const auto& type_combination : type_combinations) {
-    absl::flat_hash_map<absl::string_view, absl::string_view> replacements;
-    replacements["<<ABType>>"] = std::get<0>(type_combination);
-    replacements["<<DType>>"] = std::get<1>(type_combination);
-    const auto hlo_text = absl::StrReplaceAll(hlo_text_template, replacements);
-    EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
-
-    if (std::get<0>(type_combination) == "bf16" && IsCuda() &&
-        !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
-      continue;
-    }
-
-    TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
-                            GetOptimizedModule(hlo_text));
-    EXPECT_THAT(optimized_module->entry_computation()->root_instruction(),
-                GmockMatch(m::GetTupleElement(
-                    m::CustomCall(m::Parameter(0), m::Parameter(1),
-                                  m::Negate(m::Parameter(2))),
-                    0)));
-  }
-}
-#endif
-
-// Test batch gemm matrix bias add fusion with mix type that is not supported.
-TEST_F(LegacyCublasGemmRewriteTest, MatrixBiasMixTypeNotSupported) {
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  x = bf16[16,32] parameter(0)
-  y = bf16[32,16] parameter(1)
-  z = f64[16,16] parameter(2)
-  dot_a = bf16[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  bias = f64[16,16] negate(z)
-  convert = f64[16,16] convert(dot_a)
-  ROOT out = f64[16,16] add(convert, bias)
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
-                          GetOptimizedModule(hlo_text));
-  MatchOptimizedHlo(hlo_text, R"(
-; CHECK:        %[[custom_call:.*]] = {{.*}} custom-call{{.*}}__cublas$gemm
-; CHECK:        %[[gte:.*]] = {{.*}} get-tuple-element{{.*}}%[[custom_call]]
-; CHECK:        ROOT {{.*}} fusion({{.*}}%[[gte]]
-)");
-}
-
-// Test batch gemm matrix bias add fusion with mix type that is not supported
-// because there are consumers of bias add.
-TEST_F(LegacyCublasGemmRewriteTest, MatrixBiasMixTypeAddWithMoreConsumers) {
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  x = bf16[16,32] parameter(0)
-  y = bf16[32,16] parameter(1)
-  z = f32[16,16] parameter(2)
-  dot_a = bf16[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  bias = f32[16,16] negate(z)
-  convert = f32[16,16] convert(dot_a)
-  add_bias = f32[16,16] add(convert, bias)
-  ROOT out = f32[16,16] negate(add_bias)
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
-
-  if (IsCuda() &&
-      !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
-    GTEST_SKIP() << "Pre-Ampere casts up bf16 to fp32";
-  }
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
-                          GetOptimizedModule(hlo_text));
-  MatchOptimizedHlo(hlo_text, R"(
-; CHECK:        %[[custom_call:.*]] = {{.*}} custom-call{{.*}}__cublas$gemm
-; CHECK:        %[[gte:.*]] = {{.*}} get-tuple-element{{.*}}%[[custom_call]]
-; CHECK:        ROOT {{.*}} fusion({{.*}}%[[gte]]
-)");
-}
-
-TEST_F(LegacyCublasGemmRewriteTest, MergeBitcastAndAdd) {
-  const char* hlo_text = R"(
-HloModule test
-ENTRY test {
-  x = f32[2,2] parameter(0)
-  y = f32[2,2] parameter(1)
-  bias = f32[4] parameter(2)
-  dot = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  ROOT out = f32[4] add(f32[4] bitcast(dot), bias)
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(hlo_text));
-  GemmRewriter pass(Capability(), GetToolkitVersion());
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
-  EXPECT_TRUE(changed);
-
-  EXPECT_THAT(
-      module->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::Bitcast(
-              m::GetTupleElement(
-                  m::CustomCall(
-                      {"__cublas$gemm"}, m::Parameter(0), m::Parameter(1),
-                      m::Bitcast(m::Parameter(2)).WithShape(F32, {2, 2})),
-                  0))
-              .WithShape(F32, {4})));
-}
-
-// In order to have a bias add fused, the bias term must be overwritable.
-// We assume that we may not overwrite parameters of a computation. Hence, we
-// use the third parameter to create a new value which can be overwritten and
-// will be used as the bias. This negate(param_2) has no semantic use, it simply
-// exists so that bias may be overwritten.
-TEST_F(LegacyCublasGemmRewriteTest, FoldConstantBias) {
-  const char* hlo_text = R"(
-HloModule test
-ENTRY test {
-  x = f32[2,2] parameter(0)
-  y = f32[2,2] parameter(1)
-  bias = f32[2,2] broadcast(f32[2] constant({0, 0})), dimensions={0}
-
-  dot1 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  param_2 = f32[2,2] parameter(2)
-  bias1 = f32[2,2] negate(param_2)
-  sum1 = add(dot1, bias1)
-
-  dot2 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  sum2 = add(dot2, f32[2,2] reshape(bias))
-
-  dot3 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  bias3 = f32[2,2] transpose(bias), dimensions={1,0}
-  sum3 = add(dot3, bias3)
-
-  dot4 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  sum4 = add(dot4, f32[2,2] bitcast(bias))
-
-  ROOT root = tuple(sum1, sum2, sum3, sum4)
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(hlo_text));
-  GemmRewriter pass(Capability(), GetToolkitVersion());
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
-  SCOPED_TRACE(module->ToString());
-  EXPECT_TRUE(changed);
-
-  EXPECT_THAT(
-      module->entry_computation()->root_instruction(),
-      GmockMatch(m::Tuple(
-          m::GetTupleElement(m::CustomCall(m::Parameter(0), m::Parameter(1),
-                                           m::Negate(m::Parameter(2))),
-                             0),
-          m::GetTupleElement(
-              m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()),
-              0),
-          m::GetTupleElement(
-              m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()),
-              0),
-          m::GetTupleElement(
-              m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()),
-              0))));
-}
-
-#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
-// A test fixture class for tests which are specific to cublasLt
-class CublasLtGemmRewriteTest : public GemmRewriteTest {
- public:
-  DebugOptions GetDebugOptionsForTest() override {
-    DebugOptions debug_options = GemmRewriteTest::GetDebugOptionsForTest();
-    debug_options.set_xla_gpu_enable_cublaslt(true);
-    debug_options.set_xla_gpu_enable_triton_gemm(false);
-    return debug_options;
-  }
-
- protected:
-  void SetUp() override {
-    if (SkipGpuBlasLtTest()) {
-      GTEST_SKIP() << "BlasLt is not supported on this GPU architecture";
-    }
-  }
-};
-
-TEST_F(CublasLtGemmRewriteTest, AlphaBetaRewrite) {
-  const char* hlo_text = R"(
-HloModule NonZeroAlphaBeta
-
-ENTRY AddDotsFunc {
-  x = f32[2,2] parameter(0)
-  y = f32[2,2] parameter(1)
-  bias = f32[2,2] parameter(2)
-  k = f32[] constant(3.0)
-  k_broadcast = f32[2, 2] broadcast(k), dimensions={}
-  dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,highest}
-  dot_a_multiplied = f32[2, 2] multiply(dot_a, k_broadcast)
-  ROOT out = f32[2,2] add(dot_a_multiplied, bias)
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,2], {{.*}}: f32[2,2], {{.*}}: f32[2,2]) -> f32[2,2] {
-; CHECK-DAG:     [[X:%[^ ]+]] = f32[2,2]{1,0} parameter(0)
-; CHECK-DAG:     [[Y:%[^ ]+]] = f32[2,2]{1,0} parameter(1)
-; CHECK-DAG:     [[BIAS:%[^ ]+]] = f32[2,2]{1,0} parameter(2)
-; CHECK-NEXT:    [[GEMM:%[^ ]+]] = (f32[2,2]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[X]], [[Y]], [[BIAS]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":3
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":1
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["HIGHEST","HIGHEST"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-; CHECK-NEXT  ROOT [[OUT:%[^ ]+]] = f32[2,2]{1,0} get-tuple-element(%cublas-lt-matmul.2.0), index=0
-)");
-}
-
-TEST_F(CublasLtGemmRewriteTest, BiasMultipleUsersNoOverwrite) {
-  const char* hlo_text = R"(
-HloModule BiasMultipleUsersNoOverwrite
-
-ENTRY AddDotsFunc {
-  x = f32[2,2] parameter(0)
-  y = f32[2,2] parameter(1)
-  bias = f32[2,2] parameter(2)
-  k = f32[] constant(3.0)
-  k_broadcast = f32[2, 2] broadcast(k), dimensions={}
-  dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,highest}
-  dot_a_multiplied = f32[2, 2] multiply(dot_a, k_broadcast)
-  biased_out = f32[2,2] add(dot_a_multiplied, bias)
-  ROOT out = f32[2,2] add(biased_out, bias)
-}
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,2], {{.*}}: f32[2,2], {{.*}}: f32[2,2]) -> f32[2,2] {
-; CHECK-DAG:     [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0)
-; CHECK-DAG:     [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1)
-; CHECK-DAG:     [[BIAS:%[^ ]+]] = f32[2,2]{1,0} parameter(2)
-; CHECK-NEXT:    [[GEMM:%[^ ]+]] = (f32[2,2]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[BIAS]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul",
-; CHECK-NOT:       output_to_operand_aliasing
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":3
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":1
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["HIGHEST","HIGHEST"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-)");
-}
-
-TEST_F(CublasLtGemmRewriteTest, LargerBiasMultipleUsersNoRewrite) {
-  const char* hlo_text = R"(
-HloModule LargerBiasMultipleUsersNoRewrite
-
-ENTRY AddDotsFunc {
-  x = f32[1024,1024] parameter(0)
-  y = f32[1024,1024] parameter(1)
-  bias = f32[1024,1024] parameter(2)
-  dot_a = f32[1024,1024] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  biased_out = f32[1024,1024] add(dot_a, bias)
-  ROOT out = f32[1024,1024] add(biased_out, bias)
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[1024,1024], {{.*}}: f32[1024,1024], {{.*}}: f32[1024,1024]) -> f32[1024,1024] {
-; CHECK-DAG:     [[P0:%[^ ]+]] = f32[1024,1024]{1,0} parameter(0)
-; CHECK-DAG:     [[P1:%[^ ]+]] = f32[1024,1024]{1,0} parameter(1)
-; CHECK-DAG:     [[BIAS:%[^ ]+]] = f32[1024,1024]{1,0} parameter(2)
-; CHECK-NEXT:    [[GEMM_TUPLE:%[^ ]+]] = (f32[1024,1024]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[BIAS]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":1
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-; CHECK-NEXT:  [[GEMM:%[^ ]+]] = f32[1024,1024]{1,0} get-tuple-element([[GEMM_TUPLE]]), index=0
-; CHECK-NEXT:  ROOT [[OUT:%[^ ]+]] = f32[1024,1024]{1,0} add([[GEMM]], [[BIAS]])
-)");
-}
-
-TEST_F(CublasLtGemmRewriteTest, BF16GemmWithBias) {
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY BF16GemmWithBias {
-  x = bf16[8,8]{1,0} parameter(0)
-  y = bf16[8,8]{1,0} parameter(1)
-  dot.5 = bf16[8,8]{1,0} dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  bias = bf16[8,8]{1,0} parameter(2)
-  ROOT add.6 = bf16[8,8]{1,0} add(dot.5, bias)
-}
-  )";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
-
-  if (IsCuda() &&
-      !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
-    GTEST_SKIP() << "Pre-Ampere casts up bf16 to fp32";
-  }
-
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-; CHECK-LABEL: ENTRY %BF16GemmWithBias ({{.*}}: bf16[8,8], {{.*}}: bf16[8,8], {{.*}}: bf16[8,8]) -> bf16[8,8] {
-; CHECK-DAG:    [[X:%[^ ]+]] = bf16[8,8]{1,0} parameter(0)
-; CHECK-DAG:    [[Y:%[^ ]+]] = bf16[8,8]{1,0} parameter(1)
-; CHECK-DAG:    [[BIAS:%[^ ]+]] = bf16[8,8]{1,0} parameter(2)
-; CHECK-NEXT:   [[GEMM:%[^ ]+]] = (bf16[8,8]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[X]], [[Y]], [[BIAS]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":1
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-)");
-}
-
-TEST_F(CublasLtGemmRewriteTest, MatrixBias) {
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  x = f32[2,3] parameter(0)
-  y = f32[3,4] parameter(1)
-  z = f32[2,4] parameter(2)
-  dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  ROOT out = f32[2,4] add(dot_a, z)
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[2,4]) -> f32[2,4] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[2,4]{1,0} parameter(2)
-; CHECK-NEXT:    [[GEMM:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":1
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-)");
-}
-
-TEST_F(CublasLtGemmRewriteTest, MatrixBiasWhereBiasIsNotAParameter) {
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  w = f32[2,3] parameter(0)
-  x = f32[3,4] parameter(1)
-  first_dot = f32[2,4] dot(w, x), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  y = f32[2,3] parameter(2)
-  z = f32[3,4] parameter(3)
-  second_dot = f32[2,4] dot(y, z), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  ROOT out = f32[2,4] add(second_dot, first_dot)
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[2,3], {{.*}}: f32[3,4]) -> f32[2,4] {
-; CHECK-DAG:     [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
-; CHECK-DAG:     [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
-; CHECK-DAG:     [[P2:%[^ ]+]] = f32[2,3]{1,0} parameter(2)
-; CHECK-DAG:     [[P3:%[^ ]+]] = f32[3,4]{1,0} parameter(3)
-; CHECK-NEXT:    [[FIRST_GEMM_TUPLE:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-; CHECK:         [[FIRST_GEMM:%[^ ]+]] = f32[2,4]{1,0} get-tuple-element([[FIRST_GEMM_TUPLE]]), index=0
-; CHECK-NEXT:    [[SECOND_GEMM:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P2]], [[P3]], [[FIRST_GEMM]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul",
-; CHECK:           output_to_operand_aliasing={
-; CHECK:              {0}: (2, {})
-; CHECK:           }
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":1
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-)");
-}
-
-TEST_F(CublasLtGemmRewriteTest, VectorBias) {
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  x = f32[2,3] parameter(0)
-  y = f32[3,4] parameter(1)
-  z = f32[4] parameter(2)
-  dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  z_bcast = f32[2,4] broadcast(z), dimensions={1}
-  ROOT out = f32[2,4] add(dot_a, z_bcast)
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[4]) -> f32[2,4] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[4]{0} parameter(2)
-; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"BIAS"
-; CHECK:           }
-)");
-}
-
-// Epilogue Fusion disabled when GEMM has multiple users.
-TEST_F(CublasLtGemmRewriteTest, VectorBiasMultipleUsers) {
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  x = f32[4,4] parameter(0)
-  y = f32[4,4] parameter(1)
-  z = f32[4] parameter(2)
-  c = f32[] constant(5)
-  dot_a = f32[4,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,highest}
-  z_bcast = f32[4,4] broadcast(z), dimensions={1}
-  add_a = f32[4,4] add(dot_a, z_bcast)
-  c_bcast = f32[4,4] broadcast(c), dimensions={}
-  dot_b = f32[4,4] dot(dot_a, c_bcast), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,highest}
-  ROOT out = f32[4,4] dot(add_a, dot_b), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,highest}
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-
-; CHECK:        [[FUSED_COMPUTATION:%[^ ]+]] ([[DUMMY0:[^ ]+]]: f32[4,4], [[DUMMY1:[^ ]+]]: f32[4]) -> f32[4,4] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[4,4]{1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[4]{0} parameter(1)
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[4,4]{1,0} broadcast([[P1]]), dimensions={1}
-; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = f32[4,4]{1,0} add([[P0]], [[P2]])
-}
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[4,4], {{.*}}: f32[4,4], {{.*}}: f32[4]) -> f32[4,4] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[4,4]{1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[4,4]{1,0} parameter(1)
-; CHECK-NEXT:    [[MATMUL0_TUPLE:%[^ ]+]] = (f32[4,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["HIGHEST","HIGHEST"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-; CHECK-NEXT:    [[MATMUL0:%[^ ]+]] = f32[4,4]{1,0} get-tuple-element([[MATMUL0_TUPLE]]), index=0
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[4]{0} parameter(2)
-; CHECK-NEXT:    [[FUSION:%[^ ]+]] = f32[4,4]{1,0} fusion([[MATMUL0]], [[P2]]), kind=kLoop, calls=[[FUSED_COMPUTATION]]
-; CHECK:         [[MATMUL1_TUPLE:%[^ ]+]] = (f32[4,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[MATMUL0]]
-; CHECK:           custom_call_target="__cublas$lt$matmul",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["HIGHEST","HIGHEST"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-; CHECK-NEXT:    [[MATMUL1:%[^ ]+]] = f32[4,4]{1,0} get-tuple-element([[MATMUL1_TUPLE]]), index=0
-; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f32[4,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[FUSION]], [[MATMUL1]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["HIGHEST","HIGHEST"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-      )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, BatchedVectorBias) {
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  x = f32[2,3,4] parameter(0)
-  y = f32[4,5,6] parameter(1)
-  z = f32[3,5,6] parameter(2)
-  dot_a = f32[2,3,5,6] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={0}, operand_precision={highest,highest}
-  z_bcast = f32[2,3,5,6] broadcast(z), dimensions={1,2,3}
-  ROOT out = f32[2,3,5,6] add(dot_a, z_bcast)
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3,4], {{.*}}: f32[4,5,6], {{.*}}: f32[3,5,6]) -> f32[2,3,5,6] {
-; CHECK:         [[MATMUL_TUPLE:%[^ ]+]] = (f32[6,30]{1,0}, s8[{{[0-9]+}}]{0}) custom-call(
-; CHECK:           custom_call_target="__cublas$lt$matmul",
-; CHECK:           output_to_operand_aliasing={
-; CHECK:              {0}: (2, {})
-; CHECK:           }
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":1
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["HIGHEST","HIGHEST"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-; CHECK-NEXT:   [[MATMUL:%[^ ]+]] = f32[6,30]{1,0} get-tuple-element([[MATMUL_TUPLE]]), index=0
-; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = f32[2,3,5,6]{3,2,1,0} bitcast([[MATMUL]])
-      )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, BatchedSharedVectorBias) {
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  x = f32[2,3,4] parameter(0)
-  y = f32[4,5,6] parameter(1)
-  z = f32[6] parameter(2)
-  dot_a = f32[2,3,5,6] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={0}, operand_precision={highest,highest}
-  z_bcast = f32[2,3,5,6] broadcast(z), dimensions={3}
-  ROOT out = f32[2,3,5,6] add(dot_a, z_bcast)
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3,4], {{.*}}: f32[4,5,6], {{.*}}: f32[6]) -> f32[2,3,5,6] {
-; CHECK:         [[MATMUL_TUPLE:%[^ ]+]] = (f32[6,30]{1,0}, s8[{{[0-9]+}}]{0}) custom-call(
-; CHECK:           custom_call_target="__cublas$lt$matmul",
-; CHECK:           output_to_operand_aliasing={
-; CHECK:              {0}: (2, {})
-; CHECK:           }
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":1
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["HIGHEST","HIGHEST"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-; CHECK:         [[MATMUL:%[^ ]+]] = f32[6,30]{1,0} get-tuple-element([[MATMUL_TUPLE]]), index=0
-; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = f32[2,3,5,6]{3,2,1,0} bitcast([[MATMUL]])
-      )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, VectorBiasIncorrectAxisFusedAsMatrix) {
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  x = f32[2,3] parameter(0)
-  y = f32[3,4] parameter(1)
-  z = f32[2] parameter(2)
-  dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  z_bcast = f32[2,4] broadcast(z), dimensions={0}
-  add = f32[2,4] add(dot_a, z_bcast)
-  ROOT out = f32[4,2] transpose(add), dimensions={1,0}
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[2]) -> f32[4,2] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[2]{0} parameter(2)
-; CHECK-NEXT:    [[MATMUL_TUPLE:%[^ ]+]] = (f32[2,4]{0,1}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"BIAS"
-; CHECK:           }
-; CHECK-NEXT:    [[MATMUL:%[^ ]+]] = f32[2,4]{0,1} get-tuple-element([[MATMUL_TUPLE]]), index=0
-; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = f32[4,2]{1,0} bitcast([[MATMUL]])
-)");
-}
-
-TEST_F(CublasLtGemmRewriteTest, VectorBiasSliced) {
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  x = f32[4,3] parameter(0)
-  y = f32[3,4] parameter(1)
-  z = f32[3] parameter(2)
-  dot_a = f32[4,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  slice_a = f32[2,3] slice(dot_a), slice={[0:2], [0:3]}
-  z_bcast = f32[2,3] broadcast(z), dimensions={1}
-  ROOT out = f32[2,3] add(slice_a, z_bcast)
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[4,3], {{.*}}: f32[3,4], {{.*}}: f32[3]) -> f32[2,3] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[4,3]{1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[3]{0} parameter(2)
-; CHECK-NEXT:    [[MATMUL:%[^ ]+]] = (f32[4,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"BIAS"
-; CHECK:           }
-; CHECK-NEXT:    [[GETTUPLE:%[^ ]+]] = f32[4,4]{1,0} get-tuple-element([[MATMUL]]), index=0
-; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = f32[2,3]{1,0} slice([[GETTUPLE]]), slice={[0:2], [0:3]}
-      )");
-}
-
-// Epilogue Fusion disabled when slice has multiple users.
-TEST_F(CublasLtGemmRewriteTest, VectorBiasSlicedMultipleUsers) {
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  x = f32[2,3] parameter(0)
-  y = f32[3,4] parameter(1)
-  z = f32[2] parameter(2)
-  c = f32[] constant(5)
-  dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  slice_a = f32[2,2] slice(dot_a), slice={[0:2], [0:2]}
-  z_bcast = f32[2,2] broadcast(z), dimensions={1}
-  add_a = f32[2,2] add(slice_a, z_bcast)
-  c_bcast = f32[2,2] broadcast(c), dimensions={}
-  dot_b = f32[2,2] dot(slice_a, c_bcast), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  ROOT out = f32[2,2] dot(add_a, dot_b), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[2]) -> f32[2,2] {
-; CHECK-DAG:     [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
-; CHECK-DAG:     [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
-; CHECK-DAG:     [[P2:%[^ ]+]] = f32[2]{0} parameter(2)
-; CHECK-NEXT:    [[MATMUL0_TUPLE:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-; CHECK:         [[MATMUL1_TUPLE:%[^ ]+]] = (f32[2,2]{1,0}, s8[{{[0-9]+}}]{0}) custom-call(
-; CHECK:           custom_call_target="__cublas$lt$matmul",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-; CHECK:         [[MATMUL1:%[^ ]+]] = f32[2,2]{1,0} get-tuple-element([[MATMUL1_TUPLE]]), index=0
-; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f32[2,2]{1,0}, s8[{{[0-9]+}}]{0}) custom-call{{.*}}[[MATMUL1]]
-; CHECK:           custom_call_target="__cublas$lt$matmul",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-      )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, VectorBiasTransposed) {
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  x = f32[2,3] parameter(0)
-  y = f32[3,4] parameter(1)
-  z = f32[2] parameter(2)
-  dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  z_bcast = f32[2,4] parameter(3)
-  ROOT out = f32[2,4] add(dot_a, z_bcast)
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-; CHECK:    [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
-; CHECK-NEXT:    [[P2_BCAST:%[^ ]+]] = f32[2,4]{1,0} parameter(3)
-; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2_BCAST]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":1
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-)");
-}
-
-TEST_F(CublasLtGemmRewriteTest, VectorBiasThenMatrixBias) {
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  x = f32[2,3] parameter(0)
-  y = f32[3,4] parameter(1)
-  z = f32[4] parameter(2)
-  z2 = f32[2,4] parameter(3)
-  dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  z_bcast = f32[2,4] broadcast(z), dimensions={1}
-  add0 = f32[2,4] add(dot_a, z_bcast)
-  ROOT add1 = f32[2,4] add(add0, z2)
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[4], {{.*}}: f32[2,4]) -> f32[2,4] {
-; CHECK-DAG:     [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
-; CHECK-DAG:     [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
-; CHECK-DAG:     [[VECTOR_BIAS:%[^ ]+]] = f32[4]{0} parameter(2)
-; CHECK-DAG:     [[MATRIX_BIAS:%[^ ]+]] = f32[2,4]{1,0} parameter(3)
-; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[MATRIX_BIAS]], [[VECTOR_BIAS]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":1
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"BIAS"
-; CHECK:           }
-)");
-}
-
-TEST_F(CublasLtGemmRewriteTest, BF16VectorBias) {
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  x = bf16[16,24] parameter(0)
-  y = bf16[24,32] parameter(1)
-  z = bf16[32] parameter(2)
-  dot_a = bf16[16,32] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  z_bcast = bf16[16,32] broadcast(z), dimensions={1}
-  ROOT out = bf16[16,32] add(dot_a, z_bcast)
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{3e-3, 1e-3}));
-
-  if (IsCuda() &&
-      !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
-    GTEST_SKIP() << "Pre-Ampere casts up bf16 to fp32";
-  }
-
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: bf16[16,24], {{.*}}: bf16[24,32], {{.*}}: bf16[32]) -> bf16[16,32] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = bf16[16,24]{1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = bf16[24,32]{1,0} parameter(1)
-; CHECK-NEXT:    [[P2:%[^ ]+]] = bf16[32]{0} parameter(2)
-; CHECK-NEXT:    [[OUT:%[^ ]+]] = (bf16[16,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"BIAS"
-      )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, BF16VectorBiasPadded) {
-  if (IsCuda() &&
-      !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
-    GTEST_SKIP() << "Padding of GEMM bf16 operands only implemented on "
-                    "architectures with bf16 Tensor Cores.";
-  }
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  x = bf16[2,3] parameter(0)
-  y = bf16[3,4] parameter(1)
-  z = bf16[4] parameter(2)
-  dot_a = bf16[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  z_bcast = bf16[2,4] broadcast(z), dimensions={1}
-  ROOT out = bf16[2,4] add(dot_a, z_bcast)
-})";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
-  MatchOptimizedHlo(hlo_text, R"(
-; CHECK-DAG: ENTRY %test ({{.*}}: bf16[2,3], {{.*}}: bf16[3,4], {{.*}}: bf16[4]) -> bf16[2,4] {
-; CHECK-DAG:    bf16[8,8]{1,0} pad({{.*}}), padding=0_6x0_5
-; CHECK-DAG:    bf16[8,8]{1,0} pad({{.*}}), padding=0_5x0_4
-      )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, ReluActivation) {
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  x = f32[2,3] parameter(0)
-  y = f32[3,4] parameter(1)
-  dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  c = f32[] constant(0)
-  c_bcast = f32[2,4] broadcast(c), dimensions={}
-  ROOT out = f32[2,4] maximum(dot_a, c_bcast)
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4]) -> f32[2,4] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
-; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"RELU"
-; CHECK:           }
-      )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, BatchedReluActivation) {
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  x = f32[2,3,4] parameter(0)
-  y = f32[4,5,6] parameter(1)
-  dot_a = f32[2,3,5,6] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={0}, operand_precision={highest,highest}
-  c = f32[] constant(0)
-  c_bcast = f32[2,3,5,6] broadcast(c), dimensions={}
-  ROOT out = f32[2,3,5,6] maximum(dot_a, c_bcast)
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3,4], {{.*}}: f32[4,5,6]) -> f32[2,3,5,6] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,3,4]{2,1,0} parameter(0)
-; CHECK-NEXT:    [[P0_BITCAST:%[^ ]+]] = f32[6,4]{1,0} bitcast([[P0]])
-; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[4,5,6]{2,1,0} parameter(1)
-; CHECK-NEXT:    [[P1_BITCAST:%[^ ]+]] = f32[4,30]{1,0}
-; CHECK-NEXT:    [[MATMUL_TUPLE:%[^ ]+]] = (f32[6,30]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["HIGHEST","HIGHEST"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"RELU"
-; CHECK:           }
-; CHECK:         [[MATMUL:%[^ ]+]] = f32[6,30]{1,0} get-tuple-element([[MATMUL_TUPLE]]), index=0
-; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = f32[2,3,5,6]{3,2,1,0} bitcast([[MATMUL]])
-      )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, ReluActivationSliced) {
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  x = f32[2,3] parameter(0)
-  y = f32[3,4] parameter(1)
-  dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  c = f32[] constant(0)
-  c_bcast = f32[2,2] broadcast(c), dimensions={}
-  slice_a = f32[2,2] slice(dot_a), slice={[0:2], [0:2]}
-  ROOT out = f32[2,2] maximum(slice_a, c_bcast)
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4]) -> f32[2,2] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
-; CHECK-NEXT:    [[MATMUL_TUPLE:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"RELU"
-; CHECK:           }
-; CHECK:         [[MATMUL:%[^ ]+]] = f32[2,4]{1,0} get-tuple-element([[MATMUL_TUPLE]]), index=0
-; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = f32[2,2]{1,0} slice([[MATMUL]]), slice={[0:2], [0:2]}
-      )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, MatrixBiasReluActivation) {
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  x = f32[2,3] parameter(0)
-  y = f32[3,4] parameter(1)
-  z = f32[2,4] parameter(2)
-  dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  add = f32[2,4] add(dot_a, z)
-  c = f32[] constant(0)
-  c_bcast = f32[2,4] broadcast(c), dimensions={}
-  ROOT out = f32[2,4] maximum(add, c_bcast)
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[2,4]) -> f32[2,4] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[2,4]{1,0} parameter(2)
-; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":1
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"RELU"
-; CHECK:           }
-      )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, SquareMatrixBiasReluActivation) {
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  x = f32[4,4] parameter(0)
-  y = f32[4,4] parameter(1)
-  z = f32[4,4] parameter(2)
-  dot_a = f32[4,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  add = f32[4,4] add(dot_a, z)
-  c = f32[] constant(0)
-  c_bcast = f32[4,4] broadcast(c), dimensions={}
-  ROOT out = f32[4,4] maximum(add, c_bcast)
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[4,4], {{.*}}: f32[4,4], {{.*}}: f32[4,4]) -> f32[4,4] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[4,4]{1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[4,4]{1,0} parameter(1)
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[4,4]{1,0} parameter(2)
-; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f32[4,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":1
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"RELU"
-; CHECK:           }
-      )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, VectorBiasReluActivation) {
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  x = f32[2,3] parameter(0)
-  y = f32[3,4] parameter(1)
-  z = f32[4] parameter(2)
-  dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  z_bcast = f32[2,4] broadcast(z), dimensions={1}
-  add = f32[2,4] add(dot_a, z_bcast)
-  c = f32[] constant(0)
-  c_bcast = f32[2,4] broadcast(c), dimensions={}
-  ROOT out = f32[2,4] maximum(add, c_bcast)
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[4]) -> f32[2,4] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[4]{0} parameter(2)
-; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"BIAS_RELU"
-; CHECK:           }
-      )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, BatchedVectorBiasReluActivation) {
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  x = f32[2,3,4] parameter(0)
-  y = f32[4,5,6] parameter(1)
-  z = f32[3,5,6] parameter(2)
-  dot_a = f32[2,3,5,6] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={0}, operand_precision={highest,highest}
-  z_bcast = f32[2,3,5,6] broadcast(z), dimensions={1,2,3}
-  add = f32[2,3,5,6] add(dot_a, z_bcast)
-  c = f32[] constant(0)
-  c_bcast = f32[2,3,5,6] broadcast(c), dimensions={}
-  ROOT out = f32[2,3,5,6] maximum(add, c_bcast)
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3,4], {{.*}}: f32[4,5,6], {{.*}}: f32[3,5,6]) -> f32[2,3,5,6] {
-; CHECK:         [[MATMUL_TUPLE:%[^ ]+]] = (f32[6,30]{1,0}, s8[{{[0-9]+}}]{0}) custom-call(
-; CHECK:           custom_call_target="__cublas$lt$matmul",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":1
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["HIGHEST","HIGHEST"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"RELU"
-; CHECK:           }
-; CHECK-NEXT:    [[MATMUL:%[^ ]+]] = f32[6,30]{1,0} get-tuple-element([[MATMUL_TUPLE]]), index=0
-; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = f32[2,3,5,6]{3,2,1,0} bitcast([[MATMUL]])
-      )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, VectorBiasTransposedReluActivation) {
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  x = f32[2,3] parameter(0)
-  y = f32[3,4] parameter(1)
-  z = f32[2] parameter(2)
-  dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  z_bcast = f32[2,4] broadcast(z), dimensions={0}
-  add = f32[2,4] add(dot_a, z_bcast)
-  c = f32[] constant(0)
-  c_bcast = f32[2,4] broadcast(c), dimensions={}
-  maximum = f32[2,4] maximum(add, c_bcast)
-  ROOT out = f32[4,2] transpose(maximum), dimensions={1,0}
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[2]) -> f32[4,2] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[2]{0} parameter(2)
-; CHECK-NEXT:    [[MATMUL_TUPLE:%[^ ]+]] = (f32[2,4]{0,1}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul",
-; CHECK:           backend_config={
-; CHECK-DAG:       "alpha_real":1
-; CHECK-DAG:       "alpha_imag":0
-; CHECK-DAG:       "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"BIAS_RELU"
-; CHECK:           }
-; CHECK-NEXT:    [[MATMUL:%[^ ]+]] = f32[2,4]{0,1} get-tuple-element([[MATMUL_TUPLE]]), index=0
-; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = f32[4,2]{1,0} bitcast([[MATMUL]])
-      )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, VectorBiasThenMatrixBiasReluActivation) {
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  x = f32[2,3] parameter(0)
-  y = f32[3,4] parameter(1)
-  z_vec = f32[4] parameter(2)
-  z_matrix = f32[2,4] parameter(3)
-  dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  z_bcast = f32[2,4] broadcast(z_vec), dimensions={1}
-  add0 = f32[2,4] add(dot_a, z_bcast)
-  add1 = f32[2,4] add(add0, z_matrix)
-  c = f32[] constant(0)
-  c_bcast = f32[2,4] broadcast(c), dimensions={}
-  ROOT out = f32[2,4] maximum(add1, c_bcast)
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[4], {{.*}}: f32[2,4]) -> f32[2,4] {
-; CHECK-DAG:     [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
-; CHECK-DAG:     [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
-; CHECK-DAG:     [[P2:%[^ ]+]] = f32[4]{0} parameter(2)
-; CHECK-DAG:     [[P3:%[^ ]+]] = f32[2,4]{1,0} parameter(3)
-; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P3]], [[P2]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":1
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"BIAS_RELU"
-; CHECK:           }
-      )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, ApproxGeluActivation) {
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  x = f32[2,3] parameter(0)
-  y = f32[3,4] parameter(1)
-  dot = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  mul.0 = f32[2,4] multiply(dot, dot)
-  mul.1 = f32[2,4] multiply(dot, mul.0)
-  const.0 = f32[] constant(0.044715)
-  bcast.0 = f32[2,4] broadcast(const.0), dimensions={}
-  mul.2 = f32[2,4] multiply(mul.1, bcast.0)
-  add.0 = f32[2,4] add(dot, mul.2)
-  const.1 = f32[] constant(0.797884583)
-  bcast.1 = f32[2,4] broadcast(const.1), dimensions={}
-  mul.3 = f32[2,4] multiply(add.0, bcast.1)
-  tanh = f32[2,4] tanh(mul.3)
-  const.2 = f32[] constant(1)
-  bcast.2 = f32[2,4] broadcast(const.2), dimensions={}
-  add.2 = f32[2,4] add(tanh, bcast.2)
-  const.3 = f32[] constant(0.5)
-  bcast.3 = f32[2,4] broadcast(const.3), dimensions={}
-  mul.4 = f32[2,4] multiply(add.2, bcast.3)
-  ROOT out = f32[2,4] multiply(dot, mul.4)
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4]) -> f32[2,4] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
-; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"GELU"
-; CHECK:           }
-      )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, ApproxGeluActivationWrongConstant) {
-  // Modify one constant slightly, so it should no longer pattern match.
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  x = f32[2,3] parameter(0)
-  y = f32[3,4] parameter(1)
-  dot = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  mul.0 = f32[2,4] multiply(dot, dot)
-  mul.1 = f32[2,4] multiply(dot, mul.0)
-  const.0 = f32[] constant(0.05)
-  bcast.0 = f32[2,4] broadcast(const.0), dimensions={}
-  mul.2 = f32[2,4] multiply(mul.1, bcast.0)
-  add.0 = f32[2,4] add(dot, mul.2)
-  const.1 = f32[] constant(0.797884583)
-  bcast.1 = f32[2,4] broadcast(const.1), dimensions={}
-  mul.3 = f32[2,4] multiply(add.0, bcast.1)
-  tanh = f32[2,4] tanh(mul.3)
-  const.2 = f32[] constant(1)
-  bcast.2 = f32[2,4] broadcast(const.2), dimensions={}
-  add.2 = f32[2,4] add(tanh, bcast.2)
-  const.3 = f32[] constant(0.5)
-  bcast.3 = f32[2,4] broadcast(const.3), dimensions={}
-  mul.4 = f32[2,4] multiply(add.2, bcast.3)
-  ROOT out = f32[2,4] multiply(dot, mul.4)
-}
-
-)";
-
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-
-; CHECK-NOT: GELU
-      )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, VectorBiasThenApproxGeluActivation) {
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60000
-  auto rocm_switch = false;  // GELU is only available from ROCM 6.0
-#else
-  auto rocm_switch = true;
-#endif
-  if (!IsCuda() && rocm_switch) {
-    GTEST_SKIP() << "TODO: Unsupported blas-lt epilogue on ROCM";
-  }
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  x = f32[2,3] parameter(0)
-  y = f32[3,4] parameter(1)
-  z = f32[4] parameter(2)
-  dot = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  z_bcast = f32[2,4] broadcast(z), dimensions={1}
-  add = f32[2,4] add(dot, z_bcast)
-  mul.0 = f32[2,4] multiply(add, add)
-  mul.1 = f32[2,4] multiply(add, mul.0)
-  const.0 = f32[] constant(0.044715)
-  bcast.0 = f32[2,4] broadcast(const.0), dimensions={}
-  mul.2 = f32[2,4] multiply(mul.1, bcast.0)
-  add.0 = f32[2,4] add(add, mul.2)
-  const.1 = f32[] constant(0.797884583)
-  bcast.1 = f32[2,4] broadcast(const.1), dimensions={}
-  mul.3 = f32[2,4] multiply(add.0, bcast.1)
-  tanh = f32[2,4] tanh(mul.3)
-  const.2 = f32[] constant(1)
-  bcast.2 = f32[2,4] broadcast(const.2), dimensions={}
-  add.2 = f32[2,4] add(tanh, bcast.2)
-  const.3 = f32[] constant(0.5)
-  bcast.3 = f32[2,4] broadcast(const.3), dimensions={}
-  mul.4 = f32[2,4] multiply(add.2, bcast.3)
-  ROOT out = f32[2,4] multiply(add, mul.4)
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[4]) -> f32[2,4] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[4]{0} parameter(2)
-; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"BIAS_GELU"
-; CHECK:           }
-      )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, ApproxGeluActivationWithAux) {
-  if (!IsCuda()) {
-    GTEST_SKIP() << "TODO: Unsupported blas-lt epilogue on ROCM";
-  }
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  x = f32[2,3] parameter(0)
-  y = f32[3,4] parameter(1)
-  dot = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  mul.0 = f32[2,4] multiply(dot, dot)
-  mul.1 = f32[2,4] multiply(dot, mul.0)
-  const.0 = f32[] constant(0.044715)
-  bcast.0 = f32[2,4] broadcast(const.0), dimensions={}
-  mul.2 = f32[2,4] multiply(mul.1, bcast.0)
-  add.0 = f32[2,4] add(dot, mul.2)
-  const.1 = f32[] constant(0.797884583)
-  bcast.1 = f32[2,4] broadcast(const.1), dimensions={}
-  mul.3 = f32[2,4] multiply(add.0, bcast.1)
-  tanh = f32[2,4] tanh(mul.3)
-  const.2 = f32[] constant(1)
-  bcast.2 = f32[2,4] broadcast(const.2), dimensions={}
-  add.2 = f32[2,4] add(tanh, bcast.2)
-  const.3 = f32[] constant(0.5)
-  bcast.3 = f32[2,4] broadcast(const.3), dimensions={}
-  mul.4 = f32[2,4] multiply(add.2, bcast.3)
-  mul.5 = f32[2,4] multiply(dot, mul.4)
-  ROOT out = (f32[2,4], f32[2,4]) tuple(mul.5, dot)
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4]) -> (f32[2,4], f32[2,4]) {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
-; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"GELU_AUX"
-; CHECK:           }
-      )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, VectorBiasThenApproxGeluActivationWithAux) {
-  if (!IsCuda()) {
-    GTEST_SKIP() << "TODO: Unsupported blas-lt epilogue on ROCM";
-  }
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  x = f32[2,3] parameter(0)
-  y = f32[3,4] parameter(1)
-  z = f32[4] parameter(2)
-  dot = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  z_bcast = f32[2,4] broadcast(z), dimensions={1}
-  add = f32[2,4] add(dot, z_bcast)
-  mul.0 = f32[2,4] multiply(add, add)
-  mul.1 = f32[2,4] multiply(add, mul.0)
-  const.0 = f32[] constant(0.044715)
-  bcast.0 = f32[2,4] broadcast(const.0), dimensions={}
-  mul.2 = f32[2,4] multiply(mul.1, bcast.0)
-  add.0 = f32[2,4] add(add, mul.2)
-  const.1 = f32[] constant(0.797884583)
-  bcast.1 = f32[2,4] broadcast(const.1), dimensions={}
-  mul.3 = f32[2,4] multiply(add.0, bcast.1)
-  tanh = f32[2,4] tanh(mul.3)
-  const.2 = f32[] constant(1)
-  bcast.2 = f32[2,4] broadcast(const.2), dimensions={}
-  add.2 = f32[2,4] add(tanh, bcast.2)
-  const.3 = f32[] constant(0.5)
-  bcast.3 = f32[2,4] broadcast(const.3), dimensions={}
-  mul.4 = f32[2,4] multiply(add.2, bcast.3)
-  mul.5 = f32[2,4] multiply(add, mul.4)
-  ROOT out = (f32[2,4], f32[2,4]) tuple(mul.5, add)
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[4]) -> (f32[2,4], f32[2,4]) {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[4]{0} parameter(2)
-; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"BIAS_GELU_AUX"
-; CHECK:           }
-      )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, ApproxGeluActivationBF16) {
-  if (IsCuda() &&
-      !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
-    GTEST_SKIP() << "Padding of GEMM bf16 operands only implemented on "
-                    "architectures with bf16 Tensor Cores.";
-  }
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  x = bf16[2,3] parameter(0)
-  y = bf16[3,4] parameter(1)
-  dot = bf16[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  mul.0 = bf16[2,4] multiply(dot, dot)
-  mul.1 = bf16[2,4] multiply(dot, mul.0)
-  const.0 = bf16[] constant(0.044715)
-  bcast.0 = bf16[2,4] broadcast(const.0), dimensions={}
-  mul.2 = bf16[2,4] multiply(mul.1, bcast.0)
-  add.0 = bf16[2,4] add(dot, mul.2)
-  const.1 = bf16[] constant(0.797884583)
-  bcast.1 = bf16[2,4] broadcast(const.1), dimensions={}
-  mul.3 = bf16[2,4] multiply(add.0, bcast.1)
-  tanh = bf16[2,4] tanh(mul.3)
-  const.2 = bf16[] constant(1)
-  bcast.2 = bf16[2,4] broadcast(const.2), dimensions={}
-  add.2 = bf16[2,4] add(tanh, bcast.2)
-  const.3 = bf16[] constant(0.5)
-  bcast.3 = bf16[2,4] broadcast(const.3), dimensions={}
-  mul.4 = bf16[2,4] multiply(add.2, bcast.3)
-  ROOT out = bf16[2,4] multiply(dot, mul.4)
-})";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{5e-5, 1e-5}));
-  MatchOptimizedHlo(hlo_text, R"(
-; CHECK-DAG: ENTRY %test ({{.*}}: bf16[2,3], {{.*}}: bf16[3,4]) -> bf16[2,4] {
-; CHECK-DAG:    bf16[8,8]{1,0} pad({{.*}}), padding=0_6x0_5
-; CHECK-DAG:    bf16[8,8]{1,0} pad({{.*}}), padding=0_5x0_4
-      )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, ApproxGeluActivationBitcast) {
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  x = f32[2,3] parameter(0)
-  y = f32[3,4] parameter(1)
-  dot = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  dot_bitcast = f32[2,2,2] bitcast(dot)
-  mul.0 = f32[2,2,2] multiply(dot_bitcast, dot_bitcast)
-  mul.1 = f32[2,2,2] multiply(dot_bitcast, mul.0)
-  const.0 = f32[] constant(0.044715)
-  bcast.0 = f32[2,2,2] broadcast(const.0), dimensions={}
-  mul.2 = f32[2,2,2] multiply(mul.1, bcast.0)
-  add.0 = f32[2,2,2] add(dot_bitcast, mul.2)
-  const.1 = f32[] constant(0.797884583)
-  bcast.1 = f32[2,2,2] broadcast(const.1), dimensions={}
-  mul.3 = f32[2,2,2] multiply(add.0, bcast.1)
-  tanh = f32[2,2,2] tanh(mul.3)
-  const.2 = f32[] constant(1)
-  bcast.2 = f32[2,2,2] broadcast(const.2), dimensions={}
-  add.2 = f32[2,2,2] add(tanh, bcast.2)
-  const.3 = f32[] constant(0.5)
-  bcast.3 = f32[2,2,2] broadcast(const.3), dimensions={}
-  mul.4 = f32[2,2,2] multiply(add.2, bcast.3)
-  ROOT out = f32[2,2,2] multiply(dot_bitcast, mul.4)
-}
-
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(hlo_text));
-  GemmRewriter pass(Capability(), GetToolkitVersion());
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
-  EXPECT_TRUE(changed);
-
-  EXPECT_THAT(
-      module->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::Bitcast(m::GetTupleElement(
-                         m::CustomCall({"__cublas$lt$matmul"},
-                                       m::Parameter(0).WithShape(F32, {2, 3}),
-                                       m::Parameter(1).WithShape(F32, {3, 4})),
-                         0))
-              .WithShape(F32, {2, 2, 2})));
-}
-
-// For F16, the sizes of all dimensions of the operands are required to be
-// multiples of 8 to allow matrix bias fusion.
-TEST_F(CublasLtGemmRewriteTest, MatrixBiasF16) {
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  x = f16[8,16] parameter(0)
-  y = f16[16,8] parameter(1)
-  z = f16[8,8] parameter(2)
-  dot_a = f16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  ROOT out = f16[8,8] add(dot_a, z)
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f16[8,16], {{.*}}: f16[16,8], {{.*}}: f16[8,8]) -> f16[8,8] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f16[8,16]{1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = f16[16,8]{1,0} parameter(1)
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f16[8,8]{1,0} parameter(2)
-; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f16[8,8]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":1
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-      )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, VectorBiasF32UnpaddedWithBitcast) {
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  x = f32[2,3]{1,0} parameter(0)
-  y = f32[3,4]{1,0} parameter(1)
-  z = f32[2]{0} parameter(2)
-  dot_a = f32[2,4]{0,1} dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  bitc = f32[4,2]{1,0} bitcast(f32[2,4]{0,1} dot_a)
-  z_bcast = f32[4,2] broadcast(z), dimensions={1}
-  ROOT add = f32[4,2]{1,0} add(bitc, z_bcast)
-}
-
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(hlo_text));
-  GemmRewriter pass(Capability(), GetToolkitVersion());
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
-  EXPECT_TRUE(changed);
-
-  EXPECT_THAT(
-      module->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::Bitcast(m::GetTupleElement(
-                         m::CustomCall({"__cublas$lt$matmul"}, m::Parameter(0),
-                                       m::Parameter(1),
-                                       m::Parameter(2).WithShape(F32, {2})),
-                         0)
-                         .WithShape(F32, {2, 4}))
-              .WithShape(F32, {4, 2})));
-}
-
-// For F16, the operands are padded on GPUs with Tensor Cores (i.e. Volta and
-// newer architectures) so that the sizes of all dimensions are multiples of 8.
-TEST_F(CublasLtGemmRewriteTest, VectorBiasF16Unpadded) {
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  x = f16[8,16] parameter(0)
-  y = f16[16,8] parameter(1)
-  z = f16[8] parameter(2)
-  dot_a = f16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  z_bcast = f16[8,8] broadcast(z), dimensions={1}
-  ROOT add = f16[8,8] add(dot_a, z_bcast)
-})";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{8e-3, 2e-3}));
-  MatchOptimizedHlo(hlo_text, R"(
-; CHECK-NOT:  pad("
-; CHECK:      custom-call
-; CHECK-SAME: custom_call_target="__cublas$lt$matmul"
-      )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, VectorBiasF16Padded) {
-  if (IsCuda() &&
-      !HasCudaComputeCapability(se::CudaComputeCapability::Volta())) {
-    GTEST_SKIP() << "Padding of GEMM operands only implemented on "
-                    "architectures with Tensor Cores.";
-  }
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  x = f16[6,12] parameter(0)
-  y = f16[12,6] parameter(1)
-  z = f16[6] parameter(2)
-  dot_a = f16[6,6] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  z_bcast = f16[6,6] broadcast(z), dimensions={1}
-  ROOT add = f16[6,6] add(dot_a, z_bcast)
-})";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-
-; CHECK-DAG: ENTRY %test ({{.*}}: f16[6,12], {{.*}}: f16[12,6], {{.*}}: f16[6]) -> f16[6,6] {
-; CHECK-DAG:    f16[8,16]{1,0} pad({{.*}}), padding=0_2x0_4
-; CHECK-DAG:    f16[16,8]{1,0} pad({{.*}}), padding=0_4x0_2
-      )");
-}
-
-// For F16, the operands are padded on GPUs with Tensor Cores (i.e. Volta and
-// newer architectures) so that the sizes of all dimensions are multiples of 8.
-TEST_F(CublasLtGemmRewriteTest, ReluActivationF16Unpadded) {
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  x = f16[8,16] parameter(0)
-  y = f16[16,8] parameter(1)
-  dot_a = f16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  c = f16[] constant(0)
-  c_bcast = f16[8,8] broadcast(c), dimensions={}
-  ROOT out = f16[8,8] maximum(dot_a, c_bcast)
-})";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
-  MatchOptimizedHlo(hlo_text, R"(
-; CHECK-NOT:  pad("
-; CHECK:      custom-call
-; CHECK-SAME: custom_call_target="__cublas$lt$matmul"
-      )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, ReluActivationF16Padded) {
-  if (IsCuda() &&
-      !HasCudaComputeCapability(se::CudaComputeCapability::Volta())) {
-    GTEST_SKIP() << "Padding of GEMM operands only implemented on "
-                    "architectures with Tensor Cores.";
-  }
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  x = f16[6,12] parameter(0)
-  y = f16[12,6] parameter(1)
-  dot_a = f16[6,6] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  c = f16[] constant(0)
-  c_bcast = f16[6,6] broadcast(c), dimensions={}
-  ROOT out = f16[6,6] maximum(dot_a, c_bcast)
-})";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-  MatchOptimizedHlo(hlo_text, R"(
-; CHECK-DAG: ENTRY %test ({{.*}}: f16[6,12], {{.*}}: f16[12,6]) -> f16[6,6] {
-; CHECK-DAG:    f16[8,16]{1,0} pad({{.*}}), padding=0_2x0_4
-; CHECK-DAG:    f16[16,8]{1,0} pad({{.*}}), padding=0_4x0_2
-      )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, MatrixBiasReluActivationF16) {
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  x = f16[8,16] parameter(0)
-  y = f16[16,8] parameter(1)
-  z = f16[8,8] parameter(2)
-  dot_a = f16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  add = f16[8,8] add(dot_a, z)
-  c = f16[] constant(0)
-  c_bcast = f16[8,8] broadcast(c), dimensions={}
-  ROOT out = f16[8,8] maximum(add, c_bcast)
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f16[8,16], {{.*}}: f16[16,8], {{.*}}: f16[8,8]) -> f16[8,8] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f16[8,16]{1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = f16[16,8]{1,0} parameter(1)
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f16[8,8]{1,0} parameter(2)
-; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f16[8,8]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":1
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"RELU"
-; CHECK:           }
-      )");
-}
-
-// For F16, the operands are padded on GPUs with Tensor Cores (i.e. Volta and
-// newer architectures) so that the sizes of all dimensions are multiples of 8.
-TEST_F(CublasLtGemmRewriteTest, VectorBiasReluActivationF16Unpadded) {
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  x = f16[8,16] parameter(0)
-  y = f16[16,8] parameter(1)
-  z = f16[8] parameter(2)
-  dot_a = f16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  z_bcast = f16[8,8] broadcast(z), dimensions={1}
-  add = f16[8,8] add(dot_a, z_bcast)
-  c = f16[] constant(0)
-  c_bcast = f16[8,8] broadcast(c), dimensions={}
-  ROOT out = f16[8,8] maximum(add, c_bcast)
-})";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
-  MatchOptimizedHlo(hlo_text, R"(
-; CHECK-NOT:  pad("
-; CHECK:      custom-call
-; CHECK-SAME: custom_call_target="__cublas$lt$matmul"
-)");
-}
-
-TEST_F(CublasLtGemmRewriteTest, VectorBiasReluActivationF16Padded) {
-  if (IsCuda() &&
-      !HasCudaComputeCapability(se::CudaComputeCapability::Volta())) {
-    GTEST_SKIP() << "Padding of GEMM operands only implemented on "
-                    "architectures with Tensor Cores.";
-  }
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  x = f16[6,12] parameter(0)
-  y = f16[12,6] parameter(1)
-  z = f16[6] parameter(2)
-  dot_a = f16[6,6] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  z_bcast = f16[6,6] broadcast(z), dimensions={1}
-  add = f16[6,6] add(dot_a, z_bcast)
-  c = f16[] constant(0)
-  c_bcast = f16[6,6] broadcast(c), dimensions={}
-  ROOT out = f16[6,6] maximum(add, c_bcast)
-})";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
-  MatchOptimizedHlo(hlo_text, R"(
-; CHECK-DAG: ENTRY %test ({{.*}}: f16[6,12], {{.*}}: f16[12,6], {{.*}}: f16[6]) -> f16[6,6] {
-; CHECK-DAG:   f16[8,16]{1,0} pad({{.*}}), padding=0_2x0_4
-; CHECK-DAG:   f16[16,8]{1,0} pad({{.*}}), padding=0_4x0_2
-      )");
-}
-
-// For bfloat16, the sizes of all dimensions of the operands are required to be
-// multiples of 8 to allow matrix bias fusion.
-TEST_F(CublasLtGemmRewriteTest, MatrixBiasBF16) {
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  x = bf16[8,16] parameter(0)
-  y = bf16[16,8] parameter(1)
-  z = bf16[8,8] parameter(2)
-  dot_a = bf16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  ROOT out = bf16[8,8] add(dot_a, z)
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
-
-  if (IsCuda() &&
-      !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
-    GTEST_SKIP() << "Pre-Ampere casts up bf16 to fp32";
-  }
-
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: bf16[8,16], {{.*}}: bf16[16,8], {{.*}}: bf16[8,8]) -> bf16[8,8] {
-; CHECK-DAG:     [[P0:%[^ ]+]] = bf16[8,16]{1,0} parameter(0)
-; CHECK-DAG:     [[P1:%[^ ]+]] = bf16[16,8]{1,0} parameter(1)
-; CHECK-DAG:     [[P2:%[^ ]+]] = bf16[8,8]{1,0} parameter(2)
-; CHECK-NEXT:    [[OUT:%[^ ]+]] = (bf16[8,8]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":1
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-      )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, MatrixBiasBitcastBF16) {
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  x = bf16[8,16] parameter(0)
-  y = bf16[16,8] parameter(1)
-  bias = bf16[2,4,8] parameter(2)
-  dot = bf16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  bitcast = bf16[2,4,8] bitcast(dot)
-  ROOT out = bf16[2,4,8] add(bitcast, bias)
-}
-
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(hlo_text));
-  GemmRewriter pass(Capability(), GetToolkitVersion());
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
-  EXPECT_TRUE(changed);
-
-  EXPECT_THAT(
-      module->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::Bitcast(
-              m::GetTupleElement(
-                  m::CustomCall(
-                      {"__cublas$lt$matmul"},
-                      m::Parameter(0).WithShape(BF16, {8, 16}),
-                      m::Parameter(1).WithShape(BF16, {16, 8}),
-                      m::Bitcast(m::Parameter(2)).WithShape(BF16, {8, 8})),
-                  0))
-              .WithShape(BF16, {2, 4, 8})));
-}
-
-// For bfloat16, the operands are padded if necessary on Ampere and newer
-// architectures so that the sizes of all dimensions are multiples of 8.
-TEST_F(CublasLtGemmRewriteTest, VectorBiasBF16Unpadded) {
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  x = bf16[8,16] parameter(0)
-  y = bf16[16,8] parameter(1)
-  z = bf16[8] parameter(2)
-  dot_a = bf16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  z_bcast = bf16[8,8] broadcast(z), dimensions={1}
-  ROOT add = bf16[8,8] add(dot_a, z_bcast)
-})";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{8e-3, 2e-3}));
-  MatchOptimizedHlo(hlo_text, R"(
-; CHECK-NOT:  pad("
-; CHECK:      custom-call
-; CHECK-SAME: custom_call_target="__cublas$lt$matmul"
-      )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, VectorBiasBF16Padded) {
-  if (IsCuda() &&
-      !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
-    GTEST_SKIP() << "Padding of GEMM operands in bfloat16 only implemented on "
-                    "Ampere and newer architectures.";
-  }
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  x = bf16[6,12] parameter(0)
-  y = bf16[12,6] parameter(1)
-  z = bf16[6] parameter(2)
-  dot_a = bf16[6,6] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  z_bcast = bf16[6,6] broadcast(z), dimensions={1}
-  ROOT add = bf16[6,6] add(dot_a, z_bcast)
-})";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
-  MatchOptimizedHlo(hlo_text, R"(
-; CHECK-DAG:  ENTRY %test ({{.*}}: bf16[6,12], {{.*}}: bf16[12,6], {{.*}}: bf16[6]) -> bf16[6,6] {
-; CHECK-DAG:    bf16[8,16]{1,0} pad({{.*}}), padding=0_2x0_4
-; CHECK-DAG:    bf16[16,8]{1,0} pad({{.*}}), padding=0_4x0_2
-      )");
-}
-
-// For bfloat16, the operands are padded if necessary on Ampere and newer
-// architectures so that the sizes of all dimensions are multiples of 8.
-TEST_F(CublasLtGemmRewriteTest, ReluActivationBF16Unpadded) {
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  x = bf16[8,16] parameter(0)
-  y = bf16[16,8] parameter(1)
-  dot_a = bf16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  c = bf16[] constant(0)
-  c_bcast = bf16[8,8] broadcast(c), dimensions={}
-  ROOT out = bf16[8,8] maximum(dot_a, c_bcast)
-}
-
-)";
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
-  MatchOptimizedHlo(hlo_text, R"(
-; CHECK-NOT:  pad("
-; CHECK:      custom-call
-; CHECK-SAME: custom_call_target="__cublas$lt$matmul"
-      )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, ReluActivationBF16Padded) {
-  if (IsCuda() &&
-      !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
-    GTEST_SKIP() << "Padding of GEMM operands in bfloat16 only implemented on "
-                    "Ampere and newer architectures.";
-  }
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  x = bf16[6,12] parameter(0)
-  y = bf16[12,6] parameter(1)
-  dot_a = bf16[6,6] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  c = bf16[] constant(0)
-  c_bcast = bf16[6,6] broadcast(c), dimensions={}
-  ROOT out = bf16[6,6] maximum(dot_a, c_bcast)
-})";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-  MatchOptimizedHlo(hlo_text, R"(
-; CHECK-DAG: ENTRY %test ({{.*}}: bf16[6,12], {{.*}}: bf16[12,6]) -> bf16[6,6] {
-; CHECK-DAG:     bf16[8,16]{1,0} pad({{.*}}), padding=0_2x0_4
-; CHECK-DAG:     bf16[16,8]{1,0} pad({{.*}}), padding=0_4x0_2
-      )");
-}
-
-// For bfloat16, the operands are padded if necessary on Ampere and newer
-// architectures so that the sizes of all dimensions are multiples of 8.
-TEST_F(CublasLtGemmRewriteTest, VectorBiasReluActivationBF16Unpadded) {
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  x = bf16[8,16] parameter(0)
-  y = bf16[16,8] parameter(1)
-  z = bf16[8] parameter(2)
-  dot_a = bf16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  z_bcast = bf16[8,8] broadcast(z), dimensions={1}
-  add = bf16[8,8] add(dot_a, z_bcast)
-  c = bf16[] constant(0)
-  c_bcast = bf16[8,8] broadcast(c), dimensions={}
-  ROOT out = bf16[8,8] maximum(add, c_bcast)
-})";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{8e-3, 2e-3}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-; CHECK-NOT:  pad("
-; CHECK:      custom-call
-; CHECK-SAME: custom_call_target="__cublas$lt$matmul"
-      )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, VectorBiasReluActivationBF16Padded) {
-  if (IsCuda() &&
-      !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
-    GTEST_SKIP() << "Padding of GEMM operands in bfloat16 only implemented on "
-                    "Ampere and newer architectures.";
-  }
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  x = bf16[6,12] parameter(0)
-  y = bf16[12,6] parameter(1)
-  z = bf16[6] parameter(2)
-  dot_a = bf16[6,6] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  z_bcast = bf16[6,6] broadcast(z), dimensions={1}
-  add = bf16[6,6] add(dot_a, z_bcast)
-  c = bf16[] constant(0)
-  c_bcast = bf16[6,6] broadcast(c), dimensions={}
-  ROOT out = bf16[6,6] maximum(add, c_bcast)
-}
-
-)";
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
-  MatchOptimizedHlo(hlo_text, R"(
-; CHECK-DAG: ENTRY %test ({{.*}}: bf16[6,12], {{.*}}: bf16[12,6], {{.*}}: bf16[6]) -> bf16[6,6] {
-; CHECK-DAG:     bf16[8,16]{1,0} pad({{.*}}), padding=0_2x0_4
-; CHECK-DAG:     bf16[16,8]{1,0} pad({{.*}}), padding=0_4x0_2
-      )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, VectorBiasReluActivationF64) {
-  if (!IsCuda()) {
-    GTEST_SKIP() << "TODO: Unsupported blas-lt F64 datatype on ROCM";
-  }
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  x = f64[2,3] parameter(0)
-  y = f64[3,4] parameter(1)
-  z = f64[4] parameter(2)
-  dot_a = f64[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  z_bcast = f64[2,4] broadcast(z), dimensions={1}
-  add = f64[2,4] add(dot_a, z_bcast)
-  c = f64[] constant(0)
-  c_bcast = f64[2,4] broadcast(c), dimensions={}
-  ROOT out = f64[2,4] maximum(add, c_bcast)
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-10, 1e-10}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f64[2,3], {{.*}}: f64[3,4], {{.*}}: f64[4]) -> f64[2,4] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f64[2,3]{1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = f64[3,4]{1,0} parameter(1)
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f64[4]{0} parameter(2)
-; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f64[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"BIAS_RELU"
-; CHECK:           }
-      )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, AlphaSimpleRewriteBiasAddActivation) {
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  x = f32[2,3] parameter(0)
-  y = f32[3,4] parameter(1)
-  z = f32[4] parameter(2)
-  k = f32[] constant(3.0)
-  k_bcast = f32[2,4] broadcast(k), dimensions={}
-  dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,highest}
-  dot_a_multiplied = f32[2, 4] multiply(dot_a, k_bcast)
-  z_bcast = f32[2,4] broadcast(z), dimensions={1}
-  add = f32[2,4] add(dot_a_multiplied, z_bcast)
-  c = f32[] constant(0)
-  c_bcast = f32[2,4] broadcast(c), dimensions={}
-  ROOT out = f32[2,4] maximum(add, c_bcast)
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[4]) -> f32[2,4] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[4]{0} parameter(2)
-; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":3
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["HIGHEST","HIGHEST"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"BIAS_RELU"
-; CHECK:           }
-      )");
-}
-
-TEST_F(CublasLtGemmRewriteTest, FoldConstantBias) {
-  const char* hlo_text = R"(
-HloModule test
-ENTRY test {
-  x = f32[2,2] parameter(0)
-  y = f32[2,2] parameter(1)
-  bias = f32[2,2] broadcast(f32[2] constant({0, 0})), dimensions={0}
-
-  dot1 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  bias1 = f32[2,2] parameter(2)
-  sum1 = add(dot1, bias1)
-
-  dot2 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  sum2 = add(dot2, f32[2,2] reshape(bias))
-
-  dot3 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  bias3 = f32[2,2] transpose(bias), dimensions={1,0}
-  sum3 = add(dot3, bias3)
-
-  dot4 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  sum4 = add(dot4, f32[2,2] bitcast(bias))
-
-  ROOT root = tuple(sum1, sum2, sum3, sum4)
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(hlo_text));
-  GemmRewriter pass(Capability(), GetToolkitVersion());
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
-  SCOPED_TRACE(module->ToString());
-  EXPECT_TRUE(changed);
-
-  EXPECT_THAT(
-      module->entry_computation()->root_instruction(),
-      GmockMatch(m::Tuple(
-          m::GetTupleElement(
-              m::CustomCall(m::Parameter(0), m::Parameter(1), m::Parameter()),
-              0),
-          m::GetTupleElement(
-              m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()),
-              0),
-          m::GetTupleElement(
-              m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()),
-              0),
-          m::GetTupleElement(
-              m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()),
-              0))));
-}
-
-TEST_F(CublasLtGemmRewriteTest, MultipleMaximumUsers) {
-  const char* hlo_text = R"(
-HloModule multiple_maximum_users
-
-relu {
-  Arg_0 = f32[3,896,54]{2,1,0} parameter(0)
-  constant = f32[] constant(0)
-  broadcast = f32[3,896,54]{2,1,0} broadcast(constant), dimensions={}
-  ROOT maximum = f32[3,896,54]{2,1,0} maximum(Arg_0, broadcast)
-}
-
-ENTRY main {
-  constant = f32[] constant(1)
-  broadcast_1 = f32[3,896,1024]{2,1,0} broadcast(constant), dimensions={}
-  Arg_2 = f32[1024,54]{1,0} parameter(2)
-  dot = f32[3,896,54]{2,1,0} dot(broadcast_1, Arg_2), lhs_contracting_dims={2}, rhs_contracting_dims={0}
-  Arg_1 = f32[54]{0} parameter(1)
-  broadcast_2 = f32[3,896,54]{2,1,0} broadcast(Arg_1), dimensions={2}
-  add = f32[3,896,54]{2,1,0} add(dot, broadcast_2)
-  call = f32[3,896,54]{2,1,0} call(add), to_apply=relu
-  Arg_0 = f32[1]{0} parameter(0)
-  reshape_1 = f32[1,1,1]{2,1,0} reshape(Arg_0)
-  broadcast_3 = f32[1,1,1]{2,1,0} broadcast(reshape_1), dimensions={0,1,2}
-  reshape_2 = f32[] reshape(broadcast_3)
-  broadcast_4 = f32[3,896,54]{2,1,0} broadcast(reshape_2), dimensions={}
-  multiply = f32[3,896,54]{2,1,0} multiply(call, broadcast_4)
-  ROOT tuple = (f32[3,896,54]{2,1,0}, f32[3,896,54]{2,1,0}) tuple(multiply, call)
-}
-)";
-
-  // TODO(cjfj): Why do we need to relax the error constraint here?!
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-4}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-; CHECK:           custom_call_target="__cublas$lt$matmul",
-      )");
-}
-
-// Test gemm matrix bias add fusion with mix type and out of place update(C !=
-// D)
-TEST_F(CublasLtGemmRewriteTest, MatrixBiasMixTypeOutOfPlace) {
-  if (!IsCuda()) {
-    GTEST_SKIP() << "TODO: Unsupported mixed datatypes on ROCM";
-  }
-  std::vector<std::tuple<absl::string_view, absl::string_view>>
-      type_combinations = {
-          {"f16", "f32"},
-          {"bf16", "f32"},
-      };
-
-  const char* hlo_text_template = R"(
-HloModule test
-
-ENTRY test {
-  x = <<ABType>>[16,32] parameter(0)
-  y = <<ABType>>[32,16] parameter(1)
-  z = <<DType>>[16,16] parameter(2)
-  dot_a = <<ABType>>[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  convert = <<DType>>[16,16] convert(dot_a)
-  ROOT out = <<DType>>[16,16] add(convert, z)
-})";
-  for (const auto& type_combination : type_combinations) {
-    absl::flat_hash_map<absl::string_view, absl::string_view> replacements;
-    replacements["<<ABType>>"] = std::get<0>(type_combination);
-    replacements["<<DType>>"] = std::get<1>(type_combination);
-    const auto hlo_text = absl::StrReplaceAll(hlo_text_template, replacements);
-    EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
-
-    if (std::get<0>(type_combination) == "bf16" && IsCuda() &&
-        !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
-      continue;
-    }
-
-    TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
-                            GetOptimizedModule(hlo_text));
-    EXPECT_THAT(
-        optimized_module->entry_computation()->root_instruction(),
-        GmockMatch(m::GetTupleElement(
-            m::CustomCall(m::Parameter(0), m::Parameter(1), m::Parameter(2)),
-            0)));
-  }
-}
-
-// Test batch gemm matrix bias add fusion with mix type and out of place
-// update(C != D)
-TEST_F(CublasLtGemmRewriteTest, MatrixBiasMixTypeOutOfPlaceBatched) {
-  if (!IsCuda()) {
-    GTEST_SKIP() << "TODO: Unsupported mixed datatypes on ROCM";
-  }
-  std::vector<std::tuple<absl::string_view, absl::string_view>>
-      type_combinations = {
-          {"f16", "f32"},
-          {"bf16", "f32"},
-      };
-
-  const char* hlo_text_template = R"(
-HloModule test
-
-ENTRY test {
-  x = <<ABType>>[4,16,32] parameter(0)
-  y = <<ABType>>[4,32,16] parameter(1)
-  z = <<DType>>[4,16,16] parameter(2)
-  dot_a = <<ABType>>[4,16,16] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0}
-  convert = <<DType>>[4,16,16] convert(dot_a)
-  ROOT out = <<DType>>[4,16,16] add(convert, z)
-})";
-  for (const auto& type_combination : type_combinations) {
-    absl::flat_hash_map<absl::string_view, absl::string_view> replacements;
-    replacements["<<ABType>>"] = std::get<0>(type_combination);
-    replacements["<<DType>>"] = std::get<1>(type_combination);
-    const auto hlo_text = absl::StrReplaceAll(hlo_text_template, replacements);
-    EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
-
-    if (std::get<0>(type_combination) == "bf16" && IsCuda() &&
-        !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
-      continue;
-    }
-
-    TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
-                            GetOptimizedModule(hlo_text));
-    EXPECT_THAT(
-        optimized_module->entry_computation()->root_instruction(),
-        GmockMatch(m::GetTupleElement(
-            m::CustomCall(m::Parameter(0), m::Parameter(1), m::Parameter(2)),
-            0)));
-  }
-}
-
-// Test gemm matrix bias add fusion with mix type and in place update(C = D)
-TEST_F(CublasLtGemmRewriteTest, MatrixBiasMixTypeInPlace) {
-  if (!IsCuda()) {
-    GTEST_SKIP() << "TODO: Unsupported mixed datatypes on ROCM";
-  }
-  std::vector<std::tuple<absl::string_view, absl::string_view>>
-      type_combinations = {
-          {"f16", "f32"},
-          {"bf16", "f32"},
-      };
-  const char* hlo_text_template = R"(
-HloModule test
-
-ENTRY test {
-  x = <<ABType>>[16,32] parameter(0)
-  y = <<ABType>>[32,16] parameter(1)
-  z = <<DType>>[16,16] parameter(2)
-  dot_a = <<ABType>>[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  bias = <<DType>>[16,16] negate(z)
-  convert = <<DType>>[16,16] convert(dot_a)
-  ROOT out = <<DType>>[16,16] add(convert, bias)
-})";
-
-  for (const auto& type_combination : type_combinations) {
-    absl::flat_hash_map<absl::string_view, absl::string_view> replacements;
-    replacements["<<ABType>>"] = std::get<0>(type_combination);
-    replacements["<<DType>>"] = std::get<1>(type_combination);
-    const auto hlo_text = absl::StrReplaceAll(hlo_text_template, replacements);
-    EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
-
-    if (std::get<0>(type_combination) == "bf16" && IsCuda() &&
-        !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
-      continue;
-    }
-
-    TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
-                            GetOptimizedModule(hlo_text));
-    EXPECT_THAT(optimized_module->entry_computation()->root_instruction(),
-                GmockMatch(m::GetTupleElement(
-                    m::CustomCall(m::Parameter(0), m::Parameter(1),
-                                  m::Negate(m::Parameter(2))),
-                    0)));
-  }
-}
-
-// Test gemm matrix bias add fusion with mix type that is not supported
-TEST_F(CublasLtGemmRewriteTest, MatrixBiasMixTypeNotSupported) {
-  const char* hlo_text = R"(
-HloModule test
-
-ENTRY test {
-  x = bf16[16,32] parameter(0)
-  y = bf16[32,16] parameter(1)
-  z = f64[16,16] parameter(2)
-  dot_a = bf16[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  bias = f64[16,16] negate(z)
-  convert = f64[16,16] convert(dot_a)
-  ROOT out = f64[16,16] add(convert, bias)
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
-
-  if (IsCuda() &&
-      !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
-    GTEST_SKIP() << "Pre-Ampere casts up bf16 to fp32";
-  }
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
-                          GetOptimizedModule(hlo_text));
-  MatchOptimizedHlo(hlo_text, R"(
-; CHECK:        %[[custom_call:.*]] = {{.*}} custom-call{{.*}}__cublas$lt$matmul
-; CHECK:        %[[tuple:.*]] = bf16[16,16]{1,0} get-tuple-element(%[[custom_call]]), index=0
-; CHECK:        ROOT {{.*}} fusion({{.*}}%[[tuple]]
-)");
-}
-
-class ParameterizedFp8GemmRewriteTest : public ParameterizedGemmRewriteTest {
- public:
-  ParameterizedFp8GemmRewriteTest() {
-    replacements_[kF8E4M3DatatypePlaceholder] =
-#if GOOGLE_CUDA
-        "f8e4m3fn";
-#else
-        "f8e4m3fnuz";
-#endif
-    replacements_[kF8E5M2DatatypePlaceholder] =
-#if GOOGLE_CUDA
-        "f8e5m2";
-#else
-        "f8e5m2fnuz";
-#endif
-    replacements_[kF8E4M3AmaxPlaceholder] =
-#if GOOGLE_CUDA
-        "448.";
-#else
-        "240.";
-#endif
-  }
-
- protected:
-  // Check the HLO runs and has an FP8 cuBLAS LT custom call on supported
-  // architectures (Ada, Hopper, and later).
-  void CheckFp8IfSupported(absl::string_view hlo_text,
-                           ErrorSpec error_spec = ErrorSpec{1e-2, 1e-2}) {
-    if (!HasFp8Support()) {
-      return;
-    }
-    std::string replaced_hlo_text =
-        absl::StrReplaceAll(hlo_text, replacements_);
-    EXPECT_TRUE(RunAndCompare(absl::StrReplaceAll(hlo_text, replacements_),
-                              error_spec));
-
-    // Most FP8 tests directly create a GemmRewriter and check the output.
-    // Here, also run the entire HLO pass pipeline to ensure no other passes
-    // interfere with GemmRewriter's pattern matching.
-    TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
-                            GetOptimizedModule(replaced_hlo_text));
-    const HloInstruction* call =
-        FindInstruction(optimized_module.get(), HloOpcode::kCustomCall);
-    ASSERT_NE(call, nullptr);
-    EXPECT_EQ(call->custom_call_target(), "__cublas$lt$matmul$f8");
-  }
-
-  void MatchOptimizedHlo(absl::string_view hlo, const absl::string_view pattern,
-                         bool print_operand_shape = false) {
-    GemmRewriteTest::MatchOptimizedHlo(
-        absl::StrReplaceAll(hlo, replacements_),
-        absl::StrReplaceAll(pattern, replacements_), print_operand_shape);
-  }
-
-  void RunAndFilecheckHloRewrite(
-      absl::string_view hlo, HloPassInterface&& hlo_pass,
-      std::optional<absl::string_view> expected,
-      std::function<void(HloModule*)> after_pass_checks = nullptr,
-      const HloModuleConfig* config = nullptr) {
-    if (expected.has_value()) {
-      std::string replaced_pattern =
-          absl::StrReplaceAll(expected.value(), replacements_);
-      GemmRewriteTest::RunAndFilecheckHloRewrite(
-          absl::StrReplaceAll(hlo, replacements_), std::move(hlo_pass),
-          replaced_pattern, after_pass_checks, config);
-    }
-  }
-
-  absl::StatusOr<std::unique_ptr<VerifiedHloModule>>
-  ParseAndReturnVerifiedModule(absl::string_view hlo_text,
-                               int64_t replica_count = 1,
-                               int64_t num_partitions = 1) {
-    return GemmRewriteTest::ParseAndReturnVerifiedModule(
-        absl::StrReplaceAll(hlo_text, replacements_));
-  }
-
- private:
-  static constexpr const char* kF8E4M3DatatypePlaceholder{"<<F8E4M3>>"};
-  static constexpr const char* kF8E5M2DatatypePlaceholder{"<<F8E5M2>>"};
-  static constexpr const char* kF8E4M3AmaxPlaceholder{"<<F8E4M3_AMAX>>"};
-};
-
-TEST_P(ParameterizedFp8GemmRewriteTest, DoNotRewriteToF8OnPreAda) {
-  if (HasFp8Support()) {
-    GTEST_SKIP() << "Test requires a pre-Ada GPU or an AMD GPU prior to MI300.";
-  }
-  const char* hlo_text = R"(
-    HloModule test
-
-    ENTRY PreAdaTest {
-      x = <<F8E4M3>>[16,32] parameter(0)
-      y = <<F8E4M3>>[32,16] parameter(1)
-      ROOT out = <<F8E4M3>>[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-          }
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(absl::StrReplaceAll(hlo_text, replacements_),
-                            ErrorSpec{1e-2, 1e-2}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-; CHECK-LABEL: ENTRY %PreAdaTest ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16]) -> <<F8E4M3>>[16,16] {
-; CHECK:    {{.*}} = {{.*}} custom-call({{.*}}, {{.*}})
-; CHECK-DAG:  custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>"
-          )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, DoNotRewriteOnPreAdaWithF32Output) {
-  if (HasFp8Support()) {
-    GTEST_SKIP() << "Test requires a pre-Ada GPU or an AMD GPU prior to MI300.";
-  }
-  const char* hlo_text = R"(
-    HloModule test
-
-    ENTRY PreAdaTest {
-      x = <<F8E4M3>>[16,32] parameter(0)
-      y = <<F8E4M3>>[32,16] parameter(1)
-      ROOT out = f32[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-          }
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(absl::StrReplaceAll(hlo_text, replacements_),
-                            ErrorSpec{1e-2, 1e-2}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-; CHECK-LABEL: ENTRY %PreAdaTest ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16]) -> f32[16,16] {
-; CHECK:    {{.*}} = {{.*}} custom-call({{.*}}, {{.*}})
-; CHECK-DAG:  custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>"
-          )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, UnsupportedTypesF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif  // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif  // TF_ROCM_VERSION < 60000
-
-  // Test with types unsupported by cuBLAS LT when FP8 is used. cuBLAS LT with
-  // FP8 requires one of the operands to be F8E4M3FN.
-  const char* hlo_text = R"(
-    HloModule test
-
-    ENTRY unsupported_types {
-      x = <<F8E5M2>>[16,16] parameter(0)
-      y = <<F8E5M2>>[16,16] parameter(1)
-      ROOT out = <<F8E5M2>>[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-          }
-)";
-  EXPECT_TRUE(RunAndCompare(absl::StrReplaceAll(hlo_text, replacements_),
-                            ErrorSpec{1e-2, 1e-2}));
-  RunAndFilecheckHloRewrite(
-      hlo_text,
-      GemmRewriter(Capability(), GetToolkitVersion(), /*f8_rewrite=*/true),
-      R"(
-; CHECK-LABEL: ENTRY %unsupported_types ({{.*}}: <<F8E5M2>>[16,16], {{.*}}: <<F8E5M2>>[16,16]) -> <<F8E5M2>>[16,16] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = <<F8E5M2>>[16,16]{1,0} parameter(0)
-; CHECK-NEXT:    [[P0_CONVERT:%[^ ]+]] = f16[16,16]{1,0} convert([[P0]])
-; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E5M2>>[16,16]{1,0} parameter(1)
-; CHECK-NEXT:    [[P1_CONVERT:%[^ ]+]] = f16[16,16]{1,0} convert([[P1]])
-; CHECK-NEXT:    [[DOT:%[^ ]+]] = f16[16,16]{1,0} dot([[P0_CONVERT]], [[P1_CONVERT]]), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = <<F8E5M2>>[16,16]{1,0} convert([[DOT]])
-      )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif  // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif  // TF_ROCM_VERSION < 60000
-
-  const char* hlo_text = R"(
-    HloModule test
-
-    ENTRY test {
-      x = <<F8E4M3>>[16,32] parameter(0)
-      y = <<F8E4M3>>[32,16] parameter(1)
-      ROOT out = <<F8E4M3>>[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-          }
-
-)";
-
-  CheckFp8IfSupported(hlo_text);
-  RunAndFilecheckHloRewrite(
-      hlo_text,
-      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
-                   /*f8_rewrite=*/true),
-      R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16]) -> <<F8E4M3>>[16,16] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
-; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK-NEXT:    [[C1:[^ ]+]] = f32[] constant(1)
-)"
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60200
-      R"(; CHECK-GCN-NEXT:    [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]], [[C1]], /*index=5*/[[C1]]),
-)"
-#else
-      R"(; CHECK-NEXT:    [[OUT:%[^ ]+]] = (<<F8E4M3>>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]], [[C1]], /*index=5*/[[C1]]),
-)"
-#endif
-      R"(; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-      )");
-}
-
-// Do not fuse FP8 matrix bias.
-TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDMatrixBiasF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif  // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60200
-  GTEST_SKIP() << "F8 gemm rewrite for D to be fp8 with Matrix Bias is only "
-                  "supported in ROCm 6.2 and above.";
-#endif  // TF_ROCM_VERSION < 60200
-
-  const char* hlo_text = R"(
-    HloModule test
-
-    ENTRY test {
-      x = <<F8E4M3>>[16,32] parameter(0)
-      y = <<F8E4M3>>[32,16] parameter(1)
-      dot_a = <<F8E4M3>>[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-      b = <<F8E4M3>>[16,16] parameter(2)
-      ROOT out = <<F8E4M3>>[16,16] add(dot_a, b)
-          }
-
-)";
-
-  CheckFp8IfSupported(hlo_text);
-  RunAndFilecheckHloRewrite(
-      hlo_text,
-      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
-                   /*f8_rewrite=*/true),
-      R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: <<F8E4M3>>[16,16]) -> <<F8E4M3>>[16,16] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
-; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK-NEXT:    [[C1:[^ ]+]] = f32[] constant(1)
-; CHECK-NEXT:    [[DOT_TUPLE:%[^ ]+]] = (<<F8E4M3>>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]], [[C1]], /*index=5*/[[C1]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-; CHECK:         [[DOT:%[^ ]+]] = <<F8E4M3>>[16,16]{1,0} get-tuple-element([[DOT_TUPLE]]), index=0
-; CHECK-NEXT:    [[P2:%[^ ]+]] = <<F8E4M3>>[16,16]{1,0} parameter(2)
-; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = <<F8E4M3>>[16,16]{1,0} add([[DOT]], [[P2]])
-      )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif  // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif  // TF_ROCM_VERSION < 60000
-
-  const char* hlo_text = R"(
-    HloModule test
-
-    ENTRY test {
-      x = <<F8E4M3>>[16,32] parameter(0)
-      y = <<F8E4M3>>[32,16] parameter(1)
-      x_f32 = f32[16,32] convert(x)
-      y_f32 = f32[32,16] convert(y)
-      x_scale = f32[] parameter(2)
-      y_scale = f32[] parameter(3)
-      x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
-      y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
-      x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
-      y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast)
-      ROOT out = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-          }
-
-)";
-
-  CheckFp8IfSupported(hlo_text);
-  RunAndFilecheckHloRewrite(
-      hlo_text,
-      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
-                   /*f8_rewrite=*/true),
-      R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
-; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[] parameter(2)
-; CHECK-NEXT:    [[P3:%[^ ]+]] = f32[] parameter(3)
-; CHECK-NEXT:    [[C1:%[^ ]+]] = f32[] constant(1)
-; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-      )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDPaddedF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif  // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif  // TF_ROCM_VERSION < 60000
-
-  const char* hlo_text = R"(
-    HloModule test
-
-    ENTRY test {
-      x = <<F8E4M3>>[13,17] parameter(0)
-      y = <<F8E4M3>>[17,31] parameter(1)
-      x_f32 = f32[13,17] convert(x)
-      y_f32 = f32[17,31] convert(y)
-      x_scale = f32[] parameter(2)
-      y_scale = f32[] parameter(3)
-      x_scale_bcast = f32[13,17] broadcast(x_scale), dimensions={}
-      y_scale_bcast = f32[17,31] broadcast(y_scale), dimensions={}
-      x_unscaled = f32[13,17] multiply(x_f32, x_scale_bcast)
-      y_unscaled = f32[17,31] multiply(y_f32, y_scale_bcast)
-      ROOT out = f32[13,31] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-          }
-
-)";
-
-  CheckFp8IfSupported(hlo_text);
-  RunAndFilecheckHloRewrite(
-      hlo_text,
-      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
-                   /*f8_rewrite=*/true),
-      R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[13,17], {{.*}}: <<F8E4M3>>[17,31], {{.*}}: f32[], {{.*}}: f32[]) -> f32[13,31] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = <<F8E4M3>>[13,17]{1,0} parameter(0)
-; CHECK-NEXT:    [[C0:%[^ ]+]] = <<F8E4M3>>[] constant(0)
-; CHECK-NEXT:    [[P0_PADDED:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} pad([[P0]], [[C0]]), padding=0_3x0_15
-; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[17,31]{1,0} parameter(1)
-; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[31,17]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK-NEXT:    [[C1:%[^ ]+]] = <<F8E4M3>>[] constant(0)
-; CHECK-NEXT:    [[P1_TRANSPOSE_PADDED:%[^ ]+]] = <<F8E4M3>>[32,32]{1,0} pad([[P1_TRANSPOSE]], [[C1]])
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[] parameter(2)
-; CHECK-NEXT:    [[P3:%[^ ]+]] = f32[] parameter(3)
-; CHECK-NEXT:    [[C4:%[^ ]+]] = f32[] constant(1)
-; CHECK-NEXT:    [[DOT_TUPLE:%[^ ]+]] = (f32[16,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_PADDED]], [[P1_TRANSPOSE_PADDED]], [[P2]], [[P3]], [[C4]], /*index=5*/[[C4]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-; CHECK-NEXT:  [[DOT:%[^ ]+]] = f32[16,32]{1,0} get-tuple-element([[DOT_TUPLE]]), index=0
-; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[13,31]{1,0} slice([[DOT]]), slice={[0:13], [0:31]}
-      )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDBitcastF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif  // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif  // TF_ROCM_VERSION < 60000
-
-  const char* hlo_text = R"(
-    HloModule test
-
-    ENTRY test {
-      x = <<F8E4M3>>[2,8,16] parameter(0)
-      y = <<F8E4M3>>[16,16] parameter(1)
-      x_f32 = f32[2,8,16] convert(x)
-      y_f32 = f32[16,16] convert(y)
-      x_scale = f32[] parameter(2)
-      y_scale = f32[] parameter(3)
-      x_scale_bcast = f32[2,8,16] broadcast(x_scale), dimensions={}
-      y_scale_bcast = f32[16,16] broadcast(y_scale), dimensions={}
-      x_unscaled = f32[2,8,16] multiply(x_f32, x_scale_bcast)
-      y_unscaled = f32[16,16] multiply(y_f32, y_scale_bcast)
-      x_bitcast = f32[16,16] bitcast(x_unscaled)
-      ROOT out = f32[16,16] dot(x_bitcast, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-          }
-
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(hlo_text));
-  GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(),
-                    /*f8_rewrite=*/true);
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
-  EXPECT_TRUE(changed);
-
-  EXPECT_THAT(
-      module->entry_computation()->root_instruction(),
-      GmockMatch(m::GetTupleElement(m::CustomCall({"__cublas$lt$matmul$f8"}), 0)
-                     .WithShape(F32, {16, 16})));
-}
-
-// Test case where F8 inputs are converted to F32 before the dot, but without
-// any scaling.
-TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDWithConvertF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif  // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif  // TF_ROCM_VERSION < 60000
-
-  const char* hlo_text = R"(
-    HloModule test
-
-    ENTRY test {
-      x = <<F8E4M3>>[16,32] parameter(0)
-      y = <<F8E4M3>>[32,16] parameter(1)
-      x_f32 = f32[16,32] convert(x)
-      y_f32 = f32[32,16] convert(y)
-      ROOT out = f32[16,16] dot(x_f32, y_f32), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-          }
-
-)";
-
-  CheckFp8IfSupported(hlo_text);
-  RunAndFilecheckHloRewrite(
-      hlo_text,
-      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
-                   /*f8_rewrite=*/true),
-      R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16]) -> f32[16,16] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
-; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK-NEXT:    [[C1:%[^ ]+]] = f32[] constant(1)
-; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]], [[C1]], /*index=5*/[[C1]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-      )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDUnaryOpsF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif  // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif  // TF_ROCM_VERSION < 60000
-
-  const char* hlo_text = R"(
-    HloModule test
-
-    ENTRY test {
-      x = <<F8E4M3>>[3] parameter(0)
-      y = <<F8E4M3>>[32,16] parameter(1)
-      x_f32 = f32[3] convert(x)
-      y_f32 = f32[32,16] convert(y)
-      x_scale = f32[] parameter(2)
-      y_scale = f32[] parameter(3)
-      x_scale_bcast = f32[3] broadcast(x_scale), dimensions={}
-      y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
-      x_unscaled = f32[3] multiply(x_f32, x_scale_bcast)
-      zero = f32[] constant(0)
-      x_unscaled_padded = f32[30] pad(x_unscaled, zero), padding=0_27
-      x_unscaled_padded_bcast = f32[30,8,5] broadcast(x_unscaled_padded), dimensions={0}
-      x_unscaled_padded_bcast_sliced = f32[16,8,4] slice(x_unscaled_padded_bcast), slice={[2:18], [0:8], [0:4]}
-      x_unscaled_padded_bcast_sliced_reshaped = f32[16,32] reshape(x_unscaled_padded_bcast_sliced)
-      y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast)
-      ROOT out = f32[16,16] dot(x_unscaled_padded_bcast_sliced_reshaped, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-          }
-
-)";
-
-  CheckFp8IfSupported(hlo_text);
-  RunAndFilecheckHloRewrite(
-      hlo_text,
-      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
-                   /*f8_rewrite=*/true),
-      R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[3], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = <<F8E4M3>>[3]{0} parameter(0)
-; CHECK-NEXT:    [[C0:%[^ ]+]] = f32[] constant(0)
-; CHECK-NEXT:    [[C0_CONVERT:%[^ ]+]] = <<F8E4M3>>[] convert([[C0]])
-; CHECK-NEXT:    [[P0_U0:%[^ ]+]] = <<F8E4M3>>[30]{0} pad([[P0]], [[C0_CONVERT]]), padding=0_27
-; CHECK-NEXT:    [[P0_U1:%[^ ]+]] = <<F8E4M3>>[30,8,5]{2,1,0} broadcast([[P0_U0]]), dimensions={0}
-; CHECK-NEXT:    [[P0_U2:%[^ ]+]] = <<F8E4M3>>[16,8,4]{2,1,0} slice([[P0_U1]]), slice={[2:18], [0:8], [0:4]}
-; CHECK-NEXT:    [[P0_U3:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} reshape([[P0_U2]])
-; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
-; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[] parameter(2)
-; CHECK-NEXT:    [[P3:%[^ ]+]] = f32[] parameter(3)
-; CHECK-NEXT:    [[C2:%[^ ]+]] = f32[] constant(1)
-; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_U3]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C2]], /*index=5*/[[C2]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-      )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest,
-       UnscaledABUnscaledDUnaryOpsWithConvertF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif  // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif  // TF_ROCM_VERSION < 60000
-
-  const char* hlo_text = R"(
-    HloModule test
-
-    ENTRY test {
-      x = <<F8E4M3>>[3] parameter(0)
-      y = <<F8E4M3>>[32,16] parameter(1)
-      x_f32 = f32[3] convert(x)
-      y_f32 = f32[32,16] convert(y)
-      zero = f32[] constant(0)
-      x_padded = f32[30] pad(x_f32, zero), padding=0_27
-      x_padded_bcast = f32[30,8,5] broadcast(x_padded), dimensions={0}
-      x_padded_bcast_sliced = f32[16,8,4] slice(x_padded_bcast), slice={[2:18], [0:8], [0:4]}
-      x_padded_bcast_sliced_reshaped = f32[16,32] reshape(x_padded_bcast_sliced)
-      ROOT out = f32[16,16] dot(x_padded_bcast_sliced_reshaped, y_f32), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-          }
-
-)";
-
-  CheckFp8IfSupported(hlo_text);
-  RunAndFilecheckHloRewrite(
-      hlo_text,
-      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
-                   /*f8_rewrite=*/true),
-      R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[3], {{.*}}: <<F8E4M3>>[32,16]) -> f32[16,16] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = <<F8E4M3>>[3]{0} parameter(0)
-; CHECK-NEXT:    [[C0:%[^ ]+]] = f32[] constant(0)
-; CHECK-NEXT:    [[C0_CONVERT:%[^ ]+]] = <<F8E4M3>>[] convert([[C0]])
-; CHECK-NEXT:    [[P0_U0:%[^ ]+]] = <<F8E4M3>>[30]{0} pad([[P0]], [[C0_CONVERT]]), padding=0_27
-; CHECK-NEXT:    [[P0_U1:%[^ ]+]] = <<F8E4M3>>[30,8,5]{2,1,0} broadcast([[P0_U0]]), dimensions={0}
-; CHECK-NEXT:    [[P0_U2:%[^ ]+]] = <<F8E4M3>>[16,8,4]{2,1,0} slice([[P0_U1]]), slice={[2:18], [0:8], [0:4]}
-; CHECK-NEXT:    [[P0_U3:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} reshape([[P0_U2]])
-; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
-; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK-NEXT:    [[C2:%[^ ]+]] = f32[] constant(1)
-; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_U3]], [[P1_TRANSPOSE]], [[C2]], [[C2]], [[C2]], /*index=5*/[[C2]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-      )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDDynamicSliceF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif  // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif  // TF_ROCM_VERSION < 60000
-
-  const char* hlo_text = R"(
-    HloModule test
-
-    ENTRY test {
-      x = <<F8E4M3>>[32,32] parameter(0)
-      y = <<F8E4M3>>[16,32] parameter(1)
-      zero = s32[] constant(0)
-      x_f32 = f32[32,32] convert(x)
-      y_f32 = f32[16,32] convert(y)
-      x_scale = f32[] parameter(2)
-      y_scale = f32[] parameter(3)
-      x_scale_bcast = f32[32,32] broadcast(x_scale), dimensions={}
-      y_scale_bcast = f32[16,32] broadcast(y_scale), dimensions={}
-      x_unscaled = f32[32,32] multiply(x_f32, x_scale_bcast)
-      y_unscaled = f32[16,32] multiply(y_f32, y_scale_bcast)
-      dyn_slice = f32[16,32]{1,0} dynamic-slice(x_unscaled, zero, zero), dynamic_slice_sizes={16,32}
-      ROOT dot_a = f32[16,16] dot(dyn_slice, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={1}
-          }
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(hlo_text));
-  GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(),
-                    /*f8_rewrite=*/true);
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
-  EXPECT_TRUE(changed);
-
-  CheckFp8IfSupported(hlo_text);
-  RunAndFilecheckHloRewrite(
-      hlo_text,
-      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
-                   /*f8_rewrite=*/true),
-      R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[32,32], {{.*}}: <<F8E4M3>>[16,32], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = <<F8E4M3>>[32,32]{1,0} parameter(0)
-; CHECK-NEXT:    [[C0:%[^ ]+]] = s32[] constant(0)
-; CHECK-NEXT:    [[DYN_SLICE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} dynamic-slice([[P0]], [[C0]], [[C0]]), dynamic_slice_sizes={16,32}
-; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(1)
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[] parameter(2)
-; CHECK-NEXT:    [[P3:%[^ ]+]] = f32[] parameter(3)
-; CHECK-NEXT:    [[C1:%[^ ]+]] = f32[] constant(1)
-; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[DYN_SLICE]], [[P1]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-      )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDSelectF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif  // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif  // TF_ROCM_VERSION < 60000
-
-  const char* hlo_text = R"(
-    HloModule test
-
-    ENTRY test {
-      x = <<F8E4M3>>[16,32] parameter(0)
-      y = <<F8E4M3>>[16,32] parameter(1)
-      x_f32 = f32[16,32] convert(x)
-      y_f32 = f32[16,32] convert(y)
-      x_scale = f32[] parameter(2)
-      y_scale = f32[] parameter(3)
-      x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
-      y_scale_bcast = f32[16,32] broadcast(y_scale), dimensions={}
-      x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
-      y_unscaled = f32[16,32] multiply(y_f32, y_scale_bcast)
-      k = pred[16,32] parameter(4)
-      c = f32[] constant(0)
-      c_bcast = f32[16,32] broadcast(c), dimensions={}
-      select_a = f32[16,32] select(k, y_unscaled, c_bcast)
-      ROOT dot_a = f32[16,16] dot(x_unscaled, select_a), lhs_contracting_dims={1}, rhs_contracting_dims={1}
-          }
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(hlo_text));
-  GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(),
-                    /*f8_rewrite=*/true);
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
-  EXPECT_TRUE(changed);
-
-  CheckFp8IfSupported(hlo_text);
-  RunAndFilecheckHloRewrite(
-      hlo_text,
-      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
-                   /*f8_rewrite=*/true),
-      R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[16,32], {{.*}}: f32[], {{.*}}: f32[], {{.*}}: pred[16,32]) -> f32[16,16] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
-; CHECK-NEXT:    [[P4:%[^ ]+]] = pred[16,32]{1,0} parameter(4)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(1)
-; CHECK-NEXT:    [[C0:%[^ ]+]] = f32[] constant(0)
-; CHECK-NEXT:    [[C0_BCAST:%[^ ]+]] = f32[16,32]{1,0} broadcast([[C0]]), dimensions={}
-; CHECK-NEXT:    [[C0_CONVERT:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} convert([[C0_BCAST]])
-; CHECK-NEXT:    [[SELECT:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} select([[P4]], [[P1]], [[C0_CONVERT]])
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[] parameter(2)
-; CHECK-NEXT:    [[P3:%[^ ]+]] = f32[] parameter(3)
-; CHECK-NEXT:    [[C1:%[^ ]+]] = f32[] constant(1)
-; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[SELECT]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-      )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest,
-       ScaledABUnscaledDSelectNonzeroConstantF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif  // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif  // TF_ROCM_VERSION < 60000
-
-  const char* hlo_text = R"(
-    HloModule test
-
-    ENTRY test {
-      x = <<F8E4M3>>[16,32] parameter(0)
-      y = <<F8E4M3>>[16,32] parameter(1)
-      x_f32 = f32[16,32] convert(x)
-      y_f32 = f32[16,32] convert(y)
-      x_scale = f32[] parameter(2)
-      y_scale = f32[] parameter(3)
-      x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
-      y_scale_bcast = f32[16,32] broadcast(y_scale), dimensions={}
-      x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
-      y_unscaled = f32[16,32] multiply(y_f32, y_scale_bcast)
-      k = pred[16,32] parameter(4)
-      c = f32[] constant(1)
-      c_bcast = f32[16,32] broadcast(c), dimensions={}
-      select_a = f32[16,32] select(k, y_unscaled, c_bcast)
-      ROOT dot_a = f32[16,16] dot(x_unscaled, select_a), lhs_contracting_dims={1}, rhs_contracting_dims={1}
-          }
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(hlo_text));
-  GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(),
-                    /*f8_rewrite=*/true);
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
-  EXPECT_FALSE(changed);
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, BatchedScaledABUnscaledDF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif  // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif  // TF_ROCM_VERSION < 60000
-
-  const char* hlo_text = R"(
-    HloModule test
-
-    ENTRY test {
-      x = <<F8E4M3>>[10,16,32] parameter(0)
-      y = <<F8E4M3>>[10,32,16] parameter(1)
-      x_f32 = f32[10,16,32] convert(x)
-      y_f32 = f32[10,32,16] convert(y)
-      x_scale = f32[] parameter(2)
-      y_scale = f32[] parameter(3)
-      x_scale_bcast = f32[10,16,32] broadcast(x_scale), dimensions={}
-      y_scale_bcast = f32[10,32,16] broadcast(y_scale), dimensions={}
-      x_unscaled = f32[10,16,32] multiply(x_f32, x_scale_bcast)
-      y_unscaled = f32[10,32,16] multiply(y_f32, y_scale_bcast)
-      ROOT out = f32[10,16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0}
-          }
-
-)";
-
-  CheckFp8IfSupported(hlo_text);
-  RunAndFilecheckHloRewrite(
-      hlo_text,
-      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
-                   /*f8_rewrite=*/true),
-      R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[10,16,32], {{.*}}: <<F8E4M3>>[10,32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[10,16,16] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = <<F8E4M3>>[10,16,32]{2,1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[10,32,16]{2,1,0} parameter(1)
-; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[10,16,32]{2,1,0} transpose([[P1]]), dimensions={0,2,1}
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[] parameter(2)
-; CHECK-NEXT:    [[P3:%[^ ]+]] = f32[] parameter(3)
-; CHECK-NEXT:    [[C1:%[^ ]+]] = f32[] constant(1)
-; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f32[10,16,16]{2,1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["2"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["2"]
-; CHECK-DAG:           "lhs_batch_dimensions":["0"]
-; CHECK-DAG:           "rhs_batch_dimensions":["0"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-      )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABAlphaDF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif  // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif  // TF_ROCM_VERSION < 60000
-
-  const char* hlo_text = R"(
-    HloModule test
-
-    ENTRY test {
-      x = <<F8E4M3>>[16,32] parameter(0)
-      y = <<F8E4M3>>[32,16] parameter(1)
-      x_f32 = f32[16,32] convert(x)
-      y_f32 = f32[32,16] convert(y)
-      x_scale = f32[] parameter(2)
-      y_scale = f32[] parameter(3)
-      x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
-      y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
-      x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
-      y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast)
-      k = f32[] constant(3.0)
-      k_bcast = f32[16,16] broadcast(k), dimensions={}
-      dot_a = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-      ROOT out = f32[16,16] multiply(dot_a, k_bcast)
-          }
-
-)";
-
-  CheckFp8IfSupported(hlo_text);
-  RunAndFilecheckHloRewrite(
-      hlo_text,
-      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
-                   /*f8_rewrite=*/true),
-      R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
-; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[] parameter(2)
-; CHECK-NEXT:    [[P3:%[^ ]+]] = f32[] parameter(3)
-; CHECK-NEXT:    [[C1:%[^ ]+]] = f32[] constant(1)
-; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":3
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-      )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDReluActivationF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif  // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif  // TF_ROCM_VERSION < 60000
-
-  const char* hlo_text = R"(
-    HloModule test
-
-    ENTRY test {
-      x = <<F8E4M3>>[16,32] parameter(0)
-      y = <<F8E4M3>>[32,16] parameter(1)
-      x_f32 = f32[16,32] convert(x)
-      y_f32 = f32[32,16] convert(y)
-      x_scale = f32[] parameter(2)
-      y_scale = f32[] parameter(3)
-      x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
-      y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
-      x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
-      y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast)
-      dot_a = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-      c = f32[] constant(0)
-      c_bcast = f32[16,16] broadcast(c), dimensions={}
-      ROOT out = f32[16,16] maximum(dot_a, c_bcast)
-          }
-
-)";
-
-  CheckFp8IfSupported(hlo_text);
-  RunAndFilecheckHloRewrite(
-      hlo_text,
-      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
-                   /*f8_rewrite=*/true),
-      R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
-; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[] parameter(2)
-; CHECK-NEXT:    [[P3:%[^ ]+]] = f32[] parameter(3)
-; CHECK-NEXT:    [[C1:%[^ ]+]] = f32[] constant(1)
-; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"RELU"
-
-; CHECK:           }
-      )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest,
-       ScaledABUnscaledDVectorBiasThenApproxGeluActivationF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif  // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif  // TF_ROCM_VERSION < 60000
-
-  const char* hlo_text = R"(
-    HloModule test
-    ENTRY test {
-      x = <<F8E4M3>>[16,32] parameter(0)
-      y = <<F8E4M3>>[32,16] parameter(1)
-      x_bf16 = bf16[16,32] convert(x)
-      y_bf16 = bf16[32,16] convert(y)
-      x_scale = bf16[] parameter(2)
-      y_scale = bf16[] parameter(3)
-      bias = bf16[16] parameter(4)
-      x_scale_bcast = bf16[16,32] broadcast(x_scale), dimensions={}
-      y_scale_bcast = bf16[32,16] broadcast(y_scale), dimensions={}
-      x_unscaled = bf16[16,32] multiply(x_bf16, x_scale_bcast)
-      y_unscaled = bf16[32,16] multiply(y_bf16, y_scale_bcast)
-      dot1 = bf16[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-      b_bcast = bf16[16,16] broadcast(bias), dimensions={1}
-      dot = bf16[16,16] add(dot1, b_bcast)
-      mul.0 = bf16[16,16] multiply(dot, dot)
-      mul.1 = bf16[16,16] multiply(dot, mul.0)
-      const.0 = bf16[] constant(0.044715)
-      bcast.0 = bf16[16,16] broadcast(const.0), dimensions={}
-      mul.2 = bf16[16,16] multiply(mul.1, bcast.0)
-      add.0 = bf16[16,16] add(dot, mul.2)
-      const.1 = bf16[] constant(0.797884583)
-      bcast.1 = bf16[16,16] broadcast(const.1), dimensions={}
-      mul.3 = bf16[16,16] multiply(add.0, bcast.1)
-      tanh = bf16[16,16] tanh(mul.3)
-      const.2 = bf16[] constant(1)
-      bcast.2 = bf16[16,16] broadcast(const.2), dimensions={}
-      add.2 = bf16[16,16] add(tanh, bcast.2)
-      const.3 = bf16[] constant(0.5)
-      bcast.3 = bf16[16,16] broadcast(const.3), dimensions={}
-      mul.4 = bf16[16,16] multiply(add.2, bcast.3)
-      ROOT out = bf16[16,16] multiply(dot, mul.4)
-          }
-)";
-
-  CheckFp8IfSupported(hlo_text);
-
-// Fusing gelu into FP8 cublas matmuls is disabled on CUDA versions less
-// than 12.4.
-#if (GOOGLE_CUDA && CUDA_VERSION >= 12040) || TENSORFLOW_USE_ROCM
-  RunAndFilecheckHloRewrite(
-      hlo_text,
-      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
-                   /*f8_rewrite=*/true),
-      R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: bf16[], {{.*}}: bf16[], {{.*}}: bf16[16]) -> bf16[16,16] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
-; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK-NEXT:    [[P2:%[^ ]+]] = bf16[] parameter(2)
-; CHECK-NEXT:    [[XS:%[^ ]+]] = f32[] convert([[P2]])
-; CHECK-NEXT:    [[P3:%[^ ]+]] = bf16[] parameter(3)
-; CHECK-NEXT:    [[XS1:%[^ ]+]] = f32[] convert([[P3]])
-; CHECK-NEXT:    [[C1:%[^ ]+]] = f32[] constant(1)
-)"
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60200
-      R"(; CHECK-GCN-NEXT:    [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]], [[C1]], /*index=5*/[[C1]]),
-)"
-#else
-      R"(; CHECK-NEXT:    [[B:%[^ ]+]] = bf16[16]{0} parameter(4)
-; CHECK-NEXT:    [[OUT:%[^ ]+]] = (bf16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]], [[C1]], /*index=5*/[[C1]], [[B]]),
-)"
-#endif
-      R"(; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-)"
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60200
-      R"(; CHECK-GCN-DAG:         "epilogue":"DEFAULT"
-)"
-#else
-      R"(; CHECK-DAG:         "epilogue":"BIAS_GELU"
-)"
-#endif
-      R"(; CHECK:           }
-      )");
-#endif  // (GOOGLE_CUDA && CUDA_VERSION >= 12040) || TENSORFLOW_USE_ROCM
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest,
-       ScaledABUnscaledDApproxGeluActivationF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif  // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif  // TF_ROCM_VERSION < 60000
-
-  const char* hlo_text = R"(
-    HloModule test
-    ENTRY test {
-      x = <<F8E4M3>>[16,32] parameter(0)
-      y = <<F8E4M3>>[32,16] parameter(1)
-      x_bf16 = bf16[16,32] convert(x)
-      y_bf16 = bf16[32,16] convert(y)
-      x_scale = bf16[] parameter(2)
-      y_scale = bf16[] parameter(3)
-      x_scale_bcast = bf16[16,32] broadcast(x_scale), dimensions={}
-      y_scale_bcast = bf16[32,16] broadcast(y_scale), dimensions={}
-      x_unscaled = bf16[16,32] multiply(x_bf16, x_scale_bcast)
-      y_unscaled = bf16[32,16] multiply(y_bf16, y_scale_bcast)
-      dot = bf16[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-      mul.0 = bf16[16,16] multiply(dot, dot)
-      mul.1 = bf16[16,16] multiply(dot, mul.0)
-      const.0 = bf16[] constant(0.044715)
-      bcast.0 = bf16[16,16] broadcast(const.0), dimensions={}
-      mul.2 = bf16[16,16] multiply(mul.1, bcast.0)
-      add.0 = bf16[16,16] add(dot, mul.2)
-      const.1 = bf16[] constant(0.797884583)
-      bcast.1 = bf16[16,16] broadcast(const.1), dimensions={}
-      mul.3 = bf16[16,16] multiply(add.0, bcast.1)
-      tanh = bf16[16,16] tanh(mul.3)
-      const.2 = bf16[] constant(1)
-      bcast.2 = bf16[16,16] broadcast(const.2), dimensions={}
-      add.2 = bf16[16,16] add(tanh, bcast.2)
-      const.3 = bf16[] constant(0.5)
-      bcast.3 = bf16[16,16] broadcast(const.3), dimensions={}
-      mul.4 = bf16[16,16] multiply(add.2, bcast.3)
-      ROOT out = bf16[16,16] multiply(dot, mul.4)
-          }
-)";
-
-  CheckFp8IfSupported(hlo_text);
-
-// Fusing gelu into FP8 cublas matmuls is disabled on CUDA versions less
-// than 12.4.
-#if (GOOGLE_CUDA && CUDA_VERSION >= 12040) || TENSORFLOW_USE_ROCM
-  // Currently, hipBlasLt does not support output datatype bf16 for fp8 matmul.
-  // And no fusion was done for such cases.
-  RunAndFilecheckHloRewrite(
-      hlo_text,
-      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
-                   /*f8_rewrite=*/true),
-      R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: bf16[], {{.*}}: bf16[]) -> bf16[16,16] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
-; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK-NEXT:    [[P2:%[^ ]+]] = bf16[] parameter(2)
-; CHECK-NEXT:    [[XS:%[^ ]+]] = f32[] convert([[P2]])
-; CHECK-NEXT:    [[P3:%[^ ]+]] = bf16[] parameter(3)
-; CHECK-NEXT:    [[XS1:%[^ ]+]] = f32[] convert([[P3]])
-; CHECK-NEXT:    [[C1:%[^ ]+]] = f32[] constant(1)
-)"
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60200
-      R"(; CHECK-GCN-NEXT:    [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]], [[C1]], /*index=5*/[[C1]]),
-)"
-#else
-      R"(; CHECK-NEXT:    [[OUT:%[^ ]+]] = (bf16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]], [[C1]], /*index=5*/[[C1]]),
-)"
-#endif
-      R"(; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-)"
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60200
-      R"(; CHECK-GCN-DAG:         "epilogue":"DEFAULT"
-)"
-#else
-      R"(; CHECK-DAG:         "epilogue":"GELU"
-)"
-#endif
-      R"(; CHECK:           }
-      )");
-#endif  // (GOOGLE_CUDA && CUDA_VERSION >= 12040) || TENSORFLOW_USE_ROCM
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, InvScaledABUnscaledDF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif  // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif  // TF_ROCM_VERSION < 60000
-
-  const char* hlo_text = R"(
-    HloModule test
-
-    ENTRY test {
-      x = <<F8E4M3>>[16,32] parameter(0)
-      y = <<F8E4M3>>[32,16] parameter(1)
-      x_f32 = f32[16,32] convert(x)
-      y_f32 = f32[32,16] convert(y)
-      x_scale = f32[] parameter(2)
-      y_scale = f32[] parameter(3)
-      x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
-      y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
-      x_unscaled = f32[16,32] divide(x_f32, x_scale_bcast)
-      y_unscaled = f32[32,16] divide(y_f32, y_scale_bcast)
-      ROOT out = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-          }
-
-)";
-
-  CheckFp8IfSupported(hlo_text);
-  RunAndFilecheckHloRewrite(
-      hlo_text,
-      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
-                   /*f8_rewrite=*/true),
-      R"(
-; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
-      )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDMatrixBiasF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif  // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif  // TF_ROCM_VERSION < 60000
-
-  const char* hlo_text = R"(
-    HloModule test
-
-    ENTRY test {
-      x = <<F8E4M3>>[16,32] parameter(0)
-      y = <<F8E4M3>>[32,16] parameter(1)
-      b = f32[16,16] parameter(2)
-      one = f32[] constant(1)
-      ones = f32[16,16] broadcast(one), dimensions={}
-      b_ones = f32[16,16] add(b, ones)
-      x_f32 = f32[16,32] convert(x)
-      y_f32 = f32[32,16] convert(y)
-      x_scale = f32[] parameter(3)
-      y_scale = f32[] parameter(4)
-      x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
-      y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
-      x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
-      y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast)
-      dot_a = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-      ROOT out = add(dot_a, b_ones)
-          }
-
-)";
-
-  CheckFp8IfSupported(hlo_text);
-  RunAndFilecheckHloRewrite(
-      hlo_text,
-      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
-                   /*f8_rewrite=*/true),
-      R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[16,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
-; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK:         [[C0:%[^ ]+]] = f32[16,16]{1,0} add({{.*}})
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[] parameter(3)
-; CHECK-NEXT:    [[P3:%[^ ]+]] = f32[] parameter(4)
-; CHECK-NEXT:    [[C1:%[^ ]+]] = f32[] constant(1)
-; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C0]], [[P2]], [[P3]], /*index=5*/[[C1]], [[C1]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
-; CHECK:           output_to_operand_aliasing={
-; CHECK-SAME:        {0}: (2, {})
-; CHECK-SAME:      }
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":1
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-      )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDMatrixBiasPaddedF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif  // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif  // TF_ROCM_VERSION < 60000
-
-  const char* hlo_text = R"(
-    HloModule test
-
-    ENTRY test {
-      x = <<F8E4M3>>[14,31] parameter(0)
-      y = <<F8E4M3>>[31,14] parameter(1)
-      b = f32[14,14] parameter(2)
-      x_f32 = f32[14,31] convert(x)
-      y_f32 = f32[31,14] convert(y)
-      x_scale = f32[] parameter(3)
-      y_scale = f32[] parameter(4)
-      x_scale_bcast = f32[14,31] broadcast(x_scale), dimensions={}
-      y_scale_bcast = f32[31,14] broadcast(y_scale), dimensions={}
-      x_unscaled = f32[14,31] multiply(x_f32, x_scale_bcast)
-      y_unscaled = f32[31,14] multiply(y_f32, y_scale_bcast)
-      dot_a = f32[14,14] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-      ROOT out = add(dot_a, b)
-          }
-
-)";
-
-  CheckFp8IfSupported(hlo_text);
-  RunAndFilecheckHloRewrite(
-      hlo_text,
-      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
-                   /*f8_rewrite=*/true),
-      R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[14,31], {{.*}}: <<F8E4M3>>[31,14], {{.*}}: f32[14,14], {{.*}}: f32[], {{.*}}: f32[]) -> f32[14,14] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = <<F8E4M3>>[14,31]{1,0} parameter(0)
-; CHECK-NEXT:    [[C0:%[^ ]+]] = <<F8E4M3>>[] constant(0)
-; CHECK-NEXT:    [[P0_PADDED:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} pad([[P0]], [[C0]]), padding=0_2x0_1
-; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[31,14]{1,0} parameter(1)
-; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[14,31]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK-NEXT:    [[C1:%[^ ]+]] = <<F8E4M3>>[] constant(0)
-; CHECK-NEXT:    [[P1_TRANSPOSE_PADDED:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} pad([[P1_TRANSPOSE]], [[C1]]), padding=0_2x0_1
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[14,14]{1,0} parameter(2)
-; CHECK-NEXT:    [[C2:%[^ ]+]] = f32[] constant(0)
-; CHECK-NEXT:    [[P2_PADDED:%[^ ]+]] = f32[16,16]{1,0} pad([[P2]], [[C2]]), padding=0_2x0_2
-; CHECK-NEXT:    [[P3:%[^ ]+]] = f32[] parameter(3)
-; CHECK-NEXT:    [[P4:%[^ ]+]] = f32[] parameter(4)
-; CHECK-NEXT:    [[C3:%[^ ]+]] = f32[] constant(1)
-; CHECK-NEXT:    [[DOT_TUPLE:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_PADDED]], [[P1_TRANSPOSE_PADDED]], [[P2_PADDED]], [[P3]], [[P4]], /*index=5*/[[C3]], [[C3]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":1
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-; CHECK:      [[DOT:%[^ ]+]] = f32[16,16]{1,0} get-tuple-element([[DOT_TUPLE]]), index=0
-; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[14,14]{1,0} slice([[DOT]]), slice={[0:14], [0:14]}
-      )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABScaledDF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif  // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif  // TF_ROCM_VERSION < 60000
-
-  const char* hlo_text = R"(
-    HloModule test
-
-    ENTRY test {
-      x = <<F8E4M3>>[16,32] parameter(0)
-      y = <<F8E4M3>>[32,16] parameter(1)
-      z_scale = f32[] parameter(2)
-      z_scale_bcast = f32[16,16] broadcast(z_scale), dimensions={}
-      dot_a = f32[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-      dot_a_scaled = f32[16,16] divide(dot_a, z_scale_bcast)
-      c1 = f32[] constant(-448.)
-      c1_bcast = f32[16,16] broadcast(c1), dimensions={}
-      c2 = f32[] constant(448.)
-      c2_bcast = f32[16,16] broadcast(c2), dimensions={}
-      dot_a_clamped = f32[16,16] clamp(c1_bcast, dot_a_scaled, c2_bcast)
-      ROOT dot_a_f8 = <<F8E4M3>>[16,16] convert(dot_a_clamped)
-          }
-
-)";
-
-  CheckFp8IfSupported(hlo_text, ErrorSpec{1e-2, 1e-1});
-  RunAndFilecheckHloRewrite(
-      hlo_text,
-      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
-                   /*f8_rewrite=*/true),
-      R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[]) -> <<F8E4M3>>[16,16] {
-; CHECK:         [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
-; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK-NEXT:    [[C0:%[^ ]+]] = f32[] constant(1)
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[] parameter(2)
-; CHECK-NEXT:    [[P2_INV:%[^ ]+]] = f32[] divide([[C0]], [[P2]])
-; CHECK-NEXT:    [[C1:%[^ ]+]] = f32[] constant(1)
-; CHECK-PTX-NEXT:    [[OUT:%[^ ]+]] = (<<F8E4M3>>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_INV]], [[C1]], [[C1]], /*index=5*/[[C1]]),
-; CHECK-GCN-NEXT:    [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_INV]], [[C1]], [[C1]], /*index=5*/[[C1]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-      )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABScaledF32DF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
-  GTEST_SKIP() << "F8 GEMM rewrite requires CUDA 12 and above.";
-#endif  // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif  // TF_ROCM_VERSION < 60000
-
-  const char* hlo_text = R"(
-    HloModule test
-
-    ENTRY test {
-      x = <<F8E4M3>>[16,32] parameter(0)
-      y = <<F8E4M3>>[32,16] parameter(1)
-      z_scale = f32[] parameter(2)
-      z_scale_bcast = f32[16,16] broadcast(z_scale), dimensions={}
-      dot_a = f32[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-      ROOT dot_a_scaled = f32[16,16] divide(dot_a, z_scale_bcast)
-          }
-
-)";
-
-  CheckFp8IfSupported(hlo_text, ErrorSpec{1e-2, 1e-1});
-  RunAndFilecheckHloRewrite(
-      hlo_text,
-      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
-                   /*f8_rewrite=*/true),
-      R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[]) -> f32[16,16] {
-; CHECK:         [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
-; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK-NEXT:    [[C0:%[^ ]+]] = f32[] constant(1)
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[] parameter(2)
-; CHECK-NEXT:    [[P2_INV:%[^ ]+]] = f32[] divide([[C0]], [[P2]])
-; CHECK-NEXT:    [[C1:%[^ ]+]] = f32[] constant(1)
-; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_INV]], [[C1]], [[C1]], /*index=5*/[[C1]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-      )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABInvScaledF32DF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
-  GTEST_SKIP() << "F8 GEMM rewrite requires CUDA 12 and above.";
-#endif  // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif  // TF_ROCM_VERSION < 60000
-
-  const char* hlo_text = R"(
-    HloModule test
-
-    ENTRY test {
-      x = <<F8E4M3>>[16,32] parameter(0)
-      y = <<F8E4M3>>[32,16] parameter(1)
-      z_scale = f32[] parameter(2)
-      z_scale_bcast = f32[16,16] broadcast(z_scale), dimensions={}
-      dot_a = f32[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-      ROOT dot_a_scaled = f32[16,16] multiply(dot_a, z_scale_bcast)
-          }
-
-)";
-
-  CheckFp8IfSupported(hlo_text, ErrorSpec{1e-2, 1e-1});
-  RunAndFilecheckHloRewrite(
-      hlo_text,
-      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
-                   /*f8_rewrite=*/true),
-      R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[]) -> f32[16,16] {
-; CHECK:         [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
-; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[] parameter(2)
-; CHECK-NEXT:    [[C0:%[^ ]+]] = f32[] constant(1)
-; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[C0]], [[C0]], /*index=5*/[[C0]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-      )");
-}
-
-// Do not fuse output scaling without type conversion when a matrix bias was
-// fused.
-TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABScaledF32DMatrixBiasF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
-  GTEST_SKIP() << "F8 GEMM rewrite requires CUDA 12 and above.";
-#endif  // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif  // TF_ROCM_VERSION < 60000
-
-  const char* hlo_text = R"(
-    HloModule test
-
-    ENTRY test {
-      x = <<F8E4M3>>[16,32] parameter(0)
-      y = <<F8E4M3>>[32,16] parameter(1)
-      b = f32[16,16] parameter(2)
-      z_scale = f32[] parameter(3)
-      z_scale_bcast = f32[16,16] broadcast(z_scale), dimensions={}
-      dot_a = f32[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-      dot_a_bias = f32[16,16] add(dot_a, b)
-      ROOT dot_a_scaled = f32[16,16] divide(dot_a_bias, z_scale_bcast)
-          }
-
-)";
-
-  CheckFp8IfSupported(hlo_text, ErrorSpec{1e-2, 1e-1});
-  RunAndFilecheckHloRewrite(
-      hlo_text,
-      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
-                   /*f8_rewrite=*/true),
-      R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[]) -> f32[16,16] {
-; CHECK:         [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
-; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[16,16]{1,0} parameter(2)
-; CHECK-NEXT:    [[C0:%[^ ]+]] = f32[] constant(1)
-; CHECK-NEXT:    [[GEMM_TUPLE:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[C0]], [[C0]], /*index=5*/[[C0]], [[C0]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":1
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK-PTX-NEXT:    [[GEMM:%[^ ]+]] = f32[16,16]{1,0} get-tuple-element([[GEMM_TUPLE]]), index=0
-; CHECK-PTX-NEXT:    [[P3:%[^ ]+]] = f32[] parameter(3)
-; CHECK-PTX-NEXT:    [[P3_BCAST:%[^ ]+]] = f32[16,16]{1,0} broadcast([[P3]]), dimensions={}
-; CHECK-PTX-NEXT:    ROOT [[OUT:%[^ ]+]] = f32[16,16]{1,0} divide([[GEMM]], [[P3_BCAST]])
-; CHECK:           }
-      )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif  // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif  // TF_ROCM_VERSION < 60000
-
-  const char* hlo_text = R"(
-    HloModule test
-
-    ENTRY test {
-      x = <<F8E4M3>>[16,32] parameter(0)
-      y = <<F8E4M3>>[32,16] parameter(1)
-      x_f32 = f32[16,32] convert(x)
-      y_f32 = f32[32,16] convert(y)
-      x_scale = f32[] parameter(2)
-      y_scale = f32[] parameter(3)
-      z_scale = f32[] parameter(4)
-      x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
-      y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
-      z_scale_bcast = f32[16,16] broadcast(z_scale), dimensions={}
-      x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
-      y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast)
-      dot_a = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-      dot_a_scaled = f32[16,16] divide(dot_a, z_scale_bcast)
-      c1 = f32[] constant(-<<F8E4M3_AMAX>>)
-      c1_bcast = f32[16,16] broadcast(c1), dimensions={}
-      c2 = f32[] constant(<<F8E4M3_AMAX>>)
-      c2_bcast = f32[16,16] broadcast(c2), dimensions={}
-      dot_a_clamped = f32[16,16] clamp(c1_bcast, dot_a_scaled, c2_bcast)
-      ROOT dot_a_f8 = <<F8E4M3>>[16,16] convert(dot_a_clamped)
-          }
-
-)";
-
-  CheckFp8IfSupported(hlo_text);
-  RunAndFilecheckHloRewrite(
-      hlo_text,
-      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
-                   /*f8_rewrite=*/true),
-      R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[], {{.*}}: f32[], {{.*}}: f32[]) -> <<F8E4M3>>[16,16] {
-; CHECK:         [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
-; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[] parameter(2)
-; CHECK-NEXT:    [[P3:%[^ ]+]] = f32[] parameter(3)
-; CHECK-NEXT:    [[C1:%[^ ]+]] = f32[] constant(1)
-; CHECK-PTX-NEXT:    [[C2:%[^ ]+]] = f32[] constant(1)
-; CHECK-PTX-NEXT:    [[P4:%[^ ]+]] = f32[] parameter(4)
-; CHECK-PTX-NEXT:    [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]])
-; CHECK-PTX-NEXT:    [[OUT:%[^ ]+]] = (<<F8E4M3>>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[P4_INV]]),
-; CHECK-GCN-NEXT:    [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-      )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABInvScaledDF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif  // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif  // TF_ROCM_VERSION < 60000
-
-  const char* hlo_text = R"(
-    HloModule test
-
-    ENTRY test {
-      x = <<F8E4M3>>[16,32] parameter(0)
-      y = <<F8E4M3>>[32,16] parameter(1)
-      x_f32 = f32[16,32] convert(x)
-      y_f32 = f32[32,16] convert(y)
-      x_scale = f32[] parameter(2)
-      y_scale = f32[] parameter(3)
-      z_scale = f32[] parameter(4)
-      x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
-      y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
-      z_scale_bcast = f32[16,16] broadcast(z_scale), dimensions={}
-      x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
-      y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast)
-      dot_a = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-      dot_a_scaled = f32[16,16] multiply(dot_a, z_scale_bcast)
-      c1 = f32[] constant(-<<F8E4M3_AMAX>>)
-      c1_bcast = f32[16,16] broadcast(c1), dimensions={}
-      c2 = f32[] constant(<<F8E4M3_AMAX>>)
-      c2_bcast = f32[16,16] broadcast(c2), dimensions={}
-      dot_a_clamped = f32[16,16] clamp(c1_bcast, dot_a_scaled, c2_bcast)
-      ROOT dot_a_f8 = <<F8E4M3>>[16,16] convert(dot_a_clamped)
-          }
-
-)";
-
-  CheckFp8IfSupported(hlo_text);
-  RunAndFilecheckHloRewrite(
-      hlo_text,
-      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
-                   /*f8_rewrite=*/true),
-      R"(
-
-; CHECK-NOT:     divide
-
-; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
-
-      )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDReluActivationF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif  // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif  // TF_ROCM_VERSION < 60000
-
-  const char* hlo_text = R"(
-    HloModule test
-    ENTRY test {
-      x = <<F8E4M3>>[16,32] parameter(0)
-      y = <<F8E4M3>>[32,16] parameter(1)
-      x_f32 = f32[16,32] convert(x)
-      y_f32 = f32[32,16] convert(y)
-      x_scale = f32[] parameter(2)
-      y_scale = f32[] parameter(3)
-      z_scale = f32[] parameter(4)
-      x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
-      y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
-      z_scale_bcast = f32[16,16] broadcast(z_scale), dimensions={}
-      x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
-      y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast)
-      c = f32[] constant(0)
-      c_bcast = f32[16,16] broadcast(c), dimensions={}
-      dot_a = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-      relu_a = f32[16,16] maximum(dot_a, c_bcast)
-      relu_a_scaled = f32[16,16] divide(relu_a, z_scale_bcast)
-      c1 = f32[] constant(-<<F8E4M3_AMAX>>)
-      c1_bcast = f32[16,16] broadcast(c1), dimensions={}
-      c2 = f32[] constant(<<F8E4M3_AMAX>>)
-      c2_bcast = f32[16,16] broadcast(c2), dimensions={}
-      relu_a_clamped = f32[16,16] clamp(c1_bcast, relu_a_scaled, c2_bcast)
-      ROOT out = <<F8E4M3>>[16,16] convert(relu_a_clamped)
-          }
-)";
-
-  CheckFp8IfSupported(hlo_text);
-  RunAndFilecheckHloRewrite(
-      hlo_text,
-      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
-                   /*f8_rewrite=*/true),
-      R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[], {{.*}}: f32[], {{.*}}: f32[]) -> <<F8E4M3>>[16,16] {
-; CHECK:         [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
-; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[] parameter(2)
-; CHECK-NEXT:    [[P3:%[^ ]+]] = f32[] parameter(3)
-; CHECK-NEXT:    [[C1:%[^ ]+]] = f32[] constant(1)
-; CHECK-PTX-NEXT:    [[C2:%[^ ]+]] = f32[] constant(1)
-; CHECK-PTX-NEXT:    [[P4:%[^ ]+]] = f32[] parameter(4)
-; CHECK-PTX-NEXT:    [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]])
-; CHECK-PTX-NEXT:    [[OUT:%[^ ]+]] = (<<F8E4M3>>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[P4_INV]]),
-; CHECK-GCN-NEXT:    [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"RELU"
-; CHECK:           }
-      )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDMatrixBiasWithDAmaxF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif  // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif  // TF_ROCM_VERSION < 60000
-
-  const char* hlo_text = R"(
-    HloModule test
-
-    apply {
-      a = f16[] parameter(0)
-      b = f16[] parameter(1)
-      ROOT c = f16[] maximum(a, b)
-    }
-
-    ENTRY test {
-      x = <<F8E4M3>>[16,32] parameter(0)
-      y = <<F8E4M3>>[32,16] parameter(1)
-      x_f16 = f16[16,32] convert(x)
-      y_f16 = f16[32,16] convert(y)
-      b = f16[16,16] parameter(2)
-      one = f16[] constant(1)
-      ones = f16[16,16] broadcast(one), dimensions={}
-      b_ones = f16[16,16] add(b, ones)
-      x_scale = f16[] parameter(3)
-      y_scale = f16[] parameter(4)
-      z_scale = f16[] parameter(5)
-      x_scale_bcast = f16[16,32] broadcast(x_scale), dimensions={}
-      y_scale_bcast = f16[32,16] broadcast(y_scale), dimensions={}
-      z_scale_bcast = f16[16,16] broadcast(z_scale), dimensions={}
-      x_unscaled = f16[16,32] multiply(x_f16, x_scale_bcast)
-      y_unscaled = f16[32,16] multiply(y_f16, y_scale_bcast)
-      dot_a = f16[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-      dot_a_bias = f16[16,16] add(dot_a, b_ones)
-      abs_dot_a = f16[16,16] abs(dot_a_bias)
-      c0 = f16[] constant(-inf)
-      amax = f16[] reduce(abs_dot_a, c0), dimensions={0,1}, to_apply=apply
-      dot_a_scaled = f16[16,16] divide(dot_a_bias, z_scale_bcast)
-      c1 = f16[] constant(-<<F8E4M3_AMAX>>)
-      c1_bcast = f16[16,16] broadcast(c1), dimensions={}
-      c2 = f16[] constant(<<F8E4M3_AMAX>>)
-      c2_bcast = f16[16,16] broadcast(c2), dimensions={}
-      dot_a_clamped = f16[16,16] clamp(c1_bcast, dot_a_scaled, c2_bcast)
-      dot_a_f8 = <<F8E4M3>>[16,16] convert(dot_a_clamped)
-      ROOT result = (<<F8E4M3>>[16,16], f16[]) tuple(dot_a_f8, amax)
-    }
-)";
-
-  CheckFp8IfSupported(hlo_text, ErrorSpec{0.1, 0.1});
-  RunAndFilecheckHloRewrite(
-      hlo_text,
-      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
-                   /*f8_rewrite=*/true),
-      R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f16[16,16], {{.*}}: f16[], {{.*}}: f16[], {{.*}}: f16[]) -> (<<F8E4M3>>[16,16], f16[]) {
-; CHECK:    [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
-; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK:         [[C0:%[^ ]+]] = f16[16,16]{1,0} add({{.*}})
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f16[] parameter(3)
-; CHECK:         [[P3:%[^ ]+]] = f16[] parameter(4)
-; CHECK:         [[C1:%[^ ]+]] = f32[] constant(1)
-; CHECK-PTX:         [[P4:%[^ ]+]] = f16[] parameter(5)
-; CHECK-PTX:       [[OUT:%[^ ]+]] = (<<F8E4M3>>[16,16]{1,0}, f32[], s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C0]], [[DUMMY0:%[^ ]+]], [[DUMMY1:%[^ ]+]], /*index=5*/[[C1]], [[DUMMY2:%[^ ]+]]),
-; CHECK-NOT:       output_to_operand_aliasing
-; CHECK-GCN:       [[OUT:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C0]], [[DUMMY0:%[^ ]+]], [[DUMMY1:%[^ ]+]], /*index=5*/[[C1]], [[DUMMY2:%[^ ]+]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":1
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-      )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDVectorBiasF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif  // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif  // TF_ROCM_VERSION < 60000
-
-  const char* hlo_text = R"(
-    HloModule test
-
-    ENTRY test {
-      x = <<F8E4M3>>[16,32] parameter(0)
-      y = <<F8E4M3>>[32,16] parameter(1)
-      x_f16 = f16[16,32] convert(x)
-      y_f16 = f16[32,16] convert(y)
-      b = f16[16] parameter(2)
-      b_bcast = f16[16,16] broadcast(b), dimensions={1}
-      x_scale = f16[] parameter(3)
-      y_scale = f16[] parameter(4)
-      z_scale = f16[] parameter(5)
-      x_scale_bcast = f16[16,32] broadcast(x_scale), dimensions={}
-      y_scale_bcast = f16[32,16] broadcast(y_scale), dimensions={}
-      z_scale_bcast = f16[16,16] broadcast(z_scale), dimensions={}
-      x_unscaled = f16[16,32] multiply(x_f16, x_scale_bcast)
-      y_unscaled = f16[32,16] multiply(y_f16, y_scale_bcast)
-      dot_a = f16[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-      dot_a_bias = f16[16,16] add(dot_a, b_bcast)
-      dot_a_scaled = f16[16,16] divide(dot_a_bias, z_scale_bcast)
-      c1 = f16[] constant(-<<F8E4M3_AMAX>>)
-      c1_bcast = f16[16,16] broadcast(c1), dimensions={}
-      c2 = f16[] constant(<<F8E4M3_AMAX>>)
-      c2_bcast = f16[16,16] broadcast(c2), dimensions={}
-      dot_a_clamped = f16[16,16] clamp(c1_bcast, dot_a_scaled, c2_bcast)
-      ROOT dot_a_f8 = <<F8E4M3>>[16,16] convert(dot_a_clamped)
-          }
-
-)";
-
-  CheckFp8IfSupported(hlo_text, ErrorSpec{0.1, 0.1});
-  RunAndFilecheckHloRewrite(
-      hlo_text,
-      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
-                   /*f8_rewrite=*/true),
-      R"(
-
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f16[16], {{.*}}: f16[], {{.*}}: f16[], {{.*}}: f16[]) -> <<F8E4M3>>[16,16] {
-; CHECK:    [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
-; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f16[] parameter(3)
-; CHECK-NEXT:    [[CV:%[^ ]+]] = f32[] convert([[P2]])
-; CHECK-NEXT:    [[P3:%[^ ]+]] = f16[] parameter(4)
-; CHECK-NEXT:    [[CV1:%[^ ]+]] = f32[] convert([[P3]])
-; CHECK-NEXT:    [[C:%[^ ]+]] = f32[] constant(1)
-; CHECK-PTX-NEXT:    [[C2:%[^ ]+]] = f16[] constant(1)
-; CHECK-PTX-NEXT:    [[P4:%[^ ]+]] = f16[] parameter(5)
-; CHECK-PTX-NEXT:    [[DV:%[^ ]+]] = f16[] divide([[C2]], [[P4]])
-; CHECK-PTX-NEXT:    [[CV2:%[^ ]+]] = f32[] convert([[DV]])
-; CHECK-NEXT:    [[VB:%[^ ]+]] = f16[16]{0} parameter(2)
-; CHECK-PTX:         [[OUT:%[^ ]+]] = (<<F8E4M3>>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[CV]], [[CV1]], [[C]], /*index=5*/[[CV2]], [[VB]]),
-; CHECK-GCN:         [[OUT:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[CV]], [[CV1]], [[C]], /*index=5*/[[C]], [[VB]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"BIAS"
-; CHECK:           }
-      )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF32VectorBiasF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif  // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif  // TF_ROCM_VERSION < 60000
-
-  const char* hlo_text = R"(
-    HloModule test
-
-    ENTRY test {
-      x = <<F8E4M3>>[16,32] parameter(0)
-      y = <<F8E4M3>>[32,16] parameter(1)
-      x_f32 = f32[16,32] convert(x)
-      y_f32 = f32[32,16] convert(y)
-      b = f32[16] parameter(2)
-      b_bf16 = bf16[16] convert(b)
-      b_f32 = f32[16] convert(b_bf16)
-      b_bcast = f32[16,16] broadcast(b_f32), dimensions={1}
-      x_scale = f32[] parameter(3)
-      y_scale = f32[] parameter(4)
-      x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
-      y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
-      x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
-      y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast)
-      dot_a = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-      ROOT out = f32[16,16] add(dot_a, b_bcast)
-           }
-
-)";
-
-  CheckFp8IfSupported(hlo_text);
-  RunAndFilecheckHloRewrite(
-      hlo_text,
-      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
-                   /*f8_rewrite=*/true),
-      R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] {
-; CHECK:         [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
-; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[] parameter(3)
-; CHECK-NEXT:    [[P3:%[^ ]+]] = f32[] parameter(4)
-; CHECK-NEXT:    [[C:%[^ ]+]] = f32[] constant(1)
-; CHECK-NEXT:    [[VB:%[^ ]+]] = f32[16]{0} parameter(2)
-; CHECK-NEXT:    [[VBC:%[^ ]+]] = bf16[16]{0} convert([[VB]])
-; CHECK:         [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C]], /*index=5*/[[C]], [[VBC]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"BIAS"
-; CHECK:           }
-      )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest,
-       ScaledABUnscaledDVectorBiasThenReluActivationF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif  // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif  // TF_ROCM_VERSION < 60000
-
-  const char* hlo_text = R"(
-    HloModule test
-
-    ENTRY test {
-      x = <<F8E4M3>>[16,32] parameter(0)
-      y = <<F8E4M3>>[32,16] parameter(1)
-      b = f16[16] parameter(2)
-      b_bcast = f16[16,16] broadcast(b), dimensions={1}
-      x_f32 = f16[16,32] convert(x)
-      y_f32 = f16[32,16] convert(y)
-      x_scale = f16[] parameter(3)
-      y_scale = f16[] parameter(4)
-      x_scale_bcast = f16[16,32] broadcast(x_scale), dimensions={}
-      y_scale_bcast = f16[32,16] broadcast(y_scale), dimensions={}
-      x_unscaled = f16[16,32] multiply(x_f32, x_scale_bcast)
-      y_unscaled = f16[32,16] multiply(y_f32, y_scale_bcast)
-      c = f16[] constant(0)
-      c_bcast = f16[16,16] broadcast(c), dimensions={}
-      dot_a0 = f16[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-      dot_a = f16[16,16] add(dot_a0, b_bcast)
-      ROOT out = f16[16,16] maximum(dot_a, c_bcast)
-          }
-)";
-
-  CheckFp8IfSupported(hlo_text, ErrorSpec{2e-3, 0.});
-  RunAndFilecheckHloRewrite(
-      hlo_text,
-      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
-                   /*f8_rewrite=*/true),
-      R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f16[16], {{.*}}: f16[], {{.*}}: f16[]) -> f16[16,16] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
-; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f16[] parameter(3)
-; CHECK-NEXT:    [[CV:%[^ ]+]] = f32[] convert([[P2]])
-; CHECK-NEXT:    [[P3:%[^ ]+]] = f16[] parameter(4)
-; CHECK-NEXT:    [[CV1:%[^ ]+]] = f32[] convert([[P3]])
-; CHECK-NEXT:    [[C:%[^ ]+]] = f32[] constant(1)
-; CHECK-NEXT:    [[VB:%[^ ]+]] = f16[16]{0} parameter(2)
-; CHECK     :    ROOT [[OUT:%[^ ]+]] = f16[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[CV]], [[CV1]], [[C]], /*index=5*/[[C]], [[VB]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"BIAS_RELU"
-; CHECK:           }
-      )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDVectorBiasF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
-  GTEST_SKIP() << "A matrix bias on a matmul is only supported in CUDA 12";
-#endif  // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif  // TF_ROCM_VERSION < 60000
-
-  const char* hlo_text = R"(
-    HloModule test
-    ENTRY test {
-      x = <<F8E4M3>>[4,16,16] parameter(0)
-      y = <<F8E4M3>>[16,32] parameter(1)
-      b = f32[32] parameter(2)
-      b_f16 = f16[32] convert(b)
-      b_bcast = f16[4,16,32] broadcast(b_f16), dimensions={2}
-      x_f16 = f16[4,16,16] convert(x)
-      y_f16 = f16[16,32] convert(y)
-      x_scale = f16[] parameter(3)
-      y_scale = f16[] parameter(4)
-      x_scale_bcast = f16[4,16,16] broadcast(x_scale), dimensions={}
-      y_scale_bcast = f16[16,32] broadcast(y_scale), dimensions={}
-      x_unscaled = f16[4,16,16] multiply(x_f16, x_scale_bcast)
-      x_unscaled_bitcast = f16[64,16] bitcast(x_unscaled)
-      y_unscaled = f16[16,32] multiply(y_f16, y_scale_bcast)
-      dot_a = f16[64,32] dot(x_unscaled_bitcast, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-      dot_a_bitcast = f16[4,16,32]{2,1,0} bitcast(dot_a)
-      ROOT out = f16[4,16,32] add(dot_a_bitcast, b_bcast)
-          }
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(hlo_text));
-  GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(),
-                    /*f8_rewrite=*/true);
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
-  EXPECT_TRUE(changed);
-
-  EXPECT_THAT(
-      module->entry_computation()->root_instruction(),
-      GmockMatch(m::Bitcast(m::GetTupleElement(
-                                m::CustomCall({"__cublas$lt$matmul$f8"}), 0)
-                                .WithShape(F16, {64, 32}))
-                     .WithShape(F16, {4, 16, 32})));
-
-  RunAndFilecheckHloRewrite(
-      hlo_text,
-      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
-                   /*f8_rewrite=*/true),
-      R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[4,16,16], {{.*}}: <<F8E4M3>>[16,32], {{.*}}: f32[32], {{.*}}: f16[], {{.*}}: f16[]) -> f16[4,16,32] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = <<F8E4M3>>[4,16,16]{2,1,0} parameter(0)
-; CHECK-NEXT:    [[P0_BITCAST:%[^ ]+]] = <<F8E4M3>>[64,16]{1,0} bitcast([[P0]])
-; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(1)
-; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f16[] parameter(3)
-; CHECK-NEXT:    [[P2_CV:%[^ ]+]] = f32[] convert([[P2]])
-; CHECK-NEXT:    [[P3:%[^ ]+]] = f16[] parameter(4)
-; CHECK-NEXT:    [[P3_CV:%[^ ]+]] = f32[] convert([[P3]])
-; CHECK-NEXT:    [[C:%[^ ]+]] = f32[] constant(1)
-; CHECK-NEXT:    [[B:%[^ ]+]] = f32[32]{0} parameter(2)
-; CHECK-NEXT:    [[B_F16:%[^ ]+]] = f16[32]{0} convert([[B]])
-; CHECK-NEXT:    [[GEMM_TUPLE:%[^ ]+]] = (f16[64,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_BITCAST]], [[P1_TRANSPOSE]], [[P2_CV]], [[P3_CV]], [[C]], /*index=5*/[[C]], [[B_F16]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"BIAS"
-; CHECK:           }
-; CHECK:         [[GEMM:%[^ ]+]] = f16[64,32]{1,0} get-tuple-element([[GEMM_TUPLE]]), index=0
-; CHECK:         ROOT [[OUT:%[^ ]+]] = f16[4,16,32]{2,1,0} bitcast([[GEMM]])
-      )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest,
-       Rank3ScaledABUnscaledDVectorBiasPaddedF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
-  GTEST_SKIP() << "A matrix bias on a matmul is only supported in CUDA 12";
-#endif
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif  // TF_ROCM_VERSION < 60000
-
-  const char* hlo_text = R"(
-    HloModule test
-    ENTRY test {
-      x = <<F8E4M3>>[4,15,15] parameter(0)
-      y = <<F8E4M3>>[15,31] parameter(1)
-      b = f32[31] parameter(2)
-      b_f16 = f16[31] convert(b)
-      b_bcast = f16[4,15,31] broadcast(b_f16), dimensions={2}
-      x_f16 = f16[4,15,15] convert(x)
-      y_f16 = f16[15,31] convert(y)
-      x_scale = f16[] parameter(3)
-      y_scale = f16[] parameter(4)
-      x_scale_bcast = f16[4,15,15] broadcast(x_scale), dimensions={}
-      y_scale_bcast = f16[15,31] broadcast(y_scale), dimensions={}
-      x_unscaled = f16[4,15,15] multiply(x_f16, x_scale_bcast)
-      x_unscaled_bitcast = f16[60,15] bitcast(x_unscaled)
-      y_unscaled = f16[15,31] multiply(y_f16, y_scale_bcast)
-      dot_a = f16[60,31] dot(x_unscaled_bitcast, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-      dot_a_bitcast = f16[4,15,31]{2,1,0} bitcast(dot_a)
-      ROOT out = f16[4,15,31] add(dot_a_bitcast, b_bcast)
-          }
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(hlo_text));
-  GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(),
-                    /*f8_rewrite=*/true);
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
-  EXPECT_TRUE(changed);
-
-  EXPECT_THAT(
-      module->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::Bitcast(m::Slice(m::GetTupleElement(
-                                  m::CustomCall({"__cublas$lt$matmul$f8"}), 0)
-                                  .WithShape(F16, {64, 32}))
-                         .WithShape(F16, {60, 31}))
-              .WithShape(F16, {4, 15, 31})));
-
-  RunAndFilecheckHloRewrite(
-      hlo_text,
-      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
-                   /*f8_rewrite=*/true),
-      R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[4,15,15], {{.*}}: <<F8E4M3>>[15,31], {{.*}}: f32[31], {{.*}}: f16[], {{.*}}: f16[]) -> f16[4,15,31] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = <<F8E4M3>>[4,15,15]{2,1,0} parameter(0)
-; CHECK-NEXT:    [[P0_BITCAST:%[^ ]+]] = <<F8E4M3>>[60,15]{1,0} bitcast([[P0]])
-; CHECK-NEXT:    [[C1:%[^ ]+]] = <<F8E4M3>>[] constant(0)
-; CHECK-NEXT:    [[P0_PAD:%[^ ]+]] = <<F8E4M3>>[64,16]{1,0} pad([[P0_BITCAST]], [[C1]]), padding=0_4x0_1
-; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[15,31]{1,0} parameter(1)
-; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[31,15]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK-NEXT:    [[C2:%[^ ]+]] = <<F8E4M3>>[] constant(0)
-; CHECK-NEXT:    [[P1_PAD:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} pad([[P1_TRANSPOSE]], [[C2]]), padding=0_1x0_1
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f16[] parameter(3)
-; CHECK-NEXT:    [[P2_CV:%[^ ]+]] = f32[] convert([[P2]])
-; CHECK-NEXT:    [[P3:%[^ ]+]] = f16[] parameter(4)
-; CHECK-NEXT:    [[P3_CV:%[^ ]+]] = f32[] convert([[P3]])
-; CHECK-NEXT:    [[C:%[^ ]+]] = f32[] constant(1)
-; CHECK-NEXT:    [[B:%[^ ]+]] = f32[31]{0} parameter(2)
-; CHECK-NEXT:    [[B_F16:%[^ ]+]] = f16[31]{0} convert([[B]])
-; CHECK-NEXT:    [[C3:%[^ ]+]] = f16[] constant(0)
-; CHECK-NEXT:    [[P2_PAD:%[^ ]+]] = f16[32]{0} pad([[B_F16]], [[C3]]), padding=0_1
-; CHECK-NEXT:    [[GEMM_TUPLE:%[^ ]+]] = (f16[64,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_PAD]], [[P1_PAD]], [[P2_CV]], [[P3_CV]], [[C]], /*index=5*/[[C]], [[P2_PAD]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"BIAS"
-; CHECK:           }
-; CHECK:          [[GEMM:%[^ ]+]] = f16[64,32]{1,0} get-tuple-element([[GEMM_TUPLE]]), index=0
-; CHECK-NEXT:     [[SLICE:%[^ ]+]] = f16[60,31]{1,0} slice([[GEMM]]), slice={[0:60], [0:31]}
-; CHECK-NEXT:     ROOT [[OUT:%[^ ]+]] = f16[4,15,31]{2,1,0} bitcast([[SLICE]])
-      )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDMatrixBiasF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
-  GTEST_SKIP() << "A matrix bias on a matmul is only supported in CUDA 12";
-#endif
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif  // TF_ROCM_VERSION < 60000
-
-  const char* hlo_text = R"(
-    HloModule test
-    ENTRY test {
-      x = <<F8E4M3>>[4,16,16] parameter(0)
-      y = <<F8E4M3>>[16,32] parameter(1)
-      b = f32[4,16,32] parameter(2)
-      x_f32 = f32[4,16,16] convert(x)
-      y_f32 = f32[16,32] convert(y)
-      x_scale = f32[] parameter(3)
-      y_scale = f32[] parameter(4)
-      x_scale_bcast = f32[4,16,16] broadcast(x_scale), dimensions={}
-      y_scale_bcast = f32[16,32] broadcast(y_scale), dimensions={}
-      x_unscaled = f32[4,16,16] multiply(x_f32, x_scale_bcast)
-      x_unscaled_bitcast = f32[64,16] bitcast(x_unscaled)
-      y_unscaled = f32[16,32] multiply(y_f32, y_scale_bcast)
-      dot_a = f32[64,32] dot(x_unscaled_bitcast, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-      dot_a_bitcast = f32[4,16,32]{2,1,0} bitcast(dot_a)
-      ROOT out = f32[4,16,32] add(dot_a_bitcast, b)
-          }
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(hlo_text));
-  GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(),
-                    /*f8_rewrite=*/true);
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
-  EXPECT_TRUE(changed);
-
-  EXPECT_THAT(
-      module->entry_computation()->root_instruction(),
-      GmockMatch(m::Bitcast(m::GetTupleElement(
-                                m::CustomCall({"__cublas$lt$matmul$f8"}), 0)
-                                .WithShape(F32, {64, 32}))
-                     .WithShape(F32, {4, 16, 32})));
-
-  RunAndFilecheckHloRewrite(
-      hlo_text,
-      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
-                   /*f8_rewrite=*/true),
-      R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[4,16,16], {{.*}}: <<F8E4M3>>[16,32], {{.*}}: f32[4,16,32], {{.*}}: f32[], {{.*}}: f32[]) -> f32[4,16,32] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = <<F8E4M3>>[4,16,16]{2,1,0} parameter(0)
-; CHECK-NEXT:    [[P0_BITCAST:%[^ ]+]] = <<F8E4M3>>[64,16]{1,0} bitcast([[P0]])
-; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(1)
-; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK-NEXT:    [[B:%[^ ]+]] = f32[4,16,32]{2,1,0} parameter(2)
-; CHECK-NEXT:    [[B_BITCAST:%[^ ]+]] = f32[64,32]{1,0} bitcast([[B]])
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[] parameter(3)
-; CHECK-NEXT:    [[P3:%[^ ]+]] = f32[] parameter(4)
-; CHECK-NEXT:    [[C:%[^ ]+]] = f32[] constant(1)
-; CHECK-NEXT:    [[GEMM_TUPLE:%[^ ]+]] = (f32[64,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_BITCAST]], [[P1_TRANSPOSE]], [[B_BITCAST]], [[P2]], [[P3]], /*index=5*/[[C]], [[C]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":1
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-; CHECK:         [[GEMM:%[^ ]+]] = f32[64,32]{1,0} get-tuple-element([[GEMM_TUPLE]]), index=0
-; CHECK:         ROOT [[OUT:%[^ ]+]] = f32[4,16,32]{2,1,0} bitcast([[GEMM]])
-      )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest,
-       Rank3ScaledABUnscaledDMatrixBiasPaddedF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
-  GTEST_SKIP() << "A matrix bias on a matmul is only supported in CUDA 12";
-#endif
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif  // TF_ROCM_VERSION < 60000
-
-  const char* hlo_text = R"(
-    HloModule test
-    ENTRY test {
-      x = <<F8E4M3>>[3,15,15] parameter(0)
-      y = <<F8E4M3>>[15,31] parameter(1)
-      b = f32[3,15,31] parameter(2)
-      x_f32 = f32[3,15,15] convert(x)
-      y_f32 = f32[15,31] convert(y)
-      x_scale = f32[] parameter(3)
-      y_scale = f32[] parameter(4)
-      x_scale_bcast = f32[3,15,15] broadcast(x_scale), dimensions={}
-      y_scale_bcast = f32[15,31] broadcast(y_scale), dimensions={}
-      x_unscaled = f32[3,15,15] multiply(x_f32, x_scale_bcast)
-      x_unscaled_bitcast = f32[45,15] bitcast(x_unscaled)
-      y_unscaled = f32[15,31] multiply(y_f32, y_scale_bcast)
-      dot_a = f32[45,31] dot(x_unscaled_bitcast, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-      dot_a_bitcast = f32[3,15,31]{2,1,0} bitcast(dot_a)
-      ROOT out = f32[3,15,31] add(dot_a_bitcast, b)
-          }
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(hlo_text));
-  GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(),
-                    /*f8_rewrite=*/true);
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
-  EXPECT_TRUE(changed);
-
-  EXPECT_THAT(
-      module->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::Bitcast(m::Slice(m::GetTupleElement(
-                                  m::CustomCall({"__cublas$lt$matmul$f8"}), 0)
-                                  .WithShape(F32, {48, 32}))
-                         .WithShape(F32, {45, 31}))
-              .WithShape(F32, {3, 15, 31})));
-
-  RunAndFilecheckHloRewrite(
-      hlo_text,
-      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
-                   /*f8_rewrite=*/true),
-      R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[3,15,15], {{.*}}: <<F8E4M3>>[15,31], {{.*}}: f32[3,15,31], {{.*}}: f32[], {{.*}}: f32[]) -> f32[3,15,31] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = <<F8E4M3>>[3,15,15]{2,1,0} parameter(0)
-; CHECK-NEXT:    [[P0_BITCAST:%[^ ]+]] = <<F8E4M3>>[45,15]{1,0} bitcast([[P0]])
-; CHECK-NEXT:    [[C1:%[^ ]+]] = <<F8E4M3>>[] constant(0)
-; CHECK-NEXT:    [[P0_PADDED:%[^ ]+]] = <<F8E4M3>>[48,16]{1,0} pad([[P0_BITCAST]], [[C1]]), padding=0_3x0_1
-; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[15,31]{1,0} parameter(1)
-; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[31,15]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK-NEXT:    [[C2:%[^ ]+]] = <<F8E4M3>>[] constant(0)
-; CHECK-NEXT:    [[P1_PADDED:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} pad([[P1_TRANSPOSE]], [[C2]]), padding=0_1x0_1
-; CHECK-NEXT:    [[B:%[^ ]+]] = f32[3,15,31]{2,1,0} parameter(2)
-; CHECK-NEXT:    [[B_BITCAST:%[^ ]+]] = f32[45,31]{1,0} bitcast([[B]])
-; CHECK-NEXT:    [[C3:%[^ ]+]] = f32[] constant(0)
-; CHECK-NEXT:    [[P2_PADDED:%[^ ]+]] = f32[48,32]{1,0} pad([[B_BITCAST]], [[C3]]), padding=0_3x0_1
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[] parameter(3)
-; CHECK-NEXT:    [[P3:%[^ ]+]] = f32[] parameter(4)
-; CHECK-NEXT:    [[C:%[^ ]+]] = f32[] constant(1)
-; CHECK-NEXT:    [[GEMM_TUPLE:%[^ ]+]] = (f32[48,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_PADDED]], [[P1_PADDED]], [[P2_PADDED]], [[P2]], [[P3]], /*index=5*/[[C]], [[C]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":1
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-; CHECK-NEXT:      [[GEMM:%[^ ]+]] = f32[48,32]{1,0} get-tuple-element([[GEMM_TUPLE]]), index=0
-; CHECK-NEXT:      [[SLICE:%[^ ]+]] = f32[45,31]{1,0} slice([[GEMM]]), slice={[0:45], [0:31]}
-; CHECK-NEXT:      ROOT [[OUT:%[^ ]+]] = f32[3,15,31]{2,1,0} bitcast([[SLICE]])
-      )");
-}
-
-// Do not fuse matrix bias When there is a slice that does not chop off the ends
-// of dimensions.
-TEST_P(ParameterizedFp8GemmRewriteTest,
-       ScaledABUnscaledDMatrixBiasWithSliceF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
-  GTEST_SKIP() << "A matrix bias on a matmul is only supported in CUDA 12";
-#endif
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif  // TF_ROCM_VERSION < 60000
-
-  const char* hlo_text = R"(
-    HloModule test
-    ENTRY test {
-      x = <<F8E4M3>>[48,16] parameter(0)
-      y = <<F8E4M3>>[16,32] parameter(1)
-      b = f32[32,16] parameter(2)
-      x_f32 = f32[48,16] convert(x)
-      y_f32 = f32[16,32] convert(y)
-      x_scale = f32[] parameter(3)
-      y_scale = f32[] parameter(4)
-      x_scale_bcast = f32[48,16] broadcast(x_scale), dimensions={}
-      y_scale_bcast = f32[16,32] broadcast(y_scale), dimensions={}
-      x_unscaled = f32[48,16] multiply(x_f32, x_scale_bcast)
-      y_unscaled = f32[16,32] multiply(y_f32, y_scale_bcast)
-      dot_a = f32[48,32] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-      dot_a_sliced = f32[32,16] slice(dot_a), slice={[16:48], [16:32]}
-      ROOT out = f32[32,16] add(dot_a_sliced, b)
-          }
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(hlo_text));
-  GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(),
-                    /*f8_rewrite=*/true);
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
-  EXPECT_TRUE(changed);
-
-  RunAndFilecheckHloRewrite(
-      hlo_text,
-      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
-                   /*f8_rewrite=*/true),
-      R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[48,16], {{.*}}: <<F8E4M3>>[16,32], {{.*}}: f32[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[32,16] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = <<F8E4M3>>[48,16]{1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(1)
-; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[] parameter(3)
-; CHECK-NEXT:    [[P3:%[^ ]+]] = f32[] parameter(4)
-; CHECK-NEXT:    [[C:%[^ ]+]] = f32[] constant(1)
-; CHECK-NEXT:    [[GEMM_TUPLE:%[^ ]+]] = (f32[48,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C]], /*index=5*/[[C]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-; CHECK:           [[GEMM:%[^_]+]] = f32[48,32]{1,0} get-tuple-element([[GEMM_TUPLE]]), index=0
-; CHECK-NEXT:      [[SLICE:%[^ ]+]] = f32[32,16]{1,0} slice([[GEMM]]), slice={[16:48], [16:32]}
-; CHECK-NEXT:      [[B:%[^ ]+]] = f32[32,16]{1,0} parameter(2)
-; CHECK-NEXT:      ROOT [[OUT:%[^ ]+]] = f32[32,16]{1,0} add([[SLICE]], [[B]])
-      )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDWithAllGatherF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
-  GTEST_SKIP() << "A matrix bias on a matmul is only supported in CUDA 12";
-#endif
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif  // TF_ROCM_VERSION < 60000
-
-  absl::string_view hlo_text = R"(
-    HloModule test
-
-    ENTRY test {
-      x = <<F8E4M3>>[16,32] parameter(0)
-      y = <<F8E4M3>>[16,32] parameter(1)
-      x_f32 = f32[16,32] convert(x)
-      y_f32 = f32[16,32] convert(y)
-      x_scale = f32[] parameter(2)
-      y_scale = f32[] parameter(3)
-      x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
-      y_scale_bcast = f32[16,32] broadcast(y_scale), dimensions={}
-      x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
-      y_unscaled = f32[16,32] multiply(y_f32, y_scale_bcast)
-      all_gather = f32[16,64]{1,0} all-gather(x_unscaled), channel_id=1, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={1}, use_global_device_ids=true
-      all_gather1 = f32[64,32]{1,0} all-gather(y_unscaled), channel_id=2, replica_groups={{0,2,4,6},{1,3,5,7}}, dimensions={0}, use_global_device_ids=true
-      ROOT dot_a = f32[16,32] dot(all_gather, all_gather1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-          }
-)";
-
-  HloModuleConfig config = GetModuleConfigForTest();
-  config.set_use_spmd_partitioning(true);
-  config.set_num_partitions(8);
-
-  RunAndFilecheckHloRewrite(
-      hlo_text,
-      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
-                   /*f8_rewrite=*/true),
-      R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[16,32], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,32] {
-; CHECK:         [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
-; CHECK:         [[AG:%[^ ]+]] = <<F8E4M3>>[16,64]{1,0} all-gather([[P0]]), {{[^ ]+}}
-; CHECK:         [[P1:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(1)
-; CHECK:         [[AG1:%[^ ]+]] = <<F8E4M3>>[64,32]{1,0} all-gather([[P1]]), {{[^ ]+}}
-; CHECK:         [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[32,64]{1,0} transpose([[AG1]]), dimensions={1,0}
-; CHECK:         [[P2:%[^ ]+]] = f32[] parameter(2)
-; CHECK:         [[P3:%[^ ]+]] = f32[] parameter(3)
-; CHECK:         [[C:%[^ ]+]] = f32[] constant(1)
-; CHECK:         [[GEMM_TUPLE:%[^ ]+]] = (f32[16,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[AG]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C]], /*index=5*/[[C]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-; CHECK:           ROOT [[GEMM:%[^_]+]] = f32[16,32]{1,0} get-tuple-element([[GEMM_TUPLE]]), index=0
-      )",
-      nullptr, &config);
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDWithAllToAllF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
-  GTEST_SKIP() << "A matrix bias on a matmul is only supported in CUDA 12";
-#endif
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif  // TF_ROCM_VERSION < 60000
-
-  absl::string_view hlo_text = R"(
-    HloModule test
-
-    ENTRY test {
-      x = <<F8E4M3>>[16,32] parameter(0)
-      y = <<F8E4M3>>[16,32] parameter(1)
-      x_f32 = f32[16,32] convert(x)
-      y_f32 = f32[16,32] convert(y)
-      x_scale = f32[] parameter(2)
-      y_scale = f32[] parameter(3)
-      x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
-      y_scale_bcast = f32[16,32] broadcast(y_scale), dimensions={}
-      x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
-      y_unscaled = f32[16,32] multiply(y_f32, y_scale_bcast)
-      all_to_all = f32[16,32]{1,0} all-to-all(x_unscaled), channel_id=1, replica_groups={{0,1,2,3},{4,5,6,7}}, dimensions={0}
-      ROOT dot_a = f32[16,16] dot(all_to_all, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={1}
-          }
-)";
-
-  HloModuleConfig config = GetModuleConfigForTest();
-  config.set_use_spmd_partitioning(true);
-  config.set_num_partitions(8);
-
-  RunAndFilecheckHloRewrite(
-      hlo_text,
-      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
-                   /*f8_rewrite=*/true),
-      R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[16,32], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] {
-; CHECK:         [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
-; CHECK:         [[AA:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} all-to-all([[P0]]), {{[^ ]+}}
-; CHECK:         [[P1:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(1)
-; CHECK:         [[P2:%[^ ]+]] = f32[] parameter(2)
-; CHECK:         [[P3:%[^ ]+]] = f32[] parameter(3)
-; CHECK:         [[C:%[^ ]+]] = f32[] constant(1)
-; CHECK:         [[GEMM:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[AA]], [[P1]], [[P2]], [[P3]], [[C]], /*index=5*/[[C]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-      )",
-      nullptr, &config);
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest,
-       ScaledABUnscaledDWithCollectivePermuteF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif  // TF_ROCM_VERSION < 60000
-
-  absl::string_view hlo_text = R"(
-    HloModule test
-
-    ENTRY test {
-      x = <<F8E4M3>>[16,32] parameter(0)
-      y = <<F8E4M3>>[16,32] parameter(1)
-      x_f32 = f32[16,32] convert(x)
-      y_f32 = f32[16,32] convert(y)
-      x_scale = f32[] parameter(2)
-      y_scale = f32[] parameter(3)
-      x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
-      y_scale_bcast = f32[16,32] broadcast(y_scale), dimensions={}
-      x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
-      y_unscaled = f32[16,32] multiply(y_f32, y_scale_bcast)
-      collective_permute = f32[16,32]{1,0} collective-permute(x_unscaled), source_target_pairs={{0,0}, {1,1}, {2,4}, {3,5}, {4,2}, {5,3}, {6,6}, {7,7}}
-      ROOT dot_a = f32[16,16] dot(collective_permute, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={1}
-          }
-)";
-
-  HloModuleConfig config = GetModuleConfigForTest();
-  config.set_use_spmd_partitioning(true);
-  config.set_num_partitions(8);
-
-  RunAndFilecheckHloRewrite(
-      hlo_text,
-      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
-                   /*f8_rewrite=*/true),
-      R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[16,32], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] {
-; CHECK:         [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
-; CHECK:         [[AA:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} collective-permute([[P0]]), {{[^ ]+}}
-; CHECK:         [[P1:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(1)
-; CHECK:         [[P2:%[^ ]+]] = f32[] parameter(2)
-; CHECK:         [[P3:%[^ ]+]] = f32[] parameter(3)
-; CHECK:         [[C:%[^ ]+]] = f32[] constant(1)
-; CHECK:         [[GEMM:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[AA]], [[P1]], [[P2]], [[P3]], [[C]], /*index=5*/[[C]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-      )",
-      nullptr, &config);
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest,
-       ScaledABUnscaledDMatrixBiasThenVectorBiasF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif  // TF_ROCM_VERSION < 60000
-
-  const char* hlo_text = R"(
-    HloModule test
-
-    ENTRY test {
-      x = <<F8E4M3>>[16,32] parameter(0)
-      y = <<F8E4M3>>[32,16] parameter(1)
-      x_f16 = f16[16,32] convert(x)
-      y_f16 = f16[32,16] convert(y)
-      b = f16[16] parameter(2)
-      b_bcast = f16[16,16] broadcast(b), dimensions={1}
-      b2 = f16[16,16] parameter(3)
-      x_scale = f16[] parameter(4)
-      y_scale = f16[] parameter(5)
-      x_scale_bcast = f16[16,32] broadcast(x_scale), dimensions={}
-      y_scale_bcast = f16[32,16] broadcast(y_scale), dimensions={}
-      x_unscaled = f16[16,32] multiply(x_f16, x_scale_bcast)
-      y_unscaled = f16[32,16] multiply(y_f16, y_scale_bcast)
-      dot_a = f16[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-      dot_a_bias1 = f16[16,16] add(dot_a, b2)
-      ROOT dot_a_bias = f16[16,16] add(dot_a_bias1, b_bcast)
-          }
-
-)";
-
-  CheckFp8IfSupported(hlo_text, ErrorSpec{2e-3, 0.});
-  RunAndFilecheckHloRewrite(
-      hlo_text,
-      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
-                   /*f8_rewrite=*/true),
-      R"(
-; CHECK-LABEL:   ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f16[16], {{.*}}: f16[16,16], {{.*}}: f16[], {{.*}}: f16[]) -> f16[16,16] {
-; CHECK-DAG:     [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
-; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
-; CHECK-NEXT:    [[MB:%[^ ]+]] = f16[16,16]{1,0} parameter(3)
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f16[] parameter(4)
-; CHECK-NEXT:    [[CV0:%[^ ]+]] = f32[] convert([[P2]])
-; CHECK-NEXT:    [[P3:%[^ ]+]] = f16[] parameter(5)
-; CHECK-NEXT:    [[CV1:%[^ ]+]] = f32[] convert([[P3]])
-; CHECK:         [[C1:%[^ ]+]] = f32[] constant(1)
-; CHECK:         [[GEMMOUT_TUPLE:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[MB]], [[CV0]], [[CV1]], /*index=5*/[[C1]], [[C1]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":1
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-; CHECK:         [[GEMMOUT:%[^ ]+]] = f16[16,16]{1,0} get-tuple-element([[GEMMOUT_TUPLE]]), index=0
-; CHECK:         [[VB:%[^ ]+]] = f16[16]{0} parameter(2)
-; CHECK:         [[VBC:%[^ ]+]] = f16[16,16]{1,0} broadcast([[VB]]), dimensions={1}
-; CHECK:         ROOT [[OUT:%[^ ]+]] = f16[16,16]{1,0} add([[GEMMOUT]], [[VBC]])
-      )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDWithDAmaxF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif  // TF_ROCM_VERSION < 60000
-
-  const char* hlo_text = R"(
-    HloModule test
-
-    apply {
-      a = f32[] parameter(0)
-      b = f32[] parameter(1)
-      ROOT c = f32[] maximum(a, b)
-    }
-
-    ENTRY test {
-      x = <<F8E4M3>>[16,32] parameter(0)
-      y = <<F8E4M3>>[32,16] parameter(1)
-      x_f32 = f32[16,32] convert(x)
-      y_f32 = f32[32,16] convert(y)
-      x_scale = f32[] parameter(2)
-      y_scale = f32[] parameter(3)
-      z_scale = f32[] parameter(4)
-      x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
-      y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
-      z_scale_bcast = f32[16,16] broadcast(z_scale), dimensions={}
-      x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
-      y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast)
-      dot_a = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-      abs_dot_a = f32[16,16] abs(dot_a)
-      c0 = f32[] constant(-inf)
-      amax = f32[] reduce(abs_dot_a, c0), dimensions={0,1}, to_apply=apply
-      dot_a_scaled = f32[16,16] divide(dot_a, z_scale_bcast)
-      c1 = f32[] constant(-<<F8E4M3_AMAX>>)
-      c1_bcast = f32[16,16] broadcast(c1), dimensions={}
-      c2 = f32[] constant(<<F8E4M3_AMAX>>)
-      c2_bcast = f32[16,16] broadcast(c2), dimensions={}
-      dot_a_clamped = f32[16,16] clamp(c1_bcast, dot_a_scaled, c2_bcast)
-      dot_a_f8 = <<F8E4M3>>[16,16] convert(dot_a_clamped)
-      ROOT out = (<<F8E4M3>>[16,16], f32[]) tuple(dot_a_f8, amax)
-          }
-
-)";
-
-  CheckFp8IfSupported(hlo_text);
-  RunAndFilecheckHloRewrite(
-      hlo_text,
-      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
-                   /*f8_rewrite=*/true),
-      R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[], {{.*}}: f32[], {{.*}}: f32[]) -> (<<F8E4M3>>[16,16], f32[]) {
-; CHECK:    [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
-; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]])
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[] parameter(2)
-; CHECK-NEXT:    [[P3:%[^ ]+]] = f32[] parameter(3)
-; CHECK-NEXT:    [[C1:%[^ ]+]] = f32[] constant(1)
-; CHECK-PTX-NEXT:    [[C2:%[^ ]+]] = f32[] constant(1)
-; CHECK-PTX-NEXT:    [[P4:%[^ ]+]] = f32[] parameter(4)
-; CHECK-PTX-NEXT:    [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]])
-; CHECK-PTX-NEXT:    [[OUT:%[^ ]+]] = (<<F8E4M3>>[16,16]{1,0}, f32[], s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[P4_INV]]),
-; CHECK-GCN-NEXT:    [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-      )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest,
-       ScaledABScaledDWithDAmaxF8WithF16Intermediates) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif  // TF_ROCM_VERSION < 60000
-
-  // This is the same as ScaledABScaledDWithDAmaxF8, but uses F16 intermediate
-  // values instead of F32 intermediate values.
-  const char* hlo_text = R"(
-    HloModule test
-
-    apply {
-      a = f16[] parameter(0)
-      b = f16[] parameter(1)
-      ROOT c = f16[] maximum(a, b)
-    }
-
-    ENTRY test {
-      x = <<F8E4M3>>[16,32] parameter(0)
-      y = <<F8E4M3>>[32,16] parameter(1)
-      x_f16 = f16[16,32] convert(x)
-      y_f16 = f16[32,16] convert(y)
-      x_scale = f16[] parameter(2)
-      y_scale = f16[] parameter(3)
-      z_scale = f16[] parameter(4)
-      x_scale_bcast = f16[16,32] broadcast(x_scale), dimensions={}
-      y_scale_bcast = f16[32,16] broadcast(y_scale), dimensions={}
-      z_scale_bcast = f16[16,16] broadcast(z_scale), dimensions={}
-      x_unscaled = f16[16,32] multiply(x_f16, x_scale_bcast)
-      y_unscaled = f16[32,16] multiply(y_f16, y_scale_bcast)
-      dot_a = f16[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-      abs_dot_a = f16[16,16] abs(dot_a)
-      c0 = f16[] constant(-inf)
-      amax = f16[] reduce(abs_dot_a, c0), dimensions={0,1}, to_apply=apply
-      dot_a_scaled = f16[16,16] divide(dot_a, z_scale_bcast)
-      c1 = f16[] constant(-<<F8E4M3_AMAX>>)
-      c1_bcast = f16[16,16] broadcast(c1), dimensions={}
-      c2 = f16[] constant(<<F8E4M3_AMAX>>)
-      c2_bcast = f16[16,16] broadcast(c2), dimensions={}
-      dot_a_clamped = f16[16,16] clamp(c1_bcast, dot_a_scaled, c2_bcast)
-      dot_a_f8 = <<F8E4M3>>[16,16] convert(dot_a_clamped)
-      ROOT out = (<<F8E4M3>>[16,16], f16[]) tuple(dot_a_f8, amax)
-          }
-
-)";
-
-  CheckFp8IfSupported(hlo_text);
-  RunAndFilecheckHloRewrite(
-      hlo_text,
-      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
-                   /*f8_rewrite=*/true),
-      R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f16[], {{.*}}: f16[], {{.*}}: f16[]) -> (<<F8E4M3>>[16,16], f16[]) {
-; CHECK:    [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
-; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]])
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f16[] parameter(2)
-; CHECK-NEXT:    [[P2_CONVERT:%[^ ]+]] = f32[] convert([[P2]])
-; CHECK-NEXT:    [[P3:%[^ ]+]] = f16[] parameter(3)
-; CHECK-NEXT:    [[P3_CONVERT:%[^ ]+]] = f32[] convert([[P3]])
-; CHECK-NEXT:    [[C1:%[^ ]+]] = f32[] constant(1)
-; CHECK-PTX-NEXT:    [[C2:%[^ ]+]] = f16[] constant(1)
-; CHECK-PTX-NEXT:    [[P4:%[^ ]+]] = f16[] parameter(4)
-; CHECK-PTX-NEXT:    [[P4_INV:%[^ ]+]] = f16[] divide([[C2]], [[P4]])
-; CHECK-PTX-NEXT:    [[P4_INV_CONVERT:%[^ ]+]] = f32[] convert([[P4_INV]])
-; CHECK-PTX-NEXT:    [[OUT:%[^ ]+]] = (<<F8E4M3>>[16,16]{1,0}, f32[], s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_CONVERT]], [[P3_CONVERT]], [[C1]], /*index=5*/[[P4_INV_CONVERT]]),
-; CHECK-GCN-NEXT:    [[OUT:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_CONVERT]], [[P3_CONVERT]], [[C1]], /*index=5*/[[C1]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-      )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest,
-       ScaledABScaledDReluActivationWithDAmaxF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif  // TF_ROCM_VERSION < 60000
-
-  const char* hlo_text = R"(
-    HloModule test
-
-    apply {
-      a = f32[] parameter(0)
-      b = f32[] parameter(1)
-      ROOT c = f32[] maximum(a, b)
-    }
-
-    ENTRY test {
-      x = <<F8E4M3>>[16,32] parameter(0)
-      y = <<F8E4M3>>[32,16] parameter(1)
-      x_f32 = f32[16,32] convert(x)
-      y_f32 = f32[32,16] convert(y)
-      x_scale = f32[] parameter(2)
-      y_scale = f32[] parameter(3)
-      z_scale = f32[] parameter(4)
-      x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
-      y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
-      z_scale_bcast = f32[16,16] broadcast(z_scale), dimensions={}
-      x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
-      y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast)
-      dot_a = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-      czero = f32[] constant(0)
-      czero_bcast = f32[16,16] broadcast(czero), dimensions={}
-      dot_a_relu = f32[16,16] maximum(dot_a, czero_bcast)
-      c0 = f32[] constant(-inf)
-      amax = f32[] reduce(dot_a_relu, c0), dimensions={0,1}, to_apply=apply
-      dot_a_scaled = f32[16,16] divide(dot_a_relu, z_scale_bcast)
-      c1 = f32[] constant(-<<F8E4M3_AMAX>>)
-      c1_bcast = f32[16,16] broadcast(c1), dimensions={}
-      c2 = f32[] constant(<<F8E4M3_AMAX>>)
-      c2_bcast = f32[16,16] broadcast(c2), dimensions={}
-      dot_a_clamped = f32[16,16] clamp(c1_bcast, dot_a_scaled, c2_bcast)
-      dot_a_f8 = <<F8E4M3>>[16,16] convert(dot_a_clamped)
-      ROOT out = (<<F8E4M3>>[16,16], f32[]) tuple(dot_a_f8, amax)
-          }
-
-)";
-
-  CheckFp8IfSupported(hlo_text);
-  RunAndFilecheckHloRewrite(
-      hlo_text,
-      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
-                   /*f8_rewrite=*/true),
-      R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[], {{.*}}: f32[], {{.*}}: f32[]) -> (<<F8E4M3>>[16,16], f32[]) {
-; CHECK:    [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
-; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]])
-; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[] parameter(2)
-; CHECK-NEXT:    [[P3:%[^ ]+]] = f32[] parameter(3)
-; CHECK-NEXT:    [[C1:%[^ ]+]] = f32[] constant(1)
-; CHECK-PTX-NEXT:    [[C2:%[^ ]+]] = f32[] constant(1)
-; CHECK-PTX-NEXT:    [[P4:%[^ ]+]] = f32[] parameter(4)
-; CHECK-PTX-NEXT:    [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]])
-; CHECK-PTX-NEXT:    [[OUT:%[^ ]+]] = (<<F8E4M3>>[16,16]{1,0}, f32[], s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[P4_INV]]),
-; CHECK-GCN-NEXT:    [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]),
-; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"RELU"
-; CHECK:           }
-      )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDPrecisionF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif  // CUDA_VERSION < 12000
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif  // TF_ROCM_VERSION < 60000
-
-  const char* raw_hlo_template = R"(
-    HloModule test
-
-    ENTRY test {
-      x = <<F8E4M3>>[1600,3200] parameter(0)
-      y = <<F8E4M3>>[3200,1600] parameter(1)
-      x_f32 = f32[1600,3200] convert(x)
-      y_f32 = f32[3200,1600] convert(y)
-      x_scale = f32[] parameter(2)
-      y_scale = f32[] parameter(3)
-      x_scale_bcast = f32[1600,3200] broadcast(x_scale), dimensions={}
-      y_scale_bcast = f32[3200,1600] broadcast(y_scale), dimensions={}
-      x_unscaled = f32[1600,3200] multiply(x_f32, x_scale_bcast)
-      y_unscaled = f32[3200,1600] multiply(y_f32, y_scale_bcast)
-      ROOT out = f32[1600,1600] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={<<precision>>,<<precision>>}
-          }
-)";
-
-  std::string hlo_template =
-      absl::StrReplaceAll(raw_hlo_template, replacements_);
-
-  absl::flat_hash_map<absl::string_view, absl::string_view> replacements;
-  replacements["<<precision>>"] = "default";
-  const auto hlo_text_default = absl::StrReplaceAll(hlo_template, replacements);
-  EXPECT_TRUE(RunAndCompare(hlo_text_default, ErrorSpec{1e-3, 1e-3}));
-
-  replacements["<<precision>>"] = "highest";
-  const auto hlo_text_highest = absl::StrReplaceAll(hlo_template, replacements);
-  EXPECT_TRUE(RunAndCompare(hlo_text_highest, ErrorSpec{1e-4, 1e-4}));
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8Parameterized) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif  // TF_ROCM_VERSION < 60000
-
-  std::array<std::array<absl::string_view, 7>, 32> combinations;
-  int i = 0;
-
-  for (bool d_is_col : {false, true}) {
-    for (bool a_is_col : {false, true}) {
-      for (bool b_is_col : {false, true}) {
-        for (int lhs_contracting_dim : {0, 1}) {
-          for (int rhs_contracting_dim : {0, 1}) {
-            const absl::string_view lcd =
-                lhs_contracting_dim == 1 ? "{1}" : "{0}";
-            const absl::string_view rcd =
-                rhs_contracting_dim == 1 ? "{1}" : "{0}";
-            const absl::string_view a_shape =
-                lhs_contracting_dim == 1 ? "[64,32]" : "[32,64]";
-            const absl::string_view b_shape =
-                rhs_contracting_dim == 0 ? "[32,16]" : "[16,32]";
-            const absl::string_view a_layout = a_is_col ? "{0,1}" : "{1,0}";
-            const absl::string_view b_layout = b_is_col ? "{0,1}" : "{1,0}";
-            const absl::string_view output_layout =
-                d_is_col ? "{0,1}" : "{1,0}";
-            combinations[i++] = std::array{
-                lcd, rcd, a_shape, b_shape, a_layout, b_layout, output_layout};
-          }
-        }
-      }
-    }
-  }
-
-  const char* hlo_template = R"(
-      HloModule test
-    ENTRY test {
-      x = <<F8E4M3>><<Ashape>><<Alayout>> parameter(0)
-      x_f32 = f32<<Ashape>><<Alayout>> convert(x)
-      x_scale = f32[] parameter(2)
-      x_scale_bcast = f32<<Ashape>> broadcast(x_scale), dimensions={}
-      x_unscaled = f32<<Ashape>> multiply(x_f32, x_scale_bcast)
-      y = <<F8E4M3>><<Bshape>><<Blayout>> parameter(1)
-      y_f32 = f32<<Bshape>><<Blayout>> convert(y)
-      y_scale = f32[] parameter(3)
-      y_scale_bcast = f32<<Bshape>> broadcast(y_scale), dimensions={}
-      y_unscaled = f32<<Bshape>> multiply(y_f32, y_scale_bcast)
-      ROOT out = f32[64,16]<<Olayout>> dot(x_unscaled, y_unscaled), lhs_contracting_dims=<<Lcd>>, rhs_contracting_dims=<<Rcd>>
-    }
-      )";
-  for (const auto& combination : combinations) {
-    absl::flat_hash_map<absl::string_view, absl::string_view> replacements;
-    replacements["<<Lcd>>"] = std::get<0>(combination);
-    replacements["<<Rcd>>"] = std::get<1>(combination);
-    replacements["<<Ashape>>"] = std::get<2>(combination);
-    replacements["<<Bshape>>"] = std::get<3>(combination);
-    replacements["<<Alayout>>"] = std::get<4>(combination);
-    replacements["<<Blayout>>"] = std::get<5>(combination);
-    replacements["<<Olayout>>"] = std::get<6>(combination);
-    const auto hlo_text = absl::StrReplaceAll(hlo_template, replacements);
-    CheckFp8IfSupported(hlo_text);
-
-    RunAndFilecheckHloRewrite(
-        hlo_text,
-        GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
-                     /*f8_rewrite=*/true),
-        R"(
-    ; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
-          )");
-  }
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest,
-       ScaledABUnscaledDF8ParameterizedBatched) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif  // TF_ROCM_VERSION < 60000
-
-  // TODO(wenscarl): For batched matmul, not all combinations of A, B and
-  // output layouts get pattern matched successfully to FP8 custom call. Only
-  // a handful of cases are tested here.
-  std::array<std::array<std::string, 7>, 32> combinations;
-  std::string lcd, rcd, a_shape, b_shape, a_layout, b_layout, o_layout;
-  int i = 0;
-  for (bool o_is_col : {false, true}) {
-    for (int lhs_contracting_dim : {2, 1}) {
-      for (int rhs_contracting_dim : {2, 1}) {
-        lcd = lhs_contracting_dim == 2 ? "{2}" : "{1}";
-        rcd = rhs_contracting_dim == 2 ? "{2}" : "{1}";
-        a_shape = lhs_contracting_dim == 2 ? "[2,64,32]" : "[2,32,64]";
-        b_shape = rhs_contracting_dim == 1 ? "[2,32,16]" : "[2,16,32]";
-        o_layout = o_is_col ? "{2, 0, 1}" : "{2, 1, 0}";
-        for (std::string a_layout : {"{2,1,0}", "{1,2,0}"}) {
-          for (std::string b_layout : {"{2,1,0}", "{1,2,0}"}) {
-            combinations[i++] = std::array{lcd,      rcd,      a_shape, b_shape,
-                                           a_layout, b_layout, o_layout};
-          }
-        }
-      }
-    }
-  }
-
-  const char* hlo_template = R"(
-      HloModule m
-ENTRY f {
-  x_q = <<F8E4M3>><<Ashape>><<Alayout>> parameter(0)
-  x_scale = f32[] parameter(2)
-  x_scale_broadcast = f32<<Ashape>><<Alayout>> broadcast(x_scale), dimensions={}
-  x_q_convert = f32<<Ashape>><<Alayout>> convert(x_q)
-  x_qdq = f32<<Ashape>><<Alayout>> multiply(x_q_convert, x_scale_broadcast)
-
-  y_q = <<F8E4M3>><<Bshape>><<Blayout>> parameter(1)
-  y_scale = f32[] parameter(3)
-  y_scale_broadcast = f32<<Bshape>><<Blayout>> broadcast(y_scale), dimensions={}
-  y_q_convert = f32<<Bshape>><<Blayout>> convert(y_q)
-  y_qdq = f32<<Bshape>><<Blayout>> multiply(y_q_convert, y_scale_broadcast)
-
-  ROOT out = f32[2,64,16]<<Olayout>> dot(x_qdq, y_qdq), lhs_batch_dims={0}, lhs_contracting_dims=<<Lcd>>, rhs_batch_dims={0}, rhs_contracting_dims=<<Rcd>>
-}
-     )";
-
-  for (const auto& combination : combinations) {
-    absl::flat_hash_map<std::string, std::string> replacements;
-    replacements["<<Lcd>>"] = std::get<0>(combination);
-    replacements["<<Rcd>>"] = std::get<1>(combination);
-    replacements["<<Ashape>>"] = std::get<2>(combination);
-    replacements["<<Bshape>>"] = std::get<3>(combination);
-    replacements["<<Alayout>>"] = std::get<4>(combination);
-    replacements["<<Blayout>>"] = std::get<5>(combination);
-    replacements["<<Olayout>>"] = std::get<6>(combination);
-
-    const auto hlo_text = absl::StrReplaceAll(hlo_template, replacements);
-    CheckFp8IfSupported(hlo_text);
-
-    RunAndFilecheckHloRewrite(
-        hlo_text,
-        GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
-                     /*f8_rewrite=*/true),
-        R"(
-    ; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
-          )");
-  }
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8TF32E5M2) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif  // TF_ROCM_VERSION < 60000
-
-  const char* hlo_text = R"(
-    HloModule test
-
-    ENTRY test {
-      x = <<F8E4M3>>[16,32] parameter(0)
-      y = <<F8E5M2>>[32,16] parameter(1)
-      x_f32 = f32[16,32] convert(x)
-      y_f32 = f32[32,16] convert(y)
-      x_scale = f32[] parameter(2)
-      y_scale = f32[] parameter(3)
-      x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
-      y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
-      x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
-      y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast)
-      ROOT out = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-          }
-
-)";
-
-  CheckFp8IfSupported(hlo_text);
-  RunAndFilecheckHloRewrite(
-      hlo_text,
-      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
-                   /*f8_rewrite=*/true),
-      R"(
-    ; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
-          )");
-}
-
-TEST_P(ParameterizedFp8GemmRewriteTest, FnuzTypeF8) {
-#if GOOGLE_CUDA && CUDA_VERSION < 12000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
-#endif
-
-#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
-  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
-#endif  // TF_ROCM_VERSION < 60000
-
-  // Test that FNUZ FP8 gemms are not rewritten, as cuBLAS does not support them
-  const char* hlo_text = R"(
-    HloModule test
-
-    ENTRY test {
-      x = f8e4m3fnuz[16,32] parameter(0)
-      y = f8e4m3fnuz[32,16] parameter(1)
-      x_f32 = f32[16,32] convert(x)
-      y_f32 = f32[32,16] convert(y)
-      x_scale = f32[] parameter(2)
-      y_scale = f32[] parameter(3)
-      x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
-      y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
-      x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
-      y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast)
-      ROOT out = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-          }
-)";
-#if GOOGLE_CUDA
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(hlo_text));
-  GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(),
-                    /*f8_rewrite=*/true);
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
-  EXPECT_FALSE(changed);
-#endif
-#if TENSORFLOW_USE_ROCM
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-2, 1e-2}));
-  RunAndFilecheckHloRewrite(
-      hlo_text,
-      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
-                   /*f8_rewrite=*/true),
-      R"(
-; CHECK-LABEL: ENTRY %test ({{.*}}: f8e4m3fnuz[16,32], {{.*}}: f8e4m3fnuz[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f8e4m3fnuz[16,32]{1,0} parameter(0)
-; CHECK-PTX-NEXT:    [[P0_CV:%[^ ]+]] = f32[16,32]{1,0} convert([[P0]])
-; CHECK-PTX-NEXT:    [[P2:%[^ ]+]] = f32[] parameter(2)
-; CHECK-PTX-NEXT:    [[P2_B:%[^ ]+]] = f32[16,32]{1,0} broadcast([[P2]]), dimensions={}
-; CHECK-PTX-NEXT:    [[P0_UNSCALED:%[^ ]+]] = f32[16,32]{1,0} multiply([[P0_CV]], [[P2_B]])
-; CHECK-PTX-NEXT:    [[P1:%[^ ]+]] = f8e4m3fnuz[32,16]{1,0} parameter(1)
-; CHECK-PTX-NEXT:    [[P1_CV:%[^ ]+]] = f32[32,16]{1,0} convert([[P1]])
-; CHECK-PTX-NEXT:    [[P3:%[^ ]+]] = f32[] parameter(3)
-; CHECK-PTX-NEXT:    [[P3_B:%[^ ]+]] = f32[32,16]{1,0} broadcast([[P3]]), dimensions={}
-; CHECK-PTX-NEXT:    [[P1_UNSCALED:%[^ ]+]] = f32[32,16]{1,0} multiply([[P1_CV]], [[P3_B]])
-; CHECK-PTX-NEXT:    [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0_UNSCALED]], [[P1_UNSCALED]]),
-; CHECK-GCN-NEXT:    [[P1:%[^ ]+]] = f8e4m3fnuz[32,16]{1,0} parameter(1)
-; CHECK-GCN-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]])
-; CHECK-GCN-NEXT:    [[P2:%[^ ]+]] = f32[] parameter(2)
-; CHECK-GCN-NEXT:    [[P3:%[^ ]+]] = f32[] parameter(3)
-; CHECK-GCN-NEXT:    [[C1:%[^ ]+]] = f32[] constant(1)
-; CHECK-PTX:           custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
-; CHECK-GCN:           custom_call_target="__cublas$lt$matmul$f8",
-; CHECK:           backend_config={
-; CHECK-DAG:         "alpha_real":1
-; CHECK-DAG:         "alpha_imag":0
-; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{
-; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
-; CHECK-PTX-DAG:           "rhs_contracting_dimensions":["0"]
-; CHECK-GCN-DAG:           "rhs_contracting_dimensions":["1"]
-; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[]
-; CHECK-DAG:         }
-; CHECK-DAG:         "precision_config":{
-; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
-; CHECK-DAG:         }
-; CHECK-DAG:         "epilogue":"DEFAULT"
-; CHECK:           }
-      )");
-#endif
-}
-
-INSTANTIATE_TEST_SUITE_P(Fp8CublasTestsBothLegacyAndLt,
-                         ParameterizedFp8GemmRewriteTest, ::testing::Bool());
-#endif
-
-TEST_F(GemmRewriteTest, NoFuseBiasBroadcast) {
-  const char* hlo = R"(
-
-HloModule module
-
-ENTRY main.10 {
-  Arg_0.1 = f16[384,128]{1,0} parameter(0)
-  Arg_1.2 = f16[128,256]{1,0} parameter(1)
-  dot.4 = f16[384,256]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  Arg_2.3 = f16[256]{0} parameter(2)
-  reshape.5 = f16[1,256]{1,0} reshape(Arg_2.3)
-  broadcast.6 = f16[1,256]{1,0} broadcast(reshape.5), dimensions={0,1}
-  reshape.7 = f16[256]{0} reshape(broadcast.6)
-  broadcast.8 = f16[384,256]{1,0} broadcast(reshape.7), dimensions={1}
-  ROOT add.9 = f16[384,256]{1,0} add(dot.4, broadcast.8)
-})";
-
-  MatchOptimizedHlo(hlo, R"(
-// CHECK: "beta":0
-  )");
-}
-
-TEST_F(GemmRewriteTest, ReduceOfBatchDot) {
-  absl::string_view hlo_string =
-      R"(
-HloModule test
-
-region_5.50 {
-  Arg_0.51 = f32[] parameter(0)
-  Arg_1.52 = f32[] parameter(1)
-  ROOT add.53 = f32[] add(Arg_0.51, Arg_1.52)
-}
-
-ENTRY main {
-  p0 = bf16[3,32,3,13]{3,2,1,0} parameter(0)
-  p1 = bf16[3,32,3,64]{3,2,1,0} parameter(1)
-  dot.95 = bf16[3,3,13,64]{3,2,1,0} dot(p0, p1), lhs_batch_dims={0,2}, lhs_contracting_dims={1}, rhs_batch_dims={0,2}, rhs_contracting_dims={1}, operand_precision={highest,highest}
-  transpose.96 = bf16[3,64,3,13]{1,3,2,0} transpose(dot.95), dimensions={0,3,1,2}
-  convert.101 = f32[3,64,3,13]{1,3,2,0} convert(transpose.96)
-  constant.66 = f32[] constant(0.0)
-  ROOT reduce.102 = f32[3,64,13]{2,1,0} reduce(convert.101, constant.66), dimensions={2}, to_apply=region_5.50
-}
-)";
-  // Make sure the dot is lowered to a custom call. There is an algebraic
-  // simplifier simplification which could turn the dot into a non-canonical dot
-  // late in the pipeline, which will make it unsupported by the GemmRewriter.
-  MatchOptimizedHlo(hlo_string, R"(
-  // CHECK: custom_call_target="__cublas$gemm"
-  )");
-}
-
-class GemmRewriteAllocationTest : public GpuCodegenTest {
- public:
-  void CheckNumberOfAllocations(const std::string& hlo,
-                                int expected_number_of_allocations) {
-    TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
-                            GetOptimizedModule(hlo));
-    if (allocator_ == nullptr) {
-      allocator_ = std::make_unique<se::StreamExecutorMemoryAllocator>(
-          backend().default_stream_executor());
-    }
-    TF_ASSERT_OK_AND_ASSIGN(
-        std::unique_ptr<Executable> executable,
-        backend().compiler()->RunBackend(std::move(optimized_module),
-                                         backend().default_stream_executor(),
-                                         allocator_.get()));
-    GpuExecutable* gpu_executable =
-        static_cast<GpuExecutable*>(executable.get());
-    absl::Span<const BufferAllocation> allocations =
-        gpu_executable->GetAllocations();
-    ASSERT_EQ(allocations.size(), expected_number_of_allocations);
-  }
-
-  DebugOptions GetDebugOptionsForTest() override {
-    DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest();
-    // Make sure the rewriter does not skip the rewrite for being too small.
-    debug_options.set_xla_gpu_gemm_rewrite_size_threshold(0);
-    return debug_options;
-  }
-
- private:
-  std::unique_ptr<se::DeviceMemoryAllocator> allocator_;
-};
-
-TEST_F(GemmRewriteAllocationTest, SharedBufferAssignment) {
-  const char* hlo_text = R"(
-HloModule SharedBufferAssignment
-
-ENTRY AddDotsFunc {
-  x = f32[2,2] parameter(0)
-  y = f32[2,2] parameter(1)
-  bias = f32[2,2] add(x, y)
-  dot = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  ROOT out = f32[2,2] add(dot, bias)
-}
-
-)";
-
-  // Bias should be fused into the multiplication.
-  CheckNumberOfAllocations(hlo_text, 4);
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-}
-
-class SmallDotGemmRewriteTest : public GemmRewriteTest {
- public:
-  DebugOptions GetDebugOptionsForTest() override {
-    DebugOptions debug_options = GemmRewriteTest::GetDebugOptionsForTest();
-    debug_options.set_xla_gpu_gemm_rewrite_size_threshold(100);
-    return debug_options;
-  }
-};
-
-TEST_F(SmallDotGemmRewriteTest, SkipSmallMatrixMultiplicationRewrite) {
-  const char* hlo_text = R"(
-HloModule SkipSmallMatrixRewrite
-
-ENTRY DotFunc {
-  x = f32[3,3] parameter(0)
-  y = f32[3,3] parameter(1)
-  ROOT out = f32[3,3] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
-)";
-
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-; CHECK-LABEL: ENTRY %DotFunc ({{.*}}: f32[3,3], {{.*}}: f32[3,3]) -> f32[3,3] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[3,3]{1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[3,3]{1,0} parameter(1)
-; CHECK-NEXT:    [[GEMM:%[^ ]+]] = {{.*}} dot([[P0]], [[P1]]),
-; CHECK:           lhs_contracting_dims={1}, rhs_contracting_dims={0}
-)");
-}
-
-TEST_F(SmallDotGemmRewriteTest, LargeMatrixMultiplicationIsRewritten) {
-  const char* hlo_text = R"(
-HloModule SkipSmallMatrixRewrite
-
-ENTRY DotFunc {
-  x = f32[8,8] parameter(0)
-  y = f32[8,8] parameter(1)
-  ROOT out = f32[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-}
-)";
-
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-; CHECK-LABEL: ENTRY %DotFunc ({{.*}}: f32[8,8], {{.*}}: f32[8,8]) -> f32[8,8] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[8,8]{1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[8,8]{1,0} parameter(1)
-; CHECK:         {{[^ ]+}} = {{.*}} custom-call([[P0]], [[P1]])
-)");
-}
-
-}  // namespace
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/tests/gpu_all_gather_optimizer_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_all_gather_optimizer_test.cc
deleted file mode 100644
index 5db5ffd..0000000
--- a/third_party/xla/xla/service/gpu/tests/gpu_all_gather_optimizer_test.cc
+++ /dev/null
@@ -1,232 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gpu_all_gather_optimizer.h"
-
-#include <cstddef>
-#include <cstdint>
-#include <memory>
-#include <utility>
-
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/hlo_module_config.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/util.h"
-#include "tsl/platform/statusor.h"
-#include "tsl/platform/test.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-class GpuAllGatherOptimizerTest : public HloTestBase {
- public:
-  absl::StatusOr<std::unique_ptr<HloModule>> RunPass(
-      absl::string_view hlo_module, int64_t num_replicas,
-      int64_t num_partitions, bool expect_change) {
-    HloModuleConfig config = GetModuleConfigForTest(
-        /*replica_count=*/num_replicas,
-        /*num_partitions=*/num_partitions);
-    config.set_use_spmd_partitioning(num_partitions > 1);
-    TF_ASSIGN_OR_RETURN(auto module,
-                        ParseAndReturnVerifiedModule(hlo_module, config));
-
-    auto changed = AllGatherOptimizer().Run(module.get());
-    if (!changed.ok()) {
-      return changed.status();
-    }
-    EXPECT_EQ(changed.value(), expect_change);
-    return absl::StatusOr<std::unique_ptr<HloModule>>(std::move(module));
-  }
-
-  template <HloOpcode oc>
-  size_t CollectiveCount(std::unique_ptr<HloModule> &module) {
-    return absl::c_count_if(module->entry_computation()->instructions(),
-                            HloPredicateIsOp<oc>);
-  }
-};
-
-TEST_F(GpuAllGatherOptimizerTest, BranchesOptimized) {
-  absl::string_view hlo_string = R"(
-HloModule ReduceScatter
-
-add {
-  x = bf16[] parameter(0)
-  y = bf16[] parameter(1)
-  ROOT add = bf16[] add(x, y)
-}
-
-ENTRY main {
-param.1 = bf16[8,128,1024]{2,1,0} parameter(0)
-param.2 = bf16[8,128,1024]{2,1,0} parameter(1)
-reduce-scatter.1 = bf16[8,64,1024]{2,1,0} reduce-scatter(param.1), channel_id=8, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, dimensions={1}, to_apply=add
-all-gather.1 = bf16[8,128,1024]{2,1,0} all-gather(reduce-scatter.1), channel_id=5, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={1}, use_global_device_ids=true
-reduce-scatter.2 = bf16[8,64,1024]{2,1,0} reduce-scatter(param.2), channel_id=9, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, dimensions={1}, to_apply=add
-all-gather.2 = bf16[8,128,1024]{2,1,0} all-gather(reduce-scatter.2), channel_id=5, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={1}, use_global_device_ids=true
-add.1 = bf16[8,128,1024]{2,1,0} add(all-gather.1, all-gather.2)
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
-                                               /*num_replicas=*/8,
-                                               /*num_partitions=*/1,
-                                               /*expect_change=*/true));
-  // graph should contain 1 all-gather but since the node removal piece
-  // is diferred, they still exist at this stage
-  EXPECT_EQ(CollectiveCount<HloOpcode::kAllGather>(module), 3);
-  EXPECT_EQ(CollectiveCount<HloOpcode::kReduceScatter>(module), 2);
-}
-
-TEST_F(GpuAllGatherOptimizerTest, DisbledSPMDPartitioningJAXBug) {
-  absl::string_view hlo_string = R"(
-HloModule pjit_f, entry_computation_layout={(f32[4,8]{1,0}, f32[4,8]{1,0})->f32[8,8]{1,0}}
-
-ENTRY %main.6_spmd (param: f32[4,8], param.1: f32[4,8]) -> f32[8,8] {
-  %param = f32[4,8]{1,0} parameter(0), sharding={devices=[2,1]<=[2]}
-  %all-gather = f32[8,8]{1,0} all-gather(f32[4,8]{1,0} %param), channel_id=1, replica_groups={{0,1}}, dimensions={0}, use_global_device_ids=true, metadata={op_name="pjit(f)/jit(main)/add" source_file="third_party/py/jax/tests/pjit_test.py" source_line=207}
-  %param.1 = f32[4,8]{1,0} parameter(1), sharding={devices=[2,1]<=[2]}
-  %all-gather.1 = f32[8,8]{1,0} all-gather(f32[4,8]{1,0} %param.1), channel_id=2, replica_groups={{0,1}}, dimensions={0}, use_global_device_ids=true, metadata={op_name="pjit(f)/jit(main)/add" source_file="third_party/py/jax/tests/pjit_test.py" source_line=207}
-  ROOT %add.0 = f32[8,8]{1,0} add(f32[8,8]{1,0} %all-gather, f32[8,8]{1,0} %all-gather.1), metadata={op_name="pjit(f)/jit(main)/add" source_file="third_party/py/jax/tests/pjit_test.py" source_line=207}
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
-                                               /*num_replicas=*/1,
-                                               /*num_partitions=*/2,
-                                               /*expect_change=*/true));
-  EXPECT_EQ(CollectiveCount<HloOpcode::kAllGather>(module), 1);
-}
-
-TEST_F(GpuAllGatherOptimizerTest, MoreThanSingleUserForAllGather) {
-  absl::string_view hlo_string = R"(
-HloModule ReduceScatter
-
-add {
-  x = bf16[] parameter(0)
-  y = bf16[] parameter(1)
-  ROOT add = bf16[] add(x, y)
-}
-
-ENTRY main {
-param.1 = bf16[8,128,1024]{2,1,0} parameter(0)
-param.2 = bf16[8,128,1024]{2,1,0} parameter(1)
-param.3 = bf16[8,128,1024]{2,1,0} parameter(2)
-reduce-scatter.1 = bf16[8,64,1024]{2,1,0} reduce-scatter(param.1), channel_id=8, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, dimensions={1}, to_apply=add
-all-gather.1 = bf16[8,128,1024]{2,1,0} all-gather(reduce-scatter.1), channel_id=5, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={1}, use_global_device_ids=true
-reduce-scatter.2 = bf16[8,64,1024]{2,1,0} reduce-scatter(param.2), channel_id=9, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, dimensions={1}, to_apply=add
-all-gather.2 = bf16[8,128,1024]{2,1,0} all-gather(reduce-scatter.2), channel_id=5, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={1}, use_global_device_ids=true
-reduce-scatter.3 = bf16[8,64,1024]{2,1,0} reduce-scatter(param.3), channel_id=9, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, dimensions={1}, to_apply=add
-all-gather.3 = bf16[8,128,1024]{2,1,0} all-gather(reduce-scatter.3), channel_id=5, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={1}, use_global_device_ids=true
-add.1 = bf16[8,128,1024]{2,1,0} add(all-gather.1, all-gather.3)
-add.2 = bf16[8,128,1024]{2,1,0} add(all-gather.1, all-gather.2)
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
-                                               /*num_replicas=*/8,
-                                               /*num_partitions=*/1,
-                                               /*expect_change=*/false));
-  // see the comment for BranchesOptimized test
-  EXPECT_EQ(CollectiveCount<HloOpcode::kAllGather>(module), 3);
-  EXPECT_EQ(CollectiveCount<HloOpcode::kReduceScatter>(module), 3);
-}
-
-TEST_F(GpuAllGatherOptimizerTest, AllGatherWithOpInBetweenOnRightBranch) {
-  absl::string_view hlo_string = R"(
-HloModule ReduceScatter
-
-add {
-  x = bf16[] parameter(0)
-  y = bf16[] parameter(1)
-  ROOT add = bf16[] add(x, y)
-}
-
-ENTRY main {
-param.1 = bf16[8,128,1024]{2,1,0} parameter(0)
-param.2 = bf16[8,128,1024]{2,1,0} parameter(1)
-param.3 = bf16[8,128,1024]{2,1,0} parameter(2)
-reduce-scatter.1 = bf16[8,64,1024]{2,1,0} reduce-scatter(param.1), channel_id=8, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, dimensions={1}, to_apply=add
-reduce-scatter.2 = bf16[8,64,1024]{2,1,0} reduce-scatter(param.2), channel_id=9, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, dimensions={1}, to_apply=add
-add.1 = bf16[8,64,1024]{2,1,0} add(reduce-scatter.1, reduce-scatter.2)
-all-gather.1 = bf16[8,128,1024]{2,1,0} all-gather(add.1), channel_id=5, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={1}, use_global_device_ids=true
-reduce-scatter.3 = bf16[8,64,1024]{2,1,0} reduce-scatter(param.3), channel_id=9, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, dimensions={1}, to_apply=add
-all-gather.3 = bf16[8,128,1024]{2,1,0} all-gather(reduce-scatter.3), channel_id=5, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={1}, use_global_device_ids=true
-add.2 = bf16[8,128,1024]{2,1,0} add(all-gather.1, all-gather.3)
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
-                                               /*num_replicas=*/8,
-                                               /*num_partitions=*/1,
-                                               /*expect_change=*/true));
-  EXPECT_EQ(CollectiveCount<HloOpcode::kAllGather>(module), 3);
-  EXPECT_EQ(CollectiveCount<HloOpcode::kReduceScatter>(module), 3);
-}
-
-TEST_F(GpuAllGatherOptimizerTest, AllGatherOneSided) {
-  absl::string_view hlo_string = R"(
-HloModule ReduceScatter
-
-add {
-  x = bf16[] parameter(0)
-  y = bf16[] parameter(1)
-  ROOT add = bf16[] add(x, y)
-}
-
-ENTRY main {
-param.1 = bf16[8,128,1024]{2,1,0} parameter(0)
-param.2 = bf16[8,128,1024]{2,1,0} parameter(1)
-param.3 = bf16[8,128,1024]{2,1,0} parameter(2)
-
-add.1 = bf16[8,128,1024]{2,1,0} add(param.1, param.2)
-reduce-scatter = bf16[8,64,1024]{2,1,0} reduce-scatter(param.3), channel_id=9, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, dimensions={1}, to_apply=add
-all-gather = bf16[8,128,1024]{2,1,0} all-gather(reduce-scatter), channel_id=5, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={1}, use_global_device_ids=true
-add.2 = bf16[8,128,1024]{2,1,0} add(all-gather, add.1)
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
-                                               /*num_replicas=*/8,
-                                               /*num_partitions=*/1,
-                                               /*expect_change=*/false));
-  EXPECT_EQ(CollectiveCount<HloOpcode::kAllGather>(module), 1);
-  EXPECT_EQ(CollectiveCount<HloOpcode::kReduceScatter>(module), 1);
-}
-
-TEST_F(GpuAllGatherOptimizerTest, DifferentOperandShapes) {
-  absl::string_view hlo_string = R"(
-HloModule TestModule
-
-ENTRY main {
-param.1 = bf16[8,64,128]{2,1,0} parameter(0)
-param.2 = bf16[8,128,64]{2,1,0} parameter(1)
-all-gather.1 = bf16[8,128,128]{2,1,0} all-gather(param.1), channel_id=5, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={1}, use_global_device_ids=true
-all-gather.2 = bf16[8,128,128]{2,1,0} all-gather(param.2), channel_id=5, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={2}, use_global_device_ids=true
-add.1 = bf16[8,128,128]{2,1,0} add(all-gather.1, all-gather.2)
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
-                                               /*num_replicas=*/8,
-                                               /*num_partitions=*/1,
-                                               /*expect_change=*/false));
-}
-
-}  // namespace
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/tests/gpu_cub_sort_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_cub_sort_test.cc
index 2e5db53..b4124f2 100644
--- a/third_party/xla/xla/service/gpu/tests/gpu_cub_sort_test.cc
+++ b/third_party/xla/xla/service/gpu/tests/gpu_cub_sort_test.cc
@@ -24,7 +24,7 @@
 #include "xla/error_spec.h"
 #include "xla/hlo/ir/hlo_module.h"
 #include "xla/primitive_util.h"
-#include "xla/service/gpu/gpu_sort_rewriter.h"
+#include "xla/service/gpu/transforms/sort_rewriter.h"
 #include "xla/tests/hlo_test_base.h"
 #include "xla/xla_data.pb.h"
 #include "tsl/platform/statusor.h"
@@ -35,7 +35,7 @@
 
 bool HloWasRewrittenToUseCubSort(const HloModule& module) {
   for (const auto& pass_metadata : module.metadata().proto().pass_metadata()) {
-    if (pass_metadata.pass_name() == "gpu-sort-rewriter") {
+    if (pass_metadata.pass_name() == "sort-rewriter") {
       return pass_metadata.module_changed();
     }
   }
@@ -50,13 +50,13 @@
  public:
   void SetUp() override {
     HloTestBase::SetUp();
-    GpuSortRewriter::SetSortSizeThresholdForTestingOnly(33000);
+    SortRewriter::SetSortSizeThresholdForTestingOnly(33000);
   }
 };
 
 TEST_P(CubSortKeysTest, CompareToReference) {
   int batch_size = std::get<2>(GetParam());
-  int segment_size = GpuSortRewriter::SortSizeThreshold() / batch_size;
+  int segment_size = SortRewriter::SortSizeThreshold() / batch_size;
 
   const char* kHloTpl = R"(
 HloModule TestSortKeys
@@ -103,7 +103,7 @@
 })";
 
   int batch_size = std::get<2>(GetParam());
-  int segment_size = GpuSortRewriter::SortSizeThreshold() / batch_size;
+  int segment_size = SortRewriter::SortSizeThreshold() / batch_size;
   std::string hlo_str = absl::Substitute(
       kHloTpl,
       primitive_util::LowercasePrimitiveTypeName(std::get<0>(GetParam())),
@@ -138,13 +138,13 @@
  public:
   void SetUp() override {
     HloTestBase::SetUp();
-    GpuSortRewriter::SetSortSizeThresholdForTestingOnly(33000);
+    SortRewriter::SetSortSizeThresholdForTestingOnly(33000);
   }
 };
 
 TEST_P(CubSortPairsTest, CompareToReference) {
   int batch_size = std::get<3>(GetParam());
-  int segment_size = GpuSortRewriter::SortSizeThreshold() / batch_size;
+  int segment_size = SortRewriter::SortSizeThreshold() / batch_size;
 
   const char* kHloTpl = R"(
 HloModule TestSortPairs
@@ -216,7 +216,7 @@
 })";
 
   int batch_size = std::get<3>(GetParam());
-  int segment_size = GpuSortRewriter::SortSizeThreshold() / batch_size;
+  int segment_size = SortRewriter::SortSizeThreshold() / batch_size;
   std::string hlo_str = absl::Substitute(
       kHloTpl,
       primitive_util::LowercasePrimitiveTypeName(std::get<0>(GetParam())),
diff --git a/third_party/xla/xla/service/gpu/tests/gpu_fused_mha_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_fused_mha_test.cc
index 639cf51..aed017c 100644
--- a/third_party/xla/xla/service/gpu/tests/gpu_fused_mha_test.cc
+++ b/third_party/xla/xla/service/gpu/tests/gpu_fused_mha_test.cc
@@ -134,7 +134,7 @@
     EXPECT_TRUE(
         LiteralTestUtil::Near(expected_result, actual_result, mha_error_spec_));
 
-    // Run FusedMHA/FuseMHABackward thunk through command buffer
+    // Run through command buffer
     DebugOptions debug_options = GetDebugOptionsForTest();
     debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CUDNN);
     debug_options.set_xla_gpu_graph_min_graph_size(1);
@@ -393,8 +393,8 @@
   void TestImpl_Flash_Attention_BMM1_CausalMask_Softmax_BMM2() {
     if (skip_reason_) GTEST_SKIP() << *skip_reason_;
     if (GetDnnVersionInfoOrDefault(backend().default_stream_executor()) <
-        se::dnn::VersionInfo(8, 9, 4)) {
-      GTEST_SKIP() << "Flash Attention requires cuDNN >= 8.9.4.";
+        se::dnn::VersionInfo(9, 0, 0)) {
+      GTEST_SKIP() << "Flash Attention requires cuDNN >= 9.0.0.";
     }
     XlaBuilder builder(TestName());
     auto lhs_bmm1_literal =
@@ -414,8 +414,8 @@
   void TestImpl_Flash_Attention_Training_BMM1_CausalMask_Softmax_BMM2() {
     if (skip_reason_) GTEST_SKIP() << *skip_reason_;
     if (GetDnnVersionInfoOrDefault(backend().default_stream_executor()) <
-        se::dnn::VersionInfo(8, 9, 4)) {
-      GTEST_SKIP() << "Flash Attention requires cuDNN >= 8.9.4.";
+        se::dnn::VersionInfo(9, 0, 0)) {
+      GTEST_SKIP() << "Flash Attention requires cuDNN >= 9.0.0.";
     }
     XlaBuilder builder(TestName());
     auto lhs_bmm1_literal =
@@ -709,8 +709,8 @@
   void TestImpl_Flash_Attention_BMM1_Bias_Softmax_BMM2() {
     if (skip_reason_) GTEST_SKIP() << *skip_reason_;
     if (GetDnnVersionInfoOrDefault(backend().default_stream_executor()) <
-        se::dnn::VersionInfo(8, 9, 4)) {
-      GTEST_SKIP() << "Flash Attention requires cuDNN >= 8.9.4.";
+        se::dnn::VersionInfo(9, 0, 0)) {
+      GTEST_SKIP() << "Flash Attention requires cuDNN >= 9.0.0.";
     }
     XlaBuilder builder(TestName());
     auto lhs_bmm1_literal =
@@ -731,8 +731,8 @@
   void TestImpl_Flash_Attention_Training_BMM1_Bias_Softmax_BMM2() {
     if (skip_reason_) GTEST_SKIP() << *skip_reason_;
     if (GetDnnVersionInfoOrDefault(backend().default_stream_executor()) <
-        se::dnn::VersionInfo(8, 9, 4)) {
-      GTEST_SKIP() << "Flash Attention requires cuDNN >= 8.9.4.";
+        se::dnn::VersionInfo(9, 0, 0)) {
+      GTEST_SKIP() << "Flash Attention requires cuDNN >= 9.0.0.";
     }
     XlaBuilder builder(TestName());
     auto lhs_bmm1_literal =
@@ -756,9 +756,9 @@
   void TestImpl_Flash_Attention_BMM1_Bias_Softmax_BMM2_Cross_Attention() {
     if (skip_reason_) GTEST_SKIP() << *skip_reason_;
     if (GetDnnVersionInfoOrDefault(backend().default_stream_executor()) <
-        se::dnn::VersionInfo(8, 9, 4)) {
+        se::dnn::VersionInfo(9, 0, 0)) {
       GTEST_SKIP() << "Flash Attention cross attention requires "
-                      "cuDNN >= 8.9.4.";
+                      "cuDNN >= 9.0.0.";
     }
     XlaBuilder builder(TestName());
     auto lhs_bmm1_literal =
@@ -780,10 +780,10 @@
     if (skip_reason_) GTEST_SKIP() << *skip_reason_;
     auto cc = GetCudaComputeCapability();
     if (GetDnnVersionInfoOrDefault(backend().default_stream_executor()) <
-            se::dnn::VersionInfo(8, 9, 6) ||
+            se::dnn::VersionInfo(9, 0, 0) ||
         !cc.IsAtLeastHopper() || cc.minor != 0) {
       GTEST_SKIP()
-          << "Flash Attention dbias requires cuDNN >= 8.9.6 and Hopper arch.";
+          << "Flash Attention dbias requires cuDNN >= 9.0.0 and Hopper arch.";
     }
     XlaBuilder builder(TestName());
     auto lhs_bmm1_literal =
@@ -900,8 +900,8 @@
   void TestImpl_Flash_Attention_Training_BMM1_Softmax_BMM2() {
     if (skip_reason_) GTEST_SKIP() << *skip_reason_;
     if (GetDnnVersionInfoOrDefault(backend().default_stream_executor()) <
-        se::dnn::VersionInfo(8, 9, 4)) {
-      GTEST_SKIP() << "Flash Attention requires cuDNN >= 8.9.4.";
+        se::dnn::VersionInfo(9, 0, 0)) {
+      GTEST_SKIP() << "Flash Attention requires cuDNN >= 9.0.0.";
     }
     XlaBuilder builder(TestName());
     auto lhs_bmm1_literal =
@@ -925,10 +925,10 @@
     if (skip_reason_) GTEST_SKIP() << *skip_reason_;
     auto cc = GetCudaComputeCapability();
     if (GetDnnVersionInfoOrDefault(backend().default_stream_executor()) <
-            se::dnn::VersionInfo(8, 9, 4) ||
+            se::dnn::VersionInfo(9, 0, 0) ||
         !cc.IsAtLeastHopper() || cc.minor != 0) {
       GTEST_SKIP() << "Flash Attention deterministic kernels requires cuDNN >= "
-                      "8.9.4 and Hopper arch.";
+                      "9.0.0 and Hopper arch.";
     }
     XlaBuilder builder(TestName());
     auto lhs_bmm1_literal =
@@ -1085,8 +1085,8 @@
   void TestImpl_Flash_Attention_Training_BMM1_PaddingMask_Softmax_BMM2() {
     if (skip_reason_) GTEST_SKIP() << *skip_reason_;
     if (GetDnnVersionInfoOrDefault(backend().default_stream_executor()) <
-        se::dnn::VersionInfo(8, 9, 3)) {
-      GTEST_SKIP() << "Flash Attention requires cuDNN >= 8.9.3.";
+        se::dnn::VersionInfo(9, 0, 0)) {
+      GTEST_SKIP() << "Flash Attention requires cuDNN >= 9.0.0.";
     }
     XlaBuilder builder(TestName());
     // pass padding mask as bias
diff --git a/third_party/xla/xla/service/gpu/tests/gpu_fusion_pipeline_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_fusion_pipeline_test.cc
index 3e573eb..f1577f0 100644
--- a/third_party/xla/xla/service/gpu/tests/gpu_fusion_pipeline_test.cc
+++ b/third_party/xla/xla/service/gpu/tests/gpu_fusion_pipeline_test.cc
@@ -18,11 +18,11 @@
 #include <utility>
 
 #include "absl/strings/string_view.h"
-#include "xla/service/gpu/fusion_merger.h"
 #include "xla/service/gpu/gpu_device_info_for_tests.h"
-#include "xla/service/gpu/instruction_fusion.h"
-#include "xla/service/gpu/multi_output_fusion.h"
 #include "xla/service/gpu/tests/gpu_codegen_test.h"
+#include "xla/service/gpu/transforms/fusion_merger.h"
+#include "xla/service/gpu/transforms/instruction_fusion.h"
+#include "xla/service/gpu/transforms/multi_output_fusion.h"
 #include "xla/service/hlo_cost_analysis.h"
 #include "xla/service/hlo_pass_pipeline.h"
 #include "xla/shape.h"
@@ -51,8 +51,7 @@
                                            device_info);
     pipeline.AddPass<GpuInstructionFusion>(/*may_duplicate=*/true, device_info);
     pipeline.AddPass<FusionMerger>(device_info, ShapeSizeBytesFunction());
-    pipeline.AddPass<GpuMultiOutputFusion>(device_info,
-                                           ShapeSizeBytesFunction());
+    pipeline.AddPass<MultiOutputFusion>(device_info, ShapeSizeBytesFunction());
 
     RunAndFilecheckHloRewrite(hlo, std::move(pipeline), expected);
   }
diff --git a/third_party/xla/xla/service/gpu/tests/gpu_fusion_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_fusion_test.cc
index 849cf1d..8328156 100644
--- a/third_party/xla/xla/service/gpu/tests/gpu_fusion_test.cc
+++ b/third_party/xla/xla/service/gpu/tests/gpu_fusion_test.cc
@@ -23,8 +23,8 @@
 #include "xla/hlo/ir/hlo_opcode.h"
 #include "xla/service/gpu/gpu_device_info_for_tests.h"
 #include "xla/service/gpu/gpu_fusible.h"
-#include "xla/service/gpu/instruction_fusion.h"
 #include "xla/service/gpu/tests/gpu_codegen_test.h"
+#include "xla/service/gpu/transforms/instruction_fusion.h"
 #include "xla/shape.h"
 #include "xla/shape_util.h"
 
diff --git a/third_party/xla/xla/service/gpu/tests/gpu_kernel_tiling_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_kernel_tiling_test.cc
index e86f2c0..45e7af6 100644
--- a/third_party/xla/xla/service/gpu/tests/gpu_kernel_tiling_test.cc
+++ b/third_party/xla/xla/service/gpu/tests/gpu_kernel_tiling_test.cc
@@ -34,7 +34,11 @@
   // Most tests in this file want to skip layout assignment, but a few need it
   // enabled.
   HloModuleConfig ConfigWithLayoutAssignment() {
-    return GetModuleConfigForTest();
+    HloModuleConfig config;
+    auto debug_options = HloTestBase::GetDebugOptionsForTest();
+    debug_options.set_xla_gpu_mlir_emitter_level(3);
+    config.set_debug_options(debug_options);
+    return config;
   }
 
   HloModuleConfig ConfigWithoutLayoutAssignment() {
@@ -42,6 +46,7 @@
     auto debug_options = HloTestBase::GetDebugOptionsForTest();
     // Disable layout_assignment to use the preassigned layouts.
     debug_options.add_xla_disable_hlo_passes("layout-assignment");
+    debug_options.set_xla_gpu_mlir_emitter_level(3);
     config.set_debug_options(debug_options);
     return config;
   }
@@ -635,6 +640,8 @@
   }
   )";
   auto hlo_module = ParseAndReturnVerifiedModule(kHloString).value();
+  auto &debug_options = hlo_module->mutable_config().mutable_debug_options();
+  debug_options.set_xla_gpu_mlir_emitter_level(3);
   auto expected_ir = is_built_with_rocm_ ? R"(
 ; CHECK: %llvm.amdgcn.kernel.input_reduce_fusion.lds.t = type { [4 x [2 x float]] }
 ; CHECK: @llvm.amdgcn.kernel.input_reduce_fusion.lds = internal addrspace(3) global %llvm.amdgcn.kernel.input_reduce_fusion.lds.t poison
diff --git a/third_party/xla/xla/service/gpu/tests/gpu_reduce_scatter_creator_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_reduce_scatter_creator_test.cc
deleted file mode 100644
index b1d2734..0000000
--- a/third_party/xla/xla/service/gpu/tests/gpu_reduce_scatter_creator_test.cc
+++ /dev/null
@@ -1,572 +0,0 @@
-/* Copyright 2021 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/gpu_reduce_scatter_creator.h"
-
-#include <cstddef>
-#include <cstdint>
-#include <memory>
-#include <utility>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "absl/algorithm/container.h"
-#include "absl/log/log.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/hlo_module_config.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/pattern_matcher_gmock.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/util.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-namespace m = ::xla::match;
-
-class GpuReduceScatterCreatorTest : public HloTestBase {
- public:
-  absl::StatusOr<std::unique_ptr<HloModule>> RunPass(
-      absl::string_view hlo_module, int64_t num_replicas,
-      int64_t num_partitions, bool expect_change) {
-    HloModuleConfig config = GetModuleConfigForTest(
-        /*replica_count=*/num_replicas,
-        /*num_partitions=*/num_partitions);
-    config.set_use_spmd_partitioning(num_partitions > 1);
-    TF_ASSIGN_OR_RETURN(auto module,
-                        ParseAndReturnVerifiedModule(hlo_module, config));
-    auto changed = ReduceScatterCreator().Run(module.get());
-    if (!changed.ok()) {
-      return changed.status();
-    }
-    EXPECT_EQ(changed.value(), expect_change);
-    return absl::StatusOr<std::unique_ptr<HloModule>>(std::move(module));
-  }
-
-  size_t AllReduceCount(std::unique_ptr<HloModule> &module) {
-    return CollectiveCount(module, HloOpcode::kAllReduce);
-  }
-
-  size_t ReduceScatterCount(std::unique_ptr<HloModule> &module) {
-    return CollectiveCount(module, HloOpcode::kAllReduce);
-  }
-
- private:
-  size_t CollectiveCount(std::unique_ptr<HloModule> &module, HloOpcode opcode) {
-    return absl::c_count_if(
-        module->entry_computation()->instructions(),
-        [&opcode](HloInstruction *instr) { return instr->opcode() == opcode; });
-  }
-};
-
-TEST_F(GpuReduceScatterCreatorTest, AllReplicas) {
-  absl::string_view hlo_string = R"(
-HloModule AllReduce
-
-%sum {
-  %a = f32[] parameter(0)
-  %b = f32[] parameter(1)
-  ROOT %add = f32[] add(%a, %b)
-}
-
-ENTRY %AllReduce {
-  %param = f32[32,8,128]{2,1,0} parameter(0)
-  %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
-    replica_groups={}, to_apply=%sum
-  %table = s32[8]{0} constant({0,1,2,3,4,5,6,7})
-  %rid = u32[] replica-id()
-  %id = s32[1] dynamic-slice(%table, %rid), dynamic_slice_sizes={1}
-  %reshape = s32[] reshape(%id)
-  %slice_size = s32[] constant(4)
-  %offset = s32[] multiply(%reshape, %slice_size)
-  %zero = s32[] constant(0)
-  ROOT %dynamic-slice = f32[4,8,128] dynamic-slice(%all-reduce, %offset, %zero, %zero),
-    dynamic_slice_sizes={4,8,128}
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
-                                               /*num_replicas=*/8,
-                                               /*num_partitions=*/1,
-                                               /*expect_change=*/true));
-  ASSERT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch(m::ReduceScatter(m::Parameter(0))));
-  const auto *rs = Cast<HloReduceScatterInstruction>(
-      module->entry_computation()->root_instruction());
-  EXPECT_EQ(rs->scatter_dimension(), 0) << rs->ToString();
-  EXPECT_EQ(AllReduceCount(module), 0);
-}
-
-TEST_F(GpuReduceScatterCreatorTest, AllReplicasWithOffsetReshape) {
-  absl::string_view hlo_string = R"(
-HloModule AllReduce
-
-%sum {
-  %a = f32[] parameter(0)
-  %b = f32[] parameter(1)
-  ROOT %add = f32[] add(%a, %b)
-}
-
-ENTRY %AllReduce {
-  %param = f32[32,8,128]{2,1,0} parameter(0)
-  %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
-    replica_groups={}, to_apply=%sum
-  %table = s32[8]{0} constant({0,1,2,3,4,5,6,7})
-  %rid = u32[] replica-id()
-  %id = s32[1] dynamic-slice(%table, %rid), dynamic_slice_sizes={1}
-  %slice_size = s32[1] constant({4})
-  %offset = s32[1] multiply(%id, %slice_size)
-  %reshape = s32[] reshape(%offset)
-  %zero = s32[] constant(0)
-  ROOT %dynamic-slice = f32[4,8,128] dynamic-slice(%all-reduce, %reshape, %zero, %zero),
-    dynamic_slice_sizes={4,8,128}
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
-                                               /*num_replicas=*/8,
-                                               /*num_partitions=*/1,
-                                               /*expect_change=*/true));
-  ASSERT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch(m::ReduceScatter(m::Parameter(0))));
-  const auto *rs = Cast<HloReduceScatterInstruction>(
-      module->entry_computation()->root_instruction());
-  EXPECT_EQ(rs->scatter_dimension(), 0) << rs->ToString();
-  EXPECT_EQ(AllReduceCount(module), 0);
-}
-
-TEST_F(GpuReduceScatterCreatorTest, AllReplicasWithReshape) {
-  absl::string_view hlo_string = R"(
-HloModule AllReduce
-
-%sum {
-  %a = f32[] parameter(0)
-  %b = f32[] parameter(1)
-  ROOT %add = f32[] add(%a, %b)
-}
-
-ENTRY %AllReduce {
-  %param = f32[32,8,128]{2,1,0} parameter(0)
-  %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
-    replica_groups={}, to_apply=%sum
-  %table = s32[8]{0} constant({0,1,2,3,4,5,6,7})
-  %rid = u32[] replica-id()
-  %id = s32[1] dynamic-slice(%table, %rid), dynamic_slice_sizes={1}
-  %reshape = s32[] reshape(%id)
-  %slice_size = s32[] constant(4)
-  %offset = s32[] multiply(%reshape, %slice_size)
-  %zero = s32[] constant(0)
-  %reshape.1 = f32[32,16,64] reshape(%all-reduce)
-  ROOT %dynamic-slice = f32[4,16,64] dynamic-slice(%reshape.1, %offset, %zero, %zero),
-    dynamic_slice_sizes={4,16,64}
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
-                                               /*num_replicas=*/8,
-                                               /*num_partitions=*/1,
-                                               /*expect_change=*/true));
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch(m::Reshape(m::ReduceScatter(m::Parameter(0)))));
-  EXPECT_EQ(AllReduceCount(module), 0);
-}
-
-TEST_F(GpuReduceScatterCreatorTest, AllReplicasWithReshapeSplitDimModified) {
-  absl::string_view hlo_string = R"(
-HloModule AllReduce
-
-%sum {
-  %a = f32[] parameter(0)
-  %b = f32[] parameter(1)
-  ROOT %add = f32[] add(%a, %b)
-}
-
-ENTRY %AllReduce {
-  %param = f32[336,1024] parameter(0)
-  %all-reduce = f32[336,1024] all-reduce(%param), replica_groups={}, to_apply=%sum
-  %rid = u32[] replica-id()
-  %id = s32[] convert(%rid)
-  %slice_size = s32[] constant(128)
-  %offset = s32[] multiply(%id, %slice_size)
-  %zero = s32[] constant(0)
-  %reshape.1 = f32[4,84,1024] reshape(%all-reduce)
-  ROOT %dynamic-slice = f32[4,84,128] dynamic-slice(%reshape.1, %zero, %zero, %offset),
-    dynamic_slice_sizes={4,84,128}
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
-                                               /*num_replicas=*/8,
-                                               /*num_partitions=*/1,
-                                               /*expect_change=*/true));
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch(m::Reshape(m::ReduceScatter(m::Parameter(0)))));
-  EXPECT_EQ(AllReduceCount(module), 0);
-}
-
-TEST_F(GpuReduceScatterCreatorTest, AllReplicasDim2) {
-  absl::string_view hlo_string = R"(
-HloModule AllReduce
-
-%sum {
-  %a = f32[] parameter(0)
-  %b = f32[] parameter(1)
-  ROOT %add = f32[] add(%a, %b)
-}
-
-ENTRY %AllReduce {
-  %param = f32[32,8,128]{2,1,0} parameter(0)
-  %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
-    replica_groups={}, to_apply=%sum
-  %table = s32[8]{0} constant({0,1,2,3,4,5,6,7})
-  %rid = u32[] replica-id()
-  %rid_s32 = s32[] convert(%rid)
-  %slice_size = s32[] constant(16)
-  %offset = s32[] multiply(%rid_s32, %slice_size)
-  %zero = s32[] constant(0)
-  ROOT %dynamic-slice = f32[32,8,16] dynamic-slice(%all-reduce, %zero, %zero, %offset),
-    dynamic_slice_sizes={32,8,16}
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
-                                               /*num_replicas=*/8,
-                                               /*num_partitions=*/1,
-                                               /*expect_change=*/true));
-  ASSERT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch(m::ReduceScatter(m::Parameter(0))));
-  const auto *rs = Cast<HloReduceScatterInstruction>(
-      module->entry_computation()->root_instruction());
-  EXPECT_EQ(rs->scatter_dimension(), 2) << rs->ToString();
-  EXPECT_EQ(AllReduceCount(module), 0);
-}
-
-TEST_F(GpuReduceScatterCreatorTest, AllReplicasWrongOffsets) {
-  absl::string_view hlo_string = R"(
-HloModule AllReduce
-
-%sum {
-  %a = f32[] parameter(0)
-  %b = f32[] parameter(1)
-  ROOT %add = f32[] add(%a, %b)
-}
-
-ENTRY %AllReduce {
-  %param = f32[32,8,128]{2,1,0} parameter(0)
-  %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
-    replica_groups={}, to_apply=%sum
-  %table = s32[8]{0} constant({0,1,2,3,4,5,6,8})
-  %rid = u32[] replica-id()
-  %id = s32[1] dynamic-slice(%table, %rid), dynamic_slice_sizes={1}
-  %reshape = s32[] reshape(%id)
-  %slice_size = s32[] constant(4)
-  %offset = s32[] multiply(%reshape, %slice_size)
-  %zero = s32[] constant(0)
-  ROOT %dynamic-slice = f32[4,8,128] dynamic-slice(%all-reduce, %offset, %zero, %zero),
-    dynamic_slice_sizes={4,8,128}
-}
-)";
-  TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
-                                               /*num_replicas=*/8,
-                                               /*num_partitions=*/1,
-                                               /*expect_change=*/false));
-}
-
-TEST_F(GpuReduceScatterCreatorTest, AllReplicasIotaTable) {
-  absl::string_view hlo_string = R"(
-HloModule AllReduce
-
-%sum {
-  %a = f32[] parameter(0)
-  %b = f32[] parameter(1)
-  ROOT %add = f32[] add(%a, %b)
-}
-
-ENTRY %AllReduce {
-  %param = f32[32,8,128]{2,1,0} parameter(0)
-  %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
-    replica_groups={}, to_apply=%sum
-  %table = s32[8]{0} iota(), iota_dimension=0
-  %rid = u32[] replica-id()
-  %id = s32[1] dynamic-slice(%table, %rid), dynamic_slice_sizes={1}
-  %reshape = s32[] reshape(%id)
-  %slice_size = s32[] constant(4)
-  %offset = s32[] multiply(%reshape, %slice_size)
-  %zero = s32[] constant(0)
-  ROOT %dynamic-slice = f32[4,8,128] dynamic-slice(%all-reduce, %offset, %zero, %zero),
-    dynamic_slice_sizes={4,8,128}
-}
-)";
-  TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
-                                               /*num_replicas=*/8,
-                                               /*num_partitions=*/2,
-                                               /*expect_change=*/true));
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch(m::ReduceScatter(m::Parameter(0))));
-  EXPECT_EQ(AllReduceCount(module), 0);
-}
-
-TEST_F(GpuReduceScatterCreatorTest, SubgroupedReplicas) {
-  absl::string_view hlo_string = R"(
-HloModule AllReduce
-
-%sum {
-  %a = f32[] parameter(0)
-  %b = f32[] parameter(1)
-  ROOT %add = f32[] add(%a, %b)
-}
-
-ENTRY %AllReduce {
-  %param = f32[32,8,128]{2,1,0} parameter(0)
-  %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
-    replica_groups={{1,3,2,0},{4,5,6,7}}, to_apply=%sum
-  %gtable = s32[8]{0} constant({3,0,2,1,0,1,2,3})
-  %rid = u32[] replica-id()
-  %id = s32[1] dynamic-slice(%gtable, %rid), dynamic_slice_sizes={1}
-  %reshape.0 = s32[] reshape(%id)
-  %table = s32[4]{0} constant({0,8,16,24})
-  %offset = s32[1] dynamic-slice(%table, %reshape.0), dynamic_slice_sizes={1}
-  %reshape.1 = s32[] reshape(%offset)
-  %zero = s32[] constant(0)
-  ROOT %dynamic-slice = f32[8,8,128] dynamic-slice(%all-reduce, %reshape.1, %zero, %zero),
-    dynamic_slice_sizes={8,8,128}
-}
-)";
-  TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
-                                               /*num_replicas=*/8,
-                                               /*num_partitions=*/2,
-                                               /*expect_change=*/true));
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch(m::ReduceScatter(m::Parameter(0))));
-  EXPECT_EQ(AllReduceCount(module), 0);
-}
-
-TEST_F(GpuReduceScatterCreatorTest, AllPartitions) {
-  absl::string_view hlo_string = R"(
-HloModule AllReduce
-
-%sum {
-  %a = f32[] parameter(0)
-  %b = f32[] parameter(1)
-  ROOT %add = f32[] add(%a, %b)
-}
-
-ENTRY %AllReduce {
-  %param = f32[32,8,128]{2,1,0} parameter(0)
-  %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
-    replica_groups={{0},{1}}, to_apply=%sum, channel_id=1
-  %table = s32[8]{0} constant({0,1,2,3,4,5,6,7})
-  %pid = u32[] partition-id()
-  %id = s32[1] dynamic-slice(%table, %pid), dynamic_slice_sizes={1}
-  %reshape = s32[] reshape(%id)
-  %slice_size = s32[] constant(4)
-  %offset = s32[] multiply(%reshape, %slice_size)
-  %zero = s32[] constant(0)
-  ROOT %dynamic-slice = f32[4,8,128] dynamic-slice(%all-reduce, %offset, %zero, %zero),
-    dynamic_slice_sizes={4,8,128}
-}
-)";
-  TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
-                                               /*num_replicas=*/2,
-                                               /*num_partitions=*/8,
-                                               /*expect_change=*/true));
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch(m::ReduceScatter(m::Parameter(0))));
-  EXPECT_EQ(AllReduceCount(module), 0);
-}
-
-TEST_F(GpuReduceScatterCreatorTest, AllReduceFollowedByAllReduce) {
-  absl::string_view hlo_string = R"(
-HloModule AllReduce
-
-%sum {
-  %a = f32[] parameter(0)
-  %b = f32[] parameter(1)
-  ROOT %add = f32[] add(%a, %b)
-}
-
-ENTRY %AllReduce {
-  %param = f32[32,8,128]{2,1,0} parameter(0)
-  %all-reduce.scattered = f32[32,8,128]{2,1,0} all-reduce(%param),
-    replica_groups={{0,1,2,3,4,5,6,7},{8,9,10,11,12,13,14,15}}, to_apply=%sum, use_global_device_ids=true, channel_id=1
-  %table = s32[8]{0} constant({0,1,2,3,4,5,6,7})
-  %pid = u32[] partition-id()
-  %id = s32[1] dynamic-slice(%table, %pid), dynamic_slice_sizes={1}
-  %reshape = s32[] reshape(%id)
-  %slice_size = s32[] constant(4)
-  %offset = s32[] multiply(%reshape, %slice_size)
-  %zero = s32[] constant(0)
-  %dynamic-slice = f32[4,8,128] dynamic-slice(%all-reduce.scattered, %offset, %zero, %zero),
-    dynamic_slice_sizes={4,8,128}
-  ROOT %all-reduce.sync = f32[4,8,128]{2,1,0} all-reduce(%dynamic-slice),
-    replica_groups={{0,8},{1,9},{2,10},{3,11},{4,12},{5,13},{6,14},{7,15}}, to_apply=%sum, use_global_device_ids=true, channel_id=2
-}
-)";
-  TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
-                                               /*num_replicas=*/2,
-                                               /*num_partitions=*/8,
-                                               /*expect_change=*/true));
-  EXPECT_EQ(AllReduceCount(module), 1);
-  EXPECT_EQ(ReduceScatterCount(module), 1);
-}
-
-TEST_F(GpuReduceScatterCreatorTest, SubgroupsGlobals) {
-  absl::string_view hlo_string = R"(
-HloModule AllReduce
-
-%sum {
-  %a = f32[] parameter(0)
-  %b = f32[] parameter(1)
-  ROOT %add = f32[] add(%a, %b)
-}
-
-ENTRY %AllReduce {
-  %param = f32[32,8,128]{2,1,0} parameter(0)
-  %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
-    replica_groups={{1,3,2,0},{4,5,6,7}}, to_apply=%sum, channel_id=1, use_global_device_ids=true
-  %pid = u32[] partition-id()
-  %rid = u32[] replica-id()
-  %pcount = u32[] constant(4)
-  %ridxp = u32[] multiply(%rid, %pcount)
-  %gid = u32[] add(%ridxp, %pid)
-  %gtable = s32[8]{0} constant({3,0,2,1,0,1,2,3})
-  %id = s32[1] dynamic-slice(%gtable, %gid), dynamic_slice_sizes={1}
-  %reshape.0 = s32[] reshape(%id)
-  %table = s32[4]{0} constant({0,8,16,24})
-  %offset = s32[1] dynamic-slice(%table, %reshape.0), dynamic_slice_sizes={1}
-  %reshape.1 = s32[] reshape(%offset)
-  %zero = s32[] constant(0)
-  ROOT %dynamic-slice = f32[8,8,128] dynamic-slice(%all-reduce, %reshape.1, %zero, %zero),
-    dynamic_slice_sizes={8,8,128}
-}
-)";
-  TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
-                                               /*num_replicas=*/2,
-                                               /*num_partitions=*/4,
-                                               /*expect_change=*/true));
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch(m::ReduceScatter(m::Parameter(0))));
-  EXPECT_EQ(AllReduceCount(module), 0);
-}
-
-TEST_F(GpuReduceScatterCreatorTest, SubgroupsGlobalsOrthogonalReplicas) {
-  absl::string_view hlo_string = R"(
-HloModule AllReduce
-
-%sum {
-  %a = f32[] parameter(0)
-  %b = f32[] parameter(1)
-  ROOT %add = f32[] add(%a, %b)
-}
-
-ENTRY %AllReduce {
-  %param = f32[32,8,128]{2,1,0} parameter(0)
-  %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
-    replica_groups={{1,3,2,0},{5,7,6,4}}, to_apply=%sum, channel_id=1, use_global_device_ids=true
-  %pid = u32[] partition-id()
-  %pid_table = s32[4]{0} constant({3,0,2,1})
-  %offset = s32[1] dynamic-slice(%pid_table, %pid), dynamic_slice_sizes={1}
-  %reshape = s32[] reshape(%offset)
-  %shard_size = s32[] constant(8)
-  %mul = s32[] multiply(%reshape, %shard_size)
-  %zero = s32[] constant(0)
-  ROOT %dynamic-slice = f32[8,8,128] dynamic-slice(%all-reduce, %mul, %zero, %zero),
-    dynamic_slice_sizes={8,8,128}
-}
-)";
-  TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
-                                               /*num_replicas=*/2,
-                                               /*num_partitions=*/4,
-                                               /*expect_change=*/true));
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch(m::ReduceScatter(m::Parameter(0))));
-  EXPECT_EQ(AllReduceCount(module), 0);
-}
-
-TEST_F(GpuReduceScatterCreatorTest, SubgroupsGlobalsNonOrthogonalReplicas) {
-  absl::string_view hlo_string = R"(
-HloModule AllReduce
-
-%sum {
-  %a = f32[] parameter(0)
-  %b = f32[] parameter(1)
-  ROOT %add = f32[] add(%a, %b)
-}
-
-ENTRY %AllReduce {
-  %param = f32[32,8,128]{2,1,0} parameter(0)
-  %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
-    replica_groups={{1,3,2,0},{7,5,6,4}}, to_apply=%sum, channel_id=1, use_global_device_ids=true
-  %pid = u32[] partition-id()
-  %pid_table = s32[4]{0} constant({3,0,2,1})
-  %offset = s32[1] dynamic-slice(%pid_table, %pid), dynamic_slice_sizes={1}
-  %reshape = s32[] reshape(%offset)
-  %shard_size = s32[] constant(8)
-  %mul = s32[] multiply(%reshape, %shard_size)
-  %zero = s32[] constant(0)
-  ROOT %dynamic-slice = f32[8,8,128] dynamic-slice(%all-reduce, %mul, %zero, %zero),
-    dynamic_slice_sizes={8,8,128}
-}
-)";
-  TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
-                                               /*num_replicas=*/2,
-                                               /*num_partitions=*/4,
-                                               /*expect_change=*/false));
-}
-
-TEST_F(GpuReduceScatterCreatorTest, NonUniformSplit) {
-  absl::string_view hlo_string = R"(
-HloModule AllReduce
-
-%sum {
-  %a = f32[] parameter(0)
-  %b = f32[] parameter(1)
-  ROOT %add = f32[] add(%a, %b)
-}
-
-ENTRY %AllReduce {
-  %param = f32[1,7]{1,0} parameter(0)
-  %all-reduce = f32[1,7]{1,0} all-reduce(%param),
-    replica_groups={{0,1},{2,3},{4,5},{6,7}}, to_apply=%sum, channel_id=1, use_global_device_ids=true
-  %pid = u32[] partition-id()
-  %pid_table = s32[8]{0} constant({0, 1, 0, 1, 0, 1, 0, 1})
-  %offset = s32[1] dynamic-slice(%pid_table, %pid), dynamic_slice_sizes={1}
-  %reshape = s32[] reshape(%offset)
-  %shard_size = s32[] constant(3)
-  %mul = s32[] multiply(%reshape, %shard_size)
-  %zero = s32[] constant(0)
-  ROOT %dynamic-slice = f32[1,3] dynamic-slice(%all-reduce, %zero, %mul),
-    dynamic_slice_sizes={1,3}
-}
-)";
-  TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
-                                               /*num_replicas=*/1,
-                                               /*num_partitions=*/8,
-                                               /*expect_change=*/true));
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch(m::ReduceScatter(m::Slice(m::Parameter(0)))));
-}
-
-}  // namespace
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/tests/reduction_degenerate_dim_remover_test.cc b/third_party/xla/xla/service/gpu/tests/reduction_degenerate_dim_remover_test.cc
deleted file mode 100644
index bb6eb63..0000000
--- a/third_party/xla/xla/service/gpu/tests/reduction_degenerate_dim_remover_test.cc
+++ /dev/null
@@ -1,133 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/reduction_degenerate_dim_remover.h"
-
-#include <optional>
-
-#include "absl/strings/string_view.h"
-#include "xla/tests/hlo_test_base.h"
-#include "tsl/platform/test.h"
-
-namespace xla {
-
-namespace {
-
-class ReductionDegenerateDimRemoverTest : public HloTestBase {
- public:
-  void CheckDegenerateDimRemover(absl::string_view hlo,
-                                 std::optional<absl::string_view> expected) {
-    RunAndFilecheckHloRewrite(hlo, gpu::ReductionDegenerateDimRemover{},
-                              expected);
-  }
-};
-
-TEST_F(ReductionDegenerateDimRemoverTest, ReductionWithDegenerateDimensions) {
-  const char* hlo = R"(
-HloModule ReduceWithDegenerateDimensions
-
-add {
-  accum = f32[] parameter(0)
-  op = f32[] parameter(1)
-  ROOT out = f32[] add(accum, op)
-}
-
-ENTRY main {
-  input = f32[1,3,1,4,1,5,1] parameter(0)
-  zero = f32[] constant(0)
-
-  ROOT out = f32[1,1,1,1] reduce(input, zero), dimensions={1,3,5}, to_apply=add
-}
-
-)";
-
-  CheckDegenerateDimRemover(hlo, R"(
-// CHECK: [[bitcast_0:%[^ ]+]] = f32[3,4,5]{2,1,0} bitcast([[input_1:%[^ ]+]])
-// CHECK: [[reduce_2:%[^ ]+]] = f32[] reduce([[bitcast_0]], [[zero_3:%[^ ]+]]), dimensions={0,1,2}, to_apply=[[add_4:%[^ ]+]]
-// CHECK: ROOT [[bitcast_1_5:%[^ ]+]] = f32[1,1,1,1]{3,2,1,0} bitcast([[reduce_2]])
-  )");
-}
-
-TEST_F(ReductionDegenerateDimRemoverTest,
-       ReductionWithDegenerateDimensionsVariadic) {
-  const char* hlo = R"(
-HloModule ReduceWithDegenerateDimensions
-
-argmax {
-  running_max = f32[] parameter(0)
-  running_max_idx = u32[] parameter(1)
-  current_value = f32[] parameter(2)
-  current_value_idx = u32[] parameter(3)
-
-  current = (f32[], u32[]) tuple(running_max, running_max_idx)
-  potential = (f32[], u32[]) tuple(current_value, current_value_idx)
-
-  cmp_code = pred[] compare(current_value, running_max), direction=GT
-
-  new_max = f32[] select(cmp_code, current_value, running_max)
-  new_idx = u32[] select(cmp_code, current_value_idx, running_max_idx)
-
-  ROOT out = (f32[], u32[]) tuple(new_max, new_idx)
-}
-
-ENTRY main {
-  input = f32[1,3,1,4,1,5,1] parameter(0)
-  idxs = u32[1,3,1,4,1,5,1] parameter(1)
-  zero = f32[] constant(0)
-  zero_idx = u32[] constant(0)
-
-  ROOT out = (f32[1,1,1,1], u32[1,1,1,1]) reduce(input, idxs, zero, zero_idx), dimensions={1,3,5}, to_apply=argmax
-}
-
-)";
-
-  CheckDegenerateDimRemover(hlo, R"(
-// CHECK:  [[bitcast_0:%[^ ]+]] = f32[3,4,5]{2,1,0} bitcast([[input_1:%[^ ]+]])
-// CHECK:  [[bitcast_1_2:%[^ ]+]] = u32[3,4,5]{2,1,0} bitcast([[idxs_3:%[^ ]+]])
-// CHECK:  [[reduce_4:%[^ ]+]] = (f32[], u32[]) reduce([[bitcast_0]], [[bitcast_1_2]], [[zero_5:%[^ ]+]], [[zero_idx_6:%[^ ]+]]), dimensions={0,1,2}, to_apply=[[argmax_7:%[^ ]+]]
-// CHECK-NEXT:  [[get_tuple_element_8:%[^ ]+]] = f32[] get-tuple-element([[reduce_4]]), index=0
-// CHECK-NEXT:  [[bitcast_2_9:%[^ ]+]] = f32[1,1,1,1]{3,2,1,0} bitcast([[get_tuple_element_8]])
-// CHECK-NEXT:  [[get_tuple_element_1_10:%[^ ]+]] = u32[] get-tuple-element([[reduce_4]]), index=1
-// CHECK-NEXT:  [[bitcast_3_11:%[^ ]+]] = u32[1,1,1,1]{3,2,1,0} bitcast([[get_tuple_element_1_10]])
-// CHECK-NEXT:  ROOT [[tuple_12:%[^ ]+]] = (f32[1,1,1,1]{3,2,1,0}, u32[1,1,1,1]{3,2,1,0}) tuple([[bitcast_2_9]], [[bitcast_3_11]])
-)");
-}
-
-TEST_F(ReductionDegenerateDimRemoverTest, DegenerateWithEmptyDimension) {
-  const char* hlo = R"(
-HloModule ReduceWithDegenerateDimensions
-
-add {
-  accum = f32[] parameter(0)
-  op = f32[] parameter(1)
-  ROOT out = f32[] add(accum, op)
-}
-
-ENTRY main {
-  input = f32[1,3,1,4,1,5,1] parameter(0)
-  zero = f32[] constant(0)
-
-  ROOT out = f32[3,4,5,1] reduce(input, zero), dimensions={0,2,4}, to_apply=add
-}
-)";
-
-  CheckDegenerateDimRemover(hlo,
-                            R"(
-// CHECK: ROOT [[bitcast_0:%[^ ]+]] = f32[3,4,5,1]{3,2,1,0} bitcast([[input_1:%[^ ]+]])
-      )");
-}
-
-}  // namespace
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/tests/reduction_dimension_grouper_test.cc b/third_party/xla/xla/service/gpu/tests/reduction_dimension_grouper_test.cc
deleted file mode 100644
index fa149a1..0000000
--- a/third_party/xla/xla/service/gpu/tests/reduction_dimension_grouper_test.cc
+++ /dev/null
@@ -1,103 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/reduction_dimension_grouper.h"
-
-#include <optional>
-
-#include "absl/strings/string_view.h"
-#include "xla/tests/hlo_test_base.h"
-#include "tsl/platform/test.h"
-
-namespace xla {
-
-namespace {
-
-class ReductionDimensionGrouperTest : public HloTestBase {
- public:
-  void CheckDimensionGrouper(absl::string_view hlo,
-                             std::optional<absl::string_view> expected) {
-    RunAndFilecheckHloRewrite(hlo, gpu::ReductionDimensionGrouper{}, expected);
-  }
-};
-
-TEST_F(ReductionDimensionGrouperTest, ReductionWithGrouping) {
-  const char* hlo = R"(
-HloModule ReductionWithGrouping
-
-add {
-  accum = f32[] parameter(0)
-  op = f32[] parameter(1)
-  ROOT out = f32[] add(accum, op)
-}
-
-ENTRY main {
-  input = f32[100,10,32,3]{3,2,1,0} parameter(0)
-  zero = f32[] constant(0)
-
-  ROOT out = f32[100,10]{0,1} reduce(input, zero), dimensions={2,3}, to_apply=add
-}
-)";
-
-  CheckDimensionGrouper(hlo,
-                        R"(
-// CHECK:  [[input_0:%[^ ]+]] = f32[100,10,32,3]{3,2,1,0} parameter(0)
-// CHECK:  [[bitcast_1:%[^ ]+]] = f32[100,10,96]{2,1,0} bitcast([[input_0]])
-// CHECK:  ROOT [[out_1_2:%[^ ]+]] = f32[100,10]{0,1} reduce([[bitcast_1]], [[zero_3:%[^ ]+]]), dimensions={2}, to_apply=[[add_4:%[^ ]+]]
-      )");
-}
-
-TEST_F(ReductionDimensionGrouperTest, ReductionWithGroupingVariadic) {
-  const char* hlo = R"(
-HloModule ReductionWithGrouping
-
-argmax {
-  running_max = f32[] parameter(0)
-  running_max_idx = u32[] parameter(1)
-  current_value = f32[] parameter(2)
-  current_value_idx = u32[] parameter(3)
-
-  current = (f32[], u32[]) tuple(running_max, running_max_idx)
-  potential = (f32[], u32[]) tuple(current_value, current_value_idx)
-
-  cmp_code = pred[] compare(current_value, running_max), direction=GT
-
-  new_max = f32[] select(cmp_code, current_value, running_max)
-  new_idx = u32[] select(cmp_code, current_value_idx, running_max_idx)
-
-  ROOT out = (f32[], u32[]) tuple(new_max, new_idx)
-}
-
-ENTRY main {
-  input = f32[100,10,32,3]{3,2,1,0} parameter(0)
-  idxs = u32[100,10,32,3]{3,2,1,0} parameter(1)
-  zero = f32[] constant(0)
-  zero_idx = u32[] constant(0)
-
-  ROOT out = (f32[100,10]{1,0}, u32[100,10]{1,0}) reduce(input, idxs, zero, zero_idx), dimensions={2,3}, to_apply=argmax
-}
-)";
-
-  CheckDimensionGrouper(hlo, R"(
-// CHECK:  [[input_0:%[^ ]+]] = f32[100,10,32,3]{3,2,1,0} parameter(0)
-// CHECK:  [[bitcast_1:%[^ ]+]] = f32[100,10,96]{2,1,0} bitcast([[input_0]])
-// CHECK:  [[idxs_2:%[^ ]+]] = u32[100,10,32,3]{3,2,1,0} parameter(1)
-// CHECK:  [[bitcast_1_3:%[^ ]+]] = u32[100,10,96]{2,1,0} bitcast([[idxs_2]])
-// CHECK:  ROOT [[out_1_4:%[^ ]+]] = (f32[100,10]{1,0}, u32[100,10]{1,0}) reduce([[bitcast_1]], [[bitcast_1_3]], [[zero_5:%[^ ]+]], [[zero_idx_6:%[^ ]+]]), dimensions={2}, to_apply=[[argmax_7:%[^ ]+]]
-)");
-}
-
-}  // namespace
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/tests/reduction_layout_normalizer_test.cc b/third_party/xla/xla/service/gpu/tests/reduction_layout_normalizer_test.cc
deleted file mode 100644
index 817d9c7..0000000
--- a/third_party/xla/xla/service/gpu/tests/reduction_layout_normalizer_test.cc
+++ /dev/null
@@ -1,164 +0,0 @@
-/* Copyright 2019 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/reduction_layout_normalizer.h"
-
-#include <optional>
-
-#include "absl/strings/string_view.h"
-#include "xla/error_spec.h"
-#include "xla/tests/hlo_test_base.h"
-#include "tsl/platform/test.h"
-
-namespace xla {
-
-namespace {
-
-class ReductionLayoutNormalizerTest : public HloTestBase {
- public:
-  void CheckReductionLayoutNormalizer(
-      absl::string_view hlo, std::optional<absl::string_view> expected) {
-    RunAndFilecheckHloRewrite(hlo, gpu::ReductionLayoutNormalizer{}, expected);
-  }
-};
-
-TEST_F(ReductionLayoutNormalizerTest, LayoutCanonicalizerTest) {
-  const char* hlo = R"(
-HloModule ReduceWithLayoutChange
-
-add {
-  x0 = f32[] parameter(0)
-  y0 = f32[] parameter(1)
-  ROOT add0 = f32[] add(x0, y0)
-}
-
-ENTRY main {
-  arg0 = f32[4,5,5,16,12,12,3,3]{2,3,5,4,0,7,6,1}  parameter(0)
-  constant0 = f32[] constant(0)
-  ROOT reduce0 = f32[4,5,16,12,12]{4,3,2,1,0} reduce(arg0, constant0),
-    dimensions={1,6,7}, to_apply=add
-}
-
-)";
-
-  CheckReductionLayoutNormalizer(hlo,
-                                 R"(
-// CHECK:  [[bitcast_0:%[^ ]+]] = f32[5,3,3,4,12,12,16,5]{7,6,5,4,3,2,1,0} bitcast([[arg0_1:%[^ ]+]])
-// CHECK:  [[reduce_2:%[^ ]+]] = f32[4,12,12,16,5]{2,1,3,4,0} reduce([[bitcast_0]], [[constant0_3:%[^ ]+]]), dimensions={0,1,2}, to_apply=[[add_4:%[^ ]+]]
-// CHECK:  ROOT [[bitcast_1_5:%[^ ]+]] = f32[4,5,16,12,12]{4,3,2,1,0} bitcast([[reduce_2]])
-      )");
-}
-
-TEST_F(ReductionLayoutNormalizerTest, LayoutCanonicalizerTestVariadic) {
-  const char* hlo = R"(
-HloModule ReduceWithLayoutChangeVariadic
-
-
-argmax {
-  running_max = f32[] parameter(0)
-  running_max_idx = u32[] parameter(1)
-  current_value = f32[] parameter(2)
-  current_value_idx = u32[] parameter(3)
-
-  current = (f32[], u32[]) tuple(running_max, running_max_idx)
-  potential = (f32[], u32[]) tuple(current_value, current_value_idx)
-
-  cmp_code = pred[] compare(current_value, running_max), direction=GT
-
-  new_max = f32[] select(cmp_code, current_value, running_max)
-  new_idx = u32[] select(cmp_code, current_value_idx, running_max_idx)
-
-  ROOT out = (f32[], u32[]) tuple(new_max, new_idx)
-}
-
-ENTRY main {
-  arg0 = f32[4,5,5,16,12,12,3,3]{2,3,5,4,0,7,6,1}  parameter(0)
-  idxs = u32[4,5,5,16,12,12,3,3]{2,3,5,4,0,7,6,1}  parameter(1)
-  constant0 = f32[] constant(0)
-  constant1 = u32[] constant(0)
-  ROOT reduce0 = (
-      f32[4,5,16,12,12]{4,3,2,1,0},
-      u32[4,5,16,12,12]{4,3,2,1,0}
-    ) reduce(arg0, idxs, constant0,constant1), dimensions={1,6,7}, to_apply=argmax
-}
-
-
-)";
-
-  CheckReductionLayoutNormalizer(hlo,
-                                 R"(
-// CHECK:  [[arg0_0:%[^ ]+]] = f32[4,5,5,16,12,12,3,3]{2,3,5,4,0,7,6,1} parameter(0)
-// CHECK:  [[bitcast_1:%[^ ]+]] = f32[5,3,3,4,12,12,16,5]{7,6,5,4,3,2,1,0} bitcast([[arg0_0]])
-// CHECK:  [[idxs_2:%[^ ]+]] = u32[4,5,5,16,12,12,3,3]{2,3,5,4,0,7,6,1} parameter(1)
-// CHECK:  [[bitcast_1_3:%[^ ]+]] = u32[5,3,3,4,12,12,16,5]{7,6,5,4,3,2,1,0} bitcast([[idxs_2]])
-// CHECK:  [[reduce_4:%[^ ]+]] = (f32[4,12,12,16,5]{2,1,3,4,0}, u32[4,12,12,16,5]{2,1,3,4,0}) reduce([[bitcast_1]], [[bitcast_1_3]], [[constant0_5:%[^ ]+]], [[constant1_6:%[^ ]+]]), dimensions={0,1,2}, to_apply=[[argmax_7:%[^ ]+]]
-// CHECK:  [[get_tuple_element_8:%[^ ]+]] = f32[4,12,12,16,5]{2,1,3,4,0} get-tuple-element([[reduce_4]]), index=0
-// CHECK:  [[bitcast_2_9:%[^ ]+]] = f32[4,5,16,12,12]{4,3,2,1,0} bitcast([[get_tuple_element_8]])
-// CHECK:  [[get_tuple_element_1_10:%[^ ]+]] = u32[4,12,12,16,5]{2,1,3,4,0} get-tuple-element([[reduce_4]]), index=1
-// CHECK:  [[bitcast_3_11:%[^ ]+]] = u32[4,5,16,12,12]{4,3,2,1,0} bitcast([[get_tuple_element_1_10]])
-// CHECK:  ROOT [[tuple_12:%[^ ]+]] = (f32[4,5,16,12,12]{4,3,2,1,0}, u32[4,5,16,12,12]{4,3,2,1,0}) tuple([[bitcast_2_9]], [[bitcast_3_11]])
-      )");
-}
-
-TEST_F(ReductionLayoutNormalizerTest,
-       LayoutCanonicalizerTestVariadicDifferentLayouts) {
-  const char* hlo = R"(
-HloModule ReduceWithLayoutChangeVariadicDifferent
-
-argmax {
-  running_max = f32[] parameter(0)
-  running_max_idx = u32[] parameter(1)
-  current_value = f32[] parameter(2)
-  current_value_idx = u32[] parameter(3)
-
-  current = (f32[], u32[]) tuple(running_max, running_max_idx)
-  potential = (f32[], u32[]) tuple(current_value, current_value_idx)
-
-  cmp_code = pred[] compare(current_value, running_max), direction=GT
-
-  new_max = f32[] select(cmp_code, current_value, running_max)
-  new_idx = u32[] select(cmp_code, current_value_idx, running_max_idx)
-
-  ROOT out = (f32[], u32[]) tuple(new_max, new_idx)
-}
-
-ENTRY main {
-  arg0 = f32[2,3,4,7]{2,1,0,3}  parameter(0)
-  idxs = u32[2,3,4,7]{3,2,1,0}  parameter(1)
-  constant0 = f32[] constant(0)
-  constant1 = u32[] constant(0)
-  ROOT reduce0 = (
-      f32[2,3,4]{2,1,0},
-      u32[2,3,4]{2,1,0}
-    ) reduce(arg0, idxs, constant0,constant1), dimensions={3}, to_apply=argmax
-}
-
-
-)";
-
-  CheckReductionLayoutNormalizer(hlo,
-                                 R"(
-// CHECK:  [[arg0_0:%[^ ]+]] = f32[2,3,4,7]{2,1,0,3} parameter(0)
-// CHECK:  [[bitcast_1:%[^ ]+]] = f32[7,2,3,4]{3,2,1,0} bitcast([[arg0_0]])
-// CHECK:  [[idxs_2:%[^ ]+]] = u32[2,3,4,7]{3,2,1,0} parameter(1)
-// CHECK:  [[copy_3:%[^ ]+]] = u32[2,3,4,7]{2,1,0,3} copy([[idxs_2]])
-// CHECK:  [[bitcast_1_4:%[^ ]+]] = u32[7,2,3,4]{3,2,1,0} bitcast([[copy_3]])
-// CHECK:  ROOT [[reduce0_5:%[^ ]+]] = (f32[2,3,4]{2,1,0}, u32[2,3,4]{2,1,0}) reduce([[bitcast_1]], [[bitcast_1_4]], [[constant0_6:%[^ ]+]], [[constant1_7:%[^ ]+]]), dimensions={0}, to_apply=[[argmax_8:%[^ ]+]]
-      )");
-  EXPECT_TRUE(RunAndCompare(hlo, ErrorSpec{1e-5, 1e-5}));
-}
-
-}  // namespace
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/tests/sparse_add_encoding.mlir b/third_party/xla/xla/service/gpu/tests/sparse_add_encoding.mlir
new file mode 100644
index 0000000..d8abd2d
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/tests/sparse_add_encoding.mlir
@@ -0,0 +1,46 @@
+// RUN: xla-opt %s \
+// RUN:   -convert-triton-to-tritongpu='target=cuda:80' \
+// RUN:   -sparse-add-encoding -canonicalize \
+// RUN: | FileCheck %s
+
+// Note: 'canonicalize' folds redundant (back-and-forth) convert_layout ops.
+
+// CHECK-DAG: #[[BLOCKED4x4:.*]] = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> 
+// CHECK-DAG: #[[BLOCKED1x1:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
+
+module {
+  // CHECK: @sparse_dot
+  tt.func @sparse_dot() {
+    // CHECK-NEXT: %[[A:.*]] = arith.constant dense<1.000000e+00>
+    // CHECK-SAME:   : tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[BLOCKED4x4]]}>> 
+    // CHECK-NEXT: %[[B:.*]] = arith.constant dense<2.000000e+00>
+    // CHECK-SAME:   : tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[BLOCKED4x4]]}>> 
+    // CHECK-NEXT: %[[C:.*]] = arith.constant dense<0.000000e+00>
+    // CHECK-SAME:   : tensor<64x64xf32, #[[BLOCKED4x4]]> 
+    // CHECK-NEXT: %[[META:.*]] = arith.constant dense<13107>
+    // CHECK-SAME:   : tensor<64x4xi16, #triton_gpu.sparse_dot_meta<{parent = #[[BLOCKED4x4]]}>> 
+    %a = arith.constant dense<1.00e+00> : tensor<64x32xf16>
+    %b = arith.constant dense<2.00e+00> : tensor<64x64xf16>
+    %c = arith.constant dense<0.00e+00> : tensor<64x64xf32>
+    %meta = arith.constant dense<0x3333> : tensor<64x4xi16>
+
+    // CHECK-NEXT: %[[D:.*]] = triton_gpu.sparse_dot %[[A]], %[[B]], %[[C]], %[[META]]
+    // CHECK-SAME:   : tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[BLOCKED4x4]]}>>
+    // CHECK-SAME:     meta tensor<64x4xi16, #triton_gpu.sparse_dot_meta<{parent = #[[BLOCKED4x4]]}>>
+    // CHECK-SAME:     * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[BLOCKED4x4]]}>>
+    // CHECK-SAME:     -> tensor<64x64xf32, #[[BLOCKED4x4]]>
+    %d = triton_gpu.sparse_dot %a, %b, %c, %meta
+      : tensor<64x32xf16> meta tensor<64x4xi16> * tensor<64x64xf16> -> tensor<64x64xf32>
+
+    // CHECK-NEXT: %[[CVT:.*]] = triton_gpu.convert_layout %[[D]]
+    // CHECK-SAME:   : tensor<64x64xf32, #[[BLOCKED4x4]]>
+    // CHECK-SAME:     -> tensor<64x64xf32, #[[BLOCKED1x1]]>
+    // CHECK-NEXT: tt.print "" {hex = false, isSigned = array<i32: 0>} : %[[CVT]]
+    // CHECK-SAME:   : tensor<64x64xf32, #[[BLOCKED1x1]]>
+    // A use with side effects so we don't DCE the whole function.
+    tt.print "" { hex = false, isSigned = array<i32: 0>} : %d : tensor<64x64xf32>
+
+    // CHECK-NEXT: tt.return 
+    tt.return
+  }
+}
diff --git a/third_party/xla/xla/service/gpu/tests/sparse_add_layout.mlir b/third_party/xla/xla/service/gpu/tests/sparse_add_layout.mlir
deleted file mode 100644
index 10b3e45..0000000
--- a/third_party/xla/xla/service/gpu/tests/sparse_add_layout.mlir
+++ /dev/null
@@ -1,46 +0,0 @@
-// RUN: xla-opt %s \
-// RUN:   -convert-triton-to-tritongpu='target=cuda:80' \
-// RUN:   -add-sparse-encoding -canonicalize \
-// RUN: | FileCheck %s
-
-// Note: 'canonicalize' folds redundant (back-and-forth) convert_layout ops.
-
-// CHECK-DAG: #[[BLOCKED4x4:.*]] = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> 
-// CHECK-DAG: #[[BLOCKED1x1:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
-
-module {
-  // CHECK: @sparse_dot
-  tt.func @sparse_dot() {
-    // CHECK-NEXT: %[[A:.*]] = arith.constant dense<1.000000e+00>
-    // CHECK-SAME:   : tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[BLOCKED4x4]]}>> 
-    // CHECK-NEXT: %[[B:.*]] = arith.constant dense<2.000000e+00>
-    // CHECK-SAME:   : tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[BLOCKED4x4]]}>> 
-    // CHECK-NEXT: %[[C:.*]] = arith.constant dense<0.000000e+00>
-    // CHECK-SAME:   : tensor<64x64xf32, #[[BLOCKED4x4]]> 
-    // CHECK-NEXT: %[[META:.*]] = arith.constant dense<13107>
-    // CHECK-SAME:   : tensor<64x4xi16, #triton_gpu.sparse_dot_meta<{parent = #[[BLOCKED4x4]]}>> 
-    %a = arith.constant dense<1.00e+00> : tensor<64x32xf16>
-    %b = arith.constant dense<2.00e+00> : tensor<64x64xf16>
-    %c = arith.constant dense<0.00e+00> : tensor<64x64xf32>
-    %meta = arith.constant dense<0x3333> : tensor<64x4xi16>
-
-    // CHECK-NEXT: %[[D:.*]] = triton_gpu.sparse_dot %[[A]], %[[B]], %[[C]], %[[META]]
-    // CHECK-SAME:   : tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[BLOCKED4x4]]}>>
-    // CHECK-SAME:     meta tensor<64x4xi16, #triton_gpu.sparse_dot_meta<{parent = #[[BLOCKED4x4]]}>>
-    // CHECK-SAME:     * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[BLOCKED4x4]]}>>
-    // CHECK-SAME:     -> tensor<64x64xf32, #[[BLOCKED4x4]]>
-    %d = triton_gpu.sparse_dot %a, %b, %c, %meta
-      : tensor<64x32xf16> meta tensor<64x4xi16> * tensor<64x64xf16> -> tensor<64x64xf32>
-
-    // CHECK-NEXT: %[[CVT:.*]] = triton_gpu.convert_layout %[[D]]
-    // CHECK-SAME:   : tensor<64x64xf32, #[[BLOCKED4x4]]>
-    // CHECK-SAME:     -> tensor<64x64xf32, #[[BLOCKED1x1]]>
-    // CHECK-NEXT: tt.print "" {hex = false, isSigned = array<i32: 0>} : %[[CVT]]
-    // CHECK-SAME:   : tensor<64x64xf32, #[[BLOCKED1x1]]>
-    // A use with side effects so we don't DCE the whole function.
-    tt.print "" { hex = false, isSigned = array<i32: 0>} : %d : tensor<64x64xf32>
-
-    // CHECK-NEXT: tt.return 
-    tt.return
-  }
-}
diff --git a/third_party/xla/xla/service/gpu/tests/sparse_convert_triton_to_tritongpu.mlir b/third_party/xla/xla/service/gpu/tests/sparse_convert_triton_to_tritongpu.mlir
new file mode 100644
index 0000000..ad61162
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/tests/sparse_convert_triton_to_tritongpu.mlir
@@ -0,0 +1,25 @@
+// RUN: xla-opt %s -convert-triton-to-tritongpu='target=cuda:80' | FileCheck %s
+
+module attributes {} {
+  tt.func @gemm_fusion_dot_1_impl() {
+    %c0_i32 = arith.constant 0 : i32
+    %c32_i32 = arith.constant 32 : i32
+    %acc = arith.constant dense<0.000000e+00> : tensor<32x32xf32>
+    %a = arith.constant dense<0.000000e+00> : tensor<32x16xbf16>
+    // CHECK: %[[A:.+]] = triton_gpu.convert_layout {{.+}} : tensor<32x16xbf16, {{.+}}> -> tensor<32x16xbf16>
+    %b = arith.constant dense<0.000000e+00> : tensor<32x32xbf16>
+    // CHECK: %[[B:.+]] = triton_gpu.convert_layout {{.+}} : tensor<32x32xbf16, {{.+}}> -> tensor<32x32xbf16>
+    %meta = arith.constant dense<0> : tensor<32x2xi16>
+    // CHECK: %[[META:.+]] = triton_gpu.convert_layout {{.+}} : tensor<32x2xi16, {{.+}}> -> tensor<32x2xi16>
+    %35:1 = scf.for %arg4 = %c0_i32 to %c32_i32 step %c32_i32 iter_args(%arg8 = %acc) -> (tensor<32x32xf32>)  : i32 {
+      // CHECK: %[[ACC:.+]] = triton_gpu.convert_layout {{.+}} : tensor<32x32xf32, {{.+}}> -> tensor<32x32xf32>
+      // CHECK-NEXT: %[[D:.*]] = triton_gpu.sparse_dot %[[A]], %[[B]], %[[ACC]], %[[META]]
+      // CHECK-SAME:   : tensor<32x16xbf16> meta tensor<32x2xi16>
+      // CHECK-SAME:     * tensor<32x32xbf16> -> tensor<32x32xf32>
+      %74 = triton_gpu.sparse_dot %a, %b, %arg8, %meta : tensor<32x16xbf16> meta tensor<32x2xi16> * tensor<32x32xbf16> -> tensor<32x32xf32>
+      // CHECK: %[[ACC:.+]] = triton_gpu.convert_layout {{.+}} : tensor<32x32xf32> -> tensor<32x32xf32, {{.+}}>
+      scf.yield %74 : tensor<32x32xf32>
+    }
+    tt.return
+  }
+}
\ No newline at end of file
diff --git a/third_party/xla/xla/service/gpu/tests/tree_reduction_rewriter_test.cc b/third_party/xla/xla/service/gpu/tests/tree_reduction_rewriter_test.cc
deleted file mode 100644
index ef6e189..0000000
--- a/third_party/xla/xla/service/gpu/tests/tree_reduction_rewriter_test.cc
+++ /dev/null
@@ -1,557 +0,0 @@
-/* Copyright 2020 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/tree_reduction_rewriter.h"
-
-#include <optional>
-
-#include "absl/strings/string_view.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/tests/hlo_test_base.h"
-#include "tsl/platform/test.h"
-
-namespace xla {
-
-namespace {
-
-class TreeReductionRewriterTest : public HloTestBase {
- public:
-  void CheckTreeRewriter(absl::string_view hlo,
-                         std::optional<absl::string_view> expected) {
-    RunAndFilecheckHloRewrite(
-        hlo,
-#if TENSORFLOW_USE_ROCM
-        gpu::GpuTreeReductionRewriter{se::RocmComputeCapability {
-          "908"
-        }},
-#else
-        gpu::GpuTreeReductionRewriter{se::CudaComputeCapability{8, 1}},
-#endif
-        expected);
-  }
-};
-
-TEST_F(TreeReductionRewriterTest, RowReductionSingleDimensionNoBatched) {
-  const char* hlo = R"(
-HloModule ReduceWithPadding
-
-add {
-  accum = f32[] parameter(0)
-  op = f32[] parameter(1)
-  ROOT out = f32[] add(accum, op)
-}
-
-ENTRY main {
-  input = f32[50021] parameter(0)
-  zero = f32[] constant(0)
-  ROOT out = f32[] reduce(input, zero), dimensions={0}, to_apply=add
-}
-)";
-
-  CheckTreeRewriter(hlo,
-                    R"(
-// CHECK: [[pad_0:%[^ ]+]] = f32[50022]{0} pad([[input_1:%[^ ]+]], [[zero_2:%[^ ]+]]), padding=0_1
-// CHECK: [[bitcast_3:%[^ ]+]] = f32[126,397]{1,0} bitcast([[pad_0]])
-// CHECK: [[reduce_4:%[^ ]+]] = f32[126]{0} reduce([[bitcast_3]], [[zero_2]]), dimensions={1}, to_apply=[[add_5:%[^ ]+]]
-// CHECK: ROOT [[out_1_6:%[^ ]+]] = f32[] reduce([[reduce_4]], [[zero_2]]), dimensions={0}, to_apply=[[add_5]]
-      )");
-}
-
-TEST_F(TreeReductionRewriterTest, RowReductionWeirdOutputLayout) {
-  const char* hlo = R"(
-HloModule ReduceWithPadding
-
-add {
-  accum = f32[] parameter(0)
-  op = f32[] parameter(1)
-  ROOT out = f32[] add(accum, op)
-}
-
-ENTRY main {
-  input = f32[2,4,17000]{2,1,0} parameter(0)
-  zero = f32[] constant(0)
-  ROOT out = f32[2,4]{0,1} reduce(input, zero), dimensions={2}, to_apply=add
-}
-)";
-
-  // Check that we preserve the layout.
-  CheckTreeRewriter(hlo,
-                    R"(
-// CHECK: f32[2,4]{0,1} reduce(
-      )");
-}
-
-TEST_F(TreeReductionRewriterTest,
-       RowReductionSingleDimensionNoBatchedDivisible) {
-  const char* hlo = R"(
-HloModule ReduceWithPadding
-
-add {
-  accum = f32[] parameter(0)
-  op = f32[] parameter(1)
-  ROOT out = f32[] add(accum, op)
-}
-
-ENTRY main {
-  input = f32[50048] parameter(0)
-  zero = f32[] constant(0)
-  ROOT out = f32[] reduce(input, zero), dimensions={0}, to_apply=add
-}
-)";
-
-  CheckTreeRewriter(hlo,
-                    R"(
-// CHECK: [[input_0:%[^ ]+]] = f32[50048]{0} parameter(0)
-// CHECK: [[bitcast_1:%[^ ]+]] = f32[128,391]{1,0} bitcast([[input_0]])
-// CHECK: [[zero_2:%[^ ]+]] = f32[] constant(0)
-// CHECK: [[reduce_3:%[^ ]+]] = f32[128]{0} reduce([[bitcast_1]], [[zero_2]]), dimensions={1}, to_apply=[[add_4:%[^ ]+]]
-// CHECK: ROOT [[out_1_5:%[^ ]+]] = f32[] reduce([[reduce_3]], [[zero_2]]), dimensions={0}, to_apply=[[add_4]]
-      )");
-}
-
-TEST_F(TreeReductionRewriterTest, RowReductionNoBatched) {
-  const char* hlo = R"(
-HloModule ReduceWithPadding
-
-add {
-  accum = f32[] parameter(0)
-  op = f32[] parameter(1)
-  ROOT out = f32[] add(accum, op)
-}
-
-ENTRY main {
-  input = f32[100,10,65536] parameter(0)
-  zero = f32[] constant(0)
-  ROOT out = f32[100,10] reduce(input, zero), dimensions={2}, to_apply=add
-}
-)";
-
-  CheckTreeRewriter(hlo,
-                    R"(
-// CHECK: [[bitcast_0:%[^ ]+]] = f32[100,10,256,256]{3,2,1,0} bitcast([[input_1:%[^ ]+]])
-// CHECK: [[zero_2:%[^ ]+]] = f32[] constant(0)
-// CHECK: [[reduce_3:%[^ ]+]] = f32[100,10,256]{2,1,0} reduce([[bitcast_0]], [[zero_2]]), dimensions={3}, to_apply=[[add_4:%[^ ]+]]
-// CHECK: ROOT [[out_1_5:%[^ ]+]] = f32[100,10]{1,0} reduce([[reduce_3]], [[zero_2]]), dimensions={2}, to_apply=[[add_4]]
-      )");
-}
-
-TEST_F(TreeReductionRewriterTest,
-       RowReductionSingleDimensionNoBatchedLargeInput) {
-  const char* hlo = R"(
-HloModule ReduceWithPadding
-
-add {
-  accum = f32[] parameter(0)
-  op = f32[] parameter(1)
-  ROOT out = f32[] add(accum, op)
-}
-
-ENTRY main {
-  input = f32[1048576] parameter(0)
-  zero = f32[] constant(0)
-  ROOT out = f32[] reduce(input, zero), dimensions={0}, to_apply=add
-}
-)";
-
-  CheckTreeRewriter(hlo,
-                    R"(
-// CHECK:  [[input_0:%[^ ]+]] = f32[1048576]{0} parameter(0)
-// CHECK:  [[bitcast_1:%[^ ]+]] = f32[1024,1024]{1,0} bitcast([[input_0]])
-// CHECK:  [[zero_2:%[^ ]+]] = f32[] constant(0)
-// CHECK:  [[reduce_3:%[^ ]+]] = f32[1024]{0} reduce([[bitcast_1]], [[zero_2]]), dimensions={1}, to_apply=[[add_4:%[^ ]+]]
-// CHECK:  ROOT [[out_1_5:%[^ ]+]] = f32[] reduce([[reduce_3]], [[zero_2]]), dimensions={0}, to_apply=[[add_4]]
-      )");
-}
-
-TEST_F(TreeReductionRewriterTest, RowReductionBatchedDimensionFits) {
-  const char* hlo = R"(
-HloModule ReduceWithPadding
-
-add {
-  accum = f32[] parameter(0)
-  op = f32[] parameter(1)
-  ROOT out = f32[] add(accum, op)
-}
-
-ENTRY main {
-  input = f32[8,100,65536] parameter(0)
-  zero = f32[] constant(0)
-  ROOT out = f32[100] reduce(input, zero), dimensions={0,2}, to_apply=add
-}
-)";
-
-  CheckTreeRewriter(hlo,
-                    R"(
-// CHECK:  [[bitcast_0:%[^ ]+]] = f32[8,100,256,256]{3,2,1,0} bitcast([[input_1:%[^ ]+]])
-// CHECK:  [[zero_2:%[^ ]+]] = f32[] constant(0)
-// CHECK:  [[reduce_3:%[^ ]+]] = f32[100,256]{1,0} reduce([[bitcast_0]], [[zero_2]]), dimensions={0,3}, to_apply=[[add_4:%[^ ]+]]
-// CHECK:  ROOT [[out_1_5:%[^ ]+]] = f32[100]{0} reduce([[reduce_3]], [[zero_2]]), dimensions={1}, to_apply=[[add_4]]
-      )");
-}
-
-TEST_F(TreeReductionRewriterTest, RowReductionBatchedDimensionDoesNotFit) {
-  const char* hlo = R"(
-HloModule ReduceWithPadding
-
-add {
-  accum = f32[] parameter(0)
-  op = f32[] parameter(1)
-  ROOT out = f32[] add(accum, op)
-}
-
-ENTRY main {
-  input = f32[32,100,90000] parameter(0)
-  zero = f32[] constant(0)
-  ROOT out = f32[100] reduce(input, zero), dimensions={0,2}, to_apply=add
-}
-)";
-
-  CheckTreeRewriter(hlo,
-                    R"(
-// CHECK: [[reduce_0:%[^ ]+]] = f32[32,100]{1,0} reduce([[input_1:%[^ ]+]], [[zero_2:%[^ ]+]]), dimensions={2}, to_apply=[[add_3:%[^ ]+]]
-// CHECK:  ROOT [[out_1_4:%[^ ]+]] = f32[100]{0} reduce([[reduce_0]], [[zero_2]]), dimensions={0}, to_apply=[[add_3]]
-      )");
-}
-
-TEST_F(TreeReductionRewriterTest, ColumnReductionSimple) {
-  const char* hlo = R"(
-HloModule ReduceWithPadding
-
-add {
-  accum = f32[] parameter(0)
-  op = f32[] parameter(1)
-  ROOT out = f32[] add(accum, op)
-}
-
-ENTRY main {
-  input = f32[16384,100] parameter(0)
-  zero = f32[] constant(0)
-  ROOT out = f32[100] reduce(input, zero), dimensions={0}, to_apply=add
-}
-)";
-
-  CheckTreeRewriter(hlo,
-                    R"(
-
-// CHECK:  [[input_0:%[^ ]+]] = f32[16384,100]{1,0} parameter(0)
-// CHECK:  [[bitcast_1:%[^ ]+]] = f32[128,128,100]{2,1,0} bitcast([[input_0]])
-// CHECK:  [[reduce_2:%[^ ]+]] = f32[128,100]{1,0} reduce([[bitcast_1]], [[zero_3:%[^ ]+]]), dimensions={1}, to_apply=[[add_4:%[^ ]+]]
-// CHECK:  ROOT [[out_1_5:%[^ ]+]] = f32[100]{0} reduce([[reduce_2]], [[zero_3]]), dimensions={0}, to_apply=[[add_4]]
-      )");
-}
-
-TEST_F(TreeReductionRewriterTest, ColumnReductionSimpleNoDivisible) {
-  const char* hlo = R"(
-HloModule ReduceWithPadding
-
-add {
-  accum = f32[] parameter(0)
-  op = f32[] parameter(1)
-  ROOT out = f32[] add(accum, op)
-}
-
-ENTRY main {
-  input = f32[10303,100] parameter(0)
-  zero = f32[] constant(0)
-  ROOT out = f32[100] reduce(input, zero), dimensions={0}, to_apply=add
-}
-)";
-
-  CheckTreeRewriter(hlo,
-                    R"(
-// CHECK:  [[input_0:%[^ ]+]] = f32[10303,100]{1,0} parameter(0)
-// CHECK:  [[zero_2:%[^ ]+]] = f32[] constant(0)
-// CHECK:  [[pad_0:%[^ ]+]] = f32[10304,100]{1,0} pad([[input_1:%[^ ]+]], [[zero_2:%[^ ]+]]), padding=0_1x0_0
-// CHECK:  [[bitcast_1:%[^ ]+]] = f32[64,161,100]{2,1,0} bitcast([[pad_0]])
-// CHECK:  [[reduce_3:%[^ ]+]] = f32[64,100]{1,0} reduce([[bitcast_1]], [[zero_2]]), dimensions={1}, to_apply=[[add_4:%[^ ]+]]
-// CHECK:  ROOT [[out_1_5:%[^ ]+]] = f32[100]{0} reduce([[reduce_3]], [[zero_2]]), dimensions={0}, to_apply=[[add_4]]
-      )");
-}
-
-TEST_F(TreeReductionRewriterTest, ColumnReductionOtherIndex) {
-  const char* hlo = R"(
-HloModule ReduceWithPadding
-
-add {
-  accum = f32[] parameter(0)
-  op = f32[] parameter(1)
-  ROOT out = f32[] add(accum, op)
-}
-
-ENTRY main {
-  input = f32[16384,2,2,2] parameter(0)
-  zero = f32[] constant(0)
-  ROOT out = f32[2,2,2] reduce(input, zero), dimensions={0}, to_apply=add
-}
-)";
-
-  CheckTreeRewriter(hlo,
-                    R"(
-// CHECK:  [[input_0:%[^ ]+]] = f32[16384,2,2,2]{3,2,1,0} parameter(0)
-// CHECK:  [[bitcast_1:%[^ ]+]] = f32[128,128,2,2,2]{4,3,2,1,0} bitcast([[input_0]])
-// CHECK:  [[zero_2:%[^ ]+]] = f32[] constant(0)
-// CHECK:  [[reduce_3:%[^ ]+]] = f32[128,2,2,2]{3,2,1,0} reduce([[bitcast_1]], [[zero_2]]), dimensions={1}, to_apply=[[add_4:%[^ ]+]]
-// CHECK:  ROOT [[out_1_5:%[^ ]+]] = f32[2,2,2]{2,1,0} reduce([[reduce_3]], [[zero_2]]), dimensions={0}, to_apply=[[add_4]]
-      )");
-}
-
-TEST_F(TreeReductionRewriterTest, ColumnReductionVeryLargeInput) {
-  const char* hlo = R"(
-HloModule ReduceWithPadding
-
-add {
-  accum = f32[] parameter(0)
-  op = f32[] parameter(1)
-  ROOT out = f32[] add(accum, op)
-}
-
-ENTRY main {
-  input = f32[1048576,5] parameter(0)
-  zero = f32[] constant(0)
-  ROOT out = f32[5] reduce(input, zero), dimensions={0}, to_apply=add
-}
-)";
-
-  CheckTreeRewriter(hlo,
-                    R"(
-
-// CHECK:  [[bitcast_0:%[^ ]+]] = f32[1024,1024,5]{2,1,0} bitcast([[input_1:%[^ ]+]])
-// CHECK:  [[zero_2:%[^ ]+]] = f32[] constant(0)
-// CHECK:  [[reduce_3:%[^ ]+]] = f32[1024,5]{1,0} reduce([[bitcast_0]], [[zero_2]]), dimensions={1}, to_apply=[[add_4:%[^ ]+]]
-// CHECK:  ROOT [[out_1_5:%[^ ]+]] = f32[5]{0} reduce([[reduce_3]], [[zero_2]]), dimensions={0}, to_apply=[[add_4]]
-      )");
-}
-
-TEST_F(TreeReductionRewriterTest, VariadicReductionLargeRow) {
-  const char* hlo = R"(
-HloModule Reduce_R1x2_to_R0x2_argmax
-
-argmax {
-  running_max = f32[] parameter(0)
-  running_max_idx = u32[] parameter(1)
-  current_value = f32[] parameter(2)
-  current_value_idx = u32[] parameter(3)
-
-  current = (f32[], u32[]) tuple(running_max, running_max_idx)
-  potential = (f32[], u32[]) tuple(current_value, current_value_idx)
-
-  cmp_code = pred[] compare(current_value, running_max), direction=GT
-
-  new_max = f32[] select(cmp_code, current_value, running_max)
-  new_idx = u32[] select(cmp_code, current_value_idx, running_max_idx)
-
-  ROOT out = (f32[], u32[]) tuple(new_max, new_idx)
-}
-
-ENTRY main {
-  input = f32[2,100003] parameter(0)
-  idxs = u32[2,100003] iota(), iota_dimension=0
-  zero = f32[] constant(0)
-  zero_idx = u32[] constant(0)
-
-  ROOT out = (f32[2], u32[2]) reduce(
-    input, idxs, zero, zero_idx),
-    dimensions={1},
-    to_apply=%argmax
-}
-)";
-
-  CheckTreeRewriter(hlo,
-                    R"(
-// CHECK:  [[pad_0:%[^ ]+]] = f32[2,100005]{1,0} pad([[input_1:%[^ ]+]], [[zero_2:%[^ ]+]]), padding=0_0x0_2
-// CHECK:  [[bitcast_3:%[^ ]+]] = f32[2,295,339]{2,1,0} bitcast([[pad_0]])
-// CHECK:  [[zero_idx_4:%[^ ]+]] = u32[] constant(0)
-// CHECK:  [[pad_1_5:%[^ ]+]] = u32[2,100005]{1,0} pad([[idxs_6:%[^ ]+]], [[zero_idx_4]]), padding=0_0x0_2
-// CHECK:  [[bitcast_1_7:%[^ ]+]] = u32[2,295,339]{2,1,0} bitcast([[pad_1_5]])
-// CHECK:  [[reduce_8:%[^ ]+]] = (f32[2,295]{1,0}, u32[2,295]{1,0}) reduce([[bitcast_3]], [[bitcast_1_7]], [[zero_2]], [[zero_idx_4]]), dimensions={2}, to_apply=[[argmax_9:%[^ ]+]]
-// CHECK:  [[get_tuple_element_10:%[^ ]+]] = f32[2,295]{1,0} get-tuple-element([[reduce_8]]), index=0
-// CHECK:  [[get_tuple_element_1_11:%[^ ]+]] = u32[2,295]{1,0} get-tuple-element([[reduce_8]]), index=1
-// CHECK:  ROOT [[out_1_12:%[^ ]+]] = (f32[2]{0}, u32[2]{0}) reduce([[get_tuple_element_10]], [[get_tuple_element_1_11]], [[zero_2]], [[zero_idx_4]]), dimensions={1}, to_apply=[[argmax_9]]
-      )");
-}
-
-TEST_F(TreeReductionRewriterTest, VariadicReductionLargeBatchSize) {
-  const char* hlo = R"(
-HloModule Reduce_R1x2_to_R0x2_argmax
-
-argmax {
-  running_max = f32[] parameter(0)
-  running_max_idx = u32[] parameter(1)
-  current_value = f32[] parameter(2)
-  current_value_idx = u32[] parameter(3)
-
-  current = (f32[], u32[]) tuple(running_max, running_max_idx)
-  potential = (f32[], u32[]) tuple(current_value, current_value_idx)
-
-  cmp_code = pred[] compare(current_value, running_max), direction=GT
-
-  new_max = f32[] select(cmp_code, current_value, running_max)
-  new_idx = u32[] select(cmp_code, current_value_idx, running_max_idx)
-
-  ROOT out = (f32[], u32[]) tuple(new_max, new_idx)
-}
-
-ENTRY main {
-  input = f32[20,2,100] parameter(0)
-  idxs = u32[20,2,100] iota(), iota_dimension=0
-  zero = f32[] constant(0)
-  zero_idx = u32[] constant(0)
-
-  ROOT out = (f32[2], u32[2]) reduce(
-    input, idxs, zero, zero_idx),
-    dimensions={0,2},
-    to_apply=%argmax
-}
-)";
-
-  CheckTreeRewriter(hlo,
-                    R"(
-// CHECK:  [[reduce_0:%[^ ]+]] = (f32[20,2]{1,0}, u32[20,2]{1,0}) reduce([[input_1:%[^ ]+]], [[idxs_2:%[^ ]+]], [[zero_3:%[^ ]+]], [[zero_idx_4:%[^ ]+]]), dimensions={2}, to_apply=[[argmax_5:%[^ ]+]]
-// CHECK:  [[get_tuple_element_6:%[^ ]+]] = f32[20,2]{1,0} get-tuple-element([[reduce_0]]), index=0
-// CHECK:  [[get_tuple_element_1_7:%[^ ]+]] = u32[20,2]{1,0} get-tuple-element([[reduce_0]]), index=1
-// CHECK:  ROOT [[out_1_8:%[^ ]+]] = (f32[2]{0}, u32[2]{0}) reduce([[get_tuple_element_6]], [[get_tuple_element_1_7]], [[zero_3]], [[zero_idx_4]]), dimensions={0}, to_apply=[[argmax_5]]
-      )");
-}
-
-TEST_F(TreeReductionRewriterTest, KeepInnerReductionVectorized) {
-  const char* hlo = R"(
-HloModule KeepInnerRowReductionVectorized
-
-add {
-  accum = f32[] parameter(0)
-  op = f32[] parameter(1)
-  ROOT out = f32[] add(accum, op)
-}
-
-ENTRY main {
-  input = f32[1024,73984] parameter(0)
-  zero = f32[] constant(0)
-  ROOT out = f32[1024] reduce(input, zero), dimensions={1}, to_apply=add
-}
-)";
-
-  CheckTreeRewriter(hlo,
-                    R"(
-
-// CHECK:  [[bitcast_0:%[^ ]+]] = f32[1024,289,256]{2,1,0} bitcast([[input_1:%[^ ]+]])
-// CHECK:  [[zero_2:%[^ ]+]] = f32[] constant(0)
-// CHECK:  [[reduce_3:%[^ ]+]] = f32[1024,289]{1,0} reduce([[bitcast_0]], [[zero_2]]), dimensions={2}, to_apply=[[add_4:%[^ ]+]]
-// CHECK:  ROOT [[out_1_5:%[^ ]+]] = f32[1024]{0} reduce([[reduce_3]], [[zero_2]]), dimensions={1}, to_apply=[[add_4]]
-      )");
-}
-
-TEST_F(TreeReductionRewriterTest, PreferLargeVectorizedDimension) {
-  const char* hlo = R"(
-HloModule PreferLargeVectorizedDimension
-
-add {
-  accum = f32[] parameter(0)
-  op = f32[] parameter(1)
-  ROOT out = f32[] add(accum, op)
-}
-
-ENTRY main {
-  input = f32[1024,98304] parameter(0)
-  zero = f32[] constant(0)
-  ROOT out = f32[1024] reduce(input, zero), dimensions={1}, to_apply=add
-}
-)";
-
-  CheckTreeRewriter(hlo,
-                    R"(
-
-// CHECK:  [[bitcast_0:%[^ ]+]] = f32[1024,256,384]{2,1,0} bitcast([[input_1:%[^ ]+]])
-// CHECK:  [[zero_2:%[^ ]+]] = f32[] constant(0)
-// CHECK:  [[reduce_3:%[^ ]+]] = f32[1024,256]{1,0} reduce([[bitcast_0]], [[zero_2]]), dimensions={2}, to_apply=[[add_4:%[^ ]+]]
-// CHECK:  ROOT [[out_1_5:%[^ ]+]] = f32[1024]{0} reduce([[reduce_3]], [[zero_2]]), dimensions={1}, to_apply=[[add_4]]
-      )");
-}
-
-TEST_F(TreeReductionRewriterTest, SwapIfNonAlignedBeforePadding) {
-  const char* hlo = R"(
-HloModule SwapIfNonAlignedBeforePadding
-
-add {
-  accum = f32[] parameter(0)
-  op = f32[] parameter(1)
-  ROOT out = f32[] add(accum, op)
-}
-
-ENTRY main {
-  input = f32[1024,19739] parameter(0)
-  zero = f32[] constant(0)
-  ROOT out = f32[1024] reduce(input, zero), dimensions={1}, to_apply=add
-}
-)";
-
-  CheckTreeRewriter(hlo,
-                    R"(
-
-// CHECK-DAG:  [[bitcast_0:%[^ ]+]] = f32[1024,140,141]{2,1,0} bitcast([[input_1:%[^ ]+]])
-// CHECK-DAG:  [[zero_2:%[^ ]+]] = f32[] constant(0)
-// CHECK:  [[reduce_3:%[^ ]+]] = f32[1024,140]{1,0} reduce([[bitcast_0]], [[zero_2]]), dimensions={2}, to_apply=[[add_4:%[^ ]+]]
-// CHECK:  ROOT [[out_1_5:%[^ ]+]] = f32[1024]{0} reduce([[reduce_3]], [[zero_2]]), dimensions={1}, to_apply=[[add_4]]
-      )");
-}
-
-TEST_F(TreeReductionRewriterTest, DontSwapIfNonAlignedBeforePadding) {
-  const char* hlo = R"(
-HloModule DontSwapIfNonAlignedBeforePadding
-
-add {
-  accum = f32[] parameter(0)
-  op = f32[] parameter(1)
-  ROOT out = f32[] add(accum, op)
-}
-
-ENTRY main {
-  input = f32[1024,19459] parameter(0)
-  zero = f32[] constant(0)
-  ROOT out = f32[1024] reduce(input, zero), dimensions={1}, to_apply=add
-}
-)";
-
-  CheckTreeRewriter(hlo,
-                    R"(
-
-// CHECK-DAG:  [[bitcast_0:%[^ ]+]] = f32[1024,140,139]{2,1,0} bitcast([[input_1:%[^ ]+]])
-// CHECK-DAG:  [[zero_2:%[^ ]+]] = f32[] constant(0)
-// CHECK:  [[reduce_3:%[^ ]+]] = f32[1024,140]{1,0} reduce([[bitcast_0]], [[zero_2]]), dimensions={2}, to_apply=[[add_4:%[^ ]+]]
-// CHECK:  ROOT [[out_1_5:%[^ ]+]] = f32[1024]{0} reduce([[reduce_3]], [[zero_2]]), dimensions={1}, to_apply=[[add_4]]
-      )");
-}
-
-TEST_F(TreeReductionRewriterTest, NonCosequtiveReductionDims) {
-  const char* hlo = R"(
-    HloModule NonCosequtiveReductionDims
-
-    add {
-      accum = f32[] parameter(0)
-      op = f32[] parameter(1)
-      ROOT out = f32[] add(accum, op)
-    }
-
-    ENTRY main {
-      input = f32[5,3,4,5] parameter(0)
-      zero = f32[] constant(0)
-      ROOT out = f32[5,4] reduce(input, zero), dimensions={1,3}, to_apply=add
-    }
-  )";
-
-  CheckTreeRewriter(hlo, std::nullopt);
-}
-
-}  // namespace
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/tests/xla-opt.cc b/third_party/xla/xla/service/gpu/tests/xla-opt.cc
index 30bd45f..cd5eeb8 100644
--- a/third_party/xla/xla/service/gpu/tests/xla-opt.cc
+++ b/third_party/xla/xla/service/gpu/tests/xla-opt.cc
@@ -15,16 +15,14 @@
 
 #include "mlir/InitAllExtensions.h"
 #include "mlir/Tools/mlir-opt/MlirOptMain.h"
-#include "xla/service/gpu/fusions/triton/prevent_mmav3_loop_unrolling.h"
-#include "xla/service/gpu/fusions/triton/sparse_extensions.h"
+#include "xla/service/gpu/fusions/triton/passes.h"
 #include "third_party/triton/bin/RegisterTritonDialects.h"
 
 int main(int argc, char **argv) {
   mlir::DialectRegistry registry;
   mlir::registerAllExtensions(registry);
   registerTritonDialects(registry);  // This registers all passes as well.
-  xla::gpu::RegisterSparsePasses();
-  xla::gpu::RegisterPreventMmaV3LoopUnrollingPass();
+  xla::gpu::registerTritonFusionTransformsPasses();
 
   return mlir::asMainReturnCode(mlir::MlirOptMain(
       argc, argv, "xla-opt modular optimizer driver\n", registry));
diff --git a/third_party/xla/xla/service/gpu/topk_specializer.cc b/third_party/xla/xla/service/gpu/topk_specializer.cc
deleted file mode 100644
index bd01a07..0000000
--- a/third_party/xla/xla/service/gpu/topk_specializer.cc
+++ /dev/null
@@ -1,113 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/topk_specializer.h"
-
-#include <stddef.h>
-
-#include <initializer_list>
-#include <string>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/log/check.h"
-#include "absl/log/log.h"
-#include "absl/status/status.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/primitive_util.h"
-#include "xla/service/hlo.pb.h"
-#include "xla/service/tuple_util.h"
-#include "xla/shape.h"
-#include "xla/status_macros.h"
-#include "xla/util.h"
-#include "xla/xla_data.pb.h"
-
-namespace xla {
-namespace gpu {
-
-namespace {
-
-absl::StatusOr<HloInstruction*> SmallBufferOptimization(
-    HloCustomCallInstruction* topk) {
-  Shape data_shape = topk->operand(0)->shape();
-  auto supported_dtypes = {F32, BF16};
-  if (!absl::c_linear_search(supported_dtypes, data_shape.element_type())) {
-    return InvalidArgument(
-        "Invalid Dtype: %s",
-        primitive_util::LowercasePrimitiveTypeName(data_shape.element_type()));
-  }
-  // We only support topk of the shape [x] or [batch, x].
-  if (data_shape.dimensions_size() > 2) {
-    return InvalidArgument("Invalid input dimensions: %s",
-                           data_shape.ToString());
-  }
-  bool has_batch = data_shape.dimensions_size() == 2;
-  constexpr size_t max_k = 16;
-  constexpr size_t min_n = 1024;
-  size_t n = data_shape.dimensions(has_batch ? 1 : 0);
-  size_t k = topk->shape().tuple_shapes(0).dimensions(has_batch ? 1 : 0);
-  if (k > max_k) {
-    return InvalidArgument("k too large (%d), must be <= %d", k, max_k);
-  }
-  if (n < min_n) {
-    return InvalidArgument("Input too small (n=%d, min_n=%d)", n, min_n);
-  }
-  HloComputation* comp = topk->parent();
-  HloInstruction* new_topk =
-      comp->AddInstruction(HloInstruction::CreateCustomCall(
-          topk->shape(), topk->operands(),
-          // We don't need the original to_apply, but keeping it around allows
-          // us to round-trip this CustomCall on tests.
-          topk->to_apply(), "__gpu$TopK",
-          /*opaque=*/"", CustomCallApiVersion::API_VERSION_TYPED_FFI));
-  return TupleUtil::ExtractPrefix(new_topk, 2);
-}
-
-class SpecializeTopkVisitor : public DfsHloRewriteVisitor {
- public:
-  absl::Status HandleCustomCall(HloInstruction* inst) override {
-    HloCustomCallInstruction* topk = DynCast<HloCustomCallInstruction>(inst);
-    if (topk == nullptr || topk->custom_call_target() != "TopK") {
-      return absl::OkStatus();
-    }
-    TF_RET_CHECK(topk->operand_count() == 1);
-
-    if (auto small_topk = SmallBufferOptimization(topk); small_topk.ok()) {
-      return ReplaceInstruction(topk, *small_topk);
-    } else {
-      VLOG(2) << "Small TopK optimization doesn't match: "
-              << small_topk.status();
-    }
-
-    return absl::OkStatus();
-  }
-};
-
-}  // namespace
-
-absl::StatusOr<bool> TopkSpecializer::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  return SpecializeTopkVisitor().RunOnModule(module, execution_threads);
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/topk_specializer.h b/third_party/xla/xla/service/gpu/topk_specializer.h
deleted file mode 100644
index 5b57f57..0000000
--- a/third_party/xla/xla/service/gpu/topk_specializer.h
+++ /dev/null
@@ -1,41 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_TOPK_SPECIALIZER_H_
-#define XLA_SERVICE_GPU_TOPK_SPECIALIZER_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla::gpu {
-
-// This pass transforms eligible TopK CustomCall into a call to be executed by
-// runtime/topk.cc.
-class TopkSpecializer : public HloModulePass {
- public:
-  absl::string_view name() const override { return "topk-specializer"; }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-};
-
-}  // namespace xla::gpu
-
-#endif  // XLA_SERVICE_GPU_TOPK_SPECIALIZER_H_
diff --git a/third_party/xla/xla/service/gpu/topk_splitter.cc b/third_party/xla/xla/service/gpu/topk_splitter.cc
deleted file mode 100644
index d20116d..0000000
--- a/third_party/xla/xla/service/gpu/topk_splitter.cc
+++ /dev/null
@@ -1,154 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/topk_splitter.h"
-
-#include <algorithm>
-#include <cmath>
-#include <cstddef>
-#include <cstdint>
-#include <string>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/log/log.h"
-#include "absl/numeric/bits.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/hlo_creation_utils.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-constexpr size_t kRequiredAlignment = 1024;
-constexpr size_t kMaximumBatchSize = 1024;
-
-class TopkSplitterVisitor : public DfsHloRewriteVisitor {
- public:
-  explicit TopkSplitterVisitor(size_t split_threshold)
-      : split_threshold_(split_threshold) {}
-
-  absl::Status HandleCustomCall(HloInstruction* inst) override {
-    HloCustomCallInstruction* topk = DynCast<HloCustomCallInstruction>(inst);
-    if (topk == nullptr || topk->custom_call_target() != "TopK") {
-      return absl::OkStatus();
-    }
-    HloComputation* comp = inst->parent();
-    Shape data_shape = topk->operand(0)->shape();
-    bool has_batch = data_shape.dimensions_size() == 2;
-    // TODO(doak): Support multiple batches.
-    if (has_batch && data_shape.dimensions(0) != 1) {
-      return absl::OkStatus();
-    }
-    size_t n = data_shape.dimensions(has_batch ? 1 : 0);
-    int64_t k = topk->shape().tuple_shapes(0).dimensions(has_batch ? 1 : 0);
-    // If K approaches N, splitting the input will not be beneficial anymore.
-    if (k > sqrt(n)) {
-      return absl::OkStatus();
-    }
-    // TODO(doak): Relax this alignment requirement.
-    if (n % kRequiredAlignment != 0) {
-      return absl::OkStatus();
-    }
-    if (n < split_threshold_) return absl::OkStatus();
-    int new_batch =
-        std::min(absl::bit_floor(n / split_threshold_), kMaximumBatchSize);
-    int new_n = n / new_batch;
-    // Split the input into B batches and compute TopK over the batched arrays.
-    Shape split_input_shape =
-        ShapeUtil::MakeShape(data_shape.element_type(), {new_batch, new_n});
-    TF_ASSIGN_OR_RETURN(
-        HloInstruction * reshaped,
-        MakeReshapeHlo(split_input_shape, topk->mutable_operand(0)));
-    Shape batch_topk_shape = ShapeUtil::MakeTupleShape(
-        {ShapeUtil::MakeShape(data_shape.element_type(), {new_batch, k}),
-         ShapeUtil::MakeShape(S32, {new_batch, k})});
-    HloInstruction* batch_topk =
-        comp->AddInstruction(HloInstruction::CreateCustomCall(
-            batch_topk_shape, {reshaped}, topk->to_apply(), "TopK",
-            /*opaque=*/""));
-    // Fix indices, adding j*split_N to the j-th batch of indices.
-    TF_ASSIGN_OR_RETURN(HloInstruction * indices,
-                        MakeGetTupleElementHlo(batch_topk, 1));
-    TF_ASSIGN_OR_RETURN(HloInstruction * values,
-                        MakeGetTupleElementHlo(batch_topk, 0));
-    Shape iota_shape = ShapeUtil::MakeShape(S32, {new_batch});
-    TF_ASSIGN_OR_RETURN(
-        HloInstruction * fix,
-        MakeBinaryHlo(
-            HloOpcode::kMultiply, MakeIotaHlo(comp, iota_shape, 0),
-            MakeBroadcastHlo(MakeR0ConstantHlo<int32_t>(comp, new_n),
-                             /*broadcast_dimensions=*/{}, iota_shape)));
-    TF_ASSIGN_OR_RETURN(
-        indices, MakeBinaryHlo(HloOpcode::kAdd, indices,
-                               MakeBroadcastHlo(fix, {0}, indices->shape())));
-    // With the indices restored, compute a final top-k. Since this topk uses
-    // arbitrary indices, we need to use sort+slice.
-    Shape linear_index_shape = ShapeUtil::MakeShape(S32, {k * new_batch});
-    Shape linear_shape = ShapeUtil::ChangeElementType(
-        linear_index_shape, data_shape.element_type());
-    Shape linear_sort_shape =
-        ShapeUtil::MakeTupleShape({linear_shape, linear_index_shape});
-    // Assuming the outputs of the TopK above are stably sorted, using a stable
-    // sort here is enough to guarantee global stable sorting:
-    //  - Within a blocks elements are stably sorted by TopK.
-    //  - Since blocks are organized linearly from smallest to largest, the
-    //    index used on the stable sort below will also respect block ordering.
-    HloInstruction* aggregated_sort =
-        comp->AddInstruction(HloInstruction::CreateSort(
-            linear_sort_shape, 0,
-            {*MakeReshapeHlo(linear_shape, values),
-             *MakeReshapeHlo(linear_index_shape, indices)},
-            topk->to_apply(), /*is_stable=*/true));
-    auto slice_tuple = [&](HloInstruction* sort, const size_t index) {
-      return *MakeReshapeHlo(
-          topk->shape().tuple_shapes(index),
-          *MakeSliceHlo(*MakeGetTupleElementHlo(sort, index), {0}, {k}, {1}));
-    };
-    return ReplaceInstruction(topk,
-                              comp->AddInstruction(HloInstruction::CreateTuple({
-                                  slice_tuple(aggregated_sort, 0),
-                                  slice_tuple(aggregated_sort, 1),
-                              })));
-  }
-
- private:
-  size_t split_threshold_;
-};
-
-}  // namespace
-
-absl::StatusOr<bool> TopKSplitter::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  return TopkSplitterVisitor(split_threshold_)
-      .RunOnModule(module, execution_threads);
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/topk_splitter.h b/third_party/xla/xla/service/gpu/topk_splitter.h
deleted file mode 100644
index 8fee2dc..0000000
--- a/third_party/xla/xla/service/gpu/topk_splitter.h
+++ /dev/null
@@ -1,52 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_TOPK_SPLITTER_H_
-#define XLA_SERVICE_GPU_TOPK_SPLITTER_H_
-
-#include <cstddef>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-// Splits large TopK into batches of smaller TopKs, followed by sorting and
-// slicing the results of those smaller topks. We consider TopKs to be 'large'
-// the last dimension of the TopK is larger than `split_threshold`.
-class TopKSplitter : public HloModulePass {
- public:
-  explicit TopKSplitter(size_t split_threshold = 1024 * 1024)
-      : split_threshold_(split_threshold) {}
-  absl::string_view name() const override { return "topk-splitter"; }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
-  const size_t split_threshold_;
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_TOPK_SPLITTER_H_
diff --git a/third_party/xla/xla/service/gpu/topk_splitter_test.cc b/third_party/xla/xla/service/gpu/topk_splitter_test.cc
deleted file mode 100644
index 834185f..0000000
--- a/third_party/xla/xla/service/gpu/topk_splitter_test.cc
+++ /dev/null
@@ -1,210 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/topk_splitter.h"
-
-#include <stdint.h>
-
-#include <cstddef>
-#include <memory>
-#include <optional>
-#include <string>
-#include <utility>
-
-#include "absl/strings/string_view.h"
-#include "absl/strings/substitute.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_dce.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/service/topk_rewriter.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/tests/verified_hlo_module.h"
-#include "tsl/platform/status_matchers.h"
-#include "tsl/platform/statusor.h"
-#include "tsl/platform/test.h"
-
-namespace m = ::xla::match;
-
-namespace xla {
-namespace gpu {
-namespace {
-
-using ::tsl::testing::IsOkAndHolds;
-using TopkSplitterTest = HloTestBase;
-
-constexpr absl::string_view kComparator = R"(
-  %compare {
-    %p.1.lhs.40628 = s32[] parameter(2)
-    %p.1.rhs.40629 = s32[] parameter(3)
-    %constant.40630 = pred[] constant(true)
-    %broadcast.40631 = pred[] broadcast(pred[] %constant.40630), dimensions={}
-    %p.0.lhs.40626 = f32[] parameter(0)
-    %p.0.rhs.40627 = f32[] parameter(1)
-    %compare.40632 = pred[] compare(f32[] %p.0.lhs.40626, f32[] %p.0.rhs.40627), direction=GT, type=TOTALORDER
-    ROOT %select.40633 = pred[] select(pred[] %broadcast.40631, pred[] %compare.40632, pred[] %broadcast.40631)
-  })";
-
-TEST_F(TopkSplitterTest, SplitsTopK) {
-  const std::string hlo_string = absl::Substitute(R"(
-HloModule module
-$0
-ENTRY cluster {
-  %arg.1 = f32[1,1073741824] parameter(0)
-  ROOT %cc.2 = (f32[1,5], s32[1,5]) custom-call(%arg.1), custom_call_target= "TopK", to_apply=%compare
-})",
-                                                  kComparator);
-  TF_ASSERT_OK_AND_ASSIGN(auto module,
-                          ParseAndReturnVerifiedModule(hlo_string));
-  EXPECT_THAT(RunHloPass(TopKSplitter(), module.get()), IsOkAndHolds(true));
-  auto first_topk = m::CustomCall(m::Reshape(m::Parameter(0)));
-  auto slice_result = [&](auto input, size_t i) {
-    return m::Reshape(m::Slice(m::GetTupleElement(input, i)));
-  };
-  auto index_correction =
-      m::Broadcast(m::Multiply(m::Iota(), m::Broadcast(m::Constant())));
-  auto sorted = m::Sort(
-      m::Reshape(m::GetTupleElement(first_topk, 0)),
-      m::Reshape(m::Add(m::GetTupleElement(first_topk, 1), index_correction)));
-  EXPECT_TRUE(
-      Match(module->entry_computation()->root_instruction(),
-            m::Tuple(slice_result(sorted, 0), slice_result(sorted, 1))));
-}
-
-TEST_F(TopkSplitterTest, SplitsTopKNoBatchDimension) {
-  const std::string hlo_string = absl::Substitute(R"(
-HloModule module
-$0
-ENTRY cluster {
-  %arg.1 = f32[1073741824] parameter(0)
-  ROOT %cc.2 = (f32[5], s32[5]) custom-call(%arg.1), custom_call_target= "TopK", to_apply=%compare
-})",
-                                                  kComparator);
-  TF_ASSERT_OK_AND_ASSIGN(auto module,
-                          ParseAndReturnVerifiedModule(hlo_string));
-  EXPECT_THAT(RunHloPass(TopKSplitter(), module.get()), IsOkAndHolds(true));
-  auto first_topk = m::CustomCall(m::Reshape(m::Parameter(0)));
-  auto slice_result = [&](auto input, size_t i) {
-    return m::Reshape(m::Slice(m::GetTupleElement(input, i)));
-  };
-  auto index_correction =
-      m::Broadcast(m::Multiply(m::Iota(), m::Broadcast(m::Constant())));
-  auto sorted = m::Sort(
-      m::Reshape(m::GetTupleElement(first_topk, 0)),
-      m::Reshape(m::Add(m::GetTupleElement(first_topk, 1), index_correction)));
-  EXPECT_TRUE(
-      Match(module->entry_computation()->root_instruction(),
-            m::Tuple(slice_result(sorted, 0), slice_result(sorted, 1))));
-}
-
-TEST_F(TopkSplitterTest, SplitFailsUnderThreshold) {
-  const std::string hlo_string = absl::Substitute(R"(
-HloModule module
-$0
-ENTRY cluster {
-  %arg.1 = f32[1,524288] parameter(0)
-  ROOT %cc.2 = (f32[1,5], s32[1,5]) custom-call(%arg.1), custom_call_target= "TopK", to_apply=%compare
-})",
-                                                  kComparator);
-  TF_ASSERT_OK_AND_ASSIGN(auto module,
-                          ParseAndReturnVerifiedModule(hlo_string));
-  EXPECT_THAT(
-      RunHloPass(TopKSplitter(/*split_threshold=*/1048576), module.get()),
-      IsOkAndHolds(false));
-}
-
-TEST_F(TopkSplitterTest, SplitFailsUnaligned) {
-  const std::string hlo_string = absl::Substitute(R"(
-HloModule module
-$0
-ENTRY cluster {
-  %arg.1 = f32[1,524289] parameter(0)
-  ROOT %cc.2 = (f32[1,5], s32[1,5]) custom-call(%arg.1), custom_call_target= "TopK", to_apply=%compare
-})",
-                                                  kComparator);
-  TF_ASSERT_OK_AND_ASSIGN(auto module,
-                          ParseAndReturnVerifiedModule(hlo_string));
-  EXPECT_THAT(RunHloPass(TopKSplitter(/*split_threshold=*/1024), module.get()),
-              IsOkAndHolds(false));
-}
-
-TEST_F(TopkSplitterTest, SplitFailsLargeK) {
-  const std::string hlo_string = absl::Substitute(R"(
-HloModule module
-$0
-ENTRY cluster {
-  %arg.1 = f32[1,524288] parameter(0)
-  ROOT %cc.2 = (f32[1,1024], s32[1,1024]) custom-call(%arg.1), custom_call_target= "TopK", to_apply=%compare
-})",
-                                                  kComparator);
-  TF_ASSERT_OK_AND_ASSIGN(auto module,
-                          ParseAndReturnVerifiedModule(hlo_string));
-  EXPECT_THAT(RunHloPass(TopKSplitter(/*split_threshold=*/1024), module.get()),
-              IsOkAndHolds(false));
-}
-
-TEST_F(TopkSplitterTest, Equivalent) {
-  const std::string hlo_string = absl::Substitute(R"(
-HloModule module
-$0
-ENTRY cluster {
-  %arg.1 = f32[1,16384] parameter(0)
-  ROOT %cc.2 = (f32[1,5], s32[1,5]) custom-call(%arg.1), custom_call_target= "TopK", to_apply=%compare
-})",
-                                                  kComparator);
-  TF_ASSERT_OK_AND_ASSIGN(auto module,
-                          ParseAndReturnVerifiedModule(hlo_string));
-  EXPECT_THAT(TopkDecomposer().Run(module.get()), IsOkAndHolds(true));
-  auto round_trip = [](HloModule* module) {
-    EXPECT_THAT(TopkRewriter([](const HloSortInstruction*, int64_t) {
-                  return true;
-                }).Run(module),
-                IsOkAndHolds(true));
-    EXPECT_THAT(TopKSplitter(1024).Run(module), IsOkAndHolds(true));
-    EXPECT_THAT(TopkDecomposer().Run(module), IsOkAndHolds(true));
-    EXPECT_TRUE(HloDCE().Run(module).status().ok());
-  };
-  EXPECT_TRUE(RunAndCompare(std::move(module), std::nullopt, round_trip));
-}
-
-TEST_F(TopkSplitterTest, StableSorts) {
-  const std::string hlo_string = absl::Substitute(R"(
-HloModule module
-$0
-ENTRY cluster {
-  %constant.1 = f32[] constant(42)
-  %broadcast.2= f32[1,16384] broadcast(f32[] %constant.1), dimensions={}
-  ROOT %cc.3 = (f32[1,5], s32[1,5]) custom-call(%broadcast.2), custom_call_target= "TopK", to_apply=%compare
-})",
-                                                  kComparator);
-  TF_ASSERT_OK_AND_ASSIGN(auto module,
-                          ParseAndReturnVerifiedModule(hlo_string));
-  EXPECT_THAT(TopkDecomposer().Run(module.get()), IsOkAndHolds(true));
-  auto round_trip = [](HloModule* module) {
-    EXPECT_THAT(TopkRewriter([](const HloSortInstruction*, int64_t) {
-                  return true;
-                }).Run(module),
-                IsOkAndHolds(true));
-    EXPECT_THAT(TopKSplitter(1024).Run(module), IsOkAndHolds(true));
-    EXPECT_THAT(TopkDecomposer().Run(module), IsOkAndHolds(true));
-    EXPECT_TRUE(HloDCE().Run(module).status().ok());
-  };
-  EXPECT_TRUE(RunAndCompare(std::move(module), std::nullopt, round_trip));
-}
-
-}  // namespace
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/topk_test.cc b/third_party/xla/xla/service/gpu/topk_test.cc
deleted file mode 100644
index 43e25b8..0000000
--- a/third_party/xla/xla/service/gpu/topk_test.cc
+++ /dev/null
@@ -1,159 +0,0 @@
-/* Copyright 2022 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include <stddef.h>
-
-#include <memory>
-#include <optional>
-#include <string>
-#include <string_view>
-#include <tuple>
-#include <utility>
-
-#include <gtest/gtest.h>
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/status.h"
-#include "absl/strings/string_view.h"
-#include "absl/strings/substitute.h"
-#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/gpu/topk_specializer.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/service/platform_util.h"
-#include "xla/service/topk_rewriter.h"
-#include "xla/shape_util.h"
-#include "xla/tests/hlo_test_base.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace {
-
-using ::testing::Combine;
-using ::testing::Values;
-
-// Params:
-//  - n_kb: number of elements in kilobytes.
-//  - k: number of elements to return.
-//  - batch_size
-//  - dtype
-using ParameterizedInterface =
-    ::testing::WithParamInterface<std::tuple<int, int, int, std::string_view>>;
-
-class TopkTest : public HloTestBase, public ParameterizedInterface {
- public:
-  TopkTest()
-      : HloTestBase(*PlatformUtil::GetPlatform("gpu"),
-                    *PlatformUtil::GetPlatform("gpu"), true, true, {}) {}
-
- protected:
-  absl::StatusOr<std::unique_ptr<HloModule>> TopkHlo(int n, int k,
-                                                     int batch_size,
-                                                     std::string_view dtype) {
-    return ParseAndReturnVerifiedModule(absl::Substitute(
-        R"(
-      %compare {
-        %p.1.lhs.40628 = s32[] parameter(2)
-        %p.1.rhs.40629 = s32[] parameter(3)
-        %constant.40630 = pred[] constant(true)
-        %broadcast.40631 = pred[] broadcast(pred[] %constant.40630), dimensions={}
-        %p.0.lhs.40626 = f32[] parameter(0)
-        %p.0.rhs.40627 = f32[] parameter(1)
-        %compare.40632 = pred[] compare(f32[] %p.0.lhs.40626, f32[] %p.0.rhs.40627), direction=GT, type=TOTALORDER
-        ROOT %select.40633 = pred[] select(pred[] %broadcast.40631, pred[] %compare.40632, pred[] %broadcast.40631)
-      }
-
-      ENTRY top_k {
-        %arg = $3[$2,$0] parameter(0)
-        ROOT %result = ($3[$2,$1], s32[$2,$1]) custom-call(%arg), custom_call_target="TopK", to_apply=%compare
-      }
-    )",
-        n, k, batch_size, dtype));
-  }
-};
-
-class GeneralizeTopkVisitor : public DfsHloRewriteVisitor {
- public:
-  absl::Status HandleCustomCall(HloInstruction* inst) override {
-    HloCustomCallInstruction* topk = DynCast<HloCustomCallInstruction>(inst);
-    if (topk == nullptr || topk->custom_call_target() != "__gpu$TopK") {
-      return absl::OkStatus();
-    }
-    HloComputation* comp = topk->parent();
-    auto original_shape = ShapeUtil::SliceTuple(topk->shape(), 0, 2);
-    HloInstruction* original_topk =
-        comp->AddInstruction(HloInstruction::CreateCustomCall(
-            original_shape, topk->operands(), topk->to_apply(), "TopK"));
-    // TupleUtil::ExtractPrefix creates the following structure:
-    //      TopK
-    //   -------------
-    //   |     |     |
-    //  Get   Get   Get
-    //    \    |     /
-    //     CreateTuple
-    // Here we walk to Create Tuple and replace it with the original topk.
-    HloInstruction* new_tuple = topk->users()[0]->users()[0];
-    return ReplaceInstruction(new_tuple, original_topk);
-  }
-};
-
-class GeneralizeTopk : public HloModulePass {
- public:
-  absl::string_view name() const override { return "generalized-topk"; }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(HloModule* module,
-                           const absl::flat_hash_set<absl::string_view>&
-                               execution_threads) override {
-    return GeneralizeTopkVisitor().RunOnModule(module, execution_threads);
-  }
-};
-
-void ToSortAndSlice(HloModule* module) {
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, GeneralizeTopk().Run(module));
-  ASSERT_TRUE(changed);
-  TF_ASSERT_OK_AND_ASSIGN(changed, TopkDecomposer().Run(module));
-  ASSERT_TRUE(changed);
-}
-
-TEST_P(TopkTest, ProducesCorrectResult) {
-  const auto [n_kb, k, batch_size, dtype] = GetParam();
-  const size_t n = n_kb * 1024;
-  TF_ASSERT_OK_AND_ASSIGN(auto topk_module, TopkHlo(n, k, batch_size, dtype));
-  TF_ASSERT_OK_AND_ASSIGN(bool changed,
-                          gpu::TopkSpecializer().Run(topk_module.get()));
-  ASSERT_TRUE(changed);
-  EXPECT_TRUE(
-      RunAndCompare(std::move(topk_module), std::nullopt, ToSortAndSlice));
-}
-
-INSTANTIATE_TEST_SUITE_P(
-    TopkTests, TopkTest,
-    Combine(
-        /*n_kb=*/Values(1, 8, 12, 32),
-        /*k=*/Values(1, 2, 4, 8, 16, 7, 12),
-        /*batch_size=*/Values(1, 16, 32, 64, 128),
-        /*dtype=*/Values(absl::string_view("f32"), "bf16")),
-    [](const auto& info) {
-      return absl::Substitute("n$0KiB_k$1_batch_size$2_$3",
-                              std::get<0>(info.param), std::get<1>(info.param),
-                              std::get<2>(info.param), std::get<3>(info.param));
-    });
-
-}  // namespace
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/BUILD b/third_party/xla/xla/service/gpu/transforms/BUILD
new file mode 100644
index 0000000..21a841a
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/BUILD
@@ -0,0 +1,2963 @@
+load(
+    "@local_config_rocm//rocm:build_defs.bzl",
+    "if_rocm_is_configured",
+)
+load(
+    "@local_tsl//tsl/platform/default:cuda_build_defs.bzl",
+    "if_cuda_is_configured",
+)
+load("//xla:xla.bzl", "xla_cc_test")
+load(
+    "//xla/stream_executor:build_defs.bzl",
+    "if_gpu_is_configured",
+)
+load("//xla/tests:build_defs.bzl", "xla_test")
+load("//xla/tsl:tsl.bzl", "if_google", "if_oss")
+
+package(
+    # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
+    default_visibility = ["//xla/service/gpu:__subpackages__"],
+    licenses = ["notice"],
+)
+
+cc_library(
+    name = "algebraic_simplifier",
+    srcs = [
+        "algebraic_simplifier.cc",
+    ],
+    hdrs = [
+        "algebraic_simplifier.h",
+    ],
+    deps = [
+        "//xla:util",
+        "//xla:xla_data_proto_cc",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:algebraic_simplifier",
+        "//xla/service:hlo_pass",
+        "//xla/service/gpu:matmul_utils",
+        "//xla/service/gpu/fusions/triton:triton_support",
+        "//xla/stream_executor:device_description",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings:string_view",
+    ],
+)
+
+xla_cc_test(
+    name = "algebraic_simplifier_test",
+    srcs = ["algebraic_simplifier_test.cc"],
+    deps = [
+        ":algebraic_simplifier",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:algebraic_simplifier",
+        "//xla/stream_executor:device_description",
+        "//xla/tests:hlo_test_base",
+        "//xla/tests:xla_internal_test_main",
+        "@com_google_googletest//:gtest",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+# End-to-end tested via //third_party/tensorflow/compiler/xla/service/gpu:dot_algorithm_support_test
+cc_library(
+    name = "algorithm_checker",
+    srcs = ["algorithm_checker.cc"],
+    hdrs = ["algorithm_checker.h"],
+    deps = [
+        "//xla:xla_data_proto_cc",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:algorithm_util",
+        "//xla/service:hlo_pass",
+        "//xla/stream_executor:device_description",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings:str_format",
+        "@com_google_absl//absl/strings:string_view",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:logging",
+    ],
+)
+
+cc_library(
+    name = "alias_passthrough_params",
+    srcs = ["alias_passthrough_params.cc"],
+    hdrs = ["alias_passthrough_params.h"],
+    deps = [
+        "//xla:shape_util",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_pass",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings:string_view",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:logging",
+    ],
+)
+
+xla_cc_test(
+    name = "alias_passthrough_params_test",
+    srcs = ["alias_passthrough_params_test.cc"],
+    tags = [
+        "nomsan",
+    ],
+    deps = [
+        ":alias_passthrough_params",
+        "//xla/tests:hlo_test_base",
+        "//xla/tests:xla_internal_test_main",
+        "//xla/tsl/lib/core:status_test_util",
+        "@local_tsl//tsl/platform:test",
+    ],
+)
+
+cc_library(
+    name = "all_gather_optimizer",
+    srcs = ["all_gather_optimizer.cc"],
+    hdrs = ["all_gather_optimizer.h"],
+    deps = [
+        "//xla:shape_util",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:collective_ops_utils",
+        "//xla/service:hlo_pass",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:logging",
+    ],
+)
+
+xla_cc_test(
+    name = "all_gather_optimizer_test",
+    srcs = ["all_gather_optimizer_test.cc"],
+    deps = [
+        ":all_gather_optimizer",
+        "//xla:util",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_module_config",
+        "//xla/tests:hlo_test_base",
+        "//xla/tests:xla_internal_test_main",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
+        "@local_tsl//tsl/platform:statusor",
+        "@local_tsl//tsl/platform:test",
+    ],
+)
+
+cc_library(
+    name = "all_reduce_blueconnect",
+    srcs = ["all_reduce_blueconnect.cc"],
+    hdrs = ["all_reduce_blueconnect.h"],
+    deps = [
+        "//xla:shape_util",
+        "//xla:status_macros",
+        "//xla/hlo/ir:hlo",
+        "//xla/hlo/utils:hlo_query",
+        "//xla/service:collective_ops_utils",
+        "//xla/service:computation_placer_hdr",
+        "//xla/service:global_device_id",
+        "//xla/service:hlo_creation_utils",
+        "//xla/service:hlo_pass",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/container:btree",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings:string_view",
+        "@com_google_absl//absl/types:span",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:logging",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+xla_cc_test(
+    name = "all_reduce_blueconnect_test",
+    srcs = ["all_reduce_blueconnect_test.cc"],
+    deps = [
+        ":all_reduce_blueconnect",
+        "//xla:shape_util",
+        "//xla:util",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:computation_placer_hdr",
+        "//xla/service:pattern_matcher",
+        "//xla/service:pattern_matcher_gmock",
+        "//xla/tests:hlo_test_base",
+        "@com_google_absl//absl/strings:string_view",
+        "@com_google_absl//absl/types:span",
+        "@com_google_googletest//:gtest",
+        "@local_tsl//tsl/platform:status_matchers",
+        "@local_tsl//tsl/platform:statusor",
+        "@local_tsl//tsl/platform:test_main",
+    ],
+)
+
+cc_library(
+    name = "all_reduce_splitter",
+    srcs = ["all_reduce_splitter.cc"],
+    hdrs = ["all_reduce_splitter.h"],
+    deps = [
+        "//xla:shape_util",
+        "//xla:xla_data_proto_cc",
+        "//xla/hlo/ir:hlo",
+        "//xla/hlo/utils:hlo_query",
+        "//xla/service:collective_opt_utils",
+        "//xla/service:hlo_module_config",
+        "//xla/service:hlo_pass",
+        "@com_google_absl//absl/cleanup",
+        "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/strings:string_view",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+xla_cc_test(
+    name = "all_reduce_splitter_test",
+    srcs = ["all_reduce_splitter_test.cc"],
+    deps = [
+        ":all_reduce_splitter",
+        ":reduce_scatter_creator",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_module_config",
+        "//xla/service:hlo_pass_pipeline",
+        "//xla/tests:filecheck",
+        "//xla/tests:hlo_test_base",
+        "//xla/tsl/lib/core:status_test_util",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/strings:string_view",
+        "@com_google_googletest//:gtest_main",
+        "@local_tsl//tsl/platform:status_matchers",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+cc_library(
+    name = "async_collective_annotator",
+    srcs = ["async_collective_annotator.cc"],
+    hdrs = ["async_collective_annotator.h"],
+    deps = [
+        "//xla:util",
+        "//xla/hlo/ir:hlo",
+        "//xla/hlo/utils:hlo_query",
+        "//xla/service:hlo_pass",
+        "//xla/service/gpu:backend_configs_cc",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings:string_view",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+xla_cc_test(
+    name = "async_collective_annotator_test",
+    srcs = ["async_collective_annotator_test.cc"],
+    deps = [
+        ":async_collective_annotator",
+        "//xla:util",
+        "//xla/hlo/ir:hlo",
+        "//xla/hlo/utils:hlo_query",
+        "//xla/service/gpu:backend_configs_cc",
+        "//xla/tests:hlo_test_base",
+        "//xla/tests:test_macros_header",
+        "//xla/tests:xla_internal_test_main",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/strings",
+        "@com_google_googletest//:gtest_main",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+cc_library(
+    name = "async_wrapper",
+    srcs = ["async_wrapper.cc"],
+    hdrs = ["async_wrapper.h"],
+    deps = [
+        "//xla:shape_util",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_pass",
+        "//xla/service:hlo_proto_cc",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings:string_view",
+        "@local_tsl//tsl/platform:errors",
+    ],
+)
+
+xla_cc_test(
+    name = "async_wrapper_test",
+    srcs = ["async_wrapper_test.cc"],
+    deps = [
+        ":async_wrapper",
+        "//xla:literal",
+        "//xla:literal_util",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_pass",
+        "//xla/service:hlo_proto_cc",
+        "//xla/tests:hlo_test_base",
+        "//xla/tests:literal_test_util",
+        "//xla/tests:verified_hlo_module",
+        "@com_google_googletest//:gtest_main",
+        "@local_tsl//tsl/platform:status_matchers",
+        "@local_tsl//tsl/platform:test_main",
+    ],
+)
+
+cc_library(
+    name = "collective_permute_cycle_decomposer",
+    srcs = ["collective_permute_cycle_decomposer.cc"],
+    hdrs = ["collective_permute_cycle_decomposer.h"],
+    deps = [
+        "//xla:comparison_util",
+        "//xla:literal_util",
+        "//xla:shape_util",
+        "//xla:util",
+        "//xla:xla_data_proto_cc",
+        "//xla/hlo/ir:hlo",
+        "//xla/hlo/utils:hlo_query",
+        "//xla/service:collective_ops_utils",
+        "//xla/service:hlo_parser",
+        "//xla/service:hlo_pass",
+        "//xla/service/gpu:backend_configs_cc",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/strings:string_view",
+        "@local_tsl//tsl/platform:errors",
+    ],
+)
+
+xla_cc_test(
+    name = "collective_permute_cycle_decomposer_test",
+    srcs = ["collective_permute_cycle_decomposer_test.cc"],
+    deps = [
+        ":collective_permute_cycle_decomposer",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_parser",
+        "//xla/tests:hlo_test_base",
+        "@com_google_absl//absl/strings:string_view",
+        "@com_google_googletest//:gtest",
+        "@local_tsl//tsl/platform:statusor",
+        "@local_tsl//tsl/platform:test_main",
+    ],
+)
+
+cc_library(
+    name = "collective_permute_valid_iteration_annotator",
+    srcs = ["collective_permute_valid_iteration_annotator.cc"],
+    hdrs = ["collective_permute_valid_iteration_annotator.h"],
+    deps = [
+        "//xla:literal_util",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:collective_ops_utils",
+        "//xla/service:hlo_pass",
+        "//xla/service:pattern_matcher",
+        "//xla/service:while_loop_analysis",
+    ],
+)
+
+xla_cc_test(
+    name = "collective_permute_valid_iteration_annotator_test",
+    srcs = ["collective_permute_valid_iteration_annotator_test.cc"],
+    deps = [
+        ":collective_permute_valid_iteration_annotator",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:collective_ops_utils",
+        "//xla/service:hlo_pass_pipeline",
+        "//xla/service:while_loop_trip_count_annotator",
+        "//xla/tests:hlo_test_base",
+        "@local_tsl//tsl/platform:test_main",
+    ],
+)
+
+cc_library(
+    name = "command_buffer_scheduling",
+    srcs = ["command_buffer_scheduling.cc"],
+    hdrs = ["command_buffer_scheduling.h"],
+    deps = [
+        "//xla:shape_util",
+        "//xla:util",
+        "//xla/ffi:ffi_api",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_pass",
+        "//xla/service/gpu:backend_configs_cc",
+        "//xla/service/gpu:cublas_cudnn",
+        "//xla/service/gpu:hlo_fusion_analysis",
+        "//xla/service/gpu:hlo_traversal",
+        "//xla/service/gpu:ir_emission_utils",
+        "//xla/service/gpu:variant_visitor",
+        "//xla/stream_executor:device_description",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/container:inlined_vector",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:span",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:logging",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+xla_cc_test(
+    name = "command_buffer_scheduling_test",
+    srcs = ["command_buffer_scheduling_test.cc"],
+    deps = [
+        ":command_buffer_scheduling",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_parser",
+        "//xla/stream_executor:device_description",
+        "//xla/tests:filecheck",
+        "//xla/tests:hlo_test_base",
+        "//xla/tests:verified_hlo_module",
+        "//xla/tsl/lib/core:status_test_util",
+        "@com_google_googletest//:gtest_main",
+        "@local_tsl//tsl/platform:status",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+cc_library(
+    name = "conv_padding_legalization",
+    srcs = ["conv_padding_legalization.cc"],
+    hdrs = ["conv_padding_legalization.h"],
+    deps = [
+        "//xla:literal_util",
+        "//xla:shape_util",
+        "//xla:util",
+        "//xla:window_util",
+        "//xla:xla_data_proto_cc",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_creation_utils",
+        "//xla/service:hlo_pass",
+        "//xla/service:shape_inference",
+        "//xla/service/gpu:cublas_cudnn",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings:string_view",
+        "@local_tsl//tsl/platform:status",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+xla_cc_test(
+    name = "conv_padding_legalization_test",
+    srcs = ["conv_padding_legalization_test.cc"],
+    deps = [
+        ":conv_padding_legalization",
+        "//xla:shape_util",
+        "//xla:test",
+        "//xla:xla_data_proto_cc",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:pattern_matcher",
+        "//xla/service:pattern_matcher_gmock",
+        "//xla/service/gpu:cublas_cudnn",
+        "//xla/tests:hlo_test_base",
+        "//xla/tests:xla_internal_test_main",  # fixdeps: keep
+        "@local_tsl//tsl/platform:test",
+    ],
+)
+
+cc_library(
+    name = "conv_rewriter",
+    srcs = ["conv_rewriter.cc"],
+    hdrs = ["conv_rewriter.h"],
+    deps = [
+        "//xla:permutation_util",
+        "//xla:shape_util",
+        "//xla:util",
+        "//xla:window_util",
+        "//xla:xla_data_proto_cc",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_pass",
+        "//xla/service/gpu:backend_configs_cc",
+        "//xla/service/gpu:cublas_cudnn",
+        "//xla/stream_executor:device_description",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/strings:string_view",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:logging",
+        "@local_tsl//tsl/platform:status",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+xla_cc_test(
+    name = "conv_rewriter_test",
+    srcs = ["conv_rewriter_test.cc"],
+    deps = [
+        ":conv_rewriter",
+        "//xla:array4d",
+        "//xla:literal_util",
+        "//xla:protobuf_util",
+        "//xla:shape_util",
+        "//xla:test",
+        "//xla:test_helpers",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:pattern_matcher",
+        "//xla/service:pattern_matcher_gmock",
+        "//xla/service:shape_inference",
+        "//xla/service/gpu:cublas_cudnn",
+        "//xla/stream_executor:device_description",
+        "//xla/tests:hlo_test_base",
+        "//xla/tests:xla_internal_test_main",  # fixdeps: keep
+        "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/strings:str_format",
+        "@local_tsl//tsl/platform:status_matchers",
+        "@local_tsl//tsl/platform:statusor",
+        "@local_tsl//tsl/platform:test",
+    ],
+)
+
+cc_library(
+    name = "convert_async_collectives_to_sync",
+    srcs = ["convert_async_collectives_to_sync.cc"],
+    hdrs = ["convert_async_collectives_to_sync.h"],
+    deps = [
+        "//xla/hlo/ir:hlo",
+        "//xla/service:convert_async_collectives_to_sync",
+        "//xla/service/gpu:backend_configs_cc",
+        "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/strings:string_view",
+        "@com_google_absl//absl/types:span",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+xla_cc_test(
+    name = "convert_async_collectives_to_sync_test",
+    srcs = ["convert_async_collectives_to_sync_test.cc"],
+    deps = [
+        ":convert_async_collectives_to_sync",
+        "//xla:util",
+        "//xla/hlo/ir:hlo",
+        "//xla/service/gpu:backend_configs_cc",
+        "//xla/tests:hlo_test_base",
+        "//xla/tests:xla_internal_test_main",
+        "//xla/tsl/lib/core:status_test_util",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/strings:string_view",
+        "@com_google_googletest//:gtest_main",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+cc_library(
+    name = "copy_fusion",
+    srcs = ["copy_fusion.cc"],
+    hdrs = ["copy_fusion.h"],
+    deps = [
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_pass",
+        "//xla/service/gpu:gpu_fusible",
+        "//xla/service/gpu:hlo_traversal",
+        "//xla/service/gpu:ir_emission_utils",
+        "//xla/service/gpu:reduction_utils",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings:string_view",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:logging",
+    ],
+)
+
+xla_cc_test(
+    name = "copy_fusion_test",
+    srcs = ["copy_fusion_test.cc"],
+    deps = [
+        ":copy_fusion",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:pattern_matcher",
+        "//xla/service:pattern_matcher_gmock",
+        "//xla/tests:hlo_test_base",
+        "//xla/tests:xla_internal_test_main",
+        "@com_google_absl//absl/strings",
+        "@com_google_googletest//:gtest",
+    ],
+)
+
+cc_library(
+    name = "cublas_pad_for_gemms",
+    srcs = ["cublas_pad_for_gemms.cc"],
+    hdrs = ["cublas_pad_for_gemms.h"],
+    deps = [
+        ":gemm_fusion",
+        "//xla:literal_util",
+        "//xla:shape_util",
+        "//xla:util",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_pass",
+        "//xla/service/gpu:ir_emission_utils",
+        "//xla/service/gpu/fusions/triton:triton_support",
+        "//xla/stream_executor:device_description",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings:string_view",
+        "@local_tsl//tsl/platform:logging",
+        "@local_tsl//tsl/platform:status",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+xla_cc_test(
+    name = "cublas_pad_for_gemms_test",
+    srcs = ["cublas_pad_for_gemms_test.cc"],
+    tags = [
+        "nomsan",
+    ],
+    deps = [
+        ":cublas_pad_for_gemms",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:pattern_matcher",
+        "//xla/service:pattern_matcher_gmock",
+        "//xla/stream_executor:device_description",
+        "//xla/tests:hlo_test_base",
+        "//xla/tests:xla_internal_test_main",  # build_cleaner: keep
+        "@com_google_googletest//:gtest",
+    ],
+)
+
+cc_library(
+    name = "cudnn_custom_call_converter",
+    srcs = ["cudnn_custom_call_converter.cc"],
+    hdrs = ["cudnn_custom_call_converter.h"],
+    deps = [
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_pass",
+        "//xla/service/gpu:backend_configs_cc",
+        "//xla/service/gpu:ir_emission_utils",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings:string_view",
+        "@local_tsl//tsl/platform:errors",
+    ],
+)
+
+xla_cc_test(
+    name = "cudnn_custom_call_converter_test",
+    srcs = ["cudnn_custom_call_converter_test.cc"],
+    deps = [
+        ":cudnn_custom_call_converter",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:pattern_matcher",
+        "//xla/service:pattern_matcher_gmock",
+        "//xla/tests:hlo_test_base",
+        "//xla/tests:xla_internal_test_main",
+        "@com_google_googletest//:gtest",
+    ],
+)
+
+cc_library(
+    name = "cudnn_fused_conv_rewriter",
+    srcs = ["cudnn_fused_conv_rewriter.cc"],
+    hdrs = ["cudnn_fused_conv_rewriter.h"],
+    deps = [
+        "//xla:comparison_util",
+        "//xla:debug_options_flags",
+        "//xla:literal",
+        "//xla:shape_util",
+        "//xla:util",
+        "//xla:xla_data_proto_cc",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_creation_utils",
+        "//xla/service:hlo_pass",
+        "//xla/service:pattern_matcher",
+        "//xla/service/gpu:backend_configs_cc",
+        "//xla/service/gpu:cublas_cudnn",
+        "//xla/stream_executor",
+        "//xla/stream_executor:dnn",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/container:inlined_vector",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/strings:str_format",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:ml_dtypes",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+xla_test(
+    name = "cudnn_fused_conv_rewriter_test",
+    srcs = ["cudnn_fused_conv_rewriter_test.cc"],
+    backend_tags = {
+        "gpu_a100": [
+            "noasan",
+            "nomsan",
+            "no_rocm",
+        ],
+    },
+    backends = [
+        "gpu_a100",
+        "gpu_amd_any",
+    ] + if_oss(["gpu_any"]),
+    local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([
+        "TENSORFLOW_USE_ROCM=1",
+    ]),
+    shard_count = 10,
+    deps = [
+        ":conv_rewriter",
+        ":cudnn_fused_conv_rewriter",
+        "//xla:comparison_util",
+        "//xla:error_spec",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:algebraic_simplifier",
+        "//xla/service:convert_mover",
+        "//xla/service:hlo_constant_folding",
+        "//xla/service:hlo_module_config",
+        "//xla/service:hlo_pass",
+        "//xla/service:hlo_pass_pipeline",
+        "//xla/service:pattern_matcher",
+        "//xla/service:pattern_matcher_gmock",
+        "//xla/service:reshape_mover",
+        "//xla/service/gpu:backend_configs_cc",
+        "//xla/service/gpu:cublas_cudnn",
+        "//xla/service/gpu:stream_executor_util",
+        "//xla/service/gpu/tests:gpu_codegen_test",
+        "//xla/stream_executor:device_description",
+        "//xla/stream_executor:dnn",
+        "//xla/tests:filecheck",
+        "//xla/tests:hlo_test_base",
+        "//xla/tests:verified_hlo_module",
+        "//xla/tsl/lib/core:status_test_util",
+        "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/strings:str_format",
+        "@com_google_googletest//:gtest_main",
+        "@local_tsl//tsl/platform:statusor",
+        "@local_tsl//tsl/platform:test_main",
+    ] + if_cuda_is_configured([
+        "@local_config_cuda//cuda:cuda_headers",
+        "@local_config_cuda//cuda:cudnn_header",
+    ]) + if_rocm_is_configured([
+        "@local_config_rocm//rocm:rocm_headers",
+    ]),
+)
+
+cc_library(
+    name = "cudnn_fused_mha_rewriter",
+    srcs = ["cudnn_fused_mha_rewriter.cc"],
+    hdrs = ["cudnn_fused_mha_rewriter.h"],
+    local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
+    deps = [
+        "//xla:permutation_util",
+        "//xla:shape_util",
+        "//xla:status_macros",
+        "//xla:types",
+        "//xla:util",
+        "//xla:xla_data_proto_cc",
+        "//xla:xla_proto_cc",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_pass",
+        "//xla/service:pattern_matcher",
+        "//xla/service/gpu:backend_configs_cc",
+        "//xla/service/gpu:cublas_cudnn",
+        "//xla/service/gpu:matmul_utils",
+        "//xla/service/gpu:stream_executor_util",
+        "//xla/stream_executor",
+        "//xla/stream_executor:dnn",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/strings:str_format",
+        "@com_google_absl//absl/types:span",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:statusor",
+    ] + if_cuda_is_configured([
+        "@local_config_cuda//cuda:cuda_headers",
+    ]),
+)
+
+xla_test(
+    name = "cudnn_fused_mha_rewriter_test",
+    srcs = ["cudnn_fused_mha_rewriter_test.cc"],
+    backend_tags = {"gpu": [
+        "requires-gpu-nvidia",
+        "no_rocm",
+    ]},
+    backends = [
+        "gpu",
+    ],
+    local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
+    deps = [
+        ":cudnn_fused_mha_rewriter",
+        ":cudnn_fused_mha_transpose_fusion",
+        "//xla:error_spec",
+        "//xla:test_helpers",
+        "//xla:util",
+        "//xla:xla_data_proto_cc",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:algebraic_simplifier",
+        "//xla/service:computation_layout",
+        "//xla/service:hlo_cse",
+        "//xla/service:hlo_dce",
+        "//xla/service:hlo_module_config",
+        "//xla/service:hlo_parser",
+        "//xla/service:hlo_verifier",
+        "//xla/service:layout_normalization",
+        "//xla/service:pattern_matcher",
+        "//xla/service:pattern_matcher_gmock",
+        "//xla/service:reshape_decomposer",
+        "//xla/service/gpu:backend_configs_cc",
+        "//xla/service/gpu:cublas_cudnn",
+        "//xla/stream_executor:device_description",
+        "//xla/stream_executor:dnn",
+        "//xla/tests:hlo_test_base",
+        "//xla/tsl/lib/core:status_test_util",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/strings",
+        "@com_google_googletest//:gtest_main",
+        "@local_tsl//tsl/platform:statusor",
+        "@local_tsl//tsl/platform:test_main",
+    ] + if_cuda_is_configured([
+        "@local_config_cuda//cuda:cuda_headers",
+        "@local_config_cuda//cuda:cudnn_header",
+    ]),
+)
+
+# Tested via cudnn_fused_mha_rewriter_test.
+cc_library(
+    name = "cudnn_fused_mha_transpose_fusion",
+    srcs = ["cudnn_fused_mha_transpose_fusion.cc"],
+    hdrs = ["cudnn_fused_mha_transpose_fusion.h"],
+    deps = [
+        "//xla:permutation_util",
+        "//xla:shape_util",
+        "//xla:util",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_pass",
+        "//xla/service:pattern_matcher",
+        "//xla/service/gpu:backend_configs_cc",
+        "//xla/service/gpu:cublas_cudnn",
+        "//xla/service/gpu:matmul_utils",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:span",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+# Tested via //third_party/tensorflow/compiler/xla/service/gpu/fusions:cudnn_test
+cc_library(
+    name = "cudnn_fusion_compiler",
+    srcs = if_cuda_is_configured(["cudnn_fusion_compiler.cc"]),
+    hdrs = if_cuda_is_configured(["cudnn_fusion_compiler.h"]),
+    deps = if_cuda_is_configured([
+        "//xla/service/gpu:backend_configs_cc",
+        "//xla/service/gpu:cudnn_support_utils",
+        "//xla/service/gpu:ir_emission_utils",
+        "//xla/service/gpu:kernel_reuse_cache",
+        "//xla/service/gpu:matmul_utils",
+        "//xla/service/gpu:stream_executor_util",
+        "//xla/service/gpu:triton_fusion_analysis",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings:string_view",
+        "@local_config_cuda//cuda:cudnn_header",
+        "//xla:shape_util",
+        "//xla:comparison_util",
+        "//xla:util",
+        "//xla/hlo/ir:hlo",
+        "//xla/hlo/utils:hlo_query",
+        "//xla/service:hlo_pass",
+        "//xla/stream_executor:dnn",
+        "//xla/stream_executor:stream_executor_h",
+        "//xla/service:dump",
+        "//xla/stream_executor/cuda:cudnn_frontend_helpers",
+        "//xla/stream_executor/cuda:cudnn_plugin",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:statusor",
+    ]),
+)
+
+cc_library(
+    name = "cudnn_norm_rewriter",
+    srcs = ["cudnn_norm_rewriter.cc"],
+    hdrs = ["cudnn_norm_rewriter.h"],
+    local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
+    deps = [
+        "//xla:shape_util",
+        "//xla:types",
+        "//xla:util",
+        "//xla:window_util",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_creation_utils",
+        "//xla/service:hlo_pass",
+        "//xla/service:pattern_matcher",
+        "//xla/service/gpu:backend_configs_cc",
+        "//xla/service/gpu:cublas_cudnn",
+        "//xla/stream_executor",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:span",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:logging",
+        "@local_tsl//tsl/platform:statusor",
+        "@local_tsl//tsl/protobuf:dnn_proto_cc",
+    ] + if_cuda_is_configured([
+        "@local_config_cuda//cuda:cuda_headers",
+        "@local_config_cuda//cuda:cudnn_header",
+    ]) + if_google([
+        "@com_google_protobuf//:wrappers_cc_proto",
+    ]),
+)
+
+xla_test(
+    name = "cudnn_norm_rewriter_test",
+    srcs = ["cudnn_norm_rewriter_test.cc"],
+    backends = ["gpu"],
+    local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
+    deps = [
+        ":cudnn_norm_rewriter",
+        "//xla:error_spec",
+        "//xla/service/gpu:cublas_cudnn",
+        "//xla/service/gpu/tests:gpu_codegen_test",
+        "//xla/stream_executor:device_description",
+        "//xla/tests:filecheck",
+        "//xla/tsl/lib/core:status_test_util",
+        "@com_google_googletest//:gtest_main",
+    ] + if_cuda_is_configured([
+        "@local_config_cuda//cuda:cuda_headers",
+        "@local_config_cuda//cuda:cudnn_header",
+    ]),
+)
+
+cc_library(
+    name = "cudnn_pad_for_convolutions",
+    srcs = ["cudnn_pad_for_convolutions.cc"],
+    hdrs = ["cudnn_pad_for_convolutions.h"],
+    deps = [
+        "//xla:literal_util",
+        "//xla:shape_util",
+        "//xla:util",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_pass",
+        "//xla/service/gpu:cublas_cudnn",
+        "//xla/service/gpu:cudnn_support_utils",
+        "//xla/service/gpu:stream_executor_util",
+        "//xla/stream_executor",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/functional:bind_front",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings:string_view",
+        "@com_google_absl//absl/types:span",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:logging",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+xla_cc_test(
+    name = "cudnn_pad_for_convolutions_test",
+    srcs = ["cudnn_pad_for_convolutions_test.cc"],
+    deps = [
+        ":cudnn_pad_for_convolutions",
+        "//xla/service:hlo_parser",
+        "//xla/service:pattern_matcher",
+        "//xla/service:pattern_matcher_gmock",
+        "//xla/service/gpu:cublas_cudnn",
+        "//xla/tests:hlo_test_base",
+        "//xla/tests:xla_internal_test_main",  # build_cleaner: keep
+        "@com_google_googletest//:gtest",
+    ],
+)
+
+cc_library(
+    name = "cudnn_simplify_padding",
+    srcs = ["cudnn_simplify_padding.cc"],
+    hdrs = ["cudnn_simplify_padding.h"],
+    deps = [
+        "//xla:literal",
+        "//xla:xla_data_proto_cc",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_creation_utils",
+        "//xla/service:hlo_pass",
+        "//xla/service:pattern_matcher",
+        "//xla/service/gpu:backend_configs_cc",
+        "//xla/service/gpu:cublas_cudnn",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/container:inlined_vector",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:span",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:logging",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+xla_cc_test(
+    name = "cudnn_simplify_padding_test",
+    srcs = ["cudnn_simplify_padding_test.cc"],
+    deps = [
+        ":cudnn_pad_for_convolutions",
+        ":cudnn_simplify_padding",
+        ":cudnn_vectorize_convolutions",
+        "//xla:literal",
+        "//xla:util",
+        "//xla/service:algebraic_simplifier",
+        "//xla/service:call_inliner",
+        "//xla/service:hlo_pass",
+        "//xla/service:hlo_pass_pipeline",
+        "//xla/service:pattern_matcher",
+        "//xla/service:pattern_matcher_gmock",
+        "//xla/service:reshape_mover",
+        "//xla/service:tuple_simplifier",
+        "//xla/stream_executor:device_description",
+        "//xla/stream_executor:dnn",
+        "//xla/tests:hlo_test_base",
+        "//xla/tests:xla_internal_test_main",  # build_cleaner: keep
+        "//xla/tsl/lib/core:status_test_util",
+        "@com_google_absl//absl/functional:function_ref",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:span",
+        "@com_google_googletest//:gtest",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:logging",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+cc_library(
+    name = "cudnn_vectorize_convolutions",
+    srcs = ["cudnn_vectorize_convolutions.cc"],
+    hdrs = ["cudnn_vectorize_convolutions.h"],
+    deps = [
+        "//xla:shape_util",
+        "//xla:util",
+        "//xla/client:xla_builder",
+        "//xla/client:xla_computation",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_module_config",
+        "//xla/service:hlo_pass",
+        "//xla/service/gpu:backend_configs_cc",
+        "//xla/service/gpu:cublas_cudnn",
+        "//xla/service/gpu:cudnn_support_utils",
+        "//xla/service/gpu:stream_executor_util",
+        "//xla/stream_executor",
+        "//xla/stream_executor:dnn",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/container:inlined_vector",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/strings:string_view",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:logging",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+xla_cc_test(
+    name = "cudnn_vectorize_convolutions_test",
+    srcs = ["cudnn_vectorize_convolutions_test.cc"],
+    deps = [
+        ":cudnn_vectorize_convolutions",
+        "//xla:util",
+        "//xla/service:call_inliner",
+        "//xla/service:hlo_parser",
+        "//xla/service:pattern_matcher",
+        "//xla/service:pattern_matcher_gmock",
+        "//xla/service/gpu:backend_configs_cc",
+        "//xla/service/gpu:cublas_cudnn",
+        "//xla/stream_executor:device_description",
+        "//xla/stream_executor:dnn",
+        "//xla/tests:hlo_test_base",
+        "//xla/tests:xla_internal_test_main",  # build_cleaner: keep
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_googletest//:gtest",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+# TODO(b/358278858): Currently lacking test coverage.
+cc_library(
+    name = "cudnn_custom_call_compiler",
+    srcs = if_cuda_is_configured(["cudnn_custom_call_compiler.cc"]),
+    hdrs = if_cuda_is_configured(["cudnn_custom_call_compiler.h"]),
+    deps = if_cuda_is_configured([
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/container:inlined_vector",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings:string_view",
+        "@local_config_cuda//cuda:cudnn_header",
+        "//xla:shape_util",
+        "//xla:status_macros",
+        "//xla:util",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_pass",
+        "//xla/service/gpu/runtime:cudnn_thunk",
+        "//xla/service/gpu:backend_configs_cc",
+        "//xla/service/gpu:cublas_cudnn",
+        "//xla/service/gpu:ir_emission_utils",
+        "//xla/service/gpu:stream_executor_util",
+        "//xla/stream_executor:dnn",
+        "//xla/stream_executor:stream_executor_h",
+        "//xla/stream_executor/cuda:cudnn_frontend_helpers",
+        "//xla/stream_executor/cuda:cudnn_plugin",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:statusor",
+    ]),
+)
+
+cc_library(
+    name = "custom_kernel_fusion_rewriter",
+    srcs = ["custom_kernel_fusion_rewriter.cc"],
+    hdrs = ["custom_kernel_fusion_rewriter.h"],
+    deps = [
+        "//xla:shape_util",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_pass",
+        "//xla/service/gpu/kernels:custom_fusion_library",
+        "//xla/service/gpu/kernels:custom_kernel_fusion_pattern",
+        "//xla/stream_executor:device_description",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/container:inlined_vector",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:span",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:logging",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+xla_cc_test(
+    name = "custom_kernel_fusion_rewriter_test",
+    srcs = ["custom_kernel_fusion_rewriter_test.cc"],
+    deps = [
+        ":custom_kernel_fusion_rewriter",
+        "//xla/hlo/ir:hlo",
+        "//xla/service/gpu:gpu_device_info_for_tests",
+        "//xla/service/gpu/kernels:custom_kernel_fusion_pattern",
+        "//xla/stream_executor:device_description",
+        "//xla/tests:hlo_test_base",
+        "@local_tsl//tsl/platform:test",
+        "@local_tsl//tsl/platform:test_main",
+    ],
+)
+
+cc_library(
+    name = "dot_dimension_sorter",
+    srcs = ["dot_dimension_sorter.cc"],
+    hdrs = ["dot_dimension_sorter.h"],
+    deps = [
+        "//xla:permutation_util",
+        "//xla:shape_util",
+        "//xla:util",
+        "//xla:xla_data_proto_cc",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_pass",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings:string_view",
+        "@com_google_absl//absl/types:span",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:logging",
+    ],
+)
+
+xla_test(
+    name = "dot_dimension_sorter_test",
+    srcs = ["dot_dimension_sorter_test.cc"],
+    backends = ["gpu"],
+    deps = [
+        ":dot_dimension_sorter",
+        "//xla:error_spec",
+        "//xla/hlo/ir:hlo",
+        "//xla/service/gpu/tests:gpu_codegen_test",
+        "//xla/tests:xla_internal_test_main",  # fixdeps: keep
+        "@com_google_googletest//:gtest",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+cc_library(
+    name = "dot_operand_converter",
+    srcs = ["dot_operand_converter.cc"],
+    hdrs = ["dot_operand_converter.h"],
+    deps = [
+        "//xla:shape_util",
+        "//xla:util",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:op_expander_pass",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings:string_view",
+        "@local_tsl//tsl/platform:errors",
+    ],
+)
+
+xla_test(
+    name = "dot_operand_converter_test",
+    srcs = if_gpu_is_configured(["dot_operand_converter_test.cc"]),
+    backends = [
+        "gpu_a100",
+        "gpu_p100",
+        "gpu_v100",
+        "gpu_amd_any",
+    ],
+    deps = if_gpu_is_configured(
+        [
+            ":dot_operand_converter",
+            "@com_google_googletest//:gtest",
+            "@com_google_absl//absl/strings",
+            "@com_google_absl//absl/strings:string_view",
+            "//xla:shape_util",
+            "//xla:xla_data_proto_cc",
+            "//xla/hlo/ir:hlo",
+            "//xla/hlo/utils:hlo_matchers",
+            "//xla/service:pattern_matcher",
+            "//xla/tests:hlo_test_base",
+            "//xla/tests:xla_internal_test_main",
+            "@local_tsl//tsl/platform:statusor",
+        ],
+    ) + [
+        # b/317293391
+        "@local_tsl//tsl/platform:test_main",
+    ],
+)
+
+cc_library(
+    name = "dot_sparsity_rewriter",
+    srcs = ["dot_sparsity_rewriter.cc"],
+    hdrs = ["dot_sparsity_rewriter.h"],
+    deps = [
+        "//xla:util",
+        "//xla:xla_data_proto_cc",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_creation_utils",
+        "//xla/service:hlo_pass",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings:string_view",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+xla_cc_test(
+    name = "dot_sparsity_rewriter_test",
+    srcs = ["dot_sparsity_rewriter_test.cc"],
+    deps = [
+        ":dot_sparsity_rewriter",
+        "//xla:xla_data_proto_cc",
+        "//xla/hlo/ir:hlo",
+        "//xla/tests:hlo_test_base",
+        "//xla/tests:xla_internal_test_main",  # fixdeps: keep
+        "@com_google_googletest//:gtest",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+cc_library(
+    name = "double_buffer_loop_unrolling",
+    srcs = ["double_buffer_loop_unrolling.cc"],
+    hdrs = ["double_buffer_loop_unrolling.h"],
+    deps = [
+        "//xla:util",
+        "//xla:xla_data_proto_cc",
+        "//xla/hlo/ir:hlo",
+        "//xla/hlo/ir:hlo_instruction_utils",
+        "//xla/hlo/utils:hlo_query",
+        "//xla/service:collective_ops_utils",
+        "//xla/service:flatten_call_graph",
+        "//xla/service:hlo_pass",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+xla_cc_test(
+    name = "double_buffer_loop_unrolling_test",
+    srcs = ["double_buffer_loop_unrolling_test.cc"],
+    deps = [
+        ":double_buffer_loop_unrolling",
+        "//xla:test",
+        "//xla:xla_data_proto_cc",
+        "//xla:xla_proto_cc",
+        "//xla/hlo/ir:hlo",
+        "//xla/hlo/utils:hlo_query",
+        "//xla/service:tuple_simplifier",
+        "//xla/tests:filecheck",
+        "//xla/tests:hlo_test_base",
+        "//xla/tests:xla_internal_test_main",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@local_tsl//tsl/platform:status_matchers",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+cc_library(
+    name = "dynamic_slice_fusion_rewriter",
+    srcs = ["dynamic_slice_fusion_rewriter.cc"],
+    hdrs = ["dynamic_slice_fusion_rewriter.h"],
+    deps = [
+        "//xla:shape_util",
+        "//xla:util",
+        "//xla/ffi:ffi_api",
+        "//xla/ffi/api:c_api",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:custom_call_target_registry",
+        "//xla/service:hlo_pass",
+        "//xla/service:pattern_matcher",
+        "//xla/service/gpu:backend_configs_cc",
+        "//xla/service/gpu:cublas_cudnn",
+        "//xla/service/gpu:gpu_constants",
+        "//xla/service/gpu:hlo_traversal",
+        "//xla/service/gpu:ir_emission_utils",
+        "//xla/service/gpu/kernels:custom_fusion_library",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/container:inlined_vector",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:span",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:logging",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+xla_cc_test(
+    name = "dynamic_slice_fusion_rewriter_test",
+    srcs = if_cuda_is_configured(["dynamic_slice_fusion_rewriter_test.cc"]),
+    deps = [
+        ":dynamic_slice_fusion_rewriter",
+        "//xla:shape_util",
+        "//xla/client:xla_builder",
+        "//xla/client/lib:constants",
+        "//xla/ffi",
+        "//xla/ffi:ffi_api",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:buffer_value",
+        "//xla/service:custom_call_target_registry",
+        "//xla/service:executable",
+        "//xla/service:hlo_memory_scheduler",
+        "//xla/service:hlo_module_config",
+        "//xla/service/gpu:gpu_device_info_for_tests",
+        "//xla/stream_executor",
+        "//xla/stream_executor/gpu:gpu_types_header",
+        "//xla/tests:hlo_test_base",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/status",
+        "@local_tsl//tsl/platform:status",
+        "@local_tsl//tsl/platform:statusor",
+        "@local_tsl//tsl/platform:test",
+        "@local_tsl//tsl/platform:test_main",
+    ],
+)
+
+cc_library(
+    name = "fusion_merger",
+    srcs = ["fusion_merger.cc"],
+    hdrs = ["fusion_merger.h"],
+    deps = [
+        "//xla:shape_util",
+        "//xla:util",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_cost_analysis",
+        "//xla/service:hlo_graph_dumper",
+        "//xla/service:hlo_pass",
+        "//xla/service:instruction_fusion",
+        "//xla/service/gpu:gpu_fusible",
+        "//xla/service/gpu/model:gpu_hlo_cost_analysis",
+        "//xla/service/gpu/model:gpu_performance_model",
+        "//xla/service/gpu/model:gpu_performance_model_base",
+        "//xla/stream_executor:device_description",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:status",
+    ],
+)
+
+xla_cc_test(
+    name = "fusion_merger_test",
+    srcs = ["fusion_merger_test.cc"],
+    tags = [
+        "nomsan",
+    ],
+    deps = [
+        ":fusion_merger",
+        "//xla:shape_util",
+        "//xla:xla_data_proto_cc",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_cost_analysis",
+        "//xla/service:pattern_matcher",
+        "//xla/service:pattern_matcher_gmock",
+        "//xla/service/gpu:gpu_device_info_for_tests",
+        "//xla/service/gpu:gpu_fusible",
+        "//xla/tests:hlo_test_base",
+        "//xla/tests:xla_internal_test_main",
+        "@com_google_absl//absl/types:span",
+        "@com_google_googletest//:gtest_main",
+    ],
+)
+
+cc_library(
+    name = "fusion_wrapper",
+    srcs = ["fusion_wrapper.cc"],
+    hdrs = ["fusion_wrapper.h"],
+    deps = [
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_pass",
+        "//xla/service/gpu:gpu_fusible",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
+        "@local_tsl//tsl/platform:errors",
+    ],
+)
+
+xla_cc_test(
+    name = "fusion_wrapper_test",
+    srcs = ["fusion_wrapper_test.cc"],
+    deps = [
+        ":fusion_wrapper",
+        "//xla/tests:hlo_test_base",
+        "@com_google_googletest//:gtest_main",
+    ],
+)
+
+cc_library(
+    name = "gemm_broadcast_folding_rewriter",
+    srcs = ["gemm_broadcast_folding_rewriter.cc"],
+    hdrs = ["gemm_broadcast_folding_rewriter.h"],
+    deps = [
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_pass",
+        "//xla/service:pattern_matcher",
+        "//xla/service/gpu:backend_configs_cc",
+        "//xla/service/gpu:cublas_cudnn",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings:string_view",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+xla_test(
+    name = "gemm_broadcast_folding_rewriter_test",
+    srcs = ["gemm_broadcast_folding_rewriter_test.cc"],
+    backends = ["gpu"],
+    local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([
+        "TENSORFLOW_USE_ROCM=1",
+    ]),
+    deps = [
+        ":gemm_broadcast_folding_rewriter",
+        ":gemm_rewriter",
+        "//xla:error_spec",
+        "//xla/hlo/ir:hlo",
+        "//xla/service/gpu/tests:gpu_codegen_test",
+        "@local_tsl//tsl/platform:statusor",
+        "@local_tsl//tsl/platform:test",
+        "@local_tsl//tsl/platform:test_main",
+    ],
+)
+
+cc_library(
+    name = "gemm_fusion",
+    srcs = ["gemm_fusion.cc"],
+    hdrs = ["gemm_fusion.h"],
+    deps = [
+        "//xla:shape_util",
+        "//xla:util",
+        "//xla:xla_data_proto_cc",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_pass",
+        "//xla/service:instruction_fusion",
+        "//xla/service/gpu:backend_configs_cc",
+        "//xla/service/gpu:cublas_padding_requirements",
+        "//xla/service/gpu:ir_emission_utils",
+        "//xla/service/gpu:matmul_utils",
+        "//xla/service/gpu:triton_fusion_analysis",
+        "//xla/service/gpu:triton_tiling_propagation",
+        "//xla/service/gpu/fusions/triton:triton_support",
+        "//xla/stream_executor:device_description",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:statusor",
+        "@local_tsl//tsl/platform:tensor_float_32_utils",
+    ],
+)
+
+xla_cc_test(
+    name = "gemm_fusion_test",
+    srcs = ["gemm_fusion_test.cc"],
+    deps = [
+        ":gemm_fusion",
+        "//xla:autotuning_proto_cc",
+        "//xla:xla_data_proto_cc",
+        "//xla:xla_proto_cc",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:pattern_matcher",
+        "//xla/service:pattern_matcher_gmock",
+        "//xla/service/gpu:cublas_padding_requirements",
+        "//xla/service/gpu:triton_fusion_analysis",
+        "//xla/stream_executor:device_description",
+        "//xla/tests:filecheck",
+        "//xla/tests:hlo_test_base",
+        "//xla/tests:verified_hlo_module",
+        "//xla/tests:xla_internal_test_main",  # fixdeps: keep
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/strings",
+        "@com_google_googletest//:gtest",
+        "@local_tsl//tsl/platform:status_matchers",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+cc_library(
+    name = "gemm_rewriter",
+    srcs = ["gemm_rewriter.cc"],
+    hdrs = ["gemm_rewriter.h"],
+    deps = [
+        "//xla:literal",
+        "//xla:literal_util",
+        "//xla:shape_util",
+        "//xla:status_macros",
+        "//xla:types",
+        "//xla:util",
+        "//xla:xla_data_proto_cc",
+        "//xla/hlo/evaluator:hlo_evaluator",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:algorithm_util",
+        "//xla/service:hlo_creation_utils",
+        "//xla/service:hlo_pass",
+        "//xla/service:pattern_matcher",
+        "//xla/service/gpu:backend_configs_cc",
+        "//xla/service/gpu:cublas_cudnn",
+        "//xla/service/gpu:ir_emission_utils",
+        "//xla/service/gpu:matmul_utils",
+        "//xla/stream_executor:blas",
+        "//xla/stream_executor:device_description",
+        "//xla/stream_executor/gpu:gpu_blas_lt",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:span",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:ml_dtypes",
+        "@local_tsl//tsl/platform:statusor",
+        "@local_tsl//tsl/protobuf:dnn_proto_cc",
+    ],
+)
+
+xla_test(
+    name = "gemm_rewriter_test",
+    srcs = ["gemm_rewriter_test.cc"],
+    backends = ["gpu"],
+    local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([
+        "TENSORFLOW_USE_ROCM=1",
+    ]),
+    deps = [
+        ":gemm_rewriter",
+        "//xla:error_spec",
+        "//xla:test",
+        "//xla:xla_proto_cc",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:buffer_assignment",
+        "//xla/service:executable",
+        "//xla/service:hlo_module_config",
+        "//xla/service:hlo_pass",
+        "//xla/service:pattern_matcher",
+        "//xla/service:pattern_matcher_gmock",
+        "//xla/service/gpu:gpu_executable",
+        "//xla/service/gpu/tests:gpu_codegen_test",
+        "//xla/stream_executor:device_description",
+        "//xla/stream_executor:device_memory_allocator",
+        "//xla/stream_executor:stream_executor_memory_allocator",
+        "//xla/tests:filecheck",
+        "//xla/tests:verified_hlo_module",
+        "//xla/tsl/lib/core:status_test_util",
+        "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:span",
+        "@local_tsl//tsl/platform:statusor",
+        "@local_tsl//tsl/platform:test_main",
+    ] + if_cuda_is_configured([
+        "@local_config_cuda//cuda:cuda_headers",
+    ]) + if_rocm_is_configured([
+        "@local_config_rocm//rocm:rocm_headers",
+    ]),
+)
+
+cc_library(
+    name = "gemv_rewriter",
+    srcs = ["gemv_rewriter.cc"],
+    hdrs = ["gemv_rewriter.h"],
+    deps = [
+        "//xla:shape_util",
+        "//xla:util",
+        "//xla:xla_data_proto_cc",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_pass",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/container:inlined_vector",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings:string_view",
+        "@com_google_absl//absl/types:span",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+xla_cc_test(
+    name = "gemv_rewriter_test",
+    srcs = ["gemv_rewriter_test.cc"],
+    deps = [
+        ":gemv_rewriter",
+        "//xla/hlo/ir:hlo",
+        "//xla/tests:hlo_test_base",
+        "//xla/tests:xla_internal_test_main",  # fixdeps: keep
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_googletest//:gtest_main",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+# TODO(b/358278858): Currently lacking test coverage.
+cc_library(
+    name = "gpusolver_rewriter",
+    srcs = if_gpu_is_configured(["gpusolver_rewriter.cc"]),
+    hdrs = if_gpu_is_configured(["gpusolver_rewriter.h"]),
+    deps = if_gpu_is_configured([
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings:string_view",
+        "//xla:comparison_util",
+        "//xla:literal",
+        "//xla:literal_util",
+        "//xla:shape_util",
+        "//xla:util",
+        "//xla:xla_data_proto_cc",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_pass",
+        "//xla/service/gpu:cusolver_context",
+        "//xla/service/gpu:ir_emission_utils",
+        "//xla/stream_executor",
+        "//xla/stream_executor:blas",
+        "//xla/stream_executor:device_memory_allocator",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:logging",
+        "@local_tsl//tsl/platform:status",
+        "@local_tsl//tsl/platform:statusor",
+    ]),
+)
+
+cc_library(
+    name = "horizontal_input_fusion",
+    srcs = ["horizontal_input_fusion.cc"],
+    hdrs = ["horizontal_input_fusion.h"],
+    deps = [
+        "//xla:shape_util",
+        "//xla:util",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_creation_utils",
+        "//xla/service:hlo_pass",
+        "//xla/service/gpu:gpu_fusible",
+        "//xla/stream_executor:device_description",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings:string_view",
+        "@com_google_absl//absl/types:span",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+xla_test(
+    name = "horizontal_input_fusion_test",
+    srcs = ["horizontal_input_fusion_test.cc"],
+    backends = ["gpu"],
+    deps = [
+        ":horizontal_input_fusion",
+        "//xla:error_spec",
+        "//xla:literal_util",
+        "//xla:shape_util",
+        "//xla:test",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:pattern_matcher",
+        "//xla/service:pattern_matcher_gmock",
+        "//xla/service/gpu:gpu_device_info_for_tests",
+        "//xla/service/gpu/tests:gpu_codegen_test",
+        "//xla/stream_executor:device_description",
+        "//xla/tests:xla_internal_test_main",
+    ],
+)
+
+cc_library(
+    name = "horizontal_loop_fusion",
+    srcs = ["horizontal_loop_fusion.cc"],
+    hdrs = ["horizontal_loop_fusion.h"],
+    deps = [
+        "//xla:shape_util",
+        "//xla:util",
+        "//xla:xla_data_proto_cc",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_creation_utils",
+        "//xla/service:hlo_pass",
+        "//xla/service:sub_byte_normalization",
+        "//xla/service/gpu:gpu_fusible",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/strings:string_view",
+        "@com_google_absl//absl/types:span",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+xla_test(
+    name = "horizontal_loop_fusion_test",
+    srcs = ["horizontal_loop_fusion_test.cc"],
+    backends = ["gpu"],
+    deps = [
+        ":horizontal_loop_fusion",
+        ":instruction_fusion",
+        "//xla:error_spec",
+        "//xla:shape_util",
+        "//xla:test",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_dce",
+        "//xla/service:hlo_parser",
+        "//xla/service:hlo_pass",
+        "//xla/service:hlo_pass_pipeline",
+        "//xla/service:pattern_matcher",
+        "//xla/service:pattern_matcher_gmock",
+        "//xla/service/gpu:gpu_device_info_for_tests",
+        "//xla/stream_executor:device_description",
+        "//xla/tests:hlo_test_base",
+        "//xla/tests:xla_internal_test_main",
+        "//xla/tsl/lib/core:status_test_util",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/log",
+    ],
+)
+
+cc_library(
+    name = "instruction_fusion",
+    srcs = ["instruction_fusion.cc"],
+    hdrs = ["instruction_fusion.h"],
+    deps = [
+        "//xla:shape_util",
+        "//xla:xla_data_proto_cc",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:fusion_node_indexing_evaluation",
+        "//xla/service:fusion_queue",
+        "//xla/service:hlo_pass",
+        "//xla/service:instruction_fusion",
+        "//xla/service/gpu:gpu_fusible",
+        "//xla/stream_executor:device_description",
+        "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/meta:type_traits",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
+    ],
+)
+
+xla_cc_test(
+    name = "instruction_fusion_test",
+    srcs = ["instruction_fusion_test.cc"],
+    tags = [
+        "nomsan",
+        "not_run:arm",
+    ],
+    deps = [
+        ":instruction_fusion",
+        "//xla:literal_util",
+        "//xla:shape_util",
+        "//xla:util",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:pattern_matcher",
+        "//xla/service:pattern_matcher_gmock",
+        "//xla/service/gpu:gpu_device_info_for_tests",
+        "//xla/service/gpu:gpu_fusible",
+        "//xla/tests:hlo_test_base",
+        "//xla/tests:test_utils",
+        "//xla/tests:verified_hlo_module",
+        "//xla/tests:xla_internal_test_main",
+        "@com_google_googletest//:gtest_main",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+cc_library(
+    name = "layout_assignment",
+    srcs = ["layout_assignment.cc"],
+    hdrs = ["layout_assignment.h"],
+    deps = [
+        "//xla:shape_layout",
+        "//xla:shape_util",
+        "//xla:util",
+        "//xla:window_util",
+        "//xla:xla_data_proto_cc",
+        "//xla:xla_proto_cc",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:computation_layout",
+        "//xla/service:host_memory_offload_annotations_hdr",
+        "//xla/service:layout_assignment",
+        "//xla/service:logical_buffer",
+        "//xla/service/gpu:backend_configs_cc",
+        "//xla/service/gpu:cublas_cudnn",
+        "//xla/service/gpu:matmul_utils",
+        "//xla/service/gpu:reduction_utils",
+        "//xla/service/gpu:stream_executor_util",
+        "//xla/stream_executor",
+        "//xla/stream_executor:dnn",
+        "//xla/tsl/util:env_var",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/types:span",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:status",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+xla_cc_test(
+    name = "layout_assignment_test",
+    srcs = ["layout_assignment_test.cc"],
+    deps = [
+        ":layout_assignment",
+        "//xla:shape_layout",
+        "//xla:shape_util",
+        "//xla:xla_data_proto_cc",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:computation_layout",
+        "//xla/service:hlo_parser",
+        "//xla/service:pattern_matcher",
+        "//xla/service:pattern_matcher_gmock",
+        "//xla/service/gpu:stream_executor_util",
+        "//xla/stream_executor:device_description",
+        "//xla/stream_executor:dnn",
+        "//xla/tests:hlo_test_base",
+        "//xla/tests:xla_internal_test_main",  # build_cleaner: keep
+        "@com_google_absl//absl/types:span",
+        "@com_google_googletest//:gtest",
+        "@local_tsl//tsl/platform:status_matchers",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+cc_library(
+    name = "move_copy_to_users",
+    srcs = ["move_copy_to_users.cc"],
+    hdrs = ["move_copy_to_users.h"],
+    deps = [
+        "//xla:shape_util",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_creation_utils",
+        "//xla/service:hlo_pass",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:span",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:logging",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+xla_cc_test(
+    name = "move_copy_to_users_test",
+    srcs = ["move_copy_to_users_test.cc"],
+    deps = [
+        ":move_copy_to_users",
+        "//xla/service:layout_assignment",
+        "//xla/tests:hlo_test_base",
+        "@com_google_absl//absl/strings:string_view",
+        "@local_tsl//tsl/platform:test",
+        "@local_tsl//tsl/platform:test_main",
+    ],
+)
+
+cc_library(
+    name = "multi_output_fusion",
+    srcs = ["multi_output_fusion.cc"],
+    hdrs = ["multi_output_fusion.h"],
+    deps = [
+        "//xla:debug_options_flags",
+        "//xla:shape_util",
+        "//xla/hlo/ir:hlo",
+        "//xla/hlo/ir:hlo_dfs_reachability",
+        "//xla/service:hlo_cost_analysis",
+        "//xla/service:hlo_graph_dumper",
+        "//xla/service:hlo_pass",
+        "//xla/service:instruction_fusion",
+        "//xla/service/gpu:gpu_fusible",
+        "//xla/service/gpu/model:gpu_hlo_cost_analysis",
+        "//xla/service/gpu/model:gpu_performance_model",
+        "//xla/service/gpu/model:gpu_performance_model_base",
+        "//xla/stream_executor:device_description",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/strings:str_format",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:status",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+xla_cc_test(
+    name = "multi_output_fusion_test",
+    srcs = ["multi_output_fusion_test.cc"],
+    tags = [
+        "nomsan",
+    ],
+    deps = [
+        ":multi_output_fusion",
+        "//xla:shape_util",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_cost_analysis",
+        "//xla/service:pattern_matcher",
+        "//xla/service:pattern_matcher_gmock",
+        "//xla/service/gpu:gpu_device_info_for_tests",
+        "//xla/service/gpu:gpu_fusible",
+        "//xla/tests:hlo_test_base",
+        "//xla/tests:xla_internal_test_main",
+        "@com_google_absl//absl/strings",
+    ],
+)
+
+cc_library(
+    name = "pipelined_p2p_rewriter",
+    srcs = ["pipelined_p2p_rewriter.cc"],
+    hdrs = ["pipelined_p2p_rewriter.h"],
+    deps = [
+        "//xla:shape_util",
+        "//xla:util",
+        "//xla/hlo/ir:hlo",
+        "//xla/hlo/utils:hlo_query",
+        "//xla/service:collective_ops_utils",
+        "//xla/service:hlo_pass",
+        "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/strings:string_view",
+        "@com_google_absl//absl/types:span",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+xla_cc_test(
+    name = "pipelined_p2p_rewriter_test",
+    srcs = ["pipelined_p2p_rewriter_test.cc"],
+    deps = [
+        ":pipelined_p2p_rewriter",
+        "//xla/hlo/ir:hlo",
+        "//xla/tests:filecheck",
+        "//xla/tests:hlo_test_base",
+        "@com_google_absl//absl/strings:string_view",
+        "@com_google_googletest//:gtest",
+        "@local_tsl//tsl/platform:statusor",
+        "@local_tsl//tsl/platform:test_main",
+    ],
+)
+
+cc_library(
+    name = "priority_fusion",
+    srcs = ["priority_fusion.cc"],
+    hdrs = ["priority_fusion.h"],
+    deps = [
+        "//xla:debug_options_flags",
+        "//xla:shape_util",
+        "//xla:xla_data_proto_cc",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:dump",
+        "//xla/service:fusion_queue",
+        "//xla/service:hlo_cost_analysis",
+        "//xla/service:hlo_graph_dumper",
+        "//xla/service:hlo_pass",
+        "//xla/service:instruction_fusion",
+        "//xla/service/gpu:backend_configs_cc",
+        "//xla/service/gpu:fusion_process_dump_proto_cc",
+        "//xla/service/gpu:gpu_fusible",
+        "//xla/service/gpu:hlo_fusion_analysis",
+        "//xla/service/gpu:hlo_traversal",
+        "//xla/service/gpu/model:fusion_analysis_cache",
+        "//xla/service/gpu/model:gpu_hlo_cost_analysis",
+        "//xla/service/gpu/model:gpu_performance_model",
+        "//xla/service/gpu/model:gpu_performance_model_base",
+        "//xla/service/gpu/model:symbolic_tile_analysis",
+        "//xla/stream_executor:device_description",
+        "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/meta:type_traits",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/strings:str_format",
+        "@com_google_absl//absl/synchronization",
+        "@com_google_absl//absl/time",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:IR",
+        "@local_tsl//tsl/platform:blocking_counter",
+        "@local_tsl//tsl/platform:env",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:logging",
+        "@local_tsl//tsl/platform:status",
+    ],
+)
+
+xla_cc_test(
+    name = "priority_fusion_test",
+    srcs = ["priority_fusion_test.cc"],
+    local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
+    tags = ["no_pip"],
+    deps = [
+        ":priority_fusion",
+        "//xla:shape_util",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_cost_analysis",
+        "//xla/service:pattern_matcher",
+        "//xla/service:pattern_matcher_gmock",
+        "//xla/service/gpu:backend_configs_cc",
+        "//xla/service/gpu:gpu_device_info_for_tests",
+        "//xla/service/gpu:gpu_fusible",
+        "//xla/service/gpu:hlo_fusion_analysis",
+        "//xla/service/gpu/model:gpu_hlo_cost_analysis",
+        "//xla/tests:hlo_test_base",
+        "//xla/tests:verified_hlo_module",
+        "//xla/tests:xla_internal_test_main",
+        "@com_google_absl//absl/strings:string_view",
+        "@com_google_googletest//:gtest",
+        "@local_tsl//tsl/platform:status_matchers",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+cc_library(
+    name = "reduce_scatter_creator",
+    srcs = ["reduce_scatter_creator.cc"],
+    hdrs = ["reduce_scatter_creator.h"],
+    deps = [
+        "//xla:shape_util",
+        "//xla:status_macros",
+        "//xla/hlo/ir:hlo",
+        "//xla/hlo/utils:hlo_query",
+        "//xla/service:collective_opt_utils",
+        "//xla/service:hlo_module_config",
+        "//xla/service:hlo_pass",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings:string_view",
+        "@local_tsl//tsl/platform:errors",
+    ],
+)
+
+xla_cc_test(
+    name = "reduce_scatter_creator_test",
+    srcs = ["reduce_scatter_creator_test.cc"],
+    deps = [
+        ":reduce_scatter_creator",
+        "//xla:util",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_module_config",
+        "//xla/service:pattern_matcher",
+        "//xla/service:pattern_matcher_gmock",
+        "//xla/tests:hlo_test_base",
+        "//xla/tests:xla_internal_test_main",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings:string_view",
+        "@com_google_googletest//:gtest_main",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+cc_library(
+    name = "reduction_degenerate_dim_remover",
+    srcs = ["reduction_degenerate_dim_remover.cc"],
+    hdrs = ["reduction_degenerate_dim_remover.h"],
+    deps = [
+        "//xla:shape_util",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_pass",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/container:inlined_vector",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+xla_cc_test(
+    name = "reduction_degenerate_dim_remover_test",
+    srcs = [
+        "reduction_degenerate_dim_remover_test.cc",
+    ],
+    deps = [
+        ":reduction_degenerate_dim_remover",
+        "//xla/tests:hlo_test_base",
+        "@com_google_absl//absl/strings:string_view",
+        "@local_tsl//tsl/platform:test",
+        "@local_tsl//tsl/platform:test_main",
+    ],
+)
+
+cc_library(
+    name = "reduction_dimension_grouper",
+    srcs = ["reduction_dimension_grouper.cc"],
+    hdrs = ["reduction_dimension_grouper.h"],
+    deps = [
+        "//xla:shape_util",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_pass",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/container:inlined_vector",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings:string_view",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+xla_cc_test(
+    name = "reduction_dimension_grouper_test",
+    srcs = [
+        "reduction_dimension_grouper_test.cc",
+    ],
+    deps = [
+        ":reduction_dimension_grouper",
+        "//xla/tests:hlo_test_base",
+        "@com_google_absl//absl/strings:string_view",
+        "@local_tsl//tsl/platform:test",
+        "@local_tsl//tsl/platform:test_main",
+    ],
+)
+
+cc_library(
+    name = "reduction_layout_normalizer",
+    srcs = ["reduction_layout_normalizer.cc"],
+    hdrs = ["reduction_layout_normalizer.h"],
+    deps = [
+        "//xla:shape_util",
+        "//xla:status_macros",
+        "//xla:util",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_pass",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/container:inlined_vector",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+xla_test(
+    name = "reduction_layout_normalizer_test",
+    srcs = [
+        "reduction_layout_normalizer_test.cc",
+    ],
+    backends = ["gpu"],
+    deps = [
+        ":reduction_layout_normalizer",
+        "//xla:error_spec",
+        "//xla/tests:hlo_test_base",
+        "@com_google_absl//absl/strings:string_view",
+        "@local_tsl//tsl/platform:test",
+        "@local_tsl//tsl/platform:test_main",
+    ],
+)
+
+cc_library(
+    name = "reduction_splitter",
+    srcs = ["reduction_splitter.cc"],
+    hdrs = ["reduction_splitter.h"],
+    deps = [
+        "//xla:shape_util",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_pass",
+        "//xla/service/gpu:reduction_utils",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings:string_view",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+xla_cc_test(
+    name = "reduction_splitter_test",
+    srcs = ["reduction_splitter_test.cc"],
+    deps = [
+        ":reduction_splitter",
+        "//xla:shape_util",
+        "//xla:test",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_parser",
+        "//xla/service:pattern_matcher",
+        "//xla/service:pattern_matcher_gmock",
+        "//xla/tests:hlo_test_base",
+        "//xla/tests:xla_internal_test_main",
+    ],
+)
+
+cc_library(
+    name = "rename_fusions",
+    srcs = ["rename_fusions.cc"],
+    hdrs = ["rename_fusions.h"],
+    deps = [
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_pass",
+        "//xla/service/gpu:hlo_traversal",
+        "//xla/service/gpu:ir_emission_utils",
+        "@com_google_absl//absl/container:btree",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
+    ],
+)
+
+xla_cc_test(
+    name = "rename_fusions_test",
+    srcs = ["rename_fusions_test.cc"],
+    deps = [
+        ":rename_fusions",
+        "//xla/tests:hlo_test_base",
+        "//xla/tests:xla_internal_test_main",
+        "@com_google_absl//absl/strings:string_view",
+        "@com_google_googletest//:gtest",
+    ],
+)
+
+cc_library(
+    name = "sanitize_constant_names",
+    srcs = ["sanitize_constant_names.cc"],
+    hdrs = ["sanitize_constant_names.h"],
+    deps = [
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_pass",
+        "//xla/service:name_uniquer",
+        "//xla/service/llvm_ir:buffer_assignment_util",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings:string_view",
+        "@local_tsl//tsl/platform:logging",
+    ],
+)
+
+xla_cc_test(
+    name = "sanitize_constant_names_test",
+    srcs = ["sanitize_constant_names_test.cc"],
+    deps = [
+        ":sanitize_constant_names",
+        "//xla:literal_util",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:pattern_matcher",
+        "//xla/service:pattern_matcher_gmock",
+        "//xla/tests:hlo_test_base",
+        "//xla/tests:xla_internal_test_main",
+        "@local_tsl//tsl/platform:statusor",
+        "@local_tsl//tsl/platform:test",
+    ],
+)
+
+cc_library(
+    name = "scatter_slice_simplifier",
+    srcs = ["scatter_slice_simplifier.cc"],
+    hdrs = ["scatter_slice_simplifier.h"],
+    deps = [
+        "//xla:shape_util",
+        "//xla:util",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_creation_utils",
+        "//xla/service:hlo_pass",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings:string_view",
+        "@com_google_absl//absl/types:span",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+# TODO(b/358278858): Currently lacking test coverage.
+cc_library(
+    name = "scatter_expander",
+    srcs = ["scatter_expander.cc"],
+    hdrs = ["scatter_expander.h"],
+    deps = [
+        "//xla:shape_util",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:scatter_expander",
+        "@com_google_absl//absl/strings:string_view",
+    ],
+)
+
+xla_cc_test(
+    name = "scatter_slice_simplifier_test",
+    srcs = ["scatter_slice_simplifier_test.cc"],
+    deps = [
+        ":scatter_slice_simplifier",
+        "//xla:shape_util",
+        "//xla/service:pattern_matcher",
+        "//xla/service:pattern_matcher_gmock",
+        "//xla/tests:hlo_test_base",
+        "//xla/tests:xla_internal_test_main",
+        "@com_google_googletest//:gtest_main",
+    ],
+)
+
+cc_library(
+    name = "schedule_postprocessing",
+    srcs = ["schedule_postprocessing.cc"],
+    hdrs = ["schedule_postprocessing.h"],
+    deps = [
+        "//xla/hlo/ir:hlo",
+        "//xla/hlo/utils:hlo_query",
+        "//xla/service:hlo_pass",
+        "//xla/service/gpu:backend_configs_cc",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings:string_view",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+xla_cc_test(
+    name = "schedule_postprocessing_test",
+    srcs = ["schedule_postprocessing_test.cc"],
+    deps = [
+        ":schedule_postprocessing",
+        "//xla:util",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_parser",
+        "//xla/service/gpu:backend_configs_cc",
+        "//xla/tests:hlo_test_base",
+        "//xla/tests:xla_internal_test_main",
+        "@com_google_absl//absl/strings:string_view",
+        "@com_google_googletest//:gtest",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+cc_library(
+    name = "scheduling_instruction_annotator",
+    srcs = ["scheduling_instruction_annotator.cc"],
+    hdrs = ["scheduling_instruction_annotator.h"],
+    deps = [
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_pass",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings:string_view",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+xla_cc_test(
+    name = "scheduling_instruction_annotator_test",
+    srcs = ["scheduling_instruction_annotator_test.cc"],
+    deps = [
+        ":scheduling_instruction_annotator",
+        "//xla/hlo/ir:hlo",
+        "//xla/tests:filecheck",
+        "//xla/tests:hlo_test_base",
+        "@com_google_absl//absl/strings:string_view",
+        "@com_google_googletest//:gtest_main",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+cc_library(
+    name = "softmax_rewriter_triton",
+    srcs = ["softmax_rewriter_triton.cc"],
+    hdrs = ["softmax_rewriter_triton.h"],
+    deps = [
+        "//xla:shape_util",
+        "//xla:status_macros",
+        "//xla:util",
+        "//xla:xla_data_proto_cc",
+        "//xla/hlo/ir:hlo",
+        "//xla/hlo/utils:hlo_query",
+        "//xla/service:hlo_cost_analysis",
+        "//xla/service:hlo_pass",
+        "//xla/service:instruction_fusion",
+        "//xla/service/gpu:backend_configs_cc",
+        "//xla/service/gpu:hlo_traversal",
+        "//xla/service/gpu:ir_emission_utils",
+        "//xla/service/gpu/fusions/triton:triton_support",
+        "//xla/service/gpu/model:fusion_analysis_cache",
+        "//xla/service/gpu/model:gpu_indexing_performance_model",
+        "//xla/service/gpu/model:symbolic_tile_analysis",
+        "//xla/service/gpu/model:tiled_hlo_computation",
+        "//xla/stream_executor:device_description",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
+        "@llvm-project//mlir:IR",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:logging",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+xla_cc_test(
+    name = "softmax_rewriter_triton_test",
+    srcs = ["softmax_rewriter_triton_test.cc"],
+    deps = [
+        ":softmax_rewriter_triton",
+        "//xla:shape_util",
+        "//xla/hlo/ir:hlo",
+        "//xla/hlo/utils:hlo_query",
+        "//xla/service:instruction_fusion",
+        "//xla/service:pattern_matcher",
+        "//xla/service:pattern_matcher_gmock",
+        "//xla/service/gpu:backend_configs_cc",
+        "//xla/service/gpu:gpu_device_info_for_tests",
+        "//xla/service/gpu/fusions/triton:triton_support",
+        "//xla/service/gpu/model:gpu_hlo_cost_analysis",
+        "//xla/stream_executor:device_description",
+        "//xla/tests:hlo_test_base",
+        "//xla/tests:xla_internal_test_main",  # build_cleaner: keep
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/log:check",
+        "@com_google_googletest//:gtest",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:status_matchers",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+cc_library(
+    name = "sort_rewriter",
+    srcs = if_gpu_is_configured(
+        ["sort_rewriter.cc"],
+        ["sort_rewriter_stub.cc"],
+    ),
+    hdrs = ["sort_rewriter.h"],
+    deps = [
+        "//xla:comparison_util",
+        "//xla:shape_util",
+        "//xla:util",
+        "//xla:xla_data_proto_cc",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_pass",
+        "//xla/service:stable_sort_expander",
+        "//xla/service/gpu:cublas_cudnn",
+        "//xla/service/gpu/runtime:cub_sort_thunk",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:span",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:logging",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+xla_test(
+    name = "sort_rewriter_test",
+    srcs = if_cuda_is_configured(["sort_rewriter_test.cc"]),
+    backends = ["gpu"],
+    tags = ["no_oss"],
+    deps = [
+        ":sort_rewriter",
+        "//xla:error_spec",
+        "//xla:xla_data_proto_cc",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:pattern_matcher",
+        "//xla/service:pattern_matcher_gmock",
+        "//xla/service/gpu:cublas_cudnn",
+        "//xla/tests:hlo_test_base",
+        "//xla/tests:xla_internal_test_main",  # fixdeps: keep
+        "@com_google_googletest//:gtest",
+        "@local_tsl//tsl/platform:statusor",
+        "@local_tsl//tsl/platform:test",
+    ],
+)
+
+cc_library(
+    name = "stream_attribute_annotator",
+    srcs = ["stream_attribute_annotator.cc"],
+    hdrs = ["stream_attribute_annotator.h"],
+    deps = [
+        "//xla:util",
+        "//xla:xla_data_proto_cc",
+        "//xla/hlo/ir:hlo",
+        "//xla/hlo/utils:hlo_query",
+        "//xla/service:hlo_pass",
+        "//xla/service/gpu:backend_configs_cc",
+        "//xla/service/gpu:gpu_fusible",
+        "//xla/service/gpu/runtime:thunk",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings:string_view",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:logging",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+xla_cc_test(
+    name = "stream_attribute_annotator_test",
+    srcs = ["stream_attribute_annotator_test.cc"],
+    deps = [
+        ":stream_attribute_annotator",
+        "//xla/hlo/ir:hlo",
+        "//xla/service/gpu:backend_configs_cc",
+        "//xla/tests:filecheck",
+        "//xla/tests:hlo_test_base",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/strings:string_view",
+        "@com_google_googletest//:gtest_main",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+cc_library(
+    name = "stream_attribute_async_wrapper",
+    srcs = ["stream_attribute_async_wrapper.cc"],
+    hdrs = ["stream_attribute_async_wrapper.h"],
+    deps = [
+        "//xla:util",
+        "//xla:xla_data_proto_cc",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_pass",
+        "//xla/service/gpu:backend_configs_cc",
+        "//xla/service/gpu/runtime:thunk",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings:string_view",
+        "@local_tsl//tsl/platform:logging",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+xla_cc_test(
+    name = "stream_attribute_async_wrapper_test",
+    srcs = ["stream_attribute_async_wrapper_test.cc"],
+    deps = [
+        ":stream_attribute_async_wrapper",
+        "//xla/hlo/ir:hlo",
+        "//xla/service/gpu:backend_configs_cc",
+        "//xla/tests:hlo_test_base",
+        "@com_google_absl//absl/strings:string_view",
+        "@com_google_googletest//:gtest_main",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+cc_library(
+    name = "topk_specializer",
+    srcs = ["topk_specializer.cc"],
+    hdrs = ["topk_specializer.h"],
+    deps = [
+        "//xla:shape_util",
+        "//xla:status_macros",
+        "//xla:util",
+        "//xla:xla_data_proto_cc",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_pass",
+        "//xla/service:hlo_proto_cc",
+        "//xla/service:tuple_util",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
+    ],
+)
+
+xla_test(
+    name = "topk_specializer_test",
+    srcs = ["topk_specializer_test.cc"],
+    backends = ["gpu"],
+    deps = [
+        ":topk_specializer",
+        "//xla:shape_util",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_pass",
+        "//xla/service:platform_util",
+        "//xla/service:topk_rewriter",
+        "//xla/tests:hlo_test_base",
+        "//xla/tests:xla_internal_test_main",  # fixdeps: keep
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/strings",
+        "@com_google_googletest//:gtest_main",
+        "@local_tsl//tsl/platform:statusor",
+        "@local_tsl//tsl/platform:test_main",
+    ],
+)
+
+cc_library(
+    name = "topk_splitter",
+    srcs = ["topk_splitter.cc"],
+    hdrs = ["topk_splitter.h"],
+    deps = [
+        "//xla:shape_util",
+        "//xla:xla_data_proto_cc",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_creation_utils",
+        "//xla/service:hlo_pass",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/numeric:bits",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:span",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+xla_cc_test(
+    name = "topk_splitter_test",
+    srcs = ["topk_splitter_test.cc"],
+    deps = [
+        ":topk_splitter",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_dce",
+        "//xla/service:pattern_matcher",
+        "//xla/service:topk_rewriter",
+        "//xla/tests:hlo_test_base",
+        "//xla/tests:verified_hlo_module",
+        "//xla/tests:xla_internal_test_main",
+        "@com_google_absl//absl/strings",
+        "@local_tsl//tsl/platform:status_matchers",
+        "@local_tsl//tsl/platform:statusor",
+        "@local_tsl//tsl/platform:test",
+    ],
+)
+
+cc_library(
+    name = "tree_reduction_rewriter",
+    srcs = ["tree_reduction_rewriter.cc"],
+    hdrs = ["tree_reduction_rewriter.h"],
+    deps = [
+        "//xla:shape_util",
+        "//xla:util",
+        "//xla:xla_data_proto_cc",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:collective_ops_utils",
+        "//xla/service:hlo_module_config",
+        "//xla/service:hlo_pass",
+        "//xla/service/gpu:reduction_utils",
+        "//xla/stream_executor:device_description",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/container:inlined_vector",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/numeric:bits",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:span",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+xla_cc_test(
+    name = "tree_reduction_rewriter_test",
+    srcs = [
+        "tree_reduction_rewriter_test.cc",
+    ],
+    deps = [
+        ":tree_reduction_rewriter",
+        "//xla/stream_executor:device_description",
+        "//xla/tests:hlo_test_base",
+        "@com_google_absl//absl/strings",
+        "@local_tsl//tsl/platform:test",
+        "@local_tsl//tsl/platform:test_main",
+    ],
+)
+
+# TODO(b/358278858): Currently lacking test coverage.
+cc_library(
+    name = "triangular_solve_rewriter",
+    srcs = ["triangular_solve_rewriter.cc"],
+    hdrs = ["triangular_solve_rewriter.h"],
+    deps = [
+        "//xla:shape_util",
+        "//xla:xla_data_proto_cc",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_creation_utils",
+        "//xla/service:hlo_pass",
+        "//xla/service/gpu:cublas_cudnn",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+cc_library(
+    name = "triton_fusion_numerics_verifier",
+    srcs = ["triton_fusion_numerics_verifier.cc"],
+    hdrs = ["triton_fusion_numerics_verifier.h"],
+    tags = ["gpu"],
+    deps = [
+        "//xla:shape_util",
+        "//xla:status_macros",
+        "//xla:util",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:executable",
+        "//xla/service:hlo_module_config",
+        "//xla/service:hlo_pass",
+        "//xla/service:shaped_buffer",
+        "//xla/service/gpu:backend_configs_cc",
+        "//xla/service/gpu:buffer_comparator",
+        "//xla/service/gpu:ir_emission_utils",
+        "//xla/service/gpu/autotuning:autotuner_compile_util",
+        "//xla/service/gpu/autotuning:autotuner_util",
+        "//xla/stream_executor:stream",
+        "//xla/tools:hlo_decomposer_lib",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/functional:any_invocable",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings:string_view",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+xla_test(
+    name = "triton_fusion_numerics_verifier_test",
+    srcs = ["triton_fusion_numerics_verifier_test.cc"],
+    backend_tags = {"gpu": [
+        "requires-gpu-sm80",
+    ]},
+    backends = ["gpu"],
+    deps = [
+        ":triton_fusion_numerics_verifier",
+        "//xla:shape_util",
+        "//xla:test_helpers",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:platform_util",
+        "//xla/service/gpu/autotuning:autotuner_compile_util",
+        "//xla/service/gpu/autotuning:autotuner_util",
+        "//xla/stream_executor:platform",
+        "//xla/tests:hlo_test_base",
+        "//xla/tsl/lib/core:status_test_util",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/strings:string_view",
+        "@com_google_googletest//:gtest_main",
+    ],
+)
+
+cc_library(
+    name = "variadic_op_splitter",
+    srcs = ["variadic_op_splitter.cc"],
+    hdrs = ["variadic_op_splitter.h"],
+    deps = [
+        "//xla:shape_util",
+        "//xla:util",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_pass",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:span",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+xla_cc_test(
+    name = "variadic_op_splitter_test",
+    srcs = ["variadic_op_splitter_test.cc"],
+    tags = [
+        "nomsan",
+    ],
+    deps = [
+        ":variadic_op_splitter",
+        "//xla:literal_util",
+        "//xla:shape_util",
+        "//xla:util",
+        "//xla:xla_data_proto_cc",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_parser",
+        "//xla/service:pattern_matcher",
+        "//xla/tests:hlo_test_base",
+        "//xla/tests:xla_internal_test_main",
+        "@com_google_googletest//:gtest_main",
+    ],
+)
+
+cc_library(
+    name = "windowed_einsum_handler",
+    srcs = ["windowed_einsum_handler.cc"],
+    hdrs = ["windowed_einsum_handler.h"],
+    deps = [
+        "//xla:literal_util",
+        "//xla:util",
+        "//xla:xla_data_proto_cc",
+        "//xla/hlo/ir:hlo",
+        "//xla/hlo/utils:hlo_query",
+        "//xla/service:hlo_creation_utils",
+        "//xla/service:hlo_pass",
+        "//xla/service:pattern_matcher",
+        "//xla/service:shape_inference",
+        "//xla/service/gpu:backend_configs_cc",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings:string_view",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:logging",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+xla_cc_test(
+    name = "windowed_einsum_handler_test",
+    srcs = ["windowed_einsum_handler_test.cc"],
+    deps = [
+        ":windowed_einsum_handler",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:pattern_matcher",
+        "//xla/service:pattern_matcher_gmock",
+        "//xla/service/gpu:backend_configs_cc",
+        "//xla/tests:filecheck",
+        "//xla/tests:hlo_test_base",
+        "@com_google_absl//absl/strings:string_view",
+        "@com_google_googletest//:gtest_main",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
diff --git a/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier.cc b/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier.cc
new file mode 100644
index 0000000..bba27a0
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier.cc
@@ -0,0 +1,66 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/algebraic_simplifier.h"
+
+#include "absl/log/check.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/service/gpu/fusions/triton/triton_support.h"
+#include "xla/service/gpu/matmul_utils.h"
+#include "xla/xla_data.pb.h"
+
+namespace xla::gpu {
+
+bool GpuAlgebraicSimplifierVisitor::ShouldStrengthReduceDotToReduce(
+    const HloInstruction* hlo) {
+  if (!options_.enable_dot_strength_reduction()) {
+    return false;
+  }
+
+  const HloDotInstruction* dot = DynCast<HloDotInstruction>(hlo);
+  if (dot == nullptr) {
+    return false;
+  }
+
+  const HloInstruction* lhs = dot->operand(0);
+  const HloInstruction* rhs = dot->operand(1);
+  DotDimensionNumbers dnums = dot->dot_dimension_numbers();
+  bool lhs_is_vector = (dnums.lhs_batch_dimensions_size() +
+                            dnums.lhs_contracting_dimensions_size() ==
+                        lhs->shape().rank());
+  bool rhs_is_vector = (dnums.rhs_batch_dimensions_size() +
+                            dnums.rhs_contracting_dimensions_size() ==
+                        rhs->shape().rank());
+  // Strength-reduce vector-vector dots since they are not supported by
+  // GemmFusion.
+  if (lhs_is_vector && rhs_is_vector) {
+    return true;
+  }
+
+  absl::StatusOr<bool> is_too_small =
+      IsMatrixMultiplicationTooSmallForRewriting(*hlo, /*threshold=*/10000000);
+  CHECK_OK(is_too_small.status());
+  if (is_too_small.value()) {
+    return true;
+  }
+
+  // If GemmFusion cannot handle this dot, we should strength-reduce it so that
+  // it can be handled by the fusion pipeline.
+  return !legacy_triton::CanTritonHandleGEMM(*dot, compute_capability_);
+}
+
+}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier.h b/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier.h
new file mode 100644
index 0000000..f29b31e
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier.h
@@ -0,0 +1,78 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_ALGEBRAIC_SIMPLIFIER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_ALGEBRAIC_SIMPLIFIER_H_
+
+#include <utility>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/service/algebraic_simplifier.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/util.h"
+
+namespace xla::gpu {
+
+class GpuAlgebraicSimplifierVisitor : public AlgebraicSimplifierVisitor {
+ public:
+  explicit GpuAlgebraicSimplifierVisitor(
+      const AlgebraicSimplifierOptions& options,
+      se::GpuComputeCapability compute_capability,
+      AlgebraicSimplifier* simplifier)
+      : AlgebraicSimplifierVisitor(options, simplifier),
+        compute_capability_(std::move(compute_capability)) {}
+
+  bool ShouldStrengthReduceDotToReduce(const HloInstruction* hlo) override;
+
+ private:
+  se::GpuComputeCapability compute_capability_;
+};
+
+class GpuAlgebraicSimplifier : public AlgebraicSimplifier {
+ public:
+  explicit GpuAlgebraicSimplifier(const AlgebraicSimplifierOptions& options,
+                                  se::GpuComputeCapability compute_capability)
+      : AlgebraicSimplifier(options),
+        compute_capability_(std::move(compute_capability)) {}
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(HloModule* module,
+                           const absl::flat_hash_set<absl::string_view>&
+                               execution_threads) override {
+    XLA_VLOG_LINES(
+        2, "GpuAlgebraicSimplifier::Run(), before:\n" + module->ToString());
+    bool changed = false;
+    GpuAlgebraicSimplifierVisitor visitor(options_, compute_capability_, this);
+    for (auto* comp : module->MakeNonfusionComputations(execution_threads)) {
+      if (visitor.Run(comp, options_, this)) {
+        changed = true;
+      }
+    }
+    XLA_VLOG_LINES(
+        2, "GpuAlgebraicSimplifier::Run(), after:\n" + module->ToString());
+    return changed;
+  }
+
+ private:
+  se::GpuComputeCapability compute_capability_;
+};
+
+}  // namespace xla::gpu
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_ALGEBRAIC_SIMPLIFIER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier_test.cc b/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier_test.cc
new file mode 100644
index 0000000..c1e52e9
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier_test.cc
@@ -0,0 +1,141 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/algebraic_simplifier.h"
+
+#include <string>
+
+#include <gtest/gtest.h>
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/service/algebraic_simplifier.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla::gpu {
+namespace {
+
+class GpuAlgebraicSimplifierTest : public HloTestBase {};
+
+TEST_F(GpuAlgebraicSimplifierTest, VectorVectorDotShouldBeStrengthReduced) {
+  const std::string& hlo_string = R"(
+HloModule m
+
+ENTRY entry {
+  p0 = f32[32, 500] parameter(0)
+  p1 = f32[32, 500] parameter(1)
+  ROOT dot = f32[32] dot(p0, p1), lhs_batch_dims={0},
+    lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1}
+})";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  const HloInstruction* dot = module->entry_computation()->root_instruction();
+  AlgebraicSimplifierOptions options;
+  options.set_enable_dot_strength_reduction(true);
+  se::CudaComputeCapability ampere(8, 0);
+  GpuAlgebraicSimplifier simplifier(options, ampere);
+  GpuAlgebraicSimplifierVisitor visitor(options, ampere, &simplifier);
+  EXPECT_TRUE(visitor.ShouldStrengthReduceDotToReduce(dot));
+}
+
+TEST_F(GpuAlgebraicSimplifierTest, MatrixVectorDotShouldNotBeStrengthReduced) {
+  const std::string& hlo_string = R"(
+HloModule m
+
+ENTRY entry {
+  p0 = f32[32, 5000, 7000] parameter(0)
+  p1 = f32[32, 5000] parameter(1)
+  ROOT dot = f32[32,7000] dot(p0, p1), lhs_batch_dims={0},
+    lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1},
+    algorithm=dot_bf16_bf16_f32_x6
+})";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  const HloInstruction* dot = module->entry_computation()->root_instruction();
+  AlgebraicSimplifierOptions options;
+  options.set_enable_dot_strength_reduction(true);
+  se::CudaComputeCapability ampere(8, 0);
+  GpuAlgebraicSimplifier simplifier(options, ampere);
+  GpuAlgebraicSimplifierVisitor visitor(options, ampere, &simplifier);
+  EXPECT_FALSE(visitor.ShouldStrengthReduceDotToReduce(dot));
+}
+
+TEST_F(GpuAlgebraicSimplifierTest,
+       DotWithTypeUnsupportedByGemmFusionShouldBeStrengthReduced) {
+  const std::string& hlo_string = R"(
+HloModule m
+
+ENTRY entry {
+  p0 = c64[32, 5000, 7000] parameter(0)
+  p1 = c64[32, 5000] parameter(1)
+  ROOT dot = c64[32,7000] dot(p0, p1), lhs_batch_dims={0},
+    lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1}
+})";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  const HloInstruction* dot = module->entry_computation()->root_instruction();
+  AlgebraicSimplifierOptions options;
+  options.set_enable_dot_strength_reduction(true);
+  se::CudaComputeCapability ampere(8, 0);
+  GpuAlgebraicSimplifier simplifier(options, ampere);
+  GpuAlgebraicSimplifierVisitor visitor(options, ampere, &simplifier);
+  EXPECT_TRUE(visitor.ShouldStrengthReduceDotToReduce(dot));
+}
+
+TEST_F(GpuAlgebraicSimplifierTest, SmallDotShouldBeStrengthReduced) {
+  const std::string& hlo_string = R"(
+HloModule m
+
+ENTRY entry {
+  p0 = f32[32, 50, 70] parameter(0)
+  p1 = f32[32, 50] parameter(1)
+  ROOT dot = f32[32,70] dot(p0, p1), lhs_batch_dims={0},
+    lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1},
+    algorithm=dot_bf16_bf16_f32_x6
+})";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  const HloInstruction* dot = module->entry_computation()->root_instruction();
+  AlgebraicSimplifierOptions options;
+  options.set_enable_dot_strength_reduction(true);
+  se::CudaComputeCapability ampere(8, 0);
+  GpuAlgebraicSimplifier simplifier(options, ampere);
+  GpuAlgebraicSimplifierVisitor visitor(options, ampere, &simplifier);
+  EXPECT_TRUE(visitor.ShouldStrengthReduceDotToReduce(dot));
+}
+
+TEST_F(GpuAlgebraicSimplifierTest, SmallDotShouldBeStrengthReduced2) {
+  const std::string& hlo_string = R"(
+HloModule m
+
+ENTRY entry {
+  p0 = f32[2000, 3000] parameter(0)
+  p1 = f32[2000] parameter(1)
+  ROOT dot = f32[3000] dot(p0, p1), lhs_contracting_dims={0},
+    rhs_contracting_dims={0}, algorithm=dot_bf16_bf16_f32_x6
+})";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  const HloInstruction* dot = module->entry_computation()->root_instruction();
+  AlgebraicSimplifierOptions options;
+  options.set_enable_dot_strength_reduction(true);
+  se::CudaComputeCapability ampere(8, 0);
+  GpuAlgebraicSimplifier simplifier(options, ampere);
+  GpuAlgebraicSimplifierVisitor visitor(options, ampere, &simplifier);
+  EXPECT_TRUE(visitor.ShouldStrengthReduceDotToReduce(dot));
+}
+
+}  // namespace
+}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/transforms/algorithm_checker.cc b/third_party/xla/xla/service/gpu/transforms/algorithm_checker.cc
new file mode 100644
index 0000000..664d7b2
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/algorithm_checker.cc
@@ -0,0 +1,117 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/algorithm_checker.h"
+
+#include <utility>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/service/algorithm_util.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/logging.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+
+bool HasNonDefaultOperandPrecision(const PrecisionConfig& config) {
+  return absl::c_any_of(config.operand_precision(), [](int precision) {
+    return static_cast<PrecisionConfig::Precision>(precision) !=
+           PrecisionConfig::DEFAULT;
+  });
+}
+
+class AlgorithmCheckerVisitor : public ConstDfsHloVisitorWithDefault {
+ public:
+  explicit AlgorithmCheckerVisitor(
+      se::GpuComputeCapability gpu_compute_capability)
+      : gpu_compute_capability_(std::move(gpu_compute_capability)) {}
+
+  absl::Status RunOnModule(
+      const HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads = {}) {
+    for (HloComputation* computation :
+         module->MakeNonfusionComputations(execution_threads)) {
+      TF_RETURN_IF_ERROR(computation->Accept(this));
+    }
+    return absl::OkStatus();
+  }
+
+  absl::Status HandleDot(const HloInstruction* hlo) override {
+    VLOG(1) << "Handling dot: " << hlo->ToString();
+    const PrecisionConfig& config = hlo->precision_config();
+
+    if (config.algorithm() != PrecisionConfig::ALG_UNSET &&
+        HasNonDefaultOperandPrecision(config)) {
+      LOG(WARNING)
+          << "There is no need to set precisions when we set the algorithm: "
+          << hlo->ToString();
+    }
+
+    if (config.algorithm() == PrecisionConfig::ALG_UNSET) {
+      return absl::OkStatus();
+    }
+
+    PrimitiveType lhs_storage_type = hlo->operand(0)->shape().element_type();
+    PrimitiveType rhs_storage_type = hlo->operand(1)->shape().element_type();
+    PrimitiveType output_storage_type = hlo->shape().element_type();
+
+    if (lhs_storage_type != rhs_storage_type) {
+      return absl::UnimplementedError(absl::StrFormat(
+          "Dot operands must have the same type when using an algorithm: %s",
+          hlo->ToString()));
+    }
+
+    return algorithm_util::IsSupportedDotAlgorithmOnGpu(
+               config.algorithm(), gpu_compute_capability_, lhs_storage_type,
+               output_storage_type)
+               ? absl::OkStatus()
+               : absl::UnimplementedError(absl::StrFormat(
+                     "Unsupported algorithm on the current device(s): %s",
+                     PrecisionConfig::Algorithm_Name(config.algorithm())));
+  }
+
+  absl::Status DefaultAction(const HloInstruction* hlo) override {
+    return absl::OkStatus();
+  }
+
+ private:
+  se::GpuComputeCapability gpu_compute_capability_;
+};
+
+}  // namespace
+
+absl::StatusOr<bool> AlgorithmChecker::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  TF_RETURN_IF_ERROR(AlgorithmCheckerVisitor(gpu_compute_capability_)
+                         .RunOnModule(module, execution_threads));
+  // No change was made.
+  return false;
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/algorithm_checker.h b/third_party/xla/xla/service/gpu/transforms/algorithm_checker.h
new file mode 100644
index 0000000..c2cf0d2
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/algorithm_checker.h
@@ -0,0 +1,54 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_ALGORITHM_CHECKER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_ALGORITHM_CHECKER_H_
+
+#include <utility>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/stream_executor/device_description.h"
+
+namespace xla {
+namespace gpu {
+
+// This checks if the requested algorithms are supported. This can give an early
+// and specific error if an unsupported algorithm is requested.
+//
+// Note: Maybe we can make this more generic and move it outside of GPU.
+class AlgorithmChecker : public HloModulePass {
+ public:
+  explicit AlgorithmChecker(se::GpuComputeCapability gpu_compute_capability)
+      : gpu_compute_capability_(std::move(gpu_compute_capability)){};
+
+  absl::string_view name() const override { return "algorithm-checker"; }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+  se::GpuComputeCapability gpu_compute_capability_;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_ALGORITHM_CHECKER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/alias_passthrough_params.cc b/third_party/xla/xla/service/gpu/transforms/alias_passthrough_params.cc
new file mode 100644
index 0000000..0d6bff3
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/alias_passthrough_params.cc
@@ -0,0 +1,69 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "xla/service/gpu/transforms/alias_passthrough_params.h"
+
+#include <cstdint>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/shape_util.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/logging.h"
+
+namespace xla {
+namespace gpu {
+
+absl::StatusOr<bool> AliasPassthroughParams::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  const HloInstruction* root = module->entry_computation()->root_instruction();
+  if (module->entry_computation()->num_parameters() == 0 ||
+      root->opcode() != HloOpcode::kTuple) {
+    return false;
+  }
+  bool changed = false;
+  absl::flat_hash_set<int64_t> used_params;
+  for (int64_t i = 0; i < root->operand_count(); ++i) {
+    if (root->operand(i)->opcode() == HloOpcode::kParameter &&
+        used_params.count(root->operand(i)->parameter_number()) == 0) {
+      VLOG(2) << "Parameter " << root->operand(i)->parameter_number()
+              << " with shape " << root->operand(i)->shape().ToString()
+              << " in module " << module->name()
+              << " is passed-through to root tuple element " << i << ": "
+              << root->shape().ToString();
+
+      if (module->input_output_alias_config().OutputHasAlias({i}) ||
+          module->input_output_alias_config().ParameterHasAlias(
+              root->operand(i)->parameter_number(), /*param_index=*/{})) {
+        VLOG(2) << "Skip setting the above pass-through alias as an alias may"
+                << " have been set up for alising resource update.";
+        continue;
+      }
+
+      TF_RETURN_IF_ERROR(module->input_output_alias_config().SetUpAlias(
+          /*output_index=*/{i},
+          /*param_number=*/root->operand(i)->parameter_number(),
+          /*param_index=*/{}));
+      used_params.insert(root->operand(i)->parameter_number());
+      changed = true;
+    }
+  }
+  return changed;
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/alias_passthrough_params.h b/third_party/xla/xla/service/gpu/transforms/alias_passthrough_params.h
new file mode 100644
index 0000000..4cd4ab4
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/alias_passthrough_params.h
@@ -0,0 +1,50 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_ALIAS_PASSTHROUGH_PARAMS_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_ALIAS_PASSTHROUGH_PARAMS_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// This pass aliases input and output buffers that are associated with a
+// parameter that is passed through to the module root unmodified.
+//
+// This pass assumes that parameters and the root use unnested shapes, which is
+// the case for XLA:GPU.
+//
+// This pass must run prior to copy insertion.
+class AliasPassthroughParams : public HloModulePass {
+ public:
+  AliasPassthroughParams() = default;
+  ~AliasPassthroughParams() override = default;
+  absl::string_view name() const override { return "alias_passthrough_params"; }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_ALIAS_PASSTHROUGH_PARAMS_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/alias_passthrough_params_test.cc b/third_party/xla/xla/service/gpu/transforms/alias_passthrough_params_test.cc
new file mode 100644
index 0000000..32c1e5b
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/alias_passthrough_params_test.cc
@@ -0,0 +1,87 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/alias_passthrough_params.h"
+
+#include "xla/tests/hlo_test_base.h"
+#include "xla/tsl/lib/core/status_test_util.h"
+#include "tsl/platform/test.h"
+
+namespace xla {
+namespace gpu {
+
+class AliasPassthroughParamsTest : public HloTestBase {};
+
+TEST_F(AliasPassthroughParamsTest, AliasPassThroughParams) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule TestModule
+
+  ENTRY TestComputation {
+    p0 = f16[2048,1024] parameter(0)
+    p1 = f16[2048,1024] parameter(1)
+    sum = f16[2048,1024] add(p0, p1)
+    ROOT root = (f16[2048,1024], f16[2048,1024], f16[2048,1024]) tuple(p0, sum, p1)
+  })")
+                    .value();
+  EXPECT_TRUE(AliasPassthroughParams().Run(module.get()).value());
+  const auto& alias_config = module->input_output_alias_config();
+  EXPECT_EQ(0, alias_config.GetAliasedParameter({0})->parameter_number);
+  EXPECT_FALSE(alias_config.OutputHasAlias({1}));
+  EXPECT_EQ(1, alias_config.GetAliasedParameter({2})->parameter_number);
+}
+
+TEST_F(AliasPassthroughParamsTest, DoNotAliasPassThroughParamsMoreThanOnce) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule TestModule
+
+  ENTRY TestComputation {
+    p0 = f16[2048,1024] parameter(0)
+    ROOT root = (f16[2048,1024], f16[2048,1024]) tuple(p0, p0)
+  })")
+                    .value();
+  EXPECT_TRUE(AliasPassthroughParams().Run(module.get()).value());
+  const auto& alias_config = module->input_output_alias_config();
+  EXPECT_EQ(0, alias_config.GetAliasedParameter({0})->parameter_number);
+  EXPECT_FALSE(alias_config.OutputHasAlias({1}));
+}
+
+TEST_F(AliasPassthroughParamsTest, PresetAliases) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule TestModule
+
+  ENTRY TestComputation {
+    p0 = f16[2048,1024] parameter(0)
+    p1 = f16[2048,1024] parameter(1)
+    sum = f16[2048,1024] add(p0, p1)
+    ROOT root = (f16[2048,1024], f16[2048,1024], f16[2048,1024]) tuple(p0, sum, p1)
+  })")
+                    .value();
+
+  // Presetting an alias for p0 -> Sum. This could happen in a case of
+  // `alias_resource_update`.
+  auto& preset_alias = module->input_output_alias_config();
+  TF_EXPECT_OK(preset_alias.SetUpAlias(/*output_index=*/{1},
+                                       /*param_number=*/0,
+                                       /*param_index=*/{}));
+
+  EXPECT_TRUE(AliasPassthroughParams().Run(module.get()).value());
+  const auto& alias_result = module->input_output_alias_config();
+  // Assert that an alias p1 -> p1 is established by `AliasPassthroughParams`.
+  EXPECT_EQ(1, alias_result.GetAliasedParameter({2})->parameter_number);
+  EXPECT_FALSE(alias_result.OutputHasAlias({0}));
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/all_gather_optimizer.cc b/third_party/xla/xla/service/gpu/transforms/all_gather_optimizer.cc
new file mode 100644
index 0000000..2f7c130
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/all_gather_optimizer.cc
@@ -0,0 +1,109 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/all_gather_optimizer.h"
+
+#include <cstdint>
+#include <utility>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/collective_ops_utils.h"
+#include "xla/shape_util.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/logging.h"
+
+namespace xla {
+namespace gpu {
+
+absl::StatusOr<bool> AllGatherOptimizer::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  bool changed = false;
+  for (HloComputation* computation :
+       module->MakeNonfusionComputations(execution_threads)) {
+    for (HloInstruction* instruction :
+         computation->MakeInstructionPostOrder()) {
+      if (!HloOpcodeIsBinaryCommutative(instruction->opcode())) {
+        continue;
+      }
+
+      HloInstruction* left_op = instruction->mutable_operand(0);
+      HloInstruction* right_op = instruction->mutable_operand(1);
+
+      if (right_op->opcode() != HloOpcode::kAllGather ||
+          left_op->opcode() != HloOpcode::kAllGather) {
+        VLOG(2) << "Binary op's operands are not all-gather deduced types.";
+        continue;
+      }
+
+      auto* left_all_gather = Cast<HloAllGatherInstruction>(left_op);
+      auto* right_all_gather = Cast<HloAllGatherInstruction>(right_op);
+
+      if (right_all_gather->constrain_layout() !=
+              left_all_gather->constrain_layout() ||
+          right_all_gather->use_global_device_ids() !=
+              left_all_gather->use_global_device_ids() ||
+          !ReplicaGroupsEqual(right_all_gather->replica_groups(),
+                              left_all_gather->replica_groups())) {
+        VLOG(2) << "The right and left all-gather ops are not compatible "
+                   "to merge. ";
+        continue;
+      }
+
+      if (!ShapeUtil::Equal(left_all_gather->operand(0)->shape(),
+                            right_all_gather->operand(0)->shape())) {
+        VLOG(2) << "all-gather operands have different shapes";
+        continue;
+      }
+
+      if (right_all_gather->user_count() != 1 ||
+          left_all_gather->user_count() != 1) {
+        VLOG(2) << "all-gather user_count > 1 ";
+        continue;
+      }
+      auto index_in_full_shape =
+          computation->AddInstruction(HloInstruction::CreateBinary(
+              right_all_gather->operand(0)->shape(), instruction->opcode(),
+              left_all_gather->mutable_operand(0),
+              right_all_gather->mutable_operand(0)));
+
+      int64_t all_gather_dimension =
+          Cast<HloAllGatherInstruction>(right_all_gather)
+              ->all_gather_dimension();
+
+      auto combined = HloInstruction::CreateAllGather(
+          left_all_gather->shape(), {index_in_full_shape}, all_gather_dimension,
+          left_all_gather->device_list(),
+          /*constrain_layout=*/false, left_all_gather->channel_id(),
+          Cast<HloAllGatherInstruction>(left_all_gather)
+              ->use_global_device_ids());
+
+      TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction(
+          instruction, std::move(combined)));
+      changed = true;
+    }
+  }
+
+  return changed;
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/all_gather_optimizer.h b/third_party/xla/xla/service/gpu/transforms/all_gather_optimizer.h
new file mode 100644
index 0000000..988c1f6
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/all_gather_optimizer.h
@@ -0,0 +1,46 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_ALL_GATHER_OPTIMIZER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_ALL_GATHER_OPTIMIZER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// Transforms binary_op(all-gather(reduce_scatter(a)),
+// all-gather(reduce_scatter(b))) to allgather(binary_op(reduce_scatter(a),
+// reduce_scatter(b)))
+
+class AllGatherOptimizer : public HloModulePass {
+ public:
+  AllGatherOptimizer() = default;
+  absl::string_view name() const override { return "all-gather-optimizer"; }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_ALL_GATHER_OPTIMIZER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/all_gather_optimizer_test.cc b/third_party/xla/xla/service/gpu/transforms/all_gather_optimizer_test.cc
new file mode 100644
index 0000000..27f6d65
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/all_gather_optimizer_test.cc
@@ -0,0 +1,232 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/all_gather_optimizer.h"
+
+#include <cstddef>
+#include <cstdint>
+#include <memory>
+#include <utility>
+
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/hlo_module_config.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/util.h"
+#include "tsl/platform/statusor.h"
+#include "tsl/platform/test.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+class GpuAllGatherOptimizerTest : public HloTestBase {
+ public:
+  absl::StatusOr<std::unique_ptr<HloModule>> RunPass(
+      absl::string_view hlo_module, int64_t num_replicas,
+      int64_t num_partitions, bool expect_change) {
+    HloModuleConfig config = GetModuleConfigForTest(
+        /*replica_count=*/num_replicas,
+        /*num_partitions=*/num_partitions);
+    config.set_use_spmd_partitioning(num_partitions > 1);
+    TF_ASSIGN_OR_RETURN(auto module,
+                        ParseAndReturnVerifiedModule(hlo_module, config));
+
+    auto changed = AllGatherOptimizer().Run(module.get());
+    if (!changed.ok()) {
+      return changed.status();
+    }
+    EXPECT_EQ(changed.value(), expect_change);
+    return absl::StatusOr<std::unique_ptr<HloModule>>(std::move(module));
+  }
+
+  template <HloOpcode oc>
+  size_t CollectiveCount(std::unique_ptr<HloModule> &module) {
+    return absl::c_count_if(module->entry_computation()->instructions(),
+                            HloPredicateIsOp<oc>);
+  }
+};
+
+TEST_F(GpuAllGatherOptimizerTest, BranchesOptimized) {
+  absl::string_view hlo_string = R"(
+HloModule ReduceScatter
+
+add {
+  x = bf16[] parameter(0)
+  y = bf16[] parameter(1)
+  ROOT add = bf16[] add(x, y)
+}
+
+ENTRY main {
+param.1 = bf16[8,128,1024]{2,1,0} parameter(0)
+param.2 = bf16[8,128,1024]{2,1,0} parameter(1)
+reduce-scatter.1 = bf16[8,64,1024]{2,1,0} reduce-scatter(param.1), channel_id=8, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, dimensions={1}, to_apply=add
+all-gather.1 = bf16[8,128,1024]{2,1,0} all-gather(reduce-scatter.1), channel_id=5, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={1}, use_global_device_ids=true
+reduce-scatter.2 = bf16[8,64,1024]{2,1,0} reduce-scatter(param.2), channel_id=9, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, dimensions={1}, to_apply=add
+all-gather.2 = bf16[8,128,1024]{2,1,0} all-gather(reduce-scatter.2), channel_id=5, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={1}, use_global_device_ids=true
+add.1 = bf16[8,128,1024]{2,1,0} add(all-gather.1, all-gather.2)
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
+                                               /*num_replicas=*/8,
+                                               /*num_partitions=*/1,
+                                               /*expect_change=*/true));
+  // graph should contain 1 all-gather but since the node removal piece
+  // is diferred, they still exist at this stage
+  EXPECT_EQ(CollectiveCount<HloOpcode::kAllGather>(module), 3);
+  EXPECT_EQ(CollectiveCount<HloOpcode::kReduceScatter>(module), 2);
+}
+
+TEST_F(GpuAllGatherOptimizerTest, DisbledSPMDPartitioningJAXBug) {
+  absl::string_view hlo_string = R"(
+HloModule pjit_f, entry_computation_layout={(f32[4,8]{1,0}, f32[4,8]{1,0})->f32[8,8]{1,0}}
+
+ENTRY %main.6_spmd (param: f32[4,8], param.1: f32[4,8]) -> f32[8,8] {
+  %param = f32[4,8]{1,0} parameter(0), sharding={devices=[2,1]<=[2]}
+  %all-gather = f32[8,8]{1,0} all-gather(f32[4,8]{1,0} %param), channel_id=1, replica_groups={{0,1}}, dimensions={0}, use_global_device_ids=true, metadata={op_name="pjit(f)/jit(main)/add" source_file="third_party/py/jax/tests/pjit_test.py" source_line=207}
+  %param.1 = f32[4,8]{1,0} parameter(1), sharding={devices=[2,1]<=[2]}
+  %all-gather.1 = f32[8,8]{1,0} all-gather(f32[4,8]{1,0} %param.1), channel_id=2, replica_groups={{0,1}}, dimensions={0}, use_global_device_ids=true, metadata={op_name="pjit(f)/jit(main)/add" source_file="third_party/py/jax/tests/pjit_test.py" source_line=207}
+  ROOT %add.0 = f32[8,8]{1,0} add(f32[8,8]{1,0} %all-gather, f32[8,8]{1,0} %all-gather.1), metadata={op_name="pjit(f)/jit(main)/add" source_file="third_party/py/jax/tests/pjit_test.py" source_line=207}
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
+                                               /*num_replicas=*/1,
+                                               /*num_partitions=*/2,
+                                               /*expect_change=*/true));
+  EXPECT_EQ(CollectiveCount<HloOpcode::kAllGather>(module), 1);
+}
+
+TEST_F(GpuAllGatherOptimizerTest, MoreThanSingleUserForAllGather) {
+  absl::string_view hlo_string = R"(
+HloModule ReduceScatter
+
+add {
+  x = bf16[] parameter(0)
+  y = bf16[] parameter(1)
+  ROOT add = bf16[] add(x, y)
+}
+
+ENTRY main {
+param.1 = bf16[8,128,1024]{2,1,0} parameter(0)
+param.2 = bf16[8,128,1024]{2,1,0} parameter(1)
+param.3 = bf16[8,128,1024]{2,1,0} parameter(2)
+reduce-scatter.1 = bf16[8,64,1024]{2,1,0} reduce-scatter(param.1), channel_id=8, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, dimensions={1}, to_apply=add
+all-gather.1 = bf16[8,128,1024]{2,1,0} all-gather(reduce-scatter.1), channel_id=5, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={1}, use_global_device_ids=true
+reduce-scatter.2 = bf16[8,64,1024]{2,1,0} reduce-scatter(param.2), channel_id=9, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, dimensions={1}, to_apply=add
+all-gather.2 = bf16[8,128,1024]{2,1,0} all-gather(reduce-scatter.2), channel_id=5, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={1}, use_global_device_ids=true
+reduce-scatter.3 = bf16[8,64,1024]{2,1,0} reduce-scatter(param.3), channel_id=9, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, dimensions={1}, to_apply=add
+all-gather.3 = bf16[8,128,1024]{2,1,0} all-gather(reduce-scatter.3), channel_id=5, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={1}, use_global_device_ids=true
+add.1 = bf16[8,128,1024]{2,1,0} add(all-gather.1, all-gather.3)
+add.2 = bf16[8,128,1024]{2,1,0} add(all-gather.1, all-gather.2)
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
+                                               /*num_replicas=*/8,
+                                               /*num_partitions=*/1,
+                                               /*expect_change=*/false));
+  // see the comment for BranchesOptimized test
+  EXPECT_EQ(CollectiveCount<HloOpcode::kAllGather>(module), 3);
+  EXPECT_EQ(CollectiveCount<HloOpcode::kReduceScatter>(module), 3);
+}
+
+TEST_F(GpuAllGatherOptimizerTest, AllGatherWithOpInBetweenOnRightBranch) {
+  absl::string_view hlo_string = R"(
+HloModule ReduceScatter
+
+add {
+  x = bf16[] parameter(0)
+  y = bf16[] parameter(1)
+  ROOT add = bf16[] add(x, y)
+}
+
+ENTRY main {
+param.1 = bf16[8,128,1024]{2,1,0} parameter(0)
+param.2 = bf16[8,128,1024]{2,1,0} parameter(1)
+param.3 = bf16[8,128,1024]{2,1,0} parameter(2)
+reduce-scatter.1 = bf16[8,64,1024]{2,1,0} reduce-scatter(param.1), channel_id=8, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, dimensions={1}, to_apply=add
+reduce-scatter.2 = bf16[8,64,1024]{2,1,0} reduce-scatter(param.2), channel_id=9, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, dimensions={1}, to_apply=add
+add.1 = bf16[8,64,1024]{2,1,0} add(reduce-scatter.1, reduce-scatter.2)
+all-gather.1 = bf16[8,128,1024]{2,1,0} all-gather(add.1), channel_id=5, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={1}, use_global_device_ids=true
+reduce-scatter.3 = bf16[8,64,1024]{2,1,0} reduce-scatter(param.3), channel_id=9, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, dimensions={1}, to_apply=add
+all-gather.3 = bf16[8,128,1024]{2,1,0} all-gather(reduce-scatter.3), channel_id=5, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={1}, use_global_device_ids=true
+add.2 = bf16[8,128,1024]{2,1,0} add(all-gather.1, all-gather.3)
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
+                                               /*num_replicas=*/8,
+                                               /*num_partitions=*/1,
+                                               /*expect_change=*/true));
+  EXPECT_EQ(CollectiveCount<HloOpcode::kAllGather>(module), 3);
+  EXPECT_EQ(CollectiveCount<HloOpcode::kReduceScatter>(module), 3);
+}
+
+TEST_F(GpuAllGatherOptimizerTest, AllGatherOneSided) {
+  absl::string_view hlo_string = R"(
+HloModule ReduceScatter
+
+add {
+  x = bf16[] parameter(0)
+  y = bf16[] parameter(1)
+  ROOT add = bf16[] add(x, y)
+}
+
+ENTRY main {
+param.1 = bf16[8,128,1024]{2,1,0} parameter(0)
+param.2 = bf16[8,128,1024]{2,1,0} parameter(1)
+param.3 = bf16[8,128,1024]{2,1,0} parameter(2)
+
+add.1 = bf16[8,128,1024]{2,1,0} add(param.1, param.2)
+reduce-scatter = bf16[8,64,1024]{2,1,0} reduce-scatter(param.3), channel_id=9, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, dimensions={1}, to_apply=add
+all-gather = bf16[8,128,1024]{2,1,0} all-gather(reduce-scatter), channel_id=5, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={1}, use_global_device_ids=true
+add.2 = bf16[8,128,1024]{2,1,0} add(all-gather, add.1)
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
+                                               /*num_replicas=*/8,
+                                               /*num_partitions=*/1,
+                                               /*expect_change=*/false));
+  EXPECT_EQ(CollectiveCount<HloOpcode::kAllGather>(module), 1);
+  EXPECT_EQ(CollectiveCount<HloOpcode::kReduceScatter>(module), 1);
+}
+
+TEST_F(GpuAllGatherOptimizerTest, DifferentOperandShapes) {
+  absl::string_view hlo_string = R"(
+HloModule TestModule
+
+ENTRY main {
+param.1 = bf16[8,64,128]{2,1,0} parameter(0)
+param.2 = bf16[8,128,64]{2,1,0} parameter(1)
+all-gather.1 = bf16[8,128,128]{2,1,0} all-gather(param.1), channel_id=5, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={1}, use_global_device_ids=true
+all-gather.2 = bf16[8,128,128]{2,1,0} all-gather(param.2), channel_id=5, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={2}, use_global_device_ids=true
+add.1 = bf16[8,128,128]{2,1,0} add(all-gather.1, all-gather.2)
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
+                                               /*num_replicas=*/8,
+                                               /*num_partitions=*/1,
+                                               /*expect_change=*/false));
+}
+
+}  // namespace
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/all_reduce_blueconnect.cc b/third_party/xla/xla/service/gpu/transforms/all_reduce_blueconnect.cc
new file mode 100644
index 0000000..0e0e67a
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/all_reduce_blueconnect.cc
@@ -0,0 +1,373 @@
+/* Copyright 2021 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/all_reduce_blueconnect.h"
+
+#include <algorithm>
+#include <cstddef>
+#include <cstdint>
+#include <iterator>
+#include <optional>
+#include <utility>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/btree_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/hlo/utils/hlo_query.h"
+#include "xla/service/collective_ops_utils.h"
+#include "xla/service/computation_placer.h"
+#include "xla/service/global_device_id.h"
+#include "xla/service/hlo_creation_utils.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/status_macros.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/logging.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace {
+
+std::vector<HloInstruction*> GetOutputs(HloInstruction& instruction) {
+  if (!instruction.shape().IsTuple()) {
+    return {&instruction};
+  }
+
+  std::vector<HloInstruction*> outputs;
+  outputs.reserve(instruction.shape().tuple_shapes_size());
+
+  HloComputation& computation = *instruction.parent();  // never null
+  for (int i = 0; i < instruction.shape().tuple_shapes_size(); ++i) {
+    outputs.push_back(computation.AddInstruction(
+        HloInstruction::CreateGetTupleElement(&instruction, i)));
+  }
+  return outputs;
+}
+
+struct DecomposedReplicaGroups {
+  std::vector<ReplicaGroup> scatter_gather_groups;
+  std::vector<ReplicaGroup> new_all_reduce_groups;
+};
+
+// Returns the global device id for the given replica id. Returns nullopt if
+// if the replica id can refer to multiple devices, or if the pass does not
+// support the CollectiveOpGroupMode.
+std::optional<GlobalDeviceId> TryConvertingReplicaIdToDeviceId(
+    int64_t replica_id, const DeviceAssignment& device_assignment,
+    CollectiveOpGroupMode collective_group_mode) {
+  if (collective_group_mode == CollectiveOpGroupMode::kCrossReplica) {
+    if (device_assignment.computation_count() != 1) {
+      // If there are multiple partitions, the replica_id may refer to multiple
+      // devices on different partitions.
+      return std::nullopt;
+    }
+    return GlobalDeviceId{device_assignment(replica_id, /*computation_id=*/0)};
+  } else if (collective_group_mode == CollectiveOpGroupMode::kFlattenedID) {
+    int partition_count = device_assignment.computation_count();
+    int64_t actual_replica_id = replica_id / partition_count;
+    int64_t partition_id = replica_id % partition_count;
+    return GlobalDeviceId{device_assignment(actual_replica_id, partition_id)};
+  }
+
+  // kCrossPartition and kCrossReplicaAndPartition are unsupported.
+  VLOG(1) << "Skip AllReduceBlueConnect because of unsupported "
+             "CollectiveOpGroupMode "
+          << CollectiveOpGroupModeToString(collective_group_mode);
+  return std::nullopt;
+}
+
+absl::StatusOr<std::optional<DecomposedReplicaGroups>> TryDecomposeReplicaGroup(
+    const ReplicaGroup& replica_group,
+    const DeviceAssignment& device_assignment, size_t num_devices_per_host,
+    CollectiveOpGroupMode collective_group_mode) {
+  int group_size = replica_group.replica_ids_size();
+  TF_RET_CHECK(group_size > 0);
+
+  absl::btree_map<int, std::vector<int64_t>> replica_ids_by_host;
+  for (int64_t replica_id : replica_group.replica_ids()) {
+    std::optional<GlobalDeviceId> device_id = TryConvertingReplicaIdToDeviceId(
+        replica_id, device_assignment, collective_group_mode);
+    if (!device_id.has_value()) {
+      return {std::nullopt};
+    }
+    TF_RET_CHECK(*device_id >= 0);
+    // We assume that devices are ordered by host.
+    int host_id = device_id->value() / num_devices_per_host;
+    replica_ids_by_host[host_id].push_back(replica_id);
+  }
+
+  size_t num_local_devices = replica_ids_by_host.begin()->second.size();
+  bool same_num_devices_on_each_host =
+      absl::c_all_of(replica_ids_by_host, [&](const auto& entry) {
+        return entry.second.size() == num_local_devices;
+      });
+
+  if (!same_num_devices_on_each_host) {
+    return {std::nullopt};
+  }
+
+  std::vector<int64_t> sorted_replica_group;
+  sorted_replica_group.reserve(group_size);
+  for (const auto& entry : replica_ids_by_host) {
+    absl::c_copy(entry.second, std::back_inserter(sorted_replica_group));
+  }
+
+  size_t scatter_group_size = std::max(num_local_devices, size_t(2));
+  size_t num_scatter_groups = group_size / scatter_group_size;
+
+  if ((group_size % scatter_group_size != 0) || (num_scatter_groups < 2)) {
+    return {std::nullopt};
+  }
+
+  std::vector<ReplicaGroup> scatter_gather_groups(num_scatter_groups);
+  std::vector<ReplicaGroup> new_all_reduce_groups(scatter_group_size);
+
+  for (size_t i = 0; i < group_size; ++i) {
+    int64_t replica_id = sorted_replica_group[i];
+    scatter_gather_groups[i / scatter_group_size].add_replica_ids(replica_id);
+    new_all_reduce_groups[i % scatter_group_size].add_replica_ids(replica_id);
+  }
+
+  return {DecomposedReplicaGroups{std::move(scatter_gather_groups),
+                                  std::move(new_all_reduce_groups)}};
+}
+
+absl::StatusOr<std::optional<DecomposedReplicaGroups>>
+TryDecomposeReplicaGroups(const HloAllReduceInstruction& all_reduce,
+                          size_t num_devices_per_host) {
+  const DeviceAssignment& device_assignment =
+      all_reduce.GetModule()->config().static_device_assignment();
+
+  absl::Span<const ReplicaGroup> replica_groups = all_reduce.replica_groups();
+
+  ReplicaGroup all_replicas;  // only populated if replica groups not present.
+  if (replica_groups.empty()) {
+    for (int i = 0; i < device_assignment.replica_count(); ++i) {
+      all_replicas.add_replica_ids(i);
+    }
+    replica_groups = absl::MakeSpan(&all_replicas, 1);
+  }
+
+  TF_ASSIGN_OR_RETURN(
+      CollectiveOpGroupMode collective_op_group_mode,
+      GetCollectiveOpGroupMode(all_reduce.channel_id().has_value(),
+                               all_reduce.use_global_device_ids()));
+
+  std::vector<ReplicaGroup> scatter_gather_groups;
+  std::vector<ReplicaGroup> new_all_reduce_groups;
+
+  // Try to find a valid decomposition for each replica group.
+  for (const ReplicaGroup& replica_group : replica_groups) {
+    TF_ASSIGN_OR_RETURN(
+        std::optional<DecomposedReplicaGroups> decomposed_groups,
+        TryDecomposeReplicaGroup(replica_group, device_assignment,
+                                 num_devices_per_host,
+                                 collective_op_group_mode));
+
+    if (!decomposed_groups) return {std::nullopt};
+
+    int scatter_group_size =
+        decomposed_groups->scatter_gather_groups[0].replica_ids_size();
+
+    if (scatter_gather_groups.empty()) {
+      // Check that every operand is exactly divisible by scatter group sizes.
+      for (const HloInstruction* operand : all_reduce.operands()) {
+        TF_RET_CHECK(operand->shape().IsArray());
+        int64_t num_elements = ShapeUtil::ElementsIn(operand->shape());
+        if (num_elements % scatter_group_size != 0) {
+          return {std::nullopt};
+        }
+      }
+
+      scatter_gather_groups.reserve(
+          replica_groups.size() *
+          decomposed_groups->scatter_gather_groups.size());
+      new_all_reduce_groups.reserve(
+          replica_groups.size() *
+          decomposed_groups->new_all_reduce_groups.size());
+    } else if (scatter_group_size !=
+               scatter_gather_groups[0].replica_ids_size()) {
+      // Reduce-scatter would have different output shapes on different devices.
+      return {std::nullopt};
+    }
+
+    absl::c_move(decomposed_groups->scatter_gather_groups,
+                 std::back_inserter(scatter_gather_groups));
+    absl::c_move(decomposed_groups->new_all_reduce_groups,
+                 std::back_inserter(new_all_reduce_groups));
+  }
+
+  return {DecomposedReplicaGroups{std::move(scatter_gather_groups),
+                                  std::move(new_all_reduce_groups)}};
+}
+
+// Attempts to decompose all-reduces as described by the BlueConnect paper.
+//
+// If possible, the all-reduce will be transformed into:
+// 1. reduce-scatter
+// 2. all-reduce
+// 3. all-gather
+//
+// If the all-reduce replica groups have more than one device within the same
+// host, the reduce-scatter will be performed over all devices with each host.
+// Otherwise, the reduce-scatter will be performed between pairs of devices on
+// different hosts.
+//
+// When applied repeatedly, this transformation will reproduce the same pattern
+// as described in the BlueConnect paper.
+absl::StatusOr<bool> TryDecomposeAllReduce(HloAllReduceInstruction* all_reduce,
+                                           size_t num_devices_per_host) {
+  TF_RET_CHECK(all_reduce);
+  TF_RET_CHECK(!all_reduce->has_sharding());
+
+  HloComputation& computation = *all_reduce->parent();  // never null
+  PrimitiveType element_type = all_reduce->operand(0)->shape().element_type();
+
+  TF_ASSIGN_OR_RETURN(
+      std::optional<DecomposedReplicaGroups> decomposed_groups,
+      TryDecomposeReplicaGroups(*all_reduce, num_devices_per_host));
+
+  if (!decomposed_groups) return false;
+
+  // Bitcast operands to 1D to guarantee that first dimension is divisible by
+  // scatter group size (we checked num elements was divisible above).
+  std::vector<HloInstruction*> flat_operands;
+  flat_operands.reserve(all_reduce->operand_count());
+  std::vector<Shape> flat_shapes;
+  flat_shapes.reserve(all_reduce->operand_count());
+  std::vector<Shape> scattered_shapes;
+  scattered_shapes.reserve(all_reduce->operand_count());
+
+  int scatter_group_size =
+      decomposed_groups->scatter_gather_groups[0].replica_ids_size();
+
+  for (HloInstruction* operand : all_reduce->operands()) {
+    TF_RET_CHECK(operand->shape().IsArray());
+    int64_t num_elements = ShapeUtil::ElementsIn(operand->shape());
+    Shape flat_shape = ShapeUtil::MakeShape(element_type, {num_elements});
+    flat_operands.push_back(computation.AddInstruction(
+        HloInstruction::CreateBitcast(flat_shape, operand)));
+    flat_shapes.push_back(std::move(flat_shape));
+    scattered_shapes.push_back(ShapeUtil::MakeShape(
+        element_type, {num_elements / scatter_group_size}));
+  }
+
+  Shape reduce_scatter_shape = ShapeUtil::MakeMaybeTupleShape(scattered_shapes);
+
+  int64_t next_channel_id = hlo_query::NextChannelId(*computation.parent());
+  auto get_channel_id = [&]() -> std::optional<int64_t> {
+    if (all_reduce->channel_id().has_value()) {
+      return next_channel_id++;
+    }
+    return std::nullopt;
+  };
+
+  HloInstruction* reduce_scatter =
+      computation.AddInstruction(HloInstruction::CreateReduceScatter(
+          reduce_scatter_shape, flat_operands, all_reduce->to_apply(),
+          CollectiveDeviceList(decomposed_groups->scatter_gather_groups),
+          /*constrain_layout=*/false, get_channel_id(),
+          all_reduce->use_global_device_ids(),
+          /*scatter_dimension=*/0));
+
+  HloInstruction* new_all_reduce =
+      computation.AddInstruction(HloInstruction::CreateAllReduce(
+          reduce_scatter_shape, GetOutputs(*reduce_scatter),
+          all_reduce->to_apply(),
+          CollectiveDeviceList(decomposed_groups->new_all_reduce_groups),
+          /*constrain_layout=*/false, all_reduce->channel_id(),
+          all_reduce->use_global_device_ids()));
+
+  HloInstruction* all_gather =
+      computation.AddInstruction(HloInstruction::CreateAllGather(
+          ShapeUtil::MakeMaybeTupleShape(flat_shapes),
+          GetOutputs(*new_all_reduce),
+          /*all_gather_dimension=*/0,
+          CollectiveDeviceList(decomposed_groups->scatter_gather_groups),
+          /*constrain_layout=*/false, get_channel_id(),
+          all_reduce->use_global_device_ids()));
+
+  // Bitcast back to the original shapes and replace all-reduce with decomposed
+  // implementation.
+  std::vector<HloInstruction*> outputs = GetOutputs(*all_gather);
+  for (int64_t i = 0; i < outputs.size(); ++i) {
+    outputs[i] = computation.AddInstruction(HloInstruction::CreateBitcast(
+        all_reduce->operand(i)->shape(), outputs[i]));
+  }
+  HloInstruction* replacement = MaybeMakeTuple(outputs);
+
+  TF_RETURN_IF_ERROR(
+      all_reduce->CopyAllControlDepsTo(reduce_scatter, replacement));
+
+  TF_RETURN_IF_ERROR(all_reduce->DropAllControlDeps());
+  TF_RETURN_IF_ERROR(computation.ReplaceInstruction(all_reduce, replacement));
+
+  // Try to apply decomposition recursively.
+  TF_RETURN_IF_ERROR(
+      TryDecomposeAllReduce(Cast<HloAllReduceInstruction>(new_all_reduce),
+                            num_devices_per_host)
+          .status());
+  return true;
+}
+
+}  // namespace
+
+absl::StatusOr<bool> AllReduceBlueConnect::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  VLOG(1) << "Running AllReduceBlueConnect";
+
+  if (hlo_query::ContainsLayoutConstrainedAllReduce(*module)) {
+    VLOG(1)
+        << "Skip AllReduceBlueConnect because the module contains all-reduce "
+           "with constrained layouts";
+    return false;
+  }
+  if (!module->config().has_static_device_assignment()) {
+    VLOG(1)
+        << "Skip AllReduceBlueConnect because the module doesn't have static "
+           "device assignment";
+    return false;
+  }
+
+  std::vector<HloAllReduceInstruction*> all_reduces;
+  for (HloComputation* computation :
+       module->MakeNonfusionComputations(execution_threads)) {
+    for (HloInstruction* instruction : computation->instructions()) {
+      if (instruction->opcode() == HloOpcode::kAllReduce) {
+        all_reduces.push_back(Cast<HloAllReduceInstruction>(instruction));
+      }
+    }
+  }
+
+  bool changed = false;
+  for (HloAllReduceInstruction* all_reduce : all_reduces) {
+    TF_ASSIGN_OR_RETURN(
+        bool all_reduce_changed,
+        TryDecomposeAllReduce(all_reduce, num_devices_per_host_));
+    changed |= all_reduce_changed;
+  }
+
+  return changed;
+}
+
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/all_reduce_blueconnect.h b/third_party/xla/xla/service/gpu/transforms/all_reduce_blueconnect.h
new file mode 100644
index 0000000..6da0bbe
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/all_reduce_blueconnect.h
@@ -0,0 +1,56 @@
+/* Copyright 2021 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_ALL_REDUCE_BLUECONNECT_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_ALL_REDUCE_BLUECONNECT_H_
+
+#include <cstddef>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+
+// Decomposes all-reduce operations using the BlueConnect algorithm.
+//
+// Paper: "BLUECONNECT: DECOMPOSING ALL-REDUCE FOR DEEP LEARNING ON
+// HETEROGENEOUS NETWORK HIERARCHY"
+// https://mlsys.org/Conferences/2019/doc/2019/130.pdf
+//
+// This algorithm attempts to minimize the number of levels of network hierarchy
+// traversed for as much data transfer as possible. This implementation assumes
+// that host IDs are ordered corresponding to network hierarchy.
+class AllReduceBlueConnect : public HloModulePass {
+ public:
+  explicit AllReduceBlueConnect(size_t num_devices_per_host)
+      : num_devices_per_host_(num_devices_per_host) {}
+
+  absl::string_view name() const override { return "all-reduce-blueconnect"; }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+  size_t num_devices_per_host_;
+};
+
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_ALL_REDUCE_BLUECONNECT_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/all_reduce_blueconnect_test.cc b/third_party/xla/xla/service/gpu/transforms/all_reduce_blueconnect_test.cc
new file mode 100644
index 0000000..cafbf24
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/all_reduce_blueconnect_test.cc
@@ -0,0 +1,414 @@
+/* Copyright 2021 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/all_reduce_blueconnect.h"
+
+#include <cstddef>
+#include <cstdint>
+#include <memory>
+#include <optional>
+#include <utility>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/computation_placer.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/util.h"
+#include "tsl/platform/status_matchers.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace {
+
+using ::tsl::testing::IsOkAndHolds;
+namespace m = ::xla::match;
+
+using AllReduceBlueConnectTest = HloTestBase;
+
+HloPredicate MatchChannelId(std::optional<int64_t> channel_id) {
+  return [channel_id](const HloInstruction* instruction) {
+    return instruction->channel_id() == channel_id;
+  };
+}
+
+void SetModuleConfig(HloModuleConfig* module_config, size_t replica_count,
+                     size_t partition_count = 1) {
+  DeviceAssignment device_assignment(replica_count,
+                                     /*computation_count=*/partition_count);
+  device_assignment.FillIota(0);
+  module_config->set_replica_count(replica_count);
+  module_config->set_num_partitions(partition_count);
+  module_config->set_static_device_assignment(device_assignment);
+}
+
+void SetModuleConfig(HloModule& module, size_t replica_count,
+                     size_t partition_count = 1) {
+  SetModuleConfig(&module.mutable_config(), replica_count, partition_count);
+}
+
+TEST_F(AllReduceBlueConnectTest, OneStage) {
+  constexpr absl::string_view hlo_string = R"(
+HloModule module
+
+%add {
+  lhs = f32[] parameter(0)
+  rhs = f32[] parameter(1)
+  ROOT add = f32[] add(lhs, rhs)
+}
+
+ENTRY %comp {
+  p0 = f32[4,4] parameter(0)
+  ROOT crs = f32[4,4] all-reduce(p0), to_apply=add
+})";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  SetModuleConfig(*module, /*replica_count=*/8);
+
+  AllReduceBlueConnect pass(/*num_devices_per_host=*/4);
+  EXPECT_THAT(pass.Run(module.get()), IsOkAndHolds(true));
+
+  // clang-format off
+  std::vector<std::vector<int64_t>> scatter_gather_groups = {
+      {0, 1, 2, 3}, {4, 5, 6, 7}};
+  std::vector<std::vector<int64_t>> new_all_reduce_groups = {
+      {0, 4}, {1, 5}, {2, 6}, {3, 7}};
+  // clang-format on
+
+  auto bitcast = m::Bitcast(m::Parameter(0)).WithShape(F32, {16});
+  auto reduce_scatter = m::ReduceScatter(bitcast)
+                            .WithShape(F32, {4})
+                            .WithReplicaGroups(scatter_gather_groups)
+                            .WithPredicate(MatchChannelId(std::nullopt));
+  auto all_reduce = m::AllReduce(reduce_scatter)
+                        .WithShape(F32, {4})
+                        .WithReplicaGroups(new_all_reduce_groups)
+                        .WithPredicate(MatchChannelId(std::nullopt));
+  auto all_gather = m::AllGather(all_reduce)
+                        .WithShape(F32, {16})
+                        .WithReplicaGroups(scatter_gather_groups)
+                        .WithPredicate(MatchChannelId(std::nullopt));
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch(m::Bitcast(all_gather).WithShape(F32, {4, 4})));
+}
+
+TEST_F(AllReduceBlueConnectTest, TwoStage) {
+  constexpr absl::string_view hlo_string = R"(
+HloModule module
+
+%add {
+  lhs = f32[] parameter(0)
+  rhs = f32[] parameter(1)
+  ROOT add = f32[] add(lhs, rhs)
+}
+
+ENTRY %comp {
+  p0 = f32[4,4] parameter(0)
+  ROOT crs = f32[4,4] all-reduce(p0), to_apply=add
+})";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  SetModuleConfig(*module, /*replica_count=*/16);
+
+  AllReduceBlueConnect pass(/*num_devices_per_host=*/4);
+  EXPECT_THAT(pass.Run(module.get()), IsOkAndHolds(true));
+
+  std::vector<std::vector<int64_t>> outer_scatter_gather_groups = {
+      {0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}, {12, 13, 14, 15}};
+  std::vector<std::vector<int64_t>> inner_scatter_gather_groups = {
+      {0, 4}, {8, 12}, {1, 5}, {9, 13}, {2, 6}, {10, 14}, {3, 7}, {11, 15}};
+  std::vector<std::vector<int64_t>> new_all_reduce_groups = {
+      {0, 8}, {4, 12}, {1, 9}, {5, 13}, {2, 10}, {6, 14}, {3, 11}, {7, 15}};
+
+  auto bitcast0 = m::Bitcast(m::Parameter(0)).WithShape(F32, {16});
+  auto reduce_scatter0 =
+      m::ReduceScatter(bitcast0).WithShape(F32, {4}).WithReplicaGroups(
+          outer_scatter_gather_groups);
+  auto bitcast1 = m::Bitcast(reduce_scatter0).WithShape(F32, {4});
+  auto reduce_scatter1 =
+      m::ReduceScatter(bitcast1).WithShape(F32, {2}).WithReplicaGroups(
+          inner_scatter_gather_groups);
+  auto all_reduce = m::AllReduce(reduce_scatter1)
+                        .WithShape(F32, {2})
+                        .WithReplicaGroups(new_all_reduce_groups);
+  auto all_gather0 = m::AllGather(all_reduce)
+                         .WithShape(F32, {4})
+                         .WithReplicaGroups(inner_scatter_gather_groups);
+  auto bitcast2 = m::Bitcast(all_gather0).WithShape(F32, {4});
+  auto all_gather1 =
+      m::AllGather(bitcast2).WithShape(F32, {16}).WithReplicaGroups(
+          outer_scatter_gather_groups);
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch(m::Bitcast(all_gather1).WithShape(F32, {4, 4})));
+}
+
+TEST_F(AllReduceBlueConnectTest, TwoOperands) {
+  constexpr absl::string_view hlo_string = R"(
+HloModule module
+
+%add {
+  lhs = f32[] parameter(0)
+  rhs = f32[] parameter(1)
+  ROOT add = f32[] add(lhs, rhs)
+}
+
+ENTRY %comp {
+  p0 = f32[4,4] parameter(0)
+  p1 = f32[4,4,2] parameter(1)
+  ROOT crs = (f32[4,4], f32[4,4,2]) all-reduce(p0, p1), to_apply=add
+})";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  SetModuleConfig(*module, /*replica_count=*/8);
+
+  AllReduceBlueConnect pass(/*num_devices_per_host=*/4);
+  EXPECT_THAT(pass.Run(module.get()), IsOkAndHolds(true));
+
+  // clang-format off
+  std::vector<std::vector<int64_t>> scatter_gather_groups = {
+      {0, 1, 2, 3}, {4, 5, 6, 7}};
+  std::vector<std::vector<int64_t>> new_all_reduce_groups = {
+      {0, 4}, {1, 5}, {2, 6}, {3, 7}};
+  // clang-format on
+
+  auto bitcast0 = m::Bitcast(m::Parameter(0)).WithShape(F32, {16});
+  auto bitcast1 = m::Bitcast(m::Parameter(1)).WithShape(F32, {32});
+
+  Shape expected0 = ShapeUtil::MakeTupleShape(
+      {ShapeUtil::MakeShape(F32, {4}), ShapeUtil::MakeShape(F32, {8})});
+  Shape expected1 = ShapeUtil::MakeTupleShape(
+      {ShapeUtil::MakeShape(F32, {16}), ShapeUtil::MakeShape(F32, {32})});
+  auto reduce_scatter = m::ReduceScatter(bitcast0, bitcast1)
+                            .WithShapeEqualTo(&expected0)
+                            .WithReplicaGroups(scatter_gather_groups);
+  auto all_reduce = m::AllReduce(m::GetTupleElement(reduce_scatter, 0),
+                                 m::GetTupleElement(reduce_scatter, 1))
+                        .WithShapeEqualTo(&expected0)
+                        .WithReplicaGroups(new_all_reduce_groups);
+  auto all_gather = m::AllGather(m::GetTupleElement(all_reduce, 0),
+                                 m::GetTupleElement(all_reduce, 1))
+                        .WithShapeEqualTo(&expected1)
+                        .WithReplicaGroups(scatter_gather_groups);
+  auto bitcast2 =
+      m::Bitcast(m::GetTupleElement(all_gather, 0)).WithShape(F32, {4, 4});
+  auto bitcast3 =
+      m::Bitcast(m::GetTupleElement(all_gather, 1)).WithShape(F32, {4, 4, 2});
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch(m::Tuple(bitcast2, bitcast3)));
+}
+
+TEST_F(AllReduceBlueConnectTest, MultiplePartitionsFilecheck) {
+  constexpr absl::string_view hlo_string = R"(
+HloModule module
+
+%add {
+  lhs = f32[] parameter(0)
+  rhs = f32[] parameter(1)
+  ROOT add = f32[] add(lhs, rhs)
+}
+
+ENTRY %comp {
+  p0 = f32[8,8] parameter(0)
+  ROOT crs = f32[8,8] all-reduce(p0), channel_id=1,
+    replica_groups={{0,1,2,3,4,5,6,7}}, use_global_device_ids=true, to_apply=add
+})";
+  HloModuleConfig module_config;
+  SetModuleConfig(&module_config, /*replica_count=*/1, /*partition_count=*/8);
+
+  AllReduceBlueConnect pass(/*num_devices_per_host=*/4);
+  // Note: When matching strings like "replica_groups={{0,1,2,3}}", FileCheck
+  // interprets the string inside the double braces as regex. So to match such
+  // strings, we use "replica_groups={{..0,1,2,3..}}", where the dots match the
+  // opening and closing braces.
+  RunAndFilecheckHloRewrite(hlo_string, std::move(pass), R"(
+  CHECK:       %p0 = f32[8,8]{1,0} parameter(0)
+  CHECK-NEXT:  [[bitcast:%[^ ]+]] = f32[64]{0} bitcast(%p0)
+  CHECK-NEXT:  [[reduce_scatter:%[^ ]+]] = f32[16]{0} reduce-scatter([[bitcast]]), channel_id=2, replica_groups={{..0,1,2,3.,.4,5,6,7..}}, use_global_device_ids=true, dimensions={0}, to_apply=%add
+  CHECK-NEXT:  [[all_reduce:%[^ ]+]] = f32[16]{0} all-reduce([[reduce_scatter]]), channel_id=1, replica_groups={{..0,4.,.1,5.,.2,6.,.3,7..}}, use_global_device_ids=true, to_apply=%add
+  CHECK-NEXT:  [[all_gather:%[^ ]+]] = f32[64]{0} all-gather([[all_reduce]]), channel_id=3, replica_groups={{..0,1,2,3.,.4,5,6,7..}}, dimensions={0}, use_global_device_ids=true
+  CHECK-NEXT:  ROOT [[output:%[^ ]+]] = f32[8,8]{1,0} bitcast([[all_gather]])
+}
+)",
+                            /*after_pass_checks=*/nullptr, &module_config);
+}
+
+TEST_F(AllReduceBlueConnectTest, DifferentNumLocalDevicesWithinReplicaGroup) {
+  constexpr absl::string_view hlo_string = R"(
+HloModule module
+
+%add {
+  lhs = f32[] parameter(0)
+  rhs = f32[] parameter(1)
+  ROOT add = f32[] add(lhs, rhs)
+}
+
+ENTRY %comp {
+  p0 = f32[4,4] parameter(0)
+  ROOT crs = f32[4,4] all-reduce(p0),
+    replica_groups={{0,1,2,7},{3,4,5,6}}, to_apply=add
+})";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  SetModuleConfig(*module, /*replica_count=*/8);
+
+  AllReduceBlueConnect pass(/*num_devices_per_host=*/4);
+  EXPECT_THAT(pass.Run(module.get()), IsOkAndHolds(false));
+}
+
+TEST_F(AllReduceBlueConnectTest, DifferentNumLocalDevicesAcrossReplicaGroups) {
+  constexpr absl::string_view hlo_string = R"(
+HloModule module
+
+%add {
+  lhs = f32[] parameter(0)
+  rhs = f32[] parameter(1)
+  ROOT add = f32[] add(lhs, rhs)
+}
+
+ENTRY %comp {
+  p0 = f32[4,4] parameter(0)
+  ROOT crs = f32[4,4] all-reduce(p0),
+    replica_groups={{0,1,4,5},{2,3,6,7},{8,9,10,11},{12,13,14,15}}, to_apply=add
+})";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  SetModuleConfig(*module, /*replica_count=*/16);
+
+  AllReduceBlueConnect pass(/*num_devices_per_host=*/4);
+  EXPECT_THAT(pass.Run(module.get()), IsOkAndHolds(false));
+}
+
+TEST_F(AllReduceBlueConnectTest, OperandIndivisible) {
+  constexpr absl::string_view hlo_string = R"(
+HloModule module
+
+%add {
+  lhs = f32[] parameter(0)
+  rhs = f32[] parameter(1)
+  ROOT add = f32[] add(lhs, rhs)
+}
+
+ENTRY %comp {
+  p0 = f32[4,4] parameter(0)
+  p1 = f32[9] parameter(1)
+  ROOT crs = (f32[4,4], f32[9]) all-reduce(p0, p1), to_apply=add
+})";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  SetModuleConfig(*module, /*replica_count=*/8);
+
+  AllReduceBlueConnect pass(/*num_devices_per_host=*/4);
+  EXPECT_THAT(pass.Run(module.get()), IsOkAndHolds(false));
+}
+
+TEST_F(AllReduceBlueConnectTest, ControlDeps) {
+  constexpr absl::string_view hlo_string = R"(
+HloModule module
+
+%add {
+  lhs = f32[] parameter(0)
+  rhs = f32[] parameter(1)
+  ROOT add = f32[] add(lhs, rhs)
+}
+
+ENTRY %comp {
+  p0 = f32[4,4] parameter(0)
+  p1 = f32[4,4] parameter(1)
+  add = f32[4,4] add(p0, p1)
+  crs = f32[4,4] all-reduce(p0), to_apply=add, control-predecessors={add}
+  ROOT add1 = f32[4,4] add(crs, add), control-predecessors={crs}
+})";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  SetModuleConfig(*module, /*replica_count=*/8);
+
+  // Remember all-reduce's control succ and preds.
+  const HloInstruction* ar =
+      module->entry_computation()->root_instruction()->operand(0);
+  auto expected_preds = ar->control_predecessors();
+  auto expected_succs = ar->control_successors();
+
+  AllReduceBlueConnect pass(/*num_devices_per_host=*/4);
+  EXPECT_THAT(pass.Run(module.get()), IsOkAndHolds(true));
+
+  // clang-format off
+  std::vector<std::vector<int64_t>> scatter_gather_groups = {
+      {0, 1, 2, 3}, {4, 5, 6, 7}};
+  std::vector<std::vector<int64_t>> new_all_reduce_groups = {
+      {0, 4}, {1, 5}, {2, 6}, {3, 7}};
+  // clang-format on
+
+  const HloInstruction *matched_rs, *matched_bitcast;
+  auto bitcast = m::Bitcast(m::Parameter(0)).WithShape(F32, {16});
+  auto reduce_scatter = m::ReduceScatter(&matched_rs, bitcast)
+                            .WithShape(F32, {4})
+                            .WithReplicaGroups(scatter_gather_groups);
+  auto all_reduce = m::AllReduce(reduce_scatter)
+                        .WithShape(F32, {4})
+                        .WithReplicaGroups(new_all_reduce_groups);
+  auto all_gather = m::AllGather(all_reduce)
+                        .WithShape(F32, {16})
+                        .WithReplicaGroups(scatter_gather_groups);
+  HloInstruction* root = module->entry_computation()->root_instruction();
+  ASSERT_THAT(root, GmockMatch(m::Add()));
+
+  EXPECT_THAT(
+      root->operand(0),
+      GmockMatch(
+          m::Bitcast(&matched_bitcast, all_gather).WithShape(F32, {4, 4})));
+
+  // Verify that control dependencies are transferred correctly.
+  EXPECT_THAT(matched_rs, GmockMatch(m::Op().WithControlDeps(
+                              absl::MakeSpan(expected_preds), {})));
+  EXPECT_THAT(matched_bitcast, GmockMatch(m::Op().WithControlDeps(
+                                   {}, absl::MakeSpan(expected_succs))));
+}
+
+TEST_F(AllReduceBlueConnectTest, ReduceScatterUnchanged) {
+  // Tests that this pass does not affect reduce-scatter. In principle, the
+  // BlueConnect algorithm could be applied to reduce-scatter, but for now it
+  // doesn't.
+  constexpr absl::string_view hlo_string = R"(
+HloModule module
+
+%add {
+  lhs = f32[] parameter(0)
+  rhs = f32[] parameter(1)
+  ROOT add = f32[] add(lhs, rhs)
+}
+
+ENTRY %comp {
+  p0 = f32[8,4] parameter(0)
+  ROOT crs = f32[1,4] reduce-scatter(p0), dimensions={0}, to_apply=add
+})";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  SetModuleConfig(*module, /*replica_count=*/8);
+
+  AllReduceBlueConnect pass(/*num_devices_per_host=*/4);
+  EXPECT_THAT(pass.Run(module.get()), IsOkAndHolds(false));
+}
+
+}  // namespace
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/all_reduce_splitter.cc b/third_party/xla/xla/service/gpu/transforms/all_reduce_splitter.cc
new file mode 100644
index 0000000..51f71c0
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/all_reduce_splitter.cc
@@ -0,0 +1,436 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/all_reduce_splitter.h"
+
+#include <cstdint>
+#include <optional>
+#include <string>
+#include <variant>
+#include <vector>
+
+#include "absl/cleanup/cleanup.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "absl/strings/substitute.h"
+#include "xla/hlo/ir/collective_device_list.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/hlo/utils/hlo_query.h"
+#include "xla/service/collective_opt_utils.h"
+#include "xla/service/hlo_module_config.h"
+#include "xla/shape.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace {
+
+// Structure containing the newly calculated replica groups.
+struct ARReplicaGroups {
+  // First AR's replica group.
+  std::vector<ReplicaGroup> first_ar_replica_groups;
+  // Second AR's replica group.
+  std::vector<ReplicaGroup> second_ar_replica_groups;
+};
+
+// Contains relevant data to rewrite the AR + DS into AR + DS + AR.
+struct AllReduceRewriteSpec {
+  // Determines a dimension on which DS occurs.
+  int split_dim;
+  // Determines the size of the process group.
+  int group_size;
+  // AllReduce instruction to be rewritten.
+  HloAllReduceInstruction* all_reduce;
+  // DynamicSlice following the `all_reduce` indicating logical RS.
+  HloDynamicSliceInstruction* dynamic_slice;
+  // New replica groups for an `all_reduce`.
+  ARReplicaGroups replica_groups;
+
+  std::string ToString() {
+    return absl::Substitute(
+        "{\n split_dim=$0\n group_size=$1\n all_reduce=$2\n "
+        "dynamic_slice=$3\n}\n",
+        split_dim, group_size, all_reduce->ToString(),
+        dynamic_slice->ToString());
+  }
+};
+
+// Contains the relevant metadata for debugging why rewrite is infeasible.
+struct RewriteInfeasibleReason {
+  // Instruction for which it is infeasible to do a rewrite.
+  const HloInstruction* ar;
+  // Describes a reason of infeasibility.
+  std::string message;
+};
+
+// Hashable container to hold replica groups.
+struct ReplicaGroups {
+  std::vector<ReplicaGroup> replica_groups;
+
+  template <typename H>
+  friend H AbslHashValue(H h, const ReplicaGroups& rg) {
+    return H::combine(std::move(h), rg.replica_groups.size());
+  }
+
+  friend bool operator==(const ReplicaGroups& item,
+                         const ReplicaGroups& other) {
+    if (item.replica_groups.size() != other.replica_groups.size()) {
+      return false;
+    }
+    for (int i = 0; i < item.replica_groups.size(); i++) {
+      const ReplicaGroup& item_replica_group = item.replica_groups[i];
+      const ReplicaGroup& other_replica_group = other.replica_groups[i];
+      for (int i = 0; i < item_replica_group.replica_ids_size(); i++) {
+        if (item_replica_group.replica_ids(i) !=
+            other_replica_group.replica_ids(i)) {
+          return false;
+        }
+      }
+    }
+    return true;
+  }
+};
+
+using ARReplicaGroupMap =
+    absl::flat_hash_map<ReplicaGroups,
+                        std::vector<const HloAllReduceInstruction*>>;
+
+using RewriteDecision =
+    std::variant<AllReduceRewriteSpec, RewriteInfeasibleReason>;
+
+// Returns a single dimension which is being split by `ds`. Returns
+// std::nullopt if there are more, or no dimension to be split.
+std::optional<int> GetSplitDim(const HloAllReduceInstruction& ar,
+                               const HloDynamicSliceInstruction& ds) {
+  int split_dim = -1;
+  int num_dims = 0;
+  for (int64_t dim = 0; dim < ar.shape().rank(); ++dim) {
+    if (ar.shape().dimensions(dim) != ds.shape().dimensions(dim)) {
+      num_dims++;
+      split_dim = dim;
+    }
+  }
+  if (num_dims != 1) {
+    VLOG(2) << "No support for multiple nor 0 split dims.";
+    return std::nullopt;
+  }
+  return split_dim;
+}
+
+// For input collective instruction `ar` get the process group size (# shards).
+std::optional<int> GetProcessGroupSize(const HloAllReduceInstruction& ar,
+                                       const HloDynamicSliceInstruction& ds) {
+  CHECK(ds.operand(0) == &ar) << "Irrelevant AR + DS pair.";
+  std::optional<int> split_dim = GetSplitDim(ar, ds);
+  if (!split_dim.has_value()) {
+    return std::nullopt;
+  }
+
+  return ar.shape().dimensions(*split_dim) /
+         ds.dynamic_slice_sizes()[*split_dim];
+}
+
+ARReplicaGroupMap GetReplicaGroupsMap(HloComputation& computation) {
+  ARReplicaGroupMap map;
+  hlo_query::ForEachInstructionWithOpcode(
+      computation, HloOpcode::kAllReduce,
+      [&map](const HloInstruction* instruction) {
+        const HloAllReduceInstruction* ar =
+            Cast<HloAllReduceInstruction>(instruction);
+        auto rgs = ReplicaGroups{ar->replica_groups()};
+        map[rgs].push_back(ar);
+      });
+  return map;
+}
+
+ARReplicaGroups GetNewReplicaGroups(int group_size, int num_partitions) {
+  CHECK_EQ(num_partitions % group_size, 0);
+
+  std::vector<ReplicaGroup> first_ar_rgs, second_ar_rgs;
+  int num_units = num_partitions / group_size;
+  first_ar_rgs.reserve(num_units);
+  second_ar_rgs.reserve(group_size);
+
+  // Construct first AR replica groups.
+  for (int u = 0; u < group_size * num_units; u += group_size) {
+    ReplicaGroup& group = first_ar_rgs.emplace_back();
+    for (int r = u; r < u + group_size; r++) {
+      group.add_replica_ids(r);
+    }
+  }
+
+  // Construct second AR replica groups.
+  for (int g = 0; g < group_size; g++) {
+    ReplicaGroup& group = second_ar_rgs.emplace_back();
+    for (int r = g; r < group_size * num_units; r += group_size) {
+      group.add_replica_ids(r);
+    }
+  }
+  return {
+      /*first_ar_replica_groups=*/first_ar_rgs,
+      /*second_ar_replica_groups=*/second_ar_rgs,
+  };
+}
+
+// Returns true if `spec` can be transformed into a logical reduce scatter.
+// False otherwise.
+bool IsLogicalReduceScatter(const HloModule& module,
+                            const AllReduceRewriteSpec& spec,
+                            HloComputation& computation) {
+  HloAllReduceInstruction& ar = *spec.all_reduce;
+  CHECK_EQ(ar.user_count(), 1);
+  CHECK_EQ(module.config().replica_count(), 1);
+
+  HloInstruction* first_ar =
+      computation.AddInstruction(HloInstruction::CreateAllReduce(
+          ar.shape(), ar.operands(), ar.to_apply(),
+          CollectiveDeviceList(spec.replica_groups.first_ar_replica_groups),
+          ar.constrain_layout(), hlo_query::NextChannelId(module),
+          ar.use_global_device_ids()));
+
+  HloInstruction* ds = ar.users()[0];
+  auto* old_operand = ds->mutable_operand(0);
+  if (!ds->ReplaceOperandWith(0, first_ar).ok()) {
+    return false;
+  }
+  absl::Cleanup _ = [&] {
+    CHECK_OK(ds->ReplaceOperandWith(0, old_operand));
+    CHECK_OK(computation.RemoveInstruction(first_ar));
+  };
+  return MatchReduceScatter(Cast<HloAllReduceInstruction>(first_ar),
+                            module.config().num_partitions(),
+                            module.config().replica_count(),
+                            /*allow_multiple_split_dims=*/false,
+                            /*allow_intervening_reshape=*/true)
+      .has_value();
+}
+
+// Determine whether the given `spec`'s AllReduce instruction is profitable to
+// split. Currently it employs a simple heuristic, and it checks whether there
+// exists at least one all reduce with same replica groups as any of the all
+// reduce's replica groups after the potential split.
+bool IsProfitableToSplit(const ARReplicaGroupMap& replica_map,
+                         const AllReduceRewriteSpec& spec) {
+  auto new_rgs = spec.replica_groups;
+  bool first_replica_exists =
+      replica_map.contains(ReplicaGroups{new_rgs.first_ar_replica_groups});
+  bool second_replica_exists =
+      replica_map.contains(ReplicaGroups{new_rgs.second_ar_replica_groups});
+  return first_replica_exists || second_replica_exists;
+}
+
+RewriteDecision CanRewrite(const HloModule& module,
+                           const ARReplicaGroupMap& replica_map,
+                           HloComputation& computation,
+                           HloInstruction& instruction) {
+  // We rely on SPMD partitioning enabled, thus asserting `replica_count` = 1.
+  const HloModuleConfig& config = module.config();
+  if (config.use_auto_spmd_partitioning() || !config.use_spmd_partitioning() ||
+      config.replica_count() != 1) {
+    return RewriteInfeasibleReason{
+        &instruction,
+        "Supporting only SPMD partitioning scheme.",
+    };
+  }
+
+  if (instruction.opcode() != HloOpcode::kAllReduce) {
+    return RewriteInfeasibleReason{
+        &instruction,
+        "Cannot rewrite an AllReduce, since it's not AllReduce.",
+    };
+  }
+
+  auto* ar = Cast<HloAllReduceInstruction>(&instruction);
+
+  if (!ar->use_global_device_ids()) {
+    return RewriteInfeasibleReason{
+        &instruction,
+        "Only global ids are supported currently.",
+    };
+  }
+
+  if (ar->user_count() != 1 ||
+      ar->users().front()->opcode() != HloOpcode::kDynamicSlice) {
+    return RewriteInfeasibleReason{
+        &instruction,
+        "Cannot rewrite AllReduce if it is not a logical reduce scatter.",
+    };
+  }
+
+  auto* ds = Cast<HloDynamicSliceInstruction>(ar->users().front());
+
+  if (ds->user_count() > 1) {
+    return RewriteInfeasibleReason{
+        &instruction,
+        "Exactly one user of dynamic slice is required for a rewrite.",
+    };
+  }
+
+  int num_partitions = config.num_partitions();
+
+  std::vector<ReplicaGroup> rgs = ar->replica_groups();
+  if (rgs.size() != 1 || rgs.front().replica_ids_size() != num_partitions) {
+    return RewriteInfeasibleReason{
+        &instruction,
+        absl::StrCat("Cannot determine a valid split with num_partitions: ",
+                     num_partitions),
+    };
+  }
+
+  std::optional<int> split_dim = GetSplitDim(*ar, *ds);
+  if (!split_dim.has_value()) {
+    return RewriteInfeasibleReason{
+        &instruction,
+        "Cannot get a split dim.",
+    };
+  }
+
+  std::optional<int> group_size = GetProcessGroupSize(*ar, *ds);
+  if (!group_size.has_value()) {
+    return RewriteInfeasibleReason{
+        &instruction,
+        "Cannot determine a group size.",
+    };
+  }
+
+  if (num_partitions == group_size) {
+    return RewriteInfeasibleReason{
+        &instruction,
+        "Nothing to rewrite",
+    };
+  }
+
+  if (num_partitions % *group_size != 0) {
+    return RewriteInfeasibleReason{
+        &instruction,
+        "Group size does not evenly divide the number of partitions",
+    };
+  }
+
+  auto spec = AllReduceRewriteSpec{
+      /*split_dim=*/*split_dim,
+      /*group_size=*/*group_size,
+      /*all_reduce=*/ar,
+      /*dynamic_slice=*/ds,
+      /*replica_groups=*/GetNewReplicaGroups(*group_size, num_partitions),
+  };
+
+  if (!IsLogicalReduceScatter(module, spec, computation)) {
+    return RewriteInfeasibleReason{
+        &instruction,
+        "Not a logical reduce scatter.",
+    };
+  }
+
+  if (!IsProfitableToSplit(replica_map, spec)) {
+    return RewriteInfeasibleReason{
+        &instruction,
+        "Splitting is not profitable.",
+    };
+  }
+
+  return spec;
+}
+
+absl::StatusOr<bool> SplitAllReduce(const HloModuleConfig& config,
+                                    AllReduceRewriteSpec spec,
+                                    HloComputation& computation) {
+  int64_t next_channel_id =
+      hlo_query::NextChannelId(*spec.all_reduce->GetModule());
+  VLOG(1) << "AR splitting spec: " << spec.ToString();
+  // Create first AR.
+  int num_partitions = config.num_partitions();
+  // # of shards within a replica
+  int group_size = spec.group_size;
+
+  CHECK_EQ(num_partitions % group_size, 0);
+
+  HloAllReduceInstruction& ar = *spec.all_reduce;
+  HloDynamicSliceInstruction& ds = *spec.dynamic_slice;
+
+  const auto& [first_ar_replica_groups, second_ar_replica_groups] =
+      spec.replica_groups;
+  int channel_id = next_channel_id++;
+  HloInstruction* first_ar =
+      computation.AddInstruction(HloInstruction::CreateAllReduce(
+          ar.shape(), ar.operands(), ar.to_apply(),
+          CollectiveDeviceList(first_ar_replica_groups), ar.constrain_layout(),
+          channel_id, ar.use_global_device_ids()));
+
+  // Create second AR.
+  channel_id = next_channel_id++;
+  HloInstruction* second_ar =
+      computation.AddInstruction(HloInstruction::CreateAllReduce(
+          ds.shape(), {&ds}, ar.to_apply(),
+          CollectiveDeviceList(second_ar_replica_groups), ar.constrain_layout(),
+          channel_id, ar.use_global_device_ids()));
+
+  // Rewire.
+  TF_RETURN_IF_ERROR(computation.ReplaceInstruction(&ar, first_ar));
+  if (ds.IsRoot()) {
+    computation.set_root_instruction(second_ar);
+  }
+  TF_RETURN_IF_ERROR(ds.ReplaceAllUsesWith(second_ar));
+  return true;  // changed
+}
+
+// Splits `instruction` if it finds it is feasible and profitable to do so.
+// Return true if `instruction` has been rewritten, or false otherwise.
+absl::StatusOr<bool> SplitAllReduce(const HloModule& module,
+                                    const ARReplicaGroupMap& replica_map,
+                                    HloComputation& computation,
+                                    HloInstruction& instruction) {
+  RewriteDecision spec =
+      CanRewrite(module, replica_map, computation, instruction);
+  if (std::holds_alternative<RewriteInfeasibleReason>(spec)) {
+    auto reason = std::get<RewriteInfeasibleReason>(spec);
+    VLOG(1) << "Cannot process {" << reason.ar->ToString()
+            << "} due to : " << reason.message;
+    return false;  // changed
+  }
+  return SplitAllReduce(module.config(), std::get<AllReduceRewriteSpec>(spec),
+                        computation);  // changed
+}
+
+}  // namespace
+
+absl::StatusOr<bool> AllReduceSplitter::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  bool changed = false;
+
+  for (auto* computation : module->computations(execution_threads)) {
+    ARReplicaGroupMap replica_map = GetReplicaGroupsMap(*computation);
+    for (HloInstruction* instr : computation->MakeInstructionPostOrder()) {
+      TF_ASSIGN_OR_RETURN(bool rewritten, SplitAllReduce(*module, replica_map,
+                                                         *computation, *instr));
+      changed |= rewritten;
+    }
+  }
+
+  return changed;
+}
+
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/all_reduce_splitter.h b/third_party/xla/xla/service/gpu/transforms/all_reduce_splitter.h
new file mode 100644
index 0000000..91e0811
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/all_reduce_splitter.h
@@ -0,0 +1,77 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_ALL_REDUCE_SPLITTER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_ALL_REDUCE_SPLITTER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+
+// Rewrites global AR if it is in the form of AR + DS and matches existing
+// replica groups into a logical RS followed by AR.
+//
+// If the pass detects AR followed by DS, then it checks whether
+// it is profitable to break it down into a logical RS (but AR + DS still),
+// followed by an AR to keep the rewrite numerically equivalent.
+//
+// Consider following example:
+//
+// Input program:
+//   HloModule m, num_partitions=8
+//     p = partition_id()
+//     ar = bf16[32] all-reduce(x), replica_groups={{0,1,2,3,4,5,6,7}}
+//     ds = dynamic-slice(ar, pointer(partition_id)), dynamic_slice_sizes={8}
+//
+// There is a global AR performing a reduction over 8 partitions.
+// However DS is performing 8-sized slice of a 32-sized tensor which implies
+// only 4 distinct slices of a tensor, which further implies 2 replicas of each
+// calculated slice. This can be expressed as RS within the replicas followed by
+// AR across the replicas. The transformation limits collectives to the data
+// that is actually needed for the requested slice.
+//
+// Output program:
+//   HloModule m, num_partitions=8
+//     p = partition_id()
+//     ar = bf16[32] all-reduce(x), replica_groups={{0,1,2,3},{4,5,6,7}}
+//     ds = dynamic-slice(ar, pointer(partition_id)), dynamic_slice_sizes={8}
+//     ar.2 = bf16[32] all-reduce(ds), replica_groups={{0,4},{1,5},{2,6},{3,7}}
+//
+// In addition the pass does the rewrite only if it finds it profitable to do
+// so. The profitability function is simple, and just checks whether there are
+// any collectives with same replica groups. If there are then the combiner pass
+// can pick it up, and fuse it into the same NCCL call.
+//
+// While the solution is orthogonal to existing known distribution patterns, in
+// practice it is profitable for HSDP style communication pattern.
+// https://arxiv.org/pdf/2203.11014
+//
+class AllReduceSplitter : public HloModulePass {
+ public:
+  absl::string_view name() const override { return "all-reduce-splitter"; }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_ALL_REDUCE_SPLITTER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/all_reduce_splitter_test.cc b/third_party/xla/xla/service/gpu/transforms/all_reduce_splitter_test.cc
new file mode 100644
index 0000000..ec2e66d
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/all_reduce_splitter_test.cc
@@ -0,0 +1,506 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/all_reduce_splitter.h"
+
+#include <cstddef>
+#include <cstdint>
+#include <memory>
+#include <string>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/algorithm/container.h"
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/gpu/transforms/reduce_scatter_creator.h"
+#include "xla/service/hlo_module_config.h"
+#include "xla/service/hlo_pass_pipeline.h"
+#include "xla/tests/filecheck.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/tsl/lib/core/status_test_util.h"
+#include "tsl/platform/status_matchers.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+using ::tsl::testing::IsOkAndHolds;
+
+class AllReduceSplitterTest : public HloTestBase {
+ public:
+  absl::StatusOr<std::unique_ptr<HloModule>> PrepareModule(
+      absl::string_view hlo_module, int64_t num_replicas,
+      int64_t num_partitions) {
+    HloModuleConfig config = GetModuleConfigForTest(
+        /*replica_count=*/num_replicas,
+        /*num_partitions=*/num_partitions);
+    config.set_use_spmd_partitioning(num_partitions > 1);
+    return ParseAndReturnVerifiedModule(hlo_module, config);
+  }
+
+  size_t AllReduceCount(const HloModule &module) {
+    return CollectiveCount(module, HloOpcode::kAllReduce);
+  }
+
+ private:
+  size_t CollectiveCount(const HloModule &module, HloOpcode opcode) {
+    return absl::c_count_if(
+        module.entry_computation()->instructions(),
+        [&opcode](HloInstruction *instr) { return instr->opcode() == opcode; });
+  }
+};
+
+class AllReduceSplitterFilecheckTest : public AllReduceSplitterTest {
+ public:
+  absl::Status FileCheck(const std::string &hlo_text,
+                         absl::string_view pattern) {
+    TF_ASSIGN_OR_RETURN(bool matched, RunFileCheck(hlo_text, pattern));
+    if (!matched) {
+      return absl::InternalError("Filecheck failed.");
+    }
+    return absl::OkStatus();
+  }
+};
+
+TEST_F(
+    AllReduceSplitterFilecheckTest,
+    MatchBasicPatternIfDynamicSliceIsRootAndThereExistsAllReduceWithSameReplicaGroups) {  // NOLINT
+  absl::string_view hlo_string = R"(
+HloModule m
+
+sum {
+  a = bf16[] parameter(0)
+  b = bf16[] parameter(1)
+  ROOT _ = bf16[] add(a,b)
+}
+
+ENTRY main {
+  p = bf16[2,4096,4096] parameter(0)
+  first.ar = bf16[2,4096,4096] all-reduce(p), replica_groups={{0,1,2,3},{4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=1
+  zero = bf16[] constant(0)
+  reduce = bf16[4096] reduce(first.ar, zero), dimensions={0,1}, to_apply=sum
+  all-reduce = bf16[4096] all-reduce(reduce), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=2
+  table = s32[8]{0} constant({0,1,2,3,0,1,2,3})
+  pid = u32[] partition-id()
+  id = s32[1] dynamic-slice(table, pid), dynamic_slice_sizes={1}
+  reshape = s32[] reshape(id)
+  slice_size = s32[] constant(1024)
+  offset = s32[] multiply(reshape, slice_size)
+  ROOT _ = bf16[1024] dynamic-slice(all-reduce, offset), dynamic_slice_sizes={1024}
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(
+      std::unique_ptr<HloModule> module,
+      PrepareModule(hlo_string, /*num_replicas=*/1, /*num_partitions=*/8));
+
+  EXPECT_THAT(AllReduceSplitter().Run(module.get()), IsOkAndHolds(true));
+  TF_EXPECT_OK(FileCheck(module->ToString(), R"(
+    CHECK-DAG:    %[[P0:.*]] = bf16[2,4096,4096]{2,1,0} parameter(0)
+    CHECK:        %[[AR0:.*]] = bf16[2,4096,4096]{2,1,0} all-reduce(bf16[2,4096,4096]{2,1,0} %[[P0]])
+    CHECK-SAME:   replica_groups={[[DESIRED_RGS:.*]]}
+    CHECK-DAG:    %[[ZERO:.*]] = bf16[] constant(0)
+    CHECK-DAG:    %[[LOCAL_REDUCE:.*]] = bf16[4096]{0} reduce(bf16[2,4096,4096]{2,1,0} %[[AR0]], bf16[] %[[ZERO]])
+    CHECK:        %[[AR1:.*]] = bf16[4096]{0} all-reduce(bf16[4096]{0} %[[LOCAL_REDUCE]])
+    CHECK-SAME:   replica_groups={[[DESIRED_RGS]]}
+    CHECK:        %[[DS:.*]] = bf16[1024]{0} dynamic-slice(bf16[4096]{0} %[[AR1]], s32[] %[[_:.*]])
+    CHECK-SAME:   dynamic_slice_sizes={1024}
+    CHECK-NEXT:   ROOT %[[AR2:.*]] = bf16[1024]{0} all-reduce(bf16[1024]{0} %[[DS]])
+    CHECK-SAME:   replica_groups={{[{]}}{0,4},{1,5},{2,6},{3,7}{{[}]}}
+    )"));
+}
+
+TEST_F(
+    AllReduceSplitterTest,
+    DoesNotMatchMatchBasicPatternIfDynamicSliceIsRootAndThereIsNoAllReduceWithSameReplicaGroups) {  // NOLINT
+  absl::string_view hlo_string = R"(
+HloModule m
+
+sum {
+  a = bf16[] parameter(0)
+  b = bf16[] parameter(1)
+  ROOT _ = bf16[] add(a,b)
+}
+
+ENTRY main {
+  p = bf16[2,4096,4096] parameter(0)
+  zero = bf16[] constant(0)
+  reduce = bf16[4096] reduce(p, zero), dimensions={0,1}, to_apply=sum
+  all-reduce = bf16[4096] all-reduce(reduce), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=2
+  table = s32[8]{0} constant({0,1,2,3,0,1,2,3})
+  pid = u32[] partition-id()
+  id = s32[1] dynamic-slice(table, pid), dynamic_slice_sizes={1}
+  reshape = s32[] reshape(id)
+  slice_size = s32[] constant(1024)
+  offset = s32[] multiply(reshape, slice_size)
+  ROOT _ = bf16[1024] dynamic-slice(all-reduce, offset), dynamic_slice_sizes={1024}
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(
+      std::unique_ptr<HloModule> module,
+      PrepareModule(hlo_string, /*num_replicas=*/1, /*num_partitions=*/8));
+
+  EXPECT_THAT(AllReduceSplitter().Run(module.get()), IsOkAndHolds(false));
+
+  EXPECT_EQ(AllReduceCount(*module), 1);
+}
+
+TEST_F(
+    AllReduceSplitterFilecheckTest,
+    MatchBasicPatternIfDynamicSliceIsNotRootAndThereExistsAllReduceWithSameReplicaGroups) {  // NOLINT
+  absl::string_view hlo_string = R"(
+HloModule m
+
+sum {
+  a = bf16[] parameter(0)
+  b = bf16[] parameter(1)
+  ROOT _ = bf16[] add(a,b)
+}
+
+ENTRY main {
+  p = bf16[2,4096,4096] parameter(0)
+  zero = bf16[] constant(0)
+  first.ar = bf16[2,4096,4096] all-reduce(p), replica_groups={{0,1,2,3},{4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=1
+  reduce = bf16[4096] reduce(p, zero), dimensions={0,1}, to_apply=sum
+  all-reduce = bf16[4096] all-reduce(reduce), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=1
+  table = s32[8]{0} constant({0,1,2,3,0,1,2,3})
+  pid = u32[] partition-id()
+  id = s32[1] dynamic-slice(table, pid), dynamic_slice_sizes={1}
+  reshape = s32[] reshape(id)
+  slice_size = s32[] constant(1024)
+  offset = s32[] multiply(reshape, slice_size)
+  dynamic_slice = bf16[1024] dynamic-slice(all-reduce, offset), dynamic_slice_sizes={1024}
+  broadcast = bf16[1024,1024] broadcast(dynamic_slice), dimensions={0}
+  ROOT _ = tuple(broadcast, first.ar)
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(
+      std::unique_ptr<HloModule> module,
+      PrepareModule(hlo_string, /*num_replicas=*/1, /*num_partitions=*/8));
+
+  EXPECT_THAT(AllReduceSplitter().Run(module.get()), IsOkAndHolds(true));
+  TF_EXPECT_OK(FileCheck(module->ToString(), R"(
+    CHECK-DAG:    %[[P0:.*]] = bf16[2,4096,4096]{2,1,0} parameter(0)
+    CHECK-DAG:    %[[ZERO:.*]] = bf16[] constant(0)
+    CHECK-DAG:    %[[LOCAL_REDUCE:.*]] = bf16[4096]{0} reduce(bf16[2,4096,4096]{2,1,0} %[[P0]], bf16[] %[[ZERO]])
+    CHECK:        %[[AR0:.*]] = bf16[4096]{0} all-reduce(bf16[4096]{0} %[[LOCAL_REDUCE]])
+    CHECK-SAME:   replica_groups={[[DESIRED_RGS:.*]]}
+    CHECK:        %[[DS:.*]] = bf16[1024]{0} dynamic-slice(bf16[4096]{0} %[[AR0]], s32[] %[[_:.*]])
+    CHECK-SAME:   dynamic_slice_sizes={1024}
+    CHECK-NEXT:   %[[AR1:.*]] = bf16[1024]{0} all-reduce(bf16[1024]{0} %[[DS]])
+    CHECK-SAME:   replica_groups={{[{]}}{0,4},{1,5},{2,6},{3,7}{{[}]}}
+    CHECK:        %[[EXISTING_AR:.*]] = bf16[2,4096,4096]{2,1,0} all-reduce(bf16[2,4096,4096]{2,1,0} %[[P0]])
+    CHECK-SAME:   replica_groups={[[DESIRED_RGS]]}
+    CHECK:        ROOT
+    CHECK-NOT:    %[[AR1]]
+    CHECK-SAME:   %[[EXISTING_AR]]
+    )"));
+}
+
+TEST_F(
+    AllReduceSplitterTest,
+    DoesNotMatchBasicPatternIfDynamicSliceIsNotRootAndThereIsNoAllReduceWithSameReplicaGroups) {  // NOLINT
+  absl::string_view hlo_string = R"(
+HloModule m
+
+sum {
+  a = bf16[] parameter(0)
+  b = bf16[] parameter(1)
+  ROOT _ = bf16[] add(a,b)
+}
+
+ENTRY main {
+  p = bf16[2,4096,4096] parameter(0)
+  p.1 = bf16[2,4096,4096] parameter(1)
+  zero = bf16[] constant(0)
+  reduce = bf16[4096] reduce(p, zero), dimensions={0,1}, to_apply=sum
+  all-reduce = bf16[4096] all-reduce(reduce), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=1
+  table = s32[8]{0} constant({0,1,2,3,0,1,2,3})
+  pid = u32[] partition-id()
+  id = s32[1] dynamic-slice(table, pid), dynamic_slice_sizes={1}
+  reshape = s32[] reshape(id)
+  slice_size = s32[] constant(1024)
+  offset = s32[] multiply(reshape, slice_size)
+  dynamic_slice = bf16[1024] dynamic-slice(all-reduce, offset), dynamic_slice_sizes={1024}
+  broadcast = bf16[1024,1024] broadcast(dynamic_slice), dimensions={0}
+  add = bf16[2,4096,4096] add(p,p.1)
+  ROOT _ = tuple(broadcast, add)
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(
+      std::unique_ptr<HloModule> module,
+      PrepareModule(hlo_string, /*num_replicas=*/1, /*num_partitions=*/8));
+
+  EXPECT_THAT(AllReduceSplitter().Run(module.get()), IsOkAndHolds(false));
+  EXPECT_EQ(AllReduceCount(*module), 1);
+}
+
+TEST_F(AllReduceSplitterTest,
+       DoesNotMatchBasicPatternIfDynamicSliceIsFullySharded) {  // NOLINT
+  absl::string_view hlo_string = R"(
+HloModule m
+
+sum {
+  a = bf16[] parameter(0)
+  b = bf16[] parameter(1)
+  ROOT _ = bf16[] add(a,b)
+}
+
+ENTRY main {
+  p = bf16[2,4096,4096] parameter(0)
+  first.ar = bf16[2,4096,4096] all-reduce(p), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=1
+  zero = bf16[] constant(0)
+  reduce = bf16[4096] reduce(first.ar, zero), dimensions={0,1}, to_apply=sum
+  all-reduce = bf16[4096] all-reduce(reduce), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=2
+  table = s32[8]{0} constant({0,1,2,3,0,1,2,3})
+  pid = u32[] partition-id()
+  id = s32[1] dynamic-slice(table, pid), dynamic_slice_sizes={1}
+  reshape = s32[] reshape(id)
+  slice_size = s32[] constant(512)
+  offset = s32[] multiply(reshape, slice_size)
+  ROOT _ = bf16[512] dynamic-slice(all-reduce, offset), dynamic_slice_sizes={512}
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(
+      std::unique_ptr<HloModule> module,
+      PrepareModule(hlo_string, /*num_replicas=*/1, /*num_partitions=*/8));
+
+  EXPECT_THAT(AllReduceSplitter().Run(module.get()), IsOkAndHolds(false));
+  EXPECT_EQ(AllReduceCount(*module), 2);
+}
+
+TEST_F(AllReduceSplitterTest,
+       DoesNotMatchBasicPatternIfItIsNotCompiledWithSPMDPartitioning) {
+  absl::string_view hlo_string = R"(
+HloModule m
+
+sum {
+  a = bf16[] parameter(0)
+  b = bf16[] parameter(1)
+  ROOT _ = bf16[] add(a,b)
+}
+
+ENTRY main {
+  p = bf16[2,4096,4096] parameter(0)
+  first.ar = bf16[2,4096,4096] all-reduce(p), replica_groups={{0,1,2,3},{4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=1
+  zero = bf16[] constant(0)
+  reduce = bf16[4096] reduce(first.ar, zero), dimensions={0,1}, to_apply=sum
+  all-reduce = bf16[4096] all-reduce(reduce), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=2
+  table = s32[8]{0} constant({0,1,2,3,0,1,2,3})
+  pid = u32[] partition-id()
+  id = s32[1] dynamic-slice(table, pid), dynamic_slice_sizes={1}
+  reshape = s32[] reshape(id)
+  slice_size = s32[] constant(1024)
+  offset = s32[] multiply(reshape, slice_size)
+  ROOT _ = bf16[1024] dynamic-slice(all-reduce, offset), dynamic_slice_sizes={1024}
+}
+)";
+  HloModuleConfig config =
+      GetModuleConfigForTest(/*replica_count=*/1, /*num_partitions=*/8);
+  config.set_use_spmd_partitioning(false);
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_string, config));
+
+  EXPECT_THAT(AllReduceSplitter().Run(module.get()), IsOkAndHolds(false));
+  EXPECT_THAT(AllReduceCount(*module), 2);
+}
+
+TEST_F(AllReduceSplitterTest,
+       DoesNotMatchBasicPatternIfUseGlobalDeviceIdsIsFalse) {  // NOLINT
+  absl::string_view hlo_string = R"(
+HloModule m
+
+sum {
+  a = bf16[] parameter(0)
+  b = bf16[] parameter(1)
+  ROOT _ = bf16[] add(a,b)
+}
+
+ENTRY main {
+  p = bf16[2,4096,4096] parameter(0)
+  first.ar = bf16[2,4096,4096] all-reduce(p), replica_groups={{0,1,2,3},{4,5,6,7}}, to_apply=sum, channel_id=1
+  zero = bf16[] constant(0)
+  reduce = bf16[4096] reduce(first.ar, zero), dimensions={0,1}, to_apply=sum
+  all-reduce = bf16[4096] all-reduce(reduce), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=sum, channel_id=2
+  table = s32[8]{0} constant({0,1,2,3,0,1,2,3})
+  pid = u32[] partition-id()
+  id = s32[1] dynamic-slice(table, pid), dynamic_slice_sizes={1}
+  reshape = s32[] reshape(id)
+  slice_size = s32[] constant(1024)
+  offset = s32[] multiply(reshape, slice_size)
+  ROOT _ = bf16[1024] dynamic-slice(all-reduce, offset), dynamic_slice_sizes={1024}
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(
+      std::unique_ptr<HloModule> module,
+      PrepareModule(hlo_string, /*num_replicas=*/1, /*num_partitions=*/8));
+
+  EXPECT_THAT(AllReduceSplitter().Run(module.get()), IsOkAndHolds(false));
+
+  EXPECT_EQ(AllReduceCount(*module), 2);
+}
+
+TEST_F(AllReduceSplitterTest,
+       DoesNotMatchBasicPatternIfIsNotCrossAllPartitionsAllReduce) {  // NOLINT
+  absl::string_view hlo_string = R"(
+HloModule m
+
+sum {
+  a = bf16[] parameter(0)
+  b = bf16[] parameter(1)
+  ROOT _ = bf16[] add(a,b)
+}
+
+ENTRY main {
+  p = bf16[2,4096,4096] parameter(0)
+  first.ar = bf16[2,4096,4096] all-reduce(p), replica_groups={{0,1,2,3},{4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=1
+  zero = bf16[] constant(0)
+  reduce = bf16[4096] reduce(first.ar, zero), dimensions={0,1}, to_apply=sum
+  all-reduce = bf16[4096] all-reduce(reduce), replica_groups={{0,1,2,3},{4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=2
+  table = s32[8]{0} constant({0,1,2,3,0,1,2,3})
+  pid = u32[] partition-id()
+  id = s32[1] dynamic-slice(table, pid), dynamic_slice_sizes={1}
+  reshape = s32[] reshape(id)
+  slice_size = s32[] constant(1024)
+  offset = s32[] multiply(reshape, slice_size)
+  ROOT _ = bf16[1024] dynamic-slice(all-reduce, offset), dynamic_slice_sizes={1024}
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(
+      std::unique_ptr<HloModule> module,
+      PrepareModule(hlo_string, /*num_replicas=*/1, /*num_partitions=*/8));
+
+  EXPECT_THAT(AllReduceSplitter().Run(module.get()), IsOkAndHolds(false));
+
+  EXPECT_EQ(AllReduceCount(*module), 2);
+}
+
+TEST_F(
+    AllReduceSplitterFilecheckTest,
+    PipelineMatchesBasicPatternWithDynamicSliceAsRootAndRewritesToReduceScatter) {  // NOLINT
+  absl::string_view hlo_string = R"(
+HloModule m
+
+sum {
+  a = bf16[] parameter(0)
+  b = bf16[] parameter(1)
+  ROOT _ = bf16[] add(a,b)
+}
+
+ENTRY main {
+  p = bf16[2,4096,4096] parameter(0)
+  first.ar = bf16[2,4096,4096] all-reduce(p), replica_groups={{0,1,2,3},{4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=1
+  zero = bf16[] constant(0)
+  reduce = bf16[4096] reduce(first.ar, zero), dimensions={0,1}, to_apply=sum
+  all-reduce = bf16[4096] all-reduce(reduce), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=2
+  table = s32[8]{0} constant({0,1,2,3,0,1,2,3})
+  pid = u32[] partition-id()
+  id = s32[1] dynamic-slice(table, pid), dynamic_slice_sizes={1}
+  reshape = s32[] reshape(id)
+  slice_size = s32[] constant(1024)
+  offset = s32[] multiply(reshape, slice_size)
+  ROOT _ = bf16[1024] dynamic-slice(all-reduce, offset), dynamic_slice_sizes={1024}
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(
+      std::unique_ptr<HloModule> module,
+      PrepareModule(hlo_string, /*num_replicas=*/1, /*num_partitions=*/8));
+
+  HloPassPipeline pipeline("all-reduce-splitter-rewrite");
+  pipeline.AddPass<AllReduceSplitter>();
+  pipeline.AddPass<ReduceScatterCreator>();
+  EXPECT_THAT(pipeline.Run(module.get()), IsOkAndHolds(true));
+  TF_EXPECT_OK(FileCheck(module->ToString(), R"(
+    CHECK-DAG:    %[[P0:.*]] = bf16[2,4096,4096]{2,1,0} parameter(0)
+    CHECK:        %[[AR0:.*]] = bf16[2,4096,4096]{2,1,0} all-reduce(bf16[2,4096,4096]{2,1,0} %[[P0]])
+    CHECK-SAME:   replica_groups={[[DESIRED_RGS:.*]]}
+    CHECK-DAG:    %[[ZERO:.*]] = bf16[] constant(0)
+    CHECK-DAG:    %[[LOCAL_REDUCE:.*]] = bf16[4096]{0} reduce(bf16[2,4096,4096]{2,1,0} %[[AR0]], bf16[] %[[ZERO]])
+    CHECK:        %[[REDUCE_SCATTER:.*]] = bf16[1024]{0} reduce-scatter(bf16[4096]{0} %[[LOCAL_REDUCE]])
+    CHECK-SAME:   replica_groups={[[DESIRED_RGS]]}
+    CHECK-NEXT:   ROOT %[[AR2:.*]] = bf16[1024]{0} all-reduce(bf16[1024]{0} %[[REDUCE_SCATTER]])
+    CHECK-SAME:   replica_groups={{[{]}}{0,4},{1,5},{2,6},{3,7}{{[}]}}
+    )"));
+}
+
+TEST_F(
+    AllReduceSplitterFilecheckTest,
+    PipelineMatchesBasicPatternWithDynamicSliceNotAsRootAndRewritesToReduceScatter) {  // NOLINT
+  absl::string_view hlo_string = R"(
+HloModule m
+
+sum {
+  a = bf16[] parameter(0)
+  b = bf16[] parameter(1)
+  ROOT _ = bf16[] add(a,b)
+}
+
+ENTRY main {
+  p = bf16[2,4096,4096] parameter(0)
+  zero = bf16[] constant(0)
+  first.ar = bf16[2,4096,4096] all-reduce(p), replica_groups={{0,1,2,3},{4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=1
+  reduce = bf16[4096] reduce(p, zero), dimensions={0,1}, to_apply=sum
+  all-reduce = bf16[4096] all-reduce(reduce), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=1
+  table = s32[8]{0} constant({0,1,2,3,0,1,2,3})
+  pid = u32[] partition-id()
+  id = s32[1] dynamic-slice(table, pid), dynamic_slice_sizes={1}
+  reshape = s32[] reshape(id)
+  slice_size = s32[] constant(1024)
+  offset = s32[] multiply(reshape, slice_size)
+  dynamic_slice = bf16[1024] dynamic-slice(all-reduce, offset), dynamic_slice_sizes={1024}
+  broadcast = bf16[1024,1024] broadcast(dynamic_slice), dimensions={0}
+  ROOT _ = tuple(broadcast, first.ar)
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(
+      std::unique_ptr<HloModule> module,
+      PrepareModule(hlo_string, /*num_replicas=*/1, /*num_partitions=*/8));
+
+  HloPassPipeline pipeline("all-reduce-splitter-rewrite");
+  pipeline.AddPass<AllReduceSplitter>();
+  pipeline.AddPass<ReduceScatterCreator>();
+  EXPECT_THAT(pipeline.Run(module.get()), IsOkAndHolds(true));
+  TF_EXPECT_OK(FileCheck(module->ToString(), R"(
+    CHECK-DAG:    %[[P0:.*]] = bf16[2,4096,4096]{2,1,0} parameter(0)
+    CHECK-DAG:    %[[ZERO:.*]] = bf16[] constant(0)
+    CHECK-DAG:    %[[LOCAL_REDUCE:.*]] = bf16[4096]{0} reduce(bf16[2,4096,4096]{2,1,0} %[[P0]], bf16[] %[[ZERO]])
+    CHECK:        %[[REDUCE_SCATTER:.*]] = bf16[1024]{0} reduce-scatter(bf16[4096]{0} %[[LOCAL_REDUCE]])
+    CHECK-NEXT:   %[[AR1:.*]] = bf16[1024]{0} all-reduce(bf16[1024]{0} %[[REDUCE_SCATTER]])
+    CHECK-SAME:   replica_groups={{[{]}}{0,4},{1,5},{2,6},{3,7}{{[}]}}
+    CHECK:        %[[EXISTING_AR:.*]] = bf16[2,4096,4096]{2,1,0} all-reduce(bf16[2,4096,4096]{2,1,0} %[[P0]])
+    CHECK:        ROOT
+    CHECK-NOT:    %[[AR1]]
+    CHECK-SAME:   %[[EXISTING_AR]]
+    )"));
+}
+
+}  // namespace
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/async_collective_annotator.cc b/third_party/xla/xla/service/gpu/transforms/async_collective_annotator.cc
new file mode 100644
index 0000000..aa76aff
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/async_collective_annotator.cc
@@ -0,0 +1,55 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/async_collective_annotator.h"
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/utils/hlo_query.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+absl::StatusOr<bool> AsyncCollectiveAnnotator::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  bool changed = false;
+  for (HloComputation* computation :
+       module->MakeNonfusionComputations(execution_threads)) {
+    for (HloInstruction* instruction : computation->instructions()) {
+      if (!hlo_query::IsAsyncCollectiveStartOp(instruction)) {
+        continue;
+      }
+      CollectiveBackendConfig config;
+      config.set_is_sync(!is_collective_async_(instruction));
+      TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
+                          instruction->backend_config<GpuBackendConfig>());
+      *gpu_config.mutable_collective_backend_config() = config;
+      TF_RETURN_IF_ERROR(instruction->set_backend_config(gpu_config));
+      changed = true;
+    }
+  }
+  return changed;
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/async_collective_annotator.h b/third_party/xla/xla/service/gpu/transforms/async_collective_annotator.h
new file mode 100644
index 0000000..1b41d50
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/async_collective_annotator.h
@@ -0,0 +1,52 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_ASYNC_COLLECTIVE_ANNOTATOR_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_ASYNC_COLLECTIVE_ANNOTATOR_H_
+
+#include <utility>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/util.h"
+
+namespace xla {
+namespace gpu {
+
+// Annotate async collectives with CollectiveBackendConfig.
+class AsyncCollectiveAnnotator : public HloModulePass {
+ public:
+  explicit AsyncCollectiveAnnotator(HloPredicate is_collective_async)
+      : is_collective_async_(std::move(is_collective_async)) {}
+  absl::string_view name() const override {
+    return "async-collective-annotator";
+  }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+  HloPredicate is_collective_async_;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_ASYNC_COLLECTIVE_ANNOTATOR_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/async_collective_annotator_test.cc b/third_party/xla/xla/service/gpu/transforms/async_collective_annotator_test.cc
new file mode 100644
index 0000000..6622a7b
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/async_collective_annotator_test.cc
@@ -0,0 +1,183 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/async_collective_annotator.h"
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "absl/container/flat_hash_set.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/hlo/utils/hlo_query.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/tests/test_macros.h"
+#include "xla/util.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+constexpr absl::string_view kHloString = R"(
+  HloModule ModuleWithAsync
+
+  addf32 {
+    p0 = f32[] parameter(0)
+    p1 = f32[] parameter(1)
+    ROOT add = f32[] add(p0, p1)
+  }
+
+  addf16 {
+    p0 = f16[] parameter(0)
+    p1 = f16[] parameter(1)
+    ROOT add = f16[] add(p0, p1)
+  }
+
+  reduce_scatterf32 {
+    p0 = f32[2] parameter(0)
+    ROOT result = f32[1] reduce-scatter(p0), replica_groups={},
+                      dimensions={0}, to_apply=addf32
+  }
+
+  ENTRY entry {
+    pf32 = f32[1] parameter(0)
+    pf16 = f16[1] parameter(1)
+
+    arf32-start = f32[1] all-reduce-start(pf32), to_apply=addf32
+    arf32-done = f32[1] all-reduce-done(arf32-start)
+
+    arf16-start = f16[1] all-reduce-start(pf16), to_apply=addf16
+    arf16-done = f16[1] all-reduce-done(arf16-start)
+
+    agf32-start = (f32[1], f32[2]) all-gather-start(pf32), dimensions={0}
+    agf32-done = f32[2] all-gather-done(agf32-start)
+
+    agf16-start = (f16[1], f16[2]) all-gather-start(pf16), dimensions={0}
+    agf16-done = f16[2] all-gather-done(agf16-start)
+
+    cpf32-start = (f32[1], f32[1], u32[], u32[]) collective-permute-start(pf32),
+                    source_target_pairs={{0,1}, {1,0}}
+    cpf32-done = f32[1] collective-permute-done(cpf32-start)
+
+    cpf16-start = (f16[1], f16[1], u32[], u32[]) collective-permute-start(pf16),
+                    source_target_pairs={{0,1}, {1,0}}
+    cpf16-done = f16[1] collective-permute-done(cpf16-start)
+
+    rsf32-start = ((f32[2]), f32[1]) async-start(agf32-done), calls=reduce_scatterf32
+    rsf32-done = f32[1] async-done(rsf32-start), calls=reduce_scatterf32
+
+    ROOT tuple = (f32[1], f16[1], f32[2], f16[2], f32[1], f16[1], f32[1])
+                tuple(arf32-done, arf16-done, agf32-done, agf16-done, cpf32-done,
+                      cpf16-done, rsf32-done)
+  }
+)";
+
+struct TestCase {
+  std::string test_name;
+  HloPredicate is_async_predicate;
+  absl::flat_hash_set<absl::string_view> expected_async;
+  absl::flat_hash_set<absl::string_view> expected_sync;
+};
+
+class AsyncCollectiveAnnotatorTest
+    : public HloTestBase,
+      public ::testing::WithParamInterface<TestCase> {};
+
+XLA_TEST_P(AsyncCollectiveAnnotatorTest, Test) {
+  const TestCase& test_case = GetParam();
+  TF_ASSERT_OK_AND_ASSIGN(
+      std::unique_ptr<HloModule> module,
+      ParseAndReturnVerifiedModule(kHloString, /*replica_count=*/2));
+  TF_ASSERT_OK_AND_ASSIGN(
+      bool changed,
+      AsyncCollectiveAnnotator(test_case.is_async_predicate).Run(module.get()));
+  EXPECT_TRUE(changed);
+
+  // Assert that all async collectives are annotated with the backend config.
+  for (const HloInstruction* hlo :
+       module->entry_computation()->instructions()) {
+    if (!hlo_query::IsAsyncCollectiveStartOp(hlo)) {
+      continue;
+    }
+    auto gpu_config = hlo->backend_config<GpuBackendConfig>();
+    ASSERT_TRUE(gpu_config.ok());
+
+    const CollectiveBackendConfig& backend_config =
+        gpu_config.value().collective_backend_config();
+    if (test_case.expected_async.contains(hlo->name())) {
+      EXPECT_FALSE(backend_config.is_sync());
+    }
+
+    if (test_case.expected_sync.contains(hlo->name())) {
+      EXPECT_TRUE(backend_config.is_sync());
+    }
+  }
+}
+
+std::vector<TestCase> TestCases() {
+  HloPredicate is_f16 = [](const HloInstruction* hlo) {
+    return hlo->operand(0)->shape().element_type() == PrimitiveType::F16;
+  };
+
+  return {
+      {"all_async",
+       HloPredicateTrue, /*expected_async=*/
+       {"arf32-start", "arf16-start", "agf32-start", "agf16-start",
+        "cpf32-start", "cpf16-start", "rsf32-start"},
+       /*expected_sync=*/{}},
+      {"all_sync",
+       HloPredicateFalse,
+       /*expected_async=*/{},
+       /*expected_sync=*/
+       {"arf32-start", "arf16-start", "agf32-start", "agf16-start",
+        "cpf32-start", "cpf16-start", "rsf32-start"}},
+      {"ar_async",
+       HloPredicateIsOp<HloOpcode::kAllReduceStart>,
+       /*expected_async=*/
+       {"arf32-start", "arf16-start"},
+       /*expected_sync=*/
+       {"agf32-start", "agf16-start", "cpf32-start", "cpf16-start",
+        "rsf32-start"}},
+      {"cp_async",
+       HloPredicateIsOp<HloOpcode::kCollectivePermuteStart>,
+       /*expected_async=*/
+       {"cpf32-start", "cpf16-start"},
+       /*expected_sync=*/
+       {"arf32-start", "arf16-start", "agf32-start", "agf16-start",
+        "rsf32-start"}},
+      {"f16_async",
+       is_f16,
+       /*expected_async=*/{"arf16-start", "agf16-start", "cpf16-start"},
+       /*expected_sync=*/
+       {"arf32-start", "agf32-start", "cpf32-start", "rsf32-start"}},
+  };
+}
+
+std::string TestCaseName(const ::testing::TestParamInfo<TestCase>& test_case) {
+  return test_case.param.test_name;
+}
+
+INSTANTIATE_TEST_SUITE_P(AsyncCollectiveAnnotatorTest,
+                         AsyncCollectiveAnnotatorTest,
+                         ::testing::ValuesIn(TestCases()), TestCaseName);
+}  // namespace
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/async_wrapper.cc b/third_party/xla/xla/service/gpu/transforms/async_wrapper.cc
new file mode 100644
index 0000000..baeecba
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/async_wrapper.cc
@@ -0,0 +1,70 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/async_wrapper.h"
+
+#include <algorithm>
+#include <deque>
+#include <iterator>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/shape_util.h"
+#include "tsl/platform/errors.h"
+
+namespace xla::gpu {
+
+absl::StatusOr<bool> AsyncWrapper::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  bool changed = false;
+
+  std::deque<HloComputation*> computations;
+  computations.push_back(module->entry_computation());
+  while (!computations.empty()) {
+    HloComputation* computation = computations.front();
+    computations.pop_front();
+
+    for (HloInstruction* instruction :
+         computation->MakeInstructionPostOrder()) {
+      if (predicate_(instruction)) {
+        // If the predicate matches, then wrap the instructions in async blocks.
+        TF_RETURN_IF_ERROR(
+            computation
+                ->CreateAsyncInstructions(instruction,
+                                          {ShapeUtil::MakeScalarShape(U32)})
+                .status());
+        changed = true;
+        continue;
+      }
+
+      // Otherwise, follow any `calls` to discover other instructions that can
+      // potentially be made async.
+      if (instruction->opcode() == HloOpcode::kCall) {
+        std::copy(instruction->called_computations().begin(),
+                  instruction->called_computations().end(),
+                  std::back_inserter(computations));
+      }
+    }
+  }
+  return changed;
+}
+
+}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/transforms/async_wrapper.h b/third_party/xla/xla/service/gpu/transforms/async_wrapper.h
new file mode 100644
index 0000000..d6cefe8
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/async_wrapper.h
@@ -0,0 +1,51 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_ASYNC_WRAPPER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_ASYNC_WRAPPER_H_
+
+#include <functional>
+#include <utility>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla::gpu {
+
+// AsyncWrappers wrap instructions that match a given `predicate` into async
+// blocks (i.e. `async-start` and `async-stop` instructions) so that they run
+// concurrently.
+class AsyncWrapper : public HloModulePass {
+ public:
+  using Predicate = std::function<bool(HloInstruction*)>;
+  explicit AsyncWrapper(Predicate predicate)
+      : predicate_(std::move(predicate)) {}
+
+  absl::string_view name() const override { return "async-wrapper"; }
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+  const Predicate predicate_;
+};
+
+}  // namespace xla::gpu
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_ASYNC_WRAPPER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/async_wrapper_test.cc b/third_party/xla/xla/service/gpu/transforms/async_wrapper_test.cc
new file mode 100644
index 0000000..9d69899
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/async_wrapper_test.cc
@@ -0,0 +1,87 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/async_wrapper.h"
+
+#include <memory>
+#include <utility>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/literal.h"
+#include "xla/literal_util.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/tests/literal_test_util.h"
+#include "xla/tests/verified_hlo_module.h"
+#include "tsl/platform/status_matchers.h"
+
+namespace xla::gpu {
+namespace {
+
+using ::tsl::testing::IsOkAndHolds;
+
+class AsyncWrapperTest : public HloTestBase {};
+
+int CountAsyncInstructions(HloComputation* computation) {
+  int count = 0;
+  for (const HloInstruction* instruction : computation->instructions()) {
+    if (instruction->IsAsynchronous()) ++count;
+  }
+  return count;
+}
+
+TEST_F(AsyncWrapperTest, BasicFusion) {
+  const char* hlo_text = R"(
+  HloModule m
+
+  double1 {
+    p0 = f32[1] parameter(0)
+    ROOT add = f32[1] add(p0, p0)
+  }
+
+  double2 {
+    p0 = f32[1] parameter(0)
+    ROOT add = f32[1] add(p0, p0)
+  }
+
+  ENTRY main {
+    p0 = f32[1] parameter(0)
+    agg1 = f32[1] fusion(p0), kind=kLoop, calls=double1
+    agg2 = f32[1] fusion(p0), kind=kLoop, calls=double2
+    ROOT done = f32[1] add(agg1, agg2)
+  })";
+
+  std::unique_ptr<VerifiedHloModule> module =
+      ParseAndReturnVerifiedModule(hlo_text).value();
+
+  AsyncWrapper wrapper([](const HloInstruction* instruction) {
+    return instruction->opcode() == HloOpcode::kFusion;
+  });
+  EXPECT_THAT(wrapper.HloModulePass::Run(module.get()), IsOkAndHolds(true));
+  EXPECT_EQ(CountAsyncInstructions(module->entry_computation()), 4);
+
+  Literal argument = LiteralUtil::CreateR1<float>({1.0});
+  Literal expected = LiteralUtil::CreateR1<float>({4.0});
+
+  Literal result = ExecuteNoHloPasses(std::move(module), {&argument});
+  EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
+}
+
+}  // namespace
+}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/transforms/collective_permute_cycle_decomposer.cc b/third_party/xla/xla/service/gpu/transforms/collective_permute_cycle_decomposer.cc
new file mode 100644
index 0000000..07f52fd
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/collective_permute_cycle_decomposer.cc
@@ -0,0 +1,231 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/collective_permute_cycle_decomposer.h"
+
+#include <cstdint>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/check.h"
+#include "absl/status/status.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/string_view.h"
+#include "xla/comparison_util.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/hlo/utils/hlo_query.h"
+#include "xla/literal_util.h"
+#include "xla/service/collective_ops_utils.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/hlo_parser.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
+
+namespace xla {
+
+namespace {
+using SourceTargetPair = std::pair<int64_t, int64_t>;
+using SourceTargetPairs = std::vector<SourceTargetPair>;
+enum class CycleType { kUnknown, kForward, kBackward };
+
+// Returns true if the CollectivePermute instruction has a cycle in its
+// source-target pairs and should be decomposed.
+CycleType ShouldDecomposeWithCycleType(
+    const HloCollectivePermuteInstruction& collective_permute,
+    int64_t threshold_in_bytes) {
+  if (!collective_permute.channel_id().has_value()) {
+    return CycleType::kUnknown;
+  }
+
+  if (collective_permute.operand_count() != 1) {
+    return CycleType::kUnknown;
+  }
+
+  const Shape& result_shape = collective_permute.shape();
+  // Skip the transformation if there is any context data.
+  if (result_shape.IsTuple()) {
+    return CycleType::kUnknown;
+  }
+
+  CHECK(result_shape.IsArray());
+  if (ShapeUtil::ByteSizeOf(result_shape) < threshold_in_bytes) {
+    return CycleType::kUnknown;
+  }
+
+  const SourceTargetPairs& pairs = collective_permute.source_target_pairs();
+  if (pairs.size() == 1) {
+    return CycleType::kUnknown;
+  }
+
+  return IsForwardCycle(pairs)    ? CycleType::kForward
+         : IsBackwardCycle(pairs) ? CycleType::kBackward
+                                  : CycleType::kUnknown;
+}
+
+// Constructs the frontend attributes for the two decomposed CollectivePermute
+// instructions.
+absl::Status GetFrontendAttributes(HloCollectivePermuteInstruction* cp,
+                                   CycleType cycle_type,
+                                   xla::FrontendAttributes& cp1_attr,
+                                   xla::FrontendAttributes& cp2_attr) {
+  cp1_attr = cp->frontend_attributes();
+  cp2_attr = cp->frontend_attributes();
+  auto validation_it =
+      cp->frontend_attributes().map().find(kSendRecvValidationAttr);
+  if (validation_it == cp->frontend_attributes().map().end() ||
+      validation_it->second == "invalid") {
+    return absl::OkStatus();
+  }
+
+  auto statusor_bounds = ParseReplicaGroupsOnly(validation_it->second);
+  if (!statusor_bounds.ok()) {
+    return statusor_bounds.status();
+  }
+  const std::vector<ReplicaGroup>& bounds = statusor_bounds.value();
+  if (bounds.size() < 2) {
+    return Internal("Invalid number of replica groups");
+  }
+
+  int64_t num_pairs = bounds.size();
+  // A forward cycle has its backedge at the end while a backward cycle has its
+  // backedge at the beginning.
+  auto backedge_start = cycle_type == CycleType::kBackward
+                            ? bounds.begin()
+                            : bounds.begin() + num_pairs - 1;
+  auto other_edges_start =
+      cycle_type == CycleType::kBackward ? bounds.begin() + 1 : bounds.begin();
+  std::vector<ReplicaGroup> cp1_bounds(backedge_start, backedge_start + 1);
+  std::vector<ReplicaGroup> cp2_bounds(other_edges_start,
+                                       other_edges_start + num_pairs - 1);
+  auto bounds_to_string = [](const std::vector<ReplicaGroup> groups) {
+    return "{" +
+           absl::StrJoin(groups, ",",
+                         [](std::string* out, const ReplicaGroup& value) {
+                           absl::StrAppend(out, "{", value.replica_ids(0), ",",
+                                           value.replica_ids(1), "}");
+                         }) +
+           "}";
+  };
+  std::string cp1_validation_str = bounds_to_string(cp1_bounds);
+  std::string cp2_validation_str = bounds_to_string(cp2_bounds);
+  (*cp1_attr.mutable_map())[kSendRecvValidationAttr] = cp1_validation_str;
+  (*cp2_attr.mutable_map())[kSendRecvValidationAttr] = cp2_validation_str;
+  return absl::OkStatus();
+}
+
+// Decomposes a CollectivePermute instruction with a cycle in its source-target
+// pairs into two CollectivePermute instructions.
+absl::Status DecomposeCollectivePermuteCycle(
+    HloCollectivePermuteInstruction* cp, HloComputation* computation,
+    HloModule* module, int64_t next_channel_id, CycleType cycle_type) {
+  const SourceTargetPairs& pairs = cp->source_target_pairs();
+  int64_t num_pairs = pairs.size();
+  // A forward cycle has its backedge at the end as in
+  // {{0,1},{1,2},{2,3},{3,0}} while a backward cycle has its backedge at the
+  // beginning as in {{0,3},{1,0},{2,1},{3,2}}.
+  auto backedge_start = cycle_type == CycleType::kBackward
+                            ? pairs.begin()
+                            : pairs.begin() + num_pairs - 1;
+  auto other_edges_start =
+      cycle_type == CycleType::kBackward ? pairs.begin() + 1 : pairs.begin();
+  SourceTargetPairs backedge(backedge_start, backedge_start + 1);
+  SourceTargetPairs other_edges(other_edges_start,
+                                other_edges_start + num_pairs - 1);
+  const OpMetadata& metadata = cp->metadata();
+  xla::FrontendAttributes cp1_attr, cp2_attr;
+  TF_RETURN_IF_ERROR(GetFrontendAttributes(cp, cycle_type, cp1_attr, cp2_attr));
+
+  // Create the CollectivePermute instruction for the communication represented
+  // by the backedge.
+  HloInstruction* cp1 =
+      computation->AddInstruction(HloInstruction::CreateCollectivePermute(
+          cp->shape(), cp->mutable_operand(0), backedge,
+          cp->channel_id().value()));
+  cp1->set_metadata(metadata);
+  cp1->set_frontend_attributes(cp1_attr);
+  int64_t cp1_receiver = backedge.back().second;
+
+  // Create the CollectivePermute instruction for the communication represented
+  // byt other edges.
+  HloInstruction* cp2 =
+      computation->AddInstruction(HloInstruction::CreateCollectivePermute(
+          cp->shape(), cp->mutable_operand(0), other_edges, next_channel_id));
+  cp2->set_metadata(metadata);
+  cp2->set_frontend_attributes(cp2_attr);
+
+  // Calculate the received data as follows:
+  //   partition = u32[] partition-id()
+  //   constant = u32[] constant(cp1_receiver)
+  //   compare0 = pred[] compare(partition, cp1_received), direction=EQ
+  //   compare = pred[?] broadcast(compare0), dimensions={}
+  //   recv-data = type[?] select(compare, cp1_done, cp2_done)
+  HloInstruction* partition =
+      computation->AddInstruction(HloInstruction::CreatePartitionId());
+  HloInstruction* constant = computation->AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::CreateR0(U32, cp1_receiver)));
+  HloInstruction* compare0 = computation->AddInstruction(
+      HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), partition,
+                                    constant, Comparison::Direction::kEq));
+  HloInstruction* compare =
+      computation->AddInstruction(HloInstruction::CreateBroadcast(
+          ShapeUtil::MakeShape(PRED, cp1->shape().dimensions()), compare0, {}));
+  HloInstruction* recv_data =
+      computation->AddInstruction(HloInstruction::CreateTernary(
+          cp1->shape(), HloOpcode::kSelect, compare, cp1, cp2));
+
+  TF_RETURN_IF_ERROR(cp->ReplaceAllUsesWith(recv_data));
+  TF_RETURN_IF_ERROR(computation->RemoveInstructionAndUnusedOperands(cp));
+
+  return absl::OkStatus();
+}
+}  // namespace
+
+absl::StatusOr<bool> CollectivePermuteCycleDecomposer::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  bool changed = false;
+  int64_t next_channel_id;
+  for (auto comp : module->computations(execution_threads)) {
+    for (auto hlo : comp->MakeInstructionPostOrder()) {
+      if (hlo->opcode() != HloOpcode::kCollectivePermute) {
+        continue;
+      }
+      auto collective_permute = Cast<HloCollectivePermuteInstruction>(hlo);
+      CycleType cycle_type = ShouldDecomposeWithCycleType(*collective_permute,
+                                                          threshold_in_bytes_);
+      if (cycle_type != CycleType::kUnknown) {
+        if (changed == false) {
+          next_channel_id = hlo_query::NextChannelId(*module);
+          changed = true;
+        }
+        TF_RETURN_IF_ERROR(DecomposeCollectivePermuteCycle(
+            collective_permute, comp, module, next_channel_id++, cycle_type));
+      }
+    }
+  }
+  return changed;
+}
+
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/collective_permute_cycle_decomposer.h b/third_party/xla/xla/service/gpu/transforms/collective_permute_cycle_decomposer.h
new file mode 100644
index 0000000..cfacd66
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/collective_permute_cycle_decomposer.h
@@ -0,0 +1,73 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_COLLECTIVE_PERMUTE_CYCLE_DECOMPOSER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_COLLECTIVE_PERMUTE_CYCLE_DECOMPOSER_H_
+
+#include <cstdint>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+
+// CollectivePermuteCycleDecomposer is a pass that converts CollectivePermute
+// instructions with all participants forming either a forward cycle (such as
+// {{0,1},{1,2},{2,3},{3,0}) or a backward cycle (such as {{3,2},{2,1},{1,0},
+// {0,3}}) into two CollectivePermute instructions. We currently restrict
+// this transformation to CollectivePermute using partition mode, with one
+// input, without any context data. Here is an example.
+//
+// before transformation:
+//     start = (<rt>, <rt>) collective-permute(data),
+//       source_target_pairs={{0,1},{1,2},{2,3},{3,0}}
+//
+// after transformation:
+//     partition-id = u32[] partition-id()
+//     constant = u32[] constant(0)
+//     compare = pred[] compare(u32[] partition-id, u32[] constant),
+//       direction=EQ
+//     pred = pred[] broadcast(pred[] compare), dimensions={}
+//     cp1 = (<rt>, <rt>) collective-permute(data), source_target_pairs={{3,0}}
+//     cp2 = (<rt>, <rt>) collective-permute(data),
+//       source_target_pairs={{0,1},{1,2},{2,3}}
+//     data = <rt> select(pred, cp1, cp2)
+//
+class CollectivePermuteCycleDecomposer : public HloModulePass {
+ public:
+  explicit CollectivePermuteCycleDecomposer(int64_t threshold_in_bytes)
+      : threshold_in_bytes_(threshold_in_bytes) {}
+  absl::string_view name() const override {
+    return "collective-permute-cycle-decomposer";
+  }
+
+  using HloPassInterface::Run;
+  // Runs CollectivePermuteCycleDecomposer pass on computations in 'module'.
+  // Returns whether the 'module' was changed.
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+  // Transform only if the size of the CollectivePermute data >= threshold.
+  int64_t threshold_in_bytes_;
+};
+
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_COLLECTIVE_PERMUTE_CYCLE_DECOMPOSER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/collective_permute_cycle_decomposer_test.cc b/third_party/xla/xla/service/gpu/transforms/collective_permute_cycle_decomposer_test.cc
new file mode 100644
index 0000000..ae537d9
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/collective_permute_cycle_decomposer_test.cc
@@ -0,0 +1,234 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/collective_permute_cycle_decomposer.h"
+
+#include <memory>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_parser.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace {
+
+using ::testing::HasSubstr;
+using CollectivePermuteCycleDecomposerTest = HloTestBase;
+
+using ::testing::HasSubstr;
+using CollectivePermuteDecomposerTest = HloTestBase;
+
+TEST_F(CollectivePermuteDecomposerTest, DefaultChannelNotTransformed) {
+  const absl::string_view kModuleStr = R"(
+      HloModule test
+      ENTRY test_computation {
+        p = u32[] replica-id()
+        ROOT start = u32[] collective-permute(p),
+          source_target_pairs={{0,1},{1,0}}
+      }
+    )";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnUnverifiedModule((kModuleStr)));
+  CollectivePermuteCycleDecomposer decomposer(/*threshold_in_bytes=*/0);
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get()));
+  EXPECT_FALSE(changed);
+}
+
+TEST_F(CollectivePermuteCycleDecomposerTest, TrivialNotTransformed) {
+  const absl::string_view kModuleStr = R"(
+      HloModule test
+      ENTRY test_computation {
+        p = u32[] partition-id()
+        ROOT start = u32[] collective-permute(p), channel_id=1,
+          source_target_pairs={{0,0}}
+      }
+    )";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnUnverifiedModule((kModuleStr)));
+  CollectivePermuteCycleDecomposer decomposer(/*threshold_in_bytes=*/0);
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get()));
+  EXPECT_FALSE(changed);
+}
+
+TEST_F(CollectivePermuteCycleDecomposerTest, BelowThresholdNotTransformed) {
+  const absl::string_view kModuleStr = R"(
+      HloModule test
+      ENTRY test_computation {
+        p = u32[] partition-id()
+        ROOT start = u32[] collective-permute(p), channel_id=1,
+          source_target_pairs={{0,1},{1,2},{2,3},{3,0}}
+      }
+    )";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnUnverifiedModule((kModuleStr)));
+  CollectivePermuteCycleDecomposer decomposer(/*threshold_in_bytes=*/33);
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get()));
+  EXPECT_FALSE(changed);
+}
+
+TEST_F(CollectivePermuteCycleDecomposerTest, ForwardCycle) {
+  const absl::string_view kModuleStr = R"(
+      HloModule test
+      ENTRY test_computation {
+        p = u32[] partition-id()
+        ROOT start = u32[3,2] collective-permute(p), channel_id=1,
+          source_target_pairs={{0,1},{1,2},{2,3},{3,0}},
+          frontend_attributes={_xla_send_recv_validation="{{0,7},{1,8},{2,9},{3,10}}"},
+          metadata={op_name="op1/op2/add" source_file="foo/bar/mysource.py" source_line=35}
+      }
+    )";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnUnverifiedModule((kModuleStr)));
+  CollectivePermuteCycleDecomposer decomposer(/*threshold_in_bytes=*/0);
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get()));
+  EXPECT_TRUE(changed);
+
+  auto check_metadata = [](const HloInstruction* inst) {
+    EXPECT_EQ(inst->metadata().op_name(), "op1/op2/add");
+    EXPECT_EQ(inst->metadata().source_file(), "foo/bar/mysource.py");
+    EXPECT_EQ(inst->metadata().source_line(), 35);
+  };
+
+  HloCollectivePermuteInstruction* cp1 =
+      DynCast<HloCollectivePermuteInstruction>(
+          FindInstruction(module.get(), "collective-permute"));
+  HloCollectivePermuteInstruction* cp2 =
+      DynCast<HloCollectivePermuteInstruction>(
+          FindInstruction(module.get(), "collective-permute.1"));
+  EXPECT_NE(cp1, nullptr);
+  EXPECT_NE(cp2, nullptr);
+  EXPECT_EQ(cp1->operand(0), cp2->operand(0));
+  EXPECT_GT(cp2->channel_id().value(), cp1->channel_id().value());
+  EXPECT_THAT(cp1->ToString(), HasSubstr("source_target_pairs={{3,0}}"));
+  EXPECT_THAT(cp1->ToString(), HasSubstr("_xla_send_recv_validation={{3,10}}"));
+  EXPECT_THAT(cp2->ToString(),
+              HasSubstr("source_target_pairs={{0,1},{1,2},{2,3}}"));
+  EXPECT_THAT(cp2->ToString(),
+              HasSubstr("_xla_send_recv_validation={{0,7},{1,8},{2,9}}"));
+  check_metadata(cp1);
+  check_metadata(cp2);
+}
+
+TEST_F(CollectivePermuteCycleDecomposerTest, ForwardCycleWithMatmul) {
+  const absl::string_view kModuleStr = R"(
+  HloModule test
+
+  while_cond {
+    param = (u32[], f32[2,2], f32[2,2]) parameter(0)
+    iter = u32[] get-tuple-element(param), index=0
+    max_iter = u32[] constant(3)
+    ROOT cmp = pred[] compare(iter, max_iter), direction=LT
+  }
+
+  while_body {
+    param = (u32[], f32[2,2], f32[2,2]) parameter(0)
+    iter = u32[] get-tuple-element(param), index=0
+    data = f32[2,2] get-tuple-element(param), index=1
+    weights = f32[2,2] get-tuple-element(param), index=2
+    cp = f32[2,2] collective-permute(data),
+      channel_id=1,
+      source_target_pairs={{0,1}, {1,2}, {2,3}, {3,0}},
+      frontend_attributes={_xla_send_recv_validation="{{0,7},{1,8},{2,9},{3,10}}"}
+    matmul = f32[2,2] dot(weights, cp), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+    iter_increment = u32[] constant(1)
+    next_iter = u32[] add(iter, iter_increment)
+    ROOT result = (u32[], f32[2,2], f32[2,2]) tuple(next_iter, matmul, weights)
+  }
+
+  ENTRY test_computation {
+    iter = u32[] constant(0)
+    data = f32[2,2] parameter(0)
+    weights = f32[2,2] parameter(1)
+    input = (u32[], f32[2,2], f32[2,2]) tuple(iter, data, weights)
+    while_res = (u32[], f32[2,2], f32[2,2]) while(input), condition=while_cond, body=while_body
+    ROOT data_out = f32[2,2] get-tuple-element(while_res), index=1
+  }
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnUnverifiedModule((kModuleStr)));
+  CollectivePermuteCycleDecomposer decomposer(/*threshold_in_bytes=*/0);
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get()));
+  EXPECT_TRUE(changed);
+  HloCollectivePermuteInstruction* cp1 =
+      DynCast<HloCollectivePermuteInstruction>(
+          FindInstruction(module.get(), "collective-permute"));
+  HloCollectivePermuteInstruction* cp2 =
+      DynCast<HloCollectivePermuteInstruction>(
+          FindInstruction(module.get(), "collective-permute.1"));
+  EXPECT_THAT(cp1->ToString(), HasSubstr("source_target_pairs={{3,0}}"));
+  EXPECT_THAT(cp1->ToString(), HasSubstr("_xla_send_recv_validation={{3,10}}"));
+  EXPECT_THAT(cp2->ToString(),
+              HasSubstr("source_target_pairs={{0,1},{1,2},{2,3}}"));
+  EXPECT_THAT(cp2->ToString(),
+              HasSubstr("_xla_send_recv_validation={{0,7},{1,8},{2,9}}"));
+}
+
+TEST_F(CollectivePermuteCycleDecomposerTest, BackwardCycle) {
+  const absl::string_view kModuleStr = R"(
+      HloModule test
+      ENTRY test_computation {
+        p = u32[] partition-id()
+        ROOT start = u32[] collective-permute(p), channel_id=1,
+          source_target_pairs={{0,3},{1,0},{2,1},{3,2}},
+          frontend_attributes={_xla_send_recv_validation="{{0,7},{1,8},{2,9},{3,10}}"},
+          metadata={op_name="op1/op2/add" source_file="foo/bar/mysource.py" source_line=35}
+      }
+    )";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnUnverifiedModule((kModuleStr)));
+  CollectivePermuteCycleDecomposer decomposer(/*threshold_in_bytes=*/0);
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get()));
+  EXPECT_TRUE(changed);
+  auto check_metadata = [](const HloInstruction* inst) {
+    EXPECT_EQ(inst->metadata().op_name(), "op1/op2/add");
+    EXPECT_EQ(inst->metadata().source_file(), "foo/bar/mysource.py");
+    EXPECT_EQ(inst->metadata().source_line(), 35);
+  };
+
+  HloCollectivePermuteInstruction* cp1 =
+      DynCast<HloCollectivePermuteInstruction>(
+          FindInstruction(module.get(), "collective-permute"));
+  HloCollectivePermuteInstruction* cp2 =
+      DynCast<HloCollectivePermuteInstruction>(
+          FindInstruction(module.get(), "collective-permute.1"));
+  EXPECT_NE(cp1, nullptr);
+  EXPECT_NE(cp2, nullptr);
+  EXPECT_EQ(cp1->operand(0), cp2->operand(0));
+  EXPECT_GT(cp2->channel_id().value(), cp1->channel_id().value());
+  EXPECT_THAT(cp1->ToString(), HasSubstr("source_target_pairs={{0,3}}"));
+  EXPECT_THAT(cp1->ToString(), HasSubstr("_xla_send_recv_validation={{0,7}}"));
+  EXPECT_THAT(cp2->ToString(),
+              HasSubstr("source_target_pairs={{1,0},{2,1},{3,2}}"));
+  EXPECT_THAT(cp2->ToString(),
+              HasSubstr("_xla_send_recv_validation={{1,8},{2,9},{3,10}}"));
+  check_metadata(cp1);
+  check_metadata(cp2);
+}
+
+}  // namespace
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/collective_permute_valid_iteration_annotator.cc b/third_party/xla/xla/service/gpu/transforms/collective_permute_valid_iteration_annotator.cc
new file mode 100644
index 0000000..e9df22a
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/collective_permute_valid_iteration_annotator.cc
@@ -0,0 +1,163 @@
+/* Copyright 2024 The OpenXLA Authors.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+    http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/collective_permute_valid_iteration_annotator.h"
+
+#include "xla/literal_util.h"
+#include "xla/service/collective_ops_utils.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/while_loop_analysis.h"
+
+namespace xla {
+
+// Finds and returns the non-constant operand in instr.
+//
+// CHECK-fails if instr doesn't have exactly one unique non-constant operand.
+static const HloInstruction* NonConstantOperand(const HloInstruction* instr) {
+  const HloInstruction* result = nullptr;
+  for (const HloInstruction* operand : instr->operands()) {
+    if (!operand->IsConstant()) {
+      if (result != nullptr) {
+        CHECK_EQ(result, operand);
+      }
+      result = operand;
+    }
+  }
+  CHECK_NE(result, nullptr);
+  return result;
+}
+
+// Finds the step (k) for while instruction, if the loop is of the form:
+//
+// while(cond) {
+//   ind_var = ind_var + k
+// }
+//
+// If this pattern is not found, it returns std::nullopt.
+std::optional<int64_t> GetStep(HloInstruction* while_inst) {
+  // Get the update operation
+  std::optional<int64_t> indvar_tuple_idx =
+      GetLoopInductionVarTupleIdx(while_inst);
+  if (!indvar_tuple_idx) {
+    return std::nullopt;
+  };
+  auto* while_body_indvar_update =
+      while_inst->while_body()->root_instruction()->mutable_operand(
+          *indvar_tuple_idx);
+  auto* while_body_indvar = NonConstantOperand(while_body_indvar_update);
+
+  HloInstruction* trip_count_increase_step_instr = nullptr;
+  if (!Match(while_body_indvar_update,
+             match::AddAnyOrder(match::Op().Is(while_body_indvar),
+                                match::Op(&trip_count_increase_step_instr)))) {
+    return std::nullopt;
+  }
+  return LiteralUtil::LiteralAsScalarInt64(
+      trip_count_increase_step_instr->literal());
+}
+
+absl::StatusOr<bool> CollectivePermuteValidIterationAnnotator::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  bool changed = false;
+  for (HloComputation* comp : module->computations(execution_threads)) {
+    for (HloInstruction* inst : comp->instructions()) {
+      if (inst->opcode() != HloOpcode::kCollectivePermute) {
+        continue;
+      }
+
+      if (inst->frontend_attributes().map().find(kSendRecvValidationAttr) !=
+          inst->frontend_attributes().map().end()) {
+        continue;
+      }
+      auto sourceTargetPairs = inst->source_target_pairs();
+      if (!IsForwardCycle(sourceTargetPairs) &&
+          !IsBackwardCycle(sourceTargetPairs)) {
+        continue;
+      }
+
+      VLOG(2) << "Collective permute with cycle: " << inst->ToString();
+
+      int64_t max_device_num = -1;
+      for (auto [source, target] : sourceTargetPairs) {
+        max_device_num = std::max(std::max(source, target), max_device_num);
+      }
+      int64_t num_devices = max_device_num + 1;
+
+      HloInstruction* whileOp = inst->parent()->WhileCallInstruction();
+      if (whileOp == nullptr) {
+        VLOG(2) << "No surrounding while op found. Ignoring " << inst->name();
+        continue;
+      }
+      if (!whileOp->frontend_attributes().map().contains(
+              "is_pipelined_while_loop"))
+        continue;
+      TF_ASSIGN_OR_RETURN(WhileLoopBackendConfig config,
+                          whileOp->backend_config<WhileLoopBackendConfig>());
+      if (!config.has_known_trip_count()) {
+        VLOG(2) << "Trip count for while loop (" << whileOp->name()
+                << "): unknown";
+        continue;
+      }
+
+      int64_t trip_count = config.known_trip_count().n();
+      std::optional<int64_t> step = GetStep(whileOp);
+      VLOG(2) << "Trip count for while loop (" << whileOp->name()
+              << "): " << trip_count;
+      if (!step) {
+        VLOG(2) << "Could not find step for while operation";
+        continue;
+      }
+      VLOG(2) << "Step for while loop (" << whileOp->name() << "): " << *step;
+      if (*step != 1) {
+        VLOG(2) << "Step is not 1. Skipping...";
+        continue;
+      }
+
+      // For each source i, the send/recv iteration instances are {i, i+offset}
+      // where offset is `number of microbatches * CR - 1`. We know that
+      // `trip_count = number_of_microbatches * CR + num_devices - 1` So, offset
+      // = number_of_microbatches * CR - 1 = trip_count - num_devices.
+      int64_t offset = trip_count - num_devices;
+
+      std::vector<std::pair<int64_t, int64_t>> sendRecvValidation(
+          sourceTargetPairs.size());
+
+      for (size_t currIdx = 0; currIdx < sourceTargetPairs.size(); currIdx++) {
+        sendRecvValidation[currIdx] = {currIdx, currIdx + offset};
+      }
+
+      if (IsBackwardCycle(sourceTargetPairs)) {
+        std::reverse(sendRecvValidation.begin(), sendRecvValidation.end());
+      }
+
+      xla::FrontendAttributes attributes;
+      std::string iteration_instances =
+          "{" +
+          absl::StrJoin(sendRecvValidation, ",",
+                        [](std::string* out, std::pair<int64_t, int64_t> item) {
+                          absl::StrAppend(out, "{", item.first, ",",
+                                          item.second, "}");
+                        }) +
+          "}";
+      (*attributes.mutable_map())[kSendRecvValidationAttr] =
+          iteration_instances;
+
+      inst->add_frontend_attributes(attributes);
+      VLOG(1) << "Adding " << kSendRecvValidationAttr << " to " << inst->name()
+              << ": " << iteration_instances;
+      changed = true;
+    }
+  }
+  return changed;
+}
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/collective_permute_valid_iteration_annotator.h b/third_party/xla/xla/service/gpu/transforms/collective_permute_valid_iteration_annotator.h
new file mode 100644
index 0000000..f8999a9
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/collective_permute_valid_iteration_annotator.h
@@ -0,0 +1,58 @@
+/* Copyright 2024 The OpenXLA Authors.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+    http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_COLLECTIVE_PERMUTE_VALID_ITERATION_ANNOTATOR_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_COLLECTIVE_PERMUTE_VALID_ITERATION_ANNOTATOR_H_
+
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+
+// This is an unsafe transformation that is triggered only if the attribute
+// `is_pipelined_while_loop` is present on a while loop.
+//
+// If a while loop is known to be a pipelined while loop, has a known trip count
+// and increments with step=1, then this pass annotates the `collective-permute`
+// operations within the while loop with valid iterations for each GPU. This is
+// only done when the source-target pairs of the `collective-permute` operation
+// form a forward or backward cycle.
+//
+// For example, if the trip count is 10 (iteration 0 to 9), with step=1, and the
+// source-target pairs of a `collective-permute` operation are
+// `{{0,1},{1,2},{2,3},{3,0}}`, then this pass would annotate such operation
+// with `_xla_send_recv_validation="{{0,6},{1,7},{2,8},{3,9}}"`. This annotation
+// means that
+//   - for GPU index 0, the valid iterations are 0,1,2,3,4,5,6.
+//   - for GPU index 1, the valid iterations are 1,2,3,4,5,6,7.
+//   - for GPU index 2, the valid iterations are 2,3,4,5,6,7,8.
+//   - for GPU index 3, the valid iterations are 3,4,5,6,7,8,9.
+//
+// The index in the list denotes the device index and the bounds {start,end} are
+// inclusive. For more examples, look at
+// `xla/service/spmd/collective_permute_valid_iteration_annotator_tests.cc`.
+class CollectivePermuteValidIterationAnnotator : public HloModulePass {
+ public:
+  CollectivePermuteValidIterationAnnotator() = default;
+  absl::string_view name() const override {
+    return "collective-permute-valid-iteration-annotator";
+  }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_COLLECTIVE_PERMUTE_VALID_ITERATION_ANNOTATOR_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/collective_permute_valid_iteration_annotator_test.cc b/third_party/xla/xla/service/gpu/transforms/collective_permute_valid_iteration_annotator_test.cc
new file mode 100644
index 0000000..3585acc
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/collective_permute_valid_iteration_annotator_test.cc
@@ -0,0 +1,174 @@
+/* Copyright 2024 The OpenXLA Authors.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+    http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/collective_permute_valid_iteration_annotator.h"
+
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/service/collective_ops_utils.h"
+#include "xla/service/hlo_pass_pipeline.h"
+#include "xla/service/while_loop_trip_count_annotator.h"
+#include "xla/tests/hlo_test_base.h"
+
+namespace xla {
+namespace {
+
+using CollectivePermuteValidIterationAnnotatorTest = HloTestBase;
+
+TEST_F(CollectivePermuteValidIterationAnnotatorTest, NoChange) {
+  // We expect no changes here because the while loop is not labelled as
+  // `is_pipelined_while_loop`.
+  absl::string_view hlo_string = R"(
+    HloModule test, entry_computation_layout={()->(s32[], s32[])}
+    %Body (param: (s32[], s32[])) -> (s32[], s32[]) {
+      %param = (s32[], s32[]) parameter(0)
+      %i = s32[] get-tuple-element((s32[], s32[]) %param), index=1
+      %one = s32[] constant(1)
+      %i_plus_one = s32[] add(s32[] %i, s32[] %one)
+      %permute = s32[] collective-permute(%i_plus_one), channel_id=1, source_target_pairs={{0,1},{1,2},{2,3},{3,0}}
+      ROOT %tuple = (s32[], s32[]) tuple(s32[] %permute, s32[] %permute)
+    }
+    %Cond (param.1: (s32[], s32[])) -> pred[] {
+      %param.1 = (s32[], s32[]) parameter(0)
+      %i.1 = s32[] get-tuple-element((s32[], s32[]) %param.1), index=1
+      %trip_count = s32[] constant(10)
+      ROOT %done = pred[] compare(s32[] %i.1, s32[] %trip_count), direction=LT
+    }
+    ENTRY %test () -> (s32[], s32[]) {
+      %i_start = s32[] constant(0)
+      %p_start = s32[] constant(0)
+      %initial_tuple = (s32[], s32[]) tuple(s32[] %i_start, s32[] %p_start)
+      ROOT %while = (s32[], s32[]) while((s32[], s32[]) %initial_tuple), condition=%Cond, body=%Body
+    }
+  )";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_string, 1, 4));
+
+  HloPassPipeline pipeline("my-pass-pipeline");
+
+  pipeline.AddPass<WhileLoopTripCountAnnotator>();
+  pipeline.AddPass<CollectivePermuteValidIterationAnnotator>();
+
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get()));
+  EXPECT_FALSE(changed);
+
+  HloCollectivePermuteInstruction* cp =
+      DynCastOrNull<HloCollectivePermuteInstruction>(
+          FindInstruction(module.get(), HloOpcode::kCollectivePermute));
+
+  ASSERT_NE(cp, nullptr);
+
+  auto sendRecvValidationIt =
+      cp->frontend_attributes().map().find(kSendRecvValidationAttr);
+  ASSERT_EQ(sendRecvValidationIt, cp->frontend_attributes().map().end());
+}
+
+TEST_F(CollectivePermuteValidIterationAnnotatorTest, ForwardCycle) {
+  absl::string_view hlo_string = R"(
+    HloModule test, entry_computation_layout={()->(s32[], s32[])}
+    %Body (param: (s32[], s32[])) -> (s32[], s32[]) {
+      %param = (s32[], s32[]) parameter(0)
+      %i = s32[] get-tuple-element((s32[], s32[]) %param), index=1
+      %one = s32[] constant(1)
+      %i_plus_one = s32[] add(s32[] %i, s32[] %one)
+      %permute = s32[] collective-permute(%i_plus_one), channel_id=1, source_target_pairs={{0,1},{1,2},{2,3},{3,0}}
+      ROOT %tuple = (s32[], s32[]) tuple(s32[] %permute, s32[] %i_plus_one)
+    }
+    %Cond (param.1: (s32[], s32[])) -> pred[] {
+      %param.1 = (s32[], s32[]) parameter(0)
+      %i.1 = s32[] get-tuple-element((s32[], s32[]) %param.1), index=1
+      %trip_count = s32[] constant(10)
+      ROOT %done = pred[] compare(s32[] %i.1, s32[] %trip_count), direction=LT
+    }
+    ENTRY %test () -> (s32[], s32[]) {
+      %i_start = s32[] constant(0)
+      %p_start = s32[] constant(0)
+      %initial_tuple = (s32[], s32[]) tuple(s32[] %i_start, s32[] %p_start)
+      ROOT %while = (s32[], s32[]) while((s32[], s32[]) %initial_tuple), condition=%Cond, body=%Body, frontend_attributes={is_pipelined_while_loop="true"}
+    }
+  )";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_string, 1, 4));
+
+  HloPassPipeline pipeline("my-pass-pipeline");
+
+  pipeline.AddPass<WhileLoopTripCountAnnotator>();
+  pipeline.AddPass<CollectivePermuteValidIterationAnnotator>();
+
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get()));
+  EXPECT_TRUE(changed);
+
+  HloCollectivePermuteInstruction* cp =
+      DynCastOrNull<HloCollectivePermuteInstruction>(
+          FindInstruction(module.get(), HloOpcode::kCollectivePermute));
+
+  ASSERT_NE(cp, nullptr);
+
+  auto sendRecvValidationIt =
+      cp->frontend_attributes().map().find(kSendRecvValidationAttr);
+  ASSERT_NE(sendRecvValidationIt, cp->frontend_attributes().map().end());
+  std::string sendRecvValidationAttr = sendRecvValidationIt->second;
+  EXPECT_EQ(sendRecvValidationAttr, "{{0,6},{1,7},{2,8},{3,9}}");
+}
+
+TEST_F(CollectivePermuteValidIterationAnnotatorTest, BackwardCycle) {
+  absl::string_view hlo_string = R"(
+    HloModule test, entry_computation_layout={()->(s32[], s32[])}
+    %Body (param: (s32[], s32[])) -> (s32[], s32[]) {
+      %param = (s32[], s32[]) parameter(0)
+      %i = s32[] get-tuple-element((s32[], s32[]) %param), index=1
+      %one = s32[] constant(1)
+      %i_plus_one = s32[] add(s32[] %i, s32[] %one)
+      %permute = s32[] collective-permute(%i_plus_one), channel_id=1, source_target_pairs={{0,3},{1,0},{2,1},{3,2}}
+      ROOT %tuple = (s32[], s32[]) tuple(s32[] %permute, s32[] %i_plus_one)
+    }
+    %Cond (param.1: (s32[], s32[])) -> pred[] {
+      %param.1 = (s32[], s32[]) parameter(0)
+      %i.1 = s32[] get-tuple-element((s32[], s32[]) %param.1), index=1
+      %trip_count = s32[] constant(10)
+      ROOT %done = pred[] compare(s32[] %i.1, s32[] %trip_count), direction=LT
+    }
+    ENTRY %test () -> (s32[], s32[]) {
+      %i_start = s32[] constant(0)
+      %p_start = s32[] constant(0)
+      %initial_tuple = (s32[], s32[]) tuple(s32[] %i_start, s32[] %p_start)
+      ROOT %while = (s32[], s32[]) while((s32[], s32[]) %initial_tuple), condition=%Cond, body=%Body, frontend_attributes={is_pipelined_while_loop="true"}
+    }
+  )";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_string, 1, 4));
+
+  HloPassPipeline pipeline("my-pass-pipeline");
+
+  pipeline.AddPass<WhileLoopTripCountAnnotator>();
+  pipeline.AddPass<CollectivePermuteValidIterationAnnotator>();
+
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get()));
+  EXPECT_TRUE(changed);
+
+  HloCollectivePermuteInstruction* cp =
+      DynCastOrNull<HloCollectivePermuteInstruction>(
+          FindInstruction(module.get(), HloOpcode::kCollectivePermute));
+
+  ASSERT_NE(cp, nullptr);
+
+  auto sendRecvValidationIt =
+      cp->frontend_attributes().map().find(kSendRecvValidationAttr);
+  ASSERT_NE(sendRecvValidationIt, cp->frontend_attributes().map().end());
+  std::string sendRecvValidationAttr = sendRecvValidationIt->second;
+  EXPECT_EQ(sendRecvValidationAttr, "{{3,9},{2,8},{1,7},{0,6}}");
+}
+}  // namespace
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.cc b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.cc
new file mode 100644
index 0000000..2d4aa52
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.cc
@@ -0,0 +1,811 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/command_buffer_scheduling.h"
+
+#include <algorithm>
+#include <cstddef>
+#include <cstdint>
+#include <iterator>
+#include <memory>
+#include <utility>
+#include <variant>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/container/inlined_vector.h"
+#include "absl/status/status.h"
+#include "absl/strings/match.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/ffi/ffi_api.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_clone_context.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/hlo/ir/hlo_schedule.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
+#include "xla/service/gpu/hlo_traversal.h"
+#include "xla/service/gpu/ir_emission_utils.h"
+#include "xla/service/gpu/variant_visitor.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/util.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/logging.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla::gpu {
+
+using CommandBuffer = CommandBufferScheduling::CommandBuffer;
+using CommandBufferConfig = CommandBufferScheduling::CommandBufferConfig;
+
+// Returns true if HLO computation can be executed as a command buffer.
+static bool IsCommand(const HloComputation* computation,
+                      const CommandBufferConfig& config);
+
+//===----------------------------------------------------------------------===//
+// No-op HLO operations.
+//===----------------------------------------------------------------------===//
+
+// Some of the HLO operations do not have corresponding operations at run time
+// and they can be safely wrapped into command buffers together with load
+// bearing commands.
+
+static bool IsConstant(const HloInstruction* hlo) {
+  return hlo->opcode() == HloOpcode::kConstant;
+}
+
+static bool IsParameter(const HloInstruction* hlo) {
+  return hlo->opcode() == HloOpcode::kParameter;
+}
+
+// Returns true if instruction is no-op at run time and doesn't have a
+// corresponding Thunk or Command (metadata only operation).
+static bool IsNoOp(const HloInstruction* hlo) {
+  return HloPredicateIsOp<HloOpcode::kBitcast, HloOpcode::kTuple,
+                          HloOpcode::kGetTupleElement>(hlo);
+};
+
+//===----------------------------------------------------------------------===//
+// Synchronous HLO operations mapped to commands.
+//===----------------------------------------------------------------------===//
+
+// Synchronous HLO operations can be wrapped into command buffers when they have
+// a corresponding commands.
+
+// This is a template to define pattern matching functions for HLO instructions
+// that do not have a corresponding class for them.
+template <HloOpcode op>
+static bool IsCommand(const HloInstruction*, const CommandBufferConfig&);
+
+// While loops can be executed inside command buffers only if condition and body
+// regions can be executed as command buffers.
+template <>
+bool IsCommand<HloOpcode::kWhile>(const HloInstruction* hlo,
+                                  const CommandBufferConfig& config) {
+  return config.enabled_commands.contains(DebugOptions::CONDITIONALS) &&
+         IsCommand(hlo->while_body(), config) &&
+         IsCommand(hlo->while_condition(), config);
+}
+
+// Conditional can be executed inside command buffers only if all regions of its
+// branches can be executed as command buffers.
+template <>
+bool IsCommand<HloOpcode::kConditional>(const HloInstruction* hlo,
+                                        const CommandBufferConfig& config) {
+  return config.enabled_commands.contains(DebugOptions::CONDITIONALS) &&
+         absl::c_all_of(hlo->branch_computations(),
+                        [&](const HloComputation* comp) {
+                          return IsCommand(comp, config);
+                        });
+}
+
+static bool IsCommand(const HloCustomCallInstruction* hlo,
+                      const CommandBufferConfig& config) {
+  // cuBLAS gemms represented in the HLO as custom call instructions.
+  if (config.enabled_commands.contains(DebugOptions::CUBLAS) &&
+      IsLegacyCublasMatmul(*hlo)) {
+    return true;
+  }
+
+  if (config.enabled_commands.contains(DebugOptions::CUBLASLT) &&
+      (IsCublasLtMatmul(*hlo) || IsCublasLtMatmulF8(*hlo))) {
+    return true;
+  }
+
+  if (config.enabled_commands.contains(DebugOptions::CUDNN) &&
+      IsCustomCallTofMHA(*hlo)) {
+    VLOG(3) << "Recording FusedMHA, target " << hlo->custom_call_target()
+            << " into command buffer.";
+    return true;
+  }
+
+  if (!config.enabled_commands.contains(DebugOptions::CUSTOM_CALL)) {
+    return false;
+  }
+
+  if (config.enabled_legacy_custom_call_targets.contains(
+          hlo->custom_call_target())) {
+    VLOG(3) << "Recording legacy custom call target "
+            << hlo->custom_call_target() << " into command buffer.";
+    return true;
+  }
+
+  // A special case for jax-triton kernel while it is not ported to FFI.
+  if (hlo->custom_call_target() == "triton_kernel_call" &&
+      // TODO(b/327718087): This is an ugly hack to prevent capturing triton
+      // custom calls that might do autotuning at run time.
+      !absl::StrContains(hlo->metadata().op_name(), "Autotuner")) {
+    return true;
+  }
+
+  // Check if FFI handler is compatible with command buffers.
+  auto registration = ffi::FindHandler(hlo->custom_call_target(), "gpu");
+  return registration.ok()
+             ? ffi::IsCommandBufferCompatible(registration->traits)
+             : false;
+}
+
+static bool IsCommand(const HloInstruction* hlo,
+                      const CommandBufferConfig& config) {
+  if (auto* fusion = DynCast<HloFusionInstruction>(hlo)) {
+    auto gpu_config = fusion->backend_config<GpuBackendConfig>();
+    const FusionBackendConfig& backend_config =
+        gpu_config->fusion_backend_config();
+    if (backend_config.kind() == kCuDnnFusionKind) {
+      return config.enabled_commands.contains(DebugOptions::CUDNN);
+    }
+    const auto& custom_config = backend_config.custom_fusion_config();
+    if (custom_config.name() == "address_computation") {
+      auto fusion_analysis =
+          HloFusionAnalysis::Create(*hlo, config.device_description);
+      const HloFusionAdaptor& adaptor = fusion_analysis.fusion();
+      auto custom_call_adaptor = HloBfsFindIf(
+          adaptor.GetRoots(), adaptor,
+          [](auto node) { return node.opcode() == HloOpcode::kCustomCall; });
+      const auto* custom_call = static_cast<const HloCustomCallInstruction*>(
+          &custom_call_adaptor->instruction());
+      return IsCommand(custom_call, config);
+    }
+    if (custom_config.name() == "dynamic_address_computation") {
+      return false;
+    }
+    return config.enabled_commands.contains(DebugOptions::FUSION);
+  }
+
+  if (auto* sort = DynCast<HloSortInstruction>(hlo))
+    return config.enabled_commands.contains(DebugOptions::FUSION);
+
+  if (hlo->opcode() == HloOpcode::kPartitionId ||
+      hlo->opcode() == HloOpcode::kReplicaId) {
+    return config.enabled_commands.contains(DebugOptions::FUSION);
+  }
+
+  if (auto* custom_call = DynCast<HloCustomCallInstruction>(hlo))
+    return IsCommand(custom_call, config);
+
+  if (hlo->opcode() == HloOpcode::kWhile)
+    return IsCommand<HloOpcode::kWhile>(hlo, config);
+
+  if (hlo->opcode() == HloOpcode::kConditional)
+    return IsCommand<HloOpcode::kConditional>(hlo, config);
+
+  return false;
+}
+
+//===----------------------------------------------------------------------===//
+// Asynchronous HLO operations mapped to commands.
+//===----------------------------------------------------------------------===//
+
+// Asynchronous HLO operations can be wrapped into command buffers only when
+// both start and done operations can be put into the same command buffer.
+// Command buffer semantics implies that when command buffer execution
+// completes, all recorded commands are also completed, which means that if
+// done operation is not part of the same command buffer, we would change the
+// execution semantics and create additional synchronization point.
+
+static bool IsAsyncStartCommand(const HloInstruction* hlo,
+                                const CommandBufferConfig& config) {
+  if (hlo->opcode() == HloOpcode::kAllReduceStart ||
+      hlo->opcode() == HloOpcode::kAllGatherStart) {
+    return config.enabled_commands.contains(DebugOptions::COLLECTIVES);
+  }
+
+  if (hlo->opcode() == HloOpcode::kAsyncStart) {
+    if (hlo->async_wrapped_opcode() == HloOpcode::kReduceScatter) {
+      return config.enabled_commands.contains(DebugOptions::COLLECTIVES);
+    }
+  }
+
+  return false;
+}
+
+static bool IsAsyncDoneCommand(const HloInstruction* hlo,
+                               const CommandBufferConfig& config) {
+  if (hlo->opcode() == HloOpcode::kAllReduceDone ||
+      hlo->opcode() == HloOpcode::kAllGatherDone) {
+    return config.enabled_commands.contains(DebugOptions::COLLECTIVES);
+  }
+
+  if (hlo->opcode() == HloOpcode::kAsyncDone) {
+    if (hlo->async_wrapped_opcode() == HloOpcode::kReduceScatter) {
+      return config.enabled_commands.contains(DebugOptions::COLLECTIVES);
+    }
+  }
+
+  return false;
+}
+
+// Finds an async-done HLO operation corresponding on an async-start one.
+static HloInstruction* FindAsyncDoneCommand(const HloInstruction* start) {
+  if (start->opcode() == HloOpcode::kAllReduceStart ||
+      start->opcode() == HloOpcode::kAllGatherStart) {
+    CHECK(start->users().size() == 1);  // NOLINT, checked by HLO verifier
+    return start->users().front();
+  } else if (start->opcode() == HloOpcode::kAsyncStart) {
+    return start->async_chain_done();
+  }
+
+  return nullptr;
+}
+
+//===----------------------------------------------------------------------===//
+// HLO computations mapped to command buffers.
+//===----------------------------------------------------------------------===//
+
+// Returns true if HLO computation can be executed as a command buffer.
+static bool IsCommand(const HloComputation* computation,
+                      const CommandBufferConfig& config) {
+  return absl::c_all_of(
+      computation->instructions(), [&](const HloInstruction* inst) {
+        return IsNoOp(inst) || IsConstant(inst) || IsParameter(inst) ||
+               IsCommand(inst, config) || IsAsyncStartCommand(inst, config) ||
+               IsAsyncDoneCommand(inst, config);
+      });
+}
+
+//===----------------------------------------------------------------------===//
+
+static void RemoveTrailingNoOps(HloInstructionSequence& seq) {
+  std::vector<HloInstruction*> instructions = seq.instructions();
+  for (int i = instructions.size() - 1; i >= 0; i--) {
+    if (HloInstruction* inst = instructions[i]; IsNoOp(inst)) {
+      seq.remove_instruction(inst);
+    } else {
+      break;
+    }
+  }
+}
+
+//===----------------------------------------------------------------------===//
+// Discovering sequences of compatible Hlo instructions
+//===----------------------------------------------------------------------===//
+
+// The input is a scheduled sequence of instructions. This function collects
+// subsequences that will be extracted as command buffers.
+std::vector<HloInstructionSequence>
+CommandBufferScheduling::CollectCommandBufferSequences(
+    const HloInstructionSequence schedule, const CommandBufferConfig& config,
+    int32_t min_num_commands) {
+  std::vector<HloInstructionSequence> sequences;
+
+  HloInstructionSequence current_seq;
+  int64_t num_commands_in_current_seq = 0;
+
+  // Adds `current_seq` to `sequences` if it has enough commands in it.
+  auto collect_current_seq = [&]() {
+    if (num_commands_in_current_seq >= std::max(1, min_num_commands)) {
+      RemoveTrailingNoOps(current_seq);
+      sequences.push_back(std::move(current_seq));
+    }
+    current_seq = HloInstructionSequence();
+    num_commands_in_current_seq = 0;
+  };
+
+  auto& instructions = schedule.instructions();
+
+  // Collect the sequence of instructions that contains the async start and its
+  // corresponding done instruction. If there is another start instruction
+  // between the original start and done, we may potentially extend the sequence
+  // to include its corresponding done instruction. For example, if we call this
+  // function on async-start_a in the following sequence:
+  //
+  // async_start_a
+  // async_start_b
+  // async_done_a
+  // async_done_b
+  //
+  // The returned sequence will contain async_done_b. So that all async pairs
+  // are captured by the same command buffer.
+  auto collect_async_region = [&](const HloInstruction* start) {
+    auto get_index = [&](const HloInstruction* inst) -> size_t {
+      auto it = std::find(instructions.begin(), instructions.end(), inst);
+      return std::distance(instructions.begin(), it);
+    };
+
+    HloInstructionSequence seq;
+    size_t done_index = get_index(FindAsyncDoneCommand(start));
+    for (size_t i = get_index(start); i <= done_index; i++) {
+      HloInstruction* inst = instructions.at(i);
+      if (IsAsyncStartCommand(inst, config)) {
+        const HloInstruction* done = FindAsyncDoneCommand(inst);
+        done_index = std::max(done_index, get_index(done));
+      }
+      seq.push_back(inst);
+    }
+    return seq;
+  };
+
+  // Check that instructions are safe to be captured by command buffer, and that
+  // we do not capture unmatched async done instruction.
+  auto check_async_region = [&](const HloInstructionSequence& seq) {
+    if (!absl::c_all_of(seq.instructions(), [&](HloInstruction* inst) {
+          return IsNoOp(inst) || IsCommand(inst, config) ||
+                 IsAsyncStartCommand(inst, config) ||
+                 IsAsyncDoneCommand(inst, config);
+        })) {
+      return false;
+    }
+
+    absl::flat_hash_set<HloInstruction*> done_instructions;
+    for (const HloInstruction* inst : seq.instructions()) {
+      if (IsAsyncStartCommand(inst, config)) {
+        done_instructions.insert(FindAsyncDoneCommand(inst));
+      }
+      if (IsAsyncDoneCommand(inst, config)) {
+        if (!done_instructions.contains(inst)) {
+          return false;
+        }
+      }
+    }
+    return true;
+  };
+
+  for (size_t i = 0; i < instructions.size(); i++) {
+    HloInstruction* inst = instructions.at(i);
+
+    // We add no-op instructions to current sequence only if they act as a glue
+    // between commands. We do not create command sequences consisting only from
+    // no-op instruction. First and last instruction in the command buffer is
+    // always a load-bearing command.
+    if (IsNoOp(inst) && num_commands_in_current_seq) {
+      current_seq.push_back(inst);
+      continue;
+    }
+
+    // Synchronous commands always can be added to instruction sequence.
+    if (IsCommand(inst, config)) {
+      num_commands_in_current_seq++;
+      current_seq.push_back(inst);
+      continue;
+    }
+
+    // We capture async commands if all instruction between start and done can
+    // be outlined into a command buffer.
+    if (IsAsyncStartCommand(inst, config)) {
+      HloInstructionSequence seq = collect_async_region(inst);
+      if (check_async_region(seq)) {
+        num_commands_in_current_seq += seq.instructions().size();
+        for (HloInstruction* inst : seq.instructions()) {
+          current_seq.push_back(inst);
+        }
+        i += seq.instructions().size() - 1;
+        continue;
+      }
+    }
+
+    // If we didn't find the next command, collect the current sequence and
+    // start a new one.
+    collect_current_seq();
+  }
+
+  // Don't forget to collect the final command sequence.
+  collect_current_seq();
+  return sequences;
+}
+
+// This function moves kParameter and kConstant instructions in a computation to
+// the beginning of the computation. This simplifies the construction of command
+// buffer computations because we don't need to deal with parameters and
+// constants that have users outside of a command buffer.
+absl::Status CommandBufferScheduling::MoveParametersAndConstantsToFront(
+    HloComputation* computation) {
+  HloInstructionSequence new_sequence;
+  HloSchedule& schedule = computation->parent()->schedule();
+  HloInstructionSequence& sequence = schedule.GetOrCreateSequence(computation);
+
+  for (HloInstruction* inst : sequence.instructions()) {
+    if (IsParameter(inst) || IsConstant(inst)) {
+      new_sequence.push_back(inst);
+
+      // Because we move instruction to the front of the computation we can't
+      // have any control predecessors, however silently dropping them is unsafe
+      // as we can have transitive dependencies that define schedule order, so
+      // we forward control predecessors to all users.
+      for (HloInstruction* control_predecessor : inst->control_predecessors()) {
+        for (HloInstruction* user : inst->users()) {
+          TF_RETURN_IF_ERROR(control_predecessor->AddControlDependencyTo(user));
+        }
+      }
+      TF_RETURN_IF_ERROR(inst->DropAllControlDeps());
+    }
+  }
+
+  for (HloInstruction* inst : sequence.instructions()) {
+    if (!IsParameter(inst) && !IsConstant(inst)) {
+      new_sequence.push_back(inst);
+    }
+  }
+
+  schedule.set_sequence(computation, new_sequence);
+  return absl::OkStatus();
+}
+
+//===----------------------------------------------------------------------===//
+// Prepares command buffer from sequence of instructions
+//===----------------------------------------------------------------------===//
+
+absl::StatusOr<CommandBuffer> CommandBufferScheduling::PrepareCommandBuffer(
+    const HloInstructionSequence& seq, HloModule* module) {
+  auto builder = HloComputation::Builder("command_buffer");
+
+  absl::Span<HloInstruction* const> instructions =
+      absl::MakeSpan(seq.instructions());
+
+  // A set of instructions that will be moved into command buffer computation.
+  absl::flat_hash_set<HloInstruction*> in_command_buffer(instructions.begin(),
+                                                         instructions.end());
+
+  // The sequence might use results of instructions that are not captured by the
+  // sequence. We pass those results as parameters and map the producers of the
+  // results to their corresponding parameter instructions.
+  absl::flat_hash_map<HloInstruction*, HloParameterInstruction*> parameters;
+
+  // Mapping from command buffer instructions to their clones in the command
+  // buffer computation body.
+  absl::flat_hash_map<HloInstruction*, HloInstruction*> inst_mapping;
+
+  // Maps HLO instructions in the original computation to instructions in the
+  // command buffer: (a) a parameter corresponding to captured value (b) cloned
+  // instruction corresponding to a command.
+  auto mapped_operands = [&](HloInstruction* instr) {
+    absl::InlinedVector<HloInstruction*, 4> operands;
+    for (HloInstruction* operand : instr->operands()) {
+      if (auto it = inst_mapping.find(operand); it != inst_mapping.end())
+        operands.push_back(it->second);
+    }
+    return operands;
+  };
+
+  // Create parameters in the command buffer computation for captured values.
+  for (HloInstruction* inst : instructions) {
+    for (HloInstruction* operand : inst->operands()) {
+      // We already mapped instruction to a parameter.
+      if (parameters.contains(operand)) continue;
+
+      // Operand instruction is a part of the command buffer.
+      if (in_command_buffer.contains(operand)) continue;
+
+      // Create a new parameter for value defined outside of a command buffer.
+      int64_t parameter_id = parameters.size();
+      auto* parameter = Cast<HloParameterInstruction>(
+          builder.AddInstruction(HloInstruction::CreateParameter(
+              parameter_id, operand->shape(), "p")));
+
+      parameter->UniquifyName(module);
+      parameter->UniquifyId(module);
+      inst_mapping[operand] = parameters[operand] = parameter;
+    }
+  }
+
+  // Clone commands into the command buffer body with mapped operands.
+  for (HloInstruction* inst : seq.instructions()) {
+    HloCloneContext ctx(inst->GetModule());
+
+    // Cloned instructions should call the same computations as original
+    // instructions will be dead code eliminated.
+    for (HloComputation* called_computation : inst->called_computations()) {
+      // Async computations can only be referenced by a single async chain at
+      // a time. Detach the current chain to let its copy bind to the
+      // computation.
+      if (called_computation->IsAsyncComputation()) {
+        called_computation->RemoveAsyncStart();
+      }
+      ctx.MapComputation(called_computation, called_computation);
+    }
+
+    inst_mapping[inst] = builder.AddInstruction(
+        inst->CloneWithNewOperands(inst->shape(), mapped_operands(inst), &ctx));
+    inst_mapping[inst]->UniquifyId(module);
+  }
+
+  // Convert parameters to command buffer arguments.
+  std::vector<HloInstruction*> arguments(parameters.size());
+  for (auto& [argument, parameter] : parameters) {
+    arguments[parameter->parameter_number()] = argument;
+  }
+
+  // Collect command buffer `results` (instructions replaced in the original
+  // computation) and `results` (instructions in the command buffer).
+  std::vector<HloInstruction*> results;
+  std::vector<HloInstruction*> returned;
+
+  auto has_external_users = [&](HloInstruction* inst) {
+    return inst->IsRoot() || absl::c_any_of(inst->users(), [&](auto* user) {
+             return !in_command_buffer.contains(user);
+           });
+  };
+
+  for (HloInstruction* inst : instructions) {
+    if (has_external_users(inst)) {
+      results.push_back(inst);
+      returned.push_back(inst_mapping[inst]);
+    }
+  }
+
+  // If we return multiple results wrap them into tuple.
+  if (returned.size() > 1) {
+    HloInstruction* inst =
+        builder.AddInstruction(HloInstruction::CreateTuple(returned));
+    inst->UniquifyName(module);
+    inst->UniquifyId(module);
+  }
+
+  std::unique_ptr<HloComputation> comp = builder.Build();
+  comp->UniquifyName(module);
+  comp->SetUniqueId(comp->root_instruction()->unique_id());
+
+  return CommandBuffer{std::move(arguments), std::move(results),
+                       std::move(comp), std::move(inst_mapping)};
+}
+
+//===----------------------------------------------------------------------===//
+// Rewrites original computation into command buffer call
+//===----------------------------------------------------------------------===//
+
+absl::StatusOr<HloComputation*> CommandBufferScheduling::RewriteCommandBuffer(
+    HloComputation* parent, const HloInstructionSequence& seq,
+    CommandBuffer command_buffer) {
+  if (command_buffer.results.empty())
+    return absl::InternalError("command buffer results must not be empty");
+
+  // If we have more than one result we return them as tuple, and get individual
+  // values using `get-tuple-element` instructions. Otherwise we simply return
+  // a result from a command buffer computation.
+  Shape cmd_buffer_result_shape;
+  bool has_single_result = command_buffer.results.size() == 1;
+
+  if (has_single_result) {
+    cmd_buffer_result_shape = command_buffer.results[0]->shape();
+  } else {
+    absl::InlinedVector<Shape, 4> shapes;
+    shapes.reserve(command_buffer.results.size());
+    for (auto* res : command_buffer.results) shapes.push_back(res->shape());
+    cmd_buffer_result_shape = ShapeUtil::MakeTupleShape(shapes);
+  }
+
+  HloComputation* computation =
+      parent->parent()->AddComputation(std::move(command_buffer.computation),
+                                       /*is_entry=*/false);
+
+  HloInstruction* call = parent->AddInstruction(HloInstruction::CreateCall(
+      cmd_buffer_result_shape, command_buffer.arguments, computation));
+
+  // Replace all users or original results with a command buffer results.
+  if (has_single_result) {
+    TF_RETURN_IF_ERROR(command_buffer.results[0]->ReplaceAllUsesWith(call));
+  } else {
+    for (int i = 0; i < command_buffer.results.size(); i++) {
+      TF_RETURN_IF_ERROR(
+          command_buffer.results[i]->ReplaceAllUsesWith(parent->AddInstruction(
+              HloInstruction::CreateGetTupleElement(call, i))));
+    }
+  }
+
+  // As we are running after scheduling we have to keep it valid.
+  HloSchedule& schedule = parent->parent()->schedule();
+
+  // Update schedule to replace the last instruction with a command buffer call.
+  // Removal of the rest of the instructions in the sequence is handled by
+  // schedule update below.
+  HloInstructionSequence& sequence = schedule.GetOrCreateSequence(parent);
+  sequence.replace_instruction(seq.instructions().back(), call);
+
+  // Rebuild original instruction sequence schedule in a newly created
+  // command buffer computation to guarantee that we'll get exactly the same
+  // buffer assignment result as if we were running without command buffers.
+  HloInstructionSequence cmd_buffer_schedule;
+  for (auto* argument : command_buffer.arguments) {
+    cmd_buffer_schedule.push_back(command_buffer.inst_mapping[argument]);
+  }
+  for (auto* inst : seq.instructions()) {
+    cmd_buffer_schedule.push_back(command_buffer.inst_mapping[inst]);
+  }
+  if (!has_single_result) {
+    cmd_buffer_schedule.push_back(computation->root_instruction());
+  }
+  schedule.set_sequence(computation, cmd_buffer_schedule);
+
+  // Forward control dependencies between original instructions to instruction
+  // in the command buffer computation.
+  auto& inst_mapping = command_buffer.inst_mapping;
+  for (HloInstruction* inst : seq.instructions()) {
+    HloInstruction* cmd_inst = inst_mapping[inst];
+
+    // Forward control dependencies to the new instruction inside command
+    // buffer. If the dependent instruction is not captured by the command
+    // buffer, forward the dependency to the command buffer call instead.
+    for (HloInstruction* predecessor : inst->control_predecessors()) {
+      if (auto it = inst_mapping.find(predecessor); it != inst_mapping.end()) {
+        // If predecessor mapped to a parameter instruction it means that we
+        // need to forward control dependency to a call operation, otherwise
+        // we add control dependency between commands in the command buffer.
+        HloInstruction* cmd_predecessor = it->second;
+        if (IsParameter(cmd_predecessor)) {
+          TF_RETURN_IF_ERROR(predecessor->AddControlDependencyTo(call));
+        } else {
+          TF_RETURN_IF_ERROR(cmd_predecessor->AddControlDependencyTo(cmd_inst));
+        }
+      } else {
+        TF_RETURN_IF_ERROR(predecessor->AddControlDependencyTo(call));
+      }
+    }
+
+    for (HloInstruction* successor : inst->control_successors()) {
+      if (auto it = inst_mapping.find(successor); it != inst_mapping.end()) {
+        HloInstruction* cmd_successor = it->second;
+        TF_RETURN_IF_ERROR(cmd_inst->AddControlDependencyTo(cmd_successor));
+      } else {
+        TF_RETURN_IF_ERROR(call->AddControlDependencyTo(successor));
+      }
+    }
+
+    TF_RETURN_IF_ERROR(inst->DropAllControlDeps());
+  }
+
+  // Traverse in reverse order as original sequence was topologically sorted and
+  // we can't remove instructions with users.
+  for (int32_t i = seq.instructions().size() - 1; i >= 0; i--) {
+    TF_RETURN_IF_ERROR(parent->RemoveInstruction(seq.instructions()[i]));
+  }
+
+  return computation;
+}
+
+//===----------------------------------------------------------------------===//
+
+CommandBufferScheduling::CommandBufferScheduling(
+    const se::DeviceDescription& device_description,
+    int32_t gpu_toolkit_version, int32_t gpu_driver_version)
+    : device_description_(device_description),
+      gpu_toolkit_version_(gpu_toolkit_version),
+      gpu_driver_version_(gpu_driver_version) {}
+
+absl::StatusOr<bool> CommandBufferScheduling::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  // We run command buffer scheduling after a regular scheduling to guarantee
+  // that command buffers will not change execution order and buffer assignment
+  // compared to a regular execution. Some operations (i.e. async collectives)
+  // can't be captured into command buffers, and forming too large command
+  // buffers too early can impact async operations scheduling.
+  if (!module->has_schedule()) return Internal("module is not scheduled");
+
+  const DebugOptions& debug_options = module->config().debug_options();
+
+  absl::flat_hash_set<DebugOptions::CommandBufferCmdType> commands;
+  for (auto cmd_type : debug_options.xla_gpu_enable_command_buffer()) {
+    commands.insert(static_cast<DebugOptions::CommandBufferCmdType>(cmd_type));
+  }
+
+  absl::flat_hash_set<std::string> legacy_custom_call_targets;
+  for (const auto& target :
+       debug_options.legacy_command_buffer_custom_call_targets()) {
+    legacy_custom_call_targets.insert(target);
+  }
+
+  CommandBufferConfig config{std::move(commands),
+                             std::move(legacy_custom_call_targets),
+                             device_description_};
+
+  // Erase command buffer cmd types that are not supported by the gpu runtime.
+  static constexpr auto kRequireConditionals = {DebugOptions::CONDITIONALS};
+  static constexpr auto kRequireTracing = {
+      DebugOptions::CUBLAS, DebugOptions::CUBLASLT, DebugOptions::CUDNN,
+      DebugOptions::CUSTOM_CALL, DebugOptions::COLLECTIVES};
+
+  auto erase = [&](absl::Span<const DebugOptions::CommandBufferCmdType> cmds) {
+    for (auto cmd : cmds) {
+      if (config.enabled_commands.erase(cmd)) {
+        VLOG(1) << "Removed command buffer support for "
+                << DebugOptions::CommandBufferCmdType_Name(cmd)
+                << " as it's not supported with gpu toolkit version "
+                << gpu_toolkit_version_ << " and driver version "
+                << gpu_driver_version_
+                << ". This might negatively impact peformance. To enable "
+                << DebugOptions::CommandBufferCmdType_Name(cmd)
+                << " support in command buffers use cuda-compat package: "
+#if defined(PLATFORM_GOOGLE)
+                << "set CUDA_COMPAT_LOAD=1 env variable.";
+#else
+                << "https://docs.nvidia.com/deploy/cuda-compatibility/.";
+#endif
+      }
+    }
+  };
+
+  // Check if CUDA/ROCM driver supports required features.
+  auto erase_cuda = [&](const se::CudaComputeCapability& cuda_comp) {
+    if (std::min(gpu_toolkit_version_, gpu_driver_version_) < 12030) {
+      erase(kRequireTracing);       // cuStreamBeginCaptureToGraph
+      erase(kRequireConditionals);  // on-device control flow
+    }
+  };
+  auto erase_rocm = [&](const se::RocmComputeCapability& rocm_comp) {
+    erase(kRequireConditionals);  // on-device control flow
+  };
+
+  std::visit(VariantVisitor{erase_cuda, erase_rocm},
+             device_description_.gpu_compute_capability());
+
+  auto order = module->MakeComputationPostOrder();
+  std::reverse(order.begin(), order.end());
+  absl::flat_hash_set<HloComputation*> processed_command_buffers;
+
+  for (HloComputation* comp : order) {
+    // Skip special computations that do not have lowering to thunks.
+    if (comp->IsFusionComputation() || comp->IsAsyncComputation() ||
+        comp->IsCustomCallComputation())
+      continue;
+
+    // Skip computations that already part of command buffers.
+    if (processed_command_buffers.contains(comp)) continue;
+
+    TF_RETURN_IF_ERROR(MoveParametersAndConstantsToFront(comp));
+
+    std::vector<HloInstructionSequence> sequences =
+        CollectCommandBufferSequences(
+            module->schedule().sequence(comp), config,
+            debug_options.xla_gpu_graph_min_graph_size());
+
+    for (const HloInstructionSequence& seq : sequences) {
+      TF_ASSIGN_OR_RETURN(CommandBuffer command_buffer,
+                          PrepareCommandBuffer(seq, comp->parent()));
+      TF_ASSIGN_OR_RETURN(
+          HloComputation * command_buffer_computation,
+          RewriteCommandBuffer(comp, seq, std::move(command_buffer)));
+
+      // All computations reachable from a command buffer computation are nested
+      // command buffers (i.e. body computations attached to a while operation).
+      for (HloComputation* called :
+           command_buffer_computation->MakeEmbeddedComputationsList()) {
+        processed_command_buffers.insert(called);
+      }
+    }
+  }
+  TF_RETURN_IF_ERROR(module->schedule().Update());
+
+  return true;
+}
+
+}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.h b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.h
new file mode 100644
index 0000000..30e0249
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.h
@@ -0,0 +1,143 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_COMMAND_BUFFER_SCHEDULING_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_COMMAND_BUFFER_SCHEDULING_H_
+
+#include <cstdint>
+#include <memory>
+#include <vector>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/status.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_schedule.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/stream_executor/device_description.h"
+
+namespace xla::gpu {
+
+// Lift fusion instructions to command buffers.
+//
+// Before the pass:
+//   %fused_computation (param_0: s32[], param_1: s32[]) -> s32[] {
+//     ...
+//   }
+//
+//   ENTRY %main (a: s32[], b: s32[]) -> s32[] {
+//     %a = s32[] parameter(0)
+//     %b = s32[] parameter(1)
+//     ROOT %fusion = s32[] fusion(s32[] %a, s32[] %b), kind=kLoop,
+//       calls=%fused_computation
+//   }
+//
+// After the pass:
+//   %fused_computation (param_0: s32[], param_1: s32[]) -> s32[] {
+//     ...
+//   }
+//
+//   %command_buffer (param_0: s32[], param_1: s32[]) -> s32[] {
+//     %param_0 = s32[] parameter(0)
+//     %param_1 = s32[] parameter(1)
+//     ROOT %fusion = s32[] fusion(s32[] %param_0, s32[] %param_1), kind=kLoop,
+//       calls=%fused_computation
+//   }
+//
+//   ENTRY %main (a: s32[], b: s32[]) -> s32[] {
+//     %a = s32[] parameter(0)
+//     %b = s32[] parameter(1)
+//     ROOT %call = s32[] call(s32[] %a, s32[] %b), to_apply=%command_buffer
+//  }
+//
+// We currently do not have a command_buffer HLO operation, so we'll start with
+// a kCall op code with an attached HLO computation. We'll consider graduating
+// custom call to a first class operation later.
+class CommandBufferScheduling : public HloModulePass {
+ public:
+  struct CommandBufferConfig {
+    // DebugOptions control which commands are enabled. Long term we want to
+    // remove that flag and enable all supported commands by default.
+    absl::flat_hash_set<DebugOptions::CommandBufferCmdType> enabled_commands;
+    absl::flat_hash_set<std::string> enabled_legacy_custom_call_targets;
+    const se::DeviceDescription& device_description;
+  };
+
+  CommandBufferScheduling(const se::DeviceDescription& device_description,
+                          int32_t gpu_toolkit_version,
+                          int32_t gpu_driver_version);
+
+  absl::string_view name() const override {
+    return "command-buffer-scheduling";
+  }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+  static std::vector<HloInstructionSequence> CollectCommandBufferSequences(
+      HloInstructionSequence schedule, const CommandBufferConfig& config,
+      int32_t min_num_commands = 1);
+
+  // Moves kParameter and kConstant instructions in a computation to
+  // the beginning of the computation. This simplifies the construction of
+  // command buffer computations because we don't need to deal with parameters
+  // and constants that have users outside of a command buffer.
+  static absl::Status MoveParametersAndConstantsToFront(
+      HloComputation* computation);
+
+  struct CommandBuffer {
+    // Command buffer arguments (call instruction arguments).
+    std::vector<HloInstruction*> arguments;
+
+    // Command buffer result (call instruction result tuple).
+    std::vector<HloInstruction*> results;
+
+    // Hlo computation corresponding to a command buffer body.
+    std::unique_ptr<HloComputation> computation;
+
+    // Mapping from original instruction to their clones in the command buffer.
+    absl::flat_hash_map<HloInstruction*, HloInstruction*> inst_mapping;
+  };
+
+  // Prepares a command buffer from the instruction sequence. Used values
+  // constructed by instructions outside of the sequence are passed in as
+  // parameters. Results of instructions in the sequence are returned in a tuple
+  // (if command buffer has a single result we don't wrap it into tuple).
+  static absl::StatusOr<CommandBuffer> PrepareCommandBuffer(
+      const HloInstructionSequence& seq, HloModule* module);
+
+  // Rewrites prepared command buffer computation into Hlo operations in the
+  // parent computation (calls command buffer and replaced all users).
+  static absl::StatusOr<HloComputation*> RewriteCommandBuffer(
+      HloComputation* parent, const HloInstructionSequence& seq,
+      CommandBuffer command_buffer);
+
+ private:
+  se::DeviceDescription device_description_;
+  // For NVIDIA gpus XLA can be compiled with a CUDA version that is larger than
+  // the version supported by the driver, e.g. we can compile for CUDA 12.3 but
+  // have 12.1 driver installed. When deciding what command buffer features we
+  // can use we have to consider both versions.
+  int32_t gpu_toolkit_version_;
+  int32_t gpu_driver_version_;
+};
+
+}  // namespace xla::gpu
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_COMMAND_BUFFER_SCHEDULING_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling_test.cc b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling_test.cc
new file mode 100644
index 0000000..43d0dae
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling_test.cc
@@ -0,0 +1,1018 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "xla/service/gpu/transforms/command_buffer_scheduling.h"
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/hlo/ir/hlo_schedule.h"
+#include "xla/service/hlo_parser.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/tests/filecheck.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/tests/verified_hlo_module.h"
+#include "xla/tsl/lib/core/status_test_util.h"
+#include "tsl/platform/status.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla::gpu {
+namespace {
+
+class CommandBufferSchedulingTest : public HloTestBase {
+ public:
+  // Use CUDA 12.3 version for testing as it has all the features we rely on.
+  static constexpr int32_t kCudaVersion = 12030;
+
+  const se::DeviceDescription& device_desc() {
+    return backend().default_stream_executor()->GetDeviceDescription();
+  }
+
+  DebugOptions GetDebugOptionsForTest() override {
+    auto debug_options = HloTestBase::GetDebugOptionsForTest();
+    debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::FUSION);
+    debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CONDITIONALS);
+    debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::COLLECTIVES);
+    debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CUDNN);
+    debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CUBLASLT);
+    debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CUSTOM_CALL);
+    debug_options.set_xla_gpu_graph_min_graph_size(2);
+    return debug_options;
+  }
+};
+
+using CommandBuffer = CommandBufferScheduling::CommandBuffer;
+
+TEST_F(CommandBufferSchedulingTest, SingleCommandBuffer) {
+  const char* hlo = R"(
+      HloModule TestModule, is_scheduled=true
+
+      %fused_computation (param_0: s32[], param_1: s32[]) -> s32[] {
+        %p0 = s32[] parameter(0)
+        %p1 = s32[] parameter(1)
+        ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
+      }
+
+      %fused_computation.1 (param_0: s32[], param_1: s32[]) -> s32[] {
+        %p0 = s32[] parameter(0)
+        %p1 = s32[] parameter(1)
+        ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
+      }
+
+      ENTRY %main (a: s32[], b: s32[]) -> s32[] {
+        %a = s32[] parameter(0)
+        %b = s32[] parameter(1)
+        %fusion = s32[] fusion(s32[] %a, s32[] %b), kind=kLoop, calls=%fused_computation
+        %fusion.1 = s32[] fusion(s32[] %a, s32[] %b), kind=kLoop, calls=%fused_computation.1
+        ROOT %custom-call = s32[] custom-call(s32[] %fusion, s32[] %fusion.1), custom_call_target="some target"
+      })";
+
+  const char* expected = R"(
+// CHECK: %command_buffer ([[P0:.+]]: s32[], [[P1:.+]]: s32[]) -> (s32[], s32[]) {
+// CHECK:   %[[P0]] = s32[] parameter(0)
+// CHECK:   %[[P1]] = s32[] parameter(1)
+// CHECK:   %fusion = s32[] fusion(%[[P0]], %[[P1]]), kind=kLoop, calls=%fused_computation
+// CHECK:   %fusion.1 = s32[] fusion(%[[P0]], %[[P1]]), kind=kLoop, calls=%fused_computation.1
+// CHECK:   ROOT %tuple = (s32[], s32[]) tuple(%fusion, %fusion.1)
+// CHECK: }
+//
+// CHECK: ENTRY %main (a: s32[], b: s32[]) -> s32[] {
+// CHECK:   %a = s32[] parameter(0)
+// CHECK:   %b = s32[] parameter(1)
+// CHECK:   %call = (s32[], s32[]) call(%a, %b), to_apply=%command_buffer
+// CHECK:   %get-tuple-element = s32[] get-tuple-element(%call), index=0
+// CHECK:   %get-tuple-element.1 = s32[] get-tuple-element(%call), index=1
+// CHECK:   ROOT %custom-call = s32[] custom-call(%get-tuple-element, %get-tuple-element.1), custom_call_target="some target"
+// CHECK: })";
+
+  RunAndFilecheckHloRewrite(
+      hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
+      expected, [](HloModule* module) {
+        EXPECT_TRUE(module->has_schedule());
+        TF_CHECK_OK(module->schedule().Verify());
+      });
+}
+
+TEST_F(CommandBufferSchedulingTest, MultipleCommandBuffers) {
+  const char* hlo = R"(
+      HloModule TestModule, is_scheduled=true
+
+      %fused_computation(param_0: s32[], param_1: s32[]) -> s32[] {
+        %p0 = s32[] parameter(0)
+        %p1 = s32[] parameter(1)
+        ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
+      }
+
+      %fused_computation.1(param_0: s32[], param_1: s32[]) -> s32[] {
+        %p0 = s32[] parameter(0)
+        %p1 = s32[] parameter(1)
+        ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
+      }
+
+      %fused_computation.2(param_0: s32[], param_1: s32[]) -> s32[] {
+        %p0 = s32[] parameter(0)
+        %p1 = s32[] parameter(1)
+        ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
+      }
+
+      %fused_computation.3(param_0: s32[], param_1: s32[]) -> s32[] {
+        %p0 = s32[] parameter(0)
+        %p1 = s32[] parameter(1)
+        ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
+      }
+
+      ENTRY %main (a: s32[], b: s32[], c: (s32[], s32[])) -> s32[] {
+        %a = s32[] parameter(0)
+        %b = s32[] parameter(1)
+        %c = (s32[], s32[]) parameter(2)
+        %fusion = s32[] fusion(s32[] %a, s32[] %b), kind=kLoop, calls=%fused_computation
+        %d = s32[] get-tuple-element((s32[], s32[]) %c), index=0
+        %fusion.1 = s32[] fusion(s32[] %fusion, s32[] %d), kind=kLoop, calls=%fused_computation.1
+        %e = s32[] get-tuple-element((s32[], s32[]) %c), index=1
+        %custom-call = s32[] custom-call(s32[] %fusion.1, s32[] %e), custom_call_target="some target"
+        %fusion.2 = s32[] fusion(s32[] %custom-call, s32[] %a), kind=kLoop, calls=%fused_computation.2
+        %fusion.3 = s32[] fusion(s32[] %custom-call, s32[] %fusion.2), kind=kLoop, calls=%fused_computation.3
+        ROOT %custom-call.1 = s32[] custom-call(s32[] %fusion.3), custom_call_target="some target"
+      })";
+
+  const char* expected = R"(
+// CHECK:  %command_buffer ([[P0:.+]]: s32[], [[P1:.+]]: s32[], [[P2:.+]]: (s32[], s32[])) -> s32[] {
+// CHECK:    %[[P0]] = s32[] parameter(0)
+// CHECK:    %[[P1]] = s32[] parameter(1)
+// CHECK:    %[[P2]] = (s32[], s32[]) parameter(2)
+// CHECK:    %[[F0:.+]] = s32[] fusion(%[[P0]], %[[P1]]), kind=kLoop, calls=%fused_computation
+// CHECK:    %[[V0:.+]] = s32[] get-tuple-element(%[[P2]]), index=0
+// CHECK:    ROOT {{.*}} = s32[] fusion(%[[F0]], %[[V0]]), kind=kLoop, calls=%fused_computation.1
+// CHECK:  }
+
+// CHECK:  %command_buffer.2 ([[P0:.+]]: s32[], [[P1:.+]]: s32[]) -> s32[] {
+// CHECK:    %[[P0]] = s32[] parameter(0)
+// CHECK:    %[[P1]] = s32[] parameter(1)
+// CHECK:    %[[F2:.+]] = s32[] fusion(%[[P0]], %[[P1]]), kind=kLoop, calls=%fused_computation.2
+// CHECK:    ROOT {{.*}} = s32[] fusion(%[[P0]], %[[F2]]), kind=kLoop, calls=%fused_computation.3
+// CHECK:  }
+
+// CHECK:  ENTRY %main (a: s32[], b: s32[], c: (s32[], s32[])) -> s32[] {
+// CHECK:    %a = s32[] parameter(0)
+// CHECK:    %b = s32[] parameter(1)
+// CHECK:    %c = (s32[], s32[]) parameter(2)
+// CHECK:    %[[CMD0:.+]] = s32[] call(%a, %b, %c), to_apply=%command_buffer
+// CHECK:    %e = s32[] get-tuple-element(%c), index=1
+// CHECK:    %[[CALL:.+]] = s32[] custom-call(%[[CMD0]], %e), custom_call_target="some target"
+// CHECK:    %[[CMD1:.+]] = s32[] call(%[[CALL]], %a), to_apply=%command_buffer.2
+// CHECK:    ROOT {{.*}} = s32[] custom-call(%[[CMD1]]), custom_call_target="some target"
+// CHECK:  })";
+
+  RunAndFilecheckHloRewrite(
+      hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
+      expected, [](HloModule* module) {
+        EXPECT_TRUE(module->has_schedule());
+        TF_CHECK_OK(module->schedule().Verify());
+      });
+}
+
+TEST_F(CommandBufferSchedulingTest, AllReduceStartFollowedByDone) {
+  const char* hlo = R"(
+    HloModule TestModule, is_scheduled=true
+
+    %add (p0: s32[4], p1: s32[4]) -> s32[4] {
+      %p0 = s32[4] parameter(0)
+      %p1 = s32[4] parameter(1)
+      ROOT %add = s32[4] add(s32[4] %p0, s32[4] %p1)
+    }
+
+    ENTRY %main (a: s32[4]) -> s32[4] {
+      %a = s32[4] parameter(0)
+      %start = s32[4]{0} all-reduce-start(s32[4]{0} %a),
+        replica_groups={{0,1}}, to_apply=%add,
+        backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}}
+      ROOT %done = s32[4]{0} all-reduce-done(s32[4]{0} %start)
+    })";
+
+  const char* expected = R"(
+    CHECK: %command_buffer ([[P0:.+]]: s32[4]) -> s32[4] {
+    CHECK:   %[[P0]] = s32[4]{0} parameter(0)
+    CHECK:   %[[START:.+]] = s32[4]{0} all-reduce-start(%[[P0]])
+    CHECK:   ROOT %[[DONE:.+]] = s32[4]{0} all-reduce-done(%[[START]])
+    CHECK: }
+
+    CHECK: ENTRY %main (a: s32[4]) -> s32[4] {
+    CHECK:   %[[A:.+]] = s32[4]{0} parameter(0)
+    CHECK:   ROOT %[[CALL:.+]] = s32[4]{0} call(%[[A]]),
+    CHECK:     to_apply=%command_buffer
+    CHECK: })";
+
+  RunAndFilecheckHloRewrite(
+      hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
+      expected, [](HloModule* module) {
+        EXPECT_TRUE(module->has_schedule());
+        TF_CHECK_OK(module->schedule().Verify());
+      });
+}
+
+TEST_F(CommandBufferSchedulingTest, AllGatherStartFollowedByDone) {
+  const char* hlo = R"(
+    HloModule TestModule, is_scheduled=true
+
+    ENTRY %main (a: s32[2]) -> s32[4] {
+      %a = s32[2] parameter(0)
+
+      %start = (s32[2]{0}, s32[4]{0}) all-gather-start(%a),
+        channel_id=555, replica_groups={{0,1}}, dimensions={0},
+        backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}}
+
+      ROOT %done = s32[4]{0} all-gather-done(%start)
+    })";
+
+  const char* expected = R"(
+    CHECK: %command_buffer ([[P0:.+]]: s32[2]) -> s32[4] {
+    CHECK:   %[[P0]] = s32[2]{0} parameter(0)
+    CHECK:   %[[START:.+]] = {{.*}} all-gather-start(%[[P0]])
+    CHECK:   ROOT %[[DONE:.+]] = s32[4]{0} all-gather-done(%[[START]])
+    CHECK: }
+
+    CHECK: ENTRY %main (a: s32[2]) -> s32[4] {
+    CHECK:   %[[A:.+]] = s32[2]{0} parameter(0)
+    CHECK:   ROOT %[[CALL:.+]] = s32[4]{0} call(%[[A]]),
+    CHECK:     to_apply=%command_buffer
+    CHECK: })";
+
+  RunAndFilecheckHloRewrite(
+      hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
+      expected, [](HloModule* module) {
+        EXPECT_TRUE(module->has_schedule());
+        TF_CHECK_OK(module->schedule().Verify());
+      });
+}
+
+TEST_F(CommandBufferSchedulingTest, ReduceScatterStartFollowedByDone) {
+  const char* hlo = R"(
+    HloModule TestModule, is_scheduled=true
+
+    %add (p0: s32[], p1: s32[]) -> s32[] {
+      %p0 = s32[] parameter(0)
+      %p1 = s32[] parameter(1)
+      ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
+    }
+
+    ENTRY %main (a: s32[4]) -> s32[2] {
+      %a = s32[4] parameter(0)
+
+      %start = ((s32[4]{0}), s32[2]{0}) reduce-scatter-start(%a),
+        channel_id=555, replica_groups={{0,1}}, dimensions={0}, to_apply=add,
+        backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}}
+
+      ROOT %done = s32[2]{0} reduce-scatter-done(%start)
+    })";
+
+  const char* expected = R"(
+    CHECK: %command_buffer ([[P0:.+]]: s32[4]) -> s32[2] {
+    CHECK:   %[[P0]] = s32[4]{0} parameter(0)
+    CHECK:   %[[START:.+]] = {{.*}} reduce-scatter-start(%[[P0]])
+    CHECK:   ROOT %[[DONE:.+]] = s32[2]{0} reduce-scatter-done(%[[START]])
+    CHECK: }
+
+    CHECK: ENTRY %main (a: s32[4]) -> s32[2] {
+    CHECK:   %[[A:.+]] = s32[4]{0} parameter(0)
+    CHECK:   ROOT %[[CALL:.+]] = s32[2]{0} call(%[[A]]),
+    CHECK:     to_apply=%command_buffer
+    CHECK: })";
+
+  RunAndFilecheckHloRewrite(
+      hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
+      expected, [](HloModule* module) {
+        EXPECT_TRUE(module->has_schedule());
+        TF_CHECK_OK(module->schedule().Verify());
+      });
+}
+
+TEST_F(CommandBufferSchedulingTest, AllReduceStartFollowedByBitcast) {
+  const char* hlo = R"(
+    HloModule TestModule, is_scheduled=true
+
+    %add (p0: s32[4], p1: s32[4]) -> s32[4] {
+      %p0 = s32[4] parameter(0)
+      %p1 = s32[4] parameter(1)
+      ROOT %add = s32[4] add(s32[4] %p0, s32[4] %p1)
+    }
+
+    ENTRY %main (a: s32[4]) -> s32[4] {
+      %a = s32[4] parameter(0)
+      %start = s32[4]{0} all-reduce-start(s32[4]{0} %a),
+        replica_groups={{0,1}}, to_apply=%add,
+        backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}}
+      %bitcast = s32[4] bitcast(s32[4]{0} %a)
+      ROOT %done = s32[4]{0} all-reduce-done(s32[4]{0} %start)
+    })";
+
+  const char* expected = R"(
+    CHECK: %command_buffer ([[P0:.+]]: s32[4]) -> s32[4] {
+    CHECK:   %[[P0]] = s32[4]{0} parameter(0)
+    CHECK:   %[[START:.+]] = s32[4]{0} all-reduce-start(%[[P0]])
+    CHECK:   %[[BITCAST:.+]] = s32[4]{0} bitcast(%[[P0]])
+    CHECK:   ROOT %[[DONE:.+]] = s32[4]{0} all-reduce-done(%[[START]])
+    CHECK: }
+
+    CHECK: ENTRY %main (a: s32[4]) -> s32[4] {
+    CHECK:   %[[A:.+]] = s32[4]{0} parameter(0)
+    CHECK:   ROOT %[[CALL:.+]] = s32[4]{0} call(%[[A]]),
+    CHECK:     to_apply=%command_buffer
+    CHECK: })";
+
+  RunAndFilecheckHloRewrite(
+      hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
+      expected, [](HloModule* module) {
+        EXPECT_TRUE(module->has_schedule());
+        TF_CHECK_OK(module->schedule().Verify());
+      });
+}
+
+TEST_F(CommandBufferSchedulingTest, AllReduceStartFollowedAllReduceStart) {
+  const char* hlo = R"(
+    HloModule TestModule, is_scheduled=true
+
+    %add (p0: s32[4], p1: s32[4]) -> s32[4] {
+      %p0 = s32[4] parameter(0)
+      %p1 = s32[4] parameter(1)
+      ROOT %add = s32[4] add(s32[4] %p0, s32[4] %p1)
+    }
+
+    ENTRY %main (a: s32[4]) -> s32[4] {
+      %a = s32[4] parameter(0)
+      %start1 = s32[4]{0} all-reduce-start(s32[4]{0} %a),
+        replica_groups={{0,1}}, to_apply=%add,
+        backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}}
+      %start2 = s32[4]{0} all-reduce-start(s32[4]{0} %a),
+        replica_groups={{0,1}}, to_apply=%add,
+        backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}}
+      %done1 = s32[4]{0} all-reduce-done(s32[4]{0} %start1)
+      ROOT %done2 = s32[4]{0} all-reduce-done(s32[4]{0} %start2)
+    })";
+
+  const char* expected = R"(
+    CHECK: %command_buffer ([[P0:.+]]: s32[4]) -> s32[4] {
+    CHECK:   %[[P0]] = s32[4]{0} parameter(0)
+    CHECK:   %[[START1:.+]] = s32[4]{0} all-reduce-start(%[[P0]])
+    CHECK:   %[[START2:.+]] = s32[4]{0} all-reduce-start(%[[P0]])
+    CHECK:   %[[DONE1:.+]] = s32[4]{0} all-reduce-done(%[[START1]])
+    CHECK:   ROOT %[[DONE2:.+]] = s32[4]{0} all-reduce-done(%[[START2]])
+    CHECK: }
+
+    CHECK: ENTRY %main (a: s32[4]) -> s32[4] {
+    CHECK:   %[[A:.+]] = s32[4]{0} parameter(0)
+    CHECK:   ROOT %[[CALL:.+]] = s32[4]{0} call(%[[A]]),
+    CHECK:     to_apply=%command_buffer
+    CHECK: })";
+
+  RunAndFilecheckHloRewrite(
+      hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
+      expected, [](HloModule* module) {
+        EXPECT_TRUE(module->has_schedule());
+        TF_CHECK_OK(module->schedule().Verify());
+      });
+}
+
+TEST_F(CommandBufferSchedulingTest, DoNotCaptureUnmatchedAsyncDone) {
+  const char* hlo = R"(
+    HloModule TestModule, is_scheduled=true
+
+    %fused_computation(param_0: s32[], param_1: s32[]) -> s32[] {
+      %p0 = s32[] parameter(0)
+      %p1 = s32[] parameter(1)
+      ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
+    }
+
+    %fused_computation.1(param_0: s32[], param_1: s32[]) -> s32[] {
+      %p0 = s32[] parameter(0)
+      %p1 = s32[] parameter(1)
+      ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
+    }
+
+    %add (p0: s32[4], p1: s32[4]) -> s32[4] {
+      %p0 = s32[4] parameter(0)
+      %p1 = s32[4] parameter(1)
+      ROOT %add = s32[4] add(s32[4] %p0, s32[4] %p1)
+    }
+
+    ENTRY %main (a: s32[4], b:s32[]) -> s32[] {
+      %a = s32[4] parameter(0)
+      %b = s32[] parameter(1)
+      %start1 = s32[4]{0} all-reduce-start(s32[4]{0} %a),
+        replica_groups={{0,1}}, to_apply=%add,
+        backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}}
+      %c = s32[] custom-call(), custom_call_target="target"
+      %start2 = s32[4]{0} all-reduce-start(s32[4]{0} %a),
+        replica_groups={{0,1}}, to_apply=%add,
+        backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}}
+      %done1 = s32[4]{0} all-reduce-done(s32[4]{0} %start1)
+      %done2 = s32[4]{0} all-reduce-done(s32[4]{0} %start2)
+      %fusion = s32[] fusion(s32[] %b, s32[] %c), kind=kLoop, calls=%fused_computation
+      ROOT %fusion.1 = s32[] fusion(s32[] %b, s32[] %c), kind=kLoop, calls=%fused_computation.1
+    })";
+
+  const char* expected = R"(
+    CHECK: %command_buffer ([[P0:.+]]: s32[], [[P1:.+]]: s32[]) -> s32[] {
+    CHECK:   %[[P0]] = s32[] parameter(0)
+    CHECK:   %[[P1]] = s32[] parameter(1)
+    CHECK:   %fusion = s32[] fusion(%[[P0]], %[[P1]]), kind=kLoop, calls=%fused_computation
+    CHECK:   ROOT %fusion.1 = s32[] fusion(%[[P0]], %[[P1]]), kind=kLoop, calls=%fused_computation.1
+    CHECK: }
+
+    CHECK: ENTRY %main (a: s32[4], b: s32[]) -> s32[] {
+    CHECK:   %[[A:.+]] = s32[4]{0} parameter(0)
+    CHECK:   %[[B:.+]] = s32[] parameter(1)
+    CHECK:   %[[START1:.+]] = s32[4]{0} all-reduce-start(%[[A]])
+    CHECK:   %[[C:.+]] = s32[] custom-call()
+    CHECK:   %[[START2:.+]] = s32[4]{0} all-reduce-start(%[[A]])
+    CHECK:   %[[DONE1:.+]] = s32[4]{0} all-reduce-done(%[[START1]])
+    CHECK:   %[[DONE2:.+]] = s32[4]{0} all-reduce-done(%[[START2]])
+    CHECK:   %call = s32[] call(%b, %c), to_apply=%command_buffer
+    CHECK: })";
+
+  RunAndFilecheckHloRewrite(
+      hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
+      expected, [](HloModule* module) {
+        EXPECT_TRUE(module->has_schedule());
+        TF_CHECK_OK(module->schedule().Verify());
+      });
+}
+
+TEST_F(CommandBufferSchedulingTest, CollectCommandBufferSequence) {
+  const char* hlo = R"(
+      HloModule TestModule, is_scheduled=true
+
+      %fused_computation(param_0: s32[], param_1: s32[]) -> s32[] {
+        %p0 = s32[] parameter(0)
+        %p1 = s32[] parameter(1)
+        ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
+      }
+
+      %fused_computation.1(param_0: s32[], param_1: s32[]) -> s32[] {
+        %p0 = s32[] parameter(0)
+        %p1 = s32[] parameter(1)
+        ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
+      }
+
+      %fused_computation.2(param_0: s32[], param_1: s32[]) -> s32[] {
+        %p0 = s32[] parameter(0)
+        %p1 = s32[] parameter(1)
+        ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
+      }
+
+      %fused_computation.3(param_0: s32[], param_1: s32[]) -> s32[] {
+        %p0 = s32[] parameter(0)
+        %p1 = s32[] parameter(1)
+        ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
+      }
+
+      ENTRY %main (a: s32[], b: s32[], c: (s32[], s32[])) -> s32[] {
+        %a = s32[] parameter(0)
+        %b = s32[] parameter(1)
+        %c = (s32[], s32[]) parameter(2)
+        %fusion = s32[] fusion(s32[] %a, s32[] %b), kind=kLoop, calls=%fused_computation
+        %d = s32[] get-tuple-element((s32[], s32[]) %c), index=0
+        %fusion.1 = s32[] fusion(s32[] %fusion, s32[] %d), kind=kLoop, calls=%fused_computation.1
+        %e = s32[] get-tuple-element((s32[], s32[]) %c), index=1
+        %custom-call = s32[] custom-call(s32[] %fusion.1, s32[] %e), custom_call_target="some target"
+        %fusion.2 = s32[] fusion(s32[] %custom-call, s32[] %a), kind=kLoop, calls=%fused_computation.2
+        ROOT %fusion.3 = s32[] fusion(s32[] %custom-call, s32[] %fusion.2), kind=kLoop, calls=%fused_computation.3
+      })";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+                          ParseAndReturnVerifiedModule(hlo));
+
+  HloInstructionSequence seq;
+  for (HloInstruction* x : module->entry_computation()->instructions()) {
+    seq.push_back(x);
+  }
+  EXPECT_EQ(seq.size(), 10);
+
+  CommandBufferScheduling::CommandBufferConfig config{
+      {DebugOptions::FUSION}, {}, device_desc()};
+
+  std::vector<HloInstructionSequence> command_buffer_sequences =
+      CommandBufferScheduling::CollectCommandBufferSequences(seq, config);
+  EXPECT_EQ(command_buffer_sequences.size(), 2);
+
+  std::vector<HloInstruction*> seq_0 =
+      command_buffer_sequences[0].instructions();
+  EXPECT_EQ(seq_0.size(), 3);
+  EXPECT_EQ(seq_0[0]->opcode(), HloOpcode::kFusion);
+  EXPECT_EQ(seq_0[1]->opcode(), HloOpcode::kGetTupleElement);
+  EXPECT_EQ(seq_0[2]->opcode(), HloOpcode::kFusion);
+
+  std::vector<HloInstruction*> seq_1 =
+      command_buffer_sequences[1].instructions();
+  EXPECT_EQ(seq_1.size(), 2);
+  EXPECT_EQ(seq_1[0]->opcode(), HloOpcode::kFusion);
+  EXPECT_EQ(seq_1[1]->opcode(), HloOpcode::kFusion);
+}
+
+TEST_F(CommandBufferSchedulingTest, MoveParametersToFront) {
+  const char* hlo = R"(
+      HloModule TestModule, is_scheduled=true
+
+      %fused_computation (param_0: s32[], param_1: s32[]) -> s32[] {
+        %p0 = s32[] parameter(0)
+        %p1 = s32[] parameter(1)
+        ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
+      }
+
+      %fused_computation.1 (param_0: s32[], param_1: s32[]) -> s32[] {
+        %p0 = s32[] parameter(0)
+        %p1 = s32[] parameter(1)
+        ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
+      }
+
+      ENTRY %main (a: s32[], b: s32[], c: s32[]) -> s32[] {
+        %a = s32[] parameter(0)
+        %b = s32[] parameter(1)
+        %fusion = s32[] fusion(s32[] %a, s32[] %b), kind=kLoop, calls=%fused_computation
+        %c = s32[] parameter(2)
+        ROOT %fusion.1 = s32[] fusion(s32[] %a, s32[] %c), kind=kLoop, calls=%fused_computation.1
+      })";
+
+  const char* expected = R"(
+// CHECK: ENTRY %main (a: s32[], b: s32[], c: s32[]) -> s32[] {
+// CHECK:   %a = s32[] parameter(0)
+// CHECK:   %b = s32[] parameter(1)
+// CHECK:   %c = s32[] parameter(2)
+// CHECK:   %fusion = s32[] fusion(%a, %b), kind=kLoop, calls=%fused_computation
+// CHECK:   ROOT %fusion.1 = s32[] fusion(%a, %c), kind=kLoop, calls=%fused_computation.1
+// CHECK: })";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+                          ParseAndReturnVerifiedModule(hlo));
+  TF_ASSERT_OK(CommandBufferScheduling::MoveParametersAndConstantsToFront(
+      module->entry_computation()));
+  TF_ASSERT_OK_AND_ASSIGN(
+      bool filecheck_matches,
+      RunFileCheck(
+          module->ToString(HloPrintOptions{}.set_print_operand_shape(false)),
+          expected));
+  EXPECT_TRUE(filecheck_matches);
+}
+
+TEST_F(CommandBufferSchedulingTest, PrepareCommandBuffer) {
+  const char* hlo = R"(
+      HloModule TestModule, is_scheduled=true
+
+      %fused_computation(param_0: s32[], param_1: s32[]) -> (s32[], s32[]) {
+        %p0 = s32[] parameter(0)
+        %p1 = s32[] parameter(1)
+        ROOT %tuple.1 = (s32[], s32[]) tuple(s32[] %p0, s32[] %p1)
+      }
+
+      %fused_computation.1(param_0: s32[], param_1: s32[]) -> s32[] {
+        %p0 = s32[] parameter(0)
+        %p1 = s32[] parameter(1)
+        ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
+      }
+
+      ENTRY %main (a: s32[], b: s32[]) -> s32[] {
+        %a = s32[] parameter(0)
+        %b = s32[] custom-call(), custom_call_target="target"
+        %fusion = (s32[], s32[]) fusion(s32[] %a, s32[] %b), kind=kLoop, calls=%fused_computation
+        %d = s32[] get-tuple-element((s32[], s32[]) %fusion), index=0
+        %fusion.1 = s32[] fusion(s32[] %a, s32[] %d), kind=kLoop, calls=%fused_computation.1
+        ROOT %custom-call = s32[] custom-call(s32[] %fusion.1, s32[] %d), custom_call_target="some target"
+      })";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnUnverifiedModule(hlo));
+
+  EXPECT_EQ(module->entry_computation()->instruction_count(), 6);
+  std::vector<HloInstruction*> instructions;
+  HloInstructionSequence seq;
+  for (HloInstruction* inst : module->entry_computation()->instructions()) {
+    if (inst->opcode() == HloOpcode::kFusion ||
+        inst->opcode() == HloOpcode::kGetTupleElement) {
+      seq.push_back(inst);
+    }
+    instructions.push_back(inst);
+  }
+
+  TF_ASSERT_OK_AND_ASSIGN(
+      CommandBuffer command_buffer,
+      CommandBufferScheduling::PrepareCommandBuffer(seq, module.get()));
+  HloComputation* computation = module->AddComputation(
+      std::move(command_buffer.computation), /*is_entry=*/false);
+
+  const char* expected = R"(
+// CHECK: %command_buffer ([[P0:.+]]: s32[], [[P1:.+]]: s32[]) -> (s32[], s32[]) {
+// CHECK:  %[[P0]] = s32[] parameter(0)
+// CHECK:  %[[P1]] = s32[] parameter(1)
+// CHECK:  %fusion = (s32[], s32[]) fusion(%[[P0]], %[[P1]]), kind=kLoop, calls=%fused_computation
+// CHECK:  %[[V0:.+]] = s32[] get-tuple-element(%fusion), index=0
+// CHECK:  %fusion.1 = s32[] fusion(%[[P0]], %[[V0]]), kind=kLoop, calls=%fused_computation.1
+// CHECK:  ROOT {{.*}} = (s32[], s32[]) tuple(%[[V0]], %fusion.1)
+// CHECK:})";
+
+  TF_ASSERT_OK_AND_ASSIGN(
+      bool filecheck_matches,
+      RunFileCheck(computation->ToString(
+                       HloPrintOptions{}.set_print_operand_shape(false)),
+                   expected));
+  EXPECT_TRUE(filecheck_matches);
+
+  auto& arguments = command_buffer.arguments;
+  ASSERT_EQ(arguments.size(), 2);
+  EXPECT_EQ(arguments[0], instructions[0]);
+  EXPECT_EQ(arguments[1], instructions[1]);
+
+  auto& results = command_buffer.results;
+  ASSERT_EQ(results.size(), 2);
+  EXPECT_EQ(results[0], instructions[3]);
+  EXPECT_EQ(results[1], instructions[4]);
+}
+
+TEST_F(CommandBufferSchedulingTest, ForwardControlDependencies) {
+  const char* hlo = R"(
+    HloModule TestModule, is_scheduled=true
+
+    %fused_computation (param_0: s32[], param_1: s32[]) -> s32[] {
+      %p0 = s32[] parameter(0)
+      %p1 = s32[] parameter(1)
+      ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
+    }
+
+    %fused_computation.1 (param_0: s32[], param_1: s32[]) -> s32[] {
+      %p0 = s32[] parameter(0)
+      %p1 = s32[] parameter(1)
+      ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
+    }
+
+    %fused_computation.2 (param_0: s32[], param_1: s32[]) -> s32[] {
+      %p0 = s32[] parameter(0)
+      %p1 = s32[] parameter(1)
+      ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
+    }
+
+    ENTRY %main (a: s32[], b: s32[]) -> s32[] {
+      %a = s32[] parameter(0)
+      %b = s32[] parameter(1)
+      %custom-call = s32[] custom-call(), custom_call_target="some target"
+      %fusion = s32[] fusion(s32[] %a, s32[] %b), kind=kLoop, calls=%fused_computation, control-predecessors={%custom-call}
+      %fusion.1 = s32[] fusion(s32[] %a, s32[] %b), kind=kLoop, calls=%fused_computation.1, control-predecessors={%fusion}
+      %custom-call.1 = s32[] custom-call(), custom_call_target="some target"
+      %fusion.2 = s32[] fusion(s32[] %a, s32[] %b), kind=kLoop, calls=%fused_computation.2, control-predecessors={%fusion.1}
+      ROOT %custom-call.2 = s32[] custom-call(s32[] %fusion.1, s32[] %fusion.2), custom_call_target="some target"
+    })";
+
+  const char* expected = R"(
+    CHECK: %command_buffer ([[P0:.+]]: s32[], [[P1:.+]]: s32[]) -> s32[] {
+    CHECK:   %[[P0]] = s32[] parameter(0)
+    CHECK:   %[[P1]] = s32[] parameter(1)
+    CHECK:   %[[F0:.+]] = s32[] fusion(%[[P0]], %[[P1]])
+    CHECK:   ROOT {{.*}} = s32[] fusion(%[[P0]], %[[P1]]), {{.*}} control-predecessors={%[[F0]]}
+    CHECK: }
+
+    CHECK: ENTRY %main (a: s32[], b: s32[]) -> s32[] {
+    CHECK:   %a = s32[] parameter(0)
+    CHECK:   %b = s32[] parameter(1)
+    CHECK:   %custom-call = s32[] custom-call(), custom_call_target="some target"
+    CHECK:   %call = s32[] call(%a, %b), to_apply=%command_buffer, control-predecessors={%custom-call}
+    CHECK:   %custom-call.1 = s32[] custom-call(), custom_call_target="some target"
+    CHECK:   %[[F3:.+]] = s32[] fusion(%a, %b), kind=kLoop, calls=%fused_computation.2, control-predecessors={%call}
+    CHECK:   ROOT %custom-call.2 = s32[] custom-call(%call, %[[F3]]), custom_call_target="some target"
+    CHECK: })";
+
+  RunAndFilecheckHloRewrite(
+      hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
+      expected, [](HloModule* module) {
+        EXPECT_TRUE(module->has_schedule());
+        TF_CHECK_OK(module->schedule().Verify());
+      });
+}
+
+TEST_F(CommandBufferSchedulingTest, ForwardControlDependenciesToParams) {
+  const char* hlo = R"(
+    HloModule TestModule, is_scheduled=true
+
+    %fused_computation.0 (p0: s32[], p1: s32[]) -> s32[] {
+      %p0 = s32[] parameter(0)
+      %p1 = s32[] parameter(1)
+      ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
+    }
+
+    %fused_computation.1 (p0: s32[], p1: s32[]) -> s32[] {
+      %p0 = s32[] parameter(0)
+      %p1 = s32[] parameter(1)
+      ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
+    }
+
+    ENTRY %main (a: s32[], b: s32[]) -> s32[] {
+      %a = s32[] parameter(0)
+      %b = s32[] parameter(1)
+      %custom-call = s32[] custom-call(), custom_call_target="some target"
+      %fusion = s32[] fusion(s32[] %custom-call, s32[] %a), kind=kLoop, calls=%fused_computation.0, control-predecessors={%custom-call}
+      ROOT %fusion.1 = s32[] fusion(s32[] %fusion, s32[] %b), kind=kLoop, calls=%fused_computation.1
+    })";
+
+  const char* expected = R"(
+    CHECK: ENTRY %main (a: s32[], b: s32[]) -> s32[] {
+    CHECK:   %a = s32[] parameter(0)
+    CHECK:   %b = s32[] parameter(1)
+    CHECK:   %[[CUSTOM_CALL:.+]] = s32[] custom-call(), custom_call_target="some target"
+    CHECK:   ROOT {{.*}} call(%[[CUSTOM_CALL]], %a, %b), to_apply=%command_buffer, control-predecessors={%[[CUSTOM_CALL]]}
+    CHECK: })";
+
+  RunAndFilecheckHloRewrite(
+      hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
+      expected, [](HloModule* module) {
+        EXPECT_TRUE(module->has_schedule());
+        TF_CHECK_OK(module->schedule().Verify());
+      });
+}
+
+TEST_F(CommandBufferSchedulingTest, WhileNotCommand) {
+  const char* hlo = R"(
+    HloModule TestModule, is_scheduled=true
+
+    %fused_computation (param_0: f32[1]) -> f32[1] {
+      %param_0 = f32[1]{0} parameter(0)
+      ROOT %copy.5 = f32[1]{0} copy(f32[1]{0} %param_0)
+    }
+
+    %fused_computation.1 (param_0.1: f32[1], param_1: f32[1]) -> f32[1] {
+      %param_0.1 = f32[1]{0} parameter(0)
+      %param_1 = f32[1]{0} parameter(1)
+      ROOT %add.2 = f32[1]{0} add(f32[1]{0} %param_0.1, f32[1]{0} %param_1)
+    }
+
+    %fused_computation.2 (param_0.2: f32[1], param_1.1: f32[1]) -> pred[1] {
+      %param_0.2 = f32[1]{0} parameter(0)
+      %param_1.1 = f32[1]{0} parameter(1)
+      ROOT %compare.3 = pred[1]{0} compare(f32[1]{0} %param_0.2, f32[1]{0} %param_1.1), direction=LT
+    }
+
+    %fused_computation.3 (param_0.1: f32[1], param_1: f32[1]) -> f32[1] {
+      %param_0.1 = f32[1]{0} parameter(0)
+      %param_1 = f32[1]{0} parameter(1)
+      ROOT %add.2 = f32[1]{0} add(f32[1]{0} %param_0.1, f32[1]{0} %param_1)
+    }
+
+    %body (Arg_.3: f32[1]) -> f32[1] {
+      %constant_4 = f32[1]{0} constant({1})
+      %Arg_.3 = f32[1]{0} parameter(0)
+      %custom-call = s32[] custom-call(), custom_call_target="some target"
+      %add = f32[1]{0} fusion(f32[1]{0} %Arg_.3, f32[1]{0} %constant_4), kind=kLoop, calls=%fused_computation.1, control-predecessors={%custom-call}
+      ROOT %wrapped_add.1 = f32[1]{0} fusion(f32[1]{0} %add, f32[1]{0} %constant_4), kind=kLoop, calls=%fused_computation.3, control-predecessors={%custom-call}
+    }
+
+    %cond (Arg_.11: f32[1]) -> pred[] {
+      %constant = f32[1]{0} constant({100})
+      %Arg_.11 = f32[1]{0} parameter(0)
+      %wrapped_compare.2 = pred[1]{0} fusion(f32[1]{0} %Arg_.11, f32[1]{0} %constant), kind=kLoop, calls=%fused_computation.2
+      ROOT %bitcast = pred[] bitcast(pred[1]{0} %wrapped_compare.2)
+    }
+
+    ENTRY %main.18 (Arg_0.1: f32[1]) -> f32[] {
+      %Arg_0.1 = f32[1]{0} parameter(0), sharding={replicated}
+      %wrapped_copy.4 = f32[1]{0} fusion(f32[1]{0} %Arg_0.1), kind=kLoop, calls=%fused_computation
+      %while.16 = f32[1]{0} while(f32[1]{0} %wrapped_copy.4), condition=%cond, body=%body
+      ROOT %bitcast.1 = f32[] bitcast(f32[1]{0} %while.16)
+    })";
+
+  const char* expected = R"(
+    CHECK: %command_buffer ([[P0:.+]]: f32[1], [[P1:.+]]: f32[1]) -> f32[1] {
+    CHECK:   %[[P0]] = f32[1]{0} parameter(0)
+    CHECK:   %[[P1]] = f32[1]{0} parameter(1)
+    CHECK:   %[[ADD:.*]] = f32[1]{0} fusion(%[[P0]], %[[P1]]), kind=kLoop
+    CHECK:   ROOT {{.*}} = f32[1]{0} fusion(%[[ADD]], %[[P1]]), kind=kLoop
+    CHECK: }
+
+    CHECK: %[[BODY:[a-z_0-9.]+]] ([[P0:.+]]: f32[1]) -> f32[1] {
+    CHECK:   %[[C1:.*]] = f32[1]{0} constant({1})
+    CHECK:   %[[P0]] = f32[1]{0} parameter(0)
+    CHECK:   %[[CC:.*]] = s32[] custom-call(), custom_call_target="some target"
+    CHECK:   ROOT %call = f32[1]{0} call(%[[P0]], %[[C1]]), to_apply=%command_buffer, control-predecessors={%[[CC]]}
+    CHECK: }
+
+    CHECK: ENTRY %[[MAIN:.+]] ([[ARG0:.+]]: f32[1]) -> f32[] {
+    CHECK:   %[[ARG0]] = f32[1]{0} parameter(0)
+    CHECK:   %[[COPY:.*]] = f32[1]{0} fusion(%[[ARG0]]), kind=kLoop
+    CHECK:   %[[WHILE:.*]] = f32[1]{0} while(%[[COPY]]), condition=%[[COND:[a-z_0-9.]+]], body=%[[BODY]]
+    CHECK:   ROOT %[[BC:.+]] = f32[] bitcast(%[[WHILE]])
+    CHECK: })";
+
+  RunAndFilecheckHloRewrite(
+      hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
+      expected, [](HloModule* module) {
+        EXPECT_TRUE(module->has_schedule());
+        TF_CHECK_OK(module->schedule().Verify());
+      });
+}
+
+TEST_F(CommandBufferSchedulingTest, While) {
+  const char* hlo = R"(
+    HloModule TestModule, is_scheduled=true
+
+    %fused_computation (param_0: f32[1]) -> f32[1] {
+      %param_0 = f32[1]{0} parameter(0)
+      ROOT %copy.5 = f32[1]{0} copy(f32[1]{0} %param_0)
+    }
+
+    %fused_computation.1 (param_0.1: f32[1], param_1: f32[1]) -> f32[1] {
+      %param_0.1 = f32[1]{0} parameter(0)
+      %param_1 = f32[1]{0} parameter(1)
+      ROOT %add.2 = f32[1]{0} add(f32[1]{0} %param_0.1, f32[1]{0} %param_1)
+    }
+
+    %fused_computation.2 (param_0.2: f32[1], param_1.1: f32[1]) -> pred[1] {
+      %param_0.2 = f32[1]{0} parameter(0)
+      %param_1.1 = f32[1]{0} parameter(1)
+      ROOT %compare.3 = pred[1]{0} compare(f32[1]{0} %param_0.2, f32[1]{0} %param_1.1), direction=LT
+    }
+
+    %body (Arg_.3: f32[1]) -> f32[1] {
+      %constant_4 = f32[1]{0} constant({1})
+      %Arg_.3 = f32[1]{0} parameter(0)
+      ROOT %wrapped_add.1 = f32[1]{0} fusion(f32[1]{0} %Arg_.3, f32[1]{0} %constant_4), kind=kLoop, calls=%fused_computation.1
+    }
+
+    %cond (Arg_.11: f32[1]) -> pred[] {
+      %constant = f32[1]{0} constant({100})
+      %Arg_.11 = f32[1]{0} parameter(0)
+      %wrapped_compare.2 = pred[1]{0} fusion(f32[1]{0} %Arg_.11, f32[1]{0} %constant), kind=kLoop, calls=%fused_computation.2
+      ROOT %bitcast = pred[] bitcast(pred[1]{0} %wrapped_compare.2)
+    }
+
+    ENTRY %main.18 (Arg_0.1: f32[1]) -> f32[] {
+      %Arg_0.1 = f32[1]{0} parameter(0), sharding={replicated}
+      %wrapped_copy.4 = f32[1]{0} fusion(f32[1]{0} %Arg_0.1), kind=kLoop, calls=%fused_computation
+      %while.16 = f32[1]{0} while(f32[1]{0} %wrapped_copy.4), condition=%cond, body=%body
+      ROOT %bitcast.1 = f32[] bitcast(f32[1]{0} %while.16)
+    })";
+
+  const char* expected = R"(
+    CHECK: %command_buffer ([[P0:.+]]: f32[1]) -> f32[1] {
+    CHECK:   %[[P0]] = f32[1]{0} parameter(0)
+    CHECK:   %[[COPY:.*]] = f32[1]{0} fusion(%[[P0]]), kind=kLoop
+    CHECK:   ROOT {{.*}} = f32[1]{0} while(%[[COPY]]), condition=%[[COND:[a-z_0-9.]+]], body=%[[BODY:[a-z_0-9.]+]]
+    CHECK: }
+
+    CHECK: ENTRY %[[MAIN:.+]] ([[ARG0:.+]]: f32[1]) -> f32[] {
+    CHECK:   %[[ARG0]] = f32[1]{0} parameter(0)
+    CHECK:   %call = f32[1]{0} call(%[[ARG0]]), to_apply=%command_buffer
+    CHECK:   ROOT %[[BC:.+]] = f32[] bitcast(%call)
+    CHECK: })";
+
+  RunAndFilecheckHloRewrite(
+      hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
+      expected, [](HloModule* module) {
+        EXPECT_TRUE(module->has_schedule());
+        TF_CHECK_OK(module->schedule().Verify());
+      });
+}
+
+TEST_F(CommandBufferSchedulingTest, Conditional) {
+  const char* hlo = R"(
+    HloModule TestModule, is_scheduled=true
+
+    %fused_computation.1 (param_0.2: s32[5]) -> s32[5] {
+      %param_0.2 = s32[5]{0} parameter(0)
+      ROOT %negate.2 = s32[5]{0} negate(s32[5]{0} %param_0.2)
+    }
+
+    %region_0.7 (Arg_.8: s32[5]) -> (s32[5]) {
+      %Arg_.8 = s32[5]{0} parameter(0)
+      %wrapped_negate.1 = s32[5]{0} fusion(s32[5]{0} %Arg_.8), kind=kLoop, calls=%fused_computation.1
+      ROOT %tuple.3 = (s32[5]{0}) tuple(s32[5]{0} %wrapped_negate.1)
+    }
+
+    %fused_computation.2 (param_0.3: s32[5]) -> s32[5] {
+      %param_0.3 = s32[5]{0} parameter(0)
+      ROOT %not.2 = s32[5]{0} not(s32[5]{0} %param_0.3)
+    }
+
+    %region_1.10 (Arg_.11: s32[5]) -> (s32[5]) {
+      %Arg_.11 = s32[5]{0} parameter(0)
+      %wrapped_not.1 = s32[5]{0} fusion(s32[5]{0} %Arg_.11), kind=kLoop, calls=%fused_computation.2
+      ROOT %tuple.4 = (s32[5]{0}) tuple(s32[5]{0} %wrapped_not.1)
+    }
+
+    %fused_computation.3 (param_0.4: s32[5]) -> s32[5] {
+      %param_0.4 = s32[5]{0} parameter(0)
+      ROOT %multiply.2 = s32[5]{0} multiply(s32[5]{0} %param_0.4, s32[5]{0} %param_0.4)
+    }
+
+    %region_2.13 (Arg_.14: s32[5]) -> (s32[5]) {
+      %Arg_.14 = s32[5]{0} parameter(0)
+      %wrapped_multiply.1 = s32[5]{0} fusion(s32[5]{0} %Arg_.14), kind=kLoop, calls=%fused_computation.3
+      ROOT %tuple.5 = (s32[5]{0}) tuple(s32[5]{0} %wrapped_multiply.1)
+    }
+
+    %fused_computation (param_0.1: s64[]) -> s32[] {
+      %constant_1 = s32[] constant(0)
+      %param_0.1 = s64[] parameter(0)
+      %convert.2 = s32[] convert(s64[] %param_0.1)
+      %constant_0 = s32[] constant(2)
+      ROOT %clamp.2 = s32[] clamp(s32[] %constant_1, s32[] %convert.2, s32[] %constant_0)
+    }
+
+    ENTRY %main.17 (Arg_0.1: s64[], Arg_1.2: s32[5]) -> s32[5] {
+      %Arg_0.1 = s64[] parameter(0), sharding={replicated}
+      %fusion = s32[] fusion(s64[] %Arg_0.1), kind=kLoop, calls=%fused_computation
+      %Arg_1.2 = s32[5]{0} parameter(1), sharding={replicated}
+      %conditional.16.clone = (s32[5]{0}) conditional(s32[] %fusion, s32[5]{0} %Arg_1.2, s32[5]{0} %Arg_1.2, s32[5]{0} %Arg_1.2), branch_computations={%region_0.7, %region_1.10, %region_2.13}
+      ROOT %get-tuple-element = s32[5]{0} get-tuple-element((s32[5]{0}) %conditional.16.clone), index=0
+    })";
+
+  const char* expected = R"(
+    CHECK: %command_buffer ([[P0:.+]]: s64[], [[P1:.+]]: s32[5]) -> (s32[5]) {
+    CHECK:   %[[P0]] = s64[] parameter(0)
+    CHECK:   %[[P1]] = s32[5]{0} parameter(1)
+    CHECK:   %[[FUSION:.*]] = s32[] fusion(%[[P0]]), kind=kLoop
+    CHECK:   ROOT {{.*}} = (s32[5]{0}) conditional(%[[FUSION]], %[[P1]], %[[P1]], %[[P1]]), branch_computations={%[[B1:[a-z_0-9.]+]], %[[B2:[a-z_0-9.]+]], %[[B3:[a-z_0-9.]+]]}
+    CHECK: }
+
+    CHECK: ENTRY %[[MAIN:.+]] ([[ARG0:.+]]: s64[], [[ARG1:.+]]: s32[5]) -> s32[5] {
+    CHECK:   %[[ARG0]] = s64[] parameter(0)
+    CHECK:   %[[ARG1]] = s32[5]{0} parameter(1)
+    CHECK:   %call = (s32[5]{0}) call(%[[ARG0]], %[[ARG1]]), to_apply=%command_buffer
+    CHECK:   ROOT %[[GEP:.+]] = s32[5]{0} get-tuple-element(%call)
+    CHECK: })";
+
+  RunAndFilecheckHloRewrite(
+      hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
+      expected, [](HloModule* module) {
+        EXPECT_TRUE(module->has_schedule());
+        TF_CHECK_OK(module->schedule().Verify());
+      });
+}
+
+TEST_F(CommandBufferSchedulingTest, CuDnnFusionGraphCaptureWorks) {
+  const std::string kHloText = R"(
+HloModule m, is_scheduled=true
+
+fusion0 {
+  p0 = f32[64,64] parameter(0)
+  p1 = f32[64,64] parameter(1)
+  ROOT d = f32[64,64] dot(p0, p1),
+    lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+
+fusion1 {
+  p0 = f32[64,64] parameter(0)
+  p1 = f32[64,64] parameter(1)
+  ROOT d = f32[64,64] dot(p0, p1),
+    lhs_contracting_dims={0}, rhs_contracting_dims={1}
+}
+
+fusion_a {
+  p0 = f32[64,64] parameter(0)
+  p1 = f32[64,64] parameter(1)
+  ROOT a = f32[64,64] add(p0, p1)
+}
+
+ENTRY e {
+  p0 = f32[64,64] parameter(0)
+  p1 = f32[64,64] parameter(1)
+  d0 = f32[64,64] fusion(p0, p1), kind=kCustom,
+    calls=fusion0,
+    backend_config={"fusion_backend_config": {"kind":"__cudnn$fusion"}}
+  a = f32[64,64] fusion(d0, d0), kind=kLoop, calls=fusion_a
+  ROOT d1 = f32[64,64] fusion(a, p1), kind=kCustom,
+    calls=fusion1,
+    backend_config={"fusion_backend_config": {"kind":"__cudnn$fusion"}}
+})";
+
+  const std::string kExpected = R"(
+; CHECK: ENTRY
+; CHECK-NEXT: parameter
+; CHECK-NEXT: parameter
+; CHECK-NEXT: ROOT
+; CHECK-SAME: call(
+; CHECK-SAME: to_apply=%command_buffer
+})";
+
+  RunAndFilecheckHloRewrite(
+      kHloText,
+      CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
+      kExpected, [](HloModule* module) {
+        EXPECT_TRUE(module->has_schedule());
+        TF_CHECK_OK(module->schedule().Verify());
+      });
+}
+
+}  // namespace
+
+}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/transforms/conv_padding_legalization.cc b/third_party/xla/xla/service/gpu/transforms/conv_padding_legalization.cc
new file mode 100644
index 0000000..f072a91
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/conv_padding_legalization.cc
@@ -0,0 +1,461 @@
+/* Copyright 2017 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/conv_padding_legalization.h"
+
+#include <algorithm>
+#include <cstddef>
+#include <cstdint>
+#include <cstdlib>
+#include <vector>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/literal_util.h"
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/service/hlo_creation_utils.h"
+#include "xla/service/shape_inference.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/util.h"
+#include "xla/window_util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/status.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+bool IsForwardConvolutionCanonical(const HloInstruction& conv) {
+  CHECK(conv.custom_call_target() == kCudnnConvForwardCallTarget ||
+        conv.custom_call_target() ==
+            kCudnnConvBiasActivationForwardCallTarget ||
+        conv.custom_call_target() == kCudnnConvForwardGraphCallTarget);
+  return window_util::HasSymmetricPadding(conv.window()) &&
+         !window_util::HasNegativePadding(conv.window()) &&
+         !window_util::HasDilation(conv.window());
+}
+
+// If the (positive and negative) padding on the input operand of a convolution
+// can't be folded into a cuDNN convolution libcall (e.g. uneven padding and
+// dilation), returns kPad and/or kSlice instructions that explicitly apply the
+// padding; otherwise returns the original input operand. When there is both
+// positive padding (including dilation) and negative padding, we insert both
+// kPad and kSlice. Modifies 'conv_window' accordingly if any padding was moved
+// into a kPad or kSlice op.
+HloInstruction* MaybePaddedAndSlicedInput(
+    Window* conv_window, const ConvolutionDimensionNumbers& conv_dnums,
+    HloInstruction* input) {
+  HloComputation* computation = input->parent();
+  if (!window_util::HasSymmetricPadding(*conv_window) ||
+      window_util::HasBaseDilation(*conv_window)) {
+    // If padding is uneven or has dilation, we insert a kPad instruction that
+    // applies positive padding and dilation.
+    //
+    // TODO(phawkins): If conv_window has asymmetric padding, perhaps instead of
+    // moving all the padding into an explicit pad op, we should keep as much
+    // padding inside of cudnn as possible, on the assumption that padding
+    // within cudnn is basically free, whereas a kPad's cost increases as the
+    // amount of padding increases.
+    PaddingConfig padding_config =
+        MakeNoPaddingConfig(input->shape().dimensions_size());
+    for (size_t i = 0; i < conv_dnums.input_spatial_dimensions().size(); ++i) {
+      int64_t dim = conv_dnums.input_spatial_dimensions(i);
+      if (conv_window->dimensions(i).padding_low() > 0) {
+        padding_config.mutable_dimensions(dim)->set_edge_padding_low(
+            conv_window->dimensions(i).padding_low());
+        conv_window->mutable_dimensions(i)->set_padding_low(0);
+      }
+      if (conv_window->dimensions(i).padding_high() > 0) {
+        padding_config.mutable_dimensions(dim)->set_edge_padding_high(
+            conv_window->dimensions(i).padding_high());
+        conv_window->mutable_dimensions(i)->set_padding_high(0);
+      }
+      if (conv_window->dimensions(i).base_dilation() != 1) {
+        padding_config.mutable_dimensions(dim)->set_interior_padding(
+            conv_window->dimensions(i).base_dilation() - 1);
+        conv_window->mutable_dimensions(i)->set_base_dilation(1);
+      }
+    }
+    PrimitiveType element_type = input->shape().element_type();
+    HloInstruction* padding = computation->AddInstruction(
+        HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
+    input =
+        MakePadHlo(input, padding, padding_config, &input->metadata()).value();
+  }
+
+  if (window_util::HasNegativePadding(*conv_window)) {
+    // If the window has negative padding, insert a kSlice that explicitly
+    // applies negative padding.
+    //
+    // For each dimension, initialize the start index to 0 and the limit index
+    // to the size of that dimension.
+    std::vector<int64_t> start_indices(input->shape().dimensions_size(), 0);
+    std::vector<int64_t> limit_indices(input->shape().dimensions().begin(),
+                                       input->shape().dimensions().end());
+    std::vector<int64_t> strides(input->shape().dimensions_size(), 1);
+    for (size_t i = 0; i < conv_dnums.input_spatial_dimensions().size(); ++i) {
+      int64_t dim = conv_dnums.input_spatial_dimensions(i);
+      // If dimension "dim" has negative padding, increase the start index or
+      // decrement the limit index by the amount of negative padding.
+      if (conv_window->dimensions(i).padding_low() < 0) {
+        start_indices[dim] += -conv_window->dimensions(i).padding_low();
+        conv_window->mutable_dimensions(i)->set_padding_low(0);
+      }
+      if (conv_window->dimensions(i).padding_high() < 0) {
+        limit_indices[dim] -= -conv_window->dimensions(i).padding_high();
+        conv_window->mutable_dimensions(i)->set_padding_high(0);
+      }
+    }
+
+    input = MakeSliceHlo(input, start_indices, limit_indices, strides).value();
+  }
+
+  return input;
+}
+
+// If the padding on the kernel operand of a convolution can't be folded into a
+// cuDNN convolution libcall (e.g. dilation), returns a kPad instruction that
+// explicitly applies the padding; otherwise returns the original kernel
+// operand.
+HloInstruction* MaybePaddedKernel(const Window& conv_window,
+                                  const ConvolutionDimensionNumbers& conv_dnums,
+                                  HloInstruction* kernel) {
+  if (!window_util::HasWindowDilation(conv_window)) {
+    return kernel;
+  }
+
+  // Compute the shape and padding config of the pad to be inserted.
+  PaddingConfig padding_config;
+  padding_config.mutable_dimensions()->Reserve(
+      kernel->shape().dimensions_size());
+  for (size_t i = 0; i < kernel->shape().dimensions_size(); ++i) {
+    padding_config.add_dimensions();
+  }
+  for (size_t i = 0; i < conv_dnums.kernel_spatial_dimensions().size(); ++i) {
+    int64_t dim = conv_dnums.kernel_spatial_dimensions(i);
+    padding_config.mutable_dimensions(dim)->set_interior_padding(
+        conv_window.dimensions(i).window_dilation() - 1);
+  }
+
+  HloComputation* computation = kernel->parent();
+  PrimitiveType element_type = kernel->shape().element_type();
+  HloInstruction* padding = computation->AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
+  return MakePadHlo(kernel, padding, padding_config, &kernel->metadata())
+      .value();
+}
+}  // namespace
+
+bool ConvPaddingLegalization::CanonicalizeForwardConvolution(
+    HloInstruction* conv) {
+  if (IsForwardConvolutionCanonical(*conv)) {
+    return false;
+  }
+
+  // Insert slices and/or pads between the convolution and its input and/or
+  // kernel operand.
+  Window new_conv_window = conv->window();
+  HloInstruction* new_input = MaybePaddedAndSlicedInput(
+      &new_conv_window, conv->convolution_dimension_numbers(),
+      conv->mutable_operand(0));
+  HloInstruction* new_kernel =
+      MaybePaddedKernel(new_conv_window, conv->convolution_dimension_numbers(),
+                        conv->mutable_operand(1));
+
+  // Remove the window dilation from convolution's window field. These paddings
+  // are made explicit with the pads inserted by MaybePaddedKernel().
+  for (size_t i = 0; i < new_conv_window.dimensions_size(); ++i) {
+    WindowDimension* dim = new_conv_window.mutable_dimensions(i);
+
+    // The size of the kernel may have changed so update the Window to match.
+    dim->set_size(new_kernel->shape().dimensions(
+        conv->convolution_dimension_numbers().kernel_spatial_dimensions(i)));
+    dim->set_window_dilation(1);
+  }
+
+  // The conv CustomCall returns a tuple (conv_result, scratch_buffer).  Extract
+  // out the shape of conv_result.
+  VLOG(1) << "Canonicalizing forward conv";
+  std::vector<HloInstruction*> operands(conv->operands().begin(),
+                                        conv->operands().end());
+  operands[0] = new_input;
+  operands[1] = new_kernel;
+  auto new_conv = conv->parent()->AddInstruction(
+      conv->CloneWithNewOperands(conv->shape(), operands));
+  new_conv->set_window(new_conv_window);
+  VLOG(1) << "Replacing:\n  " << conv->ToString() << "\nwith:\n  "
+          << new_conv->ToString();
+  TF_CHECK_OK(conv->parent()->ReplaceInstruction(conv, new_conv));
+  return true;
+}
+
+namespace {
+void IncreasePaddingLowBy(int64_t delta, WindowDimension* window_dim) {
+  window_dim->set_padding_low(window_dim->padding_low() + delta);
+}
+
+void IncreasePaddingHighBy(int64_t delta, WindowDimension* window_dim) {
+  window_dim->set_padding_high(window_dim->padding_high() + delta);
+}
+}  // namespace
+
+bool ConvPaddingLegalization::CanonicalizeBackwardFilterConvolution(
+    HloInstruction* backward_conv) {
+  CHECK_EQ(backward_conv->custom_call_target(),
+           kCudnnConvBackwardFilterCallTarget);
+  if (window_util::HasSymmetricPadding(backward_conv->window())) {
+    return false;
+  }
+
+  // A backward filter convolution with uneven padding can be canonicalized to
+  // one with even padding by padding the activations (input) beforehand. For
+  // example,
+  //   BackwardFilterConv(ABCD, xyz, padding_low=1, padding_high=2)
+  // is equivalent to
+  //   ABCD0 = Pad(ABCD, padding_high=1)
+  //   BackwardFilterConv(ABCD0, xyz, padding_low=padding_high=1)
+  // We choose the lesser of padding_low and padding_high as the new padding.
+  HloInstruction* input = backward_conv->mutable_operand(0);
+  Window new_backward_conv_window = backward_conv->window();
+  // input_padding_config is the config of the kPad to be inserted.
+  PaddingConfig input_padding_config =
+      MakeNoPaddingConfig(input->shape().rank());
+  ConvolutionDimensionNumbers backward_conv_dnums =
+      backward_conv->convolution_dimension_numbers();
+  for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) {
+    int64_t padding_low = backward_conv->window().dimensions(i).padding_low();
+    int64_t padding_high = backward_conv->window().dimensions(i).padding_high();
+    if (padding_low < 0 || padding_high < 0) {
+      // TODO(b/32744257): The following canonicalization wouldn't remove
+      // negative padding in a backward convolution, and would therefore cause
+      // cuDNN convolution (which doesn't support negative padding) to fail.
+      return false;
+    }
+    // Compute the new, even padding for the backward conv operation.
+    int64_t new_conv_padding = std::min(padding_low, padding_high);
+    int64_t dim = backward_conv_dnums.input_spatial_dimensions(i);
+    input_padding_config.mutable_dimensions(dim)->set_edge_padding_low(
+        padding_low - new_conv_padding);
+    input_padding_config.mutable_dimensions(dim)->set_edge_padding_high(
+        padding_high - new_conv_padding);
+
+    // Since we move some padding from the backward convolution to the kPad, we
+    // need to accordingly reduce the padding amount of the backward convolution
+    // and its inner forward convolution.
+    auto* new_dim = new_backward_conv_window.mutable_dimensions(i);
+    new_dim->set_padding_low(new_conv_padding);
+    new_dim->set_padding_high(new_conv_padding);
+  }
+
+  // Create a new backward convolution replacing the old one.
+  HloComputation* computation = backward_conv->parent();
+  HloInstruction* output = backward_conv->mutable_operand(1);
+  HloInstruction* padding =
+      computation->AddInstruction(HloInstruction::CreateConstant(
+          LiteralUtil::Zero(input->shape().element_type())));
+  HloInstruction* padded_input =
+      MakePadHlo(input, padding, input_padding_config).value();
+
+  // The shape of the backward_conv CustomCall is a tuple (conv_result,
+  // scratch_buffer).  Extract out the shape of conv_result.
+  HloInstruction* new_backward_conv =
+      computation->AddInstruction(backward_conv->CloneWithNewOperands(
+          backward_conv->shape(), {padded_input, output}));
+  new_backward_conv->set_window(new_backward_conv_window);
+
+  VLOG(1) << "Canonicalizing backward filter conv";
+  VLOG(1) << "Replacing:\n  " << backward_conv->ToString() << "\nwith:\n  "
+          << new_backward_conv->ToString();
+
+  TF_CHECK_OK(
+      computation->ReplaceInstruction(backward_conv, new_backward_conv));
+  return true;
+}
+
+bool ConvPaddingLegalization::CanonicalizeBackwardInputConvolution(
+    HloInstruction* backward_conv) {
+  if (window_util::HasSymmetricPadding(backward_conv->window())) {
+    return false;
+  }
+
+  Window new_backward_conv_window = backward_conv->window();
+  ConvolutionDimensionNumbers backward_conv_dnums =
+      backward_conv->convolution_dimension_numbers();
+
+  // The backward_conv CustomCall returns a tuple (conv_result, scratch_memory).
+  // Get the shape of conv_result.
+  Shape backward_conv_shape = backward_conv->shape().tuple_shapes(0);
+
+  Shape new_backward_conv_shape = backward_conv_shape;
+  for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) {
+    int64_t padding_low = backward_conv->window().dimensions(i).padding_low();
+    int64_t padding_high = backward_conv->window().dimensions(i).padding_high();
+    if (padding_low < 0 || padding_high < 0) {
+      // TODO(b/32744257): The following canonicalization wouldn't remove
+      // negative padding in a backward convolution, and would therefore cause
+      // cuDNN convolution (which doesn't support negative padding) to fail.
+      return false;
+    }
+    // If the backward convolution has uneven padding on the activations, we
+    // move some padding on the larger end to "internal" padding, so that the
+    // backward convolution produces larger activations which get sliced later.
+    //
+    // For example, suppose we have a non-canonical HLO
+    //   [A] = BackwardInputConvolve([a b], [x y z], padding=(low=2,high=1))
+    // where the amount of padding low is larger, we can canonicalize it to
+    //   [B A] = BackwardInputConvolve([a b], [x y z], padding=(low=1,high=1))
+    //   [A] = Slice([B A])
+    if (padding_low > padding_high) {
+      IncreasePaddingLowBy(padding_high - padding_low,
+                           new_backward_conv_window.mutable_dimensions(i));
+    } else if (padding_low < padding_high) {
+      IncreasePaddingHighBy(padding_low - padding_high,
+                            new_backward_conv_window.mutable_dimensions(i));
+    }
+    // Decreasing the padding by X *increases* the size of our output by X.
+    // Note that we have swapped input spatial dimensions with output spatial
+    // dimensions to be compatible with the cuDNN API, so
+    // input_spatial_dimensions(i) gives the i-th spatial dimension of the
+    // output.
+    int64_t dim = backward_conv_dnums.input_spatial_dimensions(i);
+    new_backward_conv_shape.set_dimensions(
+        dim, new_backward_conv_shape.dimensions(dim) +
+                 std::abs(padding_low - padding_high));
+  }
+
+  // Create a new backward convolution replacing the old one.
+  HloComputation* computation = backward_conv->parent();
+  HloInstruction* output = backward_conv->mutable_operand(0);
+  HloInstruction* filter = backward_conv->mutable_operand(1);
+
+  HloInstruction* new_backward_conv_call =
+      computation->AddInstruction(backward_conv->CloneWithNewOperands(
+          ShapeUtil::MakeTupleShape(
+              {new_backward_conv_shape, ShapeUtil::MakeShape(U8, {0})}),
+          {output, filter}));
+  new_backward_conv_call->set_window(new_backward_conv_window);
+
+  // The CustomCall created above returns a tuple (conv_result, scratch_memory).
+  // Extract out the two elements.
+  HloInstruction* new_backward_conv =
+      computation->AddInstruction(HloInstruction::CreateGetTupleElement(
+          new_backward_conv_shape, new_backward_conv_call, 0));
+  HloInstruction* new_backward_conv_scratch =
+      computation->AddInstruction(HloInstruction::CreateGetTupleElement(
+          new_backward_conv_call->shape().tuple_shapes(1),
+          new_backward_conv_call, 1));
+
+  // Slice the new backward convolution.
+  //
+  // Initialize start_indices and limit_indices as no slicing.
+  std::vector<int64_t> start_indices(
+      new_backward_conv->shape().dimensions_size(), 0LL);
+  std::vector<int64_t> limit_indices(
+      new_backward_conv->shape().dimensions().begin(),
+      new_backward_conv->shape().dimensions().end());
+  std::vector<int64_t> strides(new_backward_conv->shape().dimensions_size(),
+                               1LL);
+  for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) {
+    int64_t padding_low = backward_conv->window().dimensions(i).padding_low();
+    int64_t padding_high = backward_conv->window().dimensions(i).padding_high();
+    // Note that we have swapped input spatial dimensions with output spatial
+    // dimensions to be compatible with the cuDNN API, so
+    // input_spatial_dimensions(i) gives the i-th spatial dimension of the
+    // output.
+    int64_t dim = backward_conv_dnums.input_spatial_dimensions(i);
+    if (padding_low > padding_high) {
+      // If the amount of low padding (of the old backward convolution) is
+      // larger, we internally pad the low end of the activations and slice
+      // internal padding out here.
+      start_indices[dim] += padding_low - padding_high;
+    } else if (padding_low < padding_high) {
+      // If the amount of high padding is larger, we slice out the internal
+      // padding on the high end.
+      limit_indices[dim] -= padding_high - padding_low;
+    }
+  }
+
+  // Replace the old backward convolution with the slice.
+  Shape slice_shape =
+      ShapeInference::InferSliceShape(new_backward_conv->shape(), start_indices,
+                                      limit_indices, strides)
+          .value();
+  CHECK(ShapeUtil::Compatible(slice_shape, backward_conv_shape))
+      << ShapeUtil::HumanString(slice_shape) << " vs "
+      << ShapeUtil::HumanString(backward_conv_shape);
+
+  HloInstruction* slice = computation->AddInstruction(
+      HloInstruction::CreateSlice(backward_conv_shape, new_backward_conv,
+                                  start_indices, limit_indices, strides));
+  HloInstruction* new_tuple = computation->AddInstruction(
+      HloInstruction::CreateTuple({slice, new_backward_conv_scratch}));
+
+  VLOG(1) << "Canonicalizing backward input conv";
+  VLOG(1) << "Replacing:\n  " << backward_conv->ToString() << "\nwith:\n  "
+          << new_tuple->ToString();
+
+  TF_CHECK_OK(computation->ReplaceInstruction(backward_conv, new_tuple));
+  return true;
+}
+
+absl::StatusOr<bool> ConvPaddingLegalization::RunOnComputation(
+    HloComputation* computation) {
+  bool changed = false;
+  std::vector<HloCustomCallInstruction*> convs;
+  for (auto* instr : computation->instructions()) {
+    if (IsCustomCallToDnnConvolution(*instr)) {
+      convs.push_back(Cast<HloCustomCallInstruction>(instr));
+    }
+  }
+  for (HloCustomCallInstruction* instruction : convs) {
+    TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(instruction));
+    changed |= [&] {
+      switch (kind) {
+        case CudnnConvKind::kForward:
+        case CudnnConvKind::kForwardActivation:
+        case CudnnConvKind::kForwardGraph:
+          return CanonicalizeForwardConvolution(instruction);
+        case CudnnConvKind::kBackwardInput:
+          return CanonicalizeBackwardInputConvolution(instruction);
+        case CudnnConvKind::kBackwardFilter:
+          return CanonicalizeBackwardFilterConvolution(instruction);
+      }
+    }();
+  }
+  return changed;
+}
+
+absl::StatusOr<bool> ConvPaddingLegalization::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  bool changed = false;
+  for (HloComputation* computation :
+       module->MakeNonfusionComputations(execution_threads)) {
+    TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation));
+    changed |= result;
+  }
+  return changed;
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/conv_padding_legalization.h b/third_party/xla/xla/service/gpu/transforms/conv_padding_legalization.h
new file mode 100644
index 0000000..1841c92
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/conv_padding_legalization.h
@@ -0,0 +1,55 @@
+/* Copyright 2017 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_CONV_PADDING_LEGALIZATION_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_CONV_PADDING_LEGALIZATION_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// An HLO pass that canonicalizes convolution instructions for GPU codegen. It
+// inserts Pad instructions before Convolution instructions with uncanonicalized
+// padding, so that they can be lowered to Cudnn/Miopen convolution.
+class ConvPaddingLegalization : public HloModulePass {
+ public:
+  absl::string_view name() const override {
+    return "conv-padding-legalization";
+  }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+  absl::StatusOr<bool> RunOnComputation(HloComputation* computation);
+  // Returns if any changes are made to the parent computation.
+  bool CanonicalizeForwardConvolution(HloInstruction* conv);
+  bool CanonicalizeBackwardFilterConvolution(HloInstruction* backward_conv);
+  bool CanonicalizeBackwardInputConvolution(HloInstruction* backward_conv);
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_CONV_PADDING_LEGALIZATION_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/conv_padding_legalization_test.cc b/third_party/xla/xla/service/gpu/transforms/conv_padding_legalization_test.cc
new file mode 100644
index 0000000..06682e7
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/conv_padding_legalization_test.cc
@@ -0,0 +1,96 @@
+/* Copyright 2020 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/conv_padding_legalization.h"
+
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/test.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/test.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+namespace m = ::xla::match;
+
+using ConvPaddingLegalizationTest = HloTestBase;
+
+TEST_F(ConvPaddingLegalizationTest, BackwardInputConvolve) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule convolution_module
+ENTRY %convolution (operand f64[2,2,2,3]{3,2,1,0}) -> (f64[2,2,4,4]{3,2,1,0}, u8[0]) {
+  %operand = f64[2,2,2,3]{3,2,1,0} parameter(0)
+  %kernel = f64[2,3,2,3]{3,2,1,0} constant(
+  {
+    { /*i0=0*/
+    { /*i1=0*/
+      { 0.29629629629629628, 0.30246913580246915, 0.30864197530864196 },
+      { 0.31481481481481483, 0.32098765432098764, 0.3271604938271605 }
+    },
+    { /*i1=1*/
+      { 0.25925925925925924, 0.26543209876543211, 0.27160493827160492 },
+      { 0.27777777777777779, 0.2839506172839506, 0.29012345679012347 }
+    },
+    { /*i1=2*/
+      { 0.22222222222222221, 0.22839506172839505, 0.23456790123456789 },
+      { 0.24074074074074073, 0.24691358024691357, 0.25308641975308643 }
+    }
+    },
+    { /*i0=1*/
+    { /*i1=0*/
+      { 0.18518518518518517, 0.19135802469135801, 0.19753086419753085 },
+      { 0.20370370370370369, 0.20987654320987653, 0.21604938271604937 }
+    },
+    { /*i1=1*/
+      { 0.14814814814814814, 0.15432098765432098, 0.16049382716049382 },
+      { 0.16666666666666666, 0.1728395061728395, 0.17901234567901234 }
+    },
+    { /*i2=2*/
+      { 0.1111111111111111, 0.11728395061728394, 0.12345679012345678 },
+      { 0.12962962962962962, 0.13580246913580246, 0.1419753086419753 }
+    }
+    }
+  })
+  %reverse = f64[2,3,2,3]{3,2,1,0} reverse(%kernel), dimensions={0,1}
+  ROOT %custom-call = (f64[2,2,4,4]{3,2,1,0}, u8[0]{0}) custom-call(f64[2,2,2,3]{3,2,1,0} %operand, f64[2,3,2,3]{3,2,1,0} %reverse), window={size=2x3 stride=2x2 pad=0_0x0_1}, dim_labels=bf01_01io->b01f, custom_call_target="__cudnn$convBackwardInput", backend_config="{\"algorithm\":\"0\",\"tensor_ops_enabled\":false,\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}"
+}
+                                               )")
+                    .value();
+  ASSERT_TRUE(ConvPaddingLegalization().Run(module.get()).value());
+  auto root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(root, GmockMatch(m::Tuple(
+                        m::Slice(m::GetTupleElement(
+                            m::CustomCall({kCudnnConvBackwardInputCallTarget},
+                                          m::Op(), m::Reverse(m::Constant())),
+                            0)),
+                        m::GetTupleElement())));
+  auto slice = root->operand(0);
+  Shape expected_slice_shape = ShapeUtil::MakeShape(F64, {2, 2, 4, 4});
+  EXPECT_TRUE(ShapeUtil::Equal(slice->shape(), expected_slice_shape));
+  auto conv = slice->operand(0);
+  Shape expected_conv_shape = ShapeUtil::MakeShape(F64, {2, 2, 4, 5});
+  EXPECT_TRUE(ShapeUtil::Equal(conv->shape(), expected_conv_shape));
+}
+
+}  // anonymous namespace
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/conv_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/conv_rewriter.cc
new file mode 100644
index 0000000..e19622d
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/conv_rewriter.cc
@@ -0,0 +1,869 @@
+/* Copyright 2018 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/conv_rewriter.h"
+
+#include <cstdint>
+#include <cstdlib>
+#include <memory>
+#include <numeric>
+#include <optional>
+#include <string>
+#include <string_view>
+#include <tuple>
+#include <utility>
+#include <variant>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/status.h"
+#include "absl/strings/str_replace.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/permutation_util.h"
+#include "xla/primitive_util.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/util.h"
+#include "xla/window_util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/logging.h"
+#include "tsl/platform/status.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+
+absl::Status CheckTypes(HloInstruction* conv,
+                        const se::GpuComputeCapability cc) {
+  auto valid_shape = [conv, &cc](const Shape& shape) -> absl::Status {
+    PrimitiveType type = shape.element_type();
+    if (!primitive_util::IsFloatingPointType(type) &&
+        !primitive_util::IsIntegralType(type)) {
+      // Among integral types, only S8 is supported. But CudnnFusedConvRewriter
+      // may rewrite convolutions of wider types into S8 convolutions, so allow
+      // all integral convolutions here.
+      return Unimplemented(
+          "Convolutions must have floating-point or integral operands/outputs, "
+          "but got convolution with type %s: %s",
+          primitive_util::LowercasePrimitiveTypeName(type), conv->ToString());
+    }
+    if (primitive_util::IsF8Type(type)) {
+      if (type != F8E4M3FN && type != F8E5M2) {
+        return Unimplemented(
+            "The only FP8 types supported in convolutions are f8e5m2 and "
+            "f8e4m3, "
+            "but got convolution with FP8 type %s: %s",
+            primitive_util::LowercasePrimitiveTypeName(type), conv->ToString());
+      }
+      if (!std::holds_alternative<se::CudaComputeCapability>(cc)) {
+        return Unimplemented(
+            "FP8 convolutions are only supported on CUDA GPUs, but got "
+            "FP8 convolution on ROCm GPU: %s",
+            conv->ToString());
+      } else if (!std::get<se::CudaComputeCapability>(cc).IsAtLeastHopper()) {
+        return Unimplemented(
+            "FP8 convolutions are only supported on CUDA GPUs with compute "
+            "capability at least 9.0, but got "
+            "FP8 convolution on GPU with compute capability %s: %s",
+            std::get<se::CudaComputeCapability>(cc).ToString(),
+            conv->ToString());
+      }
+    }
+    return absl::OkStatus();
+  };
+
+  TF_RETURN_IF_ERROR(valid_shape(conv->shape()));
+  TF_RETURN_IF_ERROR(valid_shape(conv->operand(0)->shape()));
+  TF_RETURN_IF_ERROR(valid_shape(conv->operand(1)->shape()));
+  return absl::OkStatus();
+}
+
+using ConvolutionMatch = std::optional<
+    std::tuple<Window, ConvolutionDimensionNumbers, HloInstruction*>>;
+
+// Determine whether conv2d is equal to conv1d.
+bool MaybeConv1dToConv2d(HloInstruction* conv) {
+  if (conv->window().dimensions().size() != 2) {
+    return false;
+  }
+  if (conv->operand(1)->opcode() != HloOpcode::kReshape) {
+    return false;
+  }
+  auto filter = conv->operand(1);
+  std::optional<ShapeUtil::ShapeEqualityDescriptor> reshape_degenerate =
+      filter->ReshapeMerelyInsertsOrDeletes1SizedDimensions();
+  if (reshape_degenerate.has_value() &&
+      reshape_degenerate->deleted_dimensions.empty() &&
+      reshape_degenerate->inserted_dimensions.size() == 1) {
+    const auto& dnums = conv->convolution_dimension_numbers();
+    for (auto dim : dnums.kernel_spatial_dimensions()) {
+      if (dim == reshape_degenerate->inserted_dimensions[0]) {
+        return true;
+      }
+    }
+  }
+  return false;
+}
+
+bool CanImplementAsGpuForwardConv(HloInstruction* conv) {
+  const ConvolutionDimensionNumbers& dnums =
+      conv->convolution_dimension_numbers();
+  if (dnums.input_spatial_dimensions_size() > 3) {
+    return false;
+  }
+
+  // CuDNN does not accept zero-element arguments
+  if (ShapeUtil::IsZeroElementArray(conv->operand(0)->shape()) ||
+      ShapeUtil::IsZeroElementArray(conv->operand(1)->shape())) {
+    return false;
+  }
+
+  // CuDNN can perform either cross correlation (no reversal),
+  // or convolution (all dimensions reversed).
+  if (dnums.input_spatial_dimensions_size() == 2
+          ? !window_util::AllOrNoneReversed(conv->window())
+          : window_util::HasWindowReversal(conv->window())) {
+    return false;
+  }
+  return true;
+}
+
+// Try to match a backward filter pattern that contains "conv".
+// Precondition: "conv" is a kConvolution.
+ConvolutionMatch MatchBackwardFilter(HloInstruction* conv) {
+  VLOG(2) << "Trying to match convolution backward filter.";
+
+  if (conv->feature_group_count() > 1) {
+    VLOG(1) << conv->ToString()
+            << " is a forward convolution. All grouped backward filters are "
+               "mapped to batch grouped convolutions in tf2xla bridge. Hence "
+               "backward filter "
+               "convolutions cannot have feature groups greater than 1 at this "
+               "point. No need to fold to backward filter.";
+    return std::nullopt;
+  }
+
+  // Step 1: match the instruction pattern without considering the paddings and
+  // dimension numbers just yet. We may need some generic pattern matcher
+  // similar to third_party/llvm/llvm/include/llvm/IR/PatternMatch.h
+  //
+  // Backward filter convolution is implemented in XLA as the forward
+  // convolution of padded activations and dilated gradients. Padding on
+  // activations and dilation on gradients are specified in the "window" field
+  // of the forward convolution.
+  //
+  //        activations  gradients
+  //              \         /
+  //               v       v
+  //              Convolution
+  //                 conv
+  CHECK_EQ(HloOpcode::kConvolution, conv->opcode());
+
+  // Step 2: match paddings and dimension numbers of the forward convolution.
+  const ConvolutionDimensionNumbers& conv_dnums =
+      conv->convolution_dimension_numbers();
+  auto input_batch_dim = conv_dnums.input_batch_dimension();
+  auto input_feature_dim = conv_dnums.input_feature_dimension();
+  auto input_spatial_dims = conv_dnums.input_spatial_dimensions();
+  auto kernel_input_feature_dim = conv_dnums.kernel_input_feature_dimension();
+  auto kernel_output_feature_dim = conv_dnums.kernel_output_feature_dimension();
+  auto kernel_spatial_dims = conv_dnums.kernel_spatial_dimensions();
+  auto output_batch_dim = conv_dnums.output_batch_dimension();
+  auto output_feature_dim = conv_dnums.output_feature_dimension();
+  auto output_spatial_dims = conv_dnums.output_spatial_dimensions();
+  for (const WindowDimension& window_dim : conv->window().dimensions()) {
+    if (window_dim.stride() != 1) {
+      VLOG(1) << "Forward convolution's window "
+              << conv->window().ShortDebugString()
+              << " should have stride of 1.";
+      return std::nullopt;
+    }
+    if (window_dim.base_dilation() != 1) {
+      VLOG(1) << "Forward convolution's window "
+              << conv->window().ShortDebugString()
+              << " should have no base (LHS) dilation.";
+      return std::nullopt;
+    }
+    if (window_dim.padding_low() < 0) {
+      VLOG(1) << "Padding low should be non-negative.";
+      return std::nullopt;
+    }
+    if (window_dim.window_reversal()) {
+      VLOG(1) << "Window reversal field not supported";
+      return std::nullopt;
+    }
+    // Padding high will be checked in Step 3.
+  }
+  // Mathematically, there is no difference between convolution forward vs
+  // backward filter. A backward filter:
+  //   [N, O, H+h-1, W+w-1] x [N, C, H, W] -> [O, C, h, w]
+  // Can be treated as a forward convolution with `N` treated as the new
+  // contracting (feature) dimension, `O` treated as the new batch dimension,
+  // and `C` treated as the new output feature dimension. The only difference is
+  // layouts and performance.
+  //
+  // Since there is no way to precisely tell whether we want a foward conv or
+  // backward filter conv, we have to rely on heuristics. Empirically forward
+  // convolutions have very small kernel dimensions, while in the backward pass
+  // "kernel dimensions" are large. If kernel dimensions are smaller than the
+  // output dimensions, return foward conv; otherwise proceed with backward
+  // filter conv. But for conv1d, it is not same. Due to conv1d always reshape
+  // 1D-filter to 2D-filter, even backward or forward will exist one small
+  // kernel dimension. We should handle this special case.
+  int small_kernel_dimension_num = 0;
+  for (int i = 0; i < kernel_spatial_dims.size(); ++i) {
+    if (conv->operand(1)->shape().dimensions(kernel_spatial_dims[i]) <=
+        conv->shape().dimensions(output_spatial_dims[i])) {
+      small_kernel_dimension_num += 1;
+    }
+  }
+  if ((kernel_spatial_dims.empty() || small_kernel_dimension_num > 1 ||
+       (!MaybeConv1dToConv2d(conv) && small_kernel_dimension_num == 1)) &&
+      !window_util::HasWindowDilation(conv->window())) {
+    VLOG(1) << conv->ToString()
+            << " is a regular forward convolution. No need "
+               "to fold it to a backward filter convolution....";
+    return std::nullopt;
+  }
+
+  // Step 3: fuse the matched HLOs into a backward convolution instruction.
+  //
+  // Compute the window of the backward convolution.
+  Window backward_conv_window;
+  for (int i = 0; i < input_spatial_dims.size(); ++i) {
+    WindowDimension* dim = backward_conv_window.add_dimensions();
+    // The window size of the backward convolution equals the output size of the
+    // forward convolution.
+    int64_t filter_size = conv->shape().dimensions(output_spatial_dims[i]);
+    dim->set_size(filter_size);
+    // The window stride equals the window dilation of the forward convolution.
+    dim->set_stride(conv->window().dimensions(i).window_dilation());
+    // The window's low padding is the same as the low padding of the
+    // activations.
+    dim->set_padding_low(conv->window().dimensions(i).padding_low());
+    dim->set_base_dilation(1);
+    dim->set_window_dilation(1);
+
+    int64_t input_size =
+        conv->operand(0)->shape().dimensions(input_spatial_dims[i]);
+    int64_t output_size = conv->window().dimensions(i).size();
+    // Compute the range of the amount of valid high padding. We first compute
+    // min_padding_high, the amount of padding on the right/bottom to ensure the
+    // last patch ends at the border, i.e.,
+    //
+    //   input_size + dim->padding_low() + min_padding_high
+    //     = (output_size - 1) * stride + filter_size
+    //
+    // Because convolution ignores trailing incomplete windows, any amount of
+    // padding high from min_padding_high to min_padding_high+stride-1
+    // (max_padding_high) has the same effect.
+    int64_t padded_input_size = filter_size + (output_size - 1) * dim->stride();
+    int64_t min_padding_high =
+        padded_input_size - input_size - dim->padding_low();
+    int64_t max_padding_high = min_padding_high + dim->stride() - 1;
+    CHECK_GE(dim->padding_low(), 0);
+    // In practice, since cuDNN convolution only supports even padding, we make
+    // the amount of high padding the same as the amount of low padding as long
+    // as it is between min_padding_high and max_padding_high. If it is not in
+    // that range, we pick the one that's closest to dim->padding_low() and let
+    // GpuConvPaddingLegalization canonicalize the resultant backward
+    // convolution later. Picking the closest one minimizes the cost of the kPad
+    // instruction to be inserted by GpuConvPaddingLegalization.
+    if (dim->padding_low() >= min_padding_high &&
+        dim->padding_low() <= max_padding_high) {
+      dim->set_padding_high(dim->padding_low());
+    } else {
+      if (dim->padding_low() < min_padding_high) {
+        dim->set_padding_high(min_padding_high);
+      } else {
+        dim->set_padding_high(max_padding_high);
+      }
+    }
+    if (dim->padding_high() < 0) {
+      LOG(WARNING)
+          << "Fusing this pattern to backward filter convolution would cause "
+             "negative padding ("
+          << dim->padding_high()
+          << ") on right/bottom of the weight gradients, which is not "
+             "supported by GpuConvPaddingLegalization (b/32744257). "
+             "Falling back to "
+             "unfused convolution for instruction: "
+          << conv->ToString();
+      return std::nullopt;
+    }
+  }
+
+  // Restore the dimension numbers of the backward convolution from the forward
+  // convolution. The two activation dimensions are reversed (batch and
+  // feature).
+  ConvolutionDimensionNumbers backward_conv_dnums;
+  backward_conv_dnums.set_input_batch_dimension(input_feature_dim);
+  backward_conv_dnums.set_input_feature_dimension(input_batch_dim);
+  for (int i = 0; i < input_spatial_dims.size(); ++i) {
+    backward_conv_dnums.add_input_spatial_dimensions(input_spatial_dims[i]);
+  }
+  backward_conv_dnums.set_output_batch_dimension(kernel_input_feature_dim);
+  backward_conv_dnums.set_output_feature_dimension(kernel_output_feature_dim);
+  for (int i = 0; i < kernel_spatial_dims.size(); ++i) {
+    backward_conv_dnums.add_output_spatial_dimensions(kernel_spatial_dims[i]);
+  }
+  // The dimension numbering of the output of the forward convolution (before
+  // transposition) is the same as that of the activations (according to the
+  // semantics of kConvolution). The batch dimension of the activations should
+  // be treated as the input feature dimension, and the feature dimension should
+  // be treated as the output feature.
+  backward_conv_dnums.set_kernel_input_feature_dimension(output_batch_dim);
+  backward_conv_dnums.set_kernel_output_feature_dimension(output_feature_dim);
+  for (int i = 0; i < output_spatial_dims.size(); ++i) {
+    backward_conv_dnums.add_kernel_spatial_dimensions(output_spatial_dims[i]);
+  }
+
+  HloInstruction* lhs = conv->mutable_operand(0);
+  return std::make_tuple(backward_conv_window, backward_conv_dnums, lhs);
+}
+
+// Try to match a backward input pattern that contains "conv".
+// Precondition: "conv" is a kConvolution.
+ConvolutionMatch MatchBackwardInput(HloInstruction* conv) {
+  VLOG(2) << "Trying to match convolution backward input.";
+
+  // TODO(timshen) Theoretically cuDNN supports grouped convolutions also
+  // for the backward input convolution, but based on the cudnn's current state
+  // there is not much performance improvement when using the
+  // cudnn backward input API for grouped conv.
+  // This needs to be re-evaluated for future cuDNN versions.
+  // Note that we already have the necessary code down below, the only thing to
+  // enable it is to remove the following early return.
+  if (conv->feature_group_count() > 1) {
+    return std::nullopt;
+  }
+
+  // Match instruction pattern.
+  CHECK_EQ(HloOpcode::kConvolution, conv->opcode());
+  HloInstruction* reverse_filter = conv->mutable_operand(1);
+  ConvolutionDimensionNumbers dnums = conv->convolution_dimension_numbers();
+
+  // Match BackwardInput for a depthwise convolution and thunk it to forward
+  // convolution Output feature dimension and input feature dimension has been
+  // swapped in the bridge. Hence to get the actual input features we need to
+  // query the output feature dimension
+  auto kernel_out_feature_dim = dnums.kernel_output_feature_dimension();
+  auto kernel_out_features =
+      reverse_filter->shape().dimensions(kernel_out_feature_dim);
+
+  // For a depthwise convolution, the input features must be equal to the
+  // feature_group_count. We can leverage this property to match a depthwise
+  // convolution and thunk it to forward conv
+  if (conv->feature_group_count() > 1 &&
+      kernel_out_features == conv->feature_group_count()) {
+    return std::nullopt;
+  }
+
+  // We pattern-match to a backwards input conv if:
+  //
+  //  - all spatial dims of the filter are reversed
+  //
+  // OR
+  //
+  //  - filter is 1x1 or a constant AND
+  //  - conv has base dilation (otherwise this is just a regular forward conv).
+  //
+  // The final criterion above is just for canonicalization; cudnn seems to run
+  // just as fast if we canonicalize 1x1/constant filters without base dilation
+  // to forward or backward convs.  We canonicalize to forward conv because (a)
+  // it's more natural (constant filters usually show up when doing inference,
+  // and having backwards convolutions in inference graphs would be weird), and
+  // (b) cudnn has special fusions for forward conv plus bias and activation,
+  // and we want to pattern-match to that after running this pass.
+  bool is_reversed_filter =
+      reverse_filter->opcode() == HloOpcode::kReverse &&
+      absl::c_is_permutation(dnums.kernel_spatial_dimensions(),
+                             reverse_filter->dimensions());
+  // For conv1d which reshape to conv2d, filter reverse pattern is
+  // reshape(reverse(filter)). It seems we can reuse conv2d backward input
+  // pattern matcher, but after algsimp pass, this pattern will change to
+  // reverse(reshape(filter)) and fail to match. So matching conv1d backward
+  // input need different processing logic.
+  bool is_reversed_conv1d_filter =
+      MaybeConv1dToConv2d(conv) &&
+      reverse_filter->operand(0)->opcode() == HloOpcode::kReverse;
+  bool is_1x1_filter =
+      absl::c_all_of(conv->window().dimensions(),
+                     [](const WindowDimension& d) { return d.size() == 1; });
+  if (!is_reversed_filter && !is_reversed_conv1d_filter &&
+      !(window_util::HasBaseDilation(conv->window()) &&
+        (reverse_filter->IsConstant() || is_1x1_filter))) {
+    VLOG(1) << "Can't match to backwards convolution. Either filter is not "
+               "kReverse, or it's not a base-dilated conv with a 1x1 or "
+               "constant filter.";
+    return std::nullopt;
+  }
+
+  // Match padding and dilation of the forward convolution.
+  for (const WindowDimension& window_dim : conv->window().dimensions()) {
+    if (window_dim.stride() != 1) {
+      VLOG(1) << "Forward convolution's window "
+              << conv->window().ShortDebugString()
+              << " should have stride of 1.";
+      return std::nullopt;
+    }
+    if (window_dim.window_dilation() != 1) {
+      VLOG(1) << "Forward convolution's window "
+              << conv->window().ShortDebugString()
+              << " should have no window dilation.";
+      return std::nullopt;
+    }
+    if (window_dim.window_reversal()) {
+      VLOG(1) << "Window reversal field not supported";
+      return std::nullopt;
+    }
+  }
+
+  const auto& input_spatial_dims = dnums.input_spatial_dimensions();
+  const auto& output_spatial_dims = dnums.output_spatial_dimensions();
+  CHECK_EQ(conv->window().dimensions().size(), input_spatial_dims.size());
+  CHECK_EQ(output_spatial_dims.size(), input_spatial_dims.size());
+
+  const Window& old_window = conv->window();
+  Window new_window = old_window;
+  for (size_t i = 0; i < input_spatial_dims.size(); ++i) {
+    // Restore backward convolution's padding config from the matched pattern.
+    // See the comment in tensorflow/core/kernels/conv_grad_ops.h for how we
+    // convert backward input convolution to a variant of forward convolution.
+    //
+    // The stride of the backward convolution
+    // = the base dilation factor of the forward convolution
+    auto dim = new_window.mutable_dimensions(i);
+    dim->set_stride(old_window.dimensions(i).base_dilation());
+    dim->set_base_dilation(1);
+
+    // The low padding = kernel_size - 1 - low padding on the gradients
+    // Make sure the low padding is not negative.
+    auto kernel_size = old_window.dimensions(i).size();
+    auto backward_padding_low =
+        kernel_size - 1 - old_window.dimensions(i).padding_low();
+    if (backward_padding_low < 0) {
+      LOG(WARNING)
+          << "The low padding of the backward convolution would be negative ("
+          << backward_padding_low
+          << "), which isn't supported by GpuConvPaddingLegalization "
+             "for now (b/32744257).";
+      return std::nullopt;
+    }
+    dim->set_padding_low(backward_padding_low);
+
+    // Compute the range of the amount of padding on the right/bottom of the
+    // activations. XLA's convolution requires all patches to be within the
+    // padded base. This gives us flexiblity to choose the amount of high
+    // padding from a set of values without changing the result of the backward
+    // convolution. The minimum amount (min_padding_high) makes the last patch
+    // end at the border. The maximum amount (max_padding_high) equals
+    // min_padding_high+stride-1 -- max_padding_high+1 would cause the output
+    // size to change.
+    auto unpadded_input_size = conv->shape().dimensions(output_spatial_dims[i]);
+    auto output_size =
+        conv->operand(0)->shape().dimensions(input_spatial_dims[i]);
+    auto padded_input_size = kernel_size + dim->stride() * (output_size - 1);
+    auto total_pad_size = padded_input_size - unpadded_input_size;
+    auto min_padding_high = total_pad_size - backward_padding_low;
+    auto max_padding_high = min_padding_high + dim->stride() - 1;
+
+    if (backward_padding_low >= min_padding_high &&
+        backward_padding_low <= max_padding_high) {
+      // In the best case (most likely), if backward_padding_low is in the range
+      // of the amounts of valid high padding, we choose backward_padding_low
+      // because cuDNN supports even padding only.
+      dim->set_padding_high(backward_padding_low);
+    } else {
+      // Otherwise, we choose the amount that's closest to backward_padding_low,
+      // and GpuConvPaddingLegalization will later insert kSlice
+      // instructions to enforce even padding.
+      //
+      // For example, consider the backward convolution pattern
+      //
+      //   ab     xy
+      //   | pad  | reverse
+      //  .a.b    yx
+      //     \   /
+      //      ABC
+      //
+      // The amount of low padding on activations (in backward convolution) is
+      //   backward_padding_low = kernel_size - 1 - forward_padding_low
+      //                        = 2 - 1 - 1 = 0
+      //
+      // The amount of padding high must be between 1 and 2, in order to make
+      // Conv(ABC, xy, stride=2) produce exactly 2 elements (ab). 0 is not in
+      // the range of [1,2], so we pick the closest valid amount of padding
+      // high, which is 1 in this case. Therefore, we fuse the above pattern to
+      //
+      //   ABC = BackwardInputConv(ab, xy, stride=2, padding_high=1)
+      if (backward_padding_low < min_padding_high) {
+        dim->set_padding_high(min_padding_high);
+      } else {
+        dim->set_padding_high(max_padding_high);
+      }
+    }
+    // GpuConvPaddingLegalization doesn't handle backward input
+    // convolution with negative padding for now. So fall back to unfused
+    // convolution in case of negative padding. For example,
+    //   ABCD = Conv(abc, reverse(xy), padding_high=2)
+    // could be fused to
+    //   ABCD = BackwardInputConv(abc, xy, padding_low=1, padding_high=-1)
+    // with positive padding low but negative padding high.
+    if (dim->padding_high() < 0) {
+      LOG(WARNING) << "Fusing this pattern to backward convolution would cause "
+                      "negative padding ("
+                   << dim->padding_high()
+                   << ") on right/bottom of the activations, which is not "
+                      "supported by GpuConvPaddingLegalization (b/32744257). "
+                      "Falling back to unfused convolution for instruction: "
+                   << conv->ToString();
+      return std::nullopt;
+    }
+  }
+
+  // OK, it's a match! Switch the input feature dimension with the output
+  // feature dimension. Also switch the output with the input. This is the way
+  // cuDNN expects it to be.
+  auto conv_dnums = conv->convolution_dimension_numbers();
+  dnums.set_kernel_input_feature_dimension(
+      conv_dnums.kernel_output_feature_dimension());
+  dnums.set_kernel_output_feature_dimension(
+      conv_dnums.kernel_input_feature_dimension());
+  for (int i = 0; i < input_spatial_dims.size(); ++i) {
+    dnums.set_input_spatial_dimensions(i,
+                                       conv_dnums.output_spatial_dimensions(i));
+    dnums.set_output_spatial_dimensions(i,
+                                        conv_dnums.input_spatial_dimensions(i));
+  }
+  dnums.set_input_feature_dimension(conv_dnums.output_feature_dimension());
+  dnums.set_input_batch_dimension(conv_dnums.output_batch_dimension());
+  dnums.set_output_feature_dimension(conv_dnums.input_feature_dimension());
+  dnums.set_output_batch_dimension(conv_dnums.input_batch_dimension());
+
+  // If we matched against a constant, we need to add a reverse op that can be
+  // subsumed by the cuDNN call. algebraic-simplifier will later remove any
+  // unnecessary reverses.
+  if (reverse_filter->opcode() != HloOpcode::kReverse &&
+      reverse_filter->IsConstant()) {
+    // Create a double-reverse, which is a nop.
+    HloComputation* c = conv->parent();
+    reverse_filter = c->AddInstruction(
+        HloInstruction::CreateReverse(reverse_filter->shape(), reverse_filter,
+                                      dnums.kernel_spatial_dimensions()));
+    reverse_filter = c->AddInstruction(
+        HloInstruction::CreateReverse(reverse_filter->shape(), reverse_filter,
+                                      dnums.kernel_spatial_dimensions()));
+    TF_CHECK_OK(conv->ReplaceOperandWith(/*operand_num=*/1, reverse_filter));
+  }
+
+  // Calculate the 'rhs' that goes into the backward input convolution.
+  HloInstruction* rhs = reverse_filter;
+  // One reverse is subsumed by the cuDNN call.
+  if (rhs->opcode() == HloOpcode::kReverse) {
+    rhs = rhs->mutable_operand(0);
+  } else if (is_reversed_conv1d_filter) {
+    auto src = rhs->mutable_operand(0)->mutable_operand(0);
+    rhs = conv->parent()->AddInstruction(
+        HloInstruction::CreateReshape(rhs->shape(), src));
+  }
+  if (conv->feature_group_count() == 1) {
+    return std::make_tuple(new_window, dnums, rhs);
+  }
+
+  // Handle grouped convolutions. Because we swapped the input feature dimension
+  // with the output feature dimension, we need to also reshape the kernel so
+  // that the 'feature_group_count' parameter still makes sense. The
+  // 'feature_group_count' parameter essentially specifies how often the
+  // 'kernel_input_feature_dimension' is repeated. So when we swap these
+  // dimensions, we need to divide the new 'kernel_input_feature_dimension' by
+  // 'feature_group_count' and multiply the new
+  // 'kernel_output_feature_dimension' by 'feature_group_count'.
+  int64_t input_feature_dimension = dnums.kernel_input_feature_dimension();
+  int64_t output_feature_dimension = dnums.kernel_output_feature_dimension();
+  // The following code assumes that input_feature_dimension and
+  // output_feature_dimension are adjacent.
+  if (std::abs(input_feature_dimension - output_feature_dimension) != 1) {
+    return std::nullopt;
+  }
+
+  int64_t input_features = rhs->shape().dimensions(input_feature_dimension);
+  int64_t output_features = rhs->shape().dimensions(output_feature_dimension);
+
+  // Reshape [H, W, ..., in_depth, out_depth / G] -> [H, W, ..., G, in_depth/G,
+  // out_depth / G]
+  std::vector<int64_t> reshape_dims = SpanToVector(rhs->shape().dimensions());
+  auto num_groups = conv->feature_group_count();
+  CHECK_EQ(input_features % num_groups, 0)
+      << "Input feature count should be an exact multiple of feature group "
+         "count";
+  reshape_dims[input_feature_dimension] =
+      reshape_dims[input_feature_dimension] / num_groups;
+  reshape_dims.insert(reshape_dims.begin() + input_feature_dimension,
+                      num_groups);
+
+  HloComputation* c = conv->parent();
+  rhs = c->AddInstruction(HloInstruction::CreateReshape(
+      ShapeUtil::MakeShape(rhs->shape().element_type(), reshape_dims), rhs));
+
+  // Transpose [H, W, ..., G, in_depth/G, out_depth / G] -> [H, W, ...,
+  // in_depth/G, G, out_depth / G]
+  std::vector<int64_t> transpose_dims(rhs->shape().dimensions_size());
+  std::iota(transpose_dims.begin(), transpose_dims.end(), 0);
+  transpose_dims.erase(transpose_dims.begin() + input_feature_dimension);
+  transpose_dims.insert(transpose_dims.begin() + output_feature_dimension,
+                        input_feature_dimension);
+  std::vector<int64_t> transpose_reshape_dims =
+      SpanToVector(rhs->shape().dimensions());
+  transpose_reshape_dims.erase(transpose_reshape_dims.begin() +
+                               input_feature_dimension);
+  transpose_reshape_dims.insert(
+      transpose_reshape_dims.begin() + output_feature_dimension, num_groups);
+  rhs = c->AddInstruction(HloInstruction::CreateTranspose(
+      ShapeUtil::MakeShape(rhs->shape().element_type(), transpose_reshape_dims),
+      rhs, transpose_dims));
+
+  // Reshape [H, W, ..., in_depth/G, G, out_depth / G] -> [H, W, ...,
+  // in_depth/G, out_depth]
+  Shape new_shape = rhs->shape();
+  new_shape.DeleteDimension(output_feature_dimension);
+  new_shape.set_dimensions(output_feature_dimension,
+                           output_features * num_groups);
+  rhs = c->AddInstruction(HloInstruction::CreateReshape(new_shape, rhs));
+  return std::make_tuple(new_window, dnums, rhs);
+}
+
+HloInstruction* CreateGpuConv(absl::string_view call_target, const Shape& shape,
+                              HloInstruction* lhs, HloInstruction* rhs,
+                              const Window& window,
+                              const ConvolutionDimensionNumbers& dnums,
+                              int64_t feature_group_count,
+                              const PrecisionConfig& precision_config,
+                              const OpMetadata& metadata) {
+  HloComputation* computation = lhs->parent();
+
+  // This call returns a tuple of (conv_result, scratch_memory), where
+  // conv_result is the actual result of the convolution, and scratch_memory is
+  // temporary memory used by cudnn.
+  //
+  // At the moment, we don't know how much scratch memory this conv is going to
+  // use, so we put u8[0] in this place.  Later on another pass will choose
+  // which conv algorithm to use, and at that point we'll modify the shape of
+  // this second tuple element.
+  Shape call_shape =
+      ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U8, {0})});
+
+  HloInstruction* custom_call = computation->AddInstruction(
+      HloInstruction::CreateCustomCall(call_shape, {lhs, rhs}, call_target));
+  custom_call->set_window(window);
+  custom_call->set_convolution_dimension_numbers(dnums);
+  custom_call->set_feature_group_count(feature_group_count);
+  *custom_call->mutable_precision_config() = precision_config;
+  custom_call->set_metadata(metadata);
+
+  // Give the customcall a user-friendly name.
+  std::optional<std::string> name;
+  if (call_target == kCudnnConvForwardCallTarget) {
+    name = "cudnn-conv";
+  } else if (call_target == kCudnnConvBackwardInputCallTarget) {
+    name = "cudnn-conv-bw-input";
+  } else if (call_target == kCudnnConvBackwardFilterCallTarget) {
+    name = "cudnn-conv-bw-filter";
+  } else if (call_target == kCudnnConvBiasActivationForwardCallTarget) {
+    name = "cudnn-conv-bias-activation";
+  }
+  if (name.has_value()) {
+    computation->parent()->SetAndUniquifyInstrName(custom_call, *name);
+  }
+
+  return custom_call;
+}
+
+HloInstruction* ConvertBatchGroupedToFeatureGroupedConvolution(
+    HloInstruction* conv) {
+  CHECK_EQ(conv->feature_group_count(), 1);
+  int64_t num_groups = conv->batch_group_count();
+  auto dim_numbers = conv->convolution_dimension_numbers();
+  auto lhs = conv->mutable_operand(0);
+  auto rhs = conv->mutable_operand(1);
+
+  int64_t input_batch_dimension = dim_numbers.input_batch_dimension();
+
+  Shape output_shape = conv->shape();
+  int64_t input_feature_dimension = dim_numbers.input_feature_dimension();
+  int64_t input_feature = lhs->shape().dimensions(input_feature_dimension);
+
+  HloComputation* computation = lhs->parent();
+  auto add = [&](std::unique_ptr<HloInstruction> inst) {
+    return computation->AddInstruction(std::move(inst));
+  };
+  // Reshape batch_dim N -> [G, N/G]
+  std::vector<int64_t> reshape_dims = SpanToVector(lhs->shape().dimensions());
+  reshape_dims[input_batch_dimension] =
+      reshape_dims[input_batch_dimension] / num_groups;
+  reshape_dims.insert(reshape_dims.begin() + input_batch_dimension, num_groups);
+  lhs = add(HloInstruction::CreateReshape(
+      ShapeUtil::MakeShape(lhs->shape().element_type(), reshape_dims), lhs));
+
+  // Transpose G to the axis before C, For eg: [G, N/G, H, W, C ] -> [N/G, H,
+  // W, G, C]
+  std::vector<int64_t> transpose_dims(lhs->shape().dimensions_size());
+  std::iota(transpose_dims.begin(), transpose_dims.end(), 0);
+  transpose_dims.erase(transpose_dims.begin() + input_batch_dimension);
+  transpose_dims.insert(transpose_dims.begin() + input_feature_dimension,
+                        input_batch_dimension);
+  std::vector<int64_t> transpose_reshape_dims =
+      ComposePermutations(lhs->shape().dimensions(), transpose_dims);
+  lhs = add(HloInstruction::CreateTranspose(
+      ShapeUtil::MakeShape(lhs->shape().element_type(), transpose_reshape_dims),
+      lhs, transpose_dims));
+
+  // Merge [G,C] -> [C*G]
+  Shape new_shape = lhs->shape();
+  new_shape.DeleteDimension(input_feature_dimension);
+  new_shape.set_dimensions(input_feature_dimension, input_feature * num_groups);
+  lhs = add(HloInstruction::CreateReshape(new_shape, lhs));
+
+  std::vector<HloInstruction*> new_operands = {lhs, rhs};
+  auto new_conv = conv->CloneWithNewOperands(output_shape, new_operands);
+  new_conv->set_feature_group_count(num_groups);
+  new_conv->set_batch_group_count(1);
+  new_conv->set_convolution_dimension_numbers(dim_numbers);
+  return computation->AddInstruction(std::move(new_conv));
+}
+
+CudnnConvBackendConfig GetDefaultBackendConfig() {
+  CudnnConvBackendConfig config;
+  config.set_conv_result_scale(1);
+  return config;
+}
+
+// Helper function to create a custom_call instruction to replace the given
+// conv instruction
+static absl::StatusOr<HloInstruction*> CreateCustomCallHelper(
+    HloInstruction* conv, const se::GpuComputeCapability& cc) {
+  TF_RETURN_IF_ERROR(CheckTypes(conv, cc));
+  if (ConvolutionMatch m = MatchBackwardInput(conv)) {
+    auto& [window, dnums, rhs] = *m;
+    return CreateGpuConv(kCudnnConvBackwardInputCallTarget, conv->shape(),
+                         conv->mutable_operand(0), rhs, window, dnums,
+                         conv->feature_group_count(), conv->precision_config(),
+                         conv->metadata());
+  }
+
+  if (ConvolutionMatch m = MatchBackwardFilter(conv)) {
+    auto& [window, dnums, lhs] = *m;
+    return CreateGpuConv(kCudnnConvBackwardFilterCallTarget, conv->shape(), lhs,
+                         conv->mutable_operand(1), window, dnums,
+                         conv->batch_group_count(), conv->precision_config(),
+                         conv->metadata());
+  }
+
+  // If all else fails, try a forward convolution.
+  if (CanImplementAsGpuForwardConv(conv)) {
+    if (conv->batch_group_count() > 1) {
+      conv = ConvertBatchGroupedToFeatureGroupedConvolution(conv);
+    }
+
+    return CreateGpuConv(kCudnnConvForwardCallTarget, conv->shape(),
+                         conv->mutable_operand(0), conv->mutable_operand(1),
+                         conv->window(), conv->convolution_dimension_numbers(),
+                         conv->feature_group_count(), conv->precision_config(),
+                         conv->metadata());
+  }
+
+  return nullptr;
+}
+
+// Tries to rewrite a single convolution into a call to cudnn/miopen.
+absl::StatusOr<bool> RunOnInstruction(HloInstruction* conv,
+                                      const se::GpuComputeCapability& cc) {
+  CHECK_EQ(conv->opcode(), HloOpcode::kConvolution);
+
+  TF_ASSIGN_OR_RETURN(HloInstruction * custom_call,
+                      CreateCustomCallHelper(conv, cc));
+  if (custom_call == nullptr) {
+    return false;
+  }
+
+  GpuBackendConfig gpu_backend_config;
+  *gpu_backend_config.mutable_cudnn_conv_backend_config() =
+      GetDefaultBackendConfig();
+  TF_RETURN_IF_ERROR(custom_call->set_backend_config(gpu_backend_config));
+
+  VLOG(1) << "Replacing convolution " << conv->ToString() << " with "
+          << custom_call->ToString();
+
+  // The CustomCall returns a tuple (conv_result, scratch_memory).  Extract
+  // out the conv result and replace `conv` with it.
+  TF_RETURN_IF_ERROR(conv->parent()->ReplaceWithNewInstruction(
+      conv,
+      HloInstruction::CreateGetTupleElement(conv->shape(), custom_call, 0)));
+  return true;
+}
+
+// Rewrites the convolutions in the given computation into calls to
+// cudnn/miopen.
+// Returns true if it made any changes.
+absl::StatusOr<bool> RunOnComputation(HloComputation* computation,
+                                      const se::GpuComputeCapability& cc) {
+  std::vector<HloInstruction*> convs;
+  for (auto* hlo : computation->instructions()) {
+    if (hlo->opcode() == HloOpcode::kConvolution) {
+      convs.push_back(hlo);
+    }
+  }
+
+  bool changed = false;
+  for (HloInstruction* conv : convs) {
+    TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(conv, cc));
+    changed |= result;
+  }
+  return changed;
+}
+}  // namespace
+
+absl::StatusOr<bool> ConvRewriter::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  XLA_VLOG_LINES(2, "ConvRewriter::Run(), before:\n" + module->ToString());
+  bool changed = false;
+  for (HloComputation* computation :
+       module->MakeNonfusionComputations(execution_threads)) {
+    TF_ASSIGN_OR_RETURN(bool result,
+                        RunOnComputation(computation, compute_capability_));
+    changed |= result;
+  }
+  XLA_VLOG_LINES(2, "ConvRewriter::Run(), after:\n" + module->ToString());
+  return changed;
+}
+
+/*static*/ bool ConvRewriter::ConvIsLowerable(HloInstruction* conv) {
+  return CanImplementAsGpuForwardConv(conv) || MatchBackwardFilter(conv) ||
+         MatchBackwardInput(conv);
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/conv_rewriter.h b/third_party/xla/xla/service/gpu/transforms/conv_rewriter.h
new file mode 100644
index 0000000..69369f1
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/conv_rewriter.h
@@ -0,0 +1,58 @@
+/* Copyright 2018 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_CONV_REWRITER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_CONV_REWRITER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// Rewrites plain convolutions, backwards-filter convolutions, and
+// backwards-input convolutions into CustomCall HLOs that call into
+// Cudnn/Miopen.
+//
+// This pass does not fuse other ops into the convolution. Instead, specific
+// patterns of ops will be matched and fused into the custom call in
+// CudnnFusedConvRewriter.
+
+class ConvRewriter : public HloModulePass {
+ public:
+  explicit ConvRewriter(const se::GpuComputeCapability& compute_capability)
+      : compute_capability_(compute_capability) {};
+
+  absl::string_view name() const override { return "conv-rewriter"; }
+
+  static bool ConvIsLowerable(HloInstruction* conv);
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+  se::GpuComputeCapability compute_capability_;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_CONV_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/conv_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/conv_rewriter_test.cc
new file mode 100644
index 0000000..d01ffd1
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/conv_rewriter_test.cc
@@ -0,0 +1,807 @@
+/* Copyright 2018 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/conv_rewriter.h"
+
+#include <optional>
+#include <string>
+
+#include "absl/log/check.h"
+#include "absl/strings/str_format.h"
+#include "xla/array4d.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/literal_util.h"
+#include "xla/protobuf_util.h"
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/service/shape_inference.h"
+#include "xla/shape_util.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/test.h"
+#include "xla/test_helpers.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/status_matchers.h"
+#include "tsl/platform/statusor.h"
+#include "tsl/platform/test.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+namespace m = ::xla::match;
+
+class ConvRewriterTest : public HloTestBase {
+ public:
+  ConvRewriterTest()
+      : HloTestBase(/*verifier_layout_sensitive=*/true,
+                    /*allow_mixed_precision_in_hlo_verifier=*/false) {
+    for (int i = 0; i < 2; ++i) {
+      WindowDimension* window_dim = default_conv_window_.add_dimensions();
+      window_dim->set_size(1);
+      window_dim->set_stride(1);
+      window_dim->set_padding_low(0);
+      window_dim->set_padding_high(0);
+      window_dim->set_window_dilation(1);
+      window_dim->set_base_dilation(1);
+    }
+    // TF data shapes are by default in the NHWC order, and filter shape is by
+    // default in HWIO order. For backward filter convolution, we need to swap
+    // the batch and feature dimension in the activations, and treat the batch
+    // dimension in gradients as the input feature dimension in the filter.
+    //
+    // TODO(jingyue): Add more tests on NCHW input order, which TF also
+    // supports.
+    tf_default_dnums_for_backward_filter_.set_input_batch_dimension(3);
+    tf_default_dnums_for_backward_filter_.set_input_feature_dimension(0);
+    tf_default_dnums_for_backward_filter_.add_input_spatial_dimensions(1);
+    tf_default_dnums_for_backward_filter_.add_input_spatial_dimensions(2);
+    tf_default_dnums_for_backward_filter_.set_kernel_input_feature_dimension(0);
+    tf_default_dnums_for_backward_filter_.set_kernel_output_feature_dimension(
+        3);
+    tf_default_dnums_for_backward_filter_.add_kernel_spatial_dimensions(1);
+    tf_default_dnums_for_backward_filter_.add_kernel_spatial_dimensions(2);
+    tf_default_dnums_for_backward_filter_.add_output_spatial_dimensions(0);
+    tf_default_dnums_for_backward_filter_.add_output_spatial_dimensions(1);
+    tf_default_dnums_for_backward_filter_.set_output_batch_dimension(2);
+    tf_default_dnums_for_backward_filter_.set_output_feature_dimension(3);
+
+    tf_default_dnums_for_backward_input_.set_input_batch_dimension(0);
+    tf_default_dnums_for_backward_input_.set_output_batch_dimension(0);
+    tf_default_dnums_for_backward_input_.set_input_feature_dimension(3);
+    tf_default_dnums_for_backward_input_.set_output_feature_dimension(3);
+    tf_default_dnums_for_backward_input_.add_input_spatial_dimensions(1);
+    tf_default_dnums_for_backward_input_.add_output_spatial_dimensions(1);
+    tf_default_dnums_for_backward_input_.add_input_spatial_dimensions(2);
+    tf_default_dnums_for_backward_input_.add_output_spatial_dimensions(2);
+    tf_default_dnums_for_backward_input_.set_kernel_input_feature_dimension(3);
+    tf_default_dnums_for_backward_input_.set_kernel_output_feature_dimension(2);
+    tf_default_dnums_for_backward_input_.add_kernel_spatial_dimensions(0);
+    tf_default_dnums_for_backward_input_.add_kernel_spatial_dimensions(1);
+  }
+
+ protected:
+  const se::GpuComputeCapability& GetComputeCapability() {
+    return backend()
+        .default_stream_executor()
+        ->GetDeviceDescription()
+        .gpu_compute_capability();
+  }
+
+  bool RunPass(HloModule* module) {
+    return ConvRewriter(GetComputeCapability()).Run(module).value();
+  }
+
+  // A convolution window with stride 1 and zero padding. The size fields are
+  // not set.
+  Window default_conv_window_;
+  ConvolutionDimensionNumbers tf_default_dnums_for_backward_filter_;
+  ConvolutionDimensionNumbers tf_default_dnums_for_backward_input_;
+};
+
+TEST_F(ConvRewriterTest, BackwardFilterConvolve) {
+  HloComputation::Builder builder(TestName());
+  HloInstruction* activations =
+      builder.AddInstruction(HloInstruction::CreateParameter(
+          0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "activations"));
+  HloInstruction* gradients =
+      builder.AddInstruction(HloInstruction::CreateParameter(
+          1, ShapeUtil::MakeShape(F32, {1, 1, 2, 1}), "gradients"));
+  Window conv_window = default_conv_window_;
+  conv_window.mutable_dimensions(1)->set_size(2);
+  conv_window.mutable_dimensions(1)->set_window_dilation(2);
+  auto* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
+      ShapeInference::InferConvolveShape(
+          activations->shape(), gradients->shape(), /*feature_group_count=*/1,
+          /*batch_group_count=*/1, conv_window,
+          tf_default_dnums_for_backward_filter_,
+          /*preferred_element_type=*/std::nullopt)
+          .value(),
+      activations, gradients, /*feature_group_count=*/1,
+      /*batch_group_count=*/1, conv_window,
+      tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
+
+  OpMetadata metadata;
+  metadata.set_op_name("foo");
+  conv->set_metadata(metadata);
+
+  auto module = CreateNewVerifiedModule();
+  HloComputation* entry_computation =
+      module->AddEntryComputation(builder.Build());
+  EXPECT_TRUE(RunPass(module.get()));
+  ASSERT_THAT(entry_computation->root_instruction(),
+              GmockMatch(m::GetTupleElement(
+                  m::CustomCall({kCudnnConvBackwardFilterCallTarget}), 0)));
+
+  // Check that metadata was preserved.
+  const auto& md_after_opt =
+      entry_computation->root_instruction()->operand(0)->metadata();
+  EXPECT_TRUE(protobuf_util::ProtobufEquals(md_after_opt, metadata))
+      << md_after_opt.DebugString() << " vs " << metadata.DebugString();
+}
+
+TEST_F(ConvRewriterTest, BackwardFilterConvolveEquivalentToForwardConvolution) {
+  HloComputation::Builder builder(TestName());
+  HloInstruction* activations =
+      builder.AddInstruction(HloInstruction::CreateParameter(
+          0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "activations"));
+  HloInstruction* gradients =
+      builder.AddInstruction(HloInstruction::CreateParameter(
+          1, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "gradients"));
+  Window conv_window = default_conv_window_;
+  conv_window.mutable_dimensions(1)->set_size(3);
+  builder.AddInstruction(HloInstruction::CreateConvolve(
+      ShapeInference::InferConvolveShape(
+          activations->shape(), gradients->shape(), /*feature_group_count=*/1,
+          /*batch_group_count=*/1, conv_window,
+          tf_default_dnums_for_backward_filter_,
+          /*preferred_element_type=*/std::nullopt)
+          .value(),
+      activations, gradients, /*feature_group_count=*/1,
+      /*batch_group_count=*/1, conv_window,
+      tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
+
+  auto module = CreateNewVerifiedModule();
+  HloComputation* entry_computation =
+      module->AddEntryComputation(builder.Build());
+  EXPECT_TRUE(RunPass(module.get()));
+  EXPECT_THAT(entry_computation->root_instruction(),
+              GmockMatch(m::GetTupleElement(
+                  m::CustomCall({kCudnnConvForwardCallTarget}), 0)));
+}
+
+// Extracted from block35 training.
+TEST_F(ConvRewriterTest, BackwardFilterConvolveWithPaddedActivations) {
+  auto builder = HloComputation::Builder(TestName());
+  HloInstruction* activations =
+      builder.AddInstruction(HloInstruction::CreateParameter(
+          0, ShapeUtil::MakeShape(F32, {20, 35, 35, 32}), "activations"));
+  HloInstruction* gradients =
+      builder.AddInstruction(HloInstruction::CreateParameter(
+          1, ShapeUtil::MakeShape(F32, {20, 35, 35, 32}), "gradients"));
+
+  Window conv_window = default_conv_window_;
+  for (int i = 0; i < 2; ++i) {
+    conv_window.mutable_dimensions(i)->set_size(35);
+    conv_window.mutable_dimensions(i)->set_padding_low(1);
+    conv_window.mutable_dimensions(i)->set_padding_high(1);
+  }
+  builder.AddInstruction(HloInstruction::CreateConvolve(
+      ShapeUtil::MakeShape(F32, {32, 3, 3, 32}), activations, gradients,
+      /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window,
+      tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
+
+  auto module = CreateNewVerifiedModule();
+  HloComputation* entry_computation =
+      module->AddEntryComputation(builder.Build());
+  EXPECT_TRUE(RunPass(module.get()));
+  EXPECT_THAT(entry_computation->root_instruction(),
+              GmockMatch(m::GetTupleElement(
+                  m::CustomCall({kCudnnConvBackwardFilterCallTarget}), 0)));
+}
+
+// Extracted from inception v3 training.
+TEST_F(ConvRewriterTest, BackwardFilterConvolveWithPaddedGradients) {
+  auto builder = HloComputation::Builder(TestName());
+  HloInstruction* activations =
+      builder.AddInstruction(HloInstruction::CreateParameter(
+          0, ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), "activations"));
+  HloInstruction* gradients =
+      builder.AddInstruction(HloInstruction::CreateParameter(
+          1, ShapeUtil::MakeShape(F32, {20, 4, 4, 320}), "gradients"));
+
+  Window conv_window = default_conv_window_;
+  for (int i = 0; i < 2; ++i) {
+    conv_window.mutable_dimensions(i)->set_size(4);
+    conv_window.mutable_dimensions(i)->set_padding_high(-1);
+    conv_window.mutable_dimensions(i)->set_window_dilation(2);
+  }
+  builder.AddInstruction(HloInstruction::CreateConvolve(
+      ShapeUtil::MakeShape(F32, {320, 3, 3, 192}), activations, gradients,
+      /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window,
+      tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
+
+  auto module = CreateNewVerifiedModule();
+  HloComputation* entry_computation =
+      module->AddEntryComputation(builder.Build());
+  EXPECT_TRUE(RunPass(module.get()));
+  EXPECT_THAT(entry_computation->root_instruction(),
+              GmockMatch(m::GetTupleElement(
+                  m::CustomCall({kCudnnConvBackwardFilterCallTarget}), 0)));
+}
+
+TEST_F(ConvRewriterTest, BackwardFilterConvolveWithUnevenPadding) {
+  auto builder = HloComputation::Builder(TestName());
+  HloInstruction* activations =
+      builder.AddInstruction(HloInstruction::CreateParameter(
+          0, ShapeUtil::MakeShape(F32, {20, 35, 35, 32}), "activations"));
+  HloInstruction* gradients =
+      builder.AddInstruction(HloInstruction::CreateParameter(
+          1, ShapeUtil::MakeShape(F32, {20, 35, 35, 32}), "gradients"));
+
+  Window conv_window = default_conv_window_;
+  for (int i = 0; i < 2; ++i) {
+    conv_window.mutable_dimensions(i)->set_size(35);
+    // Uneven padding: padding_low=0, padding_high=1
+    conv_window.mutable_dimensions(i)->set_padding_high(1);
+  }
+  builder.AddInstruction(HloInstruction::CreateConvolve(
+      ShapeUtil::MakeShape(F32, {32, 2, 2, 32}), activations, gradients,
+      /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window,
+      tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
+
+  auto module = CreateNewVerifiedModule();
+  HloComputation* entry_computation =
+      module->AddEntryComputation(builder.Build());
+  EXPECT_TRUE(RunPass(module.get()));
+  EXPECT_THAT(entry_computation->root_instruction(),
+              GmockMatch(m::GetTupleElement(
+                  m::CustomCall({kCudnnConvBackwardFilterCallTarget}), 0)));
+}
+
+TEST_F(ConvRewriterTest, BackwardInputConvolveEvenPadding) {
+  auto builder = HloComputation::Builder(TestName());
+  HloInstruction* output =
+      builder.AddInstruction(HloInstruction::CreateParameter(
+          0, ShapeUtil::MakeShape(F32, {4, 5, 16, 16}), "output"));
+  HloInstruction* kernel =
+      builder.AddInstruction(HloInstruction::CreateParameter(
+          1, ShapeUtil::MakeShape(F32, {5, 3, 7, 7}), "kernel"));
+  HloInstruction* reverse_kernel = builder.AddInstruction(
+      HloInstruction::CreateReverse(kernel->shape(), kernel, {2, 3}));
+
+  Window conv_window = default_conv_window_;
+  for (int i = 0; i < 2; ++i) {
+    conv_window.mutable_dimensions(i)->set_size(7);
+    conv_window.mutable_dimensions(i)->set_padding_low(3);
+    conv_window.mutable_dimensions(i)->set_padding_high(3);
+  }
+  ConvolutionDimensionNumbers conv_dnums;
+  conv_dnums.set_input_batch_dimension(0);
+  conv_dnums.set_output_batch_dimension(0);
+  conv_dnums.set_input_feature_dimension(1);
+  conv_dnums.set_output_feature_dimension(1);
+  conv_dnums.add_input_spatial_dimensions(2);
+  conv_dnums.add_output_spatial_dimensions(2);
+  conv_dnums.add_input_spatial_dimensions(3);
+  conv_dnums.add_output_spatial_dimensions(3);
+  conv_dnums.set_kernel_input_feature_dimension(0);
+  conv_dnums.set_kernel_output_feature_dimension(1);
+  conv_dnums.add_kernel_spatial_dimensions(2);
+  conv_dnums.add_kernel_spatial_dimensions(3);
+
+  HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
+      ShapeUtil::MakeShape(F32, {4, 3, 16, 16}), /*lhs=*/output,
+      /*rhs=*/reverse_kernel, /*feature_group_count=*/1,
+      /*batch_group_count=*/1, conv_window, conv_dnums,
+      DefaultPrecisionConfig(2)));
+  // Verify the convolution's shape is consistent with ShapeInference.
+  CHECK(ShapeUtil::Compatible(
+      conv->shape(),
+      ShapeInference::InferConvolveShape(
+          output->shape(), reverse_kernel->shape(),
+          /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window,
+          conv_dnums, /*preferred_element_type=*/std::nullopt)
+          .value()));
+
+  auto module = CreateNewVerifiedModule();
+  HloComputation* entry_computation =
+      module->AddEntryComputation(builder.Build());
+  EXPECT_TRUE(RunPass(module.get()));
+
+  ASSERT_THAT(entry_computation->root_instruction(),
+              GmockMatch(m::GetTupleElement(
+                  m::CustomCall({kCudnnConvBackwardInputCallTarget}), 0)));
+  const HloInstruction* custom_call =
+      entry_computation->root_instruction()->operand(0);
+  for (int i = 0; i < 2; ++i) {
+    const WindowDimension& window_dim = custom_call->window().dimensions(i);
+    // Low padding of the backward input convolution
+    //   = kernel_size - 1 - low padding on gradients.
+    EXPECT_EQ(3, window_dim.padding_low());
+    EXPECT_EQ(3, window_dim.padding_high());
+    EXPECT_EQ(1, window_dim.stride());
+    EXPECT_EQ(1, window_dim.base_dilation());
+  }
+}
+
+// Convolve([abc], [x], base_dilation=2)
+//   = Convolve([abc], Reverse([x]), base_dilation=2)
+//   = BackwardInputConvolve([abc], [x], stride=2)
+TEST_F(ConvRewriterTest, BackwardInputConvolve1x1Filter) {
+  auto builder = HloComputation::Builder(TestName());
+  // NHWC dimension order.
+  HloInstruction* output =
+      builder.AddInstruction(HloInstruction::CreateParameter(
+          0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "output"));
+  // HWOI dimension order.
+  HloInstruction* kernel =
+      builder.AddInstruction(HloInstruction::CreateParameter(
+          1, ShapeUtil::MakeShape(F32, {1, 1, 1, 1}), "kernel"));
+
+  Window conv_window = default_conv_window_;
+  conv_window.mutable_dimensions(1)->set_base_dilation(2);
+
+  builder.AddInstruction(HloInstruction::CreateConvolve(
+      ShapeInference::InferConvolveShape(
+          output->shape(), kernel->shape(),
+          /*feature_group_count=*/1,
+          /*batch_group_count=*/1, conv_window,
+          tf_default_dnums_for_backward_input_,
+          /*preferred_element_type=*/std::nullopt)
+          .value(),
+      /*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1,
+      /*batch_group_count=*/1, conv_window,
+      tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
+
+  auto module = CreateNewVerifiedModule();
+  HloComputation* entry_computation =
+      module->AddEntryComputation(builder.Build());
+  EXPECT_TRUE(RunPass(module.get()));
+  EXPECT_THAT(entry_computation->root_instruction(),
+              GmockMatch(m::GetTupleElement(
+                  m::CustomCall({kCudnnConvBackwardInputCallTarget}), 0)));
+}
+
+// BackwardInputConvolve([abc], [x], stride=1) is equivalent to
+// ForwardConvolve([abc], [x], stride=1). No need to fold it into backward input
+// convolution.
+TEST_F(ConvRewriterTest,
+       BackwardInputConvolve1x1FilterEquivalentToForwardConvolve) {
+  auto builder = HloComputation::Builder(TestName());
+  // NHWC dimension order.
+  HloInstruction* output =
+      builder.AddInstruction(HloInstruction::CreateParameter(
+          0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "output"));
+  // HWOI dimension order.
+  HloInstruction* kernel =
+      builder.AddInstruction(HloInstruction::CreateParameter(
+          1, ShapeUtil::MakeShape(F32, {1, 1, 1, 1}), "kernel"));
+
+  builder.AddInstruction(HloInstruction::CreateConvolve(
+      ShapeInference::InferConvolveShape(
+          output->shape(), kernel->shape(), /*feature_group_count=*/1,
+          /*batch_group_count=*/1, default_conv_window_,
+          tf_default_dnums_for_backward_input_,
+          /*preferred_element_type=*/std::nullopt)
+          .value(),
+      /*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1,
+      /*batch_group_count=*/1, default_conv_window_,
+      tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
+
+  auto module = CreateNewVerifiedModule();
+  HloComputation* entry_computation =
+      module->AddEntryComputation(builder.Build());
+  EXPECT_TRUE(RunPass(module.get()));
+  EXPECT_THAT(entry_computation->root_instruction(),
+              GmockMatch(m::GetTupleElement(
+                  m::CustomCall({kCudnnConvForwardCallTarget}), 0)));
+}
+
+// Extracted from Inception V3 training.
+//
+//                                  filter(HWIO)
+//                                  3x3x192x320
+//                                      |
+//                                      v
+//      gradients(NHWC)              reverse
+//        20x4x4x320               3x3x192x320
+//                    \            /
+//                     \          /
+//  conv (NHWC) with padding (low=2,high=3,interior=1)
+//                     20x10x10x192
+//
+// Gradients are padded unevenly.
+TEST_F(ConvRewriterTest, BackwardInputConvolveUnevenPaddingOnGradients) {
+  auto builder = HloComputation::Builder(TestName());
+  HloInstruction* output =
+      builder.AddInstruction(HloInstruction::CreateParameter(
+          0, ShapeUtil::MakeShape(F32, {20, 4, 4, 320}), "output"));
+  HloInstruction* kernel =
+      builder.AddInstruction(HloInstruction::CreateParameter(
+          1, ShapeUtil::MakeShape(F32, {3, 3, 192, 320}), "kernel"));
+  HloInstruction* reverse_kernel = builder.AddInstruction(
+      HloInstruction::CreateReverse(kernel->shape(), kernel, {0, 1}));
+
+  Window conv_window = default_conv_window_;
+  for (int i = 0; i < 2; ++i) {
+    conv_window.mutable_dimensions(i)->set_size(3);
+    conv_window.mutable_dimensions(i)->set_padding_low(2);
+    conv_window.mutable_dimensions(i)->set_padding_high(3);
+    // Interior padding = 1.
+    conv_window.mutable_dimensions(i)->set_base_dilation(2);
+  }
+  HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
+      ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), output, reverse_kernel,
+      /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window,
+      tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
+  // Verify the convolution's shape is consistent with ShapeInference.
+  CHECK(ShapeUtil::Compatible(
+      conv->shape(), ShapeInference::InferConvolveShape(
+                         output->shape(), reverse_kernel->shape(),
+                         /*feature_group_count=*/1, /*batch_group_count=*/1,
+                         conv_window, tf_default_dnums_for_backward_input_,
+                         /*preferred_element_type=*/std::nullopt)
+                         .value()));
+
+  auto module = CreateNewVerifiedModule();
+  HloComputation* entry_computation =
+      module->AddEntryComputation(builder.Build());
+  EXPECT_TRUE(RunPass(module.get()));
+  ASSERT_THAT(entry_computation->root_instruction(),
+              GmockMatch(m::GetTupleElement(
+                  m::CustomCall({kCudnnConvBackwardInputCallTarget}), 0)));
+  const HloInstruction* custom_call =
+      entry_computation->root_instruction()->operand(0);
+  for (int i = 0; i < 2; ++i) {
+    const WindowDimension& window_dim = custom_call->window().dimensions(i);
+    EXPECT_EQ(0, window_dim.padding_low());
+    EXPECT_EQ(0, window_dim.padding_high());
+    EXPECT_EQ(2, window_dim.stride());
+    EXPECT_EQ(1, window_dim.base_dilation());
+  }
+}
+
+// Similar to BackwardInputConvolveUnevenPadding, but the low padding of the
+// gradients exceeds kernel_size - 1. Therefore, this pattern cannot be fused.
+TEST_F(ConvRewriterTest, BackwardInputConvolveLowPaddingTooLarge) {
+  auto builder = HloComputation::Builder(TestName());
+  HloInstruction* output =
+      builder.AddInstruction(HloInstruction::CreateParameter(
+          0, ShapeUtil::MakeShape(F32, {20, 4, 4, 320}), "output"));
+  HloInstruction* kernel =
+      builder.AddInstruction(HloInstruction::CreateParameter(
+          1, ShapeUtil::MakeShape(F32, {3, 3, 192, 320}), "kernel"));
+  HloInstruction* reverse_kernel = builder.AddInstruction(
+      HloInstruction::CreateReverse(kernel->shape(), kernel, {0, 1}));
+
+  Window conv_window = default_conv_window_;
+  for (int i = 0; i < 2; ++i) {
+    conv_window.mutable_dimensions(i)->set_size(3);
+    conv_window.mutable_dimensions(i)->set_padding_low(3);
+    conv_window.mutable_dimensions(i)->set_padding_high(2);
+    conv_window.mutable_dimensions(i)->set_base_dilation(2);
+  }
+  HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
+      ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), output, reverse_kernel,
+      /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window,
+      tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
+  // Verify the convolution's shape is consistent with ShapeInference.
+  CHECK(ShapeUtil::Compatible(
+      conv->shape(), ShapeInference::InferConvolveShape(
+                         output->shape(), reverse_kernel->shape(),
+                         /*feature_group_count=*/1, /*batch_group_count=*/1,
+                         conv_window, tf_default_dnums_for_backward_input_,
+                         /*preferred_element_type=*/std::nullopt)
+                         .value()));
+
+  auto module = CreateNewVerifiedModule();
+  HloComputation* entry_computation =
+      module->AddEntryComputation(builder.Build());
+  EXPECT_TRUE(RunPass(module.get()));
+  EXPECT_THAT(entry_computation->root_instruction(),
+              GmockMatch(m::GetTupleElement(
+                  m::CustomCall({kCudnnConvForwardCallTarget}), 0)));
+}
+
+// Extracted from Resnet-50.
+//
+// For simplicity, we focus on the column dimension and ignore other dimensions.
+// We use [?] to represent the shape instead of the content.
+//
+// Suppose operator FC does
+//   [4] = conv([14], [3], stride=2, padding_high=1)  // Padding::kSame
+//
+// BC = BackwardInput(FC) does:
+//   [14] = conv([7], reverse([3]),
+//               padding_low=2, padding_high=1, base_dilation=2)
+//
+// We should fuse BC even though padding on activations is uneven, because
+// GpuConvPaddingLegalization will canonicalize the fusion HLO.
+TEST_F(ConvRewriterTest, BackwardInputConvolveUnevenPaddingOnActivations) {
+  auto builder = HloComputation::Builder(TestName());
+  // The gradients are in NCHW layout.
+  HloInstruction* output =
+      builder.AddInstruction(HloInstruction::CreateParameter(
+          0, ShapeUtil::MakeShape(F32, {1, 1, 7, 1}), "output"));
+  // The kernel is in HWIO layout.
+  HloInstruction* kernel =
+      builder.AddInstruction(HloInstruction::CreateParameter(
+          1, ShapeUtil::MakeShape(F32, {1, 3, 1, 1}), "kernel"));
+  HloInstruction* reverse_kernel = builder.AddInstruction(
+      HloInstruction::CreateReverse(kernel->shape(), kernel, {0, 1}));
+
+  Window conv_window = default_conv_window_;
+  WindowDimension* forward_conv_col_dim = conv_window.mutable_dimensions(1);
+  forward_conv_col_dim->set_size(3);
+  forward_conv_col_dim->set_padding_low(2);
+  forward_conv_col_dim->set_padding_high(1);
+  forward_conv_col_dim->set_base_dilation(2);
+  HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
+      ShapeUtil::MakeShape(F32, {1, 1, 14, 1}), output, reverse_kernel,
+      /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window,
+      tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
+  // Verify the convolution's shape is consistent with ShapeInference.
+  CHECK(ShapeUtil::Compatible(
+      conv->shape(), ShapeInference::InferConvolveShape(
+                         output->shape(), reverse_kernel->shape(),
+                         /*feature_group_count=*/1, /*batch_group_count=*/1,
+                         conv_window, tf_default_dnums_for_backward_input_,
+                         /*preferred_element_type=*/std::nullopt)
+                         .value()));
+
+  auto module = CreateNewVerifiedModule();
+  const HloComputation* entry_computation =
+      module->AddEntryComputation(builder.Build());
+  EXPECT_TRUE(RunPass(module.get()));
+  ASSERT_THAT(entry_computation->root_instruction(),
+              GmockMatch(m::GetTupleElement(
+                  m::CustomCall({kCudnnConvBackwardInputCallTarget}), 0)));
+  const WindowDimension& backward_conv_col_dim =
+      entry_computation->root_instruction()->operand(0)->window().dimensions(1);
+  EXPECT_EQ(0, backward_conv_col_dim.padding_low());
+  EXPECT_EQ(1, backward_conv_col_dim.padding_high());
+}
+
+// For simplicity, we focus on the column dimension and ignore other dimensions.
+// We use [?] to represent the shape instead of the content.
+//
+// Suppose operator FC does
+//   [3] = conv([4], [2], padding_low=1, padding_high=-1)
+//
+// BC = BackwardInput(FC) does:
+//   [4] = conv([3], reverse([2]), padding_high=2)
+//
+// We currently don't fuse BC because GpuConvPaddingLegalization
+// doesn't support negative padding on the gradients of backward convolution
+// (b/32744257).
+TEST_F(ConvRewriterTest,
+       BackwardInputConvolveNegativePaddingHighOnActivations) {
+  auto builder = HloComputation::Builder(TestName());
+  // The gradients are in NCHW layout.
+  HloInstruction* output =
+      builder.AddInstruction(HloInstruction::CreateParameter(
+          0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "output"));
+  // The kernel is in HWIO layout.
+  HloInstruction* kernel =
+      builder.AddInstruction(HloInstruction::CreateParameter(
+          1, ShapeUtil::MakeShape(F32, {1, 2, 1, 1}), "kernel"));
+  HloInstruction* reverse_kernel = builder.AddInstruction(
+      HloInstruction::CreateReverse(kernel->shape(), kernel, {0, 1}));
+
+  Window conv_window = default_conv_window_;
+  WindowDimension* forward_conv_col_dim = conv_window.mutable_dimensions(1);
+  forward_conv_col_dim->set_size(2);
+  forward_conv_col_dim->set_padding_high(2);
+  HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
+      ShapeUtil::MakeShape(F32, {1, 1, 4, 1}), output, reverse_kernel,
+      /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window,
+      tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
+  // Verify the convolution's shape is consistent with ShapeInference.
+  CHECK(ShapeUtil::Compatible(
+      conv->shape(), ShapeInference::InferConvolveShape(
+                         output->shape(), reverse_kernel->shape(),
+                         /*feature_group_count=*/1, /*batch_group_count=*/1,
+                         conv_window, tf_default_dnums_for_backward_input_,
+                         /*preferred_element_type=*/std::nullopt)
+                         .value()));
+
+  auto module = CreateNewVerifiedModule();
+  HloComputation* entry_computation =
+      module->AddEntryComputation(builder.Build());
+  EXPECT_TRUE(RunPass(module.get()));
+  EXPECT_THAT(entry_computation->root_instruction(),
+              GmockMatch(m::GetTupleElement(
+                  m::CustomCall({kCudnnConvForwardCallTarget}), 0)));
+}
+
+// Check that we will materialize a reversed version of a constant in order to
+// pattern-match a backwards input convolution.
+TEST_F(ConvRewriterTest, BackwardInputConvolveConstantFilter) {
+  Array4D<float> constant_arr(4, 4, 2, 2);
+  constant_arr.FillIota(0);
+  std::string constant_str =
+      LiteralUtil::CreateR4FromArray4D(constant_arr).ToStringWithoutShape();
+
+  const std::string module_str = absl::StrFormat(R"(
+    HloModule test
+
+    ENTRY entry_computation {
+      param0 = f32[128,2,16,16]{3,2,1,0} parameter(0)
+      constant = f32[4,4,2,2]{3,2,1,0} constant(%s)
+      ROOT convolution = f32[128,2,32,32]{3,2,1,0} convolution(param0, constant),
+          window={size=4x4 pad=2_2x2_2 lhs_dilate=2x2},
+          dim_labels=bf01_01oi->bf01, feature_group_count=1
+    })",
+                                                 constant_str);
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+  EXPECT_TRUE(RunPass(m.get()));
+  EXPECT_THAT(m->entry_computation()->root_instruction(),
+              GmockMatch(m::GetTupleElement(
+                  m::CustomCall({kCudnnConvBackwardInputCallTarget},
+                                m::Parameter(), m::Reverse(m::Constant())),
+                  0)));
+}
+
+TEST_F(ConvRewriterTest, TestBackwardFilterPatternMatch) {
+  // All filter dimensions are larger than the corresponding output dimensions.
+  // This must be a backward filter convolution.
+  const std::string module_str = absl::StrFormat(R"(
+    HloModule Test
+
+    ENTRY Test {
+      input = f32[8,120,256,256] parameter(0)
+      filter = f32[8,120,256,256] parameter(1)
+
+      ROOT conv = f32[120,120,3,3] convolution(input, filter), window={size=256x256 pad=1_1x1_1}, dim_labels=fb01_io01->fb01
+    })");
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+  EXPECT_TRUE(RunPass(m.get()));
+  EXPECT_THAT(m->entry_computation()->root_instruction(),
+              GmockMatch(m::GetTupleElement(
+                  m::CustomCall({kCudnnConvBackwardFilterCallTarget},
+                                m::Parameter(0), m::Parameter(1)),
+                  0)));
+}
+
+TEST_F(ConvRewriterTest, TestBackwardFilterPatternNoMatch) {
+  // At least one filter dimension is smaller than the corresponding output
+  // dimension. This must be a forward convolution.
+  const std::string module_str = absl::StrFormat(R"(
+    HloModule Test
+
+    ENTRY Test {
+      input = f32[8,128,2,32] parameter(0)
+      filter = f32[3,3,128,128] parameter(1)
+
+      ROOT conv = f32[8,128,2,32] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01
+    })");
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+  EXPECT_TRUE(RunPass(m.get()));
+  EXPECT_THAT(m->entry_computation()->root_instruction(),
+              GmockMatch(m::GetTupleElement(
+                  m::CustomCall({kCudnnConvForwardCallTarget}, m::Parameter(0),
+                                m::Parameter(1)),
+                  0)));
+}
+
+TEST_F(ConvRewriterTest, TestConv1dBackwardFilterPatternMatch) {
+  // There exist one kernel dimension equal to output dimension, regard
+  // it as backward filter if conv is 1d.
+  const std::string module_str = absl::StrFormat(R"(
+    HloModule Test
+
+    ENTRY Test {
+      input = f32[8,256,128] parameter(0)
+      filter = f32[8,254,128] parameter(1)
+      reshape.1 = f32[8,1,256,128] reshape(input)
+      reshape.2 = f32[8,1,254,128] reshape(filter)
+      ROOT conv = f32[1,3,128,128] convolution(reshape.1, reshape.2), window={size=1x254}, dim_labels=f01b_i01o->01bf
+    })");
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+  EXPECT_TRUE(RunPass(m.get()));
+  EXPECT_THAT(m->entry_computation()->root_instruction(),
+              GmockMatch(m::GetTupleElement(
+                  m::CustomCall({kCudnnConvBackwardFilterCallTarget},
+                                m::Reshape(), m::Reshape()),
+                  0)));
+}
+
+TEST_F(ConvRewriterTest, TestConv1dBackwardInputPatternMatch) {
+  // For conv1d backward input, filter may reverse first and then reshape.
+  const std::string module_str = absl::StrFormat(R"(
+    HloModule Test
+
+    ENTRY Test {
+      input = f32[8,254,128] parameter(0)
+      filter = f32[3,128,128] parameter(1)
+      reverse = f32[3,128,128] reverse(filter), dimensions={0}
+      reshape.1 = f32[8,1,254,128] reshape(input)
+      reshape.2 = f32[1,3,128,128] reshape(reverse)
+      ROOT conv = f32[8,1,256,128] convolution(reshape.1, reshape.2), window={size=1x3 pad=0_0x2_2}, dim_labels=b01f_01oi->b01f
+    })");
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+  EXPECT_TRUE(RunPass(m.get()));
+  EXPECT_THAT(m->entry_computation()->root_instruction(),
+              GmockMatch(m::GetTupleElement(
+                  m::CustomCall({kCudnnConvBackwardInputCallTarget},
+                                m::Reshape(), m::Reshape()),
+                  0)));
+}
+
+TEST_F(ConvRewriterTest, TestInvalidTypes) {
+  const std::string module_str = absl::StrFormat(R"(
+    HloModule Test
+
+    ENTRY Test {
+      input = TYPE[1,17,9,9] parameter(0)
+      filter = TYPE[3,3,17,32] parameter(1)
+      ROOT conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
+    })");
+
+  // Test complex types
+  for (std::string_view type : {"c64", "c128"}) {
+    const std::string module_with_type =
+        absl::StrReplaceAll(module_str, {{"TYPE", type}});
+    TF_ASSERT_OK_AND_ASSIGN(auto m,
+                            ParseAndReturnVerifiedModule(module_with_type));
+
+    absl::Status s = ConvRewriter(GetComputeCapability()).Run(m.get()).status();
+    EXPECT_THAT(
+        s, tsl::testing::StatusIs(
+               absl::StatusCode::kUnimplemented,
+               ::testing::HasSubstr("Convolutions must have floating-point or "
+                                    "integral operands/outputs")));
+  }
+
+  // Test FP8 type on unsupported GPUs
+  std::string module_with_type =
+      absl::StrReplaceAll(module_str, {{"TYPE", "f8e4m3fn"}});
+  TF_ASSERT_OK_AND_ASSIGN(auto m,
+                          ParseAndReturnVerifiedModule(module_with_type));
+  absl::Status s =
+      ConvRewriter(se::CudaComputeCapability::Ampere()).Run(m.get()).status();
+  EXPECT_THAT(s, tsl::testing::StatusIs(
+                     absl::StatusCode::kUnimplemented,
+                     ::testing::HasSubstr(
+                         "FP8 convolutions are only supported on CUDA "
+                         "GPUs with compute capability at least 9.0")));
+  s = ConvRewriter(se::RocmComputeCapability{"gfx942"}).Run(m.get()).status();
+  EXPECT_THAT(s, tsl::testing::StatusIs(
+                     absl::StatusCode::kUnimplemented,
+                     ::testing::HasSubstr(
+                         "FP8 convolutions are only supported on CUDA GPUs")));
+
+  // Test unsupported FP8 type
+  module_with_type = absl::StrReplaceAll(module_str, {{"TYPE", "f8e4m3fnuz"}});
+  TF_ASSERT_OK_AND_ASSIGN(m, ParseAndReturnVerifiedModule(module_with_type));
+  s = ConvRewriter(GetComputeCapability()).Run(m.get()).status();
+  EXPECT_THAT(s,
+              tsl::testing::StatusIs(
+                  absl::StatusCode::kUnimplemented,
+                  ::testing::HasSubstr("The only FP8 types supported in "
+                                       "convolutions are f8e5m2 and f8e4m3")));
+}
+
+}  // anonymous namespace
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/convert_async_collectives_to_sync.cc b/third_party/xla/xla/service/gpu/transforms/convert_async_collectives_to_sync.cc
new file mode 100644
index 0000000..a7dc96e
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/convert_async_collectives_to_sync.cc
@@ -0,0 +1,80 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/convert_async_collectives_to_sync.h"
+
+#include <utility>
+#include <vector>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/status/status.h"
+#include "absl/types/span.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_schedule.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+absl::Status GpuConvertAsyncCollectivesToSync::ConvertAsyncInstructionsToSync(
+    HloComputation* computation,
+    absl::Span<const std::pair<HloInstruction*, HloInstruction*>> async_pairs)
+    const {
+  absl::flat_hash_map<HloInstruction*, HloInstruction*> replaced_ops;
+  CollectiveBackendConfig sync_config;
+  sync_config.set_is_sync(true);
+  for (auto& [async_start, async_done] : async_pairs) {
+    // Tag the async start with is_sync = true.
+    TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
+                        async_start->backend_config<GpuBackendConfig>());
+    *gpu_config.mutable_collective_backend_config() = sync_config;
+    TF_RETURN_IF_ERROR(async_start->set_backend_config(gpu_config));
+    replaced_ops[async_start] = nullptr;
+    replaced_ops[async_done] = async_start;
+  }
+
+  // Update schedule.
+  HloModule* module = computation->parent();
+  const HloInstructionSequence& sequence =
+      module->schedule().sequence(computation);
+  std::vector<HloInstruction*> new_sequence;
+  new_sequence.reserve(sequence.size());
+  for (HloInstruction* instr : sequence.instructions()) {
+    auto it = replaced_ops.find(instr);
+    // If its not a start or done, add it to new schedule.
+    if (it == replaced_ops.end()) {
+      new_sequence.push_back(instr);
+      continue;
+    }
+
+    // If its a start op, do not add it to the schedule yet.
+    if (it->second == nullptr) {
+      continue;
+    }
+
+    // Its a done op. First add the start and then the done.
+    new_sequence.push_back(it->second);
+    new_sequence.push_back(instr);
+  }
+  module->schedule().set_sequence(computation, new_sequence);
+  return absl::OkStatus();
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/convert_async_collectives_to_sync.h b/third_party/xla/xla/service/gpu/transforms/convert_async_collectives_to_sync.h
new file mode 100644
index 0000000..6507080
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/convert_async_collectives_to_sync.h
@@ -0,0 +1,47 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_CONVERT_ASYNC_COLLECTIVES_TO_SYNC_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_CONVERT_ASYNC_COLLECTIVES_TO_SYNC_H_
+
+#include <utility>
+
+#include "absl/status/status.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/service/convert_async_collectives_to_sync.h"
+
+namespace xla {
+namespace gpu {
+
+class GpuConvertAsyncCollectivesToSync : public ConvertAsyncCollectivesToSync {
+ public:
+  using ConvertAsyncCollectivesToSync::ConvertAsyncCollectivesToSync;
+  absl::string_view name() const override {
+    return "gpu-convert-async-collectives-to-sync";
+  }
+
+  absl::Status ConvertAsyncInstructionsToSync(
+      HloComputation* computation,
+      absl::Span<const std::pair<HloInstruction*, HloInstruction*>> async_pairs)
+      const override;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_CONVERT_ASYNC_COLLECTIVES_TO_SYNC_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/convert_async_collectives_to_sync_test.cc b/third_party/xla/xla/service/gpu/transforms/convert_async_collectives_to_sync_test.cc
new file mode 100644
index 0000000..d38ab70
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/convert_async_collectives_to_sync_test.cc
@@ -0,0 +1,347 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/convert_async_collectives_to_sync.h"
+
+#include <string_view>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/status/status.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/tsl/lib/core/status_test_util.h"
+#include "xla/util.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+using ::testing::IsFalse;
+using ::testing::IsTrue;
+
+// Note: The pass only processes modules that are already scheduled. If the test
+// does not work as expected, make sure to check if "is_scheduled=true" is added
+// to the HLO module string.
+class GpuConvertAsyncCollectivesToSyncTest : public HloTestBase {
+ public:
+  absl::Status RunPass(HloModule *module, bool expect_change,
+                       HloPredicate is_nop = {}) {
+    TF_ASSIGN_OR_RETURN(bool changed,
+                        GpuConvertAsyncCollectivesToSync{is_nop}.Run(module));
+    EXPECT_EQ(changed, expect_change);
+    return absl::OkStatus();
+  }
+
+  // Returns true if the instruction with the given name is synchronous.
+  bool IsSync(HloModule *module, std::string_view name) {
+    const HloInstruction *inst = FindInstruction(module, name);
+    if (inst == nullptr) {
+      return false;
+    }
+    auto backend_config = inst->backend_config<GpuBackendConfig>()
+                              .value()
+                              .collective_backend_config();
+    return backend_config.is_sync();
+  }
+
+  HloPredicate is_nop_simple_ =
+      HloPredicateIsOp<HloOpcode::kBitcast, HloOpcode::kGetTupleElement,
+                       HloOpcode::kParameter>;
+};
+
+TEST_F(GpuConvertAsyncCollectivesToSyncTest, SimpleAllReduce) {
+  const absl::string_view hlo_string = R"(
+      HloModule test, is_scheduled=true
+
+      apply_op {
+        x = u32[] parameter(0)
+        y = u32[] parameter(1)
+        ROOT apply_op = u32[] add(x, y)
+      }
+
+      ENTRY test_computation {
+        id = u32[] replica-id()
+        start = u32[] all-reduce-start(id), to_apply=apply_op, channel_id=3
+        ROOT done = u32[] all-reduce-done(start)
+      }
+    )";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  TF_ASSERT_OK(RunPass(module.get(), /*expect_change=*/true));
+  EXPECT_THAT(IsSync(module.get(), "start"), IsTrue());
+}
+
+TEST_F(GpuConvertAsyncCollectivesToSyncTest, SimpleAllReduceWithNop) {
+  const absl::string_view hlo_string = R"(
+      HloModule test, is_scheduled=true
+
+      apply_op {
+        x = u32[] parameter(0)
+        y = u32[] parameter(1)
+        ROOT apply_op = u32[] add(x, y)
+      }
+
+      ENTRY test_computation {
+        id = u32[] replica-id()
+        start = u32[] all-reduce-start(id), to_apply=apply_op, channel_id=3, replica_groups={{0,1}, {2,3}}
+        id2 = f32[] bitcast(id)
+        ROOT done = u32[] all-reduce-done(start)
+      }
+    )";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  TF_ASSERT_OK(RunPass(module.get(), /*expect_change=*/true, is_nop_simple_));
+  EXPECT_THAT(IsSync(module.get(), "start"), IsTrue());
+}
+TEST_F(GpuConvertAsyncCollectivesToSyncTest, SimpleCollectiveBroadcast) {
+  const absl::string_view hlo_string = R"(
+  HloModule test, is_scheduled=true
+
+  collective_broadcast {
+    p0 = u32[8] parameter(0)
+    ROOT result = u32[8] collective-broadcast(p0), replica_groups={{0,1}, {2,3}}
+  }
+
+  ENTRY main {
+    data = u32[8] parameter(0)
+    cb-start = ((u32[8]{0}), u32[8]{0}) async-start(u32[8]{0} %data), calls=collective_broadcast
+    ROOT %ars = u32[8]{0} async-done(((u32[8]{0}), u32[8]{0}) %cb-start), calls=collective_broadcast
+  }
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  TF_ASSERT_OK(RunPass(module.get(), /*expect_change=*/true));
+  EXPECT_THAT(IsSync(module.get(), "cb-start"), IsTrue());
+}
+
+TEST_F(GpuConvertAsyncCollectivesToSyncTest, SimpleAllReduceWithNonNop) {
+  const absl::string_view hlo_string = R"(
+      HloModule test, is_scheduled=true
+
+      apply_op {
+        x = u32[] parameter(0)
+        y = u32[] parameter(1)
+        ROOT apply_op = u32[] add(x, y)
+      }
+
+      ENTRY test_computation {
+        id = u32[] replica-id()
+        start = u32[] all-reduce-start(id), to_apply=apply_op, channel_id=3
+        id2 = u32[] add(id, id)
+        ROOT done = u32[] all-reduce-done(start)
+      }
+    )";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  TF_ASSERT_OK(RunPass(module.get(), /*expect_change=*/false));
+}
+
+TEST_F(GpuConvertAsyncCollectivesToSyncTest, SimpleAllGather) {
+  const absl::string_view hlo_string = R"(
+  HloModule test, is_scheduled=true
+  ENTRY test_computation {
+    a1 = u32[1, 2] parameter(0)
+    ags = (u32[1, 2], u32[2, 2]) all-gather-start(a1), dimensions={0}, channel_id=3
+    ROOT allgather = u32[2,2] all-gather-done(ags)
+  })";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  TF_ASSERT_OK(RunPass(module.get(), /*expect_change=*/true));
+  EXPECT_THAT(IsSync(module.get(), "ags"), IsTrue());
+}
+
+TEST_F(GpuConvertAsyncCollectivesToSyncTest, SimpleCollectivePermute) {
+  const absl::string_view hlo_string = R"(
+  HloModule test, is_scheduled=true
+
+  ENTRY test_computation {
+    p = u32[2] parameter(0)
+    start = (u32[2], u32[2], u32[], u32[]) collective-permute-start(p), source_target_pairs={{0,1}, {1,0}}
+    ROOT done = u32[2] collective-permute-done(start)
+  })";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  TF_ASSERT_OK(RunPass(module.get(), /*expect_change=*/true));
+  EXPECT_THAT(IsSync(module.get(), "start"), IsTrue());
+}
+
+TEST_F(GpuConvertAsyncCollectivesToSyncTest, SimpleReduceScatter) {
+  const absl::string_view hlo_string = R"(
+  HloModule test, is_scheduled=true
+
+  add {
+    lhs = u32[] parameter(0)
+    rhs = u32[] parameter(1)
+    ROOT add = u32[] add(lhs, rhs)
+  }
+
+  reduce_scatter {
+    p0 = u32[8] parameter(0)
+    ROOT result = u32[4] reduce-scatter(p0), replica_groups={{0,3}, {1,2}},
+                      dimensions={0}, to_apply=add
+  }
+
+  ENTRY main {
+    data = u32[8] parameter(0)
+    rs-start = ((u32[8]{0}), u32[4]{0}) async-start(u32[8]{0} %data), calls=reduce_scatter
+    ROOT %ars = u32[4]{0} async-done(((u32[8]{0}), u32[4]{0}) %rs-start), calls=reduce_scatter
+  }
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  TF_ASSERT_OK(RunPass(module.get(), /*expect_change=*/true));
+  EXPECT_THAT(IsSync(module.get(), "rs-start"), IsTrue());
+}
+
+TEST_F(GpuConvertAsyncCollectivesToSyncTest, SimpleAllToAll) {
+  const absl::string_view hlo_string = R"(
+  HloModule test, is_scheduled=true
+
+  all_to_all {
+    p0 = u32[2] parameter(0)
+    ROOT result = u32[2] all-to-all(p0), dimensions={0}, replica_groups={{0,1},{2,3}}
+  }
+
+  ENTRY test_computation {
+    a1 = u32[2] parameter(0)
+    a2a-start = ((u32[2]), u32[2]) async-start(u32[2] a1), calls=all_to_all
+    ROOT a2s = u32[2] async-done(a2a-start), calls=all_to_all
+  }
+  )";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  TF_ASSERT_OK(RunPass(module.get(), /*expect_change=*/true));
+  EXPECT_THAT(IsSync(module.get(), "a2a-start"), IsTrue());
+}
+
+TEST_F(GpuConvertAsyncCollectivesToSyncTest, ControlDeps) {
+  const absl::string_view hlo_string = R"(
+      HloModule test, is_scheduled=true
+
+      apply_op {
+        x = u32[] parameter(0)
+        y = u32[] parameter(1)
+        ROOT apply_op = u32[] add(x, y)
+      }
+
+      ENTRY test_computation {
+        id = u32[] replica-id()
+        start1 = u32[] all-reduce-start(id), to_apply=apply_op, channel_id=3
+        done1 = u32[] all-reduce-done(start1)
+        start2 = u32[] all-reduce-start(id), to_apply=apply_op, channel_id=4, control-predecessors={done1}
+        done2 = u32[] all-reduce-done(start2)
+        ROOT x = u32[] add(done1, done2)
+      }
+    )";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  TF_ASSERT_OK(RunPass(module.get(), /*expect_change=*/true));
+  EXPECT_THAT(IsSync(module.get(), "start1"), IsTrue());
+  EXPECT_THAT(IsSync(module.get(), "start2"), IsTrue());
+}
+
+// Test multiple in-flight collectives that are ordered in a streaming fashion:
+// i.e., ends are in start order (FIFO).
+TEST_F(GpuConvertAsyncCollectivesToSyncTest, MultipleInFlightStreaming) {
+  const absl::string_view hlo_string = R"(
+      HloModule test, is_scheduled=true
+
+      apply_op {
+        x = u32[] parameter(0)
+        y = u32[] parameter(1)
+        ROOT apply_op = u32[] add(x, y)
+      }
+
+      ENTRY test_computation {
+        id = u32[] replica-id()
+        start1 = u32[] all-reduce-start(id), to_apply=apply_op, channel_id=3
+        start2 = u32[] all-reduce-start(id), to_apply=apply_op, channel_id=4
+        done1 = u32[] all-reduce-done(start1)
+        done2 = u32[] all-reduce-done(start2)
+        ROOT x = u32[] add(done1, done2)
+      }
+    )";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  TF_ASSERT_OK(RunPass(module.get(), /*expect_change=*/true));
+  EXPECT_THAT(IsSync(module.get(), "start1"), IsTrue());
+  EXPECT_THAT(IsSync(module.get(), "start2"), IsTrue());
+}
+
+// Test multiple in-flight collectives that are nested: {s0,{s1,e1},e0}
+TEST_F(GpuConvertAsyncCollectivesToSyncTest, MultipleInFlightNested) {
+  const absl::string_view hlo_string = R"(
+      HloModule test, is_scheduled=true
+
+      apply_op {
+        x = u32[] parameter(0)
+        y = u32[] parameter(1)
+        ROOT apply_op = u32[] add(x, y)
+      }
+
+      ENTRY test_computation {
+        id = u32[] replica-id()
+        start1 = u32[] all-reduce-start(id), to_apply=apply_op, channel_id=3
+        start2 = u32[] all-reduce-start(id), to_apply=apply_op, channel_id=4
+        done2 = u32[] all-reduce-done(start2)
+        done1 = u32[] all-reduce-done(start1)
+        ROOT x = u32[] add(done1, done2)
+      }
+    )";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  TF_ASSERT_OK(RunPass(module.get(), /*expect_change=*/true));
+  EXPECT_THAT(IsSync(module.get(), "start1"), IsTrue());
+  EXPECT_THAT(IsSync(module.get(), "start2"), IsTrue());
+}
+
+// Test multiple in-flight collectives that are nested: {s0,{s1,e1},e0} where
+// inner pair can be converted but not outer.
+TEST_F(GpuConvertAsyncCollectivesToSyncTest, MultipleInFlightNestedPartial) {
+  const absl::string_view hlo_string = R"(
+      HloModule test, is_scheduled=true
+
+      apply_op {
+        x = u32[] parameter(0)
+        y = u32[] parameter(1)
+        ROOT apply_op = u32[] add(x, y)
+      }
+
+      ENTRY test_computation {
+        id = u32[] replica-id()
+        start1 = u32[] all-reduce-start(id), to_apply=apply_op, channel_id=3
+        start2 = u32[] all-reduce-start(id), to_apply=apply_op, channel_id=4
+        done2 = u32[] all-reduce-done(start2)
+        id2 = u32[] add(done2, done2)
+        done1 = u32[] all-reduce-done(start1)
+        ROOT x = u32[] add(done1, done2)
+      }
+    )";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  TF_ASSERT_OK(RunPass(module.get(), /*expect_change=*/true));
+  EXPECT_THAT(IsSync(module.get(), "start1"), IsFalse());
+  EXPECT_THAT(IsSync(module.get(), "start2"), IsTrue());
+}
+
+}  // namespace
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/copy_fusion.cc b/third_party/xla/xla/service/gpu/transforms/copy_fusion.cc
new file mode 100644
index 0000000..eb43ca2
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/copy_fusion.cc
@@ -0,0 +1,197 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/copy_fusion.h"
+
+#include <cstdint>
+#include <queue>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/gpu/gpu_fusible.h"
+#include "xla/service/gpu/hlo_traversal.h"
+#include "xla/service/gpu/ir_emission_utils.h"
+#include "xla/service/gpu/reduction_utils.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/logging.h"
+
+namespace xla {
+namespace gpu {
+
+bool OnlyElementwiseOpsReachableFromParams(HloComputation* fused_computation) {
+  std::queue<const HloInstruction*> q;
+  absl::flat_hash_set<const HloInstruction*> visited;
+  for (auto param : fused_computation->parameter_instructions()) {
+    q.push(param);
+    visited.insert(param);
+  }
+  while (!q.empty()) {
+    const HloInstruction* hlo = q.front();
+    q.pop();
+    for (auto user : hlo->users()) {
+      if ((!user->IsElementwiseOnOperand(user->operand_index(hlo)) ||
+           user->opcode() == HloOpcode::kCopy) &&
+          user->opcode() != HloOpcode::kBitcast &&
+          user->opcode() != HloOpcode::kTuple) {
+        return false;
+      }
+      if (visited.insert(user).second) {
+        q.push(user);
+      }
+    }
+  }
+  return true;
+}
+
+absl::StatusOr<bool> CopyFusion::DoCopyFusion(HloComputation* computation) {
+  bool changed = false;
+  std::vector<HloInstruction*> defs_before_uses =
+      computation->MakeInstructionPostOrder();
+
+  for (HloInstruction* hlo : defs_before_uses) {
+    if (hlo->opcode() != HloOpcode::kFusion) {
+      continue;
+    }
+    std::vector<HloInstruction*> copies;
+    std::vector<HloInstruction*> other_users;
+    HloComputation* fused_computation = hlo->fused_instructions_computation();
+    if (!OnlyElementwiseOpsReachableFromParams(fused_computation)) {
+      continue;
+    }
+    HloInstruction* root = fused_computation->root_instruction();
+    if (IsReductionFromOrToContiguousDimensions(*root) ||
+        root->opcode() == HloOpcode::kScatter ||
+        (hlo->IsMultiOutputFusion() &&
+         absl::c_all_of(root->operands(), [](const HloInstruction* slice) {
+           return slice->opcode() == HloOpcode::kSlice;
+         }))) {
+      continue;
+    }
+    for (auto user : hlo->users()) {
+      HloInstruction* copy_user = user;
+      // Skip get-tuple-element ops.
+      if (copy_user->opcode() == HloOpcode::kGetTupleElement &&
+          copy_user->user_count() == 1) {
+        if (IsReductionFromOrToContiguousDimensions(
+                *(root->operand(copy_user->tuple_index())))) {
+          other_users.push_back(user);
+          continue;
+        }
+        copy_user = copy_user->users()[0];
+      }
+      // Skip bitcast ops.
+      if (copy_user->opcode() == HloOpcode::kBitcast &&
+          copy_user->user_count() == 1) {
+        copy_user = copy_user->users()[0];
+      }
+      if (copy_user->opcode() == HloOpcode::kCopy &&
+          copy_user->shape() == copy_user->operand(0)->shape() &&
+          !copy_user->shape().IsTuple() &&
+          !copy_user->HasControlDependencies()) {
+        copies.push_back(copy_user);
+      } else {
+        other_users.push_back(user);
+      }
+    }
+    if (copies.empty()) {
+      continue;
+    }
+    auto fusion_adaptor = HloFusionAdaptor::ForComputation(fused_computation);
+    auto dynamic_update_slices =
+        GetOutputDefiningDynamicUpdateSlices(fusion_adaptor->GetRoots());
+    // Skip dynamic update slice fusions which might be emitted in-place.
+    if (!dynamic_update_slices.empty() &&
+        (root->opcode() != HloOpcode::kTuple ||
+         dynamic_update_slices.size() == root->shape().tuple_shapes_size())) {
+      continue;
+    }
+    changed = true;
+
+    HloInstruction::InstructionVector tuple_elements;
+    int64_t num_outputs =
+        hlo->IsMultiOutputFusion() ? root->operand_count() : int64_t{1};
+    tuple_elements.reserve(copies.size() + num_outputs);
+    if (hlo->IsMultiOutputFusion()) {
+      for (HloInstruction* operand : root->operands()) {
+        tuple_elements.push_back(operand);
+      }
+    } else {
+      tuple_elements.push_back(root);
+    }
+
+    for (auto copy : copies) {
+      HloInstruction* user = copy;
+      std::vector<HloInstruction*> operand_chain;
+      operand_chain.push_back(user);
+      while (user->operand(0) != hlo) {
+        user = user->mutable_operand(0);
+        operand_chain.push_back(user);
+      }
+      HloInstruction* clone_operand = root;
+      if (hlo->IsMultiOutputFusion()) {
+        clone_operand = root->mutable_operand(user->tuple_index());
+        CHECK_EQ(operand_chain.back()->opcode(), HloOpcode::kGetTupleElement);
+        operand_chain.pop_back();
+      }
+      for (int64_t i = operand_chain.size() - 1; i >= 0; --i) {
+        HloInstruction* user = operand_chain[i];
+        clone_operand = fused_computation->AddInstruction(
+            user->CloneWithNewOperands(user->shape(), {clone_operand}));
+      }
+      tuple_elements.push_back(clone_operand);
+    }
+
+    HloInstruction* new_root = fused_computation->AddInstruction(
+        HloInstruction::CreateTuple(tuple_elements));
+    fused_computation->set_root_instruction(new_root,
+                                            /*accept_different_shape=*/true);
+    *hlo->mutable_shape() = new_root->shape();
+
+    if (root->opcode() == HloOpcode::kTuple) {
+      TF_RETURN_IF_ERROR(fused_computation->RemoveInstruction(root));
+    } else {
+      auto get_tuple_element_root = computation->AddInstruction(
+          HloInstruction::CreateGetTupleElement(hlo, 0));
+      TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWithDifferentShape(
+          other_users, get_tuple_element_root));
+    }
+    for (int64_t i = 0; i < copies.size(); ++i) {
+      auto get_tuple_element = computation->AddInstruction(
+          HloInstruction::CreateGetTupleElement(hlo, num_outputs + i));
+      TF_RETURN_IF_ERROR(
+          computation->ReplaceInstruction(copies[i], get_tuple_element));
+    }
+  }
+  return changed;
+}
+
+absl::StatusOr<bool> CopyFusion::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  // Only for the entry computation we can be sure that the copies do not share
+  // a buffer with a parameter of the fusion that it will be fused with. For
+  // example while loop computations have tuple parameters that need to share
+  // the buffers with the output tuples, and copies inserted by the
+  // CopyInsertion pass will share a buffer with the tuple output (and thus
+  // with the tuple input as well).
+  return DoCopyFusion(module->entry_computation());
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/copy_fusion.h b/third_party/xla/xla/service/gpu/transforms/copy_fusion.h
new file mode 100644
index 0000000..a6a1ae4
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/copy_fusion.h
@@ -0,0 +1,49 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_COPY_FUSION_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_COPY_FUSION_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// CopyFusion checks if a fusion is followed by multiple copies and if so, adds
+// those copies to the fusion, replacing the copies with get_tuple_elements.
+class CopyFusion : public HloModulePass {
+ public:
+  CopyFusion() = default;
+
+  absl::string_view name() const override { return "copy_fusion"; }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+  absl::StatusOr<bool> DoCopyFusion(HloComputation* computation);
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_COPY_FUSION_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/copy_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/copy_fusion_test.cc
new file mode 100644
index 0000000..1bd2d11
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/copy_fusion_test.cc
@@ -0,0 +1,500 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/copy_fusion.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/strings/str_cat.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/tests/hlo_test_base.h"
+
+namespace xla {
+namespace gpu {
+
+namespace m = ::xla::match;
+
+class CopyFusionTest : public HloTestBase {
+ public:
+  CopyFusion cf_;
+};
+
+const char kModulePrefix[] = R"(
+    HloModule test_module
+
+    scalar_add_computation {
+      scalar_lhs.0 = f32[] parameter(0)
+      scalar_rhs.0 = f32[] parameter(1)
+      ROOT add.0 = f32[] add(scalar_lhs.0, scalar_rhs.0)
+    }
+    scalar_mul_computation {
+      scalar_lhs.1 = f32[] parameter(0)
+      scalar_rhs.1 = f32[] parameter(1)
+      ROOT mul.1 = f32[] multiply(scalar_lhs.1, scalar_rhs.1)
+    })";
+
+TEST_F(CopyFusionTest, CopyFusionTransposeOfBroadcastedConstantTwoCopies) {
+  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+    fused_computation {
+      two = f32[] constant(2.0)
+      broadcast = f32[16,32]{1,0} broadcast(two), dimensions={}
+      s.1 = f32[16,32]{1,0} sqrt(broadcast)
+      ROOT c.1 = f32[32,16]{1,0} transpose(s.1), dimensions={1,0}
+    }
+
+    ENTRY main {
+      fusion = f32[32,16]{1,0} fusion(), kind=kInput, calls=fused_computation
+      copy.1 = f32[32,16]{1,0} copy(fusion)
+      copy.2 = f32[32,16]{1,0} copy(fusion)
+      ROOT t = (f32[32,16]{1,0}, f32[32,16]{1,0}) tuple(copy.2, copy.1)
+    })"))
+                    .value();
+  ASSERT_TRUE(cf_.Run(module.get()).value());
+  SCOPED_TRACE(module->ToString());
+  const HloInstruction* root = module->entry_computation()->root_instruction();
+  const HloInstruction* fusion = nullptr;
+  ASSERT_THAT(root, GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
+                                        m::GetTupleElement())));
+  EXPECT_THAT(fusion->fused_expression_root(),
+              GmockMatch(m::Tuple(m::Transpose(), m::Copy(), m::Copy())));
+}
+
+TEST_F(CopyFusionTest, CopyFusionTransposeTwoCopies) {
+  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+    fused_computation {
+      param_0.1 = f32[16,32]{1,0} parameter(0)
+      s.1 = f32[16,32]{1,0} sqrt(param_0.1)
+      ROOT c.1 = f32[32,16]{1,0} transpose(s.1), dimensions={1,0}
+    }
+
+    ENTRY main {
+      p = f32[16,32]{1,0} parameter(0)
+      fusion = f32[32,16]{1,0} fusion(p), kind=kInput, calls=fused_computation
+      copy.1 = f32[32,16]{1,0} copy(fusion)
+      copy.2 = f32[32,16]{1,0} copy(fusion)
+      ROOT t = (f32[32,16]{1,0}, f32[32,16]{1,0}) tuple(copy.2, copy.1)
+    })"))
+                    .value();
+  ASSERT_FALSE(cf_.Run(module.get()).value());
+}
+
+TEST_F(CopyFusionTest, CopyFusionNegateAndTwoCopies) {
+  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+    fused_computation {
+      p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(0)
+      mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1)
+      ROOT neg = f32[128,512,28,28]{3,2,1,0} negate(mul)
+    }
+
+    ENTRY entry {
+      p0 = f32[128,512,28,28]{3,2,1,0} parameter(0)
+      fusion = f32[128,512,28,28]{3,2,1,0} fusion(p0), kind=kInput, calls=fused_computation
+      copy.1 = f32[128,512,28,28]{3,2,1,0} copy(fusion)
+      copy.2 = f32[128,512,28,28]{3,2,1,0} copy(fusion)
+      ROOT root = (f32[128,512,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0}) tuple(copy.1, copy.2)
+    })"))
+                    .value();
+  ASSERT_TRUE(cf_.Run(module.get()).value());
+  SCOPED_TRACE(module->ToString());
+  const HloInstruction* root = module->entry_computation()->root_instruction();
+  const HloInstruction* fusion = nullptr;
+  ASSERT_THAT(root, GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
+                                        m::GetTupleElement())));
+  EXPECT_THAT(fusion->fused_expression_root(),
+              GmockMatch(m::Tuple(m::Negate(), m::Copy(), m::Copy())));
+}
+
+TEST_F(CopyFusionTest, CopyFusionShouldNotRunWithReduce) {
+  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+    fused_computation {
+      p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
+      mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1)
+      const.1 = f32[] parameter(0)
+      ROOT reduce.1 = f32[512]{0} reduce(mul, const.1), dimensions={0,2,3}, to_apply=scalar_add_computation
+    }
+
+    ENTRY entry {
+      p0 = f32[] parameter(0)
+      p1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
+      fusion = f32[512] fusion(p0, p1), kind=kInput, calls=fused_computation
+      copy.1 = f32[512]{0} copy(fusion)
+      copy.2 = f32[512]{0} copy(fusion)
+      ROOT root = (f32[512]{0}, f32[512]{0}) tuple(copy.1, copy.2)
+    })"))
+                    .value();
+  ASSERT_FALSE(cf_.Run(module.get()).value());
+}
+
+TEST_F(CopyFusionTest, CopyFusionShouldRunWithUncopiedReduce) {
+  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+    fused_computation {
+      two = f32[] constant(2.0)
+      broadcast = f32[128,512,28,28]{3,2,1,0} broadcast(two)
+      mul = f32[128,512,28,28]{3,2,1,0} multiply(broadcast, broadcast)
+      const = f32[] constant(0.0)
+      reduce = f32[512]{0} reduce(mul, const), dimensions={0,2,3}, to_apply=scalar_add_computation
+      ROOT tuple = (f32[128,512,28,28]{3,2,1,0}, f32[512]{0}) tuple(mul, reduce)
+    }
+
+    ENTRY entry {
+      fusion = (f32[128,512,28,28]{3,2,1,0}, f32[512]) fusion(), kind=kInput, calls=fused_computation
+      gte = f32[128,512,28,28]{3,2,1,0} get-tuple-element(fusion), index=0
+      gte.2 = f32[512]{0} get-tuple-element(fusion), index=1
+      copy.1 = f32[128,512,28,28]{3,2,1,0} copy(gte)
+      ROOT root = (f32[128,512,28,28]{3,2,1,0}, f32[512]{0}) tuple(copy.1, gte.2)
+    })"))
+                    .value();
+  ASSERT_TRUE(cf_.Run(module.get()).value());
+  SCOPED_TRACE(module->ToString());
+  const HloInstruction* root = module->entry_computation()->root_instruction();
+  const HloInstruction* fusion = nullptr;
+  ASSERT_THAT(root, GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
+                                        m::GetTupleElement())));
+  EXPECT_THAT(fusion->fused_expression_root(),
+              GmockMatch(m::Tuple(m::Multiply(), m::Reduce(), m::Copy())));
+}
+
+TEST_F(CopyFusionTest, CopyFusionShouldNotFuseForSliceMultioutputFusion) {
+  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+    fused_computation {
+      p1 = f32[128,512,28,28]{3,2,1,0} parameter(0)
+      mul = f32[128,512,28,28]{3,2,1,0} multiply(p1, p1)
+      slice1 = f32[128,100,28,28]{3,2,1,0} slice(mul), slice={[0:128],[0:100],[0:28],[0:28]}
+      slice2 = f32[128,200,28,28]{3,2,1,0} slice(mul), slice={[0:128],[50:250],[0:28],[0:28]}
+      ROOT tuple = (f32[128,100,28,28]{3,2,1,0}, f32[128,200,28,28]{3,2,1,0}) tuple(slice1, slice2)
+    }
+
+    ENTRY entry {
+      p1 = f32[128,512,28,28]{3,2,1,0} parameter(0)
+      ROOT fusion = (f32[128,100,28,28]{3,2,1,0}, f32[128,200,28,28]{3,2,1,0}) fusion(p1), kind=kInput, calls=fused_computation
+    })"))
+                    .value();
+  ASSERT_FALSE(cf_.Run(module.get()).value());
+}
+
+TEST_F(CopyFusionTest, CopyFusionShouldNotRunWithScatter) {
+  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+    fused_computation {
+      p0 = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
+      scatter_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
+      updates = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} parameter(2)
+      input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} negate(p0)
+      ROOT %scatter = f32[50,49,48,47,46]{4,3,2,1,0} scatter(input_tensor, scatter_indices, updates), update_window_dims={4,5,6,7,8}, inserted_window_dims={}, scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, to_apply=scalar_add_computation
+}
+
+    ENTRY entry {
+      param.0 = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
+      param.1 = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
+      param.2 = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} parameter(2)
+      fusion = f32[50,49,48,47,46]{4,3,2,1,0} fusion(param.0, param.1, param.2), kind=kInput, calls=fused_computation
+      ROOT copy = f32[50,49,48,47,46]{4,3,2,1,0} copy(fusion)
+    })"))
+                    .value();
+  ASSERT_FALSE(cf_.Run(module.get()).value());
+}
+
+TEST_F(CopyFusionTest, CopyFusionShouldNotRunOutsideEntryComputation) {
+  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+fused_computation.549 {
+  param_0.8511 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} parameter(0)
+  bitcast.52601 = bf16[15,1,2,48,128,2048]{5,4,3,2,1,0} bitcast(param_0.8511)
+  slice = bf16[15,1,2,48,128,1]{5,4,3,2,1,0} slice(bitcast.52601), slice={[0:15:1], [0:1:1], [0:2:1], [0:48:1], [0:128:1], [0:1:1]}
+  bitcast = bf16[15,1,2,48,128]{4,3,2,1,0} bitcast(slice)
+  ROOT broadcast = bf16[15,1,2,48,128,2048]{5,4,3,2,1,0} broadcast(bitcast), dimensions={0,1,2,3,4}
+}
+
+condition {
+  constant_6915 = s32[] constant(15)
+  param.218 = (bf16[15,1,2,2048,48,128]{3,5,4,2,1,0}, s32[]) parameter(0)
+  get-tuple-element.3714 = s32[] get-tuple-element(param.218), index=1
+  ROOT compare.1738 = pred[] compare(get-tuple-element.3714, constant_6915), direction=LT
+}
+
+body {
+  tuple_param = (bf16[15,1,2,2048,48,128]{3,5,4,2,1,0}, s32[]) parameter(0)
+  param_0 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} get-tuple-element(tuple_param), index=0
+  param_1 = s32[] get-tuple-element(tuple_param), index=1
+  fusion.549 = bf16[15,1,2,48,128,2048]{5,4,3,2,1,0} fusion(param_0), kind=kLoop, calls=fused_computation.549
+  bitcast = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} bitcast(fusion.549)
+  copy = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} copy(bitcast)
+  constant_one = s32[] constant(1)
+  add = s32[] add(param_1, constant_one), control-predecessors={fusion.549}
+  ROOT tuple = (bf16[15,1,2,2048,48,128]{3,5,4,2,1,0}, s32[]) tuple(copy, add)
+}
+
+ENTRY main {
+  param_0 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} parameter(0)
+  zero = s32[] constant(0)
+  copy.0 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} copy(param_0)
+  copy.1 = s32[] copy(zero)
+  tuple = tuple(copy.0, copy.1)
+  ROOT while = (bf16[15,1,2,2048,48,128]{3,5,4,2,1,0}, s32[]) while(tuple), condition=condition, body=body, backend_config="{\"known_trip_count\":{\"n\":\"15\"}}"
+})"))
+                    .value();
+  ASSERT_FALSE(cf_.Run(module.get()).value());
+}
+
+TEST_F(CopyFusionTest, CopyFusionShouldNotRunWithDynamicUpdateSliceInplace) {
+  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+    fused_computation {
+      p.0 = f16[50,96,1024]{2,1,0} parameter(0)
+      p.1 = f16[1,96,1024]{2,1,0} parameter(1)
+      c.0 = s32[3]{0} constant({0, 0, 0})
+      ROOT %dynamic-update-slice = f16[50,96,1024]{2,1,0} dynamic-update-slice(p.0, p.1, c.0)
+    }
+
+    ENTRY entry {
+      p0 = f16[50,96,1024]{2,1,0} parameter(0)
+      p1 = f16[1,96,1024]{2,1,0} parameter(1)
+      fusion = f16[50,96,1024]{2,1,0} fusion(p0, p1), kind=kInput, calls=fused_computation
+      copy.1 = f16[50,96,1024]{2,1,0} copy(fusion)
+      copy.2 = f16[50,96,1024]{2,1,0} copy(fusion)
+      ROOT root = (f16[50,96,1024]{2,1,0}, f16[50,96,1024]{2,1,0}) tuple(copy.1, copy.2)
+    })"))
+                    .value();
+  ASSERT_FALSE(cf_.Run(module.get()).value());
+}
+
+TEST_F(CopyFusionTest, CopyFusionWithDynamicUpdateSliceNotInplace) {
+  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+    fused_computation {
+      one = f32[] constant(1.0)
+      zero = f32[] constant(0.0)
+      p.0 = f16[50,96,1024]{2,1,0} broadcast(one), dimensions={}
+      p.1 = f16[1,96,1024]{2,1,0} broadcast(zero), dimensions={}
+      c.0 = s32[3]{0} constant({0, 0, 0})
+      dynamic-update-slice = f16[50,96,1024]{2,1,0} dynamic-update-slice(p.0, p.1, c.0)
+      neg = f16[50,96,1024]{2,1,0} negate(dynamic-update-slice)
+      ROOT tuple = (f16[50,96,1024]{2,1,0}, f16[50,96,1024]{2,1,0}) tuple(dynamic-update-slice, neg)
+    }
+
+    ENTRY entry {
+      fusion = (f16[50,96,1024]{2,1,0}, f16[50,96,1024]{2,1,0}) fusion(), kind=kInput, calls=fused_computation
+      gte.0 = f16[50,96,1024]{2,1,0} get-tuple-element(fusion), index=0
+      gte.1 = f16[50,96,1024]{2,1,0} get-tuple-element(fusion), index=1
+      bitcast = f16[1,50,96,1024]{3,2,1,0} bitcast(gte.0)
+      copy = f16[1,50,96,1024]{3,2,1,0} copy(bitcast)
+      ROOT root = (f16[1,50,96,1024]{3,2,1,0}, f16[50,96,1024]{2,1,0}) tuple(copy, gte.1)
+    })"))
+                    .value();
+  ASSERT_TRUE(cf_.Run(module.get()).value());
+  SCOPED_TRACE(module->ToString());
+  const HloInstruction* root = module->entry_computation()->root_instruction();
+  const HloInstruction* fusion = nullptr;
+  ASSERT_THAT(root, GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
+                                        m::GetTupleElement())));
+  EXPECT_THAT(
+      fusion->fused_expression_root(),
+      GmockMatch(m::Tuple(m::DynamicUpdateSlice(), m::Negate(), m::Copy())));
+}
+
+TEST_F(CopyFusionTest, CopyFusionTransposeAndThreeCopies) {
+  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+    fused_computation {
+      two = f32[] constant(2.0)
+      param_0.1 = f32[16,32]{1,0} broadcast(two), dimensions={}
+      s.1 = f32[16,32]{1,0} sqrt(param_0.1)
+      ROOT c.1 = f32[32,16]{1,0} transpose(s.1), dimensions={1,0}
+    }
+
+    ENTRY entry {
+      fusion = f32[32,16]{1,0} fusion(), kind=kInput, calls=fused_computation
+      copy.1 = f32[32,16]{1,0} copy(fusion)
+      copy.2 = f32[32,16]{1,0} copy(fusion)
+      copy.3 = f32[32,16]{1,0} copy(fusion)
+      ROOT root = (f32[32,16]{1,0}, f32[32,16]{1,0}, f32[32,16]{1,0}) tuple(copy.1, copy.2, copy.3)
+    })"))
+                    .value();
+  ASSERT_TRUE(cf_.Run(module.get()).value());
+  SCOPED_TRACE(module->ToString());
+  const HloInstruction* root = module->entry_computation()->root_instruction();
+  const HloInstruction* fusion = nullptr;
+  ASSERT_THAT(root,
+              GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
+                                  m::GetTupleElement(), m::GetTupleElement())));
+  EXPECT_THAT(
+      fusion->fused_expression_root(),
+      GmockMatch(m::Tuple(m::Transpose(), m::Copy(), m::Copy(), m::Copy())));
+}
+
+TEST_F(CopyFusionTest, CopyFusionRunWithOnlyOneCopy) {
+  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+    fused_computation {
+      p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(0)
+      mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1)
+      ROOT neg = f32[128,512,28,28]{3,2,1,0} negate(mul)
+    }
+
+    ENTRY entry {
+      p0 = f32[128,512,28,28]{3,2,1,0} parameter(0)
+      fusion = f32[128,512,28,28]{3,2,1,0} fusion(p0), kind=kInput, calls=fused_computation
+      ROOT copy.1 = f32[128,512,28,28]{3,2,1,0} copy(fusion)
+    })"))
+                    .value();
+  ASSERT_TRUE(cf_.Run(module.get()).value());
+  SCOPED_TRACE(module->ToString());
+  const HloInstruction* root = module->entry_computation()->root_instruction();
+  const HloInstruction* fusion = nullptr;
+  ASSERT_THAT(root, GmockMatch(m::GetTupleElement(m::Fusion(&fusion))));
+  EXPECT_THAT(fusion->fused_expression_root(),
+              GmockMatch(m::Tuple(m::Negate(), m::Copy())));
+}
+
+TEST_F(CopyFusionTest, CopyFusionNegateAndTwoCopiesAndTransposeCopy) {
+  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+    fused_computation {
+      p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(0)
+      mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1)
+      ROOT neg = f32[128,512,28,28]{3,2,1,0} negate(mul)
+    }
+
+    ENTRY entry {
+      p0 = f32[128,512,28,28]{3,2,1,0} parameter(0)
+      fusion = f32[128,512,28,28]{3,2,1,0} fusion(p0), kind=kInput, calls=fused_computation
+      copy.1 = f32[128,512,28,28]{3,2,1,0} copy(fusion)
+      transpose = f32[128,512,28,28]{2,3,0,1} copy(fusion)
+      bitcast = f32[512,128,28,28]{3,2,1,0} bitcast(transpose)
+      copy.2 = f32[128,512,28,28]{3,2,1,0} copy(fusion)
+      ROOT root = (f32[128,512,28,28]{3,2,1,0}, f32[512,128,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0}) tuple(copy.1, bitcast, copy.2)
+    })"))
+                    .value();
+  ASSERT_TRUE(cf_.Run(module.get()).value());
+  SCOPED_TRACE(module->ToString());
+  const HloInstruction* root = module->entry_computation()->root_instruction();
+  const HloInstruction* fusion = nullptr;
+  ASSERT_THAT(root, GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
+                                        m::Bitcast(), m::GetTupleElement())));
+  EXPECT_THAT(fusion->fused_expression_root(),
+              GmockMatch(m::Tuple(m::Negate(), m::Copy(), m::Copy())));
+}
+
+TEST_F(CopyFusionTest, CopyFusionRunWithOnlyOneNonTransposeCopy) {
+  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+    fused_computation {
+      p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(0)
+      mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1)
+      ROOT neg = f32[128,512,28,28]{3,2,1,0} negate(mul)
+    }
+
+    ENTRY entry {
+      p0 = f32[128,512,28,28]{3,2,1,0} parameter(0)
+      fusion = f32[128,512,28,28]{3,2,1,0} fusion(p0), kind=kInput, calls=fused_computation
+      copy.1 = f32[128,512,28,28]{3,2,1,0} copy(fusion)
+      transpose.1 = f32[128,512,28,28]{2,3,0,1} copy(fusion)
+      bitcast.1 = f32[512,128,28,28]{3,2,1,0} bitcast(transpose.1)
+      transpose.2 = f32[128,512,28,28]{2,3,0,1} copy(fusion)
+      bitcast.2 = f32[512,128,28,28]{3,2,1,0} bitcast(transpose.2)
+      ROOT root = (f32[128,512,28,28]{3,2,1,0}, f32[512,128,28,28]{3,2,1,0}, f32[512,128,28,28]{3,2,1,0}) tuple(copy.1, bitcast.1, bitcast.2)
+    })"))
+                    .value();
+  ASSERT_TRUE(cf_.Run(module.get()).value());
+  SCOPED_TRACE(module->ToString());
+  const HloInstruction* root = module->entry_computation()->root_instruction();
+  const HloInstruction* fusion = nullptr;
+  ASSERT_THAT(root, GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
+                                        m::Bitcast(), m::Bitcast())));
+  EXPECT_THAT(fusion->fused_expression_root(),
+              GmockMatch(m::Tuple(m::Negate(), m::Copy())));
+}
+
+TEST_F(CopyFusionTest, CopyFusionSkipTupleCopies) {
+  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+    fused_computation {
+      p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(0)
+      mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1)
+      neg.1 = f32[128,512,28,28]{3,2,1,0} negate(mul)
+      neg.2 = f32[128,512,28,28]{3,2,1,0} negate(mul)
+      ROOT tuple = (f32[128,512,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0}) tuple(neg.1, neg.2)
+    }
+
+    ENTRY entry {
+      p0 = f32[128,512,28,28]{3,2,1,0} parameter(0)
+      fusion = (f32[128,512,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0}) fusion(p0), kind=kInput, calls=fused_computation
+      copy.1 = (f32[128,512,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0}) copy(fusion)
+      copy.2 = (f32[128,512,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0}) copy(fusion)
+      ROOT root = ((f32[128,512,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0}),(f32[128,512,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0})) tuple(copy.1, copy.2)
+    })"))
+                    .value();
+  ASSERT_FALSE(cf_.Run(module.get()).value());
+}
+
+TEST_F(CopyFusionTest, CopyFusionTupleAndGetTuple) {
+  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+    fused_computation {
+      p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(0)
+      mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1)
+      neg.1 = f32[128,512,28,28]{3,2,1,0} negate(mul)
+      neg.2 = f32[128,512,28,28]{3,2,1,0} negate(mul)
+      ROOT tuple = (f32[128,512,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0}) tuple(neg.1, neg.2)
+    }
+
+    ENTRY entry {
+      p0 = f32[128,512,28,28]{3,2,1,0} parameter(0)
+      fusion = (f32[128,512,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0}) fusion(p0), kind=kInput, calls=fused_computation
+      gte.1 = f32[128,512,28,28]{3,2,1,0} get-tuple-element(fusion), index=0
+      gte.2 = f32[128,512,28,28]{3,2,1,0} get-tuple-element(fusion), index=1
+      copy.1 = f32[128,512,28,28]{3,2,1,0} copy(gte.1)
+      copy.2 = f32[128,512,28,28]{3,2,1,0} copy(gte.2)
+      ROOT root = (f32[128,512,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0}) tuple(copy.1, copy.2)
+    })"))
+                    .value();
+  ASSERT_TRUE(cf_.Run(module.get()).value());
+  SCOPED_TRACE(module->ToString());
+  const HloInstruction* root = module->entry_computation()->root_instruction();
+  const HloInstruction* fusion = nullptr;
+  ASSERT_THAT(root, GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
+                                        m::GetTupleElement())));
+  EXPECT_THAT(
+      fusion->fused_expression_root(),
+      GmockMatch(m::Tuple(m::Negate(), m::Negate(), m::Copy(), m::Copy())));
+}
+
+TEST_F(CopyFusionTest, CopyFusionWithFusionReturningTupleAndOtherUser) {
+  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+    fused_computation {
+      p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(0)
+      mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1)
+      neg.1 = f32[128,512,28,28]{3,2,1,0} negate(mul)
+      neg.2 = f32[128,512,28,28]{3,2,1,0} negate(mul)
+      ROOT tuple = (f32[128,512,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0}) tuple(neg.1, neg.2)
+    }
+
+    ENTRY entry {
+      p0 = f32[128,512,28,28]{3,2,1,0} parameter(0)
+      fusion = (f32[128,512,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0}) fusion(p0), kind=kInput, calls=fused_computation
+      gte.1 = f32[128,512,28,28]{3,2,1,0} get-tuple-element(fusion), index=0
+      gte.2 = f32[128,512,28,28]{3,2,1,0} get-tuple-element(fusion), index=1
+      copy.1 = f32[128,512,28,28]{3,2,1,0} copy(gte.1)
+      copy.2 = f32[128,512,28,28]{3,2,1,0} copy(gte.2)
+      transpose = f32[128,512,28,28]{2,3,0,1} copy(gte.1)
+      bitcast = f32[512,128,28,28]{3,2,1,0} bitcast(transpose)
+      ROOT root = (f32[128,512,28,28]{3,2,1,0}, f32[512,128,28,28]{3,2,1,0}, f32[128,512,28,28]{3,2,1,0}) tuple(copy.1, bitcast, copy.2)
+    })"))
+                    .value();
+  ASSERT_TRUE(cf_.Run(module.get()).value());
+  SCOPED_TRACE(module->ToString());
+  const HloInstruction* root = module->entry_computation()->root_instruction();
+  const HloInstruction* fusion = nullptr;
+  ASSERT_THAT(root,
+              GmockMatch(m::Tuple(m::Copy(), m::Bitcast(),
+                                  m::GetTupleElement(m::Fusion(&fusion)))));
+  EXPECT_THAT(fusion->fused_expression_root(),
+              GmockMatch(m::Tuple(m::Negate(), m::Negate(), m::Copy())));
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/cublas_pad_for_gemms.cc b/third_party/xla/xla/service/gpu/transforms/cublas_pad_for_gemms.cc
new file mode 100644
index 0000000..43f1242
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cublas_pad_for_gemms.cc
@@ -0,0 +1,210 @@
+/* Copyright 2018 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/cublas_pad_for_gemms.h"
+
+#include <cstdint>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/literal_util.h"
+#include "xla/service/gpu/fusions/triton/triton_support.h"
+#include "xla/service/gpu/ir_emission_utils.h"
+#include "xla/service/gpu/transforms/gemm_fusion.h"
+#include "xla/shape.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/util.h"
+#include "tsl/platform/logging.h"
+#include "tsl/platform/status.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+static absl::StatusOr<bool> PadForGemm(HloDotInstruction* dot,
+                                       PrimitiveType datatype,
+                                       int pad_to_multiple_of) {
+  auto* lhs = dot->mutable_operand(0);
+  auto* rhs = dot->mutable_operand(1);
+
+  Shape lshape = lhs->shape();
+  Shape rshape = rhs->shape();
+  Shape result_shape = dot->shape();
+
+  if (lshape.element_type() != datatype || rshape.element_type() != datatype) {
+    return false;
+  }
+
+  auto pad_dim = [&](Shape& s, int dim) {
+    s.set_dimensions(dim,
+                     RoundUpTo<int64_t>(s.dimensions(dim), pad_to_multiple_of));
+  };
+
+  auto pad_matrix_dims = [&pad_dim](Shape s) {
+    // Since the dot instruction is canonicalized, the last two dimensions for
+    // each operand represent non-batch dimensions, and the others are the same
+    // for both operands and correspond to batch dimensions.
+    pad_dim(s, s.rank() - 2);
+    pad_dim(s, s.rank() - 1);
+    return s;
+  };
+
+  Shape new_lshape = pad_matrix_dims(lshape);
+  Shape new_rshape = pad_matrix_dims(rshape);
+  Shape new_result_shape = pad_matrix_dims(result_shape);
+
+  if (new_lshape == lshape && new_rshape == rshape) {
+    return false;
+  }
+
+  VLOG(3) << "old shape: " << lshape << " " << rshape << " " << result_shape;
+  VLOG(3) << "new shape: " << new_lshape << " " << new_rshape << " "
+          << new_result_shape;
+
+  auto create_padding_config = [](Shape& shape, Shape& new_shape) {
+    PaddingConfig padding_config;
+    for (int i = 0; i < shape.rank(); ++i) {
+      auto dimension = padding_config.add_dimensions();
+      dimension->set_edge_padding_high(new_shape.dimensions()[i] -
+                                       shape.dimensions()[i]);
+      dimension->set_edge_padding_low(0);
+      dimension->set_interior_padding(0);
+    }
+    return padding_config;
+  };
+
+  auto l_padding_config = create_padding_config(lshape, new_lshape);
+  auto r_padding_config = create_padding_config(rshape, new_rshape);
+
+  HloComputation* parent = dot->parent();
+
+  HloInstruction* zero_float = parent->AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::Zero(datatype)));
+  zero_float->set_metadata(dot->metadata());
+
+  HloInstruction* lpad = parent->AddInstruction(
+      HloInstruction::CreatePad(new_lshape, lhs, zero_float, l_padding_config));
+  lpad->set_metadata(dot->metadata());
+
+  HloInstruction* rpad = parent->AddInstruction(
+      HloInstruction::CreatePad(new_rshape, rhs, zero_float, r_padding_config));
+  rpad->set_metadata(dot->metadata());
+
+  HloInstruction* new_dot = parent->AddInstruction(
+      dot->CloneWithNewOperands(new_result_shape, {lpad, rpad}));
+
+  std::vector<int64_t> start_indices(result_shape.rank(), 0);
+  std::vector<int64_t> strides(result_shape.rank(), 1);
+  HloInstruction* slice = parent->AddInstruction(
+      HloInstruction::CreateSlice(result_shape, new_dot, start_indices,
+                                  result_shape.dimensions(), strides));
+  slice->set_metadata(dot->metadata());
+
+  bool is_root = dot->user_count() == 0;
+
+  TF_CHECK_OK(parent->ReplaceInstruction(dot, slice));
+
+  if (is_root) {
+    parent->set_root_instruction(slice);
+  }
+
+  return true;
+}
+
+namespace {
+
+// We need this check because PadForGemm works in the assumption that
+// the dot instruction is canonicalized.
+bool CheckCanonical(HloDotInstruction* dot) {
+  const auto& dimension_numbers = dot->dot_dimension_numbers();
+
+  if (dimension_numbers.lhs_batch_dimensions_size() + 2 !=
+          dot->operand(0)->shape().rank() ||
+      dimension_numbers.rhs_batch_dimensions_size() + 2 !=
+          dot->operand(1)->shape().rank()) {
+    VLOG(2)
+        << dot->ToString()
+        << " is not canonical: Expected all dimensions but 2 to be "
+           "batch_dimensions. Hence, this dot is not a candidate for padding.";
+    return false;
+  }
+
+  std::vector<int64_t> canonical_batch_dims(
+      dimension_numbers.lhs_batch_dimensions_size());
+  absl::c_iota(canonical_batch_dims, 0);
+  if (!absl::c_equal(dimension_numbers.lhs_batch_dimensions(),
+                     canonical_batch_dims) ||
+      !absl::c_equal(dimension_numbers.rhs_batch_dimensions(),
+                     canonical_batch_dims)) {
+    VLOG(2)
+        << dot->ToString()
+        << " is not canonical: Expected batch dimensions to be all "
+           "dimensions except for the last 2 ones. Hence, this dot is not a "
+           "candidate for padding.";
+    return false;
+  }
+
+  return true;
+}
+
+}  // namespace
+
+static std::vector<HloDotInstruction*> GetRelevantDots(
+    const se::GpuComputeCapability& gpu_compute_capability,
+    HloComputation* comp, PrimitiveType datatype) {
+  std::vector<HloDotInstruction*> gemms;
+
+  for (HloInstruction* instr : comp->instructions()) {
+    if (IsMatrixMultiplication(*instr)) {
+      HloDotInstruction* dot = Cast<HloDotInstruction>(instr);
+      if (instr->operand(0)->shape().element_type() == datatype &&
+          CheckCanonical(dot) &&
+          !(instr->GetModule()
+                ->config()
+                .debug_options()
+                .xla_gpu_enable_triton_gemm() &&
+            legacy_triton::IsTritonSupportedInstruction(
+                *dot, gpu_compute_capability) &&
+            ShouldTritonHandleGEMM(*dot, gpu_compute_capability))) {
+        gemms.push_back(dot);
+      }
+    }
+  }
+  return gemms;
+}
+
+absl::StatusOr<bool> CublasPadForGemms::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  bool changed = false;
+  for (HloComputation* comp :
+       module->MakeNonfusionComputations(execution_threads)) {
+    for (HloDotInstruction* dot :
+         GetRelevantDots(gpu_compute_capability_, comp, datatype_)) {
+      TF_ASSIGN_OR_RETURN(bool result,
+                          PadForGemm(dot, datatype_, pad_to_multiple_of_));
+      changed |= result;
+    }
+  }
+  return changed;
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/cublas_pad_for_gemms.h b/third_party/xla/xla/service/gpu/transforms/cublas_pad_for_gemms.h
new file mode 100644
index 0000000..8c7d8e5
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cublas_pad_for_gemms.h
@@ -0,0 +1,63 @@
+/* Copyright 2018 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_CUBLAS_PAD_FOR_GEMMS_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_CUBLAS_PAD_FOR_GEMMS_H_
+
+#include <cstdint>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/stream_executor/device_description.h"
+
+namespace xla {
+namespace gpu {
+
+// Adds padding to dot operations to make them run faster on GPUs.
+//
+//
+// This can be used to pad f16 dots on tensor cores, or s8 dots to multiples of
+// four.
+//
+// This pass depends on xla::DotDecomposer pass,
+// so it should go strictly later.
+class CublasPadForGemms : public HloModulePass {
+ public:
+  CublasPadForGemms(const se::GpuComputeCapability gpu_compute_capability,
+                    PrimitiveType datatype, int32_t pad_to_multiple_of)
+      : gpu_compute_capability_(gpu_compute_capability),
+        datatype_(datatype),
+        pad_to_multiple_of_(pad_to_multiple_of) {}
+
+  absl::string_view name() const override { return "cublas-pad-for-gemms"; }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+  const se::GpuComputeCapability gpu_compute_capability_;
+  PrimitiveType datatype_;
+  int32_t pad_to_multiple_of_;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_CUBLAS_PAD_FOR_GEMMS_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/cublas_pad_for_gemms_test.cc b/third_party/xla/xla/service/gpu/transforms/cublas_pad_for_gemms_test.cc
new file mode 100644
index 0000000..77a32c9
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cublas_pad_for_gemms_test.cc
@@ -0,0 +1,306 @@
+/* Copyright 2018 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/cublas_pad_for_gemms.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/tests/hlo_test_base.h"
+
+namespace m = ::xla::match;
+
+namespace xla {
+namespace gpu {
+namespace {
+
+class CublasGemmPadForTensorCoresTest : public HloTestBase {
+ protected:
+  bool PadForF16Gemms(HloModule* module) {
+    return CublasPadForGemms(se::CudaComputeCapability(7, 0),
+                             PrimitiveType::F16, 8)
+        .Run(module)
+        .value();
+  }
+
+  DebugOptions GetDebugOptionsForTest() override {
+    DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest();
+    // Some pads would not be added if we detect that Triton will handle the
+    // given dot operation.
+    debug_options.set_xla_gpu_triton_gemm_any(false);
+    return debug_options;
+  }
+};
+
+TEST_F(CublasGemmPadForTensorCoresTest, OneDotRootComputation) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule TestModule
+
+  ENTRY TestComputation {
+    %param1 = f16[2048,1024] parameter(0)
+    %param2 = f16[1024,33708] parameter(1)
+    ROOT %dot.2309 = f16[2048,33708]{1,0} dot(f16[2048,1024]{1,0} %param1,
+                f16[1024,33708]{0,1} %param2),
+                lhs_contracting_dims={1}, rhs_contracting_dims={0}
+                })")
+                    .value();
+
+  EXPECT_TRUE(PadForF16Gemms(module.get()));
+  SCOPED_TRACE(module->ToString());
+
+  auto* root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(
+      root,
+      GmockMatch(
+          m::Slice(m::Dot(m::Pad(m::Parameter().WithShape(F16, {2048, 1024}),
+                                 m::Constant().WithShape(F16, {}))
+                              .WithShape(F16, {2048, 1024}),
+                          m::Pad(m::Parameter().WithShape(F16, {1024, 33708}),
+                                 m::Constant().WithShape(F16, {}))
+                              .WithShape(F16, {1024, 33712}))
+                       .WithShape(F16, {2048, 33712})
+                       .WithContractingDims(/*lhs_contracting_dims=*/{1},
+                                            /*rhs_contracting_dims=*/{0}))
+              .WithShape(F16, {2048, 33708})));
+}
+
+TEST_F(CublasGemmPadForTensorCoresTest, OneDotS8RootComputation) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule TestModule
+
+  ENTRY TestComputation {
+    %param1 = s8[2047,1023] parameter(0)
+    %param2 = s8[1023,33707] parameter(1)
+    ROOT %dot.2309 = s32[2047,33707]{1,0} dot(s8[2047,1023]{1,0} %param1,
+                s8[1023,33707]{0,1} %param2),
+                lhs_contracting_dims={1}, rhs_contracting_dims={0}
+                })")
+                    .value();
+
+  EXPECT_TRUE(
+      CublasPadForGemms(se::CudaComputeCapability(7, 0), PrimitiveType::S8, 4)
+          .Run(module.get())
+          .value());
+  SCOPED_TRACE(module->ToString());
+
+  auto* root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(
+      root,
+      GmockMatch(
+          m::Slice(m::Dot(m::Pad(m::Parameter().WithShape(S8, {2047, 1023}),
+                                 m::Constant().WithShape(S8, {}))
+                              .WithShape(S8, {2048, 1024}),
+                          m::Pad(m::Parameter().WithShape(S8, {1023, 33707}),
+                                 m::Constant().WithShape(S8, {}))
+                              .WithShape(S8, {1024, 33708}))
+                       .WithShape(S32, {2048, 33708})
+                       .WithContractingDims(/*lhs_contracting_dims=*/{1},
+                                            /*rhs_contracting_dims=*/{0}))
+              .WithShape(S32, {2047, 33707})));
+}
+
+TEST_F(CublasGemmPadForTensorCoresTest, TwoDotsComputation) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule TestModule
+
+  ENTRY TestComputation {
+    %param1 = f16[2048, 1024] parameter(0)
+    %param2 = f16[1024, 33708] parameter(1)
+    %param3 = f16[33708, 1] parameter(2)
+    %dot1 = f16[2048, 33708]{1,0} dot(f16[2048, 1024]{1,0} %param1,
+                f16[1024, 33708]{0,1} %param2),
+                lhs_contracting_dims={1}, rhs_contracting_dims={0}
+    ROOT %dot2 = f16[2048, 1]{1,0} dot(f16[2048, 33708]{1,0} %dot1,
+                f16[33708, 1]{0,1} %param3),
+                lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  })")
+                    .value();
+
+  EXPECT_TRUE(PadForF16Gemms(module.get()));
+  SCOPED_TRACE(module->ToString());
+
+  auto* root = module->entry_computation()->root_instruction();
+  const HloInstruction* dot2 = nullptr;
+  ASSERT_THAT(
+      root,
+      GmockMatch(
+          m::Slice(
+              m::Dot(
+                  m::Pad(m::Slice(m::Dot(&dot2,
+                                         m::Pad().WithShape(F16, {2048, 1024}),
+                                         m::Pad().WithShape(F16, {1024, 33712}))
+                                      .WithContractingDims(
+                                          /*lhs_contracting_dims=*/{1},
+                                          /*rhs_contracting_dims=*/{0})
+                                      .WithShape(F16, {2048, 33712}))
+                             .WithShape(F16, {2048, 33708}),
+                         m::Constant().WithShape(F16, {}))
+                      .WithShape(F16, {2048, 33712}),
+
+                  m::Pad(m::Parameter().WithShape(F16, {33708, 1}),
+                         m::Constant().WithShape(F16, {}))
+                      .WithShape(F16, {33712, 8}))
+                  .WithShape(F16, {2048, 8})
+                  .WithContractingDims(/*lhs_contracting_dims=*/{1},
+                                       /*rhs_contracting_dims=*/{0}))
+              .WithShape(F16, {2048, 1})));
+
+  EXPECT_THAT(
+      dot2,
+      GmockMatch(m::Dot(m::Pad(m::Parameter().WithShape(F16, {2048, 1024}),
+                               m::Constant().WithShape(F16, {}))
+                            .WithShape(F16, {2048, 1024}),
+                        m::Pad(m::Parameter().WithShape(F16, {1024, 33708}),
+                               m::Constant().WithShape(F16, {}))
+                            .WithShape(F16, {1024, 33712}))
+                     .WithContractingDims(/*lhs_contracting_dims=*/{1},
+                                          /*rhs_contracting_dims=*/{0})));
+}
+
+TEST_F(CublasGemmPadForTensorCoresTest, DotWithBatchDimensions) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule TestModule
+
+  ENTRY TestComputation {
+    %param1 = f16[3, 5, 2048, 1024] parameter(0)
+    %param2 = f16[3, 5, 1024, 33708] parameter(1)
+    ROOT %dot.2309 = f16[3, 5, 2048, 33708]{3, 2, 1,0} dot(f16[3, 5, 2048, 1024]{3, 2, 1,0} %param1,
+                f16[3, 5, 1024, 33708]{2, 3, 0,1} %param2), lhs_batch_dims={0, 1}, rhs_batch_dims={0, 1}, lhs_contracting_dims={3}, rhs_contracting_dims={2}})")
+                    .value();
+
+  EXPECT_TRUE(PadForF16Gemms(module.get()));
+  SCOPED_TRACE(module->ToString());
+
+  auto* root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(
+      root,
+      GmockMatch(
+          m::Slice(
+              m::Dot(m::Pad(m::Parameter().WithShape(F16, {3, 5, 2048, 1024}),
+                            m::Constant().WithShape(F16, {}))
+                         .WithShape(F16, {3, 5, 2048, 1024}),
+                     m::Pad(m::Parameter().WithShape(F16, {3, 5, 1024, 33708}),
+                            m::Constant().WithShape(F16, {}))
+                         .WithShape(F16, {3, 5, 1024, 33712}))
+                  .WithShape(F16, {3, 5, 2048, 33712})
+                  .WithContractingDims(/*lhs_contracting_dims=*/{3},
+                                       /*rhs_contracting_dims=*/{2}))
+              .WithShape(F16, {3, 5, 2048, 33708})));
+}
+
+TEST_F(CublasGemmPadForTensorCoresTest, NoDotComputation) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule TestModule
+
+  ENTRY TestComputation {
+    %x = f32[] parameter(0)
+    %y = f32[] parameter(1)
+    ROOT %maximum = f32[] maximum(f32[] %x, f32[] %y)
+  })")
+                    .value();
+
+  EXPECT_FALSE(PadForF16Gemms(module.get()));
+}
+
+TEST_F(CublasGemmPadForTensorCoresTest, F32DotComputation) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule TestModule
+
+  ENTRY TestComputation {
+    %param1 = f32[2048,1024] parameter(0)
+    %param2 = f32[1024,33708] parameter(1)
+    ROOT %dot.2309 = f32[2048,33708]{1,0} dot(f32[2048,1024]{1,0} %param1,
+                f32[1024,33708]{0,1} %param2),
+                lhs_contracting_dims={1}, rhs_contracting_dims={0}})")
+                    .value();
+
+  EXPECT_FALSE(PadForF16Gemms(module.get()));
+}
+
+TEST_F(CublasGemmPadForTensorCoresTest, F64DotComputation) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule TestModule
+
+  ENTRY TestComputation {
+    %param1 = f64[2048,1024] parameter(0)
+    %param2 = f64[1024,33708] parameter(1)
+    ROOT %dot.2309 = f64[2048,33708]{1,0} dot(f64[2048,1024]{1,0} %param1,
+                f64[1024,33708]{0,1} %param2),
+                lhs_contracting_dims={1}, rhs_contracting_dims={0}})")
+                    .value();
+
+  EXPECT_FALSE(PadForF16Gemms(module.get()));
+}
+
+TEST_F(CublasGemmPadForTensorCoresTest, MultiplesOf8DotComputation) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule TestModule
+
+  ENTRY TestComputation {
+    %param1 = f16[2048,1024] parameter(0)
+    %param2 = f16[1024,33712] parameter(1)
+    ROOT %dot.2309 = f16[2048,33712]{1,0} dot(f16[2048,1024]{1,0} %param1,
+                f16[1024,33712]{0,1} %param2),
+                lhs_contracting_dims={1}, rhs_contracting_dims={0}})")
+                    .value();
+
+  EXPECT_FALSE(PadForF16Gemms(module.get()));
+}
+
+TEST_F(CublasGemmPadForTensorCoresTest, CheckSavingMetadata) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule TestModule
+
+  ENTRY TestComputation {
+    %param1 = f16[2048,1024] parameter(0)
+    %param2 = f16[1024,33708] parameter(1)
+    ROOT %dot.2309 = f16[2048,33708]{1,0} dot(f16[2048,1024]{1,0} %param1,
+                f16[1024,33708]{0,1} %param2),
+                lhs_contracting_dims={1}, rhs_contracting_dims={0},
+                metadata={op_type="MatMul" op_name="transformer_v2/Transformer/decode/embedding_shared_weights_1/presoftmax_linear/MatMul"}
+                })")
+                    .value();
+
+  SCOPED_TRACE(module->ToString());
+
+  EXPECT_TRUE(PadForF16Gemms(module.get()));
+  auto metadata = module->entry_computation()->root_instruction()->metadata();
+  EXPECT_EQ("MatMul", metadata.op_type());
+  EXPECT_EQ(
+      "transformer_v2/Transformer/decode/embedding_shared_weights_1/"
+      "presoftmax_linear/MatMul",
+      metadata.op_name());
+}
+
+TEST_F(CublasGemmPadForTensorCoresTest, NotCanonicalizedDot) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule TestModule
+
+  ENTRY TestComputation {
+    %param1 = f16[3, 5, 2048, 1024] parameter(0)
+    %param2 = f16[3, 5, 1024, 33708] parameter(1)
+    ROOT %dot.2309 = f16[3,2048, 33708]{2, 1, 0} dot(f16[3, 5, 2048, 1024]{3, 2, 1, 0} %param1, f16[3, 5, 1024, 33708]{3, 2, 1, 0} %param2), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={3, 1}, rhs_contracting_dims={2, 1}})")
+                    .value();
+
+  EXPECT_FALSE(PadForF16Gemms(module.get()));
+}
+
+}  // anonymous namespace
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc
new file mode 100644
index 0000000..00b73c9
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc
@@ -0,0 +1,660 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/cudnn_custom_call_compiler.h"
+
+#include <optional>
+#include <vector>
+
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_clone_context.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/service/gpu/ir_emission_utils.h"
+#include "xla/service/gpu/stream_executor_util.h"
+#include "xla/stream_executor/cuda/cuda_dnn.h"
+#include "xla/stream_executor/cuda/cudnn_frontend_helpers.h"
+#include "xla/util.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+
+inline absl::StatusOr<CudnnfMHAMaskKind> AsCudnnFmhaMaskKind(
+    CudnnfMHABackendConfig_MaskType mask_type) {
+  switch (mask_type) {
+    case CudnnfMHABackendConfig::NO_MASK:
+      return CudnnfMHAMaskKind::kNoMask;
+    case CudnnfMHABackendConfig::PADDING:
+      return CudnnfMHAMaskKind::kPadding;
+    case CudnnfMHABackendConfig::CAUSAL:
+      return CudnnfMHAMaskKind::kCausal;
+    case CudnnfMHABackendConfig::PADDING_CAUSAL:
+      return CudnnfMHAMaskKind::kPaddingCausal;
+    case CudnnfMHABackendConfig::ALIBI:
+      return CudnnfMHAMaskKind::kAlibi;
+    default:
+      return xla::Internal("Unknown fmha mask kind.");
+  }
+}
+
+// This is an interim structure to hold the parameters to construct a
+// GpufMHAConfig.
+// Struct to describe properties of a FMHA without being tied to specific
+// IR. Will be used to help build FMHA thunks from either XLA HLO or
+// LHLO GPU dialect in MLIR.
+struct GpufMHADescriptor {
+  CudnnfMHAKind kind;
+  CudnnfMHABackendConfig backend_config;
+  CudnnfMHAMaskKind mask_type;
+  Shape lhs_bmm1_shape;
+  Shape rhs_bmm1_shape;
+  Shape rhs_bmm2_shape;
+  Shape intermediate_lhs_bmm2_shape;
+  // This will contain both output shape and activation shape
+  absl::InlinedVector<Shape, 2> output_shapes;
+  DotDimensionNumbers bmm1_dnums;
+  DotDimensionNumbers bmm2_dnums;
+
+  std::optional<Shape> mask_shape;
+  std::optional<Shape> bias_shape;
+};
+
+struct GpufMHABackwardDescriptor {
+  CudnnfMHAKind kind;
+  CudnnfMHABackendConfig backend_config;
+  CudnnfMHAMaskKind mask_type;
+  Shape bmm1_grad_gemm1_rhs_shape;
+  Shape bmm1_grad_gemm2_rhs_shape;
+  Shape bmm2_grad_gemm1_lhs_shape;
+  Shape bmm2_grad_gemm2_rhs_shape;
+  Shape d_output_shape;
+  Shape d_bmm1_lhs_shape;
+  Shape d_bmm1_rhs_shape;
+  Shape d_bmm2_rhs_shape;
+  DotDimensionNumbers bmm1_grad_gemm1_dnums;
+  DotDimensionNumbers bmm1_grad_gemm2_dnums;
+  DotDimensionNumbers bmm2_grad_gemm1_dnums;
+  DotDimensionNumbers bmm2_grad_gemm2_dnums;
+
+  std::optional<Shape> d_s_shape;
+  std::optional<Shape> fwd_output_shape;
+  std::optional<Shape> mask_shape;
+  std::optional<Shape> d_bias_shape;
+  std::optional<Shape> bias_shape;
+  bool force_deterministic;
+};
+
+// Structure to describe static properties of a GPU fused Multi-Headed
+// Attention.
+struct GpufMHAConfig {
+  static absl::StatusOr<GpufMHAConfig> For(const GpufMHADescriptor &fmha_desc);
+  PrimitiveType
+      input_type;  // Capture the primitive type of one of the inputs of BMM1
+  PrimitiveType output_type;
+  CudnnfMHAKind kind;
+  std::optional<double> fmha_scale;
+  std::optional<double> dropout_rate;
+  std::optional<int64_t> seed;
+
+  se::dnn::AlgorithmDesc algorithm;
+  CudnnfMHAMaskKind mask_type;
+  // bias -> [1, num_attn_heads, q_seq_len, kv_seq_len]
+  // mask -> [batch_size, 1, q_seq_len, kv_seq_len]
+  se::dnn::MatmulTensorDescriptor lhs_bmm1;
+  se::dnn::MatmulTensorDescriptor rhs_bmm1;
+  se::dnn::MatmulTensorDescriptor rhs_bmm2;
+  se::dnn::MatmulTensorDescriptor intermediate_lhs_bmm2;
+  se::dnn::TensorDescriptor output;
+
+  std::optional<se::dnn::TensorDescriptor> activation;
+  std::optional<se::dnn::TensorDescriptor> mask;
+  std::optional<se::dnn::TensorDescriptor> bias;
+};
+
+// Structure to describe static properties of a GPU fused Multi-Headed
+// Attention backward.
+struct GpufMHABackwardConfig {
+  static absl::StatusOr<GpufMHABackwardConfig> For(
+      const GpufMHABackwardDescriptor &fmha_desc);
+  PrimitiveType
+      input_type;  // Capture the primitive type of one of the inputs of BMM1
+  PrimitiveType output_type;
+  CudnnfMHAKind kind;
+  std::optional<double> fmha_scale;
+  std::optional<double> dropout_rate;
+  std::optional<int64_t> seed;
+
+  se::dnn::AlgorithmDesc algorithm;
+  CudnnfMHAMaskKind mask_type;
+  // mask -> [batch_size, 1, q_seq_len, kv_seq_len]
+  // d_bias -> [1, num_heads, q_seq_len, kv_seq_len]
+  se::dnn::MatmulTensorDescriptor bmm1_grad_gemm1_rhs;
+  se::dnn::MatmulTensorDescriptor bmm1_grad_gemm2_rhs;
+  se::dnn::MatmulTensorDescriptor bmm2_grad_gemm1_lhs;
+  se::dnn::MatmulTensorDescriptor bmm2_grad_gemm2_rhs;
+  se::dnn::MatmulTensorDescriptor d_output;
+  se::dnn::TensorDescriptor d_bmm1_lhs;
+  se::dnn::TensorDescriptor d_bmm1_rhs;
+  se::dnn::TensorDescriptor d_bmm2_rhs;
+  std::optional<se::dnn::TensorDescriptor> d_s;
+  std::optional<se::dnn::TensorDescriptor> mask;
+  std::optional<se::dnn::TensorDescriptor> d_bias;
+  std::optional<se::dnn::TensorDescriptor> fwd_output;
+  std::optional<se::dnn::TensorDescriptor> bias;
+};
+
+using se::DeviceMemory;
+using se::DeviceMemoryBase;
+using se::dnn::DataType;
+using se::dnn::MatmulTensorDescriptor;
+using se::dnn::TensorDescriptor;
+
+/*static*/ absl::StatusOr<GpufMHAConfig> GpufMHAConfig::For(
+    const GpufMHADescriptor &desc) {
+  // Get shapes from desc.
+  const Shape &lhs_bmm1_shape = desc.lhs_bmm1_shape;
+  const Shape &rhs_bmm1_shape = desc.rhs_bmm1_shape;
+  const Shape &rhs_bmm2_shape = desc.rhs_bmm2_shape;
+  const Shape &intermediate_lhs_bmm2_shape = desc.intermediate_lhs_bmm2_shape;
+  const Shape &output_shape = desc.output_shapes[0];
+
+  // Get DNN dtype from primtive types
+  TF_ASSIGN_OR_RETURN(
+      DataType lhs_bmm1_type,
+      GetDNNDataTypeFromPrimitiveType(lhs_bmm1_shape.element_type()));
+  TF_ASSIGN_OR_RETURN(
+      DataType rhs_bmm1_type,
+      GetDNNDataTypeFromPrimitiveType(rhs_bmm1_shape.element_type()));
+
+  TF_ASSIGN_OR_RETURN(
+      DataType rhs_bmm2_type,
+      GetDNNDataTypeFromPrimitiveType(rhs_bmm2_shape.element_type()));
+  TF_ASSIGN_OR_RETURN(DataType lhs_bmm2_type,
+                      GetDNNDataTypeFromPrimitiveType(
+                          intermediate_lhs_bmm2_shape.element_type()));
+  TF_ASSIGN_OR_RETURN(DataType output_type, GetDNNDataTypeFromPrimitiveType(
+                                                output_shape.element_type()));
+  GpufMHAConfig config;
+  config.input_type = lhs_bmm1_shape.element_type();
+  config.output_type = output_shape.element_type();
+
+  // Get MatmulTensorDescriptors for BMM1
+  config.lhs_bmm1 =
+      MatmulTensorDescriptor::For(lhs_bmm1_type, lhs_bmm1_shape.dimensions(),
+                                  desc.lhs_bmm1_shape.layout().minor_to_major(),
+                                  desc.bmm1_dnums.lhs_batch_dimensions(),
+                                  desc.bmm1_dnums.lhs_contracting_dimensions());
+  config.rhs_bmm1 =
+      MatmulTensorDescriptor::For(rhs_bmm1_type, rhs_bmm1_shape.dimensions(),
+                                  desc.rhs_bmm1_shape.layout().minor_to_major(),
+                                  desc.bmm1_dnums.rhs_batch_dimensions(),
+                                  desc.bmm1_dnums.rhs_contracting_dimensions());
+
+  // Get MatmulTensorDescriptors for BMM2
+  config.rhs_bmm2 =
+      MatmulTensorDescriptor::For(rhs_bmm2_type, rhs_bmm2_shape.dimensions(),
+                                  desc.rhs_bmm2_shape.layout().minor_to_major(),
+                                  desc.bmm2_dnums.rhs_batch_dimensions(),
+                                  desc.bmm2_dnums.rhs_contracting_dimensions());
+
+  config.intermediate_lhs_bmm2 = MatmulTensorDescriptor::For(
+      lhs_bmm2_type, intermediate_lhs_bmm2_shape.dimensions(),
+      desc.intermediate_lhs_bmm2_shape.layout().minor_to_major(),
+      desc.bmm2_dnums.lhs_batch_dimensions(),
+      desc.bmm2_dnums.lhs_contracting_dimensions());
+
+  config.output = TensorDescriptor::For(output_type, output_shape.dimensions(),
+                                        output_shape.layout().minor_to_major());
+
+  if (desc.output_shapes.size() > 1) {
+    const Shape &activation_shape = desc.output_shapes.back();
+    // Generally, activation should have same type as output, but set it
+    // explicityly just to be safe.
+    TF_ASSIGN_OR_RETURN(
+        DataType activation_type,
+        GetDNNDataTypeFromPrimitiveType(activation_shape.element_type()));
+    config.activation =
+        TensorDescriptor::For(activation_type, activation_shape.dimensions(),
+                              activation_shape.layout().minor_to_major());
+  }
+
+  if (desc.mask_shape) {
+    const Shape &mask_shape = *desc.mask_shape;
+    TF_ASSIGN_OR_RETURN(DataType mask_type, GetDNNDataTypeFromPrimitiveType(
+                                                mask_shape.element_type()));
+    config.mask = TensorDescriptor::For(mask_type, mask_shape.dimensions(),
+                                        mask_shape.layout().minor_to_major());
+  }
+
+  if (desc.bias_shape) {
+    const Shape &bias_shape = *desc.bias_shape;
+    TF_ASSIGN_OR_RETURN(DataType bias_type, GetDNNDataTypeFromPrimitiveType(
+                                                bias_shape.element_type()));
+    config.bias = TensorDescriptor::For(bias_type, bias_shape.dimensions(),
+                                        bias_shape.layout().minor_to_major());
+  }
+  config.kind = desc.kind;
+  config.mask_type = desc.mask_type;
+  const CudnnfMHABackendConfig &backend_config = desc.backend_config;
+  config.algorithm = se::dnn::AlgorithmDesc(backend_config.algorithm());
+  config.fmha_scale.emplace(backend_config.fmha_scale());
+  config.dropout_rate.emplace(backend_config.dropout_rate());
+  config.seed.emplace(backend_config.seed());
+  return config;
+}
+
+/*static*/ absl::StatusOr<GpufMHABackwardConfig> GpufMHABackwardConfig::For(
+    const GpufMHABackwardDescriptor &desc) {
+  // Get shapes from desc.
+  const Shape &bmm1_grad_gemm1_rhs_shape = desc.bmm1_grad_gemm1_rhs_shape;
+  const Shape &bmm1_grad_gemm2_rhs_shape = desc.bmm1_grad_gemm2_rhs_shape;
+  const Shape &bmm2_grad_gemm1_lhs_shape = desc.bmm2_grad_gemm1_lhs_shape;
+  const Shape &bmm2_grad_gemm2_rhs_shape = desc.bmm2_grad_gemm2_rhs_shape;
+  const Shape &d_output_shape = desc.d_output_shape;
+  const Shape &d_bmm1_lhs_shape = desc.d_bmm1_lhs_shape;
+  const Shape &d_bmm1_rhs_shape = desc.d_bmm1_rhs_shape;
+  const Shape &d_bmm2_rhs_shape = desc.d_bmm2_rhs_shape;
+  // Get DNN dtype from primtive types
+  TF_ASSIGN_OR_RETURN(DataType bmm1_grad_gemm1_rhs_type,
+                      GetDNNDataTypeFromPrimitiveType(
+                          bmm1_grad_gemm1_rhs_shape.element_type()));
+
+  TF_ASSIGN_OR_RETURN(DataType bmm1_grad_gemm2_rhs_type,
+                      GetDNNDataTypeFromPrimitiveType(
+                          bmm1_grad_gemm2_rhs_shape.element_type()));
+
+  TF_ASSIGN_OR_RETURN(DataType bmm2_grad_gemm1_lhs_type,
+                      GetDNNDataTypeFromPrimitiveType(
+                          bmm2_grad_gemm1_lhs_shape.element_type()));
+
+  TF_ASSIGN_OR_RETURN(DataType bmm2_grad_gemm2_rhs_type,
+                      GetDNNDataTypeFromPrimitiveType(
+                          bmm2_grad_gemm2_rhs_shape.element_type()));
+
+  TF_ASSIGN_OR_RETURN(
+      DataType d_output_type,
+      GetDNNDataTypeFromPrimitiveType(d_output_shape.element_type()));
+
+  TF_ASSIGN_OR_RETURN(
+      DataType d_bmm1_lhs_type,
+      GetDNNDataTypeFromPrimitiveType(d_bmm1_lhs_shape.element_type()));
+
+  TF_ASSIGN_OR_RETURN(
+      DataType d_bmm1_rhs_type,
+      GetDNNDataTypeFromPrimitiveType(d_bmm1_rhs_shape.element_type()));
+
+  TF_ASSIGN_OR_RETURN(
+      DataType d_bmm2_rhs_type,
+      GetDNNDataTypeFromPrimitiveType(d_bmm2_rhs_shape.element_type()));
+
+  GpufMHABackwardConfig config;
+  config.input_type = bmm1_grad_gemm1_rhs_shape.element_type();
+  config.output_type = d_bmm1_lhs_shape.element_type();
+
+  // Get MatmulTensorDescriptors for lhs of BMM1 grad GEMM 1
+  config.bmm1_grad_gemm1_rhs = MatmulTensorDescriptor::For(
+      bmm1_grad_gemm1_rhs_type, bmm1_grad_gemm1_rhs_shape.dimensions(),
+      desc.bmm1_grad_gemm1_rhs_shape.layout().minor_to_major(),
+      desc.bmm1_grad_gemm1_dnums.rhs_batch_dimensions(),
+      desc.bmm1_grad_gemm1_dnums.rhs_contracting_dimensions());
+
+  // Get MatmulTensorDescriptors for rhs of BMM1 grad GEMM 2
+  config.bmm1_grad_gemm2_rhs = MatmulTensorDescriptor::For(
+      bmm1_grad_gemm2_rhs_type, bmm1_grad_gemm2_rhs_shape.dimensions(),
+      desc.bmm1_grad_gemm2_rhs_shape.layout().minor_to_major(),
+      desc.bmm1_grad_gemm2_dnums.rhs_batch_dimensions(),
+      desc.bmm1_grad_gemm2_dnums.rhs_contracting_dimensions());
+
+  // Get MatmulTensorDescriptors for BMM2 grad GEMM 1
+  config.bmm2_grad_gemm1_lhs = MatmulTensorDescriptor::For(
+      bmm2_grad_gemm1_lhs_type, bmm2_grad_gemm1_lhs_shape.dimensions(),
+      desc.bmm2_grad_gemm1_lhs_shape.layout().minor_to_major(),
+      desc.bmm2_grad_gemm1_dnums.lhs_batch_dimensions(),
+      desc.bmm2_grad_gemm1_dnums.lhs_contracting_dimensions());
+
+  config.d_output = MatmulTensorDescriptor::For(
+      d_output_type, d_output_shape.dimensions(),
+      desc.d_output_shape.layout().minor_to_major(),
+      desc.bmm2_grad_gemm1_dnums.rhs_batch_dimensions(),
+      desc.bmm2_grad_gemm1_dnums.rhs_contracting_dimensions());
+
+  // Get MatmulTensorDescriptors for BMM2 grad GEMM 2
+  config.bmm2_grad_gemm2_rhs = MatmulTensorDescriptor::For(
+      bmm2_grad_gemm2_rhs_type, bmm2_grad_gemm2_rhs_shape.dimensions(),
+      desc.bmm2_grad_gemm2_rhs_shape.layout().minor_to_major(),
+      desc.bmm2_grad_gemm2_dnums.rhs_batch_dimensions(),
+      desc.bmm2_grad_gemm2_dnums
+          .rhs_contracting_dimensions());  // FMHA TODO: transpose here?
+
+  config.d_bmm1_lhs =
+      TensorDescriptor::For(d_bmm1_lhs_type, d_bmm1_lhs_shape.dimensions(),
+                            d_bmm1_lhs_shape.layout().minor_to_major());
+  config.d_bmm1_rhs =
+      TensorDescriptor::For(d_bmm1_rhs_type, d_bmm1_rhs_shape.dimensions(),
+                            d_bmm1_rhs_shape.layout().minor_to_major());
+  config.d_bmm2_rhs =
+      TensorDescriptor::For(d_bmm2_rhs_type, d_bmm2_rhs_shape.dimensions(),
+                            d_bmm2_rhs_shape.layout().minor_to_major());
+  config.d_s = TensorDescriptor::For(
+      bmm2_grad_gemm1_lhs_type, bmm2_grad_gemm1_lhs_shape.dimensions(),
+      bmm2_grad_gemm1_lhs_shape.layout().minor_to_major());
+
+  if (desc.d_bias_shape) {
+    const Shape &d_bias_shape = *desc.d_bias_shape;
+    // Get DNN dtype from primtive types
+    TF_ASSIGN_OR_RETURN(DataType d_bias_type, GetDNNDataTypeFromPrimitiveType(
+                                                  d_bias_shape.element_type()));
+    config.d_bias =
+        TensorDescriptor::For(d_bias_type, d_bias_shape.dimensions(),
+                              d_bias_shape.layout().minor_to_major());
+  }
+
+  if (desc.mask_shape) {
+    const Shape &mask_shape = *desc.mask_shape;
+    TF_ASSIGN_OR_RETURN(DataType mask_type, GetDNNDataTypeFromPrimitiveType(
+                                                mask_shape.element_type()));
+    config.mask = TensorDescriptor::For(mask_type, mask_shape.dimensions(),
+                                        mask_shape.layout().minor_to_major());
+  }
+  if (desc.fwd_output_shape) {
+    const Shape &fwd_output_shape = *desc.fwd_output_shape;
+    TF_ASSIGN_OR_RETURN(
+        DataType fwd_output_type,
+        GetDNNDataTypeFromPrimitiveType(fwd_output_shape.element_type()));
+    config.fwd_output =
+        TensorDescriptor::For(fwd_output_type, fwd_output_shape.dimensions(),
+                              fwd_output_shape.layout().minor_to_major());
+  }
+
+  if (desc.bias_shape) {
+    const Shape &bias_shape = *desc.bias_shape;
+    TF_ASSIGN_OR_RETURN(DataType bias_type, GetDNNDataTypeFromPrimitiveType(
+                                                bias_shape.element_type()));
+    config.bias = TensorDescriptor::For(bias_type, bias_shape.dimensions(),
+                                        bias_shape.layout().minor_to_major());
+  }
+
+  config.kind = desc.kind;
+  config.mask_type = desc.mask_type;
+  const CudnnfMHABackendConfig &backend_config = desc.backend_config;
+  config.algorithm = se::dnn::AlgorithmDesc(backend_config.algorithm());
+  config.fmha_scale.emplace(backend_config.fmha_scale());
+  config.dropout_rate.emplace(backend_config.dropout_rate());
+  config.seed.emplace(backend_config.seed());
+  return config;
+}
+
+absl::StatusOr<se::gpu::CudnnGraph> HloCustomCallToCuDnnGraph(
+    se::dnn::DnnSupport &dnn_support, HloCustomCallInstruction *custom_call) {
+  if (IsFwdCustomCallTofMHA(*custom_call)) {
+    TF_ASSIGN_OR_RETURN(const xla::gpu::CudnnfMHAKind kind,
+                        xla::gpu::GetCudnnfMHAKind(custom_call));
+    std::optional<Shape> mask_shape, bias_shape;
+    {
+      bool has_bias = kind == CudnnfMHAKind::kScaleBiasSoftmax ||
+                      kind == CudnnfMHAKind::kScaleBiasSoftmaxDropout;
+
+      if (has_bias) {
+        const HloInstruction *bias = custom_call->operand(3);
+        bias_shape = bias->shape();
+      }
+    }
+
+    TF_ASSIGN_OR_RETURN(
+        const auto gpu_config,
+        custom_call->backend_config<xla::gpu::GpuBackendConfig>());
+    const xla::gpu::CudnnfMHABackendConfig &config =
+        gpu_config.cudnn_fmha_backend_config();
+    Shape intermediate_tensor_shape(config.intermediate_tensor_shape());
+    absl::InlinedVector<Shape, 2> output_shapes = {
+        ShapeUtil::GetSubshape(custom_call->shape(), {0})};
+
+    bool has_activation =
+        xla::ShapeUtil::TupleElementCount(custom_call->shape()) == 3;
+    if (has_activation) {
+      output_shapes.push_back(
+          ShapeUtil::GetSubshape(custom_call->shape(), {1}));
+    }
+
+    Shape q_shape = custom_call->operand(0)->shape();
+    Shape k_shape = custom_call->operand(1)->shape();
+    Shape v_shape = custom_call->operand(2)->shape();
+    TF_ASSIGN_OR_RETURN(CudnnfMHAMaskKind cudnn_mask_type,
+                        AsCudnnFmhaMaskKind(config.mask_type()));
+    GpufMHADescriptor descriptor = {kind,
+                                    config,
+                                    cudnn_mask_type,
+                                    q_shape,
+                                    k_shape,
+                                    v_shape,
+                                    intermediate_tensor_shape,
+                                    output_shapes,
+                                    config.bmm1_dot_dimension_numbers(),
+                                    config.bmm2_dot_dimension_numbers(),
+                                    mask_shape,
+                                    bias_shape};
+
+    TF_ASSIGN_OR_RETURN(GpufMHAConfig fmha_config,
+                        GpufMHAConfig::For(descriptor));
+    TF_ASSIGN_OR_RETURN(
+        se::dnn::FMHAMaskKind dnn_mask_type,
+        GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(fmha_config.mask_type));
+    TF_ASSIGN_OR_RETURN(
+        se::gpu::CudnnGraph graph,
+        se::gpu::GetCudnnFlashAttentionOperationGraph(
+            dnn_support, fmha_config.lhs_bmm1, fmha_config.rhs_bmm1,
+            fmha_config.rhs_bmm2, fmha_config.output, fmha_config.bias,
+            fmha_config.activation, static_cast<float>(*fmha_config.fmha_scale),
+            fmha_config.dropout_rate && *fmha_config.dropout_rate > 0.0,
+            fmha_config.dropout_rate, dnn_mask_type));
+    return std::move(graph);
+  } else {
+    TF_ASSIGN_OR_RETURN(
+        auto gpu_config,
+        custom_call->backend_config<xla::gpu::GpuBackendConfig>());
+    xla::gpu::CudnnfMHABackendConfig &config =
+        *gpu_config.mutable_cudnn_fmha_backend_config();
+
+    int input_index = 0;
+    Shape bmm1_grad_gemm1_rhs_shape =
+        custom_call->operand(input_index++)->shape();
+    Shape bmm1_grad_gemm2_rhs_shape =
+        custom_call->operand(input_index++)->shape();
+    Shape bmm2_grad_gemm2_rhs_shape =
+        custom_call->operand(input_index++)->shape();
+    Shape bmm2_grad_gemm1_lhs_shape(config.intermediate_tensor_shape());
+    input_index++;
+    Shape d_output_shape = custom_call->operand(input_index++)->shape();
+
+    TF_ASSIGN_OR_RETURN(const CudnnfMHAKind kind,
+                        GetCudnnfMHAKind(custom_call));
+    std::optional<Shape> mask_shape;
+
+    bool has_bias = (kind == CudnnfMHAKind::kBackwardScaleBiasSoftmax ||
+                     kind == CudnnfMHAKind::kBackwardScaleBiasSoftmaxDropout);
+    std::optional<Shape> bias_shape;
+    if (has_bias) {
+      bias_shape = custom_call->operand(input_index++)->shape();
+    }
+
+    std::optional<Shape> fwd_output_shape =
+        custom_call->operand(input_index++)->shape();
+    if (config.mask_type() == xla::gpu::CudnnfMHABackendConfig::PADDING ||
+        config.mask_type() ==
+            xla::gpu::CudnnfMHABackendConfig::PADDING_CAUSAL) {
+      // skip q_seqlen and kv_seqlen
+      input_index += 2;
+    }
+    TF_RET_CHECK(input_index == custom_call->operand_count());
+
+    int output_index = 0;
+    Shape d_bmm1_lhs_shape =
+        ShapeUtil::GetSubshape(custom_call->shape(), {output_index++});
+    Shape d_bmm1_rhs_shape =
+        ShapeUtil::GetSubshape(custom_call->shape(), {output_index++});
+    Shape d_bmm2_rhs_shape =
+        ShapeUtil::GetSubshape(custom_call->shape(), {output_index++});
+    std::optional<Shape> d_s_shape;
+    std::optional<Shape> d_bias_shape;
+    bool has_dbias = custom_call->shape().tuple_shapes().size() == 5;
+    if (has_dbias) {
+      d_bias_shape =
+          ShapeUtil::GetSubshape(custom_call->shape(), {output_index++});
+    }
+    // The last one is the workspace.
+    TF_RET_CHECK(output_index ==
+                 custom_call->shape().tuple_shapes().size() - 1);
+    TF_ASSIGN_OR_RETURN(CudnnfMHAMaskKind cudnn_mask_type,
+                        AsCudnnFmhaMaskKind(config.mask_type()));
+
+    const DebugOptions &debug_options =
+        custom_call->GetModule()->config().debug_options();
+    bool force_deterministic =
+        debug_options.xla_gpu_deterministic_ops() ||
+        debug_options.xla_gpu_exclude_nondeterministic_ops();
+    config.set_force_deterministic(force_deterministic);
+    TF_RETURN_IF_ERROR(custom_call->set_backend_config(gpu_config));
+
+    GpufMHABackwardDescriptor descriptor = {
+        kind,
+        config,
+        cudnn_mask_type,
+        bmm1_grad_gemm1_rhs_shape,
+        bmm1_grad_gemm2_rhs_shape,
+        bmm2_grad_gemm1_lhs_shape,
+        bmm2_grad_gemm2_rhs_shape,
+        d_output_shape,
+        d_bmm1_lhs_shape,
+        d_bmm1_rhs_shape,
+        d_bmm2_rhs_shape,
+        config.bmm1_grad_gemm1_dot_dimension_numbers(),
+        config.bmm1_grad_gemm2_dot_dimension_numbers(),
+        config.bmm2_grad_gemm1_dot_dimension_numbers(),
+        config.bmm2_grad_gemm2_dot_dimension_numbers(),
+        d_s_shape,
+        fwd_output_shape,
+        mask_shape,
+        d_bias_shape,
+        bias_shape,
+        force_deterministic};
+
+    TF_ASSIGN_OR_RETURN(GpufMHABackwardConfig fmha_config,
+                        GpufMHABackwardConfig::For(descriptor));
+    TF_ASSIGN_OR_RETURN(
+        se::dnn::FMHAMaskKind dnn_mask_type,
+        GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(fmha_config.mask_type));
+
+    TF_ASSIGN_OR_RETURN(
+        se::gpu::CudnnGraph graph,
+        se::gpu::GetCudnnFlashAttentionBackwardOperationGraph(
+            dnn_support, fmha_config.bmm1_grad_gemm1_rhs,
+            fmha_config.bmm1_grad_gemm2_rhs, fmha_config.bmm2_grad_gemm1_lhs,
+            fmha_config.bmm2_grad_gemm2_rhs, fmha_config.d_output,
+            fmha_config.d_bmm1_lhs, fmha_config.d_bmm1_rhs,
+            fmha_config.d_bmm2_rhs, fmha_config.bias, fmha_config.dropout_rate,
+            fmha_config.seed, *fmha_config.fmha_scale,
+            fmha_config.dropout_rate && *fmha_config.dropout_rate > 0.0,
+            fmha_config.bias != std::nullopt, dnn_mask_type,
+            force_deterministic));
+    return std::move(graph);
+  }
+}
+
+class CuDnnCustomCallVisitor : public DfsHloRewriteVisitor {
+ public:
+  explicit CuDnnCustomCallVisitor(se::dnn::DnnSupport &dnn_support,
+                                  BinaryMap &compilation_results)
+      : dnn_support_(dnn_support), compilation_results_(compilation_results) {}
+
+  void AddWorkspace(HloInstruction &hlo, int64_t workspace_size) {
+    if (workspace_size == 0) {
+      return;
+    }
+    VLOG(4) << "Applying workspace size " << workspace_size << " to "
+            << hlo.ToString();
+    Shape *shape = hlo.mutable_shape();
+    shape->mutable_tuple_shapes()->back().set_dimensions(0, workspace_size);
+  }
+
+  absl::Status HandleCustomCall(HloInstruction *hlo) override {
+    if (!IsCustomCallTofMHA(*hlo)) {
+      return absl::OkStatus();
+    }
+
+    TF_ASSIGN_OR_RETURN(const std::string fingerprint_without_workspace,
+                        FingerprintWithBackendConfig<GpuBackendConfig>(*hlo));
+    auto workspace_size_it =
+        workspace_sizes_.find(fingerprint_without_workspace);
+    if (workspace_size_it == workspace_sizes_.cend()) {
+      TF_ASSIGN_OR_RETURN(
+          se::gpu::CudnnGraph graph,
+          HloCustomCallToCuDnnGraph(dnn_support_,
+                                    DynCast<HloCustomCallInstruction>(hlo)));
+
+      const int64_t workspace_size = graph.Graph().get_workspace_size();
+      workspace_sizes_.insert(workspace_size_it,
+                              {fingerprint_without_workspace, workspace_size});
+      AddWorkspace(*hlo, workspace_size);
+
+      std::vector<uint8_t> serialized_graph;
+      RETURN_IF_CUDNN_FRONTEND_ERROR(graph.Graph().serialize(serialized_graph));
+      // Compute a new fingerprint with a potential workspace for the
+      // compilation results to match a fingerprint computed by the emitter.
+      TF_ASSIGN_OR_RETURN(const std::string fingerprint_with_workspace,
+                          FingerprintWithBackendConfig<GpuBackendConfig>(*hlo));
+      compilation_results_[fingerprint_with_workspace] =
+          std::string(reinterpret_cast<char *>(serialized_graph.data()),
+                      serialized_graph.size());
+    } else {
+      VLOG(4) << "Cache hit.";
+      AddWorkspace(*hlo, workspace_size_it->second);
+    }
+
+    MarkAsChanged();
+    return absl::OkStatus();
+  }
+
+ private:
+  se::dnn::DnnSupport &dnn_support_;
+  BinaryMap &compilation_results_;
+  absl::flat_hash_map<std::string, int64_t> workspace_sizes_;
+};
+
+}  // namespace
+
+absl::StatusOr<bool> CuDnnCustomCallCompiler::Run(
+    HloModule *module,
+    const absl::flat_hash_set<absl::string_view> &execution_threads) {
+  XLA_SCOPED_LOGGING_TIMER_LEVEL("cuDNN custom call compiler", 8);
+  return CuDnnCustomCallVisitor(dnn_support_, compilation_results_)
+      .RunOnModule(module, execution_threads);
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_compiler.h b/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_compiler.h
new file mode 100644
index 0000000..810286f
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_compiler.h
@@ -0,0 +1,57 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_CUDNN_CUSTOM_CALL_COMPILER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_CUDNN_CUSTOM_CALL_COMPILER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/gpu/ir_emission_utils.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/stream_executor/dnn.h"
+#include "xla/stream_executor/stream_executor.h"
+
+namespace xla {
+namespace gpu {
+
+// Compile cuDNN custom calls to binaries and serialize them.
+// Also adjust them in HLO to have correct workspace size.
+class CuDnnCustomCallCompiler : public HloModulePass {
+ public:
+  explicit CuDnnCustomCallCompiler(se::StreamExecutor& stream_exec,
+                                   BinaryMap& compilation_results)
+      : dnn_support_(*stream_exec.AsDnn()),
+        compilation_results_(compilation_results) {}
+
+  absl::string_view name() const override {
+    return "cudnn-custom-call-compiler";
+  }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+  se::dnn::DnnSupport& dnn_support_;
+  BinaryMap& compilation_results_;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_CUDNN_CUSTOM_CALL_COMPILER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_converter.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_converter.cc
new file mode 100644
index 0000000..71ed08c
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_converter.cc
@@ -0,0 +1,65 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/cudnn_custom_call_converter.h"
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/ir_emission_utils.h"
+#include "tsl/platform/errors.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+class CustomCallVisitor : public DfsHloRewriteVisitor {
+ public:
+  absl::Status HandleCustomCall(HloInstruction *hlo) override {
+    if (hlo->custom_call_target() != kCuDnnFusionKind) {
+      return absl::OkStatus();
+    }
+    HloComputation *computation = hlo->GetModule()->AddEmbeddedComputation(
+        hlo->called_computations()[0]->Clone());
+    HloInstruction *fusion =
+        hlo->parent()->AddInstruction(HloInstruction::CreateFusion(
+            hlo->shape(), HloInstruction::FusionKind::kCustom, hlo->operands(),
+            computation));
+    GpuBackendConfig gpu_config;
+    FusionBackendConfig &backend_config =
+        *gpu_config.mutable_fusion_backend_config();
+    backend_config.set_kind(hlo->custom_call_target());
+    TF_RETURN_IF_ERROR(fusion->set_backend_config(gpu_config));
+    TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, fusion));
+    return absl::OkStatus();
+  }
+};
+
+}  // namespace
+
+absl::StatusOr<bool> CuDnnCustomCallConverter::Run(
+    HloModule *module,
+    const absl::flat_hash_set<absl::string_view> &execution_threads) {
+  return CustomCallVisitor().RunOnModule(module, execution_threads);
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_converter.h b/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_converter.h
new file mode 100644
index 0000000..5397a4d
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_converter.h
@@ -0,0 +1,47 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_CUDNN_CUSTOM_CALL_CONVERTER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_CUDNN_CUSTOM_CALL_CONVERTER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// Converts custom calls with kCuDnnFusionKind backend config to
+// fusions with the same backend config. Frameworks can pass computations
+// outlined this way through StableHLO; after the conversion they can be
+// processed by XLA using the existing pipeline for custom fusions.
+class CuDnnCustomCallConverter : public HloModulePass {
+ public:
+  absl::string_view name() const override {
+    return "cudnn-custom-call-converter";
+  }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_CUDNN_CUSTOM_CALL_CONVERTER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_converter_test.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_converter_test.cc
new file mode 100644
index 0000000..ad29e15
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_converter_test.cc
@@ -0,0 +1,47 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/cudnn_custom_call_converter.h"
+
+#include <gtest/gtest.h>
+#include "xla/tests/hlo_test_base.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+using ConverterTest = HloTestBase;
+
+TEST_F(ConverterTest, CustomCallGetsConvertedToCustomFusion) {
+  RunAndFilecheckHloRewrite(R"(
+f {
+  a = s8[] parameter(0)
+  ROOT r = s8[] add(a, a)
+}
+
+ENTRY e {
+  b = s8[] parameter(0)
+  ROOT c = s8[] custom-call(b),
+    custom_call_target="__cudnn$fusion", called_computations={f}
+})",
+                            CuDnnCustomCallConverter(), R"(
+; CHECK: ROOT %fusion = s8[] fusion(%b), kind=kCustom, calls=%f
+; CHECK-SAME: "fusion_backend_config":{"kind":"__cudnn$fusion"}
+                          )");
+}
+
+}  // namespace
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter.cc
new file mode 100644
index 0000000..c51a76f
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter.cc
@@ -0,0 +1,1566 @@
+/* Copyright 2018 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/cudnn_fused_conv_rewriter.h"
+
+#include <algorithm>
+#include <array>
+#include <cstdint>
+#include <functional>
+#include <limits>
+#include <optional>
+#include <string>
+#include <tuple>
+#include <utility>
+#include <variant>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/container/inlined_vector.h"
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/string_view.h"
+#include "xla/comparison_util.h"
+#include "xla/debug_options_flags.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/literal.h"
+#include "xla/primitive_util.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/service/hlo_creation_utils.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/stream_executor/dnn.h"
+#include "xla/stream_executor/stream_executor.h"
+#include "xla/util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/ml_dtypes.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+namespace m = match;
+
+bool IsConvCustomCall(const HloInstruction* instr) {
+  return instr->opcode() == HloOpcode::kCustomCall &&
+         (instr->custom_call_target() == kCudnnConvForwardCallTarget ||
+          instr->custom_call_target() ==
+              kCudnnConvBiasActivationForwardCallTarget);
+}
+
+bool IsConvDepthwise(const HloInstruction* instr) {
+  int64_t feature_group_count = instr->feature_group_count();
+  if (feature_group_count == 1) {
+    return false;
+  }
+
+  const HloInstruction* input = instr->operand(0);
+  int64_t input_feature_dimension =
+      instr->convolution_dimension_numbers().input_feature_dimension();
+  int64_t input_feature_count =
+      input->shape().dimensions(input_feature_dimension);
+  return input_feature_count == feature_group_count;
+}
+
+// We don't want to upgrade depthwise convolutions to ConvBiasActivation,
+// because the fused CUDNN functions are slower for some of those.
+bool IsNonDepthwiseConvCustomCall(const HloInstruction* instr) {
+  return IsConvCustomCall(instr) && !IsConvDepthwise(instr);
+}
+
+bool IsROCm(se::GpuComputeCapability cc) {
+  return std::holds_alternative<se::RocmComputeCapability>(cc);
+}
+
+// elu, relu6, and leaky-relu activations are supported in cudnn via the
+// "runtime fusion" engine, which JIT compiles C++ code.  This can be slow to
+// compile, so we guard it with a debug option.
+//
+// nvidia currently recommends that we enable this only on Ampere+, but we've
+// tested on Turing (sm75) and it seems to work fine.
+//
+// Note that as of writing, xla_gpu_use_runtime_fusion is disabled by default
+// due to apparent bugs in cudnn 8.9.0.  See debug_options_flags.cc for details.
+bool ShouldUseCudnnRuntimeFusion(const DebugOptions& debug_opts,
+                                 se::GpuComputeCapability cc) {
+  const auto* cuda_cc = std::get_if<se::CudaComputeCapability>(&cc);
+  if (cuda_cc != nullptr)
+    return debug_opts.xla_gpu_use_runtime_fusion() && cuda_cc->IsAtLeast(7, 5);
+  else
+    return true;
+}
+
+bool IsSuitableForCudnnRuntimeFusion(HloInstruction* conv) {
+  // cudnn runtime fusion is pathologically slow on convs with side-inputs.
+  // TODO(kaixih@nvidia): remove this check when cuDNN fixes it.
+  if (conv->operands().size() > 3) {
+    return false;
+  }
+
+  // cuDNN runtime funsion kernels require 32-bit aligned data access, which
+  // means that the number of in/out channels must be divisible by 2 for fp16.
+  // (We don't currently do runtime fusion for int8.)
+  if (conv->operand(0)->shape().element_type() != F16) {
+    return false;
+  }
+  const Shape& shape = conv->operand(1)->shape();
+  int64_t num_input_features = shape.dimensions(
+      conv->convolution_dimension_numbers().kernel_input_feature_dimension());
+  int64_t num_output_features = shape.dimensions(
+      conv->convolution_dimension_numbers().kernel_output_feature_dimension());
+  if (num_input_features % 2 != 0 || num_output_features % 2 != 0) {
+    return false;
+  }
+
+  return true;
+}
+
+// Can instr be converted to type `dst_ty` without losing any precision?  For
+// our purposes, this is true if:
+//
+//  - instr already has type dst_ty, or
+//  - instr is convert<wider type>(op_with_dst_ty), or
+//  - instr is a constant which we can convert orig_ty -> dst_ty -> orig_ty and
+//    get back exactly the original value, or
+//  - instr is a broadcast, reshape, or transpose of one of the above.
+bool IsLosslesslyConvertibleTo(const HloInstruction* instr,
+                               PrimitiveType dst_ty) {
+  if (instr->shape().element_type() == dst_ty) {
+    return true;
+  }
+
+  if (Match(instr, m::Convert(m::Op().WithElementType(dst_ty)))) {
+    // Check that the convert from dst_ty to instr->element_type() doesn't lose
+    // precision.  Otherwise, this convert is not lossless.
+    return primitive_util::CastPreservesValues(dst_ty,
+                                               instr->shape().element_type());
+  }
+
+  if (instr->opcode() == HloOpcode::kConstant) {
+    if (!instr->shape().IsArray()) {
+      return false;
+    }
+    // Check if instr's literal roundtrips to ty and back to its original type
+    // without modification.
+    PrimitiveType orig_ty = instr->shape().element_type();
+
+    // The only reason Convert() should fail is if we don't support converting
+    // from x to y, which indeed means it's not losslessly-convertible.
+    absl::StatusOr<Literal> converted1 = instr->literal().Convert(dst_ty);
+    if (!converted1.ok()) {
+      return false;
+    }
+    absl::StatusOr<Literal> converted2 = converted1->Convert(orig_ty);
+    if (!converted2.ok()) {
+      return false;
+    }
+
+    return instr->literal() == *converted2;
+  }
+
+  if (instr->opcode() == HloOpcode::kBroadcast ||
+      instr->opcode() == HloOpcode::kReshape ||
+      instr->opcode() == HloOpcode::kTranspose) {
+    return IsLosslesslyConvertibleTo(instr->operand(0), dst_ty);
+  }
+
+  return false;
+}
+
+// Helpers suitable for use in m::Op().WithPredicate(...).
+bool IsLosslesslyConvertibleToS8(const HloInstruction* instr) {
+  return IsLosslesslyConvertibleTo(instr, S8);
+}
+bool IsLosslesslyConvertibleToF16(const HloInstruction* instr) {
+  return IsLosslesslyConvertibleTo(instr, F16);
+}
+
+// If `conv` is a vanilla forward conv, transforms it into a
+// conv-bias-activation.  If it's already a conv-bias-activation, does nothing.
+//
+// If `conv` is anything else, returns an error.
+absl::StatusOr<HloInstruction*> EnsureIsConvBiasActivation(
+    HloInstruction* conv) {
+  CHECK_EQ(conv->opcode(), HloOpcode::kCustomCall);
+
+  if (conv->custom_call_target() == kCudnnConvBiasActivationForwardCallTarget) {
+    return conv;
+  }
+
+  if (conv->custom_call_target() == kCudnnConvForwardCallTarget) {
+    HloComputation* comp = conv->parent();
+
+    const Shape& shape = conv->shape().tuple_shapes(0);
+    int64_t num_output_features = shape.dimensions(
+        conv->convolution_dimension_numbers().output_feature_dimension());
+
+    // bias for integer convs is always f32, see
+    // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionBiasActivationForward
+    PrimitiveType bias_ty;
+    if (primitive_util::IsIntegralType(shape.element_type())) {
+      bias_ty = F32;
+    } else {
+      bias_ty = shape.element_type();
+    }
+    auto bias = BroadcastZeros(comp, bias_ty, {num_output_features});
+
+    absl::InlinedVector<HloInstruction*, 3> new_operands(
+        conv->operands().begin(), conv->operands().end());
+    new_operands.push_back(bias);
+
+    HloInstruction* new_conv = comp->AddInstruction(
+        conv->CloneWithNewOperands(conv->shape(), new_operands));
+    TF_RETURN_IF_ERROR(comp->ReplaceInstruction(conv, new_conv));
+    new_conv->set_custom_call_target(kCudnnConvBiasActivationForwardCallTarget);
+    comp->parent()->SetAndUniquifyInstrName(new_conv,
+                                            "cudnn-conv-bias-activation");
+    return new_conv;
+  }
+
+  return FailedPrecondition("Unsupported conv: %s", conv->ToString());
+}
+
+// convert<cvt_type>(gte(custom-call<conv_type>(int8_x, int8_w))) ->
+// gte(custom-call<cvt_type>(int8_x, int8_w))
+absl::StatusOr<bool> FuseConvertTypeIntoConv(HloComputation* comp,
+                                             PrimitiveType conv_type,
+                                             PrimitiveType cvt_type) {
+  bool changed = false;
+  for (auto instr : comp->MakeInstructionPostOrder()) {
+    HloInstruction* conv = nullptr;
+    auto tuple_elem =
+        m::GetTupleElement(m::Op(&conv).WithPredicate(IsConvCustomCall), 0)
+            .WithElementType(conv_type);
+    auto pattern =
+        m::Convert(tuple_elem.WithOneUser()).WithElementType(cvt_type);
+    if (!Match(instr, pattern)) {
+      continue;
+    }
+    if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
+          return absl::StrCat("FuseConvertTypeIntoConv: ", conv->ToString());
+        })) {
+      continue;
+    }
+
+    Shape new_shape = conv->shape();
+    new_shape.mutable_tuple_shapes(0)->set_element_type(cvt_type);
+    HloInstruction* new_conv =
+        comp->AddInstruction(conv->CloneWithNewShape(new_shape));
+    comp->parent()->SetAndUniquifyInstrName(new_conv, conv->name());
+    TF_ASSIGN_OR_RETURN(HloInstruction * new_gte,
+                        MakeGetTupleElementHlo(new_conv, 0));
+    TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_gte));
+
+    changed = true;
+  }
+
+  return changed;
+}
+
+struct ConvConvertTypes {
+  PrimitiveType convolution_type;
+  PrimitiveType conversion_type;
+};
+
+// Remove convert around convolution by making the convolution-type
+// (custom call) to be the same as the conversion result.
+// For example: convert<float>(gte(custom-call<int32>(int8_x, int8_w))) ->
+// gte(custom-call<float>(int8_x, int8_w))
+absl::StatusOr<bool> FuseRemoveConvertInConv(HloComputation* comp) {
+  bool changed = false;
+  // Note: We are eliminating F16->F32 because it fails on internal tests.
+  std::array<ConvConvertTypes, 3> types{{
+      {S32, F32},
+      {S8, F32},
+      {F32, S8},
+  }};
+  for (auto [conv_type, cvt_type] : types) {
+    TF_ASSIGN_OR_RETURN(bool curr_change,
+                        FuseConvertTypeIntoConv(comp, conv_type, cvt_type));
+    changed |= curr_change;
+  }
+  return changed;
+}
+
+// alpha * gte(custom-call(...)) ->
+// gte(custom-call(..., backend_config={alpha})).
+absl::StatusOr<bool> FuseConvAlpha(HloComputation* comp) {
+  bool changed = false;
+  for (auto instr : comp->MakeInstructionPostOrder()) {
+    HloInstruction* conv = nullptr;
+    HloInstruction* gte = nullptr;
+    HloInstruction* alpha = nullptr;
+
+    auto pattern = m::MultiplyAnyOrder(
+        m::GetTupleElement(
+            &gte, m::Op(&conv).WithPredicate(IsNonDepthwiseConvCustomCall), 0)
+            .WithOneUse(),
+        m::Broadcast(m::ConstantEffectiveScalar(&alpha)));
+    if (!Match(instr, pattern)) {
+      continue;
+    }
+
+    // alpha is f32 except for f64 convs, where it's f64.  See
+    // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionBiasActivationForward
+    PrimitiveType alpha_ty = gte->shape().element_type() == F64 ? F64 : F32;
+    if (!IsLosslesslyConvertibleTo(alpha, alpha_ty)) {
+      continue;
+    }
+
+    TF_ASSIGN_OR_RETURN(auto gpu_config,
+                        conv->backend_config<GpuBackendConfig>());
+    CudnnConvBackendConfig& config =
+        *gpu_config.mutable_cudnn_conv_backend_config();
+
+    if (config.conv_result_scale() != 1) {
+      continue;
+    }
+    if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
+          return absl::StrCat("FuseConvAlpha: ", conv->ToString());
+        })) {
+      continue;
+    }
+
+    // StreamExecutor doesn't support the alpha parameter on non-bias-activation
+    // convs, so we have to upgrade `conv`.
+    TF_ASSIGN_OR_RETURN(conv, EnsureIsConvBiasActivation(conv));
+
+    TF_ASSIGN_OR_RETURN(Literal alpha_f64, alpha->literal().Convert(F64));
+    config.set_conv_result_scale(alpha_f64.GetFirstElement<double>());
+    TF_RETURN_IF_ERROR(conv->set_backend_config(gpu_config));
+    TF_RETURN_IF_ERROR(conv->parent()->ReplaceInstruction(instr, gte));
+
+    changed = true;
+  }
+  return changed;
+}
+
+// The format of the serialized graph describing a sequence of ops fused
+// into the cuDNN convolution Custom Call is
+// "UID:[output_type]conv();UID[output_type]:op_name(operand
+// UID);UID:[output_type]op_name(operand UID);..." with the convolution assumed
+// to be the first op in the graph. Operand UIDs identifying ops outside the
+// serialized graph are elided. Currently, multiplication and division by a
+// broadcast scalar, addition of a matrix bias, the application of a ReLU
+// activation and the calculation of the maximum of the absolute value are
+// supported.
+class GraphString {
+ public:
+  GraphString() = default;
+
+  bool AppendOp(std::string op_name, HloInstruction* op,
+                std::vector<HloInstruction*> operands = {}) {
+    std::optional<int64_t> operand_uid;
+    int num_operands_in_graph = 0;
+    for (HloInstruction* operand : operands) {
+      if (OpInGraph(operand->unique_id())) {
+        num_operands_in_graph++;
+        // Ops with more than one operand in the graph are not supported.
+        if (num_operands_in_graph > 1) {
+          return false;
+        }
+        operand_uid = operand->unique_id();
+      }
+    }
+    graph_.emplace_back(OpDescriptor(
+        {op->unique_id(), op->shape().element_type(), op_name, operand_uid}));
+    return true;
+  }
+
+  void ChangeDataType(PrimitiveType type) {
+    DCHECK(!graph_.empty());
+    graph_.back().output_type = type;
+  }
+
+  std::string Graph() const {
+    std::string graph;
+    for (OpDescriptor op : graph_) {
+      graph.append(std::to_string(op.uid));
+      graph.append(":[" +
+                   primitive_util::LowercasePrimitiveTypeName(op.output_type) +
+                   "]");
+      graph.append(op.name);
+      graph.append("(");
+      if (op.operand.has_value()) {
+        graph.append(std::to_string(*op.operand));
+      }
+      graph.append(");");
+    }
+    return graph;
+  }
+
+  bool OpInGraph(int64_t uid, std::string op_name = "") const {
+    auto op_filter = [&](OpDescriptor op) -> bool {
+      if (op_name.empty()) {
+        return op.uid == uid;
+      } else {
+        return op.uid == uid && op.name == op_name;
+      }
+    };
+    return std::find_if(graph_.begin(), graph_.end(), op_filter) !=
+           graph_.end();
+  }
+
+ private:
+  struct OpDescriptor {
+    int64_t uid;
+    PrimitiveType output_type;
+    std::string name;
+    std::optional<int64_t> operand;
+  };
+
+  std::vector<OpDescriptor> graph_;
+};
+
+bool IsF8Type(const HloInstruction* instr) {
+  return primitive_util::IsF8Type(instr->shape().element_type());
+}
+
+bool IsScalar(const HloInstruction* instr) {
+  return ShapeUtil::IsScalar(instr->shape());
+}
+
+std::optional<PrimitiveType> IsSaturatingCastToF8(HloInstruction* instr) {
+  HloInstruction *op, *clamp_lower, *clamp_upper;
+  if (Match(instr,
+            m::Convert(
+                &op,
+                m::Clamp(m::Broadcast(m::ConstantScalar(&clamp_lower)), m::Op(),
+                         m::Broadcast(m::ConstantScalar(&clamp_upper))))) &&
+      ((op->shape().element_type() == F8E4M3FN &&
+        clamp_lower->literal().IsAllFloat(static_cast<float>(
+            std::numeric_limits<tsl::float8_e4m3fn>::lowest())) &&
+        clamp_upper->literal().IsAllFloat(static_cast<float>(
+            std::numeric_limits<tsl::float8_e4m3fn>::max()))) ||
+       (op->shape().element_type() == F8E5M2 &&
+        clamp_lower->literal().IsAllFloat(static_cast<float>(
+            std::numeric_limits<tsl::float8_e5m2>::lowest())) &&
+        clamp_upper->literal().IsAllFloat(static_cast<float>(
+            std::numeric_limits<tsl::float8_e5m2>::max()))))) {
+    return op->shape().element_type();
+  }
+  return std::nullopt;
+}
+
+// Returns whether the HLO Computation applied by `op` calculates the largest
+// element.
+bool AppliesMaxReduce(HloInstruction* op) {
+  HloComputation* reduce_comp = op->to_apply();
+  HloInstruction* reduce_comp_root = reduce_comp->root_instruction();
+  return ShapeUtil::IsScalar(op->shape()) &&
+         ShapeUtil::IsScalar(op->operand(1)->shape()) &&
+         op->operand(1)->IsConstant() &&
+         op->operand(1)->literal().GetAsDouble({}) <= 0. &&
+         reduce_comp_root->opcode() == HloOpcode::kMaximum &&
+         reduce_comp_root->operand(0)->opcode() == HloOpcode::kParameter &&
+         reduce_comp_root->operand(1)->opcode() == HloOpcode::kParameter;
+}
+
+// Recursively captures and serializes the graph of pointwise operations
+// operating on the convolution.
+void CaptureConvGraphRecursive(HloInstruction* instr,
+                               std::vector<HloInstruction*>& operands,
+                               std::vector<HloInstruction*>& aux_outputs,
+                               GraphString& graph_string,
+                               absl::flat_hash_set<int>& visited_instrs,
+                               HloInstruction*& final_instr) {
+  // Avoid visiting the same instruction more than once.
+  if (!visited_instrs.emplace(instr->unique_id()).second) {
+    return;
+  }
+  final_instr = instr;
+
+  // Copy the current state in case fusion will be unsuccessful or unfavorable.
+  GraphString init_graph_string = graph_string;
+  std::vector<HloInstruction*> init_operands = operands,
+                               init_aux_outputs = aux_outputs;
+  // The loop adds each user of `instr` that supports fusion into the
+  // cuDNN convolution Custom Call to GraphString. Most ops following the
+  // convolution describe a linear sequence that generates a single return
+  // tensor. The identification of one of these linear ops is followed by a
+  // recursive call of CaptureConvGraphRecursive to match and potentially fuse
+  // its users. The calculation of the scalar maximum of the absolute value
+  // (Amax) of a preceding op is considered a nonlinear user as it adds a
+  // return value to the convolution. The users of a nonlinear op are
+  // not considered for fusion into the Custom Call. The numbers of linear and
+  // nonlinear users of `instr` are stored in `num_linear_users` and
+  // `num_nonlinear_users`.
+  int num_linear_users = 0, num_nonlinear_users = 0;
+  for (HloInstruction* user : instr->users()) {
+    HloInstruction *op, *operand0, *operand1;
+    // Add
+    if (Match(user, m::AddAnyOrder(&op, m::Op(&operand0), m::Op(&operand1)))) {
+      if (graph_string.AppendOp("add", op, {operand0, operand1})) {
+        operands.push_back(operand0 == instr ? operand1 : operand0);
+        num_linear_users++;
+        CaptureConvGraphRecursive(user, operands, aux_outputs, graph_string,
+                                  visited_instrs, final_instr);
+      }
+      continue;
+    }
+    // Scale
+    if (Match(user, m::MultiplyAnyOrder(&op, m::Op(&operand0),
+                                        m::Broadcast(m::Op(&operand1)))) &&
+        ShapeUtil::IsScalar(operand1->shape())) {
+      if (graph_string.AppendOp("scale", op, {operand0, operand1})) {
+        operands.push_back(operand1);
+        num_linear_users++;
+        CaptureConvGraphRecursive(user, operands, aux_outputs, graph_string,
+                                  visited_instrs, final_instr);
+      }
+      continue;
+    }
+    // Inverse Scale
+    if (Match(user, m::Divide(&op, m::Op(&operand0),
+                              m::Broadcast(m::Op(&operand1)))) &&
+        ShapeUtil::IsScalar(operand1->shape())) {
+      if (graph_string.AppendOp("invscale", op, {operand0, operand1})) {
+        operands.push_back(operand1);
+        num_linear_users++;
+        CaptureConvGraphRecursive(user, operands, aux_outputs, graph_string,
+                                  visited_instrs, final_instr);
+      }
+      continue;
+    }
+    // ReLU
+    if (Match(user, m::MaximumAnyOrder(&op, m::Op(&operand0),
+                                       m::Broadcast(m::ConstantScalar(0))))) {
+      if (graph_string.AppendOp("relu", op, {operand0})) {
+        num_linear_users++;
+        CaptureConvGraphRecursive(user, operands, aux_outputs, graph_string,
+                                  visited_instrs, final_instr);
+      }
+      continue;
+    }
+    //  Maximum of the absolute value (Amax) following ReLU (elided Abs) -- not
+    //  a linear user
+    if (Match(user, m::Reduce(&op, m::Op(&operand0), m::Op())) &&
+        graph_string.OpInGraph(operand0->unique_id(), "relu") &&
+        AppliesMaxReduce(op)) {
+      if (graph_string.AppendOp("amax", op, {operand0})) {
+        aux_outputs.emplace_back(op);
+        num_nonlinear_users++;
+      }
+      continue;
+    }
+
+    // The following patterns match the user of `user`.
+    if (!user->users().empty()) {
+      HloInstruction* users_user = user->users()[0];
+      // Convert with Clamp to FP8 types
+      std::optional<PrimitiveType> f8_type = IsSaturatingCastToF8(users_user);
+      if (f8_type.has_value()) {
+        graph_string.ChangeDataType(f8_type.value());
+        num_linear_users++;
+        CaptureConvGraphRecursive(users_user, operands, aux_outputs,
+                                  graph_string, visited_instrs, final_instr);
+        continue;
+      }
+      // Maximum of the absolute value (Amax) -- not a linear user
+      if (Match(users_user,
+                m::Reduce(&op, m::Abs(m::Op(&operand0)), m::Op())) &&
+          AppliesMaxReduce(op)) {
+        if (graph_string.AppendOp("amax", op, {operand0})) {
+          aux_outputs.emplace_back(op);
+          num_nonlinear_users++;
+        }
+        continue;
+      }
+    }
+  }
+  // Do not fuse into the cuDNN convolution Custom Call when there are more than
+  // one linear or nonlinear users, or when the number of users eligible for
+  // fusion is less than the total number of users.
+  if (num_linear_users > 1 || num_nonlinear_users > 1 ||
+      num_linear_users + num_nonlinear_users < instr->user_count()) {
+    graph_string = init_graph_string;
+    operands = init_operands;
+    aux_outputs = init_aux_outputs;
+    final_instr = instr;
+  }
+}
+
+// Captures in a GraphString the subgraph of pointwise operations operating on
+// the convolution that will be fused into the cuDNN convolution Custom Call.
+absl::StatusOr<
+    std::tuple<std::vector<HloInstruction*>, std::vector<HloInstruction*>,
+               GraphString, HloInstruction*>>
+CaptureConvGraph(HloInstruction* instr, HloInstruction* convolution,
+                 HloInstruction* wide_input, HloInstruction* wide_filter,
+                 HloInstruction* input_scale, HloInstruction* filter_scale,
+                 bool x_mult_scale, bool w_mult_scale) {
+  GraphString graph_string;
+  graph_string.AppendOp("conv", instr);
+
+  // Shift the scaling of the input and filter to the output of the convolution.
+  HloInstruction *input_scaled_conv, *filter_scaled_conv;
+  if (input_scale) {
+    TF_RETURN_IF_ERROR(convolution->ReplaceOperandWith(0, wide_input));
+    HloInstruction* bcast_input_scale = instr->AddInstruction(
+        HloInstruction::CreateBroadcast(instr->shape(), input_scale, {}));
+    input_scaled_conv = instr->AddInstruction(HloInstruction::CreateBinary(
+        instr->shape(),
+        x_mult_scale ? HloOpcode::kMultiply : HloOpcode::kDivide, instr,
+        bcast_input_scale));
+    TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(input_scaled_conv));
+  }
+  if (filter_scale) {
+    TF_RETURN_IF_ERROR(convolution->ReplaceOperandWith(1, wide_filter));
+    HloInstruction* bcast_filter_scale = instr->AddInstruction(
+        HloInstruction::CreateBroadcast(instr->shape(), filter_scale, {}));
+    filter_scaled_conv = instr->AddInstruction(HloInstruction::CreateBinary(
+        instr->shape(),
+        w_mult_scale ? HloOpcode::kMultiply : HloOpcode::kDivide,
+        input_scale ? input_scaled_conv : instr, bcast_filter_scale));
+    TF_RETURN_IF_ERROR((input_scale ? input_scaled_conv : instr)
+                           ->ReplaceAllUsesWith(filter_scaled_conv));
+  }
+
+  std::vector<HloInstruction*> operands, aux_outputs;
+  absl::flat_hash_set<int> visited_instrs;
+  HloInstruction* final_instr;
+  CaptureConvGraphRecursive(instr, operands, aux_outputs, graph_string,
+                            visited_instrs, final_instr);
+  return std::make_tuple(operands, aux_outputs, graph_string, final_instr);
+}
+
+// Matches convolutions operating on FP8 inputs and filters and rewrites into a
+// ForwardGraph Custom Call. For scaled FP8 convolutions on Hopper systems, the
+// following steps are elided and rewritten into a ForwardGraph Custom Call:
+//
+// 1. Cast the filter and input from FP8 to a wider type such as FP16 or FP32.
+// 2. Optionally unscale the filter and input by multiplying or dividing by
+// scalars.
+// 3. Evaluate the convolution based on the scaled filter and input.
+// 4. Apply a series of elementwise transformations, where a transformation can
+// be adding a matrix bias, applying a ReLU activation, or
+// multiplying or dividing by a broadcast scalar.
+// 5. Optionally calculate the maximum of the absolute of the result.
+// 6. Optionally cast the output back to FP8.
+absl::StatusOr<bool> F8GraphConv(HloComputation* comp,
+                                 se::CudaComputeCapability cc,
+                                 se::dnn::VersionInfo dnn_version,
+                                 int32_t toolkit_version) {
+  bool changed = false;
+
+  if (dnn_version < se::dnn::VersionInfo(8, 9, 0)) {
+    return false;
+  }
+  if (toolkit_version < 12000) {
+    return false;
+  }
+  if (!cc.IsAtLeast(se::CudaComputeCapability::HOPPER)) {
+    return false;
+  }
+  for (auto instr : comp->MakeInstructionPostOrder()) {
+    HloInstruction *convolution, *gte, *input, *filter,
+        *input_scale = nullptr, *filter_scale = nullptr,
+        *input_scale_op = nullptr, *filter_scale_op = nullptr,
+        *wide_input = nullptr, *wide_filter = nullptr;
+
+    auto conv_operand_maybe_scaled = [](HloInstruction** operand,
+                                        HloInstruction** wide_operand,
+                                        HloInstruction** scale_op,
+                                        HloInstruction** scale) {
+      return m::AnyOf<HloInstruction>(
+          m::Op(operand).WithPredicate(IsF8Type),
+          m::Convert(wide_operand, m::Op(operand).WithPredicate(IsF8Type)),
+          m::Divide(
+              scale_op,
+              m::Convert(wide_operand, m::Op(operand).WithPredicate(IsF8Type)),
+              m::Broadcast(m::Op(scale).WithPredicate(IsScalar))),
+          m::MultiplyAnyOrder(
+              scale_op,
+              m::Convert(wide_operand, m::Op(operand).WithPredicate(IsF8Type)),
+              m::Broadcast(m::Op(scale).WithPredicate(IsScalar))));
+    };
+
+    // TODO(philipphack): Consider allowing ops between dequantization and
+    // convolution.
+    auto pattern = m::GetTupleElement(
+        &gte,
+        m::CustomCall(
+            &convolution,
+            conv_operand_maybe_scaled(&input, &wide_input, &input_scale_op,
+                                      &input_scale),
+            conv_operand_maybe_scaled(&filter, &wide_filter, &filter_scale_op,
+                                      &filter_scale))
+            .WithPredicate(IsConvCustomCall),
+        0);
+    if (Match(instr, pattern)) {
+      if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
+            return absl::StrCat("F8GraphConv: ", convolution->ToString());
+          })) {
+        continue;
+      }
+
+      std::vector<HloInstruction*> operands, aux_outputs;
+      GraphString graph_string;
+      HloInstruction* final_instr;
+
+      TF_ASSIGN_OR_RETURN(
+          std::tie(operands, aux_outputs, graph_string, final_instr),
+          CaptureConvGraph(
+              instr, convolution, wide_input, wide_filter, input_scale,
+              filter_scale,
+              input_scale_op ? input_scale_op->opcode() == HloOpcode::kMultiply
+                             : false,
+              filter_scale_op
+                  ? filter_scale_op->opcode() == HloOpcode::kMultiply
+                  : false));
+      TF_ASSIGN_OR_RETURN(auto gpu_config,
+                          convolution->backend_config<GpuBackendConfig>());
+      CudnnConvBackendConfig& config =
+          *gpu_config.mutable_cudnn_conv_backend_config();
+
+      config.set_serialized_graph(graph_string.Graph());
+      operands.insert(operands.begin(), input);
+      operands.insert(operands.begin() + 1, filter);
+
+      std::vector<Shape> output_shapes;
+      output_shapes.emplace_back(ShapeUtil::ChangeElementType(
+          ShapeUtil::GetTupleElementShape(convolution->shape(), 0),
+          final_instr->shape().element_type()));
+      for (HloInstruction* aux_output : aux_outputs) {
+        output_shapes.emplace_back(aux_output->shape());
+      }
+      output_shapes.emplace_back(
+          ShapeUtil::GetTupleElementShape(convolution->shape(), 1));
+
+      HloInstruction* new_convolution =
+          comp->AddInstruction(convolution->CloneWithNewOperands(
+              ShapeUtil::MakeTupleShape(output_shapes), operands));
+
+      new_convolution->set_custom_call_target(kCudnnConvForwardGraphCallTarget);
+      TF_RETURN_IF_ERROR(new_convolution->set_backend_config(gpu_config));
+      TF_ASSIGN_OR_RETURN(HloInstruction * new_gte,
+                          MakeGetTupleElementHlo(new_convolution, 0));
+      TF_RETURN_IF_ERROR(comp->ReplaceInstruction(final_instr, new_gte));
+
+      for (int i = 0; i < aux_outputs.size(); ++i) {
+        TF_ASSIGN_OR_RETURN(HloInstruction * new_gte,
+                            MakeGetTupleElementHlo(new_convolution, i + 1));
+        TF_RETURN_IF_ERROR(comp->ReplaceInstruction(aux_outputs[i], new_gte));
+      }
+
+      changed = true;
+    }
+  }
+  return changed;
+}
+
+absl::StatusOr<bool> FuseBiasOrSideInput(HloComputation* comp) {
+  bool changed = false;
+  for (auto instr : comp->MakeInstructionPostOrder()) {
+    HloInstruction* conv = nullptr;
+    HloInstruction* gte = nullptr;
+    HloInstruction* addend = nullptr;
+
+    auto pattern = m::AddAnyOrder(
+        m::GetTupleElement(&gte,
+                           m::Op(&conv)
+                               .WithPredicate(IsNonDepthwiseConvCustomCall)
+                               .WithOneUse(),
+                           0)
+            .WithOneUse(),
+        m::Op(&addend));
+    if (!Match(instr, pattern)) {
+      continue;
+    }
+
+    if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
+          return absl::StrCat("FuseBiasOrSideInput: ", conv->ToString());
+        })) {
+      continue;
+    }
+
+    // If it's a vanilla forward conv, upgrade it to a bias-activation conv.  We
+    // only want to do this if the fusion will succeed, but we're guaranteed
+    // that it will, because the only reason we'll bail at this point is if
+    // !can_accept_bias && !can_accept_side_input, and our shiny new
+    // bias-activation conv will be able to accept both.
+    if (conv->custom_call_target() == kCudnnConvForwardCallTarget) {
+      TF_ASSIGN_OR_RETURN(conv, EnsureIsConvBiasActivation(conv));
+    }
+
+    // Can't fuse bias or side-input if the conv already has a relu (or other
+    // activation), because bias and side-input are added before the activation
+    // is applied.
+    TF_ASSIGN_OR_RETURN(auto gpu_config,
+                        conv->backend_config<GpuBackendConfig>());
+    CudnnConvBackendConfig& config =
+        *gpu_config.mutable_cudnn_conv_backend_config();
+    if (config.activation_mode() != se::dnn::kNone) {
+      continue;
+    }
+
+    // Does `conv` already have a (nonzero) bias?  Does it already have a
+    // side_input?
+    bool can_accept_bias =
+        Match(conv->operand(2), m::Broadcast(m::ConstantEffectiveScalar(0)));
+    bool can_accept_side_input = conv->operand_count() < 4;
+
+    // The addend can be fused as a bias if
+    //  - it is 1D broadcasted in the output feature dimension, and
+    //  - it is losslessly-convertible to the correct type (f32 for s8/f32/u32
+    //    convs, and conv_ty for floating-point convs)
+    PrimitiveType conv_ty = gte->shape().element_type();
+    PrimitiveType bias_ty =
+        primitive_util::IsFloatingPointType(conv_ty) ? conv_ty : F32;
+    bool addend_may_be_rank1_bias =
+        addend->opcode() == HloOpcode::kBroadcast &&
+        addend->dimensions().size() == 1 &&
+        addend->dimensions(0) ==
+            conv->convolution_dimension_numbers().output_feature_dimension() &&
+        IsLosslesslyConvertibleTo(addend, bias_ty);
+
+    bool addend_may_be_rank0_bias = addend->opcode() == HloOpcode::kBroadcast &&
+                                    addend->dimensions().empty() &&
+                                    IsLosslesslyConvertibleTo(addend, bias_ty);
+
+    absl::InlinedVector<HloInstruction*, 4> new_operands(
+        conv->operands().begin(), conv->operands().end());
+    if (can_accept_bias && addend_may_be_rank1_bias) {
+      new_operands[2] = MakeConvertToHlo(addend->mutable_operand(0), bias_ty,
+                                         &addend->operand(0)->metadata());
+    } else if (can_accept_bias && addend_may_be_rank0_bias) {
+      new_operands[2] = MakeBroadcastHlo(
+          MakeConvertToHlo(addend->mutable_operand(0), bias_ty,
+                           &addend->operand(0)->metadata()),
+          /*broadcast_dimensions=*/{},
+          /*result_shape_bounds=*/
+          {gte->shape().dimensions(conv->convolution_dimension_numbers()
+                                       .output_feature_dimension())});
+    } else if (can_accept_side_input) {
+      CHECK_EQ(new_operands.size(), 3);
+      new_operands.push_back(addend);
+      config.set_side_input_scale(1);
+    } else {
+      // Can't fuse; this op already has a bias and a side-input.
+      continue;
+    }
+
+    HloInstruction* new_conv = comp->AddInstruction(
+        conv->CloneWithNewOperands(conv->shape(), new_operands));
+    comp->parent()->SetAndUniquifyInstrName(new_conv, conv->name());
+    TF_RETURN_IF_ERROR(new_conv->set_backend_config(gpu_config));
+    TF_ASSIGN_OR_RETURN(HloInstruction * new_instr,
+                        MakeGetTupleElementHlo(new_conv, 0));
+    TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_instr));
+    changed = true;
+  }
+  return changed;
+}
+
+// custom-call(..., alpha * side_input) ->
+// custom-call(..., side_input, backend_config={alpha}).
+//
+// We also have to support the more complicated case of
+//
+//   custom-call(..., reshape(side_input * alpha)) -->
+//   custom-call(..., reshape(side_input), backend_config={alpha}),
+//
+// where `reshape` can be an arbitrary chain of reshapes+transposes.  This idiom
+// is created by the ReshapeMover pass.
+absl::StatusOr<bool> FuseSideInputAlpha(HloComputation* comp) {
+  bool changed = false;
+  for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
+    HloInstruction* conv;
+    HloInstruction* side_input;
+    auto pattern = m::Op(&conv)
+                       .WithPredicate(IsConvCustomCall)
+                       .WithOperand(3, m::Op(&side_input));
+    if (!Match(instr, pattern)) {
+      continue;
+    }
+    TF_ASSIGN_OR_RETURN(auto gpu_config,
+                        conv->backend_config<GpuBackendConfig>());
+    CudnnConvBackendConfig& config =
+        *gpu_config.mutable_cudnn_conv_backend_config();
+    if (config.side_input_scale() != 1) {
+      continue;
+    }
+
+    // Given side_input, pattern match the following (working from bottom up).
+    //
+    // before_reshape = multiply(base, broadcast(alpha))
+    // side_input = chain_of_reshapes_and_transposes(before_reshape)
+    //
+    // where alpha is a scalar constant.
+    //
+    // alpha is f32 except for f64 convs, where it's f64.  See
+    // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionBiasActivationForward
+    HloInstruction* before_reshape = side_input;
+    while (before_reshape->opcode() == HloOpcode::kReshape ||
+           before_reshape->opcode() == HloOpcode::kTranspose) {
+      before_reshape = before_reshape->mutable_operand(0);
+    }
+
+    PrimitiveType conv_ty = conv->shape().tuple_shapes(0).element_type();
+    PrimitiveType alpha_ty = conv_ty == F64 ? F64 : F32;
+    HloInstruction* base;
+    HloInstruction* alpha;
+    if (!Match(
+            before_reshape,
+            m::MultiplyAnyOrder(
+                m::Op(&base),
+                m::Broadcast(m::ConstantEffectiveScalar(&alpha).WithPredicate(
+                    [&](const HloInstruction* instr) {
+                      return IsLosslesslyConvertibleTo(instr, alpha_ty);
+                    }))))) {
+      continue;
+    }
+    if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
+          return absl::StrCat("FuseSideInputAlpha: ", conv->ToString());
+        })) {
+      continue;
+    }
+
+    // Rewrite conv's operand 3 to
+    //
+    //   chain_of_reshapes_and_transposes(before_reshape).
+    //
+    // and store alpha in the conv's backend config.
+    //
+    // We're going to do something bad here: We aren't going to check that the
+    // chain of reshapes/transposes has one use, so we're potentially
+    // duplicating all these instructions (once with alpha and once without).
+    //
+    // This is justified because
+    //
+    //  - duplicating reshapes/transposes shouldn't be "that bad" -- these
+    //    instructions can usually be fused, and
+    //
+    //  - *not* fusing alpha can be catastrophic.  For s8->s8 convolutions, the
+    //    side-input must be s8.  But the product side_input * alpha is f32, so
+    //    we can only see that side-input is s8 if we fuse alpha. IOW not fusing
+    //    alpha means we'll run this s8->s8 conv as s8->f32, which is *much*
+    //    slower than some extra transposes.
+
+    // Recursively clone chain_of_reshapes_and_transposes until we get to
+    // `before_reshape`, at which point we skip the multiply(base, alpha) and
+    // just return base.
+    std::function<HloInstruction*(const HloInstruction*)> clone =
+        [&](const HloInstruction* instr) {
+          if (instr == before_reshape) {
+            return base;
+          }
+          CHECK(instr->opcode() == HloOpcode::kReshape ||
+                instr->opcode() == HloOpcode::kTranspose)
+              << "Must be reshape or transpose: " << instr->ToString();
+          return comp->AddInstruction(instr->CloneWithNewOperands(
+              instr->shape(), {clone(instr->operand(0))}));
+        };
+    absl::InlinedVector<HloInstruction*, 4> new_operands(
+        conv->operands().begin(), conv->operands().end());
+    new_operands[3] = clone(side_input);
+
+    HloInstruction* new_conv = comp->AddInstruction(
+        conv->CloneWithNewOperands(conv->shape(), new_operands));
+    comp->parent()->SetAndUniquifyInstrName(new_conv, conv->name());
+
+    TF_ASSIGN_OR_RETURN(Literal alpha_f64, alpha->literal().Convert(F64));
+    config.set_side_input_scale(alpha_f64.GetFirstElement<double>());
+    TF_RETURN_IF_ERROR(new_conv->set_backend_config(gpu_config));
+
+    TF_RETURN_IF_ERROR(comp->ReplaceInstruction(conv, new_conv));
+    changed = true;
+  }
+  return changed;
+}
+
+absl::StatusOr<bool> FuseElu(HloComputation* comp,
+                             se::GpuComputeCapability cc) {
+  if (!ShouldUseCudnnRuntimeFusion(comp->parent()->config().debug_options(),
+                                   cc)) {
+    return false;
+  }
+
+  bool changed = false;
+  for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
+    HloInstruction *gte1, *gte2, *gte3;
+    HloInstruction* conv;
+    HloInstruction* expm1;
+
+    if (!Match(instr,
+               m::Select(m::Compare(m::GetTupleElement(&gte1, m::Op()),
+                                    m::Broadcast(m::ConstantEffectiveScalar(0)))
+                             .WithComparisonDirection(ComparisonDirection::kGt)
+                             .WithOneUse(),
+                         m::GetTupleElement(
+                             &gte2,
+                             m::Op(&conv)
+                                 .WithPredicate(IsNonDepthwiseConvCustomCall)
+                                 .WithOneUse(),
+                             /*tuple_index=*/0)
+                             // TODO(jlebar): Why only fp16?
+                             .WithElementType(F16),
+                         m::Op(&expm1)
+                             .WithOpcode(HloOpcode::kExpm1)
+                             .WithOperand(0, m::GetTupleElement(&gte3, m::Op()))
+                             .WithOneUse()))) {
+      continue;
+    }
+
+    // The three GTEs should be the same, and these should be the only uses.
+    if (gte1 != gte2 || gte2 != gte3 || gte1->user_count() != 3) {
+      continue;
+    }
+
+    if (!IsSuitableForCudnnRuntimeFusion(conv)) {
+      continue;
+    }
+
+    TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
+                        conv->backend_config<GpuBackendConfig>());
+    CudnnConvBackendConfig& config =
+        *gpu_config.mutable_cudnn_conv_backend_config();
+    if (config.activation_mode() != se::dnn::kNone) {
+      continue;
+    }
+
+    if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
+          return absl::StrCat("FuseElu: ", conv->ToString());
+        })) {
+      continue;
+    }
+    TF_ASSIGN_OR_RETURN(conv, EnsureIsConvBiasActivation(conv));
+    config.set_activation_mode(se::dnn::kElu);
+    TF_RETURN_IF_ERROR(conv->set_backend_config(gpu_config));
+    TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, gte1));
+    changed = true;
+  }
+  return changed;
+}
+
+absl::StatusOr<bool> FuseRelu(HloComputation* comp) {
+  bool changed = false;
+  for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
+    HloInstruction* gte;
+    HloInstruction* conv;
+    if (!Match(instr,
+               m::MaximumAnyOrder(
+                   m::Broadcast(m::ConstantEffectiveScalar(0)),
+                   m::GetTupleElement(
+                       &gte, m::Op(&conv)
+                                 .WithPredicate(IsNonDepthwiseConvCustomCall)
+                                 .WithOneUse())
+                       .WithOneUse()))) {
+      continue;
+    }
+    TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
+                        conv->backend_config<GpuBackendConfig>());
+    CudnnConvBackendConfig& config =
+        *gpu_config.mutable_cudnn_conv_backend_config();
+    if (config.activation_mode() != se::dnn::kNone) {
+      continue;
+    }
+
+    if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
+          return absl::StrCat("FuseRelu: ", conv->ToString());
+        })) {
+      continue;
+    }
+    TF_ASSIGN_OR_RETURN(conv, EnsureIsConvBiasActivation(conv));
+    config.set_activation_mode(se::dnn::kRelu);
+    TF_RETURN_IF_ERROR(conv->set_backend_config(gpu_config));
+    TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, gte));
+    changed = true;
+  }
+  return changed;
+}
+
+absl::StatusOr<bool> FuseRelu6(HloComputation* comp,
+                               se::GpuComputeCapability cc) {
+  if (!ShouldUseCudnnRuntimeFusion(comp->parent()->config().debug_options(),
+                                   cc)) {
+    return false;
+  }
+
+  bool changed = false;
+  for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
+    HloInstruction *gte, *conv;
+    if (!Match(
+            instr,
+            m::Clamp(m::Broadcast(m::ConstantEffectiveScalar(0)),
+                     m::GetTupleElement(
+                         &gte, m::Op(&conv)
+                                   .WithPredicate(IsNonDepthwiseConvCustomCall)
+                                   .WithOneUse())
+                         // TODO(jlebar): Why only fp16?
+                         .WithElementType(F16)
+                         .WithOneUse(),
+                     m::Broadcast(m::ConstantEffectiveScalar(6))))) {
+      continue;
+    }
+    TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
+                        conv->backend_config<GpuBackendConfig>());
+    CudnnConvBackendConfig& config =
+        *gpu_config.mutable_cudnn_conv_backend_config();
+    if (config.activation_mode() != se::dnn::kNone) {
+      continue;
+    }
+
+    if (!IsSuitableForCudnnRuntimeFusion(conv)) {
+      continue;
+    }
+
+    if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
+          return absl::StrCat("FuseRelu6: ", conv->ToString());
+        })) {
+      continue;
+    }
+    TF_ASSIGN_OR_RETURN(conv, EnsureIsConvBiasActivation(conv));
+    config.set_activation_mode(se::dnn::kRelu6);
+    TF_RETURN_IF_ERROR(conv->set_backend_config(gpu_config));
+    TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, gte));
+    changed = true;
+  }
+  return changed;
+}
+
+absl::StatusOr<bool> FuseLeakyRelu(HloComputation* comp,
+                                   se::GpuComputeCapability cc) {
+  if (!ShouldUseCudnnRuntimeFusion(comp->parent()->config().debug_options(),
+                                   cc)) {
+    return false;
+  }
+
+  bool changed = false;
+  for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
+    HloInstruction *gte1, *gte2, *gte3, *conv, *alpha;
+    if (!Match(instr,
+               m::Select(
+                   m::Compare(m::GetTupleElement(&gte1, m::Op()),
+                              m::Broadcast(m::ConstantEffectiveScalar(0)))
+                       .WithComparisonDirection(ComparisonDirection::kGt)
+                       .WithOneUse(),
+                   m::GetTupleElement(
+                       &gte2, m::Op(&conv)
+                                  .WithPredicate(IsNonDepthwiseConvCustomCall)
+                                  .WithOneUse())
+                       // TODO(jlebar): Why only fp16?
+                       .WithElementType(F16),
+                   m::Multiply(m::GetTupleElement(&gte3, m::Op()),
+                               m::Broadcast(m::ConstantEffectiveScalar(&alpha)))
+                       .WithOneUse()))) {
+      continue;
+    }
+
+    // The three GTEs should be the same, and these should be the only uses.
+    if (gte1 != gte2 || gte2 != gte3 || gte1->user_count() != 3) {
+      continue;
+    }
+
+    TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
+                        conv->backend_config<GpuBackendConfig>());
+    CudnnConvBackendConfig& config =
+        *gpu_config.mutable_cudnn_conv_backend_config();
+    if (config.activation_mode() != se::dnn::kNone) {
+      continue;
+    }
+
+    if (!IsSuitableForCudnnRuntimeFusion(conv)) {
+      continue;
+    }
+
+    if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
+          return absl::StrCat("FuseLeakyRelu: ", conv->ToString());
+        })) {
+      continue;
+    }
+    TF_ASSIGN_OR_RETURN(conv, EnsureIsConvBiasActivation(conv));
+    config.set_activation_mode(se::dnn::kLeakyRelu);
+    TF_ASSIGN_OR_RETURN(Literal alpha_f64, alpha->literal().Convert(F64));
+    config.set_leakyrelu_alpha(alpha_f64.GetFirstElement<double>());
+    TF_RETURN_IF_ERROR(conv->set_backend_config(gpu_config));
+    TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, gte1));
+    changed = true;
+  }
+  return changed;
+}
+
+absl::StatusOr<bool> FuseConvertToF16(HloComputation* comp) {
+  bool changed = false;
+  for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
+    HloInstruction* gte = nullptr;
+    HloInstruction* conv = nullptr;
+
+    auto f32_convertible_to_f16_pat =
+        m::Op().WithElementType(F32).WithPredicate(
+            IsLosslesslyConvertibleToF16);
+    if (!MatchAndLogIfFailed(
+            instr, "f16 conv",
+            m::Convert(
+                m::GetTupleElement(
+                    &gte,
+                    m::Op(&conv)
+                        .WithPredicate(IsConvCustomCall)
+                        .WithOperand(0, f32_convertible_to_f16_pat)
+                        .WithOperand(1, f32_convertible_to_f16_pat)
+                        .WithOperandIfPresent(2, f32_convertible_to_f16_pat)
+                        .WithOperandIfPresent(3, f32_convertible_to_f16_pat),
+                    0)
+                    .WithOneUse())
+                .WithElementType(F16),
+            VLOG_IS_ON(3),
+            m::Op().WithOperand(0, m::GetTupleElement(m::Op().WithPredicate(
+                                       IsConvCustomCall))))) {
+      continue;
+    }
+    if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
+          return absl::StrCat("FuseConvertToF16: ", conv->ToString());
+        })) {
+      continue;
+    }
+
+    VLOG(2) << "Matched fp16 conv: " << conv->ToString();
+
+    // In fp16 convs, all operands, including `bias`, must be fp16.  This is
+    // different from int8 convs, where the bias is fp32.  See table of
+    // supported datatypes at
+    // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionBiasActivationForward
+    absl::InlinedVector<HloInstruction*, 4> new_operands;
+    for (HloInstruction* operand : conv->operands()) {
+      new_operands.push_back(
+          MakeConvertToHlo(operand, F16, &operand->metadata()));
+    }
+
+    Shape new_shape = conv->shape();
+    new_shape.mutable_tuple_shapes(0)->set_element_type(F16);
+
+    HloInstruction* new_conv = comp->AddInstruction(
+        conv->CloneWithNewOperands(new_shape, new_operands));
+    comp->parent()->SetAndUniquifyInstrName(new_conv, conv->name());
+    TF_ASSIGN_OR_RETURN(HloInstruction * new_instr,
+                        MakeGetTupleElementHlo(new_conv, 0));
+    TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_instr));
+    changed = true;
+  }
+  return changed;
+}
+
+absl::StatusOr<bool> FuseConvertToS8(HloComputation* comp,
+                                     se::GpuComputeCapability cc) {
+  if (IsROCm(cc)) return false;
+  bool changed = false;
+  for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
+    HloInstruction* gte = nullptr;
+    HloInstruction* conv = nullptr;
+
+    auto conv_pattern =
+        m::Op(&conv)
+            .WithPredicate(IsConvCustomCall)
+            .WithOperand(0, m::Op().WithPredicate(IsLosslesslyConvertibleToS8))
+            .WithOperand(1, m::Op().WithPredicate(IsLosslesslyConvertibleToS8));
+
+    PrimitiveType conv_output_ty;
+    if (MatchAndLogIfFailed(
+            instr, "s8->s8 conv",
+            m::Convert(m::Clamp(m::Broadcast(m::ConstantEffectiveScalar(-128)),
+                                m::GetTupleElement(
+                                    &gte,
+                                    conv_pattern.WithOperandIfPresent(
+                                        3, m::Op().WithPredicate(
+                                               IsLosslesslyConvertibleToS8)),
+                                    0)
+                                    .WithOneUse(),
+                                m::Broadcast(m::ConstantEffectiveScalar(127))))
+                .WithElementType(S8),
+            VLOG_IS_ON(3),
+            m::Convert(m::Clamp(m::Op(),
+                                m::GetTupleElement(
+                                    m::Op().WithPredicate(IsConvCustomCall)),
+                                m::Op()))
+                .WithElementType(S8))) {
+      conv_output_ty = S8;
+    } else if (MatchAndLogIfFailed(
+                   instr, "s8->f32 conv",
+                   m::GetTupleElement(&gte,
+                                      conv_pattern.WithOperandIfPresent(
+                                          3, m::Op().WithElementType(F32)),
+                                      0)
+                       .WithElementType(F32),
+                   VLOG_IS_ON(3),
+                   m::GetTupleElement(m::Op().WithPredicate(IsConvCustomCall))
+                       .WithElementType(F32))) {
+      conv_output_ty = F32;
+    } else {
+      continue;
+    }
+    if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
+          return absl::StrCat("FuseConvertToS8: ", conv->ToString());
+        })) {
+      continue;
+    }
+
+    absl::InlinedVector<HloInstruction*, 4> new_operands(
+        conv->operands().begin(), conv->operands().end());
+    new_operands[0] =
+        MakeConvertToHlo(new_operands[0], S8, &new_operands[0]->metadata());
+    new_operands[1] =
+        MakeConvertToHlo(new_operands[1], S8, &new_operands[1]->metadata());
+    // Don't convert bias (operand 2); it's always f32 for s8 ops in cudnn.  See
+    // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionBiasActivationForward
+    if (new_operands.size() >= 4) {
+      // side-input always matches conv output type.  We checked in the patterns
+      // above that it's losslessly-convertible to this type.
+      new_operands[3] = MakeConvertToHlo(new_operands[3], conv_output_ty,
+                                         &new_operands[3]->metadata());
+    }
+
+    Shape new_shape = conv->shape();
+    new_shape.mutable_tuple_shapes(0)->set_element_type(conv_output_ty);
+
+    HloInstruction* new_conv = comp->AddInstruction(
+        conv->CloneWithNewOperands(new_shape, new_operands));
+    comp->parent()->SetAndUniquifyInstrName(new_conv, conv->name());
+    TF_ASSIGN_OR_RETURN(HloInstruction * new_instr,
+                        MakeGetTupleElementHlo(new_conv, 0));
+    TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_instr));
+    changed = true;
+  }
+  return changed;
+}
+
+absl::Status CheckNoIllegalIntegerConvs(HloComputation* comp) {
+  auto is_integral_not_s8 = [](const Shape& s) {
+    return primitive_util::IsIntegralType(s.element_type()) &&
+           s.element_type() != S8;
+  };
+
+  std::vector<HloInstruction*> bad_convs;
+  for (HloInstruction* instr : comp->instructions()) {
+    if (!IsConvCustomCall(instr)) {
+      continue;
+    }
+    if (is_integral_not_s8(instr->shape().tuple_shapes(0)) ||
+        is_integral_not_s8(instr->operand(0)->shape()) ||
+        is_integral_not_s8(instr->operand(1)->shape()) ||
+        (instr->operand_count() >= 4 &&
+         is_integral_not_s8(instr->operand(3)->shape()))) {
+      bad_convs.push_back(instr);
+    }
+  }
+
+  if (bad_convs.empty()) {
+    return absl::OkStatus();
+  }
+
+  return Unimplemented(
+      R"(
+Can't lower one or more integer convolutions to idioms supported by CuDNN.
+
+CuDNN integer convolutions must have:
+
+  - s8 input and filter,
+  - f32 bias (if present),
+  - s8 or f32 output, and
+  - s8 side_input (if present) if output is s8.
+
+For each of the unsupported convs below, we weren't able to lower one of the
+operands or the output to the appropriate type.
+
+See specific HLO idioms in cudnn_fused_conv_rewriter.h, and see cudnn semantics:
+
+https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionBiasActivationForward and
+https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#scaling-parameters
+
+Unsupported convs:
+%s
+
+******* Full HLO module *******
+%s
+)",
+      absl::StrJoin(bad_convs, "\n",
+                    [](std::string* out, HloInstruction* instr) {
+                      absl::StrAppend(out, " - ", instr->ToString());
+                    }),
+      comp->parent()->ToString());
+}
+
+void VlogStats(HloModule* module) {
+  if (!VLOG_IS_ON(1)) {
+    return;
+  }
+
+  VLOG(1) << "Results of CudnnFusedConvRewriter for " << module->name();
+  absl::flat_hash_map<std::string, int> stats;
+  for (HloComputation* comp : module->MakeNonfusionComputations()) {
+    for (HloInstruction* instr : comp->instructions()) {
+      if (!Match(instr, m::Op().WithPredicate(IsConvCustomCall))) {
+        continue;
+      }
+
+      VLOG(3) << instr->ToString();
+
+      if (instr->custom_call_target() == kCudnnConvForwardCallTarget) {
+        ++stats["01 non-fused forward convs"];
+      } else if (instr->custom_call_target() ==
+                 kCudnnConvBiasActivationForwardCallTarget) {
+        ++stats["02 fused forward convs"];
+      }
+
+      PrimitiveType conv_in_ty = instr->operand(0)->shape().element_type();
+      PrimitiveType conv_out_ty = instr->shape().tuple_shapes(0).element_type();
+      if (conv_in_ty == F32) {
+        ++stats["10 f32 convs"];
+      } else if (conv_in_ty == F16) {
+        ++stats["11 f16 convs"];
+      } else if (conv_in_ty == S8) {
+        if (conv_out_ty == S8) {
+          ++stats["12 s8->s8 convs"];
+        } else if (conv_out_ty == F32) {
+          ++stats["13 s8->f32 convs"];
+        } else {
+          LOG(ERROR) << "Unexpected conv: " << instr->ToString();
+        }
+      }
+
+      if (instr->operand_count() > 2) {
+        ++stats["20 convs with bias"];
+        if (Match(instr->operand(2),
+                  m::Broadcast(m::ConstantEffectiveScalar(0)))) {
+          ++stats["21 convs with 0 bias"];
+        }
+      }
+      if (instr->operand_count() > 3) {
+        ++stats["22 convs with side-input"];
+      }
+
+      auto gpu_config = instr->backend_config<GpuBackendConfig>();
+      if (!gpu_config.ok()) {
+        LOG(ERROR) << "Couldn't parse backend config for " << instr->ToString();
+        continue;
+      }
+      const CudnnConvBackendConfig& config =
+          gpu_config->cudnn_conv_backend_config();
+      if (config.conv_result_scale() != 1) {
+        ++stats["30 convs with result scale"];
+      }
+      if (config.side_input_scale() != 0 && config.side_input_scale() != 1) {
+        ++stats["31 convs with side-input scale"];
+      }
+      ++stats[absl::StrCat(
+          "32 convs with activation mode ",
+          se::dnn::ActivationMode_Name(config.activation_mode()))];
+    }
+  }
+
+  std::vector<std::pair<std::string, int>> stats_sorted(stats.begin(),
+                                                        stats.end());
+  absl::c_sort(stats_sorted);
+  for (const auto& kv : stats_sorted) {
+    VLOG(1) << absl::StreamFormat("%4d %s", kv.second,
+                                  absl::string_view(kv.first).substr(3));
+  }
+}
+
+}  // namespace
+
+absl::StatusOr<bool> CudnnFusedConvRewriter::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  bool any_changed = false;
+
+  for (HloComputation* comp :
+       module->MakeNonfusionComputations(execution_threads)) {
+    bool changed = false;
+    // Rewrite FP8 convolutions and supported adjacent pointwise ops into a
+    // ForwardGraph Custom Call.
+    if (!IsROCm(compute_capability_)) {
+      auto cc = std::get<se::CudaComputeCapability>(compute_capability_);
+      TF_ASSIGN_OR_RETURN(
+          changed, F8GraphConv(comp, cc, dnn_version_, toolkit_version_));
+      if (changed) {
+        return changed;
+      }
+    }
+    // Fuse "inside out" starting with the operations closest to the conv.
+    TF_ASSIGN_OR_RETURN(changed, FuseRemoveConvertInConv(comp));
+    any_changed |= changed;
+
+    TF_ASSIGN_OR_RETURN(changed, FuseConvAlpha(comp));
+    any_changed |= changed;
+
+    // s8 convs' bias and side-input appear before conversion to s8.
+    //
+    // Run FuseBiasOrSideInput twice, so we get both the bias and the side
+    // input, if both are present.
+    TF_ASSIGN_OR_RETURN(changed, FuseBiasOrSideInput(comp));
+    any_changed |= changed;
+    TF_ASSIGN_OR_RETURN(changed, FuseBiasOrSideInput(comp));
+    any_changed |= changed;
+    TF_ASSIGN_OR_RETURN(changed, FuseSideInputAlpha(comp));
+    any_changed |= changed;
+
+    // Relu might appear before or after convert-to-f16/s8, so we check in both
+    // cases.
+    TF_ASSIGN_OR_RETURN(changed, FuseRelu(comp));
+    any_changed |= changed;
+    TF_ASSIGN_OR_RETURN(changed, FuseElu(comp, compute_capability_));
+    any_changed |= changed;
+    TF_ASSIGN_OR_RETURN(changed, FuseRelu6(comp, compute_capability_));
+    any_changed |= changed;
+    TF_ASSIGN_OR_RETURN(changed, FuseLeakyRelu(comp, compute_capability_));
+    any_changed |= changed;
+
+    TF_ASSIGN_OR_RETURN(changed, FuseConvertToF16(comp));
+    any_changed |= changed;
+
+    TF_ASSIGN_OR_RETURN(changed, FuseConvertToS8(comp, compute_capability_));
+    any_changed |= changed;
+
+    // f16 convs' bias+side-input can appear before or after conversion to f16.
+    TF_ASSIGN_OR_RETURN(changed, FuseBiasOrSideInput(comp));
+    any_changed |= changed;
+    TF_ASSIGN_OR_RETURN(changed, FuseBiasOrSideInput(comp));
+    any_changed |= changed;
+    TF_ASSIGN_OR_RETURN(changed, FuseSideInputAlpha(comp));
+    any_changed |= changed;
+
+    TF_ASSIGN_OR_RETURN(changed, FuseRelu(comp));
+    any_changed |= changed;
+    TF_ASSIGN_OR_RETURN(changed, FuseElu(comp, compute_capability_));
+    any_changed |= changed;
+    TF_ASSIGN_OR_RETURN(changed, FuseRelu6(comp, compute_capability_));
+    any_changed |= changed;
+    TF_ASSIGN_OR_RETURN(changed, FuseLeakyRelu(comp, compute_capability_));
+    any_changed |= changed;
+
+    // Check that we don't have any convs outputting integer types other than
+    // s8 - cudnn does not support these.  They should have been transformed to
+    // int8->int8 or int8->float above.
+    TF_RETURN_IF_ERROR(CheckNoIllegalIntegerConvs(comp));
+  }
+
+  VlogStats(module);
+
+  return any_changed;
+}
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter.h b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter.h
new file mode 100644
index 0000000..5caf9c0
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter.h
@@ -0,0 +1,135 @@
+/* Copyright 2018 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_CUDNN_FUSED_CONV_REWRITER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_CUDNN_FUSED_CONV_REWRITER_H_
+
+#include <cstdint>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/stream_executor/dnn.h"
+
+namespace xla {
+namespace gpu {
+
+// Rewrites custom-calls targeting cudnnConvolutionForward to
+// cudnnConvolutionBiasActivationForward by fusing operations following forward
+// convolution.  This transform must run after GpuConvRewriter.
+//
+// Semantics of underlying cudnn ops:
+//
+// https://docs.nvidia.com/deeplearning/cudnn/latest/api/cudnn-cnn-library.html#cudnnconvolutionforward
+// https://docs.nvidia.com/deeplearning/cudnn/latest/developer/misc.html#scaling-parameters
+//
+// ## Floating-point convs
+//
+// A "complete" fused floating-point conv has the form
+//
+//   max(0, alpha1 * conv(x, w) + alpha2 * side_input + broadcast(bias)),
+//
+// which we fuse to
+//
+//   cudnnConvolutionBiasActivationForward(x, w, bias, side_input).
+//
+// You can leave out side_input, bias, alpha1, alpha2, and max(x, 0) and still
+// get a fused convolution.  alpha1/2 must be broadcasts of scalar constants.
+//
+// f16 convs accumulate in f32.  We represent this in HLO as an f32 convolution
+// whose inputs can be converted to f16 without loss of precision and whose
+// output is immediately converted to f16.  A fused f16 conv must follow one of
+// the following idioms.
+//
+//   1. convert_f16(conv_f32(x_f32, w_f32)) +
+//      side_input_f16 + broadcast(bias_f16)
+//
+//   2. convert_f16(conv_f32(x_f32, w_f32) +
+//                  side_input_f32 + broadcast(bias_f32))
+//
+// (These are not strictly mathematically equivalent, but cudnn doesn't tell us
+// which one it does, and we deem them "close enough".)
+//
+// The foo_f32 HLOs must all be losslessly-convertible to f16.  Some valid
+// examples:
+//
+//   - foo_f32 = convert_f32(foo_f16)
+//   - foo_f32 = an f32 constant whose values all fit within f16
+//   - foo_f32 = broadcast/transpose/reshape(one of the above)
+//
+// If you have a relu, it can appear before or after the convert_f16.
+//
+// Note that here `bias` must be losslessly-convertible to f16; this is
+// different than for s8 convolutions, where bias is f32.
+//
+// ## Integer convs
+//
+// In pure HLO, a "complete" integer conv is spelled as one of the following
+// `result`s.
+//
+//   base = alpha1_f32 * convert_f32(conv_s32(input_s32, filter_s32)) +
+//          alpha2_f32 * side_input +
+//          bias_f32
+//
+//   result_f32        = max(base, 0)
+//   result_s8_option1 = max(convert_s8(clamp(-128, base, 127)), 0)
+//   result_s8_option2 = convert_s8(clamp(-128, max(base, 0), 127))
+//
+// The foo_s32 HLOs must be losslessly-convertible to s8.  If the `result_s8`
+// case, side_input should be an f32 HLO that's losslessly-convertible to s8;
+// otherwise, it should be losslessly-convertible to f32.
+//
+// In the `result_s8` case where there's no bias, side-input, or alpha1, you can
+// skip the convert_f32 on conv.
+//
+// If you have an integer convolution that doesn't fit one of these idioms, this
+// pass returns an error -- cudnn will not be able to run it.
+class CudnnFusedConvRewriter : public HloModulePass {
+ public:
+  CudnnFusedConvRewriter(se::CudaComputeCapability cc,
+                         se::dnn::VersionInfo dnn_version,
+                         int32_t toolkit_version)
+      : compute_capability_(cc),
+        dnn_version_(dnn_version),
+        toolkit_version_(toolkit_version) {}
+  CudnnFusedConvRewriter(se::RocmComputeCapability cc,
+                         se::dnn::VersionInfo dnn_version,
+                         int32_t toolkit_version)
+      : compute_capability_(cc),
+        dnn_version_(dnn_version),
+        toolkit_version_(toolkit_version) {}
+
+  absl::string_view name() const override {
+    return "cudnn-fused-convolution-rewriter";
+  }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+  const se::GpuComputeCapability compute_capability_;
+  const se::dnn::VersionInfo dnn_version_;
+  const int32_t toolkit_version_;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_CUDNN_FUSED_CONV_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter_test.cc
new file mode 100644
index 0000000..5e22a6f
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter_test.cc
@@ -0,0 +1,3169 @@
+/* Copyright 2018 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/cudnn_fused_conv_rewriter.h"
+
+#include <array>
+#include <memory>
+#include <string>
+#include <string_view>
+#include <thread>  // NOLINT
+#include <utility>
+#include <variant>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/container/flat_hash_map.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_replace.h"
+#include "absl/strings/string_view.h"
+#include "xla/comparison_util.h"
+#include "xla/error_spec.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/algebraic_simplifier.h"
+#include "xla/service/convert_mover.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/service/gpu/stream_executor_util.h"
+#include "xla/service/gpu/tests/gpu_codegen_test.h"
+#include "xla/service/gpu/transforms/conv_rewriter.h"
+#include "xla/service/hlo_constant_folding.h"
+#include "xla/service/hlo_module_config.h"
+#include "xla/service/hlo_pass_fix.h"
+#include "xla/service/hlo_pass_pipeline.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/service/reshape_mover.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/stream_executor/dnn.h"
+#include "xla/tests/filecheck.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/tests/verified_hlo_module.h"
+#include "xla/tsl/lib/core/status_test_util.h"
+#include "tsl/platform/statusor.h"
+
+#if GOOGLE_CUDA
+#include "third_party/gpus/cuda/include/cuda.h"
+#elif TENSORFLOW_USE_ROCM
+#include "rocm/rocm_config.h"
+#endif  // GOOGLE_CUDA
+
+namespace xla {
+namespace gpu {
+namespace {
+
+// TODO(b/210165681): The tests in this file are fragile to HLO op names.
+
+namespace m = match;
+
+using ::testing::HasSubstr;
+using ::testing::Not;
+
+// TODO: Use constexpr vector once XLA is compiled with C++20.
+const auto* kf16f32f64 = new std::vector<std::string>({"f16", "f32", "f64"});
+const auto* kf16f32 = new std::vector<std::string>({"f16", "f32"});
+
+class CudnnFusedConvRewriterHloTest : public HloTestBase {
+ public:
+  bool IsCuda() {
+    return std::holds_alternative<se::CudaComputeCapability>(
+        backend()
+            .default_stream_executor()
+            ->GetDeviceDescription()
+            .gpu_compute_capability());
+  }
+  se::CudaComputeCapability GetCudaComputeCapability() {
+    return backend()
+        .default_stream_executor()
+        ->GetDeviceDescription()
+        .cuda_compute_capability();
+  }
+  stream_executor::dnn::VersionInfo GetDnnVersion() {
+    return GetDnnVersionInfoOrDefault(backend().default_stream_executor());
+  }
+
+  int32_t GetToolkitVersion() const {
+#if GOOGLE_CUDA
+    return CUDA_VERSION;
+#elif TENSORFLOW_USE_ROCM
+    return TF_ROCM_VERSION;
+#endif
+    return 0;
+  }
+
+  CudnnFusedConvRewriterHloTest()
+      : HloTestBase(/*verifier_layout_sensitive=*/false,
+                    /*allow_mixed_precision_in_hlo_verifier=*/false,
+                    /*instruction_can_change_layout_func=*/{}) {}
+};
+
+class CudnnFusedConvRewriterTest : public GpuCodegenTest {
+ public:
+  bool IsCuda() {
+    return std::holds_alternative<se::CudaComputeCapability>(
+        backend()
+            .default_stream_executor()
+            ->GetDeviceDescription()
+            .gpu_compute_capability());
+  }
+  se::CudaComputeCapability GetCudaComputeCapability() {
+    return backend()
+        .default_stream_executor()
+        ->GetDeviceDescription()
+        .cuda_compute_capability();
+  }
+  stream_executor::dnn::VersionInfo GetDnnVersion() {
+    return GetDnnVersionInfoOrDefault(backend().default_stream_executor());
+  }
+
+  int32_t GetToolkitVersion() const {
+#if GOOGLE_CUDA
+    return CUDA_VERSION;
+#elif TENSORFLOW_USE_ROCM
+    return TF_ROCM_VERSION;
+#endif
+    return 0;
+  }
+
+ protected:
+  std::string GetOptimizedHlo(absl::string_view hlo_string) {
+    // cudnn_vectorize_convolutions transforms convolutions, making it hard to
+    // match them here in this test.  What's worse, the transforms it does
+    // depends on the GPU that's available!  So just disable them for this
+    // function that gets the optimized HLO.  When we actually run the module
+    // we'll still have this pass enabled.
+    HloModuleConfig config = GetModuleConfigForTest();
+    DebugOptions debug_opts = config.debug_options();
+    debug_opts.add_xla_disable_hlo_passes("cudnn_vectorize_convolutions");
+    debug_opts.set_xla_gpu_use_runtime_fusion(true);
+    config.set_debug_options(debug_opts);
+
+    auto result = backend().compiler()->RunHloPasses(
+        ParseAndReturnVerifiedModule(hlo_string, config).value(),
+        backend().default_stream_executor(), backend().memory_allocator());
+    if (!result.status().ok()) {
+      TF_EXPECT_OK(result.status())
+          << "HLO compilation failed: " << result.status();
+      return "";
+    }
+    HloPrintOptions print_opts;
+    print_opts.set_print_operand_shape(false);
+    return (*result)->ToString(print_opts);
+  }
+
+  void TestMatchWithAllTypes(absl::string_view hlo_string) {
+    for (absl::string_view type : *(IsCuda() ? kf16f32f64 : kf16f32)) {
+      const std::string hlo_with_new_type =
+          absl::StrReplaceAll(hlo_string, {{"TYPE", type}});
+      std::string optimized_hlo_string = GetOptimizedHlo(hlo_with_new_type);
+      EXPECT_THAT(optimized_hlo_string,
+                  Not(HasSubstr(kCudnnConvForwardCallTarget)))
+          << optimized_hlo_string;
+      EXPECT_THAT(optimized_hlo_string,
+                  HasSubstr(kCudnnConvBiasActivationForwardCallTarget));
+
+      TF_ASSERT_OK_AND_ASSIGN(auto module,
+                              ParseAndReturnVerifiedModule(hlo_with_new_type));
+      DebugOptions debug_opts = module->config().debug_options();
+      debug_opts.set_xla_gpu_use_runtime_fusion(true);
+      module->mutable_config().set_debug_options(debug_opts);
+      EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{0.01}))
+          << optimized_hlo_string;
+    }
+  }
+
+  void TestClamp(absl::string_view pre_hlo_string,
+                 absl::string_view post_hlo_string) {
+    std::string alpha_conv_scalar, alpha_side_input_scalar;
+    std::string elementwise_type;
+
+    std::string optimized_hlo_string = GetOptimizedHlo(pre_hlo_string);
+    EXPECT_THAT(optimized_hlo_string, Not(HasSubstr("Convert")));
+    EXPECT_THAT(optimized_hlo_string, HasSubstr("__cudnn$conv"));
+    EXPECT_TRUE(RunAndCompare(pre_hlo_string, ErrorSpec{0.01}))
+        << pre_hlo_string;
+
+    absl::StatusOr<bool> filecheck_result =
+        RunFileCheck(optimized_hlo_string, post_hlo_string);
+    ASSERT_TRUE(filecheck_result.ok()) << filecheck_result.status();
+    EXPECT_TRUE(*filecheck_result);
+  }
+
+  void TestNotMatchWithAllTypes(absl::string_view hlo_string) {
+    for (absl::string_view type : *(IsCuda() ? kf16f32f64 : kf16f32)) {
+      const std::string hlo_with_new_type =
+          absl::StrReplaceAll(hlo_string, {{"TYPE", type}});
+      std::string optimized_hlo_string = GetOptimizedHlo(hlo_with_new_type);
+      SCOPED_TRACE(optimized_hlo_string);
+      EXPECT_THAT(optimized_hlo_string, HasSubstr(kCudnnConvForwardCallTarget));
+      EXPECT_THAT(optimized_hlo_string,
+                  Not(HasSubstr(kCudnnConvBiasActivationForwardCallTarget)));
+    }
+  }
+
+  void TestF8(std::string pre_hlo_string, std::string custom_call_string,
+              std::string serialized_graph_string) {
+    if (!IsCuda()) return;
+    if (GetCudaComputeCapability().IsAtLeast(
+            se::CudaComputeCapability::HOPPER)) {
+      // On Hopper and newer architectures, test numerical correctness and
+      // verify the HLO of the Custom Call with operand and return layouts and
+      // the serialized graph based on the full compiler pipeline.
+      std::string optimized_hlo_string = GetOptimizedHlo(pre_hlo_string);
+      EXPECT_THAT(optimized_hlo_string, Not(HasSubstr("Convert")));
+      EXPECT_THAT(optimized_hlo_string, HasSubstr("__cudnn$conv"));
+      EXPECT_TRUE(RunAndCompare(pre_hlo_string, ErrorSpec{0.15, 0.15}))
+          << pre_hlo_string;
+
+      absl::StatusOr<bool> filecheck_result =
+          RunFileCheck(optimized_hlo_string, custom_call_string);
+      ASSERT_TRUE(filecheck_result.ok()) << filecheck_result.status();
+      EXPECT_TRUE(*filecheck_result);
+
+      filecheck_result =
+          RunFileCheck(optimized_hlo_string, serialized_graph_string);
+      ASSERT_TRUE(filecheck_result.ok()) << filecheck_result.status();
+      EXPECT_TRUE(*filecheck_result);
+    } else {
+      // On older architectures, disregard layout information and only verify
+      // the basic configuration of the convolution Custom Call using the number
+      // of operands and the window_size and serialized graph attributes based
+      // on the ConvRewriter and CudnnFusedConvRewriter passes.
+      std::string::size_type p0 = custom_call_string.find(':');
+      std::string::size_type p1 = custom_call_string.find("custom-call");
+      custom_call_string.erase(p0 + 1, p1 - p0 - 2);
+      p0 = custom_call_string.find(", dim_labels");
+      custom_call_string.erase(p0);
+
+      TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+                              ParseAndReturnVerifiedModule(pre_hlo_string));
+      TF_ASSERT_OK_AND_ASSIGN(
+          bool changed,
+          RunHloPass(ConvRewriter(GetCudaComputeCapability()), module.get()));
+      EXPECT_TRUE(changed);
+      RunAndFilecheckHloRewrite(
+          module->ToString(HloPrintOptions{}.set_print_operand_shape(false)),
+          CudnnFusedConvRewriter(
+              se::CudaComputeCapability{se::CudaComputeCapability::HOPPER, 0},
+              GetDnnVersion(), GetToolkitVersion()),
+          custom_call_string);
+      RunAndFilecheckHloRewrite(
+          module->ToString(HloPrintOptions{}.set_print_operand_shape(false)),
+          CudnnFusedConvRewriter(
+              se::CudaComputeCapability{se::CudaComputeCapability::HOPPER, 0},
+              GetDnnVersion(), GetToolkitVersion()),
+          serialized_graph_string);
+    }
+  }
+
+  void TestF8Parameterized(std::string template_pre_hlo_string,
+                           std::string template_custom_call_string,
+                           std::string template_serialized_graph_string) {
+    std::array<absl::string_view, 2> types = {"f8e4m3fn", "f8e5m2"};
+    std::array<absl::string_view, 2> clamp_lower = {"-448.", "-57344."};
+    std::array<absl::string_view, 2> clamp_upper = {"448.", "57344."};
+    absl::flat_hash_map<absl::string_view, absl::string_view> replacements;
+    for (int i = 0; i < 2; ++i) {
+      replacements["<<InputType>>"] = types[i];
+      for (int j = 0; j < 2; ++j) {
+        replacements["<<FilterType>>"] = types[j];
+        for (int k = 0; k < 2; ++k) {
+          replacements["<<OutputType>>"] = types[k];
+          replacements["<<ClampLower>>"] = clamp_lower[k];
+          replacements["<<ClampUpper>>"] = clamp_upper[k];
+          TestF8(absl::StrReplaceAll(template_pre_hlo_string, replacements),
+                 absl::StrReplaceAll(template_custom_call_string, replacements),
+                 absl::StrReplaceAll(template_serialized_graph_string,
+                                     replacements));
+        }
+      }
+    }
+  }
+};
+
+#if GOOGLE_CUDA
+#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8900)
+#define MAYBE_SKIP_TEST(CAUSE)                                           \
+  do {                                                                   \
+    if (absl::string_view(CAUSE) == "F8")                                \
+      GTEST_SKIP() << "FP8 convolutions require CUDA 12 and cuDNN 8.9."; \
+  } while (0)
+#else
+#define MAYBE_SKIP_TEST(CAUSE)
+#endif
+#else
+#define MAYBE_SKIP_TEST(CAUSE)                                \
+  do {                                                        \
+    GTEST_SKIP() << "ROCm does not support " CAUSE " fusion"; \
+  } while (0)
+#endif
+
+TEST_F(CudnnFusedConvRewriterTest, TestConvOnly) {
+  // max(0, conv(x, w));
+  TestMatchWithAllTypes(R"(
+    HloModule Test
+
+    ENTRY Test {
+      zero = TYPE[] constant(0)
+      zeros = TYPE[1,32,9,9] broadcast(zero), dimensions={}
+
+      input = TYPE[1,17,9,9] parameter(0)
+      filter = TYPE[3,3,17,32] parameter(1)
+
+      conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
+      ROOT relu = TYPE[1,32,9,9] maximum(zeros, conv)
+    })");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, DontFuseReluWithDepthwiseConv) {
+  // max(0, conv(x, w));
+  TestNotMatchWithAllTypes(R"(
+    HloModule Test
+
+    ENTRY Test {
+      zero = TYPE[] constant(0)
+      zeros = TYPE[1,17,9,9] broadcast(zero), dimensions={}
+
+      input = TYPE[1,17,9,9] parameter(0)
+      filter = TYPE[3,3,1,17] parameter(1)
+
+      conv = TYPE[1,17,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=17
+      ROOT relu = TYPE[1,17,9,9] maximum(zeros, conv)
+    })");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestBias) {
+  // max(0, conv(x, w) + bias);
+  TestMatchWithAllTypes(R"(
+    HloModule Test
+
+    ENTRY Test {
+      zero = TYPE[] constant(0)
+      zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+
+      input = TYPE[1,3,3,64] parameter(0)
+      filter = TYPE[3,3,64,64] parameter(1)
+      bias = TYPE[64] parameter(2)
+
+      conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+      broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
+      add1 = TYPE[1,3,3,64] add(conv, broadcasted_bias)
+      ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
+    })");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, Test3D) {
+  // max(0, conv(x, w) + bias);
+  std::string body = R"(
+    HloModule Test
+
+    ENTRY Test {
+      zero = TYPE[] constant(0)
+      zeros = TYPE[1,3,5,7,64] broadcast(zero), dimensions={}
+
+      input = TYPE[1,3,5,7,64] parameter(0)
+      filter = TYPE[3,3,3,64,64] parameter(1)
+      bias = TYPE[64] parameter(2)
+
+      conv = TYPE[1,3,5,7,64] convolution(input, filter), window={size=3x3x3 pad=1_1x1_1x1_1}, dim_labels=b012f_012io->b012f, feature_group_count=1
+      broadcasted_bias = TYPE[1,3,5,7,64] broadcast(bias), dimensions={4}
+      add1 = TYPE[1,3,5,7,64] add(conv, broadcasted_bias)
+    )";
+
+  std::string relu = R"(
+      ROOT relu = TYPE[1,3,5,7,64] maximum(zeros, add1)
+    })";
+
+  std::string elu = R"(
+      cmp = pred[1,3,5,7,64] compare(add1, zeros), direction=GT
+      expm1 = TYPE[1,3,5,7,64] exponential-minus-one(add1)
+      ROOT elu = TYPE[1,3,5,7,64] select(cmp, add1, expm1)
+    })";
+
+  TestMatchWithAllTypes(body + relu);
+  if (!IsCuda()) TestMatchWithAllTypes(body + elu);
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestBiasMultiCall) {
+  // max(0, conv(x, w) + bias);
+  std::string code = R"(
+    HloModule Test
+
+    ENTRY Test {
+      zero = TYPE[] constant(0)
+      zeros = TYPE[1,<<<format>>>,64] broadcast(zero), dimensions={}
+
+      input = TYPE[1,<<<format>>>,64] parameter(0)
+      filter = TYPE[3,3,64,64] parameter(1)
+      bias = TYPE[64] parameter(2)
+
+      conv = TYPE[1,<<<format>>>,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+      broadcasted_bias = TYPE[1,<<<format>>>,64] broadcast(bias), dimensions={3}
+      add1 = TYPE[1,<<<format>>>,64] add(conv, broadcasted_bias)
+      ROOT relu = TYPE[1,<<<format>>>,64] maximum(zeros, add1)
+    })";
+  absl::flat_hash_map<absl::string_view, absl::string_view> replacements;
+  replacements["<<<format>>>"] = "3,3";
+  TestMatchWithAllTypes(absl::StrReplaceAll(code, replacements));
+  replacements["<<<format>>>"] = "5,5";
+  TestMatchWithAllTypes(absl::StrReplaceAll(code, replacements));
+  replacements["<<<format>>>"] = "3,3";
+  TestMatchWithAllTypes(absl::StrReplaceAll(code, replacements));
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestBiasNoRelu) {
+  // conv(x, w) + bias;
+  TestMatchWithAllTypes(R"(
+    HloModule Test
+
+    ENTRY Test {
+      input = TYPE[1,3,3,64] parameter(0)
+      filter = TYPE[3,3,64,64] parameter(1)
+      bias = TYPE[64] parameter(2)
+
+      conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+      broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
+      ROOT add1 = TYPE[1,3,3,64] add(conv, broadcasted_bias)
+    })");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, DontFuseBiasWithDepthwiseConv) {
+  // conv(x, w) + bias;
+  TestNotMatchWithAllTypes(R"(
+    HloModule Test
+
+    ENTRY Test {
+      zero = TYPE[] constant(0)
+      zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+
+      input = TYPE[1,3,3,64] parameter(0)
+      filter = TYPE[3,3,1,64] parameter(1)
+      bias = TYPE[64] parameter(2)
+
+      conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=64
+      broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
+      add1 = TYPE[1,3,3,64] add(conv, broadcasted_bias)
+      ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
+    })");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestElu) {
+  // sum = conv(x, w) + bias
+  // select(compare(sum, 0, GT), sum, exponential-minus-one(sum));
+  TestMatchWithAllTypes(R"(
+    HloModule Test
+
+    ENTRY Test {
+      zero = TYPE[] constant(0)
+      zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+
+      input = TYPE[1,3,3,64] parameter(0)
+      filter = TYPE[3,3,64,64] parameter(1)
+      bias = TYPE[64] parameter(2)
+
+      conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+      broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
+      sum = TYPE[1,3,3,64] add(conv, broadcasted_bias)
+      cmp = pred[1,3,3,64] compare(sum, zeros), direction=GT
+      expm1 = TYPE[1,3,3,64] exponential-minus-one(sum)
+      ROOT elu = TYPE[1,3,3,64] select(cmp, sum, expm1)
+    })");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, DontFuseEluWithDepthwiseConv) {
+  // sum = conv(x, w) + bias
+  // select(compare(sum, 0, GT), sum, exponential-minus-one(sum));
+  TestNotMatchWithAllTypes(R"(
+    HloModule Test
+
+    ENTRY Test {
+      zero = TYPE[] constant(0)
+      zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+
+      input = TYPE[1,3,3,64] parameter(0)
+      filter = TYPE[3,3,1,64] parameter(1)
+      bias = TYPE[64] parameter(2)
+
+      conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=64
+      broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
+      sum = TYPE[1,3,3,64] add(conv, broadcasted_bias)
+      cmp = pred[1,3,3,64] compare(sum, zeros), direction=GT
+      expm1 = TYPE[1,3,3,64] exponential-minus-one(sum)
+      ROOT elu = TYPE[1,3,3,64] select(cmp, sum, expm1)
+    })");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestRelu6) {
+  if (IsCuda() && !GetCudaComputeCapability().IsAtLeast(
+                      se::CudaComputeCapability::AMPERE)) {
+    GTEST_SKIP() << "Conv-Bias-Relu6 fusion is supported and recommended with "
+                    "the Nvidia Ampere+ GPUs.";
+  }
+  // sum = conv(x, w) + bias
+  // clamp(0, sum, 6);
+  TestMatchWithAllTypes(R"(
+    HloModule Test
+    ENTRY Test {
+      zero = TYPE[] constant(0)
+      zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+      six = TYPE[] constant(6)
+      sixes = TYPE[1,3,3,64] broadcast(six), dimensions={}
+      input = TYPE[1,3,3,64] parameter(0)
+      filter = TYPE[3,3,64,64] parameter(1)
+      bias = TYPE[64] parameter(2)
+      conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+      broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
+      sum = TYPE[1,3,3,64] add(conv, broadcasted_bias)
+      ROOT relu6 = TYPE[1,3,3,64] clamp(zeros, sum, sixes)
+    })");
+}
+
+// At time of writing, cudnn runtime fusion cannot handle f16 convs with an odd
+// number of input/output channels.  Check that we don't try to run this conv
+// with runtime fusion (or, if we do, that it works!).
+TEST_F(CudnnFusedConvRewriterTest, TestRelu6OddChannels) {
+  if (IsCuda() && !GetCudaComputeCapability().IsAtLeast(
+                      se::CudaComputeCapability::AMPERE)) {
+    GTEST_SKIP() << "Conv-Bias-Relu6 fusion is supported and recommended with "
+                    "the Nvidia Ampere+ GPUs.";
+  }
+  TestMatchWithAllTypes(R"(
+    HloModule Test
+    ENTRY Test {
+      zeros = TYPE[1,384,1024,32] broadcast(TYPE[] constant(0)), dimensions={}
+      sixes = TYPE[1,384,1024,32] broadcast(TYPE[] constant(6)), dimensions={}
+      input = TYPE[1,769,2049,3] parameter(0)
+      filter = TYPE[32,3,3,3] parameter(1)
+      bias = TYPE[32] parameter(2)
+      conv = TYPE[1,384,1024,32] convolution(input, filter), window={size=3x3 stride=2x2}, dim_labels=b01f_o01i->b01f
+      broadcasted_bias = TYPE[1,384,1024,32] broadcast(bias), dimensions={3}
+      sum = add(conv, broadcasted_bias)
+      ROOT relu6 = clamp(zeros, sum, sixes)
+    })");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestLeakyRelu) {
+  if (IsCuda() && !GetCudaComputeCapability().IsAtLeast(
+                      se::CudaComputeCapability::AMPERE)) {
+    GTEST_SKIP()
+        << "Conv-Bias-LeakyRelu fusion is supported and recommended with "
+           "the Nvidia Ampere+ GPUs.";
+  }
+  // sum = conv(x, w) + bias
+  // select(compare(sum, 0, GT), sum, multiply(sum, alpha));
+  TestMatchWithAllTypes(R"(
+    HloModule Test
+    ENTRY Test {
+      zero = TYPE[] constant(0)
+      zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+      alpha = TYPE[] constant(0.2)
+      alphas = TYPE[1,3,3,64] broadcast(alpha), dimensions={}
+      input = TYPE[1,3,3,64] parameter(0)
+      filter = TYPE[3,3,64,64] parameter(1)
+      bias = TYPE[64] parameter(2)
+      conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+      broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
+      sum = TYPE[1,3,3,64] add(conv, broadcasted_bias)
+      cmp = pred[1,3,3,64] compare(sum, zeros), direction=GT
+      mul = TYPE[1,3,3,64] multiply(sum, alphas)
+      ROOT elu = TYPE[1,3,3,64] select(cmp, sum, mul)
+    })");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestSideInputOnly) {
+  // max(0, conv(x, w) + side_input);
+  TestMatchWithAllTypes(R"(
+    HloModule Test
+
+    ENTRY Test {
+      zero = TYPE[] constant(0)
+      zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+
+      input = TYPE[1,3,3,64] parameter(0)
+      filter = TYPE[3,3,64,64] parameter(1)
+      side_input = TYPE[1,3,3,64] parameter(2)
+
+      conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+      add1 = TYPE[1,3,3,64] add(conv, side_input)
+      ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
+    })");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, DontFuseSideInputWithDepthwiseConv) {
+  // max(0, conv(x, w) + side_input);
+  TestNotMatchWithAllTypes(R"(
+    HloModule Test
+
+    ENTRY Test {
+      zero = TYPE[] constant(0)
+      zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+
+      input = TYPE[1,3,3,64] parameter(0)
+      filter = TYPE[3,3,1,64] parameter(1)
+      side_input = TYPE[1,3,3,64] parameter(2)
+
+      conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=64
+      add1 = TYPE[1,3,3,64] add(conv, side_input)
+      ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
+    })");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestBiasAndSideInput) {
+  // max(0, conv(x, w) + side_input + bias);
+  TestMatchWithAllTypes(R"(
+    HloModule Test
+
+    ENTRY Test {
+      zero = TYPE[] constant(0)
+      zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+
+      input = TYPE[1,3,3,64] parameter(0)
+      filter = TYPE[3,3,64,64] parameter(1)
+      side_input = TYPE[1,3,3,64] parameter(2)
+      bias = TYPE[64] parameter(3)
+
+      conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+      broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
+      add1 = TYPE[1,3,3,64] add(conv, broadcasted_bias)
+      add2 = TYPE[1,3,3,64] add(add1, side_input)
+      ROOT relu = TYPE[1,3,3,64] maximum(zeros, add2)
+    })");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestScaledConv) {
+  // max(0, 0.999994934 * conv(x, w));
+  TestMatchWithAllTypes(R"(
+    HloModule Test
+
+    ENTRY Test {
+      zero = TYPE[] constant(0)
+      zeros = TYPE[1,32,9,9] broadcast(zero), dimensions={}
+      alpha_conv_scalar = TYPE[] constant(0.999994934)
+
+      input = TYPE[1,17,9,9] parameter(0)
+      filter = TYPE[3,3,17,32] parameter(1)
+
+      conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
+      alpha_conv = TYPE[1,32,9,9] broadcast(alpha_conv_scalar), dimensions={}
+      scaled_conv = TYPE[1,32,9,9] multiply(conv, alpha_conv)
+      ROOT relu = TYPE[1,32,9,9] maximum(zeros, scaled_conv)
+    })");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, DontFuseScaledDepthwiseConv) {
+  // max(0, 0.999994934 * conv(x, w));
+  TestNotMatchWithAllTypes(R"(
+    HloModule Test
+
+    ENTRY Test {
+      zero = TYPE[] constant(0)
+      zeros = TYPE[1,17,9,9] broadcast(zero), dimensions={}
+      alpha_conv_scalar = TYPE[] constant(0.999994934)
+
+      input = TYPE[1,17,9,9] parameter(0)
+      filter = TYPE[3,3,1,17] parameter(1)
+
+      conv = TYPE[1,17,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=17
+      alpha_conv = TYPE[1,17,9,9] broadcast(alpha_conv_scalar), dimensions={}
+      scaled_conv = TYPE[1,17,9,9] multiply(conv, alpha_conv)
+      ROOT relu = TYPE[1,17,9,9] maximum(zeros, scaled_conv)
+    })");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestNoCrashOnInf) {
+  EXPECT_TRUE(RunAndCompare(R"(
+    HloModule Test
+
+    ENTRY Test {
+      zero = f32[] constant(inf)
+      zeros = f32[1,32,9,9] broadcast(zero), dimensions={}
+      alpha_conv_scalar = f32[] constant(0.999994934)
+
+      input = f32[1,17,9,9] parameter(0)
+      filter = f32[3,3,17,32] parameter(1)
+
+      conv = f32[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
+      alpha_conv = f32[1,32,9,9] broadcast(alpha_conv_scalar), dimensions={}
+      scaled_conv = f32[1,32,9,9] multiply(conv, alpha_conv)
+      ROOT relu = f32[1,32,9,9] maximum(zeros, scaled_conv)
+    })",
+                            ErrorSpec{0.01}));
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestConvAndScaledSideInput) {
+  // max(0, conv(x, w) + 0.899994934 * side_input);
+  TestMatchWithAllTypes(R"(
+    HloModule Test
+
+    ENTRY Test {
+      zero = TYPE[] constant(0)
+      zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+      alpha_side_input_scalar = TYPE[] constant(0.899994934)
+      alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={}
+
+      input = TYPE[1,3,3,64] parameter(0)
+      filter = TYPE[3,3,64,64] parameter(1)
+      side_input = TYPE[1,3,3,64] parameter(2)
+
+      conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+      scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input)
+      add1 = TYPE[1,3,3,64] add(conv, scaled_side_input)
+      ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
+    })");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, DontFuseDepthwiseConvWithScaledSideInput) {
+  // max(0, conv(x, w) + 0.899994934 * side_input);
+  TestNotMatchWithAllTypes(R"(
+    HloModule Test
+
+    ENTRY Test {
+      zero = TYPE[] constant(0)
+      zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+      alpha_side_input_scalar = TYPE[] constant(0.899994934)
+      alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={}
+
+      input = TYPE[1,3,3,64] parameter(0)
+      filter = TYPE[3,3,1,64] parameter(1)
+      side_input = TYPE[1,3,3,64] parameter(2)
+
+      conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=64
+      scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input)
+      add1 = TYPE[1,3,3,64] add(conv, scaled_side_input)
+      ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
+    })");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestScaledConvAndScaledSideInput) {
+  // max(0, 0.999994934 * conv(x, w) + 0.899994934 * side_input);
+  TestMatchWithAllTypes(R"(
+    HloModule Test
+
+    ENTRY Test {
+      zero = TYPE[] constant(0)
+      zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+      alpha_conv_scalar = TYPE[] constant(0.999994934)
+      alpha_conv = TYPE[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={}
+      alpha_side_input_scalar = TYPE[] constant(0.899994934)
+      alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={}
+
+      input = TYPE[1,3,3,64] parameter(0)
+      filter = TYPE[3,3,64,64] parameter(1)
+      side_input = TYPE[1,3,3,64] parameter(2)
+
+      conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+      scaled_conv = TYPE[1,3,3,64] multiply(conv, alpha_conv)
+      scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input)
+      add1 = TYPE[1,3,3,64] add(scaled_conv, scaled_side_input)
+      ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
+    })");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestScaledConvAndScaledSideInputWithBias) {
+  // max(0, 0.999994934 * conv(x, w) + 0.899994934 * side_input + bias);
+  TestMatchWithAllTypes(R"(
+    HloModule Test
+
+    ENTRY Test {
+      zero = TYPE[] constant(0)
+      zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+      alpha_conv_scalar = TYPE[] constant(0.999994934)
+      alpha_conv = TYPE[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={}
+      alpha_side_input_scalar = TYPE[] constant(0.899994934)
+      alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={}
+
+      input = TYPE[1,3,3,64] parameter(0)
+      filter = TYPE[3,3,64,64] parameter(1)
+      side_input = TYPE[1,3,3,64] parameter(2)
+      bias = TYPE[64] parameter(3)
+
+      conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+      scaled_conv = TYPE[1,3,3,64] multiply(conv, alpha_conv)
+      scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input)
+      broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
+      add1 = TYPE[1,3,3,64] add(scaled_conv, broadcasted_bias)
+      add2 = TYPE[1,3,3,64] add(add1, scaled_side_input)
+      ROOT relu = TYPE[1,3,3,64] maximum(zeros, add2)
+    })");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestMatchMaxZeroOnly) {
+  // max(0.1, conv(x, w)) shouldn't match.
+  TestNotMatchWithAllTypes(R"(
+    HloModule Test
+
+    ENTRY Test {
+      point_one = TYPE[] constant(0.1)
+      point_ones = TYPE[1,32,9,9] broadcast(point_one), dimensions={}
+
+      input = TYPE[1,17,9,9] parameter(0)
+      filter = TYPE[3,3,17,32] parameter(1)
+
+      conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
+      ROOT relu = TYPE[1,32,9,9] maximum(point_ones, conv)
+    })");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, PreservesMetadata) {
+  const char* kHloString = R"(
+    HloModule Test
+
+    ENTRY Test {
+      zero = f32[] constant(0)
+      zeros = f32[1,32,9,9] broadcast(zero), dimensions={}
+
+      input = f32[1,17,9,9] parameter(0)
+      filter = f32[3,3,17,32] parameter(1)
+
+      conv = f32[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1, metadata={op_type="foo" op_name="bar"}
+      ROOT relu = f32[1,32,9,9] maximum(zeros, conv)
+    })";
+
+  const std::string optimized_hlo_string =
+      backend()
+          .compiler()
+          ->RunHloPasses(
+              ParseAndReturnVerifiedModule(kHloString, GetModuleConfigForTest())
+                  .value(),
+              backend().default_stream_executor(), backend().memory_allocator())
+          .value()
+          ->ToString();
+  EXPECT_THAT(optimized_hlo_string,
+              ::testing::ContainsRegex(
+                  R"(custom-call.*metadata=\{op_type="foo" op_name="bar"\})"));
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestPreservesFeatureGroupCount) {
+  // The convolution below would crash if feature_count is not preserved.
+  const char* kHloString = R"(
+    HloModule jaxpr_computation__6.19
+
+    primitive_computation__1.4 {
+      parameter.5 = f32[] parameter(0)
+      parameter.6 = f32[] parameter(1)
+      ROOT add.7 = f32[] add(parameter.5, parameter.6)
+    }
+
+    ENTRY jaxpr_computation__7.8 {
+      parameter.11 = f32[2,64,64,53]{3,2,1,0} parameter(1)
+      parameter.10 = f32[3,3,1,53]{3,2,1,0} parameter(0)
+      convolution.12 = f32[2,64,64,53]{3,2,1,0} convolution(parameter.11, parameter.10), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=53
+      constant.13 = f32[] constant(0)
+      broadcast.14 = f32[2,64,64,53]{3,2,1,0} broadcast(constant.13), dimensions={}
+      maximum.15 = f32[2,64,64,53]{3,2,1,0} maximum(convolution.12, broadcast.14)
+      ROOT reduce.17 = f32[] reduce(maximum.15, constant.13), dimensions={0,1,2,3}, to_apply=primitive_computation__1.4
+    }
+  )";
+  EXPECT_TRUE(RunAndCompare(kHloString, ErrorSpec{0.01}));
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestConvF8) {
+  MAYBE_SKIP_TEST("F8");
+  TestF8(
+      // pre_hlo
+      R"(
+    HloModule Test
+
+    ENTRY Test {
+       input = f8e4m3fn[1,128,6,6] parameter(0)
+       filter = f8e4m3fn[3,3,128,16] parameter(1)
+       ROOT conv_a = f8e4m3fn[1,16,6,6] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
+
+    })",
+      // custom_call
+      R"(
+// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f8e4m3fn[1,6,6,16]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph"
+  )",
+      // serialized_graph
+      R"(
+// CHECK: "serialized_graph":"[[CONV_UID:[0-9]+]]:[f8e4m3fn]conv();"
+      )");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestConvScaledOutputF8) {
+  MAYBE_SKIP_TEST("F8");
+  TestF8(
+      // pre_hlo
+      R"(
+    HloModule Test
+
+    ENTRY Test {
+       input = f8e4m3fn[1,128,6,6] parameter(0)
+       filter = f8e4m3fn[3,3,128,16] parameter(1)
+       input_f32 = f32[1,128,6,6] convert(input)
+       filter_f32 = f32[3,3,128,16] convert(filter)
+       z_scale = f32[] parameter(2)
+       z_scale_bcast = f32[1,16,6,6] broadcast(z_scale), dimensions={}
+       conv_a = f32[1,16,6,6] convolution(input_f32, filter_f32), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
+       conv_a_scaled = f32[1,16,6,6] multiply(conv_a, z_scale_bcast)
+       c1 = f32[] constant(-448.)
+       c1_bcast = f32[1,16,6,6] broadcast(c1), dimensions={}
+       c2 = f32[] constant(448.)
+       c2_bcast = f32[1,16,6,6] broadcast(c2), dimensions={}
+       conv_a_clamped = f32[1,16,6,6] clamp(c1_bcast, conv_a_scaled, c2_bcast)
+       ROOT conv_f8 = f8e4m3fn[1,16,6,6] convert(conv_a_clamped)
+
+    })",
+      // custom_call
+      R"(
+// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f8e4m3fn[1,6,6,16]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]], [[OPERAND2:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph"
+  )",
+      // serialized_graph
+      R"(
+// CHECK: "serialized_graph":"[[CONV_UID:[0-9]+]]:[f32]conv();[[SCALE_UID:[0-9]+]]:[f8e4m3fn]scale([[CONV_UID]]);"
+      )");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestConvInvscaledOutputF8) {
+  MAYBE_SKIP_TEST("F8");
+  TestF8(
+      // pre_hlo
+      R"(
+    HloModule Test
+
+    ENTRY Test {
+       input = f8e4m3fn[1,128,6,6] parameter(0)
+       filter = f8e4m3fn[3,3,128,16] parameter(1)
+       input_f32 = f32[1,128,6,6] convert(input)
+       filter_f32 = f32[3,3,128,16] convert(filter)
+       z_scale = f32[] parameter(2)
+       z_scale_bcast = f32[1,16,6,6] broadcast(z_scale), dimensions={}
+       conv_a = f32[1,16,6,6] convolution(input_f32, filter_f32), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
+       conv_a_scaled = f32[1,16,6,6] divide(conv_a, z_scale_bcast)
+       c1 = f32[] constant(-448.)
+       c1_bcast = f32[1,16,6,6] broadcast(c1), dimensions={}
+       c2 = f32[] constant(448.)
+       c2_bcast = f32[1,16,6,6] broadcast(c2), dimensions={}
+       conv_a_clamped = f32[1,16,6,6] clamp(c1_bcast, conv_a_scaled, c2_bcast)
+       ROOT conv_f8 = f8e4m3fn[1,16,6,6] convert(conv_a_clamped)
+
+    })",
+      // custom_call
+      R"(
+// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f8e4m3fn[1,6,6,16]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]], [[OPERAND2:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph"
+  )",
+      // serialized_graph
+      R"(
+// CHECK: "serialized_graph":"[[CONV_UID:[0-9]+]]:[f32]conv();[[SCALE0_UID:[0-9]+]]:[f8e4m3fn]invscale([[CONV_UID]]);"
+      )");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestConvScaledF8Parameterized) {
+  MAYBE_SKIP_TEST("F8");
+  TestF8Parameterized(
+      // pre_hlo
+      R"(
+    HloModule Test
+
+    ENTRY Test {
+       input = <<InputType>>[1,128,6,6] parameter(0)
+       filter = <<FilterType>>[3,3,128,16] parameter(1)
+       input_scale = f32[] parameter(2)
+       input_scale_bcast = f32[1,128,6,6] broadcast(input_scale), dimensions={}
+       filter_scale = f32[] parameter(3)
+       filter_scale_bcast = f32[3,3,128,16] broadcast(filter_scale), dimensions={}
+       input_f32 = f32[1,128,6,6] convert(input)
+       input_unscaled = f32[1,128,6,6] multiply(input_f32, input_scale_bcast)
+       filter_f32 = f32[3,3,128,16] convert(filter)
+       filter_unscaled = f32[3,3,128,16] multiply(filter_f32, filter_scale_bcast)
+       z_scale = f32[] parameter(4)
+       z_scale_bcast = f32[1,16,6,6] broadcast(z_scale), dimensions={}
+       conv_a = f32[1,16,6,6] convolution(input_unscaled, filter_unscaled), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
+       conv_a_scaled = f32[1,16,6,6] multiply(conv_a, z_scale_bcast)
+       c1 = f32[] constant(<<ClampLower>>)
+       c1_bcast = f32[1,16,6,6] broadcast(c1), dimensions={}
+       c2 = f32[] constant(<<ClampUpper>>)
+       c2_bcast = f32[1,16,6,6] broadcast(c2), dimensions={}
+       conv_a_clamped = f32[1,16,6,6] clamp(c1_bcast, conv_a_scaled, c2_bcast)
+       ROOT conv_f8 = <<OutputType>>[1,16,6,6] convert(conv_a_clamped)
+
+    })",
+      // custom_call
+      R"(
+// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (<<OutputType>>[1,6,6,16]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]], [[OPERAND2:%[^ ]+]], [[OPERAND3:%[^ ]+]], [[OPERAND4:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph"
+      )",
+      // serialized_graph
+      R"(
+// CHECK: "serialized_graph":"[[CONV_UID:[0-9]+]]:[f32]conv();[[SCALE0_UID:[0-9]+]]:[f32]scale([[CONV_UID]]);[[SCALE1_UID:[0-9]+]]:[f32]scale([[SCALE0_UID]]);[[SCALE2_UID:[0-9]+]]:[<<OutputType>>]scale([[SCALE1_UID]]);"
+      )");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestConvScaledBiasF8) {
+  MAYBE_SKIP_TEST("F8");
+  TestF8(
+      // pre_hlo
+      R"(
+    HloModule Test
+
+    ENTRY Test {
+       input = f8e4m3fn[1,128,6,6] parameter(0)
+       filter = f8e4m3fn[3,3,128,16] parameter(1)
+       input_scale = f32[] parameter(2)
+       input_scale_bcast = f32[1,128,6,6] broadcast(input_scale), dimensions={}
+       filter_scale = f32[] parameter(3)
+       filter_scale_bcast = f32[3,3,128,16] broadcast(filter_scale), dimensions={}
+       input_f32 = f32[1,128,6,6] convert(input)
+       input_unscaled = f32[1,128,6,6] multiply(input_f32, input_scale_bcast)
+       filter_f32 = f32[3,3,128,16] convert(filter)
+       filter_unscaled = f32[3,3,128,16] multiply(filter_f32, filter_scale_bcast)
+       bias = f32[1,16,6,6] parameter(4)
+       z_scale = f32[] parameter(5)
+       z_scale_bcast = f32[1,16,6,6] broadcast(z_scale), dimensions={}
+       conv_a = f32[1,16,6,6] convolution(input_unscaled, filter_unscaled), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
+       conv_a_bias = f32[1,16,6,6] add(conv_a, bias)
+       conv_a_scaled = f32[1,16,6,6] multiply(conv_a_bias, z_scale_bcast)
+       c1 = f32[] constant(-448.)
+       c1_bcast = f32[1,16,6,6] broadcast(c1), dimensions={}
+       c2 = f32[] constant(448.)
+       c2_bcast = f32[1,16,6,6] broadcast(c2), dimensions={}
+       conv_a_clamped = f32[1,16,6,6] clamp(c1_bcast, conv_a_scaled, c2_bcast)
+       ROOT conv_f8 = f8e4m3fn[1,16,6,6] convert(conv_a_clamped)
+
+    })",
+      // custom_call
+      R"(
+// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f8e4m3fn[1,6,6,16]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]], [[OPERAND2:%[^ ]+]], [[OPERAND3:%[^ ]+]], [[OPERAND4:%[^ ]+]], /*index=5*/[[OPERAND5:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph"
+      )",
+      // serialized_graph
+      R"(
+// CHECK: "serialized_graph":"[[CONV_UID:[0-9]+]]:[f32]conv();[[SCALE0_UID:[0-9]+]]:[f32]scale([[CONV_UID]]);[[SCALE1_UID:[0-9]+]]:[f32]scale([[SCALE0_UID]]);[[ADD_UID:[0-9]+]]:[f32]add([[SCALE1_UID]]);[[SCALE2_UID:[0-9]+]]:[f8e4m3fn]scale([[ADD_UID]]);"
+      )");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestConvScaledReluF8) {
+  MAYBE_SKIP_TEST("F8");
+  TestF8(
+      // pre_hlo
+      R"(
+    HloModule Test
+
+    ENTRY Test {
+       input = f8e4m3fn[1,128,6,6] parameter(0)
+       filter = f8e4m3fn[3,3,128,16] parameter(1)
+       input_f32 = f32[1,128,6,6] convert(input)
+       filter_f32 = f32[3,3,128,16] convert(filter)
+       z_scale = f32[] parameter(2)
+       z_scale_bcast = f32[1,16,6,6] broadcast(z_scale), dimensions={}
+       c = f32[] constant(0)
+       c_bcast = f32[1,16,6,6] broadcast(c), dimensions={}
+       conv_a = f32[1,16,6,6] convolution(input_f32, filter_f32), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
+       relu_a = f32[1,16,6,6] maximum(conv_a, c_bcast)
+       relu_a_scaled = f32[1,16,6,6] multiply(relu_a, z_scale_bcast)
+       c1 = f32[] constant(-448.)
+       c1_bcast = f32[1,16,6,6] broadcast(c1), dimensions={}
+       c2 = f32[] constant(448.)
+       c2_bcast = f32[1,16,6,6] broadcast(c2), dimensions={}
+       relu_a_clamped = f32[1,16,6,6] clamp(c1_bcast, relu_a_scaled, c2_bcast)
+       ROOT conv_f8 = f8e4m3fn[1,16,6,6] convert(relu_a_clamped)
+
+    })",
+      // custom_call
+      R"(
+// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f8e4m3fn[1,6,6,16]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]], [[OPERAND2:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph"
+  )",
+      // serialized_graph
+      R"(
+// CHECK: "serialized_graph":"[[CONV_UID:[0-9]+]]:[f32]conv();[[RELU_UID:[0-9]+]]:[f32]relu([[CONV_UID]]);[[SCALE0_UID:[0-9]+]]:[f8e4m3fn]scale([[RELU_UID]]);"
+      )");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestConvAmaxF8) {
+  MAYBE_SKIP_TEST("F8");
+  TestF8(
+      // pre_hlo
+      R"(
+    HloModule Test
+
+    apply {
+      a = f32[] parameter(0)
+      b = f32[] parameter(1)
+      ROOT c = f32[] maximum(a, b)
+    }
+
+    ENTRY Test {
+       input = f8e4m3fn[1,128,6,6] parameter(0)
+       filter = f8e4m3fn[3,3,128,16] parameter(1)
+       input_scale = f32[] parameter(2)
+       input_scale_bcast = f32[1,128,6,6] broadcast(input_scale), dimensions={}
+       filter_scale = f32[] parameter(3)
+       filter_scale_bcast = f32[3,3,128,16] broadcast(filter_scale), dimensions={}
+       input_f32 = f32[1,128,6,6] convert(input)
+       input_unscaled = f32[1,128,6,6] multiply(input_f32, input_scale_bcast)
+       filter_f32 = f32[3,3,128,16] convert(filter)
+       filter_unscaled = f32[3,3,128,16] multiply(filter_f32, filter_scale_bcast)
+       conv_a = f32[1,16,6,6] convolution(input_unscaled, filter_unscaled), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
+       z_scale = f32[] parameter(4)
+       z_scale_bcast = f32[1,16,6,6] broadcast(z_scale), dimensions={}
+       conv_a_scaled = f32[1,16,6,6] multiply(conv_a, z_scale_bcast)
+       c1 = f32[] constant(-448.)
+       c1_bcast = f32[1,16,6,6] broadcast(c1), dimensions={}
+       c2 = f32[] constant(448.)
+       c2_bcast = f32[1,16,6,6] broadcast(c2), dimensions={}
+       conv_a_clamped = f32[1,16,6,6] clamp(c1_bcast, conv_a_scaled, c2_bcast)
+       conv_a_clamped_f8 = f8e4m3fn[1,16,6,6] convert(conv_a_clamped)
+       abs_conv_a = f32[1,16,6,6] abs(conv_a)
+       c0 = f32[] constant(-inf)
+       amax = f32[] reduce(abs_conv_a, c0), dimensions={0,1,2,3}, to_apply=apply
+       ROOT conv_f8 = (f8e4m3fn[1,16,6,6], f32[]) tuple(conv_a_clamped_f8, amax)
+
+    })",
+      // custom_call
+      R"(
+// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f8e4m3fn[1,6,6,16]{3,2,1,0}, f32[], u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]], [[OPERAND2:%[^ ]+]], [[OPERAND3:%[^ ]+]], [[OPERAND4:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph"
+    )",
+      // serialized_graph
+      R"(
+// CHECK: "serialized_graph":"[[CONV_UID:[0-9]+]]:[f32]conv();[[SCALE0_UID:[0-9]+]]:[f32]scale([[CONV_UID]]);[[SCALE1_UID:[0-9]+]]:[f32]scale([[SCALE0_UID]]);[[SCALE2_UID:[0-9]+]]:[f8e4m3fn]scale([[SCALE1_UID]]);[[AMAX_UID:[0-9]+]]:[f32]amax([[SCALE1_UID]]);"
+      )");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestConvReluAmaxF8) {
+  MAYBE_SKIP_TEST("F8");
+  TestF8(
+      // pre_hlo
+      R"(
+    HloModule Test
+
+    apply {
+      a = f32[] parameter(0)
+      b = f32[] parameter(1)
+      ROOT c = f32[] maximum(a, b)
+    }
+
+    ENTRY Test {
+       input = f8e4m3fn[1,128,6,6] parameter(0)
+       filter = f8e4m3fn[3,3,128,16] parameter(1)
+       input_scale = f32[] parameter(2)
+       input_scale_bcast = f32[1,128,6,6] broadcast(input_scale), dimensions={}
+       filter_scale = f32[] parameter(3)
+       filter_scale_bcast = f32[3,3,128,16] broadcast(filter_scale), dimensions={}
+       input_f32 = f32[1,128,6,6] convert(input)
+       input_unscaled = f32[1,128,6,6] multiply(input_f32, input_scale_bcast)
+       filter_f32 = f32[3,3,128,16] convert(filter)
+       filter_unscaled = f32[3,3,128,16] multiply(filter_f32, filter_scale_bcast)
+       c = f32[] constant(0)
+       c_bcast = f32[1,16,6,6] broadcast(c), dimensions={}
+       conv_a = f32[1,16,6,6] convolution(input_unscaled, filter_unscaled), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
+       relu_a = f32[1,16,6,6] maximum(conv_a, c_bcast)
+       z_scale = f32[] parameter(4)
+       z_scale_bcast = f32[1,16,6,6] broadcast(z_scale), dimensions={}
+       relu_a_scaled = f32[1,16,6,6] multiply(relu_a, z_scale_bcast)
+       c1 = f32[] constant(-448.)
+       c1_bcast = f32[1,16,6,6] broadcast(c1), dimensions={}
+       c2 = f32[] constant(448.)
+       c2_bcast = f32[1,16,6,6] broadcast(c2), dimensions={}
+       relu_a_clamped = f32[1,16,6,6] clamp(c1_bcast, relu_a_scaled, c2_bcast)
+       relu_a_clamped_f8 = f8e4m3fn[1,16,6,6] convert(relu_a_clamped)
+       abs_relu_a = f32[1,16,6,6] abs(relu_a)
+       c0 = f32[] constant(-inf)
+       amax = f32[] reduce(abs_relu_a, c0), dimensions={0,1,2,3}, to_apply=apply
+       ROOT conv_f8 = (f8e4m3fn[1,16,6,6], f32[]) tuple(relu_a_clamped_f8, amax)
+
+    })",
+      // custom_call
+      R"(
+// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f8e4m3fn[1,6,6,16]{3,2,1,0}, f32[], u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]], [[OPERAND2:%[^ ]+]], [[OPERAND3:%[^ ]+]], [[OPERAND4:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph"
+    )",
+      // serialized_graph
+      R"(
+// CHECK: "serialized_graph":"[[CONV_UID:[0-9]+]]:[f32]conv();[[SCALE0_UID:[0-9]+]]:[f32]scale([[CONV_UID]]);[[SCALE1_UID:[0-9]+]]:[f32]scale([[SCALE0_UID]]);[[RELU_UID:[0-9]+]]:[f32]relu([[SCALE1_UID]]);[[SCALE2_UID:[0-9]+]]:[f8e4m3fn]scale([[RELU_UID]]);[[AMAX_UID:[0-9]+]]:[f32]amax([[RELU_UID]]);"
+      )");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestConvScaledOutputMultipleUsersF8) {
+  MAYBE_SKIP_TEST("F8");
+  TestF8(
+      // pre_hlo
+      R"(
+    HloModule Test
+
+    ENTRY Test {
+       input = f8e4m3fn[1,128,6,6] parameter(0)
+       filter = f8e4m3fn[3,3,128,16] parameter(1)
+       input_f32 = f32[1,128,6,6] convert(input)
+       filter_f32 = f32[3,3,128,16] convert(filter)
+       z_scale0 = f32[] parameter(2)
+       z_scale0_bcast = f32[1,16,6,6] broadcast(z_scale0), dimensions={}
+       z_scale1 = f32[] parameter(3)
+       z_scale1_bcast = f32[1,16,6,6] broadcast(z_scale1), dimensions={}
+       conv_a = f32[1,16,6,6] convolution(input_f32, filter_f32), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
+       conv_a_scaled0 = f32[1,16,6,6] multiply(conv_a, z_scale0_bcast)
+       conv_a_scaled1 = f32[1,16,6,6] multiply(conv_a, z_scale1_bcast)
+       c1 = f32[] constant(-448.)
+       c1_bcast = f32[1,16,6,6] broadcast(c1), dimensions={}
+       c2 = f32[] constant(448.)
+       c2_bcast = f32[1,16,6,6] broadcast(c2), dimensions={}
+       conv_a_clamped0 = f32[1,16,6,6] clamp(c1_bcast, conv_a_scaled0, c2_bcast)
+       conv_a_clamped1 = f32[1,16,6,6] clamp(c1_bcast, conv_a_scaled1, c2_bcast)
+       conv_a_convert0 = f8e4m3fn[1,16,6,6] convert(conv_a_clamped0)
+       conv_a_convert1 = f8e4m3fn[1,16,6,6] convert(conv_a_clamped1)
+       ROOT conv_f8 = (f8e4m3fn[1,16,6,6], f8e4m3fn[1,16,6,6]) tuple(conv_a_convert0, conv_a_convert1)
+
+    })",
+      // custom_call
+      R"(
+// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f32[1,6,6,16]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph"
+  )",
+      // serialized_graph
+      R"(
+// CHECK: "serialized_graph":"[[CONV_UID:[0-9]+]]:[f32]conv();"
+      )");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestConvScaledOutputUnsupportedUserF8) {
+  MAYBE_SKIP_TEST("F8");
+  TestF8(
+      // pre_hlo
+      R"(
+    HloModule Test
+
+    ENTRY Test {
+       input = f8e4m3fn[1,128,6,6] parameter(0)
+       filter = f8e4m3fn[3,3,128,16] parameter(1)
+       input_f32 = f32[1,128,6,6] convert(input)
+       filter_f32 = f32[3,3,128,16] convert(filter)
+       z_scale = f32[] parameter(2)
+       z_scale_bcast = f32[1,16,6,6] broadcast(z_scale), dimensions={}
+       conv_a = f32[1,16,6,6] convolution(input_f32, filter_f32), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
+       conv_a_cos = f32[1,16,6,6] cosine(conv_a)
+       conv_a_scaled = f32[1,16,6,6] multiply(conv_a, z_scale_bcast)
+       c1 = f32[] constant(-448.)
+       c1_bcast = f32[1,16,6,6] broadcast(c1), dimensions={}
+       c2 = f32[] constant(448.)
+       c2_bcast = f32[1,16,6,6] broadcast(c2), dimensions={}
+       conv_a_clamped = f32[1,16,6,6] clamp(c1_bcast, conv_a_scaled, c2_bcast)
+       conv_a_convert = f8e4m3fn[1,16,6,6] convert(conv_a_clamped)
+       ROOT conv_f8 = (f8e4m3fn[1,16,6,6], f32[1,16,6,6]) tuple(conv_a_convert, conv_a_cos)
+
+    })",
+      // custom_call
+      R"(
+// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f32[1,6,6,16]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph"
+  )",
+      // serialized_graph
+      R"(
+// CHECK: "serialized_graph":"[[CONV_UID:[0-9]+]]:[f32]conv();"
+      )");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestConvInt8ToInt8) {
+  MAYBE_SKIP_TEST("I8");
+  // max(0, clamp(conv(x, w)))); for int8_t
+  TestClamp(
+      // pre_hlo
+      R"(
+    HloModule Test
+
+    ENTRY Test {
+      zero = s8[] constant(0)
+      zeros = s8[1,32,9,9] broadcast(zero), dimensions={}
+
+      input = s8[1,17,9,9] parameter(0)
+      filter = s8[3,3,17,32] parameter(1)
+
+      inputs32 = s32[1,17,9,9] convert(input)
+      filters32 = s32[3,3,17,32] convert(filter)
+
+      conv = s32[1,32,9,9] convolution(inputs32, filters32), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
+
+      lower = s32[] constant(-128)
+      lowers = s32[1,32,9,9] broadcast(lower), dimensions={}
+      upper = s32[] constant(127)
+      uppers = s32[1,32,9,9] broadcast(upper), dimensions={}
+
+      clamp = s32[1,32,9,9] clamp(lowers, conv, uppers)
+
+      ROOT convert = s8[1,32,9,9] convert(clamp)
+    })",
+      // post_hlo
+      R"(
+// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (s8[1,9,9,32]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[fusion_2_1:%[^ ]+]], [[fusion_1_2:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForward"
+      )");
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, TestConvInt8ToFloat) {
+  MAYBE_SKIP_TEST("I8");
+  const std::string module_str = R"(
+    HloModule Test
+
+    ENTRY Test {
+      input = s8[1,17,9,9] parameter(0)
+      filter = s8[3,3,17,32] parameter(1)
+
+      inputs32 = s32[1,17,9,9] convert(input)
+      filters32 = s32[3,3,17,32] convert(filter)
+
+      conv = s32[1,32,9,9] convolution(inputs32, filters32),
+               window={size=3x3 pad=1_1x1_1},
+               dim_labels=bf01_01io->bf01
+
+      ROOT convert = f32[1,32,9,9] convert(conv)
+    })";
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+  ConvRewriter rewriter{GetCudaComputeCapability()};
+  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+                               GetToolkitVersion()};
+  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+  SCOPED_TRACE(m->ToString());
+  EXPECT_THAT(m->entry_computation()->root_instruction(),
+              GmockMatch(m::GetTupleElement(
+                             m::CustomCall({kCudnnConvForwardCallTarget}), 0)
+                             .WithShape(F32, {1, 32, 9, 9})));
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, TestConvInt8ToInt8BiasSideInput) {
+  MAYBE_SKIP_TEST("I8");
+  const std::string module_str = R"(
+    HloModule Test
+
+    ENTRY Test {
+      input = s32[1,17,9,9] convert(s8[1,17,9,9] parameter(0))
+      filter = s32[3,3,17,32] convert(s8[3,3,17,32] parameter(1))
+      bias = f32[1,32,9,9] broadcast(f32[32] parameter(2)), dimensions={1}
+      side_input = f32[1,32,9,9] convert(s8[1,32,9,9] parameter(3))
+
+      conv = s32[1,32,9,9] convolution(input, filter),
+               window={size=3x3 pad=1_1x1_1},
+               dim_labels=bf01_01io->bf01
+      conv_f32 = f32[1,32,9,9] convert(conv)
+      ROOT root = s8[1,32,9,9] convert(clamp(f32[1,32,9,9] broadcast(f32[] constant(-128)),
+                                             add(add(conv_f32, bias), side_input),
+                                             f32[1,32,9,9] broadcast(f32[] constant(127))))
+    })";
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+  ConvRewriter rewriter{GetCudaComputeCapability()};
+  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+                               GetToolkitVersion()};
+  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+  // Simplify new `convert`'s that may be added to the graph.
+  AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
+  TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
+
+  SCOPED_TRACE(m->ToString());
+  EXPECT_THAT(
+      m->entry_computation()->root_instruction(),
+      GmockMatch(m::GetTupleElement(
+                     m::CustomCall({kCudnnConvBiasActivationForwardCallTarget},
+                                   m::Parameter(0), m::Parameter(1),
+                                   m::Parameter(2), m::Parameter(3)),
+                     0)
+                     .WithShape(S8, {1, 32, 9, 9})));
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, TestReluAfterConvert) {
+  MAYBE_SKIP_TEST("I8");
+  const std::string module_str = R"(
+    HloModule Test
+
+    ENTRY Test {
+      input = s32[1,17,9,9] convert(s8[1,17,9,9] parameter(0))
+      filter = s32[3,3,17,32] convert(s8[3,3,17,32] parameter(1))
+
+      conv = s32[1,32,9,9] convolution(input, filter),
+               window={size=3x3 pad=1_1x1_1},
+               dim_labels=bf01_01io->bf01
+      conv_s8 = s8[1,32,9,9] convert(clamp(s32[1,32,9,9] broadcast(s32[] constant(-128)),
+                                           conv,
+                                           s32[1,32,9,9] broadcast(s32[] constant(127))))
+      zeros = s8[1,32,9,9] broadcast(s8[] constant(0)), dimensions={}
+      ROOT root = s8[1,32,9,9] maximum(conv_s8, zeros)
+    })";
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+  ConvRewriter rewriter{GetCudaComputeCapability()};
+  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+                               GetToolkitVersion()};
+  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+  // Simplify new `convert`'s that may be added to the graph.
+  AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
+  TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
+
+  SCOPED_TRACE(m->ToString());
+  const HloInstruction* conv;
+  ASSERT_THAT(
+      m->entry_computation()->root_instruction(),
+      GmockMatch(
+          m::GetTupleElement(
+              m::CustomCall(
+                  &conv, {kCudnnConvBiasActivationForwardCallTarget},
+                  m::Parameter(0),  //
+                  m::Parameter(1),  //
+                  m::Broadcast(
+                      m::ConstantEffectiveScalar(0).WithElementType(F32))),
+              0)
+              .WithShape(S8, {1, 32, 9, 9})));
+  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+                          conv->backend_config<GpuBackendConfig>());
+  const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
+  EXPECT_EQ(config.activation_mode(), se::dnn::kRelu);
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, TestConvInt8ToFloatBiasSideInput) {
+  MAYBE_SKIP_TEST("I8");
+  const std::string module_str = R"(
+    HloModule Test
+
+    ENTRY Test {
+      input = s8[1,17,9,9] parameter(0)
+      filter = s8[3,3,17,32] parameter(1)
+      bias = f32[32] parameter(2)
+      bias_broadcast = f32[1,32,9,9] broadcast(bias), dimensions={1}
+      side_input_f32 = f32[1,32,9,9] parameter(3)
+
+      inputs32 = s32[1,17,9,9] convert(input)
+      filters32 = s32[3,3,17,32] convert(filter)
+
+      conv = s32[1,32,9,9] convolution(inputs32, filters32),
+               window={size=3x3 pad=1_1x1_1},
+               dim_labels=bf01_01io->bf01
+      conv_f32 = f32[1,32,9,9] convert(conv)
+      sum1 = add(conv_f32, bias_broadcast)
+      ROOT sum2 = add(sum1, side_input_f32)
+    })";
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+  ConvRewriter rewriter{GetCudaComputeCapability()};
+  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+                               GetToolkitVersion()};
+  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+  // Simplify new `convert`'s that may be added to the graph.
+  AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
+  TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
+
+  SCOPED_TRACE(m->ToString());
+  EXPECT_THAT(
+      m->entry_computation()->root_instruction(),
+      GmockMatch(m::GetTupleElement(
+                     m::CustomCall({kCudnnConvBiasActivationForwardCallTarget},
+                                   m::Parameter(0), m::Parameter(1),
+                                   m::Parameter(2), m::Parameter(3)),
+                     0)
+                     .WithShape(F32, {1, 32, 9, 9})));
+}
+
+// The ReshapeMover pass changes
+//   reshape(side_input) * alpha -->
+//   reshape(side_input * alpha).
+// Make sure we can pattern-match this.
+TEST_F(CudnnFusedConvRewriterHloTest, Int8SideInputWithScaleAndReshape) {
+  MAYBE_SKIP_TEST("I8");
+  const std::string module_str = R"(
+    HloModule Test
+
+    ENTRY Test {
+      input = s32[1,17,9,9] convert(s8[1,17,9,9] parameter(0))
+      filter = s32[3,3,17,32] convert(s8[3,3,17,32] parameter(1))
+      bias = f32[1,32,9,9] broadcast(f32[32] parameter(2)), dimensions={1}
+      side_input_scale = f32[2592] broadcast(f32[] constant(0.25)), dimensions={}
+      side_input = f32[1,32,9,9] reshape(multiply(f32[2592] convert(s8[2592] parameter(3)), side_input_scale))
+
+      conv = s32[1,32,9,9] convolution(input, filter),
+               window={size=3x3 pad=1_1x1_1},
+               dim_labels=bf01_01io->bf01
+      ROOT root = s8[1,32,9,9] convert(clamp(f32[1,32,9,9] broadcast(f32[] constant(-128)),
+                                             add(add(f32[1,32,9,9] convert(conv), bias), side_input),
+                                             f32[1,32,9,9] broadcast(f32[] constant(127))))
+    })";
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+  ConvRewriter rewriter{GetCudaComputeCapability()};
+  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+                               GetToolkitVersion()};
+  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+  // Simplify new `convert`'s that may be added to the graph.
+  HloPassFix<HloPassPipeline> simplify("simplify");
+  simplify.AddPass<AlgebraicSimplifier>(AlgebraicSimplifierOptions{});
+  simplify.AddPass<ReshapeMover>();
+  simplify.AddPass<ConvertMover>();
+  TF_ASSERT_OK(RunHloPass(&simplify, m.get()).status());
+
+  SCOPED_TRACE(m->ToString());
+  const HloInstruction* conv = nullptr;
+  ASSERT_THAT(
+      m->entry_computation()->root_instruction(),
+      GmockMatch(
+          m::GetTupleElement(
+              m::CustomCall(
+                  &conv, {kCudnnConvBiasActivationForwardCallTarget},
+                  m::Parameter(0),  //
+                  m::Parameter(1),  //
+                  m::Parameter(2),  //
+                  m::Reshape(m::Parameter(3)).WithShape(S8, {1, 32, 9, 9})),
+              0)
+              .WithShape(S8, {1, 32, 9, 9})));
+  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+                          conv->backend_config<GpuBackendConfig>());
+  const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
+  EXPECT_EQ(config.conv_result_scale(), 1);
+  EXPECT_EQ(config.side_input_scale(), 0.25);
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, FuseAlpha) {
+  MAYBE_SKIP_TEST("I8");
+  const std::string module_str = R"(
+    HloModule Test
+
+    ENTRY Test {
+      input = s8[1,17,9,9] parameter(0)
+      filter = s8[3,3,17,32] parameter(1)
+      inputs32 = s32[1,17,9,9] convert(input)
+      filters32 = s32[3,3,17,32] convert(filter)
+      alpha = f32[] constant(42)
+      alpha_broadcast = f32[1,32,9,9] broadcast(alpha), dimensions={}
+
+      conv = s32[1,32,9,9] convolution(inputs32, filters32),
+               window={size=3x3 pad=1_1x1_1},
+               dim_labels=bf01_01io->bf01
+      convert = f32[1,32,9,9] convert(conv)
+      ROOT root = multiply(convert, alpha_broadcast)
+    })";
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+  ConvRewriter rewriter{GetCudaComputeCapability()};
+  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+                               GetToolkitVersion()};
+  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+  SCOPED_TRACE(m->ToString());
+  const HloInstruction* conv = nullptr;
+  ASSERT_THAT(
+      m->entry_computation()->root_instruction(),
+      GmockMatch(
+          m::GetTupleElement(
+              m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget}),
+              0)
+              .WithShape(F32, {1, 32, 9, 9})));
+  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+                          conv->backend_config<GpuBackendConfig>());
+  const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
+  EXPECT_EQ(config.conv_result_scale(), 42);
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, FuseRelu) {
+  const std::string module_str = R"(
+    HloModule Test
+
+    ENTRY Test {
+      inputs = f32[1,17,9,9] parameter(0)
+      filters = f32[3,3,17,32] parameter(1)
+      bias = f32[32] parameter(2)
+      bias_broadcast = f32[1,32,9,9] broadcast(bias), dimensions={1}
+      zero = f32[] constant(0)
+      zeros = f32[1,32,9,9] broadcast(zero), dimensions={}
+      conv = f32[1,32,9,9] convolution(inputs, filters),
+               window={size=3x3 pad=1_1x1_1},
+               dim_labels=bf01_01io->bf01
+      sum = add(conv, bias_broadcast)
+      ROOT relu = maximum(sum, zeros)
+    })";
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+  ConvRewriter rewriter{GetCudaComputeCapability()};
+  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+                               GetToolkitVersion()};
+  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+  SCOPED_TRACE(m->ToString());
+  const HloInstruction* conv;
+  ASSERT_THAT(
+      m->entry_computation()->root_instruction(),
+      GmockMatch(
+          m::GetTupleElement(
+              m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget},
+                            m::Parameter(0), m::Parameter(1), m::Parameter(2)),
+              0)
+              .WithShape(F32, {1, 32, 9, 9})));
+  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+                          conv->backend_config<GpuBackendConfig>());
+  const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
+  EXPECT_EQ(config.activation_mode(), se::dnn::kRelu);
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, DontFuseReluIfMultipleUses) {
+  const std::string module_str = R"(
+    HloModule Test
+
+    ENTRY Test {
+      inputs = f32[1,17,9,9] parameter(0)
+      filters = f32[3,3,17,32] parameter(1)
+      bias = f32[1,32,9,9] broadcast(f32[32] parameter(2)), dimensions={1}
+      zeros = f32[1,32,9,9] broadcast(f32[] constant(0)), dimensions={}
+      conv = f32[1,32,9,9] convolution(inputs, filters),
+               window={size=3x3 pad=1_1x1_1},
+               dim_labels=bf01_01io->bf01
+      sum = add(conv, bias)
+      relu = maximum(sum, zeros)
+      not_relu = minimum(sum, zeros)
+      ROOT root = tuple(relu, not_relu)
+    })";
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+  ConvRewriter rewriter{GetCudaComputeCapability()};
+  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+                               GetToolkitVersion()};
+  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+  SCOPED_TRACE(m->ToString());
+  const HloInstruction* conv;
+  ASSERT_THAT(
+      m->entry_computation()->root_instruction(),
+      GmockMatch(m::Tuple(
+          m::MaximumAnyOrder(
+              m::Broadcast(m::ConstantEffectiveScalar(0)),
+              m::GetTupleElement(
+                  m::CustomCall(
+                      &conv, {kCudnnConvBiasActivationForwardCallTarget},
+                      m::Parameter(0), m::Parameter(1), m::Parameter(2)),
+                  0)
+                  .WithShape(F32, {1, 32, 9, 9})),
+          m::Minimum())));
+  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+                          conv->backend_config<GpuBackendConfig>());
+  const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
+  EXPECT_EQ(config.activation_mode(), se::dnn::kNone);
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, FuseElu) {
+  const std::string module_str = R"(
+    HloModule Test
+
+    ENTRY Test {
+      inputs = f16[1,16,9,9] parameter(0)
+      filters = f16[3,3,16,32] parameter(1)
+      bias = f16[32] parameter(2)
+      bias_broadcast = f16[1,32,9,9] broadcast(bias), dimensions={1}
+      zero = f16[] constant(0)
+      zeros = f16[1,32,9,9] broadcast(zero), dimensions={}
+      conv = f16[1,32,9,9] convolution(inputs, filters),
+               window={size=3x3 pad=1_1x1_1},
+               dim_labels=bf01_01io->bf01
+      sum = add(conv, bias_broadcast)
+      cmp = compare(sum, zeros), direction=GT
+      expm1 = exponential-minus-one(sum)
+      ROOT elu = select(cmp, sum, expm1)
+    })";
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+  DebugOptions debug_opts = m->config().debug_options();
+  debug_opts.set_xla_gpu_use_runtime_fusion(true);
+  m->mutable_config().set_debug_options(debug_opts);
+
+  ConvRewriter rewriter{GetCudaComputeCapability()};
+  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+  // elu fusion is only active on Ampere+.
+  CudnnFusedConvRewriter fuser{se::CudaComputeCapability(8, 0), GetDnnVersion(),
+                               GetToolkitVersion()};
+  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+  SCOPED_TRACE(m->ToString());
+  const HloInstruction* conv;
+  ASSERT_THAT(
+      m->entry_computation()->root_instruction(),
+      GmockMatch(
+          m::GetTupleElement(
+              m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget},
+                            m::Parameter(0), m::Parameter(1), m::Parameter(2)),
+              0)
+              .WithShape(F16, {1, 32, 9, 9})));
+  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+                          conv->backend_config<GpuBackendConfig>());
+  const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
+  EXPECT_EQ(config.activation_mode(), se::dnn::kElu);
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, DontFuseEluIfMultipleUses) {
+  const std::string module_str = R"(
+    HloModule Test
+
+    ENTRY Test {
+      inputs = f16[1,16,9,9] parameter(0)
+      filters = f16[3,3,16,32] parameter(1)
+      bias = f16[32] parameter(2)
+      bias_broadcast = f16[1,32,9,9] broadcast(bias), dimensions={1}
+      zero = f16[] constant(0)
+      zeros = f16[1,32,9,9] broadcast(zero), dimensions={}
+      conv = f16[1,32,9,9] convolution(inputs, filters),
+               window={size=3x3 pad=1_1x1_1},
+               dim_labels=bf01_01io->bf01
+      sum = add(conv, bias_broadcast)
+      cmp = compare(sum, zeros), direction=GT
+      expm1 = exponential-minus-one(sum)
+      elu = select(cmp, sum, expm1)
+      not_elu = minimum(sum, zeros)
+      ROOT root = tuple(elu, not_elu)
+    })";
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+  DebugOptions debug_opts = m->config().debug_options();
+  debug_opts.set_xla_gpu_use_runtime_fusion(true);
+  m->mutable_config().set_debug_options(debug_opts);
+
+  ConvRewriter rewriter{GetCudaComputeCapability()};
+  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+                               GetToolkitVersion()};
+  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+  SCOPED_TRACE(m->ToString());
+  const HloInstruction* conv;
+  auto gte_pattern =
+      m::GetTupleElement(
+          m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget},
+                        m::Parameter(0), m::Parameter(1), m::Parameter(2)),
+          0)
+          .WithShape(F16, {1, 32, 9, 9});
+  ASSERT_THAT(
+      m->entry_computation()->root_instruction(),
+      GmockMatch(m::Tuple(
+          m::Select(m::Compare(gte_pattern,
+                               m::Broadcast(m::ConstantEffectiveScalar(0)))
+                        .WithComparisonDirection(ComparisonDirection::kGt),
+                    gte_pattern,
+                    m::Op()
+                        .WithPredicate(HloPredicateIsOp<HloOpcode::kExpm1>)
+                        .WithOperand(0, gte_pattern)),
+          m::Minimum())));
+  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+                          conv->backend_config<GpuBackendConfig>());
+  const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
+  EXPECT_EQ(config.activation_mode(), se::dnn::kNone);
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, FuseRelu6) {
+  const std::string module_str = R"(
+    HloModule Test
+    ENTRY Test {
+      inputs = f16[1,18,9,9] parameter(0)
+      filters = f16[3,3,18,32] parameter(1)
+      bias = f16[32] parameter(2)
+      bias_broadcast = f16[1,32,9,9] broadcast(bias), dimensions={1}
+      zero = f16[] constant(0)
+      zeros = f16[1,32,9,9] broadcast(zero), dimensions={}
+      sixes = f16[1,32,9,9] broadcast(f16[] constant(6)), dimensions={}
+      conv = f16[1,32,9,9] convolution(inputs, filters),
+               window={size=3x3 pad=1_1x1_1},
+               dim_labels=bf01_01io->bf01
+      sum = add(conv, bias_broadcast)
+      ROOT relu = clamp(zeros, sum, sixes)
+    })";
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+  DebugOptions debug_opts = m->config().debug_options();
+  debug_opts.set_xla_gpu_use_runtime_fusion(true);
+  m->mutable_config().set_debug_options(debug_opts);
+
+  ConvRewriter rewriter{GetCudaComputeCapability()};
+  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+  // relu6 fusion is only enabled on Ampere+.
+  CudnnFusedConvRewriter fuser{se::CudaComputeCapability(8, 0), GetDnnVersion(),
+                               GetToolkitVersion()};
+  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+  SCOPED_TRACE(m->ToString());
+  const HloInstruction* conv;
+  ASSERT_THAT(
+      m->entry_computation()->root_instruction(),
+      GmockMatch(
+          m::GetTupleElement(
+              m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget},
+                            m::Parameter(0), m::Parameter(1), m::Parameter(2)),
+              0)
+              .WithShape(F16, {1, 32, 9, 9})));
+  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+                          conv->backend_config<GpuBackendConfig>());
+  const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
+  EXPECT_EQ(config.activation_mode(), se::dnn::kRelu6);
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, DontFuseRelu6IfMultipleUses) {
+  const std::string module_str = R"(
+    HloModule Test
+    ENTRY Test {
+      inputs = f16[1,18,9,9] parameter(0)
+      filters = f16[3,3,18,32] parameter(1)
+      bias = f16[1,32,9,9] broadcast(f16[32] parameter(2)), dimensions={1}
+      zeros = f16[1,32,9,9] broadcast(f16[] constant(0)), dimensions={}
+      sixes = f16[1,32,9,9] broadcast(f16[] constant(6)), dimensions={}
+      conv = f16[1,32,9,9] convolution(inputs, filters),
+               window={size=3x3 pad=1_1x1_1},
+               dim_labels=bf01_01io->bf01
+      sum = add(conv, bias)
+      relu = clamp(zeros, sum, sixes)
+      not_relu = minimum(sum, zeros)
+      ROOT root = tuple(relu, not_relu)
+    })";
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+  DebugOptions debug_opts = m->config().debug_options();
+  debug_opts.set_xla_gpu_use_runtime_fusion(true);
+  m->mutable_config().set_debug_options(debug_opts);
+
+  ConvRewriter rewriter{GetCudaComputeCapability()};
+  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+                               GetToolkitVersion()};
+  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+  SCOPED_TRACE(m->ToString());
+  const HloInstruction* conv;
+  ASSERT_THAT(
+      m->entry_computation()->root_instruction(),
+      GmockMatch(m::Tuple(
+          m::Clamp(m::Broadcast(m::ConstantEffectiveScalar(0)),
+                   m::GetTupleElement(
+                       m::CustomCall(
+                           &conv, {kCudnnConvBiasActivationForwardCallTarget},
+                           m::Parameter(0), m::Parameter(1), m::Parameter(2)),
+                       0)
+                       .WithShape(F16, {1, 32, 9, 9}),
+                   m::Broadcast(m::ConstantEffectiveScalar(6))),
+          m::Minimum())));
+  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+                          conv->backend_config<GpuBackendConfig>());
+  const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
+  EXPECT_EQ(config.activation_mode(), se::dnn::kNone);
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, FuseLeakyRelu) {
+  const std::string module_str = R"(
+    HloModule Test
+    ENTRY Test {
+      inputs = f16[1,16,9,9] parameter(0)
+      filters = f16[3,3,16,32] parameter(1)
+      bias = f16[1,32,9,9] broadcast(f16[32] parameter(2)), dimensions={1}
+      zeros = f16[1,32,9,9] broadcast(f16[] constant(0)), dimensions={}
+      alphas = f16[1,32,9,9] broadcast(f16[] constant(0.2)), dimensions={}
+      conv = f16[1,32,9,9] convolution(inputs, filters),
+               window={size=3x3 pad=1_1x1_1},
+               dim_labels=bf01_01io->bf01
+      sum = add(conv, bias)
+      cmp = compare(sum, zeros), direction=GT
+      mul = multiply(sum, alphas)
+      ROOT leaky_relu = select(cmp, sum, mul)
+    })";
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+  DebugOptions debug_opts = m->config().debug_options();
+  debug_opts.set_xla_gpu_use_runtime_fusion(true);
+  m->mutable_config().set_debug_options(debug_opts);
+
+  ConvRewriter rewriter{GetCudaComputeCapability()};
+  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+  // Leaky-relu fusion is only enabled on Ampere+.
+  CudnnFusedConvRewriter fuser{se::CudaComputeCapability(8, 0), GetDnnVersion(),
+                               GetToolkitVersion()};
+  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+  SCOPED_TRACE(m->ToString());
+  const HloInstruction* conv;
+  ASSERT_THAT(
+      m->entry_computation()->root_instruction(),
+      GmockMatch(
+          m::GetTupleElement(
+              m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget},
+                            m::Parameter(0), m::Parameter(1), m::Parameter(2)),
+              0)
+              .WithShape(F16, {1, 32, 9, 9})));
+  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+                          conv->backend_config<GpuBackendConfig>());
+  const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
+  EXPECT_EQ(config.activation_mode(), se::dnn::kLeakyRelu);
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, DontFuseLeakyReluIfMultipleUses) {
+  const std::string module_str = R"(
+    HloModule Test
+    ENTRY Test {
+      inputs = f16[1,16,9,9] parameter(0)
+      filters = f16[3,3,16,32] parameter(1)
+      bias = f16[1,32,9,9] broadcast(f16[32] parameter(2)), dimensions={1}
+      zeros = f16[1,32,9,9] broadcast(f16[] constant(0)), dimensions={}
+      alphas = f16[1,32,9,9] broadcast(f16[] constant(0.2)), dimensions={}
+      conv = f16[1,32,9,9] convolution(inputs, filters),
+               window={size=3x3 pad=1_1x1_1},
+               dim_labels=bf01_01io->bf01
+      sum = add(conv, bias)
+      cmp = compare(sum, zeros), direction=GT
+      mul = multiply(sum, alphas)
+      leaky_relu = select(cmp, sum, mul)
+      not_leaky_relu = minimum(sum, zeros)
+      ROOT root = tuple(leaky_relu, not_leaky_relu)
+    })";
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+  DebugOptions debug_opts = m->config().debug_options();
+  debug_opts.set_xla_gpu_use_runtime_fusion(true);
+  m->mutable_config().set_debug_options(debug_opts);
+
+  ConvRewriter rewriter{GetCudaComputeCapability()};
+  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+                               GetToolkitVersion()};
+  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+  SCOPED_TRACE(m->ToString());
+  const HloInstruction* conv;
+  auto gte_pattern =
+      m::GetTupleElement(
+          m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget},
+                        m::Parameter(0), m::Parameter(1), m::Parameter(2)),
+          0)
+          .WithShape(F16, {1, 32, 9, 9});
+  ASSERT_THAT(
+      m->entry_computation()->root_instruction(),
+      GmockMatch(m::Tuple(
+          m::Select(m::Compare(gte_pattern,
+                               m::Broadcast(m::ConstantEffectiveScalar(0)))
+                        .WithComparisonDirection(ComparisonDirection::kGt)
+                        .WithOneUse(),
+                    gte_pattern,
+                    m::Multiply(gte_pattern,
+                                m::Broadcast(m::ConstantEffectiveScalar()))),
+          m::Minimum())));
+  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+                          conv->backend_config<GpuBackendConfig>());
+  const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
+  EXPECT_EQ(config.activation_mode(), se::dnn::kNone);
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, DontFuseAlphaIfMultipleUsers) {
+  const std::string module_str = R"(
+    HloModule Test
+
+    ENTRY Test {
+      inputs = f32[1,17,9,9] parameter(0)
+      filters = f32[3,3,17,32] parameter(1)
+      bias = f32[1,32,9,9] broadcast(f32[32] parameter(2)), dimensions={1}
+      alpha = f32[1,32,9,9] broadcast(f32[] parameter(3)), dimensions={}
+      conv = f32[1,32,9,9] convolution(inputs, filters),
+               window={size=3x3 pad=1_1x1_1},
+               dim_labels=bf01_01io->bf01
+      sum = add(multiply(alpha, conv), bias)
+      ROOT root = tuple(conv, sum)
+    })";
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+  ConvRewriter rewriter{GetCudaComputeCapability()};
+  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+                               GetToolkitVersion()};
+  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+  SCOPED_TRACE(m->ToString());
+  const HloInstruction* conv1;
+  const HloInstruction* conv2;
+  ASSERT_THAT(
+      m->entry_computation()->root_instruction(),
+      GmockMatch(m::Tuple(
+          m::GetTupleElement(m::CustomCall(&conv1), 0),
+          m::AddAnyOrder(m::Broadcast(m::Parameter(2)),
+                         m::MultiplyAnyOrder(
+                             m::Broadcast(m::Parameter(3)),
+                             m::GetTupleElement(m::CustomCall(&conv2), 0))))));
+  EXPECT_EQ(conv1, conv2);
+  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+                          conv1->backend_config<GpuBackendConfig>());
+  const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
+  EXPECT_EQ(config.conv_result_scale(), 1);
+  EXPECT_EQ(config.activation_mode(), se::dnn::kNone);
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, DontFuseBiasIfMultipleUsers) {
+  const std::string module_str = R"(
+    HloModule Test
+
+    ENTRY Test {
+      inputs = f32[1,17,9,9] parameter(0)
+      filters = f32[3,3,17,32] parameter(1)
+      bias = f32[1,32,9,9] broadcast(f32[32] parameter(2)), dimensions={1}
+      conv = f32[1,32,9,9] convolution(inputs, filters),
+               window={size=3x3 pad=1_1x1_1},
+               dim_labels=bf01_01io->bf01
+      ROOT root = tuple(conv, add(conv, bias))
+    })";
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+  ConvRewriter rewriter{GetCudaComputeCapability()};
+  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+                               GetToolkitVersion()};
+  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+  SCOPED_TRACE(m->ToString());
+  const HloInstruction* conv1;
+  const HloInstruction* conv2;
+  ASSERT_THAT(
+      m->entry_computation()->root_instruction(),
+      GmockMatch(m::Tuple(
+          m::GetTupleElement(m::CustomCall(&conv1), 0),
+          m::AddAnyOrder(m::Broadcast(m::Parameter(2)),
+                         m::GetTupleElement(m::CustomCall(&conv2), 0)))));
+  EXPECT_EQ(conv1, conv2);
+  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+                          conv1->backend_config<GpuBackendConfig>());
+  const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
+  EXPECT_EQ(config.conv_result_scale(), 1);
+  EXPECT_EQ(config.activation_mode(), se::dnn::kNone);
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, DontFuseSideInputThroughRelu) {
+  const std::string module_str = R"(
+    HloModule Test
+
+    ENTRY Test {
+      inputs = f32[1,17,9,9] parameter(0)
+      filters = f32[3,3,17,32] parameter(1)
+      side_input = f32[1,32,9,9] parameter(2)
+      conv = f32[1,32,9,9] convolution(inputs, filters),
+               window={size=3x3 pad=1_1x1_1},
+               dim_labels=bf01_01io->bf01
+      relu = maximum(conv, f32[1,32,9,9] broadcast(f32[] constant(0)))
+      ROOT root = add(relu, side_input)
+    })";
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+  ConvRewriter rewriter{GetCudaComputeCapability()};
+  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+                               GetToolkitVersion()};
+  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+  SCOPED_TRACE(m->ToString());
+  const HloInstruction* conv;
+  ASSERT_THAT(
+      m->entry_computation()->root_instruction(),
+      GmockMatch(m::AddAnyOrder(
+          m::Parameter(2),
+          m::GetTupleElement(
+              m::CustomCall(&conv, m::Parameter(0), m::Parameter(1),
+                            m::Broadcast(m::ConstantEffectiveScalar(0))),
+              0))));
+  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+                          conv->backend_config<GpuBackendConfig>());
+  const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
+  EXPECT_EQ(config.conv_result_scale(), 1);
+  EXPECT_EQ(config.activation_mode(), se::dnn::kRelu);
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, DontFuseBiasThroughRelu) {
+  const std::string module_str = R"(
+    HloModule Test
+
+    ENTRY Test {
+      inputs = f32[1,17,9,9] parameter(0)
+      filters = f32[3,3,17,32] parameter(1)
+      bias = f32[1,32,9,9] broadcast(f32[32] parameter(2)), dimensions={1}
+      conv = f32[1,32,9,9] convolution(inputs, filters),
+               window={size=3x3 pad=1_1x1_1},
+               dim_labels=bf01_01io->bf01
+      relu = maximum(conv, f32[1,32,9,9] broadcast(f32[] constant(0)))
+      ROOT root = add(relu, bias)
+    })";
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+  ConvRewriter rewriter{GetCudaComputeCapability()};
+  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+                               GetToolkitVersion()};
+  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+  SCOPED_TRACE(m->ToString());
+  const HloInstruction* conv;
+  ASSERT_THAT(m->entry_computation()->root_instruction(),
+              GmockMatch(m::AddAnyOrder(
+                  m::Broadcast(m::Parameter(2)),
+                  m::GetTupleElement(m::CustomCall(
+                      &conv, m::Parameter(0), m::Parameter(1),
+                      m::Broadcast(m::ConstantEffectiveScalar(0)))))));
+  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+                          conv->backend_config<GpuBackendConfig>());
+  const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
+  EXPECT_EQ(config.conv_result_scale(), 1);
+  EXPECT_EQ(config.activation_mode(), se::dnn::kRelu);
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, DontFuseSideInputIfMultipleUsers) {
+  const std::string module_str = R"(
+    HloModule Test
+
+    ENTRY Test {
+      inputs = f32[1,17,9,9] parameter(0)
+      filters = f32[3,3,17,32] parameter(1)
+      side_input = f32[1,32,9,9] parameter(2)
+      conv = f32[1,32,9,9] convolution(inputs, filters),
+               window={size=3x3 pad=1_1x1_1},
+               dim_labels=bf01_01io->bf01
+      ROOT root = tuple(conv, add(conv, side_input))
+    })";
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+  ConvRewriter rewriter{GetCudaComputeCapability()};
+  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+                               GetToolkitVersion()};
+  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+  SCOPED_TRACE(m->ToString());
+  const HloInstruction* conv1;
+  const HloInstruction* conv2;
+  ASSERT_THAT(
+      m->entry_computation()->root_instruction(),
+      GmockMatch(m::Tuple(
+          m::GetTupleElement(m::CustomCall(&conv1), 0),
+          m::AddAnyOrder(m::Parameter(2),
+                         m::GetTupleElement(m::CustomCall(&conv2), 0)))));
+  EXPECT_EQ(conv1, conv2);
+  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+                          conv1->backend_config<GpuBackendConfig>());
+  const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
+  EXPECT_EQ(config.conv_result_scale(), 1);
+  EXPECT_EQ(config.activation_mode(), se::dnn::kNone);
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, DontFuseConvertToF16IfMultipleUsers) {
+  const std::string module_str = R"(
+    HloModule Test
+
+    ENTRY Test {
+      inputs = f32[1,17,9,9] convert(f16[1,17,9,9] parameter(0))
+      filters = f32[3,3,17,32] convert(f16[3,3,17,32] parameter(1))
+      conv = f32[1,32,9,9] convolution(inputs, filters),
+               window={size=3x3 pad=1_1x1_1},
+               dim_labels=bf01_01io->bf01
+      ROOT root = tuple(conv, f16[1,32,9,9] convert(conv))
+    })";
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+  ConvRewriter rewriter{GetCudaComputeCapability()};
+  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+                               GetToolkitVersion()};
+  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+  SCOPED_TRACE(m->ToString());
+  const HloInstruction* conv1;
+  const HloInstruction* conv2;
+  ASSERT_THAT(m->entry_computation()->root_instruction(),
+              GmockMatch(m::Tuple(
+                  m::GetTupleElement(m::CustomCall(&conv1), 0),
+                  m::Convert(m::GetTupleElement(m::CustomCall(&conv2), 0)))));
+  EXPECT_EQ(conv1, conv2);
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, DontFuseToS8IfMultipleUsers) {
+  const std::string module_str = R"(
+    HloModule Test
+
+    ENTRY Test {
+      inputs = f32[1,17,9,9] convert(s8[1,17,9,9] parameter(0))
+      filters = f32[3,3,17,32] convert(s8[3,3,17,32] parameter(1))
+      conv = f32[1,32,9,9] convolution(inputs, filters),
+               window={size=3x3 pad=1_1x1_1},
+               dim_labels=bf01_01io->bf01
+      conv_s8 = s8[1,32,9,9] convert(clamp(
+                  f32[1,32,9,9] broadcast(f32[] constant(-128)),
+                  conv,
+                  f32[1,32,9,9] broadcast(f32[] constant(127))))
+      ROOT root = tuple(conv, conv_s8)
+    })";
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+  ConvRewriter rewriter{GetCudaComputeCapability()};
+  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+                               GetToolkitVersion()};
+  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+  SCOPED_TRACE(m->ToString());
+  const HloInstruction* conv1;
+  const HloInstruction* conv2;
+  ASSERT_THAT(
+      m->entry_computation()->root_instruction(),
+      GmockMatch(m::Tuple(
+          m::GetTupleElement(m::CustomCall(&conv1), 0),
+          m::Convert(m::Clamp(m::Op(),  //
+                              m::GetTupleElement(m::CustomCall(&conv2), 0),
+                              m::Op())))));
+  EXPECT_EQ(conv1, conv2);
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, RemoveConvertByFusingS32ToF32) {
+  MAYBE_SKIP_TEST("I8");
+  const std::string_view module_str = R"(
+    HloModule Test
+
+    ENTRY test_entry {
+      inputs = s8[1, 17, 9, 9] parameter(0)
+      filters = s8[3, 3, 17, 32] parameter(1)
+      mult_op  = f32[1, 32, 9, 9] parameter(2)
+      conv = s32[1, 32, 9, 9] convolution(inputs, filters), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01
+      ROOT ret = multiply(f32[1, 32, 9, 9] convert(conv), mult_op)
+    })";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+  ConvRewriter rewriter{GetCudaComputeCapability()};
+  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+                               GetToolkitVersion()};
+  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+  SCOPED_TRACE(m->ToString());
+  HloInstruction* conv1 = nullptr;
+  // Checks that it removed the Convert inside multiply around conv.
+  ASSERT_THAT(m->entry_computation()->root_instruction(),
+              GmockMatch(m::Multiply(m::GetTupleElement(m::CustomCall(&conv1)),
+                                     m::Parameter(2))));
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, RemoveConvertByFusingS8ToF32) {
+  MAYBE_SKIP_TEST("I8");
+  const std::string_view module_str = R"(
+    HloModule Test
+
+    ENTRY test_entry {
+      inputs = s8[1, 17, 9, 9] parameter(0)
+      filters = s8[3, 3, 17, 32] parameter(1)
+      mult_op  = f32[1, 32, 9, 9] parameter(2)
+      conv = convolution(inputs, filters), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01
+      ROOT ret = multiply(f32[1, 32, 9, 9] convert(conv), mult_op)
+    })";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+  ConvRewriter rewriter{GetCudaComputeCapability()};
+  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+                               GetToolkitVersion()};
+  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+  SCOPED_TRACE(m->ToString());
+  HloInstruction* conv1 = nullptr;
+  // Checks that it removed the Convert inside multiply around conv.
+  ASSERT_THAT(m->entry_computation()->root_instruction(),
+              GmockMatch(m::Multiply(m::GetTupleElement(m::CustomCall(&conv1)),
+                                     m::Parameter(2))));
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, RemoveConvertByFusingF32ToS8) {
+  MAYBE_SKIP_TEST("I8");
+  const std::string_view module_str = R"(
+    HloModule Test
+
+    ENTRY test_entry {
+      inputs = f32[1, 17, 9, 9] parameter(0)
+      filters = f32[3, 3, 17, 32] parameter(1)
+      mult_op  = s8[1, 32, 9, 9] parameter(2)
+      conv = convolution(inputs, filters), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01
+      ROOT ret = multiply(s8[1, 32, 9, 9] convert(conv), mult_op)
+    })";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+  ConvRewriter rewriter{GetCudaComputeCapability()};
+  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+                               GetToolkitVersion()};
+  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+  SCOPED_TRACE(m->ToString());
+  HloInstruction* conv1 = nullptr;
+  // Checks that it removed the Convert inside multiply around conv.
+  ASSERT_THAT(m->entry_computation()->root_instruction(),
+              GmockMatch(m::Multiply(m::GetTupleElement(m::CustomCall(&conv1)),
+                                     m::Parameter(2))));
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, DontRemoveConvertDuetoMultpleUser) {
+  const std::string_view module_str = R"(
+    HloModule Test
+
+    ENTRY test_entry {
+      inputs = f32[1, 17, 9, 9] parameter(0)
+      filters = f32[3, 3, 17, 32] parameter(1)
+      mult_op  = s8[1, 32, 9, 9] parameter(2)
+      sub_op = s8[1, 32, 9, 9] parameter(3)
+      conv = convolution(inputs, filters), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01
+      another = subtract(s8[1, 32, 9, 9] convert(conv), sub_op)
+      ROOT ret = multiply(s8[1, 32, 9, 9] convert(conv), mult_op)
+    })";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+  ConvRewriter rewriter{GetCudaComputeCapability()};
+  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+                               GetToolkitVersion()};
+  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+  SCOPED_TRACE(m->ToString());
+  HloInstruction* conv1 = nullptr;
+  // Checks that it removed the Convert inside multiply around conv.
+  ASSERT_THAT(m->entry_computation()->root_instruction(),
+              GmockMatch(m::Multiply(
+                  m::Convert(m::GetTupleElement(m::CustomCall(&conv1))),
+                  m::Parameter(2))));
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, FuseBias) {
+  const std::string module_str = R"(
+    HloModule Test
+
+    ENTRY Test {
+      inputs = f32[1,17,9,9] parameter(0)
+      filters = f32[3,3,17,32] parameter(1)
+      bias = f32[32] parameter(2)
+      bias_broadcast = f32[1,32,9,9] broadcast(bias), dimensions={1}
+      conv = f32[1,32,9,9] convolution(inputs, filters),
+               window={size=3x3 pad=1_1x1_1},
+               dim_labels=bf01_01io->bf01
+      ROOT root = add(conv, bias_broadcast)
+    })";
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+  ConvRewriter rewriter{GetCudaComputeCapability()};
+  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+                               GetToolkitVersion()};
+  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+  SCOPED_TRACE(m->ToString());
+  ASSERT_THAT(
+      m->entry_computation()->root_instruction(),
+      GmockMatch(
+          m::GetTupleElement(
+              m::CustomCall({kCudnnConvBiasActivationForwardCallTarget},
+                            m::Parameter(0), m::Parameter(1), m::Parameter(2)),
+              0)
+              .WithShape(F32, {1, 32, 9, 9})));
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, FuseSideInput) {
+  const std::string module_str = R"(
+    HloModule Test
+
+    ENTRY Test {
+      inputs = f32[1,17,9,9] parameter(0)
+      filters = f32[3,3,17,32] parameter(1)
+      side_input = f32[1,32,9,9] parameter(2)
+      conv = f32[1,32,9,9] convolution(inputs, filters),
+               window={size=3x3 pad=1_1x1_1},
+               dim_labels=bf01_01io->bf01
+      ROOT root = add(conv, side_input)
+    })";
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+  ConvRewriter rewriter{GetCudaComputeCapability()};
+  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+                               GetToolkitVersion()};
+  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+  SCOPED_TRACE(m->ToString());
+  const HloInstruction* conv;
+  ASSERT_THAT(
+      m->entry_computation()->root_instruction(),
+      GmockMatch(
+          m::GetTupleElement(
+              m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget},
+                            m::Parameter(0), m::Parameter(1),
+                            m::Broadcast(m::ConstantEffectiveScalar(0))
+                                .WithShape(F32, {32}),
+                            m::Parameter(2)),
+              0)
+              .WithShape(F32, {1, 32, 9, 9})));
+  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+                          conv->backend_config<GpuBackendConfig>());
+  const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
+  EXPECT_EQ(config.side_input_scale(), 1);
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, FuseScaledSideInput) {
+  const std::string module_str = R"(
+    HloModule Test
+
+    ENTRY Test {
+      inputs = f32[1,17,9,9] parameter(0)
+      filters = f32[3,3,17,32] parameter(1)
+      side_input = f32[1,32,9,9] parameter(2)
+      side_input_scale = f32[] constant(42)
+      side_input_scale_broadcast = f32[1,32,9,9] broadcast(side_input_scale), dimensions={}
+      side_input_product = multiply(side_input, side_input_scale_broadcast)
+      conv = f32[1,32,9,9] convolution(inputs, filters),
+               window={size=3x3 pad=1_1x1_1},
+               dim_labels=bf01_01io->bf01
+      ROOT root = add(conv, side_input_product)
+    })";
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+  ConvRewriter rewriter{GetCudaComputeCapability()};
+  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+                               GetToolkitVersion()};
+  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+  SCOPED_TRACE(m->ToString());
+  const HloInstruction* conv;
+  ASSERT_THAT(
+      m->entry_computation()->root_instruction(),
+      GmockMatch(
+          m::GetTupleElement(
+              m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget},
+                            m::Parameter(0), m::Parameter(1),
+                            m::Broadcast(m::ConstantEffectiveScalar(0))
+                                .WithShape(F32, {32}),
+                            m::Parameter(2)),
+              0)
+              .WithShape(F32, {1, 32, 9, 9})));
+  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+                          conv->backend_config<GpuBackendConfig>());
+  const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
+  EXPECT_EQ(config.side_input_scale(), 42);
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, FuseBiasAndSideInput) {
+  const std::string module_str = R"(
+    HloModule Test
+
+    ENTRY Test {
+      inputs = f32[1,17,9,9] parameter(0)
+      filters = f32[3,3,17,32] parameter(1)
+      bias = f32[32] parameter(2)
+      side_input = f32[1,32,9,9] parameter(3)
+      bias_broadcast = f32[1,32,9,9] broadcast(bias), dimensions={1}
+      conv = f32[1,32,9,9] convolution(inputs, filters),
+               window={size=3x3 pad=1_1x1_1},
+               dim_labels=bf01_01io->bf01
+      sum = add(conv, side_input)
+      ROOT sum2 = add(sum, bias_broadcast)
+    })";
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+  ConvRewriter rewriter{GetCudaComputeCapability()};
+  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+                               GetToolkitVersion()};
+  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+  SCOPED_TRACE(m->ToString());
+  const HloInstruction* conv;
+  ASSERT_THAT(
+      m->entry_computation()->root_instruction(),
+      GmockMatch(
+          m::GetTupleElement(
+              m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget},
+                            m::Parameter(0), m::Parameter(1), m::Parameter(2),
+                            m::Parameter(3)),
+              0)
+              .WithShape(F32, {1, 32, 9, 9})));
+  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+                          conv->backend_config<GpuBackendConfig>());
+  const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
+  EXPECT_EQ(config.side_input_scale(), 1);
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, EffectiveScalarBias) {
+  const std::string module_str = R"(
+    HloModule Test
+
+    ENTRY Test {
+      inputs = f32[1,17,9,9] parameter(0)
+      filters = f32[3,3,17,32] parameter(1)
+      bias = f32[1,32,9,9] broadcast(f32[] parameter(2)), dimensions={}
+      conv = f32[1,32,9,9] convolution(inputs, filters),
+               window={size=3x3 pad=1_1x1_1},
+               dim_labels=bf01_01io->bf01
+      ROOT root = add(conv, bias)
+    })";
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+  ConvRewriter rewriter{GetCudaComputeCapability()};
+  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+                               GetToolkitVersion()};
+  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+  SCOPED_TRACE(m->ToString());
+  const HloInstruction* conv;
+  ASSERT_THAT(
+      m->entry_computation()->root_instruction(),
+      GmockMatch(
+          m::GetTupleElement(
+              m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget},
+                            m::Parameter(0), m::Parameter(1),
+                            m::Broadcast(m::Parameter(2)).WithShape(F32, {32})),
+              0)
+              .WithShape(F32, {1, 32, 9, 9})));
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, StrengthReduceF32ToF16) {
+  const std::string module_str = R"(
+    HloModule Test
+
+    ENTRY Test {
+      inputs = f16[1,17,9,9] parameter(0)
+      filters = f16[3,3,17,32] parameter(1)
+      bias = f16[32] parameter(2)
+      side_input = f16[1,32,9,9] parameter(3)
+
+      inputs_f32 = f32[1,17,9,9] convert(inputs)
+      filters_f32 = f32[3,3,17,32] convert(filters)
+      bias_f32 = f32[32] convert(bias)
+      bias_broadcast = f32[1,32,9,9] broadcast(bias_f32), dimensions={1}
+      side_input_f32 = f32[1,32,9,9] convert(side_input)
+      conv = f32[1,32,9,9] convolution(inputs_f32, filters_f32),
+               window={size=3x3 pad=1_1x1_1},
+               dim_labels=bf01_01io->bf01
+      sum = add(conv, side_input_f32)
+      sum2 = add(sum, bias_broadcast)
+      ROOT conv_f16 = f16[1,32,9,9] convert(sum2)
+    })";
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+  ConvRewriter rewriter{GetCudaComputeCapability()};
+  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+                               GetToolkitVersion()};
+  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+  // Simplify new `convert`'s that may be added to the graph.
+  AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
+  TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
+
+  SCOPED_TRACE(m->ToString());
+  const HloInstruction* conv;
+  ASSERT_THAT(
+      m->entry_computation()->root_instruction(),
+      GmockMatch(
+          m::GetTupleElement(
+              m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget},
+                            m::Parameter(0), m::Parameter(1), m::Parameter(2),
+                            m::Parameter(3)),
+              0)
+              .WithShape(F16, {1, 32, 9, 9})));
+  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+                          conv->backend_config<GpuBackendConfig>());
+  const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
+  EXPECT_EQ(config.side_input_scale(), 1);
+}
+
+// We should be able to lower this to an f16 convolution even though the
+// f16-ness of the inputs is hidden behind broadcast/transpose/reshape.
+TEST_F(CudnnFusedConvRewriterHloTest, BroadcastReshapeTransposeAfterConvert) {
+  const std::string module_str = R"(
+    HloModule Test
+
+    ENTRY Test {
+      inputs = f32[1,17,9,9] reshape(f32[1377] convert(f16[1377] parameter(0)))
+      filters = f32[3,3,17,32] transpose(f32[17,32,3,3] convert(f16[17,32,3,3] parameter(1))), dimensions={2,3,0,1}
+      bias = f16[1,32,9,9] broadcast(f16[32] parameter(2)), dimensions={1}
+      side_input = f16[1,32,9,9] reshape(f16[2592] parameter(3))
+
+      conv_f32 = f32[1,32,9,9] convolution(inputs, filters),
+                 window={size=3x3 pad=1_1x1_1},
+                 dim_labels=bf01_01io->bf01
+      conv_f16 = f16[1,32,9,9] convert(conv_f32)
+      ROOT root = f16[1,32,9,9] add(add(conv_f16, side_input), bias)
+    })";
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+  ConvRewriter rewriter{GetCudaComputeCapability()};
+  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+                               GetToolkitVersion()};
+  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+  // Simplify new `convert`'s that may be added to the graph.
+  AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
+  TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
+
+  SCOPED_TRACE(m->ToString());
+  const HloInstruction* conv;
+  ASSERT_THAT(
+      m->entry_computation()->root_instruction(),
+      GmockMatch(m::GetTupleElement(
+                     m::CustomCall(
+                         &conv, {kCudnnConvBiasActivationForwardCallTarget},
+                         m::Convert(m::Reshape(m::Convert(m::Parameter(0))))
+                             .WithElementType(F16),
+                         m::Convert(m::Transpose(m::Convert(m::Parameter(1))))
+                             .WithElementType(F16),
+                         m::Parameter(2), m::Reshape(m::Parameter(3))),
+                     0)
+                     .WithShape(F16, {1, 32, 9, 9})));
+  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+                          conv->backend_config<GpuBackendConfig>());
+  const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
+  EXPECT_EQ(config.side_input_scale(), 1);
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, NoStrengthReduceF32ToF16IfBiasIsF32) {
+  const std::string module_str = R"(
+    HloModule Test
+
+    ENTRY Test {
+      inputs = f16[1,17,9,9] parameter(0)
+      filters = f16[3,3,17,32] parameter(1)
+      bias = f32[32] parameter(2)
+      side_input = f16[1,32,9,9] parameter(3)
+
+      inputs_f32 = f32[1,17,9,9] convert(inputs)
+      filters_f32 = f32[3,3,17,32] convert(filters)
+      bias_broadcast = f32[1,32,9,9] broadcast(bias), dimensions={1}
+      side_input_f32 = f32[1,32,9,9] convert(side_input)
+      conv = f32[1,32,9,9] convolution(inputs_f32, filters_f32),
+               window={size=3x3 pad=1_1x1_1},
+               dim_labels=bf01_01io->bf01
+      sum = add(conv, side_input_f32)
+      sum2 = add(sum, bias_broadcast)
+      ROOT conv_f16 = f16[1,32,9,9] convert(sum2)
+    })";
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+  ConvRewriter rewriter{GetCudaComputeCapability()};
+  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+                               GetToolkitVersion()};
+  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+  // Simplify new `convert`'s that may be added to the graph.
+  AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
+  TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
+
+  SCOPED_TRACE(m->ToString());
+  const HloInstruction* conv;
+  // fp16 convs only support fp16 biases.  Because bias is fp32, it doesn't get
+  // fused in, and we get an fp32 conv.
+  ASSERT_THAT(
+      m->entry_computation()->root_instruction(),
+      GmockMatch(
+          m::Convert(m::GetTupleElement(
+                         m::CustomCall(
+                             &conv, {kCudnnConvBiasActivationForwardCallTarget},
+                             m::Convert(m::Parameter(0)).WithElementType(F32),
+                             m::Convert(m::Parameter(1)).WithElementType(F32),
+                             m::Parameter(2),
+                             m::Convert(m::Parameter(3)).WithElementType(F32)),
+                         0))
+              .WithShape(F16, {1, 32, 9, 9})));
+  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+                          conv->backend_config<GpuBackendConfig>());
+  const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
+  EXPECT_EQ(config.side_input_scale(), 1);
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, F32Constants) {
+  const std::string module_str = R"(
+    HloModule Test
+
+    ENTRY Test {
+      inputs = f16[1,2,2,2] parameter(0)
+      filters_f32 = f32[1,1,2,2] constant({{{{1, 2},{3, 4}}}})
+      bias = f16[2] parameter(1)
+      bias_f32 = f32[2] convert(bias)
+      side_input_f32 = f32[1,2,2,2] constant({{
+        {{0.5, 0.25}, {0.125, 0.0625}},
+        {{0.5, 0.25}, {0.125, 0.0625}}
+      }})
+
+      inputs_f32 = f32[1,2,2,2] convert(inputs)
+      bias_broadcast = f32[1,2,2,2] broadcast(bias_f32), dimensions={1}
+      conv = f32[1,2,2,2] convolution(inputs_f32, filters_f32),
+               window={size=1x1}, dim_labels=bf01_01io->bf01
+      sum = add(conv, side_input_f32)
+      sum2 = add(sum, bias_broadcast)
+      ROOT conv_f16 = f16[1,2,2,2] convert(sum2)
+    })";
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+  ConvRewriter rewriter{GetCudaComputeCapability()};
+  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+                               GetToolkitVersion()};
+  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+  // Simplify new `convert`'s that may be added to the graph, and fold
+  // convert back into constants.
+  AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
+  TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
+  HloConstantFolding constant_folding;
+  TF_ASSERT_OK(RunHloPass(&constant_folding, m.get()).status());
+
+  SCOPED_TRACE(m->ToString());
+  const HloInstruction* conv;
+  ASSERT_THAT(
+      m->entry_computation()->root_instruction(),
+      GmockMatch(m::GetTupleElement(
+                     m::CustomCall(
+                         &conv, {kCudnnConvBiasActivationForwardCallTarget},
+                         m::Parameter(0), m::Constant().WithElementType(F16),
+                         m::Parameter(1), m::Constant().WithElementType(F16)),
+                     0)
+                     .WithShape(F16, {1, 2, 2, 2})));
+  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+                          conv->backend_config<GpuBackendConfig>());
+  const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
+  EXPECT_EQ(config.side_input_scale(), 1);
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, F32ConstantsNotLosslesslyConvertible) {
+  const std::string module_str = R"(
+    HloModule Test
+
+    ENTRY Test {
+      inputs = f16[1,2,2,2] parameter(0)
+      filters_f32 = f32[1,1,2,2] constant({{{{1, 2.123456789},{3, 4}}}})
+      bias = f16[2] parameter(1)
+      bias_f32 = f32[2] convert(bias)
+      side_input_f32 = f32[1,2,2,2] constant({{
+        {{0.1, 0.2}, {0.3, 0.4}},
+        {{0.5, 0.6}, {0.7, 0.8}}
+      }})
+
+      inputs_f32 = f32[1,2,2,2] convert(inputs)
+      bias_broadcast = f32[1,2,2,2] broadcast(bias_f32), dimensions={1}
+      conv = f32[1,2,2,2] convolution(inputs_f32, filters_f32),
+               window={size=1x1}, dim_labels=bf01_01io->bf01
+      sum = add(conv, side_input_f32)
+      sum2 = add(sum, bias_broadcast)
+      ROOT conv_f16 = f16[1,2,2,2] convert(sum2)
+    })";
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+  ConvRewriter rewriter{GetCudaComputeCapability()};
+  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+                               GetToolkitVersion()};
+  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+  // Simplify new `convert`'s that may be added to the graph, and fold
+  // convert back into constants.
+  AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
+  TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
+  HloConstantFolding constant_folding;
+  TF_ASSERT_OK(RunHloPass(&constant_folding, m.get()).status());
+
+  SCOPED_TRACE(m->ToString());
+  const HloInstruction* conv;
+  // This doesn't get transformed into an f16 conv because the filters param is
+  // not losslessly expressible as f16.
+  ASSERT_THAT(
+      m->entry_computation()->root_instruction(),
+      GmockMatch(
+          m::Convert(m::GetTupleElement(
+                         m::CustomCall(
+                             &conv, {kCudnnConvBiasActivationForwardCallTarget},
+                             m::Convert(m::Parameter(0)).WithElementType(F32),
+                             m::Constant().WithElementType(F32),
+                             m::Convert(m::Parameter(1)).WithElementType(F32),
+                             m::Constant().WithElementType(F32)),
+                         0)
+                         .WithShape(F32, {1, 2, 2, 2}))
+              .WithElementType(F16)));
+  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+                          conv->backend_config<GpuBackendConfig>());
+  const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
+  EXPECT_EQ(config.side_input_scale(), 1);
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, FuseReluBeforeConvert) {
+  MAYBE_SKIP_TEST("I8");
+  const std::string module_str = R"(
+  HloModule Test
+
+  ENTRY Test {
+    input = s8[1,17,9,9] parameter(0)
+    filter = s8[3,3,17,32] parameter(1)
+    inputs32 = s32[1,17,9,9] convert(input)
+    filters32 = s32[3,3,17,32] convert(filter)
+
+    conv = s32[1,32,9,9] convolution(inputs32, filters32), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
+
+    zero = s32[] constant(0)
+    zeros = s32[1,32,9,9] broadcast(zero), dimensions={}
+    relu = maximum(conv, zeros)
+
+    lower = s32[] constant(-128)
+    lowers = s32[1,32,9,9] broadcast(lower), dimensions={}
+    upper = s32[] constant(127)
+    uppers = s32[1,32,9,9] broadcast(upper), dimensions={}
+
+    clamp = s32[1,32,9,9] clamp(lowers, relu, uppers)
+
+    ROOT convert = s8[1,32,9,9] convert(clamp)
+  })";
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+  ConvRewriter rewriter{GetCudaComputeCapability()};
+  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+                               GetToolkitVersion()};
+  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+  // Simplify new `convert`'s that may be added to the graph.
+  AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
+  TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
+
+  SCOPED_TRACE(m->ToString());
+  const HloInstruction* conv;
+  ASSERT_THAT(
+      m->entry_computation()->root_instruction(),
+      GmockMatch(
+          m::GetTupleElement(
+              m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget},
+                            m::Parameter(0),  //
+                            m::Parameter(1),  //
+                            m::Broadcast(m::ConstantEffectiveScalar(0))
+                                .WithShape(F32, {32})),
+              0)
+              .WithShape(S8, {1, 32, 9, 9})));
+  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+                          conv->backend_config<GpuBackendConfig>());
+  const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config();
+  EXPECT_EQ(config.activation_mode(), se::dnn::kRelu);
+}
+
+TEST_F(CudnnFusedConvRewriterHloTest, BiasTypeMatchesConvTypeIfFp) {
+  MAYBE_SKIP_TEST("F64");
+  const std::string module_str = R"(
+  HloModule Test
+
+  ENTRY Test {
+    input = f64[1,17,9,9] parameter(0)
+    filter = f64[3,3,17,32] parameter(1)
+    bias = f64[1,32,9,9] broadcast(f64[32] convert(f32[32] parameter(2))), dimensions={1}
+    conv = f64[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
+    ROOT root = f64[1,32,9,9] add(conv, bias)
+  })";
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+  ConvRewriter rewriter{GetCudaComputeCapability()};
+  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
+  CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
+                               GetToolkitVersion()};
+  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
+
+  // Simplify new `convert`'s that may be added to the graph.
+  AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
+  TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
+
+  SCOPED_TRACE(m->ToString());
+  const HloInstruction* conv;
+  ASSERT_THAT(
+      m->entry_computation()->root_instruction(),
+      GmockMatch(
+          m::GetTupleElement(
+              m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget},
+                            m::Parameter(0),  //
+                            m::Parameter(1),  //
+                            m::Convert(m::Parameter(2)).WithShape(F64, {32})),
+              0)
+              .WithShape(F64, {1, 32, 9, 9})));
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestFusedConvInt8ToInt8) {
+  MAYBE_SKIP_TEST("I8");
+  // clamp(max(0, conv(x, w)+bias)); for int8_t
+  TestClamp(
+      // pre_hlo
+      R"(
+    HloModule Test
+
+    ENTRY Test {
+      zero = f32[] constant(0)
+      zeros = f32[1,3,3,64] broadcast(zero), dimensions={}
+
+      input = s8[1,3,3,64] parameter(0)
+      filter = s8[3,3,64,64] parameter(1)
+      bias = f32[64] parameter(2)
+
+      inputs32 = s32[1,3,3,64] convert(input)
+      filters32 = s32[3,3,64,64] convert(filter)
+
+      conv = s32[1,3,3,64] convolution(inputs32, filters32), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+
+      convfloat = f32[1,3,3,64] convert(conv)
+      broadcasted_bias = f32[1,3,3,64] broadcast(bias), dimensions={3}
+      add1 = f32[1,3,3,64] add(convfloat, broadcasted_bias)
+      relu = f32[1,3,3,64] maximum(zeros, add1)
+
+      lower = f32[] constant(-128)
+      lowers = f32[1,3,3,64] broadcast(lower), dimensions={}
+      upper = f32[] constant(127)
+      uppers = f32[1,3,3,64] broadcast(upper), dimensions={}
+
+      clamp = f32[1,3,3,64] clamp(lowers, relu, uppers)
+
+      ROOT convert = s8[1,3,3,64] convert(clamp)
+    })",
+      // post_hlo
+      R"(
+// CHECK: [[cudnn_conv_bias_activation_7_0:%[^ ]+]] = (s8[1,3,3,64]{3,2,1,0}, u8[{{[0-9]+}}]{0}) custom-call([[input_1:%[^ ]+]], [[transpose_2:%[^ ]+]], [[bias_3:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convBiasActivationForward"
+      )");
+}
+
+// Disabled per b/190854862 or nvbugs/3326122.
+TEST_F(CudnnFusedConvRewriterTest, DISABLED_TestFusedConvInt8ToFloat) {
+  MAYBE_SKIP_TEST("I8");
+  // max(0, convert<float>(conv<int32_t>(int8_x),
+  // conv<int32_t>(int8_w))+float_bias)); int8_t to float via bias.
+  TestClamp(
+      // pre_hlo
+      R"(
+    HloModule Test
+
+    ENTRY Test {
+      zero = f32[] constant(0)
+      zeros = f32[1,3,3,64] broadcast(zero), dimensions={}
+
+      input = s8[1,3,3,64] parameter(0)
+      filter = s8[3,3,64,64] parameter(1)
+      bias = f32[64] parameter(2)
+
+      inputs32 = s32[1,3,3,64] convert(input)
+      filters32 = s32[3,3,64,64] convert(filter)
+
+      conv = s32[1,3,3,64] convolution(inputs32, filters32), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+
+      convfloat = f32[1,3,3,64] convert(conv)
+      broadcasted_bias = f32[1,3,3,64] broadcast(bias), dimensions={3}
+      add1 = f32[1,3,3,64] add(convfloat, broadcasted_bias)
+      ROOT relu = f32[1,3,3,64] maximum(zeros, add1)
+    })",
+      // post_hlo
+      R"(
+      ; CHECK-LABEL: ENTRY %Test (input: s8[1,3,3,64], filter: s8[3,3,64,64], bias: f32[64]) -> f32[1,3,3,64] {
+      ; CHECK:  [[custom_call_0:%[^ ]+]]{{(\.[0-9])?}} = (f32[1,3,3,64]{3,2,1,0}, u8[{{[0-9]*}}]{0}) custom-call([[input_1:%[^ ]+]], [[copy_2:%[^ ]+]]{{(\.[0-9])?}}, [[bias_3:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convBiasActivationForward", backend_config=
+      ; CHECK-NEXT:  ROOT [[get_tuple_element_4:%[^ ]+]]{{(\.[0-9])?}} = f32[1,3,3,64]{3,2,1,0} get-tuple-element([[custom_call_0]]{{(\.[0-9])?}}), index=0
+      )");
+}
+
+TEST_F(CudnnFusedConvRewriterTest,
+       TestFusedConvWithScaledInt8SideInputBiasInt8ToInt8) {
+  MAYBE_SKIP_TEST("I8");
+  // clamp(max(0, alpha_conv * conv(x, w) + alpha_side *
+  // convert<int32_t>(int8_side_input) + bias)); for int8_t
+  TestClamp(
+      // pre_hlo
+      R"(
+    HloModule Test
+
+    ENTRY Test {
+      zero = f32[] constant(0)
+      zeros = f32[1,3,3,64] broadcast(zero), dimensions={}
+      alpha_conv_scalar = f32[] constant(0.999994934)
+      alpha_conv = f32[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={}
+      alpha_side_input_scalar = f32[] constant(0.899994934)
+      alpha_side_input = f32[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={}
+
+      input = s8[1,3,3,64] parameter(0)
+      filter = s8[3,3,64,64] parameter(1)
+      side_input = s8[1,3,3,64] parameter(2)
+      bias = f32[64] parameter(3)
+
+      inputs32 = s32[1,3,3,64] convert(input)
+      filters32 = s32[3,3,64,64] convert(filter)
+      side_input_f32 = f32[1,3,3,64] convert(side_input)
+
+      conv = s32[1,3,3,64] convolution(inputs32, filters32), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+
+      convfloat = f32[1,3,3,64] convert(conv)
+      scaled_conv = f32[1,3,3,64] multiply(convfloat, alpha_conv)
+      scaled_side_input = f32[1,3,3,64] multiply(side_input_f32, alpha_side_input)
+      broadcasted_bias = f32[1,3,3,64] broadcast(bias), dimensions={3}
+      add1 = f32[1,3,3,64] add(scaled_conv, broadcasted_bias)
+      add2 = f32[1,3,3,64] add(add1, scaled_side_input)
+      relu = f32[1,3,3,64] maximum(zeros, add2)
+
+      lower = f32[] constant(-128)
+      lowers = f32[1,3,3,64] broadcast(lower), dimensions={}
+      upper = f32[] constant(127)
+      uppers = f32[1,3,3,64] broadcast(upper), dimensions={}
+
+      clamp = f32[1,3,3,64] clamp(lowers, relu, uppers)
+
+      ROOT convert = s8[1,3,3,64] convert(clamp)
+    })",
+      // post_hlo
+      R"(
+// CHECK: [[cudnn_conv_bias_activation_11_0:%[^ ]+]] = (s8[1,3,3,64]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[input_1:%[^ ]+]], [[transpose_2:%[^ ]+]], [[bias_3:%[^ ]+]], [[side_input_4:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convBiasActivationForward"
+      )");
+}
+
+TEST_F(CudnnFusedConvRewriterTest,
+       TestFusedConvWithScaledFloatSideInputBiasInt8ToInt8) {
+  MAYBE_SKIP_TEST("I8");
+  // From:
+  // convert<int8_t>(clamp(max(0, alpha_conv * conv(x, w) + alpha_side *
+  // float_side_input + bias))); To: convert<int8_t>(clamp(conv(int8_x, int8_w,
+  // float_alpha_side, float_side_input, float_bias)));
+  TestClamp(
+      // pre_hlo
+      R"(
+    HloModule Test
+
+    ENTRY Test {
+      zero = f32[] constant(0)
+      zeros = f32[1,3,3,64] broadcast(zero), dimensions={}
+      alpha_conv_scalar = f32[] constant(0.999994934)
+      alpha_conv = f32[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={}
+      alpha_side_input_scalar = f32[] constant(0.899994934)
+      alpha_side_input = f32[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={}
+
+      input = s8[1,3,3,64] parameter(0)
+      filter = s8[3,3,64,64] parameter(1)
+      side_input = f32[1,3,3,64] parameter(2)
+      bias = f32[64] parameter(3)
+
+      inputs32 = s32[1,3,3,64] convert(input)
+      filters32 = s32[3,3,64,64] convert(filter)
+
+      conv = s32[1,3,3,64] convolution(inputs32, filters32), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+
+      convfloat = f32[1,3,3,64] convert(conv)
+      scaled_conv = f32[1,3,3,64] multiply(convfloat, alpha_conv)
+      scaled_side_input = f32[1,3,3,64] multiply(side_input, alpha_side_input)
+      broadcasted_bias = f32[1,3,3,64] broadcast(bias), dimensions={3}
+      add1 = f32[1,3,3,64] add(scaled_conv, broadcasted_bias)
+      add2 = f32[1,3,3,64] add(add1, scaled_side_input)
+      relu = f32[1,3,3,64] maximum(zeros, add2)
+
+      lower = f32[] constant(-128)
+      lowers = f32[1,3,3,64] broadcast(lower), dimensions={}
+      upper = f32[] constant(127)
+      uppers = f32[1,3,3,64] broadcast(upper), dimensions={}
+
+      clamp = f32[1,3,3,64] clamp(lowers, relu, uppers)
+
+      ROOT convert = s8[1,3,3,64] convert(clamp)
+    })",
+      //  post_hlo
+      R"(
+// CHECK: [[cudnn_conv_bias_activation_9_0:%[^ ]+]] = (f32[1,3,3,64]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[input_1:%[^ ]+]], [[transpose_2:%[^ ]+]], [[bias_3:%[^ ]+]], [[side_input_4:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convBiasActivationForward"
+      )");
+}
+
+TEST_F(CudnnFusedConvRewriterTest,
+       TestFusedConvWithScaledInt8SideInputBiasInt8ToFloat) {
+  MAYBE_SKIP_TEST("I8");
+  // From:
+  // clamp(max(0, alpha_conv * conv(x, w) + alpha_side *
+  // convert<float>(int8_side_input) + bias)); To: clamp(conv(int8_x, int8_w,
+  // float_alpha_side, convert<float>(int8_side_input), float_bias));
+  TestClamp(
+      // pre_hlo
+      R"(
+    HloModule Test
+
+    ENTRY Test {
+      zero = f32[] constant(0)
+      zeros = f32[1,3,3,64] broadcast(zero), dimensions={}
+      alpha_conv_scalar = f32[] constant(0.999994934)
+      alpha_conv = f32[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={}
+      alpha_side_input_scalar = f32[] constant(0.899994934)
+      alpha_side_input = f32[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={}
+
+      input = s8[1,3,3,64] parameter(0)
+      filter = s8[3,3,64,64] parameter(1)
+      side_input = s8[1,3,3,64] parameter(2)
+      bias = f32[64] parameter(3)
+
+      inputs32 = s32[1,3,3,64] convert(input)
+      filters32 = s32[3,3,64,64] convert(filter)
+      side_input_f32 = f32[1,3,3,64] convert(side_input)
+
+      conv = s32[1,3,3,64] convolution(inputs32, filters32), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+
+      convfloat = f32[1,3,3,64] convert(conv)
+      scaled_conv = f32[1,3,3,64] multiply(convfloat, alpha_conv)
+      scaled_side_input = f32[1,3,3,64] multiply(side_input_f32, alpha_side_input)
+      broadcasted_bias = f32[1,3,3,64] broadcast(bias), dimensions={3}
+      add1 = f32[1,3,3,64] add(scaled_conv, broadcasted_bias)
+      add2 = f32[1,3,3,64] add(add1, scaled_side_input)
+      relu = f32[1,3,3,64] maximum(zeros, add2)
+
+      lower = f32[] constant(-128)
+      lowers = f32[1,3,3,64] broadcast(lower), dimensions={}
+      upper = f32[] constant(127)
+      uppers = f32[1,3,3,64] broadcast(upper), dimensions={}
+
+      ROOT clamp = f32[1,3,3,64] clamp(lowers, relu, uppers)
+    })",
+      // post_hlo
+      R"(
+// CHECK: [[cudnn_conv_bias_activation_9_0:%[^ ]+]] = (f32[1,3,3,64]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[input_1:%[^ ]+]], [[transpose_2:%[^ ]+]], [[bias_3:%[^ ]+]], [[fusion_1_4:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convBiasActivationForward"
+      )");
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestConvInt8ToInt8NoClamp) {
+  MAYBE_SKIP_TEST("I8");
+  // Check that integer convolution without clamp to int8_t is not allowed.
+  // convert<int8_t>(custom_call<int32_t>(int32_x, int32_w,
+  // cudnnConvolutionForward))
+  const std::string module_str = absl::StrFormat(R"(
+    HloModule Test
+
+    ENTRY Test (input: s8[1,17,9,9], filter: s8[3,3,17,32]) -> s8[1,32,9,9] {
+      zero = s8[] constant(0)
+      zeros = s8[1,32,9,9]{3,2,1,0} broadcast(s8[] zero), dimensions={}
+      input = s8[1,17,9,9]{3,2,1,0} parameter(0)
+      filter = s8[3,3,17,32]{3,2,1,0} parameter(1)
+      custom-call = (s32[1,32,9,9]{3,2,1,0}, u8[0]{0}) custom-call(s8[1,17,9,9]{3,2,1,0} input, s8[3,3,17,32]{3,2,1,0} filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, custom_call_target="__cudnn$convForward", backend_config="{\"convResultScale\":1}"
+      get-tuple-element = s32[1,32,9,9]{3,2,1,0} get-tuple-element((s32[1,32,9,9]{3,2,1,0}, u8[0]{0}) custom-call), index=0
+      convert = s8[1,32,9,9]{3,2,1,0} convert(s32[1,32,9,9]{3,2,1,0} get-tuple-element)
+      ROOT relu = s8[1,32,9,9]{3,2,1,0} maximum(s8[1,32,9,9]{3,2,1,0} zeros, s8[1,32,9,9]{3,2,1,0} convert)
+    })");
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+  ASSERT_FALSE(CudnnFusedConvRewriter(GetCudaComputeCapability(),
+                                      GetDnnVersion(), GetToolkitVersion())
+                   .Run(m.get())
+                   .ok());
+}
+
+TEST_F(CudnnFusedConvRewriterTest, TestFusedConvInt8ToInt8NoClamp) {
+  MAYBE_SKIP_TEST("I8");
+  // Although bias and so on are fused with forward convolution,
+  // it is still not allowed if the output is not clampped/converted to int8_t
+  // max(0, alpha_conv * conv(x, w) + alpha_side * side_input + bias); for
+  // int8_t
+
+  const std::string module_str = absl::StrFormat(R"(
+    HloModule Test
+
+    ENTRY Test (input: s8[1,17,9,9], filter: s8[3,3,17,32]) -> s8[1,32,9,9] {
+      zero = s8[] constant(0)
+      zeros = s8[1,32,9,9]{3,2,1,0} broadcast(s8[] zero), dimensions={}
+      input = s8[1,17,9,9]{3,2,1,0} parameter(0)
+      filter = s8[3,3,17,32]{3,2,1,0} parameter(1)
+      custom-call = (s32[1,32,9,9]{3,2,1,0}, u8[0]{0}) custom-call(s8[1,17,9,9]{3,2,1,0} input, s8[3,3,17,32]{3,2,1,0} filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, custom_call_target="__cudnn$convForward", backend_config="{\"convResultScale\":1}"
+      get-tuple-element = s32[1,32,9,9]{3,2,1,0} get-tuple-element((s32[1,32,9,9]{3,2,1,0}, u8[0]{0}) custom-call), index=0
+      convert = s8[1,32,9,9]{3,2,1,0} convert(s32[1,32,9,9]{3,2,1,0} get-tuple-element)
+      ROOT relu = s8[1,32,9,9]{3,2,1,0} maximum(s8[1,32,9,9]{3,2,1,0} zeros, s8[1,32,9,9]{3,2,1,0} convert)
+    })");
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+  ASSERT_FALSE(CudnnFusedConvRewriter(GetCudaComputeCapability(),
+                                      GetDnnVersion(), GetToolkitVersion())
+                   .Run(m.get())
+                   .ok());
+}
+
+}  // namespace
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_rewriter.cc
new file mode 100644
index 0000000..7fbd589
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_rewriter.cc
@@ -0,0 +1,1777 @@
+/* Copyright 2022 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/cudnn_fused_mha_rewriter.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <numeric>
+#include <optional>
+#include <queue>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/permutation_util.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/service/gpu/matmul_utils.h"
+#include "xla/service/gpu/stream_executor_util.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/status_macros.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/stream_executor/dnn.h"
+#include "xla/types.h"
+#include "xla/util.h"
+#include "xla/xla.pb.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
+
+#if GOOGLE_CUDA
+#include "third_party/gpus/cuda/include/cuda.h"
+#endif
+
+namespace xla {
+namespace gpu {
+namespace {
+namespace m = match;
+
+// A struct that contains all the matched nodes
+// and results from pattern matching forward graph
+struct MatchFwdResult {
+  HloInstruction* matched_bmm_1 = nullptr;
+  HloInstruction* matched_bmm_2 = nullptr;
+  HloInstruction* matched_bias = nullptr;
+  HloInstruction* matched_scale = nullptr;
+  HloInstruction* matched_softmax_input = nullptr;
+  HloInstruction* matched_reduce_sum = nullptr;
+
+  double matched_dropout_rate = 0.0;
+  bool need_canonicalization = false;
+  bool is_training = false;
+  // We use this to keep track of whether the bias is being
+  // applied to the bmm1 is a causal mask, cuDNN can generate causal mask inside
+  // the attention kernel to save I/O.
+  bool is_causal_mask = false;
+  bool has_match = false;
+  std::string matched_custom_call_name;
+};
+
+// A struct that contains all the matched nodes
+// and results from pattern matching backward graph
+struct MatchBwdResult {
+  HloInstruction* matched_bmm_1_grad_1 = nullptr;
+  HloInstruction* matched_bmm_1_grad_2 = nullptr;
+
+  HloInstruction* matched_bmm_2_grad_1 = nullptr;
+  HloInstruction* matched_bmm_2_grad_2 = nullptr;
+  HloInstruction* matched_dbias = nullptr;
+  // We use this to keep track of all gradient bmms that need
+  // canonicalization.
+  bool bmm_1_grad_1_need_canonicalization = false;
+  bool bmm_1_grad_2_need_canonicalization = false;
+  bool bmm_2_grad_1_need_canonicalization = false;
+  bool bmm_2_grad_2_need_canonicalization = false;
+
+  bool has_match = false;
+  std::string matched_custom_call_name;
+};
+
+template <typename Pattern>
+auto OptionalReshape(Pattern pattern) {
+  auto shared = m::SharedSubpattern(pattern);
+  return m::AnyOf<HloInstruction>(m::Reshape(shared), shared);
+}
+
+template <typename Pattern>
+auto OptionalConvert(Pattern pattern) {
+  auto shared = m::SharedSubpattern(pattern);
+  return m::AnyOf<HloInstruction>(m::Convert(shared), shared);
+}
+
+template <typename Pattern>
+auto OptionalBitcast(Pattern pattern) {
+  auto shared = m::SharedSubpattern(pattern);
+  return m::AnyOf<HloInstruction>(m::Bitcast(shared), shared);
+}
+
+template <typename Pattern>
+auto OptionalBroadcast(Pattern pattern) {
+  auto shared = m::SharedSubpattern(pattern);
+  return m::AnyOf<HloInstruction>(m::Broadcast(shared), shared);
+}
+
+bool IsBatchedMatmul(const HloInstruction* instr) {
+  if (instr->opcode() != HloOpcode::kDot) return false;
+  if (Cast<HloDotInstruction>(instr)->sparse_operands()) return false;
+  const DotDimensionNumbers& dot_dims = instr->dot_dimension_numbers();
+  bool is_batch_dot = !dot_dims.lhs_batch_dimensions().empty() ||
+                      !dot_dims.rhs_batch_dimensions().empty();
+  return is_batch_dot;
+}
+
+// We need to check if current gemm is sharing a parent node with a forward
+// fMHA call because when we match backward gemms, the only way that we can be
+// sure this is a backward gemm is to see if it's sharing the same operand with
+// forward mha call(i.e Q,K,V,activation tensors). We can also use this function
+// to infer if a gemm is a forward fmha gemm or not. We check this by doing a
+// BFS of all operands to see if there's any user that is a forward fMHA custom
+// call. We continue the traversal for shape ops like bitcast, reshape and
+// transpose until we see a forward fmha call or there's no shape ops in path
+// which means that current node will never share the same operand with a
+// forward fmha call.
+bool IsSharingOperandWithFwdMha(HloInstruction* gemm) {
+  for (int64_t i = 0; i < gemm->operands().size(); i++) {
+    std::queue<HloInstruction*> visit_list;
+    visit_list.push(gemm->mutable_operand(i));
+    while (!visit_list.empty()) {
+      HloInstruction* current_instr = visit_list.front();
+      for (auto user : current_instr->users()) {
+        switch (user->opcode()) {
+          case HloOpcode::kBitcast:
+          case HloOpcode::kReshape:
+          case HloOpcode::kTranspose: {
+            visit_list.push(user);
+            break;
+          }
+          case HloOpcode::kCustomCall: {
+            if (IsFwdCustomCallTofMHA(*user)) {
+              return true;
+            }
+          } break;
+          default:
+            break;
+        }
+      }
+      visit_list.pop();
+    }
+  }
+  return false;
+}
+// When we reach a gemm instruction, it could be one of the 3 cases:
+//   1. one of the 2 gemms in forward fmha call
+//   2. one of the 4 gemms in backward fmha call
+//   3. gemms of other un-related layers
+// 3 can be easily ruled out by the pattern matcher.
+// However, 1 and 2 have very similar bmm-bmm structures.
+// We need to determine that we exactly match case 1 for forward gemms
+// which have below properties:
+//    - A batched matmul
+//    - None of the operands is a forward fmha call, in which case would make it
+//      a backward gemm.
+//    - It's not directly or indirectly sharing an operand with any other fmha
+//      call, in which case would make it a backward gemm
+bool IsFirstFwdMatmul(HloInstruction* gemm) {
+  return IsBatchedMatmul(gemm) && !IsFwdCustomCallTofMHA(*gemm->operand(0)) &&
+         !IsFwdCustomCallTofMHA(*gemm->operand(1)) &&
+         !IsSharingOperandWithFwdMha(gemm);
+}
+
+bool IsScalar(const HloInstruction* instr) {
+  return ShapeUtil::IsEffectiveScalar(instr->shape());
+}
+
+bool IsReduceMax(const HloInstruction* instr) {
+  return instr->opcode() == HloOpcode::kReduce &&
+         instr->to_apply()->root_instruction()->opcode() == HloOpcode::kMaximum;
+}
+
+bool IsReduceSum(const HloInstruction* instr) {
+  return instr->opcode() == HloOpcode::kReduce &&
+         instr->to_apply()->root_instruction()->opcode() == HloOpcode::kAdd;
+}
+
+// Set up subpatterns for re-use.
+// Matches softmax sub-pattern ->
+// divide(exp(Subtract(producer, reduce_max(producer))),
+// broadcast(reduce_add(exp(Subtract(...))))). There might be reshape and
+// convert nodes between reduce and Subtract.
+// TODO TJ: Make this more general to any patterns that has this structure when
+// cudnn runner supports generic cudnnOpGraphs. producer
+// |   \
+// |  reduce
+// |     |
+// |  broadcast
+// |   /
+// root
+auto GetUnfusedReduceMaxSumSoftmaxPattern(
+    HloInstruction** softmax_input = nullptr,
+    HloInstruction** softmax_reduce_sum = nullptr,
+    HloInstruction** softmax_reduce_sum_bcast = nullptr) {
+  // The reduce-max part of the softmax
+  // reduce_max and subtract will always have exactly 1 user
+  // in both training and inference
+  // softmax_input should always have exactly 2 users
+  auto unfused_softmax_max_subpattern = m::SharedSubpattern(
+      m::Subtract(
+          m::Op(),
+          m::Broadcast(OptionalConvert(
+              m::Op()
+                  .WithPredicate(IsReduceMax)
+                  .WithOneUse()
+                  .WithOperand(0, OptionalBitcast(OptionalConvert(
+                                      m::Op(softmax_input).WithNumUser(2)))))))
+          .WithOneUse());
+  // The reduce-add part of the softmax
+  // reduce_sum and reduce_sum_broadcast should have 2 users in training
+  // and 1 user in inference
+  auto unfused_softmax_sum_subpattern = m::SharedSubpattern(m::Divide(
+      OptionalBitcast(m::Exp(unfused_softmax_max_subpattern)),
+      m::Broadcast(
+          softmax_reduce_sum_bcast,
+          OptionalConvert(
+              m::Op(softmax_reduce_sum)
+                  .WithOperand(0, OptionalBitcast(OptionalConvert(
+                                      m::Exp(unfused_softmax_max_subpattern))))
+                  .WithPredicate(IsReduceSum)
+                  .WithAtMostNumUser(2)))
+          .WithAtMostNumUser(2)));
+  return unfused_softmax_sum_subpattern;
+}
+
+std::optional<double> GetConstantValue(const HloInstruction* inst) {
+  if (!IsScalar(inst)) {
+    return std::nullopt;
+  }
+  switch (inst->shape().element_type()) {
+    case F16:
+      return static_cast<float>(inst->literal().GetFirstElement<half>());
+    case BF16:
+      return static_cast<float>(inst->literal().GetFirstElement<bfloat16>());
+    case F32:
+      return inst->literal().GetFirstElement<float>();
+    case F64:
+      return inst->literal().GetFirstElement<double>();
+    default:
+      return std::nullopt;
+  }
+}
+
+double GetDropoutRateFromHlo(HloInstruction* dropout) {
+  std::optional<double> dropout_rate_inv;
+  dropout_rate_inv = GetConstantValue(dropout);
+  if (!dropout_rate_inv.has_value()) {
+    return 0.0;
+  }
+  // In dropout, inputs are divided by (1 - rate), we need to divide 1 by
+  // the constant in dropout node and substract
+  // from 1 here to get the actual dropout rate.
+  return (1.0 - (1.0 / *dropout_rate_inv));
+}
+
+bool IsComputeCapabilityAndCudnnSupported(
+    stream_executor::CudaComputeCapability cc,
+    stream_executor::dnn::VersionInfo cudnn_version,
+    stream_executor::dnn::VersionInfo supported_cudnn_version) {
+  if (cc.IsAtLeastAmpere() && cudnn_version >= supported_cudnn_version) {
+    return true;
+  }
+  VLOG(2) << absl::StrFormat(
+      "CudnnFusedMHARewriter did not run. Unsupported compute "
+      "capability(%s; major should be >= 8, minor should be 0) or cudnn version"
+      "(%s; should be >= %s)",
+      cc.ToString(), cudnn_version.ToString(),
+      supported_cudnn_version.ToString());
+  return false;
+}
+
+bool IsSupportedPrimitiveType(const HloInstruction* bmm) {
+  PrimitiveType dtype = bmm->shape().element_type();
+  return dtype == BF16 || dtype == F16;
+}
+
+std::vector<int64_t> GetDimensionVector(absl::Span<const int64_t> dimensions,
+                                        absl::Span<const int64_t> dim_nums) {
+  std::vector<int64_t> vec(dim_nums.size());
+  for (int i = 0; i < dim_nums.size(); i++) {
+    vec[i] = dimensions.at(dim_nums.at(i));
+  }
+  return vec;
+}
+
+struct QKVLayout {
+  int64_t batch;
+  int64_t num_heads;
+  int64_t seqlen_q;
+  int64_t seqlen_kv;
+  int64_t hidden_dim;
+};
+
+absl::StatusOr<std::optional<QKVLayout>> GetQKVLayout(
+    HloInstruction* bmm_1, HloInstruction* bmm_2, bool need_canonicalization) {
+  // get layout from bmm1
+  const DotDimensionNumbers& bmm1_dnums = bmm_1->dot_dimension_numbers();
+  TF_ASSIGN_OR_RETURN(
+      std::vector<int64_t> bmm1_s_q_dims,
+      GetNonContractingDims(bmm_1->operand(0)->shape(),
+                            bmm1_dnums.lhs_batch_dimensions(),
+                            bmm1_dnums.lhs_contracting_dimensions()));
+
+  TF_ASSIGN_OR_RETURN(
+      std::vector<int64_t> bmm1_s_kv_dims,
+      GetNonContractingDims(bmm_1->operand(1)->shape(),
+                            bmm1_dnums.rhs_batch_dimensions(),
+                            bmm1_dnums.rhs_contracting_dimensions()));
+
+  std::vector<int64_t> bmm1_bh =
+      GetDimensionVector(bmm_1->operand(0)->shape().dimensions(),
+                         bmm1_dnums.lhs_batch_dimensions());
+
+  std::vector<int64_t> bmm1_s_q = GetDimensionVector(
+      bmm_1->operand(0)->shape().dimensions(), bmm1_s_q_dims);
+
+  std::vector<int64_t> bmm1_s_kv = GetDimensionVector(
+      bmm_1->operand(1)->shape().dimensions(), bmm1_s_kv_dims);
+
+  std::vector<int64_t> bmm1_d =
+      GetDimensionVector(bmm_1->operand(0)->shape().dimensions(),
+                         bmm1_dnums.lhs_contracting_dimensions());
+
+  TF_RET_CHECK(bmm1_bh.size() == 2);
+  TF_RET_CHECK(bmm1_s_q.size() == 1);
+  TF_RET_CHECK(bmm1_s_kv.size() == 1);
+  TF_RET_CHECK(bmm1_d.size() == 1);
+
+  // get layout from bmm2
+  const DotDimensionNumbers& bmm2_dnums = bmm_2->dot_dimension_numbers();
+  TF_ASSIGN_OR_RETURN(
+      std::vector<int64_t> bmm2_lhs_non_contracting_dims,
+      GetNonContractingDims(bmm_2->operand(0)->shape(),
+                            bmm2_dnums.lhs_batch_dimensions(),
+                            bmm2_dnums.lhs_contracting_dimensions()));
+
+  TF_ASSIGN_OR_RETURN(
+      std::vector<int64_t> bmm2_rhs_non_contracting_dims,
+      GetNonContractingDims(bmm_2->operand(1)->shape(),
+                            bmm2_dnums.rhs_batch_dimensions(),
+                            bmm2_dnums.rhs_contracting_dimensions()));
+
+  std::vector<int64_t> bmm2_bh =
+      GetDimensionVector(bmm_2->operand(0)->shape().dimensions(),
+                         bmm2_dnums.lhs_batch_dimensions());
+
+  std::vector<int64_t> bmm2_s_kv =
+      GetDimensionVector(bmm_2->operand(0)->shape().dimensions(),
+                         bmm2_dnums.lhs_contracting_dimensions());
+
+  std::vector<int64_t> bmm2_s_q =
+      need_canonicalization
+          ? GetDimensionVector(bmm_2->operand(1)->shape().dimensions(),
+                               bmm2_rhs_non_contracting_dims)
+          : GetDimensionVector(bmm_2->operand(0)->shape().dimensions(),
+                               bmm2_lhs_non_contracting_dims);
+
+  std::vector<int64_t> bmm2_d =
+      need_canonicalization
+          ? GetDimensionVector(bmm_2->operand(0)->shape().dimensions(),
+                               bmm2_lhs_non_contracting_dims)
+          : GetDimensionVector(bmm_2->operand(1)->shape().dimensions(),
+                               bmm2_rhs_non_contracting_dims);
+
+  TF_RET_CHECK(bmm2_bh.size() == 2);
+  TF_RET_CHECK(bmm2_s_q.size() == 1);
+  TF_RET_CHECK(bmm2_s_kv.size() == 1);
+  TF_RET_CHECK(bmm2_d.size() == 1);
+
+  // check if bhsd is correct between bmm1 and bmm2
+  if (bmm1_bh[0] != bmm2_bh[0] || bmm1_bh[1] != bmm2_bh[1] ||
+      bmm1_s_q[0] != bmm2_s_q[0] || bmm1_s_kv[0] != bmm2_s_kv[0] ||
+      bmm1_d[0] != bmm2_d[0]) {
+    return std::nullopt;
+  }
+
+  QKVLayout qkv_layout;
+  qkv_layout.batch = bmm1_bh[0];
+  qkv_layout.num_heads = bmm1_bh[1];
+  qkv_layout.seqlen_q = bmm1_s_q[0];
+  qkv_layout.seqlen_kv = bmm1_s_kv[0];
+  qkv_layout.hidden_dim = bmm1_d[0];
+  return qkv_layout;
+}
+
+absl::StatusOr<bool> IsFlashAttention(
+    QKVLayout qkv_layout, bool is_training,
+    stream_executor::CudaComputeCapability cc,
+    stream_executor::dnn::VersionInfo cudnn_version) {
+  int64_t s_q = qkv_layout.seqlen_q;
+  int64_t s_kv = qkv_layout.seqlen_kv;
+  int64_t hidden_dim = qkv_layout.hidden_dim;
+  // start with most relaxed constraint
+  bool is_seqlen_supported = (!is_training || (s_q % 2 == 0 && s_kv % 2 == 0));
+  bool is_hidden_dim_supported = hidden_dim <= 128 && hidden_dim % 8 == 0;
+  bool is_flash_attention = is_seqlen_supported && is_hidden_dim_supported;
+  if (!is_flash_attention) return false;
+
+  // going backwards to check compatibility
+  if ((is_training && (s_q < 64 || s_kv < 64)) &&
+      !IsComputeCapabilityAndCudnnSupported(
+          cc, cudnn_version, stream_executor::dnn::VersionInfo(9, 0, 0))) {
+    VLOG(2) << "Flash attention training with seq < 64 not supported cuDNN < "
+               "9.0.0.";
+    return false;
+  }
+
+  if ((hidden_dim != 64 && hidden_dim != 128) &&
+      !IsComputeCapabilityAndCudnnSupported(
+          cc, cudnn_version, stream_executor::dnn::VersionInfo(8, 9, 6))) {
+    VLOG(2) << "Flash attention head dim != 64 or 128 not supported with cuDNN "
+               "< 8.9.6.";
+    return false;
+  }
+
+  if ((is_training && s_kv % 64 != 0) &&
+      !IsComputeCapabilityAndCudnnSupported(
+          cc, cudnn_version, stream_executor::dnn::VersionInfo(8, 9, 5))) {
+    VLOG(2) << "Flash attention training with seq kv % 64 != 0 not supported "
+               "with cuDNN < 8.9.5.";
+    return false;
+  }
+
+  if (!IsComputeCapabilityAndCudnnSupported(
+          cc, cudnn_version, stream_executor::dnn::VersionInfo(8, 9, 4))) {
+    VLOG(2) << "Require cuDNN 8.9.4 to run flash attention.";
+    return false;
+  }
+  return is_flash_attention;
+}
+
+bool IsCausalMaskPattern(HloInstruction* mask) {
+  auto causal_mask =
+      m::Select(m::Compare(m::Iota(), m::Iota()), m::Broadcast(m::Constant()),
+                m::Broadcast(m::Constant()));
+  auto causal_mask_pattern_fwd_remat =
+      m::Broadcast(OptionalBitcast(causal_mask));
+  auto causal_mask_pattern_bwd = m::Broadcast(m::Convert(OptionalBitcast(
+      m::Minimum(m::Op(), m::Broadcast(OptionalBitcast(causal_mask))))));
+  HloInstruction* param = nullptr;
+  HloInstruction* gte = nullptr;
+  auto causal_mask_pattern_fwd = m::Broadcast(
+      OptionalBitcast(m::GetTupleElement(&gte, m::Parameter(&param))));
+  auto causal_mask_pattern = m::AnyOf<HloInstruction>(
+      causal_mask_pattern_fwd_remat, causal_mask_pattern_fwd,
+      causal_mask_pattern_bwd);
+  if (Match(mask, causal_mask_pattern)) {
+    if (param != nullptr && param->parent()->IsWhileBodyComputation()) {
+      // need to track to outside of the while loop body to find the real mask.
+      auto while_instr = param->parent()->WhileCallInstruction();
+      auto mask_index = gte->tuple_index();
+      auto actual_mask =
+          while_instr->mutable_operand(0)->mutable_operand(mask_index);
+      auto causal_mask_pattern_fwd =
+          OptionalBitcast(m::Convert(m::MinimumAnyOrder(
+              m::Op(),
+              OptionalBitcast(m::MinimumAnyOrder(
+                  m::Op(), m::Broadcast(OptionalBitcast(causal_mask)))))));
+      return Match(actual_mask, causal_mask_pattern_fwd);
+    }
+    return true;
+  }
+  return false;
+}
+
+MatchFwdResult MatchSoftmaxDropoutBmm(MatchFwdResult previous_result,
+                                      int64_t bmm2_operand_position,
+                                      HloInstruction* instr) {
+  // Matches the dropout-softmax subpattern.
+  // Softmax_output is a divide
+  // Dropout can take multiple forms, we capture 2 forms here based on
+  // heurustics Form 1 -> softmax - mul - select(dropout) - BMM2
+  MatchFwdResult match_result = previous_result;
+  HloInstruction* softmax_reduce_sum;
+  HloInstruction* softmax_reduce_sum_bcast;
+  HloInstruction* bmm_2;
+  HloInstruction* softmax_input;
+  HloInstruction* dropout = nullptr;
+  auto dropout_softmax_pattern_form_1 = m::Select(
+      m::Op(),
+      OptionalConvert(m::MultiplyAnyOrder(
+          OptionalBitcast(OptionalReshape(
+              OptionalConvert(GetUnfusedReduceMaxSumSoftmaxPattern(
+                  &softmax_input, &softmax_reduce_sum,
+                  &softmax_reduce_sum_bcast)))),
+          m::Broadcast(
+              OptionalConvert(m::Constant(&dropout).WithPredicate(IsScalar))))),
+      m::Op());
+
+  // Form 2 -> softmax - mul - BMM2
+  //                     /
+  //                    /
+  //                 select(dropout)
+  auto dropout_softmax_pattern_form_2 =
+      OptionalBitcast(OptionalBitcast(OptionalConvert(m::MultiplyAnyOrder(
+          OptionalReshape(OptionalConvert(GetUnfusedReduceMaxSumSoftmaxPattern(
+              &softmax_input, &softmax_reduce_sum, &softmax_reduce_sum_bcast))),
+          m::Broadcast(
+              OptionalConvert(OptionalBitcast(OptionalReshape(m::Select(
+                  m::Op(),
+                  m::Broadcast(m::Constant(&dropout).WithPredicate(IsScalar)),
+                  m::Op())))))))));
+
+  // Form3 -> softmax - mul(dropout) - mul(scale) - BMM2
+  auto dropout_softmax_pattern_form_3 = m::MultiplyAnyOrder(
+      m::MultiplyAnyOrder(
+          OptionalConvert(GetUnfusedReduceMaxSumSoftmaxPattern(
+              &softmax_input, &softmax_reduce_sum, &softmax_reduce_sum_bcast)),
+          m::Op()),
+      m::Broadcast(m::Constant(&dropout).WithPredicate(IsScalar)));
+
+  // Try matching BMM1 - (Scale) - (Bias) - Softmax - (Dropout) -
+  // BMM2 Dropout with non-zero drop rate has select(divide(softmax_output,
+  // broadcast(1-dropout_rate)))
+  auto softmax_dropout_bmm2_pattern =
+      m::Op(&bmm_2)
+          .WithPredicate(IsBatchedMatmul)
+          .WithOperand(bmm2_operand_position,
+                       m::AnyOf<HloInstruction>(
+                           OptionalBitcast(OptionalConvert(
+                               GetUnfusedReduceMaxSumSoftmaxPattern(
+                                   &softmax_input, &softmax_reduce_sum,
+                                   &softmax_reduce_sum_bcast))),
+                           dropout_softmax_pattern_form_1,
+                           dropout_softmax_pattern_form_2,
+                           dropout_softmax_pattern_form_3));
+
+  if (!Match(instr, softmax_dropout_bmm2_pattern) ||
+      !IsSupportedPrimitiveType(bmm_2)) {
+    match_result.has_match = false;
+    return match_result;
+  }
+  if (softmax_reduce_sum->users()[0]->opcode() == HloOpcode::kConvert) {
+    softmax_reduce_sum = softmax_reduce_sum->users()[0];
+  }
+  match_result.is_training = softmax_reduce_sum->user_count() == 2 &&
+                             softmax_reduce_sum_bcast->user_count() == 2;
+  match_result.matched_bmm_2 = bmm_2;
+  if (dropout) {
+    match_result.matched_dropout_rate = GetDropoutRateFromHlo(dropout);
+  }
+  match_result.matched_softmax_input = softmax_input;
+  match_result.matched_reduce_sum = softmax_reduce_sum;
+  match_result.has_match = true;
+  return match_result;
+}
+
+MatchFwdResult MatchBmm1UnfusedBiasSoftmaxBmm2(MatchFwdResult previous_result,
+                                               HloInstruction* softmax_input,
+                                               bool has_dropout) {
+  MatchFwdResult match_result = previous_result;
+  HloInstruction* bmm_1;
+  HloInstruction* bias = nullptr;
+  HloInstruction* scale = nullptr;
+  // bmm1/scale/bias add should have 2 users if being connected to softmax
+  // otherwise should have exactly 1 user
+  auto first_bmm_pattern =
+      m::SharedSubpattern(m::Op(&bmm_1).WithPredicate(IsBatchedMatmul));
+  auto unfused_scaled_bmm_subpattern = m::MultiplyAnyOrder(
+      OptionalConvert(first_bmm_pattern.WithOneUse()),
+      OptionalConvert(
+          m::Broadcast(m::Constant(&scale).WithPredicate(IsScalar))));
+  if (Match(softmax_input,
+            OptionalConvert(OptionalBitcast(m::AnyOf<HloInstruction>(
+                first_bmm_pattern, unfused_scaled_bmm_subpattern))))) {
+    // bmm1 - (scale) - softmax
+    match_result.matched_bmm_1 = bmm_1;
+    match_result.matched_scale = scale;
+    match_result.matched_custom_call_name =
+        has_dropout ? kCudnnfMHASoftmaxDropoutCallTarget
+                    : kCudnnfMHASoftmaxCallTarget;
+    match_result.has_match = true;
+  } else if (Match(softmax_input,
+                   OptionalBitcast(m::AddAnyOrder(
+                       OptionalConvert(OptionalBitcast(m::AnyOf<HloInstruction>(
+                           unfused_scaled_bmm_subpattern.WithOneUse(),
+                           first_bmm_pattern.WithOneUse()))),
+                       m::Op(&bias))))) {
+    // bmm1 - (scale) - bias - softmax
+    match_result.matched_bmm_1 = bmm_1;
+    match_result.matched_scale = scale;
+    match_result.matched_custom_call_name =
+        has_dropout ? kCudnnfMHAScaleBiasSoftmaxDropoutCallTarget
+                    : kCudnnfMHAScaleBiasSoftmaxCallTarget;
+    match_result.is_causal_mask |= IsCausalMaskPattern(bias);
+    if (!match_result.is_causal_mask &&
+        bias->opcode() == HloOpcode::kBroadcast) {
+      // we can take the bias before broadcast
+      auto dims = Cast<HloBroadcastInstruction>(bias)->dimensions();
+      if (dims == std::vector<int64_t>{2, 3} ||
+          dims == std::vector<int64_t>{0, 2, 3} ||
+          dims == std::vector<int64_t>{1, 2, 3}) {
+        // shapes [1, 1, s, s], [b, 1, s, s], [1, h, s, s] are supported
+        HloInstruction* bias_bc = bias->mutable_operand(0);
+        // bitcast bias_before_broadcast to be 4D
+        std::vector<int64_t> bitcast_dims(bias->shape().rank(), 1);
+        for (int dim : dims) {
+          bitcast_dims[dim] = bias->shape().dimensions()[dim];
+        }
+        bias = bias_bc->AddInstruction(HloInstruction::CreateBitcast(
+            ShapeUtil::MakeShape(bias->shape().element_type(), bitcast_dims),
+            bias_bc));
+      }
+    }
+    match_result.matched_bias = bias;
+    match_result.has_match = true;
+  } else {
+    match_result.has_match = false;
+  }
+  return match_result;
+}
+
+// We will try to match all the patterns below:
+// BMM1 - Scale - bias - Softmax - Dropout - BMM2
+// BMM1 - Scale - bias - Softmax - BMM2
+// BMM1 - Softmax - Dropout - BMM2
+// BMM1 - Softmax - BMM2
+MatchFwdResult MatchFwdMHAPatternsForCanonicalization(HloInstruction* instr) {
+  // We need to match 2 general cases:
+  // 1. bmm1 --> (intermediate nodes) --> bmm2 <-- V matrix
+  // 2. V matrix --> bmm2 <-- (intermediate nodes) <-- bmm1
+  // to determine if we need to canonicalize bmm2.
+  // So we go through both of bmm2's operands and see which one matches our
+  // desired patterns, if operand 1 consumes them, then we need to canonicalize.
+  MatchFwdResult match_result;
+  for (auto bmm2_operand_pos : {0, 1}) {
+    if (bmm2_operand_pos == 1) {
+      match_result.need_canonicalization = true;
+    }
+
+    bool has_dropout = false;
+    // We first check if bmm2 is connect to a softmax or dropout.
+    // If so, we set softmax input and dropout rate to their corresponding
+    // values.
+    match_result =
+        MatchSoftmaxDropoutBmm(match_result, bmm2_operand_pos, instr);
+    if (!match_result.has_match) {
+      continue;
+    }
+    has_dropout = match_result.matched_dropout_rate > 0.0;
+    match_result = MatchBmm1UnfusedBiasSoftmaxBmm2(
+        match_result, match_result.matched_softmax_input, has_dropout);
+    if (match_result.has_match) {
+      return match_result;
+    }
+  }
+  // Didn't find any match
+  match_result.need_canonicalization = false;
+  return match_result;
+}
+
+bool IsBmm2GradGemm2(HloInstruction* instr) {
+  // Check to see if input bmm is bmm2 gradient gemm2, it needs to be either:
+  // 1. having 1 user in cases of dropout
+  // 2. having 2 users in other cases.
+  return (instr->user_count() == 1) || (instr->user_count() == 2);
+}
+
+MatchBwdResult MatchBmm1GradGemm1(MatchBwdResult previous_result,
+                                  HloInstruction* bmm_1) {
+  MatchBwdResult match_result = previous_result;
+  match_result.has_match = false;
+  const HloInstruction* q_tensor = bmm_1->operand(0);
+  for (int64_t i = 0; i < q_tensor->user_count(); i++) {
+    HloInstruction* q_tensor_user_i = q_tensor->users()[i];
+    if (IsBatchedMatmul(q_tensor_user_i) && q_tensor_user_i != bmm_1) {
+      match_result.matched_bmm_1_grad_1 = q_tensor_user_i;
+      // Check for canonicalization.
+      if (match_result.matched_bmm_1_grad_1->operand_index(q_tensor) != 1) {
+        match_result.bmm_1_grad_1_need_canonicalization = true;
+      }
+      match_result.has_match = true;
+    }
+  }
+  return match_result;
+}
+
+MatchBwdResult MatchBmm1GradGemm2(MatchBwdResult previous_result,
+                                  HloInstruction* fwd_fmha_call) {
+  HloInstruction* bmm_1_grad_2 = nullptr;
+  MatchBwdResult match_result = previous_result;
+  match_result.has_match = false;
+  // bmm1 gradient gemm2 shares the same input d_s as bmm1 gradient gemm1.
+  // Check to see if bmm1 grad gemm1 needs canonicalization or not, if not,
+  // then the shared input is the first operand.
+  int64_t d_s_index = match_result.bmm_1_grad_1_need_canonicalization ? 1 : 0;
+  HloInstruction* d_s_user_0 = match_result.matched_bmm_1_grad_1;
+
+  HloInstruction* d_s = d_s_user_0->mutable_operand(d_s_index);
+  if (d_s->opcode() == HloOpcode::kBitcast && d_s->user_count() == 1) {
+    d_s = d_s->mutable_operand(0);
+  }
+
+  auto bmm_1_grad_2_it = std::find_if(
+      d_s->users().begin(), d_s->users().end(), [&](HloInstruction* instr) {
+        return instr != match_result.matched_bmm_1_grad_1 &&
+               instr->opcode() == HloOpcode::kDot;
+      });
+  if (bmm_1_grad_2_it != d_s->users().end()) {
+    bmm_1_grad_2 = *bmm_1_grad_2_it;
+  } else {
+    return match_result;
+  }
+
+  match_result.matched_bmm_1_grad_2 = bmm_1_grad_2;
+
+  if (match_result.matched_bmm_1_grad_2->operand_index(d_s) != 0) {
+    match_result.bmm_1_grad_2_need_canonicalization = true;
+  }
+  match_result.has_match = true;
+  return match_result;
+}
+
+MatchBwdResult MatchBmm2GradGemm1(HloInstruction* fwd_fmha_call) {
+  HloInstruction* bmm_2_grad_1 = nullptr;
+  MatchBwdResult matched_result;
+  // The second GTE of the forward MHA call is the input of the bmm2's gradient
+  // gemm 1, we check to see if the current gemm satisfies above condition.
+  int64_t activation_out_gte_index = 1;
+  if (fwd_fmha_call->user_count() < 2 ||
+      fwd_fmha_call->users()[activation_out_gte_index]->opcode() !=
+          HloOpcode::kGetTupleElement ||
+      fwd_fmha_call->users()[activation_out_gte_index]->user_count() > 1 ||
+      !IsBatchedMatmul(
+          fwd_fmha_call->users()[activation_out_gte_index]->users()[0])) {
+    matched_result.has_match = false;
+    return matched_result;
+  }
+  // Found fmha->GTE->gemm, assign it to bmm_2_grad_1 and check to see if it
+  // needs canonicalization.
+  bmm_2_grad_1 = fwd_fmha_call->users()[activation_out_gte_index]->users()[0];
+  matched_result.matched_bmm_2_grad_1 = bmm_2_grad_1;
+  if (bmm_2_grad_1->operand_index(
+          fwd_fmha_call->users()[activation_out_gte_index]) != 0) {
+    matched_result.bmm_2_grad_1_need_canonicalization = true;
+  }
+
+  matched_result.has_match = true;
+  return matched_result;
+}
+
+MatchBwdResult MatchBmm2GradGemm2(MatchBwdResult previous_result,
+                                  HloInstruction* fwd_fmha_call,
+                                  bool v_transposed) {
+  MatchBwdResult match_result = previous_result;
+  match_result.has_match = false;
+  // If v tensor is transposed by forward fmha call, then we need to take fmha v
+  // input's producer's producer.
+  const HloInstruction* v_tensor = v_transposed
+                                       ? fwd_fmha_call->operand(2)->operand(0)
+                                       : fwd_fmha_call->operand(2);
+  for (int64_t i = 0; i < v_tensor->user_count(); i++) {
+    HloInstruction* v_tensor_user_i = v_tensor->users()[i];
+    if (IsBatchedMatmul(v_tensor_user_i) && IsBmm2GradGemm2(v_tensor_user_i)) {
+      match_result.matched_bmm_2_grad_2 = v_tensor_user_i;
+      // Check for canonicalization.
+      if (match_result.matched_bmm_2_grad_2->operand_index(v_tensor) != 1) {
+        match_result.bmm_2_grad_2_need_canonicalization = true;
+      }
+      match_result.has_match = true;
+    }
+  }
+  return match_result;
+}
+
+MatchBwdResult MatchDbias(MatchBwdResult previous_result,
+                          HloInstruction* d_intermediate,
+                          const absl::flat_hash_set<HloInstruction*> users) {
+  MatchBwdResult match_result = previous_result;
+  auto user_count = d_intermediate->user_count();
+  HloInstruction* dbias_user = nullptr;
+  HloInstruction* dbias = nullptr;
+  for (auto user : d_intermediate->users()) {
+    if (users.contains(user)) {
+      user_count -= 1;
+    } else {
+      dbias_user = user;
+    }
+  }
+  auto ConsumeExtraConvert = [](HloInstruction* instr) {
+    Match(instr->users()[0], m::Convert(&instr, m::Op()).WithOneUse());
+    return true;
+  };
+  // user_count == 1 && (reduce-> {convert} ->bitcast)
+  match_result.has_match =
+      user_count == 1 &&
+      Match(dbias_user, m::Reduce(&dbias, m::Op(), m::Op()).WithOneUse()) &&
+      dbias->shape().rank() == 3 && ConsumeExtraConvert(dbias);
+  if (match_result.has_match) {
+    // cuDNN only supports dbias for [1, h, s, s]
+    // make sure reduce dimension is on batch dim
+    auto reduce_dim = dbias->dimensions();
+    if (reduce_dim.size() == 1 && reduce_dim[0] == 0) {
+      match_result.matched_dbias = dbias;
+    } else {
+      match_result.has_match = false;
+    }
+  }
+  return match_result;
+}
+
+MatchBwdResult MatchBwdBmmSoftmaxDropoutBmm(MatchBwdResult previous_result,
+                                            HloInstruction* fwd_fmha_call) {
+  MatchBwdResult match_result = previous_result;
+  bool is_bmm1_grad1_canonicalized =
+      match_result.bmm_1_grad_1_need_canonicalization;
+  match_result.has_match = false;
+  bool has_scale = false;
+  bool has_dropout = false;
+  // Backward dropout pattern
+  // select(mask, bmm2_grad2, broadcast())
+  auto bwd_dropout_pattern_form_1 = m::SharedSubpattern(
+      OptionalBitcast(OptionalReshape(OptionalConvert(m::Select(
+          m::Op(), m::Op().WithPredicate([&](const HloInstruction* instr) {
+            return instr == match_result.matched_bmm_2_grad_2;
+          }),
+          m::Broadcast(
+              OptionalConvert(m::Constant().WithPredicate(IsScalar))))))));
+
+  // multiply(bmm2_grad2, broadcast(select(mask, broadcast(), op())))
+  auto bwd_dropout_pattern_form_2 =
+      m::SharedSubpattern(OptionalBitcast(m::MultiplyAnyOrder(
+          OptionalConvert(
+              m::Op().WithPredicate([&](const HloInstruction* instr) {
+                return instr == match_result.matched_bmm_2_grad_2;
+              })),
+          m::Broadcast(OptionalConvert(OptionalBitcast(OptionalReshape(
+              m::Select(m::Op(),
+                        m::Broadcast(OptionalConvert(
+                            m::Constant().WithPredicate(IsScalar))),
+                        m::Op()))))))));
+  auto bwd_dropout_pattern_form_3 = OptionalConvert(m::MultiplyAnyOrder(
+      m::MultiplyAnyOrder(
+          m::Op().WithPredicate([&](const HloInstruction* instr) {
+            return instr == match_result.matched_bmm_2_grad_2;
+          }),
+          m::Broadcast(m::Constant().WithPredicate(IsScalar))),
+      m::Op()));
+  auto bwd_dropout_pattern = m::AnyOf<HloInstruction>(
+      bwd_dropout_pattern_form_1, bwd_dropout_pattern_form_2,
+      bwd_dropout_pattern_form_3);
+  // Backward softmax pattern
+  HloInstruction* bwd_softmax_input = nullptr;
+  HloInstruction* exp_1;
+  HloInstruction* exp_2;
+  HloInstruction* d_softmax;
+
+  // d_softmax = exp * (dy / s_b - sum(dy * exp * 1 / s^2))
+  // there could be at most 3 users of d_softmax: bmm1grad1 bmm1grad2 and dbias
+  auto bwd_softmax_pattern = OptionalBitcast(OptionalConvert(
+      m::MultiplyAnyOrder(
+          &d_softmax,
+          m::AddAnyOrder(
+              m::Divide().WithOneUse(),
+              m::Broadcast(OptionalBitcast(OptionalConvert(
+                  m::Negate(
+                      OptionalBitcast(
+                          m::Op()
+                              .WithPredicate(IsReduceSum)
+                              .WithOneUse()
+                              .WithOperand(
+                                  0, OptionalBitcast(
+                                         m::MultiplyAnyOrder(
+                                             m::MultiplyAnyOrder(
+                                                 m::Op(&bwd_softmax_input),
+                                                 m::Broadcast())
+                                                 .WithOneUse(),
+                                             m::Exp(&exp_2, m::Op()))
+                                             .WithOneUse()))))
+                      .WithOneUse())))),
+          m::Exp(&exp_1, m::Op()))
+          .WithAtMostNumUser(3)));
+
+  // Backward scale input pattern
+  HloInstruction* bwd_scale_input = nullptr;
+  HloInstruction* bwd_scale = nullptr;
+
+  auto bwd_scale_pattern =
+      m::MultiplyAnyOrder(&bwd_scale, m::Op(&bwd_scale_input),
+                          m::Broadcast(m::Constant().WithPredicate(IsScalar)))
+          .WithNumUser(2);
+  int intermediate_input_pos = is_bmm1_grad1_canonicalized ? 1 : 0;
+
+  HloInstruction* intermediate_input =
+      match_result.matched_bmm_1_grad_1->mutable_operand(
+          intermediate_input_pos);
+
+  has_scale = Match(intermediate_input, bwd_scale_pattern);
+
+  if (has_scale) {
+    intermediate_input = bwd_scale_input;
+  }
+
+  if (!Match(intermediate_input, bwd_softmax_pattern) || exp_1 != exp_2) {
+    return match_result;
+  }
+  has_dropout = Match(bwd_softmax_input, bwd_dropout_pattern);
+  // If no dropout but softmax input is not coming from bmm2 gradient gemm 2,
+  // then it's not the pattern that we care about.
+  if (!has_dropout &&
+      !Match(bwd_softmax_input,
+             OptionalConvert((OptionalBitcast(
+                 m::Op().WithPredicate([&](const HloInstruction* instr) {
+                   return instr == match_result.matched_bmm_2_grad_2;
+                 })))))) {
+    return match_result;
+  }
+
+  if (has_dropout) {
+    // has bias
+    if (fwd_fmha_call->custom_call_target() ==
+        kCudnnfMHAScaleBiasSoftmaxDropoutCallTarget)
+      match_result.matched_custom_call_name =
+          kCudnnfMHAScaleBiasSoftmaxDropoutBackwardCallTarget;
+    // no bias
+    if (fwd_fmha_call->custom_call_target() ==
+        kCudnnfMHASoftmaxDropoutCallTarget)
+      match_result.matched_custom_call_name =
+          kCudnnfMHASoftmaxDropoutBackwardCallTarget;
+  } else {
+    // has bias
+    if (fwd_fmha_call->custom_call_target() ==
+        kCudnnfMHAScaleBiasSoftmaxCallTarget)
+      match_result.matched_custom_call_name =
+          kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget;
+    // no bias
+    if (fwd_fmha_call->custom_call_target() == kCudnnfMHASoftmaxCallTarget)
+      match_result.matched_custom_call_name =
+          kCudnnfMHASoftmaxBackwardCallTarget;
+  }
+  // try to pattern match dbias
+  HloInstruction* dS = d_softmax;
+  if (dS->users()[0]->opcode() == HloOpcode::kConvert) {
+    dS = dS->users()[0];
+  }
+  if (has_scale) {
+    // bmm1-(scale)-(bias)-softmax pattern users could be dbias or scale bwd
+    if (dS->user_count() == 1) {
+      // no dbias
+      match_result.has_match = true;
+    } else if (dS->user_count() == 2) {
+      match_result = MatchDbias(match_result, dS, {bwd_scale});
+    } else {
+      match_result.has_match = false;
+    }
+  } else {
+    // bmm1-(bias)-softmax pattern
+    // users could be dbias besides bmm1grad1 bmm1grad2
+    if (dS->user_count() == 2) {
+      match_result.has_match = true;
+    } else if (dS->user_count() == 3) {
+      match_result = MatchDbias(match_result, dS,
+                                {match_result.matched_bmm_1_grad_1,
+                                 match_result.matched_bmm_1_grad_2});
+    } else {
+      match_result.has_match = false;
+    }
+  }
+  return match_result;
+}
+// First, we look for the bmm2 gradient gemm 1 which takes the activation
+// output from a forward fmha call.
+// Secondly, look for bmm2 gradient gemm 2 that takes the v tensor as an
+// input. We take the v tensor from the third operand of the forward fmha
+// call. If forward is canonicalized, then we skip the additional transpose in
+// between.
+// Then we look for bmm1 gradient gemm1 by searching for gemms that share q
+// tensor with current fmha call.
+MatchBwdResult MatchBackwardBmms(HloInstruction* fwd_fmha_call,
+                                 HloInstruction* bmm_1, bool v_transposed) {
+  MatchBwdResult matched_result = MatchBmm2GradGemm1(fwd_fmha_call);
+  if (!matched_result.has_match) {
+    return matched_result;
+  }
+
+  matched_result =
+      MatchBmm2GradGemm2(matched_result, fwd_fmha_call, v_transposed);
+  if (!matched_result.has_match) {
+    return matched_result;
+  }
+
+  matched_result = MatchBmm1GradGemm1(matched_result, bmm_1);
+  if (!matched_result.has_match) {
+    return matched_result;
+  }
+
+  matched_result = MatchBmm1GradGemm2(matched_result, fwd_fmha_call);
+  if (!matched_result.has_match) {
+    return matched_result;
+  }
+  return matched_result;
+}
+// We will match the backward graphs for all forward patterns defined in
+// MatchFwdMHAPatternsForCanonicalization
+MatchBwdResult MatchBwdMHAPatternsForCanonicalization(
+    HloInstruction* fwd_fmha_call, HloInstruction* bmm_1, bool v_transposed) {
+  MatchBwdResult match_result =
+      MatchBackwardBmms(fwd_fmha_call, bmm_1, v_transposed);
+  if (!match_result.has_match) {
+    return match_result;
+  }
+  match_result = MatchBwdBmmSoftmaxDropoutBmm(match_result, fwd_fmha_call);
+  return match_result;
+}
+
+absl::StatusOr<bool> IsMHABlockSupported(
+    HloInstruction* bmm_1, HloInstruction* bmm_2, bool need_canonicalization,
+    bool is_training, bool is_causal_mask, std::string& custom_call_name,
+    const DebugOptions& debug_options,
+    stream_executor::CudaComputeCapability cc,
+    stream_executor::dnn::VersionInfo cudnn_version) {
+  if (MHACallHasDropout(custom_call_name) &&
+      !debug_options.xla_gpu_fused_attention_use_cudnn_rng()) {
+    VLOG(3) << "Using CUDNN RNG for fused attention dropout is not enabled.\n";
+    return false;
+  }
+
+  // cuDNN 8.8 currently only supports BF16 and F16 data types.
+  if (!IsSupportedPrimitiveType(bmm_1) || !IsSupportedPrimitiveType(bmm_2)) {
+    if (VLOG_IS_ON(2)) {
+      VLOG(2) << "Unsupported primitive type for cuDNN MHA fusion:\n"
+              << bmm_1->ToString() << "\nOR\n"
+              << bmm_2->ToString() << "\n"
+              << "BF16 and F16 are the supported Dtypes.";
+    }
+    return false;
+  }
+
+  if (bmm_1->shape().rank() != 4 || bmm_2->shape().rank() != 4) {
+    if (VLOG_IS_ON(2)) {
+      VLOG(2) << "Unsupported bmm rank for cuDNN MHA fusion:\n"
+              << bmm_1->ToString() << "\nOR\n"
+              << bmm_2->ToString() << "\n"
+              << "Only bmm with rank 4 is supported.";
+    }
+    return false;
+  }
+
+  // get batch/num heads/sequence length/hidden dim from bmm1 and bmm2
+  // also make sure they are the same between bmm1 and bmm2
+  TF_ASSIGN_OR_RETURN(std::optional<QKVLayout> qkv_layout,
+                      GetQKVLayout(bmm_1, bmm_2, need_canonicalization));
+  if (!qkv_layout.has_value()) {
+    VLOG(2) << "bmm1 and bmm2 have different qkv layout.";
+    return false;
+  }
+
+  // check if matched attention block is supported by cuDNN flash attention.
+  TF_ASSIGN_OR_RETURN(
+      bool is_flash_attention,
+      IsFlashAttention(qkv_layout.value(), is_training, cc, cudnn_version));
+  if (is_flash_attention) {
+    if (is_causal_mask) {
+      // if bias is causal mask, needs to remove bias from name
+      if (custom_call_name == kCudnnfMHAScaleBiasSoftmaxDropoutCallTarget) {
+        custom_call_name = kCudnnfMHASoftmaxDropoutCallTarget;
+      } else if (custom_call_name == kCudnnfMHAScaleBiasSoftmaxCallTarget) {
+        custom_call_name = kCudnnfMHASoftmaxCallTarget;
+      }
+    }
+  }
+  return is_flash_attention;
+}
+
+absl::StatusOr<HloInstruction*> CanonicalizeBatchedGemmForcuDNNFMHA(
+    HloInstruction* bmm, HloComputation* comp) {
+  if (VLOG_IS_ON(3)) {
+    VLOG(3) << "Before FMHA Dot Cannonicalization: \n"
+            << comp->parent()->ToString();
+  }
+  HloInstruction* lhs_bmm = bmm->mutable_operand(0);
+  HloInstruction* rhs_bmm = bmm->mutable_operand(1);
+  const DotDimensionNumbers& dnums = bmm->dot_dimension_numbers();
+
+  int64_t rank = bmm->shape().dimensions_size();
+  std::vector<int64_t> perm(rank);
+  std::iota(perm.begin(), perm.end(), 0);
+  // Swap the non-contracting dims of BMM shape. By contract, the
+  // non-contracting dims in the output are the last two dimensions.
+  std::swap(perm[rank - 1], perm[rank - 2]);
+
+  DotDimensionNumbers new_dnums = dnums;
+  std::swap(*new_dnums.mutable_lhs_contracting_dimensions(),
+            *new_dnums.mutable_rhs_contracting_dimensions());
+  std::swap(*new_dnums.mutable_lhs_batch_dimensions(),
+            *new_dnums.mutable_rhs_batch_dimensions());
+  auto original_bmm_shape = bmm->shape();
+  HloInstruction* new_dot = comp->AddInstruction(HloInstruction::CreateDot(
+      ShapeUtil::MakeShape(original_bmm_shape.element_type(),
+                           Permute(original_bmm_shape.dimensions(), perm)),
+      /* lhs */ rhs_bmm, /* rhs */ lhs_bmm, new_dnums,
+      bmm->precision_config()));
+
+  TF_RETURN_IF_ERROR(comp->ReplaceWithNewInstruction(
+      bmm, HloInstruction::CreateTranspose(original_bmm_shape, new_dot, perm)));
+  if (VLOG_IS_ON(2)) {
+    VLOG(2) << "After FMHA Dot Cannonicalization: \n"
+            << comp->parent()->ToString();
+  }
+  return new_dot;
+}
+
+absl::StatusOr<HloInstruction*> ChangeCheckedDimToFastest(
+    HloComputation* comp, HloInstruction* bmm, bool is_lhs,
+    bool should_contracting_be_fastest) {
+  const DotDimensionNumbers& dot_dims_bmm = bmm->dot_dimension_numbers();
+  DotDimensionNumbers new_dot_dims_bmm = dot_dims_bmm;
+  int64_t bmm_operand = is_lhs ? 0 : 1;
+  absl::Span<const int64_t> contracting_dims =
+      is_lhs ? dot_dims_bmm.lhs_contracting_dimensions()
+             : dot_dims_bmm.rhs_contracting_dimensions();
+  absl::Span<const int64_t> batch_dims =
+      is_lhs ? dot_dims_bmm.lhs_batch_dimensions()
+             : dot_dims_bmm.rhs_batch_dimensions();
+  absl::Span<const int64_t> lhs_minor_to_major_bmm =
+      bmm->operand(0)->shape().layout().minor_to_major();
+  absl::Span<const int64_t> rhs_minor_to_major_bmm =
+      bmm->operand(1)->shape().layout().minor_to_major();
+
+  absl::Span<const int64_t>& minor_to_major_to_check =
+      is_lhs ? lhs_minor_to_major_bmm : rhs_minor_to_major_bmm;
+
+  CHECK_EQ(contracting_dims.size(), 1);
+  TF_ASSIGN_OR_RETURN(std::vector<int64_t> non_contracting_dims,
+                      GetNonContractingDims(bmm->operand(bmm_operand)->shape(),
+                                            batch_dims, contracting_dims));
+  CHECK_EQ(non_contracting_dims.size(), 1);
+  HloInstruction* operand_bmm = bmm->mutable_operand(bmm_operand);
+  int64_t hidden_dim = should_contracting_be_fastest ? contracting_dims[0]
+                                                     : non_contracting_dims[0];
+  int64_t minor_dim = minor_to_major_to_check[0];
+  // If the hidden dim of the target operand is not the fastest moving
+  // dimension, make it so.
+  if (minor_dim != hidden_dim) {
+    std::vector<int64_t> perm(bmm->shape().dimensions_size());
+    std::iota(perm.begin(), perm.end(), 0);
+    std::swap(perm[hidden_dim], perm[minor_dim]);
+
+    if (is_lhs) {
+      new_dot_dims_bmm.set_lhs_contracting_dimensions(0,
+                                                      non_contracting_dims[0]);
+    } else {
+      new_dot_dims_bmm.set_rhs_contracting_dimensions(0,
+                                                      non_contracting_dims[0]);
+    }
+
+    operand_bmm = comp->AddInstruction(
+        HloInstruction::CreateTranspose(
+            ShapeUtil::MakeShapeWithDenseLayout(
+                bmm->shape().element_type(),
+                Permute(operand_bmm->shape().dimensions(), perm),
+                minor_to_major_to_check),
+            operand_bmm, perm),
+        &operand_bmm->metadata());
+    *((DynCast<HloDotInstruction>(bmm))->mutable_dot_dimension_numbers()) =
+        new_dot_dims_bmm;
+  }
+  return operand_bmm;
+}
+
+absl::StatusOr<HloInstruction*> FuseFwdMultiHeadedAttentionBlock(
+    HloComputation* comp, HloInstruction* bmm_1, HloInstruction* bmm_2,
+    HloInstruction* bias, HloInstruction* scale, HloInstruction* reduce_sum,
+    HloInstruction* softmax_input, double dropout_rate,
+    std::string& custom_call_name, stream_executor::CudaComputeCapability cc,
+    bool is_training, bool& changed, bool& v_transposed, bool is_causal_mask) {
+  double scale_value = 1.0;
+  HloInstruction* lhs_bmm1;
+  HloInstruction* rhs_bmm1;
+  HloInstruction* rhs_bmm2;
+  DotDimensionNumbers orig_bmm1_dot_dim = bmm_1->dot_dimension_numbers();
+  DotDimensionNumbers orig_bmm2_dot_dim = bmm_2->dot_dimension_numbers();
+
+  TF_ASSIGN_OR_RETURN(rhs_bmm1, ChangeCheckedDimToFastest(
+                                    comp, bmm_1, false /*is_lhs*/,
+                                    true /*should_contracting_be_fastest*/));
+  TF_ASSIGN_OR_RETURN(lhs_bmm1, ChangeCheckedDimToFastest(
+                                    comp, bmm_1, true /*is_lhs*/,
+                                    true /*should_contracting_be_fastest*/));
+
+  TF_ASSIGN_OR_RETURN(rhs_bmm2, ChangeCheckedDimToFastest(
+                                    comp, bmm_2, false /*is_lhs*/,
+                                    false /*should_contracting_be_fastest*/));
+
+  if (rhs_bmm2 != bmm_2->mutable_operand(1)) {
+    v_transposed = true;
+  }
+
+  GpuBackendConfig gpu_config;
+  CudnnfMHABackendConfig& fmha_config =
+      *gpu_config.mutable_cudnn_fmha_backend_config();
+
+  *fmha_config.mutable_bmm1_dot_dimension_numbers() =
+      bmm_1->dot_dimension_numbers();
+  *fmha_config.mutable_bmm2_dot_dimension_numbers() =
+      bmm_2->dot_dimension_numbers();
+
+  TF_RET_CHECK((dropout_rate >= 0.0 && dropout_rate <= 1.0));
+  // Restore original DotDimensionNumbers.
+  *((DynCast<HloDotInstruction>(bmm_1))->mutable_dot_dimension_numbers()) =
+      orig_bmm1_dot_dim;
+  *((DynCast<HloDotInstruction>(bmm_2))->mutable_dot_dimension_numbers()) =
+      orig_bmm2_dot_dim;
+
+  // If scale node is assigned, extract value from it.
+  if (scale != nullptr) {
+    std::optional<double> value;
+    value = GetConstantValue(scale);
+    TF_RET_CHECK(value.has_value());
+    scale_value = (double)*value;
+  }
+
+  fmha_config.set_fmha_scale(scale_value);
+  fmha_config.set_dropout_rate(dropout_rate);
+  // Set to an arbitrary seed for now, seed is not exposed to XLA in HLO
+  // graph.
+  // TODO Find a way to compute original seed from dropout keys.
+  fmha_config.set_seed(42);
+
+  *fmha_config.mutable_intermediate_tensor_shape() = bmm_1->shape().ToProto();
+  {
+    auto* algorithm = fmha_config.mutable_algorithm();
+    algorithm->set_algo_id(0);  // engine id
+    algorithm->set_math_type(se::dnn::AlgorithmProto::TENSOR_OP_MATH);
+    std::vector<int64_t> knob_ids = /* {0, 1} */ {17, 24};
+    std::vector<int64_t> knob_vals = {1, 0};
+    for (int i = 0; i < knob_ids.size(); ++i) {
+      (*algorithm->mutable_tuning_knobs())[knob_ids[i]] = knob_vals[i];
+    }
+    algorithm->set_is_cudnn_frontend(true);
+    algorithm->mutable_workspace_size()->set_value(0);
+  }
+  // set is_causal_mask here
+  // choose to generate causal mask inside cuDNN attention or not
+  fmha_config.set_mask_type(is_causal_mask ? CudnnfMHABackendConfig::CAUSAL
+                                           : CudnnfMHABackendConfig::NO_MASK);
+
+  const Shape& output_shape = bmm_2->shape();
+
+  Shape call_shape;
+  // Activation output is used by backward gemm.
+  HloInstruction* activation_output = nullptr;
+
+  // Output Order: {O, Fwd act*, workspace}
+  std::vector<Shape> output_shapes = {output_shape};
+  if (is_training) {
+    activation_output = bmm_2->mutable_operand(0);
+    // Sometimes activation output is bitcast, the actual activation is the
+    // other user of the producer of bmm_2's first operand.
+    if (activation_output->user_count() < 2 &&
+        activation_output->opcode() == HloOpcode::kBitcast) {
+      HloInstruction* producer = activation_output->mutable_operand(0);
+      TF_RET_CHECK(producer->user_count() == 2);
+      HloInstruction* bmm2_grad2_user =
+          producer->users()[0] == activation_output ? producer->users()[1]
+                                                    : producer->users()[0];
+      // might be (transpose) - bmm2_grad2
+      if (IsBatchedMatmul(bmm2_grad2_user)) {
+        activation_output = producer;
+      } else if (bmm2_grad2_user->opcode() == HloOpcode::kTranspose) {
+        activation_output = bmm2_grad2_user;
+      } else {
+        return Internal("Unexpected activation patterns");
+      }
+    }
+    // if it is flash attention, should output softmax stats to the bwd
+    TF_RET_CHECK(reduce_sum != nullptr);
+    output_shapes.push_back(
+        ShapeUtil::MakeShape(F32, reduce_sum->shape().dimensions()));
+  }
+  output_shapes.push_back(ShapeUtil::MakeShape(U8, {0}));
+  call_shape = ShapeUtil::MakeTupleShape(output_shapes);
+
+  // Input Order: {Q, K, V, bias*}
+  std::vector<HloInstruction*> operands = {lhs_bmm1, rhs_bmm1, rhs_bmm2};
+  if (!is_causal_mask && bias != nullptr) {
+    HloInstruction* original_bias;
+    HloInstruction* original_broadcast;
+    // There will be cases where the bias is up-casted to wider float type,
+    // we need to take the original bias node and broadcast it without
+    // converting.
+    if (Match(bias, m::Broadcast(
+                        &original_broadcast,
+                        m::Convert(
+                            m::Op(&original_bias)
+                                .WithPredicate([](const HloInstruction* instr) {
+                                  return instr->shape().element_type() == F16 ||
+                                         instr->shape().element_type() == BF16;
+                                }))
+                            .WithPredicate([](const HloInstruction* instr) {
+                              return instr->shape().element_type() == F32 ||
+                                     instr->shape().element_type() == F64;
+                            })))) {
+      absl::Span<const int64_t> original_bcast_dims =
+          (DynCast<HloBroadcastInstruction>(original_broadcast))->dimensions();
+      // This is to deal with cases like paxml where an extra dimension of 1 is
+      // added to the left of the tensor.
+      // TODO Make this logic more generic
+      absl::Span<const int64_t> original_broadcast_shape_dims =
+          original_broadcast->shape().dimensions();
+      int64_t starting_index = original_broadcast_shape_dims.size() == 5 &&
+                                       original_broadcast_shape_dims[0] == 1
+                                   ? 1
+                                   : 0;
+      std::vector<int64_t> bcast_dimensions;
+      for (auto& dim : original_bcast_dims) {
+        bcast_dimensions.push_back(dim - starting_index);
+      }
+
+      const Shape& bcast_shape = bmm_1->shape();
+      bias = comp->AddInstruction(HloInstruction::CreateBroadcast(
+          bcast_shape, original_bias, bcast_dimensions));
+    }
+    operands.push_back(bias);
+  }
+
+  HloInstruction* fmha_call =
+      comp->AddInstruction(HloInstruction::CreateCustomCall(
+          call_shape, operands, absl::string_view(custom_call_name)));
+  TF_RETURN_IF_ERROR(fmha_call->set_backend_config(gpu_config));
+  TF_RETURN_IF_ERROR(SetFMHAInstructionName(bmm_1->GetModule(), fmha_call));
+
+  TF_RETURN_IF_ERROR(comp->ReplaceWithNewInstruction(
+      bmm_2,
+      HloInstruction::CreateGetTupleElement(bmm_2->shape(), fmha_call, 0)));
+
+  if (activation_output) {
+    HloInstruction* activation_gte =
+        comp->AddInstruction(HloInstruction::CreateGetTupleElement(
+            activation_output->shape(), fmha_call, 1));
+    TF_RETURN_IF_ERROR(comp->ReplaceInstructionWithDifferentShape(
+                               activation_output, activation_gte,
+                               /*preserve_sharding=*/false,
+                               /*relay_control_dependency=*/false,
+                               /*remove_unused_operands=*/false)
+                           .status());
+  }
+
+  if (VLOG_IS_ON(2)) {
+    VLOG(2) << "After CudnnFusedMHARewriter: \n" << comp->parent()->ToString();
+  }
+  changed = true;
+  return fmha_call;
+}
+
+absl::StatusOr<bool> FuseBwdMultiHeadedAttentionBlock(
+    HloComputation* comp, HloInstruction* bmm_1_grad_1,
+    HloInstruction* bmm_1_grad_2, HloInstruction* bmm_2_grad_1,
+    HloInstruction* bmm_2_grad_2, HloInstruction* fwd_fmha_call,
+    HloInstruction* dbias, HloInstruction* bias,
+    std::string& bwd_custom_call_name) {
+  HloInstruction* rhs_bmm1_grad_gemm1;
+  HloInstruction* lhs_bmm1_grad_gemm2;
+  HloInstruction* rhs_bmm2_grad_gemm2;
+  HloInstruction* d_output_grad;
+
+  DotDimensionNumbers orig_bmm1_grad1_config =
+      bmm_1_grad_1->dot_dimension_numbers();
+  DotDimensionNumbers orig_bmm1_grad2_config =
+      bmm_1_grad_2->dot_dimension_numbers();
+  DotDimensionNumbers orig_bmm2_grad1_config =
+      bmm_2_grad_1->dot_dimension_numbers();
+  DotDimensionNumbers orig_bmm2_grad2_config =
+      bmm_2_grad_2->dot_dimension_numbers();
+
+  TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
+                      fwd_fmha_call->backend_config<GpuBackendConfig>());
+  const CudnnfMHABackendConfig& fwd_config =
+      gpu_config.cudnn_fmha_backend_config();
+  bool is_causal_mask =
+      fwd_config.mask_type() == CudnnfMHABackendConfig::CAUSAL;
+  CudnnfMHABackendConfig bwd_fmha_config;
+  // Q tensor
+  TF_ASSIGN_OR_RETURN(
+      rhs_bmm1_grad_gemm1,
+      ChangeCheckedDimToFastest(comp, bmm_1_grad_1, false /*is_lhs*/,
+                                false /*should_contracting_be_fastest*/));
+  // K tensor
+  TF_ASSIGN_OR_RETURN(
+      lhs_bmm1_grad_gemm2,
+      ChangeCheckedDimToFastest(comp, bmm_1_grad_2, false /*is_lhs*/,
+                                false /*should_contracting_be_fastest*/));
+
+  // Forward activation
+  // softmax_stats
+  HloInstruction* fwd_act;
+  int64_t fwd_act_index = 1;
+  fwd_act = comp->AddInstruction(HloInstruction::CreateGetTupleElement(
+      fwd_fmha_call->shape().tuple_shapes(fwd_act_index), fwd_fmha_call,
+      fwd_act_index));
+
+  // V tensor
+  TF_ASSIGN_OR_RETURN(
+      rhs_bmm2_grad_gemm2,
+      ChangeCheckedDimToFastest(comp, bmm_2_grad_2, false /*is_lhs*/,
+                                true /*should_contracting_be_fastest*/));
+  // d output to bmm2_grad2
+  // Since d_o is the input of 2 bmms, we set the dim number using the
+  // constraint
+  // -> the contracting dimension of the lhs of bmm_2_grad_2 needs to be the
+  // fastest moving dimension.
+  TF_ASSIGN_OR_RETURN(
+      d_output_grad,
+      ChangeCheckedDimToFastest(comp, bmm_2_grad_2, true /*is_lhs*/,
+                                true /*should_contracting_be_fastest*/));
+  // d output to bmm2_grad1
+  // we don't use this value but we call this to make sure dot number is being
+  // set correctly
+  TF_ASSIGN_OR_RETURN(
+      HloInstruction * bmm_2_grad_1_rhs,
+      ChangeCheckedDimToFastest(comp, bmm_2_grad_1, false /*is_lhs*/,
+                                false /*should_contracting_be_fastest*/));
+  (void)bmm_2_grad_1_rhs;
+  // Operand order: {Q, K, V, Fwd act, d_o, bias*, O*}
+  std::vector<HloInstruction*> operands = {
+      rhs_bmm1_grad_gemm1, lhs_bmm1_grad_gemm2, rhs_bmm2_grad_gemm2, fwd_act,
+      d_output_grad};
+
+  // For flash attention, add fwd output to input list
+  if (!is_causal_mask && bias) {
+    operands.push_back(bias);
+  }
+  HloInstruction* fwd_output;
+  for (auto user : fwd_fmha_call->users()) {
+    if (user->opcode() == HloOpcode::kGetTupleElement &&
+        user->tuple_index() == 0) {
+      fwd_output = user;
+    }
+  }
+  // should be able to find the instruction
+  TF_RET_CHECK(fwd_output != nullptr);
+  // check dO and O have the same layout as it is required by cuDNN
+  TF_RET_CHECK(fwd_output->shape() == d_output_grad->shape());
+  operands.push_back(fwd_output);
+
+  *bwd_fmha_config.mutable_bmm1_grad_gemm1_dot_dimension_numbers() =
+      bmm_1_grad_1->dot_dimension_numbers();
+  *bwd_fmha_config.mutable_bmm1_grad_gemm2_dot_dimension_numbers() =
+      bmm_1_grad_2->dot_dimension_numbers();
+  *bwd_fmha_config.mutable_bmm2_grad_gemm1_dot_dimension_numbers() =
+      bmm_2_grad_1->dot_dimension_numbers();
+  *bwd_fmha_config.mutable_bmm2_grad_gemm2_dot_dimension_numbers() =
+      bmm_2_grad_2->dot_dimension_numbers();
+
+  // Restore original DotDimensionNumbers
+  *((DynCast<HloDotInstruction>(bmm_1_grad_1))
+        ->mutable_dot_dimension_numbers()) = orig_bmm1_grad1_config;
+  *((DynCast<HloDotInstruction>(bmm_1_grad_2))
+        ->mutable_dot_dimension_numbers()) = orig_bmm1_grad2_config;
+  *((DynCast<HloDotInstruction>(bmm_2_grad_1))
+        ->mutable_dot_dimension_numbers()) = orig_bmm2_grad1_config;
+  *((DynCast<HloDotInstruction>(bmm_2_grad_2))
+        ->mutable_dot_dimension_numbers()) = orig_bmm2_grad2_config;
+
+  bwd_fmha_config.set_fmha_scale(fwd_config.fmha_scale());
+  bwd_fmha_config.set_dropout_rate(fwd_config.dropout_rate());
+  // Set to an arbitrary seed for now, seed is not exposed to XLA in HLO
+  // graph.
+  // TODO Find a way to compute original seed from dropout keys.
+  bwd_fmha_config.set_seed(fwd_config.seed());
+  bwd_fmha_config.set_mask_type(is_causal_mask
+                                    ? CudnnfMHABackendConfig::CAUSAL
+                                    : CudnnfMHABackendConfig::NO_MASK);
+
+  *bwd_fmha_config.mutable_intermediate_tensor_shape() =
+      fwd_config.intermediate_tensor_shape();
+  {
+    auto* algorithm = bwd_fmha_config.mutable_algorithm();
+    algorithm->set_algo_id(0);  // engine id
+    algorithm->set_math_type(se::dnn::AlgorithmProto::TENSOR_OP_MATH);
+    std::vector<int64_t> knob_ids = /* {0, 1} */ {17, 24};
+    std::vector<int64_t> knob_vals = {1, 0};
+    for (int i = 0; i < knob_ids.size(); ++i) {
+      (*algorithm->mutable_tuning_knobs())[knob_ids[i]] = knob_vals[i];
+    }
+    algorithm->set_is_cudnn_frontend(true);
+    algorithm->mutable_workspace_size()->set_value(0);
+  }
+
+  // Output order:
+  // {dQ(bmm_1_grad_2), dK(bmm_1_grad_1), dV(bmm_2_grad_1), dbias*, workspace}
+  std::vector<Shape> output_shapes = {
+      bmm_1_grad_2->shape(), bmm_1_grad_1->shape(), bmm_2_grad_1->shape()};
+
+  if (dbias) {
+    // Cudnn kernel only outputs dbias in this shape [1, num_heads, seq, seq],
+    // so we add a dimension of 1 to existing dbias' shape.
+    std::vector<int64_t> dbias_shape_vector =
+        SpanToVector(dbias->shape().dimensions());
+    dbias_shape_vector.insert(dbias_shape_vector.begin(), 1);
+    Shape cudnn_dbias_shape =
+        ShapeUtil::MakeShape(dbias->shape().element_type(), dbias_shape_vector);
+    output_shapes.push_back(cudnn_dbias_shape);
+  }
+  // Reserved placeholder for workspace
+  output_shapes.push_back(ShapeUtil::MakeShape(U8, {0}));
+  Shape call_shape = ShapeUtil::MakeTupleShape(output_shapes);
+  HloInstruction* fmha_bwd_call =
+      comp->AddInstruction(HloInstruction::CreateCustomCall(
+          call_shape, operands, absl::string_view(bwd_custom_call_name)));
+  GpuBackendConfig bwd_gpu_config;
+  *bwd_gpu_config.mutable_cudnn_fmha_backend_config() = bwd_fmha_config;
+  TF_RETURN_IF_ERROR(fmha_bwd_call->set_backend_config(bwd_gpu_config));
+  TF_RETURN_IF_ERROR(
+      SetFMHAInstructionName(bmm_1_grad_1->GetModule(), fmha_bwd_call));
+
+  // Q gradient
+  TF_RETURN_IF_ERROR(comp->ReplaceWithNewInstruction(
+      bmm_1_grad_2, HloInstruction::CreateGetTupleElement(bmm_1_grad_2->shape(),
+                                                          fmha_bwd_call, 0)));
+  // K gradient
+  TF_RETURN_IF_ERROR(comp->ReplaceWithNewInstruction(
+      bmm_1_grad_1, HloInstruction::CreateGetTupleElement(bmm_1_grad_1->shape(),
+                                                          fmha_bwd_call, 1)));
+  // V gradient
+  TF_RETURN_IF_ERROR(comp->ReplaceWithNewInstruction(
+      bmm_2_grad_1, HloInstruction::CreateGetTupleElement(bmm_2_grad_1->shape(),
+                                                          fmha_bwd_call, 2)));
+
+  if (dbias) {
+    // Reshape fmha dbias output to original user's input shape.
+    // If the reshape doesn't involve physical data movement,
+    // algebraic simplifer can change it to a no-op bitcast.
+    Shape original_shape = dbias->shape();
+    HloInstruction* dbias_user = dbias->users()[0];
+    HloInstruction* cudnn_dbias_output =
+        comp->AddInstruction(HloInstruction::CreateGetTupleElement(
+            output_shapes[3], fmha_bwd_call, 3));
+    HloInstruction* reshape_dbias = comp->AddInstruction(
+        HloInstruction::CreateReshape(original_shape, cudnn_dbias_output));
+    TF_RETURN_IF_ERROR(dbias_user->ReplaceOperandWith(
+        dbias_user->operand_index(dbias), reshape_dbias));
+
+    TF_RETURN_IF_ERROR(
+        comp->ReplaceInstructionWithDifferentShape(dbias, cudnn_dbias_output));
+  }
+  return true;
+}
+
+absl::Status RestoreFwdGraph(
+    HloComputation* comp, HloInstruction* fwd_fmha_call, HloInstruction* bmm2,
+    HloInstruction* activation, HloInstruction* original_bmm2_producer0,
+    HloInstruction* original_bmm2_producer1,
+    std::vector<HloInstruction*>& original_activation_producers,
+    bool bmm_2_need_canonicalization) {
+  // If backward pattern is not matched, we need to restore the
+  // original graph structure.
+  // Replacing new GTEs added by forward FMHA call with cloned old
+  // activations and bmm2.
+  HloInstruction* output_gte = fwd_fmha_call->users()[0];
+  HloInstruction* activation_gte = fwd_fmha_call->users()[1];
+  std::string suffix = "fmha_no_match_clone";
+  HloInstruction* cloned_activation =
+      comp->AddInstruction(activation->CloneWithNewOperands(
+          activation->shape(), original_activation_producers, suffix));
+
+  // Since old activation is detached by forward FMHA rewrite, we need
+  // to use the newly cloned activation.
+  HloInstruction* lhs = activation == original_bmm2_producer0
+                            ? cloned_activation
+                            : original_bmm2_producer0;
+  HloInstruction* rhs = activation == original_bmm2_producer0
+                            ? original_bmm2_producer1
+                            : cloned_activation;
+  HloInstruction* cloned_bmm2 = comp->AddInstruction(
+      bmm2->CloneWithNewOperands(bmm2->shape(), {lhs, rhs}, suffix));
+  if (bmm_2_need_canonicalization) {
+    TF_RET_CHECK(output_gte->users()[0]->opcode() == HloOpcode::kTranspose);
+    TF_RETURN_IF_ERROR(
+        comp->ReplaceInstruction(output_gte->users()[0], cloned_bmm2));
+  } else {
+    TF_RETURN_IF_ERROR(comp->ReplaceInstruction(output_gte, cloned_bmm2));
+  }
+  TF_RETURN_IF_ERROR(
+      comp->ReplaceInstruction(activation_gte, cloned_activation));
+  return absl::OkStatus();
+}
+}  // namespace
+
+absl::StatusOr<bool> CudnnFusedMHARewriter::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  bool any_changed = false;
+  // we use this set to keep track of all already matched attention block
+  absl::flat_hash_set<HloInstruction*> matched_bmm1;
+  for (HloComputation* comp :
+       module->MakeNonfusionComputations(execution_threads)) {
+    const DebugOptions& debug_options =
+        comp->parent()->config().debug_options();
+    const se::dnn::VersionInfo cudnn_version =
+        GetDnnVersionInfoOrDefault(stream_executor_, cudnn_version_);
+#if !defined(GOOGLE_CUDA) || CUDA_VERSION < 12000
+    // CUDA needs to be >= 12.0 for cuDNN to work with all supported hardware.
+    // Some cuDNN versions work with CUDA 11, but it is impractical for us to
+    // test those combinations so just disable them.
+    return false;
+#endif
+    if (!debug_options.xla_gpu_enable_cudnn_fmha() ||
+        !IsComputeCapabilityAndCudnnSupported(
+            compute_capability_, cudnn_version,
+            stream_executor::dnn::VersionInfo(9, 0, 0))) {
+      return false;
+    }
+    for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
+      bool v_transposed = false;
+      bool changed = false;
+      MatchFwdResult matched_result =
+          MatchFwdMHAPatternsForCanonicalization(instr);
+      if (!matched_result.has_match) {
+        continue;
+      }
+      // We check the validity of bmms here before canonicalization so we don't
+      // modify the graph if mha fusion is not possible
+      TF_ASSIGN_OR_RETURN(
+          bool is_mha_module_supported,
+          IsMHABlockSupported(
+              matched_result.matched_bmm_1, matched_result.matched_bmm_2,
+              matched_result.need_canonicalization, matched_result.is_training,
+              matched_result.is_causal_mask,
+              matched_result.matched_custom_call_name, debug_options,
+              compute_capability_, cudnn_version));
+
+      if (!is_mha_module_supported) continue;
+      // If we have an activation with more than 1 users in non-training mode,
+      // we cannot rewrite the graph. So skip processing the rest.
+      HloInstruction* activation =
+          matched_result.need_canonicalization
+              ? matched_result.matched_bmm_2->mutable_operand(1)
+              : matched_result.matched_bmm_2->mutable_operand(0);
+      if (!matched_result.is_training && activation->user_count() > 1) {
+        VLOG(2)
+            << "Activation: " << activation->ToString()
+            << " cannot have more than 1 users in non-training mode. Skipping.";
+        continue;
+      }
+      HloInstruction* original_bmm2_producer0 =
+          matched_result.matched_bmm_2->mutable_operand(0);
+      HloInstruction* original_bmm2_producer1 =
+          matched_result.matched_bmm_2->mutable_operand(1);
+
+      HloInstruction* original_bmm2 = matched_result.matched_bmm_2;
+      std::vector<HloInstruction*> original_activation_producers;
+      for (HloInstruction* operand : activation->mutable_operands()) {
+        original_activation_producers.push_back(operand);
+      }
+      // We make sure no attention block is matched and replaced twice here
+      if (!matched_bmm1.insert(matched_result.matched_bmm_1).second) {
+        continue;
+      }
+      // If we need to canonicalize the bmm, we will assign the newly
+      // canonicalized bmm to bmm_2.
+      if (matched_result.need_canonicalization) {
+        TF_ASSIGN_OR_RETURN(matched_result.matched_bmm_2,
+                            CanonicalizeBatchedGemmForcuDNNFMHA(
+                                matched_result.matched_bmm_2, comp));
+      }
+
+      // Fuse the bmms and intermediate nodes into fMHA call, the fused call
+      // will replace bmm_2.
+      TF_ASSIGN_OR_RETURN(
+          HloInstruction * fwd_fmha_call,
+          FuseFwdMultiHeadedAttentionBlock(
+              comp, matched_result.matched_bmm_1, matched_result.matched_bmm_2,
+              matched_result.matched_bias, matched_result.matched_scale,
+              matched_result.matched_reduce_sum,
+              matched_result.matched_softmax_input,
+              matched_result.matched_dropout_rate,
+              matched_result.matched_custom_call_name, compute_capability_,
+              matched_result.is_training, changed, v_transposed,
+              matched_result.is_causal_mask));
+      any_changed |= changed;
+      if (matched_result.is_training) {
+        MatchBwdResult matched_bwd_result =
+            MatchBwdMHAPatternsForCanonicalization(
+                fwd_fmha_call, matched_result.matched_bmm_1, v_transposed);
+        if (!matched_bwd_result.has_match) {
+          VLOG(2) << "Backward pattern not matching, skipping.";
+          // restore fwd graph if bwd pattern match failed
+          TF_RETURN_IF_ERROR(
+              RestoreFwdGraph(comp, fwd_fmha_call, original_bmm2, activation,
+                              original_bmm2_producer0, original_bmm2_producer1,
+                              original_activation_producers,
+                              matched_result.need_canonicalization));
+          continue;
+        }
+        if (matched_bwd_result.matched_dbias &&
+            !(compute_capability_.IsAtLeastHopper() &&
+              cudnn_version >= stream_executor::dnn::VersionInfo(9, 0, 0))) {
+          VLOG(2) << "Flash attention dbias requires cudnn 9.0.0 + hopper.";
+          // restore fwd graph if bwd pattern match failed
+          TF_RETURN_IF_ERROR(
+              RestoreFwdGraph(comp, fwd_fmha_call, original_bmm2, activation,
+                              original_bmm2_producer0, original_bmm2_producer1,
+                              original_activation_producers,
+                              matched_result.need_canonicalization));
+          continue;
+        }
+        // Canonicalize gemms
+        if (matched_bwd_result.bmm_1_grad_1_need_canonicalization) {
+          TF_ASSIGN_OR_RETURN(
+              matched_bwd_result.matched_bmm_1_grad_1,
+              CanonicalizeBatchedGemmForcuDNNFMHA(
+                  matched_bwd_result.matched_bmm_1_grad_1, comp));
+        }
+        if (matched_bwd_result.bmm_1_grad_2_need_canonicalization) {
+          TF_ASSIGN_OR_RETURN(
+              matched_bwd_result.matched_bmm_1_grad_2,
+              CanonicalizeBatchedGemmForcuDNNFMHA(
+                  matched_bwd_result.matched_bmm_1_grad_2, comp));
+        }
+        if (matched_bwd_result.bmm_2_grad_1_need_canonicalization) {
+          TF_ASSIGN_OR_RETURN(
+              matched_bwd_result.matched_bmm_2_grad_1,
+              CanonicalizeBatchedGemmForcuDNNFMHA(
+                  matched_bwd_result.matched_bmm_2_grad_1, comp));
+        }
+        if (matched_bwd_result.bmm_2_grad_2_need_canonicalization) {
+          TF_ASSIGN_OR_RETURN(
+              matched_bwd_result.matched_bmm_2_grad_2,
+              CanonicalizeBatchedGemmForcuDNNFMHA(
+                  matched_bwd_result.matched_bmm_2_grad_2, comp));
+        }
+
+        // Fuse the corresponding gradient graph to an fMHA fused call.s
+        TF_ASSIGN_OR_RETURN(
+            changed,
+            FuseBwdMultiHeadedAttentionBlock(
+                comp, matched_bwd_result.matched_bmm_1_grad_1,
+                matched_bwd_result.matched_bmm_1_grad_2,
+                matched_bwd_result.matched_bmm_2_grad_1,
+                matched_bwd_result.matched_bmm_2_grad_2, fwd_fmha_call,
+                matched_bwd_result.matched_dbias, matched_result.matched_bias,
+                matched_bwd_result.matched_custom_call_name));
+        any_changed |= changed;
+      }
+    }
+  }
+
+  return any_changed;
+}
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_rewriter.h b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_rewriter.h
new file mode 100644
index 0000000..6a985ee
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_rewriter.h
@@ -0,0 +1,59 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_CUDNN_FUSED_MHA_REWRITER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_CUDNN_FUSED_MHA_REWRITER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/stream_executor/device_memory.h"
+#include "xla/stream_executor/dnn.h"
+
+namespace xla {
+namespace gpu {
+
+class CudnnFusedMHARewriter : public HloModulePass {
+ public:
+  explicit CudnnFusedMHARewriter(se::CudaComputeCapability cc,
+                                 se::StreamExecutor* stream_executor)
+      : compute_capability_(cc), stream_executor_(stream_executor) {}
+
+  explicit CudnnFusedMHARewriter(se::CudaComputeCapability cc,
+                                 se::dnn::VersionInfo cudnn_version)
+      : compute_capability_(cc), cudnn_version_(cudnn_version) {}
+
+  absl::string_view name() const override {
+    return "cudnn-fused-multi-headed-attention-rewriter";
+  }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+  const se::CudaComputeCapability compute_capability_;
+  se::StreamExecutor* stream_executor_ = nullptr;
+  const se::dnn::VersionInfo cudnn_version_;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_CUDNN_FUSED_MHA_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_rewriter_test.cc
new file mode 100644
index 0000000..a64fd062
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_rewriter_test.cc
@@ -0,0 +1,3021 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/cudnn_fused_mha_rewriter.h"
+
+#include <cstddef>
+#include <memory>
+#include <optional>
+#include <utility>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/algorithm/container.h"
+#include "absl/strings/string_view.h"
+#include "xla/error_spec.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/algebraic_simplifier.h"
+#include "xla/service/computation_layout.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/service/gpu/transforms/cudnn_fused_mha_transpose_fusion.h"
+#include "xla/service/hlo_cse.h"
+#include "xla/service/hlo_dce.h"
+#include "xla/service/hlo_module_config.h"
+#include "xla/service/hlo_parser.h"
+#include "xla/service/hlo_verifier.h"
+#include "xla/service/layout_normalization.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/service/reshape_decomposer.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/stream_executor/dnn.h"
+#include "xla/test_helpers.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/tsl/lib/core/status_test_util.h"
+#include "xla/util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/statusor.h"
+
+#if GOOGLE_CUDA
+#include "third_party/gpus/cuda/include/cuda.h"
+#include "third_party/gpus/cudnn/cudnn.h"  // IWYU pragma: keep
+#endif
+
+namespace xla {
+namespace gpu {
+namespace {
+
+namespace m = xla::match;
+
+class CudnnFusedMhaRewriterTestHloTest : public HloTestBase {
+ public:
+  se::CudaComputeCapability GetCudaComputeCapability() {
+    // Fake a supported compute capability to run tests,
+    // we don't run any kernels in these tests so they should be safe
+    // to run anywhere.
+    return se::CudaComputeCapability(8, 0);
+  }
+
+  se::CudaComputeCapability GetRealCudaComputeCapability() {
+    return backend()
+        .default_stream_executor()
+        ->GetDeviceDescription()
+        .cuda_compute_capability();
+  }
+
+  se::dnn::VersionInfo GetCudnnVersion() {
+    // Fake a supported compute capability to run tests,
+    // we don't run any kernels in these tests so they should be safe
+    // to run anywhere.
+    return se::dnn::VersionInfo(9, 0, 0);
+  }
+
+  CudnnFusedMhaRewriterTestHloTest()
+      : HloTestBase(/*verifier_layout_sensitive=*/false,
+                    /*allow_mixed_precision_in_hlo_verifier=*/false,
+                    /*instruction_can_change_layout_func=*/{}) {
+#if !defined(GOOGLE_CUDA) || CUDA_VERSION < 12000
+    skip_reason_ = "cuDNN fused MHA requires CUDA 12 or later.";
+    return;
+#endif
+  }
+
+ protected:
+  size_t CountFusedAttentionCall(HloModule* module, bool is_backward = false) {
+    return absl::c_count_if(module->entry_computation()->instructions(),
+                            [&](const HloInstruction* instr) {
+                              if (is_backward) {
+                                return IsBwdCustomCallTofMHA(*instr);
+                              } else {
+                                return IsFwdCustomCallTofMHA(*instr);
+                              }
+                            });
+  }
+
+  DebugOptions GetDebugOptionsForTest() override {
+    auto debug_options = HloTestBase::GetDebugOptionsForTest();
+    debug_options.set_xla_gpu_enable_cudnn_fmha(true);
+    debug_options.set_xla_gpu_fused_attention_use_cudnn_rng(true);
+    return debug_options;
+  }
+
+  HloModuleConfig GetModuleConfig() {
+    DebugOptions debug_options = GetDebugOptionsForTest();
+    HloModuleConfig config_with_fmha;
+    config_with_fmha.set_debug_options(debug_options);
+    return config_with_fmha;
+  }
+
+  // Centralize skip checks in the constructor. Unfortunately we cannot call
+  // GTEST_SKIP from the constructor. Instead, we set (if needed) `skip_reason`,
+  // and then check it from all test fixtures.
+  // An alternative would be to use the SetUp() override, but for this to be
+  // correct we'd have to ensure that all the parents' SetUp() methods are
+  // called, which is error prone.
+  std::optional<absl::string_view> skip_reason_;
+};
+
+constexpr absl::string_view
+    hlo_BF16Bmm1SoftmaxBmm2Pattern_k_hidden_not_most_minor = R"(
+HloModule fmha_test, entry_computation_layout={(bf16[16,16,256,64]{3,2,1,0},bf16[16,16,256,64]{3,2,1,0},bf16[16,16,256,64]{3,2,1,0})->bf16[16,16,256,64]{3,2,1,0}}
+
+region_0.7 {
+  Arg_0.8 = bf16[] parameter(0)
+  Arg_1.9 = bf16[] parameter(1)
+  ROOT maximum = bf16[] maximum(Arg_0.8, Arg_1.9)
+}
+
+region_1.19 {
+  Arg_0.20 = f32[] parameter(0)
+  Arg_1.21 = f32[] parameter(1)
+  ROOT add = f32[] add(Arg_0.20, Arg_1.21)
+}
+
+ENTRY main.6 {
+  Arg_2.3 = bf16[16,16,256,64]{3,2,1,0} parameter(2)
+  Arg_0.1 = bf16[16,16,256,64]{3,2,1,0} parameter(0)
+  Arg_1.2 = bf16[16,16,256,64]{2,3,1,0} parameter(1)
+  dot.0 = bf16[16,16,256,256]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}, metadata={}
+  constant = bf16[] constant(-inf)
+  reduce.11 = bf16[16,16,256]{2,1,0} reduce(dot.0, constant), dimensions={3}, to_apply=region_0.7
+  broadcast.3 = bf16[16,16,256,256]{3,2,1,0} broadcast(reduce.11), dimensions={0,1,2}
+  subtract.1 = bf16[16,16,256,256]{3,2,1,0} subtract(dot.0, broadcast.3)
+  exponential.1 = bf16[16,16,256,256]{3,2,1,0} exponential(subtract.1)
+  convert.1 = f32[16,16,256,256]{3,2,1,0} convert(exponential.1)
+  constant.1 = f32[] constant(0)
+  reduce.23 = f32[16,16,256]{2,1,0} reduce(convert.1, constant.1), dimensions={3}, to_apply=region_1.19
+  convert.2 = bf16[16,16,256]{2,1,0} convert(reduce.23)
+  broadcast.4 = bf16[16,16,256,256]{3,2,1,0} broadcast(convert.2), dimensions={0,1,2}
+  divide = bf16[16,16,256,256]{3,2,1,0} divide(exponential.1, broadcast.4)
+  ROOT dot.1 = bf16[16,16,256,64]{3,2,1,0} dot(divide, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}, metadata={}
+})";
+
+TEST_F(CudnnFusedMhaRewriterTestHloTest,
+       BF16Bmm1SoftmaxBmm2Pattern_bmm1_rhs_contracting_dim_not_most_minor) {
+  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+  TF_ASSERT_OK_AND_ASSIGN(
+      auto m, ParseAndReturnVerifiedModule(
+                  hlo_BF16Bmm1SoftmaxBmm2Pattern_k_hidden_not_most_minor));
+  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
+                                         GetCudnnVersion()};
+  TF_ASSERT_OK_AND_ASSIGN(bool result, RunHloPass(&fusedMhaRewriter, m.get()));
+  EXPECT_TRUE(result);
+  const HloInstruction* fmha;
+
+  SCOPED_TRACE(m->ToString());
+  EXPECT_THAT(
+      m->entry_computation()->root_instruction(),
+      GmockMatch(m::GetTupleElement(
+                     m::CustomCall(&fmha, {kCudnnfMHASoftmaxCallTarget}), 0)
+                     .WithShape(BF16, {16, 16, 256, 64})));
+  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+                          fmha->backend_config<GpuBackendConfig>());
+  const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config();
+  EXPECT_EQ(config.bmm1_dot_dimension_numbers().rhs_contracting_dimensions()[0],
+            2);
+}
+
+constexpr absl::string_view
+    hlo_BF16Bmm1SoftmaxBmm2Pattern_q_hidden_not_most_minor = R"(
+HloModule fmha_test, entry_computation_layout={(bf16[16,16,256,64]{3,2,1,0},bf16[16,16,256,64]{3,2,1,0},bf16[16,16,256,64]{3,2,1,0})->bf16[16,16,256,64]{3,2,1,0}}
+
+region_0.7 {
+  Arg_0.8 = bf16[] parameter(0)
+  Arg_1.9 = bf16[] parameter(1)
+  ROOT maximum = bf16[] maximum(Arg_0.8, Arg_1.9)
+}
+
+region_1.19 {
+  Arg_0.20 = f32[] parameter(0)
+  Arg_1.21 = f32[] parameter(1)
+  ROOT add = f32[] add(Arg_0.20, Arg_1.21)
+}
+
+ENTRY main.6 {
+  Arg_2.3 = bf16[16,16,256,64]{3,2,1,0} parameter(2)
+  Arg_0.1 = bf16[16,16,256,64]{2,3,1,0} parameter(0)
+  Arg_1.2 = bf16[16,16,256,64]{2,3,1,0} parameter(1)
+  dot.0 = bf16[16,16,256,256]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}, metadata={}
+  constant = bf16[] constant(-inf)
+  reduce.11 = bf16[16,16,256]{2,1,0} reduce(dot.0, constant), dimensions={3}, to_apply=region_0.7
+  broadcast.3 = bf16[16,16,256,256]{3,2,1,0} broadcast(reduce.11), dimensions={0,1,2}
+  subtract.1 = bf16[16,16,256,256]{3,2,1,0} subtract(dot.0, broadcast.3)
+  exponential.1 = bf16[16,16,256,256]{3,2,1,0} exponential(subtract.1)
+  convert.1 = f32[16,16,256,256]{3,2,1,0} convert(exponential.1)
+  constant.1 = f32[] constant(0)
+  reduce.23 = f32[16,16,256]{2,1,0} reduce(convert.1, constant.1), dimensions={3}, to_apply=region_1.19
+  convert.2 = bf16[16,16,256]{2,1,0} convert(reduce.23)
+  broadcast.4 = bf16[16,16,256,256]{3,2,1,0} broadcast(convert.2), dimensions={0,1,2}
+  divide = bf16[16,16,256,256]{3,2,1,0} divide(exponential.1, broadcast.4)
+  ROOT dot.1 = bf16[16,16,256,64]{3,2,1,0} dot(divide, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}, metadata={}
+})";
+
+TEST_F(CudnnFusedMhaRewriterTestHloTest,
+       BF16Bmm1SoftmaxBmm2Pattern_bmm1_lhs_contracting_dim_not_most_minor) {
+  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+  TF_ASSERT_OK_AND_ASSIGN(
+      auto m, ParseAndReturnVerifiedModule(
+                  hlo_BF16Bmm1SoftmaxBmm2Pattern_q_hidden_not_most_minor));
+  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
+                                         GetCudnnVersion()};
+  TF_ASSERT_OK_AND_ASSIGN(bool result, RunHloPass(&fusedMhaRewriter, m.get()));
+  EXPECT_TRUE(result);
+  const HloInstruction* fmha;
+
+  SCOPED_TRACE(m->ToString());
+  EXPECT_THAT(
+      m->entry_computation()->root_instruction(),
+      GmockMatch(m::GetTupleElement(
+                     m::CustomCall(&fmha, {kCudnnfMHASoftmaxCallTarget}), 0)
+                     .WithShape(BF16, {16, 16, 256, 64})));
+  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+                          fmha->backend_config<GpuBackendConfig>());
+  const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config();
+  EXPECT_EQ(config.bmm1_dot_dimension_numbers().lhs_contracting_dimensions()[0],
+            2);
+  EXPECT_EQ(config.bmm1_dot_dimension_numbers().rhs_contracting_dimensions()[0],
+            2);
+}
+
+constexpr absl::string_view
+    hlo_BF16Bmm1SoftmaxBmm2Pattern_v_hidden_dim_not_most_minor = R"(
+HloModule fmha_test, entry_computation_layout={(bf16[16,16,256,64]{3,2,1,0},bf16[16,16,256,64]{3,2,1,0},bf16[16,16,256,64]{3,2,1,0})->bf16[16,16,256,64]{3,2,1,0}}
+
+region_0.7 {
+  Arg_0.8 = bf16[] parameter(0)
+  Arg_1.9 = bf16[] parameter(1)
+  ROOT maximum = bf16[] maximum(Arg_0.8, Arg_1.9)
+}
+
+region_1.19 {
+  Arg_0.20 = f32[] parameter(0)
+  Arg_1.21 = f32[] parameter(1)
+  ROOT add = f32[] add(Arg_0.20, Arg_1.21)
+}
+
+ENTRY main.6 {
+  Arg_2.3 = bf16[16,16,256,64]{2,3,1,0} parameter(2)
+  Arg_0.1 = bf16[16,16,256,64]{2,3,1,0} parameter(0)
+  Arg_1.2 = bf16[16,16,256,64]{2,3,1,0} parameter(1)
+  dot.0 = bf16[16,16,256,256]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}, metadata={}
+  constant = bf16[] constant(-inf)
+  reduce.11 = bf16[16,16,256]{2,1,0} reduce(dot.0, constant), dimensions={3}, to_apply=region_0.7
+  broadcast.3 = bf16[16,16,256,256]{3,2,1,0} broadcast(reduce.11), dimensions={0,1,2}
+  subtract.1 = bf16[16,16,256,256]{3,2,1,0} subtract(dot.0, broadcast.3)
+  exponential.1 = bf16[16,16,256,256]{3,2,1,0} exponential(subtract.1)
+  convert.1 = f32[16,16,256,256]{3,2,1,0} convert(exponential.1)
+  constant.1 = f32[] constant(0)
+  reduce.23 = f32[16,16,256]{2,1,0} reduce(convert.1, constant.1), dimensions={3}, to_apply=region_1.19
+  convert.2 = bf16[16,16,256]{2,1,0} convert(reduce.23)
+  broadcast.4 = bf16[16,16,256,256]{3,2,1,0} broadcast(convert.2), dimensions={0,1,2}
+  divide = bf16[16,16,256,256]{3,2,1,0} divide(exponential.1, broadcast.4)
+  ROOT dot.1 = bf16[16,16,256,64]{3,2,1,0} dot(divide, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}, metadata={}
+})";
+
+TEST_F(CudnnFusedMhaRewriterTestHloTest,
+       BF16Bmm1SoftmaxBmm2Pattern_bmm2_non_contracting_dim_not_most_minor) {
+  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+  TF_ASSERT_OK_AND_ASSIGN(
+      auto m, ParseAndReturnVerifiedModule(
+                  hlo_BF16Bmm1SoftmaxBmm2Pattern_v_hidden_dim_not_most_minor));
+  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
+                                         GetCudnnVersion()};
+  TF_ASSERT_OK_AND_ASSIGN(bool result, RunHloPass(&fusedMhaRewriter, m.get()));
+  EXPECT_TRUE(result);
+  const HloInstruction* fmha;
+
+  SCOPED_TRACE(m->ToString());
+  EXPECT_THAT(
+      m->entry_computation()->root_instruction(),
+      GmockMatch(m::GetTupleElement(
+                     m::CustomCall(&fmha, {kCudnnfMHASoftmaxCallTarget}), 0)
+                     .WithShape(BF16, {16, 16, 256, 64})));
+  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+                          fmha->backend_config<GpuBackendConfig>());
+  const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config();
+  EXPECT_EQ(config.bmm2_dot_dimension_numbers().lhs_contracting_dimensions()[0],
+            3);
+  EXPECT_EQ(config.bmm2_dot_dimension_numbers().rhs_contracting_dimensions()[0],
+            3);
+}
+
+TEST_F(CudnnFusedMhaRewriterTestHloTest, BF16Bmm1CombinedMaskBiasSoftmaxBmm2) {
+  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+  const char* module_str = R"(
+HloModule jit__unnamed_wrapped_function_,
+entry_computation_layout={(bf16[16,256,16,64]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0},bf16[1,16,256,256]{3,2,1,0},pred[16,1,256,256]{3,2,1,0})->bf16[16,256,16,64]{3,2,1,0}}
+
+region_0.32.clone {
+  Arg_0.0 = f32[] parameter(0)
+  Arg_1.0 = f32[] parameter(1)
+  ROOT maximum.1 = f32[] maximum(Arg_0.0, Arg_1.0)
+}
+
+region_1.44 {
+  Arg_0.45 = f32[] parameter(0)
+  Arg_1.46 = f32[] parameter(1)
+  ROOT add = f32[] add(Arg_0.45, Arg_1.46)
+}
+
+ENTRY main.61 {
+  Arg_2.3 = bf16[16,256,16,64]{3,2,1,0} parameter(2), sharding={replicated}
+  transpose.5 = bf16[16,16,64,256]{3,2,1,0} transpose(Arg_2.3), dimensions={0,2,3,1}
+  Arg_0.1 = bf16[16,256,16,64]{3,2,1,0} parameter(0), sharding={replicated}
+  transpose.6 = bf16[16,16,256,64]{3,2,1,0} transpose(Arg_0.1), dimensions={0,2,1,3}
+  Arg_1.2 = bf16[16,256,16,64]{3,2,1,0} parameter(1), sharding={replicated}
+  transpose.7 = bf16[16,16,64,256]{3,2,1,0} transpose(Arg_1.2), dimensions={0,2,3,1}
+  Arg_4.5 = pred[16,1,256,256]{3,2,1,0} parameter(4), sharding={replicated}
+  bitcast.35 = pred[16,256,256]{2,1,0} bitcast(Arg_4.5)
+  convert.49 = s32[16,256,256]{2,1,0} convert(bitcast.35)
+  constant.5 = s32[] constant(0)
+  broadcast.10 = s32[16,256,256]{2,1,0} broadcast(constant.5), dimensions={}
+  compare = pred[16,256,256]{2,1,0} compare(convert.49, broadcast.10), direction=GT
+  constant.7 = bf16[] constant(0)
+  broadcast.12 = bf16[16,256,256]{2,1,0} broadcast(constant.7), dimensions={}
+  constant.9 = bf16[] constant(-9.999e+09)
+  broadcast.13 = bf16[16,256,256]{2,1,0} broadcast(constant.9), dimensions={}
+  select = bf16[16,256,256]{2,1,0} select(compare, broadcast.12, broadcast.13)
+  convert.51 = f32[16,256,256]{2,1,0} convert(select)
+  broadcast.14 = f32[16,16,256,256]{3,2,1,0} broadcast(convert.51), dimensions={0,2,3}
+  Arg_3.4 = bf16[1,16,256,256]{3,2,1,0} parameter(3), sharding={replicated}
+  bitcast.52 = bf16[16,256,256]{2,1,0} bitcast(Arg_3.4)
+  convert.52 = f32[16,256,256]{2,1,0} convert(bitcast.52)
+  broadcast.15 = f32[16,16,256,256]{3,2,1,0} broadcast(convert.52), dimensions={1,2,3}
+  add.1 = f32[16,16,256,256]{3,2,1,0} add(broadcast.14, broadcast.15)
+  dot.2 = bf16[16,16,256,256]{3,2,1,0} dot(transpose.6, transpose.7), lhs_contracting_dims={3}, rhs_contracting_dims={2}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1}
+  convert.55 = f32[16,16,256,256]{3,2,1,0} convert(dot.2)
+  add.18 = f32[16,16,256,256]{3,2,1,0} add(convert.55, add.1)
+  constant.11 = f32[] constant(-inf)
+  reduce.36 = f32[16,16,256]{2,1,0} reduce(add.18, constant.11), dimensions={3}, to_apply=region_0.32.clone
+  broadcast.17 = f32[16,16,256,256]{3,2,1,0} broadcast(reduce.36), dimensions={0,1,2}
+  subtract.1 = f32[16,16,256,256]{3,2,1,0} subtract(add.18, broadcast.17)
+  exponential.1 = f32[16,16,256,256]{3,2,1,0} exponential(subtract.1)
+  constant.14 = f32[] constant(0)
+  reduce.48 = f32[16,16,256]{2,1,0} reduce(exponential.1, constant.14), dimensions={3}, to_apply=region_1.44
+  broadcast.18 = f32[16,16,256,256]{3,2,1,0} broadcast(reduce.48), dimensions={0,1,2}
+  divide = f32[16,16,256,256]{3,2,1,0} divide(exponential.1, broadcast.18)
+  convert.68 = bf16[16,16,256,256]{3,2,1,0} convert(divide)
+  dot.1 = bf16[16,16,64,256]{3,2,1,0} dot(transpose.5, convert.68), lhs_contracting_dims={3}, rhs_contracting_dims={3}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1}
+  ROOT transpose.8 = bf16[16,256,16,64]{3,2,1,0} transpose(dot.1), dimensions={0,3,1,2}
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
+                                         GetCudnnVersion()};
+  TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
+  const HloInstruction* fmha;
+
+  SCOPED_TRACE(m->ToString());
+  EXPECT_THAT(
+      m->entry_computation()->root_instruction(),
+      GmockMatch(
+          m::Transpose(
+              m::Transpose(m::GetTupleElement(
+                  m::CustomCall(&fmha, {kCudnnfMHAScaleBiasSoftmaxCallTarget}),
+                  0)))
+              .WithShape(BF16, {16, 256, 16, 64})));
+  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+                          fmha->backend_config<GpuBackendConfig>());
+  EXPECT_EQ(fmha->operands().size(), 4);
+}
+
+TEST_F(CudnnFusedMhaRewriterTestHloTest, F16Bmm1UnfusedSoftmaxBmm2) {
+  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+  const char* module_str = R"(
+HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(f16[2,6,40,64]{3,2,1,0},f16[2,6,64,40]{3,2,1,0},f16[2,6,40,64]{3,2,1,0})->f16[2,6,40,64]{3,2,1,0}}
+
+region_0.7 {
+  Arg_0.8 = f16[] parameter(0)
+  Arg_1.9 = f16[] parameter(1)
+  ROOT maximum = f16[] maximum(Arg_0.8, Arg_1.9)
+}
+
+region_1.19 {
+  Arg_0.20 = f32[] parameter(0)
+  Arg_1.21 = f32[] parameter(1)
+  ROOT add = f32[] add(Arg_0.20, Arg_1.21)
+}
+
+ENTRY main.31 {
+  Arg_0.1 = f16[2,6,40,64]{3,2,1,0} parameter(0), sharding={replicated}
+  Arg_1.2 = f16[2,6,64,40]{3,2,1,0} parameter(1), sharding={replicated}
+  dot = f16[2,6,40,40]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={3}, rhs_contracting_dims={2}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1}
+  constant = f16[] constant(-inf)
+  reduce.11 = f16[2,6,40]{2,1,0} reduce(dot, constant), dimensions={3}, to_apply=region_0.7
+  broadcast.3 = f16[2,6,40,40]{3,2,1,0} broadcast(reduce.11), dimensions={0,1,2}
+  subtract.1 = f16[2,6,40,40]{3,2,1,0} subtract(dot, broadcast.3)
+  exponential.1 = f16[2,6,40,40]{3,2,1,0} exponential(subtract.1)
+  convert.1 = f32[2,6,40,40]{3,2,1,0} convert(exponential.1)
+  constant.1 = f32[] constant(0)
+  reduce.23 = f32[2,6,40]{2,1,0} reduce(convert.1, constant.1), dimensions={3}, to_apply=region_1.19
+  convert.2 = f16[2,6,40]{2,1,0} convert(reduce.23)
+  broadcast.4 = f16[2,6,40,40]{3,2,1,0} broadcast(convert.2), dimensions={0,1,2}
+  divide = f16[2,6,40,40]{3,2,1,0} divide(exponential.1, broadcast.4)
+  Arg_2.3 = f16[2,6,40,64]{3,2,1,0} parameter(2), sharding={replicated}
+  ROOT dot.1 = f16[2,6,40,64]{3,2,1,0} dot(divide, Arg_2.3), lhs_contracting_dims={3}, rhs_contracting_dims={2}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1}
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
+                                         GetCudnnVersion()};
+  TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
+  const HloInstruction* fmha;
+
+  SCOPED_TRACE(m->ToString());
+  EXPECT_THAT(
+      m->entry_computation()->root_instruction(),
+      GmockMatch(m::GetTupleElement(
+                     m::CustomCall(&fmha, {kCudnnfMHASoftmaxCallTarget}), 0)
+                     .WithShape(F16, {2, 6, 40, 64})));
+  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+                          fmha->backend_config<GpuBackendConfig>());
+  const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config();
+  EXPECT_FLOAT_EQ(config.fmha_scale(), 1.0);
+  EXPECT_FLOAT_EQ(config.dropout_rate(), 0.0);
+  EXPECT_EQ(fmha->operands().size(), 3);
+}
+
+TEST_F(CudnnFusedMhaRewriterTestHloTest,
+       BF16Bmm1ConvertedMaskAddedAfterFirstGemmSoftmaxBmm2) {
+  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+  const char* module_str = R"(
+HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[16,256,16,64]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0},pred[16,1,256,256]{3,2,1,0})->bf16[16,256,16,64]{3,2,1,0}}
+
+region_0.27.clone {
+  Arg_0.0 = f32[] parameter(0)
+  Arg_1.0 = f32[] parameter(1)
+  ROOT maximum.1 = f32[] maximum(Arg_0.0, Arg_1.0)
+}
+
+region_1.39 {
+  Arg_0.40 = f32[] parameter(0)
+  Arg_1.41 = f32[] parameter(1)
+  ROOT add = f32[] add(Arg_0.40, Arg_1.41)
+}
+
+ENTRY main.56 {
+  Arg_2.3 = bf16[16,256,16,64]{3,2,1,0} parameter(2), sharding={replicated}
+  transpose.5 = bf16[16,16,64,256]{3,2,1,0} transpose(Arg_2.3), dimensions={0,2,3,1}
+  Arg_0.1 = bf16[16,256,16,64]{3,2,1,0} parameter(0), sharding={replicated}
+  transpose.6 = bf16[16,16,256,64]{3,2,1,0} transpose(Arg_0.1), dimensions={0,2,1,3}
+  Arg_1.2 = bf16[16,256,16,64]{3,2,1,0} parameter(1), sharding={replicated}
+  transpose.7 = bf16[16,16,64,256]{3,2,1,0} transpose(Arg_1.2), dimensions={0,2,3,1}
+  dot = bf16[16,16,256,256]{3,2,1,0} dot(transpose.6, transpose.7), lhs_contracting_dims={3}, rhs_contracting_dims={2}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1}
+  convert.47 = f32[16,16,256,256]{3,2,1,0} convert(dot)
+  Arg_3.4 = pred[16,1,256,256]{3,2,1,0} parameter(3), sharding={replicated}
+  bitcast.37 = pred[16,256,256]{2,1,0} bitcast(Arg_3.4)
+  convert.42 = s32[16,256,256]{2,1,0} convert(bitcast.37)
+  constant.6 = s32[] constant(0)
+  broadcast.9 = s32[16,256,256]{2,1,0} broadcast(constant.6), dimensions={}
+  compare = pred[16,256,256]{2,1,0} compare(convert.42, broadcast.9), direction=GT
+  constant.8 = bf16[] constant(0)
+  broadcast.11 = bf16[16,256,256]{2,1,0} broadcast(constant.8), dimensions={}
+  constant.10 = bf16[] constant(-9.999e+09)
+  broadcast.12 = bf16[16,256,256]{2,1,0} broadcast(constant.10), dimensions={}
+  select = bf16[16,256,256]{2,1,0} select(compare, broadcast.11, broadcast.12)
+  convert.48 = f32[16,256,256]{2,1,0} convert(select)
+  broadcast.14 = f32[16,16,256,256]{3,2,1,0} broadcast(convert.48), dimensions={0,2,3}
+  add.2 = f32[16,16,256,256]{3,2,1,0} add(convert.47, broadcast.14)
+  constant.13 = f32[] constant(-inf)
+  reduce.31 = f32[16,16,256]{2,1,0} reduce(add.2, constant.13), dimensions={3}, to_apply=region_0.27.clone
+  broadcast.16 = f32[16,16,256,256]{3,2,1,0} broadcast(reduce.31), dimensions={0,1,2}
+  subtract.1 = f32[16,16,256,256]{3,2,1,0} subtract(add.2, broadcast.16)
+  exponential.1 = f32[16,16,256,256]{3,2,1,0} exponential(subtract.1)
+  constant.14 = f32[] constant(0)
+  reduce.43 = f32[16,16,256]{2,1,0} reduce(exponential.1, constant.14), dimensions={3}, to_apply=region_1.39
+  broadcast.17 = f32[16,16,256,256]{3,2,1,0} broadcast(reduce.43), dimensions={0,1,2}
+  divide = f32[16,16,256,256]{3,2,1,0} divide(exponential.1, broadcast.17)
+  convert.63 = bf16[16,16,256,256]{3,2,1,0} convert(divide)
+  dot.1 = bf16[16,16,64,256]{3,2,1,0} dot(transpose.5, convert.63), lhs_contracting_dims={3}, rhs_contracting_dims={3}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1}
+  ROOT transpose.8 = bf16[16,256,16,64]{3,2,1,0} transpose(dot.1), dimensions={0,3,1,2}
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
+                                         GetCudnnVersion()};
+  TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
+  const HloInstruction* fmha;
+
+  SCOPED_TRACE(m->ToString());
+  EXPECT_THAT(
+      m->entry_computation()->root_instruction(),
+      GmockMatch(
+          m::Transpose(
+              m::Transpose(m::GetTupleElement(
+                  m::CustomCall(&fmha, {kCudnnfMHAScaleBiasSoftmaxCallTarget}),
+                  0)))
+              .WithShape(BF16, {16, 256, 16, 64})));
+  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+                          fmha->backend_config<GpuBackendConfig>());
+  EXPECT_EQ(fmha->operands().size(), 4);
+}
+
+// negative test
+TEST_F(CudnnFusedMhaRewriterTestHloTest,
+       BF16Bmm1Bmm2Pattern_bmm1_contracting_dim_not_equal_64) {
+  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+  const char* module_str = R"(
+HloModule fmha_test, entry_computation_layout={(bf16[16,16,256,32]{3,2,1,0},bf16[16,16,256,32]{3,2,1,0},bf16[16,16,256,64]{3,2,1,0})->bf16[16,16,256,64]{3,2,1,0}}
+ENTRY main.6 {
+  Arg_2.3 = bf16[16,16,256,64]{3,2,1,0} parameter(2)
+  Arg_0.1 = bf16[16,16,256,32]{3,2,1,0} parameter(0)
+  Arg_1.2 = bf16[16,16,256,32]{3,2,1,0} parameter(1)
+  dot.0 = bf16[16,16,256,256]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}, metadata={}
+  ROOT dot.1 = bf16[16,16,256,64]{3,2,1,0} dot(dot.0, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}, metadata={}
+}
+
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
+                                         GetCudnnVersion()};
+  TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
+  const HloInstruction* fmha;
+
+  SCOPED_TRACE(m->ToString());
+  EXPECT_THAT(m->entry_computation()->root_instruction(),
+              GmockMatch(m::Dot(&fmha, m::Dot(m::Parameter(0), m::Parameter(1)),
+                                m::Parameter(2))
+                             .WithShape(BF16, {16, 16, 256, 64})));
+}
+
+TEST_F(CudnnFusedMhaRewriterTestHloTest,
+       BF16Bmm1Bmm2Pattern_bmm2_rhs_non_contracting_dim_not_equal_64) {
+  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+  const char* module_str = R"(
+HloModule fmha_test, entry_computation_layout={(bf16[16,16,256,64]{3,2,1,0},bf16[16,16,256,64]{3,2,1,0},bf16[16,16,256,32]{3,2,1,0})->bf16[16,16,256,32]{3,2,1,0}}
+ENTRY main.6 {
+  Arg_2.3 = bf16[16,16,256,32]{3,2,1,0} parameter(2)
+  Arg_0.1 = bf16[16,16,256,64]{3,2,1,0} parameter(0)
+  Arg_1.2 = bf16[16,16,256,64]{3,2,1,0} parameter(1)
+  dot.0 = bf16[16,16,256,256]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}, metadata={}
+  ROOT dot.1 = bf16[16,16,256,32]{3,2,1,0} dot(dot.0, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}, metadata={}
+}
+
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
+                                         GetCudnnVersion()};
+  TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
+  const HloInstruction* fmha;
+
+  SCOPED_TRACE(m->ToString());
+  EXPECT_THAT(m->entry_computation()->root_instruction(),
+              GmockMatch(m::Dot(&fmha, m::Op(), m::Parameter(2))
+                             .WithShape(BF16, {16, 16, 256, 32})));
+}
+
+// check if MHA is unsupported, canonicalization will not kick in
+TEST_F(CudnnFusedMhaRewriterTestHloTest,
+       BF16Bmm1Bmm2PatternUncanonicalized_bmm1_contracting_dim_not_equal_64) {
+  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+  const char* module_str = R"(
+HloModule fmha_test, entry_computation_layout={(bf16[16,16,256,32]{3,2,1,0},bf16[16,16,256,32]{3,2,1,0},bf16[16,16,256,64]{3,2,1,0})->bf16[16,16,64,256]{3,2,1,0}}
+
+ENTRY main.6 {
+  Arg_2.3 = bf16[16,16,256,64]{3,2,1,0} parameter(2)
+  Arg_0.1 = bf16[16,16,256,32]{3,2,1,0} parameter(0)
+  Arg_1.2 = bf16[16,16,256,32]{3,2,1,0} parameter(1)
+  dot.0 = bf16[16,16,256,256]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}, metadata={}
+  ROOT dot.1 = bf16[16,16,64,256]{3,2,1,0} dot(Arg_2.3, dot.0), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}, metadata={}
+}
+
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
+                                         GetCudnnVersion()};
+
+  TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
+  const HloInstruction* fmha;
+
+  SCOPED_TRACE(m->ToString());
+  EXPECT_THAT(m->entry_computation()->root_instruction(),
+              GmockMatch(m::Dot(&fmha, m::Parameter(2), m::Op())
+                             .WithShape(BF16, {16, 16, 64, 256})));
+}
+
+TEST_F(CudnnFusedMhaRewriterTestHloTest, BF16Bmm1BiasSoftmaxDropoutBmm2) {
+  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+  const char* module_str = R"(
+HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[16,256,16,64]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0},bf16[1,16,256,256]{3,2,1,0})->bf16[16,256,16,64]{3,2,1,0}}
+
+region_0.34 {
+  Arg_0.35 = bf16[] parameter(0)
+  Arg_1.36 = bf16[] parameter(1)
+  ROOT maximum.37 = bf16[] maximum(Arg_0.35, Arg_1.36)
+}
+
+region_1.46 {
+  Arg_0.47 = f32[] parameter(0)
+  Arg_1.48 = f32[] parameter(1)
+  ROOT add.49 = f32[] add(Arg_0.47, Arg_1.48)
+}
+
+ENTRY main.82 {
+  Arg_2.3 = bf16[16,256,16,64]{3,2,1,0} parameter(2), sharding={replicated}
+  copy = bf16[16,256,16,64]{1,3,2,0} copy(Arg_2.3), sharding={replicated}
+  transpose.2 = bf16[16,16,64,256]{3,2,1,0} transpose(copy), dimensions={0,2,3,1}
+  Arg_0.1 = bf16[16,256,16,64]{3,2,1,0} parameter(0), sharding={replicated}
+  copy.1 = bf16[16,256,16,64]{3,1,2,0} copy(Arg_0.1), sharding={replicated}
+  transpose = bf16[16,16,256,64]{3,2,1,0} transpose(copy.1), dimensions={0,2,1,3}
+  Arg_1.2 = bf16[16,256,16,64]{3,2,1,0} parameter(1), sharding={replicated}
+  copy.2 = bf16[16,256,16,64]{1,3,2,0} copy(Arg_1.2), sharding={replicated}
+  transpose.1 = bf16[16,16,64,256]{3,2,1,0} transpose(copy.2), dimensions={0,2,3,1}
+  dot = bf16[16,16,256,256]{3,2,1,0} dot(transpose, transpose.1), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  Arg_3.4 = bf16[1,16,256,256]{3,2,1,0} parameter(3), sharding={replicated}
+  reshape.31 = bf16[16,256,256]{2,1,0} reshape(Arg_3.4)
+  broadcast.32 = bf16[16,16,256,256]{3,2,1,0} broadcast(reshape.31), dimensions={1,2,3}
+  add.33 = bf16[16,16,256,256]{3,2,1,0} add(dot, broadcast.32)
+  constant.21 = bf16[] constant(-inf)
+  reduce.38 = bf16[16,16,256]{2,1,0} reduce(add.33, constant.21), dimensions={3}, to_apply=region_0.34
+  broadcast.42 = bf16[16,16,256,256]{3,2,1,0} broadcast(reduce.38), dimensions={0,1,2}
+  subtract.43 = bf16[16,16,256,256]{3,2,1,0} subtract(add.33, broadcast.42)
+  exponential.44 = bf16[16,16,256,256]{3,2,1,0} exponential(subtract.43)
+  convert.45 = f32[16,16,256,256]{3,2,1,0} convert(exponential.44)
+  constant.9 = f32[] constant(0)
+  reduce.50 = f32[16,16,256]{2,1,0} reduce(convert.45, constant.9), dimensions={3}, to_apply=region_1.46
+  convert.1 = bf16[16,16,256]{2,1,0} convert(reduce.50)
+  broadcast.55 = bf16[16,16,256,256]{3,2,1,0} broadcast(convert.1), dimensions={0,1,2}
+  divide.56 = bf16[16,16,256,256]{3,2,1,0} divide(exponential.44, broadcast.55)
+  constant.18 = u32[1]{0} constant({255383827})
+  constant.17 = u32[1]{0} constant({267815257})
+  constant.2 = u32[1]{0} constant({0})
+  constant.19 = u32[1]{0} constant({3213575472})
+  custom-call.26 = (u32[1]{0}, u32[1]{0}) custom-call(constant.18, constant.17, constant.2, constant.19), custom_call_target="cu_threefry2x32", operand_layout_constraints={u32[1]{0}, u32[1]{0}, u32[1]{0}, u32[1]{0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config="\001\000\000\000\000\000\000\000"
+  get-tuple-element.27 = u32[1]{0} get-tuple-element(custom-call.26), index=0
+  reshape.58 = u32[] reshape(get-tuple-element.27)
+  broadcast.62 = u32[32768]{0} broadcast(reshape.58), dimensions={}
+  get-tuple-element.28 = u32[1]{0} get-tuple-element(custom-call.26), index=1
+  reshape.59 = u32[] reshape(get-tuple-element.28)
+  broadcast.63 = u32[32768]{0} broadcast(reshape.59), dimensions={}
+  iota.57 = u32[65536]{0} iota(), iota_dimension=0
+  slice.60 = u32[32768]{0} slice(iota.57), slice={[0:32768]}
+  slice.61 = u32[32768]{0} slice(iota.57), slice={[32768:65536]}
+  custom-call.64 = (u32[32768]{0}, u32[32768]{0}) custom-call(broadcast.62, broadcast.63, slice.60, slice.61), custom_call_target="cu_threefry2x32", operand_layout_constraints={u32[32768]{0}, u32[32768]{0}, u32[32768]{0}, u32[32768]{0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config="\000\200\000\000\000\000\000\000"
+  get-tuple-element.65 = u32[32768]{0} get-tuple-element(custom-call.64), index=0
+  get-tuple-element.66 = u32[32768]{0} get-tuple-element(custom-call.64), index=1
+  concatenate.67 = u32[65536]{0} concatenate(get-tuple-element.65, get-tuple-element.66), dimensions={0}
+  constant.15 = u32[] constant(9)
+  broadcast.3 = u32[65536]{0} broadcast(constant.15), dimensions={}
+  shift-right-logical.0 = u32[65536]{0} shift-right-logical(concatenate.67, broadcast.3)
+  constant.13 = u32[] constant(1065353216)
+  broadcast.11 = u32[65536]{0} broadcast(constant.13), dimensions={}
+  or.0 = u32[65536]{0} or(shift-right-logical.0, broadcast.11)
+  bitcast-convert.0 = f32[65536]{0} bitcast-convert(or.0)
+  constant.3 = f32[] constant(-1)
+  broadcast.17 = f32[65536]{0} broadcast(constant.3), dimensions={}
+  add.1 = f32[65536]{0} add(bitcast-convert.0, broadcast.17)
+  broadcast.18 = f32[65536]{0} broadcast(constant.9), dimensions={}
+  maximum.0 = f32[65536]{0} maximum(add.1, broadcast.18)
+  constant.7 = f32[] constant(0.9)
+  broadcast.19 = f32[65536]{0} broadcast(constant.7), dimensions={}
+  compare.0 = pred[65536]{0} compare(maximum.0, broadcast.19), direction=LT
+  constant = bf16[] constant(1.109)
+  broadcast.20 = bf16[65536]{0} broadcast(constant), dimensions={}
+  constant.4 = bf16[] constant(0)
+  broadcast.21 = bf16[65536]{0} broadcast(constant.4), dimensions={}
+  select.1 = bf16[65536]{0} select(compare.0, broadcast.20, broadcast.21)
+  reshape.19 = bf16[16,16,256]{2,1,0} reshape(select.1)
+  broadcast.9 = bf16[16,16,256,256]{3,2,1,0} broadcast(reshape.19), dimensions={0,1,3}
+  multiply.79 = bf16[16,16,256,256]{3,2,1,0} multiply(divide.56, broadcast.9)
+  dot.1 = bf16[16,16,64,256]{3,2,1,0} dot(transpose.2, multiply.79), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
+  transpose.81 = bf16[16,256,16,64]{1,3,2,0} transpose(dot.1), dimensions={0,3,1,2}
+  ROOT copy.3 = bf16[16,256,16,64]{3,2,1,0} copy(transpose.81)
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
+                                         GetCudnnVersion()};
+  TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
+  const HloInstruction* fmha;
+
+  SCOPED_TRACE(m->ToString());
+  EXPECT_THAT(
+      m->entry_computation()->root_instruction(),
+      GmockMatch(
+          m::Copy(m::Transpose(m::Transpose(m::GetTupleElement(
+                      m::CustomCall(
+                          &fmha, {kCudnnfMHAScaleBiasSoftmaxDropoutCallTarget}),
+                      0))))
+              .WithShape(BF16, {16, 256, 16, 64})));
+  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+                          fmha->backend_config<GpuBackendConfig>());
+  const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config();
+  EXPECT_EQ(fmha->operands().size(), 4);
+  EXPECT_NEAR(config.dropout_rate(), 0.1, 1e-2);
+}
+
+TEST_F(CudnnFusedMhaRewriterTestHloTest,
+       BF16Bmm1ScaleBiasSoftmaxDropoutForm2Bmm2) {
+  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+  const char* module_str = R"(
+HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[32,40,60,64]{3,2,1,0},bf16[32,40,60,64]{3,2,1,0},bf16[32,40,60,64]{3,2,1,0})->bf16[32,40,60,64]{3,2,1,0}}, allow_spmd_sharding_propagation_to_output={true}
+
+region_0.29 {
+  Arg_0.30 = bf16[] parameter(0)
+  Arg_1.31 = bf16[] parameter(1)
+  ROOT maximum.32 = bf16[] maximum(Arg_0.30, Arg_1.31)
+}
+
+region_1.41 {
+  Arg_0.42 = f32[] parameter(0)
+  Arg_1.43 = f32[] parameter(1)
+  ROOT add.44 = f32[] add(Arg_0.42, Arg_1.43)
+}
+
+ENTRY main.79 {
+  Arg_2.3 = bf16[32,40,60,64]{3,2,1,0} parameter(2), sharding={replicated}
+  copy = bf16[32,40,60,64]{1,3,2,0} copy(Arg_2.3), sharding={replicated}
+  transpose.2 = bf16[32,60,64,40]{3,2,1,0} transpose(copy), dimensions={0,2,3,1}
+  constant.19 = u32[1]{0} constant({2718843009})
+  constant.18 = u32[1]{0} constant({1272950319})
+  constant.2 = u32[1]{0} constant({0})
+  constant.20 = u32[1]{0} constant({2711844646})
+  custom-call.54 = (u32[1]{0}, u32[1]{0}) custom-call(constant.19, constant.18, constant.2, constant.20), custom_call_target="cu_threefry2x32", operand_layout_constraints={u32[1]{0}, u32[1]{0}, u32[1]{0}, u32[1]{0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config="\001\000\000\000\000\000\000\000"
+  get-tuple-element.55 = u32[1]{0} get-tuple-element(custom-call.54), index=0
+  reshape.58 = u32[] reshape(get-tuple-element.55)
+  broadcast.62 = u32[1536000]{0} broadcast(reshape.58), dimensions={}
+  get-tuple-element.56 = u32[1]{0} get-tuple-element(custom-call.54), index=1
+  reshape.59 = u32[] reshape(get-tuple-element.56)
+  broadcast.63 = u32[1536000]{0} broadcast(reshape.59), dimensions={}
+  iota.57 = u32[3072000]{0} iota(), iota_dimension=0
+  slice.60 = u32[1536000]{0} slice(iota.57), slice={[0:1536000]}
+  slice.61 = u32[1536000]{0} slice(iota.57), slice={[1536000:3072000]}
+  custom-call.64 = (u32[1536000]{0}, u32[1536000]{0}) custom-call(broadcast.62, broadcast.63, slice.60, slice.61), custom_call_target="cu_threefry2x32", operand_layout_constraints={u32[1536000]{0}, u32[1536000]{0}, u32[1536000]{0}, u32[1536000]{0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config="\000p\027\000\000\000\000\000"
+  get-tuple-element.65 = u32[1536000]{0} get-tuple-element(custom-call.64), index=0
+  get-tuple-element.66 = u32[1536000]{0} get-tuple-element(custom-call.64), index=1
+  concatenate.67 = u32[3072000]{0} concatenate(get-tuple-element.65, get-tuple-element.66), dimensions={0}
+  constant.16 = u32[] constant(9)
+  broadcast.2 = u32[3072000]{0} broadcast(constant.16), dimensions={}
+  shift-right-logical.0 = u32[3072000]{0} shift-right-logical(concatenate.67, broadcast.2)
+  constant.14 = u32[] constant(1065353216)
+  broadcast.6 = u32[3072000]{0} broadcast(constant.14), dimensions={}
+  or.0 = u32[3072000]{0} or(shift-right-logical.0, broadcast.6)
+  bitcast-convert.0 = f32[3072000]{0} bitcast-convert(or.0)
+  constant.3 = f32[] constant(-1)
+  broadcast.8 = f32[3072000]{0} broadcast(constant.3), dimensions={}
+  add.1 = f32[3072000]{0} add(bitcast-convert.0, broadcast.8)
+  constant.10 = f32[] constant(0)
+  broadcast.10 = f32[3072000]{0} broadcast(constant.10), dimensions={}
+  maximum.0 = f32[3072000]{0} maximum(add.1, broadcast.10)
+  constant.8 = f32[] constant(0.9)
+  broadcast.12 = f32[3072000]{0} broadcast(constant.8), dimensions={}
+  compare.0 = pred[3072000]{0} compare(maximum.0, broadcast.12), direction=LT
+  reshape.18 = pred[32,60,40,40]{3,2,1,0} reshape(compare.0)
+  Arg_0.1 = bf16[32,40,60,64]{3,2,1,0} parameter(0), sharding={replicated}
+  copy.1 = bf16[32,40,60,64]{3,1,2,0} copy(Arg_0.1), sharding={replicated}
+  transpose = bf16[32,60,40,64]{3,2,1,0} transpose(copy.1), dimensions={0,2,1,3}
+  Arg_1.2 = bf16[32,40,60,64]{3,2,1,0} parameter(1), sharding={replicated}
+  copy.2 = bf16[32,40,60,64]{1,3,2,0} copy(Arg_1.2), sharding={replicated}
+  transpose.1 = bf16[32,60,64,40]{3,2,1,0} transpose(copy.2), dimensions={0,2,3,1}
+  dot = bf16[32,60,40,40]{3,2,1,0} dot(transpose, transpose.1), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  constant.25 = bf16[] constant(1)
+  broadcast.26 = bf16[32,60,40,40]{3,2,1,0} broadcast(constant.25), dimensions={}
+  add.28 = bf16[32,60,40,40]{3,2,1,0} add(dot, broadcast.26)
+  constant.24 = bf16[] constant(-inf)
+  reduce.33 = bf16[32,60,40]{2,1,0} reduce(add.28, constant.24), dimensions={3}, to_apply=region_0.29
+  broadcast.37 = bf16[32,60,40,40]{3,2,1,0} broadcast(reduce.33), dimensions={0,1,2}
+  subtract.38 = bf16[32,60,40,40]{3,2,1,0} subtract(add.28, broadcast.37)
+  exponential.39 = bf16[32,60,40,40]{3,2,1,0} exponential(subtract.38)
+  convert.40 = f32[32,60,40,40]{3,2,1,0} convert(exponential.39)
+  reduce.45 = f32[32,60,40]{2,1,0} reduce(convert.40, constant.10), dimensions={3}, to_apply=region_1.41
+  convert.0 = bf16[32,60,40]{2,1,0} convert(reduce.45)
+  broadcast.50 = bf16[32,60,40,40]{3,2,1,0} broadcast(convert.0), dimensions={0,1,2}
+  divide.51 = bf16[32,60,40,40]{3,2,1,0} divide(exponential.39, broadcast.50)
+  constant = bf16[] constant(1.109)
+  broadcast.1 = bf16[32,60,40,40]{3,2,1,0} broadcast(constant), dimensions={}
+  multiply = bf16[32,60,40,40]{3,2,1,0} multiply(divide.51, broadcast.1)
+  constant.4 = bf16[] constant(0)
+  broadcast.5 = bf16[32,60,40,40]{3,2,1,0} broadcast(constant.4), dimensions={}
+  select.76 = bf16[32,60,40,40]{3,2,1,0} select(reshape.18, multiply, broadcast.5)
+  dot.1 = bf16[32,60,64,40]{3,2,1,0} dot(transpose.2, select.76), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
+  transpose.78 = bf16[32,40,60,64]{1,3,2,0} transpose(dot.1), dimensions={0,3,1,2}
+  ROOT copy.3 = bf16[32,40,60,64]{3,2,1,0} copy(transpose.78)
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
+                                         GetCudnnVersion()};
+  TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
+  const HloInstruction* fmha;
+
+  SCOPED_TRACE(m->ToString());
+  EXPECT_THAT(
+      m->entry_computation()->root_instruction(),
+      GmockMatch(
+          m::Copy(m::Transpose(m::Transpose(m::GetTupleElement(
+                      m::CustomCall(
+                          &fmha, {kCudnnfMHAScaleBiasSoftmaxDropoutCallTarget}),
+                      0))))
+              .WithShape(BF16, {32, 40, 60, 64})));
+  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+                          fmha->backend_config<GpuBackendConfig>());
+  const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config();
+  EXPECT_NEAR(config.dropout_rate(), 0.1, 1e-2);
+  EXPECT_EQ(fmha->operands().size(), 4);
+}
+
+TEST_F(CudnnFusedMhaRewriterTestHloTest, BF16TrainingBmm1Bmm2) {
+  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+  const char* module_str = R"(
+HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[16,256,16,64]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0})->(bf16[16,256,16,64]{3,2,1,0}, bf16[16,256,16,64]{3,2,1,0}, bf16[16,256,16,64]{3,2,1,0}, bf16[16,256,16,64]{3,2,1,0})}
+
+ENTRY main.17 {
+  Arg_2.3 = bf16[16,256,16,64]{3,2,1,0} parameter(2), sharding={replicated}
+  copy = bf16[16,256,16,64]{1,3,2,0} copy(Arg_2.3), sharding={replicated}
+  transpose.2 = bf16[16,16,64,256]{3,2,1,0} transpose(copy), dimensions={0,2,3,1}
+  Arg_0.1 = bf16[16,256,16,64]{3,2,1,0} parameter(0), sharding={replicated}
+  copy.1 = bf16[16,256,16,64]{3,1,2,0} copy(Arg_0.1), sharding={replicated}
+  transpose = bf16[16,16,256,64]{3,2,1,0} transpose(copy.1), dimensions={0,2,1,3}
+  Arg_1.2 = bf16[16,256,16,64]{3,2,1,0} parameter(1), sharding={replicated}
+  copy.2 = bf16[16,256,16,64]{1,3,2,0} copy(Arg_1.2), sharding={replicated}
+  transpose.1 = bf16[16,16,64,256]{3,2,1,0} transpose(copy.2), dimensions={0,2,3,1}
+  dot = bf16[16,16,256,256]{3,2,1,0} dot(transpose, transpose.1), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  dot.1 = bf16[16,16,64,256]{3,2,1,0} dot(transpose.2, dot), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
+  transpose.7 = bf16[16,256,16,64]{1,3,2,0} transpose(dot.1), dimensions={0,3,1,2}
+  Arg_3.4 = bf16[16,256,16,64]{3,2,1,0} parameter(3), sharding={replicated}
+  copy.3 = bf16[16,256,16,64]{3,1,2,0} copy(Arg_3.4), sharding={replicated}
+  transpose.4 = bf16[16,16,256,64]{3,2,1,0} transpose(copy.3), dimensions={0,2,1,3}
+  dot.2 = bf16[16,16,256,256]{3,2,1,0} dot(transpose.4, transpose.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  copy.4 = bf16[16,256,16,64]{3,1,2,0} copy(Arg_1.2), sharding={replicated}
+  transpose.12 = bf16[16,16,256,64]{3,2,1,0} transpose(copy.4), dimensions={0,2,1,3}
+  dot.4 = bf16[16,16,256,64]{3,2,1,0} dot(dot.2, transpose.12), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  transpose.15 = bf16[16,256,16,64]{3,1,2,0} transpose(dot.4), dimensions={0,2,1,3}
+  dot.3 = bf16[16,16,256,64]{3,2,1,0} dot(dot.2, transpose), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  transpose.13 = bf16[16,256,16,64]{3,1,2,0} transpose(dot.3), dimensions={0,2,1,3}
+  copy.5 = bf16[16,256,16,64]{1,3,2,0} copy(Arg_3.4), sharding={replicated}
+  transpose.8 = bf16[16,16,64,256]{3,2,1,0} transpose(copy.5), dimensions={0,2,3,1}
+  dot.10 = bf16[16,16,64,256]{3,2,1,0} dot(transpose.8, dot), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  transpose.11 = bf16[16,256,16,64]{1,3,2,0} transpose(dot.10), dimensions={0,3,1,2}
+  tuple.16 = (bf16[16,256,16,64]{1,3,2,0}, bf16[16,256,16,64]{3,1,2,0}, bf16[16,256,16,64]{3,1,2,0}, bf16[16,256,16,64]{1,3,2,0}) tuple(transpose.7, transpose.15, transpose.13, transpose.11)
+  get-tuple-element = bf16[16,256,16,64]{1,3,2,0} get-tuple-element(tuple.16), index=0
+  copy.6 = bf16[16,256,16,64]{3,2,1,0} copy(get-tuple-element)
+  get-tuple-element.1 = bf16[16,256,16,64]{3,1,2,0} get-tuple-element(tuple.16), index=1
+  copy.7 = bf16[16,256,16,64]{3,2,1,0} copy(get-tuple-element.1)
+  get-tuple-element.2 = bf16[16,256,16,64]{3,1,2,0} get-tuple-element(tuple.16), index=2
+  copy.8 = bf16[16,256,16,64]{3,2,1,0} copy(get-tuple-element.2)
+  get-tuple-element.3 = bf16[16,256,16,64]{1,3,2,0} get-tuple-element(tuple.16), index=3
+  copy.9 = bf16[16,256,16,64]{3,2,1,0} copy(get-tuple-element.3)
+  ROOT tuple = (bf16[16,256,16,64]{3,2,1,0}, bf16[16,256,16,64]{3,2,1,0}, bf16[16,256,16,64]{3,2,1,0}, bf16[16,256,16,64]{3,2,1,0}) tuple(copy.6, copy.7, copy.8, copy.9)
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
+                                         GetCudnnVersion()};
+  const auto status = RunHloPass(&fusedMhaRewriter, m.get());
+  const bool changed = status.value();
+  EXPECT_EQ(changed, false);
+}
+
+TEST_F(CudnnFusedMhaRewriterTestHloTest, BF16MiniT5xTest) {
+  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+  const char* module_str = R"(
+HloModule jit__lambda_, entry_computation_layout={(bf16[12,512,32,64]{3,2,1,0},bf16[12,512,2,32,64]{4,3,2,1,0},f32[12,512]{1,0},f32[12,512]{1,0})->(bf16[], bf16[12,512,32,64]{3,2,1,0}, bf16[12,512,2,32,64]{4,3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true}
+
+region_0.51 {
+  Arg_0.52 = bf16[] parameter(0)
+  Arg_1.53 = bf16[] parameter(1)
+  ROOT maximum.54 = bf16[] maximum(Arg_0.52, Arg_1.53)
+}
+
+region_1.63 {
+  Arg_0.64 = f32[] parameter(0)
+  Arg_1.65 = f32[] parameter(1)
+  ROOT add.66 = f32[] add(Arg_0.64, Arg_1.65)
+}
+
+region_3.99 {
+  Arg_0.100 = bf16[] parameter(0)
+  Arg_1.101 = bf16[] parameter(1)
+  ROOT add.102 = bf16[] add(Arg_0.100, Arg_1.101)
+}
+
+ENTRY main.129 {
+  Arg_1.2 = bf16[12,512,2,32,64]{4,3,2,1,0} parameter(1), sharding={replicated}
+  copy = bf16[12,512,2,32,64]{1,4,3,0,2} copy(Arg_1.2), sharding={replicated}
+  slice.42 = bf16[12,512,1,32,64]{1,4,3,0,2} slice(copy), slice={[0:12], [0:512], [1:2], [0:32], [0:64]}
+  reshape.44 = bf16[12,512,32,64]{1,3,2,0} reshape(slice.42)
+  transpose.5 = bf16[12,32,64,512]{3,2,1,0} transpose(reshape.44), dimensions={0,2,3,1}
+  Arg_0.1 = bf16[12,512,32,64]{3,2,1,0} parameter(0), sharding={replicated}
+  copy.1 = bf16[12,512,32,64]{3,1,2,0} copy(Arg_0.1), sharding={replicated}
+  constant.5 = bf16[] constant(0.125)
+  broadcast.6 = bf16[12,512,32,64]{3,1,2,0} broadcast(constant.5), dimensions={}
+  multiply.45 = bf16[12,512,32,64]{3,1,2,0} multiply(copy.1, broadcast.6)
+  transpose = bf16[12,32,512,64]{3,2,1,0} transpose(multiply.45), dimensions={0,2,1,3}
+  copy.2 = bf16[12,512,2,32,64]{1,4,3,0,2} copy(Arg_1.2), sharding={replicated}
+  slice.41 = bf16[12,512,1,32,64]{1,4,3,0,2} slice(copy.2), slice={[0:12], [0:512], [0:1], [0:32], [0:64]}
+  reshape.43 = bf16[12,512,32,64]{1,3,2,0} reshape(slice.41)
+  transpose.1 = bf16[12,32,64,512]{3,2,1,0} transpose(reshape.43), dimensions={0,2,3,1}
+  dot = bf16[12,32,512,512]{3,2,1,0} dot(transpose, transpose.1), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  Arg_2.3 = f32[12,512]{1,0} parameter(2), sharding={replicated}
+  constant.14 = f32[] constant(0)
+  broadcast.19 = f32[12,512]{1,0} broadcast(constant.14), dimensions={}
+  compare.24 = pred[12,512]{1,0} compare(Arg_2.3, broadcast.19), direction=GT
+  broadcast.30 = pred[12,512,512]{2,1,0} broadcast(compare.24), dimensions={0,1}
+  Arg_3.4 = f32[12,512]{1,0} parameter(3), sharding={replicated}
+  compare.25 = pred[12,512]{1,0} compare(Arg_3.4, broadcast.19), direction=GT
+  broadcast.33 = pred[12,512,512]{2,1,0} broadcast(compare.25), dimensions={0,2}
+  and.34 = pred[12,512,512]{2,1,0} and(broadcast.30, broadcast.33)
+  convert.4 = s32[12,512,512]{2,1,0} convert(and.34)
+  constant.16 = s32[] constant(0)
+  broadcast.21 = s32[12,512,512]{2,1,0} broadcast(constant.16), dimensions={}
+  compare.0 = pred[12,512,512]{2,1,0} compare(convert.4, broadcast.21), direction=GT
+  constant.20 = bf16[] constant(0)
+  broadcast.22 = bf16[12,512,512]{2,1,0} broadcast(constant.20), dimensions={}
+  constant.11 = bf16[] constant(-9.999e+09)
+  broadcast.23 = bf16[12,512,512]{2,1,0} broadcast(constant.11), dimensions={}
+  select.0 = bf16[12,512,512]{2,1,0} select(compare.0, broadcast.22, broadcast.23)
+  broadcast.49 = bf16[12,32,512,512]{3,2,1,0} broadcast(select.0), dimensions={0,2,3}
+  add.50 = bf16[12,32,512,512]{3,2,1,0} add(dot, broadcast.49)
+  constant.22 = bf16[] constant(-inf)
+  reduce.55 = bf16[12,32,512]{2,1,0} reduce(add.50, constant.22), dimensions={3}, to_apply=region_0.51
+  broadcast.59 = bf16[12,32,512,512]{3,2,1,0} broadcast(reduce.55), dimensions={0,1,2}
+  subtract.60 = bf16[12,32,512,512]{3,2,1,0} subtract(add.50, broadcast.59)
+  exponential.61 = bf16[12,32,512,512]{3,2,1,0} exponential(subtract.60)
+  convert.62 = f32[12,32,512,512]{3,2,1,0} convert(exponential.61)
+  reduce.67 = f32[12,32,512]{2,1,0} reduce(convert.62, constant.14), dimensions={3}, to_apply=region_1.63
+  convert.5 = bf16[12,32,512]{2,1,0} convert(reduce.67)
+  broadcast.72 = bf16[12,32,512,512]{3,2,1,0} broadcast(convert.5), dimensions={0,1,2}
+  divide.73 = bf16[12,32,512,512]{3,2,1,0} divide(exponential.61, broadcast.72)
+  dot.1 = bf16[12,32,64,512]{3,2,1,0} dot(transpose.5, divide.73), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
+  convert.6 = f32[12,32,64,512]{3,2,1,0} convert(dot.1)
+  reduce.83 = f32[] reduce(convert.6, constant.14), dimensions={0,3,1,2}, to_apply=region_1.63
+  convert.84 = bf16[] convert(reduce.83)
+  constant.2 = bf16[] constant(0.0007935)
+  multiply.86 = bf16[] multiply(convert.84, constant.2)
+  broadcast.9 = bf16[12,32,512,64]{3,2,1,0} broadcast(constant.2), dimensions={}
+  dot.2 = bf16[12,32,512,512]{3,2,1,0} dot(broadcast.9, transpose.5), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  divide.109 = bf16[12,32,512,512]{3,2,1,0} divide(dot.2, broadcast.72)
+  constant.10 = bf16[] constant(1)
+  broadcast.24 = bf16[12,32,512]{2,1,0} broadcast(constant.10), dimensions={}
+  multiply.4 = bf16[12,32,512]{2,1,0} multiply(convert.5, convert.5)
+  divide.0 = bf16[12,32,512]{2,1,0} divide(broadcast.24, multiply.4)
+  broadcast.96 = bf16[12,32,512,512]{3,2,1,0} broadcast(divide.0), dimensions={0,1,2}
+  multiply.97 = bf16[12,32,512,512]{3,2,1,0} multiply(dot.2, broadcast.96)
+  multiply.98 = bf16[12,32,512,512]{3,2,1,0} multiply(multiply.97, exponential.61)
+  reduce.103 = bf16[12,32,512]{2,1,0} reduce(multiply.98, constant.20), dimensions={3}, to_apply=region_3.99
+  negate.0 = bf16[12,32,512]{2,1,0} negate(reduce.103)
+  broadcast.10 = bf16[12,32,512,512]{3,2,1,0} broadcast(negate.0), dimensions={0,1,2}
+  add.118 = bf16[12,32,512,512]{3,2,1,0} add(divide.109, broadcast.10)
+  multiply.119 = bf16[12,32,512,512]{3,2,1,0} multiply(add.118, exponential.61)
+  transpose.9 = bf16[12,32,512,64]{2,3,1,0} transpose(reshape.43), dimensions={0,2,1,3}
+  copy.3 = bf16[12,32,512,64]{3,2,1,0} copy(transpose.9)
+  dot.4 = bf16[12,32,512,64]{3,2,1,0} dot(multiply.119, copy.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  broadcast.12 = bf16[12,32,512,64]{3,2,1,0} broadcast(constant.5), dimensions={}
+  multiply.3 = bf16[12,32,512,64]{3,2,1,0} multiply(dot.4, broadcast.12)
+  transpose.11 = bf16[12,512,32,64]{3,1,2,0} transpose(multiply.3), dimensions={0,2,1,3}
+  broadcast.7 = bf16[12,32,64,512]{3,2,1,0} broadcast(constant.2), dimensions={}
+  dot.90 = bf16[12,32,64,512]{3,2,1,0} dot(broadcast.7, divide.73), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  transpose.91 = bf16[12,512,32,64]{1,3,2,0} transpose(dot.90), dimensions={0,3,1,2}
+  reshape.92 = bf16[12,512,1,32,64]{1,4,3,0,2} reshape(transpose.91)
+  pad.93 = bf16[12,512,2,32,64]{1,4,3,0,2} pad(reshape.92, constant.20), padding=0_0x0_0x1_0x0_0x0_0
+  dot.3 = bf16[12,32,512,64]{3,2,1,0} dot(multiply.119, transpose), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  copy.4 = bf16[12,32,512,64]{2,3,1,0} copy(dot.3)
+  transpose.121 = bf16[12,512,32,64]{1,3,2,0} transpose(copy.4), dimensions={0,2,1,3}
+  reshape.124 = bf16[12,512,1,32,64]{1,4,3,0,2} reshape(transpose.121)
+  pad.125 = bf16[12,512,2,32,64]{1,4,3,0,2} pad(reshape.124, constant.20), padding=0_0x0_0x0_1x0_0x0_0
+  add.126 = bf16[12,512,2,32,64]{1,4,3,0,2} add(pad.93, pad.125)
+  tuple.128 = (bf16[], bf16[12,512,32,64]{3,1,2,0}, bf16[12,512,2,32,64]{1,4,3,0,2}) tuple(multiply.86, transpose.11, add.126)
+  get-tuple-element = bf16[] get-tuple-element(tuple.128), index=0
+  get-tuple-element.1 = bf16[12,512,32,64]{3,1,2,0} get-tuple-element(tuple.128), index=1
+  copy.5 = bf16[12,512,32,64]{3,2,1,0} copy(get-tuple-element.1)
+  get-tuple-element.2 = bf16[12,512,2,32,64]{1,4,3,0,2} get-tuple-element(tuple.128), index=2
+  copy.6 = bf16[12,512,2,32,64]{4,3,2,1,0} copy(get-tuple-element.2)
+  ROOT tuple = (bf16[], bf16[12,512,32,64]{3,2,1,0}, bf16[12,512,2,32,64]{4,3,2,1,0}) tuple(get-tuple-element, copy.5, copy.6)
+}
+
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+  AlgebraicSimplifierOptions alg_sim_options;
+  alg_sim_options.set_supports_non_canonical_dots(false);
+  alg_sim_options.set_is_layout_sensitive(true);
+  alg_sim_options.set_enable_conv_operand_swap(false);
+  AlgebraicSimplifier alge_simp{alg_sim_options};
+  ReshapeDecomposer reshape_decomposer;
+  LayoutNormalization layout_normalizer;
+  HloCSE cse{/*is_layout_sensitive=*/true};
+  TF_ASSERT_OK(RunHloPass(&reshape_decomposer, m.get()).status());
+  TF_ASSERT_OK(RunHloPass(&layout_normalizer, m.get()).status());
+  TF_ASSERT_OK(RunHloPass(&cse, m.get()).status());
+  TF_ASSERT_OK(RunHloPass(&alge_simp, m.get()).status());
+
+  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
+                                         GetCudnnVersion()};
+  TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
+
+  CudnnFusedMHATransposeFusion fmha_transpose_fusion;
+
+  HloDCE dce;
+  TF_ASSERT_OK(RunHloPass(&alge_simp, m.get()).status());
+  TF_ASSERT_OK(RunHloPass(&fmha_transpose_fusion, m.get()).status());
+
+  TF_ASSERT_OK(RunHloPass(&dce, m.get()).status());
+
+  EXPECT_EQ(CountFusedAttentionCall(m.get()), 1);
+  EXPECT_EQ(CountFusedAttentionCall(m.get(), /*is_backward*/ true), 1);
+}
+
+TEST_F(CudnnFusedMhaRewriterTestHloTest,
+       ActivationHasMoreThan1UserShouldNotLower) {
+  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+  const char* module_str = R"(
+HloModule test
+
+%region_50.2457 (Arg_0.2458: bf16[], Arg_1.2459: bf16[]) -> bf16[] {
+  %Arg_0.2458 = bf16[] parameter(0)
+  %Arg_1.2459 = bf16[] parameter(1)
+  ROOT %maximum.2 = bf16[] maximum(bf16[] %Arg_0.2458, bf16[] %Arg_1.2459)
+}
+
+%region_36.2316 (Arg_0.2317: f32[], Arg_1.2318: f32[]) -> f32[] {
+  %Arg_0.2317 = f32[] parameter(0)
+  %Arg_1.2318 = f32[] parameter(1)
+  ROOT %add.342 = f32[] add(f32[] %Arg_0.2317, f32[] %Arg_1.2318)
+}
+
+ENTRY main {
+  %transpose.482 = bf16[4,5,64]{2,1,0} parameter(0)
+  %transpose.484 = bf16[4,64,5]{2,1,0} parameter(1)
+  %dot.20 = bf16[4,5,5]{2,1,0} dot(bf16[4,5,64]{2,1,0} %transpose.482, bf16[4,64,5]{2,1,0} %transpose.484), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}
+  %constant.2515 = bf16[] constant(0.125)
+  %broadcast.789 = bf16[4,5,5]{2,1,0} broadcast(bf16[] %constant.2515), dimensions={}
+  %multiply.267 = bf16[4,5,5]{2,1,0} multiply(bf16[4,5,5]{2,1,0} %dot.20, bf16[4,5,5]{2,1,0} %broadcast.789)
+  %constant.287 = f32[] constant(-1)
+  %broadcast.792 = bf16[4,5,5]{2,1,0} parameter(3)
+  %add.348 = bf16[4,5,5]{2,1,0} add(bf16[4,5,5]{2,1,0} %multiply.267, bf16[4,5,5]{2,1,0} %broadcast.792)
+  %constant.2510 = bf16[] constant(-inf)
+  %reduce.2550 = bf16[4,5]{1,0} reduce(bf16[4,5,5]{2,1,0} %add.348, bf16[] %constant.2510), dimensions={2}, to_apply=%region_50.2457
+  %broadcast.793 = bf16[4,5,5]{2,1,0} broadcast(bf16[4,5]{1,0} %reduce.2550), dimensions={0,1}
+  %subtract.81 = bf16[4,5,5]{2,1,0} subtract(bf16[4,5,5]{2,1,0} %add.348, bf16[4,5,5]{2,1,0} %broadcast.793)
+  %exponential.21 = bf16[4,5,5]{2,1,0} exponential(bf16[4,5,5]{2,1,0} %subtract.81)
+  %convert.180 = f32[4,5,5]{2,1,0} convert(bf16[4,5,5]{2,1,0} %exponential.21)
+  %constant.2509 = f32[] constant(0)
+  %reduce.2558 = f32[4,5]{1,0} reduce(f32[4,5,5]{2,1,0} %convert.180, f32[] %constant.2509), dimensions={2}, to_apply=%region_36.2316
+  %convert.182 = bf16[4,5]{1,0} convert(f32[4,5]{1,0} %reduce.2558)
+  %broadcast.794 = bf16[4,5,5]{2,1,0} broadcast(bf16[4,5]{1,0} %convert.182), dimensions={0,1}
+  %divide.25 = bf16[4,5,5]{2,1,0} divide(bf16[4,5,5]{2,1,0} %exponential.21, bf16[4,5,5]{2,1,0} %broadcast.794)
+  %transpose.481 = bf16[4,64,5]{2,1,0} parameter(2)
+  %dot.21 = bf16[4,64,5]{2,1,0} dot(bf16[4,64,5]{2,1,0} %transpose.481, bf16[4,5,5]{2,1,0} %divide.25), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={2}
+  ROOT %tuple.2668 = (bf16[4,5,5]{2,1,0}, bf16[4,64,5]{2,1,0}) tuple(bf16[4,5,5]{2,1,0} %divide.25, bf16[4,64,5]{2,1,0} %dot.21)
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
+                                         GetCudnnVersion()};
+  TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
+  HloDCE dce;
+  TF_ASSERT_OK(RunHloPass(&dce, m.get()).status());
+
+  HloVerifier verifier(/*layout_sensitive=*/false,
+                       /*allow_mixed_precision*/ true);
+  ASSERT_IS_OK(verifier.Run(m.get()).status());
+
+  EXPECT_EQ(CountFusedAttentionCall(m.get()), 0);
+}
+
+TEST_F(CudnnFusedMhaRewriterTestHloTest,
+       F16InvalidTrainingBmm1ScaleBiasMaskSoftmaxBmm2ShouldNotBeLowered) {
+  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+  const char* module_str = R"(
+HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(f16[2,6,128,64]{3,2,1,0},f16[2,6,64,128]{3,2,1,0},f16[2,6,128,64]{3,2,1,0},f16[2,6,128,64]{3,2,1,0})->(f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,64,128]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true}
+
+region_0.21 {
+  Arg_0.22 = f16[] parameter(0)
+  Arg_1.23 = f16[] parameter(1)
+  ROOT maximum = f16[] maximum(Arg_0.22, Arg_1.23)
+}
+
+region_1.33 {
+  Arg_0.34 = f32[] parameter(0)
+  Arg_1.35 = f32[] parameter(1)
+  ROOT add = f32[] add(Arg_0.34, Arg_1.35)
+}
+
+region_2.55 {
+  Arg_0.56 = f16[] parameter(0)
+  Arg_1.57 = f16[] parameter(1)
+  ROOT add.1 = f16[] add(Arg_0.56, Arg_1.57)
+}
+
+ENTRY main.82 {
+  constant.18 = pred[2,6,128,128]{3,2,1,0} constant({...})
+  Arg_0.1 = f16[2,6,128,64]{3,2,1,0} parameter(0), sharding={replicated}
+  Arg_1.2 = f16[2,6,64,128]{3,2,1,0} parameter(1), sharding={replicated}
+  dot.17 = f16[2,6,128,128]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  constant.22 = f16[] constant(2)
+  broadcast.24 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.22), dimensions={}
+  multiply.2 = f16[2,6,128,128]{3,2,1,0} multiply(dot.17, broadcast.24)
+  constant.19 = f16[] constant(1)
+  broadcast.13 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.19), dimensions={}
+  add.3 = f16[2,6,128,128]{3,2,1,0} add(multiply.2, broadcast.13)
+  constant.21 = f16[] constant(0)
+  broadcast.23 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.21), dimensions={}
+  select.1 = f16[2,6,128,128]{3,2,1,0} select(constant.18, add.3, broadcast.23)
+  constant.15 = f16[] constant(-inf)
+  reduce.25 = f16[2,6,128]{2,1,0} reduce(select.1, constant.15), dimensions={3}, to_apply=region_0.21
+  broadcast.17 = f16[2,6,128,128]{3,2,1,0} broadcast(reduce.25), dimensions={0,1,2}
+  subtract.1 = f16[2,6,128,128]{3,2,1,0} subtract(select.1, broadcast.17)
+  exponential.1 = f16[2,6,128,128]{3,2,1,0} exponential(subtract.1)
+  convert.5 = f32[2,6,128,128]{3,2,1,0} convert(exponential.1)
+  constant.17 = f32[] constant(0)
+  reduce.37 = f32[2,6,128]{2,1,0} reduce(convert.5, constant.17), dimensions={3}, to_apply=region_1.33
+  convert.9 = f16[2,6,128]{2,1,0} convert(reduce.37)
+  broadcast.26 = f16[2,6,128,128]{3,2,1,0} broadcast(convert.9), dimensions={0,1,2}
+  divide.5 = f16[2,6,128,128]{3,2,1,0} divide(exponential.1, broadcast.26)
+  Arg_2.3 = f16[2,6,128,64]{3,2,1,0} parameter(2), sharding={replicated}
+  dot.46 = f16[2,6,128,64]{3,2,1,0} dot(divide.5, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  Arg_3.4 = f16[2,6,128,64]{3,2,1,0} parameter(3), sharding={replicated}
+  dot.49 = f16[2,6,128,128]{3,2,1,0} dot(Arg_3.4, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
+  divide.4 = f16[2,6,128,128]{3,2,1,0} divide(dot.49, broadcast.26)
+  broadcast.20 = f16[2,6,128]{2,1,0} broadcast(constant.19), dimensions={}
+  multiply.3 = f16[2,6,128]{2,1,0} multiply(convert.9, convert.9)
+  divide.3 = f16[2,6,128]{2,1,0} divide(broadcast.20, multiply.3)
+  broadcast.21 = f16[2,6,128,128]{3,2,1,0} broadcast(divide.3), dimensions={0,1,2}
+  multiply.4 = f16[2,6,128,128]{3,2,1,0} multiply(dot.49, broadcast.21)
+  multiply.5 = f16[2,6,128,128]{3,2,1,0} multiply(multiply.4, exponential.1)
+  reduce.59 = f16[2,6,128]{2,1,0} reduce(multiply.5, constant.21), dimensions={3}, to_apply=region_2.55
+  broadcast.25 = f16[2,6,128,128]{3,2,1,0} broadcast(reduce.59), dimensions={0,1,2}
+  add.5 = f16[2,6,128,128]{3,2,1,0} add(divide.4, broadcast.25)
+  multiply.8 = f16[2,6,128,128]{3,2,1,0} multiply(add.5, exponential.1)
+  select.3 = f16[2,6,128,128]{3,2,1,0} select(constant.18, multiply.8, broadcast.23)
+  multiply.9 = f16[2,6,128,128]{3,2,1,0} multiply(select.3, broadcast.24)
+  dot.80 = f16[2,6,128,64]{3,2,1,0} dot(multiply.9, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
+  dot = f16[2,6,64,128]{3,2,1,0} dot(Arg_0.1, multiply.9), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  dot.1 = f16[2,6,128,64]{3,2,1,0} dot(divide.5, Arg_3.4), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  ROOT tuple.81 = (f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,64,128]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}) tuple(dot.46, dot.80, dot, dot.1)
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
+                                         GetCudnnVersion()};
+  TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
+  HloDCE dce;
+  TF_ASSERT_OK(RunHloPass(&dce, m.get()).status());
+
+  ComputationLayout computation_layout(
+      m->entry_computation()->ComputeProgramShape());
+
+  HloVerifier verifier(/*layout_sensitive=*/false,
+                       /*allow_mixed_precision*/ true);
+  ASSERT_IS_OK(verifier.Run(m.get()).status());
+
+  // The backward pattern in the graph is not a valid fmha pattern,
+  // we expect no rewrite happening.
+  EXPECT_EQ(CountFusedAttentionCall(m.get()), 0);
+  EXPECT_EQ(CountFusedAttentionCall(m.get(), /*is_backward*/ true), 0);
+}
+
+TEST_F(CudnnFusedMhaRewriterTestHloTest,
+       F16InvalidTrainingBmm1ScaleBiasMaskSoftmaxDropoutBmm2ShouldNotLower) {
+  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+  const char* module_str = R"(
+HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(f16[2,6,128,64]{3,2,1,0},f16[2,6,64,128]{3,2,1,0},f16[2,6,128,64]{3,2,1,0},f16[2,6,128,64]{3,2,1,0})->(f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,64,128]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true}
+
+region_0.38 {
+  Arg_0.39 = f16[] parameter(0)
+  Arg_1.40 = f16[] parameter(1)
+  ROOT maximum.1 = f16[] maximum(Arg_0.39, Arg_1.40)
+}
+
+region_1.50 {
+  Arg_0.51 = f32[] parameter(0)
+  Arg_1.52 = f32[] parameter(1)
+  ROOT add.2 = f32[] add(Arg_0.51, Arg_1.52)
+}
+
+region_2.99 {
+  Arg_0.100 = f16[] parameter(0)
+  Arg_1.101 = f16[] parameter(1)
+  ROOT add.3 = f16[] add(Arg_0.100, Arg_1.101)
+}
+
+ENTRY main.126 {
+  constant.6 = u32[1]{0} constant({2718843009})
+  constant.8 = u32[1]{0} constant({1272950319})
+  constant.10 = u32[1]{0} constant({0})
+  constant.12 = u32[1]{0} constant({2711844646})
+  custom-call.65 = (u32[1]{0}, u32[1]{0}) custom-call(constant.6, constant.8, constant.10, constant.12), custom_call_target="cu_threefry2x32", operand_layout_constraints={u32[1]{0}, u32[1]{0}, u32[1]{0}, u32[1]{0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config="\001\000\000\000\000\000\000\000"
+  get-tuple-element.66 = u32[1]{0} get-tuple-element(custom-call.65), index=0
+  bitcast.343 = u32[] bitcast(get-tuple-element.66)
+  broadcast.27 = u32[98304]{0} broadcast(bitcast.343), dimensions={}
+  get-tuple-element.67 = u32[1]{0} get-tuple-element(custom-call.65), index=1
+  bitcast.344 = u32[] bitcast(get-tuple-element.67)
+  broadcast.28 = u32[98304]{0} broadcast(bitcast.344), dimensions={}
+  iota.68 = u32[196608]{0} iota(), iota_dimension=0
+  slice = u32[98304]{0} slice(iota.68), slice={[0:98304]}
+  slice.1 = u32[98304]{0} slice(iota.68), slice={[98304:196608]}
+  custom-call.75 = (u32[98304]{0}, u32[98304]{0}) custom-call(broadcast.27, broadcast.28, slice, slice.1), custom_call_target="cu_threefry2x32", operand_layout_constraints={u32[98304]{0}, u32[98304]{0}, u32[98304]{0}, u32[98304]{0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config="\000\200\001\000\000\000\000\000"
+  get-tuple-element.76 = u32[98304]{0} get-tuple-element(custom-call.75), index=0
+  get-tuple-element.77 = u32[98304]{0} get-tuple-element(custom-call.75), index=1
+  concatenate.2 = u32[196608]{0} concatenate(get-tuple-element.76, get-tuple-element.77), dimensions={0}
+  constant.56 = u32[] constant(9)
+  broadcast.63 = u32[196608]{0} broadcast(constant.56), dimensions={}
+  shift-right-logical.3 = u32[196608]{0} shift-right-logical(concatenate.2, broadcast.63)
+  constant.57 = u32[] constant(1065353216)
+  broadcast.64 = u32[196608]{0} broadcast(constant.57), dimensions={}
+  or.3 = u32[196608]{0} or(shift-right-logical.3, broadcast.64)
+  bitcast-convert.3 = f32[196608]{0} bitcast-convert(or.3)
+  constant.58 = f32[] constant(-1)
+  broadcast.65 = f32[196608]{0} broadcast(constant.58), dimensions={}
+  add.10 = f32[196608]{0} add(bitcast-convert.3, broadcast.65)
+  constant.48 = f32[] constant(0)
+  broadcast.66 = f32[196608]{0} broadcast(constant.48), dimensions={}
+  maximum.4 = f32[196608]{0} maximum(add.10, broadcast.66)
+  constant.59 = f32[] constant(0.9)
+  broadcast.67 = f32[196608]{0} broadcast(constant.59), dimensions={}
+  compare.3 = pred[196608]{0} compare(maximum.4, broadcast.67), direction=LT
+  bitcast.308 = pred[2,6,128,128]{3,2,1,0} bitcast(compare.3)
+  constant.44 = pred[2,6,128,128]{3,2,1,0} constant({...})
+  Arg_0.1 = f16[2,6,128,64]{3,2,1,0} parameter(0), sharding={replicated}
+  Arg_1.2 = f16[2,6,64,128]{3,2,1,0} parameter(1), sharding={replicated}
+  dot.34 = f16[2,6,128,128]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  constant.55 = f16[] constant(2)
+  broadcast.61 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.55), dimensions={}
+  multiply.8 = f16[2,6,128,128]{3,2,1,0} multiply(dot.34, broadcast.61)
+  constant.52 = f16[] constant(1)
+  broadcast.39 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.52), dimensions={}
+  add.6 = f16[2,6,128,128]{3,2,1,0} add(multiply.8, broadcast.39)
+  constant.54 = f16[] constant(0)
+  broadcast.52 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.54), dimensions={}
+  select.1 = f16[2,6,128,128]{3,2,1,0} select(constant.44, add.6, broadcast.52)
+  constant.41 = f16[] constant(-inf)
+  reduce.42 = f16[2,6,128]{2,1,0} reduce(select.1, constant.41), dimensions={3}, to_apply=region_0.38
+  broadcast.42 = f16[2,6,128,128]{3,2,1,0} broadcast(reduce.42), dimensions={0,1,2}
+  subtract.1 = f16[2,6,128,128]{3,2,1,0} subtract(select.1, broadcast.42)
+  exponential.1 = f16[2,6,128,128]{3,2,1,0} exponential(subtract.1)
+  convert.5 = f32[2,6,128,128]{3,2,1,0} convert(exponential.1)
+  reduce.54 = f32[2,6,128]{2,1,0} reduce(convert.5, constant.48), dimensions={3}, to_apply=region_1.50
+  convert.9 = f16[2,6,128]{2,1,0} convert(reduce.54)
+  broadcast.68 = f16[2,6,128,128]{3,2,1,0} broadcast(convert.9), dimensions={0,1,2}
+  divide.5 = f16[2,6,128,128]{3,2,1,0} divide(exponential.1, broadcast.68)
+  constant.60 = f16[] constant(1.1113)
+  broadcast.69 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.60), dimensions={}
+  multiply.20 = f16[2,6,128,128]{3,2,1,0} multiply(divide.5, broadcast.69)
+  select.8 = f16[2,6,128,128]{3,2,1,0} select(bitcast.308, multiply.20, broadcast.52)
+  Arg_2.3 = f16[2,6,128,64]{3,2,1,0} parameter(2), sharding={replicated}
+  dot.88 = f16[2,6,128,64]{3,2,1,0} dot(select.8, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  bitcast.248 = pred[2,6,128,128]{3,2,1,0} bitcast(compare.3)
+  Arg_3.4 = f16[2,6,128,64]{3,2,1,0} parameter(3), sharding={replicated}
+  dot.91 = f16[2,6,128,128]{3,2,1,0} dot(Arg_3.4, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
+  select.6 = f16[2,6,128,128]{3,2,1,0} select(bitcast.248, dot.91, broadcast.52)
+  multiply.17 = f16[2,6,128,128]{3,2,1,0} multiply(select.6, broadcast.69)
+  divide.4 = f16[2,6,128,128]{3,2,1,0} divide(multiply.17, broadcast.68)
+  broadcast.55 = f16[2,6,128]{2,1,0} broadcast(constant.52), dimensions={}
+  multiply.11 = f16[2,6,128]{2,1,0} multiply(convert.9, convert.9)
+  divide.3 = f16[2,6,128]{2,1,0} divide(broadcast.55, multiply.11)
+  broadcast.56 = f16[2,6,128]{2,1,0} broadcast(constant.60), dimensions={}
+  multiply.12 = f16[2,6,128]{2,1,0} multiply(divide.3, broadcast.56)
+  broadcast.58 = f16[2,6,128,128]{3,2,1,0} broadcast(multiply.12), dimensions={0,1,2}
+  multiply.13 = f16[2,6,128,128]{3,2,1,0} multiply(select.6, broadcast.58)
+  multiply.14 = f16[2,6,128,128]{3,2,1,0} multiply(multiply.13, exponential.1)
+  reduce.103 = f16[2,6,128]{2,1,0} reduce(multiply.14, constant.54), dimensions={3}, to_apply=region_2.99
+  broadcast.62 = f16[2,6,128,128]{3,2,1,0} broadcast(reduce.103), dimensions={0,1,2}
+  add.9 = f16[2,6,128,128]{3,2,1,0} add(divide.4, broadcast.62)
+  multiply.18 = f16[2,6,128,128]{3,2,1,0} multiply(add.9, exponential.1)
+  select.7 = f16[2,6,128,128]{3,2,1,0} select(constant.44, multiply.18, broadcast.52)
+  multiply.19 = f16[2,6,128,128]{3,2,1,0} multiply(select.7, broadcast.61)
+  dot.124 = f16[2,6,128,64]{3,2,1,0} dot(multiply.19, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
+  dot = f16[2,6,64,128]{3,2,1,0} dot(Arg_0.1, multiply.19), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  dot.1 = f16[2,6,128,64]{3,2,1,0} dot(select.8, Arg_3.4), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  ROOT tuple.125 = (f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,64,128]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}) tuple(dot.88, dot.124, dot, dot.1)
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
+                                         GetCudnnVersion()};
+  TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
+  HloDCE dce;
+  TF_ASSERT_OK(RunHloPass(&dce, m.get()).status());
+
+  ComputationLayout computation_layout(
+      m->entry_computation()->ComputeProgramShape());
+
+  HloVerifier verifier(/*layout_sensitive=*/false,
+                       /*allow_mixed_precision*/ true);
+  ASSERT_IS_OK(verifier.Run(m.get()).status());
+
+  // The backward pattern in the graph is not a valid fmha pattern,
+  // we expect no rewrite happening.
+  EXPECT_EQ(CountFusedAttentionCall(m.get()), 0);
+  EXPECT_EQ(CountFusedAttentionCall(m.get(), /*is_backward*/ true), 0);
+}
+
+TEST_F(CudnnFusedMhaRewriterTestHloTest,
+       F16TrainingBmm1ScaleBiasSoftmaxBmm2QTranspose) {
+  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+  const char* module_str = R"(
+HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(f16[2,6,64,128]{3,2,1,0},f16[2,6,64,128]{3,2,1,0},f16[2,6,128,64]{3,2,1,0},f16[2,6,128,64]{3,2,1,0})->(f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,64,128]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true}
+
+region_0.21 {
+  Arg_0.22 = f16[] parameter(0)
+  Arg_1.23 = f16[] parameter(1)
+  ROOT maximum = f16[] maximum(Arg_0.22, Arg_1.23)
+}
+
+region_1.33 {
+  Arg_0.34 = f32[] parameter(0)
+  Arg_1.35 = f32[] parameter(1)
+  ROOT add = f32[] add(Arg_0.34, Arg_1.35)
+}
+
+region_2.55 {
+  Arg_0.56 = f16[] parameter(0)
+  Arg_1.57 = f16[] parameter(1)
+  ROOT add.1 = f16[] add(Arg_0.56, Arg_1.57)
+}
+
+ENTRY main.82 {
+  Arg_0.1 = f16[2,6,64,128]{3,2,1,0} parameter(0), sharding={replicated}
+  Arg_1.2 = f16[2,6,64,128]{3,2,1,0} parameter(1), sharding={replicated}
+  dot.17 = f16[2,6,128,128]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  constant.22 = f16[] constant(2)
+  broadcast.24 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.22), dimensions={}
+  multiply.2 = f16[2,6,128,128]{3,2,1,0} multiply(dot.17, broadcast.24)
+  constant.19 = f16[] constant(1)
+  broadcast.13 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.19), dimensions={}
+  add.3 = f16[2,6,128,128]{3,2,1,0} add(multiply.2, broadcast.13)
+  constant.21 = f16[] constant(0)
+  constant.15 = f16[] constant(-inf)
+  reduce.25 = f16[2,6,128]{2,1,0} reduce(add.3, constant.15), dimensions={3}, to_apply=region_0.21
+  broadcast.17 = f16[2,6,128,128]{3,2,1,0} broadcast(reduce.25), dimensions={0,1,2}
+  subtract.1 = f16[2,6,128,128]{3,2,1,0} subtract(add.3, broadcast.17)
+  exponential.1 = f16[2,6,128,128]{3,2,1,0} exponential(subtract.1)
+  convert.5 = f32[2,6,128,128]{3,2,1,0} convert(exponential.1)
+  constant.17 = f32[] constant(0)
+  reduce.37 = f32[2,6,128]{2,1,0} reduce(convert.5, constant.17), dimensions={3}, to_apply=region_1.33
+  convert.9 = f16[2,6,128]{2,1,0} convert(reduce.37)
+  broadcast.26 = f16[2,6,128,128]{3,2,1,0} broadcast(convert.9), dimensions={0,1,2}
+  divide.5 = f16[2,6,128,128]{3,2,1,0} divide(exponential.1, broadcast.26)
+  Arg_2.3 = f16[2,6,128,64]{3,2,1,0} parameter(2), sharding={replicated}
+  dot.46 = f16[2,6,128,64]{3,2,1,0} dot(divide.5, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  Arg_3.4 = f16[2,6,128,64]{3,2,1,0} parameter(3), sharding={replicated}
+  dot.49 = f16[2,6,128,128]{3,2,1,0} dot(Arg_3.4, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
+  divide.4 = f16[2,6,128,128]{3,2,1,0} divide(dot.49, broadcast.26)
+  broadcast.20 = f16[2,6,128]{2,1,0} broadcast(constant.19), dimensions={}
+  multiply.3 = f16[2,6,128]{2,1,0} multiply(convert.9, convert.9)
+  divide.3 = f16[2,6,128]{2,1,0} divide(broadcast.20, multiply.3)
+  broadcast.21 = f16[2,6,128,128]{3,2,1,0} broadcast(divide.3), dimensions={0,1,2}
+  multiply.4 = f16[2,6,128,128]{3,2,1,0} multiply(dot.49, broadcast.21)
+  multiply.5 = f16[2,6,128,128]{3,2,1,0} multiply(multiply.4, exponential.1)
+  reduce.59 = f16[2,6,128]{2,1,0} reduce(multiply.5, constant.21), dimensions={3}, to_apply=region_2.55
+  negate.2 = f16[2,6,128]{2,1,0} negate(reduce.59)
+  broadcast.25 = f16[2,6,128,128]{3,2,1,0} broadcast(negate.2), dimensions={0,1,2}
+  add.5 = f16[2,6,128,128]{3,2,1,0} add(divide.4, broadcast.25)
+  multiply.8 = f16[2,6,128,128]{3,2,1,0} multiply(add.5, exponential.1)
+  multiply.9 = f16[2,6,128,128]{3,2,1,0} multiply(multiply.8, broadcast.24)
+  dot.80 = f16[2,6,128,64]{3,2,1,0} dot(multiply.9, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
+  dot = f16[2,6,64,128]{3,2,1,0} dot(Arg_0.1, multiply.9), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  dot.1 = f16[2,6,128,64]{3,2,1,0} dot(divide.5, Arg_3.4), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  ROOT tuple.81 = (f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,64,128]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}) tuple(dot.46, dot.80, dot, dot.1)
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
+                                         GetCudnnVersion()};
+  TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
+  HloDCE dce;
+  TF_ASSERT_OK(RunHloPass(&dce, m.get()).status());
+
+  ComputationLayout computation_layout(
+      m->entry_computation()->ComputeProgramShape());
+
+  const HloInstruction* fmha;
+
+  SCOPED_TRACE(m->ToString());
+  EXPECT_THAT(
+      m->entry_computation()->root_instruction(),
+      GmockMatch(m::Tuple(
+          m::GetTupleElement(
+              m::CustomCall(&fmha, {kCudnnfMHAScaleBiasSoftmaxCallTarget}), 0)
+              .WithShape(F16, {2, 6, 128, 64}),
+          m::GetTupleElement(
+              m::CustomCall(&fmha,
+                            {kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget}),
+              0)
+              .WithShape(F16, {2, 6, 128, 64}),
+          m::Transpose(
+              m::GetTupleElement(
+                  m::CustomCall({kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget}),
+                  1))
+              .WithShape(F16, {2, 6, 64, 128}),
+          m::GetTupleElement(
+              m::CustomCall({kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget}), 2)
+              .WithShape(F16, {2, 6, 128, 64}))));
+  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+                          fmha->backend_config<GpuBackendConfig>());
+  const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config();
+  EXPECT_EQ(fmha->operands().size(), 7);
+  EXPECT_NEAR(config.dropout_rate(), 0, 1e-2);
+}
+
+TEST_F(CudnnFusedMhaRewriterTestHloTest,
+       F16Bmm1UnfusedSoftmaxBmm2IncorrectBmm1NumUsers) {
+  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+  const char* module_str = R"(
+HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(f16[2,6,40,64]{3,2,1,0},f16[2,6,64,40]{3,2,1,0},f16[2,6,40,64]{3,2,1,0})->(f16[2,6,40,64]{3,2,1,0}, f16[2,6,40,40]{3,2,1,0})}
+
+region_0.7 {
+  Arg_0.8 = f16[] parameter(0)
+  Arg_1.9 = f16[] parameter(1)
+  ROOT maximum = f16[] maximum(Arg_0.8, Arg_1.9)
+}
+
+region_1.19 {
+  Arg_0.20 = f32[] parameter(0)
+  Arg_1.21 = f32[] parameter(1)
+  ROOT add = f32[] add(Arg_0.20, Arg_1.21)
+}
+
+ENTRY main.31 {
+  Arg_0.1 = f16[2,6,40,64]{3,2,1,0} parameter(0), sharding={replicated}
+  Arg_1.2 = f16[2,6,64,40]{3,2,1,0} parameter(1), sharding={replicated}
+  dot = f16[2,6,40,40]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={3}, rhs_contracting_dims={2}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1}
+  // extra user of bmm1
+  neg.1 = f16[2,6,40,40]{3,2,1,0} negate(dot)
+  constant = f16[] constant(-inf)
+  reduce.11 = f16[2,6,40]{2,1,0} reduce(dot, constant), dimensions={3}, to_apply=region_0.7
+  broadcast.3 = f16[2,6,40,40]{3,2,1,0} broadcast(reduce.11), dimensions={0,1,2}
+  subtract.1 = f16[2,6,40,40]{3,2,1,0} subtract(dot, broadcast.3)
+  exponential.1 = f16[2,6,40,40]{3,2,1,0} exponential(subtract.1)
+  convert.1 = f32[2,6,40,40]{3,2,1,0} convert(exponential.1)
+  constant.1 = f32[] constant(0)
+  reduce.23 = f32[2,6,40]{2,1,0} reduce(convert.1, constant.1), dimensions={3}, to_apply=region_1.19
+  convert.2 = f16[2,6,40]{2,1,0} convert(reduce.23)
+  broadcast.4 = f16[2,6,40,40]{3,2,1,0} broadcast(convert.2), dimensions={0,1,2}
+  divide = f16[2,6,40,40]{3,2,1,0} divide(exponential.1, broadcast.4)
+  Arg_2.3 = f16[2,6,40,64]{3,2,1,0} parameter(2), sharding={replicated}
+  dot.1 = f16[2,6,40,64]{3,2,1,0} dot(divide, Arg_2.3), lhs_contracting_dims={3}, rhs_contracting_dims={2}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1}
+  ROOT tuple.81 = (f16[2,6,40,64]{3,2,1,0}, f16[2,6,40,40]{3,2,1,0}) tuple(dot.1, neg.1)
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
+                                         GetCudnnVersion()};
+  TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
+
+  SCOPED_TRACE(m->ToString());
+  EXPECT_THAT(m->entry_computation()->root_instruction(),
+              GmockMatch(m::Tuple(m::Dot(), m::Negate())));
+}
+
+TEST_F(CudnnFusedMhaRewriterTestHloTest,
+       F16Bmm1UnfusedSoftmaxBmm2IncorrectSoftmaxNumUsers) {
+  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+  const char* module_str = R"(
+HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(f16[2,6,40,64]{3,2,1,0},f16[2,6,64,40]{3,2,1,0},f16[2,6,40,64]{3,2,1,0})->(f16[2,6,40,64]{3,2,1,0}, f16[2,6,40,40]{3,2,1,0})}
+
+region_0.7 {
+  Arg_0.8 = f16[] parameter(0)
+  Arg_1.9 = f16[] parameter(1)
+  ROOT maximum = f16[] maximum(Arg_0.8, Arg_1.9)
+}
+
+region_1.19 {
+  Arg_0.20 = f32[] parameter(0)
+  Arg_1.21 = f32[] parameter(1)
+  ROOT add = f32[] add(Arg_0.20, Arg_1.21)
+}
+
+ENTRY main.31 {
+  Arg_0.1 = f16[2,6,40,64]{3,2,1,0} parameter(0), sharding={replicated}
+  Arg_1.2 = f16[2,6,64,40]{3,2,1,0} parameter(1), sharding={replicated}
+  dot = f16[2,6,40,40]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={3}, rhs_contracting_dims={2}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1}
+  constant = f16[] constant(-inf)
+  reduce.11 = f16[2,6,40]{2,1,0} reduce(dot, constant), dimensions={3}, to_apply=region_0.7
+  broadcast.3 = f16[2,6,40,40]{3,2,1,0} broadcast(reduce.11), dimensions={0,1,2}
+  subtract.1 = f16[2,6,40,40]{3,2,1,0} subtract(dot, broadcast.3)
+  // extra user of softmax sub node
+  neg.1 = f16[2,6,40,40]{3,2,1,0} negate(subtract.1)
+  exponential.1 = f16[2,6,40,40]{3,2,1,0} exponential(subtract.1)
+  convert.1 = f32[2,6,40,40]{3,2,1,0} convert(exponential.1)
+  constant.1 = f32[] constant(0)
+  reduce.23 = f32[2,6,40]{2,1,0} reduce(convert.1, constant.1), dimensions={3}, to_apply=region_1.19
+  convert.2 = f16[2,6,40]{2,1,0} convert(reduce.23)
+  broadcast.4 = f16[2,6,40,40]{3,2,1,0} broadcast(convert.2), dimensions={0,1,2}
+  divide = f16[2,6,40,40]{3,2,1,0} divide(exponential.1, broadcast.4)
+  Arg_2.3 = f16[2,6,40,64]{3,2,1,0} parameter(2), sharding={replicated}
+  dot.1 = f16[2,6,40,64]{3,2,1,0} dot(divide, Arg_2.3), lhs_contracting_dims={3}, rhs_contracting_dims={2}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1}
+  ROOT tuple.81 = (f16[2,6,40,64]{3,2,1,0}, f16[2,6,40,40]{3,2,1,0}) tuple(dot.1, neg.1)
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
+                                         GetCudnnVersion()};
+  TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
+
+  SCOPED_TRACE(m->ToString());
+  EXPECT_THAT(m->entry_computation()->root_instruction(),
+              GmockMatch(m::Tuple(m::Dot(), m::Negate())));
+}
+
+TEST_F(CudnnFusedMhaRewriterTestHloTest,
+       F16TrainingBmm1ScaleBiasSoftmaxBmm2IncorrectSoftmaxBwdNumUsers) {
+  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+  const char* module_str = R"(
+HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(f16[2,6,64,128]{3,2,1,0},f16[2,6,64,128]{3,2,1,0},f16[2,6,128,64]{3,2,1,0},f16[2,6,128,64]{3,2,1,0})->(f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,64,128]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,128]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true}
+
+region_0.21 {
+  Arg_0.22 = f16[] parameter(0)
+  Arg_1.23 = f16[] parameter(1)
+  ROOT maximum = f16[] maximum(Arg_0.22, Arg_1.23)
+}
+
+region_1.33 {
+  Arg_0.34 = f32[] parameter(0)
+  Arg_1.35 = f32[] parameter(1)
+  ROOT add = f32[] add(Arg_0.34, Arg_1.35)
+}
+
+region_2.55 {
+  Arg_0.56 = f16[] parameter(0)
+  Arg_1.57 = f16[] parameter(1)
+  ROOT add.1 = f16[] add(Arg_0.56, Arg_1.57)
+}
+
+ENTRY main.82 {
+  Arg_0.1 = f16[2,6,64,128]{3,2,1,0} parameter(0), sharding={replicated}
+  Arg_1.2 = f16[2,6,64,128]{3,2,1,0} parameter(1), sharding={replicated}
+  dot.17 = f16[2,6,128,128]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  constant.22 = f16[] constant(2)
+  broadcast.24 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.22), dimensions={}
+  multiply.2 = f16[2,6,128,128]{3,2,1,0} multiply(dot.17, broadcast.24)
+  constant.19 = f16[] constant(1)
+  broadcast.13 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.19), dimensions={}
+  add.3 = f16[2,6,128,128]{3,2,1,0} add(multiply.2, broadcast.13)
+  constant.21 = f16[] constant(0)
+  constant.15 = f16[] constant(-inf)
+  reduce.25 = f16[2,6,128]{2,1,0} reduce(add.3, constant.15), dimensions={3}, to_apply=region_0.21
+  broadcast.17 = f16[2,6,128,128]{3,2,1,0} broadcast(reduce.25), dimensions={0,1,2}
+  subtract.1 = f16[2,6,128,128]{3,2,1,0} subtract(add.3, broadcast.17)
+  exponential.1 = f16[2,6,128,128]{3,2,1,0} exponential(subtract.1)
+  convert.5 = f32[2,6,128,128]{3,2,1,0} convert(exponential.1)
+  constant.17 = f32[] constant(0)
+  reduce.37 = f32[2,6,128]{2,1,0} reduce(convert.5, constant.17), dimensions={3}, to_apply=region_1.33
+  convert.9 = f16[2,6,128]{2,1,0} convert(reduce.37)
+  broadcast.26 = f16[2,6,128,128]{3,2,1,0} broadcast(convert.9), dimensions={0,1,2}
+  divide.5 = f16[2,6,128,128]{3,2,1,0} divide(exponential.1, broadcast.26)
+  Arg_2.3 = f16[2,6,128,64]{3,2,1,0} parameter(2), sharding={replicated}
+  dot.46 = f16[2,6,128,64]{3,2,1,0} dot(divide.5, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  Arg_3.4 = f16[2,6,128,64]{3,2,1,0} parameter(3), sharding={replicated}
+  dot.49 = f16[2,6,128,128]{3,2,1,0} dot(Arg_3.4, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
+  divide.4 = f16[2,6,128,128]{3,2,1,0} divide(dot.49, broadcast.26)
+  // extra user of softmax bwd divide node
+  neg.1 = f16[2,6,128,128]{3,2,1,0} negate(divide.4)
+  broadcast.20 = f16[2,6,128]{2,1,0} broadcast(constant.19), dimensions={}
+  multiply.3 = f16[2,6,128]{2,1,0} multiply(convert.9, convert.9)
+  divide.3 = f16[2,6,128]{2,1,0} divide(broadcast.20, multiply.3)
+  broadcast.21 = f16[2,6,128,128]{3,2,1,0} broadcast(divide.3), dimensions={0,1,2}
+  multiply.4 = f16[2,6,128,128]{3,2,1,0} multiply(dot.49, broadcast.21)
+  multiply.5 = f16[2,6,128,128]{3,2,1,0} multiply(multiply.4, exponential.1)
+  reduce.59 = f16[2,6,128]{2,1,0} reduce(multiply.5, constant.21), dimensions={3}, to_apply=region_2.55
+  negate.2 = f16[2,6,128]{2,1,0} negate(reduce.59)
+  broadcast.25 = f16[2,6,128,128]{3,2,1,0} broadcast(negate.2), dimensions={0,1,2}
+  add.5 = f16[2,6,128,128]{3,2,1,0} add(divide.4, broadcast.25)
+  multiply.8 = f16[2,6,128,128]{3,2,1,0} multiply(add.5, exponential.1)
+  multiply.9 = f16[2,6,128,128]{3,2,1,0} multiply(multiply.8, broadcast.24)
+  dot.80 = f16[2,6,128,64]{3,2,1,0} dot(multiply.9, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
+  dot = f16[2,6,64,128]{3,2,1,0} dot(Arg_0.1, multiply.9), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  dot.1 = f16[2,6,128,64]{3,2,1,0} dot(divide.5, Arg_3.4), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  ROOT tuple.81 = (f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,64,128]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,128]{3,2,1,0}) tuple(dot.46, dot.80, dot, dot.1, neg.1)
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
+                                         GetCudnnVersion()};
+  TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
+  HloDCE dce;
+  TF_ASSERT_OK(RunHloPass(&dce, m.get()).status());
+
+  ComputationLayout computation_layout(
+      m->entry_computation()->ComputeProgramShape());
+
+  SCOPED_TRACE(m->ToString());
+  EXPECT_THAT(m->entry_computation()->root_instruction(),
+              GmockMatch(m::Tuple(m::Dot(), m::Dot(), m::Dot(), m::Dot(),
+                                  m::Negate())));
+}
+
+TEST_F(CudnnFusedMhaRewriterTestHloTest, F16Bmm1SoftmaxBmm2IncorrectRank) {
+  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+  const char* module_str = R"(
+HloModule reproducer, entry_computation_layout={(f16[1,8,16,5,128]{4,3,2,1,0}, f16[1,8,16,5,128]{4,3,2,1,0}, f16[1,8,16,5,128]{4,3,2,1,0}, f32[128,2,64]{2,1,0}, f32[2,64]{1,0}, /*index=5*/f32[128,2,64]{2,1,0}, f32[2,64]{1,0}, f32[128,2,64]{2,1,0}, f32[2,64]{1,0})->f16[8,16,2,5,64]{4,3,2,1,0}}
+
+region_0.36 {
+  Arg_0.37 = f16[] parameter(0)
+  Arg_1.38 = f16[] parameter(1)
+  ROOT maximum = f16[] maximum(Arg_0.37, Arg_1.38)
+}
+
+region_1.48 {
+  Arg_0.49 = f32[] parameter(0)
+  Arg_1.50 = f32[] parameter(1)
+  ROOT add.1 = f32[] add(Arg_0.49, Arg_1.50)
+}
+
+ENTRY main {
+  arg2.3 = f16[1,8,16,5,128]{4,3,2,1,0} parameter(2), parameter_replication={false}
+  bitcast.31 = f16[640,128]{1,0} bitcast(arg2.3)
+  arg5.6 = f32[128,2,64]{2,1,0} parameter(5), parameter_replication={false}
+  convert.3 = f16[128,2,64]{2,1,0} convert(arg5.6)
+  bitcast.36 = f16[128,128]{1,0} bitcast(convert.3)
+  dot = f16[640,128]{1,0} dot(bitcast.31, bitcast.36), lhs_contracting_dims={1}, rhs_contracting_dims={0}, frontend_attributes={grad_x="false",grad_y="false"}
+  bitcast.39 = f16[1,8,16,5,2,64]{5,4,3,2,1,0} bitcast(dot)
+  transpose.27 = f16[1,8,16,2,5,64]{5,4,3,2,1,0} transpose(bitcast.39), dimensions={0,1,2,4,3,5}, frontend_attributes={grad_x="false",grad_y="false"}
+  arg6.7 = f32[2,64]{1,0} parameter(6), parameter_replication={false}
+  convert.4 = f16[2,64]{1,0} convert(arg6.7)
+  broadcast.9 = f16[1,8,16,2,5,64]{5,4,3,2,1,0} broadcast(convert.4), dimensions={3,5}
+  add.2 = f16[1,8,16,2,5,64]{5,4,3,2,1,0} add(transpose.27, broadcast.9)
+  bitcast.49 = f16[8,16,2,5,64]{4,3,2,1,0} bitcast(add.2)
+  arg0.1 = f16[1,8,16,5,128]{4,3,2,1,0} parameter(0), parameter_replication={false}
+  bitcast.53 = f16[640,128]{1,0} bitcast(arg0.1)
+  arg3.4 = f32[128,2,64]{2,1,0} parameter(3), parameter_replication={false}
+  convert.5 = f16[128,2,64]{2,1,0} convert(arg3.4)
+  bitcast.58 = f16[128,128]{1,0} bitcast(convert.5)
+  dot.1 = f16[640,128]{1,0} dot(bitcast.53, bitcast.58), lhs_contracting_dims={1}, rhs_contracting_dims={0}, frontend_attributes={grad_x="false",grad_y="false"}
+  bitcast.61 = f16[1,8,16,5,2,64]{5,4,3,2,1,0} bitcast(dot.1)
+  transpose.28 = f16[1,8,16,2,64,5]{5,4,3,2,1,0} transpose(bitcast.61), dimensions={0,1,2,4,5,3}, frontend_attributes={grad_x="false",grad_y="false"}
+  arg4.5 = f32[2,64]{1,0} parameter(4), parameter_replication={false}
+  convert.6 = f16[2,64]{1,0} convert(arg4.5)
+  broadcast.10 = f16[1,8,16,2,64,5]{5,4,3,2,1,0} broadcast(convert.6), dimensions={3,4}
+  add.3 = f16[1,8,16,2,64,5]{5,4,3,2,1,0} add(transpose.28, broadcast.10)
+  constant.29 = f16[] constant(0.125)
+  broadcast.11 = f16[1,8,16,2,64,5]{5,4,3,2,1,0} broadcast(constant.29), dimensions={}
+  multiply = f16[1,8,16,2,64,5]{5,4,3,2,1,0} multiply(add.3, broadcast.11)
+  bitcast.74 = f16[8,16,2,64,5]{4,3,2,1,0} bitcast(multiply)
+  dot.6 = f16[8,16,2,5,5]{4,3,2,1,0} dot(bitcast.49, bitcast.74), lhs_batch_dims={0,1,2}, lhs_contracting_dims={4}, rhs_batch_dims={0,1,2}, rhs_contracting_dims={3}, frontend_attributes={grad_x="false",grad_y="false"}
+  constant.35 = f16[] constant(-inf)
+  reduce.1 = f16[8,16,2,5]{3,2,1,0} reduce(dot.6, constant.35), dimensions={3}, to_apply=region_0.36
+  broadcast.12 = f16[8,16,2,5,5]{4,3,2,1,0} broadcast(reduce.1), dimensions={0,1,2,4}
+  subtract.2 = f16[8,16,2,5,5]{4,3,2,1,0} subtract(dot.6, broadcast.12)
+  exponential.2 = f16[8,16,2,5,5]{4,3,2,1,0} exponential(subtract.2)
+  convert.7 = f32[8,16,2,5,5]{4,3,2,1,0} convert(exponential.2)
+  constant.34 = f32[] constant(0)
+  reduce.3 = f32[8,16,2,5]{3,2,1,0} reduce(convert.7, constant.34), dimensions={3}, to_apply=region_1.48
+  convert.8 = f16[8,16,2,5]{3,2,1,0} convert(reduce.3)
+  broadcast.13 = f16[8,16,2,5,5]{4,3,2,1,0} broadcast(convert.8), dimensions={0,1,2,4}
+  divide.2 = f16[8,16,2,5,5]{4,3,2,1,0} divide(exponential.2, broadcast.13)
+  bitcast.98 = f16[8,16,2,5,5]{3,4,2,1,0} bitcast(divide.2)
+  arg1.2 = f16[1,8,16,5,128]{4,3,2,1,0} parameter(1), parameter_replication={false}
+  bitcast.102 = f16[640,128]{1,0} bitcast(arg1.2)
+  arg7.8 = f32[128,2,64]{2,1,0} parameter(7), parameter_replication={false}
+  convert.9 = f16[128,2,64]{2,1,0} convert(arg7.8)
+  bitcast.107 = f16[128,128]{1,0} bitcast(convert.9)
+  dot.3 = f16[640,128]{1,0} dot(bitcast.102, bitcast.107), lhs_contracting_dims={1}, rhs_contracting_dims={0}, frontend_attributes={grad_x="false",grad_y="false"}
+  bitcast.110 = f16[1,8,16,5,2,64]{5,4,3,2,1,0} bitcast(dot.3)
+  transpose.30 = f16[1,8,16,2,5,64]{5,4,3,2,1,0} transpose(bitcast.110), dimensions={0,1,2,4,3,5}, frontend_attributes={grad_x="false",grad_y="false"}
+  arg8.9 = f32[2,64]{1,0} parameter(8), parameter_replication={false}
+  convert.10 = f16[2,64]{1,0} convert(arg8.9)
+  broadcast.14 = f16[1,8,16,2,5,64]{5,4,3,2,1,0} broadcast(convert.10), dimensions={3,5}
+  add.4 = f16[1,8,16,2,5,64]{5,4,3,2,1,0} add(transpose.30, broadcast.14)
+  bitcast.120 = f16[8,16,2,5,64]{4,3,2,1,0} bitcast(add.4)
+  ROOT dot.7 = f16[8,16,2,5,64]{4,3,2,1,0} dot(bitcast.98, bitcast.120), lhs_batch_dims={0,1,2}, lhs_contracting_dims={4}, rhs_batch_dims={0,1,2}, rhs_contracting_dims={3}, frontend_attributes={grad_x="false",grad_y="false"}
+} // main
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
+                                         GetCudnnVersion()};
+  const auto status_or = RunHloPass(&fusedMhaRewriter, m.get());
+  TF_ASSERT_OK(status_or.status());
+  EXPECT_FALSE(status_or.value());
+
+  HloDCE dce;
+  TF_ASSERT_OK(RunHloPass(&dce, m.get()).status());
+
+  ComputationLayout computation_layout(
+      m->entry_computation()->ComputeProgramShape());
+
+  SCOPED_TRACE(m->ToString());
+  EXPECT_THAT(m->entry_computation()->root_instruction(), GmockMatch(m::Dot()));
+}
+
+TEST_F(CudnnFusedMhaRewriterTestHloTest, F16TrainingBmm2Grad1IncorrectPattern) {
+  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+  const char* module_str = R"(
+HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(f16[2,6,64,128]{3,2,1,0},f16[2,6,64,128]{3,2,1,0},f16[2,6,128,64]{3,2,1,0},f16[2,6,128,64]{3,2,1,0})->(f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,64,128]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,128]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true}
+
+region_0.21 {
+  Arg_0.22 = f16[] parameter(0)
+  Arg_1.23 = f16[] parameter(1)
+  ROOT maximum = f16[] maximum(Arg_0.22, Arg_1.23)
+}
+
+region_1.33 {
+  Arg_0.34 = f32[] parameter(0)
+  Arg_1.35 = f32[] parameter(1)
+  ROOT add = f32[] add(Arg_0.34, Arg_1.35)
+}
+
+region_2.55 {
+  Arg_0.56 = f16[] parameter(0)
+  Arg_1.57 = f16[] parameter(1)
+  ROOT add.1 = f16[] add(Arg_0.56, Arg_1.57)
+}
+
+ENTRY main.82 {
+  Arg_0.1 = f16[2,6,64,128]{3,2,1,0} parameter(0), sharding={replicated}
+  Arg_1.2 = f16[2,6,64,128]{3,2,1,0} parameter(1), sharding={replicated}
+  dot.17 = f16[2,6,128,128]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  constant.22 = f16[] constant(2)
+  broadcast.24 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.22), dimensions={}
+  multiply.2 = f16[2,6,128,128]{3,2,1,0} multiply(dot.17, broadcast.24)
+  constant.19 = f16[] constant(1)
+  broadcast.13 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.19), dimensions={}
+  add.3 = f16[2,6,128,128]{3,2,1,0} add(multiply.2, broadcast.13)
+  constant.21 = f16[] constant(0)
+  constant.15 = f16[] constant(-inf)
+  reduce.25 = f16[2,6,128]{2,1,0} reduce(add.3, constant.15), dimensions={3}, to_apply=region_0.21
+  broadcast.17 = f16[2,6,128,128]{3,2,1,0} broadcast(reduce.25), dimensions={0,1,2}
+  subtract.1 = f16[2,6,128,128]{3,2,1,0} subtract(add.3, broadcast.17)
+  exponential.1 = f16[2,6,128,128]{3,2,1,0} exponential(subtract.1)
+  convert.5 = f32[2,6,128,128]{3,2,1,0} convert(exponential.1)
+  constant.17 = f32[] constant(0)
+  reduce.37 = f32[2,6,128]{2,1,0} reduce(convert.5, constant.17), dimensions={3}, to_apply=region_1.33
+  convert.9 = f16[2,6,128]{2,1,0} convert(reduce.37)
+  broadcast.26 = f16[2,6,128,128]{3,2,1,0} broadcast(convert.9), dimensions={0,1,2}
+  divide.5 = f16[2,6,128,128]{3,2,1,0} divide(exponential.1, broadcast.26)
+  Arg_2.3 = f16[2,6,128,64]{3,2,1,0} parameter(2), sharding={replicated}
+  dot.46 = f16[2,6,128,64]{3,2,1,0} dot(divide.5, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  Arg_3.4 = f16[2,6,128,64]{3,2,1,0} parameter(3), sharding={replicated}
+  dot.49 = f16[2,6,128,128]{3,2,1,0} dot(Arg_3.4, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
+  divide.4 = f16[2,6,128,128]{3,2,1,0} divide(dot.49, broadcast.26)
+  broadcast.20 = f16[2,6,128]{2,1,0} broadcast(constant.19), dimensions={}
+  multiply.3 = f16[2,6,128]{2,1,0} multiply(convert.9, convert.9)
+  divide.3 = f16[2,6,128]{2,1,0} divide(broadcast.20, multiply.3)
+  broadcast.21 = f16[2,6,128,128]{3,2,1,0} broadcast(divide.3), dimensions={0,1,2}
+  multiply.4 = f16[2,6,128,128]{3,2,1,0} multiply(dot.49, broadcast.21)
+  multiply.5 = f16[2,6,128,128]{3,2,1,0} multiply(multiply.4, exponential.1)
+  reduce.59 = f16[2,6,128]{2,1,0} reduce(multiply.5, constant.21), dimensions={3}, to_apply=region_2.55
+  negate.2 = f16[2,6,128]{2,1,0} negate(reduce.59)
+  broadcast.25 = f16[2,6,128,128]{3,2,1,0} broadcast(negate.2), dimensions={0,1,2}
+  add.5 = f16[2,6,128,128]{3,2,1,0} add(divide.4, broadcast.25)
+  multiply.8 = f16[2,6,128,128]{3,2,1,0} multiply(add.5, exponential.1)
+  multiply.9 = f16[2,6,128,128]{3,2,1,0} multiply(multiply.8, broadcast.24)
+  dot.80 = f16[2,6,128,64]{3,2,1,0} dot(multiply.9, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
+  dot = f16[2,6,64,128]{3,2,1,0} dot(Arg_0.1, multiply.9), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  // add another user of ds multiply.9 here, neg.1 should not be pattern matched as bmm2grad1
+  neg.1 = f16[2,6,128,128]{3,2,1,0} negate(multiply.9)
+  dot.1 = f16[2,6,128,64]{3,2,1,0} dot(divide.5, Arg_3.4), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  ROOT tuple.81 = (f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,64,128]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,128]{3,2,1,0}) tuple(dot.46, dot.80, dot, dot.1, neg.1)
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
+                                         GetCudnnVersion()};
+  TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
+  HloDCE dce;
+  TF_ASSERT_OK(RunHloPass(&dce, m.get()).status());
+
+  ComputationLayout computation_layout(
+      m->entry_computation()->ComputeProgramShape());
+
+  SCOPED_TRACE(m->ToString());
+  EXPECT_THAT(m->entry_computation()->root_instruction(),
+              GmockMatch(m::Tuple(m::Dot(), m::Dot(), m::Dot(), m::Dot(),
+                                  m::Negate())));
+}
+
+// flash attention
+TEST_F(CudnnFusedMhaRewriterTestHloTest,
+       FlashAttentionBF16TrainingBmm1CausalMaskSoftmaxBmm2Pattern) {
+  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+  const char* module_str = R"(
+HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[2,6,2048,128]{3,2,1,0},bf16[2,6,128,2048]{3,2,1,0},bf16[2,6,2048,128]{3,2,1,0},bf16[2,6,2048,128]{3,2,1,0})->(bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,128,2048]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true}
+region_0.32 {
+  Arg_0.33 = bf16[] parameter(0)
+  Arg_1.34 = bf16[] parameter(1)
+  ROOT maximum = bf16[] maximum(Arg_0.33, Arg_1.34)
+}
+region_1.44 {
+  Arg_0.45 = f32[] parameter(0)
+  Arg_1.46 = f32[] parameter(1)
+  ROOT add = f32[] add(Arg_0.45, Arg_1.46)
+}
+region_2.66 {
+  Arg_0.67 = bf16[] parameter(0)
+  Arg_1.68 = bf16[] parameter(1)
+  ROOT add.1 = bf16[] add(Arg_0.67, Arg_1.68)
+}
+ENTRY main.92 {
+  Arg_0.1 = bf16[2,6,2048,128]{3,2,1,0} parameter(0), sharding={replicated}
+  Arg_1.2 = bf16[2,6,128,2048]{3,2,1,0} parameter(1), sharding={replicated}
+  dot.14 = bf16[2,6,2048,2048]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  constant.17 = bf16[] constant(2)
+  broadcast.29 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(constant.17), dimensions={}
+  multiply.2 = bf16[2,6,2048,2048]{3,2,1,0} multiply(dot.14, broadcast.29)
+  iota.2 = s32[2048,2048]{1,0} iota(), iota_dimension=0
+  iota.5 = s32[2048,2048]{1,0} iota(), iota_dimension=1
+  compare.1 = pred[2048,2048]{1,0} compare(iota.2, iota.5), direction=LT
+  constant.6 = bf16[] constant(-2.366e+38)
+  broadcast.16 = bf16[2048,2048]{1,0} broadcast(constant.6), dimensions={}
+  constant.16 = bf16[] constant(0)
+  broadcast.17 = bf16[2048,2048]{1,0} broadcast(constant.16), dimensions={}
+  select.2 = bf16[2048,2048]{1,0} select(compare.1, broadcast.16, broadcast.17)
+  broadcast.19 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(select.2), dimensions={2,3}
+  add.3 = bf16[2,6,2048,2048]{3,2,1,0} add(multiply.2, broadcast.19)
+  constant.10 = bf16[] constant(-inf)
+  reduce.36 = bf16[2,6,2048]{2,1,0} reduce(add.3, constant.10), dimensions={3}, to_apply=region_0.32
+  broadcast.21 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(reduce.36), dimensions={0,1,2}
+  subtract.1 = bf16[2,6,2048,2048]{3,2,1,0} subtract(add.3, broadcast.21)
+  exponential.1 = bf16[2,6,2048,2048]{3,2,1,0} exponential(subtract.1)
+  convert.5 = f32[2,6,2048,2048]{3,2,1,0} convert(exponential.1)
+  constant.14 = f32[] constant(0)
+  reduce.48 = f32[2,6,2048]{2,1,0} reduce(convert.5, constant.14), dimensions={3}, to_apply=region_1.44
+  convert.9 = bf16[2,6,2048]{2,1,0} convert(reduce.48)
+  broadcast.32 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(convert.9), dimensions={0,1,2}
+  divide.5 = bf16[2,6,2048,2048]{3,2,1,0} divide(exponential.1, broadcast.32)
+  Arg_2.3 = bf16[2,6,2048,128]{3,2,1,0} parameter(2), sharding={replicated}
+  dot.57 = bf16[2,6,2048,128]{3,2,1,0} dot(divide.5, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  Arg_3.4 = bf16[2,6,2048,128]{3,2,1,0} parameter(3), sharding={replicated}
+  dot.60 = bf16[2,6,2048,2048]{3,2,1,0} dot(Arg_3.4, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
+  divide.4 = bf16[2,6,2048,2048]{3,2,1,0} divide(dot.60, broadcast.32)
+  constant.15 = bf16[] constant(1)
+  broadcast.25 = bf16[2,6,2048]{2,1,0} broadcast(constant.15), dimensions={}
+  multiply.3 = bf16[2,6,2048]{2,1,0} multiply(convert.9, convert.9)
+  divide.3 = bf16[2,6,2048]{2,1,0} divide(broadcast.25, multiply.3)
+  broadcast.26 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(divide.3), dimensions={0,1,2}
+  multiply.4 = bf16[2,6,2048,2048]{3,2,1,0} multiply(dot.60, broadcast.26)
+  multiply.5 = bf16[2,6,2048,2048]{3,2,1,0} multiply(multiply.4, exponential.1)
+  reduce.70 = bf16[2,6,2048]{2,1,0} reduce(multiply.5, constant.16), dimensions={3}, to_apply=region_2.66
+  negate.2 = bf16[2,6,2048]{2,1,0} negate(reduce.70)
+  broadcast.31 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(negate.2), dimensions={0,1,2}
+  add.5 = bf16[2,6,2048,2048]{3,2,1,0} add(divide.4, broadcast.31)
+  multiply.8 = bf16[2,6,2048,2048]{3,2,1,0} multiply(add.5, exponential.1)
+  multiply.9 = bf16[2,6,2048,2048]{3,2,1,0} multiply(multiply.8, broadcast.29)
+  dot.90 = bf16[2,6,2048,128]{3,2,1,0} dot(multiply.9, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
+  dot = bf16[2,6,128,2048]{3,2,1,0} dot(Arg_0.1, multiply.9), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  dot.1 = bf16[2,6,2048,128]{3,2,1,0} dot(divide.5, Arg_3.4), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  ROOT tuple.91 = (bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,128,2048]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0}) tuple(dot.57, dot.90, dot, dot.1)
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
+                                         GetCudnnVersion()};
+  TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
+  HloDCE dce;
+  TF_ASSERT_OK(RunHloPass(&dce, m.get()).status());
+
+  ComputationLayout computation_layout(
+      m->entry_computation()->ComputeProgramShape());
+
+  const HloInstruction* fwd_fmha;
+  const HloInstruction* bwd_fmha;
+  SCOPED_TRACE(m->ToString());
+  EXPECT_THAT(
+      m->entry_computation()->root_instruction(),
+      GmockMatch(m::Tuple(
+          m::GetTupleElement(
+              m::CustomCall(&fwd_fmha, {kCudnnfMHASoftmaxCallTarget}), 0)
+              .WithShape(BF16, {2, 6, 2048, 128}),
+          m::GetTupleElement(
+              m::CustomCall(&bwd_fmha, {kCudnnfMHASoftmaxBackwardCallTarget}),
+              0)
+              .WithShape(BF16, {2, 6, 2048, 128}),
+          m::Transpose(
+              m::GetTupleElement(
+                  m::CustomCall({kCudnnfMHASoftmaxBackwardCallTarget}), 1))
+              .WithShape(BF16, {2, 6, 128, 2048}),
+          m::GetTupleElement(
+              m::CustomCall({kCudnnfMHASoftmaxBackwardCallTarget}), 2)
+              .WithShape(BF16, {2, 6, 2048, 128}))));
+  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+                          fwd_fmha->backend_config<GpuBackendConfig>());
+  const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config();
+  EXPECT_EQ(fwd_fmha->operands().size(), 3);
+  EXPECT_EQ(bwd_fmha->operands().size(), 6);
+  EXPECT_NEAR(config.dropout_rate(), 0, 1e-2);
+  EXPECT_EQ(config.mask_type(), CudnnfMHABackendConfig::CAUSAL);
+}
+
+TEST_F(CudnnFusedMhaRewriterTestHloTest,
+       FlashAttentionBF16TrainingBmm1BiasSoftmaxBmm2Pattern) {
+  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+  const char* module_str = R"(
+HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[2,6,2048,128]{3,2,1,0},bf16[2,6,128,2048]{3,2,1,0},bf16[2,6,2048,128]{3,2,1,0},bf16[2,6,2048,128]{3,2,1,0},bf16[2,6,2048,2048]{3,2,1,0})->(bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,128,2048]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true}
+region_0.32 {
+  Arg_0.33 = bf16[] parameter(0)
+  Arg_1.34 = bf16[] parameter(1)
+  ROOT maximum = bf16[] maximum(Arg_0.33, Arg_1.34)
+}
+region_1.44 {
+  Arg_0.45 = f32[] parameter(0)
+  Arg_1.46 = f32[] parameter(1)
+  ROOT add = f32[] add(Arg_0.45, Arg_1.46)
+}
+region_2.66 {
+  Arg_0.67 = bf16[] parameter(0)
+  Arg_1.68 = bf16[] parameter(1)
+  ROOT add.1 = bf16[] add(Arg_0.67, Arg_1.68)
+}
+ENTRY main.92 {
+  Arg_0.1 = bf16[2,6,2048,128]{3,2,1,0} parameter(0), sharding={replicated}
+  Arg_1.2 = bf16[2,6,128,2048]{3,2,1,0} parameter(1), sharding={replicated}
+  dot.14 = bf16[2,6,2048,2048]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  constant.17 = bf16[] constant(2)
+  broadcast.29 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(constant.17), dimensions={}
+  multiply.2 = bf16[2,6,2048,2048]{3,2,1,0} multiply(dot.14, broadcast.29)
+  // bias
+  Arg_4.5 = bf16[2,6,2048,2048]{3,2,1,0} parameter(4), sharding={replicated}
+  add.3 = bf16[2,6,2048,2048]{3,2,1,0} add(multiply.2, Arg_4.5)
+  constant.10 = bf16[] constant(-inf)
+  constant.16 = bf16[] constant(0)
+  reduce.36 = bf16[2,6,2048]{2,1,0} reduce(add.3, constant.10), dimensions={3}, to_apply=region_0.32
+  broadcast.21 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(reduce.36), dimensions={0,1,2}
+  subtract.1 = bf16[2,6,2048,2048]{3,2,1,0} subtract(add.3, broadcast.21)
+  exponential.1 = bf16[2,6,2048,2048]{3,2,1,0} exponential(subtract.1)
+  convert.5 = f32[2,6,2048,2048]{3,2,1,0} convert(exponential.1)
+  constant.14 = f32[] constant(0)
+  reduce.48 = f32[2,6,2048]{2,1,0} reduce(convert.5, constant.14), dimensions={3}, to_apply=region_1.44
+  convert.9 = bf16[2,6,2048]{2,1,0} convert(reduce.48)
+  broadcast.32 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(convert.9), dimensions={0,1,2}
+  divide.5 = bf16[2,6,2048,2048]{3,2,1,0} divide(exponential.1, broadcast.32)
+  Arg_2.3 = bf16[2,6,2048,128]{3,2,1,0} parameter(2), sharding={replicated}
+  dot.57 = bf16[2,6,2048,128]{3,2,1,0} dot(divide.5, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  Arg_3.4 = bf16[2,6,2048,128]{3,2,1,0} parameter(3), sharding={replicated}
+  dot.60 = bf16[2,6,2048,2048]{3,2,1,0} dot(Arg_3.4, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
+  divide.4 = bf16[2,6,2048,2048]{3,2,1,0} divide(dot.60, broadcast.32)
+  constant.15 = bf16[] constant(1)
+  broadcast.25 = bf16[2,6,2048]{2,1,0} broadcast(constant.15), dimensions={}
+  multiply.3 = bf16[2,6,2048]{2,1,0} multiply(convert.9, convert.9)
+  divide.3 = bf16[2,6,2048]{2,1,0} divide(broadcast.25, multiply.3)
+  broadcast.26 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(divide.3), dimensions={0,1,2}
+  multiply.4 = bf16[2,6,2048,2048]{3,2,1,0} multiply(dot.60, broadcast.26)
+  multiply.5 = bf16[2,6,2048,2048]{3,2,1,0} multiply(multiply.4, exponential.1)
+  reduce.70 = bf16[2,6,2048]{2,1,0} reduce(multiply.5, constant.16), dimensions={3}, to_apply=region_2.66
+  negate.2 = bf16[2,6,2048]{2,1,0} negate(reduce.70)
+  broadcast.31 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(negate.2), dimensions={0,1,2}
+  add.5 = bf16[2,6,2048,2048]{3,2,1,0} add(divide.4, broadcast.31)
+  multiply.8 = bf16[2,6,2048,2048]{3,2,1,0} multiply(add.5, exponential.1)
+  multiply.9 = bf16[2,6,2048,2048]{3,2,1,0} multiply(multiply.8, broadcast.29)
+  dot.90 = bf16[2,6,2048,128]{3,2,1,0} dot(multiply.9, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
+  dot = bf16[2,6,128,2048]{3,2,1,0} dot(Arg_0.1, multiply.9), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  dot.1 = bf16[2,6,2048,128]{3,2,1,0} dot(divide.5, Arg_3.4), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  ROOT tuple.91 = (bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,128,2048]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0}) tuple(dot.57, dot.90, dot, dot.1)
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
+                                         GetCudnnVersion()};
+  TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
+  HloDCE dce;
+  TF_ASSERT_OK(RunHloPass(&dce, m.get()).status());
+
+  ComputationLayout computation_layout(
+      m->entry_computation()->ComputeProgramShape());
+
+  const HloInstruction* fmha;
+
+  SCOPED_TRACE(m->ToString());
+  EXPECT_THAT(
+      m->entry_computation()->root_instruction(),
+      GmockMatch(m::Tuple(
+          m::GetTupleElement(
+              m::CustomCall(&fmha, {kCudnnfMHAScaleBiasSoftmaxCallTarget}), 0)
+              .WithShape(BF16, {2, 6, 2048, 128}),
+          m::GetTupleElement(
+              m::CustomCall(&fmha,
+                            {kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget}),
+              0)
+              .WithShape(BF16, {2, 6, 2048, 128}),
+          m::Transpose(
+              m::GetTupleElement(
+                  m::CustomCall({kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget}),
+                  1))
+              .WithShape(BF16, {2, 6, 128, 2048}),
+          m::GetTupleElement(
+              m::CustomCall({kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget}), 2)
+              .WithShape(BF16, {2, 6, 2048, 128}))));
+  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+                          fmha->backend_config<GpuBackendConfig>());
+  const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config();
+  EXPECT_EQ(fmha->operands().size(), 7);
+  EXPECT_NEAR(config.dropout_rate(), 0, 1e-2);
+  EXPECT_EQ(config.mask_type(), CudnnfMHABackendConfig::NO_MASK);
+}
+
+TEST_F(CudnnFusedMhaRewriterTestHloTest,
+       FlashAttentionBF16TrainingBmm1SoftmaxBmm2Pattern) {
+  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+  const char* module_str = R"(
+HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[2,6,2048,128]{3,2,1,0},bf16[2,6,128,2048]{3,2,1,0},bf16[2,6,2048,128]{3,2,1,0},bf16[2,6,2048,128]{3,2,1,0})->(bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,128,2048]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true}
+region_0.32 {
+  Arg_0.33 = bf16[] parameter(0)
+  Arg_1.34 = bf16[] parameter(1)
+  ROOT maximum = bf16[] maximum(Arg_0.33, Arg_1.34)
+}
+region_1.44 {
+  Arg_0.45 = f32[] parameter(0)
+  Arg_1.46 = f32[] parameter(1)
+  ROOT add = f32[] add(Arg_0.45, Arg_1.46)
+}
+region_2.66 {
+  Arg_0.67 = bf16[] parameter(0)
+  Arg_1.68 = bf16[] parameter(1)
+  ROOT add.1 = bf16[] add(Arg_0.67, Arg_1.68)
+}
+ENTRY main.92 {
+  Arg_0.1 = bf16[2,6,2048,128]{3,2,1,0} parameter(0), sharding={replicated}
+  Arg_1.2 = bf16[2,6,128,2048]{3,2,1,0} parameter(1), sharding={replicated}
+  dot.14 = bf16[2,6,2048,2048]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  constant.17 = bf16[] constant(2)
+  broadcast.29 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(constant.17), dimensions={}
+  multiply.2 = bf16[2,6,2048,2048]{3,2,1,0} multiply(dot.14, broadcast.29)
+  constant.10 = bf16[] constant(-inf)
+  constant.16 = bf16[] constant(0)
+  reduce.36 = bf16[2,6,2048]{2,1,0} reduce(multiply.2, constant.10), dimensions={3}, to_apply=region_0.32
+  broadcast.21 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(reduce.36), dimensions={0,1,2}
+  subtract.1 = bf16[2,6,2048,2048]{3,2,1,0} subtract(multiply.2, broadcast.21)
+  exponential.1 = bf16[2,6,2048,2048]{3,2,1,0} exponential(subtract.1)
+  convert.5 = f32[2,6,2048,2048]{3,2,1,0} convert(exponential.1)
+  constant.14 = f32[] constant(0)
+  reduce.48 = f32[2,6,2048]{2,1,0} reduce(convert.5, constant.14), dimensions={3}, to_apply=region_1.44
+  convert.9 = bf16[2,6,2048]{2,1,0} convert(reduce.48)
+  broadcast.32 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(convert.9), dimensions={0,1,2}
+  divide.5 = bf16[2,6,2048,2048]{3,2,1,0} divide(exponential.1, broadcast.32)
+  Arg_2.3 = bf16[2,6,2048,128]{3,2,1,0} parameter(2), sharding={replicated}
+  dot.57 = bf16[2,6,2048,128]{3,2,1,0} dot(divide.5, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  Arg_3.4 = bf16[2,6,2048,128]{3,2,1,0} parameter(3), sharding={replicated}
+  dot.60 = bf16[2,6,2048,2048]{3,2,1,0} dot(Arg_3.4, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
+  divide.4 = bf16[2,6,2048,2048]{3,2,1,0} divide(dot.60, broadcast.32)
+  constant.15 = bf16[] constant(1)
+  broadcast.25 = bf16[2,6,2048]{2,1,0} broadcast(constant.15), dimensions={}
+  multiply.3 = bf16[2,6,2048]{2,1,0} multiply(convert.9, convert.9)
+  divide.3 = bf16[2,6,2048]{2,1,0} divide(broadcast.25, multiply.3)
+  broadcast.26 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(divide.3), dimensions={0,1,2}
+  multiply.4 = bf16[2,6,2048,2048]{3,2,1,0} multiply(dot.60, broadcast.26)
+  multiply.5 = bf16[2,6,2048,2048]{3,2,1,0} multiply(multiply.4, exponential.1)
+  reduce.70 = bf16[2,6,2048]{2,1,0} reduce(multiply.5, constant.16), dimensions={3}, to_apply=region_2.66
+  negate.2 = bf16[2,6,2048]{2,1,0} negate(reduce.70)
+  broadcast.31 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(negate.2), dimensions={0,1,2}
+  add.5 = bf16[2,6,2048,2048]{3,2,1,0} add(divide.4, broadcast.31)
+  multiply.8 = bf16[2,6,2048,2048]{3,2,1,0} multiply(add.5, exponential.1)
+  multiply.9 = bf16[2,6,2048,2048]{3,2,1,0} multiply(multiply.8, broadcast.29)
+  dot.90 = bf16[2,6,2048,128]{3,2,1,0} dot(multiply.9, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
+  dot = bf16[2,6,128,2048]{3,2,1,0} dot(Arg_0.1, multiply.9), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  dot.1 = bf16[2,6,2048,128]{3,2,1,0} dot(divide.5, Arg_3.4), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  ROOT tuple.91 = (bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,128,2048]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0}) tuple(dot.57, dot.90, dot, dot.1)
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
+                                         GetCudnnVersion()};
+  TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
+  HloDCE dce;
+  TF_ASSERT_OK(RunHloPass(&dce, m.get()).status());
+
+  ComputationLayout computation_layout(
+      m->entry_computation()->ComputeProgramShape());
+
+  const HloInstruction* fmha;
+
+  SCOPED_TRACE(m->ToString());
+  EXPECT_THAT(
+      m->entry_computation()->root_instruction(),
+      GmockMatch(m::Tuple(
+          m::GetTupleElement(
+              m::CustomCall(&fmha, {kCudnnfMHASoftmaxCallTarget}), 0)
+              .WithShape(BF16, {2, 6, 2048, 128}),
+          m::GetTupleElement(
+              m::CustomCall(&fmha, {kCudnnfMHASoftmaxBackwardCallTarget}), 0)
+              .WithShape(BF16, {2, 6, 2048, 128}),
+          m::Transpose(
+              m::GetTupleElement(
+                  m::CustomCall({kCudnnfMHASoftmaxBackwardCallTarget}), 1))
+              .WithShape(BF16, {2, 6, 128, 2048}),
+          m::GetTupleElement(
+              m::CustomCall({kCudnnfMHASoftmaxBackwardCallTarget}), 2)
+              .WithShape(BF16, {2, 6, 2048, 128}))));
+  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+                          fmha->backend_config<GpuBackendConfig>());
+  const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config();
+  EXPECT_EQ(fmha->operands().size(), 6);
+  EXPECT_NEAR(config.dropout_rate(), 0, 1e-2);
+  EXPECT_FLOAT_EQ(config.fmha_scale(), 2);
+  EXPECT_EQ(config.mask_type(), CudnnfMHABackendConfig::NO_MASK);
+}
+
+// GPT3 pattern
+TEST_F(CudnnFusedMhaRewriterTestHloTest, FlashAttentionBF16TrainingGPT3_5B) {
+  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+  const char* module_str = R"(
+HloModule jit__unnamed_wrapped_function_, entry_computation_layout={((s32[], bf16[32,2048,2048]{1,0,2}, bf16[24,8192]{1,0}, bf16[24,1024,8192]{2,1,0}, bf16[24,1024]{0,1}, /*index=5*/bf16[24,8192,1024]{1,2,0}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, /*index=10*/bf16[24,3,16,128]{3,2,1,0}, bf16[24,3,1024,16,128]{4,3,1,2,0}, bf16[24,1024]{1,0}, bf16[24,1024,16,128]{3,2,1,0}, bf16[24,8192]{1,0}, /*index=15*/bf16[24,1024,8192]{2,1,0}, bf16[24,8192,1024]{1,2,0}, bf16[24,2048]{1,0}, bf16[24,2048]{1,0}, bf16[24,2048]{1,0}, /*index=20*/bf16[24,2048]{1,0}, bf16[24,3,16,128]{3,2,1,0}, bf16[24,3,1024,16,128]{4,3,1,2,0}, bf16[24,1024]{1,0}, bf16[24,1024,16,128]{3,2,1,0}, /*index=25*/bf16[24,32,2048,2048]{2,1,3,0}, bf16[32,1,2048,2048]{3,2,0,1}, bf16[32,2048]{1,0}))->(s32[], bf16[32,2048,2048]{1,0,2}, bf16[24,8192]{1,0}, bf16[24,1024,8192]{2,1,0}, bf16[24,1024]{0,1}, /*index=5*/bf16[24,8192,1024]{1,2,0}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, /*index=10*/bf16[24,3,16,128]{3,2,1,0}, bf16[24,3,1024,16,128]{4,3,1,2,0}, bf16[24,1024]{1,0}, bf16[24,1024,16,128]{3,2,1,0}, bf16[24,8192]{1,0}, /*index=15*/bf16[24,1024,8192]{2,1,0}, bf16[24,8192,1024]{1,2,0}, bf16[24,2048]{1,0}, bf16[24,2048]{1,0}, bf16[24,2048]{1,0}, /*index=20*/bf16[24,2048]{1,0}, bf16[24,3,16,128]{3,2,1,0}, bf16[24,3,1024,16,128]{4,3,1,2,0}, bf16[24,1024]{1,0}, bf16[24,1024,16,128]{3,2,1,0}, /*index=25*/bf16[24,32,2048,2048]{2,1,3,0}, bf16[32,1,2048,2048]{3,2,0,1}, bf16[32,2048]{1,0})}
+add {
+  x = bf16[] parameter(0)
+  y = bf16[] parameter(1)
+  ROOT add.580 = bf16[] add(x, y)
+}
+
+region_20.962 {
+  Arg_0.963 = f32[] parameter(0)
+  Arg_1.964 = f32[] parameter(1)
+  ROOT add.579 = f32[] add(Arg_0.963, Arg_1.964)
+}
+
+region_39.1120 {
+  Arg_0.1121 = f32[] parameter(0)
+  Arg_1.1122 = f32[] parameter(1)
+  ROOT maximum.21 = f32[] maximum(Arg_0.1121, Arg_1.1122)
+}
+
+main {
+  param.3 = (s32[], bf16[32,2048,2048]{1,0,2}, bf16[24,8192]{1,0}, bf16[24,1024,8192]{2,1,0}, bf16[24,1024]{0,1}, /*index=5*/bf16[24,8192,1024]{1,2,0}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, /*index=10*/bf16[24,3,16,128]{3,2,1,0}, bf16[24,3,1024,16,128]{4,3,1,2,0}, bf16[24,1024]{1,0}, bf16[24,1024,16,128]{3,2,1,0}, bf16[24,8192]{1,0}, /*index=15*/bf16[24,1024,8192]{2,1,0}, bf16[24,8192,1024]{1,2,0}, bf16[24,2048]{1,0}, bf16[24,2048]{1,0}, bf16[24,2048]{1,0}, /*index=20*/bf16[24,2048]{1,0}, bf16[24,3,16,128]{3,2,1,0}, bf16[24,3,1024,16,128]{4,3,1,2,0}, bf16[24,1024]{1,0}, bf16[24,1024,16,128]{3,2,1,0}, /*index=25*/bf16[24,32,2048,2048]{2,1,3,0}, bf16[32,1,2048,2048]{3,2,0,1}, bf16[32,2048]{1,0}) parameter(0)
+  get-tuple-element.31 = s32[] get-tuple-element(param.3), index=0
+  constant.1961 = s32[] constant(1)
+  add.581 = s32[] add(get-tuple-element.31, constant.1961)
+  get-tuple-element.32 = bf16[24,32,2048,2048]{2,1,3,0} get-tuple-element(param.3), index=25
+  bitcast.187 = bf16[24,2048,32,2048]{3,2,1,0} bitcast(get-tuple-element.32)
+  constant.1977 = s32[] constant(23)
+  subtract.221 = s32[] subtract(constant.1977, get-tuple-element.31)
+  constant.1980 = s32[] constant(0)
+  compare.210 = pred[] compare(subtract.221, constant.1980), direction=LT
+  constant.1979 = s32[] constant(47)
+  subtract.222 = s32[] subtract(constant.1979, get-tuple-element.31)
+  select.372 = s32[] select(compare.210, subtract.222, subtract.221)
+  dynamic-slice.324 = bf16[1,2048,32,2048]{3,2,1,0} dynamic-slice(bitcast.187, select.372, constant.1980, constant.1980, constant.1980), dynamic_slice_sizes={1,2048,32,2048}
+  bitcast.756 = bf16[2048,32,2048]{2,1,0} bitcast(dynamic-slice.324)
+  convert.282 = f32[2048,32,2048]{2,1,0} convert(bitcast.756)
+  constant.1991 = bf16[] constant(1)
+  broadcast.1270 = bf16[32,2048]{1,0} broadcast(constant.1991), dimensions={}
+  get-tuple-element.33 = bf16[32,2048]{1,0} get-tuple-element(param.3), index=27
+  subtract.229 = bf16[32,2048]{1,0} subtract(broadcast.1270, get-tuple-element.33)
+  convert.285 = f32[32,2048]{1,0} convert(subtract.229)
+  broadcast.1228 = f32[2048,32,2048]{2,1,0} broadcast(convert.285), dimensions={1,2}
+  multiply.656 = f32[2048,32,2048]{2,1,0} multiply(convert.282, broadcast.1228)
+  bitcast.367 = f32[32,2048,2048]{1,0,2} bitcast(multiply.656)
+  constant.1968 = f32[] constant(0)
+  reduce.84 = f32[] reduce(bitcast.367, constant.1968), dimensions={0,1,2}, to_apply=region_20.962
+  all-reduce.230 = f32[] all-reduce(reduce.84), channel_id=278, replica_groups={{0,1,2,3,4,5,6,7}}, use_global_device_ids=true, to_apply=region_20.962
+  broadcast.1221 = f32[32,2048,4096]{2,1,0} broadcast(convert.285), dimensions={0,1}
+  reduce.85 = f32[] reduce(broadcast.1221, constant.1968), dimensions={0,1,2}, to_apply=region_20.962
+  all-reduce.14 = f32[] all-reduce(reduce.85), channel_id=49, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=region_20.962
+  constant.2005 = f32[] constant(1)
+  maximum.24 = f32[] maximum(all-reduce.14, constant.2005)
+  divide.96 = f32[] divide(all-reduce.230, maximum.24)
+  broadcast.1223 = f32[2048,32,2048]{2,1,0} broadcast(divide.96), dimensions={}
+  subtract.219 = f32[2048,32,2048]{2,1,0} subtract(convert.282, broadcast.1223)
+  multiply.644 = f32[2048,32,2048]{2,1,0} multiply(subtract.219, broadcast.1228)
+  multiply.645 = f32[2048,32,2048]{2,1,0} multiply(multiply.644, multiply.644)
+  bitcast.271 = f32[32,2048,2048]{1,0,2} bitcast(multiply.645)
+  reduce.86 = f32[] reduce(bitcast.271, constant.1968), dimensions={0,1,2}, to_apply=region_20.962
+  all-reduce.231 = f32[] all-reduce(reduce.86), channel_id=279, replica_groups={{0,1,2,3,4,5,6,7}}, use_global_device_ids=true, to_apply=region_20.962
+  divide.99 = f32[] divide(all-reduce.231, maximum.24)
+  rsqrt.16 = f32[] rsqrt(divide.99)
+  multiply.650 = f32[] multiply(rsqrt.16, constant.1968)
+  divide.100 = f32[] divide(multiply.650, maximum.24)
+  constant.1974 = f32[] constant(2)
+  multiply.652 = f32[] multiply(divide.100, constant.1974)
+  broadcast.1227 = f32[2048,32,2048]{2,1,0} broadcast(multiply.652), dimensions={}
+  multiply.653 = f32[2048,32,2048]{2,1,0} multiply(multiply.644, broadcast.1227)
+  multiply.654 = f32[2048,32,2048]{2,1,0} multiply(multiply.653, broadcast.1228)
+  negate.56 = f32[2048,32,2048]{2,1,0} negate(multiply.654)
+  bitcast.321 = f32[32,2048,2048]{1,0,2} bitcast(negate.56)
+  reduce.87 = f32[] reduce(bitcast.321, constant.1968), dimensions={0,1,2}, to_apply=region_20.962
+  all-reduce.232 = f32[] all-reduce(reduce.87), channel_id=280, replica_groups={{0,1,2,3,4,5,6,7}}, use_global_device_ids=true, to_apply=region_20.962
+  divide.101 = f32[] divide(all-reduce.232, maximum.24)
+  broadcast.1229 = f32[32,2048]{1,0} broadcast(divide.101), dimensions={}
+  multiply.655 = f32[32,2048]{1,0} multiply(broadcast.1229, convert.285)
+  broadcast.1230 = f32[2048,32,2048]{2,1,0} broadcast(multiply.655), dimensions={1,2}
+  add.582 = f32[2048,32,2048]{2,1,0} add(multiply.654, broadcast.1230)
+  broadcast.1236 = f32[2048,32,2048]{2,1,0} broadcast(constant.1968), dimensions={}
+  compare.208 = pred[2048,32,2048]{2,1,0} compare(multiply.656, broadcast.1236), direction=GE
+  abs.22 = f32[2048,32,2048]{2,1,0} abs(multiply.656)
+  bitcast.373 = f32[32,2048,2048]{1,0,2} bitcast(abs.22)
+  constant.1989 = f32[] constant(-inf)
+  reduce.88 = f32[] reduce(bitcast.373, constant.1989), dimensions={0,1,2}, to_apply=region_39.1120
+  all-reduce.233 = f32[] all-reduce(reduce.88), channel_id=281, replica_groups={{0,1,2,3,4,5,6,7}}, use_global_device_ids=true, to_apply=region_39.1120
+  broadcast.1233 = f32[2048,32,2048]{2,1,0} broadcast(all-reduce.233), dimensions={}
+  compare.207 = pred[2048,32,2048]{2,1,0} compare(abs.22, broadcast.1233), direction=EQ
+  convert.286 = f32[2048,32,2048]{2,1,0} convert(compare.207)
+  bitcast.393 = f32[32,2048,2048]{1,0,2} bitcast(convert.286)
+  reduce.89 = f32[] reduce(bitcast.393, constant.1968), dimensions={0,1,2}, to_apply=region_20.962
+  all-reduce.234 = f32[] all-reduce(reduce.89), channel_id=282, replica_groups={{0,1,2,3,4,5,6,7}}, use_global_device_ids=true, to_apply=region_20.962
+  divide.103 = f32[] divide(constant.1968, all-reduce.234)
+  broadcast.1238 = f32[2048,32,2048]{2,1,0} broadcast(divide.103), dimensions={}
+  select.370 = f32[2048,32,2048]{2,1,0} select(compare.207, broadcast.1238, broadcast.1236)
+  select.369 = f32[2048,32,2048]{2,1,0} select(compare.208, select.370, broadcast.1236)
+  constant.1976 = pred[] constant(false)
+  broadcast.1237 = pred[2048,32,2048]{2,1,0} broadcast(constant.1976), dimensions={}
+  compare.209 = pred[2048,32,2048]{2,1,0} compare(compare.208, broadcast.1237), direction=EQ
+  select.371 = f32[2048,32,2048]{2,1,0} select(compare.209, select.370, broadcast.1236)
+  negate.57 = f32[2048,32,2048]{2,1,0} negate(select.371)
+  add.583 = f32[2048,32,2048]{2,1,0} add(select.369, negate.57)
+  multiply.658 = f32[2048,32,2048]{2,1,0} multiply(add.583, broadcast.1228)
+  add.585 = f32[2048,32,2048]{2,1,0} add(add.582, multiply.658)
+  convert.287 = bf16[2048,32,2048]{2,1,0} convert(add.585)
+  get-tuple-element.34 = bf16[32,2048,2048]{1,0,2} get-tuple-element(param.3), index=1
+  bitcast.1652 = bf16[2048,32,2048]{2,1,0} bitcast(get-tuple-element.34)
+  get-tuple-element.35 = bf16[24,3,1024,16,128]{4,3,1,2,0} get-tuple-element(param.3), index=22
+  bitcast.461 = bf16[24,1024,3,16,128]{4,3,2,1,0} bitcast(get-tuple-element.35)
+  dynamic-slice.325 = bf16[1,1024,3,16,128]{4,3,2,1,0} dynamic-slice(bitcast.461, select.372, constant.1980, constant.1980, constant.1980, /*index=5*/constant.1980), dynamic_slice_sizes={1,1024,3,16,128}
+  bitcast.485 = bf16[3,1024,16,128]{3,2,0,1} bitcast(dynamic-slice.325)
+  all-gather.7 = bf16[3,4096,16,128]{3,2,0,1} all-gather(bitcast.485), channel_id=60, replica_groups={{0,2,4,6},{1,3,5,7}}, dimensions={1}, use_global_device_ids=true
+  bitcast.1420 = bf16[6144,4096]{0,1} bitcast(all-gather.7)
+  bitcast.500 = f32[32,2048,2048]{1,0,2} bitcast(convert.282)
+  reduce.90 = f32[32,2048]{1,0} reduce(bitcast.500, constant.1968), dimensions={2}, to_apply=region_20.962
+  all-reduce.23 = f32[32,2048]{1,0} all-reduce(reduce.90), channel_id=58, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, to_apply=region_20.962
+  constant.1983 = f32[] constant(0.000244140625)
+  broadcast.1243 = f32[32,2048]{1,0} broadcast(constant.1983), dimensions={}
+  multiply.660 = f32[32,2048]{1,0} multiply(all-reduce.23, broadcast.1243)
+  broadcast.1242 = f32[2048,32,2048]{2,1,0} broadcast(multiply.660), dimensions={1,2}
+  subtract.224 = f32[2048,32,2048]{2,1,0} subtract(convert.282, broadcast.1242)
+  multiply.661 = f32[2048,32,2048]{2,1,0} multiply(subtract.224, subtract.224)
+  bitcast.527 = f32[32,2048,2048]{1,0,2} bitcast(multiply.661)
+  reduce.91 = f32[32,2048]{1,0} reduce(bitcast.527, constant.1968), dimensions={2}, to_apply=region_20.962
+  all-reduce.24 = f32[32,2048]{1,0} all-reduce(reduce.91), channel_id=59, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, to_apply=region_20.962
+  multiply.662 = f32[32,2048]{1,0} multiply(all-reduce.24, broadcast.1243)
+  constant.1990 = f32[] constant(1e-05)
+  broadcast.1264 = f32[32,2048]{1,0} broadcast(constant.1990), dimensions={}
+  add.587 = f32[32,2048]{1,0} add(multiply.662, broadcast.1264)
+  bitcast.1447 = f32[1,32,2048]{2,1,0} bitcast(add.587)
+  rsqrt.20 = f32[1,32,2048]{2,1,0} rsqrt(bitcast.1447)
+  bitcast.1892 = f32[32,2048]{1,0} bitcast(rsqrt.20)
+  broadcast.1337 = f32[2048,32,2048]{2,1,0} broadcast(bitcast.1892), dimensions={1,2}
+  multiply.754 = f32[2048,32,2048]{2,1,0} multiply(subtract.224, broadcast.1337)
+  convert.314 = bf16[2048,32,2048]{2,1,0} convert(multiply.754)
+  get-tuple-element.36 = bf16[24,2048]{1,0} get-tuple-element(param.3), index=20
+  dynamic-slice.326 = bf16[1,2048]{1,0} dynamic-slice(get-tuple-element.36, select.372, constant.1980), dynamic_slice_sizes={1,2048}
+  broadcast.1266 = bf16[1,2048]{1,0} broadcast(constant.1991), dimensions={}
+  add.588 = bf16[1,2048]{1,0} add(dynamic-slice.326, broadcast.1266)
+  bitcast.1992 = bf16[2048]{0} bitcast(add.588)
+  broadcast.1338 = bf16[2048,32,2048]{2,1,0} broadcast(bitcast.1992), dimensions={0}
+  multiply.755 = bf16[2048,32,2048]{2,1,0} multiply(convert.314, broadcast.1338)
+  get-tuple-element.37 = bf16[24,2048]{1,0} get-tuple-element(param.3), index=19
+  dynamic-slice.327 = bf16[1,2048]{1,0} dynamic-slice(get-tuple-element.37, select.372, constant.1980), dynamic_slice_sizes={1,2048}
+  bitcast.1998 = bf16[2048]{0} bitcast(dynamic-slice.327)
+  broadcast.1339 = bf16[2048,32,2048]{2,1,0} broadcast(bitcast.1998), dimensions={0}
+  add.640 = bf16[2048,32,2048]{2,1,0} add(multiply.755, broadcast.1339)
+  bitcast.2003 = bf16[32,2048,2048]{1,0,2} bitcast(add.640)
+  all-gather.8 = bf16[32,2048,4096]{1,0,2} all-gather(bitcast.2003), channel_id=61, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={2}, use_global_device_ids=true
+  bitcast.597 = bf16[4096,65536]{1,0} bitcast(all-gather.8)
+  dot.42 = bf16[6144,65536]{1,0} dot(bitcast.1420, bitcast.597), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  bitcast.623 = bf16[3,16,128,32,2048]{4,3,2,1,0} bitcast(dot.42)
+  transpose.112 = bf16[3,32,16,128,2048]{4,3,2,1,0} transpose(bitcast.623), dimensions={0,3,1,2,4}
+  get-tuple-element.38 = bf16[24,3,16,128]{3,2,1,0} get-tuple-element(param.3), index=21
+  dynamic-slice.328 = bf16[1,3,16,128]{3,2,1,0} dynamic-slice(get-tuple-element.38, select.372, constant.1980, constant.1980, constant.1980), dynamic_slice_sizes={1,3,16,128}
+  bitcast.626 = bf16[3,16,128]{2,1,0} bitcast(dynamic-slice.328)
+  broadcast.1250 = bf16[3,32,16,128,2048]{4,3,2,1,0} broadcast(bitcast.626), dimensions={0,2,3}
+  add.591 = bf16[3,32,16,128,2048]{4,3,2,1,0} add(transpose.112, broadcast.1250)
+  slice.87 = bf16[1,32,16,128,2048]{4,3,2,1,0} slice(add.591), slice={[2:3], [0:32], [0:16], [0:128], [0:2048]}
+  bitcast.1280 = bf16[32,16,128,2048]{3,2,1,0} bitcast(slice.87)
+  slice.88 = bf16[1,32,16,128,2048]{4,3,2,1,0} slice(add.591), slice={[0:1], [0:32], [0:16], [0:128], [0:2048]}
+  constant.2007 = bf16[] constant(0.08838)
+  broadcast.1251 = bf16[1,32,16,128,2048]{4,3,2,1,0} broadcast(constant.2007), dimensions={}
+  multiply.666 = bf16[1,32,16,128,2048]{4,3,2,1,0} multiply(slice.88, broadcast.1251)
+  bitcast.1330 = bf16[32,16,128,2048]{3,2,1,0} bitcast(multiply.666)
+  transpose.113 = bf16[32,16,2048,128]{3,2,1,0} transpose(bitcast.1330), dimensions={0,1,3,2}
+  slice.89 = bf16[1,32,16,128,2048]{4,3,2,1,0} slice(add.591), slice={[1:2], [0:32], [0:16], [0:128], [0:2048]}
+  bitcast.647 = bf16[32,16,128,2048]{3,2,1,0} bitcast(slice.89)
+  dot.43 = bf16[32,16,2048,2048]{3,2,1,0} dot(transpose.113, bitcast.647), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  convert.291 = f32[32,16,2048,2048]{3,2,1,0} convert(dot.43)
+  get-tuple-element.39 = bf16[32,1,2048,2048]{3,2,0,1} get-tuple-element(param.3), index=26
+  bitcast.651 = bf16[1,32,2048,2048]{3,2,1,0} bitcast(get-tuple-element.39)
+  iota.38 = s32[2048,2048]{1,0} iota(), iota_dimension=0
+  iota.39 = s32[2048,2048]{1,0} iota(), iota_dimension=1
+  compare.211 = pred[2048,2048]{1,0} compare(iota.38, iota.39), direction=LT
+  constant.1987 = bf16[] constant(-2.366e+38)
+  broadcast.1252 = bf16[2048,2048]{1,0} broadcast(constant.1987), dimensions={}
+  constant.2006 = bf16[] constant(0)
+  broadcast.1253 = bf16[2048,2048]{1,0} broadcast(constant.2006), dimensions={}
+  select.373 = bf16[2048,2048]{1,0} select(compare.211, broadcast.1252, broadcast.1253)
+  broadcast.1254 = bf16[1,32,2048,2048]{3,2,1,0} broadcast(select.373), dimensions={2,3}
+  minimum.5 = bf16[1,32,2048,2048]{3,2,1,0} minimum(bitcast.651, broadcast.1254)
+  bitcast.673 = bf16[32,2048,2048]{2,1,0} bitcast(minimum.5)
+  convert.292 = f32[32,2048,2048]{2,1,0} convert(bitcast.673)
+  broadcast.1256 = f32[32,16,2048,2048]{3,2,1,0} broadcast(convert.292), dimensions={0,2,3}
+  add.593 = f32[32,16,2048,2048]{3,2,1,0} add(convert.291, broadcast.1256)
+  reduce.92 = f32[32,16,2048]{2,1,0} reduce(add.593, constant.1989), dimensions={3}, to_apply=region_39.1120
+  broadcast.1258 = f32[32,16,2048,2048]{3,2,1,0} broadcast(reduce.92), dimensions={0,1,2}
+  subtract.226 = f32[32,16,2048,2048]{3,2,1,0} subtract(add.593, broadcast.1258)
+  exponential.8 = f32[32,16,2048,2048]{3,2,1,0} exponential(subtract.226)
+  reduce.93 = f32[32,16,2048]{2,1,0} reduce(exponential.8, constant.1968), dimensions={3}, to_apply=region_20.962
+  broadcast.1309 = f32[32,16,2048,2048]{3,2,1,0} broadcast(reduce.93), dimensions={0,1,2}
+  divide.109 = f32[32,16,2048,2048]{3,2,1,0} divide(exponential.8, broadcast.1309)
+  convert.306 = bf16[32,16,2048,2048]{3,2,1,0} convert(divide.109)
+  dot.44 = bf16[32,16,128,2048]{3,2,1,0} dot(bitcast.1280, convert.306), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
+  transpose.116 = bf16[32,2048,16,128]{3,2,1,0} transpose(dot.44), dimensions={0,3,1,2}
+  bitcast.711 = bf16[65536,2048]{1,0} bitcast(transpose.116)
+  get-tuple-element.40 = bf16[24,1024,16,128]{3,2,1,0} get-tuple-element(param.3), index=24
+  dynamic-slice.329 = bf16[1,1024,16,128]{3,2,1,0} dynamic-slice(get-tuple-element.40, select.372, constant.1980, constant.1980, constant.1980), dynamic_slice_sizes={1,1024,16,128}
+  bitcast.724 = bf16[1024,16,128]{2,1,0} bitcast(dynamic-slice.329)
+  all-gather.9 = bf16[4096,16,128]{2,1,0} all-gather(bitcast.724), channel_id=62, replica_groups={{0,2,4,6},{1,3,5,7}}, dimensions={0}, use_global_device_ids=true
+  bitcast.729 = bf16[2048,4096]{0,1} bitcast(all-gather.9)
+  dot.57 = bf16[65536,4096]{0,1} dot(bitcast.711, bitcast.729), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  bitcast.733 = bf16[32,2048,4096]{1,0,2} bitcast(dot.57)
+  reduce-scatter = bf16[32,2048,2048]{1,0,2} reduce-scatter(bitcast.733), channel_id=322, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, dimensions={2}, to_apply=add
+  bitcast.763 = bf16[2048,32,2048]{2,1,0} bitcast(reduce-scatter)
+  get-tuple-element.41 = bf16[24,1024]{1,0} get-tuple-element(param.3), index=23
+  dynamic-slice.330 = bf16[1,1024]{1,0} dynamic-slice(get-tuple-element.41, select.372, constant.1980), dynamic_slice_sizes={1,1024}
+  bitcast.748 = bf16[1024]{0} bitcast(dynamic-slice.330)
+  collective-permute.1 = bf16[1024]{0} collective-permute(bitcast.748), channel_id=64, source_target_pairs={{0,0},{1,2},{2,4},{3,6},{4,1},{5,3},{6,5},{7,7}}
+  all-gather.10 = bf16[2048]{0} all-gather(collective-permute.1), channel_id=65, replica_groups={{0,4},{2,6},{1,5},{3,7}}, dimensions={0}, use_global_device_ids=true
+  broadcast.1261 = bf16[2048,32,2048]{2,1,0} broadcast(all-gather.10), dimensions={0}
+  add.596 = bf16[2048,32,2048]{2,1,0} add(bitcast.763, broadcast.1261)
+  add.597 = bf16[2048,32,2048]{2,1,0} add(add.596, bitcast.756)
+  convert.295 = f32[2048,32,2048]{2,1,0} convert(add.597)
+  bitcast.774 = f32[32,2048,2048]{1,0,2} bitcast(convert.295)
+  reduce.94 = f32[32,2048]{1,0} reduce(bitcast.774, constant.1968), dimensions={2}, to_apply=region_20.962
+  all-reduce.26 = f32[32,2048]{1,0} all-reduce(reduce.94), channel_id=66, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, to_apply=region_20.962
+  multiply.668 = f32[32,2048]{1,0} multiply(all-reduce.26, broadcast.1243)
+  broadcast.1263 = f32[2048,32,2048]{2,1,0} broadcast(multiply.668), dimensions={1,2}
+  subtract.228 = f32[2048,32,2048]{2,1,0} subtract(convert.295, broadcast.1263)
+  multiply.669 = f32[2048,32,2048]{2,1,0} multiply(subtract.228, subtract.228)
+  bitcast.809 = f32[32,2048,2048]{1,0,2} bitcast(multiply.669)
+  reduce.95 = f32[32,2048]{1,0} reduce(bitcast.809, constant.1968), dimensions={2}, to_apply=region_20.962
+  all-reduce.27 = f32[32,2048]{1,0} all-reduce(reduce.95), channel_id=67, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, to_apply=region_20.962
+  multiply.670 = f32[32,2048]{1,0} multiply(all-reduce.27, broadcast.1243)
+  add.598 = f32[32,2048]{1,0} add(multiply.670, broadcast.1264)
+  bitcast.1148 = f32[1,32,2048]{2,1,0} bitcast(add.598)
+  rsqrt.19 = f32[1,32,2048]{2,1,0} rsqrt(bitcast.1148)
+  bitcast.1602 = f32[32,2048]{1,0} bitcast(rsqrt.19)
+  broadcast.1329 = f32[2048,32,2048]{2,1,0} broadcast(bitcast.1602), dimensions={1,2}
+  multiply.750 = f32[2048,32,2048]{2,1,0} multiply(subtract.228, broadcast.1329)
+  convert.312 = bf16[2048,32,2048]{2,1,0} convert(multiply.750)
+  get-tuple-element.42 = bf16[24,2048]{1,0} get-tuple-element(param.3), index=18
+  dynamic-slice.331 = bf16[1,2048]{1,0} dynamic-slice(get-tuple-element.42, select.372, constant.1980), dynamic_slice_sizes={1,2048}
+  add.599 = bf16[1,2048]{1,0} add(dynamic-slice.331, broadcast.1266)
+  bitcast.1609 = bf16[2048]{0} bitcast(add.599)
+  broadcast.1330 = bf16[2048,32,2048]{2,1,0} broadcast(bitcast.1609), dimensions={0}
+  multiply.745 = bf16[2048,32,2048]{2,1,0} multiply(convert.312, broadcast.1330)
+  get-tuple-element.43 = bf16[24,2048]{1,0} get-tuple-element(param.3), index=17
+  dynamic-slice.332 = bf16[1,2048]{1,0} dynamic-slice(get-tuple-element.43, select.372, constant.1980), dynamic_slice_sizes={1,2048}
+  bitcast.1615 = bf16[2048]{0} bitcast(dynamic-slice.332)
+  broadcast.1331 = bf16[2048,32,2048]{2,1,0} broadcast(bitcast.1615), dimensions={0}
+  add.636 = bf16[2048,32,2048]{2,1,0} add(multiply.745, broadcast.1331)
+  bitcast.1620 = bf16[32,2048,2048]{1,0,2} bitcast(add.636)
+  all-gather.12 = bf16[32,2048,4096]{1,0,2} all-gather(bitcast.1620), channel_id=69, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={2}, use_global_device_ids=true
+  bitcast.877 = bf16[65536,4096]{0,1} bitcast(all-gather.12)
+  get-tuple-element.44 = bf16[24,1024,8192]{2,1,0} get-tuple-element(param.3), index=15
+  dynamic-slice.333 = bf16[1,1024,8192]{2,1,0} dynamic-slice(get-tuple-element.44, select.372, constant.1980, constant.1980), dynamic_slice_sizes={1,1024,8192}
+  bitcast.890 = bf16[1024,8192]{1,0} bitcast(dynamic-slice.333)
+  all-gather.11 = bf16[4096,8192]{1,0} all-gather(bitcast.890), channel_id=68, replica_groups={{0,2,4,6},{1,3,5,7}}, dimensions={0}, use_global_device_ids=true
+  dot.45 = bf16[65536,8192]{1,0} dot(bitcast.877, all-gather.11), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  get-tuple-element.45 = bf16[24,8192]{1,0} get-tuple-element(param.3), index=14
+  dynamic-slice.334 = bf16[1,8192]{1,0} dynamic-slice(get-tuple-element.45, select.372, constant.1980), dynamic_slice_sizes={1,8192}
+  bitcast.906 = bf16[8192]{0} bitcast(dynamic-slice.334)
+  broadcast.1269 = bf16[65536,8192]{1,0} broadcast(bitcast.906), dimensions={1}
+  add.601 = bf16[65536,8192]{1,0} add(dot.45, broadcast.1269)
+  bitcast.997 = bf16[32,2048,8192]{2,1,0} bitcast(add.601)
+  broadcast.1333 = bf16[2048,32,2048]{2,1,0} broadcast(subtract.229), dimensions={1,2}
+  multiply.746 = bf16[2048,32,2048]{2,1,0} multiply(bitcast.1652, broadcast.1333)
+  bitcast.1739 = bf16[32,2048,2048]{1,0,2} bitcast(multiply.746)
+  all-gather.14 = bf16[32,2048,4096]{1,0,2} all-gather(bitcast.1739), channel_id=71, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={2}, use_global_device_ids=true
+  bitcast.934 = bf16[65536,4096]{0,1} bitcast(all-gather.14)
+  get-tuple-element.46 = bf16[24,8192,1024]{1,2,0} get-tuple-element(param.3), index=16
+  bitcast.935 = bf16[24,1024,8192]{2,1,0} bitcast(get-tuple-element.46)
+  dynamic-slice.335 = bf16[1,1024,8192]{2,1,0} dynamic-slice(bitcast.935, select.372, constant.1980, constant.1980), dynamic_slice_sizes={1,1024,8192}
+  bitcast.947 = bf16[8192,1024]{0,1} bitcast(dynamic-slice.335)
+  all-gather.13 = bf16[8192,4096]{0,1} all-gather(bitcast.947), channel_id=70, replica_groups={{0,2,4,6},{1,3,5,7}}, dimensions={1}, use_global_device_ids=true
+  dot.46 = bf16[65536,8192]{1,0} dot(bitcast.934, all-gather.13), lhs_contracting_dims={1}, rhs_contracting_dims={1}
+  bitcast.1092 = bf16[32,2048,8192]{2,1,0} bitcast(dot.46)
+  broadcast.1335 = bf16[32,2048,8192]{2,1,0} broadcast(subtract.229), dimensions={0,1}
+  multiply.703 = bf16[32,2048,8192]{2,1,0} multiply(bitcast.1092, broadcast.1335)
+  multiply.685 = bf16[32,2048,8192]{2,1,0} multiply(bitcast.997, multiply.703)
+  constant.2002 = bf16[] constant(0.5)
+  broadcast.1288 = bf16[32,2048,8192]{2,1,0} broadcast(constant.2002), dimensions={}
+  multiply.686 = bf16[32,2048,8192]{2,1,0} multiply(multiply.685, broadcast.1288)
+  broadcast.1287 = bf16[32,2048,8192]{2,1,0} broadcast(constant.1991), dimensions={}
+  multiply.700 = bf16[32,2048,8192]{2,1,0} multiply(bitcast.997, bitcast.997)
+  multiply.693 = bf16[32,2048,8192]{2,1,0} multiply(bitcast.997, multiply.700)
+  constant.1998 = bf16[] constant(0.04468)
+  broadcast.1282 = bf16[32,2048,8192]{2,1,0} broadcast(constant.1998), dimensions={}
+  multiply.694 = bf16[32,2048,8192]{2,1,0} multiply(multiply.693, broadcast.1282)
+  add.605 = bf16[32,2048,8192]{2,1,0} add(bitcast.997, multiply.694)
+  constant.2010 = bf16[] constant(0.7969)
+  broadcast.1324 = bf16[32,2048,8192]{2,1,0} broadcast(constant.2010), dimensions={}
+  multiply.695 = bf16[32,2048,8192]{2,1,0} multiply(add.605, broadcast.1324)
+  tanh.7 = bf16[32,2048,8192]{2,1,0} tanh(multiply.695)
+  subtract.231 = bf16[32,2048,8192]{2,1,0} subtract(broadcast.1287, tanh.7)
+  multiply.691 = bf16[32,2048,8192]{2,1,0} multiply(multiply.686, subtract.231)
+  multiply.737 = bf16[32,2048,8192]{2,1,0} multiply(multiply.691, tanh.7)
+  add.630 = bf16[32,2048,8192]{2,1,0} add(multiply.691, multiply.737)
+  multiply.738 = bf16[32,2048,8192]{2,1,0} multiply(add.630, broadcast.1324)
+  constant.2011 = bf16[] constant(0.03564)
+  broadcast.1326 = bf16[32,2048,8192]{2,1,0} broadcast(constant.2011), dimensions={}
+  multiply.739 = bf16[32,2048,8192]{2,1,0} multiply(add.630, broadcast.1326)
+  constant.2012 = bf16[] constant(3)
+  broadcast.1327 = bf16[32,2048,8192]{2,1,0} broadcast(constant.2012), dimensions={}
+  multiply.740 = bf16[32,2048,8192]{2,1,0} multiply(multiply.700, broadcast.1327)
+  multiply.741 = bf16[32,2048,8192]{2,1,0} multiply(multiply.739, multiply.740)
+  add.632 = bf16[32,2048,8192]{2,1,0} add(multiply.738, multiply.741)
+  add.637 = bf16[32,2048,8192]{2,1,0} add(tanh.7, broadcast.1287)
+  multiply.747 = bf16[32,2048,8192]{2,1,0} multiply(add.637, broadcast.1288)
+  multiply.743 = bf16[32,2048,8192]{2,1,0} multiply(multiply.703, multiply.747)
+  add.635 = bf16[32,2048,8192]{2,1,0} add(add.632, multiply.743)
+  bitcast.1629 = bf16[65536,8192]{1,0} bitcast(add.635)
+  dot.47 = bf16[65536,4096]{0,1} dot(bitcast.1629, all-gather.11), lhs_contracting_dims={1}, rhs_contracting_dims={1}
+  bitcast.1130 = bf16[32,2048,4096]{1,0,2} bitcast(dot.47)
+  reduce-scatter.1 = bf16[32,2048,2048]{1,0,2} reduce-scatter(bitcast.1130), channel_id=323, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, dimensions={2}, to_apply=add
+  bitcast.1766 = bf16[2048,32,2048]{2,1,0} bitcast(reduce-scatter.1)
+  multiply.712 = bf16[2048,32,2048]{2,1,0} multiply(bitcast.1766, broadcast.1330)
+  convert.299 = f32[2048,32,2048]{2,1,0} convert(multiply.712)
+  multiply.707 = f32[2048,32,2048]{2,1,0} multiply(subtract.228, convert.299)
+  bitcast.1135 = f32[32,2048,2048]{1,0,2} bitcast(multiply.707)
+  reduce.96 = f32[32,2048]{1,0} reduce(bitcast.1135, constant.1968), dimensions={2}, to_apply=region_20.962
+  all-reduce.29 = f32[32,2048]{1,0} all-reduce(reduce.96), channel_id=73, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, to_apply=region_20.962
+  bitcast.1140 = f32[1,32,2048]{2,1,0} bitcast(all-reduce.29)
+  divide.105 = f32[1,32,2048]{2,1,0} divide(rsqrt.19, bitcast.1148)
+  constant.2008 = f32[] constant(-0.5)
+  broadcast.1313 = f32[1,32,2048]{2,1,0} broadcast(constant.2008), dimensions={}
+  multiply.708 = f32[1,32,2048]{2,1,0} multiply(divide.105, broadcast.1313)
+  multiply.709 = f32[1,32,2048]{2,1,0} multiply(bitcast.1140, multiply.708)
+  constant.2009 = f32[] constant(0.00048828125)
+  broadcast.1315 = f32[1,32,2048]{2,1,0} broadcast(constant.2009), dimensions={}
+  multiply.710 = f32[1,32,2048]{2,1,0} multiply(multiply.709, broadcast.1315)
+  bitcast.1235 = f32[32,2048]{1,0} bitcast(multiply.710)
+  broadcast.1296 = f32[2048,32,2048]{2,1,0} broadcast(bitcast.1235), dimensions={1,2}
+  multiply.717 = f32[2048,32,2048]{2,1,0} multiply(subtract.228, broadcast.1296)
+  multiply.718 = f32[2048,32,2048]{2,1,0} multiply(convert.299, broadcast.1329)
+  add.617 = f32[2048,32,2048]{2,1,0} add(multiply.717, multiply.718)
+  negate.58 = f32[2048,32,2048]{2,1,0} negate(multiply.717)
+  bitcast.1189 = f32[32,2048,2048]{1,0,2} bitcast(negate.58)
+  reduce.97 = f32[32,2048]{1,0} reduce(bitcast.1189, constant.1968), dimensions={2}, to_apply=region_20.962
+  negate.59 = f32[2048,32,2048]{2,1,0} negate(multiply.718)
+  bitcast.1203 = f32[32,2048,2048]{1,0,2} bitcast(negate.59)
+  reduce.98 = f32[32,2048]{1,0} reduce(bitcast.1203, constant.1968), dimensions={2}, to_apply=region_20.962
+  add.613 = f32[32,2048]{1,0} add(reduce.97, reduce.98)
+  all-reduce.274 = f32[32,2048]{1,0} all-reduce(add.613), channel_id=335, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, to_apply=region_20.962
+  multiply.719 = f32[32,2048]{1,0} multiply(all-reduce.274, broadcast.1243)
+  broadcast.1297 = f32[2048,32,2048]{2,1,0} broadcast(multiply.719), dimensions={1,2}
+  add.618 = f32[2048,32,2048]{2,1,0} add(add.617, broadcast.1297)
+  convert.301 = bf16[2048,32,2048]{2,1,0} convert(add.618)
+  add.619 = bf16[2048,32,2048]{2,1,0} add(bitcast.1652, convert.301)
+  add.616 = bf16[2048,32,2048]{2,1,0} add(convert.287, add.619)
+  bitcast.2063 = bf16[32,2048,2048]{1,0,2} bitcast(add.619)
+  all-gather.15 = bf16[32,2048,4096]{1,0,2} all-gather(bitcast.2063), channel_id=76, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={2}, use_global_device_ids=true
+  bitcast.1263 = bf16[65536,4096]{0,1} bitcast(all-gather.15)
+  bitcast.1269 = bf16[4096,2048]{1,0} bitcast(all-gather.9)
+  dot.48 = bf16[65536,2048]{1,0} dot(bitcast.1263, bitcast.1269), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  bitcast.1381 = bf16[32,2048,16,128]{3,2,1,0} bitcast(dot.48)
+  transpose.122 = bf16[32,16,2048,128]{3,2,1,0} transpose(bitcast.1381), dimensions={0,2,1,3}
+  dot.49 = bf16[32,16,2048,2048]{3,2,1,0} dot(transpose.122, bitcast.1280), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  convert.303 = f32[32,16,2048,2048]{3,2,1,0} convert(dot.49)
+  broadcast.1298 = f32[32,16,2048]{2,1,0} broadcast(constant.2005), dimensions={}
+  multiply.720 = f32[32,16,2048]{2,1,0} multiply(reduce.93, reduce.93)
+  divide.106 = f32[32,16,2048]{2,1,0} divide(broadcast.1298, multiply.720)
+  broadcast.1299 = f32[32,16,2048,2048]{3,2,1,0} broadcast(divide.106), dimensions={0,1,2}
+  multiply.721 = f32[32,16,2048,2048]{3,2,1,0} multiply(convert.303, broadcast.1299)
+  multiply.722 = f32[32,16,2048,2048]{3,2,1,0} multiply(multiply.721, exponential.8)
+  reduce.99 = f32[32,16,2048]{2,1,0} reduce(multiply.722, constant.1968), dimensions={3}, to_apply=region_20.962
+  negate.61 = f32[32,16,2048]{2,1,0} negate(reduce.99)
+  broadcast.1305 = f32[32,16,2048,2048]{3,2,1,0} broadcast(negate.61), dimensions={0,1,2}
+  divide.108 = f32[32,16,2048,2048]{3,2,1,0} divide(convert.303, broadcast.1309)
+  add.622 = f32[32,16,2048,2048]{3,2,1,0} add(broadcast.1305, divide.108)
+  multiply.724 = f32[32,16,2048,2048]{3,2,1,0} multiply(add.622, exponential.8)
+  convert.305 = bf16[32,16,2048,2048]{3,2,1,0} convert(multiply.724)
+  dot.50 = bf16[32,16,2048,128]{3,2,1,0} dot(convert.305, transpose.113), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  bitcast.1934 = bf16[1,32,16,2048,128]{4,3,2,1,0} bitcast(dot.50)
+  pad.6 = bf16[3,32,16,2048,128]{4,3,2,1,0} pad(bitcast.1934, constant.2006), padding=1_1x0_0x0_0x0_0x0_0
+  transpose.120 = bf16[32,16,2048,128]{3,2,1,0} transpose(bitcast.647), dimensions={0,1,3,2}
+  dot.51 = bf16[32,16,2048,128]{3,2,1,0} dot(convert.305, transpose.120), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  broadcast.1307 = bf16[32,16,2048,128]{3,2,1,0} broadcast(constant.2007), dimensions={}
+  multiply.725 = bf16[32,16,2048,128]{3,2,1,0} multiply(dot.51, broadcast.1307)
+  bitcast.1941 = bf16[1,32,16,2048,128]{4,3,2,1,0} bitcast(multiply.725)
+  pad.7 = bf16[3,32,16,2048,128]{4,3,2,1,0} pad(bitcast.1941, constant.2006), padding=0_2x0_0x0_0x0_0x0_0
+  add.638 = bf16[3,32,16,2048,128]{4,3,2,1,0} add(pad.6, pad.7)
+  transpose.123 = bf16[32,16,128,2048]{3,2,1,0} transpose(bitcast.1381), dimensions={0,2,3,1}
+  dot.89 = bf16[32,16,2048,128]{3,2,1,0} dot(convert.306, transpose.123), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
+  bitcast.1949 = bf16[1,32,16,2048,128]{4,3,2,1,0} bitcast(dot.89)
+  pad.8 = bf16[3,32,16,2048,128]{4,3,2,1,0} pad(bitcast.1949, constant.2006), padding=2_0x0_0x0_0x0_0x0_0
+  add.639 = bf16[3,32,16,2048,128]{4,3,2,1,0} add(add.638, pad.8)
+  transpose.127 = bf16[32,2048,3,16,128]{4,3,2,1,0} transpose(add.639), dimensions={1,3,0,2,4}
+  bitcast.1416 = bf16[65536,6144]{1,0} bitcast(transpose.127)
+  dot.52 = bf16[65536,4096]{0,1} dot(bitcast.1416, bitcast.1420), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  bitcast.1424 = bf16[32,2048,4096]{1,0,2} bitcast(dot.52)
+  reduce-scatter.2 = bf16[32,2048,2048]{1,0,2} reduce-scatter(bitcast.1424), channel_id=324, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, dimensions={2}, to_apply=add
+  bitcast.1851 = bf16[2048,32,2048]{2,1,0} bitcast(reduce-scatter.2)
+  multiply.732 = bf16[2048,32,2048]{2,1,0} multiply(bitcast.1851, broadcast.1338)
+  convert.308 = f32[2048,32,2048]{2,1,0} convert(multiply.732)
+  multiply.727 = f32[2048,32,2048]{2,1,0} multiply(subtract.224, convert.308)
+  bitcast.1434 = f32[32,2048,2048]{1,0,2} bitcast(multiply.727)
+  reduce.100 = f32[32,2048]{1,0} reduce(bitcast.1434, constant.1968), dimensions={2}, to_apply=region_20.962
+  all-reduce.33 = f32[32,2048]{1,0} all-reduce(reduce.100), channel_id=78, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, to_apply=region_20.962
+  bitcast.1439 = f32[1,32,2048]{2,1,0} bitcast(all-reduce.33)
+  divide.110 = f32[1,32,2048]{2,1,0} divide(rsqrt.20, bitcast.1447)
+  multiply.728 = f32[1,32,2048]{2,1,0} multiply(divide.110, broadcast.1313)
+  multiply.729 = f32[1,32,2048]{2,1,0} multiply(bitcast.1439, multiply.728)
+  multiply.730 = f32[1,32,2048]{2,1,0} multiply(multiply.729, broadcast.1315)
+  bitcast.1485 = f32[32,2048]{1,0} bitcast(multiply.730)
+  broadcast.1321 = f32[2048,32,2048]{2,1,0} broadcast(bitcast.1485), dimensions={1,2}
+  multiply.734 = f32[2048,32,2048]{2,1,0} multiply(subtract.224, broadcast.1321)
+  multiply.735 = f32[2048,32,2048]{2,1,0} multiply(convert.308, broadcast.1337)
+  add.625 = f32[2048,32,2048]{2,1,0} add(multiply.734, multiply.735)
+  negate.62 = f32[2048,32,2048]{2,1,0} negate(multiply.734)
+  bitcast.1491 = f32[32,2048,2048]{1,0,2} bitcast(negate.62)
+  reduce.101 = f32[32,2048]{1,0} reduce(bitcast.1491, constant.1968), dimensions={2}, to_apply=region_20.962
+  negate.63 = f32[2048,32,2048]{2,1,0} negate(multiply.735)
+  bitcast.1505 = f32[32,2048,2048]{1,0,2} bitcast(negate.63)
+  reduce.102 = f32[32,2048]{1,0} reduce(bitcast.1505, constant.1968), dimensions={2}, to_apply=region_20.962
+  add.626 = f32[32,2048]{1,0} add(reduce.101, reduce.102)
+  all-reduce.275 = f32[32,2048]{1,0} all-reduce(add.626), channel_id=336, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, to_apply=region_20.962
+  multiply.736 = f32[32,2048]{1,0} multiply(all-reduce.275, broadcast.1243)
+  broadcast.1323 = f32[2048,32,2048]{2,1,0} broadcast(multiply.736), dimensions={1,2}
+  add.628 = f32[2048,32,2048]{2,1,0} add(add.625, broadcast.1323)
+  convert.309 = bf16[2048,32,2048]{2,1,0} convert(add.628)
+  add.629 = bf16[2048,32,2048]{2,1,0} add(add.616, convert.309)
+  bitcast.1525 = bf16[32,2048,2048]{1,0,2} bitcast(add.629)
+  get-tuple-element.47 = bf16[24,8192]{1,0} get-tuple-element(param.3), index=2
+  reduce.103 = bf16[8192]{0} reduce(add.635, constant.2006), dimensions={0,1}, to_apply=add
+  all-reduce.36 = bf16[8192]{0} all-reduce(reduce.103), channel_id=81, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=add
+  bitcast.1583 = bf16[1,8192]{1,0} bitcast(all-reduce.36)
+  dynamic-update-slice.28 = bf16[24,8192]{1,0} dynamic-update-slice(get-tuple-element.47, bitcast.1583, select.372, constant.1980)
+  get-tuple-element.48 = bf16[24,1024,8192]{2,1,0} get-tuple-element(param.3), index=3
+  all-gather.16 = bf16[32,2048,4096]{1,0,2} all-gather(bitcast.1620), channel_id=82, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={2}, use_global_device_ids=true
+  bitcast.1625 = bf16[4096,65536]{1,0} bitcast(all-gather.16)
+  dot.53 = bf16[4096,8192]{1,0} dot(bitcast.1625, bitcast.1629), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  reduce-scatter.3 = bf16[1024,8192]{1,0} reduce-scatter(dot.53), channel_id=325, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, dimensions={0}, to_apply=add
+  bitcast.1634 = bf16[1,1024,8192]{2,1,0} bitcast(reduce-scatter.3)
+  dynamic-update-slice.29 = bf16[24,1024,8192]{2,1,0} dynamic-update-slice(get-tuple-element.48, bitcast.1634, select.372, constant.1980, constant.1980)
+  get-tuple-element.49 = bf16[24,1024]{0,1} get-tuple-element(param.3), index=4
+  collective-permute.2 = bf16[24,1024]{0,1} collective-permute(get-tuple-element.49), channel_id=85, source_target_pairs={{0,0},{1,2},{2,4},{3,6},{4,1},{5,3},{6,5},{7,7}}
+  all-gather.17 = bf16[24,2048]{0,1} all-gather(collective-permute.2), channel_id=86, replica_groups={{0,4},{2,6},{1,5},{3,7}}, dimensions={1}, use_global_device_ids=true
+  bitcast.1649 = bf16[2048,24]{1,0} bitcast(all-gather.17)
+  reduce.104 = bf16[2048]{0} reduce(bitcast.1739, constant.2006), dimensions={0,1}, to_apply=add
+  all-reduce.38 = bf16[2048]{0} all-reduce(reduce.104), channel_id=84, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=add
+  bitcast.1671 = bf16[2048,1]{1,0} bitcast(all-reduce.38)
+  dynamic-update-slice.30 = bf16[2048,24]{1,0} dynamic-update-slice(bitcast.1649, bitcast.1671, constant.1980, select.372)
+  constant.2013 = s32[8]{0} constant({0, 2048, 0, 2048, 1024, 3072, 1024, 3072})
+  partition-id.3 = u32[] partition-id()
+  dynamic-slice.336 = s32[1]{0} dynamic-slice(constant.2013, partition-id.3), dynamic_slice_sizes={1}
+  constant.2014 = s32[8]{0} constant({0, 2048, 0, 2048, 0, 2048, 0, 2048})
+  dynamic-slice.337 = s32[1]{0} dynamic-slice(constant.2014, partition-id.3), dynamic_slice_sizes={1}
+  subtract.232 = s32[1]{0} subtract(dynamic-slice.336, dynamic-slice.337)
+  bitcast.2087 = s32[] bitcast(subtract.232)
+  dynamic-slice.338 = bf16[1024,24]{1,0} dynamic-slice(dynamic-update-slice.30, bitcast.2087, constant.1980), dynamic_slice_sizes={1024,24}
+  bitcast.1695 = bf16[24,1024]{0,1} bitcast(dynamic-slice.338)
+  collective-permute.9 = bf16[24,1024]{0,1} collective-permute(bitcast.1695), channel_id=109, source_target_pairs={{0,0},{2,1},{4,2},{6,3},{1,4},{3,5},{5,6},{7,7}}
+  get-tuple-element.50 = bf16[24,8192,1024]{1,2,0} get-tuple-element(param.3), index=5
+  bitcast.1698 = bf16[24,1024,8192]{2,1,0} bitcast(get-tuple-element.50)
+  multiply.748 = bf16[32,2048,8192]{2,1,0} multiply(bitcast.997, multiply.747)
+  multiply.749 = bf16[32,2048,8192]{2,1,0} multiply(multiply.748, broadcast.1335)
+  bitcast.1735 = bf16[8192,65536]{0,1} bitcast(multiply.749)
+  all-gather.18 = bf16[32,2048,4096]{1,0,2} all-gather(bitcast.1739), channel_id=87, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={2}, use_global_device_ids=true
+  bitcast.1743 = bf16[65536,4096]{0,1} bitcast(all-gather.18)
+  dot.54 = bf16[8192,4096]{0,1} dot(bitcast.1735, bitcast.1743), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  reduce-scatter.4 = bf16[8192,1024]{0,1} reduce-scatter(dot.54), channel_id=326, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, dimensions={1}, to_apply=add
+  bitcast.1748 = bf16[1,1024,8192]{2,1,0} bitcast(reduce-scatter.4)
+  dynamic-update-slice.31 = bf16[24,1024,8192]{2,1,0} dynamic-update-slice(bitcast.1698, bitcast.1748, select.372, constant.1980, constant.1980)
+  bitcast.1758 = bf16[24,8192,1024]{1,2,0} bitcast(dynamic-update-slice.31)
+  get-tuple-element.51 = bf16[24,1024]{0,1} get-tuple-element(param.3), index=6
+  collective-permute.3 = bf16[24,1024]{0,1} collective-permute(get-tuple-element.51), channel_id=90, source_target_pairs={{0,0},{1,2},{2,4},{3,6},{4,1},{5,3},{6,5},{7,7}}
+  all-gather.19 = bf16[24,2048]{0,1} all-gather(collective-permute.3), channel_id=91, replica_groups={{0,4},{2,6},{1,5},{3,7}}, dimensions={1}, use_global_device_ids=true
+  bitcast.1763 = bf16[2048,24]{1,0} bitcast(all-gather.19)
+  reduce.105 = bf16[2048]{0} reduce(reduce-scatter.1, constant.2006), dimensions={0,1}, to_apply=add
+  all-reduce.40 = bf16[2048]{0} all-reduce(reduce.105), channel_id=89, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=add
+  bitcast.1779 = bf16[2048,1]{1,0} bitcast(all-reduce.40)
+  dynamic-update-slice.32 = bf16[2048,24]{1,0} dynamic-update-slice(bitcast.1763, bitcast.1779, constant.1980, select.372)
+  dynamic-slice.339 = bf16[1024,24]{1,0} dynamic-slice(dynamic-update-slice.32, bitcast.2087, constant.1980), dynamic_slice_sizes={1024,24}
+  bitcast.1794 = bf16[24,1024]{0,1} bitcast(dynamic-slice.339)
+  collective-permute.10 = bf16[24,1024]{0,1} collective-permute(bitcast.1794), channel_id=110, source_target_pairs={{0,0},{2,1},{4,2},{6,3},{1,4},{3,5},{5,6},{7,7}}
+  get-tuple-element.52 = bf16[24,1024]{0,1} get-tuple-element(param.3), index=7
+  collective-permute.4 = bf16[24,1024]{0,1} collective-permute(get-tuple-element.52), channel_id=93, source_target_pairs={{0,0},{1,2},{2,4},{3,6},{4,1},{5,3},{6,5},{7,7}}
+  all-gather.20 = bf16[24,2048]{0,1} all-gather(collective-permute.4), channel_id=94, replica_groups={{0,4},{2,6},{1,5},{3,7}}, dimensions={1}, use_global_device_ids=true
+  bitcast.1801 = bf16[2048,24]{1,0} bitcast(all-gather.20)
+  multiply.751 = bf16[2048,32,2048]{2,1,0} multiply(convert.312, bitcast.1766)
+  bitcast.1817 = bf16[32,2048,2048]{1,0,2} bitcast(multiply.751)
+  reduce.106 = bf16[2048]{0} reduce(bitcast.1817, constant.2006), dimensions={0,1}, to_apply=add
+  all-reduce.41 = bf16[2048]{0} all-reduce(reduce.106), channel_id=92, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=add
+  bitcast.1826 = bf16[2048,1]{1,0} bitcast(all-reduce.41)
+  dynamic-update-slice.33 = bf16[2048,24]{1,0} dynamic-update-slice(bitcast.1801, bitcast.1826, constant.1980, select.372)
+  dynamic-slice.340 = bf16[1024,24]{1,0} dynamic-slice(dynamic-update-slice.33, bitcast.2087, constant.1980), dynamic_slice_sizes={1024,24}
+  bitcast.1841 = bf16[24,1024]{0,1} bitcast(dynamic-slice.340)
+  collective-permute.11 = bf16[24,1024]{0,1} collective-permute(bitcast.1841), channel_id=111, source_target_pairs={{0,0},{2,1},{4,2},{6,3},{1,4},{3,5},{5,6},{7,7}}
+  get-tuple-element.53 = bf16[24,1024]{0,1} get-tuple-element(param.3), index=8
+  collective-permute.5 = bf16[24,1024]{0,1} collective-permute(get-tuple-element.53), channel_id=96, source_target_pairs={{0,0},{1,2},{2,4},{3,6},{4,1},{5,3},{6,5},{7,7}}
+  all-gather.21 = bf16[24,2048]{0,1} all-gather(collective-permute.5), channel_id=97, replica_groups={{0,4},{2,6},{1,5},{3,7}}, dimensions={1}, use_global_device_ids=true
+  bitcast.1848 = bf16[2048,24]{1,0} bitcast(all-gather.21)
+  reduce.107 = bf16[2048]{0} reduce(reduce-scatter.2, constant.2006), dimensions={0,1}, to_apply=add
+  all-reduce.42 = bf16[2048]{0} all-reduce(reduce.107), channel_id=95, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=add
+  bitcast.1864 = bf16[2048,1]{1,0} bitcast(all-reduce.42)
+  dynamic-update-slice.34 = bf16[2048,24]{1,0} dynamic-update-slice(bitcast.1848, bitcast.1864, constant.1980, select.372)
+  dynamic-slice.341 = bf16[1024,24]{1,0} dynamic-slice(dynamic-update-slice.34, bitcast.2087, constant.1980), dynamic_slice_sizes={1024,24}
+  bitcast.1879 = bf16[24,1024]{0,1} bitcast(dynamic-slice.341)
+  collective-permute.12 = bf16[24,1024]{0,1} collective-permute(bitcast.1879), channel_id=112, source_target_pairs={{0,0},{2,1},{4,2},{6,3},{1,4},{3,5},{5,6},{7,7}}
+  get-tuple-element.54 = bf16[24,1024]{0,1} get-tuple-element(param.3), index=9
+  collective-permute.6 = bf16[24,1024]{0,1} collective-permute(get-tuple-element.54), channel_id=99, source_target_pairs={{0,0},{1,2},{2,4},{3,6},{4,1},{5,3},{6,5},{7,7}}
+  all-gather.22 = bf16[24,2048]{0,1} all-gather(collective-permute.6), channel_id=100, replica_groups={{0,4},{2,6},{1,5},{3,7}}, dimensions={1}, use_global_device_ids=true
+  bitcast.1886 = bf16[2048,24]{1,0} bitcast(all-gather.22)
+  multiply.753 = bf16[2048,32,2048]{2,1,0} multiply(convert.314, bitcast.1851)
+  bitcast.1905 = bf16[32,2048,2048]{1,0,2} bitcast(multiply.753)
+  reduce.108 = bf16[2048]{0} reduce(bitcast.1905, constant.2006), dimensions={0,1}, to_apply=add
+  all-reduce.43 = bf16[2048]{0} all-reduce(reduce.108), channel_id=98, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=add
+  bitcast.1914 = bf16[2048,1]{1,0} bitcast(all-reduce.43)
+  dynamic-update-slice.35 = bf16[2048,24]{1,0} dynamic-update-slice(bitcast.1886, bitcast.1914, constant.1980, select.372)
+  dynamic-slice.342 = bf16[1024,24]{1,0} dynamic-slice(dynamic-update-slice.35, bitcast.2087, constant.1980), dynamic_slice_sizes={1024,24}
+  bitcast.1929 = bf16[24,1024]{0,1} bitcast(dynamic-slice.342)
+  collective-permute.13 = bf16[24,1024]{0,1} collective-permute(bitcast.1929), channel_id=113, source_target_pairs={{0,0},{2,1},{4,2},{6,3},{1,4},{3,5},{5,6},{7,7}}
+  get-tuple-element.55 = bf16[24,3,16,128]{3,2,1,0} get-tuple-element(param.3), index=10
+  bitcast.1979 = bf16[3,32,2048,16,128]{4,2,3,1,0} bitcast(add.639)
+  reduce.109 = bf16[3,16,128]{2,1,0} reduce(bitcast.1979, constant.2006), dimensions={1,2}, to_apply=add
+  all-reduce.44 = bf16[3,16,128]{2,1,0} all-reduce(reduce.109), channel_id=101, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=add
+  bitcast.1963 = bf16[1,3,16,128]{3,2,1,0} bitcast(all-reduce.44)
+  dynamic-update-slice.36 = bf16[24,3,16,128]{3,2,1,0} dynamic-update-slice(get-tuple-element.55, bitcast.1963, select.372, constant.1980, constant.1980, /*index=5*/constant.1980)
+  get-tuple-element.56 = bf16[24,3,1024,16,128]{4,3,1,2,0} get-tuple-element(param.3), index=11
+  bitcast.1974 = bf16[24,1024,3,16,128]{4,3,2,1,0} bitcast(get-tuple-element.56)
+  transpose.130 = bf16[3,16,128,32,2048]{4,3,2,1,0} transpose(add.639), dimensions={0,2,4,1,3}
+  bitcast.1983 = bf16[6144,65536]{1,0} bitcast(transpose.130)
+  all-gather.23 = bf16[32,2048,4096]{1,0,2} all-gather(bitcast.2003), channel_id=102, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={2}, use_global_device_ids=true
+  bitcast.2007 = bf16[65536,4096]{0,1} bitcast(all-gather.23)
+  dot.55 = bf16[6144,4096]{0,1} dot(bitcast.1983, bitcast.2007), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  bitcast.2011 = bf16[3,16,128,4096]{2,1,0,3} bitcast(dot.55)
+  reduce-scatter.5 = bf16[3,16,128,1024]{2,1,0,3} reduce-scatter(bitcast.2011), channel_id=327, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, dimensions={3}, to_apply=add
+  bitcast.2015 = bf16[1,1024,3,16,128]{4,3,2,1,0} bitcast(reduce-scatter.5)
+  dynamic-update-slice.37 = bf16[24,1024,3,16,128]{4,3,2,1,0} dynamic-update-slice(bitcast.1974, bitcast.2015, select.372, constant.1980, constant.1980, /*index=5*/constant.1980, constant.1980)
+  bitcast.2025 = bf16[24,3,1024,16,128]{4,3,1,2,0} bitcast(dynamic-update-slice.37)
+  get-tuple-element.57 = bf16[24,1024]{1,0} get-tuple-element(param.3), index=12
+  reduce.110 = bf16[2048]{0} reduce(bitcast.2063, constant.2006), dimensions={0,1}, to_apply=add
+  all-reduce.46 = bf16[2048]{0} all-reduce(reduce.110), channel_id=104, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=add
+  dynamic-slice.343 = bf16[1024]{0} dynamic-slice(all-reduce.46, bitcast.2087), dynamic_slice_sizes={1024}
+  bitcast.2046 = bf16[1,1024]{1,0} bitcast(dynamic-slice.343)
+  collective-permute.7 = bf16[1,1024]{1,0} collective-permute(bitcast.2046), channel_id=105, source_target_pairs={{0,0},{2,1},{4,2},{6,3},{1,4},{3,5},{5,6},{7,7}}
+  dynamic-update-slice.38 = bf16[24,1024]{1,0} dynamic-update-slice(get-tuple-element.57, collective-permute.7, select.372, constant.1980)
+  get-tuple-element.58 = bf16[24,1024,16,128]{3,2,1,0} get-tuple-element(param.3), index=13
+  bitcast.2066 = bf16[2048,65536]{1,0} bitcast(add.619)
+  transpose.133 = bf16[16,32,2048,128]{3,2,1,0} transpose(dot.44), dimensions={1,0,3,2}
+  bitcast.2072 = bf16[32,2048,16,128]{3,1,0,2} bitcast(transpose.133)
+  all-gather.24 = bf16[32,2048,32,128]{3,1,0,2} all-gather(bitcast.2072), channel_id=106, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={2}, use_global_device_ids=true
+  bitcast.2073 = bf16[32,32,2048,128]{3,2,1,0} bitcast(all-gather.24)
+  transpose.134 = bf16[32,2048,32,128]{3,2,1,0} transpose(bitcast.2073), dimensions={1,2,0,3}
+  bitcast.2077 = bf16[65536,4096]{1,0} bitcast(transpose.134)
+  dot.56 = bf16[2048,4096]{1,0} dot(bitcast.2066, bitcast.2077), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  bitcast.2081 = bf16[2048,32,128]{2,1,0} bitcast(dot.56)
+  all-reduce.47 = bf16[2048,32,128]{2,1,0} all-reduce(bitcast.2081), channel_id=107, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=add
+  constant.2015 = s32[8]{0} constant({0, 0, 16, 16, 0, 0, 16, 16})
+  dynamic-slice.344 = s32[1]{0} dynamic-slice(constant.2015, partition-id.3), dynamic_slice_sizes={1}
+  bitcast.2095 = s32[] bitcast(dynamic-slice.344)
+  dynamic-slice.345 = bf16[1024,16,128]{2,1,0} dynamic-slice(all-reduce.47, bitcast.2087, bitcast.2095, constant.1980), dynamic_slice_sizes={1024,16,128}
+  bitcast.2102 = bf16[1,1024,16,128]{3,2,1,0} bitcast(dynamic-slice.345)
+  collective-permute.8 = bf16[1,1024,16,128]{3,2,1,0} collective-permute(bitcast.2102), channel_id=108, source_target_pairs={{0,0},{2,1},{4,2},{6,3},{1,4},{3,5},{5,6},{7,7}}
+  dynamic-update-slice.39 = bf16[24,1024,16,128]{3,2,1,0} dynamic-update-slice(get-tuple-element.58, collective-permute.8, select.372, constant.1980, constant.1980, /*index=5*/constant.1980)
+  ROOT tuple.2 = (s32[], bf16[32,2048,2048]{1,0,2}, bf16[24,8192]{1,0}, bf16[24,1024,8192]{2,1,0}, bf16[24,1024]{0,1}, /*index=5*/bf16[24,8192,1024]{1,2,0}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, /*index=10*/bf16[24,3,16,128]{3,2,1,0}, bf16[24,3,1024,16,128]{4,3,1,2,0}, bf16[24,1024]{1,0}, bf16[24,1024,16,128]{3,2,1,0}, bf16[24,8192]{1,0}, /*index=15*/bf16[24,1024,8192]{2,1,0}, bf16[24,8192,1024]{1,2,0}, bf16[24,2048]{1,0}, bf16[24,2048]{1,0}, bf16[24,2048]{1,0}, /*index=20*/bf16[24,2048]{1,0}, bf16[24,3,16,128]{3,2,1,0}, bf16[24,3,1024,16,128]{4,3,1,2,0}, bf16[24,1024]{1,0}, bf16[24,1024,16,128]{3,2,1,0}, /*index=25*/bf16[24,32,2048,2048]{2,1,3,0}, bf16[32,1,2048,2048]{3,2,0,1}, bf16[32,2048]{1,0}) tuple(add.581, bitcast.1525, dynamic-update-slice.28, dynamic-update-slice.29, collective-permute.9, /*index=5*/bitcast.1758, collective-permute.10, collective-permute.11, collective-permute.12, collective-permute.13, /*index=10*/dynamic-update-slice.36, bitcast.2025, dynamic-update-slice.38, dynamic-update-slice.39, get-tuple-element.45, /*index=15*/get-tuple-element.44, get-tuple-element.46, get-tuple-element.43, get-tuple-element.42, get-tuple-element.37, /*index=20*/get-tuple-element.36, get-tuple-element.38, get-tuple-element.35, get-tuple-element.41, get-tuple-element.40, /*index=25*/get-tuple-element.32, get-tuple-element.39, get-tuple-element.33)
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
+                                         GetCudnnVersion()};
+  TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
+
+  ComputationLayout computation_layout(
+      m->entry_computation()->ComputeProgramShape());
+
+  HloInstruction* fwd_instruction = nullptr;
+  HloInstruction* bwd_instruction = nullptr;
+  SCOPED_TRACE(m->ToString());
+  for (HloInstruction* instr :
+       m->entry_computation()->MakeInstructionPostOrder()) {
+    if (instr->opcode() == HloOpcode::kCustomCall &&
+        instr->custom_call_target() == kCudnnfMHASoftmaxCallTarget) {
+      fwd_instruction = instr;
+    }
+    if (instr->opcode() == HloOpcode::kCustomCall &&
+        instr->custom_call_target() == kCudnnfMHASoftmaxBackwardCallTarget) {
+      bwd_instruction = instr;
+    }
+  }
+  EXPECT_NE(fwd_instruction, nullptr);
+  EXPECT_NE(bwd_instruction, nullptr);
+  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+                          fwd_instruction->backend_config<GpuBackendConfig>());
+  const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config();
+  EXPECT_EQ(config.mask_type(), CudnnfMHABackendConfig::CAUSAL);
+}
+
+TEST_F(CudnnFusedMhaRewriterTestHloTest,
+       BF16TrainingBmm2CanonicalizationRestoreFwdGraph) {
+  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+  const char* module_str = R"(
+HloModule pjit__unnamed_function_, entry_computation_layout={(bf16[2,256,4,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}, bf16[2,4,256,256]{3,2,1,0})->(bf16[4,256,8,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={false,false,false,false}, num_partitions=4
+
+region_0.6 {
+  Arg_0.7 = bf16[] parameter(0)
+  Arg_1.8 = bf16[] parameter(1)
+  ROOT maximum.5 = bf16[] maximum(Arg_0.7, Arg_1.8)
+}
+
+region_1.10 {
+  Arg_0.11 = f32[] parameter(0)
+  Arg_1.12 = f32[] parameter(1)
+  ROOT add.14 = f32[] add(Arg_0.11, Arg_1.12)
+}
+
+add.clone {
+  x.1 = u32[] parameter(0)
+  y.1 = u32[] parameter(1)
+  ROOT add.15 = u32[] add(x.1, y.1)
+}
+
+region_2.65 {
+  Arg_0.66 = bf16[] parameter(0)
+  Arg_1.67 = bf16[] parameter(1)
+  ROOT add.16 = bf16[] add(Arg_0.66, Arg_1.67)
+}
+
+ENTRY main.164_spmd {
+  param = bf16[2,256,4,64]{3,2,1,0} parameter(2), sharding={devices=[2,1,2,1]<=[4]}
+  transpose.26 = bf16[2,4,64,256]{3,2,1,0} transpose(param), dimensions={0,2,3,1}
+  param.1 = bf16[2,256,4,64]{3,2,1,0} parameter(0), sharding={devices=[2,1,2,1]<=[4]}
+  transpose.27 = bf16[2,4,256,64]{3,2,1,0} transpose(param.1), dimensions={0,2,1,3}
+  constant.46 = bf16[] constant(0.5)
+  broadcast.126 = bf16[2,4,256,64]{3,2,1,0} broadcast(constant.46), dimensions={}
+  multiply.34 = bf16[2,4,256,64]{3,2,1,0} multiply(transpose.27, broadcast.126)
+  param.2 = bf16[2,256,4,64]{3,2,1,0} parameter(1), sharding={devices=[2,1,2,1]<=[4]}
+  transpose.29 = bf16[2,4,64,256]{3,2,1,0} transpose(param.2), dimensions={0,2,3,1}
+  dot.12 = bf16[2,4,256,256]{3,2,1,0} dot(multiply.34, transpose.29), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  param.3 = bf16[2,4,256,256]{3,2,1,0} parameter(4), sharding={devices=[2,2,1,1]<=[4]}
+  add.17 = bf16[2,4,256,256]{3,2,1,0} add(dot.12, param.3)
+  constant.47 = bf16[] constant(-inf)
+  reduce.4 = bf16[2,4,256]{2,1,0} reduce(add.17, constant.47), dimensions={3}, to_apply=region_0.6
+  broadcast.127 = bf16[2,4,256,256]{3,2,1,0} broadcast(reduce.4), dimensions={0,1,2}
+  subtract.14 = bf16[2,4,256,256]{3,2,1,0} subtract(add.17, broadcast.127)
+  exponential.2 = bf16[2,4,256,256]{3,2,1,0} exponential(subtract.14)
+  convert.46 = f32[2,4,256,256]{3,2,1,0} convert(exponential.2)
+  constant.48 = f32[] constant(0)
+  reduce.5 = f32[2,4,256]{2,1,0} reduce(convert.46, constant.48), dimensions={3}, to_apply=region_1.10
+  convert.47 = bf16[2,4,256]{2,1,0} convert(reduce.5)
+  broadcast.128 = bf16[2,4,256,256]{3,2,1,0} broadcast(convert.47), dimensions={0,1,2}
+  divide.7 = bf16[2,4,256,256]{3,2,1,0} divide(exponential.2, broadcast.128)
+  broadcast.129 = f32[4096]{0} broadcast(constant.48), dimensions={}
+  constant.50 = u32[] constant(0)
+  broadcast.131 = u32[8192]{0} broadcast(constant.50), dimensions={}
+  broadcast.133 = u32[4096]{0} broadcast(constant.50), dimensions={}
+  iota.3 = u32[8192]{0} iota(), iota_dimension=0
+  slice.14 = u32[4096]{0} slice(iota.3), slice={[0:4096]}
+  slice.15 = u32[4096]{0} slice(iota.3), slice={[4096:8192]}
+  custom-call.3 = (u32[4096]{0}, u32[4096]{0}) custom-call(broadcast.133, broadcast.133, slice.14, slice.15), custom_call_target="cu_threefry2x32", operand_layout_constraints={u32[4096]{0}, u32[4096]{0}, u32[4096]{0}, u32[4096]{0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config="\000\020\000\000\000\000\000\000"
+  get-tuple-element.6 = u32[4096]{0} get-tuple-element(custom-call.3), index=0
+  constant.115 = u32[1]{0} constant({0})
+  constant.52 = u32[4]{0} constant({0, 0, 1, 1})
+  partition-id = u32[] partition-id()
+  dynamic-slice.21 = u32[1]{0} dynamic-slice(constant.52, partition-id), dynamic_slice_sizes={1}
+  constant.116 = u32[1]{0} constant({1})
+  clamp.3 = u32[1]{0} clamp(constant.115, dynamic-slice.21, constant.116)
+  convert.48 = s32[1]{0} convert(clamp.3)
+  constant.117 = s32[1]{0} constant({2048})
+  multiply.35 = s32[1]{0} multiply(convert.48, constant.117)
+  bitcast.105 = s32[] bitcast(multiply.35)
+  dynamic-slice.22 = u32[2048]{0} dynamic-slice(get-tuple-element.6, bitcast.105), dynamic_slice_sizes={2048}
+  constant.58 = s32[4]{0} constant({0, 0, 1, 1})
+  dynamic-slice.23 = s32[1]{0} dynamic-slice(constant.58, partition-id), dynamic_slice_sizes={1}
+  multiply.36 = s32[1]{0} multiply(dynamic-slice.23, constant.117)
+  bitcast.108 = s32[] bitcast(multiply.36)
+  dynamic-update-slice.2 = u32[8192]{0} dynamic-update-slice(broadcast.131, dynamic-slice.22, bitcast.108)
+  get-tuple-element.7 = u32[4096]{0} get-tuple-element(custom-call.3), index=1
+  dynamic-slice.24 = u32[2048]{0} dynamic-slice(get-tuple-element.7, bitcast.105), dynamic_slice_sizes={2048}
+  constant.65 = s32[] constant(4096)
+  add.18 = s32[] add(bitcast.108, constant.65)
+  dynamic-update-slice.3 = u32[8192]{0} dynamic-update-slice(dynamic-update-slice.2, dynamic-slice.24, add.18)
+  all-reduce = u32[8192]{0} all-reduce(dynamic-update-slice.3), channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=add.clone
+  constant.118 = s32[1]{0} constant({4096})
+  multiply.37 = s32[1]{0} multiply(dynamic-slice.23, constant.118)
+  bitcast.119 = s32[] bitcast(multiply.37)
+  dynamic-slice.25 = u32[4096]{0} dynamic-slice(all-reduce, bitcast.119), dynamic_slice_sizes={4096}
+  constant.69 = u32[] constant(9)
+  broadcast.134 = u32[4096]{0} broadcast(constant.69), dimensions={}
+  shift-right-logical.6 = u32[4096]{0} shift-right-logical(dynamic-slice.25, broadcast.134)
+  constant.70 = u32[] constant(1065353216)
+  broadcast.135 = u32[4096]{0} broadcast(constant.70), dimensions={}
+  or.5 = u32[4096]{0} or(shift-right-logical.6, broadcast.135)
+  bitcast-convert.5 = f32[4096]{0} bitcast-convert(or.5)
+  constant.71 = f32[] constant(-1)
+  broadcast.136 = f32[4096]{0} broadcast(constant.71), dimensions={}
+  add.19 = f32[4096]{0} add(bitcast-convert.5, broadcast.136)
+  maximum.6 = f32[4096]{0} maximum(broadcast.129, add.19)
+  constant.72 = f32[] constant(0.5)
+  broadcast.137 = f32[4096]{0} broadcast(constant.72), dimensions={}
+  compare.4 = pred[4096]{0} compare(maximum.6, broadcast.137), direction=LT
+  bitcast.135 = pred[2,8,256]{2,1,0} bitcast(compare.4)
+  convert.49 = bf16[2,8,256]{2,1,0} convert(bitcast.135)
+  constant.80 = s32[] constant(0)
+  constant.78 = s32[4]{0} constant({0, 4, 0, 4})
+  dynamic-slice.26 = s32[1]{0} dynamic-slice(constant.78, partition-id), dynamic_slice_sizes={1}
+  bitcast.181 = s32[] bitcast(dynamic-slice.26)
+  dynamic-slice.27 = bf16[2,4,256]{2,1,0} dynamic-slice(convert.49, constant.80, bitcast.181, constant.80), dynamic_slice_sizes={2,4,256}
+  broadcast.139 = bf16[2,4,256,256]{3,2,1,0} broadcast(dynamic-slice.27), dimensions={0,1,3}
+  multiply.38 = bf16[2,4,256,256]{3,2,1,0} multiply(divide.7, broadcast.139)
+  constant.93 = bf16[] constant(2)
+  broadcast.141 = bf16[2,4,256,256]{3,2,1,0} broadcast(constant.93), dimensions={}
+  multiply.39 = bf16[2,4,256,256]{3,2,1,0} multiply(multiply.38, broadcast.141)
+  dot.13 = bf16[2,4,64,256]{3,2,1,0} dot(transpose.26, multiply.39), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
+  transpose.31 = bf16[4,2,64,256]{3,2,1,0} transpose(dot.13), dimensions={1,0,2,3}
+  bitcast.154 = bf16[2,256,4,64]{1,3,0,2} bitcast(transpose.31)
+  all-gather = bf16[2,256,8,64]{1,3,0,2} all-gather(bitcast.154), channel_id=2, replica_groups={{0,1},{2,3}}, dimensions={2}, use_global_device_ids=true
+  bitcast.155 = bf16[8,2,64,256]{3,2,1,0} bitcast(all-gather)
+  transpose.32 = bf16[2,8,64,256]{3,2,1,0} transpose(bitcast.155), dimensions={1,0,2,3}
+  bitcast.157 = bf16[2,256,8,64]{1,3,2,0} bitcast(transpose.32)
+  all-gather.1 = bf16[4,256,8,64]{1,3,2,0} all-gather(bitcast.157), channel_id=3, replica_groups={{0,2},{1,3}}, dimensions={0}, use_global_device_ids=true
+  bitcast.236 = bf16[4,8,64,256]{3,2,1,0} bitcast(all-gather.1)
+  transpose.38 = bf16[4,256,8,64]{3,2,1,0} transpose(bitcast.236), dimensions={0,3,1,2}
+  param.4 = bf16[2,256,4,64]{3,2,1,0} parameter(3), sharding={devices=[2,1,2,1]<=[4]}
+  transpose.33 = bf16[2,4,256,64]{3,2,1,0} transpose(param.4), dimensions={0,2,1,3}
+  dot.14 = bf16[2,4,256,256]{3,2,1,0} dot(transpose.33, transpose.26), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  broadcast.142 = bf16[4096]{0} broadcast(constant.93), dimensions={}
+  constant.95 = bf16[] constant(0)
+  broadcast.143 = bf16[4096]{0} broadcast(constant.95), dimensions={}
+  select.4 = bf16[4096]{0} select(compare.4, broadcast.142, broadcast.143)
+  bitcast.176 = bf16[2,8,256]{2,1,0} bitcast(select.4)
+  dynamic-slice.28 = bf16[2,4,256]{2,1,0} dynamic-slice(bitcast.176, constant.80, bitcast.181, constant.80), dynamic_slice_sizes={2,4,256}
+  broadcast.145 = bf16[2,4,256,256]{3,2,1,0} broadcast(dynamic-slice.28), dimensions={0,1,3}
+  multiply.40 = bf16[2,4,256,256]{3,2,1,0} multiply(dot.14, broadcast.145)
+  divide.8 = bf16[2,4,256,256]{3,2,1,0} divide(multiply.40, broadcast.128)
+  constant.106 = bf16[] constant(1)
+  broadcast.146 = bf16[2,4,256]{2,1,0} broadcast(constant.106), dimensions={}
+  multiply.41 = bf16[2,4,256]{2,1,0} multiply(convert.47, convert.47)
+  divide.9 = bf16[2,4,256]{2,1,0} divide(broadcast.146, multiply.41)
+  broadcast.147 = bf16[2,4,256,256]{3,2,1,0} broadcast(divide.9), dimensions={0,1,2}
+  multiply.42 = bf16[2,4,256,256]{3,2,1,0} multiply(multiply.40, broadcast.147)
+  multiply.43 = bf16[2,4,256,256]{3,2,1,0} multiply(multiply.42, exponential.2)
+  reduce.6 = bf16[2,4,256]{2,1,0} reduce(multiply.43, constant.95), dimensions={3}, to_apply=region_2.65
+  negate.4 = bf16[2,4,256]{2,1,0} negate(reduce.6)
+  broadcast.148 = bf16[2,4,256,256]{3,2,1,0} broadcast(negate.4), dimensions={0,1,2}
+  add.20 = bf16[2,4,256,256]{3,2,1,0} add(divide.8, broadcast.148)
+  multiply.44 = bf16[2,4,256,256]{3,2,1,0} multiply(add.20, exponential.2)
+  transpose.34 = bf16[2,4,256,64]{3,2,1,0} transpose(param.2), dimensions={0,2,1,3}
+  dot.15 = bf16[2,4,256,64]{3,2,1,0} dot(multiply.44, transpose.34), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  multiply.45 = bf16[2,4,256,64]{3,2,1,0} multiply(dot.15, broadcast.126)
+  transpose.39 = bf16[2,256,4,64]{3,2,1,0} transpose(multiply.45), dimensions={0,2,1,3}
+  dot.16 = bf16[2,4,256,64]{3,2,1,0} dot(multiply.44, multiply.34), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  transpose.40 = bf16[2,256,4,64]{3,2,1,0} transpose(dot.16), dimensions={0,2,1,3}
+  transpose.36 = bf16[2,4,64,256]{3,2,1,0} transpose(param.4), dimensions={0,2,3,1}
+  dot.11 = bf16[2,4,64,256]{3,2,1,0} dot(transpose.36, multiply.39), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  transpose.41 = bf16[2,256,4,64]{3,2,1,0} transpose(dot.11), dimensions={0,3,1,2}
+  ROOT tuple.2 = (bf16[4,256,8,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}) tuple(transpose.38, transpose.39, transpose.40, transpose.41)
+} // main.164_spmd
+)";
+  // Dropout bwd pattern not supported, should not lower fwd as well
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+  CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(),
+                                         GetCudnnVersion()};
+  TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
+  SCOPED_TRACE(m->ToString());
+  // check if fwd graph has been restored with cloned activation
+  EXPECT_THAT(
+      m->entry_computation()->root_instruction(),
+      GmockMatch(m::Tuple(
+          m::Transpose(), m::Transpose(), m::Transpose(),
+          m::Transpose(m::Dot(
+              m::Op(), m::Op().WithPredicate([](const HloInstruction* instr) {
+                return instr->name() == "multiply.39.fmha_no_match_clone";
+              }))))));
+}
+
+constexpr absl::string_view hlo_BF16Bmm1BiasSoftmaxBmm2Pattern_dbias = R"(
+HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[2,1024,4,64]{3,2,1,0}, bf16[2,1024,4,64]{3,2,1,0}, bf16[2,1024,4,64]{3,2,1,0}, bf16[2,1024,4,64]{3,2,1,0}, bf16[4,1024,1024]{2,1,0})->(bf16[2,1024,4,64]{3,2,1,0}, bf16[2,1024,4,64]{3,2,1,0}, bf16[2,1024,4,64]{3,2,1,0}, bf16[2,1024,4,64]{3,2,1,0}, bf16[4,1024,1024]{2,1,0})}, allow_spmd_sharding_propagation_to_parameters={true,true,true,true,true}, allow_spmd_sharding_propagation_to_output={true,true,true,true,true}
+
+region_0.14 {
+  Arg_0.15 = bf16[] parameter(0)
+  Arg_1.16 = bf16[] parameter(1)
+  ROOT maximum = bf16[] maximum(Arg_0.15, Arg_1.16)
+}
+
+region_1.27 {
+  Arg_0.28 = f32[] parameter(0)
+  Arg_1.29 = f32[] parameter(1)
+  ROOT add = f32[] add(Arg_0.28, Arg_1.29)
+}
+
+region_2.56 {
+  Arg_0.57 = bf16[] parameter(0)
+  Arg_1.58 = bf16[] parameter(1)
+  ROOT add.1 = bf16[] add(Arg_0.57, Arg_1.58)
+}
+
+ENTRY main.87 {
+  Arg_2.3 = bf16[2,1024,4,64]{3,2,1,0} parameter(2)
+  transpose.12 = bf16[2,4,64,1024]{3,2,1,0} transpose(Arg_2.3), dimensions={0,2,3,1}
+  Arg_0.1 = bf16[2,1024,4,64]{3,2,1,0} parameter(0)
+  transpose.13 = bf16[2,4,1024,64]{3,2,1,0} transpose(Arg_0.1), dimensions={0,2,1,3}
+  Arg_1.2 = bf16[2,1024,4,64]{3,2,1,0} parameter(1)
+  transpose.15 = bf16[2,4,64,1024]{3,2,1,0} transpose(Arg_1.2), dimensions={0,2,3,1}
+  dot = bf16[2,4,1024,1024]{3,2,1,0} dot(transpose.13, transpose.15), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  Arg_4.5 = bf16[4,1024,1024]{2,1,0} parameter(4)
+  broadcast.9 = bf16[2,4,1024,1024]{3,2,1,0} broadcast(Arg_4.5), dimensions={1,2,3}
+  add.2 = bf16[2,4,1024,1024]{3,2,1,0} add(dot, broadcast.9)
+  constant.10 = bf16[] constant(-inf)
+  reduce.18 = bf16[2,4,1024]{2,1,0} reduce(add.2, constant.10), dimensions={3}, to_apply=region_0.14
+  broadcast.10 = bf16[2,4,1024,1024]{3,2,1,0} broadcast(reduce.18), dimensions={0,1,2}
+  subtract = bf16[2,4,1024,1024]{3,2,1,0} subtract(add.2, broadcast.10)
+  exponential = bf16[2,4,1024,1024]{3,2,1,0} exponential(subtract)
+  convert.5 = f32[2,4,1024,1024]{3,2,1,0} convert(exponential)
+  constant.9 = f32[] constant(0)
+  reduce.31 = f32[2,4,1024]{2,1,0} reduce(convert.5, constant.9), dimensions={3}, to_apply=region_1.27
+  convert.6 = bf16[2,4,1024]{2,1,0} convert(reduce.31)
+  broadcast.11 = bf16[2,4,1024,1024]{3,2,1,0} broadcast(convert.6), dimensions={0,1,2}
+  divide.2 = bf16[2,4,1024,1024]{3,2,1,0} divide(exponential, broadcast.11)
+  dot.1 = bf16[2,4,64,1024]{3,2,1,0} dot(transpose.12, divide.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
+  transpose.22 = bf16[2,1024,4,64]{3,2,1,0} transpose(dot.1), dimensions={0,3,1,2}
+  Arg_3.4 = bf16[2,1024,4,64]{3,2,1,0} parameter(3)
+  transpose.17 = bf16[2,4,1024,64]{3,2,1,0} transpose(Arg_3.4), dimensions={0,2,1,3}
+  dot.2 = bf16[2,4,1024,1024]{3,2,1,0} dot(transpose.17, transpose.12), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  divide.3 = bf16[2,4,1024,1024]{3,2,1,0} divide(dot.2, broadcast.11)
+  constant.0 = bf16[] constant(1)
+  broadcast.13 = bf16[2,4,1024]{2,1,0} broadcast(constant.0), dimensions={}
+  multiply.2 = bf16[2,4,1024]{2,1,0} multiply(convert.6, convert.6)
+  divide.4 = bf16[2,4,1024]{2,1,0} divide(broadcast.13, multiply.2)
+  broadcast.14 = bf16[2,4,1024,1024]{3,2,1,0} broadcast(divide.4), dimensions={0,1,2}
+  multiply.3 = bf16[2,4,1024,1024]{3,2,1,0} multiply(dot.2, broadcast.14)
+  multiply.4 = bf16[2,4,1024,1024]{3,2,1,0} multiply(multiply.3, exponential)
+  constant.8 = bf16[] constant(0)
+  reduce.60 = bf16[2,4,1024]{2,1,0} reduce(multiply.4, constant.8), dimensions={3}, to_apply=region_2.56
+  negate.1 = bf16[2,4,1024]{2,1,0} negate(reduce.60)
+  broadcast.15 = bf16[2,4,1024,1024]{3,2,1,0} broadcast(negate.1), dimensions={0,1,2}
+  add.3 = bf16[2,4,1024,1024]{3,2,1,0} add(divide.3, broadcast.15)
+  multiply.5 = bf16[2,4,1024,1024]{3,2,1,0} multiply(add.3, exponential)
+  transpose.18 = bf16[2,4,1024,64]{3,2,1,0} transpose(Arg_1.2), dimensions={0,2,1,3}
+  dot.4 = bf16[2,4,1024,64]{3,2,1,0} dot(multiply.5, transpose.18), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  transpose.23 = bf16[2,1024,4,64]{3,2,1,0} transpose(dot.4), dimensions={0,2,1,3}
+  dot.3 = bf16[2,4,1024,64]{3,2,1,0} dot(multiply.5, transpose.13), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  transpose.24 = bf16[2,1024,4,64]{3,2,1,0} transpose(dot.3), dimensions={0,2,1,3}
+  transpose.20 = bf16[2,4,64,1024]{3,2,1,0} transpose(Arg_3.4), dimensions={0,2,3,1}
+  dot.49 = bf16[2,4,64,1024]{3,2,1,0} dot(transpose.20, divide.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  transpose.25 = bf16[2,1024,4,64]{3,2,1,0} transpose(dot.49), dimensions={0,3,1,2}
+  reduce.81 = bf16[4,1024,1024]{2,1,0} reduce(multiply.5, constant.8), dimensions={0}, to_apply=region_2.56
+  ROOT tuple = (bf16[2,1024,4,64]{3,2,1,0}, bf16[2,1024,4,64]{3,2,1,0}, bf16[2,1024,4,64]{3,2,1,0}, bf16[2,1024,4,64]{3,2,1,0}, bf16[4,1024,1024]{2,1,0}) tuple(transpose.22, transpose.23, transpose.24, transpose.25, reduce.81)
+} // main.87
+)";
+
+TEST_F(CudnnFusedMhaRewriterTestHloTest, BF16Bmm1BiasSoftmaxBmm2PatternDbias) {
+  if (skip_reason_) GTEST_SKIP() << *skip_reason_;
+  TF_ASSERT_OK_AND_ASSIGN(
+      auto m,
+      ParseAndReturnVerifiedModule(hlo_BF16Bmm1BiasSoftmaxBmm2Pattern_dbias));
+  // require cudnn 8.9.6 + hopper for dbias
+  CudnnFusedMHARewriter fusedMhaRewriter{se::CudaComputeCapability(9, 0),
+                                         se::dnn::VersionInfo(9, 0, 0)};
+  TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
+
+  ComputationLayout computation_layout(
+      m->entry_computation()->ComputeProgramShape());
+
+  const HloInstruction* fmha;
+
+  SCOPED_TRACE(m->ToString());
+  EXPECT_THAT(
+      m->entry_computation()->root_instruction(),
+      GmockMatch(m::Tuple(
+          m::Transpose(
+              m::Transpose(m::GetTupleElement(
+                  m::CustomCall(&fmha, {kCudnnfMHAScaleBiasSoftmaxCallTarget}),
+                  0)))
+              .WithShape(BF16, {2, 1024, 4, 64}),
+          m::Transpose(
+              m::GetTupleElement(
+                  m::CustomCall({kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget}),
+                  0))
+              .WithShape(BF16, {2, 1024, 4, 64}),
+          m::Transpose(
+              m::GetTupleElement(
+                  m::CustomCall({kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget}),
+                  1))
+              .WithShape(BF16, {2, 1024, 4, 64}),
+          m::Transpose(
+              m::Transpose(m::GetTupleElement(
+                  m::CustomCall({kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget}),
+                  2)))
+              .WithShape(BF16, {2, 1024, 4, 64}),
+          m::Reshape(
+              m::GetTupleElement(
+                  m::CustomCall({kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget}),
+                  3))
+              .WithShape(BF16, {4, 1024, 1024}))));
+  TF_ASSERT_OK_AND_ASSIGN(auto gpu_config,
+                          fmha->backend_config<GpuBackendConfig>());
+  const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config();
+  EXPECT_EQ(fmha->operands().size(), 4);
+  EXPECT_EQ(fmha->operand(3)->shape(),
+            ShapeUtil::MakeShape(BF16, {1, 4, 1024, 1024}));
+  EXPECT_EQ(config.fmha_scale(), 1.0);
+  EXPECT_EQ(config.dropout_rate(), 0.0);
+  EXPECT_EQ(config.mask_type(), CudnnfMHABackendConfig::NO_MASK);
+}
+}  // anonymous namespace
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_transpose_fusion.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_transpose_fusion.cc
new file mode 100644
index 0000000..7299643
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_transpose_fusion.cc
@@ -0,0 +1,668 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/cudnn_fused_mha_transpose_fusion.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <iterator>
+#include <vector>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/permutation_util.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/service/gpu/matmul_utils.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/util.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+namespace m = match;
+
+bool IsFMHACustomCall(const HloInstruction* instr) {
+  return IsCustomCallTofMHA(*instr);
+}
+
+bool IsFwdFMHACustomCall(const HloInstruction* instr) {
+  return IsFwdCustomCallTofMHA(*instr);
+}
+
+bool IsBwdFMHACustomCall(const HloInstruction* instr) {
+  return IsBwdCustomCallTofMHA(*instr);
+}
+
+absl::StatusOr<bool> FuseArgPrologueTransposeWithcuDNNFMHA(
+    HloInstruction* fmha, int64_t operand_index, bool is_lhs,
+    bool should_contracting_be_fastest) {
+  HloInstruction* transpose_arg = fmha->mutable_operand(operand_index);
+  HloInstruction* transpose_arg_operand = transpose_arg->mutable_operand(0);
+  TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
+                      fmha->backend_config<GpuBackendConfig>());
+  CudnnfMHABackendConfig config = gpu_config.cudnn_fmha_backend_config();
+  CudnnfMHABackendConfig& new_fmha_config =
+      *gpu_config.mutable_cudnn_fmha_backend_config();
+
+  std::vector<int64_t> inverse_perm =
+      InversePermutation(transpose_arg->dimensions());
+  DotDimensionNumbers new_bmm_dot_dims;
+  if (IsFwdCustomCallTofMHA(*fmha)) {
+    if (operand_index == 0 || operand_index == 1) {
+      new_bmm_dot_dims = config.bmm1_dot_dimension_numbers();
+    } else {
+      new_bmm_dot_dims = config.bmm2_dot_dimension_numbers();
+    }
+  } else {
+    switch (operand_index) {
+      case 0:
+        // Q
+        new_bmm_dot_dims = config.bmm1_grad_gemm1_dot_dimension_numbers();
+        break;
+      case 1:
+        // K
+        new_bmm_dot_dims = config.bmm1_grad_gemm2_dot_dimension_numbers();
+        break;
+      case 2:
+        // V
+        new_bmm_dot_dims = config.bmm2_grad_gemm2_dot_dimension_numbers();
+        break;
+      case 3:
+        // Forward activation
+        new_bmm_dot_dims = config.bmm2_grad_gemm1_dot_dimension_numbers();
+        break;
+      case 4:
+        // Output gradient
+        new_bmm_dot_dims = config.bmm2_grad_gemm2_dot_dimension_numbers();
+        break;
+      default:
+        return Internal("Invalid operand index.");
+    }
+  }
+  absl::Span<const int64_t> checked_dims;
+  std::vector<int64_t> checked_dims_vec;
+
+  // `should_contracting_be_fastest` means if contracting dim is the head
+  // dim. cuDNN requires head dim to be the fastest dim. fwd bmm1 and bwd
+  // bmm2grad1 should set this value to true.
+  if (should_contracting_be_fastest) {
+    checked_dims = is_lhs ? new_bmm_dot_dims.lhs_contracting_dimensions()
+                          : new_bmm_dot_dims.rhs_contracting_dimensions();
+  } else {
+    absl::Span<const int64_t> batch_dims =
+        is_lhs ? new_bmm_dot_dims.lhs_batch_dimensions()
+               : new_bmm_dot_dims.rhs_batch_dimensions();
+    absl::Span<const int64_t> contracting_dims =
+        is_lhs ? new_bmm_dot_dims.lhs_contracting_dimensions()
+               : new_bmm_dot_dims.rhs_contracting_dimensions();
+
+    TF_ASSIGN_OR_RETURN(checked_dims_vec,
+                        GetNonContractingDims(transpose_arg->shape(),
+                                              batch_dims, contracting_dims));
+    checked_dims = checked_dims_vec;
+  }
+
+  int64_t checked_dims_bmm_size = checked_dims.size();
+  std::vector<int64_t> new_bmm_checked_dims(checked_dims_bmm_size);
+  for (int i = 0; i < checked_dims_bmm_size; i++) {
+    auto itr =
+        std::find(inverse_perm.begin(), inverse_perm.end(), checked_dims[i]);
+    if (itr == inverse_perm.end()) {
+      return Internal("Invalid inverse perm");
+    }
+    new_bmm_checked_dims[i] = std::distance(inverse_perm.begin(), itr);
+  }
+  // We want to make sure that making the argument to transpose, an input to
+  // fmha, doesn't break cuDNN constraint that the head dim of
+  // corresponding operand of BMM is the fastest moving dimension.
+  // One exception is the forward activation which doesn't have the constraint
+  // since it does not have head dim.
+  absl::Span<const int64_t> minor_to_major_bmm =
+      transpose_arg_operand->shape().layout().minor_to_major();
+  if ((minor_to_major_bmm[0] != new_bmm_checked_dims[0]) &&
+      !(IsBwdCustomCallTofMHA(*fmha) && operand_index == 3)) {
+    return false;
+  }
+  if (should_contracting_be_fastest) {
+    if (is_lhs) {
+      new_bmm_dot_dims.clear_lhs_contracting_dimensions();
+      *new_bmm_dot_dims.mutable_lhs_contracting_dimensions() = {
+          new_bmm_checked_dims.begin(), new_bmm_checked_dims.end()};
+    } else {
+      new_bmm_dot_dims.clear_rhs_contracting_dimensions();
+      *new_bmm_dot_dims.mutable_rhs_contracting_dimensions() = {
+          new_bmm_checked_dims.begin(), new_bmm_checked_dims.end()};
+    }
+  }
+  auto& batch_dims = is_lhs ? new_bmm_dot_dims.lhs_batch_dimensions()
+                            : new_bmm_dot_dims.rhs_batch_dimensions();
+  int64_t batch_dims_bmm_size = batch_dims.size();
+  std::vector<int64_t> new_bmm_batch_dims(batch_dims_bmm_size);
+  for (int i = 0; i < batch_dims_bmm_size; i++) {
+    auto itr =
+        std::find(inverse_perm.begin(), inverse_perm.end(), batch_dims[i]);
+    if (itr == inverse_perm.end()) {
+      return Internal("Invalid inverse perm");
+    }
+    new_bmm_batch_dims[i] = std::distance(inverse_perm.begin(), itr);
+  }
+
+  if (is_lhs) {
+    new_bmm_dot_dims.clear_lhs_batch_dimensions();
+    *new_bmm_dot_dims.mutable_lhs_batch_dimensions() = {
+        new_bmm_batch_dims.begin(), new_bmm_batch_dims.end()};
+
+  } else {
+    new_bmm_dot_dims.clear_rhs_batch_dimensions();
+    *new_bmm_dot_dims.mutable_rhs_batch_dimensions() = {
+        new_bmm_batch_dims.begin(), new_bmm_batch_dims.end()};
+  }
+
+  if (!should_contracting_be_fastest) {
+    // Given the non-contracting dimensions, we can use the same function,
+    // GetNonContractingDims, to find the new contracting dims. Simply pass the
+    // non-contracting dimensions as the second argument.
+    TF_ASSIGN_OR_RETURN(
+        std::vector<int64_t> new_bmm_contracting_dims,
+        GetNonContractingDims(transpose_arg_operand->shape(),
+                              new_bmm_batch_dims, new_bmm_checked_dims));
+    if (is_lhs) {
+      new_bmm_dot_dims.clear_lhs_contracting_dimensions();
+      *new_bmm_dot_dims.mutable_lhs_contracting_dimensions() = {
+          new_bmm_contracting_dims.begin(), new_bmm_contracting_dims.end()};
+
+    } else {
+      new_bmm_dot_dims.clear_rhs_contracting_dimensions();
+      *new_bmm_dot_dims.mutable_rhs_contracting_dimensions() = {
+          new_bmm_contracting_dims.begin(), new_bmm_contracting_dims.end()};
+    }
+  }
+  if (IsFwdCustomCallTofMHA(*fmha)) {
+    if (operand_index == 0 || operand_index == 1) {
+      // Q or K
+      *new_fmha_config.mutable_bmm1_dot_dimension_numbers() = new_bmm_dot_dims;
+    } else {
+      // V
+      *new_fmha_config.mutable_bmm2_dot_dimension_numbers() = new_bmm_dot_dims;
+    }
+  } else {
+    switch (operand_index) {
+      case 0:
+        // Q
+        *new_fmha_config.mutable_bmm1_grad_gemm1_dot_dimension_numbers() =
+            new_bmm_dot_dims;
+        break;
+      case 1:
+        // K
+        *new_fmha_config.mutable_bmm1_grad_gemm2_dot_dimension_numbers() =
+            new_bmm_dot_dims;
+        break;
+      case 2:
+        // V
+        *new_fmha_config.mutable_bmm2_grad_gemm2_dot_dimension_numbers() =
+            new_bmm_dot_dims;
+        break;
+      case 3:
+        // Forward activation
+        *new_fmha_config.mutable_bmm2_grad_gemm1_dot_dimension_numbers() =
+            new_bmm_dot_dims;
+        break;
+      case 4: {
+        // Output gradient
+        *new_fmha_config.mutable_bmm2_grad_gemm2_dot_dimension_numbers() =
+            new_bmm_dot_dims;
+        DotDimensionNumbers bmm2_grad_gemm1_dot_dims =
+            config.bmm2_grad_gemm1_dot_dimension_numbers();
+        absl::Span<const int64_t> bmm2_grad_gemm1_contracting_dims =
+            bmm2_grad_gemm1_dot_dims.rhs_contracting_dimensions();
+        CHECK_EQ(bmm2_grad_gemm1_contracting_dims.size(), 1);
+        absl::Span<const int64_t> transpose_permutation =
+            transpose_arg->dimensions();
+        auto itr = std::find(transpose_permutation.begin(),
+                             transpose_permutation.end(),
+                             bmm2_grad_gemm1_contracting_dims[0]);
+        if (itr == transpose_permutation.end()) {
+          return Internal(
+              "bmm2 gradident gemm1 contracting dimension not found.");
+        }
+        int64_t index = std::distance(transpose_permutation.begin(), itr);
+        std::vector<int64_t> new_bmm2_grad_gemm1_rhs_contracting_dims = {index};
+        // Find the new batch dimensions, this is done by passing new
+        // contracting dimensions and contracting dimension of lhs of
+        // bmm2_grad_gemm2(which is the non-contracting dimension of rhs
+        // bmm2_grad_gemm1) to GetNonContractingDims.
+        TF_ASSIGN_OR_RETURN(
+            std::vector<int64_t> new_bmm2_grad_gemm1_rhs_batch_dims,
+            GetNonContractingDims(
+                transpose_arg_operand->shape(),
+                new_bmm2_grad_gemm1_rhs_contracting_dims,
+                new_bmm_dot_dims.lhs_contracting_dimensions()));
+        bmm2_grad_gemm1_dot_dims.clear_rhs_contracting_dimensions();
+        bmm2_grad_gemm1_dot_dims.clear_rhs_batch_dimensions();
+        *bmm2_grad_gemm1_dot_dims.mutable_rhs_contracting_dimensions() = {
+            new_bmm2_grad_gemm1_rhs_contracting_dims.begin(),
+            new_bmm2_grad_gemm1_rhs_contracting_dims.end()};
+        *bmm2_grad_gemm1_dot_dims.mutable_rhs_batch_dimensions() = {
+            new_bmm2_grad_gemm1_rhs_batch_dims.begin(),
+            new_bmm2_grad_gemm1_rhs_batch_dims.end()};
+        *new_fmha_config.mutable_bmm2_grad_gemm1_dot_dimension_numbers() =
+            bmm2_grad_gemm1_dot_dims;
+        break;
+      }
+      default:
+        return Internal("Invalid operand index.");
+    }
+  }
+
+  TF_RETURN_IF_ERROR(fmha->set_backend_config(gpu_config));
+
+  TF_RETURN_IF_ERROR(fmha->ReplaceOperandWithDifferentShape(
+      operand_index, transpose_arg_operand));
+
+  return true;
+}
+
+/* Let's say A is transposed to B with perm {3, 0, 2, 1} as shown below:
+A[16, 256, 32, 64]
+      |
+      |
+      | Transpose with perm = {3, 0, 2, 1}
+      |
+      \/
+B[64, 16, 32, 256]
+
+The inverse perm to obtain A from B would be {1, 3, 2, 0}. That is
+B[64, 16, 32, 256]
+      |
+      |
+      | Transpose' with inv_perm = {1, 3, 2, 0}
+      |
+      \/
+A[16, 256, 32, 64]
+
+Now, let's say B is the lhs of a BatchedMatmul and the lhs_contracting
+dim is 3 (i.e dim 256). In order to now make A the lhs to the
+batchedMatmul (thus consuming the Transpose from A->B), we need to find
+the dimension number in A that corresponds to dimension number 3 in B.
+This can be done by finding the index of dim num 3 in inv_perm. That
+would be 2. Hence, dim num 3 in B is equivalent to dim num 2 in A. Thus
+the new lhs_contracting dim ,if A were to be the new lhs, would be 2.
+
+Similarly, we need to find corresponding batch dimensions as well.
+*/
+absl::StatusOr<bool> FusePrologueTransposeWithcuDNNFMHA(HloComputation* comp) {
+  bool changed = false;
+  for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
+    HloInstruction *transpose_arg0, *transpose_arg0_operand;
+    HloInstruction *transpose_arg1, *transpose_arg1_operand;
+    HloInstruction *transpose_arg2, *transpose_arg2_operand;
+    HloInstruction *transpose_arg3, *transpose_arg3_operand;
+    HloInstruction *transpose_arg4, *transpose_arg4_operand;
+
+    HloInstruction* fmha;
+
+    // Arg0 is common between forward and backward fmha calls, so we match
+    // either of these.
+    auto pattern_arg0 =
+        m::Op(&fmha)
+            .WithPredicate(IsFMHACustomCall)
+            .WithOperand(0, m::Transpose(&transpose_arg0,
+                                         m::Op(&transpose_arg0_operand)));
+    if (Match(instr, pattern_arg0)) {
+      if (VLOG_IS_ON(2)) {
+        VLOG(2) << "Before CudnnFusedMHATransposeFusion Arg 0: \n"
+                << comp->parent()->ToString();
+      }
+      if (IsFwdFMHACustomCall(fmha)) {
+        // Q tensor in forward graph is lhs with constraint on contracting dim.
+        TF_ASSIGN_OR_RETURN(changed,
+                            FuseArgPrologueTransposeWithcuDNNFMHA(
+                                fmha, 0, true /*is_lhs*/,
+                                true /*should_contracting_be_fastest*/));
+      } else {
+        // Q tensor in backward graph is rhs with constraint on non-contracting
+        // dim.
+        TF_ASSIGN_OR_RETURN(changed,
+                            FuseArgPrologueTransposeWithcuDNNFMHA(
+                                fmha, 0, false /*is_lhs*/,
+                                false /*should_contracting_be_fastest*/));
+      }
+
+      if (changed && VLOG_IS_ON(2)) {
+        VLOG(2) << "After CudnnFusedMHATransposeFusion Arg 0: \n"
+                << comp->parent()->ToString();
+      }
+    }
+
+    // Arg1 is common between forward and backward fmha calls, so we match
+    // either of these.
+    auto pattern_arg1 =
+        m::Op(&fmha)
+            .WithPredicate(IsFMHACustomCall)
+            .WithOperand(1, m::Transpose(&transpose_arg1,
+                                         m::Op(&transpose_arg1_operand)));
+    if (Match(instr, pattern_arg1)) {
+      if (VLOG_IS_ON(2)) {
+        VLOG(2) << "Before CudnnFusedMHATransposeFusion Arg 1: \n"
+                << comp->parent()->ToString();
+      }
+      if (IsFwdFMHACustomCall(fmha)) {
+        // K tensor in forward graph is rhs with constraint on contracting dim.
+        TF_ASSIGN_OR_RETURN(changed,
+                            FuseArgPrologueTransposeWithcuDNNFMHA(
+                                fmha, 1, false /*is_lhs*/,
+                                true /*should_contracting_be_fastest*/));
+      } else {
+        // K tensor in backward graph is rhs with constraint on non-contracting
+        // dim.
+        TF_ASSIGN_OR_RETURN(changed,
+                            FuseArgPrologueTransposeWithcuDNNFMHA(
+                                fmha, 1, false /*is_lhs*/,
+                                false /*should_contracting_be_fastest*/));
+      }
+
+      if (changed && VLOG_IS_ON(2)) {
+        VLOG(2) << "After CudnnFusedMHATransposeFusion Arg 1: \n"
+                << comp->parent()->ToString();
+      }
+    }
+
+    // Arg2 is common between forward and backward fmha calls, so we match
+    // either of these.
+    auto pattern_arg2 =
+        m::Op(&fmha)
+            .WithPredicate(IsFMHACustomCall)
+            .WithOperand(2, m::Transpose(&transpose_arg2,
+                                         m::Op(&transpose_arg2_operand)));
+    if (Match(instr, pattern_arg2)) {
+      if (VLOG_IS_ON(2)) {
+        VLOG(2) << "Before CudnnFusedMHATransposeFusion Arg 2: \n"
+                << comp->parent()->ToString();
+      }
+      if (IsFwdFMHACustomCall(fmha)) {
+        // V tensor in forward graph is rhs with constraint on non-contracting
+        // dim.
+        TF_ASSIGN_OR_RETURN(changed,
+                            FuseArgPrologueTransposeWithcuDNNFMHA(
+                                fmha, 2, false /*is_lhs*/,
+                                false /*should_contracting_be_fastest*/));
+      } else {
+        // V tensor in backward graph is rhs with constraint on contracting dim.
+        TF_ASSIGN_OR_RETURN(changed,
+                            FuseArgPrologueTransposeWithcuDNNFMHA(
+                                fmha, 2, false /*is_lhs*/,
+                                true /*should_contracting_be_fastest*/));
+      }
+
+      if (changed && VLOG_IS_ON(2)) {
+        VLOG(2) << "After CudnnFusedMHATransposeFusion Arg 2: \n"
+                << comp->parent()->ToString();
+      }
+    }
+
+    // We only care about arg3 of backward
+    auto pattern_arg3 =
+        m::Op(&fmha)
+            .WithPredicate(IsBwdFMHACustomCall)
+            .WithOperand(3, m::Transpose(&transpose_arg3,
+                                         m::Op(&transpose_arg3_operand)));
+    if (Match(instr, pattern_arg3)) {
+      if (VLOG_IS_ON(2)) {
+        VLOG(2) << "Before CudnnFusedMHATransposeFusion Arg 3: \n"
+                << comp->parent()->ToString();
+      }
+      // Forward activation tensor in backward graph is lhs with constraint on
+      // non-contracting dim.
+      TF_ASSIGN_OR_RETURN(changed,
+                          FuseArgPrologueTransposeWithcuDNNFMHA(
+                              fmha, 3, true /*is_lhs*/,
+                              false /*should_contracting_be_fastest*/));
+
+      if (changed && VLOG_IS_ON(2)) {
+        VLOG(2) << "After CudnnFusedMHATransposeFusion Arg 3: \n"
+                << comp->parent()->ToString();
+      }
+    }
+
+    // We only care about arg4 of backward
+    auto pattern_arg4 =
+        m::Op(&fmha)
+            .WithPredicate(IsBwdFMHACustomCall)
+            .WithOperand(4, m::Transpose(&transpose_arg4,
+                                         m::Op(&transpose_arg4_operand)));
+    if (Match(instr, pattern_arg4)) {
+      if (VLOG_IS_ON(2)) {
+        VLOG(2) << "Before CudnnFusedMHATransposeFusion Arg 4: \n"
+                << comp->parent()->ToString();
+      }
+      // D_output tensor in backward graph is lhs with constraint on
+      // contracting dim.
+      // make sure we dont change layout of dO in flash attention case as dO
+      // should have the same layout of O
+      TF_ASSIGN_OR_RETURN(auto gpu_config,
+                          fmha->backend_config<GpuBackendConfig>());
+      if (changed && VLOG_IS_ON(2)) {
+        VLOG(2) << "After CudnnFusedMHATransposeFusion Arg 4: \n"
+                << comp->parent()->ToString();
+      }
+    }
+  }
+  return changed;
+}
+
+/* Let's say FMHA out is transposed to result with perm {1, 2, 0, 3} as shown
+below: FMHA_out[b0, b1, n, m]{}
+      |
+      |
+  Transpose with perm = {1, 2, 0, 3}
+      |
+      \/
+result[b1, n, b0, m]{1, 0, 3, 2}
+The goal is to find the minor_to_major of 'FMHA_out' such that it's physical
+layout matches the physical layout of 'result', thus eliminating the need for an
+explicit transpose. cuDNN can perform an implicit transpose by knowing the
+corresponding strides (inferred from the corresponding minor_to_major).
+
+In order to find the required mino_to_major of 'FMHA_out', we first determine
+the inverse perm to obtain 'FMHA_out' from 'result'. The function
+"ShapeUtil::PermuteDimensions" generates a transposed shape such that the
+physical layout of the transposed shape is equivalent to the input shape.
+Calling this function with 'result' shape as the input shape and the inverse
+perm as the permutation will generate an output shape whose dimensions match
+'FMHA_out' dimensions but the physical layout is equivalent to 'result'. This is
+exactly what we want.
+
+FMHA output should have exactly one gte instruction for a tuple index
+so we can safely fuse the transpose following that gte to FMHA
+
+FMHA_out = gte(FMHA, index=0)
+FMHA_out_t = transpose(FMHA_out)
+use(FMHA_out_t)
+
+after fusion:
+
+FMHA_out_t = gte(FMHA, index=0)
+use(FMHA_out_t)
+*/
+
+absl::StatusOr<bool> FuseEpilogueTransposeWithcuDNNFMHA(HloComputation* comp) {
+  bool changed = false;
+
+  auto only_one_gte_with_spec_index = [](const HloInstruction* instr,
+                                         int64_t index) {
+    int count = 0;
+    for (auto user : instr->users()) {
+      if (user->opcode() == HloOpcode::kGetTupleElement &&
+          user->tuple_index() == index) {
+        count += 1;
+      }
+    }
+    return count == 1;
+  };
+
+  for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
+    HloInstruction* fmha;
+    HloInstruction* transpose;
+    HloInstruction* gte;
+    auto fwd_tuple_elem =
+        m::GetTupleElement(&gte,
+                           m::Op(&fmha).WithPredicate(IsFwdFMHACustomCall), 0)
+            .WithOneUser();
+    // Note that we don't match any specific tuple index in matcher for
+    // backward.
+    auto bwd_tuple_elem =
+        m::GetTupleElement(&gte,
+                           m::Op(&fmha).WithPredicate(IsBwdFMHACustomCall))
+            .WithOneUser();
+    auto fwd_pattern = m::Transpose(&transpose, fwd_tuple_elem);
+    auto bwd_pattern = m::Transpose(&transpose, bwd_tuple_elem);
+
+    if (Match(instr, fwd_pattern)) {
+      // check if only one gte with such index exist
+      int64_t tuple_index = gte->tuple_index();
+      if (!only_one_gte_with_spec_index(fmha, tuple_index)) continue;
+
+      std::vector<int64_t> inverse_perm =
+          InversePermutation(transpose->dimensions());
+
+      auto expected_fmha_shape =
+          ShapeUtil::PermuteDimensions(inverse_perm, transpose->shape());
+
+      // cuDNN requires the last dimension of the output to be the fastest
+      // moving.
+      if (expected_fmha_shape.layout().minor_to_major()[0] !=
+          expected_fmha_shape.dimensions_size() - 1) {
+        VLOG(3) << "cuDNN requires the last dimension of the FMHA output to be "
+                   "the fastest moving. The last dimension is dim: "
+                << expected_fmha_shape.dimensions_size() - 1
+                << " but the upon fusion with transpose, the fmha output shape "
+                   "would have been "
+                << expected_fmha_shape.ToString(true)
+                << " and the fastest moving "
+                   "dimension would be dim: "
+                << expected_fmha_shape.layout().minor_to_major()[0];
+        continue;
+      }
+      Shape call_shape = fmha->shape();
+      *call_shape.mutable_tuple_shapes(0) = expected_fmha_shape;
+      HloInstruction* new_fmha_custom_call =
+          comp->AddInstruction(HloInstruction::CreateCustomCall(
+              call_shape, fmha->operands(),
+              absl::string_view(fmha->custom_call_target())));
+
+      TF_ASSIGN_OR_RETURN(GpuBackendConfig config,
+                          fmha->backend_config<GpuBackendConfig>());
+      TF_RETURN_IF_ERROR(new_fmha_custom_call->set_backend_config(config));
+      TF_RETURN_IF_ERROR(
+          SetFMHAInstructionName(fmha->GetModule(), new_fmha_custom_call));
+      new_fmha_custom_call->set_metadata(fmha->metadata());
+
+      auto gte = comp->AddInstruction(HloInstruction::CreateGetTupleElement(
+          new_fmha_custom_call->shape().tuple_shapes(0), new_fmha_custom_call,
+          0));
+      TF_RETURN_IF_ERROR(comp->ReplaceWithNewInstruction(
+          instr, HloInstruction::CreateBitcast(transpose->shape(), gte)));
+      TF_RETURN_IF_ERROR(fmha->ReplaceAllUsesWith(new_fmha_custom_call));
+
+      if (VLOG_IS_ON(2)) {
+        VLOG(2) << "After forward FuseEpilogueTransposeWithcuDNNFMHA: \n"
+                << comp->parent()->ToString();
+      }
+      changed |= true;
+    } else if (Match(instr, bwd_pattern)) {
+      // check if only one gte with such index exist
+      int64_t operand_tuple_idx = gte->tuple_index();
+      if (!only_one_gte_with_spec_index(fmha, operand_tuple_idx)) continue;
+
+      std::vector<int64_t> inverse_perm =
+          InversePermutation(transpose->dimensions());
+
+      auto expected_fmha_shape =
+          ShapeUtil::PermuteDimensions(inverse_perm, transpose->shape());
+
+      // cuDNN requires the last dimension of the output to be the fastest
+      // moving.
+      if (expected_fmha_shape.layout().minor_to_major()[0] !=
+          expected_fmha_shape.dimensions_size() - 1) {
+        VLOG(3) << "cuDNN requires the last dimension of the FMHA output to be "
+                   "the fastest moving. The last dimension is dim: "
+                << expected_fmha_shape.dimensions_size() - 1
+                << " but the upon fusion with transpose, the fmha output shape "
+                   "would have been "
+                << expected_fmha_shape.ToString(true)
+                << " and the fastest moving "
+                   "dimension would be dim: "
+                << expected_fmha_shape.layout().minor_to_major()[0];
+        continue;
+      }
+      Shape call_shape = fmha->shape();
+      *call_shape.mutable_tuple_shapes(operand_tuple_idx) = expected_fmha_shape;
+      HloInstruction* new_fmha_custom_call =
+          comp->AddInstruction(HloInstruction::CreateCustomCall(
+              call_shape, fmha->operands(),
+              absl::string_view(fmha->custom_call_target())));
+
+      TF_ASSIGN_OR_RETURN(GpuBackendConfig config,
+                          fmha->backend_config<GpuBackendConfig>());
+      TF_RETURN_IF_ERROR(new_fmha_custom_call->set_backend_config(config));
+      TF_RETURN_IF_ERROR(
+          SetFMHAInstructionName(fmha->GetModule(), new_fmha_custom_call));
+      new_fmha_custom_call->set_metadata(fmha->metadata());
+      TF_RETURN_IF_ERROR(fmha->ReplaceAllUsesWith(new_fmha_custom_call));
+
+      auto gte = comp->AddInstruction(HloInstruction::CreateGetTupleElement(
+          new_fmha_custom_call->shape().tuple_shapes(operand_tuple_idx),
+          new_fmha_custom_call, operand_tuple_idx));
+      TF_RETURN_IF_ERROR(comp->ReplaceWithNewInstruction(
+          instr, HloInstruction::CreateBitcast(transpose->shape(), gte)));
+
+      if (VLOG_IS_ON(2)) {
+        VLOG(2) << "After backward FuseEpilogueTransposeWithcuDNNFMHA: \n"
+                << comp->parent()->ToString();
+      }
+      changed |= true;
+    }
+  }
+  return changed;
+}
+}  // namespace
+
+absl::StatusOr<bool> CudnnFusedMHATransposeFusion::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  bool any_changed = false;
+  for (HloComputation* comp :
+       module->MakeNonfusionComputations(execution_threads)) {
+    bool changed = false;
+    TF_ASSIGN_OR_RETURN(changed, FusePrologueTransposeWithcuDNNFMHA(comp));
+    any_changed |= changed;
+    TF_ASSIGN_OR_RETURN(changed, FuseEpilogueTransposeWithcuDNNFMHA(comp));
+    any_changed |= changed;
+  }
+
+  return any_changed;
+}
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_transpose_fusion.h b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_transpose_fusion.h
new file mode 100644
index 0000000..825d97e
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_transpose_fusion.h
@@ -0,0 +1,45 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_CUDNN_FUSED_MHA_TRANSPOSE_FUSION_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_CUDNN_FUSED_MHA_TRANSPOSE_FUSION_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+class CudnnFusedMHATransposeFusion : public HloModulePass {
+ public:
+  CudnnFusedMHATransposeFusion() = default;
+
+  absl::string_view name() const override {
+    return "cudnn-fused-multi-headed-attention-transpose-fusion";
+  }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_CUDNN_FUSED_MHA_TRANSPOSE_FUSION_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_fusion_compiler.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_fusion_compiler.cc
new file mode 100644
index 0000000..519b495
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cudnn_fusion_compiler.cc
@@ -0,0 +1,739 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/cudnn_fusion_compiler.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <memory>
+#include <optional>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "third_party/gpus/cudnn/cudnn_version.h"
+#include "xla/comparison_util.h"
+#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_clone_context.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/hlo/utils/hlo_query.h"
+#include "xla/primitive_util.h"
+#include "xla/service/dump.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/cudnn_support_utils.h"
+#include "xla/service/gpu/ir_emission_utils.h"
+#include "xla/service/gpu/kernel_reuse_cache.h"
+#include "xla/service/gpu/matmul_utils.h"
+#include "xla/service/gpu/stream_executor_util.h"
+#include "xla/service/gpu/triton_fusion_analysis.h"
+#include "xla/shape_util.h"
+#include "xla/stream_executor/cuda/cuda_dnn.h"
+#include "xla/stream_executor/cuda/cudnn_frontend_helpers.h"
+#include "xla/stream_executor/dnn.h"
+#include "xla/stream_executor/stream_executor.h"
+#include "xla/util.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+
+namespace fe = cudnn_frontend;
+namespace graph = fe::graph;
+
+inline std::optional<fe::PointwiseMode_t> GetElementwiseMode(
+    const HloInstruction& instruction) {
+  const HloOpcode opcode = instruction.opcode();
+  using m = fe::PointwiseMode_t;
+  switch (opcode) {
+    case HloOpcode::kAbs:
+      return m::ABS;
+    case HloOpcode::kAdd:
+      return m::ADD;
+    case HloOpcode::kCeil:
+      return m::CEIL;
+    case HloOpcode::kCompare:
+      switch (instruction.comparison_direction()) {
+        case Comparison::Direction::kEq:
+          return m::CMP_EQ;
+        case Comparison::Direction::kNe:
+          return m::CMP_NEQ;
+        case Comparison::Direction::kGe:
+          return m::CMP_GE;
+        case Comparison::Direction::kGt:
+          return m::CMP_GT;
+        case Comparison::Direction::kLe:
+          return m::CMP_LE;
+        case Comparison::Direction::kLt:
+          return m::CMP_LT;
+      }
+      break;
+    case HloOpcode::kConvert:
+      return m::IDENTITY;
+    case HloOpcode::kCos:
+      return m::COS;
+    case HloOpcode::kDivide:
+      return m::DIV;
+    case HloOpcode::kExp:
+      return m::EXP;
+    case HloOpcode::kFloor:
+      return m::FLOOR;
+    case HloOpcode::kLog:
+      return m::LOG;
+    case HloOpcode::kMaximum:
+      return m::MAX;
+    case HloOpcode::kMinimum:
+      return m::MIN;
+    case HloOpcode::kMultiply:
+      return m::MUL;
+    case HloOpcode::kNegate:
+      return m::NEG;
+    case HloOpcode::kPower:
+      return m::POW;
+    case HloOpcode::kRsqrt:
+      return m::RSQRT;
+#if CUDNN_VERSION >= 90100
+    case HloOpcode::kSelect:
+      return m::BINARY_SELECT;
+#endif  // CUDNN_VERSION
+    case HloOpcode::kSin:
+      return m::SIN;
+    case HloOpcode::kSqrt:
+      return m::SQRT;
+    case HloOpcode::kSubtract:
+      return m::SUB;
+    case HloOpcode::kTan:
+      return m::TAN;
+    case HloOpcode::kTanh:
+      return m::TANH_FWD;
+    default:
+      return std::nullopt;
+  }
+}
+
+inline std::optional<fe::DataType_t> ToCudnnDataType(const PrimitiveType type) {
+  using t = fe::DataType_t;
+  switch (type) {
+    case PrimitiveType::F32:
+      return t::FLOAT;
+    case PrimitiveType::F16:
+      return t::HALF;
+    case PrimitiveType::BF16:
+      return t::BFLOAT16;
+    case PrimitiveType::S32:
+      return t::INT32;
+    case PrimitiveType::S8:
+      return t::INT8;
+    case PrimitiveType::PRED:
+      return t::INT8;
+    case PrimitiveType::F8E5M2:
+      return t::FP8_E5M2;
+    case PrimitiveType::F8E4M3FN:
+      return t::FP8_E4M3;
+    default:
+      return std::nullopt;
+  }
+}
+
+inline std::optional<fe::DataType_t> GetComputeDataType(
+    const PrimitiveType type) {
+  fe::DataType_t compute_dtype = fe::DataType_t::FLOAT;
+  if (primitive_util::IsIntegralType(type)) {
+#if CUDNN_VERSION >= 90100
+    compute_dtype = fe::DataType_t::INT32;
+#else
+    VLOG(3) << "Integer math requires cuDNN 9.1+.";
+    return std::nullopt;
+#endif  // CUDNN_VERSION
+  }
+  return compute_dtype;
+}
+
+int FusionLevel(const HloInstruction& hlo) {
+  return hlo.GetModule()
+      ->config()
+      .debug_options()
+      .xla_gpu_cudnn_gemm_fusion_level();
+};
+
+// Extracts dimensions and strides from HLO tensors in the format expected by
+// cuDNN.
+class GemmDimensionAdapter {
+  explicit GemmDimensionAdapter(const HloDotInstruction& dot,
+                                TritonFusionAnalysis analysis)
+      : analysis_(std::move(analysis)), dot_(dot) {};
+
+ public:
+  const TritonFusionAnalysis analysis_;
+
+  static absl::StatusOr<std::optional<GemmDimensionAdapter>> Create(
+      const HloComputation& computation) {
+    const HloInstruction* maybe_dot =
+        hlo_query::GetFirstInstructionWithOpcode(computation, HloOpcode::kDot);
+    if (maybe_dot == nullptr) {
+      VLOG(3) << "Not a GEMM fusion.";
+      return std::nullopt;
+    }
+    const HloDotInstruction* dot = DynCast<HloDotInstruction>(
+        hlo_query::GetFirstInstructionWithOpcode(computation, HloOpcode::kDot));
+    if (absl::c_any_of(dot->precision_config().operand_precision(),
+                       [](int x) { return x != PrecisionConfig::DEFAULT; })) {
+      VLOG(3) << "Non-default precision is not supported.";
+      return std::nullopt;
+    }
+    TF_ASSIGN_OR_RETURN(auto analysis,
+                        TritonFusionAnalysis::Execute(computation));
+    return GemmDimensionAdapter{*dot, std::move(analysis)};
+  }
+
+  bool DimensionsAndStrides(const HloInstruction& hlo,
+                            const TritonFusionAnalysis::Scope scope,
+                            std::vector<int64_t>& dimensions,
+                            std::vector<int64_t>& strides) {
+    const DotDimensionNumbers& dims = dot_.dot_dimension_numbers();
+    // GEMM fusions require a specific canonical order of dimensions.
+    constexpr int kBatchDimensionIndex = 0;
+    constexpr int kOutputLHSNonContractingDimensionIndex = 1;
+    std::vector<int64_t> dim_indices;
+    int lhs_noncontracting_index = -1;
+    switch (scope) {
+      case TritonFusionAnalysis::Scope::LHS:
+        lhs_noncontracting_index =
+            GetNonContractingDims(dot_.operand(0)->shape(),
+                                  dims.lhs_batch_dimensions(),
+                                  dims.lhs_contracting_dimensions())
+                .value()[0];
+        dim_indices = {
+            dims.lhs_batch_dimensions().empty() ? -1
+                                                : dims.lhs_batch_dimensions(0),
+            lhs_noncontracting_index, dims.lhs_contracting_dimensions(0)};
+        break;
+      case TritonFusionAnalysis::Scope::RHS:
+        dim_indices = {dims.rhs_batch_dimensions().empty()
+                           ? -1
+                           : dims.rhs_batch_dimensions(0),
+                       dims.rhs_contracting_dimensions(0),
+                       GetNonContractingDims(dot_.operand(1)->shape(),
+                                             dims.rhs_batch_dimensions(),
+                                             dims.rhs_contracting_dimensions())
+                           .value()[0]};
+        break;
+      case TritonFusionAnalysis::Scope::OUTPUT:
+        lhs_noncontracting_index = dot_.shape().rank() - 2;
+        dim_indices = {dims.lhs_batch_dimensions().empty() ? -1 : 0,
+                       lhs_noncontracting_index, dot_.shape().rank() - 1};
+        break;
+      case TritonFusionAnalysis::Scope::META:
+        LOG(FATAL) << "Unsupported scope.";
+    }
+    dimensions.reserve(dim_indices.size());
+    strides.reserve(dim_indices.size());
+    for (const int index : dim_indices) {
+      const auto* spec = analysis_.IterSpec(scope, &hlo, index);
+      if (spec == nullptr) {
+        dimensions.push_back(1);
+        strides.push_back(strides.empty() ? 1 : strides.back());
+        continue;
+      } else {
+        if (spec->size() == 1) {
+          // The dimension is not split, nothing to do.
+        } else if (spec->size() == 2) {
+          if (FusionLevel(hlo) < 3) {
+            return false;
+          }
+          if (!dims.lhs_batch_dimensions().empty()) {
+            VLOG(8) << "Noncontracting dimension split is not compatible with "
+                       "batch dimensions.";
+            return false;
+          }
+          if (index != lhs_noncontracting_index) {
+            VLOG(8) << "Only LHS noncontracting dimension can be split.";
+            return false;
+          }
+          switch (scope) {
+            case TritonFusionAnalysis::Scope::LHS:
+              lhs_noncontracting_split_ = spec->back().count;
+              break;
+            case TritonFusionAnalysis::Scope::OUTPUT:
+              if (lhs_noncontracting_split_ != spec->back().count) {
+                VLOG(8) << "Output non-contracting dimension has to be split "
+                           "the same way as the LHS input one if it is split.";
+                return false;
+              }
+              break;
+            default:
+              VLOG(8) << "Only LHS noncontracting dimension can be split.";
+              return false;
+          }
+          // Assign the major part of the noncontracting dimension to the
+          // unused batch one.
+          CHECK_EQ(dimensions[kBatchDimensionIndex], 1);
+          dimensions[kBatchDimensionIndex] = spec->back().count;
+          strides[kBatchDimensionIndex] = spec->back().stride;
+        } else {
+          VLOG(8) << "The dimension is split multiple times.";
+          return false;
+        }
+        dimensions.push_back(spec->front().count);
+        strides.push_back(spec->front().stride);
+      }
+    }
+    if (lhs_noncontracting_split_ > 1 &&
+        scope == TritonFusionAnalysis::Scope::OUTPUT &&
+        dimensions[kBatchDimensionIndex] == 1) {
+      // LHS input noncontracting dimension is split but the corresponding
+      // output one is not. Assign part of the output one to the unused batch
+      // dimension.
+      dimensions[kBatchDimensionIndex] = lhs_noncontracting_split_;
+      dimensions[kOutputLHSNonContractingDimensionIndex] /=
+          lhs_noncontracting_split_;
+      strides[kBatchDimensionIndex] =
+          strides[kOutputLHSNonContractingDimensionIndex] *
+          dimensions[kOutputLHSNonContractingDimensionIndex];
+    }
+    return true;
+  }
+
+ private:
+  int64_t lhs_noncontracting_split_ = 1;
+  const HloDotInstruction& dot_;
+};
+
+template <PrimitiveType XlaT, typename T>
+std::shared_ptr<graph::Tensor_attributes> LiteralToCudnnTensor(
+    const HloInstruction& hlo, graph::Graph& graph) {
+  using NativeT = typename primitive_util::PrimitiveTypeToNative<XlaT>::type;
+  return graph.tensor(T(hlo.literal().GetFirstElement<NativeT>()));
+}
+
+std::optional<std::shared_ptr<graph::Tensor_attributes>>
+HandleConstantHloToCudnnGraph(const HloInstruction& hlo, graph::Graph& graph) {
+  CHECK(hlo.IsConstant()) << "HLO is not a constant: " << hlo.ToShortString();
+  if (!ShapeUtil::IsScalar(hlo.shape())) {
+    VLOG(3) << "Currently only support fusing scalar in the graph";
+    return std::nullopt;
+  }
+  PrimitiveType constant_type = hlo.shape().element_type();
+  switch (constant_type) {
+    case BF16:
+      return LiteralToCudnnTensor<BF16, __nv_bfloat16>(hlo, graph);
+    case F32:
+      return LiteralToCudnnTensor<F32, float>(hlo, graph);
+    case S32:
+      return LiteralToCudnnTensor<S32, int>(hlo, graph);
+    default:
+      VLOG(3) << "Unsupported constant type: "
+              << PrimitiveType_Name(constant_type);
+      return std::nullopt;
+  }
+}
+
+std::optional<std::shared_ptr<graph::Tensor_attributes>>
+HandleClampToCudnnGraph(
+    const HloInstruction& hlo, graph::Graph& graph,
+    absl::flat_hash_map<const HloInstruction*,
+                        std::shared_ptr<graph::Tensor_attributes>>
+        hlo_to_cudnn,
+    fe::DataType_t data_type, fe::DataType_t compute_dtype) {
+  CHECK(hlo.opcode() == HloOpcode::kClamp)
+      << "HLO is not a clamp: " << hlo.ToShortString();
+  CHECK(hlo.operands().size() == 3)
+      << "Clamp requires to have 3 operands: " << hlo.ToShortString();
+  // clamp = max(lower, min(value, upper));
+  const auto min_attrs = graph::Pointwise_attributes()
+                             .set_mode(fe::PointwiseMode_t::MIN)
+                             .set_compute_data_type(compute_dtype);
+  std::shared_ptr<graph::Tensor_attributes> min_tensor = graph.pointwise(
+      hlo_to_cudnn[hlo.operand(1)], hlo_to_cudnn[hlo.operand(2)], min_attrs);
+  min_tensor->set_data_type(data_type).set_name(std::string(hlo.name()));
+  const auto max_attrs = graph::Pointwise_attributes()
+                             .set_mode(fe::PointwiseMode_t::MAX)
+                             .set_compute_data_type(compute_dtype);
+  return graph.pointwise(min_tensor, hlo_to_cudnn[hlo.operand(0)], max_attrs);
+}
+
+// Traverses fusion computations and creates cuDNN graphs out of them.
+absl::StatusOr<std::optional<se::gpu::CudnnGraph>> HloFusionToCuDnnGraph(
+    const HloFusionInstruction& fusion) {
+  const HloComputation& computation = *fusion.fused_instructions_computation();
+  VLOG(5) << fusion.ToString();
+  VLOG(5) << computation.ToString();
+  graph::Graph graph;
+  std::vector<HloInstruction*> instructions =
+      computation.MakeInstructionPostOrder();
+  absl::flat_hash_map<const HloInstruction*,
+                      std::shared_ptr<graph::Tensor_attributes>>
+      hlo_to_cudnn;
+  TF_ASSIGN_OR_RETURN(std::optional<GemmDimensionAdapter> adapter,
+                      GemmDimensionAdapter::Create(computation));
+  if (!adapter.has_value()) {
+    return std::nullopt;
+  }
+  auto add_parameter = [&](const HloInstruction& parameter,
+                           std::vector<int64_t>& dimensions,
+                           std::vector<int64_t> strides) {
+    const std::optional<fe::DataType_t> data_type =
+        ToCudnnDataType(parameter.shape().element_type());
+    if (!data_type.has_value()) {
+      VLOG(3) << "Unsupported data type.";
+      return false;
+    }
+    hlo_to_cudnn[&parameter] = graph.tensor(
+        graph::Tensor_attributes()
+            .set_dim(dimensions)
+            .set_stride(strides)
+            .set_data_type(*data_type)
+            .set_name(std::string(parameter.name()))
+            .set_uid(se::gpu::CuDnnTensorUID(parameter.parameter_number())));
+    return true;
+  };
+  for (const TritonFusionAnalysis::Scope scope :
+       {TritonFusionAnalysis::Scope::LHS, TritonFusionAnalysis::Scope::RHS,
+        TritonFusionAnalysis::Scope::OUTPUT}) {
+    for (const HloInstruction* parameter :
+         adapter->analysis_.ScopeParameters(scope)) {
+      std::vector<int64_t> dimensions;
+      std::vector<int64_t> strides;
+      if (!adapter->DimensionsAndStrides(*parameter, scope, dimensions,
+                                         strides)) {
+        VLOG(3) << "Unsupported dimensions.";
+        return std::nullopt;
+      }
+      if (!add_parameter(*parameter, dimensions, strides)) {
+        return std::nullopt;
+      }
+    }
+  }
+
+  for (const HloInstruction* hlo : instructions) {
+    VLOG(5) << hlo->ToShortString();
+    auto operand = [&hlo_to_cudnn, &hlo](int i) {
+      return hlo_to_cudnn[hlo->operand(i)];
+    };
+    const auto data_type = ToCudnnDataType(hlo->shape().element_type());
+    if (!data_type.has_value()) {
+      VLOG(3) << "Unimplemented data type: " << hlo->shape().element_type();
+      return std::nullopt;
+    }
+    if (hlo->opcode() == HloOpcode::kParameter) {
+      CHECK(hlo_to_cudnn.contains(hlo));
+      continue;
+    } else if (hlo->opcode() == HloOpcode::kCustomCall) {
+      if (hlo->user_count() != 1 ||
+          !IsWorkspaceAllocationRoot(*hlo->users()[0])) {
+        VLOG(3) << "Custom calls are only expected to be used for workspace "
+                   "allocation.";
+        return std::nullopt;
+      }
+      continue;
+    } else if (hlo->opcode() == HloOpcode::kTuple) {
+      if (!IsWorkspaceAllocationRoot(*hlo)) {
+        VLOG(3) << "Tuples are only expected at outputs for workspace "
+                   "allocation.";
+        return std::nullopt;
+      }
+      continue;
+    } else if (FusionLevel(fusion) >= 2 &&
+               hlo->opcode() == HloOpcode::kConstant) {
+      if (const auto const_tensor = HandleConstantHloToCudnnGraph(*hlo, graph);
+          const_tensor.has_value()) {
+        hlo_to_cudnn[hlo] = const_tensor.value();
+      } else {
+        return std::nullopt;
+      }
+    } else if (hlo->opcode() == HloOpcode::kReshape ||
+               hlo->opcode() == HloOpcode::kBitcast ||
+               hlo->opcode() == HloOpcode::kTranspose ||
+               hlo->opcode() == HloOpcode::kCopy ||
+               (FusionLevel(fusion) >= 2 &&
+                hlo->opcode() == HloOpcode::kBroadcast)) {
+      // All these are accounted for separately as transformations of strides.
+      hlo_to_cudnn[hlo] = operand(0);
+    } else if (hlo->IsElementwise()) {
+      const auto compute_dtype =
+          GetComputeDataType(hlo->shape().element_type());
+      if (!compute_dtype.has_value()) {
+        return std::nullopt;
+      }
+      if (hlo->opcode() == HloOpcode::kClamp) {
+        const auto clamp =
+            HandleClampToCudnnGraph(*hlo, graph, hlo_to_cudnn,
+                                    data_type.value(), compute_dtype.value());
+        if (!clamp.has_value()) {
+          return std::nullopt;
+        }
+        hlo_to_cudnn[hlo] = clamp.value();
+      } else {
+        const auto mode = GetElementwiseMode(*hlo);
+        if (!mode.has_value()) {
+          VLOG(3) << "Unsupported elementwise operation.";
+          return std::nullopt;
+        }
+        const auto attrs = graph::Pointwise_attributes()
+                               .set_mode(mode.value())
+                               .set_compute_data_type(compute_dtype.value());
+        if (hlo->operand_count() == 1) {
+          hlo_to_cudnn[hlo] = graph.pointwise(operand(0), attrs);
+          // Sets the dimensions for unary ops whose operands are broadcast
+          // for cuDNN to infer its inputs' shapes. constant has dimension [1]
+          // while cuDNN requires constant to have dimension [1,1,1]. Not
+          // setting output of the unary shapes results in the rejection of
+          // the cuDNN graph.
+          if (hlo->operand(0)->opcode() == HloOpcode::kBroadcast) {
+            const auto scope = adapter->analysis_.QueryInstructionScope(*hlo);
+            std::vector<int64_t> dimensions;
+            std::vector<int64_t> strides;
+            if (!scope.has_value()) {
+              LOG(FATAL) << "No scope for instruction: "
+                         << hlo->ToShortString();
+            }
+            if (!adapter->DimensionsAndStrides(*hlo, scope.value(), dimensions,
+                                               strides)) {
+              VLOG(3) << "Unsupported hlo for querying dimensions: "
+                      << hlo->ToShortString();
+            } else {
+              hlo_to_cudnn[hlo]->set_dim(dimensions);
+            }
+          }
+        } else if (hlo->operand_count() == 2) {
+          hlo_to_cudnn[hlo] = graph.pointwise(operand(0), operand(1), attrs);
+        } else if (hlo->operand_count() == 3) {
+          if (hlo->opcode() != HloOpcode::kSelect) {
+            VLOG(3) << "Unexpected ternary operation: " << hlo->ToString();
+            return std::nullopt;
+          }
+          // Operand order for select differs between HLO and cuDNN.
+          hlo_to_cudnn[hlo] =
+              graph.pointwise(operand(1), operand(2), operand(0), attrs);
+        } else {
+          VLOG(3) << "Unimplemented elementwise operation.";
+          return std::nullopt;
+        }
+      }
+    } else if (hlo->opcode() == HloOpcode::kDot) {
+      const auto compute_dtype =
+          GetComputeDataType(hlo->shape().element_type());
+      if (!compute_dtype.has_value()) {
+        return std::nullopt;
+      }
+      hlo_to_cudnn[hlo] =
+          graph.matmul(operand(0), operand(1),
+                       graph::Matmul_attributes().set_compute_data_type(
+                           compute_dtype.value()));
+    } else {
+      VLOG(3) << "Unimplemented operation.";
+      return std::nullopt;
+    }
+    if (hlo_to_cudnn[hlo] == nullptr) {
+      VLOG(3) << "Creation of the operation failed.";
+      return std::nullopt;
+    }
+    hlo_to_cudnn[hlo]
+        ->set_data_type(data_type.value())
+        .set_name(std::string(hlo->name()));
+  }
+  const HloInstruction* output = instructions.back();
+  if (instructions.back()->shape().IsTuple()) {
+    output = instructions.back()->operand(0);
+  }
+  std::vector<int64_t> dimensions;
+  std::vector<int64_t> strides;
+  if (!adapter->DimensionsAndStrides(
+          *output, TritonFusionAnalysis::Scope::OUTPUT, dimensions, strides)) {
+    VLOG(3) << "Unsupported dimensions.";
+    return std::nullopt;
+  }
+  hlo_to_cudnn[output]
+      ->set_output(true)
+      .set_dim(dimensions)
+      .set_stride(strides)
+      .set_uid(se::gpu::CuDnnTensorUID(fusion.operand_count()));
+  if (!fusion.GetModule()->config().debug_options().xla_dump_to().empty()) {
+    json dump;
+    graph.serialize(dump);
+    DumpToFileInDirOrStdout(
+        /*module=*/*fusion.GetModule(),
+        /*file_prefix=*/"",
+        /*file_suffix=*/
+        absl::StrCat("cudnn_fusion_", fusion.name(), ".json"),
+        /*contents=*/dump.dump(1));
+  }
+
+  return se::gpu::CudnnGraph(std::move(graph));
+}
+
+// Creates a cuDNN graph, queries cuDNN whether it is supported.
+absl::StatusOr<se::gpu::CudnnGraph> PrepareGraph(
+    se::dnn::DnnSupport& dnn_support, const HloFusionInstruction& hlo) {
+  TF_ASSIGN_OR_RETURN(std::optional<se::gpu::CudnnGraph> graph,
+                      HloFusionToCuDnnGraph(hlo));
+  if (!graph.has_value()) {
+    return absl::InternalError("Construction of cuDNN graph failed.");
+  }
+  TF_RETURN_IF_ERROR(graph->Prepare(
+      dnn_support,
+      se::NumericOptions{RequireDeterminism(hlo.GetModule()->config()),
+                         /*allow_tf32=*/true}));
+  return *graph;
+}
+
+absl::StatusOr<HloInstruction*> AddWorkspace(HloInstruction& fusion,
+                                             const int64_t workspace_size) {
+  HloComputation* computation = fusion.fused_instructions_computation();
+  HloInstruction* custom_call =
+      computation->AddInstruction(HloInstruction::CreateCustomCall(
+          ShapeUtil::MakeShape(S8, {workspace_size}), {},
+          kWorkspaceAllocationCustomCallTarget));
+  HloInstruction* output_tuple =
+      computation->AddInstruction(HloInstruction::CreateTuple(
+          {computation->root_instruction(), custom_call}));
+  computation->set_root_instruction(output_tuple, true);
+  HloInstruction* new_fusion = fusion.parent()->AddInstruction(
+      fusion.CloneWithNewShape(output_tuple->shape()));
+  TF_RETURN_IF_ERROR(fusion.ReplaceAllUsesWith(fusion.parent()->AddInstruction(
+      HloInstruction::CreateGetTupleElement(new_fusion, 0))));
+  TF_RETURN_IF_ERROR(fusion.parent()->RemoveInstruction(&fusion));
+  return new_fusion;
+}
+
+class CuDnnFusionVisitor : public DfsHloRewriteVisitor {
+ public:
+  explicit CuDnnFusionVisitor(se::dnn::DnnSupport& dnn_support,
+                              BinaryMap& compilation_results)
+      : dnn_support_(dnn_support), compilation_results_(compilation_results) {}
+
+  absl::Status HandleFusion(HloInstruction* hlo) override {
+    TF_ASSIGN_OR_RETURN(auto gpu_config,
+                        hlo->backend_config<GpuBackendConfig>());
+    const auto& fusion_backend_config = gpu_config.fusion_backend_config();
+    if (fusion_backend_config.kind() != kCuDnnFusionKind) {
+      return absl::OkStatus();
+    }
+    int64_t plan_id = -1;
+    if (fusion_backend_config.has_cudnn_fusion_config()) {
+      plan_id = fusion_backend_config.cudnn_fusion_config().plan_id();
+    }
+
+    VLOG(4) << "Processing " << hlo->ToString();
+    VLOG(4) << "Plan ID: " << plan_id;
+
+    auto add_workspace = [&](const int64_t workspace_size) {
+      if (workspace_size > 0) {
+        TF_ASSIGN_OR_RETURN(hlo, AddWorkspace(*hlo, workspace_size));
+        SetVisited(*hlo);
+      }
+      return absl::OkStatus();
+    };
+    const std::string fingerprint_without_workspace =
+        GetComputationFingerprint(hlo->fused_instructions_computation(), {});
+    auto workspace_size_it =
+        workspace_sizes_.find(fingerprint_without_workspace);
+    if (workspace_size_it == workspace_sizes_.cend()) {
+      TF_ASSIGN_OR_RETURN(
+          se::gpu::CudnnGraph graph,
+          PrepareGraph(dnn_support_, *DynCast<HloFusionInstruction>(hlo)));
+
+      if (plan_id >= 0) {
+        // Build single plan with given ID.
+        if (plan_id >= graph.Graph().get_execution_plan_count()) {
+          return absl::InternalError("cuDNN graph plan does not exist.");
+        }
+        TF_RETURN_IF_ERROR(graph.Build(dnn_support_, plan_id));
+      } else {
+        // Build plans one by one till first successful when no plan_id was
+        // provided.
+        for (plan_id = 0; plan_id < graph.Graph().get_execution_plan_count();
+             ++plan_id) {
+          VLOG(7) << "Trying plan ID " << plan_id;
+          if (graph.Build(dnn_support_, plan_id).ok()) {
+            VLOG(7) << "Successfully built plan ID " << plan_id;
+            break;
+          }
+        }
+        if (plan_id == graph.Graph().get_execution_plan_count()) {
+          return absl::InternalError("No cuDNN plans can be built.");
+        }
+      }
+      const int64_t workspace_size = graph.Graph().get_workspace_size();
+      workspace_sizes_.insert(workspace_size_it,
+                              {fingerprint_without_workspace, workspace_size});
+      TF_RETURN_IF_ERROR(add_workspace(workspace_size));
+
+      std::vector<uint8_t> serialized_graph;
+      RETURN_IF_CUDNN_FRONTEND_ERROR(graph.Graph().serialize(serialized_graph));
+      // Compute a new fingerprint with a potential workspace for the
+      // compilation results to match a fingerprint computed by the emitter.
+      compilation_results_[GetComputationFingerprint(
+          hlo->fused_instructions_computation(), {})] =
+          std::string(reinterpret_cast<char*>(serialized_graph.data()),
+                      serialized_graph.size());
+    } else {
+      VLOG(4) << "Cache hit.";
+      TF_RETURN_IF_ERROR(add_workspace(workspace_size_it->second));
+    }
+    auto cudnn_config = gpu_config.mutable_fusion_backend_config()
+                            ->mutable_cudnn_fusion_config();
+    cudnn_config->set_plan_id(plan_id);
+    TF_RETURN_IF_ERROR(hlo->set_backend_config(gpu_config));
+
+    MarkAsChanged();
+    return absl::OkStatus();
+  }
+
+ private:
+  se::dnn::DnnSupport& dnn_support_;
+  // <HLO computation fingerprint, serialized compiled cuDNN graph>.
+  BinaryMap& compilation_results_;
+  absl::flat_hash_map<std::string, int64_t> workspace_sizes_;
+};
+
+}  // namespace
+
+absl::StatusOr<bool> CuDnnFusionCompiler::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  XLA_SCOPED_LOGGING_TIMER("cuDNN fusion compiler");
+  return CuDnnFusionVisitor(dnn_support_, compilation_results_)
+      .RunOnModule(module, execution_threads);
+}
+
+int CuDnnFusionCompiler::GetAvailablePlanCount(
+    se::StreamExecutor& stream_exec, const HloFusionInstruction& hlo) {
+  auto graph = PrepareGraph(*stream_exec.AsDnn(), hlo);
+  if (!graph.ok()) {
+    return 0;
+  }
+  return std::min(
+      static_cast<int32_t>(graph->Graph().get_execution_plan_count()),
+      hlo.GetModule()->config().debug_options().xla_gpu_cudnn_gemm_max_plans());
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_fusion_compiler.h b/third_party/xla/xla/service/gpu/transforms/cudnn_fusion_compiler.h
new file mode 100644
index 0000000..4917914
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cudnn_fusion_compiler.h
@@ -0,0 +1,59 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_CUDNN_FUSION_COMPILER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_CUDNN_FUSION_COMPILER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/gpu/ir_emission_utils.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/stream_executor/dnn.h"
+#include "xla/stream_executor/stream_executor.h"
+
+namespace xla {
+namespace gpu {
+
+// Converts HLO fusions with cuDNN backend config to cuDNN graphs,
+// compiles them using a cuDNN handle and serializes them.
+class CuDnnFusionCompiler : public HloModulePass {
+ public:
+  explicit CuDnnFusionCompiler(se::StreamExecutor& stream_exec,
+                               BinaryMap& compilation_results)
+      : dnn_support_(*stream_exec.AsDnn()),
+        compilation_results_(compilation_results) {}
+
+  absl::string_view name() const override { return "cudnn-fusion-compiler"; }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+  static int GetAvailablePlanCount(se::StreamExecutor& stream_exec,
+                                   const HloFusionInstruction& hlo);
+
+ private:
+  se::dnn::DnnSupport& dnn_support_;
+  BinaryMap& compilation_results_;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_CUDNN_FUSION_COMPILER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_norm_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_norm_rewriter.cc
new file mode 100644
index 0000000..5d5e089
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cudnn_norm_rewriter.cc
@@ -0,0 +1,1553 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/cudnn_norm_rewriter.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <cstdlib>
+#include <functional>
+#include <iterator>
+#include <limits>
+#include <optional>
+#include <utility>
+#include <vector>
+
+#include "google/protobuf/wrappers.pb.h"
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/status.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/layout_util.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/service/hlo_creation_utils.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/types.h"
+#include "xla/util.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/logging.h"
+#include "tsl/platform/statusor.h"
+#include "tsl/protobuf/dnn.pb.h"
+
+#if GOOGLE_CUDA
+#include "third_party/gpus/cuda/include/cuda.h"  // IWYU pragma: keep
+#include "third_party/gpus/cudnn/cudnn.h"        // IWYU pragma: keep
+#include "third_party/gpus/cudnn/cudnn_version.h"
+#endif
+
+namespace xla {
+namespace gpu {
+
+namespace {
+
+namespace m = match;
+
+// Traverses the graph upward starting at instr and returns the
+// first instruction that is not a convert, bitcast or reshape.
+const HloInstruction* SkipUnaryOps(const HloInstruction* instr) {
+  while (instr->opcode() == HloOpcode::kConvert ||
+         instr->opcode() == HloOpcode::kBitcast ||
+         instr->opcode() == HloOpcode::kReshape) {
+    instr = instr->operand(0);
+  }
+  return instr;
+}
+
+// Recursively traverses the graph downward starting at instr and stores in
+// instrs the users that are not a convert, bitcast or reshape.
+void SkipUnaryOpsTopDownRecursive(HloInstruction* instr,
+                                  std::vector<HloInstruction*>& instrs) {
+  if (instr->opcode() == HloOpcode::kConvert ||
+      instr->opcode() == HloOpcode::kBitcast ||
+      instr->opcode() == HloOpcode::kReshape) {
+    for (HloInstruction* user : instr->users()) {
+      SkipUnaryOpsTopDownRecursive(user, instrs);
+    }
+  } else {
+    instrs.emplace_back(instr);
+  }
+}
+
+// Holds auxiliary information about individual layer norm patterns rewritten
+// into a cuDNN Custom Call.
+struct NormMetadata {
+  // Transposes applied to the input and output of the forward layer norm to
+  // order the normalization and non-normalization dimensions as required by
+  // cuDNN. Nullptr if no transposes were inserted.
+  HloInstruction *x_transpose, *y_transpose;
+  // The reduction and non-reduction dimensions of the input into the forward
+  // layer norm before the potential application of transposes and adjusted for
+  // the removal of any degenerate dimensions in the input to the norm.
+  std::vector<int64_t> norm_dims_adjusted, non_norm_dims_adjusted;
+};
+
+// Map from the instruction pointer of a layer norm Custom Call to its metadata.
+using NormMetadataMap = absl::flat_hash_map<HloInstruction*, NormMetadata>;
+
+// Captures multiple HloInstruction pointers and verifies that their target
+// is identical.
+//
+// Example:
+// Pattern cos(x) / sin(x) with cos and sin intended to operate on the same
+// HloInstruction:
+//  UniqueHloInstruction x;
+//  bool m = Match(
+//      instr, m::Divide(m::Cos(m::Op().WithPredicate(x.capture_and_verify)),
+//                       m::Sin(m::Op().WithPredicate(x.capture_and_verify))));
+// m is true and x.Instr() returns an HloInstruction pointer to the operand of
+// cosine and sine iff HloInstruction *instr points to a division of a cosine by
+// a sine that operate on the same instruction.
+class UniqueHloInstruction {
+ public:
+  UniqueHloInstruction()
+      : is_set_(false), instr_(nullptr), capture_or_verify_() {}
+  HloInstruction* Instr() const { return instr_; }
+  void SetInstr(HloInstruction* instr) {
+    is_set_ = true;
+    instr_ = instr;
+  }
+
+  // Stores instr when invoked the first time. Otherwise, compares instr to the
+  // stored value and sets the stored value to nullptr if the comparison fails.
+  bool CaptureOrVerify(HloInstruction* instr) {
+    if (is_set_ && instr != instr_) {
+      instr_ = nullptr;
+    }
+    if (!is_set_) {
+      is_set_ = true;
+      instr_ = instr;
+    }
+    return instr_;
+  }
+
+  // Returns a std::function for capturing or verifying an instruction using
+  // WithPredicate.
+  std::function<bool(const HloInstruction*)> GetCaptureOrVerifyFn() {
+    if (!capture_or_verify_) {
+      capture_or_verify_ = [this](const HloInstruction* instr) -> bool {
+        return CaptureOrVerify(const_cast<HloInstruction*>(instr));
+      };
+    }
+    return capture_or_verify_;
+  }
+
+ private:
+  bool is_set_;
+  HloInstruction* instr_;
+  std::function<bool(const HloInstruction*)> capture_or_verify_;
+};
+
+// Returns an architecture-specific constant for the calculation of an upper
+// bound for the size of the scratch space for layer norm kernels.
+absl::StatusOr<int64_t> CConstant(
+    se::CudaComputeCapability cuda_compute_capability) {
+  if (cuda_compute_capability.major == se::CudaComputeCapability::AMPERE) {
+    return 32 * 128;
+  } else if (cuda_compute_capability.major ==
+             se::CudaComputeCapability::HOPPER) {
+    return 32 * 144;
+  }
+  return xla::Internal("Norm kernels require Ampere or Hopper architecture.");
+}
+
+// Returns whether the element type of instr is compatible with layer norm
+// kernels.
+bool CompatibleElementType(const HloInstruction* instr) {
+  PrimitiveType element_type = instr->shape().element_type();
+  return element_type == BF16 || element_type == F16 || element_type == F32;
+}
+
+// Returns the dimensions associated with shape, adjusted for the removal of any
+// degenerate dimensions in shape. Specifically, for each dimension d in
+// dimensions, returns the new index of d if all dimensions of size 1 are
+// removed from shape. If d has size 1, it is not included in the returned
+// vector.
+std::vector<int64_t> AdjustedDimensions(const Shape& shape,
+                                        absl::Span<const int64_t> dimensions) {
+  absl::flat_hash_map<int64_t, int64_t> dimension_map;
+  for (int64_t dimension = 0, non_degen_dimension = 0; dimension < shape.rank();
+       ++dimension) {
+    if (shape.dimensions(dimension) > 1) {
+      dimension_map.insert({dimension, non_degen_dimension});
+      non_degen_dimension++;
+    }
+  }
+  std::vector<int64_t> adjusted_dimensions;
+  for (int64_t dimension : dimensions) {
+    auto non_degenerate_dimension = dimension_map.find(dimension);
+    if (non_degenerate_dimension != dimension_map.end()) {
+      adjusted_dimensions.emplace_back(non_degenerate_dimension->second);
+    }
+  }
+  return adjusted_dimensions;
+}
+
+// Returns the dimensions of broadcast or reduction instructions, adjusted for
+// the removal of any degenerate dimensions in the output or input.
+std::vector<int64_t> AdjustedDimensions(const HloInstruction* instr) {
+  Shape shape;
+  if (instr->opcode() == HloOpcode::kBroadcast) {
+    shape = instr->shape();
+  } else if (instr->opcode() == HloOpcode::kReduce) {
+    shape = instr->operand(0)->shape();
+  } else {
+    return {};
+  }
+  return AdjustedDimensions(shape, instr->dimensions());
+}
+
+// Returns whether the HLO Computation applied by instr calculates the sum of
+// the elements. When provided, compares reduce_dims to the dimensions of the
+// reduction.
+bool AppliesAddReduce(const HloInstruction* instr,
+                      absl::Span<const int64_t> reduce_dims = {}) {
+  if (instr->opcode() != HloOpcode::kReduce) {
+    return false;
+  }
+
+  // Verify the dimensions of the reduction.
+  if (!reduce_dims.empty() && AdjustedDimensions(instr) != reduce_dims) {
+    return false;
+  }
+
+  HloComputation* reduce_comp = instr->to_apply();
+  HloInstruction* reduce_comp_root = reduce_comp->root_instruction();
+  return instr->operand_count() == 2 &&
+         instr->operand(1)->opcode() == HloOpcode::kConstant &&
+         ShapeUtil::IsScalar(instr->operand(1)->shape()) &&
+         instr->operand(1)->literal().GetAsDouble({}) == 0. &&
+         reduce_comp_root->opcode() == HloOpcode::kAdd &&
+         reduce_comp_root->operand(0)->opcode() == HloOpcode::kParameter &&
+         reduce_comp_root->operand(1)->opcode() == HloOpcode::kParameter;
+}
+
+// Returns whether instr multiplies the result of a reduction by one over the
+// number of reduced elements.
+bool CalculatesExpectation(const HloInstruction* instr) {
+  instr = SkipUnaryOps(instr);
+  if (instr->opcode() != HloOpcode::kMultiply) {
+    return false;
+  }
+  bool bcast_operand = instr->operand(0)->opcode() != HloOpcode::kBroadcast;
+  const HloInstruction *broadcast = instr->operand(bcast_operand),
+                       *reduce = SkipUnaryOps(instr->operand(!bcast_operand));
+  if (reduce->opcode() != HloOpcode::kReduce ||
+      broadcast->opcode() != HloOpcode::kBroadcast ||
+      broadcast->operand(0)->opcode() != HloOpcode::kConstant) {
+    return false;
+  }
+
+  float actual_r_nelems =
+      broadcast->operand(0)->literal().GetAsDouble({}).value();
+  int64_t nelems = 1;
+  for (int64_t norm_dim : reduce->dimensions()) {
+    nelems *= reduce->operand(0)->shape().dimensions()[norm_dim];
+  }
+  // The absolute of the difference between the actual scaling factor and the
+  // reference value must not exceed a prescribed threshold.
+  float r_nelems = 1. / static_cast<float>(nelems);
+  float numerical_epsilon = std::numeric_limits<bfloat16>::epsilon();
+  return abs(actual_r_nelems - r_nelems) <
+         ((actual_r_nelems + r_nelems) * numerical_epsilon);
+}
+
+// Returns whether target can be reached from instr by recursively traversing
+// the graph across converts, bitcasts and reshapes.
+bool FindTargetRecursive(
+    const HloInstruction* instr, const HloInstruction* target,
+    absl::flat_hash_set<const HloInstruction*>& visited_instrs,
+    const HloInstruction* transpose) {
+  visited_instrs.emplace(instr);
+  const absl::flat_hash_set<HloOpcode> supported_ops = {
+      HloOpcode::kConvert, HloOpcode::kBitcast, HloOpcode::kReshape};
+  if (instr == target) {
+    return true;
+  }
+  // Look for target among the users of instr.
+  for (HloInstruction* user : instr->users()) {
+    if ((supported_ops.contains(user->opcode()) || user == transpose) &&
+        !visited_instrs.contains(user)) {
+      return FindTargetRecursive(user, target, visited_instrs, transpose);
+    }
+  }
+  // Ascend the graph if target is not found and instr is a convert, bitcast
+  // or reshape.
+  if (supported_ops.contains(instr->opcode())) {
+    return FindTargetRecursive(instr->operand(0), target, visited_instrs,
+                               transpose);
+  }
+  return false;
+}
+
+bool FindTarget(const HloInstruction* custom_call, const HloInstruction* instr,
+                const HloInstruction* target,
+                const NormMetadataMap& norm_metadata) {
+  absl::flat_hash_set<const HloInstruction*> visited_instrs;
+  auto custom_call_metadata = norm_metadata.find(custom_call);
+  if (custom_call_metadata == norm_metadata.end()) {
+    return false;
+  }
+  return FindTargetRecursive(instr, target, visited_instrs,
+                             custom_call_metadata->second.x_transpose);
+}
+
+// Maps the dimension numbers in dimensions from shape original_shape to shape
+// reshaped_shape, assuming that the shapes are related through a strict
+// reshape. Returns an empty vector if a dimension mapping is not found.
+std::vector<int64_t> MapDimensions(const Shape& original_shape,
+                                   const Shape& reshaped_shape,
+                                   const absl::Span<const int64_t> dimensions) {
+  auto dimension_product =
+      [](const Shape& shape,
+         absl::Span<const int64_t> product_dimensions) -> int64_t {
+    int64_t product = 1;
+    for (int64_t product_dimension : product_dimensions) {
+      product *= shape.dimensions(product_dimension);
+    }
+    return product;
+  };
+  // Construct the dimension mapping.
+  absl::flat_hash_map<int64_t, std::vector<int64_t>> dimensions_map;
+  std::vector<int64_t> original_dimensions, reshaped_dimensions;
+  for (int64_t original_dimension = 0, reshaped_dimension = 0;
+       original_dimension < original_shape.rank(); ++original_dimension) {
+    original_dimensions.emplace_back(original_dimension);
+    while ((reshaped_dimensions.empty() ||
+            dimension_product(reshaped_shape, reshaped_dimensions) <
+                dimension_product(original_shape, original_dimensions)) &&
+           reshaped_dimension < reshaped_shape.rank()) {
+      reshaped_dimensions.emplace_back(reshaped_dimension++);
+    }
+
+    // Many-to-many dimension mappings are not supported.
+    if (original_dimensions.size() > 1 && reshaped_dimensions.size() > 1) {
+      return {};
+    }
+
+    if (dimension_product(original_shape, original_dimensions) ==
+        dimension_product(reshaped_shape, reshaped_dimensions)) {
+      std::vector<int64_t> original_dimensions_in_dimensions;
+      std::set_intersection(
+          original_dimensions.begin(), original_dimensions.end(),
+          dimensions.begin(), dimensions.end(),
+          std::back_inserter(original_dimensions_in_dimensions));
+      // The unique mapping of dimensions requires either all or none of the
+      // entries of original_dimensions to be an element of dimensions.
+      if (!original_dimensions_in_dimensions.empty() &&
+          original_dimensions_in_dimensions.size() !=
+              original_dimensions.size()) {
+        return {};
+      }
+      for (int64_t dimension : original_dimensions) {
+        dimensions_map.insert({dimension, reshaped_dimensions});
+      }
+      original_dimensions.clear();
+      reshaped_dimensions.clear();
+    }
+  }
+
+  // Map the dimensions numbers to the reshaped shape.
+  std::vector<int64_t> mapped_dimensions;
+  for (int64_t dimension : dimensions) {
+    auto mapped_dimension = dimensions_map.find(dimension);
+    if (mapped_dimension == dimensions_map.end()) {
+      return {};
+    }
+    mapped_dimensions.insert(mapped_dimensions.end(),
+                             mapped_dimension->second.begin(),
+                             mapped_dimension->second.end());
+  }
+
+  // Eliminate duplicates in the mapped dimension numbers.
+  mapped_dimensions.erase(
+      std::unique(mapped_dimensions.begin(), mapped_dimensions.end()),
+      mapped_dimensions.end());
+  return mapped_dimensions;
+}
+
+// Recursively traverses the graph across converts, bitcasts and reshapes,
+// starting from instr, and returns the first addition-reduction identified.
+// Returns nullptr if no addition-reduction is found.
+HloInstruction* FindAddReduceRecursive(
+    HloInstruction* instr, const Shape& orig_instr_shape,
+    const absl::Span<const int64_t> reduce_dims,
+    absl::flat_hash_set<HloInstruction*>& visited_instrs) {
+  visited_instrs.emplace(instr);
+  const absl::flat_hash_set<HloOpcode> supported_ops = {
+      HloOpcode::kConvert, HloOpcode::kBitcast, HloOpcode::kReshape};
+  // Look for a reduction among the users of instr.
+  for (HloInstruction* user : instr->users()) {
+    if (user->opcode() == HloOpcode::kReduce) {
+      std::vector<int64_t> mapped_reduce_dims =
+          MapDimensions(orig_instr_shape, instr->shape(), reduce_dims);
+      if (!mapped_reduce_dims.empty() &&
+          AppliesAddReduce(user, mapped_reduce_dims)) {
+        return user;
+      }
+    }
+    if (supported_ops.contains(user->opcode()) &&
+        !visited_instrs.contains(user)) {
+      return FindAddReduceRecursive(user, orig_instr_shape, reduce_dims,
+                                    visited_instrs);
+    }
+  }
+  // Ascend the graph if the addition-reduction is not found and instr is a
+  // convert, bitcast or reshape.
+  if (supported_ops.contains(instr->opcode())) {
+    return FindAddReduceRecursive(instr->mutable_operand(0), orig_instr_shape,
+                                  reduce_dims, visited_instrs);
+  }
+  return nullptr;
+}
+
+HloInstruction* FindAddReduce(HloInstruction* instr,
+                              const absl::Span<const int64_t> reduce_dims) {
+  absl::flat_hash_set<HloInstruction*> visited_instrs;
+  return FindAddReduceRecursive(instr, instr->shape(), reduce_dims,
+                                visited_instrs);
+}
+
+// Type conversion from and to any of BF16, FP16 and FP32.
+template <typename Pattern>
+auto SupportedConvert(Pattern pattern) {
+  auto supported_convert = [](const HloInstruction* instr) -> bool {
+    return CompatibleElementType(instr) &&
+           CompatibleElementType(instr->operand(0));
+  };
+  return m::Convert(pattern).WithPredicate(supported_convert);
+}
+
+// Bitcast or reshape adding or removing degenerate dimensions.
+template <typename Pattern>
+auto SupportedBitcastOrReshape(Pattern pattern) {
+  auto supported_bitcast_or_reshape = [](const HloInstruction* instr) -> bool {
+    return ShapeUtil::Equal(
+        ShapeUtil::DropDegenerateDimensions(instr->shape()),
+        ShapeUtil::DropDegenerateDimensions(instr->operand(0)->shape()));
+  };
+  return m::AnyOf<HloInstruction>(
+      m::Bitcast(pattern).WithPredicate(supported_bitcast_or_reshape),
+      m::Reshape(pattern).WithPredicate(supported_bitcast_or_reshape));
+}
+
+// Matches pattern, SupportedConvert(pattern),
+// SupportedBitcastOrReshape(pattern),
+// SupportedConvert(SupportedBitcastOrReshape(pattern)) and
+// SupportedBitcastOrReshape(SupportedConvert(pattern)).
+template <typename Pattern>
+auto OptionalSupportedTransform(Pattern pattern) {
+  auto shared_subpattern = m::SharedSubpattern(pattern);
+  return m::AnyOf<HloInstruction>(
+      SupportedConvert(SupportedBitcastOrReshape(shared_subpattern)),
+      SupportedBitcastOrReshape(SupportedConvert(shared_subpattern)),
+      SupportedConvert(shared_subpattern),
+      SupportedBitcastOrReshape(shared_subpattern), shared_subpattern);
+}
+
+// Bitcast or reshape with optional supported type conversion and/or addition or
+// removal of degenerate dimensions.
+template <typename Pattern>
+auto BitcastOrReshape(Pattern pattern) {
+  return OptionalSupportedTransform(
+      m::AnyOf<HloInstruction>(m::Bitcast(pattern), m::Reshape(pattern)));
+}
+
+// Transpose with optional supported type conversion and/or addition or removal
+// of degenerate dimensions.
+template <typename Pattern>
+auto Transpose(Pattern pattern) {
+  return OptionalSupportedTransform(m::Transpose(pattern));
+}
+
+// Rsqrt with optional supported type conversion and/or addition or removal of
+// degenerate dimensions.
+template <typename Pattern>
+auto Rsqrt(HloInstruction** rsqrt, Pattern pattern) {
+  return OptionalSupportedTransform(m::Rsqrt(rsqrt, pattern));
+}
+
+// AddAnyOrder with optional supported type conversion and/or addition or
+// removal of degenerate dimensions.
+template <typename Pattern0, typename Pattern1>
+auto AddAnyOrder(Pattern0 pattern0, Pattern1 pattern1) {
+  return OptionalSupportedTransform(m::AddAnyOrder(pattern0, pattern1));
+}
+
+// Subtract with optional supported type conversion and/or addition or removal
+// of degenerate dimensions.
+template <typename Pattern0, typename Pattern1>
+auto Subtract(Pattern0 pattern0, Pattern1 pattern1) {
+  return OptionalSupportedTransform(m::Subtract(pattern0, pattern1));
+}
+
+// Capturing subtract with optional supported type conversion and/or addition or
+// removal of degenerate dimensions.
+template <typename Pattern0, typename Pattern1>
+auto Subtract(HloInstruction** subtract, Pattern0 pattern0, Pattern1 pattern1) {
+  return OptionalSupportedTransform(m::Subtract(subtract, pattern0, pattern1));
+}
+
+// Multiply with optional supported type conversion and/or addition or removal
+// of degenerate dimensions.
+template <typename Pattern0, typename Pattern1>
+auto MultiplyAnyOrder(Pattern0 pattern0, Pattern1 pattern1) {
+  return OptionalSupportedTransform(m::MultiplyAnyOrder(pattern0, pattern1));
+}
+
+// Capturing multiply with optional supported type conversion and/or addition or
+// removal of degenerate dimensions.
+template <typename Pattern0, typename Pattern1>
+auto MultiplyAnyOrder(HloInstruction** multiply, Pattern0 pattern0,
+                      Pattern1 pattern1) {
+  return OptionalSupportedTransform(
+      m::MultiplyAnyOrder(multiply, pattern0, pattern1));
+}
+
+// Multiplication of pattern by itself with optional supported type conversion
+// and/or addition or removal of degenerate dimensions.
+template <typename Pattern>
+auto Square(Pattern pattern) {
+  return MultiplyAnyOrder(pattern, pattern)
+      .WithPredicate([](const HloInstruction* instr) {
+        return instr->unique_operands().size() == 1;
+      });
+}
+
+// Multiplication of the square of pattern by pattern with optional supported
+// type conversion and/or addition or removal of degenerate dimensions. The root
+// instruction of pattern cannot be a multiplication.
+template <typename Pattern>
+auto Cube(Pattern pattern) {
+  auto unique_cube = [](const HloInstruction* instr) -> bool {
+    bool square_operand = instr->operand(0)->opcode() != HloOpcode::kMultiply;
+    return instr->operand(!square_operand)->opcode() != HloOpcode::kMultiply &&
+           instr->operand(square_operand)->operand(0) ==
+               instr->operand(!square_operand);
+  };
+  return MultiplyAnyOrder(Square(pattern), pattern).WithPredicate(unique_cube);
+}
+
+// Addition-reduction of pattern with optional supported type conversion and/or
+// addition or removal of degenerate dimensions.
+template <typename Pattern>
+auto AddReduce(Pattern pattern) {
+  return OptionalSupportedTransform(
+      m::Reduce(pattern, m::Op())
+          .WithPredicate([](const HloInstruction* instr) {
+            return AppliesAddReduce(instr);
+          }));
+}
+
+// Capturing addition-reduction of pattern with optional supported type
+// conversion and/or addition or removal of degenerate dimensions.
+template <typename Pattern>
+auto AddReduce(HloInstruction** reduction, Pattern pattern) {
+  return OptionalSupportedTransform(
+      m::Reduce(reduction, pattern, m::Op())
+          .WithPredicate([](const HloInstruction* instr) {
+            return AppliesAddReduce(instr);
+          }));
+}
+
+// Negated addition-reduction.
+template <typename Pattern>
+auto NegateAddReduce(HloInstruction** reduction, Pattern pattern) {
+  return m::AnyOf<HloInstruction>(AddReduce(reduction, m::Negate(pattern)),
+                                  m::Negate(AddReduce(reduction, pattern)));
+}
+
+// Expected value, or mean, with optional broadcast.
+template <typename Pattern>
+auto Expectation(Pattern pattern) {
+  auto shared_subpattern =
+      MultiplyAnyOrder(m::Broadcast(m::ConstantScalar()), AddReduce(pattern))
+          .WithPredicate([](const HloInstruction* instr) {
+            return CalculatesExpectation(instr);
+          });
+  return m::AnyOf<HloInstruction>(m::Broadcast(shared_subpattern),
+                                  shared_subpattern);
+}
+
+// Expected value, or mean, with optional broadcast.
+template <typename Pattern>
+auto Expectation(UniqueHloInstruction* expectation, Pattern pattern) {
+  auto shared_subpattern = OptionalSupportedTransform(
+      m::MultiplyAnyOrder(m::Broadcast(m::ConstantScalar()), AddReduce(pattern))
+          .WithPredicate([](const HloInstruction* instr) {
+            return CalculatesExpectation(instr);
+          })
+          .WithPredicate(expectation->GetCaptureOrVerifyFn()));
+  return m::AnyOf<HloInstruction>(m::Broadcast(shared_subpattern),
+                                  shared_subpattern);
+}
+
+// Expected value, or mean, with optional broadcast.
+template <typename Pattern>
+auto Expectation(UniqueHloInstruction* expectation, HloInstruction** reduce,
+                 Pattern pattern) {
+  auto shared_subpattern = OptionalSupportedTransform(
+      m::MultiplyAnyOrder(m::Broadcast(m::ConstantScalar()),
+                          AddReduce(reduce, pattern))
+          .WithPredicate([](const HloInstruction* instr) {
+            return CalculatesExpectation(instr);
+          })
+          .WithPredicate(expectation->GetCaptureOrVerifyFn()));
+  return m::AnyOf<HloInstruction>(m::Broadcast(shared_subpattern),
+                                  shared_subpattern);
+}
+
+// Variance, expressed as expectation(X^2) - expectation(X)^2 or
+// expectation((X - expectation(X))^2).
+auto Variance(UniqueHloInstruction* variance, UniqueHloInstruction* expectation,
+              UniqueHloInstruction* x) {
+  return m::AnyOf<HloInstruction>(
+      Subtract(
+          Expectation(Square(OptionalSupportedTransform(
+              m::Op().WithPredicate(x->GetCaptureOrVerifyFn())))),
+          Square(Expectation(expectation,
+                             OptionalSupportedTransform(m::Op().WithPredicate(
+                                 x->GetCaptureOrVerifyFn())))))
+          .WithPredicate(variance->GetCaptureOrVerifyFn()),
+      Expectation(
+          Square(Subtract(
+              OptionalSupportedTransform(
+                  m::Op().WithPredicate(x->GetCaptureOrVerifyFn())),
+              Expectation(expectation,
+                          OptionalSupportedTransform(m::Op().WithPredicate(
+                              x->GetCaptureOrVerifyFn()))))))
+          .WithPredicate(variance->GetCaptureOrVerifyFn()));
+}
+
+// Reciprocal of the square root of variance + epsilon with optional broadcast.
+auto NormFactor(HloInstruction** norm_factor, UniqueHloInstruction* x,
+                UniqueHloInstruction* variance,
+                UniqueHloInstruction* expectation,
+                UniqueHloInstruction* epsilon) {
+  auto shared_subpattern = m::SharedSubpattern(Rsqrt(
+      norm_factor, AddAnyOrder(Variance(variance, expectation, x),
+                               m::Broadcast(m::ConstantScalar().WithPredicate(
+                                   epsilon->GetCaptureOrVerifyFn())))));
+  return m::AnyOf<HloInstruction>(m::Broadcast(shared_subpattern),
+                                  shared_subpattern);
+}
+
+// Any order of p0 * p1 * p2.
+template <typename P0, typename P1, typename P2>
+auto MultiplyMultiplyAnyOrder(P0 p0, P1 p1, P2 p2) {
+  return m::AnyOf<HloInstruction>(
+      MultiplyAnyOrder(p0, MultiplyAnyOrder(p1, p2)),
+      MultiplyAnyOrder(p1, MultiplyAnyOrder(p0, p2)),
+      MultiplyAnyOrder(p2, MultiplyAnyOrder(p0, p1)));
+}
+
+// Any order of p0 + p1 + p2.
+template <typename P0, typename P1, typename P2>
+auto AddAddAnyOrder(P0 p0, P1 p1, P2 p2) {
+  return m::AnyOf<HloInstruction>(AddAnyOrder(p0, AddAnyOrder(p1, p2)),
+                                  AddAnyOrder(p1, AddAnyOrder(p0, p2)),
+                                  AddAnyOrder(p2, AddAnyOrder(p0, p1)));
+}
+
+// Any order of p0 * (p1 + p2).
+template <typename P0, typename P1, typename P2>
+auto MultiplyAddAnyOrder(P0 p0, P1 p1, P2 p2) {
+  return m::AnyOf<HloInstruction>(
+      MultiplyAnyOrder(p0, AddAnyOrder(p1, p2)),
+      AddAnyOrder(MultiplyAnyOrder(p0, p1), MultiplyAnyOrder(p0, p2)));
+}
+
+// Any order of p0 - p1 + p2.
+template <typename P0, typename P1, typename P2>
+auto SubtractAddAnyOrder(P0 p0, P1 p1, P2 p2) {
+  return m::AnyOf<HloInstruction>(AddAnyOrder(Subtract(p0, p1), p2),
+                                  AddAnyOrder(Subtract(p2, p1), p0),
+                                  Subtract(AddAnyOrder(p0, p2), p1));
+}
+
+// Any order of (p0 - p1) * p2 * p3 + p4.
+template <typename P0, typename P1, typename P2, typename P3, typename P4>
+auto SubtractMultiplyAddAnyOrder(P0 p0, P1 p1, P2 p2, P3 p3, P4 p4) {
+  return m::AnyOf<HloInstruction>(
+      SubtractAddAnyOrder(MultiplyMultiplyAnyOrder(p0, p2, p3),
+                          MultiplyMultiplyAnyOrder(p1, p2, p3), p4),
+      AddAnyOrder(MultiplyMultiplyAnyOrder(Subtract(p0, p1), p2, p3), p4));
+}
+
+// Expectation fused into a layer norm Custom Call.
+auto FusedExpectation(UniqueHloInstruction* custom_call) {
+  auto shared_subpattern = m::SharedSubpattern(m::GetTupleElement(
+      m::CustomCall({kCudnnNormCallTarget})
+          .WithPredicate(custom_call->GetCaptureOrVerifyFn()),
+      1));
+  return m::AnyOf<HloInstruction>(shared_subpattern,
+                                  BitcastOrReshape(shared_subpattern));
+}
+
+// Expectation fused into a layer norm Custom Call.
+auto FusedExpectation(UniqueHloInstruction* fused_expectation,
+                      UniqueHloInstruction* custom_call) {
+  auto shared_subpattern = m::SharedSubpattern(
+      m::GetTupleElement(
+          m::CustomCall({kCudnnNormCallTarget})
+              .WithPredicate(custom_call->GetCaptureOrVerifyFn()),
+          1)
+          .WithPredicate(fused_expectation->GetCaptureOrVerifyFn()));
+  return m::AnyOf<HloInstruction>(shared_subpattern,
+                                  BitcastOrReshape(shared_subpattern));
+}
+
+// Norm factor fused into a layer norm Custom Call.
+auto FusedNormFactor(UniqueHloInstruction* custom_call) {
+  auto shared_subpattern = m::SharedSubpattern(m::GetTupleElement(
+      m::CustomCall({kCudnnNormCallTarget})
+          .WithPredicate(custom_call->GetCaptureOrVerifyFn()),
+      2));
+  return m::AnyOf<HloInstruction>(shared_subpattern,
+                                  BitcastOrReshape(shared_subpattern));
+}
+
+// Norm factor fused into a layer norm Custom Call.
+auto FusedNormFactor(UniqueHloInstruction* fused_norm_factor,
+                     UniqueHloInstruction* custom_call) {
+  auto shared_subpattern = m::SharedSubpattern(
+      m::GetTupleElement(
+          m::CustomCall({kCudnnNormCallTarget})
+              .WithPredicate(custom_call->GetCaptureOrVerifyFn()),
+          2)
+          .WithPredicate(fused_norm_factor->GetCaptureOrVerifyFn()));
+  return m::AnyOf<HloInstruction>(shared_subpattern,
+                                  BitcastOrReshape(shared_subpattern));
+}
+
+// Derivative of the norm factor w.r.t. variance + epsilon,
+// d(norm_factor)/d(variance + epsilon)
+// = d((variance + epsilon)^-1/2)/d(variance + epsilon)
+// = -1/2 * norm_factor^3.
+// Forwards custom_call to FusedNormFactor for verification.
+auto DNormFactor(UniqueHloInstruction* custom_call) {
+  return MultiplyAnyOrder(m::Broadcast(m::ConstantScalar(-0.5)),
+                          Cube(FusedNormFactor(custom_call)));
+}
+
+//  Zero-centered input of the layer norm, X - expectation(X). Verifies that
+//  custom_call is a forward layer norm fusing X. Forwards custom_call to
+//  FusedExpectation for verification.
+auto XCenter(UniqueHloInstruction* x, UniqueHloInstruction* custom_call,
+             const NormMetadataMap& norm_metadata) {
+  auto capture_or_verify_x =
+      [x, custom_call, &norm_metadata](const HloInstruction* instr) -> bool {
+    return x->CaptureOrVerify(
+        FindTarget(custom_call->Instr(), instr->operand(0),
+                   custom_call->Instr()->operand(0), norm_metadata)
+            ? custom_call->Instr()->mutable_operand(0)
+            : nullptr);
+  };
+  return Subtract(m::Op(), m::Broadcast(FusedExpectation(custom_call)))
+      .WithPredicate(capture_or_verify_x);
+}
+
+// Zero-centered input of the layer norm, X - expectation(X). Captures X in x if
+// custom_call is a forward layer norm fusing X. Forwards custom_call to
+// FusedExpectation for comparison.
+auto XCenter(UniqueHloInstruction* x_center, UniqueHloInstruction* x,
+             UniqueHloInstruction* fused_expectation,
+             UniqueHloInstruction* custom_call,
+             const NormMetadataMap& norm_metadata) {
+  auto capture_or_verify_x =
+      [x, custom_call, &norm_metadata](const HloInstruction* instr) -> bool {
+    return x->CaptureOrVerify(
+        FindTarget(custom_call->Instr(), instr->operand(0),
+                   custom_call->Instr()->operand(0), norm_metadata)
+            ? custom_call->Instr()->mutable_operand(0)
+            : nullptr);
+  };
+  return Subtract(m::Op(), m::Broadcast(FusedExpectation(fused_expectation,
+                                                         custom_call)))
+      .WithPredicate(x_center->GetCaptureOrVerifyFn())
+      .WithPredicate(capture_or_verify_x);
+}
+
+// Addition-reduction of the product of XCenter, the broadcasted scale and DY,
+// XCenter * scale * DY. Captures the scale in scale if custom_call is a forward
+// layer norm fusing the scale. Forwards custom_call to XCenter for comparison.
+auto F0(UniqueHloInstruction* custom_call, UniqueHloInstruction* scale,
+        UniqueHloInstruction* dy, UniqueHloInstruction* x,
+        HloInstruction** reduce, const NormMetadataMap& norm_metadata) {
+  auto capture_or_verify_scale = [scale, custom_call, &norm_metadata](
+                                     const HloInstruction* instr) -> bool {
+    return scale->CaptureOrVerify(FindTarget(custom_call->Instr(), instr,
+                                             custom_call->Instr()->operand(1),
+                                             norm_metadata)
+                                      ? custom_call->Instr()->mutable_operand(1)
+                                      : nullptr);
+  };
+  return AddReduce(
+      reduce, MultiplyMultiplyAnyOrder(
+                  XCenter(x, custom_call, norm_metadata),
+                  m::Broadcast(m::Op().WithPredicate(capture_or_verify_scale)),
+                  m::Op().WithPredicate(dy->GetCaptureOrVerifyFn())));
+}
+
+// Product of XCenter and the scaled and broadcasted product of F0 and
+// d(norm_factor)/d(variance + epsilon), XCenter * F0 * DNormFactor * 2 /
+// nelems. Forwards custom_call to XCenter, F0 and DNormFactor for capture or
+// verification.
+auto F1(UniqueHloInstruction* x, UniqueHloInstruction* x_center,
+        UniqueHloInstruction* fused_expectation,
+        UniqueHloInstruction* custom_call, UniqueHloInstruction* scale,
+        UniqueHloInstruction* dy, HloInstruction** reduce,
+        const NormMetadataMap& norm_metadata) {
+  auto broadcasts_two_over_nelems = [](const HloInstruction* instr) -> bool {
+    const HloInstruction* multiply = SkipUnaryOps(instr->operand(0));
+    bool bcast_operand =
+        multiply->operand(0)->opcode() != HloOpcode::kBroadcast;
+
+    // The captured scalar must be two over the number of elements in the
+    // broadcasted dimensions.
+    float actual_two_over_nelems = multiply->operand(bcast_operand)
+                                       ->operand(0)
+                                       ->literal()
+                                       .GetAsDouble({})
+                                       .value();
+    int64_t nelems = 1;
+    for (int i = 0; i < instr->shape().dimensions_size(); ++i) {
+      if (!absl::c_linear_search(instr->dimensions(), i)) {
+        nelems *= instr->shape().dimensions()[i];
+      }
+    }
+    // The absolute of the difference between the actual scaling factor and the
+    // reference value must not exceed a prescribed threshold.
+    float two_over_nelems = 2. / static_cast<float>(nelems);
+    float numerical_epsilon = std::numeric_limits<bfloat16>::epsilon();
+    return abs(actual_two_over_nelems - two_over_nelems) <
+           ((actual_two_over_nelems + two_over_nelems) * numerical_epsilon);
+  };
+
+  return MultiplyAnyOrder(
+      XCenter(x_center, x, fused_expectation, custom_call, norm_metadata),
+      m::Broadcast(
+          MultiplyAnyOrder(m::Broadcast(m::ConstantScalar()),
+                           MultiplyAnyOrder(DNormFactor(custom_call),
+                                            F0(custom_call, scale, dy, x,
+                                               reduce, norm_metadata))))
+          .WithPredicate(broadcasts_two_over_nelems));
+}
+
+// Product of the norm factor, scale and DY, NormFactor * scale * DY. Captures
+// the scale in scale if custom_call is a forward layer norm fusing the scale.
+// Forwards custom_call to FusedNormFactor for comparison.
+auto F2(UniqueHloInstruction* fused_norm_factor, UniqueHloInstruction* scale,
+        UniqueHloInstruction* dy, UniqueHloInstruction* custom_call,
+        const NormMetadataMap& norm_metadata) {
+  auto capture_or_verify_scale = [scale, custom_call, &norm_metadata](
+                                     const HloInstruction* instr) -> bool {
+    return scale->CaptureOrVerify(
+        FindTarget(custom_call->Instr(), instr->operand(0),
+                   custom_call->Instr()->operand(1), norm_metadata)
+            ? custom_call->Instr()->mutable_operand(1)
+            : nullptr);
+  };
+  return MultiplyAnyOrder(
+      m::Broadcast(
+          BitcastOrReshape(FusedNormFactor(fused_norm_factor, custom_call))),
+      MultiplyAnyOrder(m::Broadcast().WithPredicate(capture_or_verify_scale),
+                       m::Op().WithPredicate(dy->GetCaptureOrVerifyFn())));
+}
+
+class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor {
+ public:
+  explicit CudnnNormRewriterVisitor(
+      const se::CudaComputeCapability cuda_compute_capability)
+      : cuda_compute_capability_(cuda_compute_capability) {}
+
+  absl::Status HandleAdd(HloInstruction* instr) override {
+    TF_RETURN_IF_ERROR(MatchLayerNorm(instr));
+    TF_RETURN_IF_ERROR(MatchLayerNormGradient(instr));
+    return absl::OkStatus();
+  }
+
+  absl::Status HandleSubtract(HloInstruction* instr) override {
+    return MatchLayerNorm(instr);
+  }
+
+  // Matches and rewrites layer norm patterns,
+  // Y = (X - expectation(X))/sqrt(variance(X) + epsilon) * scale + bias,
+  // into Custom Calls to cuDNN.
+  absl::Status MatchLayerNorm(HloInstruction* instr) {
+    UniqueHloInstruction x, expectation, variance, epsilon;
+    HloInstruction *scale, *bias, *reduce, *norm_factor, *broadcast_scale,
+        *broadcast_bias;
+    if (Match(
+            instr,
+            SubtractMultiplyAddAnyOrder(
+                OptionalSupportedTransform(
+                    m::Op().WithPredicate(x.GetCaptureOrVerifyFn())),
+                Expectation(&expectation, &reduce,
+                            OptionalSupportedTransform(m::Op().WithPredicate(
+                                x.GetCaptureOrVerifyFn()))),
+                NormFactor(&norm_factor, &x, &variance, &expectation, &epsilon),
+                m::Broadcast(&broadcast_scale, m::Op(&scale)),
+                m::Broadcast(&broadcast_bias, m::Op(&bias))))) {
+#if CUDNN_VERSION < 8905
+      // Layer norm kernels are available with cuDNN 8.9.5 and above.
+      VLOG(1) << "Layer norm Custom Calls require cuDNN 8.9.5.";
+      return absl::OkStatus();
+#endif  // CUDNN_VERSION < 8905
+
+      if (!instr->GetModule()
+               ->config()
+               .debug_options()
+               .xla_gpu_enable_cudnn_layer_norm()) {
+        VLOG(1) << "Layer norm Custom Calls disabled.";
+        return absl::OkStatus();
+      }
+
+      // Layer norm kernels require Ampere or Hopper architectures.
+      if (cuda_compute_capability_.major != se::CudaComputeCapability::AMPERE &&
+          cuda_compute_capability_.major != se::CudaComputeCapability::HOPPER) {
+        VLOG(1) << "Layer norm Custom Calls require Ampere or Hopper "
+                   "architectures.";
+        return absl::OkStatus();
+      }
+
+      // Verify the uniqueness of the inputs.
+      if (!x.Instr() || !expectation.Instr() || !variance.Instr() ||
+          !epsilon.Instr()) {
+        VLOG(1) << "Layer norm operands not unique.";
+        return absl::OkStatus();
+      }
+
+      // Verify the input and output layouts.
+      // TODO(philipphack): Consider supporting more general cases.
+      if (!LayoutUtil::IsMonotonicWithDim0Major(x.Instr()->shape().layout()) ||
+          !LayoutUtil::IsMonotonicWithDim0Major(scale->shape().layout()) ||
+          !LayoutUtil::IsMonotonicWithDim0Major(bias->shape().layout()) ||
+          !LayoutUtil::IsMonotonicWithDim0Major(instr->shape().layout())) {
+        VLOG(1) << "Layer norm input and/or output layouts nor supported.";
+        return absl::OkStatus();
+      }
+
+      // Verify the element types. The element types of input and output and the
+      // shapes of scale and bias must match.
+      if (!CompatibleElementType(instr) || !CompatibleElementType(scale) ||
+          !CompatibleElementType(bias) ||
+          !ShapeUtil::SameElementType(instr->shape(), x.Instr()->shape()) ||
+          !ShapeUtil::Equal(scale->shape(), bias->shape())) {
+        VLOG(1) << "Layer norm input types or shapes not supported.";
+        return absl::OkStatus();
+      }
+
+      // Verify that the shapes of scale and bias are compatible with the
+      // operation. The adjusted norm dimensions are the dimensions of the
+      // reduction after removing any degenerate dimensions from the input of
+      // the reduction.
+      std::vector<int64_t> norm_dims(reduce->dimensions().begin(),
+                                     reduce->dimensions().end());
+      std::vector<int64_t> norm_dims_adjusted = AdjustedDimensions(reduce);
+      if (norm_dims_adjusted.size() !=
+          ShapeUtil::DropDegenerateDimensions(scale->shape())
+              .dimensions_size()) {
+        VLOG(1) << "Layer norm input dimensions not supported.";
+        return absl::OkStatus();
+      }
+
+      // Verify the broadcasts of scale and bias.
+      if (!ShapeUtil::EqualIgnoringElementType(
+              ShapeUtil::DropDegenerateDimensions(reduce->operand(0)->shape()),
+              ShapeUtil::DropDegenerateDimensions(broadcast_scale->shape())) ||
+          !ShapeUtil::EqualIgnoringElementType(
+              ShapeUtil::DropDegenerateDimensions(reduce->operand(0)->shape()),
+              ShapeUtil::DropDegenerateDimensions(broadcast_bias->shape())) ||
+          norm_dims_adjusted != AdjustedDimensions(broadcast_scale) ||
+          norm_dims_adjusted != AdjustedDimensions(broadcast_bias)) {
+        VLOG(1) << "Layer norm operand broadcast not supported.";
+        return absl::OkStatus();
+      }
+
+      // If necessary, transpose the input so that the dimensions not being
+      // normalized are the leading dimensions.
+      std::vector<int64_t> non_norm_dims;
+      for (int64_t x_dim = 0; x_dim < x.Instr()->shape().rank(); ++x_dim) {
+        if (std::find(norm_dims.begin(), norm_dims.end(), x_dim) ==
+            norm_dims.end()) {
+          non_norm_dims.emplace_back(x_dim);
+        }
+      }
+      std::vector<int64_t> non_norm_dims_adjusted =
+          AdjustedDimensions(x.Instr()->shape(), non_norm_dims);
+
+      std::vector<int64_t> x_transpose_order = non_norm_dims;
+      x_transpose_order.insert(x_transpose_order.end(), norm_dims.begin(),
+                               norm_dims.end());
+
+      bool apply_transpose = false;
+      for (int i = 0; i < x_transpose_order.size(); ++i) {
+        if (x_transpose_order[i] != i) {
+          apply_transpose = true;
+          break;
+        }
+      }
+
+      std::optional<HloInstruction*> x_transpose;
+      // The transpose applied to the output is the inverse of the transpose
+      // applied to the input.
+      std::vector<int64_t> y_transpose_order(x_transpose_order.size());
+      if (apply_transpose) {
+        for (int k = 0; k < x_transpose_order.size(); ++k) {
+          y_transpose_order[x_transpose_order[k]] = k;
+        }
+        TF_ASSIGN_OR_RETURN(x_transpose,
+                            MakeTransposeHlo(x.Instr(), x_transpose_order));
+      }
+
+      // Combine the dimensions not normalized into the first dimension of the
+      // input as required by cuDNN.
+      std::vector<int64_t> reshaped_dims = {1};
+      for (auto non_norm_dim : non_norm_dims) {
+        reshaped_dims[0] *= x.Instr()->shape().dimensions(non_norm_dim);
+      }
+      for (auto norm_dim : norm_dims) {
+        reshaped_dims.emplace_back(x.Instr()->shape().dimensions(norm_dim));
+      }
+      // cuDNN requires tensors to have at least four dimensions.
+      while (reshaped_dims.size() < 4) {
+        reshaped_dims.emplace_back(1);
+      }
+
+      Shape reshaped_shape = ShapeUtil::MakeShape(
+          x.Instr()->shape().element_type(), reshaped_dims);
+      TF_ASSIGN_OR_RETURN(
+          HloInstruction * x_reshape,
+          MakeReshapeHlo(reshaped_shape, x_transpose.value_or(x.Instr())));
+
+      // Reshape the scale and bias. The first dimension corresponds to the
+      // non-normalization dimension of the norm input and must have size 1.
+      std::vector<int64_t> reshaped_scale_dims = reshaped_dims;
+      reshaped_scale_dims[0] = 1;
+
+      Shape scale_bias_shape = ShapeUtil::MakeShape(
+          scale->shape().element_type(), reshaped_scale_dims);
+      TF_ASSIGN_OR_RETURN(HloInstruction * scale_reshape,
+                          MakeReshapeHlo(scale_bias_shape, scale));
+      TF_ASSIGN_OR_RETURN(HloInstruction * bias_reshape,
+                          MakeReshapeHlo(scale_bias_shape, bias));
+      GpuBackendConfig gpu_backend_config;
+      CudnnNormBackendConfig& backend_config =
+          *gpu_backend_config.mutable_cudnn_norm_backend_config();
+      backend_config.set_epsilon(
+          epsilon.Instr()->literal().GetAsDouble({}).value());
+      backend_config.set_kind(CudnnNormBackendConfig::LAYER_FWD_INFER);
+      auto* algorithm = backend_config.mutable_algorithm();
+      algorithm->set_algo_id(0);
+      algorithm->set_math_type(se::dnn::AlgorithmProto::TENSOR_OP_MATH);
+      algorithm->set_is_cudnn_frontend(true);
+
+      // Set the workspace size to its upper bound.
+      // TODO(philipphack): Consider autotuning the norm kernels.
+      TF_ASSIGN_OR_RETURN(const int64_t c_constant,
+                          CConstant(cuda_compute_capability_));
+      const int64_t workspace_size =
+          (2 * c_constant * (4 + 256)) + (2 * reshaped_dims[0] * 4) + 64;
+      algorithm->mutable_workspace_size()->set_value(workspace_size);
+
+      // The output of the Custom Call is a tuple, the second element of which
+      // describes the scratch space.
+      Shape custom_call_shape = ShapeUtil::MakeTupleShape(
+          {x_reshape->shape(), ShapeUtil::MakeShape(U8, {workspace_size})});
+
+      HloInstruction* custom_call =
+          instr->AddInstruction(HloInstruction::CreateCustomCall(
+              custom_call_shape, {x_reshape, scale_reshape, bias_reshape},
+              kCudnnNormCallTarget));
+      TF_RETURN_IF_ERROR(custom_call->set_backend_config(gpu_backend_config));
+
+      TF_ASSIGN_OR_RETURN(HloInstruction * gte,
+                          MakeGetTupleElementHlo(custom_call, 0));
+      TF_ASSIGN_OR_RETURN(
+          HloInstruction * y_reshape,
+          MakeReshapeHlo(x_transpose.value_or(instr)->shape(), gte));
+
+      std::optional<HloInstruction*> y_transpose;
+      if (apply_transpose) {
+        TF_ASSIGN_OR_RETURN(y_transpose,
+                            MakeTransposeHlo(y_reshape, y_transpose_order));
+      }
+      TF_RETURN_IF_ERROR(
+          ReplaceInstruction(instr, y_transpose.value_or(y_reshape)));
+
+      // Store metadata for potential use in the backward graph.
+      norm_metadata_.insert(
+          {custom_call,
+           NormMetadata({x_transpose.value_or(nullptr),
+                         y_transpose.value_or(nullptr), norm_dims_adjusted,
+                         non_norm_dims_adjusted})});
+
+      VLOG(1) << "Layer norm rewritten into Custom Call.";
+
+      // The layer norm training graph separately contains the norm factor
+      // divided by the sum of variance and epsilon.
+      for (HloInstruction* user : norm_factor->users()) {
+        if (user->opcode() == HloOpcode::kDivide &&
+            user->operand_index(norm_factor) == 0) {
+          TF_ASSIGN_OR_RETURN(bool changed,
+                              MatchNormFactor(user, custom_call, variance,
+                                              expectation, epsilon));
+          if (changed) {
+            break;
+          }
+        }
+      }
+    }
+
+    return absl::OkStatus();
+  }
+
+  // The layer norm training graph separately contains the expectation as well
+  // as the norm factor and its cube, (variance + epsilon)^-1/2 and (variance +
+  // epsilon)^-3/2. When identified in the graph, these quantities are fused
+  // into the layer norm Custom Call.
+  absl::StatusOr<bool> MatchNormFactor(HloInstruction* instr,
+                                       HloInstruction* custom_call,
+                                       UniqueHloInstruction& variance,
+                                       UniqueHloInstruction& expectation,
+                                       UniqueHloInstruction& epsilon) {
+    HloInstruction* gte = custom_call->users()[0];
+    if (Match(instr,
+              m::Divide(
+                  m::Op(),
+                  AddAnyOrder(
+                      m::Op().WithPredicate(variance.GetCaptureOrVerifyFn()),
+                      m::Broadcast(m::ConstantScalar().WithPredicate(
+                          epsilon.GetCaptureOrVerifyFn())))))) {
+      // Verify the uniqueness of the operands.
+      if (!variance.Instr() || !epsilon.Instr()) {
+        VLOG(1) << "Layer norm operands not unique.";
+        return false;
+      }
+
+      // Verify the element types.
+      if (!CompatibleElementType(instr) ||
+          !CompatibleElementType(expectation.Instr())) {
+        VLOG(1) << "Layer norm input types not compatible.";
+        return false;
+      }
+
+      // Retrieve metadata of the forward layer norm.
+      auto norm_metadata = norm_metadata_.extract(custom_call);
+      if (!norm_metadata) {
+        VLOG(1) << "Unable to retrieve norm metadata of forward Custom Call.";
+        return false;
+      }
+
+      // The shape of the expectation and norm factor return values of the
+      // Custom Call is [nelems, 1, 1, 1], where nelems is the
+      // number of elements in the expectation and norm factor shapes.
+      auto make_compatible_shape = [](Shape shape) -> Shape {
+        return ShapeUtil::MakeShape(shape.element_type(),
+                                    {ShapeUtil::ElementsIn(shape), 1, 1, 1});
+      };
+
+      Shape expectation_shape =
+          make_compatible_shape(expectation.Instr()->shape());
+      Shape norm_factor_shape = make_compatible_shape(instr->shape());
+
+      // The augmented Custom Call additionally returns the expectation and the
+      // norm factor.
+      std::vector<Shape> tuple_shapes = custom_call->shape().tuple_shapes();
+      tuple_shapes.insert(tuple_shapes.begin() + 1,
+                          {expectation_shape, norm_factor_shape});
+
+      Shape custom_call_shape = ShapeUtil::MakeTupleShape(tuple_shapes);
+
+      HloInstruction* new_custom_call = instr->AddInstruction(
+          custom_call->CloneWithNewShape(custom_call_shape));
+
+      TF_ASSIGN_OR_RETURN(
+          GpuBackendConfig gpu_backend_config,
+          custom_call->backend_config<xla::gpu::GpuBackendConfig>());
+      CudnnNormBackendConfig& backend_config =
+          *gpu_backend_config.mutable_cudnn_norm_backend_config();
+      backend_config.set_kind(CudnnNormBackendConfig::LAYER_FWD_TRAIN);
+
+      // Update the workspace size.
+      TF_ASSIGN_OR_RETURN(const int64_t c_constant,
+                          CConstant(cuda_compute_capability_));
+      const int64_t workspace_size = (2 * c_constant * (4 + 256)) + 32;
+      backend_config.mutable_algorithm()->mutable_workspace_size()->set_value(
+          workspace_size);
+      TF_RETURN_IF_ERROR(
+          new_custom_call->set_backend_config(gpu_backend_config));
+
+      auto replace_with_new_cc = [new_custom_call, this](
+                                     HloInstruction* old_instr,
+                                     int tuple_index) -> absl::Status {
+        TF_ASSIGN_OR_RETURN(
+            HloInstruction * new_gte,
+            MakeGetTupleElementHlo(new_custom_call, tuple_index));
+        HloInstruction* new_instr = new_gte;
+        if (!ShapeUtil::Equal(new_gte->shape(), old_instr->shape())) {
+          TF_ASSIGN_OR_RETURN(new_instr,
+                              MakeReshapeHlo(old_instr->shape(), new_gte));
+        }
+        if (old_instr->opcode() != HloOpcode::kDivide) {
+          // Replace the result of the layer norm or the expectation.
+          TF_RETURN_IF_ERROR(ReplaceInstruction(old_instr, new_instr));
+        } else {
+          // Replace the norm factor, (variance + epsilon)^-1/2.
+          TF_RETURN_IF_ERROR(
+              ReplaceInstruction(old_instr->mutable_operand(0), new_instr));
+          // Also replace the norm factor to the power of 3, (variance +
+          // epsilon)^-1/2 / (variance + epsilon) = ((variance +
+          // epsilon)^-1/2)^3.
+          TF_ASSIGN_OR_RETURN(
+              HloInstruction * new_multiply0,
+              MakeBinaryHlo(HloOpcode::kMultiply, new_instr, new_instr));
+          TF_ASSIGN_OR_RETURN(
+              HloInstruction * new_multiply1,
+              MakeBinaryHlo(HloOpcode::kMultiply, new_multiply0, new_instr));
+          TF_RETURN_IF_ERROR(ReplaceInstruction(old_instr, new_multiply1));
+        }
+        return absl::OkStatus();
+      };
+
+      // Replace the result of the original Custom Call as well as the
+      // expectation and the norm factor with the augmented Custom Call.
+      TF_RETURN_IF_ERROR(replace_with_new_cc(gte, 0));
+      TF_RETURN_IF_ERROR(replace_with_new_cc(expectation.Instr(), 1));
+      TF_RETURN_IF_ERROR(replace_with_new_cc(instr, 2));
+
+      // Update the Custom Call associated with the metadata of the forward
+      // norm.
+      norm_metadata.key() = new_custom_call;
+      norm_metadata_.insert(std::move(norm_metadata));
+
+      VLOG(1)
+          << "Expectation and norm factor fused into layer norm Custom Call.";
+    }
+
+    return true;
+  }
+
+  // Matches and rewrites the backward graph of layer norm patterns into Custom
+  // Calls to cuDNN when the associated forward graph has been rewritten into a
+  // cuDNN Custom Call. The gradients are
+  //   DX = F1 + F2 - AddReduce(F1 + F2) / nelems,
+  //   Dscale = AddReduce(DY * XCenter * NormFactor),
+  //   Dbias = AddReduce(DY),
+  // with
+  //   F0 = XCenter * scale * DY,
+  //   F1 =  XCenter * F0 * DNormFactor * 2 / nelems,
+  //   F2 = NormFactor * scale * DY,
+  //   XCenter = X - expectation(X),
+  //   NormFactor = (variance(X) + epsilon)^-1/2 and
+  //   DNormFactor = -1/2 * NormFactor^3.
+  absl::Status MatchLayerNormGradient(HloInstruction* instr) {
+    UniqueHloInstruction fwd_custom_call, x, x_center, scale, dy,
+        fused_expectation, fused_norm_factor;
+    HloInstruction *broadcast, *scalar, *dscale, *dbias, *reduce0, *reduce1,
+        *reduce2, *reduce3;
+    if (Match(instr,
+              AddAddAnyOrder(
+                  m::Broadcast(
+                      &broadcast,
+                      MultiplyAddAnyOrder(
+                          m::Broadcast(m::ConstantScalar(&scalar)),
+                          NegateAddReduce(&reduce0,
+                                          F1(&x, &x_center, &fused_expectation,
+                                             &fwd_custom_call, &scale, &dy,
+                                             &reduce2, norm_metadata_)),
+                          NegateAddReduce(
+                              &reduce1, F2(&fused_norm_factor, &scale, &dy,
+                                           &fwd_custom_call, norm_metadata_)))),
+                  F2(&fused_norm_factor, &scale, &dy, &fwd_custom_call,
+                     norm_metadata_),
+                  F1(&x, &x_center, &fused_expectation, &fwd_custom_call,
+                     &scale, &dy, &reduce3, norm_metadata_)))) {
+      // Skip initial convert, if present.
+      if (instr->user_count() == 1 &&
+          instr->users()[0]->opcode() == HloOpcode::kConvert &&
+          CompatibleElementType(instr->users()[0])) {
+        instr = instr->users()[0];
+      }
+
+      // Verify the uniqueness of the captured Custom Call and inputs.
+      if (!fwd_custom_call.Instr() || !x.Instr() || !dy.Instr() ||
+          !x_center.Instr() || !scale.Instr() || !fused_expectation.Instr() ||
+          !fused_norm_factor.Instr()) {
+        VLOG(1) << "Layer norm gradient inputs not unique.";
+        return absl::OkStatus();
+      }
+
+      // Retrieve metadata of the forward layer norm.
+      auto norm_metadata = norm_metadata_.find(fwd_custom_call.Instr());
+      if (norm_metadata == norm_metadata_.end()) {
+        VLOG(1) << "Unable to retrieve norm metadata of forward Custom Call.";
+        return absl::OkStatus();
+      }
+
+      // Verify the dimensions of reductions in the backward graph.
+      if (AdjustedDimensions(reduce0) !=
+              norm_metadata->second.norm_dims_adjusted ||
+          AdjustedDimensions(reduce1) !=
+              norm_metadata->second.norm_dims_adjusted ||
+          AdjustedDimensions(reduce2) !=
+              norm_metadata->second.norm_dims_adjusted ||
+          AdjustedDimensions(reduce3) !=
+              norm_metadata->second.norm_dims_adjusted) {
+        VLOG(1) << "Unexpected reductions dimensions in layer norm gradient.";
+        return absl::OkStatus();
+      }
+
+      // The captured scalar must be one over the number of elements in the
+      // broadcasted dimensions.
+      float actual_r_nelems = scalar->literal().GetAsDouble({}).value();
+      int64_t nelems = 1;
+      for (int i = 0; i < broadcast->shape().dimensions_size(); ++i) {
+        if (!absl::c_linear_search(broadcast->dimensions(), i)) {
+          nelems *= broadcast->shape().dimensions()[i];
+        }
+      }
+      // The absolute of the difference between the actual scaling factor and
+      // the reference value must not exceed a prescribed threshold.
+      float r_nelems = 1. / static_cast<float>(nelems);
+      float numerical_epsilon = std::numeric_limits<bfloat16>::epsilon();
+      if (!(abs(actual_r_nelems - r_nelems) <
+            ((actual_r_nelems + r_nelems) * numerical_epsilon))) {
+        VLOG(1)
+            << "Layer norm backward broadcast operand outside expected range.";
+        return absl::OkStatus();
+      }
+
+      // Identify Dscale = AddReduce(DY * XCenter * norm factor) with factor0
+      // and factor1 intended to be XCenter and DY or DY and XCenter.
+      auto find_dscale =
+          [&fused_norm_factor, &norm_metadata](
+              const UniqueHloInstruction& factor0,
+              const UniqueHloInstruction& factor1) -> HloInstruction* {
+        for (HloInstruction* factor0_user : factor0.Instr()->users()) {
+          std::vector<HloInstruction*> users;
+          SkipUnaryOpsTopDownRecursive(factor0_user, users);
+          // One of the users of factor0 must be a chained multiplication by the
+          // fused norm factor and factor1.
+          for (HloInstruction* user : users) {
+            if (Match(user,
+                      MultiplyAnyOrder(
+                          m::Op(), MultiplyAnyOrder(
+                                       m::Broadcast(BitcastOrReshape(m::Op().Is(
+                                           fused_norm_factor.Instr()))),
+                                       m::Op().Is(factor1.Instr()))))) {
+              // Dscale is an addition-reduction of the product.
+              for (HloInstruction* multiply_user : user->users()) {
+                if (AppliesAddReduce(
+                        multiply_user,
+                        norm_metadata->second.non_norm_dims_adjusted)) {
+                  return multiply_user;
+                }
+              }
+            }
+          }
+        }
+        return nullptr;
+      };
+      if (!(dscale = find_dscale(x_center, dy)) &&
+          !(dscale = find_dscale(dy, x_center))) {
+        VLOG(1) << "Unable to identify Dscale in graph.";
+        return absl::OkStatus();
+      }
+
+      // Find Dbias, i.e. an addition-reduction of DY, starting from DY.
+      // Rewriting proceeds without fusing Dbias if unsuccessful.
+      dbias = FindAddReduce(dy.Instr(),
+                            norm_metadata->second.non_norm_dims_adjusted);
+
+      // Verify the input and output layouts.
+      // TODO(philipphack): Consider supporting more general cases.
+      if (!LayoutUtil::IsMonotonicWithDim0Major(dy.Instr()->shape().layout()) ||
+          !LayoutUtil::IsMonotonicWithDim0Major(instr->shape().layout()) ||
+          !LayoutUtil::IsMonotonicWithDim0Major(dscale->shape().layout()) ||
+          (dbias &&
+           !LayoutUtil::IsMonotonicWithDim0Major(dbias->shape().layout()))) {
+        VLOG(1) << "Layer norm input and/or output layouts nor supported.";
+        return absl::OkStatus();
+      }
+
+      // The types of X and DX must match.
+      if (x.Instr()->shape().element_type() != instr->shape().element_type()) {
+        VLOG(1) << "The types of X and DX must match.";
+        return absl::OkStatus();
+      }
+
+      // The types and shapes of scale, Dscale and Dbias (if present) must
+      // match.
+      if (!ShapeUtil::Equal(
+              ShapeUtil::DropDegenerateDimensions(scale.Instr()->shape()),
+              ShapeUtil::DropDegenerateDimensions(dscale->shape())) ||
+          (dbias &&
+           !ShapeUtil::Equal(
+               ShapeUtil::DropDegenerateDimensions(scale.Instr()->shape()),
+               ShapeUtil::DropDegenerateDimensions(dbias->shape())))) {
+        VLOG(1) << "Backward layer norm types not supported.";
+        return absl::OkStatus();
+      }
+
+      // Verify the element types.
+      if (!CompatibleElementType(dy.Instr())) {
+        VLOG(1) << "Backward layer norm types not supported.";
+        return absl::OkStatus();
+      }
+
+      // cuDNN requires the byte size of the element type of X to be at least
+      // that of DY and scale.
+      if (ShapeUtil::ByteSizeOfPrimitiveType(
+              x.Instr()->shape().element_type()) <
+              ShapeUtil::ByteSizeOfPrimitiveType(
+                  dy.Instr()->shape().element_type()) ||
+          ShapeUtil::ByteSizeOfPrimitiveType(
+              x.Instr()->shape().element_type()) <
+              ShapeUtil::ByteSizeOfPrimitiveType(
+                  scale.Instr()->shape().element_type())) {
+        VLOG(1) << "Backward layer norm types not supported.";
+        return absl::OkStatus();
+      }
+
+      // Transpose DY applying the stored transpose order of X from the forward
+      // graph.
+      HloInstruction* transposed_dy = dy.Instr();
+      if (norm_metadata->second.x_transpose) {
+        TF_ASSIGN_OR_RETURN(
+            transposed_dy,
+            MakeTransposeHlo(dy.Instr(),
+                             norm_metadata->second.x_transpose->dimensions()));
+      }
+      TF_ASSIGN_OR_RETURN(HloInstruction * reshaped_dy,
+                          MakeReshapeHlo(x.Instr()->shape(), transposed_dy));
+
+      Shape dx_shape = ShapeUtil::MakeShape(instr->shape().element_type(),
+                                            x.Instr()->shape().dimensions());
+
+      Shape dscale_dbias_shape = ShapeUtil::MakeShape(
+          dscale->shape().element_type(), scale.Instr()->shape().dimensions());
+
+      GpuBackendConfig gpu_backend_config;
+      CudnnNormBackendConfig& backend_config =
+          *gpu_backend_config.mutable_cudnn_norm_backend_config();
+      backend_config.set_kind(CudnnNormBackendConfig::LAYER_BWD);
+      auto* algorithm = backend_config.mutable_algorithm();
+      algorithm->set_algo_id(0);
+      algorithm->set_math_type(se::dnn::AlgorithmProto::TENSOR_OP_MATH);
+      algorithm->set_is_cudnn_frontend(true);
+
+      // Set the workspace size to its upper bound.
+      // TODO(philipphack): Consider autotuning the norm kernels.
+      TF_ASSIGN_OR_RETURN(const int64_t c_constant,
+                          CConstant(cuda_compute_capability_));
+      const int64_t workspace_size =
+          (2 * c_constant * (4 + 256)) +
+          (2 * x.Instr()->shape().dimensions(0) * 4) + 64;
+      algorithm->mutable_workspace_size()->set_value(workspace_size);
+
+      // The output of the Custom Call is a tuple. The output shape of Dscale
+      // and Dbias is that of scale.
+      Shape custom_call_shape = ShapeUtil::MakeTupleShape(
+          {dx_shape, dscale_dbias_shape, dscale_dbias_shape,
+           ShapeUtil::MakeShape(U8, {workspace_size})});
+
+      HloInstruction* custom_call =
+          instr->AddInstruction(HloInstruction::CreateCustomCall(
+              custom_call_shape,
+              {x.Instr(), scale.Instr(), reshaped_dy, fused_expectation.Instr(),
+               fused_norm_factor.Instr()},
+              kCudnnNormCallTarget));
+      TF_RETURN_IF_ERROR(custom_call->set_backend_config(gpu_backend_config));
+
+      auto replace_with_cc = [custom_call, norm_metadata, transposed_dy, this](
+                                 HloInstruction* old_instr,
+                                 int tuple_index) -> absl::Status {
+        TF_ASSIGN_OR_RETURN(HloInstruction * gte,
+                            MakeGetTupleElementHlo(custom_call, tuple_index));
+        HloInstruction* new_instr;
+        // Transpose DX applying the stored transpose order of Y from the
+        // forward graph.
+        if (tuple_index == 0 && norm_metadata->second.y_transpose) {
+          TF_ASSIGN_OR_RETURN(new_instr,
+                              MakeReshapeHlo(transposed_dy->shape(), gte));
+          TF_ASSIGN_OR_RETURN(
+              new_instr,
+              MakeTransposeHlo(
+                  new_instr, norm_metadata->second.y_transpose->dimensions()));
+        } else {
+          TF_ASSIGN_OR_RETURN(new_instr,
+                              MakeReshapeHlo(old_instr->shape(), gte));
+        }
+        TF_RETURN_IF_ERROR(ReplaceInstruction(old_instr, new_instr));
+        return absl::OkStatus();
+      };
+
+      TF_RETURN_IF_ERROR(replace_with_cc(instr, 0));
+      TF_RETURN_IF_ERROR(replace_with_cc(dscale, 1));
+      if (dbias) {
+        TF_RETURN_IF_ERROR(replace_with_cc(dbias, 2));
+      }
+      VLOG(1) << "Gradients w.r.t. x"
+              << (dbias ? ", scale and bias" : " and scale")
+              << " rewritten into layer norm backward Custom Call.";
+    }
+
+    return absl::OkStatus();
+  }
+
+ private:
+  se::CudaComputeCapability cuda_compute_capability_;
+  NormMetadataMap norm_metadata_;
+};
+
+absl::StatusOr<bool> RunOnComputation(
+    HloComputation* computation,
+    se::CudaComputeCapability cuda_compute_capability) {
+  CudnnNormRewriterVisitor visitor(cuda_compute_capability);
+  TF_RETURN_IF_ERROR(computation->Accept(&visitor));
+  return visitor.changed();
+}
+
+}  // anonymous namespace
+
+CudnnNormRewriter::CudnnNormRewriter(
+    se::CudaComputeCapability cuda_compute_capability)
+    : cuda_compute_capability_(cuda_compute_capability) {}
+
+absl::StatusOr<bool> CudnnNormRewriter::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  bool changed = false;
+  for (HloComputation* computation :
+       module->MakeNonfusionComputations(execution_threads)) {
+    TF_ASSIGN_OR_RETURN(
+        bool result, RunOnComputation(computation, cuda_compute_capability_));
+    changed |= result;
+  }
+  return changed;
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_norm_rewriter.h b/third_party/xla/xla/service/gpu/transforms/cudnn_norm_rewriter.h
new file mode 100644
index 0000000..a2332d3
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cudnn_norm_rewriter.h
@@ -0,0 +1,48 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_CUDNN_NORM_REWRITER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_CUDNN_NORM_REWRITER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/stream_executor/device_description.h"
+
+namespace xla {
+namespace gpu {
+
+// Rewrites norm patterns into Custom Calls to the cuDNN library. Currently, the
+// forward and backward passes of layer norm patterns are implemented.
+class CudnnNormRewriter : public HloModulePass {
+ public:
+  explicit CudnnNormRewriter(se::CudaComputeCapability cuda_compute_capability);
+  absl::string_view name() const override { return "norm-rewriter"; }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+  se::CudaComputeCapability cuda_compute_capability_;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_CUDNN_NORM_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_norm_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_norm_rewriter_test.cc
new file mode 100644
index 0000000..d130c08
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cudnn_norm_rewriter_test.cc
@@ -0,0 +1,1610 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <string>
+
+#include <gtest/gtest.h>
+#include "xla/error_spec.h"
+#include "xla/stream_executor/device_description.h"
+
+#if GOOGLE_CUDA
+#include "third_party/gpus/cuda/include/cuda.h"
+#include "third_party/gpus/cudnn/cudnn.h"  // IWYU pragma: keep
+#include "third_party/gpus/cudnn/cudnn_version.h"
+#endif
+
+#include "xla/service/gpu/tests/gpu_codegen_test.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+class CudnnNormRewriterTest : public GpuCodegenTest {
+ public:
+  se::CudaComputeCapability GetCudaComputeCapability() {
+    return backend()
+        .default_stream_executor()
+        ->GetDeviceDescription()
+        .cuda_compute_capability();
+  }
+
+  DebugOptions GetDebugOptionsForTest() override {
+    DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest();
+    debug_options.set_xla_gpu_enable_cudnn_layer_norm(true);
+    return debug_options;
+  }
+
+ protected:
+  void SetUp() override {
+#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
+    GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
+#endif
+    if (!(GetCudaComputeCapability().major ==
+          se::CudaComputeCapability::AMPERE) &&
+        !(GetCudaComputeCapability().major ==
+          se::CudaComputeCapability::HOPPER)) {
+      GTEST_SKIP()
+          << "Layer norm kernels require Ampere or Hopper architectures.";
+    }
+  }
+  void TestNorm(std::string hlo_text, std::string optimized_hlo) {
+    EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+    MatchOptimizedHlo(hlo_text, optimized_hlo);
+  }
+};
+
+// The following tests evaluate LayerNormXDY configurations, with X the rank of
+// the input and Y the dimensions that are normalized.
+TEST_F(CudnnNormRewriterTest, LayerNorm2D1) {
+  const char* hlo_text = R"(
+    HloModule test
+
+    apply {
+      a = f32[] parameter(0)
+      b = f32[] parameter(1)
+      ROOT c = f32[] add(a,b)
+    }
+
+    ENTRY test {
+        input = f32[2,4] parameter(0)
+        input_square = f32[2,4] multiply(input, input)
+        c0 = f32[] constant(0)
+        input_square_sum = f32[2] reduce(input_square, c0), dimensions={1}, to_apply=apply
+        r_nelems = f32[] constant(0.25)
+        r_nelems_bcast = f32[2] broadcast(r_nelems), dimensions={}
+        input_square_mean = f32[2] multiply(input_square_sum, r_nelems_bcast)
+        input_sum = f32[2] reduce(input, c0),dimensions={1}, to_apply=apply
+        input_mean = f32[2] multiply(input_sum, r_nelems_bcast)
+        input_mean_square = f32[2] multiply(input_mean, input_mean)
+        variance = f32[2] subtract(input_square_mean, input_mean_square)
+        epsilon = f32[] constant(0.001)
+        epsilon_bcast = f32[2] broadcast(epsilon), dimensions={}
+        variance_plus_epsilon = f32[2] add(variance, epsilon_bcast)
+        norm_factor = f32[2] rsqrt(variance_plus_epsilon)
+        norm_factor_bcast = f32[2,4] broadcast(norm_factor), dimensions={0}
+        input_mean_bcast = f32[2,4] broadcast(input_mean), dimensions={0}
+        input_center = f32[2,4] subtract(input, input_mean_bcast)
+        norm = f32[2,4] multiply(norm_factor_bcast, input_center)
+        scale = f32[4] parameter(1)
+        scale_bcast = f32[2,4] broadcast(scale), dimensions={1}
+        norm_scale = f32[2,4] multiply(norm, scale_bcast)
+        bias = f32[4] parameter(2)
+        bias_broadcast = f32[2,4] broadcast(bias), dimensions={1}
+        ROOT out = f32[2,4] add(norm_scale, bias_broadcast)
+    })";
+
+  const char* optimized_hlo = R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4], {{.*}}: f32[4], {{.*}}: f32[4]) -> f32[2,4] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,4]{1,0} parameter(0)
+; CHECK-NEXT:    [[P0_BITCAST:%[^ ]+]] = f32[2,4,1,1]{3,2,1,0} bitcast([[P0]])
+; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[4]{0} parameter(1)
+; CHECK-NEXT:    [[P1_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P1]])
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[4]{0} parameter(2)
+; CHECK-NEXT:    [[P2_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P2]])
+; CHECK-NEXT:    [[CC:%[^ ]+]] = (f32[2,4,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
+; CHECK:           custom_call_target="__cudnn$norm",
+; CHECK:           backend_config={
+; CHECK-DAG:         "epsilon":0.001
+; CHECK:           }
+; CHECK-NEXT:    [[GTE:%[^ ]+]] = f32[2,4,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=0
+; CHECK-NEXT:  ROOT [[GTE_BITCAST:%[^ ]+]] = f32[2,4]{1,0} bitcast([[GTE]])
+  )";
+
+  TestNorm(hlo_text, optimized_hlo);
+}
+
+TEST_F(CudnnNormRewriterTest, LayerNorm4D3) {
+  const char* hlo_text = R"(
+    HloModule test
+
+    apply {
+      a = f32[] parameter(0)
+      b = f32[] parameter(1)
+      ROOT c = f32[] add(a,b)
+    }
+
+    ENTRY test {
+        input = f32[2,4,6,8] parameter(0)
+        input_square = f32[2,4,6,8] multiply(input, input)
+        c0 = f32[] constant(0)
+        input_square_sum = f32[2,4,6] reduce(input_square, c0), dimensions={3}, to_apply=apply
+        r_nelems = f32[] constant(0.125)
+        r_nelems_bcast = f32[2,4,6] broadcast(r_nelems), dimensions={}
+        input_square_mean = f32[2,4,6] multiply(input_square_sum, r_nelems_bcast)
+        input_sum = f32[2,4,6] reduce(input, c0), dimensions={3}, to_apply=apply
+        input_mean = f32[2,4,6] multiply(input_sum, r_nelems_bcast)
+        input_mean_square = f32[2,4,6] multiply(input_mean, input_mean)
+        variance = f32[2,4,6] subtract(input_square_mean, input_mean_square)
+        epsilon = f32[] constant(0.001)
+        epsilon_bcast = f32[2,4,6] broadcast(epsilon), dimensions={}
+        variance_plus_epsilon = f32[2,4,6] add(variance, epsilon_bcast)
+        norm_factor = f32[2,4,6] rsqrt(variance_plus_epsilon)
+        norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,1,2}
+        input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,1,2}
+        input_center = f32[2,4,6,8] subtract(input, input_mean_bcast)
+        norm = f32[2,4,6,8] multiply(norm_factor_bcast, input_center)
+        scale = f32[8] parameter(1)
+        scale_bcast = f32[2,4,6,8] broadcast(scale), dimensions={3}
+        norm_scale = f32[2,4,6,8] multiply(norm, scale_bcast)
+        bias = f32[8] parameter(2)
+        bias_bcast = f32[2,4,6,8] broadcast(bias), dimensions={3}
+        ROOT out = f32[2,4,6,8] add(norm_scale, bias_bcast)
+    })";
+
+  const char* optimized_hlo = R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[8], {{.*}}: f32[8]) -> f32[2,4,6,8] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0)
+; CHECK-NEXT:    [[P0_BITCAST:%[^ ]+]] = f32[48,8,1,1]{3,2,1,0} bitcast([[P0]])
+; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[8]{0} parameter(1)
+; CHECK-NEXT:    [[P1_BITCAST:%[^ ]+]] = f32[1,8,1,1]{3,2,1,0} bitcast([[P1]])
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[8]{0} parameter(2)
+; CHECK-NEXT:    [[P2_BITCAST:%[^ ]+]] = f32[1,8,1,1]{3,2,1,0} bitcast([[P2]])
+; CHECK-NEXT:    [[CC:%[^ ]+]] = (f32[48,8,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
+; CHECK:           custom_call_target="__cudnn$norm",
+; CHECK:           backend_config={
+; CHECK-DAG:         "epsilon":0.001
+; CHECK:           }
+; CHECK-NEXT:    [[GTE:%[^ ]+]] = f32[48,8,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=0
+; CHECK-NEXT:  ROOT [[GTE_BITCAST:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} bitcast([[GTE]])
+  )";
+
+  TestNorm(hlo_text, optimized_hlo);
+}
+
+TEST_F(CudnnNormRewriterTest, LayerNorm4D3Degenerate0) {
+  const char* hlo_text = R"(
+    HloModule test
+
+    apply {
+      a = f32[] parameter(0)
+      b = f32[] parameter(1)
+      ROOT c = f32[] add(a,b)
+    }
+
+    ENTRY test {
+        input = f32[1,4,6,8] parameter(0)
+        input_square = f32[1,4,6,8] multiply(input, input)
+        c0 = f32[] constant(0)
+        input_square_sum = f32[1,4,6] reduce(input_square, c0), dimensions={3}, to_apply=apply
+        r_nelems = f32[] constant(0.125)
+        r_nelems_bcast = f32[1,4,6] broadcast(r_nelems), dimensions={}
+        input_square_mean = f32[1,4,6] multiply(input_square_sum, r_nelems_bcast)
+        input_sum = f32[1,4,6] reduce(input, c0), dimensions={3}, to_apply=apply
+        input_mean = f32[1,4,6] multiply(input_sum, r_nelems_bcast)
+        input_mean_square = f32[1,4,6] multiply(input_mean, input_mean)
+        variance = f32[1,4,6] subtract(input_square_mean, input_mean_square)
+        epsilon = f32[] constant(0.001)
+        epsilon_bcast = f32[1,4,6] broadcast(epsilon), dimensions={}
+        variance_plus_epsilon = f32[1,4,6] add(variance, epsilon_bcast)
+        norm_factor = f32[1,4,6] rsqrt(variance_plus_epsilon)
+        norm_factor_bcast = f32[1,4,6,8] broadcast(norm_factor), dimensions={0,1,2}
+        input_mean_bcast = f32[1,4,6,8] broadcast(input_mean), dimensions={0,1,2}
+        input_center = f32[1,4,6,8] subtract(input, input_mean_bcast)
+        norm = f32[1,4,6,8] multiply(norm_factor_bcast, input_center)
+        scale = f32[8] parameter(1)
+        scale_bcast = f32[1,4,6,8] broadcast(scale), dimensions={3}
+        norm_scale = f32[1,4,6,8] multiply(norm, scale_bcast)
+        bias = f32[8] parameter(2)
+        bias_bcast = f32[1,4,6,8] broadcast(bias), dimensions={3}
+        ROOT out = f32[1,4,6,8] add(norm_scale, bias_bcast)
+    })";
+
+  const char* optimized_hlo = R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[1,4,6,8], {{.*}}: f32[8], {{.*}}: f32[8]) -> f32[1,4,6,8] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[1,4,6,8]{3,2,1,0} parameter(0)
+; CHECK-NEXT:    [[P0_BITCAST:%[^ ]+]] = f32[24,8,1,1]{3,2,1,0} bitcast([[P0]])
+; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[8]{0} parameter(1)
+; CHECK-NEXT:    [[P1_BITCAST:%[^ ]+]] = f32[1,8,1,1]{3,2,1,0} bitcast([[P1]])
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[8]{0} parameter(2)
+; CHECK-NEXT:    [[P2_BITCAST:%[^ ]+]] = f32[1,8,1,1]{3,2,1,0} bitcast([[P2]])
+; CHECK-NEXT:    [[CC:%[^ ]+]] = (f32[24,8,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
+; CHECK:           custom_call_target="__cudnn$norm",
+; CHECK:           backend_config={
+; CHECK-DAG:         "epsilon":0.001
+; CHECK:           }
+; CHECK-NEXT:    [[GTE:%[^ ]+]] = f32[24,8,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=0
+; CHECK-NEXT:  ROOT [[GTE_BITCAST:%[^ ]+]] = f32[1,4,6,8]{3,2,1,0} bitcast([[GTE]])
+  )";
+
+  TestNorm(hlo_text, optimized_hlo);
+}
+
+TEST_F(CudnnNormRewriterTest, LayerNorm4D2) {
+  const char* hlo_text = R"(
+    HloModule test
+
+    apply {
+      a = f32[] parameter(0)
+      b = f32[] parameter(1)
+      ROOT c = f32[] add(a,b)
+    }
+
+    ENTRY test {
+        input = f32[2,4,6,8] parameter(0)
+        input_square = f32[2,4,6,8] multiply(input, input)
+        c0 = f32[] constant(0)
+        input_square_sum = f32[2,4,8] reduce(input_square, c0), dimensions={2}, to_apply=apply
+        r_nelems = f32[] constant(0.166667)
+        r_nelems_bcast = f32[2,4,8] broadcast(r_nelems), dimensions={}
+        input_square_mean = f32[2,4,8] multiply(input_square_sum, r_nelems_bcast)
+        reduce = f32[2,4,8] reduce(input, c0), dimensions={2}, to_apply=apply
+        input_mean = f32[2,4,8] multiply(reduce, r_nelems_bcast)
+        input_mean_square = f32[2,4,8] multiply(input_mean, input_mean)
+        variance = f32[2,4,8] subtract(input_square_mean, input_mean_square)
+        epsilon = f32[] constant(0.001)
+        epsilon_bcast = f32[2,4,8] broadcast(epsilon), dimensions={}
+        variance_plus_epsilon = f32[2,4,8] add(variance, epsilon_bcast)
+        norm_factor = f32[2,4,8] rsqrt(variance_plus_epsilon)
+        norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,1,3}
+        input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,1,3}
+        input_center = f32[2,4,6,8] subtract(input, input_mean_bcast)
+        norm = f32[2,4,6,8] multiply(norm_factor_bcast, input_center)
+        scale = f32[6] parameter(1)
+        scale_bcast = f32[2,4,6,8] broadcast(scale), dimensions={2}
+        norm_scale = f32[2,4,6,8] multiply(norm, scale_bcast)
+        bias = f32[6] parameter(2)
+        bias_broadcast = f32[2,4,6,8] broadcast(bias), dimensions={2}
+        ROOT out = f32[2,4,6,8] add(norm_scale, bias_broadcast)
+    })";
+
+  const char* optimized_hlo = R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[6], {{.*}}: f32[6]) -> f32[2,4,6,8] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0)
+; CHECK-NEXT:    [[TRANSPOSE:%[^ ]+]] = f32[2,4,8,6]{3,2,1,0} transpose([[P0]]), dimensions={0,1,3,2}
+; CHECK-NEXT:    [[P0_BITCAST:%[^ ]+]] = f32[64,6,1,1]{3,2,1,0} bitcast([[TRANSPOSE]])
+; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[6]{0} parameter(1)
+; CHECK-NEXT:    [[P1_BITCAST:%[^ ]+]] = f32[1,6,1,1]{3,2,1,0} bitcast([[P1]])
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[6]{0} parameter(2)
+; CHECK-NEXT:    [[P2_BITCAST:%[^ ]+]] = f32[1,6,1,1]{3,2,1,0} bitcast([[P2]])
+; CHECK-NEXT:    [[CC:%[^ ]+]] = (f32[64,6,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
+; CHECK:           custom_call_target="__cudnn$norm",
+; CHECK:           backend_config={
+; CHECK-DAG:         "epsilon":0.001
+; CHECK:           }
+; CHECK-NEXT:    [[GTE:%[^ ]+]] = f32[64,6,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=0
+; CHECK-NEXT:  ROOT [[FUSION:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} fusion([[GTE]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]]
+  )";
+
+  TestNorm(hlo_text, optimized_hlo);
+}
+
+TEST_F(CudnnNormRewriterTest, LayerNorm4D2Degenerate1) {
+  const char* hlo_text = R"(
+    HloModule test
+
+    apply {
+      a = f32[] parameter(0)
+      b = f32[] parameter(1)
+      ROOT c = f32[] add(a,b)
+    }
+
+    ENTRY test {
+        input = f32[2,1,6,8] parameter(0)
+        input_square = f32[2,1,6,8] multiply(input, input)
+        c0 = f32[] constant(0)
+        input_square_sum = f32[2,1,8] reduce(input_square, c0), dimensions={2}, to_apply=apply
+        r_nelems = f32[] constant(0.166667)
+        r_nelems_bcast = f32[2,1,8] broadcast(r_nelems), dimensions={}
+        input_square_mean = f32[2,1,8] multiply(input_square_sum, r_nelems_bcast)
+        reduce = f32[2,1,8] reduce(input, c0), dimensions={2}, to_apply=apply
+        input_mean = f32[2,1,8] multiply(reduce, r_nelems_bcast)
+        input_mean_square = f32[2,1,8] multiply(input_mean, input_mean)
+        variance = f32[2,1,8] subtract(input_square_mean, input_mean_square)
+        epsilon = f32[] constant(0.001)
+        epsilon_bcast = f32[2,1,8] broadcast(epsilon), dimensions={}
+        variance_plus_epsilon = f32[2,1,8] add(variance, epsilon_bcast)
+        norm_factor = f32[2,1,8] rsqrt(variance_plus_epsilon)
+        norm_factor_bcast = f32[2,1,6,8] broadcast(norm_factor), dimensions={0,1,3}
+        input_mean_bcast = f32[2,1,6,8] broadcast(input_mean), dimensions={0,1,3}
+        input_center = f32[2,1,6,8] subtract(input, input_mean_bcast)
+        norm = f32[2,1,6,8] multiply(norm_factor_bcast, input_center)
+        scale = f32[6] parameter(1)
+        scale_bcast = f32[2,1,6,8] broadcast(scale), dimensions={2}
+        norm_scale = f32[2,1,6,8] multiply(norm, scale_bcast)
+        bias = f32[6] parameter(2)
+        bias_broadcast = f32[2,1,6,8] broadcast(bias), dimensions={2}
+        ROOT out = f32[2,1,6,8] add(norm_scale, bias_broadcast)
+    })";
+
+  const char* optimized_hlo = R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,1,6,8], {{.*}}: f32[6], {{.*}}: f32[6]) -> f32[2,1,6,8] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,1,6,8]{3,2,1,0} parameter(0)
+; CHECK-NEXT:    [[TRANSPOSE:%[^ ]+]] = f32[1,2,8,6]{3,2,1,0} transpose([[P0]]), dimensions={1,0,3,2}
+; CHECK-NEXT:    [[P0_BITCAST:%[^ ]+]] = f32[16,6,1,1]{3,2,1,0} bitcast([[TRANSPOSE]])
+; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[6]{0} parameter(1)
+; CHECK-NEXT:    [[P1_BITCAST:%[^ ]+]] = f32[1,6,1,1]{3,2,1,0} bitcast([[P1]])
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[6]{0} parameter(2)
+; CHECK-NEXT:    [[P2_BITCAST:%[^ ]+]] = f32[1,6,1,1]{3,2,1,0} bitcast([[P2]])
+; CHECK-NEXT:    [[CC:%[^ ]+]] = (f32[16,6,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
+; CHECK:           custom_call_target="__cudnn$norm",
+; CHECK:           backend_config={
+; CHECK-DAG:         "epsilon":0.001
+; CHECK:           }
+; CHECK-NEXT:    [[GTE:%[^ ]+]] = f32[16,6,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=0
+; CHECK-NEXT:  ROOT [[FUSION:%[^ ]+]] = f32[2,1,6,8]{3,2,1,0} fusion([[GTE]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]]
+  )";
+
+  TestNorm(hlo_text, optimized_hlo);
+}
+
+TEST_F(CudnnNormRewriterTest, LayerNorm4D12) {
+  const char* hlo_text = R"(
+    HloModule test
+
+    apply {
+      a = f32[] parameter(0)
+      b = f32[] parameter(1)
+      ROOT c = f32[] add(a,b)
+    }
+
+    ENTRY test {
+        input = f32[2,4,6,8] parameter(0)
+        input_square = f32[2,4,6,8] multiply(input, input)
+        c0 = f32[] constant(0)
+        input_square_sum = f32[2,8] reduce(input_square, c0), dimensions={1,2}, to_apply=apply
+        r_nelems = f32[] constant(0.041667)
+        r_nelems_bcast = f32[2,8] broadcast(r_nelems), dimensions={}
+        input_square_mean = f32[2,8] multiply(input_square_sum, r_nelems_bcast)
+        reduce = f32[2,8] reduce(input, c0), dimensions={1,2}, to_apply=apply
+        input_mean = f32[2,8] multiply(reduce, r_nelems_bcast)
+        input_mean_square = f32[2,8] multiply(input_mean, input_mean)
+        variance = f32[2,8] subtract(input_square_mean, input_mean_square)
+        epsilon = f32[] constant(0.001)
+        epsilon_bcast = f32[2,8] broadcast(epsilon), dimensions={}
+        variance_plus_epsilon = f32[2,8] add(variance, epsilon_bcast)
+        norm_factor = f32[2,8] rsqrt(variance_plus_epsilon)
+        norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,3}
+        input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,3}
+        input_center = f32[2,4,6,8] subtract(input, input_mean_bcast)
+        norm = f32[2,4,6,8] multiply(norm_factor_bcast, input_center)
+        scale = f32[4,6] parameter(1)
+        scale_bcast = f32[2,4,6,8] broadcast(scale), dimensions={1,2}
+        norm_scale = f32[2,4,6,8] multiply(norm, scale_bcast)
+        bias = f32[4,6] parameter(2)
+        bias_broadcast = f32[2,4,6,8] broadcast(bias), dimensions={1,2}
+        ROOT out = f32[2,4,6,8] add(norm_scale, bias_broadcast)
+    })";
+
+  const char* optimized_hlo = R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[4,6], {{.*}}: f32[4,6]) -> f32[2,4,6,8] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0)
+; CHECK-NEXT:    [[TRANSPOSE:%[^ ]+]] = f32[2,8,4,6]{3,2,1,0} transpose([[P0]]), dimensions={0,3,1,2}
+; CHECK-NEXT:    [[P0_BITCAST:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} bitcast([[TRANSPOSE]])
+; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[4,6]{1,0} parameter(1)
+; CHECK-NEXT:    [[P1_BITCAST:%[^ ]+]] = f32[1,4,6,1]{3,2,1,0} bitcast([[P1]])
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[4,6]{1,0} parameter(2)
+; CHECK-NEXT:    [[P2_BITCAST:%[^ ]+]] = f32[1,4,6,1]{3,2,1,0} bitcast([[P2]])
+; CHECK-NEXT:    [[CC:%[^ ]+]] = (f32[16,4,6,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
+; CHECK:           custom_call_target="__cudnn$norm",
+; CHECK:           backend_config={
+; CHECK-DAG:         "epsilon":0.001
+; CHECK:           }
+; CHECK-NEXT:    [[GTE:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} get-tuple-element([[CC]]), index=0
+; CHECK-NEXT:  ROOT  [[FUSION:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} fusion([[GTE]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]]
+  )";
+
+  TestNorm(hlo_text, optimized_hlo);
+}
+
+TEST_F(CudnnNormRewriterTest, LayerNorm4D12Degenerate2) {
+  const char* hlo_text = R"(
+    HloModule test
+
+    apply {
+      a = f32[] parameter(0)
+      b = f32[] parameter(1)
+      ROOT c = f32[] add(a,b)
+    }
+
+    ENTRY test {
+        input = f32[2,4,1,8] parameter(0)
+        input_square = f32[2,4,1,8] multiply(input, input)
+        c0 = f32[] constant(0)
+        input_square_sum = f32[2,8] reduce(input_square, c0), dimensions={1,2}, to_apply=apply
+        r_nelems = f32[] constant(0.25)
+        r_nelems_bcast = f32[2,8] broadcast(r_nelems), dimensions={}
+        input_square_mean = f32[2,8] multiply(input_square_sum, r_nelems_bcast)
+        reduce = f32[2,8] reduce(input, c0), dimensions={1,2}, to_apply=apply
+        input_mean = f32[2,8] multiply(reduce, r_nelems_bcast)
+        input_mean_square = f32[2,8] multiply(input_mean, input_mean)
+        variance = f32[2,8] subtract(input_square_mean, input_mean_square)
+        epsilon = f32[] constant(0.001)
+        epsilon_bcast = f32[2,8] broadcast(epsilon), dimensions={}
+        variance_plus_epsilon = f32[2,8] add(variance, epsilon_bcast)
+        norm_factor = f32[2,8] rsqrt(variance_plus_epsilon)
+        norm_factor_bcast = f32[2,4,1,8] broadcast(norm_factor), dimensions={0,3}
+        input_mean_bcast = f32[2,4,1,8] broadcast(input_mean), dimensions={0,3}
+        input_center = f32[2,4,1,8] subtract(input, input_mean_bcast)
+        norm = f32[2,4,1,8] multiply(norm_factor_bcast, input_center)
+        scale = f32[4,1] parameter(1)
+        scale_bcast = f32[2,4,1,8] broadcast(scale), dimensions={1,2}
+        norm_scale = f32[2,4,1,8] multiply(norm, scale_bcast)
+        bias = f32[4,1] parameter(2)
+        bias_broadcast = f32[2,4,1,8] broadcast(bias), dimensions={1,2}
+        ROOT out = f32[2,4,1,8] add(norm_scale, bias_broadcast)
+    })";
+
+  const char* optimized_hlo = R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,1,8], {{.*}}: f32[4,1], {{.*}}: f32[4,1]) -> f32[2,4,1,8] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} parameter(0)
+; CHECK-NEXT:    [[TRANSPOSE:%[^ ]+]] = f32[1,2,8,4]{3,2,1,0} transpose([[P0]]), dimensions={2,0,3,1}
+; CHECK-NEXT:    [[P0_BITCAST:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} bitcast([[TRANSPOSE]])
+; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[4,1]{1,0} parameter(1)
+; CHECK-NEXT:    [[P1_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P1]])
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[4,1]{1,0} parameter(2)
+; CHECK-NEXT:    [[P2_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P2]])
+; CHECK-NEXT:    [[CC:%[^ ]+]] = (f32[16,4,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
+; CHECK:           custom_call_target="__cudnn$norm",
+; CHECK:           backend_config={
+; CHECK-DAG:         "epsilon":0.001
+; CHECK:           }
+; CHECK-NEXT:    [[GTE:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=0
+; CHECK-NEXT:  ROOT  [[FUSION:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} fusion([[GTE]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]]
+  )";
+
+  TestNorm(hlo_text, optimized_hlo);
+}
+
+TEST_F(CudnnNormRewriterTest, LayerNorm4D3IncorrectScaleBroadcast) {
+  const char* hlo_text = R"(
+    HloModule test
+
+    apply {
+      a = f32[] parameter(0)
+      b = f32[] parameter(1)
+      ROOT c = f32[] add(a,b)
+    }
+
+    ENTRY test {
+        input = f32[2,2,2,2] parameter(0)
+        input_square = f32[2,2,2,2] multiply(input, input)
+        c0 = f32[] constant(0)
+        input_square_sum = f32[2,2,2] reduce(input_square, c0), dimensions={3}, to_apply=apply
+        r_nelems = f32[] constant(0.5)
+        r_nelems_bcast = f32[2,2,2] broadcast(r_nelems), dimensions={}
+        input_square_mean = f32[2,2,2] multiply(input_square_sum, r_nelems_bcast)
+        input_sum = f32[2,2,2] reduce(input, c0), dimensions={3}, to_apply=apply
+        input_mean = f32[2,2,2] multiply(input_sum, r_nelems_bcast)
+        input_mean_square = f32[2,2,2] multiply(input_mean, input_mean)
+        variance = f32[2,2,2] subtract(input_square_mean, input_mean_square)
+        epsilon = f32[] constant(0.001)
+        epsilon_bcast = f32[2,2,2] broadcast(epsilon), dimensions={}
+        variance_plus_epsilon = f32[2,2,2] add(variance, epsilon_bcast)
+        norm_factor = f32[2,2,2] rsqrt(variance_plus_epsilon)
+        norm_factor_bcast = f32[2,2,2,2] broadcast(norm_factor), dimensions={0,1,2}
+        input_mean_bcast = f32[2,2,2,2] broadcast(input_mean), dimensions={0,1,2}
+        input_center = f32[2,2,2,2] subtract(input, input_mean_bcast)
+        norm = f32[2,2,2,2] multiply(norm_factor_bcast, input_center)
+        scale = f32[2] parameter(1)
+        scale_bcast = f32[2,2,2,2] broadcast(scale), dimensions={2}
+        norm_scale = f32[2,2,2,2] multiply(norm, scale_bcast)
+        bias = f32[2] parameter(2)
+        bias_bcast = f32[2,2,2,2] broadcast(bias), dimensions={3}
+        ROOT out = f32[2,2,2,2] add(norm_scale, bias_bcast)
+    })";
+
+  const char* optimized_hlo = R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,2,2,2], {{.*}}: f32[2], {{.*}}: f32[2]) -> f32[2,2,2,2] {
+; CHECK-NOT:           custom_call_target="__cudnn$norm"
+  )";
+
+  TestNorm(hlo_text, optimized_hlo);
+}
+
+TEST_F(CudnnNormRewriterTest, LayerNorm4D3InputOutputTypeMismatch) {
+  const char* hlo_text = R"(
+    HloModule test
+
+    apply {
+      a = f32[] parameter(0)
+      b = f32[] parameter(1)
+      ROOT c = f32[] add(a,b)
+    }
+
+    ENTRY test {
+        input = f16[2,4,6,8] parameter(0)
+        input_f32 = f32[2,4,6,8] convert(input)
+        input_square = f32[2,4,6,8] multiply(input_f32, input_f32)
+        c0 = f32[] constant(0)
+        input_square_sum = f32[2,4,6] reduce(input_square, c0), dimensions={3}, to_apply=apply
+        r_nelems = f32[] constant(0.125)
+        r_nelems_bcast = f32[2,4,6] broadcast(r_nelems), dimensions={}
+        input_square_mean = f32[2,4,6] multiply(input_square_sum, r_nelems_bcast)
+        input_sum = f32[2,4,6] reduce(input_f32, c0), dimensions={3}, to_apply=apply
+        input_mean = f32[2,4,6] multiply(input_sum, r_nelems_bcast)
+        input_mean_square = f32[2,4,6] multiply(input_mean, input_mean)
+        variance = f32[2,4,6] subtract(input_square_mean, input_mean_square)
+        epsilon = f32[] constant(0.001)
+        epsilon_bcast = f32[2,4,6] broadcast(epsilon), dimensions={}
+        variance_plus_epsilon = f32[2,4,6] add(variance, epsilon_bcast)
+        norm_factor = f32[2,4,6] rsqrt(variance_plus_epsilon)
+        norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,1,2}
+        input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,1,2}
+        input_center = f32[2,4,6,8] subtract(input_f32, input_mean_bcast)
+        norm = f32[2,4,6,8] multiply(norm_factor_bcast, input_center)
+        scale = f32[8] parameter(1)
+        scale_bcast = f32[2,4,6,8] broadcast(scale), dimensions={3}
+        norm_scale = f32[2,4,6,8] multiply(norm, scale_bcast)
+        bias = f32[8] parameter(2)
+        bias_bcast = f32[2,4,6,8] broadcast(bias), dimensions={3}
+        ROOT out = f32[2,4,6,8] add(norm_scale, bias_bcast)
+    })";
+
+  const char* optimized_hlo = R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f16[2,4,6,8], {{.*}}: f32[8], {{.*}}: f32[8]) -> f32[2,4,6,8] {
+; CHECK-NOT:           custom_call_target="__cudnn$norm"
+  )";
+
+  TestNorm(hlo_text, optimized_hlo);
+}
+
+TEST_F(CudnnNormRewriterTest, LayerNormTrain2D1) {
+  const char* hlo_text = R"(
+    HloModule test
+
+    apply {
+      a = f32[] parameter(0)
+      b = f32[] parameter(1)
+      ROOT c = f32[] add(a,b)
+    }
+
+    ENTRY test {
+        input = f32[2,4] parameter(0)
+        input_square = f32[2,4] multiply(input, input)
+        c0 = f32[] constant(0)
+        input_square_sum = f32[2] reduce(input_square, c0), dimensions={1}, to_apply=apply
+        r_nelems = f32[] constant(0.25)
+        r_nelems_bcast = f32[2] broadcast(r_nelems), dimensions={}
+        input_square_mean = f32[2] multiply(input_square_sum,r_nelems_bcast)
+        reduce = f32[2] reduce(input, c0), dimensions={1}, to_apply=apply
+        input_mean = f32[2] multiply(reduce,r_nelems_bcast)
+        input_mean_square = f32[2] multiply(input_mean,input_mean)
+        variance = f32[2] subtract(input_square_mean,input_mean_square)
+        epsilon = f32[] constant(0.001)
+        epsilon_bcast = f32[2] broadcast(epsilon), dimensions={}
+        variance_plus_epsilon = f32[2] add(variance, epsilon_bcast)
+        norm_factor = f32[2] rsqrt(variance_plus_epsilon)
+        norm_factor_bcast = f32[2,4] broadcast(norm_factor), dimensions={0}
+        input_mean_bcast = f32[2,4] broadcast(input_mean), dimensions={0}
+        input_center = f32[2,4] subtract(input,input_mean_bcast)
+        norm = f32[2,4] multiply(norm_factor_bcast,input_center)
+        scale = f32[4] parameter(1)
+        scale_bcast = f32[2,4] broadcast(scale), dimensions={1}
+        norm_scale = f32[2,4] multiply(norm,scale_bcast)
+        bias = f32[4] parameter(2)
+        bias_broadcast = f32[2,4] broadcast(bias), dimensions={1}
+        norm_scale_bias = f32[2,4] add(norm_scale, bias_broadcast)
+        norm_factor_cube = f32[2] divide(norm_factor, variance_plus_epsilon)
+        ROOT out = (f32[2,4], f32[2], f32[2], f32[2]) tuple(norm_scale_bias, input_mean, norm_factor, norm_factor_cube)
+    })";
+
+  const char* optimized_hlo = R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4], {{.*}}: f32[4], {{.*}}: f32[4]) -> (f32[2,4], f32[2], f32[2], f32[2]) {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,4]{1,0} parameter(0)
+; CHECK-NEXT:    [[P0_BITCAST:%[^ ]+]] = f32[2,4,1,1]{3,2,1,0} bitcast([[P0]])
+; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[4]{0} parameter(1)
+; CHECK-NEXT:    [[P1_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P1]])
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[4]{0} parameter(2)
+; CHECK-NEXT:    [[P2_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P2]])
+; CHECK-NEXT:    [[CC:%[^ ]+]] = (f32[2,4,1,1]{3,2,1,0}, f32[2,1,1,1]{3,2,1,0}, f32[2,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
+; CHECK:           custom_call_target="__cudnn$norm",
+; CHECK:           backend_config={
+; CHECK-DAG:         "epsilon":0.001
+; CHECK:           }
+; CHECK-NEXT:    [[GTE0:%[^ ]+]] = f32[2,4,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=0
+; CHECK-NEXT:    [[GTE0_BITCAST:%[^ ]+]] = f32[2,4]{1,0} bitcast([[GTE0]])
+; CHECK-NEXT:    [[GTE1:%[^ ]+]] = f32[2,1,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=1
+; CHECK-NEXT:    [[GTE1_BITCAST:%[^ ]+]] = f32[2]{0} bitcast([[GTE1]])
+; CHECK-NEXT:    [[GTE2:%[^ ]+]] = f32[2,1,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=2
+; CHECK-NEXT:    [[GTE2_BITCAST:%[^ ]+]] = f32[2]{0} bitcast([[GTE2]])
+; CHECK-NEXT:    [[FUSION:%[^ ]+]] = f32[2]{0} fusion([[GTE2]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]]
+; CHECK-NEXT:  ROOT [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, f32[2]{0}, f32[2]{0}, f32[2]{0}) tuple([[GTE0_BITCAST]], [[GTE1_BITCAST]], [[GTE2_BITCAST]], [[FUSION]])
+  )";
+
+  TestNorm(hlo_text, optimized_hlo);
+}
+
+TEST_F(CudnnNormRewriterTest, LayerNormTrain4D3) {
+  const char* hlo_text = R"(
+    HloModule test
+
+    apply {
+      a = f32[] parameter(0)
+      b = f32[] parameter(1)
+      ROOT c = f32[] add(a,b)
+    }
+
+    ENTRY test {
+        input = f32[2,4,6,8] parameter(0)
+        input_square = f32[2,4,6,8] multiply(input, input)
+        c0 = f32[] constant(0)
+        input_square_sum = f32[2,4,6] reduce(input_square, c0), dimensions={3}, to_apply=apply
+        r_nelems = f32[] constant(0.125)
+        r_nelems_bcast = f32[2,4,6] broadcast(r_nelems), dimensions={}
+        input_square_mean = f32[2,4,6] multiply(input_square_sum, r_nelems_bcast)
+        reduce = f32[2,4,6] reduce(input, c0), dimensions={3}, to_apply=apply
+        input_mean = f32[2,4,6] multiply(reduce, r_nelems_bcast)
+        input_mean_square = f32[2,4,6] multiply(input_mean, input_mean)
+        variance = f32[2,4,6] subtract(input_square_mean, input_mean_square)
+        epsilon = f32[] constant(0.001)
+        epsilon_bcast = f32[2,4,6] broadcast(epsilon), dimensions={}
+        variance_plus_epsilon = f32[2,4,6] add(variance, epsilon_bcast)
+        norm_factor = f32[2,4,6] rsqrt(variance_plus_epsilon)
+        norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,1,2}
+        input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,1,2}
+        input_center = f32[2,4,6,8] subtract(input, input_mean_bcast)
+        norm = f32[2,4,6,8] multiply(norm_factor_bcast, input_center)
+        scale = f32[8] parameter(1)
+        scale_bcast = f32[2,4,6,8] broadcast(scale), dimensions={3}
+        norm_scale = f32[2,4,6,8] multiply(norm,scale_bcast)
+        bias = f32[8] parameter(2)
+        bias_broadcast = f32[2,4,6,8] broadcast(bias), dimensions={3}
+        norm_scale_bias = f32[2,4,6,8] add(norm_scale, bias_broadcast)
+        norm_factor_cube = f32[2,4,6] divide(norm_factor, variance_plus_epsilon)
+        ROOT out = (f32[2,4,6,8], f32[2,4,6], f32[2,4,6], f32[2,4,6]) tuple(norm_scale_bias, input_mean, norm_factor, norm_factor_cube)
+    })";
+
+  const char* optimized_hlo = R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[8], {{.*}}: f32[8]) -> (f32[2,4,6,8], f32[2,4,6], f32[2,4,6], f32[2,4,6]) {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0)
+; CHECK-NEXT:    [[P0_BITCAST:%[^ ]+]] = f32[48,8,1,1]{3,2,1,0} bitcast([[P0]])
+; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[8]{0} parameter(1)
+; CHECK-NEXT:    [[P1_BITCAST:%[^ ]+]] = f32[1,8,1,1]{3,2,1,0} bitcast([[P1]])
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[8]{0} parameter(2)
+; CHECK-NEXT:    [[P2_BITCAST:%[^ ]+]] = f32[1,8,1,1]{3,2,1,0} bitcast([[P2]])
+; CHECK-NEXT:    [[CC:%[^ ]+]] = (f32[48,8,1,1]{3,2,1,0}, f32[48,1,1,1]{3,2,1,0}, f32[48,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
+; CHECK:           custom_call_target="__cudnn$norm",
+; CHECK:           backend_config={
+; CHECK-DAG:         "epsilon":0.001
+; CHECK:           }
+; CHECK-NEXT:    [[GTE0:%[^ ]+]] = f32[48,8,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=0
+; CHECK-NEXT:    [[GTE0_BITCAST:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} bitcast([[GTE0]])
+; CHECK-NEXT:    [[GTE1:%[^ ]+]] = f32[48,1,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=1
+; CHECK-NEXT:    [[GTE1_BITCAST:%[^ ]+]] = f32[2,4,6]{2,1,0} bitcast([[GTE1]])
+; CHECK-NEXT:    [[GTE2:%[^ ]+]] = f32[48,1,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=2
+; CHECK-NEXT:    [[GTE2_BITCAST:%[^ ]+]] = f32[2,4,6]{2,1,0} bitcast([[GTE2]])
+; CHECK-NEXT:    [[FUSION:%[^ ]+]] = f32[2,4,6]{2,1,0} fusion([[GTE2]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]]
+; CHECK-NEXT:  ROOT [[OUT:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6]{2,1,0}, f32[2,4,6]{2,1,0}, f32[2,4,6]{2,1,0}) tuple([[GTE0_BITCAST]], [[GTE1_BITCAST]], [[GTE2_BITCAST]], [[FUSION]])
+  )";
+
+  TestNorm(hlo_text, optimized_hlo);
+}
+
+TEST_F(CudnnNormRewriterTest, LayerNormTrain4D12) {
+  const char* hlo_text = R"(
+    HloModule test
+
+    apply {
+      a = f32[] parameter(0)
+      b = f32[] parameter(1)
+      ROOT c = f32[] add(a,b)
+    }
+
+    ENTRY test {
+        input = f32[2,4,6,8] parameter(0)
+        input_square = f32[2,4,6,8] multiply(input, input)
+        c0 = f32[] constant(0)
+        input_square_sum = f32[2,8] reduce(input_square, c0), dimensions={1,2}, to_apply=apply
+        r_nelems = f32[] constant(0.041667)
+        r_nelems_bcast = f32[2,8] broadcast(r_nelems), dimensions={}
+        input_square_mean = f32[2,8] multiply(input_square_sum, r_nelems_bcast)
+        reduce = f32[2,8] reduce(input, c0), dimensions={1,2}, to_apply=apply
+        input_mean = f32[2,8] multiply(reduce, r_nelems_bcast)
+        input_mean_square = f32[2,8] multiply(input_mean, input_mean)
+        variance = f32[2,8] subtract(input_square_mean, input_mean_square)
+        epsilon = f32[] constant(0.001)
+        epsilon_bcast = f32[2,8] broadcast(epsilon), dimensions={}
+        variance_plus_epsilon = f32[2,8] add(variance, epsilon_bcast)
+        norm_factor = f32[2,8] rsqrt(variance_plus_epsilon)
+        norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,3}
+        input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,3}
+        input_center = f32[2,4,6,8] subtract(input, input_mean_bcast)
+        norm = f32[2,4,6,8] multiply(norm_factor_bcast, input_center)
+        scale = f32[4,6] parameter(1)
+        scale_bcast = f32[2,4,6,8] broadcast(scale), dimensions={1,2}
+        norm_scale = f32[2,4,6,8] multiply(norm, scale_bcast)
+        bias = f32[4,6] parameter(2)
+        bias_broadcast = f32[2,4,6,8] broadcast(bias), dimensions={1,2}
+        norm_scale_bias = f32[2,4,6,8] add(norm_scale, bias_broadcast)
+        norm_factor_cube = f32[2,8] divide(norm_factor, variance_plus_epsilon)
+        ROOT out = (f32[2,4,6,8], f32[2,8], f32[2,8], f32[2,8]) tuple(norm_scale_bias, input_mean, norm_factor, norm_factor_cube)
+    })";
+
+  const char* optimized_hlo = R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[4,6], {{.*}}: f32[4,6]) -> (f32[2,4,6,8], f32[2,8], f32[2,8], f32[2,8]) {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0)
+; CHECK-NEXT:    [[TRANSPOSE:%[^ ]+]] = f32[2,8,4,6]{3,2,1,0} transpose([[P0]]), dimensions={0,3,1,2}
+; CHECK-NEXT:    [[P0_BITCAST:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} bitcast([[TRANSPOSE]])
+; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[4,6]{1,0} parameter(1)
+; CHECK-NEXT:    [[P1_BITCAST:%[^ ]+]] = f32[1,4,6,1]{3,2,1,0} bitcast([[P1]])
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[4,6]{1,0} parameter(2)
+; CHECK-NEXT:    [[P2_BITCAST:%[^ ]+]] = f32[1,4,6,1]{3,2,1,0} bitcast([[P2]])
+; CHECK-NEXT:    [[CC:%[^ ]+]] = (f32[16,4,6,1]{3,2,1,0}, f32[16,1,1,1]{3,2,1,0}, f32[16,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
+; CHECK:           custom_call_target="__cudnn$norm",
+; CHECK:           backend_config={
+; CHECK-DAG:         "epsilon":0.001
+; CHECK:           }
+; CHECK-NEXT:    [[GTE0:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} get-tuple-element([[CC]]), index=0
+; CHECK-NEXT:    [[FUSION0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} fusion([[GTE0]]), kind=kLoop, calls=[[FUSED_COMPUTATION0:%[^ ]+]]
+; CHECK-NEXT:    [[GTE1:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=1
+; CHECK-NEXT:    [[GTE1_BITCAST:%[^ ]+]] = f32[2,8]{1,0} bitcast([[GTE1]])
+; CHECK-NEXT:    [[GTE2:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=2
+; CHECK-NEXT:    [[GTE2_BITCAST:%[^ ]+]] = f32[2,8]{1,0} bitcast([[GTE2]])
+; CHECK-NEXT:    [[FUSION1:%[^ ]+]] = f32[2,8]{1,0} fusion([[GTE2]]), kind=kLoop, calls=[[FUSED_COMPUTATION1:%[^ ]+]]
+; CHECK-NEXT:  ROOT [[OUT:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,8]{1,0}, f32[2,8]{1,0}, f32[2,8]{1,0}) tuple([[FUSION0]], [[GTE1_BITCAST]], [[GTE2_BITCAST]], [[FUSION1]])
+  )";
+
+  TestNorm(hlo_text, optimized_hlo);
+}
+
+TEST_F(CudnnNormRewriterTest, LayerNormTrain4D12Degenerate2) {
+  const char* hlo_text = R"(
+    HloModule test
+
+    apply {
+      a = f32[] parameter(0)
+      b = f32[] parameter(1)
+      ROOT c = f32[] add(a,b)
+    }
+
+    ENTRY test {
+        input = f32[2,4,1,8] parameter(0)
+        input_square = f32[2,4,1,8] multiply(input, input)
+        c0 = f32[] constant(0)
+        input_square_sum = f32[2,8] reduce(input_square, c0), dimensions={1,2}, to_apply=apply
+        r_nelems = f32[] constant(0.25)
+        r_nelems_bcast = f32[2,8] broadcast(r_nelems), dimensions={}
+        input_square_mean = f32[2,8] multiply(input_square_sum, r_nelems_bcast)
+        reduce = f32[2,8] reduce(input, c0), dimensions={1,2}, to_apply=apply
+        input_mean = f32[2,8] multiply(reduce, r_nelems_bcast)
+        input_mean_square = f32[2,8] multiply(input_mean, input_mean)
+        variance = f32[2,8] subtract(input_square_mean, input_mean_square)
+        epsilon = f32[] constant(0.001)
+        epsilon_bcast = f32[2,8] broadcast(epsilon), dimensions={}
+        variance_plus_epsilon = f32[2,8] add(variance, epsilon_bcast)
+        norm_factor = f32[2,8] rsqrt(variance_plus_epsilon)
+        norm_factor_bcast = f32[2,4,1,8] broadcast(norm_factor), dimensions={0,3}
+        input_mean_bcast = f32[2,4,1,8] broadcast(input_mean), dimensions={0,3}
+        input_center = f32[2,4,1,8] subtract(input, input_mean_bcast)
+        norm = f32[2,4,1,8] multiply(norm_factor_bcast, input_center)
+        scale = f32[4,1] parameter(1)
+        scale_bcast = f32[2,4,1,8] broadcast(scale), dimensions={1,2}
+        norm_scale = f32[2,4,1,8] multiply(norm, scale_bcast)
+        bias = f32[4,1] parameter(2)
+        bias_broadcast = f32[2,4,1,8] broadcast(bias), dimensions={1,2}
+        norm_scale_bias = f32[2,4,1,8] add(norm_scale, bias_broadcast)
+        norm_factor_cube = f32[2,8] divide(norm_factor, variance_plus_epsilon)
+        ROOT out = (f32[2,4,1,8], f32[2,8], f32[2,8], f32[2,8]) tuple(norm_scale_bias, input_mean, norm_factor, norm_factor_cube)
+    })";
+
+  const char* optimized_hlo = R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,1,8], {{.*}}: f32[4,1], {{.*}}: f32[4,1]) -> (f32[2,4,1,8], f32[2,8], f32[2,8], f32[2,8]) {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} parameter(0)
+; CHECK-NEXT:    [[TRANSPOSE:%[^ ]+]] = f32[1,2,8,4]{3,2,1,0} transpose([[P0]]), dimensions={2,0,3,1}
+; CHECK-NEXT:    [[P0_BITCAST:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} bitcast([[TRANSPOSE]])
+; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[4,1]{1,0} parameter(1)
+; CHECK-NEXT:    [[P1_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P1]])
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[4,1]{1,0} parameter(2)
+; CHECK-NEXT:    [[P2_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P2]])
+; CHECK-NEXT:    [[CC:%[^ ]+]] = (f32[16,4,1,1]{3,2,1,0}, f32[16,1,1,1]{3,2,1,0}, f32[16,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
+; CHECK:           custom_call_target="__cudnn$norm",
+; CHECK:           backend_config={
+; CHECK-DAG:         "epsilon":0.001
+; CHECK:           }
+; CHECK-NEXT:    [[GTE0:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=0
+; CHECK-NEXT:    [[FUSION0:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} fusion([[GTE0]]), kind=kLoop, calls=[[FUSED_COMPUTATION0:%[^ ]+]]
+; CHECK-NEXT:    [[GTE1:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=1
+; CHECK-NEXT:    [[GTE1_BITCAST:%[^ ]+]] = f32[2,8]{1,0} bitcast([[GTE1]])
+; CHECK-NEXT:    [[GTE2:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=2
+; CHECK-NEXT:    [[GTE2_BITCAST:%[^ ]+]] = f32[2,8]{1,0} bitcast([[GTE2]])
+; CHECK-NEXT:    [[FUSION1:%[^ ]+]] = f32[2,8]{1,0} fusion([[GTE2]]), kind=kLoop, calls=[[FUSED_COMPUTATION1:%[^ ]+]]
+; CHECK-NEXT:  ROOT [[OUT:%[^ ]+]] = (f32[2,4,1,8]{3,2,1,0}, f32[2,8]{1,0}, f32[2,8]{1,0}, f32[2,8]{1,0}) tuple([[FUSION0]], [[GTE1_BITCAST]], [[GTE2_BITCAST]], [[FUSION1]])
+  )";
+
+  TestNorm(hlo_text, optimized_hlo);
+}
+
+TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward2D1) {
+  const char* hlo_text = R"(
+    HloModule test
+
+    apply {
+      a = f32[] parameter(0)
+      b = f32[] parameter(1)
+      ROOT c = f32[] add(a,b)
+    }
+
+    ENTRY test {
+        input = f32[2,4] parameter(0)
+        input_square = f32[2,4] multiply(input, input)
+        c0 = f32[] constant(0)
+        input_square_sum = f32[2] reduce(input_square, c0), dimensions={1}, to_apply=apply
+        reduce = f32[2] reduce(input, c0), dimensions={1}, to_apply=apply
+        r_nelems = f32[] constant(0.25)
+        r_nelems_bcast = f32[2] broadcast(r_nelems), dimensions={}
+        input_square_mean = f32[2] multiply(input_square_sum,r_nelems_bcast)
+        input_mean = f32[2] multiply(reduce, r_nelems_bcast)
+        input_mean_square = f32[2] multiply(input_mean,input_mean)
+        variance = f32[2] subtract(input_square_mean,input_mean_square)
+        epsilon = f32[] constant(0.001)
+        epsilon_bcast = f32[2] broadcast(epsilon), dimensions={}
+        variance_plus_epsilon = f32[2] add(variance, epsilon_bcast)
+        norm_factor = f32[2] rsqrt(variance_plus_epsilon)
+        norm_factor_bcast = f32[2,4] broadcast(norm_factor), dimensions={0}
+        input_mean_bcast = f32[2,4] broadcast(input_mean), dimensions={0}
+        input_center = f32[2,4] subtract(input, input_mean_bcast)
+        norm = f32[2,4] multiply(input_center, norm_factor_bcast)
+        scale = f32[4] parameter(1)
+        scale_bcast = f32[2,4] broadcast(scale), dimensions={1}
+        norm_scale = f32[2,4] multiply(norm, scale_bcast)
+        bias = f32[4] parameter(2)
+        bias_bcast = f32[2,4] broadcast(bias), dimensions={1}
+        norm_scale_bias = f32[2,4] add(norm_scale, bias_bcast)
+        doutput = f32[2,4] parameter(3)
+        dbias = f32[4] reduce(doutput, c0), dimensions={0}, to_apply=apply
+        norm_doutput = f32[2,4] multiply(norm, doutput)
+        dscale = f32[4] reduce(norm_doutput, c0), dimensions={0}, to_apply=apply
+        scale_doutput = f32[2,4] multiply(scale_bcast, doutput)
+        input_center_scale_doutput = f32[2,4] multiply(input_center, scale_doutput)
+        f0 = f32[2] reduce(input_center_scale_doutput, c0), dimensions={1}, to_apply=apply
+        norm_factor_cube = f32[2] divide(norm_factor, variance_plus_epsilon)
+        c1 = f32[] constant(-0.5)
+        c1_bcast = f32[2] broadcast(c1), dimensions={}
+        dnorm_factor = f32[2] multiply(norm_factor_cube, c1_bcast)
+        f0_dnorm_factor = f32[2] multiply(f0, dnorm_factor)
+        c2 = f32[] constant(0.5)
+        c2_bcast = f32[2] broadcast(c2), dimensions={}
+        f0_dnorm_factor_scaled = f32[2] multiply(f0_dnorm_factor, c2_bcast)
+        f0_dnorm_factor_scaled_bcast = f32[2,4] broadcast(f0_dnorm_factor_scaled), dimensions={0}
+        f1 = f32[2,4] multiply(input_center, f0_dnorm_factor_scaled_bcast)
+        minus_f1 = f32[2,4] negate(f1)
+        minus_f1_sum = f32[2] reduce(minus_f1, c0), dimensions={1}, to_apply=apply
+        f2 = f32[2,4] multiply(norm_factor_bcast, scale_doutput)
+        minus_f2 = f32[2,4] negate(f2)
+        minus_f2_sum = f32[2] reduce(minus_f2, c0), dimensions={1}, to_apply=apply
+        minus_f1_f2_sum = f32[2] add(minus_f1_sum, minus_f2_sum)
+        minus_f1_f2_sum_scaled = f32[2] multiply(minus_f1_f2_sum, r_nelems_bcast)
+        minus_f1_f2_sum_scaled_bcast = f32[2,4] broadcast(minus_f1_f2_sum_scaled), dimensions={0}
+        f1_f2 = f32[2,4] add(f1, f2)
+        dinput = f32[2,4] add(f1_f2, minus_f1_f2_sum_scaled_bcast)
+        ROOT out = (f32[2,4], f32[2,4], f32[4], f32[4]) tuple(norm_scale_bias, dinput, dscale, dbias)
+    })";
+
+  const char* optimized_hlo = R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4], {{.*}}: f32[4], {{.*}}: f32[4], {{.*}}: f32[2,4]) -> (f32[2,4], f32[2,4], f32[4], f32[4]) {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,4]{1,0} parameter(0)
+; CHECK-NEXT:    [[P0_BITCAST:%[^ ]+]] = f32[2,4,1,1]{3,2,1,0} bitcast([[P0]])
+; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[4]{0} parameter(1)
+; CHECK-NEXT:    [[P1_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P1]])
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[4]{0} parameter(2)
+; CHECK-NEXT:    [[P2_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P2]])
+; CHECK-NEXT:    [[CC0:%[^ ]+]] = (f32[2,4,1,1]{3,2,1,0}, f32[2,1,1,1]{3,2,1,0}, f32[2,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
+; CHECK:           custom_call_target="__cudnn$norm",
+; CHECK:           backend_config={
+; CHECK-DAG:         "epsilon":0.001
+; CHECK-DAG:         "kind":"LAYER_FWD_TRAIN"
+; CHECK:           }
+; CHECK-DAG:     [[GTE0:%[^ ]+]] = f32[2,4,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=0
+; CHECK-DAG:     [[GTE0_BITCAST:%[^ ]+]] = f32[2,4]{1,0} bitcast([[GTE0]])
+; CHECK-DAG:     [[P3:%[^ ]+]] = f32[2,4]{1,0} parameter(3)
+; CHECK-DAG:     [[P3_BITCAST:%[^ ]+]] = f32[2,4,1,1]{3,2,1,0} bitcast([[P3]])
+; CHECK-DAG:     [[GTE1:%[^ ]+]] = f32[2,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=1
+; CHECK-DAG:     [[GTE2:%[^ ]+]] = f32[2,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=2
+; CHECK-NEXT:    [[CC1:%[^ ]+]] = (f32[2,4,1,1]{3,2,1,0}, f32[1,4,1,1]{3,2,1,0}, f32[1,4,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P3_BITCAST]], [[GTE1]], [[GTE2]]),
+; CHECK:           custom_call_target="__cudnn$norm",
+; CHECK:           backend_config={
+; CHECK-DAG:         "epsilon":0
+; CHECK-DAG:         "kind":"LAYER_BWD"
+; CHECK:           }
+; CHECK-DAG:     [[GTE3:%[^ ]+]] = f32[2,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=0
+; CHECK-DAG:     [[GTE3_BITCAST:%[^ ]+]] = f32[2,4]{1,0} bitcast([[GTE3]])
+; CHECK-DAG:     [[GTE4:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=1
+; CHECK-DAG:     [[GTE4_BITCAST:%[^ ]+]] = f32[4]{0} bitcast([[GTE4]])
+; CHECK-DAG:     [[GTE5:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=2
+; CHECK-DAG:     [[GTE5_BITCAST:%[^ ]+]] = f32[4]{0} bitcast([[GTE5]])
+; CHECK-DAG:  ROOT [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, f32[2,4]{1,0}, f32[4]{0}, f32[4]{0}) tuple([[GTE0_BITCAST]], [[GTE3_BITCAST]], [[GTE4_BITCAST]], [[GTE5_BITCAST]])
+  )";
+
+  TestNorm(hlo_text, optimized_hlo);
+}
+
+TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D3) {
+  const char* hlo_text = R"(
+    HloModule test
+
+    apply {
+      a = f32[] parameter(0)
+      b = f32[] parameter(1)
+      ROOT c = f32[] add(a,b)
+    }
+
+    ENTRY test {
+        input = f32[2,4,6,8] parameter(0)
+        input_square = f32[2,4,6,8] multiply(input, input)
+        c0 = f32[] constant(0)
+        input_square_sum = f32[2,4,6] reduce(input_square, c0), dimensions={3}, to_apply=apply
+        reduce = f32[2,4,6] reduce(input, c0), dimensions={3}, to_apply=apply
+        r_nelems = f32[] constant(0.125)
+        r_nelems_bcast = f32[2,4,6] broadcast(r_nelems), dimensions={}
+        input_square_mean = f32[2,4,6] multiply(input_square_sum,r_nelems_bcast)
+        input_mean = f32[2,4,6] multiply(reduce, r_nelems_bcast)
+        input_mean_square = f32[2,4,6] multiply(input_mean,input_mean)
+        variance = f32[2,4,6] subtract(input_square_mean,input_mean_square)
+        epsilon = f32[] constant(0.001)
+        epsilon_bcast = f32[2,4,6] broadcast(epsilon), dimensions={}
+        variance_plus_epsilon = f32[2,4,6] add(variance, epsilon_bcast)
+        norm_factor = f32[2,4,6] rsqrt(variance_plus_epsilon)
+        norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,1,2}
+        input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,1,2}
+        input_center = f32[2,4,6,8] subtract(input, input_mean_bcast)
+        norm = f32[2,4,6,8] multiply(input_center, norm_factor_bcast)
+        scale = f32[8] parameter(1)
+        scale_bcast = f32[2,4,6,8] broadcast(scale), dimensions={3}
+        norm_scale = f32[2,4,6,8] multiply(norm, scale_bcast)
+        bias = f32[8] parameter(2)
+        bias_bcast = f32[2,4,6,8] broadcast(bias), dimensions={3}
+        norm_scale_bias = f32[2,4,6,8] add(norm_scale, bias_bcast)
+        doutput = f32[2,4,6,8] parameter(3)
+        dbias = f32[8] reduce(doutput, c0), dimensions={0,1,2}, to_apply=apply
+        norm_doutput = f32[2,4,6,8] multiply(norm, doutput)
+        dscale = f32[8] reduce(norm_doutput, c0), dimensions={0,1,2}, to_apply=apply
+        scale_doutput = f32[2,4,6,8] multiply(scale_bcast, doutput)
+        input_center_scale_doutput = f32[2,4,6,8] multiply(input_center, scale_doutput)
+        f0 = f32[2,4,6] reduce(input_center_scale_doutput, c0), dimensions={3}, to_apply=apply
+        norm_factor_cube = f32[2,4,6] divide(norm_factor, variance_plus_epsilon)
+        c1 = f32[] constant(-0.5)
+        c1_bcast = f32[2,4,6] broadcast(c1), dimensions={}
+        dnorm_factor = f32[2,4,6] multiply(norm_factor_cube, c1_bcast)
+        f0_dnorm_factor = f32[2,4,6] multiply(f0, dnorm_factor)
+        c2 = f32[] constant(0.25)
+        c2_bcast = f32[2,4,6] broadcast(c2), dimensions={}
+        f0_dnorm_factor_scaled = f32[2,4,6] multiply(f0_dnorm_factor, c2_bcast)
+        f0_dnorm_factor_scaled_bcast = f32[2,4,6,8] broadcast(f0_dnorm_factor_scaled), dimensions={0,1,2}
+        f1 = f32[2,4,6,8] multiply(input_center, f0_dnorm_factor_scaled_bcast)
+        minus_f1 = f32[2,4,6,8] negate(f1)
+        minus_f1_sum = f32[2,4,6] reduce(minus_f1, c0), dimensions={3}, to_apply=apply
+        f2 = f32[2,4,6,8] multiply(norm_factor_bcast, scale_doutput)
+        minus_f2 = f32[2,4,6,8] negate(f2)
+        minus_f2_sum = f32[2,4,6] reduce(minus_f2, c0), dimensions={3}, to_apply=apply
+        minus_f1_f2_sum = f32[2,4,6] add(minus_f1_sum, minus_f2_sum)
+        minus_f1_f2_sum_scaled = f32[2,4,6] multiply(minus_f1_f2_sum, r_nelems_bcast)
+        minus_f1_f2_sum_scaled_bcast = f32[2,4,6,8] broadcast(minus_f1_f2_sum_scaled), dimensions={0,1,2}
+        f1_f2 = f32[2,4,6,8] add(f1, f2)
+        dinput = f32[2,4,6,8] add(f1_f2, minus_f1_f2_sum_scaled_bcast)
+        ROOT out = (f32[2,4,6,8], f32[2,4,6,8], f32[8], f32[8]) tuple(norm_scale_bias, dinput, dscale, dbias)
+    })";
+
+  const char* optimized_hlo = R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[8], {{.*}}: f32[8], {{.*}}: f32[2,4,6,8]) -> (f32[2,4,6,8], f32[2,4,6,8], f32[8], f32[8]) {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0)
+; CHECK-NEXT:    [[P0_BITCAST:%[^ ]+]] = f32[48,8,1,1]{3,2,1,0} bitcast([[P0]])
+; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[8]{0} parameter(1)
+; CHECK-NEXT:    [[P1_BITCAST:%[^ ]+]] = f32[1,8,1,1]{3,2,1,0} bitcast([[P1]])
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[8]{0} parameter(2)
+; CHECK-NEXT:    [[P2_BITCAST:%[^ ]+]] = f32[1,8,1,1]{3,2,1,0} bitcast([[P2]])
+; CHECK-NEXT:    [[CC0:%[^ ]+]] = (f32[48,8,1,1]{3,2,1,0}, f32[48,1,1,1]{3,2,1,0}, f32[48,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
+; CHECK:           custom_call_target="__cudnn$norm",
+; CHECK:           backend_config={
+; CHECK-DAG:         "epsilon":0.001
+; CHECK-DAG:         "kind":"LAYER_FWD_TRAIN"
+; CHECK:           }
+; CHECK-DAG:     [[GTE0:%[^ ]+]] = f32[48,8,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=0
+; CHECK-DAG:     [[GTE0_BITCAST:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} bitcast([[GTE0]])
+; CHECK-DAG:     [[P3:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(3)
+; CHECK-DAG:     [[P3_BITCAST:%[^ ]+]] = f32[48,8,1,1]{3,2,1,0} bitcast([[P3]])
+; CHECK-DAG:     [[GTE1:%[^ ]+]] = f32[48,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=1
+; CHECK-DAG:     [[GTE2:%[^ ]+]] = f32[48,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=2
+; CHECK-NEXT:    [[CC1:%[^ ]+]] = (f32[48,8,1,1]{3,2,1,0}, f32[1,8,1,1]{3,2,1,0}, f32[1,8,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P3_BITCAST]], [[GTE1]], [[GTE2]]),
+; CHECK:           custom_call_target="__cudnn$norm",
+; CHECK:           backend_config={
+; CHECK-DAG:         "epsilon":0
+; CHECK-DAG:         "kind":"LAYER_BWD"
+; CHECK:           }
+; CHECK-DAG:     [[GTE3:%[^ ]+]] = f32[48,8,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=0
+; CHECK-DAG:     [[GTE3_BITCAST:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} bitcast([[GTE3]])
+; CHECK-DAG:     [[GTE4:%[^ ]+]] = f32[1,8,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=1
+; CHECK-DAG:     [[GTE4_BITCAST:%[^ ]+]] = f32[8]{0} bitcast([[GTE4]])
+; CHECK-DAG:     [[GTE5:%[^ ]+]] = f32[1,8,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=2
+; CHECK-DAG:     [[GTE5_BITCAST:%[^ ]+]] = f32[8]{0} bitcast([[GTE5]])
+; CHECK-DAG:  ROOT [[OUT:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}, f32[8]{0}, f32[8]{0}) tuple([[GTE0_BITCAST]], [[GTE3_BITCAST]], [[GTE4_BITCAST]], [[GTE5_BITCAST]])
+  )";
+
+  TestNorm(hlo_text, optimized_hlo);
+}
+
+TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D2) {
+  const char* hlo_text = R"(
+    HloModule test
+
+    apply {
+      a = f32[] parameter(0)
+      b = f32[] parameter(1)
+      ROOT c = f32[] add(a,b)
+    }
+
+    ENTRY test {
+        input = f32[2,4,6,8] parameter(0)
+        input_square = f32[2,4,6,8] multiply(input, input)
+        c0 = f32[] constant(0)
+        input_square_sum = f32[2,4,8] reduce(input_square, c0), dimensions={2}, to_apply=apply
+        reduce = f32[2,4,8] reduce(input, c0), dimensions={2}, to_apply=apply
+        r_nelems = f32[] constant(0.166667)
+        r_nelems_bcast = f32[2,4,8] broadcast(r_nelems), dimensions={}
+        input_square_mean = f32[2,4,8] multiply(input_square_sum,r_nelems_bcast)
+        input_mean = f32[2,4,8] multiply(reduce, r_nelems_bcast)
+        input_mean_square = f32[2,4,8] multiply(input_mean,input_mean)
+        variance = f32[2,4,8] subtract(input_square_mean,input_mean_square)
+        epsilon = f32[] constant(0.001)
+        epsilon_bcast = f32[2,4,8] broadcast(epsilon), dimensions={}
+        variance_plus_epsilon = f32[2,4,8] add(variance, epsilon_bcast)
+        norm_factor = f32[2,4,8] rsqrt(variance_plus_epsilon)
+        norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,1,3}
+        input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,1,3}
+        input_center = f32[2,4,6,8] subtract(input, input_mean_bcast)
+        norm = f32[2,4,6,8] multiply(input_center, norm_factor_bcast)
+        scale = f32[6] parameter(1)
+        scale_bcast = f32[2,4,6,8] broadcast(scale), dimensions={2}
+        norm_scale = f32[2,4,6,8] multiply(norm, scale_bcast)
+        bias = f32[6] parameter(2)
+        bias_bcast = f32[2,4,6,8] broadcast(bias), dimensions={2}
+        norm_scale_bias = f32[2,4,6,8] add(norm_scale, bias_bcast)
+        doutput = f32[2,4,6,8] parameter(3)
+        dbias = f32[6] reduce(doutput, c0), dimensions={0,1,3}, to_apply=apply
+        norm_doutput = f32[2,4,6,8] multiply(norm, doutput)
+        dscale = f32[6] reduce(norm_doutput, c0), dimensions={0,1,3}, to_apply=apply
+        scale_doutput = f32[2,4,6,8] multiply(scale_bcast, doutput)
+        input_center_scale_doutput = f32[2,4,6,8] multiply(input_center, scale_doutput)
+        f0 = f32[2,4,8] reduce(input_center_scale_doutput, c0), dimensions={2}, to_apply=apply
+        norm_factor_cube = f32[2,4,8] divide(norm_factor, variance_plus_epsilon)
+        c1 = f32[] constant(-0.5)
+        c1_bcast = f32[2,4,8] broadcast(c1), dimensions={}
+        dnorm_factor = f32[2,4,8] multiply(norm_factor_cube, c1_bcast)
+        f0_dnorm_factor = f32[2,4,8] multiply(f0, dnorm_factor)
+        c2 = f32[] constant(0.333333)
+        c2_bcast = f32[2,4,8] broadcast(c2), dimensions={}
+        f0_dnorm_factor_scaled = f32[2,4,8] multiply(f0_dnorm_factor, c2_bcast)
+        f0_dnorm_factor_scaled_bcast = f32[2,4,6,8] broadcast(f0_dnorm_factor_scaled), dimensions={0,1,3}
+        f1 = f32[2,4,6,8] multiply(input_center, f0_dnorm_factor_scaled_bcast)
+        minus_f1 = f32[2,4,6,8] negate(f1)
+        minus_f1_sum = f32[2,4,8] reduce(minus_f1, c0), dimensions={2}, to_apply=apply
+        f2 = f32[2,4,6,8] multiply(norm_factor_bcast, scale_doutput)
+        minus_f2 = f32[2,4,6,8] negate(f2)
+        minus_f2_sum = f32[2,4,8] reduce(minus_f2, c0), dimensions={2}, to_apply=apply
+        minus_f1_f2_sum = f32[2,4,8] add(minus_f1_sum, minus_f2_sum)
+        minus_f1_f2_sum_scaled = f32[2,4,8] multiply(minus_f1_f2_sum, r_nelems_bcast)
+        minus_f1_f2_sum_scaled_bcast = f32[2,4,6,8] broadcast(minus_f1_f2_sum_scaled), dimensions={0,1,3}
+        f1_f2 = f32[2,4,6,8] add(f1, f2)
+        dinput = f32[2,4,6,8] add(f1_f2, minus_f1_f2_sum_scaled_bcast)
+        ROOT out = (f32[2,4,6,8], f32[2,4,6,8], f32[6], f32[6]) tuple(norm_scale_bias, dinput, dscale, dbias)
+    })";
+
+  const char* optimized_hlo = R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[6], {{.*}}: f32[6], {{.*}}: f32[2,4,6,8]) -> (f32[2,4,6,8], f32[2,4,6,8], f32[6], f32[6]) {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0)
+; CHECK-NEXT:    [[TRANSPOSE0:%[^ ]+]] = f32[2,4,8,6]{3,2,1,0} transpose([[P0]]), dimensions={0,1,3,2}
+; CHECK-NEXT:    [[P0_BITCAST:%[^ ]+]] = f32[64,6,1,1]{3,2,1,0} bitcast([[TRANSPOSE0]])
+; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[6]{0} parameter(1)
+; CHECK-NEXT:    [[P1_BITCAST:%[^ ]+]] = f32[1,6,1,1]{3,2,1,0} bitcast([[P1]])
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[6]{0} parameter(2)
+; CHECK-NEXT:    [[P2_BITCAST:%[^ ]+]] = f32[1,6,1,1]{3,2,1,0} bitcast([[P2]])
+; CHECK-NEXT:    [[CC0:%[^ ]+]] = (f32[64,6,1,1]{3,2,1,0}, f32[64,1,1,1]{3,2,1,0}, f32[64,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
+; CHECK:           custom_call_target="__cudnn$norm",
+; CHECK:           backend_config={
+; CHECK-DAG:         "epsilon":0.001
+; CHECK-DAG:         "kind":"LAYER_FWD_TRAIN"
+; CHECK:           }
+; CHECK-DAG:     [[GTE0:%[^ ]+]] = f32[64,6,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=0
+; CHECK-DAG:     [[P3:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(3)
+; CHECK-NEXT:    [[TRANSPOSE1:%[^ ]+]] = f32[2,4,8,6]{3,2,1,0} transpose([[P3]]), dimensions={0,1,3,2}
+; CHECK-DAG:     [[P3_BITCAST:%[^ ]+]] = f32[64,6,1,1]{3,2,1,0} bitcast([[TRANSPOSE1]])
+; CHECK-DAG:     [[GTE1:%[^ ]+]] = f32[64,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=1
+; CHECK-DAG:     [[GTE2:%[^ ]+]] = f32[64,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=2
+; CHECK-NEXT:    [[CC1:%[^ ]+]] = (f32[64,6,1,1]{3,2,1,0}, f32[1,6,1,1]{3,2,1,0}, f32[1,6,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P3_BITCAST]], [[GTE1]], [[GTE2]]),
+; CHECK:           custom_call_target="__cudnn$norm",
+; CHECK:           backend_config={
+; CHECK-DAG:         "epsilon":0
+; CHECK-DAG:         "kind":"LAYER_BWD"
+; CHECK:           }
+; CHECK-DAG:     [[GTE3:%[^ ]+]] = f32[64,6,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=0
+; CHECK-DAG:     [[FUSION:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}) fusion([[GTE0]], [[GTE3]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]]
+; CHECK-DAG:     [[GTEF0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION]]), index=0
+; CHECK-DAG:     [[GTEF1:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION]]), index=1
+; CHECK-DAG:     [[GTE4:%[^ ]+]] = f32[1,6,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=1
+; CHECK-DAG:     [[GTE4_BITCAST:%[^ ]+]] = f32[6]{0} bitcast([[GTE4]])
+; CHECK-DAG:     [[GTE5:%[^ ]+]] = f32[1,6,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=2
+; CHECK-DAG:     [[GTE5_BITCAST:%[^ ]+]] = f32[6]{0} bitcast([[GTE5]])
+; CHECK-DAG:  ROOT [[OUT:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}, f32[6]{0}, f32[6]{0}) tuple([[GTEF0]], [[GTEF1]], [[GTE4_BITCAST]], [[GTE5_BITCAST]])
+  )";
+
+  TestNorm(hlo_text, optimized_hlo);
+}
+
+TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D12) {
+  const char* hlo_text = R"(
+    HloModule test
+
+    apply {
+      a = f32[] parameter(0)
+      b = f32[] parameter(1)
+      ROOT c = f32[] add(a,b)
+    }
+
+    ENTRY test {
+        input = f32[2,4,6,8] parameter(0)
+        input_square = f32[2,4,6,8] multiply(input, input)
+        c0 = f32[] constant(0)
+        input_square_sum = f32[2,8] reduce(input_square, c0), dimensions={1,2}, to_apply=apply
+        reduce = f32[2,8] reduce(input, c0), dimensions={1,2}, to_apply=apply
+        r_nelems = f32[] constant(0.041667)
+        r_nelems_bcast = f32[2,8] broadcast(r_nelems), dimensions={}
+        input_square_mean = f32[2,8] multiply(input_square_sum,r_nelems_bcast)
+        input_mean = f32[2,8] multiply(reduce, r_nelems_bcast)
+        input_mean_square = f32[2,8] multiply(input_mean,input_mean)
+        variance = f32[2,8] subtract(input_square_mean,input_mean_square)
+        epsilon = f32[] constant(0.001)
+        epsilon_bcast = f32[2,8] broadcast(epsilon), dimensions={}
+        variance_plus_epsilon = f32[2,8] add(variance, epsilon_bcast)
+        norm_factor = f32[2,8] rsqrt(variance_plus_epsilon)
+        norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,3}
+        input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,3}
+        input_center = f32[2,4,6,8] subtract(input, input_mean_bcast)
+        norm = f32[2,4,6,8] multiply(input_center, norm_factor_bcast)
+        scale = f32[4,6] parameter(1)
+        scale_bcast = f32[2,4,6,8] broadcast(scale), dimensions={1,2}
+        norm_scale = f32[2,4,6,8] multiply(norm, scale_bcast)
+        bias = f32[4,6] parameter(2)
+        bias_bcast = f32[2,4,6,8] broadcast(bias), dimensions={1,2}
+        norm_scale_bias = f32[2,4,6,8] add(norm_scale, bias_bcast)
+        doutput = f32[2,4,6,8] parameter(3)
+        dbias = f32[4,6] reduce(doutput, c0), dimensions={0,3}, to_apply=apply
+        norm_doutput = f32[2,4,6,8] multiply(norm, doutput)
+        dscale = f32[4,6] reduce(norm_doutput, c0), dimensions={0,3}, to_apply=apply
+        scale_doutput = f32[2,4,6,8] multiply(scale_bcast, doutput)
+        input_center_scale_doutput = f32[2,4,6,8] multiply(input_center, scale_doutput)
+        f0 = f32[2,8] reduce(input_center_scale_doutput, c0), dimensions={1,2}, to_apply=apply
+        norm_factor_cube = f32[2,8] divide(norm_factor, variance_plus_epsilon)
+        c1 = f32[] constant(-0.5)
+        c1_bcast = f32[2,8] broadcast(c1), dimensions={}
+        dnorm_factor = f32[2,8] multiply(norm_factor_cube, c1_bcast)
+        f0_dnorm_factor = f32[2,8] multiply(f0, dnorm_factor)
+        c2 = f32[] constant(0.083333)
+        c2_bcast = f32[2,8] broadcast(c2), dimensions={}
+        f0_dnorm_factor_scaled = f32[2,8] multiply(f0_dnorm_factor, c2_bcast)
+        f0_dnorm_factor_scaled_bcast = f32[2,4,6,8] broadcast(f0_dnorm_factor_scaled), dimensions={0,3}
+        f1 = f32[2,4,6,8] multiply(input_center, f0_dnorm_factor_scaled_bcast)
+        minus_f1 = f32[2,4,6,8] negate(f1)
+        minus_f1_sum = f32[2,8] reduce(minus_f1, c0), dimensions={1,2}, to_apply=apply
+        f2 = f32[2,4,6,8] multiply(norm_factor_bcast, scale_doutput)
+        minus_f2 = f32[2,4,6,8] negate(f2)
+        minus_f2_sum = f32[2,8] reduce(minus_f2, c0), dimensions={1,2}, to_apply=apply
+        minus_f1_f2_sum = f32[2,8] add(minus_f1_sum, minus_f2_sum)
+        minus_f1_f2_sum_scaled = f32[2,8] multiply(minus_f1_f2_sum, r_nelems_bcast)
+        minus_f1_f2_sum_scaled_bcast = f32[2,4,6,8] broadcast(minus_f1_f2_sum_scaled), dimensions={0,3}
+        f1_f2 = f32[2,4,6,8] add(f1, f2)
+        dinput = f32[2,4,6,8] add(f1_f2, minus_f1_f2_sum_scaled_bcast)
+        ROOT out = (f32[2,4,6,8], f32[2,4,6,8], f32[4,6], f32[4,6]) tuple(norm_scale_bias, dinput, dscale, dbias)
+    })";
+
+  const char* optimized_hlo = R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[4,6], {{.*}}: f32[4,6], {{.*}}: f32[2,4,6,8]) -> (f32[2,4,6,8], f32[2,4,6,8], f32[4,6], f32[4,6]) {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0)
+; CHECK-NEXT:    [[TRANSPOSE0:%[^ ]+]] = f32[2,8,4,6]{3,2,1,0} transpose([[P0]]), dimensions={0,3,1,2}
+; CHECK-NEXT:    [[P0_BITCAST:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} bitcast([[TRANSPOSE0]])
+; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[4,6]{1,0} parameter(1)
+; CHECK-NEXT:    [[P1_BITCAST:%[^ ]+]] = f32[1,4,6,1]{3,2,1,0} bitcast([[P1]])
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[4,6]{1,0} parameter(2)
+; CHECK-NEXT:    [[P2_BITCAST:%[^ ]+]] = f32[1,4,6,1]{3,2,1,0} bitcast([[P2]])
+; CHECK-NEXT:    [[CC0:%[^ ]+]] = (f32[16,4,6,1]{3,2,1,0}, f32[16,1,1,1]{3,2,1,0}, f32[16,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
+; CHECK:           custom_call_target="__cudnn$norm",
+; CHECK:           backend_config={
+; CHECK-DAG:         "epsilon":0.001
+; CHECK-DAG:         "kind":"LAYER_FWD_TRAIN"
+; CHECK:           }
+; CHECK-DAG:     [[GTE0:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} get-tuple-element([[CC0]]), index=0
+; CHECK-DAG:     [[P3:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(3)
+; CHECK-NEXT:    [[TRANSPOSE1:%[^ ]+]] = f32[2,8,4,6]{3,2,1,0} transpose([[P3]]), dimensions={0,3,1,2}
+; CHECK-DAG:     [[P3_BITCAST:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} bitcast([[TRANSPOSE1]])
+; CHECK-DAG:     [[GTE1:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=1
+; CHECK-DAG:     [[GTE2:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=2
+; CHECK-NEXT:    [[CC1:%[^ ]+]] = (f32[16,4,6,1]{3,2,1,0}, f32[1,4,6,1]{3,2,1,0}, f32[1,4,6,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P3_BITCAST]], [[GTE1]], [[GTE2]]),
+; CHECK:           custom_call_target="__cudnn$norm",
+; CHECK:           backend_config={
+; CHECK-DAG:         "epsilon":0
+; CHECK-DAG:         "kind":"LAYER_BWD"
+; CHECK:           }
+; CHECK-DAG:     [[GTE3:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} get-tuple-element([[CC1]]), index=0
+; CHECK-DAG:     [[FUSION:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}) fusion([[GTE0]], [[GTE3]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]]
+; CHECK-DAG:     [[GTEF0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION]]), index=0
+; CHECK-DAG:     [[GTEF1:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION]]), index=1
+; CHECK-DAG:     [[GTE4:%[^ ]+]] = f32[1,4,6,1]{3,2,1,0} get-tuple-element([[CC1]]), index=1
+; CHECK-DAG:     [[GTE4_BITCAST:%[^ ]+]] = f32[4,6]{1,0} bitcast([[GTE4]])
+; CHECK-DAG:     [[GTE5:%[^ ]+]] = f32[1,4,6,1]{3,2,1,0} get-tuple-element([[CC1]]), index=2
+; CHECK-DAG:     [[GTE5_BITCAST:%[^ ]+]] = f32[4,6]{1,0} bitcast([[GTE5]])
+; CHECK-DAG:  ROOT [[OUT:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}, f32[4,6]{1,0}, f32[4,6]{1,0}) tuple([[GTEF0]], [[GTEF1]], [[GTE4_BITCAST]], [[GTE5_BITCAST]])
+  )";
+
+  TestNorm(hlo_text, optimized_hlo);
+}
+
+TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D12Degenerate2) {
+  const char* hlo_text = R"(
+    HloModule test
+
+    apply {
+      a = f32[] parameter(0)
+      b = f32[] parameter(1)
+      ROOT c = f32[] add(a,b)
+    }
+
+    ENTRY test {
+        input = f32[2,4,1,8] parameter(0)
+        input_square = f32[2,4,1,8] multiply(input, input)
+        c0 = f32[] constant(0)
+        input_square_sum = f32[2,8] reduce(input_square, c0), dimensions={1,2}, to_apply=apply
+        reduce = f32[2,8] reduce(input, c0), dimensions={1,2}, to_apply=apply
+        r_nelems = f32[] constant(0.25)
+        r_nelems_bcast = f32[2,8] broadcast(r_nelems), dimensions={}
+        input_square_mean = f32[2,8] multiply(input_square_sum,r_nelems_bcast)
+        input_mean = f32[2,8] multiply(reduce, r_nelems_bcast)
+        input_mean_square = f32[2,8] multiply(input_mean,input_mean)
+        variance = f32[2,8] subtract(input_square_mean,input_mean_square)
+        epsilon = f32[] constant(0.001)
+        epsilon_bcast = f32[2,8] broadcast(epsilon), dimensions={}
+        variance_plus_epsilon = f32[2,8] add(variance, epsilon_bcast)
+        norm_factor = f32[2,8] rsqrt(variance_plus_epsilon)
+        norm_factor_bcast = f32[2,4,1,8] broadcast(norm_factor), dimensions={0,3}
+        input_mean_bcast = f32[2,4,1,8] broadcast(input_mean), dimensions={0,3}
+        input_center = f32[2,4,1,8] subtract(input, input_mean_bcast)
+        norm = f32[2,4,1,8] multiply(input_center, norm_factor_bcast)
+        scale = f32[4,1] parameter(1)
+        scale_bcast = f32[2,4,1,8] broadcast(scale), dimensions={1,2}
+        norm_scale = f32[2,4,1,8] multiply(norm, scale_bcast)
+        bias = f32[4,1] parameter(2)
+        bias_bcast = f32[2,4,1,8] broadcast(bias), dimensions={1,2}
+        norm_scale_bias = f32[2,4,1,8] add(norm_scale, bias_bcast)
+        doutput = f32[2,4,1,8] parameter(3)
+        dbias = f32[4,1] reduce(doutput, c0), dimensions={0,3}, to_apply=apply
+        norm_doutput = f32[2,4,1,8] multiply(norm, doutput)
+        dscale = f32[4,1] reduce(norm_doutput, c0), dimensions={0,3}, to_apply=apply
+        scale_doutput = f32[2,4,1,8] multiply(scale_bcast, doutput)
+        input_center_scale_doutput = f32[2,4,1,8] multiply(input_center, scale_doutput)
+        f0 = f32[2,8] reduce(input_center_scale_doutput, c0), dimensions={1,2}, to_apply=apply
+        norm_factor_cube = f32[2,8] divide(norm_factor, variance_plus_epsilon)
+        c1 = f32[] constant(-0.5)
+        c1_bcast = f32[2,8] broadcast(c1), dimensions={}
+        dnorm_factor = f32[2,8] multiply(norm_factor_cube, c1_bcast)
+        f0_dnorm_factor = f32[2,8] multiply(f0, dnorm_factor)
+        c2 = f32[] constant(0.5)
+        c2_bcast = f32[2,8] broadcast(c2), dimensions={}
+        f0_dnorm_factor_scaled = f32[2,8] multiply(f0_dnorm_factor, c2_bcast)
+        f0_dnorm_factor_scaled_bcast = f32[2,4,1,8] broadcast(f0_dnorm_factor_scaled), dimensions={0,3}
+        f1 = f32[2,4,1,8] multiply(input_center, f0_dnorm_factor_scaled_bcast)
+        minus_f1 = f32[2,4,1,8] negate(f1)
+        minus_f1_sum = f32[2,8] reduce(minus_f1, c0), dimensions={1,2}, to_apply=apply
+        f2 = f32[2,4,1,8] multiply(norm_factor_bcast, scale_doutput)
+        minus_f2 = f32[2,4,1,8] negate(f2)
+        minus_f2_sum = f32[2,8] reduce(minus_f2, c0), dimensions={1,2}, to_apply=apply
+        minus_f1_f2_sum = f32[2,8] add(minus_f1_sum, minus_f2_sum)
+        minus_f1_f2_sum_scaled = f32[2,8] multiply(minus_f1_f2_sum, r_nelems_bcast)
+        minus_f1_f2_sum_scaled_bcast = f32[2,4,1,8] broadcast(minus_f1_f2_sum_scaled), dimensions={0,3}
+        f1_f2 = f32[2,4,1,8] add(f1, f2)
+        dinput = f32[2,4,1,8] add(f1_f2, minus_f1_f2_sum_scaled_bcast)
+        ROOT out = (f32[2,4,1,8], f32[2,4,1,8], f32[4,1], f32[4,1]) tuple(norm_scale_bias, dinput, dscale, dbias)
+    })";
+
+  const char* optimized_hlo = R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,1,8], {{.*}}: f32[4,1], {{.*}}: f32[4,1], {{.*}}: f32[2,4,1,8]) -> (f32[2,4,1,8], f32[2,4,1,8], f32[4,1], f32[4,1]) {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} parameter(0)
+; CHECK-NEXT:    [[TRANSPOSE0:%[^ ]+]] = f32[1,2,8,4]{3,2,1,0} transpose([[P0]]), dimensions={2,0,3,1}
+; CHECK-NEXT:    [[P0_BITCAST:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} bitcast([[TRANSPOSE0]])
+; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[4,1]{1,0} parameter(1)
+; CHECK-NEXT:    [[P1_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P1]])
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[4,1]{1,0} parameter(2)
+; CHECK-NEXT:    [[P2_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P2]])
+; CHECK-NEXT:    [[CC0:%[^ ]+]] = (f32[16,4,1,1]{3,2,1,0}, f32[16,1,1,1]{3,2,1,0}, f32[16,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
+; CHECK:           custom_call_target="__cudnn$norm",
+; CHECK:           backend_config={
+; CHECK-DAG:         "epsilon":0.001
+; CHECK-DAG:         "kind":"LAYER_FWD_TRAIN"
+; CHECK:           }
+; CHECK-DAG:     [[GTE0:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=0
+; CHECK-DAG:     [[P3:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} parameter(3)
+; CHECK-NEXT:    [[TRANSPOSE1:%[^ ]+]] = f32[1,2,8,4]{3,2,1,0} transpose([[P3]]), dimensions={2,0,3,1}
+; CHECK-DAG:     [[P3_BITCAST:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} bitcast([[TRANSPOSE1]])
+; CHECK-DAG:     [[GTE1:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=1
+; CHECK-DAG:     [[GTE2:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=2
+; CHECK-NEXT:    [[CC1:%[^ ]+]] = (f32[16,4,1,1]{3,2,1,0}, f32[1,4,1,1]{3,2,1,0}, f32[1,4,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P3_BITCAST]], [[GTE1]], [[GTE2]]),
+; CHECK:           custom_call_target="__cudnn$norm",
+; CHECK:           backend_config={
+; CHECK-DAG:         "epsilon":0
+; CHECK-DAG:         "kind":"LAYER_BWD"
+; CHECK:           }
+; CHECK-DAG:     [[GTE3:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=0
+; CHECK-DAG:     [[FUSION0:%[^ ]+]] = (f32[2,4,1,8]{3,2,1,0}, f32[2,4,1,8]{3,2,1,0}) fusion([[GTE0]], [[GTE3]]), kind=kLoop, calls=[[FUSED_COMPUTATION0:%[^ ]+]]
+; CHECK-DAG:     [[GTEF0:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} get-tuple-element([[FUSION0]]), index=0
+; CHECK-DAG:     [[GTEF1:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} get-tuple-element([[FUSION0]]), index=1
+; CHECK-DAG:     [[GTE4:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=1
+; CHECK-DAG:     [[GTE4_BITCAST:%[^ ]+]] = f32[4,1]{1,0} bitcast([[GTE4]])
+; CHECK-DAG:     [[GTE5:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=2
+; CHECK-DAG:     [[GTE5_BITCAST:%[^ ]+]] = f32[4,1]{1,0} bitcast([[GTE5]])
+; CHECK-DAG:  ROOT [[OUT:%[^ ]+]] = (f32[2,4,1,8]{3,2,1,0}, f32[2,4,1,8]{3,2,1,0}, f32[4,1]{1,0}, f32[4,1]{1,0}) tuple([[GTEF0]], [[GTEF1]], [[GTE4_BITCAST]], [[GTE5_BITCAST]])
+  )";
+
+  TestNorm(hlo_text, optimized_hlo);
+}
+
+// TODO(b/343124533) Reenable when fixed
+TEST_F(CudnnNormRewriterTest,
+       DISABLED_LayerNormTrainBackward4D1DoutputReshapeSplit) {
+  const char* hlo_text = R"(
+    HloModule test
+
+    apply {
+      a = f32[] parameter(0)
+      b = f32[] parameter(1)
+      ROOT c = f32[] add(a,b)
+    }
+
+    ENTRY test {
+        input = f32[2,4,6,8] parameter(0)
+        input_square = f32[2,4,6,8] multiply(input, input)
+        c0 = f32[] constant(0)
+        input_square_sum = f32[2,6,8] reduce(input_square, c0), dimensions={1}, to_apply=apply
+        reduce = f32[2,6,8] reduce(input, c0), dimensions={1}, to_apply=apply
+        r_nelems = f32[] constant(0.25)
+        r_nelems_bcast = f32[2,6,8] broadcast(r_nelems), dimensions={}
+        input_square_mean = f32[2,6,8] multiply(input_square_sum,r_nelems_bcast)
+        input_mean = f32[2,6,8] multiply(reduce, r_nelems_bcast)
+        input_mean_square = f32[2,6,8] multiply(input_mean,input_mean)
+        variance = f32[2,6,8] subtract(input_square_mean,input_mean_square)
+        epsilon = f32[] constant(0.001)
+        epsilon_bcast = f32[2,6,8] broadcast(epsilon), dimensions={}
+        variance_plus_epsilon = f32[2,6,8] add(variance, epsilon_bcast)
+        norm_factor = f32[2,6,8] rsqrt(variance_plus_epsilon)
+        norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,2,3}
+        input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,2,3}
+        input_center = f32[2,4,6,8] subtract(input, input_mean_bcast)
+        norm = f32[2,4,6,8] multiply(input_center, norm_factor_bcast)
+        scale = f32[4] parameter(1)
+        scale_bcast = f32[2,4,6,8] broadcast(scale), dimensions={1}
+        norm_scale = f32[2,4,6,8] multiply(norm, scale_bcast)
+        bias = f32[4] parameter(2)
+        bias_bcast = f32[2,4,6,8] broadcast(bias), dimensions={1}
+        norm_scale_bias = f32[2,4,6,8] add(norm_scale, bias_bcast)
+        doutput = f32[2,4,48] parameter(3)
+        dbias = f32[4] reduce(doutput, c0), dimensions={0,2}, to_apply=apply
+        doutput_bitcast = f32[2,4,6,8] reshape(doutput)
+        norm_doutput = f32[2,4,6,8] multiply(norm, doutput_bitcast)
+        dscale = f32[4] reduce(norm_doutput, c0), dimensions={0,2,3}, to_apply=apply
+        scale_doutput = f32[2,4,6,8] multiply(scale_bcast, doutput_bitcast)
+        input_center_scale_doutput = f32[2,4,6,8] multiply(input_center, scale_doutput)
+        f0 = f32[2,6,8] reduce(input_center_scale_doutput, c0), dimensions={1}, to_apply=apply
+        norm_factor_cube = f32[2,6,8] divide(norm_factor, variance_plus_epsilon)
+        c1 = f32[] constant(-0.5)
+        c1_bcast = f32[2,6,8] broadcast(c1), dimensions={}
+        dnorm_factor = f32[2,6,8] multiply(norm_factor_cube, c1_bcast)
+        f0_dnorm_factor = f32[2,6,8] multiply(f0, dnorm_factor)
+        c2 = f32[] constant(0.5)
+        c2_bcast = f32[2,6,8] broadcast(c2), dimensions={}
+        f0_dnorm_factor_scaled = f32[2,6,8] multiply(f0_dnorm_factor, c2_bcast)
+        f0_dnorm_factor_scaled_bcast = f32[2,4,6,8] broadcast(f0_dnorm_factor_scaled), dimensions={0,2,3}
+        f1 = f32[2,4,6,8] multiply(input_center, f0_dnorm_factor_scaled_bcast)
+        minus_f1 = f32[2,4,6,8] negate(f1)
+        minus_f1_sum = f32[2,6,8] reduce(minus_f1, c0), dimensions={1}, to_apply=apply
+        f2 = f32[2,4,6,8] multiply(norm_factor_bcast, scale_doutput)
+        minus_f2 = f32[2,4,6,8] negate(f2)
+        minus_f2_sum = f32[2,6,8] reduce(minus_f2, c0), dimensions={1}, to_apply=apply
+        minus_f1_f2_sum = f32[2,6,8] add(minus_f1_sum, minus_f2_sum)
+        minus_f1_f2_sum_scaled = f32[2,6,8] multiply(minus_f1_f2_sum, r_nelems_bcast)
+        minus_f1_f2_sum_scaled_bcast = f32[2,4,6,8] broadcast(minus_f1_f2_sum_scaled), dimensions={0,2,3}
+        f1_f2 = f32[2,4,6,8] add(f1, f2)
+        dinput = f32[2,4,6,8] add(f1_f2, minus_f1_f2_sum_scaled_bcast)
+        ROOT out = (f32[2,4,6,8], f32[2,4,6,8], f32[4], f32[4]) tuple(norm_scale_bias, dinput, dscale, dbias)
+    })";
+
+  const char* optimized_hlo = R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[4], {{.*}}: f32[4], {{.*}}: f32[2,4,48]) -> (f32[2,4,6,8], f32[2,4,6,8], f32[4], f32[4]) {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0)
+; CHECK-NEXT:    [[TRANSPOSE0:%[^ ]+]] = f32[2,6,8,4]{3,2,1,0} transpose([[P0]]), dimensions={0,2,3,1}
+; CHECK-NEXT:    [[P0_BITCAST:%[^ ]+]] = f32[96,4,1,1]{3,2,1,0} bitcast([[TRANSPOSE0]])
+; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[4]{0} parameter(1)
+; CHECK-NEXT:    [[P1_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P1]])
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[4]{0} parameter(2)
+; CHECK-NEXT:    [[P2_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P2]])
+; CHECK-NEXT:    [[CC0:%[^ ]+]] = (f32[96,4,1,1]{3,2,1,0}, f32[96,1,1,1]{3,2,1,0}, f32[96,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
+; CHECK:           custom_call_target="__cudnn$norm",
+; CHECK:           backend_config={
+; CHECK-DAG:         "epsilon":0.001
+; CHECK-DAG:         "kind":"LAYER_FWD_TRAIN"
+; CHECK:           }
+; CHECK-DAG:     [[GTE0:%[^ ]+]] = f32[96,4,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=0
+; CHECK-DAG:     [[P3:%[^ ]+]] = f32[2,4,48]{2,1,0} parameter(3)
+; CHECK-DAG:     [[FUSION0:%[^ ]+]] = f32[2,6,8,4]{3,2,1,0} fusion([[P3]]), kind=kLoop, calls=[[FUSED_COMPUTATION0:%[^ ]+]]
+; CHECK-DAG:     [[FUSION0_BITCAST:%[^ ]+]] = f32[96,4,1,1]{3,2,1,0} bitcast([[FUSION0]])
+; CHECK-DAG:     [[GTE1:%[^ ]+]] = f32[96,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=1
+; CHECK-DAG:     [[GTE2:%[^ ]+]] = f32[96,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=2
+; CHECK-NEXT:    [[CC1:%[^ ]+]] = (f32[96,4,1,1]{3,2,1,0}, f32[1,4,1,1]{3,2,1,0}, f32[1,4,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[FUSION0_BITCAST]], [[GTE1]], [[GTE2]]),
+; CHECK:           custom_call_target="__cudnn$norm",
+; CHECK:           backend_config={
+; CHECK-DAG:         "epsilon":0
+; CHECK-DAG:         "kind":"LAYER_BWD"
+; CHECK:           }
+; CHECK-DAG:     [[GTE3:%[^ ]+]] = f32[96,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=0
+; CHECK-DAG:     [[FUSION1:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}) fusion([[GTE0]], [[GTE3]]), kind=kLoop, calls=[[FUSED_COMPUTATION1:%[^ ]+]]
+; CHECK-DAG:     [[GTEF1:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION1]]), index=0
+; CHECK-DAG:     [[GTEF2:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION1]]), index=1
+; CHECK-DAG:     [[GTE4:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=1
+; CHECK-DAG:     [[GTE4_BITCAST:%[^ ]+]] = f32[4]{0} bitcast([[GTE4]])
+; CHECK-DAG:     [[GTE5:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=2
+; CHECK-DAG:     [[GTE5_BITCAST:%[^ ]+]] = f32[4]{0} bitcast([[GTE5]])
+; CHECK-DAG:  ROOT [[OUT:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}, f32[4]{0}, f32[4]{0}) tuple([[GTEF1]], [[GTEF2]], [[GTE4_BITCAST]], [[GTE5_BITCAST]])
+  )";
+
+  TestNorm(hlo_text, optimized_hlo);
+}
+
+// TODO(b/343124533) Reenable when fixed
+TEST_F(CudnnNormRewriterTest,
+       DISABLED_LayerNormTrainBackward4D1DoutputReshapeCombine) {
+  const char* hlo_text = R"(
+    HloModule test
+
+    apply {
+      a = f32[] parameter(0)
+      b = f32[] parameter(1)
+      ROOT c = f32[] add(a,b)
+    }
+
+    ENTRY test {
+        input = f32[2,4,6,8] parameter(0)
+        input_square = f32[2,4,6,8] multiply(input, input)
+        c0 = f32[] constant(0)
+        input_square_sum = f32[2,6,8] reduce(input_square, c0), dimensions={1}, to_apply=apply
+        reduce = f32[2,6,8] reduce(input, c0), dimensions={1}, to_apply=apply
+        r_nelems = f32[] constant(0.25)
+        r_nelems_bcast = f32[2,6,8] broadcast(r_nelems), dimensions={}
+        input_square_mean = f32[2,6,8] multiply(input_square_sum,r_nelems_bcast)
+        input_mean = f32[2,6,8] multiply(reduce, r_nelems_bcast)
+        input_mean_square = f32[2,6,8] multiply(input_mean,input_mean)
+        variance = f32[2,6,8] subtract(input_square_mean,input_mean_square)
+        epsilon = f32[] constant(0.001)
+        epsilon_bcast = f32[2,6,8] broadcast(epsilon), dimensions={}
+        variance_plus_epsilon = f32[2,6,8] add(variance, epsilon_bcast)
+        norm_factor = f32[2,6,8] rsqrt(variance_plus_epsilon)
+        norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,2,3}
+        input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,2,3}
+        input_center = f32[2,4,6,8] subtract(input, input_mean_bcast)
+        norm = f32[2,4,6,8] multiply(input_center, norm_factor_bcast)
+        scale = f32[4] parameter(1)
+        scale_bcast = f32[2,4,6,8] broadcast(scale), dimensions={1}
+        norm_scale = f32[2,4,6,8] multiply(norm, scale_bcast)
+        bias = f32[4] parameter(2)
+        bias_bcast = f32[2,4,6,8] broadcast(bias), dimensions={1}
+        norm_scale_bias = f32[2,4,6,8] add(norm_scale, bias_bcast)
+        doutput = f32[2,4,6,2,2,2] parameter(3)
+        dbias = f32[4] reduce(doutput, c0), dimensions={0,2,3,4,5}, to_apply=apply
+        doutput_bitcast = f32[2,4,6,8] reshape(doutput)
+        norm_doutput = f32[2,4,6,8] multiply(norm, doutput_bitcast)
+        dscale = f32[4] reduce(norm_doutput, c0), dimensions={0,2,3}, to_apply=apply
+        scale_doutput = f32[2,4,6,8] multiply(scale_bcast, doutput_bitcast)
+        input_center_scale_doutput = f32[2,4,6,8] multiply(input_center, scale_doutput)
+        f0 = f32[2,6,8] reduce(input_center_scale_doutput, c0), dimensions={1}, to_apply=apply
+        norm_factor_cube = f32[2,6,8] divide(norm_factor, variance_plus_epsilon)
+        c1 = f32[] constant(-0.5)
+        c1_bcast = f32[2,6,8] broadcast(c1), dimensions={}
+        dnorm_factor = f32[2,6,8] multiply(norm_factor_cube, c1_bcast)
+        f0_dnorm_factor = f32[2,6,8] multiply(f0, dnorm_factor)
+        c2 = f32[] constant(0.5)
+        c2_bcast = f32[2,6,8] broadcast(c2), dimensions={}
+        f0_dnorm_factor_scaled = f32[2,6,8] multiply(f0_dnorm_factor, c2_bcast)
+        f0_dnorm_factor_scaled_bcast = f32[2,4,6,8] broadcast(f0_dnorm_factor_scaled), dimensions={0,2,3}
+        f1 = f32[2,4,6,8] multiply(input_center, f0_dnorm_factor_scaled_bcast)
+        minus_f1 = f32[2,4,6,8] negate(f1)
+        minus_f1_sum = f32[2,6,8] reduce(minus_f1, c0), dimensions={1}, to_apply=apply
+        f2 = f32[2,4,6,8] multiply(norm_factor_bcast, scale_doutput)
+        minus_f2 = f32[2,4,6,8] negate(f2)
+        minus_f2_sum = f32[2,6,8] reduce(minus_f2, c0), dimensions={1}, to_apply=apply
+        minus_f1_f2_sum = f32[2,6,8] add(minus_f1_sum, minus_f2_sum)
+        minus_f1_f2_sum_scaled = f32[2,6,8] multiply(minus_f1_f2_sum, r_nelems_bcast)
+        minus_f1_f2_sum_scaled_bcast = f32[2,4,6,8] broadcast(minus_f1_f2_sum_scaled), dimensions={0,2,3}
+        f1_f2 = f32[2,4,6,8] add(f1, f2)
+        dinput = f32[2,4,6,8] add(f1_f2, minus_f1_f2_sum_scaled_bcast)
+        ROOT out = (f32[2,4,6,8], f32[2,4,6,8], f32[4], f32[4]) tuple(norm_scale_bias, dinput, dscale, dbias)
+    })";
+
+  const char* optimized_hlo = R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[4], {{.*}}: f32[4], {{.*}}: f32[2,4,6,2,2,2]) -> (f32[2,4,6,8], f32[2,4,6,8], f32[4], f32[4]) {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0)
+; CHECK-NEXT:    [[TRANSPOSE0:%[^ ]+]] = f32[2,6,8,4]{3,2,1,0} transpose([[P0]]), dimensions={0,2,3,1}
+; CHECK-NEXT:    [[P0_BITCAST:%[^ ]+]] = f32[96,4,1,1]{3,2,1,0} bitcast([[TRANSPOSE0]])
+; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[4]{0} parameter(1)
+; CHECK-NEXT:    [[P1_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P1]])
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[4]{0} parameter(2)
+; CHECK-NEXT:    [[P2_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P2]])
+; CHECK-NEXT:    [[CC0:%[^ ]+]] = (f32[96,4,1,1]{3,2,1,0}, f32[96,1,1,1]{3,2,1,0}, f32[96,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]),
+; CHECK:           custom_call_target="__cudnn$norm",
+; CHECK:           backend_config={
+; CHECK-DAG:         "epsilon":0.001
+; CHECK-DAG:         "kind":"LAYER_FWD_TRAIN"
+; CHECK:           }
+; CHECK-DAG:     [[GTE0:%[^ ]+]] = f32[96,4,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=0
+; CHECK-DAG:     [[P3:%[^ ]+]] = f32[2,4,6,2,2,2]{5,4,3,2,1,0} parameter(3)
+; CHECK-DAG:     [[FUSION0:%[^ ]+]] = f32[2,6,8,4]{3,2,1,0} fusion([[P3]]), kind=kLoop, calls=[[FUSED_COMPUTATION0:%[^ ]+]]
+; CHECK-DAG:     [[FUSION0_BITCAST:%[^ ]+]] = f32[96,4,1,1]{3,2,1,0} bitcast([[FUSION0]])
+; CHECK-DAG:     [[GTE1:%[^ ]+]] = f32[96,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=1
+; CHECK-DAG:     [[GTE2:%[^ ]+]] = f32[96,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=2
+; CHECK-NEXT:    [[CC1:%[^ ]+]] = (f32[96,4,1,1]{3,2,1,0}, f32[1,4,1,1]{3,2,1,0}, f32[1,4,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[FUSION0_BITCAST]], [[GTE1]], [[GTE2]]),
+; CHECK:           custom_call_target="__cudnn$norm",
+; CHECK:           backend_config={
+; CHECK-DAG:         "epsilon":0
+; CHECK-DAG:         "kind":"LAYER_BWD"
+; CHECK:           }
+; CHECK-DAG:     [[GTE3:%[^ ]+]] = f32[96,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=0
+; CHECK-DAG:     [[FUSION1:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}) fusion([[GTE0]], [[GTE3]]), kind=kLoop, calls=[[FUSED_COMPUTATION1:%[^ ]+]]
+; CHECK-DAG:     [[GTEF1:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION1]]), index=0
+; CHECK-DAG:     [[GTEF2:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION1]]), index=1
+; CHECK-DAG:     [[GTE4:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=1
+; CHECK-DAG:     [[GTE4_BITCAST:%[^ ]+]] = f32[4]{0} bitcast([[GTE4]])
+; CHECK-DAG:     [[GTE5:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=2
+; CHECK-DAG:     [[GTE5_BITCAST:%[^ ]+]] = f32[4]{0} bitcast([[GTE5]])
+; CHECK-DAG:  ROOT [[OUT:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}, f32[4]{0}, f32[4]{0}) tuple([[GTEF1]], [[GTEF2]], [[GTE4_BITCAST]], [[GTE5_BITCAST]])
+  )";
+
+  TestNorm(hlo_text, optimized_hlo);
+}
+
+}  // namespace
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_pad_for_convolutions.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_pad_for_convolutions.cc
new file mode 100644
index 0000000..2acdf9a
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cudnn_pad_for_convolutions.cc
@@ -0,0 +1,528 @@
+/* Copyright 2018 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/cudnn_pad_for_convolutions.h"
+
+#include <cstdint>
+#include <functional>
+#include <memory>
+#include <optional>
+#include <tuple>
+#include <utility>
+#include <vector>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/functional/bind_front.h"
+#include "absl/status/status.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/literal_util.h"
+#include "xla/primitive_util.h"
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/service/gpu/cudnn_support_utils.h"
+#include "xla/service/gpu/stream_executor_util.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/util.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/logging.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+// Creates and returns an HLO that zero-pads one or more dimensions in the given
+// instruction so that its shape is equal to the given shape.
+//
+// Padding is added to the end of each relevant dimension.
+//
+// If the instruction already has the given shape, simply returns it without an
+// intervening pad.
+static HloInstruction* PadInstruction(HloInstruction* instr,
+                                      const Shape& new_shape) {
+  HloComputation* comp = instr->parent();
+
+  const Shape& shape = instr->shape();
+  PaddingConfig pad_config = MakeNoPaddingConfig(shape.rank());
+
+  bool added_padding = false;
+  for (int64_t dim = 0; dim < shape.rank(); ++dim) {
+    if (shape.dimensions(dim) == new_shape.dimensions(dim)) {
+      continue;
+    }
+    CHECK_GT(new_shape.dimensions(dim), shape.dimensions(dim));
+    pad_config.mutable_dimensions(dim)->set_edge_padding_high(
+        new_shape.dimensions(dim) - shape.dimensions(dim));
+    added_padding = true;
+  }
+  if (!added_padding) {
+    return instr;
+  }
+
+  auto* zero = comp->AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type())));
+  return comp->AddInstruction(
+      HloInstruction::CreatePad(new_shape, instr, zero, pad_config),
+      &instr->metadata());
+}
+
+// Modifies the given convolution to have the given input and result shapes.
+static absl::Status PadConv(HloCustomCallInstruction* conv,
+                            absl::Span<const Shape> new_input_shapes,
+                            const Shape& new_result_shape) {
+  CHECK_EQ(0, conv->shape().tuple_shapes(1).dimensions(0))
+      << "conv must use 0 scratch bytes, i.e. this pass must be run "
+         "before CudnnConvAlgorithmPicker.";
+  std::vector<HloInstruction*> new_operands;
+  new_operands.reserve(conv->operand_count());
+  for (int i = 0; i < conv->operand_count(); ++i) {
+    new_operands.push_back(
+        PadInstruction(conv->mutable_operand(i), new_input_shapes[i]));
+  }
+  const Shape& result_shape = conv->shape().tuple_shapes(0);
+
+  bool changed = false;
+  for (int i = 0; i < conv->operand_count(); ++i) {
+    changed |= (new_operands[i] != conv->mutable_operand(i));
+  }
+  CHECK(changed) << "We should have had to pad at least one input operand.";
+
+  auto add = [&](std::unique_ptr<HloInstruction> new_instr) {
+    return conv->parent()->AddInstruction(std::move(new_instr));
+  };
+
+  Shape new_conv_shape = ShapeUtil::MakeTupleShape(
+      {new_result_shape, ShapeUtil::MakeShape(U8, {0})});
+  auto* new_conv =
+      add(conv->CloneWithNewOperands(new_conv_shape, new_operands));
+
+  // Clone conv's name to new_conv.  This is safe because we're going to remove
+  // conv below.
+  new_conv->SetAndSanitizeName(conv->name());
+
+  VLOG(2) << "Padded features of " << conv->ToString() << ", replaced with "
+          << new_conv->ToString();
+
+  // Slice the new conv result if necessary, keeping in mind that new_conv
+  // has tuple shape (new_result_shape, u8[0]).
+  if (!ShapeUtil::Equal(result_shape, new_result_shape)) {
+    std::vector<int64_t> start_indices(result_shape.dimensions_size(), 0);
+    std::vector<int64_t> end_indices(result_shape.dimensions().begin(),
+                                     result_shape.dimensions().end());
+    std::vector<int64_t> strides(result_shape.dimensions_size(), 1);
+
+    auto* new_conv_result = add(
+        HloInstruction::CreateGetTupleElement(new_result_shape, new_conv, 0));
+    auto* empty_temp_buffer =
+        add(HloInstruction::CreateConstant(LiteralUtil::CreateR1<uint8_t>({})));
+    auto* sliced_result = add(HloInstruction::CreateSlice(
+        result_shape, new_conv_result, start_indices, end_indices, strides));
+    new_conv =
+        add(HloInstruction::CreateTuple({sliced_result, empty_temp_buffer}));
+  }
+
+  return conv->parent()->ReplaceInstruction(conv, new_conv);
+}
+
+static std::vector<HloCustomCallInstruction*> GetRelevantConvs(
+    HloComputation* comp) {
+  std::vector<HloCustomCallInstruction*> convs;
+  for (HloInstruction* instr : comp->instructions()) {
+    if (IsCustomCallToDnnConvolution(*instr)) {
+      convs.push_back(Cast<HloCustomCallInstruction>(instr));
+    }
+  }
+  return convs;
+}
+
+// This is the main function of the transform.  It runs on a given custom call
+// nodes to cuDNN convolution, calls resolve_pad_shapes to resolve
+// the desired input/output feature map shapes, and adds necessary padding and
+// slicing nodes around them.
+//
+// resolve_pad_shapes takes conv, a custom call instruction to cuDNN convolution
+// that may need padding to figure out the desired padded input and output
+// tensor shapes and store the desired shapes in new_input_shapes and
+// new_input_shapes.  Notice that new_input_shapes is a vector for multiple
+// input tensors. This function shall return true if padding is necessary or
+// false otherwise in addition to status.
+static absl::StatusOr<bool> ResolveAndPad(
+    HloCustomCallInstruction* conv,
+    std::function<absl::StatusOr<bool>(HloCustomCallInstruction* conv,
+                                       std::vector<Shape>* new_input_shapes,
+                                       Shape* new_result_shape)>
+        resolve_pad_shapes) {
+  std::vector<Shape> new_input_shapes;
+  Shape new_result_shape;
+  TF_ASSIGN_OR_RETURN(bool result, resolve_pad_shapes(conv, &new_input_shapes,
+                                                      &new_result_shape));
+  if (result) {
+    TF_RETURN_IF_ERROR(PadConv(conv, new_input_shapes, new_result_shape));
+    return true;
+  }
+  return false;
+}
+
+// Adds padding to cudnn convolutions to make them run faster on GPUs with
+// tensor cores.
+//
+//  - f16 convolutions are padded to have input/output channel dimensions that
+//    are multiples of 8, so that we can use tensor cores.
+//
+//  - f16 convolutions with 3 input channels and 32 or 64 output channels are
+//    padded to 4 input channels.  There's a special-cased cudnn algorithm just
+//    for this.
+//
+// Don't run this pass on GPUs without tensor cores -- it will make them slower!
+//
+// TODO(jlebar): Also pad dots.
+static absl::StatusOr<bool> TryResolvePaddedShapesForTensorCore(
+    HloCustomCallInstruction* conv, std::vector<Shape>* new_input_shapes_ptr,
+    Shape* new_result_shape_ptr) {
+  TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(conv));
+  const auto& dnums = conv->convolution_dimension_numbers();
+  auto* lhs = conv->mutable_operand(0);
+  auto* rhs = conv->mutable_operand(1);
+  const Shape& result_shape = conv->shape().tuple_shapes(0);
+
+  // Nothing to do on non-f16 convolutions.
+  if (result_shape.element_type() != PrimitiveType::F16) {
+    return false;
+  }
+
+  // When convolution is grouped, the shapes are in agreement with the group
+  // size. We cannot pad them independently.
+  if (conv->feature_group_count() > 1 || conv->batch_group_count() > 1) {
+    VLOG(2) << "Do not pad grouped convolution.";
+    return false;
+  }
+
+  // TODO(timshen): Don't skip forward-activation convs if we find a benchmark
+  // where there's a speedup.
+  if (kind == CudnnConvKind::kForwardActivation) {
+    return false;
+  }
+
+  Shape new_lhs_shape = lhs->shape();
+  Shape new_rhs_shape = rhs->shape();
+  Shape& new_result_shape = *new_result_shape_ptr;
+  new_result_shape = conv->shape().tuple_shapes(0);
+
+  // new_{input,filter_output}_shape points to the appropriate one of
+  // new_{lhs,rhs,result}_shape.
+  Shape* new_input_shape;
+  Shape* new_filter_shape;
+  Shape* new_output_shape;
+  std::tie(new_input_shape, new_filter_shape, new_output_shape) = [&] {
+    switch (kind) {
+      case CudnnConvKind::kForward:
+      case CudnnConvKind::kForwardActivation:
+      case CudnnConvKind::kForwardGraph:
+        return std::make_tuple(&new_lhs_shape, &new_rhs_shape,
+                               &new_result_shape);
+      case CudnnConvKind::kBackwardInput:
+        return std::make_tuple(&new_result_shape, &new_rhs_shape,
+                               &new_lhs_shape);
+      case CudnnConvKind::kBackwardFilter:
+        return std::make_tuple(&new_lhs_shape, &new_result_shape,
+                               &new_rhs_shape);
+    }
+  }();
+
+  // If there are 3 input features and 32 or 64 output features, pad the input
+  // features to 4.  Otherwise, try padding to multiples of 8 and check that
+  // this doesn't make any of the conv buffers too much larger.
+  auto input_features =
+      new_input_shape->dimensions(dnums.input_feature_dimension());
+  auto output_features =
+      new_output_shape->dimensions(dnums.output_feature_dimension());
+  if (input_features == 3 && (output_features == 32 || output_features == 64)) {
+    new_input_shape->set_dimensions(dnums.input_feature_dimension(), 4);
+    new_filter_shape->set_dimensions(dnums.kernel_input_feature_dimension(), 4);
+  } else {
+    auto pad_dim = [](Shape* s, int64_t dim) {
+      s->set_dimensions(dim, RoundUpTo<int64_t>(s->dimensions(dim), 8));
+    };
+    pad_dim(new_input_shape, dnums.input_feature_dimension());
+    pad_dim(new_filter_shape, dnums.kernel_input_feature_dimension());
+    pad_dim(new_filter_shape, dnums.kernel_output_feature_dimension());
+    pad_dim(new_output_shape, dnums.output_feature_dimension());
+
+    // We won't pad a conv if doing so increases the total number of bytes in
+    // the lhs, rhs, or result by more than this amount.
+    //
+    // TODO(jlebar): This number was tuned experimentally.  It represents a
+    // compromise on our current benchmarks; it speeds some up significantly,
+    // and doesn't slow any down.  But we can observe by changing this value
+    // that there's additional room for speedups.  Achieving those speedups
+    // without also slowing other things down will likely require a more
+    // sophisticated heuristic, possibly some form of auto-tuning.
+    static constexpr double kMaxBytesTouchedBound = 1.35;
+
+    // Check that padding wouldn't increase the total bytes read/written by this
+    // operation too much.
+    auto check_size_increase = [&](const Shape& old_shape,
+                                   const Shape& new_shape) {
+      int64_t old_bytes = ShapeUtil::ByteSizeOf(old_shape);
+      int64_t new_bytes = ShapeUtil::ByteSizeOf(new_shape);
+      if (new_bytes <= old_bytes * kMaxBytesTouchedBound) {
+        return true;
+      }
+      VLOG(3)
+          << "Not padding convolution; doing so would change input / result "
+             "shape from "
+          << ShapeUtil::HumanString(old_shape) << " to "
+          << ShapeUtil::HumanString(new_shape) << ", a size increase of "
+          << new_bytes / static_cast<double>(old_bytes) << "x > "
+          << kMaxBytesTouchedBound << "x: " << conv->ToString();
+      return false;
+    };
+
+    if (!check_size_increase(lhs->shape(), new_lhs_shape) ||
+        !check_size_increase(rhs->shape(), new_rhs_shape) ||
+        !check_size_increase(result_shape, new_result_shape)) {
+      return false;
+    }
+  }
+
+  if (ShapeUtil::Equal(lhs->shape(), new_lhs_shape) &&
+      ShapeUtil::Equal(rhs->shape(), new_rhs_shape)) {
+    VLOG(3) << "No need to pad features of " << conv->ToString();
+    return false;
+  }
+
+  new_input_shapes_ptr->push_back(new_lhs_shape);
+  new_input_shapes_ptr->push_back(new_rhs_shape);
+  return true;
+}
+
+// Adds padding to cudnn integer convolutions to make input and output feature
+// maps multiples of pad_to (usually 4 or 32).
+absl::StatusOr<bool> TryResolvePaddedShapesForIntegerConvolution(
+    int pad_to, const se::CudaComputeCapability& compute_capability,
+    HloCustomCallInstruction* conv, std::vector<Shape>* new_input_shapes_ptr,
+    Shape* new_result_shape_ptr) {
+  TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(conv));
+  const Shape& input_shape = conv->operand(0)->shape();
+  const Shape& kernel_shape = conv->operand(1)->shape();
+  const Shape& result_shape = conv->shape().tuple_shapes(0);
+
+  // Integer convolution only
+  if (!primitive_util::IsIntegralType(input_shape.element_type())) {
+    return false;
+  }
+
+  // kForward and kForwardActivation only
+  if (kind != CudnnConvKind::kForward &&
+      kind != CudnnConvKind::kForwardActivation) {
+    return false;
+  }
+
+  const auto& dnums = conv->convolution_dimension_numbers();
+  std::vector<Shape>& new_input_shapes = *new_input_shapes_ptr;
+  for (auto operand : conv->operands()) {
+    new_input_shapes.push_back(operand->shape());
+  }
+  Shape& new_result_shape = *new_result_shape_ptr;
+  new_result_shape = conv->shape().tuple_shapes(0);
+
+  // The input/kernel/output might already be vectorized (i.e. cudnn layout
+  // NCHW_VECT_C).  If so, we pad the features dim so that
+  // size(features_dim) * size(vect_dim) is a multiple of pad_to.
+  std::optional<int64_t> input_vect_dim;
+  std::optional<int64_t> kernel_vect_dim;
+  std::optional<int64_t> result_vect_dim;
+  std::tie(input_vect_dim, kernel_vect_dim, result_vect_dim) =
+      FindVectorizedFeatureDims(dnums, input_shape, kernel_shape, result_shape);
+
+  int64_t input_vect_size =
+      input_vect_dim.has_value() ? input_shape.dimensions(*input_vect_dim) : 1;
+  int64_t kernel_vect_size = kernel_vect_dim.has_value()
+                                 ? kernel_shape.dimensions(*kernel_vect_dim)
+                                 : 1;
+  int64_t result_vect_size = result_vect_dim.has_value()
+                                 ? result_shape.dimensions(*result_vect_dim)
+                                 : 1;
+  if (pad_to % input_vect_size != 0 || pad_to % kernel_vect_size != 0 ||
+      pad_to % result_vect_size != 0) {
+    // If the conv is already vectorized but pad_to is not a multiple of the
+    // vector size, we choose not to pad.  This is a weird case, because the
+    // only useful vector sizes in cudnn (as of writing) are 4 and 32, and those
+    // are also the only pad_to cases.
+    return false;
+  }
+
+  // Check that cudnn support our desired integer padding/vectorization.
+  TF_ASSIGN_OR_RETURN(bool cudnn_supports,
+                      CudnnSupportsOptimizedIntegerConvolution(
+                          compute_capability, *conv, pad_to));
+  if (!cudnn_supports) {
+    return false;
+  }
+
+  // Pad the features to multiples of pad_to.
+  {
+    auto pad_dim = [&](Shape* s, int64_t dim, int64_t cur_vect_size) {
+      CHECK_EQ(pad_to % cur_vect_size, 0);
+      s->set_dimensions(
+          dim, RoundUpTo<int64_t>(s->dimensions(dim), pad_to / cur_vect_size));
+    };
+
+    switch (kind) {
+      case CudnnConvKind::kForward:
+        CHECK_EQ(new_input_shapes.size(), 2);
+        // Input feature maps
+        pad_dim(new_input_shapes.data(), dnums.input_feature_dimension(),
+                input_vect_size);
+        // Kernel for the input feature maps
+        pad_dim(&new_input_shapes[1], dnums.kernel_input_feature_dimension(),
+                kernel_vect_size);
+        // Kernel for the output feature maps.  In the NCHW_VECT_C, only the
+        // kernel input feature dim is vectorized, so this has cur_vect_size 1.
+        pad_dim(&new_input_shapes[1], dnums.kernel_output_feature_dimension(),
+                /*cur_vect_size=*/1);
+        // Output feature maps
+        pad_dim(&new_result_shape, dnums.output_feature_dimension(),
+                result_vect_size);
+        break;
+      case CudnnConvKind::kForwardActivation:
+        CHECK(new_input_shapes.size() == 3 || new_input_shapes.size() == 4);
+        // Input feature maps
+        pad_dim(new_input_shapes.data(), dnums.input_feature_dimension(),
+                input_vect_size);
+        // Kernel for the input feature maps
+        pad_dim(&new_input_shapes[1], dnums.kernel_input_feature_dimension(),
+                kernel_vect_size);
+        // Kernel for the output feature maps.  In the NCHW_VECT_C, only the
+        // kernel input feature dim is vectorized, so this has cur_vect_size 1.
+        pad_dim(&new_input_shapes[1], dnums.kernel_output_feature_dimension(),
+                /*cur_vect_size=*/1);
+
+        // Bias.  This ia 1D vector of length output-depth, and it's unclear if
+        // we *have* to pad it.  But hey, we might as well.  cur_vect_size 1
+        // because NCHW_VECT_C doesn't apply here (there is no channels
+        // dimension!).
+        pad_dim(&new_input_shapes[2], /*dim=*/0, /*cur_vect_size=*/1);
+
+        if (new_input_shapes.size() == 4) {
+          // Optional side input.  Same layout as result, so gets padded the
+          // same.
+          pad_dim(&new_input_shapes[3], dnums.output_feature_dimension(),
+                  result_vect_size);
+        }
+        // Output feature maps
+        pad_dim(&new_result_shape, dnums.output_feature_dimension(),
+                result_vect_size);
+        break;
+      default:
+        CHECK(false);
+    }
+
+    // We won't pad a conv if doing so increases the total number of bytes in
+    // the lhs, rhs, or result by a factor of this much or more.
+    //
+    // Note: It's important that this bound is exclusive.  It's a performance
+    // regression to pad and increase input/output size by 2x, so we only pad
+    // strictly less than 2x.
+    //
+    // TODO(jlebar): This number was tuned experimentally, but without much
+    // experimental evidence.
+    static constexpr double kMaxBytesTouchedBound = 2;
+
+    // Check that padding wouldn't increase the total bytes read/written by this
+    // operation too much.
+    auto check_size_increase = [&](const Shape& old_shape,
+                                   const Shape& new_shape) {
+      int64_t old_bytes = ShapeUtil::ByteSizeOf(old_shape);
+      int64_t new_bytes = ShapeUtil::ByteSizeOf(new_shape);
+      if (new_bytes < old_bytes * kMaxBytesTouchedBound) {
+        return true;
+      }
+      VLOG(3)
+          << "Not padding convolution; doing so would change input / result "
+             "shape from "
+          << ShapeUtil::HumanString(old_shape) << " to "
+          << ShapeUtil::HumanString(new_shape) << ", a size increase of "
+          << new_bytes / static_cast<double>(old_bytes)
+          << "x >= " << kMaxBytesTouchedBound << "x: " << conv->ToString();
+      return false;
+    };
+
+    // Check size increase only on the input and output.  No need to check the
+    // filter, since that's determined by the input/output.  The bias (if
+    // present) is tiny (1D array of length output-depth), so padding doesn't
+    // matter.  And the side-input, if present, is the same shape as the input.
+    if (!check_size_increase(conv->operand(0)->shape(), new_input_shapes[0]) ||
+        !check_size_increase(result_shape, new_result_shape)) {
+      return false;
+    }
+  }
+
+  bool changed = false;
+  for (int64_t i = 0; i < conv->operand_count(); ++i) {
+    changed |=
+        !ShapeUtil::Equal(conv->operand(i)->shape(), new_input_shapes[i]);
+  }
+  if (!changed) {
+    VLOG(3) << "No need to pad features of " << conv->ToString();
+  }
+
+  return changed;
+}
+
+absl::StatusOr<bool> CudnnPadForConvolutions::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  bool changed = false;
+  for (HloComputation* comp :
+       module->MakeNonfusionComputations(execution_threads)) {
+    for (HloCustomCallInstruction* conv : GetRelevantConvs(comp)) {
+      // On Turing and later (sm75+), pad to multiples of 32 bytes if possible,
+      // because that lets us use the fast int8x32 data type.
+      bool local_changed = false;
+      if (compute_capability_.IsAtLeast(7, 5)) {
+        TF_ASSIGN_OR_RETURN(
+            local_changed,
+            ResolveAndPad(conv, absl::bind_front(
+                                    TryResolvePaddedShapesForIntegerConvolution,
+                                    32, compute_capability_)));
+      }
+      if (!local_changed) {
+        TF_ASSIGN_OR_RETURN(
+            local_changed,
+            ResolveAndPad(conv, absl::bind_front(
+                                    TryResolvePaddedShapesForIntegerConvolution,
+                                    4, compute_capability_)));
+      }
+      changed |= local_changed;
+    }
+    if (compute_capability_.IsAtLeast(se::CudaComputeCapability::VOLTA)) {
+      for (HloCustomCallInstruction* conv : GetRelevantConvs(comp)) {
+        TF_ASSIGN_OR_RETURN(
+            bool local_changed,
+            ResolveAndPad(conv, TryResolvePaddedShapesForTensorCore));
+        changed |= local_changed;
+      }
+    }
+  }
+  return changed;
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_pad_for_convolutions.h b/third_party/xla/xla/service/gpu/transforms/cudnn_pad_for_convolutions.h
new file mode 100644
index 0000000..719efec
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cudnn_pad_for_convolutions.h
@@ -0,0 +1,53 @@
+/* Copyright 2018 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_CUDNN_PAD_FOR_CONVOLUTIONS_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_CUDNN_PAD_FOR_CONVOLUTIONS_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/util.h"
+
+namespace xla {
+namespace gpu {
+
+// Two zero-paddings for CuDNN thunking are done in this transform: padding for
+// tensor cores and padding for integer convolutions.  This transform also
+// add slice instruction to remove unnecessary output features.
+class CudnnPadForConvolutions : public HloModulePass {
+ public:
+  explicit CudnnPadForConvolutions(se::CudaComputeCapability compute_capability)
+      : compute_capability_(compute_capability) {}
+
+  absl::string_view name() const override {
+    return "cudnn_pad_for_convolutions";
+  }
+  // Run PadForConvolutions on the given module and return if any change is made
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+  const se::CudaComputeCapability compute_capability_;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_CUDNN_PAD_FOR_CONVOLUTIONS_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_pad_for_convolutions_test.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_pad_for_convolutions_test.cc
new file mode 100644
index 0000000..7cee2c5
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cudnn_pad_for_convolutions_test.cc
@@ -0,0 +1,456 @@
+/* Copyright 2018 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/cudnn_pad_for_convolutions.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/service/hlo_parser.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/tests/hlo_test_base.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+namespace m = xla::match;
+
+class CudnnPadForConvolutionsTest : public HloTestBase {};
+
+TEST_F(CudnnPadForConvolutionsTest, DoNotPadF16ForwardConvWhenGrouped) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule TestModule
+
+  ENTRY TestComputation {
+    input = f16[704,48,1,49]{3,2,1,0} parameter(0)
+    filter = f16[44,768,1,50]{3,2,1,0} parameter(1)
+    ROOT result = (f16[1,128,48,768]{3,2,1,0}, u8[0]{0})
+      custom-call(input, filter)
+      , window={size=1x50 pad=0_0x64_64}
+      , dim_labels=fb01_io01->01bf
+      , feature_group_count=16
+      , custom_call_target="__cudnn$convForward"
+  })")
+                    .value();
+  EXPECT_FALSE(CudnnPadForConvolutions({7, 5}).Run(module.get()).value());
+}
+
+TEST_F(CudnnPadForConvolutionsTest, PadF16ForwardConvInputChannels) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule TestModule
+
+  ENTRY TestComputation {
+    input = f16[10,20,30,41] parameter(0)
+    filter = f16[2,2,41,40] parameter(1)
+    ROOT result = (f16[10,20,30,40], u8[0]) custom-call(input, filter),
+                  window={size=2x2}, dim_labels=b01f_01io->b01f,
+                  custom_call_target="__cudnn$convForward"
+  })")
+                    .value();
+  EXPECT_TRUE(CudnnPadForConvolutions({7, 0}).Run(module.get()).value());
+  auto* root = module->entry_computation()->root_instruction();
+
+  SCOPED_TRACE(module->ToString());
+
+  EXPECT_THAT(
+      root,
+      GmockMatch(m::CustomCall(
+          {kCudnnConvForwardCallTarget},
+          m::Pad(m::Parameter(0), m::Op()).WithShape(F16, {10, 20, 30, 48}),
+          m::Pad(m::Parameter(1), m::Op()).WithShape(F16, {2, 2, 48, 40}))));
+}
+
+TEST_F(CudnnPadForConvolutionsTest, PadF16BackwardInputConvOutputChannels) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule TestModule
+
+  ENTRY TestComputation {
+    output = f16[10,20,30,41] parameter(0)
+    filter = f16[2,2,40,41] parameter(1)
+    ROOT result = (f16[10,20,30,40], u8[0]) custom-call(output, filter),
+                  window={size=2x2}, dim_labels=b01f_01io->b01f,
+                  custom_call_target="__cudnn$convBackwardInput"
+  })")
+                    .value();
+  EXPECT_TRUE(CudnnPadForConvolutions({7, 0}).Run(module.get()).value());
+  auto* root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(
+      root,
+      GmockMatch(m::CustomCall(
+          {kCudnnConvBackwardInputCallTarget},
+          m::Pad(m::Parameter(0), m::Op()).WithShape(F16, {10, 20, 30, 48}),
+          m::Pad(m::Parameter(1), m::Op()).WithShape(F16, {2, 2, 40, 48}))));
+}
+
+TEST_F(CudnnPadForConvolutionsTest, PadF16ForwardConvOutputChannels) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule TestModule
+
+  ENTRY TestComputation {
+    input = f16[10,20,30,40] parameter(0)
+    filter = f16[2,2,40,41] parameter(1)
+    ROOT result = (f16[10,20,30,41], u8[0]) custom-call(input, filter),
+                  window={size=2x2}, dim_labels=b01f_01io->b01f,
+                  custom_call_target="__cudnn$convForward"
+  })")
+                    .value();
+  EXPECT_TRUE(CudnnPadForConvolutions({7, 0}).Run(module.get()).value());
+  auto* root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(root, GmockMatch(m::Tuple(
+                        m::Slice(m::GetTupleElement(m::CustomCall(
+                            {kCudnnConvForwardCallTarget}, m::Parameter(0),
+                            m::Pad(m::Parameter(1), m::Op())))),
+                        m::Op())));
+}
+
+TEST_F(CudnnPadForConvolutionsTest, PadF16BackwardInputConvInputChannels) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule TestModule
+
+  ENTRY TestComputation {
+    output = f16[10,20,30,40] parameter(0)
+    filter = f16[2,2,41,40] parameter(1)
+    result = (f16[10,20,30,41], u8[0]) custom-call(output, filter),
+              window={size=2x2}, dim_labels=b01f_01io->b01f,
+              custom_call_target="__cudnn$convBackwardInput"
+    ROOT gte = f16[10,20,30,41] get-tuple-element(result), index=0
+  })")
+                    .value();
+  EXPECT_TRUE(CudnnPadForConvolutions({7, 0}).Run(module.get()).value());
+  auto* root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(root,
+              GmockMatch(m::GetTupleElement(m::Tuple(
+                  m::Slice(m::GetTupleElement(m::CustomCall(
+                      {kCudnnConvBackwardInputCallTarget}, m::Parameter(0),
+                      m::Pad(m::Parameter(1), m::Op())))),
+                  m::Op()))));
+}
+
+TEST_F(CudnnPadForConvolutionsTest, PadF16BackwardFilterConvInputChannels) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule TestModule
+
+  ENTRY TestComputation {
+    input = f16[10,20,30,41] parameter(0)
+    output = f16[10,20,30,40] parameter(1)
+    result = (f16[2,2,41,40], u8[0]) custom-call(input, output),
+              window={size=2x2}, dim_labels=b01f_01io->b01f,
+              custom_call_target="__cudnn$convBackwardFilter"
+    ROOT gte = f16[2,2,41,40] get-tuple-element(result), index=0
+  })")
+                    .value();
+  EXPECT_TRUE(CudnnPadForConvolutions({7, 0}).Run(module.get()).value());
+  auto* root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(root,
+              GmockMatch(m::GetTupleElement(m::Tuple(
+                  m::Slice(m::GetTupleElement(m::CustomCall(
+                      {kCudnnConvBackwardFilterCallTarget},
+                      m::Pad(m::Parameter(0), m::Op()), m::Parameter(1)))),
+                  m::Op()))));
+}
+
+TEST_F(CudnnPadForConvolutionsTest, PadF16BackwardFilterConvOutputChannels) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule TestModule
+
+  ENTRY TestComputation {
+    input = f16[10,20,30,40] parameter(0)
+    output = f16[10,20,30,41] parameter(1)
+    result = (f16[2,2,40,41], u8[0]) custom-call(input, output),
+              window={size=2x2}, dim_labels=b01f_01io->b01f,
+              custom_call_target="__cudnn$convBackwardFilter"
+    ROOT gte = f16[2,2,40,41] get-tuple-element(result), index=0
+  })")
+                    .value();
+  EXPECT_TRUE(CudnnPadForConvolutions({7, 0}).Run(module.get()).value());
+  auto* root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(root,
+              GmockMatch(m::GetTupleElement(m::Tuple(
+                  m::Slice(m::GetTupleElement(m::CustomCall(
+                      {kCudnnConvBackwardFilterCallTarget}, m::Parameter(0),
+                      m::Pad(m::Parameter(1), m::Op())))),
+                  m::Op()))));
+}
+
+TEST_F(CudnnPadForConvolutionsTest, PadInputFeatures3To4) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule TestModule
+
+  ENTRY TestComputation {
+    input = f16[10,20,30,3] parameter(0)
+    filter = f16[2,2,3,32] parameter(1)
+    ROOT result = (f16[10,20,30,32], u8[0]) custom-call(input, filter),
+                  window={size=2x2}, dim_labels=b01f_01io->b01f,
+                  custom_call_target="__cudnn$convForward"
+  })")
+                    .value();
+  EXPECT_TRUE(CudnnPadForConvolutions({7, 0}).Run(module.get()).value());
+  auto* root = module->entry_computation()->root_instruction();
+
+  SCOPED_TRACE(module->ToString());
+  EXPECT_THAT(
+      root,
+      GmockMatch(m::CustomCall(
+          {kCudnnConvForwardCallTarget},
+          m::Pad(m::Parameter(0), m::Op()).WithShape(F16, {10, 20, 30, 4}),
+          m::Pad(m::Parameter(1), m::Op()).WithShape(F16, {2, 2, 4, 32}))));
+}
+
+TEST_F(CudnnPadForConvolutionsTest, PadIntForwardConvInputChannels) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule TestModule
+
+  ENTRY TestComputation {
+    input = s8[10,20,30,41] parameter(0)
+    filter = s8[2,2,41,40] parameter(1)
+    ROOT result = (f32[10,20,30,40], u8[0]) custom-call(input, filter),
+                  window={size=2x2}, dim_labels=b01f_01io->b01f,
+                  custom_call_target="__cudnn$convForward"
+  })")
+                    .value();
+  EXPECT_TRUE(CudnnPadForConvolutions({7, 0}).Run(module.get()).value());
+  auto* root = module->entry_computation()->root_instruction();
+
+  SCOPED_TRACE(module->ToString());
+  EXPECT_THAT(
+      root,
+      GmockMatch(m::CustomCall(
+          {kCudnnConvForwardCallTarget},
+          m::Pad(m::Parameter(0), m::Op()).WithShape(S8, {10, 20, 30, 44}),
+          m::Pad(m::Parameter(1), m::Op()).WithShape(S8, {2, 2, 44, 40}))));
+}
+
+TEST_F(CudnnPadForConvolutionsTest, PadIntForwardConvOutputChannels) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule TestModule
+
+  ENTRY TestComputation {
+    input = s8[10,20,30,40] parameter(0)
+    filter = s8[2,2,40,41] parameter(1)
+    ROOT result = (f32[10,20,30,41], u8[0]) custom-call(input, filter),
+                  window={size=2x2}, dim_labels=b01f_01io->b01f,
+                  custom_call_target="__cudnn$convForward"
+  })")
+                    .value();
+  EXPECT_TRUE(CudnnPadForConvolutions({7, 0}).Run(module.get()).value());
+  auto* root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(root, GmockMatch(m::Tuple(
+                        m::Slice(m::GetTupleElement(m::CustomCall(
+                            {kCudnnConvForwardCallTarget}, m::Parameter(0),
+                            m::Pad(m::Parameter(1), m::Op())))),
+                        m::Op())));
+}
+
+TEST_F(CudnnPadForConvolutionsTest, PadInt8To32OnSm75) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule TestModule
+
+  ENTRY TestComputation {
+    input = s8[10,20,30,40] parameter(0)
+    filter = s8[2,2,40,41] parameter(1)
+    ROOT result = (s8[10,20,30,41], u8[0]) custom-call(input, filter),
+                  window={size=2x2}, dim_labels=b01f_01io->b01f,
+                  custom_call_target="__cudnn$convForward"
+  })")
+                    .value();
+  EXPECT_TRUE(CudnnPadForConvolutions({7, 5}).Run(module.get()).value());
+  auto* root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(
+      root,
+      GmockMatch(m::Tuple(
+          m::Slice(m::GetTupleElement(m::CustomCall(
+              {kCudnnConvForwardCallTarget},
+              m::Pad(m::Parameter(0), m::Op()).WithShape(S8, {10, 20, 30, 64}),
+              m::Pad(m::Parameter(1), m::Op()).WithShape(S8, {2, 2, 64, 64})))),
+          m::Op())));
+}
+
+TEST_F(CudnnPadForConvolutionsTest, NoPadInt8To32OnSm70) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule TestModule
+
+  ENTRY TestComputation {
+    input = s8[10,20,30,40] parameter(0)
+    filter = s8[2,2,40,41] parameter(1)
+    ROOT result = (s8[10,20,30,41], u8[0]) custom-call(input, filter),
+                  window={size=2x2}, dim_labels=b01f_01io->b01f,
+                  custom_call_target="__cudnn$convForward"
+  })")
+                    .value();
+  EXPECT_TRUE(CudnnPadForConvolutions({7, 0}).Run(module.get()).value());
+  auto* root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(
+      root,
+      GmockMatch(m::Tuple(
+          m::Slice(m::GetTupleElement(m::CustomCall(
+              {kCudnnConvForwardCallTarget}, m::Parameter(0),
+              m::Pad(m::Parameter(1), m::Op()).WithShape(S8, {2, 2, 40, 44})))),
+          m::Op())));
+}
+
+TEST_F(CudnnPadForConvolutionsTest, NoPadInt8To32FloatOutputSm75) {
+  // This test checks that the padding pass correctly calls
+  // CudnnSupportsOptimizedIntegerConvolution() which should reject this
+  // convolution because its output type is f32. It should be padded to int8x4
+  // because that supports float outputs.
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule TestModule
+
+  ENTRY TestComputation {
+    input = s8[10,20,30,38] parameter(0)
+    filter = s8[2,2,38,41] parameter(1)
+    ROOT result = (f32[10,20,30,41], u8[0]) custom-call(input, filter),
+                  window={size=2x2}, dim_labels=b01f_01io->b01f,
+                  custom_call_target="__cudnn$convForward"
+  })")
+                    .value();
+  EXPECT_TRUE(CudnnPadForConvolutions({7, 5}).Run(module.get()).value());
+  auto* root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(
+      root,
+      GmockMatch(m::Tuple(
+          m::Slice(m::GetTupleElement(m::CustomCall(
+              {kCudnnConvForwardCallTarget},
+              m::Pad(m::Parameter(0), m::Op()).WithShape(S8, {10, 20, 30, 40}),
+              m::Pad(m::Parameter(1), m::Op()).WithShape(S8, {2, 2, 40, 44})))),
+          m::Op())));
+}
+
+TEST_F(CudnnPadForConvolutionsTest, NoPadInt8UnsupportedFilterTypeOutputSm75) {
+  // This test checks that the padding pass correctly calls
+  // CudnnSupportsOptimizedIntegerConvolution() which should reject this
+  // convolution because kernel type is f32.
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule TestModule
+
+  ENTRY TestComputation {
+    input = s8[10,20,30,38] parameter(0)
+    filter = f32[2,2,38,41] parameter(1)
+    ROOT result = (s8[10,20,30,41], u8[0]) custom-call(input, filter),
+                  window={size=2x2}, dim_labels=b01f_01io->b01f,
+                  custom_call_target="__cudnn$convForward"
+  })")
+                    .value();
+  EXPECT_FALSE(CudnnPadForConvolutions({7, 5}).Run(module.get()).value());
+}
+
+TEST_F(CudnnPadForConvolutionsTest, NoPadToInt8x32ExcessiveBlowup) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule TestModule
+
+  ENTRY TestComputation {
+    input = s8[128,4,48,48] parameter(0)
+    filter = s8[64,4,3,3] parameter(1)
+    ROOT result = (f32[128,64,48,48], u8[0]) custom-call(input, filter),
+                  window={size=3x3}, dim_labels=bf01_io01->bf01,
+                  custom_call_target="__cudnn$convForward"
+  })")
+                    .value();
+  EXPECT_FALSE(CudnnPadForConvolutions({7, 5}).Run(module.get()).value());
+}
+
+TEST_F(CudnnPadForConvolutionsTest, PadInt8x4To32) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule TestModule
+
+  ENTRY TestComputation {
+    input = s8[10,20,30,41,4] parameter(0)
+    filter = s8[2,2,41,4,168] parameter(1)
+    ROOT result = (s8[10,20,30,42,4], u8[0]) custom-call(input, filter),
+                  window={size=2x2}, dim_labels=b01f?_01i?o->b01f?,
+                  custom_call_target="__cudnn$convForward"
+  })")
+                    .value();
+  EXPECT_TRUE(CudnnPadForConvolutions({7, 5}).Run(module.get()).value());
+  auto* root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(
+      root,
+      GmockMatch(m::Tuple(
+          m::Slice(m::GetTupleElement(
+                       m::CustomCall({kCudnnConvForwardCallTarget},
+                                     m::Pad(m::Parameter(0), m::Op())
+                                         .WithShape(S8, {10, 20, 30, 48, 4}),
+                                     m::Pad(m::Parameter(1), m::Op())
+                                         .WithShape(S8, {2, 2, 48, 4, 192})))
+                       .WithShape(S8, {10, 20, 30, 48, 4})),
+          m::Op())));
+}
+
+TEST_F(CudnnPadForConvolutionsTest, PadInt8x4To32BiasActivation) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule TestModule
+
+  ENTRY TestComputation {
+    input = s8[10,20,30,41,4] parameter(0)
+    filter = s8[2,2,41,4,168] parameter(1)
+    bias = f32[10] parameter(2)
+    side_input = s8[10,20,30,42,4] parameter(3)
+    ROOT result = (s8[10,20,30,42,4], u8[0]) custom-call(input, filter, bias, side_input),
+                  window={size=2x2}, dim_labels=b01f?_01i?o->b01f?,
+                  custom_call_target="__cudnn$convBiasActivationForward"
+  })")
+                    .value();
+  EXPECT_TRUE(CudnnPadForConvolutions({7, 5}).Run(module.get()).value());
+  auto* root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(
+      root,
+      GmockMatch(m::Tuple(
+          m::Slice(
+              m::GetTupleElement(
+                  m::CustomCall(
+                      {kCudnnConvBiasActivationForwardCallTarget},
+                      m::Pad(m::Parameter(0), m::Op())
+                          .WithShape(S8, {10, 20, 30, 48, 4}),
+                      m::Pad(m::Parameter(1), m::Op())
+                          .WithShape(S8, {2, 2, 48, 4, 192}),
+                      m::Pad(m::Parameter(2), m::Op()).WithShape(F32, {32}),
+                      m::Pad(m::Parameter(3), m::Op())
+                          .WithShape(S8, {10, 20, 30, 48, 4})))
+                  .WithShape(S8, {10, 20, 30, 48, 4})),
+          m::Op())));
+}
+
+TEST_F(CudnnPadForConvolutionsTest,
+       PadIntFusedForwardConvInputAndOutputChannels) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule Test
+
+    ENTRY %Test (input: s8[1,3,3,2], filter: s8[3,3,2,5], side_input: s8[1,3,3,5], bias: s8[5]) -> f32[1,3,3,5] {
+    %input = s8[1,3,3,3]{3,2,1,0} parameter(0)
+    %filter = s8[3,3,2,5]{3,2,1,0} parameter(1)
+    %bias = s8[5]{0} parameter(3)
+    %convert = f32[5]{0} convert(s8[5]{0} %bias)
+    %side_input = f32[1,3,3,5]{3,2,1,0} parameter(2)
+    %custom-call.1 = (f32[1,3,3,5]{3,2,1,0}, u8[0]{0}) custom-call(s8[1,3,3,3]{3,2,1,0} %input, s8[3,3,2,5]{3,2,1,0} %filter, f32[5]{0} %convert, f32[1,3,3,5]{3,2,1,0} %side_input), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convBiasActivationForward", backend_config="{\"activationMode\":\"2\",\"convResultScale\":1,\"sideInputScale\":1}"
+    ROOT %get-tuple-element.1 = f32[1,3,3,5]{3,2,1,0} get-tuple-element((f32[1,3,3,5]{3,2,1,0}, u8[0]{0}) %custom-call.1), index=0
+    })")
+                    .value();
+  EXPECT_TRUE(CudnnPadForConvolutions({7, 0}).Run(module.get()).value());
+  auto* root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(root, GmockMatch(m::GetTupleElement(m::Tuple(
+                        m::Slice(m::GetTupleElement(m::CustomCall(
+                            {kCudnnConvBiasActivationForwardCallTarget},
+                            m::Pad(m::Parameter(0), m::Op()),
+                            m::Pad(m::Parameter(1), m::Op()),
+                            m::Pad(m::Convert(m::Parameter(3)), m::Op()),
+                            m::Pad(m::Parameter(2), m::Op())))),
+                        m::Op()))));
+}
+
+}  // anonymous namespace
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_simplify_padding.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_simplify_padding.cc
new file mode 100644
index 0000000..30e0a4b
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cudnn_simplify_padding.cc
@@ -0,0 +1,482 @@
+/* Copyright 2022 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/cudnn_simplify_padding.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <iterator>
+#include <optional>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/container/inlined_vector.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/literal.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/service/hlo_creation_utils.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/logging.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla::gpu {
+
+namespace {
+namespace m = ::xla::match;
+
+// If exactly one index of `dims` is false, returns that index.  If 0 or more
+// than one index is false, returns nullopt.
+std::optional<int64_t> FindFalseIndex(absl::Span<const bool> vals) {
+  std::optional<int64_t> missing_dim;
+  for (int i = 0; i < vals.size(); i++) {
+    if (vals[i]) {
+      continue;
+    }
+    if (missing_dim.has_value()) {
+      VLOG(2) << "Multiple dimensions are missing from conv dnums; can't "
+                 "determine which is vect_c dimension";
+      return std::nullopt;
+    }
+    missing_dim = i;
+  }
+  return missing_dim;
+}
+
+// Finds the vect_c dimension in the convolution's output.
+//
+// The vect_c dimension in dnums is the dimension that's not mentioned in
+// `dnums`.  If there's zero or more than one such dimension, returns nullopt.
+std::optional<int64_t> FindOutputVectCDim(HloInstruction* conv) {
+  const ConvolutionDimensionNumbers& dnums =
+      conv->convolution_dimension_numbers();
+  int64_t num_dims = conv->shape().tuple_shapes(0).dimensions_size();
+  absl::InlinedVector<bool, 5> seen_dims(num_dims);
+  seen_dims[dnums.output_batch_dimension()] = true;
+  seen_dims[dnums.output_feature_dimension()] = true;
+  for (int64_t d : dnums.output_spatial_dimensions()) {
+    seen_dims[d] = true;
+  }
+  return FindFalseIndex(seen_dims);
+}
+
+// Finds the vect_c dimension in the convolution's kernel.
+std::optional<int64_t> FindKernelVectCDim(HloInstruction* conv) {
+  const ConvolutionDimensionNumbers& dnums =
+      conv->convolution_dimension_numbers();
+  int64_t num_dims = conv->operand(1)->shape().dimensions_size();
+  absl::InlinedVector<bool, 5> seen_dims(num_dims);
+  seen_dims[dnums.kernel_input_feature_dimension()] = true;
+  seen_dims[dnums.kernel_output_feature_dimension()] = true;
+  for (int64_t d : dnums.kernel_spatial_dimensions()) {
+    seen_dims[d] = true;
+  }
+  return FindFalseIndex(seen_dims);
+}
+
+// Attempts to count the number of output features at the end of conv that are
+// guaranteed to be 0.
+//
+// This is the same as counting the number of values o at the end of the kernel
+// for which kernel[i,o,h,w] is 0 for all values i,h,w.
+std::optional<int64_t> NumTrailingZeroOutputFeatures(HloInstruction* conv) {
+  const ConvolutionDimensionNumbers& dnums =
+      conv->convolution_dimension_numbers();
+  int64_t feature_dim = dnums.kernel_output_feature_dimension();
+  const HloInstruction* weights = conv->operand(1);
+
+  // If the filter is reordered for an int8x32 NCHW_VECT_C convolution, find the
+  // original, un-reordered filter and check *it* for trailing zero output
+  // features.
+  auto backend_config = conv->backend_config<GpuBackendConfig>();
+  if (backend_config.ok() &&
+      backend_config->cudnn_conv_backend_config().reordered_int8_nchw_vect()) {
+    VLOG(2) << "Matched int8x32 convolution with filter reordering";
+
+    // Try to set weights to the original, un-reordered value.
+    const HloInstruction *reshape, *transpose;
+    bool matched =
+        Match(weights, m::Reshape(m::Transpose(
+                           &transpose, m::Reshape(&reshape, m::Op(&weights)))));
+
+    // Verify some properties of the reshape-transpose-reshape pattern.
+    // If these don't hold, it means that some pass (e.g. constant folding)
+    // has modified the filter, making making it infeasible to get the original,
+    // un-reordered value.
+    if (!matched || feature_dim != 0 || transpose->shape().rank() != 8) {
+      VLOG(2) << "The filter output feature dimension cannot be determined, as "
+                 "the reordering sequence is modified";
+      return std::nullopt;
+    }
+
+    // Calculate the actual output feature dimension before the transpose.
+    // For example: the input filter [I, O, H, W] will get reshaped to
+    // [I/32, 8, 4, O/8, 4, 2, H, W], transposed in a way that is compatible
+    // with cuDNN INT8x32_CONFIG convolutions (see 'cudnn_support_utils.h') and
+    // reshaped again to [O, I/32, H, W, 32]. While the output features
+    // dimension is zero, we need to get the dimension in the original shape
+    // (equals to one in this example).
+    const auto& transpose_dimensions =
+        Cast<HloTransposeInstruction>(transpose)->dimensions();
+
+    // Calculate combined dimensions size before the first appearing output
+    // component [O/8], which appears in position 3 of the transpose.
+    int64_t preceding_size = 1;
+    for (int64_t i = transpose_dimensions.at(3) - 1; i >= 0; --i) {
+      preceding_size *= reshape->shape().dimensions(i);
+    }
+
+    // Skip dimensions in front until the output features dimension is found.
+    int64_t accumulated_size = 1;
+    for (int64_t size : weights->shape().dimensions()) {
+      if (accumulated_size < preceding_size) {
+        accumulated_size *= size;
+        ++feature_dim;
+      } else {
+        break;
+      }
+    }
+    // Sanity check; if this condition doesn't hold, something is off.
+    if (accumulated_size != preceding_size) {
+      VLOG(2) << "Something is really wrong here, I give up";
+      return std::nullopt;
+    }
+    VLOG(2) << "Computed output feature dimension: " << feature_dim;
+  }
+
+  VLOG(2) << "Computing NumTrailingZeroOutputFeatures of " << conv->ToString()
+          << "\nwith weights " << weights->ToString();
+  if (Match(weights, m::Pad(m::Op(), m::ConstantEffectiveScalar(0)))) {
+    const PaddingConfig::PaddingConfigDimension& padding_config =
+        weights->padding_config().dimensions(feature_dim);
+    // The last N output feature weights are all 0.
+    VLOG(2) << "Success: Weights is a pad; padding on output feature dim is "
+            << padding_config.edge_padding_high();
+    return padding_config.edge_padding_high();
+  } else if (const HloInstruction * pad; Match(
+                 weights, m::Reshape(m::Pad(&pad, m::Op(),
+                                            m::ConstantEffectiveScalar(0))))) {
+    // Check that the reshape merely adds a VECT_C to the kernel input features.
+    // That is, we reshape from [I,O,H,W] (in some order) to [I/k,k,O,H,W] (in
+    // the same order) for some constant k (probably 32).  Then check how much
+    // the pad adds to the O dimension.
+    std::optional<int64_t> vect_c_dim = FindKernelVectCDim(conv);
+    if (!vect_c_dim.has_value()) {
+      VLOG(2) << "fail: Can't find vect_c dimension in conv.";
+      return std::nullopt;
+    }
+    if (*vect_c_dim != dnums.kernel_input_feature_dimension() + 1) {
+      VLOG(2) << "fail: vect_c dim is in the wrong place; should be right "
+                 "after kernel input feature dims in conv.";
+      return std::nullopt;
+    }
+    absl::InlinedVector<int64_t, 5> expected_pad_dim_sizes(
+        weights->shape().dimensions().begin(),
+        weights->shape().dimensions().end());
+    expected_pad_dim_sizes[dnums.kernel_input_feature_dimension()] *=
+        weights->shape().dimensions(*vect_c_dim);
+    expected_pad_dim_sizes.erase(expected_pad_dim_sizes.begin() + *vect_c_dim);
+    if (pad->shape().dimensions() != expected_pad_dim_sizes) {
+      VLOG(2) << "fail: Reshape doesn't simply merge vect_c dimension into "
+                 "input features dim "
+              << weights->ToString() << " but expected dims "
+              << absl::StrJoin(expected_pad_dim_sizes, ",");
+      return std::nullopt;
+    }
+
+    // If the filter dnums are e.g. [I,O,H,W] then after reshape they are
+    // [I/k,k,O,H,W] and the new index of O is greater less than before the
+    // reshape (which we know only adds the I/k and k dims, which we also know
+    // are contiguous).  OTOH if the O comes before the I in the original, then
+    // the index of O doesn't change after the reshape.
+    int64_t feature_dim_before_reshape = feature_dim;
+    if (dnums.kernel_output_feature_dimension() >
+        dnums.kernel_input_feature_dimension()) {
+      feature_dim_before_reshape--;
+    }
+    const PaddingConfig::PaddingConfigDimension& padding_config =
+        pad->padding_config().dimensions(feature_dim_before_reshape);
+
+    // The last N output feature weights are all 0.
+    VLOG(2) << "Success: Weights is a reshape of a pad; padding on output "
+               "feature dim is "
+            << padding_config.edge_padding_high();
+    return padding_config.edge_padding_high();
+  } else if (Match(weights, m::Constant())) {
+    // Iterate backwards over `weights` to find the index of the first nonzero
+    // value.
+    //
+    // TODO(jlebar): This may be slow, because it iterates over potentially the
+    // whole constant and does a multi_index -> linear_index conversion for each
+    // element. If necessary we could rewrite this by using linear indices, but
+    // we'd have to be careful of the fact that literals can have arbitrary
+    // layouts, so you can't just iterate over the literal's bytes.
+    const Literal& lit = weights->literal();
+    const auto& dims = weights->shape().dimensions();
+    absl::InlinedVector<int64_t, 5> multi_index;
+    for (int64_t dim : dims) {
+      multi_index.push_back(dim - 1);
+    }
+    // This iterates through the literal with feature_dim as the most
+    // major dimension looking for the final non-zero feature.
+    auto decrement_multi_index = [&] {
+      for (int i = 0; i < multi_index.size(); ++i) {
+        if (i != feature_dim) {
+          int64_t& idx = multi_index[i];
+          --idx;
+          if (idx == -1) {
+            idx = dims[i] - 1;
+          } else {
+            return true;
+          }
+        }
+      }
+      int64_t& idx = multi_index[feature_dim];
+      --idx;
+      return idx != -1;
+    };
+    do {
+      if (!lit.IsZero(multi_index)) {
+        break;
+      }
+    } while (decrement_multi_index());
+
+    // The iteration stops if a feature has a non-zero value (or -1), but we
+    // want the first zero feature which is always the next one (or 0 if -1).
+    int64_t first_trailing_zero_feature = multi_index[feature_dim] + 1;
+
+    if (first_trailing_zero_feature == 0) {
+      VLOG(2) << "Weights constant is entirely zero.";
+    } else {
+      VLOG(2) << "First nonzero index in weights constant is "
+              << absl::StrJoin(multi_index, ",");
+    }
+    int64_t ret =
+        std::max<int64_t>(0, weights->shape().dimensions(feature_dim) -
+                                 first_trailing_zero_feature);
+    VLOG(2) << "Success: weights is a constant; num zero trailing output "
+               "features is "
+            << ret;
+    return ret;
+  }
+  return std::nullopt;
+}
+
+absl::StatusOr<bool> TrySimplifyPadding(HloInstruction* instr) {
+  // Match one of the following patterns.
+  //   conv -> slice -> pad
+  //   conv -> reshape -> slice-> pad
+  //   conv -> transpose -> reshape -> slice -> pad
+  //
+  // where `pad` (the root of the pattern) is `instr`.
+  HloInstruction* conv;
+  HloInstruction* transpose = nullptr;  // optional
+  HloInstruction* reshape = nullptr;    // optional
+  HloInstruction* slice;
+  HloInstruction* pad;
+  auto conv_matcher = m::GetTupleElement(
+      m::CustomCall(&conv).WithPredicate([](const HloInstruction* instr) {
+        return instr->custom_call_target() == kCudnnConvForwardCallTarget ||
+               instr->custom_call_target() ==
+                   kCudnnConvBiasActivationForwardCallTarget;
+      }),
+      0);
+  auto pad_matcher = m::Pad(m::Op(), m::ConstantEffectiveScalar(0));
+  if (!MatchAndLogIfFailed(instr, "conv-slice-pad",
+                           m::Pad(&pad, m::Slice(&slice, conv_matcher),
+                                  m::ConstantEffectiveScalar(0)),
+                           VLOG_IS_ON(3), pad_matcher) &&
+      !MatchAndLogIfFailed(
+          instr, "conv-reshape-slice-pad",
+          m::Pad(&pad, m::Slice(&slice, m::Reshape(&reshape, conv_matcher)),
+                 m::ConstantEffectiveScalar(0)),
+          VLOG_IS_ON(3), pad_matcher) &&
+      !MatchAndLogIfFailed(
+          instr, "conv-transpose-reshape-slice-pad",
+          m::Pad(&pad,
+                 m::Slice(&slice,
+                          m::Reshape(&reshape,
+                                     m::Transpose(&transpose, conv_matcher))),
+                 m::ConstantEffectiveScalar(0)),
+          VLOG_IS_ON(3), pad_matcher)) {
+    return false;
+  }
+
+  VLOG(2) << "Found pattern to attempt to simplify:\n"
+          << "conv: " << conv->ToString()  //
+          << "\ntranspose: "
+          << (transpose != nullptr ? transpose->ToString() : "(null)")
+          << "\nreshape: "
+          << (reshape != nullptr ? reshape->ToString() : "(null)")
+          << "\nslice: " << slice->ToString()  //
+          << "\npad: " << pad->ToString();
+
+  // Now check that we can merge the slice into the pad, because the slice is
+  // slicing off elements that we know are 0 and the pad is just adding those 0s
+  // back.
+  //
+  // First, we have to check whether any of the output features at the end of
+  // the conv are known to be 0.
+  std::optional<int64_t> num_known_zero_output_features =
+      NumTrailingZeroOutputFeatures(conv);
+  if (!num_known_zero_output_features.has_value() ||
+      *num_known_zero_output_features == 0) {
+    VLOG(2) << "fail: Didn't find any known-zero output features";
+    return false;
+  }
+
+  // We now know that some of the output features of the conv (starting at
+  // known_zero_output_features_start_idx) are zero.  Check if the
+  // optional-reshape + optional-transpose + slice + pad combination is setting
+  // all of these features to 0.  If so, we can merge the slice into the pad.
+  const auto& dnums = conv->convolution_dimension_numbers();
+  int64_t output_feature_dim;
+  if (reshape == nullptr) {
+    CHECK_EQ(transpose, nullptr);
+    output_feature_dim = dnums.output_feature_dimension();
+  } else {
+    std::optional<int64_t> vect_c_dim_before_transpose =
+        FindOutputVectCDim(conv);
+    if (!vect_c_dim_before_transpose.has_value()) {
+      VLOG(2) << "Couldn't find vect_c output dim in conv.";
+      return false;
+    }
+
+    // If there's no transpose, check that the vect_c dim is immediately after
+    // the feature dim.  OTOH if there is a transpose, check that the transpose
+    // moves the vect_c dim immediately after the feature dim.
+    int64_t feature_dim_after_transpose;
+    int64_t vect_c_dim_after_transpose;
+    if (transpose == nullptr) {
+      feature_dim_after_transpose = dnums.output_feature_dimension();
+      vect_c_dim_after_transpose = *vect_c_dim_before_transpose;
+    } else {
+      const auto& transpose_dims = transpose->dimensions();
+      feature_dim_after_transpose = std::distance(
+          transpose->dimensions().begin(),
+          absl::c_find(transpose_dims, dnums.output_feature_dimension()));
+      vect_c_dim_after_transpose = std::distance(
+          transpose->dimensions().begin(),
+          absl::c_find(transpose_dims, *vect_c_dim_before_transpose));
+    }
+    if (vect_c_dim_after_transpose != feature_dim_after_transpose + 1) {
+      VLOG(2) << "fail: after transpose (if present), vect_c dim must appear "
+                 "immediately after output feature dim: Computed "
+                 "vect_d_dim_after_transpose to be "
+              << vect_c_dim_after_transpose;
+      return false;
+    }
+
+    // Now check that the reshape merges the feature + vect_c dims and
+    // doesn't do anything else.
+    absl::InlinedVector<int64_t, 5> expected_reshape_dim_sizes(
+        reshape->operand(0)->shape().dimensions().begin(),
+        reshape->operand(0)->shape().dimensions().end());
+    expected_reshape_dim_sizes[feature_dim_after_transpose] *=
+        expected_reshape_dim_sizes[vect_c_dim_after_transpose];
+    expected_reshape_dim_sizes.erase(expected_reshape_dim_sizes.begin() +
+                                     vect_c_dim_after_transpose);
+    if (reshape->shape().dimensions() != expected_reshape_dim_sizes) {
+      VLOG(2) << "fail: Reshape doesn't merge vect_c with feature dimension.";
+      return false;
+    }
+
+    output_feature_dim = feature_dim_after_transpose;
+  }
+
+  // Check that `slice` slices only the output feature dimension.
+  if (!absl::c_all_of(slice->slice_starts(), [](auto v) { return v == 0; }) ||
+      !absl::c_all_of(slice->slice_strides(), [](auto v) { return v == 1; })) {
+    VLOG(2) << "fail: Slice doesn't start at the front or has stride != 1.";
+    return false;
+  }
+
+  // We're only allowed to slice the feature dim.
+  for (int64_t dim = 0; dim < slice->slice_limits().size(); dim++) {
+    if (slice->slice_starts(dim) != 0 || slice->slice_strides(dim) != 1 ||
+        (dim != output_feature_dim &&
+         slice->slice_limits(dim) !=
+             slice->operand(0)->shape().dimensions(dim))) {
+      VLOG(2) << "fail: Slice removes something other than the features dim.";
+      return false;
+    }
+  }
+  int64_t num_sliced_from_feature_dim =
+      slice->operand(0)->shape().dimensions(output_feature_dim) -
+      slice->slice_limits(output_feature_dim);
+
+  // If we slice off more than the known-zero output features, then we need to
+  // keep the slice -- it's "doing something".
+  if (num_sliced_from_feature_dim > *num_known_zero_output_features) {
+    VLOG(2) << "fail: Slice removes " << num_sliced_from_feature_dim
+            << " features from the conv, but only "
+            << *num_known_zero_output_features
+            << " features in the conv are known to be zero.";
+    return false;
+  }
+
+  // Check if we can merge the slice into the pad.
+  if (pad->padding_config().dimensions(output_feature_dim).interior_padding() !=
+      0) {
+    VLOG(2)
+        << "fail: Can't merge slice into pad because pad adds interior padding "
+           "in feature dimension.";
+    return false;
+  }
+
+  // Okay!  If we got here, it's legal to fold the slice into the pad.  We pad
+  // less, because we know that the sliced-off elements are all 0.  Ideally, the
+  // pad becomes a nop and gets eliminated by algsimp later.
+  VLOG(1) << "Eliminating " << num_sliced_from_feature_dim
+          << " elements of padding from conv " << conv->name();
+  PaddingConfig new_padding_config = pad->padding_config();
+  PaddingConfig::PaddingConfigDimension* new_pad_feature_dim =
+      new_padding_config.mutable_dimensions(output_feature_dim);
+  // This is safe even if the new edge_padding_high is negative -- negative
+  // padding is allowed.
+  new_pad_feature_dim->set_edge_padding_high(
+      new_pad_feature_dim->edge_padding_high() - num_sliced_from_feature_dim);
+  TF_ASSIGN_OR_RETURN(HloInstruction * new_pad,
+                      MakePadHlo(slice->mutable_operand(0),
+                                 pad->mutable_operand(1), new_padding_config));
+  TF_RETURN_IF_ERROR(pad->parent()->ReplaceInstruction(pad, new_pad));
+  return true;
+}
+
+}  // anonymous namespace
+
+absl::StatusOr<bool> CudnnSimplifyPadding::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  bool changed = false;
+  for (HloComputation* comp :
+       module->MakeNonfusionComputations(execution_threads)) {
+    for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
+      TF_ASSIGN_OR_RETURN(bool c, TrySimplifyPadding(instr));
+      changed |= c;
+    }
+  }
+  return changed;
+}
+
+}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_simplify_padding.h b/third_party/xla/xla/service/gpu/transforms/cudnn_simplify_padding.h
new file mode 100644
index 0000000..67580b4
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cudnn_simplify_padding.h
@@ -0,0 +1,67 @@
+/* Copyright 2022 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_CUDNN_SIMPLIFY_PADDING_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_CUDNN_SIMPLIFY_PADDING_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla::gpu {
+
+// Simplifies or eliminates padding introduced by CudnnPadForConvolutions and
+// CudnnVectorizeConvolutions.
+//
+// CudnnVectorizeConvolutions will generate code that does the following.
+//  - pad input and output features to a multiple of 32 (or 4),
+//  - reshape input from [N,C,H,W] to [N,C/32,H,W,32] and reshape kernel from
+//    [I,O,H,W] to [I/32,32,O,H,W],
+//  - run the conv,
+//  - reshape output from [N,C/32,H,W,32] to [N,C,H,W], and finally
+//  - slice off the padding on the C channel.
+//
+// But if this is followed by another convolution (very common), then the slice
+// is immediately followed by another pad. This may be redundant; we know that
+// the trailing channels sliced off from the first conv are 0.
+//
+// Ideally we can eliminate the whole reshape+slice+pad+reshape sequence between
+// the two convolutions.
+//
+// Specifically, this pass tries to merge the slice at the end of the sequence
+// above into the pad from the next convolution (when we can prove that the
+// sliced-off elements are all 0). We then rely on algsimp to remove the pad if
+// it's a nop and then to merge and eliminate the remaining reshapes.
+//
+// This pass should run after CudnnVectorizeConvolutions and there should be no
+// simplification passes in between that modify the reshape-transpose-reshape
+// introduced by int8x32 convolution filter reordering.
+class CudnnSimplifyPadding : public HloModulePass {
+ public:
+  CudnnSimplifyPadding() = default;
+
+  absl::string_view name() const override { return "cudnn_simplify_padding"; }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+}  // namespace xla::gpu
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_CUDNN_SIMPLIFY_PADDING_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_simplify_padding_test.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_simplify_padding_test.cc
new file mode 100644
index 0000000..e924cca
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cudnn_simplify_padding_test.cc
@@ -0,0 +1,771 @@
+/* Copyright 2021 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/cudnn_simplify_padding.h"
+
+#include <cstdint>
+#include <memory>
+#include <utility>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/functional/function_ref.h"
+#include "absl/strings/str_cat.h"
+#include "absl/types/span.h"
+#include "xla/literal.h"
+#include "xla/service/algebraic_simplifier.h"
+#include "xla/service/call_inliner.h"
+#include "xla/service/gpu/transforms/cudnn_pad_for_convolutions.h"
+#include "xla/service/gpu/transforms/cudnn_vectorize_convolutions.h"
+#include "xla/service/hlo_pass_fix.h"
+#include "xla/service/hlo_pass_pipeline.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/service/reshape_mover.h"
+#include "xla/service/tuple_simplifier.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/stream_executor/dnn.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/tsl/lib/core/status_test_util.h"
+#include "xla/util.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/logging.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla::gpu {
+namespace {
+
+namespace m = ::xla::match;
+
+class CudnnSimplifyPaddingTest : public HloTestBase {
+ protected:
+  // Runs the whole relevant pass pipeline starting at CudnnPadForConvolutions.
+  // This lets us test that we're matching the patterns that actually get
+  // generated by padding+vectorization.
+  absl::StatusOr<bool> RunEndToEnd(std::pair<int, int> compute_capability,
+                                   HloModule* module) {
+    se::CudaComputeCapability cc{compute_capability.first,
+                                 compute_capability.second};
+
+    TF_RETURN_IF_ERROR(
+        RunHloPass(CudnnPadForConvolutions(cc), module).status());
+
+    TF_RETURN_IF_ERROR(
+        RunHloPass(CudnnVectorizeConvolutions(
+                       cc, /*cudnn_version=*/se::dnn::VersionInfo{8, 3, 0}),
+                   module)
+            .status());
+    VLOG(1) << "after vectorizing convs:\n" << module->ToString();
+
+    TF_RETURN_IF_ERROR(RunHloPass(CallInliner(), module).status());
+    VLOG(1) << "after inliner:\n" << module->ToString();
+
+    TF_RETURN_IF_ERROR(RunHloPass(TupleSimplifier(), module).status());
+    VLOG(1) << "after tuple simplifier:\n" << module->ToString();
+
+    TF_ASSIGN_OR_RETURN(bool changed,
+                        RunHloPass(CudnnSimplifyPadding(), module));
+    VLOG(1) << "after simplify_padding:\n" << module->ToString();
+
+    {
+      // reshape-mover expects to be run alongside algsimp.
+      HloPassFix<HloPassPipeline> pipeline("reshape-mover and algsimp");
+      pipeline.AddPass<ReshapeMover>();
+      pipeline.AddPass<AlgebraicSimplifier>(AlgebraicSimplifierOptions());
+      TF_RETURN_IF_ERROR(RunHloPass(pipeline, module).status());
+    }
+    VLOG(1) << "after reshape mover + algsimp:\n" << module->ToString();
+
+    return changed;
+  }
+
+  absl::StatusOr<bool> RunJustThisPass(HloModule* module) {
+    TF_ASSIGN_OR_RETURN(bool changed,
+                        RunHloPass(CudnnSimplifyPadding(), module));
+    VLOG(1) << "after simplify_padding:\n" << module->ToString();
+
+    // I know the name says "just this pass", but you really want algsimp too,
+    // otherwise the resulting patterns are ugly/hard to match.
+    TF_RETURN_IF_ERROR(RunHloPass(HloPassFix<AlgebraicSimplifier>(
+                                      AlgebraicSimplifierOptions()),
+                                  module)
+                           .status());
+    return changed;
+  }
+};
+
+void ExpectOnlyPadsOneDim(int64_t dim, int64_t padding_high,
+                          const PaddingConfig& p) {
+  SCOPED_TRACE(p.DebugString());
+  for (int i = 0; i < p.dimensions_size(); ++i) {
+    SCOPED_TRACE(absl::StrCat("dimension ", i));
+    EXPECT_EQ(p.dimensions(i).edge_padding_low(), 0);
+    if (i == dim) {
+      EXPECT_EQ(p.dimensions(i).edge_padding_high(), padding_high);
+    } else {
+      EXPECT_EQ(p.dimensions(i).edge_padding_high(), 0);
+    }
+  }
+}
+
+template <typename NativeT>
+void SetConstantValue(
+    HloInstruction* instr,
+    absl::FunctionRef<NativeT(absl::Span<const int64_t>, NativeT)> value_fn) {
+  Literal new_literal = instr->literal().Clone();
+  new_literal.MutableEachCell<int8_t>(value_fn);
+  TF_EXPECT_OK(instr->parent()->ReplaceWithNewInstruction(
+      instr, HloInstruction::CreateConstant(std::move(new_literal))));
+}
+
+TEST_F(CudnnSimplifyPaddingTest, EndToEnd) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule TestModule
+
+  ENTRY TestComputation {
+    conv1 = (s8[10,20,30,190], u8[0]) custom-call(
+        s8[10,20,30,63] parameter(0), s8[3,5,63,190] parameter(1),
+        f32[10] parameter(2), s8[10,20,30,190] parameter(3)),
+      window={size=3x5}, dim_labels=b01f_01io->b01f,
+      custom_call_target="__cudnn$convBiasActivationForward"
+    conv1_result = get-tuple-element(conv1), index=0
+    ROOT conv2 = (s8[10,20,30,29], u8[0]) custom-call(
+        conv1_result, s8[3,5,190,29] parameter(4),
+        f32[10] parameter(5), s8[10,20,30,29] parameter(6)),
+      window={size=3x5}, dim_labels=b01f_01io->b01f,
+      custom_call_target="__cudnn$convBiasActivationForward"
+  })")
+                    .value();
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunEndToEnd({7, 5}, module.get()));
+  EXPECT_TRUE(changed);
+
+  SCOPED_TRACE(module->ToString());
+  auto* root = module->entry_computation()->root_instruction();
+
+  // conv2 should be fed directly from conv1, without any intervening
+  // reshapes/pads.
+  EXPECT_THAT(
+      root,
+      GmockMatch(m::Tuple(
+          m::Slice(m::Reshape(m::GetTupleElement(m::CustomCall(
+              {"__cudnn$convBiasActivationForward"},
+              m::GetTupleElement(
+                  m::CustomCall({"__cudnn$convBiasActivationForward"}), 0),
+              m::Op(), m::Op(), m::Op())))),
+          m::Op())));
+}
+
+TEST_F(CudnnSimplifyPaddingTest, EndToEndNCHW) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule TestModule
+
+  ENTRY TestComputation {
+    conv1 = (s8[1,64,480,400], u8[0]) custom-call(
+        s8[1,112,480,400] parameter(0), s8[3,3,112,64] parameter(1),
+        f32[64] parameter(2)),
+      window={size=3x3}, dim_labels=bf01_01io->bf01,
+      custom_call_target="__cudnn$convBiasActivationForward"
+    conv1_result = get-tuple-element(conv1), index=0
+    convert = f32[1,64,480,400] convert(conv1_result)
+    constant = f32[] constant(0.349002093)
+    broadcast = f32[1,64,480,400] broadcast(constant)
+    ROOT multiply = f32[1,64,480,400] multiply(convert, broadcast)
+  })")
+                    .value();
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunEndToEnd({7, 5}, module.get()));
+  // The SimplifyPadding pass itself does not do anything.
+  EXPECT_FALSE(changed);
+
+  SCOPED_TRACE(module->ToString());
+  auto* root = module->entry_computation()->root_instruction();
+
+  // The reshape introduced by CudnnVectorizeConvolutions should have been moved
+  // to the root.
+  EXPECT_THAT(root, GmockMatch(m::Reshape(m::Multiply())));
+}
+
+TEST_F(CudnnSimplifyPaddingTest, PaddedWeights) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule TestModule
+
+    ENTRY TestComputation {
+      weights = pad(s8[3,3,10,10] parameter(0), s8[] constant(0)), padding=0_0x0_0x0_0x0_4
+      conv = (s8[10,10,10,10], u8[0]) custom-call(
+          s8[10,10,10,10] parameter(1),
+          weights
+        ), window={size=3x3}, dim_labels=b01f_01io->b01f,
+        custom_call_target="__cudnn$convForward"
+      conv_result = get-tuple-element(conv), index=0
+      slice = s8[10,10,10,6] slice(conv_result), slice={[0:10], [0:10], [0:10], [0:6]}
+      ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_5
+    }
+  )")
+                    .value();
+
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
+  EXPECT_TRUE(changed);
+
+  SCOPED_TRACE(module->ToString());
+  auto* root = module->entry_computation()->root_instruction();
+  const HloInstruction* pad = nullptr;
+  ASSERT_THAT(root,
+              GmockMatch(m::Pad(&pad, m::GetTupleElement(m::CustomCall(), 0),
+                                m::ConstantScalar(0))));
+
+  ExpectOnlyPadsOneDim(/*dim=*/3, /*padding_high=*/1, pad->padding_config());
+}
+
+// This is similar to PaddedWeights, except the only 3 elements of the weights
+// are padded to 0 while we slice off 4 elements from the output features. As a
+// result, not all of the sliced elements are 0, and we can't merge the slice
+// into the pad that follows.
+TEST_F(CudnnSimplifyPaddingTest, PaddedWeightsNotPaddedEnough) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule TestModule
+
+    ENTRY TestComputation {
+      weights = pad(s8[3,3,10,10] parameter(0), s8[] constant(0)), padding=0_0x0_0x0_0x0_3
+      conv = (s8[10,10,10,10], u8[0]) custom-call(
+          s8[10,10,10,10] parameter(1),
+          weights
+        ), window={size=3x3}, dim_labels=b01f_01io->b01f,
+        custom_call_target="__cudnn$convForward"
+      conv_result = get-tuple-element(conv), index=0
+      slice = s8[10,10,10,6] slice(conv_result), slice={[0:10], [0:10], [0:10], [0:6]}
+      ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_5
+    }
+  )")
+                    .value();
+
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
+  EXPECT_FALSE(changed);
+}
+
+TEST_F(CudnnSimplifyPaddingTest, PaddedAndReshapedWeightsNCHW) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule TestModule
+
+    ENTRY TestComputation {
+      weights_p = pad(s8[64,60,3,3] parameter(0), s8[] constant(0)), padding=0_0x0_4x0_0x0_0
+      weights = s8[2,32,64,3,3] reshape(weights_p)
+      conv = (s8[10,2,32,10,10], u8[0]) custom-call(
+          s8[10,2,32,10,10] parameter(1),
+          weights
+        ), window={size=3x3}, dim_labels=bf?01_i?o01->bf?01,
+        custom_call_target="__cudnn$convForward"
+      conv_result = get-tuple-element(conv), index=0
+      slice = s8[10,60,10,10] slice(s8[10,64,10,10] reshape(conv_result)), slice={[0:10], [0:60], [0:10], [0:10]}
+      ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_5x0_0x0_0
+    }
+  )")
+                    .value();
+
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
+  EXPECT_TRUE(changed);
+
+  SCOPED_TRACE(module->ToString());
+  auto* root = module->entry_computation()->root_instruction();
+  const HloInstruction* pad = nullptr;
+  ASSERT_THAT(
+      root, GmockMatch(
+                m::Pad(&pad, m::Reshape(m::GetTupleElement(m::CustomCall(), 0)),
+                       m::ConstantScalar(0))));
+
+  ExpectOnlyPadsOneDim(/*dim=*/1, /*padding_high=*/1, pad->padding_config());
+}
+
+TEST_F(CudnnSimplifyPaddingTest, PaddedAndReshapedWeightsNHWC) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule TestModule
+
+    ENTRY TestComputation {
+      weights_p = pad(s8[3,3,64,60] parameter(0), s8[] constant(0)), padding=0_0x0_0x0_0x0_4
+      weights = s8[3,3,2,32,64] reshape(weights_p)
+      conv = (s8[10,10,10,2,32], u8[0]) custom-call(
+          s8[10,10,10,2,32] parameter(1),
+          weights
+        ), window={size=3x3}, dim_labels=b01f?_01i?o->b01f?,
+        custom_call_target="__cudnn$convForward"
+      conv_result = get-tuple-element(conv), index=0
+      slice = s8[10,10,10,60] slice(s8[10,10,10,64] reshape(conv_result)), slice={[0:10], [0:10], [0:10], [0:60]}
+      ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_5
+    }
+  )")
+                    .value();
+
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
+  EXPECT_TRUE(changed);
+
+  SCOPED_TRACE(module->ToString());
+  auto* root = module->entry_computation()->root_instruction();
+  const HloInstruction* pad = nullptr;
+  ASSERT_THAT(
+      root, GmockMatch(
+                m::Pad(&pad, m::Reshape(m::GetTupleElement(m::CustomCall(), 0)),
+                       m::ConstantScalar(0))));
+
+  ExpectOnlyPadsOneDim(/*dim=*/3, /*padding_high=*/1, pad->padding_config());
+}
+
+TEST_F(CudnnSimplifyPaddingTest, PaddedTransposedAndReshapedOutput) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule TestModule
+
+    ENTRY TestComputation {
+      weights_p = pad(s8[64,60,3,3] parameter(0), s8[] constant(0)), padding=0_0x0_4x0_0x0_0
+      weights = s8[2,32,64,3,3] reshape(weights_p)
+      conv = (s8[10,2,10,10,32], u8[0]) custom-call(
+          s8[10,2,10,10,32] parameter(1),
+          weights
+        ), window={size=3x3}, dim_labels=bf01?_i?o01->bf01?,
+        custom_call_target="__cudnn$convForward"
+      conv_result = get-tuple-element(conv), index=0
+      conv_transposed = s8[10,2,32,10,10] transpose(conv_result), dimensions={0,1,4,2,3}
+      slice = s8[10,60,10,10] slice(s8[10,64,10,10] reshape(conv_transposed)), slice={[0:10], [0:60], [0:10], [0:10]}
+      ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_6x0_0x0_0
+    }
+  )")
+                    .value();
+
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
+  EXPECT_TRUE(changed);
+
+  SCOPED_TRACE(module->ToString());
+  auto* root = module->entry_computation()->root_instruction();
+  const HloInstruction* pad = nullptr;
+  ASSERT_THAT(
+      root,
+      GmockMatch(m::Pad(
+          &pad,
+          m::Reshape(m::Transpose(m::GetTupleElement(m::CustomCall(), 0))),
+          m::ConstantScalar(0))));
+
+  ExpectOnlyPadsOneDim(/*dim=*/1, /*padding_high=*/2, pad->padding_config());
+}
+
+TEST_F(CudnnSimplifyPaddingTest, PaddedConstantWeight) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule TestModule
+
+    ENTRY TestComputation {
+      conv = (s8[10,10,10,10], u8[0]) custom-call(
+          s8[10,10,10,10] parameter(0),
+          s8[3,3,10,10] constant({...})
+        ), window={size=3x3}, dim_labels=b01f_01io->b01f,
+        custom_call_target="__cudnn$convForward"
+      conv_result = get-tuple-element(conv), index=0
+      slice = s8[10,10,10,6] slice(conv_result), slice={[0:10], [0:10], [0:10], [0:6]}
+      ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_5
+    }
+  )")
+                    .value();
+
+  // Set the constant's value.  (The HLO text above sets it to all 0s.)
+  {
+    HloInstruction* weights = nullptr;
+    ASSERT_THAT(module->entry_computation()->root_instruction(),
+                GmockMatch(m::Pad(m::Slice(m::GetTupleElement(m::CustomCall(
+                                      m::Op(), m::Constant(&weights)))),
+                                  m::Op())));
+    SetConstantValue<int8_t>(
+        weights, [](absl::Span<const int64_t> dims, int8_t old_val) -> int8_t {
+          if (dims[3] < 6) return 1;
+          return 0;
+        });
+  }
+
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
+  EXPECT_TRUE(changed);
+
+  SCOPED_TRACE(module->ToString());
+  auto* root = module->entry_computation()->root_instruction();
+  const HloInstruction* pad = nullptr;
+  ASSERT_THAT(root,
+              GmockMatch(m::Pad(&pad, m::GetTupleElement(m::CustomCall(), 0),
+                                m::ConstantScalar(0))));
+
+  ExpectOnlyPadsOneDim(/*dim=*/3, /*padding_high=*/1, pad->padding_config());
+}
+
+TEST_F(CudnnSimplifyPaddingTest, PaddedConstantWeightIsNotLargeEnough) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule TestModule
+
+    ENTRY TestComputation {
+      conv = (s8[10,10,10,10], u8[0]) custom-call(
+          s8[10,10,10,10] parameter(0),
+          s8[3,3,10,10] constant({...})
+        ), window={size=3x3}, dim_labels=b01f_01io->b01f,
+        custom_call_target="__cudnn$convForward"
+      conv_result = get-tuple-element(conv), index=0
+      slice = s8[10,10,10,6] slice(conv_result), slice={[0:10], [0:10], [0:10], [0:6]}
+      ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_5
+    }
+  )")
+                    .value();
+
+  // Set the constant's value.  (The HLO text above sets it to all 0s.)
+  {
+    HloInstruction* weights = nullptr;
+    ASSERT_THAT(module->entry_computation()->root_instruction(),
+                GmockMatch(m::Pad(m::Slice(m::GetTupleElement(m::CustomCall(
+                                      m::Op(), m::Constant(&weights)))),
+                                  m::Op())));
+    SetConstantValue<int8_t>(
+        weights, [](absl::Span<const int64_t> dims, int8_t old_val) -> int8_t {
+          // The sixth feature dimension (i.e. index 5) is only partially 0.
+          if (dims[3] < 5 /*|| (dims[3] == 5 && dims[2] > 1)*/) return 0;
+          return 1;
+        });
+  }
+
+  // Some of the value sliced off are not 0, so we can't merge the slice into
+  // the pad.
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
+  EXPECT_FALSE(changed);
+}
+
+TEST_F(CudnnSimplifyPaddingTest, ReshapeDoesntMergeVectCDim) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule TestModule
+
+    ENTRY TestComputation {
+      weights_p = pad(s8[64,60,3,3] parameter(0), s8[] constant(0)), padding=0_0x0_4x0_0x0_0
+      weights = s8[2,64,3,3,32] reshape(weights_p)
+      conv = (s8[10,2,10,10,32], u8[0]) custom-call(
+          s8[10,2,10,10,32] parameter(1),
+          weights_p
+        ), window={size=3x3}, dim_labels=bf01?_io01?->bf01?,
+        custom_call_target="__cudnn$convForward"
+      conv_result = get-tuple-element(conv), index=0
+      slice = s8[10,60,10,10] slice(s8[10,64,10,10] reshape(conv_result)), slice={[0:10], [0:60], [0:10], [0:10]}
+      ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_6x0_0x0_0
+    }
+  )")
+                    .value();
+
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
+  EXPECT_FALSE(changed);
+}
+
+TEST_F(CudnnSimplifyPaddingTest, TwoVectCDimsInOutput) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule TestModule
+
+    ENTRY TestComputation {
+      weights_p = pad(s8[64,60,3,3] parameter(0), s8[] constant(0)), padding=0_0x0_4x0_0x0_0
+      weights = s8[2,64,3,3,32] reshape(weights_p)
+      conv = (s8[10,2,10,10,4,8], u8[0]) custom-call(
+          s8[10,2,10,10,32] parameter(1),
+          weights
+        ), window={size=3x3}, dim_labels=bf01?_io01?->bf01??,
+        custom_call_target="__cudnn$convForward"
+      conv_result = get-tuple-element(conv), index=0
+      conv_transposed = s8[10,2,4,8,10,10] transpose(conv_result), dimensions={0,1,4,5,2,3}
+      slice = s8[10,60,10,10] slice(s8[10,64,10,10] reshape(conv_transposed)), slice={[0:10], [0:60], [0:10], [0:10]}
+      ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_6x0_0x0_0
+    }
+  )")
+                    .value();
+
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
+  EXPECT_FALSE(changed);
+}
+
+TEST_F(CudnnSimplifyPaddingTest, TwoVectCDimsInKernel) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule TestModule
+
+    ENTRY TestComputation {
+      weights_p = pad(s8[64,60,3,3] parameter(0), s8[] constant(0)), padding=0_0x0_4x0_0x0_0
+      weights = s8[2,64,3,3,4,8] reshape(weights_p)
+      conv = (s8[10,2,10,10,32], u8[0]) custom-call(
+          s8[10,2,10,10,32] parameter(1),
+          weights
+        ), window={size=3x3}, dim_labels=bf01?_io01??->bf01?,
+        custom_call_target="__cudnn$convForward"
+      conv_result = get-tuple-element(conv), index=0
+      conv_transposed = s8[10,2,32,10,10] transpose(conv_result), dimensions={0,1,4,2,3}
+      slice = s8[10,60,10,10] slice(s8[10,64,10,10] reshape(conv_transposed)), slice={[0:10], [0:60], [0:10], [0:10]}
+      ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_6x0_0x0_0
+    }
+  )")
+                    .value();
+
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
+  EXPECT_FALSE(changed);
+}
+
+TEST_F(CudnnSimplifyPaddingTest, SliceDoesntStartAtBeginning) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule TestModule
+
+    ENTRY TestComputation {
+      weights = pad(s8[3,3,10,10] parameter(0), s8[] constant(0)), padding=0_0x0_0x0_0x0_4
+      conv = (s8[10,10,10,10], u8[0]) custom-call(
+          s8[10,10,10,10] parameter(1),
+          weights
+        ), window={size=3x3}, dim_labels=b01f_01io->b01f,
+        custom_call_target="__cudnn$convForward"
+      conv_result = get-tuple-element(conv), index=0
+      slice = s8[10,9,10,6] slice(conv_result), slice={[0:10], [1:10], [0:10], [0:6]}
+      ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_5
+    }
+  )")
+                    .value();
+
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
+  EXPECT_FALSE(changed);
+}
+
+TEST_F(CudnnSimplifyPaddingTest, SliceDoesntStartAtBeginningOfFeatureDim) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule TestModule
+
+    ENTRY TestComputation {
+      weights = pad(s8[3,3,10,10] parameter(0), s8[] constant(0)), padding=0_0x0_0x0_0x0_4
+      conv = (s8[10,10,10,10], u8[0]) custom-call(
+          s8[10,10,10,10] parameter(1),
+          weights
+        ), window={size=3x3}, dim_labels=b01f_01io->b01f,
+        custom_call_target="__cudnn$convForward"
+      conv_result = get-tuple-element(conv), index=0
+      slice = s8[10,10,10,5] slice(conv_result), slice={[0:10], [0:10], [0:10], [1:6]}
+      ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_5
+    }
+  )")
+                    .value();
+
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
+  EXPECT_FALSE(changed);
+}
+
+TEST_F(CudnnSimplifyPaddingTest, SliceHasStride) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule TestModule
+
+    ENTRY TestComputation {
+      weights = pad(s8[3,3,10,10] parameter(0), s8[] constant(0)), padding=0_0x0_0x0_0x0_4
+      conv = (s8[10,10,10,10], u8[0]) custom-call(
+          s8[10,10,10,10] parameter(1),
+          weights
+        ), window={size=3x3}, dim_labels=b01f_01io->b01f,
+        custom_call_target="__cudnn$convForward"
+      conv_result = get-tuple-element(conv), index=0
+      slice = s8[10,10,10,3] slice(conv_result), slice={[0:10], [0:10], [0:10], [0:6:2]}
+      ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_5
+    }
+  )")
+                    .value();
+
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
+  EXPECT_FALSE(changed);
+}
+
+TEST_F(CudnnSimplifyPaddingTest, PadAddsInteriorPadding) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule TestModule
+
+    ENTRY TestComputation {
+      weights = pad(s8[3,3,10,10] parameter(0), s8[] constant(0)), padding=0_0x0_0x0_0x0_4
+      conv = (s8[10,10,10,10], u8[0]) custom-call(
+          s8[10,10,10,10] parameter(1),
+          weights
+        ), window={size=3x3}, dim_labels=b01f_01io->b01f,
+        custom_call_target="__cudnn$convForward"
+      conv_result = get-tuple-element(conv), index=0
+      slice = s8[10,10,10,6] slice(conv_result), slice={[0:10], [0:10], [0:10], [0:6]}
+      ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_5_1
+    }
+  )")
+                    .value();
+
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
+  EXPECT_FALSE(changed);
+}
+
+TEST_F(CudnnSimplifyPaddingTest, SliceMoreElementsThanPad) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule TestModule
+
+    ENTRY TestComputation {
+      weights = pad(s8[3,3,10,10] parameter(0), s8[] constant(0)), padding=0_0x0_0x0_0x0_4
+      conv = (s8[10,10,10,10], u8[0]) custom-call(
+          s8[10,10,10,10] parameter(1),
+          weights
+        ), window={size=3x3}, dim_labels=b01f_01io->b01f,
+        custom_call_target="__cudnn$convForward"
+      conv_result = get-tuple-element(conv), index=0
+      slice = s8[10,10,10,6] slice(conv_result), slice={[0:10], [0:10], [0:10], [0:6]}
+      ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_2
+    }
+  )")
+                    .value();
+
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
+  EXPECT_TRUE(changed);
+
+  SCOPED_TRACE(module->ToString());
+  auto* root = module->entry_computation()->root_instruction();
+  const HloInstruction* slice = nullptr;
+  // The pass creates a pad with negative padding; this is simplified by algsimp
+  // into a slice.
+  ASSERT_THAT(root, GmockMatch(m::Slice(
+                        &slice, m::GetTupleElement(m::CustomCall(), 0))));
+  for (int64_t i = 0; i < slice->shape().dimensions_size(); ++i) {
+    SCOPED_TRACE(i);
+    EXPECT_EQ(slice->slice_starts(i), 0);
+    EXPECT_EQ(slice->slice_strides(i), 1);
+    if (i != 3) {
+      EXPECT_EQ(slice->slice_limits(i), 10);
+    } else {
+      EXPECT_EQ(slice->slice_limits(i), 8);
+    }
+  }
+}
+
+TEST_F(CudnnSimplifyPaddingTest, NoChangeOnNonTrivialConstants) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+HloModule jit_outer
+
+ENTRY main.26 {
+  reshape.2 = f32[1,3,3,12]{3,2,1,0} parameter(0)
+  constant.1 = f32[3,3,1,12]{3,2,1,0} constant({ {
+    { /*i1=0*/ { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
+    { /*i1=1*/ { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
+    { /*i1=2*/ { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } }
+  }, {
+    { /*i1=0*/ { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
+    { /*i1=2*/ { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } }
+    { /*i1=2*/ { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } }
+  }, {
+    { /*i1=0*/ { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
+    { /*i1=1*/ { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
+    { /*i1=2*/ { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } } } })
+  cudnn-conv = (f32[1,5,5,12]{3,2,1,0}, u8[0]{0}) custom-call(reshape.2, constant.1), window={size=3x3 pad=2_2x2_2}, dim_labels=b01f_01io->b01f, feature_group_count=12, custom_call_target="__cudnn$convForward"
+  get-tuple-element = f32[1,5,5,12]{3,2,1,0} get-tuple-element(cudnn-conv), index=0
+  slice.2 = f32[1,5,1,12]{3,2,1,0} slice(get-tuple-element), slice={[0:1], [0:5], [0:1], [0:12]}
+  constant.0 = f32[] constant(0)
+  ROOT pad.1 = f32[1,5,3,12]{3,2,1,0} pad(slice.2, constant.0), padding=0_0x0_0x2_0x0_0
+}
+  )")
+                    .value();
+
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
+  EXPECT_FALSE(changed);
+}
+
+TEST_F(CudnnSimplifyPaddingTest, NoChangeOnComplexSlices) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+HloModule jit_outer
+
+ENTRY main.26 {
+  reshape.2 = f32[1,3,3,12]{3,2,1,0} parameter(0)
+  constant.1 = f32[3,3,1,12]{3,2,1,0} constant({ {
+    { /*i1=0*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
+    { /*i1=1*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
+    { /*i1=2*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } }
+  }, {
+    { /*i1=0*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
+    { /*i1=2*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } }
+    { /*i1=2*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } }
+  }, {
+    { /*i1=0*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
+    { /*i1=1*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
+    { /*i1=2*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } } } })
+  cudnn-conv = (f32[1,5,5,12]{3,2,1,0}, u8[0]{0}) custom-call(reshape.2, constant.1), window={size=3x3 pad=2_2x2_2}, dim_labels=b01f_01io->b01f, feature_group_count=12, custom_call_target="__cudnn$convForward"
+  get-tuple-element = f32[1,5,5,12]{3,2,1,0} get-tuple-element(cudnn-conv), index=0
+  slice.2 = f32[1,5,5,4]{3,2,1,0} slice(get-tuple-element), slice={[0:1], [0:5], [0:5], [2:6]}
+  constant.0 = f32[] constant(0)
+  ROOT pad.1 = f32[1,5,5,12]{3,2,1,0} pad(slice.2, constant.0), padding=0_0x0_0x0_0x0_8
+}
+  )")
+                    .value();
+
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
+  EXPECT_FALSE(changed);
+}
+
+TEST_F(CudnnSimplifyPaddingTest, ScanOrderFeatureDimLast) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+HloModule jit_outer
+
+ENTRY main.26 {
+  reshape.2 = f32[1,3,3,12]{3,2,1,0} parameter(0)
+  constant.1 = f32[3,3,1,12]{3,2,1,0} constant({ {
+    { /*i1=0*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
+    { /*i1=1*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
+    { /*i1=2*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } }
+  }, {
+    { /*i1=0*/ { 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0 } },
+    { /*i1=2*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } }
+    { /*i1=2*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } }
+  }, {
+    { /*i1=0*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
+    { /*i1=1*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } },
+    { /*i1=2*/ { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } } } })
+  cudnn-conv = (f32[1,5,5,12]{3,2,1,0}, u8[0]{0}) custom-call(reshape.2, constant.1), window={size=3x3 pad=2_2x2_2}, dim_labels=b01f_01io->b01f, feature_group_count=12, custom_call_target="__cudnn$convForward"
+  get-tuple-element = f32[1,5,5,12]{3,2,1,0} get-tuple-element(cudnn-conv), index=0
+  slice.2 = f32[1,5,5,6]{3,2,1,0} slice(get-tuple-element), slice={[0:1], [0:5], [0:5], [0:6]}
+  constant.0 = f32[] constant(0)
+  ROOT pad.1 = f32[1,5,5,12]{3,2,1,0} pad(slice.2, constant.0), padding=0_0x0_0x0_0x0_6
+}
+  )")
+                    .value();
+
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
+  EXPECT_FALSE(changed);
+}
+
+TEST_F(CudnnSimplifyPaddingTest, Int8FilterReorderedOutputFirst) {
+  // Test feature dimension calculation from reordering transpose (oi01)
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule TestModule
+
+  ENTRY TestComputation {
+    conv.1 = (s8[1,63,80,80], u8[0]) custom-call(
+        s8[1,112,80,80] parameter(0), s8[63,112,3,3] parameter(1)),
+      window={size=3x3}, dim_labels=bf01_oi01->bf01,
+      custom_call_target="__cudnn$convForward"
+    gte.1 = s8[1,63,80,80] get-tuple-element(conv.1), index=0
+    const.0 = s8[] constant(0)
+    ROOT pad.1 = s8[1,64,80,80] pad(gte.1, const.0), padding=0_0x0_1x0_0x0_0
+  })")
+                    .value();
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunEndToEnd({7, 5}, module.get()));
+  EXPECT_TRUE(changed);
+}
+
+TEST_F(CudnnSimplifyPaddingTest, Int8FilterReorderedOutputLast) {
+  // Test feature dimension calculation from reordering transpose (01io)
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule TestModule
+
+  ENTRY TestComputation {
+    conv.1 = (s8[1,63,80,80], u8[0]) custom-call(
+        s8[1,112,80,80] parameter(0), s8[3,3,112,63] parameter(1)),
+      window={size=3x3}, dim_labels=bf01_01io->bf01,
+      custom_call_target="__cudnn$convForward"
+    gte.1 = s8[1,63,80,80] get-tuple-element(conv.1), index=0
+    const.0 = s8[] constant(0)
+    ROOT pad.1 = s8[1,64,80,80] pad(gte.1, const.0), padding=0_0x0_1x0_0x0_0
+  })")
+                    .value();
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunEndToEnd({7, 5}, module.get()));
+  EXPECT_TRUE(changed);
+}
+
+}  // anonymous namespace
+}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_vectorize_convolutions.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_vectorize_convolutions.cc
new file mode 100644
index 0000000..698b8fb
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cudnn_vectorize_convolutions.cc
@@ -0,0 +1,650 @@
+/* Copyright 2021 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/cudnn_vectorize_convolutions.h"
+
+#include <cstdint>
+#include <optional>
+#include <string>
+#include <tuple>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/container/inlined_vector.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "xla/client/xla_builder.h"
+#include "xla/client/xla_computation.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_clone_context.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/primitive_util.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/service/gpu/cudnn_support_utils.h"
+#include "xla/service/gpu/stream_executor_util.h"
+#include "xla/service/hlo_module_config.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/stream_executor/dnn.h"
+#include "xla/util.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/logging.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+// Finds convolutions that this pass may be able to transform, namely int8_t
+// cudnn forward or forward-bias-activation convolutions
+//
+// cudnn as of v8.2 supports the following data type combinations for forward
+// and forward-bias-activation convolutions.  We have to make sure we only
+// vectorize to one of these supported configs.
+//
+//   in       out
+//   int8x1   int8x1
+//   int8x1   float
+//   int8x1   int32_t
+//
+//   int8x4   int8x4
+//   int8x4   float
+//
+//   int8x32  int8x32
+//
+// https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionForward
+//
+// For now we restrict ourselves to only the int8xN -> int8xN cases.  We could
+// allow the int8x4 -> float case in the future if desirable.
+static std::vector<HloCustomCallInstruction*> GetRelevantConvs(
+    HloComputation* comp) {
+  std::vector<HloCustomCallInstruction*> convs;
+  for (HloInstruction* instr : comp->instructions()) {
+    if (instr->opcode() != HloOpcode::kCustomCall ||
+        (instr->custom_call_target() != kCudnnConvForwardCallTarget &&
+         instr->custom_call_target() !=
+             kCudnnConvBiasActivationForwardCallTarget) ||
+        instr->operand_count() < 2) {
+      continue;
+    }
+
+    PrimitiveType input_ty = instr->operand(0)->shape().element_type();
+    PrimitiveType output_ty = instr->shape().tuple_shapes(0).element_type();
+    if (input_ty == output_ty && (input_ty == S8 || input_ty == U8)) {
+      convs.push_back(Cast<HloCustomCallInstruction>(instr));
+    }
+  }
+  return convs;
+}
+
+// Converts an XlaBuilder into an HloComputation in the same module as
+// `sibling_computation`.
+//
+// Yes, we serialize/deserialize as a proto.  :)
+static absl::StatusOr<HloComputation*> BuilderToHloComputation(
+    XlaBuilder& b, XlaOp root, HloComputation* sibling_computation) {
+  TF_ASSIGN_OR_RETURN(XlaComputation comp, b.Build(root));
+  TF_ASSIGN_OR_RETURN(ProgramShape program_shape, comp.GetProgramShape());
+  HloModuleConfig config(program_shape);
+  TF_ASSIGN_OR_RETURN(auto new_module,
+                      HloModule::CreateFromProto(comp.proto(), config));
+
+  HloModule* dest_module = sibling_computation->parent();
+  HloCloneContext context(dest_module);
+  return dest_module->DeepCloneComputation(new_module->entry_computation(),
+                                           &context);
+}
+
+// Reshapes `instr` so that it has an extra dimension of size `vect_size` right
+// after `dim`.
+static XlaOp SplitAtDim(XlaOp instr, int64_t dim, int64_t vect_size) {
+  XlaBuilder& b = *instr.builder();
+  Shape shape = b.GetShape(instr).value();
+  DimensionVector new_dims(shape.dimensions().begin(),
+                           shape.dimensions().end());
+  CHECK_EQ(new_dims[dim] % vect_size, 0);
+  new_dims[dim] /= vect_size;
+  new_dims.insert(new_dims.begin() + dim + 1, vect_size);
+  return Reshape(instr, new_dims);
+}
+
+// Reshapes `shape` so that there's an extra dimension of size `vect_size` right
+// after `dim`.
+//
+// For example given shape=s8[10, 32, 20], dim=1, vect_size=4, returns
+// s8[10, 8, 4, 20].
+static Shape SplitShapeAtDim(Shape shape, int64_t dim, int64_t vect_size) {
+  DimensionVector new_dims(shape.dimensions().begin(),
+                           shape.dimensions().end());
+  CHECK_EQ(new_dims[dim] % vect_size, 0);
+  new_dims[dim] /= vect_size;
+  new_dims.insert(new_dims.begin() + dim + 1, vect_size);
+  return ShapeUtil::MakeShape(shape.element_type(), new_dims);
+}
+
+// Transposes dimension `src` to right before `dst`.
+static XlaOp MoveDim(XlaOp instr, int64_t src, int64_t dst) {
+  XlaBuilder& b = *instr.builder();
+  int64_t rank = b.GetShape(instr)->dimensions_size();
+
+  DimensionVector idxs(rank);
+  absl::c_iota(idxs, 0);
+  if (src < dst) {
+    idxs.insert(idxs.begin() + dst, src);
+    idxs.erase(idxs.begin() + src);
+  } else {
+    idxs.erase(idxs.begin() + src);
+    idxs.insert(idxs.begin() + dst, src);
+  }
+  return Transpose(instr, idxs);
+}
+
+// Reshapes instr so that dimension `vect_dim` has size `vect_size`, by stealing
+// elements from `dim`.
+//
+// Requires that this is possible without merging and re-splitting the two
+// dimensions.  I.e. there should be some amount of dim that we can "split off"
+// and add to vect_dim to get it to have size vect_size.
+static XlaOp RevectorizeInstr(XlaOp instr, int64_t dim, int64_t vect_dim,
+                              int64_t vect_size) {
+  XlaBuilder& b = *instr.builder();
+  Shape shape = b.GetShape(instr).value();
+  auto size = [&](int64_t d) { return shape.dimensions(d); };
+
+  CHECK_LE(size(vect_dim), vect_size);
+  CHECK_EQ(vect_size % size(vect_dim), 0);
+
+  int64_t split_factor = vect_size / size(vect_dim);
+  CHECK_EQ(size(dim) % split_factor, 0);
+
+  // Split dim into [C, split_factor].
+  instr = SplitAtDim(instr, dim, split_factor);
+
+  // SplitAtDim may have added a dimension before vect_dim.
+  if (vect_dim > dim) {
+    vect_dim++;
+  }
+
+  // Move the split_factor dimension to right before vect_dim.
+  instr = MoveDim(instr, dim + 1, vect_dim);
+
+  // Moving the split_factor dimension may have *removed* a dimension before
+  // vect_dim.
+  if (vect_dim > dim) {
+    vect_dim--;
+  }
+
+  // Collapse the split_factor dimension into vect_dim.
+  return Collapse(instr, {vect_dim, vect_dim + 1});
+}
+
+// Inverse of RevectorizeInstr.  Reshapes instr so that dimension `vect_dim` has
+// size `vect_size`, moving excess elements into `dim`.
+static XlaOp UnrevectorizeInstr(XlaOp instr, int64_t dim, int64_t vect_dim,
+                                int64_t orig_vect_size) {
+  XlaBuilder& b = *instr.builder();
+  Shape shape = b.GetShape(instr).value();
+  auto size = [&](int64_t d) { return shape.dimensions(d); };
+
+  CHECK_GE(size(vect_dim), orig_vect_size);
+  CHECK_EQ(size(vect_dim) % orig_vect_size, 0);
+
+  // Split vect_dim into [C, orig_vect_size].
+  instr = SplitAtDim(instr, vect_dim, orig_vect_size);
+
+  // SplitAtDim may have added a dimension before dim.
+  if (dim > vect_dim) {
+    dim++;
+  }
+
+  // Move the `C` dimension to right after `dim`.  Take into account that
+  // SplitAtDim may have added a dimension before dim.
+  instr = MoveDim(instr, vect_dim, dim + 1);
+
+  // MoveDim may have *removed* a dimension before dim.
+  if (dim > vect_dim) {
+    dim--;
+  }
+
+  // Collapse the `C` and `dim` dimensions.
+  return Collapse(instr, {dim, dim + 1});
+}
+
+// Adds a vectorized-feature dimension to dnums right after the current feature
+// dimension.
+//
+// ConvolutionDimensionNumbers doesn't represent the vectorized-feature
+// dimension explicitly, because the whole concept of a vectorized-feature
+// dimension is specific to cudnn.  Rather, the vectorized-feature dimension is
+// implicit; it's the first dimension that *doesn't* appear in the dnums.
+//
+// This function "makes room" in dnums for the new vectorized dimension by
+// incrementing any dimensions which appear after the feature dim.  The implicit
+// vector dim is then in this "empty" spot.
+static ConvolutionDimensionNumbers VectorizeDnums(
+    ConvolutionDimensionNumbers dnums, bool reordered_filter) {
+  int64_t input_vect_dim = dnums.input_feature_dimension();
+  if (dnums.input_batch_dimension() > input_vect_dim) {
+    dnums.set_input_batch_dimension(dnums.input_batch_dimension() + 1);
+  }
+  for (int64_t& d : *dnums.mutable_input_spatial_dimensions()) {
+    if (d > input_vect_dim) {
+      ++d;
+    }
+  }
+
+  if (!reordered_filter) {
+    int64_t kernel_vect_dim = dnums.kernel_input_feature_dimension();
+    if (dnums.kernel_output_feature_dimension() > kernel_vect_dim) {
+      dnums.set_kernel_output_feature_dimension(
+          dnums.kernel_output_feature_dimension() + 1);
+    }
+    for (int64_t& d : *dnums.mutable_kernel_spatial_dimensions()) {
+      if (d > kernel_vect_dim) {
+        ++d;
+      }
+    }
+  }
+
+  int64_t output_vect_dim = dnums.output_feature_dimension();
+  if (dnums.output_batch_dimension() > output_vect_dim) {
+    dnums.set_output_batch_dimension(dnums.output_batch_dimension() + 1);
+  }
+  for (int64_t& d : *dnums.mutable_output_spatial_dimensions()) {
+    if (d > output_vect_dim) {
+      ++d;
+    }
+  }
+
+  return dnums;
+}
+
+// Reorders the convolution's filter and bias (if present) according to
+// cudnnReorderFilterAndBias.  Also marks that the filter + bias are reordered
+// in the conv's backend-config.
+absl::Status ReorderInt8NchwVect(HloCustomCallInstruction* conv,
+                                 XlaOp* operands) {
+  bool has_bias = conv->operand_count() > 2;
+  VLOG(1) << "Reordering filter" << (has_bias ? " and bias" : "")
+          << " (replacement for cudnnReorderFilterAndBias)";
+
+  auto builder = operands->builder();
+  ConvolutionDimensionNumbers dnums = conv->convolution_dimension_numbers();
+
+  // Update convolution backend config.
+  TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
+                      conv->backend_config<GpuBackendConfig>());
+  CudnnConvBackendConfig& config =
+      *gpu_config.mutable_cudnn_conv_backend_config();
+  config.set_reordered_int8_nchw_vect(true);
+  TF_RETURN_IF_ERROR(conv->set_backend_config(gpu_config));
+
+  // Reorder the filter.
+  TF_ASSIGN_OR_RETURN(Shape filter_shape, builder->GetShape(operands[1]));
+  TF_ASSIGN_OR_RETURN(auto reorder, CudnnInferTransposeForFilterReordering(
+                                        filter_shape, dnums));
+  XlaOp reshape = Reshape(reorder.transpose_shape, operands[1]);
+  XlaOp transpose = Transpose(reshape, reorder.permutation);
+  operands[1] = Reshape(reorder.result_shape, transpose);
+
+  // The reshape-transpose-reshape we did above makes sure the resulting filter
+  // has dimension numbers corresponding to "oihw?", so update them.
+  dnums.set_kernel_output_feature_dimension(0);
+  dnums.set_kernel_input_feature_dimension(1);
+  dnums.set_kernel_spatial_dimensions(0, 2);
+  dnums.set_kernel_spatial_dimensions(1, 3);
+  conv->set_convolution_dimension_numbers(dnums);
+
+  if (has_bias) {
+    // Reorder the bias.
+    TF_ASSIGN_OR_RETURN(Shape bias_shape, builder->GetShape(operands[2]));
+    TF_ASSIGN_OR_RETURN(reorder,
+                        CudnnInferTransposeForBiasReordering(bias_shape));
+    reshape = Reshape(reorder.transpose_shape, operands[2]);
+    transpose = Transpose(reshape, reorder.permutation);
+    operands[2] = Reshape(reorder.result_shape, transpose);
+  }
+  return absl::OkStatus();
+}
+
+// Tries to vectorize an already-vectorized convolution.
+//
+// That is, given a convolution of shape [N, C/k, H, W, k], changes it to have
+// shape [N, C/vect_size, H, W, vect_size].  Similarly changes the filter from
+// [H, W, I/k, O] to [H, W, I/vect_size, vect_size, O].
+//
+// (The dimensions can appear in any order; which is N/C/etc is determined by
+// the convolutions' dnums.)
+static absl::StatusOr<bool> TryRevectorizeConv(
+    const se::CudaComputeCapability& compute_capability,
+    const se::dnn::VersionInfo& cudnn_version, HloCustomCallInstruction* conv,
+    int vect_size) {
+  const Shape& input_shape = conv->operand(0)->shape();
+  const Shape& kernel_shape = conv->operand(1)->shape();
+  const Shape& output_shape = conv->shape().tuple_shapes(0);
+  const ConvolutionDimensionNumbers* dnums =
+      &conv->convolution_dimension_numbers();
+
+  // Find the vectorized-features dim in the input/kernel/output.
+  std::optional<int64_t> input_vect_dim;
+  std::optional<int64_t> kernel_vect_dim;
+  std::optional<int64_t> output_vect_dim;
+  std::tie(input_vect_dim, kernel_vect_dim, output_vect_dim) =
+      FindVectorizedFeatureDims(*dnums, input_shape, kernel_shape,
+                                output_shape);
+
+  if (!input_vect_dim.has_value() || !kernel_vect_dim.has_value() ||
+      !output_vect_dim.has_value()) {
+    return false;
+  }
+
+  int64_t input_feat_size =
+      input_shape.dimensions(dnums->input_feature_dimension());
+  int64_t output_feat_size =
+      output_shape.dimensions(dnums->output_feature_dimension());
+  int64_t input_vect_size = input_shape.dimensions(*input_vect_dim);
+  int64_t output_vect_size = output_shape.dimensions(*output_vect_dim);
+  if (vect_size % input_vect_size != 0 || vect_size % output_vect_size != 0 ||
+      input_feat_size % (vect_size / input_vect_size) != 0 ||
+      output_feat_size % (vect_size / output_vect_size) != 0) {
+    return false;
+  }
+
+  // If this is an integer convolution check that we only vectorize when cuDNN
+  // supports the vectorized implementation.
+  if (primitive_util::IsIntegralType(input_shape.element_type())) {
+    TF_ASSIGN_OR_RETURN(bool supported_target_vectorization,
+                        CudnnSupportsOptimizedIntegerConvolution(
+                            compute_capability, *conv, vect_size));
+    if (!supported_target_vectorization) {
+      VLOG(3) << "Skipping re-vectorization of conv to vector size: "
+              << vect_size << ": " << conv->ToString();
+      return false;
+    }
+  }
+
+  VLOG(1) << "Re-vectorizing conv channels from "
+          << input_shape.dimensions(*input_vect_dim) << " to " << vect_size
+          << ": " << conv->ToString();
+
+  // We use XlaBuilder because it's a lot easier to get these tricky
+  // reshape/transposes correct using that API.
+  XlaBuilder b(absl::StrCat(conv->name(), ".revectorized"));
+  b.SetOpMetadata(conv->metadata());
+
+  XlaOp filter = Parameter(&b, 1, conv->operand(1)->shape(), "filter");
+  absl::InlinedVector<XlaOp, 4> new_operands = {
+      RevectorizeInstr(Parameter(&b, 0, conv->operand(0)->shape(), "input"),
+                       dnums->input_feature_dimension(), *input_vect_dim,
+                       vect_size),
+      RevectorizeInstr(filter, dnums->kernel_input_feature_dimension(),
+                       *kernel_vect_dim, vect_size),
+  };
+  if (conv->operand_count() > 2) {
+    // Bias, if present.  This is passed through unmodified.
+    new_operands.push_back(Parameter(&b, 2, conv->operand(2)->shape(), "bias"));
+  }
+  if (conv->operand_count() > 3) {
+    new_operands.push_back(RevectorizeInstr(
+        Parameter(&b, 3, conv->operand(3)->shape(), "side_input"),
+        dnums->input_feature_dimension(), *input_vect_dim, vect_size));
+  }
+
+  if (conv->operand_count() > 4) {
+    return InvalidArgument(
+        "Don't understand a conv with more than 4 arguments: %s",
+        conv->ToString());
+  }
+
+  // Reorder filter and bias for the int8x32 convolutions.  This requires cudnn
+  // >= 8.3.0.
+  //
+  // TODO(jlebar): Remove this guard once JAX no longer supports cudnn 8.3.
+  const auto& debug_options = conv->GetModule()->config().debug_options();
+  bool use_reordering =
+      input_shape.element_type() == xla::S8 && vect_size == 32 &&
+      debug_options.xla_gpu_enable_cudnn_int8x32_convolution_reordering() &&
+      cudnn_version >= se::dnn::VersionInfo{8, 3, 0};
+  if (use_reordering) {
+    // Reordering helper supports vector sizes of 4 and 32, so an additional
+    // reshape-transpose-reshape is not necessary in these cases.
+    int64_t kernel_vect_size = kernel_shape.dimensions(*kernel_vect_dim);
+    if (kernel_vect_size == 4 || kernel_vect_size == 32) {
+      new_operands[1] = filter;
+    }
+    TF_RETURN_IF_ERROR(ReorderInt8NchwVect(conv, new_operands.data()));
+    dnums = &conv->convolution_dimension_numbers();
+  }
+
+  // The custom-call returns a tuple (new_output_shape, u8[0]), where the second
+  // value in the tuple represents the convolution's scratch memory.
+  DimensionVector new_output_dims(output_shape.dimensions().begin(),
+                                  output_shape.dimensions().end());
+  new_output_dims[dnums->output_feature_dimension()] /=
+      (vect_size / output_vect_size);
+  new_output_dims[*output_vect_dim] = vect_size;
+  XlaOp new_conv = CustomCallWithConvDnums(
+      &b, conv->custom_call_target(), new_operands,
+      ShapeUtil::MakeTupleShape(
+          {ShapeUtil::MakeShape(output_shape.element_type(), new_output_dims),
+           ShapeUtil::MakeShape(U8, {0})}),
+      /*operand_shapes_with_layout=*/{},
+      /*opaque=*/conv->raw_backend_config_string(), /*has_side_effect=*/false,
+      /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
+      /*window=*/conv->window(),
+      /*dnums=*/*dnums);
+
+  XlaOp new_conv_result = GetTupleElement(new_conv, 0);
+  XlaOp new_conv_scratch = GetTupleElement(new_conv, 1);
+
+  XlaOp new_conv_result_unrevectorized = UnrevectorizeInstr(
+      new_conv_result, dnums->output_feature_dimension(), *output_vect_dim,
+      /*orig_vect_size=*/output_shape.dimensions(*output_vect_dim));
+
+  TF_ASSIGN_OR_RETURN(
+      HloComputation * new_conv_comp,
+      BuilderToHloComputation(
+          b, Tuple(&b, {new_conv_result_unrevectorized, new_conv_scratch}),
+          conv->parent()));
+
+  // Set the name on the new conv.  This is purely cosmetic, but we attempt to
+  // preserve e.g. "cudnn-conv.42" instead of "custom-call.42".
+  auto new_conv_comp_instrs = new_conv_comp->instructions();
+  auto new_conv_it =
+      absl::c_find_if(new_conv_comp_instrs, [](HloInstruction* instr) {
+        return instr->opcode() == HloOpcode::kCustomCall;
+      });
+  if (new_conv_it != new_conv_comp_instrs.end()) {
+    new_conv_comp->parent()->SetAndUniquifyInstrName(*new_conv_it,
+                                                     conv->name());
+  }
+
+  // Replace the old conv with a call to the computation we just created.
+  VLOG(1) << "Re-vectorized conv to " << new_conv_comp->ToString();
+  TF_RETURN_IF_ERROR(conv->parent()->ReplaceWithNewInstruction(
+      conv, HloInstruction::CreateCall(conv->shape(), conv->operands(),
+                                       new_conv_comp)));
+
+  return true;
+}
+
+// Tries to vectorize a convolution.
+//
+// Given a convolution of dimensions [N, C, H, W], tries to convert it to have
+// shape [N, C/vect_size, H, W, vect_size].  Similarly, given a kernel of shape
+// [H, W, I, O], tries to conver it to [H, W, I/vect_size, vect_size, O].
+//
+// This requires that C be a multiple of vect_size.  CudnnPadForConvolutions can
+// add padding to make this true.
+static absl::StatusOr<bool> TryVectorizeConv(
+    const se::CudaComputeCapability& compute_capability,
+    const se::dnn::VersionInfo& cudnn_version, HloCustomCallInstruction* conv,
+    int64_t vect_size) {
+  const Shape& input_shape = conv->operand(0)->shape();
+  const Shape& output_shape = conv->shape().tuple_shapes(0);
+  const ConvolutionDimensionNumbers* dnums =
+      &conv->convolution_dimension_numbers();
+  int64_t in_channels =
+      input_shape.dimensions(dnums->input_feature_dimension());
+  int64_t out_channels =
+      output_shape.dimensions(dnums->output_feature_dimension());
+
+  if (in_channels % vect_size != 0 || out_channels % vect_size != 0) {
+    return false;
+  }
+
+  if (input_shape.dimensions_size() >
+      2 + dnums->input_spatial_dimensions_size()) {
+    // Conv already has an extra dimension, which we assume is the vectorized
+    // features dim.
+    return false;
+  }
+
+  // If this is an integer convolution check that we only vectorize when cuDNN
+  // supports the vectorized implementation.
+  if (primitive_util::IsIntegralType(input_shape.element_type())) {
+    TF_ASSIGN_OR_RETURN(bool supported_target_vectorization,
+                        CudnnSupportsOptimizedIntegerConvolution(
+                            compute_capability, *conv, vect_size));
+    if (!supported_target_vectorization) {
+      VLOG(3) << "Skipping vectorization of conv to vector size: " << vect_size
+              << ": " << conv->ToString();
+      return false;
+    }
+  }
+
+  VLOG(1) << "Vectorizing conv channels by " << vect_size << ": "
+          << conv->ToString();
+
+  // We use XlaBuilder because it's a lot easier to get these tricky
+  // reshape/transposes correct using that API.
+  XlaBuilder b(absl::StrCat(conv->name(), ".revectorized"));
+  b.SetOpMetadata(conv->metadata());
+
+  XlaOp filter = Parameter(&b, 1, conv->operand(1)->shape(), "filter");
+  absl::InlinedVector<XlaOp, 4> new_operands = {
+      SplitAtDim(Parameter(&b, 0, conv->operand(0)->shape(), "input"),
+                 dnums->input_feature_dimension(), vect_size),
+      SplitAtDim(filter, dnums->kernel_input_feature_dimension(), vect_size),
+  };
+  if (conv->operand_count() > 2) {
+    // Bias, if present.  This is passed through unmodified.
+    new_operands.push_back(Parameter(&b, 2, conv->operand(2)->shape(), "bias"));
+  }
+  if (conv->operand_count() > 3) {
+    // Handle side input, which has same shape as the output.
+    new_operands.push_back(
+        SplitAtDim(Parameter(&b, 3, conv->operand(3)->shape(), "side_input"),
+                   dnums->output_feature_dimension(), vect_size));
+  }
+  if (conv->operand_count() > 4) {
+    return InvalidArgument(
+        "Don't understand a conv with more than 4 arguments: %s",
+        conv->ToString());
+  }
+
+  // Reorder filter and bias for the int8x32 convolutions.  This requires cudnn
+  // >= 8.3.0.
+  //
+  // TODO(jlebar): Remove this guard once JAX no longer supports cudnn 8.3.
+  const auto& debug_options = conv->GetModule()->config().debug_options();
+  bool use_reordering =
+      input_shape.element_type() == xla::S8 && vect_size == 32 &&
+      debug_options.xla_gpu_enable_cudnn_int8x32_convolution_reordering() &&
+      cudnn_version >= se::dnn::VersionInfo{8, 3, 0};
+  if (use_reordering) {
+    new_operands[1] = filter;
+    TF_RETURN_IF_ERROR(ReorderInt8NchwVect(conv, new_operands.data()));
+    dnums = &conv->convolution_dimension_numbers();
+  }
+
+  // The custom-call returns a tuple (new_output_shape, u8[0]), where the second
+  // value in the tuple represents the convolution's scratch memory.
+  Shape new_output_shape = SplitShapeAtDim(
+      output_shape, dnums->output_feature_dimension(), vect_size);
+  XlaOp new_conv = CustomCallWithConvDnums(
+      &b, conv->custom_call_target(), new_operands,
+      ShapeUtil::MakeTupleShape(
+          {new_output_shape, ShapeUtil::MakeShape(U8, {0})}),
+      /*operand_shapes_with_layout=*/{},
+      /*opaque=*/conv->raw_backend_config_string(), /*has_side_effect=*/false,
+      /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
+      /*window=*/conv->window(),
+      /*dnums=*/VectorizeDnums(*dnums, use_reordering));
+
+  XlaOp new_conv_result = GetTupleElement(new_conv, 0);
+  XlaOp new_conv_scratch = GetTupleElement(new_conv, 1);
+
+  // Reshape back to the original shape.
+  XlaOp conv_result_collapsed =
+      Collapse(new_conv_result, {dnums->output_feature_dimension(),
+                                 dnums->output_feature_dimension() + 1});
+
+  TF_ASSIGN_OR_RETURN(
+      HloComputation * new_conv_comp,
+      BuilderToHloComputation(
+          b, Tuple(&b, {conv_result_collapsed, new_conv_scratch}),
+          conv->parent()));
+
+  // Create a tuple and replace the old conv with it!
+  VLOG(1) << "Vectorized conv to: " << new_conv_comp->ToString();
+  TF_RETURN_IF_ERROR(conv->parent()->ReplaceWithNewInstruction(
+      conv, HloInstruction::CreateCall(conv->shape(), conv->operands(),
+                                       new_conv_comp)));
+  return true;
+}
+
+}  // namespace
+
+absl::StatusOr<bool> CudnnVectorizeConvolutions::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  bool changed = false;
+  for (HloComputation* comp :
+       module->MakeNonfusionComputations(execution_threads)) {
+    for (HloCustomCallInstruction* conv : GetRelevantConvs(comp)) {
+      // Try to (re)vectorize to int8x32 if this is an sm75+ GPU.  If we can't,
+      // fall back to int8x4.
+      bool local_changed = false;
+      if (compute_capability_.IsAtLeast(7, 5)) {
+        TF_ASSIGN_OR_RETURN(
+            local_changed,
+            TryRevectorizeConv(compute_capability_, cudnn_version_, conv, 32));
+        if (!local_changed) {
+          TF_ASSIGN_OR_RETURN(
+              local_changed,
+              TryVectorizeConv(compute_capability_, cudnn_version_, conv, 32));
+        }
+      }
+      if (!local_changed) {
+        TF_ASSIGN_OR_RETURN(
+            local_changed,
+            TryVectorizeConv(compute_capability_, cudnn_version_, conv, 4));
+      }
+      changed |= local_changed;
+    }
+  }
+  return changed;
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_vectorize_convolutions.h b/third_party/xla/xla/service/gpu/transforms/cudnn_vectorize_convolutions.h
new file mode 100644
index 0000000..6fd2f6e
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cudnn_vectorize_convolutions.h
@@ -0,0 +1,73 @@
+/* Copyright 2021 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_CUDNN_VECTORIZE_CONVOLUTIONS_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_CUDNN_VECTORIZE_CONVOLUTIONS_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/stream_executor/dnn.h"
+
+namespace xla {
+namespace gpu {
+
+// Changes the shape of cudnn convolutions to allow faster "vectorized"
+// algorithms.
+//
+// On sm61+ will convert int8_t convolutions from
+//
+//   - [N, C, H, W] to [N, C/4, H, W, 4],
+//
+// assuming C is divisible by 4.
+//
+// On sm75+ will convert int8_t convolutions from
+//
+//   - [N, C, H, W]      to [N, C/32, H, W, 32],
+//   - [N, C/4, H, W, 4] to [N, C/32, H, W, 32], and
+//   - [N, C, H, W]      to [N,  C/4, H, W,  4] (same as sm61+),
+//
+// assuming C is divisible by 4 or 32.
+//
+// This pass will not pad the channel dim to a multiple of 4 or 32, so you
+// should run CudnnPadForConvolutions before this.
+class CudnnVectorizeConvolutions : public HloModulePass {
+ public:
+  explicit CudnnVectorizeConvolutions(
+      se::CudaComputeCapability compute_capability,
+      se::dnn::VersionInfo cudnn_version)
+      : compute_capability_(compute_capability),
+        cudnn_version_(cudnn_version) {}
+
+  absl::string_view name() const override {
+    return "cudnn_vectorize_convolutions";
+  }
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+  const se::CudaComputeCapability compute_capability_;
+  const se::dnn::VersionInfo cudnn_version_;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_CUDNN_VECTORIZE_CONVOLUTIONS_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_vectorize_convolutions_test.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_vectorize_convolutions_test.cc
new file mode 100644
index 0000000..7528870
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/cudnn_vectorize_convolutions_test.cc
@@ -0,0 +1,758 @@
+/* Copyright 2021 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/cudnn_vectorize_convolutions.h"
+
+#include <cstdint>
+#include <utility>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/algorithm/container.h"
+#include "absl/status/statusor.h"
+#include "xla/service/call_inliner.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/service/hlo_parser.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/stream_executor/dnn.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/util.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+namespace m = ::xla::match;
+
+class CudnnVectorizeConvolutionsTest : public HloTestBase {
+ protected:
+  // Runs this pass and some cleanup to make pattern-matching easier.
+  absl::StatusOr<bool> Run(std::pair<int, int> compute_capability,
+                           HloModule* module) {
+    CudnnVectorizeConvolutions pass(
+        se::CudaComputeCapability{compute_capability.first,
+                                  compute_capability.second},
+        se::dnn::VersionInfo(8, 3, 0));
+    TF_ASSIGN_OR_RETURN(bool changed, RunHloPass(&pass, module));
+
+    CallInliner inliner;
+    TF_RETURN_IF_ERROR(RunHloPass(&inliner, module).status());
+
+    return changed;
+  }
+};
+
+TEST_F(CudnnVectorizeConvolutionsTest, VectorizeTo4) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule TestModule
+
+  ENTRY TestComputation {
+    input = s8[10,20,30,40] parameter(0)
+    filter = s8[2,2,40,44] parameter(1)
+    ROOT result = (s8[10,20,30,44], u8[0]) custom-call(input, filter),
+                  window={size=2x2}, dim_labels=b01f_01io->b01f,
+                  custom_call_target="__cudnn$convForward",
+                  backend_config="{bar: 0}"
+  })")
+                    .value();
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
+  EXPECT_TRUE(changed);
+
+  SCOPED_TRACE(module->ToString());
+  auto* root = module->entry_computation()->root_instruction();
+
+  const HloInstruction* conv = nullptr;
+  ASSERT_THAT(
+      root,
+      GmockMatch(m::Tuple(
+          m::Reshape(m::GetTupleElement(
+                         m::CustomCall(&conv, {kCudnnConvForwardCallTarget},
+                                       m::Reshape(m::Parameter(0))
+                                           .WithShape(S8, {10, 20, 30, 10, 4}),
+                                       m::Reshape(m::Parameter(1))
+                                           .WithShape(S8, {2, 2, 10, 4, 44}))
+                             .WithConvDnums("b01f?_01i?o->b01f?"))
+                         .WithShape(S8, {10, 20, 30, 11, 4})),
+          m::Op())));
+
+  EXPECT_EQ(conv->raw_backend_config_string(), "{bar: 0}");
+}
+
+TEST_F(CudnnVectorizeConvolutionsTest, NoVectorizeTo4UnsupportedFilterType) {
+  // This test checks that the vectorize pass correctly calls
+  // CudnnSupportsOptimizedIntegerConvolution() which should reject this
+  // convolution because its filter type is f32.
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule TestModule
+
+  ENTRY TestComputation {
+    input = s8[10,20,30,40] parameter(0)
+    filter = f32[2,2,40,44] parameter(1)
+    ROOT result = (s8[10,20,30,44], u8[0]) custom-call(input, filter),
+                  window={size=2x2}, dim_labels=b01f_01io->b01f,
+                  custom_call_target="__cudnn$convForward",
+                  backend_config="{bar: 0}"
+  })")
+                    .value();
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
+  EXPECT_FALSE(changed);
+}
+
+TEST_F(CudnnVectorizeConvolutionsTest, VectorizeTo4NCHW) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule TestModule
+
+  ENTRY TestComputation {
+    input = s8[10,48,20,30] parameter(0)
+    filter = s8[48,44,2,2] parameter(1)
+    ROOT result = (s8[10,44,20,30], u8[0]) custom-call(input, filter),
+                  window={size=2x2}, dim_labels=bf01_io01->bf01,
+                  custom_call_target="__cudnn$convForward"
+  })")
+                    .value();
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
+  EXPECT_TRUE(changed);
+
+  SCOPED_TRACE(module->ToString());
+  auto* root = module->entry_computation()->root_instruction();
+
+  const HloInstruction* conv = nullptr;
+  ASSERT_THAT(
+      root,
+      GmockMatch(m::Tuple(
+          m::Reshape(m::GetTupleElement(
+                         m::CustomCall(&conv, {kCudnnConvForwardCallTarget},
+                                       m::Reshape(m::Parameter(0))
+                                           .WithShape(S8, {10, 12, 4, 20, 30}),
+                                       m::Reshape(m::Parameter(1))
+                                           .WithShape(S8, {12, 4, 44, 2, 2}))
+                             .WithConvDnums("bf?01_i?o01->bf?01"))
+                         .WithShape(S8, {10, 11, 4, 20, 30})),
+          m::Op())));
+}
+
+TEST_F(CudnnVectorizeConvolutionsTest, IncrementAllDnums) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule TestModule
+
+  ENTRY TestComputation {
+    input = s8[16,16,16,16] parameter(0)
+    filter = s8[16,16,3,3] parameter(1)
+    ROOT result = (s8[16,16,16,16], u8[0]) custom-call(input, filter),
+                  window={size=2x2}, dim_labels=fb01_i01o->fb01,
+                  custom_call_target="__cudnn$convForward"
+  })")
+                    .value();
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
+  EXPECT_TRUE(changed);
+
+  SCOPED_TRACE(module->ToString());
+  auto* root = module->entry_computation()->root_instruction();
+
+  const HloInstruction* conv = nullptr;
+  ASSERT_THAT(
+      root,
+      GmockMatch(m::Tuple(
+          m::Reshape(m::GetTupleElement(
+                         m::CustomCall(&conv, {kCudnnConvForwardCallTarget},
+                                       m::Reshape(m::Parameter(0))
+                                           .WithShape(S8, {4, 4, 16, 16, 16}),
+                                       m::Reshape(m::Parameter(1))
+                                           .WithShape(S8, {4, 4, 16, 3, 3}))
+                             .WithConvDnums("f?b01_i?01o->f?b01"))
+                         .WithShape(S8, {4, 4, 16, 16, 16})),
+          m::Op())));
+}
+
+TEST_F(CudnnVectorizeConvolutionsTest, FilterDnums) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule TestModule
+
+  ENTRY TestComputation {
+    input = s8[1,20,9,9] parameter(0)
+    filter = s8[3,3,20,32] parameter(1)
+    ROOT result = (s8[1,32,9,9], u8[0]) custom-call(s8[1,20,9,9] input, s8[3,3,20,32] filter),
+                  window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01,
+                  custom_call_target="__cudnn$convForward"
+  })")
+                    .value();
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
+  EXPECT_TRUE(changed);
+
+  SCOPED_TRACE(module->ToString());
+  auto* root = module->entry_computation()->root_instruction();
+
+  const HloInstruction* conv = nullptr;
+  ASSERT_THAT(
+      root,
+      GmockMatch(m::Tuple(
+          m::Reshape(m::GetTupleElement(
+                         m::CustomCall(&conv, {kCudnnConvForwardCallTarget},
+                                       m::Reshape(m::Parameter(0))
+                                           .WithShape(S8, {1, 5, 4, 9, 9}),
+                                       m::Reshape(m::Parameter(1))
+                                           .WithShape(S8, {3, 3, 5, 4, 32}))
+                             .WithConvDnums("bf?01_01i?o->bf?01"))
+                         .WithShape(S8, {1, 8, 4, 9, 9})),
+          m::Op())));
+}
+
+TEST_F(CudnnVectorizeConvolutionsTest, NoVectorizeTo4) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule TestModule
+
+  ENTRY TestComputation {
+    input = s8[10,20,30,41] parameter(0)
+    filter = s8[2,2,41,44] parameter(1)
+    ROOT result = (s8[10,20,30,44], u8[0]) custom-call(input, filter),
+                  window={size=2x2}, dim_labels=b01f_01io->b01f,
+                  custom_call_target="__cudnn$convForward"
+  })")
+                    .value();
+  CudnnVectorizeConvolutions pass(
+      /*compute_capability=*/{7, 5},
+      /*cudnn_version=*/se::dnn::VersionInfo{8, 3, 0});
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
+
+  SCOPED_TRACE(module->ToString());
+  EXPECT_FALSE(changed);
+}
+
+// Don't vectorize int8_t -> int32_t into int8x4 or int8x32; this is not
+// supported in cudnn.
+TEST_F(CudnnVectorizeConvolutionsTest, NoVectorizeTo4IfOutputIsS32) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule TestModule
+
+  ENTRY TestComputation {
+    input = s8[10,20,30,41] parameter(0)
+    filter = s8[2,2,41,44] parameter(1)
+    ROOT result = (s32[10,20,30,44], u8[0]) custom-call(input, filter),
+                  window={size=2x2}, dim_labels=b01f_01io->b01f,
+                  custom_call_target="__cudnn$convForward"
+  })")
+                    .value();
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
+  SCOPED_TRACE(module->ToString());
+  EXPECT_FALSE(changed);
+}
+
+// Don't vectorize int8_t -> float into int8x4 or int8x32.  Vectorizing to
+// int8x4 *is* allowed by cudnn, but we don't do it at the moment.
+TEST_F(CudnnVectorizeConvolutionsTest, NoVectorizeTo4IfOutputIsF32) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule TestModule
+
+  ENTRY TestComputation {
+    input = s8[10,20,30,41] parameter(0)
+    filter = s8[2,2,41,44] parameter(1)
+    ROOT result = (f32[10,20,30,44], u8[0]) custom-call(input, filter),
+                  window={size=2x2}, dim_labels=b01f_01io->b01f,
+                  custom_call_target="__cudnn$convForward"
+  })")
+                    .value();
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
+  SCOPED_TRACE(module->ToString());
+  EXPECT_FALSE(changed);
+}
+
+TEST_F(CudnnVectorizeConvolutionsTest, VectorizeTo32) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule TestModule
+
+  ENTRY TestComputation {
+    input = s8[10,20,30,64] parameter(0)
+    filter = s8[2,2,64,128] parameter(1)
+    ROOT result = (s8[10,20,30,128], u8[0]) custom-call(input, filter),
+                  window={size=2x2}, dim_labels=b01f_01io->b01f,
+                  custom_call_target="__cudnn$convForward"
+  })")
+                    .value();
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
+  EXPECT_TRUE(changed);
+
+  SCOPED_TRACE(module->ToString());
+  auto* root = module->entry_computation()->root_instruction();
+
+  const HloInstruction* conv = nullptr;
+  ASSERT_THAT(
+      root,
+      GmockMatch(m::Tuple(
+          m::Reshape(
+              m::GetTupleElement(
+                  m::CustomCall(
+                      &conv, {kCudnnConvForwardCallTarget},
+                      m::Reshape(m::Parameter(0))
+                          .WithShape(S8, {10, 20, 30, 2, 32}),
+                      m::Reshape(
+                          m::Transpose(
+                              m::Reshape(m::Parameter(1))
+                                  .WithShape(S8, {2, 2, 2, 8, 4, 16, 4, 2}))
+                              .WithShape(S8, {2, 2, 2, 16, 2, 8, 4, 4})
+                              .WithPredicate([](const HloInstruction* instr) {
+                                return absl::c_equal(
+                                    instr->dimensions(),
+                                    std::vector<int64_t>{2, 0, 1, 5, 7, 3, 6,
+                                                         4});
+                              }))
+                          .WithShape(S8, {128, 2, 2, 2, 32})))
+                  .WithShape(S8, {10, 20, 30, 4, 32})),
+          m::Op())));
+
+  EXPECT_TRUE(conv->backend_config<GpuBackendConfig>()
+                  ->cudnn_conv_backend_config()
+                  .reordered_int8_nchw_vect());
+}
+
+TEST_F(CudnnVectorizeConvolutionsTest, BiasAndSideInput) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule TestModule
+
+  ENTRY TestComputation {
+    input = s8[10,20,30,64] parameter(0)
+    filter = s8[2,2,64,128] parameter(1)
+    bias = f32[128] parameter(2)
+    side_input = s8[10,20,30,64] parameter(3)
+
+    ROOT result = (s8[10,20,30,128], u8[0]) custom-call(input, filter, bias, side_input),
+                  window={size=2x2}, dim_labels=b01f_01io->b01f,
+                  custom_call_target="__cudnn$convForward"
+  })")
+                    .value();
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
+  EXPECT_TRUE(changed);
+
+  SCOPED_TRACE(module->ToString());
+  auto* root = module->entry_computation()->root_instruction();
+
+  const HloInstruction* conv = nullptr;
+  ASSERT_THAT(
+      root,
+      GmockMatch(m::Tuple(
+          m::Reshape(
+              m::GetTupleElement(
+                  m::CustomCall(
+                      &conv, {kCudnnConvForwardCallTarget},
+                      m::Reshape(m::Parameter(0))
+                          .WithShape(S8, {10, 20, 30, 2, 32}),
+                      m::Reshape(m::Transpose(m::Reshape(m::Parameter(1))))
+                          .WithShape(S8, {128, 2, 2, 2, 32}),
+                      m::Reshape(
+                          m::Transpose(m::Reshape(m::Parameter(2))
+                                           .WithShape(F32, {4, 4, 2, 4}))
+                              .WithShape(F32, {4, 2, 4, 4})
+                              .WithPredicate([](const HloInstruction* instr) {
+                                return absl::c_equal(
+                                    instr->dimensions(),
+                                    std::vector<int64_t>{0, 2, 1, 3});
+                              }))
+                          .WithShape(F32, {128}),
+                      m::Reshape(m::Parameter(3))
+                          .WithShape(S8, {10, 20, 30, 2, 32})))
+                  .WithShape(S8, {10, 20, 30, 4, 32})),
+          m::Op())));
+
+  EXPECT_TRUE(conv->backend_config<GpuBackendConfig>()
+                  ->cudnn_conv_backend_config()
+                  .reordered_int8_nchw_vect());
+}
+
+TEST_F(CudnnVectorizeConvolutionsTest, InputNHWC_OutputNCHW) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule TestModule
+
+  ENTRY TestComputation {
+    input = s8[10,20,30,64] parameter(0)
+    filter = s8[2,2,64,128] parameter(1)
+    bias = f32[128] parameter(2)
+    side_input = s8[10,128,20,30] parameter(3)
+
+    ROOT result = (s8[10,128,20,30], u8[0]) custom-call(input, filter, bias, side_input),
+                  window={size=2x2}, dim_labels=b01f_01io->bf01,
+                  custom_call_target="__cudnn$convForward"
+  })")
+                    .value();
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
+  EXPECT_TRUE(changed);
+
+  SCOPED_TRACE(module->ToString());
+  auto* root = module->entry_computation()->root_instruction();
+
+  const HloInstruction* conv = nullptr;
+  ASSERT_THAT(
+      root,
+      GmockMatch(m::Tuple(
+          m::Reshape(
+              m::GetTupleElement(
+                  m::CustomCall(
+                      &conv, {kCudnnConvForwardCallTarget},
+                      m::Reshape(m::Parameter(0))
+                          .WithShape(S8, {10, 20, 30, 2, 32}),
+                      m::Reshape(m::Transpose(m::Reshape(m::Parameter(1))))
+                          .WithShape(S8, {128, 2, 2, 2, 32}),
+                      m::Reshape(
+                          m::Transpose(m::Reshape(m::Parameter(2))
+                                           .WithShape(F32, {4, 4, 2, 4}))
+                              .WithShape(F32, {4, 2, 4, 4})
+                              .WithPredicate([](const HloInstruction* instr) {
+                                return absl::c_equal(
+                                    instr->dimensions(),
+                                    std::vector<int64_t>{0, 2, 1, 3});
+                              }))
+                          .WithShape(F32, {128}),
+                      m::Reshape(m::Parameter(3))
+                          .WithShape(S8, {10, 4, 32, 20, 30})))
+                  .WithShape(S8, {10, 4, 32, 20, 30})),
+          m::Op())));
+
+  EXPECT_TRUE(conv->backend_config<GpuBackendConfig>()
+                  ->cudnn_conv_backend_config()
+                  .reordered_int8_nchw_vect());
+}
+
+TEST_F(CudnnVectorizeConvolutionsTest, NoVectorizeTo32) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule TestModule
+
+  ENTRY TestComputation {
+    input = s8[10,20,30,64] parameter(0)
+    filter = s8[2,2,64,128] parameter(1)
+    ROOT result = (s8[10,20,30,128], u8[0]) custom-call(input, filter),
+                  window={size=2x2}, dim_labels=b01f_01io->b01f,
+                  custom_call_target="__cudnn$convForward"
+  })")
+                    .value();
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 0}, module.get()));
+  EXPECT_TRUE(changed);
+
+  SCOPED_TRACE(module->ToString());
+  auto* root = module->entry_computation()->root_instruction();
+
+  const HloInstruction* conv = nullptr;
+  ASSERT_THAT(
+      root,
+      GmockMatch(m::Tuple(
+          m::Reshape(m::GetTupleElement(
+                         m::CustomCall(&conv, {kCudnnConvForwardCallTarget},
+                                       m::Reshape(m::Parameter(0))
+                                           .WithShape(S8, {10, 20, 30, 16, 4}),
+                                       m::Reshape(m::Parameter(1))
+                                           .WithShape(S8, {2, 2, 16, 4, 128})))
+                         .WithShape(S8, {10, 20, 30, 32, 4})),
+          m::Op())));
+
+  EXPECT_FALSE(conv->backend_config<GpuBackendConfig>()
+                   ->cudnn_conv_backend_config()
+                   .reordered_int8_nchw_vect());
+}
+
+TEST_F(CudnnVectorizeConvolutionsTest, Vectorize4To32) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule TestModule
+
+  ENTRY TestComputation {
+    input = s8[10,20,30,16,4] parameter(0)
+    filter = s8[3,5,16,192,4] parameter(1)
+    bias = f32[64] parameter(2)
+    side_input = s8[10,20,30,16,4] parameter(3)
+    ROOT result = (s8[10,20,30,48,4], u8[0]) custom-call(input, filter, bias, side_input),
+                  window={size=3x5}, dim_labels=b01f_01io->b01f,
+                  custom_call_target="__cudnn$convForward"
+  })")
+                    .value();
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
+  EXPECT_TRUE(changed);
+
+  SCOPED_TRACE(module->ToString());
+  auto* root = module->entry_computation()->root_instruction();
+
+  const HloInstruction* conv = nullptr;
+  auto conv_pat =
+      m::GetTupleElement(
+          m::CustomCall(
+              &conv, {kCudnnConvForwardCallTarget},
+              m::Reshape(m::Transpose(m::Reshape(m::Parameter(0))
+                                          .WithShape(S8, {10, 20, 30, 2, 8, 4}))
+                             .WithShape(S8, {10, 20, 30, 2, 8, 4}))
+                  .WithShape(S8, {10, 20, 30, 2, 32}),
+              m::Reshape(
+                  m::Transpose(m::Reshape(m::Parameter(1))
+                                   .WithShape(S8, {3, 5, 2, 8, 24, 4, 2, 4}))
+                      .WithShape(S8, {2, 3, 5, 24, 2, 8, 4, 4})
+                      .WithPredicate([](const HloInstruction* instr) {
+                        return absl::c_equal(
+                            instr->dimensions(),
+                            std::vector<int64_t>{2, 0, 1, 4, 6, 3, 5, 7});
+                      }))
+                  .WithShape(S8, {192, 2, 3, 5, 32}),
+              m::Reshape(m::Transpose(m::Reshape(m::Parameter(2)))),
+              m::Reshape(m::Transpose(m::Reshape(m::Parameter(3))
+                                          .WithShape(S8, {10, 20, 30, 2, 8, 4}))
+                             .WithShape(S8, {10, 20, 30, 2, 8, 4}))
+                  .WithShape(S8, {10, 20, 30, 2, 32}))
+              .WithConvDnums("b01f?_oi01?->b01f?"))
+          .WithShape(S8, {10, 20, 30, 6, 32});
+  ASSERT_THAT(root, GmockMatch(m::Tuple(
+                        m::Reshape(m::Transpose(m::Reshape(conv_pat).WithShape(
+                                                    S8, {10, 20, 30, 6, 8, 4}))
+                                       .WithShape(S8, {10, 20, 30, 6, 8, 4}))
+                            .WithShape(S8, {10, 20, 30, 48, 4}),
+                        m::Op())));
+
+  EXPECT_TRUE(conv->backend_config<GpuBackendConfig>()
+                  ->cudnn_conv_backend_config()
+                  .reordered_int8_nchw_vect());
+}
+
+TEST_F(CudnnVectorizeConvolutionsTest, Vectorize4To32NCHW) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule TestModule
+
+  ENTRY TestComputation {
+    input = s8[10,16,20,30,4] parameter(0)
+    filter = s8[16,128,2,2,4] parameter(1)
+    bias = f32[64] parameter(2)
+    side_input = s8[10,16,20,30,4] parameter(3)
+    ROOT result = (s8[10,32,20,30,4], u8[0]) custom-call(input, filter, bias, side_input),
+                  window={size=2x2}, dim_labels=bf01_io01->bf01,
+                  custom_call_target="__cudnn$convForward"
+  })")
+                    .value();
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
+  EXPECT_TRUE(changed);
+
+  SCOPED_TRACE(module->ToString());
+  auto* root = module->entry_computation()->root_instruction();
+
+  const HloInstruction* conv = nullptr;
+  auto conv_pat =
+      m::GetTupleElement(
+          m::CustomCall(
+              &conv, {kCudnnConvForwardCallTarget},
+              m::Reshape(m::Transpose(m::Reshape(m::Parameter(0))
+                                          .WithShape(S8, {10, 2, 8, 20, 30, 4}))
+                             .WithShape(S8, {10, 2, 20, 30, 8, 4}))
+                  .WithShape(S8, {10, 2, 20, 30, 32}),
+              m::Reshape(
+                  m::Transpose(m::Reshape(m::Parameter(1))
+                                   .WithShape(S8, {2, 8, 16, 4, 2, 2, 2, 4}))
+                      .WithShape(S8, {2, 2, 2, 16, 2, 8, 4, 4})
+                      .WithPredicate([](const HloInstruction* instr) {
+                        return absl::c_equal(
+                            instr->dimensions(),
+                            std::vector<int64_t>{0, 5, 6, 2, 4, 1, 3, 7});
+                      }))
+                  .WithShape(S8, {128, 2, 2, 2, 32}),
+              m::Reshape(m::Transpose(m::Reshape(m::Parameter(2)))),
+              m::Reshape(m::Transpose(m::Reshape(m::Parameter(3))
+                                          .WithShape(S8, {10, 2, 8, 20, 30, 4}))
+                             .WithShape(S8, {10, 2, 20, 30, 8, 4}))
+                  .WithShape(S8, {10, 2, 20, 30, 32}))
+              .WithConvDnums("bf01_oi01->bf01"))
+          .WithShape(S8, {10, 4, 20, 30, 32});
+  ASSERT_THAT(root, GmockMatch(m::Tuple(
+                        m::Reshape(m::Transpose(m::Reshape(conv_pat).WithShape(
+                                                    S8, {10, 4, 20, 30, 8, 4}))
+                                       .WithShape(S8, {10, 4, 8, 20, 30, 4}))
+                            .WithShape(S8, {10, 32, 20, 30, 4}),
+                        m::Op())));
+
+  EXPECT_TRUE(conv->backend_config<GpuBackendConfig>()
+                  ->cudnn_conv_backend_config()
+                  .reordered_int8_nchw_vect());
+}
+
+TEST_F(CudnnVectorizeConvolutionsTest, Vectorize4To32VectorDimFirst) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule TestModule
+
+  ENTRY TestComputation {
+    input = s8[4,10,20,30,16] parameter(0)
+    filter = s8[4,3,5,16,192] parameter(1)
+    bias = f32[64] parameter(2)
+    side_input = s8[4,10,20,30,16] parameter(3)
+    ROOT result = (s8[4,10,20,30,48], u8[0]) custom-call(input, filter, bias, side_input),
+                  window={size=3x5}, dim_labels=?b01f_?01io->?b01f,
+                  custom_call_target="__cudnn$convForward"
+  })")
+                    .value();
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
+  EXPECT_TRUE(changed);
+
+  SCOPED_TRACE(module->ToString());
+  auto* root = module->entry_computation()->root_instruction();
+
+  const HloInstruction* conv = nullptr;
+  auto conv_pat =
+      m::GetTupleElement(
+          m::CustomCall(
+              &conv, {kCudnnConvForwardCallTarget},
+              m::Reshape(m::Transpose(m::Reshape(m::Parameter(0))
+                                          .WithShape(S8, {4, 10, 20, 30, 2, 8}))
+                             .WithShape(S8, {8, 4, 10, 20, 30, 2}))
+                  .WithShape(S8, {32, 10, 20, 30, 2}),
+              m::Reshape(
+                  m::Transpose(m::Reshape(m::Parameter(1))
+                                   .WithShape(S8, {4, 3, 5, 2, 8, 24, 4, 2}))
+                      .WithShape(S8, {2, 3, 5, 24, 2, 8, 4, 4})
+                      .WithPredicate([](const HloInstruction* instr) {
+                        return absl::c_equal(
+                            instr->dimensions(),
+                            std::vector<int64_t>{3, 1, 2, 5, 7, 4, 6, 0});
+                      }))
+                  .WithShape(S8, {192, 2, 3, 5, 32}),
+              m::Reshape(m::Transpose(m::Reshape(m::Parameter(2)))),
+              m::Reshape(m::Transpose(m::Reshape(m::Parameter(3))
+                                          .WithShape(S8, {4, 10, 20, 30, 2, 8}))
+                             .WithShape(S8, {8, 4, 10, 20, 30, 2}))
+                  .WithShape(S8, {32, 10, 20, 30, 2}))
+              .WithConvDnums("?b01f_oi01->?b01f"))
+          .WithShape(S8, {32, 10, 20, 30, 6});
+  ASSERT_THAT(root, GmockMatch(m::Tuple(
+                        m::Reshape(m::Transpose(m::Reshape(conv_pat).WithShape(
+                                                    S8, {8, 4, 10, 20, 30, 6}))
+                                       .WithShape(S8, {4, 10, 20, 30, 6, 8}))
+                            .WithShape(S8, {4, 10, 20, 30, 48}),
+                        m::Op())));
+
+  EXPECT_TRUE(conv->backend_config<GpuBackendConfig>()
+                  ->cudnn_conv_backend_config()
+                  .reordered_int8_nchw_vect());
+}
+
+TEST_F(CudnnVectorizeConvolutionsTest, NoVectorize4To32) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule TestModule
+
+  ENTRY TestComputation {
+    input = s8[10,20,30,16,4] parameter(0)
+    filter = s8[2,2,16,128,4] parameter(1)
+    bias = f32[10] parameter(2)
+    side_input = s8[10,20,30,16,4] parameter(3)
+    ROOT result = (s8[10,20,30,32,4], u8[0]) custom-call(input, filter, bias, side_input),
+                  window={size=2x2}, dim_labels=b01f_01io->b01f,
+                  custom_call_target="__cudnn$convForward"
+  })")
+                    .value();
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 0}, module.get()));
+  EXPECT_FALSE(changed);
+}
+
+TEST_F(CudnnVectorizeConvolutionsTest, Vectorize16To32) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule TestModule
+
+  ENTRY TestComputation {
+    input = s8[10,20,30,4,16] parameter(0)
+    filter = s8[3,5,4,192,16] parameter(1)
+    ROOT result = (s8[10,20,30,12,16], u8[0]) custom-call(input, filter),
+                  window={size=3x5}, dim_labels=b01f_01io->b01f,
+                  custom_call_target="__cudnn$convForward"
+  })")
+                    .value();
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
+  EXPECT_TRUE(changed);
+
+  SCOPED_TRACE(module->ToString());
+  auto* root = module->entry_computation()->root_instruction();
+
+  const HloInstruction* conv = nullptr;
+  auto filter_pat =
+      m::Reshape(
+          m::Transpose(
+              m::Reshape(m::Parameter(1)).WithShape(S8, {3, 5, 2, 2, 192, 16}))
+              .WithShape(S8, {3, 5, 2, 192, 2, 16}))
+          .WithShape(S8, {3, 5, 2, 192, 32});
+  auto conv_pat =
+      m::GetTupleElement(
+          m::CustomCall(
+              &conv, {kCudnnConvForwardCallTarget},
+              m::Reshape(
+                  m::Transpose(m::Reshape(m::Parameter(0))
+                                   .WithShape(S8, {10, 20, 30, 2, 2, 16}))
+                      .WithShape(S8, {10, 20, 30, 2, 2, 16}))
+                  .WithShape(S8, {10, 20, 30, 2, 32}),
+              m::Reshape(
+                  m::Transpose(m::Reshape(filter_pat)
+                                   .WithShape(S8, {3, 5, 2, 24, 4, 2, 8, 4}))
+                      .WithShape(S8, {2, 3, 5, 24, 2, 8, 4, 4}))
+                  .WithShape(S8, {192, 2, 3, 5, 32}))
+              .WithConvDnums("b01f_oi01->b01f"))
+          .WithShape(S8, {10, 20, 30, 6, 32});
+  ASSERT_THAT(root, GmockMatch(m::Tuple(
+                        m::Reshape(m::Transpose(m::Reshape(conv_pat).WithShape(
+                                                    S8, {10, 20, 30, 6, 2, 16}))
+                                       .WithShape(S8, {10, 20, 30, 6, 2, 16}))
+                            .WithShape(S8, {10, 20, 30, 12, 16}),
+                        m::Op())));
+  EXPECT_TRUE(conv->backend_config<GpuBackendConfig>()
+                  ->cudnn_conv_backend_config()
+                  .reordered_int8_nchw_vect());
+}
+
+TEST_F(CudnnVectorizeConvolutionsTest, VectorizeMixedTo32) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule TestModule
+
+  ENTRY TestComputation {
+    input = s8[10,20,30,8,8] parameter(0)
+    filter = s8[3,5,2,192,32] parameter(1)
+    ROOT result = (s8[10,20,30,96,2], u8[0]) custom-call(input, filter),
+                  window={size=3x5}, dim_labels=b01f_01io->b01f,
+                  custom_call_target="__cudnn$convForward"
+  })")
+                    .value();
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
+  EXPECT_TRUE(changed);
+
+  SCOPED_TRACE(module->ToString());
+  auto* root = module->entry_computation()->root_instruction();
+
+  const HloInstruction* conv = nullptr;
+  auto conv_pat =
+      m::GetTupleElement(
+          m::CustomCall(
+              &conv, {kCudnnConvForwardCallTarget},
+              m::Reshape(m::Transpose(m::Reshape(m::Parameter(0))
+                                          .WithShape(S8, {10, 20, 30, 2, 4, 8}))
+                             .WithShape(S8, {10, 20, 30, 2, 4, 8}))
+                  .WithShape(S8, {10, 20, 30, 2, 32}),
+              m::Reshape(
+                  m::Transpose(m::Reshape(m::Parameter(1))
+                                   .WithShape(S8, {3, 5, 2, 24, 4, 2, 8, 4}))
+                      .WithShape(S8, {2, 3, 5, 24, 2, 8, 4, 4}))
+                  .WithShape(S8, {192, 2, 3, 5, 32}))
+              .WithConvDnums("b01f_oi01->b01f"))
+          .WithShape(S8, {10, 20, 30, 6, 32});
+  ASSERT_THAT(root, GmockMatch(m::Tuple(
+                        m::Reshape(m::Transpose(m::Reshape(conv_pat).WithShape(
+                                                    S8, {10, 20, 30, 6, 16, 2}))
+                                       .WithShape(S8, {10, 20, 30, 6, 16, 2}))
+                            .WithShape(S8, {10, 20, 30, 96, 2}),
+                        m::Op())));
+  EXPECT_TRUE(conv->backend_config<GpuBackendConfig>()
+                  ->cudnn_conv_backend_config()
+                  .reordered_int8_nchw_vect());
+}
+
+}  // namespace
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/custom_kernel_fusion_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/custom_kernel_fusion_rewriter.cc
new file mode 100644
index 0000000..af9591a
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/custom_kernel_fusion_rewriter.cc
@@ -0,0 +1,240 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/custom_kernel_fusion_rewriter.h"
+
+#include <cstdint>
+#include <optional>
+#include <utility>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/container/inlined_vector.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/gpu/kernels/custom_kernel_fusion_pattern.h"
+#include "xla/shape_util.h"
+#include "xla/stream_executor/device_description.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/logging.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla::gpu {
+
+CustomKernelFusionRewriter::CustomKernelFusionRewriter(
+    const se::DeviceDescription* device,
+    const CustomKernelFusionPatternRegistry* patterns)
+    : device_(device), patterns_(patterns) {}
+
+// Returns a set of instruction that have users outside of a matched pattern
+// and have a replacement that must be applied after building a new custom
+// fusion instruction. Only root instruction can have external users and does
+// not require a replacement, as the fusion itself is a replacement. If
+// instruction has external users and does not have a replacement returns empty
+// optional.
+static std::optional<absl::flat_hash_set<HloInstruction*>>
+GetPatternReplacements(const CustomKernelFusionPattern::Match& match) {
+  absl::flat_hash_set<HloInstruction*> requires_replacement;
+  absl::flat_hash_set<HloInstruction*> instructions_set(
+      match.instructions().begin(), match.instructions().end());
+
+  for (HloInstruction* instr : match.instructions()) {
+    for (HloInstruction* user : instr->users()) {
+      if (instr == match.root() || instructions_set.contains(user)) continue;
+
+      if (match.HasReplacement(instr)) {
+        requires_replacement.insert(instr);
+        continue;
+      }
+
+      VLOG(3) << "Custom kernel fusion intermediate result " << instr->name()
+              << " has users outside of a matched pattern: " << user->name();
+      return std::nullopt;
+    }
+  }
+
+  return requires_replacement;
+}
+
+// Returns instructions that have to become custom kernel fusion parameters.
+// Returns an error if matched pattern can't be outlined as a fusion.
+static absl::InlinedVector<HloInstruction*, 4> GetPatternCaptures(
+    const CustomKernelFusionPattern::Match& match) {
+  absl::InlinedVector<HloInstruction*, 4> captures;
+
+  absl::flat_hash_set<HloInstruction*> instructions_set(
+      match.instructions().begin(), match.instructions().end());
+
+  for (HloInstruction* instr : match.instructions()) {
+    for (HloInstruction* operand : instr->operands()) {
+      if (!instructions_set.contains(operand) &&
+          absl::c_find(captures, operand) == captures.end()) {
+        captures.emplace_back(operand);
+      }
+    }
+  }
+
+  return captures;
+}
+
+// Creates custom kernel fusion computation and moves all matched instructions
+// into it.
+static absl::StatusOr<HloComputation*> CreateFusionBody(
+    HloModule* module, const CustomKernelFusionPattern::Match& match,
+    absl::Span<HloInstruction* const> captures) {
+  HloComputation::Builder builder(match.config().name());
+
+  // A mapping from original instructions to instructions in the fusion body.
+  absl::flat_hash_map<const HloInstruction*, HloInstruction*> instr_mapping;
+
+  auto mapped_operands = [&](HloInstruction* instr) {
+    absl::InlinedVector<HloInstruction*, 4> operands;
+    for (HloInstruction* operand : instr->operands()) {
+      operands.push_back(instr_mapping.at(operand));
+    }
+    return operands;
+  };
+
+  // For every captured value create a parameter instruction in the computation
+  // body and set up instruction mapping.
+  for (const HloInstruction* capture : captures) {
+    int64_t index = instr_mapping.size();
+    instr_mapping[capture] =
+        builder.AddInstruction(HloInstruction::CreateParameter(
+            index, capture->shape(), absl::StrCat("p", index)));
+  }
+
+  // TODO(ezhulenev): Instructions in the pattern must be topologically sorted,
+  // otherwise we'll get a crash! Figure out how to do it!
+  for (HloInstruction* instr : match.instructions()) {
+    instr_mapping[instr] = builder.AddInstruction(
+        instr->CloneWithNewOperands(instr->shape(), mapped_operands(instr)));
+  }
+
+  HloInstruction* root = builder.last_added_instruction();
+
+  // If custom kernel fusion requires a workspace we add a custom call that
+  // allocates workspace and return a tuple of "real" result and a workspace.
+  if (match.workspace_size_bytes() > 0) {
+    auto workspace_shape =
+        ShapeUtil::MakeShape(PrimitiveType::U8, {match.workspace_size_bytes()});
+    HloInstruction* workspace =
+        builder.AddInstruction(HloInstruction::CreateCustomCall(
+            workspace_shape, {}, CustomKernelFusionPattern::kWorkspace, "",
+            CustomCallApiVersion::API_VERSION_TYPED_FFI));
+    builder.AddInstruction(HloInstruction::CreateTuple({root, workspace}));
+  }
+
+  return module->AddComputationAndUnifyNamesAndIds(builder.Build(), false);
+}
+
+static absl::StatusOr<HloInstruction*> CreateFusionInstruction(
+    HloModule* module, const CustomKernelFusionPattern::Match& match,
+    absl::Span<HloInstruction* const> captures, HloComputation* body) {
+  // We'll be replacing the root operation of a custom kernel fusion with a
+  // fusion instruction calling fusion computation.
+  HloInstruction* root = match.root();
+  HloComputation* parent = root->parent();
+
+  // Add a fusion operation calling outlined fusion computation.
+  HloInstruction* fusion = parent->AddInstruction(HloInstruction::CreateFusion(
+      body->root_instruction()->shape(), HloInstruction::FusionKind::kCustom,
+      captures, body));
+  module->SetAndUniquifyInstrName(fusion, match.config().name());
+
+  // Set backends config to a matched custom fusion config.
+  GpuBackendConfig gpu_config;
+  FusionBackendConfig& backend_config =
+      *gpu_config.mutable_fusion_backend_config();
+  backend_config.set_kind("__custom_fusion");
+  *backend_config.mutable_custom_fusion_config() = match.config();
+  backend_config.mutable_custom_fusion_config()->set_kernel_index(0);
+  TF_RETURN_IF_ERROR(fusion->set_backend_config(std::move(gpu_config)));
+
+  // If we don't have workspace we can return constructed fusion instruction.
+  if (match.workspace_size_bytes() == 0) return fusion;
+
+  // Otherwise have to get result corresponding to the original value;
+  return parent->AddInstruction(
+      HloInstruction::CreateGetTupleElement(fusion, 0));
+}
+
+absl::StatusOr<bool> CustomKernelFusionRewriter::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  std::vector<CustomKernelFusionPattern::Match> matches;
+
+  // Collect all potential custom fusion matches in the module.
+  for (HloComputation* computation : module->computations()) {
+    for (HloInstruction* instr : computation->instructions()) {
+      auto matched = patterns_->Match(*device_, instr);
+      matches.insert(matches.end(), matched.begin(), matched.end());
+    }
+  }
+
+  if (matches.empty()) return false;
+
+  for (const CustomKernelFusionPattern::Match& match : matches) {
+    VLOG(2) << "Matched custom kernel fusion " << match.config().name()
+            << "; root instruction: " << match.instructions().back()->name();
+
+    auto replacememts = GetPatternReplacements(match);
+    if (!replacememts.has_value()) continue;
+
+    auto captures = GetPatternCaptures(match);
+
+    TF_ASSIGN_OR_RETURN(HloComputation * fusion_body,
+                        CreateFusionBody(module, match, captures));
+    TF_ASSIGN_OR_RETURN(
+        HloInstruction * fusion,
+        CreateFusionInstruction(module, match, captures, fusion_body));
+
+    VLOG(2) << "Added a fusion instruction: " << fusion->name()
+            << " for custom kernel fusion " << match.config().name()
+            << " (instruction count = " << match.instructions().size() << ")";
+
+    for (HloInstruction* instr : *replacememts) {
+      VLOG(2) << "Replace matched instruction: " << instr->name()
+              << " with a pattern replacement";
+
+      TF_ASSIGN_OR_RETURN(
+          HloInstruction * replacement,
+          match.BuildReplacement(instr, Cast<HloFusionInstruction>(fusion)));
+
+      TF_RETURN_IF_ERROR(
+          instr->ReplaceAllUsesWith(replacement, match.config().name()));
+
+      VLOG(2) << "Replaced instruction: " << instr->name()
+              << " with: " << replacement->name();
+    }
+
+    VLOG(2) << "Replace custom kernel fusion root instruction "
+            << match.root()->name() << "with " << fusion->name();
+    HloComputation* parent = match.root()->parent();
+    TF_RETURN_IF_ERROR(parent->ReplaceInstruction(match.root(), fusion));
+  }
+
+  return true;
+}
+
+}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/transforms/custom_kernel_fusion_rewriter.h b/third_party/xla/xla/service/gpu/transforms/custom_kernel_fusion_rewriter.h
new file mode 100644
index 0000000..849fdbb
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/custom_kernel_fusion_rewriter.h
@@ -0,0 +1,86 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_CUSTOM_KERNEL_FUSION_REWRITER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_CUSTOM_KERNEL_FUSION_REWRITER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_schedule.h"
+#include "xla/service/gpu/kernels/custom_kernel_fusion_pattern.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/stream_executor/device_description.h"
+
+namespace xla::gpu {
+
+// Pattern matches HLO instruction to custom kernel fusions (hand written CUDA
+// C++ kernels, e.g. custom GEMMs implemented with CUTLASS) and rewrites them
+// into fusion instructions and fusion computations.
+//
+// Example: pattern matching dot operation into CUTLASS gemm
+//
+//  ENTRY %main (p0: f16[15,19], p1: f16[19,17]) -> f16[15,17] {
+//    %p0 = f16[15,19]{1,0} parameter(0)
+//    %p1 = f16[19,17]{1,0} parameter(1)
+//    ROOT %r = f16[15,17]{1,0} dot(%p0, %p1),
+//      lhs_contracting_dims={1}, rhs_contracting_dims={0}
+//  }
+//
+// After the pass:
+//
+//  %cutlass_gemm (p0: f16[19,17], p1: f16[15,19]) -> f16[15,17] {
+//    %p0 = f16[15,19]{1,0} parameter(0)
+//    %p1 = f16[19,17]{1,0} parameter(1)
+//    ROOT %r = f16[15,17]{1,0} dot(%p0, %p1),
+//      lhs_contracting_dims={1}, rhs_contracting_dims={0}
+//  }
+//
+//  ENTRY %main (p0: f16[15,19], p1: f16[19,17]) -> f16[15,17] {
+//    %p0 = f16[15,19]{1,0} parameter(0)
+//    %p1 = f16[19,17]{1,0} parameter(1)
+//    ROOT %r = f16[15,17]{1,0} fusion(%p0, %p1), kind=kCustom,
+//      calls==cutlass_gemm,
+//      backend_config={kind: "__custom_fusion",
+//                      custom_fusion_config: {"name":"cutlass_gemm"}}
+//  }
+//
+class CustomKernelFusionRewriter : public HloModulePass {
+ public:
+  explicit CustomKernelFusionRewriter(
+      const se::DeviceDescription* device,
+      const CustomKernelFusionPatternRegistry* patterns =
+          CustomKernelFusionPatternRegistry::Default());
+
+  absl::string_view name() const override {
+    return "custom-kernel-fusion-rewriter";
+  }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+  const se::DeviceDescription* device_;
+  const CustomKernelFusionPatternRegistry* patterns_;
+};
+
+}  // namespace xla::gpu
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_CUSTOM_KERNEL_FUSION_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/custom_kernel_fusion_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/custom_kernel_fusion_rewriter_test.cc
new file mode 100644
index 0000000..235e9de
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/custom_kernel_fusion_rewriter_test.cc
@@ -0,0 +1,138 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/custom_kernel_fusion_rewriter.h"
+
+#include <cstdint>
+#include <optional>
+#include <utility>
+
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/service/gpu/gpu_device_info_for_tests.h"
+#include "xla/service/gpu/kernels/custom_kernel_fusion_pattern.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/test.h"
+
+namespace xla::gpu {
+
+//===----------------------------------------------------------------------===//
+// Simple pattern matchers for testing custom kernel_fusion rewriter.
+//===----------------------------------------------------------------------===//
+
+struct SimpleGemmPattern : public CustomKernelFusionPattern {
+  explicit SimpleGemmPattern(int64_t workspace = 0) : workspace(workspace) {}
+
+  std::optional<Match> TryMatch(const se::DeviceDescription& device,
+                                HloInstruction* instr) const override {
+    if (auto* dot = DynCast<HloDotInstruction>(instr)) {
+      CustomFusionConfig config;
+      config.set_name("simple_gemm");
+      return Match{config, {instr}, workspace};
+    }
+    return std::nullopt;
+  }
+
+  int64_t workspace;
+};
+
+//===----------------------------------------------------------------------===//
+
+class CustomKernelFusionRewriterTest : public HloTestBase {};
+
+TEST_F(CustomKernelFusionRewriterTest, SimpleGemm) {
+  const char* hlo = R"(
+    HloModule test
+
+    ENTRY %main (p0: f16[15,19], p1: f16[19,17]) -> f16[15,17] {
+      %p0 = f16[15,19]{1,0} parameter(0)
+      %p1 = f16[19,17]{1,0} parameter(1)
+      ROOT %r = f16[15,17]{1,0} dot(%p0, %p1),
+        lhs_contracting_dims={1}, rhs_contracting_dims={0}
+    }
+  )";
+
+  const char* expected = R"(
+    ; CHECK: %simple_gemm {{.*}} {
+    ; CHECK:   [[P0:%[^ ]+]] = f16[15,19]{1,0} parameter(0)
+    ; CHECK:   [[P1:%[^ ]+]] = f16[19,17]{1,0} parameter(1)
+    ; CHECK:   ROOT [[DOT:%[^ ]+]] = f16[15,17]{1,0} dot([[P0]], [[P1]]),
+    ; CHECK:     lhs_contracting_dims={1}, rhs_contracting_dims={0}
+    ; CHECK: }
+
+    ; CHECK: ENTRY %main {{.*}} {
+    ; CHECK:   ROOT [[FUSION:%[^ ]+]] = f16[15,17]{1,0} fusion
+    ; CHECK:     kind=kCustom, calls=%simple_gemm,
+    ; CHECK:     backend_config={
+    ; CHECK:       "kind":"__custom_fusion",
+    ; CHECK:       "custom_fusion_config":{"name":"simple_gemm","kernel_index":0}
+    ; CHECK:     }
+    ; CHECK: }
+  )";
+
+  CustomKernelFusionPatternRegistry patterns;
+  patterns.Emplace<SimpleGemmPattern>();
+
+  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+  CustomKernelFusionRewriter pass(&device, &patterns);
+  RunAndFilecheckHloRewrite(hlo, std::move(pass), expected);
+}
+
+TEST_F(CustomKernelFusionRewriterTest, SimpleGemmWithWorkspace) {
+  const char* hlo = R"(
+    HloModule test
+
+    ENTRY %main (p0: f16[15,19], p1: f16[19,17]) -> f16[15,17] {
+      %p0 = f16[15,19]{1,0} parameter(0)
+      %p1 = f16[19,17]{1,0} parameter(1)
+      ROOT %r = f16[15,17]{1,0} dot(%p0, %p1),
+        lhs_contracting_dims={1}, rhs_contracting_dims={0}
+    }
+  )";
+
+  const char* expected = R"(
+    ; CHECK: %simple_gemm {{.*}} {
+    ; CHECK:   [[P0:%[^ ]+]] = f16[15,19]{1,0} parameter(0)
+    ; CHECK:   [[P1:%[^ ]+]] = f16[19,17]{1,0} parameter(1)
+    ; CHECK:   [[DOT:%[^ ]+]] = f16[15,17]{1,0} dot([[P0]], [[P1]]),
+    ; CHECK:     lhs_contracting_dims={1}, rhs_contracting_dims={0}
+    ; CHECK:   [[WORKSPACE:%[^ ]+]] = u8[1024]{0} custom-call(),
+    ; CHECK:     custom_call_target="__custom_kernel_fusion$workspace"
+    ; CHECK:   ROOT [[TUPLE:%[^ ]+]] = (f16[15,17]{1,0}, u8[1024]{0})
+    ; CHECK:     tuple([[DOT]], [[WORKSPACE]])
+    ; CHECK: }
+
+    ; CHECK: ENTRY %main {{.*}} {
+    ; CHECK:   [[FUSION:%[^ ]+]] = (f16[15,17]{1,0}, u8[1024]{0}) fusion
+    ; CHECK:     kind=kCustom, calls=%simple_gemm,
+    ; CHECK:     backend_config={
+    ; CHECK:       "kind":"__custom_fusion",
+    ; CHECK:       "custom_fusion_config":{"name":"simple_gemm","kernel_index":0}
+    ; CHECK:     }
+    ; CHECK:   ROOT {{.*}} get-tuple-element([[FUSION]]), index=0
+    ; CHECK: }
+  )";
+
+  CustomKernelFusionPatternRegistry patterns;
+  patterns.Emplace<SimpleGemmPattern>(1024);
+
+  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+  CustomKernelFusionRewriter pass(&device, &patterns);
+  RunAndFilecheckHloRewrite(hlo, std::move(pass), expected);
+}
+
+}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/transforms/dot_dimension_sorter.cc b/third_party/xla/xla/service/gpu/transforms/dot_dimension_sorter.cc
new file mode 100644
index 0000000..b1e0b98
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/dot_dimension_sorter.cc
@@ -0,0 +1,136 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/dot_dimension_sorter.h"
+
+#include <cstdint>
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/status.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/layout_util.h"
+#include "xla/permutation_util.h"
+#include "xla/util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/logging.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+
+// Sort contracting dimensions of a dot() instruction preserving lhs-rhs pairs.
+absl::Status SortDotDimensions(HloDotInstruction* dot) {
+  const DotDimensionNumbers& dims = dot->dot_dimension_numbers();
+  DotDimensionNumbers new_dims(dims);
+  new_dims.clear_lhs_contracting_dimensions();
+  new_dims.clear_rhs_contracting_dimensions();
+  const bool sort_by_lhs =
+      DistinctNumbersAreConsecutiveIfSorted(dims.lhs_contracting_dimensions());
+  // Sort lhs and rhs by sort_key using the fact that
+  // sort_key is guaranteed to have only distinct consecutive numbers.
+  const absl::Span<const int64_t>& sort_key =
+      sort_by_lhs ? dims.lhs_contracting_dimensions()
+                  : dims.rhs_contracting_dimensions();
+  std::vector<int64_t> permutation;
+  for (const int64_t a : sort_key) {
+    permutation.push_back(a - *absl::c_min_element(sort_key));
+  }
+  const std::vector<int64_t> sorted_lhs =
+      Permute(dims.lhs_contracting_dimensions(), permutation);
+  *new_dims.mutable_lhs_contracting_dimensions() = {sorted_lhs.begin(),
+                                                    sorted_lhs.end()};
+  const std::vector<int64_t> sorted_rhs =
+      Permute(dims.rhs_contracting_dimensions(), permutation);
+  *new_dims.mutable_rhs_contracting_dimensions() = {sorted_rhs.begin(),
+                                                    sorted_rhs.end()};
+  std::unique_ptr<HloInstruction> new_dot = HloInstruction::CreateDot(
+      dot->shape(), dot->mutable_operand(0), dot->mutable_operand(1), new_dims,
+      dot->precision_config(), {dot->sparsity().begin(), dot->sparsity().end()},
+      absl::MakeSpan(dot->operands()).subspan(HloDotInstruction::kOperands));
+  dot->SetupDerivedInstruction(new_dot.get());
+
+  VLOG(3) << "Sorted dot() dimensions:\n"
+          << "\t before: " << dot->ToString() << "\n"
+          << "\t after: " << new_dot->ToString();
+  return dot->parent()->ReplaceWithNewInstruction(dot, std::move(new_dot));
+}
+
+}  // namespace
+
+absl::StatusOr<bool> DotDimensionSorter::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  std::vector<HloInstruction*> dots_to_process;
+  for (const HloComputation* computation :
+       module->MakeNonfusionComputations(execution_threads)) {
+    for (HloInstruction* instr : computation->instructions()) {
+      if (instr->opcode() != HloOpcode::kDot) {
+        continue;
+      }
+      // TODO(b/265688934): should non-default layouts be expected here at all?
+      if ((instr->operand(0)->shape().has_layout() &&
+           !LayoutUtil::IsMonotonicWithDim0Major(
+               instr->operand(0)->shape().layout())) ||
+          (instr->operand(1)->shape().has_layout() &&
+           !LayoutUtil::IsMonotonicWithDim0Major(
+               instr->operand(1)->shape().layout()))) {
+        continue;
+      }
+      const DotDimensionNumbers& dims = instr->dot_dimension_numbers();
+      if (dims.lhs_contracting_dimensions_size() == 0) {
+        continue;
+      }
+      const bool cons_lhs = DistinctNumbersAreConsecutiveIfSorted(
+          dims.lhs_contracting_dimensions());
+      const bool cons_rhs = DistinctNumbersAreConsecutiveIfSorted(
+          dims.rhs_contracting_dimensions());
+      const bool sorted_lhs =
+          absl::c_is_sorted(dims.lhs_contracting_dimensions());
+      const bool sorted_rhs =
+          absl::c_is_sorted(dims.rhs_contracting_dimensions());
+      // The side to be sorted has to be consecutive and not sorted yet;
+      // the other side should not get worsened.
+      // TODO(b/265688934): we may still want to change which one is sorted
+      // if this reduces the amount of transposed data.
+      if ((cons_lhs && !sorted_lhs && !cons_rhs) ||
+          (cons_rhs && !sorted_rhs && !cons_lhs) ||
+          (cons_lhs && !sorted_lhs && cons_rhs && !sorted_rhs)) {
+        dots_to_process.push_back(instr);
+      }
+    }
+  }
+  if (dots_to_process.empty()) {
+    return false;
+  }
+  for (HloInstruction* dot : dots_to_process) {
+    TF_RETURN_IF_ERROR(SortDotDimensions(Cast<HloDotInstruction>(dot)));
+  }
+  return true;
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/dot_dimension_sorter.h b/third_party/xla/xla/service/gpu/transforms/dot_dimension_sorter.h
new file mode 100644
index 0000000..872fb72
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/dot_dimension_sorter.h
@@ -0,0 +1,52 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_DOT_DIMENSION_SORTER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_DOT_DIMENSION_SORTER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// Sorts contracting dimensions of dot() operands when this reduces the
+// number of transposes. Example:
+// dot(p0, p1), lhs_contracting_dims={3,2}, rhs_contracting_dims={2,1}  ->
+// dot(p0, p1), lhs_contracting_dims={2,3}, rhs_contracting_dims={1,2}
+// The first case gets transposes inserted by dot_decomposer, the second one
+// does not and thus is generally more efficient.
+
+// TODO(b/265688934): do the same for batch dimensions?
+
+class DotDimensionSorter : public HloModulePass {
+ public:
+  absl::string_view name() const override { return "dot_dimension_sorter"; }
+
+  // Run the pass on computations in 'module'.
+  // Returns whether the 'module' was changed.
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_DOT_DIMENSION_SORTER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/dot_dimension_sorter_test.cc b/third_party/xla/xla/service/gpu/transforms/dot_dimension_sorter_test.cc
new file mode 100644
index 0000000..364c140
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/dot_dimension_sorter_test.cc
@@ -0,0 +1,191 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/dot_dimension_sorter.h"
+
+#include <memory>
+
+#include <gtest/gtest.h>
+#include "xla/error_spec.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/gpu/tests/gpu_codegen_test.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+class WithoutDotDimensionSorterTest : public GpuCodegenTest {
+ public:
+  DebugOptions GetDebugOptionsForTest() override {
+    DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest();
+    // The pass is disabled here to preserve suboptimal dimension order in
+    // 1) UnsortedDimsCreateTransposes to reveal the transposes.
+    // 2) DimOrderCanBeChanged for the comparison of ordered vs unordered.
+    // The pass does not touch SortedDimsDoNotCreateTransposes anyway because
+    // the dimensions are already ordered there.
+    debug_options.add_xla_disable_hlo_passes("dot_dimension_sorter");
+    return debug_options;
+  }
+};
+
+TEST_F(WithoutDotDimensionSorterTest, UnsortedDimsCreateTransposes) {
+  const char* hlo_text = R"(
+HloModule m
+
+ENTRY e {
+  p0 = f16[1,14,9,32] parameter(0)
+  p1 = f16[12,9,32] parameter(1)
+  ROOT _ = f16[1,14,12] dot(p0, p1),
+    lhs_contracting_dims={3,2}, rhs_contracting_dims={2,1}
+}
+)";
+
+  MatchOptimizedHlo(hlo_text, R"(
+; CHECK: transpose
+)");
+}
+
+TEST_F(WithoutDotDimensionSorterTest, SortedDimsDoNotCreateTransposes) {
+  const char* hlo_text = R"(
+HloModule m
+
+ENTRY e {
+  p0 = f16[1,14,9,32] parameter(0)
+  p1 = f16[12,9,32] parameter(1)
+  ROOT _ = f16[1,14,12] dot(p0, p1),
+    lhs_contracting_dims={2,3}, rhs_contracting_dims={1,2}
+}
+)";
+
+  MatchOptimizedHlo(hlo_text, R"(
+; CHECK-NOT: transpose
+)");
+}
+
+TEST_F(WithoutDotDimensionSorterTest, DimOrderCanBeChanged) {
+  const char* hlo_text_ref = R"(
+HloModule m
+
+ENTRY e {
+  p0 = f16[1,14,9,32] parameter(0)
+  p1 = f16[12,9,32] parameter(1)
+  ROOT _ = f16[1,14,12] dot(p0, p1),
+    lhs_contracting_dims={3,2}, rhs_contracting_dims={2,1}
+}
+)";
+
+  const char* hlo_text_modified = R"(
+HloModule m
+
+ENTRY e {
+  p0 = f16[1,14,9,32] parameter(0)
+  p1 = f16[12,9,32] parameter(1)
+  ROOT _ = f16[1,14,12] dot(p0, p1),
+    lhs_contracting_dims={2,3}, rhs_contracting_dims={1,2}
+}
+)";
+
+  EXPECT_TRUE(RunAndCompareTwoModules(hlo_text_ref, hlo_text_modified,
+                                      ErrorSpec{1e-5, 1e-3},
+                                      /*run_hlo_passes=*/true));
+}
+
+using DotDimensionSorterTest = GpuCodegenTest;
+
+TEST_F(DotDimensionSorterTest, SortContractingDims) {
+  const char* module_string = R"(
+HloModule m
+
+ENTRY e {
+  p0 = f16[1,144,96,32] parameter(0)
+  p1 = f16[122,96,32] parameter(1)
+  ROOT _ = f16[1,144,122] dot(p0, p1),
+    lhs_contracting_dims={3,2}, rhs_contracting_dims={2,1}
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(module_string));
+  const auto& dims =
+      module->entry_computation()->root_instruction()->dot_dimension_numbers();
+
+  EXPECT_EQ(dims.lhs_contracting_dimensions(0), 3);
+  EXPECT_EQ(dims.lhs_contracting_dimensions(1), 2);
+
+  EXPECT_EQ(dims.rhs_contracting_dimensions(0), 2);
+  EXPECT_EQ(dims.rhs_contracting_dimensions(1), 1);
+
+  TF_ASSERT_OK_AND_ASSIGN(bool modified,
+                          DotDimensionSorter().Run(module.get()));
+  EXPECT_TRUE(modified);
+  const auto& dims2 =
+      module->entry_computation()->root_instruction()->dot_dimension_numbers();
+
+  EXPECT_EQ(dims2.lhs_contracting_dimensions(0), 2);
+  EXPECT_EQ(dims2.lhs_contracting_dimensions(1), 3);
+
+  EXPECT_EQ(dims2.rhs_contracting_dimensions(0), 1);
+  EXPECT_EQ(dims2.rhs_contracting_dimensions(1), 2);
+}
+
+TEST_F(DotDimensionSorterTest, NothingToReorder) {
+  const char* module_string = R"(
+HloModule m
+
+ENTRY e {
+  p0 = f16[1,144,96,32] parameter(0)
+  p1 = f16[122,96,32] parameter(1)
+  ROOT _ = f16[1,144,122] dot(p0, p1),
+    lhs_contracting_dims={2,3}, rhs_contracting_dims={1,2}
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(module_string));
+
+  TF_ASSERT_OK_AND_ASSIGN(bool modified,
+                          DotDimensionSorter().Run(module.get()));
+  EXPECT_FALSE(modified);
+}
+
+TEST_F(DotDimensionSorterTest, SparseDotSortContractingDims) {
+  const char* module_string = R"(
+HloModule m
+
+ENTRY e {
+  p0 = f16[1,144,96,16] parameter(0)
+  p1 = f16[122,96,32] parameter(1)
+  meta = u16[1,144,96,2] parameter(2)
+  ROOT _ = f16[1,144,122] dot(p0, p1, meta), sparsity=L.3@2:4,
+    lhs_contracting_dims={3,2}, rhs_contracting_dims={2,1}
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(module_string));
+  TF_ASSERT_OK_AND_ASSIGN(bool modified,
+                          DotDimensionSorter().Run(module.get()));
+  EXPECT_TRUE(modified);
+  HloDotInstruction* dot = DynCast<HloDotInstruction>(
+      module->entry_computation()->root_instruction());
+  EXPECT_TRUE(dot != nullptr && dot->sparse_operands() == 1);
+}
+
+}  // namespace
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/dot_operand_converter.cc b/third_party/xla/xla/service/gpu/transforms/dot_operand_converter.cc
new file mode 100644
index 0000000..d9e095e
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/dot_operand_converter.cc
@@ -0,0 +1,74 @@
+/* Copyright 2020 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/dot_operand_converter.h"
+
+#include "absl/status/statusor.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/shape_util.h"
+#include "tsl/platform/errors.h"
+
+namespace xla::gpu {
+
+bool DotOperandConverter::InstructionMatchesPattern(
+    HloInstruction* instruction) {
+  if (instruction->opcode() != HloOpcode::kDot) {
+    return false;
+  }
+  HloInstruction* lhs = instruction->mutable_operand(0);
+  HloInstruction* rhs = instruction->mutable_operand(1);
+
+  PrimitiveType lhs_type = lhs->shape().element_type();
+  PrimitiveType rhs_type = rhs->shape().element_type();
+
+  if (lhs_type == rhs_type) {
+    return false;
+  }
+
+  // Exclude conversions between FP8 types.
+  absl::flat_hash_set<PrimitiveType> non_converting = {F8E4M3FN, F8E5M2};
+  if (non_converting.contains(lhs_type) && non_converting.contains(rhs_type)) {
+    return false;
+  }
+
+  PrimitiveType desired_type =
+      ShapeUtil::HigherPrecisionElementType(lhs->shape(), rhs->shape());
+
+  return desired_type == lhs_type || desired_type == rhs_type;
+}
+
+absl::StatusOr<HloInstruction*> DotOperandConverter::ExpandInstruction(
+    HloInstruction* instruction) {
+  HloInstruction* lhs = instruction->mutable_operand(0);
+  HloInstruction* rhs = instruction->mutable_operand(1);
+
+  // Find the higher precision type among the two operands, and add a convert
+  // instruction to convert the lesser-precise operand to that type.
+  PrimitiveType desired_type =
+      ShapeUtil::HigherPrecisionElementType(lhs->shape(), rhs->shape());
+  int operand_index = desired_type == lhs->shape().element_type() ? 1 : 0;
+  HloInstruction* inst_to_replace =
+      desired_type == lhs->shape().element_type() ? rhs : lhs;
+  auto upcast_shape = inst_to_replace->shape();
+  upcast_shape.set_element_type(desired_type);
+  auto* convert_inst = instruction->AddInstruction(
+      HloInstruction::CreateConvert(upcast_shape, inst_to_replace));
+  TF_RETURN_IF_ERROR(instruction->ReplaceOperandWithDifferentShape(
+      operand_index, convert_inst));
+  return nullptr;
+}
+
+}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/transforms/dot_operand_converter.h b/third_party/xla/xla/service/gpu/transforms/dot_operand_converter.h
new file mode 100644
index 0000000..b269bed
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/dot_operand_converter.h
@@ -0,0 +1,46 @@
+/* Copyright 2020 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_DOT_OPERAND_CONVERTER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_DOT_OPERAND_CONVERTER_H_
+
+#include <utility>
+
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/service/op_expander_pass.h"
+#include "xla/util.h"
+
+namespace xla::gpu {
+
+// Converts both operands to the highest precision operand type.
+class DotOperandConverter : public OpExpanderPass {
+ public:
+  explicit DotOperandConverter(HloPredicate extra_filter = nullptr)
+      : OpExpanderPass(std::move(extra_filter)) {}
+
+  absl::string_view name() const override { return "operand_converter"; }
+
+ protected:
+  bool InstructionMatchesPattern(HloInstruction* instruction) override;
+
+  absl::StatusOr<HloInstruction*> ExpandInstruction(
+      HloInstruction* instruction) override;
+};
+
+}  // namespace xla::gpu
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_DOT_OPERAND_CONVERTER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/dot_operand_converter_test.cc b/third_party/xla/xla/service/gpu/transforms/dot_operand_converter_test.cc
new file mode 100644
index 0000000..be05b67
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/dot_operand_converter_test.cc
@@ -0,0 +1,142 @@
+/* Copyright 2020 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/dot_operand_converter.h"
+
+#include <memory>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/strings/string_view.h"
+#include "absl/strings/substitute.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/utils/hlo_matchers.h"
+#include "xla/primitive_util.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla::gpu {
+namespace {
+
+namespace op = ::xla::testing::opcode_matchers;
+
+class DotOperandConverterTest : public HloTestBase {
+ public:
+  void TestConvert(bool left_less_precise, PrimitiveType lhs_type,
+                   PrimitiveType rhs_type, PrimitiveType result_type) {
+    absl::string_view module_tmpl = R"(
+  HloModule module
+
+  ENTRY main {
+    p0 = $0[2,3]{1,0} parameter(0)
+    p1 = $1[3,2]{1,0} parameter(1)
+    ROOT dot = $2[2,2]{1,0} dot(p0, p1), lhs_contracting_dims={1},
+                                         rhs_contracting_dims={0}
+  })";
+    auto module_string = absl::Substitute(
+        module_tmpl, primitive_util::LowercasePrimitiveTypeName(lhs_type),
+        primitive_util::LowercasePrimitiveTypeName(rhs_type),
+        primitive_util::LowercasePrimitiveTypeName(result_type));
+    TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                            ParseAndReturnVerifiedModule(module_string));
+    TF_ASSERT_OK_AND_ASSIGN(bool upcasted,
+                            DotOperandConverter().Run(module.get()));
+    EXPECT_TRUE(upcasted);
+    if (left_less_precise) {
+      auto original_lhs = op::Parameter(0);
+      auto upcasted_lhs =
+          AllOf(op::Convert(original_lhs),
+                op::Shape(absl::Substitute(
+                    "$0[2,3]{1,0}",
+                    primitive_util::LowercasePrimitiveTypeName(rhs_type))));
+      EXPECT_THAT(
+          module->entry_computation()->root_instruction(),
+          AllOf(op::Dot(upcasted_lhs, op::Parameter(1)),
+                op::Shape(absl::Substitute(
+                    "$0[2,2]{1,0}",
+                    primitive_util::LowercasePrimitiveTypeName(result_type)))));
+    } else {
+      auto original_rhs = op::Parameter(1);
+      auto upcasted_rhs =
+          AllOf(op::Convert(original_rhs),
+                op::Shape(absl::Substitute(
+                    "$0[3,2]{1,0}",
+                    primitive_util::LowercasePrimitiveTypeName(lhs_type))));
+      EXPECT_THAT(
+          module->entry_computation()->root_instruction(),
+          AllOf(op::Dot(op::Parameter(0), upcasted_rhs),
+                op::Shape(absl::Substitute(
+                    "$0[2,2]{1,0}",
+                    primitive_util::LowercasePrimitiveTypeName(result_type)))));
+    }
+  }
+};
+
+TEST_F(DotOperandConverterTest, ConvertsLeftAndRight) {
+  TestConvert(/*left_less_precise=*/true, S8, BF16, F32);
+  TestConvert(/*left_less_precise=*/false, BF16, S8, F32);
+}
+
+TEST_F(DotOperandConverterTest, NoConvertHappensWithSameTypes) {
+  absl::string_view module_string = R"(
+  HloModule module
+
+  ENTRY main {
+    p0 = s8[2,3]{1,0} parameter(0)
+    p1 = s8[3,2]{1,0} parameter(1)
+    ROOT dot = bf16[2,2]{1,0} dot(p0, p1), lhs_contracting_dims={1},
+                                         rhs_contracting_dims={0}
+  })";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(module_string));
+  TF_ASSERT_OK_AND_ASSIGN(bool upcasted,
+                          DotOperandConverter().Run(module.get()));
+  EXPECT_FALSE(upcasted);
+}
+
+TEST_F(DotOperandConverterTest, NoConvertFromF8toF8) {
+  absl::string_view module_string = R"(
+  HloModule module
+
+  ENTRY main {
+    p0 = f8e4m3fn[2,3]{1,0} parameter(0)
+    p1 = f8e5m2[3,2]{1,0} parameter(1)
+    ROOT dot = bf16[2,2]{1,0} dot(p0, p1), lhs_contracting_dims={1},
+                                         rhs_contracting_dims={0}
+  })";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(module_string));
+  TF_ASSERT_OK_AND_ASSIGN(bool upcasted,
+                          DotOperandConverter().Run(module.get()));
+  EXPECT_FALSE(upcasted);
+}
+
+TEST_F(DotOperandConverterTest, CompilerOptimizesUsingDotOperandConverter) {
+  absl::string_view module_string = R"(
+  HloModule module
+
+  ENTRY main {
+    p0 = s8[2,3]{1,0} parameter(0)
+    p1 = bf16[3,2]{1,0} parameter(1)
+    ROOT dot = bf16[2,2]{1,0} dot(p0, p1), lhs_contracting_dims={1},
+                                         rhs_contracting_dims={0}
+  })";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          GetOptimizedModule(module_string));
+}
+
+}  // namespace
+}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/transforms/dot_sparsity_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/dot_sparsity_rewriter.cc
new file mode 100644
index 0000000..637689a
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/dot_sparsity_rewriter.cc
@@ -0,0 +1,110 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/dot_sparsity_rewriter.h"
+
+#include <utility>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_creation_utils.h"
+#include "xla/util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+class SparseDotRewriterImpl : public DfsHloRewriteVisitor {
+ public:
+  absl::Status HandleDot(HloInstruction* instr) override {
+    // Only handle sparse dots with a single RHS sparse descriptor.
+    HloDotInstruction* dot = Cast<HloDotInstruction>(instr);
+    if (dot->sparse_operands() != 1 || dot->sparsity().front().index() != 1) {
+      return absl::OkStatus();
+    }
+
+    HloInstruction* lhs = dot->mutable_operand(0);
+    HloInstruction* rhs = dot->mutable_operand(1);
+    HloInstruction* meta = dot->mutable_operand(2);
+
+    // Swap LHS and RHS in the attributes.
+    DotDimensionNumbers dnums = dot->dot_dimension_numbers();
+    std::swap(*dnums.mutable_lhs_batch_dimensions(),
+              *dnums.mutable_rhs_batch_dimensions());
+    std::swap(*dnums.mutable_lhs_contracting_dimensions(),
+              *dnums.mutable_rhs_contracting_dimensions());
+
+    PrecisionConfig precision_config = dot->precision_config();
+    std::swap(precision_config.mutable_operand_precision()->at(0),
+              precision_config.mutable_operand_precision()->at(1));
+
+    SparsityDescriptor sparsity = dot->sparsity().front();
+    sparsity.set_index(0);
+
+    // Create new dot with LHS and RHS swapped.
+    TF_ASSIGN_OR_RETURN(
+        HloInstruction * new_dot,
+        MakeDotHlo(rhs, lhs, dnums, precision_config,
+                   dot->shape().element_type(), {std::move(sparsity)}, {meta}));
+    dot->SetupDerivedInstruction(new_dot);
+
+    // Result dimensions: <batch>, <rhs_noncontracting>, <lhs_noncontracting>
+    int batch_dims = dnums.lhs_batch_dimensions().size();
+    int new_lhs_noncontracting = rhs->shape().rank() - batch_dims -
+                                 dnums.lhs_contracting_dimensions().size();
+    int new_rhs_noncontracting = lhs->shape().rank() - batch_dims -
+                                 dnums.rhs_contracting_dimensions().size();
+
+    int rank = dot->shape().rank();
+    DimensionVector dimensions(rank);
+    for (int i = 0; i < batch_dims; ++i) {
+      dimensions[i] = i;
+    }
+    for (int i = 0; i < new_lhs_noncontracting; ++i) {
+      dimensions[i + batch_dims] = i + batch_dims + new_rhs_noncontracting;
+    }
+    for (int i = 0; i < new_rhs_noncontracting; ++i) {
+      dimensions[i + batch_dims + new_lhs_noncontracting] = i + batch_dims;
+    }
+
+    // Transpose the result.
+    TF_ASSIGN_OR_RETURN(HloInstruction * transpose,
+                        MakeTransposeHlo(new_dot, dimensions));
+    transpose->set_metadata(dot->metadata());
+    *transpose->mutable_shape()->mutable_layout() = dot->shape().layout();
+
+    return ReplaceInstruction(dot, transpose);
+  }
+};
+
+}  // namespace
+
+absl::StatusOr<bool> DotSparsityRewriter::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  return SparseDotRewriterImpl().RunOnModule(module, execution_threads);
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/dot_sparsity_rewriter.h b/third_party/xla/xla/service/gpu/transforms/dot_sparsity_rewriter.h
new file mode 100644
index 0000000..b912e2b
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/dot_sparsity_rewriter.h
@@ -0,0 +1,42 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_DOT_SPARSITY_REWRITER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_DOT_SPARSITY_REWRITER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// Make sure sparse dot requirements are met (sparse operand is LHS).
+class DotSparsityRewriter : public HloModulePass {
+ public:
+  absl::string_view name() const override { return "dot_sparsity_rewriter"; }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_DOT_SPARSITY_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/dot_sparsity_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/dot_sparsity_rewriter_test.cc
new file mode 100644
index 0000000..28f813f
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/dot_sparsity_rewriter_test.cc
@@ -0,0 +1,85 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/dot_sparsity_rewriter.h"
+
+#include <memory>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+using ::testing::ElementsAre;
+
+class DotSparsityRewriterTest : public HloTestBase {
+ public:
+  DotSparsityRewriterTest() : HloTestBase(/*verifier_layout_sensitive=*/true) {}
+};
+
+TEST_F(DotSparsityRewriterTest, SparseDotRhsToLhs) {
+  const char* module_string = R"(
+HloModule m
+
+ENTRY e {
+  lhs = f16[4,2,16,8,64] parameter(0)
+  rhs = f16[2,4,8,32,128] parameter(1)
+  meta = u16[2,4,8,4,128] parameter(2)
+  ROOT dot = f16[4,2,16,128] dot(lhs, rhs, meta),
+    lhs_contracting_dims={3,4}, rhs_contracting_dims={2,3},
+    lhs_batch_dims={0,1}, rhs_batch_dims={1,0}, sparsity=R.3@2:4
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(module_string));
+  TF_ASSERT_OK_AND_ASSIGN(bool modified,
+                          DotSparsityRewriter().Run(module.get()));
+  EXPECT_TRUE(modified);
+
+  const HloTransposeInstruction* transpose = DynCast<HloTransposeInstruction>(
+      module->entry_computation()->root_instruction());
+  ASSERT_TRUE(transpose != nullptr);
+  EXPECT_THAT(transpose->dimensions(), ElementsAre(0, 1, 3, 2));
+
+  const HloDotInstruction* dot =
+      DynCast<HloDotInstruction>(transpose->operand(0));
+  ASSERT_TRUE(dot != nullptr);
+
+  const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
+  EXPECT_EQ(dnums.lhs_contracting_dimensions(0), 2);
+  EXPECT_EQ(dnums.lhs_contracting_dimensions(1), 3);
+  EXPECT_EQ(dnums.rhs_contracting_dimensions(0), 3);
+  EXPECT_EQ(dnums.rhs_contracting_dimensions(1), 4);
+  EXPECT_EQ(dnums.lhs_batch_dimensions(0), 1);
+  EXPECT_EQ(dnums.lhs_batch_dimensions(1), 0);
+  EXPECT_EQ(dnums.rhs_batch_dimensions(0), 0);
+  EXPECT_EQ(dnums.rhs_batch_dimensions(1), 1);
+
+  EXPECT_EQ(dot->sparse_operands(), 1);
+  EXPECT_EQ(dot->sparsity().front().index(), 0);
+}
+
+}  // namespace
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling.cc b/third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling.cc
new file mode 100644
index 0000000..a4f901f
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling.cc
@@ -0,0 +1,569 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "xla/service/gpu/transforms/double_buffer_loop_unrolling.h"
+
+#include <cstdint>
+#include <iterator>
+#include <optional>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_clone_context.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instruction_utils.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/hlo/utils/hlo_query.h"
+#include "xla/service/collective_ops_utils.h"
+#include "xla/service/flatten_call_graph.h"
+#include "xla/util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+
+void SetChannelIdForNewCollective(HloInstruction* new_instr,
+                                  const HloModule* module) {
+  // This is to track mappings of old->new channel id for async collectives
+  // wrapped in the form of HloAsyncInstruction, the start and done need to
+  // have the same unique channel id.
+  absl::flat_hash_map<int64_t, int64_t> old_to_new_channel_id_map;
+  absl::flat_hash_map<int64_t, HloComputation*> channel_id_comp_map;
+  if (new_instr->IsAsynchronous() && hlo_query::IsCollectiveCommunicationOp(
+                                         new_instr->async_wrapped_opcode())) {
+    HloInstruction* wrapped_instr =
+        DynCast<HloAsyncInstruction>(new_instr)->async_wrapped_instruction();
+    int64_t old_channel_id = *wrapped_instr->channel_id();
+    int64_t new_channel_id = old_to_new_channel_id_map[old_channel_id];
+    if (old_to_new_channel_id_map.find(old_channel_id) ==
+        old_to_new_channel_id_map.end()) {
+      new_channel_id = hlo_query::NextChannelId(*module);
+      VLOG(2) << "Generated new channel id " << new_channel_id;
+      old_to_new_channel_id_map[old_channel_id] = new_channel_id;
+    }
+
+    VLOG(2) << "Setting channel id to " << new_channel_id;
+
+    wrapped_instr->set_channel_id(new_channel_id);
+    if (channel_id_comp_map.find(new_channel_id) == channel_id_comp_map.end()) {
+      channel_id_comp_map[new_channel_id] =
+          new_instr->async_wrapped_computation();
+    } else {
+      channel_id_comp_map[new_channel_id]->AddAsyncStart(new_instr);
+    }
+  } else if (hlo_query::IsCollectiveCommunicationOp(new_instr->opcode()) ||
+             hlo_query::IsAsyncCollectiveStartOp(new_instr)) {
+    new_instr->set_channel_id(hlo_query::NextChannelId(*module));
+  }
+}
+
+using Interval = std::pair<int64_t, int64_t>;
+
+// Parses a string of the format `{{a,b},{c,d},{e,f}...}` to a vector of pairs.
+absl::StatusOr<std::vector<Interval>> ParseVectorOfPairs(
+    absl::string_view str) {
+  TF_ASSIGN_OR_RETURN(std::vector<ReplicaGroup> replica_groups,
+                      ParseReplicaGroupsOnly(str));
+  std::vector<Interval> res;
+  res.reserve(replica_groups.size());
+  for (const ReplicaGroup& replica_group : replica_groups) {
+    TF_RET_CHECK(replica_group.replica_ids_size() == 2);
+    int64_t a = replica_group.replica_ids(0);
+    int64_t b = replica_group.replica_ids(1);
+    res.emplace_back(a, b);
+  }
+  return res;
+}
+
+// This function fixes the `_xla_send_recv_validation` attribute for peeled
+// instructions. When the loop trip count is odd, the peeled instructions are
+// moved before the loop. The collectives in these instructions correspond to
+// the first iteration of the original loop. We have to run this peeled
+// collective for all those devices that had the 0-th iteration as a valid
+// iteration.
+absl::Status SetSendRecvValidationForPeeledInstr(HloInstruction* new_instr,
+                                                 HloInstruction* old_instr) {
+  TF_RET_CHECK(
+      new_instr->opcode() == old_instr->opcode() &&
+      "cloned instruction and original instruction have different opcodes");
+  if (!HloPredicateIsOp<HloOpcode::kCollectivePermute,
+                        HloOpcode::kCollectivePermuteStart, HloOpcode::kSend,
+                        HloOpcode::kRecv>(old_instr)) {
+    return absl::OkStatus();
+  }
+
+  const auto& attribute_map = new_instr->frontend_attributes().map();
+  if (!attribute_map.contains(kSendRecvValidationAttr)) {
+    return absl::OkStatus();
+  }
+
+  VLOG(3) << "Original send-recv iterations: "
+          << attribute_map.at(kSendRecvValidationAttr);
+
+  TF_ASSIGN_OR_RETURN(
+      auto send_recv_validation_attr,
+      ParseVectorOfPairs(attribute_map.at(kSendRecvValidationAttr)));
+
+  uint64_t n_pairs = send_recv_validation_attr.size();
+  if (n_pairs == 0) {
+    return absl::OkStatus();
+  }
+  std::vector<Interval> send_recv_validation_attr_updated(n_pairs, {1, 0});
+  // Check which of the attributes have iteration number zero as valid
+  // iteration. For all those, set the peeled instruction to run.
+  for (std::uint64_t i = 0; i < send_recv_validation_attr.size(); i++) {
+    if (send_recv_validation_attr[i].first <= 0 &&
+        send_recv_validation_attr[i].second >= 0) {
+      send_recv_validation_attr_updated[i] = {0, 0};
+    }
+  }
+
+  hlo_instruction_utils::AddOrUpdateVectorOfPairsAsAttribute(
+      /*instr=*/new_instr, /*attr_name=*/kSendRecvValidationAttr,
+      /*intervals=*/send_recv_validation_attr_updated);
+  return absl::OkStatus();
+}
+
+// This function fixes the `_xla_send_recv_validation` attribute for the two new
+// collectives inside the loop. The calculation of the new valid iterations
+// depends on whether the loop was peeled or not.
+//
+// If the loop was not peeled, then
+//  - iteration 0 of the new loop coressponds to iteration 0,1 of the old loop.
+//  - iteration 1 of the new loop coressponds to iteration 2,3 of the old loop.
+//  - and so on...
+// If the loop was peeled, then the first iteration runs before the loop. So,
+//  - iteration 0 of the new loop coressponds to iteration 1,2 of the old loop.
+//  - iteration 1 of the new loop coressponds to iteration 3,4 of the old loop.
+//  - and so on...
+//
+// Consider the case when the loop was peeled, and the original attribute for
+// some device was {4,7}. Consider that the two new collectives are
+// `collective.1` and `collective.2` (they execute in this order inside the new
+// loop). In the old loop, iterations 4,5,6,7 were valid. In the new
+// loop,
+//  - collective.2 in iteration 1 of new loop runs 4th iteration of old loop.
+//  - collective.1 in iteration 2 of new loop runs 5th iteration of old loop.
+//  - collective.2 in iteration 2 of new loop runs 6th iteration of old loop.
+//  - collective.1 in iteration 3 of new loop runs 7th iteration of old loop.
+// So, the updated attribute for that device are {1,2} for `collective.2` and
+// {2,3} for `collective.1`.
+//
+// In a similar fashion we can generalize the computation of new values based on
+// the values of the old attribute as done in the logic below.
+absl::Status SetSendRecvValidation(HloInstruction* cp1, HloInstruction* cp2,
+                                   bool is_peeled) {
+  TF_RET_CHECK(
+      cp2->opcode() == cp1->opcode() &&
+      "cloned instruction and original instruction have different opcodes");
+  if (!HloPredicateIsOp<HloOpcode::kCollectivePermute,
+                        HloOpcode::kCollectivePermuteStart, HloOpcode::kSend,
+                        HloOpcode::kRecv>(cp1)) {
+    return absl::OkStatus();
+  }
+  const auto& attribute_map = cp2->frontend_attributes().map();
+  if (!attribute_map.contains(kSendRecvValidationAttr)) {
+    return absl::OkStatus();
+  }
+  VLOG(3) << "Original send-recv iterations: "
+          << attribute_map.at(kSendRecvValidationAttr);
+
+  TF_ASSIGN_OR_RETURN(
+      auto send_recv_validation_attr,
+      ParseVectorOfPairs(attribute_map.at(kSendRecvValidationAttr)));
+
+  if (send_recv_validation_attr.size() == 0) {
+    return absl::OkStatus();
+  }
+
+  std::vector<Interval> send_recv_iterations_new_instr1,
+      send_recv_iterations_new_instr2;
+  send_recv_iterations_new_instr1.reserve(send_recv_validation_attr.size());
+  send_recv_iterations_new_instr2.reserve(send_recv_validation_attr.size());
+  for (const Interval& pair : send_recv_validation_attr) {
+    int64_t a = pair.first;
+    int64_t b = pair.second;
+    if (is_peeled) {
+      send_recv_iterations_new_instr1.emplace_back(
+          std::floor(a / 2.0), std::max(0.0, std::floor((b - 1) / 2.0)));
+      send_recv_iterations_new_instr2.emplace_back(
+          std::max(0.0, std::floor((a - 1) / 2.0)),
+          std::max(0.0, std::floor((b - 2) / 2.0)));
+    } else {
+      send_recv_iterations_new_instr1.emplace_back(std::floor((a + 1) / 2.0),
+                                                   std::floor(b / 2.0));
+      send_recv_iterations_new_instr2.emplace_back(
+          std::floor(a / 2.0), std::max(0.0, std::floor((b - 1) / 2.0)));
+    }
+  }
+
+  hlo_instruction_utils::AddOrUpdateVectorOfPairsAsAttribute(
+      /*instr=*/cp1, /*attr_name=*/kSendRecvValidationAttr,
+      /*intervals=*/send_recv_iterations_new_instr1);
+  hlo_instruction_utils::AddOrUpdateVectorOfPairsAsAttribute(
+      /*instr=*/cp2, /*attr_name=*/kSendRecvValidationAttr,
+      /*intervals=*/send_recv_iterations_new_instr2);
+
+  VLOG(3) << "Updated send-recv iterations for " << cp1->name() << " : "
+          << cp1->frontend_attributes().map().at(kSendRecvValidationAttr);
+  VLOG(3) << "Updated send-recv iterations for " << cp2->name() << " : "
+          << cp2->frontend_attributes().map().at(kSendRecvValidationAttr);
+  return absl::OkStatus();
+}
+
+// Handle control predecessors/successors for every old-new instruction pair.
+// For every new instruction, we find the relevant predecessor/successor
+// relationships of the old instruction and we reconstruct them by looking up
+// new (already created) predecessors/successors.
+//
+// When rewiring dependencies from output of the original body, to the input of
+// the cloned body we skip collectives, and ops in `skip_control_dep_injection`.
+absl::Status HandleControlDependencies(
+    const HloComputation* while_body,
+    const absl::flat_hash_map<HloInstruction*, HloInstruction*>& old_to_new_map,
+    HloInstruction::InstructionVector* old_loop_roots,
+    HloInstruction* input_parameter,
+    const absl::flat_hash_set<HloInstruction*>& skip_control_dep_injection) {
+  for (HloInstruction* old_instr : while_body->MakeInstructionPostOrder()) {
+    if (old_to_new_map.find(old_instr) != old_to_new_map.end()) {
+      HloInstruction* new_instr = old_to_new_map.at(old_instr);
+      VLOG(2) << "Processing control predecessors for "
+              << new_instr->ToString();
+      std::vector<HloInstruction*> new_control_pred;
+      new_control_pred.reserve(old_instr->control_predecessors().size());
+      for (HloInstruction* pred : old_instr->control_predecessors()) {
+        if (!old_to_new_map.contains(pred)) {
+          continue;
+        }
+        new_control_pred.push_back(old_to_new_map.at(pred));
+      }
+
+      TF_RETURN_IF_ERROR(new_instr->DropAllControlDeps());
+      for (HloInstruction* new_pred : new_control_pred) {
+        TF_RETURN_IF_ERROR(new_pred->AddControlDependencyTo(new_instr));
+        VLOG(2) << "Adding " << new_pred->ToString()
+                << " to control dependency of " << new_instr->ToString();
+      }
+    }
+  }
+  for (HloInstruction* input_consumer : input_parameter->users()) {
+    for (HloInstruction* old_input : input_consumer->users()) {
+      if (old_to_new_map.find(old_input) != old_to_new_map.end()) {
+        HloInstruction* new_input = old_to_new_map.at(old_input);
+        if (skip_control_dep_injection.find(old_input) ==
+                skip_control_dep_injection.end() &&
+            !IsCollective(old_input)) {
+          for (HloInstruction* old_root : *old_loop_roots) {
+            TF_RETURN_IF_ERROR(old_root->AddControlDependencyTo(new_input));
+          }
+        }
+      }
+    }
+  }
+
+  return absl::OkStatus();
+}
+
+absl::StatusOr<bool> FullyUnroll(HloInstruction* while_instr,
+                                 HloModule* module) {
+  HloComputation* while_body = while_instr->while_body();
+  bool changed = false;
+  VLOG(2) << "Processing root " << while_body->root_instruction()->ToString();
+
+  auto loop_roots = while_body->root_instruction()->mutable_operands();
+  HloInstruction* input_parameter = while_body->parameter_instruction(0);
+  VLOG(2) << "Processing input parameter " << input_parameter->ToString();
+
+  absl::flat_hash_map<HloInstruction*, HloInstruction*> old_to_new_map;
+  absl::flat_hash_set<HloInstruction*> skip_control_dep_injection;
+  std::string clone_suffix = "full_unroll_clone";
+
+  TF_ASSIGN_OR_RETURN(WhileLoopBackendConfig config,
+                      while_instr->backend_config<WhileLoopBackendConfig>());
+  std::vector<HloInstruction*> ops_to_clone;
+  ops_to_clone.reserve(while_body->MakeInstructionPostOrder().size());
+
+  // Pre-loop prep.
+  HloInstruction* old_input_parameter = input_parameter;
+  HloInstruction* new_input_parameter = while_body->root_instruction();
+  absl::flat_hash_set<HloInstruction*> seen_ops;
+  for (HloInstruction* old_instr : while_body->MakeInstructionPostOrder()) {
+    if (seen_ops.contains(old_instr)) {
+      continue;
+    }
+    ops_to_clone.push_back(old_instr);
+    seen_ops.insert(old_instr);
+  }
+
+  int n = config.known_trip_count().n();
+  while (--n) {
+    std::vector<HloInstruction*> new_ops_to_clone;
+    old_to_new_map[old_input_parameter] = new_input_parameter;
+    for (HloInstruction* old_instr : ops_to_clone) {
+      if (old_to_new_map.contains(old_instr)) {
+        continue;
+      }
+      VLOG(2) << "Cloning instruction " << old_instr->ToString();
+      std::vector<HloInstruction*> new_operands;
+      for (HloInstruction* old_operand : old_instr->mutable_operands()) {
+        new_operands.push_back(old_to_new_map[old_operand]);
+      }
+      HloInstruction* new_instr =
+          while_body->AddInstruction(old_instr->CloneWithNewOperands(
+              old_instr->shape(), new_operands, clone_suffix));
+
+      // If an elementwise instruction with constant operand is present, we
+      // won't inject control dependency at the end to allow more constant
+      // folding opportunities.
+      if (old_instr->IsElementwiseBinary() && old_instr->HasConstantOperand()) {
+        skip_control_dep_injection.insert(old_instr);
+      }
+      SetChannelIdForNewCollective(new_instr, module);
+      old_to_new_map[old_instr] = new_instr;
+      new_ops_to_clone.push_back(new_instr);
+      VLOG(2) << "Added instruction " << new_instr->ToString();
+    }
+
+    while_body->set_root_instruction(
+        old_to_new_map[while_body->root_instruction()]);
+    VLOG(2) << "Replaced with new root "
+            << while_body->root_instruction()->ToString();
+
+    TF_RETURN_IF_ERROR(HandleControlDependencies(
+        while_body, old_to_new_map, &loop_roots, old_input_parameter,
+        skip_control_dep_injection));
+
+    // Inductive step update, clean/update necessary buffers to prepare them for
+    // the next unrolling iteration.
+    old_to_new_map.clear();
+    skip_control_dep_injection.clear();
+    loop_roots = while_body->root_instruction()->mutable_operands();
+    old_input_parameter = new_input_parameter;
+    new_input_parameter = while_body->root_instruction();
+    ops_to_clone = std::move(new_ops_to_clone);
+    changed = true;
+  }
+
+  WhileLoopBackendConfig new_config;
+  new_config.mutable_known_trip_count()->set_n(1);
+  TF_RETURN_IF_ERROR(while_instr->set_backend_config(new_config));
+
+  return changed;
+}
+
+absl::Status PeelInstructionsForOddTripCount(HloModule* module,
+                                             HloInstruction* while_instr) {
+  std::string suffix = "peeled_double_buffer";
+  absl::flat_hash_map<HloInstruction*, HloInstruction*> old_to_new_map;
+  HloComputation* while_body = while_instr->while_body();
+  HloInstruction* input_parameter = while_body->parameter_instruction(0);
+  HloInstruction* input_tuple = while_instr->mutable_operand(0);
+
+  auto old_loop_roots = while_body->root_instruction()->mutable_operands();
+  HloComputation* parent_comp = while_instr->parent();
+  old_to_new_map[input_parameter] = input_tuple;
+
+  for (HloInstruction* old_instr : while_body->MakeInstructionPostOrder()) {
+    if (old_to_new_map.find(old_instr) != old_to_new_map.end()) {
+      continue;
+    }
+    VLOG(2) << "Peeling instruction " << old_instr->ToString();
+    std::vector<HloInstruction*> new_operands(old_instr->operand_count());
+    for (int64_t i = 0; i < old_instr->operand_count(); i++) {
+      new_operands[i] = old_to_new_map[old_instr->mutable_operand(i)];
+    }
+    HloInstruction* new_instr =
+        parent_comp->AddInstruction(old_instr->CloneWithNewOperands(
+            old_instr->shape(), new_operands, suffix));
+
+    SetChannelIdForNewCollective(new_instr, module);
+    TF_CHECK_OK(SetSendRecvValidationForPeeledInstr(new_instr, old_instr));
+    old_to_new_map[old_instr] = new_instr;
+    VLOG(2) << "Added instruction " << new_instr->ToString()
+            << " to parent computation.";
+  }
+
+  std::vector<HloInstruction*> new_roots;
+  for (HloInstruction* instr : old_loop_roots) {
+    new_roots.push_back(old_to_new_map[instr]);
+  }
+  TF_RETURN_IF_ERROR(while_instr->ReplaceOperandWith(
+      0, old_to_new_map[while_body->root_instruction()]));
+  VLOG(2) << "Replaced with new input tuple "
+          << while_instr->operand(0)->ToString();
+
+  // Handle existing control dependencies.
+  for (HloInstruction* old_instr : while_body->MakeInstructionPostOrder()) {
+    if (old_to_new_map.find(old_instr) != old_to_new_map.end()) {
+      HloInstruction* new_instr = old_to_new_map[old_instr];
+      VLOG(2) << "Processing control predecessors for peeled instruction "
+              << new_instr->ToString();
+      std::vector<HloInstruction*> new_control_pred(
+          old_instr->control_predecessors().size());
+      for (HloInstruction* pred : old_instr->control_predecessors()) {
+        new_control_pred.push_back(old_to_new_map[pred]);
+      }
+
+      TF_RETURN_IF_ERROR(new_instr->DropAllControlDeps());
+      for (HloInstruction* new_pred : new_control_pred) {
+        TF_RETURN_IF_ERROR(new_pred->AddControlDependencyTo(new_instr));
+        VLOG(2) << "Adding " << new_pred->ToString()
+                << " to control dependency of peeled instruction: "
+                << new_instr->ToString();
+      }
+    }
+  }
+  return absl::OkStatus();
+}
+
+// TODO(olechwierowicz): Extract common logic of this and `FullyUnroll` to
+// a separate function.
+absl::StatusOr<bool> DoubleBufferingUnroll(HloInstruction* while_instr,
+                                           HloModule* module) {
+  TF_ASSIGN_OR_RETURN(auto config,
+                      while_instr->backend_config<WhileLoopBackendConfig>());
+
+  CHECK(config.has_known_trip_count())
+      << "Only loops with known trip count are supported.";
+  int64_t exact_trip_count = config.known_trip_count().n();
+  VLOG(2) << "Processing while loop " << while_instr->ToString()
+          << " with trip count: " << exact_trip_count;
+
+  HloComputation* while_body = while_instr->while_body();
+
+  VLOG(2) << "Processing root " << while_body->root_instruction()->ToString();
+
+  auto old_loop_roots = while_body->root_instruction()->mutable_operands();
+  HloInstruction* input_parameter = while_body->parameter_instruction(0);
+  VLOG(2) << "Processing input parameter " << input_parameter->ToString();
+  absl::flat_hash_map<HloInstruction*, HloInstruction*> old_to_new_map;
+  absl::flat_hash_set<HloInstruction*> skip_control_dep_injection;
+
+  bool is_peeled = exact_trip_count % 2;
+  if (is_peeled) {
+    VLOG(2) << "Found loops with odd trip count, 1 iteration will be peeled "
+               "outside of the main body.";
+    TF_RETURN_IF_ERROR(PeelInstructionsForOddTripCount(module, while_instr));
+    exact_trip_count -= 1;
+  }
+
+  std::string suffix = "double_buffer_clone";
+  old_to_new_map[input_parameter] = while_body->root_instruction();
+  for (HloInstruction* old_instr : while_body->MakeInstructionPostOrder()) {
+    if (old_to_new_map.find(old_instr) != old_to_new_map.end()) {
+      continue;
+    }
+    VLOG(2) << "Cloning instruction " << old_instr->ToString();
+    std::vector<HloInstruction*> new_operands;
+    for (HloInstruction* old_operand : old_instr->mutable_operands()) {
+      new_operands.push_back(old_to_new_map[old_operand]);
+    }
+    HloInstruction* new_instr =
+        while_body->AddInstruction(old_instr->CloneWithNewOperands(
+            old_instr->shape(), new_operands, suffix));
+
+    // If an elementwise instruction with constant operand is present, we
+    // won't inject control dependency at the end to allow more constant
+    // folding opportunities.
+    if (old_instr->IsElementwiseBinary() && old_instr->HasConstantOperand()) {
+      skip_control_dep_injection.insert(old_instr);
+    }
+    SetChannelIdForNewCollective(new_instr, module);
+    TF_CHECK_OK(SetSendRecvValidation(old_instr, new_instr, is_peeled));
+    old_to_new_map[old_instr] = new_instr;
+    VLOG(2) << "Added instruction " << new_instr->ToString();
+  }
+
+  while_body->set_root_instruction(
+      old_to_new_map[while_body->root_instruction()]);
+  VLOG(2) << "Replaced with new root "
+          << while_body->root_instruction()->ToString();
+
+  // Handle existing control dependencies.
+  TF_RETURN_IF_ERROR(HandleControlDependencies(while_body, old_to_new_map,
+                                               &old_loop_roots, input_parameter,
+                                               skip_control_dep_injection));
+
+  WhileLoopBackendConfig new_config;
+  new_config.mutable_known_trip_count()->set_n(exact_trip_count / 2);
+  TF_RETURN_IF_ERROR(while_instr->set_backend_config(new_config));
+  return true;  // changed
+}
+
+}  // namespace
+
+absl::StatusOr<bool> DoubleBufferLoopUnrolling::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  bool changed = false;
+  std::vector<HloInstruction*> while_instrs;
+  for (auto comp : module->MakeNonfusionComputations()) {
+    absl::c_copy_if(comp->instructions(), std::back_inserter(while_instrs),
+                    HloPredicateIsOp<HloOpcode::kWhile>);
+  }
+  VLOG(2) << "Processing " << while_instrs.size() << " while loops.";
+
+  for (HloInstruction* while_instr : while_instrs) {
+    TF_ASSIGN_OR_RETURN(WhileLoopBackendConfig config,
+                        while_instr->backend_config<WhileLoopBackendConfig>());
+    if (!config.has_known_trip_count() || config.known_trip_count().n() == 1) {
+      VLOG(2) << while_instr->ToString()
+              << " doesn't have exact trip count, skipping loop unrolling "
+                 "for now";
+      continue;
+    }
+
+    if (unroll_strategy_ == UnrollStrategy::kFullUnroll) {
+      TF_ASSIGN_OR_RETURN(changed, FullyUnroll(while_instr, module));
+    } else if (unroll_strategy_ == UnrollStrategy::kDoubleBuffer) {
+      TF_ASSIGN_OR_RETURN(changed, DoubleBufferingUnroll(while_instr, module));
+    } else {
+      LOG(FATAL) << absl::StrCat("Unhandled unrolling strategy: ",
+                                 unroll_strategy_);
+    }
+  }
+
+  VLOG(2) << "LoopDoubleBufferTransformer output: " << module->ToString();
+
+  // Run necessary cleanup to ensure LoopDoubleBufferTransformer behaves
+  // correctly.
+  if (changed) {
+    // The call graph will not be flat if one of the loops that was unrolled
+    // contains any kind of call to another computation---since the call will
+    // be duplicated, thereby adding a second callsite for that computation.
+    TF_RETURN_IF_ERROR(
+        FlattenCallGraph().Run(module, execution_threads).status());
+  }
+
+  return changed;
+}
+
+}  // end namespace gpu
+}  // end namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling.h b/third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling.h
new file mode 100644
index 0000000..26bb178
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling.h
@@ -0,0 +1,75 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_DOUBLE_BUFFER_LOOP_UNROLLING_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_DOUBLE_BUFFER_LOOP_UNROLLING_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// With `kDoubleBuffer` strategy:
+//   This pass performs the unrolling-by-2 loop transformation
+//   to effectively achieve double buffering between inputs and outputs
+//   of previously rolled iterations.
+//   This pass only runs on loops with known trip counts.
+//   For even number of iterations, unrolling-by-2 will be done directly.
+//   For odd number of iterations, the first iteration of the loop will be
+//   peeled outside of the while loop to make the trip count an even number,
+//   then proceed to unroll by 2.
+//   It also updates the trip count property of the loop to the correct one
+//   (n/2).
+//
+// With `kFullUnroll` strategy:
+//   This pass will perform the full unroll of the loop with the same strategy
+//   that is used with `kDoubleBuffer` but while loop trip count times.
+//   It updates the trip count of the while loop to 1, and relies on other
+//   passes (like `WhileLoopSimplifier`) to simplify/get rid of the while loop
+//   eventually.
+//
+// Note that this pass will flatten the call graph if any loop has been
+// unrolled.
+// TODO(olechwierowicz): Rename the loop unroller to something more generic like
+// 'DoubleBufferLoopUnrolling'.
+class DoubleBufferLoopUnrolling : public HloModulePass {
+ public:
+  enum class UnrollStrategy { kDoubleBuffer, kFullUnroll };
+
+  explicit DoubleBufferLoopUnrolling(
+      UnrollStrategy unroll_strategy = UnrollStrategy::kDoubleBuffer)
+      : unroll_strategy_(unroll_strategy) {};
+  ~DoubleBufferLoopUnrolling() override = default;
+
+  absl::string_view name() const override {
+    return "loop-double-buffer-transformer";
+  }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+  UnrollStrategy unroll_strategy_;
+};
+
+}  // end namespace gpu
+}  // end namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_DOUBLE_BUFFER_LOOP_UNROLLING_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling_test.cc b/third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling_test.cc
new file mode 100644
index 0000000..05e704d
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling_test.cc
@@ -0,0 +1,1243 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/double_buffer_loop_unrolling.h"
+
+#include <cstdint>
+#include <memory>
+#include <optional>
+
+#include "absl/container/flat_hash_set.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/hlo/utils/hlo_query.h"
+#include "xla/service/tuple_simplifier.h"
+#include "xla/test.h"
+#include "xla/tests/filecheck.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/xla.pb.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/status_matchers.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+using tsl::testing::IsOkAndHolds;
+
+int64_t CountInstructions(const HloComputation& computation, HloOpcode opcode) {
+  int64_t count = 0;
+  for (const auto& instruction : computation.instructions()) {
+    if (instruction->opcode() == opcode) {
+      count++;
+    }
+  }
+  return count;
+}
+
+int64_t CountInstructions(const HloModule& module, HloOpcode opcode) {
+  int64_t count = 0;
+  for (const auto& computation : module.computations()) {
+    count += CountInstructions((*computation), opcode);
+  }
+  return count;
+}
+
+class GpuLoopDoubleBufferTransformerTest : public HloTestBase {
+  DebugOptions GetDebugOptionsForTest() override {
+    DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest();
+    debug_options.set_xla_gpu_enable_while_loop_double_buffering(true);
+    return debug_options;
+  }
+};
+
+TEST_F(GpuLoopDoubleBufferTransformerTest, FullUnrollOddTripCountTest) {
+  const char* const kModuleString = R"(
+HloModule all_gather_overlapping
+condition {
+  input_tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) parameter(0)
+  cond = s32[] get-tuple-element(input_tuple), index=3
+  trip_count = s32[] constant(10)
+  ROOT done = pred[] compare(cond, trip_count), direction=LT
+}
+
+body {
+ input_tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) parameter(0)
+ param_0 = f32[1,128] get-tuple-element(input_tuple), index=0
+ param_1 = f32[2,128] get-tuple-element(input_tuple), index=2
+ cond = s32[] get-tuple-element(input_tuple), index=3
+ c0 = f32[] constant(0)
+ splat_c0 = f32[1,128] broadcast(c0), dimensions={}
+ add = f32[1,128] add(splat_c0, param_0)
+ all-gather-start = (f32[1,128], f32[2,128]) all-gather-start(add), channel_id=1337, replica_groups={{0,1}}, dimensions={0}, use_global_device_ids=true
+ c1_s32 = s32[] constant(1)
+ c0_s32 = s32[] constant(0)
+ one = s32[] constant(1)
+ cond_plus_1 = s32[] add(cond, one)
+ dynamic-slice = f32[1,128] dynamic-slice(param_1, c1_s32, c0_s32), dynamic_slice_sizes={1,128}
+ all-gather-done = f32[2,128] all-gather-done(all-gather-start)
+ ROOT output_tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) tuple(param_0, dynamic-slice, all-gather-done, cond_plus_1)
+}
+
+ENTRY main {
+ param_0 = f32[1,128] parameter(0)
+ param_1 = f32[2,128] parameter(1)
+ param_2 = s32[] constant(0)
+ tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) tuple(param_0, param_0, param_1, param_2)
+ ROOT while = (f32[1,128], f32[1,128], f32[2,128], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"11"}}
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
+                          ParseAndReturnVerifiedModule(kModuleString));
+  DoubleBufferLoopUnrolling double_buffer(
+      DoubleBufferLoopUnrolling::UnrollStrategy::kFullUnroll);
+  TupleSimplifier tuple_simp;
+  bool changed;
+  TF_ASSERT_OK_AND_ASSIGN(changed, double_buffer.Run(module.get()));
+  EXPECT_TRUE(changed);
+  TF_ASSERT_OK_AND_ASSIGN(changed, tuple_simp.Run(module.get()));
+  EXPECT_TRUE(changed);
+  HloInstruction* while_instruction = hlo_query::GetFirstInstructionWithOpcode(
+      *module->entry_computation(), HloOpcode::kWhile);
+  TF_ASSERT_OK_AND_ASSIGN(
+      WhileLoopBackendConfig config,
+      while_instruction->backend_config<WhileLoopBackendConfig>());
+  int64_t exact_trip_count = config.known_trip_count().n();
+  EXPECT_EQ(exact_trip_count, 1);
+  EXPECT_EQ(CountInstructions((*while_instruction->while_body()),
+                              HloOpcode::kAllGatherStart),
+            11);
+  EXPECT_EQ(CountInstructions((*module), HloOpcode::kAllGatherStart), 11);
+}
+
+TEST_F(GpuLoopDoubleBufferTransformerTest, FullUnrollEvenTripCountTest) {
+  const char* const kModuleString = R"(
+HloModule all_gather_overlapping
+condition {
+  input_tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) parameter(0)
+  cond = s32[] get-tuple-element(input_tuple), index=3
+  trip_count = s32[] constant(10)
+  ROOT done = pred[] compare(cond, trip_count), direction=LT
+}
+
+body {
+ input_tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) parameter(0)
+ param_0 = f32[1,128] get-tuple-element(input_tuple), index=0
+ param_1 = f32[2,128] get-tuple-element(input_tuple), index=2
+ cond = s32[] get-tuple-element(input_tuple), index=3
+ c0 = f32[] constant(0)
+ splat_c0 = f32[1,128] broadcast(c0), dimensions={}
+ add = f32[1,128] add(splat_c0, param_0)
+ // Start all-gather communication
+ all-gather-start = (f32[1,128], f32[2,128]) all-gather-start(add), channel_id=1337, replica_groups={{0,1}}, dimensions={0}, use_global_device_ids=true
+ c1_s32 = s32[] constant(1)
+ c0_s32 = s32[] constant(0)
+ one = s32[] constant(1)
+ cond_plus_1 = s32[] add(cond, one)
+ dynamic-slice = f32[1,128] dynamic-slice(param_1, c1_s32, c0_s32), dynamic_slice_sizes={1,128}
+ all-gather-done = f32[2,128] all-gather-done(all-gather-start)
+ ROOT output_tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) tuple(param_0, dynamic-slice, all-gather-done, cond_plus_1)
+}
+
+ENTRY main {
+ param_0 = f32[1,128] parameter(0)
+ param_1 = f32[2,128] parameter(1)
+ param_2 = s32[] constant(0)
+ tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) tuple(param_0, param_0, param_1, param_2)
+ ROOT while = (f32[1,128], f32[1,128], f32[2,128], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"10"}}
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
+                          ParseAndReturnVerifiedModule(kModuleString));
+  DoubleBufferLoopUnrolling double_buffer(
+      DoubleBufferLoopUnrolling::UnrollStrategy::kFullUnroll);
+  TupleSimplifier tuple_simp;
+  bool changed;
+  TF_ASSERT_OK_AND_ASSIGN(changed, double_buffer.Run(module.get()));
+  EXPECT_TRUE(changed);
+  TF_ASSERT_OK_AND_ASSIGN(changed, tuple_simp.Run(module.get()));
+  EXPECT_TRUE(changed);
+
+  HloInstruction* while_instruction;
+  for (auto instr : module->entry_computation()->instructions()) {
+    if (instr->opcode() == HloOpcode::kWhile) {
+      while_instruction = instr;
+    }
+  }
+  TF_ASSERT_OK_AND_ASSIGN(
+      WhileLoopBackendConfig config,
+      while_instruction->backend_config<WhileLoopBackendConfig>());
+  int64_t exact_trip_count = config.known_trip_count().n();
+  EXPECT_EQ(exact_trip_count, 1);
+  EXPECT_EQ(CountInstructions((*while_instruction->while_body()),
+                              HloOpcode::kAllGatherStart),
+            10);
+  EXPECT_EQ(CountInstructions((*module), HloOpcode::kAllGatherStart), 10);
+}
+
+TEST_F(GpuLoopDoubleBufferTransformerTest, UnrolledLoopEvenTripCount) {
+  const char* const kModuleString = R"(
+HloModule all_gather_overlapping
+condition {
+  input_tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) parameter(0)
+  cond = s32[] get-tuple-element(input_tuple), index=3
+  trip_count = s32[] constant(10)
+  ROOT done = pred[] compare(cond, trip_count), direction=LT
+}
+
+body {
+ input_tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) parameter(0)
+ param_0 = f32[1,128] get-tuple-element(input_tuple), index=0
+ param_1 = f32[2,128] get-tuple-element(input_tuple), index=2
+ cond = s32[] get-tuple-element(input_tuple), index=3
+ c0 = f32[] constant(0)
+ splat_c0 = f32[1,128] broadcast(c0), dimensions={}
+ add = f32[1,128] add(splat_c0, param_0)
+ // Start all-gather communication
+ all-gather-start = (f32[1,128], f32[2,128]) all-gather-start(add), channel_id=1337, replica_groups={{0,1}}, dimensions={0}, use_global_device_ids=true
+ // Intertwined with the all-gather communication, an operation happens which
+ // depends on param_1, but crucially has a different output shape (which
+ // excludes reusing param_1's buffer for its output).
+ c1_s32 = s32[] constant(1)
+ c0_s32 = s32[] constant(0)
+ one = s32[] constant(1)
+ cond_plus_1 = s32[] add(cond, one)
+ dynamic-slice = f32[1,128] dynamic-slice(param_1, c1_s32, c0_s32), dynamic_slice_sizes={1,128}
+ // The all-gather communication finishes
+ all-gather-done = f32[2,128] all-gather-done(all-gather-start)
+ ROOT output_tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) tuple(param_0, dynamic-slice, all-gather-done, cond_plus_1)
+}
+
+ENTRY main {
+ param_0 = f32[1,128] parameter(0)
+ param_1 = f32[2,128] parameter(1)
+ param_2 = s32[] constant(0)
+ tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) tuple(param_0, param_0, param_1, param_2)
+ ROOT while = (f32[1,128], f32[1,128], f32[2,128], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"10"}}
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
+                          ParseAndReturnVerifiedModule(kModuleString));
+  DoubleBufferLoopUnrolling double_buffer;
+  TupleSimplifier tuple_simp;
+  bool changed;
+  TF_ASSERT_OK_AND_ASSIGN(changed, double_buffer.Run(module.get()));
+  EXPECT_TRUE(changed);
+  TF_ASSERT_OK_AND_ASSIGN(changed, tuple_simp.Run(module.get()));
+  EXPECT_TRUE(changed);
+
+  HloInstruction* while_instruction = hlo_query::GetFirstInstructionWithOpcode(
+      *module->entry_computation(), HloOpcode::kWhile);
+  TF_ASSERT_OK_AND_ASSIGN(
+      WhileLoopBackendConfig config,
+      while_instruction->backend_config<WhileLoopBackendConfig>());
+  int64_t exact_trip_count = config.known_trip_count().n();
+  // We expect that after unrolling, the total trip count is half of original
+  // count.
+  EXPECT_EQ(exact_trip_count, 5);
+  // We expect that after unrolling, there should be 2 allgather starts,
+  // both in while body.
+  EXPECT_EQ(CountInstructions((*while_instruction->while_body()),
+                              HloOpcode::kAllGatherStart),
+            2);
+  EXPECT_EQ(CountInstructions((*module), HloOpcode::kAllGatherStart), 2);
+}
+
+TEST_F(GpuLoopDoubleBufferTransformerTest, UnrolledLoopOddTripCount) {
+  const char* const kModuleString = R"(
+HloModule all_gather_overlapping
+condition {
+  input_tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) parameter(0)
+  cond = s32[] get-tuple-element(input_tuple), index=3
+  trip_count = s32[] constant(10)
+  ROOT done = pred[] compare(cond, trip_count), direction=LT
+}
+
+body {
+ input_tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) parameter(0)
+ param_0 = f32[1,128] get-tuple-element(input_tuple), index=0
+ param_1 = f32[2,128] get-tuple-element(input_tuple), index=2
+ cond = s32[] get-tuple-element(input_tuple), index=3
+ c0 = f32[] constant(0)
+ splat_c0 = f32[1,128] broadcast(c0), dimensions={}
+ add = f32[1,128] add(splat_c0, param_0)
+ // Start all-gather communication
+ all-gather-start = (f32[1,128], f32[2,128]) all-gather-start(add), channel_id=1337, replica_groups={{0,1}}, dimensions={0}, use_global_device_ids=true
+ // Intertwined with the all-gather communication, an operation happens which
+ // depends on param_1, but crucially has a different output shape (which
+ // excludes reusing param_1's buffer for its output).
+ c1_s32 = s32[] constant(1)
+ c0_s32 = s32[] constant(0)
+ one = s32[] constant(1)
+ cond_plus_1 = s32[] add(cond, one)
+ dynamic-slice = f32[1,128] dynamic-slice(param_1, c1_s32, c0_s32), dynamic_slice_sizes={1,128}
+ // The all-gather communication finishes
+ all-gather-done = f32[2,128] all-gather-done(all-gather-start)
+ ROOT output_tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) tuple(param_0, dynamic-slice, all-gather-done, cond_plus_1)
+}
+
+ENTRY main {
+ param_0 = f32[1,128] parameter(0)
+ param_1 = f32[2,128] parameter(1)
+ param_2 = s32[] constant(0)
+ tuple = (f32[1,128], f32[1,128], f32[2,128], s32[]) tuple(param_0, param_0, param_1, param_2)
+ ROOT while = (f32[1,128], f32[1,128], f32[2,128], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"11"}}
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
+                          ParseAndReturnVerifiedModule(kModuleString));
+  DoubleBufferLoopUnrolling double_buffer;
+  TupleSimplifier tuple_simp;
+  EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
+  EXPECT_THAT(tuple_simp.Run(module.get()), IsOkAndHolds(true));
+
+  // We expect that for the while loop, no further copy needs to be added to the
+  // module.
+  HloInstruction* while_instruction = hlo_query::GetFirstInstructionWithOpcode(
+      *module->entry_computation(), HloOpcode::kWhile);
+  TF_ASSERT_OK_AND_ASSIGN(
+      WhileLoopBackendConfig config,
+      while_instruction->backend_config<WhileLoopBackendConfig>());
+  int64_t exact_trip_count = config.known_trip_count().n();
+  // We expect that after unrolling, the total trip count is half of original
+  // count.
+  EXPECT_EQ(exact_trip_count, 5);
+
+  // We expect that after unrolling, there should be 3 allgather starts,
+  // 1 in parent computation, 2 in while body.
+  EXPECT_EQ(CountInstructions((*while_instruction->while_body()),
+                              HloOpcode::kAllGatherStart),
+            2);
+  EXPECT_EQ(CountInstructions((*module), HloOpcode::kAllGatherStart), 3);
+
+  // We expect that after unrolling, the third operand of the input tuple should
+  // be the peeled allgather done.
+  EXPECT_EQ(while_instruction->operand(0)->operand(2)->opcode(),
+            HloOpcode::kAllGatherDone);
+}
+
+TEST_F(GpuLoopDoubleBufferTransformerTest,
+       UnrolledLoopNoControlDepsForConstantAdd) {
+  const char* const kModuleString = R"(
+HloModule loop_unrolling_no_deps
+condition {
+  input_tuple = (f32[], s32[]) parameter(0)
+  cond = s32[] get-tuple-element(input_tuple), index=1
+  trip_count = s32[] constant(10)
+  ROOT done = pred[] compare(cond, trip_count), direction=LT
+}
+
+body {
+ input_tuple = (f32[], s32[]) parameter(0)
+ param_0 = f32[] get-tuple-element(input_tuple), index=0
+ cond = s32[] get-tuple-element(input_tuple), index=1
+ c2 = f32[] constant(2)
+ add = f32[] add(c2, param_0)
+ one = s32[] constant(1)
+ cond_plus_1 = s32[] add(cond, one)
+ ROOT output_tuple = (f32[], s32[]) tuple(add, cond_plus_1)
+}
+
+ENTRY main {
+ param_0 = f32[] parameter(0)
+ param_2 = s32[] constant(0)
+ tuple = (f32[], s32[]) tuple(param_0, param_2)
+ ROOT while = (f32[], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"11"}}
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
+                          ParseAndReturnVerifiedModule(kModuleString));
+  DoubleBufferLoopUnrolling double_buffer;
+  TupleSimplifier tuple_simp;
+  EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
+  EXPECT_THAT(tuple_simp.Run(module.get()), IsOkAndHolds(true));
+
+  HloInstruction* while_instruction = hlo_query::GetFirstInstructionWithOpcode(
+      *module->entry_computation(), HloOpcode::kWhile);
+  TF_ASSERT_OK_AND_ASSIGN(
+      WhileLoopBackendConfig config,
+      while_instruction->backend_config<WhileLoopBackendConfig>());
+  int64_t exact_trip_count = config.known_trip_count().n();
+  // We expect that after unrolling, the total trip count is half of original
+  // count.
+  EXPECT_EQ(exact_trip_count, 5);
+
+  // We expect that after unrolling, there should be 4 adds
+  EXPECT_EQ(
+      CountInstructions((*while_instruction->while_body()), HloOpcode::kAdd),
+      4);
+
+  // We expect that after unrolling, the first operand of the output tuple
+  // should not have any control dependency since it's a elementwise add with a
+  // constant operand.
+  EXPECT_EQ(while_instruction->while_body()
+                ->root_instruction()
+                ->operand(0)
+                ->control_predecessors()
+                .size(),
+            0);
+}
+
+TEST_F(GpuLoopDoubleBufferTransformerTest,
+       UnrolledLoopNoControlDepsForCollective) {
+  const char* const kModuleString = R"(
+HloModule loop_unrolling_no_deps
+condition {
+  input_tuple = (f32[], s32[]) parameter(0)
+  cond = s32[] get-tuple-element(input_tuple), index=1
+  trip_count = s32[] constant(10)
+  ROOT done = pred[] compare(cond, trip_count), direction=LT
+}
+
+ar_add {
+  Arg_1 = f32[] parameter(1)
+  Arg_0 = f32[] parameter(0)
+  ROOT add_ar = f32[] add(Arg_1, Arg_0)
+}
+
+body {
+ input_tuple = (f32[], s32[]) parameter(0)
+ param_0 = f32[] get-tuple-element(input_tuple), index=0
+ cond = s32[] get-tuple-element(input_tuple), index=1
+ all-reduce-start = f32[] all-reduce-start(param_0), channel_id=8, replica_groups={{0}}, to_apply=ar_add, backend_config="{\"is_sync\":false}"
+ one = s32[] constant(1)
+ all-reduce-done = f32[] all-reduce-done(all-reduce-start)
+ cond_plus_1 = s32[] add(cond, one)
+ ROOT output_tuple = (f32[], s32[]) tuple(all-reduce-done, cond_plus_1)
+}
+
+ENTRY main {
+ param_0 = f32[] parameter(0)
+ param_2 = s32[] constant(0)
+ tuple = (f32[], s32[]) tuple(param_0, param_2)
+ ROOT while = (f32[], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"10"}}
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
+                          ParseAndReturnVerifiedModule(kModuleString));
+  DoubleBufferLoopUnrolling double_buffer;
+  TupleSimplifier tuple_simp;
+  EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
+  EXPECT_THAT(tuple_simp.Run(module.get()), IsOkAndHolds(true));
+
+  HloInstruction* while_instruction = hlo_query::GetFirstInstructionWithOpcode(
+      *module->entry_computation(), HloOpcode::kWhile);
+  TF_ASSERT_OK_AND_ASSIGN(
+      WhileLoopBackendConfig config,
+      while_instruction->backend_config<WhileLoopBackendConfig>());
+  int64_t exact_trip_count = config.known_trip_count().n();
+  // We expect that after unrolling, the total trip count is half of original
+  // count.
+  EXPECT_EQ(exact_trip_count, 5);
+
+  // We expect that after unrolling, there should be 2 all-reduce-starts
+  EXPECT_EQ(CountInstructions((*while_instruction->while_body()),
+                              HloOpcode::kAllReduceStart),
+            2);
+  absl::flat_hash_set<int64_t> channel_ids;
+  hlo_query::ForEachInstructionWithOpcode(
+      *while_instruction->while_body(), HloOpcode::kAllReduceStart,
+      [&channel_ids](HloInstruction* ar) {
+        // We expect that after unrolling, all-reduces should not have any
+        // control deps.
+        EXPECT_EQ(ar->control_predecessors().size(), 0);
+        channel_ids.insert(*(ar->channel_id()));
+      });
+  // we expect that all 2 all-reduces will have different channel ids.
+  EXPECT_EQ(channel_ids.size(), 2);
+}
+
+TEST_F(GpuLoopDoubleBufferTransformerTest,
+       FullyUnrolledLoopNoControlDepsForCollective) {
+  const char* const kModuleString = R"(
+HloModule loop_unrolling_no_deps
+condition {
+  input_tuple = (f32[], s32[]) parameter(0)
+  cond = s32[] get-tuple-element(input_tuple), index=1
+  trip_count = s32[] constant(10)
+  ROOT done = pred[] compare(cond, trip_count), direction=LT
+}
+
+ar_add {
+  Arg_1 = f32[] parameter(1)
+  Arg_0 = f32[] parameter(0)
+  ROOT add_ar = f32[] add(Arg_1, Arg_0)
+}
+
+body {
+ input_tuple = (f32[], s32[]) parameter(0)
+ param_0 = f32[] get-tuple-element(input_tuple), index=0
+ cond = s32[] get-tuple-element(input_tuple), index=1
+ all-reduce-start = f32[] all-reduce-start(param_0), channel_id=8, replica_groups={{0}}, to_apply=ar_add, backend_config="{\"is_sync\":false}"
+ one = s32[] constant(1)
+ all-reduce-done = f32[] all-reduce-done(all-reduce-start)
+ cond_plus_1 = s32[] add(cond, one)
+ ROOT output_tuple = (f32[], s32[]) tuple(all-reduce-done, cond_plus_1)
+}
+
+ENTRY main {
+ param_0 = f32[] parameter(0)
+ param_2 = s32[] constant(0)
+ tuple = (f32[], s32[]) tuple(param_0, param_2)
+ ROOT while = (f32[], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"10"}}
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
+                          ParseAndReturnVerifiedModule(kModuleString));
+  DoubleBufferLoopUnrolling double_buffer(
+      DoubleBufferLoopUnrolling::UnrollStrategy::kFullUnroll);
+  TupleSimplifier tuple_simp;
+  EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
+  EXPECT_THAT(tuple_simp.Run(module.get()), IsOkAndHolds(true));
+
+  HloInstruction* while_instruction = hlo_query::GetFirstInstructionWithOpcode(
+      *module->entry_computation(), HloOpcode::kWhile);
+  TF_ASSERT_OK_AND_ASSIGN(
+      WhileLoopBackendConfig config,
+      while_instruction->backend_config<WhileLoopBackendConfig>());
+  int64_t exact_trip_count = config.known_trip_count().n();
+  EXPECT_EQ(exact_trip_count, 1);
+
+  // We expect that after unrolling, there should be 10 all-reduce-starts
+  EXPECT_EQ(CountInstructions((*while_instruction->while_body()),
+                              HloOpcode::kAllReduceStart),
+            10);
+  absl::flat_hash_set<int64_t> channel_ids;
+  hlo_query::ForEachInstructionWithOpcode(
+      *while_instruction->while_body(), HloOpcode::kAllReduceStart,
+      [&channel_ids](HloInstruction* ar) {
+        // We expect that after unrolling, all-reduces should not have any
+        // control deps.
+        EXPECT_EQ(ar->control_predecessors().size(), 0);
+        channel_ids.insert(*(ar->channel_id()));
+      });
+  // we expect that all 10 all-reduces will have different channel ids.
+  EXPECT_EQ(channel_ids.size(), 10);
+}
+
+// The following 2 tests also address the regression described here:
+// https://github.com/openxla/xla/issues/6353
+TEST_F(GpuLoopDoubleBufferTransformerTest, NestedWhileLoopRemainsFlattened) {
+  const char* const kModuleString = R"(
+HloModule loop_unrolling_nested_while_loop_remains_flattened
+
+condition_nested {
+  input_tuple = (s32[]) parameter(0)
+  cond = s32[] get-tuple-element(input_tuple), index=0
+  trip_count = s32[] constant(10)
+  ROOT done = pred[] compare(cond, trip_count), direction=LT
+}
+
+body_nested {
+ input_tuple = (s32[]) parameter(0)
+ cond = s32[] get-tuple-element(input_tuple), index=0
+ one = s32[] constant(1)
+ cond_plus_1 = s32[] add(cond, one)
+ ROOT output = (s32[]) tuple(cond_plus_1)
+}
+
+condition {
+  input_tuple = (s32[]) parameter(0)
+  cond = s32[] get-tuple-element(input_tuple), index=0
+  trip_count = s32[] constant(10)
+  ROOT done = pred[] compare(cond, trip_count), direction=LT
+}
+
+body {
+  input_tuple = (s32[]) parameter(0)
+  ROOT output = (s32[]) while(input_tuple), condition=condition_nested, body=body_nested
+}
+
+ENTRY main {
+ param_0 = (s32[]) parameter(0)
+ ROOT while = (s32[]) while(param_0), condition=condition, body=body, backend_config={"known_trip_count":{"n":"10"}}
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
+                          ParseAndReturnVerifiedModule(kModuleString));
+  DoubleBufferLoopUnrolling double_buffer;
+  EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
+
+  absl::flat_hash_set<const HloComputation*> while_loops_callees;
+
+  hlo_query::ForEachInstructionWithOpcode(
+      *module, HloOpcode::kWhile,
+      [&while_loops_callees](HloInstruction* instr) {
+        EXPECT_TRUE(
+            while_loops_callees.insert(instr->while_condition()).second);
+        EXPECT_TRUE(while_loops_callees.insert(instr->while_body()).second);
+      });
+
+  // We expect that the nested while loop has been duplicated, along with its
+  // associated computations.
+  EXPECT_EQ(while_loops_callees.size(), 6);
+}
+
+TEST_F(GpuLoopDoubleBufferTransformerTest,
+       NestedWhileLoopRemainsFlattenedOddTripCount) {
+  const char* const kModuleString = R"(
+HloModule loop_unrolling_nested_while_loop_remains_flattened
+
+condition_nested {
+  input_tuple = (s32[]) parameter(0)
+  cond = s32[] get-tuple-element(input_tuple), index=0
+  trip_count = s32[] constant(10)
+  ROOT done = pred[] compare(cond, trip_count), direction=LT
+}
+
+body_nested {
+ input_tuple = (s32[]) parameter(0)
+ cond = s32[] get-tuple-element(input_tuple), index=0
+ one = s32[] constant(1)
+ cond_plus_1 = s32[] add(cond, one)
+ ROOT output = (s32[]) tuple(cond_plus_1)
+}
+
+condition {
+  input_tuple = (s32[]) parameter(0)
+  cond = s32[] get-tuple-element(input_tuple), index=0
+  trip_count = s32[] constant(10)
+  ROOT done = pred[] compare(cond, trip_count), direction=LT
+}
+
+body {
+  input_tuple = (s32[]) parameter(0)
+  ROOT output = (s32[]) while(input_tuple), condition=condition_nested, body=body_nested
+}
+
+ENTRY main {
+ param_0 = (s32[]) parameter(0)
+ ROOT while = (s32[]) while(param_0), condition=condition, body=body, backend_config={"known_trip_count":{"n":"11"}}
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
+                          ParseAndReturnVerifiedModule(kModuleString));
+  DoubleBufferLoopUnrolling double_buffer;
+  EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
+
+  absl::flat_hash_set<const HloComputation*> while_loops_callees;
+
+  hlo_query::ForEachInstructionWithOpcode(
+      *module, HloOpcode::kWhile,
+      [&while_loops_callees](HloInstruction* instr) {
+        EXPECT_TRUE(
+            while_loops_callees.insert(instr->while_condition()).second);
+        EXPECT_TRUE(while_loops_callees.insert(instr->while_body()).second);
+      });
+
+  // We expect that the nested while loop has been duplicated, along with its
+  // associated computations.
+  EXPECT_EQ(while_loops_callees.size(), 8);
+}
+
+TEST_F(GpuLoopDoubleBufferTransformerTest,
+       NestedWhileLoopRemainsFlattenedWhenFullyUnrolled) {
+  const char* const kModuleString = R"(
+HloModule loop_unrolling_nested_while_loop_remains_flattened
+
+condition_nested {
+  input_tuple = (s32[]) parameter(0)
+  cond = s32[] get-tuple-element(input_tuple), index=0
+  trip_count = s32[] constant(10)
+  ROOT done = pred[] compare(cond, trip_count), direction=LT
+}
+
+body_nested {
+ input_tuple = (s32[]) parameter(0)
+ cond = s32[] get-tuple-element(input_tuple), index=0
+ one = s32[] constant(1)
+ cond_plus_1 = s32[] add(cond, one)
+ ROOT output = (s32[]) tuple(cond_plus_1)
+}
+
+condition {
+  input_tuple = (s32[]) parameter(0)
+  cond = s32[] get-tuple-element(input_tuple), index=0
+  trip_count = s32[] constant(10)
+  ROOT done = pred[] compare(cond, trip_count), direction=LT
+}
+
+body {
+  input_tuple = (s32[]) parameter(0)
+  ROOT output = (s32[]) while(input_tuple), condition=condition_nested, body=body_nested
+}
+
+ENTRY main {
+ param_0 = (s32[]) parameter(0)
+ ROOT while = (s32[]) while(param_0), condition=condition, body=body, backend_config={"known_trip_count":{"n":"10"}}
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
+                          ParseAndReturnVerifiedModule(kModuleString));
+  DoubleBufferLoopUnrolling double_buffer(
+      DoubleBufferLoopUnrolling::UnrollStrategy::kFullUnroll);
+  EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
+
+  absl::flat_hash_set<const HloComputation*> while_loops_callees;
+
+  hlo_query::ForEachInstructionWithOpcode(
+      *module, HloOpcode::kWhile,
+      [&while_loops_callees](HloInstruction* instr) {
+        EXPECT_TRUE(
+            while_loops_callees.insert(instr->while_condition()).second);
+        EXPECT_TRUE(while_loops_callees.insert(instr->while_body()).second);
+      });
+
+  hlo_query::ForEachInstructionWithOpcode(
+      *module->entry_computation(), HloOpcode::kWhile,
+      [](HloInstruction* instr) {
+        TF_ASSERT_OK_AND_ASSIGN(
+            WhileLoopBackendConfig config,
+            instr->backend_config<WhileLoopBackendConfig>());
+        int64_t exact_trip_count = config.known_trip_count().n();
+        EXPECT_EQ(exact_trip_count, 1);
+      });
+
+  // We expect that the nested while loop has been fully duplicated 10
+  // times. The one outer while loop still remains so that's 11 while
+  // instructions. We check whether there are 22 distinct computations for
+  // each while loop body and condition.
+  EXPECT_EQ(while_loops_callees.size(), 22);
+}
+
+TEST_F(GpuLoopDoubleBufferTransformerTest, NestedWhileLoopAreUnrolled) {
+  const char* const kModuleString = R"(
+HloModule loop_unrolling_nested_are_unrolled
+condition_nested {
+  input_tuple = (s32[]) parameter(0)
+  cond = s32[] get-tuple-element(input_tuple), index=0
+  trip_count = s32[] constant(10)
+  ROOT done = pred[] compare(cond, trip_count), direction=LT
+}
+body_nested {
+ input_tuple = (s32[]) parameter(0)
+ cond = s32[] get-tuple-element(input_tuple), index=0
+ one = s32[] constant(1)
+ cond_plus_1 = s32[] add(cond, one)
+ ROOT output = (s32[]) tuple(cond_plus_1)
+}
+condition {
+  input_tuple = (s32[]) parameter(0)
+  cond = s32[] get-tuple-element(input_tuple), index=0
+  trip_count = s32[] constant(10)
+  ROOT done = pred[] compare(cond, trip_count), direction=LT
+}
+body {
+  input_tuple = (s32[]) parameter(0)
+  ROOT output = (s32[]) while(input_tuple), condition=condition_nested, body=body_nested, backend_config={"known_trip_count":{"n":"11"}}
+}
+ENTRY main {
+ param_0 = (s32[]) parameter(0)
+ ROOT while = (s32[]) while(param_0), condition=condition, body=body, backend_config={"known_trip_count":{"n":"11"}}
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
+                          ParseAndReturnVerifiedModule(kModuleString));
+  DoubleBufferLoopUnrolling double_buffer;
+  EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
+
+  int64_t num_whiles = 0;
+  hlo_query::ForEachInstructionWithOpcode(
+      *module, HloOpcode::kWhile, [&num_whiles](HloInstruction* instr) {
+        EXPECT_EQ(instr->backend_config<WhileLoopBackendConfig>()
+                      ->known_trip_count()
+                      .n(),
+                  5);
+        ++num_whiles;
+      });
+  // We expect the number of while loops to be 4 in total after unrolling.
+  EXPECT_EQ(num_whiles, 4);
+}
+
+TEST_F(GpuLoopDoubleBufferTransformerTest, NestedWhileLoopAreFullyUnrolled) {
+  const char* const kModuleString = R"(
+HloModule loop_unrolling_nested_are_unrolled
+condition_nested {
+  input_tuple = (s32[]) parameter(0)
+  cond = s32[] get-tuple-element(input_tuple), index=0
+  trip_count = s32[] constant(10)
+  ROOT done = pred[] compare(cond, trip_count), direction=LT
+}
+body_nested {
+ input_tuple = (s32[]) parameter(0)
+ cond = s32[] get-tuple-element(input_tuple), index=0
+ one = s32[] constant(1)
+ cond_plus_1 = s32[] add(cond, one)
+ ROOT output = (s32[]) tuple(cond_plus_1)
+}
+condition {
+  input_tuple = (s32[]) parameter(0)
+  cond = s32[] get-tuple-element(input_tuple), index=0
+  trip_count = s32[] constant(10)
+  ROOT done = pred[] compare(cond, trip_count), direction=LT
+}
+body {
+  input_tuple = (s32[]) parameter(0)
+  ROOT output = (s32[]) while(input_tuple), condition=condition_nested, body=body_nested, backend_config={"known_trip_count":{"n":"11"}}
+}
+ENTRY main {
+ param_0 = (s32[]) parameter(0)
+ ROOT while = (s32[]) while(param_0), condition=condition, body=body, backend_config={"known_trip_count":{"n":"11"}}
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
+                          ParseAndReturnVerifiedModule(kModuleString));
+  DoubleBufferLoopUnrolling double_buffer(
+      DoubleBufferLoopUnrolling::UnrollStrategy::kFullUnroll);
+  EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
+
+  int64_t num_whiles = 0;
+  hlo_query::ForEachInstructionWithOpcode(
+      *module, HloOpcode::kWhile, [&num_whiles](HloInstruction* instr) {
+        EXPECT_EQ(instr->backend_config<WhileLoopBackendConfig>()
+                      ->known_trip_count()
+                      .n(),
+                  1);
+        ++num_whiles;
+      });
+  EXPECT_EQ(num_whiles, 12);
+}
+
+TEST_F(GpuLoopDoubleBufferTransformerTest, WhileLoopWithCollectivePermute) {
+  const char* kModuleString = R"(
+HloModule loop_unrolling_no_deps
+condition {
+  input_tuple = (f32[], s32[]) parameter(0)
+  cond = s32[] get-tuple-element(input_tuple), index=1
+  trip_count = s32[] constant(10)
+  ROOT done = pred[] compare(cond, trip_count), direction=LT
+}
+ar_add {
+  Arg_1 = f32[] parameter(1)
+  Arg_0 = f32[] parameter(0)
+  ROOT add_ar = f32[] add(Arg_1, Arg_0)
+}
+body {
+  input_tuple = (f32[], s32[]) parameter(0)
+  param_0 = f32[] get-tuple-element(input_tuple), index=0
+  cond = s32[] get-tuple-element(input_tuple), index=1
+  collective-permute = f32[] collective-permute(param_0), channel_id=1, source_target_pairs={{0,1},{1,2},{2,3},{3,0}},
+                             frontend_attributes={_xla_send_recv_validation="{{0,6},{1,7},{2,8},{3,9}}"}
+  one = s32[] constant(1)
+  cond_plus_1 = s32[] add(cond, one)
+  ROOT output_tuple = (f32[], s32[]) tuple(collective-permute, cond_plus_1)
+}
+ENTRY main {
+  param_0 = f32[] parameter(0)
+  param_2 = s32[] constant(0)
+  tuple = (f32[], s32[]) tuple(param_0, param_2)
+  ROOT while = (f32[], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"10"}}
+}
+  )";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
+                          ParseAndReturnVerifiedModule(kModuleString));
+  DoubleBufferLoopUnrolling double_buffer(
+      DoubleBufferLoopUnrolling::UnrollStrategy::kDoubleBuffer);
+  EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
+  VLOG(0) << module->ToString();
+  EXPECT_TRUE(*RunFileCheck(module->ToString(), R"(
+    // CHECK: %body {{.+}} {
+    // CHECK:   %[[cp1:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{0,3},{1,3},{1,4},{2,4}{{[}]}}}
+    // CHECK:   %[[out1:.+]] = {{.+}} tuple({{.+}} %[[cp1]], {{.+}})
+    // CHECK:   %[[param2:.+]] = {{.+}} get-tuple-element({{.+}} %[[out1]]), index=0
+    // CHECK:   %[[cp2:.+]] = {{.+}} collective-permute({{.+}} %[[param2]]), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{0,2},{0,3},{1,3},{1,4}{{[}]}}}
+    // CHECK:   ROOT {{.+}} = {{.+}} tuple({{.+}} %[[cp2]], {{.+}})
+    // CHECK: }
+    // CHECK: ENTRY %main {{.+}} {
+    // CHECK-NOT: collective-permute
+    // CHECK: }
+  )"));
+}
+
+TEST_F(GpuLoopDoubleBufferTransformerTest,
+       WhileLoopWithCollectivePermutePeeled) {
+  const char* kModuleString = R"(
+HloModule loop_unrolling_no_deps
+condition {
+  input_tuple = (f32[], s32[]) parameter(0)
+  cond = s32[] get-tuple-element(input_tuple), index=1
+  trip_count = s32[] constant(15)
+  ROOT done = pred[] compare(cond, trip_count), direction=LT
+}
+ar_add {
+  Arg_1 = f32[] parameter(1)
+  Arg_0 = f32[] parameter(0)
+  ROOT add_ar = f32[] add(Arg_1, Arg_0)
+}
+body {
+  input_tuple = (f32[], s32[]) parameter(0)
+  param_0 = f32[] get-tuple-element(input_tuple), index=0
+  cond = s32[] get-tuple-element(input_tuple), index=1
+  collective-permute = f32[] collective-permute(param_0), channel_id=1, source_target_pairs={{0,1},{1,2},{2,3},{3,4},{4,5},{5,6},{6,7},{7,0}},
+                             frontend_attributes={_xla_send_recv_validation="{{0,7},{1,8},{2,9},{3,10},{4,11},{5,12},{6,13},{7,14}}"}
+  one = s32[] constant(1)
+  cond_plus_1 = s32[] add(cond, one)
+  ROOT output_tuple = (f32[], s32[]) tuple(collective-permute, cond_plus_1)
+}
+ENTRY main {
+  param_0 = f32[] parameter(0)
+  param_2 = s32[] constant(0)
+  tuple = (f32[], s32[]) tuple(param_0, param_2)
+  ROOT while = (f32[], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"15"}}
+}
+  )";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
+                          ParseAndReturnVerifiedModule(kModuleString));
+  DoubleBufferLoopUnrolling double_buffer(
+      DoubleBufferLoopUnrolling::UnrollStrategy::kDoubleBuffer);
+  EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
+  VLOG(0) << module->ToString();
+
+  EXPECT_TRUE(*RunFileCheck(module->ToString(), R"(
+    // CHECK: %body
+    // CHECK:   %[[cp1:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{0,3},{0,3},{1,4},{1,4},{2,5},{2,5},{3,6},{3,6}{{[}]}}}
+    // CHECK:   %[[out1:.+]] = {{.+}} tuple({{.+}} %[[cp1]], {{.+}})
+    // CHECK:   %[[param2:.+]] = {{.+}} get-tuple-element({{.+}} %[[out1]])
+    // CHECK:   %[[cp2:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{0,2},{0,3},{0,3},{1,4},{1,4},{2,5},{2,5},{3,6}{{[}]}}}
+    // CHECK:   ROOT {{.+}} = {{.+}} tuple({{.+}} %[[cp2]], {{.+}})
+    // CHECK: ENTRY %main {{.+}} {
+    // CHECK:   %[[cp_peeled:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{0,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0}{{[}]}}}
+    // CHECK:   %[[out_peeled:.+]] = {{.+}} tuple({{.+}} %[[cp_peeled]], {{.+}})
+    // CHECK:   %[[while:.+]] = {{.+}} while({{.+}} %[[out_peeled]])
+    // CHECK: }
+    )"));
+}
+
+TEST_F(GpuLoopDoubleBufferTransformerTest,
+       WhileLoopWithCollectivePermuteBackwardCycle) {
+  const char* kModuleString = R"(
+HloModule loop_unrolling_no_deps
+condition {
+  input_tuple = (f32[], s32[]) parameter(0)
+  cond = s32[] get-tuple-element(input_tuple), index=1
+  trip_count = s32[] constant(14)
+  ROOT done = pred[] compare(cond, trip_count), direction=LT
+}
+ar_add {
+  Arg_1 = f32[] parameter(1)
+  Arg_0 = f32[] parameter(0)
+  ROOT add_ar = f32[] add(Arg_1, Arg_0)
+}
+body {
+  input_tuple = (f32[], s32[]) parameter(0)
+  param_0 = f32[] get-tuple-element(input_tuple), index=0
+  cond = s32[] get-tuple-element(input_tuple), index=1
+  collective-permute = f32[] collective-permute(param_0), channel_id=1, source_target_pairs={{0,7},{1,0},{2,1},{3,2},{4,3},{5,4},{6,5},{7,6}},
+                             frontend_attributes={_xla_send_recv_validation="{{7,13},{6,12},{5,11},{4,10},{3,9},{2,8},{1,7},{0,6}}"}
+  one = s32[] constant(1)
+  cond_plus_1 = s32[] add(cond, one)
+  ROOT output_tuple = (f32[], s32[]) tuple(collective-permute, cond_plus_1)
+}
+ENTRY main {
+  param_0 = f32[] parameter(0)
+  param_2 = s32[] constant(0)
+  tuple = (f32[], s32[]) tuple(param_0, param_2)
+  ROOT while = (f32[], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"14"}}
+}
+  )";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
+                          ParseAndReturnVerifiedModule(kModuleString));
+  DoubleBufferLoopUnrolling double_buffer(
+      DoubleBufferLoopUnrolling::UnrollStrategy::kDoubleBuffer);
+  EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
+
+  EXPECT_TRUE(*RunFileCheck(module->ToString(), R"(
+    // CHECK: %body
+    // CHECK:   %[[cp1:.+]] = f32[] collective-permute(f32[] %param_0), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{4,6},{3,6},{3,5},{2,5},{2,4},{1,4},{1,3},{0,3}{{[}]}}}
+    // CHECK:   %[[out1:.+]] = {{.+}} tuple({{.+}} %[[cp1]], {{.+}})
+    // CHECK:   %[[param2:.+]] = {{.+}} get-tuple-element({{.+}} %[[out1]]), index=0
+    // CHECK:   %[[cp2:.+]] = {{.+}} collective-permute({{.+}} %[[param2]]), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{3,6},{3,5},{2,5},{2,4},{1,4},{1,3},{0,3},{0,2}{{[}]}}}
+    // CHECK:   ROOT {{.+}} = {{.+}} tuple({{.+}} %[[cp2]], {{.+}})
+    // CHECK: ENTRY %main
+    // CHECK-NOT: collective-permute
+    // CHECK: }
+  )"));
+}
+
+TEST_F(GpuLoopDoubleBufferTransformerTest,
+       WhileLoopWithCollectivePermuteBackwardCyclePeeled) {
+  const char* kModuleString = R"(
+HloModule loop_unrolling_no_deps
+condition {
+  input_tuple = (f32[], s32[]) parameter(0)
+  cond = s32[] get-tuple-element(input_tuple), index=1
+  trip_count = s32[] constant(15)
+  ROOT done = pred[] compare(cond, trip_count), direction=LT
+}
+ar_add {
+  Arg_1 = f32[] parameter(1)
+  Arg_0 = f32[] parameter(0)
+  ROOT add_ar = f32[] add(Arg_1, Arg_0)
+}
+body {
+  input_tuple = (f32[], s32[]) parameter(0)
+  param_0 = f32[] get-tuple-element(input_tuple), index=0
+  cond = s32[] get-tuple-element(input_tuple), index=1
+  collective-permute = f32[] collective-permute(param_0), channel_id=1, source_target_pairs={{0,7},{1,0},{2,1},{3,2},{4,3},{5,4},{6,5},{7,6}},
+                             frontend_attributes={_xla_send_recv_validation="{{7,14},{6,13},{5,12},{4,11},{3,10},{2,9},{1,8},{0,7}}"}
+  one = s32[] constant(1)
+  cond_plus_1 = s32[] add(cond, one)
+  ROOT output_tuple = (f32[], s32[]) tuple(collective-permute, cond_plus_1)
+}
+ENTRY main {
+  param_0 = f32[] parameter(0)
+  param_2 = s32[] constant(0)
+  tuple = (f32[], s32[]) tuple(param_0, param_2)
+  ROOT while = (f32[], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"15"}}
+}
+  )";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
+                          ParseAndReturnVerifiedModule(kModuleString));
+  DoubleBufferLoopUnrolling double_buffer(
+      DoubleBufferLoopUnrolling::UnrollStrategy::kDoubleBuffer);
+  EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
+
+  EXPECT_TRUE(*RunFileCheck(module->ToString(), R"(
+    // CHECK: %body
+    // CHECK:   %[[cp1:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{3,6},{3,6},{2,5},{2,5},{1,4},{1,4},{0,3},{0,3}{{[}]}}}
+    // CHECK:   %[[out1:.+]] = {{.+}} tuple({{.+}} %[[cp1]], {{.+}})
+    // CHECK:   %[[param2:.+]] = {{.+}} get-tuple-element({{.+}} %[[out1]]), index=0
+    // CHECK:   %[[cp2:.+]] = {{.+}} collective-permute({{.+}} %[[param2]]), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{3,6},{2,5},{2,5},{1,4},{1,4},{0,3},{0,3},{0,2}{{[}]}}}
+    // CHECK:   ROOT {{.+}} = {{.+}} tuple({{.+}} %[[cp2]], {{.+}})
+    // CHECK: }
+    // CHECK: ENTRY %main
+    // CHECK:   %[[cp_peeled:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{0,0}{{[}]}}}
+    // CHECK:   %[[out_peeled:.+]] = {{.+}} tuple({{.+}} %[[cp_peeled]], {{.+}})
+    // CHECK:   ROOT {{.+}} = {{.+}} while({{.+}} %[[out_peeled]])
+    // CHECK: }
+  )"));
+}
+
+TEST_F(GpuLoopDoubleBufferTransformerTest,
+       WhileLoopWithCollectivePermuteStartDone) {
+  const char* kModuleString = R"(
+HloModule loop_unrolling_no_deps
+condition {
+  input_tuple = (f32[], s32[]) parameter(0)
+  cond = s32[] get-tuple-element(input_tuple), index=1
+  trip_count = s32[] constant(10)
+  ROOT done = pred[] compare(cond, trip_count), direction=LT
+}
+ar_add {
+  Arg_1 = f32[] parameter(1)
+  Arg_0 = f32[] parameter(0)
+  ROOT add_ar = f32[] add(Arg_1, Arg_0)
+}
+body {
+  input_tuple = (f32[], s32[]) parameter(0)
+  param_0 = f32[] get-tuple-element(input_tuple), index=0
+  cond = s32[] get-tuple-element(input_tuple), index=1
+  collective-permute-start = (f32[], f32[], u32[], u32[]) collective-permute-start(param_0), channel_id=1, source_target_pairs={{0,1},{1,2},{2,3},{3,0}},
+                             frontend_attributes={_xla_send_recv_validation="{{0,6},{1,7},{2,8},{3,9}}"}
+  collective-permute = f32[] collective-permute-done(collective-permute-start)
+  one = s32[] constant(1)
+  cond_plus_1 = s32[] add(cond, one)
+  ROOT output_tuple = (f32[], s32[]) tuple(collective-permute, cond_plus_1)
+}
+ENTRY main {
+  param_0 = f32[] parameter(0)
+  param_2 = s32[] constant(0)
+  tuple = (f32[], s32[]) tuple(param_0, param_2)
+  ROOT while = (f32[], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"10"}}
+}
+  )";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
+                          ParseAndReturnVerifiedModule(kModuleString));
+  DoubleBufferLoopUnrolling double_buffer(
+      DoubleBufferLoopUnrolling::UnrollStrategy::kDoubleBuffer);
+  EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
+
+  EXPECT_TRUE(*RunFileCheck(module->ToString(), R"(
+    // CHECK: %body
+    // CHECK:   %[[cp_start1:.+]] = {{.+}} collective-permute-start({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{0,3},{1,3},{1,4},{2,4}{{[}]}}}
+    // CHECK:   %[[cp1:.+]] = {{.+}} collective-permute-done({{.+}} %[[cp_start1]])
+    // CHECK:   %[[out1:.+]] = {{.+}} tuple({{.+}} %[[cp1]], {{.+}})
+    // CHECK:   %[[param2:.+]] = {{.+}} get-tuple-element({{.+}} %[[out1]]), index=0
+    // CHECK:   %[[cp_start2:.+]] = {{.+}} collective-permute-start({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{0,2},{0,3},{1,3},{1,4}{{[}]}}}
+    // CHECK:   %[[cp2:.+]] = {{.+}} collective-permute-done({{.+}} %[[cp_start2]])
+    // CHECK:   ROOT {{.+}} = {{.+}} tuple({{.+}} %[[cp2]], {{.+}})
+    // CHECK: }
+    // CHECK: ENTRY %main
+    // CHECK-NOT: collective-permute
+    // CHECK: }
+  )"));
+}
+
+TEST_F(GpuLoopDoubleBufferTransformerTest, WhileLoopWithRecvDone) {
+  const char* kModuleString = R"(
+HloModule loop_unrolling_no_deps
+condition {
+  input_tuple = (f32[], s32[]) parameter(0)
+  cond = s32[] get-tuple-element(input_tuple), index=1
+  trip_count = s32[] constant(10)
+  ROOT done = pred[] compare(cond, trip_count), direction=LT
+}
+ar_add {
+  Arg_1 = f32[] parameter(1)
+  Arg_0 = f32[] parameter(0)
+  ROOT add_ar = f32[] add(Arg_1, Arg_0)
+}
+body {
+  input_tuple = (f32[], s32[]) parameter(0)
+  param_0 = f32[] get-tuple-element(input_tuple), index=0
+  cond = s32[] get-tuple-element(input_tuple), index=1
+  after-all.0 = token[] after-all()
+  recv.0 = (f32[], u32[], token[]) recv(after-all.0), channel_id=1,
+        frontend_attributes={
+          _xla_send_recv_source_target_pairs="{{0,1},{1,2},{2,3},{3,0}}",
+          _xla_send_recv_pipeline="0",
+          _xla_send_recv_validation="{{0,6},{1,7},{2,8},{3,9}}"
+        }
+  recv-done.0 = (f32[], token[]) recv-done(recv.0), channel_id=1,
+        frontend_attributes={
+          _xla_send_recv_pipeline="0"
+        }
+  recv-data = f32[] get-tuple-element(recv-done.0), index=0
+  one = s32[] constant(1)
+  cond_plus_1 = s32[] add(cond, one)
+  ROOT output_tuple = (f32[], s32[]) tuple(recv-data, cond_plus_1)
+}
+ENTRY main {
+  param_0 = f32[] parameter(0)
+  param_2 = s32[] constant(0)
+  tuple = (f32[], s32[]) tuple(param_0, param_2)
+  ROOT while = (f32[], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"10"}}
+}
+  )";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
+                          ParseAndReturnVerifiedModule(kModuleString));
+  DoubleBufferLoopUnrolling double_buffer(
+      DoubleBufferLoopUnrolling::UnrollStrategy::kDoubleBuffer);
+  EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
+
+  EXPECT_TRUE(*RunFileCheck(module->ToString(), R"(
+    // CHECK: %body
+    // CHECK:   %[[recv1:.+]] = {{.+}} recv({{.+}}), {{.+}},_xla_send_recv_validation={{[{]}}{0,3},{1,3},{1,4},{2,4}{{[}]}}
+    // CHECK:   %[[recv2:.+]] = {{.+}} recv({{.+}}), {{.+}},_xla_send_recv_validation={{[{]}}{0,2},{0,3},{1,3},{1,4}{{[}]}}
+    // CHECK: ENTRY %main
+    // CHECK-NOT: recv
+    // CHECK: }
+  )"));
+}
+
+TEST_F(GpuLoopDoubleBufferTransformerTest, WhileLoopWithSendDone) {
+  const char* kModuleString = R"(
+HloModule loop_unrolling_no_deps
+condition {
+  input_tuple = (f32[], s32[]) parameter(0)
+  cond = s32[] get-tuple-element(input_tuple), index=1
+  trip_count = s32[] constant(10)
+  ROOT done = pred[] compare(cond, trip_count), direction=LT
+}
+ar_add {
+  Arg_1 = f32[] parameter(1)
+  Arg_0 = f32[] parameter(0)
+  ROOT add_ar = f32[] add(Arg_1, Arg_0)
+}
+body {
+  input_tuple = (f32[], s32[]) parameter(0)
+  param_0 = f32[] get-tuple-element(input_tuple), index=0
+  cond = s32[] get-tuple-element(input_tuple), index=1
+  after-all.0 = token[] after-all()
+  send.0 = (f32[], u32[], token[]) send(param_0, after-all.0), channel_id=1,
+        frontend_attributes={
+          _xla_send_recv_source_target_pairs="{{0,1},{1,2},{2,3},{3,0}}",
+          _xla_send_recv_pipeline="0",
+          _xla_send_recv_validation="{{0,6},{1,7},{2,8},{3,9}}"
+        }
+  send-done.0 = token[] send-done(send.0), channel_id=1,
+        frontend_attributes={
+          _xla_send_recv_pipeline="0"
+        }
+  one = s32[] constant(1)
+  cond_plus_1 = s32[] add(cond, one)
+  ROOT output_tuple = (f32[], s32[]) tuple(param_0, cond_plus_1)
+}
+ENTRY main {
+  param_0 = f32[] parameter(0)
+  param_2 = s32[] constant(0)
+  tuple = (f32[], s32[]) tuple(param_0, param_2)
+  ROOT while = (f32[], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"10"}}
+}
+  )";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
+                          ParseAndReturnVerifiedModule(kModuleString));
+  DoubleBufferLoopUnrolling double_buffer(
+      DoubleBufferLoopUnrolling::UnrollStrategy::kDoubleBuffer);
+  EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true));
+
+  EXPECT_TRUE(*RunFileCheck(module->ToString(), R"(
+    // CHECK: %body
+    // CHECK:   %[[send1:.+]] = {{.+}} send({{.+}}), {{.+}},_xla_send_recv_validation={{[{]}}{0,3},{1,3},{1,4},{2,4}{{[}]}}
+    // CHECK:   %[[send2:.+]] = {{.+}} send({{.+}}), {{.+}},_xla_send_recv_validation={{[{]}}{0,2},{0,3},{1,3},{1,4}{{[}]}}
+    // CHECK: ENTRY %main
+    // CHECK-NOT: send
+    // CHECK: }
+  )"));
+}
+
+TEST_F(GpuLoopDoubleBufferTransformerTest,
+       WhileLoopWithTripCount1ShouldBeSkipped) {
+  const char* const kModuleString = R"(
+HloModule loop_unrolling_skipped
+condition_nested {
+  input_tuple = (s32[]) parameter(0)
+  cond = s32[] get-tuple-element(input_tuple), index=0
+  trip_count = s32[] constant(0)
+  ROOT done = pred[] compare(cond, trip_count), direction=LT
+}
+body_nested {
+ input_tuple = (s32[]) parameter(0)
+ cond = s32[] get-tuple-element(input_tuple), index=0
+ one = s32[] constant(1)
+ cond_plus_1 = s32[] add(cond, one)
+ ROOT output = (s32[]) tuple(cond_plus_1)
+}
+condition {
+  input_tuple = (s32[]) parameter(0)
+  cond = s32[] get-tuple-element(input_tuple), index=0
+  trip_count = s32[] constant(0)
+  ROOT done = pred[] compare(cond, trip_count), direction=LT
+}
+body {
+  input_tuple = (s32[]) parameter(0)
+  ROOT output = (s32[]) while(input_tuple), condition=condition_nested, body=body_nested, backend_config={"known_trip_count":{"n":"1"}}
+}
+ENTRY main {
+ param_0 = (s32[]) parameter(0)
+ ROOT while = (s32[]) while(param_0), condition=condition, body=body, backend_config={"known_trip_count":{"n":"1"}}
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
+                          ParseAndReturnVerifiedModule(kModuleString));
+  DoubleBufferLoopUnrolling double_buffer(
+      DoubleBufferLoopUnrolling::UnrollStrategy::kFullUnroll);
+  // The processing of the loop should be completely skipped.
+  EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(false));
+}
+
+}  // namespace
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.cc
new file mode 100644
index 0000000..f059607
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.cc
@@ -0,0 +1,627 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.h"
+
+#include <cstddef>
+#include <cstdint>
+#include <functional>
+#include <iterator>
+#include <optional>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/container/inlined_vector.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/ffi/ffi_api.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/hlo/ir/hlo_schedule.h"
+#include "xla/service/custom_call_target_registry.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/service/gpu/gpu_constants.h"
+#include "xla/service/gpu/hlo_traversal.h"
+#include "xla/service/gpu/ir_emission_utils.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/util.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/logging.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+
+namespace m = ::xla::match;
+
+// A dataflow path flowing from a definition to a user.
+using DefUseDataflowPath = absl::InlinedVector<HloInstruction*, 2>;
+
+// All dataflow paths flowing from a definition to all users. Each user will
+// have a separate entry in the vector.
+using DefUseDataflowPaths = absl::InlinedVector<DefUseDataflowPath, 4>;
+
+// A dataflow path flowing from a user to a definition.
+using UseDefDataflowPath = absl::InlinedVector<HloInstruction*, 4>;
+
+// All dataflow paths flowing from a user to all definitions of its operands.
+using UseDefDataflowPaths = absl::InlinedVector<HloInstruction*, 8>;
+
+using DataflowPathView = absl::Span<HloInstruction* const>;
+using DataflowPathsView = absl::Span<DataflowPathView>;
+
+using InstructionSet = absl::flat_hash_set<HloInstruction*>;
+
+bool IsNoOp(const HloInstruction* hlo) {
+  return HloPredicateIsOp<HloOpcode::kBitcast, HloOpcode::kTuple,
+                          HloOpcode::kGetTupleElement>(hlo);
+}
+
+bool IsCustomCall(const HloInstruction* hlo, absl::string_view platform_name) {
+  auto* custom_call = DynCast<HloCustomCallInstruction>(hlo);
+  if (custom_call == nullptr) return false;
+
+  // TODO(vuson): properly handle token by following
+  // `LhloDialectEmitter::EmitCustomCallOp`'s `CreateOperands` logic for
+  // `LhloDialectEmitter::EmitFusionOp`'s `RewriteFusionOperand`
+  if (custom_call->shape().IsTuple() &&
+      absl::c_any_of(
+          custom_call->shape().tuple_shapes(),
+          [&](const Shape& sub_shape) { return sub_shape.IsToken(); }))
+    return false;
+
+  const std::string call_target_name = custom_call->custom_call_target();
+
+  bool is_ffi_custom_call =
+      custom_call->api_version() == CustomCallApiVersion::API_VERSION_TYPED_FFI;
+
+  void* call_target = CustomCallTargetRegistry::Global()->Lookup(
+      call_target_name, std::string(platform_name));
+
+  absl::StatusOr<ffi::HandlerRegistration> handler_registration =
+      ffi::FindHandler(call_target_name, platform_name);
+
+  // At least one implementation should be available at run time.
+  bool found_custom_call = !is_ffi_custom_call && call_target != nullptr;
+  bool found_ffi_handler = is_ffi_custom_call && handler_registration.ok();
+
+  return found_custom_call || found_ffi_handler;
+}
+
+// Returns true if the slice is 128-byte-aligned. The slice starting
+// address is determined by the product of all non-sliced dimensions and an
+// offset defined by `slice_starts` of the slice op.
+//
+// For dynamic cases, we don't have info about the start indices, so we have to
+// be conservative by only accepting sliced shapes that have the product of all
+// non-sliced dimensions being a multiple of `kXlaAllocatedBufferAlignBytes`.
+bool IsAlignedSlice(const HloInstruction* slice) {
+  DCHECK(slice->opcode() == HloOpcode::kSlice ||
+         slice->opcode() == HloOpcode::kDynamicSlice ||
+         slice->opcode() == HloOpcode::kDynamicUpdateSlice)
+      << "Unknown slice operation: " << slice->ToString();
+
+  if (!IsContiguousSlice(*slice)) return false;
+
+  auto [full_shape, slice_shape] = [&] {
+    if (auto* dus = DynCast<HloDynamicUpdateSliceInstruction>(slice)) {
+      return std::make_pair(dus->shape(), dus->update()->shape());
+    }
+    return std::make_pair(slice->operand(0)->shape(), slice->shape());
+  }();
+
+  auto strides = ShapeUtil::ByteStrides(slice_shape);
+  if (!strides.has_value()) return false;
+
+  for (auto dim : slice_shape.layout().minor_to_major()) {
+    if ((strides.value()[dim] % kXlaAllocatedBufferAlignBytes) == 0) {
+      return true;
+    }
+    if (slice_shape.dimensions(dim) < full_shape.dimensions(dim)) {
+      return (slice->opcode() == HloOpcode::kSlice &&
+              (((*strides)[dim] * slice->slice_starts(dim)) %
+                   kXlaAllocatedBufferAlignBytes ==
+               0));
+    }
+  }
+  return true;
+}
+
+// Pattern matches the following IR (generated by `jax.lax.scan`) to check if
+// the offset is a loop iteration number:
+
+// clang-format off
+//   param = (s32[], s32[], s32[16]{0}, s32[16]{0}) parameter(0)
+//   // the index in `gte` has to be the loop iteration index
+//   gte = s32[] get-tuple-element(param), index=0
+//   c0 = s32[] constant(0) compare = pred[] compare(gte, c0), direction=LT
+//   c_trip_count = s32[] constant(16)
+//   add = s32[] add(gte, c_trip_count) select = s32[] select(compare, add, gte)
+// clang-format on
+
+bool IsLoopIterationNumber(const HloInstruction& offset) {
+  const HloComputation* parent = offset.parent();
+  if (!parent->IsWhileBodyComputation()) return false;
+
+  // Scan loops trip count must be known at compile time as it iterates over the
+  // leading dimension of the statically shaped input.
+  const HloInstruction* while_instr = parent->WhileCallInstruction();
+  auto config = while_instr->backend_config<xla::WhileLoopBackendConfig>();
+  if (!config.ok() || !config->has_known_trip_count()) return false;
+  int32_t trip_count = config->known_trip_count().n();
+
+  // First lets check the offset computation pattern
+  if (!Match(&offset, m::Select(m::Lt(m::GetTupleElement(m::Parameter(0)),
+                                      m::ConstantScalar<int32_t>(0)),
+                                m::Add(m::GetTupleElement(m::Parameter(0)),
+                                       m::ConstantScalar(trip_count)),
+                                m::GetTupleElement(m::Parameter())))) {
+    return false;
+  }
+
+  // Next, we check that the parameter used in offset computation is the loop
+  // induction variable
+  int64_t param_idx = offset.operand(2)->tuple_index();
+  const HloInstruction* root = offset.parent()->root_instruction();
+  if (root->opcode() != HloOpcode::kTuple) {
+    return false;
+  }
+  // Check the update operation
+  const HloInstruction* updated_var =
+      offset.parent()->root_instruction()->operand(param_idx);
+  if (!Match(updated_var, m::Add(m::GetTupleElement(m::Parameter(0), param_idx),
+                                 m::ConstantScalar(1)))) {
+    return false;
+  }
+  // Check that the condition considers this.
+  const HloInstruction* condition_root =
+      while_instr->while_condition()->root_instruction();
+  if (!Match(condition_root,
+             m::Lt(m::GetTupleElement(m::Parameter(0), param_idx),
+                   m::ConstantScalar(trip_count)))) {
+    return false;
+  }
+  // Check init
+  const HloInstruction* init_loop_iter =
+      while_instr->operand(0)->operand(param_idx);
+  if (!Match(init_loop_iter, m::ConstantScalar(0))) {
+    return false;
+  }
+
+  return true;
+}
+
+// This returns true for the constants that are handled in the dynamic slice
+// fusion runtime. These constants do not force a D2H copy and hence preserve
+// the cuda graph.
+bool IsHandledConstantForDynamicSliceFusion(const HloInstruction& offset) {
+  if (auto* cst = DynCast<HloConstantInstruction>(&offset)) {
+    switch (cst->shape().element_type()) {
+      case PrimitiveType::S32:
+      case PrimitiveType::S64:
+      case PrimitiveType::U32:
+      case PrimitiveType::U64:
+        return true;
+      default:
+        return false;
+    };
+  }
+  return false;
+}
+
+// This checks whether a dynamic index operation has all offsets that are either
+// constant or loop iteration offsets.
+bool HasConstantOrLoopIterationOffsets(
+    const HloDynamicIndexInstruction& instr) {
+  return llvm::all_of(instr.index_operands(), [](const HloInstruction* offset) {
+    return IsLoopIterationNumber(*offset) ||
+           IsHandledConstantForDynamicSliceFusion(*offset);
+  });
+}
+
+UseDefDataflowPaths GetSlicedOperandPaths(const HloInstruction* instr) {
+  UseDefDataflowPaths sliced_operand_paths;
+
+  // This set is used to avoid duplicates in the matched results. It contains
+  // the matched instructions that we have seen so far.
+  InstructionSet processed_instrs;
+
+  const auto& aliasing_pairs =
+      Cast<HloCustomCallInstruction>(instr)->output_to_operand_aliasing();
+  absl::flat_hash_set<int64_t> aliased_operands;
+  for (const auto& pair : aliasing_pairs) {
+    aliased_operands.insert(pair.second.first);
+  }
+
+  for (const auto* operand : instr->operands()) {
+    // output_to_operand_aliasing means the operand is to be materialized, which
+    // is against the whole idea of address computation fusion. Skip this
+    // operand.
+    if (aliased_operands.contains(instr->operand_index(operand))) continue;
+    UseDefDataflowPath maybe_sliced_operand_path;
+    bool slice_found = false;
+    // TODO: currently HloFindIf exits upon encountering the first node that
+    // matches. This works well if each operand only has 1 data flow (i.e. only
+    // flows through unary op). We might want to keep finding until the queue is
+    // empty: if the operand is a tuple, it might have different data flows
+    // (i.e. 1 for each element).
+    auto maybe_slice_instr =
+        HloBfsFindIf({operand}, [&](const HloInstruction* cur) {
+          // If the node is a match that has been processed, stop the traversal.
+          if (processed_instrs.contains(cur)) return true;
+
+          maybe_sliced_operand_path.push_back(const_cast<HloInstruction*>(cur));
+
+          if (IsOpcodeAnyOf<HloOpcode::kDynamicSlice, HloOpcode::kSlice>(cur)) {
+            if (IsAlignedSlice(cur)) {
+              slice_found = true;
+              return slice_found;
+            }
+          }
+
+          return !IsNoOp(cur);
+        });
+
+    if (maybe_slice_instr == std::nullopt) continue;
+    auto dynamic_index_operation =
+        DynCast<HloDynamicIndexInstruction>(maybe_slice_instr.value());
+    bool valid_slice_found =
+        slice_found &&
+        ((dynamic_index_operation &&
+          HasConstantOrLoopIterationOffsets(*dynamic_index_operation)) ||
+         (*maybe_slice_instr)->opcode() == HloOpcode::kSlice);
+    if (valid_slice_found ||
+        processed_instrs.contains(maybe_slice_instr.value())) {
+      // Even in the case of stopping at a match that has been processed, we
+      // still need to add instructions encountered in the sliced operand path
+      // during the latest traversal.
+      sliced_operand_paths.insert(sliced_operand_paths.end(),
+                                  maybe_sliced_operand_path.rbegin(),
+                                  maybe_sliced_operand_path.rend());
+      processed_instrs.insert(maybe_sliced_operand_path.begin(),
+                              maybe_sliced_operand_path.end());
+    }
+  }
+
+  sliced_operand_paths.push_back(const_cast<HloInstruction*>(instr));
+  return sliced_operand_paths;
+}
+
+// Each user of `instr` that goes into a DUS will have an entry in the returned
+// vector.
+// Each entry contains the sliced paths for that user, i.e. the sequence of ops
+// following the dataflow from the user itself to the DUS (included).
+DefUseDataflowPaths GetSlicedUserPaths(const HloInstruction* instr) {
+  DefUseDataflowPaths sliced_user_paths;
+  // This set is used to avoid duplicates in the matched results. It contains
+  // the matched instructions that we have seen so far.
+  InstructionSet processed_instrs;
+
+  auto traverse_hlo_and_collect = [&](HloInstruction* start) {
+    DefUseDataflowPath maybe_sliced_user_path;
+    bool dus_found = false;
+    auto maybe_dus_instr = HloBfsFindIf(
+        {start},
+        [&](const HloInstruction* cur) {
+          // If the node is a match that has been processed, stop the
+          // traversal.
+          if (processed_instrs.contains(cur)) return true;
+          maybe_sliced_user_path.push_back(const_cast<HloInstruction*>(cur));
+          if (const auto slice_instr =
+                  DynCast<HloDynamicUpdateSliceInstruction>(cur)) {
+            if (IsAlignedSlice(slice_instr)) {
+              dus_found = true;
+              return true;
+            }
+          }
+          return cur->user_count() > 1 || !IsNoOp(cur);
+        },
+        /*visit_operands=*/false);
+    if (maybe_dus_instr == std::nullopt) return;
+    auto dynamic_index_operation =
+        DynCast<HloDynamicIndexInstruction>(maybe_dus_instr.value());
+    bool valid_dus_found =
+        dus_found && dynamic_index_operation &&
+        HasConstantOrLoopIterationOffsets(*dynamic_index_operation);
+    if (valid_dus_found || processed_instrs.contains(maybe_dus_instr.value())) {
+      // Even in the case of stopping at a match that has been processed, we
+      // still need to add instructions encountered in the sliced user path
+      // during the latest traversal.
+      processed_instrs.insert(maybe_sliced_user_path.begin(),
+                              maybe_sliced_user_path.end());
+      sliced_user_paths.push_back(std::move(maybe_sliced_user_path));
+    }
+  };
+
+  if (instr->shape().IsTuple()) {
+    for (auto* user : instr->users()) {
+      if (DynCast<HloGetTupleElementInstruction>(user)) {
+        traverse_hlo_and_collect(user);
+      }
+    }
+  } else {
+    if (instr->user_count() == 1) {
+      traverse_hlo_and_collect(instr->users().front());
+    }
+  }
+
+  return sliced_user_paths;
+}
+
+absl::InlinedVector<HloInstruction*, 4> GetPatternCaptures(
+    DataflowPathView matches) {
+  absl::InlinedVector<HloInstruction*, 4> captures;
+
+  InstructionSet matched_instrs(matches.begin(), matches.end());
+
+  for (HloInstruction* instr : matches) {
+    for (HloInstruction* operand : instr->operands()) {
+      if (!matched_instrs.contains(operand) &&
+          absl::c_find(captures, operand) == captures.end()) {
+        captures.emplace_back(operand);
+      }
+    }
+  }
+
+  return captures;
+}
+
+absl::Status CreateRootTuple(
+    HloInstruction* hero, HloComputation::Builder& builder,
+    DataflowPathsView sliced_user_paths,
+    absl::flat_hash_map<const HloInstruction*, HloInstruction*>&
+        instr_mapping) {
+  unsigned tuple_size = hero->shape().tuple_shapes_size();
+
+  std::vector<HloInstruction*> sliced_elems(tuple_size, nullptr);
+  for (auto& sliced_user_path : sliced_user_paths) {
+    auto gte = Cast<HloGetTupleElementInstruction>(sliced_user_path.front());
+    sliced_elems[gte->tuple_index()] = sliced_user_path.back();
+  }
+
+  std::vector<HloInstruction*> elements;
+  for (size_t i = 0; i < tuple_size; ++i) {
+    if (sliced_elems[i] != nullptr) {
+      elements.push_back(instr_mapping[sliced_elems[i]]);
+      continue;
+    }
+    auto* gte = builder.AddInstruction(
+        HloInstruction::CreateGetTupleElement(instr_mapping[hero], i));
+    if (hero->shape().tuple_shapes(i).IsTuple()) {
+      instr_mapping[gte] = gte;
+      TF_RETURN_IF_ERROR(CreateRootTuple(gte, builder, {}, instr_mapping));
+      elements.push_back(builder.last_added_instruction());
+    } else {
+      elements.push_back(gte);
+    }
+  }
+  if (elements.size() > 1)
+    builder.AddInstruction(HloInstruction::CreateTuple(elements));
+
+  return absl::OkStatus();
+}
+
+absl::StatusOr<HloComputation*> CreateFusionBody(
+    HloModule* module, DataflowPathView sliced_operand_paths,
+    DataflowPathsView sliced_user_paths, DataflowPathView captures) {
+  HloComputation::Builder builder("address-computation");
+
+  // A mapping from original instructions to instructions in the fusion body.
+  absl::flat_hash_map<const HloInstruction*, HloInstruction*> instr_mapping;
+
+  auto mapped_operands = [&](HloInstruction* instr) {
+    absl::InlinedVector<HloInstruction*, 4> operands;
+    for (HloInstruction* operand : instr->operands()) {
+      operands.push_back(instr_mapping.at(operand));
+    }
+    return operands;
+  };
+
+  // For every captured value create a parameter instruction in the computation
+  // body and set up instruction mapping.
+  for (const HloInstruction* capture : captures) {
+    int64_t index = instr_mapping.size();
+    instr_mapping[capture] =
+        builder.AddInstruction(HloInstruction::CreateParameter(
+            index, capture->shape(), absl::StrCat("p", index)));
+  }
+
+  // Instructions in the pattern are already topologically sorted, as we visited
+  // them following use-def path, then reverse the list.
+  HloInstruction* hero;
+  for (HloInstruction* instr : sliced_operand_paths) {
+    instr_mapping[instr] = builder.AddInstruction(
+        instr->CloneWithNewOperands(instr->shape(), mapped_operands(instr)));
+    hero = instr;
+  }
+
+  for (auto& sliced_user_path : sliced_user_paths) {
+    for (HloInstruction* instr : sliced_user_path) {
+      instr_mapping[instr] = builder.AddInstruction(
+          instr->CloneWithNewOperands(instr->shape(), mapped_operands(instr)));
+    }
+  }
+
+  // Create a tuple if the hero is a tuple to make sure there's a buffer
+  // assigned for each of the elements. Make sure the tuple is not nil first.
+  if (hero->shape().IsTuple() && hero->shape().tuple_shapes_size() > 0) {
+    TF_RETURN_IF_ERROR(
+        CreateRootTuple(hero, builder, sliced_user_paths, instr_mapping));
+  }
+
+  return module->AddComputationAndUnifyNamesAndIds(builder.Build(), false);
+}
+
+absl::StatusOr<HloInstruction*> CreateFusionInstruction(
+    HloModule* module, HloInstruction* orig, DataflowPathView captures,
+    HloComputation* body, bool dynamic) {
+  HloComputation* parent = orig->parent();
+
+  // Add a fusion operation calling outlined fusion computation.
+  HloInstruction* fusion = parent->AddInstruction(HloInstruction::CreateFusion(
+      body->root_instruction()->shape(), HloInstruction::FusionKind::kCustom,
+      captures, body));
+  module->SetAndUniquifyInstrName(fusion, "address_computation");
+
+  // We don't need to set/update output_to_operand_aliasing for the new fusion
+  // instruction because all buffers are already assigned at this point.
+
+  // Set backends config to a matched custom fusion config.
+  GpuBackendConfig gpu_config;
+  FusionBackendConfig& backend_config =
+      *gpu_config.mutable_fusion_backend_config();
+  backend_config.set_kind("__custom_fusion");
+  CustomFusionConfig config;
+  config.set_name(dynamic ? "dynamic_address_computation"
+                          : "address_computation");
+  *backend_config.mutable_custom_fusion_config() = config;
+  TF_RETURN_IF_ERROR(fusion->set_backend_config(std::move(gpu_config)));
+
+  return fusion;
+}
+
+}  // namespace
+
+absl::StatusOr<bool> DynamicSliceFusionRewriter::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  absl::flat_hash_map<HloInstruction*,
+                      std::pair<UseDefDataflowPaths, DefUseDataflowPaths>>
+      matches;
+
+  // Collect all potential custom call matches in the non-fusion computations.
+  for (HloComputation* computation : module->computations()) {
+    if (computation->IsFusionComputation()) continue;
+    for (HloInstruction* instr : computation->instructions()) {
+      UseDefDataflowPaths sliced_operand_paths = {instr};
+      bool has_sliced_operand_paths = false;
+      if (IsLegacyCublasMatmul(*instr) || IsCustomCall(instr, platform_name_)) {
+        sliced_operand_paths = GetSlicedOperandPaths(instr);
+        has_sliced_operand_paths = sliced_operand_paths.size() > 1;
+      }
+      if (instr->opcode() == HloOpcode::kReduceScatter ||
+          IsLegacyCublasMatmul(*instr) || IsCustomCall(instr, platform_name_)) {
+        DefUseDataflowPaths sliced_user_paths = GetSlicedUserPaths(instr);
+        bool has_sliced_user_paths = absl::c_any_of(
+            sliced_user_paths,
+            [&](auto& sliced_user_path) { return !sliced_user_path.empty(); });
+
+        if (absl::c_any_of(sliced_user_paths, [&](auto& sliced_user_path) {
+              return DynCast<HloDynamicUpdateSliceInstruction>(
+                         sliced_user_path.back()) == nullptr;
+            })) {
+          return absl::InternalError(
+              "Expect sliced user path to end with a DUS.");
+        }
+
+        if (has_sliced_operand_paths || has_sliced_user_paths) {
+          matches[instr] = std::make_pair(std::move(sliced_operand_paths),
+                                          std::move(sliced_user_paths));
+        }
+      }
+    }
+  }
+
+  if (matches.empty()) return false;
+
+  for (auto& [hero, paths] : matches) {
+    auto& [sliced_operand_paths, sliced_user_paths] = paths;
+    std::vector<HloInstruction*> matched_instrs;
+    absl::c_copy(sliced_operand_paths, std::back_inserter(matched_instrs));
+
+    std::vector<DataflowPathView> sliced_user_paths_view;
+    for (auto& sliced_user_path : sliced_user_paths) {
+      absl::c_copy(sliced_user_path, std::back_inserter(matched_instrs));
+      DataflowPathView sliced_user_path_view{&sliced_user_path.front(),
+                                             sliced_user_path.size()};
+      sliced_user_paths_view.push_back(std::move(sliced_user_path_view));
+    }
+
+    auto captures = GetPatternCaptures(matched_instrs);
+
+    TF_ASSIGN_OR_RETURN(
+        HloComputation * fusion_body,
+        CreateFusionBody(module, sliced_operand_paths,
+                         DataflowPathsView(sliced_user_paths_view), captures));
+
+    bool has_dynamic_slices = absl::c_any_of(matched_instrs, [&](auto* instr) {
+      return DynCast<HloDynamicIndexInstruction>(instr) != nullptr;
+    });
+    TF_ASSIGN_OR_RETURN(
+        HloInstruction * fusion,
+        CreateFusionInstruction(module, hero, captures, fusion_body,
+                                has_dynamic_slices));
+
+    HloComputation* parent = hero->parent();
+    if (fusion->shape().IsTuple()) {
+      TF_RETURN_IF_ERROR(parent->ReplaceInstructionWithDifferentShape(
+          const_cast<HloInstruction*>(hero), fusion));
+      for (auto& sliced_user_path : sliced_user_paths) {
+        auto old_gte =
+            Cast<HloGetTupleElementInstruction>(sliced_user_path.front());
+        HloInstruction* gte =
+            parent->AddInstruction(HloInstruction::CreateGetTupleElement(
+                fusion, old_gte->tuple_index()));
+        TF_RETURN_IF_ERROR(
+            parent->ReplaceInstruction(sliced_user_path.back(), gte));
+      }
+    } else {
+      auto* instr_to_be_replaced = const_cast<HloInstruction*>(hero);
+      if (sliced_user_paths.empty()) {
+        // The only case where a tuple-shaped original hero op is fused into a
+        // non-tuple-shaped fusion is there's only one element of the original
+        // tuple being used. In that case, we need to replace that single
+        // get-tuple-element (instead of the hero op) with the fusion
+        // instruction.
+        if (hero->shape().IsTuple()) {
+          if (hero->user_count() != 1 ||
+              !DynCast<HloGetTupleElementInstruction>(hero->users().front())) {
+            return absl::InternalError(
+                "Expect a single get-tuple-element user of the original "
+                "tuple-shaped hero op when address computation fusion does "
+                "not return a tuple");
+          }
+          instr_to_be_replaced = hero->users().front();
+        }
+      } else {
+        instr_to_be_replaced = sliced_user_paths.front().back();
+      }
+      TF_RETURN_IF_ERROR(
+          parent->ReplaceInstruction(instr_to_be_replaced, fusion));
+      // This is required for collective operations which will not be removed.
+      if (hero->parent()) {
+        TF_RETURN_IF_ERROR(hero->parent()->RemoveInstruction(hero));
+      }
+    }
+  }
+
+  return true;
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.h b/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.h
new file mode 100644
index 0000000..ad996de
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.h
@@ -0,0 +1,91 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_DYNAMIC_SLICE_FUSION_REWRITER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_DYNAMIC_SLICE_FUSION_REWRITER_H_
+
+#include <string>
+#include <utility>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// Pattern matches (slice(s) + custom call) to custom address computation
+// fusions and rewrites them into fusion instructions and fusion computations.
+//
+// Example:
+//
+//  ENTRY %main {
+//    %p0 = bf16[2,8,8]{2,1,0} parameter(0)
+//    %p1 = bf16[2,8,8]{2,1,0} parameter(1)
+//    %slice_lhs = bf16[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]}
+//    %bitcast_lhs = bf16[8,8]{1,0} bitcast(%slice_lhs)
+//    %slice_rhs = bf16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]}
+//    %bitcast_rhs = bf16[8,8]{1,0} bitcast(%slice_rhs)
+//    ROOT %dot = bf16[8,8]{1,0} custom-call(%bitcast_lhs, %bitcast_rhs),
+//      custom_call_target="__cublas$gemm"
+//  }
+//
+// After the pass:
+//
+//  %address_computation {
+//    %p0 = bf16[2,8,8]{2,1,0} parameter(0)
+//    %p1 = bf16[2,8,8]{2,1,0} parameter(1)
+//    %slice_lhs = bf16[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]}
+//    %bitcast_lhs = bf16[8,8]{1,0} bitcast(%slice_lhs)
+//    %slice_rhs = bf16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]}
+//    %bitcast_rhs = bf16[8,8]{1,0} bitcast(%slice_rhs)
+//    ROOT %dot = bf16[8,8]{1,0} custom-call(%bitcast_lhs, %bitcast_rhs),
+//      custom_call_target="__cublas$gemm"
+//  }
+//
+//  ENTRY %main {
+//    %p0 = bf16[2,8,8]{2,1,0} parameter(0)
+//    %p1 = bf16[2,8,8]{2,1,0} parameter(1)
+//    ROOT %fusion.2 = bf16[8,8]{1,0} fusion(%p0, %p1),
+//        kind=kCustom, calls=%address_computation,
+//        backend_config={"fusion_backend_config":{
+//            "kind":"__custom_fusion",
+//            "custom_fusion_config":{"name":"address_computation"}
+//        }}
+//  }
+//
+class DynamicSliceFusionRewriter : public HloModulePass {
+ public:
+  absl::string_view name() const override {
+    return "address-computation-fusion-rewriter";
+  }
+
+  explicit DynamicSliceFusionRewriter(std::string platform_name)
+      : platform_name_(std::move(platform_name)) {}
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+  std::string platform_name_;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_DYNAMIC_SLICE_FUSION_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter_test.cc
new file mode 100644
index 0000000..2bd7168
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter_test.cc
@@ -0,0 +1,2064 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.h"
+
+#include <cstddef>
+#include <optional>
+
+#include "absl/status/status.h"
+#include "xla/client/lib/constants.h"
+#include "xla/client/xla_builder.h"
+#include "xla/ffi/ffi.h"
+#include "xla/ffi/ffi_api.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_schedule.h"
+#include "xla/service/custom_call_target_registry.h"
+#include "xla/service/gpu/gpu_device_info_for_tests.h"
+#include "xla/service/hlo_module_config.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/stream_executor/device_memory.h"
+#include "xla/stream_executor/gpu/gpu_types.h"
+#include "xla/stream_executor/stream.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/statusor.h"
+#include "tsl/platform/test.h"
+
+namespace xla::gpu {
+
+class DynamicSliceFusionRewriterTest : public HloTestBase {};
+
+TEST_F(DynamicSliceFusionRewriterTest, SimpleGemm) {
+  const char* hlo = R"(
+    HloModule test
+
+    ENTRY %main.9 {
+      %p0 = f16[2,8,8]{2,1,0} parameter(0)
+      %p1 = f16[2,8,8]{2,1,0} parameter(1)
+      %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]}
+      %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13)
+      %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]}
+      %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14)
+
+      ROOT %custom-call.1 = f16[8,8]{1,0} custom-call(%bitcast.41, %bitcast.42),
+        custom_call_target="__cublas$gemm",
+        backend_config={"gemm_backend_config":{
+          "alpha_real":1,
+          "beta":0,
+          "dot_dimension_numbers":{
+            "lhs_contracting_dimensions":["1"],
+            "rhs_contracting_dimensions":["0"],
+            "lhs_batch_dimensions":[],
+            "rhs_batch_dimensions":[]
+          },
+          "alpha_imag":0,
+          "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
+          "epilogue":"DEFAULT",
+          "lhs_stride":"64",
+          "rhs_stride":"64",
+          "grad_x":false,
+          "grad_y":false
+        }}
+    }
+  )";
+
+  const char* expected = R"(
+    ; CHECK:     %address-computation {{.*}} {
+    ; CHECK-DAG:   [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
+    ; CHECK-DAG:   [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1)
+    ; CHECK-DAG:   [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[1:2], [0:8], [0:8]}
+    ; CHECK-DAG:   [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
+    ; CHECK-DAG:   [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P1]]), slice={[1:2], [0:8], [0:8]}
+    ; CHECK-DAG:   [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
+    ; CHECK:       ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]),
+    ; CHECK:              custom_call_target="__cublas$gemm"
+    ; CHECK:     }
+
+    ; CHECK:     ENTRY %main{{.*}} {
+    ; CHECK:       ROOT [[FUSION:%[^ ]+]] = f16[8,8]{1,0} fusion
+    ; CHECK:         kind=kCustom, calls=%address-computation,
+    ; CHECK:         backend_config={
+    ; CHECK:           "kind":"__custom_fusion",
+    ; CHECK:           "custom_fusion_config":{"name":"address_computation","kernel_index":0}
+    ; CHECK:         }
+    ; CHECK:     }
+  )";
+
+  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+  RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmWithWorkspace) {
+  const char* hlo = R"(
+    HloModule test
+
+    ENTRY %main.9 {
+      %p0 = f16[2,8,8]{2,1,0} parameter(0)
+      %p1 = f16[2,8,8]{2,1,0} parameter(1)
+      %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]}
+      %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13)
+      %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]}
+      %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14)
+
+      ROOT %custom-call.1 = (f16[8,8]{1,0}, s8[256]{0}) custom-call(%bitcast.41, %bitcast.42),
+        custom_call_target="__cublas$gemm",
+        backend_config={"gemm_backend_config":{
+          "alpha_real":1,
+          "beta":0,
+          "dot_dimension_numbers":{
+            "lhs_contracting_dimensions":["1"],
+            "rhs_contracting_dimensions":["0"],
+            "lhs_batch_dimensions":[],
+            "rhs_batch_dimensions":[]
+          },
+          "alpha_imag":0,
+          "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
+          "epilogue":"DEFAULT",
+          "lhs_stride":"64",
+          "rhs_stride":"64",
+          "grad_x":false,
+          "grad_y":false
+        }}
+    }
+  )";
+
+  const char* expected = R"(
+    ; CHECK:     %address-computation {{.*}} {
+    ; CHECK-DAG:   [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
+    ; CHECK-DAG:   [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1)
+    ; CHECK-DAG:   [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[1:2], [0:8], [0:8]}
+    ; CHECK-DAG:   [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
+    ; CHECK-DAG:   [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P1]]), slice={[1:2], [0:8], [0:8]}
+    ; CHECK-DAG:   [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
+    ; CHECK:       [[CC:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0}) custom-call([[B0]], [[B1]]),
+    ; CHECK:              custom_call_target="__cublas$gemm"
+    ; CHECK:       [[DOT:%[^ ]+]] = f16[8,8]{1,0} get-tuple-element([[CC]]), index=0
+    ; CHECK:       [[WORKSPACE:%[^ ]+]] = s8[256]{0} get-tuple-element([[CC]]), index=1
+    ; CHECK:       ROOT [[TUPLE:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0})
+    ; CHECK:              tuple([[DOT]], [[WORKSPACE]])
+    ; CHECK:     }
+
+    ; CHECK:     ENTRY %main{{.*}} {
+    ; CHECK:       ROOT [[FUSION:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0}) fusion
+    ; CHECK:         kind=kCustom, calls=%address-computation,
+    ; CHECK:         backend_config={
+    ; CHECK:           "kind":"__custom_fusion",
+    ; CHECK:           "custom_fusion_config":{"name":"address_computation","kernel_index":0}
+    ; CHECK:         }
+    ; CHECK:     }
+  )";
+
+  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+  RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmWorkspaceIgnored) {
+  const char* hlo = R"(
+    HloModule test
+
+    ENTRY %main.9 {
+      %p0 = f16[2,8,8]{2,1,0} parameter(0)
+      %p1 = f16[2,8,8]{2,1,0} parameter(1)
+      %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]}
+      %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13)
+      %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]}
+      %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14)
+
+      %custom-call.1 = (f16[8,8]{1,0}, s8[256]{0}) custom-call(%bitcast.41, %bitcast.42),
+        custom_call_target="__cublas$gemm",
+        backend_config={"gemm_backend_config":{
+          "alpha_real":1,
+          "beta":0,
+          "dot_dimension_numbers":{
+            "lhs_contracting_dimensions":["1"],
+            "rhs_contracting_dimensions":["0"],
+            "lhs_batch_dimensions":[],
+            "rhs_batch_dimensions":[]
+          },
+          "alpha_imag":0,
+          "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
+          "epilogue":"DEFAULT",
+          "lhs_stride":"64",
+          "rhs_stride":"64",
+          "grad_x":false,
+          "grad_y":false
+        }}
+      ROOT %get-tuple-element.0 = f16[8,8]{1,0} get-tuple-element(%custom-call.1), index=0
+    }
+  )";
+
+  const char* expected = R"(
+    ; CHECK:     %address-computation {{.*}} {
+    ; CHECK-DAG:   [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
+    ; CHECK-DAG:   [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1)
+    ; CHECK-DAG:   [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[1:2], [0:8], [0:8]}
+    ; CHECK-DAG:   [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
+    ; CHECK-DAG:   [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P1]]), slice={[1:2], [0:8], [0:8]}
+    ; CHECK-DAG:   [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
+    ; CHECK:       [[CC:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0}) custom-call([[B0]], [[B1]]),
+    ; CHECK:              custom_call_target="__cublas$gemm"
+    ; CHECK:       [[DOT:%[^ ]+]] = f16[8,8]{1,0} get-tuple-element([[CC]]), index=0
+    ; CHECK:       [[WORKSPACE:%[^ ]+]] = s8[256]{0} get-tuple-element([[CC]]), index=1
+    ; CHECK:       ROOT [[TUPLE:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0})
+    ; CHECK:              tuple([[DOT]], [[WORKSPACE]])
+    ; CHECK:     }
+
+    ; CHECK:     ENTRY %main{{.*}} {
+    ; CHECK:       [[FUSION:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0}) fusion
+    ; CHECK:         kind=kCustom, calls=%address-computation,
+    ; CHECK:         backend_config={
+    ; CHECK:           "kind":"__custom_fusion",
+    ; CHECK:           "custom_fusion_config":{"name":"address_computation","kernel_index":0}
+    ; CHECK:         }
+    ; CHECK:       ROOT [[DOT_MAIN:%[^ ]+]] = f16[8,8]{1,0} get-tuple-element([[FUSION]]), index=0
+    ; CHECK:     }
+  )";
+
+  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+  RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmNotRoot) {
+  const char* hlo = R"(
+    HloModule test
+
+    ENTRY %main.9 {
+      %p0 = f16[2,8,8]{2,1,0} parameter(0)
+      %p1 = f16[2,8,8]{2,1,0} parameter(1)
+      %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]}
+      %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13)
+      %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]}
+      %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14)
+
+      %custom-call.1 = f16[8,8]{1,0} custom-call(%bitcast.41, %bitcast.42),
+        custom_call_target="__cublas$gemm",
+        backend_config={"gemm_backend_config":{
+          "alpha_real":1,
+          "beta":0,
+          "dot_dimension_numbers":{
+            "lhs_contracting_dimensions":["1"],
+            "rhs_contracting_dimensions":["0"],
+            "lhs_batch_dimensions":[],
+            "rhs_batch_dimensions":[]
+          },
+          "alpha_imag":0,
+          "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
+          "epilogue":"DEFAULT",
+          "lhs_stride":"64",
+          "rhs_stride":"64",
+          "grad_x":false,
+          "grad_y":false
+        }}
+      ROOT %res = f16[8,8]{1,0} add(%custom-call.1, %custom-call.1)
+    }
+  )";
+
+  const char* expected = R"(
+    ; CHECK:     %address-computation {{.*}} {
+    ; CHECK-DAG:   [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
+    ; CHECK-DAG:   [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1)
+    ; CHECK-DAG:   [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[1:2], [0:8], [0:8]}
+    ; CHECK-DAG:   [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
+    ; CHECK-DAG:   [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P1]]), slice={[1:2], [0:8], [0:8]}
+    ; CHECK-DAG:   [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
+    ; CHECK:       ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]),
+    ; CHECK:              custom_call_target="__cublas$gemm"
+    ; CHECK:     }
+
+    ; CHECK:     ENTRY %main{{.*}} {
+    ; CHECK:       [[FUSION:%[^ ]+]] = f16[8,8]{1,0} fusion
+    ; CHECK:         kind=kCustom, calls=%address-computation,
+    ; CHECK:         backend_config={
+    ; CHECK:           "kind":"__custom_fusion",
+    ; CHECK:           "custom_fusion_config":{"name":"address_computation","kernel_index":0}
+    ; CHECK:         }
+    ; CHECK:       ROOT {{.*}} = f16[8,8]{1,0} add([[FUSION]], [[FUSION]])
+    ; CHECK:     }
+  )";
+
+  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+  RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmOperandHasMultipleUsers) {
+  const char* hlo = R"(
+    HloModule test
+
+    ENTRY %main.9 {
+      %p0 = f16[2,8,8]{2,1,0} parameter(0)
+      %p1 = f16[4,8,8]{2,1,0} parameter(1)
+      %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]}
+      %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13)
+      %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[2:3], [0:8], [0:8]}
+      %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14)
+
+      %custom-call.1 = f16[8,8]{1,0} custom-call(%bitcast.41, %bitcast.42),
+        custom_call_target="__cublas$gemm",
+        backend_config={"gemm_backend_config":{
+          "alpha_real":1,
+          "beta":0,
+          "dot_dimension_numbers":{
+            "lhs_contracting_dimensions":["1"],
+            "rhs_contracting_dimensions":["0"],
+            "lhs_batch_dimensions":[],
+            "rhs_batch_dimensions":[]
+          },
+          "alpha_imag":0,
+          "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
+          "epilogue":"DEFAULT",
+          "lhs_stride":"64",
+          "rhs_stride":"64",
+          "grad_x":false,
+          "grad_y":false
+        }}
+      ROOT %res = f16[8,8]{1,0} add(%custom-call.1, %bitcast.41)
+    }
+  )";
+
+  const char* expected = R"(
+    ; CHECK:     %address-computation {{.*}} {
+    ; CHECK-DAG:   [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
+    ; CHECK-DAG:   [[P1:%[^ ]+]] = f16[4,8,8]{2,1,0} parameter(1)
+    ; CHECK-DAG:   [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[1:2], [0:8], [0:8]}
+    ; CHECK-DAG:   [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
+    ; CHECK-DAG:   [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P1]]), slice={[2:3], [0:8], [0:8]}
+    ; CHECK-DAG:   [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
+    ; CHECK:       ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]),
+    ; CHECK:              custom_call_target="__cublas$gemm"
+    ; CHECK:     }
+
+    ; CHECK:     ENTRY %main{{.*}} {
+    ; CHECK-DAG:   [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
+    ; CHECK-DAG:   [[P1:%[^ ]+]] = f16[4,8,8]{2,1,0} parameter(1)
+    ; CHECK-DAG:   [[FUSION:%[^ ]+]] = f16[8,8]{1,0} fusion([[P0]], [[P1]])
+    ; CHECK-DAG:     kind=kCustom, calls=%address-computation,
+    ; CHECK-DAG:     backend_config={
+    ; CHECK-DAG:       "kind":"__custom_fusion",
+    ; CHECK-DAG:       "custom_fusion_config":{"name":"address_computation","kernel_index":0}
+    ; CHECK-DAG:     }
+    ; CHECK-DAG:   [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[1:2], [0:8], [0:8]}
+    ; CHECK-DAG:   [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
+    ; CHECK:       ROOT {{.*}} = f16[8,8]{1,0} add([[FUSION]], [[B0]])
+    ; CHECK:     }
+  )";
+
+  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+  RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmOperandsHaveMultipleUsers) {
+  const char* hlo = R"(
+    HloModule test
+
+    ENTRY %main.9 {
+      %p0 = f16[2,8,8]{2,1,0} parameter(0)
+      %p1 = f16[2,8,8]{2,1,0} parameter(1)
+      %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]}
+      %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13)
+      %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]}
+      %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14)
+
+      %custom-call.0 = f16[8,8]{1,0} custom-call(%bitcast.41, %bitcast.42),
+        custom_call_target="__cublas$gemm",
+        backend_config={"gemm_backend_config":{
+          "alpha_real":1,
+          "beta":0,
+          "dot_dimension_numbers":{
+            "lhs_contracting_dimensions":["1"],
+            "rhs_contracting_dimensions":["0"],
+            "lhs_batch_dimensions":[],
+            "rhs_batch_dimensions":[]
+          },
+          "alpha_imag":0,
+          "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
+          "epilogue":"DEFAULT",
+          "lhs_stride":"64",
+          "rhs_stride":"64",
+          "grad_x":false,
+          "grad_y":false
+        }}
+
+      ROOT %custom-call.1 = f16[8,8]{1,0} custom-call(%bitcast.42, %bitcast.41),
+        custom_call_target="__cublas$gemm",
+        backend_config={"gemm_backend_config":{
+          "alpha_real":1,
+          "beta":0,
+          "dot_dimension_numbers":{
+            "lhs_contracting_dimensions":["1"],
+            "rhs_contracting_dimensions":["0"],
+            "lhs_batch_dimensions":[],
+            "rhs_batch_dimensions":[]
+          },
+          "alpha_imag":0,
+          "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
+          "epilogue":"DEFAULT",
+          "lhs_stride":"64",
+          "rhs_stride":"64",
+          "grad_x":false,
+          "grad_y":false
+        }}
+    }
+  )";
+
+  const char* expected = R"(
+    ; CHECK:     %address-computation{{.*}} {
+    ; CHECK-DAG:   [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
+    ; CHECK-DAG:   [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1)
+    ; CHECK-DAG:   [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[1:2], [0:8], [0:8]}
+    ; CHECK-DAG:   [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
+    ; CHECK-DAG:   [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P1]]), slice={[1:2], [0:8], [0:8]}
+    ; CHECK-DAG:   [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
+    ; CHECK:       ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]),
+    ; CHECK:              custom_call_target="__cublas$gemm"
+    ; CHECK:     }
+    ; CHECK:     %address-computation{{.*}} {
+    ; CHECK-DAG:   [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
+    ; CHECK-DAG:   [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1)
+    ; CHECK-DAG:   [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[1:2], [0:8], [0:8]}
+    ; CHECK-DAG:   [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
+    ; CHECK-DAG:   [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P1]]), slice={[1:2], [0:8], [0:8]}
+    ; CHECK-DAG:   [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
+    ; CHECK:       ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]),
+    ; CHECK:              custom_call_target="__cublas$gemm"
+    ; CHECK:     }
+  )";
+
+  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+  RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmSlicingNotParameter) {
+  const char* hlo = R"(
+    HloModule test
+
+    ENTRY %main.9 {
+      %p0 = f16[4,8,8]{2,1,0} parameter(0)
+      %p1 = f16[2,8,8]{2,1,0} parameter(1)
+      %slice.12 = f16[2,8,8]{2,1,0} slice(%p0), slice={[0:2], [0:8], [0:8]}
+      %slice.13 = f16[1,8,8]{2,1,0} slice(%slice.12), slice={[1:2], [0:8], [0:8]}
+      %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13)
+      %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]}
+      %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14)
+
+      %custom-call.1 = f16[8,8]{1,0} custom-call(%bitcast.41, %bitcast.42),
+        custom_call_target="__cublas$gemm",
+        backend_config={"gemm_backend_config":{
+          "alpha_real":1,
+          "beta":0,
+          "dot_dimension_numbers":{
+            "lhs_contracting_dimensions":["1"],
+            "rhs_contracting_dimensions":["0"],
+            "lhs_batch_dimensions":[],
+            "rhs_batch_dimensions":[]
+          },
+          "alpha_imag":0,
+          "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
+          "epilogue":"DEFAULT",
+          "lhs_stride":"64",
+          "rhs_stride":"64",
+          "grad_x":false,
+          "grad_y":false
+        }}
+      ROOT %res = f16[8,8]{1,0} add(%custom-call.1, %custom-call.1)
+    }
+  )";
+
+  const char* expected = R"(
+    ; CHECK:     %address-computation {{.*}} {
+    ; CHECK-DAG:   [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
+    ; CHECK-DAG:   [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1)
+    ; CHECK-DAG:   [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[1:2], [0:8], [0:8]}
+    ; CHECK-DAG:   [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
+    ; CHECK-DAG:   [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P1]]), slice={[1:2], [0:8], [0:8]}
+    ; CHECK-DAG:   [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
+    ; CHECK:       ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]),
+    ; CHECK:              custom_call_target="__cublas$gemm"
+    ; CHECK:     }
+
+    ; CHECK:     ENTRY %main{{.*}} {
+    ; CHECK-DAG:   [[P0:%[^ ]+]] = f16[4,8,8]{2,1,0} parameter(0)
+    ; CHECK-DAG:   [[S0:%[^ ]+]] = f16[2,8,8]{2,1,0} slice([[P0]]), slice={[0:2], [0:8], [0:8]}
+    ; CHECK-DAG:   [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1)
+    ; CHECK:       [[FUSION:%[^ ]+]] = f16[8,8]{1,0} fusion([[S0]], [[P1]])
+    ; CHECK:         kind=kCustom, calls=%address-computation,
+    ; CHECK:         backend_config={
+    ; CHECK:           "kind":"__custom_fusion",
+    ; CHECK:           "custom_fusion_config":{"name":"address_computation","kernel_index":0}
+    ; CHECK:         }
+    ; CHECK:       ROOT {{.*}} = f16[8,8]{1,0} add([[FUSION]], [[FUSION]])
+    ; CHECK:     }
+  )";
+
+  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+  RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmNotContiguousSlice) {
+  const char* hlo = R"(
+    HloModule test
+
+    ENTRY %main.9 {
+      %p0 = f16[2,8,8]{2,1,0} parameter(0)
+      %p1 = f16[2,8,8]{2,1,0} parameter(1)
+      %slice.13 = f16[1,4,6]{2,1,0} slice(%p0), slice={[1:2], [0:4], [0:6]}
+      %bitcast.41 = f16[4,6]{1,0} bitcast(%slice.13)
+      %slice.14 = f16[1,6,4]{2,1,0} slice(%p1), slice={[1:2], [0:6], [0:4]}
+      %bitcast.42 = f16[6,4]{1,0} bitcast(%slice.14)
+
+      ROOT %custom-call.1 = f16[4,4]{1,0} custom-call(%bitcast.41, %bitcast.42),
+        custom_call_target="__cublas$gemm",
+        backend_config={"gemm_backend_config":{
+          "alpha_real":1,
+          "beta":0,
+          "dot_dimension_numbers":{
+            "lhs_contracting_dimensions":["1"],
+            "rhs_contracting_dimensions":["0"],
+            "lhs_batch_dimensions":[],
+            "rhs_batch_dimensions":[]
+          },
+          "alpha_imag":0,
+          "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
+          "epilogue":"DEFAULT",
+          "lhs_stride":"64",
+          "rhs_stride":"64",
+          "grad_x":false,
+          "grad_y":false
+        }}
+    }
+  )";
+
+  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+  RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"),
+                            std::nullopt);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmNonNoOpInSliceChain) {
+  const char* hlo = R"(
+    HloModule test
+
+    ENTRY %main.9 {
+      %p0 = f16[2,8,8]{2,1,0} parameter(0)
+      %p1 = f16[2,8,8]{2,1,0} parameter(1)
+      %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[0:1], [0:8], [0:8]}
+      %slice.14 = f16[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]}
+      %add.0 = f16[1,8,8]{2,1,0} add(%slice.13, %slice.14)
+      %bitcast.41 = f16[8,8]{1,0} bitcast(%add.0)
+      %slice.15 = f16[1,8,8]{2,1,0} slice(%p1), slice={[0:1], [0:8], [0:8]}
+      %slice.16 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]}
+      %add.1 = f16[1,8,8]{2,1,0} add(%slice.15, %slice.16)
+      %bitcast.42 = f16[8,8]{1,0} bitcast(%add.1)
+
+      ROOT %custom-call.1 = f16[8,8]{1,0} custom-call(%bitcast.41, %bitcast.42),
+        custom_call_target="__cublas$gemm",
+        backend_config={"gemm_backend_config":{
+          "alpha_real":1,
+          "beta":0,
+          "dot_dimension_numbers":{
+            "lhs_contracting_dimensions":["1"],
+            "rhs_contracting_dimensions":["0"],
+            "lhs_batch_dimensions":[],
+            "rhs_batch_dimensions":[]
+          },
+          "alpha_imag":0,
+          "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
+          "epilogue":"DEFAULT",
+          "lhs_stride":"64",
+          "rhs_stride":"64",
+          "grad_x":false,
+          "grad_y":false
+        }}
+    }
+  )";
+
+  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+  RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"),
+                            std::nullopt);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmDuplicateOperand) {
+  const char* hlo = R"(
+    HloModule test
+
+    ENTRY %main {
+      %p0 = (f32[100,100]{1,0}, f32[100,100]{1,0}) parameter(0)
+      %get-tuple-element.240 = f32[100,100]{1,0} get-tuple-element(%p0), index=0
+      %get-tuple-element.241 = f32[100,100]{1,0} get-tuple-element(%p0), index=1
+      %concatenate.10 = f32[200,100]{1,0} concatenate(%get-tuple-element.240, %get-tuple-element.241), dimensions={0}
+      %custom-call.16 = (f32[200,100]{1,0}, s8[120000]{0}) custom-call(%concatenate.10, %get-tuple-element.240),
+        custom_call_target="__cublas$gemm",
+        backend_config={
+          "gemm_backend_config":{
+            "alpha_real":1,
+            "beta":0,
+            "dot_dimension_numbers":{
+              "lhs_contracting_dimensions":["1"],
+              "rhs_contracting_dimensions":["0"],
+              "lhs_batch_dimensions":[],
+              "rhs_batch_dimensions":[]
+            },
+            "alpha_imag":0,
+            "precision_config":{"operand_precision":["HIGHEST","HIGHEST"]},
+            "epilogue":"DEFAULT",
+            "lhs_stride":"20000",
+            "rhs_stride":"10000",
+            "grad_x":false,
+            "grad_y":false
+          }
+        }
+      %get-tuple-element.97 = f32[200,100]{1,0} get-tuple-element(%custom-call.16), index=0
+      %slice.26 = f32[100,100]{1,0} slice(%get-tuple-element.97), slice={[0:100], [0:100]}
+      ROOT %custom-call.17 = (f32[100,100]{1,0}, s8[80000]{0}) custom-call(%slice.26, %slice.26),
+        custom_call_target="__cublas$gemm",
+        backend_config={
+          "gemm_backend_config":{
+            "alpha_real":1,
+            "beta":0,
+            "dot_dimension_numbers":{
+              "lhs_contracting_dimensions":["1"],
+              "rhs_contracting_dimensions":["0"],
+              "lhs_batch_dimensions":[],
+              "rhs_batch_dimensions":[]
+            },
+            "alpha_imag":0,
+            "precision_config":{"operand_precision":["HIGHEST","HIGHEST"]},
+            "epilogue":"DEFAULT",
+            "lhs_stride":"10000",
+            "rhs_stride":"10000",
+            "grad_x":false,
+            "grad_y":false
+          }
+        }
+  })";
+
+  const char* expected = R"(
+    ; CHECK:     %address-computation {{.*}} {
+    ; CHECK:       [[P0:%[^ ]+]] = f32[200,100]{1,0} parameter(0)
+    ; CHECK:       [[S0:%[^ ]+]] = f32[100,100]{1,0} slice([[P0]]), slice={[0:100], [0:100]}
+    ; CHECK-NOT:   slice
+    ; CHECK:       [[CC:%[^ ]+]] = (f32[100,100]{1,0}, s8[80000]{0}) custom-call([[S0]], [[S0]]),
+    ; CHECK:         custom_call_target="__cublas$gemm"
+    ; CHECK:     }
+
+    ; CHECK:     ENTRY %main{{.*}} {
+    ; CHECK:       ROOT [[FUSION:%[^ ]+]] = (f32[100,100]{1,0}, s8[80000]{0}) fusion
+    ; CHECK:         kind=kCustom, calls=%address-computation,
+    ; CHECK:         backend_config={
+    ; CHECK:           "kind":"__custom_fusion",
+    ; CHECK:           "custom_fusion_config":{"name":"address_computation","kernel_index":0}
+    ; CHECK:         }
+    ; CHECK:     }
+  )";
+
+  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+  RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmReverseOperandOrder) {
+  const char* hlo = R"(
+    HloModule test
+
+    ENTRY %main.9 {
+      %p0 = f16[2,8,8]{2,1,0} parameter(1)
+      %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[0:1], [0:8], [0:8]}
+      %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13)
+      %p1 = f16[2,8,8]{2,1,0} parameter(0)
+      %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]}
+      %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14)
+
+      ROOT %custom-call.1 = f16[8,8]{1,0} custom-call(%bitcast.41, %bitcast.42),
+        custom_call_target="__cublas$gemm",
+        backend_config={"gemm_backend_config":{
+          "alpha_real":1,
+          "beta":0,
+          "dot_dimension_numbers":{
+            "lhs_contracting_dimensions":["1"],
+            "rhs_contracting_dimensions":["0"],
+            "lhs_batch_dimensions":[],
+            "rhs_batch_dimensions":[]
+          },
+          "alpha_imag":0,
+          "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
+          "epilogue":"DEFAULT",
+          "lhs_stride":"64",
+          "rhs_stride":"64",
+          "grad_x":false,
+          "grad_y":false
+        }}
+    }
+  )";
+
+  const char* expected = R"(
+    ; CHECK:     %address-computation {{.*}} {
+    ; CHECK-DAG:   [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
+    ; CHECK-DAG:   [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1)
+    ; CHECK-DAG:   [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[0:1], [0:8], [0:8]}
+    ; CHECK-DAG:   [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
+    ; CHECK-DAG:   [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P1]]), slice={[1:2], [0:8], [0:8]}
+    ; CHECK-DAG:   [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
+    ; CHECK:       ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]),
+    ; CHECK:              custom_call_target="__cublas$gemm"
+    ; CHECK:     }
+
+    ; CHECK:     ENTRY %main{{.*}} {
+    ; CHECK-DAG:   [[A0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1)
+    ; CHECK-DAG:   [[A1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
+    ; CHECK:       ROOT [[FUSION:%[^ ]+]] = f16[8,8]{1,0} fusion([[A0]], [[A1]])
+    ; CHECK:         kind=kCustom, calls=%address-computation,
+    ; CHECK:         backend_config={
+    ; CHECK:           "kind":"__custom_fusion",
+    ; CHECK:           "custom_fusion_config":{"name":"address_computation","kernel_index":0}
+    ; CHECK:         }
+    ; CHECK:     }
+  )";
+
+  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+  RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmReverseOperandOrder2) {
+  const char* hlo = R"(
+    HloModule test
+
+    ENTRY %main.9 {
+      %p0 = f16[2,8,8]{2,1,0} parameter(0)
+      %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[0:1], [0:8], [0:8]}
+      %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13)
+      %p1 = f16[2,8,8]{2,1,0} parameter(1)
+      %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]}
+      %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14)
+
+      ROOT %custom-call.1 = f16[8,8]{1,0} custom-call(%bitcast.42, %bitcast.41),
+        custom_call_target="__cublas$gemm",
+        backend_config={"gemm_backend_config":{
+          "alpha_real":1,
+          "beta":0,
+          "dot_dimension_numbers":{
+            "lhs_contracting_dimensions":["1"],
+            "rhs_contracting_dimensions":["0"],
+            "lhs_batch_dimensions":[],
+            "rhs_batch_dimensions":[]
+          },
+          "alpha_imag":0,
+          "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
+          "epilogue":"DEFAULT",
+          "lhs_stride":"64",
+          "rhs_stride":"64",
+          "grad_x":false,
+          "grad_y":false
+        }}
+    }
+  )";
+
+  const char* expected = R"(
+    ; CHECK:     %address-computation {{.*}} {
+    ; CHECK-DAG:   [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
+    ; CHECK-DAG:   [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1)
+    ; CHECK-DAG:   [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[1:2], [0:8], [0:8]}
+    ; CHECK-DAG:   [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
+    ; CHECK-DAG:   [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P1]]), slice={[0:1], [0:8], [0:8]}
+    ; CHECK-DAG:   [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
+    ; CHECK:       ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]),
+    ; CHECK:              custom_call_target="__cublas$gemm"
+    ; CHECK:     }
+
+    ; CHECK:     ENTRY %main{{.*}} {
+    ; CHECK-DAG:   [[A0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1)
+    ; CHECK-DAG:   [[A1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
+    ; CHECK:       ROOT [[FUSION:%[^ ]+]] = f16[8,8]{1,0} fusion([[A0]], [[A1]])
+    ; CHECK:         kind=kCustom, calls=%address-computation,
+    ; CHECK:         backend_config={
+    ; CHECK:           "kind":"__custom_fusion",
+    ; CHECK:           "custom_fusion_config":{"name":"address_computation","kernel_index":0}
+    ; CHECK:         }
+    ; CHECK:     }
+  )";
+
+  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+  RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmOperandAliasingOutput) {
+  const char* hlo = R"(
+    HloModule test
+
+    ENTRY %main.9 {
+      %p0 = (f32[100,100]{1,0}, f32[100,100]{1,0}) parameter(0)
+      %get-tuple-element.287 = f32[100,100]{1,0} get-tuple-element(%p0), index=0
+      %get-tuple-element.288 = f32[100,100]{1,0} get-tuple-element(%p0), index=1
+      %concatenate.12 = f32[200,100]{1,0} concatenate(%get-tuple-element.287, %get-tuple-element.288), dimensions={0}
+      %slice.30 = f32[100,100]{1,0} slice(%concatenate.12), slice={[16:116], [0:100]}
+      %slice.34 = f32[100,100]{1,0} slice(%concatenate.12), slice={[99:199], [0:100]}
+      ROOT %cublas-gemm.15 = (f32[100,100]{1,0}, s8[120000]{0}) custom-call(%get-tuple-element.287, %slice.30, %slice.34),
+        custom_call_target="__cublas$gemm",
+        output_to_operand_aliasing={{0}: (2, {})},
+        backend_config={"gemm_backend_config":{
+          "alpha_real":1,
+          "beta":1,
+          "dot_dimension_numbers":{
+            "lhs_contracting_dimensions":["1"],
+            "rhs_contracting_dimensions":["0"],
+            "lhs_batch_dimensions":[],
+            "rhs_batch_dimensions":[]
+          },
+          "alpha_imag":0,
+          "precision_config":{"operand_precision":["HIGHEST","HIGHEST"]},
+          "epilogue":"DEFAULT",
+          "lhs_stride":"10000",
+          "rhs_stride":"10000",
+          "grad_x":false,
+          "grad_y":false
+        }}
+    }
+  )";
+
+  const char* expected = R"(
+    ; CHECK:     %address-computation {{.*}} {
+    ; CHECK-DAG:   [[P2:%[^ ]+]] = f32[100,100]{1,0} parameter(2)
+    ; CHECK-DAG:   [[P1:%[^ ]+]] = f32[100,100]{1,0} parameter(1)
+    ; CHECK-DAG:   [[P0:%[^ ]+]] = f32[200,100]{1,0} parameter(0)
+    ; CHECK-DAG:   [[S1:%[^ ]+]] = f32[100,100]{1,0} slice([[P0]]), slice={[16:116], [0:100]}
+    ; CHECK:       [[CC:%[^ ]+]] = (f32[100,100]{1,0}, s8[120000]{0}) custom-call([[P1]], [[S1]], [[P2]]),
+    ; CHECK:         custom_call_target="__cublas$gemm"
+    ; CHECK:     }
+
+    ; CHECK:     ENTRY %main{{.*}} {
+    ; CHECK:       [[P:%[^ ]+]] = (f32[100,100]{1,0}, f32[100,100]{1,0}) parameter(0)
+    ; CHECK:       [[GTE0:%[^ ]+]] = f32[100,100]{1,0} get-tuple-element([[P]]), index=0
+    ; CHECK:       [[GTE1:%[^ ]+]] = f32[100,100]{1,0} get-tuple-element([[P]]), index=1
+    ; CHECK:       [[CONCAT:%[^ ]+]] = f32[200,100]{1,0} concatenate([[GTE0]], [[GTE1]]), dimensions={0}
+    ; CHECK:       [[S:%[^ ]+]] = f32[100,100]{1,0} slice([[CONCAT]]), slice={[99:199], [0:100]}
+    ; CHECK:       ROOT [[FUSION:%[^ ]+]] = (f32[100,100]{1,0}, s8[120000]{0}) fusion([[CONCAT]], [[GTE0]], [[S]])
+    ; CHECK:         kind=kCustom, calls=%address-computation,
+    ; CHECK:         backend_config={
+    ; CHECK:           "kind":"__custom_fusion",
+    ; CHECK:           "custom_fusion_config":{"name":"address_computation","kernel_index":0}
+    ; CHECK:         }
+    ; CHECK:     }
+  )";
+
+  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+  RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmOperandsFromSameSlice) {
+  const char* hlo = R"(
+    HloModule test
+
+    ENTRY %main.9 {
+      %p0 = f16[2,8,8]{2,1,0} parameter(0)
+      %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[0:1], [0:8], [0:8]}
+      %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13)
+      %bitcast.42 = f16[8,8]{0,1} bitcast(%slice.13)
+
+      ROOT %custom-call.1 = f16[8,8]{1,0} custom-call(%bitcast.41, %bitcast.42),
+        custom_call_target="__cublas$gemm",
+        backend_config={"gemm_backend_config":{
+          "alpha_real":1,
+          "beta":0,
+          "dot_dimension_numbers":{
+            "lhs_contracting_dimensions":["1"],
+            "rhs_contracting_dimensions":["0"],
+            "lhs_batch_dimensions":[],
+            "rhs_batch_dimensions":[]
+          },
+          "alpha_imag":0,
+          "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
+          "epilogue":"DEFAULT",
+          "lhs_stride":"64",
+          "rhs_stride":"64",
+          "grad_x":false,
+          "grad_y":false
+        }}
+    }
+  )";
+
+  const char* expected = R"(
+    ; CHECK:     %address-computation {{.*}} {
+    ; CHECK-DAG:   [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
+    ; CHECK-DAG:   [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[0:1], [0:8], [0:8]}
+    ; CHECK-DAG:   [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
+    ; CHECK-DAG:   [[B1:%[^ ]+]] = f16[8,8]{0,1} bitcast([[S0]])
+    ; CHECK:       ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]),
+    ; CHECK:              custom_call_target="__cublas$gemm"
+    ; CHECK:     }
+
+    ; CHECK:     ENTRY %main{{.*}} {
+    ; CHECK-DAG:   [[A0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
+    ; CHECK:       ROOT [[FUSION:%[^ ]+]] = f16[8,8]{1,0} fusion([[A0]])
+    ; CHECK:         kind=kCustom, calls=%address-computation,
+    ; CHECK:         backend_config={
+    ; CHECK:           "kind":"__custom_fusion",
+    ; CHECK:           "custom_fusion_config":{"name":"address_computation","kernel_index":0}
+    ; CHECK:         }
+    ; CHECK:     }
+  )";
+
+  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+  RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
+}
+
+static absl::Status Memcpy(se::Stream* stream, ffi::AnyBuffer src,
+                           ffi::AnyBuffer dst) {
+  se::DeviceMemoryBase dst_mem = dst.device_memory();
+  se::DeviceMemoryBase src_mem = src.device_memory();
+  return stream->MemcpyD2D(&dst_mem, src_mem, src_mem.size());
+}
+
+XLA_FFI_DEFINE_HANDLER(kMemcpy, Memcpy,
+                       ffi::Ffi::Bind()
+                           .Ctx<ffi::Stream>()
+                           .Arg<ffi::AnyBuffer>()  // src
+                           .Arg<ffi::AnyBuffer>()  // dst
+);
+XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$memcpy", "gpu",
+                         kMemcpy);
+
+TEST_F(DynamicSliceFusionRewriterTest, SimpleCustomCall) {
+  XlaBuilder b(TestName());
+  CustomCall(&b, "__xla_test$$memcpy",
+             /*operands=*/
+             {Slice(Broadcast(ConstantR0WithType(&b, F32, 42.0), {256}), {0},
+                    {128}, {1})},
+             ShapeUtil::MakeShape(F32, {128}), /*opaque=*/"",
+             /*has_side_effect=*/false,
+             /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
+             /*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
+             /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
+  TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build());
+  xla::HloModuleConfig hlo_config(
+      xla::ProgramShape(computation.proto().host_program_shape()),
+      /*ignore_layouts=*/false);
+  DebugOptions debug_options = GetDebugOptionsForTest();
+  debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false);
+  hlo_config.set_debug_options(debug_options);
+  TF_ASSERT_OK_AND_ASSIGN(auto hlo, xla::HloModule::CreateFromProto(
+                                        computation.proto(), hlo_config));
+
+  const char* expected = R"(
+    ; CHECK:     %address-computation {{.*}} {
+    ; CHECK:       [[P0:%[^ ]+]] = f32[256]{0} parameter(0)
+    ; CHECK:       [[S0:%[^ ]+]] = f32[128]{0} slice([[P0]]), slice={[0:128]}
+    ; CHECK:       ROOT [[CC:%[^ ]+]] = f32[128]{0} custom-call([[S0]]),
+    ; CHECK:              custom_call_target="__xla_test$$memcpy",
+    ; CHECK:              api_version=API_VERSION_TYPED_FFI
+    ; CHECK:     }
+
+    ; CHECK:     ENTRY %{{.*}} {
+    ; CHECK:       [[C0:%[^ ]+]] = f32[] constant(42)
+    ; CHECK:       [[BC:%[^ ]+]] = f32[256]{0} broadcast([[C0]])
+    ; CHECK:       ROOT [[FUSION:%[^ ]+]] = f32[128]{0} fusion([[BC]])
+    ; CHECK:         kind=kCustom, calls=%address-computation,
+    ; CHECK:         backend_config={
+    ; CHECK:           "kind":"__custom_fusion",
+    ; CHECK:           "custom_fusion_config":{"name":"address_computation","kernel_index":0}
+    ; CHECK:         }
+    ; CHECK:     }
+  )";
+
+  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+  RunAndFilecheckHloRewrite(hlo->ToString(), DynamicSliceFusionRewriter("gpu"),
+                            expected);
+}
+
+void Callback_Void(se::gpu::GpuStreamHandle stream, void** buffers,
+                   const char* /*opaque*/, size_t /*opaque_len*/) {}
+
+XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_Void, "gpu");
+
+TEST_F(DynamicSliceFusionRewriterTest, SimpleCustomCallLegacy) {
+  XlaBuilder b(TestName());
+  CustomCall(&b, "Callback_Void",
+             /*operands=*/
+             {Slice(Broadcast(ConstantR0WithType(&b, F32, 42.0), {256}), {0},
+                    {128}, {1})},
+             ShapeUtil::MakeShape(F32, {128}), /*opaque=*/"");
+  TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build());
+  xla::HloModuleConfig hlo_config(
+      xla::ProgramShape(computation.proto().host_program_shape()),
+      /*ignore_layouts=*/false);
+  DebugOptions debug_options = GetDebugOptionsForTest();
+  debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false);
+  hlo_config.set_debug_options(debug_options);
+  TF_ASSERT_OK_AND_ASSIGN(auto hlo, xla::HloModule::CreateFromProto(
+                                        computation.proto(), hlo_config));
+  // TF_ASSERT_OK_AND_ASSIGN(
+  //     HloSchedule schedule,
+  //     ScheduleModule(hlo.get(), [](const BufferValue& buffer) {
+  //       return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
+  //     }));
+  // TF_CHECK_OK(hlo->set_schedule(std::move(schedule)));
+
+  const char* expected = R"(
+    ; CHECK:     %address-computation {{.*}} {
+    ; CHECK:       [[P0:%[^ ]+]] = f32[256]{0} parameter(0)
+    ; CHECK:       [[S0:%[^ ]+]] = f32[128]{0} slice([[P0]]), slice={[0:128]}
+    ; CHECK:       ROOT [[CC:%[^ ]+]] = f32[128]{0} custom-call([[S0]]),
+    ; CHECK:              custom_call_target="Callback_Void"
+    ; CHECK:     }
+
+    ; CHECK:     ENTRY %{{.*}} {
+    ; CHECK:       [[C0:%[^ ]+]] = f32[] constant(42)
+    ; CHECK:       [[BC:%[^ ]+]] = f32[256]{0} broadcast([[C0]])
+    ; CHECK:       ROOT [[FUSION:%[^ ]+]] = f32[128]{0} fusion([[BC]])
+    ; CHECK:         kind=kCustom, calls=%address-computation,
+    ; CHECK:         backend_config={
+    ; CHECK:           "kind":"__custom_fusion",
+    ; CHECK:           "custom_fusion_config":{"name":"address_computation","kernel_index":0}
+    ; CHECK:         }
+    ; CHECK:     }
+  )";
+
+  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+  RunAndFilecheckHloRewrite(hlo->ToString(), DynamicSliceFusionRewriter("gpu"),
+                            expected);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, TupleSliceCustomCallLegacy) {
+  XlaBuilder b(TestName());
+  CustomCall(
+      &b, "Callback_Void",
+      /*operands=*/
+      {
+          Tuple(&b,
+                {
+                    Slice(Broadcast(ConstantR0WithType(&b, F32, 5), {8, 8}),
+                          {0, 0}, {4, 8}, {1, 1}),
+                    Broadcast(ConstantR0WithType(&b, F32, 2), {256}),
+                }),
+          Tuple(&b,
+                {
+                    Broadcast(ConstantR0WithType(&b, F32, 3), {1024}),
+                    Broadcast(ConstantR0WithType(&b, F32, 4), {8}),
+                }),
+      },
+      ShapeUtil::MakeShape(F32, {128}), /*opaque=*/"");
+  TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build());
+  xla::HloModuleConfig hlo_config(
+      xla::ProgramShape(computation.proto().host_program_shape()),
+      /*ignore_layouts=*/false);
+  DebugOptions debug_options = GetDebugOptionsForTest();
+  debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false);
+  hlo_config.set_debug_options(debug_options);
+  TF_ASSERT_OK_AND_ASSIGN(auto hlo, xla::HloModule::CreateFromProto(
+                                        computation.proto(), hlo_config));
+  // TF_ASSERT_OK_AND_ASSIGN(
+  //     HloSchedule schedule,
+  //     ScheduleModule(hlo.get(), [](const BufferValue& buffer) {
+  //       return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
+  //     }));
+  // TF_CHECK_OK(hlo->set_schedule(std::move(schedule)));
+
+  const char* expected = R"(
+    ; CHECK:     %address-computation {{.*}} {
+    ; CHECK-DAG:   [[P0:%[^ ]+]] = f32[8,8]{1,0} parameter(0)
+    ; CHECK-DAG:   [[S0:%[^ ]+]] = f32[4,8]{1,0} slice([[P0]]), slice={[0:4], [0:8]}
+    ; CHECK-DAG:   [[P1:%[^ ]+]] = f32[256]{0} parameter(1)
+    ; CHECK-DAG:   [[T0:%[^ ]+]] = (f32[4,8]{1,0}, f32[256]{0}) tuple([[S0]], [[P1]])
+    ; CHECK-DAG:   [[P2:%[^ ]+]] = (f32[1024]{0}, f32[8]{0}) parameter(2)
+    ; CHECK:       ROOT [[CC:%[^ ]+]] = f32[128]{0} custom-call([[T0]], [[P2]]),
+    ; CHECK:              custom_call_target="Callback_Void"
+    ; CHECK:     }
+
+    ; CHECK:     ENTRY %{{.*}} {
+    ; CHECK:       ROOT [[FUSION:%[^ ]+]] = f32[128]{0} fusion(
+    ; CHECK:         kind=kCustom, calls=%address-computation,
+    ; CHECK:         backend_config={
+    ; CHECK:           "kind":"__custom_fusion",
+    ; CHECK:           "custom_fusion_config":{"name":"address_computation","kernel_index":0}
+    ; CHECK:         }
+    ; CHECK:     }
+  )";
+
+  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+  RunAndFilecheckHloRewrite(hlo->ToString(), DynamicSliceFusionRewriter("gpu"),
+                            expected);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, TupledOutputCustomCallLegacy) {
+  XlaBuilder b(TestName());
+  auto custom_call = CustomCall(
+      &b, "Callback_Void",
+      /*operands=*/
+      {
+          Tuple(&b,
+                {
+                    Slice(Broadcast(ConstantR0WithType(&b, F32, 5), {8, 8}),
+                          {0, 0}, {4, 8}, {1, 1}),
+                    Broadcast(ConstantR0WithType(&b, F32, 2), {256}),
+                }),
+          Tuple(&b,
+                {
+                    Broadcast(ConstantR0WithType(&b, F32, 3), {1024}),
+                    Broadcast(ConstantR0WithType(&b, F32, 4), {8}),
+                }),
+      },
+      ShapeUtil::MakeTupleShape({
+          ShapeUtil::MakeShape(F32, {8}),
+          ShapeUtil::MakeTupleShape({
+              ShapeUtil::MakeShape(F32, {128}),
+              ShapeUtil::MakeShape(F32, {256}),
+          }),
+          ShapeUtil::MakeShape(F32, {1024}),
+          ShapeUtil::MakeShape(F32, {4, 8}),
+      }),
+      /*opaque=*/"");
+  Tuple(&b, {GetTupleElement(GetTupleElement(custom_call, 1), 0),
+             GetTupleElement(custom_call, 2)});
+  TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build());
+  xla::HloModuleConfig hlo_config(
+      xla::ProgramShape(computation.proto().host_program_shape()),
+      /*ignore_layouts=*/false);
+  DebugOptions debug_options = GetDebugOptionsForTest();
+  debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false);
+  hlo_config.set_debug_options(debug_options);
+  TF_ASSERT_OK_AND_ASSIGN(auto hlo, xla::HloModule::CreateFromProto(
+                                        computation.proto(), hlo_config));
+  // TF_ASSERT_OK_AND_ASSIGN(
+  //     HloSchedule schedule,
+  //     ScheduleModule(hlo.get(), [](const BufferValue& buffer) {
+  //       return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
+  //     }));
+  // TF_CHECK_OK(hlo->set_schedule(std::move(schedule)));
+
+  const char* expected = R"(
+    ; CHECK:     %address-computation {{.*}} {
+    ; CHECK-DAG:   [[P2:%[^ ]+]] = (f32[1024]{0}, f32[8]{0}) parameter(2)
+    ; CHECK-DAG:   [[P1:%[^ ]+]] = f32[256]{0} parameter(1)
+    ; CHECK-DAG:   [[P0:%[^ ]+]] = f32[8,8]{1,0} parameter(0)
+    ; CHECK-DAG:   [[S0:%[^ ]+]] = f32[4,8]{1,0} slice([[P0]]), slice={[0:4], [0:8]}
+    ; CHECK-DAG:   [[T0:%[^ ]+]] = (f32[4,8]{1,0}, f32[256]{0}) tuple([[S0]], [[P1]])
+    ; CHECK:       [[CC:%[^ ]+]] = (f32[8]{0}, (f32[128]{0}, f32[256]{0}), f32[1024]{0}, f32[4,8]{1,0}) custom-call([[T0]], [[P2]]),
+    ; CHECK:              custom_call_target="Callback_Void"
+    ; CHECK-DAG:   [[GTE0:%[^ ]+]] = f32[8]{0} get-tuple-element([[CC]]), index=0
+    ; CHECK-DAG:   [[GTE1:%[^ ]+]] = (f32[128]{0}, f32[256]{0}) get-tuple-element([[CC]]), index=1
+    ; CHECK-DAG:   [[GTE2:%[^ ]+]] = f32[128]{0} get-tuple-element([[GTE1]]), index=0
+    ; CHECK-DAG:   [[GTE3:%[^ ]+]] = f32[256]{0} get-tuple-element([[GTE1]]), index=1
+    ; CHECK-DAG:   [[T1:%[^ ]+]] = (f32[128]{0}, f32[256]{0}) tuple([[GTE2]], [[GTE3]])
+    ; CHECK-DAG:   [[GTE4:%[^ ]+]] = f32[1024]{0} get-tuple-element([[CC]]), index=2
+    ; CHECK-DAG:   [[GTE5:%[^ ]+]] = f32[4,8]{1,0} get-tuple-element([[CC]]), index=3
+    ; CHECK:       ROOT {{.*}} = (f32[8]{0}, (f32[128]{0}, f32[256]{0}), f32[1024]{0}, f32[4,8]{1,0}) tuple([[GTE0]], [[T1]], [[GTE4]], [[GTE5]])
+    ; CHECK:     }
+
+    ; CHECK:     ENTRY %{{.*}} {
+    ; CHECK:       [[FUSION:%[^ ]+]] = (f32[8]{0}, (f32[128]{0}, f32[256]{0}), f32[1024]{0}, f32[4,8]{1,0}) fusion
+    ; CHECK:         kind=kCustom, calls=%address-computation,
+    ; CHECK:         backend_config={
+    ; CHECK:           "kind":"__custom_fusion",
+    ; CHECK:           "custom_fusion_config":{"name":"address_computation","kernel_index":0}
+    ; CHECK:         }
+    ; CHECK-DAG:   [[GTE6:%[^ ]+]] = f32[1024]{0} get-tuple-element([[FUSION]]), index=2
+    ; CHECK-DAG:   [[GTE7:%[^ ]+]] = (f32[128]{0}, f32[256]{0}) get-tuple-element([[FUSION]]), index=1
+    ; CHECK-DAG:   [[GTE8:%[^ ]+]] = f32[128]{0} get-tuple-element([[GTE7]]), index=0
+    ; CHECK:       ROOT {{.*}} = (f32[128]{0}, f32[1024]{0}) tuple([[GTE8]], [[GTE6]])
+    ; CHECK:     }
+  )";
+
+  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+  RunAndFilecheckHloRewrite(hlo->ToString(), DynamicSliceFusionRewriter("gpu"),
+                            expected);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, UnalignedSlice) {
+  XlaBuilder b(TestName());
+  CustomCall(
+      &b, "Callback_Void",
+      /*operands=*/
+      {Slice(Broadcast(ConstantR0WithType(&b, S32, 42), {17}), {1}, {17}, {1})},
+      ShapeUtil::MakeShape(S32, {16}), /*opaque=*/"");
+  TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build());
+  xla::HloModuleConfig hlo_config(
+      xla::ProgramShape(computation.proto().host_program_shape()),
+      /*ignore_layouts=*/false);
+  DebugOptions debug_options = GetDebugOptionsForTest();
+  debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false);
+  hlo_config.set_debug_options(debug_options);
+  TF_ASSERT_OK_AND_ASSIGN(auto hlo, xla::HloModule::CreateFromProto(
+                                        computation.proto(), hlo_config));
+  // TF_ASSERT_OK_AND_ASSIGN(
+  //     HloSchedule schedule,
+  //     ScheduleModule(hlo.get(), [](const BufferValue& buffer) {
+  //       return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
+  //     }));
+  // TF_CHECK_OK(hlo->set_schedule(std::move(schedule)));
+
+  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+  RunAndFilecheckHloRewrite(hlo->ToString(), DynamicSliceFusionRewriter("gpu"),
+                            std::nullopt);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, DynamicSimpleGemm) {
+  const char* hlo = R"(
+    HloModule test
+
+    ENTRY main.9 {
+      p0 = f16[2,8,8]{2,1,0} parameter(0)
+      p1 = f16[2,8,8]{2,1,0} parameter(1)
+      c1_s32 = s32[] constant(1)
+      c0_s32 = s32[] constant(0)
+      slice.13 = f16[1,8,8]{2,1,0} dynamic-slice(p0, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8}
+      bitcast.41 = f16[8,8]{1,0} bitcast(slice.13)
+      slice.14 = f16[1,8,8]{2,1,0} dynamic-slice(p1, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8}
+      bitcast.42 = f16[8,8]{1,0} bitcast(slice.14)
+
+      ROOT custom-call.1 = f16[8,8]{1,0} custom-call(bitcast.41, bitcast.42),
+        custom_call_target="__cublas$gemm",
+        backend_config={"gemm_backend_config":{
+          "alpha_real":1,
+          "beta":0,
+          "dot_dimension_numbers":{
+            "lhs_contracting_dimensions":["1"],
+            "rhs_contracting_dimensions":["0"],
+            "lhs_batch_dimensions":[],
+            "rhs_batch_dimensions":[]
+          },
+          "alpha_imag":0,
+          "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
+          "epilogue":"DEFAULT",
+          "lhs_stride":"64",
+          "rhs_stride":"64",
+          "grad_x":false,
+          "grad_y":false
+        }}
+    }
+  )";
+
+  const char* expected = R"(
+    ; CHECK:     address-computation {{.*}} {
+    ; CHECK-DAG:   [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
+    ; CHECK-DAG:   [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(3)
+    ; CHECK-DAG:   [[C1:%[^ ]+]] = s32[] parameter(1)
+    ; CHECK-DAG:   [[C0:%[^ ]+]] = s32[] parameter(2)
+    ; CHECK-DAG:   [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P0]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8}
+    ; CHECK-DAG:   [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
+    ; CHECK-DAG:   [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P1]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8}
+    ; CHECK-DAG:   [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
+    ; CHECK:       ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]),
+    ; CHECK:              custom_call_target="__cublas$gemm"
+    ; CHECK:     }
+
+    ; CHECK:     ENTRY %main{{.*}} {
+    ; CHECK:       ROOT [[FUSION:%[^ ]+]] = f16[8,8]{1,0} fusion
+    ; CHECK:         kind=kCustom, calls=%address-computation,
+    ; CHECK:         backend_config={
+    ; CHECK:           "kind":"__custom_fusion",
+    ; CHECK:           "custom_fusion_config":{"name":"dynamic_address_computation","kernel_index":0}
+    ; CHECK:         }
+    ; CHECK:     }
+  )";
+
+  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+  RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, DynamicSimpleGemmWithWorkspace) {
+  const char* hlo = R"(
+    HloModule test
+
+    ENTRY main.9 {
+      p0 = f16[2,8,8]{2,1,0} parameter(0)
+      p1 = f16[2,8,8]{2,1,0} parameter(1)
+      c1_s32 = s32[] constant(1)
+      c0_s32 = s32[] constant(0)
+      slice.13 = f16[1,8,8]{2,1,0} dynamic-slice(p0, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8}
+      bitcast.41 = f16[8,8]{1,0} bitcast(slice.13)
+      slice.14 = f16[1,8,8]{2,1,0} dynamic-slice(p1, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8}
+      bitcast.42 = f16[8,8]{1,0} bitcast(slice.14)
+
+      ROOT custom-call.1 = (f16[8,8]{1,0}, s8[256]{0}) custom-call(bitcast.41, bitcast.42),
+        custom_call_target="__cublas$gemm",
+        backend_config={"gemm_backend_config":{
+          "alpha_real":1,
+          "beta":0,
+          "dot_dimension_numbers":{
+            "lhs_contracting_dimensions":["1"],
+            "rhs_contracting_dimensions":["0"],
+            "lhs_batch_dimensions":[],
+            "rhs_batch_dimensions":[]
+          },
+          "alpha_imag":0,
+          "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
+          "epilogue":"DEFAULT",
+          "lhs_stride":"64",
+          "rhs_stride":"64",
+          "grad_x":false,
+          "grad_y":false
+        }}
+    }
+  )";
+
+  const char* expected = R"(
+    ; CHECK:     address-computation {{.*}} {
+    ; CHECK-DAG:   [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
+    ; CHECK-DAG:   [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(3)
+    ; CHECK-DAG:   [[C1:%[^ ]+]] = s32[] parameter(1)
+    ; CHECK-DAG:   [[C0:%[^ ]+]] = s32[] parameter(2)
+    ; CHECK-DAG:   [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P0]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8}
+    ; CHECK-DAG:   [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
+    ; CHECK-DAG:   [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P1]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8}
+    ; CHECK-DAG:   [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
+    ; CHECK:       [[CC:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0}) custom-call([[B0]], [[B1]]),
+    ; CHECK:              custom_call_target="__cublas$gemm"
+    ; CHECK:       [[DOT:%[^ ]+]] = f16[8,8]{1,0} get-tuple-element([[CC]]), index=0
+    ; CHECK:       [[WORKSPACE:%[^ ]+]] = s8[256]{0} get-tuple-element([[CC]]), index=1
+    ; CHECK:       ROOT [[TUPLE:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0})
+    ; CHECK:              tuple([[DOT]], [[WORKSPACE]])
+    ; CHECK:     }
+
+
+    ; CHECK:     ENTRY %main{{.*}} {
+    ; CHECK:       ROOT [[FUSION:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0}) fusion
+    ; CHECK:         kind=kCustom, calls=%address-computation,
+    ; CHECK:         backend_config={
+    ; CHECK:           "kind":"__custom_fusion",
+    ; CHECK:           "custom_fusion_config":{"name":"dynamic_address_computation","kernel_index":0}
+    ; CHECK:         }
+    ; CHECK:     }
+  )";
+
+  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+  RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, DynamicSimpleGemmWorkspaceIgnored) {
+  const char* hlo = R"(
+    HloModule test
+
+    ENTRY main.9 {
+      p0 = f16[2,8,8]{2,1,0} parameter(0)
+      p1 = f16[2,8,8]{2,1,0} parameter(1)
+      c1_s32 = s32[] constant(1)
+      c0_s32 = s32[] constant(0)
+      slice.13 = f16[1,8,8]{2,1,0} dynamic-slice(p0, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8}
+      bitcast.41 = f16[8,8]{1,0} bitcast(slice.13)
+      slice.14 = f16[1,8,8]{2,1,0} dynamic-slice(p1, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8}
+      bitcast.42 = f16[8,8]{1,0} bitcast(slice.14)
+
+      custom-call.1 = (f16[8,8]{1,0}, s8[256]{0}) custom-call(bitcast.41, bitcast.42),
+        custom_call_target="__cublas$gemm",
+        backend_config={"gemm_backend_config":{
+          "alpha_real":1,
+          "beta":0,
+          "dot_dimension_numbers":{
+            "lhs_contracting_dimensions":["1"],
+            "rhs_contracting_dimensions":["0"],
+            "lhs_batch_dimensions":[],
+            "rhs_batch_dimensions":[]
+          },
+          "alpha_imag":0,
+          "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
+          "epilogue":"DEFAULT",
+          "lhs_stride":"64",
+          "rhs_stride":"64",
+          "grad_x":false,
+          "grad_y":false
+        }}
+      ROOT get-tuple-element.0 = f16[8,8]{1,0} get-tuple-element(custom-call.1), index=0
+    }
+  )";
+
+  const char* expected = R"(
+    ; CHECK:     address-computation {{.*}} {
+    ; CHECK-DAG:   [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
+    ; CHECK-DAG:   [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(3)
+    ; CHECK-DAG:   [[C1:%[^ ]+]] = s32[] parameter(1)
+    ; CHECK-DAG:   [[C0:%[^ ]+]] = s32[] parameter(2)
+    ; CHECK-DAG:   [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P0]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8}
+    ; CHECK-DAG:   [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
+    ; CHECK-DAG:   [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P1]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8}
+    ; CHECK-DAG:   [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
+    ; CHECK:       [[CC:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0}) custom-call([[B0]], [[B1]]),
+    ; CHECK:              custom_call_target="__cublas$gemm"
+    ; CHECK:       [[DOT:%[^ ]+]] = f16[8,8]{1,0} get-tuple-element([[CC]]), index=0
+    ; CHECK:       [[WORKSPACE:%[^ ]+]] = s8[256]{0} get-tuple-element([[CC]]), index=1
+    ; CHECK:       ROOT [[TUPLE:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0})
+    ; CHECK:              tuple([[DOT]], [[WORKSPACE]])
+    ; CHECK:     }
+
+    ; CHECK:     ENTRY %main{{.*}} {
+    ; CHECK:       [[FUSION:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0}) fusion
+    ; CHECK:         kind=kCustom, calls=%address-computation,
+    ; CHECK:         backend_config={
+    ; CHECK:           "kind":"__custom_fusion",
+    ; CHECK:           "custom_fusion_config":{"name":"dynamic_address_computation","kernel_index":0}
+    ; CHECK:         }
+    ; CHECK:       ROOT [[DOT_MAIN:%[^ ]+]] = f16[8,8]{1,0} get-tuple-element([[FUSION]]), index=0
+    ; CHECK:     }
+  )";
+
+  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+  RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, DynamicSimpleGemmNotRoot) {
+  const char* hlo = R"(
+    HloModule test
+
+    ENTRY main.9 {
+      p0 = f16[2,8,8]{2,1,0} parameter(0)
+      p1 = f16[2,8,8]{2,1,0} parameter(1)
+      c1_s32 = s32[] constant(1)
+      c0_s32 = s32[] constant(0)
+      slice.13 = f16[1,8,8]{2,1,0} dynamic-slice(p0, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8}
+      bitcast.41 = f16[8,8]{1,0} bitcast(slice.13)
+      slice.14 = f16[1,8,8]{2,1,0} dynamic-slice(p1, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8}
+      bitcast.42 = f16[8,8]{1,0} bitcast(slice.14)
+
+      custom-call.1 = f16[8,8]{1,0} custom-call(bitcast.41, bitcast.42),
+        custom_call_target="__cublas$gemm",
+        backend_config={"gemm_backend_config":{
+          "alpha_real":1,
+          "beta":0,
+          "dot_dimension_numbers":{
+            "lhs_contracting_dimensions":["1"],
+            "rhs_contracting_dimensions":["0"],
+            "lhs_batch_dimensions":[],
+            "rhs_batch_dimensions":[]
+          },
+          "alpha_imag":0,
+          "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
+          "epilogue":"DEFAULT",
+          "lhs_stride":"64",
+          "rhs_stride":"64",
+          "grad_x":false,
+          "grad_y":false
+        }}
+      ROOT res = f16[8,8]{1,0} add(custom-call.1, custom-call.1)
+    }
+  )";
+
+  const char* expected = R"(
+    ; CHECK:     address-computation {{.*}} {
+    ; CHECK-DAG:   [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
+    ; CHECK-DAG:   [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(3)
+    ; CHECK-DAG:   [[C1:%[^ ]+]] = s32[] parameter(1)
+    ; CHECK-DAG:   [[C0:%[^ ]+]] = s32[] parameter(2)
+    ; CHECK-DAG:   [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P0]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8}
+    ; CHECK-DAG:   [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
+    ; CHECK-DAG:   [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P1]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8}
+    ; CHECK-DAG:   [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
+    ; CHECK:       ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]),
+    ; CHECK:              custom_call_target="__cublas$gemm"
+    ; CHECK:     }
+
+    ; CHECK:     ENTRY %main{{.*}} {
+    ; CHECK:       [[FUSION:%[^ ]+]] = f16[8,8]{1,0} fusion
+    ; CHECK:         kind=kCustom, calls=%address-computation,
+    ; CHECK:         backend_config={
+    ; CHECK:           "kind":"__custom_fusion",
+    ; CHECK:           "custom_fusion_config":{"name":"dynamic_address_computation","kernel_index":0}
+    ; CHECK:         }
+    ; CHECK:       ROOT {{.*}} = f16[8,8]{1,0} add([[FUSION]], [[FUSION]])
+    ; CHECK:     }
+  )";
+
+  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+  RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, DUSSimpleGemm) {
+  const char* hlo = R"(
+    HloModule test
+
+    ENTRY main.9 {
+      p0 = f16[1,8,8]{2,1,0} parameter(0)
+      p1 = f16[1,8,8]{2,1,0} parameter(1)
+      p2 = f16[4,8,8]{2,1,0} parameter(2)
+      c1_s32 = s32[] constant(1)
+      c0_s32 = s32[] constant(0)
+      bitcast.41 = f16[8,8]{1,0} bitcast(p0)
+      bitcast.42 = f16[8,8]{1,0} bitcast(p1)
+
+      custom-call.1 = f16[8,8]{1,0} custom-call(bitcast.41, bitcast.42),
+        custom_call_target="__cublas$gemm",
+        backend_config={"gemm_backend_config":{
+          "alpha_real":1,
+          "beta":0,
+          "dot_dimension_numbers":{
+            "lhs_contracting_dimensions":["1"],
+            "rhs_contracting_dimensions":["0"],
+            "lhs_batch_dimensions":[],
+            "rhs_batch_dimensions":[]
+          },
+          "alpha_imag":0,
+          "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
+          "epilogue":"DEFAULT",
+          "lhs_stride":"64",
+          "rhs_stride":"64",
+          "grad_x":false,
+          "grad_y":false
+        }}
+      bitcast.43 = f16[1,8,8]{2,1,0} bitcast(custom-call.1)
+      ROOT dus = f16[4,8,8]{2,1,0} dynamic-update-slice(p2, bitcast.43, c1_s32, c0_s32, c0_s32)
+    }
+  )";
+
+  const char* expected = R"(
+    ; CHECK-DAG:   [[P0:%[^ ]+]] = f16[8,8]{1,0} parameter(0)
+    ; CHECK-DAG:   [[P1:%[^ ]+]] = f16[8,8]{1,0} parameter(1)
+    ; CHECK-DAG:   [[P2:%[^ ]+]] = f16[4,8,8]{2,1,0} parameter(2)
+    ; CHECK-DAG:   [[C1:%[^ ]+]] = s32[] parameter(3)
+    ; CHECK-DAG:   [[C0:%[^ ]+]] = s32[] parameter(4)
+    ; CHECK-DAG:   [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[P0]], [[P1]]),
+    ; CHECK-DAG:          custom_call_target="__cublas$gemm"
+    ; CHECK-DAG:   [[BC:%[^ ]+]] = f16[1,8,8]{2,1,0} bitcast([[CC]])
+    ; CHECK:       ROOT {{.*}} = f16[4,8,8]{2,1,0} dynamic-update-slice([[P2]], [[BC]], [[C1]], [[C0]], [[C0]])
+    ; CHECK:     }
+
+    ; CHECK:     ENTRY %main{{.*}} {
+    ; CHECK:       ROOT [[FUSION:%[^ ]+]] = f16[4,8,8]{2,1,0} fusion
+    ; CHECK:         kind=kCustom, calls=%address-computation,
+    ; CHECK:         backend_config={
+    ; CHECK:           "kind":"__custom_fusion",
+    ; CHECK:           "custom_fusion_config":{"name":"dynamic_address_computation","kernel_index":0}
+    ; CHECK:         }
+    ; CHECK:     }
+  )";
+
+  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+  RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, DUSSimpleGemmNotRoot) {
+  const char* hlo = R"(
+    HloModule test
+
+    ENTRY main.9 {
+      p0 = f16[2,8,8]{2,1,0} parameter(0)
+      p1 = f16[2,8,8]{2,1,0} parameter(1)
+      p2 = f16[4,8,8]{2,1,0} parameter(2)
+      c1_s32 = s32[] constant(1)
+      c0_s32 = s32[] constant(0)
+      slice.13 = f16[1,8,8]{2,1,0} dynamic-slice(p0, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8}
+      bitcast.41 = f16[8,8]{1,0} bitcast(slice.13)
+      slice.14 = f16[1,8,8]{2,1,0} dynamic-slice(p1, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8}
+      bitcast.42 = f16[8,8]{1,0} bitcast(slice.14)
+
+      custom-call.1 = f16[8,8]{1,0} custom-call(bitcast.41, bitcast.42),
+        custom_call_target="__cublas$gemm",
+        backend_config={"gemm_backend_config":{
+          "alpha_real":1,
+          "beta":0,
+          "dot_dimension_numbers":{
+            "lhs_contracting_dimensions":["1"],
+            "rhs_contracting_dimensions":["0"],
+            "lhs_batch_dimensions":[],
+            "rhs_batch_dimensions":[]
+          },
+          "alpha_imag":0,
+          "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
+          "epilogue":"DEFAULT",
+          "lhs_stride":"64",
+          "rhs_stride":"64",
+          "grad_x":false,
+          "grad_y":false
+        }}
+      bitcast.43 = f16[1,8,8]{2,1,0} bitcast(custom-call.1)
+      dus = f16[4,8,8]{2,1,0} dynamic-update-slice(p2, bitcast.43, c1_s32, c0_s32, c0_s32)
+      ROOT res = f16[4,8,8]{2,1,0} log(dus)
+    }
+  )";
+
+  const char* expected = R"(
+    ; CHECK:     address-computation {{.*}} {
+    ; CHECK-DAG:   [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
+    ; CHECK-DAG:   [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(3)
+    ; CHECK-DAG:   [[P2:%[^ ]+]] = f16[4,8,8]{2,1,0} parameter(4)
+    ; CHECK-DAG:   [[C1:%[^ ]+]] = s32[] parameter(1)
+    ; CHECK-DAG:   [[C0:%[^ ]+]] = s32[] parameter(2)
+    ; CHECK-DAG:   [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P0]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8}
+    ; CHECK-DAG:   [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
+    ; CHECK-DAG:   [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P1]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8}
+    ; CHECK-DAG:   [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
+    ; CHECK-DAG:   [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]),
+    ; CHECK-DAG:          custom_call_target="__cublas$gemm"
+    ; CHECK-DAG:   [[BC:%[^ ]+]] = f16[1,8,8]{2,1,0} bitcast([[CC]])
+    ; CHECK:       ROOT {{.*}} = f16[4,8,8]{2,1,0} dynamic-update-slice([[P2]], [[BC]], [[C1]], [[C0]], [[C0]])
+    ; CHECK:     }
+
+    ; CHECK:     ENTRY %main{{.*}} {
+    ; CHECK:       [[FUSION:%[^ ]+]] = f16[4,8,8]{2,1,0} fusion
+    ; CHECK:         kind=kCustom, calls=%address-computation,
+    ; CHECK:         backend_config={
+    ; CHECK:           "kind":"__custom_fusion",
+    ; CHECK:           "custom_fusion_config":{"name":"dynamic_address_computation","kernel_index":0}
+    ; CHECK:         }
+    ; CHECK:       ROOT {{.*}} = f16[4,8,8]{2,1,0} log([[FUSION]])
+    ; CHECK:     }
+  )";
+
+  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+  RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, DUSSimpleGemmWithWorkspace) {
+  const char* hlo = R"(
+    HloModule test
+
+    ENTRY main.9 {
+      p0 = f16[2,8,8]{2,1,0} parameter(0)
+      p1 = f16[2,8,8]{2,1,0} parameter(1)
+      p2 = f16[4,8,8]{2,1,0} parameter(2)
+      c1_s32 = s32[] constant(1)
+      c0_s32 = s32[] constant(0)
+      slice.13 = f16[1,8,8]{2,1,0} dynamic-slice(p0, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8}
+      bitcast.41 = f16[8,8]{1,0} bitcast(slice.13)
+      slice.14 = f16[1,8,8]{2,1,0} dynamic-slice(p1, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8}
+      bitcast.42 = f16[8,8]{1,0} bitcast(slice.14)
+
+      custom-call.1 = (f16[8,8]{1,0}, s8[256]{0}) custom-call(bitcast.41, bitcast.42),
+        custom_call_target="__cublas$gemm",
+        backend_config={"gemm_backend_config":{
+          "alpha_real":1,
+          "beta":0,
+          "dot_dimension_numbers":{
+            "lhs_contracting_dimensions":["1"],
+            "rhs_contracting_dimensions":["0"],
+            "lhs_batch_dimensions":[],
+            "rhs_batch_dimensions":[]
+          },
+          "alpha_imag":0,
+          "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
+          "epilogue":"DEFAULT",
+          "lhs_stride":"64",
+          "rhs_stride":"64",
+          "grad_x":false,
+          "grad_y":false
+        }}
+
+    get-tuple-element.0 = f16[8,8]{1,0} get-tuple-element(custom-call.1), index=0
+    bitcast.43 = f16[1,8,8]{2,1,0} bitcast(get-tuple-element.0)
+    dus = f16[4,8,8]{2,1,0} dynamic-update-slice(p2, bitcast.43, c1_s32, c0_s32, c0_s32)
+    get-tuple-element.1 = s8[256]{0} get-tuple-element(custom-call.1), index=1
+    ROOT tuple = (f16[4,8,8]{2,1,0}, s8[256]{0}) tuple(dus, get-tuple-element.1)
+    }
+  )";
+
+  const char* expected = R"(
+    ; CHECK:     address-computation {{.*}} {
+    ; CHECK-DAG:   [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0)
+    ; CHECK-DAG:   [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(3)
+    ; CHECK-DAG:   [[P2:%[^ ]+]] = f16[4,8,8]{2,1,0} parameter(4)
+    ; CHECK-DAG:   [[C1:%[^ ]+]] = s32[] parameter(1)
+    ; CHECK-DAG:   [[C0:%[^ ]+]] = s32[] parameter(2)
+    ; CHECK-DAG:   [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P0]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8}
+    ; CHECK-DAG:   [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
+    ; CHECK-DAG:   [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P1]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8}
+    ; CHECK-DAG:   [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
+    ; CHECK:       [[CC:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0}) custom-call([[B0]], [[B1]]),
+    ; CHECK:              custom_call_target="__cublas$gemm"
+    ; CHECK:       [[DOT:%[^ ]+]] = f16[8,8]{1,0} get-tuple-element([[CC]]), index=0
+    ; CHECK:       [[BC:%[^ ]+]] = f16[1,8,8]{2,1,0} bitcast([[DOT]])
+    ; CHECK:       [[DUS:%[^ ]+]] = f16[4,8,8]{2,1,0} dynamic-update-slice([[P2]], [[BC]], [[C1]], [[C0]], [[C0]])
+    ; CHECK:       [[WORKSPACE:%[^ ]+]] = s8[256]{0} get-tuple-element([[CC]]), index=1
+    ; CHECK:       ROOT [[TUPLE:%[^ ]+]] = (f16[4,8,8]{2,1,0}, s8[256]{0})
+    ; CHECK:              tuple([[DUS]], [[WORKSPACE]])
+    ; CHECK:     }
+
+    ; CHECK:     ENTRY %main{{.*}} {
+    ; CHECK:       [[FUSION:%[^ ]+]] = (f16[4,8,8]{2,1,0}, s8[256]{0}) fusion
+    ; CHECK:         kind=kCustom, calls=%address-computation,
+    ; CHECK:         backend_config={
+    ; CHECK:           "kind":"__custom_fusion",
+    ; CHECK:           "custom_fusion_config":{"name":"dynamic_address_computation","kernel_index":0}
+    ; CHECK:         }
+    ; CHECK:       [[DUS_MAIN:%[^ ]+]] = f16[4,8,8]{2,1,0} get-tuple-element([[FUSION]]), index=0
+    ; CHECK:       [[WORKSPACE_MAIN:%[^ ]+]] = s8[256]{0} get-tuple-element([[FUSION]]), index=1
+    ; CHECK:       ROOT {{.*}} = (f16[4,8,8]{2,1,0}, s8[256]{0})
+    ; CHECK:              tuple([[DUS_MAIN]], [[WORKSPACE_MAIN]])
+    ; CHECK:     }
+  )";
+
+  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+  RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, DUSSimpleGemmWorkspaceIgnored) {
+  const char* hlo = R"(
+    HloModule test
+
+    ENTRY %main.9 {
+      %p0 = f16[8,8]{1,0} parameter(0)
+      %p1 = f16[8,8]{1,0} parameter(1)
+      %p2 = f16[4,8,8]{2,1,0} parameter(2)
+      %c1_s32 = s32[] constant(1)
+      %c0_s32 = s32[] constant(0)
+
+      %custom-call.1 = (f16[8,8]{1,0}, s8[256]{0}) custom-call(%p0, %p1),
+        custom_call_target="__cublas$gemm",
+        backend_config={"gemm_backend_config":{
+          "alpha_real":1,
+          "beta":0,
+          "dot_dimension_numbers":{
+            "lhs_contracting_dimensions":["1"],
+            "rhs_contracting_dimensions":["0"],
+            "lhs_batch_dimensions":[],
+            "rhs_batch_dimensions":[]
+          },
+          "alpha_imag":0,
+          "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
+          "epilogue":"DEFAULT",
+          "lhs_stride":"64",
+          "rhs_stride":"64",
+          "grad_x":false,
+          "grad_y":false
+        }}
+      %get-tuple-element.0 = f16[8,8]{1,0} get-tuple-element(%custom-call.1), index=0
+      %bitcast.43 = f16[1,8,8]{2,1,0} bitcast(%get-tuple-element.0)
+      ROOT %dus = f16[4,8,8]{2,1,0} dynamic-update-slice(%p2, %bitcast.43, %c1_s32, %c0_s32, %c0_s32)
+    })";
+
+  const char* expected = R"(
+    ; CHECK:     address-computation {{.*}} {
+    ; CHECK-DAG:   [[P0:%[^ ]+]] = f16[8,8]{1,0} parameter(0)
+    ; CHECK-DAG:   [[P1:%[^ ]+]] = f16[8,8]{1,0} parameter(1)
+    ; CHECK-DAG:   [[P2:%[^ ]+]] = f16[4,8,8]{2,1,0} parameter(2)
+    ; CHECK-DAG:   [[C1:%[^ ]+]] = s32[] parameter(3)
+    ; CHECK-DAG:   [[C0:%[^ ]+]] = s32[] parameter(4)
+    ; CHECK-DAG:   [[CC:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0}) custom-call([[P0]], [[P1]]),
+    ; CHECK-DAG:          custom_call_target="__cublas$gemm"
+    ; CHECK-DAG:   [[DOT:%[^ ]+]] = f16[8,8]{1,0} get-tuple-element([[CC]]), index=0
+    ; CHECK-DAG:   [[BC:%[^ ]+]] = f16[1,8,8]{2,1,0} bitcast([[DOT]])
+    ; CHECK-DAG:   [[DUS:%[^ ]+]] = f16[4,8,8]{2,1,0} dynamic-update-slice([[P2]], [[BC]], [[C1]], [[C0]], [[C0]])
+    ; CHECK-DAG:   [[WORKSPACE:%[^ ]+]] = s8[256]{0} get-tuple-element([[CC]]), index=1
+    ; CHECK:       ROOT [[TUPLE:%[^ ]+]] = (f16[4,8,8]{2,1,0}, s8[256]{0})
+    ; CHECK:              tuple([[DUS]], [[WORKSPACE]])
+    ; CHECK:     }
+
+    ; CHECK:     ENTRY %main{{.*}} {
+    ; CHECK:       [[FUSION:%[^ ]+]] = (f16[4,8,8]{2,1,0}, s8[256]{0}) fusion
+    ; CHECK:         kind=kCustom, calls=%address-computation,
+    ; CHECK:         backend_config={
+    ; CHECK:           "kind":"__custom_fusion",
+    ; CHECK:           "custom_fusion_config":{"name":"dynamic_address_computation","kernel_index":0}
+    ; CHECK:         }
+    ; CHECK:       ROOT [[DOT_MAIN:%[^ ]+]] = f16[4,8,8]{2,1,0} get-tuple-element([[FUSION]]), index=0
+    ; CHECK:     }
+  )";
+
+  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+  RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, ReduceScatterDUSConstantOffset) {
+  const char* hlo = R"(
+  HloModule test, replica_count=2
+
+  add {
+    param_0 = f16[] parameter(0)
+    param_1 = f16[] parameter(1)
+    ROOT add.1 = f16[] add(param_0, param_1)
+  }
+
+  ENTRY main.9 {
+    param_0 = f16[128,128]{1,0} parameter(0)
+    param_1 = f16[128,128]{1,0} parameter(1)
+    constant_20 = u32[] constant(20)
+    constant_0 = u32[] constant(0)
+    reduce-scatter = f16[64,128]{1,0} reduce-scatter(param_0), channel_id=64, replica_groups={{0,1}}, use_global_device_ids=true, dimensions={0}, to_apply=add
+    ROOT loop_dynamic_update_slice_fusion = f16[128,128]{1,0} dynamic-update-slice(param_1, reduce-scatter, constant_20, constant_0)
+  }
+  )";
+
+  const char* expected = R"(
+  // CHECK: %address-computation{{.+}} {
+  // CHECK:   %[[RS:.+]] = f16[64,128]{1,0} reduce-scatter({{.+}})
+  // CHECK:   ROOT %{{.+}} = f16[128,128]{1,0} dynamic-update-slice(%{{.+}}, %[[RS]], %{{.+}}, %{{.+}})
+  // CHECK: }
+  // CHECK: ENTRY {{.+}} {
+  // CHECK-NOT: reduce-scatter
+  // CHECK:   ROOT %{{.+}} = {{.+}} fusion(%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}), kind=kCustom, calls=%address-computation, {{.+}}"name":"dynamic_address_computation"
+  // CHECK: }
+  )";
+  RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, ReduceScatterDUSParameterOffset) {
+  const char* hlo = R"(
+  HloModule test, replica_count=2
+
+  add.clone {
+    x.1 = f16[] parameter(0)
+    y.1 = f16[] parameter(1)
+    ROOT add.462 = f16[] add(x.1, y.1)
+  }
+
+  ENTRY %main.9 {
+    param_0 = f16[128,128]{1,0} parameter(0)
+    param_1 = f16[128,128]{1,0} parameter(1)
+    param_2 = u32[] parameter(2)
+    constant_0 = u32[] constant(0)
+    reduce-scatter = f16[64,128]{1,0} reduce-scatter(param_0), channel_id=64, replica_groups={{0,1}}, use_global_device_ids=true, dimensions={0}, to_apply=add.clone
+    ROOT dynamic-update-slice = f16[128,128]{1,0} dynamic-update-slice(param_1, reduce-scatter, param_2, constant_0)
+  })";
+  RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"),
+                            std::nullopt);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, ReduceScatterDUSLoopIterationOffset) {
+  const char* hlo = R"(
+  HloModule jit_scan, replica_count=2
+
+  add {
+    param_0 = f32[] parameter(0)
+    param_1 = f32[] parameter(1)
+    ROOT add.6 = f32[] add(param_0, param_1)
+  }
+
+  Body {
+    arg_tuple.1 = (s32[], f32[128,128]{1,0}, f32[128,128,128]{2,1,0}, f32[128,128]{1,0}) parameter(0)
+    get-tuple-element.5 = s32[] get-tuple-element(arg_tuple.1), index=0
+    constant.1 = s32[] constant(1)
+    add.7 = s32[] add(get-tuple-element.5, constant.1)
+    get-tuple-element.6 = f32[128,128]{1,0} get-tuple-element(arg_tuple.1), index=3
+    get-tuple-element.7 = f32[128,128,128]{2,1,0} get-tuple-element(arg_tuple.1), index=2
+    reduce-scatter.0 = f32[64,128]{1,0} reduce-scatter(get-tuple-element.6), channel_id=64, replica_groups={{0,1}}, use_global_device_ids=true, dimensions={0}, to_apply=add
+    bitcast.63 = f32[1,64,128]{2,1,0} bitcast(reduce-scatter.0)
+    constant.2 = s32[] constant(0)
+    compare.4 = pred[] compare(get-tuple-element.5, constant.2), direction=LT
+    constant.3 = s32[] constant(128)
+    add.8 = s32[] add(get-tuple-element.5, constant.3)
+    select.2 = s32[] select(compare.4, add.8, get-tuple-element.5)
+    dynamic-update-slice.2 = f32[128,128,128]{2,1,0} dynamic-update-slice(get-tuple-element.7, bitcast.63, select.2, constant.2, constant.2)
+    ROOT tuple.1 = tuple(add.7, get-tuple-element.6, dynamic-update-slice.2, get-tuple-element.6)
+  } // Body
+
+  Cond {
+    arg_tuple.0 = (s32[], f32[128,128]{1,0}, f32[128,128,128]{2,1,0}, f32[128,128]{1,0}) parameter(0)
+    get-tuple-element.4 = s32[] get-tuple-element(arg_tuple.0), index=0
+    constant.0 = s32[] constant(128)
+    ROOT compare.5 = pred[] compare(get-tuple-element.4, constant.0), direction=LT
+  }
+
+  ENTRY main.55 {
+    Arg_2.3 = f32[128,128,128]{2,1,0} parameter(2)
+    constant.4 = s32[] constant(0)
+    Arg_1.2 = f32[128,128]{1,0} parameter(1)
+    constant.5 = f32[] constant(0)
+    broadcast.1 = f32[128,128,128]{2,1,0} broadcast(constant.5), dimensions={}
+    Arg_0.1 = f32[128,128]{1,0} parameter(0)
+    tuple = tuple(constant.4, Arg_1.2, broadcast.1, Arg_0.1)
+    while = while(tuple), condition=Cond, body=Body, backend_config={"known_trip_count":{"n":"128"}}
+    get-tuple-element.50 = f32[128,128]{1,0} get-tuple-element(while), index=1
+    get-tuple-element.51 = f32[128,128,128]{2,1,0} get-tuple-element(while), index=2
+    ROOT tuple.54 = (f32[128,128]{1,0}, f32[128,128,128]{2,1,0}) tuple(get-tuple-element.50, get-tuple-element.51)
+  })";
+  const char* expected = R"(
+  // CHECK: %address-computation{{.*}}{
+  // CHECK:   {{.+}} = {{.*}}reduce-scatter({{.+}})
+  // CHECK:   {{.+}} = {{.*}}dynamic-update-slice({{.+}})
+  // CHECK: }
+  // CHECK: Body{{.+}}{
+  // CHECK-NOT: {{.+}} = {{.*}}reduce-scatter({{.+}})
+  // CHECK:   {{.+}} = {{.+}}fusion({{.+}}), kind=kCustom, calls=%address-computation{{.*}}"name":"dynamic_address_computation"
+  // CHECK: }
+  )";
+  RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, DUSSimpleGemmLoopIteration) {
+  const char* hlo = R"(
+  HloModule test
+
+  %Body {
+    param = (f16[1,8,8]{2,1,0}, f16[1,8,8]{2,1,0}, f16[4,8,8]{2,1,0}, u32[]) parameter(0)
+    p0 = get-tuple-element(param), index=0
+    p1 = get-tuple-element(param), index=1
+    p2 = get-tuple-element(param), index=2
+    loop_iter = get-tuple-element(param), index=3
+
+    bitcast.41 = f16[8,8]{1,0} bitcast(p0)
+    bitcast.42 = f16[8,8]{1,0} bitcast(p1)
+    custom-call.1 = f16[8,8]{1,0} custom-call(bitcast.41, bitcast.42), custom_call_target="__cublas$gemm", backend_config={"gemm_backend_config":{
+          "alpha_real":1,
+          "beta":0,
+          "dot_dimension_numbers":{
+          "lhs_contracting_dimensions":["1"],
+          "rhs_contracting_dimensions":["0"],
+          "lhs_batch_dimensions":[],
+          "rhs_batch_dimensions":[]
+          },
+          "alpha_imag":0,
+          "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
+          "epilogue":"DEFAULT",
+          "lhs_stride":"64",
+          "rhs_stride":"64",
+          "grad_x":false,
+          "grad_y":false
+      }}
+    bitcast.43 = f16[1,8,8]{2,1,0} bitcast(custom-call.1)
+    c0 = u32[] constant(0)
+    c_trip_count = u32[] constant(11)
+    compare = pred[] compare(loop_iter, c0), direction=LT
+    add = u32[] add(loop_iter, c_trip_count)
+    offset = u32[] select(compare, add, loop_iter)
+    dus = f16[4,8,8]{2,1,0} dynamic-update-slice(p2, bitcast.43, offset, c0, c0)
+    c1 = u32[] constant(1)
+    add2 = u32[] add(loop_iter, c1)
+    ROOT tuple = tuple(p0, p1, dus, u32[] add2)
+  }
+
+  %Cond {
+    %param.1 = (f16[1,8,8]{2,1,0}, f16[1,8,8]{2,1,0}, f16[4,8,8]{2,1,0}, u32[]) parameter(0)
+    %i.1 = u32[] get-tuple-element(%param.1), index=3
+    %trip_count = u32[] constant(11)
+    ROOT %done = pred[] compare(u32[] %i.1, u32[] %trip_count), direction=LT
+  }
+
+  ENTRY %test {
+    %p0.1 = f16[1,8,8]{2,1,0} parameter(0)
+    %p1.1 = f16[1,8,8]{2,1,0} parameter(1)
+    %p2.1 = f16[4,8,8]{2,1,0} parameter(2)
+    %c0.1 = u32[] constant(0)
+    %initial_tuple = tuple(%p0.1, %p1.1, %p2.1, u32[] %c0.1)
+    ROOT %while = while(%initial_tuple), condition=%Cond, body=%Body, backend_config={"known_trip_count":{"n":"11"}}
+  })";
+
+  const char* expected = R"(
+  // CHECK: %Body{{.+}}{
+  // CHECK:   %[[PARAM:.+]] = {{.+}} parameter(0)
+  // CHECK:   %[[LOOP_ITER:.+]] = u32[] get-tuple-element(%[[PARAM]]), index=3
+  // CHECK:   %[[OFFSET:.+]] = u32[] select({{.+}})
+  // CHECK:   %[[ADDRESS_COMPUTATION:.+]] = {{.+}} fusion({{.+}}, {{.+}}, {{.+}}, %[[OFFSET]], %{{.+}}), kind=kCustom, calls=%address-computation, {{.+}}"name":"dynamic_address_computation"
+  // CHECK:   ROOT %tuple = {{.+}} tuple(%{{.+}}, %{{.+}}, %[[ADDRESS_COMPUTATION]], %{{.+}})
+  // CHECK: }
+  // CHECK: ENTRY %test{{.+}}{
+  // CHECK:   ROOT %{{.+}} = {{.+}} while(%{{.+}}), condition=%{{.+}}, body=%Body{{.*}}, backend_config={"known_trip_count":{"n":"11"}}
+  }
+  )";
+
+  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+  RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, DUSSimpleGemmParameterOffset) {
+  const char* hlo = R"(
+      HloModule test
+
+    ENTRY main.9 {
+      p0 = f16[1,8,8]{2,1,0} parameter(0)
+      p1 = f16[1,8,8]{2,1,0} parameter(1)
+      p2 = f16[4,8,8]{2,1,0} parameter(2)
+      p3 = s32[] parameter(3)
+      c1_s32 = s32[] constant(1)
+      c0_s32 = s32[] constant(0)
+      bitcast.41 = f16[8,8]{1,0} bitcast(p0)
+      bitcast.42 = f16[8,8]{1,0} bitcast(p1)
+
+      custom-call.1 = f16[8,8]{1,0} custom-call(bitcast.41, bitcast.42),
+        custom_call_target="__cublas$gemm",
+        backend_config={"gemm_backend_config":{
+          "alpha_real":1,
+          "beta":0,
+          "dot_dimension_numbers":{
+            "lhs_contracting_dimensions":["1"],
+            "rhs_contracting_dimensions":["0"],
+            "lhs_batch_dimensions":[],
+            "rhs_batch_dimensions":[]
+          },
+          "alpha_imag":0,
+          "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
+          "epilogue":"DEFAULT",
+          "lhs_stride":"64",
+          "rhs_stride":"64",
+          "grad_x":false,
+          "grad_y":false
+        }}
+      bitcast.43 = f16[1,8,8]{2,1,0} bitcast(custom-call.1)
+      ROOT dus = f16[4,8,8]{2,1,0} dynamic-update-slice(p2, bitcast.43, p3, c0_s32, c0_s32)
+    })";
+
+  RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"),
+                            std::nullopt);
+}
+
+TEST_F(DynamicSliceFusionRewriterTest, DUSSimpleGemmLaxScan) {
+  const char* hlo = R"(
+  HloModule lax_scan
+
+  // This is the HLO generated for the following:
+  // 
+  // inp = jax.random.uniform(jax.random.key(128), (128, 128, 128))
+  // init = jnp.identity(128)
+  // ans = jax.lax.scan(lambda carry, x : (init, x@carry), init, inp)
+
+  Body {
+    arg_tuple.15 = (s32[], f32[128,128]{1,0}, f32[128,128,128]{2,1,0}, f32[128,128,128]{2,1,0}, f32[128,128]{1,0}) parameter(0)
+    get-tuple-element.16 = s32[] get-tuple-element(arg_tuple.15), index=0
+    constant.21 = s32[] constant(1)
+    add.2 = s32[] add(get-tuple-element.16, constant.21)
+    get-tuple-element.30 = f32[128,128]{1,0} get-tuple-element(arg_tuple.15), index=4
+    get-tuple-element.18 = f32[128,128,128]{2,1,0} get-tuple-element(arg_tuple.15), index=2
+    get-tuple-element.19 = f32[128,128,128]{2,1,0} get-tuple-element(arg_tuple.15), index=3
+    constant.23 = s32[] constant(0)
+    compare.2 = pred[] compare(get-tuple-element.16, constant.23), direction=LT
+    constant.22 = s32[] constant(128)
+    add.3 = s32[] add(get-tuple-element.16, constant.22)
+    select.1 = s32[] select(compare.2, add.3, get-tuple-element.16)
+    dynamic-slice.1 = f32[1,128,128]{2,1,0} dynamic-slice(get-tuple-element.19, select.1, constant.23, constant.23), dynamic_slice_sizes={1,128,128}
+    bitcast.72 = f32[128,128]{1,0} bitcast(dynamic-slice.1)
+    get-tuple-element.17 = f32[128,128]{1,0} get-tuple-element(arg_tuple.15), index=1
+    custom-call.1 = (f32[128,128]{1,0}, s8[131072]{0}) custom-call(bitcast.72, get-tuple-element.17), custom_call_target="__cublas$gemm"
+    get-tuple-element = f32[128,128]{1,0} get-tuple-element(custom-call.1), index=0
+    bitcast.77 = f32[1,128,128]{2,1,0} bitcast(get-tuple-element)
+    dynamic-update-slice.1 = f32[128,128,128]{2,1,0} dynamic-update-slice(get-tuple-element.18, bitcast.77, select.1, constant.23, constant.23)
+    ROOT tuple.38 = tuple(add.2, get-tuple-element.30, dynamic-update-slice.1, get-tuple-element.19, get-tuple-element.30)
+  } // Body
+
+  Cond {
+    arg_tuple.40 = (s32[], f32[128,128]{1,0}, f32[128,128,128]{2,1,0}, f32[128,128,128]{2,1,0}, f32[128,128]{1,0}) parameter(0)
+    get-tuple-element.41 = s32[] get-tuple-element(arg_tuple.40), index=0
+    constant.46 = s32[] constant(128)
+    ROOT compare.3 = pred[] compare(get-tuple-element.41, constant.46), direction=LT
+  }
+
+  ENTRY main {
+    constant.4 = s32[] constant(0)
+    Arg_1.2 = f32[128,128]{1,0} parameter(1)
+    constant.5 = f32[] constant(0)
+    broadcast.1 = f32[128,128,128]{2,1,0} broadcast(constant.5), dimensions={}
+    Arg_2.3 = f32[128,128,128]{2,1,0} parameter(2)
+    Arg_0.1 = f32[128,128]{1,0} parameter(0)
+    tuple.7 = tuple(constant.4, Arg_1.2, broadcast.1, Arg_2.3, Arg_0.1)
+    while.48 = while(tuple.7), condition=Cond, body=Body, backend_config={"known_trip_count":{"n":"128"}}
+    get-tuple-element.50 = f32[128,128]{1,0} get-tuple-element(while.48), index=1
+    get-tuple-element.51 = f32[128,128,128]{2,1,0} get-tuple-element(while.48), index=2
+    ROOT tuple.54 = (f32[128,128]{1,0}, f32[128,128,128]{2,1,0}) tuple(get-tuple-element.50, get-tuple-element.51)
+  } // main.55
+
+)";
+  auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+  const char* expected = R"(
+  // CHECK: %address-computation{{.*}} {{.+}} {
+  // CHECK:   {{.+}} = {{.+}}dynamic-slice
+  // CHECK:   {{.+}} = {{.+}}custom-call
+  // CHECK:   {{.+}} = {{.+}}dynamic-update-slice
+  // CHECK: }
+  // CHECK: %Body{{.+}}{
+  // CHECK:   %[[PARAM:.+]] = {{.+}} parameter(0)
+  // CHECK:   %[[LOOP_ITER:.+]] = s32[] get-tuple-element(%[[PARAM]]), index=0
+  // CHECK:   %[[OFFSET:.+]] = s32[] select({{.+}})
+  // CHECK:   %[[ADDRESS_COMPUTATION:.+]] = {{.+}} fusion({{.+}}, %[[OFFSET]], %{{.+}}), kind=kCustom, calls=%address-computation{{.+}}"name":"dynamic_address_computation"
+  // CHECK:   %[[GTE:.+]] = {{.+}} get-tuple-element(%[[ADDRESS_COMPUTATION]]), index=0
+  // CHECK:   ROOT %{{.+}} = {{.+}} tuple(%{{.+}}, %[[GTE]], %{{.+}})
+  // CHECK: }
+  // CHECK: ENTRY %main{{.+}}{
+  // CHECK:   %{{.+}} = {{.+}} while(%{{.+}}), condition=%{{.+}}, body=%Body{{.*}}, backend_config={"known_trip_count":{"n":"128"}}
+  // CHECK: }
+  )";
+  RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected);
+}
+
+}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/transforms/fusion_merger.cc b/third_party/xla/xla/service/gpu/transforms/fusion_merger.cc
new file mode 100644
index 0000000..5a09bf5
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/fusion_merger.cc
@@ -0,0 +1,327 @@
+/* Copyright 2016 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/fusion_merger.h"
+
+#include <optional>
+#include <string>
+#include <vector>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/gpu/gpu_fusible.h"
+#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h"
+#include "xla/service/gpu/model/gpu_performance_model.h"
+#include "xla/service/gpu/model/gpu_performance_model_base.h"
+#include "xla/service/hlo_cost_analysis.h"
+#include "xla/service/hlo_graph_dumper.h"
+#include "xla/service/instruction_fusion.h"
+#include "xla/shape_util.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/util.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/status.h"
+
+namespace xla {
+namespace gpu {
+
+// For each fusion F, attempts to fuse F into *all* of F's users (does not fuse
+// if can't fuse into at least one).
+class FusionInstructionMerger {
+ public:
+  explicit FusionInstructionMerger(
+      HloComputation* computation, const se::DeviceDescription& gpu_device_info,
+      HloCostAnalysis::ShapeSizeFunction shape_size_function)
+      : computation_(computation),
+        shape_size_function_(shape_size_function),
+        gpu_device_info_(gpu_device_info),
+        dump_fusion_visualization_(computation->parent()
+                                       ->config()
+                                       .debug_options()
+                                       .xla_dump_fusion_visualization()) {}
+
+  absl::Status Run();
+
+  bool changed() const { return changed_; }
+
+ private:
+  FusionDecision ShouldFuse(HloInstruction* producer);
+  absl::Status FuseIntoAllUsers(HloInstruction* producer);
+
+  HloComputation* computation_;
+  HloCostAnalysis::ShapeSizeFunction shape_size_function_;
+  // Many cheap checks can prevent fusion merging - postpone execution of full
+  // HLO cost analysis of the computation so that it may be not needed at all.
+  std::optional<GpuHloCostAnalysis> cost_analysis_;
+  FusionInfoCache fusion_info_cache_;
+  const se::DeviceDescription& gpu_device_info_;
+  bool changed_ = false;
+  bool dump_fusion_visualization_ = false;
+
+  // Fusion instruction merge stats.
+  int total_visited_ = 0;
+  int total_merged_ = 0;
+  int num_fail_no_users_ = 0;
+  int num_fail_not_loop_fusion_ = 0;
+  int num_fail_merge_all_users_ = 0;
+  int num_fail_inefficient_fusion_emitter_ = 0;
+  int num_fail_fusion_too_large_ = 0;
+  int num_fail_uncoalesced_read_ = 0;
+  int num_fail_slower_if_fused_ = 0;
+
+  FusionInstructionMerger(const FusionInstructionMerger&) = delete;
+  FusionInstructionMerger& operator=(const FusionInstructionMerger&) = delete;
+};
+
+absl::Status FusionInstructionMerger::FuseIntoAllUsers(
+    HloInstruction* producer) {
+  // Merge fused instructions from 'fusion' into each user.
+  std::vector<HloInstruction*> users = producer->users();
+  for (HloInstruction* user : users) {
+    if (dump_fusion_visualization_) {
+      RegisterFusionState(
+          *computation_,
+          absl::StrCat("About to fuse |", producer->name(), "| into |",
+                       user->name(), "| inside FusionMerger"),
+          /*consumer=*/*user,
+          /*producer=*/producer);
+    }
+
+    TF_RETURN_IF_ERROR(cost_analysis_->RemoveInstruction(user));
+
+    // Wrap consumers which are not fusions first.
+    HloInstruction* consumer = user;
+    if (consumer->opcode() != HloOpcode::kFusion) {
+      consumer = computation_->AddInstruction(HloInstruction::CreateFusion(
+          user->shape(), ChooseFusionKind(*producer, *user), user));
+      TF_CHECK_OK(computation_->ReplaceInstruction(user, consumer));
+    }
+
+    consumer->MergeFusionInstruction(producer);
+    TF_RETURN_IF_ERROR(cost_analysis_->RevisitInstruction(consumer));
+    fusion_info_cache_.Invalidate(consumer);
+
+    if (dump_fusion_visualization_) {
+      RegisterFusionState(*computation_,
+                          absl::StrCat("Fused |", producer->name(), "| into |",
+                                       user->name(), "| inside FusionMerger"),
+                          *consumer);
+    }
+
+    changed_ = true;
+  }
+
+  CHECK_EQ(0, producer->user_count()) << producer->ToString();
+  TF_RETURN_IF_ERROR(computation_->RemoveInstruction(producer));
+  TF_RETURN_IF_ERROR(cost_analysis_->RemoveInstruction(producer));
+  fusion_info_cache_.Invalidate(producer);
+  VLOG(2) << "Merged fusion instruction: " << producer->name()
+          << " into users { "
+          << absl::StrJoin(users, ", ",
+                           [](std::string* out, HloInstruction* user) {
+                             absl::StrAppend(out, user->name());
+                           })
+          << " }";
+  return absl::OkStatus();
+}
+
+absl::Status FusionInstructionMerger::Run() {
+  for (HloInstruction* producer : computation_->MakeInstructionPostOrder()) {
+    if (producer->opcode() != HloOpcode::kFusion) {
+      continue;
+    }
+    FusionDecision should_fuse = ShouldFuse(producer);
+    if (should_fuse) {
+      TF_RETURN_IF_ERROR(FuseIntoAllUsers(producer));
+      ++total_merged_;
+    } else {
+      VLOG(3) << "Not fusing fusion |" << producer->name()
+              << "| with all of it's users due to: " << should_fuse.Explain();
+      if (dump_fusion_visualization_ && !producer->users().empty()) {
+        RegisterFusionState(
+            *computation_,
+            absl::StrCat(
+                "Not fusing fusion |", producer->name(),
+                "| into all of its users due to: ", should_fuse.Explain()),
+            // Just pick any consumer, since we are trying to merge into all.
+            /*consumer=*/*producer->users()[0],
+            /*producer=*/producer);
+      }
+    }
+  }
+
+  VLOG(1) << "FusionInstructionMerger EXIT"
+          << " computation: " << computation_->name()
+          << " total_visited: " << total_visited_
+          << " total_merged: " << total_merged_ << " merge failures { "
+          << " no_users: " << num_fail_no_users_
+          << " not_loop_fusion: " << num_fail_not_loop_fusion_
+          << " merge_all_users: " << num_fail_merge_all_users_
+          << " uncoalesced_read: " << num_fail_uncoalesced_read_
+          << " inefficient_fusion_emitter: "
+          << num_fail_inefficient_fusion_emitter_
+          << " slower_if_fused: " << num_fail_slower_if_fused_
+          << " fusion_too_large: " << num_fail_fusion_too_large_ << " }";
+  return absl::OkStatus();
+}
+
+bool TransposesMostData(const HloInstruction& fusion) {
+  float score = 0;
+
+  for (const HloInstruction* instr : fusion.fused_instructions()) {
+    if (IsPhysicallyTransposing(*instr)) {
+      score += 1.0 * ShapeUtil::ElementsInRecursive(instr->shape()) /
+               ShapeUtil::ElementsInRecursive(fusion.shape());
+      if (score >= 0.5) {
+        VLOG(3) << fusion.ToString() << " transpose ratio exceeds " << score;
+        return true;
+      }
+    }
+  }
+
+  return false;
+}
+
+FusionDecision FusionInstructionMerger::ShouldFuse(HloInstruction* producer) {
+  ++total_visited_;
+
+  VLOG(4) << "Considering producer " << producer->name();
+
+  // Skip 'producer' instruction if there are no users into which we can
+  // merge.
+  if (producer->users().empty()) {
+    ++num_fail_no_users_;
+    return "fusion has no users";
+  }
+
+  // Skip 'producer' instruction if it is not a loop fusion. Library fusion
+  // instructions match specific patterns, so they shouldn't be further fused.
+  // Input fusion instructions need to be rooted at a particular HLO (e.g.
+  // kReduce), so they shouldn't be further fused either.
+  if (!producer->IsLoopFusion()) {
+    ++num_fail_not_loop_fusion_;
+    return "not a loop fusion";
+  }
+
+  auto producer_hero = GetRealHeroForMultiOutputFusion(*producer);
+
+  bool has_reduction_user = false;
+  for (const HloInstruction* user : producer->users()) {
+    if (user->opcode() == HloOpcode::kBitcast) {
+      ++num_fail_merge_all_users_;
+      return "not fusing bitcast ops";
+    }
+    if (user->IsCustomFusion()) {
+      ++num_fail_merge_all_users_;
+      return "not fusing custom fusions";
+    }
+    auto consumer_hero = GetRealHeroForMultiOutputFusion(*user);
+    if (auto compatible =
+            FusionHeroesAreCompatible(producer_hero, consumer_hero);
+        !compatible) {
+      return compatible;
+    }
+    FusionDecision fusible = IsProducerConsumerFusible(*producer, *user);
+    if (!fusible) {
+      ++num_fail_merge_all_users_;
+      VLOG(9) << user->ToString();
+      return fusible;
+    }
+    if (IsInputFusibleReduction(*user)) {
+      has_reduction_user = true;
+    }
+  }
+
+  // We do not want to worsen reduction's memory access pattern by connecting
+  // it to a producer which transposes most data.
+  if (has_reduction_user && TransposesMostData(*producer)) {
+    ++num_fail_uncoalesced_read_;
+    return "would read mostly uncoalesced";
+  }
+
+  for (const HloInstruction* user : producer->users()) {
+    // Skip 'fusion' instruction if merging it into at least one of the users
+    // would make the fusion use too much shared memory or registers.
+    FusionDecision fits = FusionFitsInBudget(
+        *user, *producer, gpu_device_info_,
+        /*is_consumer_producer_fusion=*/true, &fusion_info_cache_);
+    if (!fits) {
+      ++num_fail_fusion_too_large_;
+      return fits;
+    }
+  }
+
+  if (!cost_analysis_) {
+    VLOG(2) << "Running full HLO cost analysis for " << computation_->name();
+    cost_analysis_.emplace(
+        GpuHloCostAnalysis::Options{shape_size_function_,
+                                    /*per_second_rates=*/{},
+                                    /*count_multiple_input_accesses=*/true},
+        gpu_device_info_);
+    TF_CHECK_OK(computation_->Accept(&cost_analysis_.value()));
+  }
+
+  for (const HloInstruction* user : producer->users()) {
+    if (cost_analysis_->ProducerConsumerMergedTooLarge(*producer, *user)) {
+      ++num_fail_inefficient_fusion_emitter_;
+      return FusionDecision{} << "if merged with " << user->name()
+                              << " will generate huge IR";
+    }
+  }
+
+  GpuPerformanceModel::RunTimes t = GpuPerformanceModel::EstimateRunTimes(
+      producer, gpu_device_info_, &*cost_analysis_,
+      GpuPerformanceModelOptions::Default(), producer->users());
+  if (t.time_fused > t.time_unfused) {
+    ++num_fail_slower_if_fused_;
+    return "will execute slower if fused";
+  }
+
+  return {};
+}
+
+absl::StatusOr<bool> FusionMerger::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  bool changed = false;
+  VLOG(1) << "FusionMerger for module: " << module->name();
+  for (auto* computation :
+       module->MakeNonfusionComputations(execution_threads)) {
+    VLOG(9) << "Before running FusionInstructionMerger for computation: "
+            << computation->name();
+    XLA_VLOG_LINES(9, computation->ToString());
+
+    FusionInstructionMerger fusion_merger(computation, gpu_device_info_,
+                                          shape_size_function_);
+    TF_RETURN_IF_ERROR(fusion_merger.Run());
+    changed |= fusion_merger.changed();
+
+    VLOG(9) << "After running FusionInstructionMerger for computation: "
+            << computation->name() << " changed: " << changed;
+    XLA_VLOG_LINES(9, computation->ToString());
+  }
+  return changed;
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/fusion_merger.h b/third_party/xla/xla/service/gpu/transforms/fusion_merger.h
new file mode 100644
index 0000000..15ea960
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/fusion_merger.h
@@ -0,0 +1,85 @@
+/* Copyright 2016 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_FUSION_MERGER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_FUSION_MERGER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_cost_analysis.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/stream_executor/device_description.h"
+
+namespace xla {
+namespace gpu {
+
+// An HLO pass that attempts to merge fusion instructions to reduce memory
+// bandwidth requirements and kernel launch overhead.
+//
+// Consider the example below. On the left-hand side, op A is the producer and
+// ops B and C are its consumers. FusionMerger duplicates producer ops and fuses
+// them into all consumers. The result is depicted on the right-hand side below.
+//
+//        p                    p
+//        |                  /   \
+//        v                 /     \
+//        A            +fusion+  +fusion+
+//      /   \          |  A'  |  |  A"  |
+//     |     |         |  |   |  |  |   |
+//     v     v         |  v   |  |  v   |
+//     B     C         |  B   |  |  C   |
+//                     +------+  +------+
+//
+// Op A has been cloned twice and fused with B and C. The kernel launch overhead
+// is reduced from 3 to 2. The memory bandwidth requirements may be reduced.
+// We trade 1 read of input(A) + 1 write and 2 reads of output(A) for 2 reads of
+// input(A). In general the achieveable savings in memory bandwidth depend on
+// the differences in memory read and written and the number of consumers. The
+// FusionMeger pass takes this into account when making fusion decisions.
+//
+// The pass traverses the HLO module in post-order (defs before uses).
+// Fusion instructions are merged into their users if some conditions are met:
+// * The result of merging the fusion instruction into its users would not
+//   increase bytes transferred.
+// * Producer ops are fusible with _all_ consumers. If they are not fusible with
+//   at least one consumers, they won't be fused at all.
+// * Producers are kLoop fusion ops.
+//
+// None of these restrictions are necessary for correctness. In fact, lifting
+// the latter two could be beneficial.
+
+class FusionMerger : public HloModulePass {
+ public:
+  explicit FusionMerger(const se::DeviceDescription& d,
+                        HloCostAnalysis::ShapeSizeFunction f)
+      : gpu_device_info_(d), shape_size_function_(f) {}
+  absl::string_view name() const override { return "fusion_merger"; }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+  se::DeviceDescription gpu_device_info_;
+  HloCostAnalysis::ShapeSizeFunction shape_size_function_;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_FUSION_MERGER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/fusion_merger_test.cc b/third_party/xla/xla/service/gpu/transforms/fusion_merger_test.cc
new file mode 100644
index 0000000..5068f65
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/fusion_merger_test.cc
@@ -0,0 +1,1170 @@
+/* Copyright 2016 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/fusion_merger.h"
+
+#include <cstdint>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/types/span.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/gpu/gpu_device_info_for_tests.h"
+#include "xla/service/gpu/gpu_fusible.h"
+#include "xla/service/hlo_cost_analysis.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/xla_data.pb.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+namespace m = ::xla::match;
+
+class FusionMergerTest : public HloTestBase {
+  HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const {
+    return [&](const Shape& shape) {
+      constexpr int64_t kPointerSize = 8;
+      return ShapeUtil::ByteSizeOf(shape, kPointerSize);
+    };
+  }
+
+ public:
+  FusionMerger fusion_merger_{TestGpuDeviceInfo::RTXA6000DeviceInfo(),
+                              ShapeSizeBytesFunction()};
+  FusionMergerTest() : HloTestBase() {}
+};
+
+// Tests that we can merge a fusion instruction that is below threshold.
+//
+// Computation after fusion merger pass (Fusion2 is merged into Fusion0 and
+// Fusion1):
+//                   Param
+//                 /   |   \
+//          Fusion3 Fusion0 Fusion1
+//                 \   |   /
+//                   Tuple
+//
+TEST_F(FusionMergerTest, MergeSharedFusionInstruction) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+HloModule MergeSharedFusionInstruction
+
+comp.3 {
+  constant.param_0 = f32[4]{0} parameter(0)
+  param.param_1.2 = (f32[4]{0}, f32[4]{0}, f32[4]{0}) parameter(1)
+  get-tuple-element.6 = f32[4]{0} get-tuple-element(param.param_1.2), index=0
+  ROOT add.7 = f32[4]{0} add(constant.param_0, get-tuple-element.6)
+}
+
+comp.2 {
+  param.param_1.1 = (f32[4]{0}, f32[4]{0}, f32[4]{0}) parameter(0)
+  get-tuple-element.4 = f32[4]{0} get-tuple-element(param.param_1.1), index=1
+  get-tuple-element.5 = f32[4]{0} get-tuple-element(param.param_1.1), index=2
+  ROOT add.6 = f32[4]{0} add(get-tuple-element.4, get-tuple-element.5)
+}
+
+comp.1 {
+  add.1.param_1.1 = f32[4]{0} parameter(1)
+  constant.param_1.3 = f32[4]{0} parameter(0)
+  add.5 = f32[4]{0} add(add.1.param_1.1, constant.param_1.3)
+  ROOT multiply.3 = f32[4]{0} multiply(add.5, constant.param_1.3)
+}
+
+comp {
+  add.1.param_1 = f32[4]{0} parameter(1)
+  constant.param_1.1 = f32[4]{0} parameter(0)
+  multiply.2 = f32[4]{0} multiply(add.1.param_1, constant.param_1.1)
+  ROOT add.4 = f32[4]{0} add(multiply.2, constant.param_1.1)
+}
+
+ENTRY MergeSharedFusionInstruction.Computation0 {
+  constant = f32[4]{0} constant({1, 1, 1, 1})
+  param = (f32[4]{0}, f32[4]{0}, f32[4]{0}) parameter(0)
+  fusion.3 = f32[4]{0} fusion(constant, param), kind=kLoop, calls=comp.3
+  fusion.4 = f32[4]{0} fusion(param), kind=kLoop, calls=comp.2
+  fusion.5 = f32[4]{0} fusion(constant, fusion.4), kind=kLoop, calls=comp.1
+  fusion.6 = f32[4]{0} fusion(constant, fusion.4), kind=kLoop, calls=comp
+  ROOT tuple = (f32[4]{0}, f32[4]{0}, f32[4]{0}) tuple(fusion.3, fusion.5, fusion.6)
+})")
+                    .value();
+  EXPECT_TRUE(fusion_merger_.Run(module.get()).value());
+
+  auto* root = module->entry_computation()->root_instruction();
+  EXPECT_EQ(HloOpcode::kTuple, root->opcode());
+  // Check operand 0 (not merged). Should have 4 instructions.
+  auto* operand0 = root->operand(0);
+  EXPECT_EQ(HloOpcode::kFusion, operand0->opcode());
+  EXPECT_EQ(4, operand0->fused_instruction_count());
+  // Check operand 1 (should have merged in its operand fusion instruction).
+  auto* operand1 = root->operand(1);
+  EXPECT_EQ(HloOpcode::kFusion, operand1->opcode());
+  EXPECT_EQ(7, operand1->fused_instruction_count());
+  // Check operand 2 (should have merged in its operand fusion instruction).
+  auto* operand2 = root->operand(2);
+  EXPECT_EQ(HloOpcode::kFusion, operand2->opcode());
+  EXPECT_EQ(7, operand2->fused_instruction_count());
+}
+
+TEST_F(FusionMergerTest, MoreMemoryAccessIfFused) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+HloModule m
+
+f32add {
+  x = f32[] parameter(0)
+  y = f32[] parameter(1)
+  ROOT _ = f32[] add(x, y)
+}
+
+comp0 {
+  p = (f32[2048], f32[2048], f32[2048], f32[2048]) parameter(0)
+  gte0 = f32[2048] get-tuple-element(p), index=0
+  gte1 = f32[2048] get-tuple-element(p), index=1
+  add.9 = f32[2048] add(gte0, gte1)
+  gte2 = f32[2048] get-tuple-element(p), index=2
+  add.10 = f32[2048] add(add.9, gte2)
+  gte3 = f32[2048] get-tuple-element(p), index=3
+  add.11 = f32[2048] add(add.10, gte3)
+  p1 = (f32[2048], f32[2048], f32[2048], f32[2048]) parameter(1)
+  gte4 = f32[2048] get-tuple-element(p1), index=0
+  gte5 = f32[2048] get-tuple-element(p1), index=1
+  add.12 = f32[2048] add(gte4, gte5)
+  gte6 = f32[2048] get-tuple-element(p1), index=2
+  add.13 = f32[2048] add(add.12, gte6)
+  gte7 = f32[2048] get-tuple-element(p1), index=3
+  add.14 = f32[2048] add(add.13, gte7)
+  ROOT r = f32[2048] add(add.14, add.11)
+}
+
+comp1 {
+  p = f32[2048] parameter(0)
+  c0 = f32[] constant(0)
+  ROOT r = f32[] reduce(p, c0), dimensions={0}, to_apply=f32add
+}
+
+comp2 {
+  p = f32[2048] parameter(0)
+  c0 = f32[] constant(0)
+  r = f32[] reduce(p, c0), dimensions={0}, to_apply=f32add
+  ROOT n = f32[] negate(r)
+}
+
+ENTRY m.Computation2 {
+  p0 = (f32[2048], f32[2048], f32[2048], f32[2048]) parameter(0)
+  p1 = (f32[2048], f32[2048], f32[2048], f32[2048]) parameter(1)
+  fusion.0 = f32[2048] fusion(p0, p1), kind=kLoop, calls=comp0
+  fusion.1 = f32[] fusion(fusion.0), kind=kLoop, calls=comp1
+  fusion.2 = f32[] fusion(fusion.0), kind=kLoop, calls=comp2
+  ROOT tuple = (f32[], f32[]) tuple(fusion.1, fusion.2)
+}
+)")
+                    .value();
+  EXPECT_FALSE(fusion_merger_.Run(module.get()).value());
+}
+
+TEST_F(FusionMergerTest, LessMemoryAccessIfFused) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+HloModule m
+
+comp.2 {
+  state.param_1.1 = (f32[4]{0}, f32[4]{0}, f32[4]{0}) parameter(0)
+  get-tuple-element.5 = f32[4]{0} get-tuple-element(state.param_1.1), index=0
+  get-tuple-element.6 = f32[4]{0} get-tuple-element(state.param_1.1), index=1
+  add.7 = f32[4]{0} add(get-tuple-element.5, get-tuple-element.6)
+  get-tuple-element.7 = f32[4]{0} get-tuple-element(state.param_1.1), index=2
+  ROOT add.8 = f32[4]{0} add(add.7, get-tuple-element.7)
+}
+
+comp.1 {
+  add.1.param_1.1 = f32[4]{0} parameter(1)
+  constant.param_1.3 = f32[4]{0} parameter(0)
+  add.5 = f32[4]{0} add(add.1.param_1.1, constant.param_1.3)
+  ROOT multiply.3 = f32[4]{0} multiply(add.5, constant.param_1.3)
+}
+
+comp {
+  add.1.param_1 = f32[4]{0} parameter(1)
+  constant.param_1.1 = f32[4]{0} parameter(0)
+  multiply.2 = f32[4]{0} multiply(add.1.param_1, constant.param_1.1)
+  ROOT add.4 = f32[4]{0} add(multiply.2, constant.param_1.1)
+}
+
+ENTRY m.Computation2 {
+  constant = f32[4]{0} constant({1, 1, 1, 1})
+  state = (f32[4]{0}, f32[4]{0}, f32[4]{0}) parameter(0)
+  fusion.2 = f32[4]{0} fusion(state), kind=kLoop, calls=comp.2
+  fusion.3 = f32[4]{0} fusion(constant, fusion.2), kind=kLoop, calls=comp.1
+  fusion.4 = f32[4]{0} fusion(constant, fusion.2), kind=kLoop, calls=comp
+  ROOT tuple = (f32[4]{0}, f32[4]{0}) tuple(fusion.3, fusion.4)
+})")
+                    .value();
+  EXPECT_TRUE(fusion_merger_.Run(module.get()).value());
+}
+
+// Check that we're willing to merge f1_computation into f2_computation, even
+// though f2 is an input fusion node.
+TEST_F(FusionMergerTest, WillMergeIntoInputFusion) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule m
+
+    f1_computation {
+      f1_p0 = f32[32]{0} parameter(0)
+      ROOT f1_root = f32[32]{0} add(f1_p0, f1_p0)
+    }
+
+    add_computation {
+      add_lhs = f32[] parameter(0)
+      add_rhs = f32[] parameter(1)
+      ROOT add_root = f32[] add(add_lhs, add_rhs)
+    }
+
+    f2_computation {
+      f2_p0 = f32[32]{0} parameter(0)
+      f2_mul = f32[32]{0} multiply(f2_p0, f2_p0)
+      f2_zero = f32[] constant(0)
+      ROOT f2_root = f32[] reduce(f2_mul, f2_zero), dimensions={0},
+             to_apply=add_computation
+    }
+
+    ENTRY entry {
+      p0 = f32[32]{0} parameter(0)
+      f1 = f32[32]{0} fusion(p0), kind=kLoop, calls=f1_computation
+      ROOT f2 = f32[] fusion(f1), kind=kInput, calls=f2_computation
+    })")
+                    .value();
+  EXPECT_TRUE(fusion_merger_.Run(module.get()).value());
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch(m::Fusion(m::Parameter())));
+}
+
+TEST_F(FusionMergerTest, WillMergeIntoUnfusedConsumer) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule jit_matmul.36
+
+    max (parameter.13: f32[], parameter.14: f32[]) -> f32[] {
+      parameter.13 = f32[] parameter(0)
+      parameter.14 = f32[] parameter(1)
+      ROOT maximum.15 = f32[] maximum(f32[] parameter.13, f32[] parameter.14)
+    }
+
+    add (parameter.29: f32[], parameter.30: f32[]) -> f32[] {
+      parameter.29 = f32[] parameter(0)
+      parameter.30 = f32[] parameter(1)
+      ROOT add.31 = f32[] add(f32[] parameter.29, f32[] parameter.30)
+    }
+
+    fused_computation.1 (param_1.4: f32[200,200,200], param_2.1: f32[200,200]) -> f32[200,200] {
+      param_1.4 = f32[200,200,200]{2,1,0} parameter(0)
+      param_2.1 = f32[200,200]{1,0} parameter(1)
+      broadcast.3 = f32[200,200,200]{2,1,0} broadcast(f32[200,200]{1,0} param_2.1), dimensions={0,2}
+      subtract.0 = f32[200,200,200]{2,1,0} subtract(f32[200,200,200]{2,1,0} param_1.4, f32[200,200,200]{2,1,0} broadcast.3)
+      exponential.0 = f32[200,200,200]{2,1,0} exponential(f32[200,200,200]{2,1,0} subtract.0)
+      constant.27 = f32[] constant(0)
+      ROOT reduce.0 = f32[200,200]{1,0} reduce(f32[200,200,200]{2,1,0} exponential.0, f32[] constant.27), dimensions={1}, to_apply=add
+    }
+
+    fused_computation.3 (param_0.7: f32[200,200], param_1.9: f32[200,200]) -> f32[200,200,200] {
+      param_1.9 = f32[200,200]{1,0} parameter(1)
+      broadcast.10 = f32[200,200,200]{2,1,0} broadcast(f32[200,200]{1,0} param_1.9), dimensions={0,1}
+      param_0.7 = f32[200,200]{1,0} parameter(0)
+      broadcast.8 = f32[200,200,200]{2,1,0} broadcast(f32[200,200]{1,0} param_0.7), dimensions={1,2}
+      ROOT add.1 = f32[200,200,200]{2,1,0} add(f32[200,200,200]{2,1,0} broadcast.10, f32[200,200,200]{2,1,0} broadcast.8)
+    }
+
+    ENTRY entry (parameter.1: f32[200,200], parameter.2: f32[200,200]) -> f32[200,200] {
+      parameter.2 = f32[200,200]{1,0} parameter(1)
+      parameter.1 = f32[200,200]{1,0} parameter(0)
+      fusion.3 = f32[200,200,200]{2,1,0} fusion(f32[200,200]{1,0} parameter.2, f32[200,200]{1,0} parameter.1), kind=kLoop, calls=fused_computation.3
+      constant.11 = f32[] constant(-inf)
+      reduce.16 = f32[200,200]{1,0} reduce(f32[200,200,200]{2,1,0} fusion.3, f32[] constant.11), dimensions={1}, to_apply=max
+      ROOT fusion.1 = f32[200,200]{1,0} fusion(f32[200,200,200]{2,1,0} fusion.3, f32[200,200]{1,0} reduce.16), kind=kInput, calls=fused_computation.1
+    })")
+                    .value();
+  EXPECT_TRUE(fusion_merger_.Run(module.get()).value());
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(m::Fusion(m::Fusion(), m::Parameter(), m::Parameter())));
+}
+
+TEST_F(FusionMergerTest, WillNotMergeReduceUnfriendlyLayouts) {
+  // TODO(b/247762001): the case here does not represent the problem -
+  // profiling shows that it works faster if merged (even on larger dimensions).
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule m
+
+    f1_computation {
+      f1_p0 = f32[16,16,256]{0,1,2} parameter(0)
+      add = f32[16,16,256]{0,1,2} add(f1_p0, f1_p0)
+      // Note that the copy changes the layout from {0,1,2} to {2,1,0}.
+      ROOT f1_root = f32[16,16,256]{2,1,0} copy(add)
+    }
+
+    add_computation {
+      add_lhs = f32[] parameter(0)
+      add_rhs = f32[] parameter(1)
+      ROOT add_root = f32[] add(add_lhs, add_rhs)
+    }
+
+    f2_computation {
+      f2_p0 = f32[16,16,256]{2,1,0} parameter(0)
+      f2_zero = f32[] constant(0)
+      ROOT f2_root = f32[] reduce(f2_p0, f2_zero), dimensions={0,1,2},
+             to_apply=add_computation
+    }
+
+    ENTRY entry {
+      p0 = f32[16,16,256]{0,1,2} parameter(0)
+      f1 = f32[16,16,256]{2,1,0} fusion(p0), kind=kLoop, calls=f1_computation
+      ROOT f2 = f32[] fusion(f1), kind=kInput, calls=f2_computation
+    })")
+                    .value();
+  EXPECT_FALSE(fusion_merger_.Run(module.get()).value());
+}
+
+TEST_F(FusionMergerTest, WillMergeReduceNotTooUnfriendlyLayouts) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule m
+
+    f1_computation {
+      f1_p0 = f32[16,16,256]{0,1,2} parameter(0)
+      slice1 = f32[5,16,256]{0,1,2} slice(f1_p0), slice={[0:5], [0:16], [0:256]}
+      // Here the copy changes the layout only of a part of the data.
+      f1_copy = f32[5,16,256]{2,1,0} copy(slice1)
+      slice2 = f32[11,16,256]{0,1,2} slice(f1_p0), slice={[0:11], [0:16], [0:256]}
+      bitcast = f32[11,16,256]{2,1,0} bitcast(slice2)
+      ROOT f1_root = f32[16,16,256]{2,1,0} concatenate(f1_copy, bitcast), dimensions={0}
+    }
+
+    add_computation {
+      add_lhs = f32[] parameter(0)
+      add_rhs = f32[] parameter(1)
+      ROOT add_root = f32[] add(add_lhs, add_rhs)
+    }
+
+    f2_computation {
+      f2_p0 = f32[16,16,256]{2,1,0} parameter(0)
+      f2_zero = f32[] constant(0)
+      ROOT f2_root = f32[16,16] reduce(f2_p0, f2_zero), dimensions={2},
+             to_apply=add_computation
+    }
+
+    ENTRY entry {
+      p0 = f32[16,16,256]{0,1,2} parameter(0)
+      f1 = f32[16,16,256]{2,1,0} fusion(p0), kind=kLoop, calls=f1_computation
+      ROOT f2 = f32[16,16] fusion(f1), kind=kInput, calls=f2_computation
+    })")
+                    .value();
+  EXPECT_TRUE(fusion_merger_.Run(module.get()).value());
+}
+
+// Check that we limit the number of operands to fusions we create.
+TEST_F(FusionMergerTest, AvoidsLargeFusion) {
+  constexpr int64_t kNumParams = MaxOperandsAndOutputsPerFusion() + 1;
+
+  // Compute
+  //   p0 + p1 + p2 + ... + pn,
+  // Use so many parameters that they do not fit into one fusion.
+  auto module = CreateNewVerifiedModule();
+  HloComputation::Builder b(TestName());
+  Shape shape = ShapeUtil::MakeShape(F32, {10, 100});
+
+  std::vector<HloInstruction*> entry_params;
+
+  for (int64_t i = 0; i < kNumParams; ++i) {
+    entry_params.push_back(
+        b.AddInstruction(HloInstruction::CreateParameter(i, shape, "p")));
+  }
+  auto make_fusion = [&](absl::Span<HloInstruction* const> params) {
+    // Build a fusion computation for calculating the sum of all parameters.
+    HloComputation::Builder sub_builder("subcomp");
+    HloInstruction* sum = nullptr;
+    for (int64_t i = 0; i < params.size(); ++i) {
+      auto p = sub_builder.AddInstruction(
+          HloInstruction::CreateParameter(i, shape, "p"));
+      if (sum == nullptr) {
+        sum = p;
+      } else {
+        sum = sub_builder.AddInstruction(
+            HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sum, p));
+      }
+    }
+    HloComputation* subcomp =
+        module->AddEmbeddedComputation(sub_builder.Build());
+    return HloInstruction::CreateFusion(
+        shape, HloInstruction::FusionKind::kLoop, params, subcomp);
+  };
+  auto fusion = b.AddInstruction(
+      make_fusion(absl::MakeSpan(entry_params)
+                      .subspan(0, MaxOperandsAndOutputsPerFusion())));
+  b.AddInstruction(make_fusion({entry_params.back(), fusion}));
+  module->AddEntryComputation(b.Build());
+  EXPECT_FALSE(fusion_merger_.Run(module.get()).value());
+}
+
+// TODO(b/119692968): Remove this test once fusion emitter is fixed.
+TEST_F(FusionMergerTest, WillNotMergeIfFusionEmitterIsInefficient) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+HloModule m
+
+f1 {
+  Arg_0.5 = f32[200000] parameter(0)
+  slice.7 = f32[100000] slice(Arg_0.5), slice={[0:199999:2]}
+  slice.8 = f32[100000] slice(Arg_0.5), slice={[1:200000:2]}
+  add.9 = f32[100000] add(slice.7, slice.8)
+  slice.10 = f32[50000] slice(add.9), slice={[0:99999:2]}
+  slice.11 = f32[50000] slice(add.9), slice={[1:100000:2]}
+  add.12 = f32[50000] add(slice.10, slice.11)
+  slice.13 = f32[25000] slice(add.12), slice={[0:49999:2]}
+  slice.14 = f32[25000] slice(add.12), slice={[1:50000:2]}
+  add.15 = f32[25000] add(slice.13, slice.14)
+  slice.16 = f32[12500] slice(add.15), slice={[0:24999:2]}
+  slice.17 = f32[12500] slice(add.15), slice={[1:25000:2]}
+  add.18 = f32[12500] add(slice.16, slice.17)
+  slice.19 = f32[6250] slice(add.18), slice={[0:12499:2]}
+  slice.20 = f32[6250] slice(add.18), slice={[1:12500:2]}
+  add.21 = f32[6250] add(slice.19, slice.20)
+  slice.22 = f32[3125] slice(add.21), slice={[0:6249:2]}
+  slice.23 = f32[3125] slice(add.21), slice={[1:6250:2]}
+  ROOT add.24 = f32[3125] add(slice.22, slice.23)
+}
+
+f2 {
+  Arg_0 = f32[3125] parameter(0)
+  slice.25 = f32[1562] slice(Arg_0), slice={[0:3124:2]}
+  slice.26 = f32[1562] slice(Arg_0), slice={[1:3125:2]}
+  add.27 = f32[1562] add(slice.25, slice.26)
+  slice.28 = f32[781] slice(add.27), slice={[0:1561:2]}
+  slice.29 = f32[781] slice(add.27), slice={[1:1562:2]}
+  add.30 = f32[781] add(slice.28, slice.29)
+  slice.31 = f32[390] slice(add.30), slice={[0:780:2]}
+  slice.32 = f32[390] slice(add.30), slice={[1:781:2]}
+  add.33 = f32[390] add(slice.31, slice.32)
+  slice.34 = f32[195] slice(add.33), slice={[0:389:2]}
+  slice.35 = f32[195] slice(add.33), slice={[1:390:2]}
+  add.36 = f32[195] add(slice.34, slice.35)
+  slice.37 = f32[97] slice(add.36), slice={[0:194:2]}
+  slice.38 = f32[97] slice(add.36), slice={[1:195:2]}
+  add.39 = f32[97] add(slice.37, slice.38)
+  slice.40 = f32[48] slice(add.39), slice={[0:96:2]}
+  slice.41 = f32[48] slice(add.39), slice={[1:97:2]}
+  ROOT add.42 = f32[48] add(slice.40, slice.41)
+}
+
+ENTRY e {
+  p0 = f32[200000] parameter(0)
+  f1 = f32[3125] fusion(p0), kind=kLoop, calls=f1
+  ROOT r = f32[48] fusion(f1), kind=kLoop, calls=f2
+})")
+                    .value();
+  EXPECT_FALSE(fusion_merger_.Run(module.get()).value());
+}
+
+TEST_F(FusionMergerTest, WillMergeSliceIntoReusingConsumer) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+HloModule m
+
+f1 {
+  p01 = s8[1000000] parameter(0)
+  ROOT s0 = s8[10] slice(p01), slice={[0:10]}
+}
+
+f2 {
+  p02 = s8[10] parameter(0)
+  ROOT b0 = s8[10,1000000] broadcast(p02), dimensions={0}
+}
+
+ENTRY e {
+  p0 = s8[1000000] parameter(0)
+  f1 = s8[10] fusion(p0), kind=kLoop, calls=f1
+  ROOT r = s8[10,1000000] fusion(f1), kind=kLoop, calls=f2
+})")
+                    .value();
+  EXPECT_TRUE(fusion_merger_.Run(module.get()).value());
+}
+
+TEST_F(FusionMergerTest, WillMergeExpensiveFusionsIfSavesMemory) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule m
+
+    %f_a (p: f32[]) -> f32[1024,1024,1024] {
+      %p = f32[] parameter(0)
+      %b = f32[1024,1024,1024] broadcast(%p), dimensions={}
+      ROOT %t = f32[1024,1024,1024] tanh(%b)
+    }
+
+    %f_b (p: f32[1024,1024,1024]) -> f32[1024,1024,1024] {
+      %p = f32[1024,1024,1024] parameter(0)
+      ROOT %t = f32[1024,1024,1024] tanh(%p)
+    }
+
+    %f_c (p: f32[1024,1024,1024]) -> f32[1024,1024,1024] {
+      %p = f32[1024,1024,1024] parameter(0)
+      ROOT %t = f32[1024,1024,1024] tanh(%p)
+    }
+
+    ENTRY entry {
+      p0 = f32[] parameter(0)
+      f1 = f32[1024,1024,1024] fusion(p0), kind=kLoop, calls=%f_a
+      f2 = f32[1024,1024,1024] fusion(f1), kind=kLoop, calls=%f_b
+      f3 = f32[1024,1024,1024] fusion(f1), kind=kLoop, calls=%f_c
+      ROOT f4 = f32[1024,1024,1024] add(f2, f3)
+    })")
+                    .value();
+  EXPECT_TRUE(fusion_merger_.Run(module.get()).value());
+}
+
+TEST_F(FusionMergerTest, WillMergeExpensiveFusionsWithSingleConsumer) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule m
+
+    %f_b (p: f32[1024,1024,1024]) -> f32[1024,1024,1024] {
+      %p = f32[1024,1024,1024] parameter(0)
+      ROOT %t = f32[1024,1024,1024] tanh(%p)
+    }
+
+    %f_c (p: f32[1024,1024,1024]) -> f32[1024,1024,1024] {
+      %p = f32[1024,1024,1024] parameter(0)
+      ROOT %t = f32[1024,1024,1024] add(%p, %p)
+    }
+
+    ENTRY entry {
+      p0 = f32[1024,1024,1024] parameter(0)
+      f1 = f32[1024,1024,1024] fusion(p0), kind=kLoop, calls=%f_b
+      ROOT f2 = f32[1024,1024,1024] fusion(f1), kind=kLoop, calls=%f_c
+    })")
+                    .value();
+  EXPECT_TRUE(fusion_merger_.Run(module.get()).value());
+}
+
+TEST_F(FusionMergerTest, WillNotMergeExpensiveFusionsWithReusingConsumer) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule m
+
+    %f_b {
+      %p = f32[1024,1024,1024] parameter(0)
+      %t1 = f32[1024,1024,1024] tanh(%p)
+      %t2 = f32[1024,1024,1024] tanh(%t1)
+      %t3 = f32[1024,1024,1024] tanh(%t2)
+      %t4 = f32[1024,1024,1024] tanh(%t3)
+      %t5 = f32[1024,1024,1024] tanh(%t4)
+      %t6 = f32[1024,1024,1024] tanh(%t5)
+      %t7 = f32[1024,1024,1024] tanh(%t6)
+      %t8 = f32[1024,1024,1024] tanh(%t7)
+      ROOT %t9 = f32[1024,1024,1024] tanh(%t8)
+    }
+
+    %f_c {
+      %p = f32[1024,1024,1024] parameter(0)
+      ROOT %t = f32[1024,1024,1024,2048] broadcast(%p), dimensions={0,1,2}
+    }
+
+    ENTRY entry {
+      p0 = f32[1024,1024,1024] parameter(0)
+      f1 = f32[1024,1024,1024] fusion(p0), kind=kLoop, calls=%f_b
+      ROOT f2 = f32[1024,1024,1024,2048] fusion(f1), kind=kLoop, calls=%f_c
+    })")
+                    .value();
+  EXPECT_FALSE(fusion_merger_.Run(module.get()).value());
+}
+
+TEST_F(FusionMergerTest, NoMergeWithBitcast) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+HloModule m
+
+f32add {
+  x.634 = f32[] parameter(0)
+  y.635 = f32[] parameter(1)
+  ROOT add.636 = f32[] add(x.634, y.635)
+}
+
+fused_computation.103 {
+  param_0.310 = f16[1,8,512,1536]{2,3,1,0} parameter(0)
+  param_1.420 = f32[8,512]{1,0} parameter(1)
+  bitcast.1144 = f32[1,8,512]{2,1,0} bitcast(param_1.420)
+  convert.252 = f16[1,8,512]{2,1,0} convert(bitcast.1144)
+  bitcast.1143 = f16[8,512]{1,0} bitcast(convert.252)
+  broadcast.481 = f16[1,8,512,1536]{2,3,1,0} broadcast(bitcast.1143), dimensions={1,2}
+  divide.15 = f16[1,8,512,1536]{2,3,1,0} divide(param_0.310, broadcast.481)
+  ROOT bitcast.1142 = f16[8,512,1536]{1,2,0} bitcast(divide.15)
+}
+
+fused_computation.105 {
+  param_1.426 = f16[8,1536,512]{2,1,0} parameter(1)
+  bitcast.1896 = f16[1,8,1536,512]{3,2,1,0} bitcast(param_1.426)
+  transpose.238 = f16[1,8,512,1536]{2,3,1,0} transpose(bitcast.1896), dimensions={0,1,3,2}
+  param_0.315 = f16[8,512]{1,0} parameter(0)
+  broadcast.482 = f16[1,8,512,1536]{2,3,1,0} broadcast(param_0.315), dimensions={1,2}
+  subtract.22 = f16[1,8,512,1536]{2,3,1,0} subtract(transpose.238, broadcast.482)
+  ROOT exponential.15 = f16[1,8,512,1536]{2,3,1,0} exponential(subtract.22)
+}
+
+fused_computation.104 {
+  param_0.1000 = f16[8,1536,512]{2,1,0} parameter(0)
+  convert.652 = f32[8,1536,512]{2,1,0} convert(param_0.1000)
+  constant_752 = f32[] constant(-0)
+  ROOT reduce.232 = f32[8,512]{1,0} reduce(convert.652, constant_752),
+  dimensions={1}, to_apply=f32add
+}
+
+ENTRY entry {
+  p0 = f16[8,1536,512]{2,1,0} parameter(0)
+  p1 = f16[8,512]{1,0} parameter(1)
+  fusion.105 = f16[1,8,512,1536]{2,3,1,0} fusion(p1, p0), kind=kLoop, calls=fused_computation.105
+  bitcast.1787 = f16[8,1536,512]{2,1,0} bitcast(fusion.105)
+  fusion.104 = f32[8,512]{1,0} fusion(bitcast.1787), kind=kInput, calls=fused_computation.104
+  ROOT fusion.103 = f16[8,512,1536]{1,2,0} fusion(fusion.105, fusion.104), kind=kLoop, calls=fused_computation.103
+}
+    )")
+                    .value();
+  EXPECT_FALSE(fusion_merger_.Run(module.get()).value());
+}
+
+TEST_F(FusionMergerTest, CostBasedMerge) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+HloModule m
+
+fused_computation.45 {
+  param_1.194 = f16[8,1536,512]{2,1,0} parameter(1)
+  bitcast.1042 = f16[1,8,512,1536]{2,3,1,0} bitcast(param_1.194)
+  param_0.135 = f16[8,512]{1,0} parameter(0)
+  broadcast.391 = f16[1,8,512,1536]{2,3,1,0} broadcast(param_0.135), dimensions={1,2}
+  subtract.6 = f16[1,8,512,1536]{2,3,1,0} subtract(bitcast.1042, broadcast.391)
+  ROOT exponential.11 = f16[1,8,512,1536]{2,3,1,0} exponential(subtract.6)
+}
+
+f32add {
+  x.634 = f32[] parameter(0)
+  y.635 = f32[] parameter(1)
+  ROOT add.636 = f32[] add(x.634, y.635)
+}
+
+fused_computation.44 {
+  param_0.869 = f16[1,8,512,1536]{2,3,1,0} parameter(0)
+  convert.221 = f32[1,8,512,1536]{2,3,1,0} convert(param_0.869)
+  transpose.212 = f32[1,8,1536,512]{3,2,1,0} transpose(convert.221), dimensions={0,1,3,2}
+  bitcast.1041 = f32[8,1536,512]{2,1,0} bitcast(transpose.212)
+  constant_429 = f32[] constant(0)
+  ROOT reduce.149 = f32[8,512]{1,0} reduce(bitcast.1041, constant_429), dimensions={1}, to_apply=f32add
+}
+
+fused_computation.43 {
+  param_0.130 = f16[1,8,512,1536]{2,3,1,0} parameter(0)
+  param_1.188 = f32[8,512]{1,0} parameter(1)
+  bitcast.1040 = f32[1,8,512]{2,1,0} bitcast(param_1.188)
+  convert.220 = f16[1,8,512]{2,1,0} convert(bitcast.1040)
+  bitcast.1039 = f16[8,512]{1,0} bitcast(convert.220)
+  broadcast.390 = f16[1,8,512,1536]{2,3,1,0} broadcast(bitcast.1039), dimensions={1,2}
+  divide.11 = f16[1,8,512,1536]{2,3,1,0} divide(param_0.130, broadcast.390)
+  ROOT bitcast.1038 = f16[8,512,1536]{1,2,0} bitcast(divide.11)
+}
+
+ENTRY entry {
+  p0 = f16[8,1536,512]{2,1,0} parameter(0)
+  p1 = f16[8,512]{1,0} parameter(1)
+  fusion.45 = f16[1,8,512,1536]{2,3,1,0} fusion(p1, p0), kind=kLoop, calls=fused_computation.45
+  fusion.44 = f32[8,512]{1,0} fusion(fusion.45), kind=kInput, calls=fused_computation.44
+  ROOT fusion.43 = f16[8,512,1536]{1,2,0} fusion(fusion.45, fusion.44), kind=kLoop, calls=fused_computation.43
+}
+    )")
+                    .value();
+  auto& debug_options = module->mutable_config().mutable_debug_options();
+  // For some reason, we would not merge any fusions when using the MLIR
+  // reduction emitter. The cost model queries the reduction emitter regarding
+  // the launch dimensions, so it seems likely that it is caused by different
+  // launch dimensions.
+  debug_options.set_xla_gpu_mlir_emitter_level(3);
+  EXPECT_TRUE(fusion_merger_.Run(module.get()).value());
+}
+
+// Outputs of fusions 66 and 67 here are heavily reused by fusion 59 - so
+// it is better to not merge here.
+TEST_F(FusionMergerTest, CostBasedNoMerge) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+HloModule m
+
+add_float_.56 {
+  x.57 = f32[] parameter(0)
+  y.58 = f32[] parameter(1)
+  ROOT add.59 = f32[] add(x.57, y.58)
+}
+
+fused_computation.66 {
+  constant.635 = f32[] constant(0)
+  broadcast.257 = f32[459,3]{1,0} broadcast(constant.635), dimensions={}
+  constant.641 = f32[] constant(1)
+  broadcast.256 = f32[459,3]{1,0} broadcast(constant.641), dimensions={}
+  broadcast.255 = f32[459]{0} broadcast(constant.635), dimensions={}
+  iota.28 = f32[459]{0} iota(), iota_dimension=0
+  constant.629 = f32[] constant(1.49891067)
+  broadcast.253 = f32[459]{0} broadcast(constant.629), dimensions={}
+  multiply.39 = f32[459]{0} multiply(iota.28, broadcast.253)
+  constant.633 = f32[] constant(-1)
+  broadcast.252 = f32[459]{0} broadcast(constant.633), dimensions={}
+  add.31 = f32[459]{0} add(multiply.39, broadcast.252)
+  ceil.11 = f32[459]{0} ceil(add.31)
+  constant.630 = f32[] constant(685)
+  broadcast.251 = f32[459]{0} broadcast(constant.630), dimensions={}
+  clamp.49 = f32[459]{0} clamp(broadcast.255, ceil.11, broadcast.251)
+  subtract.11 = f32[459]{0} subtract(clamp.49, multiply.39)
+  broadcast.249 = f32[459,3]{1,0} broadcast(subtract.11), dimensions={0}
+  iota.26 = f32[459,3]{1,0} iota(), iota_dimension=1
+  add.30 = f32[459,3]{1,0} add(broadcast.249, iota.26)
+  abs.3 = f32[459,3]{1,0} abs(add.30)
+  subtract.10 = f32[459,3]{1,0} subtract(broadcast.256, abs.3)
+  maximum.6 = f32[459,3]{1,0} maximum(broadcast.257, subtract.10)
+  ROOT reduce.3 = f32[459]{0} reduce(maximum.6, constant.635), dimensions={1}, to_apply=add_float_.56
+}
+
+fused_computation.67 {
+  constant.684 = f32[] constant(0)
+  broadcast.296 = f32[1130,3]{1,0} broadcast(constant.684), dimensions={}
+  constant.685 = f32[] constant(1)
+  broadcast.295 = f32[1130,3]{1,0} broadcast(constant.685), dimensions={}
+  broadcast.294 = f32[1130]{0} broadcast(constant.684), dimensions={}
+  iota.41 = f32[1130]{0} iota(), iota_dimension=0
+  constant.675 = f32[] constant(1.34513271)
+  broadcast.293 = f32[1130]{0} broadcast(constant.675), dimensions={}
+  multiply.47 = f32[1130]{0} multiply(iota.41, broadcast.293)
+  constant.677 = f32[] constant(-1)
+  broadcast.290 = f32[1130]{0} broadcast(constant.677), dimensions={}
+  add.39 = f32[1130]{0} add(multiply.47, broadcast.290)
+  ceil.15 = f32[1130]{0} ceil(add.39)
+  constant.676 = f32[] constant(1517)
+  broadcast.289 = f32[1130]{0} broadcast(constant.676), dimensions={}
+  clamp.53 = f32[1130]{0} clamp(broadcast.294, ceil.15, broadcast.289)
+  subtract.19 = f32[1130]{0} subtract(clamp.53, multiply.47)
+  broadcast.287 = f32[1130,3]{1,0} broadcast(subtract.19), dimensions={0}
+  iota.39 = f32[1130,3]{1,0} iota(), iota_dimension=1
+  add.38 = f32[1130,3]{1,0} add(broadcast.287, iota.39)
+  abs.7 = f32[1130,3]{1,0} abs(add.38)
+  subtract.18 = f32[1130,3]{1,0} subtract(broadcast.295, abs.7)
+  maximum.10 = f32[1130,3]{1,0} maximum(broadcast.296, subtract.18)
+  ROOT reduce.4 = f32[1130]{0} reduce(maximum.10, constant.684), dimensions={1}, to_apply=add_float_.56
+}
+
+fused_computation.59 {
+  constant.532 = f32[] constant(0)
+  broadcast.316 = f32[1130,3]{1,0} broadcast(constant.532), dimensions={}
+  constant.663 = f32[] constant(1)
+  broadcast.315 = f32[1130,3]{1,0} broadcast(constant.663), dimensions={}
+  broadcast.314 = f32[1130]{0} broadcast(constant.532), dimensions={}
+  iota.47 = f32[1130]{0} iota(), iota_dimension=0
+  constant.579 = f32[] constant(1.34513271)
+  broadcast.311 = f32[1130]{0} broadcast(constant.579), dimensions={}
+  multiply.51 = f32[1130]{0} multiply(iota.47, broadcast.311)
+  constant.578 = f32[] constant(-1)
+  broadcast.310 = f32[1130]{0} broadcast(constant.578), dimensions={}
+  add.43 = f32[1130]{0} add(multiply.51, broadcast.310)
+  ceil.17 = f32[1130]{0} ceil(add.43)
+  constant.576 = f32[] constant(1517)
+  broadcast.309 = f32[1130]{0} broadcast(constant.576), dimensions={}
+  clamp.55 = f32[1130]{0} clamp(broadcast.314, ceil.17, broadcast.309)
+  subtract.24 = f32[1130]{0} subtract(clamp.55, multiply.51)
+  broadcast.306 = f32[1130,3]{1,0} broadcast(subtract.24), dimensions={0}
+  iota.45 = f32[1130,3]{1,0} iota(), iota_dimension=1
+  add.42 = f32[1130,3]{1,0} add(broadcast.306, iota.45)
+  abs.9 = f32[1130,3]{1,0} abs(add.42)
+  subtract.23 = f32[1130,3]{1,0} subtract(broadcast.315, abs.9)
+  maximum.12 = f32[1130,3]{1,0} maximum(broadcast.316, subtract.23)
+  param_2.183 = f32[1130]{0} parameter(2)
+  broadcast.172 = f32[1130,3]{1,0} broadcast(param_2.183), dimensions={0}
+  divide.3 = f32[1130,3]{1,0} divide(maximum.12, broadcast.172)
+  bitcast.53 = f32[3390]{0} bitcast(divide.3)
+  broadcast.171 = f32[3390,1377]{1,0} broadcast(bitcast.53), dimensions={0}
+  broadcast.276 = f32[459,3]{1,0} broadcast(constant.532), dimensions={}
+  broadcast.275 = f32[459,3]{1,0} broadcast(constant.663), dimensions={}
+  broadcast.274 = f32[459]{0} broadcast(constant.532), dimensions={}
+  iota.35 = f32[459]{0} iota(), iota_dimension=0
+  constant.614 = f32[] constant(1.49891067)
+  broadcast.273 = f32[459]{0} broadcast(constant.614), dimensions={}
+  multiply.43 = f32[459]{0} multiply(iota.35, broadcast.273)
+  broadcast.272 = f32[459]{0} broadcast(constant.578), dimensions={}
+  add.35 = f32[459]{0} add(multiply.43, broadcast.272)
+  ceil.13 = f32[459]{0} ceil(add.35)
+  constant.611 = f32[] constant(685)
+  broadcast.269 = f32[459]{0} broadcast(constant.611), dimensions={}
+  clamp.51 = f32[459]{0} clamp(broadcast.274, ceil.13, broadcast.269)
+  subtract.15 = f32[459]{0} subtract(clamp.51, multiply.43)
+  broadcast.267 = f32[459,3]{1,0} broadcast(subtract.15), dimensions={0}
+  iota.33 = f32[459,3]{1,0} iota(), iota_dimension=1
+  add.34 = f32[459,3]{1,0} add(broadcast.267, iota.33)
+  abs.5 = f32[459,3]{1,0} abs(add.34)
+  subtract.14 = f32[459,3]{1,0} subtract(broadcast.275, abs.5)
+  maximum.8 = f32[459,3]{1,0} maximum(broadcast.276, subtract.14)
+  param_1.177 = f32[459]{0} parameter(1)
+  broadcast.170 = f32[459,3]{1,0} broadcast(param_1.177), dimensions={0}
+  divide.2 = f32[459,3]{1,0} divide(maximum.8, broadcast.170)
+  bitcast.52 = f32[1377]{0} bitcast(divide.2)
+  broadcast.169 = f32[3390,1377]{1,0} broadcast(bitcast.52), dimensions={1}
+  multiply.15 = f32[3390,1377]{1,0} multiply(broadcast.171, broadcast.169)
+  bitcast.61 = f32[1130,3,459,3]{3,2,1,0} bitcast(multiply.15)
+  transpose.68 = f32[459,1130,3,3]{2,0,3,1} transpose(bitcast.61), dimensions={2,0,3,1}
+  copy.1 = f32[459,1130,3,3]{3,2,1,0} copy(transpose.68)
+  bitcast.50 = f32[1130,459,9]{2,1,0} bitcast(copy.1)
+  broadcast.168 = f32[1130,459,6,9]{3,2,1,0} broadcast(bitcast.50), dimensions={0,1,3}
+  param_0.171 = u8[1,688,1520,6]{3,2,1,0} parameter(0)
+  bitcast.49 = u8[688,1520,1,6]{3,1,0,2} bitcast(param_0.171)
+  convert.175 = f32[688,1520,1,6]{3,1,0,2} convert(bitcast.49)
+  broadcast.167 = f32[459,1130,1]{2,1,0} broadcast(clamp.51), dimensions={0}
+  broadcast.166 = f32[459,1130,1]{2,1,0} broadcast(clamp.55), dimensions={1}
+  concatenate.3 = f32[459,1130,2]{2,1,0} concatenate(broadcast.167, broadcast.166), dimensions={2}
+  convert.174 = s32[459,1130,2]{2,1,0} convert(concatenate.3)
+  bitcast.48 = s32[518670,2]{1,0} bitcast(convert.174)
+  gather.1 = f32[518670,3,3,1,6]{2,1,4,0,3} gather(convert.175, bitcast.48), offset_dims={1,2,3,4}, collapsed_slice_dims={}, start_index_map={0,1}, index_vector_dim=1, slice_sizes={3,3,1,6}
+  transpose.69 = f32[1,518670,6,3,3]{4,3,2,1,0} transpose(gather.1), dimensions={3,0,4,1,2}
+  bitcast.47 = f32[1130,459,6,9]{3,2,1,0} bitcast(transpose.69)
+  multiply.14 = f32[1130,459,6,9]{3,2,1,0} multiply(broadcast.168, bitcast.47)
+  reduce.2 = f32[1130,459,6]{2,1,0} reduce(multiply.14, constant.532), dimensions={3}, to_apply=add_float_.56
+  convert.173 = f16[1130,459,6]{2,1,0} convert(reduce.2)
+  bitcast.46 = f16[1,459,1130,6]{3,2,1,0} bitcast(convert.173)
+  constant.533 = f16[] constant(0)
+  pad.9 = f16[1,480,1130,6]{3,2,1,0} pad(bitcast.46, constant.533), padding=0_0x0_21x0_0x0_0
+  pad.8 = f16[1,480,1152,6]{3,2,1,0} pad(pad.9, constant.533), padding=0_0x0_0x0_22x0_0
+  constant.532f16 = f16[] constant(0)
+  ROOT pad.7 = f16[1,485,1157,6]{3,2,1,0} pad(pad.8, constant.532f16), padding=0_0x2_3x2_3x0_0
+}
+
+ENTRY e {
+  arg0.1 = u8[1,688,1520,6]{3,2,1,0} parameter(0), parameter_replication={false}
+  fusion.66 = f32[459]{0} fusion(), kind=kLoop, calls=fused_computation.66
+  fusion.67 = f32[1130]{0} fusion(), kind=kLoop, calls=fused_computation.67
+  ROOT fusion.59 = f16[1,485,1157,6]{2,1,3,0} fusion(arg0.1, fusion.66, fusion.67), kind=kLoop, calls=fused_computation.59
+}
+    )")
+                    .value();
+  EXPECT_FALSE(fusion_merger_.Run(module.get()).value());
+}
+
+TEST_F(FusionMergerTest, NoMergeBecauseTooManyBasicBlockSplits) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+HloModule m
+
+region_6.97 {
+  Arg_0.98 = pred[] parameter(0)
+  Arg_1.99 = pred[] parameter(1)
+  ROOT or.100 = pred[] or(Arg_0.98, Arg_1.99)
+}
+
+region_4.50 {
+  Arg_0.51 = f64[] parameter(0)
+  Arg_1.52 = f64[] parameter(1)
+  ROOT add.53 = f64[] add(Arg_0.51, Arg_1.52)
+}
+
+f2 {
+  param_0 = s64[1]{0} parameter(0)
+  constant_70 = f64[] constant(0)
+  convert.41.clone.1 = f64[1]{0} convert(param_0)
+  ROOT pad.99.clone.1 = f64[3]{0} pad(convert.41.clone.1, constant_70), padding=0_2
+}
+
+f1 {
+  param_0.361 = pred[5]{0} parameter(0)
+  broadcast.107 = pred[10,5]{1,0} broadcast(param_0.361), dimensions={1}
+  param_6.244 = pred[5]{0} parameter(6)
+  broadcast.111.clone.1 = pred[10,5]{1,0} broadcast(param_6.244), dimensions={1}
+  param_1.450 = f64[10,5]{1,0} parameter(1)
+  constant_294_clone_1 = f64[] constant(1)
+  broadcast.153.clone.1 = f64[10,5]{1,0} broadcast(constant_294_clone_1), dimensions={}
+  compare.22.clone.1 = pred[10,5]{1,0} compare(param_1.450, broadcast.153.clone.1), direction=GE
+  constant_75_clone_1 = f64[] constant(-1)
+  broadcast.109.clone.1 = f64[10,5]{1,0} broadcast(constant_75_clone_1), dimensions={}
+  add.34.clone.1 = f64[10,5]{1,0} add(param_1.450, broadcast.109.clone.1)
+  param_5.322 = f64[10,5,4]{1,0,2} parameter(5)
+  slice.45.clone.1 = f64[10,5,1]{1,0,2} slice(param_5.322), slice={[0:10], [0:5], [3:4]}
+  bitcast.94.clone.1 = f64[10,5]{1,0} bitcast(slice.45.clone.1)
+  divide.7.clone.1 = f64[10,5]{1,0} divide(add.34.clone.1, bitcast.94.clone.1)
+  add.33.clone.1 = f64[10,5]{1,0} add(divide.7.clone.1, broadcast.153.clone.1)
+  constant_70 = f64[] constant(0)
+  broadcast.157.clone.1 = f64[10,5]{1,0} broadcast(constant_70), dimensions={}
+  compare.26.clone.1 = pred[10,5]{1,0} compare(param_1.450, broadcast.157.clone.1), direction=LE
+  slice.46.clone.1 = f64[10,5,1]{1,0,2} slice(param_5.322), slice={[0:10], [0:5], [0:1]}
+  bitcast.93.clone.1 = f64[10,5]{1,0} bitcast(slice.46.clone.1)
+  divide.6.clone.1 = f64[10,5]{1,0} divide(param_1.450, bitcast.93.clone.1)
+  broadcast.295.clone.1 = f64[10,5,3]{1,0,2} broadcast(param_1.450), dimensions={0,1}
+  param_4.368 = f64[10,5,2]{1,0,2} parameter(4)
+  pad.103.clone.1 = f64[10,5,3]{1,0,2} pad(param_4.368, constant_70), padding=0_0x0_0x1_0
+  compare.121.clone.1 = pred[10,5,3]{1,0,2} compare(broadcast.295.clone.1, pad.103.clone.1), direction=GE
+  pad.102.clone.1 = f64[10,5,3]{1,0,2} pad(param_4.368, constant_294_clone_1), padding=0_0x0_0x0_1
+  compare.120.clone.1 = pred[10,5,3]{1,0,2} compare(broadcast.295.clone.1, pad.102.clone.1), direction=LT
+  and.39.clone.1 = pred[10,5,3]{1,0,2} and(compare.121.clone.1, compare.120.clone.1)
+  transpose.9 = pred[3,10,5]{2,1,0} transpose(and.39.clone.1), dimensions={2,0,1}
+  constant_296_clone_1 = pred[] constant(false)
+  reduce.91.clone.1 = pred[10,5]{1,0} reduce(transpose.9, constant_296_clone_1), dimensions={0}, to_apply=region_6.97
+  broadcast.294.clone.1 = pred[10,5,3]{1,0,2} broadcast(reduce.91.clone.1), dimensions={0,1}
+  pad.99.clone.1 = f64[3]{0} parameter(3)
+  broadcast.292.clone.1 = f64[3]{0} broadcast(constant_70), dimensions={}
+  compare.117.clone.1 = pred[3]{0} compare(pad.99.clone.1, broadcast.292.clone.1), direction=NE
+  broadcast.290.clone.1 = pred[10,5,3]{1,0,2} broadcast(compare.117.clone.1), dimensions={2}
+  select.67.clone.1 = pred[10,5,3]{1,0,2} select(broadcast.294.clone.1, and.39.clone.1, broadcast.290.clone.1)
+  convert.40.clone.1 = f64[10,5,3]{1,0,2} convert(select.67.clone.1)
+  broadcast.288.clone.1 = f64[10,5,3,3]{1,0,2,3} broadcast(convert.40.clone.1), dimensions={0,1,2}
+  param_2.361 = f64[10,5,4,3]{1,0,2,3} parameter(2)
+  slice.114.clone.1 = f64[10,5,3,3]{1,0,2,3} slice(param_2.361), slice={[0:10], [0:5], [1:4], [0:3]}
+  multiply.53.clone.1 = f64[10,5,3,3]{1,0,2,3} multiply(broadcast.288.clone.1, slice.114.clone.1)
+  transpose.10 = f64[3,3,10,5]{3,2,1,0} transpose(multiply.53.clone.1), dimensions={3,2,0,1}
+  reduce.90.clone.1 = f64[3,10,5]{2,1,0} reduce(transpose.10, constant_70), dimensions={1}, to_apply=region_4.50
+  transpose.11 = f64[10,5,3]{1,0,2} transpose(reduce.90.clone.1), dimensions={1,2,0}
+  slice.28.clone.1 = f64[10,5,1]{1,0,2} slice(transpose.11), slice={[0:10], [0:5], [0:1]}
+  bitcast.99.clone.1 = f64[10,5]{1,0} bitcast(slice.28.clone.1)
+  slice.108.clone.1 = f64[10,5,3,3]{1,0,2,3} slice(param_2.361), slice={[0:10], [0:5], [0:3], [0:3]}
+  multiply.49.clone.1 = f64[10,5,3,3]{1,0,2,3} multiply(broadcast.288.clone.1, slice.108.clone.1)
+  transpose.12 = f64[3,3,10,5]{3,2,1,0} transpose(multiply.49.clone.1), dimensions={3,2,0,1}
+  reduce.82.clone.1 = f64[3,10,5]{2,1,0} reduce(transpose.12, constant_70), dimensions={1}, to_apply=region_4.50
+  transpose.13 = f64[10,5,3]{1,0,2} transpose(reduce.82.clone.1), dimensions={1,2,0}
+  slice.107.clone.1 = f64[10,5,1]{1,0,2} slice(transpose.13), slice={[0:10], [0:5], [0:1]}
+  bitcast.240.clone.1 = f64[10,5]{1,0} bitcast(slice.107.clone.1)
+  subtract.27.clone.1 = f64[10,5]{1,0} subtract(bitcast.99.clone.1, bitcast.240.clone.1)
+  slice.27.clone.1 = f64[10,5,1]{1,0,2} slice(transpose.13), slice={[0:10], [0:5], [2:3]}
+  bitcast.98.clone.1 = f64[10,5]{1,0} bitcast(slice.27.clone.1)
+  slice.26.clone.1 = f64[10,5,1]{1,0,2} slice(transpose.11), slice={[0:10], [0:5], [2:3]}
+  bitcast.97.clone.1 = f64[10,5]{1,0} bitcast(slice.26.clone.1)
+  add.36.clone.1 = f64[10,5]{1,0} add(bitcast.97.clone.1, bitcast.98.clone.1)
+  slice.24.clone.1 = f64[10,5,1]{1,0,2} slice(transpose.11), slice={[0:10], [0:5], [1:2]}
+  bitcast.95.clone.1 = f64[10,5]{1,0} bitcast(slice.24.clone.1)
+  slice.121.clone.1 = f64[10,5,1]{1,0,2} slice(transpose.13), slice={[0:10], [0:5], [1:2]}
+  bitcast.274.clone.1 = f64[10,5]{1,0} bitcast(slice.121.clone.1)
+  subtract.26.clone.1 = f64[10,5]{1,0} subtract(bitcast.95.clone.1, bitcast.274.clone.1)
+  divide.21 = f64[10,5]{1,0} divide(subtract.26.clone.1, subtract.27.clone.1)
+  constant_77_clone_1 = f64[] constant(2)
+  broadcast.117.clone.1 = f64[10,5]{1,0} broadcast(constant_77_clone_1), dimensions={}
+  multiply.37.clone.1 = f64[10,5]{1,0} multiply(divide.21, broadcast.117.clone.1)
+  subtract.25.clone.1 = f64[10,5]{1,0} subtract(add.36.clone.1, multiply.37.clone.1)
+  subtract.24.clone.1 = f64[10,5]{1,0} subtract(param_1.450, bitcast.274.clone.1)
+  divide.9.clone.1 = f64[10,5]{1,0} divide(subtract.24.clone.1, subtract.26.clone.1)
+  clamp.7.clone.1 = f64[10,5]{1,0} clamp(broadcast.157.clone.1, divide.9.clone.1, broadcast.153.clone.1)
+  multiply.36.clone.1 = f64[10,5]{1,0} multiply(subtract.25.clone.1, clamp.7.clone.1)
+  subtract.23.clone.1 = f64[10,5]{1,0} subtract(bitcast.98.clone.1, multiply.36.clone.1)
+  compare.13.clone.1 = pred[10,5]{1,0} compare(subtract.23.clone.1, broadcast.157.clone.1), direction=GE
+  negate.19.clone.1 = f64[10,5]{1,0} negate(divide.21)
+  multiply.35.clone.1 = f64[10,5]{1,0} multiply(negate.19.clone.1, clamp.7.clone.1)
+  multiply.34.clone.1 = f64[10,5]{1,0} multiply(multiply.35.clone.1, broadcast.117.clone.1)
+  negate.18.clone.1 = f64[10,5]{1,0} negate(subtract.23.clone.1)
+  multiply.33.clone.1 = f64[10,5]{1,0} multiply(subtract.23.clone.1, subtract.23.clone.1)
+  subtract.22.clone.1 = f64[10,5]{1,0} subtract(divide.21, subtract.23.clone.1)
+  constant_78_clone_1 = f64[] constant(4)
+  broadcast.113.clone.1 = f64[10,5]{1,0} broadcast(constant_78_clone_1), dimensions={}
+  multiply.32.clone.1 = f64[10,5]{1,0} multiply(subtract.22.clone.1, broadcast.113.clone.1)
+  multiply.31.clone.1 = f64[10,5]{1,0} multiply(multiply.32.clone.1, multiply.35.clone.1)
+  subtract.21.clone.1 = f64[10,5]{1,0} subtract(multiply.33.clone.1, multiply.31.clone.1)
+  compare.12.clone.1 = pred[10,5]{1,0} compare(subtract.21.clone.1, broadcast.157.clone.1), direction=GT
+  constant_79_clone_1 = f64[] constant(2.2250738585072014e-308)
+  broadcast.112.clone.1 = f64[10,5]{1,0} broadcast(constant_79_clone_1), dimensions={}
+  maximum.18.clone.1 = f64[10,5]{1,0} maximum(broadcast.112.clone.1, subtract.21.clone.1)
+  sqrt.1.clone.1 = f64[10,5]{1,0} sqrt(maximum.18.clone.1)
+  select.47.clone.1 = f64[10,5]{1,0} select(compare.12.clone.1, sqrt.1.clone.1, broadcast.157.clone.1)
+  add.35.clone.1 = f64[10,5]{1,0} add(negate.18.clone.1, select.47.clone.1)
+  select.46.clone.1 = f64[10,5]{1,0} select(compare.13.clone.1, multiply.34.clone.1, add.35.clone.1)
+  subtract.20.clone.1 = f64[10,5]{1,0} subtract(negate.18.clone.1, select.47.clone.1)
+  multiply.30.clone.1 = f64[10,5]{1,0} multiply(subtract.22.clone.1, broadcast.117.clone.1)
+  select.45.clone.1 = f64[10,5]{1,0} select(compare.13.clone.1, subtract.20.clone.1, multiply.30.clone.1)
+  divide.8.clone.1 = f64[10,5]{1,0} divide(select.46.clone.1, select.45.clone.1)
+  clamp.6.clone.1 = f64[10,5]{1,0} clamp(broadcast.157.clone.1, divide.8.clone.1, broadcast.153.clone.1)
+  multiply.29.clone.1 = f64[10,5]{1,0} multiply(subtract.27.clone.1, clamp.6.clone.1)
+  add.32.clone.1 = f64[10,5]{1,0} add(multiply.29.clone.1, bitcast.240.clone.1)
+  select.44.clone.1 = f64[10,5]{1,0} select(compare.26.clone.1, divide.6.clone.1, add.32.clone.1)
+  select.43.clone.1 = f64[10,5]{1,0} select(compare.22.clone.1, add.33.clone.1, select.44.clone.1)
+  select.42.clone.1 = f64[10,5]{1,0} select(broadcast.111.clone.1, param_1.450, select.43.clone.1)
+  select.41 = f64[10,5]{1,0} select(broadcast.107, select.42.clone.1, broadcast.157.clone.1)
+  ROOT tuple.14 = (f64[10,5]{1,0}, f64[10,5]{1,0}, f64[10,5]{1,0}, f64[10,5]{1,0}, f64[10,5]{1,0}, f64[10,5]{1,0}, f64[10,5]{1,0}, f64[10,5]{1,0}) tuple(select.41, select.42.clone.1, clamp.6.clone.1, subtract.25.clone.1, bitcast.97.clone.1, multiply.37.clone.1, bitcast.98.clone.1, divide.21)
+}
+
+ENTRY e {
+  p3 = s64[1]{0} parameter(3)
+  f2 = f64[3]{0} fusion(p3), kind=kLoop, calls=f2
+
+  p0 = pred[5]{0} parameter(0)
+  p1 = f64[10,5]{1,0} parameter(1)
+  p2 = f64[10,5,4,3]{1,0,2,3} parameter(2)
+  p4 = f64[10,5,2]{1,0,2} parameter(4)
+  p5 = f64[10,5,4]{1,0,2} parameter(5)
+  p6 = pred[5]{0} parameter(6)
+  ROOT ret = (f64[10,5]{1,0}, f64[10,5]{1,0}, f64[10,5]{1,0}, f64[10,5]{1,0}, f64[10,5]{1,0}, f64[10,5]{1,0}, f64[10,5]{1,0}, f64[10,5]{1,0}) fusion(p0, p1, p2, f2, p4, p5, p6), kind=kLoop, calls=f1
+}
+  )")
+                    .value();
+  auto& debug_options = module->mutable_config().mutable_debug_options();
+  debug_options.set_xla_gpu_mlir_emitter_level(3);
+  EXPECT_FALSE(fusion_merger_.Run(module.get()).value());
+}
+
+TEST_F(FusionMergerTest, CommonElementwiseUsedParameter) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule m
+
+    p {
+      p0 = f32[10000000] parameter(0)
+      p1 = f32[10000000] parameter(1)
+      p2 = f32[10000000] parameter(2)
+      p3 = f32[10000000] parameter(3)
+      a0 = f32[10000000] add(p1, p2)
+      a1 = f32[10000000] add(a0, p3)
+      ROOT _ = add(p0, a1)
+    }
+
+    c1 {
+      p0 = f32[10000000] parameter(0)
+      p1 = f32[10000000] parameter(1)
+      ROOT _ = add(p0, p1)
+    }
+
+    c2 {
+      p0 = f32[10000000] parameter(0)
+      p1 = f32[10000000] parameter(1)
+      ROOT _ = multiply(p0, p1)
+    }
+
+    ENTRY entry {
+      p0 = f32[10000000] parameter(0)
+      p1 = f32[10000000] parameter(1)
+      p2 = f32[10000000] parameter(2)
+      p3 = f32[10000000] parameter(3)
+      f = f32[10000000] fusion(p0, p1, p2, p3), kind=kLoop, calls=p
+      f1 = f32[10000000] fusion(p0, f), kind=kLoop, calls=c1
+      f2 = f32[10000000] fusion(p1, f), kind=kLoop, calls=c2
+      ROOT _ = (f32[10000000], f32[10000000]) tuple(f1, f2)
+    }
+    )")
+                    .value();
+  EXPECT_TRUE(fusion_merger_.Run(module.get()).value());
+}
+
+TEST_F(FusionMergerTest, IncompatibleNonTrivialHeroes) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule module
+
+    fused_computation {
+      param_0.1 = f32[18,16,32]{2,1,0} parameter(0)
+      param_1.1 = f32[32,16,18]{2,1,0} parameter(1)
+      s.1 = f32[18,16,32]{2,1,0} sqrt(param_0.1)
+      t.1 = f32[32,16,18]{2,1,0} transpose(s.1), dimensions={2,1,0}
+      sub.1 = f32[32,16,18]{2,1,0} subtract(t.1, param_1.1)
+      exp.1 = f32[32,16,18]{2,1,0} exponential(sub.1)
+      ROOT add.1 = f32[32,16,18]{2,1,0} add(exp.1, exp.1)
+    }
+
+    fused_computation.2 {
+      param_0.2 = f32[32,16,18]{2,1,0} parameter(0)
+      s.2 = f32[32,16,18]{2,1,0} sqrt(param_0.2)
+      ROOT t.2 = f32[32,18,16]{2,1,0} transpose(s.2), dimensions={0,2,1}
+    }
+
+    ENTRY main {
+      p = f32[18,16,32]{2,1,0} parameter(0)
+      p2 = f32[32,16,18]{2,1,0} parameter(1)
+      fusion = f32[32,16,18]{2,1,0} fusion(p, p2), kind=kLoop, calls=fused_computation
+      ROOT fusion2 = f32[32,18,16]{2,1,0} fusion(fusion), kind=kInput, calls=fused_computation.2
+    }
+    )")
+                    .value();
+  EXPECT_FALSE(fusion_merger_.Run(module.get()).value());
+}
+
+TEST_F(FusionMergerTest, DoNotMergeDUSFusions) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule module
+
+    %fused_computation (param_0: f32[8], param_1.2: f32[], param_2.3: f32[8]) -> f32[8] {
+      %param_0 = f32[8]{0} parameter(0)
+      %param_2.3 = f32[8]{0} parameter(2)
+      %slice.2 = f32[5]{0} slice(f32[8]{0} %param_2.3), slice={[0:5]}
+      %param_1.2 = f32[] parameter(1)
+      %broadcast.2 = f32[5]{0} broadcast(f32[] %param_1.2), dimensions={}
+      %add.2 = f32[5]{0} add(f32[5]{0} %slice.2, f32[5]{0} %broadcast.2)
+      %two.1 = s32[] constant(2)
+      ROOT %dynamic-update-slice.2 = f32[8]{0} dynamic-update-slice(f32[8]{0} %param_0, f32[5]{0} %add.2, s32[] %two.1)
+    }
+
+    %fused_computation.1 (param_0.1: f32[8], param_1.4: f32[6], param_2.6: f32[]) -> f32[8] {
+      %param_0.1 = f32[8]{0} parameter(0)
+      %param_1.4 = f32[6]{0} parameter(1)
+      %param_2.6 = f32[] parameter(2)
+      %broadcast.3 = f32[6]{0} broadcast(f32[] %param_2.6), dimensions={}
+      %add.3 = f32[6]{0} add(f32[6]{0} %param_1.4, f32[6]{0} %broadcast.3)
+      %three.1 = s32[] constant(3)
+      ROOT %dynamic-update-slice.3 = f32[8]{0} dynamic-update-slice(f32[8]{0} %param_0.1, f32[6]{0} %add.3, s32[] %three.1)
+    }
+
+    ENTRY %Test (parameter: f32[8]) -> f32[8] {
+      %parameter = f32[8]{0} parameter(0)
+      %slice.1 = f32[6]{0} slice(f32[8]{0} %parameter), slice={[0:6]}
+      %one = f32[] constant(1)
+      %fusion.1 = f32[8]{0} fusion(f32[8]{0} %parameter, f32[6]{0} %slice.1, f32[] %one), kind=kLoop, calls=%fused_computation.1
+      ROOT %fusion = f32[8]{0} fusion(f32[8]{0} %fusion.1, f32[] %one, f32[8]{0} %parameter), kind=kLoop, calls=%fused_computation
+    }
+    )")
+                    .value();
+  EXPECT_FALSE(fusion_merger_.Run(module.get()).value());
+}
+
+TEST_F(FusionMergerTest, MergeDUSFusionWithElementwiseFusion) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule module
+
+    %fused_computation {
+      %param_0 = f32[1,8]{1,0} parameter(0)
+      %bitcast = f32[8]{0} bitcast(%param_0)
+      ROOT %neg = f32[8]{0} negate(%bitcast)
+    }
+
+    %fused_computation.1 {
+      %param_0.1 = f32[8]{0} parameter(0)
+      %param_1.4 = f32[5]{0} parameter(1)
+      %three.1 = s32[] constant(3)
+      %exp = f32[5]{0} exponential(%param_1.4)
+      ROOT %dynamic-update-slice.3 = f32[8]{0} dynamic-update-slice(f32[8]{0} %param_0.1, f32[5]{0} %exp, s32[] %three.1)
+    }
+
+    ENTRY %Test {
+      %parameter = f32[5]{0} parameter(0)
+      %parameter.1 = f32[1,8]{1,0} parameter(1)
+      %fusion = f32[8]{0} fusion(f32[1,8]{1,0} %parameter.1), kind=kLoop, calls=%fused_computation
+      ROOT %fusion.1 = f32[8]{0} fusion(f32[8]{0} %fusion, f32[5]{0} %parameter), kind=kLoop, calls=%fused_computation.1
+    }
+    )")
+                    .value();
+  EXPECT_TRUE(fusion_merger_.Run(module.get()).value());
+}
+
+TEST_F(FusionMergerTest, DoNotMergeTwoReduces) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule test_module
+
+    add {
+      p0 = f32[] parameter(0)
+      p1 = f32[] parameter(1)
+      ROOT add.13235 = f32[] add(p0, p1)
+    }
+
+    ENTRY main {
+      p0 = f32[8,4,128,226]{3,2,1,0} parameter(0)
+      c0 = f32[] constant(0)
+      r0 = f32[8,4,128]{2,1,0} reduce(p0, c0), dimensions={3}, to_apply=add
+      ROOT r1 = f32[8,4]{1,0} reduce(r0, c0), dimensions={2}, to_apply=add
+    }
+    )")
+                    .value();
+  EXPECT_FALSE(fusion_merger_.Run(module.get()).value());
+}
+
+}  // namespace
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/fusion_wrapper.cc b/third_party/xla/xla/service/gpu/transforms/fusion_wrapper.cc
new file mode 100644
index 0000000..d7f8505
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/fusion_wrapper.cc
@@ -0,0 +1,152 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "xla/service/gpu/transforms/fusion_wrapper.h"
+
+#include <functional>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/gpu/gpu_fusible.h"
+#include "tsl/platform/errors.h"
+
+namespace xla {
+namespace gpu {
+
+absl::StatusOr<bool> FusionWrapper::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  auto instructions = module->entry_computation()->MakeInstructionPostOrder();
+  bool changed = false;
+
+  std::function<absl::Status(HloInstruction*)> handle_instruction;
+  handle_instruction = [&](HloInstruction* instruction) -> absl::Status {
+    switch (instruction->opcode()) {
+      case HloOpcode::kConditional:
+      case HloOpcode::kWhile:
+        for (auto* computation : instruction->called_computations()) {
+          for (auto* inner_instruction :
+               computation->MakeInstructionPostOrder()) {
+            TF_RETURN_IF_ERROR(handle_instruction(inner_instruction));
+          }
+        }
+        break;
+      case HloOpcode::kAbs:
+      case HloOpcode::kAdd:
+      case HloOpcode::kAnd:
+      case HloOpcode::kAtan2:
+      case HloOpcode::kBitcastConvert:
+      case HloOpcode::kBroadcast:
+      case HloOpcode::kCeil:
+      case HloOpcode::kCbrt:
+      case HloOpcode::kClamp:
+      case HloOpcode::kClz:
+      case HloOpcode::kCompare:
+      case HloOpcode::kComplex:
+      case HloOpcode::kConcatenate:
+      case HloOpcode::kConvert:
+      case HloOpcode::kCopy:
+      case HloOpcode::kCos:
+      case HloOpcode::kDivide:
+      case HloOpcode::kDot:
+      case HloOpcode::kDynamicSlice:
+      case HloOpcode::kDynamicUpdateSlice:
+      case HloOpcode::kErf:
+      case HloOpcode::kExp:
+      case HloOpcode::kExpm1:
+      case HloOpcode::kFloor:
+      case HloOpcode::kGather:
+      case HloOpcode::kImag:
+      case HloOpcode::kIota:
+      case HloOpcode::kIsFinite:
+      case HloOpcode::kLog:
+      case HloOpcode::kLog1p:
+      case HloOpcode::kMap:
+      case HloOpcode::kMaximum:
+      case HloOpcode::kMinimum:
+      case HloOpcode::kMultiply:
+      case HloOpcode::kNegate:
+      case HloOpcode::kNot:
+      case HloOpcode::kOr:
+      case HloOpcode::kPad:
+      case HloOpcode::kPopulationCount:
+      case HloOpcode::kPower:
+      case HloOpcode::kReal:
+      case HloOpcode::kReshape:
+      case HloOpcode::kReduce:
+      case HloOpcode::kReducePrecision:
+      case HloOpcode::kReduceWindow:
+      case HloOpcode::kRemainder:
+      case HloOpcode::kReverse:
+      case HloOpcode::kRoundNearestAfz:
+      case HloOpcode::kRoundNearestEven:
+      case HloOpcode::kRsqrt:
+      case HloOpcode::kScatter:
+      case HloOpcode::kSelect:
+      case HloOpcode::kShiftLeft:
+      case HloOpcode::kShiftRightLogical:
+      case HloOpcode::kShiftRightArithmetic:
+      case HloOpcode::kSign:
+      case HloOpcode::kSin:
+      case HloOpcode::kSlice:
+      case HloOpcode::kSqrt:
+      case HloOpcode::kSubtract:
+      case HloOpcode::kStochasticConvert:
+      case HloOpcode::kTan:
+      case HloOpcode::kTanh:
+      case HloOpcode::kTranspose:
+      case HloOpcode::kXor: {
+        auto* computation = instruction->parent();
+        auto* fusion_instruction =
+            computation->AddInstruction(HloInstruction::CreateFusion(
+                instruction->shape(),
+                ChooseFusionKind(*instruction, *instruction), instruction));
+        const absl::string_view wrapped_opcode =
+            HloOpcodeString(instruction->opcode());
+        module->SetAndUniquifyInstrName(
+            fusion_instruction, absl::StrCat("wrapped_", wrapped_opcode));
+        module->SetAndUniquifyComputationName(
+            fusion_instruction->fused_instructions_computation(),
+            absl::StrCat("wrapped_", wrapped_opcode, "_computation"));
+        if (module->has_schedule()) {
+          module->schedule().replace_instruction(computation, instruction,
+                                                 fusion_instruction);
+        }
+        TF_RETURN_IF_ERROR(
+            fusion_instruction->CopyAllControlDepsFrom(instruction));
+        TF_RETURN_IF_ERROR(instruction->DropAllControlDeps());
+        TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(fusion_instruction));
+        TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction));
+        changed = true;
+        break;
+      }
+      default:
+        break;
+    }
+    return absl::OkStatus();
+  };
+
+  for (auto* instruction : instructions) {
+    TF_RETURN_IF_ERROR(handle_instruction(instruction));
+  }
+  return changed;
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/fusion_wrapper.h b/third_party/xla/xla/service/gpu/transforms/fusion_wrapper.h
new file mode 100644
index 0000000..30b1c8a
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/fusion_wrapper.h
@@ -0,0 +1,42 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_FUSION_WRAPPER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_FUSION_WRAPPER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// Wraps leftover unfused instruction that are in the entry computation that
+// have no LHLO equivalent in fusions containing just that instruction.
+class FusionWrapper : public HloModulePass {
+ public:
+  absl::string_view name() const override { return "fusion-wrapper"; }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_FUSION_WRAPPER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/fusion_wrapper_test.cc b/third_party/xla/xla/service/gpu/transforms/fusion_wrapper_test.cc
new file mode 100644
index 0000000..a46338f
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/fusion_wrapper_test.cc
@@ -0,0 +1,188 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "xla/service/gpu/transforms/fusion_wrapper.h"
+
+#include <optional>
+
+#include <gtest/gtest.h>
+#include "xla/tests/hlo_test_base.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+class FusionWrapperTest : public HloTestBase {};
+
+TEST_F(FusionWrapperTest, SimpleOp) {
+  RunAndFilecheckHloRewrite(R"(
+      HloModule TestModule
+
+      ENTRY TestComputation {
+        p0 = f16[30,41] parameter(0)
+        p1 = f16[30,41] parameter(1)
+        ROOT result = f16[60, 41] concatenate(p0, p1), dimensions={0}
+      })",
+                            FusionWrapper(), R"(
+// CHECK: %wrapped_concatenate_computation (param_0: f16[30,41], param_1: f16[30,41]) -> f16[60,41] {
+// CHECK:   %param_0 = f16[30,41]{1,0} parameter(0)
+// CHECK:   %param_1 = f16[30,41]{1,0} parameter(1)
+// CHECK:   ROOT %result.1 = f16[60,41]{1,0} concatenate(%param_0, %param_1), dimensions={0}
+// CHECK: }
+
+// CHECK: ENTRY %TestComputation (p0: f16[30,41], p1: f16[30,41]) -> f16[60,41] {
+// CHECK:   %p0 = f16[30,41]{1,0} parameter(0)
+// CHECK:   %p1 = f16[30,41]{1,0} parameter(1)
+// CHECK:   ROOT %wrapped_concatenate = f16[60,41]{1,0} fusion(%p0, %p1), kind=kLoop, calls=%wrapped_concatenate_computation
+// CHECK: })");
+}
+
+TEST_F(FusionWrapperTest, Scatter) {
+  RunAndFilecheckHloRewrite(R"(
+      HloModule ScatterIntoScalar
+
+      update_s32 {
+        lhs = s32[] parameter(0)
+        ROOT rhs = s32[] parameter(1)
+      }
+
+      ENTRY main {
+        parameter.1 = s32[] parameter(0)
+        parameter.2 = s32[0]{0} parameter(1)
+        parameter.3 = s32[] parameter(2)
+        ROOT scatter_ScatterIntoScalar = s32[] scatter(parameter.1, parameter.2, parameter.3),
+            update_window_dims={},
+            inserted_window_dims={},
+            scatter_dims_to_operand_dims={},
+            index_vector_dim=0,
+            to_apply=update_s32
+      })",
+                            FusionWrapper(), R"(
+// CHECK: wrapped_scatter_computation
+// CHECK:   %[[param_0:.*]] = s32[] parameter(0)
+// CHECK:   %[[param_1:.*]] = s32[0]{0} parameter(1)
+// CHECK:   %[[param_2:.*]] = s32[] parameter(2)
+// CHECK:   ROOT %{{.*}} = s32[] scatter(%[[param_0]], %[[param_1]], %[[param_2]])
+
+// CHECK: ENTRY
+// CHECK:   %[[p0:.*]] = s32[] parameter(0)
+// CHECK:   %[[p1:.*]] = s32[0]{0} parameter(1)
+// CHECK:   %[[p2:.*]] = s32[] parameter(2)
+// CHECK:   ROOT %{{.*}} = s32[] fusion(%[[p0]], %[[p1]], %[[p2]]), kind=kInput, calls=%wrapped_scatter_computation
+// CHECK: })");
+}
+
+TEST_F(FusionWrapperTest, ControlDependency) {
+  RunAndFilecheckHloRewrite(R"(
+      HloModule TestModule
+
+      fusion {
+        ROOT param = f32[] parameter(0)
+      }
+
+      ENTRY main {
+        param = f32[] parameter(0)
+        fusion = f32[] fusion(param), kind=kLoop, calls=fusion
+        constant_one = f32[] constant(1)
+        ROOT add = f32[] add(param, constant_one), control-predecessors={fusion}
+      })",
+                            FusionWrapper(), R"(
+// CHECK: ROOT %wrapped_add = f32[] fusion(%param.1, %constant_one),
+// CHECK-SAME: control-predecessors={%fusion})");
+}
+
+TEST_F(FusionWrapperTest, While) {
+  RunAndFilecheckHloRewrite(R"(
+      HloModule While
+
+      %body {
+        %parameter.5 = (f32[5]{0}) parameter(0)
+        %constant_8 = f32[] constant(0)
+        %broadcast.9 = f32[5]{0} broadcast(f32[] %constant_8), dimensions={}
+        ROOT %tuple.2 = (f32[5]{0}) tuple(f32[5]{0} %broadcast.9)
+      }
+
+      %cond {
+        %parameter.12 = (f32[5]{0}) parameter(0)
+        ROOT %constant_1 = pred[] constant(false)
+      }
+
+      ENTRY %main (parameter.1: f32[5]) -> (f32[5]) {
+        %parameter.1 = f32[5]{0} parameter(0)
+        %copy.3 = f32[5]{0} copy(f32[5]{0} %parameter.1)
+        %tuple = (f32[5]{0}) tuple(f32[5]{0} %copy.3)
+        ROOT %while.19 = (f32[5]{0}) while((f32[5]{0}) %tuple), condition=%cond, body=%body
+      })",
+                            FusionWrapper(), R"(
+// CHECK: %wrapped_broadcast_computation {{.*}} {
+// CHECK:  %param_0.1 = f32[] parameter(0)
+// CHECK:  ROOT %broadcast.0 = f32[5]{0} broadcast(%param_0.1), dimensions={}
+// CHECK: }
+// CHECK: %body {{.*}} {
+// CHECK:   %parameter.5 = (f32[5]{0}) parameter(0)
+// CHECK:   %constant_8 = f32[] constant(0)
+// CHECK:   %wrapped_broadcast = f32[5]{0} fusion(%constant_8), kind=kLoop, calls=%wrapped_broadcast_computation
+// CHECK:   ROOT %tuple.2 = (f32[5]{0}) tuple(%wrapped_broadcast)
+// CHECK: }
+// CHECK: %cond {{.*}} {
+// CHECK:   %parameter.12 = (f32[5]{0}) parameter(0)
+// CHECK:   ROOT %constant_1 = pred[] constant(false)
+// CHECK: }
+// CHECK: %wrapped_copy_computation {{.*}} {
+// CHECK:   %param_0 = f32[5]{0} parameter(0)
+// CHECK:   ROOT %copy.0 = f32[5]{0} copy(%param_0)
+// CHECK: }
+// CHECK: ENTRY %main {{.*}} {
+// CHECK:   %parameter.1 = f32[5]{0} parameter(0)
+// CHECK:   %wrapped_copy = f32[5]{0} fusion(%parameter.1), kind=kLoop, calls=%wrapped_copy_computation
+// CHECK:   %tuple = (f32[5]{0}) tuple(%wrapped_copy)
+// CHECK:   ROOT %while.19 = (f32[5]{0}) while(%tuple), condition=%cond, body=%body
+// CHECK: })");
+}
+
+TEST_F(FusionWrapperTest, WhileInFusion) {
+  RunAndFilecheckHloRewrite(R"(
+      HloModule While
+
+      %body {
+        %parameter.5 = (f32[5]{0}) parameter(0)
+        %constant_8 = f32[] constant(0)
+        %broadcast.9 = f32[5]{0} broadcast(f32[] %constant_8), dimensions={}
+        ROOT %tuple.2 = (f32[5]{0}) tuple(f32[5]{0} %broadcast.9)
+      }
+
+      %cond {
+        %parameter.12 = (f32[5]{0}) parameter(0)
+        ROOT %constant_1 = pred[] constant(false)
+      }
+
+      %fusion {
+        %parameter.1 = f32[5]{0} parameter(0)
+        %copy.3 = f32[5]{0} copy(f32[5]{0} %parameter.1)
+        %tuple = (f32[5]{0}) tuple(f32[5]{0} %copy.3)
+        ROOT %while.19 = (f32[5]{0}) while((f32[5]{0}) %tuple), condition=%cond, body=%body
+      }
+
+      ENTRY %main (parameter.1: f32[5]) -> (f32[5]) {
+        %parameter.1 = f32[5]{0} parameter(0)
+        ROOT %fusion = (f32[5]{0}) fusion(f32[5]{0} %parameter.1), kind=kLoop, calls=%fusion
+      })",
+                            FusionWrapper(),
+                            // No change
+                            std::nullopt);
+}
+
+}  // namespace
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_broadcast_folding_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/gemm_broadcast_folding_rewriter.cc
new file mode 100644
index 0000000..3395ed8
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/gemm_broadcast_folding_rewriter.cc
@@ -0,0 +1,124 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/gemm_broadcast_folding_rewriter.h"
+
+#include <cstdint>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/check.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/service/pattern_matcher.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+namespace m = match;
+
+class GemmBroadcastFoldingVisitor : public DfsHloRewriteVisitor {
+ public:
+  absl::Status HandleCustomCall(HloInstruction *instr) override {
+    HloInstruction *existing_gemm;
+    HloInstruction *bcast;
+    if (Match(instr, m::CustomCall(&existing_gemm,
+                                   {kGemmCallTarget, kCublasLtMatmulCallTarget})
+                         .WithOperand(0, m::Broadcast(&bcast, m::Op()))) ||
+        (Match(instr, m::CustomCall(&existing_gemm, {kGemmCallTarget,
+                                                     kCublasLtMatmulCallTarget})
+                          .WithOperand(1, m::Broadcast(&bcast, m::Op()))))) {
+      TF_ASSIGN_OR_RETURN(auto gpu_config,
+                          existing_gemm->backend_config<GpuBackendConfig>());
+      GemmBackendConfig &config = *gpu_config.mutable_gemm_backend_config();
+      DotDimensionNumbers *dim_nums = config.mutable_dot_dimension_numbers();
+      int bcast_operand_index = instr->operand_index(bcast);
+      int num_bcast_dims = (bcast->shape().dimensions_size() -
+                            bcast->operand(0)->shape().dimensions_size());
+      int num_batch_dims = dim_nums->lhs_batch_dimensions_size();
+
+      const tsl::protobuf::RepeatedField<int64_t> &batch_dimensions =
+          (bcast_operand_index == 1) ? dim_nums->rhs_batch_dimensions()
+                                     : dim_nums->lhs_batch_dimensions();
+      // This optimization is only valid if the set of broadcasted dimensions
+      // is exactly the set of batch dimensions. First, check that all newly
+      // broadcast dimensions have been inserted on the left i.e. all new
+      // dimensions must be in [0, num_bcast_dims) or equivalently all original
+      // dimensions are >= num_bcast_dims.
+      for (int64_t bcast_dim : bcast->dimensions()) {
+        if (bcast_dim < num_bcast_dims) {
+          return absl::OkStatus();
+        }
+        // bcast_dim should not be in batch_dimensions.
+        if (absl::c_linear_search(batch_dimensions, bcast_dim)) {
+          return absl::OkStatus();
+        }
+      }
+
+      // Then check that all batch dimensions are being broadcast, and that
+      // there is at least one batch dimension.
+      CHECK_GT(num_bcast_dims, 0);
+      if (num_bcast_dims != num_batch_dims) {
+        return absl::OkStatus();
+      }
+
+      if (bcast_operand_index == 1) {
+        CHECK_EQ(dim_nums->rhs_contracting_dimensions_size(), 1);
+        dim_nums->set_rhs_contracting_dimensions(
+            0, dim_nums->rhs_contracting_dimensions(0) - num_batch_dims);
+        dim_nums->clear_rhs_batch_dimensions();
+      } else {
+        CHECK_EQ(dim_nums->lhs_contracting_dimensions_size(), 1);
+        dim_nums->set_lhs_contracting_dimensions(
+            0, dim_nums->lhs_contracting_dimensions(0) - num_batch_dims);
+        dim_nums->clear_lhs_batch_dimensions();
+      }
+      TF_RETURN_IF_ERROR(existing_gemm->ReplaceOperandWithDifferentShape(
+          bcast_operand_index, bcast->mutable_operand(0)));
+      TF_RETURN_IF_ERROR(existing_gemm->set_backend_config(gpu_config));
+      MarkAsChanged();
+    }
+    return absl::OkStatus();
+  }
+};
+
+static absl::StatusOr<bool> RunOnComputation(HloComputation *computation) {
+  GemmBroadcastFoldingVisitor visitor;
+  TF_RETURN_IF_ERROR(computation->Accept(&visitor));
+  return visitor.changed();
+}
+
+absl::StatusOr<bool> GemmBroadcastFoldingRewriter::Run(
+    HloModule *module,
+    const absl::flat_hash_set<absl::string_view> &execution_threads) {
+  bool changed = false;
+  for (HloComputation *computation :
+       module->MakeNonfusionComputations(execution_threads)) {
+    TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation));
+    changed |= result;
+  }
+  return changed;
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_broadcast_folding_rewriter.h b/third_party/xla/xla/service/gpu/transforms/gemm_broadcast_folding_rewriter.h
new file mode 100644
index 0000000..8606136
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/gemm_broadcast_folding_rewriter.h
@@ -0,0 +1,51 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_GEMM_BROADCAST_FOLDING_REWRITER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_GEMM_BROADCAST_FOLDING_REWRITER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// cuBLAS GEMM has support for strided batched calls, where the stride is used
+// to determine the offset between the batches.
+//
+// This allows (kCustomCall:gemm A kBroadcast(B)) or
+//             (kCustomCall:gemm kBroadcast(A) B)
+// to be rewritten as (kCustomCall:gemm A B) with a zero stride for the
+// broadcasted operand if the broadcast operates on all the batch dimensions.
+//
+// This pattern matches the above case and removes the unnecessary broadcast.
+class GemmBroadcastFoldingRewriter : public HloModulePass {
+ public:
+  absl::string_view name() const override {
+    return "cublas-gemm-broadcast-folding-rewriter";
+  }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_GEMM_BROADCAST_FOLDING_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_broadcast_folding_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/gemm_broadcast_folding_rewriter_test.cc
new file mode 100644
index 0000000..57e68fc
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/gemm_broadcast_folding_rewriter_test.cc
@@ -0,0 +1,230 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/gemm_broadcast_folding_rewriter.h"
+
+#include <memory>
+
+#include "xla/error_spec.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/gpu/tests/gpu_codegen_test.h"
+#include "xla/service/gpu/transforms/gemm_rewriter.h"
+#include "tsl/platform/statusor.h"
+#include "tsl/platform/test.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+
+class GemmBroadcastFoldingRewriteTest : public GpuCodegenTest {
+ protected:
+  const auto& GpuComputeComp() {
+    return backend()
+        .default_stream_executor()
+        ->GetDeviceDescription()
+        .gpu_compute_capability();
+  }
+
+  DebugOptions GetDebugOptionsForTest() override {
+    DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest();
+    // These tests test the cuBLAS rewriter so we have to make sure that we use
+    // cuBLAS for them.
+    debug_options.set_xla_gpu_enable_triton_gemm(false);
+    debug_options.set_xla_gpu_gemm_rewrite_size_threshold(0);
+    return debug_options;
+  }
+};
+
+TEST_F(GemmBroadcastFoldingRewriteTest, BroadcastedStridedRewriteRhs) {
+  const char* hlo_text = R"(
+HloModule BroadcastedInput
+
+ENTRY AddDotsFunc {
+  x = f32[3,2,2]{2,1,0} parameter(0)
+  y = f32[2,2]{1,0} parameter(1)
+  y_broadcast = f32[3,2,2]{2,1,0} broadcast(y), dimensions={1,2}
+  ROOT dot_a = f32[3,2,2]{2,1,0} dot(x, y_broadcast), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_contracting_dims={1}
+}
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[3,2,2], {{.*}}: f32[2,2]) -> f32[3,2,2] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[3,2,2]{2,1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1)
+; CHECK-NEXT:    [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
+; CHECK:           custom_call_target="__cublas${{(lt\$matmul|gemm)}}",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["2"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":["0"]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+)");
+}
+
+TEST_F(GemmBroadcastFoldingRewriteTest, BroadcastedStridedRewriteLhs) {
+  const char* hlo_text = R"(
+HloModule BroadcastedInput
+
+ENTRY AddDotsFunc {
+  x = f32[2,2]{1,0} parameter(0)
+  y = f32[3,2,2]{2,1,0} parameter(1)
+  x_broadcast = f32[3,2,2]{2,1,0} broadcast(x), dimensions={1,2}
+  ROOT dot_a = f32[3,2,2]{2,1,0} dot(x_broadcast, y), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_contracting_dims={1}
+}
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,2], {{.*}}: f32[3,2,2]) -> f32[3,2,2] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[3,2,2]{2,1,0} parameter(1)
+; CHECK-NEXT:    [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
+; CHECK    :       custom_call_target="__cublas${{(lt\$matmul|gemm)}}",
+; CHECK    :       backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":["0"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+)");
+}
+
+TEST_F(GemmBroadcastFoldingRewriteTest,
+       BroadcastedStridedRewriteRhsPassChanged) {
+  const char* hlo_text = R"(
+HloModule BroadcastedInput
+
+ENTRY AddDotsFunc {
+  x = f32[3,2,2]{2,1,0} parameter(0)
+  y = f32[2,2]{1,0} parameter(1)
+  y_broadcast = f32[3,2,2]{2,1,0} broadcast(y), dimensions={1,2}
+  ROOT dot_a = f32[3,2,2]{2,1,0} dot(x, y_broadcast), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_contracting_dims={1}
+}
+
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_text));
+  // Use GemmRewriter to generate cublasGemm call.
+  GemmRewriter gemm_rewriter(GpuComputeComp(), /*toolkit_version=*/12040);
+  TF_ASSERT_OK_AND_ASSIGN(bool changed,
+                          this->RunHloPass(&gemm_rewriter, module.get()));
+  EXPECT_TRUE(changed);
+  GemmBroadcastFoldingRewriter pass;
+  TF_ASSERT_OK_AND_ASSIGN(changed, this->RunHloPass(&pass, module.get()));
+  EXPECT_TRUE(changed);
+}
+
+TEST_F(GemmBroadcastFoldingRewriteTest,
+       BroadcastedStridedRewriteLhsPassChanged) {
+  const char* hlo_text = R"(
+HloModule BroadcastedInput
+
+ENTRY AddDotsFunc {
+  x = f32[2,2]{1,0} parameter(0)
+  y = f32[3,2,2]{2,1,0} parameter(1)
+  x_broadcast = f32[3,2,2]{2,1,0} broadcast(x), dimensions={1,2}
+  ROOT dot_a = f32[3,2,2]{2,1,0} dot(x_broadcast, y), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_contracting_dims={1}
+}
+
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_text));
+  // Use GemmRewriter to generate cublasGemm call.
+  GemmRewriter gemm_rewriter(GpuComputeComp(), /*toolkit_version=*/12040);
+  TF_ASSERT_OK_AND_ASSIGN(bool changed,
+                          this->RunHloPass(&gemm_rewriter, module.get()));
+  EXPECT_TRUE(changed);
+  GemmBroadcastFoldingRewriter pass;
+  TF_ASSERT_OK_AND_ASSIGN(changed, this->RunHloPass(&pass, module.get()));
+  EXPECT_TRUE(changed);
+}
+
+TEST_F(GemmBroadcastFoldingRewriteTest, LHSBatchDimNonZero) {
+  const char* hlo_text = R"(
+HloModule LHSBatchDimNonZero
+
+ENTRY %LHSBatchDimNonZero (Arg_1: f32[4,3], Arg_2: f32[4,7,3]) -> f32[4,7,7] {
+  %Arg_1 = f32[4,3]{1,0} parameter(0)
+  %Arg_2 = f32[4,7,3]{2,1,0} parameter(1)
+  %broadcast.22 = f32[7,4,3]{2,1,0} broadcast(f32[4,3]{1,0} %Arg_1), dimensions={1,2}
+  ROOT %dot.24 = f32[4,7,7]{2,1,0} dot(f32[7,4,3]{2,1,0} %broadcast.22, f32[4,7,3]{2,1,0} %Arg_2), lhs_batch_dims={1}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={2}
+}
+)";
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_text));
+  // Use GemmRewriter to generate cublasGemm call.
+  GemmRewriter gemm_rewriter(GpuComputeComp(), /*toolkit_version=*/12040);
+  TF_ASSERT_OK_AND_ASSIGN(bool changed,
+                          this->RunHloPass(&gemm_rewriter, module.get()));
+  EXPECT_TRUE(changed);
+  GemmBroadcastFoldingRewriter pass;
+  TF_ASSERT_OK_AND_ASSIGN(changed, this->RunHloPass(&pass, module.get()));
+  EXPECT_FALSE(changed);
+}
+
+TEST_F(GemmBroadcastFoldingRewriteTest, RHSBatchDimNonZero) {
+  const char* hlo_text = R"(
+HloModule RHSBatchDimNonZero
+
+ENTRY %RHSBatchDimNonZero (Arg_1: f32[4,3], Arg_2: f32[4,7,3]) -> f32[4,7,7] {
+  %Arg_1 = f32[4,3]{1,0} parameter(0)
+  %Arg_2 = f32[4,7,3]{2,1,0} parameter(1)
+  %broadcast.22 = f32[7,4,3]{2,1,0} broadcast(f32[4,3]{1,0} %Arg_1), dimensions={1,2}
+  ROOT %dot.24 = f32[4,7,7]{2,1,0} dot(f32[4,7,3]{2,1,0} %Arg_2, f32[7,4,3]{2,1,0} %broadcast.22), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={1}, rhs_contracting_dims={2}
+}
+)";
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_text));
+  GemmRewriter gemm_rewriter(GpuComputeComp(), /*toolkit_version=*/12040);
+  TF_ASSERT_OK_AND_ASSIGN(bool changed,
+                          this->RunHloPass(&gemm_rewriter, module.get()));
+  EXPECT_TRUE(changed);
+  GemmBroadcastFoldingRewriter pass;
+  TF_ASSERT_OK_AND_ASSIGN(changed, this->RunHloPass(&pass, module.get()));
+  EXPECT_FALSE(changed);
+}
+
+}  // namespace
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_fusion.cc b/third_party/xla/xla/service/gpu/transforms/gemm_fusion.cc
new file mode 100644
index 0000000..0801830
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/gemm_fusion.cc
@@ -0,0 +1,815 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/gemm_fusion.h"
+
+#include <array>
+#include <cstddef>
+#include <cstdint>
+#include <optional>
+#include <queue>
+#include <string>
+#include <tuple>
+#include <utility>
+#include <variant>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/cublas_padding_requirements.h"
+#include "xla/service/gpu/fusions/triton/triton_support.h"
+#include "xla/service/gpu/ir_emission_utils.h"
+#include "xla/service/gpu/matmul_utils.h"
+#include "xla/service/gpu/triton_fusion_analysis.h"
+#include "xla/service/gpu/triton_tiling_propagation.h"
+#include "xla/service/instruction_fusion.h"
+#include "xla/shape_util.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+
+using triton_fusion::CombineDotRequirements;
+using triton_fusion::DimensionOrder;
+using triton_fusion::DimOrderMap;
+using triton_fusion::DimOrdersAndReqs;
+using triton_fusion::DimOrdersAndReqsOrError;
+using triton_fusion::DotProperties;
+using triton_fusion::DotRequirements;
+using triton_fusion::DotRequirementsOrError;
+using triton_fusion::FusionContext;
+using triton_fusion::GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible;
+using triton_fusion::TransformDirection;
+
+// This represents a directed graph.
+class AdjacencyList {
+ public:
+  using NodeId = int64_t;
+
+  NodeId AddNode() {
+    adj_.emplace_back();
+    return adj_.size() - 1;
+  }
+
+  const std::vector<NodeId>& GetOutNeighbors(NodeId node_id) const {
+    return adj_.at(node_id);
+  }
+
+  void ReserveSpaceForOutNeighbors(NodeId node_id, size_t count) {
+    adj_.at(node_id).reserve(count);
+  }
+
+  void AddArc(NodeId from, NodeId to) { adj_.at(from).push_back(to); }
+
+  // Currently the Root node is the node which was added first.
+  NodeId GetRoot() const {
+    CHECK(!adj_.empty());
+    return 0;
+  }
+
+ private:
+  // Adjacency list: A vector of out-neighbors for each node.
+  std::vector<std::vector<NodeId>> adj_;
+};
+
+struct HloAndDimOrder {
+  const HloInstruction* original_hlo = nullptr;
+  DimensionOrder dim_order;
+};
+
+struct HloAndIterSpec {
+  const HloInstruction* original_hlo;
+  TensorIterationSpec iter_spec;
+
+  auto ToTuple() const { return std::make_tuple(original_hlo, iter_spec); }
+  bool operator==(const HloAndIterSpec& other) const {
+    return ToTuple() == other.ToTuple();
+  }
+  template <typename H>
+  friend H AbslHashValue(H h, const HloAndIterSpec& key) {
+    return H::combine(std::move(h), key.ToTuple());
+  }
+};
+
+struct NodeFusionPlan {
+  const HloInstruction* original_hlo = nullptr;
+  bool should_fuse = false;
+};
+
+struct FusionPlan {
+  // The graph describing the structure of the fusion that we build - nodes
+  // corresponding to the instructions and arcs pointing from users to operands.
+  AdjacencyList graph;
+  // The fusion plan for each node.
+  absl::flat_hash_map<AdjacencyList::NodeId, NodeFusionPlan> map;
+};
+
+struct FusionPlanAndRequirements {
+  FusionPlan fusion_plan;
+  DotRequirements requirements;
+};
+
+struct HlosAndRequirements {
+  // The original HLO (which is outside the fusion computation).
+  const HloInstruction* original_hlo = nullptr;
+  // The fused HLO inside the new fusion computation, built by the builder.
+  //
+  // This can have the same opcode as `original_hlo` or it can be a parameter if
+  // the original HLO can't be fused.
+  const HloInstruction* fused_hlo = nullptr;
+  // The requirements imposed by the fused operations.
+  //
+  // If we fuse further operations they may have to conform to these
+  // requirements.
+  DotRequirements requirements;
+};
+
+// Clones the hero kDot operation into the fusion.
+HloInstruction& FuseDot(const HloDotInstruction& dot,
+                        const HloInstruction& fused_lhs,
+                        const HloInstruction& fused_rhs,
+                        std::optional<const HloInstruction*> fused_meta,
+                        HloComputation::Builder& builder  // append
+) {
+  VLOG(3) << "Fusing " << dot.ToString();
+
+  std::vector<HloInstruction*> hlo_new_operands = {
+      const_cast<HloInstruction*>(&fused_lhs),
+      const_cast<HloInstruction*>(&fused_rhs)};
+  if (fused_meta.has_value()) {
+    hlo_new_operands.push_back(const_cast<HloInstruction*>(fused_meta.value()));
+  }
+  return *builder.AddInstruction(
+      dot.CloneWithNewOperands(dot.shape(), hlo_new_operands));
+}
+
+// Tells how many new parameters does a fusion gain by fusing the operation as
+// an input.
+int64_t NumAddedParameters(const HloInstruction& hlo) {
+  // Non-scalar constant is equivalent to a parameter: one input, one output.
+  if (hlo.opcode() == HloOpcode::kParameter ||
+      (hlo.opcode() == HloOpcode::kConstant &&
+       !ShapeUtil::IsScalar(hlo.shape()))) {
+    return 0;
+  }
+  // All other instructions add all own inputs and remove own single output.
+  return hlo.operand_count() - 1;
+}
+
+// Just a helper to reduce "unwrapping" code where we use this.
+std::optional<DimOrdersAndReqs> GetOperandDimOrdersAndCombinedReqs(
+    const HloInstruction& hlo, const DimensionOrder& dim_order,
+    const DotProperties& properties,
+    const se::GpuComputeCapability& gpu_version,
+    const DotRequirements& requirements) {
+  DimOrdersAndReqsOrError dim_orders_and_new_reqs =
+      GetPropagatedDimOrdersAndRequirements(
+          hlo, dim_order, TransformDirection::kOutputToInput, properties);
+  if (!std::holds_alternative<DimOrdersAndReqs>(dim_orders_and_new_reqs)) {
+    return std::nullopt;
+  }
+  DotRequirementsOrError combined_reqs = CombineDotRequirements(
+      requirements,
+      std::get<DimOrdersAndReqs>(dim_orders_and_new_reqs).requirements);
+  if (!std::holds_alternative<DotRequirements>(combined_reqs)) {
+    return std::nullopt;
+  }
+  return DimOrdersAndReqs{
+      std::get<DimOrdersAndReqs>(dim_orders_and_new_reqs).dim_orders,
+      std::get<DotRequirements>(combined_reqs)};
+}
+
+// Just a helper to reduce "unwrapping" code where we use this.
+std::optional<DimOrdersAndReqs> GetOperandDimOrdersAndCombinedReqsIfProfitable(
+    const HloInstruction& hlo, const DimensionOrder& dim_order,
+    const DotProperties& properties,
+    const se::GpuComputeCapability& gpu_version,
+    const DotRequirements& requirements) {
+  DimOrdersAndReqsOrError dim_orders_and_new_reqs =
+      GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible(
+          hlo, TransformDirection::kOutputToInput,
+          /*src_operand_index=*/std::nullopt, dim_order, gpu_version,
+          properties);
+  if (!std::holds_alternative<DimOrdersAndReqs>(dim_orders_and_new_reqs)) {
+    return std::nullopt;
+  }
+  DotRequirementsOrError combined_reqs = CombineDotRequirements(
+      requirements,
+      std::get<DimOrdersAndReqs>(dim_orders_and_new_reqs).requirements);
+  if (!std::holds_alternative<DotRequirements>(combined_reqs)) {
+    return std::nullopt;
+  }
+  return DimOrdersAndReqs{
+      std::get<DimOrdersAndReqs>(dim_orders_and_new_reqs).dim_orders,
+      std::get<DotRequirements>(combined_reqs)};
+}
+
+// Just a helper to reduce "unwrapping" code where we use this.
+std::optional<DimOrdersAndReqs> GetUserDimOrdersAndCombinedReqsIfProfitable(
+    const HloInstruction& hlo, const DimensionOrder& hlo_dim_order,
+    const HloInstruction& user, const DotProperties& properties,
+    const se::GpuComputeCapability& gpu_version,
+    const DotRequirements& requirements) {
+  DimOrdersAndReqsOrError dim_orders_and_new_reqs =
+      GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible(
+          user, TransformDirection::kInputToOutput, user.operand_index(&hlo),
+          hlo_dim_order, gpu_version, properties);
+  if (!std::holds_alternative<DimOrdersAndReqs>(dim_orders_and_new_reqs)) {
+    return std::nullopt;
+  }
+  DotRequirementsOrError combined_reqs = CombineDotRequirements(
+      requirements,
+      std::get<DimOrdersAndReqs>(dim_orders_and_new_reqs).requirements);
+  if (!std::holds_alternative<DotRequirements>(combined_reqs)) {
+    return std::nullopt;
+  }
+  return DimOrdersAndReqs{
+      std::get<DimOrdersAndReqs>(dim_orders_and_new_reqs).dim_orders,
+      std::get<DotRequirements>(combined_reqs)};
+}
+
+// Builds the fusion map and the requirements which can later be used to
+// actually fuse that subgraph.
+FusionPlanAndRequirements BuildFusionPlanTowardOperands(
+    const HloInstruction& root_hlo, const DimensionOrder& root_dim_order,
+    const std::optional<int>& max_params,
+    const se::GpuComputeCapability& gpu_version,
+    const DotProperties& properties,
+    const DotRequirements& requirements_so_far) {
+  CHECK(!max_params.has_value() || max_params.value() >= 1);
+
+  // The graph describing the structure of the fusion that we build - nodes
+  // corresponding to the instructions and arcs pointing from users to operands.
+  // We can build and modify this graph easily without the need to create
+  // HloInstructions at this point.
+  AdjacencyList graph;
+  // Stores the original HLO and the dimension order for each node. This is a
+  // temporary map which is used when processing the nodes in this function.
+  absl::flat_hash_map<AdjacencyList::NodeId, HloAndDimOrder>
+      hlo_and_dim_order_map;
+  // Stores the information needed to build the fused HLO for each node (what
+  // was the original HLO and whether we should fuse it or create a parameter).
+  // This is one of the outputs of this function.
+  absl::flat_hash_map<AdjacencyList::NodeId, NodeFusionPlan> fusion_plan_map;
+  // Allows reusing nodes when multiple instructions iterate over the same HLO
+  // using the same iteration spec. In that case we don't duplicate the
+  // instruction in the fusion.
+  absl::flat_hash_map<HloAndIterSpec, AdjacencyList::NodeId> node_reuse_map;
+  // The requirements imposed by the fusion choices made in this function,
+  // combined with the existing requirements. This is one of the outputs of this
+  // function.
+  DotRequirements combined_reqs = requirements_so_far;
+
+  auto get_or_create_fusion_node =
+      [&](const HloInstruction& hlo, const DimensionOrder& dim_order,
+          bool* is_new_node = nullptr) -> AdjacencyList::NodeId {
+    HloAndIterSpec reuse_key = {&hlo, dim_order.ToTensorIterationSpec()};
+    if (auto it = node_reuse_map.find(reuse_key); it != node_reuse_map.end()) {
+      if (is_new_node != nullptr) {
+        *is_new_node = false;
+      }
+      return it->second;
+    }
+    AdjacencyList::NodeId node_id = graph.AddNode();
+    CHECK(hlo_and_dim_order_map.insert({node_id, {&hlo, dim_order}}).second);
+    CHECK(node_reuse_map.insert({reuse_key, node_id}).second);
+    if (is_new_node != nullptr) {
+      *is_new_node = true;
+    }
+    return node_id;
+  };
+  AdjacencyList::NodeId root =
+      get_or_create_fusion_node(root_hlo, root_dim_order);
+
+  // Nodes at the fusion edge that can either get fused too or become parameters
+  // of the fusion. Used to track the number of parameters.
+  absl::flat_hash_set<AdjacencyList::NodeId> inputs({root});
+  std::queue<AdjacencyList::NodeId> queue({root});
+  int64_t num_requeued = 0;
+  // BFS
+  while (queue.size() > num_requeued) {
+    AdjacencyList::NodeId node_id = queue.front();
+    queue.pop();
+    const HloAndDimOrder& hlo_and_dim_order = hlo_and_dim_order_map.at(node_id);
+    const HloInstruction& original_hlo = *hlo_and_dim_order.original_hlo;
+    const DimensionOrder& dim_order = hlo_and_dim_order.dim_order;
+
+    // Watch the total number of fusion parameters.
+    if (max_params.has_value() &&
+        inputs.size() + NumAddedParameters(original_hlo) > max_params.value()) {
+      // Re-queue: the number of parameters may go down when other instructions
+      // are processed.
+      queue.push(node_id);
+      // Prevent infinite loops.
+      ++num_requeued;
+      continue;
+    }
+    num_requeued = 0;
+    if (original_hlo.opcode() == HloOpcode::kParameter) {
+      CHECK(fusion_plan_map
+                .insert({node_id, {&original_hlo, /*should_fuse=*/false}})
+                .second);
+      continue;
+    }
+    auto opt_result = GetOperandDimOrdersAndCombinedReqsIfProfitable(
+        original_hlo, dim_order, properties, gpu_version, combined_reqs);
+    if (!opt_result.has_value()) {
+      CHECK(fusion_plan_map
+                .insert({node_id, {&original_hlo, /*should_fuse=*/false}})
+                .second);
+      continue;
+    }
+    const DimOrderMap operand_dim_orders = std::move(opt_result->dim_orders);
+    combined_reqs = std::move(opt_result->requirements);
+    inputs.erase(node_id);
+    graph.ReserveSpaceForOutNeighbors(node_id, original_hlo.operand_count());
+    for (int64_t i = 0; i < original_hlo.operand_count(); ++i) {
+      const HloInstruction& operand = *original_hlo.operand(i);
+      const DimensionOrder& operand_dim_order = operand_dim_orders.at(&operand);
+      bool is_new_node = false;
+      AdjacencyList::NodeId operand_node_id =
+          get_or_create_fusion_node(operand, operand_dim_order, &is_new_node);
+      graph.AddArc(node_id, operand_node_id);
+      if (is_new_node) {
+        VLOG(6) << "Enqueueing " << operand.ToString() << ":"
+                << operand_dim_order.ToString();
+        inputs.insert(operand_node_id);
+        queue.push(operand_node_id);
+      }
+    }
+    CHECK(
+        fusion_plan_map.insert({node_id, {&original_hlo, /*should_fuse=*/true}})
+            .second);
+  }
+  // Handle the remaining requeued items.
+  while (!queue.empty()) {
+    AdjacencyList::NodeId node_id = queue.front();
+    queue.pop();
+
+    const HloAndDimOrder& hlo_and_dim_order = hlo_and_dim_order_map.at(node_id);
+    CHECK(fusion_plan_map
+              .insert({node_id,
+                       {hlo_and_dim_order.original_hlo, /*should_fuse=*/false}})
+              .second);
+  }
+  return {{std::move(graph), std::move(fusion_plan_map)},
+          std::move(combined_reqs)};
+}
+
+// Builds the HLO instructions for the fusion represented by `fusion_plan`,
+// starting from `node_id`.
+HloInstruction& BuildFusionTowardOperandsImpl(
+    AdjacencyList::NodeId node_id, const FusionPlan& fusion_plan,
+    absl::flat_hash_map<AdjacencyList::NodeId, HloInstruction*>&
+        fused_hlo_map,                           // read/append
+    HloComputation::Builder& builder,            // append
+    std::vector<HloInstruction*>& fusion_params  // append
+) {
+  if (auto it = fused_hlo_map.find(node_id); it != fused_hlo_map.end()) {
+    return *it->second;
+  }
+
+  const NodeFusionPlan& node_fusion_plan = fusion_plan.map.at(node_id);
+  const bool should_fuse = node_fusion_plan.should_fuse;
+  const HloInstruction& original_hlo = *node_fusion_plan.original_hlo;
+
+  HloInstruction* fused_hlo = nullptr;
+  if (should_fuse) {
+    HloInstruction::InstructionVector new_operands;
+    for (AdjacencyList::NodeId operand_id :
+         fusion_plan.graph.GetOutNeighbors(node_id)) {
+      new_operands.push_back(&BuildFusionTowardOperandsImpl(
+          operand_id, fusion_plan, fused_hlo_map, builder, fusion_params));
+    }
+    fused_hlo = builder.AddInstruction(
+        original_hlo.CloneWithNewOperands(original_hlo.shape(), new_operands));
+  } else {
+    fusion_params.push_back(const_cast<HloInstruction*>(&original_hlo));
+    fused_hlo = builder.AddInstruction(HloInstruction::CreateParameter(
+        fusion_params.size() - 1, original_hlo.shape(),
+        absl::StrCat("parameter_", fusion_params.size() - 1)));
+  }
+
+  CHECK(fused_hlo_map.insert({node_id, fused_hlo}).second);
+  return *fused_hlo;
+}
+
+// Builds the HLO instructions for the fusion represented by `fusion_plan`.
+HloInstruction& BuildFusionTowardOperands(
+    const FusionPlan& fusion_plan,
+    HloComputation::Builder& builder,            // append
+    std::vector<HloInstruction*>& fusion_params  // append
+) {
+  absl::flat_hash_map<AdjacencyList::NodeId, HloInstruction*> fused_hlo_map;
+  return BuildFusionTowardOperandsImpl(fusion_plan.graph.GetRoot(), fusion_plan,
+                                       fused_hlo_map, builder, fusion_params);
+}
+
+// Grows the fusion toward the operands.
+//
+// This always succeeds.
+//
+// If it's not possible to fuse something, it fuses a parameter instead.
+//
+// The fusion can grow until it has `max_params` params and it can only grow
+// with operations for which the DimOrder propagation works and they don't
+// impose requirements contradicting the existing requirements.
+//
+// The return value contains the HLOs corresponding to `root_hlo` and the
+// requirements corresponding to the whole fusion so far.
+HlosAndRequirements FuseTowardOperands(
+    const HloInstruction& root_hlo, const DimensionOrder& root_dim_order,
+    const std::optional<int>& max_params,
+    const se::GpuComputeCapability& gpu_version,
+    const DotProperties& properties, const DotRequirements& requirements_so_far,
+    HloComputation::Builder& builder,            // append
+    std::vector<HloInstruction*>& fusion_params  // append
+) {
+  FusionPlanAndRequirements fusion_plan_and_reqs =
+      BuildFusionPlanTowardOperands(root_hlo, root_dim_order, max_params,
+                                    gpu_version, properties,
+                                    requirements_so_far);
+  HloInstruction& fused_hlo_or_param = BuildFusionTowardOperands(
+      fusion_plan_and_reqs.fusion_plan, builder, fusion_params);
+  return HlosAndRequirements{&root_hlo, &fused_hlo_or_param,
+                             fusion_plan_and_reqs.requirements};
+}
+
+// Grows the fusion toward the given dot operand.
+//
+// This always succeeds.
+//
+// If it's not possible to fuse something, it fuses a parameter instead.
+//
+// The fusion can grow until it has `max_params` params and it can only grow
+// with operations for which the DimOrder propagation works and they don't
+// impose requirements contradicting the existing requirements.
+//
+// The return value contains the HLOs corresponding to the given dot operand and
+// the requirements corresponding to the whole fusion so far.
+absl::StatusOr<HlosAndRequirements> FuseDotOperand(
+    const HloInstruction& dot, int operand_index,
+    const se::GpuComputeCapability& gpu_version,
+    HloComputation::Builder& builder,            // append
+    std::vector<HloInstruction*>& fusion_params  // append
+) {
+  // Direct dot inputs have well defined dimension orders.
+  TF_ASSIGN_OR_RETURN(const FusionContext context,
+                      FusionContext::FromDotOperand(dot, operand_index));
+  const HloInstruction& operand = *dot.operand(operand_index);
+  return FuseTowardOperands(operand, context.dim_orders().at(&operand),
+                            TritonFusionAnalysis::kMaxParameterPerDotOperand,
+                            gpu_version, context.dot_properties(),
+                            context.requirements(), builder, fusion_params);
+}
+
+// Grows the fusion toward the users.
+//
+// This always succeeds.
+//
+// The fusion can grow as long as the DimOrder propagation works and the users
+// don't impose requirements contradicting the existing requirements.
+//
+// The return value contains the HLOs corresponding to the "lowest" fused user
+// or `hlo` if no users can be fused.
+//
+// It also grows the fusion upward, toward the "other" operands of the users,
+// but currently only in special cases, such as binary elementwise operation
+// with broadcast of scalar constant.
+HlosAndRequirements FuseTowardUsers(
+    const HloInstruction& hlo, const HloInstruction& fused_hlo,
+    const DimensionOrder& hlo_dim_order,
+    const se::GpuComputeCapability& gpu_version,
+    const DotProperties& properties, const DotRequirements& requirements,
+    HloComputation::Builder& builder,            // append
+    std::vector<HloInstruction*>& fusion_params  // append
+) {
+  const HlosAndRequirements existing_hlos_and_requirements = {&hlo, &fused_hlo,
+                                                              requirements};
+  if (hlo.user_count() != 1) {
+    return existing_hlos_and_requirements;
+  }
+  const HloInstruction& user = *hlo.users()[0];
+  if (!legacy_triton::IsDistributiveOverAddition(user)) {
+    return existing_hlos_and_requirements;
+  }
+
+  // Get the dim orders for the user.
+  auto opt_user_result = GetUserDimOrdersAndCombinedReqsIfProfitable(
+      hlo, hlo_dim_order, user, properties, gpu_version, requirements);
+  if (!opt_user_result.has_value()) {
+    return existing_hlos_and_requirements;
+  }
+  DimensionOrder user_dim_order = opt_user_result->dim_orders.at(&user);
+  DotRequirements combined_requirements = opt_user_result->requirements;
+
+  HloInstruction::InstructionVector new_operands;
+  if (user.operand_count() == 1) {
+    new_operands.push_back(const_cast<HloInstruction*>(&fused_hlo));
+  } else {
+    // Get the dim orders for the operands of the user.
+    // We shouldn't do a profitability check here, we made that decision in
+    // GetUserDimOrdersAndCombinedReqsIfProfitable.
+    auto opt_operand_result = GetOperandDimOrdersAndCombinedReqs(
+        user, user_dim_order, properties, gpu_version, combined_requirements);
+    // This shouldn't fail, because currently we only encounter this when we
+    // have just propagated down the DimOrders on a binary elementwise
+    // operation (user). In that case propagating up the DimOrders should always
+    // work.
+    if (!opt_operand_result.has_value()) {
+      return existing_hlos_and_requirements;
+    }
+    DimOrderMap operand_dim_orders = opt_operand_result->dim_orders;
+    combined_requirements = opt_operand_result->requirements;
+
+    // Fuse the other operands of the user.
+    for (int i = 0; i < user.operand_count(); ++i) {
+      const HloInstruction& operand = *user.operand(i);
+      if (&operand == &hlo) {
+        new_operands.push_back(const_cast<HloInstruction*>(&fused_hlo));
+      } else {
+        HlosAndRequirements hlos_and_requirements = FuseTowardOperands(
+            operand, operand_dim_orders.at(&operand),
+            /*max_params=*/std::nullopt, gpu_version, properties,
+            combined_requirements, builder, fusion_params);
+        new_operands.push_back(
+            const_cast<HloInstruction*>(hlos_and_requirements.fused_hlo));
+        combined_requirements = hlos_and_requirements.requirements;
+      }
+    }
+  }
+
+  const HloInstruction& fused_user = *builder.AddInstruction(
+      user.CloneWithNewOperands(user.shape(), new_operands));
+  return FuseTowardUsers(user, fused_user, user_dim_order, gpu_version,
+                         properties, combined_requirements, builder,
+                         fusion_params);
+}
+
+// Grows the fusion toward the users of the dot.
+//
+// This always succeeds.
+//
+// The fusion can grow as long as the DimOrder propagation works and the users
+// don't impose requirements contradicting the existing requirements.
+//
+// The return value contains the HLOs corresponding to the "lowest" fused user
+// or `dot` if no users can be fused.
+//
+// It also grows the fusion towards the "other" operands of the users, but
+// currently only in special cases, such as binary elementwise operation with
+// broadcast of scalar constant.
+HlosAndRequirements FuseDotOutput(
+    const HloInstruction& dot, const HloInstruction& fused_dot,
+    const se::GpuComputeCapability& gpu_version,
+    const DotRequirements& requirements,
+    HloComputation::Builder& builder,            // append
+    std::vector<HloInstruction*>& fusion_params  // append
+) {
+  const auto context =
+      FusionContext::FromDotOutput(dot, /*split_k=*/1, requirements);
+  return FuseTowardUsers(dot, fused_dot, context.dim_orders().at(&dot),
+                         gpu_version, context.dot_properties(),
+                         context.requirements(), builder, fusion_params);
+}
+
+// Fuses dot and the compatible and profitable to fuse operations around it
+// into a new fusion computation constructed using the builder. fusion_inputs
+// get populated with the non-fused instructions that become operands of the
+// call to this fusion. fusion_output_ptr (if not nullptr) gets assigned the
+// original instruction that has to be replaced by the call to the fusion.
+absl::StatusOr<FusionDecision> CreateDotFusion(
+    const HloDotInstruction& dot, const se::GpuComputeCapability gpu_version,
+    HloComputation::Builder& builder,
+    std::vector<HloInstruction*>& fusion_inputs,
+    HloInstruction** fusion_output_ptr) {
+  VLOG(5) << dot.ToString();
+  if (CodegenDecision is_supported =
+          legacy_triton::IsTritonSupportedInstruction(dot, gpu_version);
+      !is_supported) {
+    VLOG(3) << is_supported.Explain();
+    return is_supported;
+  }
+
+  // Verify sparse dot constraints.
+  if (dot.sparse_operands()) {
+    const SparsityDescriptor& descriptor = dot.sparsity().front();
+    if (dot.sparse_operands() != 1 || descriptor.index() != 0) {
+      return InvalidArgument("Sparsity is only supported on left operand");
+    }
+    if (descriptor.type() != SparsityType::SPARSITY_STRUCTURED_N_M ||
+        descriptor.n() != 2 || descriptor.m() != 4) {
+      return InvalidArgument("Only 2:4 structured sparsity is supported");
+    }
+    // DotDimensionSorter pass makes sure the sparse dimension is minor.
+    CHECK_EQ(descriptor.dimension(), dot.operand(0)->shape().rank() - 1);
+  }
+
+  TF_ASSIGN_OR_RETURN(HlosAndRequirements lhs_hlos_and_reqs,
+                      FuseDotOperand(dot, /*operand_index=*/0, gpu_version,
+                                     builder, fusion_inputs));
+  TF_ASSIGN_OR_RETURN(HlosAndRequirements rhs_hlos_and_reqs,
+                      FuseDotOperand(dot, /*operand_index=*/1, gpu_version,
+                                     builder, fusion_inputs));
+  std::optional<const HloInstruction*> meta_hlo;
+  if (dot.sparse_operands()) {
+    TF_ASSIGN_OR_RETURN(HlosAndRequirements meta_hlos_and_reqs,
+                        FuseDotOperand(dot, /*operand_index=*/2, gpu_version,
+                                       builder, fusion_inputs));
+    meta_hlo.emplace(meta_hlos_and_reqs.fused_hlo);
+  }
+  HloInstruction& fused_dot =
+      FuseDot(dot, *lhs_hlos_and_reqs.fused_hlo, *rhs_hlos_and_reqs.fused_hlo,
+              meta_hlo, builder);
+  // For now the RHS doesn't support splits, so it also doesn't impose any
+  // requirements.
+  HlosAndRequirements fused_output_and_reqs =
+      FuseDotOutput(dot, fused_dot, gpu_version, lhs_hlos_and_reqs.requirements,
+                    builder, fusion_inputs);
+
+  if (fusion_output_ptr != nullptr) {
+    *fusion_output_ptr =
+        const_cast<HloInstruction*>(fused_output_and_reqs.original_hlo);
+  }
+
+  const PrecisionConfig::Algorithm algorithm =
+      dot.precision_config().algorithm();
+  if (algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6 ||
+      algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3 ||
+      dot.GetModule()->config().debug_options().xla_gpu_triton_gemm_any() ||
+      dot.sparse_operands()) {
+    return FusionDecision{};
+  }
+
+  bool is_pure_matmul = true;
+  (void)builder.ForEachInstruction([&](const HloInstruction* fused_hlo) {
+    static constexpr std::array<HloOpcode, 4> kPureOpcodes = {
+        HloOpcode::kBitcast, HloOpcode::kDot, HloOpcode::kParameter,
+        HloOpcode::kReshape};
+    if (absl::c_find(kPureOpcodes, fused_hlo->opcode()) == kPureOpcodes.end()) {
+      is_pure_matmul = false;
+      // Stop iterating.
+      return absl::CancelledError();
+    }
+    return absl::OkStatus();
+  });
+  if (!is_pure_matmul) {
+    return FusionDecision{};
+  }
+
+  return "No profitable operations to fuse.";
+}
+
+// Extracts into fused computations parts of HLO graph including dot()
+// operations that can target the triton GEMM emitter.
+class GemmFusionVisitor : public DfsHloRewriteVisitor {
+ public:
+  explicit GemmFusionVisitor(const se::GpuComputeCapability& gpu_version)
+      : gpu_version_(gpu_version) {}
+  // Checks that a dot() should be targeting the triton GEMM emitter;
+  // if so - fuses all its compatible inputs and outputs as a new computation
+  // and replaces the original dot() with a call to the computation.
+  absl::Status HandleDot(HloInstruction* dot) override {
+    CHECK_EQ(dot->opcode(), HloOpcode::kDot);
+
+    int64_t gemm_rewrite_size_threshold =
+        dot->GetModule()
+            ->config()
+            .debug_options()
+            .xla_gpu_gemm_rewrite_size_threshold();
+    TF_ASSIGN_OR_RETURN(bool is_matmul_tiny,
+                        IsMatrixMultiplicationTooSmallForRewriting(
+                            *dot, gemm_rewrite_size_threshold));
+    if (is_matmul_tiny && IsDotSupportedByClassicalEmitters(*dot)) {
+      return absl::OkStatus();
+    }
+
+    std::string fusion_name = absl::StrCat("gemm_fusion_", dot->name());
+    HloComputation::Builder builder(absl::StrCat(fusion_name, "_computation"));
+    std::vector<HloInstruction*> fusion_inputs;
+    HloInstruction* fusion_output = nullptr;
+    TF_ASSIGN_OR_RETURN(
+        const FusionDecision should_fuse,
+        CreateDotFusion(*Cast<HloDotInstruction>(dot), gpu_version_, builder,
+                        fusion_inputs, &fusion_output));
+    if (builder.last_added_instruction() == nullptr) {
+      return absl::OkStatus();
+    }
+    // If a GEMM requiring padding for cuBLAS is encountered here this
+    // happened because earlier ShouldTritonHandleGEMM() accepted it and padding
+    // was skipped. Accept it ignoring profitability checks.
+    // TODO(rocm): check ROCM padding requirements.
+    if (std::holds_alternative<se::CudaComputeCapability>(gpu_version_)) {
+      if (!CublasRequiresPadding(
+              *Cast<HloDotInstruction>(dot),
+              std::get<se::CudaComputeCapability>(gpu_version_)) &&
+          !should_fuse) {
+        return absl::OkStatus();
+      }
+    }
+
+    HloComputation* computation =
+        dot->GetModule()->AddComputationAndUnifyNamesAndIds(builder.Build(),
+                                                            /*is_entry=*/false);
+    HloInstruction* dot_fusion =
+        dot->parent()->AddInstruction(HloInstruction::CreateFusion(
+            computation->root_instruction()->shape(),
+            HloInstruction::FusionKind::kCustom, fusion_inputs, computation));
+    // Copy the metadata of the `dot` to the newly created `fusion` op. This
+    // is convenient for handling metadata in split-k rewriting subsequently.
+    dot_fusion->set_metadata(dot->metadata());
+    dot_fusion->GetModule()->SetAndUniquifyInstrName(dot_fusion, fusion_name);
+
+    TF_ASSIGN_OR_RETURN(auto gpu_config,
+                        dot_fusion->backend_config<GpuBackendConfig>());
+    FusionBackendConfig& backend_config =
+        *gpu_config.mutable_fusion_backend_config();
+    backend_config.set_kind(std::string(kTritonGemmFusionKind));
+    TF_RETURN_IF_ERROR(dot_fusion->set_backend_config(gpu_config));
+
+    if (fusion_output->IsRoot()) {
+      fusion_output->parent()->set_root_instruction(dot_fusion);
+      TF_RETURN_IF_ERROR(
+          fusion_output->parent()->RemoveInstructionAndUnusedOperands(
+              fusion_output));
+      MarkAsChanged();
+    } else {
+      TF_RETURN_IF_ERROR(ReplaceInstruction(fusion_output, dot_fusion));
+    }
+    XLA_VLOG_LINES(5, computation->ToString(HloPrintOptions::ShortParsable()));
+    return absl::OkStatus();
+  }
+
+ private:
+  se::GpuComputeCapability gpu_version_;
+};
+
+absl::StatusOr<bool> RunOnComputation(
+    HloComputation* computation, const se::GpuComputeCapability& gpu_version) {
+  GemmFusionVisitor visitor(gpu_version);
+  TF_RETURN_IF_ERROR(computation->Accept(&visitor));
+  return visitor.changed();
+}
+
+
+}  // namespace
+
+bool ShouldTritonHandleGEMM(HloDotInstruction& dot,
+                            const se::GpuComputeCapability& gpu_version) {
+  std::vector<HloInstruction*> fusion_inputs;
+  HloComputation::Builder builder("disposable");
+  return CreateDotFusion(dot, gpu_version, builder, fusion_inputs,
+                         /*fusion_output_ptr=*/nullptr)
+      ->CanFuse();
+}
+
+absl::StatusOr<bool> GemmFusion::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  TF_RETURN_IF_ERROR(
+      EnsureTritonSupportsComputeCapability(compute_capability_));
+
+  bool changed = false;
+  for (HloComputation* computation :
+       module->MakeNonfusionComputations(execution_threads)) {
+    TF_ASSIGN_OR_RETURN(bool result,
+                        RunOnComputation(computation, compute_capability_));
+    changed |= result;
+  }
+  return changed;
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_fusion.h b/third_party/xla/xla/service/gpu/transforms/gemm_fusion.h
new file mode 100644
index 0000000..7f8fe6f
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/gemm_fusion.h
@@ -0,0 +1,57 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_GEMM_FUSION_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_GEMM_FUSION_H_
+
+// This file contains the code for fusing dots and other operations into Triton
+// GEMM fusions.
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/service/instruction_fusion.h"
+#include "xla/stream_executor/device_description.h"
+
+namespace xla {
+namespace gpu {
+
+// Filters GEMMs which are better to handle using Triton.
+bool ShouldTritonHandleGEMM(HloDotInstruction&,
+                            const se::GpuComputeCapability&);
+
+// Rewrite compatible dot() calls into custom calls with fused computations
+// that target Triton-based matmul emitter.
+class GemmFusion : public HloModulePass {
+ public:
+  explicit GemmFusion(const se::GpuComputeCapability& compute_capability)
+      : compute_capability_(compute_capability) {}
+  absl::string_view name() const override { return "triton-gemm-rewriter"; }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+  se::GpuComputeCapability compute_capability_;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_GEMM_FUSION_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/gemm_fusion_test.cc
new file mode 100644
index 0000000..85ad2e8
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/gemm_fusion_test.cc
@@ -0,0 +1,1334 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/gemm_fusion.h"
+
+#include <memory>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/status/status.h"
+#include "absl/strings/string_view.h"
+#include "xla/autotuning.pb.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/gpu/cublas_padding_requirements.h"
+#include "xla/service/gpu/triton_fusion_analysis.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/tests/filecheck.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/tests/verified_hlo_module.h"
+#include "xla/xla.pb.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/status_matchers.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::FieldsAre;
+
+namespace m = ::xla::match;
+
+class GemmFusionTest : public HloTestBase {
+ public:
+  GemmFusionTest()
+      : HloTestBase(/*verifier_layout_sensitive=*/true,
+                    /*allow_mixed_precision_in_hlo_verifier=*/false) {}
+
+  DebugOptions GetDebugOptionsForTest() override {
+    DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest();
+    debug_options.set_xla_gpu_triton_gemm_any(false);
+    debug_options.set_xla_gpu_gemm_rewrite_size_threshold(0);
+    return debug_options;
+  }
+
+  se::GpuComputeCapability gpu_version_{
+      se::CudaComputeCapability{se::CudaComputeCapability::AMPERE, 0}};
+
+  void MatchHloModule(HloModule& module, absl::string_view pattern) {
+    TF_ASSERT_OK_AND_ASSIGN(bool filecheck_result,
+                            RunFileCheck(module.ToString(), pattern));
+    EXPECT_TRUE(filecheck_result);
+  }
+};
+
+TEST_F(GemmFusionTest, TransposeSubdimensionGroup) {
+  // This HLO is artificial because unnecessary reshapes get optimized
+  // out during compilation. It tests the ability of GemmFusion
+  // to handle transposes of groups of subdimensions.
+  auto module = ParseAndReturnVerifiedModule(R"(
+HloModule m
+
+ENTRY e {
+  p0 = f32[32,3] parameter(0)
+  t1 = f32[3,32] transpose(p0), dimensions={1,0}
+  r1 = f32[3,8,4] reshape(t1)
+  r0 = f32[3,32] reshape(r1)
+  p1 = f16[32,7] parameter(1)
+  c1 = f32[32,7] convert(p1)
+  ROOT d = f32[3,7] dot(r0, c1),
+    lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})")
+                    .value();
+  EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch(m::Fusion(m::Parameter(), m::Parameter())));
+}
+
+TEST_F(GemmFusionTest, UnsupportedTransposeIsNotFused) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+  p0 = f16[1,512,8,1024]{3,1,0,2} parameter(0)
+  c = f16[1,512,8,1024]{3,2,1,0} copy(p0)
+  b = f16[4096,1024]{1,0} bitcast(c)
+  p1 = f16[128,1024]{1,0} parameter(1)
+  ROOT d = f16[4096,128]{1,0} dot(b, p1),
+    lhs_contracting_dims={1}, rhs_contracting_dims={1}
+})")
+                    .value();
+  EXPECT_FALSE(GemmFusion(gpu_version_).Run(module.get()).value());
+}
+
+TEST_F(GemmFusionTest, BitcastChain) {
+  // This HLO is artificial because unnecessary reshapes get optimized
+  // out during compilation. It tests the ability of GemmFusion
+  // to handle various kinds of bitcasts.
+  auto module = ParseAndReturnVerifiedModule(R"(
+HloModule m
+
+ENTRY e {
+  p0 = s8[60,5] parameter(0)
+  r0 = s8[3,20,5] reshape(p0)
+  c0 = f16[3,20,5] convert(r0)
+  p1 = f16[3,200] parameter(1)
+  r12 = f16[600] reshape(p1)
+  r11 = f16[30,20] reshape(r12)
+  r1 = f16[3,10,20] reshape(r11)
+  ROOT d = f16[3,5,10] dot(c0, r1),
+    lhs_contracting_dims={1}, rhs_contracting_dims={2},
+    lhs_batch_dims={0}, rhs_batch_dims={0}
+})")
+                    .value();
+  EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch(m::Fusion(m::Parameter(), m::Parameter())));
+}
+
+TEST_F(GemmFusionTest, SplitDimensionTwice) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+  p0 = s8[4,2,32,4,2] parameter(0)
+  r1 = s8[8,32,8] reshape(p0)
+  t1 = s8[32,8,8] transpose(r1), dimensions={1,0,2}
+  r0 = s8[32,64] reshape(t1)
+  p1 = s8[32,32] parameter(1)
+  c0 = f16[32,32] convert(p1)
+  ROOT d = f16[64,32] dot(r0, c0),
+    lhs_contracting_dims={0}, rhs_contracting_dims={1}
+})")
+                    .value();
+  EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch(m::Fusion(m::Parameter(), m::Parameter())));
+}
+
+TEST_F(GemmFusionTest, DoNotTriggerOnUnsupportedOutputConversions) {
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+                          ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+  p0 = f16[128,256] parameter(0)
+  p1 = f16[256,512] parameter(1)
+  r = f16[128,512] dot(p0, p1),
+    lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  ROOT c = u8[128,512] convert(r)
+})"));
+  EXPECT_FALSE(GemmFusion(gpu_version_).Run(module.get()).value());
+}
+
+TEST_F(GemmFusionTest, FuseDotWithTrivialNoncontractingDim) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+HloModule m
+
+ENTRY e {
+  p0 = s8[60,5] parameter(0)
+  r0 = s8[3,20,5] reshape(p0)
+  c0 = f16[3,20,5] convert(r0)
+  p1 = f16[3,1,20] parameter(1)
+  ROOT d = f16[3,5,1] dot(c0, p1),
+    lhs_contracting_dims={1}, rhs_contracting_dims={2},
+    lhs_batch_dims={0}, rhs_batch_dims={0}
+})")
+                    .value();
+  EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch(m::Fusion(m::Parameter(), m::Parameter())));
+}
+
+TEST_F(GemmFusionTest, HandleDotIfCublasRequiresPadding) {
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+                          ParseAndReturnVerifiedModule(R"(
+HloModule m
+
+ENTRY e {
+  p0 = f16[5,3] parameter(0)
+  p1 = f16[5,7] parameter(1)
+  ROOT d = f16[3,7] dot(p0, p1),
+    lhs_contracting_dims={0}, rhs_contracting_dims={0}
+})"));
+
+  const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0};
+  EXPECT_TRUE(CublasRequiresPadding(
+      *xla::Cast<HloDotInstruction>(
+          module->entry_computation()->root_instruction()),
+      cc));
+  EXPECT_TRUE(GemmFusion(cc).Run(module.get()).value());
+}
+
+TEST_F(GemmFusionTest, FuseSliceOfParameterWithOtherUsers) {
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+                          ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+  p0 = f32[97,121] parameter(0)
+  s0 = f32[7,101] slice(p0), slice={[3:10], [10:111]}
+  p1 = f32[101,16] parameter(1)
+  d = f32[16,7] dot(p1, s0),
+    lhs_contracting_dims={0}, rhs_contracting_dims={1}
+  s1 = f32[3,33] slice(p0), slice={[10:13], [20:53]}
+  ROOT t = tuple(d, s1)
+})"));
+
+  const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0};
+  EXPECT_TRUE(GemmFusion(cc).Run(module.get()).value());
+}
+
+TEST_F(GemmFusionTest, DoNotFuseSliceOfMixedDimensions) {
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+                          ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+  p0 = bf16[768,64] parameter(0)
+  s0 = bf16[768,32] slice(p0), slice={[0:768], [0:32]}
+  b0 = bf16[256,3,32] reshape(s0)
+  b1 = bf16[256,96] reshape(b0)
+  p1 = bf16[256,96] parameter(1)
+  ROOT d = bf16[96,96] dot(b1, p1),
+    lhs_contracting_dims={0}, rhs_contracting_dims={0}
+})"));
+
+  const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0};
+  EXPECT_FALSE(GemmFusion(cc).Run(module.get()).value());
+}
+
+TEST_F(GemmFusionTest, DoNotFuseSlicesOfNonMajorFragments) {
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+                          ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+  p0 = f32[2,2,256,256] parameter(0)
+  s0 = f32[1,1,256,256] slice(p0),
+    slice={[0:1], [0:1], [0:256], [0:256]}
+  r0 = f32[256,256] reshape(s0)
+  p1 = f16[2,2,256,256] parameter(1)
+  s1 = f16[1,1,256,256] slice(p1),
+    slice={[0:1], [0:1], [0:256], [0:256]}
+  r1 = f16[256,256] reshape(s1)
+  ROOT d = f32[256,256] dot(r0, r1),
+    lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})"));
+
+  const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0};
+  EXPECT_FALSE(GemmFusion(cc).Run(module.get()).value());
+}
+
+TEST_F(GemmFusionTest, DynamicSliceIsFused) {
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+                          ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+  dot_lhs = f32[2,18] parameter(0)
+  dynamic_slice_input = f32[2,64,2] parameter(1)
+  start_index0 = s32[] parameter(2)
+  start_index1_2 = s32[] constant(0)
+  dynamic_slice = f32[1,64,2] dynamic-slice(dynamic_slice_input, start_index0, start_index1_2, start_index1_2),
+                  dynamic_slice_sizes={1,64,2}
+  reshape = f32[64,2] reshape(dynamic_slice)
+  ROOT dot = f16[18,64] dot(dot_lhs, reshape),
+             lhs_contracting_dims={0}, rhs_contracting_dims={1}
+})"));
+
+  EXPECT_TRUE(GemmFusion(se::CudaComputeCapability{
+                             se::CudaComputeCapability::AMPERE, 0})
+                  .Run(module.get())
+                  .value());
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch((m::Fusion(m::Parameter(), m::Parameter(),
+                                    m::Parameter(), m::Constant()))));
+}
+
+TEST_F(GemmFusionTest, DynamicSlicesAreFusedEvenIfTheyShareIndices) {
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+                          ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+  p0 = f32[2,64,2] parameter(0)
+  p1 = s32[] parameter(1)
+  p2 = s32[] parameter(2)
+  p3 = s32[] parameter(3)
+  ds0 = f32[1,64,2] dynamic-slice(p0, p1, p2, p3), dynamic_slice_sizes={1,64,2}
+  a = f32[64,2] reshape(ds0)
+  ds1 = f32[1,64,2] dynamic-slice(p0, p3, p2, p1), dynamic_slice_sizes={1,64,2}
+  b = f32[64,2] reshape(ds1)
+  ROOT d = f16[64,64] dot(a, b),
+           lhs_contracting_dims={1}, rhs_contracting_dims={1}
+})"));
+
+  EXPECT_TRUE(GemmFusion(se::CudaComputeCapability{
+                             se::CudaComputeCapability::AMPERE, 0})
+                  .Run(module.get())
+                  .value());
+  // TODO(b/339810582): Don't duplicate scalar parameters to dot fusions,
+  // because they are never tiled differently.
+  // TODO(b/339814210): Don't count scalar parameters towards dot fusion
+  // parameter limit.
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch((m::Fusion(m::Parameter(), m::Parameter(), m::Parameter(),
+                            m::Parameter(), m::Parameter(), m::Parameter(),
+                            m::Parameter(), m::Parameter()))));
+}
+
+TEST_F(GemmFusionTest, DoNotFuseDynamicSliceOfNonMajorFragments) {
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+                          ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+  dot_lhs = f32[2,4]{1,0} parameter(0)
+  dynamic_slice_input = f32[4,5,2]{2,1,0} parameter(1)
+  c0 = s32[] constant(0)
+  c2 = s32[] constant(2)
+  dynamic_slice = f32[4,1,2]{2,1,0} dynamic-slice(dynamic_slice_input, c0, c2, c0),
+                  dynamic_slice_sizes={4,1,2}
+  reshape = f32[4,2]{1,0} reshape(dynamic_slice)
+  ROOT dot = f32[4,4]{1,0} dot(dot_lhs, reshape),
+             lhs_contracting_dims={0}, rhs_contracting_dims={1}
+})"));
+  const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0};
+  // FusionDecision "Unsupported dynamic slice on non-major-most dimension."
+  EXPECT_FALSE(GemmFusion(cc).Run(module.get()).value());
+}
+
+TEST_F(GemmFusionTest, CanFuseDynamicSliceOfContractingDimIfItIsMajor) {
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+                          ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+  dot_lhs = f32[2,4]{1,0} parameter(0)
+  dynamic_slice_input = f32[5,5]{1,0} parameter(1)
+  start_index0 = s32[] constant(2)
+  start_index1 = s32[] constant(0)
+  dynamic_slice = f32[2,5]{1,0} dynamic-slice(dynamic_slice_input, start_index0, start_index1),
+                  dynamic_slice_sizes={2,5}
+  ROOT d = f32[4,5]{1,0} dot(dot_lhs, dynamic_slice),
+           lhs_contracting_dims={0}, rhs_contracting_dims={0}
+})"));
+  EXPECT_TRUE(GemmFusion(se::CudaComputeCapability{
+                             se::CudaComputeCapability::AMPERE, 0})
+                  .Run(module.get())
+                  .value());
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch((m::Fusion(m::Parameter(), m::Parameter(),
+                                    m::Constant(), m::Constant()))));
+}
+
+TEST_F(GemmFusionTest, SliceToDegenerateIsSkipped) {
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+                          ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+  p = f32[3] parameter(0)
+  s = f32[1] slice(p), slice={[2:3]}
+  r = f32[] reshape(s)
+  b = f32[3,3] broadcast(r), dimensions={}
+  ROOT d = f32[3,3] dot(b, b),
+    lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+)"));
+  const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0};
+
+  ASSERT_TRUE(GemmFusion(cc).Run(module.get()).value());
+
+  // Slice is not fused.
+  MatchHloModule(*module, R"(
+; CHECK-NOT: slice
+; CHECK: ENTRY
+; CHECK: slice
+)");
+}
+
+TEST_F(GemmFusionTest, MultipleUsesAreHandled) {
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+                          ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+  c = f32[] constant(1)
+  b = f32[6,8] broadcast(c), dimensions={}
+  p0 = f32[6,8] parameter(0)
+  a1 = f32[6,8] add(p0, b)
+  e = f32[6,8] exponential(a1)
+  a2 = f32[6,8] add(e, b)
+  d = f32[6,8] divide(b, a2)
+  p2 = f16[8,6] parameter(1)
+  cv = f32[8,6] convert(p2)
+  ROOT r = f32[6,6] dot(d, cv),
+    lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})"));
+  const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0};
+  EXPECT_TRUE(GemmFusion(cc).Run(module.get()).value());
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch(m::Fusion(m::Parameter(), m::Parameter())));
+}
+
+TEST_F(GemmFusionTest, BinaryElementwiseOfBroadcastIsFused) {
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+                          ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+  p2 = f32[3072] parameter(2)
+  b = f32[8192,3072] broadcast(p2), dimensions={1}
+  p0 = f16[8192,3072] parameter(0)
+  p0c = f32[8192,3072] convert(p0)
+  a = f32[8192,3072] add(p0c, b)
+  p1 = f32[3072,768] parameter(1)
+  ROOT r = f32[8192,768] dot(a, p1),
+    lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})"));
+  const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0};
+  EXPECT_TRUE(GemmFusion(cc).Run(module.get()).value());
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(m::Fusion(m::Parameter(), m::Parameter(), m::Parameter())));
+}
+
+TEST_F(GemmFusionTest, BinaryElementwiseOfUnsupportedBroadcastIsNotFused) {
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+                          ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+  p2 = f32[768] parameter(2)
+  b = f32[8192,768,4] broadcast(p2), dimensions={1}
+  s = f32[8192,3072] bitcast(b)
+  p0 = f16[8192,3072] parameter(0)
+  p0c = f32[8192,3072] convert(p0)
+  a = f32[8192,3072] add(p0c, s)
+  p1 = f32[3072,768] parameter(1)
+  ROOT r = f32[8192,768] dot(a, p1),
+    lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})"));
+  const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0};
+  EXPECT_FALSE(GemmFusion(cc).Run(module.get()).value());
+}
+
+class GemmFusionLevel2Test : public GemmFusionTest {
+ public:
+  DebugOptions GetDebugOptionsForTest() override {
+    DebugOptions debug_options = GemmFusionTest::GetDebugOptionsForTest();
+    debug_options.set_xla_gpu_triton_fusion_level(2);
+    return debug_options;
+  }
+};
+
+TEST_F(GemmFusionTest, ConcatenationDivisibleBy64IsFused) {
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+                          ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+  p0 = bf16[8192,1]{1,0} parameter(0)
+  p1 = bf16[2752,8192]{1,0} parameter(1)
+  p2 = bf16[2752,8192]{1,0} parameter(2)
+  concat = bf16[5504,8192]{1,0} concatenate(p1, p2), dimensions={0}
+  bitcast = bf16[8192,5504]{0,1} bitcast(concat)
+  ROOT r = f32[1,5504]{1,0} dot(p0, bitcast),
+    lhs_contracting_dims={0}, rhs_contracting_dims={0}
+})"));
+  const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0};
+  EXPECT_TRUE(GemmFusion(cc).Run(module.get()).value());
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(m::Fusion(m::Parameter(), m::Parameter(), m::Parameter())));
+}
+
+TEST_F(GemmFusionLevel2Test, ReshapeToScalarIsHandled) {
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+                          ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+  p0 = s8[5,3] parameter(0)
+  c = f16[5,3] convert(p0)
+  p1 = f16[1] parameter(1)
+  r = f16[] reshape(p1)
+  b = f16[5,7] broadcast(r)
+  ROOT d = f16[3,7] dot(c, b),
+    lhs_contracting_dims={0}, rhs_contracting_dims={0}
+})"));
+
+  EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch(m::Fusion(m::Parameter(), m::Parameter())));
+}
+
+TEST_F(GemmFusionLevel2Test, DoNotFuseIncompatibleDimensionSplits) {
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+                          ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+  p1 = s8[5,7,2,3]{3,2,1,0} parameter(1)
+  t1 = s8[7,5,2,3]{3,2,1,0} transpose(p1), dimensions={1,0,2,3}
+  r1 = s8[7,30]{1,0} reshape(t1)
+  cvt = f16[7,30]{1,0} convert(r1)
+  p2 = f16[2,7,5,3]{3,2,1,0} parameter(2)
+  t2 = f16[7,2,5,3]{3,2,1,0} transpose(p2), dimensions={1,0,2,3}
+  r2 = f16[7,30]{1,0} reshape(t2)
+  a = f16[7,30]{1,0} add(cvt, r2)
+  p0 = f16[7,79]{1,0} parameter(0)
+  ROOT dot = f16[30,79]{1,0} dot(a, p0),
+    lhs_contracting_dims={0}, rhs_contracting_dims={0}
+})"));
+
+  EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(m::Fusion(m::Transpose(), m::Parameter(), m::Parameter())));
+}
+
+TEST_F(GemmFusionLevel2Test, DoNotFuseTooManyParameters) {
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+                          ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+  tmp_0 = f32[] constant(1)
+  tmp_1 = f32[3,49]{1,0} broadcast(tmp_0), dimensions={}
+  tmp_2 = f32[3,49]{1,0} parameter(6)
+  tmp_3 = f32[] constant(0)
+  tmp_4 = f32[3,49]{1,0} broadcast(tmp_3), dimensions={}
+  tmp_5 = pred[3,49]{1,0} compare(tmp_2, tmp_4), direction=GT
+  tmp_6 = f32[3,49]{1,0} convert(tmp_5)
+  tmp_7 = f32[3,49]{1,0} subtract(tmp_1, tmp_6)
+  tmp_8 = s32[] parameter(13)
+  tmp_9 = f32[] convert(tmp_8)
+  tmp_10 = f32[] maximum(tmp_9, tmp_0)
+  tmp_11 = f32[] divide(tmp_3, tmp_10)
+  tmp_12 = f32[3,49]{1,0} broadcast(tmp_11), dimensions={}
+  tmp_13 = pred[3,49]{1,0} parameter(7)
+  tmp_14 = pred[3,49]{1,0} parameter(10)
+  tmp_15 = pred[3,49]{1,0} and(tmp_13, tmp_14)
+  tmp_16 = f32[3,49]{1,0} convert(tmp_15)
+  tmp_17 = f32[3,49]{1,0} multiply(tmp_12, tmp_16)
+  tmp_18 = f32[3,49]{1,0} negate(tmp_17)
+  tmp_19 = f32[3,49]{1,0} multiply(tmp_7, tmp_18)
+  tmp_20 = f32[3,49]{1,0} parameter(19)
+  tmp_21 = f32[3,49]{1,0} subtract(tmp_1, tmp_20)
+  tmp_22 = f32[3,49]{1,0} divide(tmp_19, tmp_21)
+  tmp_23 = f32[3,49]{1,0} negate(tmp_22)
+  tmp_24 = f32[3,49]{1,0} negate(tmp_6)
+  tmp_25 = f32[3,49]{1,0} multiply(tmp_24, tmp_17)
+  tmp_26 = f32[3,49]{1,0} divide(tmp_25, tmp_20)
+  tmp_27 = f32[3,49]{1,0} add(tmp_23, tmp_26)
+  tmp_28 = f32[3,49]{1,0} parameter(18)
+  tmp_29 = f32[3,49]{1,0} multiply(tmp_27, tmp_28)
+  tmp_30 = f32[3,49]{1,0} parameter(17)
+  tmp_31 = f32[3,49]{1,0} multiply(tmp_29, tmp_30)
+  tmp_32 = f32[3,49]{1,0} parameter(16)
+  tmp_33 = f32[3,49]{1,0} multiply(tmp_31, tmp_32)
+  tmp_34 = f32[3,49]{1,0} parameter(15)
+  tmp_35 = f32[3,49]{1,0} add(tmp_33, tmp_34)
+  tmp_36 = f32[3,49]{1,0} parameter(14)
+  tmp_37 = f32[3,49]{1,0} add(tmp_35, tmp_36)
+  tmp_38 = f32[1,1]{1,0} constant({ {0} })
+  tmp_39 = f32[1,1]{1,0} broadcast(tmp_38), dimensions={0,1}
+  tmp_40 = f32[] reshape(tmp_39)
+  tmp_41 = f32[3,32]{1,0} broadcast(tmp_40), dimensions={}
+  tmp_42 = u32[48]{0} parameter(11)
+  tmp_43 = u32[48]{0} parameter(5)
+  tmp_44 = u32[96]{0} concatenate(tmp_42, tmp_43), dimensions={0}
+  tmp_45 = u32[3,32]{1,0} reshape(tmp_44)
+  tmp_46 = u32[96]{0} reshape(tmp_45)
+  tmp_47 = u32[] constant(1)
+  tmp_48 = u32[3,32]{1,0} broadcast(tmp_47), dimensions={}
+  tmp_49 = u32[96]{0} reshape(tmp_48)
+  tmp_50 = u32[96]{0} shift-right-logical(tmp_46, tmp_49)
+  tmp_51 = u32[3,32]{1,0} reshape(tmp_50)
+  tmp_52 = u32[3,32]{1,0} or(tmp_51, tmp_48)
+  tmp_53 = f32[3,32]{1,0} bitcast-convert(tmp_52)
+  tmp_54 = f32[3,32]{1,0} broadcast(tmp_0), dimensions={}
+  tmp_55 = f32[3,32]{1,0} subtract(tmp_53, tmp_54)
+  tmp_56 = f32[1,1]{1,0} constant({ {1} })
+  tmp_57 = f32[1,1]{1,0} broadcast(tmp_56), dimensions={0,1}
+  tmp_58 = f32[] reshape(tmp_57)
+  tmp_59 = f32[3,32]{1,0} broadcast(tmp_58), dimensions={}
+  tmp_60 = f32[3,32]{1,0} multiply(tmp_55, tmp_59)
+  tmp_61 = f32[3,32]{1,0} add(tmp_60, tmp_41)
+  tmp_62 = f32[3,32]{1,0} maximum(tmp_41, tmp_61)
+  tmp_63 = f32[3,32]{1,0} broadcast(tmp_3), dimensions={}
+  tmp_64 = pred[3,32]{1,0} compare(tmp_62, tmp_63), direction=LT
+  tmp_65 = f32[3,32]{1,0} convert(tmp_64)
+  tmp_66 = f32[3,49]{1,0} parameter(9)
+  tmp_67 = f32[49]{0} parameter(4)
+  tmp_68 = f32[3,49]{1,0} broadcast(tmp_67), dimensions={1}
+  tmp_69 = f32[3,49]{1,0} add(tmp_66, tmp_68)
+  tmp_70 = f32[1,49]{1,0} parameter(12)
+  tmp_71 = f32[1,49]{1,0} broadcast(tmp_0), dimensions={}
+  tmp_72 = f32[1,49]{1,0} divide(tmp_70, tmp_71)
+  tmp_73 = f32[1,49]{1,0} broadcast(tmp_72), dimensions={0,1}
+  tmp_74 = f32[49]{0} reshape(tmp_73)
+  tmp_75 = f32[3,49]{1,0} broadcast(tmp_74), dimensions={1}
+  tmp_76 = f32[3,49]{1,0} subtract(tmp_69, tmp_75)
+  tmp_77 = f32[1,49]{1,0} parameter(3)
+  tmp_78 = f32[1,49]{1,0} parameter(8)
+  tmp_79 = f32[1,49]{1,0} divide(tmp_78, tmp_71)
+  tmp_80 = f32[1,49]{1,0} multiply(tmp_72, tmp_72)
+  tmp_81 = f32[1,49]{1,0} subtract(tmp_79, tmp_80)
+  tmp_82 = f32[1,49]{1,0} add(tmp_81, tmp_71)
+  tmp_83 = f32[1,49]{1,0} rsqrt(tmp_82)
+  tmp_84 = f32[1,49]{1,0} multiply(tmp_77, tmp_83)
+  tmp_85 = f32[1,49]{1,0} broadcast(tmp_84), dimensions={0,1}
+  tmp_86 = f32[49]{0} reshape(tmp_85)
+  tmp_87 = f32[3,49]{1,0} broadcast(tmp_86), dimensions={1}
+  tmp_88 = f32[3,49]{1,0} multiply(tmp_76, tmp_87)
+  tmp_89 = f32[1,49]{1,0} parameter(2)
+  tmp_90 = f32[1,49]{1,0} broadcast(tmp_89), dimensions={0,1}
+  tmp_91 = f32[49]{0} reshape(tmp_90)
+  tmp_92 = f32[3,49]{1,0} broadcast(tmp_91), dimensions={1}
+  tmp_93 = f32[3,49]{1,0} add(tmp_88, tmp_92)
+  tmp_94 = f32[49,32]{1,0} parameter(1)
+  tmp_95 = f32[3,32]{1,0} dot(tmp_93, tmp_94), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  tmp_96 = f32[32]{0} parameter(0)
+  tmp_97 = f32[3,32]{1,0} broadcast(tmp_96), dimensions={1}
+  tmp_98 = f32[3,32]{1,0} add(tmp_95, tmp_97)
+  tmp_99 = f32[3,32]{1,0} multiply(tmp_65, tmp_98)
+  tmp_100 = f32[3,32]{1,0} divide(tmp_99, tmp_63)
+  tmp_101 = f32[3,32]{1,0} maximum(tmp_100, tmp_63)
+  ROOT tmp_102 = f32[49,32]{1,0} dot(tmp_37, tmp_101), lhs_contracting_dims={0}, rhs_contracting_dims={0}
+})"));
+
+  EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
+  EXPECT_EQ(module->entry_computation()->root_instruction()->opcode(),
+            HloOpcode::kFusion);
+  EXPECT_EQ(module->entry_computation()->root_instruction()->fusion_kind(),
+            HloInstruction::FusionKind::kCustom);
+  EXPECT_LE(module->entry_computation()->root_instruction()->operand_count(),
+            TritonFusionAnalysis::kMaxParameterPerDotOperand * 2);
+}
+
+TEST_F(GemmFusionLevel2Test,
+       DoNotFuseTooManyParametersWhenAnInstructionWouldAddMultipleParameters) {
+  static_assert(TritonFusionAnalysis::kMaxParameterPerDotOperand == 4,
+                "We have to update this test.");
+  // If we fuse the select, it adds 2 additional parameters at once (not 3,
+  // because the select instruction itself is removed from the parameters).
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+                          ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+  a = f32[3,49]{1,0} parameter(0)
+  b = f32[3,49]{1,0} parameter(1)
+  c = pred[3,49]{1,0} parameter(2)
+  d = f32[3,49]{1,0} parameter(3)
+  e = f32[3,49]{1,0} parameter(4)
+  add0 = f32[3,49]{1,0} add(a, b)
+  select = f32[3,49]{1,0} select(c, d, e)
+  add1 = f32[3,49]{1,0} add(add0, select)
+  f = f32[3,32]{1,0} parameter(5)
+  ROOT tmp_102 = f32[49,32]{1,0} dot(add1, f), lhs_contracting_dims={0}, rhs_contracting_dims={0}
+})"));
+
+  EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
+  EXPECT_EQ(module->entry_computation()->root_instruction()->opcode(),
+            HloOpcode::kFusion);
+  EXPECT_EQ(module->entry_computation()->root_instruction()->fusion_kind(),
+            HloInstruction::FusionKind::kCustom);
+  EXPECT_LE(module->entry_computation()->root_instruction()->operand_count(),
+            TritonFusionAnalysis::kMaxParameterPerDotOperand + 1);
+}
+
+TEST_F(GemmFusionLevel2Test, DoNotFuseTooManyParametersForConcat) {
+  static_assert(TritonFusionAnalysis::kMaxParameterPerDotOperand == 4,
+                "We have to update this test.");
+  // The concat shouldn't overgo the allowed parameter limit.
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+                          ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+  a = f32[3,3]{1,0} parameter(0)
+  b = f32[3,3]{1,0} parameter(1)
+  c = f32[3,3]{1,0} parameter(2)
+  d = f32[3,3]{1,0} parameter(3)
+  e = f32[3,3]{1,0} parameter(4)
+  f = f16[3,3]{1,0} parameter(5)
+  concat = f32[15,3]{1,0} concatenate(a, b, c, d, e), dimensions={0}
+  convert = f32[3,3]{1,0} convert(f)
+  ROOT dot = f32[15,3]{1,0} dot(concat, convert), lhs_contracting_dims={1}, rhs_contracting_dims={1}
+})"));
+
+  EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
+  EXPECT_EQ(module->entry_computation()->root_instruction()->opcode(),
+            HloOpcode::kFusion);
+  EXPECT_EQ(module->entry_computation()->root_instruction()->fusion_kind(),
+            HloInstruction::FusionKind::kCustom);
+  EXPECT_LE(module->entry_computation()->root_instruction()->operand_count(),
+            TritonFusionAnalysis::kMaxParameterPerDotOperand + 1);
+}
+
+TEST_F(GemmFusionLevel2Test,
+       InstructionsReachableFromMultipleOperandsAreHandledCorrectly) {
+  static_assert(TritonFusionAnalysis::kMaxParameterPerDotOperand == 4,
+                "We have to update this test.");
+  // There was a bug that some dead code was generated into some fusions in a
+  // specific edge case. When some instructions were reachable both through the
+  // LHS and the RHS operands, the BFS (Breadth-first search) through the LHS1
+  // operand "marked" one operation as non-fusible because it would exceed the
+  // limit on fusion parameters per operand. But the BFS through the RHS operand
+  // went through that node and fused some more operands. So the resulting
+  // fusion was not connected and caused errors. This test case checks that such
+  // configurations generate a correct HLO now.
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+                          ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+  a = f32[2,4]{1,0} parameter(0)
+  b = f32[2,4]{1,0} parameter(1)
+  c = f32[2,4]{1,0} parameter(2)
+  d = f32[2,4]{1,0} parameter(3)
+  e = f32[2,4]{1,0} parameter(4)
+  add0 = f32[2,4]{1,0} add(a, b)
+  add1 = f32[2,4]{1,0} add(add0, c)
+  add2 = f32[2,4]{1,0} add(add1, d)
+  add3 = f32[2,4]{1,0} add(add2, e)
+  ROOT r = f32[2,2]{1,0} dot(add3, add0),
+           lhs_contracting_dims={1}, rhs_contracting_dims={1}
+})"));
+
+  EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
+  // ~VerifiedHloModule() will verify the module.
+}
+
+TEST_F(GemmFusionLevel2Test, EachScopeIsFusedToASeparateSubgraph) {
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+                          ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+  a = f32[2,4]{1,0} parameter(0)
+  b = f32[2,4]{1,0} parameter(1)
+  add = f32[2,4]{1,0} add(a, b)
+  ROOT r = f32[2,2]{1,0} dot(add, add),
+           lhs_contracting_dims={1}, rhs_contracting_dims={1}
+})"));
+
+  EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
+
+  MatchHloModule(*module, R"(
+CHECK-DAG: %[[P0:.*]] = f32[2,4]{1,0} parameter(0)
+CHECK-DAG: %[[P1:.*]] = f32[2,4]{1,0} parameter(1)
+CHECK-DAG: %[[ADD0:.*]] = f32[2,4]{1,0} add(f32[2,4]{1,0} %[[P0]], f32[2,4]{1,0} %[[P1]])
+CHECK-DAG: %[[P2:.*]] = f32[2,4]{1,0} parameter(2)
+CHECK-DAG: %[[P3:.*]] = f32[2,4]{1,0} parameter(3)
+CHECK-DAG: %[[ADD1:.*]] = f32[2,4]{1,0} add(f32[2,4]{1,0} %[[P2]], f32[2,4]{1,0} %[[P3]])
+CHECK-DAG: ROOT {{.*}} = f32[2,2]{1,0} dot(f32[2,4]{1,0} %[[ADD0]], f32[2,4]{1,0} %[[ADD1]])
+CHECK: ENTRY
+CHECK-DAG: %[[P0:.*]] = f32[2,4]{1,0} parameter(0)
+CHECK-DAG: %[[P1:.*]] = f32[2,4]{1,0} parameter(1)
+CHECK-DAG: ROOT {{.*}} = f32[2,2]{1,0}
+CHECK-SAME: fusion(f32[2,4]{1,0} %[[P0]], f32[2,4]{1,0} %[[P1]], f32[2,4]{1,0} %[[P0]], f32[2,4]{1,0} %[[P1]]),
+CHECK-SAME: kind=kCustom
+CHECK-SAME: __triton_gemm
+})");
+}
+
+// The 2 inputs of the add operation are the same and they are iterated the same
+// way, so the same parameter node is reused for them.
+// The reuse happens per "operand fusion", so the add of the LHS and RHS still
+// use different nodes.
+TEST_F(GemmFusionLevel2Test, ParamNodesAreReusedIfTheyHaveTheSameIterSpec) {
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+                          ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+  a = f32[2,4]{1,0} parameter(0)
+  add = f32[2,4]{1,0} add(a, a)
+  ROOT r = f32[2,2]{1,0} dot(add, add),
+           lhs_contracting_dims={1}, rhs_contracting_dims={1}
+})"));
+
+  EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
+
+  MatchHloModule(*module, R"(
+CHECK-DAG: %[[P0:.*]] = f32[2,4]{1,0} parameter(0)
+CHECK-DAG: %[[ADD0:.*]] = f32[2,4]{1,0} add(f32[2,4]{1,0} %[[P0]], f32[2,4]{1,0} %[[P0]])
+CHECK-DAG: %[[P1:.*]] = f32[2,4]{1,0} parameter(1)
+CHECK-DAG: %[[ADD1:.*]] = f32[2,4]{1,0} add(f32[2,4]{1,0} %[[P1]], f32[2,4]{1,0} %[[P1]])
+CHECK-DAG: ROOT {{.*}} = f32[2,2]{1,0} dot(f32[2,4]{1,0} %[[ADD0]], f32[2,4]{1,0} %[[ADD1]])
+CHECK: ENTRY
+CHECK-DAG: %[[P0:.*]] = f32[2,4]{1,0} parameter(0)
+CHECK-DAG: ROOT {{.*}} = f32[2,2]{1,0}
+CHECK-SAME: fusion(f32[2,4]{1,0} %[[P0]], f32[2,4]{1,0} %[[P0]])
+CHECK-SAME: kind=kCustom
+CHECK-SAME: __triton_gemm
+})");
+}
+
+// NEGATE has the same iteration spec at both usages, so the node is reused
+// (implying that P0 is also reused).
+TEST_F(GemmFusionLevel2Test, NonParamNodesAreReusedIfTheyHaveTheSameIterSpec) {
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+                          ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+  a = f32[4,4]{1,0} parameter(0)
+  b = f32[4,4]{1,0} parameter(1)
+  negate = f32[4,4]{1,0} negate(a)
+  sine = f32[4,4]{1,0} sine(negate)
+  add = f32[4,4]{1,0} add(negate, sine)
+  ROOT r = f32[4,4]{1,0} dot(add, b),
+           lhs_contracting_dims={1}, rhs_contracting_dims={1}
+})"));
+
+  EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
+
+  MatchHloModule(*module, R"(
+CHECK-DAG: %[[P0:.*]] = f32[4,4]{1,0} parameter(0)
+CHECK-DAG: %[[P1:.*]] = f32[4,4]{1,0} parameter(1)
+CHECK-DAG: %[[NEGATE:.*]] = f32[4,4]{1,0} negate(f32[4,4]{1,0} %[[P0]])
+CHECK-DAG: %[[SINE:.*]] = f32[4,4]{1,0} sine(f32[4,4]{1,0} %[[NEGATE]])
+CHECK-DAG: %[[ADD:.*]] = f32[4,4]{1,0} add(f32[4,4]{1,0} %[[NEGATE]], f32[4,4]{1,0} %[[SINE]])
+CHECK-DAG: ROOT {{.*}} = f32[4,4]{1,0} dot(f32[4,4]{1,0} %[[ADD]], f32[4,4]{1,0} %[[P1]])
+CHECK: ENTRY
+CHECK-DAG: %[[P0:.*]] = f32[4,4]{1,0} parameter(0)
+CHECK-DAG: %[[P1:.*]] = f32[4,4]{1,0} parameter(1)
+CHECK-DAG: ROOT {{.*}} = f32[4,4]{1,0}
+CHECK-SAME: fusion(f32[4,4]{1,0} %[[P0]], f32[4,4]{1,0} %[[P1]])
+CHECK-SAME: kind=kCustom
+CHECK-SAME: __triton_gemm
+})");
+}
+
+// The direct read of the input and the transposed read of the input have
+// different iteration specs, so we don't reuse the node.
+TEST_F(GemmFusionLevel2Test, NodesAreNotReusedIfTheyHaveDifferentIterSpecs) {
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+                          ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+  a = f32[4,4]{1,0} parameter(0)
+  b = f32[4,4]{1,0} parameter(1)
+  tr_a = f32[4,4]{1,0} transpose(a), dimensions={1,0}
+  add = f32[4,4]{1,0} add(a, tr_a)
+  ROOT r = f32[4,4]{1,0} dot(add, b),
+           lhs_contracting_dims={1}, rhs_contracting_dims={1}
+})"));
+
+  EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
+
+  MatchHloModule(*module, R"(
+CHECK-DAG: %[[P0:.*]] = f32[4,4]{1,0} parameter(0)
+CHECK-DAG: %[[P1:.*]] = f32[4,4]{1,0} parameter(1)
+CHECK-DAG: %[[P2:.*]] = f32[4,4]{1,0} parameter(2)
+CHECK-DAG: %[[TRANSPOSE:.*]] = f32[4,4]{1,0} transpose(f32[4,4]{1,0} %[[P1]])
+CHECK-DAG: %[[ADD:.*]] = f32[4,4]{1,0} add(f32[4,4]{1,0} %[[P0]], f32[4,4]{1,0} %[[TRANSPOSE]])
+CHECK-DAG: ROOT {{.*}} = f32[4,4]{1,0} dot(f32[4,4]{1,0} %[[ADD]], f32[4,4]{1,0} %[[P2]])
+CHECK: ENTRY
+CHECK-DAG: %[[P0:.*]] = f32[4,4]{1,0} parameter(0)
+CHECK-DAG: %[[P1:.*]] = f32[4,4]{1,0} parameter(1)
+CHECK-DAG: ROOT {{.*}} = f32[4,4]{1,0}
+CHECK-SAME: fusion(f32[4,4]{1,0} %[[P0]], f32[4,4]{1,0} %[[P0]], f32[4,4]{1,0} %[[P1]])
+CHECK-SAME: kind=kCustom
+CHECK-SAME: __triton_gemm
+})");
+}
+
+TEST_F(GemmFusionLevel2Test, OperationsAddingMoreParametersGetMultipleTries) {
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+                          ParseAndReturnVerifiedModule(R"(
+e {
+  p0 = f32[2,2] parameter(0)
+  c0 = f32[] constant(12345)
+  b0 = f32[2,2] broadcast(c0), dimensions={}
+  m0 = f32[2,2] multiply(p0, b0)
+  c1 = f32[] constant(34567)
+  b1 = f32[2,2] broadcast(c1), dimensions={}
+  a0 = f32[2,2] add(m0, b1)
+  b3 = f32[2,2,2] broadcast(a0), dimensions={0,1}
+  p2 = f32[2,2,2] parameter(2)
+  m2 = f32[2,2,2] multiply(p2, b3)
+  p1 = f32[2]{0} parameter(1)
+  c2 = f32[] constant(5678)
+  b2 = f32[2] broadcast(c2), dimensions={}
+  a1 = f32[2]{0} add(p1, b2)
+  b4 = f32[2,2,2] broadcast(a1), dimensions={2}
+  m1 = f32[2,2,2] multiply(m2, b4)
+  b = f32[4,2] bitcast(m1)
+  p3 = f16[2,2] parameter(3)
+  p3c = f32[2,2] convert(p3)
+  ROOT r = f32[4,2] dot(b, p3c),
+    lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})"));
+
+  EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch((m::Fusion(m::Parameter(), m::Parameter(),
+                                    m::Parameter(), m::Parameter()))));
+}
+
+TEST_F(GemmFusionLevel2Test, GemmFusionBailsOutPreAmpere) {
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+                          ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+  p0 = f32[2,53] parameter(0)
+  p0e = f32[2,53] exponential(p0)
+  p1 = s16[53,2] parameter(1)
+  p1c = f32[53,2] convert(p1)
+  ROOT dot = f32[2,2] dot(p0e, p1c),
+    lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})"));
+  EXPECT_THAT(
+      GemmFusion(se::CudaComputeCapability{se::CudaComputeCapability::VOLTA, 0})
+          .Run(module.get()),
+      tsl::testing::StatusIs(
+          absl::StatusCode::kFailedPrecondition,
+          ::testing::HasSubstr("Triton support is only enabled for Ampere GPUs "
+                               "(compute capability 8.0) and up, but got")));
+}
+
+TEST_F(GemmFusionLevel2Test, GemmFusionSucceedsOnNonCudaGpu) {
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+                          ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+  p0 = f32[2,53] parameter(0)
+  p0e = f32[2,53] exponential(p0)
+  p1 = s16[53,2] parameter(1)
+  p1c = f32[53,2] convert(p1)
+  ROOT dot = f32[2,2] dot(p0e, p1c),
+    lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})"));
+  EXPECT_TRUE(GemmFusion(se::RocmComputeCapability{}).Run(module.get()).ok());
+}
+
+TEST_F(GemmFusionLevel2Test, ParameterUsedElementwiseTwiceIsFused) {
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+                          ParseAndReturnVerifiedModule(R"(
+HloModule t
+
+ENTRY e {
+  p0 = f32[2,35] parameter(0)
+  p0n = f32[2,35] negate(p0)
+  p0e = f32[2,35] exponential(p0)
+  a = f32[2,35] add(p0e, p0n)
+  p1 = f16[35,2] parameter(1)
+  p1c = f32[35,2] convert(p1)
+  ROOT dot = f32[2,2] dot(a, p1c),
+    lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})"));
+  EXPECT_TRUE(GemmFusion(se::CudaComputeCapability{
+                             se::CudaComputeCapability::AMPERE, 0})
+                  .Run(module.get())
+                  .value());
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch((m::Fusion(m::Parameter(), m::Parameter()))));
+  TF_ASSERT_OK_AND_ASSIGN(
+      const auto analysis,
+      TritonFusionAnalysis::Execute(*module->entry_computation()
+                                         ->root_instruction()
+                                         ->called_computations()[0]));
+  EXPECT_EQ(analysis.ScopeParameters(TritonFusionAnalysis::Scope::LHS).size(),
+            1);
+  EXPECT_EQ(analysis.ScopeParameters(TritonFusionAnalysis::Scope::RHS).size(),
+            1);
+}
+
+TEST_F(GemmFusionLevel2Test,
+       ParameterUsedNonElementwiseTwiceIsFusedOnBothPaths) {
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+                          ParseAndReturnVerifiedModule(R"(
+HloModule t
+
+ENTRY e {
+  p0 = f32[4,4] parameter(0)
+  p0t = f32[4,4] transpose(p0), dimensions={1,0}
+  a = f32[4,4] add(p0, p0t)
+  p1 = f16[4,5] parameter(1)
+  p1c = f32[4,5] convert(p1)
+  ROOT dot = f32[4,5] dot(a, p1c),
+    lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})"));
+  EXPECT_TRUE(GemmFusion(se::CudaComputeCapability{
+                             se::CudaComputeCapability::AMPERE, 0})
+                  .Run(module.get())
+                  .value());
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch((m::Fusion(m::Parameter(), m::Parameter(), m::Parameter()))));
+}
+
+TEST_F(GemmFusionLevel2Test,
+       ComputationParameterWithMultipleUsersIsNotTrivialToFuse) {
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+                          ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+  p0 = f32[400,400] parameter(0)
+
+  c0 = f16[400,400] convert(p0)
+  p1 = f16[400,400] parameter(1)
+  dot0 = f16[400,400] dot(c0, p1),
+    lhs_contracting_dims={1}, rhs_contracting_dims={0}
+
+  c1 = f16[400,400] convert(p0)
+  p2 = f16[400,400] parameter(2)
+  dot1 = f16[400,400] dot(c1, p2),
+    lhs_contracting_dims={1}, rhs_contracting_dims={0}
+
+  ROOT a = f16[400,400] add(dot0, dot1)
+})"));
+  EXPECT_FALSE(GemmFusion(se::CudaComputeCapability{
+                              se::CudaComputeCapability::AMPERE, 0})
+                   .Run(module.get())
+                   .value());
+}
+
+TEST_F(GemmFusionLevel2Test, NarrowingConversionIsAlwaysBetterToFuse) {
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+                          ParseAndReturnVerifiedModule(R"(
+ENTRY e {
+  p0 = s8[512,512] parameter(0)
+  c0 = f16[512,512] convert(p0)
+  p1 = f16[512,512] parameter(1)
+  dot0 = f16[512,512] dot(c0, p1),
+    lhs_contracting_dims={1}, rhs_contracting_dims={0}
+
+  n = f16[512,512] negate(c0)
+  ROOT a = f16[512,512] add(dot0, n)
+})"));
+  EXPECT_TRUE(GemmFusion(se::CudaComputeCapability{
+                             se::CudaComputeCapability::AMPERE, 0})
+                  .Run(module.get())
+                  .value());
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch((m::Add(m::Fusion(m::Parameter(), m::Parameter()),
+                                 m::Negate()))));
+}
+
+TEST_F(GemmFusionLevel2Test, NestedSlicingIsAnalyzedCorrectly) {
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+                          ParseAndReturnVerifiedModule(R"(
+triton_gemm_d_computation {
+  p0 = f32[6,24]{1,0} parameter(0)
+  slice1 = f32[5,20]{1,0} slice(p0), slice={[1:6], [3:23]}
+  n1 = f32[5,20]{1,0} negate(slice1)
+  slice2 = f32[3,7]{1,0} slice(n1), slice={[1:4], [13:20]}
+  p1 = f32[7,37]{1,0} parameter(1)
+  ROOT d = f32[3,37]{1,0} dot(slice2, p1),
+    lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+
+ENTRY e {
+  p0 = f32[7,37]{1,0} parameter(0)
+  p1 = f32[6,24]{1,0} parameter(1)
+  ROOT triton_gemm_d = f32[3,37]{1,0} fusion(p1, p0), kind=kCustom,
+    calls=triton_gemm_d_computation
+})"));
+  const HloComputation* computation =
+      module->entry_computation()->root_instruction()->called_computations()[0];
+  TF_ASSERT_OK_AND_ASSIGN(const auto analysis,
+                          TritonFusionAnalysis::Execute(*computation));
+  EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::LHS,
+                                 computation->parameter_instruction(0), 0),
+              ElementsAre(FieldsAre(/*stride=*/24, /*count=*/6,
+                                    /*slice_start=*/2, /*sliced_count=*/3,
+                                    /*subfragments=*/ElementsAre(3))));
+  EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::LHS,
+                                 computation->parameter_instruction(0), 1),
+              ElementsAre(FieldsAre(/*stride=*/1, /*count=*/24,
+                                    /*slice_start=*/16, /*sliced_count=*/7,
+                                    /*subfragments=*/ElementsAre(7))));
+}
+
+TEST_F(GemmFusionLevel2Test, FusedConcatenationIsAnalyzedCorrectly) {
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+                          ParseAndReturnVerifiedModule(R"(
+e {
+  p0 = s8[153,1536] parameter(0)
+  p1 = s8[153,128] parameter(1)
+  p2 = s8[153,256] parameter(2)
+  cat = s8[153,1920] concatenate(p0, p1, p2), dimensions={1}
+  cvt = bf16[153,1920] convert(cat)
+  p3 = bf16[16,153] parameter(3)
+  ROOT d = bf16[16,1920] dot(p3, cvt),
+    lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})"));
+  EXPECT_TRUE(GemmFusion(se::CudaComputeCapability{
+                             se::CudaComputeCapability::AMPERE, 0})
+                  .Run(module.get())
+                  .value());
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch((m::Fusion(m::Parameter(), m::Parameter(),
+                                    m::Parameter(), m::Parameter()))));
+  const HloComputation* computation =
+      module->entry_computation()->root_instruction()->called_computations()[0];
+  TF_ASSERT_OK_AND_ASSIGN(const auto analysis,
+                          TritonFusionAnalysis::Execute(*computation));
+
+  EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS,
+                                 computation->parameter_instruction(1), 0),
+              ElementsAre(FieldsAre(/*stride=*/1536, /*count=*/153,
+                                    /*slice_start=*/0, /*sliced_count=*/153,
+                                    /*subfragments=*/ElementsAre(153))));
+  EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS,
+                                 computation->parameter_instruction(1), 1),
+              ElementsAre(FieldsAre(/*stride=*/1, /*count=*/1536,
+                                    /*slice_start=*/0, /*sliced_count=*/1536,
+                                    /*subfragments=*/ElementsAre(1536))));
+
+  EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS,
+                                 computation->parameter_instruction(2), 0),
+              ElementsAre(FieldsAre(/*stride=*/128, /*count=*/153,
+                                    /*slice_start=*/0, /*sliced_count=*/153,
+                                    /*subfragments=*/ElementsAre(153))));
+  EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS,
+                                 computation->parameter_instruction(2), 1),
+              ElementsAre(FieldsAre(/*stride=*/1, /*count=*/128,
+                                    /*slice_start=*/-1536, /*sliced_count=*/128,
+                                    /*subfragments=*/ElementsAre(128))));
+
+  EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS,
+                                 computation->parameter_instruction(3), 0),
+              ElementsAre(FieldsAre(/*stride=*/256, /*count=*/153,
+                                    /*slice_start=*/0, /*sliced_count=*/153,
+                                    /*subfragments=*/ElementsAre(153))));
+  EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS,
+                                 computation->parameter_instruction(3), 1),
+              ElementsAre(FieldsAre(/*stride=*/1, /*count=*/256,
+                                    /*slice_start=*/-1536 - 128,
+                                    /*sliced_count=*/256,
+                                    /*subfragments=*/ElementsAre(256))));
+}
+
+TEST_F(GemmFusionLevel2Test, IndivisibleConcatenationIsNotFused) {
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+                          ParseAndReturnVerifiedModule(R"(
+e {
+  p0 = s8[124,1024] parameter(0)
+  p1 = s8[124,1001] parameter(1)
+  cat = s8[124,2025] concatenate(p0, p1), dimensions={1}
+  cvt = f16[124,2025] convert(cat)
+  p2 = f16[123,124] parameter(2)
+  ROOT d = f16[2025,123] dot(cvt, p2),
+    lhs_contracting_dims={0}, rhs_contracting_dims={1}
+})"));
+  EXPECT_TRUE(GemmFusion(se::CudaComputeCapability{
+                             se::CudaComputeCapability::AMPERE, 0})
+                  .Run(module.get())
+                  .value());
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch((m::Fusion(m::Concatenate(), m::Parameter()))));
+}
+
+TEST_F(GemmFusionLevel2Test, ConcatenationOfContractingIsNotFused) {
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+                          ParseAndReturnVerifiedModule(R"(
+e {
+  p0 = s8[124,1024] parameter(0)
+  p1 = s8[124,1024] parameter(1)
+  cat = s8[124,2048] concatenate(p0, p1), dimensions={1}
+  cvt = f16[124,2048] convert(cat)
+  p2 = f16[123,2048] parameter(2)
+  ROOT d = f16[124,123] dot(cvt, p2),
+    lhs_contracting_dims={1}, rhs_contracting_dims={1}
+})"));
+  EXPECT_TRUE(GemmFusion(se::CudaComputeCapability{
+                             se::CudaComputeCapability::AMPERE, 0})
+                  .Run(module.get())
+                  .value());
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch((m::Fusion(m::Concatenate(), m::Parameter()))));
+}
+
+TEST_F(GemmFusionLevel2Test, ConcatenationOfBatchIsNotFused) {
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+                          ParseAndReturnVerifiedModule(R"(
+e {
+  p0 = s8[124,1024,50] parameter(0)
+  p1 = s8[124,1024,50] parameter(1)
+  cat = s8[124,2048,50] concatenate(p0, p1), dimensions={1}
+  cvt = f16[124,2048,50] convert(cat)
+  p2 = f16[123,2048,50] parameter(2)
+  ROOT d = f16[2048,124,123] dot(cvt, p2),
+    lhs_batch_dims={1}, rhs_batch_dims={1},
+    lhs_contracting_dims={2}, rhs_contracting_dims={2}
+})"));
+  EXPECT_TRUE(GemmFusion(se::CudaComputeCapability{
+                             se::CudaComputeCapability::AMPERE, 0})
+                  .Run(module.get())
+                  .value());
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch((m::Fusion(m::Concatenate(), m::Parameter()))));
+}
+
+TEST_F(GemmFusionLevel2Test,
+       DifferentConcatenationOfSameParametersIsFusedViaNodeDuplication) {
+  // It means that the same input is passed to the fusion multiple times and
+  // it's read differently for each.
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+                          ParseAndReturnVerifiedModule(R"(
+e {
+  p0 = s8[128,2] parameter(0)
+  p1 = s8[128,2] parameter(1)
+  cat0 = s8[256,2] concatenate(p0, p1), dimensions={0}
+  cvt0 = f16[256,2] convert(cat0)
+  cat1 = s8[256,2] concatenate(p1, p0), dimensions={0}
+  n1 = s8[256,2] negate(cat1)
+  cvt1 = f16[256,2] convert(n1)
+  a = f16[256,2] add(cvt1, cvt0)
+  p2 = f16[2,18] parameter(2)
+  ROOT d = f16[18,256] dot(p2, a),
+    lhs_contracting_dims={0}, rhs_contracting_dims={1}
+})"));
+
+  EXPECT_TRUE(GemmFusion(se::CudaComputeCapability{
+                             se::CudaComputeCapability::AMPERE, 0})
+                  .Run(module.get())
+                  .value());
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch((m::Fusion(m::Parameter(), m::Parameter(), m::Parameter(),
+                            m::Parameter(), m::Parameter()))));
+}
+
+TEST_F(GemmFusionTest, CopiesDotMetadataToFusionOp) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+HloModule m
+
+ENTRY e {
+  p0 = f16[2,18] parameter(0)
+  p1 = f16[256,2] parameter(1)
+  ROOT d = f16[18,256] dot(p0, p1),
+    lhs_contracting_dims={0}, rhs_contracting_dims={1}, metadata={op_name="foo"}
+})")
+                    .value();
+  EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
+  EXPECT_EQ(
+      module->entry_computation()->root_instruction()->metadata().op_name(),
+      "foo");
+}
+
+// A test fixture class for testing the threshold for small matrices.
+class SmallDotGemmFusionTest : public GemmFusionTest {
+ public:
+  DebugOptions GetDebugOptionsForTest() override {
+    DebugOptions debug_options = GemmFusionTest::GetDebugOptionsForTest();
+    debug_options.set_xla_gpu_gemm_rewrite_size_threshold(100);
+    return debug_options;
+  }
+};
+
+TEST_F(SmallDotGemmFusionTest, SkipSmallMatrixMultiplicationRewrite) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+HloModule m
+
+ENTRY e {
+  p0 = f16[2,10] parameter(0)
+  p1 = f16[10,2] parameter(1)
+  ROOT d = f16[10,10] dot(p0, p1),
+    lhs_contracting_dims={0}, rhs_contracting_dims={1}
+})")
+                    .value();
+
+  EXPECT_FALSE(GemmFusion(gpu_version_).Run(module.get()).value());
+
+  MatchHloModule(*module, R"(
+; CHECK-LABEL: ENTRY %e ({{.*}}: f16[2,10], {{.*}}: f16[10,2]) -> f16[10,10] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f16[2,10]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = f16[10,2]{1,0} parameter(1)
+; CHECK:      ROOT {{.*}} = f16[10,10]{1,0} dot(f16[2,10]{1,0} [[P0]], f16[10,2]{1,0} [[P1]])
+})");
+}
+
+TEST_F(SmallDotGemmFusionTest, LargeMatrixMultiplicationIsRewritten) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+HloModule m
+
+ENTRY e {
+  p0 = f16[2,18] parameter(0)
+  p1 = f16[50,2] parameter(1)
+  ROOT d = f16[18,50] dot(p0, p1),
+    lhs_contracting_dims={0}, rhs_contracting_dims={1}
+})")
+                    .value();
+
+  EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
+
+  MatchHloModule(*module, R"(
+; CHECK-LABEL: ENTRY %e ({{.*}}: f16[2,18], {{.*}}: f16[50,2]) -> f16[18,50] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f16[2,18]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = f16[50,2]{1,0} parameter(1)
+; CHECK:      ROOT {{.*}} = f16[18,50]{1,0}
+; CHECK:        fusion(f16[2,18]{1,0} [[P0]], f16[50,2]{1,0} [[P1]]),
+; CHECK:        kind=kCustom
+; CHECK:        __triton_gemm
+})");
+}
+
+class SparseDotTest : public GemmFusionTest {};
+
+TEST_F(SparseDotTest, DotWithSparseLhsOperandIsRewritten) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+HloModule test
+ENTRY main {
+  lhs = f16[2,16] parameter(0)
+  rhs = f16[32,2] parameter(1)
+  meta = u16[2,2] parameter(2)
+  ROOT dot = f32[2,2] dot(lhs, rhs, meta),
+      lhs_contracting_dims={1}, rhs_contracting_dims={0}, sparsity=L.1@2:4
+})")
+                    .value();
+  EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
+
+  MatchHloModule(*module, R"(
+; CHECK-LABEL: ENTRY %main ({{.*}}: f16[2,16], {{.*}}: f16[32,2], {{.*}}: u16[2,2]) -> f32[2,2] {
+; CHECK-NEXT: [[P0:%[^ ]+]] = f16[2,16]{1,0} parameter(0)
+; CHECK-NEXT: [[P1:%[^ ]+]] = f16[32,2]{1,0} parameter(1)
+; CHECK-NEXT: [[META:%[^ ]+]] = u16[2,2]{1,0} parameter(2)
+; CHECK:      ROOT {{.*}} = f32[2,2]{1,0}
+; CHECK-SAME:   fusion(f16[2,16]{1,0} [[P0]], f16[32,2]{1,0} [[P1]], u16[2,2]{1,0} [[META]]),
+; CHECK-SAME:   kind=kCustom
+; CHECK-SAME:   __triton_gemm
+})");
+}
+
+TEST_F(SparseDotTest, DotWithSparseRhsOperandIsNotSupported) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+HloModule test
+ENTRY main {
+  lhs = f16[2,32] parameter(0)
+  rhs = f16[16,2] parameter(1)
+  meta = u16[2,2] parameter(2)
+  ROOT dot = f32[2,2] dot(lhs, rhs, meta),
+      lhs_contracting_dims={1}, rhs_contracting_dims={0}, sparsity=R.0@2:4
+})")
+                    .value();
+  auto result = GemmFusion(gpu_version_).Run(module.get());
+  EXPECT_FALSE(result.ok());
+}
+
+TEST_F(SparseDotTest, UnsupportedSparsityType) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+HloModule test
+ENTRY main {
+  lhs = f16[2,8] parameter(0)
+  rhs = f16[32,2] parameter(1)
+  meta = u16[2,1] parameter(2)
+  ROOT dot = f32[2,2] dot(lhs, rhs, meta),
+      lhs_contracting_dims={1}, rhs_contracting_dims={0}, sparsity=L.1@1:4
+})")
+                    .value();
+  auto result = GemmFusion(gpu_version_).Run(module.get());
+  EXPECT_FALSE(result.ok());
+}
+
+}  // namespace
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc
new file mode 100644
index 0000000..a48d4f9
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc
@@ -0,0 +1,2370 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+=
+=============================================================================*/
+
+#include "xla/service/gpu/transforms/gemm_rewriter.h"
+
+#include <algorithm>
+#include <array>
+#include <cmath>
+#include <cstddef>
+#include <cstdint>
+#include <initializer_list>
+#include <limits>
+#include <memory>
+#include <optional>
+#include <string>
+#include <tuple>
+#include <utility>
+#include <variant>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/hlo/evaluator/hlo_evaluator.h"
+#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/layout.h"
+#include "xla/literal.h"
+#include "xla/literal_util.h"
+#include "xla/primitive_util.h"
+#include "xla/service/algorithm_util.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/service/gpu/ir_emission_utils.h"
+#include "xla/service/gpu/matmul_utils.h"
+#include "xla/service/hlo_creation_utils.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/status_macros.h"
+#include "xla/stream_executor/blas.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/stream_executor/gpu/gpu_blas_lt.h"
+#include "xla/types.h"
+#include "xla/util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/ml_dtypes.h"
+#include "tsl/platform/statusor.h"
+#include "tsl/protobuf/dnn.pb.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+namespace m = match;
+
+// Give this instruction a more useful name than "custom-call.42".
+absl::Status SetName(HloModule *module, HloInstruction *gemm) {
+  if (IsCublasLtMatmul(*gemm)) {
+    module->SetAndUniquifyInstrName(gemm, "cublas-lt-matmul");
+    return absl::OkStatus();
+  }
+
+  TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
+                      gemm->backend_config<GpuBackendConfig>());
+  const GemmBackendConfig &config = gpu_config.gemm_backend_config();
+  const DotDimensionNumbers &dot_dims = config.dot_dimension_numbers();
+  bool is_batch_dot = !dot_dims.lhs_batch_dimensions().empty() ||
+                      !dot_dims.rhs_batch_dimensions().empty();
+
+  module->SetAndUniquifyInstrName(
+      gemm, is_batch_dot ? "cublas-batch-gemm" : "cublas-gemm");
+  return absl::OkStatus();
+}
+
+// Returns whether a given PrimitiveType is supported by cuBLASLt Epilogue
+// Fusion. A table of supported data types can be found in the cuBLASLt
+// documentation: https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmul.
+// Note that `Ctype` also describes the output type of the GEMM. Rows with
+// `Non-default epilogue not supported` entries in the last column indicate data
+// types not compatible with Epilogue Fusion.
+bool SupportsEpilogueFusion(PrimitiveType type) {
+  switch (type) {
+    case F8E4M3FN:
+    case F8E5M2:
+    case F16:
+    case BF16:
+    case F32:
+    case F64:
+      return true;
+    default:
+      return false;
+  }
+}
+
+bool IsF8Type(const HloInstruction *instr) {
+  return primitive_util::IsF8Type(instr->shape().element_type());
+}
+
+// Returns a new shape with non-batch dimensions padded to multiples of 16, as
+// required by cuBLASLt FP8 gemms.
+Shape PadShapeToMultipleOf16(const Shape old_shape,
+                             const absl::Span<const int64_t> batch_dims) {
+  Shape padded_shape = old_shape;
+  for (int i = 0; i < old_shape.rank(); ++i) {
+    if (!absl::c_linear_search(batch_dims, i)) {
+      int64_t padded_dimension =
+          RoundUpTo<int64_t>(old_shape.dimensions(i), 16);
+      padded_shape.set_dimensions(i, padded_dimension);
+    }
+  }
+  return padded_shape;
+}
+
+// Pad the dimensions of the operands to the target shape.
+HloInstruction *PadOperandToTargetShape(const Shape &target,
+                                        HloInstruction *x) {
+  if (ShapeUtil::Equal(target, x->shape()) ||
+      !ShapeUtil::SameElementType(x->shape(), target)) {
+    return x;
+  }
+
+  PaddingConfig padding_config;
+  for (int i = 0; i < x->shape().rank(); ++i) {
+    auto dimension = padding_config.add_dimensions();
+    dimension->set_edge_padding_low(0);
+    dimension->set_edge_padding_high(target.dimensions(i) -
+                                     x->shape().dimensions(i));
+    dimension->set_interior_padding(0);
+  }
+
+  HloInstruction *zero = x->AddInstruction(HloInstruction::CreateConstant(
+      LiteralUtil::Zero(x->shape().element_type())));
+  return x->AddInstruction(
+      HloInstruction::CreatePad(target, x, zero, padding_config));
+}
+
+// Pad the non-batch dimensions of the operands to multiples of 16 as required
+// by cuBLASLt FP8 gemms.
+HloInstruction *PadOperandToMultipleOf16(absl::Span<const int64_t> batch_dims,
+                                         HloInstruction *x) {
+  Shape padded_shape = PadShapeToMultipleOf16(x->shape(), batch_dims);
+  return PadOperandToTargetShape(padded_shape, x);
+}
+
+// Calculates the reciprocal of scalar when invert is true and converts to FP32.
+absl::StatusOr<HloInstruction *> InvertAndConvertScalar(HloInstruction *scalar,
+                                                        bool invert) {
+  DCHECK(ShapeUtil::IsScalar(scalar->shape()));
+
+  if (invert) {
+    Literal one_literal = LiteralUtil::One(scalar->shape().element_type());
+    HloInstruction *one = scalar->parent()->AddInstruction(
+        HloInstruction::CreateConstant(one_literal.Clone()));
+    TF_ASSIGN_OR_RETURN(scalar, MakeBinaryHlo(HloOpcode::kDivide, one, scalar,
+                                              &scalar->metadata()));
+  }
+  if (scalar->shape().element_type() != F32) {
+    scalar = MakeConvertToHlo(scalar, F32, &scalar->metadata());
+  }
+
+  return scalar;
+}
+
+// A path of instructions by traversing downwards through users, as (op,
+// operand_index) pairs. operand_index is the index to get to the previous
+// element in the path. I.e.,
+// path[i].first->operand(path[i].second) == path[i-1].first
+using InstrPath = std::vector<std::pair<HloInstruction *, int>>;
+
+// From 'instr', recursively traverses operands until an FP8 instruction is
+// encountered. Only unary ops and a few types of non-unary ops are traversed.
+// If an FP8 instruction is found, returns the path from the FP8 instruction to
+// 'instr'. Returns nullopt when no FP8 instruction is reached.
+//
+// The intent is, given 'instr' is the operand of a dot, to find a sequence of
+// instruction that can potentially be fused into a cuBLAS LT FP8 gemm.
+std::optional<InstrPath> FindF8SubgraphRecursive(
+    HloInstruction *instr, absl::flat_hash_set<int> &visited_instrs) {
+  // Avoid visiting the same instruction more than once.
+  if (!visited_instrs.emplace(instr->unique_id()).second) {
+    return std::nullopt;
+  }
+  if (IsF8Type(instr)) {
+    // The initial operand index is meaningless. Arbitrarily use -1.
+    return InstrPath{{instr, -1}};
+  }
+  if (instr->operand_count() == 1 || instr->opcode() == HloOpcode::kDivide ||
+      instr->opcode() == HloOpcode::kDynamicSlice ||
+      instr->opcode() == HloOpcode::kPad) {
+    std::optional<InstrPath> subgraph =
+        FindF8SubgraphRecursive(instr->mutable_operand(0), visited_instrs);
+    if (subgraph) {
+      subgraph->emplace_back(std::make_pair(instr, 0));
+    }
+    return subgraph;
+  } else if (instr->opcode() == HloOpcode::kMultiply ||
+             instr->opcode() == HloOpcode::kSelect) {
+    for (int k = 0; k < 2; ++k) {
+      // Iterate over operands 0 and 1 for multiply and operands 1 and 2 for
+      // select.
+      int operand_idx = k + (instr->opcode() == HloOpcode::kSelect);
+      std::optional<InstrPath> subgraph = FindF8SubgraphRecursive(
+          instr->mutable_operand(operand_idx), visited_instrs);
+      if (subgraph) {
+        subgraph->emplace_back(std::make_pair(instr, operand_idx));
+        return subgraph;
+      }
+    }
+  }
+  return std::nullopt;
+}
+
+// Contains information on a parameter (either the LHS or RHS) for a
+// gemm that can be potentially pattern-matched into an FP8 cublasLT gemm.
+struct MatchedFp8Param {
+  // The FP8 input to the gemm.
+  HloInstruction *fp8_input = nullptr;
+  // If nonnull, the scale for the 'x'
+  HloInstruction *scale = nullptr;
+  // Whether the scale, if present, multiplies or divides 'x'
+  bool mult_scale = false;
+  // A list of instructions from x to the dot instruction commutative with
+  // dequantization. Such instructions can be moved before the FP8 gemm.
+  InstrPath commutative_ops;
+};
+
+// Given an operand of a dot, `instr`, returns a MatchedFp8Param if this operand
+// allows rewriting the dot in an FP8 cublasLT custom call, optionally with
+// scaling. In particular, returns an MatchedFp8Param if either 'instr' is FP8
+// or there is a there is a path from an FP8 instruction 'fp8_input' to 'instr'
+// consisting of the following.
+// 1. A convert to a wider type.
+// 2. Optionally, a multiplication/division by a scalar, representing the scale.
+//    If present, the scalar scale is returned as 'scale' and 'mult_scale'
+//    is set to true or false depending on whether there is a multiplication or
+//    a division.
+// 3. A possibly-empty set of ops communative with steps (1) and (2), meaning
+//    they can be safely moved before step (1). Such ops are returned in
+//    'commutative_ops'.
+// Steps (1) and (2) together are a dequantization, and can be fused into a
+// cublas LT matmul. Step (3) can be moved before the cublas LT matmul.
+std::optional<MatchedFp8Param> MatchFp8Param(HloInstruction *instr) {
+  absl::flat_hash_set<int> visited_instrs;
+  std::optional<InstrPath> maybe_subgraph =
+      FindF8SubgraphRecursive(instr, visited_instrs);
+  if (!maybe_subgraph) {
+    return std::nullopt;
+  }
+  InstrPath &subgraph = maybe_subgraph.value();
+
+  MatchedFp8Param param;
+
+  // Directly operating on an FP8 operand.
+  if (subgraph.size() == 1) {
+    CHECK(IsF8Type(subgraph[0].first));
+    param.fp8_input = subgraph[0].first;
+    return param;
+  }
+
+  int num_dequant_ops;
+  // When not operating directly on an FP8 operand, the second and
+  // third instructions in the subgraph can describe a dequantization, i.e. a
+  // convert instruction followed by a multiply/divide instruction.
+  if (subgraph.size() > 2 &&
+      Match(subgraph[2].first,
+            m::MultiplyAnyOrder(m::Convert(m::Op(&param.fp8_input)),
+                                m::Broadcast(m::Op(&param.scale))))) {
+    param.mult_scale = true;
+    num_dequant_ops = 2;
+  } else if (subgraph.size() > 2 &&
+             Match(subgraph[2].first,
+                   m::Divide(m::Convert(m::Op(&param.fp8_input)),
+                             m::Broadcast(m::Op(&param.scale))))) {
+    param.mult_scale = false;
+    num_dequant_ops = 2;
+  } else if (subgraph.size() > 1 &&
+             Match(subgraph[1].first, m::Convert(m::Op(&param.fp8_input)))) {
+    // We have a convert from FP8 without a scale in this case.
+    param.scale = nullptr;
+    num_dequant_ops = 1;
+  } else {
+    VLOG(1) << "Possible intended FP8 GEMM operating on "
+            << instr->ToShortString() << " not rewritten into FP8 Custom Call.";
+    return std::nullopt;
+  }
+
+  auto preserves_element_type = [](const HloInstruction *instr) -> bool {
+    return ShapeUtil::SameElementType(instr->shape(),
+                                      instr->operand(0)->shape());
+  };
+  auto use_spmd_partitioning = [](const HloInstruction *instr) -> bool {
+    return instr->GetModule()->config().use_spmd_partitioning();
+  };
+
+  // Skip the initial FP8 instruction and the dequantization instructions.
+  int start = 1 + num_dequant_ops;
+  for (int i = start; i < subgraph.size(); ++i) {
+    // The remaining instructions must be commutative with dequantization.
+    // Bitcast, broadcast, copy, dynamic-slice, pad, reshape, select, slice,
+    // transpose, all-gather, all-to-all and collective-permute instructions are
+    // supported. Specifically, the all-gather, all-to-all and
+    // collective-permute operations are permitted only in SPMD cases since the
+    // optimization cannot be guaranteed to be applied to all replicas in the
+    // MPMD scenario.
+    if (!Match(
+            subgraph[i].first,
+            m::AnyOf<HloInstruction>(
+                m::Bitcast().WithPredicate(preserves_element_type),
+                m::Broadcast(), m::Copy(), m::DynamicSlice(), m::Pad(),
+                m::Reshape(), m::Select(), m::Slice(), m::Transpose(),
+                m::AllGather().WithPredicate(use_spmd_partitioning),
+                m::AllToAll().WithPredicate(use_spmd_partitioning),
+                m::CollectivePermute().WithPredicate(use_spmd_partitioning)))) {
+      VLOG(1) << "Possible intended FP8 GEMM operating on "
+              << instr->ToShortString()
+              << " not rewritten into FP8 Custom Call.";
+      return std::nullopt;
+    }
+    // One of the operands of select must be zero for the op to be commutative
+    // with dequantization.
+    if (Match(subgraph[i].first, m::Select()) &&
+        !Match(subgraph[i].first->operand(subgraph[i].second == 2 ? 1 : 2),
+               m::Broadcast(m::ConstantScalar(0)))) {
+      VLOG(1) << "Possible intended FP8 GEMM operating on "
+              << instr->ToShortString()
+              << " not rewritten into FP8 Custom Call. Select requires a zero "
+                 "operand to be exchanged with dequantization.";
+      return std::nullopt;
+    }
+  }
+
+  param.commutative_ops = {subgraph.begin() + start, subgraph.end()};
+  return param;
+}
+
+// Transposes a matrix by swapping the contracting and non-contracting
+// dimension. There must be only one contracting and only one non-contracting
+// dimension. Keeps the layout the same.
+HloInstruction *TransposeMatrix(HloInstruction *instr, int64_t contracting_dim,
+                                absl::Span<const int64_t> batch_dims) {
+  // Identify the dimensional order which describes a transpose of the
+  // contracting and non-contracting dimensions of the GEMM.
+  std::vector<int64_t> permutation(instr->shape().dimensions_size(), -1);
+  // Discard the batch dimensions.
+  for (int64_t batch_dim : batch_dims) {
+    permutation[batch_dim] = batch_dim;
+  }
+  // Identify the non-contracting dimension.
+  int non_contracting_dim;
+  for (int i = 0; i < instr->shape().dimensions_size(); ++i) {
+    if (permutation[i] == -1 && contracting_dim != i) {
+      non_contracting_dim = i;
+    }
+  }
+  permutation[non_contracting_dim] = contracting_dim;
+  permutation[contracting_dim] = non_contracting_dim;
+
+  Shape new_shape = ShapeUtil::PermuteDimensions(permutation, instr->shape());
+  *new_shape.mutable_layout() = instr->shape().layout();
+  return instr->AddInstruction(
+      HloInstruction::CreateTranspose(new_shape, instr, permutation));
+}
+
+// If the bias is a sequence of ops that depend only on broadcasts of
+// constants, materialize the bias if it's small.
+//
+// Normally the constant-folding pass would materialize the bias if it is
+// calculated entirely from constants. But if the bias is a broadcast of a
+// constant, constant-folding won't expand the broadcast, on the theory that
+// folding broadcasts of constants causes us to consume more memory and can
+// actually make things slower (because any op which reads the constant has
+// to read more memory).
+//
+// OTOH in our case, we don't want to run an op that just broadcasts a
+// constant so we can fuse it into this gemm. That would defeat the whole
+// purpose of this fusion, which is to launch fewer kernels.  So if we can,
+// we expand out this constant ourselves.
+//
+// TODO(b/192499646): Even better would be to use cublasLT to fuse the
+// broadcasted bias, if it supports that fusion efficiently.
+HloInstruction *MaybeConstantFoldBias(HloInstruction *bias) {
+  // This limit was not chosen carefully.
+  constexpr int kMaxMaterializeBiasBytes = 8 * 1024 * 1024;
+
+  // Don't fold broadcasts of scalars -- algsimp will just collapse it again.
+  auto is_nonscalar = [](const HloInstruction *instr) {
+    return !ShapeUtil::IsEffectiveScalar(instr->shape());
+  };
+
+  // For now, only fold broadcast(constant) or
+  // reshape/transpose/bitcast(broadcast(constant)). This lets us avoid the
+  // complexity in the constant-folding pass about what is and isn't legal to
+  // fold.
+  auto broadcast_of_nonscalar =
+      m::Broadcast(m::Constant().WithPredicate(is_nonscalar));
+
+  if (ShapeUtil::ByteSizeOf(bias->shape()) <= kMaxMaterializeBiasBytes &&
+      (Match(bias, broadcast_of_nonscalar) ||
+       Match(bias, m::Reshape(broadcast_of_nonscalar)) ||
+       Match(bias, m::Transpose(broadcast_of_nonscalar)) ||
+       Match(bias, m::Bitcast(broadcast_of_nonscalar)))) {
+    HloEvaluator evaluator(/*max_loop_iterations=*/0);
+    Literal result;
+    if (evaluator.TryEvaluate(
+            bias, &result,
+            /*recursively_evaluate_nonconstant_operands=*/true)) {
+      return bias->parent()->AddInstruction(
+          HloInstruction::CreateConstant(std::move(result)));
+    }
+  }
+
+  return bias;
+}
+
+auto Gemm(HloInstruction **instr) {
+  return m::CustomCall(instr, {kGemmCallTarget});
+}
+
+auto CublasLtMatmul(HloInstruction **instr) {
+  return m::CustomCall(instr, {kCublasLtMatmulCallTarget});
+}
+
+auto CublasLtMatmulF8(HloInstruction **instr) {
+  return m::CustomCall(instr, {kCublasLtMatmulF8CallTarget});
+}
+
+auto CublasLtMatmulMaybeF8(HloInstruction **instr) {
+  return m::CustomCall(
+      instr, {kCublasLtMatmulCallTarget, kCublasLtMatmulF8CallTarget});
+}
+
+auto GemmOrCublasLtMatmul(HloInstruction **instr) {
+  return m::CustomCall(instr, {kGemmCallTarget, kCublasLtMatmulCallTarget});
+}
+
+auto GemmOrCublasLtMatmulMaybeF8(HloInstruction **instr) {
+  return m::CustomCall(instr, {kGemmCallTarget, kCublasLtMatmulCallTarget,
+                               kCublasLtMatmulF8CallTarget});
+}
+
+auto BcastConstScalar(HloInstruction **instr, double value) {
+  return m::Broadcast(instr, m::ConstantScalar(value));
+}
+
+auto BcastConstScalar(double value) { return BcastConstScalar(nullptr, value); }
+
+auto BcastConstScalarNear(double value) {
+  return m::Broadcast(m::ConstantScalar().WithPredicate(
+      [expected = value](const HloInstruction *instr) {
+        // Not a very robust floating-point comparison, but good enough for our
+        // purposes.
+        std::optional<double> actual =
+            xla::Cast<const HloConstantInstruction>(instr)
+                ->literal()
+                .GetAsDouble({});
+        if (!actual.has_value()) return false;
+        double epsilon;
+        switch (instr->shape().element_type()) {
+          case F16:
+            epsilon = 128 * std::numeric_limits<Eigen::half>::epsilon();
+            break;
+          case BF16:
+            epsilon = 128 * std::numeric_limits<bfloat16>::epsilon();
+            break;
+          case F32:
+            epsilon = 128 * std::numeric_limits<float>::epsilon();
+            break;
+          case F64:
+            epsilon = 128 * std::numeric_limits<double>::epsilon();
+            break;
+          default:
+            return false;
+        }
+        return abs(*actual - expected) < (abs(*actual + expected) * epsilon);
+      }));
+}
+
+template <typename Pattern>
+auto OptionalSlice(HloInstruction **optional_slice, Pattern pattern) {
+  return m::AnyOf<HloInstruction>(m::Slice(optional_slice, pattern),
+                                  std::move(pattern));
+}
+
+template <typename Pattern>
+auto OptionalConvert(HloInstruction **optional_convert, Pattern pattern) {
+  return m::AnyOf<HloInstruction>(m::Convert(optional_convert, pattern),
+                                  std::move(pattern));
+}
+
+template <typename Pattern>
+auto OptionalBitcast(HloInstruction **optional_bitcast, Pattern pattern) {
+  return m::AnyOf<HloInstruction>(m::Bitcast(optional_bitcast, pattern),
+                                  std::move(pattern));
+}
+
+// The rewriting proceeds in a bottom-up way:
+//
+// (kDot A B) is rewritten into a (kCustomCall:gemm A B)
+//
+// (kMultiply (kCustomCall:gemm A B) C) is folding C (provided it's a constant)
+// into an alpha parameter of the custom call.
+//
+// (kAdd (kCustomCall:gemm A B) C) is rewritten into (kCustomCall:gemm A B C),
+// where the "beta" parameter is set to 1 (provided it was zero before,
+// and provided C has no other users).
+// We then guide the buffer assignment to alias the buffer of the custom call
+// and C.
+//
+// For scaled FP8 GEMMs on Hopper systems, the following steps mentioned in
+// RFC #22 (https://github.com/openxla/xla/discussions/22) are elided and
+// rewritten into a Custom Call:
+//
+// 1. Cast each input from FP8 to a wider type such as FP16 or FP32.
+// 2. Unscale each input by multiplying each input by the corresponding input
+// scale.
+// 3. Evaluate the matrix multiplication on the scaled inputs.
+// 4. Compute the maximum of the absolute values in the result of the GEMM
+// (DAmax).
+// 5. Scale the output by dividing the output by the output scale.
+// 6. Cast the output back to FP8. Since saturation should be done on
+// overflow, this is represented by a Clamp instruction followed by a Convert
+// instruction.
+
+// Steps 1 through 3 can be elided independently of the remainder. Steps 5 and
+// 6 are elided only if steps 1 through 3 were successfully transformed. Step
+// 4 requires steps 5 and 6, i.e. the computation of DAmax can be elided only
+// when the output of the GEMM is requested in FP8 format.
+class GemmRewriterVisitor : public DfsHloRewriteVisitor {
+ public:
+  explicit GemmRewriterVisitor(const se::GpuComputeCapability &gpu_version,
+                               const int32_t toolkit_version,
+                               const GemmRewriterOptions options)
+      : gpu_version_(gpu_version),
+        toolkit_version_(toolkit_version),
+        options_(options) {}
+
+  absl::Status HandleDot(HloInstruction *instr) override {
+    if (!IsMatrixMultiplication(*instr) &&
+        !IsMatrixVectorMultiplication(*instr)) {
+      return absl::OkStatus();
+    }
+    // Sparse dot is not supported.
+    if (Cast<HloDotInstruction>(instr)->sparse_operands()) {
+      return absl::OkStatus();
+    }
+
+    int64_t gemm_rewrite_size_threshold =
+        instr->GetModule()
+            ->config()
+            .debug_options()
+            .xla_gpu_gemm_rewrite_size_threshold();
+    TF_ASSIGN_OR_RETURN(bool is_matmul_tiny,
+                        IsMatrixMultiplicationTooSmallForRewriting(
+                            *instr, gemm_rewrite_size_threshold));
+    if (is_matmul_tiny && IsDotSupportedByClassicalEmitters(*instr)) {
+      return absl::OkStatus();
+    }
+
+    CHECK(!instr->IsRank2Transpose());
+    if (instr->operand(0)->IsRank2Transpose() ||
+        instr->operand(1)->IsRank2Transpose()) {
+      return absl::OkStatus();
+    }
+    // Create a GemmBackendConfig based on the instruction.
+    TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_backend_config,
+                        instr->backend_config<GpuBackendConfig>());
+    GemmBackendConfig &gemm_backend_config =
+        *gpu_backend_config.mutable_gemm_backend_config();
+    gemm_backend_config.set_alpha_real(1.0);
+    gemm_backend_config.set_alpha_imag(0.0);
+    gemm_backend_config.set_beta(0.0);
+    *gemm_backend_config.mutable_dot_dimension_numbers() =
+        instr->dot_dimension_numbers();
+    *gemm_backend_config.mutable_precision_config() = instr->precision_config();
+
+    HloInstruction *lhs = instr->mutable_operand(0);
+    HloInstruction *rhs = instr->mutable_operand(1);
+    auto attributes = instr->frontend_attributes().map();
+    gemm_backend_config.set_grad_x(attributes["grad_x"] == "true");
+    gemm_backend_config.set_grad_y(attributes["grad_y"] == "true");
+
+    int64_t lhs_batch_dims_size =
+        instr->dot_dimension_numbers().lhs_batch_dimensions_size();
+    bool is_lhs_vector =
+        lhs->shape().dimensions_size() == lhs_batch_dims_size + 1;
+    bool is_rhs_vector =
+        rhs->shape().dimensions_size() == lhs_batch_dims_size + 1;
+    int64_t lhs_stride =
+        is_lhs_vector ? lhs->shape().dimensions(lhs_batch_dims_size)
+                      : lhs->shape().dimensions(lhs_batch_dims_size) *
+                            lhs->shape().dimensions(lhs_batch_dims_size + 1);
+    int64_t rhs_stride =
+        is_rhs_vector ? rhs->shape().dimensions(lhs_batch_dims_size)
+                      : rhs->shape().dimensions(lhs_batch_dims_size) *
+                            rhs->shape().dimensions(lhs_batch_dims_size + 1);
+
+    gemm_backend_config.set_lhs_stride(lhs_stride);
+    gemm_backend_config.set_rhs_stride(rhs_stride);
+
+    switch (options_.dtype) {
+      case GemmRewriterOptions::DType::kFp8Only: {
+        // Rewrite FP8 GEMMs into a type-specific cublasLT Custom Call.
+        TF_ASSIGN_OR_RETURN(
+            bool supported_by_cublaslt,
+            GemmIsSupportedByCublasLt(*instr, gemm_backend_config));
+        std::optional<MatchedFp8Param> a, b;
+        if (supported_by_cublaslt && instr->opcode() == HloOpcode::kDot &&
+            (a = MatchFp8Param(
+                 const_cast<HloInstruction *>(instr->operand(0)))) &&
+            (b = MatchFp8Param(
+                 const_cast<HloInstruction *>(instr->operand(1))))) {
+          if (IsRocm(gpu_version_) && toolkit_version_ < 60200 &&
+              instr->shape().element_type() != F16 &&
+              instr->shape().element_type() != F32) {
+            TF_ASSIGN_OR_RETURN(
+                instr, TurnF8DotWithUnsupportedOutputTypeIntoF32(instr));
+          }
+          TF_ASSIGN_OR_RETURN(bool created_call,
+                              CreateF8CustomCall(instr, gpu_backend_config,
+                                                 a.value(), b.value()));
+          if (created_call) {
+            return absl::OkStatus();
+          }
+        }
+        if (IsF8Type(instr->operand(0))) {
+          // FP8 rewriter couldn't rewrite dot with FP8 inputs into cublasLt
+          // custom call, so turn into an FP16 dot which may be rewritten as an
+          // FP16 Triton, cublas or cublasLt call.
+          TF_ASSIGN_OR_RETURN(instr, TurnF8DotIntoF16Dot(instr));
+        }
+        break;
+      }
+      case GemmRewriterOptions::DType::kNonFp8Only: {
+        // Rewrite non-FP8 GEMMs into a cublas or cublasLT Custom Call.
+        TF_ASSIGN_OR_RETURN(
+            absl::string_view gemm_custom_call_target,
+            GetNonFp8GemmCustomCallTarget(*instr, gemm_backend_config));
+        const Shape &output_shape = instr->shape();
+        HloInstruction *gemm_call =
+            instr->AddInstruction(HloInstruction::CreateCustomCall(
+                output_shape,
+                {instr->mutable_operand(0), instr->mutable_operand(1)},
+                gemm_custom_call_target));
+        TF_RETURN_IF_ERROR(gemm_call->set_backend_config(gpu_backend_config));
+        TF_RETURN_IF_ERROR(ReplaceInstruction(instr, gemm_call));
+      } break;
+    };
+    return absl::OkStatus();
+  }
+
+  absl::Status HandleMultiply(HloInstruction *instr) override {
+    HloInstruction *alpha, *existing_gemm;
+    if (Match(instr,
+              m::MultiplyAnyOrder(
+                  GemmOrCublasLtMatmulMaybeF8(&existing_gemm).WithOneUser(),
+                  m::Broadcast(m::ConstantScalar(&alpha)).WithOneUser()))) {
+      TF_ASSIGN_OR_RETURN(auto gpu_config,
+                          existing_gemm->backend_config<GpuBackendConfig>());
+      GemmBackendConfig &config = *gpu_config.mutable_gemm_backend_config();
+      // Do not fuse alpha into S32 GEMM, as they only support fixed values for
+      // alpha/beta.
+      if (existing_gemm->shape().element_type() == S32) {
+        return absl::OkStatus();
+      }
+
+      if (config.beta() == 0.0 && existing_gemm->user_count() == 1) {
+        complex128 prev_alpha = {config.alpha_real(), config.alpha_imag()};
+        complex128 new_alpha =
+            *alpha->literal().GetAsComplex128({}) * prev_alpha;
+        config.set_alpha_real(new_alpha.real());
+        config.set_alpha_imag(new_alpha.imag());
+        TF_RETURN_IF_ERROR(existing_gemm->set_backend_config(gpu_config));
+        return ReplaceInstruction(instr, existing_gemm);
+      }
+    }
+
+    HloInstruction *d_scale;
+    if (Match(instr, m::MultiplyAnyOrder(
+                         CublasLtMatmulF8(&existing_gemm).WithOneUser(),
+                         m::Broadcast(m::Op(&d_scale)).WithOneUser()))) {
+      return F8ScaleD(instr, existing_gemm, d_scale);
+    }
+
+    // Attempt to match approximate GELU activation
+    // (https://arxiv.org/abs/1606.08415), where:
+    // approx_gelu(x) = x * cdf(x)
+    // cdf(x) = 0.5 * (1 + tanh(sqrt(2 / pi) * (x + 0.044715 * x**3))
+    HloInstruction *cdf, *slice_or_bitcast = nullptr;
+    if (Match(instr, m::MultiplyAnyOrder(
+                         m::AnyOf<HloInstruction>(
+                             m::Slice(&slice_or_bitcast,
+                                      CublasLtMatmulMaybeF8(&existing_gemm)),
+                             m::Bitcast(&slice_or_bitcast,
+                                        CublasLtMatmulMaybeF8(&existing_gemm)),
+                             CublasLtMatmulMaybeF8(&existing_gemm)),
+                         m::Op(&cdf).WithOneUser())) &&
+        Match(cdf,
+              m::MultiplyAnyOrder(
+                  BcastConstScalar(0.5),
+                  m::AddAnyOrder(
+                      BcastConstScalar(1.0),
+                      m::Tanh(
+                          m::MultiplyAnyOrder(
+                              BcastConstScalarNear(sqrt(M_2_PI)),
+                              m::AddAnyOrder(
+                                  m::Op().Is(slice_or_bitcast ? slice_or_bitcast
+                                                              : existing_gemm),
+                                  m::MultiplyAnyOrder(
+                                      BcastConstScalarNear(0.044715),
+                                      m::MultiplyAnyOrder(
+                                          m::Op().Is(slice_or_bitcast
+                                                         ? slice_or_bitcast
+                                                         : existing_gemm),
+                                          m::MultiplyAnyOrder(
+                                              m::Op().Is(slice_or_bitcast
+                                                             ? slice_or_bitcast
+                                                             : existing_gemm),
+                                              m::Op().Is(slice_or_bitcast
+                                                             ? slice_or_bitcast
+                                                             : existing_gemm))
+                                              .WithOneUser())
+                                          .WithOneUser())
+                                      .WithOneUser())
+                                  .WithOneUser())
+                              .WithOneUser())
+                          .WithOneUser())))) {
+      return FuseGeluActivation(instr, existing_gemm, slice_or_bitcast);
+    }
+    return absl::OkStatus();
+  }
+
+  // Fuse the scaling of an FP8 GEMM into the Custom Call.
+  absl::Status HandleDivide(HloInstruction *instr) override {
+    HloInstruction *existing_gemm, *d_scale;
+    if (Match(instr, m::Divide(CublasLtMatmulF8(&existing_gemm).WithOneUser(),
+                               m::Broadcast(m::Op(&d_scale)).WithOneUser()))) {
+      return F8ScaleD(instr, existing_gemm, d_scale);
+    }
+    return absl::OkStatus();
+  }
+
+  absl::Status HandleAdd(HloInstruction *instr) override {
+    if (options_.bias_mode == GemmRewriterOptions::BiasMode::kNoBias) {
+      // See comments for `GemmRewriterOptions::BiasMode` for details.
+      return absl::OkStatus();
+    }
+
+    HloInstruction *bias, *existing_gemm = nullptr;
+    HloInstruction *optional_slice = nullptr;
+    HloInstruction *optional_convert = nullptr;
+    HloInstruction *optional_bitcast = nullptr;
+    // Attempt to elide broadcast and fuse addition of a vector bias into
+    // GEMM, including when slicing is applied to the result.
+    if (Match(instr,
+              m::AddAnyOrder(
+                  OptionalBitcast(
+                      &optional_bitcast,
+                      OptionalSlice(
+                          &optional_slice,
+                          CublasLtMatmulMaybeF8(&existing_gemm).WithOneUser())
+                          .WithOneUser())
+                      .WithOneUser(),
+                  m::Broadcast(&bias,
+                               OptionalConvert(&optional_convert, m::Op()))))) {
+      TF_ASSIGN_OR_RETURN(
+          bool was_fused,
+          FuseVectorBiasAdd(instr, bias, existing_gemm, optional_slice,
+                            optional_convert, optional_bitcast));
+
+      if (was_fused) {
+        return absl::OkStatus();
+      }
+    }
+    // Attempt to elide broadcast and fuse addition of a vector bias into
+    // *batched* GEMM as a matrix bias addition using FuseMatrixBiasAdd.
+    // add(bitcast(gemm(a, b)), broadcast(bias)) ->
+    //   bitcast(add(gemm(a, b), bitcast(broadcast(bias)))) ->
+    //   bitcast(gemm(a, b, bitcast(broadcast(bias)))) (FuseMatrixBiasAdd)
+    //
+    if (Match(
+            instr,
+            m::AddAnyOrder(
+                m::Bitcast(CublasLtMatmulMaybeF8(&existing_gemm).WithOneUser())
+                    .WithOneUser(),
+                m::Broadcast(&bias, m::Op()).WithOneUser()))) {
+      TF_ASSIGN_OR_RETURN(
+          HloInstruction * new_add,
+          MakeBinaryHlo(HloOpcode::kAdd, existing_gemm,
+                        MakeBitcastHlo(bias, existing_gemm->shape())));
+      TF_RETURN_IF_ERROR(
+          ReplaceInstruction(instr, MakeBitcastHlo(new_add, instr->shape())));
+
+      // Continue below.
+      instr = new_add;
+    }
+
+    // Do not fuse broadcast unless we can fuse its input, as it will cause
+    // broadcast materialization.
+    auto is_not_broadcast = [](const HloInstruction *instr) {
+      return instr->opcode() != HloOpcode::kBroadcast;
+    };
+
+    // add(bitcast(gemm(a, b)), bias) ->
+    //   bitcast(add(gemm(a, b), bitcast(bias))) ->
+    //   bitcast(gemm(a, b, bitcast(bias))) (later down in this function).
+    //
+    // We see this idiom in models that contain batch-dots, where we cast
+    // between a rank-2 shape for non-batch dots and a higher-rank shape for
+    // batch-dots.
+    //
+    // The last stage of the transform may fail (because of any of the checks in
+    // FuseMatrixBiasAdd), but if so that's okay -- we'll have done a useless
+    // transformation, but it doesn't hurt anything.
+    if (Match(instr,
+              m::AddAnyOrder(
+                  m::Bitcast(
+                      GemmOrCublasLtMatmulMaybeF8(&existing_gemm).WithOneUser())
+                      .WithOneUser(),
+                  m::Op(&bias).WithPredicate(is_not_broadcast)))) {
+      HloInstruction *new_bitcast =
+          MakeBitcastHlo(bias, existing_gemm->shape(), &bias->metadata());
+      TF_ASSIGN_OR_RETURN(HloInstruction * new_add,
+                          MakeBinaryHlo(HloOpcode::kAdd, existing_gemm,
+                                        new_bitcast, &bias->metadata()));
+      TF_RETURN_IF_ERROR(
+          ReplaceInstruction(instr, MakeBitcastHlo(new_add, instr->shape())));
+
+      // Continue below transforming new_add.
+      instr = new_add;
+    }
+
+    // Attempt to fuse matrix bias into gemm with optional convert
+    // add(convert(gemm(a, b)), c) -> gemm(a, b, c)
+    // add(gemm(a, b), c) -> gemm(a, b, c)
+    if (Match(instr,
+              m::AddAnyOrder(
+                  m::AnyOf<HloInstruction>(
+                      GemmOrCublasLtMatmul(&existing_gemm).WithOneUser(),
+                      m::Convert(
+                          GemmOrCublasLtMatmul(&existing_gemm).WithOneUser())
+                          .WithOneUser()),
+                  m::Op(&bias).WithPredicate(is_not_broadcast)))) {
+      TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_backend_config,
+                          existing_gemm->backend_config<GpuBackendConfig>());
+      const GemmBackendConfig &gemm_backend_config =
+          gpu_backend_config.gemm_backend_config();
+      // check if type combination is supported here
+      TF_ASSIGN_OR_RETURN(
+          bool types_are_supported,
+          IsLegacyCublasMatmul(*existing_gemm)
+              ? TypesAreSupportedByLegacyCublas(*existing_gemm,
+                                                gemm_backend_config, instr)
+              : TypesAreSupportedByCublasLt(*existing_gemm, gemm_backend_config,
+                                            instr));
+
+      // for mix type gemm, only fuse add if there is no consumers
+      // ROOT add
+      // ROOT tuple(add)
+      bool has_no_consumer =
+          instr->shape().element_type() ==
+              existing_gemm->shape().element_type() ||
+          instr->user_count() == 0 ||
+          (instr->user_count() == 1 &&
+           instr->users()[0]->opcode() == HloOpcode::kTuple &&
+           instr->users()[0]->user_count() == 0);
+
+      if (types_are_supported && has_no_consumer) {
+        return FuseMatrixBiasAdd(instr, bias, existing_gemm);
+      }
+    }
+
+    HloInstruction *optional_bitcast_matrix = nullptr;
+    HloInstruction *optional_slice_matrix = nullptr;
+    if (Match(instr,
+              m::AddAnyOrder(
+                  OptionalBitcast(
+                      &optional_bitcast_matrix,
+                      OptionalSlice(&optional_slice_matrix,
+                                    GemmOrCublasLtMatmulMaybeF8(&existing_gemm)
+                                        .WithOneUser()))
+                      .WithOneUser(),
+                  m::Op(&bias).WithPredicate(is_not_broadcast)))) {
+      // The matrix bias must not be FP8, see
+      // https://docs.nvidia.com/cuda/cublas/index.html.
+      if (!IsF8Type(bias)) {
+        return FuseMatrixBiasAdd(instr, bias, existing_gemm,
+                                 optional_bitcast_matrix,
+                                 optional_slice_matrix);
+      }
+    }
+
+    return absl::OkStatus();
+  }
+
+  absl::Status HandleMaximum(HloInstruction *instr) override {
+    HloInstruction *existing_gemm, *zeros;
+    HloInstruction *optional_slice_or_bitcast = nullptr;
+    // Attempt to elide maximum and fuse ReLU activation into GEMM, including
+    // when slicing or bitcasting is applied to the result.
+    if (Match(instr,
+              m::MaximumAnyOrder(
+                  m::AnyOf<HloInstruction>(
+                      m::Slice(
+                          &optional_slice_or_bitcast,
+                          CublasLtMatmulMaybeF8(&existing_gemm).WithOneUser()),
+                      m::Bitcast(
+                          &optional_slice_or_bitcast,
+                          CublasLtMatmulMaybeF8(&existing_gemm).WithOneUser()),
+                      CublasLtMatmulMaybeF8(&existing_gemm))
+                      .WithOneUser(),
+                  m::Broadcast(&zeros, m::ConstantScalar(0))))) {
+      TF_RETURN_IF_ERROR(FuseReluActivation(instr, zeros, existing_gemm,
+                                            optional_slice_or_bitcast));
+    }
+    return absl::OkStatus();
+  }
+
+  absl::Status HandleConvert(HloInstruction *instr) override {
+    HloInstruction *clamp_lower, *clamp_upper, *existing_gemm,
+        *d_scale = nullptr, *binary = nullptr;
+    // Attempt to elide the scaling and conversion of the result of an FP8
+    // GEMM, including the optional calculation of the maximum of the absolute
+    // values before scaling, and adapt the Custom Call.
+    if (Match(instr,
+              m::Convert(
+                  m::Clamp(
+                      m::Broadcast(m::ConstantScalar(&clamp_lower)),
+                      m::AnyOf<HloInstruction>(
+                          CublasLtMatmulF8(&existing_gemm),
+                          m::Divide(&binary, CublasLtMatmulF8(&existing_gemm),
+                                    m::Broadcast(m::Op(&d_scale))),
+                          m::MultiplyAnyOrder(&binary,
+                                              CublasLtMatmulF8(&existing_gemm),
+                                              m::Broadcast(m::Op(&d_scale)))),
+                      m::Broadcast(m::ConstantScalar(&clamp_upper)))
+                      .WithOneUser()))) {
+      return F8ConvertD(
+          instr, existing_gemm, d_scale, clamp_lower, clamp_upper,
+          /*mult_scale=*/(binary && binary->opcode() == HloOpcode::kMultiply));
+    }
+    return absl::OkStatus();
+  }
+
+  static bool IsCuda(const se::GpuComputeCapability &gpu_version) {
+    return std::holds_alternative<se::CudaComputeCapability>(gpu_version);
+  }
+
+  static absl::StatusOr<se::CudaComputeCapability> GetCudaComputeCapability(
+      const se::GpuComputeCapability &gpu_version) {
+    auto *cuda_cc = std::get_if<se::CudaComputeCapability>(&gpu_version);
+    if (cuda_cc == nullptr) {
+      return absl::InvalidArgumentError("Compute Capability is not CUDA.");
+    }
+    return *cuda_cc;
+  }
+
+  static bool IsRocm(const se::GpuComputeCapability &gpu_version) {
+    return std::holds_alternative<se::RocmComputeCapability>(gpu_version);
+  }
+
+  static absl::StatusOr<se::RocmComputeCapability> GetRocmComputeCapability(
+      const se::GpuComputeCapability &gpu_version) {
+    auto rocm_cc = std::get_if<se::RocmComputeCapability>(&gpu_version);
+    if (rocm_cc == nullptr) {
+      return absl::InvalidArgumentError("Compute Capability is not ROCm.");
+    }
+    return *rocm_cc;
+  }
+
+  absl::StatusOr<bool> CreateF8CustomCall(HloInstruction *instr,
+                                          GpuBackendConfig &gpu_backend_config,
+                                          MatchedFp8Param a,
+                                          MatchedFp8Param b) {
+    GemmBackendConfig &gemm_backend_config =
+        *gpu_backend_config.mutable_gemm_backend_config();
+    if (IsCuda(gpu_version_)) {
+      TF_ASSIGN_OR_RETURN(auto cuda_compute_capability,
+                          GetCudaComputeCapability(gpu_version_));
+      // FP8 GEMM kernels are only available on Ada, Hopper, and later
+      // architectures.
+      if (!cuda_compute_capability.IsAtLeast(8, 9)) {
+        VLOG(1) << "FP8 Custom Calls require Ada, Hopper, or later "
+                   "architectures. Got: "
+                << cuda_compute_capability.ToString()
+                << " and toolkit version: " << toolkit_version_;
+        return false;
+      }
+      // FP8 GEMM kernels are only available with CUDA 12.0 and above
+      if (toolkit_version_ < 12000) {
+        VLOG(1) << "FP8 Custom Calls require CUDA 12.0 or newer.";
+        return false;
+      }
+    }
+
+    if (IsRocm(gpu_version_)) {
+      TF_ASSIGN_OR_RETURN(auto rocm_compute_capability,
+                          GetRocmComputeCapability(gpu_version_));
+      if (!rocm_compute_capability.has_fp8_support()) {
+        VLOG(1) << "FP8 Custom Calls require MI300, or later architectures.";
+        return false;
+      }
+      if (toolkit_version_ < 60000) {
+        // FP8 GEMM kernels are only available with ROCm 6.0 and above
+        VLOG(1) << "FP8 Custom Calls require ROCm 6.0 or newer.";
+        return false;
+      }
+    }
+
+    PrimitiveType a_type = a.fp8_input->shape().element_type();
+    PrimitiveType b_type = b.fp8_input->shape().element_type();
+
+    // cuBLASLt FP8 GEMM kernels require one of the two operands to be in
+    // F8E4M3FN format.
+    if (IsCuda(gpu_version_)) {
+      if (a_type == F8E5M2 && b_type == F8E5M2) {
+        VLOG(1)
+            << "Failed to rewrite " << instr->ToShortString()
+            << " into FP8 Custom Call. The element type of one of the operands "
+               "must be F8E4M3FN.";
+        return false;
+      }
+      if ((a_type != F8E5M2 && a_type != F8E4M3FN) ||
+          (b_type != F8E5M2 && b_type != F8E4M3FN)) {
+        VLOG(1) << "Failed to rewrite " << instr->ToShortString()
+                << " into FP8 Custom Call. The input types must be F8E5M2 or "
+                   "F8E4M3FN, but got "
+                << PrimitiveType_Name(a_type) << " and "
+                << PrimitiveType_Name(b_type);
+        return false;
+      }
+    }
+
+    if (IsRocm(gpu_version_)) {
+      if (a_type == F8E5M2FNUZ && b_type == F8E5M2FNUZ) {
+        VLOG(1)
+            << "Failed to rewrite " << instr->ToShortString()
+            << " into FP8 Custom Call. The element type of one of the operands "
+               "must be F8E4M3FNUZ.";
+        return false;
+      }
+      if ((a_type != F8E5M2FNUZ && a_type != F8E4M3FNUZ) ||
+          (b_type != F8E5M2FNUZ && b_type != F8E4M3FNUZ)) {
+        VLOG(1)
+            << "Failed to rewrite " << instr->ToShortString()
+            << " into FP8 Custom Call. The input types must be F8E5M2FNUZ or "
+               "F8E4M3FNUZ, but got "
+            << PrimitiveType_Name(a_type) << " and "
+            << PrimitiveType_Name(b_type);
+        return false;
+      }
+    }
+
+    absl::Span<const int64_t> batch_dims =
+        gemm_backend_config.dot_dimension_numbers().rhs_batch_dimensions();
+
+    // cuBLASLt FP8 GEMM kernels require the scaling factors to be in F32
+    // format. Set the factors to one when no scaling factors were captured.
+    Literal one_literal = LiteralUtil::One(F32);
+    HloInstruction *one = instr->AddInstruction(
+        HloInstruction::CreateConstant(one_literal.Clone()));
+    std::array<bool, 2> mult_scale{a.mult_scale, b.mult_scale};
+    std::array<HloInstruction *, 2> scales{a.scale, b.scale}, inv_scales,
+        scales_f32;
+    for (int i = 0; i < scales.size(); ++i) {
+      if (scales[i]) {
+        if (!ShapeUtil::IsScalar(scales[i]->shape())) {
+          VLOG(1) << "Failed to rewrite " << instr->ToShortString()
+                  << " into FP8 Custom Call. The scaling factors must be "
+                     "scalars.";
+          return false;
+        }
+        if (!mult_scale[i]) {
+          inv_scales[i] = instr->AddInstruction(HloInstruction::CreateBinary(
+              scales[i]->shape(), HloOpcode::kDivide, one, scales[i]));
+        }
+        scales_f32[i] = mult_scale[i] ? scales[i] : inv_scales[i];
+        if (scales_f32[i]->shape().element_type() != F32) {
+          scales_f32[i] = instr->AddInstruction(HloInstruction::CreateConvert(
+              ShapeUtil::MakeScalarShape(F32), scales_f32[i]));
+        }
+      } else {
+        scales_f32[i] = one;
+      }
+    }
+
+    PrimitiveType d_type = instr->shape().element_type();
+    bool supported_d_type = (d_type == BF16 || d_type == F16 || d_type == F32);
+    if (IsCuda(gpu_version_) && (d_type == F8E4M3FN || d_type == F8E5M2)) {
+      supported_d_type = true;
+    }
+    if (IsRocm(gpu_version_) && toolkit_version_ >= 60200 &&
+        (d_type == F8E4M3FNUZ || d_type == F8E5M2FNUZ)) {
+      supported_d_type = true;
+    }
+    if (!supported_d_type) {
+      VLOG(1) << "Failed to rewrite " << instr->ToShortString()
+              << " into FP8 Custom Call. Output element type must be "
+              << (IsCuda(gpu_version_) ? "F8E4M3FN, F8E5M2, BF16, F16 or F32. "
+                  : toolkit_version_ >= 60200
+                      ? "F8E4M3FNUZ, F8E5M2FNUZ, BF16, F16 or F32. "
+                      : "BF16, F16 or F32. ")
+              << "Actual element type is " << PrimitiveType_Name(d_type);
+      return false;
+    }
+
+    // Each operand must have exactly one contracting and one non-contracting
+    // dimension.
+    absl::Span<const int64_t> a_contracting_dims =
+        gemm_backend_config.dot_dimension_numbers()
+            .lhs_contracting_dimensions();
+    absl::Span<const int64_t> b_contracting_dims =
+        gemm_backend_config.dot_dimension_numbers()
+            .rhs_contracting_dimensions();
+    if (a_contracting_dims.size() != 1 || b_contracting_dims.size() != 1) {
+      VLOG(1) << "Failed to rewrite " << instr->ToShortString()
+              << " into FP8 Custom Call. A and B must have one contracting "
+                 "dimension.";
+      return false;
+    }
+    if ((a.commutative_ops.empty() ? a.fp8_input
+                                   : a.commutative_ops.back().first)
+                    ->shape()
+                    .dimensions_size() -
+                batch_dims.size() !=
+            2 ||
+        (b.commutative_ops.empty() ? b.fp8_input
+                                   : b.commutative_ops.back().first)
+                    ->shape()
+                    .dimensions_size() -
+                batch_dims.size() !=
+            2) {
+      VLOG(1) << "Failed to rewrite " << instr->ToShortString()
+              << "into FP8 Custom Call. A and B must have one non-contracting "
+                 "dimension.";
+      return false;
+    }
+
+    // Sequentially apply the collected unary, dynamic-slice, pad and select ops
+    // to the unconverted and unscaled operands.
+    auto shift_ops = [&instr](HloInstruction *&x, InstrPath &x_ops) -> void {
+      for (std::pair<HloInstruction *, int> op : x_ops) {
+        std::vector<HloInstruction *> operands = {x};
+        // Insert the additional operands of dynamic-slice ops.
+        if (op.first->opcode() == HloOpcode::kDynamicSlice) {
+          for (int i = 1; i < op.first->operand_count(); ++i) {
+            operands.emplace_back(op.first->mutable_operand(i));
+          }
+        }
+        // Convert the second operand of pad ops.
+        if (op.first->opcode() == HloOpcode::kPad) {
+          HloInstruction *convert =
+              instr->AddInstruction(HloInstruction::CreateConvert(
+                  ShapeUtil::ChangeElementType(op.first->operand(1)->shape(),
+                                               x->shape().element_type()),
+                  op.first->mutable_operand(1)));
+          operands.emplace_back(convert);
+        }
+        // Convert and insert the additional operands of select ops.
+        if (op.first->opcode() == HloOpcode::kSelect) {
+          // The first operand is the predicate.
+          operands.emplace(operands.begin(), op.first->mutable_operand(0));
+          // Convert the remaining operand.
+          int operand_idx = op.second == 2 ? 1 : 2;
+          HloInstruction *convert =
+              instr->AddInstruction(HloInstruction::CreateConvert(
+                  ShapeUtil::ChangeElementType(
+                      op.first->operand(operand_idx)->shape(),
+                      x->shape().element_type()),
+                  op.first->mutable_operand(operand_idx)));
+          operands.emplace(operands.begin() + operand_idx, convert);
+        }
+        x = instr->AddInstruction(op.first->CloneWithNewOperands(
+            ShapeUtil::MakeShapeWithDenseLayout(
+                x->shape().element_type(), op.first->shape().dimensions(),
+                op.first->shape().layout().minor_to_major()),
+            operands));
+      }
+      return;
+    };
+    shift_ops(a.fp8_input, a.commutative_ops);
+    shift_ops(b.fp8_input, b.commutative_ops);
+
+    TF_ASSIGN_OR_RETURN(GemmConfig gemm_config,
+                        GemmConfig::For(instr, gemm_backend_config));
+
+    DotDimensionNumbers *dim_nums =
+        gemm_backend_config.mutable_dot_dimension_numbers();
+    int batch_dim_offset = batch_dims.size();
+
+    // cuBLASLt FP8 GEMM kernels currently require the first operand, i.e. A, to
+    // be row-major. If A is column-major, swap the contracting and
+    // non-contracting dimension and transpose the matrix to effectively make it
+    // column-major.
+    // TODO(philipphack): Remove once cuBLASLt supports A being column-major
+    if (gemm_config.lhs_layout.order == MatrixLayout::Order::kColumnMajor) {
+      CHECK(a_contracting_dims[0] == batch_dim_offset ||
+            a_contracting_dims[0] == batch_dim_offset + 1);
+      if (a_contracting_dims[0] == batch_dim_offset) {
+        dim_nums->set_lhs_contracting_dimensions(0, batch_dim_offset + 1);
+      } else {
+        dim_nums->set_lhs_contracting_dimensions(0, batch_dim_offset);
+      }
+      a.fp8_input =
+          TransposeMatrix(a.fp8_input, a_contracting_dims[0], batch_dims);
+    }
+
+    // Similarly, cuBLASLt requires the second operand to be column-major, so
+    // make it column-major if it is currently row-major.
+    if (gemm_config.rhs_layout.order == MatrixLayout::Order::kRowMajor) {
+      CHECK(b_contracting_dims[0] == batch_dim_offset ||
+            b_contracting_dims[0] == batch_dim_offset + 1);
+      if (b_contracting_dims[0] == batch_dim_offset) {
+        dim_nums->set_rhs_contracting_dimensions(0, batch_dim_offset + 1);
+      } else {
+        dim_nums->set_rhs_contracting_dimensions(0, batch_dim_offset);
+      }
+      b.fp8_input =
+          TransposeMatrix(b.fp8_input, b_contracting_dims[0], batch_dims);
+    }
+
+    a.fp8_input = PadOperandToMultipleOf16(batch_dims, a.fp8_input);
+    b.fp8_input = PadOperandToMultipleOf16(batch_dims, b.fp8_input);
+    Shape new_output_shape = PadShapeToMultipleOf16(instr->shape(), batch_dims);
+
+    std::vector<HloInstruction *> operands_list = {
+        a.fp8_input, b.fp8_input, scales_f32[0], scales_f32[1], one, one};
+
+    HloInstruction *new_custom_call =
+        instr->AddInstruction(HloInstruction::CreateCustomCall(
+            ShapeUtil::MakeShapeWithDenseLayout(
+                instr->shape().element_type(), new_output_shape.dimensions(),
+                instr->shape().layout().minor_to_major()),
+            operands_list, kCublasLtMatmulF8CallTarget));
+    TF_RETURN_IF_ERROR(new_custom_call->set_backend_config(gpu_backend_config));
+    TF_RETURN_IF_ERROR(SetName(instr->GetModule(), new_custom_call));
+
+    // Slice the result of the GEMM if the operands were padded.
+    HloInstruction *slice = nullptr;
+    if (new_output_shape.dimensions() != instr->shape().dimensions()) {
+      std::vector<int64_t> start_indices(instr->shape().rank(), 0);
+      std::vector<int64_t> strides(instr->shape().rank(), 1);
+      slice = instr->AddInstruction(HloInstruction::CreateSlice(
+          instr->shape(), new_custom_call, start_indices,
+          instr->shape().dimensions(), strides));
+    }
+
+    TF_RETURN_IF_ERROR(
+        ReplaceInstruction(instr, slice ? slice : new_custom_call));
+    VLOG(1) << instr->ToString() << " rewritten into FP8 Custom Call.";
+    return true;
+  }
+
+  absl::Status F8ScaleD(HloInstruction *instr, HloInstruction *existing_gemm,
+                        HloInstruction *d_scale) {
+    if (!ShapeUtil::IsScalar(d_scale->shape())) {
+      return absl::OkStatus();
+    }
+
+    // When the output of an FP8 GEMM is scaled but not type converted to FP8,
+    // cublasLT requires the scaling factor to be forwarded to the Custom Call
+    // as a_scale (chosen here) or b_scale. The scaling factor is fused here
+    // when no input scaling factors were fused during the creation of the
+    // Custom Call. When the maximum of the absolute value of the output of an
+    // FP8 GEMM is calculated and the output is scaled and type converted to
+    // FP8, the scaling of the output is fused in F8ConvertD.
+    if (!existing_gemm->operand(2)->IsConstant() ||
+        existing_gemm->operand(2)->literal().GetAsDouble({}) != 1.) {
+      return absl::OkStatus();
+    }
+
+    // The application of the scaling of the output to the input (see previous
+    // comment) is not valid for epilogues other than ReLU or when a matrix bias
+    // has been fused.
+    TF_ASSIGN_OR_RETURN(auto gpu_backend_config,
+                        existing_gemm->backend_config<GpuBackendConfig>());
+    const GemmBackendConfig &config = gpu_backend_config.gemm_backend_config();
+    if ((config.epilogue() != GemmBackendConfig::DEFAULT &&
+         config.epilogue() != GemmBackendConfig::RELU) ||
+        config.beta() != 0.) {
+      return absl::OkStatus();
+    }
+
+    // If necessary, invert the scaling factor of D and convert to F32.
+    TF_ASSIGN_OR_RETURN(
+        d_scale,
+        InvertAndConvertScalar(d_scale, instr->opcode() == HloOpcode::kDivide));
+
+    TF_RETURN_IF_ERROR(existing_gemm->ReplaceOperandWith(2, d_scale));
+    TF_RETURN_IF_ERROR(ReplaceInstruction(instr, existing_gemm));
+
+    VLOG(1) << "Scaling of FP8 GEMM fused into Custom Call.";
+    return absl::OkStatus();
+  }
+
+  absl::Status F8ConvertD(HloInstruction *instr, HloInstruction *existing_gemm,
+                          HloInstruction *d_scale, HloInstruction *clamp_lower,
+                          HloInstruction *clamp_upper,
+                          bool mult_scale = false) {
+    // Verify the data types and the operands of clamp.
+    if (instr->shape().element_type() == F8E4M3FN) {
+      if (!clamp_lower->literal().IsAllFloat(static_cast<float>(
+              std::numeric_limits<tsl::float8_e4m3fn>::lowest())) ||
+          !clamp_upper->literal().IsAllFloat(static_cast<float>(
+              std::numeric_limits<tsl::float8_e4m3fn>::max()))) {
+        return absl::OkStatus();
+      }
+    } else if (instr->shape().element_type() == F8E5M2) {
+      if (!clamp_lower->literal().IsAllFloat(static_cast<float>(
+              std::numeric_limits<tsl::float8_e5m2>::lowest())) ||
+          !clamp_upper->literal().IsAllFloat(static_cast<float>(
+              std::numeric_limits<tsl::float8_e5m2>::max()))) {
+        return absl::OkStatus();
+      }
+    } else {
+      return absl::OkStatus();
+    }
+
+    if (d_scale && !ShapeUtil::IsScalar(d_scale->shape())) {
+      return absl::OkStatus();
+    }
+
+    // The possible second user of the GEMM must be the calculation of the
+    // maximum of the absolute value of the result of the GEMM. Since it is
+    // unknown in what form this operation will be used, it is identified in a
+    // top-down approach by inspecting the users of the GEMM.
+    const std::vector<HloInstruction *> gemm_users = existing_gemm->users();
+    HloInstruction *reduce_damax = nullptr;
+    if (gemm_users.size() == 2) {
+      // In the presence of a ReLU activation, the abs instruction is elided
+      // since abs(ReLU(x)) = ReLU(x).
+      TF_ASSIGN_OR_RETURN(auto gpu_config,
+                          existing_gemm->backend_config<GpuBackendConfig>());
+      const GemmBackendConfig &config = gpu_config.gemm_backend_config();
+      for (int i = 0; i < gemm_users.size(); ++i) {
+        HloInstruction *maybe_reduce = nullptr;
+        if (gemm_users[i]->opcode() == HloOpcode::kAbs) {
+          if (gemm_users[i]->users().size() != 1) continue;
+          maybe_reduce = gemm_users[i]->users()[0];
+        } else {
+          // If there is no Abs instruction, relu is required as epilogue to
+          // ensure all values are nonnegative.
+          if (config.epilogue() != GemmBackendConfig::BIAS_RELU &&
+              config.epilogue() != GemmBackendConfig::RELU)
+            continue;
+          maybe_reduce = gemm_users[i];
+        }
+
+        if (maybe_reduce->opcode() == HloOpcode::kReduce &&
+            maybe_reduce->operands().size() == 2 &&
+            maybe_reduce->operand(1)->opcode() == HloOpcode::kConstant &&
+            ShapeUtil::IsScalar(maybe_reduce->operand(1)->shape())) {
+          HloInstruction *reduce = maybe_reduce;
+          HloComputation *reduce_comp = reduce->to_apply();
+          HloInstruction *reduce_comp_root = reduce_comp->root_instruction();
+          if (reduce->operand(1)->literal().GetAsDouble({}) <= 0. &&
+              reduce_comp_root->opcode() == HloOpcode::kMaximum &&
+              reduce_comp_root->operand(0)->opcode() == HloOpcode::kParameter &&
+              reduce_comp_root->operand(1)->opcode() == HloOpcode::kParameter) {
+            reduce_damax = reduce;
+          }
+        }
+      }
+      if (!reduce_damax) {
+        return absl::OkStatus();
+      }
+    } else if (gemm_users.size() > 2) {
+      return absl::OkStatus();
+    }
+
+    TF_ASSIGN_OR_RETURN(auto gpu_backend_config,
+                        existing_gemm->backend_config<GpuBackendConfig>());
+    const GemmBackendConfig &gemm_backend_config =
+        gpu_backend_config.gemm_backend_config();
+
+    if (gemm_backend_config.beta() != 0.0) {
+      if (existing_gemm->operand(2)->shape().element_type() != BF16 &&
+          existing_gemm->operand(2)->shape().element_type() != F16) {
+        VLOG(1) << "The scaling and conversion of the result of "
+                << existing_gemm->ToShortString()
+                << " is not fused into the FP8 Custom Call because it "
+                   "conflicts with the existing fusion of the addition of a "
+                   "matrix bias with element type other than BF16 or F16.";
+        return absl::OkStatus();
+      } else {
+        // Turn off the output to operand aliasing, since the fp8 output and
+        // bf16/fp16 bias have different sizes.
+        xla::Cast<HloCustomCallInstruction>(existing_gemm)
+            ->set_output_to_operand_aliasing({});
+      }
+    }
+
+    // If necessary, invert the scaling factor of D and convert to F32.
+    if (d_scale) {
+      TF_ASSIGN_OR_RETURN(d_scale,
+                          InvertAndConvertScalar(d_scale, !mult_scale));
+      TF_RETURN_IF_ERROR(existing_gemm->ReplaceOperandWith(
+          gemm_backend_config.beta() == 0.0 ? 5 : 6, d_scale));
+    }
+
+    // If present, elide the calculation of the maximum of the absolute values
+    // of the result of the GEMM.
+    if (reduce_damax) {
+      return F8AddDAmax(instr, existing_gemm, reduce_damax);
+    }
+
+    std::unique_ptr<HloInstruction> new_gemm =
+        existing_gemm->CloneWithNewShape(instr->shape());
+
+    TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(instr, std::move(new_gemm)));
+
+    VLOG(1) << "Conversion" << (reduce_damax ? " and amax calculation" : "")
+            << " fused into FP8 GEMM.";
+    return absl::OkStatus();
+  }
+
+  // Adds a scalar DAmax return value to an FP8 GEMM.
+  absl::Status F8AddDAmax(HloInstruction *instr, HloInstruction *existing_gemm,
+                          HloInstruction *reduce_damax) {
+    // Change the output shape of the Custom Call to tuple(D, DAmax).
+    Shape damax_shape = ShapeUtil::MakeScalarShape(F32);
+    Shape tuple_shape =
+        ShapeUtil::MakeTupleShape({instr->shape(), damax_shape});
+    HloInstruction *gemm_and_damax =
+        instr->AddInstruction(existing_gemm->CloneWithNewShape(tuple_shape));
+
+    TF_ASSIGN_OR_RETURN(auto gpu_config,
+                        gemm_and_damax->backend_config<GpuBackendConfig>());
+    GemmBackendConfig &config = *gpu_config.mutable_gemm_backend_config();
+    config.set_damax_output(true);
+    TF_RETURN_IF_ERROR(gemm_and_damax->set_backend_config(gpu_config));
+
+    // Obtain D and DAmax separately from the output tuple.
+    HloInstruction *d =
+        instr->AddInstruction(HloInstruction::CreateGetTupleElement(
+            instr->shape(), gemm_and_damax, 0));
+    HloInstruction *damax = instr->AddInstruction(
+        HloInstruction::CreateGetTupleElement(damax_shape, gemm_and_damax, 1));
+
+    // Convert DAmax from FP32 to the requested type and elide reduce.
+    HloInstruction *damax_converted = instr->AddInstruction(
+        HloInstruction::CreateConvert(reduce_damax->shape(), damax));
+    TF_RETURN_IF_ERROR(ReplaceInstruction(reduce_damax, damax_converted));
+    TF_RETURN_IF_ERROR(ReplaceInstruction(instr, d));
+
+    return absl::OkStatus();
+  }
+
+  // Fuses a matrix bias into a cuBLAS call. 'instr' should be an Add
+  // instruction in the following form:
+  //   Add(OptionalBitcast(OptionalSlice(gemm)), bias)
+  // where 'gemm' is expected to be a cuBLAS custom_call. Slice is introduced
+  // when the inputs of the gemm are possibly padded. Bitcast is introduced to
+  // handle high rank input.
+  absl::Status FuseMatrixBiasAdd(HloInstruction *instr, HloInstruction *bias,
+                                 const HloInstruction *gemm,
+                                 HloInstruction *bitcast = nullptr,
+                                 HloInstruction *slice = nullptr) {
+    TF_RET_CHECK(Shape::Equal().IgnoreElementType()(bias->shape(),
+                                                    bitcast ? bitcast->shape()
+                                                    : slice ? slice->shape()
+                                                            : gemm->shape()));
+
+    // Do not fuse bias into S32 GEMM, as for this datatype cuBLAS only
+    // supports fixed values for alpha/beta.
+    if (gemm->shape().element_type() == S32) {
+      return absl::OkStatus();
+    }
+
+    // To ensure correctness, only slices that chop off the ends of dimensions
+    // are supported.
+    if (slice) {
+      int slice_op_dim = slice->operand(0)->shape().rank();
+      if (slice->slice_starts() != std::vector<int64_t>(slice_op_dim, 0) ||
+          slice->slice_strides() != std::vector<int64_t>(slice_op_dim, 1)) {
+        return absl::OkStatus();
+      }
+    }
+    // Cublas gemm overwrites the bias matrix, so fusion is only possible if the
+    // gemm is the only user. CublasLt gemm can operate out-of-place.
+    bool can_overwrite_bias = [bias]() {
+      if (bias->user_count() > 1) {
+        // There is another user of the data, do not overwrite it.
+        return false;
+      }
+
+      if (bias->opcode() != HloOpcode::kParameter) {
+        // Not a parameter; can overwrite.
+        return true;
+      }
+
+      // The bias is a parameter of the computation; check if it is aliased.
+      if (!bias->parent()->IsEntryComputation()) {
+        // Only the HloModule has input/output aliasing, since this is not the
+        // entry computation, there are no guarantees about aliasing; do not
+        // overwrite.
+        return false;
+      }
+      const auto &in_out_alias_config =
+          bias->GetModule()->input_output_alias_config();
+      // If the parameter is aliased, we can overwrite it.
+      // TODO(victorstone): The assumption when calling ParameterHasAlias is
+      // that bias is not a tuple. This is why we pass {} as the argument for
+      // param_index.
+      return in_out_alias_config.ParameterHasAlias(bias->parameter_number(),
+                                                   /*param_index=*/{});
+    }();
+    bool want_to_fuse_bias = IsCublasLtMatmulF8(*gemm) ||
+                             IsCublasLtMatmul(*gemm) || can_overwrite_bias;
+
+    auto gpu_config = gemm->backend_config<GpuBackendConfig>().value();
+    GemmBackendConfig &config = *gpu_config.mutable_gemm_backend_config();
+    // It is possible to fuse into a cublasLt matmul that already has a vector
+    // bias, but no other epilogue will commute with the matrix bias add.
+    bool supported_epilogue =
+        ((config.epilogue() == GemmBackendConfig::DEFAULT) ||
+         (config.epilogue() == GemmBackendConfig::BIAS));
+
+    if ((config.beta() != 0) || !want_to_fuse_bias ||
+        (gemm->user_count() != 1) || !supported_epilogue) {
+      return absl::OkStatus();
+    }
+
+    config.set_beta(1.0);
+
+    std::vector<HloInstruction *> operands(gemm->operands().begin(),
+                                           gemm->operands().end());
+    HloInstruction *maybe_constant_folded_bias = MaybeConstantFoldBias(bias);
+    if (bitcast) {
+      maybe_constant_folded_bias =
+          instr->AddInstruction(HloInstruction::CreateBitcast(
+              slice->shape(), maybe_constant_folded_bias));
+    }
+
+    maybe_constant_folded_bias =
+        PadOperandToTargetShape(gemm->shape(), maybe_constant_folded_bias);
+
+    operands.insert(operands.begin() + 2, maybe_constant_folded_bias);
+
+    std::unique_ptr<HloInstruction> fused_op =
+        gemm->CloneWithNewOperands(gemm->shape(), operands);
+    // set output shape to bias shape if mix type
+    fused_op->mutable_shape()->set_element_type(bias->shape().element_type());
+    TF_RETURN_IF_ERROR(fused_op->set_backend_config(gpu_config));
+
+    // Choose whether the bias must alias the output. Legacy cublas GEMMs must
+    // operate in place and alias the bias with the output, whereas with
+    // cublasLt we can choose.
+    //
+    // Operating in place is always safe; copy-insertion will insert copies if
+    // necessary.  But (we assume) copying is slower than operating
+    // out-of-place, so for cublasLt (where we have the choice), we try to
+    // operate in place if we think it a copy won't be necessary.
+    //
+    // We assume that parameters are always read-only and therefore we'd need to
+    // copy if we were going to operate in place. (This is not quite true; the
+    // param could have input/output aliasing.)  We also assume that if there
+    // are other uses of the bias, we might need to copy.  (Again, not quite
+    // true if those uses all come before this operation.  But copy-insertion
+    // runs before scheduling, so it can't know and has to conservatively insert
+    // copies.)
+    if (IsLegacyCublasMatmul(*fused_op) || can_overwrite_bias) {
+      xla::Cast<HloCustomCallInstruction>(fused_op.get())
+          ->set_output_to_operand_aliasing({{{}, {2, {}}}});
+    }
+    TF_RETURN_IF_ERROR(SetName(instr->GetModule(), fused_op.get()));
+    if (slice) {
+      fused_op = slice->CloneWithNewOperands(
+          slice->shape(),
+          {slice->parent()->AddInstruction(std::move(fused_op))});
+    }
+
+    if (bitcast) {
+      fused_op = bitcast->CloneWithNewOperands(
+          bitcast->shape(),
+          {bitcast->parent()->AddInstruction(std::move(fused_op))});
+    }
+
+    return ReplaceWithNewInstruction(instr, std::move(fused_op));
+  }
+
+  // Fuses a vector bias into a cuBLAS call. 'instr' should be an Add
+  // instruction in the following form:
+  //   Add(OptionalBitcast(OptionalSlice(gemm)), Broadcast(OptionalConvert()))
+  // where 'gemm' is expected to be a cuBLAS custom_call. The optional
+  // convert is only used for F8 matmuls as cublasLt has specific constraints
+  // on the vector bias type for such matmuls. The optional bitcast is
+  // necessary to handle high rank input cases.
+  absl::StatusOr<bool> FuseVectorBiasAdd(HloInstruction *instr,
+                                         HloInstruction *broadcast,
+                                         HloInstruction *gemm,
+                                         HloInstruction *slice = nullptr,
+                                         HloInstruction *convert = nullptr,
+                                         HloInstruction *bitcast = nullptr) {
+    if (!bitcast) {
+      TF_RET_CHECK(ShapeUtil::Compatible(
+          broadcast->shape(), (slice ? slice->shape() : gemm->shape())));
+    }
+    // Verify that the data type is supported by Epilogue Fusion.
+    if (!SupportsEpilogueFusion(gemm->shape().element_type())) {
+      return false;
+    }
+
+    HloInstruction *bias = broadcast->mutable_operand(0);
+
+    TF_ASSIGN_OR_RETURN(auto gpu_config,
+                        gemm->backend_config<GpuBackendConfig>());
+    GemmBackendConfig &config = *gpu_config.mutable_gemm_backend_config();
+    // # output column dims == # non-contracting rhs operand dims.
+    const DotDimensionNumbers &dot_dims = config.dot_dimension_numbers();
+    size_t num_col_dims = gemm->operand(1)->shape().rank() -
+                          dot_dims.rhs_batch_dimensions_size() -
+                          dot_dims.rhs_contracting_dimensions_size();
+
+    if ((gemm->user_count() != 1) ||
+        (config.epilogue() != GemmBackendConfig::DEFAULT) ||
+        (bias->shape().rank() != num_col_dims)) {
+      return false;
+    }
+    // We require the bias vector to have been broadcast in the most major
+    // dimensions; i.e. its most minor physical dimensions align with most minor
+    // physical dimensions of the gemm output.
+    absl::Span<const int64_t> broadcast_dims = broadcast->dimensions();
+    for (size_t i = 0; i < num_col_dims; ++i) {
+      int64_t dim =
+          (bitcast ? bitcast : gemm)->shape().layout().minor_to_major(i);
+
+      // Find the corresponding dimension from the bias vector.
+      auto it = absl::c_find(broadcast_dims, dim);
+
+      if (it == broadcast_dims.end()) {
+        return false;
+      }
+
+      int64_t vector_dim = it - broadcast_dims.begin();
+      if (bias->shape().layout().minor_to_major(i) != vector_dim) {
+        return false;
+      }
+    }
+
+    std::vector<HloInstruction *> operands(gemm->operands().begin(),
+                                           gemm->operands().end());
+    // When (non-trivial) matrix and vector bias co-exist for FP8 matmul, just
+    // fuse matrix bias.
+    if (gemm->custom_call_target() == kCublasLtMatmulF8CallTarget &&
+        config.beta() != 0.0) {
+      return true;
+    }
+
+    if (gemm->custom_call_target() == kCublasLtMatmulF8CallTarget &&
+        bias->shape().element_type() == F32) {
+      if (convert == nullptr) {
+        return false;
+      }
+
+      HloInstruction *bias_f16_or_bf16 = convert->mutable_operand(0);
+      auto compatible_bias_type = [](const PrimitiveType bias_type,
+                                     const PrimitiveType output_type) {
+        if (bias_type == BF16) {
+          return output_type == F8E4M3FN || output_type == F8E5M2 ||
+                 output_type == F32 || output_type == BF16;
+        } else if (bias_type == F16) {
+          return output_type == F16 || output_type == F8E4M3FN ||
+                 output_type == F8E5M2;
+        }
+        return false;
+      };
+
+      // cuBLAS LT does not support FP32 biases on matmuls with FP8 inputs,
+      // even if the matmul output is FP32. We do not unconditionally convert
+      // the bias to a supported precision (F16 or BF16) because this lowers
+      // precision. Instead, we only fuse the bias if the bias itself is a
+      // convert from F16 or BF16, fusing the input of the convert instruction
+      // to the matmul.
+      if (compatible_bias_type(bias_f16_or_bf16->shape().element_type(),
+                               gemm->shape().element_type())) {
+        bias = bias_f16_or_bf16;
+      } else {
+        VLOG(1) << "Epilogue fusion of FP32 vector bias into FP8 GEMM is "
+                   "currently not supported. See the cublasLT support matrix.";
+        return false;
+      }
+    }
+
+    // In the case of high rank input for FP8, it is necessary to consider
+    // potential padding for the bias.
+    if (gemm->custom_call_target() == kCublasLtMatmulF8CallTarget && bitcast) {
+      bias = PadOperandToMultipleOf16(
+          config.dot_dimension_numbers().rhs_batch_dimensions(), bias);
+    }
+    // Replace add(gemm, broadcast) with fused new_gemm.
+    operands.push_back(bias);
+    config.set_epilogue(GemmBackendConfig::BIAS);
+    std::unique_ptr<HloInstruction> result =
+        gemm->CloneWithNewOperands(gemm->shape(), operands);
+    TF_RETURN_IF_ERROR(result->set_backend_config(gpu_config));
+    TF_RETURN_IF_ERROR(SetName(result->GetModule(), result.get()));
+    if (slice) {
+      result = slice->CloneWithNewOperands(
+          slice->shape(), {slice->parent()->AddInstruction(std::move(result))});
+    }
+
+    if (bitcast) {
+      result = bitcast->CloneWithNewOperands(
+          bitcast->shape(),
+          {bitcast->parent()->AddInstruction(std::move(result))});
+    }
+    TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(instr, std::move(result)));
+    return true;
+  }
+
+  absl::Status FuseReluActivation(HloInstruction *instr,
+                                  HloInstruction *broadcast,
+                                  HloInstruction *gemm,
+                                  HloInstruction *slice_or_bitcast = nullptr) {
+    TF_RET_CHECK(ShapeUtil::Compatible(
+        broadcast->shape(),
+        (slice_or_bitcast ? slice_or_bitcast->shape() : gemm->shape())));
+
+    if (!SupportsEpilogueFusion(gemm->shape().element_type())) {
+      return absl::OkStatus();
+    }
+
+    if (gemm->user_count() != 1) {
+      return absl::OkStatus();
+    }
+
+    TF_ASSIGN_OR_RETURN(auto gpu_config,
+                        gemm->backend_config<GpuBackendConfig>());
+    GemmBackendConfig &config = *gpu_config.mutable_gemm_backend_config();
+    if (config.epilogue() == GemmBackendConfig::DEFAULT) {
+      config.set_epilogue(GemmBackendConfig::RELU);
+    } else if (config.epilogue() == GemmBackendConfig::BIAS) {
+      config.set_epilogue(GemmBackendConfig::BIAS_RELU);
+    } else {
+      return absl::OkStatus();
+    }
+
+    std::unique_ptr<HloInstruction> result = gemm->Clone();
+    TF_RETURN_IF_ERROR(result->set_backend_config(gpu_config));
+    TF_RETURN_IF_ERROR(SetName(result->GetModule(), result.get()));
+
+    if (slice_or_bitcast) {
+      result = slice_or_bitcast->CloneWithNewOperands(
+          slice_or_bitcast->shape(),
+          {slice_or_bitcast->parent()->AddInstruction(std::move(result))});
+    }
+
+    return ReplaceWithNewInstruction(instr, std::move(result));
+  }
+
+  absl::Status FuseGeluActivation(HloInstruction *multiply,
+                                  HloInstruction *gemm,
+                                  HloInstruction *slice_or_bitcast = nullptr) {
+    if (!SupportsEpilogueFusion(gemm->shape().element_type())) {
+      return absl::OkStatus();
+    }
+    // For CUDA versions less than 12.3.2, cuBLAS LT returns
+    // CUBLAS_STATUS_NOT_SUPPORTED in some cases when fusing gelu into an FP8
+    // matmul. We cannot check the patch version, so disable this fusion with
+    // CUDA versions less than 12.4.
+    if (IsCuda(gpu_version_) && toolkit_version_ < 12040 &&
+        IsCublasLtMatmulF8(*gemm)) {
+      return absl::OkStatus();
+    }
+
+    // There are four users of the gemm output within the GELU calculation.
+    bool has_aux = gemm->user_count() > 4;
+
+    TF_ASSIGN_OR_RETURN(auto gpu_config,
+                        gemm->backend_config<GpuBackendConfig>());
+    GemmBackendConfig &config = *gpu_config.mutable_gemm_backend_config();
+
+    if (config.epilogue() == GemmBackendConfig::DEFAULT) {
+      config.set_epilogue(has_aux ? GemmBackendConfig::GELU_AUX
+                                  : GemmBackendConfig::GELU);
+    } else if (config.epilogue() == GemmBackendConfig::BIAS) {
+      config.set_epilogue(has_aux ? GemmBackendConfig::BIAS_GELU_AUX
+                                  : GemmBackendConfig::BIAS_GELU);
+    } else {
+      return absl::OkStatus();
+    }
+
+    std::unique_ptr<HloInstruction> output = gemm->CloneWithNewShape(
+        has_aux ? ShapeUtil::MakeTupleShape({gemm->shape(), gemm->shape()})
+                : gemm->shape());
+    TF_RETURN_IF_ERROR(output->set_backend_config(gpu_config));
+    TF_RETURN_IF_ERROR(SetName(multiply->GetModule(), output.get()));
+
+    if (slice_or_bitcast) {
+      output = slice_or_bitcast->CloneWithNewOperands(
+          slice_or_bitcast->shape(),
+          {gemm->parent()->AddInstruction(std::move(output))});
+    }
+
+    if (has_aux) {
+      HloInstruction *tuple_output =
+          gemm->parent()->AddInstruction(std::move(output));
+      TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
+          gemm, HloInstruction::CreateGetTupleElement(tuple_output, 1)));
+      output = HloInstruction::CreateGetTupleElement(tuple_output, 0);
+    }
+
+    return ReplaceWithNewInstruction(multiply, std::move(output));
+  }
+
+ private:
+  se::GpuComputeCapability gpu_version_;
+  int32_t toolkit_version_;
+  GemmRewriterOptions options_;
+
+  // Choose cublas or cublasLt for the target of the custom call that instr will
+  // be rewritten into.
+  absl::StatusOr<absl::string_view> GetNonFp8GemmCustomCallTarget(
+      const HloInstruction &instr,
+      const GemmBackendConfig &gemm_backend_config) const {
+    if (!instr.GetModule()
+             ->config()
+             .debug_options()
+             .xla_gpu_enable_cublaslt()) {
+      // cublasLt is not enabled.
+      return absl::string_view(kGemmCallTarget);
+    }
+
+    // cublasLt is enabled, check if other internal conditions are met.
+    const HloInstruction *lhs = instr.operand(0);
+    const HloInstruction *rhs = instr.operand(1);
+    if (lhs->shape().element_type() == S8 ||
+        rhs->shape().element_type() == S8) {
+      // TODO(b/241446501) The XLA usage of cublasLt does not yet handle
+      // int8 matmuls. Fallback to legacy cublas.
+      return absl::string_view(kGemmCallTarget);
+    }
+
+    // All internal conditions are met, check if we meet the requirements of
+    // cublasLt.
+    TF_ASSIGN_OR_RETURN(bool gemm_is_supported_by_cublas_lt,
+                        GemmIsSupportedByCublasLt(instr, gemm_backend_config));
+    if (gemm_is_supported_by_cublas_lt) {
+      return absl::string_view(kCublasLtMatmulCallTarget);
+    }
+
+    // This case is not supported by cublasLt, fallback to legacy cublas.
+    return absl::string_view(kGemmCallTarget);
+  }
+
+  absl::StatusOr<bool> TypesAreSupportedByLegacyCublas(
+      const HloInstruction &instr, const GemmBackendConfig &gemm_backend_config,
+      const HloInstruction *bias = nullptr) const {
+    // Figure out the Atype/Btype.
+    const PrimitiveType a_dtype = instr.operand(0)->shape().element_type();
+    const PrimitiveType b_dtype = instr.operand(1)->shape().element_type();
+    const PrimitiveType output_type =
+        bias ? bias->shape().element_type() : instr.shape().element_type();
+    const std::array<PrimitiveType, 12> supported_type = {
+        PrimitiveType::S8,  PrimitiveType::F16, PrimitiveType::BF16,
+        PrimitiveType::F32, PrimitiveType::S32, PrimitiveType::F64,
+        PrimitiveType::C64, PrimitiveType::C128};
+    // legacy cublas has a defined set of combinations of types that it
+    // supports. Figure out the computeType and scaleType.
+    if (!absl::c_linear_search(supported_type, output_type)) return false;
+    TF_ASSIGN_OR_RETURN(const se::blas::DataType output_dtype,
+                        se::gpu::AsBlasDataType(output_type));
+    // TODO(tdanyluk): Investigate why don't we use the actual precision (and
+    // algorithm) here? Why do we use the default?
+    TF_ASSIGN_OR_RETURN(const se::blas::ComputationType compute_type,
+                        se::gpu::GetBlasComputationType(
+                            PrecisionConfig::ALG_UNSET, a_dtype, output_type,
+                            stream_executor::blas::kDefaultComputePrecision));
+    se::blas::DataType scale_type =
+        se::gpu::GetScaleType(output_dtype, compute_type);
+
+    using se::blas::ComputationType;
+    using se::blas::DataType;
+    // This matrix of supported types is taken directly from cublas
+    // documentation.
+    // https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmex
+    const std::array<
+        std::tuple<ComputationType, DataType /*scale_type*/,
+                   PrimitiveType /*a_dtype*/, PrimitiveType /*b_dtype*/,
+                   DataType /*output_dtype*/>,
+        32>
+        supported_type_combinations = {{
+            {ComputationType::kF16, DataType::kHalf, PrimitiveType::F16,
+             PrimitiveType::F16, DataType::kHalf},
+
+            {ComputationType::kI32, DataType::kInt32, PrimitiveType::S8,
+             PrimitiveType::S8, DataType::kInt32},
+
+            {ComputationType::kF32, DataType::kFloat, PrimitiveType::BF16,
+             PrimitiveType::BF16, DataType::kBF16},
+            {ComputationType::kF32, DataType::kFloat, PrimitiveType::F16,
+             PrimitiveType::F16, DataType::kHalf},
+            {ComputationType::kF32, DataType::kFloat, PrimitiveType::S8,
+             PrimitiveType::S8, DataType::kFloat},
+            {ComputationType::kF32, DataType::kFloat, PrimitiveType::BF16,
+             PrimitiveType::BF16, DataType::kFloat},
+            {ComputationType::kF32, DataType::kFloat, PrimitiveType::F16,
+             PrimitiveType::F16, DataType::kFloat},
+            {ComputationType::kF32, DataType::kFloat, PrimitiveType::F32,
+             PrimitiveType::F32, DataType::kFloat},
+
+            // There would be an entry here for A/BType complex int8, but we do
+            // not support that type.
+            {ComputationType::kF32, DataType::kComplexFloat, PrimitiveType::C64,
+             PrimitiveType::C64, DataType::kComplexFloat},
+
+            {ComputationType::kF16AsF32, DataType::kFloat, PrimitiveType::F32,
+             PrimitiveType::F32, DataType::kFloat},
+            {ComputationType::kF16AsF32, DataType::kComplexFloat,
+             PrimitiveType::C64, PrimitiveType::C64, DataType::kComplexFloat},
+
+            {ComputationType::kBF16AsF32, DataType::kFloat, PrimitiveType::F32,
+             PrimitiveType::F32, DataType::kFloat},
+            {ComputationType::kBF16AsF32, DataType::kComplexFloat,
+             PrimitiveType::C64, PrimitiveType::C64, DataType::kComplexFloat},
+
+            {ComputationType::kTF32AsF32, DataType::kFloat, PrimitiveType::F32,
+             PrimitiveType::F32, DataType::kFloat},
+            {ComputationType::kTF32AsF32, DataType::kComplexFloat,
+             PrimitiveType::C64, PrimitiveType::C64, DataType::kComplexFloat},
+
+            {ComputationType::kF64, DataType::kDouble, PrimitiveType::F64,
+             PrimitiveType::F64, DataType::kDouble},
+            {ComputationType::kF64, DataType::kComplexDouble,
+             PrimitiveType::C128, PrimitiveType::C128,
+             DataType::kComplexDouble},
+        }};
+
+    return absl::c_linear_search(
+        supported_type_combinations,
+        std::make_tuple(compute_type, scale_type, a_dtype, b_dtype,
+                        output_dtype));
+  }
+
+  absl::StatusOr<bool> TypesAreSupportedByCublasLt(
+      const HloInstruction &instr, const GemmBackendConfig &backend_config,
+      const HloInstruction *bias = nullptr) const {
+    // Figure out the Atype/Btype.
+    const PrimitiveType a_dtype = instr.operand(0)->shape().element_type();
+    const PrimitiveType b_dtype = instr.operand(1)->shape().element_type();
+    const PrimitiveType output_type =
+        bias ? bias->shape().element_type() : instr.shape().element_type();
+    const std::array<PrimitiveType, 12> supported_type = {
+        PrimitiveType::F8E5M2FNUZ, PrimitiveType::F8E4M3FNUZ,
+        PrimitiveType::F8E5M2,     PrimitiveType::F8E4M3FN,
+        PrimitiveType::S8,         PrimitiveType::F16,
+        PrimitiveType::BF16,       PrimitiveType::F32,
+        PrimitiveType::S32,        PrimitiveType::F64,
+        PrimitiveType::C64,        PrimitiveType::C128};
+    if (!absl::c_linear_search(supported_type, output_type)) return false;
+    // cublasLt has a defined set of combinations of types that it supports.
+    // Figure out the computeType and scaleType.
+    TF_ASSIGN_OR_RETURN(const se::blas::DataType output_dtype,
+                        se::gpu::AsBlasDataType(output_type));
+    const int max_precision = *absl::c_max_element(
+        backend_config.precision_config().operand_precision());
+    const PrecisionConfig::Algorithm algorithm =
+        backend_config.precision_config().algorithm();
+    if (!algorithm_util::IsSupportedByCublasOrCublasLt(algorithm)) return false;
+
+    TF_ASSIGN_OR_RETURN(
+        const se::blas::ComputationType compute_type,
+        se::gpu::GetBlasComputationType(
+            algorithm, a_dtype, instr.shape().element_type(), max_precision));
+    se::blas::DataType scale_type =
+        se::gpu::GetScaleType(output_dtype, compute_type);
+
+    using se::blas::ComputationType;
+    using se::blas::DataType;
+    using TypeCombinations = std::initializer_list<std::tuple<
+        ComputationType, DataType /*scale_type*/, PrimitiveType /*a_dtype*/,
+        PrimitiveType /*b_dtype*/, DataType /*output_dtype*/>>;
+    // This matrix of supported types is taken directly from cublasLt
+    // documentation.
+    // https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmul
+    const TypeCombinations supported_cublas_type_combinations = {
+        // FP8 types:
+        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN,
+         PrimitiveType::F8E4M3FN, DataType::kBF16},
+        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN,
+         PrimitiveType::F8E4M3FN, DataType::kF8E4M3FN},
+        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN,
+         PrimitiveType::F8E4M3FN, DataType::kHalf},
+        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN,
+         PrimitiveType::F8E4M3FN, DataType::kFloat},
+
+        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN,
+         PrimitiveType::F8E5M2, DataType::kBF16},
+        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN,
+         PrimitiveType::F8E5M2, DataType::kF8E4M3FN},
+        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN,
+         PrimitiveType::F8E5M2, DataType::kF8E5M2},
+        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN,
+         PrimitiveType::F8E5M2, DataType::kHalf},
+        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN,
+         PrimitiveType::F8E5M2, DataType::kFloat},
+
+        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2,
+         PrimitiveType::F8E4M3FN, DataType::kBF16},
+        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2,
+         PrimitiveType::F8E4M3FN, DataType::kF8E4M3FN},
+        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2,
+         PrimitiveType::F8E4M3FN, DataType::kF8E5M2},
+        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2,
+         PrimitiveType::F8E4M3FN, DataType::kHalf},
+        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2,
+         PrimitiveType::F8E4M3FN, DataType::kFloat},
+        // There would be an entry here for A/BType complex int8, but we do
+        // not support that type.
+        {ComputationType::kF32, DataType::kComplexFloat, PrimitiveType::C64,
+         PrimitiveType::C64, DataType::kComplexFloat},
+
+        {ComputationType::kF16AsF32, DataType::kFloat, PrimitiveType::F32,
+         PrimitiveType::F32, DataType::kFloat},
+        {ComputationType::kF16AsF32, DataType::kComplexFloat,
+         PrimitiveType::C64, PrimitiveType::C64, DataType::kComplexFloat},
+        // The next 4 may be supported by hipblaslt, but they are not
+        // covered by any unit tests
+        {ComputationType::kBF16AsF32, DataType::kFloat, PrimitiveType::F32,
+         PrimitiveType::F32, DataType::kFloat},
+        {ComputationType::kBF16AsF32, DataType::kComplexFloat,
+         PrimitiveType::C64, PrimitiveType::C64, DataType::kComplexFloat},
+
+        {ComputationType::kTF32AsF32, DataType::kFloat, PrimitiveType::F32,
+         PrimitiveType::F32, DataType::kFloat},
+        {ComputationType::kTF32AsF32, DataType::kComplexFloat,
+         PrimitiveType::C64, PrimitiveType::C64, DataType::kComplexFloat},
+
+        {ComputationType::kF64, DataType::kDouble, PrimitiveType::F64,
+         PrimitiveType::F64, DataType::kDouble},
+        {ComputationType::kF64, DataType::kComplexDouble, PrimitiveType::C128,
+         PrimitiveType::C128, DataType::kComplexDouble},
+    };
+    if (IsCuda(gpu_version_) &&
+        absl::c_linear_search(supported_cublas_type_combinations,
+                              std::tuple{compute_type, scale_type, a_dtype,
+                                         b_dtype, output_dtype})) {
+      return true;
+    }
+    const TypeCombinations supported_hipblas_type_combinations = {
+        // FP8 types:
+        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
+         PrimitiveType::F8E4M3FNUZ, DataType::kBF16},
+        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
+         PrimitiveType::F8E4M3FNUZ, DataType::kF8E4M3FNUZ},
+        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
+         PrimitiveType::F8E4M3FNUZ, DataType::kHalf},
+        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
+         PrimitiveType::F8E4M3FNUZ, DataType::kFloat},
+
+        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
+         PrimitiveType::F8E5M2FNUZ, DataType::kBF16},
+        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
+         PrimitiveType::F8E5M2FNUZ, DataType::kF8E4M3FNUZ},
+        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
+         PrimitiveType::F8E5M2FNUZ, DataType::kF8E5M2FNUZ},
+        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
+         PrimitiveType::F8E5M2FNUZ, DataType::kHalf},
+        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
+         PrimitiveType::F8E5M2FNUZ, DataType::kFloat},
+
+        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2FNUZ,
+         PrimitiveType::F8E4M3FNUZ, DataType::kBF16},
+        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2FNUZ,
+         PrimitiveType::F8E4M3FNUZ, DataType::kF8E4M3FNUZ},
+        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2FNUZ,
+         PrimitiveType::F8E4M3FNUZ, DataType::kF8E5M2FNUZ},
+        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2FNUZ,
+         PrimitiveType::F8E4M3FNUZ, DataType::kHalf},
+        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2FNUZ,
+         PrimitiveType::F8E4M3FNUZ, DataType::kFloat},
+    };
+    if (IsRocm(gpu_version_) &&
+        absl::c_linear_search(supported_hipblas_type_combinations,
+                              std::tuple{compute_type, scale_type, a_dtype,
+                                         b_dtype, output_dtype})) {
+      return true;
+    }
+    const TypeCombinations supported_type_combinations = {
+        // Other data types:
+        {ComputationType::kF16, DataType::kHalf, PrimitiveType::F16,
+         PrimitiveType::F16, DataType::kHalf},
+
+        {ComputationType::kI32, DataType::kInt32, PrimitiveType::S8,
+         PrimitiveType::S8, DataType::kInt32},
+        {ComputationType::kI32, DataType::kFloat, PrimitiveType::S8,
+         PrimitiveType::S8, DataType::kInt8},
+
+        {ComputationType::kF32, DataType::kFloat, PrimitiveType::BF16,
+         PrimitiveType::BF16, DataType::kBF16},
+        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F16,
+         PrimitiveType::F16, DataType::kHalf},
+        {ComputationType::kF32, DataType::kFloat, PrimitiveType::S8,
+         PrimitiveType::S8, DataType::kFloat},
+        {ComputationType::kF32, DataType::kFloat, PrimitiveType::BF16,
+         PrimitiveType::BF16, DataType::kFloat},
+        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F16,
+         PrimitiveType::F16, DataType::kFloat},
+        {ComputationType::kF32, DataType::kFloat, PrimitiveType::F32,
+         PrimitiveType::F32, DataType::kFloat},
+    };
+
+    return absl::c_linear_search(
+        supported_type_combinations,
+        std::make_tuple(compute_type, scale_type, a_dtype, b_dtype,
+                        output_dtype));
+  }
+
+  absl::StatusOr<bool> GemmIsSupportedByCublasLt(
+      const HloInstruction &instr,
+      const GemmBackendConfig &gemm_backend_config) const {
+    const HloInstruction *lhs = instr.operand(0);
+    const Shape &output_shape = instr.shape();
+
+    TF_ASSIGN_OR_RETURN(
+        bool types_are_supported_by_cublas_lt,
+        TypesAreSupportedByCublasLt(instr, gemm_backend_config));
+    if (!types_are_supported_by_cublas_lt) {
+      return false;
+    }
+
+    // The cublasLt API has two currently known limitations:
+    // 1. Batch count must be <2^16.
+    constexpr int64_t kMaxBatchCount = 65535;
+    // We get the batch dimension size from lhs here, but we could just as well
+    // use rhs; they are guaranteed to be the same (TODO:Verify).
+    const auto &batch_dimensions =
+        gemm_backend_config.dot_dimension_numbers().lhs_batch_dimensions();
+    int batch_count = (batch_dimensions.empty() ? 0 : 1);
+    // All batch dimensions get flattened into a single batch dimension.
+    for (auto batch_dimension : batch_dimensions) {
+      batch_count *= lhs->shape().dimensions(batch_dimension);
+    }
+    if (batch_count > kMaxBatchCount) {
+      // This is not supported by cublasLt.
+      return false;
+    }
+
+    if (auto isrocm = std::get_if<se::RocmComputeCapability>(&gpu_version_);
+        isrocm) {
+      if (!isrocm->has_hipblaslt()) {
+        return false;
+      }
+    }
+
+    // 2. cublasLt does not support rhs col dimension size > 4194240 for
+    // C64.
+    constexpr int kMaxDimensionSize{4194240};
+    if (output_shape.element_type() != C64) {
+      // Does not match type in unsupported case.
+      return true;
+    }
+
+    if (std::holds_alternative<se::CudaComputeCapability>(gpu_version_)) {
+      if (std::get<se::CudaComputeCapability>(gpu_version_).IsAtLeastAmpere()) {
+        // cuBlasLt has an implementation for complex data with compute type
+        // 32F_FAST_32TF that uses tensor cores and that is free from the
+        // restriction. This implementation only works on Ampere
+        // architecture though (where TF32 was introduced).
+        return true;
+      }
+    }
+
+    TF_ASSIGN_OR_RETURN(GemmConfig gemm_config,
+                        GemmConfig::For(&instr, gemm_backend_config));
+
+    // Check that the size of the non-contracting dimension is not too large.
+    return gemm_config.rhs_layout.num_cols <= kMaxDimensionSize;
+  }
+
+  // Turns an F8 dot with unsupported output type into an F8 dot with F32
+  // output, and converting the F32 output to unsupported output types.
+  absl::StatusOr<HloInstruction *> TurnF8DotWithUnsupportedOutputTypeIntoF32(
+      HloInstruction *instr) {
+    Shape output_f32_shape = instr->shape();
+    output_f32_shape.set_element_type(F32);
+    HloInstruction *f32_dot =
+        instr->AddInstruction(instr->CloneWithNewShape(output_f32_shape));
+    HloInstruction *convert = instr->AddInstruction(
+        HloInstruction::CreateConvert(instr->shape(), f32_dot));
+    TF_RETURN_IF_ERROR(ReplaceInstruction(instr, convert));
+    return f32_dot;
+  }
+
+  // Turns an F8 dot into an F16 dot, converting operands to F16 and
+  // converting the output back to F8.
+  absl::StatusOr<HloInstruction *> TurnF8DotIntoF16Dot(HloInstruction *instr) {
+    DCHECK(IsF8Type(instr->operand(0)));
+    DCHECK(IsF8Type(instr->operand(1)));
+
+    // Convert operands to F16
+    for (int i = 0; i < 2; ++i) {
+      Shape operand_f16_shape = instr->operand(i)->shape();
+      operand_f16_shape.set_element_type(F16);
+      HloInstruction *convert =
+          instr->AddInstruction(HloInstruction::CreateConvert(
+              operand_f16_shape, instr->mutable_operand(i)));
+      TF_RETURN_IF_ERROR(instr->ReplaceOperandWith(i, convert));
+    }
+
+    // If output is F8, change output to F16 and then convert it back to F8
+    if (IsF8Type(instr)) {
+      Shape output_f16_shape = instr->shape();
+      output_f16_shape.set_element_type(F16);
+      HloInstruction *f16_dot =
+          instr->AddInstruction(instr->CloneWithNewShape(output_f16_shape));
+      HloInstruction *convert_to_f8 = instr->AddInstruction(
+          HloInstruction::CreateConvert(instr->shape(), f16_dot));
+      TF_RETURN_IF_ERROR(ReplaceInstruction(instr, convert_to_f8));
+      return f16_dot;
+    } else {
+      return instr;
+    }
+  }
+};
+
+// Rewriter that adds a workspace to legacy cuBLAS custom calls. We run it
+// separately after gemm rewriter, so that we can do pattern matching without
+// having to match output tuples.
+class GemmWorkspaceRewriteVisitor : public DfsHloRewriteVisitor {
+ public:
+  explicit GemmWorkspaceRewriteVisitor(
+      const se::GpuComputeCapability &gpu_version)
+      : gpu_version_(gpu_version) {}
+
+  absl::Status HandleCustomCall(HloInstruction *instr) override {
+    bool has_aux_output = false;
+    if (instr->custom_call_target() == kCublasLtMatmulCallTarget ||
+        instr->custom_call_target() == kCublasLtMatmulF8CallTarget) {
+      TF_ASSIGN_OR_RETURN(const auto gpu_config,
+                          instr->backend_config<xla::gpu::GpuBackendConfig>());
+      const xla::gpu::GemmBackendConfig &config =
+          gpu_config.gemm_backend_config();
+      xla::gpu::GemmBackendConfig_Epilogue epilogue = config.epilogue();
+      TF_ASSIGN_OR_RETURN(
+          has_aux_output,
+          xla::gpu::gpublas_lt::EpilogueHasAuxiliaryOutput(epilogue));
+
+      if (!((instr->shape().IsTuple() &&
+             instr->shape().tuple_shapes_size() ==
+                 has_aux_output + config.damax_output() + 1) ||
+            instr->shape().IsArray())) {
+        return absl::OkStatus();
+      }
+    } else if (instr->custom_call_target() != kGemmCallTarget ||
+               !instr->shape().IsArray()) {
+      return absl::OkStatus();
+    }
+
+    auto *cuda_cc = std::get_if<se::CudaComputeCapability>(&gpu_version_);
+
+    // Pass a user-managed workspace to legacy cuBLAS operations, as
+    // otherwise cuBLAS will use its own internal pool which will be competing
+    // with XLA allocator for device memory.
+    int64_t workspace = cuda_cc == nullptr ? GemmConfig::kDefaultWorkspace
+                        : cuda_cc->IsAtLeastHopper()
+                            ? GemmConfig::kHopperWorkspace
+                            : GemmConfig::kDefaultWorkspace;
+
+    // We do not know the workspace size required by cuBLAS, but we can guess
+    // that in a worst case cuBLAS will transpose all operands into tiled
+    // layout optimal for the tensor cores. It doesn't make sense to allocate a
+    // larger workspace.
+    //
+    // TODO(ezhulenev): This is not based on any measurement, just a common
+    // sense, we should tweak it to find the minimal workspace size.
+    if (instr->custom_call_target() == kGemmCallTarget) {
+      int64_t operands_byte_size = 0;
+      for (auto &operand : instr->operands()) {
+        operands_byte_size += ShapeUtil::ByteSizeOf(operand->shape());
+      }
+      workspace = std::min(workspace, operands_byte_size);
+    }
+
+    // Append workspace buffer to instruction outputs.
+    std::vector<Shape> output_shapes = instr->shape().IsArray()
+                                           ? std::vector<Shape>{instr->shape()}
+                                           : instr->shape().tuple_shapes();
+    output_shapes.emplace_back(ShapeUtil::MakeShape(S8, {workspace}));
+    Shape output_shape = ShapeUtil::MakeTupleShape(output_shapes);
+
+    // Clone custom call with a new shape.
+    HloInstruction *new_call = instr->AddInstruction(
+        instr->CloneWithNewOperands(output_shape, instr->operands()));
+
+    // Update operand aliasing if it was a fused gemm with aliased output.
+    auto *custom_call = xla::Cast<HloCustomCallInstruction>(new_call);
+    if (!custom_call->output_to_operand_aliasing().empty()) {
+      custom_call->set_output_to_operand_aliasing({{{0}, {2, {}}}});
+    }
+
+    if (instr->shape().IsTuple()) {
+      for (auto user : instr->users()) {
+        auto user_get_tuple =
+            dynamic_cast<HloGetTupleElementInstruction *>(user);
+        TF_RET_CHECK(user_get_tuple);
+        HloInstruction *get_output =
+            instr->AddInstruction(HloInstruction::CreateGetTupleElement(
+                new_call, user_get_tuple->tuple_index()));
+        TF_RETURN_IF_ERROR(ReplaceInstruction(user_get_tuple, get_output));
+      }
+      return absl::OkStatus();
+    } else {
+      HloInstruction *get_output = instr->AddInstruction(
+          HloInstruction::CreateGetTupleElement(new_call, 0));
+      return ReplaceInstruction(instr, get_output);
+    }
+  }
+
+ private:
+  se::GpuComputeCapability gpu_version_;
+};
+
+absl::StatusOr<bool> RunOnComputation(HloComputation *computation,
+                                      se::GpuComputeCapability gpu_version,
+                                      int32_t toolkit_version,
+                                      GemmRewriterOptions options) {
+  GemmRewriterVisitor visitor(gpu_version, toolkit_version, options);
+  TF_RETURN_IF_ERROR(computation->Accept(&visitor));
+  GemmWorkspaceRewriteVisitor workspace_visitor(gpu_version);
+  TF_RETURN_IF_ERROR(computation->Accept(&workspace_visitor));
+  return visitor.changed();
+}
+
+}  // anonymous namespace
+
+GemmRewriter::GemmRewriter(se::GpuComputeCapability gpu_version,
+                           int32_t toolkit_version, GemmRewriterOptions options)
+    : gpu_version_(gpu_version),
+      toolkit_version_(toolkit_version),
+      options_(options) {}
+
+absl::StatusOr<bool> GemmRewriter::Run(
+    HloModule *module,
+    const absl::flat_hash_set<absl::string_view> &execution_threads) {
+  bool changed = false;
+  for (HloComputation *computation :
+       module->MakeNonfusionComputations(execution_threads)) {
+    TF_ASSIGN_OR_RETURN(bool result,
+                        RunOnComputation(computation, gpu_version_,
+                                         toolkit_version_, options_));
+    changed |= result;
+  }
+  return changed;
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.h b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.h
new file mode 100644
index 0000000..cce09c4
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.h
@@ -0,0 +1,98 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_GEMM_REWRITER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_GEMM_REWRITER_H_
+
+#include <cstdint>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/stream_executor/device_description.h"
+
+namespace xla {
+namespace gpu {
+
+// cuBLAS GEMM in the most general form can run the following operation:
+//
+// (kAdd
+//    (kMultiply (kDot A B) alpha)
+//    (kMultiply C beta))
+//
+// where A, B, C are matrices or vectors and `alpha` and `beta` are host
+// constants. In matrix-vector multiplication, one operand must be a matrix and
+// the other must be a vector. The additional requirement is that C has no other
+// users (otherwise, it does not make sense to fuse it inside the custom call).
+//
+// Both multiplication and addition can be avoided (equivalent to setting
+// `alpha` to one and `beta` to zero).
+//
+// This pass pattern-matches the most general form of this instruction
+// (we assume transposes are already folded), and rewrites it into a custom call
+// where (A, B, C) are three operands respectively, and `alpha` and `beta` are
+// stored in the backend config.
+
+struct GemmRewriterOptions {
+  // The DType of the GEMM to rewrite.
+  enum class DType { kFp8Only, kNonFp8Only };
+  DType dtype = DType::kNonFp8Only;
+
+  // Disabling bias prevents using the `beta * C` term the GEMM, which can
+  // remove dependencies between multiple matrix multiplications. This, in
+  // turn, can improve the performance of overall computation by allowing
+  // multiple GEMMs to be scheduled in parallel.
+  //
+  // As an example, consider the following computation: `(A * A) + (B * B)`.
+  // With bias enabled, the `GemmRewriter` will emit the following GEMMs:
+  //
+  // AA := GEMM(A * A)
+  // ROOT := GEMM(B * B + AA)
+  //
+  // Because the second GEMM depends on the first, they cannot be scheduled in
+  // parallel. Instead, with bias disabled, the `GemmRewriter` will emit the
+  // following:
+  //
+  // AA := GEMM(A * A)
+  // BB := GEMM(B * B)
+  // ROOT := AA + BB
+  //
+  // In this case, the two GEMMs can be scheduled in parallel.
+  enum class BiasMode { kBias, kNoBias };
+  BiasMode bias_mode = BiasMode::kBias;
+};
+
+class GemmRewriter : public HloModulePass {
+ public:
+  GemmRewriter(se::GpuComputeCapability gpu_version, int32_t toolkit_version,
+               GemmRewriterOptions options = {});
+  absl::string_view name() const override { return "cublas-gemm-rewriter"; }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+  se::GpuComputeCapability gpu_version_;
+  int32_t toolkit_version_;
+  GemmRewriterOptions options_;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_GEMM_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter_test.cc
new file mode 100644
index 0000000..e4e2aad
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter_test.cc
@@ -0,0 +1,8423 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/gemm_rewriter.h"
+
+#include <array>
+#include <cstdint>
+#include <functional>
+#include <memory>
+#include <optional>
+#include <string>
+#include <tuple>
+#include <utility>
+#include <variant>
+#include <vector>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_replace.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/error_spec.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/buffer_assignment.h"
+#include "xla/service/executable.h"
+#include "xla/service/gpu/gpu_executable.h"
+#include "xla/service/gpu/tests/gpu_codegen_test.h"
+#include "xla/service/hlo_module_config.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/stream_executor/device_memory_allocator.h"
+#include "xla/stream_executor/stream_executor_memory_allocator.h"
+#include "xla/test.h"
+#include "xla/tests/filecheck.h"
+#include "xla/tests/verified_hlo_module.h"
+#include "xla/tsl/lib/core/status_test_util.h"
+#include "xla/xla.pb.h"
+#include "tsl/platform/statusor.h"
+
+#if GOOGLE_CUDA
+#include "third_party/gpus/cuda/include/cuda.h"
+#elif TENSORFLOW_USE_ROCM
+#include "rocm/rocm_config.h"
+#endif
+
+namespace xla {
+namespace gpu {
+
+namespace {
+
+namespace m = ::xla::match;
+
+class GemmRewriteTest : public GpuCodegenTest {
+  const auto& device_desc() {
+    return backend().default_stream_executor()->GetDeviceDescription();
+  }
+
+ protected:
+  const se::GpuComputeCapability& Capability() {
+    return device_desc().gpu_compute_capability();
+  }
+
+  int32_t GetToolkitVersion() const {
+#if GOOGLE_CUDA
+    return CUDA_VERSION;
+#elif TENSORFLOW_USE_ROCM
+    return TF_ROCM_VERSION;
+#endif
+    return 0;
+  }
+
+  bool IsCuda() {
+    return std::holds_alternative<se::CudaComputeCapability>(Capability());
+  }
+
+  se::GpuComputeCapability CudaHopperOrRocmMI300() {
+    if (IsCuda()) {
+      return se::CudaComputeCapability::Hopper();
+    } else {
+      return se::RocmComputeCapability{"gfx942"};
+    }
+  }
+
+  DebugOptions GetDebugOptionsForTest() override {
+    DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest();
+    // These tests test the cuBLAS rewriter so we have to make sure that we use
+    // cuBLAS for them.
+    debug_options.set_xla_gpu_enable_triton_gemm(false);
+    debug_options.set_xla_gpu_gemm_rewrite_size_threshold(0);
+    return debug_options;
+  }
+
+  bool SkipGpuBlasLtTest() {
+    return !IsCuda() &&
+           !std::get<se::RocmComputeCapability>(Capability()).has_hipblaslt() &&
+           GetDebugOptionsForTest().xla_gpu_enable_cublaslt();
+  }
+
+  bool HasFp8Support() {
+    if (IsCuda()) {
+      return std::get<se::CudaComputeCapability>(Capability()).IsAtLeast(8, 9);
+    }
+    return std::get<se::RocmComputeCapability>(Capability()).has_fp8_support();
+  }
+
+  bool HasCudaComputeCapability(const se::CudaComputeCapability& cc) {
+    return IsCuda() &&
+           std::get<se::CudaComputeCapability>(Capability()).IsAtLeast(cc);
+  }
+};
+
+TEST_F(GemmRewriteTest, CheckCustomCallTarget) {
+  if (SkipGpuBlasLtTest()) {
+    GTEST_SKIP() << "BlasLt is not supported on this GPU architecture";
+  }
+
+  const char* hlo_text = R"(
+HloModule SimpleGemm
+
+ENTRY AddDotsFunc {
+  x = f32[2,3] parameter(0)
+  y = f32[3,4] parameter(1)
+  ROOT dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+
+)";
+  DebugOptions debug_options = GetDebugOptionsForTest();
+  if (debug_options.xla_gpu_enable_cublaslt()) {
+    MatchOptimizedHlo(hlo_text,
+                      R"(; CHECK: custom_call_target="__cublas$lt$matmul")");
+  } else {
+    MatchOptimizedHlo(hlo_text,
+                      R"(; CHECK: custom_call_target="__cublas$gemm")");
+  }
+}
+
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+TEST_F(GemmRewriteTest, TestBatchedAutotuning) {
+  if (HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
+    GTEST_SKIP()
+        << "There is no autotuning starting with the Nvidia Ampere generation";
+  }
+
+  const char* hlo_text = R"(
+HloModule ComplexDotMultipleNonContracting
+
+ENTRY %test {
+  %lhs = f32[7,17,10,13]{3,2,1,0} parameter(0)
+  %rhs = f32[7,9,10,13,6]{4,3,2,1,0} parameter(1)
+  ROOT %dot = f32[10,7,17,9,6]{4,3,2,1,0} dot(%lhs, %rhs), lhs_batch_dims={2,0}, rhs_batch_dims={2,0}, lhs_contracting_dims={3}, rhs_contracting_dims={3}
+}
+
+)";
+
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+; CHECK: selected_algorithm
+      )");
+}
+#endif
+
+TEST_F(GemmRewriteTest, SimpleRewriteDeterministic) {
+  if (SkipGpuBlasLtTest()) {
+    GTEST_SKIP() << "BlasLt is not supported on this GPU architecture";
+  }
+
+  const char* hlo_text = R"(
+HloModule SimpleGemm
+
+ENTRY AddDotsFunc {
+  x = f32[128,128] parameter(0)
+  y = f32[128,128] parameter(1)
+  ROOT dot_a = f32[128,128] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+)";
+
+  ErrorSpec error_spec = [&] {
+    DebugOptions debug_options = GetDebugOptionsForTest();
+    if (debug_options.xla_gpu_enable_cublaslt()) {
+      return ErrorSpec{1e-3, 1e-3};
+    } else {
+      return ErrorSpec{1e-3, 1e-3};
+    }
+  }();
+
+  auto get_module = [&]() {
+    HloModuleConfig config;
+    DebugOptions debug_options = GetDebugOptionsForTest();
+    debug_options.set_xla_gpu_exclude_nondeterministic_ops(true);
+    config.set_debug_options(debug_options);
+    return ParseAndReturnVerifiedModule(hlo_text, config);
+  };
+
+  se::StreamExecutorMemoryAllocator allocator(
+      backend().default_stream_executor());
+  TF_ASSERT_OK_AND_ASSIGN(
+      std::unique_ptr<HloModule> optimized_module,
+      backend().compiler()->RunHloPasses(
+          *get_module(), backend().default_stream_executor(), &allocator));
+
+  absl::StatusOr<bool> filecheck_result =
+      RunFileCheck(optimized_module->ToString(),
+                   R"(
+; CHECK:    custom_call_target="__cublas${{(lt\$matmul|gemm)}}"
+    )");
+  TF_ASSERT_OK(filecheck_result.status());
+  EXPECT_TRUE(filecheck_result.value());
+  EXPECT_TRUE(RunAndCompare(*get_module(), error_spec));
+}
+
+TEST_F(GemmRewriteTest, BF16GemmCodeGen) {
+  const char* hlo_text = R"(
+HloModule bf16codegendgemm
+
+ENTRY bf16gemm {
+  %parameter.1 = bf16[3]{0} parameter(0)
+  %parameter.2 = bf16[3]{0} parameter(1)
+  ROOT %dot.3 = bf16[] dot(bf16[3]{0} %parameter.1, bf16[3]{0} %parameter.2), lhs_contracting_dims={0}, rhs_contracting_dims={0}, operand_precision={highest,highest}
+}
+  )";
+
+  if (HasCudaComputeCapability(se::CudaComputeCapability::Hopper())) {
+    // The Hopper optimized HLO has a BF16 multiply instruction since Hopper has
+    // native BF16 multiply support.
+    MatchOptimizedHlo(hlo_text, R"(
+  ; CHECK:  [[P0:%[^ ]+]] = bf16[3]{0} parameter(0)
+  ; CHECK:  [[P1:%[^ ]+]] = bf16[3]{0} parameter(1)
+  ; CHECK:  [[INSTR_2:%[^ ]+]] = bf16[3]{0} multiply([[P0]], [[P1]])
+  ; CHECK:  [[INSTR_3:%[^ ]+]] = f32[3]{0} convert([[INSTR_2]])
+  ; CHECK:  [[INSTR_4:%[^ ]+]] = f32[] constant(0)
+  ; CHECK:  [[INSTR_5:%[^ ]+]] = f32[] reduce([[INSTR_3]], [[INSTR_4]]), dimensions={0}, to_apply=[[INSTR_6:%[^ ]+]]
+  ; CHECK:  ROOT [[INSTR_7:%[^ ]+]] = bf16[] convert([[INSTR_5]])
+    )");
+  } else {
+    MatchOptimizedHlo(hlo_text, R"(
+  ; CHECK:  [[P1:%[^ ]+]] = bf16[3]{0} parameter(1)
+  ; CHECK:  [[INSTR_1:%[^ ]+]] = f32[3]{0} convert([[P1]])
+  ; CHECK:  [[P0:%[^ ]+]] = bf16[3]{0} parameter(0)
+  ; CHECK:  [[INSTR_3:%[^ ]+]] = f32[3]{0} convert([[P0]])
+  ; CHECK:  [[INSTR_4:%[^ ]+]] = f32[3]{0} multiply([[INSTR_1]], [[INSTR_3]])
+  ; CHECK:  [[INSTR_5:%[^ ]+]] = f32[] constant(0)
+  ; CHECK:  [[INSTR_6:%[^ ]+]] = f32[] reduce([[INSTR_4]], [[INSTR_5]]), dimensions={0}, to_apply=[[INSTR_7:%[^ ]+]]
+  ; CHECK:  ROOT [[INSTR_8:%[^ ]+]] = bf16[] convert([[INSTR_6]])
+    )");
+  }
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-4, 1e-4}));
+}
+
+TEST_F(GemmRewriteTest, BF16Transpose) {
+  const char* hlo_text = R"(
+HloModule broadcast
+
+ENTRY broadcast {
+  p = bf16[9] parameter(0)
+  ROOT out = bf16[1,9] broadcast(p), dimensions={1}
+}
+)";
+
+  MatchOptimizedHlo(hlo_text, R"(
+; CHECK: bf16[1,9]{1,0} bitcast
+; CHECK: bf16[1,9]{1,0} copy
+)");
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+}
+
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+// A test fixture class for tests which should have similar results with legacy
+// cublas and cublasLt
+class ParameterizedGemmRewriteTest
+    : public GemmRewriteTest,
+      public ::testing::WithParamInterface<bool> {
+ public:
+  ParameterizedGemmRewriteTest() {
+    const bool kUsingCublasLt = GetParam();
+    replacements_[kCustomCallTargetPlaceholder] =
+        kUsingCublasLt ? "__cublas$lt$matmul" : "__cublas$gemm";
+  }
+  DebugOptions GetDebugOptionsForTest() override {
+    DebugOptions debug_options = GemmRewriteTest::GetDebugOptionsForTest();
+    debug_options.set_xla_gpu_enable_cublaslt(GetParam());
+    debug_options.set_xla_gpu_enable_triton_gemm(false);
+    return debug_options;
+  }
+  void MatchOptimizedHlo(absl::string_view hlo, const absl::string_view pattern,
+                         bool print_operand_shape = false) {
+    GemmRewriteTest::MatchOptimizedHlo(
+        hlo, absl::StrReplaceAll(pattern, replacements_), print_operand_shape);
+  }
+  absl::string_view CustomCallTarget() {
+    return replacements_[kCustomCallTargetPlaceholder];
+  }
+
+ protected:
+  void SetUp() override {
+    if (SkipGpuBlasLtTest()) {
+      GTEST_SKIP() << "BlasLt is not supported on this GPU architecture";
+    }
+  }
+
+ protected:
+  absl::flat_hash_map<absl::string_view, absl::string_view> replacements_;
+
+ private:
+  static constexpr const char* kCustomCallTargetPlaceholder{
+      "<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>"};
+};
+
+TEST_P(ParameterizedGemmRewriteTest, Simple) {
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  x = f32[2,3] parameter(0)
+  y = f32[3,4] parameter(1)
+  ROOT dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4]) -> f32[2,4] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
+; CHECK-NEXT:    [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
+; CHECK:           custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+)");
+}
+
+TEST_P(ParameterizedGemmRewriteTest, SimpleRewrite) {
+  const char* hlo_text = R"(
+HloModule SimpleGemm
+
+ENTRY AddDotsFunc {
+  x = f32[2,3] parameter(0)
+  y = f32[3,4] parameter(1)
+  ROOT dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,3], {{.*}}: f32[3,4]) -> f32[2,4] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
+; CHECK-NEXT:    [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
+; CHECK:           custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+)");
+}
+
+TEST_P(ParameterizedGemmRewriteTest, MultipleContractingDims) {
+  const char* hlo_text = R"(
+HloModule MultipleContractingCheckGemm
+
+ENTRY AddDotsFunc {
+  x = f32[3,4,2] parameter(0)
+  y = f32[3,4,5] parameter(1)
+  ROOT dot_a = f32[2,5] dot(x, y), lhs_contracting_dims={0,1}, rhs_contracting_dims={0,1}, operand_precision={highest,highest}
+}
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+; CHECK-NOT:     copy
+;
+; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[3,4,2], {{.*}}: f32[3,4,5]) -> f32[2,5] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[3,4,2]{2,1,0} parameter(0)
+; CHECK-DAG:     [[P1:%[^ ]+]] = f32[3,4,5]{2,1,0} parameter(1)
+; CHECK-DAG:     [[BITCAST0:%[^ ]+]] = f32[2,12]{0,1} bitcast([[P0]])
+; CHECK-DAG:     [[BITCAST1:%[^ ]+]] = f32[12,5]{1,0} bitcast([[P1]])
+; CHECK-NEXT:    [[GEMM:%[^ ]+]] = {{.*}} custom-call([[BITCAST0]], [[BITCAST1]]),
+; CHECK:           custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["HIGHEST","HIGHEST"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+)");
+}
+
+TEST_P(ParameterizedGemmRewriteTest, ArgTransposeFoldCheck) {
+  const char* hlo_text = R"(
+HloModule ArgTransposeFoldGemm
+
+ENTRY AddDotsFunc {
+  x = f32[3,2] parameter(0)
+  y = f32[3,4] parameter(1)
+  x_transposed = f32[2,3] transpose(x), dimensions={1, 0}
+  ROOT dot_a = f32[2,4] dot(x_transposed, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[3,2], {{.*}}: f32[3,4]) -> f32[2,4] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[3,2]{1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
+; CHECK-NEXT:    [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
+; CHECK:           custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+)");
+}
+
+TEST_P(ParameterizedGemmRewriteTest, BatchedArgRowColTransposeFoldCheck) {
+  const char* hlo_text = R"(
+HloModule BatchedArgRowColTransposeFoldGemm
+
+ENTRY AddDotsFunc {
+  x = f32[5,3,2] parameter(0)
+  y = f32[5,3,4] parameter(1)
+  x_transposed = f32[5,2,3] transpose(x), dimensions={0, 2, 1}
+  ROOT dot_a = f32[5,2,4] dot(x_transposed, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0}
+}
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[5,3,2], {{.*}}: f32[5,3,4]) -> f32[5,2,4] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[5,3,2]{2,1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[5,3,4]{2,1,0} parameter(1)
+; CHECK-NEXT:    [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
+; CHECK:           custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "lhs_batch_dimensions":["0"]
+; CHECK-DAG:           "rhs_batch_dimensions":["0"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+)");
+}
+
+TEST_P(ParameterizedGemmRewriteTest, BatchRowTransposeFoldCheck) {
+  const char* hlo_text = R"(
+HloModule BatchRowTransposeFoldCheck
+
+ENTRY AddDotsFunc {
+  x = f32[2,5,3] parameter(0)
+  y = f32[5,3,4] parameter(1)
+  x_transposed = f32[5,2,3] transpose(x), dimensions={1, 0, 2}
+  ROOT dot_a = f32[5,2,4] dot(x_transposed, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0}
+}
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{2.5e-5, 1e-5}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,5,3], {{.*}}: f32[5,3,4]) -> f32[5,2,4] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,5,3]{2,1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[5,3,4]{2,1,0} parameter(1)
+; CHECK-NEXT:    [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
+; CHECK:           custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["2"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "lhs_batch_dimensions":["1"]
+; CHECK-DAG:           "rhs_batch_dimensions":["0"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+)");
+}
+
+TEST_P(ParameterizedGemmRewriteTest, BatchFromMinorDimTransposeIsNotFolded) {
+  const char* hlo_text = R"(
+HloModule BatchFromMinorDimTransposeDoesntFold
+
+ENTRY AddDotsFunc {
+  x = f32[3,2,5] parameter(0)
+  y = f32[5,3,4] parameter(1)
+  x_transposed = f32[5,2,3] transpose(x), dimensions={2, 1, 0}
+  ROOT dot_a = f32[5,2,4] dot(x_transposed, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0}
+}
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{2.5e-5, 1e-5}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[3,2,5], {{.*}}: f32[5,3,4]) -> f32[5,2,4] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[3,2,5]{2,1,0} parameter(0)
+; CHECK-DAG:     [[P1:%[^ ]+]] = f32[5,3,4]{2,1,0} parameter(1)
+; CHECK-DAG:     [[FUSION:%[^ ]+]] = f32[5,2,3]{2,1,0} transpose([[P0]])
+; CHECK-NEXT:    [[GEMM:%[^ ]+]] = {{.*}} custom-call([[FUSION]], [[P1]]),
+; CHECK:           custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["2"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "lhs_batch_dimensions":["0"]
+; CHECK-DAG:           "rhs_batch_dimensions":["0"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+)");
+}
+
+TEST_P(ParameterizedGemmRewriteTest, LargeBatch) {
+  const char* hlo_text = R"(
+HloModule BatchedArgRowColTransposeFoldGemm
+
+ENTRY AddDotsFunc {
+  x = f32[20000,4,3,2] parameter(0)
+  y = f32[20000,4,3,4] parameter(1)
+  ROOT dot_a = f32[20000,4,2,4] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={2}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1}
+}
+
+)";
+
+  // Batch sizes larger than 2^16-1 are not supported by cublasLt. Ensure that
+  // the custom_call_target is __cublas$gemm.
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[20000,4,3,2], {{.*}}: f32[20000,4,3,4]) -> f32[20000,4,2,4] {
+; CHECK:    [[P0:%[^ ]+]] = f32[20000,4,3,2]{3,2,1,0} parameter(0)
+; CHECK:    [[BC0:%[^ ]+]] = f32[80000,3,2]{2,1,0} bitcast([[P0]])
+; CHECK:    [[P1:%[^ ]+]] = f32[20000,4,3,4]{3,2,1,0} parameter(1)
+; CHECK:    [[BC1:%[^ ]+]] = f32[80000,3,4]{2,1,0} bitcast([[P1]])
+; CHECK:    [[GEMM:%[^ ]+]] = (f32[80000,2,4]{2,1,0}, s8[{{[0-9]+}}]{0}) custom-call([[BC0]], [[BC1]]),
+; CHECK:           custom_call_target="__cublas$gemm",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "lhs_batch_dimensions":["0"]
+; CHECK-DAG:           "rhs_batch_dimensions":["0"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK:           }
+; CHECK:   [[OUT:%[^ ]+]] = f32[80000,2,4]{2,1,0} get-tuple-element([[GEMM]]), index=0
+; CHECK:   ROOT {{[^ ]+}} = f32[20000,4,2,4]{3,2,1,0} bitcast([[OUT]])
+)");
+}
+
+TEST_P(ParameterizedGemmRewriteTest, InstrTransposeFoldCheck) {
+  const char* hlo_text = R"(
+HloModule InstrTransposeFoldGemm
+
+ENTRY AddDotsFunc {
+  x = f32[2,3] parameter(0)
+  y = f32[3,4] parameter(1)
+  dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  ROOT out = f32[4,2] transpose(dot_a), dimensions={1, 0}
+}
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,3], {{.*}}: f32[3,4]) -> f32[4,2] {
+; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
+; CHECK-NEXT:    [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P1]], [[P0]]),
+; CHECK:           custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+)");
+}
+
+TEST_P(ParameterizedGemmRewriteTest, BatchedInstrLayoutTransposed) {
+  const char* hlo_text = R"(
+HloModule BatchedInstrLayoutCheck
+
+ENTRY AddDotsFunc {
+  x = f32[5,2,3] parameter(0)
+  y = f32[5,3,4] parameter(1)
+  dot_a = f32[5,2,4] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0}
+  ROOT out = f32[2,5,4] transpose(dot_a), dimensions={1, 0, 2}
+}
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{2.5e-5, 1e-5}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[5,2,3], {{.*}}: f32[5,3,4]) -> f32[2,5,4] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[5,2,3]{2,1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[5,3,4]{2,1,0} parameter(1)
+; CHECK-NEXT:    [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
+; CHECK:           custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["2"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "lhs_batch_dimensions":["0"]
+; CHECK-DAG:           "rhs_batch_dimensions":["0"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+; CHECK:         ROOT [[OUT:%[^ ]+]] = f32[2,5,4]{2,1,0} bitcast
+)");
+}
+
+TEST_P(ParameterizedGemmRewriteTest, BatchedInstrLayoutBatchNotInMinorDim) {
+  const char* hlo_text = R"(
+HloModule BatchedInstrLayoutBatchNotInMinorDim
+
+ENTRY AddDotsFunc {
+  x = f32[5,2,3] parameter(0)
+  y = f32[5,3,4] parameter(1)
+  dot_a = f32[5,2,4] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0}
+  ROOT out = f32[2,4,5] transpose(dot_a), dimensions={1, 2, 0}
+}
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{2.5e-5, 1e-5}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[5,2,3], {{.*}}: f32[5,3,4]) -> f32[2,4,5] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[5,2,3]{2,1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[5,3,4]{2,1,0} parameter(1)
+; CHECK-NEXT:    [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
+; CHECK:           custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["2"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "lhs_batch_dimensions":["0"]
+; CHECK-DAG:           "rhs_batch_dimensions":["0"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+; CHECK:         ROOT [[OUT:%[^ ]+]] = f32[2,4,5]{2,1,0} [[OP:[^ ]+]]
+)");
+}
+
+TEST_P(ParameterizedGemmRewriteTest, AlphaSimpleRewrite) {
+  const char* hlo_text = R"(
+HloModule AlphaSimpleRewrite
+
+ENTRY AddDotsFunc {
+  x = f32[2,2] parameter(0)
+  y = f32[2,2] parameter(1)
+  k = f32[] constant(3.0)
+  k_broadcast = f32[2, 2] broadcast(k), dimensions={}
+  dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,highest}
+  ROOT dot_a_multiplied = f32[2, 2] multiply(dot_a, k_broadcast)
+}
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,2], {{.*}}: f32[2,2]) -> f32[2,2] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1)
+; CHECK-NEXT:    [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
+; CHECK:           custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":3
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["HIGHEST","HIGHEST"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+)");
+}
+
+TEST_P(ParameterizedGemmRewriteTest, F64C64_CublasLtSupportTest) {
+  // This test should fail if gemm rewriter does not correctly rewrite
+  // F64/C64 dots to cublas-lt or legacy cublas calls
+  {
+    const char* hlo_text = R"(
+HloModule F64_rewrite
+
+ENTRY AddDotsFunc {
+  x = f64[2,2] parameter(0)
+  y = f64[2,2] parameter(1)
+  k = f64[] constant(3.0)
+  k_broadcast = f64[2, 2] broadcast(k), dimensions={}
+  dot_a = f64[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  ROOT dot_a_multiplied = f64[2, 2] multiply(dot_a, k_broadcast)
+}
+)";
+    EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-4, 1e-5}));
+  }
+  {
+    const char* hlo_text = R"(
+HloModule C64_rewrite
+
+ENTRY AddDotsFunc {
+  x = c64[2,2] parameter(0)
+  y = c64[2,2] parameter(1)
+  k = c64[] constant((3.0, 3.0))
+  k_broadcast = c64[2, 2] broadcast(k), dimensions={}
+  dot_a = c64[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  ROOT dot_a_multiplied = c64[2, 2] multiply(dot_a, k_broadcast)
+}
+)";
+    EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-4, 1e-5}));
+  }
+}
+
+TEST_P(ParameterizedGemmRewriteTest, ComplexAlphaSimpleRewrite) {
+  if (!IsCuda() && GetDebugOptionsForTest().xla_gpu_enable_cublaslt()) {
+    GTEST_SKIP() << "TODO: Unsupported C64 gpublas-lt datatype on ROCM";
+  }
+  const char* hlo_text = R"(
+HloModule ComplexAlphaSimpleRewrite
+
+ENTRY AddDotsFunc {
+  x = c64[2,2] parameter(0)
+  y = c64[2,2] parameter(1)
+  k = c64[] constant((3.0, 3.0))
+  k_broadcast = c64[2, 2] broadcast(k), dimensions={}
+  dot_a = c64[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  ROOT dot_a_multiplied = c64[2, 2] multiply(dot_a, k_broadcast)
+}
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-4, 1e-5}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: c64[2,2], {{.*}}: c64[2,2]) -> c64[2,2] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = c64[2,2]{1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = c64[2,2]{1,0} parameter(1)
+; CHECK-NEXT:    [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
+; CHECK:           custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":3
+; CHECK-DAG:         "alpha_imag":3
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+)");
+}
+
+TEST_P(ParameterizedGemmRewriteTest, AlphaMultipleUsersNoRewrite) {
+  const char* hlo_text = R"(
+HloModule AlphaMultipleUsersNoRewrite
+
+ENTRY AddDotsFunc {
+  x = f32[2,2] parameter(0)
+  y = f32[2,2] parameter(1)
+  k = f32[] constant(3.0)
+  k_broadcast = f32[2, 2] broadcast(k), dimensions={}
+  dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,highest}
+  dot_a_multiplied = f32[2, 2] multiply(dot_a, k_broadcast)
+  ROOT out = f32[2,2] add(dot_a_multiplied, dot_a)
+}
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+; CHECK:    {{[^ ]+}} = {{.*}} custom-call({{[^,]+}}, {{[^)]+}}),
+; CHECK:           custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["HIGHEST","HIGHEST"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+)");
+}
+
+TEST_P(ParameterizedGemmRewriteTest, AlphaVectorNoRewrite) {
+  const char* hlo_text = R"(
+HloModule AlphaVectorNoRewrite
+
+ENTRY AddDotsFunc {
+  x = f32[2,2] parameter(0)
+  y = f32[2,2] parameter(1)
+  alpha = f32[2] constant({1, 2})
+  alpha_broadcast = f32[2,2] broadcast(alpha), dimensions={1}
+  dot = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  ROOT dot_a_multiplied = f32[2, 2] multiply(dot, alpha_broadcast)
+}
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,2], {{.*}}: f32[2,2]) -> f32[2,2] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1)
+; CHECK-NEXT:    [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
+; CHECK:           custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+)");
+}
+
+TEST_P(ParameterizedGemmRewriteTest, BF16Gemm) {
+  const char* hlo_text = R"(
+HloModule bf16gemm
+
+ENTRY bf16gemm {
+  %parameter.1 = bf16[12,4]{1,0} parameter(0)
+  %parameter.2 = bf16[4,8]{1,0} parameter(1)
+  ROOT %dot.8 = bf16[12,8] dot(bf16[12,4] %parameter.1, bf16[4,8] %parameter.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+  )";
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+
+  if (!IsCuda() ||
+      HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
+    MatchOptimizedHlo(hlo_text,
+                      R"(
+; CHECK: {{.*}} custom-call(bf16[16,8]{1,0} {{.*}}, bf16[8,8]{1,0} {{.*}}), custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>"
+  )",
+                      /*print_operand_shape=*/true);
+  } else {
+    GTEST_SKIP() << "Pre-Ampere casts up bf16 to fp32";
+  }
+}
+
+TEST_P(ParameterizedGemmRewriteTest, BF16GemmStrided) {
+  const char* hlo_text = R"(
+HloModule bf16gemm
+
+ENTRY bf16gemm {
+  %parameter.1 = bf16[3,3,4] parameter(0)
+  %parameter.2 = bf16[3,3,2] parameter(1)
+  ROOT %dot.3 = bf16[3,4,2]{2,1,0} dot(bf16[3,3,4]{2,1,0} %parameter.1, bf16[3,3,2]{2,1,0} %parameter.2), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1}, operand_precision={highest,highest}
+}
+
+  )";
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+
+  if (!IsCuda() ||
+      HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
+    MatchOptimizedHlo(hlo_text,
+                      R"(
+    ; CHECK: {{.*}} custom-call(bf16[3,8,8]{2,1,0} {{.*}}, bf16[3,8,8]{2,1,0} {{.*}}), custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>"
+    )",
+                      /*print_operand_shape=*/true);
+  } else {
+    GTEST_SKIP() << "Pre-Ampere casts up bf16 to fp32";
+  }
+}
+
+TEST_P(ParameterizedGemmRewriteTest, Int8Gemm) {
+  const char* hlo_text = R"(
+HloModule int8gemm
+
+ENTRY int8gemm {
+  %parameter.1 = s8[12,4]{1,0} parameter(0)
+  %parameter.2 = s8[4,8]{1,0} parameter(1)
+  ROOT %dot.8 = s32[12,8] dot(s8[12,4] %parameter.1, s8[4,8] %parameter.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+  )";
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+
+  if (!IsCuda() ||
+      HasCudaComputeCapability(se::CudaComputeCapability::Volta())) {
+    MatchOptimizedHlo(hlo_text,
+                      R"(
+; CHECK: {{.*}} custom-call(s8[12,4]{1,0} [[A:%[^ ]+]], s8[4,8]{0,1} [[B:%[^ ]+]]), custom_call_target="__cublas$gemm"
+  )",
+                      /*print_operand_shape=*/true);
+  } else {
+    MatchOptimizedHlo(hlo_text,
+                      R"(
+; CHECK: {{.*}} dot(s32[12,4]{1,0} [[A:%[^ ]+]], s32[4,8]{1,0} [[B:%[^ ]+]]), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+
+  )",
+                      /*print_operand_shape=*/true);
+  }
+}
+
+TEST_F(GemmRewriteTest, Int8GemmRankGreaterThanTwo) {
+  if (!IsCuda()) {
+    GTEST_SKIP() << "DoBlasGemmWithAlgorithm is not yet implemented on ROCm";
+  }
+
+  const char* hlo_text = R"(
+HloModule int8gemm
+
+ENTRY main.4 {
+  Arg_0.1 = s8[1,8,2]{2,1,0} parameter(0)
+  Arg_1.2 = s8[2,4]{1,0} parameter(1)
+  ROOT dot.3 = s32[1,8,4]{2,1,0} dot(Arg_0.1, Arg_1.2),
+  lhs_contracting_dims={2}, rhs_contracting_dims={0}
+}
+  )";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+
+  if (!IsCuda() ||
+      HasCudaComputeCapability(se::CudaComputeCapability::Volta())) {
+    MatchOptimizedHlo(hlo_text,
+                      R"(
+; CHECK: [[GEMM:%[^ ]+]] = (s32[8,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call(s8[8,4]{1,0} %{{.*}}, s8[4,4]{0,1} %{{.*}}), custom_call_target="__cublas$gemm",
+  )",
+                      /*print_operand_shape=*/true);
+  }
+}
+
+TEST_P(ParameterizedGemmRewriteTest, Int8GemmNoAlphaRewrite) {
+  const char* hlo_text = R"(
+HloModule int8gemm
+
+ENTRY int8gemm {
+  %parameter.1 = s8[12,4]{1,0} parameter(0)
+  %parameter.2 = s8[4,8]{1,0} parameter(1)
+  k = s32[] constant(2)
+  k_broadcast = s32[12,8] broadcast(k), dimensions={}
+  %dot.8 = s32[12,8] dot(s8[12,4] %parameter.1, s8[4,8] %parameter.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  ROOT dot_multiplied = s32[12,8] multiply(%dot.8, k_broadcast)
+}
+  )";
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+
+  if (!IsCuda() ||
+      HasCudaComputeCapability(se::CudaComputeCapability::Volta())) {
+    MatchOptimizedHlo(hlo_text,
+                      R"(
+; CHECK: {{.*}} custom-call(s8[12,4]{1,0} [[A:%[^ ]+]], s8[4,8]{0,1} [[B:%[^ ]+]]),
+; CHECK:           custom_call_target="__cublas$gemm",
+; CHECK:           backend_config={
+; CHECK-DAG:       "alpha_real":1
+; CHECK-DAG:       "alpha_imag":0
+  )",
+                      /*print_operand_shape=*/true);
+  } else {
+    MatchOptimizedHlo(hlo_text,
+                      R"(
+; CHECK: {{.*}} dot(s32[12,4]{1,0} [[A:%[^ ]+]], s32[4,8]{1,0} [[B:%[^ ]+]]), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+
+  )",
+                      /*print_operand_shape=*/true);
+  }
+}
+
+TEST_P(ParameterizedGemmRewriteTest, Int8GemmNoBetaRewrite) {
+  const char* hlo_text = R"(
+HloModule int8gemm
+
+ENTRY int8gemm {
+  %parameter.1 = s8[12,4]{1,0} parameter(0)
+  %parameter.2 = s8[4,8]{1,0} parameter(1)
+  bias = s32[12,8] parameter(2)
+  %dot.8 = s32[12,8] dot(s8[12,4] %parameter.1, s8[4,8] %parameter.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  ROOT out = s32[12,8] add(%dot.8, bias)
+}
+  )";
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+
+  if (!IsCuda() ||
+      HasCudaComputeCapability(se::CudaComputeCapability::Volta())) {
+    MatchOptimizedHlo(hlo_text,
+                      R"(
+; CHECK: {{.*}} custom-call(s8[12,4]{1,0} [[A:%[^ ]+]], s8[4,8]{0,1} [[B:%[^ ]+]]),
+; CHECK:           custom_call_target="__cublas$gemm",
+; CHECK:           backend_config={
+; CHECK-DAG:       "alpha_real":1
+; CHECK-DAG:       "alpha_imag":0
+; CHECK-DAG:       "beta":0
+  )",
+                      /*print_operand_shape=*/true);
+  } else {
+    MatchOptimizedHlo(hlo_text,
+                      R"(
+; CHECK: {{.*}} dot(s32[12,4]{1,0} [[A:%[^ ]+]], s32[4,8]{1,0} [[B:%[^ ]+]]), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+
+  )",
+                      /*print_operand_shape=*/true);
+  }
+}
+
+TEST_P(ParameterizedGemmRewriteTest, Int8GemmNotMultipleOfFour) {
+  if (!IsCuda()) {
+    GTEST_SKIP() << "DoBlasGemmWithAlgorithm is not yet implemented on ROCm";
+  }
+
+  const char* hlo_text = R"(
+HloModule int8gemm
+
+ENTRY int8gemm {
+  %parameter.1 = s8[13,4]{1,0} parameter(0)
+  %parameter.2 = s8[4,9]{1,0} parameter(1)
+  ROOT %dot.9 = s32[13,9] dot(s8[13,4] %parameter.1, s8[4,9] %parameter.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+  )";
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+
+  if (!IsCuda() ||
+      HasCudaComputeCapability(se::CudaComputeCapability::Volta())) {
+    MatchOptimizedHlo(hlo_text,
+                      R"(
+; CHECK: {{.*}} custom-call(s8[16,4]{1,0} [[A:%[^ ]+]], s8[4,12]{0,1} [[B:%[^ ]+]]), custom_call_target="__cublas$gemm"
+  )",
+                      /*print_operand_shape=*/true);
+  } else {
+    MatchOptimizedHlo(hlo_text,
+                      R"(
+; CHECK: {{.*}} dot(s32[13,4]{1,0} [[A:%[^ ]+]], s32[4,9]{1,0} [[B:%[^ ]+]]), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+
+  )",
+                      /*print_operand_shape=*/true);
+  }
+}
+
+TEST_P(ParameterizedGemmRewriteTest, GemmTypeCombinationCheck) {
+  if (!IsCuda()) {
+    GTEST_SKIP() << "DoBlasGemmWithAlgorithm is not yet implemented on ROCm";
+  }
+
+  std::vector<std::tuple<absl::string_view, absl::string_view, bool>>
+      type_combinations = {{"s8", "s8", true},
+                           {"s32", "s32", true},
+                           {"bf16", "bf16", true},
+                           {"f16", "f16", true},
+                           {"f32", "f32", true},
+                           {"f64", "f64", true},
+                           {"c64", "c64", true},
+                           {"c128", "c128", true},
+                           // add mix type gemm
+                           {"s8", "s32", true},
+                           {"s8", "f32", true},
+                           {"f16", "f32", true},
+                           {"bf16", "f32", true}};
+
+  if (!IsCuda() ||
+      HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
+    // For compute capabilities before Ampere, we may do upcasting, so it
+    // would be impossible for this test to fail. That is why we only add these
+    // cases when the compute capability is at least Volta.
+    std::vector<std::tuple<absl::string_view, absl::string_view, bool>>
+        more_type_combinations = {
+            {"s8", "bf16", false},  {"s8", "f16", false},
+            {"s8", "f64", false},   {"s8", "c64", false},
+            {"s8", "c128", false},
+
+            {"s32", "f32", false},  {"s32", "f64", false},
+            {"s32", "c64", false},  {"s32", "c128", false},
+
+            {"f16", "bf16", false}, {"f16", "f64", false},
+            {"f16", "c64", false},  {"f16", "c128", false},
+
+            {"bf16", "f16", false}, {"bf16", "f64", false},
+            {"bf16", "c64", false}, {"bf16", "c128", false},
+
+            {"f32", "f64", false},  {"f32", "c64", false},
+            {"f32", "c128", false},
+
+            {"f64", "c64", false},  {"f64", "c128", false},
+        };
+    type_combinations.insert(type_combinations.end(),
+                             more_type_combinations.begin(),
+                             more_type_combinations.end());
+  }
+
+  for (const auto& type_combination : type_combinations) {
+    absl::flat_hash_map<absl::string_view, absl::string_view> replacements;
+    replacements["<<ABType>>"] = std::get<0>(type_combination);
+    replacements["<<DType>>"] = std::get<1>(type_combination);
+    const char* hlo_template = R"(
+  HloModule type_combo
+
+  ENTRY type_combo {
+    %parameter.1 = <<ABType>>[4,4]{1,0} parameter(0)
+    %parameter.2 = <<ABType>>[4,4]{1,0} parameter(1)
+    ROOT %dot = <<DType>>[4,4] dot(%parameter.1, %parameter.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  }
+    )";
+    const auto hlo_text = absl::StrReplaceAll(hlo_template, replacements);
+    if (std::get<2>(type_combination)) {
+      EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+    } else {
+      EXPECT_FALSE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+    }
+  }
+}
+
+TEST_P(ParameterizedGemmRewriteTest, UpcastingBf16ToF64) {
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  Arg_0.1 = bf16[4,3]{1,0} parameter(0)
+  Arg_1.2 = bf16[3,6]{1,0} parameter(1)
+  ROOT dot.3 = f64[4,6]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_text));
+  GemmRewriter pass(Capability(), GetToolkitVersion());
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
+  EXPECT_TRUE(changed);
+
+  // This is a type combination which is not supported by cublasLt, expect
+  // GemmRewriter to choose legacy cublas.
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(m::GetTupleElement(m::CustomCall({"__cublas$gemm"}), 0)));
+}
+
+TEST_P(ParameterizedGemmRewriteTest, UpcastingC64ToC128) {
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  Arg_0.1 = c64[4,3]{1,0} parameter(0)
+  Arg_1.2 = c64[3,6]{1,0} parameter(1)
+  ROOT dot.3 = c128[4,6]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_text));
+  GemmRewriter pass(Capability(), GetToolkitVersion());
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
+  EXPECT_TRUE(changed);
+
+  // This is a type combination which is not supported by cublasLt, expect
+  // GemmRewriter to choose legacy cublas.
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(m::GetTupleElement(m::CustomCall({"__cublas$gemm"}), 0)));
+}
+
+TEST_P(ParameterizedGemmRewriteTest, UpcastingF16ToF32) {
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  Arg_0.1 = f16[4,3]{1,0} parameter(0)
+  Arg_1.2 = f16[3,6]{1,0} parameter(1)
+  ROOT dot.3 = f32[4,6]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest, highest}
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_text));
+  GemmRewriter pass(Capability(), GetToolkitVersion());
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
+  EXPECT_TRUE(changed);
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(m::GetTupleElement(m::CustomCall({CustomCallTarget()}), 0)));
+}
+
+TEST_P(ParameterizedGemmRewriteTest, UpcastingF16ToF64) {
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  Arg_0.1 = f16[4,3]{1,0} parameter(0)
+  Arg_1.2 = f16[3,6]{1,0} parameter(1)
+  ROOT dot.3 = f64[4,6]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_text));
+  GemmRewriter pass(Capability(), GetToolkitVersion());
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
+  EXPECT_TRUE(changed);
+
+  // This is a type combination which is not supported by cublasLt, expect
+  // GemmRewriter to choose legacy cublas.
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(m::GetTupleElement(m::CustomCall({"__cublas$gemm"}), 0)));
+}
+
+TEST_P(ParameterizedGemmRewriteTest, UpcastingF32ToF64) {
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  Arg_0.1 = f32[4,3]{1,0} parameter(0)
+  Arg_1.2 = f32[3,6]{1,0} parameter(1)
+  ROOT dot.3 = f64[4,6]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_text));
+  GemmRewriter pass(Capability(), GetToolkitVersion());
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
+  EXPECT_TRUE(changed);
+
+  // This is a type combination which is not supported by cublasLt, expect
+  // GemmRewriter to choose legacy cublas.
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(m::GetTupleElement(m::CustomCall({"__cublas$gemm"}), 0)));
+}
+
+TEST_P(ParameterizedGemmRewriteTest, DoNotUpconvertOutput) {
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY main {
+  param_0 = f16[240,88]{1,0} parameter(0)
+  param_1 = f16[88,4]{1,0} parameter(1)
+  dot = f16[240,4]{1,0} dot(param_0, param_1), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,highest}
+  constant_255 = f16[] constant(255)
+  broadcast = f16[240,4]{1,0} broadcast(constant_255), dimensions={}
+  multiply = f16[240,4]{1,0} multiply(dot, broadcast)
+  ROOT result = f32[240,4]{1,0} convert(multiply)
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_text));
+  GemmRewriter pass(Capability(), GetToolkitVersion());
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
+  EXPECT_TRUE(changed);
+
+  // input fp16 and output fp32 combination is supported by legacy cublas and
+  // cublasLt, expect GemmRewriter to fuse the convert into gemm.
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch(m::Convert(
+                  m::GetTupleElement(m::CustomCall({CustomCallTarget()}), 0))));
+}
+
+TEST_P(ParameterizedGemmRewriteTest, UnsupportedMixTypeGemm) {
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY main {
+  param_0 = f32[240,88]{1,0} parameter(0)
+  param_1 = f32[88,4]{1,0} parameter(1)
+  dot = f32[240,4]{1,0} dot(param_0, param_1), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,highest}
+  constant_255 = f32[] constant(255)
+  broadcast = f32[240,4]{1,0} broadcast(constant_255), dimensions={}
+  multiply = f32[240,4]{1,0} multiply(dot, broadcast)
+  ROOT result = u8[240,4]{1,0} convert(multiply)
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_text));
+  GemmRewriter pass(Capability(), GetToolkitVersion());
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
+  EXPECT_TRUE(changed);
+
+  // u8 is not supported by legacy cublas and cublasLt, expect
+  // GemmRewriter to not fuse the convert into gemm.
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch(m::Convert(
+                  m::GetTupleElement(m::CustomCall({CustomCallTarget()}), 0))));
+}
+
+TEST_P(ParameterizedGemmRewriteTest, CheckIsGemmAliasedBeforeFusion) {
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY main {
+  Arg_0.1 = f16[8,16]{1,0} parameter(0)
+  Arg_1.2 = f16[16,32]{1,0} parameter(1)
+  dot.8 = f16[8,32]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  Arg_2.3 = f16[8,32]{1,0} parameter(2)
+  constant.5 = f16[] constant(1)
+  broadcast.6 = f16[8,32]{1,0} broadcast(constant.5), dimensions={}
+  add.7 = f16[8,32]{1,0} add(Arg_2.3, broadcast.6)
+  add.9 = f16[8,32]{1,0} add(dot.8, add.7)
+  convert.10 = f32[8,32]{1,0} convert(add.9)
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_text));
+  GemmRewriter pass(Capability(), GetToolkitVersion());
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
+  EXPECT_TRUE(changed);
+
+  // input fp16 and output fp32 combination is supported by legacy cublas and
+  // cublasLt, but gemm output is already aliased with one of the input expect
+  // GemmRewriter to not fuse the convert into gemm.
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch(m::Convert(
+                  m::GetTupleElement(m::CustomCall({CustomCallTarget()}), 0))));
+}
+
+INSTANTIATE_TEST_SUITE_P(CublasTestsBothLegacyAndLt,
+                         ParameterizedGemmRewriteTest, ::testing::Bool());
+#endif
+
+// A test fixture class for tests which are specific to legacy cublas
+class LegacyCublasGemmRewriteTest : public GemmRewriteTest {
+ public:
+  DebugOptions GetDebugOptionsForTest() override {
+    DebugOptions debug_options = GemmRewriteTest::GetDebugOptionsForTest();
+    debug_options.set_xla_gpu_enable_triton_gemm(false);
+    debug_options.set_xla_gpu_enable_cublaslt(false);
+    return debug_options;
+  }
+};
+
+TEST_F(LegacyCublasGemmRewriteTest, MatrixVectorMultiplication) {
+  const char* hlo_text = R"(
+HloModule m
+
+ENTRY e {
+  p0 = f32[2048] parameter(0)
+  p1 = f32[2048, 16384] parameter(1)
+  ROOT d = f32[16384] dot(p0, p1),
+    lhs_contracting_dims={0}, rhs_contracting_dims={0}
+})";
+
+  RunAndFilecheckHloRewrite(
+      hlo_text,
+      GemmRewriter(
+          se::CudaComputeCapability{se::CudaComputeCapability::AMPERE, 0},
+          /*toolkit_version=*/12040),
+      R"(
+; CHECK:  %[[P0:.+]] = f32[2048]{0} parameter(0)
+; CHECK:  %[[P1:.+]] = f32[2048,16384]{1,0} parameter(1)
+; CHECK:  %[[CUSTOM_CALL:.+]] = (f32[16384]{0}, s8[4194304]{0}) custom-call(%[[P0]], %[[P1]]), custom_call_target="__cublas$gemm"
+)");
+}
+
+TEST_F(LegacyCublasGemmRewriteTest, MatrixVectorMultiplicationWithBatch) {
+  const char* hlo_text = R"(
+HloModule m
+
+ENTRY e {
+  p0 = f32[10, 10, 2048] parameter(0)
+  p1 = f32[10, 10, 2048, 16384] parameter(1)
+  ROOT d = f32[10, 10, 16384] dot(p0, p1),
+   lhs_batch_dims={0, 1}, rhs_batch_dims={0, 1},
+   lhs_contracting_dims={2}, rhs_contracting_dims={2}
+})";
+
+  RunAndFilecheckHloRewrite(
+      hlo_text,
+      GemmRewriter(
+          se::CudaComputeCapability{se::CudaComputeCapability::AMPERE, 0},
+          /*toolkit_version=*/12040),
+      R"(
+; CHECK:  %[[P0:.+]] = f32[10,10,2048]{2,1,0} parameter(0)
+; CHECK:  %[[P1:.+]] = f32[10,10,2048,16384]{3,2,1,0} parameter(1)
+; CHECK:  %[[CUSTOM_CALL:.+]] = (f32[10,10,16384]{2,1,0}, s8[4194304]{0}) custom-call(%[[P0]], %[[P1]]), custom_call_target="__cublas$gemm"
+)");
+}
+
+TEST_F(LegacyCublasGemmRewriteTest, SparseDotNotSupported) {
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY main {
+  lhs = f16[5,16] parameter(0)
+  rhs = f16[32,10] parameter(1)
+  meta = u16[5,2] parameter(2)
+  ROOT dot = f32[5,10] dot(lhs, rhs, meta),
+      lhs_contracting_dims={1}, rhs_contracting_dims={0}, sparsity=L.1@2:4
+})";
+  auto hlo_pass = GemmRewriter(
+      se::CudaComputeCapability{se::CudaComputeCapability::AMPERE, 0},
+      /*toolkit_version=*/12040);
+  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_text));
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&hlo_pass, module.get()));
+  EXPECT_FALSE(changed);
+}
+
+// Test that the alpha and beta fields of the GemmBackendConfig are updated.
+// A bias must be present for the beta value to be set.
+// In order to have a bias add fused, the bias term must be overwritable.
+// We assume that we may not overwrite parameters of a computation. Hence, we
+// use the third parameter to create a new value which can be overwritten and
+// will be used as the bias. This negate(param_2) has no semantic use, it simply
+// exists so that bias may be overwritten.
+TEST_F(LegacyCublasGemmRewriteTest, AlphaBetaRewrite) {
+  const char* hlo_text = R"(
+HloModule NonZeroAlphaBeta
+
+ENTRY AddDotsFunc {
+  x = f32[2,2] parameter(0)
+  y = f32[2,2] parameter(1)
+  param_2 = f32[2,2] parameter(2)
+  bias = f32[2,2] negate(param_2)
+  k = f32[] constant(3.0)
+  k_broadcast = f32[2, 2] broadcast(k), dimensions={}
+  dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,highest}
+  dot_a_multiplied = f32[2, 2] multiply(dot_a, k_broadcast)
+  ROOT out = f32[2,2] add(dot_a_multiplied, bias)
+}
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,2], {{.*}}: f32[2,2], {{.*}}: f32[2,2]) -> f32[2,2] {
+; CHECK-DAG:     [[X:%[^ ]+]] = f32[2,2]{1,0} parameter(0)
+; CHECK-DAG:     [[Y:%[^ ]+]] = f32[2,2]{1,0} parameter(1)
+; CHECK:         [[O:%[^ ]+]] = (f32[2,2]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[X]], [[Y]], {{[^,)]+}}),
+; CHECK:           custom_call_target="__cublas$gemm",
+; CHECK:           output_to_operand_aliasing={
+; CHECK-SAME:        {0}: (2, {})
+; CHECK-SAME:      }
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":3
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":1
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["HIGHEST","HIGHEST"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+; CHECK:         ROOT [[OUT:%[^ ]+]] = f32[2,2]{1,0} get-tuple-element([[O]]), index=0
+)");
+}
+
+TEST_F(LegacyCublasGemmRewriteTest, BiasMultipleUsersNoOverwrite) {
+  const char* hlo_text = R"(
+HloModule BiasMultipleUsersNoOverwrite
+
+ENTRY AddDotsFunc {
+  x = f32[2,2] parameter(0)
+  y = f32[2,2] parameter(1)
+  bias = f32[2,2] parameter(2)
+  k = f32[] constant(3.0)
+  k_broadcast = f32[2, 2] broadcast(k), dimensions={}
+  dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,highest}
+  dot_a_multiplied = f32[2, 2] multiply(dot_a, k_broadcast)
+  biased_out = f32[2,2] add(dot_a_multiplied, bias)
+  ROOT out = f32[2,2] add(biased_out, bias)
+}
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,2], {{.*}}: f32[2,2], {{.*}}: f32[2,2]) -> f32[2,2] {
+; CHECK-DAG:     [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0)
+; CHECK-DAG:     [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1)
+; CHECK-NEXT:    [[CUSTOM_CALL:%[^ ]+]] = (f32[2,2]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]]),
+; CHECK:           custom_call_target="__cublas$gemm",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":3
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["HIGHEST","HIGHEST"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+
+)");
+}
+
+TEST_F(LegacyCublasGemmRewriteTest, BiasParameterNoOverwrite) {
+  const char* hlo_text = R"(
+HloModule BiasParameterNoOverwrite
+
+ENTRY AddDotsFunc {
+  x = f32[2,2] parameter(0)
+  y = f32[2,2] parameter(1)
+  bias = f32[2,2] parameter(2)
+  dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  ROOT out = f32[2,2] add(dot_a, bias)
+}
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,2], {{.*}}: f32[2,2], {{.*}}: f32[2,2]) -> f32[2,2] {
+; CHECK-DAG:     [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0)
+; CHECK-DAG:     [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1)
+; CHECK-NEXT:    [[GEMM:%[^ ]+]] = (f32[2,2]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]]),
+; CHECK:           custom_call_target="__cublas$gemm",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+)");
+}
+
+TEST_F(LegacyCublasGemmRewriteTest, BiasTupleParameterOverwrite) {
+  const char* hlo_text = R"(
+HloModule BiasTupleParameterOverwrite
+
+ENTRY AddDotsFunc {
+  x = f32[2,2] parameter(0)
+  y = f32[2,2] parameter(1)
+  param_2 = (f32[2,2], f32[3,3]) parameter(2)
+  bias = f32[2,2] get-tuple-element(param_2), index=0
+  dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  ROOT out = f32[2,2] add(dot_a, bias)
+}
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,2], {{.*}}: f32[2,2], {{.*}}: (f32[2,2], f32[3,3])) -> f32[2,2] {
+; CHECK-DAG:     [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0)
+; CHECK-DAG:     [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1)
+; CHECK-DAG:     [[P2:%[^ ]+]] = (f32[2,2]{1,0}, f32[3,3]{1,0}) parameter(2)
+; CHECK-DAG:     [[BIAS:%[^ ]+]] = f32[2,2]{1,0} get-tuple-element([[P2]]), index=0
+; CHECK-DAG:     [[BIAS_COPY:%[^ ]+]] = f32[2,2]{1,0} copy([[BIAS]])
+; CHECK-NEXT:    [[GEMM:%[^ ]+]] = (f32[2,2]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[BIAS_COPY]]),
+; CHECK:           custom_call_target="__cublas$gemm",
+; CHECK:           output_to_operand_aliasing={
+; CHECK-SAME:        {0}: (2, {})
+; CHECK-SAME:      }
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":1
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+)");
+}
+
+TEST_F(LegacyCublasGemmRewriteTest, AliasedBiasOverwrite) {
+  const char* hlo_text = R"(
+HloModule AliasedBiasOverwrite, input_output_alias={ {}: (2, {}, must-alias) }
+
+ENTRY AddDotsFunc {
+  x = f32[2,2] parameter(0)
+  y = f32[2,2] parameter(1)
+  bias = f32[2,2] parameter(2)
+  k = f32[] constant(3.0)
+  k_broadcast = f32[2, 2] broadcast(k), dimensions={}
+  dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,highest}
+  dot_a_multiplied = f32[2, 2] multiply(dot_a, k_broadcast)
+  ROOT out = f32[2,2] add(dot_a_multiplied, bias)
+}
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,2], {{.*}}: f32[2,2], {{.*}}: f32[2,2]) -> f32[2,2] {
+; CHECK-DAG:     [[X:%[^ ]+]] = f32[2,2]{1,0} parameter(0)
+; CHECK-DAG:     [[Y:%[^ ]+]] = f32[2,2]{1,0} parameter(1)
+; CHECK-DAG:     [[BIAS:%[^ ]+]] = f32[2,2]{1,0} parameter(2)
+; CHECK:         [[GEMM:%[^ ]+]] = (f32[2,2]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[X]], [[Y]], [[BIAS]]),
+; CHECK:           custom_call_target="__cublas$gemm",
+; CHECK:           output_to_operand_aliasing={
+; CHECK-SAME:        {0}: (2, {})
+; CHECK-SAME:      }
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":3
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":1
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["HIGHEST","HIGHEST"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+)");
+}
+
+TEST_F(LegacyCublasGemmRewriteTest, LargerBiasMultipleUsersNoRewrite) {
+  const char* hlo_text = R"(
+HloModule LargerBiasMultipleUsersNoRewrite
+
+ENTRY AddDotsFunc {
+  x = f32[1024,1024] parameter(0)
+  y = f32[1024,1024] parameter(1)
+  bias = f32[1024,1024] parameter(2)
+  dot_a = f32[1024,1024] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  biased_out = f32[1024,1024] add(dot_a, bias)
+  ROOT out = f32[1024,1024] add(biased_out, bias)
+}
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[1024,1024], {{.*}}: f32[1024,1024], {{.*}}: f32[1024,1024]) -> f32[1024,1024] {
+; CHECK-DAG:     [[P0:%[^ ]+]] = f32[1024,1024]{1,0} parameter(0)
+; CHECK-DAG:     [[P1:%[^ ]+]] = f32[1024,1024]{1,0} parameter(1)
+; CHECK-NEXT:    [[GEMM:%[^ ]+]] = (f32[1024,1024]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]]),
+; CHECK:           custom_call_target="__cublas$gemm",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+)");
+}
+
+// In order to have a bias add fused, the bias term must be overwritable.
+// We assume that we may not overwrite parameters of a computation. Hence, we
+// use the third parameter to create a new value which can be overwritten and
+// will be used as the bias. This negate(param_2) has no semantic use, it simply
+// exists so that bias may be overwritten.
+TEST_F(LegacyCublasGemmRewriteTest, BF16GemmWithBias) {
+  const char* hlo_text = R"(
+HloModule BF16GemmWithBias
+
+ENTRY BF16GemmWithBias {
+  x = bf16[8,8]{1,0} parameter(0)
+  y = bf16[8,8]{1,0} parameter(1)
+  dot.5 = bf16[8,8]{1,0} dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  param_2 = bf16[8,8]{1,0} parameter(2)
+  bias = bf16[8,8]{1,0} negate(param_2)
+  ROOT add.6 = bf16[8,8]{1,0} add(dot.5, bias)
+}
+  )";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{2e-3, 2e-3}));
+
+  if (IsCuda() &&
+      !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
+    GTEST_SKIP() << "Pre-Ampere casts up bf16 to fp32";
+  }
+
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+; CHECK-LABEL: ENTRY %BF16GemmWithBias ({{.*}}: bf16[8,8], {{.*}}: bf16[8,8], {{.*}}: bf16[8,8]) -> bf16[8,8] {
+; CHECK-DAG:    [[X:%[^ ]+]] = bf16[8,8]{1,0} parameter(0)
+; CHECK-DAG:    [[Y:%[^ ]+]] = bf16[8,8]{1,0} parameter(1)
+; CHECK:        [[GEMM:%[^ ]+]] = (bf16[8,8]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[X]], [[Y]], {{[^,)]+}}),
+; CHECK:           custom_call_target="__cublas$gemm",
+; CHECK:           output_to_operand_aliasing={
+; CHECK-SAME:        {0}: (2, {})
+; CHECK-SAME:      }
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":1
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+)");
+}
+
+// In order to have a bias add fused, the bias term must be overwritable.
+// We assume that we may not overwrite parameters of a computation. Hence, we
+// use the third parameter to create a new value which can be overwritten and
+// will be used as the bias. This negate(param_2) has no semantic use, it simply
+// exists so that bias may be overwritten.
+TEST_F(LegacyCublasGemmRewriteTest, MatrixBias) {
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  x = f32[2,3] parameter(0)
+  y = f32[3,4] parameter(1)
+  param_2 = f32[2,4] parameter(2)
+  bias = f32[2,4] negate(param_2)
+  dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  ROOT out = f32[2,4] add(dot_a, bias)
+}
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[2,4]) -> f32[2,4] {
+; CHECK-DAG:     [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
+; CHECK-DAG:     [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
+; CHECK:         [[GEMM:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], {{[^,)]+}}),
+; CHECK:           custom_call_target="__cublas$gemm",
+; CHECK:           output_to_operand_aliasing={
+; CHECK-SAME:        {0}: (2, {})
+; CHECK-SAME:      }
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":1
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+)");
+}
+
+TEST_F(LegacyCublasGemmRewriteTest, MatrixBiasWhereBiasIsNotAParameter) {
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  w = f32[2,3] parameter(0)
+  x = f32[3,4] parameter(1)
+  first_dot = f32[2,4] dot(w, x), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  y = f32[2,3] parameter(2)
+  z = f32[3,4] parameter(3)
+  second_dot = f32[2,4] dot(y, z), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  ROOT out = f32[2,4] add(second_dot, first_dot)
+}
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[2,3], {{.*}}: f32[3,4]) -> f32[2,4] {
+; CHECK-DAG:     [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
+; CHECK-DAG:     [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
+; CHECK-DAG:     [[P2:%[^ ]+]] = f32[2,3]{1,0} parameter(2)
+; CHECK-DAG:     [[P3:%[^ ]+]] = f32[3,4]{1,0} parameter(3)
+; CHECK-NEXT:    [[FIRST_GEMM:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]]),
+; CHECK:           custom_call_target="__cublas$gemm",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+; CHECK:         [[FIRST_GEMM_OUT:%[^ ]+]] = f32[2,4]{1,0} get-tuple-element([[FIRST_GEMM]]), index=0
+; CHECK-NEXT:    [[SECOND_GEMM:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P2]], [[P3]], [[FIRST_GEMM_OUT]]),
+; CHECK:           custom_call_target="__cublas$gemm",
+; CHECK:           output_to_operand_aliasing={
+; CHECK-SAME:        {0}: (2, {})
+; CHECK-SAME:      }
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":1
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+)");
+}
+
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+// Test gemm matrix bias add fusion with mix type
+TEST_F(LegacyCublasGemmRewriteTest, MatrixBiasMixType) {
+  std::vector<std::tuple<absl::string_view, absl::string_view>>
+      type_combinations = {
+          {"f16", "f32"},
+          {"bf16", "f32"},
+      };
+
+  const char* hlo_text_template = R"(
+HloModule test
+
+ENTRY test {
+  x = <<ABType>>[16,32] parameter(0)
+  y = <<ABType>>[32,16] parameter(1)
+  z = <<DType>>[16,16] parameter(2)
+  dot_a = <<ABType>>[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  bias = <<DType>>[16,16] negate(z)
+  convert = <<DType>>[16,16] convert(dot_a)
+  ROOT out = <<DType>>[16,16] add(convert, bias)
+}
+
+)";
+  for (const auto& type_combination : type_combinations) {
+    absl::flat_hash_map<absl::string_view, absl::string_view> replacements;
+    replacements["<<ABType>>"] = std::get<0>(type_combination);
+    replacements["<<DType>>"] = std::get<1>(type_combination);
+    const auto hlo_text = absl::StrReplaceAll(hlo_text_template, replacements);
+    EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+
+    if (std::get<0>(type_combination) == "bf16" && IsCuda() &&
+        !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
+      continue;
+    }
+
+    TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
+                            GetOptimizedModule(hlo_text));
+    EXPECT_THAT(optimized_module->entry_computation()->root_instruction(),
+                GmockMatch(m::GetTupleElement(
+                    m::CustomCall(m::Parameter(0), m::Parameter(1),
+                                  m::Negate(m::Parameter(2))),
+                    0)));
+  }
+}
+
+// Test batch gemm matrix bias add fusion with mix type
+TEST_F(LegacyCublasGemmRewriteTest, MatrixBiasMixTypeBatched) {
+  std::vector<std::tuple<absl::string_view, absl::string_view>>
+      type_combinations = {
+          {"f16", "f32"},
+          {"bf16", "f32"},
+      };
+
+  const char* hlo_text_template = R"(
+HloModule test
+
+ENTRY test {
+  x = <<ABType>>[4,16,32] parameter(0)
+  y = <<ABType>>[4,32,16] parameter(1)
+  z = <<DType>>[4,16,16] parameter(2)
+  dot_a = <<ABType>>[4,16,16] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0}
+  bias = <<DType>>[4,16,16] negate(z)
+  convert = <<DType>>[4,16,16] convert(dot_a)
+  ROOT out = <<DType>>[4,16,16] add(convert, bias)
+})";
+  for (const auto& type_combination : type_combinations) {
+    absl::flat_hash_map<absl::string_view, absl::string_view> replacements;
+    replacements["<<ABType>>"] = std::get<0>(type_combination);
+    replacements["<<DType>>"] = std::get<1>(type_combination);
+    const auto hlo_text = absl::StrReplaceAll(hlo_text_template, replacements);
+    EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+
+    if (std::get<0>(type_combination) == "bf16" && IsCuda() &&
+        !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
+      continue;
+    }
+
+    TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
+                            GetOptimizedModule(hlo_text));
+    EXPECT_THAT(optimized_module->entry_computation()->root_instruction(),
+                GmockMatch(m::GetTupleElement(
+                    m::CustomCall(m::Parameter(0), m::Parameter(1),
+                                  m::Negate(m::Parameter(2))),
+                    0)));
+  }
+}
+#endif
+
+// Test batch gemm matrix bias add fusion with mix type that is not supported.
+TEST_F(LegacyCublasGemmRewriteTest, MatrixBiasMixTypeNotSupported) {
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  x = bf16[16,32] parameter(0)
+  y = bf16[32,16] parameter(1)
+  z = f64[16,16] parameter(2)
+  dot_a = bf16[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  bias = f64[16,16] negate(z)
+  convert = f64[16,16] convert(dot_a)
+  ROOT out = f64[16,16] add(convert, bias)
+}
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
+                          GetOptimizedModule(hlo_text));
+  MatchOptimizedHlo(hlo_text, R"(
+; CHECK:        %[[custom_call:.*]] = {{.*}} custom-call{{.*}}__cublas$gemm
+; CHECK:        %[[gte:.*]] = {{.*}} get-tuple-element{{.*}}%[[custom_call]]
+; CHECK:        ROOT {{.*}} fusion({{.*}}%[[gte]]
+)");
+}
+
+// Test batch gemm matrix bias add fusion with mix type that is not supported
+// because there are consumers of bias add.
+TEST_F(LegacyCublasGemmRewriteTest, MatrixBiasMixTypeAddWithMoreConsumers) {
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  x = bf16[16,32] parameter(0)
+  y = bf16[32,16] parameter(1)
+  z = f32[16,16] parameter(2)
+  dot_a = bf16[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  bias = f32[16,16] negate(z)
+  convert = f32[16,16] convert(dot_a)
+  add_bias = f32[16,16] add(convert, bias)
+  ROOT out = f32[16,16] negate(add_bias)
+}
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+
+  if (IsCuda() &&
+      !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
+    GTEST_SKIP() << "Pre-Ampere casts up bf16 to fp32";
+  }
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
+                          GetOptimizedModule(hlo_text));
+  MatchOptimizedHlo(hlo_text, R"(
+; CHECK:        %[[custom_call:.*]] = {{.*}} custom-call{{.*}}__cublas$gemm
+; CHECK:        %[[gte:.*]] = {{.*}} get-tuple-element{{.*}}%[[custom_call]]
+; CHECK:        ROOT {{.*}} fusion({{.*}}%[[gte]]
+)");
+}
+
+TEST_F(LegacyCublasGemmRewriteTest, MergeBitcastAndAdd) {
+  const char* hlo_text = R"(
+HloModule test
+ENTRY test {
+  x = f32[2,2] parameter(0)
+  y = f32[2,2] parameter(1)
+  bias = f32[4] parameter(2)
+  dot = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  ROOT out = f32[4] add(f32[4] bitcast(dot), bias)
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_text));
+  GemmRewriter pass(Capability(), GetToolkitVersion());
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
+  EXPECT_TRUE(changed);
+
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(
+          m::Bitcast(
+              m::GetTupleElement(
+                  m::CustomCall(
+                      {"__cublas$gemm"}, m::Parameter(0), m::Parameter(1),
+                      m::Bitcast(m::Parameter(2)).WithShape(F32, {2, 2})),
+                  0))
+              .WithShape(F32, {4})));
+}
+
+// In order to have a bias add fused, the bias term must be overwritable.
+// We assume that we may not overwrite parameters of a computation. Hence, we
+// use the third parameter to create a new value which can be overwritten and
+// will be used as the bias. This negate(param_2) has no semantic use, it simply
+// exists so that bias may be overwritten.
+TEST_F(LegacyCublasGemmRewriteTest, FoldConstantBias) {
+  const char* hlo_text = R"(
+HloModule test
+ENTRY test {
+  x = f32[2,2] parameter(0)
+  y = f32[2,2] parameter(1)
+  bias = f32[2,2] broadcast(f32[2] constant({0, 0})), dimensions={0}
+
+  dot1 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  param_2 = f32[2,2] parameter(2)
+  bias1 = f32[2,2] negate(param_2)
+  sum1 = add(dot1, bias1)
+
+  dot2 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  sum2 = add(dot2, f32[2,2] reshape(bias))
+
+  dot3 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  bias3 = f32[2,2] transpose(bias), dimensions={1,0}
+  sum3 = add(dot3, bias3)
+
+  dot4 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  sum4 = add(dot4, f32[2,2] bitcast(bias))
+
+  ROOT root = tuple(sum1, sum2, sum3, sum4)
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_text));
+  GemmRewriter pass(Capability(), GetToolkitVersion());
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
+  SCOPED_TRACE(module->ToString());
+  EXPECT_TRUE(changed);
+
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(m::Tuple(
+          m::GetTupleElement(m::CustomCall(m::Parameter(0), m::Parameter(1),
+                                           m::Negate(m::Parameter(2))),
+                             0),
+          m::GetTupleElement(
+              m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()),
+              0),
+          m::GetTupleElement(
+              m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()),
+              0),
+          m::GetTupleElement(
+              m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()),
+              0))));
+}
+
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+// A test fixture class for tests which are specific to cublasLt
+class CublasLtGemmRewriteTest : public GemmRewriteTest {
+ public:
+  DebugOptions GetDebugOptionsForTest() override {
+    DebugOptions debug_options = GemmRewriteTest::GetDebugOptionsForTest();
+    debug_options.set_xla_gpu_enable_cublaslt(true);
+    debug_options.set_xla_gpu_enable_triton_gemm(false);
+    return debug_options;
+  }
+
+ protected:
+  void SetUp() override {
+    if (SkipGpuBlasLtTest()) {
+      GTEST_SKIP() << "BlasLt is not supported on this GPU architecture";
+    }
+  }
+};
+
+TEST_F(CublasLtGemmRewriteTest, AlphaBetaRewrite) {
+  const char* hlo_text = R"(
+HloModule NonZeroAlphaBeta
+
+ENTRY AddDotsFunc {
+  x = f32[2,2] parameter(0)
+  y = f32[2,2] parameter(1)
+  bias = f32[2,2] parameter(2)
+  k = f32[] constant(3.0)
+  k_broadcast = f32[2, 2] broadcast(k), dimensions={}
+  dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,highest}
+  dot_a_multiplied = f32[2, 2] multiply(dot_a, k_broadcast)
+  ROOT out = f32[2,2] add(dot_a_multiplied, bias)
+}
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,2], {{.*}}: f32[2,2], {{.*}}: f32[2,2]) -> f32[2,2] {
+; CHECK-DAG:     [[X:%[^ ]+]] = f32[2,2]{1,0} parameter(0)
+; CHECK-DAG:     [[Y:%[^ ]+]] = f32[2,2]{1,0} parameter(1)
+; CHECK-DAG:     [[BIAS:%[^ ]+]] = f32[2,2]{1,0} parameter(2)
+; CHECK-NEXT:    [[GEMM:%[^ ]+]] = (f32[2,2]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[X]], [[Y]], [[BIAS]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":3
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":1
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["HIGHEST","HIGHEST"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+; CHECK-NEXT  ROOT [[OUT:%[^ ]+]] = f32[2,2]{1,0} get-tuple-element(%cublas-lt-matmul.2.0), index=0
+)");
+}
+
+TEST_F(CublasLtGemmRewriteTest, BiasMultipleUsersNoOverwrite) {
+  const char* hlo_text = R"(
+HloModule BiasMultipleUsersNoOverwrite
+
+ENTRY AddDotsFunc {
+  x = f32[2,2] parameter(0)
+  y = f32[2,2] parameter(1)
+  bias = f32[2,2] parameter(2)
+  k = f32[] constant(3.0)
+  k_broadcast = f32[2, 2] broadcast(k), dimensions={}
+  dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,highest}
+  dot_a_multiplied = f32[2, 2] multiply(dot_a, k_broadcast)
+  biased_out = f32[2,2] add(dot_a_multiplied, bias)
+  ROOT out = f32[2,2] add(biased_out, bias)
+}
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,2], {{.*}}: f32[2,2], {{.*}}: f32[2,2]) -> f32[2,2] {
+; CHECK-DAG:     [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0)
+; CHECK-DAG:     [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1)
+; CHECK-DAG:     [[BIAS:%[^ ]+]] = f32[2,2]{1,0} parameter(2)
+; CHECK-NEXT:    [[GEMM:%[^ ]+]] = (f32[2,2]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[BIAS]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul",
+; CHECK-NOT:       output_to_operand_aliasing
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":3
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":1
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["HIGHEST","HIGHEST"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+)");
+}
+
+TEST_F(CublasLtGemmRewriteTest, LargerBiasMultipleUsersNoRewrite) {
+  const char* hlo_text = R"(
+HloModule LargerBiasMultipleUsersNoRewrite
+
+ENTRY AddDotsFunc {
+  x = f32[1024,1024] parameter(0)
+  y = f32[1024,1024] parameter(1)
+  bias = f32[1024,1024] parameter(2)
+  dot_a = f32[1024,1024] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  biased_out = f32[1024,1024] add(dot_a, bias)
+  ROOT out = f32[1024,1024] add(biased_out, bias)
+}
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[1024,1024], {{.*}}: f32[1024,1024], {{.*}}: f32[1024,1024]) -> f32[1024,1024] {
+; CHECK-DAG:     [[P0:%[^ ]+]] = f32[1024,1024]{1,0} parameter(0)
+; CHECK-DAG:     [[P1:%[^ ]+]] = f32[1024,1024]{1,0} parameter(1)
+; CHECK-DAG:     [[BIAS:%[^ ]+]] = f32[1024,1024]{1,0} parameter(2)
+; CHECK-NEXT:    [[GEMM_TUPLE:%[^ ]+]] = (f32[1024,1024]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[BIAS]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":1
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+; CHECK-NEXT:  [[GEMM:%[^ ]+]] = f32[1024,1024]{1,0} get-tuple-element([[GEMM_TUPLE]]), index=0
+; CHECK-NEXT:  ROOT [[OUT:%[^ ]+]] = f32[1024,1024]{1,0} add([[GEMM]], [[BIAS]])
+)");
+}
+
+TEST_F(CublasLtGemmRewriteTest, BF16GemmWithBias) {
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY BF16GemmWithBias {
+  x = bf16[8,8]{1,0} parameter(0)
+  y = bf16[8,8]{1,0} parameter(1)
+  dot.5 = bf16[8,8]{1,0} dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  bias = bf16[8,8]{1,0} parameter(2)
+  ROOT add.6 = bf16[8,8]{1,0} add(dot.5, bias)
+}
+  )";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+
+  if (IsCuda() &&
+      !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
+    GTEST_SKIP() << "Pre-Ampere casts up bf16 to fp32";
+  }
+
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+; CHECK-LABEL: ENTRY %BF16GemmWithBias ({{.*}}: bf16[8,8], {{.*}}: bf16[8,8], {{.*}}: bf16[8,8]) -> bf16[8,8] {
+; CHECK-DAG:    [[X:%[^ ]+]] = bf16[8,8]{1,0} parameter(0)
+; CHECK-DAG:    [[Y:%[^ ]+]] = bf16[8,8]{1,0} parameter(1)
+; CHECK-DAG:    [[BIAS:%[^ ]+]] = bf16[8,8]{1,0} parameter(2)
+; CHECK-NEXT:   [[GEMM:%[^ ]+]] = (bf16[8,8]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[X]], [[Y]], [[BIAS]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":1
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+)");
+}
+
+TEST_F(CublasLtGemmRewriteTest, MatrixBias) {
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  x = f32[2,3] parameter(0)
+  y = f32[3,4] parameter(1)
+  z = f32[2,4] parameter(2)
+  dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  ROOT out = f32[2,4] add(dot_a, z)
+}
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[2,4]) -> f32[2,4] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[2,4]{1,0} parameter(2)
+; CHECK-NEXT:    [[GEMM:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":1
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+)");
+}
+
+TEST_F(CublasLtGemmRewriteTest, MatrixBiasWhereBiasIsNotAParameter) {
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  w = f32[2,3] parameter(0)
+  x = f32[3,4] parameter(1)
+  first_dot = f32[2,4] dot(w, x), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  y = f32[2,3] parameter(2)
+  z = f32[3,4] parameter(3)
+  second_dot = f32[2,4] dot(y, z), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  ROOT out = f32[2,4] add(second_dot, first_dot)
+}
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[2,3], {{.*}}: f32[3,4]) -> f32[2,4] {
+; CHECK-DAG:     [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
+; CHECK-DAG:     [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
+; CHECK-DAG:     [[P2:%[^ ]+]] = f32[2,3]{1,0} parameter(2)
+; CHECK-DAG:     [[P3:%[^ ]+]] = f32[3,4]{1,0} parameter(3)
+; CHECK-NEXT:    [[FIRST_GEMM_TUPLE:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+; CHECK:         [[FIRST_GEMM:%[^ ]+]] = f32[2,4]{1,0} get-tuple-element([[FIRST_GEMM_TUPLE]]), index=0
+; CHECK-NEXT:    [[SECOND_GEMM:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P2]], [[P3]], [[FIRST_GEMM]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul",
+; CHECK:           output_to_operand_aliasing={
+; CHECK:              {0}: (2, {})
+; CHECK:           }
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":1
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+)");
+}
+
+TEST_F(CublasLtGemmRewriteTest, VectorBias) {
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  x = f32[2,3] parameter(0)
+  y = f32[3,4] parameter(1)
+  z = f32[4] parameter(2)
+  dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  z_bcast = f32[2,4] broadcast(z), dimensions={1}
+  ROOT out = f32[2,4] add(dot_a, z_bcast)
+}
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[4]) -> f32[2,4] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[4]{0} parameter(2)
+; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"BIAS"
+; CHECK:           }
+)");
+}
+
+// Epilogue Fusion disabled when GEMM has multiple users.
+TEST_F(CublasLtGemmRewriteTest, VectorBiasMultipleUsers) {
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  x = f32[4,4] parameter(0)
+  y = f32[4,4] parameter(1)
+  z = f32[4] parameter(2)
+  c = f32[] constant(5)
+  dot_a = f32[4,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,highest}
+  z_bcast = f32[4,4] broadcast(z), dimensions={1}
+  add_a = f32[4,4] add(dot_a, z_bcast)
+  c_bcast = f32[4,4] broadcast(c), dimensions={}
+  dot_b = f32[4,4] dot(dot_a, c_bcast), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,highest}
+  ROOT out = f32[4,4] dot(add_a, dot_b), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,highest}
+}
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+
+; CHECK:        [[FUSED_COMPUTATION:%[^ ]+]] ([[DUMMY0:[^ ]+]]: f32[4,4], [[DUMMY1:[^ ]+]]: f32[4]) -> f32[4,4] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[4,4]{1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[4]{0} parameter(1)
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[4,4]{1,0} broadcast([[P1]]), dimensions={1}
+; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = f32[4,4]{1,0} add([[P0]], [[P2]])
+}
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[4,4], {{.*}}: f32[4,4], {{.*}}: f32[4]) -> f32[4,4] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[4,4]{1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[4,4]{1,0} parameter(1)
+; CHECK-NEXT:    [[MATMUL0_TUPLE:%[^ ]+]] = (f32[4,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["HIGHEST","HIGHEST"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+; CHECK-NEXT:    [[MATMUL0:%[^ ]+]] = f32[4,4]{1,0} get-tuple-element([[MATMUL0_TUPLE]]), index=0
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[4]{0} parameter(2)
+; CHECK-NEXT:    [[FUSION:%[^ ]+]] = f32[4,4]{1,0} fusion([[MATMUL0]], [[P2]]), kind=kLoop, calls=[[FUSED_COMPUTATION]]
+; CHECK:         [[MATMUL1_TUPLE:%[^ ]+]] = (f32[4,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[MATMUL0]]
+; CHECK:           custom_call_target="__cublas$lt$matmul",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["HIGHEST","HIGHEST"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+; CHECK-NEXT:    [[MATMUL1:%[^ ]+]] = f32[4,4]{1,0} get-tuple-element([[MATMUL1_TUPLE]]), index=0
+; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f32[4,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[FUSION]], [[MATMUL1]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["HIGHEST","HIGHEST"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+      )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, BatchedVectorBias) {
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  x = f32[2,3,4] parameter(0)
+  y = f32[4,5,6] parameter(1)
+  z = f32[3,5,6] parameter(2)
+  dot_a = f32[2,3,5,6] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={0}, operand_precision={highest,highest}
+  z_bcast = f32[2,3,5,6] broadcast(z), dimensions={1,2,3}
+  ROOT out = f32[2,3,5,6] add(dot_a, z_bcast)
+}
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3,4], {{.*}}: f32[4,5,6], {{.*}}: f32[3,5,6]) -> f32[2,3,5,6] {
+; CHECK:         [[MATMUL_TUPLE:%[^ ]+]] = (f32[6,30]{1,0}, s8[{{[0-9]+}}]{0}) custom-call(
+; CHECK:           custom_call_target="__cublas$lt$matmul",
+; CHECK:           output_to_operand_aliasing={
+; CHECK:              {0}: (2, {})
+; CHECK:           }
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":1
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["HIGHEST","HIGHEST"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+; CHECK-NEXT:   [[MATMUL:%[^ ]+]] = f32[6,30]{1,0} get-tuple-element([[MATMUL_TUPLE]]), index=0
+; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = f32[2,3,5,6]{3,2,1,0} bitcast([[MATMUL]])
+      )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, BatchedSharedVectorBias) {
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  x = f32[2,3,4] parameter(0)
+  y = f32[4,5,6] parameter(1)
+  z = f32[6] parameter(2)
+  dot_a = f32[2,3,5,6] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={0}, operand_precision={highest,highest}
+  z_bcast = f32[2,3,5,6] broadcast(z), dimensions={3}
+  ROOT out = f32[2,3,5,6] add(dot_a, z_bcast)
+}
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3,4], {{.*}}: f32[4,5,6], {{.*}}: f32[6]) -> f32[2,3,5,6] {
+; CHECK:         [[MATMUL_TUPLE:%[^ ]+]] = (f32[6,30]{1,0}, s8[{{[0-9]+}}]{0}) custom-call(
+; CHECK:           custom_call_target="__cublas$lt$matmul",
+; CHECK:           output_to_operand_aliasing={
+; CHECK:              {0}: (2, {})
+; CHECK:           }
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":1
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["HIGHEST","HIGHEST"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+; CHECK:         [[MATMUL:%[^ ]+]] = f32[6,30]{1,0} get-tuple-element([[MATMUL_TUPLE]]), index=0
+; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = f32[2,3,5,6]{3,2,1,0} bitcast([[MATMUL]])
+      )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, VectorBiasIncorrectAxisFusedAsMatrix) {
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  x = f32[2,3] parameter(0)
+  y = f32[3,4] parameter(1)
+  z = f32[2] parameter(2)
+  dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  z_bcast = f32[2,4] broadcast(z), dimensions={0}
+  add = f32[2,4] add(dot_a, z_bcast)
+  ROOT out = f32[4,2] transpose(add), dimensions={1,0}
+}
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[2]) -> f32[4,2] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[2]{0} parameter(2)
+; CHECK-NEXT:    [[MATMUL_TUPLE:%[^ ]+]] = (f32[2,4]{0,1}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"BIAS"
+; CHECK:           }
+; CHECK-NEXT:    [[MATMUL:%[^ ]+]] = f32[2,4]{0,1} get-tuple-element([[MATMUL_TUPLE]]), index=0
+; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = f32[4,2]{1,0} bitcast([[MATMUL]])
+)");
+}
+
+TEST_F(CublasLtGemmRewriteTest, VectorBiasSliced) {
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  x = f32[4,3] parameter(0)
+  y = f32[3,4] parameter(1)
+  z = f32[3] parameter(2)
+  dot_a = f32[4,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  slice_a = f32[2,3] slice(dot_a), slice={[0:2], [0:3]}
+  z_bcast = f32[2,3] broadcast(z), dimensions={1}
+  ROOT out = f32[2,3] add(slice_a, z_bcast)
+}
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[4,3], {{.*}}: f32[3,4], {{.*}}: f32[3]) -> f32[2,3] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[4,3]{1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[3]{0} parameter(2)
+; CHECK-NEXT:    [[MATMUL:%[^ ]+]] = (f32[4,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"BIAS"
+; CHECK:           }
+; CHECK-NEXT:    [[GETTUPLE:%[^ ]+]] = f32[4,4]{1,0} get-tuple-element([[MATMUL]]), index=0
+; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = f32[2,3]{1,0} slice([[GETTUPLE]]), slice={[0:2], [0:3]}
+      )");
+}
+
+// Epilogue Fusion disabled when slice has multiple users.
+TEST_F(CublasLtGemmRewriteTest, VectorBiasSlicedMultipleUsers) {
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  x = f32[2,3] parameter(0)
+  y = f32[3,4] parameter(1)
+  z = f32[2] parameter(2)
+  c = f32[] constant(5)
+  dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  slice_a = f32[2,2] slice(dot_a), slice={[0:2], [0:2]}
+  z_bcast = f32[2,2] broadcast(z), dimensions={1}
+  add_a = f32[2,2] add(slice_a, z_bcast)
+  c_bcast = f32[2,2] broadcast(c), dimensions={}
+  dot_b = f32[2,2] dot(slice_a, c_bcast), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  ROOT out = f32[2,2] dot(add_a, dot_b), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[2]) -> f32[2,2] {
+; CHECK-DAG:     [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
+; CHECK-DAG:     [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
+; CHECK-DAG:     [[P2:%[^ ]+]] = f32[2]{0} parameter(2)
+; CHECK-NEXT:    [[MATMUL0_TUPLE:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+; CHECK:         [[MATMUL1_TUPLE:%[^ ]+]] = (f32[2,2]{1,0}, s8[{{[0-9]+}}]{0}) custom-call(
+; CHECK:           custom_call_target="__cublas$lt$matmul",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+; CHECK:         [[MATMUL1:%[^ ]+]] = f32[2,2]{1,0} get-tuple-element([[MATMUL1_TUPLE]]), index=0
+; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f32[2,2]{1,0}, s8[{{[0-9]+}}]{0}) custom-call{{.*}}[[MATMUL1]]
+; CHECK:           custom_call_target="__cublas$lt$matmul",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+      )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, VectorBiasTransposed) {
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  x = f32[2,3] parameter(0)
+  y = f32[3,4] parameter(1)
+  z = f32[2] parameter(2)
+  dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  z_bcast = f32[2,4] parameter(3)
+  ROOT out = f32[2,4] add(dot_a, z_bcast)
+}
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+; CHECK:    [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
+; CHECK-NEXT:    [[P2_BCAST:%[^ ]+]] = f32[2,4]{1,0} parameter(3)
+; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2_BCAST]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":1
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+)");
+}
+
+TEST_F(CublasLtGemmRewriteTest, VectorBiasThenMatrixBias) {
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  x = f32[2,3] parameter(0)
+  y = f32[3,4] parameter(1)
+  z = f32[4] parameter(2)
+  z2 = f32[2,4] parameter(3)
+  dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  z_bcast = f32[2,4] broadcast(z), dimensions={1}
+  add0 = f32[2,4] add(dot_a, z_bcast)
+  ROOT add1 = f32[2,4] add(add0, z2)
+}
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[4], {{.*}}: f32[2,4]) -> f32[2,4] {
+; CHECK-DAG:     [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
+; CHECK-DAG:     [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
+; CHECK-DAG:     [[VECTOR_BIAS:%[^ ]+]] = f32[4]{0} parameter(2)
+; CHECK-DAG:     [[MATRIX_BIAS:%[^ ]+]] = f32[2,4]{1,0} parameter(3)
+; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[MATRIX_BIAS]], [[VECTOR_BIAS]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":1
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"BIAS"
+; CHECK:           }
+)");
+}
+
+TEST_F(CublasLtGemmRewriteTest, BF16VectorBias) {
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  x = bf16[16,24] parameter(0)
+  y = bf16[24,32] parameter(1)
+  z = bf16[32] parameter(2)
+  dot_a = bf16[16,32] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  z_bcast = bf16[16,32] broadcast(z), dimensions={1}
+  ROOT out = bf16[16,32] add(dot_a, z_bcast)
+}
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{3e-3, 1e-3}));
+
+  if (IsCuda() &&
+      !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
+    GTEST_SKIP() << "Pre-Ampere casts up bf16 to fp32";
+  }
+
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: bf16[16,24], {{.*}}: bf16[24,32], {{.*}}: bf16[32]) -> bf16[16,32] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = bf16[16,24]{1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = bf16[24,32]{1,0} parameter(1)
+; CHECK-NEXT:    [[P2:%[^ ]+]] = bf16[32]{0} parameter(2)
+; CHECK-NEXT:    [[OUT:%[^ ]+]] = (bf16[16,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"BIAS"
+      )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, BF16VectorBiasPadded) {
+  if (IsCuda() &&
+      !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
+    GTEST_SKIP() << "Padding of GEMM bf16 operands only implemented on "
+                    "architectures with bf16 Tensor Cores.";
+  }
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  x = bf16[2,3] parameter(0)
+  y = bf16[3,4] parameter(1)
+  z = bf16[4] parameter(2)
+  dot_a = bf16[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  z_bcast = bf16[2,4] broadcast(z), dimensions={1}
+  ROOT out = bf16[2,4] add(dot_a, z_bcast)
+})";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+  MatchOptimizedHlo(hlo_text, R"(
+; CHECK-DAG: ENTRY %test ({{.*}}: bf16[2,3], {{.*}}: bf16[3,4], {{.*}}: bf16[4]) -> bf16[2,4] {
+; CHECK-DAG:    bf16[8,8]{1,0} pad({{.*}}), padding=0_6x0_5
+; CHECK-DAG:    bf16[8,8]{1,0} pad({{.*}}), padding=0_5x0_4
+      )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, ReluActivation) {
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  x = f32[2,3] parameter(0)
+  y = f32[3,4] parameter(1)
+  dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  c = f32[] constant(0)
+  c_bcast = f32[2,4] broadcast(c), dimensions={}
+  ROOT out = f32[2,4] maximum(dot_a, c_bcast)
+}
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4]) -> f32[2,4] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
+; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"RELU"
+; CHECK:           }
+      )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, BatchedReluActivation) {
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  x = f32[2,3,4] parameter(0)
+  y = f32[4,5,6] parameter(1)
+  dot_a = f32[2,3,5,6] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={0}, operand_precision={highest,highest}
+  c = f32[] constant(0)
+  c_bcast = f32[2,3,5,6] broadcast(c), dimensions={}
+  ROOT out = f32[2,3,5,6] maximum(dot_a, c_bcast)
+}
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3,4], {{.*}}: f32[4,5,6]) -> f32[2,3,5,6] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,3,4]{2,1,0} parameter(0)
+; CHECK-NEXT:    [[P0_BITCAST:%[^ ]+]] = f32[6,4]{1,0} bitcast([[P0]])
+; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[4,5,6]{2,1,0} parameter(1)
+; CHECK-NEXT:    [[P1_BITCAST:%[^ ]+]] = f32[4,30]{1,0}
+; CHECK-NEXT:    [[MATMUL_TUPLE:%[^ ]+]] = (f32[6,30]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["HIGHEST","HIGHEST"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"RELU"
+; CHECK:           }
+; CHECK:         [[MATMUL:%[^ ]+]] = f32[6,30]{1,0} get-tuple-element([[MATMUL_TUPLE]]), index=0
+; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = f32[2,3,5,6]{3,2,1,0} bitcast([[MATMUL]])
+      )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, ReluActivationSliced) {
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  x = f32[2,3] parameter(0)
+  y = f32[3,4] parameter(1)
+  dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  c = f32[] constant(0)
+  c_bcast = f32[2,2] broadcast(c), dimensions={}
+  slice_a = f32[2,2] slice(dot_a), slice={[0:2], [0:2]}
+  ROOT out = f32[2,2] maximum(slice_a, c_bcast)
+}
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4]) -> f32[2,2] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
+; CHECK-NEXT:    [[MATMUL_TUPLE:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"RELU"
+; CHECK:           }
+; CHECK:         [[MATMUL:%[^ ]+]] = f32[2,4]{1,0} get-tuple-element([[MATMUL_TUPLE]]), index=0
+; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = f32[2,2]{1,0} slice([[MATMUL]]), slice={[0:2], [0:2]}
+      )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, MatrixBiasReluActivation) {
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  x = f32[2,3] parameter(0)
+  y = f32[3,4] parameter(1)
+  z = f32[2,4] parameter(2)
+  dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  add = f32[2,4] add(dot_a, z)
+  c = f32[] constant(0)
+  c_bcast = f32[2,4] broadcast(c), dimensions={}
+  ROOT out = f32[2,4] maximum(add, c_bcast)
+}
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[2,4]) -> f32[2,4] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[2,4]{1,0} parameter(2)
+; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":1
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"RELU"
+; CHECK:           }
+      )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, SquareMatrixBiasReluActivation) {
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  x = f32[4,4] parameter(0)
+  y = f32[4,4] parameter(1)
+  z = f32[4,4] parameter(2)
+  dot_a = f32[4,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  add = f32[4,4] add(dot_a, z)
+  c = f32[] constant(0)
+  c_bcast = f32[4,4] broadcast(c), dimensions={}
+  ROOT out = f32[4,4] maximum(add, c_bcast)
+}
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[4,4], {{.*}}: f32[4,4], {{.*}}: f32[4,4]) -> f32[4,4] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[4,4]{1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[4,4]{1,0} parameter(1)
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[4,4]{1,0} parameter(2)
+; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f32[4,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":1
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"RELU"
+; CHECK:           }
+      )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, VectorBiasReluActivation) {
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  x = f32[2,3] parameter(0)
+  y = f32[3,4] parameter(1)
+  z = f32[4] parameter(2)
+  dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  z_bcast = f32[2,4] broadcast(z), dimensions={1}
+  add = f32[2,4] add(dot_a, z_bcast)
+  c = f32[] constant(0)
+  c_bcast = f32[2,4] broadcast(c), dimensions={}
+  ROOT out = f32[2,4] maximum(add, c_bcast)
+}
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[4]) -> f32[2,4] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[4]{0} parameter(2)
+; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"BIAS_RELU"
+; CHECK:           }
+      )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, BatchedVectorBiasReluActivation) {
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  x = f32[2,3,4] parameter(0)
+  y = f32[4,5,6] parameter(1)
+  z = f32[3,5,6] parameter(2)
+  dot_a = f32[2,3,5,6] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={0}, operand_precision={highest,highest}
+  z_bcast = f32[2,3,5,6] broadcast(z), dimensions={1,2,3}
+  add = f32[2,3,5,6] add(dot_a, z_bcast)
+  c = f32[] constant(0)
+  c_bcast = f32[2,3,5,6] broadcast(c), dimensions={}
+  ROOT out = f32[2,3,5,6] maximum(add, c_bcast)
+}
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3,4], {{.*}}: f32[4,5,6], {{.*}}: f32[3,5,6]) -> f32[2,3,5,6] {
+; CHECK:         [[MATMUL_TUPLE:%[^ ]+]] = (f32[6,30]{1,0}, s8[{{[0-9]+}}]{0}) custom-call(
+; CHECK:           custom_call_target="__cublas$lt$matmul",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":1
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["HIGHEST","HIGHEST"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"RELU"
+; CHECK:           }
+; CHECK-NEXT:    [[MATMUL:%[^ ]+]] = f32[6,30]{1,0} get-tuple-element([[MATMUL_TUPLE]]), index=0
+; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = f32[2,3,5,6]{3,2,1,0} bitcast([[MATMUL]])
+      )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, VectorBiasTransposedReluActivation) {
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  x = f32[2,3] parameter(0)
+  y = f32[3,4] parameter(1)
+  z = f32[2] parameter(2)
+  dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  z_bcast = f32[2,4] broadcast(z), dimensions={0}
+  add = f32[2,4] add(dot_a, z_bcast)
+  c = f32[] constant(0)
+  c_bcast = f32[2,4] broadcast(c), dimensions={}
+  maximum = f32[2,4] maximum(add, c_bcast)
+  ROOT out = f32[4,2] transpose(maximum), dimensions={1,0}
+}
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[2]) -> f32[4,2] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[2]{0} parameter(2)
+; CHECK-NEXT:    [[MATMUL_TUPLE:%[^ ]+]] = (f32[2,4]{0,1}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul",
+; CHECK:           backend_config={
+; CHECK-DAG:       "alpha_real":1
+; CHECK-DAG:       "alpha_imag":0
+; CHECK-DAG:       "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"BIAS_RELU"
+; CHECK:           }
+; CHECK-NEXT:    [[MATMUL:%[^ ]+]] = f32[2,4]{0,1} get-tuple-element([[MATMUL_TUPLE]]), index=0
+; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = f32[4,2]{1,0} bitcast([[MATMUL]])
+      )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, VectorBiasThenMatrixBiasReluActivation) {
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  x = f32[2,3] parameter(0)
+  y = f32[3,4] parameter(1)
+  z_vec = f32[4] parameter(2)
+  z_matrix = f32[2,4] parameter(3)
+  dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  z_bcast = f32[2,4] broadcast(z_vec), dimensions={1}
+  add0 = f32[2,4] add(dot_a, z_bcast)
+  add1 = f32[2,4] add(add0, z_matrix)
+  c = f32[] constant(0)
+  c_bcast = f32[2,4] broadcast(c), dimensions={}
+  ROOT out = f32[2,4] maximum(add1, c_bcast)
+}
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[4], {{.*}}: f32[2,4]) -> f32[2,4] {
+; CHECK-DAG:     [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
+; CHECK-DAG:     [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
+; CHECK-DAG:     [[P2:%[^ ]+]] = f32[4]{0} parameter(2)
+; CHECK-DAG:     [[P3:%[^ ]+]] = f32[2,4]{1,0} parameter(3)
+; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P3]], [[P2]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":1
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"BIAS_RELU"
+; CHECK:           }
+      )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, ApproxGeluActivation) {
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  x = f32[2,3] parameter(0)
+  y = f32[3,4] parameter(1)
+  dot = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  mul.0 = f32[2,4] multiply(dot, dot)
+  mul.1 = f32[2,4] multiply(dot, mul.0)
+  const.0 = f32[] constant(0.044715)
+  bcast.0 = f32[2,4] broadcast(const.0), dimensions={}
+  mul.2 = f32[2,4] multiply(mul.1, bcast.0)
+  add.0 = f32[2,4] add(dot, mul.2)
+  const.1 = f32[] constant(0.797884583)
+  bcast.1 = f32[2,4] broadcast(const.1), dimensions={}
+  mul.3 = f32[2,4] multiply(add.0, bcast.1)
+  tanh = f32[2,4] tanh(mul.3)
+  const.2 = f32[] constant(1)
+  bcast.2 = f32[2,4] broadcast(const.2), dimensions={}
+  add.2 = f32[2,4] add(tanh, bcast.2)
+  const.3 = f32[] constant(0.5)
+  bcast.3 = f32[2,4] broadcast(const.3), dimensions={}
+  mul.4 = f32[2,4] multiply(add.2, bcast.3)
+  ROOT out = f32[2,4] multiply(dot, mul.4)
+}
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4]) -> f32[2,4] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
+; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"GELU"
+; CHECK:           }
+      )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, ApproxGeluActivationWrongConstant) {
+  // Modify one constant slightly, so it should no longer pattern match.
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  x = f32[2,3] parameter(0)
+  y = f32[3,4] parameter(1)
+  dot = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  mul.0 = f32[2,4] multiply(dot, dot)
+  mul.1 = f32[2,4] multiply(dot, mul.0)
+  const.0 = f32[] constant(0.05)
+  bcast.0 = f32[2,4] broadcast(const.0), dimensions={}
+  mul.2 = f32[2,4] multiply(mul.1, bcast.0)
+  add.0 = f32[2,4] add(dot, mul.2)
+  const.1 = f32[] constant(0.797884583)
+  bcast.1 = f32[2,4] broadcast(const.1), dimensions={}
+  mul.3 = f32[2,4] multiply(add.0, bcast.1)
+  tanh = f32[2,4] tanh(mul.3)
+  const.2 = f32[] constant(1)
+  bcast.2 = f32[2,4] broadcast(const.2), dimensions={}
+  add.2 = f32[2,4] add(tanh, bcast.2)
+  const.3 = f32[] constant(0.5)
+  bcast.3 = f32[2,4] broadcast(const.3), dimensions={}
+  mul.4 = f32[2,4] multiply(add.2, bcast.3)
+  ROOT out = f32[2,4] multiply(dot, mul.4)
+}
+
+)";
+
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+
+; CHECK-NOT: GELU
+      )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, VectorBiasThenApproxGeluActivation) {
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60000
+  auto rocm_switch = false;  // GELU is only available from ROCM 6.0
+#else
+  auto rocm_switch = true;
+#endif
+  if (!IsCuda() && rocm_switch) {
+    GTEST_SKIP() << "TODO: Unsupported blas-lt epilogue on ROCM";
+  }
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  x = f32[2,3] parameter(0)
+  y = f32[3,4] parameter(1)
+  z = f32[4] parameter(2)
+  dot = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  z_bcast = f32[2,4] broadcast(z), dimensions={1}
+  add = f32[2,4] add(dot, z_bcast)
+  mul.0 = f32[2,4] multiply(add, add)
+  mul.1 = f32[2,4] multiply(add, mul.0)
+  const.0 = f32[] constant(0.044715)
+  bcast.0 = f32[2,4] broadcast(const.0), dimensions={}
+  mul.2 = f32[2,4] multiply(mul.1, bcast.0)
+  add.0 = f32[2,4] add(add, mul.2)
+  const.1 = f32[] constant(0.797884583)
+  bcast.1 = f32[2,4] broadcast(const.1), dimensions={}
+  mul.3 = f32[2,4] multiply(add.0, bcast.1)
+  tanh = f32[2,4] tanh(mul.3)
+  const.2 = f32[] constant(1)
+  bcast.2 = f32[2,4] broadcast(const.2), dimensions={}
+  add.2 = f32[2,4] add(tanh, bcast.2)
+  const.3 = f32[] constant(0.5)
+  bcast.3 = f32[2,4] broadcast(const.3), dimensions={}
+  mul.4 = f32[2,4] multiply(add.2, bcast.3)
+  ROOT out = f32[2,4] multiply(add, mul.4)
+}
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[4]) -> f32[2,4] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[4]{0} parameter(2)
+; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"BIAS_GELU"
+; CHECK:           }
+      )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, ApproxGeluActivationWithAux) {
+  if (!IsCuda()) {
+    GTEST_SKIP() << "TODO: Unsupported blas-lt epilogue on ROCM";
+  }
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  x = f32[2,3] parameter(0)
+  y = f32[3,4] parameter(1)
+  dot = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  mul.0 = f32[2,4] multiply(dot, dot)
+  mul.1 = f32[2,4] multiply(dot, mul.0)
+  const.0 = f32[] constant(0.044715)
+  bcast.0 = f32[2,4] broadcast(const.0), dimensions={}
+  mul.2 = f32[2,4] multiply(mul.1, bcast.0)
+  add.0 = f32[2,4] add(dot, mul.2)
+  const.1 = f32[] constant(0.797884583)
+  bcast.1 = f32[2,4] broadcast(const.1), dimensions={}
+  mul.3 = f32[2,4] multiply(add.0, bcast.1)
+  tanh = f32[2,4] tanh(mul.3)
+  const.2 = f32[] constant(1)
+  bcast.2 = f32[2,4] broadcast(const.2), dimensions={}
+  add.2 = f32[2,4] add(tanh, bcast.2)
+  const.3 = f32[] constant(0.5)
+  bcast.3 = f32[2,4] broadcast(const.3), dimensions={}
+  mul.4 = f32[2,4] multiply(add.2, bcast.3)
+  mul.5 = f32[2,4] multiply(dot, mul.4)
+  ROOT out = (f32[2,4], f32[2,4]) tuple(mul.5, dot)
+}
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4]) -> (f32[2,4], f32[2,4]) {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
+; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"GELU_AUX"
+; CHECK:           }
+      )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, VectorBiasThenApproxGeluActivationWithAux) {
+  if (!IsCuda()) {
+    GTEST_SKIP() << "TODO: Unsupported blas-lt epilogue on ROCM";
+  }
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  x = f32[2,3] parameter(0)
+  y = f32[3,4] parameter(1)
+  z = f32[4] parameter(2)
+  dot = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  z_bcast = f32[2,4] broadcast(z), dimensions={1}
+  add = f32[2,4] add(dot, z_bcast)
+  mul.0 = f32[2,4] multiply(add, add)
+  mul.1 = f32[2,4] multiply(add, mul.0)
+  const.0 = f32[] constant(0.044715)
+  bcast.0 = f32[2,4] broadcast(const.0), dimensions={}
+  mul.2 = f32[2,4] multiply(mul.1, bcast.0)
+  add.0 = f32[2,4] add(add, mul.2)
+  const.1 = f32[] constant(0.797884583)
+  bcast.1 = f32[2,4] broadcast(const.1), dimensions={}
+  mul.3 = f32[2,4] multiply(add.0, bcast.1)
+  tanh = f32[2,4] tanh(mul.3)
+  const.2 = f32[] constant(1)
+  bcast.2 = f32[2,4] broadcast(const.2), dimensions={}
+  add.2 = f32[2,4] add(tanh, bcast.2)
+  const.3 = f32[] constant(0.5)
+  bcast.3 = f32[2,4] broadcast(const.3), dimensions={}
+  mul.4 = f32[2,4] multiply(add.2, bcast.3)
+  mul.5 = f32[2,4] multiply(add, mul.4)
+  ROOT out = (f32[2,4], f32[2,4]) tuple(mul.5, add)
+}
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[4]) -> (f32[2,4], f32[2,4]) {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[4]{0} parameter(2)
+; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"BIAS_GELU_AUX"
+; CHECK:           }
+      )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, ApproxGeluActivationBF16) {
+  if (IsCuda() &&
+      !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
+    GTEST_SKIP() << "Padding of GEMM bf16 operands only implemented on "
+                    "architectures with bf16 Tensor Cores.";
+  }
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  x = bf16[2,3] parameter(0)
+  y = bf16[3,4] parameter(1)
+  dot = bf16[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  mul.0 = bf16[2,4] multiply(dot, dot)
+  mul.1 = bf16[2,4] multiply(dot, mul.0)
+  const.0 = bf16[] constant(0.044715)
+  bcast.0 = bf16[2,4] broadcast(const.0), dimensions={}
+  mul.2 = bf16[2,4] multiply(mul.1, bcast.0)
+  add.0 = bf16[2,4] add(dot, mul.2)
+  const.1 = bf16[] constant(0.797884583)
+  bcast.1 = bf16[2,4] broadcast(const.1), dimensions={}
+  mul.3 = bf16[2,4] multiply(add.0, bcast.1)
+  tanh = bf16[2,4] tanh(mul.3)
+  const.2 = bf16[] constant(1)
+  bcast.2 = bf16[2,4] broadcast(const.2), dimensions={}
+  add.2 = bf16[2,4] add(tanh, bcast.2)
+  const.3 = bf16[] constant(0.5)
+  bcast.3 = bf16[2,4] broadcast(const.3), dimensions={}
+  mul.4 = bf16[2,4] multiply(add.2, bcast.3)
+  ROOT out = bf16[2,4] multiply(dot, mul.4)
+})";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{5e-5, 1e-5}));
+  MatchOptimizedHlo(hlo_text, R"(
+; CHECK-DAG: ENTRY %test ({{.*}}: bf16[2,3], {{.*}}: bf16[3,4]) -> bf16[2,4] {
+; CHECK-DAG:    bf16[8,8]{1,0} pad({{.*}}), padding=0_6x0_5
+; CHECK-DAG:    bf16[8,8]{1,0} pad({{.*}}), padding=0_5x0_4
+      )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, ApproxGeluActivationBitcast) {
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  x = f32[2,3] parameter(0)
+  y = f32[3,4] parameter(1)
+  dot = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  dot_bitcast = f32[2,2,2] bitcast(dot)
+  mul.0 = f32[2,2,2] multiply(dot_bitcast, dot_bitcast)
+  mul.1 = f32[2,2,2] multiply(dot_bitcast, mul.0)
+  const.0 = f32[] constant(0.044715)
+  bcast.0 = f32[2,2,2] broadcast(const.0), dimensions={}
+  mul.2 = f32[2,2,2] multiply(mul.1, bcast.0)
+  add.0 = f32[2,2,2] add(dot_bitcast, mul.2)
+  const.1 = f32[] constant(0.797884583)
+  bcast.1 = f32[2,2,2] broadcast(const.1), dimensions={}
+  mul.3 = f32[2,2,2] multiply(add.0, bcast.1)
+  tanh = f32[2,2,2] tanh(mul.3)
+  const.2 = f32[] constant(1)
+  bcast.2 = f32[2,2,2] broadcast(const.2), dimensions={}
+  add.2 = f32[2,2,2] add(tanh, bcast.2)
+  const.3 = f32[] constant(0.5)
+  bcast.3 = f32[2,2,2] broadcast(const.3), dimensions={}
+  mul.4 = f32[2,2,2] multiply(add.2, bcast.3)
+  ROOT out = f32[2,2,2] multiply(dot_bitcast, mul.4)
+}
+
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_text));
+  GemmRewriter pass(Capability(), GetToolkitVersion());
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
+  EXPECT_TRUE(changed);
+
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(
+          m::Bitcast(m::GetTupleElement(
+                         m::CustomCall({"__cublas$lt$matmul"},
+                                       m::Parameter(0).WithShape(F32, {2, 3}),
+                                       m::Parameter(1).WithShape(F32, {3, 4})),
+                         0))
+              .WithShape(F32, {2, 2, 2})));
+}
+
+// For F16, the sizes of all dimensions of the operands are required to be
+// multiples of 8 to allow matrix bias fusion.
+TEST_F(CublasLtGemmRewriteTest, MatrixBiasF16) {
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  x = f16[8,16] parameter(0)
+  y = f16[16,8] parameter(1)
+  z = f16[8,8] parameter(2)
+  dot_a = f16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  ROOT out = f16[8,8] add(dot_a, z)
+}
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f16[8,16], {{.*}}: f16[16,8], {{.*}}: f16[8,8]) -> f16[8,8] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f16[8,16]{1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = f16[16,8]{1,0} parameter(1)
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f16[8,8]{1,0} parameter(2)
+; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f16[8,8]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":1
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+      )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, VectorBiasF32UnpaddedWithBitcast) {
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  x = f32[2,3]{1,0} parameter(0)
+  y = f32[3,4]{1,0} parameter(1)
+  z = f32[2]{0} parameter(2)
+  dot_a = f32[2,4]{0,1} dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  bitc = f32[4,2]{1,0} bitcast(f32[2,4]{0,1} dot_a)
+  z_bcast = f32[4,2] broadcast(z), dimensions={1}
+  ROOT add = f32[4,2]{1,0} add(bitc, z_bcast)
+}
+
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_text));
+  GemmRewriter pass(Capability(), GetToolkitVersion());
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
+  EXPECT_TRUE(changed);
+
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(
+          m::Bitcast(m::GetTupleElement(
+                         m::CustomCall({"__cublas$lt$matmul"}, m::Parameter(0),
+                                       m::Parameter(1),
+                                       m::Parameter(2).WithShape(F32, {2})),
+                         0)
+                         .WithShape(F32, {2, 4}))
+              .WithShape(F32, {4, 2})));
+}
+
+// For F16, the operands are padded on GPUs with Tensor Cores (i.e. Volta and
+// newer architectures) so that the sizes of all dimensions are multiples of 8.
+TEST_F(CublasLtGemmRewriteTest, VectorBiasF16Unpadded) {
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  x = f16[8,16] parameter(0)
+  y = f16[16,8] parameter(1)
+  z = f16[8] parameter(2)
+  dot_a = f16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  z_bcast = f16[8,8] broadcast(z), dimensions={1}
+  ROOT add = f16[8,8] add(dot_a, z_bcast)
+})";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{8e-3, 2e-3}));
+  MatchOptimizedHlo(hlo_text, R"(
+; CHECK-NOT:  pad("
+; CHECK:      custom-call
+; CHECK-SAME: custom_call_target="__cublas$lt$matmul"
+      )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, VectorBiasF16Padded) {
+  if (IsCuda() &&
+      !HasCudaComputeCapability(se::CudaComputeCapability::Volta())) {
+    GTEST_SKIP() << "Padding of GEMM operands only implemented on "
+                    "architectures with Tensor Cores.";
+  }
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  x = f16[6,12] parameter(0)
+  y = f16[12,6] parameter(1)
+  z = f16[6] parameter(2)
+  dot_a = f16[6,6] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  z_bcast = f16[6,6] broadcast(z), dimensions={1}
+  ROOT add = f16[6,6] add(dot_a, z_bcast)
+})";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+
+; CHECK-DAG: ENTRY %test ({{.*}}: f16[6,12], {{.*}}: f16[12,6], {{.*}}: f16[6]) -> f16[6,6] {
+; CHECK-DAG:    f16[8,16]{1,0} pad({{.*}}), padding=0_2x0_4
+; CHECK-DAG:    f16[16,8]{1,0} pad({{.*}}), padding=0_4x0_2
+      )");
+}
+
+// For F16, the operands are padded on GPUs with Tensor Cores (i.e. Volta and
+// newer architectures) so that the sizes of all dimensions are multiples of 8.
+TEST_F(CublasLtGemmRewriteTest, ReluActivationF16Unpadded) {
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  x = f16[8,16] parameter(0)
+  y = f16[16,8] parameter(1)
+  dot_a = f16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  c = f16[] constant(0)
+  c_bcast = f16[8,8] broadcast(c), dimensions={}
+  ROOT out = f16[8,8] maximum(dot_a, c_bcast)
+})";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+  MatchOptimizedHlo(hlo_text, R"(
+; CHECK-NOT:  pad("
+; CHECK:      custom-call
+; CHECK-SAME: custom_call_target="__cublas$lt$matmul"
+      )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, ReluActivationF16Padded) {
+  if (IsCuda() &&
+      !HasCudaComputeCapability(se::CudaComputeCapability::Volta())) {
+    GTEST_SKIP() << "Padding of GEMM operands only implemented on "
+                    "architectures with Tensor Cores.";
+  }
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  x = f16[6,12] parameter(0)
+  y = f16[12,6] parameter(1)
+  dot_a = f16[6,6] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  c = f16[] constant(0)
+  c_bcast = f16[6,6] broadcast(c), dimensions={}
+  ROOT out = f16[6,6] maximum(dot_a, c_bcast)
+})";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+  MatchOptimizedHlo(hlo_text, R"(
+; CHECK-DAG: ENTRY %test ({{.*}}: f16[6,12], {{.*}}: f16[12,6]) -> f16[6,6] {
+; CHECK-DAG:    f16[8,16]{1,0} pad({{.*}}), padding=0_2x0_4
+; CHECK-DAG:    f16[16,8]{1,0} pad({{.*}}), padding=0_4x0_2
+      )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, MatrixBiasReluActivationF16) {
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  x = f16[8,16] parameter(0)
+  y = f16[16,8] parameter(1)
+  z = f16[8,8] parameter(2)
+  dot_a = f16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  add = f16[8,8] add(dot_a, z)
+  c = f16[] constant(0)
+  c_bcast = f16[8,8] broadcast(c), dimensions={}
+  ROOT out = f16[8,8] maximum(add, c_bcast)
+}
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f16[8,16], {{.*}}: f16[16,8], {{.*}}: f16[8,8]) -> f16[8,8] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f16[8,16]{1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = f16[16,8]{1,0} parameter(1)
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f16[8,8]{1,0} parameter(2)
+; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f16[8,8]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":1
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"RELU"
+; CHECK:           }
+      )");
+}
+
+// For F16, the operands are padded on GPUs with Tensor Cores (i.e. Volta and
+// newer architectures) so that the sizes of all dimensions are multiples of 8.
+TEST_F(CublasLtGemmRewriteTest, VectorBiasReluActivationF16Unpadded) {
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  x = f16[8,16] parameter(0)
+  y = f16[16,8] parameter(1)
+  z = f16[8] parameter(2)
+  dot_a = f16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  z_bcast = f16[8,8] broadcast(z), dimensions={1}
+  add = f16[8,8] add(dot_a, z_bcast)
+  c = f16[] constant(0)
+  c_bcast = f16[8,8] broadcast(c), dimensions={}
+  ROOT out = f16[8,8] maximum(add, c_bcast)
+})";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+  MatchOptimizedHlo(hlo_text, R"(
+; CHECK-NOT:  pad("
+; CHECK:      custom-call
+; CHECK-SAME: custom_call_target="__cublas$lt$matmul"
+)");
+}
+
+TEST_F(CublasLtGemmRewriteTest, VectorBiasReluActivationF16Padded) {
+  if (IsCuda() &&
+      !HasCudaComputeCapability(se::CudaComputeCapability::Volta())) {
+    GTEST_SKIP() << "Padding of GEMM operands only implemented on "
+                    "architectures with Tensor Cores.";
+  }
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  x = f16[6,12] parameter(0)
+  y = f16[12,6] parameter(1)
+  z = f16[6] parameter(2)
+  dot_a = f16[6,6] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  z_bcast = f16[6,6] broadcast(z), dimensions={1}
+  add = f16[6,6] add(dot_a, z_bcast)
+  c = f16[] constant(0)
+  c_bcast = f16[6,6] broadcast(c), dimensions={}
+  ROOT out = f16[6,6] maximum(add, c_bcast)
+})";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+  MatchOptimizedHlo(hlo_text, R"(
+; CHECK-DAG: ENTRY %test ({{.*}}: f16[6,12], {{.*}}: f16[12,6], {{.*}}: f16[6]) -> f16[6,6] {
+; CHECK-DAG:   f16[8,16]{1,0} pad({{.*}}), padding=0_2x0_4
+; CHECK-DAG:   f16[16,8]{1,0} pad({{.*}}), padding=0_4x0_2
+      )");
+}
+
+// For bfloat16, the sizes of all dimensions of the operands are required to be
+// multiples of 8 to allow matrix bias fusion.
+TEST_F(CublasLtGemmRewriteTest, MatrixBiasBF16) {
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  x = bf16[8,16] parameter(0)
+  y = bf16[16,8] parameter(1)
+  z = bf16[8,8] parameter(2)
+  dot_a = bf16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  ROOT out = bf16[8,8] add(dot_a, z)
+}
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+
+  if (IsCuda() &&
+      !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
+    GTEST_SKIP() << "Pre-Ampere casts up bf16 to fp32";
+  }
+
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: bf16[8,16], {{.*}}: bf16[16,8], {{.*}}: bf16[8,8]) -> bf16[8,8] {
+; CHECK-DAG:     [[P0:%[^ ]+]] = bf16[8,16]{1,0} parameter(0)
+; CHECK-DAG:     [[P1:%[^ ]+]] = bf16[16,8]{1,0} parameter(1)
+; CHECK-DAG:     [[P2:%[^ ]+]] = bf16[8,8]{1,0} parameter(2)
+; CHECK-NEXT:    [[OUT:%[^ ]+]] = (bf16[8,8]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":1
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+      )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, MatrixBiasBitcastBF16) {
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  x = bf16[8,16] parameter(0)
+  y = bf16[16,8] parameter(1)
+  bias = bf16[2,4,8] parameter(2)
+  dot = bf16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  bitcast = bf16[2,4,8] bitcast(dot)
+  ROOT out = bf16[2,4,8] add(bitcast, bias)
+}
+
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_text));
+  GemmRewriter pass(Capability(), GetToolkitVersion());
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
+  EXPECT_TRUE(changed);
+
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(
+          m::Bitcast(
+              m::GetTupleElement(
+                  m::CustomCall(
+                      {"__cublas$lt$matmul"},
+                      m::Parameter(0).WithShape(BF16, {8, 16}),
+                      m::Parameter(1).WithShape(BF16, {16, 8}),
+                      m::Bitcast(m::Parameter(2)).WithShape(BF16, {8, 8})),
+                  0))
+              .WithShape(BF16, {2, 4, 8})));
+}
+
+// For bfloat16, the operands are padded if necessary on Ampere and newer
+// architectures so that the sizes of all dimensions are multiples of 8.
+TEST_F(CublasLtGemmRewriteTest, VectorBiasBF16Unpadded) {
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  x = bf16[8,16] parameter(0)
+  y = bf16[16,8] parameter(1)
+  z = bf16[8] parameter(2)
+  dot_a = bf16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  z_bcast = bf16[8,8] broadcast(z), dimensions={1}
+  ROOT add = bf16[8,8] add(dot_a, z_bcast)
+})";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{8e-3, 2e-3}));
+  MatchOptimizedHlo(hlo_text, R"(
+; CHECK-NOT:  pad("
+; CHECK:      custom-call
+; CHECK-SAME: custom_call_target="__cublas$lt$matmul"
+      )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, VectorBiasBF16Padded) {
+  if (IsCuda() &&
+      !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
+    GTEST_SKIP() << "Padding of GEMM operands in bfloat16 only implemented on "
+                    "Ampere and newer architectures.";
+  }
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  x = bf16[6,12] parameter(0)
+  y = bf16[12,6] parameter(1)
+  z = bf16[6] parameter(2)
+  dot_a = bf16[6,6] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  z_bcast = bf16[6,6] broadcast(z), dimensions={1}
+  ROOT add = bf16[6,6] add(dot_a, z_bcast)
+})";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+  MatchOptimizedHlo(hlo_text, R"(
+; CHECK-DAG:  ENTRY %test ({{.*}}: bf16[6,12], {{.*}}: bf16[12,6], {{.*}}: bf16[6]) -> bf16[6,6] {
+; CHECK-DAG:    bf16[8,16]{1,0} pad({{.*}}), padding=0_2x0_4
+; CHECK-DAG:    bf16[16,8]{1,0} pad({{.*}}), padding=0_4x0_2
+      )");
+}
+
+// For bfloat16, the operands are padded if necessary on Ampere and newer
+// architectures so that the sizes of all dimensions are multiples of 8.
+TEST_F(CublasLtGemmRewriteTest, ReluActivationBF16Unpadded) {
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  x = bf16[8,16] parameter(0)
+  y = bf16[16,8] parameter(1)
+  dot_a = bf16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  c = bf16[] constant(0)
+  c_bcast = bf16[8,8] broadcast(c), dimensions={}
+  ROOT out = bf16[8,8] maximum(dot_a, c_bcast)
+}
+
+)";
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+  MatchOptimizedHlo(hlo_text, R"(
+; CHECK-NOT:  pad("
+; CHECK:      custom-call
+; CHECK-SAME: custom_call_target="__cublas$lt$matmul"
+      )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, ReluActivationBF16Padded) {
+  if (IsCuda() &&
+      !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
+    GTEST_SKIP() << "Padding of GEMM operands in bfloat16 only implemented on "
+                    "Ampere and newer architectures.";
+  }
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  x = bf16[6,12] parameter(0)
+  y = bf16[12,6] parameter(1)
+  dot_a = bf16[6,6] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  c = bf16[] constant(0)
+  c_bcast = bf16[6,6] broadcast(c), dimensions={}
+  ROOT out = bf16[6,6] maximum(dot_a, c_bcast)
+})";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+  MatchOptimizedHlo(hlo_text, R"(
+; CHECK-DAG: ENTRY %test ({{.*}}: bf16[6,12], {{.*}}: bf16[12,6]) -> bf16[6,6] {
+; CHECK-DAG:     bf16[8,16]{1,0} pad({{.*}}), padding=0_2x0_4
+; CHECK-DAG:     bf16[16,8]{1,0} pad({{.*}}), padding=0_4x0_2
+      )");
+}
+
+// For bfloat16, the operands are padded if necessary on Ampere and newer
+// architectures so that the sizes of all dimensions are multiples of 8.
+TEST_F(CublasLtGemmRewriteTest, VectorBiasReluActivationBF16Unpadded) {
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  x = bf16[8,16] parameter(0)
+  y = bf16[16,8] parameter(1)
+  z = bf16[8] parameter(2)
+  dot_a = bf16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  z_bcast = bf16[8,8] broadcast(z), dimensions={1}
+  add = bf16[8,8] add(dot_a, z_bcast)
+  c = bf16[] constant(0)
+  c_bcast = bf16[8,8] broadcast(c), dimensions={}
+  ROOT out = bf16[8,8] maximum(add, c_bcast)
+})";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{8e-3, 2e-3}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+; CHECK-NOT:  pad("
+; CHECK:      custom-call
+; CHECK-SAME: custom_call_target="__cublas$lt$matmul"
+      )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, VectorBiasReluActivationBF16Padded) {
+  if (IsCuda() &&
+      !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
+    GTEST_SKIP() << "Padding of GEMM operands in bfloat16 only implemented on "
+                    "Ampere and newer architectures.";
+  }
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  x = bf16[6,12] parameter(0)
+  y = bf16[12,6] parameter(1)
+  z = bf16[6] parameter(2)
+  dot_a = bf16[6,6] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  z_bcast = bf16[6,6] broadcast(z), dimensions={1}
+  add = bf16[6,6] add(dot_a, z_bcast)
+  c = bf16[] constant(0)
+  c_bcast = bf16[6,6] broadcast(c), dimensions={}
+  ROOT out = bf16[6,6] maximum(add, c_bcast)
+}
+
+)";
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+  MatchOptimizedHlo(hlo_text, R"(
+; CHECK-DAG: ENTRY %test ({{.*}}: bf16[6,12], {{.*}}: bf16[12,6], {{.*}}: bf16[6]) -> bf16[6,6] {
+; CHECK-DAG:     bf16[8,16]{1,0} pad({{.*}}), padding=0_2x0_4
+; CHECK-DAG:     bf16[16,8]{1,0} pad({{.*}}), padding=0_4x0_2
+      )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, VectorBiasReluActivationF64) {
+  if (!IsCuda()) {
+    GTEST_SKIP() << "TODO: Unsupported blas-lt F64 datatype on ROCM";
+  }
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  x = f64[2,3] parameter(0)
+  y = f64[3,4] parameter(1)
+  z = f64[4] parameter(2)
+  dot_a = f64[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  z_bcast = f64[2,4] broadcast(z), dimensions={1}
+  add = f64[2,4] add(dot_a, z_bcast)
+  c = f64[] constant(0)
+  c_bcast = f64[2,4] broadcast(c), dimensions={}
+  ROOT out = f64[2,4] maximum(add, c_bcast)
+}
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-10, 1e-10}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f64[2,3], {{.*}}: f64[3,4], {{.*}}: f64[4]) -> f64[2,4] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f64[2,3]{1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = f64[3,4]{1,0} parameter(1)
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f64[4]{0} parameter(2)
+; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f64[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"BIAS_RELU"
+; CHECK:           }
+      )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, AlphaSimpleRewriteBiasAddActivation) {
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  x = f32[2,3] parameter(0)
+  y = f32[3,4] parameter(1)
+  z = f32[4] parameter(2)
+  k = f32[] constant(3.0)
+  k_bcast = f32[2,4] broadcast(k), dimensions={}
+  dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,highest}
+  dot_a_multiplied = f32[2, 4] multiply(dot_a, k_bcast)
+  z_bcast = f32[2,4] broadcast(z), dimensions={1}
+  add = f32[2,4] add(dot_a_multiplied, z_bcast)
+  c = f32[] constant(0)
+  c_bcast = f32[2,4] broadcast(c), dimensions={}
+  ROOT out = f32[2,4] maximum(add, c_bcast)
+}
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[4]) -> f32[2,4] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[4]{0} parameter(2)
+; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[P2]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":3
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["HIGHEST","HIGHEST"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"BIAS_RELU"
+; CHECK:           }
+      )");
+}
+
+TEST_F(CublasLtGemmRewriteTest, FoldConstantBias) {
+  const char* hlo_text = R"(
+HloModule test
+ENTRY test {
+  x = f32[2,2] parameter(0)
+  y = f32[2,2] parameter(1)
+  bias = f32[2,2] broadcast(f32[2] constant({0, 0})), dimensions={0}
+
+  dot1 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  bias1 = f32[2,2] parameter(2)
+  sum1 = add(dot1, bias1)
+
+  dot2 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  sum2 = add(dot2, f32[2,2] reshape(bias))
+
+  dot3 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  bias3 = f32[2,2] transpose(bias), dimensions={1,0}
+  sum3 = add(dot3, bias3)
+
+  dot4 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  sum4 = add(dot4, f32[2,2] bitcast(bias))
+
+  ROOT root = tuple(sum1, sum2, sum3, sum4)
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_text));
+  GemmRewriter pass(Capability(), GetToolkitVersion());
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
+  SCOPED_TRACE(module->ToString());
+  EXPECT_TRUE(changed);
+
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(m::Tuple(
+          m::GetTupleElement(
+              m::CustomCall(m::Parameter(0), m::Parameter(1), m::Parameter()),
+              0),
+          m::GetTupleElement(
+              m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()),
+              0),
+          m::GetTupleElement(
+              m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()),
+              0),
+          m::GetTupleElement(
+              m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()),
+              0))));
+}
+
+TEST_F(CublasLtGemmRewriteTest, MultipleMaximumUsers) {
+  const char* hlo_text = R"(
+HloModule multiple_maximum_users
+
+relu {
+  Arg_0 = f32[3,896,54]{2,1,0} parameter(0)
+  constant = f32[] constant(0)
+  broadcast = f32[3,896,54]{2,1,0} broadcast(constant), dimensions={}
+  ROOT maximum = f32[3,896,54]{2,1,0} maximum(Arg_0, broadcast)
+}
+
+ENTRY main {
+  constant = f32[] constant(1)
+  broadcast_1 = f32[3,896,1024]{2,1,0} broadcast(constant), dimensions={}
+  Arg_2 = f32[1024,54]{1,0} parameter(2)
+  dot = f32[3,896,54]{2,1,0} dot(broadcast_1, Arg_2), lhs_contracting_dims={2}, rhs_contracting_dims={0}
+  Arg_1 = f32[54]{0} parameter(1)
+  broadcast_2 = f32[3,896,54]{2,1,0} broadcast(Arg_1), dimensions={2}
+  add = f32[3,896,54]{2,1,0} add(dot, broadcast_2)
+  call = f32[3,896,54]{2,1,0} call(add), to_apply=relu
+  Arg_0 = f32[1]{0} parameter(0)
+  reshape_1 = f32[1,1,1]{2,1,0} reshape(Arg_0)
+  broadcast_3 = f32[1,1,1]{2,1,0} broadcast(reshape_1), dimensions={0,1,2}
+  reshape_2 = f32[] reshape(broadcast_3)
+  broadcast_4 = f32[3,896,54]{2,1,0} broadcast(reshape_2), dimensions={}
+  multiply = f32[3,896,54]{2,1,0} multiply(call, broadcast_4)
+  ROOT tuple = (f32[3,896,54]{2,1,0}, f32[3,896,54]{2,1,0}) tuple(multiply, call)
+}
+)";
+
+  // TODO(cjfj): Why do we need to relax the error constraint here?!
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-4}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+; CHECK:           custom_call_target="__cublas$lt$matmul",
+      )");
+}
+
+// Test gemm matrix bias add fusion with mix type and out of place update(C !=
+// D)
+TEST_F(CublasLtGemmRewriteTest, MatrixBiasMixTypeOutOfPlace) {
+  if (!IsCuda()) {
+    GTEST_SKIP() << "TODO: Unsupported mixed datatypes on ROCM";
+  }
+  std::vector<std::tuple<absl::string_view, absl::string_view>>
+      type_combinations = {
+          {"f16", "f32"},
+          {"bf16", "f32"},
+      };
+
+  const char* hlo_text_template = R"(
+HloModule test
+
+ENTRY test {
+  x = <<ABType>>[16,32] parameter(0)
+  y = <<ABType>>[32,16] parameter(1)
+  z = <<DType>>[16,16] parameter(2)
+  dot_a = <<ABType>>[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  convert = <<DType>>[16,16] convert(dot_a)
+  ROOT out = <<DType>>[16,16] add(convert, z)
+})";
+  for (const auto& type_combination : type_combinations) {
+    absl::flat_hash_map<absl::string_view, absl::string_view> replacements;
+    replacements["<<ABType>>"] = std::get<0>(type_combination);
+    replacements["<<DType>>"] = std::get<1>(type_combination);
+    const auto hlo_text = absl::StrReplaceAll(hlo_text_template, replacements);
+    EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+
+    if (std::get<0>(type_combination) == "bf16" && IsCuda() &&
+        !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
+      continue;
+    }
+
+    TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
+                            GetOptimizedModule(hlo_text));
+    EXPECT_THAT(
+        optimized_module->entry_computation()->root_instruction(),
+        GmockMatch(m::GetTupleElement(
+            m::CustomCall(m::Parameter(0), m::Parameter(1), m::Parameter(2)),
+            0)));
+  }
+}
+
+// Test batch gemm matrix bias add fusion with mix type and out of place
+// update(C != D)
+TEST_F(CublasLtGemmRewriteTest, MatrixBiasMixTypeOutOfPlaceBatched) {
+  if (!IsCuda()) {
+    GTEST_SKIP() << "TODO: Unsupported mixed datatypes on ROCM";
+  }
+  std::vector<std::tuple<absl::string_view, absl::string_view>>
+      type_combinations = {
+          {"f16", "f32"},
+          {"bf16", "f32"},
+      };
+
+  const char* hlo_text_template = R"(
+HloModule test
+
+ENTRY test {
+  x = <<ABType>>[4,16,32] parameter(0)
+  y = <<ABType>>[4,32,16] parameter(1)
+  z = <<DType>>[4,16,16] parameter(2)
+  dot_a = <<ABType>>[4,16,16] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0}
+  convert = <<DType>>[4,16,16] convert(dot_a)
+  ROOT out = <<DType>>[4,16,16] add(convert, z)
+})";
+  for (const auto& type_combination : type_combinations) {
+    absl::flat_hash_map<absl::string_view, absl::string_view> replacements;
+    replacements["<<ABType>>"] = std::get<0>(type_combination);
+    replacements["<<DType>>"] = std::get<1>(type_combination);
+    const auto hlo_text = absl::StrReplaceAll(hlo_text_template, replacements);
+    EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+
+    if (std::get<0>(type_combination) == "bf16" && IsCuda() &&
+        !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
+      continue;
+    }
+
+    TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
+                            GetOptimizedModule(hlo_text));
+    EXPECT_THAT(
+        optimized_module->entry_computation()->root_instruction(),
+        GmockMatch(m::GetTupleElement(
+            m::CustomCall(m::Parameter(0), m::Parameter(1), m::Parameter(2)),
+            0)));
+  }
+}
+
+// Test gemm matrix bias add fusion with mix type and in place update(C = D)
+TEST_F(CublasLtGemmRewriteTest, MatrixBiasMixTypeInPlace) {
+  if (!IsCuda()) {
+    GTEST_SKIP() << "TODO: Unsupported mixed datatypes on ROCM";
+  }
+  std::vector<std::tuple<absl::string_view, absl::string_view>>
+      type_combinations = {
+          {"f16", "f32"},
+          {"bf16", "f32"},
+      };
+  const char* hlo_text_template = R"(
+HloModule test
+
+ENTRY test {
+  x = <<ABType>>[16,32] parameter(0)
+  y = <<ABType>>[32,16] parameter(1)
+  z = <<DType>>[16,16] parameter(2)
+  dot_a = <<ABType>>[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  bias = <<DType>>[16,16] negate(z)
+  convert = <<DType>>[16,16] convert(dot_a)
+  ROOT out = <<DType>>[16,16] add(convert, bias)
+})";
+
+  for (const auto& type_combination : type_combinations) {
+    absl::flat_hash_map<absl::string_view, absl::string_view> replacements;
+    replacements["<<ABType>>"] = std::get<0>(type_combination);
+    replacements["<<DType>>"] = std::get<1>(type_combination);
+    const auto hlo_text = absl::StrReplaceAll(hlo_text_template, replacements);
+    EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+
+    if (std::get<0>(type_combination) == "bf16" && IsCuda() &&
+        !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
+      continue;
+    }
+
+    TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
+                            GetOptimizedModule(hlo_text));
+    EXPECT_THAT(optimized_module->entry_computation()->root_instruction(),
+                GmockMatch(m::GetTupleElement(
+                    m::CustomCall(m::Parameter(0), m::Parameter(1),
+                                  m::Negate(m::Parameter(2))),
+                    0)));
+  }
+}
+
+// Test gemm matrix bias add fusion with mix type that is not supported
+TEST_F(CublasLtGemmRewriteTest, MatrixBiasMixTypeNotSupported) {
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  x = bf16[16,32] parameter(0)
+  y = bf16[32,16] parameter(1)
+  z = f64[16,16] parameter(2)
+  dot_a = bf16[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  bias = f64[16,16] negate(z)
+  convert = f64[16,16] convert(dot_a)
+  ROOT out = f64[16,16] add(convert, bias)
+}
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+
+  if (IsCuda() &&
+      !HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
+    GTEST_SKIP() << "Pre-Ampere casts up bf16 to fp32";
+  }
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
+                          GetOptimizedModule(hlo_text));
+  MatchOptimizedHlo(hlo_text, R"(
+; CHECK:        %[[custom_call:.*]] = {{.*}} custom-call{{.*}}__cublas$lt$matmul
+; CHECK:        %[[tuple:.*]] = bf16[16,16]{1,0} get-tuple-element(%[[custom_call]]), index=0
+; CHECK:        ROOT {{.*}} fusion({{.*}}%[[tuple]]
+)");
+}
+
+class ParameterizedFp8GemmRewriteTest : public ParameterizedGemmRewriteTest {
+ public:
+  ParameterizedFp8GemmRewriteTest() {
+    replacements_[kF8E4M3DatatypePlaceholder] =
+#if GOOGLE_CUDA
+        "f8e4m3fn";
+#else
+        "f8e4m3fnuz";
+#endif
+    replacements_[kF8E5M2DatatypePlaceholder] =
+#if GOOGLE_CUDA
+        "f8e5m2";
+#else
+        "f8e5m2fnuz";
+#endif
+    replacements_[kF8E4M3AmaxPlaceholder] =
+#if GOOGLE_CUDA
+        "448.";
+#else
+        "240.";
+#endif
+  }
+
+ protected:
+  // Check the HLO runs and has an FP8 cuBLAS LT custom call on supported
+  // architectures (Ada, Hopper, and later).
+  void CheckFp8IfSupported(absl::string_view hlo_text,
+                           ErrorSpec error_spec = ErrorSpec{1e-2, 1e-2}) {
+    if (!HasFp8Support()) {
+      return;
+    }
+    std::string replaced_hlo_text =
+        absl::StrReplaceAll(hlo_text, replacements_);
+    EXPECT_TRUE(RunAndCompare(absl::StrReplaceAll(hlo_text, replacements_),
+                              error_spec));
+
+    // Most FP8 tests directly create a GemmRewriter and check the output.
+    // Here, also run the entire HLO pass pipeline to ensure no other passes
+    // interfere with GemmRewriter's pattern matching.
+    TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
+                            GetOptimizedModule(replaced_hlo_text));
+    const HloInstruction* call =
+        FindInstruction(optimized_module.get(), HloOpcode::kCustomCall);
+    ASSERT_NE(call, nullptr);
+    EXPECT_EQ(call->custom_call_target(), "__cublas$lt$matmul$f8");
+  }
+
+  void MatchOptimizedHlo(absl::string_view hlo, const absl::string_view pattern,
+                         bool print_operand_shape = false) {
+    GemmRewriteTest::MatchOptimizedHlo(
+        absl::StrReplaceAll(hlo, replacements_),
+        absl::StrReplaceAll(pattern, replacements_), print_operand_shape);
+  }
+
+  void RunAndFilecheckHloRewrite(
+      absl::string_view hlo, HloPassInterface&& hlo_pass,
+      std::optional<absl::string_view> expected,
+      std::function<void(HloModule*)> after_pass_checks = nullptr,
+      const HloModuleConfig* config = nullptr) {
+    if (expected.has_value()) {
+      std::string replaced_pattern =
+          absl::StrReplaceAll(expected.value(), replacements_);
+      GemmRewriteTest::RunAndFilecheckHloRewrite(
+          absl::StrReplaceAll(hlo, replacements_), std::move(hlo_pass),
+          replaced_pattern, after_pass_checks, config);
+    }
+  }
+
+  absl::StatusOr<std::unique_ptr<VerifiedHloModule>>
+  ParseAndReturnVerifiedModule(absl::string_view hlo_text,
+                               int64_t replica_count = 1,
+                               int64_t num_partitions = 1) {
+    return GemmRewriteTest::ParseAndReturnVerifiedModule(
+        absl::StrReplaceAll(hlo_text, replacements_));
+  }
+
+ private:
+  static constexpr const char* kF8E4M3DatatypePlaceholder{"<<F8E4M3>>"};
+  static constexpr const char* kF8E5M2DatatypePlaceholder{"<<F8E5M2>>"};
+  static constexpr const char* kF8E4M3AmaxPlaceholder{"<<F8E4M3_AMAX>>"};
+};
+
+TEST_P(ParameterizedFp8GemmRewriteTest, DoNotRewriteToF8OnPreAda) {
+  if (HasFp8Support()) {
+    GTEST_SKIP() << "Test requires a pre-Ada GPU or an AMD GPU prior to MI300.";
+  }
+  const char* hlo_text = R"(
+    HloModule test
+
+    ENTRY PreAdaTest {
+      x = <<F8E4M3>>[16,32] parameter(0)
+      y = <<F8E4M3>>[32,16] parameter(1)
+      ROOT out = <<F8E4M3>>[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+          }
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(absl::StrReplaceAll(hlo_text, replacements_),
+                            ErrorSpec{1e-2, 1e-2}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+; CHECK-LABEL: ENTRY %PreAdaTest ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16]) -> <<F8E4M3>>[16,16] {
+; CHECK:    {{.*}} = {{.*}} custom-call({{.*}}, {{.*}})
+; CHECK-DAG:  custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>"
+          )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, DoNotRewriteOnPreAdaWithF32Output) {
+  if (HasFp8Support()) {
+    GTEST_SKIP() << "Test requires a pre-Ada GPU or an AMD GPU prior to MI300.";
+  }
+  const char* hlo_text = R"(
+    HloModule test
+
+    ENTRY PreAdaTest {
+      x = <<F8E4M3>>[16,32] parameter(0)
+      y = <<F8E4M3>>[32,16] parameter(1)
+      ROOT out = f32[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+          }
+
+)";
+
+  EXPECT_TRUE(RunAndCompare(absl::StrReplaceAll(hlo_text, replacements_),
+                            ErrorSpec{1e-2, 1e-2}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+; CHECK-LABEL: ENTRY %PreAdaTest ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16]) -> f32[16,16] {
+; CHECK:    {{.*}} = {{.*}} custom-call({{.*}}, {{.*}})
+; CHECK-DAG:  custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>"
+          )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, UnsupportedTypesF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif  // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif  // TF_ROCM_VERSION < 60000
+
+  // Test with types unsupported by cuBLAS LT when FP8 is used. cuBLAS LT with
+  // FP8 requires one of the operands to be F8E4M3FN.
+  const char* hlo_text = R"(
+    HloModule test
+
+    ENTRY unsupported_types {
+      x = <<F8E5M2>>[16,16] parameter(0)
+      y = <<F8E5M2>>[16,16] parameter(1)
+      ROOT out = <<F8E5M2>>[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+          }
+)";
+  EXPECT_TRUE(RunAndCompare(absl::StrReplaceAll(hlo_text, replacements_),
+                            ErrorSpec{1e-2, 1e-2}));
+  RunAndFilecheckHloRewrite(
+      hlo_text,
+      GemmRewriter(Capability(), GetToolkitVersion(),
+                   GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+      R"(
+; CHECK-LABEL: ENTRY %unsupported_types ({{.*}}: <<F8E5M2>>[16,16], {{.*}}: <<F8E5M2>>[16,16]) -> <<F8E5M2>>[16,16] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = <<F8E5M2>>[16,16]{1,0} parameter(0)
+; CHECK-NEXT:    [[P0_CONVERT:%[^ ]+]] = f16[16,16]{1,0} convert([[P0]])
+; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E5M2>>[16,16]{1,0} parameter(1)
+; CHECK-NEXT:    [[P1_CONVERT:%[^ ]+]] = f16[16,16]{1,0} convert([[P1]])
+; CHECK-NEXT:    [[DOT:%[^ ]+]] = f16[16,16]{1,0} dot([[P0_CONVERT]], [[P1_CONVERT]]), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = <<F8E5M2>>[16,16]{1,0} convert([[DOT]])
+      )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif  // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif  // TF_ROCM_VERSION < 60000
+
+  const char* hlo_text = R"(
+    HloModule test
+
+    ENTRY test {
+      x = <<F8E4M3>>[16,32] parameter(0)
+      y = <<F8E4M3>>[32,16] parameter(1)
+      ROOT out = <<F8E4M3>>[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+          }
+
+)";
+
+  CheckFp8IfSupported(hlo_text);
+  RunAndFilecheckHloRewrite(
+      hlo_text,
+      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+                   GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+      R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16]) -> <<F8E4M3>>[16,16] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
+; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK-NEXT:    [[C1:[^ ]+]] = f32[] constant(1)
+)"
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60200
+      R"(; CHECK-GCN-NEXT:    [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]], [[C1]], /*index=5*/[[C1]]),
+)"
+#else
+      R"(; CHECK-NEXT:    [[OUT:%[^ ]+]] = (<<F8E4M3>>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]], [[C1]], /*index=5*/[[C1]]),
+)"
+#endif
+      R"(; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+      )");
+}
+
+// Do not fuse FP8 matrix bias.
+TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDMatrixBiasF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif  // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60200
+  GTEST_SKIP() << "F8 gemm rewrite for D to be fp8 with Matrix Bias is only "
+                  "supported in ROCm 6.2 and above.";
+#endif  // TF_ROCM_VERSION < 60200
+
+  const char* hlo_text = R"(
+    HloModule test
+
+    ENTRY test {
+      x = <<F8E4M3>>[16,32] parameter(0)
+      y = <<F8E4M3>>[32,16] parameter(1)
+      dot_a = <<F8E4M3>>[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+      b = <<F8E4M3>>[16,16] parameter(2)
+      ROOT out = <<F8E4M3>>[16,16] add(dot_a, b)
+          }
+
+)";
+
+  CheckFp8IfSupported(hlo_text);
+  RunAndFilecheckHloRewrite(
+      hlo_text,
+      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+                   GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+      R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: <<F8E4M3>>[16,16]) -> <<F8E4M3>>[16,16] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
+; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK-NEXT:    [[C1:[^ ]+]] = f32[] constant(1)
+; CHECK-NEXT:    [[DOT_TUPLE:%[^ ]+]] = (<<F8E4M3>>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]], [[C1]], /*index=5*/[[C1]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+; CHECK:         [[DOT:%[^ ]+]] = <<F8E4M3>>[16,16]{1,0} get-tuple-element([[DOT_TUPLE]]), index=0
+; CHECK-NEXT:    [[P2:%[^ ]+]] = <<F8E4M3>>[16,16]{1,0} parameter(2)
+; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = <<F8E4M3>>[16,16]{1,0} add([[DOT]], [[P2]])
+      )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif  // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif  // TF_ROCM_VERSION < 60000
+
+  const char* hlo_text = R"(
+    HloModule test
+
+    ENTRY test {
+      x = <<F8E4M3>>[16,32] parameter(0)
+      y = <<F8E4M3>>[32,16] parameter(1)
+      x_f32 = f32[16,32] convert(x)
+      y_f32 = f32[32,16] convert(y)
+      x_scale = f32[] parameter(2)
+      y_scale = f32[] parameter(3)
+      x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
+      y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
+      x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
+      y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast)
+      ROOT out = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+          }
+
+)";
+
+  CheckFp8IfSupported(hlo_text);
+  RunAndFilecheckHloRewrite(
+      hlo_text,
+      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+                   GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+      R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
+; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[] parameter(2)
+; CHECK-NEXT:    [[P3:%[^ ]+]] = f32[] parameter(3)
+; CHECK-NEXT:    [[C1:%[^ ]+]] = f32[] constant(1)
+; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+      )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDPaddedF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif  // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif  // TF_ROCM_VERSION < 60000
+
+  const char* hlo_text = R"(
+    HloModule test
+
+    ENTRY test {
+      x = <<F8E4M3>>[13,17] parameter(0)
+      y = <<F8E4M3>>[17,31] parameter(1)
+      x_f32 = f32[13,17] convert(x)
+      y_f32 = f32[17,31] convert(y)
+      x_scale = f32[] parameter(2)
+      y_scale = f32[] parameter(3)
+      x_scale_bcast = f32[13,17] broadcast(x_scale), dimensions={}
+      y_scale_bcast = f32[17,31] broadcast(y_scale), dimensions={}
+      x_unscaled = f32[13,17] multiply(x_f32, x_scale_bcast)
+      y_unscaled = f32[17,31] multiply(y_f32, y_scale_bcast)
+      ROOT out = f32[13,31] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+          }
+
+)";
+
+  CheckFp8IfSupported(hlo_text);
+  RunAndFilecheckHloRewrite(
+      hlo_text,
+      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+                   GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+      R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[13,17], {{.*}}: <<F8E4M3>>[17,31], {{.*}}: f32[], {{.*}}: f32[]) -> f32[13,31] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = <<F8E4M3>>[13,17]{1,0} parameter(0)
+; CHECK-NEXT:    [[C0:%[^ ]+]] = <<F8E4M3>>[] constant(0)
+; CHECK-NEXT:    [[P0_PADDED:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} pad([[P0]], [[C0]]), padding=0_3x0_15
+; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[17,31]{1,0} parameter(1)
+; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[31,17]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK-NEXT:    [[C1:%[^ ]+]] = <<F8E4M3>>[] constant(0)
+; CHECK-NEXT:    [[P1_TRANSPOSE_PADDED:%[^ ]+]] = <<F8E4M3>>[32,32]{1,0} pad([[P1_TRANSPOSE]], [[C1]])
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[] parameter(2)
+; CHECK-NEXT:    [[P3:%[^ ]+]] = f32[] parameter(3)
+; CHECK-NEXT:    [[C4:%[^ ]+]] = f32[] constant(1)
+; CHECK-NEXT:    [[DOT_TUPLE:%[^ ]+]] = (f32[16,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_PADDED]], [[P1_TRANSPOSE_PADDED]], [[P2]], [[P3]], [[C4]], /*index=5*/[[C4]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+; CHECK-NEXT:  [[DOT:%[^ ]+]] = f32[16,32]{1,0} get-tuple-element([[DOT_TUPLE]]), index=0
+; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[13,31]{1,0} slice([[DOT]]), slice={[0:13], [0:31]}
+      )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDBitcastF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif  // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif  // TF_ROCM_VERSION < 60000
+
+  const char* hlo_text = R"(
+    HloModule test
+
+    ENTRY test {
+      x = <<F8E4M3>>[2,8,16] parameter(0)
+      y = <<F8E4M3>>[16,16] parameter(1)
+      x_f32 = f32[2,8,16] convert(x)
+      y_f32 = f32[16,16] convert(y)
+      x_scale = f32[] parameter(2)
+      y_scale = f32[] parameter(3)
+      x_scale_bcast = f32[2,8,16] broadcast(x_scale), dimensions={}
+      y_scale_bcast = f32[16,16] broadcast(y_scale), dimensions={}
+      x_unscaled = f32[2,8,16] multiply(x_f32, x_scale_bcast)
+      y_unscaled = f32[16,16] multiply(y_f32, y_scale_bcast)
+      x_bitcast = f32[16,16] bitcast(x_unscaled)
+      ROOT out = f32[16,16] dot(x_bitcast, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+          }
+
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_text));
+  GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+                    GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only});
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
+  EXPECT_TRUE(changed);
+
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(m::GetTupleElement(m::CustomCall({"__cublas$lt$matmul$f8"}), 0)
+                     .WithShape(F32, {16, 16})));
+}
+
+// Test case where F8 inputs are converted to F32 before the dot, but without
+// any scaling.
+TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDWithConvertF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif  // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif  // TF_ROCM_VERSION < 60000
+
+  const char* hlo_text = R"(
+    HloModule test
+
+    ENTRY test {
+      x = <<F8E4M3>>[16,32] parameter(0)
+      y = <<F8E4M3>>[32,16] parameter(1)
+      x_f32 = f32[16,32] convert(x)
+      y_f32 = f32[32,16] convert(y)
+      ROOT out = f32[16,16] dot(x_f32, y_f32), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+          }
+
+)";
+
+  CheckFp8IfSupported(hlo_text);
+  RunAndFilecheckHloRewrite(
+      hlo_text,
+      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+                   GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+      R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16]) -> f32[16,16] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
+; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK-NEXT:    [[C1:%[^ ]+]] = f32[] constant(1)
+; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]], [[C1]], /*index=5*/[[C1]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+      )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDUnaryOpsF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif  // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif  // TF_ROCM_VERSION < 60000
+
+  const char* hlo_text = R"(
+    HloModule test
+
+    ENTRY test {
+      x = <<F8E4M3>>[3] parameter(0)
+      y = <<F8E4M3>>[32,16] parameter(1)
+      x_f32 = f32[3] convert(x)
+      y_f32 = f32[32,16] convert(y)
+      x_scale = f32[] parameter(2)
+      y_scale = f32[] parameter(3)
+      x_scale_bcast = f32[3] broadcast(x_scale), dimensions={}
+      y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
+      x_unscaled = f32[3] multiply(x_f32, x_scale_bcast)
+      zero = f32[] constant(0)
+      x_unscaled_padded = f32[30] pad(x_unscaled, zero), padding=0_27
+      x_unscaled_padded_bcast = f32[30,8,5] broadcast(x_unscaled_padded), dimensions={0}
+      x_unscaled_padded_bcast_sliced = f32[16,8,4] slice(x_unscaled_padded_bcast), slice={[2:18], [0:8], [0:4]}
+      x_unscaled_padded_bcast_sliced_reshaped = f32[16,32] reshape(x_unscaled_padded_bcast_sliced)
+      y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast)
+      ROOT out = f32[16,16] dot(x_unscaled_padded_bcast_sliced_reshaped, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+          }
+
+)";
+
+  CheckFp8IfSupported(hlo_text);
+  RunAndFilecheckHloRewrite(
+      hlo_text,
+      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+                   GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+      R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[3], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = <<F8E4M3>>[3]{0} parameter(0)
+; CHECK-NEXT:    [[C0:%[^ ]+]] = f32[] constant(0)
+; CHECK-NEXT:    [[C0_CONVERT:%[^ ]+]] = <<F8E4M3>>[] convert([[C0]])
+; CHECK-NEXT:    [[P0_U0:%[^ ]+]] = <<F8E4M3>>[30]{0} pad([[P0]], [[C0_CONVERT]]), padding=0_27
+; CHECK-NEXT:    [[P0_U1:%[^ ]+]] = <<F8E4M3>>[30,8,5]{2,1,0} broadcast([[P0_U0]]), dimensions={0}
+; CHECK-NEXT:    [[P0_U2:%[^ ]+]] = <<F8E4M3>>[16,8,4]{2,1,0} slice([[P0_U1]]), slice={[2:18], [0:8], [0:4]}
+; CHECK-NEXT:    [[P0_U3:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} reshape([[P0_U2]])
+; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
+; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[] parameter(2)
+; CHECK-NEXT:    [[P3:%[^ ]+]] = f32[] parameter(3)
+; CHECK-NEXT:    [[C2:%[^ ]+]] = f32[] constant(1)
+; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_U3]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C2]], /*index=5*/[[C2]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+      )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest,
+       UnscaledABUnscaledDUnaryOpsWithConvertF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif  // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif  // TF_ROCM_VERSION < 60000
+
+  const char* hlo_text = R"(
+    HloModule test
+
+    ENTRY test {
+      x = <<F8E4M3>>[3] parameter(0)
+      y = <<F8E4M3>>[32,16] parameter(1)
+      x_f32 = f32[3] convert(x)
+      y_f32 = f32[32,16] convert(y)
+      zero = f32[] constant(0)
+      x_padded = f32[30] pad(x_f32, zero), padding=0_27
+      x_padded_bcast = f32[30,8,5] broadcast(x_padded), dimensions={0}
+      x_padded_bcast_sliced = f32[16,8,4] slice(x_padded_bcast), slice={[2:18], [0:8], [0:4]}
+      x_padded_bcast_sliced_reshaped = f32[16,32] reshape(x_padded_bcast_sliced)
+      ROOT out = f32[16,16] dot(x_padded_bcast_sliced_reshaped, y_f32), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+          }
+
+)";
+
+  CheckFp8IfSupported(hlo_text);
+  RunAndFilecheckHloRewrite(
+      hlo_text,
+      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+                   GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+      R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[3], {{.*}}: <<F8E4M3>>[32,16]) -> f32[16,16] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = <<F8E4M3>>[3]{0} parameter(0)
+; CHECK-NEXT:    [[C0:%[^ ]+]] = f32[] constant(0)
+; CHECK-NEXT:    [[C0_CONVERT:%[^ ]+]] = <<F8E4M3>>[] convert([[C0]])
+; CHECK-NEXT:    [[P0_U0:%[^ ]+]] = <<F8E4M3>>[30]{0} pad([[P0]], [[C0_CONVERT]]), padding=0_27
+; CHECK-NEXT:    [[P0_U1:%[^ ]+]] = <<F8E4M3>>[30,8,5]{2,1,0} broadcast([[P0_U0]]), dimensions={0}
+; CHECK-NEXT:    [[P0_U2:%[^ ]+]] = <<F8E4M3>>[16,8,4]{2,1,0} slice([[P0_U1]]), slice={[2:18], [0:8], [0:4]}
+; CHECK-NEXT:    [[P0_U3:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} reshape([[P0_U2]])
+; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
+; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK-NEXT:    [[C2:%[^ ]+]] = f32[] constant(1)
+; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_U3]], [[P1_TRANSPOSE]], [[C2]], [[C2]], [[C2]], /*index=5*/[[C2]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+      )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDDynamicSliceF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif  // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif  // TF_ROCM_VERSION < 60000
+
+  const char* hlo_text = R"(
+    HloModule test
+
+    ENTRY test {
+      x = <<F8E4M3>>[32,32] parameter(0)
+      y = <<F8E4M3>>[16,32] parameter(1)
+      zero = s32[] constant(0)
+      x_f32 = f32[32,32] convert(x)
+      y_f32 = f32[16,32] convert(y)
+      x_scale = f32[] parameter(2)
+      y_scale = f32[] parameter(3)
+      x_scale_bcast = f32[32,32] broadcast(x_scale), dimensions={}
+      y_scale_bcast = f32[16,32] broadcast(y_scale), dimensions={}
+      x_unscaled = f32[32,32] multiply(x_f32, x_scale_bcast)
+      y_unscaled = f32[16,32] multiply(y_f32, y_scale_bcast)
+      dyn_slice = f32[16,32]{1,0} dynamic-slice(x_unscaled, zero, zero), dynamic_slice_sizes={16,32}
+      ROOT dot_a = f32[16,16] dot(dyn_slice, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={1}
+          }
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_text));
+  GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+                    GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only});
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
+  EXPECT_TRUE(changed);
+
+  CheckFp8IfSupported(hlo_text);
+  RunAndFilecheckHloRewrite(
+      hlo_text,
+      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+                   GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+      R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[32,32], {{.*}}: <<F8E4M3>>[16,32], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = <<F8E4M3>>[32,32]{1,0} parameter(0)
+; CHECK-NEXT:    [[C0:%[^ ]+]] = s32[] constant(0)
+; CHECK-NEXT:    [[DYN_SLICE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} dynamic-slice([[P0]], [[C0]], [[C0]]), dynamic_slice_sizes={16,32}
+; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(1)
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[] parameter(2)
+; CHECK-NEXT:    [[P3:%[^ ]+]] = f32[] parameter(3)
+; CHECK-NEXT:    [[C1:%[^ ]+]] = f32[] constant(1)
+; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[DYN_SLICE]], [[P1]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+      )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDSelectF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif  // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif  // TF_ROCM_VERSION < 60000
+
+  const char* hlo_text = R"(
+    HloModule test
+
+    ENTRY test {
+      x = <<F8E4M3>>[16,32] parameter(0)
+      y = <<F8E4M3>>[16,32] parameter(1)
+      x_f32 = f32[16,32] convert(x)
+      y_f32 = f32[16,32] convert(y)
+      x_scale = f32[] parameter(2)
+      y_scale = f32[] parameter(3)
+      x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
+      y_scale_bcast = f32[16,32] broadcast(y_scale), dimensions={}
+      x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
+      y_unscaled = f32[16,32] multiply(y_f32, y_scale_bcast)
+      k = pred[16,32] parameter(4)
+      c = f32[] constant(0)
+      c_bcast = f32[16,32] broadcast(c), dimensions={}
+      select_a = f32[16,32] select(k, y_unscaled, c_bcast)
+      ROOT dot_a = f32[16,16] dot(x_unscaled, select_a), lhs_contracting_dims={1}, rhs_contracting_dims={1}
+          }
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_text));
+  GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+                    GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only});
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
+  EXPECT_TRUE(changed);
+
+  CheckFp8IfSupported(hlo_text);
+  RunAndFilecheckHloRewrite(
+      hlo_text,
+      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+                   GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+      R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[16,32], {{.*}}: f32[], {{.*}}: f32[], {{.*}}: pred[16,32]) -> f32[16,16] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
+; CHECK-NEXT:    [[P4:%[^ ]+]] = pred[16,32]{1,0} parameter(4)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(1)
+; CHECK-NEXT:    [[C0:%[^ ]+]] = f32[] constant(0)
+; CHECK-NEXT:    [[C0_BCAST:%[^ ]+]] = f32[16,32]{1,0} broadcast([[C0]]), dimensions={}
+; CHECK-NEXT:    [[C0_CONVERT:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} convert([[C0_BCAST]])
+; CHECK-NEXT:    [[SELECT:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} select([[P4]], [[P1]], [[C0_CONVERT]])
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[] parameter(2)
+; CHECK-NEXT:    [[P3:%[^ ]+]] = f32[] parameter(3)
+; CHECK-NEXT:    [[C1:%[^ ]+]] = f32[] constant(1)
+; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[SELECT]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+      )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest,
+       ScaledABUnscaledDSelectNonzeroConstantF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif  // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif  // TF_ROCM_VERSION < 60000
+
+  const char* hlo_text = R"(
+    HloModule test
+
+    ENTRY test {
+      x = <<F8E4M3>>[16,32] parameter(0)
+      y = <<F8E4M3>>[16,32] parameter(1)
+      x_f32 = f32[16,32] convert(x)
+      y_f32 = f32[16,32] convert(y)
+      x_scale = f32[] parameter(2)
+      y_scale = f32[] parameter(3)
+      x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
+      y_scale_bcast = f32[16,32] broadcast(y_scale), dimensions={}
+      x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
+      y_unscaled = f32[16,32] multiply(y_f32, y_scale_bcast)
+      k = pred[16,32] parameter(4)
+      c = f32[] constant(1)
+      c_bcast = f32[16,32] broadcast(c), dimensions={}
+      select_a = f32[16,32] select(k, y_unscaled, c_bcast)
+      ROOT dot_a = f32[16,16] dot(x_unscaled, select_a), lhs_contracting_dims={1}, rhs_contracting_dims={1}
+          }
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_text));
+  GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+                    GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only});
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
+  EXPECT_FALSE(changed);
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, BatchedScaledABUnscaledDF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif  // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif  // TF_ROCM_VERSION < 60000
+
+  const char* hlo_text = R"(
+    HloModule test
+
+    ENTRY test {
+      x = <<F8E4M3>>[10,16,32] parameter(0)
+      y = <<F8E4M3>>[10,32,16] parameter(1)
+      x_f32 = f32[10,16,32] convert(x)
+      y_f32 = f32[10,32,16] convert(y)
+      x_scale = f32[] parameter(2)
+      y_scale = f32[] parameter(3)
+      x_scale_bcast = f32[10,16,32] broadcast(x_scale), dimensions={}
+      y_scale_bcast = f32[10,32,16] broadcast(y_scale), dimensions={}
+      x_unscaled = f32[10,16,32] multiply(x_f32, x_scale_bcast)
+      y_unscaled = f32[10,32,16] multiply(y_f32, y_scale_bcast)
+      ROOT out = f32[10,16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0}
+          }
+
+)";
+
+  CheckFp8IfSupported(hlo_text);
+  RunAndFilecheckHloRewrite(
+      hlo_text,
+      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+                   GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+      R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[10,16,32], {{.*}}: <<F8E4M3>>[10,32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[10,16,16] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = <<F8E4M3>>[10,16,32]{2,1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[10,32,16]{2,1,0} parameter(1)
+; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[10,16,32]{2,1,0} transpose([[P1]]), dimensions={0,2,1}
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[] parameter(2)
+; CHECK-NEXT:    [[P3:%[^ ]+]] = f32[] parameter(3)
+; CHECK-NEXT:    [[C1:%[^ ]+]] = f32[] constant(1)
+; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f32[10,16,16]{2,1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["2"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["2"]
+; CHECK-DAG:           "lhs_batch_dimensions":["0"]
+; CHECK-DAG:           "rhs_batch_dimensions":["0"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+      )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABAlphaDF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif  // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif  // TF_ROCM_VERSION < 60000
+
+  const char* hlo_text = R"(
+    HloModule test
+
+    ENTRY test {
+      x = <<F8E4M3>>[16,32] parameter(0)
+      y = <<F8E4M3>>[32,16] parameter(1)
+      x_f32 = f32[16,32] convert(x)
+      y_f32 = f32[32,16] convert(y)
+      x_scale = f32[] parameter(2)
+      y_scale = f32[] parameter(3)
+      x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
+      y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
+      x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
+      y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast)
+      k = f32[] constant(3.0)
+      k_bcast = f32[16,16] broadcast(k), dimensions={}
+      dot_a = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+      ROOT out = f32[16,16] multiply(dot_a, k_bcast)
+          }
+
+)";
+
+  CheckFp8IfSupported(hlo_text);
+  RunAndFilecheckHloRewrite(
+      hlo_text,
+      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+                   GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+      R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
+; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[] parameter(2)
+; CHECK-NEXT:    [[P3:%[^ ]+]] = f32[] parameter(3)
+; CHECK-NEXT:    [[C1:%[^ ]+]] = f32[] constant(1)
+; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":3
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+      )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDReluActivationF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif  // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif  // TF_ROCM_VERSION < 60000
+
+  const char* hlo_text = R"(
+    HloModule test
+
+    ENTRY test {
+      x = <<F8E4M3>>[16,32] parameter(0)
+      y = <<F8E4M3>>[32,16] parameter(1)
+      x_f32 = f32[16,32] convert(x)
+      y_f32 = f32[32,16] convert(y)
+      x_scale = f32[] parameter(2)
+      y_scale = f32[] parameter(3)
+      x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
+      y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
+      x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
+      y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast)
+      dot_a = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+      c = f32[] constant(0)
+      c_bcast = f32[16,16] broadcast(c), dimensions={}
+      ROOT out = f32[16,16] maximum(dot_a, c_bcast)
+          }
+
+)";
+
+  CheckFp8IfSupported(hlo_text);
+  RunAndFilecheckHloRewrite(
+      hlo_text,
+      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+                   GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+      R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
+; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[] parameter(2)
+; CHECK-NEXT:    [[P3:%[^ ]+]] = f32[] parameter(3)
+; CHECK-NEXT:    [[C1:%[^ ]+]] = f32[] constant(1)
+; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"RELU"
+
+; CHECK:           }
+      )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest,
+       ScaledABUnscaledDVectorBiasThenApproxGeluActivationF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif  // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif  // TF_ROCM_VERSION < 60000
+
+  const char* hlo_text = R"(
+    HloModule test
+    ENTRY test {
+      x = <<F8E4M3>>[16,32] parameter(0)
+      y = <<F8E4M3>>[32,16] parameter(1)
+      x_bf16 = bf16[16,32] convert(x)
+      y_bf16 = bf16[32,16] convert(y)
+      x_scale = bf16[] parameter(2)
+      y_scale = bf16[] parameter(3)
+      bias = bf16[16] parameter(4)
+      x_scale_bcast = bf16[16,32] broadcast(x_scale), dimensions={}
+      y_scale_bcast = bf16[32,16] broadcast(y_scale), dimensions={}
+      x_unscaled = bf16[16,32] multiply(x_bf16, x_scale_bcast)
+      y_unscaled = bf16[32,16] multiply(y_bf16, y_scale_bcast)
+      dot1 = bf16[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+      b_bcast = bf16[16,16] broadcast(bias), dimensions={1}
+      dot = bf16[16,16] add(dot1, b_bcast)
+      mul.0 = bf16[16,16] multiply(dot, dot)
+      mul.1 = bf16[16,16] multiply(dot, mul.0)
+      const.0 = bf16[] constant(0.044715)
+      bcast.0 = bf16[16,16] broadcast(const.0), dimensions={}
+      mul.2 = bf16[16,16] multiply(mul.1, bcast.0)
+      add.0 = bf16[16,16] add(dot, mul.2)
+      const.1 = bf16[] constant(0.797884583)
+      bcast.1 = bf16[16,16] broadcast(const.1), dimensions={}
+      mul.3 = bf16[16,16] multiply(add.0, bcast.1)
+      tanh = bf16[16,16] tanh(mul.3)
+      const.2 = bf16[] constant(1)
+      bcast.2 = bf16[16,16] broadcast(const.2), dimensions={}
+      add.2 = bf16[16,16] add(tanh, bcast.2)
+      const.3 = bf16[] constant(0.5)
+      bcast.3 = bf16[16,16] broadcast(const.3), dimensions={}
+      mul.4 = bf16[16,16] multiply(add.2, bcast.3)
+      ROOT out = bf16[16,16] multiply(dot, mul.4)
+          }
+)";
+
+  CheckFp8IfSupported(hlo_text);
+
+// Fusing gelu into FP8 cublas matmuls is disabled on CUDA versions less
+// than 12.4.
+#if (GOOGLE_CUDA && CUDA_VERSION >= 12040) || TENSORFLOW_USE_ROCM
+  RunAndFilecheckHloRewrite(
+      hlo_text,
+      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+                   GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+      R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: bf16[], {{.*}}: bf16[], {{.*}}: bf16[16]) -> bf16[16,16] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
+; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK-NEXT:    [[P2:%[^ ]+]] = bf16[] parameter(2)
+; CHECK-NEXT:    [[XS:%[^ ]+]] = f32[] convert([[P2]])
+; CHECK-NEXT:    [[P3:%[^ ]+]] = bf16[] parameter(3)
+; CHECK-NEXT:    [[XS1:%[^ ]+]] = f32[] convert([[P3]])
+; CHECK-NEXT:    [[C1:%[^ ]+]] = f32[] constant(1)
+)"
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60200
+      R"(; CHECK-GCN-NEXT:    [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]], [[C1]], /*index=5*/[[C1]]),
+)"
+#else
+      R"(; CHECK-NEXT:    [[B:%[^ ]+]] = bf16[16]{0} parameter(4)
+; CHECK-NEXT:    [[OUT:%[^ ]+]] = (bf16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]], [[C1]], /*index=5*/[[C1]], [[B]]),
+)"
+#endif
+      R"(; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+)"
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60200
+      R"(; CHECK-GCN-DAG:         "epilogue":"DEFAULT"
+)"
+#else
+      R"(; CHECK-DAG:         "epilogue":"BIAS_GELU"
+)"
+#endif
+      R"(; CHECK:           }
+      )");
+#endif  // (GOOGLE_CUDA && CUDA_VERSION >= 12040) || TENSORFLOW_USE_ROCM
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest,
+       ScaledABUnscaledDApproxGeluActivationF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif  // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif  // TF_ROCM_VERSION < 60000
+
+  const char* hlo_text = R"(
+    HloModule test
+    ENTRY test {
+      x = <<F8E4M3>>[16,32] parameter(0)
+      y = <<F8E4M3>>[32,16] parameter(1)
+      x_bf16 = bf16[16,32] convert(x)
+      y_bf16 = bf16[32,16] convert(y)
+      x_scale = bf16[] parameter(2)
+      y_scale = bf16[] parameter(3)
+      x_scale_bcast = bf16[16,32] broadcast(x_scale), dimensions={}
+      y_scale_bcast = bf16[32,16] broadcast(y_scale), dimensions={}
+      x_unscaled = bf16[16,32] multiply(x_bf16, x_scale_bcast)
+      y_unscaled = bf16[32,16] multiply(y_bf16, y_scale_bcast)
+      dot = bf16[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+      mul.0 = bf16[16,16] multiply(dot, dot)
+      mul.1 = bf16[16,16] multiply(dot, mul.0)
+      const.0 = bf16[] constant(0.044715)
+      bcast.0 = bf16[16,16] broadcast(const.0), dimensions={}
+      mul.2 = bf16[16,16] multiply(mul.1, bcast.0)
+      add.0 = bf16[16,16] add(dot, mul.2)
+      const.1 = bf16[] constant(0.797884583)
+      bcast.1 = bf16[16,16] broadcast(const.1), dimensions={}
+      mul.3 = bf16[16,16] multiply(add.0, bcast.1)
+      tanh = bf16[16,16] tanh(mul.3)
+      const.2 = bf16[] constant(1)
+      bcast.2 = bf16[16,16] broadcast(const.2), dimensions={}
+      add.2 = bf16[16,16] add(tanh, bcast.2)
+      const.3 = bf16[] constant(0.5)
+      bcast.3 = bf16[16,16] broadcast(const.3), dimensions={}
+      mul.4 = bf16[16,16] multiply(add.2, bcast.3)
+      ROOT out = bf16[16,16] multiply(dot, mul.4)
+          }
+)";
+
+  CheckFp8IfSupported(hlo_text);
+
+// Fusing gelu into FP8 cublas matmuls is disabled on CUDA versions less
+// than 12.4.
+#if (GOOGLE_CUDA && CUDA_VERSION >= 12040) || TENSORFLOW_USE_ROCM
+  // Currently, hipBlasLt does not support output datatype bf16 for fp8 matmul.
+  // And no fusion was done for such cases.
+  RunAndFilecheckHloRewrite(
+      hlo_text,
+      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+                   GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+      R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: bf16[], {{.*}}: bf16[]) -> bf16[16,16] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
+; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK-NEXT:    [[P2:%[^ ]+]] = bf16[] parameter(2)
+; CHECK-NEXT:    [[XS:%[^ ]+]] = f32[] convert([[P2]])
+; CHECK-NEXT:    [[P3:%[^ ]+]] = bf16[] parameter(3)
+; CHECK-NEXT:    [[XS1:%[^ ]+]] = f32[] convert([[P3]])
+; CHECK-NEXT:    [[C1:%[^ ]+]] = f32[] constant(1)
+)"
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60200
+      R"(; CHECK-GCN-NEXT:    [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]], [[C1]], /*index=5*/[[C1]]),
+)"
+#else
+      R"(; CHECK-NEXT:    [[OUT:%[^ ]+]] = (bf16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]], [[C1]], /*index=5*/[[C1]]),
+)"
+#endif
+      R"(; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+)"
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60200
+      R"(; CHECK-GCN-DAG:         "epilogue":"DEFAULT"
+)"
+#else
+      R"(; CHECK-DAG:         "epilogue":"GELU"
+)"
+#endif
+      R"(; CHECK:           }
+      )");
+#endif  // (GOOGLE_CUDA && CUDA_VERSION >= 12040) || TENSORFLOW_USE_ROCM
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, InvScaledABUnscaledDF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif  // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif  // TF_ROCM_VERSION < 60000
+
+  const char* hlo_text = R"(
+    HloModule test
+
+    ENTRY test {
+      x = <<F8E4M3>>[16,32] parameter(0)
+      y = <<F8E4M3>>[32,16] parameter(1)
+      x_f32 = f32[16,32] convert(x)
+      y_f32 = f32[32,16] convert(y)
+      x_scale = f32[] parameter(2)
+      y_scale = f32[] parameter(3)
+      x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
+      y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
+      x_unscaled = f32[16,32] divide(x_f32, x_scale_bcast)
+      y_unscaled = f32[32,16] divide(y_f32, y_scale_bcast)
+      ROOT out = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+          }
+
+)";
+
+  CheckFp8IfSupported(hlo_text);
+  RunAndFilecheckHloRewrite(
+      hlo_text,
+      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+                   GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+      R"(
+; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
+      )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDMatrixBiasF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif  // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif  // TF_ROCM_VERSION < 60000
+
+  const char* hlo_text = R"(
+    HloModule test
+
+    ENTRY test {
+      x = <<F8E4M3>>[16,32] parameter(0)
+      y = <<F8E4M3>>[32,16] parameter(1)
+      b = f32[16,16] parameter(2)
+      one = f32[] constant(1)
+      ones = f32[16,16] broadcast(one), dimensions={}
+      b_ones = f32[16,16] add(b, ones)
+      x_f32 = f32[16,32] convert(x)
+      y_f32 = f32[32,16] convert(y)
+      x_scale = f32[] parameter(3)
+      y_scale = f32[] parameter(4)
+      x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
+      y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
+      x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
+      y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast)
+      dot_a = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+      ROOT out = add(dot_a, b_ones)
+          }
+
+)";
+
+  CheckFp8IfSupported(hlo_text);
+  RunAndFilecheckHloRewrite(
+      hlo_text,
+      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+                   GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+      R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[16,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
+; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK:         [[C0:%[^ ]+]] = f32[16,16]{1,0} add({{.*}})
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[] parameter(3)
+; CHECK-NEXT:    [[P3:%[^ ]+]] = f32[] parameter(4)
+; CHECK-NEXT:    [[C1:%[^ ]+]] = f32[] constant(1)
+; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C0]], [[P2]], [[P3]], /*index=5*/[[C1]], [[C1]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
+; CHECK:           output_to_operand_aliasing={
+; CHECK-SAME:        {0}: (2, {})
+; CHECK-SAME:      }
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":1
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+      )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDMatrixBiasPaddedF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif  // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif  // TF_ROCM_VERSION < 60000
+
+  const char* hlo_text = R"(
+    HloModule test
+
+    ENTRY test {
+      x = <<F8E4M3>>[14,31] parameter(0)
+      y = <<F8E4M3>>[31,14] parameter(1)
+      b = f32[14,14] parameter(2)
+      x_f32 = f32[14,31] convert(x)
+      y_f32 = f32[31,14] convert(y)
+      x_scale = f32[] parameter(3)
+      y_scale = f32[] parameter(4)
+      x_scale_bcast = f32[14,31] broadcast(x_scale), dimensions={}
+      y_scale_bcast = f32[31,14] broadcast(y_scale), dimensions={}
+      x_unscaled = f32[14,31] multiply(x_f32, x_scale_bcast)
+      y_unscaled = f32[31,14] multiply(y_f32, y_scale_bcast)
+      dot_a = f32[14,14] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+      ROOT out = add(dot_a, b)
+          }
+
+)";
+
+  CheckFp8IfSupported(hlo_text);
+  RunAndFilecheckHloRewrite(
+      hlo_text,
+      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+                   GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+      R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[14,31], {{.*}}: <<F8E4M3>>[31,14], {{.*}}: f32[14,14], {{.*}}: f32[], {{.*}}: f32[]) -> f32[14,14] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = <<F8E4M3>>[14,31]{1,0} parameter(0)
+; CHECK-NEXT:    [[C0:%[^ ]+]] = <<F8E4M3>>[] constant(0)
+; CHECK-NEXT:    [[P0_PADDED:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} pad([[P0]], [[C0]]), padding=0_2x0_1
+; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[31,14]{1,0} parameter(1)
+; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[14,31]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK-NEXT:    [[C1:%[^ ]+]] = <<F8E4M3>>[] constant(0)
+; CHECK-NEXT:    [[P1_TRANSPOSE_PADDED:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} pad([[P1_TRANSPOSE]], [[C1]]), padding=0_2x0_1
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[14,14]{1,0} parameter(2)
+; CHECK-NEXT:    [[C2:%[^ ]+]] = f32[] constant(0)
+; CHECK-NEXT:    [[P2_PADDED:%[^ ]+]] = f32[16,16]{1,0} pad([[P2]], [[C2]]), padding=0_2x0_2
+; CHECK-NEXT:    [[P3:%[^ ]+]] = f32[] parameter(3)
+; CHECK-NEXT:    [[P4:%[^ ]+]] = f32[] parameter(4)
+; CHECK-NEXT:    [[C3:%[^ ]+]] = f32[] constant(1)
+; CHECK-NEXT:    [[DOT_TUPLE:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_PADDED]], [[P1_TRANSPOSE_PADDED]], [[P2_PADDED]], [[P3]], [[P4]], /*index=5*/[[C3]], [[C3]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":1
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+; CHECK:      [[DOT:%[^ ]+]] = f32[16,16]{1,0} get-tuple-element([[DOT_TUPLE]]), index=0
+; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[14,14]{1,0} slice([[DOT]]), slice={[0:14], [0:14]}
+      )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABScaledDF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif  // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif  // TF_ROCM_VERSION < 60000
+
+  const char* hlo_text = R"(
+    HloModule test
+
+    ENTRY test {
+      x = <<F8E4M3>>[16,32] parameter(0)
+      y = <<F8E4M3>>[32,16] parameter(1)
+      z_scale = f32[] parameter(2)
+      z_scale_bcast = f32[16,16] broadcast(z_scale), dimensions={}
+      dot_a = f32[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+      dot_a_scaled = f32[16,16] divide(dot_a, z_scale_bcast)
+      c1 = f32[] constant(-448.)
+      c1_bcast = f32[16,16] broadcast(c1), dimensions={}
+      c2 = f32[] constant(448.)
+      c2_bcast = f32[16,16] broadcast(c2), dimensions={}
+      dot_a_clamped = f32[16,16] clamp(c1_bcast, dot_a_scaled, c2_bcast)
+      ROOT dot_a_f8 = <<F8E4M3>>[16,16] convert(dot_a_clamped)
+          }
+
+)";
+
+  CheckFp8IfSupported(hlo_text, ErrorSpec{1e-2, 1e-1});
+  RunAndFilecheckHloRewrite(
+      hlo_text,
+      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+                   GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+      R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[]) -> <<F8E4M3>>[16,16] {
+; CHECK:         [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
+; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK-NEXT:    [[C0:%[^ ]+]] = f32[] constant(1)
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[] parameter(2)
+; CHECK-NEXT:    [[P2_INV:%[^ ]+]] = f32[] divide([[C0]], [[P2]])
+; CHECK-NEXT:    [[C1:%[^ ]+]] = f32[] constant(1)
+; CHECK-PTX-NEXT:    [[OUT:%[^ ]+]] = (<<F8E4M3>>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_INV]], [[C1]], [[C1]], /*index=5*/[[C1]]),
+; CHECK-GCN-NEXT:    [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_INV]], [[C1]], [[C1]], /*index=5*/[[C1]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+      )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABScaledF32DF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+  GTEST_SKIP() << "F8 GEMM rewrite requires CUDA 12 and above.";
+#endif  // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif  // TF_ROCM_VERSION < 60000
+
+  const char* hlo_text = R"(
+    HloModule test
+
+    ENTRY test {
+      x = <<F8E4M3>>[16,32] parameter(0)
+      y = <<F8E4M3>>[32,16] parameter(1)
+      z_scale = f32[] parameter(2)
+      z_scale_bcast = f32[16,16] broadcast(z_scale), dimensions={}
+      dot_a = f32[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+      ROOT dot_a_scaled = f32[16,16] divide(dot_a, z_scale_bcast)
+          }
+
+)";
+
+  CheckFp8IfSupported(hlo_text, ErrorSpec{1e-2, 1e-1});
+  RunAndFilecheckHloRewrite(
+      hlo_text,
+      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+                   GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+      R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[]) -> f32[16,16] {
+; CHECK:         [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
+; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK-NEXT:    [[C0:%[^ ]+]] = f32[] constant(1)
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[] parameter(2)
+; CHECK-NEXT:    [[P2_INV:%[^ ]+]] = f32[] divide([[C0]], [[P2]])
+; CHECK-NEXT:    [[C1:%[^ ]+]] = f32[] constant(1)
+; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_INV]], [[C1]], [[C1]], /*index=5*/[[C1]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+      )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABInvScaledF32DF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+  GTEST_SKIP() << "F8 GEMM rewrite requires CUDA 12 and above.";
+#endif  // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif  // TF_ROCM_VERSION < 60000
+
+  const char* hlo_text = R"(
+    HloModule test
+
+    ENTRY test {
+      x = <<F8E4M3>>[16,32] parameter(0)
+      y = <<F8E4M3>>[32,16] parameter(1)
+      z_scale = f32[] parameter(2)
+      z_scale_bcast = f32[16,16] broadcast(z_scale), dimensions={}
+      dot_a = f32[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+      ROOT dot_a_scaled = f32[16,16] multiply(dot_a, z_scale_bcast)
+          }
+
+)";
+
+  CheckFp8IfSupported(hlo_text, ErrorSpec{1e-2, 1e-1});
+  RunAndFilecheckHloRewrite(
+      hlo_text,
+      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+                   GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+      R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[]) -> f32[16,16] {
+; CHECK:         [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
+; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[] parameter(2)
+; CHECK-NEXT:    [[C0:%[^ ]+]] = f32[] constant(1)
+; CHECK-NEXT:    [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[C0]], [[C0]], /*index=5*/[[C0]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+      )");
+}
+
+// Do not fuse output scaling without type conversion when a matrix bias was
+// fused.
+TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABScaledF32DMatrixBiasF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+  GTEST_SKIP() << "F8 GEMM rewrite requires CUDA 12 and above.";
+#endif  // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif  // TF_ROCM_VERSION < 60000
+
+  const char* hlo_text = R"(
+    HloModule test
+
+    ENTRY test {
+      x = <<F8E4M3>>[16,32] parameter(0)
+      y = <<F8E4M3>>[32,16] parameter(1)
+      b = f32[16,16] parameter(2)
+      z_scale = f32[] parameter(3)
+      z_scale_bcast = f32[16,16] broadcast(z_scale), dimensions={}
+      dot_a = f32[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+      dot_a_bias = f32[16,16] add(dot_a, b)
+      ROOT dot_a_scaled = f32[16,16] divide(dot_a_bias, z_scale_bcast)
+          }
+
+)";
+
+  CheckFp8IfSupported(hlo_text, ErrorSpec{1e-2, 1e-1});
+  RunAndFilecheckHloRewrite(
+      hlo_text,
+      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+                   GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+      R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[]) -> f32[16,16] {
+; CHECK:         [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
+; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[16,16]{1,0} parameter(2)
+; CHECK-NEXT:    [[C0:%[^ ]+]] = f32[] constant(1)
+; CHECK-NEXT:    [[GEMM_TUPLE:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[C0]], [[C0]], /*index=5*/[[C0]], [[C0]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":1
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK-PTX-NEXT:    [[GEMM:%[^ ]+]] = f32[16,16]{1,0} get-tuple-element([[GEMM_TUPLE]]), index=0
+; CHECK-PTX-NEXT:    [[P3:%[^ ]+]] = f32[] parameter(3)
+; CHECK-PTX-NEXT:    [[P3_BCAST:%[^ ]+]] = f32[16,16]{1,0} broadcast([[P3]]), dimensions={}
+; CHECK-PTX-NEXT:    ROOT [[OUT:%[^ ]+]] = f32[16,16]{1,0} divide([[GEMM]], [[P3_BCAST]])
+; CHECK:           }
+      )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif  // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif  // TF_ROCM_VERSION < 60000
+
+  const char* hlo_text = R"(
+    HloModule test
+
+    ENTRY test {
+      x = <<F8E4M3>>[16,32] parameter(0)
+      y = <<F8E4M3>>[32,16] parameter(1)
+      x_f32 = f32[16,32] convert(x)
+      y_f32 = f32[32,16] convert(y)
+      x_scale = f32[] parameter(2)
+      y_scale = f32[] parameter(3)
+      z_scale = f32[] parameter(4)
+      x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
+      y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
+      z_scale_bcast = f32[16,16] broadcast(z_scale), dimensions={}
+      x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
+      y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast)
+      dot_a = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+      dot_a_scaled = f32[16,16] divide(dot_a, z_scale_bcast)
+      c1 = f32[] constant(-<<F8E4M3_AMAX>>)
+      c1_bcast = f32[16,16] broadcast(c1), dimensions={}
+      c2 = f32[] constant(<<F8E4M3_AMAX>>)
+      c2_bcast = f32[16,16] broadcast(c2), dimensions={}
+      dot_a_clamped = f32[16,16] clamp(c1_bcast, dot_a_scaled, c2_bcast)
+      ROOT dot_a_f8 = <<F8E4M3>>[16,16] convert(dot_a_clamped)
+          }
+
+)";
+
+  CheckFp8IfSupported(hlo_text);
+  RunAndFilecheckHloRewrite(
+      hlo_text,
+      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+                   GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+      R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[], {{.*}}: f32[], {{.*}}: f32[]) -> <<F8E4M3>>[16,16] {
+; CHECK:         [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
+; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[] parameter(2)
+; CHECK-NEXT:    [[P3:%[^ ]+]] = f32[] parameter(3)
+; CHECK-NEXT:    [[C1:%[^ ]+]] = f32[] constant(1)
+; CHECK-PTX-NEXT:    [[C2:%[^ ]+]] = f32[] constant(1)
+; CHECK-PTX-NEXT:    [[P4:%[^ ]+]] = f32[] parameter(4)
+; CHECK-PTX-NEXT:    [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]])
+; CHECK-PTX-NEXT:    [[OUT:%[^ ]+]] = (<<F8E4M3>>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[P4_INV]]),
+; CHECK-GCN-NEXT:    [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+      )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABInvScaledDF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif  // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif  // TF_ROCM_VERSION < 60000
+
+  const char* hlo_text = R"(
+    HloModule test
+
+    ENTRY test {
+      x = <<F8E4M3>>[16,32] parameter(0)
+      y = <<F8E4M3>>[32,16] parameter(1)
+      x_f32 = f32[16,32] convert(x)
+      y_f32 = f32[32,16] convert(y)
+      x_scale = f32[] parameter(2)
+      y_scale = f32[] parameter(3)
+      z_scale = f32[] parameter(4)
+      x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
+      y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
+      z_scale_bcast = f32[16,16] broadcast(z_scale), dimensions={}
+      x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
+      y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast)
+      dot_a = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+      dot_a_scaled = f32[16,16] multiply(dot_a, z_scale_bcast)
+      c1 = f32[] constant(-<<F8E4M3_AMAX>>)
+      c1_bcast = f32[16,16] broadcast(c1), dimensions={}
+      c2 = f32[] constant(<<F8E4M3_AMAX>>)
+      c2_bcast = f32[16,16] broadcast(c2), dimensions={}
+      dot_a_clamped = f32[16,16] clamp(c1_bcast, dot_a_scaled, c2_bcast)
+      ROOT dot_a_f8 = <<F8E4M3>>[16,16] convert(dot_a_clamped)
+          }
+
+)";
+
+  CheckFp8IfSupported(hlo_text);
+  RunAndFilecheckHloRewrite(
+      hlo_text,
+      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+                   GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+      R"(
+
+; CHECK-NOT:     divide
+
+; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
+
+      )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDReluActivationF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif  // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif  // TF_ROCM_VERSION < 60000
+
+  const char* hlo_text = R"(
+    HloModule test
+    ENTRY test {
+      x = <<F8E4M3>>[16,32] parameter(0)
+      y = <<F8E4M3>>[32,16] parameter(1)
+      x_f32 = f32[16,32] convert(x)
+      y_f32 = f32[32,16] convert(y)
+      x_scale = f32[] parameter(2)
+      y_scale = f32[] parameter(3)
+      z_scale = f32[] parameter(4)
+      x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
+      y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
+      z_scale_bcast = f32[16,16] broadcast(z_scale), dimensions={}
+      x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
+      y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast)
+      c = f32[] constant(0)
+      c_bcast = f32[16,16] broadcast(c), dimensions={}
+      dot_a = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+      relu_a = f32[16,16] maximum(dot_a, c_bcast)
+      relu_a_scaled = f32[16,16] divide(relu_a, z_scale_bcast)
+      c1 = f32[] constant(-<<F8E4M3_AMAX>>)
+      c1_bcast = f32[16,16] broadcast(c1), dimensions={}
+      c2 = f32[] constant(<<F8E4M3_AMAX>>)
+      c2_bcast = f32[16,16] broadcast(c2), dimensions={}
+      relu_a_clamped = f32[16,16] clamp(c1_bcast, relu_a_scaled, c2_bcast)
+      ROOT out = <<F8E4M3>>[16,16] convert(relu_a_clamped)
+          }
+)";
+
+  CheckFp8IfSupported(hlo_text);
+  RunAndFilecheckHloRewrite(
+      hlo_text,
+      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+                   GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+      R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[], {{.*}}: f32[], {{.*}}: f32[]) -> <<F8E4M3>>[16,16] {
+; CHECK:         [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
+; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[] parameter(2)
+; CHECK-NEXT:    [[P3:%[^ ]+]] = f32[] parameter(3)
+; CHECK-NEXT:    [[C1:%[^ ]+]] = f32[] constant(1)
+; CHECK-PTX-NEXT:    [[C2:%[^ ]+]] = f32[] constant(1)
+; CHECK-PTX-NEXT:    [[P4:%[^ ]+]] = f32[] parameter(4)
+; CHECK-PTX-NEXT:    [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]])
+; CHECK-PTX-NEXT:    [[OUT:%[^ ]+]] = (<<F8E4M3>>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[P4_INV]]),
+; CHECK-GCN-NEXT:    [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"RELU"
+; CHECK:           }
+      )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDMatrixBiasWithDAmaxF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif  // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif  // TF_ROCM_VERSION < 60000
+
+  const char* hlo_text = R"(
+    HloModule test
+
+    apply {
+      a = f16[] parameter(0)
+      b = f16[] parameter(1)
+      ROOT c = f16[] maximum(a, b)
+    }
+
+    ENTRY test {
+      x = <<F8E4M3>>[16,32] parameter(0)
+      y = <<F8E4M3>>[32,16] parameter(1)
+      x_f16 = f16[16,32] convert(x)
+      y_f16 = f16[32,16] convert(y)
+      b = f16[16,16] parameter(2)
+      one = f16[] constant(1)
+      ones = f16[16,16] broadcast(one), dimensions={}
+      b_ones = f16[16,16] add(b, ones)
+      x_scale = f16[] parameter(3)
+      y_scale = f16[] parameter(4)
+      z_scale = f16[] parameter(5)
+      x_scale_bcast = f16[16,32] broadcast(x_scale), dimensions={}
+      y_scale_bcast = f16[32,16] broadcast(y_scale), dimensions={}
+      z_scale_bcast = f16[16,16] broadcast(z_scale), dimensions={}
+      x_unscaled = f16[16,32] multiply(x_f16, x_scale_bcast)
+      y_unscaled = f16[32,16] multiply(y_f16, y_scale_bcast)
+      dot_a = f16[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+      dot_a_bias = f16[16,16] add(dot_a, b_ones)
+      abs_dot_a = f16[16,16] abs(dot_a_bias)
+      c0 = f16[] constant(-inf)
+      amax = f16[] reduce(abs_dot_a, c0), dimensions={0,1}, to_apply=apply
+      dot_a_scaled = f16[16,16] divide(dot_a_bias, z_scale_bcast)
+      c1 = f16[] constant(-<<F8E4M3_AMAX>>)
+      c1_bcast = f16[16,16] broadcast(c1), dimensions={}
+      c2 = f16[] constant(<<F8E4M3_AMAX>>)
+      c2_bcast = f16[16,16] broadcast(c2), dimensions={}
+      dot_a_clamped = f16[16,16] clamp(c1_bcast, dot_a_scaled, c2_bcast)
+      dot_a_f8 = <<F8E4M3>>[16,16] convert(dot_a_clamped)
+      ROOT result = (<<F8E4M3>>[16,16], f16[]) tuple(dot_a_f8, amax)
+    }
+)";
+
+  CheckFp8IfSupported(hlo_text, ErrorSpec{0.1, 0.1});
+  RunAndFilecheckHloRewrite(
+      hlo_text,
+      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+                   GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+      R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f16[16,16], {{.*}}: f16[], {{.*}}: f16[], {{.*}}: f16[]) -> (<<F8E4M3>>[16,16], f16[]) {
+; CHECK:    [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
+; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK:         [[C0:%[^ ]+]] = f16[16,16]{1,0} add({{.*}})
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f16[] parameter(3)
+; CHECK:         [[P3:%[^ ]+]] = f16[] parameter(4)
+; CHECK:         [[C1:%[^ ]+]] = f32[] constant(1)
+; CHECK-PTX:         [[P4:%[^ ]+]] = f16[] parameter(5)
+; CHECK-PTX:       [[OUT:%[^ ]+]] = (<<F8E4M3>>[16,16]{1,0}, f32[], s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C0]], [[DUMMY0:%[^ ]+]], [[DUMMY1:%[^ ]+]], /*index=5*/[[C1]], [[DUMMY2:%[^ ]+]]),
+; CHECK-NOT:       output_to_operand_aliasing
+; CHECK-GCN:       [[OUT:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C0]], [[DUMMY0:%[^ ]+]], [[DUMMY1:%[^ ]+]], /*index=5*/[[C1]], [[DUMMY2:%[^ ]+]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":1
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+      )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDVectorBiasF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif  // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif  // TF_ROCM_VERSION < 60000
+
+  const char* hlo_text = R"(
+    HloModule test
+
+    ENTRY test {
+      x = <<F8E4M3>>[16,32] parameter(0)
+      y = <<F8E4M3>>[32,16] parameter(1)
+      x_f16 = f16[16,32] convert(x)
+      y_f16 = f16[32,16] convert(y)
+      b = f16[16] parameter(2)
+      b_bcast = f16[16,16] broadcast(b), dimensions={1}
+      x_scale = f16[] parameter(3)
+      y_scale = f16[] parameter(4)
+      z_scale = f16[] parameter(5)
+      x_scale_bcast = f16[16,32] broadcast(x_scale), dimensions={}
+      y_scale_bcast = f16[32,16] broadcast(y_scale), dimensions={}
+      z_scale_bcast = f16[16,16] broadcast(z_scale), dimensions={}
+      x_unscaled = f16[16,32] multiply(x_f16, x_scale_bcast)
+      y_unscaled = f16[32,16] multiply(y_f16, y_scale_bcast)
+      dot_a = f16[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+      dot_a_bias = f16[16,16] add(dot_a, b_bcast)
+      dot_a_scaled = f16[16,16] divide(dot_a_bias, z_scale_bcast)
+      c1 = f16[] constant(-<<F8E4M3_AMAX>>)
+      c1_bcast = f16[16,16] broadcast(c1), dimensions={}
+      c2 = f16[] constant(<<F8E4M3_AMAX>>)
+      c2_bcast = f16[16,16] broadcast(c2), dimensions={}
+      dot_a_clamped = f16[16,16] clamp(c1_bcast, dot_a_scaled, c2_bcast)
+      ROOT dot_a_f8 = <<F8E4M3>>[16,16] convert(dot_a_clamped)
+          }
+
+)";
+
+  CheckFp8IfSupported(hlo_text, ErrorSpec{0.1, 0.1});
+  RunAndFilecheckHloRewrite(
+      hlo_text,
+      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+                   GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+      R"(
+
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f16[16], {{.*}}: f16[], {{.*}}: f16[], {{.*}}: f16[]) -> <<F8E4M3>>[16,16] {
+; CHECK:    [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
+; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f16[] parameter(3)
+; CHECK-NEXT:    [[CV:%[^ ]+]] = f32[] convert([[P2]])
+; CHECK-NEXT:    [[P3:%[^ ]+]] = f16[] parameter(4)
+; CHECK-NEXT:    [[CV1:%[^ ]+]] = f32[] convert([[P3]])
+; CHECK-NEXT:    [[C:%[^ ]+]] = f32[] constant(1)
+; CHECK-PTX-NEXT:    [[C2:%[^ ]+]] = f16[] constant(1)
+; CHECK-PTX-NEXT:    [[P4:%[^ ]+]] = f16[] parameter(5)
+; CHECK-PTX-NEXT:    [[DV:%[^ ]+]] = f16[] divide([[C2]], [[P4]])
+; CHECK-PTX-NEXT:    [[CV2:%[^ ]+]] = f32[] convert([[DV]])
+; CHECK-NEXT:    [[VB:%[^ ]+]] = f16[16]{0} parameter(2)
+; CHECK-PTX:         [[OUT:%[^ ]+]] = (<<F8E4M3>>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[CV]], [[CV1]], [[C]], /*index=5*/[[CV2]], [[VB]]),
+; CHECK-GCN:         [[OUT:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[CV]], [[CV1]], [[C]], /*index=5*/[[C]], [[VB]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"BIAS"
+; CHECK:           }
+      )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF32VectorBiasF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif  // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif  // TF_ROCM_VERSION < 60000
+
+  const char* hlo_text = R"(
+    HloModule test
+
+    ENTRY test {
+      x = <<F8E4M3>>[16,32] parameter(0)
+      y = <<F8E4M3>>[32,16] parameter(1)
+      x_f32 = f32[16,32] convert(x)
+      y_f32 = f32[32,16] convert(y)
+      b = f32[16] parameter(2)
+      b_bf16 = bf16[16] convert(b)
+      b_f32 = f32[16] convert(b_bf16)
+      b_bcast = f32[16,16] broadcast(b_f32), dimensions={1}
+      x_scale = f32[] parameter(3)
+      y_scale = f32[] parameter(4)
+      x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
+      y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
+      x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
+      y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast)
+      dot_a = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+      ROOT out = f32[16,16] add(dot_a, b_bcast)
+           }
+
+)";
+
+  CheckFp8IfSupported(hlo_text);
+  RunAndFilecheckHloRewrite(
+      hlo_text,
+      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+                   GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+      R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] {
+; CHECK:         [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
+; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[] parameter(3)
+; CHECK-NEXT:    [[P3:%[^ ]+]] = f32[] parameter(4)
+; CHECK-NEXT:    [[C:%[^ ]+]] = f32[] constant(1)
+; CHECK-NEXT:    [[VB:%[^ ]+]] = f32[16]{0} parameter(2)
+; CHECK-NEXT:    [[VBC:%[^ ]+]] = bf16[16]{0} convert([[VB]])
+; CHECK:         [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C]], /*index=5*/[[C]], [[VBC]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"BIAS"
+; CHECK:           }
+      )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest,
+       ScaledABUnscaledDVectorBiasThenReluActivationF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif  // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif  // TF_ROCM_VERSION < 60000
+
+  const char* hlo_text = R"(
+    HloModule test
+
+    ENTRY test {
+      x = <<F8E4M3>>[16,32] parameter(0)
+      y = <<F8E4M3>>[32,16] parameter(1)
+      b = f16[16] parameter(2)
+      b_bcast = f16[16,16] broadcast(b), dimensions={1}
+      x_f32 = f16[16,32] convert(x)
+      y_f32 = f16[32,16] convert(y)
+      x_scale = f16[] parameter(3)
+      y_scale = f16[] parameter(4)
+      x_scale_bcast = f16[16,32] broadcast(x_scale), dimensions={}
+      y_scale_bcast = f16[32,16] broadcast(y_scale), dimensions={}
+      x_unscaled = f16[16,32] multiply(x_f32, x_scale_bcast)
+      y_unscaled = f16[32,16] multiply(y_f32, y_scale_bcast)
+      c = f16[] constant(0)
+      c_bcast = f16[16,16] broadcast(c), dimensions={}
+      dot_a0 = f16[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+      dot_a = f16[16,16] add(dot_a0, b_bcast)
+      ROOT out = f16[16,16] maximum(dot_a, c_bcast)
+          }
+)";
+
+  CheckFp8IfSupported(hlo_text, ErrorSpec{2e-3, 0.});
+  RunAndFilecheckHloRewrite(
+      hlo_text,
+      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+                   GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+      R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f16[16], {{.*}}: f16[], {{.*}}: f16[]) -> f16[16,16] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
+; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f16[] parameter(3)
+; CHECK-NEXT:    [[CV:%[^ ]+]] = f32[] convert([[P2]])
+; CHECK-NEXT:    [[P3:%[^ ]+]] = f16[] parameter(4)
+; CHECK-NEXT:    [[CV1:%[^ ]+]] = f32[] convert([[P3]])
+; CHECK-NEXT:    [[C:%[^ ]+]] = f32[] constant(1)
+; CHECK-NEXT:    [[VB:%[^ ]+]] = f16[16]{0} parameter(2)
+; CHECK     :    ROOT [[OUT:%[^ ]+]] = f16[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[CV]], [[CV1]], [[C]], /*index=5*/[[C]], [[VB]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"BIAS_RELU"
+; CHECK:           }
+      )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDVectorBiasF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+  GTEST_SKIP() << "A matrix bias on a matmul is only supported in CUDA 12";
+#endif  // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif  // TF_ROCM_VERSION < 60000
+
+  const char* hlo_text = R"(
+    HloModule test
+    ENTRY test {
+      x = <<F8E4M3>>[4,16,16] parameter(0)
+      y = <<F8E4M3>>[16,32] parameter(1)
+      b = f32[32] parameter(2)
+      b_f16 = f16[32] convert(b)
+      b_bcast = f16[4,16,32] broadcast(b_f16), dimensions={2}
+      x_f16 = f16[4,16,16] convert(x)
+      y_f16 = f16[16,32] convert(y)
+      x_scale = f16[] parameter(3)
+      y_scale = f16[] parameter(4)
+      x_scale_bcast = f16[4,16,16] broadcast(x_scale), dimensions={}
+      y_scale_bcast = f16[16,32] broadcast(y_scale), dimensions={}
+      x_unscaled = f16[4,16,16] multiply(x_f16, x_scale_bcast)
+      x_unscaled_bitcast = f16[64,16] bitcast(x_unscaled)
+      y_unscaled = f16[16,32] multiply(y_f16, y_scale_bcast)
+      dot_a = f16[64,32] dot(x_unscaled_bitcast, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+      dot_a_bitcast = f16[4,16,32]{2,1,0} bitcast(dot_a)
+      ROOT out = f16[4,16,32] add(dot_a_bitcast, b_bcast)
+          }
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_text));
+  GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+                    GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only});
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
+  EXPECT_TRUE(changed);
+
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(m::Bitcast(m::GetTupleElement(
+                                m::CustomCall({"__cublas$lt$matmul$f8"}), 0)
+                                .WithShape(F16, {64, 32}))
+                     .WithShape(F16, {4, 16, 32})));
+
+  RunAndFilecheckHloRewrite(
+      hlo_text,
+      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+                   GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+      R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[4,16,16], {{.*}}: <<F8E4M3>>[16,32], {{.*}}: f32[32], {{.*}}: f16[], {{.*}}: f16[]) -> f16[4,16,32] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = <<F8E4M3>>[4,16,16]{2,1,0} parameter(0)
+; CHECK-NEXT:    [[P0_BITCAST:%[^ ]+]] = <<F8E4M3>>[64,16]{1,0} bitcast([[P0]])
+; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(1)
+; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f16[] parameter(3)
+; CHECK-NEXT:    [[P2_CV:%[^ ]+]] = f32[] convert([[P2]])
+; CHECK-NEXT:    [[P3:%[^ ]+]] = f16[] parameter(4)
+; CHECK-NEXT:    [[P3_CV:%[^ ]+]] = f32[] convert([[P3]])
+; CHECK-NEXT:    [[C:%[^ ]+]] = f32[] constant(1)
+; CHECK-NEXT:    [[B:%[^ ]+]] = f32[32]{0} parameter(2)
+; CHECK-NEXT:    [[B_F16:%[^ ]+]] = f16[32]{0} convert([[B]])
+; CHECK-NEXT:    [[GEMM_TUPLE:%[^ ]+]] = (f16[64,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_BITCAST]], [[P1_TRANSPOSE]], [[P2_CV]], [[P3_CV]], [[C]], /*index=5*/[[C]], [[B_F16]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"BIAS"
+; CHECK:           }
+; CHECK:         [[GEMM:%[^ ]+]] = f16[64,32]{1,0} get-tuple-element([[GEMM_TUPLE]]), index=0
+; CHECK:         ROOT [[OUT:%[^ ]+]] = f16[4,16,32]{2,1,0} bitcast([[GEMM]])
+      )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest,
+       Rank3ScaledABUnscaledDVectorBiasPaddedF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+  GTEST_SKIP() << "A matrix bias on a matmul is only supported in CUDA 12";
+#endif
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif  // TF_ROCM_VERSION < 60000
+
+  const char* hlo_text = R"(
+    HloModule test
+    ENTRY test {
+      x = <<F8E4M3>>[4,15,15] parameter(0)
+      y = <<F8E4M3>>[15,31] parameter(1)
+      b = f32[31] parameter(2)
+      b_f16 = f16[31] convert(b)
+      b_bcast = f16[4,15,31] broadcast(b_f16), dimensions={2}
+      x_f16 = f16[4,15,15] convert(x)
+      y_f16 = f16[15,31] convert(y)
+      x_scale = f16[] parameter(3)
+      y_scale = f16[] parameter(4)
+      x_scale_bcast = f16[4,15,15] broadcast(x_scale), dimensions={}
+      y_scale_bcast = f16[15,31] broadcast(y_scale), dimensions={}
+      x_unscaled = f16[4,15,15] multiply(x_f16, x_scale_bcast)
+      x_unscaled_bitcast = f16[60,15] bitcast(x_unscaled)
+      y_unscaled = f16[15,31] multiply(y_f16, y_scale_bcast)
+      dot_a = f16[60,31] dot(x_unscaled_bitcast, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+      dot_a_bitcast = f16[4,15,31]{2,1,0} bitcast(dot_a)
+      ROOT out = f16[4,15,31] add(dot_a_bitcast, b_bcast)
+          }
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_text));
+  GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+                    GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only});
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
+  EXPECT_TRUE(changed);
+
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(
+          m::Bitcast(m::Slice(m::GetTupleElement(
+                                  m::CustomCall({"__cublas$lt$matmul$f8"}), 0)
+                                  .WithShape(F16, {64, 32}))
+                         .WithShape(F16, {60, 31}))
+              .WithShape(F16, {4, 15, 31})));
+
+  RunAndFilecheckHloRewrite(
+      hlo_text,
+      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+                   GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+      R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[4,15,15], {{.*}}: <<F8E4M3>>[15,31], {{.*}}: f32[31], {{.*}}: f16[], {{.*}}: f16[]) -> f16[4,15,31] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = <<F8E4M3>>[4,15,15]{2,1,0} parameter(0)
+; CHECK-NEXT:    [[P0_BITCAST:%[^ ]+]] = <<F8E4M3>>[60,15]{1,0} bitcast([[P0]])
+; CHECK-NEXT:    [[C1:%[^ ]+]] = <<F8E4M3>>[] constant(0)
+; CHECK-NEXT:    [[P0_PAD:%[^ ]+]] = <<F8E4M3>>[64,16]{1,0} pad([[P0_BITCAST]], [[C1]]), padding=0_4x0_1
+; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[15,31]{1,0} parameter(1)
+; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[31,15]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK-NEXT:    [[C2:%[^ ]+]] = <<F8E4M3>>[] constant(0)
+; CHECK-NEXT:    [[P1_PAD:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} pad([[P1_TRANSPOSE]], [[C2]]), padding=0_1x0_1
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f16[] parameter(3)
+; CHECK-NEXT:    [[P2_CV:%[^ ]+]] = f32[] convert([[P2]])
+; CHECK-NEXT:    [[P3:%[^ ]+]] = f16[] parameter(4)
+; CHECK-NEXT:    [[P3_CV:%[^ ]+]] = f32[] convert([[P3]])
+; CHECK-NEXT:    [[C:%[^ ]+]] = f32[] constant(1)
+; CHECK-NEXT:    [[B:%[^ ]+]] = f32[31]{0} parameter(2)
+; CHECK-NEXT:    [[B_F16:%[^ ]+]] = f16[31]{0} convert([[B]])
+; CHECK-NEXT:    [[C3:%[^ ]+]] = f16[] constant(0)
+; CHECK-NEXT:    [[P2_PAD:%[^ ]+]] = f16[32]{0} pad([[B_F16]], [[C3]]), padding=0_1
+; CHECK-NEXT:    [[GEMM_TUPLE:%[^ ]+]] = (f16[64,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_PAD]], [[P1_PAD]], [[P2_CV]], [[P3_CV]], [[C]], /*index=5*/[[C]], [[P2_PAD]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"BIAS"
+; CHECK:           }
+; CHECK:          [[GEMM:%[^ ]+]] = f16[64,32]{1,0} get-tuple-element([[GEMM_TUPLE]]), index=0
+; CHECK-NEXT:     [[SLICE:%[^ ]+]] = f16[60,31]{1,0} slice([[GEMM]]), slice={[0:60], [0:31]}
+; CHECK-NEXT:     ROOT [[OUT:%[^ ]+]] = f16[4,15,31]{2,1,0} bitcast([[SLICE]])
+      )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDMatrixBiasF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+  GTEST_SKIP() << "A matrix bias on a matmul is only supported in CUDA 12";
+#endif
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif  // TF_ROCM_VERSION < 60000
+
+  const char* hlo_text = R"(
+    HloModule test
+    ENTRY test {
+      x = <<F8E4M3>>[4,16,16] parameter(0)
+      y = <<F8E4M3>>[16,32] parameter(1)
+      b = f32[4,16,32] parameter(2)
+      x_f32 = f32[4,16,16] convert(x)
+      y_f32 = f32[16,32] convert(y)
+      x_scale = f32[] parameter(3)
+      y_scale = f32[] parameter(4)
+      x_scale_bcast = f32[4,16,16] broadcast(x_scale), dimensions={}
+      y_scale_bcast = f32[16,32] broadcast(y_scale), dimensions={}
+      x_unscaled = f32[4,16,16] multiply(x_f32, x_scale_bcast)
+      x_unscaled_bitcast = f32[64,16] bitcast(x_unscaled)
+      y_unscaled = f32[16,32] multiply(y_f32, y_scale_bcast)
+      dot_a = f32[64,32] dot(x_unscaled_bitcast, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+      dot_a_bitcast = f32[4,16,32]{2,1,0} bitcast(dot_a)
+      ROOT out = f32[4,16,32] add(dot_a_bitcast, b)
+          }
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_text));
+  GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+                    GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only});
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
+  EXPECT_TRUE(changed);
+
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(m::Bitcast(m::GetTupleElement(
+                                m::CustomCall({"__cublas$lt$matmul$f8"}), 0)
+                                .WithShape(F32, {64, 32}))
+                     .WithShape(F32, {4, 16, 32})));
+
+  RunAndFilecheckHloRewrite(
+      hlo_text,
+      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+                   GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+      R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[4,16,16], {{.*}}: <<F8E4M3>>[16,32], {{.*}}: f32[4,16,32], {{.*}}: f32[], {{.*}}: f32[]) -> f32[4,16,32] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = <<F8E4M3>>[4,16,16]{2,1,0} parameter(0)
+; CHECK-NEXT:    [[P0_BITCAST:%[^ ]+]] = <<F8E4M3>>[64,16]{1,0} bitcast([[P0]])
+; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(1)
+; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK-NEXT:    [[B:%[^ ]+]] = f32[4,16,32]{2,1,0} parameter(2)
+; CHECK-NEXT:    [[B_BITCAST:%[^ ]+]] = f32[64,32]{1,0} bitcast([[B]])
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[] parameter(3)
+; CHECK-NEXT:    [[P3:%[^ ]+]] = f32[] parameter(4)
+; CHECK-NEXT:    [[C:%[^ ]+]] = f32[] constant(1)
+; CHECK-NEXT:    [[GEMM_TUPLE:%[^ ]+]] = (f32[64,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_BITCAST]], [[P1_TRANSPOSE]], [[B_BITCAST]], [[P2]], [[P3]], /*index=5*/[[C]], [[C]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":1
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+; CHECK:         [[GEMM:%[^ ]+]] = f32[64,32]{1,0} get-tuple-element([[GEMM_TUPLE]]), index=0
+; CHECK:         ROOT [[OUT:%[^ ]+]] = f32[4,16,32]{2,1,0} bitcast([[GEMM]])
+      )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest,
+       Rank3ScaledABUnscaledDMatrixBiasPaddedF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+  GTEST_SKIP() << "A matrix bias on a matmul is only supported in CUDA 12";
+#endif
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif  // TF_ROCM_VERSION < 60000
+
+  const char* hlo_text = R"(
+    HloModule test
+    ENTRY test {
+      x = <<F8E4M3>>[3,15,15] parameter(0)
+      y = <<F8E4M3>>[15,31] parameter(1)
+      b = f32[3,15,31] parameter(2)
+      x_f32 = f32[3,15,15] convert(x)
+      y_f32 = f32[15,31] convert(y)
+      x_scale = f32[] parameter(3)
+      y_scale = f32[] parameter(4)
+      x_scale_bcast = f32[3,15,15] broadcast(x_scale), dimensions={}
+      y_scale_bcast = f32[15,31] broadcast(y_scale), dimensions={}
+      x_unscaled = f32[3,15,15] multiply(x_f32, x_scale_bcast)
+      x_unscaled_bitcast = f32[45,15] bitcast(x_unscaled)
+      y_unscaled = f32[15,31] multiply(y_f32, y_scale_bcast)
+      dot_a = f32[45,31] dot(x_unscaled_bitcast, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+      dot_a_bitcast = f32[3,15,31]{2,1,0} bitcast(dot_a)
+      ROOT out = f32[3,15,31] add(dot_a_bitcast, b)
+          }
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_text));
+  GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+                    GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only});
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
+  EXPECT_TRUE(changed);
+
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(
+          m::Bitcast(m::Slice(m::GetTupleElement(
+                                  m::CustomCall({"__cublas$lt$matmul$f8"}), 0)
+                                  .WithShape(F32, {48, 32}))
+                         .WithShape(F32, {45, 31}))
+              .WithShape(F32, {3, 15, 31})));
+
+  RunAndFilecheckHloRewrite(
+      hlo_text,
+      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+                   GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+      R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[3,15,15], {{.*}}: <<F8E4M3>>[15,31], {{.*}}: f32[3,15,31], {{.*}}: f32[], {{.*}}: f32[]) -> f32[3,15,31] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = <<F8E4M3>>[3,15,15]{2,1,0} parameter(0)
+; CHECK-NEXT:    [[P0_BITCAST:%[^ ]+]] = <<F8E4M3>>[45,15]{1,0} bitcast([[P0]])
+; CHECK-NEXT:    [[C1:%[^ ]+]] = <<F8E4M3>>[] constant(0)
+; CHECK-NEXT:    [[P0_PADDED:%[^ ]+]] = <<F8E4M3>>[48,16]{1,0} pad([[P0_BITCAST]], [[C1]]), padding=0_3x0_1
+; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[15,31]{1,0} parameter(1)
+; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[31,15]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK-NEXT:    [[C2:%[^ ]+]] = <<F8E4M3>>[] constant(0)
+; CHECK-NEXT:    [[P1_PADDED:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} pad([[P1_TRANSPOSE]], [[C2]]), padding=0_1x0_1
+; CHECK-NEXT:    [[B:%[^ ]+]] = f32[3,15,31]{2,1,0} parameter(2)
+; CHECK-NEXT:    [[B_BITCAST:%[^ ]+]] = f32[45,31]{1,0} bitcast([[B]])
+; CHECK-NEXT:    [[C3:%[^ ]+]] = f32[] constant(0)
+; CHECK-NEXT:    [[P2_PADDED:%[^ ]+]] = f32[48,32]{1,0} pad([[B_BITCAST]], [[C3]]), padding=0_3x0_1
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[] parameter(3)
+; CHECK-NEXT:    [[P3:%[^ ]+]] = f32[] parameter(4)
+; CHECK-NEXT:    [[C:%[^ ]+]] = f32[] constant(1)
+; CHECK-NEXT:    [[GEMM_TUPLE:%[^ ]+]] = (f32[48,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_PADDED]], [[P1_PADDED]], [[P2_PADDED]], [[P2]], [[P3]], /*index=5*/[[C]], [[C]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":1
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+; CHECK-NEXT:      [[GEMM:%[^ ]+]] = f32[48,32]{1,0} get-tuple-element([[GEMM_TUPLE]]), index=0
+; CHECK-NEXT:      [[SLICE:%[^ ]+]] = f32[45,31]{1,0} slice([[GEMM]]), slice={[0:45], [0:31]}
+; CHECK-NEXT:      ROOT [[OUT:%[^ ]+]] = f32[3,15,31]{2,1,0} bitcast([[SLICE]])
+      )");
+}
+
+// Do not fuse matrix bias When there is a slice that does not chop off the ends
+// of dimensions.
+TEST_P(ParameterizedFp8GemmRewriteTest,
+       ScaledABUnscaledDMatrixBiasWithSliceF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+  GTEST_SKIP() << "A matrix bias on a matmul is only supported in CUDA 12";
+#endif
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif  // TF_ROCM_VERSION < 60000
+
+  const char* hlo_text = R"(
+    HloModule test
+    ENTRY test {
+      x = <<F8E4M3>>[48,16] parameter(0)
+      y = <<F8E4M3>>[16,32] parameter(1)
+      b = f32[32,16] parameter(2)
+      x_f32 = f32[48,16] convert(x)
+      y_f32 = f32[16,32] convert(y)
+      x_scale = f32[] parameter(3)
+      y_scale = f32[] parameter(4)
+      x_scale_bcast = f32[48,16] broadcast(x_scale), dimensions={}
+      y_scale_bcast = f32[16,32] broadcast(y_scale), dimensions={}
+      x_unscaled = f32[48,16] multiply(x_f32, x_scale_bcast)
+      y_unscaled = f32[16,32] multiply(y_f32, y_scale_bcast)
+      dot_a = f32[48,32] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+      dot_a_sliced = f32[32,16] slice(dot_a), slice={[16:48], [16:32]}
+      ROOT out = f32[32,16] add(dot_a_sliced, b)
+          }
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_text));
+  GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+                    GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only});
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
+  EXPECT_TRUE(changed);
+
+  RunAndFilecheckHloRewrite(
+      hlo_text,
+      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+                   GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+      R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[48,16], {{.*}}: <<F8E4M3>>[16,32], {{.*}}: f32[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[32,16] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = <<F8E4M3>>[48,16]{1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(1)
+; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[] parameter(3)
+; CHECK-NEXT:    [[P3:%[^ ]+]] = f32[] parameter(4)
+; CHECK-NEXT:    [[C:%[^ ]+]] = f32[] constant(1)
+; CHECK-NEXT:    [[GEMM_TUPLE:%[^ ]+]] = (f32[48,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C]], /*index=5*/[[C]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+; CHECK:           [[GEMM:%[^_]+]] = f32[48,32]{1,0} get-tuple-element([[GEMM_TUPLE]]), index=0
+; CHECK-NEXT:      [[SLICE:%[^ ]+]] = f32[32,16]{1,0} slice([[GEMM]]), slice={[16:48], [16:32]}
+; CHECK-NEXT:      [[B:%[^ ]+]] = f32[32,16]{1,0} parameter(2)
+; CHECK-NEXT:      ROOT [[OUT:%[^ ]+]] = f32[32,16]{1,0} add([[SLICE]], [[B]])
+      )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDWithAllGatherF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+  GTEST_SKIP() << "A matrix bias on a matmul is only supported in CUDA 12";
+#endif
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif  // TF_ROCM_VERSION < 60000
+
+  absl::string_view hlo_text = R"(
+    HloModule test
+
+    ENTRY test {
+      x = <<F8E4M3>>[16,32] parameter(0)
+      y = <<F8E4M3>>[16,32] parameter(1)
+      x_f32 = f32[16,32] convert(x)
+      y_f32 = f32[16,32] convert(y)
+      x_scale = f32[] parameter(2)
+      y_scale = f32[] parameter(3)
+      x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
+      y_scale_bcast = f32[16,32] broadcast(y_scale), dimensions={}
+      x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
+      y_unscaled = f32[16,32] multiply(y_f32, y_scale_bcast)
+      all_gather = f32[16,64]{1,0} all-gather(x_unscaled), channel_id=1, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={1}, use_global_device_ids=true
+      all_gather1 = f32[64,32]{1,0} all-gather(y_unscaled), channel_id=2, replica_groups={{0,2,4,6},{1,3,5,7}}, dimensions={0}, use_global_device_ids=true
+      ROOT dot_a = f32[16,32] dot(all_gather, all_gather1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+          }
+)";
+
+  HloModuleConfig config = GetModuleConfigForTest();
+  config.set_use_spmd_partitioning(true);
+  config.set_num_partitions(8);
+
+  RunAndFilecheckHloRewrite(
+      hlo_text,
+      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+                   GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+      R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[16,32], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,32] {
+; CHECK:         [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
+; CHECK:         [[AG:%[^ ]+]] = <<F8E4M3>>[16,64]{1,0} all-gather([[P0]]), {{[^ ]+}}
+; CHECK:         [[P1:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(1)
+; CHECK:         [[AG1:%[^ ]+]] = <<F8E4M3>>[64,32]{1,0} all-gather([[P1]]), {{[^ ]+}}
+; CHECK:         [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[32,64]{1,0} transpose([[AG1]]), dimensions={1,0}
+; CHECK:         [[P2:%[^ ]+]] = f32[] parameter(2)
+; CHECK:         [[P3:%[^ ]+]] = f32[] parameter(3)
+; CHECK:         [[C:%[^ ]+]] = f32[] constant(1)
+; CHECK:         [[GEMM_TUPLE:%[^ ]+]] = (f32[16,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[AG]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C]], /*index=5*/[[C]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+; CHECK:           ROOT [[GEMM:%[^_]+]] = f32[16,32]{1,0} get-tuple-element([[GEMM_TUPLE]]), index=0
+      )",
+      nullptr, &config);
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDWithAllToAllF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+  GTEST_SKIP() << "A matrix bias on a matmul is only supported in CUDA 12";
+#endif
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif  // TF_ROCM_VERSION < 60000
+
+  absl::string_view hlo_text = R"(
+    HloModule test
+
+    ENTRY test {
+      x = <<F8E4M3>>[16,32] parameter(0)
+      y = <<F8E4M3>>[16,32] parameter(1)
+      x_f32 = f32[16,32] convert(x)
+      y_f32 = f32[16,32] convert(y)
+      x_scale = f32[] parameter(2)
+      y_scale = f32[] parameter(3)
+      x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
+      y_scale_bcast = f32[16,32] broadcast(y_scale), dimensions={}
+      x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
+      y_unscaled = f32[16,32] multiply(y_f32, y_scale_bcast)
+      all_to_all = f32[16,32]{1,0} all-to-all(x_unscaled), channel_id=1, replica_groups={{0,1,2,3},{4,5,6,7}}, dimensions={0}
+      ROOT dot_a = f32[16,16] dot(all_to_all, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={1}
+          }
+)";
+
+  HloModuleConfig config = GetModuleConfigForTest();
+  config.set_use_spmd_partitioning(true);
+  config.set_num_partitions(8);
+
+  RunAndFilecheckHloRewrite(
+      hlo_text,
+      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+                   GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+      R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[16,32], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] {
+; CHECK:         [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
+; CHECK:         [[AA:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} all-to-all([[P0]]), {{[^ ]+}}
+; CHECK:         [[P1:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(1)
+; CHECK:         [[P2:%[^ ]+]] = f32[] parameter(2)
+; CHECK:         [[P3:%[^ ]+]] = f32[] parameter(3)
+; CHECK:         [[C:%[^ ]+]] = f32[] constant(1)
+; CHECK:         [[GEMM:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[AA]], [[P1]], [[P2]], [[P3]], [[C]], /*index=5*/[[C]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+      )",
+      nullptr, &config);
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest,
+       ScaledABUnscaledDWithCollectivePermuteF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif  // TF_ROCM_VERSION < 60000
+
+  absl::string_view hlo_text = R"(
+    HloModule test
+
+    ENTRY test {
+      x = <<F8E4M3>>[16,32] parameter(0)
+      y = <<F8E4M3>>[16,32] parameter(1)
+      x_f32 = f32[16,32] convert(x)
+      y_f32 = f32[16,32] convert(y)
+      x_scale = f32[] parameter(2)
+      y_scale = f32[] parameter(3)
+      x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
+      y_scale_bcast = f32[16,32] broadcast(y_scale), dimensions={}
+      x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
+      y_unscaled = f32[16,32] multiply(y_f32, y_scale_bcast)
+      collective_permute = f32[16,32]{1,0} collective-permute(x_unscaled), source_target_pairs={{0,0}, {1,1}, {2,4}, {3,5}, {4,2}, {5,3}, {6,6}, {7,7}}
+      ROOT dot_a = f32[16,16] dot(collective_permute, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={1}
+          }
+)";
+
+  HloModuleConfig config = GetModuleConfigForTest();
+  config.set_use_spmd_partitioning(true);
+  config.set_num_partitions(8);
+
+  RunAndFilecheckHloRewrite(
+      hlo_text,
+      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+                   GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+      R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[16,32], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] {
+; CHECK:         [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
+; CHECK:         [[AA:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} collective-permute([[P0]]), {{[^ ]+}}
+; CHECK:         [[P1:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(1)
+; CHECK:         [[P2:%[^ ]+]] = f32[] parameter(2)
+; CHECK:         [[P3:%[^ ]+]] = f32[] parameter(3)
+; CHECK:         [[C:%[^ ]+]] = f32[] constant(1)
+; CHECK:         [[GEMM:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[AA]], [[P1]], [[P2]], [[P3]], [[C]], /*index=5*/[[C]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+      )",
+      nullptr, &config);
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest,
+       ScaledABUnscaledDMatrixBiasThenVectorBiasF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif  // TF_ROCM_VERSION < 60000
+
+  const char* hlo_text = R"(
+    HloModule test
+
+    ENTRY test {
+      x = <<F8E4M3>>[16,32] parameter(0)
+      y = <<F8E4M3>>[32,16] parameter(1)
+      x_f16 = f16[16,32] convert(x)
+      y_f16 = f16[32,16] convert(y)
+      b = f16[16] parameter(2)
+      b_bcast = f16[16,16] broadcast(b), dimensions={1}
+      b2 = f16[16,16] parameter(3)
+      x_scale = f16[] parameter(4)
+      y_scale = f16[] parameter(5)
+      x_scale_bcast = f16[16,32] broadcast(x_scale), dimensions={}
+      y_scale_bcast = f16[32,16] broadcast(y_scale), dimensions={}
+      x_unscaled = f16[16,32] multiply(x_f16, x_scale_bcast)
+      y_unscaled = f16[32,16] multiply(y_f16, y_scale_bcast)
+      dot_a = f16[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+      dot_a_bias1 = f16[16,16] add(dot_a, b2)
+      ROOT dot_a_bias = f16[16,16] add(dot_a_bias1, b_bcast)
+          }
+
+)";
+
+  CheckFp8IfSupported(hlo_text, ErrorSpec{2e-3, 0.});
+  RunAndFilecheckHloRewrite(
+      hlo_text,
+      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+                   GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+      R"(
+; CHECK-LABEL:   ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f16[16], {{.*}}: f16[16,16], {{.*}}: f16[], {{.*}}: f16[]) -> f16[16,16] {
+; CHECK-DAG:     [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
+; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
+; CHECK-NEXT:    [[MB:%[^ ]+]] = f16[16,16]{1,0} parameter(3)
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f16[] parameter(4)
+; CHECK-NEXT:    [[CV0:%[^ ]+]] = f32[] convert([[P2]])
+; CHECK-NEXT:    [[P3:%[^ ]+]] = f16[] parameter(5)
+; CHECK-NEXT:    [[CV1:%[^ ]+]] = f32[] convert([[P3]])
+; CHECK:         [[C1:%[^ ]+]] = f32[] constant(1)
+; CHECK:         [[GEMMOUT_TUPLE:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[MB]], [[CV0]], [[CV1]], /*index=5*/[[C1]], [[C1]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":1
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+; CHECK:         [[GEMMOUT:%[^ ]+]] = f16[16,16]{1,0} get-tuple-element([[GEMMOUT_TUPLE]]), index=0
+; CHECK:         [[VB:%[^ ]+]] = f16[16]{0} parameter(2)
+; CHECK:         [[VBC:%[^ ]+]] = f16[16,16]{1,0} broadcast([[VB]]), dimensions={1}
+; CHECK:         ROOT [[OUT:%[^ ]+]] = f16[16,16]{1,0} add([[GEMMOUT]], [[VBC]])
+      )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDWithDAmaxF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif  // TF_ROCM_VERSION < 60000
+
+  const char* hlo_text = R"(
+    HloModule test
+
+    apply {
+      a = f32[] parameter(0)
+      b = f32[] parameter(1)
+      ROOT c = f32[] maximum(a, b)
+    }
+
+    ENTRY test {
+      x = <<F8E4M3>>[16,32] parameter(0)
+      y = <<F8E4M3>>[32,16] parameter(1)
+      x_f32 = f32[16,32] convert(x)
+      y_f32 = f32[32,16] convert(y)
+      x_scale = f32[] parameter(2)
+      y_scale = f32[] parameter(3)
+      z_scale = f32[] parameter(4)
+      x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
+      y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
+      z_scale_bcast = f32[16,16] broadcast(z_scale), dimensions={}
+      x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
+      y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast)
+      dot_a = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+      abs_dot_a = f32[16,16] abs(dot_a)
+      c0 = f32[] constant(-inf)
+      amax = f32[] reduce(abs_dot_a, c0), dimensions={0,1}, to_apply=apply
+      dot_a_scaled = f32[16,16] divide(dot_a, z_scale_bcast)
+      c1 = f32[] constant(-<<F8E4M3_AMAX>>)
+      c1_bcast = f32[16,16] broadcast(c1), dimensions={}
+      c2 = f32[] constant(<<F8E4M3_AMAX>>)
+      c2_bcast = f32[16,16] broadcast(c2), dimensions={}
+      dot_a_clamped = f32[16,16] clamp(c1_bcast, dot_a_scaled, c2_bcast)
+      dot_a_f8 = <<F8E4M3>>[16,16] convert(dot_a_clamped)
+      ROOT out = (<<F8E4M3>>[16,16], f32[]) tuple(dot_a_f8, amax)
+          }
+
+)";
+
+  CheckFp8IfSupported(hlo_text);
+  RunAndFilecheckHloRewrite(
+      hlo_text,
+      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+                   GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+      R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[], {{.*}}: f32[], {{.*}}: f32[]) -> (<<F8E4M3>>[16,16], f32[]) {
+; CHECK:    [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
+; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]])
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[] parameter(2)
+; CHECK-NEXT:    [[P3:%[^ ]+]] = f32[] parameter(3)
+; CHECK-NEXT:    [[C1:%[^ ]+]] = f32[] constant(1)
+; CHECK-PTX-NEXT:    [[C2:%[^ ]+]] = f32[] constant(1)
+; CHECK-PTX-NEXT:    [[P4:%[^ ]+]] = f32[] parameter(4)
+; CHECK-PTX-NEXT:    [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]])
+; CHECK-PTX-NEXT:    [[OUT:%[^ ]+]] = (<<F8E4M3>>[16,16]{1,0}, f32[], s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[P4_INV]]),
+; CHECK-GCN-NEXT:    [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+      )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest,
+       ScaledABScaledDWithDAmaxF8WithF16Intermediates) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif  // TF_ROCM_VERSION < 60000
+
+  // This is the same as ScaledABScaledDWithDAmaxF8, but uses F16 intermediate
+  // values instead of F32 intermediate values.
+  const char* hlo_text = R"(
+    HloModule test
+
+    apply {
+      a = f16[] parameter(0)
+      b = f16[] parameter(1)
+      ROOT c = f16[] maximum(a, b)
+    }
+
+    ENTRY test {
+      x = <<F8E4M3>>[16,32] parameter(0)
+      y = <<F8E4M3>>[32,16] parameter(1)
+      x_f16 = f16[16,32] convert(x)
+      y_f16 = f16[32,16] convert(y)
+      x_scale = f16[] parameter(2)
+      y_scale = f16[] parameter(3)
+      z_scale = f16[] parameter(4)
+      x_scale_bcast = f16[16,32] broadcast(x_scale), dimensions={}
+      y_scale_bcast = f16[32,16] broadcast(y_scale), dimensions={}
+      z_scale_bcast = f16[16,16] broadcast(z_scale), dimensions={}
+      x_unscaled = f16[16,32] multiply(x_f16, x_scale_bcast)
+      y_unscaled = f16[32,16] multiply(y_f16, y_scale_bcast)
+      dot_a = f16[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+      abs_dot_a = f16[16,16] abs(dot_a)
+      c0 = f16[] constant(-inf)
+      amax = f16[] reduce(abs_dot_a, c0), dimensions={0,1}, to_apply=apply
+      dot_a_scaled = f16[16,16] divide(dot_a, z_scale_bcast)
+      c1 = f16[] constant(-<<F8E4M3_AMAX>>)
+      c1_bcast = f16[16,16] broadcast(c1), dimensions={}
+      c2 = f16[] constant(<<F8E4M3_AMAX>>)
+      c2_bcast = f16[16,16] broadcast(c2), dimensions={}
+      dot_a_clamped = f16[16,16] clamp(c1_bcast, dot_a_scaled, c2_bcast)
+      dot_a_f8 = <<F8E4M3>>[16,16] convert(dot_a_clamped)
+      ROOT out = (<<F8E4M3>>[16,16], f16[]) tuple(dot_a_f8, amax)
+          }
+
+)";
+
+  CheckFp8IfSupported(hlo_text);
+  RunAndFilecheckHloRewrite(
+      hlo_text,
+      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+                   GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+      R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f16[], {{.*}}: f16[], {{.*}}: f16[]) -> (<<F8E4M3>>[16,16], f16[]) {
+; CHECK:    [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
+; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]])
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f16[] parameter(2)
+; CHECK-NEXT:    [[P2_CONVERT:%[^ ]+]] = f32[] convert([[P2]])
+; CHECK-NEXT:    [[P3:%[^ ]+]] = f16[] parameter(3)
+; CHECK-NEXT:    [[P3_CONVERT:%[^ ]+]] = f32[] convert([[P3]])
+; CHECK-NEXT:    [[C1:%[^ ]+]] = f32[] constant(1)
+; CHECK-PTX-NEXT:    [[C2:%[^ ]+]] = f16[] constant(1)
+; CHECK-PTX-NEXT:    [[P4:%[^ ]+]] = f16[] parameter(4)
+; CHECK-PTX-NEXT:    [[P4_INV:%[^ ]+]] = f16[] divide([[C2]], [[P4]])
+; CHECK-PTX-NEXT:    [[P4_INV_CONVERT:%[^ ]+]] = f32[] convert([[P4_INV]])
+; CHECK-PTX-NEXT:    [[OUT:%[^ ]+]] = (<<F8E4M3>>[16,16]{1,0}, f32[], s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_CONVERT]], [[P3_CONVERT]], [[C1]], /*index=5*/[[P4_INV_CONVERT]]),
+; CHECK-GCN-NEXT:    [[OUT:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_CONVERT]], [[P3_CONVERT]], [[C1]], /*index=5*/[[C1]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+      )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest,
+       ScaledABScaledDReluActivationWithDAmaxF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif  // TF_ROCM_VERSION < 60000
+
+  const char* hlo_text = R"(
+    HloModule test
+
+    apply {
+      a = f32[] parameter(0)
+      b = f32[] parameter(1)
+      ROOT c = f32[] maximum(a, b)
+    }
+
+    ENTRY test {
+      x = <<F8E4M3>>[16,32] parameter(0)
+      y = <<F8E4M3>>[32,16] parameter(1)
+      x_f32 = f32[16,32] convert(x)
+      y_f32 = f32[32,16] convert(y)
+      x_scale = f32[] parameter(2)
+      y_scale = f32[] parameter(3)
+      z_scale = f32[] parameter(4)
+      x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
+      y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
+      z_scale_bcast = f32[16,16] broadcast(z_scale), dimensions={}
+      x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
+      y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast)
+      dot_a = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+      czero = f32[] constant(0)
+      czero_bcast = f32[16,16] broadcast(czero), dimensions={}
+      dot_a_relu = f32[16,16] maximum(dot_a, czero_bcast)
+      c0 = f32[] constant(-inf)
+      amax = f32[] reduce(dot_a_relu, c0), dimensions={0,1}, to_apply=apply
+      dot_a_scaled = f32[16,16] divide(dot_a_relu, z_scale_bcast)
+      c1 = f32[] constant(-<<F8E4M3_AMAX>>)
+      c1_bcast = f32[16,16] broadcast(c1), dimensions={}
+      c2 = f32[] constant(<<F8E4M3_AMAX>>)
+      c2_bcast = f32[16,16] broadcast(c2), dimensions={}
+      dot_a_clamped = f32[16,16] clamp(c1_bcast, dot_a_scaled, c2_bcast)
+      dot_a_f8 = <<F8E4M3>>[16,16] convert(dot_a_clamped)
+      ROOT out = (<<F8E4M3>>[16,16], f32[]) tuple(dot_a_f8, amax)
+          }
+
+)";
+
+  CheckFp8IfSupported(hlo_text);
+  RunAndFilecheckHloRewrite(
+      hlo_text,
+      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+                   GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+      R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[], {{.*}}: f32[], {{.*}}: f32[]) -> (<<F8E4M3>>[16,16], f32[]) {
+; CHECK:    [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
+; CHECK-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]])
+; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[] parameter(2)
+; CHECK-NEXT:    [[P3:%[^ ]+]] = f32[] parameter(3)
+; CHECK-NEXT:    [[C1:%[^ ]+]] = f32[] constant(1)
+; CHECK-PTX-NEXT:    [[C2:%[^ ]+]] = f32[] constant(1)
+; CHECK-PTX-NEXT:    [[P4:%[^ ]+]] = f32[] parameter(4)
+; CHECK-PTX-NEXT:    [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]])
+; CHECK-PTX-NEXT:    [[OUT:%[^ ]+]] = (<<F8E4M3>>[16,16]{1,0}, f32[], s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[P4_INV]]),
+; CHECK-GCN-NEXT:    [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]),
+; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"RELU"
+; CHECK:           }
+      )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDPrecisionF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif  // CUDA_VERSION < 12000
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif  // TF_ROCM_VERSION < 60000
+
+  const char* raw_hlo_template = R"(
+    HloModule test
+
+    ENTRY test {
+      x = <<F8E4M3>>[1600,3200] parameter(0)
+      y = <<F8E4M3>>[3200,1600] parameter(1)
+      x_f32 = f32[1600,3200] convert(x)
+      y_f32 = f32[3200,1600] convert(y)
+      x_scale = f32[] parameter(2)
+      y_scale = f32[] parameter(3)
+      x_scale_bcast = f32[1600,3200] broadcast(x_scale), dimensions={}
+      y_scale_bcast = f32[3200,1600] broadcast(y_scale), dimensions={}
+      x_unscaled = f32[1600,3200] multiply(x_f32, x_scale_bcast)
+      y_unscaled = f32[3200,1600] multiply(y_f32, y_scale_bcast)
+      ROOT out = f32[1600,1600] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={<<precision>>,<<precision>>}
+          }
+)";
+
+  std::string hlo_template =
+      absl::StrReplaceAll(raw_hlo_template, replacements_);
+
+  absl::flat_hash_map<absl::string_view, absl::string_view> replacements;
+  replacements["<<precision>>"] = "default";
+  const auto hlo_text_default = absl::StrReplaceAll(hlo_template, replacements);
+  EXPECT_TRUE(RunAndCompare(hlo_text_default, ErrorSpec{1e-3, 1e-3}));
+
+  replacements["<<precision>>"] = "highest";
+  const auto hlo_text_highest = absl::StrReplaceAll(hlo_template, replacements);
+  EXPECT_TRUE(RunAndCompare(hlo_text_highest, ErrorSpec{1e-4, 1e-4}));
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8Parameterized) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif  // TF_ROCM_VERSION < 60000
+
+  std::array<std::array<absl::string_view, 7>, 32> combinations;
+  int i = 0;
+
+  for (bool d_is_col : {false, true}) {
+    for (bool a_is_col : {false, true}) {
+      for (bool b_is_col : {false, true}) {
+        for (int lhs_contracting_dim : {0, 1}) {
+          for (int rhs_contracting_dim : {0, 1}) {
+            const absl::string_view lcd =
+                lhs_contracting_dim == 1 ? "{1}" : "{0}";
+            const absl::string_view rcd =
+                rhs_contracting_dim == 1 ? "{1}" : "{0}";
+            const absl::string_view a_shape =
+                lhs_contracting_dim == 1 ? "[64,32]" : "[32,64]";
+            const absl::string_view b_shape =
+                rhs_contracting_dim == 0 ? "[32,16]" : "[16,32]";
+            const absl::string_view a_layout = a_is_col ? "{0,1}" : "{1,0}";
+            const absl::string_view b_layout = b_is_col ? "{0,1}" : "{1,0}";
+            const absl::string_view output_layout =
+                d_is_col ? "{0,1}" : "{1,0}";
+            combinations[i++] = std::array{
+                lcd, rcd, a_shape, b_shape, a_layout, b_layout, output_layout};
+          }
+        }
+      }
+    }
+  }
+
+  const char* hlo_template = R"(
+      HloModule test
+    ENTRY test {
+      x = <<F8E4M3>><<Ashape>><<Alayout>> parameter(0)
+      x_f32 = f32<<Ashape>><<Alayout>> convert(x)
+      x_scale = f32[] parameter(2)
+      x_scale_bcast = f32<<Ashape>> broadcast(x_scale), dimensions={}
+      x_unscaled = f32<<Ashape>> multiply(x_f32, x_scale_bcast)
+      y = <<F8E4M3>><<Bshape>><<Blayout>> parameter(1)
+      y_f32 = f32<<Bshape>><<Blayout>> convert(y)
+      y_scale = f32[] parameter(3)
+      y_scale_bcast = f32<<Bshape>> broadcast(y_scale), dimensions={}
+      y_unscaled = f32<<Bshape>> multiply(y_f32, y_scale_bcast)
+      ROOT out = f32[64,16]<<Olayout>> dot(x_unscaled, y_unscaled), lhs_contracting_dims=<<Lcd>>, rhs_contracting_dims=<<Rcd>>
+    }
+      )";
+  for (const auto& combination : combinations) {
+    absl::flat_hash_map<absl::string_view, absl::string_view> replacements;
+    replacements["<<Lcd>>"] = std::get<0>(combination);
+    replacements["<<Rcd>>"] = std::get<1>(combination);
+    replacements["<<Ashape>>"] = std::get<2>(combination);
+    replacements["<<Bshape>>"] = std::get<3>(combination);
+    replacements["<<Alayout>>"] = std::get<4>(combination);
+    replacements["<<Blayout>>"] = std::get<5>(combination);
+    replacements["<<Olayout>>"] = std::get<6>(combination);
+    const auto hlo_text = absl::StrReplaceAll(hlo_template, replacements);
+    CheckFp8IfSupported(hlo_text);
+
+    RunAndFilecheckHloRewrite(
+        hlo_text,
+        GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+                     GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+        R"(
+    ; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
+          )");
+  }
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest,
+       ScaledABUnscaledDF8ParameterizedBatched) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif  // TF_ROCM_VERSION < 60000
+
+  // TODO(wenscarl): For batched matmul, not all combinations of A, B and
+  // output layouts get pattern matched successfully to FP8 custom call. Only
+  // a handful of cases are tested here.
+  std::array<std::array<std::string, 7>, 32> combinations;
+  std::string lcd, rcd, a_shape, b_shape, a_layout, b_layout, o_layout;
+  int i = 0;
+  for (bool o_is_col : {false, true}) {
+    for (int lhs_contracting_dim : {2, 1}) {
+      for (int rhs_contracting_dim : {2, 1}) {
+        lcd = lhs_contracting_dim == 2 ? "{2}" : "{1}";
+        rcd = rhs_contracting_dim == 2 ? "{2}" : "{1}";
+        a_shape = lhs_contracting_dim == 2 ? "[2,64,32]" : "[2,32,64]";
+        b_shape = rhs_contracting_dim == 1 ? "[2,32,16]" : "[2,16,32]";
+        o_layout = o_is_col ? "{2, 0, 1}" : "{2, 1, 0}";
+        for (std::string a_layout : {"{2,1,0}", "{1,2,0}"}) {
+          for (std::string b_layout : {"{2,1,0}", "{1,2,0}"}) {
+            combinations[i++] = std::array{lcd,      rcd,      a_shape, b_shape,
+                                           a_layout, b_layout, o_layout};
+          }
+        }
+      }
+    }
+  }
+
+  const char* hlo_template = R"(
+      HloModule m
+ENTRY f {
+  x_q = <<F8E4M3>><<Ashape>><<Alayout>> parameter(0)
+  x_scale = f32[] parameter(2)
+  x_scale_broadcast = f32<<Ashape>><<Alayout>> broadcast(x_scale), dimensions={}
+  x_q_convert = f32<<Ashape>><<Alayout>> convert(x_q)
+  x_qdq = f32<<Ashape>><<Alayout>> multiply(x_q_convert, x_scale_broadcast)
+
+  y_q = <<F8E4M3>><<Bshape>><<Blayout>> parameter(1)
+  y_scale = f32[] parameter(3)
+  y_scale_broadcast = f32<<Bshape>><<Blayout>> broadcast(y_scale), dimensions={}
+  y_q_convert = f32<<Bshape>><<Blayout>> convert(y_q)
+  y_qdq = f32<<Bshape>><<Blayout>> multiply(y_q_convert, y_scale_broadcast)
+
+  ROOT out = f32[2,64,16]<<Olayout>> dot(x_qdq, y_qdq), lhs_batch_dims={0}, lhs_contracting_dims=<<Lcd>>, rhs_batch_dims={0}, rhs_contracting_dims=<<Rcd>>
+}
+     )";
+
+  for (const auto& combination : combinations) {
+    absl::flat_hash_map<std::string, std::string> replacements;
+    replacements["<<Lcd>>"] = std::get<0>(combination);
+    replacements["<<Rcd>>"] = std::get<1>(combination);
+    replacements["<<Ashape>>"] = std::get<2>(combination);
+    replacements["<<Bshape>>"] = std::get<3>(combination);
+    replacements["<<Alayout>>"] = std::get<4>(combination);
+    replacements["<<Blayout>>"] = std::get<5>(combination);
+    replacements["<<Olayout>>"] = std::get<6>(combination);
+
+    const auto hlo_text = absl::StrReplaceAll(hlo_template, replacements);
+    CheckFp8IfSupported(hlo_text);
+
+    RunAndFilecheckHloRewrite(
+        hlo_text,
+        GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+                     GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+        R"(
+    ; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
+          )");
+  }
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8TF32E5M2) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif  // TF_ROCM_VERSION < 60000
+
+  const char* hlo_text = R"(
+    HloModule test
+
+    ENTRY test {
+      x = <<F8E4M3>>[16,32] parameter(0)
+      y = <<F8E5M2>>[32,16] parameter(1)
+      x_f32 = f32[16,32] convert(x)
+      y_f32 = f32[32,16] convert(y)
+      x_scale = f32[] parameter(2)
+      y_scale = f32[] parameter(3)
+      x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
+      y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
+      x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
+      y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast)
+      ROOT out = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+          }
+
+)";
+
+  CheckFp8IfSupported(hlo_text);
+  RunAndFilecheckHloRewrite(
+      hlo_text,
+      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+                   GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+      R"(
+    ; CHECK:           custom_call_target="__cublas$lt$matmul$f8",
+          )");
+}
+
+TEST_P(ParameterizedFp8GemmRewriteTest, FnuzTypeF8) {
+#if GOOGLE_CUDA && CUDA_VERSION < 12000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif
+
+#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
+#endif  // TF_ROCM_VERSION < 60000
+
+  // Test that FNUZ FP8 gemms are not rewritten, as cuBLAS does not support them
+  const char* hlo_text = R"(
+    HloModule test
+
+    ENTRY test {
+      x = f8e4m3fnuz[16,32] parameter(0)
+      y = f8e4m3fnuz[32,16] parameter(1)
+      x_f32 = f32[16,32] convert(x)
+      y_f32 = f32[32,16] convert(y)
+      x_scale = f32[] parameter(2)
+      y_scale = f32[] parameter(3)
+      x_scale_bcast = f32[16,32] broadcast(x_scale), dimensions={}
+      y_scale_bcast = f32[32,16] broadcast(y_scale), dimensions={}
+      x_unscaled = f32[16,32] multiply(x_f32, x_scale_bcast)
+      y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast)
+      ROOT out = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+          }
+)";
+#if GOOGLE_CUDA
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_text));
+  GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+                    GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only});
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
+  EXPECT_FALSE(changed);
+#endif
+#if TENSORFLOW_USE_ROCM
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-2, 1e-2}));
+  RunAndFilecheckHloRewrite(
+      hlo_text,
+      GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
+                   GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
+      R"(
+; CHECK-LABEL: ENTRY %test ({{.*}}: f8e4m3fnuz[16,32], {{.*}}: f8e4m3fnuz[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f8e4m3fnuz[16,32]{1,0} parameter(0)
+; CHECK-PTX-NEXT:    [[P0_CV:%[^ ]+]] = f32[16,32]{1,0} convert([[P0]])
+; CHECK-PTX-NEXT:    [[P2:%[^ ]+]] = f32[] parameter(2)
+; CHECK-PTX-NEXT:    [[P2_B:%[^ ]+]] = f32[16,32]{1,0} broadcast([[P2]]), dimensions={}
+; CHECK-PTX-NEXT:    [[P0_UNSCALED:%[^ ]+]] = f32[16,32]{1,0} multiply([[P0_CV]], [[P2_B]])
+; CHECK-PTX-NEXT:    [[P1:%[^ ]+]] = f8e4m3fnuz[32,16]{1,0} parameter(1)
+; CHECK-PTX-NEXT:    [[P1_CV:%[^ ]+]] = f32[32,16]{1,0} convert([[P1]])
+; CHECK-PTX-NEXT:    [[P3:%[^ ]+]] = f32[] parameter(3)
+; CHECK-PTX-NEXT:    [[P3_B:%[^ ]+]] = f32[32,16]{1,0} broadcast([[P3]]), dimensions={}
+; CHECK-PTX-NEXT:    [[P1_UNSCALED:%[^ ]+]] = f32[32,16]{1,0} multiply([[P1_CV]], [[P3_B]])
+; CHECK-PTX-NEXT:    [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0_UNSCALED]], [[P1_UNSCALED]]),
+; CHECK-GCN-NEXT:    [[P1:%[^ ]+]] = f8e4m3fnuz[32,16]{1,0} parameter(1)
+; CHECK-GCN-NEXT:    [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]])
+; CHECK-GCN-NEXT:    [[P2:%[^ ]+]] = f32[] parameter(2)
+; CHECK-GCN-NEXT:    [[P3:%[^ ]+]] = f32[] parameter(3)
+; CHECK-GCN-NEXT:    [[C1:%[^ ]+]] = f32[] constant(1)
+; CHECK-PTX:           custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
+; CHECK-GCN:           custom_call_target="__cublas$lt$matmul$f8",
+; CHECK:           backend_config={
+; CHECK-DAG:         "alpha_real":1
+; CHECK-DAG:         "alpha_imag":0
+; CHECK-DAG:         "beta":0
+; CHECK-DAG:         "dot_dimension_numbers":{
+; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
+; CHECK-PTX-DAG:           "rhs_contracting_dimensions":["0"]
+; CHECK-GCN-DAG:           "rhs_contracting_dimensions":["1"]
+; CHECK-DAG:           "lhs_batch_dimensions":[]
+; CHECK-DAG:           "rhs_batch_dimensions":[]
+; CHECK-DAG:         }
+; CHECK-DAG:         "precision_config":{
+; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
+; CHECK-DAG:         }
+; CHECK-DAG:         "epilogue":"DEFAULT"
+; CHECK:           }
+      )");
+#endif
+}
+
+INSTANTIATE_TEST_SUITE_P(Fp8CublasTestsBothLegacyAndLt,
+                         ParameterizedFp8GemmRewriteTest, ::testing::Bool());
+#endif
+
+TEST_F(GemmRewriteTest, NoFuseBiasBroadcast) {
+  const char* hlo = R"(
+
+HloModule module
+
+ENTRY main.10 {
+  Arg_0.1 = f16[384,128]{1,0} parameter(0)
+  Arg_1.2 = f16[128,256]{1,0} parameter(1)
+  dot.4 = f16[384,256]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  Arg_2.3 = f16[256]{0} parameter(2)
+  reshape.5 = f16[1,256]{1,0} reshape(Arg_2.3)
+  broadcast.6 = f16[1,256]{1,0} broadcast(reshape.5), dimensions={0,1}
+  reshape.7 = f16[256]{0} reshape(broadcast.6)
+  broadcast.8 = f16[384,256]{1,0} broadcast(reshape.7), dimensions={1}
+  ROOT add.9 = f16[384,256]{1,0} add(dot.4, broadcast.8)
+})";
+
+  MatchOptimizedHlo(hlo, R"(
+// CHECK: "beta":0
+  )");
+}
+
+TEST_F(GemmRewriteTest, ReduceOfBatchDot) {
+  absl::string_view hlo_string =
+      R"(
+HloModule test
+
+region_5.50 {
+  Arg_0.51 = f32[] parameter(0)
+  Arg_1.52 = f32[] parameter(1)
+  ROOT add.53 = f32[] add(Arg_0.51, Arg_1.52)
+}
+
+ENTRY main {
+  p0 = bf16[3,32,3,13]{3,2,1,0} parameter(0)
+  p1 = bf16[3,32,3,64]{3,2,1,0} parameter(1)
+  dot.95 = bf16[3,3,13,64]{3,2,1,0} dot(p0, p1), lhs_batch_dims={0,2}, lhs_contracting_dims={1}, rhs_batch_dims={0,2}, rhs_contracting_dims={1}, operand_precision={highest,highest}
+  transpose.96 = bf16[3,64,3,13]{1,3,2,0} transpose(dot.95), dimensions={0,3,1,2}
+  convert.101 = f32[3,64,3,13]{1,3,2,0} convert(transpose.96)
+  constant.66 = f32[] constant(0.0)
+  ROOT reduce.102 = f32[3,64,13]{2,1,0} reduce(convert.101, constant.66), dimensions={2}, to_apply=region_5.50
+}
+)";
+  // Make sure the dot is lowered to a custom call. There is an algebraic
+  // simplifier simplification which could turn the dot into a non-canonical dot
+  // late in the pipeline, which will make it unsupported by the GemmRewriter.
+  MatchOptimizedHlo(hlo_string, R"(
+  // CHECK: custom_call_target="__cublas$gemm"
+  )");
+}
+
+TEST_F(GemmRewriteTest, DotWithBias) {
+  const char* hlo = R"(
+      HloModule m
+
+      ENTRY main {
+        p0 = f32[1024,1024] parameter(0)
+        p1 = f32[1024,1024] parameter(1)
+        p2 = f32[1024,1024] parameter(2)
+        p3 = f32[1024,1024] parameter(3)
+        dot0 = f32[1024,1024] dot(p0, p1),
+            lhs_contracting_dims={1}, rhs_contracting_dims={0}
+        dot1 = f32[1024,1024] dot(p2, p3),
+            lhs_contracting_dims={1}, rhs_contracting_dims={0}
+        ROOT root = f32[1024,1024] add(dot0, dot1)
+  })";
+
+  const char* expected = R"()
+    // CHECK: %[[P0:.*]] = f32[1024,1024]{1,0} parameter(0)
+    // CHECK: %[[P1:.*]] = f32[1024,1024]{1,0} parameter(1)
+    // CHECK: %[[P2:.*]] = f32[1024,1024]{1,0} parameter(2)
+    // CHECK: %[[P3:.*]] = f32[1024,1024]{1,0} parameter(3)
+    // CHECK: %[[TUPLE0:.*]] = (f32[1024,1024]{1,0}, s8[4194304]{0}) custom-call(%[[P2]], %[[P3]])
+    // CHECK: %[[S0:.*]] = f32[1024,1024]{1,0} get-tuple-element(%[[TUPLE0]]), index=0
+    // CHECK: %[[TUPLE1:.*]] = (f32[1024,1024]{1,0}, s8[4194304]{0}) custom-call(%[[P0]], %[[P1]], %[[S0]])
+    // CHECK: ROOT %[[S1:.*]] = f32[1024,1024]{1,0} get-tuple-element(%[[TUPLE1]]), index=0
+  })";
+
+  RunAndFilecheckHloRewrite(
+      hlo,
+      GemmRewriter(
+          se::CudaComputeCapability{}, /*toolkit_version=*/0,
+          GemmRewriterOptions{GemmRewriterOptions::DType::kNonFp8Only}),
+      expected);
+}
+
+TEST_F(GemmRewriteTest, DotWithoutBias) {
+  const char* hlo = R"(
+      HloModule m
+
+      ENTRY main {
+        p0 = f32[1024,1024] parameter(0)
+        p1 = f32[1024,1024] parameter(1)
+        p2 = f32[1024,1024] parameter(2)
+        p3 = f32[1024,1024] parameter(3)
+        dot0 = f32[1024,1024] dot(p0, p1),
+            lhs_contracting_dims={1}, rhs_contracting_dims={0}
+        dot1 = f32[1024,1024] dot(p2, p3),
+            lhs_contracting_dims={1}, rhs_contracting_dims={0}
+        ROOT root = f32[1024,1024] add(dot0, dot1)
+  })";
+
+  const char* expected = R"()
+    // CHECK: %[[P0:.*]] = f32[1024,1024]{1,0} parameter(0)
+    // CHECK: %[[P1:.*]] = f32[1024,1024]{1,0} parameter(1)
+    // CHECK: %[[TUPLE0:.*]] = (f32[1024,1024]{1,0}, s8[4194304]{0}) custom-call(%[[P0]], %[[P1]])
+    // CHECK: %[[S0:.*]] = f32[1024,1024]{1,0} get-tuple-element(%[[TUPLE0]]), index=0
+    // CHECK: %[[P2:.*]] = f32[1024,1024]{1,0} parameter(2)
+    // CHECK: %[[P3:.*]] = f32[1024,1024]{1,0} parameter(3)
+    // CHECK: %[[TUPLE1:.*]] = (f32[1024,1024]{1,0}, s8[4194304]{0}) custom-call(%[[P2]], %[[P3]])
+    // CHECK: %[[S1:.*]] = f32[1024,1024]{1,0} get-tuple-element(%[[TUPLE1]]), index=0
+    // CHECK: ROOT %[[S2:.*]] = f32[1024,1024]{1,0} add(%[[S0]], %[[S1]])
+  })";
+
+  RunAndFilecheckHloRewrite(
+      hlo,
+      GemmRewriter(se::CudaComputeCapability{}, /*toolkit_version=*/0,
+                   GemmRewriterOptions{GemmRewriterOptions::DType::kNonFp8Only,
+                                       GemmRewriterOptions::BiasMode::kNoBias}),
+      expected);
+}
+
+TEST_F(CublasLtGemmRewriteTest, CublasLtSuccessfullyMatchesLargeC64Lhs) {
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  p0 = c64[2000,3000,3]{2,1,0} parameter(0)
+  p1 = c64[3,6]{1,0} parameter(1)
+  ROOT dot = c64[2000,3000,6]{2,1,0} dot(p0, p1), lhs_contracting_dims={2}, rhs_contracting_dims={0}
+}
+)";
+  // Large lhs is fine for cuBLASlt.
+  MatchOptimizedHlo(hlo_text,
+                    R"(; CHECK: custom_call_target="__cublas$lt$matmul")");
+}
+
+TEST_F(CublasLtGemmRewriteTest, CublasLtOnlyMatchesLargeC64RhsPostAmpere) {
+  const char* hlo_text = R"(
+HloModule test
+
+ENTRY test {
+  p0 = c64[6,3]{1,0} parameter(0)
+  p1 = c64[3,2000,3000]{2,1,0} parameter(1)
+  ROOT dot = c64[6,2000,3000]{2,1,0} dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+)";
+  if (HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) {
+    // From Ampere onwards, cuBLASlt supports large rhs.
+    MatchOptimizedHlo(hlo_text,
+                      R"(; CHECK: custom_call_target="__cublas$lt$matmul")");
+  } else {
+    // Rhs with non-contracting dimensions > 4194240 (combined) is not fine for
+    // C64 type.
+    MatchOptimizedHlo(
+        hlo_text, R"(; CHECK-NOT: custom_call_target="__cublas$lt$matmul")");
+  }
+}
+
+class GemmRewriteAllocationTest : public GpuCodegenTest {
+ public:
+  void CheckNumberOfAllocations(const std::string& hlo,
+                                int expected_number_of_allocations) {
+    TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
+                            GetOptimizedModule(hlo));
+    if (allocator_ == nullptr) {
+      allocator_ = std::make_unique<se::StreamExecutorMemoryAllocator>(
+          backend().default_stream_executor());
+    }
+    TF_ASSERT_OK_AND_ASSIGN(
+        std::unique_ptr<Executable> executable,
+        backend().compiler()->RunBackend(std::move(optimized_module),
+                                         backend().default_stream_executor(),
+                                         allocator_.get()));
+    GpuExecutable* gpu_executable =
+        static_cast<GpuExecutable*>(executable.get());
+    absl::Span<const BufferAllocation> allocations =
+        gpu_executable->GetAllocations();
+    ASSERT_EQ(allocations.size(), expected_number_of_allocations);
+  }
+
+  DebugOptions GetDebugOptionsForTest() override {
+    DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest();
+    // Make sure the rewriter does not skip the rewrite for being too small.
+    debug_options.set_xla_gpu_gemm_rewrite_size_threshold(0);
+    return debug_options;
+  }
+
+ private:
+  std::unique_ptr<se::DeviceMemoryAllocator> allocator_;
+};
+
+TEST_F(GemmRewriteAllocationTest, SharedBufferAssignment) {
+  const char* hlo_text = R"(
+HloModule SharedBufferAssignment
+
+ENTRY AddDotsFunc {
+  x = f32[2,2] parameter(0)
+  y = f32[2,2] parameter(1)
+  bias = f32[2,2] add(x, y)
+  dot = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  ROOT out = f32[2,2] add(dot, bias)
+}
+
+)";
+
+  // Bias should be fused into the multiplication.
+  CheckNumberOfAllocations(hlo_text, 4);
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+}
+
+class SmallDotGemmRewriteTest : public GemmRewriteTest {
+ public:
+  DebugOptions GetDebugOptionsForTest() override {
+    DebugOptions debug_options = GemmRewriteTest::GetDebugOptionsForTest();
+    debug_options.set_xla_gpu_gemm_rewrite_size_threshold(100);
+    return debug_options;
+  }
+};
+
+TEST_F(SmallDotGemmRewriteTest, SkipSmallMatrixMultiplicationRewrite) {
+  const char* hlo_text = R"(
+HloModule SkipSmallMatrixRewrite
+
+ENTRY DotFunc {
+  x = f32[3,3] parameter(0)
+  y = f32[3,3] parameter(1)
+  ROOT out = f32[3,3] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+)";
+
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+; CHECK-LABEL: ENTRY %DotFunc ({{.*}}: f32[3,3], {{.*}}: f32[3,3]) -> f32[3,3] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[3,3]{1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[3,3]{1,0} parameter(1)
+; CHECK-NEXT:    [[GEMM:%[^ ]+]] = {{.*}} dot([[P0]], [[P1]]),
+; CHECK:           lhs_contracting_dims={1}, rhs_contracting_dims={0}
+)");
+}
+
+TEST_F(SmallDotGemmRewriteTest, LargeMatrixMultiplicationIsRewritten) {
+  const char* hlo_text = R"(
+HloModule SkipSmallMatrixRewrite
+
+ENTRY DotFunc {
+  x = f32[8,8] parameter(0)
+  y = f32[8,8] parameter(1)
+  ROOT out = f32[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+)";
+
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+; CHECK-LABEL: ENTRY %DotFunc ({{.*}}: f32[8,8], {{.*}}: f32[8,8]) -> f32[8,8] {
+; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[8,8]{1,0} parameter(0)
+; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[8,8]{1,0} parameter(1)
+; CHECK:         {{[^ ]+}} = {{.*}} custom-call([[P0]], [[P1]])
+)");
+}
+
+}  // namespace
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/gemv_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/gemv_rewriter.cc
new file mode 100644
index 0000000..fddb9e6
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/gemv_rewriter.cc
@@ -0,0 +1,183 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/gemv_rewriter.h"
+
+#include <cstdint>
+#include <vector>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/container/inlined_vector.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/layout.h"
+#include "xla/layout_util.h"
+#include "xla/shape.h"
+#include "xla/util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+
+// Construct a new layout by adding a new minor-most dimension to the input
+// layout. For example, {3, 2, 1, 0} is extended to {4, 3, 2, 1, 0}.
+// We expect that the input layout is normalized by LayoutNormalizer, so that
+// the input layout has a descending ordering.
+absl::StatusOr<Layout> GetLayoutWithNewMinorMostDimension(
+    const Layout& layout) {
+  // Check that the layout is normalized.
+  if (!LayoutUtil::IsMonotonicWithDim0Major(layout)) {
+    return absl::InvalidArgumentError("Layout is not normalized.");
+  }
+  return LayoutUtil::MakeDescendingLayout(layout.minor_to_major_size() + 1);
+}
+
+class GemvRewriterVisitor : public DfsHloRewriteVisitor {
+ public:
+  absl::Status HandleDot(HloInstruction* instr) override {
+    HloDotInstruction* dot = Cast<HloDotInstruction>(instr);
+    const DotDimensionNumbers& dim_numbers = dot->dot_dimension_numbers();
+    HloInstruction* lhs = dot->mutable_operand(0);
+    HloInstruction* rhs = dot->mutable_operand(1);
+
+    // This pass relies on dot decomposer which ensures that all non-batch
+    // dimensions are merged into one.
+    bool lhs_has_non_contracting_dim =
+        lhs->shape().rank() ==
+        dim_numbers.lhs_batch_dimensions_size() +
+            dim_numbers.lhs_contracting_dimensions_size() + 1;
+    bool rhs_has_non_contracting_dim =
+        rhs->shape().rank() ==
+        dim_numbers.rhs_batch_dimensions_size() +
+            dim_numbers.rhs_contracting_dimensions_size() + 1;
+
+    // Skip matrix-matrix multiplication.
+    if (lhs_has_non_contracting_dim && rhs_has_non_contracting_dim) {
+      return absl::OkStatus();
+    }
+
+    // Skip vector-vector multiplication.
+    if (!lhs_has_non_contracting_dim && !rhs_has_non_contracting_dim) {
+      return absl::OkStatus();
+    }
+
+    if (dot->shape().is_dynamic()) {
+      return absl::OkStatus();
+    }
+
+    changed_ = true;
+
+    HloComputation* computation = dot->parent();
+    HloInstruction* new_lhs = lhs;
+    if (!lhs_has_non_contracting_dim) {
+      const Shape& lhs_shape = lhs->shape();
+      absl::Span<const int64_t> lhs_dimensions = lhs_shape.dimensions();
+      std::vector<int64_t> new_lhs_dimensions(lhs_dimensions.begin(),
+                                              lhs_dimensions.end());
+      new_lhs_dimensions.push_back(1);
+      Shape new_lhs_shape(
+          lhs_shape.element_type(), new_lhs_dimensions,
+          absl::InlinedVector<bool, 4>(new_lhs_dimensions.size(), false),
+          /*tuple_shapes=*/{});
+      TF_ASSIGN_OR_RETURN(
+          *new_lhs_shape.mutable_layout(),
+          GetLayoutWithNewMinorMostDimension(lhs_shape.layout()));
+      new_lhs = computation->AddInstruction(
+          HloInstruction::CreateBitcast(new_lhs_shape, lhs));
+    }
+
+    HloInstruction* new_rhs = rhs;
+    if (!rhs_has_non_contracting_dim) {
+      const Shape& rhs_shape = rhs->shape();
+      absl::Span<const int64_t> rhs_dimensions = rhs_shape.dimensions();
+      std::vector<int64_t> new_rhs_dimensions(rhs_dimensions.begin(),
+                                              rhs_dimensions.end());
+      new_rhs_dimensions.push_back(1);
+      Shape new_rhs_shape(
+          rhs_shape.element_type(), new_rhs_dimensions,
+          absl::InlinedVector<bool, 4>(new_rhs_dimensions.size(), false),
+          /*tuple_shapes=*/{});
+      TF_ASSIGN_OR_RETURN(
+          *new_rhs_shape.mutable_layout(),
+          GetLayoutWithNewMinorMostDimension(rhs_shape.layout()));
+      new_rhs = computation->AddInstruction(
+          HloInstruction::CreateBitcast(new_rhs_shape, rhs));
+    }
+
+    std::vector<int64_t> new_out_dimensions;
+    new_out_dimensions.reserve(dot->shape().dimensions().size() + 1);
+    for (int64_t dim_size : dot->shape().dimensions()) {
+      new_out_dimensions.push_back(dim_size);
+    }
+    if (!lhs_has_non_contracting_dim) {
+      // Insert the trivial dimension before the non-contracting dimension from
+      // rhs.
+      int non_contracting_dim_size = new_out_dimensions.back();
+      new_out_dimensions[new_out_dimensions.size() - 1] = 1;
+      new_out_dimensions.push_back(non_contracting_dim_size);
+    } else {
+      new_out_dimensions.push_back(1);
+    }
+
+    Shape new_out_shape(
+        dot->shape().element_type(), new_out_dimensions,
+        absl::InlinedVector<bool, 4>(new_out_dimensions.size(), false),
+        /*tuple_shapes=*/{});
+    TF_ASSIGN_OR_RETURN(
+        *new_out_shape.mutable_layout(),
+        GetLayoutWithNewMinorMostDimension(dot->shape().layout()));
+
+    HloInstruction* new_dot =
+        computation->AddInstruction(HloInstruction::CreateDot(
+            new_out_shape, new_lhs, new_rhs, dot->dot_dimension_numbers(),
+            dot->precision_config()));
+    HloInstruction* bitcast = computation->AddInstruction(
+        HloInstruction::CreateBitcast(dot->shape(), new_dot));
+    return computation->ReplaceInstruction(dot, bitcast);
+  }
+
+  bool changed() const { return changed_; }
+
+ private:
+  bool changed_ = false;
+};
+
+}  // namespace
+
+absl::StatusOr<bool> GemvRewriter::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  GemvRewriterVisitor gemv_rewriter;
+  for (HloComputation* computation :
+       module->MakeNonfusionComputations(execution_threads)) {
+    TF_RETURN_IF_ERROR(computation->Accept(&gemv_rewriter));
+  }
+  return gemv_rewriter.changed();
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/gemv_rewriter.h b/third_party/xla/xla/service/gpu/transforms/gemv_rewriter.h
new file mode 100644
index 0000000..9339101
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/gemv_rewriter.h
@@ -0,0 +1,44 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_GEMV_REWRITER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_GEMV_REWRITER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// Rewrite a matrix-vector or a vector-matrix multiplication into a
+// matrix-matrix multiplication with a trivial dimension. For example,
+// [m x n] @ [n] is rewritten to [m x n] @ [n x 1], and [n] @ [m x n] is
+// rewritten to [n x 1] @ [m x n].
+class GemvRewriter : public HloModulePass {
+ public:
+  absl::string_view name() const override { return "gemv-rewriter"; }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_GEMV_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/gemv_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/gemv_rewriter_test.cc
new file mode 100644
index 0000000..d255528
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/gemv_rewriter_test.cc
@@ -0,0 +1,149 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/gemv_rewriter.h"
+
+#include <memory>
+#include <optional>
+
+#include <gtest/gtest.h>
+#include "absl/status/statusor.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla::gpu {
+namespace {
+
+class GemvRewriterTest : public HloTestBase {};
+
+TEST_F(GemvRewriterTest, RewriteMatrixVectorMultiplicationToGemm) {
+  const char* hlo = R"(
+  HloModule m
+
+  ENTRY e {
+    p0 = f32[32,7] parameter(0)
+    p1 = f32[7] parameter(1)
+    ROOT d = f32[32] dot(p0, p1),
+      lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  })";
+
+  const char* expected = R"()
+// CHECK:  %[[P0:.*]] = f32[32,7]{1,0} parameter(0)
+// CHECK:  %[[P1:.*]] = f32[7]{0} parameter(1)
+// CHECK:  %[[BITCAST:.*]] = f32[7,1]{1,0} bitcast(%[[P1]])
+// CHECK:  %[[DOT:.*]] = f32[32,1]{1,0} dot(%[[P0]], %[[BITCAST]]), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+// CHECK:  ROOT %[[ROOT:.*]] = f32[32]{0} bitcast(%[[DOT]])
+})";
+
+  RunAndFilecheckHloRewrite(hlo, GemvRewriter(), expected);
+}
+
+TEST_F(GemvRewriterTest, RewriteVectorMatrixMultiplicationToGemm) {
+  const char* hlo = R"(
+  HloModule m
+
+  ENTRY e {
+    p0 = f32[7] parameter(0)
+    p1 = f32[7,32] parameter(1)
+    ROOT d = f32[32] dot(p0, p1),
+      lhs_contracting_dims={0}, rhs_contracting_dims={0}
+  })";
+
+  const char* expected = R"()
+// CHECK:  %[[P0:.*]] = f32[7]{0} parameter(0)
+// CHECK:  %[[BITCAST:.*]] = f32[7,1]{1,0} bitcast(%[[P0]])
+// CHECK:  %[[P1:.*]] = f32[7,32]{1,0} parameter(1)
+// CHECK:  %[[DOT:.*]] = f32[1,32]{1,0} dot(%[[BITCAST]], %[[P1]]), lhs_contracting_dims={0}, rhs_contracting_dims={0}
+// CHECK:  ROOT %[[ROOT:.*]].1 = f32[32]{0} bitcast(%[[DOT]])
+})";
+
+  RunAndFilecheckHloRewrite(hlo, GemvRewriter(), expected);
+}
+
+TEST_F(GemvRewriterTest, RewriteMatrixVectorMultiplicationWithBatch) {
+  const char* hlo = R"(
+  HloModule m
+
+  ENTRY e {
+    p0 = f32[2,5,32,7] parameter(0)
+    p1 = f32[2,5,7] parameter(1)
+    ROOT d = f32[2,5,32] dot(p0, p1),
+      lhs_batch_dims={0,1}, rhs_batch_dims={0,1},
+      lhs_contracting_dims={3}, rhs_contracting_dims={2}
+  })";
+
+  const char* expected = R"()
+// CHECK:  %[[P0:.*]] = f32[2,5,32,7]{3,2,1,0} parameter(0)
+// CHECK:  %[[P1:.*]] = f32[2,5,7]{2,1,0} parameter(1)
+// CHECK:  %[[BITCAST:.*]] = f32[2,5,7,1]{3,2,1,0} bitcast(%[[P1]])
+// CHECK:  %[[DOT:.*]] = f32[2,5,32,1]{3,2,1,0} dot(%[[P0]], %[[BITCAST]]),
+// CHECK-SAME: lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+// CHECK:  ROOT %[[ROOT:.*]] = f32[2,5,32]{2,1,0} bitcast(%[[DOT]])
+})";
+
+  RunAndFilecheckHloRewrite(hlo, GemvRewriter(), expected);
+}
+
+TEST_F(GemvRewriterTest, DotNotRewriteVectorVectorMultiplication) {
+  const char* hlo = R"(
+  HloModule m
+
+  ENTRY e {
+    p0 = f32[7] parameter(0)
+    p1 = f32[7] parameter(1)
+    ROOT d = f32[] dot(p0, p1),
+      lhs_contracting_dims={0}, rhs_contracting_dims={0}
+  })";
+
+  RunAndFilecheckHloRewrite(hlo, GemvRewriter(), /*expected=*/std::nullopt);
+}
+
+TEST_F(GemvRewriterTest, DotNotRewriteMatrixMatrixMultiplication) {
+  const char* hlo = R"(
+  HloModule m
+
+  ENTRY e {
+    p0 = f32[5,7] parameter(0)
+    p1 = f32[7,32] parameter(1)
+    ROOT d = f32[5,32] dot(p0, p1),
+      lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  })";
+
+  RunAndFilecheckHloRewrite(hlo, GemvRewriter(), /*expected=*/std::nullopt);
+}
+
+TEST_F(GemvRewriterTest, DoNotRewriteDotsWithNonNormalizedLayout) {
+  const char* hlo = R"(
+  HloModule m
+
+  ENTRY e {
+    p0 = f32[5,32,7]{2,1,0} parameter(0)
+    p1 = f32[5,7]{0,1} parameter(1)
+    ROOT d = f32[5,32]{0,1} dot(p0, p1),
+      lhs_batch_dims={0}, rhs_batch_dims={0},
+      lhs_contracting_dims={2}, rhs_contracting_dims={1}
+  })";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo));
+  GemvRewriter rewriter;
+  absl::StatusOr<bool> result = this->RunHloPass(&rewriter, module.get());
+  EXPECT_FALSE(result.ok());
+  EXPECT_EQ(result.status().message(), "Layout is not normalized.");
+}
+
+}  // namespace
+}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/transforms/gpusolver_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/gpusolver_rewriter.cc
new file mode 100644
index 0000000..ef78dbd
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/gpusolver_rewriter.cc
@@ -0,0 +1,201 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/gpusolver_rewriter.h"
+
+#include <cstdint>
+#include <functional>
+#include <utility>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/strings/string_view.h"
+#include "xla/comparison_util.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/layout_util.h"
+#include "xla/literal.h"
+#include "xla/literal_util.h"
+#include "xla/service/gpu/cusolver_context.h"
+#include "xla/service/gpu/ir_emission_utils.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/stream_executor/blas.h"
+#include "xla/util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/logging.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+
+void SetFortranLayout(Shape* shape) {
+  LayoutUtil::SetToDefaultLayout(shape);
+  int n = shape->mutable_layout()->minor_to_major_size();
+  CHECK_GE(n, 2);
+  std::swap(shape->mutable_layout()->mutable_minor_to_major()->at(0),
+            shape->mutable_layout()->mutable_minor_to_major()->at(1));
+}
+
+absl::StatusOr<HloInstruction*> CreateCholesky(GpuSolverContext* context,
+                                               HloInstruction* operand,
+                                               const CholeskyOptions& options,
+                                               const OpMetadata& metadata) {
+  HloComputation* computation = operand->parent();
+
+  Shape a_shape = operand->shape();
+  int ndim = a_shape.dimensions_size();
+  CHECK_GE(ndim, 2);
+  int64_t n = a_shape.dimensions(ndim - 1);
+
+  std::vector<int64_t> batch_dims(a_shape.dimensions().begin(),
+                                  a_shape.dimensions().end() - 2);
+  std::vector<int64_t> batch_dim_ids(batch_dims.size());
+  absl::c_iota(batch_dim_ids, 0);
+  int64_t batch_size = absl::c_accumulate(batch_dims, 1, std::multiplies<>{});
+
+  // Find the workspace size.
+  se::blas::UpperLower uplo = options.lower() ? se::blas::UpperLower::kLower
+                                              : se::blas::UpperLower::kUpper;
+  int64_t workspace_size;  // Number of elements of size a_shape.element_type()
+  TF_ASSIGN_OR_RETURN(
+      workspace_size,
+      context->PotrfBufferSize(a_shape.element_type(), uplo, n, n, batch_size));
+
+  // TODO(phawkins): Ideally we would relax this constraint. What we actually
+  // want is that:
+  // a) the batch dimensions are major, in no particular order.
+  // b) the two minor dimensions are in fortran (column-major) order,
+
+  SetFortranLayout(&a_shape);
+
+  // This call returns a tuple of (cholesky_result, workspace, info) where:
+  // * cholesky_result is the result of the Cholesky decomposition,
+  // * workspace is temporary scratch memory used by cuSolver.
+  // * info contains the Potrf success/failure status.
+  // Currently we have no meaningful way to report an error, so we simply
+  // discard the success/failure information. Obviously this is suboptimal.
+  Shape info_shape = ShapeUtil::MakeShape(S32, batch_dims);
+  Shape call_shape = ShapeUtil::MakeTupleShape(
+      {a_shape,
+       ShapeUtil::MakeShape(operand->shape().element_type(), {workspace_size}),
+       info_shape});
+
+  HloInstruction* custom_call =
+      computation->AddInstruction(HloInstruction::CreateCustomCall(
+          call_shape, {operand}, kCusolverCholeskyCallTarget, {a_shape}));
+  custom_call->set_metadata(metadata);
+  TF_RETURN_IF_ERROR(custom_call->set_backend_config(options));
+  HloInstruction* out = computation->AddInstruction(
+      HloInstruction::CreateGetTupleElement(a_shape, custom_call, 0));
+  HloInstruction* info = computation->AddInstruction(
+      HloInstruction::CreateGetTupleElement(info_shape, custom_call, 2));
+
+  // If info was non-zero, indicating that the Cholesky decomposition failed,
+  // returns an array full of NaNs for the corresponding batch element.
+  HloInstruction* zero = computation->AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
+  HloInstruction* zeros =
+      computation->AddInstruction(HloInstruction::CreateBroadcast(
+          info_shape, zero, /*broadcast_dimensions=*/{}));
+  HloInstruction* ok = computation->AddInstruction(
+      HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, batch_dims),
+                                    info, zeros, ComparisonDirection::kEq));
+  ok = computation->AddInstruction(HloInstruction::CreateBroadcast(
+      ShapeUtil::MakeShape(PRED, a_shape.dimensions()), ok,
+      /*broadcast_dimensions=*/batch_dim_ids));
+
+  TF_ASSIGN_OR_RETURN(Literal nan_literal,
+                      LiteralUtil::NanValue(a_shape.element_type()));
+  HloInstruction* nan = computation->AddInstruction(
+      HloInstruction::CreateConstant(std::move(nan_literal)));
+  HloInstruction* nans =
+      computation->AddInstruction(HloInstruction::CreateBroadcast(
+          a_shape, nan, /*broadcast_dimensions=*/{}));
+
+  HloInstruction* select =
+      computation->AddInstruction(HloInstruction::CreateTernary(
+          a_shape, HloOpcode::kSelect, ok, out, nans));
+  return select;
+}
+
+// Tries to rewrite a single convolution into a call to cudnn.
+absl::StatusOr<bool> RunOnInstruction(GpuSolverContext* context,
+                                      HloInstruction* instruction) {
+  if (instruction->opcode() != HloOpcode::kCholesky) {
+    return false;
+  }
+
+  TF_ASSIGN_OR_RETURN(
+      HloInstruction * custom_call,
+      CreateCholesky(context, instruction->mutable_operand(0),
+                     instruction->cholesky_options(), instruction->metadata()));
+
+  VLOG(1) << "Replacing " << instruction->ToString() << " with "
+          << custom_call->ToString();
+
+  TF_RETURN_IF_ERROR(
+      instruction->parent()->ReplaceInstruction(instruction, custom_call));
+  return true;
+}
+
+}  // namespace
+
+// Rewrites the convolutions in the given computation into calls to cudnn.
+// Returns true if it made any changes.
+absl::StatusOr<bool> GpusolverRewriter::RunOnComputation(
+    HloComputation* computation) {
+  std::vector<HloInstruction*> cusolver_calls;
+  for (auto* hlo : computation->instructions()) {
+    if (hlo->opcode() == HloOpcode::kCholesky) {
+      cusolver_calls.push_back(hlo);
+    }
+  }
+
+  if (cusolver_calls.empty()) {
+    return false;
+  }
+
+  TF_ASSIGN_OR_RETURN(GpuSolverContext context, GpuSolverContext::Create());
+
+  bool changed = false;
+  for (HloInstruction* instruction : cusolver_calls) {
+    TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(&context, instruction));
+    changed |= result;
+  }
+  return changed;
+}
+
+GpusolverRewriter::GpusolverRewriter() = default;
+
+absl::StatusOr<bool> GpusolverRewriter::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  bool changed = false;
+  for (HloComputation* computation :
+       module->MakeNonfusionComputations(execution_threads)) {
+    TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation));
+    changed |= result;
+  }
+  return changed;
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/gpusolver_rewriter.h b/third_party/xla/xla/service/gpu/transforms/gpusolver_rewriter.h
new file mode 100644
index 0000000..cdc0ff2
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/gpusolver_rewriter.h
@@ -0,0 +1,47 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_GPUSOLVER_REWRITER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_GPUSOLVER_REWRITER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// Rewrites Cholesky calls into CustomCall HLOs that call into cuSolver.
+class GpusolverRewriter : public HloModulePass {
+ public:
+  GpusolverRewriter();
+  absl::string_view name() const override { return "gpusolver-rewriter"; }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+  absl::StatusOr<bool> RunOnComputation(HloComputation* computation);
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_GPUSOLVER_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion.cc b/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion.cc
new file mode 100644
index 0000000..befe869
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion.cc
@@ -0,0 +1,192 @@
+/* Copyright 2020 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/horizontal_input_fusion.h"
+
+#include <algorithm>
+#include <cstddef>
+#include <vector>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/log.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/gpu/gpu_fusible.h"
+#include "xla/service/hlo_creation_utils.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/util.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+
+// Gets the representative input shape of the multi-output fusion.
+Shape GetInputShapeForMultiOutputFusion(const HloInstruction& instr) {
+  // Get the HLO that determines the emitter used for lowering.
+  const HloInstruction* real_hero = GetRealHeroForMultiOutputFusion(instr);
+  if (real_hero->operands().empty()) {
+    // Simply return an empty shape if the representative node has no input
+    // operands.
+    return Shape();
+  } else {
+    return real_hero->operand(0)->shape();
+  }
+}
+
+class HorizontalInputFusionImpl {
+ public:
+  explicit HorizontalInputFusionImpl(HloComputation* computation,
+                                     const se::DeviceDescription& d)
+      : computation_(computation), device_info_(d) {}
+
+  ~HorizontalInputFusionImpl() = default;
+
+  absl::StatusOr<bool> Run();
+
+ private:
+  HloComputation* computation_;
+  const se::DeviceDescription& device_info_;
+};  // HorizontalInputFusionImpl
+
+// Compares one-by-one the dimensions of `shape_a` and `shape_b` from left to
+// right.
+bool CompareShapeDimsFromLeftToRight(const Shape& shape_a,
+                                     const Shape& shape_b) {
+  if (shape_a.rank() != shape_b.rank()) {
+    return shape_a.rank() < shape_b.rank();
+  }
+  auto dims_a = shape_a.dimensions();
+  auto dims_b = shape_b.dimensions();
+  for (size_t i = 0; i < dims_a.size(); ++i) {
+    if (dims_a[i] != dims_b[i]) {
+      return dims_a[i] < dims_b[i];
+    }
+  }
+  return true;
+}
+
+std::vector<HloInstruction*> FindAndSortFusionCandidates(
+    HloInstruction* consumer) {
+  absl::flat_hash_set<HloInstruction*> fusion_instr_set;
+  std::vector<HloInstruction*> fusion_instrs;
+  for (HloInstruction* opnd : consumer->operands()) {
+    HloInstruction* predecessor = opnd->LatestNonGteAncestor();
+    // Find out the input fusion instructions whose only consumer is `consumer`.
+    // This guarantees that fusing these candidates will never create cycles, as
+    // there is no back edge.
+    if (IsInputFusibleReduction(*predecessor) &&
+        IsConsumerTheOnlyNonRootUser(*predecessor, *consumer)) {
+      if (fusion_instr_set.insert(predecessor).second) {
+        fusion_instrs.push_back(predecessor);
+      }
+    }
+  }
+
+  std::sort(fusion_instrs.begin(), fusion_instrs.end(),
+            [&](const HloInstruction* a, const HloInstruction* b) {
+              Shape shape_a = GetInputShapeForMultiOutputFusion(*a);
+              Shape shape_b = GetInputShapeForMultiOutputFusion(*b);
+              if (!ShapeUtil::EqualIgnoringElementType(shape_a, shape_b)) {
+                // Sort shapes according to dimensions, so that the same input
+                // shapes will be placed adjacent each other.
+                return CompareShapeDimsFromLeftToRight(shape_a, shape_b);
+              }
+              // Sort `fusion_instrs` according to instruction counts, because
+              // we'd like to fuse together computations of similar sizes.
+              return GetInstrCountOfFusible(*a) < GetInstrCountOfFusible(*b);
+            });
+
+  return fusion_instrs;
+}
+
+absl::StatusOr<bool> HorizontalInputFusionImpl::Run() {
+  bool changed = false;
+  XLA_VLOG_LINES(3, computation_->ToString());
+
+  // Using def-to-use order is sound since we do not modify users.
+  std::vector<HloInstruction*> def_to_use_order =
+      computation_->MakeInstructionPostOrder();
+  for (HloInstruction* consumer : def_to_use_order) {
+    auto candidates = FindAndSortFusionCandidates(consumer);
+    if (candidates.size() <= 1) {
+      continue;
+    }
+
+    // Convert candidates into fusions if needed.
+    for (size_t j = 0; j < candidates.size(); ++j) {
+      if (candidates[j]->opcode() != HloOpcode::kFusion) {
+        TF_ASSIGN_OR_RETURN(
+            HloInstruction * fusion_instr,
+            MakeFusionInstruction(candidates[j],
+                                  HloInstruction::FusionKind::kInput));
+        candidates[j] = fusion_instr;
+        changed = true;
+      }
+    }
+
+    size_t fusion_anchor_id = 0;
+    for (size_t j = 1; j < candidates.size(); ++j) {
+      HloInstruction* fusion_anchor = candidates[fusion_anchor_id];
+      HloInstruction* fused = candidates[j];
+      if (ShapesCompatibleForMultiOutputFusion(*fusion_anchor, *fused) &&
+          FusionFitsInBudget(*fusion_anchor, *fused, device_info_)) {
+        VLOG(3) << "Fuse " << fused->ToString() << " into "
+                << fusion_anchor->ToString();
+        fusion_anchor->MergeFusionInstructionIntoMultiOutput(fused);
+        changed = true;
+      } else {
+        // Update the `fusion_anchor_id` since `fused` is either not
+        // compatible or not beneficial to be fused with current fusion anchor.
+        VLOG(3) << j - fusion_anchor_id - 1 << " instructions are fused.";
+        fusion_anchor_id = j;
+      }
+    }
+  }
+
+  return changed;
+}
+
+}  // namespace
+
+absl::StatusOr<bool> HorizontalInputFusion::RunOnComputation(
+    HloComputation* computation) {
+  HorizontalInputFusionImpl horizontal_fusion_impl(computation, device_info_);
+  return horizontal_fusion_impl.Run();
+}
+
+absl::StatusOr<bool> HorizontalInputFusion::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  bool changed = false;
+  VLOG(2) << "Run horizontal input fusion.";
+  for (HloComputation* comp :
+       module->MakeNonfusionComputations(execution_threads)) {
+    TF_ASSIGN_OR_RETURN(changed, RunOnComputation(comp));
+  }
+
+  return changed;
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion.h b/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion.h
new file mode 100644
index 0000000..a08168d
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion.h
@@ -0,0 +1,63 @@
+/* Copyright 2020 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_HORIZONTAL_INPUT_FUSION_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_HORIZONTAL_INPUT_FUSION_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/stream_executor/device_description.h"
+
+namespace xla {
+namespace gpu {
+
+// This optimization pass horizontally fuses kInput fusions to both reduce the
+// kernel launch overhead and increase parallelism degree. See
+// HorizontalLoopFusion for general description and motivation about horizontal
+// fusion. HorizontalLoopFusion deals with kLoop fusions while this pass deals
+// with kInput fusions.
+//
+// Following HorizontalLoopFusion, a simple yet effective heuristic is used
+// to search the fusion candidates while avoiding creating cycles. That is,
+// we simply search for fusion candidates by looking for instructions whose
+// outputs are all consumed by the same instruction. This catches the typical
+// target cases; often, the candidate instructions are just consumed by the
+// ROOT tuple of the entry computation.
+class HorizontalInputFusion : public HloModulePass {
+ public:
+  explicit HorizontalInputFusion(const se::DeviceDescription& d)
+      : device_info_(d) {}
+
+  absl::string_view name() const override { return "horizontal_input_fusion"; }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+  absl::StatusOr<bool> RunOnComputation(HloComputation*);
+
+  const se::DeviceDescription& device_info_;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_HORIZONTAL_INPUT_FUSION_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion_test.cc
new file mode 100644
index 0000000..5fc1a54
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion_test.cc
@@ -0,0 +1,270 @@
+/* Copyright 2020 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/horizontal_input_fusion.h"
+
+#include <cstdint>
+#include <utility>
+#include <vector>
+
+#include "xla/error_spec.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/literal_util.h"
+#include "xla/service/gpu/gpu_device_info_for_tests.h"
+#include "xla/service/gpu/tests/gpu_codegen_test.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/shape_util.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/test.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+namespace m = ::xla::match;
+
+class HorizontalInputFusionTest : public GpuCodegenTest {
+ public:
+  se::DeviceDescription device_description_{
+      TestGpuDeviceInfo::RTXA6000DeviceInfo()};
+  HorizontalInputFusion horizontal_input_fusion_{device_description_};
+};
+
+TEST_F(HorizontalInputFusionTest, BasicTest) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule BasicTest
+
+  %add_f16 {
+    %x = f16[] parameter(0)
+    %y = f16[] parameter(1)
+    ROOT %add = f16[] add(%x, %y)
+  }
+
+ fused_computation.1 {
+   arg.1 = f16[1024]{0} parameter(0)
+   constant0 = f16[] constant(0)
+   ROOT reduce1 = f16[] reduce(arg.1, constant0), dimensions={0}, to_apply=%add_f16
+ }
+
+ fused_computation.2 {
+   arg.1 = f16[1024]{0} parameter(0)
+   constant0 = f16[] constant(0)
+   ROOT reduce1 = f16[] reduce(arg.1, constant0), dimensions={0}, to_apply=%add_f16
+ }
+
+ ENTRY entry_computation {
+   arg.1 = f16[1024]{0} parameter(0)
+   arg.2 = f16[1024]{0} parameter(1)
+   fusion.1 = f16[] fusion(arg.1), kind=kInput, calls=fused_computation.1
+   fusion.2 = f16[] fusion(arg.2), kind=kInput, calls=fused_computation.2
+   ROOT tuple.1 = (f16[], f16[]) tuple(fusion.1, fusion.2)
+ }
+)")
+                    .value();
+
+  EXPECT_TRUE(horizontal_input_fusion_.Run(module.get()).value());
+
+  const HloInstruction* entry_root =
+      module->entry_computation()->root_instruction();
+  const HloInstruction* fusion = nullptr;
+  ASSERT_THAT(entry_root,
+              GmockMatch(m::Tuple((m::GetTupleElement(m::Fusion(&fusion))),
+                                  (m::GetTupleElement(m::Fusion())))));
+  ASSERT_TRUE(fusion->IsMultiOutputFusion());
+  EXPECT_THAT(fusion->fused_expression_root(),
+              GmockMatch(m::Tuple(m::Reduce(), m::Reduce())));
+}
+
+TEST_F(HorizontalInputFusionTest, ManyInputFusions) {
+  auto module = CreateNewVerifiedModule();
+
+  HloComputation* reduce_computation;
+  {
+    auto embedded_builder = HloComputation::Builder("add");
+    auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
+        0, ShapeUtil::MakeShape(F32, {}), "lhs"));
+    auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
+        1, ShapeUtil::MakeShape(F32, {}), "rhs"));
+    embedded_builder.AddInstruction(
+        HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs));
+    reduce_computation =
+        module->AddEmbeddedComputation(embedded_builder.Build());
+  }
+
+  HloComputation::Builder builder(TestName());
+  std::vector<HloInstruction*> var_outs;
+  auto input_shape = ShapeUtil::MakeShape(F32, {1024, 1024});
+  auto output_shape = ShapeUtil::MakeShape(F32, {1024});
+  for (int64_t i = 0; i < 130; ++i) {
+    // %fused_computation.3 (param_0: f32[1024,1024], param_1: f32[]) ->
+    // f32[1024] {
+    //  %param_0 = f32[1024,1024]{1,0} parameter(0)
+    //  %param_1 = f32[] parameter(1)
+    //  %broadcast = f32[1024,1024]{1,0} broadcast(f32[] %param_1),
+    //  dimensions={}
+    //  %multiply = f32[1024,1024]{1,0}
+    //      multiply(f32[1024,1024]{1,0} %param_0, f32[1024,1024]{1,0}
+    //      %broadcast)
+    //  %constant0 = f32[] constant(0)
+    //  ROOT %reduce = f32[1024]{0}
+    //      reduce(f32[1024,1024]{1,0} %multiply, f32[] %constant0),
+    //          dimensions={1}, to_apply=%add
+    // }
+    HloInstruction* param_var_in = builder.AddInstruction(
+        HloInstruction::CreateParameter(i * 2 + 0, input_shape, "var.in"));
+    HloInstruction* param_alpha =
+        builder.AddInstruction(HloInstruction::CreateParameter(
+            i * 2 + 1, ShapeUtil::MakeShape(F32, {}), "alpha"));
+    auto alpha_broadcasted = builder.AddInstruction(
+        HloInstruction::CreateBroadcast(input_shape, param_alpha, {}));
+    auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
+        input_shape, HloOpcode::kMultiply, param_var_in, alpha_broadcasted));
+    HloInstruction* const0 = builder.AddInstruction(
+        HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
+    auto reduce = builder.AddInstruction(HloInstruction::CreateReduce(
+        output_shape, mul, const0, {1}, reduce_computation));
+    var_outs.push_back(reduce);
+  }
+  builder.AddInstruction(HloInstruction::CreateTuple(var_outs));
+  module->AddEntryComputation(builder.Build());
+
+  // Verify that horizontal fusion is kicked in. Check that there are multiple
+  // `reduce` instructions fused into the same fusion.
+  if (GetDebugOptionsForTest().xla_gpu_mlir_emitter_level() < 4) {
+    // 6 is just a randomly picked number as we don't exactly know how large the
+    // fusion will be created due to the `FusionFitsInBudget` constraint.
+    CompileAndVerifyIr(module->Clone(), R"(CHECK: reduce-group-6)",
+                       /*match_optimized_ir=*/false);
+  } else {
+    // Verify that we produced a multi-output reduction with independent groups.
+    CompileAndVerifyIr(module->Clone(), R"(CHECK: switch {{.*}} label {{.*}} [
+                                           CHECK-NEXT: label)",
+                       /*match_optimized_ir=*/false);
+  }
+
+  // Testing with the entire gpu optimization pipeline.
+  EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{1e-5, 1e-5}));
+}
+
+TEST_F(HorizontalInputFusionTest, MultiOutputFusionTest) {
+  // This tests the below pattern. One known issue is that gtes (to fusions) can
+  // be removed after their producer fusions are merged. In the below case, gte2
+  // and gte6 will be gone if Fusion2 is fused into Fusion1.
+  //
+  // Fusion1   Fusion2
+  //  |   |    |     |
+  //  |  gte1 gte2   |
+  //  |   |    |     |
+  //  |   Fusion3    |
+  //  |    |   |     |
+  // gte3 gte4 gte5 gte6
+  //  \  |     |    /
+  //  =====ROOT=====
+  //
+  auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule MultiOutputFusionTest
+
+  %add_f16 {
+    %x = f16[] parameter(0)
+    %y = f16[] parameter(1)
+    ROOT %add = f16[] add(%x, %y)
+  }
+
+ fused_computation.1 {
+   arg.1 = f16[1024]{0} parameter(0)
+   constant0 = f16[] constant(0)
+   reduce.1 = f16[] reduce(arg.1, constant0), dimensions={0}, to_apply=%add_f16
+   add.0 = f16[1024] add(arg.1, arg.1)
+   ROOT tuple.1 = (f16[], f16[1024]) tuple(reduce.1, add.0)
+ }
+
+ fused_computation.2 {
+   arg.1 = f16[1024]{0} parameter(0)
+   constant0 = f16[] constant(0)
+   reduce.1 = f16[] reduce(arg.1, constant0), dimensions={0}, to_apply=%add_f16
+   add.0 = f16[1024] add(arg.1, arg.1)
+   ROOT tuple.1 = (f16[], f16[1024]) tuple(reduce.1, add.0)
+ }
+
+ fused_computation.3 {
+   arg.0 = f16[1024]{0} parameter(0)
+   arg.1 = f16[1024]{0} parameter(1)
+   add.0 = f16[1024] add(arg.0, arg.1)
+   mul.0 = f16[1024] multiply(arg.0, arg.1)
+   ROOT tuple.1 = (f16[1024], f16[1024]) tuple(add.0, mul.0)
+ }
+
+ ENTRY entry_computation {
+   arg.1 = f16[1024]{0} parameter(0)
+   arg.2 = f16[1024]{0} parameter(1)
+   fusion.1 = (f16[],f16[1024]) fusion(arg.1), kind=kInput, calls=fused_computation.1
+   fusion.2 = (f16[],f16[1024]) fusion(arg.2), kind=kInput, calls=fused_computation.2
+   gte.3 = f16[] get-tuple-element(fusion.1), index=0
+   gte.1 = f16[1024]{0} get-tuple-element(fusion.1), index=1
+   gte.2 = f16[1024]{0} get-tuple-element(fusion.2), index=1
+   gte.6 = f16[] get-tuple-element(fusion.2), index=0
+   fusion.3 = (f16[1024],f16[1024]) fusion(gte.1, gte.2),
+       kind=kLoop, calls=fused_computation.3
+   gte.4 = f16[1024] get-tuple-element(fusion.3), index=0
+   gte.5 = f16[1024]{0} get-tuple-element(fusion.3), index=1
+   ROOT tuple.1 = (f16[], f16[1024], f16[1024]{0}, f16[])
+       tuple(gte.3, gte.4, gte.5, gte.6)
+ }
+)")
+                    .value();
+
+  EXPECT_TRUE(horizontal_input_fusion_.Run(module.get()).value());
+}
+
+TEST_F(HorizontalInputFusionTest, NonfusionInstrs) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule NonfusionInstrs
+
+ %add_f16 {
+   %x = f16[] parameter(0)
+   %y = f16[] parameter(1)
+   ROOT %add = f16[] add(%x, %y)
+ }
+
+ ENTRY entry_computation {
+   arg.0 = f16[1024]{0} parameter(0)
+   arg.1 = f16[1024]{0} parameter(1)
+   constant0 = f16[] constant(0)
+   reduce.0 = f16[] reduce(arg.0, constant0), dimensions={0}, to_apply=%add_f16
+   reduce.1 = f16[] reduce(arg.1, constant0), dimensions={0}, to_apply=%add_f16
+   ROOT tuple.0 = (f16[], f16[]) tuple(reduce.0, reduce.1)
+ }
+)")
+                    .value();
+
+  EXPECT_TRUE(horizontal_input_fusion_.Run(module.get()).value());
+
+  const HloInstruction* entry_root =
+      module->entry_computation()->root_instruction();
+  const HloInstruction* fusion = nullptr;
+  ASSERT_THAT(entry_root,
+              GmockMatch(m::Tuple((m::GetTupleElement(m::Fusion(&fusion))),
+                                  (m::GetTupleElement(m::Fusion())))));
+  ASSERT_TRUE(fusion->IsMultiOutputFusion());
+  EXPECT_THAT(fusion->fused_expression_root(),
+              GmockMatch(m::Tuple(m::Reduce(), m::Reduce())));
+}
+
+}  // namespace
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.cc b/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.cc
new file mode 100644
index 0000000..0a3d705
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.cc
@@ -0,0 +1,744 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/horizontal_loop_fusion.h"
+
+#include <algorithm>
+#include <cstddef>
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/layout_util.h"
+#include "xla/service/gpu/gpu_fusible.h"
+#include "xla/service/hlo_creation_utils.h"
+#include "xla/service/sub_byte_normalization.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+
+PrimitiveType GetUniqueOutputTypeOfFusible(const HloInstruction& fusible) {
+  auto outputs = GetOutputsOfFusible(fusible);
+  CHECK(!outputs.empty());
+  PrimitiveType first_output_type = outputs[0]->shape().element_type();
+  for (size_t i = 1; i < outputs.size(); ++i) {
+    PrimitiveType cur_output_type = outputs[i]->shape().element_type();
+    CHECK(first_output_type == cur_output_type)
+        << "Output types are expected to be unique, but see "
+        << PrimitiveType_Name(first_output_type) << " and "
+        << PrimitiveType_Name(cur_output_type);
+  }
+
+  return first_output_type;
+}
+
+class HorizontalLoopFusionImpl {
+ public:
+  explicit HorizontalLoopFusionImpl(HloComputation* computation,
+                                    absl::string_view prefix)
+      : computation_(computation), prefix_(prefix) {}
+
+  ~HorizontalLoopFusionImpl() = default;
+
+  absl::StatusOr<bool> Run();
+
+ private:
+  absl::Status Fuse(absl::Span<HloInstruction*> fused_fusion_instrs,
+                    bool sliced_input_fusion,
+                    std::vector<HloInstruction*>& to_fuse_candidates);
+
+  // If `sliced_input_fusion` is true, Horizontally fuses `fused_fusion_instrs`
+  // into kInput computation, else fuses `fused_fusion_instrs` into kLoop
+  // computation.
+  //
+  // It is required that each of `fused_fusion_instrs` is a kLoop fusion. Also,
+  // we require their numbers of outputs to be the same, so that each output
+  // will be fused/concatenated with the same number of outputs from other fused
+  // fusion instrs. Then, all the fused outputs still have the same shapes for
+  // kernel generation.
+  //
+  // Returns the fused computation in `uniq_computation` and the operands that
+  // are used by `uniq_computation`.
+  absl::Status CreateFusedComputation(
+      absl::Span<HloInstruction*> fused_fusion_instrs,
+      std::unique_ptr<HloComputation>* uniq_computation,
+      std::vector<HloInstruction*>* bound_operands, bool sliced_input_fusion);
+
+  // Horizontally fuses the operands of consumer instruction,
+  // `sliced_input_fusion` controls whether kInput or kLoop type fused
+  // instruction want to be created. `to_fuse_candidates` is the instruction
+  // stack that we want to try horizontally fuse its operands, when we create a
+  // new fusion instruction, we push it to the stack in hope to further fuse its
+  // operands.
+  absl::StatusOr<bool> FuseConsumerOperands(
+      HloInstruction* consumer, bool sliced_input_fusion,
+      std::vector<HloInstruction*>& to_fuse_candidates);
+
+  // FusionCandidates collects profitable candidates for a given consumer
+  // instruction. GetNextSpanOfFusions() can then be iteratively invoked to
+  // acquire the next set of fusion candidates based on some heuristics.
+  class FusionCandidates {
+   public:
+    explicit FusionCandidates(HloInstruction* consumer,
+                              bool sliced_input_fusion)
+        : fusible_instrs_(),
+          pos_(0),
+          sliced_input_fusion_(sliced_input_fusion) {
+      Initialize(consumer);
+    }
+
+    // Gets a span of fusions to be fused.
+    absl::Span<HloInstruction*> GetNextSpanOfFusions();
+
+   private:
+    void Initialize(HloInstruction*);
+
+    std::vector<HloInstruction*> fusible_instrs_;
+    // `pos_` points to the start position of the next span.
+    size_t pos_;
+    // `sliced_input_fusion_` flag controls whether we want to fuse
+    // into kLoop (false) or kInput (True) type kernel
+    bool sliced_input_fusion_;
+  };
+
+  HloComputation* computation_;
+  std::string prefix_;
+};  // HorizontalLoopFusionImpl
+
+bool IsFusibleCandidate(const HloInstruction& instr) {
+  // For now, we do not support fusing instruction with control flow.
+  if (!instr.control_successors().empty() ||
+      !instr.control_predecessors().empty()) {
+    return false;
+  }
+
+  if (IsNestableVariadicReduction(instr)) {
+    return false;
+  }
+
+  // Require no further check for element-wise instructions.
+  if (instr.IsElementwise() && instr.operand_count() > 0) {
+    return true;
+  }
+
+  // Exclude fusions other than kLoop.
+  if (!instr.IsLoopFusion()) {
+    return false;
+  }
+
+  // Cannot support fusion who has multiple output types, because the
+  // concatenate (inserted for horizontal fusion) requires the same type
+  // for all of its operands.
+  auto outputs = GetOutputsOfFusible(instr);
+  CHECK(!outputs.empty());
+  const HloInstruction* first_output = outputs[0];
+  for (size_t i = 1; i < outputs.size(); ++i) {
+    if (first_output->shape().element_type() !=
+        outputs[i]->shape().element_type()) {
+      return false;
+    }
+  }
+
+  return true;
+}
+
+// Returns whether `instr` is a profitable candidate to be horizontally fused.
+// Since the primary benefit of horizontal fusion comes from reducing the
+// kernel launch overhead, we want to exclude the instructions with
+// insignificant kernel launch overhead. In other words, we exclude instructions
+// if their computation latencies are longer than launch latencies. We estimate
+// the computation latency of a given instruction by its shapes and the
+// instruction count in its fused computation. We roughly observe that if a
+// fusion instruction has shapes smaller than `kShapeThreshold` and has fewer
+// instructions than `kInstrCountThreshold`, it is launch-latency-bound and
+// profitable by horizontal fusion.
+bool IsProfitableFusionCandidate(const HloInstruction& instr,
+                                 bool sliced_input_fusion) {
+  // For kLoop fused kernel, each GPU thread will process 1 or more elements
+  // from each horizontal fused operands, while for kInput fused kernel, each
+  // GPU thread can only process 1 element. From experience, we enable larger
+  // tensor size threshold for kLoop fusion.
+  const int64_t kShapeThreshold =
+      sliced_input_fusion ? 128 * 2048 : 8192 * 8192;
+  const int64_t kInstrCountThreshold = sliced_input_fusion ? 30 : 128;
+  const HloInstruction* root = (instr.opcode() == HloOpcode::kFusion)
+                                   ? instr.fused_expression_root()
+                                   : &instr;
+
+  // Too large shapes are not easily profitable.
+  if (root->opcode() == HloOpcode::kTuple) {
+    // Since all output shapes are the same, use the first shape as the
+    // representative.
+    Shape shape = root->operand(0)->shape();
+    if (ShapeUtil::ElementsIn(shape) > kShapeThreshold) {
+      VLOG(2) << "Profitable check failed due to element count with "
+                 "sliced_input_fusion="
+              << sliced_input_fusion;
+      return false;
+    }
+  } else {
+    Shape shape = root->shape();
+    if (ShapeUtil::ElementsIn(shape) > kShapeThreshold) {
+      VLOG(2) << "Profiltable check failed due to element size with "
+                 "sliced_input_fusion="
+              << sliced_input_fusion;
+      return false;
+    }
+  }
+
+  // Having too many instructions is not easily profitable.
+  if (instr.opcode() == HloOpcode::kFusion &&
+      instr.fused_instruction_count() > kInstrCountThreshold) {
+    return false;
+  }
+
+  return true;
+}
+
+// Returns whether `fusion_instr` has only row-major layouts.
+// The horizontal fusion excludes computations with non-row-major layouts,
+// because fusing computations with different layouts can result in uncoalesced
+// memory accesses and cause great performance overhead.
+bool HasOnlyRowMajorLayout(const HloInstruction& instr) {
+  if (instr.opcode() != HloOpcode::kFusion) {
+    return LayoutUtil::IsMonotonicWithDim0Major(instr.shape().layout());
+  }
+
+  auto fused_instrs = instr.fused_instructions_computation()->instructions();
+  for (HloInstruction* i : fused_instrs) {
+    if (!LayoutUtil::IsDenseArray(i->shape())) {
+      continue;
+    }
+    if (!LayoutUtil::IsMonotonicWithDim0Major(i->shape().layout())) {
+      return false;
+    }
+  }
+  return true;
+}
+
+// Returns whether any operand of `instr` is a parameter instruction that
+// is shared with `fusion_instrs`.
+bool AnyOpndIsParamSharedAmongFusions(
+    const HloInstruction* instr,
+    const absl::flat_hash_set<HloInstruction*>& fusion_instrs) {
+  return absl::c_any_of(instr->operands(), [&](const HloInstruction* opnd) {
+    return opnd->opcode() == HloOpcode::kParameter &&
+           absl::c_any_of(opnd->users(), [&](const HloInstruction* user) {
+             return user != instr && fusion_instrs.contains(user);
+           });
+  });
+}
+
+void HorizontalLoopFusionImpl::FusionCandidates::Initialize(
+    HloInstruction* consumer) {
+  // First, find out all potential target candidates. We will filter out
+  // unsupported/non-profitable cases below.
+  absl::flat_hash_set<HloInstruction*> fusible_candidates;
+  std::vector<HloInstruction*> ordered_fusible_candidates;
+  for (HloInstruction* opnd : consumer->operands()) {
+    HloInstruction* predecessor = opnd->LatestNonGteAncestor();
+    // We support kLoop fusion and element-wise HLOs now. We may extend the
+    // support list if needs arise.
+    if (IsFusibleCandidate(*predecessor)) {
+      if (fusible_candidates.insert(predecessor).second) {
+        // Add unseen fusion to ordered list.
+        ordered_fusible_candidates.push_back(predecessor);
+      }
+    }
+  }
+
+  for (HloInstruction* instr : ordered_fusible_candidates) {
+    if (!IsConsumerTheOnlyNonRootUser(*instr, *consumer)) {
+      VLOG(2) << "sliced_input_fusion=" << sliced_input_fusion_
+              << " rejects maybe illegal instr " << instr->ToString()
+              << "; including it may create cycles in HLO.";
+      continue;
+    } else if (!IsProfitableFusionCandidate(*instr, sliced_input_fusion_)) {
+      VLOG(2) << "sliced_input_fusion=" << sliced_input_fusion_
+              << " rejects may-not-be profitable fusion instr"
+              << instr->ToString();
+      continue;
+    } else if (!HasOnlyRowMajorLayout(*instr)) {
+      VLOG(2) << "sliced_input_fusion=" << sliced_input_fusion_
+              << " rejects non-row-major fusion instr " << instr->ToString();
+      continue;
+    } else if (AnyOpndIsParamSharedAmongFusions(instr, fusible_candidates)) {
+      // Don't fuse fusions whose operands are parameter instructions that are
+      // shared among fusions because we cannot i/o alias the produced
+      // horizontal fusion due to the concat insertion.
+      VLOG(2) << "sliced_input_fusion=" << sliced_input_fusion_
+              << " rejects the fusion instr because it shares parameter with"
+              << " other fusion candidates, instr: " << instr->ToString();
+      continue;
+    } else {
+      VLOG(2) << "Find a fusion candidate " << instr->ToString();
+      // Encapsulate it into a fusion computation for unified representation
+      // for later processing.
+      fusible_instrs_.push_back(instr);
+    }
+  }
+
+  // Sort `fusible_instrs_` according to output types, the number of outputs,
+  // instruction counts, output tensor element count. For sliced input fusion,
+  // we only fuse instructions with the same number/type of outputs and whose
+  // computations have the same instruction count. For kLoop fusion, we requires
+  // the fused instructions to have the same number/type of outputs and also the
+  // same output shape. We did a sort here so the fusion candidates is
+  // populating a continuous span.
+  std::stable_sort(
+      fusible_instrs_.begin(), fusible_instrs_.end(),
+      [&](const HloInstruction* a, const HloInstruction* b) {
+        if (GetUniqueOutputTypeOfFusible(*a) !=
+            GetUniqueOutputTypeOfFusible(*b)) {
+          return GetUniqueOutputTypeOfFusible(*a) <
+                 GetUniqueOutputTypeOfFusible(*b);
+        } else if (GetOutputSizeOfFusible(*a) != GetOutputSizeOfFusible(*b)) {
+          return GetOutputSizeOfFusible(*a) < GetOutputSizeOfFusible(*b);
+        } else if (GetInstrCountOfFusible(*a) != GetInstrCountOfFusible(*b)) {
+          return GetInstrCountOfFusible(*a) < GetInstrCountOfFusible(*b);
+        } else {
+          return ShapeUtil::ElementsIn(GetOutputsOfFusible(*a)[0]->shape()) <
+                 ShapeUtil::ElementsIn(GetOutputsOfFusible(*b)[0]->shape());
+        }
+      });
+}
+
+// Gets a next span of fusion instructions to be fused.
+absl::Span<HloInstruction*>
+HorizontalLoopFusionImpl::FusionCandidates::GetNextSpanOfFusions() {
+  if (pos_ >= fusible_instrs_.size()) {
+    return absl::Span<HloInstruction*>();
+  }
+
+  // Fusing too many computations at a time may not be easily profitable and
+  // may increase compile time due to large kernels. Set a limit to it.
+  // From profiling results, we found an issue that large fused horizontal
+  // kernel could have lower E2E perf, though the pure GPU kernel time is
+  // shorter. TODO task for understanding why E2E perf regression for large
+  // horiizontal fused kernel. Use the experience max fusion batch size based on
+  // the fused instruction count of the operand
+  const auto kMaxFusionBatchSize = [&]() -> int64_t {
+    if (sliced_input_fusion_) {
+      return 32;
+    } else {
+      if (fusible_instrs_[pos_]->opcode() == HloOpcode::kFusion) {
+        return 32;
+      } else {
+        return 64;
+      }
+    }
+  }();
+
+  size_t left = pos_;
+  size_t right = pos_ + 1;
+  size_t first_output_size = GetOutputSizeOfFusible(*fusible_instrs_[left]);
+  PrimitiveType first_output_type =
+      GetUniqueOutputTypeOfFusible(*fusible_instrs_[left]);
+  // CUDA has a parameter size limit of ~4k bytes.
+  constexpr int64_t kMaxCudaParamSize = 4000;
+  size_t accum_io_size = 0;
+  size_t accum_num_outputs = 0;
+  for (; right < fusible_instrs_.size(); ++right) {
+    PrimitiveType cur_output_type =
+        GetUniqueOutputTypeOfFusible(*fusible_instrs_[right]);
+    if (first_output_type != cur_output_type) {
+      // Cannot fuse computations who have multiple output types.
+      break;
+    }
+    if (first_output_size != GetOutputSizeOfFusible(*fusible_instrs_[right])) {
+      // Cannot fuse computations who have different numbers of outputs.
+      break;
+    }
+    if (GetInstrCountOfFusible(*fusible_instrs_[left]) !=
+        GetInstrCountOfFusible(*fusible_instrs_[right])) {
+      // Do not fuse computations of different instruction counts as it may
+      // introduce control divergence. This is a very simple heuristic to avoid
+      // fusing computations with too much discrepancy and we may improve it
+      // when the needs arise.
+      break;
+    }
+    if (!sliced_input_fusion_ &&
+        !ShapeUtil::EqualIgnoringElementType(
+            GetOutputsOfFusible(*fusible_instrs_[left])[0]->shape(),
+            GetOutputsOfFusible(*fusible_instrs_[right])[0]->shape())) {
+      // This is for fusing into kLoop type kernel, so we requires that each
+      // fusion operand have the same shape
+      break;
+    }
+    size_t num_outputs = GetOutputSizeOfFusible(*fusible_instrs_[right]);
+    accum_num_outputs += num_outputs;
+    if (accum_num_outputs >= kMaxFusionBatchSize) {
+      // Hit max fusion batch size.
+      break;
+    }
+    accum_io_size += fusible_instrs_.at(right)->operand_count() + num_outputs;
+    if (accum_io_size * 8 >= kMaxCudaParamSize) {
+      break;
+    }
+  }
+  VLOG(2) << "horizontal fuse get instruction span with " << (right - left)
+          << " instructions for sliced_input_fusion=" << sliced_input_fusion_
+          << " fusion";
+  pos_ = right;
+  return absl::MakeSpan(fusible_instrs_).subspan(left, right - left);
+}
+
+absl::StatusOr<bool> HorizontalLoopFusionImpl::FuseConsumerOperands(
+    HloInstruction* consumer, bool sliced_input_fusion,
+    std::vector<HloInstruction*>& to_fuse_candidates) {
+  bool changed = false;
+  FusionCandidates loop_fusion_candidates(consumer, sliced_input_fusion);
+  while (true) {
+    auto fusibles = loop_fusion_candidates.GetNextSpanOfFusions();
+    if (fusibles.empty()) {
+      break;
+    } else if (fusibles.size() == 1) {
+      // Skip; there is just one fused_instr.
+      continue;
+    }
+
+    changed = true;
+    // Convert fusible into fusion_instrs to simplify the implementation of
+    // `Fuse()`.
+    std::vector<HloInstruction*> fusion_instrs;
+    for (HloInstruction* instr : fusibles) {
+      if (instr->opcode() == HloOpcode::kFusion) {
+        fusion_instrs.push_back(instr);
+      } else {
+        TF_ASSIGN_OR_RETURN(
+            HloInstruction * fusion_instr,
+            MakeFusionInstruction(instr, HloInstruction::FusionKind::kLoop));
+        fusion_instrs.push_back(fusion_instr);
+      }
+    }
+
+    TF_RETURN_IF_ERROR(Fuse(absl::MakeSpan(fusion_instrs), sliced_input_fusion,
+                            to_fuse_candidates));
+  }
+  return changed;
+}
+
+absl::Status HorizontalLoopFusionImpl::CreateFusedComputation(
+    absl::Span<HloInstruction*> fused_fusion_instrs,
+    std::unique_ptr<HloComputation>* uniq_computation,
+    std::vector<HloInstruction*>* bound_operands, bool sliced_input_fusion) {
+  // First, build a computation with only params.
+  HloComputation::Builder b(prefix_ + "horizontally_fused_computation");
+  size_t fused_comp_param_id = 0;
+  for (size_t i = 0; i < fused_fusion_instrs.size(); ++i) {
+    auto old_params = fused_fusion_instrs[i]->fused_parameters();
+    for (size_t j = 0; j < old_params.size(); ++j) {
+      HloInstruction* bound_opnd = fused_fusion_instrs[i]->mutable_operand(j);
+      // in a form of param_i_j
+      b.AddInstruction(HloInstruction::CreateParameter(
+          fused_comp_param_id++, bound_opnd->shape(),
+          absl::StrCat("param_", i, "_", j)));
+      bound_operands->push_back(bound_opnd);
+    }
+  }
+  // Always create a dummy tuple instruction to serve as the root of the
+  // computation, as the existence of a root instruction is required by the
+  // HloComputation. The real root instruction will replace it below.
+  HloInstruction* dummy_root = b.AddInstruction(
+      HloInstruction::CreateTuple(std::vector<HloInstruction*>{}));
+  *uniq_computation = b.Build(dummy_root);
+  HloComputation* comp = uniq_computation->get();
+
+  // Preparing clone_map, which maps old operand to new operand.
+  absl::flat_hash_map<const HloInstruction*, HloInstruction*> clone_map;
+  size_t new_param_id = 0;
+  for (size_t i = 0; i < fused_fusion_instrs.size(); ++i) {
+    auto old_params = fused_fusion_instrs[i]->fused_parameters();
+    for (size_t j = 0; j < old_params.size(); ++j) {
+      HloInstruction* old_param = old_params[j];
+      HloInstruction* new_param = comp->parameter_instruction(new_param_id++);
+      clone_map.insert({old_param, new_param});
+    }
+  }
+
+  // Clone every fused computation.
+  const OpMetadata* metadata = nullptr;
+  for (size_t i = 0; i < fused_fusion_instrs.size(); ++i) {
+    auto def_to_use_order = fused_fusion_instrs[i]
+                                ->fused_instructions_computation()
+                                ->MakeInstructionPostOrder();
+    for (HloInstruction* old_instr : def_to_use_order) {
+      if (old_instr->opcode() == HloOpcode::kParameter ||
+          (sliced_input_fusion && old_instr->opcode() == HloOpcode::kTuple &&
+           old_instr == fused_fusion_instrs[i]->fused_expression_root())) {
+        // Parameters have been created, and we don't need tuples from
+        // multi-output fusions, as we will directly reference the tuple
+        // operands instead by using GetOutputsOfFusible().
+        continue;
+      }
+      std::vector<HloInstruction*> new_opnds;
+      const auto& old_opnds = old_instr->operands();
+      new_opnds.reserve(old_opnds.size());
+      for (HloInstruction* old_opnd : old_opnds) {
+        CHECK(clone_map.find(old_opnd) != clone_map.end());
+        new_opnds.push_back(clone_map[old_opnd]);
+      }
+      HloInstruction* new_instr = comp->AddInstruction(
+          old_instr->CloneWithNewOperands(old_instr->shape(), new_opnds));
+      clone_map.insert({old_instr, new_instr});
+      // Get the metadata from the last fused instruction.
+      metadata = &old_instr->metadata();
+    }
+  }
+
+  // Since we require each fusion to have the same number of outputs, we can
+  // simply use the first fusion as the representative for output size.
+  size_t fused_instr_output_size =
+      GetOutputSizeOfFusible(*fused_fusion_instrs[0]);
+
+  if (sliced_input_fusion) {
+    // Fusing into kInput fusion
+    std::vector<HloInstruction*> concated_outputs;
+    for (size_t i = 0; i < fused_instr_output_size; ++i) {
+      std::vector<HloInstruction*> instr_outputs(fused_fusion_instrs.size());
+      for (size_t j = 0; j < fused_fusion_instrs.size(); ++j) {
+        const HloInstruction* old_output =
+            GetOutputsOfFusible(*fused_fusion_instrs[j])[i];
+        HloInstruction* new_output = clone_map[old_output];
+        if (new_output->shape().dimensions_size() == 1) {
+          instr_outputs[j] = new_output;
+        } else {
+          Shape new_shape = ShapeUtil::MakeShapeWithDenseLayout(
+              new_output->shape().element_type(),
+              {ShapeUtil::ElementsIn(new_output->shape())},
+              /*minor_to_major=*/std::vector<int64_t>(1, 0));
+          TF_ASSIGN_OR_RETURN(instr_outputs[j],
+                              MakeReshapeHlo(new_shape, new_output));
+        }
+      }
+      TF_ASSIGN_OR_RETURN(HloInstruction * concated_output,
+                          MakeConcatHlo(instr_outputs, 0));
+      concated_outputs.push_back(concated_output);
+    }
+
+    // Make slices of outputs.
+    std::vector<HloInstruction*> output_slices(concated_outputs.size() *
+                                               fused_fusion_instrs.size());
+    for (size_t i = 0; i < concated_outputs.size(); ++i) {
+      HloInstruction* concated_output = concated_outputs[i];
+      int64_t slice_start = 0;
+      // Create a slice per fused computation.
+      for (size_t j = 0; j < fused_fusion_instrs.size(); ++j) {
+        const HloInstruction* old_output =
+            GetOutputsOfFusible(*fused_fusion_instrs[j])[i];
+        Shape shape = old_output->shape();
+        int64_t slice_limit = slice_start + ShapeUtil::ElementsIn(shape);
+        TF_ASSIGN_OR_RETURN(
+            output_slices[concated_outputs.size() * j + i],
+            MakeSliceHlo(concated_output, {slice_start}, {slice_limit},
+                         /*strides=*/{1}));
+        slice_start = slice_limit;
+      }
+    }
+
+    // Make a tuple of output_slices.
+    HloInstruction* tuple = comp->AddInstruction(
+        HloInstruction::CreateTuple(output_slices), metadata);
+    comp->set_root_instruction(tuple, /*accept_different_shape=*/true);
+    TF_RETURN_IF_ERROR(comp->RemoveInstruction(dummy_root));
+
+  } else {
+    // Fusing into kLoop fusion
+    std::vector<HloInstruction*> tuple_operands(fused_instr_output_size *
+                                                fused_fusion_instrs.size());
+    // If fusing into kLoop fusion, the new fusion root is tuple of fused
+    // fusion computaton's root.
+    for (size_t i = 0; i < fused_instr_output_size; ++i) {
+      for (size_t j = 0; j < fused_fusion_instrs.size(); ++j) {
+        const HloInstruction* old_output =
+            GetOutputsOfFusible(*fused_fusion_instrs[j])[i];
+        HloInstruction* new_output = clone_map[old_output];
+        tuple_operands[fused_instr_output_size * j + i] = new_output;
+      }
+    }
+    // Make a tuple instruction of fused instruction outputs as
+    // the root of fused computation.
+    HloInstruction* tuple =
+        comp->AddInstruction(HloInstruction::CreateTuple(tuple_operands));
+    comp->set_root_instruction(tuple, /*accept_different_shape=*/true);
+    TF_RETURN_IF_ERROR(comp->RemoveInstruction(dummy_root));
+  }
+
+  return absl::OkStatus();
+}
+
+absl::Status HorizontalLoopFusionImpl::Fuse(
+    absl::Span<HloInstruction*> fused_fusion_instrs, bool sliced_input_fusion,
+    std::vector<HloInstruction*>& to_fuse_candidates) {
+  // Fuse fused_fusion_instrs and replace them with the new fused computation.
+  std::unique_ptr<HloComputation> uniq_computation;
+  std::vector<HloInstruction*> bound_operands;
+
+  TF_RETURN_IF_ERROR(CreateFusedComputation(fused_fusion_instrs,
+                                            &uniq_computation, &bound_operands,
+                                            sliced_input_fusion));
+
+  HloComputation* fused_comp = computation_->parent()->AddEmbeddedComputation(
+      std::move(uniq_computation));
+  HloInstruction* hori_fusion_instr = computation_->AddInstruction(
+      HloInstruction::CreateFusion(fused_comp->root_instruction()->shape(),
+                                   sliced_input_fusion
+                                       ? HloInstruction::FusionKind::kInput
+                                       : HloInstruction::FusionKind::kLoop,
+                                   bound_operands, fused_comp, prefix_),
+      &fused_comp->root_instruction()->metadata());
+  fused_comp->SetFusionInstruction(hori_fusion_instr);
+
+  // we push the newly fused instruction into fusion candidate stack, because
+  // the operands of the newly fused instruction could now be possible to be
+  // horizontally fused.
+  to_fuse_candidates.push_back(hori_fusion_instr);
+
+  // Insert bitcasts and replace corresponding users. Note that we do not insert
+  // the bitcasts in the fused computation as it does not fit into the slice
+  // input fusion pattern. However, inserting bitcasts outside the fused
+  // computation creates no performance cost.
+  size_t total_output_id = 0;
+  for (size_t i = 0; i < fused_fusion_instrs.size(); ++i) {
+    std::vector<HloInstruction*> bitcasts_or_gte;
+    HloInstruction* fused_instr = fused_fusion_instrs[i];
+    size_t num_outputs = GetOutputSizeOfFusible(*fused_instr);
+    for (size_t j = 0; j < num_outputs; ++j) {
+      const HloInstruction* output = GetOutputsOfFusible(*fused_instr)[j];
+      TF_ASSIGN_OR_RETURN(
+          HloInstruction * gep,
+          MakeGetTupleElementHlo(hori_fusion_instr, total_output_id++));
+      // This pass runs late, so useless bitcast won't be cleaned up.
+      if (output->shape().dimensions_size() == 1) {
+        bitcasts_or_gte.push_back(gep);
+      } else {
+        bitcasts_or_gte.push_back(computation_->AddInstruction(
+            HloInstruction::CreateBitcast(output->shape(), gep)));
+      }
+    }
+    HloInstruction* bitcast_or_tuple =
+        (bitcasts_or_gte.size() == 1)
+            ? bitcasts_or_gte.at(0)
+            : computation_->AddInstruction(
+                  HloInstruction::CreateTuple(bitcasts_or_gte));
+    HloComputation* old_computation =
+        fused_instr->fused_instructions_computation();
+    HloModule* module = old_computation->parent();
+    TF_RETURN_IF_ERROR(
+        computation_->ReplaceInstruction(fused_instr, bitcast_or_tuple));
+    TF_RETURN_IF_ERROR(module->RemoveEmbeddedComputation(old_computation));
+  }
+
+  TF_RETURN_IF_ERROR(Cast<HloFusionInstruction>(hori_fusion_instr)
+                         ->DeduplicateFusionOperands());
+
+  VLOG(1) << "Fused " << fused_fusion_instrs.size()
+          << " instructions into: " << hori_fusion_instr->ToString();
+  return absl::OkStatus();
+}
+
+absl::StatusOr<bool> HorizontalLoopFusionImpl::Run() {
+  bool changed = false;
+  XLA_VLOG_LINES(3, computation_->ToString());
+
+  // Traverse from use to def. Bitcasts are placed after h-fusions to resolve
+  // shape mismatch but bitcasts could prevent future h-fusion from happening.
+  // So, a bottom-up, use-to-def order should be more favorable. It also helps
+  // to save compiler iterations to reach the fixed point.
+  std::vector<HloInstruction*> to_fuse_candidates =
+      computation_->MakeInstructionPostOrder();
+
+  while (!to_fuse_candidates.empty()) {
+    HloInstruction* consumer = to_fuse_candidates.back();
+    to_fuse_candidates.pop_back();
+
+    // the consumer may be the operands of previously fused instruction, so
+    // it will no longer valid, skip this instruction.
+    if (consumer->IsDead()) {
+      continue;
+    }
+
+    // we first try to fuse into kLoop fusion instruction for those operands
+    // that have the same shape.
+    TF_ASSIGN_OR_RETURN(
+        bool loop_fusion_changed,
+        FuseConsumerOperands(consumer, false, to_fuse_candidates));
+
+    // for the remaining operands with diffent shape, we further try fuse them
+    // into kInput fusion instruction.
+    TF_ASSIGN_OR_RETURN(
+        bool sliced_input_fusion_changed,
+        FuseConsumerOperands(consumer, true, to_fuse_candidates));
+
+    changed = changed || loop_fusion_changed || sliced_input_fusion_changed;
+  }
+  return changed;
+}
+
+}  // namespace
+
+absl::StatusOr<bool> HorizontalLoopFusion::RunOnComputation(
+    HloComputation* computation) {
+  HorizontalLoopFusionImpl horizontal_fusion_impl(computation, prefix_);
+  return horizontal_fusion_impl.Run();
+}
+
+absl::StatusOr<bool> HorizontalLoopFusion::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  VLOG(2) << "Run horizontal fusion.";
+
+  // Run on the entry computation is actually enough.
+  TF_ASSIGN_OR_RETURN(bool changed,
+                      RunOnComputation(module->entry_computation()));
+
+  if (changed) {
+    // Correctly set element_size_in_bits for any sub-byte added slice and
+    // concatenate instructions
+    TF_ASSIGN_OR_RETURN(
+        [[maybe_unused]] bool unused,
+        SubByteNormalization{SubByteNormalization::SET_ELEMENT_SIZE}.Run(
+            module));
+  }
+
+  return changed;
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.h b/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.h
new file mode 100644
index 0000000..f29bcd3
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.h
@@ -0,0 +1,145 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_HORIZONTAL_LOOP_FUSION_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_HORIZONTAL_LOOP_FUSION_H_
+
+#include <string>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// This optimization pass horizontally fuses computations for reducing kernel
+// launch overhead while increasing kernel launch dims on GPU. The initial
+// motivation of this horizontal fusion is due to the observation that the
+// training optimizer phase (e.g., AdamOptimizer and L2Loss, etc.) typically
+// has many small kernels as a result of applying the same formula on many
+// training parameters (or variables in Tensorflow). Fusing these small
+// kernels, hence, provides performance gain.
+//
+// Theoretically speaking, we may implement a cycle detection algorithm to make
+// sure no cycles are created after fusion. However, cycle detection check is
+// somewhat cumbersome; also, we observe that naive horizontal fusion of
+// arbitrary kernels may not be profitable due to control divergence and
+// possible increase of memory bandwidth pressure due to uncoalesced memory
+// accesses (note that horizontal fusion does not change the amount of memory
+// read+written at all). In practice, a simple yet effective heuristic is used
+// to avoid these issues while addressing the known beneficial cases. That is,
+// we simply search for fusion candidates by looking for instructions whose
+// outputs are all consumed by the same instruction. This catches the cases in
+// the training optimizer phase, as the candidate instructions are typically
+// consumed only by the ROOT tuple of the entry computation.
+//
+// The following illustrates the mechanism of the horizontal fusion. Before
+// fusion, there are two trivial kernels in the illustrating example. One has
+// only a Mul op, while the other consists of only an Add op. Since they are
+// only consumed by the same (ROOT) tuple instruction, horizontal fusion is
+// triggered.
+//
+// i0 i1   i2 i3
+//  | |     | |
+//  v v     v v
+//  Mul     Add
+//   |       |
+//   v       v
+//  (ROOT) tuple
+//
+// We fuse into one of two possible patterns, depending on whether all the
+// fused operations have the same shape or not.
+//
+// case 1: if Mul and Add's output shape and type are the same, then we fuse
+// them into the below pattern:
+// i0 i1   i2 i3
+//  | |     | |
+//  v v     v v
+//  Mul     Add
+//   |       |
+//   v       v
+//  (ROOT) tuple
+// the fused kernel will be kLoop type, and GPU code is emitted through
+// the LoopFusion class.
+//
+// case 2: if Mul and Add's output shape are diffent, then we fuse them into
+// the below pattern that adds extra indexing:
+// i0 i1   i2 i3       +++ (Slice) Input Fusion
+//  | |     | |          +
+//  v v     v v          +
+//  Mul     Add          +
+//   |       |           +
+//   v       v           +
+// Reshape0  Reshape1    +
+//   |       |           +
+//   v       v           +
+//  Concatenate          +
+//   |       |           +
+//   v       v           +
+//  Slice0  Slice1     +++
+//   |       |
+//   v       v
+// Reshape2  Reshape3
+//   |       |
+//   v       v
+//  (ROOT) tuple
+//
+// the fused kernel will be kInput type, and, the GPU code is emitted through
+// the InputSlicesFusion class.
+//
+// In theory, the pattern in case 1 could also be fused into the case2 target
+// graph, but we prefer to fuse into kLoop type, because the codegen for it does
+// not have the slicing range check cost introduced by case 2 pattern.
+//
+// Note that the fusion style by case 2 provides an important advantage that
+// kernels of different shapes can be horizontally fused. The first pair of
+// reshapes (i.e., Reshape0 and Reshape1) reshape the dims to 1 dimension, so
+// that the outputs of the fused kernels can (always) be concatenated. The
+// second pair of reshapes (Reshape2 and Reshape3) restore the original shapes
+// to the output tensors.
+//
+// No extra copies are introduced by the horizontal fusion. Besides Reshape2
+// and Reshape3, the other instructions are fused into an input fusion; the
+// output dims of the concatenate will be used as the kernel launch dims.
+// Instruction bitcasts can be used for Reshape2 and Reshape3 as long as the
+// outputs of Mul and Add are row-major.
+//
+// Note, reshapes are added only if the tensors isn't already a vector.
+class HorizontalLoopFusion : public HloModulePass {
+ public:
+  HorizontalLoopFusion() = default;
+  explicit HorizontalLoopFusion(absl::string_view prefix) : prefix_(prefix) {}
+
+  absl::string_view name() const override { return "horizontal_loop_fusion"; }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+  absl::StatusOr<bool> RunOnComputation(HloComputation*);
+  std::string prefix_;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_HORIZONTAL_LOOP_FUSION_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion_test.cc
new file mode 100644
index 0000000..781d27a
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion_test.cc
@@ -0,0 +1,851 @@
+/* Copyright 2016 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/horizontal_loop_fusion.h"
+
+#include <cstdint>
+#include <optional>
+#include <utility>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/log/log.h"
+#include "xla/error_spec.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/gpu/gpu_device_info_for_tests.h"
+#include "xla/service/gpu/transforms/instruction_fusion.h"
+#include "xla/service/hlo_dce.h"
+#include "xla/service/hlo_parser.h"
+#include "xla/service/hlo_pass_fix.h"
+#include "xla/service/hlo_pass_pipeline.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/test.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/tsl/lib/core/status_test_util.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+namespace m = ::xla::match;
+
+class HorizontalLoopFusionTest : public HloTestBase {
+ public:
+  static bool IsFusion(const HloInstruction* instr) {
+    return instr->opcode() == HloOpcode::kFusion;
+  }
+};
+
+TEST_F(HorizontalLoopFusionTest, BasicTest) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule BasicTest
+
+ fused_computation.1 {
+   arg.1 = f16[1024]{0} parameter(0)
+   arg.2 = f16[1024]{0} parameter(1)
+   ROOT mul.1 = f16[1024]{0} multiply(arg.1, arg.2)
+ }
+
+ fused_computation.2 {
+   arg.1 = f16[123]{0} parameter(0)
+   arg.2 = f16[123]{0} parameter(1)
+   ROOT add.1 = f16[123]{0} add(arg.1, arg.2)
+ }
+
+ ENTRY entry_computation {
+   arg.1 = f16[1024]{0} parameter(0)
+   arg.2 = f16[1024]{0} parameter(1)
+   arg.3 = f16[123]{0} parameter(2)
+   arg.4 = f16[123]{0} parameter(3)
+   fusion.1 = f16[1024]{0}
+       fusion(arg.1, arg.2), kind=kLoop, calls=fused_computation.1
+   fusion.2 = f16[123]{0}
+       fusion(arg.3, arg.4), kind=kLoop, calls=fused_computation.2
+   ROOT tuple.1 = (f16[1024]{0}, f16[123]{0})
+       tuple(fusion.1, fusion.2)
+ }
+)")
+                    .value();
+
+  EXPECT_TRUE(HorizontalLoopFusion().Run(module.get()).value());
+  TF_ASSERT_OK(verifier().Run(module.get()).status());
+  EXPECT_FALSE(HloDCE().Run(module.get()).value());
+
+  const HloInstruction* entry_root =
+      module->entry_computation()->root_instruction();
+  const HloInstruction* fusion = nullptr;
+  ASSERT_THAT(entry_root,
+              GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
+                                  m::GetTupleElement(m::Fusion()))));
+  ASSERT_TRUE(fusion->IsMultiOutputFusion());
+  EXPECT_THAT(fusion->fused_expression_root(),
+              GmockMatch(m::Tuple(m::Slice(m::Concatenate()),
+                                  m::Slice(m::Concatenate()))));
+}
+
+// Horizontal fusion should not be triggered as fusion will create cycles.
+TEST_F(HorizontalLoopFusionTest, NegativeTestForCycle) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule NegativeTestForCycle
+
+ fused_computation.1 {
+   arg.1 = f16[123]{0} parameter(0)
+   arg.2 = f16[123]{0} parameter(1)
+   ROOT mul.1 = f16[123]{0} multiply(arg.1, arg.2)
+ }
+
+ fused_computation.2 {
+   arg.1 = f16[123]{0} parameter(0)
+   arg.2 = f16[123]{0} parameter(1)
+   ROOT add.1 = f16[123]{0} add(arg.1, arg.2)
+ }
+
+ ENTRY entry_computation {
+   arg.1 = f16[123]{0} parameter(0)
+   arg.2 = f16[123]{0} parameter(1)
+   arg.3 = f16[123]{0} parameter(2)
+   arg.4 = f16[123]{0} parameter(3)
+   // fusion.1 and fusion.2 will not be horizontally fused as it will create
+   // a cycle through fusion.1 -> add.2 -> fusion.2
+   fusion.1 = f16[123]{0}
+       fusion(arg.1, arg.2), kind=kLoop, calls=fused_computation.1
+   add.2 = f16[123]{0} add(fusion.1, arg.4)
+   fusion.2 = f16[123]{0}
+       fusion(add.2, arg.3), kind=kLoop, calls=fused_computation.2
+   ROOT tuple.1 = (f16[123]{0}, f16[123]{0}, f16[123]{0})
+       tuple(fusion.1, fusion.2, add.2)
+ }
+)")
+                    .value();
+
+  EXPECT_FALSE(HorizontalLoopFusion().Run(module.get()).value());
+}
+
+TEST_F(HorizontalLoopFusionTest, NegativeTestForIncompatibleTypes) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule NegativeTestForIncompatibleTypes
+
+ fused_computation.1 {
+   arg.1 = f16[1024]{0} parameter(0)
+   arg.2 = f16[1024]{0} parameter(1)
+   ROOT mul.1 = f16[1024]{0} multiply(arg.1, arg.2)
+ }
+
+ fused_computation.2 {
+   arg.1 = s32[123]{0} parameter(0)
+   arg.2 = s32[123]{0} parameter(1)
+   ROOT add.1 = s32[123]{0} add(arg.1, arg.2)
+ }
+
+ ENTRY entry_computation {
+   arg.1 = f16[1024]{0} parameter(0)
+   arg.2 = f16[1024]{0} parameter(1)
+   arg.3 = s32[123]{0} parameter(2)
+   arg.4 = s32[123]{0} parameter(3)
+   // fusion.1 and fusion.2 will not be horizontally fused because their output
+   // types are different.
+   fusion.1 = f16[1024]{0}
+       fusion(arg.1, arg.2), kind=kLoop, calls=fused_computation.1
+   fusion.2 = s32[123]{0}
+       fusion(arg.3, arg.4), kind=kLoop, calls=fused_computation.2
+   ROOT tuple.1 = (f16[1024]{0}, s32[123]{0})
+       tuple(fusion.1, fusion.2)
+ }
+)")
+                    .value();
+
+  EXPECT_FALSE(HorizontalLoopFusion().Run(module.get()).value());
+}
+
+TEST_F(HorizontalLoopFusionTest, FusingIntoKLoopAndKInputTogether) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule FusingIntoKLoopAndKInputTogether
+
+ fused_computation.1 {
+   arg.1 = f16[129, 2048]{1, 0} parameter(0)
+   arg.2 = f16[129, 2048]{1, 0} parameter(1)
+   ROOT mul.1 = f16[129,2048]{1, 0} multiply(arg.1, arg.2)
+ }
+
+ fused_computation.2 {
+   arg.1 = f16[129, 2048]{1, 0} parameter(0)
+   arg.2 = f16[129, 2048]{1, 0} parameter(1)
+   ROOT mul.1 = f16[129,2048]{1, 0} multiply(arg.1, arg.2)
+ }
+
+ fused_computation.3 {
+   arg.1 = f16[130, 2048]{1, 0} parameter(0)
+   arg.2 = f16[130, 2048]{1, 0} parameter(1)
+   ROOT mul.1 = f16[130,2048]{1, 0} multiply(arg.1, arg.2)
+ }
+
+ fused_computation.4 {
+   arg.1 = f16[130, 2048]{1, 0} parameter(0)
+   arg.2 = f16[130, 2048]{1, 0} parameter(1)
+   ROOT mul.1 = f16[130,2048]{1, 0} multiply(arg.1, arg.2)
+ }
+
+ fused_computation.5 {
+   arg.1 = f16[123]{0} parameter(0)
+   arg.2 = f16[123]{0} parameter(1)
+   ROOT add.1 = f16[123]{0} add(arg.1, arg.2)
+ }
+
+ fused_computation.6 {
+   arg.1 = f16[128]{0} parameter(0)
+   arg.2 = f16[128]{0} parameter(1)
+   ROOT add.1 = f16[128]{0} add(arg.1, arg.2)
+ }
+
+ ENTRY entry_computation {
+   arg.1 = f16[129, 2048]{1, 0} parameter(0)
+   arg.2 = f16[129, 2048]{1, 0} parameter(1)
+   arg.3 = f16[129, 2048]{1, 0} parameter(2)
+   arg.4 = f16[129, 2048]{1, 0} parameter(3)
+   arg.5 = f16[130, 2048]{1, 0} parameter(4)
+   arg.6 = f16[130, 2048]{1, 0} parameter(5)
+   arg.7 = f16[130, 2048]{1, 0} parameter(6)
+   arg.8 = f16[130, 2048]{1, 0} parameter(7)
+   arg.9 = f16[123]{0} parameter(8)
+   arg.10 = f16[123]{0} parameter(9)
+   arg.11 = f16[128]{0} parameter(10)
+   arg.12 = f16[128]{0} parameter(11)
+
+   // fusion.1 and fusion.2 will be fused into kLoop fusion
+   // fusion.3 and fusion.4 will be fused into another kLoop fusion
+   // fusion.5 and fusion.6 will be fused into kInput fusion
+
+   fusion.1 = f16[129,2048]{1, 0}
+      fusion(arg.1, arg.2), kind=kLoop, calls=fused_computation.1
+
+   fusion.2 = f16[129,2048]{1, 0}
+      fusion(arg.3, arg.4), kind=kLoop, calls=fused_computation.2
+
+   fusion.3 = f16[130,2048]{1, 0}
+      fusion(arg.5, arg.6), kind=kLoop, calls=fused_computation.3
+
+   fusion.4 = f16[130,2048]{1, 0}
+      fusion(arg.7, arg.8), kind=kLoop, calls=fused_computation.4
+
+   fusion.5 = f16[123]{0}
+      fusion(arg.9, arg.10), kind=kLoop, calls=fused_computation.5
+
+   fusion.6 = f16[128]{0}
+      fusion(arg.11, arg.12), kind=kLoop, calls=fused_computation.6
+
+   ROOT tuple.1 = (f16[129,2048]{1, 0}, f16[129,2048]{1, 0},
+                   f16[130,2048]{1, 0}, f16[130,2048]{1, 0},
+                   f16[123]{0}, f16[128]{0})
+      tuple(fusion.1, fusion.2, fusion.3, fusion.4, fusion.5, fusion.6)
+ }
+)")
+                    .value();
+
+  EXPECT_TRUE(HorizontalLoopFusion().Run(module.get()).value());
+
+  int input_fusion_count = 0;
+  int loop_fusion_count = 0;
+  for (auto inst : module->entry_computation()->MakeInstructionPostOrder()) {
+    if (inst->opcode() == HloOpcode::kFusion) {
+      input_fusion_count +=
+          (inst->fusion_kind() == HloInstruction::FusionKind::kInput) ? 1 : 0;
+      loop_fusion_count +=
+          (inst->fusion_kind() == HloInstruction::FusionKind::kLoop) ? 1 : 0;
+    }
+  }
+  EXPECT_EQ(input_fusion_count, 1);
+  EXPECT_EQ(loop_fusion_count, 2);
+}
+
+TEST_F(HorizontalLoopFusionTest, HorizontalLoopFusionAfterVerticalFusion) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule MergeSharedFusionInstruction
+
+ ENTRY MergeSharedFusionInstruction.Computation0 {
+  param.1.1   = f32[4,1024]{1,0} parameter(0)
+  param.1.2   = f32[4,1024]{1,0} parameter(1)
+  param.1.3   = f32[4,1024]{1,0} parameter(2)
+  param.2.1   = f32[321,5]{1,0} parameter(3)
+  param.2.2   = f32[321,5]{1,0} parameter(4)
+  param.2.3   = f32[321,5]{1,0} parameter(5)
+  const.1     = f32[] constant(3)
+  const.2     = f32[] constant(3)
+  broadcast.1 = f32[4,1024]{1,0} broadcast(const.1), dimensions={}
+  broadcast.2 = f32[321,5]{1,0} broadcast(const.2), dimensions={}
+  mul.1.1     = f32[4,1024]{1,0} multiply(param.1.1, param.1.2)
+  mul.1.2     = f32[4,1024]{1,0} multiply(param.1.3, broadcast.1)
+  add.1       = f32[4,1024]{1,0} add(mul.1.1, mul.1.2)
+  mul.2.1     = f32[321,5]{1,0} multiply(param.2.1, param.2.2)
+  mul.2.2     = f32[321,5]{1,0} multiply(param.2.3, broadcast.2)
+  add.2       = f32[321,5]{1,0} add(mul.2.1, mul.2.2)
+  ROOT tuple = (f32[4,1024]{1,0}, f32[321,5]{1,0}) tuple(add.1, add.2)
+})")
+                    .value();
+
+  HloPassPipeline fusion("fusion");
+  const se::DeviceDescription device_info =
+      TestGpuDeviceInfo::RTXA6000DeviceInfo();
+  fusion.AddPass<xla::gpu::GpuInstructionFusion>(/*may_duplicate=*/false,
+                                                 device_info);
+  fusion.AddPass<xla::gpu::GpuInstructionFusion>(/*may_duplicate=*/true,
+                                                 device_info);
+  EXPECT_TRUE(fusion.Run(module.get()).value());
+  EXPECT_TRUE(HorizontalLoopFusion().Run(module.get()).value());
+  TF_ASSERT_OK(verifier().Run(module.get()).status());
+
+  VLOG(2) << "Dump after horizontal fusion:";
+  VLOG(2) << module->ToString();
+
+  const HloInstruction* entry_root =
+      module->entry_computation()->root_instruction();
+  const HloInstruction* fusion_instr = nullptr;
+  // Check that we add bitcast when needed.
+  ASSERT_THAT(entry_root,
+              GmockMatch(m::Tuple(
+                  m::Bitcast(m::GetTupleElement(m::Fusion(&fusion_instr))),
+                  m::Bitcast(m::GetTupleElement(m::Fusion())))));
+  ASSERT_TRUE(fusion_instr->IsMultiOutputFusion());
+  EXPECT_THAT(fusion_instr->fused_expression_root(),
+              GmockMatch(m::Tuple(
+                  m::Slice(m::Concatenate(m::Reshape(), m::Reshape())),
+                  m::Slice(m::Concatenate(m::Reshape(), m::Reshape())))));
+
+  EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec{0, 0}));
+}
+
+TEST_F(HorizontalLoopFusionTest, GradientDescentOptimizerLike) {
+  HloComputation::Builder builder(TestName());
+
+  std::vector<HloInstruction*> var_outs;
+  for (int64_t i = 0; i < 128; ++i) {
+    // For shapes {1, 1024}, {2, 1024}, ..., {128, 1024}
+    Shape shape = ShapeUtil::MakeShape(F32, {i + 1, 1024});
+    HloInstruction* param_var_in = builder.AddInstruction(
+        HloInstruction::CreateParameter(i * 3 + 0, shape, "var.in"));
+    HloInstruction* param_alpha =
+        builder.AddInstruction(HloInstruction::CreateParameter(
+            i * 3 + 1, ShapeUtil::MakeShape(F32, {}), "alpha"));
+    HloInstruction* param_delta = builder.AddInstruction(
+        HloInstruction::CreateParameter(i * 3 + 2, shape, "delta"));
+    HloInstruction* alpha_broadcasted = builder.AddInstruction(
+        HloInstruction::CreateBroadcast(shape, param_alpha, {}));
+    HloInstruction* alpha_delta =
+        builder.AddInstruction(HloInstruction::CreateBinary(
+            shape, HloOpcode::kMultiply, alpha_broadcasted, param_delta));
+    HloInstruction* var_out =
+        builder.AddInstruction(HloInstruction::CreateBinary(
+            shape, HloOpcode::kSubtract, param_var_in, alpha_delta));
+    var_outs.push_back(var_out);
+  }
+  builder.AddInstruction(HloInstruction::CreateTuple(var_outs));
+
+  auto module = CreateNewVerifiedModule();
+  module->AddEntryComputation(builder.Build());
+
+  // Testing with the entire gpu optimization pipeline.
+  EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{0, 0}));
+}
+
+TEST_F(HorizontalLoopFusionTest, FusingDifferentOutputs) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule HeterogeneousMultiOutputFusions
+
+ fused_computation.1 {
+   arg.1 = f16[1024]{0} parameter(0)
+   arg.2 = f16[1024]{0} parameter(1)
+   arg.3 = f16[1024]{0} parameter(2)
+   arg.4 = f16[1024]{0} parameter(3)
+   mul.1 = f16[1024]{0} multiply(arg.1, arg.2)
+   mul.2 = f16[1024]{0} multiply(arg.3, arg.4)
+   add.1 = f16[1024]{0} add(mul.1, mul.2)
+   ROOT tuple.1 = (f16[1024]{0}, f16[1024]{0}) tuple(add.1, mul.1)
+ }
+
+ fused_computation.2 {
+   arg.1 = f16[123]{0} parameter(0)
+   arg.2 = f16[123]{0} parameter(1)
+   arg.3 = f16[123]{0} parameter(2)
+   arg.4 = f16[123]{0} parameter(3)
+   add.1 = f16[123]{0} add(arg.1, arg.2)
+   add.2 = f16[123]{0} add(arg.3, arg.4)
+   mul.1 = f16[123]{0} multiply(add.1, add.2)
+   ROOT tuple.1 = (f16[123]{0}, f16[123]{0}) tuple(mul.1, add.1)
+ }
+
+ ENTRY entry_computation {
+   arg.1 = f16[1024]{0} parameter(0)
+   arg.2 = f16[1024]{0} parameter(1)
+   arg.3 = f16[1024]{0} parameter(2)
+   arg.4 = f16[1024]{0} parameter(3)
+   arg.5 = f16[123]{0} parameter(4)
+   arg.6 = f16[123]{0} parameter(5)
+   arg.7 = f16[123]{0} parameter(6)
+   arg.8 = f16[123]{0} parameter(7)
+   fusion.1 = (f16[1024]{0}, f16[1024]{0})
+       fusion(arg.1, arg.2, arg.3, arg.4),
+           kind=kLoop, calls=fused_computation.1
+   fusion.2 = (f16[123]{0}, f16[123]{0})
+       fusion(arg.5, arg.6, arg.7, arg.8),
+           kind=kLoop, calls=fused_computation.2
+   gte.1 = f16[1024]{0} get-tuple-element(fusion.1), index=0
+   gte.2 = f16[1024]{0} get-tuple-element(fusion.1), index=1
+   gte.3 = f16[123]{0} get-tuple-element(fusion.2), index=0
+   gte.4 = f16[123]{0} get-tuple-element(fusion.2), index=1
+   ROOT tuple.1 = (f16[1024]{0}, f16[1024]{0}, f16[123]{0}, f16[123]{0})
+       tuple(gte.1, gte.2, gte.3, gte.4)
+ }
+)")
+                    .value();
+
+  EXPECT_TRUE(HorizontalLoopFusion().Run(module.get()).value());
+  TF_ASSERT_OK(verifier().Run(module.get()).status());
+  EXPECT_FALSE(HloDCE().Run(module.get()).value());
+
+  VLOG(2) << "Dump after horizontal fusion:";
+  VLOG(2) << module->ToString();
+
+  EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec{0, 0}));
+}
+
+TEST_F(HorizontalLoopFusionTest, RMSPropLike) {
+  HloComputation::Builder builder(TestName());
+
+  std::vector<HloInstruction*> all_outputs;
+  for (int64_t i = 0; i < 48; ++i) {
+    Shape shape = ShapeUtil::MakeShape(F32, {2, 1024 + i});
+    // ms <- grad**2 (1 - rho) + ms * rho
+    HloInstruction* grad = builder.AddInstruction(
+        HloInstruction::CreateParameter(i * 9 + 0, shape, "grad"));
+    HloInstruction* ms = builder.AddInstruction(
+        HloInstruction::CreateParameter(i * 9 + 1, shape, "ms"));
+    HloInstruction* rho =
+        builder.AddInstruction(HloInstruction::CreateParameter(
+            i * 9 + 2, ShapeUtil::MakeShape(F32, {}), "rho"));
+    HloInstruction* one_minus_rho =
+        builder.AddInstruction(HloInstruction::CreateParameter(
+            i * 9 + 3, ShapeUtil::MakeShape(F32, {}), "one_minus_rho"));
+    HloInstruction* rho_broadcasted =
+        builder.AddInstruction(HloInstruction::CreateBroadcast(shape, rho, {}));
+    HloInstruction* one_mins_rho_broadcasted = builder.AddInstruction(
+        HloInstruction::CreateBroadcast(shape, one_minus_rho, {}));
+    HloInstruction* grad_squared = builder.AddInstruction(
+        HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, grad, grad));
+    HloInstruction* ms_1st_term = builder.AddInstruction(
+        HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, grad_squared,
+                                     one_mins_rho_broadcasted));
+    HloInstruction* ms_2nd_term =
+        builder.AddInstruction(HloInstruction::CreateBinary(
+            shape, HloOpcode::kMultiply, ms, rho_broadcasted));
+    HloInstruction* ms_out =
+        builder.AddInstruction(HloInstruction::CreateBinary(
+            shape, HloOpcode::kAdd, ms_1st_term, ms_2nd_term));
+
+    // mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)
+    HloInstruction* momentum = builder.AddInstruction(
+        HloInstruction::CreateParameter(i * 9 + 4, shape, "momemtum"));
+    HloInstruction* mom = builder.AddInstruction(
+        HloInstruction::CreateParameter(i * 9 + 5, shape, "mom"));
+    HloInstruction* lr = builder.AddInstruction(HloInstruction::CreateParameter(
+        i * 9 + 6, ShapeUtil::MakeShape(F32, {}), "lr"));
+    HloInstruction* epsilon =
+        builder.AddInstruction(HloInstruction::CreateParameter(
+            i * 9 + 7, ShapeUtil::MakeShape(F32, {}), "epsilon"));
+    HloInstruction* lr_broadcasted =
+        builder.AddInstruction(HloInstruction::CreateBroadcast(shape, lr, {}));
+    HloInstruction* epsilon_broadcasted = builder.AddInstruction(
+        HloInstruction::CreateBroadcast(shape, epsilon, {}));
+    HloInstruction* mom_1st_term =
+        builder.AddInstruction(HloInstruction::CreateBinary(
+            shape, HloOpcode::kMultiply, momentum, mom));
+    HloInstruction* ms_eps =
+        builder.AddInstruction(HloInstruction::CreateBinary(
+            shape, HloOpcode::kAdd, ms_out, epsilon_broadcasted));
+    HloInstruction* ms_eps_rsq = builder.AddInstruction(
+        HloInstruction::CreateUnary(shape, HloOpcode::kRsqrt, ms_eps));
+    HloInstruction* grad_ms_eps_rsq =
+        builder.AddInstruction(HloInstruction::CreateBinary(
+            shape, HloOpcode::kMultiply, grad, ms_eps_rsq));
+    HloInstruction* mom_2nd_term =
+        builder.AddInstruction(HloInstruction::CreateBinary(
+            shape, HloOpcode::kMultiply, lr_broadcasted, grad_ms_eps_rsq));
+    HloInstruction* mom_out =
+        builder.AddInstruction(HloInstruction::CreateBinary(
+            shape, HloOpcode::kAdd, mom_1st_term, mom_2nd_term));
+
+    // var <- var - mom
+    HloInstruction* var = builder.AddInstruction(
+        HloInstruction::CreateParameter(i * 9 + 8, shape, "var"));
+    HloInstruction* var_out =
+        builder.AddInstruction(HloInstruction::CreateBinary(
+            shape, HloOpcode::kSubtract, var, mom_out));
+
+    all_outputs.push_back(ms_out);
+    all_outputs.push_back(mom_out);
+    all_outputs.push_back(var_out);
+  }
+  builder.AddInstruction(HloInstruction::CreateTuple(all_outputs));
+
+  auto module = CreateNewVerifiedModule();
+  module->AddEntryComputation(builder.Build());
+
+  EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{1.0e-5, 1.0e-5}));
+}
+
+TEST_F(HorizontalLoopFusionTest, DynamicUpdateSlice) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule NegativeTestForDynamicUpdateSlice
+
+  fusion.1 {
+    p.0 = f16[5,9,10]{2,1,0} parameter(0)
+    p.1 = s32[] parameter(1)
+    p.2 = f16[1,9,10]{2,1,0} parameter(2)
+    c.0 = s32[] constant(0)
+    ROOT %dynamic-update-slice =
+        f16[5,9,10]{2,1,0} dynamic-update-slice(p.0, p.2, p.1, c.0, c.0)
+  }
+
+  fusion.2 {
+    p.0 = f16[5,9,10]{2,1,0} parameter(0)
+    p.1 = s32[] parameter(1)
+    p.2 = f16[1,9,10]{2,1,0} parameter(2)
+    c.0 = s32[] constant(0)
+    ROOT %dynamic-update-slice =
+        f16[5,9,10]{2,1,0} dynamic-update-slice(p.0, p.2, p.1, c.0, c.0)
+  }
+
+  ENTRY entry {
+    p.00 = f16[5,9,10]{2,1,0} parameter(0)
+    p.01 = f16[5,9,10]{2,1,0} parameter(1)
+    p.10 = s32[] parameter(2)
+    p.11 = s32[] parameter(3)
+    p.20 = f16[1,9,10]{2,1,0} parameter(4)
+    p.21 = f16[1,9,10]{2,1,0} parameter(5)
+
+    f1 = f16[5,9,10] fusion(p.00, p.10, p.20), kind=kLoop, calls=fusion.1
+    f2 = f16[5,9,10] fusion(p.01, p.11, p.21), kind=kLoop, calls=fusion.2
+    ROOT tuple = (f16[5,9,10],f16[5,9,10]) tuple(f1, f2)
+  })")
+                    .value();
+
+  EXPECT_TRUE(HorizontalLoopFusion().Run(module.get()).value());
+  TF_ASSERT_OK(verifier().Run(module.get()).status());
+  EXPECT_FALSE(HloDCE().Run(module.get()).value());
+
+  VLOG(2) << "Dump after horizontal fusion:";
+  VLOG(2) << module->ToString();
+
+  EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec{0, 0}));
+}
+
+TEST_F(HorizontalLoopFusionTest, NegativeTestForSharedParam) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule BasicTest
+
+ fused_computation.1 {
+   arg.1 = f16[123]{0} parameter(0)
+   arg.2 = f16[123]{0} parameter(1)
+   ROOT mul.1 = f16[123]{0} multiply(arg.1, arg.2)
+ }
+
+ fused_computation.2 {
+   arg.1 = f16[123]{0} parameter(0)
+   arg.2 = f16[123]{0} parameter(1)
+   ROOT add.1 = f16[123]{0} add(arg.1, arg.2)
+ }
+
+ ENTRY entry_computation {
+   arg.1 = f16[123]{0} parameter(0)
+   // arg.2 is shared by fusion.1 and fusion.2
+   arg.2 = f16[123]{0} parameter(1)
+   arg.3 = f16[123]{0} parameter(2)
+   fusion.1 = f16[123]{0}
+       fusion(arg.1, arg.2), kind=kLoop, calls=fused_computation.1
+   fusion.2 = f16[123]{0}
+       fusion(arg.3, arg.2), kind=kLoop, calls=fused_computation.2
+   ROOT tuple.1 = (f16[123]{0}, f16[123]{0})
+       tuple(fusion.1, fusion.2)
+ }
+)")
+                    .value();
+
+  EXPECT_FALSE(HorizontalLoopFusion().Run(module.get()).value());
+}
+
+TEST_F(HorizontalLoopFusionTest, IterativeHorizontalFusion) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule NonfusionInstrs
+
+ fused_computation.0 {
+   arg.0 = f16[] parameter(0)
+   arg.1 = f16[123]{0} parameter(1)
+   broadcast.0 = f16[123]{0} broadcast(arg.0), dimensions={}
+   ROOT mul.1 = f16[123]{0} multiply(broadcast.0, arg.1)
+ }
+
+ fused_computation.1 {
+   arg.0 = f16[] parameter(0)
+   arg.1 = f16[456]{0} parameter(1)
+   broadcast.0 = f16[456]{0} broadcast(arg.0), dimensions={}
+   ROOT add.1 = f16[456]{0} add(broadcast.0, arg.1)
+ }
+
+ ENTRY entry_computation {
+   arg.0 = f16[] parameter(0)
+   arg.1 = f16[] parameter(1)
+   arg.2 = f16[123]{0} parameter(2)
+   arg.3 = f16[456]{0} parameter(3)
+   // Test fusion of non-fusion instructions. sqrt.0 and sqrt.1 are to be
+   // fused.
+   sqrt.0 = f16[] sqrt(arg.0)
+   sqrt.1 = f16[] sqrt(arg.1)
+   // fusion.0 and fusion.1 are to be fused.
+   fusion.0 = f16[123]{0}
+       fusion(sqrt.0, arg.2), kind=kLoop, calls=fused_computation.0
+   fusion.1 = f16[456]{0}
+       fusion(sqrt.1, arg.3), kind=kLoop, calls=fused_computation.1
+   ROOT tuple.1 = (f16[123]{0}, f16[456]{0}) tuple(fusion.0, fusion.1)
+ }
+)")
+                    .value();
+
+  HloPassFix<HloPassPipeline> iterative_h_fusion("iterative_h_fusion");
+  iterative_h_fusion.AddPass<HorizontalLoopFusion>();
+  iterative_h_fusion.AddPass<HloDCE>();
+  EXPECT_TRUE(iterative_h_fusion.Run(module.get()).value());
+
+  // Verify that fusion.0 and fusion.1 are fused.
+  const HloInstruction* entry_root =
+      module->entry_computation()->root_instruction();
+  const HloInstruction* fusion = nullptr;
+  ASSERT_THAT(entry_root,
+              GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
+                                  m::GetTupleElement(m::Fusion()))));
+  EXPECT_TRUE(fusion->IsMultiOutputFusion());
+
+  // Verify that the total number of fusion instructions is 2 so that we
+  // know sqrt.0 and sqrt.1 are fused.
+  EXPECT_EQ(
+      absl::c_count_if(module->entry_computation()->instructions(), IsFusion),
+      2);
+}
+
+TEST_F(HorizontalLoopFusionTest, TraversalOrder) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+ HloModule cluster
+
+ %fused_computation (param_0: f32[256,256], param_1: f32[], param_2: f32[])
+     -> f32[256,256] {
+   %param_0 = f32[256,256]{1,0} parameter(0)
+   %param_1 = f32[] parameter(1)
+   %param_2 = f32[] parameter(2)
+   %multiply.0 = f32[] multiply(f32[] %param_1, f32[] %param_2)
+   %broadcast.0 = f32[256,256]{1,0} broadcast(f32[] %multiply.0), dimensions={}
+   ROOT %multiply.1 = f32[256,256]{1,0}
+       multiply(f32[256,256]{1,0} %param_0, f32[256,256]{1,0} %broadcast.0)
+ }
+
+ %fused_computation.1 (param_0: f32[256,256], param_1: f32[], param_2: f32[])
+     -> f32[256,256] {
+   %param_0 = f32[256,256]{1,0} parameter(0)
+   %param_1 = f32[] parameter(1)
+   %param_2 = f32[] parameter(2)
+   %multiply.0 = f32[] multiply(f32[] %param_1, f32[] %param_2)
+   %broadcast.0 = f32[256,256]{1,0} broadcast(f32[] %multiply.0), dimensions={}
+   ROOT %multiply.1 = f32[256,256]{1,0}
+       multiply(f32[256,256]{1,0} %param_0, f32[256,256]{1,0} %broadcast.0)
+ }
+
+ ENTRY %entry_computation (arg0: f32[256,256], arg1: f32[256,256], arg2: f32[],
+                           arg3: f32[], arg4: f32[], arg5: f32[])
+                               -> (f32[256,256], f32[256,256]) {
+   %arg0 = f32[256,256]{1,0} parameter(0), parameter_replication={false}
+   %arg1 = f32[256,256]{1,0} parameter(1), parameter_replication={false}
+   %arg2 = f32[] parameter(2), parameter_replication={false}
+   %arg3 = f32[] parameter(3), parameter_replication={false}
+   %arg4 = f32[] parameter(4), parameter_replication={false}
+   %arg5 = f32[] parameter(5), parameter_replication={false}
+   %sqrt = f32[] sqrt(f32[] %arg2)
+   %sqrt.1 = f32[] sqrt(f32[] %arg3)
+   %fusion = f32[256,256]{1,0}
+       fusion(f32[256,256]{1,0} %arg0, f32[] %sqrt, f32[] %sqrt.1),
+       kind=kLoop, calls=%fused_computation
+   %sqrt.2 = f32[] sqrt(f32[] %arg4)
+   %sqrt.3 = f32[] sqrt(f32[] %arg5)
+   %fusion.1 = f32[256,256]{1,0}
+       fusion(f32[256,256]{1,0} %arg1, f32[] %sqrt.2, f32[] %sqrt.3),
+       kind=kLoop, calls=%fused_computation.1
+   ROOT %tuple.163 = (f32[256,256]{1,0}, f32[256,256]{1,0})
+       tuple(f32[256,256]{1,0} %fusion.1, f32[256,256]{1,0} %fusion)
+ }
+)")
+                    .value();
+
+  HloPassFix<HloPassPipeline> iterative_h_fusion("iterative_h_fusion");
+  iterative_h_fusion.AddPass<HorizontalLoopFusion>();
+  EXPECT_TRUE(iterative_h_fusion.Run(module.get()).value());
+
+  // Verify that the total number of fusion instructions is 2 so that we
+  // know all the sqrt instructions are fused into a kernel. Note that if we
+  // traverse from def-to-use (i.e., top-to-down) instead of use-to-def, we
+  // will end up having 3 fusions instead of 2.
+  EXPECT_EQ(
+      absl::c_count_if(module->entry_computation()->instructions(), IsFusion),
+      2);
+}
+
+// Simplified reproducer for Google bug b/242287055.
+// Things that happened:
+//  - horizontal loop fusion joined addition a0 and multiplication m0
+//  - the resulting fusion had 4 inputs: (gte1, gte0, gte1, gte0)
+//  - buffer assignment aliased outputs of this fusion with its inputs
+//  - some threads simultaneously did the addition, some - multiplication
+//  - as a result some inputs were overwritten before being read
+// Conditional operation is meaningless (branches are equivalent) and
+// is there only to properly confuse the buffer assignment.
+TEST_F(HorizontalLoopFusionTest, NoBufferAliasingOfDuplicateParameter) {
+  const char* hlo_text = R"(
+HloModule m
+
+branch_a {
+  p0 = s32[] parameter(0)
+  c0 = s32[] constant(1)
+  c1 = s32[] constant(2)
+  b0 = s32[4096] broadcast(c0), dimensions={}
+  b1 = s32[4096] broadcast(c1), dimensions={}
+  ROOT r = (s32[4096], s32[4096]) tuple(b0, b1)
+}
+
+branch_b {
+  p0 = s32[] parameter(0)
+  c0 = s32[] constant(1)
+  c1 = s32[] constant(2)
+  b0 = s32[4096] broadcast(c0), dimensions={}
+  b1 = s32[4096] broadcast(c1), dimensions={}
+  ROOT r = (s32[4096], s32[4096]) tuple(b0, b1)
+}
+
+ENTRY e {
+  p0 = s32[] parameter(0)
+  c0 = s32[] constant(0)
+  cond = (s32[4096], s32[4096]) conditional(p0, c0, c0), branch_computations={branch_a, branch_b}
+  p1 = s32[4096] parameter(1)
+  gte0 = s32[4096] get-tuple-element(cond), index=0
+  gte1 = s32[4096] get-tuple-element(cond), index=1
+  a0 = s32[4096] add(gte1, gte0)
+  m0 = s32[4096] multiply(gte1, gte0)
+  ROOT r = (s32[4096], s32[4096]) tuple(m0, a0)
+}
+)";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, std::nullopt));
+}
+
+TEST_F(HorizontalLoopFusionTest, CopyInsertionFusionControlFlow) {
+  const char* hlo_text = R"(
+HloModule cluster
+
+ENTRY main {
+  cst = f32[1]{0} constant({0})
+  cp1 = f32[1]{0} copy(cst)
+  cp2 = f32[1]{0} copy(cst)
+  cp3 = f32[1]{0} copy(cst)
+  cp4 = f32[1]{0} copy(cst), control-predecessors={cp1}
+  ROOT tuple_out = (f32[1]{0}, f32[1]{0}, f32[1]{0}, f32[1]{0}) tuple(cp1, cp2, cp3, cp4)
+}
+)";
+
+  auto module = ParseAndReturnUnverifiedModule(hlo_text).value();
+  EXPECT_TRUE(HorizontalLoopFusion().Run(module.get()).value());
+
+  VLOG(2) << module->ToString();
+
+  // Verify that the total number of fusion instructions is 1.
+  EXPECT_EQ(
+      absl::c_count_if(module->entry_computation()->instructions(), IsFusion),
+      1);
+
+  const HloInstruction* entry_root =
+      module->entry_computation()->root_instruction();
+  // Check that we fuse when supported.
+  EXPECT_THAT(entry_root,
+              GmockMatch(m::Tuple(m::Copy(), m::GetTupleElement(m::Fusion()),
+                                  m::GetTupleElement(m::Fusion()), m::Copy())));
+}
+
+TEST_F(HorizontalLoopFusionTest, DoNotMergeVariadicReductions) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule m
+
+  fused_computation.94 {
+    tmp_0 = f32[] parameter(0)
+    tmp_1 = f32[] parameter(1)
+    tmp_2 = pred[] compare(tmp_0, tmp_1), direction=GE
+    tmp_3 = f32[] select(tmp_2, tmp_0, tmp_1)
+    tmp_4 = pred[] compare(tmp_0, tmp_1), direction=EQ
+    tmp_5 = s32[] parameter(2)
+    tmp_6 = s32[] parameter(3)
+    tmp_7 = s32[] minimum(tmp_5, tmp_6)
+    tmp_8 = s32[] select(tmp_2, tmp_5, tmp_6)
+    tmp_9 = s32[] select(tmp_4, tmp_7, tmp_8)
+    ROOT tmp_10 = (f32[], s32[]) tuple(tmp_3, tmp_9)
+  }
+
+  minmax_func.1536 {
+    tmp_0 = f32[] parameter(0)
+    tmp_1 = f32[] parameter(2)
+    tmp_2 = s32[] parameter(1)
+    tmp_3 = s32[] parameter(3)
+    ROOT tmp_4 = (f32[], s32[]) fusion(tmp_0, tmp_1, tmp_2, tmp_3), kind=kLoop, calls=fused_computation.94
+  }
+
+  fused_computation {
+    tmp_0 = f32[554112,10]{1,0} parameter(0)
+    tmp_1 = s32[554112,10]{1,0} iota(), iota_dimension=1
+    tmp_2 = f32[] constant(-inf)
+    tmp_3 = s32[] constant(0)
+    ROOT tmp_4 = (f32[554112]{0}, s32[554112]{0}) reduce(tmp_0, tmp_1, tmp_2, tmp_3), dimensions={1}, to_apply=minmax_func.1536
+  }
+
+  fused_computation2 {
+    tmp_0 = f32[554112,10]{1,0} parameter(0)
+    tmp_1 = s32[554112,10]{1,0} iota(), iota_dimension=1
+    tmp_2 = f32[] constant(inf)
+    tmp_3 = s32[] constant(1)
+    ROOT tmp_4 = (f32[554112]{0}, s32[554112]{0}) reduce(tmp_0, tmp_1, tmp_2, tmp_3), dimensions={1}, to_apply=minmax_func.1536
+  }
+
+  ENTRY e {
+    tmp_0 = f32[554112,10]{1,0} parameter(0)
+    tmp_1 = (f32[554112]{0}, s32[554112]{0}) fusion(tmp_0), kind=kLoop, calls=fused_computation
+    tmp_2 = s32[554112]{0} get-tuple-element(tmp_1), index=1
+    tmp_3 = f32[554112,10]{1,0} parameter(1)
+    tmp_4 = (f32[554112]{0}, s32[554112]{0}) fusion(tmp_3), kind=kLoop, calls=fused_computation2
+    tmp_5 = s32[554112]{0} get-tuple-element(tmp_4), index=1
+    ROOT tmp_6 = s32[554112]{0} add(tmp_2, tmp_5)
+  })")
+                    .value();
+
+  EXPECT_FALSE(HorizontalLoopFusion().Run(module.get()).value());
+}
+
+}  // namespace
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/instruction_fusion.cc b/third_party/xla/xla/service/gpu/transforms/instruction_fusion.cc
new file mode 100644
index 0000000..5e32f2e
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/instruction_fusion.cc
@@ -0,0 +1,187 @@
+/* Copyright 2017 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/instruction_fusion.h"
+
+#include <cstdint>
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/meta/type_traits.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/fusion_node_indexing_evaluation.h"
+#include "xla/service/fusion_queue.h"
+#include "xla/service/gpu/gpu_fusible.h"
+#include "xla/service/instruction_fusion.h"
+#include "xla/shape.h"
+#include "xla/xla_data.pb.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+bool ElementIsF32OrF16(const Shape& shape) {
+  PrimitiveType type = shape.element_type();
+  return type == F32 || type == F16;
+}
+
+class EmptyFusionQueue : public FusionQueue {
+ public:
+  std::pair<HloInstruction*, std::vector<int64_t>>
+  DequeueNextInstructionAndOperandsToFuseInOrder() override {
+    return {nullptr, {}};
+  }
+  void RemoveInstruction(HloInstruction* instruction) override {};
+  const std::vector<bool>* FusionConfiguration() override { return nullptr; };
+};
+
+}  // namespace
+
+absl::StatusOr<bool> GpuInstructionFusion::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  fusion_node_evaluations_.clear();
+  auto fusible_computations =
+      GetFusibleComputations(*module, execution_threads);
+  fusible_computations_ = {fusible_computations.begin(),
+                           fusible_computations.end()};
+  return InstructionFusion::Run(module, execution_threads);
+}
+
+/*static*/ bool GpuInstructionFusion::IsExpensive(
+    const HloInstruction& instruction) {
+  // Some floating-point math ops are cheap on the GPU.
+  switch (instruction.opcode()) {
+    case HloOpcode::kDivide:
+    case HloOpcode::kSqrt:
+    case HloOpcode::kRsqrt:
+    case HloOpcode::kExp:
+      if (ElementIsF32OrF16(instruction.shape())) {
+        return false;
+      }
+      break;
+    default:
+      break;
+  }
+  return InstructionFusion::IsExpensive(instruction);
+}
+
+FusionDecision GpuInstructionFusion::ShouldFuseInexpensiveChecks(
+    HloInstruction* consumer, int64_t operand_index) {
+  HloInstruction* producer = consumer->mutable_operand(operand_index);
+
+  // Output fusions are not currently supported on GPUs.
+  if (producer->opcode() == HloOpcode::kFusion) {
+    return "the producer is a fusion";
+  }
+
+  if (consumer->IsCustomFusion()) {
+    return "the consumer is a custom fusion";
+  }
+
+  // Cost condition: not fuse (simple, expensive producers) and (consumers who
+  // reuse operand elements).
+  if (is_expensive(*producer) &&
+      ReusesOperandElements(consumer, operand_index)) {
+    return "the producer is expensive, and the consumer reuses inputs";
+  }
+
+  // Do not fuse into fusions if the resulting kernel would suffer from
+  // uncoalesced reads due to a transposed memory access pattern.
+  if (IsInputFusibleReduction(*consumer) &&
+      IsPhysicallyTransposing(*producer)) {
+    return "fusing the producer would break read coalescing";
+  }
+
+  RETURN_IF_NOT_FUSIBLE(IsProducerConsumerFusible(*producer, *consumer));
+
+  if (CreatesHeavyComputation(*producer, *consumer)) {
+    return "the fusion would create a heavy computation";
+  }
+
+  return InstructionFusion::ShouldFuse(consumer, operand_index);
+}
+
+FusionDecision GpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
+                                                int64_t operand_index) {
+  RETURN_IF_NOT_FUSIBLE(ShouldFuseInexpensiveChecks(consumer, operand_index));
+
+  auto producer = consumer->operand(operand_index);
+
+  // The following checks are potentially expensive.
+  RETURN_IF_NOT_FUSIBLE(
+      FusionFitsInBudget(*consumer, *producer, device_info_,
+                         /*is_consumer_producer_fusion=*/true));
+
+  if (consumer->opcode() != HloOpcode::kFusion) {
+    return {};
+  }
+
+  // Also check that our emitter can handle the fusion node. We currently can
+  // have exponential time/memory requirements for emitting certain fusion
+  // kernels, in which case we don't want to fuse.
+  // TODO(b/119692968): Remove this once we have fixed our fusion emitter.
+  if (fusion_node_evaluations_.find(consumer) ==
+      fusion_node_evaluations_.end()) {
+    // We have no cached results for this fusion node yet. This can happen when
+    // we run the InstructionFusion pass more than once. We can only cache the
+    // results within one run.
+    fusion_node_evaluations_.emplace(consumer,
+                                     FusionNodeIndexingEvaluation(consumer));
+  }
+  if (fusion_node_evaluations_.at(consumer).CodeDuplicationTooHigh(producer)) {
+    return "the fusion would result in an overly large code duplication";
+  }
+  return {};
+}
+
+HloInstruction::FusionKind GpuInstructionFusion::ChooseKind(
+    const HloInstruction* producer, const HloInstruction* consumer) {
+  return ChooseFusionKind(*producer, *consumer);
+}
+
+HloInstruction* GpuInstructionFusion::FuseInstruction(
+    HloInstruction* fusion_instruction, HloInstruction* producer) {
+  auto evaluation = fusion_node_evaluations_.find(fusion_instruction);
+  if (evaluation == fusion_node_evaluations_.end()) {
+    evaluation = fusion_node_evaluations_
+                     .emplace(fusion_instruction,
+                              FusionNodeIndexingEvaluation(fusion_instruction))
+                     .first;
+  }
+  auto indexing_users = evaluation->second.RemoveFusionOperand(producer);
+  HloInstruction* new_producer =
+      InstructionFusion::FuseInstruction(fusion_instruction, producer);
+  evaluation->second.UpdateEvaluationCache(new_producer, indexing_users);
+  return new_producer;
+}
+
+std::unique_ptr<FusionQueue> GpuInstructionFusion::GetFusionQueue(
+    HloComputation* computation) {
+  if (fusible_computations_.contains(computation)) {
+    return InstructionFusion::GetFusionQueue(computation);
+  }
+  return std::make_unique<EmptyFusionQueue>();
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/instruction_fusion.h b/third_party/xla/xla/service/gpu/transforms/instruction_fusion.h
new file mode 100644
index 0000000..d7fb7f2
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/instruction_fusion.h
@@ -0,0 +1,82 @@
+/* Copyright 2017 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_INSTRUCTION_FUSION_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_INSTRUCTION_FUSION_H_
+
+#include <stdint.h>
+
+#include <memory>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/fusion_node_indexing_evaluation.h"
+#include "xla/service/fusion_queue.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/service/instruction_fusion.h"
+#include "xla/stream_executor/device_description.h"
+
+namespace xla {
+namespace gpu {
+
+class GpuInstructionFusion : public InstructionFusion {
+ public:
+  GpuInstructionFusion(bool may_duplicate, const se::DeviceDescription& d)
+      : InstructionFusion(GpuInstructionFusion::IsExpensive, may_duplicate),
+        device_info_(d) {}
+
+  static bool IsExpensive(const HloInstruction& instruction);
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ protected:
+  std::unique_ptr<FusionQueue> GetFusionQueue(
+      HloComputation* computation) override;
+  FusionDecision ShouldFuse(HloInstruction* consumer,
+                            int64_t operand_index) override;
+
+  HloInstruction::FusionKind ChooseKind(
+      const HloInstruction* producer, const HloInstruction* consumer) override;
+
+ private:
+  // This method is called by ShouldFuse() to do all the computationally
+  // inexpensive checks whether we should fuse the operand into 'consumer'.
+  FusionDecision ShouldFuseInexpensiveChecks(HloInstruction* consumer,
+                                             int64_t operand_index);
+
+  HloInstruction* FuseInstruction(HloInstruction* fusion_instruction,
+                                  HloInstruction* producer) override;
+
+  // Keep track of the number of times each instruction inside a fusion node is
+  // indexed with different index vectors.
+  absl::flat_hash_set<const HloComputation*> fusible_computations_;
+  absl::flat_hash_map<const HloInstruction*, FusionNodeIndexingEvaluation>
+      fusion_node_evaluations_;
+
+  se::DeviceDescription device_info_;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_INSTRUCTION_FUSION_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/instruction_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/instruction_fusion_test.cc
new file mode 100644
index 0000000..454dab1
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/instruction_fusion_test.cc
@@ -0,0 +1,1006 @@
+/* Copyright 2017 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/instruction_fusion.h"
+
+#include <cstdint>
+#include <memory>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/literal_util.h"
+#include "xla/service/gpu/gpu_device_info_for_tests.h"
+#include "xla/service/gpu/gpu_fusible.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/tests/test_utils.h"
+#include "xla/tests/verified_hlo_module.h"
+#include "xla/util.h"
+#include "tsl/platform/statusor.h"
+
+namespace m = ::xla::match;
+
+namespace xla {
+namespace gpu {
+
+class InstructionFusionTest : public HloTestBase {
+ public:
+  GpuInstructionFusion duplicating_instruction_fusion_{
+      /*may_duplicate=*/true, TestGpuDeviceInfo::RTXA6000DeviceInfo()};
+};
+
+TEST_F(InstructionFusionTest, NoFusionIntoCustomFusionConsumer) {
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+                          ParseAndReturnVerifiedModule(R"(
+HloModule m
+
+c {
+  p0 = bf16[3000,53]{1,0} parameter(0)
+  p1 = bf16[22,53]{1,0} parameter(1)
+  d = bf16[3000,22]{1,0} dot(p0, p1),
+    lhs_contracting_dims={1}, rhs_contracting_dims={1}
+  r = bf16[1,1,3000,22]{3,2,1,0} reshape(d)
+  ROOT c = bf16[1,1,3000,22]{2,1,3,0} copy(r)
+}
+
+ENTRY e {
+  p1 = bf16[3000,53]{1,0} parameter(1)
+  p0 = bf16[22,53]{1,0} parameter(0)
+  cp0 = bf16[22,53]{1,0} convert(p0)
+  ROOT f = bf16[1,1,3000,22]{2,1,3,0} fusion(p1, cp0), kind=kCustom, calls=c
+})"));
+
+  EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
+}
+
+TEST_F(InstructionFusionTest,
+       CostlyProducerAndOperandElementReusingConsumerNotFused) {
+  HloComputation::Builder builder(TestName());
+  HloInstruction* const0 = builder.AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::CreateR0(5.0f)));
+  HloInstruction* log1 = builder.AddInstruction(HloInstruction::CreateUnary(
+      ShapeUtil::MakeShape(F32, {}), HloOpcode::kLog, const0));
+  HloInstruction* broadcast2 =
+      builder.AddInstruction(HloInstruction::CreateBroadcast(
+          ShapeUtil::MakeShape(F32, {1}), log1, {}));
+
+  auto module = CreateNewVerifiedModule();
+  auto computation = module->AddEntryComputation(builder.Build());
+  EXPECT_EQ(broadcast2, computation->root_instruction());
+  EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
+  EXPECT_EQ(broadcast2, computation->root_instruction());
+}
+
+TEST_F(InstructionFusionTest,
+       NonCostlyProducerAndOperandElementReusingConsumerFused) {
+  HloComputation::Builder builder(TestName());
+  HloInstruction* const0 = builder.AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::CreateR0(5)));
+  HloInstruction* negate1 = builder.AddInstruction(HloInstruction::CreateUnary(
+      ShapeUtil::MakeShape(S32, {}), HloOpcode::kNegate, const0));
+  HloInstruction* broadcast2 =
+      builder.AddInstruction(HloInstruction::CreateBroadcast(
+          ShapeUtil::MakeShape(S32, {1}), negate1, {}));
+
+  auto module = CreateNewVerifiedModule();
+  auto computation = module->AddEntryComputation(builder.Build());
+  EXPECT_EQ(broadcast2, computation->root_instruction());
+  EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
+  EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Fusion()));
+}
+
+TEST_F(InstructionFusionTest,
+       CostlyProducerAndNonOperandElementReusingConsumerFused_Reshape) {
+  HloComputation::Builder builder(TestName());
+  HloInstruction* const0 = builder.AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::CreateR0(5.0f)));
+  HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
+      ShapeUtil::MakeShape(F32, {}), HloOpcode::kExp, const0));
+  HloInstruction* reshape2 = builder.AddInstruction(
+      HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {}), exp1));
+
+  auto module = CreateNewVerifiedModule();
+  auto computation = module->AddEntryComputation(builder.Build());
+  EXPECT_EQ(reshape2, computation->root_instruction());
+  EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
+  EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Fusion()));
+}
+
+TEST_F(InstructionFusionTest,
+       CostlyProducerAndNonOperandElementReusingConsumerFused_Transpose) {
+  HloComputation::Builder builder(TestName());
+  HloInstruction* const0 = builder.AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::CreateR0(5.0f)));
+  HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
+      ShapeUtil::MakeShape(F32, {}), HloOpcode::kExp, const0));
+  HloInstruction* transpose2 = builder.AddInstruction(
+      HloInstruction::CreateTranspose(ShapeUtil::MakeShape(F32, {}), exp1, {}));
+
+  auto module = CreateNewVerifiedModule();
+  auto computation = module->AddEntryComputation(builder.Build());
+  EXPECT_EQ(transpose2, computation->root_instruction());
+  EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
+  EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Fusion()));
+}
+
+TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfDotFused) {
+  HloComputation::Builder builder(TestName());
+  auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
+      0, ShapeUtil::MakeShape(F32, {1, 1}), "0"));
+  auto dot1 = builder.AddInstruction(
+      CreateCanonicalDot(ShapeUtil::MakeShape(F32, {1, 1}), param0, param0));
+  auto reshape2 = builder.AddInstruction(HloInstruction::CreateReshape(
+      ShapeUtil::MakeShape(F32, {1, 1, 1}), dot1));
+  auto log = builder.AddInstruction(HloInstruction::CreateUnary(
+      reshape2->shape(), xla::HloOpcode::kLog, reshape2));
+
+  auto module = CreateNewVerifiedModule();
+  auto computation = module->AddEntryComputation(builder.Build());
+  EXPECT_EQ(log, computation->root_instruction());
+  EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
+}
+
+TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfDotUnfused) {
+  HloComputation::Builder builder(TestName());
+  auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
+      0, ShapeUtil::MakeShape(S32, {1, 1}), "0"));
+  auto dot1 = builder.AddInstruction(
+      CreateCanonicalDot(ShapeUtil::MakeShape(S32, {1, 1}), param0, param0));
+  auto transpose2 = builder.AddInstruction(HloInstruction::CreateTranspose(
+      ShapeUtil::MakeShape(S32, {1, 1}), dot1, {0, 1}));
+
+  auto module = CreateNewVerifiedModule();
+  auto computation = module->AddEntryComputation(builder.Build());
+  EXPECT_EQ(transpose2, computation->root_instruction());
+  EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
+}
+
+// Tests that broadcasts fused into a fusion with a reduce root.
+TEST_F(InstructionFusionTest, BroadcastIntoReduce) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule test_module
+
+    add {
+      lhs = f32[] parameter(0)
+      rhs = f32[] parameter(1)
+      ROOT add = f32[] add(lhs, rhs)
+    }
+
+    ENTRY BroadcastIntoReduce {
+      constant = f32[] constant(1)
+      broadcast = f32[16,16,16,16]{3,2,1,0} broadcast(constant), dimensions={}
+      constant.1 = f32[] constant(0)
+      ROOT reduce = f32[] reduce(broadcast, constant.1), dimensions={0,1,2,3},
+                                                         to_apply=add
+    })")
+                    .value();
+
+  EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
+
+  HloInstruction* root = module->entry_computation()->root_instruction();
+  ASSERT_THAT(root, GmockMatch(m::Fusion()));
+  EXPECT_THAT(
+      root->fused_expression_root(),
+      GmockMatch(m::Reduce(m::Broadcast(m::Constant()), m::Constant())));
+}
+
+TEST_F(InstructionFusionTest, DoNotFuseLayoutChangingOpWithReduce) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule test_module
+
+    add {
+      lhs = f32[] parameter(0)
+      rhs = f32[] parameter(1)
+      ROOT add = f32[] add(lhs, rhs)
+    }
+
+    ENTRY entry {
+      p0 = f32[16,16,16,16]{3,2,1,0} parameter(0)
+      copy = f32[16,16,16,16]{0,1,2,3} copy(p0)
+      constant.1 = f32[] constant(0)
+      ROOT reduce = f32[16] reduce(copy, constant.1), dimensions={0,1,2}, to_apply=add
+    })")
+                    .value();
+
+  EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
+}
+
+TEST_F(InstructionFusionTest, DoNotFuseLayoutChangingOpWithReduceFusion) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule test_module
+
+    add {
+      lhs = f32[] parameter(0)
+      rhs = f32[] parameter(1)
+      ROOT add = f32[] add(lhs, rhs)
+    }
+
+    fused_reduce {
+      p0.1 = f32[16,16,16,16]{0,1,2,3} parameter(0)
+      mul = f32[16,16,16,16]{0,1,2,3} multiply(p0.1, p0.1)
+      c0.1 = f32[] constant(0)
+      ROOT root = f32[] reduce(mul, c0.1), dimensions={0,1,2,3}, to_apply=add
+    }
+
+    ENTRY entry {
+      p0 = f32[16,16,16,16]{3,2,1,0} parameter(0)
+      copy = f32[16,16,16,16]{0,1,2,3} copy(p0)
+      fusion = f32[] fusion(copy), kind=kInput, calls=fused_reduce
+      ROOT root = (f32[]) tuple(fusion)
+    })")
+                    .value();
+
+  EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
+}
+
+TEST_F(InstructionFusionTest, DoNotRepeatLargeReduceWindow) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule test_module
+
+    add {
+      lhs = f32[] parameter(0)
+      rhs = f32[] parameter(1)
+      ROOT add = f32[] add(lhs, rhs)
+    }
+
+    ENTRY entry {
+      p0 = s32[512,512,2] parameter(0)
+      p1 = f32[1,1,512,512] parameter(1)
+      constant_1 = f32[] constant(1)
+      reduce-window.1 = reduce-window(p1, constant_1),
+        window={size=1x1x9x9}, to_apply=add
+      ROOT ret = gather(reduce-window.1, p0), offset_dims={0,1,2,3},
+        collapsed_slice_dims={}, start_index_map={1,2},
+        index_vector_dim=2, slice_sizes={1,1,1,1}
+    })")
+                    .value();
+
+  EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
+}
+
+TEST_F(InstructionFusionTest, FuseLayoutChangingOpWithElementwise) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule test_module
+    ENTRY entry {
+      p0 = f32[16,16,16,16]{3,2,1,0} parameter(0)
+      copy = f32[16,16,16,16]{0,1,2,3} copy(p0)
+      ROOT add = f32[16,16,16,16]{0,1,2,3} add(copy, copy)
+    })")
+                    .value();
+
+  EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
+
+  HloInstruction* root = module->entry_computation()->root_instruction();
+  ASSERT_THAT(root, GmockMatch(m::Fusion()));
+  EXPECT_THAT(root->fused_expression_root(),
+              GmockMatch(m::Add(m::Copy(), m::Copy())));
+}
+
+TEST_F(InstructionFusionTest, BitcastIntoAdd) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule test_module
+
+    ENTRY BroadcastIntoAdd {
+      p0 = f32[4,1,1]{2,1,0} parameter(0)
+      p1 = f32[4,1]{1,0} parameter(1)
+      bitcast = f32[4,1]{1,0} bitcast(p0)
+      ROOT add = f32[4,1] add(bitcast, p1)
+    })")
+                    .value();
+
+  EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
+
+  HloInstruction* root = module->entry_computation()->root_instruction();
+  ASSERT_THAT(root, GmockMatch(m::Fusion()));
+  EXPECT_THAT(root->fused_expression_root(),
+              GmockMatch(m::Add(m::Bitcast(m::Parameter()), m::Parameter())));
+}
+
+TEST_F(InstructionFusionTest, AddIntoBitcast) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule test_module
+
+    ENTRY BroadcastIntoAdd {
+      p0 = f32[4,1]{1,0} parameter(0)
+      p1 = f32[4,1]{1,0} parameter(1)
+      add = f32[4,1] add(p0, p1)
+      ROOT bitcast = f32[4,1,1] bitcast(add)
+    })")
+                    .value();
+
+  EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
+}
+
+TEST_F(InstructionFusionTest, ConvertIntoBitcastBothConsumedByTuple) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule test
+
+  ENTRY main {
+    param_0 = f32[2048,16000]{1,0} parameter(0)
+    convert = bf16[2048,16000]{1,0} convert(param_0)
+    bitcast = bf16[16000,1,2048]{2,1,0} bitcast(convert)
+    ROOT tuple.143 = (bf16[16000,1,2048]{2,1,0}, bf16[2048,16000]{1,0}) tuple(bitcast, convert)
+  })")
+                    .value();
+  EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
+}
+
+TEST_F(InstructionFusionTest, DontFuseGTE) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule test_module
+  ENTRY DontFuseGTE {
+    p0 = (f32[10], f32[10]) parameter(0)
+    gte0 = f32[10] get-tuple-element(p0), index=0
+    gte1 = f32[10] get-tuple-element(p0), index=1
+    ROOT add = f32[10] add(gte0, gte1)
+  })")
+                    .value();
+
+  EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
+}
+
+// Compute sum(1/p0), where p0 has type f32, twice.  Check that the division is
+// duplicated and fused into both reduces.
+TEST_F(InstructionFusionTest, FloatingPointDivIsCheap) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule test_module
+  Add {
+    lhs = f32[] parameter(0)
+    rhs = f32[] parameter(1)
+    ROOT add = f32[] add(lhs, rhs)
+  }
+  ENTRY TestComputation {
+    zero = f32[] constant(0)
+    p0 = f32[100] parameter(0)
+    p1 = f32[100] parameter(1)
+    recip = f32[100] divide(p1, p0)
+    sum1 = f32[] reduce(recip, zero), dimensions={0}, to_apply=Add
+    sum2 = f32[] reduce(recip, zero), dimensions={0}, to_apply=Add
+    ROOT root = (f32[], f32[]) tuple(sum1, sum2)
+  })")
+                    .value();
+
+  EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
+
+  HloInstruction* root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(root, GmockMatch(m::Tuple(m::Fusion(), m::Fusion())))
+      << module->ToString();
+}
+
+// Compute sum(100/p0), where p0 has type s32, twice.  Check that the division
+// is *not* duplicated and fused into both reduces, because we say that integer
+// division is not cheap.
+TEST_F(InstructionFusionTest, IntegerDivIsNotCheap) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule test_module
+  Add {
+    lhs = s32[] parameter(0)
+    rhs = s32[] parameter(1)
+    ROOT add = s32[] add(lhs, rhs)
+  }
+  ENTRY TestComputation {
+    zero = s32[] constant(0)
+    p0 = s32[100] parameter(0)
+    p1 = s32[100] parameter(1)
+    recip = s32[100] divide(p1, p0)
+    sum1 = s32[] reduce(recip, zero), dimensions={0}, to_apply=Add
+    sum2 = s32[] reduce(recip, zero), dimensions={0}, to_apply=Add
+    ROOT mul = (s32[], s32[]) tuple(sum1, sum2)
+  })")
+                    .value();
+
+  EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value())
+      << module->ToString();
+}
+
+TEST_F(InstructionFusionTest, DotOutputFusionImpossible) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule test_module
+  ENTRY NoOutputFusion {
+    alpha = f32[] constant(3)
+    broadcast = f32[4,4]{1,0} broadcast(alpha), dimensions={}
+    p0 = f32[4,3]{1,0} parameter(0)
+    p1 = f32[3,4]{1,0} parameter(1)
+    dot = f32[4,4]{1,0} dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+    d = f32[4,4]{1,0} multiply(dot, dot)
+    ROOT mul = f32[4,4] multiply(d, broadcast)
+  })")
+                    .value();
+
+  EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
+
+  HloInstruction* root = module->entry_computation()->root_instruction();
+  ASSERT_THAT(root, GmockMatch(m::Fusion()));
+  EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kLoop);
+  EXPECT_THAT(
+      root->fused_expression_root(),
+      GmockMatch(m::Multiply(m::Multiply(m::Parameter(), m::Parameter()),
+                             m::Broadcast(m::Constant()))));
+}
+
+// Counts the HLO ops with a given op code in the specified module.
+static int Count(const HloModule& module, HloOpcode op) {
+  int count = 0;
+  for (const auto* computation : module.computations()) {
+    for (const auto* instruction : computation->instructions()) {
+      if (instruction->opcode() == op) {
+        ++count;
+      }
+    }
+  }
+  return count;
+}
+
+TEST_F(InstructionFusionTest, MultiOutputFusion) {
+  // sub --> add --> tuple
+  //  \---------------/
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule test_module
+    ENTRY OutputFusion {
+     p0 = f32[4,3]{1,0} parameter(0)
+     p1 = f32[4,3]{1,0} parameter(1)
+     p2 = f32[4,3]{1,0} parameter(2)
+     sub = f32[4,3]{1,0} subtract(p0, p2)
+     add = f32[4,3]{1,0} add(sub, p1)
+     ROOT tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}) tuple(sub, add)
+    })")
+                    .value();
+
+  // Multi-output fusion is disabled here and performed in the
+  // MultiOutputFusion pass instead.
+  ASSERT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
+}
+
+TEST_F(InstructionFusionTest, FuseScalarConstant) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule test_module
+
+  ENTRY FuseScalarConstant {
+    p0 = f32[] parameter(0)
+    c0 = f32[] constant(1)
+    add1 = f32[] add(p0, c0)
+    b0 = f32[2]{0} broadcast(add1), dimensions={}
+    c1 = f32[2]{0} constant({1, 2})
+    ROOT add2 = f32[2]{0} add(b0, c1)
+  })")
+                    .value();
+
+  EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
+
+  HloInstruction* root = module->entry_computation()->root_instruction();
+  ASSERT_THAT(root, GmockMatch(m::Fusion()));
+  EXPECT_THAT(
+      root->fused_expression_root(),
+      GmockMatch(m::Add(m::Broadcast(m::Add(m::Parameter(), m::Constant())),
+                        m::Parameter())));
+}
+
+// Check that we limit the number of operands to fusions we create.
+TEST_F(InstructionFusionTest, AvoidsLargeFusion) {
+  constexpr int64_t kNumParams = 200;
+  ASSERT_GT(kNumParams, MaxOperandsAndOutputsPerFusion());
+
+  // Compute p0 + p1 + ... + pN.
+  HloComputation::Builder b(TestName());
+  Shape shape = ShapeUtil::MakeShape(F32, {10, 100});
+  auto param0 =
+      b.AddInstruction(HloInstruction::CreateParameter(0, shape, "p"));
+  auto sum = param0;
+  for (int64_t i = 1; i < kNumParams; ++i) {
+    auto param =
+        b.AddInstruction(HloInstruction::CreateParameter(i, shape, "p"));
+    sum = b.AddInstruction(
+        HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sum, param));
+  }
+  auto module = CreateNewVerifiedModule();
+  auto computation = module->AddEntryComputation(b.Build());
+  EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
+  SCOPED_TRACE(module->ToString());
+  for (const HloInstruction* instr : computation->instructions()) {
+    EXPECT_LE(instr->operand_count(), MaxOperandsAndOutputsPerFusion())
+        << instr->ToString();
+  }
+}
+
+TEST_F(InstructionFusionTest, FuseIntoScatter) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule test_module
+
+    add {
+      lhs = s32[] parameter(0)
+      rhs = s32[] parameter(1)
+      ROOT add = s32[] add(lhs, rhs)
+    }
+
+    ENTRY FuseIntoScatter {
+      p0 = s32[3,3] parameter(0)
+      p1 = s32[2] parameter(1)
+      indices = s32[2] add(p1, p1)
+      p2 = s32[2,3] parameter(2)
+      updates = s32[2,3] add(p2, p2)
+      scatter = s32[3,3] scatter(p0, indices, updates),
+          to_apply=add,
+          update_window_dims={1},
+          inserted_window_dims={0},
+          scatter_dims_to_operand_dims={0},
+          index_vector_dim=1
+      ROOT add = s32[3,3] add(scatter, scatter)
+    })")
+                    .value();
+
+  EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
+
+  HloInstruction* root = module->entry_computation()->root_instruction();
+  const HloInstruction* fusion = nullptr;
+  ASSERT_THAT(root, GmockMatch(m::Add(m::Fusion(&fusion), m::Fusion())));
+  EXPECT_EQ(fusion->fusion_kind(), HloInstruction::FusionKind::kInput);
+  EXPECT_THAT(fusion->fused_expression_root(),
+              GmockMatch(m::Scatter(m::Parameter(), m::Add(), m::Add())));
+}
+
+TEST_F(InstructionFusionTest, DontFuseIntoFirstOperandOfScatter) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule test_module
+
+    add {
+      lhs = s32[] parameter(0)
+      rhs = s32[] parameter(1)
+      ROOT add = s32[] add(lhs, rhs)
+    }
+
+    ENTRY FuseIntoScatter {
+      p0 = s32[3,3] parameter(0)
+      operand = s32[3,3] add(p0, p0)
+      p1 = s32[2] parameter(1)
+      indices = s32[2] add(p1, p1)
+      p2 = s32[2,3] parameter(2)
+      updates = s32[2,3] add(p2, p2)
+      scatter = s32[3,3] scatter(operand, indices, updates),
+          to_apply=add,
+          update_window_dims={1},
+          inserted_window_dims={0},
+          scatter_dims_to_operand_dims={0},
+          index_vector_dim=1
+      ROOT add = s32[3,3] add(scatter, scatter)
+    })")
+                    .value();
+
+  EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
+
+  HloInstruction* root = module->entry_computation()->root_instruction();
+  const HloInstruction* fusion = nullptr;
+  ASSERT_THAT(root, GmockMatch(m::Add(m::Fusion(&fusion), m::Fusion())));
+  EXPECT_EQ(fusion->fusion_kind(), HloInstruction::FusionKind::kInput);
+  EXPECT_THAT(fusion->fused_expression_root(),
+              GmockMatch(m::Scatter(m::Parameter(), m::Add(), m::Add())));
+}
+
+TEST_F(InstructionFusionTest, ScatterOpShouldNotFuseWithSharedOperand) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule test_module
+
+  add {
+    lhs = f32[] parameter(0)
+    rhs = f32[] parameter(1)
+    ROOT add = f32[] add(lhs, rhs)
+  }
+
+  ENTRY Test {
+    parameter.0 = f32[8,8] parameter(0)
+    parameter.1 = s32[7] parameter(1)
+    indices = s32[7] add(parameter.1, parameter.1)
+    slice = f32[7,8] slice(parameter.0), slice={[0:7],[0:8]}
+    ROOT scatter = f32[8,8] scatter(parameter.0, indices, slice),
+        to_apply=add,
+        update_window_dims={1},
+        inserted_window_dims={0},
+        scatter_dims_to_operand_dims={0},
+        index_vector_dim=1
+  })")
+                    .value();
+  EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
+  // Verify that we don't fuse scatter and slice together since
+  // scatter modifies the input buffer in-place, which is also used
+  // as slice's input, and we don't know where the scatter indices point to.
+  HloInstruction* root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(
+      root, GmockMatch(m::Fusion(m::Parameter(), m::Slice(), m::Parameter())));
+}
+
+TEST_F(InstructionFusionTest, NonscalarConstantsNotFused) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule test_module
+
+    add {
+      lhs = f32[] parameter(0)
+      rhs = f32[] parameter(1)
+      ROOT add = f32[] add(lhs, rhs)
+    }
+
+    ENTRY BroadcastIntoReduce {
+      constant = f32[16] constant({0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15})
+      broadcast = f32[16,16,16,16]{3,2,1,0} broadcast(constant), dimensions={0}
+      constant.1 = f32[] constant(0)
+      ROOT reduce = f32[] reduce(broadcast, constant.1), dimensions={0,1,2,3},
+                                                         to_apply=add
+    })")
+                    .value();
+
+  EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
+  // The f32[16] constant should not be fused into the reduce, but the f32[]
+  // constant should be.
+  auto* root = module->entry_computation()->root_instruction();
+  ASSERT_THAT(root, GmockMatch(m::Fusion()));
+  EXPECT_THAT(
+      root->fused_instructions_computation()->root_instruction(),
+      GmockMatch(m::Reduce(m::Broadcast(m::Parameter()), m::Constant())));
+}
+
+TEST_F(InstructionFusionTest, FuseReverse) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule test_module
+
+    ENTRY Reverse {
+      p0 = f32[50,96,1024]{2,1,0} parameter(0)
+      add = f32[50,96,1024]{2,1,0} add(p0, p0)
+      ROOT reverse = f32[50,96,1024] reverse(add), dimensions={0}
+    })")
+                    .value();
+
+  EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
+
+  HloInstruction* root = module->entry_computation()->root_instruction();
+  ASSERT_THAT(root, GmockMatch(m::Fusion()));
+  EXPECT_THAT(root->fused_expression_root(),
+              GmockMatch(m::Reverse(m::Add(m::Parameter(), m::Parameter()))));
+}
+
+TEST_F(InstructionFusionTest, GpuIsExpensiveF32) {
+  auto m = CreateNewVerifiedModule();
+  Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+  HloComputation::Builder builder(TestName());
+  HloInstruction* param0 = builder.AddInstruction(
+      HloInstruction::CreateParameter(0, r0f32, "param0"));
+
+  HloInstruction* one = builder.AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
+  HloInstruction* div = builder.AddInstruction(
+      HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, one));
+  HloInstruction* rem = builder.AddInstruction(
+      HloInstruction::CreateBinary(r0f32, HloOpcode::kRemainder, param0, one));
+  HloInstruction* sqrt = builder.AddInstruction(
+      HloInstruction::CreateUnary(r0f32, HloOpcode::kSqrt, param0));
+  HloInstruction* rsqrt = builder.AddInstruction(
+      HloInstruction::CreateUnary(r0f32, HloOpcode::kRsqrt, param0));
+  HloInstruction* exp = builder.AddInstruction(
+      HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0));
+
+  EXPECT_FALSE(GpuInstructionFusion::IsExpensive(*div));
+  EXPECT_TRUE(GpuInstructionFusion::IsExpensive(*rem));
+  EXPECT_FALSE(GpuInstructionFusion::IsExpensive(*sqrt));
+  EXPECT_FALSE(GpuInstructionFusion::IsExpensive(*rsqrt));
+  EXPECT_FALSE(GpuInstructionFusion::IsExpensive(*exp));
+}
+
+TEST_F(InstructionFusionTest, GpuIsExpensiveF64) {
+  auto m = CreateNewVerifiedModule();
+  Shape r0f64 = ShapeUtil::MakeShape(F64, {});
+  HloComputation::Builder builder(TestName());
+  HloInstruction* param0 = builder.AddInstruction(
+      HloInstruction::CreateParameter(0, r0f64, "param0"));
+
+  HloInstruction* one = builder.AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
+  HloInstruction* div = builder.AddInstruction(
+      HloInstruction::CreateBinary(r0f64, HloOpcode::kDivide, param0, one));
+  HloInstruction* rem = builder.AddInstruction(
+      HloInstruction::CreateBinary(r0f64, HloOpcode::kRemainder, param0, one));
+  HloInstruction* sqrt = builder.AddInstruction(
+      HloInstruction::CreateUnary(r0f64, HloOpcode::kSqrt, param0));
+  HloInstruction* rsqrt = builder.AddInstruction(
+      HloInstruction::CreateUnary(r0f64, HloOpcode::kRsqrt, param0));
+  HloInstruction* exp = builder.AddInstruction(
+      HloInstruction::CreateUnary(r0f64, HloOpcode::kExp, param0));
+
+  EXPECT_TRUE(GpuInstructionFusion::IsExpensive(*div));
+  EXPECT_TRUE(GpuInstructionFusion::IsExpensive(*rem));
+  EXPECT_TRUE(GpuInstructionFusion::IsExpensive(*sqrt));
+  EXPECT_TRUE(GpuInstructionFusion::IsExpensive(*rsqrt));
+  EXPECT_TRUE(GpuInstructionFusion::IsExpensive(*exp));
+}
+
+TEST_F(InstructionFusionTest, GpuIsExpensiveS32) {
+  auto m = CreateNewVerifiedModule();
+  Shape r0s32 = ShapeUtil::MakeShape(S32, {});
+  HloComputation::Builder builder(TestName());
+  HloInstruction* param0 = builder.AddInstruction(
+      HloInstruction::CreateParameter(0, r0s32, "param0"));
+
+  HloInstruction* one = builder.AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
+  HloInstruction* div = builder.AddInstruction(
+      HloInstruction::CreateBinary(r0s32, HloOpcode::kDivide, param0, one));
+  HloInstruction* rem = builder.AddInstruction(
+      HloInstruction::CreateBinary(r0s32, HloOpcode::kRemainder, param0, one));
+
+  EXPECT_FALSE(GpuInstructionFusion::IsExpensive(*div));
+  EXPECT_FALSE(GpuInstructionFusion::IsExpensive(*rem));
+}
+
+TEST_F(InstructionFusionTest, GpuIsExpensiveBroadcastS32) {
+  auto m = CreateNewVerifiedModule();
+  Shape r1s32 = ShapeUtil::MakeShape(S32, {10});
+  HloComputation::Builder builder(TestName());
+  HloInstruction* param0 = builder.AddInstruction(
+      HloInstruction::CreateParameter(0, r1s32, "param0"));
+
+  HloInstruction* one = builder.AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
+  HloInstruction* one_broad =
+      builder.AddInstruction(HloInstruction::CreateBroadcast(r1s32, one, {}));
+
+  HloInstruction* div = builder.AddInstruction(HloInstruction::CreateBinary(
+      r1s32, HloOpcode::kDivide, param0, one_broad));
+  HloInstruction* rem = builder.AddInstruction(HloInstruction::CreateBinary(
+      r1s32, HloOpcode::kRemainder, param0, one_broad));
+
+  EXPECT_FALSE(GpuInstructionFusion::IsExpensive(*div));
+  EXPECT_FALSE(GpuInstructionFusion::IsExpensive(*rem));
+}
+
+TEST_F(InstructionFusionTest, FloatingPointExpIsCheap) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule test_module
+  Add {
+    lhs = f32[] parameter(0)
+    rhs = f32[] parameter(1)
+    ROOT add = f32[] add(lhs, rhs)
+  }
+  ENTRY TestComputation {
+    zero = f32[] constant(0)
+    p0 = f32[100] parameter(0)
+    recip = f32[100] exponential(p0)
+    sum1 = f32[] reduce(recip, zero), dimensions={0}, to_apply=Add
+    sum2 = f32[] reduce(recip, zero), dimensions={0}, to_apply=Add
+    ROOT root = (f32[], f32[]) tuple(sum1, sum2)
+  })")
+                    .value();
+
+  EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
+
+  HloInstruction* root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(root, GmockMatch(m::Tuple(m::Fusion(), m::Fusion())))
+      << module->ToString();
+}
+
+TEST_F(InstructionFusionTest, SmallReducedDimensionIsNotLoweredToLoop) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule test_module
+
+    add {
+      lhs = s32[] parameter(0)
+      rhs = s32[] parameter(1)
+      ROOT add = s32[] add(lhs, rhs)
+    }
+
+    ENTRY FuseSmallReduction {
+      p0 = s32[1048576,4] parameter(0)
+      p1 = s32[1048576,4] parameter(1)
+      sum = s32[1048576,4] add(p0, p1)
+      init = s32[] constant(0)
+      ROOT reduce = s32[1048576] reduce(sum, init), dimensions={1}, to_apply=add
+    })")
+                    .value();
+
+  EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
+
+  HloInstruction* root = module->entry_computation()->root_instruction();
+  ASSERT_THAT(root, GmockMatch(m::Fusion()));
+  EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kInput);
+}
+
+TEST_F(InstructionFusionTest, IotaIntoVariadicReduction) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule m
+
+  f {
+    tmp_0 = f32[] parameter(0)
+    tmp_1 = f32[] parameter(1)
+    tmp_2 = pred[] compare(tmp_0, tmp_1), direction=GE
+    tmp_3 = f32[] select(tmp_2, tmp_0, tmp_1)
+    tmp_4 = pred[] compare(tmp_0, tmp_1), direction=EQ
+    tmp_5 = s32[] parameter(2)
+    tmp_6 = s32[] parameter(3)
+    tmp_7 = s32[] minimum(tmp_5, tmp_6)
+    tmp_8 = s32[] select(tmp_2, tmp_5, tmp_6)
+    tmp_9 = s32[] select(tmp_4, tmp_7, tmp_8)
+    ROOT tmp_10 = (f32[], s32[]) tuple(tmp_3, tmp_9)
+  }
+
+  minmax {
+    tmp_0 = f32[] parameter(0)
+    tmp_1 = f32[] parameter(2)
+    tmp_2 = s32[] parameter(1)
+    tmp_3 = s32[] parameter(3)
+    ROOT tmp_4 = (f32[], s32[]) fusion(tmp_0, tmp_1, tmp_2, tmp_3), kind=kLoop, calls=f
+  }
+
+  ENTRY e {
+    tmp_0 = f32[554112,10]{1,0} parameter(0)
+    tmp_1 = s32[554112,10]{1,0} iota(), iota_dimension=1
+    tmp_2 = f32[] constant(-inf)
+    tmp_3 = s32[] constant(0)
+    ROOT tmp_4 = (f32[554112]{0}, s32[554112]{0}) reduce(tmp_0, tmp_1, tmp_2, tmp_3), dimensions={1}, to_apply=minmax
+  })")
+                    .value();
+
+  EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/false,
+                                   TestGpuDeviceInfo::RTXA6000DeviceInfo())
+                  .Run(module.get())
+                  .value());
+  ASSERT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch(m::Fusion(m::Parameter())));
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction()->fused_expression_root(),
+      GmockMatch(
+          m::Reduce(m::Parameter(), m::Iota(), m::Constant(), m::Constant())));
+}
+
+TEST_F(InstructionFusionTest, InputReductionFusion) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule test_module
+    add.clone.13 {
+      x.27 = f32[] parameter(0)
+      y.27 = f32[] parameter(1)
+      ROOT add.1036 = f32[] add(x.27, y.27)
+    }
+    add.clone.14 {
+      x.28 = f32[] parameter(0)
+      y.28 = f32[] parameter(1)
+      ROOT add.1037 = f32[] add(x.28, y.28)
+    }
+    add {
+      x = bf16[] parameter(0)
+      convert.448 = f32[] convert(x)
+      y = bf16[] parameter(1)
+      convert.449 = f32[] convert(y)
+      add.597 = f32[] add(convert.448, convert.449)
+      ROOT convert.450 = bf16[] convert(add.597)
+    }
+    ENTRY FuseSmallReduction {
+      param_2.7 = bf16[8,16,64,2048]{3,2,1,0} parameter(2)
+      convert.1395 = f32[8,16,64,2048]{3,2,1,0} convert(param_2.7)
+      param_0.85 = bf16[8,16,64,2048]{3,2,1,0} parameter(0)
+      convert.1393 = f32[8,16,64,2048]{3,2,1,0} convert(param_0.85)
+      multiply.1652 = f32[8,16,64,2048]{3,2,1,0} multiply(convert.1395, convert.1393)
+      convert.1392 = bf16[8,16,64,2048]{3,2,1,0} convert(multiply.1652)
+      bitcast.15934 = bf16[128,64,2048]{2,1,0} bitcast(convert.1392)
+      convert.1391 = f32[128,64,2048]{2,1,0} convert(bitcast.15934)
+      param_1.15 = bf16[] parameter(1)
+      convert.1394 = f32[] convert(param_1.15)
+      reduce.462 = f32[128,64]{1,0} reduce(convert.1391, convert.1394), dimensions={2}, to_apply=add.clone.13
+      reduce.121 = f32[64]{0} reduce(reduce.462, convert.1394), dimensions={0}, to_apply=add.clone.14
+      ROOT convert.890 = bf16[64]{0} convert(reduce.121)
+    })")
+                    .value();
+
+  EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
+
+  HloInstruction* fused_convert_fusion =
+      module->entry_computation()->root_instruction();
+
+  ASSERT_THAT(fused_convert_fusion, GmockMatch(m::Fusion()));
+  SCOPED_TRACE(module->ToString());
+  EXPECT_EQ(fused_convert_fusion->fusion_kind(),
+            HloInstruction::FusionKind::kInput);
+}
+
+TEST_F(InstructionFusionTest, DotStrengthReductionFusion) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule test_module
+
+scalar_add_computation {
+  scalar_rhs = f32[] parameter(1)
+  scalar_lhs = f32[] parameter(0)
+  ROOT add.1 = f32[] add(scalar_lhs, scalar_rhs)
+}
+
+ENTRY main {
+  param_1.3 = f16[16,64,96,6,2,16]{5,4,3,2,1,0} parameter(1)
+  param_0.6 = f16[16,64,96,1,2,16]{5,4,3,2,1,0} parameter(0)
+  bitcast.26 = f16[16,64,96,2,16]{4,3,2,1,0} bitcast(param_0.6)
+  broadcast.4 = f16[16,64,96,6,2,16]{5,4,3,2,1,0} broadcast(bitcast.26), dimensions={0,1,2,4,5}
+  multiply.4 = f16[16,64,96,6,2,16]{5,4,3,2,1,0} multiply(broadcast.4, param_1.3)
+  convert.8 = f32[16,64,96,6,2,16]{5,4,3,2,1,0} convert(multiply.4)
+  constant_2 = f32[] constant(0)
+  reduce.3 = f32[16,64,96,6,2]{3,4,2,1,0} reduce(convert.8, constant_2), dimensions={5}, to_apply=scalar_add_computation
+  bitcast.25 = f32[16,64,96,2,6]{4,3,2,1,0} bitcast(reduce.3)
+  convert.7 = f16[16,64,96,2,6]{4,3,2,1,0} convert(bitcast.25)
+  ROOT bitcast.24 = f16[16,64,96,2,1,6]{5,4,3,2,1,0} bitcast(convert.7)
+})")
+                    .value();
+
+  EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
+
+  const HloInstruction* fused_convert_fusion =
+      module->entry_computation()->root_instruction()->operand(0);
+
+  ASSERT_THAT(fused_convert_fusion, GmockMatch(m::Fusion()));
+  SCOPED_TRACE(module->ToString());
+  EXPECT_EQ(fused_convert_fusion->fusion_kind(),
+            HloInstruction::FusionKind::kInput);
+  EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1);
+}
+
+TEST_F(InstructionFusionTest, ReductionFusionOtherUnaryElementwiseOpsAreFused) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule test_module
+
+scalar_add_computation {
+  scalar_rhs = f32[] parameter(1)
+  scalar_lhs = f32[] parameter(0)
+  ROOT add.1 = f32[] add(scalar_lhs, scalar_rhs)
+}
+
+ENTRY main {
+  param_0 = f16[64,96,6,16]{3,2,1,0} parameter(0)
+  constant_2 = f32[] constant(0)
+  reduce.3 = f32[64,6,16]{2,1,0} reduce(param_0, constant_2), dimensions={1}, to_apply=scalar_add_computation
+  negate = f32[64,6,16]{2,1,0} negate(reduce.3)
+  ROOT sine = f16[64,6,16]{2,1,0} sine(negate)
+})")
+                    .value();
+
+  EXPECT_TRUE(duplicating_instruction_fusion_.Run(module.get()).value());
+
+  HloInstruction* fused_convert_fusion =
+      module->entry_computation()->root_instruction();
+
+  ASSERT_THAT(fused_convert_fusion, GmockMatch(m::Fusion()));
+  SCOPED_TRACE(module->ToString());
+  EXPECT_EQ(fused_convert_fusion->fusion_kind(),
+            HloInstruction::FusionKind::kInput);
+  EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1);
+}
+
+TEST_F(InstructionFusionTest, DoNotFuseInsideReducer) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule test_module
+
+scalar_add_computation {
+  scalar_rhs = f32[] parameter(1)
+  scalar_lhs = f32[] parameter(0)
+  add.1 = f32[] add(scalar_lhs, scalar_rhs)
+  ROOT add.2 = f32[] add(add.1, scalar_rhs)
+}
+
+ENTRY main {
+  param_0 = f16[64,96] parameter(0)
+  constant_2 = f32[] constant(0)
+  ROOT reduce = f32[64] reduce(param_0, constant_2), dimensions={1}, to_apply=scalar_add_computation
+})")
+                    .value();
+
+  EXPECT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value());
+  SCOPED_TRACE(module->ToString());
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/layout_assignment.cc b/third_party/xla/xla/service/gpu/transforms/layout_assignment.cc
new file mode 100644
index 0000000..caa8d3c1
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/layout_assignment.cc
@@ -0,0 +1,596 @@
+/* Copyright 2017 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/layout_assignment.h"
+
+#include <cstddef>
+#include <cstdint>
+#include <initializer_list>
+#include <memory>
+#include <tuple>
+#include <utility>
+#include <variant>
+#include <vector>
+
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "absl/types/span.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/layout.h"
+#include "xla/layout_util.h"
+#include "xla/primitive_util.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/service/gpu/matmul_utils.h"
+#include "xla/service/gpu/reduction_utils.h"
+#include "xla/service/gpu/stream_executor_util.h"
+#include "xla/service/host_memory_offload_annotations.h"
+#include "xla/service/logical_buffer.h"
+#include "xla/shape.h"
+#include "xla/shape_layout.h"
+#include "xla/shape_util.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/stream_executor/dnn.h"
+#include "xla/tsl/util/env_var.h"
+#include "xla/util.h"
+#include "xla/window_util.h"
+#include "xla/xla.pb.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/status.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+using se::dnn::DataLayout;
+using se::dnn::FilterLayout;
+
+// Returns (input, filter, output) layouts.
+static std::tuple<DataLayout, FilterLayout, DataLayout>
+HeuristicLayoutAssignment(const HloInstruction* instr,
+                          const se::GpuComputeCapability& gpu_version,
+                          const se::dnn::VersionInfo& dnn_version) {
+  // DataLayout and FilterLayout uses weird enum names. Translations:
+  //   N <=> Batch or Output
+  //   C <=> Depth or Input
+  //   H <=> Y
+  //   W <=> X
+  //
+  // Therefore kOutputInputYX and kBatchDepthYX mean NCHW.
+  //
+  // If you have trouble keeping these straight, consider that all that matters
+  // is the location of the channel dim: Is it major (NCHW), or minor (NHWC)?
+
+  constexpr auto kAllNCHW =
+      std::make_tuple(DataLayout::kBatchDepthYX, FilterLayout::kOutputInputYX,
+                      DataLayout::kBatchDepthYX);
+  // kBatchDepthYX4 has the same layout as kBatchDepthYX32; they're both VECT_C
+  // layouts as far as cudnn is concerned.
+  constexpr auto kAllNCHW_VECT_C =
+      std::make_tuple(DataLayout::kBatchDepthYX4, FilterLayout::kOutputInputYX4,
+                      DataLayout::kBatchDepthYX4);
+  constexpr auto kAllNHWC =
+      std::make_tuple(DataLayout::kBatchYXDepth, FilterLayout::kOutputYXInput,
+                      DataLayout::kBatchYXDepth);
+
+  // Integer convolution must use NHWC or NCHW_VECT_C.
+  //
+  // TODO(jlebar): Do non-VECT_C int8_t convs still require NHWC with new
+  // versions of cudnn?
+  const ConvolutionDimensionNumbers& dnums =
+      instr->convolution_dimension_numbers();
+  Shape input_shape = instr->operand(0)->shape();
+  PrimitiveType input_ty = instr->operand(0)->shape().element_type();
+  if (primitive_util::IsIntegralType(input_ty)) {
+    if (input_ty == S8 && dnums.input_spatial_dimensions_size() == 2 &&
+        input_shape.dimensions_size() == 5) {
+      VLOG(2) << "Using NCHW_VECT_C for int8_t conv " << instr->ToString();
+      return kAllNCHW_VECT_C;
+    }
+    VLOG(2) << "Using NHWC for int8_t conv " << instr->ToString();
+    return kAllNHWC;
+  }
+
+  if (primitive_util::IsF8Type(input_ty)) {
+    VLOG(2) << "Using NHWC for FP8 conv " << instr->ToString();
+    return kAllNHWC;
+  }
+
+  const DebugOptions& debug_options =
+      instr->GetModule()->config().debug_options();
+
+  if (debug_options.xla_gpu_force_conv_nchw()) {
+    VLOG(2) << "Overriding layout to NCHW for " << instr->ToString();
+    return kAllNCHW;
+  }
+
+  if (debug_options.xla_gpu_force_conv_nhwc()) {
+    VLOG(2) << "Overriding layout to NHWC for " << instr->ToString();
+    return kAllNHWC;
+  }
+
+  const auto* rocm_compute_capability =
+      std::get_if<se::RocmComputeCapability>(&gpu_version);
+  if (rocm_compute_capability && input_ty == F16) return kAllNHWC;
+
+  // If we're not Volta or not fp16/bfloat16, or not conv2D, the decision is
+  // easy: Use NCHW.
+  const bool isFloat16 = (input_ty == F16) || (input_ty == BF16);
+  if (std::holds_alternative<se::CudaComputeCapability>(gpu_version)) {
+    // If we're not Volta or not fp16/bfloat16, or not conv2D, the decision is
+    // easy: Use NCHW.
+    const auto* cuda_compute_capability =
+        std::get_if<se::CudaComputeCapability>(&gpu_version);
+    bool is_volta =
+        cuda_compute_capability &&
+        cuda_compute_capability->IsAtLeast(se::CudaComputeCapability::VOLTA);
+    if (!isFloat16 || !is_volta ||
+        instr->shape().tuple_shapes(0).dimensions_size() != 4) {
+      return kAllNCHW;
+    }
+
+    // Empirically we've found with Volta and cudnn <= 7.3 that backward-input
+    // convs with stride are significantly faster with NCHW layouts.
+    //
+    // We could have used a mixed layout combination, e.g. (NHWC, NCHW, NCHW),
+    // which on paper gives good performance. However, there are two
+    // observations:
+    // * a mixed layout combination is more cuDNN-bug prone, based on empirical
+    //   evidence.
+    // * we've also observed that for mixed layouts, cuDNN transposes data back
+    //   and forth from a different layout combination. If we end up with
+    //   transposes anyway, we prefer to have them in XLA, as they can be fused.
+    if (std::make_tuple(dnn_version.major_version(),
+                        dnn_version.minor_version()) <= std::make_tuple(7, 3) &&
+        instr->custom_call_target() == kCudnnConvBackwardInputCallTarget &&
+        window_util::HasStride(instr->window())) {
+      return kAllNCHW;
+    }
+  } else if (std::holds_alternative<se::RocmComputeCapability>(gpu_version)) {
+    bool is_enabled = false;
+    TF_CHECK_OK(tsl::ReadBoolFromEnvVar("TF_USE_ROCM_NHWC",
+                                        /*default_val=*/false, &is_enabled));
+    auto rocm_compute_capability =
+        std::get<se::RocmComputeCapability>(gpu_version);
+    if (!isFloat16 || (!rocm_compute_capability.has_nhwc_layout_support()) ||
+        instr->shape().tuple_shapes(0).dimensions_size() != 4 || !is_enabled) {
+      return kAllNCHW;
+    }
+  }
+
+  VLOG(2) << "Using heuristic to figure out layouts for " << instr->ToString();
+
+  // For other Volta f16 convolutions, use NHWC.
+  return kAllNHWC;
+}
+
+// Adds layout constraints on the cudnn custom-call instruction. The layout
+// constraints are represented in terms of minor_to_major fields of both
+// operands and the output shape. Depending on the underlying algorithm, one of
+// { NCHW, NHWC } ^ 3 = 8 different layout combinations may be chosen.
+absl::Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall(
+    HloCustomCallInstruction* instr, LayoutConstraints* constraints) {
+  Shape lhs_shape = instr->operand(0)->shape();
+  Shape rhs_shape = instr->operand(1)->shape();
+  Shape result_shape = instr->shape().tuple_shapes(0);
+
+  Shape* input_shape;
+  Shape* filter_shape;
+  Shape* output_shape;
+
+  TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(instr));
+  switch (kind) {
+    case CudnnConvKind::kForward:
+    case CudnnConvKind::kForwardActivation:
+    case CudnnConvKind::kForwardGraph:
+      input_shape = &lhs_shape;
+      filter_shape = &rhs_shape;
+      output_shape = &result_shape;
+      break;
+    case CudnnConvKind::kBackwardInput:
+      input_shape = &result_shape;
+      filter_shape = &rhs_shape;
+      output_shape = &lhs_shape;
+      break;
+    case CudnnConvKind::kBackwardFilter:
+      input_shape = &lhs_shape;
+      filter_shape = &result_shape;
+      output_shape = &rhs_shape;
+      break;
+  }
+
+  {
+    DataLayout input;
+    FilterLayout filter;
+    DataLayout output;
+    std::tie(input, filter, output) =
+        HeuristicLayoutAssignment(instr, gpu_version_, dnn_version_);
+
+    TF_ASSIGN_OR_RETURN(
+        std::tie(*input_shape->mutable_layout(),
+                 *filter_shape->mutable_layout(),
+                 *output_shape->mutable_layout()),
+        StreamExecutorConvLayoutsToXlaLayouts(
+            instr->convolution_dimension_numbers(), input, filter, output));
+  }
+
+  // The custom call returns a tuple of (actual_result, scratch_buffer);
+  // call_result_buf is the logical buffer for actual_result, the thing that
+  // contains the result of the conv call.
+  TF_ASSIGN_OR_RETURN(
+      const LogicalBuffer* call_result_buf,
+      points_to_analysis_->GetBufferDefinedAt(instr, /*index=*/{0}));
+
+  // Set layouts of the instructions' shapes.
+  TF_RETURN_IF_ERROR(SetOperandLayout(lhs_shape, instr, 0));
+  TF_RETURN_IF_ERROR(SetOperandLayout(rhs_shape, instr, 1));
+  TF_RETURN_IF_ERROR(SetBufferLayout(result_shape.layout(), *call_result_buf));
+  // For fused convolutions, instr->operand(2), if exists, is the bias buffer.
+  // There is no need to assign layout to it, as it has only one dimension.
+  // instr->operand(3), if exists, is the side input buffer.
+  if (kind == CudnnConvKind::kForwardActivation &&
+      instr->operand_count() == 4) {
+    // The side input layout must match the output layout.
+    TF_RETURN_IF_ERROR(SetOperandLayout(*output_shape, instr, 3));
+  }
+
+  // For graph convolutions, align the layouts of the non-scalar inputs to any
+  // pointwise ops with the output layout.
+  if (kind == CudnnConvKind::kForwardGraph) {
+    for (int k = 2; k < instr->operand_count(); ++k) {
+      if (!ShapeUtil::IsScalar(instr->operand(k)->shape())) {
+        TF_RETURN_IF_ERROR(SetOperandLayout(*output_shape, instr, k));
+      }
+    }
+  }
+
+  if (instr->operand_count() > 2 && kind != CudnnConvKind::kForwardActivation &&
+      kind != CudnnConvKind::kForwardGraph) {
+    return Internal(
+        "Invalid convolution. Conv has a side input, but kind is not fused "
+        "conv forward or graph conv foward: %s",
+        instr->ToString());
+  }
+
+  return absl::OkStatus();
+}
+
+namespace {
+
+// Imposes the default layout with first two dimensions swapped on input
+// `shape`.
+void SetFortranLayout(Shape* shape) {
+  LayoutUtil::SetToDefaultLayout(shape);
+  int n = shape->mutable_layout()->minor_to_major_size();
+  CHECK_GE(n, 2);
+  std::swap(shape->mutable_layout()->mutable_minor_to_major()->at(0),
+            shape->mutable_layout()->mutable_minor_to_major()->at(1));
+}
+
+bool DotCanSupportShapeWithLayout(const HloInstruction* dot,
+                                  const Shape& shape) {
+  const DotDimensionNumbers& dot_dims = dot->dot_dimension_numbers();
+  // If we are able to construct a `MatrixLayout` then the dot can support
+  // this layout.
+  return MatrixLayout::For(shape, dot_dims.lhs_batch_dimensions().size(),
+                           dot->operand(0)->shape().rank() -
+                               dot_dims.lhs_contracting_dimensions().size() -
+                               dot_dims.lhs_batch_dimensions().size(),
+                           dot_dims.rhs_batch_dimensions().size(),
+                           dot->operand(1)->shape().rank() -
+                               dot_dims.rhs_contracting_dimensions().size() -
+                               dot_dims.rhs_batch_dimensions().size())
+      .ok();
+}
+
+}  // namespace
+
+absl::Status GpuLayoutAssignment::AddBackendConstraints(
+    LayoutConstraints* constraints) {
+  // Add convolution constraints in reverse postorder that the earliest
+  // convolution layout propagates first. This reduces the likelihood of fusion
+  // nodes with copies.
+  auto post_order = constraints->computation()->MakeInstructionPostOrder();
+  for (auto iterator = post_order.rbegin(); iterator != post_order.rend();
+       ++iterator) {
+    HloInstruction* instruction = *iterator;
+    if (IsCustomCallToDnnConvolution(*instruction)) {
+      TF_RETURN_IF_ERROR(AddBackendConstraintsToDnnConvCustomCall(
+          Cast<HloCustomCallInstruction>(instruction), constraints));
+    }
+
+    CHECK(!IsCublasGemm(*instruction))
+        << "Gemm rewriting should run after layout assignment";
+
+    if (instruction->opcode() == HloOpcode::kDot) {
+      const Shape& output_shape = instruction->shape();
+      const Shape& lhs_shape = instruction->operand(0)->shape();
+      const Shape& rhs_shape = instruction->operand(1)->shape();
+      const DotDimensionNumbers& dot_dims =
+          instruction->dot_dimension_numbers();
+
+      // Matmuls require the batch dimensions to be in consecutive physical
+      // dimensions and likewise for the contracting and non-contracting
+      // dimensions. Additionally, no batch dimension can be in the most
+      // minor physical dimension for inputs or the output.
+      absl::Span<const int64_t> lhs_batch_dims =
+          dot_dims.lhs_batch_dimensions();
+      absl::Span<const int64_t> lhs_contracting_dims =
+          dot_dims.lhs_contracting_dimensions();
+      TF_ASSIGN_OR_RETURN(std::vector<int64_t> lhs_non_contracting_dims,
+                          GetNonContractingDims(lhs_shape, lhs_batch_dims,
+                                                lhs_contracting_dims));
+
+      absl::Span<const int64_t> rhs_batch_dims =
+          dot_dims.rhs_batch_dimensions();
+      absl::Span<const int64_t> rhs_contracting_dims =
+          dot_dims.rhs_contracting_dimensions();
+      TF_ASSIGN_OR_RETURN(std::vector<int64_t> rhs_non_contracting_dims,
+                          GetNonContractingDims(rhs_shape, rhs_batch_dims,
+                                                rhs_contracting_dims));
+
+      const DebugOptions& debug_options =
+          instruction->GetModule()->config().debug_options();
+
+      bool is_bf16_to_bf16 =
+          (output_shape.element_type() == PrimitiveType::BF16 &&
+           lhs_shape.element_type() == PrimitiveType::BF16 &&
+           rhs_shape.element_type() == PrimitiveType::BF16);
+      bool is_s8_to_s32 = (output_shape.element_type() == PrimitiveType::S32 &&
+                           lhs_shape.element_type() == PrimitiveType::S8 &&
+                           rhs_shape.element_type() == PrimitiveType::S8 &&
+                           output_shape.dimensions_size() == 2 &&
+                           lhs_shape.dimensions_size() == 2 &&
+                           rhs_shape.dimensions_size() == 2);
+
+      if (is_s8_to_s32 ||
+          (is_bf16_to_bf16 &&
+           debug_options.xla_gpu_ensure_minor_dot_contraction_dims())) {
+        TF_RETURN_IF_ERROR(SetOperandMajorToMinorLayout(
+            instruction, /*operand=*/0,
+            /*dim_groups=*/
+            {lhs_batch_dims, lhs_non_contracting_dims, lhs_contracting_dims}));
+        TF_RETURN_IF_ERROR(SetOperandMajorToMinorLayout(
+            instruction, /*operand=*/1,
+            /*dim_groups=*/
+            {rhs_batch_dims, rhs_non_contracting_dims, rhs_contracting_dims}));
+        TF_RETURN_IF_ERROR(SetDotLayout(instruction, constraints));
+      } else {
+        if (!lhs_batch_dims.empty() || lhs_contracting_dims.size() > 1 ||
+            lhs_non_contracting_dims.size() > 1) {
+          TF_RETURN_IF_ERROR(SetDotOperandLayout(instruction, 0, lhs_batch_dims,
+                                                 lhs_contracting_dims,
+                                                 lhs_non_contracting_dims));
+        }
+        if (!rhs_batch_dims.empty() || rhs_non_contracting_dims.size() > 1 ||
+            rhs_contracting_dims.size() > 1) {
+          TF_RETURN_IF_ERROR(SetDotOperandLayout(instruction, 1, rhs_batch_dims,
+                                                 rhs_contracting_dims,
+                                                 rhs_non_contracting_dims));
+        }
+        // If we have at least one batch dimension or there is more than one
+        // non-contracting dimension on lhs or rhs, we need to set a layout for
+        // the dot output.
+        if (!lhs_batch_dims.empty() || lhs_non_contracting_dims.size() > 1 ||
+            rhs_non_contracting_dims.size() > 1) {
+          TF_RETURN_IF_ERROR(SetDotLayout(instruction, constraints));
+        }
+      }
+    } else if (instruction->opcode() == HloOpcode::kTranspose) {
+      const HloInstruction* operand = instruction->operand(0);
+      if ((operand->opcode() != HloOpcode::kDot) ||
+          (operand->user_count() > 1)) {
+        continue;
+      }
+
+      // If possible, set layout of the dot operation such that the output of
+      // the transpose (as a bitcast) has the default layout.
+      Shape shape = operand->shape();
+      *shape.mutable_layout() =
+          LayoutUtil::MakeLayoutFromMajorToMinor(instruction->dimensions());
+
+      if (DotCanSupportShapeWithLayout(operand, shape)) {
+        TF_RETURN_IF_ERROR(
+            SetOperandLayout(shape, instruction, /*operand_no=*/0));
+      }
+    } else if (instruction->opcode() == HloOpcode::kFft) {
+      // cuFFT requires a dim0 major layout.
+      Shape op0_shape = instruction->operand(0)->shape();
+      LayoutUtil::SetToDefaultLayout(&op0_shape);
+      Shape output_shape = instruction->shape();
+      LayoutUtil::SetToDefaultLayout(&output_shape);
+      TF_RETURN_IF_ERROR(SetOperandLayout(op0_shape, instruction, 0));
+      TF_RETURN_IF_ERROR(SetInstructionLayout(output_shape, instruction));
+    } else if (instruction->opcode() == HloOpcode::kSort &&
+               instruction->operand(0)->shape().rank() > 1) {
+      // Make sure that all the operands and the output(s) have the same layout.
+      Shape keys_shape = instruction->operand(0)->shape();
+      Layout keys_layout =
+          LayoutUtil::GetDefaultLayoutForRank(keys_shape.rank());
+      for (int64_t i = 0; i < instruction->operand_count(); ++i) {
+        Shape shape = instruction->operand(i)->shape();
+        *shape.mutable_layout() = keys_layout;
+        TF_RETURN_IF_ERROR(SetOperandLayout(shape, instruction, i));
+        const LogicalBuffer* output_buffer;
+        if (instruction->shape().IsArray()) {
+          TF_ASSIGN_OR_RETURN(
+              output_buffer,
+              points_to_analysis_->GetBufferDefinedAt(instruction, {}));
+        } else {
+          TF_ASSIGN_OR_RETURN(
+              output_buffer,
+              points_to_analysis_->GetBufferDefinedAt(instruction, {i}));
+        }
+        TF_RETURN_IF_ERROR(SetBufferLayout(keys_layout, *output_buffer));
+      }
+    } else if (instruction->opcode() == HloOpcode::kTriangularSolve) {
+      // TODO(phawkins): Ideally we would relax this constraint. What we
+      // actually want is that:
+      // a) the batch dimensions are major, in no particular order.
+      // b) the two minor dimensions are in fortran (column-major) order,
+      // although for the 'a' argument we could potentially accept row-major
+      // order and fold the transpose into the operator.
+      Shape op0_shape = instruction->operand(0)->shape();
+      Shape op1_shape = instruction->operand(1)->shape();
+      Shape output_shape = instruction->shape();
+      SetFortranLayout(&op0_shape);
+      SetFortranLayout(&op1_shape);
+      SetFortranLayout(&output_shape);
+      TF_RETURN_IF_ERROR(SetOperandLayout(op0_shape, instruction, 0));
+      TF_RETURN_IF_ERROR(SetOperandLayout(op1_shape, instruction, 1));
+      TF_RETURN_IF_ERROR(SetInstructionLayout(output_shape, instruction));
+    } else if (instruction->opcode() == HloOpcode::kReduceScatter) {
+      // XLA:GPU can only support reduce-scatter where the scatter dimension
+      // is the most major dimension in the layout.
+      auto ars = Cast<HloReduceScatterInstruction>(instruction);
+      TF_RETURN_IF_ERROR(SetInstructionLayout(
+          ShapeUtil::MoveDimToMajor(ars->shape(), ars->scatter_dimension()),
+          ars));
+    } else if (instruction->opcode() == HloOpcode::kAllGather) {
+      // XLA:GPU can only support all-gathers where the gather dimension is the
+      // most major dimension in the layout.
+      auto ag = Cast<HloAllGatherInstruction>(instruction);
+      TF_RETURN_IF_ERROR(SetInstructionLayout(
+          ShapeUtil::MoveDimToMajor(ag->shape(), ag->all_gather_dimension()),
+          ag));
+    } else if (instruction->opcode() == HloOpcode::kAllToAll &&
+               instruction->shape().IsArray()) {
+      // XLA:GPU can only support all-to-all with split dimensions where the
+      // split dimension is the most major dimension in the layout.
+      auto* all_to_all = Cast<HloAllToAllInstruction>(instruction);
+      TF_RETURN_IF_ERROR(SetInstructionLayout(
+          ShapeUtil::MoveDimToMajor(all_to_all->shape(),
+                                    *all_to_all->split_dimension()),
+          all_to_all));
+    } else if (instruction->opcode() == HloOpcode::kSend) {
+      Shape s = instruction->operand(0)->shape();
+      LayoutUtil::SetToDefaultLayout(&s);
+      TF_RETURN_IF_ERROR(SetInstructionLayout(s, instruction->operand(0)));
+      TF_RETURN_IF_ERROR(
+          SetArrayOperandLayout(s.layout(), instruction->operand(0), 0));
+    } else if (instruction->opcode() == HloOpcode::kRecv) {
+      Shape s = instruction->shape();
+      ShapeUtil::ForEachMutableSubshape(
+          &s, [&](Shape* subshape, const ShapeIndex& index) {
+            LayoutUtil::SetToDefaultLayout(subshape);
+          });
+      TF_RETURN_IF_ERROR(SetInstructionLayout(s, instruction));
+    }
+  }
+  return absl::OkStatus();
+}
+
+absl::Status GpuLayoutAssignment::SetDotOperandLayout(
+    const HloInstruction* instruction, int64_t operand,
+    absl::Span<const int64_t> batch_dims, absl::Span<const int64_t> row_dims,
+    absl::Span<const int64_t> col_dims) {
+  Shape shape = instruction->operand(operand)->shape();
+
+  // First, try to use the existing layout, if present.
+  if (shape.has_layout() &&
+      MatrixLayout::For(shape, batch_dims, row_dims, col_dims).ok())
+    // Re-set the operand layout, so it becomes mandatory.
+    return SetOperandLayout(shape, instruction, operand);
+
+  // Next, try the default layout (for the sake of everybody's sanity).
+  LayoutUtil::SetToDefaultLayout(&shape);
+  if (MatrixLayout::For(shape, batch_dims, row_dims, col_dims).ok())
+    return SetOperandLayout(shape, instruction, operand);
+
+  // Otherwise, fallback to forcing (batch, rows, cols) layout.
+  return SetOperandMajorToMinorLayout(
+      instruction, operand,
+      /*dim_groups=*/{batch_dims, row_dims, col_dims});
+}
+
+absl::Status GpuLayoutAssignment::SetOperandMajorToMinorLayout(
+    const HloInstruction* instruction, int64_t operand,
+    std::initializer_list<absl::Span<const int64_t>> dim_groups) {
+  size_t size = 0;
+  for (auto group : dim_groups) size += group.size();
+  std::vector<int64_t> major_to_minor;
+  major_to_minor.reserve(size);
+  for (const auto& group : dim_groups) {
+    major_to_minor.insert(major_to_minor.end(), group.begin(), group.end());
+  }
+
+  Shape shape = instruction->operand(operand)->shape();
+  *shape.mutable_layout() =
+      LayoutUtil::MakeLayoutFromMajorToMinor(major_to_minor);
+  return SetOperandLayout(shape, instruction, operand);
+}
+
+absl::Status GpuLayoutAssignment::SetDotLayout(
+    const HloInstruction* instruction, LayoutConstraints* constraints) {
+  // If a user has requested a layout that we can support, use that.
+  for (const HloInstruction* user : instruction->users()) {
+    for (int64_t i = 0; i < user->operand_count(); ++i) {
+      if (user->operand(i) != instruction) {
+        continue;
+      }
+
+      const ShapeLayout* constraint = constraints->OperandLayout(user, i);
+      if ((constraint != nullptr) &&
+          DotCanSupportShapeWithLayout(instruction, constraint->shape())) {
+        return SetInstructionLayout(constraint->shape(), instruction);
+      }
+    }
+  }
+
+  // Otherwise, use the default layout.
+  return SetInstructionLayout(
+      LayoutUtil::GetWithDefaultLayout(instruction->shape()), instruction);
+}
+
+bool GpuLayoutAssignment::PropagateReductionLayoutToOperand(
+    const HloInstruction* user) {
+  // We try to propagate a layout to make the reduction a row reduction. But
+  // propagating the layout is only beneficial if the reduction emitter would be
+  // used for the row reduction.
+  int64_t reduction_size = 1;
+  for (int64_t reduction_dim : user->dimensions()) {
+    reduction_size *= user->operand(0)->shape().dimensions(reduction_dim);
+  }
+  int64_t kept_dimension_size = ShapeUtil::ElementsIn(user->shape());
+  return IsUnnestedReductionFasterThanElemental(
+      {/*is_row_reduction=*/true, {1, kept_dimension_size, reduction_size}});
+}
+
+bool GpuLayoutAssignment::InstructionCanChangeLayoutInstance(
+    const HloInstruction* instruction) {
+  // The host offloading custom calls will be eventually removed
+  // by the offloader, so we need to make sure that the calls do not change
+  // the layout and thus cause layout mismatches after the removal.
+  const HloCustomCallInstruction* custom_call =
+      DynCast<HloCustomCallInstruction>(instruction);
+  if (custom_call != nullptr &&
+      (custom_call->custom_call_target() ==
+           host_memory_offload_annotations::kMoveToHostCustomCallTarget ||
+       custom_call->custom_call_target() ==
+           host_memory_offload_annotations::kMoveToDeviceCustomCallTarget)) {
+    return false;
+  }
+
+  return LayoutAssignment::InstructionCanChangeLayoutInstance(instruction);
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/layout_assignment.h b/third_party/xla/xla/service/gpu/transforms/layout_assignment.h
new file mode 100644
index 0000000..efa58f3
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/layout_assignment.h
@@ -0,0 +1,81 @@
+/* Copyright 2017 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_LAYOUT_ASSIGNMENT_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_LAYOUT_ASSIGNMENT_H_
+
+#include <cstdint>
+#include <initializer_list>
+
+#include "absl/status/status.h"
+#include "absl/types/span.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/service/computation_layout.h"
+#include "xla/service/layout_assignment.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/stream_executor/dnn.h"
+
+namespace xla {
+namespace gpu {
+
+// GPU-specific layout assignment pass which preassigns layouts to satisfy
+// layout constraints for operands and results of library calls.
+class GpuLayoutAssignment : public LayoutAssignment {
+ public:
+  explicit GpuLayoutAssignment(
+      ComputationLayout* entry_computation_layout,
+      const se::GpuComputeCapability& gpu_version,
+      const se::dnn::VersionInfo& dnn_version,
+      ChannelLayoutConstraints* channel_constraints = nullptr)
+      : LayoutAssignment(entry_computation_layout, channel_constraints),
+        gpu_version_(gpu_version),
+        dnn_version_(dnn_version) {}
+  ~GpuLayoutAssignment() override = default;
+
+ protected:
+  absl::Status AddBackendConstraints(LayoutConstraints* constraints) override;
+
+ private:
+  absl::Status AddBackendConstraintsToDnnConvCustomCall(
+      HloCustomCallInstruction* instr, LayoutConstraints* constraints);
+
+  // dim_groups are ordered from major to minor dimensions.
+  absl::Status SetOperandMajorToMinorLayout(
+      const HloInstruction* instruction, int64_t operand,
+      std::initializer_list<absl::Span<const int64_t>> dim_groups);
+
+  absl::Status SetDotOperandLayout(const HloInstruction* instruction,
+                                   int64_t operand,
+                                   absl::Span<const int64_t> batch_dims,
+                                   absl::Span<const int64_t> row_dims,
+                                   absl::Span<const int64_t> col_dims);
+
+  absl::Status SetDotLayout(const HloInstruction* instruction,
+                            LayoutConstraints* constraints);
+
+  bool PropagateReductionLayoutToOperand(const HloInstruction* user) override;
+
+  bool InstructionCanChangeLayoutInstance(
+      const HloInstruction* instruction) override;
+
+  const se::GpuComputeCapability gpu_version_;
+  const se::dnn::VersionInfo dnn_version_;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_LAYOUT_ASSIGNMENT_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/layout_assignment_test.cc b/third_party/xla/xla/service/gpu/transforms/layout_assignment_test.cc
new file mode 100644
index 0000000..dd1cbc6
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/layout_assignment_test.cc
@@ -0,0 +1,677 @@
+/* Copyright 2017 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/layout_assignment.h"
+
+#include <cstdint>
+#include <memory>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/types/span.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/layout.h"
+#include "xla/layout_util.h"
+#include "xla/service/computation_layout.h"
+#include "xla/service/gpu/stream_executor_util.h"
+#include "xla/service/hlo_parser.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/shape.h"
+#include "xla/shape_layout.h"
+#include "xla/shape_util.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/stream_executor/dnn.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/status_matchers.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+namespace m = ::xla::match;
+using ::tsl::testing::IsOkAndHolds;
+
+class LayoutAssignmentTest : public HloTestBase {
+ public:
+  se::CudaComputeCapability GetCudaComputeCapability() {
+    return backend()
+        .default_stream_executor()
+        ->GetDeviceDescription()
+        .cuda_compute_capability();
+  }
+
+  se::GpuComputeCapability GetGpuComputeCapability() {
+    return backend()
+        .default_stream_executor()
+        ->GetDeviceDescription()
+        .gpu_compute_capability();
+  }
+
+  se::dnn::VersionInfo GetDnnVersion() {
+    // GpuLayoutAssignment has a special case heuristic for cudnn <= 7.3, but
+    // none of the tests trigger this heuristic.
+    return GetDnnVersionInfoOrDefault(backend().default_stream_executor(),
+                                      se::dnn::VersionInfo{8, 3, 0});
+  }
+};
+
+TEST_F(LayoutAssignmentTest, Elementwise) {
+  Shape ashape = ShapeUtil::MakeShape(F32, {42, 12});
+  Shape ashape_in_row_major(ashape);
+  Shape ashape_in_col_major(ashape);
+  *ashape_in_row_major.mutable_layout() = LayoutUtil::MakeLayout({1, 0});
+  *ashape_in_col_major.mutable_layout() = LayoutUtil::MakeLayout({0, 1});
+
+  // Enumerate all possible combinations of layouts.
+  for (const Shape& lhs_shape_with_layout :
+       {ashape_in_row_major, ashape_in_col_major}) {
+    for (const Shape& rhs_shape_with_layout :
+         {ashape_in_row_major, ashape_in_col_major}) {
+      for (const Shape& result_shape_with_layout :
+           {ashape_in_row_major, ashape_in_col_major}) {
+        // GpuLayoutAssignment should assign the same layout to "add" and its
+        // two operands.
+        auto builder = HloComputation::Builder(TestName());
+        auto x = builder.AddInstruction(
+            HloInstruction::CreateParameter(0, ashape, "x"));
+        auto y = builder.AddInstruction(
+            HloInstruction::CreateParameter(1, ashape, "y"));
+        auto add = builder.AddInstruction(
+            HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, x, y));
+        auto module = CreateNewVerifiedModule();
+        HloComputation* computation =
+            module->AddEntryComputation(builder.Build(add));
+
+        ComputationLayout computation_layout(
+            computation->ComputeProgramShape());
+        *computation_layout.mutable_parameter_layout(0) =
+            ShapeLayout(lhs_shape_with_layout);
+        *computation_layout.mutable_parameter_layout(1) =
+            ShapeLayout(rhs_shape_with_layout);
+        *computation_layout.mutable_result_layout() =
+            ShapeLayout(result_shape_with_layout);
+
+        GpuLayoutAssignment layout_assignment(
+            &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
+        EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
+
+        for (const HloInstruction* operand : add->operands()) {
+          EXPECT_TRUE(LayoutUtil::Equal(add->shape().layout(),
+                                        operand->shape().layout()));
+        }
+      }
+    }
+  }
+}
+
+TEST_F(LayoutAssignmentTest, DotLayoutUnchangedIfValid) {
+  const char* hlo_text = R"(
+  HloModule DotLayout
+  ENTRY dot {
+    p0 = f32[5,2,3]{1,2,0} parameter(0)
+    p1 = f32[5,3,4]{1,2,0} parameter(1)
+    ROOT dot.1330.10585 = f32[5,2,4]{2,1,0} dot(p0, p1),
+      lhs_batch_dims={0}, lhs_contracting_dims={2},
+      rhs_batch_dims={0}, rhs_contracting_dims={1}
+  })";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_text));
+
+  ComputationLayout computation_layout(
+      module->entry_computation()->ComputeProgramShape(),
+      /*ignore_layouts=*/false);
+  GpuLayoutAssignment layout_assignment(
+      &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
+  EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch(m::Dot(m::Op().WithShape(F32, {5, 2, 3}, {1, 2, 0}),
+                                m::Op().WithShape(F32, {5, 3, 4}, {1, 2, 0}))
+                             .WithShape(F32, {5, 2, 4}, {2, 1, 0})));
+}
+
+TEST_F(LayoutAssignmentTest, DotLayoutSetToDefaultIfDefaultValid) {
+  const char* hlo_text = R"(
+  HloModule DotLayout
+  ENTRY dot {
+    p0 = f32[5,3,2] parameter(0)
+    p1 = f32[5,4,3]{0,1,2} parameter(1)
+    ROOT dot.1330.10585 = f32[5,2,4] dot(p0, p1),
+      lhs_batch_dims={0}, lhs_contracting_dims={1},
+      rhs_batch_dims={0}, rhs_contracting_dims={2}
+  })";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_text));
+
+  ComputationLayout computation_layout(
+      module->entry_computation()->ComputeProgramShape(),
+      /*ignore_layouts=*/false);
+  GpuLayoutAssignment layout_assignment(
+      &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
+
+  EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch(m::Dot(m::Op().WithShape(F32, {5, 3, 2}, {2, 1, 0}),
+                                m::Op().WithShape(F32, {5, 4, 3}, {2, 1, 0}))
+                             .WithShape(F32, {5, 2, 4}, {2, 1, 0})));
+}
+
+TEST_F(LayoutAssignmentTest, DotOperandLayoutSetToBatchRowsColsOtherwise) {
+  const char* hlo_text = R"(
+  HloModule DotLayout
+  ENTRY dot {
+    p0 = f32[2,3,5]{2,1,0} parameter(0)
+    p1 = f32[3,4,5] parameter(1)
+    ROOT dot.1330.10585 = f32[5,2,4] dot(p0, p1),
+      lhs_batch_dims={2}, lhs_contracting_dims={1},
+      rhs_batch_dims={2}, rhs_contracting_dims={0}
+  })";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_text));
+
+  ComputationLayout computation_layout(
+      module->entry_computation()->ComputeProgramShape(),
+      /*ignore_layouts=*/false);
+  GpuLayoutAssignment layout_assignment(
+      &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
+
+  EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch(m::Dot(m::Op().WithShape(F32, {2, 3, 5}, {0, 1, 2}),
+                                m::Op().WithShape(F32, {3, 4, 5}, {1, 0, 2}))));
+}
+
+TEST_F(LayoutAssignmentTest, DotOperandInconsistentDimLayouts) {
+  const char* hlo_text = R"(
+  HloModule DotLayout
+  ENTRY dot {
+    p0 = f32[5,6,2,3] parameter(0)
+    p1 = f32[6,5,3,4] parameter(1)
+    ROOT dot.1330.10585 = f32[5,6,2,4] dot(p0, p1),
+      lhs_batch_dims={0,1}, lhs_contracting_dims={3},
+      rhs_batch_dims={1,0}, rhs_contracting_dims={2}
+  })";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_text));
+
+  ComputationLayout computation_layout(
+      module->entry_computation()->ComputeProgramShape(),
+      /*ignore_layouts=*/false);
+  GpuLayoutAssignment layout_assignment(
+      &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
+
+  EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(m::Dot(m::Op().WithShape(F32, {5, 6, 2, 3}, {3, 2, 1, 0}),
+                        m::Op().WithShape(F32, {6, 5, 3, 4}, {3, 2, 0, 1}))));
+}
+
+TEST_F(LayoutAssignmentTest, TransposedDotLayout) {
+  const char* hlo_text = R"(
+  HloModule DotLayout
+  ENTRY dot {
+    p0 = f32[5,2,3] parameter(0)
+    p1 = f32[5,3,4,6] parameter(1)
+    dot = f32[5,2,4,6] dot(p0, p1),
+      lhs_batch_dims={0}, lhs_contracting_dims={2},
+      rhs_batch_dims={0}, rhs_contracting_dims={1}
+    ROOT out = f32[2,5,4,6] transpose(dot), dimensions={1,0,2,3}
+  })";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_text));
+
+  ComputationLayout computation_layout(
+      module->entry_computation()->ComputeProgramShape(),
+      /*ignore_layouts=*/false);
+  GpuLayoutAssignment layout_assignment(
+      &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
+
+  EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(m::Transpose(
+                     m::Dot(m::Op().WithShape(F32, {5, 2, 3}, {2, 1, 0}),
+                            m::Op().WithShape(F32, {5, 3, 4, 6}, {3, 2, 1, 0}))
+                         .WithShape(F32, {5, 2, 4, 6}, {3, 2, 0, 1}))
+                     .WithShape(F32, {2, 5, 4, 6}, {3, 2, 1, 0})));
+}
+
+TEST_F(LayoutAssignmentTest, TransposedDotOfDotLayout) {
+  const char* hlo_text = R"(
+  HloModule DotLayout
+  ENTRY dot {
+    p0 = f32[8,50] parameter(0)
+    p1 = f32[2,8,4,4] parameter(1)
+    p2 = f32[4,38] parameter(2)
+    dot.1 = f32[50,2,4,4]{3,2,1,0} dot(p0, p1),
+      lhs_contracting_dims={0}, rhs_contracting_dims={1}
+    dot.2 = f32[50,2,4,38]{3,2,1,0} dot(dot.1, p2),
+      lhs_contracting_dims={2}, rhs_contracting_dims={0}
+    ROOT out = f32[2,50,38,4]{2,3,0,1} transpose(dot.2), dimensions={1,0,3,2}
+  })";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_text));
+
+  ComputationLayout computation_layout(
+      module->entry_computation()->ComputeProgramShape(),
+      /*ignore_layouts=*/false);
+  GpuLayoutAssignment layout_assignment(
+      &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
+
+  EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
+  // The transpose layout is not supported by dot.2. Also, we need a copy
+  // between dot.1 and dot.2, because the needed operand layout for the lhs of
+  // dot.1 cannot be used as layout for dot.1
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(
+          m::Transpose(
+              m::Dot(m::Copy(m::Dot(m::Op().WithShape(F32, {8, 50}, {1, 0}),
+                                    m::Op().WithShape(F32, {2, 8, 4, 4},
+                                                      {3, 2, 0, 1}))
+                                 .WithShape(F32, {50, 2, 4, 4}, {3, 2, 1, 0}))
+                         .WithShape(F32, {50, 2, 4, 4}, {3, 1, 0, 2}),
+                     m::Op().WithShape(F32, {4, 38}, {1, 0}))
+                  .WithShape(F32, {50, 2, 4, 38}, {3, 2, 1, 0}))
+              .WithShape(F32, {2, 50, 38, 4}, {2, 3, 0, 1})));
+}
+
+TEST_F(LayoutAssignmentTest, DotLayoutS8) {
+  const char* hlo_text = R"(
+  HloModule DotLayout
+  ENTRY int8_t {
+    p0 = s8[32,64] parameter(0)
+    p1 = s8[64,96] parameter(1)
+    ROOT out = s32[32,96] dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  })";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_text));
+
+  ComputationLayout computation_layout(
+      module->entry_computation()->ComputeProgramShape(),
+      /*ignore_layouts=*/false);
+  GpuLayoutAssignment layout_assignment(
+      &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
+
+  EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch(m::Dot(m::Op().WithShape(S8, {32, 64}, {1, 0}),
+                                m::Op().WithShape(S8, {64, 96}, {0, 1}))));
+}
+
+TEST_F(LayoutAssignmentTest, SortLayout) {
+  const char* hlo_text = R"(
+  HloModule SortLayout
+
+  compare {
+    p.0.lhs = f32[] parameter(0)
+    p.0.rhs = f32[] parameter(1)
+    p.1.lhs = f32[] parameter(2)
+    p.1.rhs = f32[] parameter(3)
+    ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
+  }
+
+  ENTRY sort {
+    keys = f32[3,2]{0,1} constant({{0,1},{0,1},{0,1}})
+    values = f32[2,3]{1,0} parameter(0)
+    transpose = f32[3,2]{1,0} transpose(values), dimensions={1,0}
+    ROOT sort = (f32[3,2]{1,0}, f32[3,2]{1,0}) sort(keys, transpose),
+      dimensions={1}, to_apply=compare
+  })";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_text));
+
+  ComputationLayout computation_layout(
+      module->entry_computation()->ComputeProgramShape(),
+      /*ignore_layouts=*/false);
+  GpuLayoutAssignment layout_assignment(
+      &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
+
+  EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
+
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch(m::Sort(m::Op().WithShape(F32, {3, 2}, {1, 0}),
+                                 m::Op().WithShape(F32, {3, 2}, {1, 0}))));
+}
+
+TEST_F(LayoutAssignmentTest, FftLayout) {
+  const char* hlo_text = R"(
+  HloModule Fft_module
+
+  ENTRY Fft {
+    input = c64[8,32]{0,1} parameter(0)
+    fft = c64[8,32] fft(input), fft_type=FFT, fft_length={32}
+    ROOT transpose = c64[32,8] transpose(fft), dimensions={1,0}
+  })";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_text));
+
+  ComputationLayout computation_layout(
+      module->entry_computation()->ComputeProgramShape(),
+      /*ignore_layouts=*/false);
+  GpuLayoutAssignment layout_assignment(
+      &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
+
+  EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch(m::Copy(
+                  m::Transpose(m::Fft(m::Op().WithShape(C64, {8, 32}, {1, 0}))
+                                   .WithShape(C64, {8, 32}, {1, 0})))));
+}
+
+TEST_F(LayoutAssignmentTest, CustomCallConstrainedAlias) {
+  const char* module_str = R"(
+HloModule TestModule
+
+ENTRY entry {
+  Arg_0 = f32[2,5,5]{2,1,0} parameter(0)
+  Arg_1 = f32[2,5,5]{2,1,0} parameter(1)
+  Arg_2 = f32[2,5,5]{2,1,0} parameter(2)
+  dot.0 = f32[2,5,5]{2,1,0} dot(Arg_1, Arg_2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={2}, operand_precision={highest,highest}
+  custom-call.0 = (f32[2,5,5]{1,2,0}, s8[16]{0}, s8[16]{0}) custom-call(Arg_0, dot.0), custom_call_target="dummy_call", operand_layout_constraints={f32[2,5,5]{1,2,0}, f32[2,5,5]{1,2,0}}, output_to_operand_aliasing={{0}: (1, {})}
+  ROOT get-tuple-element.0 = f32[2,5,5]{1,2,0} get-tuple-element(custom-call.0), index=0
+}
+)";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
+                          ParseAndReturnVerifiedModule(module_str));
+  ComputationLayout computation_layout(
+      m->entry_computation()->ComputeProgramShape());
+
+  GpuLayoutAssignment layout_assignment(
+      &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
+
+  EXPECT_THAT(layout_assignment.Run(m.get()), IsOkAndHolds(true));
+
+  const HloInstruction* call_0 = FindInstruction(m.get(), "custom-call.0");
+  auto expect_layout = [](const Shape& shape,
+                          absl::Span<const int64_t> minor_to_major) {
+    const Layout expected = LayoutUtil::MakeLayout(minor_to_major);
+    EXPECT_TRUE(LayoutUtil::Equal(shape.layout(), expected))
+        << "Expected layout " << expected << ", actual " << shape.layout();
+  };
+  expect_layout(ShapeUtil::GetSubshape(call_0->shape(), {0}), {1, 2, 0});
+  expect_layout(call_0->operand(0)->shape(), {1, 2, 0});
+  expect_layout(call_0->operand(1)->shape(), {1, 2, 0});
+}
+
+TEST_F(LayoutAssignmentTest, MoveToHostCustomCallConstrained) {
+  const char* module_str = R"(
+HloModule TestModule
+
+ENTRY entry {
+  Arg_0 = f32[2,5,5]{2,1,0} parameter(0)
+  custom-call.0 = f32[2,5,5] custom-call(Arg_0), custom_call_target="MoveToHost"
+  ROOT custom-call.1 = f32[2,5,5]{2, 1, 0} custom-call(custom-call.0), custom_call_target="fixed_call", operand_layout_constraints={f32[2,5,5]{1,2,0}}
+}
+)";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
+                          ParseAndReturnVerifiedModule(module_str));
+  ComputationLayout computation_layout(
+      m->entry_computation()->ComputeProgramShape());
+
+  GpuLayoutAssignment layout_assignment(
+      &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
+
+  EXPECT_THAT(layout_assignment.Run(m.get()), IsOkAndHolds(true));
+
+  const HloInstruction* call_0 = FindInstruction(m.get(), "custom-call.0");
+  const Layout input_layout = call_0->operand(0)->shape().layout();
+  const Layout output_layout = call_0->shape().layout();
+  EXPECT_TRUE(LayoutUtil::Equal(input_layout, output_layout))
+      << "Expected the same input/output layouts.  Input: " << input_layout
+      << ". Output: " << output_layout;
+}
+
+TEST_F(LayoutAssignmentTest, MoveToDeviceCustomCallConstrained) {
+  const char* module_str = R"(
+HloModule TestModule
+
+ENTRY entry {
+  Arg_0 = f32[2,5,5]{2,1,0} parameter(0)
+  custom-call.0 = f32[2,5,5] custom-call(Arg_0), custom_call_target="MoveToDevice"
+  ROOT custom-call.1 = f32[2,5,5]{2, 1, 0} custom-call(custom-call.0), custom_call_target="fixed_call", operand_layout_constraints={f32[2,5,5]{1,2,0}}
+}
+)";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
+                          ParseAndReturnVerifiedModule(module_str));
+  ComputationLayout computation_layout(
+      m->entry_computation()->ComputeProgramShape());
+
+  GpuLayoutAssignment layout_assignment(
+      &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
+
+  EXPECT_THAT(layout_assignment.Run(m.get()), IsOkAndHolds(true));
+
+  const HloInstruction* call_0 = FindInstruction(m.get(), "custom-call.0");
+  const Layout input_layout = call_0->operand(0)->shape().layout();
+  const Layout output_layout = call_0->shape().layout();
+  EXPECT_TRUE(LayoutUtil::Equal(input_layout, output_layout))
+      << "Expected the same input/output layouts.  Input: " << input_layout
+      << ". Output: " << output_layout;
+}
+
+TEST_F(LayoutAssignmentTest, ConvCuDNNF8) {
+  if (!GetCudaComputeCapability().IsAtLeast(
+          se::CudaComputeCapability::HOPPER)) {
+    GTEST_SKIP() << "FP8 convolutions require HOPPER or newer archiecture.";
+  }
+
+  const char* hlo = R"(
+
+  HloModule jit_conv_general_dilated
+
+  ENTRY main.4 {
+    Arg_0 = f8e4m3fn[1,64,64,16]{3,2,1,0} parameter(0)
+    Arg_1 = f8e4m3fn[3,3,16,32]{3,2,1,0} parameter(1)
+    ROOT conv = f8e4m3fn[1,64,64,32]{3,2,1,0} convolution(Arg_0, Arg_1), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f
+  }
+)";
+
+  MatchOptimizedHlo(hlo, R"(
+  // CHECK: [[P0:%[^ ]+]] = f8e4m3fn[1,64,64,16]{3,2,1,0} parameter(0)
+  // CHECK: [[P1:%[^ ]+]] = f8e4m3fn[3,3,16,32]{3,2,1,0} parameter(1)
+  // CHECK-NEXT: [[P2:%[^ ]+]] = f8e4m3fn[32,3,3,16]{3,2,1,0} transpose([[P1]]), dimensions={3,0,1,2}
+  // CHECK-NEXT: [[CONV:%[^ ]+]] = (f8e4m3fn[1,64,64,32]{3,2,1,0}, u8[0]{0}) custom-call([[P0]], [[P2]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph"
+  )");
+}
+
+TEST_F(LayoutAssignmentTest, ConvCuDNNBF16) {
+  if (!GetCudaComputeCapability().IsAtLeast(
+          se::CudaComputeCapability::AMPERE)) {
+    GTEST_SKIP() << "Conv with Bfloat16 uses NHWC layout for "
+                    "architectures with Tensor Cores.";
+  }
+
+  const char* hlo = R"(
+
+  HloModule jit_conv_general_dilated
+
+  ENTRY main.4 {
+    Arg_0.1 = bf16[1,64,64,16]{3,2,1,0} parameter(0), sharding={replicated}
+    Arg_1.2 = bf16[3,3,16,32]{3,2,1,0} parameter(1), sharding={replicated}
+    ROOT convolution.3 = bf16[1,64,64,32]{3,2,1,0} convolution(Arg_0.1, Arg_1.2), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, metadata={op_name="jit(conv_general_dilated)/jit(main)/conv_general_dilated[window_strides=(1, 1) padding=((1, 1), (1, 1)) lhs_dilation=(1, 1) rhs_dilation=(1, 1) dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2)) feature_group_count=1 batch_group_count=1 lhs_shape=(1, 64, 64, 16) rhs_shape=(3, 3, 16, 32) precision=None preferred_element_type=None]" source_file="/usr/local/lib/python3.8/dist-packages/flax/linen/linear.py" source_line=438}
+  }
+)";
+
+  MatchOptimizedHlo(hlo, R"(
+  // CHECK: [[P0:%[^ ]+]] = bf16[1,64,64,16]{3,2,1,0} parameter(0), sharding={replicated}
+  // CHECK: [[P1:%[^ ]+]] = bf16[3,3,16,32]{3,2,1,0} parameter(1), sharding={replicated}
+  // CHECK-NEXT: [[P2:%[^ ]+]] = bf16[32,3,3,16]{3,2,1,0} transpose([[P1]]), dimensions={3,0,1,2}
+  // CHECK-NEXT: %cudnn-conv.1 = (bf16[1,64,64,32]{3,2,1,0}, u8[0]{0}) custom-call([[P0]], [[P2]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForward"
+  )");
+}
+
+TEST_F(LayoutAssignmentTest, ConvCuDNNFP16) {
+  if (!GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::VOLTA)) {
+    GTEST_SKIP() << "Conv with FP16 uses NHWC layout for "
+                    "architectures with Tensor Cores.";
+  }
+
+  const char* hlo = R"(
+
+  HloModule jit_conv_general_dilated
+
+  ENTRY main.4 {
+    Arg_0.1 = f16[1,64,64,16]{3,2,1,0} parameter(0), sharding={replicated}
+    Arg_1.2 = f16[3,3,16,32]{3,2,1,0} parameter(1), sharding={replicated}
+    ROOT convolution.3 = f16[1,64,64,32]{3,2,1,0} convolution(Arg_0.1, Arg_1.2), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f
+  }
+)";
+
+  MatchOptimizedHlo(hlo, R"(
+  // CHECK: [[P0:%[^ ]+]] = f16[1,64,64,16]{3,2,1,0} parameter(0), sharding={replicated}
+  // CHECK: [[P1:%[^ ]+]] = f16[3,3,16,32]{3,2,1,0} parameter(1), sharding={replicated}
+  // CHECK-NEXT: [[P2:%[^ ]+]] = f16[32,3,3,16]{3,2,1,0} transpose([[P1]]), dimensions={3,0,1,2}
+  // CHECK-NEXT: %cudnn-conv.1 = (f16[1,64,64,32]{3,2,1,0}, u8[0]{0}) custom-call([[P0]], [[P2]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForward"
+  )");
+}
+
+TEST_F(LayoutAssignmentTest, ReduceOperandLayout) {
+  const char* module_str = R"(
+scalar_add_computation {
+  scalar_lhs = c64[] parameter(0)
+  scalar_rhs = c64[] parameter(1)
+  ROOT add.1 = c64[] add(scalar_lhs, scalar_rhs)
+}
+
+ENTRY main {
+  param_0 = c64[512,64,1024,32,128]{4,3,2,1,0} parameter(0)
+  negate = c64[512,64,1024,32,128]{4,3,2,1,0} negate(param_0)
+  constant_7 = c64[] constant((0, 0))
+  ROOT reduce.2 = c64[512,1024,128]{2,1,0} reduce(negate, constant_7), dimensions={1,3}, to_apply=scalar_add_computation
+}
+)";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
+                          ParseAndReturnVerifiedModule(module_str));
+  ComputationLayout computation_layout(
+      m->entry_computation()->ComputeProgramShape());
+  GpuLayoutAssignment layout_assignment(
+      &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
+
+  EXPECT_THAT(layout_assignment.Run(m.get()), IsOkAndHolds(true));
+  auto reduce = m->entry_computation()->root_instruction();
+  EXPECT_EQ(reduce->operand(0)->shape().layout().minor_to_major(),
+            LayoutUtil::MakeLayout({3, 1, 4, 2, 0}).minor_to_major());
+}
+
+TEST_F(LayoutAssignmentTest, ReduceOperandLayoutDivisorOfWarpSize) {
+  // Same as ReduceOperandLayout, but with a small reduction dimension that
+  // is a divisor of the warp size.
+  const char* module_str = R"(
+scalar_add_computation {
+  scalar_lhs = c64[] parameter(0)
+  scalar_rhs = c64[] parameter(1)
+  ROOT add.1 = c64[] add(scalar_lhs, scalar_rhs)
+}
+
+ENTRY main {
+  param_0 = c64[512,16,1024,128]{3,2,1,0} parameter(0)
+  negate = c64[512,16,1024,128]{3,2,1,0} negate(param_0)
+  constant_7 = c64[] constant((0, 0))
+  ROOT reduce.2 = c64[512,1024,128]{2,1,0} reduce(negate, constant_7), dimensions={1}, to_apply=scalar_add_computation
+}
+)";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
+                          ParseAndReturnVerifiedModule(module_str));
+  ComputationLayout computation_layout(
+      m->entry_computation()->ComputeProgramShape());
+  GpuLayoutAssignment layout_assignment(
+      &computation_layout, GetGpuComputeCapability(), GetDnnVersion());
+
+  EXPECT_THAT(layout_assignment.Run(m.get()), IsOkAndHolds(true));
+  auto reduce = m->entry_computation()->root_instruction();
+  EXPECT_EQ(reduce->operand(0)->shape().layout().minor_to_major(),
+            LayoutUtil::MakeLayout({1, 3, 2, 0}).minor_to_major());
+}
+
+TEST_F(LayoutAssignmentTest, SendRcvLayout) {
+  const char* hlo = R"(
+HloModule Module
+
+condition  {
+    p = (f32[100,100], (f32[100,100], u32[], token[])) parameter(0)
+    ROOT lt = pred[] constant(1)
+}
+
+body {
+    p = (f32[100,100], (f32[100,100], u32[], token[])) parameter(0)
+
+    t1 = f32[100,100] get-tuple-element(p), index=0
+    t = (f32[100,100], u32[], token[]) get-tuple-element(p), index=1
+    sdone = token[] send-done(t), channel_id=3, frontend_attributes={
+      _xla_send_recv_pipeline="0"
+    }
+    tk = token[] after-all()
+
+
+    rcvd = (f32[100,100]{0,1}, u32[], token[]) recv(tk), channel_id=2
+    zz = (f32[100,100]{0,1}, token[]) recv-done(rcvd), channel_id=2
+
+    rcvd_d = get-tuple-element(zz), index=0
+
+    snd = (f32[100,100]{0,1}, u32[], token[]) send(t1, tk), channel_id=3, frontend_attributes={
+      _xla_send_recv_pipeline="0"
+    }
+    a = add(t1, t1)
+
+    b = add(rcvd_d, a)
+
+    ROOT tup =  tuple(b, snd)
+}
+
+ENTRY %main {
+    p0 = f32[100,100] parameter(0)
+    tk = token[] after-all()
+    snd = (f32[100,100]{0,1}, u32[], token[]) send(p0, tk), channel_id=1, frontend_attributes={
+      _xla_send_recv_pipeline="0"
+    }
+    t = tuple(p0, snd)
+    ROOT loop = while(t), condition=condition, body=body
+}
+)";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
+                          ParseAndReturnVerifiedModule(hlo));
+  ComputationLayout computation_layout(
+      m->entry_computation()->ComputeProgramShape());
+
+  RunAndFilecheckHloRewrite(
+      hlo,
+      GpuLayoutAssignment{&computation_layout, GetGpuComputeCapability(),
+                          GetDnnVersion()},
+      R"(
+// CHECK: (f32[100,100]{1,0}, u32[], token[]) recv
+// CHECK:  (f32[100,100]{1,0}, token[]) recv-done
+// CHECK:  (f32[100,100]{1,0}, u32[], token[]) send
+                                )");
+}
+
+}  // namespace
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/move_copy_to_users.cc b/third_party/xla/xla/service/gpu/transforms/move_copy_to_users.cc
new file mode 100644
index 0000000..ae66093
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/move_copy_to_users.cc
@@ -0,0 +1,240 @@
+/* Copyright 2022 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/move_copy_to_users.h"
+
+#include <vector>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/status.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/layout.h"
+#include "xla/service/hlo_creation_utils.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/logging.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace {
+
+class MoveCopyToUsersVisitor : public DfsHloRewriteVisitor {
+  // Turn copy->pad into pad->copy
+  absl::Status HandlePad(HloInstruction* hlo) override {
+    HloInstruction* operand = hlo->mutable_operand(0);
+    HloInstruction* c = hlo->mutable_operand(1);
+    if (operand->opcode() == HloOpcode::kCopy) {
+      HloInstruction* copied = operand->mutable_operand(0);
+      TF_ASSIGN_OR_RETURN(
+          HloInstruction * earlier_pad,
+          MakePadHlo(copied, c, hlo->padding_config(), &hlo->metadata()));
+      // MakePadHlo fails to propagate layout.
+      *earlier_pad->mutable_shape()->mutable_layout() =
+          copied->shape().layout();
+      HloInstruction* later_copy = MakeCopyHlo(earlier_pad, hlo->shape());
+      TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy));
+    }
+    return absl::OkStatus();
+  }
+
+  // Turn copy->slice into slice->copy, as slice is layout-preserving.
+  absl::Status HandleSlice(HloInstruction* hlo) override {
+    HloInstruction* operand = hlo->mutable_operand(0);
+    if (operand->opcode() == HloOpcode::kCopy) {
+      HloInstruction* copied = operand->mutable_operand(0);
+      TF_ASSIGN_OR_RETURN(
+          HloInstruction * earlier_slice,
+          MakeSliceHlo(copied, hlo->slice_starts(), hlo->slice_limits(),
+                       hlo->slice_strides(), &hlo->metadata()));
+      *earlier_slice->mutable_shape()->mutable_layout() =
+          copied->shape().layout();
+      HloInstruction* later_copy = MakeCopyHlo(earlier_slice, hlo->shape());
+      TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy));
+    }
+    return absl::OkStatus();
+  }
+
+  // Turn copy->dynamic-slice into dynamic-slice->copy, as dynamic-slice is
+  // layout-preserving.
+  absl::Status HandleDynamicSlice(HloInstruction* hlo) override {
+    HloInstruction* operand = hlo->mutable_operand(0);
+    if (operand->opcode() == HloOpcode::kCopy) {
+      HloInstruction* copied = operand->mutable_operand(0);
+      TF_ASSIGN_OR_RETURN(
+          HloInstruction * earlier_slice,
+          MakeDynamicSliceHlo(
+              copied,
+              absl::Span<HloInstruction* const>(hlo->operands()).subspan(1),
+              hlo->dynamic_slice_sizes(), &hlo->metadata()));
+      *earlier_slice->mutable_shape()->mutable_layout() =
+          copied->shape().layout();
+      HloInstruction* later_copy = MakeCopyHlo(earlier_slice, hlo->shape());
+      TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy));
+    }
+    return absl::OkStatus();
+  }
+
+  // Turn copy->reduce_window into reduce_window->copy, as reduce_window is
+  // layout-preserving.
+  absl::Status HandleReduceWindow(HloInstruction* hlo) override {
+    HloInstruction* operand = hlo->mutable_operand(0);
+    if (operand->opcode() == HloOpcode::kCopy) {
+      HloInstruction* copied = operand->mutable_operand(0);
+      TF_ASSIGN_OR_RETURN(
+          HloInstruction * earlier_reduce_window,
+          MakeReduceWindowHlo(copied, hlo->mutable_operand(1), hlo->window(),
+                              hlo->called_computations()[0], &hlo->metadata()));
+      *earlier_reduce_window->mutable_shape()->mutable_layout() =
+          copied->shape().layout();
+      HloInstruction* later_copy =
+          MakeCopyHlo(earlier_reduce_window, hlo->shape());
+      TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy));
+    }
+    return absl::OkStatus();
+  }
+
+  absl::Status HandleReduce(HloInstruction* hlo) override {
+    HloInstruction* operand = hlo->mutable_operand(0);
+    // Reductions can handle transposes, e.g. via column reduction.
+    if (operand->opcode() == HloOpcode::kCopy && !hlo->shape().IsTuple()) {
+      HloInstruction* new_reduce = hlo->AddInstruction(
+          hlo->CloneWithNewOperands(hlo->shape(), {operand->mutable_operand(0),
+                                                   hlo->mutable_operand(1)}));
+      TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, new_reduce));
+    }
+    return absl::OkStatus();
+  }
+
+  absl::Status HandleBitcastConvert(HloInstruction* hlo) override {
+    return absl::OkStatus();
+  }
+
+  // Sink kCopy across elementwise unary.
+  absl::Status HandleElementwiseUnary(HloInstruction* hlo) override {
+    HloInstruction* operand = hlo->mutable_operand(0);
+    if (hlo->opcode() == HloOpcode::kReducePrecision) {
+      return absl::OkStatus();
+    }
+    if (operand->opcode() == HloOpcode::kCopy) {
+      HloInstruction* copied = operand->mutable_operand(0);
+      TF_ASSIGN_OR_RETURN(
+          HloInstruction * earlier_elementwise,
+          MakeUnaryHlo(hlo->opcode(), copied, &hlo->metadata()));
+      HloInstruction* later_copy =
+          MakeCopyHlo(earlier_elementwise, hlo->shape());
+      TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy));
+    }
+    return absl::OkStatus();
+  }
+
+  // Sink kCopy across reverse
+  absl::Status HandleReverse(HloInstruction* hlo) override {
+    HloInstruction* operand = hlo->mutable_operand(0);
+    if (operand->opcode() == HloOpcode::kCopy) {
+      HloInstruction* copied = operand->mutable_operand(0);
+      TF_ASSIGN_OR_RETURN(
+          HloInstruction * earlier_reverse,
+          MakeReverseHlo(copied, hlo->dimensions(), &hlo->metadata()));
+      HloInstruction* later_copy = MakeCopyHlo(earlier_reverse, hlo->shape());
+      TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy));
+    }
+    return absl::OkStatus();
+  }
+
+  // Sink kCopy across convert.
+  absl::Status HandleConvert(HloInstruction* hlo) override {
+    HloInstruction* operand = hlo->mutable_operand(0);
+    if (operand->opcode() == HloOpcode::kCopy) {
+      HloInstruction* copied = operand->mutable_operand(0);
+      HloInstruction* earlier_convert = MakeConvertToHlo(
+          copied, hlo->shape().element_type(), &hlo->metadata());
+      HloInstruction* later_copy = MakeCopyHlo(earlier_convert, hlo->shape());
+      TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy));
+    }
+    return absl::OkStatus();
+  }
+
+  // Sink kCopy across elementwise binary.
+  absl::Status HandleElementwiseBinary(HloInstruction* hlo) override {
+    HloInstruction* a = hlo->mutable_operand(0);
+    HloInstruction* b = hlo->mutable_operand(1);
+    if (a->opcode() == HloOpcode::kCopy && b->opcode() == HloOpcode::kCopy) {
+      HloInstruction* copied_a = a->mutable_operand(0);
+      HloInstruction* copied_b = b->mutable_operand(0);
+      if (copied_a->shape() == copied_b->shape()) {
+        HloInstruction* earlier_elementwise;
+        if (hlo->opcode() == HloOpcode::kCompare) {
+          TF_ASSIGN_OR_RETURN(
+              earlier_elementwise,
+              MakeCompareHlo(hlo->comparison_direction(), copied_a, copied_b,
+                             &hlo->metadata()));
+        } else {
+          TF_ASSIGN_OR_RETURN(earlier_elementwise,
+                              MakeBinaryHlo(hlo->opcode(), copied_a, copied_b,
+                                            &hlo->metadata()));
+        }
+        HloInstruction* later_copy =
+            MakeCopyHlo(earlier_elementwise, hlo->shape());
+        TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy));
+      }
+    }
+    return absl::OkStatus();
+  }
+
+  // Move copy across kConcat if it occurs on all operands.
+  absl::Status HandleConcatenate(HloInstruction* hlo) override {
+    const HloInstruction* first = hlo->operand(0);
+    if (first->opcode() != HloOpcode::kCopy) {
+      return absl::OkStatus();
+    }
+    const HloInstruction* inner_op = first->operand(0);
+    const Layout& inner_op_layout = inner_op->shape().layout();
+
+    std::vector<HloInstruction*> new_operands;
+    new_operands.reserve(hlo->operand_count());
+    for (HloInstruction* op : hlo->mutable_operands()) {
+      if (op->opcode() != HloOpcode::kCopy ||
+          op->operand(0)->shape().layout() != inner_op_layout) {
+        VLOG(3) << "Mismatch between " << op->ToString()
+                << " and expected op layout " << inner_op_layout.ToString();
+        return absl::OkStatus();
+      }
+      new_operands.push_back(op->mutable_operand(0));
+    }
+
+    TF_ASSIGN_OR_RETURN(
+        HloInstruction * new_concat,
+        MakeConcatHlo(new_operands, hlo->concatenate_dimension()));
+    *new_concat->mutable_shape()->mutable_layout() = inner_op_layout;
+
+    HloInstruction* new_copy = MakeCopyHlo(new_concat, hlo->shape());
+    TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, new_copy));
+    return absl::OkStatus();
+  }
+};
+
+}  // end namespace
+
+absl::StatusOr<bool> MoveCopyToUsers::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  return MoveCopyToUsersVisitor{}.RunOnModule(module, execution_threads);
+}
+
+}  // end namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/move_copy_to_users.h b/third_party/xla/xla/service/gpu/transforms/move_copy_to_users.h
new file mode 100644
index 0000000..698db04
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/move_copy_to_users.h
@@ -0,0 +1,39 @@
+/* Copyright 2022 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_MOVE_COPY_TO_USERS_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_MOVE_COPY_TO_USERS_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+
+// Sink kCopy operations as far down the graph as possible.
+class MoveCopyToUsers : public HloModulePass {
+ public:
+  absl::string_view name() const override { return "move_copy_to_users"; }
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+}  // end namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_MOVE_COPY_TO_USERS_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/move_copy_to_users_test.cc b/third_party/xla/xla/service/gpu/transforms/move_copy_to_users_test.cc
new file mode 100644
index 0000000..85999db
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/move_copy_to_users_test.cc
@@ -0,0 +1,274 @@
+/* Copyright 2022 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/move_copy_to_users.h"
+
+#include <optional>
+
+#include "absl/strings/string_view.h"
+#include "xla/service/layout_assignment.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/test.h"
+
+namespace xla {
+namespace {
+
+class MoveCopyToUsersTest : public HloTestBase {
+ public:
+  MoveCopyToUsersTest()
+      : HloTestBase(/*verifier_layout_sensitive=*/true,
+                    /*allow_mixed_precision_in_hlo_verifier=*/true,
+                    LayoutAssignment::InstructionCanChangeLayout) {}
+  void CheckMoveCopyToUsers(absl::string_view hlo,
+                            std::optional<absl::string_view> expected) {
+    RunAndFilecheckHloRewrite(hlo, MoveCopyToUsers{}, expected);
+  }
+};
+
+TEST_F(MoveCopyToUsersTest, Pad) {
+  const char* hlo = R"(
+HloModule module
+
+ENTRY main {
+  input = s8[1,17,9,9]{3,1,2,0} parameter(0)
+  copy = s8[1,17,9,9]{1,3,2,0} copy(input)
+  constant = s8[] constant(0)
+  ROOT pad = s8[1,32,9,9]{1,3,2,0} pad(copy, constant), padding=0_0x0_15x0_0x0_0
+}
+)";
+
+  CheckMoveCopyToUsers(hlo, R"(
+// CHECK: [[constant_0:%[^ ]+]] = s8[] constant(0)
+// CHECK: [[pad_1_1:%[^ ]+]] = s8[1,32,9,9]{3,1,2,0} pad([[input_2:%[^ ]+]], [[constant_0]]), padding=0_0x0_15x0_0x0_0
+// CHECK: ROOT [[copy_1_3:%[^ ]+]] = s8[1,32,9,9]{1,3,2,0} copy([[pad_1_1]])
+)");
+}
+
+TEST_F(MoveCopyToUsersTest, Unary) {
+  const char* hlo = R"(
+HloModule module
+
+ENTRY main {
+  input = f32[1,17,9,9]{3,2,1,0} parameter(0)
+  copy = f32[1,17,9,9]{1,3,2,0} copy(input)
+  ROOT pad = f32[1,17,9,9]{1,3,2,0} sqrt(copy)
+}
+)";
+
+  CheckMoveCopyToUsers(hlo, R"(
+// CHECK: [[input_0:%[^ ]+]] = f32[1,17,9,9]{3,2,1,0} parameter(0)
+// CHECK: [[sqrt_1:%[^ ]+]] = f32[1,17,9,9]{3,2,1,0} sqrt([[input_0]])
+// CHECK: ROOT [[copy_1_2:%[^ ]+]] = f32[1,17,9,9]{1,3,2,0} copy([[sqrt_1]])
+)");
+}
+
+TEST_F(MoveCopyToUsersTest, Reverse) {
+  const char* hlo = R"(
+HloModule module
+
+ENTRY main {
+  input = f32[1,17,9,9]{3,2,1,0} parameter(0)
+  copy = f32[1,17,9,9]{1,3,2,0} copy(input)
+  ROOT pad = f32[1,17,9,9]{1,3,2,0} reverse(copy), dimensions={1,2}
+}
+)";
+
+  CheckMoveCopyToUsers(hlo, R"(
+// CHECK: [[input_0:%[^ ]+]] = f32[1,17,9,9]{3,2,1,0} parameter(0)
+// CHECK: [[sqrt_1:%[^ ]+]] = f32[1,17,9,9]{3,2,1,0} reverse([[input_0]]), dimensions={1,2}
+// CHECK: ROOT [[copy_1_2:%[^ ]+]] = f32[1,17,9,9]{1,3,2,0} copy([[sqrt_1]])
+)");
+}
+
+TEST_F(MoveCopyToUsersTest, Convert) {
+  const char* hlo = R"(
+HloModule module
+
+ENTRY main {
+  input = f32[1,17,9,9]{3,2,1,0} parameter(0)
+  copy = f32[1,17,9,9]{1,3,2,0} copy(input)
+  ROOT converted = f16[1,17,9,9]{1,3,2,0} convert(copy)
+}
+)";
+
+  CheckMoveCopyToUsers(hlo, R"(
+// CHECK: [[input_0:%[^ ]+]] = f32[1,17,9,9]{3,2,1,0} parameter(0)
+// CHECK: [[sqrt_1:%[^ ]+]] = f16[1,17,9,9]{3,2,1,0} convert([[input_0]])
+// CHECK: ROOT [[copy_1_2:%[^ ]+]] = f16[1,17,9,9]{1,3,2,0} copy([[sqrt_1]])
+)");
+}
+
+TEST_F(MoveCopyToUsersTest, Slice) {
+  const char* hlo = R"(
+HloModule module
+
+ENTRY main {
+  input = f32[1,17,9,9]{3,2,1,0} parameter(0)
+  copy = f32[1,17,9,9]{1,3,2,0} copy(input)
+  ROOT slice = f32[1,4,6,6]{1,3,2,0} slice(copy), slice={[0:1],[0:4],[0:6],[0:6]}
+}
+)";
+
+  CheckMoveCopyToUsers(hlo, R"(
+// CHECK: [[slice_0:%[^ ]+]] = f32[1,4,6,6]{3,2,1,0} slice([[input_1:%[^ ]+]]), slice={[0:1], [0:4], [0:6], [0:6]}
+// CHECK-NEXT: ROOT [[copy_1_2:%[^ ]+]] = f32[1,4,6,6]{1,3,2,0} copy([[slice_0]])
+)");
+}
+
+TEST_F(MoveCopyToUsersTest, DynamicSlice) {
+  const char* hlo = R"(
+HloModule module
+
+ENTRY main {
+  input = f32[1,17,9,9]{3,2,1,0} parameter(0)
+  copy = f32[1,17,9,9]{1,3,2,0} copy(input)
+  p0 = s32[] parameter(1)
+  p1 = s32[] parameter(2)
+  p2 = s32[] parameter(3)
+  p3 = s32[] parameter(4)
+  ROOT ds = f32[1,4,6,6]{1,3,2,0} dynamic-slice(copy, p0, p1, p2, p3), dynamic_slice_sizes={1,4,6,6}
+}
+)";
+
+  CheckMoveCopyToUsers(hlo, R"(
+// CHECK: [[ds:%[^ ]+]] = f32[1,4,6,6]{3,2,1,0} dynamic-slice({{.*}}), dynamic_slice_sizes={1,4,6,6}
+// CHECK-NEXT: ROOT {{.*}} = f32[1,4,6,6]{1,3,2,0} copy([[ds]])
+)");
+}
+
+TEST_F(MoveCopyToUsersTest, ReduceWindow) {
+  const char* hlo = R"(
+HloModule R2Window
+
+mul {
+  lhs = f32[] parameter(0)
+  rhs = f32[] parameter(1)
+  ROOT mul = f32[] multiply(lhs, rhs)
+}
+
+ENTRY R2Window {
+  operand = f32[256,384]{1,0} parameter(0)
+  c = f32[256,384]{0,1} copy(operand)
+  constant = f32[] constant(1)
+  ROOT reduce-window = f32[256,384]{0,1} reduce-window(c, constant), window={size=2x3 pad=0_1x1_1}, to_apply=mul
+}
+)";
+
+  CheckMoveCopyToUsers(hlo, R"(
+// CHECK: [[reduce_window_1_0:%[^ ]+]] = f32[256,384]{1,0} reduce-window([[operand_1:%[^ ]+]], [[constant_2:%[^ ]+]]), window={size=2x3 pad=0_1x1_1}, to_apply=[[mul_3:%[^ ]+]]
+// CHECK-NEXT: ROOT [[copy_4:%[^ ]+]] = f32[256,384]{0,1} copy([[reduce_window_1_0]])
+)");
+}
+
+TEST_F(MoveCopyToUsersTest, Reduce) {
+  const char* hlo = R"(
+HloModule R2
+
+mul {
+  lhs = f32[] parameter(0)
+  rhs = f32[] parameter(1)
+  ROOT mul = f32[] multiply(lhs, rhs)
+}
+
+ENTRY R2 {
+  operand = f32[256,384,10]{2,1,0} parameter(0)
+  c = f32[256,384,10]{0,1,2} copy(operand)
+  constant = f32[] constant(1)
+  ROOT reduce = f32[384,10]{0,1} reduce(c, constant), dimensions={0}, to_apply=mul
+}
+)";
+
+  CheckMoveCopyToUsers(hlo, R"(
+// CHECK: [[operand:%[^ ]+]] = f32[256,384,10]{2,1,0} parameter(0)
+// CHECK: ROOT [[reduce:%[^ ]+]] = f32[384,10]{0,1} reduce([[operand]], [[constant_2:%[^ ]+]]), dimensions={0}, to_apply=[[mul_3:%[^ ]+]]
+)");
+}
+
+TEST_F(MoveCopyToUsersTest, Binary) {
+  const char* hlo = R"(
+HloModule module
+
+ENTRY main {
+  input = f32[1,17,9,9]{3,2,1,0} parameter(0)
+  input2 = f32[1,17,9,9]{3,2,1,0} parameter(1)
+  copy = f32[1,17,9,9]{1,3,2,0} copy(input)
+  copy2 = f32[1,17,9,9]{1,3,2,0} copy(input2)
+  ROOT add = f32[1,17,9,9]{1,3,2,0} add(copy, copy2)
+}
+)";
+
+  CheckMoveCopyToUsers(hlo, R"(
+// CHECK: [[input_0:%[^ ]+]] = f32[1,17,9,9]{3,2,1,0} parameter(0)
+// CHECK: [[input2_1:%[^ ]+]] = f32[1,17,9,9]{3,2,1,0} parameter(1)
+// CHECK: [[add_1_2:%[^ ]+]] = f32[1,17,9,9]{3,2,1,0} add([[input_0]], [[input2_1]])
+// CHECK: ROOT [[copy_1_3:%[^ ]+]] = f32[1,17,9,9]{1,3,2,0} copy([[add_1_2]])
+)");
+}
+
+TEST_F(MoveCopyToUsersTest, BinaryDifferentLayoutNoChange) {
+  const char* hlo = R"(
+HloModule module
+
+ENTRY main {
+  input = f32[1,17,9,9]{3,2,0,1} parameter(0)
+  input2 = f32[1,17,9,9]{3,2,1,0} parameter(1)
+  copy = f32[1,17,9,9]{1,3,2,0} copy(input)
+  copy2 = f32[1,17,9,9]{1,3,2,0} copy(input2)
+  ROOT add = f32[1,17,9,9]{1,3,2,0} add(copy, copy2)
+}
+)";
+
+  CheckMoveCopyToUsers(hlo, std::nullopt);
+}
+
+TEST_F(MoveCopyToUsersTest, Concat) {
+  const char* hlo = R"(
+HloModule module
+
+ENTRY main {
+  input = f32[1,17,9,9]{3,2,1,0} parameter(0)
+  input2 = f32[5,17,9,9]{3,2,1,0} parameter(1)
+  copy = f32[1,17,9,9]{1,3,2,0} copy(input)
+  copy2 = f32[5,17,9,9]{1,3,2,0} copy(input2)
+  ROOT add = f32[6,17,9,9]{1,3,2,0} concatenate(copy, copy2), dimensions={0}
+}
+)";
+
+  CheckMoveCopyToUsers(hlo, R"(
+// CHECK: [[input_0:%[^ ]+]] = f32[1,17,9,9]{3,2,1,0} parameter(0)
+// CHECK: [[input2_1:%[^ ]+]] = f32[5,17,9,9]{3,2,1,0} parameter(1)
+// CHECK: [[concat:%[^ ]+]] = f32[6,17,9,9]{3,2,1,0} concatenate([[input_0]], [[input2_1]])
+// CHECK: ROOT [[copy_1_3:%[^ ]+]] = f32[6,17,9,9]{1,3,2,0} copy([[concat]])
+)");
+}
+
+TEST_F(MoveCopyToUsersTest, ConcatDifferentLayoutNoChange) {
+  const char* hlo = R"(
+HloModule module
+
+ENTRY main {
+  input = f32[1,17,9,9]{3,2,0,1} parameter(0)
+  input2 = f32[1,17,9,9]{3,2,1,0} parameter(1)
+  copy = f32[1,17,9,9]{1,3,2,0} copy(input)
+  copy2 = f32[1,17,9,9]{1,3,2,0} copy(input2)
+  ROOT add = f32[2,17,9,9]{1,3,2,0} concatenate(copy, copy2), dimensions={0}
+}
+)";
+
+  CheckMoveCopyToUsers(hlo, std::nullopt);
+}
+
+}  // namespace
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/multi_output_fusion.cc b/third_party/xla/xla/service/gpu/transforms/multi_output_fusion.cc
new file mode 100644
index 0000000..35bfe8e
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/multi_output_fusion.cc
@@ -0,0 +1,521 @@
+/* Copyright 2018 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/multi_output_fusion.h"
+
+#include <algorithm>
+#include <cstddef>
+#include <cstdint>
+#include <iterator>
+#include <memory>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/string_view.h"
+#include "xla/debug_options_flags.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_dfs_reachability.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/gpu/gpu_fusible.h"
+#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h"
+#include "xla/service/gpu/model/gpu_performance_model.h"
+#include "xla/service/gpu/model/gpu_performance_model_base.h"
+#include "xla/service/hlo_graph_dumper.h"
+#include "xla/service/instruction_fusion.h"
+#include "xla/shape_util.h"
+#include "xla/stream_executor/device_description.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/status.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+
+bool IsProfitableOperand(HloInstruction* instr) {
+  // Effective scalars are not a profitable shared operand. Skip them.
+  return !ShapeUtil::IsEffectiveScalar(instr->shape());
+}
+
+// Finds and returns the unique `slice` op where `parent` is used in `instr`.
+// Returns `nullptr` if no such `slice` exists.
+const HloSliceInstruction* FindUniqueSlice(const HloInstruction* parent,
+                                           const HloInstruction* instr) {
+  if (const auto* slice = DynCast<HloSliceInstruction>(instr)) {
+    return slice;
+  } else if (const auto* fusion = DynCast<HloFusionInstruction>(instr)) {
+    const HloSliceInstruction* result = nullptr;
+    for (size_t i = 0; i < fusion->operand_count(); ++i) {
+      if (fusion->operand(i) == parent) {
+        // Parameter used more than once -> there's no unique slice.
+        if (result) return nullptr;
+
+        auto* called_param = fusion->fused_parameter(i);
+        if (called_param->user_count() != 1) return nullptr;
+
+        result = FindUniqueSlice(called_param, called_param->users()[0]);
+        if (!result) return nullptr;
+      }
+    }
+    return result;
+  } else {
+    return nullptr;
+  }
+}
+
+FusionDecision ParameterSlicesAreNonOverlapping(const HloInstruction& instr1,
+                                                const HloInstruction& instr2,
+                                                const HloInstruction* parent) {
+  if (parent->shape().IsTuple()) return {};
+  // Allow MOF if the parameter is small, even if there's no overlap. 1024 bytes
+  // were arbitrarily chosen as the threshold.
+  if (ShapeUtil::ByteSizeOfElements(parent->shape()) < 1024) return {};
+
+  const HloSliceInstruction* slice1 = FindUniqueSlice(parent, &instr1);
+  const HloSliceInstruction* slice2 = FindUniqueSlice(parent, &instr2);
+  if (!slice1 || !slice2) return {};
+
+  // TODO(jreiffers): Check strides as well.
+  auto& starts1 = slice1->slice_starts();
+  auto& starts2 = slice2->slice_starts();
+  auto& limits1 = slice1->slice_limits();
+  auto& limits2 = slice2->slice_limits();
+
+  for (int64_t dim = 0; dim < parent->shape().rank(); ++dim) {
+    bool overlap = starts1[dim] < limits2[dim] && starts2[dim] < limits1[dim];
+    if (!overlap) {
+      return "slices are non-overlapping";
+    }
+  }
+  return {};
+}
+
+FusionDecision LegalToFuse(const HloInstruction& instr1,
+                           const HloInstruction& instr2,
+                           const se::DeviceDescription& device_info,
+                           FusionInfoCache* fusion_info_cache) {
+  CHECK(instr1.opcode() == HloOpcode::kFusion);
+
+  // The emitter only supports in-place DUS for fusions with a single DUS at the
+  // root. Don't sibling fuse DUS for now.
+  // TODO(b/119178699): Multi-output fusing DUS can improve performance if we
+  // share the input and output buffers and add support to the emitter.
+  if (instr1.fused_expression_root()->opcode() ==
+          HloOpcode::kDynamicUpdateSlice ||
+      (instr2.opcode() == HloOpcode::kFusion &&
+       instr2.fused_expression_root()->opcode() ==
+           HloOpcode::kDynamicUpdateSlice)) {
+    return "can't fuse multiple DUSs";
+  }
+
+  // Do this check last, as it may be expensive.
+  return FusionFitsInBudget(instr1, instr2, device_info,
+                            /*is_consumer_producer_fusion=*/false,
+                            fusion_info_cache);
+}
+
+// We prefer multi-output fusions over other fusions over unfused ops, because
+// we want to preserve fusion opportunities if possible.
+int FusionPriority(const HloInstruction* instr) {
+  if (instr->IsMultiOutputFusion()) {
+    return 2;
+  }
+  if (instr->opcode() == HloOpcode::kFusion) {
+    return 1;
+  }
+  return 0;
+}
+
+HloInstruction* SelectPreferredFusionCandidate(
+    const std::vector<HloInstruction*> candidates) {
+  if (candidates.empty()) {
+    return nullptr;
+  }
+  return *std::max_element(
+      candidates.begin(), candidates.end(),
+      [](const HloInstruction* a, const HloInstruction* b) {
+        return FusionPriority(a) < FusionPriority(b);
+      });
+}
+
+// Do not fuse a producer if the other operands of the fusion are
+// reachable from the producer, this would create a cycle.
+FusionDecision OperandReachableFromProducer(
+    const HloInstruction& producer, const HloInstruction& consumer,
+    const HloDfsReachability& reachability) {
+  for (const auto* operand : consumer.operands()) {
+    // If a get-tuple-element instruction is not in the reachability
+    // map, it has been created by fusion in this pass. Simply move
+    // on to its operand, which is in the reachability map.
+    if (!reachability.IsPresent(operand) &&
+        operand->opcode() == HloOpcode::kGetTupleElement) {
+      operand = operand->operand(0);
+    }
+    CHECK(reachability.IsPresent(operand) && reachability.IsPresent(&producer))
+        << "Reachability map is incomplete. This should never "
+           "happen.";
+    if (&producer != operand && reachability.IsReachable(&producer, operand)) {
+      return {
+          absl::StrCat(producer.name(), " would introduce a cycle when fused")};
+    }
+  }
+  return {};
+}
+
+FusionDecision ProducerCandidateIsFusible(
+    const HloInstruction& producer, const HloInstruction& consumer,
+    const HloDfsReachability& reachability, FusionInfoCache* fusion_info_cache,
+    const se::DeviceDescription& device_info,
+    GpuHloCostAnalysis* cost_analysis) {
+  if (!IsFusibleAsMultiOutputFusionRoot(consumer)) {
+    return "consumer not eligible as multi-output fusion root.";
+  }
+
+  RETURN_IF_NOT_FUSIBLE(
+      ShapesCompatibleForMultiOutputFusion(consumer, producer));
+
+  RETURN_IF_NOT_FUSIBLE(
+      OperandReachableFromProducer(producer, consumer, reachability));
+
+  RETURN_IF_NOT_FUSIBLE(FusionFitsInBudget(
+      producer, consumer, device_info,
+      /*is_consumer_producer_fusion=*/false, fusion_info_cache));
+
+  if (cost_analysis->ProducerConsumerMergedTooLarge(producer, consumer)) {
+    return "will generate too large IR";
+  }
+
+  GpuPerformanceModel::RunTimes t = GpuPerformanceModel::EstimateRunTimes(
+      &producer, device_info, cost_analysis,
+      GpuPerformanceModelOptions::Default(),
+      /*fused_consumers=*/{&consumer},
+      /*multi_output=*/true);
+  if (t.time_fused > t.time_unfused) {
+    return "will execute slower if fused";
+  }
+
+  return {};
+}
+
+std::vector<HloInstruction*> GetProducerConsumerMultiOutputFusionCandidates(
+    const HloInstruction* producer, const HloDfsReachability& reachability,
+    FusionInfoCache* fusion_info_cache,
+    const se::DeviceDescription& device_info,
+    GpuHloCostAnalysis* cost_analysis) {
+  std::vector<HloInstruction*> fusion_candidates;
+  const HloComputation* computation = producer->parent();
+  const HloModule* module = computation->parent();
+  bool dump_fusion =
+      module->config().debug_options().xla_dump_fusion_visualization();
+
+  // If the producer is not a valid candidate for MOF, no need to check any of
+  // its users.
+  if (!IsProducerMultiOutputFusible(*producer)) {
+    return fusion_candidates;
+  }
+
+  // If there is only one user, and it is not a multi-output fusion node, this
+  // fusion possibility was already considered and rejected by the FusionMerger
+  // pass. No need to try again!
+  if (producer->user_count() == 1 &&
+      !producer->users()[0]->IsMultiOutputFusion()) {
+    return fusion_candidates;
+  }
+
+  for (HloInstruction* consumer : producer->users()) {
+    VLOG(3) << "Looking at producer " << producer->name()
+            << " and its consumer " << consumer->name();
+
+    if (auto decision = ProducerCandidateIsFusible(
+            *producer, *consumer, reachability, fusion_info_cache, device_info,
+            cost_analysis)) {
+      fusion_candidates.push_back(consumer);
+    } else if (dump_fusion) {
+      RegisterFusionState(
+          *computation,
+          absl::StrCat("Not considering fusion of producer |", producer->name(),
+                       "| into consumer |", consumer->name(),
+                       "| due to: ", decision.Explain()),
+          *consumer, producer);
+    }
+  }
+  return fusion_candidates;
+}
+
+bool IsSiblingFusionCandidate(const HloInstruction* instr) {
+  if (instr->users().empty() || !IsFusibleAsMultiOutputFusionRoot(*instr) ||
+      IsNestableVariadicReduction(*instr)) {
+    return false;
+  }
+  // Check if the users of multioutput fusion is not a get-tuple-element.
+  // If this is the case, we bail out because the transformation assumes
+  // the users are get-tuple-element.
+  return (!instr->IsMultiOutputFusion() ||
+          absl::c_all_of(instr->users(), [&](const HloInstruction* user) {
+            return user->opcode() == HloOpcode::kGetTupleElement;
+          }));
+}
+
+FusionDecision CanFuseSiblings(const HloInstruction& sibling_consumer_1,
+                               const HloInstruction& sibling_consumer_2,
+                               const HloInstruction& common_producer,
+                               const HloDfsReachability& reachability,
+                               FusionInfoCache* fusion_info_cache,
+                               const se::DeviceDescription& device_info) {
+  if (reachability.IsConnected(&sibling_consumer_1, &sibling_consumer_2)) {
+    return {absl::StrCat(sibling_consumer_1.name(), " and ",
+                         sibling_consumer_2.name(), " are connected")};
+  }
+
+  RETURN_IF_NOT_FUSIBLE(ShapesCompatibleForMultiOutputFusion(
+      sibling_consumer_1, sibling_consumer_2));
+
+  // Technically, this check is order-dependent (e.g. siblings A, B, C where
+  // {A, B} and {B, C} overlap, but {A, C} do not. If the priority order is
+  // [C, A, B], only {C, B} will be fused, and A will only be fused in the
+  // next iteration of the fusion pipeline, potentially requiring several
+  // iterations to converge. We assume this case to be very rare in
+  // practice.
+  RETURN_IF_NOT_FUSIBLE(ParameterSlicesAreNonOverlapping(
+      sibling_consumer_1, sibling_consumer_2, &common_producer));
+
+  // This check should be last, as it may be expensive.
+  RETURN_IF_NOT_FUSIBLE(LegalToFuse(sibling_consumer_1, sibling_consumer_2,
+                                    device_info, fusion_info_cache));
+  return {};
+}
+
+}  // namespace
+
+void MultiOutputFusion::RecomputeReachability() {
+  reachability_ = HloDfsReachability::Build(computation_);
+}
+
+bool MultiOutputFusion::FuseSiblings(HloInstruction* parent,
+                                     FusionInfoCache* fusion_info_cache,
+                                     GpuHloCostAnalysis* cost_analysis) {
+  const HloComputation* computation = parent->parent();
+  const HloModule* module = computation->parent();
+  bool dump_fusion =
+      module->config().debug_options().xla_dump_fusion_visualization();
+
+  if (!IsProfitableOperand(parent)) {
+    VLOG(3) << "Operand " << parent->ToShortString() << " is not profitable";
+    return false;
+  }
+  bool changed = false;
+  std::vector<HloInstruction*> siblings;
+  // Only consider siblings that are fusion candidates.
+  absl::c_copy_if(parent->users(), std::back_inserter(siblings),
+                  IsSiblingFusionCandidate);
+  // Sort the siblings such that multi-output fusion ops occur first, followed
+  // by fusion ops, followed by unfused ops.
+  absl::c_stable_sort(siblings,
+                      [](const HloInstruction* a, const HloInstruction* b) {
+                        return FusionPriority(a) > FusionPriority(b);
+                      });
+
+  for (auto i = siblings.begin(); i != siblings.end(); ++i) {
+    VLOG(3) << "Considering " << (*i)->name();
+    if ((*i)->opcode() != HloOpcode::kFusion) {
+      continue;
+    }
+    for (auto j = i + 1; j != siblings.end();) {
+      VLOG(3) << "Considering " << (*i)->name() << " and " << (*j)->name();
+
+      if (auto fusible = CanFuseSiblings(**i, **j, *parent, *reachability_,
+                                         fusion_info_cache, device_info_);
+          !fusible) {
+        // We pick `j` arbitrarily as a consumer.
+        if (dump_fusion) {
+          RegisterFusionState(
+              *computation,
+              absl::StrCat("Not fusing siblings |", (**i).name(), "| and |",
+                           (**j).name(), "| due to: ", fusible.Explain()),
+              // Randomly pick one consumer.
+              /*consumer=*/**i,
+              /*producer=*/parent);
+        }
+        ++j;
+        continue;
+      }
+      if (!ConsumeFuel(name(), [&] {
+            return absl::StrFormat("Not fusing siblings %s and %s.",
+                                   (*i)->name(), (*j)->name());
+          })) {
+        ++j;
+        continue;
+      }
+      VLOG(2) << "Fuse siblings " << (*i)->name() << " and " << (*j)->name();
+      fusion_info_cache->Invalidate(*i);
+      fusion_info_cache->Invalidate(*j);
+      HloInstruction* remaining = *i;
+      HloInstruction* fused = *j;
+      TF_CHECK_OK(cost_analysis->RemoveInstruction(remaining));
+      TF_CHECK_OK(cost_analysis->RemoveInstruction(fused));
+
+      DumpFusionState(*remaining,
+                      absl::StrCat("About to fuse sibling |", fused->name(),
+                                   "| into sibling |", remaining->name(),
+                                   "| inside multi-output fusion"),
+                      /*producer=*/fused);
+
+      if (fused->opcode() == HloOpcode::kFusion) {
+        remaining->MergeFusionInstructionIntoMultiOutput(fused);
+        if (fused->IsInputFusion()) {
+          remaining->set_fusion_kind(HloInstruction::FusionKind::kInput);
+        }
+      } else {
+        remaining->FuseInstructionIntoMultiOutput(fused);
+        CHECK_EQ(0, fused->user_count());
+        TF_CHECK_OK(computation_->RemoveInstruction(fused));
+      }
+      DumpFusionState(*remaining,
+                      absl::StrCat("Fused into |", remaining->name(),
+                                   "| inside multi-output fusion"));
+      TF_CHECK_OK(cost_analysis->RevisitInstruction(remaining));
+      changed = true;
+      siblings.erase(j);
+      RecomputeReachability();
+    }
+  }
+  return changed;
+}
+
+absl::StatusOr<bool> MultiOutputFusion::DoMultiOutputFusion() {
+  bool changed = false;
+  RecomputeReachability();
+  GpuHloCostAnalysis cost_analysis({shape_size_function_,
+                                    /*per_second_rates=*/{},
+                                    /*count_multiple_input_accesses=*/true},
+                                   device_info_);
+  TF_RETURN_IF_ERROR(computation_->Accept(&cost_analysis));
+  std::vector<HloInstruction*> defs_before_uses =
+      computation_->MakeInstructionPostOrder();
+
+  FusionInfoCache fusion_info_cache;
+  // Traverse the HLO in uses-before-defs order.
+  for (auto it = defs_before_uses.rbegin(); it != defs_before_uses.rend();
+       ++it) {
+    auto* producer = *it;
+    // Never multi-output fuse constants.  To the extent that we want to fuse
+    // constants, that should be handled by the regular fusion pass.
+    if (producer->opcode() == HloOpcode::kConstant) {
+      VLOG(3) << producer->name() << " is a constant.";
+      continue;
+    }
+    if (producer->IsCustomFusion()) {
+      continue;
+    }
+    // First, fuse the consumer ops of the current op, which are siblings.
+    if (FuseSiblings(/*parent=*/producer, &fusion_info_cache, &cost_analysis)) {
+      changed = true;
+    }
+    // Second, perform producer-consumer multi-output fusion. This order will
+    // ensure that all get-tuple-element ops inserted as a by-product of
+    // multi-output fusion will occur before the current op in the order of
+    // traversal, and hence, not get into the way of subsequent fusion attempts.
+    const auto candidates = GetProducerConsumerMultiOutputFusionCandidates(
+        producer, *reachability_, &fusion_info_cache, device_info_,
+        &cost_analysis);
+    auto* consumer_for_fusion = SelectPreferredFusionCandidate(candidates);
+    if (consumer_for_fusion == nullptr) {
+      continue;
+    }
+    if (!ConsumeFuel(name(), [&] {
+          return absl::StrFormat("Not fusing %s and %s.", producer->name(),
+                                 consumer_for_fusion->name());
+        })) {
+      continue;
+    }
+    changed = true;
+    fusion_info_cache.Invalidate(producer);
+    fusion_info_cache.Invalidate(consumer_for_fusion);
+    TF_RETURN_IF_ERROR(cost_analysis.RemoveInstruction(producer));
+    TF_RETURN_IF_ERROR(cost_analysis.RemoveInstruction(consumer_for_fusion));
+
+    HloInstruction* input_fusion;
+    if (consumer_for_fusion->opcode() == HloOpcode::kFusion) {
+      input_fusion = consumer_for_fusion;
+      VLOG(2) << "Fuse producer " << producer->name() << " into its consumer "
+              << consumer_for_fusion->name();
+    } else {
+      input_fusion = computation_->AddInstruction(HloInstruction::CreateFusion(
+          consumer_for_fusion->shape(),
+          ChooseFusionKind(*producer, *consumer_for_fusion),
+          consumer_for_fusion));
+      VLOG(2) << "Fuse producer " << producer->name() << " and its consumer "
+              << consumer_for_fusion->name() << " into "
+              << input_fusion->name();
+      TF_CHECK_OK(
+          computation_->ReplaceInstruction(consumer_for_fusion, input_fusion));
+    }
+
+    DumpFusionState(*input_fusion,
+                    absl::StrCat("About to fuse producer |", producer->name(),
+                                 "| into consumer |", input_fusion->name(),
+                                 "| inside multi-output fusion"),
+                    /*producer=*/producer);
+
+    if (producer->opcode() == HloOpcode::kFusion) {
+      input_fusion->MergeFusionInstructionIntoMultiOutput(producer);
+    } else {
+      input_fusion->FuseInstructionIntoMultiOutput(producer);
+      CHECK_EQ(0, producer->user_count());
+      TF_CHECK_OK(computation_->RemoveInstruction(producer));
+    }
+    TF_RETURN_IF_ERROR(cost_analysis.RevisitInstruction(input_fusion));
+
+    DumpFusionState(*input_fusion,
+                    absl::StrCat("Fused into |", input_fusion->name(),
+                                 "| inside multi-output fusion"));
+    RecomputeReachability();
+  }
+  return changed;
+}
+
+void MultiOutputFusion::DumpFusionState(const HloInstruction& consumer,
+                                        absl::string_view label,
+                                        const HloInstruction* producer) {
+  if (consumer.GetModule()
+          ->config()
+          .debug_options()
+          .xla_dump_fusion_visualization()) {
+    RegisterFusionState(*computation_, label, consumer, producer);
+  }
+}
+
+absl::StatusOr<bool> MultiOutputFusion::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  bool changed = false;
+  for (auto* computation : GetFusibleComputations(*module, execution_threads)) {
+    computation_ = computation;
+    TF_ASSIGN_OR_RETURN(bool computation_changed, DoMultiOutputFusion());
+    changed |= computation_changed;
+  }
+  return changed;
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/multi_output_fusion.h b/third_party/xla/xla/service/gpu/transforms/multi_output_fusion.h
new file mode 100644
index 0000000..9ebabe6
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/multi_output_fusion.h
@@ -0,0 +1,134 @@
+/* Copyright 2018 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_MULTI_OUTPUT_FUSION_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_MULTI_OUTPUT_FUSION_H_
+
+#include <memory>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_dfs_reachability.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/gpu/gpu_fusible.h"
+#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h"
+#include "xla/service/hlo_cost_analysis.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/stream_executor/device_description.h"
+
+namespace xla {
+namespace gpu {
+
+// Multi-output fusion of sibling and producer-consumer instructions for the
+// GPU backend to reduce memory bandwidth requirements.
+//
+//   0) Before multi-    1) Sibling multi-    2) Producer-consumer
+//      output fusion       output fusion        multi-output fusion
+//
+//          p                    p                    p
+//          |                    |                    |
+//          v                    v                    v
+//          A                    A               +-fusion--+
+//        /   \                  |               |    A    |
+//       |     |            +-fusion--+          |   / \   |
+//       v     v            |   / \   |          |  B   |  |
+//       B     C            |  B   C  |          |  |   |  |
+//        \   /             |  |   |  |          |  v   v  |
+//         v v              |  v   v  |          |  tuple  |
+//        ROOT              |  tuple  |          +---------+
+//                          +---------+            /    \
+//                            /    \            gte_b  gte_a
+//                         gte_b  gte_c           |      |
+//                           |      |             |      v
+//                            \    /              |      C
+//                             v  v                \    /
+//                             ROOT                 v  v
+//                                                  ROOT
+//
+// Multi-output fusion ops have a tuple op at their root containing multiple
+// elements as outputs. GetTupleElement ops (depicted as gte_* above) are
+// inserted to extract tuple elements for consumers.
+//
+// The two different flavors of multi-output fusion this pass performs are
+// depicted above.
+// 1) Fusion of sibling ops reduces memory bandwidth requirements, because
+//    common input parameters have to be read only once.
+// 2) Fusion of producer-consumer ops reduces memory bandwidth requirements by
+//    saving one read from memory. In the example above, B does not need to read
+//    the output of A from memory, while C still does (using gte_a).
+// Note that sibling (1) and producer-consumer (2) multi-output fusion can be
+// combined.
+//
+// The MultiOutputFusion pass modifies the HLO in reverse post-order (defs
+// before uses). First, it attempts to fuse the consumer ops of the current op,
+// which are siblings (1). Hereafter, it attempts to fuse the current op with
+// one of its consumers (2). This order avoids a phase ordering issue (described
+// in go/fusionfusion). It ensures that all GetTupleElement ops inserted as a
+// by-product of multi-output fusion will occur before the current op in the
+// order of traversal, and hence, not get into the way of subsequent fusion
+// attempts.
+//
+// The MultiOutputFusion pass ensures several conditions are met for fusion.
+// Some of them are relevant for correctness. In particular, no cycles must be
+// introduced into the HLO module. Moreover, the code emitters for multi-output
+// fusion must support the combination of ops and their shapes. Other
+// restrictions are rather arbitrary and lifting them could be beneficial.
+// * Sibling fusion (1) requires at least one op to be a kFusion.
+// * Sibling fusion (1) does not fuse kInput fusions with kLoop fusions, i.e.
+//   the fusion kinds must match.
+
+class MultiOutputFusion : public HloModulePass {
+ public:
+  explicit MultiOutputFusion(
+      const se::DeviceDescription& device_info,
+      HloCostAnalysis::ShapeSizeFunction shape_size_function)
+      : device_info_(device_info), shape_size_function_(shape_size_function) {}
+
+  absl::string_view name() const override { return "multi_output_fusion"; }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+  bool FuseSiblings(HloInstruction* parent, FusionInfoCache* fusion_info_cache,
+                    GpuHloCostAnalysis* cost_analysis);
+
+  absl::StatusOr<bool> DoMultiOutputFusion();
+
+  // Recompute reachability for the current computation.
+  void RecomputeReachability();
+
+  void DumpFusionState(const HloInstruction& consumer, absl::string_view label,
+                       const HloInstruction* producer = nullptr);
+
+  // Computation for the pass.
+  HloComputation* computation_;
+
+  // The reachability map of current computation.
+  std::unique_ptr<HloDfsReachability> reachability_;
+
+  se::DeviceDescription device_info_;
+  HloCostAnalysis::ShapeSizeFunction shape_size_function_;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_MULTI_OUTPUT_FUSION_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/multi_output_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/multi_output_fusion_test.cc
new file mode 100644
index 0000000..1fd4263
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/multi_output_fusion_test.cc
@@ -0,0 +1,2236 @@
+/* Copyright 2018 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/multi_output_fusion.h"
+
+#include <cstdint>
+#include <optional>
+#include <vector>
+
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/gpu/gpu_device_info_for_tests.h"
+#include "xla/service/gpu/gpu_fusible.h"
+#include "xla/service/hlo_cost_analysis.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/tests/hlo_test_base.h"
+
+namespace xla {
+namespace gpu {
+
+namespace m = ::xla::match;
+
+class MultiOutputFusionTest : public HloTestBase {
+  HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const {
+    return [&](const Shape& shape) {
+      constexpr int64_t kPointerSize = 8;
+      return ShapeUtil::ByteSizeOf(shape, kPointerSize);
+    };
+  }
+
+ public:
+  MultiOutputFusion mof_{TestGpuDeviceInfo::RTXA6000DeviceInfo(),
+                         ShapeSizeBytesFunction()};
+
+  void CheckMultiOutputFusion(absl::string_view hlo,
+                              std::optional<absl::string_view> expected) {
+    RunAndFilecheckHloRewrite(
+        hlo,
+        MultiOutputFusion{TestGpuDeviceInfo::RTXA6000DeviceInfo(),
+                          ShapeSizeBytesFunction()},
+        expected);
+  }
+};
+
+const char kModulePrefix[] = R"(
+    HloModule test_module
+
+    scalar_add_computation {
+      scalar_lhs.0 = f32[] parameter(0)
+      scalar_rhs.0 = f32[] parameter(1)
+      ROOT add.0 = f32[] add(scalar_lhs.0, scalar_rhs.0)
+    }
+    scalar_mul_computation {
+      scalar_lhs.1 = f32[] parameter(0)
+      scalar_rhs.1 = f32[] parameter(1)
+      ROOT mul.1 = f32[] multiply(scalar_lhs.1, scalar_rhs.1)
+    })";
+
+static int64_t CountMultiOutputFusions(const HloModule* module) {
+  int multi_output_fusion_count = 0;
+  for (auto* computation : module->MakeNonfusionComputations()) {
+    for (auto* instr : computation->instructions()) {
+      if (instr->IsMultiOutputFusion()) {
+        multi_output_fusion_count++;
+      }
+    }
+  }
+  return multi_output_fusion_count;
+}
+
+TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingReduceAndReduceFusion) {
+  // Fusion with reduce instruction root and a sibling reduce instruction
+  // sharing the same input param.
+  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+    fused_computation {
+      p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
+      mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1)
+      const.1 = f32[] parameter(0)
+      ROOT reduce.1 = f32[512]{0} reduce(mul, const.1), dimensions={0,2,3}, to_apply=scalar_add_computation
+    }
+
+    ENTRY entry {
+      p0 = f32[] parameter(0)
+      p1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
+      const.2 = f32[] constant(1)
+      fusion = f32[512] fusion(p0, p1), kind=kInput, calls=fused_computation
+      reduce.2 = f32[512]{0} reduce(p1, const.2), dimensions={0,2,3}, to_apply=scalar_add_computation
+      ROOT root = (f32[512]{0}, f32[512]{0}) tuple(fusion, reduce.2)
+    })"))
+                    .value();
+  ASSERT_TRUE(mof_.Run(module.get()).value());
+  SCOPED_TRACE(module->ToString());
+  const HloInstruction* fusion =
+      module->entry_computation()->root_instruction()->operand(0)->operand(0);
+  ASSERT_TRUE(fusion->IsMultiOutputFusion());
+  EXPECT_THAT(fusion->fused_expression_root(),
+              GmockMatch(m::Tuple(m::Reduce(), m::Reduce())));
+}
+
+TEST_F(MultiOutputFusionTest, MultiOutputFusionDifferentReduceInputShapes) {
+  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+    fused_computation_1 {
+      p1.1 = f32[6400]{0} parameter(1)
+      mul = f32[6400]{0} multiply(p1.1, p1.1)
+      const.1 = f32[] parameter(0)
+      ROOT reduce.1 = f32[] reduce(mul, const.1), dimensions={0}, to_apply=scalar_add_computation
+    }
+
+    fused_computation_2 {
+      p1.2 = f32[6400]{0} parameter(1)
+      r1 = f32[64,100]{0,1} reshape(p1.2)
+      const.2 = f32[] parameter(0)
+      ROOT reduce.2 = f32[] reduce(r1, const.2), dimensions={1,0}, to_apply=scalar_mul_computation
+    }
+
+    ENTRY entry {
+      p0 = f32[] parameter(0)
+      p1 = f32[6400]{0} parameter(1)
+      fusion.1 = f32[] fusion(p0, p1), kind=kInput, calls=fused_computation_1
+      fusion.2 = f32[] fusion(p0, p1), kind=kInput, calls=fused_computation_2
+      ROOT root = (f32[], f32[]) tuple(fusion.1, fusion.2)
+    })"))
+                    .value();
+  ASSERT_FALSE(mof_.Run(module.get()).value());
+}
+
+TEST_F(MultiOutputFusionTest, ReduceMofDifferentTypes) {
+  // Fusion with reduce instruction root and a sibling reduce instruction
+  // sharing the same input param.
+  const char* hlo = R"(
+HloModule module
+
+scalar_add_computation {
+  scalar_lhs.1 = f32[] parameter(0)
+  scalar_rhs.1 = f32[] parameter(1)
+  ROOT add.1 = f32[] add(scalar_lhs.1, scalar_rhs.1)
+}
+
+scalar_add_computation_f16 {
+  scalar_lhs.0 = f16[] parameter(0)
+  scalar_rhs.0 = f16[] parameter(1)
+  ROOT add.0 = f16[] add(scalar_lhs.0, scalar_rhs.0)
+}
+
+fused_computation {
+  param_0.2 = f32[128,512,28,28]{3,2,1,0} parameter(0)
+  c.1 = f16[128,512,28,28]{3,2,1,0} convert(param_0.2)
+  const.0 = f16[] constant(0)
+  ROOT reduce.0 = f16[512]{0} reduce(c.1, const.0), dimensions={0,2,3}, to_apply=scalar_add_computation_f16
+}
+
+ENTRY entry {
+  p0 = f32[] parameter(0)
+  p1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
+  const.2 = f32[] constant(0)
+  reduce.1 = f32[512]{0} reduce(p1, const.2), dimensions={0,2,3}, to_apply=scalar_add_computation
+  fusion = f16[512]{0} fusion(p1), kind=kInput, calls=fused_computation
+  ROOT root = (f32[512]{0}, f16[512]{0}) tuple(reduce.1, fusion)
+})";
+
+  CheckMultiOutputFusion(hlo, R"(
+// CHECK: %fused_computation
+// CHECK-NEXT:   [[param_0_2_0:%[^ ]+]] = f32[128,512,28,28]{3,2,1,0} parameter(0)
+// CHECK-NEXT:   [[c_1_1:%[^ ]+]] = f16[128,512,28,28]{3,2,1,0} convert([[param_0_2_0]])
+// CHECK-NEXT:   [[const_0_2:%[^ ]+]] = f16[] constant(0)
+// CHECK-NEXT:   [[reduce_0_3:%[^ ]+]] = f16[512]{0} reduce([[c_1_1]], [[const_0_2]]), dimensions={0,2,3}, to_apply=[[scalar_add_computation_f16_4:%[^ ]+]]
+// CHECK-NEXT:   [[param_1_5:%[^ ]+]] = f32[] parameter(1)
+// CHECK-NEXT:   [[reduce_2_6:%[^ ]+]] = f32[512]{0} reduce([[param_0_2_0]], [[param_1_5]]), dimensions={0,2,3}, to_apply=[[scalar_add_computation_7:%[^ ]+]]
+// CHECK-NEXT:   ROOT [[tuple_8:%[^ ]+]] = (f16[512]{0}, f32[512]{0}) tuple([[reduce_0_3]], [[reduce_2_6]])
+// CHECK:   [[fusion_9:%[^ ]+]] = (f16[512]{0}, f32[512]{0}) fusion([[p1_10:%[^ ]+]], [[const_2_11:%[^ ]+]]), kind=kInput, calls=[[fused_computation_12:%[^ ]+]]
+)");
+}
+
+TEST_F(MultiOutputFusionTest, MultiOutputFusionDifferentReduceOutputShapes) {
+  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+    fused_computation_1 {
+      p1.1 = f32[10,10]{1,0} parameter(1)
+      mul = f32[10,10]{1,0} multiply(p1.1, p1.1)
+      const.1 = f32[] parameter(0)
+      ROOT reduce.1 = f32[] reduce(mul, const.1), dimensions={0,1}, to_apply=scalar_add_computation
+    }
+
+    fused_computation_2 {
+      p1.2 = f32[10,10]{1,0} parameter(1)
+      const.2 = f32[] parameter(0)
+      ROOT reduce.2 = f32[10]{0} reduce(p1.2, const.2), dimensions={0}, to_apply=scalar_mul_computation
+    }
+
+    ENTRY entry {
+      p0 = f32[] parameter(0)
+      p1.3 = f32[10,10]{1,0} parameter(1)
+      fusion.1 = f32[] fusion(p0, p1.3), kind=kInput, calls=fused_computation_1
+      p2 = f32[] parameter(2)
+      fusion.2 = f32[10]{0} fusion(p2, p1.3), kind=kInput, calls=fused_computation_2
+      ROOT root = (f32[], f32[10]{0}) tuple(fusion.1, fusion.2)
+    })"))
+                    .value();
+  ASSERT_FALSE(mof_.Run(module.get()).value());
+}
+
+TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingReduceFusions) {
+  // Two sibling fusions with reduce instruction roots sharing the same input
+  // param.
+  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+    fused_computation_1 {
+      p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
+      mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1)
+      const.1 = f32[] parameter(0)
+      ROOT reduce.1 = f32[512]{0} reduce(mul, const.1), dimensions={0,2,3}, to_apply=scalar_add_computation
+    }
+
+    fused_computation_2 {
+      p1.2 = f32[128,512,28,28]{3,2,1,0} parameter(1)
+      const.2 = f32[] parameter(0)
+      ROOT reduce.2 = f32[512]{0} reduce(p1.2, const.2), dimensions={0,2,3}, to_apply=scalar_add_computation
+    }
+
+    ENTRY entry {
+      p0 = f32[] parameter(0)
+      p1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
+      fusion.1 = f32[512] fusion(p0, p1), kind=kInput, calls=fused_computation_1
+      fusion.2 = f32[512] fusion(p0, p1), kind=kInput, calls=fused_computation_2
+      ROOT root = (f32[512]{0}, f32[512]{0}) tuple(fusion.1, fusion.2)
+    })"))
+                    .value();
+  ASSERT_TRUE(mof_.Run(module.get()).value());
+  SCOPED_TRACE(module->ToString());
+  const HloInstruction* fusion =
+      module->entry_computation()->root_instruction()->operand(0)->operand(0);
+  ASSERT_TRUE(fusion->IsMultiOutputFusion());
+  EXPECT_THAT(fusion->fused_expression_root(),
+              GmockMatch(m::Tuple(m::Reduce(), m::Reduce())));
+}
+
+TEST_F(MultiOutputFusionTest, MultiOutputFusionNoSiblingFusionForCommonScalar) {
+  // Two sibling fusions with bitcast roots sharing the same scalar input param.
+  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+    fused_computation_1 {
+      param_0.87 = bf16[32,4096,16384]{2,1,0} parameter(0)
+      param_1.4620 = s32[] parameter(1)
+      constant_3949 = s32[] constant(0)
+      compare.1026 = pred[] compare(param_1.4620, constant_3949), direction=LT
+      constant_5437 = s32[] constant(32)
+      add.6859 = s32[] add(param_1.4620, constant_5437)
+      select.1599 = s32[] select(compare.1026, add.6859, param_1.4620)
+      dynamic-slice.59 = bf16[1,4096,16384]{2,1,0} dynamic-slice(param_0.87, select.1599, constant_3949, constant_3949), dynamic_slice_sizes={1,4096,16384}
+      ROOT bitcast.41089 = bf16[4096,16384]{1,0} bitcast(dynamic-slice.59)
+    }
+
+    fused_computation_2 {
+      param_0 = bf16[32,4096,16384]{2,1,0} parameter(0)
+      param_1 = s32[] parameter(1)
+      constant = s32[] constant(0)
+      compare = pred[] compare(param_1, constant), direction=LT
+      constant.32 = s32[] constant(32)
+      add = s32[] add(param_1, constant.32)
+      select = s32[] select(compare, add, param_1)
+      dynamic-slice = bf16[1,4096,16384]{2,1,0} dynamic-slice(param_0, select, constant, constant), dynamic_slice_sizes={1,4096,16384}
+      ROOT bitcast.41087 = bf16[4096,16384]{1,0} bitcast(dynamic-slice)
+    }
+
+    ENTRY entry {
+      p0 = s32[] parameter(0)
+      p1 = bf16[32,4096,16384]{2,1,0} parameter(1)
+      p2 = bf16[32,4096,16384]{2,1,0} parameter(2)
+      fusion.1 = bf16[4096,16384]{1,0} fusion(p1, p0), kind=kLoop, calls=fused_computation_1
+      fusion.2 = bf16[4096,16384]{1,0} fusion(p2, p0), kind=kLoop, calls=fused_computation_2
+      ROOT root = (bf16[4096,16384]{1,0}, bf16[4096,16384]{1,0}) tuple(fusion.1, fusion.2)
+    })"))
+                    .value();
+  ASSERT_FALSE(mof_.Run(module.get()).value());
+}
+
+TEST_F(MultiOutputFusionTest,
+       MultiOutputFusionSiblingReduceAndReduceMultiOutputFusion) {
+  // Multi-output fusion with two reduce instructions root and a sibling reduce
+  // instruction sharing the same input param.
+  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+    fused_computation (p0: f32[128,512,28,28]) -> (f32[512], f32[512]) {
+      const.1 = f32[] constant(1)
+      p0.1 = f32[128,512,28,28]{3,2,1,0} parameter(0)
+      mul = f32[128,512,28,28]{3,2,1,0} multiply(f32[128,512,28,28]{3,2,1,0} p0.1, f32[128,512,28,28]{3,2,1,0} p0.1)
+      reduce.1 = f32[512]{0} reduce(f32[128,512,28,28]{3,2,1,0} mul, f32[] const.1), dimensions={0,2,3}, to_apply=scalar_add_computation
+      reduce.2 = f32[512]{0} reduce(f32[128,512,28,28]{3,2,1,0} p0.1, f32[] const.1), dimensions={0,2,3}, to_apply=scalar_add_computation
+      ROOT tuple = (f32[512]{0}, f32[512]{0}) tuple(f32[512]{0} reduce.1, f32[512]{0} reduce.2)
+    }
+
+    ENTRY entry (p0: f32[128,512,28,28]) -> (f32[512], f32[512], f32[512]) {
+      p0 = f32[128,512,28,28]{3,2,1,0} parameter(0)
+      const = f32[] constant(1)
+      fusion = (f32[512]{0}, f32[512]{0}) fusion(f32[128,512,28,28]{3,2,1,0} p0), kind=kInput, calls=fused_computation
+      get-tuple-element = f32[512]{0} get-tuple-element((f32[512]{0}, f32[512]{0}) fusion), index=0
+      get-tuple-element.1 = f32[512]{0} get-tuple-element((f32[512]{0}, f32[512]{0}) fusion), index=1
+      reduce.3 = f32[512]{0} reduce(p0, const), dimensions={0,2,3}, to_apply=scalar_add_computation
+      ROOT root = (f32[512]{0}, f32[512]{0}, f32[512]{0}) tuple(f32[512]{0} get-tuple-element, f32[512]{0} get-tuple-element.1, f32[512]{0} reduce.3)
+    })"))
+                    .value();
+  ASSERT_TRUE(mof_.Run(module.get()).value());
+  SCOPED_TRACE(module->ToString());
+  const HloInstruction* fusion =
+      module->entry_computation()->root_instruction()->operand(0)->operand(0);
+  ASSERT_TRUE(fusion->IsMultiOutputFusion());
+  EXPECT_THAT(fusion->fused_expression_root(),
+              GmockMatch(m::Tuple(m::Reduce(), m::Reduce(), m::Reduce())));
+}
+
+TEST_F(MultiOutputFusionTest,
+       MultiOutputFusionSiblingFusionCheckAgainstReduceOperand) {
+  // Verify that if we already have a multi-output fusion that we prefer to pick
+  // a reduce op from its operands for checking shape compatibility.
+  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+    fused_computation_1 {
+      p1.1 = f32[10,10]{1,0} parameter(1)
+      mul = f32[10,10]{1,0} multiply(p1.1, p1.1)
+      const.1 = f32[] parameter(0)
+      reduce.1 = f32[] reduce(p1.1, const.1), dimensions={0,1}, to_apply=scalar_add_computation
+      ROOT tuple = (f32[10,10], f32[]) tuple(mul, reduce.1)
+    }
+
+    fused_computation_2 {
+      p1.2 = f32[10,10]{1,0} parameter(1)
+      const.2 = f32[] parameter(0)
+      ROOT reduce.2 = f32[10] reduce(p1.2, const.2), dimensions={0}, to_apply=scalar_mul_computation
+    }
+
+    ENTRY entry {
+      p0 = f32[] parameter(0)
+      p1 = f32[10,10]{1,0} parameter(1)
+      p2 = f32[] parameter(2)
+      fusion.1 = (f32[10,10], f32[]) fusion(p0, p1), kind=kInput, calls=fused_computation_1
+      get-tuple-element.1 = f32[10,10] get-tuple-element((f32[10,10], f32[]) fusion.1), index=0
+      get-tuple-element.2 = f32[] get-tuple-element((f32[10,10], f32[]) fusion.1), index=1
+      fusion.2 = f32[10] fusion(p2, p1), kind=kInput, calls=fused_computation_2
+      ROOT root = (f32[10,10], f32[], f32[10]) tuple(get-tuple-element.1, get-tuple-element.2, fusion.2)
+    })"))
+                    .value();
+  ASSERT_FALSE(mof_.Run(module.get()).value());
+}
+
+TEST_F(MultiOutputFusionTest, LoopVariadicReductionFusions) {
+  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+    fused_computation.94 {
+      tmp_0 = f32[] parameter(0)
+      tmp_1 = f32[] parameter(1)
+      tmp_2 = pred[] compare(tmp_0, tmp_1), direction=GE
+      tmp_3 = f32[] select(tmp_2, tmp_0, tmp_1)
+      tmp_4 = pred[] compare(tmp_0, tmp_1), direction=EQ
+      tmp_5 = s32[] parameter(2)
+      tmp_6 = s32[] parameter(3)
+      tmp_7 = s32[] minimum(tmp_5, tmp_6)
+      tmp_8 = s32[] select(tmp_2, tmp_5, tmp_6)
+      tmp_9 = s32[] select(tmp_4, tmp_7, tmp_8)
+      ROOT tmp_10 = (f32[], s32[]) tuple(tmp_3, tmp_9)
+    }
+
+    minmax_func.1536 {
+      tmp_0 = f32[] parameter(0)
+      tmp_1 = f32[] parameter(2)
+      tmp_2 = s32[] parameter(1)
+      tmp_3 = s32[] parameter(3)
+      ROOT tmp_4 = (f32[], s32[]) fusion(tmp_0, tmp_1, tmp_2, tmp_3), kind=kLoop, calls=fused_computation.94
+    }
+
+    fused_computation {
+      tmp_0 = f32[554112,10]{1,0} parameter(0)
+      tmp_1 = s32[554112,10]{1,0} iota(), iota_dimension=1
+      tmp_2 = f32[] constant(-inf)
+      tmp_3 = s32[] constant(0)
+      ROOT tmp_4 = (f32[554112]{0}, s32[554112]{0}) reduce(tmp_0, tmp_1, tmp_2, tmp_3), dimensions={1}, to_apply=minmax_func.1536
+    }
+
+    fused_computation2 {
+      tmp_0 = f32[554112,10]{1,0} parameter(0)
+      tmp_1 = s32[554112,10]{1,0} iota(), iota_dimension=1
+      tmp_2 = f32[] constant(inf)
+      tmp_3 = s32[] constant(1)
+      ROOT tmp_4 = (f32[554112]{0}, s32[554112]{0}) reduce(tmp_0, tmp_1, tmp_2, tmp_3), dimensions={1}, to_apply=minmax_func.1536
+    }
+
+    ENTRY e {
+      tmp_0 = f32[554112,10]{1,0} parameter(0)
+      tmp_1 = (f32[554112]{0}, s32[554112]{0}) fusion(tmp_0), kind=kLoop, calls=fused_computation
+      tmp_2 = s32[554112]{0} get-tuple-element(tmp_1), index=1
+      tmp_4 = (f32[554112]{0}, s32[554112]{0}) fusion(tmp_0), kind=kLoop, calls=fused_computation2
+      tmp_5 = s32[554112]{0} get-tuple-element(tmp_4), index=1
+      ROOT tmp_6 = s32[554112]{0} add(tmp_2, tmp_5)
+    })"))
+                    .value();
+  EXPECT_FALSE(mof_.Run(module.get()).value());
+}
+
+TEST_F(MultiOutputFusionTest, InputVariadicReductionFusions) {
+  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+    fused_computation.1117 {
+      param_0.2433 = f32[] parameter(0)
+      param_1.2571 = f32[] parameter(1)
+      compare.1770 = pred[] compare(param_0.2433, param_1.2571), direction=LE
+      select.682 = f32[] select(compare.1770, param_0.2433, param_1.2571)
+      compare.1303.clone.1 = pred[] compare(param_0.2433, param_1.2571), direction=EQ
+      param_2.6460 = s32[] parameter(2)
+      param_3.6755 = s32[] parameter(3)
+      minimum.633.clone.1 = s32[] minimum(param_2.6460, param_3.6755)
+      select.398.clone.1 = s32[] select(compare.1770, param_2.6460, param_3.6755)
+      select.397.clone.1 = s32[] select(compare.1303.clone.1, minimum.633.clone.1, select.398.clone.1)
+      ROOT tuple.151 = (f32[], s32[]) tuple(select.682, select.397.clone.1)
+    }
+
+    minmax_func.223 {
+      lhs_value.224 = f32[] parameter(0)
+      rhs_value.226 = f32[] parameter(2)
+      lhs_index.225 = s32[] parameter(1)
+      rhs_index.227 = s32[] parameter(3)
+      ROOT fusion.1117 = (f32[], s32[]) fusion(lhs_value.224, rhs_value.226, lhs_index.225, rhs_index.227), kind=kLoop, calls=fused_computation.1117
+    }
+
+    fused_computation.73 {
+      bitcast.86661 = f32[3,1024,300]{2,1,0} parameter(0)
+      iota.734 = s32[3,1,1024,300]{3,2,1,0} iota(), iota_dimension=3
+      bitcast.97555 = s32[3,1024,300]{2,1,0} bitcast(iota.734)
+      constant_3917 = f32[] constant(inf)
+      constant_3918 = s32[] constant(0)
+      ROOT reduce.1069 = (f32[3,1024]{1,0}, s32[3,1024]{1,0}) reduce(bitcast.86661, bitcast.97555, constant_3917, constant_3918), dimensions={2}, to_apply=minmax_func.223
+    }
+
+    fused_computation.84 {
+      bitcast.86676 = f32[3,1024,300]{2,1,0} parameter(0)
+      iota.732 = s32[3,1,1024,300]{3,2,1,0} iota(), iota_dimension=3
+      bitcast.97553 = s32[3,1024,300]{2,1,0} bitcast(iota.732)
+      constant_3915 = f32[] constant(inf)
+      constant_3916 = s32[] constant(0)
+      ROOT reduce.1070 = (f32[3,1024]{1,0}, s32[3,1024]{1,0}) reduce(bitcast.86676, bitcast.97553, constant_3915, constant_3916), dimensions={2}, to_apply=minmax_func.223
+    }
+
+    ENTRY e {
+      p0 = f32[3,1024,300]{2,1,0} parameter(0)
+      fusion.84 = (f32[3,1024]{1,0}, s32[3,1024]{1,0}) fusion(p0), kind=kInput, calls=fused_computation.84
+      gte.391 = s32[3,1024]{1,0} get-tuple-element(fusion.84), index=1
+      fusion.73 = (f32[3,1024]{1,0}, s32[3,1024]{1,0}) fusion(p0), kind=kInput, calls=fused_computation.73
+      gte.393 = s32[3,1024]{1,0} get-tuple-element(fusion.73), index=1
+      ROOT r = s32[3,1024]{1,0} add(gte.391, gte.393)
+    })"))
+                    .value();
+  EXPECT_TRUE(mof_.Run(module.get()).value());
+  EXPECT_EQ(module->entry_computation()->parameter_instruction(0)->user_count(),
+            1);
+  const HloInstruction* fusion =
+      module->entry_computation()->parameter_instruction(0)->users()[0];
+  EXPECT_THAT(fusion, GmockMatch(m::Fusion()));
+  EXPECT_THAT(fusion->fused_expression_root(),
+              GmockMatch(m::Tuple(m::Reduce(), m::Reduce())));
+}
+
+TEST_F(MultiOutputFusionTest, MultiOutputFusionTwoLoops) {
+  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+    fused_computation_1 {
+      p0.1 = f32[6400]{0} parameter(0)
+      ROOT mul = f32[6400]{0} multiply(p0.1, p0.1)
+    }
+
+    fused_computation_2 {
+      p0.2 = f32[6400]{0} parameter(0)
+      const.2 = f32[] constant(1)
+      broadcast = f32[6400]{0} broadcast(const.2), dimensions={}
+      ROOT div = f32[6400]{0} divide(p0.2, broadcast)
+    }
+
+    ENTRY entry {
+      p0 = f32[6400]{0} parameter(0)
+      fusion.1 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_1
+      fusion.2 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_2
+      ROOT root = (f32[6400]{0}, f32[6400]{0}) tuple(fusion.1, fusion.2)
+    })"))
+                    .value();
+  ASSERT_TRUE(mof_.Run(module.get()).value());
+  SCOPED_TRACE(module->ToString());
+  const HloInstruction* fusion =
+      module->entry_computation()->root_instruction()->operand(0)->operand(0);
+  ASSERT_TRUE(fusion->IsMultiOutputFusion());
+  EXPECT_THAT(fusion->fused_expression_root(),
+              GmockMatch(m::Tuple(m::Multiply(), m::Divide())));
+}
+
+TEST_F(MultiOutputFusionTest, MultiOutputFusionLoopElementwise) {
+  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+    fused_computation_1 {
+      p0.1 = f32[6400]{0} parameter(0)
+      ROOT mul = f32[6400]{0} multiply(p0.1, p0.1)
+    }
+
+    ENTRY entry {
+      p0 = f32[6400]{0} parameter(0)
+      fusion.1 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_1
+      const.2 = f32[] constant(1)
+      broadcast = f32[6400]{0} broadcast(const.2), dimensions={}
+      div = f32[6400]{0} divide(p0, broadcast)
+      ROOT root = (f32[6400]{0}, f32[6400]{0}) tuple(fusion.1, div)
+    })"))
+                    .value();
+  ASSERT_TRUE(mof_.Run(module.get()).value());
+  SCOPED_TRACE(module->ToString());
+  const HloInstruction* fusion =
+      module->entry_computation()->root_instruction()->operand(0)->operand(0);
+  ASSERT_TRUE(fusion->IsMultiOutputFusion());
+  EXPECT_THAT(fusion->fused_expression_root(),
+              GmockMatch(m::Tuple(m::Multiply(), m::Divide())));
+}
+
+TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopsDifferentShapes) {
+  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+    fused_computation_1 {
+      p0.1 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} parameter(0)
+      ROOT mul = f32[8,1,5,16,1,2]{5,4,3,2,1,0} multiply(p0.1, p0.1)
+    }
+
+    fused_computation_2 {
+      p0.2 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} parameter(0)
+      const.2 = f32[] constant(0)
+      ROOT reduce = f32[1,5,1,2]{3,2,1,0} reduce(p0.2, const.2), dimensions={0,3}, to_apply=scalar_add_computation
+    }
+
+    ENTRY entry {
+      p0 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} parameter(0)
+      fusion.1 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_1
+      fusion.2 = f32[1,5,1,2]{3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_2
+      ROOT root = (f32[8,1,5,16,1,2]{5,4,3,2,1,0}, f32[1,5,1,2]{3,2,1,0}) tuple(fusion.1, fusion.2)
+    })"))
+                    .value();
+  ASSERT_FALSE(mof_.Run(module.get()).value());
+}
+
+TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopAndMultiOutputLoop) {
+  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+    fused_computation_1 {
+      p0.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
+      mul = f32[8,1,5,16,1,1]{5,4,3,2,1,0} multiply(p0.1, p0.1)
+      exp = f32[8,1,5,16,1,1]{5,4,3,2,1,0} exponential(p0.1)
+      ROOT tuple = (f32[8,1,5,16,1,1]{5,4,3,2,1,0},
+        f32[8,1,5,16,1,1]{5,4,3,2,1,0}) tuple(mul, exp)
+    }
+
+    fused_computation_2 {
+      p0.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
+      const.2 = f32[] constant(0)
+      broadcast = f32[8,1,5,16,1,1]{5,4,3,2,1,0} broadcast(const.2),
+        dimensions={}
+      ROOT add = f32[8,1,5,16,1,1]{5,4,3,2,1,0} add(p0.2, broadcast)
+    }
+
+    ENTRY entry {
+      p0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
+      fusion.1 = (f32[8,1,5,16,1,1]{5,4,3,2,1,0},
+        f32[8,1,5,16,1,1]{5,4,3,2,1,0}) fusion(p0), kind=kLoop,
+        calls=fused_computation_1
+      fusion.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} fusion(p0), kind=kLoop,
+        calls=fused_computation_2
+      gte0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=0
+      gte1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=1
+      ROOT root = (f32[8,1,5,16,1,1]{5,4,3,2,1,0},
+        f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0})
+        tuple(gte0, gte1, fusion.2)
+    })"))
+                    .value();
+  ASSERT_TRUE(mof_.Run(module.get()).value());
+  SCOPED_TRACE(module->ToString());
+  const HloInstruction* fusion =
+      module->entry_computation()->root_instruction()->operand(0)->operand(0);
+  ASSERT_TRUE(fusion->IsMultiOutputFusion());
+  EXPECT_THAT(fusion->fused_expression_root(),
+              GmockMatch(m::Tuple(m::Multiply(), m::Exp(), m::Add())));
+}
+
+TEST_F(MultiOutputFusionTest,
+       MultiOutputFusionSiblingMultiOutputLoopAndMultiOutputLoop) {
+  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+    fused_computation_1 {
+      p0.1 = f32[8,16]{1,0} parameter(0)
+      mul = f32[8,16]{1,0} multiply(p0.1, p0.1)
+      exp = f32[8,16]{1,0} exponential(p0.1)
+      ROOT tuple = (f32[8,16]{1,0}, f32[8,16]{1,0}) tuple(mul, exp)
+    }
+
+    fused_computation_2 {
+      p0.2 = f32[8,16]{1,0} parameter(0)
+      const.2 = f32[] constant(0)
+      broadcast = f32[8,16]{1,0} broadcast(const.2),
+        dimensions={}
+      add = f32[8,16]{1,0} add(p0.2, broadcast)
+      ROOT tuple.1 = (f32[8,16]{1,0}, f32[8,16]{1,0}) tuple(add, broadcast)
+    }
+
+    ENTRY entry {
+      p0 = f32[8,16]{1,0} parameter(0)
+      fusion.1 = (f32[8,16]{1,0}, f32[8,16]{1,0}) fusion(p0), kind=kLoop,
+        calls=fused_computation_1
+      fusion.2 = (f32[8,16]{1,0}, f32[8,16]{1,0}) fusion(p0), kind=kLoop,
+        calls=fused_computation_2
+      gte0 = f32[8,16]{1,0} get-tuple-element(fusion.1), index=0
+      gte1 = f32[8,16]{1,0} get-tuple-element(fusion.1), index=1
+      gte2 = f32[8,16]{1,0} get-tuple-element(fusion.2), index=0
+      gte3 = f32[8,16]{1,0} get-tuple-element(fusion.2), index=1
+      ROOT root = (f32[8,16]{1,0}, f32[8,16]{1,0}, f32[8,16]{1,0},
+        f32[8,16]{1,0})
+        tuple(gte0, gte1, gte2, gte3)
+    })"))
+                    .value();
+  ASSERT_TRUE(mof_.Run(module.get()).value());
+  SCOPED_TRACE(module->ToString());
+  const HloInstruction* fusion =
+      module->entry_computation()->root_instruction()->operand(0)->operand(0);
+  ASSERT_TRUE(fusion->IsMultiOutputFusion());
+  EXPECT_THAT(
+      fusion->fused_expression_root(),
+      GmockMatch(m::Tuple(m::Multiply(), m::Exp(), m::Add(), m::Broadcast())));
+}
+
+TEST_F(MultiOutputFusionTest,
+       MultiOutputFusionSiblingLoopAndMultiOutputLoopDifferentShapes) {
+  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+    fused_computation_1 {
+      p0.1 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} parameter(0)
+      mul = f32[8,1,5,16,1,2]{5,4,3,2,1,0} multiply(p0.1, p0.1)
+      exp = f32[8,1,5,16,1,2]{5,4,3,2,1,0} exponential(p0.1)
+      ROOT tuple = (f32[8,1,5,16,1,2]{5,4,3,2,1,0},
+        f32[8,1,5,16,1,2]{5,4,3,2,1,0}) tuple(mul, exp)
+    }
+
+    fused_computation_2 {
+      p0.2 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} parameter(0)
+      const.2 = f32[] constant(0)
+      ROOT reduce = f32[1,5,1,2]{3,2,1,0} reduce(p0.2, const.2),
+        dimensions={0,3}, to_apply=scalar_add_computation
+    }
+
+    ENTRY entry {
+      p0 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} parameter(0)
+      fusion.1 = (f32[8,1,5,16,1,2]{5,4,3,2,1,0},
+        f32[8,1,5,16,1,2]{5,4,3,2,1,0}) fusion(p0), kind=kLoop,
+        calls=fused_computation_1
+      fusion.2 = f32[1,5,1,2]{3,2,1,0} fusion(p0), kind=kLoop,
+        calls=fused_computation_2
+      gte0 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=0
+      gte1 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=1
+      ROOT root = (f32[8,1,5,16,1,2]{5,4,3,2,1,0},
+        f32[8,1,5,16,1,2]{5,4,3,2,1,0}, f32[1,5,1,2]{3,2,1,0})
+        tuple(gte0, gte1, fusion.2)
+    })"))
+                    .value();
+  ASSERT_FALSE(mof_.Run(module.get()).value());
+}
+
+TEST_F(MultiOutputFusionTest, SiblingFusionBitcastAndLoopFusionNotFused) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+HloModule test
+
+fused_computation_1 {
+  p0.1 = f32[2048,16000]{1,0} parameter(0)
+  bitcast = f32[2048,1,16000]{2,1,0} bitcast(p0.1)
+  ROOT exp = f32[2048,1,16000]{2,1,0} exponential(bitcast)
+}
+
+ENTRY main {
+  param_0 = f32[2048,16000]{1,0} parameter(0)
+  fusion = f32[2048,1,16000]{2,1,0} fusion(param_0), kind=kLoop, calls=fused_computation_1
+  bitcast = f32[16000,1,2048]{2,1,0} bitcast(param_0)
+  ROOT tuple.143 = (f32[16000,1,2048]{2,1,0}, f32[2048,1,16000]{2,1,0}) tuple(bitcast, fusion)
+})")
+                    .value();
+  EXPECT_FALSE(mof_.Run(module.get()).value());
+}
+
+TEST_F(MultiOutputFusionTest,
+       ProducerConsumerFusionBitcastAndElementwiseNotFused) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+HloModule test
+
+ENTRY main {
+  param_0 = f32[2048,16000]{1,0} parameter(0)
+  convert = bf16[2048,16000]{1,0} convert(param_0)
+  bitcast = bf16[16000,1,2048]{2,1,0} bitcast(convert)
+  ROOT tuple.143 = (bf16[16000,1,2048]{2,1,0}, bf16[2048,16000]{1,0}) tuple(bitcast, convert)
+})")
+                    .value();
+  EXPECT_FALSE(mof_.Run(module.get()).value());
+}
+
+TEST_F(MultiOutputFusionTest, ProducerConsumerFusionElementwiseAndReduce) {
+  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+    ENTRY reduce {
+      p0 = f32[32,32,32]{2,1,0} parameter(0)
+      c0 = f32[] constant(0)
+      exp = f32[32,32,32]{2,1,0} exponential(p0)
+      reduce = f32[32,32]{1,0} reduce(exp, c0), dimensions={2},
+        to_apply=scalar_add_computation
+      ROOT root = (f32[32,32]{1,0}, f32[32,32,32]{2,1,0}) tuple(reduce, exp)
+    })"))
+                    .value();
+  ASSERT_TRUE(mof_.Run(module.get()).value());
+  SCOPED_TRACE(module->ToString());
+  const HloInstruction* root = module->entry_computation()->root_instruction();
+  const HloInstruction* fusion = nullptr;
+  ASSERT_THAT(root, GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
+                                        m::GetTupleElement())));
+  ASSERT_TRUE(fusion->IsMultiOutputFusion());
+  EXPECT_THAT(fusion->fused_expression_root(),
+              GmockMatch(m::Tuple(m::Reduce(), m::Exp())));
+}
+
+TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduce) {
+  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+    fused_add {
+      p0.1 = f32[32,32,32]{2,1,0} parameter(0)
+      p1.1 = f32[32,32,32]{2,1,0} parameter(1)
+      ROOT add = f32[32,32,32]{2,1,0} add(p0.1, p1.1)
+    }
+
+    ENTRY reduce {
+      p0 = f32[32,32,32]{2,1,0} parameter(0)
+      p1 = f32[32,32,32]{2,1,0} parameter(1)
+      c0 = f32[] constant(0)
+      add = f32[32,32,32]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_add
+      reduce = f32[32,32]{1,0} reduce(add, c0), dimensions={2},
+        to_apply=scalar_add_computation
+      ROOT root = (f32[32,32]{1,0}, f32[32,32,32]{2,1,0}) tuple(reduce, add)
+    })"))
+                    .value();
+  ASSERT_TRUE(mof_.Run(module.get()).value());
+  SCOPED_TRACE(module->ToString());
+  const HloInstruction* root = module->entry_computation()->root_instruction();
+  const HloInstruction* fusion = nullptr;
+  ASSERT_THAT(root, GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
+                                        m::GetTupleElement())));
+  ASSERT_TRUE(fusion->IsMultiOutputFusion());
+  EXPECT_THAT(fusion->fused_expression_root(),
+              GmockMatch(m::Tuple(m::Reduce(), m::Add())));
+}
+
+TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduceFusion) {
+  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+    fused_select {
+      p1.1 = f32[32,32,32]{2,1,0} parameter(1)
+      c0 = f32[] constant(0)
+      broadcast = f32[32,32,32]{2,1,0} broadcast(f32[] c0), dimensions={}
+      greater-than = pred[32,32,32]{2,1,0} compare(f32[32,32,32]{2,1,0} p1.1,
+        f32[32,32,32]{2,1,0} broadcast), direction=GT
+      p0.1 = f32[32,32,32]{2,1,0} parameter(0)
+      ROOT select = f32[32,32,32]{2,1,0} select(pred[32,32,32]{2,1,0}
+        greater-than, f32[32,32,32]{2,1,0} p0.1, f32[32,32,32]{2,1,0} broadcast)
+    }
+
+    fused_reduce {
+      p0.2 = f32[32,32,32]{2,1,0} parameter(0)
+      c1 = f32[] constant(0)
+      r1 = f32[32,32]{1,0} reduce(p0.2, c1), dimensions={2},
+        to_apply=scalar_add_computation
+      mul = f32[32,32,32]{2,1,0} multiply(p0.2, p0.2)
+      r2 = f32[32,32]{1,0} reduce(mul, c1), dimensions={2},
+        to_apply=scalar_add_computation
+      ROOT tuple = (f32[32,32]{1,0}, f32[32,32]{1,0}) tuple(r1, r2)
+    }
+
+    ENTRY reduce {
+      p0 = f32[32,32,32]{2,1,0} parameter(0)
+      p1 = f32[32,32,32]{2,1,0} parameter(1)
+      select = f32[32,32,32]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_select
+      fusion = (f32[32,32]{1,0}, f32[32,32]{1,0}) fusion(select), kind=kInput,
+        calls=fused_reduce
+      gte0 = f32[32,32]{1,0} get-tuple-element(fusion), index=0
+      gte1 = f32[32,32]{1,0} get-tuple-element(fusion), index=1
+      ROOT root = (f32[32,32]{1,0}, f32[32,32]{1,0}, f32[32,32,32]{2,1,0})
+        tuple(gte1, gte1, select)
+    })"))
+                    .value();
+  ASSERT_TRUE(mof_.Run(module.get()).value());
+  SCOPED_TRACE(module->ToString());
+  const HloInstruction* root = module->entry_computation()->root_instruction();
+  const HloInstruction* fusion = nullptr;
+  ASSERT_THAT(root,
+              GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
+                                  m::GetTupleElement(), m::GetTupleElement())));
+  ASSERT_TRUE(fusion->IsMultiOutputFusion());
+  EXPECT_THAT(fusion->fused_expression_root(),
+              GmockMatch(m::Tuple(m::Reduce(), m::Reduce(), m::Select())));
+}
+
+TEST_F(MultiOutputFusionTest, ProducerConsumerFusionDoNotFuseLoopReduceFusion) {
+  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+    fused_element_wise {
+      p0.1 = f32[2,2,2]{2,1,0} parameter(0)
+      p1.1 = f32[2,2,2]{2,1,0} parameter(1)
+      ROOT root = f32[2,2,2]{2,1,0} add(p0.1, p1.1)
+    }
+
+    fused_reduce {
+      p0.2 = f32[2,2,2]{2,1,0} parameter(0)
+      mul = f32[2,2,2]{2,1,0} multiply(f32[2,2,2]{2,1,0} p0.2,
+        f32[2,2,2]{2,1,0} p0.2)
+      broadcast = f32[2,2,2,2]{3,2,1,0} broadcast(mul), dimensions={3,2,1}
+      c1 = f32[] constant(0)
+      ROOT reduce = f32[2,2]{1,0} reduce(f32[2,2,2,2]{3,2,1,0} broadcast,
+        f32[] c1), dimensions={1,3}, to_apply=scalar_add_computation
+    }
+
+    ENTRY reduce {
+      p0 = f32[2,2,2]{2,1,0} parameter(0)
+      p1 = f32[2,2,2]{2,1,0} parameter(1)
+      element_wise = f32[2,2,2]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_element_wise
+      fusion = f32[2,2]{1,0} fusion(element_wise), kind=kLoop, calls=fused_reduce
+      ROOT root = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(fusion, element_wise)
+    })"))
+                    .value();
+  ASSERT_FALSE(mof_.Run(module.get()).value());
+}
+
+TEST_F(MultiOutputFusionTest,
+       ProducerConsumerFusionFp16LoopFusionAndReduceFusion) {
+  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+    fused_select {
+      p1.1 = f16[32,32,32]{2,1,0} parameter(1)
+      c0 = f16[] constant(0)
+      broadcast = f16[32,32,32]{2,1,0} broadcast(f16[] c0), dimensions={}
+      greater-than = pred[32,32,32]{2,1,0} compare(f16[32,32,32]{2,1,0} p1.1,
+        f16[32,32,32]{2,1,0} broadcast), direction=GT
+      p0.1 = f16[32,32,32]{2,1,0} parameter(0)
+      ROOT select = f16[32,32,32]{2,1,0} select(pred[32,32,32]{2,1,0}
+        greater-than, f16[32,32,32]{2,1,0} p0.1, f16[32,32,32]{2,1,0} broadcast)
+    }
+    fused_reduce {
+      p0.2 = f16[32,32,32]{2,1,0} parameter(0)
+      convert = f32[32,32,32]{2,1,0} convert(p0.2)
+      c1 = f32[] constant(0)
+      r1 = f32[32,32]{1,0} reduce(convert, c1), dimensions={2},
+        to_apply=scalar_add_computation
+      mul = f32[32,32,32]{2,1,0} multiply(convert, convert)
+      r2 = f32[32,32]{1,0} reduce(mul, c1), dimensions={2},
+        to_apply=scalar_add_computation
+      ROOT tuple = (f32[32,32]{1,0}, f32[32,32]{1,0}) tuple(r1, r2)
+    }
+    ENTRY reduce {
+      p0 = f16[32,32,32]{2,1,0} parameter(0)
+      p1 = f16[32,32,32]{2,1,0} parameter(1)
+      select = f16[32,32,32]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_select
+      fusion = (f32[32,32]{1,0}, f32[32,32]{1,0}) fusion(select), kind=kInput,
+        calls=fused_reduce
+      gte0 = f32[32,32]{1,0} get-tuple-element(fusion), index=0
+      gte1 = f32[32,32]{1,0} get-tuple-element(fusion), index=1
+      ROOT root = (f32[32,32]{1,0}, f32[32,32]{1,0}, f16[32,32,32]{2,1,0})
+        tuple(gte1, gte1, select)
+    })"))
+                    .value();
+  ASSERT_TRUE(mof_.Run(module.get()).value());
+  SCOPED_TRACE(module->ToString());
+  const HloInstruction* root = module->entry_computation()->root_instruction();
+  const HloInstruction* fusion = nullptr;
+  ASSERT_THAT(root,
+              GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion(&fusion)),
+                                  m::GetTupleElement(), m::GetTupleElement())));
+  ASSERT_TRUE(fusion->IsMultiOutputFusion());
+  EXPECT_THAT(fusion->fused_expression_root(),
+              GmockMatch(m::Tuple(m::Reduce(), m::Reduce(), m::Select())));
+}
+
+TEST_F(MultiOutputFusionTest,
+       ProducerConsumerFusionReduceUnfriendlyLoopFusion) {
+  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+    mixed_input_layouts_computation {
+      p0.1 = f16[128,1024,32,32]{1,3,2,0} parameter(0)
+      p1.1 = f16[128,1024,32,32]{3,2,1,0} parameter(1)
+      copy = f16[128,1024,32,32]{1,3,2,0} copy(p1.1)
+      c0 = f16[] constant(0)
+      broadcast = f16[128,1024,32,32]{1,3,2,0} broadcast(c0), dimensions={}
+      greater-than = pred[128,1024,32,32]{1,3,2,0} compare(copy, broadcast), direction=GT
+      ROOT root = f16[128,1024,32,32]{1,3,2,0} select(greater-than, p0.1, broadcast)
+    }
+    fused_reduce {
+      p0.2 = f16[128,1024,32,32]{1,3,2,0} parameter(0)
+      convert = f32[128,1024,32,32]{1,3,2,0} convert(p0.2)
+      c0.2 = f32[] constant(0)
+      ROOT reduce = f32[1024]{0} reduce(convert, c0.2), dimensions={0,2,3}, to_apply=scalar_add_computation
+    }
+    ENTRY reduce {
+      p0 = f16[128,1024,32,32]{3,2,1,0} parameter(0)
+      p1 = f16[128,1024,32,32]{1,3,2,0} parameter(1)
+      loop_fusion = f16[128,1024,32,32]{1,3,2,0} fusion(p0, p1), kind=kLoop, calls=mixed_input_layouts_computation
+      reduce_fusion = f32[1024]{0} fusion(loop_fusion), kind=kInput, calls=fused_reduce
+      ROOT root = (f32[1024]{0}, f16[128,1024,32,32]{1,3,2,0}) tuple(reduce_fusion, loop_fusion)
+    })"))
+                    .value();
+  ASSERT_FALSE(mof_.Run(module.get()).value());
+}
+
+TEST_F(MultiOutputFusionTest, ProducerConsumerFusionAvoidsCycles) {
+  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+    fused_add {
+      p0 = f32[32,32,32]{2,1,0} parameter(0)
+      p1 = f32[32,32,32]{2,1,0} parameter(1)
+      ROOT add = f32[32,32,32]{2,1,0} add(p0, p1)
+    }
+
+    fused_mul {
+      p2 = f32[64,64,64]{2,1,0} parameter(0)
+      p3 = f32[64,64,64]{2,1,0} parameter(1)
+      ROOT multiply = f32[64,64,64]{2,1,0} multiply(p2, p3)
+    }
+
+    fused_reduce_1 {
+      p4 = f32[32,32,32]{2,1,0} parameter(0)
+      p5 = f32[64,64,64]{2,1,0} parameter(1)
+      slice = f32[32,32,32]{2,1,0} slice(p5), slice={[0:32], [0:32], [0:32]}
+      add = f32[32,32,32]{2,1,0} add(p4, slice)
+      c0 = f32[] constant(0)
+      ROOT r1 = f32[32,32]{1,0} reduce(add, c0), dimensions={2},
+        to_apply=scalar_add_computation
+    }
+
+    fused_reduce_2 {
+      p6 = f32[32,32,32]{2,1,0} parameter(0)
+      p7 = f32[64,64,64]{2,1,0} parameter(1)
+      c0 = f32[] constant(0)
+      pad = f32[64,64,64]{2,1,0} pad(p6, c0), padding=16_16x16_16x16_16
+      mul = f32[64,64,64]{2,1,0} multiply(pad, p7)
+      ROOT r1 = f32[64,64]{1,0} reduce(mul, c0), dimensions={2},
+        to_apply=scalar_add_computation
+    }
+
+    ENTRY reduce {
+      p8 = f32[32,32,32]{2,1,0} parameter(0)
+      p9 = f32[64,64,64]{2,1,0} parameter(1)
+      // `add` and `mul` can be multi-output fused with `reduce1` and `reduce2`,
+      // respectively. However, both isn't possible, because multi-output fusion
+      // will introduce an extra dependency from `neg` to `abs` or vice versa.
+      // Hence, the second multi-output fusion would introduce a cycle.
+      add = f32[32,32,32]{2,1,0} fusion(p8, p8), kind=kLoop, calls=fused_add
+      mul = f32[64,64,64]{2,1,0} fusion(p9, p9), kind=kLoop, calls=fused_mul
+
+      reduce1 = f32[32,32]{1,0} fusion(add, mul), kind=kInput,
+          calls=fused_reduce_1
+      reduce2 = f32[64,64]{1,0} fusion(add, mul), kind=kInput,
+          calls=fused_reduce_2
+      ROOT root = (f32[32,32,32]{2,1,0}, f32[32,32]{1,0}, f32[64,64]{1,0},
+                   f32[64,64,64]{2,1,0}) tuple(add, reduce1, reduce2, mul)
+    })"))
+                    .value();
+  ASSERT_TRUE(mof_.Run(module.get()).value());
+  SCOPED_TRACE(module->ToString());
+  EXPECT_EQ(1, CountMultiOutputFusions(module.get()));
+}
+
+TEST_F(MultiOutputFusionTest, PreferFuseProducerIntoFusionConsumer) {
+  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+    fused_add {
+      p0 = f32[32,32,32]{2,1,0} parameter(0)
+      p1 = f32[32,32,32]{2,1,0} parameter(1)
+      ROOT add = f32[32,32,32]{2,1,0} add(p0, p1)
+    }
+    fused_reduce {
+      p0 = f32[32,32,32]{2,1,0} parameter(0)
+      p1 = f32[64,64,64]{2,1,0} parameter(1)
+      slice = f32[32,32,32]{2,1,0} slice(p1), slice={[0:32], [0:32], [0:32]}
+      add = f32[32,32,32]{2,1,0} add(p0, slice)
+      c0 = f32[] constant(0)
+      ROOT r1 = f32[32,32]{1,0} reduce(add, c0), dimensions={2},
+        to_apply=scalar_add_computation
+    }
+    ENTRY reduce {
+      p0 = f32[32,32,32]{2,1,0} parameter(0)
+      p1 = f32[64,64,64]{2,1,0} parameter(1)
+      add = f32[32,32,32]{2,1,0} fusion(p0, p0), kind=kLoop, calls=fused_add
+      c0 = f32[] constant(0)
+      reduce2 = f32[32,32]{1,0} reduce(add, c0), dimensions={2},
+        to_apply=scalar_add_computation
+      reduce = f32[32,32]{1,0} fusion(add, p1), kind=kInput, calls=fused_reduce
+      ROOT root = (f32[32,32,32]{2,1,0}, f32[32,32]{1,0}, f32[32,32]{1,0})
+                  tuple(add, reduce, reduce2)
+    })"))
+                    .value();
+  ASSERT_TRUE(mof_.Run(module.get()).value());
+  SCOPED_TRACE(module->ToString());
+  int multi_output_fusion_count = 0;
+  for (auto* computation : module->MakeNonfusionComputations()) {
+    for (auto* instr : computation->instructions()) {
+      if (instr->IsMultiOutputFusion()) {
+        multi_output_fusion_count++;
+      }
+    }
+  }
+  EXPECT_EQ(1, multi_output_fusion_count);
+}
+
+// Check that we limit the number of operands to fusions we create.
+TEST_F(MultiOutputFusionTest, AvoidsLargeFusion) {
+  constexpr int64_t kNumParams = 200;
+  ASSERT_GT(kNumParams, MaxOperandsAndOutputsPerFusion());
+
+  // Compute
+  //   p0 * p1,
+  //   p0 * p1 + p1 * p2
+  //   p0 * p1 + p1 * p2 + p2 * p3
+  //   ...
+  // where each of the (pi * pj)'s is represented as a fusion node so that
+  // multi-output fusion will pay attention to it.
+  auto module = CreateNewVerifiedModule();
+  HloComputation::Builder b(TestName());
+  Shape shape = ShapeUtil::MakeShape(F32, {10, 100});
+
+  std::vector<HloInstruction*> params;
+  for (int64_t i = 0; i < kNumParams; ++i) {
+    params.push_back(
+        b.AddInstruction(HloInstruction::CreateParameter(i, shape, "p")));
+  }
+
+  // Creates a fusion node that calculates x*y.
+  auto make_fusion = [&](HloInstruction* x, HloInstruction* y) {
+    HloComputation::Builder sub_builder("subcomp");
+    auto* p0 = sub_builder.AddInstruction(
+        HloInstruction::CreateParameter(0, shape, "p"));
+    auto* p1 = sub_builder.AddInstruction(
+        HloInstruction::CreateParameter(1, shape, "p"));
+    sub_builder.AddInstruction(
+        HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, p0, p1));
+    HloComputation* subcomp =
+        module->AddEmbeddedComputation(sub_builder.Build());
+    return HloInstruction::CreateFusion(
+        shape, HloInstruction::FusionKind::kLoop, {x, y}, subcomp);
+  };
+
+  auto* sum = b.AddInstruction(make_fusion(params[0], params[1]));
+  for (int64_t i = 2; i < kNumParams; ++i) {
+    sum = b.AddInstruction(HloInstruction::CreateBinary(
+        shape, HloOpcode::kAdd, sum,
+        b.AddInstruction(make_fusion(params[i - 1], params[i]))));
+  }
+  auto computation = module->AddEntryComputation(b.Build());
+  EXPECT_TRUE(mof_.Run(module.get()).value());
+  SCOPED_TRACE(module->ToString());
+  for (const HloInstruction* instr : computation->instructions()) {
+    EXPECT_LE(instr->operand_count() + ShapeUtil::SubshapeCount(instr->shape()),
+              MaxOperandsAndOutputsPerFusion())
+        << instr->ToString();
+  }
+}
+
+TEST_F(MultiOutputFusionTest, MultiOutputFusionDUS) {
+  auto module = ParseAndReturnVerifiedModule(R"(HloModule dus_mof
+    fusion.1 {
+      p.0 = f16[50,96,1024]{2,1,0} parameter(0)
+      p.1 = f16[1,96,1024]{2,1,0} parameter(1)
+      c.0 = s32[3]{0} constant({0, 0, 0})
+      ROOT %dynamic-update-slice = f16[50,96,1024]{2,1,0} dynamic-update-slice(p.0, p.1, c.0)
+    }
+
+    fusion.2 {
+      p.0 = f16[50,96,1024]{2,1,0} parameter(0)
+      p.1 = f16[1,96,1024]{2,1,0} parameter(1)
+      c.0 = s32[3]{0} constant({0, 0, 0})
+      ROOT %dynamic-update-slice = f16[50,96,1024]{2,1,0} dynamic-update-slice(p.0, p.1, c.0)
+    }
+
+    ENTRY entry {
+      p.00 = f16[50,96,1024]{2,1,0} parameter(0)
+      p.01 = f16[50,96,1024]{2,1,0} parameter(1)
+      p.1 = f16[1,96,1024]{2,1,0} parameter(2)
+
+      f1 = f16[50,96,1024] fusion(p.00, p.1), kind=kLoop, calls=fusion.1
+      f2 = f16[50,96,1024] fusion(p.01, p.1), kind=kLoop, calls=fusion.2
+      ROOT tuple = (f16[50,96,1024],f16[50,96,1024]) tuple(f1, f2)
+    })")
+                    .value();
+  ASSERT_FALSE(mof_.Run(module.get()).value());
+}
+
+// Check that we don't fuse too many reductions together.
+TEST_F(MultiOutputFusionTest, SharedMemoryBudget) {
+  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+    fused_computation0 {
+      p0 = f32[64,64] parameter(0)
+      p1 = f32[64,64] parameter(1)
+      p2 = f32[] parameter(2)
+      add = f32[64,64] add(p0, p1)
+      ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
+        to_apply=scalar_add_computation
+    }
+    fused_computation1 {
+      p0 = f32[64,64] parameter(0)
+      p1 = f32[64,64] parameter(1)
+      p2 = f32[] parameter(2)
+      add = f32[64,64] add(p0, p1)
+      ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
+        to_apply=scalar_add_computation
+    }
+    fused_computation2 {
+      p0 = f32[64,64] parameter(0)
+      p1 = f32[64,64] parameter(1)
+      p2 = f32[] parameter(2)
+      add = f32[64,64] add(p0, p1)
+      ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
+        to_apply=scalar_add_computation
+    }
+    fused_computation3 {
+      p0 = f32[64,64] parameter(0)
+      p1 = f32[64,64] parameter(1)
+      p2 = f32[] parameter(2)
+      add = f32[64,64] add(p0, p1)
+      ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
+        to_apply=scalar_add_computation
+    }
+    fused_computation4 {
+      p0 = f32[64,64] parameter(0)
+      p1 = f32[64,64] parameter(1)
+      p2 = f32[] parameter(2)
+      add = f32[64,64] add(p0, p1)
+      ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
+        to_apply=scalar_add_computation
+    }
+    fused_computation5 {
+      p0 = f32[64,64] parameter(0)
+      p1 = f32[64,64] parameter(1)
+      p2 = f32[] parameter(2)
+      add = f32[64,64] add(p0, p1)
+      ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
+        to_apply=scalar_add_computation
+    }
+    fused_computation6 {
+      p0 = f32[64,64] parameter(0)
+      p1 = f32[64,64] parameter(1)
+      p2 = f32[] parameter(2)
+      add = f32[64,64] add(p0, p1)
+      ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
+        to_apply=scalar_add_computation
+    }
+    fused_computation7 {
+      p0 = f32[64,64] parameter(0)
+      p1 = f32[64,64] parameter(1)
+      p2 = f32[] parameter(2)
+      add = f32[64,64] add(p0, p1)
+      ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
+        to_apply=scalar_add_computation
+    }
+    fused_computation8 {
+      p0 = f32[64,64] parameter(0)
+      p1 = f32[64,64] parameter(1)
+      p2 = f32[] parameter(2)
+      add = f32[64,64] add(p0, p1)
+      ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
+        to_apply=scalar_add_computation
+    }
+    fused_computation9 {
+      p0 = f32[64,64] parameter(0)
+      p1 = f32[64,64] parameter(1)
+      p2 = f32[] parameter(2)
+      add = f32[64,64] add(p0, p1)
+      ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
+        to_apply=scalar_add_computation
+    }
+    ENTRY computation {
+      zero = f32[] constant(0)
+      param0 = f32[64,64] parameter(0)
+      param1 = f32[64,64] parameter(1)
+      param2 = f32[64,64] parameter(2)
+      param3 = f32[64,64] parameter(3)
+      param4 = f32[64,64] parameter(4)
+      param5 = f32[64,64] parameter(5)
+      param6 = f32[64,64] parameter(6)
+      param7 = f32[64,64] parameter(7)
+      param8 = f32[64,64] parameter(8)
+      param9 = f32[64,64] parameter(9)
+      out0 = f32[64] fusion(param0, param1, zero), kind=kInput, calls=fused_computation0
+      out1 = f32[64] fusion(param1, param2, zero), kind=kInput, calls=fused_computation1
+      out2 = f32[64] fusion(param2, param3, zero), kind=kInput, calls=fused_computation2
+      out3 = f32[64] fusion(param3, param4, zero), kind=kInput, calls=fused_computation3
+      out4 = f32[64] fusion(param4, param5, zero), kind=kInput, calls=fused_computation4
+      out5 = f32[64] fusion(param5, param6, zero), kind=kInput, calls=fused_computation5
+      out6 = f32[64] fusion(param6, param7, zero), kind=kInput, calls=fused_computation6
+      out7 = f32[64] fusion(param7, param8, zero), kind=kInput, calls=fused_computation7
+      out8 = f32[64] fusion(param8, param9, zero), kind=kInput, calls=fused_computation8
+      out9 = f32[64] fusion(param9, param0, zero), kind=kInput, calls=fused_computation9
+      ROOT out = (f32[64], f32[64], f32[64], f32[64], f32[64], f32[64], f32[64], f32[64], f32[64], f32[64]) tuple(f32[64] out0, f32[64] out1, f32[64] out2, f32[64] out3, f32[64] out4, f32[64] out5, f32[64] out6, f32[64] out7, f32[64] out8, f32[64] out9)
+    }
+  )"))
+                    .value();
+  ASSERT_TRUE(mof_.Run(module.get()).value());
+
+  EXPECT_EQ(5, CountMultiOutputFusions(module.get()));
+}
+
+TEST_F(MultiOutputFusionTest, DoNotGroupTooManyReductions) {
+  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+    fused_computation0 {
+      p0 = f32[64,64] parameter(0)
+      p1 = f32[64,64] parameter(1)
+      p2 = f32[] parameter(2)
+      add = f32[64,64] add(p0, p1)
+      ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
+        to_apply=scalar_add_computation
+    }
+    fused_computation1 {
+      p0 = f32[64,64] parameter(0)
+      p1 = f32[64,64] parameter(1)
+      p2 = f32[] parameter(2)
+      add = f32[64,64] add(p0, p1)
+      ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
+        to_apply=scalar_add_computation
+    }
+    fused_computation2 {
+      p0 = f32[64,64] parameter(0)
+      p1 = f32[64,64] parameter(1)
+      p2 = f32[] parameter(2)
+      add = f32[64,64] add(p0, p1)
+      ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
+        to_apply=scalar_add_computation
+    }
+    fused_computation3 {
+      p0 = f32[64,64] parameter(0)
+      p1 = f32[64,64] parameter(1)
+      p2 = f32[] parameter(2)
+      add = f32[64,64] add(p0, p1)
+      ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
+        to_apply=scalar_add_computation
+    }
+    fused_computation4 {
+      p0 = f32[64,64] parameter(0)
+      p1 = f32[64,64] parameter(1)
+      p2 = f32[] parameter(2)
+      add = f32[64,64] add(p0, p1)
+      ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
+        to_apply=scalar_add_computation
+    }
+    fused_computation5 {
+      p0 = f32[64,64] parameter(0)
+      p1 = f32[64,64] parameter(1)
+      p2 = f32[] parameter(2)
+      add = f32[64,64] add(p0, p1)
+      ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
+        to_apply=scalar_add_computation
+    }
+    fused_computation6 {
+      p0 = f32[64,64] parameter(0)
+      p1 = f32[64,64] parameter(1)
+      p2 = f32[] parameter(2)
+      add = f32[64,64] add(p0, p1)
+      ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
+        to_apply=scalar_add_computation
+    }
+    fused_computation7 {
+      p0 = f32[64,64] parameter(0)
+      p1 = f32[64,64] parameter(1)
+      p2 = f32[] parameter(2)
+      add = f32[64,64] add(p0, p1)
+      ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
+        to_apply=scalar_add_computation
+    }
+    fused_computation8 {
+      p0 = f32[64,64] parameter(0)
+      p1 = f32[64,64] parameter(1)
+      p2 = f32[] parameter(2)
+      add = f32[64,64] add(p0, p1)
+      ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
+        to_apply=scalar_add_computation
+    }
+    fused_computation9 {
+      p0 = f32[64,64] parameter(0)
+      p1 = f32[64,64] parameter(1)
+      p2 = f32[] parameter(2)
+      add = f32[64,64] add(p0, p1)
+      ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
+        to_apply=scalar_add_computation
+    }
+    ENTRY computation {
+      zero = f32[] constant(0)
+      param0 = f32[64,64] parameter(0)
+      param1 = f32[64,64] parameter(1)
+      param2 = f32[64,64] parameter(2)
+      param3 = f32[64,64] parameter(3)
+      param4 = f32[64,64] parameter(4)
+      param5 = f32[64,64] parameter(5)
+      param6 = f32[64,64] parameter(6)
+      param7 = f32[64,64] parameter(7)
+      param8 = f32[64,64] parameter(8)
+      param9 = f32[64,64] parameter(9)
+      out0 = f32[64] fusion(param0, param1, zero), kind=kInput, calls=fused_computation0
+      out1 = f32[64] fusion(param1, param2, zero), kind=kInput, calls=fused_computation1
+      out2 = f32[64] fusion(param2, param3, zero), kind=kInput, calls=fused_computation2
+      out3 = f32[64] fusion(param3, param4, zero), kind=kInput, calls=fused_computation3
+      out4 = f32[64] fusion(param4, param5, zero), kind=kInput, calls=fused_computation4
+      out5 = f32[64] fusion(param5, param6, zero), kind=kInput, calls=fused_computation5
+      out6 = f32[64] fusion(param6, param7, zero), kind=kInput, calls=fused_computation6
+      out7 = f32[64] fusion(param7, param8, zero), kind=kInput, calls=fused_computation7
+      out8 = f32[64] fusion(param8, param9, zero), kind=kInput, calls=fused_computation8
+      out9 = f32[64] fusion(param9, param0, zero), kind=kInput, calls=fused_computation9
+      ROOT out = (f32[64], f32[64], f32[64], f32[64], f32[64], f32[64], f32[64], f32[64], f32[64], f32[64]) tuple(f32[64] out0, f32[64] out1, f32[64] out2, f32[64] out3, f32[64] out4, f32[64] out5, f32[64] out6, f32[64] out7, f32[64] out8, f32[64] out9)
+    }
+  )"))
+                    .value();
+  ASSERT_TRUE(mof_.Run(module.get()).value());
+
+  EXPECT_EQ(2, CountMultiOutputFusions(module.get()));
+}
+
+TEST_F(MultiOutputFusionTest, NoFusionToAvoidUsingTooMuchSharedMemory) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule xla_computation_update_step.10931
+
+%scalar_add_computation.1 (scalar_lhs.1: f64[], scalar_rhs.1: f64[]) -> f64[] {
+  %scalar_lhs.1 = f64[] parameter(0)
+  %scalar_rhs.1 = f64[] parameter(1)
+  ROOT %add.1257 = f64[] add(f64[] %scalar_lhs.1, f64[] %scalar_rhs.1)
+}
+
+%fused_computation.1 (param_0.8: f64[64,64], param_1.11: f64[64,64], param_2.9: f64[64,64]) -> (f64[64], f64[64]) {
+  %param_0.8 = f64[64,64]{1,0} parameter(0)
+  %param_1.11 = f64[64,64]{1,0} parameter(1)
+  %multiply.2 = f64[64,64]{1,0} multiply(f64[64,64]{1,0} %param_0.8, f64[64,64]{1,0} %param_1.11)
+  %constant_5217.3 = f64[] constant(0)
+  %broadcast.1 = f64[64,64]{1,0} broadcast(f64[] %constant_5217.3), dimensions={}
+  %multiply.0 = f64[64,64]{1,0} multiply(f64[64,64]{1,0} %multiply.2, f64[64,64]{1,0} %broadcast.1)
+  %reduce.0 = f64[64]{0} reduce(f64[64,64]{1,0} %multiply.0, f64[] %constant_5217.3), dimensions={0}, to_apply=%scalar_add_computation.1
+  %param_2.9 = f64[64,64]{1,0} parameter(2)
+  %multiply.1514.clone.0.clone.1 = f64[64,64]{1,0} multiply(f64[64,64]{1,0} %param_2.9, f64[64,64]{1,0} %param_1.11)
+  %constant_5217.1.clone.1 = f64[] constant(0)
+  %broadcast.0.clone.1 = f64[64,64]{1,0} broadcast(f64[] %constant_5217.1.clone.1), dimensions={}
+  %multiply.1341.clone.0.clone.1 = f64[64,64]{1,0} multiply(f64[64,64]{1,0} %multiply.1514.clone.0.clone.1, f64[64,64]{1,0} %broadcast.0.clone.1)
+  %reduce.630.clone.0.clone.1 = f64[64]{0} reduce(f64[64,64]{1,0} %multiply.1341.clone.0.clone.1, f64[] %constant_5217.1.clone.1), dimensions={0}, to_apply=%scalar_add_computation.1
+  ROOT %tuple = (f64[64]{0}, f64[64]{0}) tuple(f64[64]{0} %reduce.0, f64[64]{0} %reduce.630.clone.0.clone.1)
+}
+
+%primitive_computation_add__1.6426 (parameter.6427: f64[], parameter.6428: f64[]) -> f64[] {
+  %parameter.6427 = f64[] parameter(0)
+  %parameter.6428 = f64[] parameter(1)
+  ROOT %add.6429 = f64[] add(f64[] %parameter.6427, f64[] %parameter.6428)
+}
+
+%fused_computation.2 (param_0.7: f64[64,64], param_1.9: f64[64,64]) -> f64[64] {
+  %param_0.7 = f64[64,64]{1,0} parameter(0)
+  %param_1.9 = f64[64,64]{1,0} parameter(1)
+  %multiply.1 = f64[64,64]{1,0} multiply(f64[64,64]{1,0} %param_0.7, f64[64,64]{1,0} %param_1.9)
+  %constant_5217.2 = f64[] constant(0)
+  ROOT %reduce.740.clone.0 = f64[64]{0} reduce(f64[64,64]{1,0} %multiply.1, f64[] %constant_5217.2), dimensions={0}, to_apply=%primitive_computation_add__1.6426
+}
+
+ENTRY %reproducer (param_0.1090: f64[64,64], param_1.1377: f64[64,64], param_2.1948: f64[64,64]) -> (f64[64], f64[64], f64[64]) {
+  %param_0.1090 = f64[64,64]{1,0} parameter(0)
+  %param_1.1377 = f64[64,64]{1,0} parameter(1)
+  %param_2.1948 = f64[64,64]{1,0} parameter(2)
+  %fusion.1 = (f64[64]{0}, f64[64]{0}) fusion(f64[64,64]{1,0} %param_0.1090, f64[64,64]{1,0} %param_1.1377, f64[64,64]{1,0} %param_2.1948), kind=kInput, calls=%fused_computation.1
+  %get-tuple-element = f64[64]{0} get-tuple-element((f64[64]{0}, f64[64]{0}) %fusion.1), index=0
+  %fusion.2 = f64[64]{0} fusion(f64[64,64]{1,0} %param_0.1090, f64[64,64]{1,0} %param_1.1377), kind=kInput, calls=%fused_computation.2
+  %get-tuple-element.1 = f64[64]{0} get-tuple-element((f64[64]{0}, f64[64]{0}) %fusion.1), index=1
+  ROOT %tuple.428 = (f64[64]{0}, f64[64]{0}, f64[64]{0}) tuple(f64[64]{0} %get-tuple-element, f64[64]{0} %fusion.2, f64[64]{0} %get-tuple-element.1)
+}
+  )")
+                    .value();
+  EXPECT_FALSE(mof_.Run(module.get()).value());
+}
+
+TEST_F(MultiOutputFusionTest, NoFusionToAvoidCodeDuplication) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+HloModule module
+
+and.reduce_sub_computation {
+  x = pred[] parameter(0)
+  y = pred[] parameter(1)
+  ROOT and = pred[] and(x, y)
+}
+
+fused_computation.1 {
+  param_4.658 = f32[2,20,256]{2,0,1} parameter(4)
+  slice.1385 = f32[2,1,256]{2,0,1} slice(param_4.658), slice={[0:2], [11:12], [0:256]}
+  constant.6847 = s32[] constant(0)
+  broadcast.4823 = s32[3]{0} broadcast(constant.6847), dimensions={}
+  param_9.415 = s32[3]{0} parameter(9)
+  compare.700 = pred[3]{0} compare(broadcast.4823, param_9.415), direction=LE
+  constant.6846 = pred[] constant(true)
+  reduce.221 = pred[] reduce(compare.700, constant.6846), dimensions={0}, to_apply=and.reduce_sub_computation
+  broadcast.2933 = pred[2,1,256]{2,0,1} broadcast(reduce.221), dimensions={}
+  param_5.528 = f32[2,512]{1,0} parameter(5)
+  slice.1384 = f32[2,256]{1,0} slice(param_5.528), slice={[0:2], [0:256]}
+  bitcast.341 = f32[2,1,256]{2,0,1} bitcast(slice.1384)
+  constant.5418 = f32[] constant(0)
+  broadcast.3227 = f32[2,1,256]{2,0,1} broadcast(constant.5418), dimensions={}
+  select.173 = f32[2,1,256]{2,0,1} select(broadcast.2933, bitcast.341, broadcast.3227)
+  add.573 = f32[2,1,256]{2,0,1} add(slice.1385, select.173)
+  param_0.299 = s32[] parameter(0)
+  constant.5157 = s32[] constant(11)
+  dynamic-update-slice.189 = f32[2,20,256]{2,0,1} dynamic-update-slice(param_4.658, add.573, param_0.299, constant.5157, param_0.299)
+  slice.1383 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.189), slice={[0:2], [10:11], [0:256]}
+  constant.6800 = s32[] constant(0)
+  broadcast.4803 = s32[3]{0} broadcast(constant.6800), dimensions={}
+  param_8.484 = s32[3]{0} parameter(8)
+  compare.681 = pred[3]{0} compare(broadcast.4803, param_8.484), direction=LE
+  constant.6798 = pred[] constant(true)
+  reduce.203 = pred[] reduce(compare.681, constant.6798), dimensions={0}, to_apply=and.reduce_sub_computation
+  broadcast.2932 = pred[2,1,256]{2,0,1} broadcast(reduce.203), dimensions={}
+  param_3.1169 = f32[2,512]{1,0} parameter(3)
+  slice.1382 = f32[2,256]{1,0} slice(param_3.1169), slice={[0:2], [0:256]}
+  bitcast.340 = f32[2,1,256]{2,0,1} bitcast(slice.1382)
+  select.172 = f32[2,1,256]{2,0,1} select(broadcast.2932, bitcast.340, broadcast.3227)
+  add.572 = f32[2,1,256]{2,0,1} add(slice.1383, select.172)
+  constant.5154 = s32[] constant(10)
+  dynamic-update-slice.188 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.189, add.572, param_0.299, constant.5154, param_0.299)
+  slice.1381 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.188), slice={[0:2], [9:10], [0:256]}
+  constant.6794 = s32[] constant(0)
+  broadcast.4801 = s32[3]{0} broadcast(constant.6794), dimensions={}
+  param_7.478 = s32[3]{0} parameter(7)
+  compare.679 = pred[3]{0} compare(broadcast.4801, param_7.478), direction=LE
+  constant.6793 = pred[] constant(true)
+  reduce.201 = pred[] reduce(compare.679, constant.6793), dimensions={0}, to_apply=and.reduce_sub_computation
+  broadcast.2930 = pred[2,1,256]{2,0,1} broadcast(reduce.201), dimensions={}
+  param_2.1685 = f32[2,512]{1,0} parameter(2)
+  slice.1380 = f32[2,256]{1,0} slice(param_2.1685), slice={[0:2], [0:256]}
+  bitcast.339 = f32[2,1,256]{2,0,1} bitcast(slice.1380)
+  select.171 = f32[2,1,256]{2,0,1} select(broadcast.2930, bitcast.339, broadcast.3227)
+  add.571 = f32[2,1,256]{2,0,1} add(slice.1381, select.171)
+  constant.5153 = s32[] constant(9)
+  dynamic-update-slice.187 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.188, add.571, param_0.299, constant.5153, param_0.299)
+  slice.1379 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.187), slice={[0:2], [8:9], [0:256]}
+  constant.6788 = s32[] constant(0)
+  broadcast.4799 = s32[3]{0} broadcast(constant.6788), dimensions={}
+  param_6.495 = s32[3]{0} parameter(6)
+  compare.677 = pred[3]{0} compare(broadcast.4799, param_6.495), direction=LE
+  constant.6786 = pred[] constant(true)
+  reduce.199 = pred[] reduce(compare.677, constant.6786), dimensions={0}, to_apply=and.reduce_sub_computation
+  broadcast.2929 = pred[2,1,256]{2,0,1} broadcast(reduce.199), dimensions={}
+  param_1.1408 = f32[2,512]{1,0} parameter(1)
+  slice.1378 = f32[2,256]{1,0} slice(param_1.1408), slice={[0:2], [0:256]}
+  bitcast.338 = f32[2,1,256]{2,0,1} bitcast(slice.1378)
+  select.170 = f32[2,1,256]{2,0,1} select(broadcast.2929, bitcast.338, broadcast.3227)
+  add.570 = f32[2,1,256]{2,0,1} add(slice.1379, select.170)
+  constant.5152 = s32[] constant(8)
+  ROOT dynamic-update-slice.186 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.187, add.570, param_0.299, constant.5152, param_0.299)
+}
+
+fused_computation.2 {
+  param_4.655 = f32[2,20,256]{2,0,1} parameter(4)
+  slice.1369 = f32[2,1,256]{2,0,1} slice(param_4.655), slice={[0:2], [7:8], [0:256]}
+  param_6.483 = pred[] parameter(6)
+  broadcast.2927 = pred[2,1,256]{2,0,1} broadcast(param_6.483), dimensions={}
+  param_5.525 = f32[2,512]{1,0} parameter(5)
+  slice.1368 = f32[2,256]{1,0} slice(param_5.525), slice={[0:2], [0:256]}
+  bitcast.333 = f32[2,1,256]{2,0,1} bitcast(slice.1368)
+  constant.5415 = f32[] constant(0)
+  broadcast.3225 = f32[2,1,256]{2,0,1} broadcast(constant.5415), dimensions={}
+  select.161 = f32[2,1,256]{2,0,1} select(broadcast.2927, bitcast.333, broadcast.3225)
+  add.549 = f32[2,1,256]{2,0,1} add(slice.1369, select.161)
+  param_0.265 = s32[] parameter(0)
+  constant.5151 = s32[] constant(7)
+  dynamic-update-slice.185 = f32[2,20,256]{2,0,1} dynamic-update-slice(param_4.655, add.549, param_0.265, constant.5151, param_0.265)
+  slice.1367 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.185), slice={[0:2], [6:7], [0:256]}
+  constant.6782 = s32[] constant(0)
+  broadcast.4797 = s32[3]{0} broadcast(constant.6782), dimensions={}
+  param_9.391 = s32[3]{0} parameter(9)
+  compare.675 = pred[3]{0} compare(broadcast.4797, param_9.391), direction=LE
+  constant.6781 = pred[] constant(true)
+  reduce.197 = pred[] reduce(compare.675, constant.6781), dimensions={0}, to_apply=and.reduce_sub_computation
+  broadcast.2926 = pred[2,1,256]{2,0,1} broadcast(reduce.197), dimensions={}
+  param_3.1167 = f32[2,512]{1,0} parameter(3)
+  slice.1366 = f32[2,256]{1,0} slice(param_3.1167), slice={[0:2], [0:256]}
+  bitcast.332 = f32[2,1,256]{2,0,1} bitcast(slice.1366)
+  select.160 = f32[2,1,256]{2,0,1} select(broadcast.2926, bitcast.332, broadcast.3225)
+  add.548 = f32[2,1,256]{2,0,1} add(slice.1367, select.160)
+  constant.5150 = s32[] constant(6)
+  dynamic-update-slice.184 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.185, add.548, param_0.265, constant.5150, param_0.265)
+  slice.1365 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.184), slice={[0:2], [5:6], [0:256]}
+  constant.6776 = s32[] constant(0)
+  broadcast.4794 = s32[3]{0} broadcast(constant.6776), dimensions={}
+  param_8.464 = s32[3]{0} parameter(8)
+  compare.673 = pred[3]{0} compare(broadcast.4794, param_8.464), direction=LE
+  constant.6775 = pred[] constant(true)
+  reduce.195 = pred[] reduce(compare.673, constant.6775), dimensions={0}, to_apply=and.reduce_sub_computation
+  broadcast.2925 = pred[2,1,256]{2,0,1} broadcast(reduce.195), dimensions={}
+  param_2.1684 = f32[2,512]{1,0} parameter(2)
+  slice.1364 = f32[2,256]{1,0} slice(param_2.1684), slice={[0:2], [0:256]}
+  bitcast.331 = f32[2,1,256]{2,0,1} bitcast(slice.1364)
+  select.159 = f32[2,1,256]{2,0,1} select(broadcast.2925, bitcast.331, broadcast.3225)
+  add.547 = f32[2,1,256]{2,0,1} add(slice.1365, select.159)
+  constant.5149 = s32[] constant(5)
+  dynamic-update-slice.183 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.184, add.547, param_0.265, constant.5149, param_0.265)
+  slice.1363 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.183), slice={[0:2], [4:5], [0:256]}
+  constant.6770 = s32[] constant(0)
+  broadcast.4792 = s32[3]{0} broadcast(constant.6770), dimensions={}
+  param_7.458 = s32[3]{0} parameter(7)
+  compare.671 = pred[3]{0} compare(broadcast.4792, param_7.458), direction=LE
+  constant.6769 = pred[] constant(true)
+  reduce.193 = pred[] reduce(compare.671, constant.6769), dimensions={0}, to_apply=and.reduce_sub_computation
+  broadcast.2924 = pred[2,1,256]{2,0,1} broadcast(reduce.193), dimensions={}
+  param_1.1405 = f32[2,512]{1,0} parameter(1)
+  slice.1362 = f32[2,256]{1,0} slice(param_1.1405), slice={[0:2], [0:256]}
+  bitcast.330 = f32[2,1,256]{2,0,1} bitcast(slice.1362)
+  select.158 = f32[2,1,256]{2,0,1} select(broadcast.2924, bitcast.330, broadcast.3225)
+  add.546 = f32[2,1,256]{2,0,1} add(slice.1363, select.158)
+  constant.5148 = s32[] constant(4)
+  ROOT dynamic-update-slice.182 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.183, add.546, param_0.265, constant.5148, param_0.265)
+}
+
+ENTRY main {
+  param_0.0 = s32[] parameter(0)
+  param_1.0 = f32[2,512]{1,0} parameter(1)
+  param_2.0 = f32[2,512]{1,0} parameter(2)
+  param_3.0 = f32[2,512]{1,0} parameter(3)
+  param_4.0 = f32[2,20,256]{2,1,0} parameter(4)
+  param_5.0 = f32[2,512]{1,0} parameter(5)
+  param_6.0 = s32[3]{0} parameter(6)
+  param_7.0 = s32[3]{0} parameter(7)
+  param_8.0 = s32[3]{0} parameter(8)
+  param_9.0 = s32[3]{0} parameter(9)
+  fusion.1 = f32[2,20,256]{2,0,1} fusion(param_0.0, param_1.0, param_2.0, param_3.0, param_4.0, param_5.0, param_6.0, param_7.0, param_8.0, param_9.0), kind=kLoop, calls=fused_computation.1
+  param_10 = pred[] parameter(10)
+  fusion.2 = f32[2,20,256]{2,0,1} fusion(param_0.0, param_1.0, param_2.0, param_3.0, fusion.1, param_5.0, param_10, param_7.0, param_8.0, param_9.0), kind=kLoop, calls=fused_computation.2
+  ROOT root = (f32[2,20,256]{2,0,1}, f32[2,20,256]{2,0,1}) tuple(fusion.1, fusion.2)
+}
+  )")
+                    .value();
+  auto& debug_options = module->mutable_config().mutable_debug_options();
+  debug_options.set_xla_gpu_mlir_emitter_level(3);
+  EXPECT_FALSE(mof_.Run(module.get()).value());
+}
+
+TEST_F(MultiOutputFusionTest, DoNotFuseRoot) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+HloModule module
+
+no_op {
+  arg_empty_tuple = () parameter(0)
+  ROOT tuple = () tuple()
+}
+
+fused_computation {
+  param_0 = f32[] parameter(0)
+  ROOT convert = s32[] convert(param_0)
+}
+
+ENTRY main {
+  param_0 = f32[] parameter(0)
+  fusion = s32[] fusion(param_0), kind=kLoop, calls=fused_computation
+  tuple = () tuple()
+  conditional = () conditional(fusion, tuple, tuple), branch_computations={no_op, no_op}
+  constant = f32[] constant(1)
+  ROOT root = f32[] add(param_0, constant)
+}
+  )")
+                    .value();
+  EXPECT_FALSE(mof_.Run(module.get()).value());
+}
+
+TEST_F(MultiOutputFusionTest, CostBasedNoMerge) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+HloModule m
+
+region_3.63 {
+  Arg_0.64 = f32[] parameter(0)
+  Arg_1.65 = f32[] parameter(1)
+  ROOT add.66 = f32[] add(Arg_0.64, Arg_1.65)
+}
+
+fused_computation.29 {
+  param_0.161 = f32[5,32,32,1]{3,2,1,0} parameter(0)
+  multiply.208 = f32[5,32,32,1]{3,2,1,0} multiply(param_0.161, param_0.161)
+  bitcast.67 = f32[5,32,32]{2,1,0} bitcast(multiply.208)
+  constant.265 = f32[] constant(0)
+  reduce-window.81 = f32[5,30,31]{2,1,0} reduce-window(bitcast.67, constant.265), window={size=1x3x2}, to_apply=region_3.63
+  constant.264 = f32[] constant(0.166666672)
+  broadcast.204 = f32[5,30,31]{2,1,0} broadcast(constant.264), dimensions={}
+  multiply.205 = f32[5,30,31]{2,1,0} multiply(reduce-window.81, broadcast.204)
+  constant.263 = f32[] constant(0)
+  reduce-window.80 = f32[5,30,31]{2,1,0} reduce-window(multiply.205, constant.263), window={size=1x2x3 pad=0_0x0_1x1_1}, to_apply=region_3.63
+  constant.262 = f32[] constant(0.0138888899)
+  broadcast.201 = f32[5,30,31]{2,1,0} broadcast(constant.262), dimensions={}
+  multiply.204 = f32[5,30,31]{2,1,0} multiply(reduce-window.80, broadcast.201)
+  constant.261 = f32[] constant(0)
+  reduce-window.78 = f32[5,30,31]{2,1,0} reduce-window(multiply.204, constant.261), window={size=1x1x2 pad=0_0x0_0x0_1}, to_apply=region_3.63
+  constant.113 = f32[] constant(0.5)
+  broadcast.137 = f32[5,30,31]{2,1,0} broadcast(constant.113), dimensions={}
+  multiply.125 = f32[5,30,31]{2,1,0} multiply(reduce-window.78, broadcast.137)
+  constant.114 = f32[] constant(0)
+  ROOT reduce-window.17 = f32[5,30,31]{2,1,0} reduce-window(multiply.125, constant.114), window={size=1x2x1 pad=0_0x0_1x0_0}, to_apply=region_3.63
+}
+
+fused_computation.15 {
+  constant.108 = f32[] constant(0.5)
+  broadcast.105 = f32[5,5,30,31]{3,2,1,0} broadcast(constant.108), dimensions={}
+  param_3.126 = f32[5,30,31]{2,1,0} parameter(3)
+  constant.295 = f32[] constant(0.25)
+  broadcast.234 = f32[5,30,31]{2,1,0} broadcast(constant.295), dimensions={}
+  multiply.242 = f32[5,30,31]{2,1,0} multiply(param_3.126, broadcast.234)
+  broadcast.233 = f32[5,5,30,31]{3,2,1,0} broadcast(multiply.242), dimensions={0,2,3}
+  param_2.154 = f32[5,30,31]{2,1,0} parameter(2)
+  multiply.241 = f32[5,30,31]{2,1,0} multiply(param_2.154, broadcast.234)
+  broadcast.232 = f32[5,5,30,31]{3,2,1,0} broadcast(multiply.241), dimensions={1,2,3}
+  multiply.240 = f32[5,5,30,31]{3,2,1,0} multiply(broadcast.233, broadcast.232)
+  param_1.188 = f32[5,5,30,31]{3,2,1,0} parameter(1)
+  constant.294 = f32[] constant(0.159154937)
+  broadcast.231 = f32[5,5,30,31]{3,2,1,0} broadcast(constant.294), dimensions={}
+  multiply.239 = f32[5,5,30,31]{3,2,1,0} multiply(param_1.188, broadcast.231)
+  param_0.164 = f32[5,5,30,31]{3,2,1,0} parameter(0)
+  add.19 = f32[5,5,30,31]{3,2,1,0} add(multiply.239, param_0.164)
+  constant.293 = f32[] constant(0)
+  reduce-window.90 = f32[5,5,30,31]{3,2,1,0} reduce-window(add.19, constant.293), window={size=1x1x1x2 pad=0_0x0_0x0_0x0_1}, to_apply=region_3.63
+  constant.292 = f32[] constant(0.5)
+  broadcast.230 = f32[5,5,30,31]{3,2,1,0} broadcast(constant.292), dimensions={}
+  multiply.238 = f32[5,5,30,31]{3,2,1,0} multiply(reduce-window.90, broadcast.230)
+  constant.291 = f32[] constant(0)
+  reduce-window.89 = f32[5,5,30,31]{3,2,1,0} reduce-window(multiply.238, constant.291), window={size=1x1x2x1 pad=0_0x0_0x0_1x0_0}, to_apply=region_3.63
+  constant.290 = f32[] constant(0.25)
+  broadcast.229 = f32[5,5,30,31]{3,2,1,0} broadcast(constant.290), dimensions={}
+  multiply.237 = f32[5,5,30,31]{3,2,1,0} multiply(reduce-window.89, broadcast.229)
+  multiply.236 = f32[5,5,30,31]{3,2,1,0} multiply(multiply.237, multiply.237)
+  subtract.10 = f32[5,5,30,31]{3,2,1,0} subtract(multiply.240, multiply.236)
+  constant.289 = f32[] constant(0)
+  broadcast.228 = f32[5,5,30,31]{3,2,1,0} broadcast(constant.289), dimensions={}
+  maximum.6 = f32[5,5,30,31]{3,2,1,0} maximum(subtract.10, broadcast.228)
+  sqrt.6 = f32[5,5,30,31]{3,2,1,0} sqrt(maximum.6)
+  constant.110 = f32[] constant(0)
+  broadcast.107 = f32[5,5,30,31]{3,2,1,0} broadcast(constant.110), dimensions={}
+  compare.4 = pred[5,5,30,31]{3,2,1,0} compare(sqrt.6, broadcast.107), direction=EQ
+  constant.243 = f32[] constant(0.159154937)
+  broadcast.193 = f32[5,5,30,31]{3,2,1,0} broadcast(constant.243), dimensions={}
+  multiply.194 = f32[5,5,30,31]{3,2,1,0} multiply(param_1.188, broadcast.193)
+  add.15 = f32[5,5,30,31]{3,2,1,0} add(multiply.194, param_0.164)
+  constant.242 = f32[] constant(0)
+  reduce-window.66 = f32[5,5,30,31]{3,2,1,0} reduce-window(add.15, constant.242), window={size=1x1x1x2 pad=0_0x0_0x0_0x0_1}, to_apply=region_3.63
+  constant.241 = f32[] constant(0.5)
+  broadcast.192 = f32[5,5,30,31]{3,2,1,0} broadcast(constant.241), dimensions={}
+  multiply.193 = f32[5,5,30,31]{3,2,1,0} multiply(reduce-window.66, broadcast.192)
+  constant.240 = f32[] constant(0)
+  reduce-window.65 = f32[5,5,30,31]{3,2,1,0} reduce-window(multiply.193, constant.240), window={size=1x1x2x1 pad=0_0x0_0x0_1x0_0}, to_apply=region_3.63
+  constant.239 = f32[] constant(0.25)
+  broadcast.191 = f32[5,5,30,31]{3,2,1,0} broadcast(constant.239), dimensions={}
+  multiply.192 = f32[5,5,30,31]{3,2,1,0} multiply(reduce-window.65, broadcast.191)
+  compare.3 = pred[5,5,30,31]{3,2,1,0} compare(multiply.192, broadcast.107), direction=EQ
+  and.1 = pred[5,5,30,31]{3,2,1,0} and(compare.4, compare.3)
+  constant.109 = f32[] constant(1.57079637)
+  broadcast.104 = f32[5,5,30,31]{3,2,1,0} broadcast(constant.109), dimensions={}
+  atan2.1 = f32[5,5,30,31]{3,2,1,0} atan2(sqrt.6, multiply.192)
+  select.4 = f32[5,5,30,31]{3,2,1,0} select(and.1, broadcast.104, atan2.1)
+  constant.107 = f32[] constant(0.159154937)
+  broadcast.106 = f32[5,5,30,31]{3,2,1,0} broadcast(constant.107), dimensions={}
+  multiply.100 = f32[5,5,30,31]{3,2,1,0} multiply(select.4, broadcast.106)
+  ROOT subtract.3 = f32[5,5,30,31]{3,2,1,0} subtract(broadcast.105, multiply.100)
+}
+
+fused_computation.4 {
+  param_0.172 = f32[5,30,31]{2,1,0} parameter(0)
+  constant.315 = f32[] constant(0.125)
+  broadcast.242 = f32[5,30,31]{2,1,0} broadcast(constant.315), dimensions={}
+  multiply.250 = f32[5,30,31]{2,1,0} multiply(param_0.172, broadcast.242)
+  constant.314 = f32[] constant(0)
+  reduce-window.100 = f32[5,30,31]{2,1,0} reduce-window(multiply.250, constant.314), window={size=1x3x3 pad=0_0x1_1x1_1}, to_apply=region_3.63
+  constant.79 = f32[] constant(0.055555556)
+  broadcast.85 = f32[5,30,31]{2,1,0} broadcast(constant.79), dimensions={}
+  multiply.80 = f32[5,30,31]{2,1,0} multiply(reduce-window.100, broadcast.85)
+  constant.81 = f32[] constant(0)
+  reduce-window.1 = f32[5,30,31]{2,1,0} reduce-window(multiply.80, constant.81), window={size=1x3x3 pad=0_0x1_1x1_1}, to_apply=region_3.63
+  constant.80 = f32[] constant(0.111111112)
+  broadcast.86 = f32[5,30,31]{2,1,0} broadcast(constant.80), dimensions={}
+  multiply.79 = f32[5,30,31]{2,1,0} multiply(reduce-window.1, broadcast.86)
+  bitcast.26 = f32[5,930]{1,0} bitcast(multiply.79)
+  ROOT reduce.8 = f32[5]{0} reduce(bitcast.26, constant.81), dimensions={1}, to_apply=region_3.63
+}
+
+ENTRY e {
+  Arg_0.1 = f32[5,32,32,1]{3,2,1,0} parameter(0)
+  p1 = f32[5,5,30,31]{3,2,1,0} parameter(1)
+  p2 = f32[5,5,30,31]{3,2,1,0} parameter(2)
+  p3 = f32[5,30,31]{2,1,0} parameter(3)
+  fusion.29 = f32[5,30,31]{2,1,0} fusion(Arg_0.1), kind=kLoop, calls=fused_computation.29
+  fusion.15 = f32[5,5,30,31]{3,2,1,0} fusion(p2, p1, p3, fusion.29), kind=kLoop, calls=fused_computation.15
+  ROOT fusion.4 = f32[5]{0} fusion(fusion.29), kind=kInput, calls=fused_computation.4
+})")
+                    .value();
+  EXPECT_FALSE(mof_.Run(module.get()).value());
+}
+
+TEST_F(MultiOutputFusionTest, NoOverlappingRead) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule module
+
+    fused_computation_1 {
+      p0.1 = f32[100,200]{1,0} parameter(0)
+      slice.0 = f32[50,100]{1,0} slice(p0.1), slice={[0:50],[0:100]}
+      mul = f32[50,100]{1,0} multiply(slice.0, slice.0)
+      exp = f32[50,100]{1,0} exponential(slice.0)
+      ROOT tuple = (f32[50,100]{1,0}, f32[50,100]{1,0}) tuple(mul, exp)
+    }
+
+    fused_computation_2 {
+      p0.2 = f32[100,200]{1,0} parameter(0)
+      slice.1 = f32[50,100]{1,0} slice(p0.2), slice={[0:50],[100:200]}
+      const.2 = f32[] constant(0)
+      broadcast = f32[50,100]{1,0} broadcast(const.2), dimensions={}
+      ROOT add = f32[50,100]{1,0} add(slice.1, broadcast)
+    }
+
+    ENTRY entry {
+      p0 = f32[100,200]{1,0} parameter(0)
+      fusion.1 = (f32[50,100]{1,0}, f32[50,100]{1,0}) fusion(p0), kind=kLoop,
+        calls=fused_computation_1
+      gte0 = f32[50,100]{1,0} get-tuple-element(fusion.1), index=0
+      gte1 = f32[50,100]{1,0} get-tuple-element(fusion.1), index=1
+      fusion.2 = f32[50,100]{1,0} fusion(p0), kind=kLoop,
+        calls=fused_computation_2
+      ROOT root = (f32[50,100]{1,0}, f32[50,100]{1,0}, f32[50,100]{1,0})
+        tuple(gte0, gte1, fusion.2)
+    })")
+                    .value();
+
+  EXPECT_FALSE(mof_.Run(module.get()).value());
+}
+
+TEST_F(MultiOutputFusionTest, OverlappingRead) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule module
+
+    fused_computation_1 {
+      p0.1 = f32[100,200]{1,0} parameter(0)
+      slice.0 = f32[50,100]{1,0} slice(p0.1), slice={[0:50],[50:150]}
+      mul = f32[50,100]{1,0} multiply(slice.0, slice.0)
+      exp = f32[50,100]{1,0} exponential(slice.0)
+      ROOT tuple = (f32[50,100]{1,0}, f32[50,100]{1,0}) tuple(mul, exp)
+    }
+
+    fused_computation_2 {
+      p0.2 = f32[100,200]{1,0} parameter(0)
+      slice.1 = f32[50,100]{1,0} slice(p0.2), slice={[30:80],[20:120]}
+      const.2 = f32[] constant(0)
+      broadcast = f32[50,100]{1,0} broadcast(const.2), dimensions={}
+      ROOT add = f32[50,100]{1,0} add(slice.1, broadcast)
+    }
+
+    ENTRY entry {
+      p0 = f32[100,200]{1,0} parameter(0)
+      fusion.1 = (f32[50,100]{1,0}, f32[50,100]{1,0}) fusion(p0), kind=kLoop,
+        calls=fused_computation_1
+      gte0 = f32[50,100]{1,0} get-tuple-element(fusion.1), index=0
+      gte1 = f32[50,100]{1,0} get-tuple-element(fusion.1), index=1
+      fusion.2 = f32[50,100]{1,0} fusion(p0), kind=kLoop,
+        calls=fused_computation_2
+      ROOT root = (f32[50,100]{1,0}, f32[50,100]{1,0}, f32[50,100]{1,0})
+        tuple(gte0, gte1, fusion.2)
+    })")
+                    .value();
+
+  EXPECT_TRUE(mof_.Run(module.get()).value());
+}
+
+class TransposeMultiOutputFusionTest : public MultiOutputFusionTest {};
+
+TEST_F(TransposeMultiOutputFusionTest, MultipleCopies) {
+  const char* hlo = R"(
+HloModule module
+
+fused_computation {
+  param_0.1 = f32[16,32]{1,0} parameter(0)
+  s.1 = f32[16,32]{1,0} sqrt(param_0.1)
+  ROOT c.1 = f32[16,32]{0,1} copy(s.1)
+}
+
+ENTRY main {
+  p = f32[16,32]{1,0} parameter(0)
+  fusion = f32[16,32]{0,1} fusion(p), kind=kInput, calls=fused_computation
+  c1 = f32[16,32]{0,1} copy(p)
+  ROOT t = (f32[16,32]{0,1}, f32[16,32]{0,1}) tuple(fusion, c1)
+}
+  )";
+
+  CheckMultiOutputFusion(hlo, R"(
+// CHECK: %fused_computation (param_0.1: f32[16,32]) -> (f32[16,32], f32[16,32]) {
+// CHECK-NEXT:   [[param_0_1_0:%[^ ]+]] = f32[16,32]{1,0} parameter(0)
+// CHECK-NEXT:   [[s_1_1:%[^ ]+]] = f32[16,32]{1,0} sqrt([[param_0_1_0]])
+// CHECK-NEXT:   [[c_1_2:%[^ ]+]] = f32[16,32]{0,1} copy([[s_1_1]])
+// CHECK-NEXT:   [[c1_1_3:%[^ ]+]] = f32[16,32]{0,1} copy([[param_0_1_0]])
+// CHECK-NEXT:   ROOT [[tuple_4:%[^ ]+]] = (f32[16,32]{0,1}, f32[16,32]{0,1}) tuple([[c_1_2]], [[c1_1_3]])
+// CHECK-NEXT: }
+
+// CHECK: [[fusion_0:%[^ ]+]] = (f32[16,32]{0,1}, f32[16,32]{0,1}) fusion([[p_1:%[^ ]+]]), kind=kInput, calls=[[fused_computation_2:%[^ ]+]]
+)");
+}
+
+TEST_F(TransposeMultiOutputFusionTest, MultipleTransposes) {
+  const char* hlo = R"(
+HloModule module
+
+fused_computation {
+  param_0.1 = f32[16,32]{1,0} parameter(0)
+  s.1 = f32[16,32]{1,0} sqrt(param_0.1)
+  ROOT c.1 = f32[32,16]{1,0} transpose(s.1), dimensions={1,0}
+}
+
+ENTRY main {
+  p = f32[16,32]{1,0} parameter(0)
+  fusion = f32[32,16]{1,0} fusion(p), kind=kInput, calls=fused_computation
+  c1 = f32[32,16]{1,0} transpose(p), dimensions={1,0}
+  ROOT t = (f32[32,16]{1,0}, f32[32,16]{1,0}) tuple(fusion, c1)
+}
+  )";
+
+  CheckMultiOutputFusion(hlo, R"(
+// CHECK: %fused_computation (param_0.1: f32[16,32]) -> (f32[32,16], f32[32,16]) {
+// CHECK-NEXT:   [[param_0_1_0:%[^ ]+]] = f32[16,32]{1,0} parameter(0)
+// CHECK-NEXT:   [[s_1_1:%[^ ]+]] = f32[16,32]{1,0} sqrt([[param_0_1_0]])
+// CHECK-NEXT:   [[c_1_2:%[^ ]+]] = f32[32,16]{1,0} transpose([[s_1_1]]), dimensions={1,0}
+// CHECK-NEXT:   [[c1_1_3:%[^ ]+]] = f32[32,16]{1,0} transpose([[param_0_1_0]]), dimensions={1,0}
+// CHECK-NEXT:   ROOT [[tuple_4:%[^ ]+]] = (f32[32,16]{1,0}, f32[32,16]{1,0}) tuple([[c_1_2]], [[c1_1_3]])
+// CHECK-NEXT: }
+// CHECK: [[fusion_0:%[^ ]+]] = (f32[32,16]{1,0}, f32[32,16]{1,0}) fusion([[p_1:%[^ ]+]]), kind=kInput, calls=[[fused_computation_2:%[^ ]+]]
+)");
+}
+
+TEST_F(TransposeMultiOutputFusionTest, CopyAndTranspose) {
+  const char* hlo = R"(
+HloModule module
+
+fused_computation {
+  param_0.1 = f32[16,32]{1,0} parameter(0)
+  s.1 = f32[16,32]{1,0} sqrt(param_0.1)
+  ROOT c.1 = f32[16,32]{0,1} copy(s.1)
+}
+
+ENTRY main {
+  p = f32[16,32]{1,0} parameter(0)
+  fusion = f32[16,32]{0,1} fusion(p), kind=kInput, calls=fused_computation
+  c1 = f32[32,16]{1,0} transpose(p), dimensions={1,0}
+  ROOT t = (f32[16,32]{0,1}, f32[32,16]{1,0}) tuple(fusion, c1)
+}
+  )";
+
+  CheckMultiOutputFusion(hlo, R"(
+  // CHECK: %fused_computation ({{[^ ]+}} f32[16,32]) -> (f32[16,32], f32[32,16]) {
+  // CHECK-NEXT: [[param_0:%[^ ]+]] = f32[16,32]{1,0} parameter(0)
+  // CHECK-NEXT: [[s_1:%[^ ]+]] = f32[16,32]{1,0} sqrt([[param_0]])
+  // CHECK-NEXT: [[copy:%[^ ]+]] = f32[16,32]{0,1} copy([[s_1]])
+  // CHECK-NEXT: [[transpose:[^ ]+]] = f32[32,16]{1,0} transpose([[param_0]]), dimensions={1,0}
+  // CHECK-NEXT: ROOT {{[^ ]+}} = (f32[16,32]{0,1}, f32[32,16]{1,0}) tuple([[copy]], [[transpose]])
+  // CHECK: %fusion = (f32[16,32]{0,1}, f32[32,16]{1,0}) fusion(%{{.*}}), kind=kInput, calls=%fused_computation
+)");
+}
+
+TEST_F(TransposeMultiOutputFusionTest, MultipleCopiesDifferentTypes) {
+  const char* hlo = R"(
+HloModule module
+
+fused_computation {
+  param_0.1 = f16[16,32]{1,0} parameter(0)
+  s.1 = f32[16,32]{1,0} convert(param_0.1)
+  ROOT c.1 = f32[16,32]{0,1} copy(s.1)
+}
+
+ENTRY main {
+  p = f16[16,32]{1,0} parameter(0)
+  fusion = f32[16,32]{0,1} fusion(p), kind=kInput, calls=fused_computation
+  c1 = f16[16,32]{0,1} copy(p)
+  ROOT t = (f32[16,32]{0,1}, f16[16,32]{0,1}) tuple(fusion, c1)
+}
+  )";
+
+  CheckMultiOutputFusion(hlo, R"(
+// CHECK: %fused_computation (param_0.1: f16[16,32]) -> (f32[16,32], f16[16,32]) {
+// CHECK-NEXT:   [[param_0_1_0:%[^ ]+]] = f16[16,32]{1,0} parameter(0)
+// CHECK-NEXT:   [[s_1_1:%[^ ]+]] = f32[16,32]{1,0} convert([[param_0_1_0]])
+// CHECK-NEXT:   [[c_1_2:%[^ ]+]] = f32[16,32]{0,1} copy([[s_1_1]])
+// CHECK-NEXT:   [[c1_1_3:%[^ ]+]] = f16[16,32]{0,1} copy([[param_0_1_0]])
+// CHECK-NEXT:   ROOT [[tuple_4:%[^ ]+]] = (f32[16,32]{0,1}, f16[16,32]{0,1}) tuple([[c_1_2]], [[c1_1_3]])
+// CHECK:   [[fusion_5:%[^ ]+]] = (f32[16,32]{0,1}, f16[16,32]{0,1}) fusion([[p_6:%[^ ]+]]), kind=kInput, calls=[[fused_computation_7:%[^ ]+]]
+)");
+}
+
+// Do not group copy and reduction.
+TEST_F(TransposeMultiOutputFusionTest, TiledReduceCopy) {
+  const char* hlo = R"(
+HloModule module
+
+add {
+  lhs = f32[] parameter(0)
+  rhs = f32[] parameter(1)
+  ROOT add = add(lhs, rhs)
+}
+
+fused_computation {
+  param_0.1 = f32[16,32]{1,0} parameter(0)
+  s.1 = f32[16,32]{1,0} sqrt(param_0.1)
+  ROOT c.1 = f32[16,32]{0,1} copy(s.1)
+}
+
+ENTRY main {
+  p = f32[16,32]{1,0} parameter(0)
+  fusion = f32[16,32]{0,1} fusion(p), kind=kInput, calls=fused_computation
+  z = f32[] constant(0)
+  r1 = f32[32]{0} reduce(p, z), dimensions={0}, to_apply=add
+  ROOT t = (f32[16,32]{0,1}, f32[32]{0}) tuple(fusion, r1)
+}
+  )";
+
+  CheckMultiOutputFusion(hlo, std::nullopt);
+}
+
+// Do not group incompatible transposes.
+TEST_F(TransposeMultiOutputFusionTest, IncompatibleTransposes) {
+  const char* hlo = R"(
+HloModule module
+
+fused_computation {
+  param_0.1 = f32[18,16,32]{2,1,0} parameter(0)
+  param_1.1 = f32[32,16,18]{2,1,0} parameter(1)
+  s.1 = f32[18,16,32]{2,1,0} sqrt(param_0.1)
+  t.1 = f32[32,16,18]{2,1,0} transpose(s.1), dimensions={2,1,0}
+  sub.1 = f32[32,16,18]{2,1,0} subtract(t.1, param_1.1)
+  exp.1 = f32[32,16,18]{2,1,0} exponential(sub.1)
+  ROOT add.1 = f32[32,16,18]{2,1,0} add(exp.1, exp.1)
+}
+
+fused_computation.2 {
+  param_0.2 = f32[18,16,32]{2,1,0} parameter(0)
+  s.2 = f32[18,16,32]{2,1,0} sqrt(param_0.2)
+  ROOT t.2 = f32[18,32,16]{2,1,0} transpose(s.2), dimensions={0,2,1}
+}
+
+ENTRY main {
+  p = f32[18,16,32]{2,1,0} parameter(0)
+  p2 = f32[32,16,18]{2,1,0} parameter(1)
+  fusion = f32[32,16,18]{2,1,0} fusion(p, p2), kind=kLoop, calls=fused_computation
+  fusion2 = f32[18,32,16]{2,1,0} fusion(p), kind=kInput, calls=fused_computation.2
+  ROOT t = (f32[32,16,18]{2,1,0}, f32[18,32,16]{2,1,0}) tuple(fusion, fusion2)
+}
+  )";
+
+  CheckMultiOutputFusion(hlo, std::nullopt);
+}
+
+// A variation of the test above, where no CSE was run.
+TEST_F(TransposeMultiOutputFusionTest, TransposesNoCSE) {
+  const char* hlo = R"(
+HloModule module
+
+fused_computation {
+  param_0.1 = f32[18,16,32]{2,1,0} parameter(0)
+  param_1.1 = f32[32,16,18]{2,1,0} parameter(1)
+  s.1 = f32[18,16,32]{2,1,0} sqrt(param_0.1)
+  t.1 = f32[32,16,18]{2,1,0} transpose(s.1), dimensions={2,1,0}
+  sub.1 = f32[32,16,18]{2,1,0} subtract(t.1, param_1.1)
+  exp.1 = f32[32,16,18]{2,1,0} exponential(sub.1)
+  exp.2 = f32[32,16,18]{2,1,0} exponential(sub.1)
+  ROOT add.1 = f32[32,16,18]{2,1,0} add(exp.1, exp.2)
+}
+
+fused_computation.2 {
+  param_0.2 = f32[18,16,32]{2,1,0} parameter(0)
+  s.2 = f32[18,16,32]{2,1,0} sqrt(param_0.2)
+  ROOT t.2 = f32[18,32,16]{2,1,0} transpose(s.2), dimensions={0,2,1}
+}
+
+ENTRY main {
+  p = f32[18,16,32]{2,1,0} parameter(0)
+  p2 = f32[32,16,18]{2,1,0} parameter(1)
+  fusion = f32[32,16,18]{2,1,0} fusion(p, p2), kind=kLoop, calls=fused_computation
+  fusion2 = f32[18,32,16]{2,1,0} fusion(p), kind=kInput, calls=fused_computation.2
+  ROOT t = (f32[32,16,18]{2,1,0}, f32[18,32,16]{2,1,0}) tuple(fusion, fusion2)
+}
+  )";
+
+  CheckMultiOutputFusion(hlo, std::nullopt);
+}
+
+TEST_F(TransposeMultiOutputFusionTest, CopyAndInput) {
+  const char* hlo = R"(
+HloModule module
+
+fused_computation {
+  param_0.1 = f32[16,32]{1,0} parameter(0)
+  s.1 = f32[16,32]{1,0} sqrt(param_0.1)
+  ROOT c.1 = f32[16,32]{0,1} copy(s.1)
+}
+
+ENTRY main {
+  p = f32[16,32]{1,0} parameter(0)
+  fusion = f32[16,32]{0,1} fusion(p), kind=kInput, calls=fused_computation
+  c1 = exponential(p)
+  ROOT t = tuple(fusion, c1)
+}
+  )";
+
+  CheckMultiOutputFusion(hlo, R"(
+// CHECK: %fused_computation (param_0.1: f32[16,32]) -> (f32[16,32], f32[16,32]) {
+// CHECK-NEXT:   [[param_0_1_0:%[^ ]+]] = f32[16,32]{1,0} parameter(0)
+// CHECK-NEXT:   [[s_1_1:%[^ ]+]] = f32[16,32]{1,0} sqrt([[param_0_1_0]])
+// CHECK-NEXT:   [[c_1_2:%[^ ]+]] = f32[16,32]{0,1} copy([[s_1_1]])
+// CHECK-NEXT:   [[c1_1_3:%[^ ]+]] = f32[16,32]{1,0} exponential([[param_0_1_0]])
+// CHECK-NEXT:   ROOT [[tuple_4:%[^ ]+]] = (f32[16,32]{0,1}, f32[16,32]{1,0}) tuple([[c_1_2]], [[c1_1_3]])
+// CHECK-NEXT: }
+// CHECK:   [[fusion_0:%[^ ]+]] = (f32[16,32]{0,1}, f32[16,32]{1,0}) fusion([[p_1:%[^ ]+]]), kind=kInput, calls=[[fused_computation_2:%[^ ]+]]
+)");
+}
+
+TEST_F(TransposeMultiOutputFusionTest, TransposeAndInputEpilogueFusion) {
+  const char* hlo = R"(
+HloModule module
+
+fused_computation {
+  param_0.1 = f32[16,32]{1,0} parameter(0)
+  s.1 = f32[16,32]{1,0} sqrt(param_0.1)
+  t.1 = f32[32,16]{1,0} transpose(s.1), dimensions={1,0}
+  ROOT out = f32[32,16,1]{2,1,0} bitcast(t.1)
+}
+
+ENTRY main {
+  p = f32[16,32]{1,0} parameter(0)
+  fusion = f32[32,16,1]{2,1,0} fusion(p), kind=kInput, calls=fused_computation
+  c1 = exponential(p)
+  ROOT t = tuple(fusion, c1)
+}
+  )";
+
+  CheckMultiOutputFusion(hlo, R"(
+// CHECK: %fused_computation
+// CHECK-NEXT:   [[param_0_1_0:%[^ ]+]] = f32[16,32]{1,0} parameter(0)
+// CHECK-NEXT:   [[s_1_1:%[^ ]+]] = f32[16,32]{1,0} sqrt([[param_0_1_0]])
+// CHECK-NEXT:   [[c_1_2:%[^ ]+]] = f32[32,16]{1,0} transpose([[s_1_1]])
+// CHECK-NEXT:   [[out_3:%[^ ]+]] = f32[32,16,1]{2,1,0} bitcast([[c_1_2]])
+// CHECK-NEXT:   [[c1_1_4:%[^ ]+]] = f32[16,32]{1,0} exponential([[param_0_1_0]])
+// CHECK-NEXT:   ROOT [[tuple_5:%[^ ]+]] = (f32[32,16,1]{2,1,0}, f32[16,32]{1,0}) tuple([[out_3]], [[c1_1_4]])
+// CHECK-NEXT: }
+// CHECK: [[fusion_0:%[^ ]+]] = (f32[32,16,1]{2,1,0}, f32[16,32]{1,0}) fusion([[p_1:%[^ ]+]]), kind=kInput, calls=[[fused_computation_2:%[^ ]+]]
+)");
+}
+
+class ReduceMultiOutputFusionTest : public MultiOutputFusionTest {};
+
+TEST_F(ReduceMultiOutputFusionTest, ReduceAndLoop) {
+  const char* hlo = R"(
+HloModule module
+
+add {
+  a = f32[] parameter(0)
+  b = f32[] parameter(1)
+  ROOT c = f32[] add(a, b)
+}
+
+fused_reduction {
+  p = f32[200] parameter(0)
+  z = f32[] constant(0)
+  e = f32[200] exponential(p)
+  ROOT r = f32[] reduce(e, z), dimensions={0}, to_apply=add
+}
+
+fused_elementwise {
+  p = f32[200] parameter(0)
+  ROOT r = f32[200] sqrt(p)
+}
+
+ENTRY computation {
+  p = f32[200] parameter(0)
+  o1 = f32[200] fusion(p), kind=kLoop, calls=fused_elementwise
+  o2 = f32[] fusion(p), kind=kInput, calls=fused_reduction
+  ROOT out = (f32[200], f32[]) tuple(o1, o2)
+}
+
+)";
+
+  CheckMultiOutputFusion(hlo, R"(
+// CHECK: %fused_elementwise
+// CHECK-NEXT:  [[p_1_0:%[^ ]+]] = f32[200]{0} parameter(0)
+// CHECK-NEXT:  [[r_1_1:%[^ ]+]] = f32[200]{0} sqrt([[p_1_0]])
+// CHECK-NEXT:  [[e_2:%[^ ]+]].clone.1 = f32[200]{0} exponential([[p_1_0]])
+// CHECK-NEXT:  [[z_3:%[^ ]+]].clone.1 = f32[] constant(0)
+// CHECK-NEXT:  [[r_4:%[^ ]+]].clone.1 = f32[] reduce([[e_2]].clone.1, [[z_3]].clone.1), dimensions={0}, to_apply=[[add_5:%[^ ]+]]
+// CHECK-NEXT:  ROOT [[tuple_6:%[^ ]+]] = (f32[200]{0}, f32[]) tuple([[r_1_1]], [[r_4]].clone.1)
+// CHECK-NEXT:}
+// CHECK: [[o1_0:%[^ ]+]] = (f32[200]{0}, f32[]) fusion([[p_2_1:%[^ ]+]]), kind=kInput, calls=[[fused_elementwise_2:%[^ ]+]]
+  )");
+}
+
+TEST_F(ReduceMultiOutputFusionTest, ReduceAndLoopDifferentShape) {
+  const char* hlo = R"(
+HloModule module
+
+add {
+  a = f32[] parameter(0)
+  b = f32[] parameter(1)
+  ROOT c = f32[] add(a, b)
+}
+
+fused_reduction {
+  p = f32[10,20] parameter(0)
+  z = f32[] constant(0)
+  e = f32[10,20] exponential(p)
+  b = f32[200] bitcast(e)
+  ROOT r = f32[] reduce(b, z), dimensions={0}, to_apply=add
+}
+
+fused_elementwise {
+  p = f32[10,20] parameter(0)
+  ROOT r = f32[10,20] sqrt(p)
+}
+
+ENTRY computation {
+  p = f32[10,20] parameter(0)
+  o1 = f32[10,20] fusion(p), kind=kLoop, calls=fused_elementwise
+  o2 = f32[] fusion(p), kind=kInput, calls=fused_reduction
+  ROOT out = (f32[10,20], f32[]) tuple(o1, o2)
+}
+)";
+
+  CheckMultiOutputFusion(hlo, R"(
+// CHECK: %fused_elementwise (p.1: f32[10,20]) -> (f32[10,20], f32[]) {
+// CHECK-NEXT:   [[p_1_0:%[^ ]+]] = f32[10,20]{1,0} parameter(0)
+// CHECK-NEXT:   [[r_1_1:%[^ ]+]] = f32[10,20]{1,0} sqrt([[p_1_0]])
+// CHECK-NEXT:   [[e_2:%[^ ]+]].clone.1 = f32[10,20]{1,0} exponential([[p_1_0]])
+// CHECK-NEXT:   [[b_1_3:%[^ ]+]].clone.1 = f32[200]{0} bitcast([[e_2]].clone.1)
+// CHECK-NEXT:   [[z_4:%[^ ]+]].clone.1 = f32[] constant(0)
+// CHECK-NEXT:   [[r_5:%[^ ]+]].clone.1 = f32[] reduce([[b_1_3]].clone.1, [[z_4]].clone.1), dimensions={0}, to_apply=[[add_6:%[^ ]+]]
+// CHECK-NEXT:   ROOT [[tuple_7:%[^ ]+]] = (f32[10,20]{1,0}, f32[]) tuple([[r_1_1]], [[r_5]].clone.1)
+// CHECK-NEXT: }
+  )");
+}
+
+TEST_F(ReduceMultiOutputFusionTest, ReduceAndLoopDifferentShapeDifferentType) {
+  const char* hlo = R"(
+HloModule module, entry_computation_layout={(f16[100,200]{1,0},f32[],f32[])->(f16[100,200]{1,0}, f32[])}
+
+max {
+  a = f32[] parameter(0)
+  b = f32[] parameter(1)
+  ROOT c = f32[] maximum(a, b)
+}
+
+fused_computation {
+  one_5 = f32[] constant(1)
+  one_b.5 = f32[100,200]{1,0} broadcast(one_5), dimensions={}
+  param_1.15 = f16[100,200]{1,0} parameter(1)
+  c.6 = f32[100,200]{1,0} convert(param_1.15)
+  param_0.11 = f32[] parameter(0)
+  b.6 = f32[100,200]{1,0} broadcast(param_0.11), dimensions={}
+  d.5 = f32[100,200]{1,0} divide(c.6, b.6)
+  a.6 = f32[100,200]{1,0} add(one_b.5, d.5)
+  bitcast.1 = f32[20000]{0} bitcast(a.6)
+  z_1 = f32[] constant(0)
+  ROOT r.1 = f32[] reduce(bitcast.1, z_1), dimensions={0}, to_apply=max
+}
+
+fused_computation.1 {
+  one_3 = f32[] constant(1)
+  one_b.3 = f32[100,200]{1,0} broadcast(one_3), dimensions={}
+  param_2.7 = f16[100,200]{1,0} parameter(2)
+  c.4 = f32[100,200]{1,0} convert(param_2.7)
+  param_1.10 = f32[] parameter(1)
+  b.4 = f32[100,200]{1,0} broadcast(param_1.10), dimensions={}
+  d.3 = f32[100,200]{1,0} divide(c.4, b.4)
+  a.4 = f32[100,200]{1,0} add(one_b.3, d.3)
+  param_0.8 = f32[] parameter(0)
+  output_scale_broadcast.1 = f32[100,200]{1,0} broadcast(param_0.8), dimensions={}
+  a_scaled.1 = f32[100,200]{1,0} multiply(a.4, output_scale_broadcast.1)
+  ROOT a_scaled_converted.1 = f16[100,200]{1,0} convert(a_scaled.1)
+}
+
+ENTRY computation {
+  output_scale = f32[] parameter(2)
+  input_scale = f32[] parameter(1)
+  p = f16[100,200]{1,0} parameter(0)
+  fusion.1 = f16[100,200]{1,0} fusion(output_scale, input_scale, p), kind=kLoop, calls=fused_computation.1
+  fusion = f32[] fusion(input_scale, p), kind=kInput, calls=fused_computation
+  ROOT out = (f16[100,200]{1,0}, f32[]) tuple(fusion.1, fusion)
+}
+)";
+
+  CheckMultiOutputFusion(hlo, R"(
+// CHECK: %fused_computation.1 (param_0.8: f32[], param_1.10: f32[], param_2.7: f16[100,200]) -> (f16[100,200], f32[]) {
+// CHECK-NEXT:   [[one_3_0:%[^ ]+]] = f32[] constant(1)
+// CHECK-NEXT:   [[one_b_3_1:%[^ ]+]] = f32[100,200]{1,0} broadcast([[one_3_0]]), dimensions={}
+// CHECK-NEXT:   [[param_2_7_2:%[^ ]+]] = f16[100,200]{1,0} parameter(2)
+// CHECK-NEXT:   [[c_4_3:%[^ ]+]] = f32[100,200]{1,0} convert([[param_2_7_2]])
+// CHECK-NEXT:   [[param_1_10_4:%[^ ]+]] = f32[] parameter(1)
+// CHECK-NEXT:   [[b_4_5:%[^ ]+]] = f32[100,200]{1,0} broadcast([[param_1_10_4]]), dimensions={}
+// CHECK-NEXT:   [[d_3_6:%[^ ]+]] = f32[100,200]{1,0} divide([[c_4_3]], [[b_4_5]])
+// CHECK-NEXT:   [[a_4_7:%[^ ]+]] = f32[100,200]{1,0} add([[one_b_3_1]], [[d_3_6]])
+// CHECK-NEXT:   [[param_0_8_8:%[^ ]+]] = f32[] parameter(0)
+// CHECK-NEXT:   [[output_scale_broadcast_1_9:%[^ ]+]] = f32[100,200]{1,0} broadcast([[param_0_8_8]]), dimensions={}
+// CHECK-NEXT:   [[a_scaled_1_10:%[^ ]+]] = f32[100,200]{1,0} multiply([[a_4_7]], [[output_scale_broadcast_1_9]])
+// CHECK-NEXT:   [[a_scaled_converted_1_11:%[^ ]+]] = f16[100,200]{1,0} convert([[a_scaled_1_10]])
+// CHECK-NEXT:   [[one_5_12:%[^ ]+]].clone.1 = f32[] constant(1)
+// CHECK-NEXT:   [[one_b_5_13:%[^ ]+]].clone.1 = f32[100,200]{1,0} broadcast([[one_5_12]].clone.1), dimensions={}
+// CHECK-NEXT:   [[c_6_14:%[^ ]+]].clone.1 = f32[100,200]{1,0} convert([[param_2_7_2]])
+// CHECK-NEXT:   [[b_6_15:%[^ ]+]].clone.1 = f32[100,200]{1,0} broadcast([[param_1_10_4]]), dimensions={}
+// CHECK-NEXT:   [[d_5_16:%[^ ]+]].clone.1 = f32[100,200]{1,0} divide([[c_6_14]].clone.1, [[b_6_15]].clone.1)
+// CHECK-NEXT:   [[a_6_17:%[^ ]+]].clone.1 = f32[100,200]{1,0} add([[one_b_5_13]].clone.1, [[d_5_16]].clone.1)
+// CHECK-NEXT:   [[bitcast_1_18:%[^ ]+]].clone.1 = f32[20000]{0} bitcast([[a_6_17]].clone.1)
+// CHECK-NEXT:   [[z_1_19:%[^ ]+]].clone.1 = f32[] constant(0)
+// CHECK-NEXT:   [[r_1_20:%[^ ]+]].clone.1 = f32[] reduce([[bitcast_1_18]].clone.1, [[z_1_19]].clone.1), dimensions={0}, to_apply=[[max_21:%[^ ]+]]
+// CHECK-NEXT:   ROOT [[tuple_22:%[^ ]+]] = (f16[100,200]{1,0}, f32[]) tuple([[a_scaled_converted_1_11]], [[r_1_20]].clone.1)
+// CHECK-NEXT: }
+  )");
+}
+
+TEST_F(ReduceMultiOutputFusionTest, GetTupleElementMakeTupleSequence) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule test_module
+
+    fusion {
+      p0 = s32[] parameter(0)
+      p1 = s32[32] parameter(1)
+      custom-call = (bf16[], s32[], u32[]) custom-call(p1), custom_call_target="my_custom_call"
+      get-tuple-element.0 = bf16[] get-tuple-element(custom-call), index=0
+      get-tuple-element.1 = s32[] get-tuple-element(custom-call), index=1
+      bitcast = s32[1] bitcast(get-tuple-element.1)
+      dynamic-update-slice = s32[32] dynamic-update-slice(p1, bitcast, p0)
+      get-tuple-element.2 = u32[] get-tuple-element(custom-call), index=2
+      ROOT tuple.30 = (bf16[], s32[32], u32[]) tuple(get-tuple-element.0, dynamic-update-slice, get-tuple-element.2)
+    }
+
+    ENTRY entry{
+      p0 = s32[] parameter(0)
+      bitcast = s32[32] bitcast(p0)
+      ROOT address_computation.7.0 = (bf16[], s32[32], u32[]) fusion(p0, bitcast), kind=kCustom, calls=fusion
+    }
+  )")
+                    .value();
+
+  ASSERT_FALSE(mof_.Run(module.get()).value());
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/pipelined_p2p_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/pipelined_p2p_rewriter.cc
new file mode 100644
index 0000000..493d167
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/pipelined_p2p_rewriter.cc
@@ -0,0 +1,703 @@
+
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/pipelined_p2p_rewriter.h"
+
+#include <cstdint>
+#include <optional>
+#include <utility>
+#include <vector>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/hlo/ir/dfs_hlo_visitor.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/hlo/ir/hlo_schedule.h"
+#include "xla/hlo/utils/hlo_query.h"
+#include "xla/service/collective_ops_utils.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/util.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+// Maps a computation to a boolean that indicates whether there is any
+// collective operations directly or indirectly invoked in the computation.
+using CollectiveInComputation =
+    absl::flat_hash_map<const HloComputation*, bool>;
+
+using InstructionVector = HloInstruction::InstructionVector;
+
+// Records starting index and the ending index of a pipelined while-op. They
+// are the indices of the while-loop operand.
+struct PipelinedP2PInfo {
+  int64_t opnd_start;
+  int64_t opnd_end;
+};
+
+// Returns whether the instruction is a collective operation.
+bool IsCollectiveOp(const HloInstruction* op) {
+  HloOpcode opcode = op->opcode();
+  // TODO(NVIDIA/4364298): The information is recorded in b/309639264.
+  // we need to avoid custom-calls to overlap with Send/Recv to workaround the
+  // bug. Remove custom-calls here when the bug is fixed.
+  if (opcode == HloOpcode::kCustomCall) {
+    return true;
+  }
+
+  return hlo_query::IsCollectiveCommunicationOp(opcode) ||
+         opcode == HloOpcode::kSend || opcode == HloOpcode::kRecv;
+}
+
+// Returns whether the instruction may invoke collective operations directly
+// or indirectly.
+bool MayInvokeCollectiveOp(
+    const HloInstruction* hlo,
+    const CollectiveInComputation& collective_in_computation) {
+  if (IsCollectiveOp(hlo)) {
+    return true;
+  }
+  for (HloComputation* callee : hlo->called_computations()) {
+    auto collective_in_comp = collective_in_computation.find(callee);
+    CHECK(collective_in_comp != collective_in_computation.end());
+    if (collective_in_comp->second) {
+      return true;
+    }
+  }
+  return false;
+}
+
+// Returns the unique get-tuple-element user with the given idx or nullptr if
+// there isn't such a unique user.
+HloInstruction* FindUniqueGTEUserWithIndex(const HloInstruction* op,
+                                           int64_t idx) {
+  CHECK(op->shape().IsTuple());
+
+  HloInstruction* gte = nullptr;
+  for (auto user : op->users()) {
+    if (user->opcode() != HloOpcode::kGetTupleElement) {
+      continue;
+    }
+    if (user->tuple_index() == idx) {
+      if (gte == nullptr) {
+        gte = user;
+      } else {
+        return nullptr;
+      }
+    }
+  }
+  return gte;
+}
+
+// Returns whether there is any get-tuple-element user with the given idx.
+bool HasGTEUserWithIndex(const HloInstruction* op, int64_t idx) {
+  CHECK(op->shape().IsTuple());
+
+  for (auto user : op->users()) {
+    if (user->opcode() != HloOpcode::kGetTupleElement) {
+      continue;
+    }
+    if (user->tuple_index() == idx) {
+      return true;
+    }
+  }
+  return false;
+}
+
+// Returns the instruction hidden behind a trivial tuple or `op`. This allows
+// the discovery of recv-done for the following case, for which the indirection
+// would have been removed by tuple-simplification.
+//    gte.0 = f32[1,1024,1024] get-tuple-element(recv-done), index=0
+//    gte.1 = token get-tuple-element(recv-done.p), index=1
+//    op = (f32[1,1024,1024], token[]) tuple(gte.0, gte.1)
+//
+// TODO(bixia): investigate the possible of implementing
+// m::TrivialTuple(m::RecvDone(&instr)) as suggested by code review.
+HloInstruction* MaySkipTrivialTuple(HloInstruction* op) {
+  if (op->opcode() != HloOpcode::kTuple) {
+    return op;
+  }
+  HloInstruction* hidden_op = nullptr;
+  for (auto opnd : op->mutable_operands()) {
+    if (opnd->opcode() != HloOpcode::kGetTupleElement) {
+      return op;
+    }
+    if (hidden_op == nullptr) {
+      hidden_op = opnd->mutable_operand(0);
+    } else if (opnd->mutable_operand(0) != hidden_op) {
+      return op;
+    }
+  }
+  return hidden_op;
+}
+
+// This routine is similar to the non-const version above except that the
+// the given instruction is used for pattern checking only and can't be mutated.
+const HloInstruction* MaySkipTrivialTuple(const HloInstruction* op) {
+  // Use const_cast to avoid repeating the non-const version above to find
+  // operands of the instruction through operands() instead of
+  // mutable_operands().
+  return MaySkipTrivialTuple(const_cast<HloInstruction*>(op));
+}
+
+// Finds a consecutive block of balanced SendDone/RecvDone in the while_init
+// of a while-loop, assuming its while_init is a tuple.
+std::optional<PipelinedP2PInfo>
+FindConsecutiveAndBalanceBlockOfSendDoneRecvDone(
+    const HloInstruction* while_init) {
+  PipelinedP2PInfo pipelined_p2p_info{0, 0};
+  // Return whether the first SendDone/RecvDone has been seen.
+  auto has_started = [&]() {
+    return pipelined_p2p_info.opnd_start != pipelined_p2p_info.opnd_end;
+  };
+  // Record the difference between the number of SendDone and RecvDone in a
+  // consecutive block.
+  int difference = 0;
+  // If SendDone/RecvDone exists in a consecutive block in the while_init
+  // tuple, find such block.
+  for (int64_t i = 0; i < while_init->operand_count(); ++i) {
+    const HloInstruction* op = while_init->operand(i);
+    if ((op->opcode() == HloOpcode::kRecvDone ||
+         op->opcode() == HloOpcode::kSendDone) &&
+        op->frontend_attributes().map().count(kSendRecvPipelineAttr) > 0) {
+      if (op->opcode() == HloOpcode::kRecvDone) {
+        difference++;
+      } else {
+        difference--;
+      }
+      if (!has_started()) {
+        pipelined_p2p_info.opnd_start = i;
+      }
+      pipelined_p2p_info.opnd_end = i + 1;
+    } else {
+      if (has_started()) {
+        VLOG(10) << "End a consecutive block";
+        break;
+      }
+    }
+  }
+
+  if (difference != 0) {
+    VLOG(10) << "Mismatch number of SendDone and RecvDone: " << difference;
+    return std::nullopt;
+  }
+
+  if (has_started()) {
+    // Check for SendDone/RecvDone outside the consecutive block.
+    for (int64_t i = pipelined_p2p_info.opnd_end;
+         i < while_init->operand_count(); ++i) {
+      const HloInstruction* op = while_init->operand(i);
+      if (op->opcode() == HloOpcode::kRecvDone ||
+          op->opcode() == HloOpcode::kSendDone) {
+        VLOG(10) << "SendDone/RecvDone outside the consecutive block";
+        return std::nullopt;
+        break;
+      }
+    }
+  }
+
+  if (!has_started()) {
+    VLOG(10) << "No SendDone/RecvDone in while-init ";
+    return std::nullopt;
+  }
+
+  return pipelined_p2p_info;
+}
+
+// Checks whether the while-op, its while-body and while-condition have a
+// recognized pipelined pattern. If a pipelined pattern is found, returns the
+// first and last indices for the pipelined instruction in the while-init tuple.
+// For pipelined Send/Recv to work, the SendDone/RecvDone doesn't have to be in
+// a consecutive block, but this simplifies the implementation and is the
+// pattern that the current gpu-p2p-pipeliner generated.
+//
+// As a summary, this is what the routine looks for:
+//
+// . The while-init has a tuple with a single user.
+// . The while-init has a consecutive block of SendDone and RecvDone. The
+//   numbers of SendDone and RecvDone are the same, and there isn't any other
+//   SendDone and RecvDone outside the block.
+// . The while-body has a single tuple parameter.
+// . For the while-op result tuple and the while-body parameter tuple:
+//     The index corresponding to the index of SendDone in while-init should not
+//       correspond to any get-element-tuple user.
+//     The index corresponding to the index of RecvDone in while-init should
+//       correspond to a single get-element-tuple user.
+// . In the while-body result tuple, the operand with an index corresponding to
+//   the index in the while-init SendDone and RecvDone should also be a SendDone
+//   or RecvDone.
+//
+// TODO(bixia): support pipelined SendDone/RecvDone not in a consecutive block
+// if the gpu-p2p-pipeliner will ever generate such code in the future.
+std::optional<PipelinedP2PInfo> FindPipelinedP2P(
+    const HloInstruction* while_op) {
+  VLOG(10) << "while_op: " << while_op->ToString();
+  const HloInstruction* while_init = while_op->while_init();
+  if (while_init->opcode() != HloOpcode::kTuple ||
+      while_init->user_count() != 1) {
+    return std::nullopt;
+  }
+
+  // The while-body and while-condition should have one parameter of a tuple
+  // shape.
+  const HloComputation* while_body = while_op->while_body();
+  const HloComputation* while_condition = while_op->while_condition();
+  if (while_body->num_parameters() != 1 ||
+      while_condition->num_parameters() != 1) {
+    return std::nullopt;
+  }
+
+  std::optional<PipelinedP2PInfo> pipelined_p2p_info =
+      FindConsecutiveAndBalanceBlockOfSendDoneRecvDone(while_init);
+  if (!pipelined_p2p_info.has_value()) {
+    return std::nullopt;
+  }
+
+  VLOG(10) << "opnd_start " << pipelined_p2p_info->opnd_start << " opnd_end "
+           << pipelined_p2p_info->opnd_end;
+
+  // In the while-result or while-body parameter, the index for RecvDone should
+  // correspond to one get-tuple-element user and the index for SendDone should
+  // not correspond to any get-tuple-element user.
+  for (int64_t i = pipelined_p2p_info->opnd_start;
+       i < pipelined_p2p_info->opnd_end; ++i) {
+    const HloInstruction* op = while_init->operand(i);
+    if (op->opcode() == HloOpcode::kRecvDone) {
+      if (!FindUniqueGTEUserWithIndex(while_op, i)) {
+        VLOG(10) << "While result get-tuple-element user with index " << i
+                 << " not unique";
+        return std::nullopt;
+      }
+      if (!FindUniqueGTEUserWithIndex(while_body->parameter_instruction(0),
+                                      i)) {
+        VLOG(10) << "While-body parameter get-tuple-element user with index "
+                 << i << " not unique";
+        return std::nullopt;
+      }
+    } else {
+      CHECK(op->opcode() == HloOpcode::kSendDone);
+      if (HasGTEUserWithIndex(while_op, i) ||
+          HasGTEUserWithIndex(while_body->parameter_instruction(0), i)) {
+        VLOG(10) << "SendDone with index " << i << " has unexpected users";
+        return std::nullopt;
+      }
+    }
+  }
+
+  // The element in the while-body result tuple corresponding to the pipelined
+  // SendDone/RecvDone in the while-init have the same opcode.
+  const HloInstruction* root = while_body->root_instruction();
+  for (int64_t i = pipelined_p2p_info->opnd_start;
+       i < pipelined_p2p_info->opnd_end; ++i) {
+    const HloInstruction* op_init = while_init->operand(i);
+    const HloInstruction* op_root = root->operand(i);
+    op_root = MaySkipTrivialTuple(op_root);
+    if (op_init->opcode() != op_root->opcode()) {
+      VLOG(10) << "Mismatching opcode, op_init: " << op_init->ToString()
+               << " op_root: " << op_root->ToString();
+      return std::nullopt;
+    }
+  }
+
+  return pipelined_p2p_info.value();
+}
+
+absl::Status RemoveOpFromParent(HloInstruction* op) {
+  TF_RETURN_IF_ERROR(op->DropAllControlDeps());
+  TF_RETURN_IF_ERROR(op->parent()->RemoveInstruction(op));
+  return absl::OkStatus();
+}
+
+absl::Status ReplaceOpInSequence(HloInstruction* old_op, HloInstruction* new_op,
+                                 HloInstructionSequence& instruction_sequence) {
+  VLOG(10) << "old_op: " << old_op->ToString();
+  VLOG(10) << "new_op: " << new_op->ToString();
+  instruction_sequence.replace_instruction(old_op, new_op);
+  return RemoveOpFromParent(old_op);
+}
+
+absl::Status ReplaceUsesAndUpdateSequence(
+    HloInstruction* old_op, HloInstruction* new_op,
+    HloInstructionSequence& instruction_sequence, bool diff_shape = false) {
+  VLOG(10) << "old_op: " << old_op->ToString();
+  VLOG(10) << "new_op: " << new_op->ToString();
+  if (diff_shape) {
+    TF_RETURN_IF_ERROR(old_op->ReplaceAllUsesWithDifferentShape(new_op));
+  } else {
+    TF_RETURN_IF_ERROR(old_op->ReplaceAllUsesWith(new_op));
+  }
+  return ReplaceOpInSequence(old_op, new_op, instruction_sequence);
+}
+
+absl::Status ReplaceUsesAndUpdateSequence(
+    const InstructionVector& old_ops, const InstructionVector& new_ops,
+    HloInstructionSequence& instruction_sequence) {
+  CHECK(old_ops.size() == new_ops.size());
+  for (int64_t i = 0; i < old_ops.size(); ++i) {
+    TF_RETURN_IF_ERROR(ReplaceUsesAndUpdateSequence(old_ops[i], new_ops[i],
+                                                    instruction_sequence));
+  }
+  return absl::OkStatus();
+}
+
+absl::Status RemoveDoneOpsAndUpdateSequence(
+    const InstructionVector& ops,
+    HloInstructionSequence& instruction_sequence) {
+  auto remove_op = [&](HloInstruction* op) {
+    VLOG(10) << "op: " << op->ToString();
+    TF_RETURN_IF_ERROR(RemoveOpFromParent(op));
+    instruction_sequence.remove_instruction(op);
+    return absl::OkStatus();
+  };
+  for (auto op : ops) {
+    if (op->opcode() == HloOpcode::kTuple) {
+      InstructionVector to_remove;
+      HloInstruction* tuple_op = op;
+      op = MaySkipTrivialTuple(tuple_op);
+      to_remove.push_back(tuple_op);
+      for (auto opnd : tuple_op->mutable_operands()) {
+        to_remove.push_back(opnd);
+      }
+      for (auto opnd : to_remove) {
+        TF_RETURN_IF_ERROR(remove_op(opnd));
+      }
+    }
+    TF_RETURN_IF_ERROR(remove_op(op));
+  }
+  return absl::OkStatus();
+}
+
+bool InsertBeforeFirstCollectiveOp(
+    const InstructionVector& ops,
+    const CollectiveInComputation& collective_in_computation,
+    HloInstructionSequence& instruction_sequence, int64_t& idx,
+    int64_t& idx_tot) {
+  bool inserted = false;
+  while (idx < idx_tot) {
+    HloInstruction* hlo = instruction_sequence.instructions()[idx];
+    if (MayInvokeCollectiveOp(hlo, collective_in_computation)) {
+      for (auto op : ops) {
+        instruction_sequence.insert_instruction(op, idx);
+        idx++;
+        idx_tot++;
+      }
+      inserted = true;
+      break;
+    }
+    idx++;
+  }
+  return inserted;
+}
+
+void CopyInstructionInfo(const HloInstruction* old_op, HloInstruction* new_op) {
+  new_op->SetAndSanitizeName(absl::StrCat(old_op->name(), ".clone"));
+  new_op->set_metadata(old_op->metadata());
+  new_op->add_frontend_attributes(old_op->frontend_attributes());
+  new_op->CopyBackendConfigFrom(old_op);
+}
+
+HloInstruction* CreateRecvDoneFrom(const HloInstruction* old_recv_done,
+                                   HloInstruction* recv,
+                                   HloComputation* computation) {
+  HloInstruction* recv_done =
+      computation->AddInstruction(HloInstruction::CreateRecvDone(
+          recv, old_recv_done->channel_id().value()));
+  CopyInstructionInfo(old_recv_done, recv_done);
+  return recv_done;
+}
+
+HloInstruction* CreateSendDoneFrom(const HloInstruction* old_send_done,
+                                   HloInstruction* send,
+                                   HloComputation* computation) {
+  HloInstruction* send_done =
+      computation->AddInstruction(HloInstruction::CreateSendDone(
+          send, old_send_done->channel_id().value()));
+  CopyInstructionInfo(old_send_done, send_done);
+  return send_done;
+}
+
+absl::Status RewritePipelinedP2PWhileBody(
+    const CollectiveInComputation& collective_in_computation,
+    const std::vector<Shape>& new_parameter_shapes, HloInstruction* while_op,
+    int64_t opnd_start, int64_t opnd_end) {
+  HloComputation* computation = while_op->while_body();
+  HloInstruction* while_init = while_op->while_init();
+  HloInstruction* root = computation->root_instruction();
+  HloInstructionSequence& instruction_sequence =
+      computation->parent()->schedule().GetOrCreateSequence(computation);
+
+  HloInstruction* param = computation->parameter_instruction(0);
+  *param->mutable_shape() = ShapeUtil::MakeTupleShape(new_parameter_shapes);
+
+  InstructionVector recv_dones;
+  InstructionVector new_recv_dones;
+  InstructionVector new_send_dones;
+  for (int64_t i = opnd_start; i < opnd_end; ++i) {
+    const HloInstruction* op = root->operand(i);
+    op = MaySkipTrivialTuple(op);
+    if (op->opcode() == HloOpcode::kRecvDone) {
+      HloInstruction* gte = FindUniqueGTEUserWithIndex(param, i);
+      CHECK(gte != nullptr);
+      recv_dones.push_back(gte);
+
+      // Create the new RecvDone using the new while-body parameter.
+      HloInstruction* recv = computation->AddInstruction(
+          HloInstruction::CreateGetTupleElement(param, i));
+
+      HloInstruction* recv_done = CreateRecvDoneFrom(op, recv, computation);
+      new_recv_dones.push_back(recv_done);
+      continue;
+    }
+    CHECK(op->opcode() == HloOpcode::kSendDone);
+    //  Create the new SendDone using the new while-op result.
+    HloInstruction* send = computation->AddInstruction(
+        HloInstruction::CreateGetTupleElement(param, i));
+    HloInstruction* send_done = CreateSendDoneFrom(op, send, computation);
+    new_send_dones.push_back(send_done);
+  }
+  TF_RETURN_IF_ERROR(ReplaceUsesAndUpdateSequence(recv_dones, new_recv_dones,
+                                                  instruction_sequence));
+
+  // Create a new root tuple.
+  InstructionVector done_ops;
+  InstructionVector new_opnds;
+  for (int64_t i = 0; i < while_init->operand_count(); ++i) {
+    HloInstruction* op = root->mutable_operand(i);
+    if (i >= opnd_start && i < opnd_end) {
+      new_opnds.push_back(MaySkipTrivialTuple(op)->mutable_operand(0));
+      done_ops.push_back(op);
+    } else {
+      new_opnds.push_back(op);
+    }
+  }
+  HloInstruction* new_root =
+      computation->AddInstruction(HloInstruction::CreateTuple(new_opnds));
+  computation->set_root_instruction(new_root,
+                                    /*accept_different_shape=*/true);
+  TF_RETURN_IF_ERROR(computation->RemoveInstruction(root));
+  instruction_sequence.replace_instruction(root, new_root);
+
+  TF_RETURN_IF_ERROR(
+      RemoveDoneOpsAndUpdateSequence(done_ops, instruction_sequence));
+
+  // Find a place to put the new SendDone. It will be either the first
+  // may-invoke-collective ops that is not in the pipelined Send/Recv chain or
+  // the first op in the pipelined Send/Recv chain.
+  int64_t idx = 0;
+  int64_t idx_end = instruction_sequence.size();
+  bool inserted =
+      InsertBeforeFirstCollectiveOp(new_send_dones, collective_in_computation,
+                                    instruction_sequence, idx, idx_end);
+  CHECK(inserted);  // There are Send/Recv in the while-body, expect inserted.
+  CHECK(idx_end == instruction_sequence.size());
+
+  // The module schedule will be updated at the end of the pass.
+  return absl::OkStatus();
+}
+
+void RewritePipelinedP2PWhileCond(
+    const std::vector<Shape>& new_parameter_shapes, HloInstruction* while_op) {
+  HloComputation* computation = while_op->while_condition();
+  HloInstruction* param = computation->parameter_instruction(0);
+  *param->mutable_shape() = ShapeUtil::MakeTupleShape(new_parameter_shapes);
+  VLOG(10) << computation->ToString();
+}
+
+// Rewrites the while-op with a recognized pipelined SendDone/RecvDone pattern
+// to pipeline Send/Recv instead.
+absl::Status TransformLoop(
+    const PipelinedP2PInfo& pipelined_info,
+    const CollectiveInComputation& collective_in_computation, int64_t& idx,
+    int64_t& idx_end, HloInstructionSequence& instruction_sequence,
+    HloInstruction* while_op) {
+  HloComputation* computation = while_op->parent();
+  int64_t opnd_start = pipelined_info.opnd_start;
+  int64_t opnd_end = pipelined_info.opnd_end;
+  VLOG(10) << "Transform pipelined while-op " << while_op->ToString();
+  HloInstruction* while_init = while_op->while_init();
+  InstructionVector new_while_init_opnds;
+  std::vector<Shape> new_parameter_shapes;
+  for (int64_t i = 0; i < while_init->operand_count(); ++i) {
+    HloInstruction* op = while_init->mutable_operand(i);
+    if (i >= opnd_start && i < opnd_end) {
+      // Get Send/Recv from SendDone/RecvDone.
+      new_while_init_opnds.push_back(op->mutable_operand(0));
+    } else {
+      new_while_init_opnds.push_back(op);
+    }
+    new_parameter_shapes.push_back(new_while_init_opnds.back()->shape());
+  }
+
+  RewritePipelinedP2PWhileCond(new_parameter_shapes, while_op);
+  TF_RETURN_IF_ERROR(RewritePipelinedP2PWhileBody(
+      collective_in_computation, new_parameter_shapes, while_op, opnd_start,
+      opnd_end));
+  HloInstruction* new_while_init = computation->AddInstruction(
+      HloInstruction::CreateTuple(new_while_init_opnds), "while-init");
+  VLOG(10) << "new_while_init: " << new_while_init->ToString();
+  HloInstruction* new_while_op = computation->AddInstruction(
+      HloInstruction::CreateWhile(
+          while_op->while_body()->root_instruction()->shape(),
+          while_op->while_condition(), while_op->while_body(), new_while_init),
+      "while-result");
+  CopyInstructionInfo(while_op, new_while_op);
+  VLOG(10) << "new_while_op: " << new_while_op->ToString();
+
+  InstructionVector recv_dones;
+  InstructionVector new_recv_dones;
+  InstructionVector new_send_dones;
+  InstructionVector done_ops;
+  for (int64_t i = opnd_start; i < opnd_end; ++i) {
+    HloInstruction* op = while_init->mutable_operand(i);
+    done_ops.push_back(op);
+    if (op->opcode() == HloOpcode::kRecvDone) {
+      HloInstruction* gte = FindUniqueGTEUserWithIndex(while_op, i);
+      CHECK(gte != nullptr);
+      recv_dones.push_back(gte);
+
+      // Create the new RecvDone using the new while-op result.
+      HloInstruction* recv = computation->AddInstruction(
+          HloInstruction::CreateGetTupleElement(new_while_op, i));
+      HloInstruction* recv_done = computation->AddInstruction(
+          HloInstruction::CreateRecvDone(recv, op->channel_id().value()));
+      new_recv_dones.push_back(recv_done);
+      CopyInstructionInfo(op, recv_done);
+      continue;
+    }
+    CHECK(op->opcode() == HloOpcode::kSendDone);
+    //  Create the new SendDone using the new while-op result.
+    HloInstruction* send = computation->AddInstruction(
+        HloInstruction::CreateGetTupleElement(new_while_op, i));
+    HloInstruction* send_done = computation->AddInstruction(
+        HloInstruction::CreateSendDone(send, op->channel_id().value()));
+    new_send_dones.push_back(send_done);
+    CopyInstructionInfo(op, send_done);
+  }
+
+  TF_RETURN_IF_ERROR(ReplaceUsesAndUpdateSequence(
+      while_op, new_while_op, instruction_sequence, /*diff_shape*/ true));
+  TF_RETURN_IF_ERROR(
+      ReplaceOpInSequence(while_init, new_while_init, instruction_sequence));
+  TF_RETURN_IF_ERROR(ReplaceUsesAndUpdateSequence(recv_dones, new_recv_dones,
+                                                  instruction_sequence));
+  TF_RETURN_IF_ERROR(
+      RemoveDoneOpsAndUpdateSequence(done_ops, instruction_sequence));
+
+  int64_t opnd_tot = opnd_end - opnd_start;
+  // Verify that the numbers of ops we have removed from the sequence is
+  // opnd_tot and they are before the position of the new while-op.
+  CHECK(idx_end == instruction_sequence.size() + opnd_tot);
+  CHECK(instruction_sequence.instructions()[idx - opnd_tot] == new_while_op);
+
+  // Update idx_end to reflect the current size of the instruction sequence.
+  // Update idx to right after the new while-op.
+  idx_end -= opnd_tot;
+  idx = idx - opnd_tot + 1;
+  bool inserted =
+      InsertBeforeFirstCollectiveOp(new_send_dones, collective_in_computation,
+                                    instruction_sequence, idx, idx_end);
+  CHECK(idx_end == instruction_sequence.size());
+  // If there isn't any may-invoke-collective ops after the while-op, add
+  // the new SendDone ops before the last instruction in the sequence.
+  if (!inserted) {
+    CHECK(idx_end == idx);
+    idx--;
+    for (auto send_done : new_send_dones) {
+      instruction_sequence.insert_instruction(send_done, idx++);
+    }
+  }
+  return absl::OkStatus();
+}
+
+// Find while-loop with pipelined Send/Recv and rotates the SendDone/RecvDone
+// for such while-loop.
+absl::StatusOr<bool> ProcessComputation(
+    HloModule* module, HloComputation* computation,
+    CollectiveInComputation& collective_in_computation) {
+  VLOG(10) << "Process compuation " << computation->name();
+  bool changed = false;
+  HloInstructionSequence& instruction_sequence =
+      module->schedule().GetOrCreateSequence(computation);
+  int64_t idx = 0;
+  int64_t idx_end = instruction_sequence.size();
+  while (idx < idx_end) {
+    HloInstruction* hlo = instruction_sequence.instructions()[idx];
+
+    if (MayInvokeCollectiveOp(hlo, collective_in_computation)) {
+      collective_in_computation[computation] = true;
+    }
+
+    if (hlo->opcode() != HloOpcode::kWhile) {
+      idx++;
+      continue;
+    }
+
+    std::optional<PipelinedP2PInfo> pipelined_info = FindPipelinedP2P(hlo);
+    if (!pipelined_info.has_value()) {
+      idx++;
+      continue;
+    }
+    TF_RETURN_IF_ERROR(TransformLoop(pipelined_info.value(),
+                                     collective_in_computation, idx, idx_end,
+                                     instruction_sequence, hlo));
+    changed = true;
+  }
+  return changed;
+}
+}  // namespace
+
+absl::StatusOr<bool> PipelinedP2PRewriter::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  bool changed = false;
+  if (!module->has_schedule()) return changed;
+  CollectiveInComputation collective_in_computation;
+  // Visit the computations in the order of callees to callers, so that
+  // while-body is processed before while-op.
+  for (auto* computation :
+       module->MakeComputationPostOrder(execution_threads)) {
+    if (computation->IsFusionComputation()) {
+      collective_in_computation[computation] = false;
+      continue;
+    }
+
+    TF_ASSIGN_OR_RETURN(
+        bool cur_changed,
+        ProcessComputation(module, computation, collective_in_computation));
+    changed |= cur_changed;
+  }
+
+  if (changed) {
+    TF_RETURN_IF_ERROR(module->schedule().Update());
+  }
+
+  return changed;
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/pipelined_p2p_rewriter.h b/third_party/xla/xla/service/gpu/transforms/pipelined_p2p_rewriter.h
new file mode 100644
index 0000000..d2aca8c
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/pipelined_p2p_rewriter.h
@@ -0,0 +1,133 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_PIPELINED_P2P_REWRITER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_PIPELINED_P2P_REWRITER_H_
+
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// PipelinedP2PRewriter is a pass that rewrites pipelined Send/Recv related
+// code for point-to-point communication to rotate SendDone and RecvDone at the
+// end of a loop iteration to the beginning of the next iteration. This pass
+// operates on scheduled module and updates the instruction sequence.
+//
+// In particular, a pipelined Send/Recv chain with one channel group with this
+// code pattern:
+//
+// main:
+//    recv
+//    send
+//    recv-done
+//    send-done
+//    while-init = (recv-done, send-done, ...)
+//    while-op = while(whiel-init) ...
+//
+// while-body:
+//    ...
+//    recv
+//    send
+//    recv-done
+//    send-done
+//    ROOT tuple(recv-done, send-done, ...)
+//
+// Will be transformed to:
+//
+// main:
+//    recv
+//    send
+//    while-init = (recv, send, ...)
+//    while-op = while(whiel-init) ...
+//    recv-done
+//    send-done
+//
+// while-body:
+//    recv-done
+//    ...
+//    send-done
+//    recv
+//    send
+//    ROOT tuple(recv, send, ...)
+//
+// A pipelined Send/Recv chain with two channel groups with this code pattern:
+//
+// main:
+//    recv.0
+//    send.0
+//    recv.1
+//    send.1
+//    recv-done.0
+//    send-done.0
+//    recv-done.1
+//    send-done.1
+//    while-init = (recv-done.0, send-done.0, recv-done.1, send-done.1, ...)
+//    while-op = while(whiel-init) ...
+//
+// while-body:
+//    ...
+//    recv.0
+//    send.0
+//    recv.1
+//    send.1
+//    recv-done.0
+//    send-done.0
+//    recv-done.1
+//    send-done.1
+//    ROOT = tuple(recv-done.0, send-done.0, recv-done.1, send-done.1, ...)
+//
+// Will be transformed to:
+//
+// main:
+//
+//    recv.0
+//    send.0
+//    recv.1
+//    send.1
+//    while-init = (recv.0, send.0, recv.1, send.1, ...)
+//    while-op = while(while-init) ...
+//    recv-done.0
+//    send-done.0
+//    recv-done.1
+//    send-done.1
+//
+// while-body:
+//    recv-done.0
+//    recv-done.1
+//    ...
+//    send-done.0
+//    send-done.1
+//    recv.0
+//    send.1
+//    recv.1
+//    send.1
+//    ROOT tuple(recv.0, send.0, recv.1, send.1, ...)
+//
+class PipelinedP2PRewriter : public HloModulePass {
+ public:
+  absl::string_view name() const override { return "pipelined-p2p-rewriter"; }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_PIPELINED_P2P_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/pipelined_p2p_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/pipelined_p2p_rewriter_test.cc
new file mode 100644
index 0000000..287603c
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/pipelined_p2p_rewriter_test.cc
@@ -0,0 +1,674 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/pipelined_p2p_rewriter.h"
+
+#include <gtest/gtest.h>
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/tests/filecheck.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+class PipelinedP2pRewriterTest : public HloTestBase {
+ protected:
+  void DoFileCheck(const HloModule* module, absl::string_view expected) {
+    HloPrintOptions options;
+    options.set_print_operand_shape(false);
+    options.set_print_result_shape(false);
+    TF_ASSERT_OK_AND_ASSIGN(bool filecheck_matched,
+                            RunFileCheck(module->ToString(options), expected));
+    EXPECT_TRUE(filecheck_matched);
+  }
+};
+
+TEST_F(PipelinedP2pRewriterTest, SendRecUnpipelinedNotTransform) {
+  const char* kModuleStr = R"(
+HloModule test
+
+cond {
+    param = (u32[], u32[2]) parameter(0)
+    count = get-tuple-element(%param), index=0
+    ub = u32[] constant(11)
+    ROOT result = pred[] compare(count, ub), direction=LT
+ }
+
+body {
+    param = (u32[], u32[2]) parameter(0)
+    count = get-tuple-element(param), index=0
+    send-data = u32[2] get-tuple-element(param), index=1
+
+    after-all.0.n = token[] after-all()
+    recv.0 = (u32[2], u32[], token[]) recv(after-all.0.n), channel_id=1,
+      frontend_attributes={
+        _xla_send_recv_source_target_pairs="{{3,0}}",
+        _xla_send_recv_pipeline="0"
+      }
+    send.0 = (u32[2], u32[], token[]) send(send-data, after-all.0.n),
+      channel_id=1,
+      frontend_attributes={
+        _xla_send_recv_source_target_pairs="{{3,0}}",
+        _xla_send_recv_pipeline="0"
+      }
+    recv-done.0 = (u32[2], token[]) recv-done(recv.0), channel_id=1,
+      frontend_attributes={
+        _xla_send_recv_pipeline="0"
+      }
+    send-done.0 = token[] send-done(send.0), channel_id=1,
+      frontend_attributes={
+        _xla_send_recv_pipeline="0"
+      }
+
+    recv-data = u32[2] get-tuple-element(recv-done.0), index=0
+
+    c1 = u32[] constant(1)
+    new_count = u32[] add(count, c1)
+
+    r = u32[2] broadcast(c1), dimensions={}
+    s = u32[2] add(r, recv-data)
+
+    ROOT result = (u32[], u32[2]) tuple(new_count, s)
+  }
+
+  ENTRY test_computation {
+    c0 = u32[] constant(0)
+    c1 = u32[] constant(1)
+    r = u32[] replica-id()
+    a = u32[] add(c1, r)
+    init = u32[2] broadcast(a), dimensions={}
+    while_init = (u32[], u32[2]) tuple(c0, init)
+    while_result = (u32[], u32[2]) while(while_init), body=body, condition=cond,
+      backend_config={"known_trip_count":{"n":"11"}}
+    ROOT recv-data = u32[2] get-tuple-element(while_result), index=1
+  }
+  )";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(kModuleStr));
+  PipelinedP2PRewriter rewriter;
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, rewriter.Run(module.get()));
+  EXPECT_FALSE(changed);
+}
+
+// Tests the rewrite for a pipelined Send/Recv chain with only one channel
+// group.
+TEST_F(PipelinedP2pRewriterTest, SendRecvPipelined1) {
+  const char* kModuleStr = R"(
+  HloModule test, is_scheduled=true
+
+  while-cond {
+    param = (u32[], (f32[1,1024,1024], token[]), token[]) parameter(0)
+    count = get-tuple-element(param), index=0
+    ub = u32[] constant(25)
+    ROOT cond-result = pred[] compare(count, ub), direction=LT
+  }
+
+  while-body {
+    param = (u32[], (f32[1,1024,1024], token[]), token[]) parameter(0)
+    count = get-tuple-element(param), index=0
+
+    recv-done.q = (f32[1,1024,1024], token[]) get-tuple-element(param), index=1
+    recv-data = f32[1, 1024, 1024] get-tuple-element(recv-done.q), index=0
+
+    c1 = u32[] constant(1)
+    new-count = u32[] add(count, c1)
+    replica = u32[] replica-id()
+    c10 = u32[] constant(10)
+    sum = u32[] add(replica, c10)
+    sum2 = u32[] add(sum, count)
+    conv = f32[] convert(sum2)
+    p = f32[1, 1024, 1024] broadcast(conv), dimensions={}
+    b = f32[1, 1024, 1024] add(p, recv-data)
+    c = f32[1, 1024, 1024] multiply(b, b)
+    d = f32[1, 1024, 1024] tan(c)
+    s = f32[1, 1024, 1024] dot(c, d), lhs_batch_dims={0},
+      lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1}
+    send-data = f32[1, 1024, 1024] add(c, s)
+
+    after-all = token[] after-all()
+    recv = (f32[1, 1024, 1024], u32[], token[]) recv(after-all), channel_id=1,
+      frontend_attributes={
+        _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
+        _xla_send_recv_pipeline="0"
+      }
+    send = (f32[1, 1024, 1024], u32[], token[]) send(send-data, after-all),
+      channel_id=1, frontend_attributes={
+        _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
+        _xla_send_recv_pipeline="0"
+      }
+    recv-done.p = (f32[1,1024,1024], token[]) recv-done(recv), channel_id=1,
+      frontend_attributes={
+        _xla_send_recv_pipeline="0"
+      }
+    send-done.p = token[] send-done(send), channel_id=1,
+      frontend_attributes={
+        _xla_send_recv_pipeline="0"
+      }
+    gte.0 = f32[1,1024,1024] get-tuple-element(recv-done.p), index=0
+    gte.1 = token[] get-tuple-element(recv-done.p), index=1
+    recv-done-tuple = (f32[1,1024,1024], token[]) tuple(gte.0, gte.1)
+    ROOT body-result = (u32[], (f32[1,1024,1024], token[]), token[])
+      tuple(new-count, recv-done-tuple, send-done.p)
+  }
+
+  ENTRY main {
+    c0 = u32[] constant(0)
+    f0 = f32[] constant(0.0)
+    init = f32[1, 1024, 1024] broadcast(f0), dimensions={}
+
+    after-all.1 = token[] after-all()
+    recv.1 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all.1), channel_id=1,
+      frontend_attributes={
+        _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
+        _xla_send_recv_pipeline="0"
+      }
+    send.1 = (f32[1, 1024, 1024], u32[], token[]) send(init, after-all.1), channel_id=1,
+      frontend_attributes={
+        _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
+        _xla_send_recv_pipeline="0"
+      }
+    recv-done.1.p = (f32[1,1024,1024], token[]) recv-done(recv.1), channel_id=1,
+      frontend_attributes={
+        _xla_send_recv_pipeline="0"
+      }
+    send-done.1.p = token[] send-done(send.1), channel_id=1,
+      frontend_attributes={
+        _xla_send_recv_pipeline="0"
+      }
+    while-init.p =  (u32[], (f32[1,1024,1024], token[]), token[])
+      tuple(c0, recv-done.1.p, send-done.1.p)
+    while-result.p = (u32[], (f32[1,1024,1024], token[]), token[])
+      while(while-init.p),
+      body=while-body, condition=while-cond,
+      backend_config={"known_trip_count":{"n":"25"}}
+
+    recv-done.1.q = (f32[1,1024,1024], token[]) get-tuple-element(while-result.p), index=1
+
+    ROOT entry-result = f32[1, 1024, 1024] get-tuple-element(recv-done.1.q), index=0
+  }
+  )";
+
+  const char* kExpected = R"(
+  CHECK: %while-body (param.1: (u32[], (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]))) -> (u32[], (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[])) {
+  CHECK: %param.1 = parameter(0)
+  CHECK: %get-tuple-element = get-tuple-element(%param.1), index=1
+  CHECK: %get-tuple-element.1 = get-tuple-element(%param.1), index=2
+  CHECK: %count.1 = get-tuple-element(%param.1), index=0
+  CHECK: %recv-done.p.clone = recv-done(%get-tuple-element), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"}
+  CHECK: %recv-data = get-tuple-element(%recv-done.p.clone), index=0
+  CHECK: %c1 = constant(1)
+  CHECK: %new-count = add(%count.1, %c1)
+  CHECK: %replica = replica-id()
+  CHECK: %c10 = constant(10)
+  CHECK: %sum = add(%replica, %c10)
+  CHECK: %sum2 = add(%sum, %count.1)
+  CHECK: %conv = convert(%sum2)
+  CHECK: %p = broadcast(%conv), dimensions={}
+  CHECK: %b = add(%p, %recv-data)
+  CHECK: %c = multiply(%b, %b)
+  CHECK: %d = tan(%c)
+  CHECK: %s = dot(%c, %d), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1}
+  CHECK: %send-data = add(%c, %s)
+  CHECK: %after-all = after-all()
+  CHECK: %send-done.p.clone = send-done(%get-tuple-element.1), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"}
+  CHECK{LITERAL}: %recv = recv(%after-all), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs={{0,1}, {1,2}, {2,3}, {3,4}}}
+  CHECK{LITERAL}: %send = send(%send-data, %after-all), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs={{0,1}, {1,2}, {2,3}, {3,4}}}
+  CHECK: ROOT %tuple = tuple(%new-count, %recv, %send)
+  CHECK: }
+
+  CHECK: %while-cond (param: (u32[], (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]))) -> pred[] {
+  CHECK: %param = parameter(0)
+  CHECK: %count = get-tuple-element(%param), index=0
+  CHECK: %ub = constant(25)
+  CHECK: ROOT %cond-result = compare(%count, %ub), direction=LT
+  CHECK: }
+
+  CHECK: ENTRY %main () -> f32[1,1024,1024] {
+  CHECK: %c0 = constant(0)
+  CHECK: %f0 = constant(0)
+  CHECK: %init = broadcast(%f0), dimensions={}
+  CHECK: %after-all.1 = after-all()
+  CHECK{LITERAL}: %recv.1 = recv(%after-all.1), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs={{0,1}, {1,2}, {2,3}, {3,4}}}
+  CHECK{LITERAL}: %send.1 = send(%init, %after-all.1), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs={{0,1}, {1,2}, {2,3}, {3,4}}}
+  CHECK: %while-init = tuple(%c0, %recv.1, %send.1)
+  CHECK: %while-result.p.clone = while(%while-init), condition=%while-cond, body=%while-body,
+  CHECK-SAME{LITERAL}: backend_config={"known_trip_count":{"n":"25"}}
+  CHECK: %get-tuple-element.2 = get-tuple-element(%while-result.p.clone), index=1
+  CHECK: %get-tuple-element.3 = get-tuple-element(%while-result.p.clone), index=2
+  CHECK: %recv-done.1.p.clone = recv-done(%get-tuple-element.2), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"}
+  CHECK: %send-done.1.p.clone = send-done(%get-tuple-element.3), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"}
+  CHECK: ROOT %entry-result = get-tuple-element(%recv-done.1.p.clone), index=0
+  CHECK: })";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(kModuleStr));
+  PipelinedP2PRewriter rewriter;
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, rewriter.Run(module.get()));
+  EXPECT_TRUE(changed);
+
+  DoFileCheck(module.get(), kExpected);
+}
+
+// Repeats the Send/Recv pattern in the previous test, to test that we can
+// rewrite a routine with multiple pipelined loops without crashing.
+TEST_F(PipelinedP2pRewriterTest, SendRecvTwoPipelinedWhileLoops) {
+  const char* kModuleStr = R"(
+  HloModule test, is_scheduled=true
+
+  while-cond {
+    param = (u32[], (f32[1,1024,1024], token[]), token[]) parameter(0)
+    count = get-tuple-element(param), index=0
+    ub = u32[] constant(25)
+    ROOT cond-result = pred[] compare(count, ub), direction=LT
+  }
+
+  while-body {
+    param = (u32[], (f32[1,1024,1024], token[]), token[]) parameter(0)
+    count = get-tuple-element(param), index=0
+
+    recv-done.q = (f32[1,1024,1024], token[]) get-tuple-element(param), index=1
+    send-data = f32[1, 1024, 1024] get-tuple-element(recv-done.q), index=0
+
+    c1 = u32[] constant(1)
+    new-count = u32[] add(count, c1)
+
+    after-all = token[] after-all()
+    recv = (f32[1, 1024, 1024], u32[], token[]) recv(after-all), channel_id=1,
+      frontend_attributes={
+        _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
+        _xla_send_recv_pipeline="0"
+      }
+    send = (f32[1, 1024, 1024], u32[], token[]) send(send-data, after-all),
+      channel_id=1, frontend_attributes={
+        _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
+        _xla_send_recv_pipeline="0"
+      }
+    recv-done.p = (f32[1,1024,1024], token[]) recv-done(recv), channel_id=1,
+      frontend_attributes={
+        _xla_send_recv_pipeline="0"
+      }
+    send-done.p = token[] send-done(send), channel_id=1,
+      frontend_attributes={
+        _xla_send_recv_pipeline="0"
+      }
+    gte.0 = f32[1,1024,1024] get-tuple-element(recv-done.p), index=0
+    gte.1 = token[] get-tuple-element(recv-done.p), index=1
+    recv-done-tuple = (f32[1,1024,1024], token[]) tuple(gte.0, gte.1)
+    ROOT body-result = (u32[], (f32[1,1024,1024], token[]), token[])
+      tuple(new-count, recv-done-tuple, send-done.p)
+  }
+
+  while-cond-2 {
+    param = (u32[], (f32[1,1024,1024], token[]), token[]) parameter(0)
+    count = get-tuple-element(param), index=0
+    ub = u32[] constant(25)
+    ROOT cond-result = pred[] compare(count, ub), direction=LT
+  }
+
+  while-body-2 {
+    param = (u32[], (f32[1,1024,1024], token[]), token[]) parameter(0)
+    count = get-tuple-element(param), index=0
+
+    recv-done.q = (f32[1,1024,1024], token[]) get-tuple-element(param), index=1
+    send-data = f32[1, 1024, 1024] get-tuple-element(recv-done.q), index=0
+
+    c1 = u32[] constant(1)
+    new-count = u32[] add(count, c1)
+
+    after-all = token[] after-all()
+    recv = (f32[1, 1024, 1024], u32[], token[]) recv(after-all), channel_id=1,
+      frontend_attributes={
+        _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
+        _xla_send_recv_pipeline="0"
+      }
+    send = (f32[1, 1024, 1024], u32[], token[]) send(send-data, after-all),
+      channel_id=1, frontend_attributes={
+        _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
+        _xla_send_recv_pipeline="0"
+      }
+    recv-done.p = (f32[1,1024,1024], token[]) recv-done(recv), channel_id=1,
+      frontend_attributes={
+        _xla_send_recv_pipeline="0"
+      }
+    send-done.p = token[] send-done(send), channel_id=1,
+      frontend_attributes={
+        _xla_send_recv_pipeline="0"
+      }
+    gte.0 = f32[1,1024,1024] get-tuple-element(recv-done.p), index=0
+    gte.1 = token[] get-tuple-element(recv-done.p), index=1
+    recv-done-tuple = (f32[1,1024,1024], token[]) tuple(gte.0, gte.1)
+    ROOT body-result = (u32[], (f32[1,1024,1024], token[]), token[])
+      tuple(new-count, recv-done-tuple, send-done.p)
+  }
+
+  ENTRY main {
+    c0 = u32[] constant(0)
+    f0 = f32[] constant(0.0)
+    init = f32[1, 1024, 1024] broadcast(f0), dimensions={}
+
+    after-all.1 = token[] after-all()
+    recv.1 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all.1), channel_id=1,
+      frontend_attributes={
+        _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
+        _xla_send_recv_pipeline="0"
+      }
+    send.1 = (f32[1, 1024, 1024], u32[], token[]) send(init, after-all.1), channel_id=1,
+      frontend_attributes={
+        _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
+        _xla_send_recv_pipeline="0"
+      }
+    recv-done.1.p = (f32[1,1024,1024], token[]) recv-done(recv.1), channel_id=1,
+      frontend_attributes={
+        _xla_send_recv_pipeline="0"
+      }
+    send-done.1.p = token[] send-done(send.1), channel_id=1,
+      frontend_attributes={
+        _xla_send_recv_pipeline="0"
+      }
+    while-init.p =  (u32[], (f32[1,1024,1024], token[]), token[])
+      tuple(c0, recv-done.1.p, send-done.1.p)
+    while-result.p = (u32[], (f32[1,1024,1024], token[]), token[])
+      while(while-init.p),
+      body=while-body, condition=while-cond,
+      backend_config={"known_trip_count":{"n":"25"}}
+
+    recv-done.1.q = (f32[1,1024,1024], token[]) get-tuple-element(while-result.p), index=1
+
+    after-all-2.1 = token[] after-all()
+    recv-2.1 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all-2.1), channel_id=2,
+      frontend_attributes={
+        _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
+        _xla_send_recv_pipeline="0"
+      }
+    send-2.1 = (f32[1, 1024, 1024], u32[], token[]) send(recv-done.1.q, after-all-2.1), channel_id=2,
+      frontend_attributes={
+        _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}",
+        _xla_send_recv_pipeline="0"
+      }
+    recv-done-2.1.p = (f32[1,1024,1024], token[]) recv-done(recv-2.1), channel_id=2,
+      frontend_attributes={
+        _xla_send_recv_pipeline="0"
+      }
+    send-done-2.1.p = token[] send-done(send-2.1), channel_id=2,
+      frontend_attributes={
+        _xla_send_recv_pipeline="0"
+      }
+    while-init-2.p =  (u32[], (f32[1,1024,1024], token[]), token[])
+      tuple(c0, recv-done-2.1.p, send-done-2.1.p)
+    while-result-2.p = (u32[], (f32[1,1024,1024], token[]), token[])
+      while(while-init-2.p),
+      body=while-body-2, condition=while-cond-2,
+      backend_config={"known_trip_count":{"n":"25"}}
+
+    recv-done-2.1.q = (f32[1,1024,1024], token[]) get-tuple-element(while-result-2.p), index=1
+
+    ROOT entry-result = f32[1, 1024, 1024] get-tuple-element(recv-done-2.1.q), index=0
+  }
+  )";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(kModuleStr));
+  PipelinedP2PRewriter rewriter;
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, rewriter.Run(module.get()));
+  // Check that we transform the module without crashing.
+  EXPECT_TRUE(changed);
+}
+
+// Tests the rewrite for a pipelined Send/Recv chain with two channel groups.
+TEST_F(PipelinedP2pRewriterTest, SendRecvPipelined2) {
+  const char* kModuleStr = R"(
+  HloModule test, is_scheduled=true
+
+  while-cond {
+    param = (u32[], (f32[1,1024,1024], token[]), token[],
+      (f32[1,1024,1024], token[]), token[]) parameter(0)
+    count = get-tuple-element(param), index=0
+    ub = u32[] constant(25)
+    ROOT cond-result = pred[] compare(count, ub), direction=LT
+  }
+
+  while-body {
+    param = (u32[], (f32[1,1024,1024], token[]), token[],
+      (f32[1,1024,1024], token[]), token[]) parameter(0)
+    count = get-tuple-element(param), index=0
+
+    recv-done.0.q = (f32[1,1024,1024], token[]) get-tuple-element(param), index=1
+    recv-data.0 = f32[1, 1024, 1024] get-tuple-element(recv-done.0.q), index=0
+    recv-done.1.q = (f32[1,1024,1024], token[]) get-tuple-element(param), index=3
+    recv-data.1 = f32[1, 1024, 1024] get-tuple-element(recv-done.1.q), index=0
+
+    replica = u32[] replica-id()
+    constant0 = u32[] constant(0)
+    compare0 = pred[] compare(replica, constant0), direction=EQ
+    compare = pred[1, 1024, 1024] broadcast(compare0), dimensions={}
+    recv-data = f32[1, 1024, 1024] select(compare, recv-data.0, recv-data.1)
+
+    c1 = u32[] constant(1)
+    new-count = u32[] add(count, c1)
+    c10 = u32[] constant(10)
+    sum = u32[] add(replica, c10)
+    sum2 = u32[] add(sum, count)
+    conv = f32[] convert(sum2)
+    p = f32[1, 1024, 1024] broadcast(conv), dimensions={}
+    b = f32[1, 1024, 1024] add(p, recv-data)
+    c = f32[1, 1024, 1024] multiply(b, b)
+    d = f32[1, 1024, 1024] tan(c)
+    s = f32[1, 1024, 1024] dot(c, d), lhs_batch_dims={0},
+      lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1}
+    send-data = f32[1, 1024, 1024] add(c, s)
+
+    after-all = token[] after-all()
+    recv = (f32[1, 1024, 1024], u32[], token[]) recv(after-all), channel_id=1,
+      frontend_attributes={
+        _xla_send_recv_source_target_pairs="{{3,0}}",
+        _xla_send_recv_pipeline="0"
+      }
+    send = (f32[1, 1024, 1024], u32[], token[]) send(send-data, after-all),
+      channel_id=1, frontend_attributes={
+        _xla_send_recv_source_target_pairs="{{3,0}}",
+        _xla_send_recv_pipeline="0"
+      }
+    recv-done.p = (f32[1,1024,1024], token[]) recv-done(recv), channel_id=1,
+      frontend_attributes={
+        _xla_send_recv_pipeline="0"
+      }
+    send-done.p = token[] send-done(send), channel_id=1,
+      frontend_attributes={
+        _xla_send_recv_pipeline="0"
+      }
+
+    after-all.1 = token[] after-all()
+    recv.1 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all.1), channel_id=2,
+      frontend_attributes={
+        _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}}",
+        _xla_send_recv_pipeline="1"
+      }
+    send.1 = (f32[1, 1024, 1024], u32[], token[]) send(send-data, after-all.1),
+      channel_id=2, frontend_attributes={
+       _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}}",
+       _xla_send_recv_pipeline="1"
+      }
+    recv-done.1.p = (f32[1,1024,1024], token[]) recv-done(recv.1), channel_id=2,
+      frontend_attributes={
+        _xla_send_recv_pipeline="1"
+      }
+    send-done.1.p = token[] send-done(send.1), channel_id=2,
+      frontend_attributes={
+        _xla_send_recv_pipeline="1"
+      }
+
+    ROOT body-result = (u32[], (f32[1,1024,1024], token[]), token[],
+      (f32[1,1024,1024], token[]), token[])
+      tuple(new-count, recv-done.p, send-done.p, recv-done.1.p, send-done.1.p)
+  }
+
+  ENTRY main {
+    c0 = u32[] constant(0)
+    f0 = f32[] constant(0.0)
+    init = f32[1, 1024, 1024] broadcast(f0), dimensions={}
+
+    after-all.2 = token[] after-all()
+    recv.2 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all.2), channel_id=1,
+      frontend_attributes={
+       _xla_send_recv_source_target_pairs="{{3,0}}",
+       _xla_send_recv_pipeline="0"
+    }
+    send.2 = (f32[1, 1024, 1024], u32[], token[]) send(init, after-all.2), channel_id=1,
+      frontend_attributes={
+       _xla_send_recv_source_target_pairs="{{3,0}}",
+       _xla_send_recv_pipeline="0"
+    }
+    recv-done.2.p = (f32[1,1024,1024], token[]) recv-done(recv.2), channel_id=1,
+      frontend_attributes={
+        _xla_send_recv_pipeline="0"
+      }
+    send-done.2.p = token[] send-done(send.2), channel_id=1,
+      frontend_attributes={
+        _xla_send_recv_pipeline="0"
+      }
+
+    after-all.3 = token[] after-all()
+    recv.3 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all.3), channel_id=2,
+      frontend_attributes={
+       _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}}",
+       _xla_send_recv_pipeline="1"
+    }
+    send.3 = (f32[1, 1024, 1024], u32[], token[]) send(init, after-all.3), channel_id=2,
+      frontend_attributes={
+       _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}}",
+       _xla_send_recv_pipeline="1"
+    }
+    recv-done.3.p = (f32[1,1024,1024], token[]) recv-done(recv.3), channel_id=2,
+      frontend_attributes={
+        _xla_send_recv_pipeline="1"
+      }
+    send-done.3.p = token[] send-done(send.3), channel_id=2,
+      frontend_attributes={
+        _xla_send_recv_pipeline="1"
+      }
+
+    while-init.p =  (u32[], (f32[1,1024,1024], token[]), token[],
+      (f32[1,1024,1024], token[]), token[]) tuple(c0, recv-done.2.p, send-done.2.p, recv-done.3.p, send-done.3.p)
+    while-result.p = (u32[], (f32[1,1024,1024], token[]), token[],
+      (f32[1,1024,1024], token[]), token[]) while(while-init.p),
+      body=while-body, condition=while-cond,
+      backend_config={"known_trip_count":{"n":"25"}}
+
+    recv-done.2.q = (f32[1,1024,1024], token[]) get-tuple-element(while-result.p), index=1
+    recv-data.2 = f32[1, 1024, 1024] get-tuple-element(recv-done.2.q), index=0
+    recv-done.3.q = (f32[1,1024,1024], token[]) get-tuple-element(while-result.p), index=3
+    recv-data.3 = f32[1, 1024, 1024] get-tuple-element(recv-done.3.q), index=0
+
+    replica = u32[] replica-id()
+    constant0 = u32[] constant(0)
+    compare0 = pred[] compare(replica, constant0), direction=EQ
+    compare = pred[1, 1024, 1024] broadcast(compare0), dimensions={}
+    ROOT entry-result = f32[1, 1024, 1024] select(compare, recv-data.2, recv-data.3)
+  }
+  )";
+
+  const char* kExpected = R"(
+  CHECK: %while-body (param.1: (u32[], (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]))) -> (u32[], (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[])) {
+  CHECK: %param.1 = parameter(0)
+  CHECK: %get-tuple-element = get-tuple-element(%param.1), index=1
+  CHECK: %get-tuple-element.1 = get-tuple-element(%param.1), index=2
+  CHECK: %get-tuple-element.2 = get-tuple-element(%param.1), index=3
+  CHECK: %get-tuple-element.3 = get-tuple-element(%param.1), index=4
+  CHECK: %count.1 = get-tuple-element(%param.1), index=0
+  CHECK: %recv-done.p.clone = recv-done(%get-tuple-element), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"}
+  CHECK: %recv-data.0 = get-tuple-element(%recv-done.p.clone), index=0
+  CHECK: %recv-done.1.p.clone = recv-done(%get-tuple-element.2), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1"}
+  CHECK: %recv-data.1 = get-tuple-element(%recv-done.1.p.clone), index=0
+  CHECK: %replica = replica-id()
+  CHECK: %constant0 = constant(0)
+  CHECK: %compare0 = compare(%replica, %constant0), direction=EQ
+  CHECK: %compare = broadcast(%compare0), dimensions={}
+  CHECK: %recv-data.2 = select(%compare, %recv-data.0, %recv-data.1)
+  CHECK: %c1 = constant(1)
+  CHECK: %new-count = add(%count.1, %c1)
+  CHECK: %c10 = constant(10)
+  CHECK: %sum = add(%replica, %c10)
+  CHECK: %sum2 = add(%sum, %count.1)
+  CHECK: %conv = convert(%sum2)
+  CHECK: %p = broadcast(%conv), dimensions={}
+  CHECK: %b = add(%p, %recv-data.2)
+  CHECK: %c = multiply(%b, %b)
+  CHECK: %d = tan(%c)
+  CHECK: %s = dot(%c, %d), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1}
+  CHECK: %send-data = add(%c, %s)
+  CHECK: %after-all = after-all()
+  CHECK: %send-done.p.clone = send-done(%get-tuple-element.1), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"}
+  CHECK: %send-done.1.p.clone = send-done(%get-tuple-element.3), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1"}
+  CHECK{LITERAL}: %recv = recv(%after-all), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs={{3,0}}}
+  CHECK{LITERAL}: %send = send(%send-data, %after-all), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs={{3,0}}}
+  CHECK: %after-all.1 = after-all()
+  CHECK{LITERAL}: %recv.1 = recv(%after-all.1), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1",_xla_send_recv_source_target_pairs={{0,1}, {1,2}, {2,3}}}
+  CHECK{LITERAL}: %send.1 = send(%send-data, %after-all.1), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1",_xla_send_recv_source_target_pairs={{0,1}, {1,2}, {2,3}}}
+  CHECK: ROOT %tuple = tuple(%new-count, %recv, %send, %recv.1, %send.1)
+  CHECK: }
+
+  CHECK: %while-cond (param: (u32[], (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]))) -> pred[] {
+  CHECK: %param = parameter(0)
+  CHECK: %count = get-tuple-element(%param), index=0
+  CHECK: %ub = constant(25)
+  CHECK: ROOT %cond-result = compare(%count, %ub), direction=LT
+  CHECK: }
+
+  CHECK: ENTRY %main () -> f32[1,1024,1024] {
+  CHECK: %c0 = constant(0)
+  CHECK: %f0 = constant(0)
+  CHECK: %init = broadcast(%f0), dimensions={}
+  CHECK: %after-all.2 = after-all()
+  CHECK{LITERAL}: %recv.2 = recv(%after-all.2), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs={{3,0}}}
+  CHECK{LITERAL}: %send.2 = send(%init, %after-all.2), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs={{3,0}}}
+  CHECK: %after-all.3 = after-all()
+  CHECK{LITERAL}: %recv.3 = recv(%after-all.3), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1",_xla_send_recv_source_target_pairs={{0,1}, {1,2}, {2,3}}}
+  CHECK{LITERAL}: %send.3 = send(%init, %after-all.3), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1",_xla_send_recv_source_target_pairs={{0,1}, {1,2}, {2,3}}}
+  CHECK: %while-init = tuple(%c0, %recv.2, %send.2, %recv.3, %send.3)
+  CHECK{LITERAL}: %while-result.p.clone = while(%while-init), condition=%while-cond, body=%while-body, backend_config={"known_trip_count":{"n":"25"}}
+  CHECK: %get-tuple-element.4 = get-tuple-element(%while-result.p.clone), index=1
+  CHECK: %get-tuple-element.5 = get-tuple-element(%while-result.p.clone), index=2
+  CHECK: %get-tuple-element.6 = get-tuple-element(%while-result.p.clone), index=3
+  CHECK: %get-tuple-element.7 = get-tuple-element(%while-result.p.clone), index=4
+  CHECK: %recv-done.2.p.clone = recv-done(%get-tuple-element.4), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"}
+  CHECK: %recv-data.3 = get-tuple-element(%recv-done.2.p.clone), index=0
+  CHECK: %recv-done.3.p.clone = recv-done(%get-tuple-element.6), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1"}
+  CHECK: %recv-data.4 = get-tuple-element(%recv-done.3.p.clone), index=0
+  CHECK: %replica.1 = replica-id()
+  CHECK: %constant0.1 = constant(0)
+  CHECK: %compare0.1 = compare(%replica.1, %constant0.1), direction=EQ
+  CHECK: %compare.1 = broadcast(%compare0.1), dimensions={}
+  CHECK: %send-done.2.p.clone = send-done(%get-tuple-element.5), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"}
+  CHECK: %send-done.3.p.clone = send-done(%get-tuple-element.7), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1"}
+  CHECK: ROOT %entry-result = select(%compare.1, %recv-data.3, %recv-data.4)
+  CHECK: })";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(kModuleStr));
+  PipelinedP2PRewriter rewriter;
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, rewriter.Run(module.get()));
+  EXPECT_TRUE(changed);
+
+  DoFileCheck(module.get(), kExpected);
+}
+
+}  // namespace
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/priority_fusion.cc b/third_party/xla/xla/service/gpu/transforms/priority_fusion.cc
new file mode 100644
index 0000000..9ceca0c
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/priority_fusion.cc
@@ -0,0 +1,886 @@
+/* Copyright 2017 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/priority_fusion.h"
+
+#include <cstddef>
+#include <cstdint>
+#include <functional>
+#include <iterator>
+#include <limits>
+#include <map>
+#include <memory>
+#include <string>
+#include <utility>
+#include <variant>
+#include <vector>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/check.h"
+#include "absl/meta/type_traits.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/string_view.h"
+#include "absl/synchronization/mutex.h"
+#include "absl/time/time.h"
+#include "llvm/ADT/STLExtras.h"
+#include "mlir/IR/MLIRContext.h"
+#include "xla/debug_options_flags.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/dump.h"
+#include "xla/service/fusion_queue.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/fusion_process_dump.pb.h"
+#include "xla/service/gpu/gpu_fusible.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
+#include "xla/service/gpu/hlo_traversal.h"
+#include "xla/service/gpu/model/fusion_analysis_cache.h"
+#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h"
+#include "xla/service/gpu/model/gpu_performance_model.h"
+#include "xla/service/gpu/model/gpu_performance_model_base.h"
+#include "xla/service/gpu/model/symbolic_tile_analysis.h"
+#include "xla/service/hlo_graph_dumper.h"
+#include "xla/service/instruction_fusion.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/blocking_counter.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/logging.h"
+#include "tsl/platform/status.h"
+#include "tsl/platform/threadpool.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+bool ElementIsF32OrF16(const Shape& shape) {
+  PrimitiveType type = shape.element_type();
+  return type == F32 || type == F16;
+}
+
+bool IsFusible(const HloInstruction& instr) {
+  // Side-effecting operations are not fusible.
+  if (!instr.IsFusible()) {
+    return false;
+  }
+
+  // Element-wise operations are always fusible.
+  if (instr.IsElementwise()) {
+    return true;
+  }
+
+  // Other non-elementwise ops also supported by elemental fusion.
+  switch (instr.opcode()) {
+    case HloOpcode::kFusion:
+      return instr.fusion_kind() != HloInstruction::FusionKind::kCustom;
+
+    case HloOpcode::kCopy:
+    case HloOpcode::kIota:
+    case HloOpcode::kConstant:
+    case HloOpcode::kReduce:
+    case HloOpcode::kBitcast:
+    case HloOpcode::kBroadcast:
+    case HloOpcode::kConcatenate:
+    case HloOpcode::kDynamicSlice:
+    case HloOpcode::kDynamicUpdateSlice:
+    case HloOpcode::kGather:
+    case HloOpcode::kPad:
+    case HloOpcode::kReduceWindow:
+    case HloOpcode::kReshape:
+    case HloOpcode::kReverse:
+    case HloOpcode::kScatter:
+    case HloOpcode::kSlice:
+    case HloOpcode::kTranspose:
+      return true;
+    default:
+      return false;
+  }
+}
+
+// An implementation of FusionQueue that determines whether to fuse instructions
+// according to a cost model, and chooses the next fusion candidate according to
+// dynamically updated priorities. The elements in the queue are producer nodes
+// that could be fused, and the priority of a producer is the benefit in
+// performance when fusing it to all of its fusible users. We greedily pick the
+// max-benefit producer to fuse, and update the estimated benefits of the fused
+// nodes and their operands.
+class PriorityFusionQueue {
+  using Priority = int64_t;
+  using CanFuseCallback = std::function<FusionDecision(
+      HloInstruction* /*producer*/, int64_t /*consumer operand_index*/)>;
+
+ public:
+  PriorityFusionQueue(HloComputation* computation,
+                      const GpuHloCostAnalysis::Options& cost_analysis_options,
+                      const se::DeviceDescription* device_info,
+                      FusionProcessDumpProto* fusion_process_dump,
+                      tsl::thread::ThreadPool* thread_pool,
+                      mlir::MLIRContext* mlir_context,
+                      HloFusionAnalysisCache& fusion_analysis_cache,
+                      bool triton_softmax_priority_fusion_enabled)
+      : computation_(computation),
+        device_info_(device_info),
+        cost_analysis_(cost_analysis_options, *device_info),
+        fusion_process_dump_(fusion_process_dump),
+        thread_pool_(thread_pool),
+        mlir_context_(mlir_context),
+        fusion_analysis_cache_(fusion_analysis_cache),
+        triton_softmax_priority_fusion_enabled_(
+            triton_softmax_priority_fusion_enabled) {
+    VLOG(2) << "Running full HLO cost analysis for " << computation_->name();
+    TF_CHECK_OK(computation_->Accept(&cost_analysis_));
+
+    dump_fusion_visualization_ = computation->parent()
+                                     ->config()
+                                     .debug_options()
+                                     .xla_dump_fusion_visualization();
+
+    // Initializes the priority queue.
+    std::vector<HloInstruction*> instructions;
+    for (auto* instruction : computation->MakeInstructionPostOrder()) {
+      UpdatePerformanceModelCache(instruction);
+      if (instruction->opcode() == HloOpcode::kParameter ||
+          instruction->user_count() == 0 || !instruction->IsFusible() ||
+          instruction->opcode() == HloOpcode::kTuple ||
+          instruction->opcode() == HloOpcode::kGetTupleElement) {
+        continue;
+      }
+      instructions.push_back(instruction);
+    }
+    ComputeAndSetPriorities(instructions);
+  }
+
+  void ComputeAndSetPriorities(
+      const std::vector<HloInstruction*>& instructions) {
+    std::vector<Priority> priorities = ComputePriorities(instructions);
+
+    for (auto [instruction, priority] : llvm::zip(instructions, priorities)) {
+      auto key = std::make_pair(priority, instruction->unique_id());
+
+      // Remove instruction with the old priority from the queue.
+      auto reverse_it = reverse_map_.find(instruction);
+      if (reverse_it != reverse_map_.end()) {
+        const PriorityQueue::iterator& queue_it = reverse_it->second;
+        // Priority didn't change. Nothing to do.
+        if (key == queue_it->first) {
+          continue;
+        }
+        producer_priority_queue_.erase(queue_it);
+        reverse_map_.erase(reverse_it);
+      }
+
+      // If the priority is negative, it's not helpful to perform fusion on this
+      // instruction.
+      if (priority < 0) {
+        continue;
+      }
+
+      auto emplace_result = producer_priority_queue_.emplace(key, instruction);
+      reverse_map_.emplace(instruction, emplace_result.first);
+    }
+  }
+
+  std::vector<Priority> ComputePriorities(
+      const std::vector<HloInstruction*>& instructions) {
+    auto schedule_or_run = [this](std::function<void()> fn) {
+      if (thread_pool_) {
+        thread_pool_->Schedule(std::move(fn));
+      } else {
+        fn();
+      }
+    };
+    tsl::BlockingCounter counter(instructions.size());
+    std::vector<Priority> priorities(instructions.size());
+
+    for (size_t i = 0; i < instructions.size(); ++i) {
+      schedule_or_run([&, i] {
+        priorities[i] = CalculateProducerPriority(instructions[i]);
+        counter.DecrementCount();
+      });
+    }
+    counter.Wait();
+    return priorities;
+  }
+
+  // Gets the next pair of (producer, consumers) from the queue for fusion.
+  // Returns true if there is the next producer to fuse, otherwise false. Stores
+  // the producer and consumers in `current_producer_` and `current_consumers_`.
+  bool DequeueNextProducer() {
+    current_producer_ = nullptr;
+    current_consumers_.clear();
+
+    while (!producer_priority_queue_.empty() && current_consumers_.empty()) {
+      auto next_it = std::prev(producer_priority_queue_.end());
+
+      current_producer_ = next_it->second;
+      producer_priority_queue_.erase(next_it);
+      reverse_map_.erase(current_producer_);
+
+      current_consumers_ = current_producer_->users();
+
+      if (current_producer_->opcode() == HloOpcode::kBitcast) {
+        // We don't check if bitcasts can be fused with all consumers, so we
+        // have to do it here.
+        llvm::erase_if(current_consumers_, [&](HloInstruction* consumer) {
+          return !CanFuseCached(current_producer_, consumer);
+        });
+      }
+    }
+
+    return !current_consumers_.empty();
+  }
+
+  void UpdatePerformanceModelCache(HloInstruction* producer) {
+    if (!IsFusible(*producer) && !IsGenericTritonFusion(*producer)) {
+      return;
+    }
+
+    auto config = GpuPerformanceModelOptions::PriorityFusion(
+        &fusion_analysis_cache_, &gpu_performance_model_cache_);
+
+    if (!gpu_performance_model_cache_.Get(*producer)) {
+      auto runtime_data = GpuPerformanceModel::EstimateRunTimeForInstruction(
+          producer, *device_info_, &cost_analysis_, config);
+      gpu_performance_model_cache_.Set(*producer, runtime_data);
+    }
+  }
+
+  // Update priorities of all affected ops.
+  void UpdatePriorities() {
+    // Revisit costs of all updated ops. It's important to update cost analysis
+    // before recalculating priorities.
+    for (auto instruction : to_update_priority_) {
+      TF_CHECK_OK(cost_analysis_.RevisitInstruction(instruction));
+    }
+    for (auto producer : to_update_priority_) {
+      UpdatePerformanceModelCache(producer);
+    }
+
+    ComputeAndSetPriorities(std::vector<HloInstruction*>{
+        to_update_priority_.begin(), to_update_priority_.end()});
+
+    to_update_priority_.clear();
+  }
+
+  // Prepares producer and consumer instruction to be fused. Invalidates caches
+  // and writes logs.
+  void PreFusion(HloInstruction* producer, HloInstruction* consumer) {
+    if (dump_fusion_visualization_) {
+      RegisterFusionState(
+          *computation_,
+          absl::StrCat("About to fuse |", producer->name(), "| into |",
+                       consumer->name(), "| inside PriorityFusion"),
+          *consumer, producer);
+    }
+    InvalidateCaches(consumer);
+  }
+
+  // Invalidates all cached value related to this instruction. Called before the
+  // instruction is fused. The instruction can be either producer or consumer.
+  void InvalidateCaches(HloInstruction* instruction) {
+    can_fuse_cache_.erase(instruction);
+    for (const HloInstruction* operand : instruction->operands()) {
+      auto it = can_fuse_cache_.find(operand);
+      if (it != can_fuse_cache_.end()) {
+        it->second.erase(instruction);
+      }
+    }
+
+    gpu_performance_model_cache_.Invalidate(*instruction);
+    fusion_analysis_cache_.Invalidate(*instruction);
+    fusion_info_cache_.Invalidate(instruction);
+  }
+
+  // Updates data for the new fusion instruction and its users and operands.
+  void OnFusingInstruction(HloInstruction* fusion,
+                           HloInstruction* original_producer,
+                           HloInstruction* original_consumer) {
+    if (fusion_process_dump_) {
+      auto* fusion_step =
+          fusion_process_dump_->add_fusion_steps()->mutable_fusion();
+
+      // Explicit std::string is needed for OSS proto implementation.
+      fusion_step->set_fusion_name(std::string(fusion->name()));
+      fusion_step->set_producer_name(std::string(original_producer->name()));
+      fusion_step->set_consumer_name(std::string(original_consumer->name()));
+    }
+
+    if (dump_fusion_visualization_) {
+      RegisterFusionState(
+          *computation_,
+          absl::StrCat("Fused |", original_producer->name(), "| into |",
+                       fusion->name(), "| inside PriorityFusion"),
+          *fusion);
+    }
+
+    // The original consumer was replaced with the fusion, but it's pointer can
+    // still be referenced somewhere, for example, in to_update_priority_.
+    // Priority recomputation is called before DCE. Remove all references to
+    // the original consumer here.
+    if (fusion != original_consumer) {
+      RemoveInstruction(original_consumer);
+    }
+
+    // Detach 'original_producer' from its operands if it has no users.
+    // This avoids having it appear as a "phantom" user in subsequent priority
+    // calculations on 'fusion.operands' below, before it is finally removed
+    // in 'RemoveInstruction'.
+    if (original_producer->user_count() == 0) {
+      InvalidateCaches(original_producer);
+      original_producer->DetachFromOperandsAndUsers();
+    }
+
+    // Collect the instructions whose priorities need to be updated.
+    for (HloInstruction* operand : fusion->operands()) {
+      if (operand == original_producer ||
+          operand->opcode() == HloOpcode::kConstant ||
+          operand->opcode() == HloOpcode::kGetTupleElement) {
+        continue;
+      }
+      // Need to consider only instructions that are fusible, e.g., rng with
+      // greater than one user is not fusible.
+      if (!operand->IsFusible()) {
+        continue;
+      }
+
+      to_update_priority_.insert(operand);
+    }
+    to_update_priority_.insert(fusion);
+  }
+
+  // Removes data for the instruction.
+  void RemoveInstruction(HloInstruction* instruction) {
+    to_update_priority_.erase(instruction);
+    fusion_analysis_cache_.Invalidate(*instruction);
+
+    auto reverse_it = reverse_map_.find(instruction);
+    if (reverse_it == reverse_map_.end()) {
+      return;
+    }
+    producer_priority_queue_.erase(reverse_it->second);
+    reverse_map_.erase(reverse_it);
+  }
+
+  HloInstruction* current_producer() { return current_producer_; }
+
+  const std::vector<HloInstruction*>& current_consumers() {
+    return current_consumers_;
+  }
+
+ private:
+  // Returns the priority of the producer based on its current operands and
+  // users.
+  Priority CalculateProducerPriority(HloInstruction* producer) {
+    // Bitcasts should always be fused first, since they are no-ops.
+    if (producer->opcode() == HloOpcode::kBitcast) {
+      return std::numeric_limits<Priority>::max();
+    }
+    // We always fuse constants, but the cost model doesn't handle them very
+    // well: fusing constants changes costs significantly. Also, there's no
+    // point recomputing priorities. Therefore, we fuse all of them at the end.
+    if (producer->opcode() == HloOpcode::kConstant) {
+      return std::numeric_limits<Priority>::min();
+    }
+
+    // Don't fuse if we can't fuse in all users.
+    if (auto fusion_decision = CanFuseWithAllNonBitcastUsers(producer);
+        !fusion_decision) {
+      if (fusion_process_dump_) {
+        absl::MutexLock lock(&fusion_process_dump_mutex_);
+        auto* step = fusion_process_dump_->add_fusion_steps()
+                         ->mutable_producer_ineligible();
+        step->set_producer_name(std::string(producer->name()));
+        step->set_reason(fusion_decision.Explain());
+      }
+      return std::numeric_limits<Priority>::min();
+    }
+
+    GpuPerformanceModel::RunTimes run_times =
+        GpuPerformanceModel::EstimateRunTimesForPriorityFusion(
+            producer, *device_info_, &cost_analysis_,
+            GpuPerformanceModelOptions::PriorityFusion(
+                &fusion_analysis_cache_, &gpu_performance_model_cache_),
+            producer->users());
+
+    if (fusion_process_dump_) {
+      absl::MutexLock lock(&fusion_process_dump_mutex_);
+      auto* step =
+          fusion_process_dump_->add_fusion_steps()->mutable_update_priority();
+      step->set_producer_name(std::string(producer->name()));
+      for (auto* consumer : producer->users()) {
+        step->add_consumer_names(std::string(consumer->name()));
+      }
+      step->set_us_fused(absl::ToDoubleMicroseconds(run_times.time_fused));
+      step->set_us_unfused(absl::ToDoubleMicroseconds(run_times.time_unfused));
+    }
+    return absl::ToInt64Nanoseconds(run_times.time_unfused -
+                                    run_times.time_fused);
+  }
+
+  FusionDecision CanFuseTriton(HloInstruction* producer,
+                               HloInstruction* consumer) {
+    if (!triton_softmax_priority_fusion_enabled_) {
+      return "triton softmax fusion is not enabled";
+    }
+
+    if (IsGenericTritonFusion(*producer)) {
+      if (!IsFusible(*consumer)) {
+        return "the consumer is not fusible";
+      }
+    } else {
+      if (!IsFusible(*producer)) {
+        return "the producer is not fusible";
+      }
+    }
+
+    auto fusion = HloFusionAdaptor::ForProducerConsumer(producer, consumer);
+
+    SymbolicTileAnalysisOrError symbolic_tile_analysis_or =
+        SymbolicTileAnalysis::AnalyzeFusion(*fusion, mlir_context_);
+
+    if (const auto* fusion_decision =
+            std::get_if<FusionDecision>(&symbolic_tile_analysis_or)) {
+      return {
+          absl::StrCat("Fusion can not be tiled with SymbolicTileAnalysis: ",
+                       fusion_decision->Explain())};
+    }
+
+    return {};
+  }
+
+  FusionDecision CanFuse(HloInstruction* producer, HloInstruction* consumer) {
+    if (IsGenericTritonFusion(*producer) || IsGenericTritonFusion(*consumer)) {
+      return CanFuseTriton(producer, consumer);
+    }
+
+    if (!IsFusible(*producer)) {
+      return "the producer is not fusible";
+    }
+
+    if (!IsFusible(*consumer)) {
+      return "the consumer is not fusible";
+    }
+
+    if (consumer->opcode() == HloOpcode::kBitcast) {
+      return "not fusing into a single bitcast as consumer";
+    }
+
+    // Scatter is special as it has no elemental version but is still input
+    // fusible. Block attempts to create scatter fusions we can't codegen.
+    if (auto can_fuse = CanEmitInputFusedScatter(*producer, *consumer);
+        !can_fuse) {
+      return can_fuse;
+    }
+
+    // Avoid fusing reduce into reduce. Our cost model doesn't currently
+    // understand this case due to a lack of tiling analysis.
+    // TODO(b/312200883): Remove this.
+    auto contains_significant_reduce = [&](const HloInstruction* instr) {
+      auto fusion = HloFusionAdaptor::ForInstruction(instr);
+      return HloAnyOf(*fusion, [](auto node) {
+        if (!(node.opcode() == HloOpcode::kReduce && node.shape().IsArray())) {
+          return false;
+        }
+
+        int64_t reduction_size =
+            ShapeUtil::ElementsIn(node.instruction().operand(0)->shape()) /
+            ShapeUtil::ElementsIn(node.shape());
+
+        // Small reductions are emitted using the elemental emitter anyway.
+        return reduction_size >= 16;
+      });
+    };
+    if (contains_significant_reduce(producer) &&
+        contains_significant_reduce(consumer)) {
+      return "both the producer and the consumer contain a reduce";
+    }
+
+    // Avoid doing fusions into the output of an "input" fusion when it would
+    // switch it to the loop emitter. This often occurs during epilog fusion for
+    // reductions, which suffer from limited emitter support.
+    // TODO(b/312686229): Cost model should handle this.
+    const auto& analysis = fusion_analysis_cache_.Get(*producer);
+    if (analysis.GetEmitterFusionKind() ==
+        HloFusionAnalysis::EmitterFusionKind::kReduction) {
+      const auto& analysis_fused =
+          fusion_analysis_cache_.Get(*producer, *consumer);
+      if (analysis_fused.GetEmitterFusionKind() ==
+          HloFusionAnalysis::EmitterFusionKind::kLoop) {
+        return "fusion into output of a reduce fusion would create a loop "
+               "fusion";
+      }
+    }
+
+    // Avoid cases where we'd create a fusion that hit limitations in ptxas.
+    // Would be nice to model this with cost instead.
+    if (auto fits_budget = FusionFitsInBudget(
+            *consumer, *producer, *device_info_,
+            /*is_consumer_producer_fusion=*/true, &fusion_info_cache_);
+        !fits_budget) {
+      return fits_budget;
+    }
+
+    // Also check that our emitter can handle the fusion node. We currently can
+    // have exponential time/memory requirements for emitting certain fusion
+    // kernels, in which case we don't want to fuse.
+    // TODO(b/119692968): Remove this once we have fixed our fusion emitter.
+    if (cost_analysis_.ProducerConsumerMergedTooLarge(*producer, *consumer)) {
+      return "the fusion would result in an overly large code duplication";
+    }
+
+    // Don't fuse across a root instruction. There are situation when a root
+    // instruction is not the last in the computation. Instructions after the
+    // root are not necessary dead. They can be inputs to instructions with side
+    // effects, like outfeed.
+    if (producer == producer->parent()->root_instruction()) {
+      return "not fusing into the output of the root instruction";
+    }
+
+    return InstructionFusion::ShouldFuseInPlaceOp(producer, consumer);
+  }
+
+  FusionDecision CanFuseCached(HloInstruction* producer,
+                               HloInstruction* consumer) {
+    {
+      absl::MutexLock lock(&can_fuse_cache_mutex_);
+      auto& producer_cache = can_fuse_cache_[producer];
+
+      auto it = producer_cache.find(consumer);
+      if (it != producer_cache.end()) {
+        return it->second;
+      }
+    }
+    auto fusion_decision = CanFuse(producer, consumer);
+
+    // The lock is required, because writing to a flat_hash_map is not
+    // thread-safe even for different keys. We never call this computation
+    // concurrently for the same producer, so it's guaranteed that we don't
+    // override any value.
+    {
+      absl::MutexLock lock(&can_fuse_cache_mutex_);
+      can_fuse_cache_[producer][consumer] = fusion_decision;
+    }
+
+    return fusion_decision;
+  }
+
+  FusionDecision CanFuseWithAllNonBitcastUsers(HloInstruction* producer) {
+    if (producer->users().empty()) {
+      return "No users to fuse";
+    }
+
+    FusionDecision result;
+    bool has_non_bitcast_user = false;
+    for (const auto& user : producer->users()) {
+      if (user->opcode() == HloOpcode::kBitcast) {
+        continue;
+      }
+      has_non_bitcast_user = true;
+      if (auto fusion_decision = CanFuseCached(producer, user);
+          !fusion_decision) {
+        VLOG(10) << "Cannot fuse " << producer->name() << " with "
+                 << user->name() << ", because: " << fusion_decision.Explain();
+        return fusion_decision;
+      }
+    }
+    if (!has_non_bitcast_user) {
+      return "not fusing because there are only bitcast users";
+    }
+    return {};
+  }
+
+  // Store computation for cost analysis.
+  HloComputation* computation_;
+
+  const se::DeviceDescription* device_info_;
+
+  // Reference to cost model that defines priorities in the queue.
+  GpuHloCostAnalysis cost_analysis_;
+
+  // The priority queue of producers, implemented as an ordered map, where a
+  // key is a pair: the first element is the priority and the second element is
+  // the unique ID of the instruction to break ties.
+  using PriorityQueue = std::map<std::pair<Priority, int>, HloInstruction*>;
+  PriorityQueue producer_priority_queue_;
+
+  // A reverse map that helps find an instruction in the priority queue.
+  absl::flat_hash_map<HloInstruction*, PriorityQueue::iterator> reverse_map_;
+
+  // The current producer being visited.
+  HloInstruction* current_producer_;
+
+  // The current consumers being visited.
+  std::vector<HloInstruction*> current_consumers_;
+
+  // The set of producers whose priorities need to be updated. Their
+  // priorities are changed because their neighbors got fused, but we delay
+  // the priority updates until current_consumers_ becomes empty. This is to
+  // avoid recomputing priorities multiple times before we dequeue a new
+  // producer.
+  absl::flat_hash_set<HloInstruction*> to_update_priority_;
+
+  // Proto with structured logs of fusion decisions. Used only for debugging. If
+  // null, logging is disabled.
+  FusionProcessDumpProto* fusion_process_dump_;
+  absl::Mutex fusion_process_dump_mutex_;
+
+  tsl::thread::ThreadPool* thread_pool_;
+
+  mlir::MLIRContext* mlir_context_;
+
+  HloFusionAnalysisCache& fusion_analysis_cache_;
+
+  // Caches result of can_fuse for a (producer, consumer) pair. A cache entry is
+  // invalidated if producer or consumer is modified.
+  absl::flat_hash_map<
+      const HloInstruction*,
+      absl::flat_hash_map<const HloInstruction*, FusionDecision>>
+      can_fuse_cache_;
+  absl::Mutex can_fuse_cache_mutex_;
+
+  GpuPerformanceModelCache gpu_performance_model_cache_;
+
+  // Cache for `FusionFitsInBudget` to avoid recomputing expensive properties
+  // like shared memory usage or number of unnested reductions of fusion nodes.
+  FusionInfoCache fusion_info_cache_;
+
+  bool triton_softmax_priority_fusion_enabled_;
+
+  bool dump_fusion_visualization_;
+};
+
+}  // namespace
+
+/*static*/ bool PriorityFusion::IsExpensive(const HloInstruction& instruction) {
+  // Some floating-point math ops are cheap on the GPU.
+  switch (instruction.opcode()) {
+    case HloOpcode::kDivide:
+    case HloOpcode::kSqrt:
+    case HloOpcode::kRsqrt:
+    case HloOpcode::kExp:
+      if (ElementIsF32OrF16(instruction.shape())) {
+        return false;
+      }
+      break;
+    // Loop fusions are cheap.
+    case HloOpcode::kFusion:
+      return false;
+    default:
+      break;
+  }
+  return InstructionFusion::IsExpensive(instruction);
+}
+
+// Return true, if instr is a small constant.
+//
+// There is not single definition for what is a small constant in XLA.
+// IrEmitterContext::emit_constant treats as small only constants of 1 element.
+// HloPrintOptions::print_large_constants is effective for constants larger
+// than 10 elements.
+//
+// This function matches the emitter logic.
+bool IsSmallConstant(const HloInstruction* instr) {
+  return instr->opcode() == HloOpcode::kConstant && instr->shape().IsArray() &&
+         ShapeUtil::ElementsIn(instr->shape()) <= 1;
+}
+
+bool PriorityFusion::ConsumeFuel(HloInstruction* producer,
+                                 HloInstruction* consumer) {
+  return xla::ConsumeFuel(name(), /*ran_out_of_fuel_msg=*/[&] {
+    return absl::StrFormat("Not fusing producer %s with consumer %s",
+                           producer->name(), consumer->name());
+  });
+};
+
+absl::StatusOr<bool> PriorityFusion::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  bool dump_enabled =
+      DumpingEnabledForHloPass(name(), module->config().debug_options());
+  if (dump_enabled) {
+    fusion_process_dump_ = std::make_unique<FusionProcessDumpProto>();
+    *fusion_process_dump_->mutable_gpu_device_info() =
+        device_info_.ToGpuProto();
+  }
+
+  // Compute the computations within which more fusion is possible.
+  auto fusible_computations =
+      GetFusibleComputations(*module, execution_threads);
+
+  // Appends ".0" suffix to all instructions.
+  //
+  // Every time an instruction is duplicated, the last integer suffix is
+  // incremented.
+  // Before: broadcast.123 -> broadcast.124
+  // After: broadcast.123.0 -> broadcast.123.1
+  //
+  // With this modification it will be easier to match instructions before and
+  // after fusion passes, because they will have the same unique prefix. Names
+  // are not used in the pipeline, but it makes debugging much easier.
+  for (auto* computation : fusible_computations) {
+    for (auto* instruction : computation->instructions()) {
+      module->SetAndUniquifyInstrName(instruction,
+                                      absl::StrCat(instruction->name(), ".0"));
+    }
+  }
+
+  if (dump_enabled) {
+    fusion_process_dump_->set_hlo_module_before_fusion(
+        module->ToString(HloPrintOptions::ShortParsable()));
+  }
+
+  bool triton_softmax_priority_fusion_enabled =
+      module->config()
+          .debug_options()
+          .xla_gpu_enable_triton_softmax_priority_fusion();
+
+  int changed = false;
+  for (auto* computation : fusible_computations) {
+    CHECK(!computation->IsFusionComputation());
+
+    auto fusion_queue = std::make_unique<PriorityFusionQueue>(
+        computation, cost_analysis_options_, &device_info_,
+        fusion_process_dump_.get(), thread_pool_, &mlir_context_,
+        fusion_analysis_cache_, triton_softmax_priority_fusion_enabled);
+
+    while (fusion_queue->DequeueNextProducer()) {
+      auto producer = fusion_queue->current_producer();
+
+      for (auto* consumer : fusion_queue->current_consumers()) {
+        // Don't fuse into single bitcasts. We ignore them in the check
+        // CanFuseWithAllNonBitcastUsers(), so we need to check it here.
+        if (consumer->opcode() == HloOpcode::kBitcast) {
+          continue;
+        }
+        if (!ConsumeFuel(producer, consumer)) continue;
+
+        VLOG(5) << "next: " << consumer->name() << "(" << consumer << ") + "
+                << producer->name() << "(" << producer << ")";
+
+        fusion_queue->PreFusion(producer, consumer);
+        auto fusion_instruction = Fuse(producer, consumer, computation);
+        fusion_queue->OnFusingInstruction(fusion_instruction, producer,
+                                          consumer);
+
+        changed = true;
+      }
+
+      if (producer->user_count() == 0) {
+        fusion_queue->RemoveInstruction(producer);
+        // Remove from computation.
+        TF_RETURN_IF_ERROR(computation->RemoveInstruction(producer));
+      }
+
+      fusion_queue->UpdatePriorities();
+    }
+
+    // Fuse all constants.
+    std::vector<HloInstruction*> constants;
+    for (auto* instruction : computation->instructions()) {
+      // Small constants should be fused, because they can be folded and
+      // codegened efficiently.
+      // Fusing large constants doesn't give much benefits, because they're
+      // treated like parameters and read from global memory anyway. Fusion
+      // and duplication of large constants can, however, cause problems if we
+      // want to dump hlo and parse back, because in that case duplicated
+      // constants will be filled with different data.
+      if (IsSmallConstant(instruction)) {
+        constants.push_back(instruction);
+      }
+    }
+    for (auto* constant : constants) {
+      auto users = constant->users();
+      for (auto* user : users) {
+        if (IsFusible(*user) && CanEmitInputFusedScatter(*constant, *user)) {
+          Fuse(constant, user, computation);
+          changed = true;
+        }
+      }
+    }
+  }
+
+  // FusionAnalysis cache uses unique_id as key. IDs are only unique inside one
+  // module. It's important to fully clear the cache if the same instance of the
+  // pass will be called on a different module.
+  fusion_analysis_cache_.Clear();
+
+  if (dump_enabled) {
+    DumpPerModuleProtobufToFile(*module, *fusion_process_dump_,
+                                module->config().debug_options(),
+                                "priority_fusion_dump");
+  }
+
+  return changed;
+}
+
+FusionDecision PriorityFusion::ShouldFuse(HloInstruction* consumer,
+                                          int64_t operand_index) {
+  // This method is called in `InstructionFusion::Run` right before fusion, but
+  // it will always return true. Fusion decision are fully controlled by the
+  // PriorityQueue. If the queue returns a producer that shouldn't be fused,
+  // it's a bug and should be fixed in the queue logic.
+  return {};
+}
+
+HloInstruction::FusionKind PriorityFusion::ChooseKind(
+    const HloInstruction* producer, const HloInstruction* consumer) {
+  // Derive kInput/kLoop fusion kinds from fusion analysis. This shouldn't
+  // matter but some passes downstream still query these instead of fusion
+  // analysis.
+  const auto& analysis = fusion_analysis_cache_.Get(*producer, *consumer);
+  switch (analysis.GetEmitterFusionKind()) {
+    case HloFusionAnalysis::EmitterFusionKind::kLoop:
+      return HloInstruction::FusionKind::kLoop;
+    case HloFusionAnalysis::EmitterFusionKind::kTriton:
+    case HloFusionAnalysis::EmitterFusionKind::kCustomFusion:
+    case HloFusionAnalysis::EmitterFusionKind::kCuDnn:
+      return HloInstruction::FusionKind::kCustom;
+    case HloFusionAnalysis::EmitterFusionKind::kConcatenate:
+    case HloFusionAnalysis::EmitterFusionKind::kReduction:
+    case HloFusionAnalysis::EmitterFusionKind::kTranspose:
+    case HloFusionAnalysis::EmitterFusionKind::kInputSlices:
+    case HloFusionAnalysis::EmitterFusionKind::kScatter:
+      return HloInstruction::FusionKind::kInput;
+  }
+}
+
+HloInstruction* PriorityFusion::FuseInstruction(
+    HloInstruction* fusion_instruction, HloInstruction* producer) {
+  HloInstruction* result = fusion_instruction;
+  if (producer->opcode() == HloOpcode::kFusion) {
+    if (IsGenericTritonFusion(*producer)) {
+      TF_CHECK_OK(fusion_instruction->set_backend_config(
+          *producer->backend_config<GpuBackendConfig>()));
+    }
+
+    fusion_instruction->MergeFusionInstruction(producer);
+  } else {
+    result = InstructionFusion::FuseInstruction(fusion_instruction, producer);
+  }
+  return result;
+}
+
+std::unique_ptr<FusionQueue> PriorityFusion::GetFusionQueue(
+    HloComputation* computation) {
+  return nullptr;
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/priority_fusion.h b/third_party/xla/xla/service/gpu/transforms/priority_fusion.h
new file mode 100644
index 0000000..fce2be5
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/priority_fusion.h
@@ -0,0 +1,100 @@
+/* Copyright 2017 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_PRIORITY_FUSION_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_PRIORITY_FUSION_H_
+
+#include <stdint.h>
+
+#include <memory>
+#include <utility>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "mlir/IR/MLIRContext.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/fusion_queue.h"
+#include "xla/service/gpu/fusion_process_dump.pb.h"
+#include "xla/service/gpu/model/fusion_analysis_cache.h"
+#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h"
+#include "xla/service/hlo_cost_analysis.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/service/instruction_fusion.h"
+#include "xla/stream_executor/device_description.h"
+#include "tsl/platform/threadpool.h"
+
+namespace xla {
+namespace gpu {
+
+class PriorityFusion : public InstructionFusion {
+ public:
+  PriorityFusion(tsl::thread::ThreadPool* thread_pool,
+                 const se::DeviceDescription& device,
+                 GpuHloCostAnalysis::Options cost_analysis_options)
+      : InstructionFusion(PriorityFusion::IsExpensive),
+        thread_pool_(thread_pool),
+        device_info_(device),
+        cost_analysis_options_(std::move(cost_analysis_options)),
+        fusion_analysis_cache_(device_info_) {}
+
+  absl::string_view name() const override { return "priority-fusion"; }
+
+  static bool IsExpensive(const HloInstruction& instruction);
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ protected:
+  std::unique_ptr<FusionQueue> GetFusionQueue(
+      HloComputation* computation) override;
+
+  FusionDecision ShouldFuse(HloInstruction* consumer,
+                            int64_t operand_index) override;
+
+  HloInstruction::FusionKind ChooseKind(
+      const HloInstruction* producer, const HloInstruction* consumer) override;
+
+ private:
+  HloInstruction* FuseInstruction(HloInstruction* fusion_instruction,
+                                  HloInstruction* producer) override;
+
+  // Consumes a unit of compiler fuel and returns true if we should
+  // continue with the transformation.
+  bool ConsumeFuel(HloInstruction* producer, HloInstruction* consumer);
+
+  tsl::thread::ThreadPool* thread_pool_;
+  se::DeviceDescription device_info_;
+
+  // Cost model options that defines priorities in the queue.
+  GpuHloCostAnalysis::Options cost_analysis_options_;
+
+  // Proto with structured logs of fusion decisions. Used only for debugging. If
+  // null, logging is disabled.
+  std::unique_ptr<FusionProcessDumpProto> fusion_process_dump_;
+
+  HloFusionAnalysisCache fusion_analysis_cache_;
+
+  mlir::MLIRContext mlir_context_;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_PRIORITY_FUSION_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/priority_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/priority_fusion_test.cc
new file mode 100644
index 0000000..4abd182
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/priority_fusion_test.cc
@@ -0,0 +1,942 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/priority_fusion.h"
+
+#include <stdint.h>
+
+#include <memory>
+#include <optional>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/gpu_device_info_for_tests.h"
+#include "xla/service/gpu/gpu_fusible.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
+#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h"
+#include "xla/service/hlo_cost_analysis.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/tests/verified_hlo_module.h"
+#include "tsl/platform/status_matchers.h"
+#include "tsl/platform/statusor.h"
+
+namespace m = ::xla::match;
+
+using ::testing::UnorderedElementsAre;
+using ::tsl::testing::IsOk;
+using ::tsl::testing::IsOkAndHolds;
+
+namespace xla {
+namespace gpu {
+
+class PriorityFusionTest : public HloTestBase {
+  HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const {
+    return [&](const Shape& shape) {
+      constexpr int64_t kPointerSize = 8;
+      return ShapeUtil::ByteSizeOf(shape, kPointerSize);
+    };
+  }
+
+ public:
+  std::vector<HloFusionAnalysis::EmitterFusionKind> RunAndGetFusionKinds(
+      absl::string_view hlo) {
+    auto module = ParseAndReturnVerifiedModule(hlo).value();
+    EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(true));
+    EXPECT_THAT(module->RemoveUnusedComputations(), IsOk());
+    std::vector<HloFusionAnalysis::EmitterFusionKind> kinds;
+    for (auto computation : module->computations()) {
+      if (!computation->FusionInstruction()) continue;
+
+      auto device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo();
+      auto analysis = HloFusionAnalysis::Create(
+          *computation->FusionInstruction(), device_info);
+      kinds.push_back(analysis.GetEmitterFusionKind());
+    }
+    return kinds;
+  }
+
+  PriorityFusion priority_fusion_{
+      /*thread_pool=*/nullptr, TestGpuDeviceInfo::RTXA6000DeviceInfo(),
+      GpuHloCostAnalysis::Options{ShapeSizeBytesFunction(),
+                                  /*per_second_rates=*/{},
+                                  /*count_multiple_input_accesses=*/true}};
+};
+
+TEST_F(PriorityFusionTest, FuseWithSharedArgument) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule test_module
+
+    ENTRY main {
+      %p0 = f32[] parameter(0)
+      %p1 = f32[] parameter(1)
+      %subtract = f32[] subtract(%p0, %p1)
+      %compare = pred[] compare(%subtract, %subtract), direction=NE
+      %add = f32[] add(%p0, %p1)
+      %abs = f32[] abs(%subtract)
+      ROOT %select = f32[] select(%compare, %add, %abs)
+    })")
+                    .value();
+
+  EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(true));
+
+  HloInstruction* root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(root, GmockMatch(m::Fusion()));
+  EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kLoop);
+}
+
+TEST_F(PriorityFusionTest, FusionFusionWithDuplication) {
+  absl::string_view kHlo = R"(
+    HloModule test_module
+
+    square {
+      p = f32[16384]{0} parameter(0)
+      ROOT m = f32[16384]{0} multiply(p, p)
+    }
+
+    exp {
+      p = f32[16384]{0} parameter(0)
+      ROOT e = f32[16384]{0} exponential(p)
+    }
+
+    log {
+      p = f32[16384]{0} parameter(0)
+      ROOT l = f32[16384]{0} log(p)
+    }
+
+    ENTRY main {
+      p = f32[16384]{0} parameter(0)
+      s = f32[16384]{0} fusion(p), kind=kLoop, calls=square
+      e = f32[16384]{0} fusion(s), kind=kLoop, calls=exp
+      l = f32[16384]{0} fusion(s), kind=kInput, calls=log
+      ROOT t = (f32[16384], f32[16384]) tuple(l, e)
+    })";
+
+  RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"(
+CHECK: ENTRY
+CHECK-NEXT: %[[PARAM:.*]] = f32[16384]{0} parameter(0)
+CHECK-NEXT: %[[FUSION_0:.*]] = f32[16384]{0} fusion(%[[PARAM]])
+CHECK-NEXT: %[[FUSION_1:.*]] = f32[16384]{0} fusion(%[[PARAM]])
+CHECK-NEXT: ROOT {{.*}} tuple(%[[FUSION_0]], %[[FUSION_1]])
+  )");
+}
+
+TEST_F(PriorityFusionTest, FuseBroadcastIntoBitcastConsumers) {
+  absl::string_view kHlo = R"(
+    HloModule test_module
+
+    ENTRY main {
+      param_0 = f32[96]{0} parameter(0)
+      broadcast = f32[8,96,128,7]{3,2,1,0} broadcast(param_0), dimensions={1}
+      bitcast.6079.2 = f32[8,24,4,128,7]{4,3,2,1,0} bitcast(broadcast)
+      ROOT transpose.1990.2 = f32[8,24,128,7,4]{4,3,2,1,0} transpose(bitcast.6079.2), dimensions={0,1,3,4,2}
+    }
+  )";
+  RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"(
+CHECK:      ENTRY
+CHECK-NEXT: %[[PARAM:.*]] = f32[96]{0} parameter(0)
+CHECK-NEXT: ROOT %{{.*}} fusion(%[[PARAM]])
+  )");
+}
+
+TEST_F(PriorityFusionTest, FuseWideningConvertIntoConsumers) {
+  absl::string_view kHlo = R"(
+    HloModule test_module
+
+    ENTRY main {
+      p = f16[512]{0} parameter(0)
+      a = f16[512]{0} add(p, p)
+      c = f32[512]{0} convert(a)
+      s = f32[512]{0} multiply(c, c)
+      bc = s32[512]{0} bitcast(c)
+      ROOT t = (f32[512], s32[512]) tuple(s, bc)
+    })";
+
+  RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"(
+CHECK:      ENTRY
+CHECK-NEXT: %[[PARAM:.*]] = f16[512]{0} parameter(0)
+CHECK-NEXT: %[[FUSION_F32:.*]] = f32[512]{0} fusion(%[[PARAM]])
+CHECK-NEXT: %[[CONVERT_FUSION:.*]] = f32[512]{0} fusion(%[[PARAM]])
+CHECK-NEXT: %[[BITCAST:.*]] = s32[512]{0} bitcast(%[[CONVERT_FUSION]])
+CHECK-NEXT: ROOT %{{.*}} = (f32[512]{0}, s32[512]{0}) tuple(%[[FUSION_F32]], %[[BITCAST]])
+  )");
+}
+
+TEST_F(PriorityFusionTest, FuseConvertIntoReduce) {
+  absl::string_view kHlo = R"(
+    HloModule test_module
+
+    add {
+      p0 = f32[] parameter(0)
+      p1 = f32[] parameter(1)
+      ROOT add.13235 = f32[] add(p0, p1)
+    }
+
+    ENTRY main {
+      param_0_0.79 = bf16[1024,8192]{1,0} parameter(0)
+      param_1_0.79 = bf16[1024,8192]{1,0} parameter(1)
+      param_2.483 = f32[8192]{0} parameter(2)
+      param_4.2892 = bf16[1024,8192]{1,0} parameter(3)
+      convert.21854 = f32[1024,8192]{1,0} convert(param_0_0.79)
+      convert.21855 = f32[1024,8192]{1,0} convert(param_1_0.79)
+      constant_7773 = f32[] constant(0)
+      broadcast.14555 = f32[1024,8192]{1,0} broadcast(param_2.483), dimensions={1}
+      multiply.6906 = f32[1024,8192]{1,0} multiply(broadcast.14555, convert.21854)
+      reduce.4813 = f32[1024]{0} reduce(multiply.6906, constant_7773), dimensions={1}, to_apply=add
+      convert.13970 = bf16[1024]{0} convert(reduce.4813)
+      convert.21534 = f32[1024,8192]{1,0} convert(param_4.2892)
+      multiply.6910.clone.1 = f32[1024,8192]{1,0} multiply(broadcast.14555, convert.21534)
+      reduce.4811.clone.1 = f32[1024]{0} reduce(multiply.6910.clone.1, constant_7773), dimensions={1}, to_apply=add
+      convert.13967.clone.1 = bf16[1024]{0} convert(reduce.4811.clone.1)
+      multiply.6908.clone.1 = f32[1024,8192]{1,0} multiply(broadcast.14555, convert.21855)
+      reduce.4812.clone.1 = f32[1024]{0} reduce(multiply.6908.clone.1, constant_7773), dimensions={1}, to_apply=add
+      convert.13969.clone.1 = bf16[1024]{0} convert(reduce.4812.clone.1)
+      ROOT fusion.241 = (bf16[1024]{0}, bf16[1024]{0}, bf16[1024]{0}) tuple(convert.13970, convert.13967.clone.1, convert.13969.clone.1)
+    })";
+
+  RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"(
+CHECK-COUNT-3: ROOT {{.*}} convert(
+CHECK: ENTRY %main
+CHECK-COUNT-3: fusion
+  )");
+}
+
+TEST_F(PriorityFusionTest, ReductionEpilogueFusionRegressionTest) {
+  // Regression test for epilogue fusion of convert into a reduction, even if
+  // the convert has a bitcast as consumer.
+  absl::string_view kHlo = R"(
+    HloModule test_module
+
+    add {
+      rhs.407 = f32[] parameter(1)
+      lhs.407 = f32[] parameter(0)
+      ROOT add.24451 = f32[] add(lhs.407, rhs.407)
+    }
+
+    ENTRY main {
+      param_1.15162 = f32[2752]{0} parameter(1)
+      convert.44829 = bf16[2752]{0} convert(param_1.15162)
+      bitcast.24686 = bf16[1,1,2752]{2,1,0} bitcast(convert.44829)
+      convert.44468 = f32[1,1,2752]{2,1,0} convert(bitcast.24686)
+      constant_13722 = bf16[] constant(1)
+      convert.17451 = f32[] convert(constant_13722)
+      broadcast.17565 = f32[1,1,2752]{2,1,0} broadcast(convert.17451), dimensions={}
+      negate.167 = f32[1,1,2752]{2,1,0} negate(convert.44468)
+      exponential.569 = f32[1,1,2752]{2,1,0} exponential(negate.167)
+      add.1850 = f32[1,1,2752]{2,1,0} add(broadcast.17565, exponential.569)
+      divide.1376 = f32[1,1,2752]{2,1,0} divide(broadcast.17565, add.1850)
+      multiply.9709 = f32[1,1,2752]{2,1,0} multiply(convert.44468, divide.1376)
+      param_0.15005 = f32[2752]{0} parameter(0)
+      convert.44826 = bf16[2752]{0} convert(param_0.15005)
+      bitcast.24683 = bf16[1,1,2752]{2,1,0} bitcast(convert.44826)
+      convert.44467 = f32[1,1,2752]{2,1,0} convert(bitcast.24683)
+      multiply.9708 = f32[1,1,2752]{2,1,0} multiply(multiply.9709, convert.44467)
+      convert.16959 = bf16[1,1,2752]{2,1,0} convert(multiply.9708)
+      fusion.3203 = bf16[2752]{0} bitcast(convert.16959)
+      convert.15093 = f32[2752]{0} convert(fusion.3203)
+      broadcast.13841 = f32[8192,2752]{1,0} broadcast(convert.15093), dimensions={1}
+      param_0.15525 = bf16[8192,2752]{1,0} parameter(2)
+      convert.13738 = f32[8192,2752]{1,0} convert(param_0.15525)
+      multiply.6422 = f32[8192,2752]{1,0} multiply(broadcast.13841, convert.13738)
+      constant_14382 = f32[] constant(0)
+      fusion.339 = f32[8192]{0} reduce(multiply.6422, constant_14382), dimensions={1}, to_apply=add
+      convert.44633 = bf16[8192]{0} convert(fusion.339)
+      ROOT bitcast.24487 = bf16[1,1,8192]{2,1,0} bitcast(convert.44633)
+    }
+                                                     )";
+
+  EXPECT_THAT(
+      RunAndGetFusionKinds(kHlo),
+      UnorderedElementsAre(HloFusionAnalysis::EmitterFusionKind::kLoop,
+                           HloFusionAnalysis::EmitterFusionKind::kReduction));
+
+  RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"(
+CHECK: ENTRY
+CHECK: ROOT {{.*}} bitcast({{.*}}fusion{{.*}})
+  )");
+}
+
+TEST_F(PriorityFusionTest, DoNotChangeReductionFusionToLoopFusion) {
+  // Regression test for epilogue fusion of slice into a reduction. The fusion
+  // kind for the reduction fusion is intentionally chosen to be set to kLoop,
+  // as we cannot rely on reductions always having fusion kind kInput.
+  auto module = *ParseAndReturnVerifiedModule(R"(
+    HloModule test_module
+
+    add {
+      rhs.407 = f32[] parameter(1)
+      lhs.407 = f32[] parameter(0)
+      ROOT add.24451 = f32[] add(lhs.407, rhs.407)
+    }
+
+    fused_computation {
+      p0 = f32[16,64]{1,0} parameter(0)
+      zero = f32[] constant(0.0)
+      ROOT reduce = f32[16]{0} reduce(p0, zero), dimensions={1}, to_apply=add
+    }
+
+    ENTRY main {
+      param0 = f32[16,64]{1,0} parameter(0)
+      fusion = f32[16]{0} fusion(param0), kind=kLoop, calls=fused_computation
+      ROOT slice = f32[8]{0} slice(fusion), slice={[0:8]}
+    })");
+  EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(false));
+}
+
+TEST_F(PriorityFusionTest, DoNotFuseTransposeIntoReduce) {
+  absl::string_view kHlo = R"(
+    HloModule test_module
+
+    add {
+      Arg_1.1046 = f32[] parameter(1)
+      Arg_0.1045 = f32[] parameter(0)
+      ROOT add.3303 = f32[] add(Arg_0.1045, Arg_1.1046)
+    }
+
+    ENTRY main {
+      param_0.17323 = pred[2048,2048]{1,0} parameter(0)
+      broadcast.22829 = pred[1,12,2048,2048]{3,2,1,0} broadcast(param_0.17323), dimensions={2,3}
+      param_1.19761 = bf16[2048,24576]{1,0} parameter(1)
+      convert.29880.clone.1 = f32[2048,24576]{1,0} convert(param_1.19761)
+      constant_10033_clone_1 = bf16[] constant(0.02002)
+      convert.30056.clone.1 = f32[] convert(constant_10033_clone_1)
+      broadcast.18898.clone.1 = f32[2048,24576]{1,0} broadcast(convert.30056.clone.1), dimensions={}
+      multiply.13451.clone.1 = f32[2048,24576]{1,0} multiply(convert.29880.clone.1, broadcast.18898.clone.1)
+      tanh.798.clone.1 = f32[2048,24576]{1,0} tanh(multiply.13451.clone.1)
+      constant_10244_clone_1 = bf16[] constant(50)
+      convert.30039.clone.1 = f32[] convert(constant_10244_clone_1)
+      broadcast.18310.clone.1 = f32[2048,24576]{1,0} broadcast(convert.30039.clone.1), dimensions={}
+      multiply.12550.clone.1 = f32[2048,24576]{1,0} multiply(tanh.798.clone.1, broadcast.18310.clone.1)
+      convert.29370.clone.1 = bf16[2048,24576]{1,0} convert(multiply.12550.clone.1)
+      bitcast.22330 = bf16[1,2048,2048,12]{3,2,1,0} bitcast(convert.29370.clone.1)
+      transpose.6582 = bf16[1,12,2048,2048]{3,2,1,0} transpose(bitcast.22330), dimensions={0,3,2,1}
+      convert.33705 = f32[1,12,2048,2048]{3,2,1,0} convert(transpose.6582)
+      constant_10212 = f32[] constant(-2.38197633e+38)
+      broadcast.22828 = f32[1,12,2048,2048]{3,2,1,0} broadcast(constant_10212), dimensions={}
+      select.589 = f32[1,12,2048,2048]{3,2,1,0} select(broadcast.22829, convert.33705, broadcast.22828)
+      bitcast.22075 = f32[12,2048,2048]{2,1,0} bitcast(select.589)
+      constant_10192 = f32[] constant(-inf)
+      reduce.1614 = f32[12,2048]{1,0} reduce(bitcast.22075, constant_10192), dimensions={2}, to_apply=add
+
+      predarg = pred[1,1,2048,2048]{3,2,1,0} parameter(2)
+      bitcast.11069 = pred[2048,2048]{1,0} bitcast(predarg)
+
+      broadcast.22825 = pred[1,12,2048,2048]{3,2,1,0} broadcast(bitcast.11069), dimensions={2,3}
+      bitcast.22331 = bf16[1,2048,2048,12]{3,2,1,0} bitcast(convert.29370.clone.1)
+      transpose.6580 = bf16[1,12,2048,2048]{3,2,1,0} transpose(bitcast.22331), dimensions={0,3,2,1}
+      convert.33703 = f32[1,12,2048,2048]{3,2,1,0} convert(transpose.6580)
+      constant_10213 = f32[] constant(-2.38197633e+38)
+      broadcast.22824 = f32[1,12,2048,2048]{3,2,1,0} broadcast(constant_10213), dimensions={}
+      select.587 = f32[1,12,2048,2048]{3,2,1,0} select(broadcast.22825, convert.33703, broadcast.22824)
+      broadcast.22819 = f32[1,12,2048,2048]{3,2,1,0} broadcast(reduce.1614), dimensions={1,2}
+      subtract.1129 = f32[1,12,2048,2048]{3,2,1,0} subtract(select.587, broadcast.22819)
+      exponential.418 = f32[1,12,2048,2048]{3,2,1,0} exponential(subtract.1129)
+      bitcast.22074 = f32[12,2048,2048]{2,1,0} bitcast(exponential.418)
+      constant_10490 = f32[] constant(0)
+      reduce.1613 = f32[12,2048]{1,0} reduce(bitcast.22074, constant_10490), dimensions={2}, to_apply=add
+
+      constant_468 = f32[] constant(-2.38197633e+38)
+      broadcast.22833 = pred[1,12,2048,2048]{3,2,1,0} broadcast(bitcast.11069), dimensions={2,3}
+      bitcast.22332 = bf16[1,2048,2048,12]{3,2,1,0} bitcast(convert.29370.clone.1)
+      transpose.6584 = bf16[1,12,2048,2048]{3,2,1,0} transpose(bitcast.22332), dimensions={0,3,2,1}
+      convert.33707 = f32[1,12,2048,2048]{3,2,1,0} convert(transpose.6584)
+      broadcast.22832 = f32[1,12,2048,2048]{3,2,1,0} broadcast(constant_468), dimensions={}
+      select.591 = f32[1,12,2048,2048]{3,2,1,0} select(broadcast.22833, convert.33707, broadcast.22832)
+      broadcast.22821 = f32[1,12,2048,2048]{3,2,1,0} broadcast(reduce.1614), dimensions={1,2}
+      subtract.1131 = f32[1,12,2048,2048]{3,2,1,0} subtract(select.591, broadcast.22821)
+      exponential.420 = f32[1,12,2048,2048]{3,2,1,0} exponential(subtract.1131)
+      broadcast.18351 = f32[1,12,2048,2048]{3,2,1,0} broadcast(reduce.1613), dimensions={1,2}
+      divide.340 = f32[1,12,2048,2048]{3,2,1,0} divide(exponential.420, broadcast.18351)
+      ROOT convert.29418 = bf16[1,12,2048,2048]{3,2,1,0} convert(divide.340)
+    })";
+
+  using Kind = HloFusionAnalysis::EmitterFusionKind;
+  EXPECT_THAT(
+      RunAndGetFusionKinds(kHlo),
+      UnorderedElementsAre(Kind::kLoop, Kind::kLoop, Kind::kLoop,
+                           Kind::kReduction, Kind::kReduction, Kind::kTranspose,
+                           Kind::kTranspose, Kind::kTranspose));
+}
+
+TEST_F(PriorityFusionTest, DoNotFuseReduceIntoReduce) {
+  absl::string_view kHlo = R"(
+    HloModule test_module
+
+    add {
+      p0 = f32[] parameter(0)
+      p1 = f32[] parameter(1)
+      ROOT add.13235 = f32[] add(p0, p1)
+    }
+
+    ENTRY main {
+      p0 = f32[8,4,128,226]{3,2,1,0} parameter(0)
+      c0 = f32[] constant(0)
+      r0 = f32[8,4,128]{2,1,0} reduce(p0, c0), dimensions={3}, to_apply=add
+      ROOT r1 = f32[8,4]{1,0} reduce(r0, c0), dimensions={2}, to_apply=add
+    })";
+
+  RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"(
+CHECK: ROOT {{.*}} reduce(
+CHECK: ROOT {{.*}} reduce(
+  )");
+}
+
+TEST_F(PriorityFusionTest, ConvertFusedIntoReduce) {
+  absl::string_view kHlo = R"(
+    HloModule test_module
+
+    add {
+      p0 = f32[] parameter(0)
+      p1 = f32[] parameter(1)
+      ROOT add.13235 = f32[] add(p0, p1)
+    }
+
+    ENTRY main {
+      param_0_0.79 = bf16[1024,8192]{1,0} parameter(0)
+      param_1_0.79 = bf16[1024,8192]{1,0} parameter(1)
+      param_2.483 = f32[8192]{0} parameter(2)
+      param_4.2892 = bf16[1024,8192]{1,0} parameter(3)
+      convert.21854 = f32[1024,8192]{1,0} convert(param_0_0.79)
+      convert.21855 = f32[1024,8192]{1,0} convert(param_1_0.79)
+      constant_7773 = f32[] constant(0)
+      broadcast.14555 = f32[1024,8192]{1,0} broadcast(param_2.483), dimensions={1}
+      multiply.6906 = f32[1024,8192]{1,0} multiply(broadcast.14555, convert.21854)
+      reduce.4813 = f32[1024]{0} reduce(multiply.6906, constant_7773), dimensions={1}, to_apply=add
+      convert.13970 = bf16[1024]{0} convert(reduce.4813)
+      convert.21534 = f32[1024,8192]{1,0} convert(param_4.2892)
+      multiply.6910.clone.1 = f32[1024,8192]{1,0} multiply(broadcast.14555, convert.21534)
+      reduce.4811.clone.1 = f32[1024]{0} reduce(multiply.6910.clone.1, constant_7773), dimensions={1}, to_apply=add
+      convert.13967.clone.1 = bf16[1024]{0} convert(reduce.4811.clone.1)
+      multiply.6908.clone.1 = f32[1024,8192]{1,0} multiply(broadcast.14555, convert.21855)
+      reduce.4812.clone.1 = f32[1024]{0} reduce(multiply.6908.clone.1, constant_7773), dimensions={1}, to_apply=add
+      convert.13969.clone.1 = bf16[1024]{0} convert(reduce.4812.clone.1)
+      ROOT fusion.241 = (bf16[1024]{0}, bf16[1024]{0}, bf16[1024]{0}) tuple(convert.13970, convert.13967.clone.1, convert.13969.clone.1)
+    })";
+
+  RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"(
+CHECK-COUNT-3: ROOT {{.*}} convert(
+CHECK: ENTRY %main
+CHECK-COUNT-3: fusion(
+CHECK-NOT: fusion(
+  )");
+}
+
+TEST_F(PriorityFusionTest, DoNotFuseDynamicUpdateSliceIntoReduce) {
+  GTEST_SKIP() << "b/294198633";
+  absl::string_view kHlo = R"(
+    HloModule test_module
+
+add {
+  Arg_1.1046 = f32[] parameter(1)
+  Arg_0.1045 = f32[] parameter(0)
+  ROOT add.3303 = f32[] add(Arg_0.1045, Arg_1.1046)
+}
+
+ENTRY main {
+  param_0.10549 = f32[4,2112]{1,0} parameter(0)
+  param_5.2561 = pred[] parameter(5)
+  broadcast.19725 = pred[4,1]{1,0} broadcast(param_5.2561), dimensions={}
+  param_1.11587 = pred[4]{0} parameter(1)
+  constant_5837 = f32[] constant(1)
+  broadcast.19723 = f32[4]{0} broadcast(constant_5837), dimensions={}
+  param_2.5952 = f32[4,8000]{1,0} parameter(2)
+  param_3.4004 = f32[4]{0} parameter(3)
+  broadcast.19718 = f32[4,8000]{1,0} broadcast(param_3.4004), dimensions={0}
+  subtract.1112 = f32[4,8000]{1,0} subtract(param_2.5952, broadcast.19718)
+  exponential.418 = f32[4,8000]{1,0} exponential(subtract.1112)
+  constant_6254 = f32[] constant(0)
+  reduce.1154 = f32[4]{0} reduce(exponential.418, constant_6254), dimensions={1}, to_apply=add
+  log.38 = f32[4]{0} log(reduce.1154)
+  broadcast.19717 = f32[4,8000]{1,0} broadcast(log.38), dimensions={0}
+  subtract.1111 = f32[4,8000]{1,0} subtract(subtract.1112, broadcast.19717)
+  iota.170 = s32[4,1]{1,0} iota(), iota_dimension=0
+  constant_6281 = s32[] constant(0)
+  broadcast.19735 = s32[4]{0} broadcast(constant_6281), dimensions={}
+  param_4.3400 = s32[4,8000]{1,0} parameter(4)
+  slice.3186 = s32[4,40]{1,0} slice(param_4.3400), slice={[0:4], [0:40]}
+  iota.168 = s32[4,1]{1,0} iota(), iota_dimension=0
+  param_7.1596 = s32[4]{0} parameter(7)
+  compare.341 = pred[4]{0} compare(param_7.1596, broadcast.19735), direction=LT
+  constant_5833 = s32[] constant(40)
+  broadcast.19731 = s32[4]{0} broadcast(constant_5833), dimensions={}
+  add.8348 = s32[4]{0} add(param_7.1596, broadcast.19731)
+  select.418 = s32[4]{0} select(compare.341, add.8348, param_7.1596)
+  bitcast.20942 = s32[4,1]{1,0} bitcast(select.418)
+  concatenate.1337 = s32[4,2]{1,0} concatenate(iota.168, bitcast.20942), dimensions={1}
+  gather.43 = s32[4,1,1]{2,1,0} gather(slice.3186, concatenate.1337), offset_dims={1,2}, collapsed_slice_dims={}, start_index_map={0,1}, index_vector_dim=1, slice_sizes={1,1}
+  bitcast.20941 = s32[4]{0} bitcast(gather.43)
+  select.398 = s32[4]{0} select(param_1.11587, broadcast.19735, bitcast.20941)
+  compare.334 = pred[4]{0} compare(select.398, broadcast.19735), direction=LT
+  constant_6260 = s32[] constant(8000)
+  broadcast.19720 = s32[4]{0} broadcast(constant_6260), dimensions={}
+  add.8336 = s32[4]{0} add(select.398, broadcast.19720)
+  select.396 = s32[4]{0} select(compare.334, add.8336, select.398)
+  bitcast.20830 = s32[4,1]{1,0} bitcast(select.396)
+  concatenate.1308 = s32[4,2]{1,0} concatenate(iota.170, bitcast.20830), dimensions={1}
+  gather.41 = f32[4,1,1]{2,1,0} gather(subtract.1111, concatenate.1308), offset_dims={1,2}, collapsed_slice_dims={}, start_index_map={0,1}, index_vector_dim=1, slice_sizes={1,1}
+  bitcast.20824 = f32[4]{0} bitcast(gather.41)
+  select.389 = f32[4]{0} select(param_1.11587, broadcast.19723, bitcast.20824)
+  bitcast.20823 = f32[4,1]{1,0} bitcast(select.389)
+  param_6.1719 = s32[] parameter(6)
+  constant_6323 = s32[] constant(2048)
+  add.8549 = s32[] add(param_6.1719, constant_6323)
+  compare.388 = pred[] compare(add.8549, constant_6281), direction=LT
+  constant_5436 = s32[] constant(4160)
+  add.8339 = s32[] add(param_6.1719, constant_5436)
+  select.409 = s32[] select(compare.388, add.8339, add.8549)
+  dynamic-slice.36 = f32[4,1]{1,0} dynamic-slice(param_0.10549, constant_6281, select.409), dynamic_slice_sizes={4,1}
+  select.388 = f32[4,1]{1,0} select(broadcast.19725, bitcast.20823, dynamic-slice.36)
+  ROOT dynamic-update-slice.307 = f32[4,2112]{1,0} dynamic-update-slice(param_0.10549, select.388, constant_6281, select.409)
+})";
+
+  RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"(
+CHECK: ROOT {{.*}} dynamic-update-slice(
+CHECK: %[[REDUCE:.*]] = {{.*}} reduce(
+CHECK: ROOT {{.*}} log(%[[REDUCE]])
+CHECK: ENTRY
+CHECK-COUNT-2: fusion(
+  )");
+}
+
+TEST_F(PriorityFusionTest, DontFuseIntoFirstOperandOfScatter) {
+  auto module = *ParseAndReturnVerifiedModule(R"(
+    HloModule test_module
+
+    add {
+      lhs = s32[] parameter(0)
+      rhs = s32[] parameter(1)
+      ROOT add = s32[] add(lhs, rhs)
+    }
+
+    ENTRY FuseIntoScatter {
+      p0 = s32[3,3] parameter(0)
+      operand = s32[3,3] add(p0, p0)
+      p1 = s32[2] parameter(1)
+      indices = s32[2] add(p1, p1)
+      p2 = s32[2,3] parameter(2)
+      updates = s32[2,3] add(p2, p2)
+      scatter = s32[3,3] scatter(operand, indices, updates),
+          to_apply=add,
+          update_window_dims={1},
+          inserted_window_dims={0},
+          scatter_dims_to_operand_dims={0},
+          index_vector_dim=1
+      ROOT add = s32[3,3] add(scatter, scatter)
+    })");
+
+  EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(true));
+
+  HloInstruction* root = module->entry_computation()->root_instruction();
+  const HloInstruction* fusion = nullptr;
+  ASSERT_THAT(root, GmockMatch(m::Add(m::Fusion(&fusion), m::Fusion())));
+  EXPECT_EQ(fusion->fusion_kind(), HloInstruction::FusionKind::kInput);
+  EXPECT_THAT(fusion->fused_expression_root(),
+              GmockMatch(m::Scatter(m::Parameter(), m::Add(), m::Add())));
+}
+
+// This test is similar to DontFuseIntoFirstOperandOfScatter, but PriorityFusion
+// has a separate run to fuse constants. Fusing anything into a scatter fusion
+// will fail in the emitter.
+TEST_F(PriorityFusionTest, DontFuseConstantIntoFirstOperandOfScatter) {
+  auto module = *ParseAndReturnVerifiedModule(R"(
+    HloModule test_module
+
+    add {
+      lhs = s32[] parameter(0)
+      rhs = s32[] parameter(1)
+      ROOT add = s32[] add(lhs, rhs)
+    }
+
+    ENTRY FuseIntoScatter {
+      operand = s32[1] constant({0})
+      indices = s32[24,1] parameter(0)
+      constant = s32[] constant(1)
+      updates = s32[24,1] broadcast(constant)
+      ROOT scatter = s32[1] scatter(operand, indices, updates),
+          to_apply=add,
+          update_window_dims={1},
+          inserted_window_dims={},
+          scatter_dims_to_operand_dims={0},
+          index_vector_dim=1
+    })");
+
+  EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(true));
+
+  HloInstruction* root = module->entry_computation()->root_instruction();
+  ASSERT_THAT(root, GmockMatch(m::Fusion(m::Constant(), m::Parameter())));
+  EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kInput);
+  EXPECT_THAT(root->fused_expression_root(),
+              GmockMatch(m::Scatter(m::Parameter(), m::Parameter(),
+                                    m::Broadcast(m::Constant()))));
+}
+
+TEST_F(PriorityFusionTest, DoNotFuseReduceIntoReduceEvenIfOccupancyIsHigh) {
+  constexpr absl::string_view kHlo = R"(
+    HloModule test_module
+
+    add {
+      lhs = f32[] parameter(0)
+      rhs = f32[] parameter(1)
+      ROOT add = f32[] add(lhs, rhs)
+    }
+
+    ENTRY main {
+      p0 = f32[4,3584,128,168]{3,2,1,0} parameter(0)
+      c = f32[] constant(0)
+      r1 = f32[4,3584,128]{2,1,0} reduce(p0, c), dimensions={3}, to_apply=add
+      ROOT r2 = f32[4,3584]{1,0} reduce(r1, c), dimensions={2}, to_apply=add
+    })";
+
+  RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"(
+CHECK: ROOT {{.*}} reduce(
+CHECK: ROOT {{.*}} reduce(
+  )");
+}
+
+TEST_F(PriorityFusionTest, FuseReductionEpilogueWithMultipleUsers) {
+  // Regression test that verifies we correctly fuse the `log` into the reduce.
+  constexpr absl::string_view kHlo = R"(
+    HloModule test_module
+
+    add {
+      x = f32[] parameter(0)
+      y = f32[] parameter(1)
+      ROOT add = f32[] add(x, y)
+    }
+
+    fused_computation {
+      p0 = f32[64,16384]{1,0} parameter(0)
+      c0 = f32[] constant(0)
+      ROOT reduce.858 = f32[64]{0} reduce(p0, c0), dimensions={1}, to_apply=add
+    }
+
+    ENTRY main {
+      p0 = f32[64,16384]{1,0} parameter(0)
+      fusion = f32[64]{0} fusion(p0), kind=kInput, calls=fused_computation
+      log = f32[64]{0} log(fusion)
+      negate = f32[64]{0} custom-call(log), custom_call_target="negate"
+      ROOT add = f32[64]{0} add(negate, log)
+    }
+  )";
+
+  RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"(
+    CHECK: ENTRY
+    CHECK: %[[PARAM:.*]] = {{.*}} parameter(0)
+    CHECK: %[[FUSION:.*]] = {{.*}} fusion(%[[PARAM]])
+    CHECK: custom-call(%[[FUSION]])
+  )");
+}
+
+TEST_F(PriorityFusionTest, EpilogueFusion) {
+  absl::string_view kHlo = R"(
+    HloModule test_module
+
+    add {
+      p0 = f32[] parameter(0)
+      p1 = f32[] parameter(1)
+      ROOT add.13235 = f32[] add(p0, p1)
+    }
+
+    fused_computation.1 {
+      p0 = f32[8,4,128,226]{3,2,1,0} parameter(0)
+      c0 = f32[] constant(0)
+      ROOT r0 = f32[8,4,128]{2,1,0} reduce(p0, c0), dimensions={3}, to_apply=add
+    }
+
+    fused_computation.2 {
+      p0 = f32[8,4,128]{2,1,0} parameter(0)
+      r1 = f32[8,4,128]{2,1,0} log(p0)
+      ROOT r2 = f32[8,4,128]{2,1,0} log(r1)
+    }
+
+    ENTRY main {
+      p0 = f32[8,4,128,226]{3,2,1,0} parameter(0)
+      f1 = f32[8,4,128]{2,1,0} fusion(p0), kind=kInput, calls=%fused_computation.1
+      ROOT fusion = f32[8,4,128]{2,1,0} fusion(f1), kind=kLoop, calls=%fused_computation.2
+    })";
+
+  RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"(
+CHECK: ROOT {{.*}} = f32[8,4,128]{2,1,0} fusion(%p{{.*}}), kind=kInput, calls=%fused_computation)");
+}
+
+TEST_F(PriorityFusionTest, EpilogueFusionFails) {
+  auto module = *ParseAndReturnVerifiedModule(R"(
+    HloModule test_module
+
+    add {
+      p0 = f32[] parameter(0)
+      p1 = f32[] parameter(1)
+      ROOT add.13235 = f32[] add(p0, p1)
+    }
+
+    fused_computation.1 {
+      p0 = f32[28672,4096]{1,0} parameter(0)
+      c0 = f32[] constant(0)
+      ROOT r = f32[28672]{0} reduce(p0, c0), dimensions={1}, to_apply=add
+    }
+
+    fused_computation.2 {
+      p0 = f32[28672]{0} parameter(0)
+      p1 = f32[28672]{0} parameter(1)
+      ROOT a = f32[28672]{0} add(p0, p1)
+    }
+
+    ENTRY main {
+      p0 = f32[28672,4096]{1,0} parameter(0)
+      p1 = f32[28672]{0} parameter(1)
+      f = f32[28672]{0} fusion(p0), kind=kInput, calls=%fused_computation.1
+      ROOT fusion = f32[28672]{0} fusion(f,p1), kind=kLoop, calls=%fused_computation.2
+    })");
+
+  EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(false));
+}
+
+TEST_F(PriorityFusionTest, DoNotFuseIntoRoot) {
+  auto module = *ParseAndReturnVerifiedModule(R"(
+    HloModule test_module
+
+    ENTRY %main (p.0: u32[2], p.1: u32[]) -> u32[2] {
+      %p.0 = u32[2]{0} parameter(0)
+      %p.1 = u32[] parameter(1)
+      ROOT %broadcast = u32[2]{0} broadcast(u32[] %p.1), dimensions={}, sharding={replicated}
+      %add = u32[2]{0} add(u32[2]{0} %p.0, u32[2]{0} %broadcast)
+      %tuple.1 = (u32[2]{0}) tuple(u32[2]{0} %add)
+      %token.0 = token[] after-all()
+      %outfeed.6 = token[] outfeed((u32[2]{0}) %tuple.1, token[] %token.0), outfeed_shape=(u32[2]{0}), sharding={maximal device=0}
+    })");
+
+  EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(false));
+}
+
+TEST_F(PriorityFusionTest, DontFuseConcat) {
+  // Regression test that verifies we don't fuse concat into a column reduction.
+  auto module = *ParseAndReturnVerifiedModule(R"(
+    HloModule module
+
+    %maximum (param_0: f32[], param_1: f32[]) -> f32[] {
+      %param_0 = f32[] parameter(0)
+      %param_1 = f32[] parameter(1)
+      ROOT %maximum = f32[] maximum(f32[] %param_0, f32[] %param_1)
+    }
+
+    %fused_concat (param_0: f32[1,4,401,8,8], param_1: f32[1,1,4,1023,8], param_2: bf16[1,4,1023,8,8]) -> f32[1,4,1424,8,8] {
+      %param_2 = bf16[1,4,1023,8,8]{4,3,2,1,0} parameter(2)
+      %convert = f32[1,4,1023,8,8]{4,3,2,1,0} convert(bf16[1,4,1023,8,8]{4,3,2,1,0} %param_2)
+      %param_1 = f32[1,1,4,1023,8]{4,3,2,1,0} parameter(1)
+      %bitcast = f32[4,1023,8]{2,1,0} bitcast(f32[1,1,4,1023,8]{4,3,2,1,0} %param_1)
+      %broadcast = f32[1,4,1023,8,8]{4,3,2,1,0} broadcast(f32[4,1023,8]{2,1,0} %bitcast), dimensions={1,2,4}
+      %add = f32[1,4,1023,8,8]{4,3,2,1,0} add(f32[1,4,1023,8,8]{4,3,2,1,0} %convert, f32[1,4,1023,8,8]{4,3,2,1,0} %broadcast)
+      %param_0 = f32[1,4,401,8,8]{4,3,2,1,0} parameter(0)
+      ROOT %concatenate = f32[1,4,1424,8,8]{4,3,2,1,0} concatenate(f32[1,4,1023,8,8]{4,3,2,1,0} %add, f32[1,4,401,8,8]{4,3,2,1,0} %param_0), dimensions={2}
+    }
+
+    %fused_reduce (param_0: f32[], param_1: f32[1,4,1424,8,8]) -> f32[4,8,8] {
+      %param_1 = f32[1,4,1424,8,8]{4,3,2,1,0} parameter(1)
+      %bitcast = f32[4,1424,8,8]{3,2,1,0} bitcast(f32[1,4,1424,8,8]{4,3,2,1,0} %param_1)
+      %param_0 = f32[] parameter(0)
+      ROOT %reduce = f32[4,8,8]{2,1,0} reduce(f32[4,1424,8,8]{3,2,1,0} %bitcast, f32[] %param_0), dimensions={1}, to_apply=%maximum
+    }
+
+    %fused_broadcast (param_0: f32[1,4,1424,8,8], param_1: f32[4,8,8]) -> f32[1,4,1424,8,8] {
+      %param_0 = f32[1,4,1424,8,8]{4,3,2,1,0} parameter(0)
+      %param_1 = f32[4,8,8]{2,1,0} parameter(1)
+      %broadcast = f32[1,4,1424,8,8]{4,3,2,1,0} broadcast(f32[4,8,8]{2,1,0} %param_1), dimensions={1,3,4}
+      ROOT %subtract = f32[1,4,1424,8,8]{4,3,2,1,0} subtract(f32[1,4,1424,8,8]{4,3,2,1,0} %param_0, f32[1,4,1424,8,8]{4,3,2,1,0} %broadcast)
+    }
+
+    ENTRY fusion {
+      %param_0 = f32[1,4,401,8,8]{4,3,2,1,0} parameter(0)
+      %param_1 = f32[1,1,4,1023,8]{4,3,2,1,0} parameter(1)
+      %param_2 = bf16[1,4,1023,8,8]{4,3,2,1,0} parameter(2)
+      %concat = f32[1,4,1424,8,8]{4,3,2,1,0} fusion(%param_0, %param_1, %param_2), kind=kLoop, calls=fused_concat
+      %param_3 = f32[] parameter(3)
+      %reduce = f32[4,8,8]{2,1,0} fusion(%param_3, %concat), kind=kLoop, calls=fused_reduce
+      %param_4 = f32[4,8,8]{2,1,0} parameter(4)
+      %broadcast = f32[1,4,1424,8,8]{4,3,2,1,0} fusion(%concat, %param_4), kind=kLoop, calls=fused_broadcast
+      ROOT tuple = (f32[4,8,8]{2,1,0}, f32[1,4,1424,8,8]{4,3,2,1,0}) tuple(%reduce, %broadcast)
+    }
+  )");
+
+  EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(false));
+}
+
+TEST_F(PriorityFusionTest, FuseOnlySmallConstant) {
+  auto module = *ParseAndReturnVerifiedModule(R"(
+    HloModule module
+
+    ENTRY main {
+      param_0 = f32[32,32]{1,0} parameter(0)
+      c_1 = f32[] constant(1)
+      c_2 = f32[32,32] constant({...})
+      broadcast = f32[32,32]{1,0} broadcast(c_1), dimensions={}
+      add = f32[32,32]{1,0} add(param_0, broadcast)
+      ROOT mul = f32[32,32]{1,0} multiply(c_2, add)
+    }
+  )");
+  EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(true));
+
+  HloInstruction* root = module->entry_computation()->root_instruction();
+  ASSERT_THAT(root, GmockMatch(m::Fusion(m::Constant(), m::Parameter())));
+  EXPECT_THAT(root->fused_expression_root(),
+              GmockMatch(m::Multiply(
+                  m::Parameter(),
+                  m::Add(m::Parameter(), m::Broadcast(m::Constant())))));
+}
+
+TEST_F(PriorityFusionTest, DoNotFuseProducerConsumerMergedTooLarge) {
+  auto module = *ParseAndReturnVerifiedModule(R"(
+    HloModule module
+
+    fused_computation.1 {
+      iota.9.7 = s32[3,1,1]{2,1,0} iota(), iota_dimension=0
+      param_3.29 = s32[] parameter(2)
+      pad.2.7 = s32[3,1,2]{2,1,0} pad(iota.9.7, param_3.29), padding=0_0x0_0x0_1
+      param_2.39 = s32[] parameter(1)
+      broadcast.76.1 = s32[3,1,2]{2,1,0} broadcast(param_2.39), dimensions={}
+      compare.9.1 = pred[3,1,2]{2,1,0} compare(pad.2.7, broadcast.76.1), direction=GE
+      param_1.73 = s32[2]{0} parameter(0)
+      broadcast.78.1 = s32[3,2]{1,0} broadcast(param_1.73), dimensions={1}
+      bitcast.1 = s32[3,2]{1,0} bitcast(pad.2.7)
+      compare.10.1 = pred[3,2]{1,0} compare(bitcast.1, broadcast.78.1), direction=LE
+      bitcast.2 = pred[3,1,2]{2,1,0} bitcast(compare.10.1)
+      ROOT and.3.1 = pred[3,1,2]{2,1,0} and(compare.9.1, bitcast.2)
+    }
+
+    and {
+      x = pred[] parameter(0)
+      y = pred[] parameter(1)
+      ROOT and = pred[] and(x, y)
+    }
+
+    fused_computation.2 {
+      param0 = pred[3,1,2]{2,1,0} parameter(0)
+      slice = pred[1,1,2]{2,1,0} slice(param0), slice={[0:1], [0:1], [0:2]}
+      bitcast = pred[2]{0} bitcast(slice)
+      init = pred[] constant(true)
+      reduce = pred[2]{0} reduce(param0, init), dimensions={0,1}, to_apply=and
+      and = pred[2]{0} and(bitcast, reduce)
+      pad = pred[3]{0} pad(and, init), padding=0_1
+      broadcast = pred[3,2]{1,0} broadcast(pad), dimensions={0}
+      bitcast2 = pred[6]{0} bitcast(broadcast)
+      broadcast2 = pred[2,3]{1,0} broadcast(pad), dimensions={1}
+      bitcast3 = pred[6]{0} bitcast(broadcast2)
+      ROOT and2 = pred[6]{0} and(bitcast2, bitcast3)
+    }
+
+    ENTRY main {
+      p0 = s32[2]{0} parameter(0)
+      p1 = s32[] parameter(1)
+      p2 = s32[] parameter(2)
+      fusion1 = pred[3,1,2]{2,1,0} fusion(p0, p1, p2), kind=kLoop, calls=fused_computation.1
+      ROOT fusion2 = pred[6]{0} fusion(fusion1), kind=kInput, calls=fused_computation.2
+    }
+  )");
+  auto& debug_options = module->mutable_config().mutable_debug_options();
+  debug_options.set_xla_gpu_mlir_emitter_level(3);
+  EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(false));
+}
+
+TEST_F(PriorityFusionTest, CanMergeTritonFusionWithBothProducerAndConsumer) {
+#ifndef GOOGLE_CUDA
+  GTEST_SKIP() << "Triton fusion only enable for CUDA devices.";
+#endif
+
+  const std::string kHloText = R"(
+HloModule t
+add {
+  Arg_0 = f32[] parameter(0)
+  Arg_1 = f32[] parameter(1)
+  ROOT add = f32[] add(Arg_0, Arg_1)
+}
+
+producer_computation {
+  parameter_0 = f32[125]{0} parameter(0)
+  ROOT broadcast = f32[125,127]{1,0} broadcast(parameter_0), dimensions={0}
+}
+
+consumer_computation {
+  parameter_0 = f32[125,127]{1,0} parameter(0)
+  parameter_1 = f32[125,127]{1,0} parameter(1)
+  ROOT multiply = f32[125,127]{1,0} multiply(parameter_1, parameter_0)
+}
+
+triton_softmax_computation {
+  parameter_0 = f32[125,127]{1,0} parameter(0)
+  multiply_0 = f32[125,127]{1,0} multiply(parameter_0, parameter_0)
+  constant_0 = f32[] constant(0)
+  reduce_0 = f32[125]{0} reduce(multiply_0, constant_0), dimensions={1}, to_apply=add
+  broadcast_4 = f32[125,127]{1,0} broadcast(reduce_0), dimensions={0}
+  ROOT multiply = f32[125,127]{1,0} multiply(multiply_0, broadcast_4)
+}
+
+ENTRY main {
+  param_0 = f32[125]{0} parameter(0)
+  param_1 = f32[125,127]{1,0} parameter(1)
+  producer_fusion = f32[125,127]{1,0} fusion(param_0), kind=kLoop, calls=producer_computation
+  triton_softmax = f32[125,127]{1,0} fusion(producer_fusion), kind=kCustom, calls=triton_softmax_computation, backend_config={"fusion_backend_config": {"kind":"__triton"}}
+  ROOT consumer_fusion = f32[125,127]{1,0} fusion(param_1, triton_softmax), kind=kLoop, calls=consumer_computation
+})";
+  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText));
+  auto debug_options = module->config().debug_options();
+  debug_options.set_xla_gpu_enable_triton_softmax_priority_fusion(true);
+  module->mutable_config().set_debug_options(debug_options);
+
+  EXPECT_TRUE(priority_fusion_.Run(module.get()).value());
+  EXPECT_TRUE(verifier().Run(module.get()).status().ok());
+
+  HloInstruction* root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(root, GmockMatch(m::Fusion(m::Parameter(), m::Parameter())));
+  EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kCustom);
+  EXPECT_TRUE(IsGenericTritonFusion(*root));
+}
+
+TEST_F(PriorityFusionTest, DoNotFuseInsideReducer) {
+  auto module = *ParseAndReturnVerifiedModule(R"(
+    %reducer {
+      p0 = f32[] parameter(0)
+      p1 = f32[] parameter(1)
+      add = f32[] add(p0, p1)
+      ROOT max = f32[] maximum(add, p0)
+    }
+
+    %fused_reduce {
+      p0 = f32[256] parameter(0)
+      p1 = f32[] parameter(1)
+      ROOT reduce = f32[] reduce(p0, p1), dimensions={0}, to_apply=%reducer
+    }
+
+    ENTRY fusion {
+      p0 = f32[256] parameter(0)
+      p1 = f32[] parameter(1)
+      ROOT %reduce = f32[] fusion(p0, p1), kind=kInput, calls=fused_reduce
+    }
+  )");
+  EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(false));
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/reduce_scatter_creator.cc b/third_party/xla/xla/service/gpu/transforms/reduce_scatter_creator.cc
new file mode 100644
index 0000000..d33c849
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/reduce_scatter_creator.cc
@@ -0,0 +1,130 @@
+/* Copyright 2021 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/reduce_scatter_creator.h"
+
+#include <cstdint>
+#include <optional>
+#include <vector>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/log.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/hlo/utils/hlo_query.h"
+#include "xla/service/collective_opt_utils.h"
+#include "xla/service/hlo_module_config.h"
+#include "xla/shape.h"
+#include "xla/status_macros.h"
+#include "tsl/platform/errors.h"
+
+namespace xla {
+namespace gpu {
+
+absl::StatusOr<bool> ReduceScatterCreator::Run(
+    HloModule *module,
+    const absl::flat_hash_set<absl::string_view> &execution_threads) {
+  const HloModuleConfig &config = module->config();
+  int64_t next_channel_id = hlo_query::NextChannelId(*module);
+
+  bool changed = false;
+  for (HloComputation *computation :
+       module->MakeNonfusionComputations(execution_threads)) {
+    for (HloInstruction *instruction :
+         computation->MakeInstructionPostOrder()) {
+      if (instruction->opcode() != HloOpcode::kAllReduce) {
+        continue;
+      }
+      auto *ar = Cast<HloAllReduceInstruction>(instruction);
+      auto ar_spec = MatchReduceScatter(ar, config.num_partitions(),
+                                        config.replica_count(),
+                                        /*allow_multiple_split_dims=*/false,
+                                        /*allow_intervening_reshape=*/true);
+      if (!ar_spec) {
+        VLOG(2) << "Cannot match reduce-scatter " << ar->ToString();
+        continue;
+      }
+
+      HloInstruction *ds = ar_spec->dynamic_slice;
+
+      // Convert to all-reduce scatter. The output shape of the all-reduce
+      // scatter will the same as the input shape, except the split dim size is
+      // that of the result of the dynamic slice.
+      const int64_t split_dim = ar_spec->split_dim;
+      Shape scatter_shape = ar->shape();
+      const int64_t split_dim_size = scatter_shape.dimensions(split_dim);
+      HloInstruction *rs_input = ar->mutable_operand(0);
+      const int64_t scatter_dim_size = split_dim_size / ar_spec->group_size;
+      TF_RET_CHECK(scatter_dim_size * ar_spec->group_size <= split_dim_size);
+      if (split_dim_size % ar_spec->group_size != 0) {
+        // The dynamic-slice does not evenly split the scatter dim. In that
+        // case, create a reduce-scatter with the relevant slice of the
+        // all-reduce input.
+        scatter_shape.set_dimensions(split_dim,
+                                     scatter_dim_size * ar_spec->group_size);
+        rs_input = computation->AddInstruction(HloInstruction::CreateSlice(
+            scatter_shape, rs_input,
+            std::vector<int64_t>(scatter_shape.rank(), 0),
+            scatter_shape.dimensions(),
+            std::vector<int64_t>(scatter_shape.rank(), 1)));
+      }
+      scatter_shape.set_dimensions(split_dim, scatter_dim_size);
+
+      std::optional<int64_t> channel_id;
+      if (ar->channel_id()) {
+        // We cannot reuse the channel_id on all-reduce for reduce-scatter.
+        channel_id = next_channel_id++;
+      }
+
+      HloInstruction *ars =
+          computation->AddInstruction(HloInstruction::CreateReduceScatter(
+              scatter_shape, {rs_input}, ar->to_apply(), ar->device_list(),
+              ar->constrain_layout(), channel_id, ar->use_global_device_ids(),
+              ar_spec->split_dim));
+
+      // If there was an intervening reshape, reshape the non-split dimensions
+      // to match that existing reshape. Basically we can just reshape the ars
+      // result to the dynamic slice shape.
+      HloInstruction *result = ars;
+      HloInstruction *reshape = nullptr;
+      if (ds->operand(0) != ar) {
+        reshape = ds->mutable_operand(0);
+        result = computation->AddInstruction(
+            HloInstruction::CreateReshape(ds->shape(), result));
+      }
+
+      // Note that RemoveInstructionAndUnusedOperands may not always remove the
+      // all-reduce operand of the dynamic-slice, so remove all the dead
+      // instructions manually.
+      TF_RETURN_IF_ERROR(ds->ReplaceAllUsesWith(result));
+      TF_RETURN_IF_ERROR(computation->RemoveInstruction(ds));
+      if (reshape) {
+        TF_RETURN_IF_ERROR(computation->RemoveInstruction(reshape));
+      }
+      TF_RETURN_IF_ERROR(computation->RemoveInstructionAndUnusedOperands(ar));
+      changed = true;
+    }
+  }
+
+  return changed;
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/reduce_scatter_creator.h b/third_party/xla/xla/service/gpu/transforms/reduce_scatter_creator.h
new file mode 100644
index 0000000..4e74394
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/reduce_scatter_creator.h
@@ -0,0 +1,43 @@
+/* Copyright 2021 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_REDUCE_SCATTER_CREATOR_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_REDUCE_SCATTER_CREATOR_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// Transforms dynamic-slice(all-reduce) to a reduce-scatter.
+class ReduceScatterCreator : public HloModulePass {
+ public:
+  ReduceScatterCreator() = default;
+  absl::string_view name() const override { return "reduce-scatter-creator"; }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_REDUCE_SCATTER_CREATOR_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/reduce_scatter_creator_test.cc b/third_party/xla/xla/service/gpu/transforms/reduce_scatter_creator_test.cc
new file mode 100644
index 0000000..39a2c72
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/reduce_scatter_creator_test.cc
@@ -0,0 +1,572 @@
+/* Copyright 2021 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/reduce_scatter_creator.h"
+
+#include <cstddef>
+#include <cstdint>
+#include <memory>
+#include <utility>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/algorithm/container.h"
+#include "absl/log/log.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/hlo_module_config.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/util.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+namespace m = ::xla::match;
+
+class GpuReduceScatterCreatorTest : public HloTestBase {
+ public:
+  absl::StatusOr<std::unique_ptr<HloModule>> RunPass(
+      absl::string_view hlo_module, int64_t num_replicas,
+      int64_t num_partitions, bool expect_change) {
+    HloModuleConfig config = GetModuleConfigForTest(
+        /*replica_count=*/num_replicas,
+        /*num_partitions=*/num_partitions);
+    config.set_use_spmd_partitioning(num_partitions > 1);
+    TF_ASSIGN_OR_RETURN(auto module,
+                        ParseAndReturnVerifiedModule(hlo_module, config));
+    auto changed = ReduceScatterCreator().Run(module.get());
+    if (!changed.ok()) {
+      return changed.status();
+    }
+    EXPECT_EQ(changed.value(), expect_change);
+    return absl::StatusOr<std::unique_ptr<HloModule>>(std::move(module));
+  }
+
+  size_t AllReduceCount(std::unique_ptr<HloModule> &module) {
+    return CollectiveCount(module, HloOpcode::kAllReduce);
+  }
+
+  size_t ReduceScatterCount(std::unique_ptr<HloModule> &module) {
+    return CollectiveCount(module, HloOpcode::kAllReduce);
+  }
+
+ private:
+  size_t CollectiveCount(std::unique_ptr<HloModule> &module, HloOpcode opcode) {
+    return absl::c_count_if(
+        module->entry_computation()->instructions(),
+        [&opcode](HloInstruction *instr) { return instr->opcode() == opcode; });
+  }
+};
+
+TEST_F(GpuReduceScatterCreatorTest, AllReplicas) {
+  absl::string_view hlo_string = R"(
+HloModule AllReduce
+
+%sum {
+  %a = f32[] parameter(0)
+  %b = f32[] parameter(1)
+  ROOT %add = f32[] add(%a, %b)
+}
+
+ENTRY %AllReduce {
+  %param = f32[32,8,128]{2,1,0} parameter(0)
+  %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
+    replica_groups={}, to_apply=%sum
+  %table = s32[8]{0} constant({0,1,2,3,4,5,6,7})
+  %rid = u32[] replica-id()
+  %id = s32[1] dynamic-slice(%table, %rid), dynamic_slice_sizes={1}
+  %reshape = s32[] reshape(%id)
+  %slice_size = s32[] constant(4)
+  %offset = s32[] multiply(%reshape, %slice_size)
+  %zero = s32[] constant(0)
+  ROOT %dynamic-slice = f32[4,8,128] dynamic-slice(%all-reduce, %offset, %zero, %zero),
+    dynamic_slice_sizes={4,8,128}
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
+                                               /*num_replicas=*/8,
+                                               /*num_partitions=*/1,
+                                               /*expect_change=*/true));
+  ASSERT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch(m::ReduceScatter(m::Parameter(0))));
+  const auto *rs = Cast<HloReduceScatterInstruction>(
+      module->entry_computation()->root_instruction());
+  EXPECT_EQ(rs->scatter_dimension(), 0) << rs->ToString();
+  EXPECT_EQ(AllReduceCount(module), 0);
+}
+
+TEST_F(GpuReduceScatterCreatorTest, AllReplicasWithOffsetReshape) {
+  absl::string_view hlo_string = R"(
+HloModule AllReduce
+
+%sum {
+  %a = f32[] parameter(0)
+  %b = f32[] parameter(1)
+  ROOT %add = f32[] add(%a, %b)
+}
+
+ENTRY %AllReduce {
+  %param = f32[32,8,128]{2,1,0} parameter(0)
+  %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
+    replica_groups={}, to_apply=%sum
+  %table = s32[8]{0} constant({0,1,2,3,4,5,6,7})
+  %rid = u32[] replica-id()
+  %id = s32[1] dynamic-slice(%table, %rid), dynamic_slice_sizes={1}
+  %slice_size = s32[1] constant({4})
+  %offset = s32[1] multiply(%id, %slice_size)
+  %reshape = s32[] reshape(%offset)
+  %zero = s32[] constant(0)
+  ROOT %dynamic-slice = f32[4,8,128] dynamic-slice(%all-reduce, %reshape, %zero, %zero),
+    dynamic_slice_sizes={4,8,128}
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
+                                               /*num_replicas=*/8,
+                                               /*num_partitions=*/1,
+                                               /*expect_change=*/true));
+  ASSERT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch(m::ReduceScatter(m::Parameter(0))));
+  const auto *rs = Cast<HloReduceScatterInstruction>(
+      module->entry_computation()->root_instruction());
+  EXPECT_EQ(rs->scatter_dimension(), 0) << rs->ToString();
+  EXPECT_EQ(AllReduceCount(module), 0);
+}
+
+TEST_F(GpuReduceScatterCreatorTest, AllReplicasWithReshape) {
+  absl::string_view hlo_string = R"(
+HloModule AllReduce
+
+%sum {
+  %a = f32[] parameter(0)
+  %b = f32[] parameter(1)
+  ROOT %add = f32[] add(%a, %b)
+}
+
+ENTRY %AllReduce {
+  %param = f32[32,8,128]{2,1,0} parameter(0)
+  %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
+    replica_groups={}, to_apply=%sum
+  %table = s32[8]{0} constant({0,1,2,3,4,5,6,7})
+  %rid = u32[] replica-id()
+  %id = s32[1] dynamic-slice(%table, %rid), dynamic_slice_sizes={1}
+  %reshape = s32[] reshape(%id)
+  %slice_size = s32[] constant(4)
+  %offset = s32[] multiply(%reshape, %slice_size)
+  %zero = s32[] constant(0)
+  %reshape.1 = f32[32,16,64] reshape(%all-reduce)
+  ROOT %dynamic-slice = f32[4,16,64] dynamic-slice(%reshape.1, %offset, %zero, %zero),
+    dynamic_slice_sizes={4,16,64}
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
+                                               /*num_replicas=*/8,
+                                               /*num_partitions=*/1,
+                                               /*expect_change=*/true));
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch(m::Reshape(m::ReduceScatter(m::Parameter(0)))));
+  EXPECT_EQ(AllReduceCount(module), 0);
+}
+
+TEST_F(GpuReduceScatterCreatorTest, AllReplicasWithReshapeSplitDimModified) {
+  absl::string_view hlo_string = R"(
+HloModule AllReduce
+
+%sum {
+  %a = f32[] parameter(0)
+  %b = f32[] parameter(1)
+  ROOT %add = f32[] add(%a, %b)
+}
+
+ENTRY %AllReduce {
+  %param = f32[336,1024] parameter(0)
+  %all-reduce = f32[336,1024] all-reduce(%param), replica_groups={}, to_apply=%sum
+  %rid = u32[] replica-id()
+  %id = s32[] convert(%rid)
+  %slice_size = s32[] constant(128)
+  %offset = s32[] multiply(%id, %slice_size)
+  %zero = s32[] constant(0)
+  %reshape.1 = f32[4,84,1024] reshape(%all-reduce)
+  ROOT %dynamic-slice = f32[4,84,128] dynamic-slice(%reshape.1, %zero, %zero, %offset),
+    dynamic_slice_sizes={4,84,128}
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
+                                               /*num_replicas=*/8,
+                                               /*num_partitions=*/1,
+                                               /*expect_change=*/true));
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch(m::Reshape(m::ReduceScatter(m::Parameter(0)))));
+  EXPECT_EQ(AllReduceCount(module), 0);
+}
+
+TEST_F(GpuReduceScatterCreatorTest, AllReplicasDim2) {
+  absl::string_view hlo_string = R"(
+HloModule AllReduce
+
+%sum {
+  %a = f32[] parameter(0)
+  %b = f32[] parameter(1)
+  ROOT %add = f32[] add(%a, %b)
+}
+
+ENTRY %AllReduce {
+  %param = f32[32,8,128]{2,1,0} parameter(0)
+  %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
+    replica_groups={}, to_apply=%sum
+  %table = s32[8]{0} constant({0,1,2,3,4,5,6,7})
+  %rid = u32[] replica-id()
+  %rid_s32 = s32[] convert(%rid)
+  %slice_size = s32[] constant(16)
+  %offset = s32[] multiply(%rid_s32, %slice_size)
+  %zero = s32[] constant(0)
+  ROOT %dynamic-slice = f32[32,8,16] dynamic-slice(%all-reduce, %zero, %zero, %offset),
+    dynamic_slice_sizes={32,8,16}
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
+                                               /*num_replicas=*/8,
+                                               /*num_partitions=*/1,
+                                               /*expect_change=*/true));
+  ASSERT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch(m::ReduceScatter(m::Parameter(0))));
+  const auto *rs = Cast<HloReduceScatterInstruction>(
+      module->entry_computation()->root_instruction());
+  EXPECT_EQ(rs->scatter_dimension(), 2) << rs->ToString();
+  EXPECT_EQ(AllReduceCount(module), 0);
+}
+
+TEST_F(GpuReduceScatterCreatorTest, AllReplicasWrongOffsets) {
+  absl::string_view hlo_string = R"(
+HloModule AllReduce
+
+%sum {
+  %a = f32[] parameter(0)
+  %b = f32[] parameter(1)
+  ROOT %add = f32[] add(%a, %b)
+}
+
+ENTRY %AllReduce {
+  %param = f32[32,8,128]{2,1,0} parameter(0)
+  %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
+    replica_groups={}, to_apply=%sum
+  %table = s32[8]{0} constant({0,1,2,3,4,5,6,8})
+  %rid = u32[] replica-id()
+  %id = s32[1] dynamic-slice(%table, %rid), dynamic_slice_sizes={1}
+  %reshape = s32[] reshape(%id)
+  %slice_size = s32[] constant(4)
+  %offset = s32[] multiply(%reshape, %slice_size)
+  %zero = s32[] constant(0)
+  ROOT %dynamic-slice = f32[4,8,128] dynamic-slice(%all-reduce, %offset, %zero, %zero),
+    dynamic_slice_sizes={4,8,128}
+}
+)";
+  TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
+                                               /*num_replicas=*/8,
+                                               /*num_partitions=*/1,
+                                               /*expect_change=*/false));
+}
+
+TEST_F(GpuReduceScatterCreatorTest, AllReplicasIotaTable) {
+  absl::string_view hlo_string = R"(
+HloModule AllReduce
+
+%sum {
+  %a = f32[] parameter(0)
+  %b = f32[] parameter(1)
+  ROOT %add = f32[] add(%a, %b)
+}
+
+ENTRY %AllReduce {
+  %param = f32[32,8,128]{2,1,0} parameter(0)
+  %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
+    replica_groups={}, to_apply=%sum
+  %table = s32[8]{0} iota(), iota_dimension=0
+  %rid = u32[] replica-id()
+  %id = s32[1] dynamic-slice(%table, %rid), dynamic_slice_sizes={1}
+  %reshape = s32[] reshape(%id)
+  %slice_size = s32[] constant(4)
+  %offset = s32[] multiply(%reshape, %slice_size)
+  %zero = s32[] constant(0)
+  ROOT %dynamic-slice = f32[4,8,128] dynamic-slice(%all-reduce, %offset, %zero, %zero),
+    dynamic_slice_sizes={4,8,128}
+}
+)";
+  TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
+                                               /*num_replicas=*/8,
+                                               /*num_partitions=*/2,
+                                               /*expect_change=*/true));
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch(m::ReduceScatter(m::Parameter(0))));
+  EXPECT_EQ(AllReduceCount(module), 0);
+}
+
+TEST_F(GpuReduceScatterCreatorTest, SubgroupedReplicas) {
+  absl::string_view hlo_string = R"(
+HloModule AllReduce
+
+%sum {
+  %a = f32[] parameter(0)
+  %b = f32[] parameter(1)
+  ROOT %add = f32[] add(%a, %b)
+}
+
+ENTRY %AllReduce {
+  %param = f32[32,8,128]{2,1,0} parameter(0)
+  %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
+    replica_groups={{1,3,2,0},{4,5,6,7}}, to_apply=%sum
+  %gtable = s32[8]{0} constant({3,0,2,1,0,1,2,3})
+  %rid = u32[] replica-id()
+  %id = s32[1] dynamic-slice(%gtable, %rid), dynamic_slice_sizes={1}
+  %reshape.0 = s32[] reshape(%id)
+  %table = s32[4]{0} constant({0,8,16,24})
+  %offset = s32[1] dynamic-slice(%table, %reshape.0), dynamic_slice_sizes={1}
+  %reshape.1 = s32[] reshape(%offset)
+  %zero = s32[] constant(0)
+  ROOT %dynamic-slice = f32[8,8,128] dynamic-slice(%all-reduce, %reshape.1, %zero, %zero),
+    dynamic_slice_sizes={8,8,128}
+}
+)";
+  TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
+                                               /*num_replicas=*/8,
+                                               /*num_partitions=*/2,
+                                               /*expect_change=*/true));
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch(m::ReduceScatter(m::Parameter(0))));
+  EXPECT_EQ(AllReduceCount(module), 0);
+}
+
+TEST_F(GpuReduceScatterCreatorTest, AllPartitions) {
+  absl::string_view hlo_string = R"(
+HloModule AllReduce
+
+%sum {
+  %a = f32[] parameter(0)
+  %b = f32[] parameter(1)
+  ROOT %add = f32[] add(%a, %b)
+}
+
+ENTRY %AllReduce {
+  %param = f32[32,8,128]{2,1,0} parameter(0)
+  %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
+    replica_groups={{0},{1}}, to_apply=%sum, channel_id=1
+  %table = s32[8]{0} constant({0,1,2,3,4,5,6,7})
+  %pid = u32[] partition-id()
+  %id = s32[1] dynamic-slice(%table, %pid), dynamic_slice_sizes={1}
+  %reshape = s32[] reshape(%id)
+  %slice_size = s32[] constant(4)
+  %offset = s32[] multiply(%reshape, %slice_size)
+  %zero = s32[] constant(0)
+  ROOT %dynamic-slice = f32[4,8,128] dynamic-slice(%all-reduce, %offset, %zero, %zero),
+    dynamic_slice_sizes={4,8,128}
+}
+)";
+  TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
+                                               /*num_replicas=*/2,
+                                               /*num_partitions=*/8,
+                                               /*expect_change=*/true));
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch(m::ReduceScatter(m::Parameter(0))));
+  EXPECT_EQ(AllReduceCount(module), 0);
+}
+
+TEST_F(GpuReduceScatterCreatorTest, AllReduceFollowedByAllReduce) {
+  absl::string_view hlo_string = R"(
+HloModule AllReduce
+
+%sum {
+  %a = f32[] parameter(0)
+  %b = f32[] parameter(1)
+  ROOT %add = f32[] add(%a, %b)
+}
+
+ENTRY %AllReduce {
+  %param = f32[32,8,128]{2,1,0} parameter(0)
+  %all-reduce.scattered = f32[32,8,128]{2,1,0} all-reduce(%param),
+    replica_groups={{0,1,2,3,4,5,6,7},{8,9,10,11,12,13,14,15}}, to_apply=%sum, use_global_device_ids=true, channel_id=1
+  %table = s32[8]{0} constant({0,1,2,3,4,5,6,7})
+  %pid = u32[] partition-id()
+  %id = s32[1] dynamic-slice(%table, %pid), dynamic_slice_sizes={1}
+  %reshape = s32[] reshape(%id)
+  %slice_size = s32[] constant(4)
+  %offset = s32[] multiply(%reshape, %slice_size)
+  %zero = s32[] constant(0)
+  %dynamic-slice = f32[4,8,128] dynamic-slice(%all-reduce.scattered, %offset, %zero, %zero),
+    dynamic_slice_sizes={4,8,128}
+  ROOT %all-reduce.sync = f32[4,8,128]{2,1,0} all-reduce(%dynamic-slice),
+    replica_groups={{0,8},{1,9},{2,10},{3,11},{4,12},{5,13},{6,14},{7,15}}, to_apply=%sum, use_global_device_ids=true, channel_id=2
+}
+)";
+  TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
+                                               /*num_replicas=*/2,
+                                               /*num_partitions=*/8,
+                                               /*expect_change=*/true));
+  EXPECT_EQ(AllReduceCount(module), 1);
+  EXPECT_EQ(ReduceScatterCount(module), 1);
+}
+
+TEST_F(GpuReduceScatterCreatorTest, SubgroupsGlobals) {
+  absl::string_view hlo_string = R"(
+HloModule AllReduce
+
+%sum {
+  %a = f32[] parameter(0)
+  %b = f32[] parameter(1)
+  ROOT %add = f32[] add(%a, %b)
+}
+
+ENTRY %AllReduce {
+  %param = f32[32,8,128]{2,1,0} parameter(0)
+  %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
+    replica_groups={{1,3,2,0},{4,5,6,7}}, to_apply=%sum, channel_id=1, use_global_device_ids=true
+  %pid = u32[] partition-id()
+  %rid = u32[] replica-id()
+  %pcount = u32[] constant(4)
+  %ridxp = u32[] multiply(%rid, %pcount)
+  %gid = u32[] add(%ridxp, %pid)
+  %gtable = s32[8]{0} constant({3,0,2,1,0,1,2,3})
+  %id = s32[1] dynamic-slice(%gtable, %gid), dynamic_slice_sizes={1}
+  %reshape.0 = s32[] reshape(%id)
+  %table = s32[4]{0} constant({0,8,16,24})
+  %offset = s32[1] dynamic-slice(%table, %reshape.0), dynamic_slice_sizes={1}
+  %reshape.1 = s32[] reshape(%offset)
+  %zero = s32[] constant(0)
+  ROOT %dynamic-slice = f32[8,8,128] dynamic-slice(%all-reduce, %reshape.1, %zero, %zero),
+    dynamic_slice_sizes={8,8,128}
+}
+)";
+  TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
+                                               /*num_replicas=*/2,
+                                               /*num_partitions=*/4,
+                                               /*expect_change=*/true));
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch(m::ReduceScatter(m::Parameter(0))));
+  EXPECT_EQ(AllReduceCount(module), 0);
+}
+
+TEST_F(GpuReduceScatterCreatorTest, SubgroupsGlobalsOrthogonalReplicas) {
+  absl::string_view hlo_string = R"(
+HloModule AllReduce
+
+%sum {
+  %a = f32[] parameter(0)
+  %b = f32[] parameter(1)
+  ROOT %add = f32[] add(%a, %b)
+}
+
+ENTRY %AllReduce {
+  %param = f32[32,8,128]{2,1,0} parameter(0)
+  %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
+    replica_groups={{1,3,2,0},{5,7,6,4}}, to_apply=%sum, channel_id=1, use_global_device_ids=true
+  %pid = u32[] partition-id()
+  %pid_table = s32[4]{0} constant({3,0,2,1})
+  %offset = s32[1] dynamic-slice(%pid_table, %pid), dynamic_slice_sizes={1}
+  %reshape = s32[] reshape(%offset)
+  %shard_size = s32[] constant(8)
+  %mul = s32[] multiply(%reshape, %shard_size)
+  %zero = s32[] constant(0)
+  ROOT %dynamic-slice = f32[8,8,128] dynamic-slice(%all-reduce, %mul, %zero, %zero),
+    dynamic_slice_sizes={8,8,128}
+}
+)";
+  TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
+                                               /*num_replicas=*/2,
+                                               /*num_partitions=*/4,
+                                               /*expect_change=*/true));
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch(m::ReduceScatter(m::Parameter(0))));
+  EXPECT_EQ(AllReduceCount(module), 0);
+}
+
+TEST_F(GpuReduceScatterCreatorTest, SubgroupsGlobalsNonOrthogonalReplicas) {
+  absl::string_view hlo_string = R"(
+HloModule AllReduce
+
+%sum {
+  %a = f32[] parameter(0)
+  %b = f32[] parameter(1)
+  ROOT %add = f32[] add(%a, %b)
+}
+
+ENTRY %AllReduce {
+  %param = f32[32,8,128]{2,1,0} parameter(0)
+  %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
+    replica_groups={{1,3,2,0},{7,5,6,4}}, to_apply=%sum, channel_id=1, use_global_device_ids=true
+  %pid = u32[] partition-id()
+  %pid_table = s32[4]{0} constant({3,0,2,1})
+  %offset = s32[1] dynamic-slice(%pid_table, %pid), dynamic_slice_sizes={1}
+  %reshape = s32[] reshape(%offset)
+  %shard_size = s32[] constant(8)
+  %mul = s32[] multiply(%reshape, %shard_size)
+  %zero = s32[] constant(0)
+  ROOT %dynamic-slice = f32[8,8,128] dynamic-slice(%all-reduce, %mul, %zero, %zero),
+    dynamic_slice_sizes={8,8,128}
+}
+)";
+  TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
+                                               /*num_replicas=*/2,
+                                               /*num_partitions=*/4,
+                                               /*expect_change=*/false));
+}
+
+TEST_F(GpuReduceScatterCreatorTest, NonUniformSplit) {
+  absl::string_view hlo_string = R"(
+HloModule AllReduce
+
+%sum {
+  %a = f32[] parameter(0)
+  %b = f32[] parameter(1)
+  ROOT %add = f32[] add(%a, %b)
+}
+
+ENTRY %AllReduce {
+  %param = f32[1,7]{1,0} parameter(0)
+  %all-reduce = f32[1,7]{1,0} all-reduce(%param),
+    replica_groups={{0,1},{2,3},{4,5},{6,7}}, to_apply=%sum, channel_id=1, use_global_device_ids=true
+  %pid = u32[] partition-id()
+  %pid_table = s32[8]{0} constant({0, 1, 0, 1, 0, 1, 0, 1})
+  %offset = s32[1] dynamic-slice(%pid_table, %pid), dynamic_slice_sizes={1}
+  %reshape = s32[] reshape(%offset)
+  %shard_size = s32[] constant(3)
+  %mul = s32[] multiply(%reshape, %shard_size)
+  %zero = s32[] constant(0)
+  ROOT %dynamic-slice = f32[1,3] dynamic-slice(%all-reduce, %zero, %mul),
+    dynamic_slice_sizes={1,3}
+}
+)";
+  TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
+                                               /*num_replicas=*/1,
+                                               /*num_partitions=*/8,
+                                               /*expect_change=*/true));
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch(m::ReduceScatter(m::Slice(m::Parameter(0)))));
+}
+
+}  // namespace
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/reduction_degenerate_dim_remover.cc b/third_party/xla/xla/service/gpu/transforms/reduction_degenerate_dim_remover.cc
new file mode 100644
index 0000000..8c2929c
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/reduction_degenerate_dim_remover.cc
@@ -0,0 +1,131 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/reduction_degenerate_dim_remover.h"
+
+#include <cstdint>
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/container/inlined_vector.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+class ReductionDegenerateDimRemoverVisitor : public DfsHloRewriteVisitor {
+ public:
+  absl::Status HandleReduce(HloInstruction *hlo) override {
+    auto instr = Cast<HloReduceInstruction>(hlo);
+    absl::InlinedVector<HloInstruction *, 2> input_reshapes;
+    absl::InlinedVector<Shape, 2> canonical_reduce_shapes;
+
+    int idx = -1;
+    std::vector<int64_t> updated_reduced_dimensions;
+    for (HloInstruction *reduced_op : instr->inputs()) {
+      idx++;
+      const Shape &input_shape = reduced_op->shape();
+      const Shape &reduce_shape = instr->shape().IsTuple()
+                                      ? instr->shape().tuple_shapes(idx)
+                                      : instr->shape();
+
+      if (!ShapeUtil::HasDegenerateDimensions(reduced_op->shape())) {
+        return absl::OkStatus();
+      }
+      Shape canonical_input_shape =
+          ShapeUtil::DropDegenerateDimensions(input_shape);
+
+      Shape canonical_reduce_shape =
+          ShapeUtil::DropDegenerateDimensions(reduce_shape);
+
+      auto reduced_dimensions = instr->dimensions();
+      int64_t shift = 0;
+
+      for (int dim = 0; dim < input_shape.rank(); dim++) {
+        if (input_shape.dimensions(dim) == 1) {
+          shift++;
+        } else {
+          if (absl::c_linear_search(reduced_dimensions, dim) && idx == 0) {
+            // Only populate on first iteration.
+            updated_reduced_dimensions.push_back(dim - shift);
+          }
+        }
+      }
+
+      if (updated_reduced_dimensions.empty()) {
+        std::unique_ptr<HloInstruction> reshape =
+            HloInstruction::CreateBitcast(reduce_shape, reduced_op);
+        return ReplaceWithNewInstruction(instr, std::move(reshape));
+      }
+
+      input_reshapes.push_back(instr->parent()->AddInstruction(
+          HloInstruction::CreateBitcast(canonical_input_shape, reduced_op)));
+      canonical_reduce_shapes.push_back(canonical_reduce_shape);
+    }
+
+    Shape canonical_reduce_shape =
+        ShapeUtil::MakeMaybeTupleShape(canonical_reduce_shapes);
+    const Shape &orig_reduce_shape = instr->shape();
+    std::unique_ptr<HloInstruction> new_reduce = HloInstruction::CreateReduce(
+        canonical_reduce_shape, input_reshapes, instr->init_values(),
+        updated_reduced_dimensions, instr->to_apply());
+    instr->SetupDerivedInstruction(new_reduce.get());
+
+    if (canonical_reduce_shape != instr->shape()) {
+      HloInstruction *wrapped_reduce =
+          instr->parent()->AddInstruction(std::move(new_reduce));
+      absl::InlinedVector<HloInstruction *, 2> out;
+      if (!canonical_reduce_shape.IsTuple()) {
+        new_reduce =
+            HloInstruction::CreateBitcast(orig_reduce_shape, wrapped_reduce);
+      } else {
+        for (int oidx = 0; oidx < instr->input_count(); oidx++) {
+          HloInstruction *gte = instr->parent()->AddInstruction(
+              HloInstruction::CreateGetTupleElement(wrapped_reduce, oidx));
+          out.push_back(
+              instr->parent()->AddInstruction(HloInstruction::CreateBitcast(
+                  orig_reduce_shape.tuple_shapes(oidx), gte)));
+        }
+        new_reduce = HloInstruction::CreateTuple(out);
+      }
+    }
+
+    return ReplaceWithNewInstruction(instr, std::move(new_reduce));
+  }
+};
+
+absl::StatusOr<bool> ReductionDegenerateDimRemover::Run(
+    HloModule *module,
+    const absl::flat_hash_set<absl::string_view> &execution_threads) {
+  TF_ASSIGN_OR_RETURN(bool changed,
+                      ReductionDegenerateDimRemoverVisitor().RunOnModule(
+                          module, execution_threads));
+  return changed;
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/reduction_degenerate_dim_remover.h b/third_party/xla/xla/service/gpu/transforms/reduction_degenerate_dim_remover.h
new file mode 100644
index 0000000..1630aec
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/reduction_degenerate_dim_remover.h
@@ -0,0 +1,56 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_REDUCTION_DEGENERATE_DIM_REMOVER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_REDUCTION_DEGENERATE_DIM_REMOVER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// Enforces the invariant that reduction input and output have no degenerate
+// (size 1) dimension. Since these dimensions are physically meaningless, they
+// are removed using bitcasts.
+//
+// For example,
+//
+//   f[1] out = reduce(f[100, 1, 1] input, dimensions={0, 1})
+//
+// becomes:
+//
+//
+//   f[100] tmp1 = f[100] bitcast(f[100, 1, 1], input)
+//   f[] tmp2 = reduce(f[100] tmp1, dimensions={0})
+//   f[1] out = f[] bitcast(tmp2)
+//
+class ReductionDegenerateDimRemover : public HloModulePass {
+ public:
+  absl::string_view name() const override {
+    return "reduction-degenerate-dim-remover";
+  }
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_REDUCTION_DEGENERATE_DIM_REMOVER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/reduction_degenerate_dim_remover_test.cc b/third_party/xla/xla/service/gpu/transforms/reduction_degenerate_dim_remover_test.cc
new file mode 100644
index 0000000..7a9b7fa
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/reduction_degenerate_dim_remover_test.cc
@@ -0,0 +1,133 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/reduction_degenerate_dim_remover.h"
+
+#include <optional>
+
+#include "absl/strings/string_view.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/test.h"
+
+namespace xla {
+
+namespace {
+
+class ReductionDegenerateDimRemoverTest : public HloTestBase {
+ public:
+  void CheckDegenerateDimRemover(absl::string_view hlo,
+                                 std::optional<absl::string_view> expected) {
+    RunAndFilecheckHloRewrite(hlo, gpu::ReductionDegenerateDimRemover{},
+                              expected);
+  }
+};
+
+TEST_F(ReductionDegenerateDimRemoverTest, ReductionWithDegenerateDimensions) {
+  const char* hlo = R"(
+HloModule ReduceWithDegenerateDimensions
+
+add {
+  accum = f32[] parameter(0)
+  op = f32[] parameter(1)
+  ROOT out = f32[] add(accum, op)
+}
+
+ENTRY main {
+  input = f32[1,3,1,4,1,5,1] parameter(0)
+  zero = f32[] constant(0)
+
+  ROOT out = f32[1,1,1,1] reduce(input, zero), dimensions={1,3,5}, to_apply=add
+}
+
+)";
+
+  CheckDegenerateDimRemover(hlo, R"(
+// CHECK: [[bitcast_0:%[^ ]+]] = f32[3,4,5]{2,1,0} bitcast([[input_1:%[^ ]+]])
+// CHECK: [[reduce_2:%[^ ]+]] = f32[] reduce([[bitcast_0]], [[zero_3:%[^ ]+]]), dimensions={0,1,2}, to_apply=[[add_4:%[^ ]+]]
+// CHECK: ROOT [[bitcast_1_5:%[^ ]+]] = f32[1,1,1,1]{3,2,1,0} bitcast([[reduce_2]])
+  )");
+}
+
+TEST_F(ReductionDegenerateDimRemoverTest,
+       ReductionWithDegenerateDimensionsVariadic) {
+  const char* hlo = R"(
+HloModule ReduceWithDegenerateDimensions
+
+argmax {
+  running_max = f32[] parameter(0)
+  running_max_idx = u32[] parameter(1)
+  current_value = f32[] parameter(2)
+  current_value_idx = u32[] parameter(3)
+
+  current = (f32[], u32[]) tuple(running_max, running_max_idx)
+  potential = (f32[], u32[]) tuple(current_value, current_value_idx)
+
+  cmp_code = pred[] compare(current_value, running_max), direction=GT
+
+  new_max = f32[] select(cmp_code, current_value, running_max)
+  new_idx = u32[] select(cmp_code, current_value_idx, running_max_idx)
+
+  ROOT out = (f32[], u32[]) tuple(new_max, new_idx)
+}
+
+ENTRY main {
+  input = f32[1,3,1,4,1,5,1] parameter(0)
+  idxs = u32[1,3,1,4,1,5,1] parameter(1)
+  zero = f32[] constant(0)
+  zero_idx = u32[] constant(0)
+
+  ROOT out = (f32[1,1,1,1], u32[1,1,1,1]) reduce(input, idxs, zero, zero_idx), dimensions={1,3,5}, to_apply=argmax
+}
+
+)";
+
+  CheckDegenerateDimRemover(hlo, R"(
+// CHECK:  [[bitcast_0:%[^ ]+]] = f32[3,4,5]{2,1,0} bitcast([[input_1:%[^ ]+]])
+// CHECK:  [[bitcast_1_2:%[^ ]+]] = u32[3,4,5]{2,1,0} bitcast([[idxs_3:%[^ ]+]])
+// CHECK:  [[reduce_4:%[^ ]+]] = (f32[], u32[]) reduce([[bitcast_0]], [[bitcast_1_2]], [[zero_5:%[^ ]+]], [[zero_idx_6:%[^ ]+]]), dimensions={0,1,2}, to_apply=[[argmax_7:%[^ ]+]]
+// CHECK-NEXT:  [[get_tuple_element_8:%[^ ]+]] = f32[] get-tuple-element([[reduce_4]]), index=0
+// CHECK-NEXT:  [[bitcast_2_9:%[^ ]+]] = f32[1,1,1,1]{3,2,1,0} bitcast([[get_tuple_element_8]])
+// CHECK-NEXT:  [[get_tuple_element_1_10:%[^ ]+]] = u32[] get-tuple-element([[reduce_4]]), index=1
+// CHECK-NEXT:  [[bitcast_3_11:%[^ ]+]] = u32[1,1,1,1]{3,2,1,0} bitcast([[get_tuple_element_1_10]])
+// CHECK-NEXT:  ROOT [[tuple_12:%[^ ]+]] = (f32[1,1,1,1]{3,2,1,0}, u32[1,1,1,1]{3,2,1,0}) tuple([[bitcast_2_9]], [[bitcast_3_11]])
+)");
+}
+
+TEST_F(ReductionDegenerateDimRemoverTest, DegenerateWithEmptyDimension) {
+  const char* hlo = R"(
+HloModule ReduceWithDegenerateDimensions
+
+add {
+  accum = f32[] parameter(0)
+  op = f32[] parameter(1)
+  ROOT out = f32[] add(accum, op)
+}
+
+ENTRY main {
+  input = f32[1,3,1,4,1,5,1] parameter(0)
+  zero = f32[] constant(0)
+
+  ROOT out = f32[3,4,5,1] reduce(input, zero), dimensions={0,2,4}, to_apply=add
+}
+)";
+
+  CheckDegenerateDimRemover(hlo,
+                            R"(
+// CHECK: ROOT [[bitcast_0:%[^ ]+]] = f32[3,4,5,1]{3,2,1,0} bitcast([[input_1:%[^ ]+]])
+      )");
+}
+
+}  // namespace
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/reduction_dimension_grouper.cc b/third_party/xla/xla/service/gpu/transforms/reduction_dimension_grouper.cc
new file mode 100644
index 0000000..ca4fba4
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/reduction_dimension_grouper.cc
@@ -0,0 +1,122 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/reduction_dimension_grouper.h"
+
+#include <cstdint>
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/container/inlined_vector.h"
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/layout_util.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+class ReduceDimensionGroupVisitor : public DfsHloRewriteVisitor {
+ public:
+  absl::Status HandleReduce(HloInstruction *hlo) override {
+    auto reduce = Cast<HloReduceInstruction>(hlo);
+
+    VLOG(4) << "Input: " << reduce->ToString();
+
+    absl::InlinedVector<HloInstruction *, 2> reduce_inputs_grouped;
+    std::vector<int64_t> reduced_dims_grouped;
+
+    int idx = -1;
+    for (HloInstruction *operand : reduce->inputs()) {
+      idx++;
+      std::vector<int64_t> new_grouped_dims;
+      const Shape &shape = operand->shape();
+      CHECK(shape == LayoutUtil::GetWithDefaultLayout(shape))
+          << "Default layout should be enforced on reduction operand";
+      auto is_reduced = [&](int dim) {
+        return absl::c_linear_search(reduce->dimensions(), dim);
+      };
+
+      bool changed = false;
+      int64_t next_dim_size = 1;
+
+      // Since we have enforced the standard layout, iteration over logical
+      // dimensions is equivalent to iteration over the major-to-minor order.
+      for (int logical_dim = 0; logical_dim < shape.rank(); logical_dim++) {
+        VLOG(5) << "Processing dimension " << logical_dim << " of size "
+                << shape.dimensions(logical_dim);
+        if (is_reduced(logical_dim) && logical_dim < shape.rank() - 1 &&
+            is_reduced(logical_dim + 1)) {
+          VLOG(5) << "This and consecutive dimension are reduced, merging";
+          changed = true;
+          next_dim_size *= shape.dimensions(logical_dim);
+          continue;
+        }
+
+        if (is_reduced(logical_dim)) {
+          new_grouped_dims.push_back(next_dim_size *
+                                     shape.dimensions(logical_dim));
+          if (idx == 0) {
+            // Only populate for first argument.
+            reduced_dims_grouped.push_back(new_grouped_dims.size() - 1);
+          }
+          next_dim_size = 1;
+        } else {
+          new_grouped_dims.push_back(shape.dimensions(logical_dim));
+        }
+      }
+
+      if (!changed) {  // Since all inputs have same shape dimensions.
+        return absl::OkStatus();
+      }
+
+      Shape grouped_shape =
+          ShapeUtil::MakeShape(shape.element_type(), new_grouped_dims);
+      reduce_inputs_grouped.push_back(reduce->parent()->AddInstruction(
+          HloInstruction::CreateBitcast(grouped_shape, operand),
+          &operand->metadata()));
+      VLOG(5) << "Adding bitcast: " << reduce_inputs_grouped.back()->ToString();
+    }
+
+    std::unique_ptr<HloInstruction> new_reduce = HloInstruction::CreateReduce(
+        reduce->shape(), reduce_inputs_grouped, reduce->init_values(),
+        reduced_dims_grouped, reduce->to_apply());
+    VLOG(5) << "Generated new reduction: " << new_reduce->ToString();
+    return ReplaceWithNewInstruction(reduce, std::move(new_reduce));
+  }
+};
+
+absl::StatusOr<bool> ReductionDimensionGrouper::Run(
+    HloModule *module,
+    const absl::flat_hash_set<absl::string_view> &execution_threads) {
+  TF_ASSIGN_OR_RETURN(bool changed, ReduceDimensionGroupVisitor().RunOnModule(
+                                        module, execution_threads));
+  return changed;
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/reduction_dimension_grouper.h b/third_party/xla/xla/service/gpu/transforms/reduction_dimension_grouper.h
new file mode 100644
index 0000000..1ebd600
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/reduction_dimension_grouper.h
@@ -0,0 +1,56 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_REDUCTION_DIMENSION_GROUPER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_REDUCTION_DIMENSION_GROUPER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// Groups adjacent (logically and physically) reduced dimensions in reduction
+// input.
+//
+// Precondition: ReductionLayoutNormalizer has been run (physical proximity and
+// logical proximity become the same).
+//
+// For example,
+//
+//   f[] out = reduce(f[10,20,30] input, dimensions={0,1,2})
+//
+// becomes:
+//
+//   f[600] tmp = f[600] bitcast(f[10,20,30] input)
+//   f[] out = reduce(f[600] tmp, dimensions={0})
+//
+class ReductionDimensionGrouper : public HloModulePass {
+ public:
+  absl::string_view name() const override {
+    return "reduction-dimension-grouper";
+  }
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_REDUCTION_DIMENSION_GROUPER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/reduction_dimension_grouper_test.cc b/third_party/xla/xla/service/gpu/transforms/reduction_dimension_grouper_test.cc
new file mode 100644
index 0000000..afbbbec
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/reduction_dimension_grouper_test.cc
@@ -0,0 +1,103 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/reduction_dimension_grouper.h"
+
+#include <optional>
+
+#include "absl/strings/string_view.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/test.h"
+
+namespace xla {
+
+namespace {
+
+class ReductionDimensionGrouperTest : public HloTestBase {
+ public:
+  void CheckDimensionGrouper(absl::string_view hlo,
+                             std::optional<absl::string_view> expected) {
+    RunAndFilecheckHloRewrite(hlo, gpu::ReductionDimensionGrouper{}, expected);
+  }
+};
+
+TEST_F(ReductionDimensionGrouperTest, ReductionWithGrouping) {
+  const char* hlo = R"(
+HloModule ReductionWithGrouping
+
+add {
+  accum = f32[] parameter(0)
+  op = f32[] parameter(1)
+  ROOT out = f32[] add(accum, op)
+}
+
+ENTRY main {
+  input = f32[100,10,32,3]{3,2,1,0} parameter(0)
+  zero = f32[] constant(0)
+
+  ROOT out = f32[100,10]{0,1} reduce(input, zero), dimensions={2,3}, to_apply=add
+}
+)";
+
+  CheckDimensionGrouper(hlo,
+                        R"(
+// CHECK:  [[input_0:%[^ ]+]] = f32[100,10,32,3]{3,2,1,0} parameter(0)
+// CHECK:  [[bitcast_1:%[^ ]+]] = f32[100,10,96]{2,1,0} bitcast([[input_0]])
+// CHECK:  ROOT [[out_1_2:%[^ ]+]] = f32[100,10]{0,1} reduce([[bitcast_1]], [[zero_3:%[^ ]+]]), dimensions={2}, to_apply=[[add_4:%[^ ]+]]
+      )");
+}
+
+TEST_F(ReductionDimensionGrouperTest, ReductionWithGroupingVariadic) {
+  const char* hlo = R"(
+HloModule ReductionWithGrouping
+
+argmax {
+  running_max = f32[] parameter(0)
+  running_max_idx = u32[] parameter(1)
+  current_value = f32[] parameter(2)
+  current_value_idx = u32[] parameter(3)
+
+  current = (f32[], u32[]) tuple(running_max, running_max_idx)
+  potential = (f32[], u32[]) tuple(current_value, current_value_idx)
+
+  cmp_code = pred[] compare(current_value, running_max), direction=GT
+
+  new_max = f32[] select(cmp_code, current_value, running_max)
+  new_idx = u32[] select(cmp_code, current_value_idx, running_max_idx)
+
+  ROOT out = (f32[], u32[]) tuple(new_max, new_idx)
+}
+
+ENTRY main {
+  input = f32[100,10,32,3]{3,2,1,0} parameter(0)
+  idxs = u32[100,10,32,3]{3,2,1,0} parameter(1)
+  zero = f32[] constant(0)
+  zero_idx = u32[] constant(0)
+
+  ROOT out = (f32[100,10]{1,0}, u32[100,10]{1,0}) reduce(input, idxs, zero, zero_idx), dimensions={2,3}, to_apply=argmax
+}
+)";
+
+  CheckDimensionGrouper(hlo, R"(
+// CHECK:  [[input_0:%[^ ]+]] = f32[100,10,32,3]{3,2,1,0} parameter(0)
+// CHECK:  [[bitcast_1:%[^ ]+]] = f32[100,10,96]{2,1,0} bitcast([[input_0]])
+// CHECK:  [[idxs_2:%[^ ]+]] = u32[100,10,32,3]{3,2,1,0} parameter(1)
+// CHECK:  [[bitcast_1_3:%[^ ]+]] = u32[100,10,96]{2,1,0} bitcast([[idxs_2]])
+// CHECK:  ROOT [[out_1_4:%[^ ]+]] = (f32[100,10]{1,0}, u32[100,10]{1,0}) reduce([[bitcast_1]], [[bitcast_1_3]], [[zero_5:%[^ ]+]], [[zero_idx_6:%[^ ]+]]), dimensions={2}, to_apply=[[argmax_7:%[^ ]+]]
+)");
+}
+
+}  // namespace
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/reduction_layout_normalizer.cc b/third_party/xla/xla/service/gpu/transforms/reduction_layout_normalizer.cc
new file mode 100644
index 0000000..fd45f8b
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/reduction_layout_normalizer.cc
@@ -0,0 +1,203 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/reduction_layout_normalizer.h"
+
+#include <cstdint>
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/container/inlined_vector.h"
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/layout.h"
+#include "xla/layout_util.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/status_macros.h"
+#include "xla/util.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+class EnforceMinorToMajorReduceOpVisitor : public DfsHloRewriteVisitor {
+  absl::Status HandleReduce(HloInstruction *hlo) override {
+    auto reduce = Cast<HloReduceInstruction>(hlo);
+    VLOG(5) << "Input: " << reduce->ToString();
+
+    int operand_idx = -1;
+
+    absl::InlinedVector<HloInstruction *, 2> canonical_reduce_inputs;
+    absl::InlinedVector<Shape, 2> new_reduce_shapes;
+
+    DimensionVector out_reduce_dimensions;
+    const Shape &first_instruction_shape = reduce->inputs()[0]->shape();
+
+    for (HloInstruction *operand : reduce->inputs()) {
+      operand_idx++;
+
+      if (operand_idx != 0 &&
+          operand->shape().layout() != first_instruction_shape.layout()) {
+        HloInstruction *copy =
+            reduce->parent()->AddInstruction(HloInstruction::CreateUnary(
+                operand->shape(), HloOpcode::kCopy, operand));
+
+        LayoutUtil::ClearLayout(copy->mutable_shape());
+        TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
+            first_instruction_shape, copy->mutable_shape()));
+
+        copy->set_metadata(operand->metadata());
+        operand = copy;
+        VLOG(3) << "Copying to establish consistent inputs layout: "
+                << copy->ToString();
+      }
+
+      const Shape &operand_shape = operand->shape();
+      const Layout &operand_layout = operand_shape.layout();
+
+      const Shape &reduce_shape =
+          reduce->shape().IsTuple() ? reduce->shape().tuple_shapes(operand_idx)
+                                    : reduce->shape();
+
+      DimensionVector new_reduce_dimensions;
+      DimensionVector new_operand_shape_data;
+      DimensionVector new_reduce_shape_data;
+
+      // The layout order of the reduction output can be different to the
+      // ordering of kept dimensions in the input operand, thus we need to
+      // calculate the new layout.
+      DimensionVector new_reduce_shape_layout(reduce_shape.rank());
+      std::vector<int64_t> reduce_shape_logical_to_physical =
+          LayoutUtil::MakeLogicalToPhysical(reduce_shape.layout());
+
+      auto to_reduce_logical_dim = [&](int64_t op_logical_dim) {
+        return op_logical_dim -
+               absl::c_count_if(reduce->dimensions(), [&](int64_t dim) {
+                 CHECK(dim != op_logical_dim);
+                 return dim < op_logical_dim;
+               });
+      };
+
+      for (int i = 0; i < operand_shape.rank(); i++) {
+        // Process the dimensions in the major-to-minor order in order to
+        // enforce the default layout.
+        int64_t major_to_minor_dim_idx = operand_shape.rank() - i - 1;
+        int64_t logical_dim =
+            operand_layout.minor_to_major(major_to_minor_dim_idx);
+        int64_t dim_size = operand_shape.dimensions(logical_dim);
+        VLOG(5) << "Processing logical dimension " << logical_dim << " of size "
+                << dim_size;
+        new_operand_shape_data.push_back(dim_size);
+
+        if (absl::c_linear_search(reduce->dimensions(), logical_dim)) {
+          new_reduce_dimensions.push_back(i);
+        } else {
+          new_reduce_shape_data.push_back(dim_size);
+          int64_t logical_reduce_dim = to_reduce_logical_dim(logical_dim);
+          int64_t physical_reduce_dim =
+              reduce_shape_logical_to_physical[logical_reduce_dim];
+          VLOG(5) << "logical_reduce_dim = " << logical_reduce_dim << ", "
+                  << "physical_reduce_dim = " << physical_reduce_dim;
+          new_reduce_shape_layout[reduce_shape.rank() - physical_reduce_dim -
+                                  1] = new_reduce_shape_data.size() - 1;
+        }
+      }
+
+      Shape new_operand_shape = ShapeUtil::MakeShape(
+          operand_shape.element_type(), new_operand_shape_data);
+      Shape new_reduce_shape = ShapeUtil::MakeShapeWithDenseLayout(
+          reduce_shape.element_type(), new_reduce_shape_data,
+          new_reduce_shape_layout);
+
+      if (new_operand_shape == operand_shape && reduce->inputs().size() == 1) {
+        return absl::OkStatus();
+      }
+
+      HloInstruction *canonical_reduce_input =
+          new_operand_shape != operand_shape
+              ? reduce->parent()->AddInstruction(
+                    HloInstruction::CreateBitcast(new_operand_shape, operand))
+              : operand;
+      canonical_reduce_input->set_metadata(operand->metadata());
+      VLOG(5) << "Reduction input: " << canonical_reduce_input->ToString();
+
+      new_reduce_shapes.push_back(new_reduce_shape);
+      canonical_reduce_inputs.push_back(canonical_reduce_input);
+
+      if (out_reduce_dimensions.empty()) {
+        out_reduce_dimensions = new_reduce_dimensions;
+      } else {
+        TF_RET_CHECK(out_reduce_dimensions == new_reduce_dimensions);
+      }
+    }
+
+    Shape new_reduce_shape = ShapeUtil::MakeMaybeTupleShape(new_reduce_shapes);
+
+    std::unique_ptr<HloInstruction> new_reduce = HloInstruction::CreateReduce(
+        new_reduce_shape, canonical_reduce_inputs, reduce->init_values(),
+        out_reduce_dimensions, reduce->to_apply());
+    VLOG(5) << "Generated new reduction: " << new_reduce->ToString();
+    const Shape &orig_reduce_shape = reduce->shape();
+
+    if (new_reduce_shape != orig_reduce_shape) {
+      HloInstruction *wrapped_reduce =
+          reduce->parent()->AddInstruction(std::move(new_reduce));
+
+      if (!new_reduce_shape.IsTuple()) {
+        new_reduce =
+            HloInstruction::CreateBitcast(reduce->shape(), wrapped_reduce);
+      } else {
+        // Bitcast each element of the tuple.
+        absl::InlinedVector<HloInstruction *, 2> out;
+        for (int oidx = 0; oidx < reduce->input_count(); oidx++) {
+          HloInstruction *gte = reduce->parent()->AddInstruction(
+              HloInstruction::CreateGetTupleElement(wrapped_reduce, oidx));
+          out.push_back(
+              reduce->parent()->AddInstruction(HloInstruction::CreateBitcast(
+                  orig_reduce_shape.tuple_shapes(oidx), gte)));
+        }
+        new_reduce = HloInstruction::CreateTuple(out);
+      }
+    }
+
+    VLOG(5) << "Generated output: " << new_reduce->ToString();
+    return ReplaceWithNewInstruction(reduce, std::move(new_reduce));
+  }
+};
+
+absl::StatusOr<bool> ReductionLayoutNormalizer::Run(
+    HloModule *module,
+    const absl::flat_hash_set<absl::string_view> &execution_threads) {
+  TF_ASSIGN_OR_RETURN(bool changed,
+                      EnforceMinorToMajorReduceOpVisitor().RunOnModule(
+                          module, execution_threads));
+  return changed;
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/reduction_layout_normalizer.h b/third_party/xla/xla/service/gpu/transforms/reduction_layout_normalizer.h
new file mode 100644
index 0000000..f6e2d7c
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/reduction_layout_normalizer.h
@@ -0,0 +1,54 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_REDUCTION_LAYOUT_NORMALIZER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_REDUCTION_LAYOUT_NORMALIZER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// Enforces default (minor-to-major) layout on all reduction inputs.
+// Note that since reduction output can request a custom layout,
+// this pass only guarantees standard layout for the input.
+//
+// For example,
+//
+//   f[20,30]{0,1} out = reduce(f[10,20,30]{2,0,1} input, dimensions={0})
+//
+// becomes:
+//
+//   f[20,10,30] tmp = f[20,10,30] bitcast(f[10,20,30]{2,0,1} input)
+//   f[20,30]{0,1} out = reduce(f[20,10,30]{2,1,0} tmp, dimensions={1})
+class ReductionLayoutNormalizer : public HloModulePass {
+ public:
+  absl::string_view name() const override {
+    return "reduction-layout-normalizer";
+  }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_REDUCTION_LAYOUT_NORMALIZER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/reduction_layout_normalizer_test.cc b/third_party/xla/xla/service/gpu/transforms/reduction_layout_normalizer_test.cc
new file mode 100644
index 0000000..46f5e93
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/reduction_layout_normalizer_test.cc
@@ -0,0 +1,164 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/reduction_layout_normalizer.h"
+
+#include <optional>
+
+#include "absl/strings/string_view.h"
+#include "xla/error_spec.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/test.h"
+
+namespace xla {
+
+namespace {
+
+class ReductionLayoutNormalizerTest : public HloTestBase {
+ public:
+  void CheckReductionLayoutNormalizer(
+      absl::string_view hlo, std::optional<absl::string_view> expected) {
+    RunAndFilecheckHloRewrite(hlo, gpu::ReductionLayoutNormalizer{}, expected);
+  }
+};
+
+TEST_F(ReductionLayoutNormalizerTest, LayoutCanonicalizerTest) {
+  const char* hlo = R"(
+HloModule ReduceWithLayoutChange
+
+add {
+  x0 = f32[] parameter(0)
+  y0 = f32[] parameter(1)
+  ROOT add0 = f32[] add(x0, y0)
+}
+
+ENTRY main {
+  arg0 = f32[4,5,5,16,12,12,3,3]{2,3,5,4,0,7,6,1}  parameter(0)
+  constant0 = f32[] constant(0)
+  ROOT reduce0 = f32[4,5,16,12,12]{4,3,2,1,0} reduce(arg0, constant0),
+    dimensions={1,6,7}, to_apply=add
+}
+
+)";
+
+  CheckReductionLayoutNormalizer(hlo,
+                                 R"(
+// CHECK:  [[bitcast_0:%[^ ]+]] = f32[5,3,3,4,12,12,16,5]{7,6,5,4,3,2,1,0} bitcast([[arg0_1:%[^ ]+]])
+// CHECK:  [[reduce_2:%[^ ]+]] = f32[4,12,12,16,5]{2,1,3,4,0} reduce([[bitcast_0]], [[constant0_3:%[^ ]+]]), dimensions={0,1,2}, to_apply=[[add_4:%[^ ]+]]
+// CHECK:  ROOT [[bitcast_1_5:%[^ ]+]] = f32[4,5,16,12,12]{4,3,2,1,0} bitcast([[reduce_2]])
+      )");
+}
+
+TEST_F(ReductionLayoutNormalizerTest, LayoutCanonicalizerTestVariadic) {
+  const char* hlo = R"(
+HloModule ReduceWithLayoutChangeVariadic
+
+
+argmax {
+  running_max = f32[] parameter(0)
+  running_max_idx = u32[] parameter(1)
+  current_value = f32[] parameter(2)
+  current_value_idx = u32[] parameter(3)
+
+  current = (f32[], u32[]) tuple(running_max, running_max_idx)
+  potential = (f32[], u32[]) tuple(current_value, current_value_idx)
+
+  cmp_code = pred[] compare(current_value, running_max), direction=GT
+
+  new_max = f32[] select(cmp_code, current_value, running_max)
+  new_idx = u32[] select(cmp_code, current_value_idx, running_max_idx)
+
+  ROOT out = (f32[], u32[]) tuple(new_max, new_idx)
+}
+
+ENTRY main {
+  arg0 = f32[4,5,5,16,12,12,3,3]{2,3,5,4,0,7,6,1}  parameter(0)
+  idxs = u32[4,5,5,16,12,12,3,3]{2,3,5,4,0,7,6,1}  parameter(1)
+  constant0 = f32[] constant(0)
+  constant1 = u32[] constant(0)
+  ROOT reduce0 = (
+      f32[4,5,16,12,12]{4,3,2,1,0},
+      u32[4,5,16,12,12]{4,3,2,1,0}
+    ) reduce(arg0, idxs, constant0,constant1), dimensions={1,6,7}, to_apply=argmax
+}
+
+
+)";
+
+  CheckReductionLayoutNormalizer(hlo,
+                                 R"(
+// CHECK:  [[arg0_0:%[^ ]+]] = f32[4,5,5,16,12,12,3,3]{2,3,5,4,0,7,6,1} parameter(0)
+// CHECK:  [[bitcast_1:%[^ ]+]] = f32[5,3,3,4,12,12,16,5]{7,6,5,4,3,2,1,0} bitcast([[arg0_0]])
+// CHECK:  [[idxs_2:%[^ ]+]] = u32[4,5,5,16,12,12,3,3]{2,3,5,4,0,7,6,1} parameter(1)
+// CHECK:  [[bitcast_1_3:%[^ ]+]] = u32[5,3,3,4,12,12,16,5]{7,6,5,4,3,2,1,0} bitcast([[idxs_2]])
+// CHECK:  [[reduce_4:%[^ ]+]] = (f32[4,12,12,16,5]{2,1,3,4,0}, u32[4,12,12,16,5]{2,1,3,4,0}) reduce([[bitcast_1]], [[bitcast_1_3]], [[constant0_5:%[^ ]+]], [[constant1_6:%[^ ]+]]), dimensions={0,1,2}, to_apply=[[argmax_7:%[^ ]+]]
+// CHECK:  [[get_tuple_element_8:%[^ ]+]] = f32[4,12,12,16,5]{2,1,3,4,0} get-tuple-element([[reduce_4]]), index=0
+// CHECK:  [[bitcast_2_9:%[^ ]+]] = f32[4,5,16,12,12]{4,3,2,1,0} bitcast([[get_tuple_element_8]])
+// CHECK:  [[get_tuple_element_1_10:%[^ ]+]] = u32[4,12,12,16,5]{2,1,3,4,0} get-tuple-element([[reduce_4]]), index=1
+// CHECK:  [[bitcast_3_11:%[^ ]+]] = u32[4,5,16,12,12]{4,3,2,1,0} bitcast([[get_tuple_element_1_10]])
+// CHECK:  ROOT [[tuple_12:%[^ ]+]] = (f32[4,5,16,12,12]{4,3,2,1,0}, u32[4,5,16,12,12]{4,3,2,1,0}) tuple([[bitcast_2_9]], [[bitcast_3_11]])
+      )");
+}
+
+TEST_F(ReductionLayoutNormalizerTest,
+       LayoutCanonicalizerTestVariadicDifferentLayouts) {
+  const char* hlo = R"(
+HloModule ReduceWithLayoutChangeVariadicDifferent
+
+argmax {
+  running_max = f32[] parameter(0)
+  running_max_idx = u32[] parameter(1)
+  current_value = f32[] parameter(2)
+  current_value_idx = u32[] parameter(3)
+
+  current = (f32[], u32[]) tuple(running_max, running_max_idx)
+  potential = (f32[], u32[]) tuple(current_value, current_value_idx)
+
+  cmp_code = pred[] compare(current_value, running_max), direction=GT
+
+  new_max = f32[] select(cmp_code, current_value, running_max)
+  new_idx = u32[] select(cmp_code, current_value_idx, running_max_idx)
+
+  ROOT out = (f32[], u32[]) tuple(new_max, new_idx)
+}
+
+ENTRY main {
+  arg0 = f32[2,3,4,7]{2,1,0,3}  parameter(0)
+  idxs = u32[2,3,4,7]{3,2,1,0}  parameter(1)
+  constant0 = f32[] constant(0)
+  constant1 = u32[] constant(0)
+  ROOT reduce0 = (
+      f32[2,3,4]{2,1,0},
+      u32[2,3,4]{2,1,0}
+    ) reduce(arg0, idxs, constant0,constant1), dimensions={3}, to_apply=argmax
+}
+
+
+)";
+
+  CheckReductionLayoutNormalizer(hlo,
+                                 R"(
+// CHECK:  [[arg0_0:%[^ ]+]] = f32[2,3,4,7]{2,1,0,3} parameter(0)
+// CHECK:  [[bitcast_1:%[^ ]+]] = f32[7,2,3,4]{3,2,1,0} bitcast([[arg0_0]])
+// CHECK:  [[idxs_2:%[^ ]+]] = u32[2,3,4,7]{3,2,1,0} parameter(1)
+// CHECK:  [[copy_3:%[^ ]+]] = u32[2,3,4,7]{2,1,0,3} copy([[idxs_2]])
+// CHECK:  [[bitcast_1_4:%[^ ]+]] = u32[7,2,3,4]{3,2,1,0} bitcast([[copy_3]])
+// CHECK:  ROOT [[reduce0_5:%[^ ]+]] = (f32[2,3,4]{2,1,0}, u32[2,3,4]{2,1,0}) reduce([[bitcast_1]], [[bitcast_1_4]], [[constant0_6:%[^ ]+]], [[constant1_7:%[^ ]+]]), dimensions={0}, to_apply=[[argmax_8:%[^ ]+]]
+      )");
+  EXPECT_TRUE(RunAndCompare(hlo, ErrorSpec{1e-5, 1e-5}));
+}
+
+}  // namespace
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/reduction_splitter.cc b/third_party/xla/xla/service/gpu/transforms/reduction_splitter.cc
new file mode 100644
index 0000000..dce9288
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/reduction_splitter.cc
@@ -0,0 +1,140 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/reduction_splitter.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <cstdlib>
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/layout_util.h"
+#include "xla/service/gpu/reduction_utils.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+class ReductionSplitterVisitor : public DfsHloRewriteVisitor {
+ public:
+  explicit ReductionSplitterVisitor(bool ignore_small_dims)
+      : ignore_small_dims_(ignore_small_dims) {}
+  absl::Status HandleReduce(HloInstruction *reduce) override {
+    VLOG(4) << "Input: " << reduce->ToString();
+
+    // Reductions with contiguous dimensions are lowered to efficient code. No
+    // need to split such ops.
+    if (IsReductionFromOrToContiguousDimensions(*reduce)) {
+      VLOG(4) << "Reduction with contiguous dimensions. Return.";
+      return absl::OkStatus();
+    }
+    if (reduce->dimensions().size() < 2) {
+      return absl::OkStatus();
+    }
+    if (!reduce->shape().IsArray()) {
+      // TODO(cheshire): Handle variadic reduction.
+      return absl::OkStatus();
+    }
+
+    HloInstruction *operand = reduce->mutable_operand(0);
+    const Shape &shape = operand->shape();
+    CHECK(shape == LayoutUtil::GetWithDefaultLayout(shape))
+        << "Default layout should be enforced on reduction operand";
+    // Verify that contiguous dimensions have been grouped by the
+    // ReductionDimensionGrouper pass.
+    for (int64_t i = 0; i < reduce->dimensions().size(); ++i) {
+      for (int64_t j = i + 1; j < reduce->dimensions().size(); ++j) {
+        CHECK(abs(reduce->dimensions(i) - reduce->dimensions(j)) > 1)
+            << "Reduction dimensions must not be consecutive";
+      }
+    }
+
+    // The reduce op has non-contiguous dimensions. Look for the dimension with
+    // the largest shape dimension. Reducing along this dimension first will
+    // reduce the output size most effectively.
+    int64_t max_shape_dim = 0;
+    int64_t max_reduce_dim = 0;
+    const auto &input_shape = reduce->operand(0)->shape();
+    for (int64_t i = 0; i < reduce->dimensions().size(); ++i) {
+      if (input_shape.dimensions(reduce->dimensions(i)) > max_shape_dim) {
+        max_reduce_dim = reduce->dimensions(i);
+        max_shape_dim = input_shape.dimensions(max_reduce_dim);
+      }
+    }
+    if (ignore_small_dims_ && max_shape_dim <= 8) {
+      return absl::OkStatus();
+    }
+
+    // Split the reduction into a pre-reduction and a final reduction.
+    VLOG(3) << "Splitting reduction " << reduce->name() << " at dimension "
+            << max_reduce_dim;
+    std::vector<int64_t> pre_reduce_dims;
+    pre_reduce_dims.push_back(max_reduce_dim);
+    std::vector<int64_t> pre_reduce_shape_dims(input_shape.dimensions().begin(),
+                                               input_shape.dimensions().end());
+    pre_reduce_shape_dims.erase(pre_reduce_shape_dims.begin() + max_reduce_dim);
+    Shape pre_reduce_shape = ShapeUtil::MakeShape(
+        reduce->shape().element_type(), pre_reduce_shape_dims);
+    std::unique_ptr<HloInstruction> pre_reduce = HloInstruction::CreateReduce(
+        pre_reduce_shape, reduce->mutable_operand(0),
+        reduce->mutable_operand(1), pre_reduce_dims, reduce->to_apply());
+    pre_reduce->set_metadata(reduce->metadata());
+
+    std::vector<int64_t> final_reduce_dims(reduce->dimensions().begin(),
+                                           reduce->dimensions().end());
+    final_reduce_dims.erase(
+        std::remove(final_reduce_dims.begin(), final_reduce_dims.end(),
+                    max_reduce_dim),
+        final_reduce_dims.end());
+    for (int64_t i = 0; i < final_reduce_dims.size(); ++i) {
+      if (final_reduce_dims[i] > max_reduce_dim) {
+        final_reduce_dims[i]--;
+      }
+    }
+    std::unique_ptr<HloInstruction> final_reduce = HloInstruction::CreateReduce(
+        reduce->shape(),
+        reduce->parent()->AddInstruction(std::move(pre_reduce)),
+        reduce->mutable_operand(1), final_reduce_dims, reduce->to_apply());
+    return ReplaceWithNewInstruction(reduce, std::move(final_reduce));
+  }
+
+ private:
+  bool ignore_small_dims_;
+};
+
+absl::StatusOr<bool> ReductionSplitter::Run(
+    HloModule *module,
+    const absl::flat_hash_set<absl::string_view> &execution_threads) {
+  TF_ASSIGN_OR_RETURN(bool changed,
+                      ReductionSplitterVisitor(ignore_small_dims_)
+                          .RunOnModule(module, execution_threads));
+  return changed;
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/reduction_splitter.h b/third_party/xla/xla/service/gpu/transforms/reduction_splitter.h
new file mode 100644
index 0000000..87520d3
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/reduction_splitter.h
@@ -0,0 +1,59 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_REDUCTION_SPLITTER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_REDUCTION_SPLITTER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// Splits a reduce op into two consecutive reduce ops if the reduce dimensions
+// are not contiguous. Ignores small reduce dimensions if `ignore_small_dims` is
+// set.
+//
+// Reductions with non-contiguous dimensions are emitted as simple element-wise
+// loops. This is inefficient when reducing large input shape dimensions.
+// Splitting such reductions allows using more efficient reduction emitters.
+//
+// This pass splits reduce ops into two consecutive reduce ops. Run it to a
+// fixpoint to split reduce ops along multiple dimensions.
+//
+// Precondition: ReductionDimensionGrouper has been run and adjacent reduce
+// dimentsions have been grouped. Reduction layouts have been normalized.
+
+class ReductionSplitter : public HloModulePass {
+ public:
+  explicit ReductionSplitter(bool ignore_small_dims)
+      : ignore_small_dims_(ignore_small_dims) {}
+  absl::string_view name() const override { return "reduction-splitter"; }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+  bool ignore_small_dims_;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_REDUCTION_SPLITTER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/reduction_splitter_test.cc b/third_party/xla/xla/service/gpu/transforms/reduction_splitter_test.cc
new file mode 100644
index 0000000..4b9f6fb
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/reduction_splitter_test.cc
@@ -0,0 +1,152 @@
+/* Copyright 2016 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/reduction_splitter.h"
+
+#include <cstdint>
+#include <vector>
+
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/service/hlo_parser.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/shape_util.h"
+#include "xla/test.h"
+#include "xla/tests/hlo_test_base.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+namespace m = ::xla::match;
+
+class ReductionSplitterTest : public HloTestBase {};
+
+TEST_F(ReductionSplitterTest, SplitReductionAtDimensionTwo) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule test
+
+  add_computation {
+    x = f32[] parameter(0)
+    y = f32[] parameter(1)
+    ROOT add = f32[] add(x, y)
+  }
+
+  ENTRY entry_computation {
+    param_0 = f16[6,16,512,64]{3,2,1,0} parameter(0)
+    transpose.1781 = f16[6,512,16,64]{3,1,2,0} transpose(param_0), dimensions={0,2,1,3}
+    convert.6986 = f32[6,512,16,64]{3,1,2,0} convert(transpose.1781)
+    bitcast.2136 = f32[6,16,512,64]{3,2,1,0} bitcast(convert.6986)
+    constant_11111 = f32[] constant(0)
+    ROOT reduce.982 = f32[16,64]{1,0} reduce(bitcast.2136, constant_11111), dimensions={0,2}, to_apply=add_computation
+  }
+  )")
+                    .value();
+  ASSERT_TRUE(
+      ReductionSplitter(/*ignore_small_dims=*/true).Run(module.get()).value());
+  SCOPED_TRACE(module->ToString());
+  const HloInstruction* root_reduction =
+      module->entry_computation()->root_instruction();
+  ASSERT_THAT(root_reduction,
+              GmockMatch(m::Reduce(m::Reduce(), m::Constant())));
+
+  auto* pre_reduction = root_reduction->operand(0);
+  EXPECT_THAT(pre_reduction->dimensions(), std::vector<int64_t>({2}));
+  EXPECT_THAT(pre_reduction->shape(), ShapeUtil::MakeShape(F32, {6, 16, 64}));
+  EXPECT_THAT(root_reduction->dimensions(), std::vector<int64_t>({0}));
+  EXPECT_THAT(root_reduction->shape(), ShapeUtil::MakeShape(F32, {16, 64}));
+}
+
+TEST_F(ReductionSplitterTest, SplitReductionAtDimensionZero) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule test
+
+  add_computation {
+    x = f32[] parameter(0)
+    y = f32[] parameter(1)
+    ROOT add = f32[] add(x, y)
+  }
+
+  ENTRY entry_computation {
+    param_0 = f32[1024,16,512,64,128]{4,3,2,1,0} parameter(0)
+    constant_11111 = f32[] constant(0)
+    ROOT reduce.982 = f32[16,64]{1,0} reduce(param_0, constant_11111), dimensions={2,0,4}, to_apply=add_computation
+  }
+  )")
+                    .value();
+  ASSERT_TRUE(
+      ReductionSplitter(/*ignore_small_dims=*/false).Run(module.get()).value());
+  SCOPED_TRACE(module->ToString());
+  const HloInstruction* root_reduction =
+      module->entry_computation()->root_instruction();
+  ASSERT_THAT(root_reduction,
+              GmockMatch(m::Reduce(m::Reduce(), m::Constant())));
+
+  auto* pre_reduction = root_reduction->operand(0);
+  EXPECT_THAT(pre_reduction->dimensions(), std::vector<int64_t>({0}));
+  EXPECT_THAT(pre_reduction->shape(),
+              ShapeUtil::MakeShape(F32, {16, 512, 64, 128}));
+  EXPECT_THAT(root_reduction->dimensions(), std::vector<int64_t>({1, 3}));
+  EXPECT_THAT(root_reduction->shape(), ShapeUtil::MakeShape(F32, {16, 64}));
+}
+
+TEST_F(ReductionSplitterTest, DontSplitReductionWithSmallDimensions) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule test
+
+  add_computation {
+    x = f32[] parameter(0)
+    y = f32[] parameter(1)
+    ROOT add = f32[] add(x, y)
+  }
+
+  ENTRY entry_computation {
+    param_0 = f32[16,8,1024,8]{3,2,1,0} parameter(0)
+    constant_11111 = f32[] constant(0)
+    ROOT reduce.982 = f32[16,1024]{1,0} reduce(param_0, constant_11111), dimensions={3,1}, to_apply=add_computation
+  }
+  )")
+                    .value();
+  EXPECT_FALSE(
+      ReductionSplitter(/*ignore_small_dims=*/true).Run(module.get()).value());
+  EXPECT_TRUE(
+      ReductionSplitter(/*ignore_small_dims=*/false).Run(module.get()).value());
+}
+
+TEST_F(ReductionSplitterTest, DontSplitReductionsWithContiguousDimensions) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule test
+
+  add_computation {
+    x = f32[] parameter(0)
+    y = f32[] parameter(1)
+    ROOT add = f32[] add(x, y)
+  }
+
+  ENTRY entry_computation {
+    param_0 = f32[128,128,64,128]{3,2,1,0} parameter(0)
+    constant_11111 = f32[] constant(0)
+    // The dimenstions to keep (1 and 2) are contiguous.
+    ROOT reduce.982 = f32[128,64]{1,0} reduce(param_0, constant_11111), dimensions={3,0}, to_apply=add_computation
+  }
+  )")
+                    .value();
+  EXPECT_FALSE(
+      ReductionSplitter(/*ignore_small_dims=*/false).Run(module.get()).value());
+}
+
+}  // namespace
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/rename_fusions.cc b/third_party/xla/xla/service/gpu/transforms/rename_fusions.cc
new file mode 100644
index 0000000..9ab62f6
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/rename_fusions.cc
@@ -0,0 +1,92 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/rename_fusions.h"
+
+#include <memory>
+#include <string>
+
+#include "absl/container/btree_set.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/str_replace.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/gpu/hlo_traversal.h"
+#include "xla/service/gpu/ir_emission_utils.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+constexpr absl::string_view FusionKindToString(
+    HloInstruction::FusionKind kind) {
+  switch (kind) {
+    case HloInstruction::FusionKind::kCustom:
+      return "custom";
+    case HloInstruction::FusionKind::kLoop:
+      return "loop";
+    case HloInstruction::FusionKind::kInput:
+      return "input";
+    case HloInstruction::FusionKind::kOutput:
+      return "output";
+  }
+}
+
+std::string MakeFusionHeroNames(const HloInstruction* instruction) {
+  std::unique_ptr<HloFusionAdaptor> fusion_adaptor =
+      HloFusionAdaptor::ForInstruction(instruction);
+  absl::btree_set<absl::string_view> heroes;
+
+  for (auto root : fusion_adaptor->GetRoots()) {
+    heroes.insert(HloOpcodeString(FindNonTrivialHero(root).opcode()));
+  }
+  return absl::StrReplaceAll(absl::StrJoin(heroes, "_"), {{"-", "_"}});
+}
+
+void RenameFusion(HloModule* module, HloInstruction* instruction) {
+  std::string hero_names = MakeFusionHeroNames(instruction);
+  module->SetAndUniquifyInstrName(
+      instruction, absl::StrCat(FusionKindToString(instruction->fusion_kind()),
+                                "_", hero_names, "_fusion"));
+  module->SetAndUniquifyComputationName(
+      instruction->fused_instructions_computation(),
+      absl::StrCat("fused_", hero_names));
+}
+
+}  // namespace
+
+absl::StatusOr<bool> RenameFusions::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  for (HloComputation* computation : module->MakeNonfusionComputations()) {
+    for (HloInstruction* instruction : computation->instructions()) {
+      if (instruction->opcode() != HloOpcode::kFusion ||
+          instruction->fusion_kind() == HloInstruction::FusionKind::kCustom) {
+        continue;
+      }
+      RenameFusion(module, instruction);
+    }
+  }
+  return true;
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/rename_fusions.h b/third_party/xla/xla/service/gpu/transforms/rename_fusions.h
new file mode 100644
index 0000000..5abcd61
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/rename_fusions.h
@@ -0,0 +1,47 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_RENAME_FUSIONS_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_RENAME_FUSIONS_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// An HLO pass that gives fusions and fused computations descriptive names.
+//
+// The name is based on hero instructions and the fusion kind, i.e.
+// Fusions get name "<fusion kind>_<hero instrucitons>_fusion",
+// and fused computations get name "fused_<hero instructions>".
+// In the case of multiple roots, the hero instructions in the name are
+// underscore-separated and alphabetically sorted.
+
+class RenameFusions : public HloModulePass {
+  absl::string_view name() const override { return "rename_fusions"; }
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_RENAME_FUSIONS_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/rename_fusions_test.cc b/third_party/xla/xla/service/gpu/transforms/rename_fusions_test.cc
new file mode 100644
index 0000000..4747085
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/rename_fusions_test.cc
@@ -0,0 +1,83 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/rename_fusions.h"
+
+#include <utility>
+
+#include <gtest/gtest.h>
+#include "absl/strings/string_view.h"
+#include "xla/tests/hlo_test_base.h"
+
+namespace xla {
+namespace gpu {
+
+class RenameFusionsTest : public HloTestBase {
+ protected:
+  RenameFusions rename_fusions_;
+};
+
+TEST_F(RenameFusionsTest, FusionInstructionNames) {
+  absl::string_view kHlo = R"(
+      HloModule test_module
+
+      square {
+        p = f32[16384] parameter(0)
+        ROOT m = f32[16384] multiply(p, p)
+      }
+
+      exp {
+        p = f32[16384] parameter(0)
+        ROOT e = f32[16384] exponential(p)
+      }
+
+      log {
+        p = f32[16384] parameter(0)
+        ROOT l = f32[16384] log(p)
+      }
+
+      add {
+        p0 = f32[] parameter(0)
+        p1 = f32[] parameter(1)
+        ROOT add = f32[] add(p0, p1)
+      }
+
+      ENTRY main {
+        p0 = bf16[1024,8192] parameter(0)
+        p1 = f32[8192] parameter(1)
+        p2 = f32[16384] parameter(2)
+        convert = f32[1024,8192] convert(p0)
+        broadcast = f32[1024,8192] broadcast(p1), dimensions={1}
+        c0 = f32[] constant(0)
+        multiply = f32[1024,8192] multiply(broadcast, convert)
+        reduce = f32[1024] reduce(multiply, c0), dimensions={1}, to_apply=add
+        convert.1 = bf16[1024] convert(reduce)
+        s = f32[16384] fusion(p2), kind=kLoop, calls=square
+        e = f32[16384] fusion(s), kind=kLoop, calls=exp
+        l = f32[16384] fusion(s), kind=kInput, calls=log
+        ROOT result = (bf16[1024]{0}, f32[16384]{0}, f32[16384]{0}) tuple(convert.1, l, e)
+      })";
+
+  RunAndFilecheckHloRewrite(kHlo, std::move(rename_fusions_), R"(
+CHECK: ENTRY %main
+CHECK: %loop_multiply_fusion{{.*}} calls=%fused_multiply
+CHECK: %input_log_fusion{{.*}} calls=%fused_log
+CHECK: %loop_exponential_fusion{{.*}} calls=%fused_exponential
+CHECK: ROOT %result
+  )");
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/sanitize_constant_names.cc b/third_party/xla/xla/service/gpu/transforms/sanitize_constant_names.cc
new file mode 100644
index 0000000..3841f4a
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/sanitize_constant_names.cc
@@ -0,0 +1,75 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/sanitize_constant_names.h"
+
+#include <string>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/llvm_ir/buffer_assignment_util.h"
+#include "xla/service/name_uniquer.h"
+#include "tsl/platform/logging.h"
+
+namespace xla {
+
+namespace gpu {
+
+absl::StatusOr<bool> SanitizeConstantNames::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  bool changed = false;
+
+  NameUniquer instr_name_uniquer(/*separator=*/"_");
+  // Collect the names used for the non-constant HLO instructions.+
+  for (HloComputation* computation : module->computations(execution_threads)) {
+    for (HloInstruction* instr : computation->instructions()) {
+      if (instr->opcode() == HloOpcode::kConstant) {
+        continue;
+      }
+
+      // Record the non-constant HLO instruction name in uniquer, and keep
+      // original instruction name unchanged.
+      instr_name_uniquer.GetUniqueName(instr->name());
+    }
+  }
+
+  // Sanitize the names for the constant HLO instructions and make them unique.
+  // This is not merged into the above loop because we don't want this pass to
+  // change the names of non-constant instructions, that is, if a constant HLO
+  // conflicts with a non-constant HLO, we change the name of the constant HLO
+  // even though the non-constant HLO comes after in the HLO module.
+  for (HloComputation* computation : module->computations(execution_threads)) {
+    for (HloInstruction* instr : computation->instructions()) {
+      if (instr->opcode() != HloOpcode::kConstant) {
+        continue;
+      }
+      std::string sanitized_name = llvm_ir::SanitizeConstantName(*instr);
+      instr->SetAndSanitizeName(sanitized_name);
+      instr->UniquifyName(&instr_name_uniquer);
+      // Register this new name with the module's instruction_name_uniquer to
+      // avoid name collision that might happen in future.
+      module->instruction_name_uniquer().GetUniqueName(instr->name());
+      changed = true;
+    }
+  }
+
+  return changed;
+}  // namespace gpu
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/sanitize_constant_names.h b/third_party/xla/xla/service/gpu/transforms/sanitize_constant_names.h
new file mode 100644
index 0000000..f743137
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/sanitize_constant_names.h
@@ -0,0 +1,44 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_SANITIZE_CONSTANT_NAMES_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_SANITIZE_CONSTANT_NAMES_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// Sanitizes HLO instruction names for the GPU backend. Currently, it only
+// replaces . and - in the HLO constant instruction names with _ to please the
+// LLVM PTX backend.
+class SanitizeConstantNames : public HloModulePass {
+ public:
+  absl::string_view name() const override { return "sanitize-constant-names"; }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_SANITIZE_CONSTANT_NAMES_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/sanitize_constant_names_test.cc b/third_party/xla/xla/service/gpu/transforms/sanitize_constant_names_test.cc
new file mode 100644
index 0000000..8e97790
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/sanitize_constant_names_test.cc
@@ -0,0 +1,111 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/sanitize_constant_names.h"
+
+#include <cstdint>
+#include <memory>
+#include <utility>
+
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/literal_util.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/statusor.h"
+#include "tsl/platform/test.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+namespace m = ::xla::match;
+using SanitizeConstantNamesTest = HloTestBase;
+
+TEST_F(SanitizeConstantNamesTest, InstructionNameWithHyphenSanitized) {
+  const char *const kHloString = R"(
+    HloModule HyphenInInstructionName
+      ENTRY kernelEntry {
+        ROOT equal-to = s32[2]{0} constant({42, 73})
+    })";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(kHloString));
+
+  EXPECT_TRUE(SanitizeConstantNames().Run(module.get()).value());
+  HloInstruction *root = module->entry_computation()->root_instruction();
+  EXPECT_EQ(root->name(), "equal_to");
+}
+
+TEST_F(SanitizeConstantNamesTest, InstructionNameWithDotSanitized) {
+  const char *const kHloString = R"(
+    HloModule HyphenInInstructionName
+      ENTRY kernelEntry {
+        ROOT equal.to = s32[2]{0} constant({42, 73})
+    })";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(kHloString));
+
+  EXPECT_TRUE(SanitizeConstantNames().Run(module.get()).value());
+  HloInstruction *root = module->entry_computation()->root_instruction();
+  EXPECT_EQ(root->name(), "equal_to");
+}
+
+TEST_F(SanitizeConstantNamesTest, NewInstructionNameRegisteredWithModule) {
+  const char *const kHloString = R"(
+    HloModule HyphenInInstructionName
+      ENTRY kernelEntry {
+        ROOT equal.to = s32[2]{0} constant({42, 73})
+    })";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(kHloString));
+
+  EXPECT_TRUE(SanitizeConstantNames().Run(module.get()).value());
+  HloInstruction *root = module->entry_computation()->root_instruction();
+  EXPECT_EQ(root->name(), "equal_to");
+
+  auto constant_instr =
+      HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(1));
+  constant_instr->SetAndSanitizeName("equal_to");
+  module->entry_computation()->AddInstruction(std::move(constant_instr));
+
+  EXPECT_THAT(FindInstruction(module.get(), "equal_to.1"),
+              GmockMatch(m::Constant()));
+}
+
+TEST_F(SanitizeConstantNamesTest, BufferSanitizedNameCollisionResolved) {
+  const char *const kHloString = R"(
+    HloModule BufferSanitizedName
+      ENTRY kernelEntry {
+      equal.to = s32[2]{0} constant({42, 73})
+      equal-to = s32[2]{0} constant({67, 3})
+      ROOT equal_to = s32[2]{0} add(equal.to, equal-to)
+    })";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(kHloString));
+
+  EXPECT_TRUE(SanitizeConstantNames().Run(module.get()).value());
+  EXPECT_THAT(FindInstruction(module.get(), "equal_to_1"),
+              GmockMatch(m::Constant()));
+  EXPECT_THAT(FindInstruction(module.get(), "equal_to_2"),
+              GmockMatch(m::Constant()));
+}
+
+}  // namespace
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/scatter_expander.cc b/third_party/xla/xla/service/gpu/transforms/scatter_expander.cc
new file mode 100644
index 0000000..26eb210
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/scatter_expander.cc
@@ -0,0 +1,33 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/scatter_expander.h"
+
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/primitive_util.h"
+
+namespace xla {
+
+bool GpuScatterExpander::InstructionMatchesPattern(HloInstruction* inst) {
+  // TODO(b/129698548): Scattering elements larger than 64 bits is not
+  // supported by XLA:GPU.
+  // TODO(b/227486631): Variadic scatter is not yet supported by GPU.
+  return inst->opcode() == HloOpcode::kScatter &&
+         (inst->shape().IsTuple() ||
+          primitive_util::BitWidth(inst->shape().element_type()) > 64);
+}
+
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/scatter_expander.h b/third_party/xla/xla/service/gpu/transforms/scatter_expander.h
new file mode 100644
index 0000000..f86b932
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/scatter_expander.h
@@ -0,0 +1,40 @@
+/* Copyright 2019 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_SCATTER_EXPANDER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_SCATTER_EXPANDER_H_
+
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/service/scatter_expander.h"
+
+namespace xla {
+
+// Legalizes scatters on the GPU.
+class GpuScatterExpander : public ScatterExpander {
+ public:
+  // Although we pass kEliminateAllScatters, we override this behavior in
+  // InstructionMatchesPattern and select only some scatters to expand.
+  GpuScatterExpander() : ScatterExpander(kEliminateAllScatters) {}
+
+  absl::string_view name() const override { return "gpu_scatter_expander"; }
+
+ protected:
+  bool InstructionMatchesPattern(HloInstruction* inst) override;
+};
+
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_SCATTER_EXPANDER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/scatter_slice_simplifier.cc b/third_party/xla/xla/service/gpu/transforms/scatter_slice_simplifier.cc
new file mode 100644
index 0000000..d9c1deb
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/scatter_slice_simplifier.cc
@@ -0,0 +1,264 @@
+/* Copyright 2022 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/scatter_slice_simplifier.h"
+
+#include <cstdint>
+#include <iterator>
+#include <optional>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/hlo_creation_utils.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/util.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace {
+
+// Returns whether the instruction could be an operand for a slice instruction.
+bool IsValidIntermediaryUser(const HloInstruction* instruction) {
+  // Allow elementwise instructions, as they don't depend on the truncated
+  // elements. In case of multi-output scatters, the resulting shape is a tuple.
+  return instruction->IsElementwise() ||
+         instruction->opcode() == HloOpcode::kGetTupleElement;
+}
+
+// Matches the "Scatter -> Elementwise (zero or more) -> Slice" pattern.
+// Calculates the resulting scatter dimensions from the slice users.
+class ScatterSliceMatcher {
+ public:
+  explicit ScatterSliceMatcher(const HloScatterInstruction* scatter)
+      : scatter_(scatter),
+        operand_dimensions_(
+            scatter->scatter_operands()[0]->shape().dimensions()),
+        result_dimensions_(operand_dimensions_.begin(),
+                           operand_dimensions_.end()) {}
+
+  // Determine the scatter shape from the user slice instructions.
+  // If any of the users are not truncation slices, return `nullopt`.
+  std::optional<Shape> InferShape() {
+    VLOG(10) << "Evaluating scatter " << scatter_->name();
+    if (!AreAllUsersValid(scatter_)) {
+      return std::nullopt;
+    }
+    std::vector<Shape> result_shapes;
+    absl::c_transform(scatter_->scatter_operands(),
+                      std::back_inserter(result_shapes),
+                      [&](const HloInstruction* op) {
+                        return ShapeUtil::MakeShape(op->shape().element_type(),
+                                                    result_dimensions_);
+                      });
+    return ShapeUtil::MakeMaybeTupleShape(result_shapes);
+  }
+
+ private:
+  // Update the resulting scatter dimensions from the slice configuration and
+  // the original scatter dimensions. Return `false` if the update is not
+  // possible.
+  bool UpdateDimensions(const HloSliceInstruction* slice) {
+    int64_t rank = slice->shape().rank();
+    for (int64_t i = 0; i < rank; ++i) {
+      if (slice->slice_starts(i) != 0 || slice->slice_strides(i) != 1) {
+        return false;  // The slice is not a truncation.
+      }
+      if (slice->slice_limits(i) != result_dimensions_[i]) {
+        if (result_dimensions_[i] != operand_dimensions_[i]) {
+          return false;  // Another slice has incompatible dimensions.
+        }
+        auto& update_window_dims =
+            scatter_->scatter_dimension_numbers().update_window_dims();
+        if (absl::c_binary_search(update_window_dims, i)) {
+          return false;  // Update dimensions cannot be truncated.
+        }
+        result_dimensions_[i] = slice->slice_limits(i);
+        VLOG(10) << "Dimension " << i << " truncated to size "
+                 << result_dimensions_[i];
+      }
+    }
+    return true;
+  }
+
+  // Verify that the instruction is a valid scatter user, i.e. is either a slice
+  // operation or is an elementwise operation that has slice users (recursive).
+  bool IsUserValid(const HloInstruction* op) {
+    VLOG(10) << "Visiting user " << op->name();
+
+    // If the user is a slice operation, verify the configuration and update
+    // the resulting dimensions.
+    if (auto* slice = DynCast<HloSliceInstruction>(op)) {
+      return UpdateDimensions(slice);
+    }
+    // If the user is an elementwise operation, verify the users recursively
+    // (unless already visited).
+    bool is_valid = visited_set_.contains(op) ||
+                    (IsValidIntermediaryUser(op) && AreAllUsersValid(op));
+    if (is_valid) {
+      visited_set_.emplace(op);
+    }
+    return is_valid;
+  }
+
+  // Verify that all users are valid (see the definition of IsValidUser).
+  // If we reach the root instruction, fail the matching (slice is not found).
+  bool AreAllUsersValid(const HloInstruction* op) {
+    if (op->user_count() == 0) {
+      return !op->IsRoot();
+    }
+    return absl::c_all_of(op->users(), [this](const HloInstruction* user) {
+      return IsUserValid(user);
+    });
+  }
+
+  const HloScatterInstruction* scatter_;
+  absl::flat_hash_set<const HloInstruction*> visited_set_;
+  absl::Span<const int64_t> operand_dimensions_;
+  DimensionVector result_dimensions_;
+};
+
+// Create a replacement operand for the scatter instruction.
+HloInstruction* CreateSliceFrom(HloInstruction* operand, const Shape& shape) {
+  std::vector<int64_t> start_indices(shape.rank(), 0);
+  std::vector<int64_t> limit_indices(shape.rank());
+  std::vector<int64_t> strides(shape.rank(), 1);
+  for (int64_t i = 0; i < shape.rank(); ++i) {
+    limit_indices[i] = shape.dimensions(i);
+  }
+  return operand->AddInstruction(HloInstruction::CreateSlice(
+      shape, operand, start_indices, limit_indices, strides));
+}
+
+// Create a replacement for the scatter instruction.
+HloInstruction* CreateScatterFrom(HloScatterInstruction* scatter,
+                                  const Shape& shape) {
+  std::vector<HloInstruction*> operands(scatter->scatter_operand_count());
+  for (int64_t i = 0; i < operands.size(); ++i) {
+    operands[i] =
+        CreateSliceFrom(scatter->scatter_operands()[i],
+                        shape.IsTuple() ? shape.tuple_shapes(i) : shape);
+  }
+  return scatter->AddInstruction(HloInstruction::CreateScatter(
+      shape, absl::MakeSpan(operands), scatter->scatter_indices(),
+      scatter->scatter_updates(), scatter->called_computations()[0],
+      scatter->scatter_dimension_numbers(), scatter->indices_are_sorted(),
+      scatter->unique_indices()));
+}
+
+class ScatterSliceSimplifierVisitor : public DfsHloRewriteVisitor {
+ public:
+  absl::Status HandleScatter(HloInstruction* instruction) override {
+    auto* scatter = Cast<HloScatterInstruction>(instruction);
+
+    // Infer scatter shape from the slice users.
+    std::optional<Shape> result_shape =
+        ScatterSliceMatcher(scatter).InferShape();
+    if (!result_shape.has_value()) {
+      return absl::OkStatus();
+    }
+    VLOG(2) << "Matched scatter " << scatter->name() << " with shape "
+            << scatter->shape().ToString() << ", inferred result shape "
+            << result_shape->ToString() << " (from the slice users)";
+
+    // Replace slice user instructions.
+    HloInstruction* new_scatter = CreateScatterFrom(scatter, *result_shape);
+    return ReplaceAllUsersRecursive(scatter, new_scatter);
+  }
+
+ private:
+  // Create a replacement for every user. If the user is a slice operation,
+  // replace it in the computation graph, the old branch will be removed.
+  absl::Status ReplaceAllUsersRecursive(HloInstruction* old_instruction,
+                                        HloInstruction* new_instruction) {
+    // Maintain the replacement map, needed for non-unary elementwise users.
+    replacements_[old_instruction] = new_instruction;
+
+    // It's importand to make a copy of the users list, as it may be modified
+    // during the iteration.
+    std::vector<HloInstruction*> users = old_instruction->users();
+    for (HloInstruction* user : users) {
+      if (user->parent() == nullptr) {
+        VLOG(3) << "Skipping user " << user->name() << " (already replaced)";
+        continue;
+      }
+      TF_RETURN_IF_ERROR(ReplaceUserRecursive(user, new_instruction));
+    }
+    return absl::OkStatus();
+  }
+
+  // Replace the slice user with a new scatter (or a new chain of operations
+  // starting with a scatter). For elementwise operations, create a new user
+  // with updated operands (build the chain).
+  absl::Status ReplaceUserRecursive(HloInstruction* user,
+                                    HloInstruction* operand) {
+    VLOG(3) << "Replacing scatter user " << user->name();
+    if (user->opcode() == HloOpcode::kSlice) {
+      return ReplaceInstruction(user, operand);
+    }
+
+    // Create the replacement instruction with new shape.
+    HloInstruction* new_user = nullptr;
+    if (user->IsElementwise()) {
+      auto new_shape = [operand](HloInstruction* from) {
+        return ShapeUtil::MakeShape(from->shape().element_type(),
+                                    operand->shape().dimensions());
+      };
+      std::vector<HloInstruction*> new_operands;
+      absl::c_transform(user->operands(), std::back_inserter(new_operands),
+                        [&](HloInstruction* op) {
+                          auto it = replacements_.find(op);
+                          return it != replacements_.end()
+                                     ? it->second
+                                     : CreateSliceFrom(op, new_shape(op));
+                        });
+      new_user = user->AddInstruction(
+          user->CloneWithNewOperands(new_shape(user), new_operands));
+    } else {
+      auto* gte = Cast<HloGetTupleElementInstruction>(user);
+      TF_ASSIGN_OR_RETURN(new_user,
+                          MakeGetTupleElementHlo(operand, gte->tuple_index(),
+                                                 &user->metadata()));
+    }
+
+    // Replace slice user instructions recursively.
+    return ReplaceAllUsersRecursive(user, new_user);
+  }
+
+  absl::flat_hash_map<HloInstruction*, HloInstruction*> replacements_;
+};
+
+}  // namespace
+
+absl::StatusOr<bool> ScatterSliceSimplifier::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  return ScatterSliceSimplifierVisitor{}.RunOnModule(module, execution_threads);
+}
+
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/scatter_slice_simplifier.h b/third_party/xla/xla/service/gpu/transforms/scatter_slice_simplifier.h
new file mode 100644
index 0000000..96f39b5
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/scatter_slice_simplifier.h
@@ -0,0 +1,58 @@
+/* Copyright 2022 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_SCATTER_SLICE_SIMPLIFIER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_SCATTER_SLICE_SIMPLIFIER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+
+// Replaces scatters followed by truncation slices with a new scatter using
+// a different output shape, and the slices are eliminated.
+//
+// (a) Single output     (b) Multiple outputs    (c) Elementwise users
+//
+//   T[N+1] scatter       (T1, T2) scatter         T scatter  T constant
+//          v                v       v                  v     v
+//     T[N] slice          T1 gte    T2 gte            T maximum
+//                           v       v                     v
+//                         T1 slice  T2 slice           T slice
+//
+// This pattern is used when the last element of the scatter output is intended
+// to accumulate updates from the input elements that should be ignored.
+// This is slow if there are many indices mapped to the last output index and
+// the scatter is implemented using atomics, so everything collides on that one
+// memory location.
+// As OOB scatter indices are dropped by the GPU implementation, we can remove
+// the slice step entirely and avoid the memory congestion in the scatter step.
+
+class ScatterSliceSimplifier : public HloModulePass {
+ public:
+  absl::string_view name() const override { return "scatter-slice-simplifier"; }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_SCATTER_SLICE_SIMPLIFIER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/scatter_slice_simplifier_test.cc b/third_party/xla/xla/service/gpu/transforms/scatter_slice_simplifier_test.cc
new file mode 100644
index 0000000..8f1c93c
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/scatter_slice_simplifier_test.cc
@@ -0,0 +1,336 @@
+/* Copyright 2022 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/scatter_slice_simplifier.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/tests/hlo_test_base.h"
+
+namespace xla {
+namespace {
+
+namespace m = ::xla::match;
+
+using ScatterSliceSimplifierTest = HloTestBase;
+
+TEST_F(ScatterSliceSimplifierTest, Scatter1D) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+HloModule test_module
+
+%add_F32 {
+  %lhs = f32[] parameter(0)
+  %rhs = f32[] parameter(1)
+  ROOT %add = f32[] add(%lhs, %rhs)
+}
+
+ENTRY main {
+  %indices = s32[4] parameter(0)
+  %updates = f32[4] parameter(1)
+  %operands = f32[9] constant(0)
+  %scatter = f32[9] scatter(%operands, %indices, %updates), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%add_F32
+  ROOT %slice = f32[8] slice(%scatter), slice={[0:8]}
+}
+  )")
+                    .value();
+  ScatterSliceSimplifier test_pass;
+  ASSERT_TRUE(RunHloPass(&test_pass, module.get()).value());
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch(m::Scatter(m::Slice(m::Constant()), m::Parameter(0),
+                                    m::Parameter(1))
+                             .WithShape(F32, {8})));
+}
+
+TEST_F(ScatterSliceSimplifierTest, Scatter3D) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+HloModule test_module
+
+%add_F32 {
+  %lhs = f32[] parameter(0)
+  %rhs = f32[] parameter(1)
+  ROOT %add = f32[] add(%lhs, %rhs)
+}
+
+ENTRY main {
+  %indices = s32[2] parameter(0)
+  %updates = f32[2,4,4] parameter(1)
+  %operands = f32[5,4,4] constant(0)
+  %scatter = f32[5,4,4] scatter(%operands, %indices, %updates), update_window_dims={1,2}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%add_F32
+  ROOT %slice = f32[4,4,4] slice(%scatter), slice={[0:4], [0:4], [0:4]}
+}
+  )")
+                    .value();
+  ScatterSliceSimplifier test_pass;
+  ASSERT_TRUE(RunHloPass(&test_pass, module.get()).value());
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch(m::Scatter(m::Slice(m::Constant()), m::Parameter(0),
+                                    m::Parameter(1))
+                             .WithShape(F32, {4, 4, 4})));
+}
+
+TEST_F(ScatterSliceSimplifierTest, ScatterMultiOutput) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+HloModule test_module
+
+%add_F32_add_F16 {
+  %lhs.0 = f32[] parameter(0)
+  %rhs.0 = f32[] parameter(2)
+  %add.0 = f32[] add(%lhs.0, %rhs.0)
+  %lhs.1 = f16[] parameter(1)
+  %rhs.1 = f16[] parameter(3)
+  %add.1 = f16[] add(%lhs.1, %rhs.1)
+  ROOT %tuple = (f32[], f16[]) tuple(%add.0, %add.1)
+}
+
+ENTRY main {
+  %indices = s32[4] parameter(0)
+  %updates.0 = f32[4] parameter(1)
+  %updates.1 = f16[4] parameter(2)
+  %operands.0 = f32[9] constant(0)
+  %operands.1 = f16[9] constant(0)
+  %scatter = (f32[9], f16[9]) scatter(%operands.0, %operands.1, %indices, %updates.0, %updates.1), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%add_F32_add_F16
+  %gte.0 = f32[9] get-tuple-element(%scatter), index=0
+  %slice.0 = f32[8] slice(%gte.0), slice={[0:8]}
+  %gte.1 = f16[9] get-tuple-element(%scatter), index=1
+  %slice.1 = f16[8] slice(%gte.1), slice={[0:8]}
+  ROOT %tuple = (f32[8], f16[8]) tuple(%slice.0, %slice.1)
+}
+  )")
+                    .value();
+  ScatterSliceSimplifier test_pass;
+  ASSERT_TRUE(RunHloPass(&test_pass, module.get()).value());
+  auto expected_scatter =
+      m::Scatter(m::Slice(m::Constant()), m::Slice(m::Constant()),
+                 m::Parameter(0), m::Parameter(1), m::Parameter(2));
+
+  Shape expected_shape = ShapeUtil::MakeTupleShape(
+      {ShapeUtil::MakeShape(F32, {8}), ShapeUtil::MakeShape(F16, {8})});
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch(m::Tuple(m::GetTupleElement(expected_scatter),
+                                  m::GetTupleElement(expected_scatter))
+                             .WithShapeEqualTo(&expected_shape)));
+}
+
+TEST_F(ScatterSliceSimplifierTest, NotMatching) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+HloModule test_module
+
+%add_F32 {
+  %lhs = f32[] parameter(0)
+  %rhs = f32[] parameter(1)
+  ROOT %add = f32[] add(%lhs, %rhs)
+}
+
+slice_not_truncation {
+  %indices = s32[4] parameter(0)
+  %updates = f32[4] parameter(1)
+  %operands = f32[9] constant(0)
+  %scatter = f32[9] scatter(%operands, %indices, %updates), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%add_F32
+  ROOT %slice = f32[8] slice(%scatter), slice={[1:9]}
+}
+
+slice_with_stride {
+  %indices = s32[4] parameter(0)
+  %updates = f32[4] parameter(1)
+  %operands = f32[9] constant(0)
+  %scatter = f32[9] scatter(%operands, %indices, %updates), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%add_F32
+  ROOT %slice = f32[4] slice(%scatter), slice={[0:8:2]}
+}
+
+scatter_multiple_users {
+  %indices = s32[4] parameter(0)
+  %updates = f32[4] parameter(1)
+  %operands = f32[9] constant(0)
+  %scatter = f32[9] scatter(%operands, %indices, %updates), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%add_F32
+  %slice = f32[8] slice(%scatter), slice={[0:8]}
+  ROOT %tuple = (f32[9], f32[8]) tuple(%scatter, %slice)
+}
+
+scatter_incompatible_slices {
+  %indices = s32[2] parameter(0)
+  %updates = f32[2,4] parameter(1)
+  %operands = f32[4,4] constant(0)
+  %scatter = f32[4,4] scatter(%operands, %indices, %updates), update_window_dims={1}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%add_F32
+  %slice.0 = f32[3,4] slice(%scatter), slice={[0:3], [0:4]}
+  %slice.1 = f32[4,3] slice(%scatter), slice={[0:4], [0:3]}
+  ROOT %tuple = (f32[3,4], f32[4,3]) tuple(%slice.0, %slice.1)
+}
+
+slice_not_found {
+  %indices = s32[4] parameter(0)
+  %updates = f32[4] parameter(1)
+  %operands = f32[8] constant(0)
+  %scatter = f32[8] scatter(%operands, %indices, %updates), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%add_F32
+  ROOT %exp = f32[8] exponential(%scatter)
+}
+
+slice_update_dimensions {
+  %indices = s32[10] parameter(0)
+  %updates = f32[10,1,128] parameter(1)
+  %operands = f32[100,128] constant(0)
+  %scatter = f32[100,128] scatter(%operands, %indices, %updates), update_window_dims={1,2}, inserted_window_dims={}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%add_F32
+  ROOT %slice = f32[100,64] slice(%scatter), slice={[0:100], [0:64]}
+}
+  )")
+                    .value();
+  ScatterSliceSimplifier test_pass;
+  ASSERT_FALSE(RunHloPass(&test_pass, module.get()).value());
+}
+
+TEST_F(ScatterSliceSimplifierTest, IntermediaryUsers) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+HloModule test_module
+
+%add_F32 {
+  %lhs = f32[] parameter(0)
+  %rhs = f32[] parameter(1)
+  ROOT %add = f32[] add(%lhs, %rhs)
+}
+
+ENTRY main {
+  %indices = s32[4] parameter(0)
+  %updates = f32[4] parameter(1)
+  %operands = f32[9] constant(0)
+  %scatter = f32[9] scatter(%operands, %indices, %updates), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%add_F32
+  %unary = f32[9] abs(%scatter)
+  %slice.0 = f32[8] slice(%unary), slice={[0:8]}
+  %binary = f32[9] maximum(%scatter, %operands)
+  %slice.1 = f32[8] slice(%binary), slice={[0:8]}
+  ROOT %tuple = (f32[8], f32[8]) tuple(%slice.0, %slice.1)
+}
+  )")
+                    .value();
+  ScatterSliceSimplifier test_pass;
+  ASSERT_TRUE(RunHloPass(&test_pass, module.get()).value());
+  auto expected_scatter =
+      m::Scatter(m::Slice(m::Constant()), m::Parameter(0), m::Parameter(1));
+
+  Shape expected_shape = ShapeUtil::MakeTupleShape(
+      {ShapeUtil::MakeShape(F32, {8}), ShapeUtil::MakeShape(F32, {8})});
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(m::Tuple(m::Abs(expected_scatter),
+                          m::Maximum(expected_scatter, m::Slice(m::Constant())))
+                     .WithShapeEqualTo(&expected_shape)));
+}
+
+TEST_F(ScatterSliceSimplifierTest, IntermediaryChain) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+HloModule test_module
+
+%add_F32 {
+  %lhs = f32[] parameter(0)
+  %rhs = f32[] parameter(1)
+  ROOT %add = f32[] add(%lhs, %rhs)
+}
+
+ENTRY main {
+  %indices = s32[4] parameter(0)
+  %updates = f32[4] parameter(1)
+  %operands = f32[9] constant(0)
+  %scatter = f32[9] scatter(%operands, %indices, %updates), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%add_F32
+  %elementwise.0 = f32[9] abs(%scatter)
+  %elementwise.1 = f32[9] exponential(%elementwise.0)
+  %elementwise.2 = f32[9] add(%elementwise.0, %elementwise.1)
+  ROOT %result = f32[8] slice(%elementwise.2), slice={[0:8]}
+}
+  )")
+                    .value();
+  ScatterSliceSimplifier test_pass;
+  ASSERT_TRUE(RunHloPass(&test_pass, module.get()).value());
+  auto expected_scatter =
+      m::Scatter(m::Slice(m::Constant()), m::Parameter(0), m::Parameter(1));
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch(m::Add(m::Abs(expected_scatter),
+                                m::Exp(m::Abs(expected_scatter)))
+                             .WithShape(F32, {8})));
+}
+
+TEST_F(ScatterSliceSimplifierTest, DiamondShape) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+HloModule test_module
+
+%add_F32_mul_F32 {
+  %lhs.0 = f32[] parameter(0)
+  %rhs.0 = f32[] parameter(2)
+  %add.0 = f32[] add(%lhs.0, %rhs.0)
+  %lhs.1 = f32[] parameter(1)
+  %rhs.1 = f32[] parameter(3)
+  %mul.1 = f32[] multiply(%lhs.1, %rhs.1)
+  ROOT %tuple = (f32[], f32[]) tuple(%add.0, %mul.1)
+}
+
+ENTRY main {
+  %indices = s32[4] parameter(0)
+  %updates.0 = f32[4] parameter(1)
+  %updates.1 = f32[4] parameter(2)
+  %operands.0 = f32[9] constant(0)
+  %operands.1 = f32[9] constant(0)
+  %scatter = (f32[9], f32[9]) scatter(%operands.0, %operands.1, %indices, %updates.0, %updates.1), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%add_F32_mul_F32
+  %gte.0 = f32[9] get-tuple-element(%scatter), index=0
+  %gte.1 = f32[9] get-tuple-element(%scatter), index=1
+  %consumer = f32[9] add(%gte.0, %gte.1)
+  ROOT %slice = f32[8] slice(%consumer), slice={[0:8]}
+}
+  )")
+                    .value();
+  ScatterSliceSimplifier test_pass;
+  ASSERT_TRUE(RunHloPass(&test_pass, module.get()).value());
+  auto expected_scatter =
+      m::Scatter(m::Slice(m::Constant()), m::Slice(m::Constant()),
+                 m::Parameter(0), m::Parameter(1), m::Parameter(2));
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch(m::Add(m::GetTupleElement(expected_scatter),
+                                m::GetTupleElement(expected_scatter))
+                             .WithShape(F32, {8})));
+}
+
+TEST_F(ScatterSliceSimplifierTest, ElementwiseSelect) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+HloModule test_module
+
+%add_F32 {
+  %lhs = f32[] parameter(0)
+  %rhs = f32[] parameter(1)
+  ROOT %add = f32[] add(%lhs, %rhs)
+}
+
+ENTRY main {
+  %indices = s32[4] parameter(0)
+  %updates = f32[4] parameter(1)
+  %operands = f32[9] constant(0)
+  %scatter = f32[9] scatter(%operands, %indices, %updates), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%add_F32
+  %pred_ = pred[9] parameter(2)
+  %select = f32[9] select(%pred_, %scatter, %operands)
+  ROOT %slice = f32[8] slice(%select), slice={[0:8]}
+}
+  )")
+                    .value();
+  ScatterSliceSimplifier test_pass;
+  ASSERT_TRUE(RunHloPass(&test_pass, module.get()).value());
+  auto expected_scatter =
+      m::Scatter(m::Slice(m::Constant()), m::Parameter(0), m::Parameter(1));
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch(m::Select(m::Slice(m::Parameter(2)), expected_scatter,
+                                   m::Slice(m::Constant()))
+                             .WithShape(F32, {8})));
+}
+
+}  // namespace
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing.cc b/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing.cc
new file mode 100644
index 0000000..9929b35
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing.cc
@@ -0,0 +1,165 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/schedule_postprocessing.h"
+
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/hlo/ir/hlo_schedule.h"
+#include "xla/hlo/utils/hlo_query.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+// Maps a computation to a boolean that indicates whether the computation may
+// invoke custom-calls directly or indirectly, which can eventually trigger gpu
+// synchronization.
+using CustomCallInComputation =
+    absl::flat_hash_map<const HloComputation*, bool>;
+
+// Returns whether the hlo may invoke custom-calls which may trigger gpu
+// synchronization. Currently, we only check for custom-calls, because they are
+// the only operations that can be parallel with asynchronous collectives
+// operations in an hlo-schedule and may trigger gpu synchronization.
+bool MayInvokeCustomCall(
+    const HloInstruction* hlo,
+    const CustomCallInComputation& custom_call_in_computation) {
+  if (hlo->opcode() == HloOpcode::kCustomCall) {
+    return true;
+  }
+
+  return absl::c_any_of(
+      hlo->called_computations(), [&](const HloComputation* callee) {
+        return custom_call_in_computation.find(callee)->second;
+      });
+}
+
+// Returns true if this is an asynchronous collective start operation, excluding
+// P2P operations.
+absl::StatusOr<bool> IsRelevantAsynchronousStart(const HloInstruction* hlo) {
+  if (!hlo_query::IsAsyncCollectiveStartOp(hlo,
+                                           /*include_send_recv=*/false)) {
+    return false;
+  }
+  TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
+                      hlo->backend_config<GpuBackendConfig>());
+  const CollectiveBackendConfig& collective_backend_config =
+      gpu_config.collective_backend_config();
+  return !collective_backend_config.is_sync();
+}
+
+// Returns true if this is a collective done operation, excluding P2P
+// operations.
+absl::StatusOr<bool> IsRelevantAsynchronousDone(const HloInstruction* hlo) {
+  return hlo_query::IsAsyncCollectiveDoneOp(hlo,
+                                            /*include_send_recv=*/false);
+}
+
+// For a given computation, finds all the asynchronous collective operations
+// that aren't parallel with custom-calls and sets its no_parallel_custom_call
+// attribute to true. Also records whether the given computation may invoke
+// custom-calls.
+absl::StatusOr<bool> ProcessComputation(
+    const HloSchedule& schedule, HloComputation* computation,
+    CustomCallInComputation& custom_call_in_computation) {
+  bool changed = false;
+  bool has_custom_call = false;
+  absl::flat_hash_set<HloInstruction*> async_starts;
+  const HloInstructionSequence& sequence = schedule.sequence(computation);
+
+  // Visit instructions in the sequence. Collect relevant asynchronous
+  // collective start ops. When we see a relevant asynchronous collective done
+  // op, remove the corresponding start op from the collection and set its
+  // attribute no_parallel_custom_call to true. When we see a custom-call, clear
+  // the start ops from the collection and keep their attribute
+  // no_parallel_custom_call as false.
+  const std::vector<HloInstruction*>& all_instructions =
+      sequence.instructions();
+  for (HloInstruction* hlo : all_instructions) {
+    if (MayInvokeCustomCall(hlo, custom_call_in_computation)) {
+      async_starts.clear();
+      has_custom_call = true;
+      continue;
+    }
+    TF_ASSIGN_OR_RETURN(bool is_async_start, IsRelevantAsynchronousStart(hlo));
+    if (is_async_start) {
+      async_starts.insert(hlo);
+      continue;
+    }
+
+    TF_ASSIGN_OR_RETURN(bool is_async_done, IsRelevantAsynchronousDone(hlo));
+    if (is_async_done) {
+      HloInstruction* async_start = hlo->mutable_operand(0);
+      if (async_starts.contains(async_start)) {
+        changed = true;
+        TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
+                            async_start->backend_config<GpuBackendConfig>());
+        CollectiveBackendConfig& collective_backend_config =
+            *gpu_config.mutable_collective_backend_config();
+        collective_backend_config.set_no_parallel_custom_call(true);
+        TF_RETURN_IF_ERROR(async_start->set_backend_config(gpu_config));
+        async_starts.erase(async_start);
+      }
+    }
+  }
+
+  custom_call_in_computation[computation] = has_custom_call;
+  return changed;
+}
+
+}  // anonymous namespace
+
+absl::StatusOr<bool> SchedulePostprocessing::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  if (!module->has_schedule()) return false;
+  HloSchedule& schedule = module->schedule();
+  bool changed = false;
+  CustomCallInComputation custom_call_in_computation;
+
+  // We visit computations in the order of callees to callers, as information is
+  // propagated from calles to callers.
+  std::vector<HloComputation*> all_computations =
+      module->MakeComputationPostOrder(execution_threads);
+  for (auto iter = all_computations.begin(); iter != all_computations.end();
+       ++iter) {
+    HloComputation* computation = *iter;
+    if (computation->IsFusionComputation()) {
+      custom_call_in_computation[computation] = false;
+      continue;
+    }
+
+    TF_ASSIGN_OR_RETURN(
+        bool result,
+        ProcessComputation(schedule, computation, custom_call_in_computation));
+    changed |= result;
+  }
+
+  return changed;
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing.h b/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing.h
new file mode 100644
index 0000000..899098d
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing.h
@@ -0,0 +1,50 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_SCHEDULE_POSTPROCESSING_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_SCHEDULE_POSTPROCESSING_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// Amends a schedule result with the needed information to support a runtime
+// implementation. Currently, this pass refines attribute
+// no_parallel_custom_call for asynchronous collective operations to support
+// runtime optimization, such as skipping rendezvous of all participating
+// threads for NCCL collective operations. In particular, it sets the attribute
+// value for Collective-start operations with is_sync=false; it also keeps the
+// attribute value untouch for the operations with is_sync=true and for P2P
+// operations, assumming the runtime won't use those values.
+//
+class SchedulePostprocessing : public HloModulePass {
+ public:
+  absl::string_view name() const override { return "schedule-postprocessing"; }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_SCHEDULE_POSTPROCESSING_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing_test.cc b/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing_test.cc
new file mode 100644
index 0000000..0c9c6e6
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing_test.cc
@@ -0,0 +1,163 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/schedule_postprocessing.h"
+
+#include <memory>
+
+#include <gtest/gtest.h>
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/hlo_parser.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/util.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+using SchedulePostprocessingTest = HloTestBase;
+
+TEST_F(SchedulePostprocessingTest, SynchronousOpsNotChanged) {
+  constexpr absl::string_view kHloString = R"(
+  HloModule module, is_scheduled=true
+
+  ENTRY entry {
+    pf32 = f32[1] parameter(0)
+
+    all-gather-start = (f32[1], f32[2]) all-gather-start(pf32), dimensions={0}, backend_config={"collective_backend_config":{"is_sync":true,"no_parallel_custom_call":false}}
+    ROOT all-gather-done = f32[2] all-gather-done(all-gather-start)
+  }
+)";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnUnverifiedModule((kHloString)));
+  SchedulePostprocessing pass;
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get()));
+  EXPECT_FALSE(changed);
+}
+
+TEST_F(SchedulePostprocessingTest, P2POpsNotChanged) {
+  constexpr absl::string_view kHloString = R"(
+  HloModule module, is_scheduled=true
+
+  ENTRY main {
+    f0 = f32[] constant(0.0)
+    init = f32[1, 1024, 1024] broadcast(f0), dimensions={}
+
+    after-all = token[] after-all()
+    recv = (f32[1, 1024, 1024], u32[], token[]) recv(after-all), channel_id=2,
+      frontend_attributes={
+      _xla_send_recv_source_target_pairs="{{0,1}, {1,2}}"
+    }
+    recv-done = (f32[1, 1024, 1024], token[]) recv-done(recv), channel_id=2
+    ROOT recv-data = f32[1, 1024, 1024] get-tuple-element(recv-done), index=0
+  }
+)";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnUnverifiedModule((kHloString)));
+  SchedulePostprocessing pass;
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get()));
+  EXPECT_FALSE(changed);
+}
+
+TEST_F(SchedulePostprocessingTest, AsynchronousOpsChanged) {
+  constexpr absl::string_view kHloString = R"(
+  HloModule module, is_scheduled=true
+
+  ENTRY entry {
+    pf32 = f32[1] parameter(0)
+    pf32.2 = f32[1] custom-call(pf32), custom_call_target="my_custom_call"
+    all-gather-start = (f32[1], f32[2]) all-gather-start(pf32.2), dimensions={0}, backend_config={"collective_backend_config":{"is_sync":false}}
+    ROOT all-gather-done = f32[2] all-gather-done(all-gather-start)
+  }
+)";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnUnverifiedModule((kHloString)));
+  SchedulePostprocessing pass;
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get()));
+  EXPECT_TRUE(changed);
+
+  HloInstruction* start = FindInstruction(module.get(), "all-gather-start");
+  TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config,
+                          start->backend_config<GpuBackendConfig>());
+  const CollectiveBackendConfig& collective_backend_config =
+      gpu_config.collective_backend_config();
+  EXPECT_TRUE(collective_backend_config.no_parallel_custom_call());
+}
+
+TEST_F(SchedulePostprocessingTest, AsynchronousOpsWithParallelCustomcall) {
+  constexpr absl::string_view kHloString = R"(
+  HloModule module, is_scheduled=true
+
+  ENTRY entry {
+    pf32 = f32[1] parameter(0)
+    all-gather-start = (f32[1], f32[2]) all-gather-start(pf32), dimensions={0}, backend_config={"collective_backend_config":{"is_sync":false}}
+    pf32.2 = f32[1] custom-call(pf32), custom_call_target="my_custom_call"
+    all-gather-done = f32[2] all-gather-done(all-gather-start)
+    ROOT out = (f32[1], f32[2]) tuple(f32[1] pf32.2, f32[2] all-gather-done)
+  }
+)";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnUnverifiedModule((kHloString)));
+  SchedulePostprocessing pass;
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get()));
+  EXPECT_FALSE(changed);
+
+  HloInstruction* start = FindInstruction(module.get(), "all-gather-start");
+  TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config,
+                          start->backend_config<GpuBackendConfig>());
+  const CollectiveBackendConfig& collective_backend_config =
+      gpu_config.collective_backend_config();
+  EXPECT_FALSE(collective_backend_config.no_parallel_custom_call());
+}
+
+TEST_F(SchedulePostprocessingTest,
+       AsynchronousOpsWithParallelNestedCustomcall) {
+  constexpr absl::string_view kHloString = R"(
+  HloModule module, is_scheduled=true
+  foo {
+    v = f32[1] parameter(0)
+    ROOT ret = f32[1] custom-call(v), custom_call_target="my_custom_call"
+  }
+
+  ENTRY entry {
+    pf32 = f32[1] parameter(0)
+    all-gather-start = (f32[1], f32[2]) all-gather-start(pf32), dimensions={0}, backend_config={"collective_backend_config":{"is_sync":false}}
+    pf32.2 = f32[1] call(f32[1] pf32), to_apply=foo
+    all-gather-done = f32[2] all-gather-done(all-gather-start)
+    ROOT out = (f32[1], f32[2]) tuple(f32[1] pf32.2, f32[2] all-gather-done)
+  }
+)";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnUnverifiedModule((kHloString)));
+  SchedulePostprocessing pass;
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get()));
+  EXPECT_FALSE(changed);
+
+  HloInstruction* start = FindInstruction(module.get(), "all-gather-start");
+  TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config,
+                          start->backend_config<GpuBackendConfig>());
+  const CollectiveBackendConfig& collective_backend_config =
+      gpu_config.collective_backend_config();
+  EXPECT_FALSE(collective_backend_config.no_parallel_custom_call());
+}
+
+}  // namespace
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/scheduling_instruction_annotator.cc b/third_party/xla/xla/service/gpu/transforms/scheduling_instruction_annotator.cc
new file mode 100644
index 0000000..d796213
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/scheduling_instruction_annotator.cc
@@ -0,0 +1,74 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/scheduling_instruction_annotator.h"
+
+#include <string>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/check.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla::gpu {
+namespace {
+
+// Populates `OpMetadata`'s `scheduling_name` field for all of the instructions
+// belonging to `computation`.
+absl::StatusOr<bool> AnnotateSchedulingInstructionNames(
+    HloComputation& computation) {
+  bool changed = false;
+  for (HloInstruction* inst : computation.instructions()) {
+    if (!inst->metadata().scheduling_name().empty()) {
+      continue;
+    }
+    // We skip constants as we might have to sanitize them in order to satisfy
+    // LLVM backend. I.e. we allow `GpuSanitizeConstantNames` pass to run post
+    // scheduling.
+    if (inst->opcode() == HloOpcode::kConstant) {
+      continue;
+    }
+    inst->set_metadata_scheduling_name(inst->name());
+    changed = true;
+  }
+  return changed;
+}
+
+}  // namespace
+
+absl::StatusOr<bool> SchedulingInstructionAnnotator::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  CHECK(module->has_schedule())
+      << "The pass is supposed to run in the beginning of post-scheduling!";
+  bool changed = false;
+
+  // We visit computations in the order of callees to callers, as information is
+  // propagated from calles to callers.
+  for (HloComputation* computation :
+       module->MakeComputationPostOrder(execution_threads)) {
+    TF_ASSIGN_OR_RETURN(bool result,
+                        AnnotateSchedulingInstructionNames(*computation));
+    changed |= result;
+  }
+
+  return changed;
+}
+
+}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/transforms/scheduling_instruction_annotator.h b/third_party/xla/xla/service/gpu/transforms/scheduling_instruction_annotator.h
new file mode 100644
index 0000000..03c21bb
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/scheduling_instruction_annotator.h
@@ -0,0 +1,44 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_SCHEDULING_INSTRUCTION_ANNOTATOR_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_SCHEDULING_INSTRUCTION_ANNOTATOR_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla::gpu {
+
+// The pass amends the `OpMetadata` with instruction name present at the
+// scheduling time. This is later being used to make sure instructions are not
+// renamed post scheduling. Enforcing this is necessary because otherwise
+class SchedulingInstructionAnnotator : public HloModulePass {
+ public:
+  absl::string_view name() const override {
+    return "scheduling-instruction-annotator";
+  }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+}  // namespace xla::gpu
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_SCHEDULING_INSTRUCTION_ANNOTATOR_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/scheduling_instruction_annotator_test.cc b/third_party/xla/xla/service/gpu/transforms/scheduling_instruction_annotator_test.cc
new file mode 100644
index 0000000..abe8d50
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/scheduling_instruction_annotator_test.cc
@@ -0,0 +1,131 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/scheduling_instruction_annotator.h"
+
+#include <memory>
+
+#include <gtest/gtest.h>
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/tests/filecheck.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla::gpu {
+namespace {
+
+using SchedulingInstructionAnnotatorTest = HloTestBase;
+
+TEST_F(SchedulingInstructionAnnotatorTest,
+       AnnotatesAllInstructionsWithTheirRespectiveNames) {
+  constexpr absl::string_view kHloString = R"(
+    HloModule module, is_scheduled=true
+
+    ENTRY entry {
+      p0 = f32[1] parameter(0)
+      p1 = f32[1] parameter(1)
+      add0 = f32[1] add(p0,p1)
+      ROOT exp0 = f32[1] exponential(add0)
+    }
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(kHloString));
+
+  SchedulingInstructionAnnotator pass;
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get()));
+
+  ASSERT_TRUE(changed);
+  for (const auto* comp : module->computations()) {
+    for (const auto* instruction : comp->instructions()) {
+      EXPECT_EQ(instruction->name(), instruction->metadata().scheduling_name());
+    }
+  }
+  constexpr absl::string_view kExpected = R"(
+// CHECK:       %[[P0:.+]] = {{.*}} parameter(0)
+// CHECK-SAME:  scheduling_name="[[P0]]"
+// CHECK:       %[[P1:.+]] = {{.*}} parameter(1)
+// CHECK-SAME:  scheduling_name="[[P1]]"
+// CHECK:       %[[ADD0:.+]] = {{.*}} add(%[[P0]], %[[P1]])
+// CHECK-SAME:  scheduling_name="[[ADD0]]"
+// CHECK: ROOT  %[[EXP0:.+]] = {{.*}} exponential(%[[ADD0]])
+// CHECK-SAME:  scheduling_name="[[EXP0]]"
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(
+      bool filecheck_matches,
+      RunFileCheck(
+          module->ToString(HloPrintOptions().set_print_operand_shape(false)),
+          kExpected));
+  EXPECT_TRUE(filecheck_matches);
+}
+
+TEST_F(SchedulingInstructionAnnotatorTest, SkipsAnnotatingConstants) {
+  constexpr absl::string_view kHloString = R"(
+    HloModule module, is_scheduled=true
+
+    ENTRY entry {
+      p0 = f32[1] parameter(0)
+      c1 = f32[1] constant(42)
+      ROOT add0 = f32[1] add(p0, c1)
+    }
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(kHloString));
+
+  SchedulingInstructionAnnotator pass;
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get()));
+
+  ASSERT_TRUE(changed);
+  constexpr absl::string_view kExpected = R"(
+// CHECK:       %[[P0:.+]] = {{.*}} parameter(0)
+// CHECK-SAME:  scheduling_name="[[P0]]"
+// CHECK-NEXT:  %[[C1:.+]] = f32[1]
+// CHECK-NOT:   scheduling_name
+// CHECK-SAME:  constant({42})
+// CHECK:       %[[ADD0:.+]] = {{.*}} add(%[[P0]], %[[C1]])
+// CHECK-SAME:  scheduling_name="[[ADD0]]"
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(
+      bool filecheck_matches,
+      RunFileCheck(
+          module->ToString(HloPrintOptions().set_print_operand_shape(false)),
+          kExpected));
+  EXPECT_TRUE(filecheck_matches);
+}
+
+TEST_F(SchedulingInstructionAnnotatorTest,
+       DoesNotAnnotateAllInstructionsWithTheirRespectiveNames) {
+  constexpr absl::string_view kHloString = R"(
+    HloModule module, is_scheduled=true
+
+    ENTRY entry {
+      p0 = f32[1] parameter(0), metadata={scheduling_name="p0"}
+      p1 = f32[1] parameter(1), metadata={scheduling_name="p1"}
+      add0 = f32[1] add(p0,p1), metadata={scheduling_name="add0"}
+      ROOT exp0 = f32[1] exponential(add0), metadata={scheduling_name="exp0"}
+    }
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(kHloString));
+
+  SchedulingInstructionAnnotator pass;
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get()));
+
+  EXPECT_FALSE(changed);
+}
+
+}  // namespace
+}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.cc b/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.cc
new file mode 100644
index 0000000..90bc607
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.cc
@@ -0,0 +1,797 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/softmax_rewriter_triton.h"
+
+#include <cstdint>
+#include <functional>
+#include <string>
+#include <utility>
+#include <variant>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "mlir/IR/MLIRContext.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/hlo/utils/hlo_query.h"
+#include "xla/layout_util.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/fusions/triton/triton_support.h"
+#include "xla/service/gpu/hlo_traversal.h"
+#include "xla/service/gpu/ir_emission_utils.h"
+#include "xla/service/gpu/model/fusion_analysis_cache.h"
+#include "xla/service/gpu/model/gpu_indexing_performance_model.h"
+#include "xla/service/gpu/model/symbolic_tile_analysis.h"
+#include "xla/service/gpu/model/tiled_hlo_computation.h"
+#include "xla/service/instruction_fusion.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/status_macros.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/logging.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+using hlo_query::IsBroadcastOfParameter;
+using hlo_query::IsBroadcastOfScalarConstant;
+
+bool HasDefaultLayout(const Shape& shape) {
+  return shape.has_layout() &&
+         LayoutUtil::IsMonotonicWithDim0Major(shape.layout());
+}
+
+// Returns true if a trivially connected producer of 'consumer' with opcode
+// 'opcode' exists. If such an instruction is found, the value of 'producer' is
+// set to it. The definition of "trivial" operations is as given in
+// 'IsTriviallyFusible'.
+bool TrivialEdge(HloInstruction** producer, HloInstruction* consumer,
+                 HloOpcode opcode, const se::GpuComputeCapability& gpu_version);
+
+bool BitcastIsTilingNoop(HloInstruction* bitcast,
+                         const se::GpuComputeCapability& gpu_version) {
+  CHECK_EQ(bitcast->opcode(), HloOpcode::kBitcast);
+
+  if (ShapeUtil::IsEffectiveScalar(bitcast->shape())) {
+    return true;
+  }
+
+  // In the Softmax rewriter for now, tiling is derived from a hero reduction
+  // operation, which should be reducing its input on the last axis. Therefore,
+  // a bitcast is always a no-op with regards to a tile if
+  //   (1) it does not change the size of the reduction dimension of its input
+  //       (the last one); if its input is already reduced, then (1) is true
+  //       by default
+  //   (2) the layout of its output is ordered in the same way as the layout of
+  //       its input. This is a fuzzy definition, but since we assume fusible
+  //       ops to always have a default layout, we can just check if both the
+  //       bitcast and its input have a default layout
+  auto last_dimension = [](const HloInstruction* instr) {
+    return instr->shape().dimensions().back();
+  };
+
+  HloInstruction* reduce = nullptr;
+  TrivialEdge(&reduce, bitcast->mutable_operand(0), HloOpcode::kReduce,
+              gpu_version);
+
+  return (HasDefaultLayout(bitcast->shape()) &&
+          HasDefaultLayout(bitcast->operand(0)->shape()) &&
+          (reduce != nullptr ||
+           last_dimension(bitcast->operand(0)) == last_dimension(bitcast)));
+}
+
+inline bool HasOneUse(const HloInstruction* instr) {
+  return instr->user_count() == 1;
+}
+
+// Supports two types of broadcast of parameters. Either to one batch
+// dim, or one reduction dim. For example the following cases are supported:
+//
+// Case #1:
+// p = f32[a] parameter(0)
+// b = f32[a,x] broadcast(p), dimensions={0}
+//
+// Case #2:
+// p = f32[a] parameter(0)
+// b = f32[x,a] broadcast(p), dimensions={1}
+//
+// Case #3:
+// p = f32[a,b] parameter(0)
+// b = f32[x,a,b] broadcast(p), dimensions={1,2}
+//
+// Other broadcast tiling patterns are currently unsupported.
+// See b/328049138 for details.
+//
+// Unsupported case #1:
+// p = f32[a] parameter(0)
+// b = f32[x,a,y] broadcast(p), dimensions={1}
+//
+// Unsupported case #2:
+// p = f32[a,b] parameter(0)
+// b = f32[x,a,y,b] broadcast(p), dimensions={1,3}
+//
+// Unsupported case #3:
+// p = f32[a] parameter(0)
+// b = f32[x,y,a] broadcast(p), dimensions={2}
+//
+// Unsupported case #4:
+// p = f32[a,b] parameter(0)
+// b = f32[a,x,b] broadcast(p), dimensions={0,2}
+bool IsBatchOrReductionDimBroadcast(const HloInstruction& hlo) {
+  CHECK_EQ(hlo.opcode(), HloOpcode::kBroadcast)
+      << "Expected broadcast " << hlo.ToShortString();
+  CHECK_EQ(hlo.operand(0)->opcode(), HloOpcode::kParameter)
+      << "Expected parameter " << hlo.operand(0)->ToShortString();
+
+  const HloBroadcastInstruction* broadcast =
+      Cast<HloBroadcastInstruction>(&hlo);
+
+  const HloParameterInstruction* parameter =
+      Cast<HloParameterInstruction>(hlo.operand(0));
+
+  // Support only one dim broadcast.
+  if (parameter->shape().dimensions_size() + 1 !=
+      broadcast->shape().dimensions_size()) {
+    return false;
+  }
+
+  // It is enough to ensure that the broadcast does not preserve both last, and
+  // first dimensions of the parameter at the same time. Otherwise the broadcast
+  // is the unsupported case #4.
+  //
+  // Preserve the first dim:
+  //   p = f32[a,b] parameter(0)
+  //   b1 = f32[a,b,c] broadcast(p), dimensions={0,1}
+  bool preserve_first_dim = broadcast->dimensions().front() == 0;
+  // Preserve the last dim:
+  //   p = f32[a,b] parameter(0)
+  //   b1 = f32[c,a,b] broadcast(p), dimensions={1,2}
+  bool preserve_last_dim = broadcast->dimensions().back() ==
+                           broadcast->shape().dimensions_size() - 1;
+  // We do not want to preserve both first and last dim, as it means the
+  // broadcast is not expanding on outermost dims.
+  return !(preserve_first_dim && preserve_last_dim);
+}
+
+bool IsBroadcastOfAScalar(const HloInstruction& hlo) {
+  CHECK_EQ(hlo.opcode(), HloOpcode::kBroadcast)
+      << "Expected broadcast " << hlo.ToShortString();
+  return ShapeUtil::IsScalar(hlo.operand(0)->shape());
+}
+
+bool IsSingleRowParameterBroadcast(const HloInstruction& hlo) {
+  CHECK_EQ(hlo.opcode(), HloOpcode::kBroadcast)
+      << "Expected broadcast " << hlo.ToShortString();
+  CHECK_EQ(hlo.operand(0)->opcode(), HloOpcode::kParameter)
+      << "Expected parameter " << hlo.operand(0)->ToShortString();
+
+  const HloBroadcastInstruction* broadcast =
+      Cast<HloBroadcastInstruction>(&hlo);
+  const HloParameterInstruction* parameter =
+      Cast<HloParameterInstruction>(hlo.operand(0));
+
+  if (parameter->shape().dimensions_size() != 1) {
+    return false;
+  }
+  return broadcast->dimensions()[0] == broadcast->shape().dimensions_size() - 1;
+}
+
+bool IsSupportedBroadcastOfParameter(const HloInstruction& hlo) {
+  return IsBroadcastOfParameter(hlo) &&
+         (IsBatchOrReductionDimBroadcast(hlo) || IsBroadcastOfAScalar(hlo) ||
+          IsSingleRowParameterBroadcast(hlo));
+}
+
+// Chooses which operand to use for fusion processing. Taking in a unary or
+// binary instruction, returns the first non-splat operand. If none is
+// present, returns any operand.
+HloInstruction* ChooseOperandForFusionProcessing(HloInstruction* instr) {
+  CHECK_GT(instr->operand_count(), 0);
+  CHECK_LE(instr->operand_count(), 2);
+
+  // TODO(b/326217416): Extend the broadcast of splat constants/parameters to a
+  // broadcast of any op.
+  if (instr->operand_count() > 1 &&
+      (IsBroadcastOfScalarConstant(*instr->operand(0)) ||
+       IsSupportedBroadcastOfParameter(*instr->operand(0)))) {
+    return instr->mutable_operand(1);
+  }
+  return instr->mutable_operand(0);
+}
+
+bool IsTriviallyFusible(HloInstruction* instr,
+                        const se::GpuComputeCapability& gpu_version,
+                        int num_allowed_users = 1) {
+  // Checks whether an op is trivially fusible. An op is said to be trivially
+  // fusible if it does not increase the amount of memory read/written by the
+  // resulting fusion, is compatible with any chosen tiling, and can be
+  // codegen'd using Triton. The op is allowed to have up to num_allowed_users
+  // users.
+  if (instr->user_count() > num_allowed_users ||
+      !HasDefaultLayout(instr->shape())) {
+    return false;
+  }
+
+  if (instr->opcode() == HloOpcode::kBitcast &&
+      BitcastIsTilingNoop(instr, gpu_version)) {
+    return true;
+  }
+
+  if (instr->IsElementwise() && instr->operand_count() == 1) {
+    return static_cast<bool>(IsTritonSupportedInstruction(*instr, gpu_version));
+  }
+
+  // Elementwise binary ops are trivially fusible if the operands are the same,
+  // or if exactly one of the operands is a splat constant.
+  if (instr->IsElementwiseBinary()) {
+    const HloInstruction* operand_0 = instr->operand(0);
+    const HloInstruction* operand_1 = instr->operand(1);
+
+    // Elementwise binary ops should be fused if both operands are the same and
+    // if the operand is triton supported.
+    if (operand_0 == operand_1) {
+      return static_cast<bool>(
+          IsTritonSupportedInstruction(*instr, gpu_version));
+    }
+
+    // For simplicity we only fuse elementwise binary ops with splat operands
+    // if they contain one non-splat operand.
+    // TODO(b/326217416): Extend the broadcast of splat constants/parameters to
+    // a broadcast of any op.
+    if ((IsBroadcastOfScalarConstant(*operand_0) ||
+         IsSupportedBroadcastOfParameter(*operand_0)) ^
+        (IsBroadcastOfScalarConstant(*operand_1) ||
+         IsSupportedBroadcastOfParameter(*operand_1))) {
+      return static_cast<bool>(
+          IsTritonSupportedInstruction(*instr, gpu_version));
+    }
+  }
+
+  return false;
+}
+
+bool TrivialEdge(HloInstruction** producer, HloInstruction* consumer,
+                 HloOpcode opcode,
+                 const se::GpuComputeCapability& gpu_version) {
+  while (consumer->opcode() != opcode) {
+    if (IsTriviallyFusible(consumer, gpu_version)) {
+      consumer = ChooseOperandForFusionProcessing(consumer);
+    } else {
+      return false;
+    }
+  }
+
+  *producer = consumer;
+  return true;
+}
+
+bool IsTriviallyConnectedProducerOf(
+    HloInstruction* producer, HloInstruction* consumer,
+    const se::GpuComputeCapability& gpu_version) {
+  if (producer == consumer) {
+    return true;
+  }
+
+  HloInstruction* found_producer = consumer;
+  while (
+      TrivialEdge(&found_producer, consumer, producer->opcode(), gpu_version)) {
+    if (found_producer == producer) {
+      return true;
+    }
+
+    if (!IsTriviallyFusible(found_producer, gpu_version)) {
+      return false;
+    }
+
+    consumer = found_producer->mutable_operand(0);
+  }
+
+  return false;
+}
+
+// Finds the first non-fusible producer of a diamond. This instruction is either
+//   1. the direct producer of the diamond, if that producer is used more than
+//      twice and/or is not otherwise trivially fusible
+//   2. the first parent instruction of the producer of the diamond such that
+//      that instruction is used more than once, and/or is not trivially
+//      fusible.
+HloInstruction* FindFirstNonFusibleDiamondProducer(
+    HloInstruction* diamond_producer,
+    const se::GpuComputeCapability& gpu_version) {
+  if (IsTriviallyFusible(diamond_producer, gpu_version,
+                         /*num_allowed_users=*/2)) {
+    diamond_producer = ChooseOperandForFusionProcessing(diamond_producer);
+    while (IsTriviallyFusible(diamond_producer, gpu_version)) {
+      diamond_producer = ChooseOperandForFusionProcessing(diamond_producer);
+    }
+  }
+
+  return diamond_producer;
+}
+
+// Creates a fusion corresponding to the input diamond chain. The resulting
+// fusion instruction is added to the module, but is not yet inserted into the
+// graph as a replacement of the original instructions.
+//
+// TODO(b/347956491): this awkward abstraction is needed to work around
+// limitations of HloFusionAdaptor, which underpins the implementation of
+// SymbolicTileAnalysis. We need to come up with a better solution.
+absl::StatusOr<HloFusionInstruction*> MakeFusionForDiamondChain(
+    const DiamondChainDescriptor& diamond_chain) {
+  auto [root, producer] = diamond_chain;
+
+  std::string suggested_name = "triton_softmax";
+  HloComputation::Builder builder(absl::StrCat(suggested_name, "_computation"));
+  // Original instruction -> fused one.
+  absl::flat_hash_map<const HloInstruction*, HloInstruction*>
+      old_to_new_mapping;
+
+  int param = 0;
+  old_to_new_mapping[producer] =
+      builder.AddInstruction(HloInstruction::CreateParameter(
+          param, producer->shape(), absl::StrCat("parameter_", param)));
+  param++;
+
+  std::vector<HloInstruction*> parameters = {producer};
+
+  std::function<void(HloInstruction*)> create_computation =
+      [&](HloInstruction* instr) -> void {
+    if (old_to_new_mapping.contains(instr)) {
+      return;
+    }
+    std::vector<HloInstruction*> new_operands;
+    for (HloInstruction* operand : instr->mutable_operands()) {
+      create_computation(operand);
+      new_operands.push_back(old_to_new_mapping[operand]);
+    }
+    if (instr->opcode() == HloOpcode::kParameter) {
+      old_to_new_mapping[instr] =
+          builder.AddInstruction(HloInstruction::CreateParameter(
+              param, instr->shape(), absl::StrCat("parameter_", param)));
+      parameters.push_back(instr);
+      param++;
+    } else {
+      old_to_new_mapping[instr] = builder.AddInstruction(
+          instr->CloneWithNewOperands(instr->shape(), new_operands));
+    }
+  };
+  create_computation(root);
+
+  HloComputation* computation =
+      root->GetModule()->AddComputationAndUnifyNamesAndIds(builder.Build(),
+                                                           /*is_entry=*/false);
+
+  HloInstruction* softmax_fusion =
+      root->parent()->AddInstruction(HloInstruction::CreateFusion(
+          root->shape(), HloInstruction::FusionKind::kCustom, parameters,
+          computation));
+
+  softmax_fusion->GetModule()->SetAndUniquifyInstrName(softmax_fusion,
+                                                       "triton_softmax");
+  TF_ASSIGN_OR_RETURN(auto gpu_config,
+                      softmax_fusion->backend_config<GpuBackendConfig>());
+  FusionBackendConfig& backend_config =
+      *gpu_config.mutable_fusion_backend_config();
+  backend_config.set_kind(std::string(kTritonFusionKind));
+  TF_RETURN_IF_ERROR(softmax_fusion->set_backend_config(gpu_config));
+  return xla::Cast<HloFusionInstruction>(softmax_fusion);
+}
+
+absl::Status FuseDiamondChainImpl(
+    const DiamondChainDescriptor& diamond_chain,
+    GpuPerformanceModelWithIndexingAnalysis& indexing_performance_model) {
+  TF_ASSIGN_OR_RETURN(HloFusionInstruction * softmax_fusion,
+                      MakeFusionForDiamondChain(diamond_chain));
+  HloInstruction* root = diamond_chain.root;
+
+  auto fusion_adaptor = HloFusionAdaptor::ForInstruction(softmax_fusion);
+
+  TF_ASSIGN_OR_RETURN(
+      TiledRunTimeDataOrError tiled_runtime_data_or,
+      indexing_performance_model.TryFindBestTilingForFusion(*fusion_adaptor));
+
+  if (const auto* fusion_decision =
+          std::get_if<FusionDecision>(&tiled_runtime_data_or)) {
+    return absl::FailedPreconditionError(absl::StrCat(
+        "SymbolicTileAnalysis failed. ", fusion_decision->Explain()));
+  }
+
+  TiledRunTimeData tiled_runtime_data =
+      std::get<TiledRunTimeData>(std::move(tiled_runtime_data_or));
+
+  TF_ASSIGN_OR_RETURN(auto backend_config,
+                      softmax_fusion->backend_config<GpuBackendConfig>());
+  *backend_config.mutable_fusion_backend_config()
+       ->mutable_block_level_fusion_config() =
+      tiled_runtime_data.block_level_parameters.ToBlockLevelFusionConfig();
+  TF_RETURN_IF_ERROR(softmax_fusion->set_backend_config(backend_config));
+
+  if (root->IsRoot()) {
+    root->parent()->set_root_instruction(softmax_fusion);
+    TF_RETURN_IF_ERROR(
+        root->parent()->RemoveInstructionAndUnusedOperands(root));
+  } else {
+    TF_RETURN_IF_ERROR(
+        root->parent()->ReplaceInstruction(root, softmax_fusion));
+  }
+
+  VLOG(5) << softmax_fusion->ToString();
+  return absl::OkStatus();
+}
+
+// Returns `true` if the diamond chain passed as a parameter can be tiled
+// correctly using `SymbolicTileAnalysis`.
+absl::StatusOr<bool> CanSymbolicTileAnalysisTileDiamondChain(
+    const DiamondChainDescriptor& diamond_chain) {
+  TF_ASSIGN_OR_RETURN(HloFusionInstruction * softmax_fusion,
+                      MakeFusionForDiamondChain(diamond_chain));
+  mlir::MLIRContext context;
+  SymbolicTileAnalysisOrError symbolic_tile_analysis_or_error =
+      SymbolicTileAnalysis::AnalyzeComputation(
+          *softmax_fusion->called_computation(), &context);
+
+  bool can_tile = std::holds_alternative<SymbolicTileAnalysis>(
+      symbolic_tile_analysis_or_error);
+
+  TF_RETURN_IF_ERROR(diamond_chain.root->GetModule()->RemoveEmbeddedComputation(
+      softmax_fusion->called_computation()));
+  TF_RETURN_IF_ERROR(
+      diamond_chain.root->parent()->RemoveInstruction(softmax_fusion));
+
+  return can_tile;
+}
+
+FusionDecision ShouldFuseReduction(const HloInstruction& reduce,
+                                   const se::GpuComputeCapability& cc) {
+  if (CodegenDecision is_supported = IsTritonSupportedInstruction(reduce, cc);
+      !is_supported) {
+    return FusionDecision(is_supported.Explain());
+  }
+
+  // Ensure that the reduction's identity is either a constant or a supported
+  // convert of a constant.
+  const HloInstruction* identity = reduce.operand(1);
+  bool should_fuse_identity =
+      identity->opcode() == HloOpcode::kConstant ||
+      (identity->opcode() == HloOpcode::kConvert &&
+       identity->operand(0)->opcode() == HloOpcode::kConstant &&
+       IsTritonSupportedInstruction(*identity, cc));
+  if (!should_fuse_identity) {
+    return "Reduction identity is not a constant or a supported convert of a "
+           "constant.";
+  }
+
+  return {};
+}
+
+DiamondMatchingDecision MatchesTritonCompatibleClosedReductionDiamondImpl(
+    HloInstruction* instr, const se::GpuComputeCapability& cc) {
+  if (!instr->IsElementwiseBinary()) {
+    return "Root is not elementwise binary.";
+  }
+
+  if (!IsTritonSupportedInstruction(*instr, cc)) {
+    return "Root is not supported for Triton instruction.";
+  }
+
+  HloInstruction* producer;
+  HloInstruction* broadcast;
+  HloInstruction* reduce;
+
+  if (!TrivialEdge(&broadcast, instr->mutable_operand(1), HloOpcode::kBroadcast,
+                   cc)) {
+    return "Could not find a trivial connection from root to a broadcast.";
+  }
+
+  if (!TrivialEdge(&reduce, broadcast->mutable_operand(0), HloOpcode::kReduce,
+                   cc)) {
+    return "Could not find a trivial connection from matched broadcast to a "
+           "reduction.";
+  }
+
+  if (!(HasDefaultLayout(broadcast->shape()) &&
+        HasDefaultLayout(reduce->shape()))) {
+    return "Broadcast or reduce have non-default layouts.";
+  }
+
+  if (FusionDecision should_fuse_reduction = ShouldFuseReduction(*reduce, cc);
+      !should_fuse_reduction) {
+    VLOG(2) << should_fuse_reduction.Explain();
+    return should_fuse_reduction;
+  }
+
+  // Ensure that the reduction's identity is either a constant or a supported
+  // convert of a constant.
+  const HloInstruction* identity = reduce->operand(1);
+  bool should_fuse_identity =
+      identity->opcode() == HloOpcode::kConstant ||
+      (identity->opcode() == HloOpcode::kConvert &&
+       identity->operand(0)->opcode() == HloOpcode::kConstant &&
+       IsTritonSupportedInstruction(*identity, cc));
+  if (!should_fuse_identity) {
+    return "Reduction identity is not a constant or a supported convert of a "
+           "constant.";
+  }
+
+  if (!HasOneUse(broadcast) || !HasOneUse(reduce)) {
+    return "More than one use of broadcast or reduce.";
+  }
+
+  producer = reduce->mutable_operand(0);
+
+  if (absl::c_linear_search(broadcast->dimensions(),
+                            broadcast->shape().rank() - 1)) {
+    return "Broadcast is not along the reduction dimension.";
+  }
+
+  while (IsTriviallyFusible(producer, cc)) {
+    producer = ChooseOperandForFusionProcessing(producer);
+  }
+
+  if (!HasDefaultLayout(producer->shape())) {
+    return "Producer has non-default layout.";
+  }
+
+  if (!IsTriviallyConnectedProducerOf(producer, instr->mutable_operand(0),
+                                      cc)) {
+    return "Producer is not trivially connected.";
+  }
+
+  if (producer != instr->operand(0) && instr->operand(0)->user_count() != 1) {
+    return "Unsupported root-producer connection.";
+  }
+
+  VLOG(5) << "Matched Softmax diamond with: ";
+  VLOG(5) << "root: " << instr->ToString();
+  VLOG(5) << "producer: " << producer->ToString();
+  VLOG(5) << "broadcast: " << broadcast->ToString();
+  VLOG(5) << "reduce: " << reduce->ToString();
+
+  return producer;
+}
+
+// Returns a vector containing all the single diamonds in the parameter module.
+// The diamonds are returned in def-before-use order, and grouped by
+// computation.
+absl::StatusOr<std::vector<DiamondChainDescriptor>> FindAllFusibleDiamonds(
+    HloModule& module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads,
+    const se::GpuComputeCapability& cc) {
+  std::vector<DiamondChainDescriptor> matched_diamonds;
+
+  for (HloComputation* comp :
+       module.MakeNonfusionComputations(execution_threads)) {
+    if (comp->IsCustomCallComputation()) {
+      continue;
+    }
+    for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
+      auto producer =
+          MatchesTritonCompatibleClosedReductionDiamondImpl(instr, cc);
+      if (std::holds_alternative<HloInstruction*>(producer)) {
+        DiamondChainDescriptor diamond_chain{
+            /*root=*/instr, /*producer=*/std::get<HloInstruction*>(producer)};
+        // We filter out the diamond chains that cannot be tiled correctly using
+        // `SymbolicTileAnalysis`.
+        TF_ASSIGN_OR_RETURN(
+            bool can_tile_diamond_chain,
+            CanSymbolicTileAnalysisTileDiamondChain(diamond_chain));
+        if (can_tile_diamond_chain) {
+          matched_diamonds.push_back(diamond_chain);
+        } else {
+          VLOG(5) << "Cannot tile the diamond pattern described by "
+                  << "instructions " << instr->ToString() << " and "
+                  << std::get<HloInstruction*>(producer)->ToString() << ".";
+          continue;
+        }
+
+      } else {
+        VLOG(5) << "Cannot match the diamond pattern for instruction "
+                << instr->ToString()
+                << ". Reason: " << std::get<FusionDecision>(producer).Explain();
+      }
+    }
+  }
+
+  return std::move(matched_diamonds);
+}
+
+// Returns the size of the reduction dimension of the input diamond.
+int64_t GetReductionDimensionSizeForDiamond(
+    const DiamondChainDescriptor& diamond_chain) {
+  HloInstruction* diamond_root = diamond_chain.root;
+  HloInstruction* instr = diamond_root->mutable_operand(1);
+  while (instr->opcode() != HloOpcode::kReduce) {
+    instr = ChooseOperandForFusionProcessing(instr);
+  }
+
+  int operand_rank = instr->operand(0)->shape().rank();
+  CHECK_EQ(instr->dimensions().size(), 1);
+  CHECK_EQ(instr->dimensions(0), operand_rank - 1);
+  return instr->operand(0)->shape().dimensions(operand_rank - 1);
+}
+
+// Returns a pointer to the last user of `instr` that is trivially fusible.
+HloInstruction* GetLastTriviallyFusibleUser(
+    HloInstruction* instr, const se::GpuComputeCapability& cc) {
+  while (HasOneUse(instr) && !instr->IsRoot() &&
+         IsTriviallyFusible(instr->users().front(), cc)) {
+    instr = instr->users().front();
+  }
+
+  // We do not care about the number of users for the last instruction of the
+  // fusion, so attempt to fuse one more instruction with this relaxed
+  // restriction.
+  if (HasOneUse(instr) && !instr->IsRoot() &&
+      IsTriviallyFusible(
+          instr->users().front(), cc,
+          /*num_allowed_users=*/instr->users().front()->user_count())) {
+    instr = instr->users().front();
+  }
+  return instr;
+}
+
+}  // anonymous namespace
+
+DiamondMatchingDecision
+SoftmaxRewriterTriton::MatchesTritonCompatibleClosedReductionDiamond(
+    HloInstruction* instr) const {
+  return MatchesTritonCompatibleClosedReductionDiamondImpl(
+      instr, device_info_.gpu_compute_capability());
+}
+
+absl::StatusOr<std::vector<DiamondChainDescriptor>>
+SoftmaxRewriterTriton::FindAllFusibleDiamondChains(
+    HloModule& module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) const {
+  const se::GpuComputeCapability& cc = device_info_.gpu_compute_capability();
+  TF_ASSIGN_OR_RETURN(std::vector<DiamondChainDescriptor> matched_diamonds,
+                      FindAllFusibleDiamonds(module, execution_threads, cc));
+
+  if (matched_diamonds.empty()) {
+    return std::vector<DiamondChainDescriptor>();
+  }
+
+  // If we matched several diamonds, it may be possible for some of them to be
+  // fused together. This is the case if the following conditions hold:
+  //   1. The path between the root of diamond n towards the producer of
+  //      diamond n+1 is composed only of trivially fusible operations. In that
+  //      case, the first non-trivially fusible producer of diamond n+1 must be
+  //      exactly the root of diamond n.
+  //   2. The root of diamond n/first non-fusible producer of diamond n+1 must
+  //      have
+  //        a. exactly one user if it is not exactly the producer of diamond
+  //           n+1;
+  //        b/ exactly two users otherwise.
+  //   3. The axis being reduced must have the same length in all the diamonds
+  //      being fused together.
+  //
+  // Crucially, this approach relies on a diamond root never being considered a
+  // trivially fusible operation.
+  std::vector<DiamondChainDescriptor> diamond_chains;
+  diamond_chains.reserve(matched_diamonds.size());
+
+  HloInstruction* current_fusion_producer =
+      FindFirstNonFusibleDiamondProducer(matched_diamonds.front().producer, cc);
+  int current_reduce_dimension_size =
+      GetReductionDimensionSizeForDiamond(matched_diamonds.front());
+
+  for (int diamond_idx = 1; diamond_idx < matched_diamonds.size();
+       ++diamond_idx) {
+    HloInstruction* diamond_producer = matched_diamonds[diamond_idx].producer;
+    HloInstruction* previous_diamond_root =
+        matched_diamonds[diamond_idx - 1].root;
+
+    HloInstruction* first_non_fusible_diamond_producer =
+        FindFirstNonFusibleDiamondProducer(diamond_producer, cc);
+
+    int diamond_reduce_dimension_size =
+        GetReductionDimensionSizeForDiamond(matched_diamonds[diamond_idx]);
+
+    if (first_non_fusible_diamond_producer == previous_diamond_root &&  // 1
+        ((first_non_fusible_diamond_producer != diamond_producer &&
+          HasOneUse(first_non_fusible_diamond_producer)) ||  // 2.a
+         (first_non_fusible_diamond_producer == diamond_producer &&
+          first_non_fusible_diamond_producer->user_count() == 2)) &&  // 2.b
+        diamond_reduce_dimension_size == current_reduce_dimension_size) {  // 3
+      continue;
+    }
+
+    // The "last trivially fusible user" chain of diamond chain n should never
+    // intersect with the "first non fusible diamond producer" chain of diamond
+    // chain n+1: if these chains intersected, then all the intermediate ops
+    // between the diamond chains could be trivially fused, and both diamond
+    // chains could be fused into a single diamond chain. Note that this only
+    // holds insofar as we do not allow fusing in bitcasts that modify the last
+    // dimension of the input array. It is however possible for the last
+    // trivially fusible user of diamond chain n to be the first non fusible
+    // diamond producer of diamond chain n+1.
+    diamond_chains.push_back(DiamondChainDescriptor{
+        GetLastTriviallyFusibleUser(previous_diamond_root, cc),
+        current_fusion_producer,
+    });
+
+    current_fusion_producer = first_non_fusible_diamond_producer;
+    current_reduce_dimension_size = diamond_reduce_dimension_size;
+  }
+
+  // The last diamond chain is still open; close it.
+  diamond_chains.push_back(DiamondChainDescriptor{
+      GetLastTriviallyFusibleUser(matched_diamonds.back().root, cc),
+      current_fusion_producer});
+
+  // We filter out the diamond chains that cannot be tiled correctly using
+  // `SymbolicTileAnalysis`.
+  std::vector<DiamondChainDescriptor> filtered_diamond_chains;
+  for (const DiamondChainDescriptor& diamond_chain : diamond_chains) {
+    TF_ASSIGN_OR_RETURN(bool can_tile_diamond_chain,
+                        CanSymbolicTileAnalysisTileDiamondChain(diamond_chain));
+    if (can_tile_diamond_chain) {
+      filtered_diamond_chains.push_back(diamond_chain);
+    }
+  }
+  return filtered_diamond_chains;
+}
+
+absl::Status SoftmaxRewriterTriton::FuseDiamondChain(
+    const DiamondChainDescriptor& diamond_chain) {
+  HloFusionAnalysisCache fusion_analysis_cache(device_info_);
+  GpuPerformanceModelWithIndexingAnalysis indexing_performance_model(
+      &device_info_, &fusion_analysis_cache, shape_size_, &mlir_context_);
+
+  return FuseDiamondChainImpl(diamond_chain, indexing_performance_model);
+}
+
+absl::StatusOr<bool> SoftmaxRewriterTriton::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  TF_RETURN_IF_ERROR(EnsureTritonSupportsComputeCapability(
+      device_info_.gpu_compute_capability()));
+
+  TF_ASSIGN_OR_RETURN(std::vector<DiamondChainDescriptor> diamond_chains,
+                      FindAllFusibleDiamondChains(*module, execution_threads));
+
+  if (diamond_chains.empty()) {
+    return false;
+  }
+
+  // The diamond chains must be emitted in reverse order, to make sure that
+  // producer instructions are emitted correctly when the root of
+  // diamond chain n is exactly the producer of diamond chain n+1.
+  for (auto diamond_chain = diamond_chains.rbegin();
+       diamond_chain != diamond_chains.rend(); ++diamond_chain) {
+    TF_RET_CHECK(FuseDiamondChain(*diamond_chain).ok());
+  }
+  return true;
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.h b/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.h
new file mode 100644
index 0000000..36f780f
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.h
@@ -0,0 +1,101 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_SOFTMAX_REWRITER_TRITON_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_SOFTMAX_REWRITER_TRITON_H_
+
+#include <variant>
+#include <vector>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "mlir/IR/MLIRContext.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_cost_analysis.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/service/instruction_fusion.h"
+#include "xla/stream_executor/device_description.h"
+
+namespace xla {
+namespace gpu {
+
+struct DiamondChainDescriptor {
+  HloInstruction* root = nullptr;
+  HloInstruction* producer = nullptr;
+};
+
+using DiamondMatchingDecision = std::variant<FusionDecision, HloInstruction*>;
+
+// Rewrite compatible Softmax into a custom fusion region to be code-generated
+// with the Triton-based Softmax emitter.
+class SoftmaxRewriterTriton : public HloModulePass {
+ public:
+  explicit SoftmaxRewriterTriton(const se::DeviceDescription& device_info,
+                                 HloCostAnalysis::ShapeSizeFunction shape_size)
+      : device_info_(device_info), shape_size_(shape_size) {}
+
+  absl::string_view name() const override { return "triton-softmax-rewriter"; }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+  // Finds and returns all the fusible diamond chains in the module. The
+  // resulting vector is sorted according to a post-order matching (i.e. within
+  // the same computation, producer diamonds appear before consumer diamonds).
+  absl::StatusOr<std::vector<DiamondChainDescriptor>>
+  FindAllFusibleDiamondChains(
+      HloModule& module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) const;
+
+  // Constructs a Softmax fusion containing all the instructions between the
+  // root and the producer of a diamond chain. The producer is excluded from the
+  // fusion.
+  absl::Status FuseDiamondChain(const DiamondChainDescriptor& diamond_chain);
+
+  // Return the producer of the following pattern:
+  //
+  // producer
+  // |    \
+  // |  reduce_{max,sum,...}
+  // |     |
+  // |  broadcast
+  // |   /
+  // binop (elementwise)
+  //
+  // where each edge is allowed to contain also trivial operations that can be
+  // generated by Triton. We mean by "trivial" here those operations that do not
+  // increase the amount of memory read/written by the fusion, and that are
+  // compatible with any chosen tiling.
+  //
+  // We also assume that the reduction is done on the last axis of the producer
+  // array.
+  DiamondMatchingDecision MatchesTritonCompatibleClosedReductionDiamond(
+      HloInstruction* instr) const;
+
+ private:
+  const se::DeviceDescription& device_info_;
+  const HloCostAnalysis::ShapeSizeFunction shape_size_;
+  mlir::MLIRContext mlir_context_;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_SOFTMAX_REWRITER_TRITON_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton_test.cc b/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton_test.cc
new file mode 100644
index 0000000..1b3139c
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton_test.cc
@@ -0,0 +1,1590 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+
+You may obtain a copy of the License at
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "xla/service/gpu/transforms/softmax_rewriter_triton.h"
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <variant>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/hlo/utils/hlo_query.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/fusions/triton/triton_support.h"
+#include "xla/service/gpu/gpu_device_info_for_tests.h"
+#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h"
+#include "xla/service/instruction_fusion.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/status_matchers.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+namespace m = ::xla::match;
+
+using ::testing::HasSubstr;
+
+GpuHloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() {
+  return [&](const Shape& shape) {
+    constexpr int64_t kPointerSize = 8;
+    return ShapeUtil::ByteSizeOf(shape, kPointerSize);
+  };
+}
+
+bool HasBlockLevelFusionConfig(const HloInstruction* fusion) {
+  return fusion->opcode() == HloOpcode::kFusion &&
+         fusion->has_backend_config() &&
+         fusion->backend_config<GpuBackendConfig>().ok() &&
+         fusion->backend_config<GpuBackendConfig>()
+             ->fusion_backend_config()
+             .has_block_level_fusion_config();
+}
+
+// Wrapper around SoftmaxRewriterTriton(gpu_version).Run(module) that finds
+// and fuses as many diamond chains as possible without invoking any kind of
+// cost analysis.
+absl::StatusOr<bool> SoftmaxRewriterTritonMatchAndRewrite(
+    const se::DeviceDescription& device_info, HloModule* module) {
+  CHECK_NE(module, nullptr);
+  SoftmaxRewriterTriton softmax_rewriter_triton(device_info,
+                                                ShapeSizeBytesFunction());
+  TF_ASSIGN_OR_RETURN(std::vector<DiamondChainDescriptor> diamond_chains,
+                      softmax_rewriter_triton.FindAllFusibleDiamondChains(
+                          *module, /*execution_threads=*/{}));
+
+  for (auto diamond_chain = diamond_chains.rbegin();
+       diamond_chain != diamond_chains.rend(); ++diamond_chain) {
+    TF_RETURN_IF_ERROR(
+        softmax_rewriter_triton.FuseDiamondChain(*diamond_chain));
+  }
+
+  return !diamond_chains.empty();
+}
+
+class SoftmaxRewriterTritonTest
+    : public HloTestBase,
+      public ::testing::WithParamInterface<PrimitiveType> {
+ protected:
+  se::DeviceDescription device_info_{TestGpuDeviceInfo::RTXA6000DeviceInfo()};
+};
+
+TEST_F(SoftmaxRewriterTritonTest, CanFuseExactSoftmaxF32) {
+  const std::string hlo_string = R"(
+HloModule softmax
+max_computation {
+  arg_0 = f32[] parameter(0)
+  arg_1 = f32[] parameter(1)
+  ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+add_computation {
+  arg_0.1 = f32[] parameter(0)
+  arg_1.1 = f32[] parameter(1)
+  ROOT add = f32[] add(arg_0.1, arg_1.1)
+}
+ENTRY main {
+  param_0 = f32[127,125]{1,0} parameter(0)
+  constant_neg_inf = f32[] constant(-inf)
+  reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
+  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+  subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
+  exponential = f32[127,125]{1,0} exponential(subtract)
+  constant_zero = f32[] constant(0)
+  second_reduce = f32[127]{0} reduce(exponential, constant_zero), dimensions={1}, to_apply=add_computation
+  second_broadcast = f32[127,125]{1,0} broadcast(second_reduce), dimensions={0}
+  ROOT divide = f32[127,125]{1,0} divide(exponential, second_broadcast)
+}
+)";
+  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+
+  EXPECT_TRUE(
+      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+  EXPECT_TRUE(verifier().Run(module.get()).status().ok());
+  VLOG(2) << module->ToString();
+
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(
+          m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
+}
+
+TEST_F(SoftmaxRewriterTritonTest,
+       CanFuseSoftmaxLikeComputationWithNonF32DataType) {
+  const std::string hlo_string = R"(
+HloModule softmax
+max_computation {
+  arg_0 = f16[] parameter(0)
+  arg_1 = f16[] parameter(1)
+  ROOT maximum = f16[] maximum(arg_0, arg_1)
+}
+add_computation {
+  arg_0.1 = f16[] parameter(0)
+  arg_1.1 = f16[] parameter(1)
+  ROOT add = f16[] add(arg_0.1, arg_1.1)
+}
+ENTRY main {
+  param_0 = f16[127,125]{1,0} parameter(0)
+  constant_neg_inf = f16[] constant(-inf)
+  reduce = f16[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
+  broadcast = f16[127,125]{1,0} broadcast(reduce), dimensions={0}
+  subtract = f16[127,125]{1,0} subtract(param_0, broadcast)
+  // Replace Softmax exponential with abs, because Triton doesn't support
+  // non-f32 exponentials.
+  abs = f16[127,125]{1,0} abs(subtract)
+  constant_zero = f16[] constant(0)
+  second_reduce = f16[127]{0} reduce(abs, constant_zero), dimensions={1}, to_apply=add_computation
+  second_broadcast = f16[127,125]{1,0} broadcast(second_reduce), dimensions={0}
+  // Replace divide with multiply, because Triton doesn't support f16
+  // divisions.
+  ROOT multiply = f16[127,125]{1,0} multiply(abs, second_broadcast)
+}
+)";
+  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+
+  EXPECT_TRUE(
+      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+  EXPECT_TRUE(verifier().Run(module.get()).status().ok());
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(
+          m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
+}
+
+TEST_F(SoftmaxRewriterTritonTest, CanFuseSingleNormalizationDiamond) {
+  const std::string hlo_string = R"(
+HloModule softmax
+max_computation {
+  arg_0 = f32[] parameter(0)
+  arg_1 = f32[] parameter(1)
+  ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+ENTRY main {
+  param_0 = f32[127,125]{1,0} parameter(0)
+  constant_neg_inf = f32[] constant(-inf)
+  reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
+  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+  ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
+}
+)";
+  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+  EXPECT_TRUE(
+      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+  EXPECT_TRUE(verifier().Run(module.get()).status().ok());
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(
+          m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
+}
+
+TEST_F(SoftmaxRewriterTritonTest,
+       DoesNotFuseDiamondInvolvingUnsupportedTritonInstruction) {
+  const std::string hlo_string = R"(
+HloModule softmax
+add_computation {
+  arg_0.1 = bf16[] parameter(0)
+  arg_1.1 = bf16[] parameter(1)
+  ROOT add = bf16[] add(arg_0.1, arg_1.1)
+}
+ENTRY main {
+  param_0 = bf16[127,125]{1,0} parameter(0)
+  constant_zero = bf16[] constant(0)
+  reduce = bf16[127]{0} reduce(param_0, constant_zero), dimensions={1}, to_apply=add_computation
+  broadcast = bf16[127,125]{1,0} broadcast(reduce), dimensions={0}
+  ROOT divide = bf16[127,125]{1,0} divide(param_0, broadcast)
+})";
+
+  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+  const HloInstruction* bf16_divide =
+      module->entry_computation()->root_instruction();
+  EXPECT_FALSE(IsTritonSupportedInstruction(
+      *bf16_divide, device_info_.gpu_compute_capability()));
+  EXPECT_FALSE(
+      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+}
+
+TEST_F(SoftmaxRewriterTritonTest,
+       DoesNotFuseInstructionsUnsupportedByTritonIntoDiamonds) {
+  const std::string hlo_string = R"(
+HloModule softmax
+max_computation {
+  arg_0 = bf16[] parameter(0)
+  arg_1 = bf16[] parameter(1)
+  ROOT maximum = bf16[] maximum(arg_0, arg_1)
+}
+ENTRY main {
+  param_0 = bf16[127,125]{1,0} parameter(0)
+  constant_neg_inf = bf16[] constant(-inf)
+  reduce = bf16[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
+  broadcast = bf16[127,125]{1,0} broadcast(reduce), dimensions={0}
+  subtract = bf16[127,125]{1,0} subtract(param_0, broadcast)
+  ROOT exponential = bf16[127,125]{1,0} exponential(subtract)
+})";
+
+  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+  const HloInstruction* bf16_exponential =
+      hlo_query::GetFirstInstructionWithOpcode(*module->entry_computation(),
+                                               HloOpcode::kExp);
+  EXPECT_FALSE(IsTritonSupportedInstruction(
+      *bf16_exponential, device_info_.gpu_compute_capability()));
+  EXPECT_TRUE(
+      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+  EXPECT_TRUE(verifier().Run(module.get()).status().ok());
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(m::Exp(
+          m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig))));
+}
+
+TEST_F(SoftmaxRewriterTritonTest, CanNotFuseSoftmaxDiamondWithWrongLayout) {
+  const std::string hlo_string = R"(
+HloModule softmax
+max_computation {
+  arg_0 = f32[] parameter(0)
+  arg_1 = f32[] parameter(1)
+  ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+ENTRY main {
+  param_0 = f32[127,125]{0,1} parameter(0)
+  constant_neg_inf = f32[] constant(-inf)
+  reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
+  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+  ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
+}
+)";
+  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+  EXPECT_FALSE(
+      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+}
+
+TEST_F(SoftmaxRewriterTritonTest,
+       CanNotFuseSoftmaxDiamondWithWrongReduceDimension) {
+  const std::string hlo_string = R"(
+HloModule softmax
+max_computation {
+  arg_0 = f32[] parameter(0)
+  arg_1 = f32[] parameter(1)
+  ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+ENTRY main {
+  param_0 = f32[127,125]{1,0} parameter(0)
+  constant_neg_inf = f32[] constant(-inf)
+  reduce = f32[125]{0} reduce(param_0, constant_neg_inf), dimensions={0}, to_apply=max_computation
+  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={1}
+  ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
+}
+)";
+  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+  EXPECT_FALSE(
+      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+}
+
+TEST_F(SoftmaxRewriterTritonTest,
+       CanNotFuseSoftmaxDiamondWithWrongBroadcastDimension) {
+  const std::string hlo_string = R"(
+HloModule softmax
+max_computation {
+  arg_0 = f32[] parameter(0)
+  arg_1 = f32[] parameter(1)
+  ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+ENTRY main {
+  param_0 = f32[125,125]{1,0} parameter(0)
+  constant_neg_inf = f32[] constant(-inf)
+  reduce = f32[125]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
+  broadcast = f32[125,125]{1,0} broadcast(reduce), dimensions={1}
+  ROOT subtract = f32[125,125]{1,0} subtract(param_0, broadcast)
+}
+)";
+  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+  EXPECT_FALSE(
+      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+}
+
+TEST_F(SoftmaxRewriterTritonTest,
+       CanNotFuseSoftmaxDiamondWithExtraBroadcastUsage) {
+  const std::string hlo_string = R"(
+HloModule softmax
+max_computation {
+  arg_0 = f32[] parameter(0)
+  arg_1 = f32[] parameter(1)
+  ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+ENTRY main {
+  param_0 = f32[127,125]{1,0} parameter(0)
+  constant_neg_inf = f32[] constant(-inf)
+  reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
+  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+  subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
+  ROOT multiply = f32[127,125]{1,0} multiply(broadcast, subtract)
+}
+)";
+  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+  EXPECT_FALSE(
+      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+}
+
+TEST_F(SoftmaxRewriterTritonTest,
+       CanFuseSoftmaxWithIntermediateUnaryElementwise) {
+  const std::string hlo_string = R"(
+HloModule softmax
+max_computation {
+  arg_0 = f32[] parameter(0)
+  arg_1 = f32[] parameter(1)
+  ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+add_computation {
+  arg_0.1 = f32[] parameter(0)
+  arg_1.1 = f32[] parameter(1)
+  ROOT add = f32[] add(arg_0.1, arg_1.1)
+}
+ENTRY main {
+  param_0 = f32[127,125]{1,0} parameter(0)
+  constant_neg_inf = f32[] constant(-inf)
+  reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
+  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+  subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
+  abs = f32[127,125]{1,0} abs(subtract)
+  exponential = f32[127,125]{1,0} exponential(abs)
+  constant_zero = f32[] constant(0)
+  second_reduce = f32[127]{0} reduce(exponential, constant_zero), dimensions={1}, to_apply=add_computation
+  second_broadcast = f32[127,125]{1,0} broadcast(second_reduce), dimensions={0}
+  ROOT divide = f32[127,125]{1,0} divide(exponential, second_broadcast)
+}
+)";
+  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+
+  EXPECT_TRUE(
+      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+  EXPECT_TRUE(verifier().Run(module.get()).status().ok());
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(
+          m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
+}
+
+TEST_F(SoftmaxRewriterTritonTest,
+       CanFuseTwoDiamondsWithSecondDiamondProducerEqualToFirstDiamondRoot) {
+  const std::string hlo_string = R"(
+HloModule softmax
+max_computation {
+  arg_0 = f32[] parameter(0)
+  arg_1 = f32[] parameter(1)
+  ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+add_computation {
+  arg_0.1 = f32[] parameter(0)
+  arg_1.1 = f32[] parameter(1)
+  ROOT add = f32[] add(arg_0.1, arg_1.1)
+}
+ENTRY main {
+  param_0 = f32[127,125]{1,0} parameter(0)
+  constant_neg_inf = f32[] constant(-inf)
+  reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
+  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+  subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
+  constant_zero = f32[] constant(0)
+  second_reduce = f32[127]{0} reduce(subtract, constant_zero), dimensions={1}, to_apply=add_computation
+  second_broadcast = f32[127,125]{1,0} broadcast(second_reduce), dimensions={0}
+  ROOT divide = f32[127,125]{1,0} divide(subtract, second_broadcast)
+}
+)";
+  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+
+  EXPECT_TRUE(
+      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+  EXPECT_TRUE(verifier().Run(module.get()).status().ok());
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(
+          m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
+}
+
+TEST_F(SoftmaxRewriterTritonTest,
+       CanFuseDiamondWithTrailingUnaryElementwiseAtTheRoot) {
+  const std::string hlo_string = R"(
+HloModule softmax
+max_computation {
+  arg_0 = f32[] parameter(0)
+  arg_1 = f32[] parameter(1)
+  ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+ENTRY main {
+  param_0 = f32[127,125]{1,0} parameter(0)
+  constant_neg_inf = f32[] constant(-inf)
+  reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
+  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+  subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
+  ROOT abs = f32[127,125]{1,0} abs(subtract)
+}
+)";
+  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+  EXPECT_TRUE(
+      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+  EXPECT_TRUE(verifier().Run(module.get()).status().ok());
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(
+          m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
+}
+
+TEST_F(SoftmaxRewriterTritonTest, CanFuseDiamondWithUnaryElementwisePrefix) {
+  const std::string hlo_string = R"(
+HloModule softmax
+max_computation {
+  arg_0 = f32[] parameter(0)
+  arg_1 = f32[] parameter(1)
+  ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+ENTRY main {
+  param_0 = f32[127,125]{1,0} parameter(0)
+  abs = f32[127,125]{1,0} abs(param_0)
+  constant_neg_inf = f32[] constant(-inf)
+  reduce = f32[127]{0} reduce(abs, constant_neg_inf), dimensions={1}, to_apply=max_computation
+  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+  ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
+}
+)";
+  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+  EXPECT_TRUE(
+      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+  EXPECT_TRUE(verifier().Run(module.get()).status().ok());
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(
+          m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
+}
+
+TEST_F(SoftmaxRewriterTritonTest,
+       CanFuseDiamondWithMultipleBroadcastDimensions) {
+  const std::string hlo_string = R"(
+HloModule softmax
+max_computation {
+  arg_0 = f32[] parameter(0)
+  arg_1 = f32[] parameter(1)
+  ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+ENTRY main {
+  param_0 = f32[1,3,125,125]{3,2,1,0} parameter(0)
+  bitcast = f32[3,125,125]{2,1,0} bitcast(f32[1,3,125,125]{3,2,1,0} param_0)
+  constant_neg_inf = f32[] constant(-inf)
+  reduce = f32[3,125]{1,0} reduce(f32[3,125,125]{2,1,0} bitcast, f32[] constant_neg_inf), dimensions={2}, to_apply=max_computation
+  broadcast = f32[1,3,125,125]{3,2,1,0} broadcast(f32[3,125]{1,0} reduce), dimensions={1,2}
+  ROOT subtract = f32[1,3,125,125]{3,2,1,0} subtract(f32[1,3,125,125]{3,2,1,0} param_0, f32[1,3,125,125]{3,2,1,0} broadcast)
+})";
+  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+
+  EXPECT_TRUE(
+      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+  EXPECT_TRUE(verifier().Run(module.get()).status().ok());
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(
+          m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
+}
+
+TEST_F(SoftmaxRewriterTritonTest,
+       CanNotFuseSoftmaxDiamondWithParameterReducerIdentity) {
+  const std::string hlo_string = R"(
+HloModule softmax
+max_computation {
+  arg_0 = f32[] parameter(0)
+  arg_1 = f32[] parameter(1)
+  ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+
+ENTRY main {
+  param_0 = f32[127,125]{1,0} parameter(0)
+  identity = f32[] parameter(1)
+  constant_neg_inf = f32[] constant(-inf)
+  reduce = f32[127]{0} reduce(param_0, identity), dimensions={1}, to_apply=max_computation
+  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+  ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
+}
+)";
+  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+  EXPECT_FALSE(
+      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+}
+
+TEST_F(SoftmaxRewriterTritonTest,
+       CanNotFuseSoftmaxDiamondWithTritonIncompatibleReducer) {
+  const std::string hlo_string = R"(
+HloModule softmax
+max_computation {
+  arg_0 = f32[] parameter(0)
+  arg_1 = f32[] parameter(1)
+  if_0 = pred[] is-finite(arg_0)
+  c = f32[] convert(if_0)
+  ROOT maximum = f32[] maximum(c, arg_1)
+}
+
+ENTRY main {
+  param_0 = f32[127,125]{1,0} parameter(0)
+  constant_neg_inf = f32[] constant(-inf)
+  reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
+  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+  ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
+}
+)";
+  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+  EXPECT_FALSE(
+      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+}
+
+TEST_F(SoftmaxRewriterTritonTest,
+       CanFuseSoftmaxDiamondWithLastDimensionBitcastAfterReduce) {
+  const std::string hlo_string = R"(
+HloModule softmax
+max_computation {
+  arg_0 = f32[] parameter(0)
+  arg_1 = f32[] parameter(1)
+  ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+
+ENTRY main {
+  param_0 = f32[3,127,125]{2,1,0} parameter(0)
+  constant_neg_inf = f32[] constant(-inf)
+  reduce = f32[3,127]{1,0} reduce(param_0, constant_neg_inf), dimensions={2}, to_apply=max_computation
+  bitcasted_reduce = f32[381]{0} bitcast(reduce)
+  broadcast = f32[381,125]{1,0} broadcast(bitcasted_reduce), dimensions={0}
+  bitcasted_broadcast = f32[3,127,125]{2,1,0} bitcast(broadcast)
+  ROOT subtract = f32[3,127,125]{2,1,0} subtract(param_0, bitcasted_broadcast)
+}
+)";
+  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+  EXPECT_TRUE(
+      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+  EXPECT_TRUE(verifier().Run(module.get()).status().ok());
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(
+          m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
+}
+
+TEST_F(SoftmaxRewriterTritonTest,
+       CanNotFuseSoftmaxDiamondWithTransposeBitcast) {
+  const std::string hlo_string = R"(
+HloModule softmax
+max_computation {
+  arg_0 = f32[] parameter(0)
+  arg_1 = f32[] parameter(1)
+  ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+
+ENTRY main {
+  param_0 = f32[1,127,125]{2,1,0} parameter(0)
+  constant_neg_inf = f32[] constant(-inf)
+  bitcasted_param_0 = f32[127,1,125]{2,0,1} bitcast(param_0)
+  reduce = f32[127,1]{0,1} reduce(bitcasted_param_0, constant_neg_inf), dimensions={2}, to_apply=max_computation
+  broadcast = f32[127,1,125]{2,0,1} broadcast(reduce), dimensions={0,1}
+  bitcasted_broadcast = f32[1,127,125]{2,1,0} bitcast(broadcast)
+  ROOT subtract = f32[1,127,125]{2,1,0} subtract(param_0, bitcasted_broadcast)
+}
+)";
+  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+  EXPECT_FALSE(
+      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+}
+
+TEST_F(SoftmaxRewriterTritonTest,
+       CanNotFuseTwoDiamondsWithDifferentReductionAxisSizeTogether) {
+  const std::string hlo_string = R"(
+HloModule softmax
+max_computation {
+  arg_0 = f32[] parameter(0)
+  arg_1 = f32[] parameter(1)
+  ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+add_computation {
+  arg_0.1 = f32[] parameter(0)
+  arg_1.1 = f32[] parameter(1)
+  ROOT add = f32[] add(arg_0.1, arg_1.1)
+}
+ENTRY main {
+  param_0 = f32[127,625]{1,0} parameter(0)
+  constant_neg_inf = f32[] constant(-inf)
+  reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
+  broadcast = f32[127,625]{1,0} broadcast(reduce), dimensions={0}
+  subtract = f32[127,625]{1,0} subtract(param_0, broadcast)
+  bitcasted_subtract = f32[127,5,125] bitcast(subtract)
+  exponential = f32[127,5,125] exponential(bitcasted_subtract)
+  constant_zero = f32[] constant(0)
+  second_reduce = f32[127,5] reduce(exponential, constant_zero), dimensions={2}, to_apply=add_computation
+  second_broadcast = f32[127,5,125] broadcast(second_reduce), dimensions={0,1}
+  ROOT divide = f32[127,5,125] divide(exponential, second_broadcast)
+}
+)";
+  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+
+  EXPECT_TRUE(
+      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+  EXPECT_TRUE(verifier().Run(module.get()).status().ok());
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(
+          m::Fusion(m::Bitcast(m::Fusion(m::Parameter())
+                                   .WithPredicate(HasBlockLevelFusionConfig)))
+              .WithPredicate(HasBlockLevelFusionConfig)));
+}
+
+TEST_F(SoftmaxRewriterTritonTest,
+       CanNotFuseTwoDiamondsWithExtraUsageForFirstDiamondRoot) {
+  const std::string hlo_string = R"(
+HloModule softmax
+max_computation {
+  arg_0 = f32[] parameter(0)
+  arg_1 = f32[] parameter(1)
+  ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+add_computation {
+  arg_0.1 = f32[] parameter(0)
+  arg_1.1 = f32[] parameter(1)
+  ROOT add = f32[] add(arg_0.1, arg_1.1)
+}
+ENTRY main {
+  param_0 = f32[127,125]{1,0} parameter(0)
+  constant_neg_inf = f32[] constant(-inf)
+  reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
+  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+  subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
+  exponential = f32[127,125]{1,0} exponential(subtract)
+  constant_zero = f32[] constant(0)
+  second_reduce = f32[127]{0} reduce(exponential, constant_zero), dimensions={1}, to_apply=add_computation
+  second_broadcast = f32[127,125]{1,0} broadcast(second_reduce), dimensions={0}
+  divide = f32[127,125]{1,0} divide(exponential, second_broadcast)
+  ROOT tuple = (f32[127,125]{1,0}, f32[127,125]{1,0}) tuple(divide, subtract)
+}
+)";
+  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+
+  EXPECT_TRUE(
+      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+  EXPECT_TRUE(verifier().Run(module.get()).status().ok());
+      EXPECT_THAT(
+          module->entry_computation()->root_instruction(),
+          GmockMatch(m::Tuple(
+              m::Fusion(m::Fusion()).WithPredicate(HasBlockLevelFusionConfig),
+              m::Fusion(m::Parameter())
+                  .WithPredicate(HasBlockLevelFusionConfig))));
+}
+
+TEST_F(SoftmaxRewriterTritonTest,
+       CanNotFuseTwoDiamondsWithExtraUsageForSecondDiamondProducer) {
+  const std::string hlo_string = R"(
+HloModule softmax
+max_computation {
+  arg_0 = f32[] parameter(0)
+  arg_1 = f32[] parameter(1)
+  ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+add_computation {
+  arg_0.1 = f32[] parameter(0)
+  arg_1.1 = f32[] parameter(1)
+  ROOT add = f32[] add(arg_0.1, arg_1.1)
+}
+ENTRY main {
+  param_0 = f32[127,125]{1,0} parameter(0)
+  constant_neg_inf = f32[] constant(-inf)
+  reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
+  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+  subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
+  exponential = f32[127,125]{1,0} exponential(subtract)
+  constant_zero = f32[] constant(0)
+  second_reduce = f32[127]{0} reduce(exponential, constant_zero), dimensions={1}, to_apply=add_computation
+  second_broadcast = f32[127,125]{1,0} broadcast(second_reduce), dimensions={0}
+  divide = f32[127,125]{1,0} divide(exponential, second_broadcast)
+  ROOT tuple = (f32[127,125]{1,0}, f32[127,125]{1,0}) tuple(divide, exponential)
+}
+)";
+  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+
+  EXPECT_TRUE(
+      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+  EXPECT_TRUE(verifier().Run(module.get()).status().ok());
+
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(m::Tuple(
+          m::Fusion(m::Fusion()).WithPredicate(HasBlockLevelFusionConfig),
+          m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig))));
+}
+
+TEST_F(SoftmaxRewriterTritonTest,
+       CanFuseSoftmaxDiamondWithTritonIncompatibleProducer) {
+  const std::string hlo_string = R"(
+HloModule softmax
+max_computation {
+  arg_0 = f32[] parameter(0)
+  arg_1 = f32[] parameter(1)
+  ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+
+ENTRY main {
+  param_0 = f16[127,125]{1,0} parameter(0)
+  exponential = f16[127,125] exponential(param_0)
+  convert = f32[127,125] convert(exponential)
+  constant_neg_inf = f32[] constant(-inf)
+  reduce = f32[127]{0} reduce(convert, constant_neg_inf), dimensions={1}, to_apply=max_computation
+  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+  ROOT subtract = f32[127,125]{1,0} subtract(convert, broadcast)
+})";
+  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+
+  EXPECT_TRUE(
+      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+  EXPECT_TRUE(verifier().Run(module.get()).status().ok());
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch(m::Fusion(m::Exp(m::Parameter()))
+                             .WithPredicate(HasBlockLevelFusionConfig)));
+}
+
+TEST_F(SoftmaxRewriterTritonTest,
+       CanNotFuseSoftmaxDiamondWithNonFusibleBitcastBetweenReduceAndProducer) {
+  const std::string hlo_string = R"(
+HloModule softmax
+
+max_computation {
+  arg_0 = f32[] parameter(0)
+  arg_1 = f32[] parameter(1)
+  ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+
+ENTRY main {
+  param_0 = f32[1,127,5,25]{3,2,1,0} parameter(0)
+  bitcast_0 = f32[127,125] bitcast(param_0)
+  bitcast_1 = f32[127,125] bitcast(param_0)
+  constant_neg_inf = f32[] constant(-inf)
+  reduce = f32[127]{0} reduce(bitcast_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
+  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+  ROOT subtract = f32[127,125]{1,0} subtract(bitcast_1, broadcast)
+}
+)";
+  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+  EXPECT_FALSE(
+      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+}
+
+TEST_F(SoftmaxRewriterTritonTest,
+       CanFuseSoftmaxDiamondWithBitcastProducerFollowedByBitcastsOnEachUse) {
+  const std::string hlo_string = R"(
+HloModule softmax
+
+max_computation {
+  arg_0 = f32[] parameter(0)
+  arg_1 = f32[] parameter(1)
+  ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+
+ENTRY main {
+  param_0 = f32[1,1,127,125]{3,2,1,0} parameter(0)
+  bitcast_parent = f32[127,125]{1,0} bitcast(param_0)
+  bitcast_0 = f32[127,125]{1,0} bitcast(bitcast_parent)
+  bitcast_1 = f32[127,125]{1,0} bitcast(bitcast_parent)
+  constant_neg_inf = f32[] constant(-inf)
+  reduce = f32[127]{0} reduce(bitcast_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
+  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+  ROOT subtract = f32[127,125]{1,0} subtract(bitcast_1, broadcast)
+}
+)";
+  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+  EXPECT_TRUE(
+      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+  EXPECT_TRUE(verifier().Run(module.get()).status().ok());
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(
+          m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
+}
+
+TEST_F(SoftmaxRewriterTritonTest, RewriterBailsOutOnPreAmpereCudaGpu) {
+  const std::string hlo_string = R"(
+HloModule softmax
+max_computation {
+  arg_0 = f32[] parameter(0)
+  arg_1 = f32[] parameter(1)
+  ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+ENTRY main {
+  param_0 = bf16[127,125]{1,0} parameter(0)
+  param_0_f32 = f32[127,125]{1,0} convert(param_0)
+  constant_neg_inf = f32[] constant(-inf)
+  reduce = f32[127]{0} reduce(param_0_f32, constant_neg_inf), dimensions={1}, to_apply=max_computation
+  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+  ROOT subtract = f32[127,125]{1,0} subtract(param_0_f32, broadcast)
+})";
+
+  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+
+  EXPECT_THAT(
+      SoftmaxRewriterTriton(
+          TestGpuDeviceInfo::RTXA6000DeviceInfo(
+              se::CudaComputeCapability{se::CudaComputeCapability::VOLTA, 0}),
+          ShapeSizeBytesFunction())
+          .Run(module.get()),
+      tsl::testing::StatusIs(
+          tsl::error::FAILED_PRECONDITION,
+          ::testing::HasSubstr("Triton support is only enabled for Ampere GPUs "
+                               "(compute capability 8.0) and up, but got")));
+}
+
+TEST_F(SoftmaxRewriterTritonTest, RewriterSucceedsOnNonCudaGpu) {
+  const std::string hlo_string = R"(
+HloModule softmax
+max_computation {
+  arg_0 = f32[] parameter(0)
+  arg_1 = f32[] parameter(1)
+  ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+ENTRY main {
+  param_0 = bf16[127,125]{1,0} parameter(0)
+  param_0_f32 = f32[127,125]{1,0} convert(param_0)
+  constant_neg_inf = f32[] constant(-inf)
+  reduce = f32[127]{0} reduce(param_0_f32, constant_neg_inf), dimensions={1}, to_apply=max_computation
+  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+  ROOT subtract = f32[127,125]{1,0} subtract(param_0_f32, broadcast)
+})";
+
+  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+
+  EXPECT_TRUE(SoftmaxRewriterTriton(TestGpuDeviceInfo::AMDMI210DeviceInfo(),
+                                    ShapeSizeBytesFunction())
+                  .Run(module.get())
+                  .ok());
+}
+
+TEST_F(SoftmaxRewriterTritonTest,
+       CanFuseBinaryElementwiseProducerIntoDiamondWhenBothOperandsAreTheSame) {
+  const std::string hlo_string = R"(
+HloModule fusible_diamond
+max_computation {
+  arg_0 = f32[] parameter(0)
+  arg_1 = f32[] parameter(1)
+  ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+ENTRY main {
+  param_0 = f32[127,125]{1,0} parameter(0)
+  multiply =  f32[127,125]{1,0} multiply(param_0, param_0)
+  constant_neg_inf = f32[] constant(-inf)
+  reduce = f32[127]{0} reduce(multiply, constant_neg_inf), dimensions={1}, to_apply=max_computation
+  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+  ROOT subtract = f32[127,125]{1,0} subtract(multiply, broadcast)
+})";
+  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+  EXPECT_TRUE(
+      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+  EXPECT_TRUE(verifier().Run(module.get()).status().ok());
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(
+          m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
+}
+
+TEST_F(
+    SoftmaxRewriterTritonTest,
+    CanFuseIntermediateBinaryElementwiseWithinDiamondWhenBothOperandsAreTheSame) {  // NOLINT(whitespace/line_length)
+  const std::string hlo_string = R"(
+HloModule fusible_diamond
+max_computation {
+  arg_0 = f32[] parameter(0)
+  arg_1 = f32[] parameter(1)
+  ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+ENTRY main {
+  param_0 = f32[127,125]{1,0} parameter(0)
+  constant_neg_inf = f32[] constant(-inf)
+  reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
+  multiply =  f32[127]{0} multiply(reduce, reduce)
+  broadcast = f32[127,125]{1,0} broadcast(multiply), dimensions={0}
+  ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
+}
+)";
+  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+  EXPECT_TRUE(
+      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+  EXPECT_TRUE(verifier().Run(module.get()).status().ok());
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(
+          m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
+}
+
+TEST_F(SoftmaxRewriterTritonTest,
+       CanFuseBinaryElementwiseWhenBothOperandsAreTheSameBetweenDiamonds) {
+  const std::string hlo_string = R"(
+HloModule fusible_diamonds
+max_computation {
+  arg_0 = f32[] parameter(0)
+  arg_1 = f32[] parameter(1)
+  ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+add_computation {
+  arg_0.1 = f32[] parameter(0)
+  arg_1.1 = f32[] parameter(1)
+  ROOT add = f32[] add(arg_0.1, arg_1.1)
+}
+ENTRY main {
+  param_0 = f32[127,125]{1,0} parameter(0)
+  constant_neg_inf = f32[] constant(-inf)
+  reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
+  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+  subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
+  multiply = f32[127,125]{1,0} multiply(subtract, subtract)
+  constant_zero = f32[] constant(0)
+  second_reduce = f32[127]{0} reduce(multiply, constant_zero), dimensions={1}, to_apply=add_computation
+  second_broadcast = f32[127,125]{1,0} broadcast(second_reduce), dimensions={0}
+  ROOT subtract_second = f32[127,125]{1,0} subtract(multiply, second_broadcast)
+}
+)";
+  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+  EXPECT_TRUE(
+      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+  EXPECT_TRUE(verifier().Run(module.get()).status().ok());
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(
+          m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
+}
+
+TEST_F(SoftmaxRewriterTritonTest,
+       CanFuseBinaryElementwiseConsumerWhereBothOperandsAreTheSameIntoDiamond) {
+  const std::string hlo_string = R"(
+HloModule fusible_diamond
+max_computation {
+  arg_0 = f32[] parameter(0)
+  arg_1 = f32[] parameter(1)
+  ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+add_computation {
+  arg_0.1 = f32[] parameter(0)
+  arg_1.1 = f32[] parameter(1)
+  ROOT add = f32[] add(arg_0.1, arg_1.1)
+}
+ENTRY main {
+  param_0 = f32[127,125]{1,0} parameter(0)
+  constant_neg_inf = f32[] constant(-inf)
+  reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
+  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+  subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
+  ROOT multiply = f32[127,125]{1,0} multiply(subtract, subtract)
+}
+)";
+  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+  EXPECT_TRUE(
+      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+  EXPECT_TRUE(verifier().Run(module.get()).status().ok());
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(
+          m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
+}
+
+TEST_F(
+    SoftmaxRewriterTritonTest,
+    DoesNotFuseIntermediateBinaryElementwiseWithBothSplatOperandsIntoDiamond) {
+  const std::string hlo_string = R"(
+HloModule nonfusible_splat
+max_computation {
+  arg_0 = f32[] parameter(0)
+  arg_1 = f32[] parameter(1)
+  ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+ENTRY main {
+  constant_0 = f32[] constant(0.333333343)
+  splat_0 = f32[127,125]{1,0} broadcast(constant_0), dimensions={}
+  constant_1 = f32[] constant(0.66666)
+  splat_1 = f32[127,125]{1,0} broadcast(constant_1), dimensions={}
+  param_0 = f32[127,125]{1,0} parameter(0)
+  multiply_splats = f32[127,125]{1,0} multiply(splat_0, splat_1)
+  multiply_splat_param = f32[127,125]{1,0} multiply(multiply_splats, param_0)
+  constant_neg_inf = f32[] constant(-inf)
+  reduce = f32[127]{0} reduce(multiply_splat_param, constant_neg_inf), dimensions={1}, to_apply=max_computation
+  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+  ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
+}
+)";
+
+  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+  EXPECT_FALSE(
+      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+}
+
+TEST_F(
+    SoftmaxRewriterTritonTest,
+    DoesNotFuseIntermediateBinaryElementwiseWithSameSplatOperandsIntoDiamond) {
+  const std::string hlo_string = R"(
+HloModule nonfusible_splat_diamond
+max_computation {
+  arg_0 = f32[] parameter(0)
+  arg_1 = f32[] parameter(1)
+  ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+ENTRY main {
+  constant_0 = f32[] constant(0.333333343)
+  splat = f32[127,125]{1,0} broadcast(constant_0), dimensions={}
+  param_0 = f32[127,125]{1,0} parameter(0)
+  multiply = f32[127,125]{1,0} multiply(splat, splat)
+  add = f32[127,125]{1,0} add(param_0, multiply)
+  constant_neg_inf = f32[] constant(-inf)
+  reduce = f32[127]{0} reduce(add, constant_neg_inf), dimensions={1}, to_apply=max_computation
+  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+  ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
+}
+)";
+  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+  SoftmaxRewriterTriton fusion_rewriter(device_info_, ShapeSizeBytesFunction());
+  EXPECT_FALSE(
+      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+}
+
+TEST_F(SoftmaxRewriterTritonTest, CanFuseRMSNormDiamond) {
+  const std::string hlo_string = R"(
+HloModule rms_norm
+add_computation {
+  arg_0 = f32[] parameter(0)
+  arg_1 = f32[] parameter(1)
+  ROOT add.1 = f32[] add(arg_0, arg_1)
+}
+ENTRY main.30 {
+  param_0 = f32[10,10,10,128]{3,2,1,0} parameter(0)
+  multiply_param = f32[10,10,10,128]{3,2,1,0} multiply(param_0, param_0)
+  constant_0 = f32[] constant(0)
+  reduce = f32[10,10,10]{2,1,0} reduce(multiply_param, constant_0), dimensions={3}, to_apply=add_computation
+  constant_1 = f32[] constant(0.333333343)
+  splat = f32[10,10,10]{2,1,0} broadcast(constant_1), dimensions={}
+  multiply_splat = f32[10,10,10]{2,1,0} multiply(reduce, splat)
+  epsilon = f32[] constant(1e-06)
+  splat_epsilon = f32[10,10,10]{2,1,0} broadcast(epsilon), dimensions={}
+  add = f32[10,10,10]{2,1,0} add(multiply_splat, splat_epsilon)
+  rsqrt = f32[10,10,10]{2,1,0} rsqrt(add)
+  broadcast = f32[10,10,10,128]{3,2,1,0} broadcast(rsqrt), dimensions={0,1,2}
+  ROOT multiply = f32[10,10,10,128]{3,2,1,0} multiply(param_0, broadcast)
+}
+)";
+  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+      EXPECT_TRUE(
+          SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get())
+              .value());
+      EXPECT_TRUE(verifier().Run(module.get()).status().ok());
+      EXPECT_THAT(module->entry_computation()->root_instruction(),
+                  GmockMatch(m::Fusion(m::Parameter())
+                                 .WithPredicate(HasBlockLevelFusionConfig)));
+}
+
+TEST_F(
+    SoftmaxRewriterTritonTest,
+    CanFuseAndEmitBinaryElementwiseWhereTheFirstOperandIsASplatConstantBetweenDiamonds) {  // NOLINT(whitespace/line_length)
+  const std::string hlo_string = R"(
+HloModule fusible_diamonds
+add_computation {
+  arg_0.1 = f32[] parameter(0)
+  arg_1.1 = f32[] parameter(1)
+  ROOT add = f32[] add(arg_0.1, arg_1.1)
+}
+ENTRY main {
+  param_0 = f32[127,125]{1,0} parameter(0)
+  constant_neg_inf = f32[] constant(-inf)
+  reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=add_computation
+  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+  subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
+  constant = f32[] constant(0.333333343)
+  broadcast_splat = f32[127,125]{1,0} broadcast(constant), dimensions={}
+  multiply = f32[127,125]{1,0} multiply(broadcast_splat, subtract)
+  constant_zero = f32[] constant(0)
+  second_reduce = f32[127]{0} reduce(multiply, constant_zero), dimensions={1}, to_apply=add_computation
+  second_broadcast = f32[127,125]{1,0} broadcast(second_reduce), dimensions={0}
+  ROOT second_subtract = f32[127,125]{1,0} subtract(multiply, second_broadcast)
+}
+)";
+  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+  EXPECT_TRUE(
+      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+  EXPECT_TRUE(verifier().Run(module.get()).status().ok());
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(
+          m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
+}
+
+TEST_F(
+    SoftmaxRewriterTritonTest,
+    CanFuseAndEmitBinaryElementwiseWhereTheSecondOperandIsASplatConstantBetweenDiamonds) {  // NOLINT(whitespace/line_length)
+  const std::string hlo_string = R"(
+HloModule fusible_diamonds
+add_computation {
+  arg_0.1 = f32[] parameter(0)
+  arg_1.1 = f32[] parameter(1)
+  ROOT add = f32[] add(arg_0.1, arg_1.1)
+}
+ENTRY main {
+  param_0 = f32[127,125]{1,0} parameter(0)
+  constant_neg_inf = f32[] constant(-inf)
+  reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=add_computation
+  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+  subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
+  constant = f32[] constant(0.333333343)
+  broadcast_splat = f32[127,125]{1,0} broadcast(constant), dimensions={}
+  multiply = f32[127,125]{1,0} multiply(subtract, broadcast_splat)
+  constant_zero = f32[] constant(0)
+  second_reduce = f32[127]{0} reduce(multiply, constant_zero), dimensions={1}, to_apply=add_computation
+  second_broadcast = f32[127,125]{1,0} broadcast(second_reduce), dimensions={0}
+  ROOT second_subtract = f32[127,125]{1,0} subtract(multiply, second_broadcast)
+}
+)";
+  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+  EXPECT_TRUE(
+      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+  EXPECT_TRUE(verifier().Run(module.get()).status().ok());
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(
+          m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
+}
+
+TEST_F(
+    SoftmaxRewriterTritonTest,
+    CanFuseBinaryElementwiseWhereTheFirstOperandIsASplatConstantWithinDiamond) {
+  const std::string hlo_string = R"(
+HloModule fusible_diamond
+max_computation {
+  arg_0 = f32[] parameter(0)
+  arg_1 = f32[] parameter(1)
+  ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+ENTRY main {
+  param_0 = f32[127,125]{1,0} parameter(0)
+  constant_neg_inf = f32[] constant(-inf)
+  reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
+  constant = f32[] constant(0.333333343)
+  broadcast_splat = f32[127]{0} broadcast(constant), dimensions={}
+  multiply = f32[127]{0} multiply(broadcast_splat, reduce)
+  broadcast = f32[127,125]{1,0} broadcast(multiply), dimensions={0}
+  ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
+}
+)";
+  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+  EXPECT_TRUE(
+      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+  EXPECT_TRUE(verifier().Run(module.get()).status().ok());
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(
+          m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
+}
+
+TEST_F(SoftmaxRewriterTritonTest,
+       CanFuseBinaryElementwiseConsumerWhereTheFirstOperandIsASplatConstant) {
+  const std::string hlo_string = R"(
+HloModule fusible_diamond
+add_computation {
+  arg_0.1 = f32[] parameter(0)
+  arg_1.1 = f32[] parameter(1)
+  ROOT add = f32[] add(arg_0.1, arg_1.1)
+}
+ENTRY main {
+  param_0 = f32[127,125]{1,0} parameter(0)
+  constant_neg_inf = f32[] constant(-inf)
+  reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=add_computation
+  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+  subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
+  constant = f32[] constant(0.333333343)
+  broadcast_splat = f32[127,125]{1,0} broadcast(constant), dimensions={}
+  ROOT multiply = f32[127,125]{1,0} multiply(broadcast_splat, subtract)
+})";
+  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+  EXPECT_TRUE(
+      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+  EXPECT_TRUE(verifier().Run(module.get()).status().ok());
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(
+          m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
+}
+
+TEST_F(SoftmaxRewriterTritonTest,
+       CanFuseBinaryElementwiseOperationWhereOneOperandIsASharedSplatProducer) {
+  const std::string hlo_string = R"(
+HloModule nonfusible_diamond
+max_computation {
+  arg_0 = f32[] parameter(0)
+  arg_1 = f32[] parameter(1)
+  ROOT max = f32[] maximum(arg_0, arg_1)
+}
+ENTRY main {
+  param_0 = f32[127,125]{1,0} parameter(0)
+  constant_2 = f32[] constant(0.333333343)
+  broadcast_splat = f32[127,125]{1,0} broadcast(constant_2), dimensions={}
+  param_1 = f32[127,125]{1,0} parameter(1)
+  multiply_splat = f32[127,125]{1,0} multiply(broadcast_splat, param_1)
+  multiply = f32[127,125]{1,0} multiply(param_0, broadcast_splat)
+  constant_neg_inf = f32[] constant(-inf)
+  reduce = f32[127]{0} reduce(multiply, constant_neg_inf), dimensions={1}, to_apply=max_computation
+  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+  ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
+})";
+  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+  EXPECT_TRUE(
+      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+  EXPECT_TRUE(verifier().Run(module.get()).status().ok());
+  VLOG(2) << module->ToString();
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(
+          m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)));
+}
+
+TEST_F(
+    SoftmaxRewriterTritonTest,
+    DoesNotFuseBinaryElementwiseOperationWhereFirstOperandIsASplatAndSecondOperandIsASharedSplatProducer) {  // NOLINT(whitespace/line_length)
+  const std::string hlo_string = R"(
+HloModule nonfusible_diamond
+add_computation {
+  arg_0.1 = f32[] parameter(0)
+  arg_1.1 = f32[] parameter(1)
+  ROOT add = f32[] add(arg_0.1, arg_1.1)
+}
+ENTRY main {
+  param_0 = f32[127,125]{1,0} parameter(0)
+  constant_2 = f32[] constant(0.333333343)
+  broadcast_splat_shared = f32[127,125]{1,0} broadcast(constant_2), dimensions={}
+  param_1 = f32[127,125]{1,0} parameter(1)
+  multiply_splat_shared = f32[127,125]{1,0} multiply(broadcast_splat_shared, param_1)
+  constant_3 = f32[] constant(0.5)
+  broadcast_splat = f32[127,125]{1,0} broadcast(constant_3), dimensions={}
+  multiply_splat = f32[127,125]{1,0} multiply(broadcast_splat, broadcast_splat_shared)
+  multiply = f32[127,125]{1,0} multiply(param_0, multiply_splat)
+  constant_neg_inf = f32[] constant(-inf)
+  reduce = f32[127]{0} reduce(multiply, constant_neg_inf), dimensions={1}, to_apply=add_computation
+  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+  ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
+})";
+  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+  EXPECT_FALSE(
+      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+}
+
+TEST_F(SoftmaxRewriterTritonTest, FusionDecisionIsCapturedExplicitly) {
+  const std::string hlo_string = R"(
+HloModule softmax
+max_computation {
+  arg_0 = f32[] parameter(0)
+  arg_1 = f32[] parameter(1)
+  ROOT maximum = f32[] maximum(arg_0, arg_1)
+}
+ENTRY main {
+  param_0 = f32[127,125]{1,0} parameter(0)
+  identity_f8 = f8e5m2[] parameter(1)
+  identity = f32[] convert(identity_f8)
+  reduce = f32[127]{0} reduce(param_0, identity), dimensions={1}, to_apply=max_computation
+  broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0}
+  ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast)
+}
+)";
+
+  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+  SoftmaxRewriterTriton softmax_rewriter_triton(device_info_,
+                                                ShapeSizeBytesFunction());
+  int unmatched = 0, matched = 0;
+  for (HloInstruction* instruction :
+       module->entry_computation()->MakeInstructionPostOrder()) {
+    DiamondMatchingDecision decision =
+        softmax_rewriter_triton.MatchesTritonCompatibleClosedReductionDiamond(
+            instruction);
+    if (std::holds_alternative<FusionDecision>(decision)) {
+      std::string actual_decision =
+          std::get<FusionDecision>(decision).Explain();
+      EXPECT_THAT(
+          actual_decision,
+          AnyOf(
+              HasSubstr("Root is not elementwise binary"),
+              HasSubstr("identity is not a constant or a supported convert")));
+      unmatched++;
+    } else {
+      matched++;
+    }
+  }
+  EXPECT_EQ(unmatched, 6);
+  EXPECT_EQ(matched, 0);
+}
+
+TEST_F(
+    SoftmaxRewriterTritonTest,
+    FusesBinaryElementwiseIfIntermediateDiamondOpWithBroadcastAlongReductionDimAsParameter) {  // NOLINT(whitespace/line_length)
+  const std::string hlo_string = R"(
+HloModule h1
+
+add_computation {
+  y = f32[] parameter(1)
+  x = f32[] parameter(0)
+  ROOT add = f32[] add(x, y)
+}
+
+ENTRY main {
+  p0 = f32[32]{0} parameter(0)
+  p1 = f32[32,16]{1,0} parameter(1)
+  c = f32[] constant(0)
+
+  r0 = f32[32]{0} reduce(p1, c), dimensions={1}, to_apply=add_computation
+  b0 = f32[32,16]{1,0} broadcast(r0), dimensions={0}
+  b1 = f32[32,16]{1,0} broadcast(p0), dimensions={0}
+  add0 = f32[32,16]{1,0} add(b1, p1)
+  ROOT add1 = f32[32,16]{1,0} add(add0, b0)
+})";
+  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+  EXPECT_TRUE(
+      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+}
+
+TEST_F(
+    SoftmaxRewriterTritonTest,
+    FusesBinaryElementwiseIfIntermediateDiamondOpWithBroadcastAlongBatchDimAsParameter) {  // NOLINT(whitespace/line_length)
+  const std::string hlo_string = R"(
+HloModule h1
+
+add_computation {
+  y = f32[] parameter(1)
+  x = f32[] parameter(0)
+  ROOT add = f32[] add(x, y)
+}
+
+ENTRY main {
+  p0 = f32[16]{0} parameter(0)
+  p1 = f32[32,16]{1,0} parameter(1)
+  c = f32[] constant(0)
+
+  r0 = f32[32]{0} reduce(p1, c), dimensions={1}, to_apply=add_computation
+  b0 = f32[32,16]{1,0} broadcast(r0), dimensions={0}
+  b1 = f32[32,16]{1,0} broadcast(p0), dimensions={1}
+  add0 = f32[32,16]{1,0} add(b1, p1)
+  ROOT add1 = f32[32,16]{1,0} add(add0, b0)
+})";
+  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+  EXPECT_TRUE(
+      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+}
+
+TEST_F(
+    SoftmaxRewriterTritonTest,
+    FusesBinaryElementwiseIfIntermediateDiamondOpWithMultiDimTensorBroadcastAlongBatchDimAsParameter) {  // NOLINT(whitespace/line_length)
+  const std::string hlo_string = R"(
+HloModule h1
+
+add_computation {
+  y = f32[] parameter(1)
+  x = f32[] parameter(0)
+  ROOT add = f32[] add(x, y)
+}
+
+ENTRY main {
+  p0 = f32[32,16]{1,0} parameter(0)
+  p1 = f32[64,32,16]{2,1,0} parameter(1)
+  c = f32[] constant(0)
+
+  r0 = f32[64,32]{1,0} reduce(p1, c), dimensions={2}, to_apply=add_computation
+  b0 = f32[64,32,16]{2,1,0} broadcast(r0), dimensions={0,1}
+  b1 = f32[64,32,16]{2,1,0} broadcast(p0), dimensions={1,2}
+  add0 = f32[64,32,16]{2,1,0} add(b1, p1)
+  ROOT add1 = f32[64,32,16]{2,1,0} add(add0, b0)
+})";
+  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+  EXPECT_TRUE(
+      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+}
+
+TEST_F(
+    SoftmaxRewriterTritonTest,
+    FusesBinaryElementwiseIfIntermediateDiamondOpWithZeroDimTensorBroadcastAsParameter) {  // NOLINT(whitespace/line_length)
+  const std::string hlo_string = R"(
+HloModule h1
+
+add_computation {
+  y = f32[] parameter(1)
+  x = f32[] parameter(0)
+  ROOT add = f32[] add(x, y)
+}
+
+ENTRY main {
+  parameter_0 = f32[] parameter(0)
+  parameter_1 = f32[64,32,16]{2,1,0} parameter(1)
+  c = f32[] constant(0)
+
+  reduce_0 = f32[64,32]{1,0} reduce(parameter_1, c), dimensions={2}, to_apply=add_computation
+  broadcast_0 = f32[64,32,16]{2,1,0} broadcast(reduce_0), dimensions={0,1}
+  broadcast_1 = f32[64,32,16]{2,1,0} broadcast(parameter_0), dimensions={}
+  add_0 = f32[64,32,16]{2,1,0} add(broadcast_1, parameter_1)
+  ROOT add1 = f32[64,32,16]{2,1,0} add(add_0, broadcast_0)
+})";
+  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+  EXPECT_TRUE(
+      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+}
+
+TEST_F(
+    SoftmaxRewriterTritonTest,
+    FusesBinaryElementwiseIfIntermediateDiamondOpIsBroadcastOf1DParameterAlongNonReductionDimensions) {  // NOLINT(whitespace/line_length)
+  const std::string hlo_string = R"(
+HloModule h1
+
+add_computation {
+  y = f32[] parameter(1)
+  x = f32[] parameter(0)
+  ROOT add = f32[] add(x, y)
+}
+
+ENTRY main {
+  parameter_0 = f32[16] parameter(0)
+  parameter_1 = f32[64,32,16]{2,1,0} parameter(1)
+  c = f32[] constant(0)
+
+  reduce_0 = f32[64,32]{1,0} reduce(parameter_1, c), dimensions={2}, to_apply=add_computation
+  broadcast_0 = f32[64,32,16]{2,1,0} broadcast(reduce_0), dimensions={0,1}
+  broadcast_1 = f32[64,32,16]{2,1,0} broadcast(parameter_0), dimensions={2}
+  add_0 = f32[64,32,16]{2,1,0} add(broadcast_1, parameter_1)
+  ROOT add1 = f32[64,32,16]{2,1,0} add(add_0, broadcast_0)
+})";
+  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+  EXPECT_TRUE(
+      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+}
+
+TEST_F(
+    SoftmaxRewriterTritonTest,
+    DoesNotFuseBinaryElementwiseIfIntermediateDiamondOpIsBroadcastOf1DParameterAlongBothBatchAndReductionDimensions) {  // NOLINT(whitespace/line_length)
+  const std::string hlo_string = R"(
+HloModule h1
+
+add_computation {
+  y = f32[] parameter(1)
+  x = f32[] parameter(0)
+  ROOT add = f32[] add(x, y)
+}
+
+ENTRY main {
+  parameter_0 = f32[64] parameter(0)
+  parameter_1 = f32[64,32,16]{2,1,0} parameter(1)
+  c = f32[] constant(0)
+
+  reduce_0 = f32[64,32]{1,0} reduce(parameter_1, c), dimensions={2}, to_apply=add_computation
+  broadcast_0 = f32[64,32,16]{2,1,0} broadcast(reduce_0), dimensions={0,1}
+  broadcast_1 = f32[64,32,16]{2,1,0} broadcast(parameter_0), dimensions={0}
+  add_0 = f32[64,32,16]{2,1,0} add(broadcast_1, parameter_1)
+  ROOT add1 = f32[64,32,16]{2,1,0} add(add_0, broadcast_0)
+})";
+  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+  EXPECT_FALSE(
+      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+}
+
+TEST_F(
+    SoftmaxRewriterTritonTest,
+    DoesNotFuseBinaryElementwiseIfIntermediateDiamondOpWithBroadcastAlongBatchAndReductionDimAsParameter) {  // NOLINT(whitespace/line_length)
+  const std::string hlo_string = R"(
+HloModule h1
+
+add_computation {
+  y = f32[] parameter(1)
+  x = f32[] parameter(0)
+  ROOT add = f32[] add(x, y)
+}
+
+ENTRY main {
+  p0 = f32[8]{0} parameter(0)
+  p1 = f32[32,8,16]{2,1,0} parameter(1)
+  c = f32[] constant(0)
+
+  r0 = f32[32,8]{1,0} reduce(p1, c), dimensions={2}, to_apply=add_computation
+  b0 = f32[32,8,16]{2,1,0} broadcast(r0), dimensions={0,1}
+  b1 = f32[32,8,16]{2,1,0} broadcast(p0), dimensions={1}
+  add0 = f32[32,8,16]{2,1,0} add(b1, p1)
+  ROOT add1 = f32[32,8,16]{2,1,0} add(add0, b0)
+})";
+  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+  EXPECT_FALSE(
+      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+}
+
+TEST_F(
+    SoftmaxRewriterTritonTest,
+    DoesNotFuseBinaryElementwiseIfIntermediateDiamondOpWithPartialBroadcastToBatchDim) {  // NOLINT(whitespace/line_length)
+  const std::string hlo_string = R"(
+HloModule h1
+
+add_computation {
+  y = f32[] parameter(1)
+  x = f32[] parameter(0)
+  ROOT add = f32[] add(x, y)
+}
+
+ENTRY main {
+  p0 = f32[16,64]{1,0} parameter(0)
+  p1 = f32[8,16,32,64]{3,2,1,0} parameter(1)
+  c = f32[] constant(0)
+
+  r0 = f32[8,16,32]{2,1,0} reduce(p1, c), dimensions={3}, to_apply=add_computation
+  b0 = f32[8,16,32,64]{3,2,1,0} broadcast(r0), dimensions={0,1,2}
+  b1 = f32[8,16,32,64]{3,2,1,0} broadcast(p0), dimensions={1,3}
+  add0 = f32[8,16,32,64]{3,2,1,0} add(b1, p1)
+  ROOT add1 = f32[8,16,32,64]{3,2,1,0} add(add0, b0)
+}
+)";
+  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+  EXPECT_FALSE(
+      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+}
+
+TEST_F(
+    SoftmaxRewriterTritonTest,
+    DoesNotFuseBinaryElementwiseIfIntermediateDiamondOpWithMultiDimBroadcastAlongBatchDimAsParameter) {  // NOLINT(whitespace/line_length)
+  const std::string hlo_string = R"(
+HloModule h1
+
+add_computation {
+  y = f32[] parameter(1)
+  x = f32[] parameter(0)
+  ROOT add = f32[] add(x, y)
+}
+
+ENTRY main {
+  p0 = f32[32,16]{1,0} parameter(0)
+  p1 = f32[128,64,32,16]{3,2,1,0} parameter(1)
+  c = f32[] constant(0)
+
+  r0 = f32[128,64,32]{2,1,0} reduce(p1, c), dimensions={3}, to_apply=add_computation
+  b0 = f32[128,64,32,16]{3,2,1,0} broadcast(r0), dimensions={0,1,2}
+  b1 = f32[128,64,32,16]{3,2,1,0} broadcast(p0), dimensions={2,3}
+  add0 = f32[128,64,32,16]{3,2,1,0} add(b1, p1)
+  ROOT add1 = f32[128,64,32,16]{3,2,1,0} add(add0, b0)
+})";
+  auto module = ParseAndReturnVerifiedModule(hlo_string).value();
+  EXPECT_FALSE(
+      SoftmaxRewriterTritonMatchAndRewrite(device_info_, module.get()).value());
+}
+
+}  // anonymous namespace
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/sort_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/sort_rewriter.cc
new file mode 100644
index 0000000..b299db8
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/sort_rewriter.cc
@@ -0,0 +1,342 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/sort_rewriter.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <memory>
+#include <optional>
+#include <utility>
+#include <vector>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/comparison_util.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/service/gpu/runtime/cub_sort_thunk.h"
+#include "xla/service/stable_sort_expander.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/logging.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+// Analyze sort comparer function.
+struct SortComputationAnalysis {
+  int key_operand;  // 0 or 1
+  bool descending;
+};
+
+std::pair<int64_t, int64_t> ParametersFromCmpOperands(
+    const HloCompareInstruction* cmp_op) {
+  if (cmp_op == nullptr) {
+    return std::pair<int64_t, int64_t>(-1, -1);
+  }
+  const HloParameterInstruction* param0 =
+      DynCast<HloParameterInstruction>(cmp_op->operand(0));
+  const HloParameterInstruction* param1 =
+      DynCast<HloParameterInstruction>(cmp_op->operand(1));
+  return (param0 && param1) ? std::make_pair(param0->parameter_number(),
+                                             param1->parameter_number())
+                            : std::pair<int64_t, int64_t>(-1, -1);
+}
+
+// Returns sort info on compatible compare instructions. The instruction may
+// belong to a computation that has 2 or 4 operands. If this is the root
+// instruction of a computation with 4 parameters only succeeds in cases where
+// 2 of the parameters are ignored.
+std::optional<SortComputationAnalysis> AnalyzeCompareOp(
+    const HloInstruction* maybe_compare_op) {
+  // Root instruction must be a comparison with a valid direction.
+  const HloCompareInstruction* compare =
+      DynCast<HloCompareInstruction>(maybe_compare_op);
+  if (compare == nullptr || compare->direction() == ComparisonDirection::kEq ||
+      compare->direction() == ComparisonDirection::kNe) {
+    return std::nullopt;
+  }
+
+  // Compare should operate on the function parameters for a single tensor.
+  auto [index0, index1] = ParametersFromCmpOperands(compare);
+  if (index0 == -1 || index1 == -1) {
+    return std::nullopt;
+  }
+
+  // When sorting a pair of tensors, the parameters should be adjacent.
+  int first_index = std::min(index0, index1);
+  if (first_index % 2 != 0 || std::max(index0, index1) != first_index + 1) {
+    return std::nullopt;
+  }
+
+  // Return the tensor index and the sort direction.
+  bool descending = compare->direction() == ComparisonDirection::kGt ||
+                    compare->direction() == ComparisonDirection::kGe;
+  bool reverse = first_index != index0;
+  return SortComputationAnalysis{first_index / 2, descending != reverse};
+}
+
+// Detects a sort with these properties:
+// - Has two operands -- one is an iota op
+// - Has a comparison computation that takes 4 inputs and compares them
+// hierarchically, so that the iota inputs are the final tie-breaker.
+//
+// The above is equivalent to a stable sort where the iota operand is completely
+// ignored. That simpler comparator is the one detected in AnalyzeCompareOp, but
+// that's insufficient, because the StableSortExpander pass expands it into the
+// more complex version detected below.
+std::optional<SortComputationAnalysis> AnalyzeComplexSortComputation(
+    const HloSortInstruction& sort_op) {
+  auto computation = sort_op.called_computations().front();
+  if (computation->num_parameters() != 4) {
+    return std::nullopt;
+  }
+
+  int64_t iota_operand_index =
+      StableSortExpander::IotaOperandIndexForStableSort(sort_op);
+  if (iota_operand_index < 0) {
+    return std::nullopt;
+  }
+
+  auto root = computation->root_instruction();
+  if (root->opcode() != HloOpcode::kSelect) {
+    return std::nullopt;
+  }
+
+  // Check that the middle operand of the select compares the iota input.
+  auto iota_cmp = DynCast<HloCompareInstruction>(root->operand(1));
+  auto [iotap0, iotap1] = ParametersFromCmpOperands(iota_cmp);
+  if (iota_cmp == nullptr ||
+      iota_cmp->direction() != ComparisonDirection::kLt ||
+      iotap0 != iota_operand_index * 2 ||
+      iotap1 != iota_operand_index * 2 + 1) {
+    return std::nullopt;
+  }
+
+  // Check that the first operand of the select is an EQ comparison of the
+  // values (non-iota) input.
+  auto eq_cmp = DynCast<HloCompareInstruction>(root->operand(0));
+  if (eq_cmp == nullptr || eq_cmp->direction() != ComparisonDirection::kEq) {
+    return std::nullopt;
+  }
+
+  // EQ comparison case 1: direct comparison of parameters
+  auto [p0, p1] = ParametersFromCmpOperands(eq_cmp);
+  if (p0 < 0 || p1 < 0) {
+    // EQ comparison case 2: comparison of comparisons. This is what
+    // the StableSortExpander pass currently generates.
+    auto cmp = DynCast<HloCompareInstruction>(eq_cmp->operand(0));
+    auto cmp_reverse = DynCast<HloCompareInstruction>(eq_cmp->operand(1));
+    auto [a, b] = ParametersFromCmpOperands(cmp);
+    auto [p, q] = ParametersFromCmpOperands(cmp_reverse);
+    if (cmp == nullptr || cmp_reverse == nullptr || a < 0 || b < 0 || a != q ||
+        b != p || cmp->direction() != cmp_reverse->direction() ||
+        cmp->direction() == Comparison::Direction::kEq ||
+        cmp->direction() == Comparison::Direction::kNe) {
+      return std::nullopt;
+    }
+  }
+
+  // At this point only the last operand of the select needs to be verified.
+  return AnalyzeCompareOp(root->operand(2));
+}
+
+std::optional<SortComputationAnalysis> AnalyzeSortOp(
+    const HloSortInstruction& sort_op) {
+  auto computation = sort_op.called_computations().front();
+
+  // First, check if the computation is a simple compare op on the operands.
+  auto result = AnalyzeCompareOp(computation->root_instruction());
+  if (!result.has_value()) {
+    // If the above fails, check if the sort instruction and comparer are more
+    // complex, like what is produced by the StableSortExpander pass.
+    result = AnalyzeComplexSortComputation(sort_op);
+  }
+  return result;
+}
+
+// Create runner for CUB sort operation.
+absl::StatusOr<std::unique_ptr<CubSortRunnerInterface>> CreateRunner(
+    HloSortInstruction* sort_op, const SortComputationAnalysis& sort_config) {
+  int value_index = 1 - sort_config.key_operand;
+  return CubSortRunnerInterface::Create(
+      sort_op->operand(sort_config.key_operand)->shape().element_type(),
+      sort_op->operand_count() == 2
+          ? std::optional(sort_op->operand(value_index)->shape().element_type())
+          : std::nullopt);
+}
+
+// Verify that the sort tensor shape is supported by CUB.
+bool IsCubCompatibleSort(HloSortInstruction* sort_op) {
+  VLOG(1) << "Sort instruction: " << sort_op->name();
+  if (sort_op->operand_count() != 1 && sort_op->operand_count() != 2) {
+    VLOG(2) << "Unsupported operand count: " << sort_op->operand_count();
+    return false;
+  }
+
+  const Shape& operand_shape = sort_op->operand(0)->shape();
+  if (sort_op->sort_dimension() != operand_shape.rank() - 1) {
+    VLOG(2) << "Sort dimension should be the minor one";
+    return false;
+  }
+  if (Product(operand_shape.dimensions()) < SortRewriter::SortSizeThreshold()) {
+    VLOG(2) << "Tensor shape size is too small to see an improvement";
+    return false;
+  }
+
+  auto sort_config = AnalyzeSortOp(*sort_op);
+  if (!sort_config.has_value()) {
+    VLOG(2) << "Only simple compare computations are supported";
+    return false;
+  }
+  if (!CreateRunner(sort_op, *sort_config).ok()) {
+    VLOG(2) << "Unsupported operand types (no compiled CUB kernels)";
+    return false;
+  }
+  VLOG(2) << "Sort operation is compatible";
+  return true;
+}
+
+// Restore the result shape after sorting a pair of tensors.
+// The trailing argument is the scratch buffer which should be discarded.
+HloInstruction* UnpackResultPair(HloSortInstruction* sort_op,
+                                 HloInstruction* custom_call, bool swap) {
+  HloComputation* parent = sort_op->parent();
+  HloInstruction* gte0 =
+      parent->AddInstruction(HloInstruction::CreateGetTupleElement(
+          sort_op->operand(0)->shape(), custom_call, swap ? 1 : 0));
+  HloInstruction* gte1 =
+      parent->AddInstruction(HloInstruction::CreateGetTupleElement(
+          sort_op->operand(1)->shape(), custom_call, swap ? 0 : 1));
+  return parent->AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
+}
+
+}  // namespace
+
+// Rewrites a single sort instruction with a custom call.
+absl::StatusOr<bool> SortRewriter::RunOnInstruction(
+    HloSortInstruction* sort_op) {
+  // Get the sort tensor index and direction.
+  SortComputationAnalysis sort_config = AnalyzeSortOp(*sort_op).value();
+
+  // Get scratch size requirements from CUB.
+  const Shape& operand_shape = sort_op->operand(0)->shape();
+  int64_t batch_size = Product(operand_shape.dimensions()) /
+                       operand_shape.dimensions(sort_op->sort_dimension());
+
+  TF_ASSIGN_OR_RETURN(auto runner, CreateRunner(sort_op, sort_config));
+  TF_ASSIGN_OR_RETURN(
+      int64_t scratch_size,
+      runner->GetScratchSize(Product(operand_shape.dimensions()), batch_size));
+
+  // Align and increase scratch size to fit the offsets.
+  if (batch_size > 1) {
+    scratch_size += sizeof(int) - scratch_size % sizeof(int);
+    scratch_size += (batch_size + 1) * sizeof(int);
+  }
+
+  // Values are only present if sorting a pair of tensors.
+  HloInstruction* keys = sort_op->mutable_operand(0);
+  HloInstruction* values = nullptr;
+  if (sort_op->operand_count() == 2) {
+    values = sort_op->mutable_operand(1);
+    if (sort_config.key_operand == 1) {
+      std::swap(keys, values);
+    }
+  }
+
+  // Build the resulting shape for the custom call.
+  std::vector<Shape> shapes{keys->shape()};
+  std::vector<HloInstruction*> operands{keys};
+  if (values != nullptr) {
+    shapes.push_back(values->shape());
+    operands.push_back(values);
+  }
+  shapes.push_back(ShapeUtil::MakeShape(U8, {scratch_size}));
+  Shape call_shape = ShapeUtil::MakeTupleShape(absl::MakeSpan(shapes));
+
+  // Build the custom call instruction.
+  HloInstruction* custom_call =
+      sort_op->parent()->AddInstruction(HloInstruction::CreateCustomCall(
+          call_shape, absl::MakeSpan(operands), kCubDeviceRadixSortTarget));
+
+  xla::SortOptions backend_config;
+  backend_config.set_descending(sort_config.descending);
+  TF_RETURN_IF_ERROR(custom_call->set_backend_config(backend_config));
+
+  // Build the replacement instruction.
+  HloInstruction* replacement;
+  if (sort_op->operand_count() == 1) {
+    replacement =
+        sort_op->parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
+            sort_op->shape(), custom_call, 0));
+  } else {
+    replacement = UnpackResultPair(sort_op, custom_call,
+                                   /*swap=*/sort_config.key_operand == 1);
+  }
+
+  // Replace sort operation with custom call followed by GTE.
+  TF_RETURN_IF_ERROR(
+      sort_op->parent()->ReplaceInstruction(sort_op, replacement));
+  return true;
+}
+
+// Rewrites the sorts in the given computation into calls to CUB.
+absl::StatusOr<bool> SortRewriter::RunOnComputation(
+    HloComputation* computation) {
+  std::vector<HloSortInstruction*> sort_ops;
+  for (auto* inst : computation->instructions()) {
+    HloSortInstruction* sort = DynCast<HloSortInstruction>(inst);
+    if (sort != nullptr && IsCubCompatibleSort(sort)) {
+      sort_ops.push_back(sort);
+    }
+  }
+  bool changed = false;
+  for (auto* sort : sort_ops) {
+    TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(sort));
+    changed |= result;
+  }
+  return changed;
+}
+
+// Replace compatible sort operations with custom calls.
+absl::StatusOr<bool> SortRewriter::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  XLA_VLOG_LINES(2, "SortRewriter::Run(), before:\n" + module->ToString());
+  bool changed = false;
+  for (HloComputation* computation :
+       module->MakeNonfusionComputations(execution_threads)) {
+    TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation));
+    changed |= result;
+  }
+  XLA_VLOG_LINES(2, "SortRewriter::Run(), after:\n" + module->ToString());
+  return changed;
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/sort_rewriter.h b/third_party/xla/xla/service/gpu/transforms/sort_rewriter.h
new file mode 100644
index 0000000..406df7a
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/sort_rewriter.h
@@ -0,0 +1,63 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_SORT_REWRITER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_SORT_REWRITER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// Rewrites sort operations into CustomCall HLOs that call into CUB.
+// Only a subset of shapes is supported - either a single tensor with a simple
+// compare function or a pair of tensors where keys are unsigned integers.
+
+class SortRewriter : public HloModulePass {
+ public:
+  absl::string_view name() const override { return "sort-rewriter"; }
+
+  // CUB radix sort is slower than XLA sort on small shapes, so do not rewrite
+  // tensors with sizes below this limit.
+  static int SortSizeThreshold() { return sort_size_threshold_; }
+  static void SetSortSizeThresholdForTestingOnly(int threshold) {
+    // We need to be able to reduce the threshold for testing, so that the tests
+    // can run and compare against the reference interpreter, which is quite
+    // slow.
+    sort_size_threshold_ = threshold;
+  }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+  absl::StatusOr<bool> RunOnInstruction(HloSortInstruction* sort_op);
+  absl::StatusOr<bool> RunOnComputation(HloComputation* computation);
+
+  static inline int sort_size_threshold_ = 16385;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_SORT_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/sort_rewriter_stub.cc b/third_party/xla/xla/service/gpu/transforms/sort_rewriter_stub.cc
new file mode 100644
index 0000000..e9bf60c
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/sort_rewriter_stub.cc
@@ -0,0 +1,45 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/gpu/transforms/sort_rewriter.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+absl::StatusOr<bool> SortRewriter::RunOnInstruction(
+    HloSortInstruction* sort_op) {
+  return false;
+}
+
+absl::StatusOr<bool> SortRewriter::RunOnComputation(
+    HloComputation* computation) {
+  return false;
+}
+
+absl::StatusOr<bool> SortRewriter::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  return false;
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/sort_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/sort_rewriter_test.cc
new file mode 100644
index 0000000..853de5b
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/sort_rewriter_test.cc
@@ -0,0 +1,453 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/sort_rewriter.h"
+
+#include <utility>
+
+#include <gtest/gtest.h>
+#include "xla/error_spec.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/statusor.h"
+#include "tsl/platform/test.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+namespace m = ::xla::match;
+
+class SortRewriterTest : public HloTestBase {
+ public:
+  void SetUp() override {
+    HloTestBase::SetUp();
+    SortRewriter::SetSortSizeThresholdForTestingOnly(1000);
+  }
+
+  bool RunModuleAndPass(HloModule* module) {
+    auto cloned = module->Clone();
+    bool changed = SortRewriter().Run(module).value();
+    if (changed) {
+      // Here we run an end to end test to make sure that SortRewriter does
+      // not introduce an incorrect rewrite. To do this, we need to clone the
+      // original module because the interpreter cannot process the already
+      // optimized module.
+      EXPECT_TRUE(RunAndCompare(std::move(cloned), ErrorSpec{0, 0}));
+    }
+    return changed;
+  }
+
+  void ExpectDirection(const HloInstruction* instruction, bool descending) {
+    auto config = instruction->backend_config<xla::SortOptions>();
+    EXPECT_EQ(config->descending(), descending);
+  }
+};
+
+// Basic sort: ascending.
+TEST_F(SortRewriterTest, SortKeysLessThan) {
+  constexpr char kHlo[] = R"(
+HloModule TestModule
+
+%compare {
+  %lhs = f32[] parameter(0)
+  %rhs = f32[] parameter(1)
+  ROOT %lt = pred[] compare(%lhs, %rhs), direction=LT
+}
+
+ENTRY %main {
+  %input = f32[1000] parameter(0)
+  ROOT %sort = f32[1000] sort(%input), dimensions={0}, to_apply=%compare
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
+  EXPECT_TRUE(RunModuleAndPass(module.get()));
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(m::GetTupleElement(
+          m::CustomCall({kCubDeviceRadixSortTarget}, m::Parameter()), 0)));
+  ExpectDirection(module->entry_computation()->root_instruction()->operand(0),
+                  /*descending=*/false);
+}
+
+// Basic sort: descending.
+TEST_F(SortRewriterTest, SortKeysGreaterThan) {
+  constexpr char kHlo[] = R"(
+HloModule TestModule
+
+%compare {
+  %lhs = f32[] parameter(0)
+  %rhs = f32[] parameter(1)
+  ROOT %gt = pred[] compare(%lhs, %rhs), direction=GT
+}
+
+ENTRY %main {
+  %input = f32[1000] parameter(0)
+  ROOT %sort = f32[1000] sort(%input), dimensions={0}, to_apply=%compare
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
+  EXPECT_TRUE(RunModuleAndPass(module.get()));
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(m::GetTupleElement(
+          m::CustomCall({kCubDeviceRadixSortTarget}, m::Parameter()), 0)));
+  ExpectDirection(module->entry_computation()->root_instruction()->operand(0),
+                  /*descending=*/true);
+}
+
+// Comparer swaps the parameter order -> direction is reversed.
+TEST_F(SortRewriterTest, SortKeysGreaterThanSwapped) {
+  constexpr char kHlo[] = R"(
+HloModule TestModule
+
+%compare {
+  %lhs = f32[] parameter(1)
+  %rhs = f32[] parameter(0)
+  ROOT %gt = pred[] compare(%lhs, %rhs), direction=GT
+}
+
+ENTRY %main {
+  %input = f32[1000] parameter(0)
+  ROOT %sort = f32[1000] sort(%input), dimensions={0}, to_apply=%compare
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
+  EXPECT_TRUE(RunModuleAndPass(module.get()));
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(m::GetTupleElement(
+          m::CustomCall({kCubDeviceRadixSortTarget}, m::Parameter()), 0)));
+  ExpectDirection(module->entry_computation()->root_instruction()->operand(0),
+                  /*descending=*/false);
+}
+
+// Sort a pair of tensors, keys go first.
+TEST_F(SortRewriterTest, SortPairs) {
+  constexpr char kHlo[] = R"(
+HloModule TestModule
+
+%compare {
+  %lhs_key = u32[] parameter(0)
+  %rhs_key = u32[] parameter(1)
+  %lhs_value = f32[] parameter(2)
+  %rhs_value = f32[] parameter(3)
+  ROOT %lt = pred[] compare(%lhs_key, %rhs_key), direction=LT
+}
+
+ENTRY %main {
+  %input_keys = u32[1000] parameter(0)
+  %input_values = f32[1000] parameter(1)
+  ROOT %sort = (u32[1000], f32[1000]) sort(%input_keys, %input_values),
+      dimensions={0}, to_apply=%compare
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
+  EXPECT_TRUE(RunModuleAndPass(module.get()));
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch(m::Tuple(m::GetTupleElement(m::CustomCall(), 0),
+                                  m::GetTupleElement(m::CustomCall(), 1))));
+}
+
+// Sort a pair of tensors, keys go last.
+TEST_F(SortRewriterTest, SortPairsSwapped) {
+  constexpr char kHlo[] = R"(
+HloModule TestModule
+
+%compare {
+  %lhs_value = f32[] parameter(0)
+  %rhs_value = f32[] parameter(1)
+  %lhs_key = u32[] parameter(2)
+  %rhs_key = u32[] parameter(3)
+  ROOT %lt = pred[] compare(%lhs_key, %rhs_key), direction=LT
+}
+
+ENTRY %main {
+  %input_values = f32[1000] parameter(0)
+  %input_keys = u32[1000] parameter(1)
+  ROOT %sort = (f32[1000], u32[1000]) sort(%input_values, %input_keys),
+      dimensions={0}, to_apply=%compare
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
+  EXPECT_TRUE(RunModuleAndPass(module.get()));
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch(m::Tuple(m::GetTupleElement(m::CustomCall(), 1),
+                                  m::GetTupleElement(m::CustomCall(), 0))));
+}
+
+// CUB sort doesn't support more than two tensors.
+TEST_F(SortRewriterTest, NoRewriteManyTensors) {
+  constexpr char kHlo[] = R"(
+HloModule TestModule
+
+%compare {
+  %lhs = f32[] parameter(0)
+  %rhs = f32[] parameter(1)
+  %unused1 = f64[] parameter(2)
+  %unused2 = f64[] parameter(3)
+  %unused3 = u64[] parameter(4)
+  %unused4 = u64[] parameter(5)
+  ROOT %lt = pred[] compare(%lhs, %rhs), direction=LT
+}
+
+ENTRY %main {
+  %input1 = f32[1000] parameter(0)
+  %input2 = f64[1000] parameter(1)
+  %input3 = u64[1000] parameter(2)
+  ROOT %sort = (f32[1000], f64[1000], u64[1000]) sort(%input1, %input2, %input3),
+      dimensions={0}, to_apply=%compare
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
+  EXPECT_FALSE(RunModuleAndPass(module.get()));
+}
+
+// Only 1D shapes are supported.
+TEST_F(SortRewriterTest, NoRewriteNonMinorSortDimension) {
+  constexpr char kHlo[] = R"(
+HloModule TestModule
+
+%compare {
+  %lhs = f32[] parameter(0)
+  %rhs = f32[] parameter(1)
+  ROOT %lt = pred[] compare(%lhs, %rhs), direction=LT
+}
+
+ENTRY %main {
+  %input = f32[1000,4] parameter(0)
+  ROOT %sort = f32[1000,4] sort(%input), dimensions={0}, to_apply=%compare
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
+  EXPECT_FALSE(RunModuleAndPass(module.get()));
+}
+
+// Kernels are compiled for a subset of types.
+TEST_F(SortRewriterTest, NoRewriteUnsupportedType) {
+  constexpr char kHlo[] = R"(
+HloModule TestModule
+
+%compare {
+  %lhs = pred[] parameter(0)
+  %rhs = pred[] parameter(1)
+  ROOT %lt = pred[] compare(%lhs, %rhs), direction=LT
+}
+
+ENTRY %main {
+  %input = pred[1000] parameter(0)
+  ROOT %sort = pred[1000] sort(%input), dimensions={0}, to_apply=%compare
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
+  EXPECT_FALSE(RunModuleAndPass(module.get()));
+}
+
+// Comparer must be a simple function.
+TEST_F(SortRewriterTest, NoRewriteComplexComparer) {
+  constexpr char kHlo[] = R"(
+HloModule TestModule
+
+%compare {
+  %lhs = f32[] parameter(0)
+  %lhs_scaled = f32[] multiply(%lhs, f32[] constant(2))
+  %rhs = f32[] parameter(1)
+  ROOT %lt = pred[] compare(%lhs_scaled, %rhs), direction=LT
+}
+
+ENTRY %main {
+  %input = f32[1000] parameter(0)
+  ROOT %sort = f32[1000] sort(%input), dimensions={0}, to_apply=%compare
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
+  EXPECT_FALSE(RunModuleAndPass(module.get()));
+}
+
+// Comparer must use adjacent input values.
+TEST_F(SortRewriterTest, NoRewriteMixedKeysValues) {
+  constexpr char kHlo[] = R"(
+HloModule TestModule
+
+%compare {
+  %lhs_key = u32[] parameter(0)
+  %rhs_key = u32[] parameter(1)
+  %lhs_value = u32[] parameter(2)
+  %rhs_value = u32[] parameter(3)
+  ROOT %mixed = pred[] compare(%rhs_key, %lhs_value), direction=LT
+}
+
+ENTRY %main {
+  %input_keys = u32[1000] parameter(0)
+  %input_values = u32[1000] parameter(1)
+  ROOT %sort = (u32[1000], u32[1000]) sort(%input_keys, %input_values),
+      dimensions={0}, to_apply=%compare
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
+  EXPECT_FALSE(RunModuleAndPass(module.get()));
+}
+
+// Small shapes do not see improvement from CUB sort.
+TEST_F(SortRewriterTest, NoRewriteSmallSize) {
+  constexpr char kHlo[] = R"(
+HloModule TestModule
+
+%compare {
+  %lhs = f32[] parameter(0)
+  %rhs = f32[] parameter(1)
+  ROOT %lt = pred[] compare(%lhs, %rhs), direction=LT
+}
+
+ENTRY %main {
+  %input = f32[100] parameter(0)
+  ROOT %sort = f32[100] sort(%input), dimensions={0}, to_apply=%compare
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
+  EXPECT_FALSE(RunModuleAndPass(module.get()));
+}
+
+// Basic sort: with batch dimension.
+TEST_F(SortRewriterTest, SortWithBatchDim) {
+  constexpr char kHlo[] = R"(
+HloModule TestModule
+
+%compare {
+  %lhs = f32[] parameter(0)
+  %rhs = f32[] parameter(1)
+  ROOT %lt = pred[] compare(%lhs, %rhs), direction=LT
+}
+
+ENTRY %main {
+  %input = f32[10,100] parameter(0)
+  ROOT %sort = f32[10,100] sort(%input), dimensions={1}, to_apply=%compare
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
+  EXPECT_TRUE(RunModuleAndPass(module.get()));
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(m::GetTupleElement(
+          m::CustomCall({kCubDeviceRadixSortTarget}, m::Parameter()), 0)));
+  ExpectDirection(module->entry_computation()->root_instruction()->operand(0),
+                  /*descending=*/false);
+}
+
+// Basic sort: with multiple batch dimensions.
+TEST_F(SortRewriterTest, SortWithMultipleBatchDims) {
+  constexpr char kHlo[] = R"(
+HloModule TestModule
+
+%compare {
+  %lhs = f32[] parameter(0)
+  %rhs = f32[] parameter(1)
+  ROOT %lt = pred[] compare(%lhs, %rhs), direction=LT
+}
+
+ENTRY %main {
+  %input = f32[10,10,10] parameter(0)
+  ROOT %sort = f32[10,10,10] sort(%input), dimensions={2}, to_apply=%compare
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
+  EXPECT_TRUE(RunModuleAndPass(module.get()));
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(m::GetTupleElement(
+          m::CustomCall({kCubDeviceRadixSortTarget}, m::Parameter()), 0)));
+  ExpectDirection(module->entry_computation()->root_instruction()->operand(0),
+                  /*descending=*/false);
+}
+
+// Sort a pair of tensors (values, indices generated by iota) with a complex
+// compare.
+TEST_F(SortRewriterTest, SortPairsIotaComparerSimple) {
+  constexpr char kHlo[] = R"(
+HloModule TestModule
+
+%compare {
+  %lhs = u16[] parameter(0)
+  %rhs = u16[] parameter(1)
+  %lhs_index = s32[] parameter(2)
+  %rhs_index = s32[] parameter(3)
+
+  cmp_indices = pred[] compare(%lhs_index, %rhs_index), direction=LT
+  cmp_lr = pred[] compare(%lhs, %rhs), direction=GT
+  cmp_eq = pred[] compare(%lhs, %rhs), direction=EQ
+
+  ROOT %lt = pred[] select(cmp_eq, cmp_indices, cmp_lr)
+}
+
+ENTRY %main {
+  %inputs = u16[1000] parameter(0)
+  %iota = s32[1000] iota(), iota_dimension=0
+  ROOT %sort = (u16[1000], s32[1000]) sort(%inputs, %iota),
+      dimensions={0}, to_apply=%compare
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
+  EXPECT_TRUE(RunModuleAndPass(module.get()));
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch(m::Tuple(m::GetTupleElement(m::CustomCall(), 0),
+                                  m::GetTupleElement(m::CustomCall(), 1))));
+}
+
+// Sort a pair of tensors (values, indices generated by iota) with a complex
+// compare computation that matches the output of the StableSortExpander pass.
+TEST_F(SortRewriterTest, SortPairsIotaComparerLikeStableSortExpander) {
+  constexpr char kHlo[] = R"(
+HloModule TestModule
+
+%compare {
+  %lhs = u16[] parameter(0)
+  %rhs = u16[] parameter(1)
+  %lhs_index = s32[] parameter(2)
+  %rhs_index = s32[] parameter(3)
+
+  cmp_indices = pred[] compare(%lhs_index, %rhs_index), direction=LT
+  cmp_lr = pred[] compare(%lhs, %rhs), direction=GT
+  cmp_rl = pred[] compare(%rhs, %lhs), direction=GT
+  cmp_eq = pred[] compare(cmp_lr, cmp_rl), direction=EQ
+
+  ROOT %lt = pred[] select(cmp_eq, cmp_indices, cmp_lr)
+}
+
+ENTRY %main {
+  %inputs = u16[1000] parameter(0)
+  %iota = s32[1000] iota(), iota_dimension=0
+  ROOT %sort = (u16[1000], s32[1000]) sort(%inputs, %iota),
+      dimensions={0}, to_apply=%compare
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
+  EXPECT_TRUE(RunModuleAndPass(module.get()));
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch(m::Tuple(m::GetTupleElement(m::CustomCall(), 0),
+                                  m::GetTupleElement(m::CustomCall(), 1))));
+}
+
+TEST_F(SortRewriterTest, SortSizeThresholdIsSet) {
+  EXPECT_EQ(SortRewriter::SortSizeThreshold(), 1000);
+}
+
+}  // namespace
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator.cc b/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator.cc
new file mode 100644
index 0000000..68805b1
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator.cc
@@ -0,0 +1,225 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/stream_attribute_annotator.h"
+
+#include <cstdint>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/hlo/utils/hlo_query.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/gpu_fusible.h"
+#include "xla/service/gpu/runtime/thunk.h"
+#include "xla/util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/logging.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla::gpu {
+namespace {
+
+bool IsOnlyRootNonDefaultStream(HloComputation* computation) {
+  HloInstruction* root = computation->root_instruction();
+  auto root_gpu_config = root->backend_config<GpuBackendConfig>();
+  if (!root_gpu_config.ok() || root->opcode() == HloOpcode::kTuple) {
+    return false;
+  }
+  int64_t root_stream_id = root_gpu_config->operation_queue_id();
+  VLOG(2) << "Found fusion computation's root stream id to be "
+          << root_stream_id;
+  if (root_stream_id == Thunk::kDefaultExecutionStreamId.value()) {
+    return false;
+  }
+  for (HloInstruction* instr : computation->MakeInstructionPostOrder()) {
+    if (instr == root) {
+      continue;
+    }
+    int64_t instr_stream_id =
+        instr->backend_config<GpuBackendConfig>()->operation_queue_id();
+    if (instr_stream_id != Thunk::kDefaultExecutionStreamId.value() &&
+        instr_stream_id != root_stream_id) {
+      return false;
+    }
+  }
+  return true;
+}
+
+absl::StatusOr<bool> AnnotateStreamAttributesForInstruction(
+    HloInstruction* instr, GpuBackendConfig& instr_gpu_config) {
+  if (instr->called_computations().size() != 1) {
+    return false;
+  }
+  HloComputation* called_comp = instr->called_computations()[0];
+  int64_t stream_id = instr_gpu_config.operation_queue_id();
+
+  if (!IsOnlyRootNonDefaultStream(called_comp) ||
+      stream_id != Thunk::kDefaultExecutionStreamId.value()) {
+    return false;
+  }
+
+  auto comp_root_gpu_config =
+      called_comp->root_instruction()->backend_config<GpuBackendConfig>();
+
+  instr_gpu_config.set_operation_queue_id(
+      comp_root_gpu_config->operation_queue_id());
+  *instr_gpu_config.mutable_wait_on_operation_queues() =
+      comp_root_gpu_config->wait_on_operation_queues();
+  TF_RETURN_IF_ERROR(instr->set_backend_config(instr_gpu_config));
+  return true;
+}
+
+absl::StatusOr<bool> AnnotateStreamAttributesForCopyStart(
+    HloInstruction* instr, int64_t channel_id,
+    GpuBackendConfig& instr_gpu_config) {
+  // Do nothing if copy-start has already been annotated
+  if (instr_gpu_config.operation_queue_id() !=
+      Thunk::kDefaultExecutionStreamId.value()) {
+    return false;
+  }
+  instr_gpu_config.set_operation_queue_id(channel_id);
+  TF_RETURN_IF_ERROR(instr->set_backend_config(instr_gpu_config));
+  VLOG(3) << "Add copy-start's backend config: " << channel_id;
+  return true;
+}
+
+absl::StatusOr<bool> WrapIntoFusionAndAnnotateStreamAttributes(
+    HloInstruction* instruction, int64_t channel_id,
+    GpuBackendConfig& instr_gpu_config) {
+  auto* computation = instruction->parent();
+  auto* module = computation->parent();
+  auto* fusion_instruction =
+      computation->AddInstruction(HloInstruction::CreateFusion(
+          instruction->shape(), ChooseFusionKind(*instruction, *instruction),
+          instruction));
+  const absl::string_view wrapped_opcode =
+      HloOpcodeString(instruction->opcode());
+  module->SetAndUniquifyInstrName(fusion_instruction,
+                                  absl::StrCat("wrapped_", wrapped_opcode));
+  module->SetAndUniquifyComputationName(
+      fusion_instruction->fused_instructions_computation(),
+      absl::StrCat("wrapped_", wrapped_opcode, "_computation"));
+  if (module->has_schedule()) {
+    // Update the scheduling names of the fusion and its root instruction
+    // to match their newly assigned instruction names during creation.
+    fusion_instruction->set_metadata_scheduling_name(
+        fusion_instruction->name());
+    HloInstruction* root = fusion_instruction->fused_expression_root();
+    root->set_metadata_scheduling_name(root->name());
+    module->schedule().replace_instruction(computation, instruction,
+                                           fusion_instruction);
+  }
+  TF_RETURN_IF_ERROR(fusion_instruction->CopyAllControlDepsFrom(instruction));
+  TF_RETURN_IF_ERROR(instruction->DropAllControlDeps());
+  TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(fusion_instruction));
+  TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction));
+
+  instr_gpu_config.set_operation_queue_id(channel_id);
+  TF_RETURN_IF_ERROR(fusion_instruction->set_backend_config(instr_gpu_config));
+  VLOG(3) << "Add async stream " << channel_id << " and wrapped instruction "
+          << instruction->ToString();
+  VLOG(3) << "  Fusion wrapper: " << fusion_instruction->ToString();
+  return true;
+}
+
+absl::StatusOr<bool> AnnotateStreamAttributesForUsers(
+    HloInstruction* instr, GpuBackendConfig& instr_gpu_config) {
+  bool changed = false;
+  int64_t stream_id = instr_gpu_config.operation_queue_id();
+  if (stream_id == Thunk::kDefaultExecutionStreamId.value()) {
+    return changed;
+  }
+  std::vector<HloInstruction*> all_consumers;
+  for (auto user : instr->users()) {
+    if (user->opcode() == HloOpcode::kGetTupleElement) {
+      user = user->users()[0];
+    }
+    all_consumers.push_back(user);
+  }
+
+  for (auto user : all_consumers) {
+    TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
+                        user->backend_config<GpuBackendConfig>());
+    auto it = absl::c_find(gpu_config.wait_on_operation_queues(), stream_id);
+    if (it == gpu_config.wait_on_operation_queues().end() &&
+        gpu_config.operation_queue_id() != stream_id) {
+      gpu_config.mutable_wait_on_operation_queues()->Add(stream_id);
+      TF_RETURN_IF_ERROR(user->set_backend_config(gpu_config));
+      changed = true;
+    }
+  }
+
+  return changed;
+}
+}  // namespace
+
+absl::StatusOr<bool> StreamAttributeAnnotator::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  XLA_VLOG_LINES(
+      5, "StreamAttributeAnnotator::Run(), before:\n" + module->ToString());
+  bool changed = false;
+  int64_t channel_id = hlo_query::NextChannelId(*module);
+  for (const HloComputation* comp :
+       module->MakeComputationPostOrder(execution_threads)) {
+    for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
+      auto instr_gpu_config = instr->backend_config<GpuBackendConfig>();
+      if (!instr_gpu_config.ok()) {
+        continue;
+      }
+      // For fusion instruction, only annotate
+      // when the root of fusion is a single instruction
+      // running on non-default stream.
+      if (instr->opcode() == HloOpcode::kFusion) {
+        TF_ASSIGN_OR_RETURN(bool comp_result,
+                            AnnotateStreamAttributesForInstruction(
+                                instr, instr_gpu_config.value()));
+        changed |= comp_result;
+      } else if (instr->opcode() == HloOpcode::kCopyStart) {
+        TF_ASSIGN_OR_RETURN(bool comp_result,
+                            AnnotateStreamAttributesForCopyStart(
+                                instr, channel_id, instr_gpu_config.value()));
+        changed |= comp_result;
+        continue;
+      } else if (comp->IsAsyncComputation() &&
+                 (instr->opcode() == HloOpcode::kDynamicSlice ||
+                  instr->opcode() == HloOpcode::kDynamicUpdateSlice)) {
+        TF_ASSIGN_OR_RETURN(bool comp_result,
+                            WrapIntoFusionAndAnnotateStreamAttributes(
+                                instr, channel_id, instr_gpu_config.value()));
+        changed |= comp_result;
+        continue;
+      }
+
+      TF_ASSIGN_OR_RETURN(
+          bool user_result,
+          AnnotateStreamAttributesForUsers(instr, instr_gpu_config.value()));
+      changed |= user_result;
+    }
+  }
+  XLA_VLOG_LINES(
+      5, "StreamAttributeAnnotator::Run(), after:\n" + module->ToString());
+  return changed;
+}
+
+}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator.h b/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator.h
new file mode 100644
index 0000000..81816f8
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator.h
@@ -0,0 +1,60 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_STREAM_ATTRIBUTE_ANNOTATOR_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_STREAM_ATTRIBUTE_ANNOTATOR_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla::gpu {
+
+// This pass checks to see if:
+//  1.  there's any instruction, that
+//      consumes data from other computes streams,
+//      is missing "wait_on_operation_queues" attribute.
+//  2.  there's any fusion instruction with non-default
+//      stream fusion root.
+// It will annotate the corresponding instruction with
+// the correct attribute in GpuBackendConfig.
+// Instructions annotated with operation_queue_id > 0
+// will be wrapped with AsyncInstruction and split into
+// AsyncStart and AsyncDone in the
+// StreamAttributeAsyncWrapper pass.
+// We also check if there's any non-default-stream
+// instruction's user doesn't have the correct "wait_on_operation_queues"
+// attribute and set it with producer's operation_queue_id.
+// "wait_on_operation_queues" will need to used by the emitter to emit the
+// correct WaitForStreams thunk.
+
+class StreamAttributeAnnotator : public HloModulePass {
+ public:
+  absl::string_view name() const override {
+    return "stream-attribute-annotator";
+  }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+}  // namespace xla::gpu
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_STREAM_ATTRIBUTE_ANNOTATOR_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator_test.cc b/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator_test.cc
new file mode 100644
index 0000000..c7d2ca5
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator_test.cc
@@ -0,0 +1,340 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/stream_attribute_annotator.h"
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "absl/algorithm/container.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/tests/filecheck.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla::gpu {
+namespace {
+
+using StreamAttributeAnnotatorTest = HloTestBase;
+
+TEST_F(StreamAttributeAnnotatorTest, AllUsersAreAnnotated) {
+  constexpr absl::string_view kHloString = R"(
+  HloModule ModuleWithAsync
+
+  ENTRY entry {
+    p1_32 = f32[1] parameter(0)
+    p2_32 = f32[1] parameter(1)
+    add_32 = f32[1] add(p1_32, p2_32), backend_config={"operation_queue_id":"1", "wait_on_operation_queues":[]}
+    exp_32 = f32[1] exponential(add_32)
+
+    neg32 = f32[1] negate(add_32)
+    ROOT add_out_32 = f32[1] add(neg32, exp_32)
+  }
+  )";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(kHloString));
+
+  StreamAttributeAnnotator attr_annotator;
+  bool changed;
+  TF_ASSERT_OK_AND_ASSIGN(changed, attr_annotator.Run(module.get()));
+  EXPECT_TRUE(changed);
+
+  const HloInstruction* add = FindInstruction(module.get(), "add_32");
+  for (auto user : add->users()) {
+    // Every user should have an annotation.
+    EXPECT_TRUE(user->has_backend_config());
+    TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config,
+                            user->backend_config<GpuBackendConfig>());
+    EXPECT_EQ(gpu_config.wait_on_operation_queues()[0], 1);
+  }
+}
+
+TEST_F(StreamAttributeAnnotatorTest, MultipleStreamsAreCombined) {
+  constexpr absl::string_view kHloString = R"(
+  HloModule ModuleWithAsync
+
+  ENTRY entry {
+    p1_32 = f32[1] parameter(0)
+    p2_32 = f32[1] parameter(1)
+    add_32 = f32[1] add(p1_32, p2_32), backend_config={"operation_queue_id":"1", "wait_on_operation_queues":[]}
+    exp_32 = f32[1] exponential(p2_32), backend_config={"operation_queue_id":"2", "wait_on_operation_queues":[]}
+
+    ROOT add_out_32 = f32[1] add(add_32, exp_32)
+  }
+  )";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(kHloString));
+
+  StreamAttributeAnnotator attr_annotator;
+  bool changed;
+  TF_ASSERT_OK_AND_ASSIGN(changed, attr_annotator.Run(module.get()));
+  EXPECT_TRUE(changed);
+
+  const HloInstruction* root = module->entry_computation()->root_instruction();
+  // Root should wait on 2 streams.
+  EXPECT_TRUE(root->has_backend_config());
+  TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config,
+                          root->backend_config<GpuBackendConfig>());
+  std::vector<int64_t> expected_stream_ids = {1, 2};
+  for (auto id : expected_stream_ids) {
+    auto it = absl::c_find(gpu_config.wait_on_operation_queues(), id);
+    EXPECT_NE(it, gpu_config.wait_on_operation_queues().end());
+  }
+}
+
+TEST_F(StreamAttributeAnnotatorTest, GTEUserIsAnnotated) {
+  constexpr absl::string_view kHloString = R"(
+  HloModule ModuleWithAsync
+
+  ENTRY entry {
+    p1_32 = f32[16,32] parameter(0)
+    p2_32 = f32[32,16] parameter(1)
+
+    custom-call.3 = (f32[16,16], s8[1028]{0}) custom-call(p1_32, p2_32), custom_call_target="__cublas$gemm", backend_config={"operation_queue_id":"1","wait_on_operation_queues":[],"gemm_backend_config":{"alpha_real":1,"alpha_imag":0,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":["1"],"rhs_contracting_dimensions":["0"],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT","grad_x":false,"grad_y":false}}
+    get-tuple-element.24 = f32[16,16] get-tuple-element(custom-call.3), index=0
+
+    exp_32 = f32[16,16] exponential(get-tuple-element.24)
+
+    ROOT neg32 = f32[16,16] negate(exp_32)
+  }
+  )";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(kHloString));
+
+  StreamAttributeAnnotator attr_annotator;
+  bool changed;
+  TF_ASSERT_OK_AND_ASSIGN(changed, attr_annotator.Run(module.get()));
+  EXPECT_TRUE(changed);
+
+  const HloInstruction* exp = FindInstruction(module.get(), "exp_32");
+  EXPECT_TRUE(exp->has_backend_config());
+  TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config,
+                          exp->backend_config<GpuBackendConfig>());
+  EXPECT_EQ(gpu_config.wait_on_operation_queues()[0], 1);
+}
+
+TEST_F(StreamAttributeAnnotatorTest, FusionIsAnnotated) {
+  constexpr absl::string_view kHloString = R"(
+  HloModule ModuleWithFusion
+
+  fused_computation.1 {
+    fusion_p0_32 = f32[16,16] parameter(0)
+    fusion_p2_32 = f32[16,16] parameter(1)
+    ROOT add = f32[16,16] add(fusion_p0_32, fusion_p2_32), backend_config={"operation_queue_id":"1","wait_on_operation_queues":[]}
+  }
+
+  ENTRY entry {
+    p1_32 = f32[16,16] parameter(0)
+    p2_32 = f32[16,16] parameter(1)
+    ROOT fusion.1 = f32[16,16] fusion(p1_32, p2_32), kind=kLoop, calls=fused_computation.1
+  }
+  )";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(kHloString));
+
+  StreamAttributeAnnotator attr_annotator;
+  bool changed;
+  TF_ASSERT_OK_AND_ASSIGN(changed, attr_annotator.Run(module.get()));
+  EXPECT_TRUE(changed);
+
+  const HloInstruction* fusion = FindInstruction(module.get(), "fusion.1");
+  EXPECT_TRUE(fusion->has_backend_config());
+  TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config,
+                          fusion->backend_config<GpuBackendConfig>());
+  EXPECT_EQ(gpu_config.operation_queue_id(), 1);
+}
+
+TEST_F(StreamAttributeAnnotatorTest, CopyStartIsAnnotated) {
+  constexpr absl::string_view kHloString = R"(
+  HloModule offloading
+    ENTRY %main (param_0: f32[1024], param_1: f32[1024]) -> f32[1024] {
+    %param_1 = f32[1024]{0} parameter(1)
+    %param_0 = f32[1024]{0} parameter(0)
+    %res_3 = f32[1024]{0} add(f32[1024]{0} %param_0, f32[1024]{0} %param_1)
+    %copy-start = (f32[1024]{0:S(5)}, f32[1024]{0}, u32[]) copy-start(f32[1024]{0} %res_3)
+    %res_4 = f32[1024]{0} tanh(f32[1024]{0} %res_3)
+    %copy-start.2 = (f32[1024]{0:S(5)}, f32[1024]{0}, u32[]) copy-start(f32[1024]{0} %res_4)
+    %res_5 = f32[1024]{0} tanh(f32[1024]{0} %res_4)
+    %copy-done = f32[1024]{0:S(5)} copy-done((f32[1024]{0:S(5)}, f32[1024]{0}, u32[]) %copy-start)
+    %res_6 = f32[1024]{0} tanh(f32[1024]{0} %res_5)
+    %copy-done.2 = f32[1024]{0:S(5)} copy-done((f32[1024]{0:S(5)}, f32[1024]{0}, u32[]) %copy-start.2)
+    %copy-start.3 = (f32[1024]{0}, f32[1024]{0:S(5)}, u32[]) copy-start(f32[1024]{0:S(5)} %copy-done.2)
+    %res_7 = f32[1024]{0} add(f32[1024]{0} %res_6, f32[1024]{0} %res_6)
+    %copy-start.1 = (f32[1024]{0}, f32[1024]{0:S(5)}, u32[]) copy-start(f32[1024]{0:S(5)} %copy-done)
+    %res_8 = f32[1024]{0} add(f32[1024]{0} %res_7, f32[1024]{0} %res_5)
+    %copy-done.3 = f32[1024]{0} copy-done((f32[1024]{0}, f32[1024]{0:S(5)}, u32[]) %copy-start.3)
+    %res_9 = f32[1024]{0} add(f32[1024]{0} %res_8, f32[1024]{0} %copy-done.3)
+    %copy-done.1 = f32[1024]{0} copy-done((f32[1024]{0}, f32[1024]{0:S(5)}, u32[]) %copy-start.1)
+    %res_10 = f32[1024]{0} add(f32[1024]{0} %res_9, f32[1024]{0} %copy-done.1)
+    ROOT %res_11 = f32[1024]{0} tanh(f32[1024]{0} %res_10)
+  }
+  )";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(kHloString));
+
+  StreamAttributeAnnotator attr_annotator;
+  bool changed;
+  TF_ASSERT_OK_AND_ASSIGN(changed, attr_annotator.Run(module.get()));
+  EXPECT_TRUE(changed);
+
+  for (std::string i : {"", ".1", ".2", ".3"}) {
+    const HloInstruction* cp_start =
+        FindInstruction(module.get(), "copy-start" + i);
+    EXPECT_TRUE(cp_start->has_backend_config());
+    TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config,
+                            cp_start->backend_config<GpuBackendConfig>());
+    EXPECT_EQ(gpu_config.operation_queue_id(), 1);
+  }
+}
+
+TEST_F(StreamAttributeAnnotatorTest, DynamicUpdateSliceWrappedAndAnnotated) {
+  constexpr absl::string_view kHloString = R"(
+  HloModule ModuleWithAsyncDynamicUpdateSlice, is_scheduled=true
+
+  ENTRY entry (param_0: f32[256,128,128], param_1: f32[1,128,128]) -> f32[256,128,128] {
+    param_0 = f32[256,128,128]{2,1,0:S(5)} parameter(0), metadata={scheduling_name="param_0"}
+    param_1 = f32[1,128,128]{2,1,0} parameter(1), metadata={scheduling_name="param_1"}
+    izero = s32[] constant(0), metadata={scheduling_name="izero"}
+    dynamic-update-slice-start.2 = ((f32[256,128,128]{2,1,0:S(5)}, f32[1,128,128]{2,1,0}, s32[], s32[], s32[]), f32[256,128,128]{2,1,0:S(5)}, u32[])
+        dynamic-update-slice-start(param_0, param_1, izero, izero, izero),
+        metadata={scheduling_name="dynamic-update-slice-start.2"}
+    ROOT dynamic-update-slice-done.2 = f32[256,128,128]{2,1,0:S(5)}
+        dynamic-update-slice-done(dynamic-update-slice-start.2),
+        metadata={scheduling_name="dynamic-update-slice-done.2"}
+  }
+  )";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(kHloString));
+  EXPECT_TRUE(module->has_schedule());
+
+  TF_ASSERT_OK_AND_ASSIGN(bool changed,
+                          StreamAttributeAnnotator().Run(module.get()));
+  EXPECT_TRUE(changed);
+
+  // Check that the dynamic-update-slice instruction is wrapped in a fusion
+  // and the fusion is annotated with the correct operation_queue_id.
+  const HloInstruction* dus =
+      FindInstruction(module.get(), HloOpcode::kDynamicUpdateSlice);
+  const HloComputation* computation = dus->parent();
+  EXPECT_TRUE(computation->IsFusionComputation());
+  const HloInstruction* fusion = computation->FusionInstruction();
+  EXPECT_EQ(fusion->opcode(), HloOpcode::kFusion);
+  EXPECT_TRUE(fusion->parent()->IsAsyncComputation());
+
+  EXPECT_TRUE(fusion->has_backend_config());
+  TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config,
+                          fusion->backend_config<GpuBackendConfig>());
+  EXPECT_EQ(gpu_config.operation_queue_id(), 1);
+  // Check if the schedule name the same as the instruction name
+  for (const auto* comp : module->computations()) {
+    for (const auto* instruction : comp->instructions()) {
+      if (!instruction->metadata().scheduling_name().empty()) {
+        EXPECT_EQ(instruction->name(),
+                  instruction->metadata().scheduling_name());
+      }
+    }
+  }
+  constexpr absl::string_view kExpectedSchedulingName = R"(
+// CHECK: %wrapped_dynamic-update-slice_computation
+// CHECK: ROOT %[[DYNAMIC_UPDATE_SLICE:.+]] = f32[256,128,128]{2,1,0:S(5)} dynamic-update-slice(
+// CHECK-SAME: metadata={scheduling_name="[[DYNAMIC_UPDATE_SLICE]]"}
+// CHECK: %[[DYNAMIC_UPDATE_SLICE_START:.+]] = {{.*}} fusion-start(
+// CHECK-SAME: calls=%wrapped_dynamic-update-slice_computation
+// CHECK-SAME: metadata={scheduling_name="[[DYNAMIC_UPDATE_SLICE_START]]"}
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(
+      bool filecheck_matches,
+      RunFileCheck(
+          module->ToString(HloPrintOptions().set_print_operand_shape(false)),
+          kExpectedSchedulingName));
+  EXPECT_TRUE(filecheck_matches);
+}
+
+TEST_F(StreamAttributeAnnotatorTest, DynamicSliceWrappedAndAnnotated) {
+  constexpr absl::string_view kHloString = R"(
+  HloModule ModuleWithAsyncDynamicSlice, is_scheduled=true
+
+  ENTRY entry (param_0: f32[256,128,128]) -> f32[1,128,128] {
+    param_0 = f32[256,128,128]{2,1,0:S(5)} parameter(0), metadata={scheduling_name="param_0"}
+    izero = s32[] constant(0), metadata={scheduling_name="izero"}
+    dynamic-slice-start.2 = ((f32[256,128,128]{2,1,0:S(5)}, s32[], s32[], s32[]), f32[1,128,128]{2,1,0}, u32[])
+        dynamic-slice-start(param_0, izero, izero, izero), dynamic_slice_sizes={1,128,128},
+        metadata={scheduling_name="dynamic-slice-start.2"}
+    ROOT dynamic-slice-done.2 = f32[1,128,128]{2,1,0}
+        dynamic-slice-done(dynamic-slice-start.2),
+        metadata={scheduling_name="dynamic-slice-done.2"}
+  }
+  )";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(kHloString));
+
+  EXPECT_TRUE(module->has_schedule());
+  TF_ASSERT_OK_AND_ASSIGN(bool changed,
+                          StreamAttributeAnnotator().Run(module.get()));
+  EXPECT_TRUE(changed);
+
+  // Check that the dynamic-slice instruction is wrapped in a fusion
+  // and the fusion is annotated with the correct operation_queue_id.
+  const HloInstruction* ds =
+      FindInstruction(module.get(), HloOpcode::kDynamicSlice);
+  const HloComputation* computation = ds->parent();
+  EXPECT_TRUE(computation->IsFusionComputation());
+  const HloInstruction* fusion = computation->FusionInstruction();
+  EXPECT_EQ(fusion->opcode(), HloOpcode::kFusion);
+  EXPECT_TRUE(fusion->parent()->IsAsyncComputation());
+
+  EXPECT_TRUE(fusion->has_backend_config());
+  TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config,
+                          fusion->backend_config<GpuBackendConfig>());
+  EXPECT_EQ(gpu_config.operation_queue_id(), 1);
+  // Check if the schedule name the same as the instruction name
+  for (const auto* comp : module->computations()) {
+    for (const auto* instruction : comp->instructions()) {
+      if (!instruction->metadata().scheduling_name().empty()) {
+        EXPECT_EQ(instruction->name(),
+                  instruction->metadata().scheduling_name());
+      }
+    }
+  }
+  constexpr absl::string_view kExpectedSchedulingName = R"(
+// CHECK: %wrapped_dynamic-slice_computation
+// CHECK: ROOT %[[DYNAMIC_SLICE:.+]] = f32[1,128,128]{2,1,0} dynamic-slice(
+// CHECK-SAME: metadata={scheduling_name="[[DYNAMIC_SLICE]]"}
+// CHECK: %[[DYNAMIC_SLICE_START:.+]] = {{.*}} fusion-start(
+// CHECK-SAME: calls=%wrapped_dynamic-slice_computation
+// CHECK-SAME: metadata={scheduling_name="[[DYNAMIC_SLICE_START]]"}
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(
+      bool filecheck_matches,
+      RunFileCheck(
+          module->ToString(HloPrintOptions().set_print_operand_shape(false)),
+          kExpectedSchedulingName));
+  EXPECT_TRUE(filecheck_matches);
+}
+}  // namespace
+}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/transforms/stream_attribute_async_wrapper.cc b/third_party/xla/xla/service/gpu/transforms/stream_attribute_async_wrapper.cc
new file mode 100644
index 0000000..be0eb6f
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/stream_attribute_async_wrapper.cc
@@ -0,0 +1,74 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/stream_attribute_async_wrapper.h"
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/runtime/thunk.h"
+#include "xla/util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/logging.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla::gpu {
+
+namespace {
+static absl::StatusOr<bool> AsynchronizeInstruction(HloInstruction* instr) {
+  auto instr_gpu_config = instr->backend_config<GpuBackendConfig>();
+  if (!instr_gpu_config.ok() || instr_gpu_config->operation_queue_id() ==
+                                    Thunk::kDefaultExecutionStreamId.value()) {
+    return false;
+  }
+  HloComputation* computation = instr->parent();
+  TF_ASSIGN_OR_RETURN(
+      HloInstruction * done,
+      computation->CreateAsyncInstructions(
+          instr, {}, StreamAttributeAsyncWrapper::kParallelExecutionThread,
+          /*replace=*/true));
+  TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
+                      done->backend_config<GpuBackendConfig>());
+  // Set the false delay of done op to be false so it can be scheduled
+  // far apart from start.
+  gpu_config.set_force_earliest_schedule(false);
+  TF_RETURN_IF_ERROR(done->set_backend_config(gpu_config));
+  VLOG(5) << "Created async instruction: " << done->ToString();
+  return true;
+}
+}  // namespace
+
+absl::StatusOr<bool> StreamAttributeAsyncWrapper::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  XLA_VLOG_LINES(
+      2, "StreamAttributeAsyncWrapper::Run(), before:\n" + module->ToString());
+  bool changed = false;
+  for (const HloComputation* comp :
+       module->MakeNonfusionComputations(execution_threads)) {
+    for (HloInstruction* instr : comp->instructions()) {
+      TF_ASSIGN_OR_RETURN(bool result, AsynchronizeInstruction(instr));
+      changed |= result;
+    }
+  }
+  XLA_VLOG_LINES(
+      2, "StreamAttributeAsyncWrapper::Run(), after:\n" + module->ToString());
+  return changed;
+}
+
+}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/transforms/stream_attribute_async_wrapper.h b/third_party/xla/xla/service/gpu/transforms/stream_attribute_async_wrapper.h
new file mode 100644
index 0000000..157b579
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/stream_attribute_async_wrapper.h
@@ -0,0 +1,49 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_STREAM_ATTRIBUTE_ASYNC_WRAPPER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_STREAM_ATTRIBUTE_ASYNC_WRAPPER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla::gpu {
+
+// This pass will find the instructions that
+// are annotated with non-default stream id in backend configs
+// by the StreamAttributeAnnotator pass
+// and wrap them using AsyncStartDone pairs to achieve
+// asynchronous executions.
+class StreamAttributeAsyncWrapper : public HloModulePass {
+ public:
+  inline static constexpr char kParallelExecutionThread[] = "parallel";
+
+  absl::string_view name() const override {
+    return "async-stream-attribute-wrapper";
+  }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+}  // namespace xla::gpu
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_STREAM_ATTRIBUTE_ASYNC_WRAPPER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/stream_attribute_async_wrapper_test.cc b/third_party/xla/xla/service/gpu/transforms/stream_attribute_async_wrapper_test.cc
new file mode 100644
index 0000000..32ed4c5
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/stream_attribute_async_wrapper_test.cc
@@ -0,0 +1,77 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/stream_attribute_async_wrapper.h"
+
+#include <memory>
+
+#include <gtest/gtest.h>
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla::gpu {
+namespace {
+
+using StreamAttributeAsyncWrapperTest = HloTestBase;
+
+TEST_F(StreamAttributeAsyncWrapperTest, NonDefaultOpIsWrapped) {
+  constexpr absl::string_view kHloString = R"(
+  HloModule ModuleWithAsync
+
+  ENTRY entry {
+    p1_32 = f32[1] parameter(0)
+    p2_32 = f32[1] parameter(1)
+    add_32 = f32[1] add(p1_32, p2_32), backend_config={"operation_queue_id":"1", "wait_on_operation_queues":[], "force_earliest_schedule":true}
+    ROOT exp_32 = f32[1] exponential(add_32), backend_config={"operation_queue_id":"0", "wait_on_operation_queues":[1]}
+  }
+  )";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(kHloString));
+
+  StreamAttributeAsyncWrapper async_wrapper;
+  bool changed;
+  TF_ASSERT_OK_AND_ASSIGN(changed, async_wrapper.Run(module.get()));
+  EXPECT_TRUE(changed);
+  const HloInstruction* producer =
+      module->entry_computation()->root_instruction()->operand(0);
+  EXPECT_EQ(producer->opcode(), HloOpcode::kAsyncDone);
+  // Verify that the force_earliest_schedule is set to false for the done op.
+  TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig done_gpu_config,
+                          producer->backend_config<GpuBackendConfig>());
+  EXPECT_EQ(done_gpu_config.force_earliest_schedule(), false);
+
+  const HloInstruction* producer_start = producer->operand(0);
+  EXPECT_EQ(producer_start->opcode(), HloOpcode::kAsyncStart);
+
+  const xla::HloAsyncInstruction* async =
+      Cast<HloAsyncInstruction>(producer_start);
+  EXPECT_EQ(async->async_wrapped_opcode(), HloOpcode::kAdd);
+  // Verify that the backend config is kept intact
+  TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config,
+                          async->backend_config<GpuBackendConfig>());
+  EXPECT_EQ(gpu_config.operation_queue_id(), 1);
+  EXPECT_EQ(gpu_config.force_earliest_schedule(), true);
+  EXPECT_EQ(async->async_execution_thread(), "parallel");
+}
+}  // namespace
+}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/transforms/topk_specializer.cc b/third_party/xla/xla/service/gpu/transforms/topk_specializer.cc
new file mode 100644
index 0000000..1cc6206
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/topk_specializer.cc
@@ -0,0 +1,113 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/topk_specializer.h"
+
+#include <stddef.h>
+
+#include <initializer_list>
+#include <string>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/primitive_util.h"
+#include "xla/service/hlo.pb.h"
+#include "xla/service/tuple_util.h"
+#include "xla/shape.h"
+#include "xla/status_macros.h"
+#include "xla/util.h"
+#include "xla/xla_data.pb.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+
+absl::StatusOr<HloInstruction*> SmallBufferOptimization(
+    HloCustomCallInstruction* topk) {
+  Shape data_shape = topk->operand(0)->shape();
+  auto supported_dtypes = {F32, BF16};
+  if (!absl::c_linear_search(supported_dtypes, data_shape.element_type())) {
+    return InvalidArgument(
+        "Invalid Dtype: %s",
+        primitive_util::LowercasePrimitiveTypeName(data_shape.element_type()));
+  }
+  // We only support topk of the shape [x] or [batch, x].
+  if (data_shape.dimensions_size() > 2) {
+    return InvalidArgument("Invalid input dimensions: %s",
+                           data_shape.ToString());
+  }
+  bool has_batch = data_shape.dimensions_size() == 2;
+  constexpr size_t max_k = 16;
+  constexpr size_t min_n = 1024;
+  size_t n = data_shape.dimensions(has_batch ? 1 : 0);
+  size_t k = topk->shape().tuple_shapes(0).dimensions(has_batch ? 1 : 0);
+  if (k > max_k) {
+    return InvalidArgument("k too large (%d), must be <= %d", k, max_k);
+  }
+  if (n < min_n) {
+    return InvalidArgument("Input too small (n=%d, min_n=%d)", n, min_n);
+  }
+  HloComputation* comp = topk->parent();
+  HloInstruction* new_topk =
+      comp->AddInstruction(HloInstruction::CreateCustomCall(
+          topk->shape(), topk->operands(),
+          // We don't need the original to_apply, but keeping it around allows
+          // us to round-trip this CustomCall on tests.
+          topk->to_apply(), "__gpu$TopK",
+          /*opaque=*/"", CustomCallApiVersion::API_VERSION_TYPED_FFI));
+  return TupleUtil::ExtractPrefix(new_topk, 2);
+}
+
+class SpecializeTopkVisitor : public DfsHloRewriteVisitor {
+ public:
+  absl::Status HandleCustomCall(HloInstruction* inst) override {
+    HloCustomCallInstruction* topk = DynCast<HloCustomCallInstruction>(inst);
+    if (topk == nullptr || topk->custom_call_target() != "TopK") {
+      return absl::OkStatus();
+    }
+    TF_RET_CHECK(topk->operand_count() == 1);
+
+    if (auto small_topk = SmallBufferOptimization(topk); small_topk.ok()) {
+      return ReplaceInstruction(topk, *small_topk);
+    } else {
+      VLOG(2) << "Small TopK optimization doesn't match: "
+              << small_topk.status();
+    }
+
+    return absl::OkStatus();
+  }
+};
+
+}  // namespace
+
+absl::StatusOr<bool> TopkSpecializer::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  return SpecializeTopkVisitor().RunOnModule(module, execution_threads);
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/topk_specializer.h b/third_party/xla/xla/service/gpu/transforms/topk_specializer.h
new file mode 100644
index 0000000..e3ec565
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/topk_specializer.h
@@ -0,0 +1,41 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_TOPK_SPECIALIZER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_TOPK_SPECIALIZER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla::gpu {
+
+// This pass transforms eligible TopK CustomCall into a call to be executed by
+// runtime/topk.cc.
+class TopkSpecializer : public HloModulePass {
+ public:
+  absl::string_view name() const override { return "topk-specializer"; }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+}  // namespace xla::gpu
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_TOPK_SPECIALIZER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/topk_specializer_test.cc b/third_party/xla/xla/service/gpu/transforms/topk_specializer_test.cc
new file mode 100644
index 0000000..96d7e49
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/topk_specializer_test.cc
@@ -0,0 +1,160 @@
+/* Copyright 2022 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/topk_specializer.h"
+
+#include <stddef.h>
+
+#include <memory>
+#include <optional>
+#include <string>
+#include <string_view>
+#include <tuple>
+#include <utility>
+
+#include <gtest/gtest.h>
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/status.h"
+#include "absl/strings/string_view.h"
+#include "absl/strings/substitute.h"
+#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/service/platform_util.h"
+#include "xla/service/topk_rewriter.h"
+#include "xla/shape_util.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace {
+
+using ::testing::Combine;
+using ::testing::Values;
+
+// Params:
+//  - n_kb: number of elements in kilobytes.
+//  - k: number of elements to return.
+//  - batch_size
+//  - dtype
+using ParameterizedInterface =
+    ::testing::WithParamInterface<std::tuple<int, int, int, std::string_view>>;
+
+class TopkTest : public HloTestBase, public ParameterizedInterface {
+ public:
+  TopkTest()
+      : HloTestBase(*PlatformUtil::GetPlatform("gpu"),
+                    *PlatformUtil::GetPlatform("gpu"), true, true, {}) {}
+
+ protected:
+  absl::StatusOr<std::unique_ptr<HloModule>> TopkHlo(int n, int k,
+                                                     int batch_size,
+                                                     std::string_view dtype) {
+    return ParseAndReturnVerifiedModule(absl::Substitute(
+        R"(
+      %compare {
+        %p.1.lhs.40628 = s32[] parameter(2)
+        %p.1.rhs.40629 = s32[] parameter(3)
+        %constant.40630 = pred[] constant(true)
+        %broadcast.40631 = pred[] broadcast(pred[] %constant.40630), dimensions={}
+        %p.0.lhs.40626 = f32[] parameter(0)
+        %p.0.rhs.40627 = f32[] parameter(1)
+        %compare.40632 = pred[] compare(f32[] %p.0.lhs.40626, f32[] %p.0.rhs.40627), direction=GT, type=TOTALORDER
+        ROOT %select.40633 = pred[] select(pred[] %broadcast.40631, pred[] %compare.40632, pred[] %broadcast.40631)
+      }
+
+      ENTRY top_k {
+        %arg = $3[$2,$0] parameter(0)
+        ROOT %result = ($3[$2,$1], s32[$2,$1]) custom-call(%arg), custom_call_target="TopK", to_apply=%compare
+      }
+    )",
+        n, k, batch_size, dtype));
+  }
+};
+
+class GeneralizeTopkVisitor : public DfsHloRewriteVisitor {
+ public:
+  absl::Status HandleCustomCall(HloInstruction* inst) override {
+    HloCustomCallInstruction* topk = DynCast<HloCustomCallInstruction>(inst);
+    if (topk == nullptr || topk->custom_call_target() != "__gpu$TopK") {
+      return absl::OkStatus();
+    }
+    HloComputation* comp = topk->parent();
+    auto original_shape = ShapeUtil::SliceTuple(topk->shape(), 0, 2);
+    HloInstruction* original_topk =
+        comp->AddInstruction(HloInstruction::CreateCustomCall(
+            original_shape, topk->operands(), topk->to_apply(), "TopK"));
+    // TupleUtil::ExtractPrefix creates the following structure:
+    //      TopK
+    //   -------------
+    //   |     |     |
+    //  Get   Get   Get
+    //    \    |     /
+    //     CreateTuple
+    // Here we walk to Create Tuple and replace it with the original topk.
+    HloInstruction* new_tuple = topk->users()[0]->users()[0];
+    return ReplaceInstruction(new_tuple, original_topk);
+  }
+};
+
+class GeneralizeTopk : public HloModulePass {
+ public:
+  absl::string_view name() const override { return "generalized-topk"; }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(HloModule* module,
+                           const absl::flat_hash_set<absl::string_view>&
+                               execution_threads) override {
+    return GeneralizeTopkVisitor().RunOnModule(module, execution_threads);
+  }
+};
+
+void ToSortAndSlice(HloModule* module) {
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, GeneralizeTopk().Run(module));
+  ASSERT_TRUE(changed);
+  TF_ASSERT_OK_AND_ASSIGN(changed, TopkDecomposer().Run(module));
+  ASSERT_TRUE(changed);
+}
+
+TEST_P(TopkTest, ProducesCorrectResult) {
+  const auto [n_kb, k, batch_size, dtype] = GetParam();
+  const size_t n = n_kb * 1024;
+  TF_ASSERT_OK_AND_ASSIGN(auto topk_module, TopkHlo(n, k, batch_size, dtype));
+  TF_ASSERT_OK_AND_ASSIGN(bool changed,
+                          gpu::TopkSpecializer().Run(topk_module.get()));
+  ASSERT_TRUE(changed);
+  EXPECT_TRUE(
+      RunAndCompare(std::move(topk_module), std::nullopt, ToSortAndSlice));
+}
+
+INSTANTIATE_TEST_SUITE_P(
+    TopkTests, TopkTest,
+    Combine(
+        /*n_kb=*/Values(1, 8, 12, 32),
+        /*k=*/Values(1, 2, 4, 8, 16, 7, 12),
+        /*batch_size=*/Values(1, 16, 32, 64, 128),
+        /*dtype=*/Values(absl::string_view("f32"), "bf16")),
+    [](const auto& info) {
+      return absl::Substitute("n$0KiB_k$1_batch_size$2_$3",
+                              std::get<0>(info.param), std::get<1>(info.param),
+                              std::get<2>(info.param), std::get<3>(info.param));
+    });
+
+}  // namespace
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/topk_splitter.cc b/third_party/xla/xla/service/gpu/transforms/topk_splitter.cc
new file mode 100644
index 0000000..41ba135
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/topk_splitter.cc
@@ -0,0 +1,154 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/topk_splitter.h"
+
+#include <algorithm>
+#include <cmath>
+#include <cstddef>
+#include <cstdint>
+#include <string>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/log.h"
+#include "absl/numeric/bits.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/hlo_creation_utils.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+constexpr size_t kRequiredAlignment = 1024;
+constexpr size_t kMaximumBatchSize = 1024;
+
+class TopkSplitterVisitor : public DfsHloRewriteVisitor {
+ public:
+  explicit TopkSplitterVisitor(size_t split_threshold)
+      : split_threshold_(split_threshold) {}
+
+  absl::Status HandleCustomCall(HloInstruction* inst) override {
+    HloCustomCallInstruction* topk = DynCast<HloCustomCallInstruction>(inst);
+    if (topk == nullptr || topk->custom_call_target() != "TopK") {
+      return absl::OkStatus();
+    }
+    HloComputation* comp = inst->parent();
+    Shape data_shape = topk->operand(0)->shape();
+    bool has_batch = data_shape.dimensions_size() == 2;
+    // TODO(doak): Support multiple batches.
+    if (has_batch && data_shape.dimensions(0) != 1) {
+      return absl::OkStatus();
+    }
+    size_t n = data_shape.dimensions(has_batch ? 1 : 0);
+    int64_t k = topk->shape().tuple_shapes(0).dimensions(has_batch ? 1 : 0);
+    // If K approaches N, splitting the input will not be beneficial anymore.
+    if (k > sqrt(n)) {
+      return absl::OkStatus();
+    }
+    // TODO(doak): Relax this alignment requirement.
+    if (n % kRequiredAlignment != 0) {
+      return absl::OkStatus();
+    }
+    if (n < split_threshold_) return absl::OkStatus();
+    int new_batch =
+        std::min(absl::bit_floor(n / split_threshold_), kMaximumBatchSize);
+    int new_n = n / new_batch;
+    // Split the input into B batches and compute TopK over the batched arrays.
+    Shape split_input_shape =
+        ShapeUtil::MakeShape(data_shape.element_type(), {new_batch, new_n});
+    TF_ASSIGN_OR_RETURN(
+        HloInstruction * reshaped,
+        MakeReshapeHlo(split_input_shape, topk->mutable_operand(0)));
+    Shape batch_topk_shape = ShapeUtil::MakeTupleShape(
+        {ShapeUtil::MakeShape(data_shape.element_type(), {new_batch, k}),
+         ShapeUtil::MakeShape(S32, {new_batch, k})});
+    HloInstruction* batch_topk =
+        comp->AddInstruction(HloInstruction::CreateCustomCall(
+            batch_topk_shape, {reshaped}, topk->to_apply(), "TopK",
+            /*opaque=*/""));
+    // Fix indices, adding j*split_N to the j-th batch of indices.
+    TF_ASSIGN_OR_RETURN(HloInstruction * indices,
+                        MakeGetTupleElementHlo(batch_topk, 1));
+    TF_ASSIGN_OR_RETURN(HloInstruction * values,
+                        MakeGetTupleElementHlo(batch_topk, 0));
+    Shape iota_shape = ShapeUtil::MakeShape(S32, {new_batch});
+    TF_ASSIGN_OR_RETURN(
+        HloInstruction * fix,
+        MakeBinaryHlo(
+            HloOpcode::kMultiply, MakeIotaHlo(comp, iota_shape, 0),
+            MakeBroadcastHlo(MakeR0ConstantHlo<int32_t>(comp, new_n),
+                             /*broadcast_dimensions=*/{}, iota_shape)));
+    TF_ASSIGN_OR_RETURN(
+        indices, MakeBinaryHlo(HloOpcode::kAdd, indices,
+                               MakeBroadcastHlo(fix, {0}, indices->shape())));
+    // With the indices restored, compute a final top-k. Since this topk uses
+    // arbitrary indices, we need to use sort+slice.
+    Shape linear_index_shape = ShapeUtil::MakeShape(S32, {k * new_batch});
+    Shape linear_shape = ShapeUtil::ChangeElementType(
+        linear_index_shape, data_shape.element_type());
+    Shape linear_sort_shape =
+        ShapeUtil::MakeTupleShape({linear_shape, linear_index_shape});
+    // Assuming the outputs of the TopK above are stably sorted, using a stable
+    // sort here is enough to guarantee global stable sorting:
+    //  - Within a blocks elements are stably sorted by TopK.
+    //  - Since blocks are organized linearly from smallest to largest, the
+    //    index used on the stable sort below will also respect block ordering.
+    HloInstruction* aggregated_sort =
+        comp->AddInstruction(HloInstruction::CreateSort(
+            linear_sort_shape, 0,
+            {*MakeReshapeHlo(linear_shape, values),
+             *MakeReshapeHlo(linear_index_shape, indices)},
+            topk->to_apply(), /*is_stable=*/true));
+    auto slice_tuple = [&](HloInstruction* sort, const size_t index) {
+      return *MakeReshapeHlo(
+          topk->shape().tuple_shapes(index),
+          *MakeSliceHlo(*MakeGetTupleElementHlo(sort, index), {0}, {k}, {1}));
+    };
+    return ReplaceInstruction(topk,
+                              comp->AddInstruction(HloInstruction::CreateTuple({
+                                  slice_tuple(aggregated_sort, 0),
+                                  slice_tuple(aggregated_sort, 1),
+                              })));
+  }
+
+ private:
+  size_t split_threshold_;
+};
+
+}  // namespace
+
+absl::StatusOr<bool> TopKSplitter::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  return TopkSplitterVisitor(split_threshold_)
+      .RunOnModule(module, execution_threads);
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/topk_splitter.h b/third_party/xla/xla/service/gpu/transforms/topk_splitter.h
new file mode 100644
index 0000000..c6fe429
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/topk_splitter.h
@@ -0,0 +1,52 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_TOPK_SPLITTER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_TOPK_SPLITTER_H_
+
+#include <cstddef>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// Splits large TopK into batches of smaller TopKs, followed by sorting and
+// slicing the results of those smaller topks. We consider TopKs to be 'large'
+// the last dimension of the TopK is larger than `split_threshold`.
+class TopKSplitter : public HloModulePass {
+ public:
+  explicit TopKSplitter(size_t split_threshold = 1024 * 1024)
+      : split_threshold_(split_threshold) {}
+  absl::string_view name() const override { return "topk-splitter"; }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+  const size_t split_threshold_;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_TOPK_SPLITTER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/topk_splitter_test.cc b/third_party/xla/xla/service/gpu/transforms/topk_splitter_test.cc
new file mode 100644
index 0000000..8236c26
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/topk_splitter_test.cc
@@ -0,0 +1,210 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/topk_splitter.h"
+
+#include <stdint.h>
+
+#include <cstddef>
+#include <memory>
+#include <optional>
+#include <string>
+#include <utility>
+
+#include "absl/strings/string_view.h"
+#include "absl/strings/substitute.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_dce.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/topk_rewriter.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/tests/verified_hlo_module.h"
+#include "tsl/platform/status_matchers.h"
+#include "tsl/platform/statusor.h"
+#include "tsl/platform/test.h"
+
+namespace m = ::xla::match;
+
+namespace xla {
+namespace gpu {
+namespace {
+
+using ::tsl::testing::IsOkAndHolds;
+using TopkSplitterTest = HloTestBase;
+
+constexpr absl::string_view kComparator = R"(
+  %compare {
+    %p.1.lhs.40628 = s32[] parameter(2)
+    %p.1.rhs.40629 = s32[] parameter(3)
+    %constant.40630 = pred[] constant(true)
+    %broadcast.40631 = pred[] broadcast(pred[] %constant.40630), dimensions={}
+    %p.0.lhs.40626 = f32[] parameter(0)
+    %p.0.rhs.40627 = f32[] parameter(1)
+    %compare.40632 = pred[] compare(f32[] %p.0.lhs.40626, f32[] %p.0.rhs.40627), direction=GT, type=TOTALORDER
+    ROOT %select.40633 = pred[] select(pred[] %broadcast.40631, pred[] %compare.40632, pred[] %broadcast.40631)
+  })";
+
+TEST_F(TopkSplitterTest, SplitsTopK) {
+  const std::string hlo_string = absl::Substitute(R"(
+HloModule module
+$0
+ENTRY cluster {
+  %arg.1 = f32[1,1073741824] parameter(0)
+  ROOT %cc.2 = (f32[1,5], s32[1,5]) custom-call(%arg.1), custom_call_target= "TopK", to_apply=%compare
+})",
+                                                  kComparator);
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  EXPECT_THAT(RunHloPass(TopKSplitter(), module.get()), IsOkAndHolds(true));
+  auto first_topk = m::CustomCall(m::Reshape(m::Parameter(0)));
+  auto slice_result = [&](auto input, size_t i) {
+    return m::Reshape(m::Slice(m::GetTupleElement(input, i)));
+  };
+  auto index_correction =
+      m::Broadcast(m::Multiply(m::Iota(), m::Broadcast(m::Constant())));
+  auto sorted = m::Sort(
+      m::Reshape(m::GetTupleElement(first_topk, 0)),
+      m::Reshape(m::Add(m::GetTupleElement(first_topk, 1), index_correction)));
+  EXPECT_TRUE(
+      Match(module->entry_computation()->root_instruction(),
+            m::Tuple(slice_result(sorted, 0), slice_result(sorted, 1))));
+}
+
+TEST_F(TopkSplitterTest, SplitsTopKNoBatchDimension) {
+  const std::string hlo_string = absl::Substitute(R"(
+HloModule module
+$0
+ENTRY cluster {
+  %arg.1 = f32[1073741824] parameter(0)
+  ROOT %cc.2 = (f32[5], s32[5]) custom-call(%arg.1), custom_call_target= "TopK", to_apply=%compare
+})",
+                                                  kComparator);
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  EXPECT_THAT(RunHloPass(TopKSplitter(), module.get()), IsOkAndHolds(true));
+  auto first_topk = m::CustomCall(m::Reshape(m::Parameter(0)));
+  auto slice_result = [&](auto input, size_t i) {
+    return m::Reshape(m::Slice(m::GetTupleElement(input, i)));
+  };
+  auto index_correction =
+      m::Broadcast(m::Multiply(m::Iota(), m::Broadcast(m::Constant())));
+  auto sorted = m::Sort(
+      m::Reshape(m::GetTupleElement(first_topk, 0)),
+      m::Reshape(m::Add(m::GetTupleElement(first_topk, 1), index_correction)));
+  EXPECT_TRUE(
+      Match(module->entry_computation()->root_instruction(),
+            m::Tuple(slice_result(sorted, 0), slice_result(sorted, 1))));
+}
+
+TEST_F(TopkSplitterTest, SplitFailsUnderThreshold) {
+  const std::string hlo_string = absl::Substitute(R"(
+HloModule module
+$0
+ENTRY cluster {
+  %arg.1 = f32[1,524288] parameter(0)
+  ROOT %cc.2 = (f32[1,5], s32[1,5]) custom-call(%arg.1), custom_call_target= "TopK", to_apply=%compare
+})",
+                                                  kComparator);
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  EXPECT_THAT(
+      RunHloPass(TopKSplitter(/*split_threshold=*/1048576), module.get()),
+      IsOkAndHolds(false));
+}
+
+TEST_F(TopkSplitterTest, SplitFailsUnaligned) {
+  const std::string hlo_string = absl::Substitute(R"(
+HloModule module
+$0
+ENTRY cluster {
+  %arg.1 = f32[1,524289] parameter(0)
+  ROOT %cc.2 = (f32[1,5], s32[1,5]) custom-call(%arg.1), custom_call_target= "TopK", to_apply=%compare
+})",
+                                                  kComparator);
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  EXPECT_THAT(RunHloPass(TopKSplitter(/*split_threshold=*/1024), module.get()),
+              IsOkAndHolds(false));
+}
+
+TEST_F(TopkSplitterTest, SplitFailsLargeK) {
+  const std::string hlo_string = absl::Substitute(R"(
+HloModule module
+$0
+ENTRY cluster {
+  %arg.1 = f32[1,524288] parameter(0)
+  ROOT %cc.2 = (f32[1,1024], s32[1,1024]) custom-call(%arg.1), custom_call_target= "TopK", to_apply=%compare
+})",
+                                                  kComparator);
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  EXPECT_THAT(RunHloPass(TopKSplitter(/*split_threshold=*/1024), module.get()),
+              IsOkAndHolds(false));
+}
+
+TEST_F(TopkSplitterTest, Equivalent) {
+  const std::string hlo_string = absl::Substitute(R"(
+HloModule module
+$0
+ENTRY cluster {
+  %arg.1 = f32[1,16384] parameter(0)
+  ROOT %cc.2 = (f32[1,5], s32[1,5]) custom-call(%arg.1), custom_call_target= "TopK", to_apply=%compare
+})",
+                                                  kComparator);
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  EXPECT_THAT(TopkDecomposer().Run(module.get()), IsOkAndHolds(true));
+  auto round_trip = [](HloModule* module) {
+    EXPECT_THAT(TopkRewriter([](const HloSortInstruction*, int64_t) {
+                  return true;
+                }).Run(module),
+                IsOkAndHolds(true));
+    EXPECT_THAT(TopKSplitter(1024).Run(module), IsOkAndHolds(true));
+    EXPECT_THAT(TopkDecomposer().Run(module), IsOkAndHolds(true));
+    EXPECT_TRUE(HloDCE().Run(module).status().ok());
+  };
+  EXPECT_TRUE(RunAndCompare(std::move(module), std::nullopt, round_trip));
+}
+
+TEST_F(TopkSplitterTest, StableSorts) {
+  const std::string hlo_string = absl::Substitute(R"(
+HloModule module
+$0
+ENTRY cluster {
+  %constant.1 = f32[] constant(42)
+  %broadcast.2= f32[1,16384] broadcast(f32[] %constant.1), dimensions={}
+  ROOT %cc.3 = (f32[1,5], s32[1,5]) custom-call(%broadcast.2), custom_call_target= "TopK", to_apply=%compare
+})",
+                                                  kComparator);
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  EXPECT_THAT(TopkDecomposer().Run(module.get()), IsOkAndHolds(true));
+  auto round_trip = [](HloModule* module) {
+    EXPECT_THAT(TopkRewriter([](const HloSortInstruction*, int64_t) {
+                  return true;
+                }).Run(module),
+                IsOkAndHolds(true));
+    EXPECT_THAT(TopKSplitter(1024).Run(module), IsOkAndHolds(true));
+    EXPECT_THAT(TopkDecomposer().Run(module), IsOkAndHolds(true));
+    EXPECT_TRUE(HloDCE().Run(module).status().ok());
+  };
+  EXPECT_TRUE(RunAndCompare(std::move(module), std::nullopt, round_trip));
+}
+
+}  // namespace
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter.cc
new file mode 100644
index 0000000..fb023fc
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter.cc
@@ -0,0 +1,389 @@
+/* Copyright 2020 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "xla/service/gpu/transforms/tree_reduction_rewriter.h"
+
+#include <algorithm>
+#include <cmath>
+#include <cstdint>
+#include <iterator>
+#include <memory>
+#include <utility>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/container/inlined_vector.h"
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/numeric/bits.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/collective_ops_utils.h"
+#include "xla/service/gpu/reduction_utils.h"
+#include "xla/service/hlo_module_config.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+absl::InlinedVector<int64_t, 2> GetSortedReducedDims(
+    HloReduceInstruction *reduce) {
+  absl::InlinedVector<int64_t, 2> reduced_dims{reduce->dimensions().begin(),
+                                               reduce->dimensions().end()};
+  absl::c_sort(reduced_dims);
+  return reduced_dims;
+}
+
+bool IsMinMaxReduction(HloReduceInstruction *reduce) {
+  HloComputation *called = &reduce->to_apply()[0];
+  if (auto reduction_kind = MatchReductionComputation(called)) {
+    return reduction_kind == ReductionKind::MAX ||
+           reduction_kind == ReductionKind::MIN;
+  }
+  return false;
+}
+
+}  // namespace
+
+class ReductionRewriterVisitor : public DfsHloRewriteVisitor {
+ public:
+  explicit ReductionRewriterVisitor(se::GpuComputeCapability gpu_version)
+      : gpu_version_(gpu_version) {}
+
+  absl::Status HandleReduce(HloInstruction *hlo) override {
+    auto *reduce = Cast<HloReduceInstruction>(hlo);
+    VLOG(3) << "Reduction instruction: " << reduce->ToString();
+
+    const HloModuleConfig &config = reduce->GetModule()->config();
+    if (!MatchReductionForSplit(reduce, config)) {
+      return absl::OkStatus();
+    }
+    ReductionDimensions reduction_dims =
+        GetReductionKindAndContiguousComponents(*hlo);
+    if (ReductionIsRaceFree(config, reduction_dims)) {
+      VLOG(3) << "Base case: dimensions fit";
+      return absl::OkStatus();
+    }
+    auto sorted_dims_to_reduce = GetSortedReducedDims(reduce);
+    CHECK_LE(sorted_dims_to_reduce.size(), 2);
+
+    // If the major reduced dimension does not fit, reduce the minor dimension
+    // first, then the major.
+    if (reduction_dims.is_row_reduction &&
+        reduction_dims
+                .dimensions[ReductionDimensions::kRowMajorReducedDimension] >
+            BatchedReductionRaceFreeBound()) {
+      VLOG(2) << "Splitting batched dimension reduce into a separate reduction";
+      return RewriteBatchDimensionLargerThanTile(reduce, reduction_dims,
+                                                 sorted_dims_to_reduce);
+    }
+    SplitParams split_params =
+        ComputeSplitParams(reduce, reduction_dims, sorted_dims_to_reduce);
+    return SplitReductionDimension(reduce, split_params, sorted_dims_to_reduce);
+  }
+
+ private:
+  bool MatchReductionForSplit(HloReduceInstruction *reduce,
+                              const HloModuleConfig &config) {
+    // MLIR emitters only support race-free reductions.
+    // TODO(jreiffers: Verify performance and implement atomics for reductions
+    // if needed.
+    bool reductions_via_mlir_disabled =
+        config.debug_options().xla_gpu_mlir_emitter_level() < 4;
+    if (reductions_via_mlir_disabled && IsMinMaxReduction(reduce)) {
+      // TODO(cheshire): Also enable for integers.
+      VLOG(1) << "Not performing tree expansion on min/max-reduction: "
+              << reduce->ToString()
+              << " since min/max operations are associative";
+      return false;
+    }
+    if (!IsReductionFromOrToContiguousDimensions(*reduce)) {
+      VLOG(3) << "Is not a reduction from or to contiguous dimensions";
+      return false;
+    }
+    VLOG(3) << "Perform rewrite";
+    return true;
+  }
+
+  // We observe larger n_div_k can improve tree reduction performance in most of
+  // the cases by reducing memory store and the launch overhead of blocks. Swap
+  // k and n_div_k if possible.
+  bool ShouldSwapInnerAndOuterReducedMinorDimension(uint64_t k1, uint64_t k2,
+                                                    uint64_t n,
+                                                    int64_t race_free_bound,
+                                                    bool is_row_reduction) {
+    CHECK(k1 >= k2);
+    // Keep inner reduction as race free.
+    if (k1 > race_free_bound) {
+      return false;
+    }
+    // Swapping only affects row reduction vectorization.
+    if (is_row_reduction) {
+      // Rough conditions for row reduction vectorization, not mean that
+      // vectorization will definitely occur.
+      bool maybe_vectorized = k2 % 2 == 0 && n % 2 == 0;
+      if (maybe_vectorized) {
+        // Swap if n_div_k is small enough or k dim can be vectorized also.
+        return k2 * 2 < k1 || k1 % 2 == 0;
+      }
+      // Current reduction emitter only checks reduction input dimensions but
+      // not fusion input dimensions. Due to pad and inner reduction always fuse
+      // into same computation, it may leads to each thread reads multiple non
+      // aligned elements but can not vectorized so that get bad performance.
+      // Don't swap If encountered this situation.
+      return n % 2 == 0 || k1 % 2 != 0;
+    }
+    // There exists no specific situation where swapping has no performance gain
+    // for column reduction.
+    return true;
+  }
+
+  // Parameters how to split a dimension `dim` with `k` elements into `k1` x
+  // `k2`.
+  struct SplitParams {
+    int64_t k1;
+    int64_t k2;
+    int64_t dim;
+  };
+
+  // Attempts to find the best way to split a dimension `dim` with `k` elements
+  // into `k1` x `k2`.
+  SplitParams ComputeSplitParams(
+      HloReduceInstruction *reduce, const ReductionDimensions &reduction_dims,
+      absl::Span<const int64_t> sorted_dims_to_reduce) {
+    absl::Span<int64_t const> input_shape_dims =
+        reduce->inputs()[0]->shape().dimensions();
+
+    int64_t reduced_dim = sorted_dims_to_reduce.back();
+    int64_t reduced_dim_size = input_shape_dims[reduced_dim];
+    VLOG(3) << "reduced dim size = " << reduced_dim_size;
+
+    // We will do this reduction in two stages.  The first will reduce from k
+    // elements to k1 elements in the reduction dimension.  The second will
+    // reduce further, from k2 to 1 element.
+    //
+    // We do this by splitting the input shape [a, k, b] into [a, k1, k2, b].
+    //
+    // We want to choose k1 to be roughly equal to sqrt(k) so that we process
+    // "most of" the reduction in the first step. But it is also important that
+    // we choose a value of k1 with the least amount of padding we need to add
+    // to n to make it divisible by k1. We search for the best value of k2
+    // between sqrt(k)/2 and sqrt(k). If there are several possible values for
+    // k2 that result in the minimum amount of padding, we also want k2 to
+    // be a power of 2, so that the GPU kernel doesn't spend all its time doing
+    // slow integer divmods to compute indices into the shape [a,k1,k2,b].
+    // Note that by searching in the range between sqrt(k)/2 and sqrt(k), we
+    // will have a power of 2 in that range.
+    uint64_t k2 =
+        static_cast<uint64_t>(std::floor(std::sqrt(reduced_dim_size)));
+    int64_t race_free_bound = ReductionDimensionRaceFreeBound(
+        reduce->GetModule()->config(), reduction_dims);
+    if (k2 > race_free_bound) {
+      // This means we need more than one split. It is best to limit the n/k
+      // dimension to the maximum size that doesn't require further splitting.
+      // Otherwise we might choose a rather small reduce dimension size for the
+      // first step (in the worst case, sqrt(race_free_bound + 1)).
+      k2 = race_free_bound;
+    }
+    uint64_t minimum_padding = (k2 - reduced_dim_size % k2) % k2;
+    uint64_t best_k1 = (reduced_dim_size + minimum_padding) / k2;
+    for (uint64_t i = k2 - 1; i > k2 / 2; --i) {
+      uint64_t padding = (i - reduced_dim_size % i) % i;
+      if (padding < minimum_padding ||
+          (padding == minimum_padding && absl::has_single_bit(i))) {
+        minimum_padding = padding;
+        best_k1 = (reduced_dim_size + padding) / i;
+      }
+    }
+    uint64_t padded_k = reduced_dim_size + minimum_padding;
+
+    // We get the best {k_1, k_2} pair by the size of padding and whether
+    // index computation is fast. But we ignored the overhead of memory
+    // read/write and blocks launch, which are also important for kernel
+    // performance. It is obvious that the swapped {k1, k2} pairs has same
+    // padding size and consumption of index computation as the original. So we
+    // only need to compare the memory read/write and blocks launch to choose
+    // the better one of them.
+    uint64_t best_k2 = padded_k / best_k1;
+    if (ShouldSwapInnerAndOuterReducedMinorDimension(
+            best_k1, best_k2, reduced_dim_size, race_free_bound,
+            reduction_dims.is_row_reduction)) {
+      std::swap(best_k1, best_k2);
+    }
+    return SplitParams{static_cast<int64_t>(best_k1),
+                       static_cast<int64_t>(best_k2), reduced_dim};
+  }
+
+  // Replaces the original reduce with pad->reshape>inner_reduce->outer_reduce.
+  // * 1. pads split dimension of the inputs to k1 * k2 if necessary.
+  // * 2. reshapes split dimension of the padded inputs into [k1, k2].
+  // * 3. inner reduction reduces the dims specified in the original reduction.
+  //      Instead of reducing the split dimension, reduces K2.
+  // * 4. outer_reduction reduces K1 only.
+  absl::Status SplitReductionDimension(
+      HloReduceInstruction *reduce, const SplitParams &split_params,
+      absl::Span<const int64_t> sorted_dims_to_reduce) {
+    absl::Span<int64_t const> reduce_input_dims =
+        reduce->inputs()[0]->shape().dimensions();
+    int64_t split_dim_size = reduce_input_dims[split_params.dim];
+    VLOG(2) << "dimension to split = " << split_params.dim << " with "
+            << split_dim_size << " elements into " << split_params.k1 << " by "
+            << split_params.k2;
+
+    // Pad 'k' to 'k1 * k2' if necessary.
+    HloInstruction::InstructionVector padded_inputs(reduce->inputs().begin(),
+                                                    reduce->inputs().end());
+    auto padded_size = split_params.k1 * split_params.k2;
+    absl::InlinedVector<int64_t, 3> padded_dimensions(reduce_input_dims.begin(),
+                                                      reduce_input_dims.end());
+    if (split_dim_size != padded_size) {
+      padded_dimensions[split_params.dim] = padded_size;
+      PaddingConfig padding_config =
+          MakeNoPaddingConfig(reduce_input_dims.size());
+      padding_config.mutable_dimensions(split_params.dim)
+          ->set_edge_padding_high(padded_size - split_dim_size);
+
+      for (int input_idx = 0; input_idx < padded_inputs.size(); ++input_idx) {
+        auto &reduction_input = padded_inputs[input_idx];
+        Shape padded_shape = ShapeUtil::MakeShape(
+            reduction_input->shape().element_type(), padded_dimensions);
+        VLOG(2) << "Generated padded shape: " << padded_shape.ToString();
+        reduction_input = reduce->parent()->AddInstruction(
+            HloInstruction::CreatePad(padded_shape, reduction_input,
+                                      reduce->init_values()[input_idx],
+                                      padding_config),
+            &reduction_input->metadata());
+      }
+    }
+
+    // Compute output type of reshape that expands the split dimension into
+    // [k1, k2].
+    absl::InlinedVector<int64_t, 3> reshaped_dimensions;
+    int64_t input_rank = reduce_input_dims.size();
+    for (int64_t dim_idx = 0; dim_idx < input_rank; dim_idx++) {
+      if (dim_idx == split_params.dim) {
+        reshaped_dimensions.push_back(split_params.k1);
+        reshaped_dimensions.push_back(split_params.k2);
+      } else {
+        reshaped_dimensions.push_back(padded_dimensions[dim_idx]);
+      }
+    }
+
+    // Compute dimensions to reduce for inner reduction.
+    absl::InlinedVector<int64_t, 2> inner_reduce_dims(
+        sorted_dims_to_reduce.begin(), sorted_dims_to_reduce.end());
+    auto split_dim_it = std::find(inner_reduce_dims.begin(),
+                                  inner_reduce_dims.end(), split_params.dim);
+    *split_dim_it += 1;
+
+    // Compute dimension to reduce for outer reduction.
+    absl::InlinedVector<int64_t, 1> outer_reduce_dims{
+        split_params.dim -
+        std::distance(inner_reduce_dims.begin(), split_dim_it)};
+
+    // Compute output shape of the inner reduction.
+    absl::InlinedVector<int64_t, 3> inner_reduce_shape =
+        RemoveElements(inner_reduce_dims, reshaped_dimensions);
+
+    // Reshape the split dimensions of the padded inputs into [k1, k2].
+    HloInstruction::InstructionVector reshaped_padded_inputs;
+    absl::InlinedVector<Shape, 2> inner_reduce_shapes;
+    for (HloInstruction *padded_input : padded_inputs) {
+      Shape reshaped_shape = ShapeUtil::MakeShape(
+          padded_input->shape().element_type(), reshaped_dimensions);
+      HloInstruction *reshaped_padded_input = reduce->parent()->AddInstruction(
+          HloInstruction::CreateBitcast(reshaped_shape, padded_input),
+          &padded_input->metadata());
+      VLOG(2) << "Generated reshape: " << reshaped_padded_input->ToString();
+      reshaped_padded_inputs.push_back(reshaped_padded_input);
+      inner_reduce_shapes.push_back(ShapeUtil::MakeShape(
+          padded_input->shape().element_type(), inner_reduce_shape));
+    }
+
+    // Inner reduce that reduces [k1, k2] to [k1].
+    HloInstruction *inner_reduce = reduce->parent()->AddInstruction(
+        HloInstruction::CreateReduce(
+            ShapeUtil::MakeMaybeTupleShape(inner_reduce_shapes),
+            reshaped_padded_inputs, reduce->init_values(), inner_reduce_dims,
+            reduce->to_apply()),
+        &reduce->metadata());
+    VLOG(1) << "Generated inner reduction: " << inner_reduce->ToString();
+
+    // Outer reduce that reduces [k2].
+    std::unique_ptr<HloInstruction> outer_reduce = HloInstruction::CreateReduce(
+        reduce->shape(), inner_reduce, reduce->init_values(), outer_reduce_dims,
+        reduce->to_apply());
+
+    VLOG(1) << "Generated outer reduction: " << outer_reduce->ToString();
+    return ReplaceWithNewInstruction(reduce, std::move(outer_reduce));
+  }
+
+  // Rewrites batch dimension reduction into a separate reduce operation.
+  absl::Status RewriteBatchDimensionLargerThanTile(
+      HloReduceInstruction *hlo,
+      const ReductionDimensions &reduction_dimensions,
+      absl::Span<const int64_t> sorted_dims_to_reduce) {
+    // TODO(cheshire): this codepath is essentially the exact reverse of what
+    // algebraic_simplifier is doing, we need to make sure they don't keep
+    // undoing each other.
+    CHECK(reduction_dimensions.is_row_reduction);
+
+    absl::InlinedVector<Shape, 2> tuple_shapes;
+    int64_t minor_reduction_dim = sorted_dims_to_reduce.back();
+    for (HloInstruction *input : hlo->inputs()) {
+      tuple_shapes.push_back(
+          ShapeUtil::DeleteDimension(minor_reduction_dim, input->shape()));
+    }
+
+    HloInstruction *inner_reduce =
+        hlo->parent()->AddInstruction(HloInstruction::CreateReduce(
+            ShapeUtil::MakeMaybeTupleShape(tuple_shapes), hlo->inputs(),
+            hlo->init_values(), {minor_reduction_dim}, hlo->to_apply()));
+
+    VLOG(1) << "Inner reduction: " << inner_reduce->ToString();
+    std::unique_ptr<HloInstruction> out = HloInstruction::CreateReduce(
+        hlo->shape(), inner_reduce, hlo->init_values(), {0}, hlo->to_apply());
+    VLOG(1) << "Generated: " << out->ToString();
+    return ReplaceWithNewInstruction(hlo, std::move(out));
+  }
+
+  se::GpuComputeCapability gpu_version_;
+};
+
+absl::StatusOr<bool> TreeReductionRewriter::Run(
+    HloModule *module,
+    const absl::flat_hash_set<absl::string_view> &execution_threads) {
+  VLOG(5) << "Rewriter input: " << module->ToString();
+  TF_ASSIGN_OR_RETURN(bool changed,
+                      ReductionRewriterVisitor(gpu_version_)
+                          .RunOnModule(module, execution_threads));
+  VLOG(5) << "Rewriter output: " << module->ToString();
+  return changed;
+}
+
+}  // end namespace gpu
+}  // end namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter.h b/third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter.h
new file mode 100644
index 0000000..7f57d21
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter.h
@@ -0,0 +1,96 @@
+/* Copyright 2020 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_TREE_REDUCTION_REWRITER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_TREE_REDUCTION_REWRITER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/stream_executor/device_description.h"
+
+namespace xla {
+namespace gpu {
+
+// Rewrites reductions in a way they can be implemented without atomics.
+//
+// Rule application: rewrite a single HLO reduce operation into two.
+//
+// Case 1: Row reduction, batched dimension is present, larger than
+// Z-tiling size.
+// -----------------------------------------------------------------
+//
+// Rewriting:
+//
+// f32[B] out = reduce(f32[A, B, C] input, dimensions={0, 2})
+//
+// Into:
+//
+// f32[A, B] tmp = reduce(f32[A, B, C] input, dimensions={2})
+// f32[B] out = reduce(f32[A, B] tmp, dimensions={0})
+//
+// Case 2: Row reduction
+// ------------------------------------------------------------------
+//
+// Let M be the thread tiling multiplied by the warp size.
+// We go from (assuming C > M):
+//
+// f32[B] out = reduce(f32[A, B, C] input, dimensions={0, 2})
+//
+// to:
+//
+// f32[A, B, P] padded = pad(input) // Let P = ceil(C/M) * M.
+// f32[A, B, Q, M] reshaped = bitcast(padded) // Let Q = ceil(C/M)
+// f32[B, Q] inner_reduce = reduce(reshaped, dimensions={0, 3})
+// f32[B] outer_reduce = reduce(inner_reduce, dimensions={1})
+//
+// Case 3: Column reduction
+// -------------------------------------------------------------------
+//
+// Let T be the tiling size for the column reduction.
+//
+// We go from (assuming B > T):
+//
+// f32[A, C] out = reduce(f32[A, B, C] input, dimensions={1})
+//
+// to:
+//
+// f32[A, P, C] padded = pad(input) // Let P = ceil(B/T) * T.
+// f32[A, Q, T, C] reshaped = bitcast(padded) // Let Q = ceil(B/T)
+// f32[A, Q, C] inner_reduce = reduce(reshaped, dimensions={2})
+// f32[A, C] outer_reduce = reduce(inner_reduce, dimensions={1})
+//
+class TreeReductionRewriter : public HloModulePass {
+ public:
+  explicit TreeReductionRewriter(se::GpuComputeCapability gpu_version)
+      : gpu_version_(gpu_version) {}
+
+  ~TreeReductionRewriter() override = default;
+  absl::string_view name() const override { return "tree-reduction-rewriter"; }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+  se::GpuComputeCapability gpu_version_;
+};
+
+}  // end namespace gpu
+}  // end namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_TREE_REDUCTION_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter_test.cc
new file mode 100644
index 0000000..bea969e
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter_test.cc
@@ -0,0 +1,557 @@
+/* Copyright 2020 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/tree_reduction_rewriter.h"
+
+#include <optional>
+
+#include "absl/strings/string_view.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/test.h"
+
+namespace xla {
+
+namespace {
+
+class TreeReductionRewriterTest : public HloTestBase {
+ public:
+  void CheckTreeRewriter(absl::string_view hlo,
+                         std::optional<absl::string_view> expected) {
+    RunAndFilecheckHloRewrite(
+        hlo,
+#if TENSORFLOW_USE_ROCM
+        gpu::TreeReductionRewriter{se::RocmComputeCapability {
+          "908"
+        }},
+#else
+        gpu::TreeReductionRewriter{se::CudaComputeCapability{8, 1}},
+#endif
+        expected);
+  }
+};
+
+TEST_F(TreeReductionRewriterTest, RowReductionSingleDimensionNoBatched) {
+  const char* hlo = R"(
+HloModule ReduceWithPadding
+
+add {
+  accum = f32[] parameter(0)
+  op = f32[] parameter(1)
+  ROOT out = f32[] add(accum, op)
+}
+
+ENTRY main {
+  input = f32[50021] parameter(0)
+  zero = f32[] constant(0)
+  ROOT out = f32[] reduce(input, zero), dimensions={0}, to_apply=add
+}
+)";
+
+  CheckTreeRewriter(hlo,
+                    R"(
+// CHECK: [[pad_0:%[^ ]+]] = f32[50022]{0} pad([[input_1:%[^ ]+]], [[zero_2:%[^ ]+]]), padding=0_1
+// CHECK: [[bitcast_3:%[^ ]+]] = f32[126,397]{1,0} bitcast([[pad_0]])
+// CHECK: [[reduce_4:%[^ ]+]] = f32[126]{0} reduce([[bitcast_3]], [[zero_2]]), dimensions={1}, to_apply=[[add_5:%[^ ]+]]
+// CHECK: ROOT [[out_1_6:%[^ ]+]] = f32[] reduce([[reduce_4]], [[zero_2]]), dimensions={0}, to_apply=[[add_5]]
+      )");
+}
+
+TEST_F(TreeReductionRewriterTest, RowReductionWeirdOutputLayout) {
+  const char* hlo = R"(
+HloModule ReduceWithPadding
+
+add {
+  accum = f32[] parameter(0)
+  op = f32[] parameter(1)
+  ROOT out = f32[] add(accum, op)
+}
+
+ENTRY main {
+  input = f32[2,4,17000]{2,1,0} parameter(0)
+  zero = f32[] constant(0)
+  ROOT out = f32[2,4]{0,1} reduce(input, zero), dimensions={2}, to_apply=add
+}
+)";
+
+  // Check that we preserve the layout.
+  CheckTreeRewriter(hlo,
+                    R"(
+// CHECK: f32[2,4]{0,1} reduce(
+      )");
+}
+
+TEST_F(TreeReductionRewriterTest,
+       RowReductionSingleDimensionNoBatchedDivisible) {
+  const char* hlo = R"(
+HloModule ReduceWithPadding
+
+add {
+  accum = f32[] parameter(0)
+  op = f32[] parameter(1)
+  ROOT out = f32[] add(accum, op)
+}
+
+ENTRY main {
+  input = f32[50048] parameter(0)
+  zero = f32[] constant(0)
+  ROOT out = f32[] reduce(input, zero), dimensions={0}, to_apply=add
+}
+)";
+
+  CheckTreeRewriter(hlo,
+                    R"(
+// CHECK: [[input_0:%[^ ]+]] = f32[50048]{0} parameter(0)
+// CHECK: [[bitcast_1:%[^ ]+]] = f32[128,391]{1,0} bitcast([[input_0]])
+// CHECK: [[zero_2:%[^ ]+]] = f32[] constant(0)
+// CHECK: [[reduce_3:%[^ ]+]] = f32[128]{0} reduce([[bitcast_1]], [[zero_2]]), dimensions={1}, to_apply=[[add_4:%[^ ]+]]
+// CHECK: ROOT [[out_1_5:%[^ ]+]] = f32[] reduce([[reduce_3]], [[zero_2]]), dimensions={0}, to_apply=[[add_4]]
+      )");
+}
+
+TEST_F(TreeReductionRewriterTest, RowReductionNoBatched) {
+  const char* hlo = R"(
+HloModule ReduceWithPadding
+
+add {
+  accum = f32[] parameter(0)
+  op = f32[] parameter(1)
+  ROOT out = f32[] add(accum, op)
+}
+
+ENTRY main {
+  input = f32[100,10,65536] parameter(0)
+  zero = f32[] constant(0)
+  ROOT out = f32[100,10] reduce(input, zero), dimensions={2}, to_apply=add
+}
+)";
+
+  CheckTreeRewriter(hlo,
+                    R"(
+// CHECK: [[bitcast_0:%[^ ]+]] = f32[100,10,256,256]{3,2,1,0} bitcast([[input_1:%[^ ]+]])
+// CHECK: [[zero_2:%[^ ]+]] = f32[] constant(0)
+// CHECK: [[reduce_3:%[^ ]+]] = f32[100,10,256]{2,1,0} reduce([[bitcast_0]], [[zero_2]]), dimensions={3}, to_apply=[[add_4:%[^ ]+]]
+// CHECK: ROOT [[out_1_5:%[^ ]+]] = f32[100,10]{1,0} reduce([[reduce_3]], [[zero_2]]), dimensions={2}, to_apply=[[add_4]]
+      )");
+}
+
+TEST_F(TreeReductionRewriterTest,
+       RowReductionSingleDimensionNoBatchedLargeInput) {
+  const char* hlo = R"(
+HloModule ReduceWithPadding
+
+add {
+  accum = f32[] parameter(0)
+  op = f32[] parameter(1)
+  ROOT out = f32[] add(accum, op)
+}
+
+ENTRY main {
+  input = f32[1048576] parameter(0)
+  zero = f32[] constant(0)
+  ROOT out = f32[] reduce(input, zero), dimensions={0}, to_apply=add
+}
+)";
+
+  CheckTreeRewriter(hlo,
+                    R"(
+// CHECK:  [[input_0:%[^ ]+]] = f32[1048576]{0} parameter(0)
+// CHECK:  [[bitcast_1:%[^ ]+]] = f32[1024,1024]{1,0} bitcast([[input_0]])
+// CHECK:  [[zero_2:%[^ ]+]] = f32[] constant(0)
+// CHECK:  [[reduce_3:%[^ ]+]] = f32[1024]{0} reduce([[bitcast_1]], [[zero_2]]), dimensions={1}, to_apply=[[add_4:%[^ ]+]]
+// CHECK:  ROOT [[out_1_5:%[^ ]+]] = f32[] reduce([[reduce_3]], [[zero_2]]), dimensions={0}, to_apply=[[add_4]]
+      )");
+}
+
+TEST_F(TreeReductionRewriterTest, RowReductionBatchedDimensionFits) {
+  const char* hlo = R"(
+HloModule ReduceWithPadding
+
+add {
+  accum = f32[] parameter(0)
+  op = f32[] parameter(1)
+  ROOT out = f32[] add(accum, op)
+}
+
+ENTRY main {
+  input = f32[8,100,65536] parameter(0)
+  zero = f32[] constant(0)
+  ROOT out = f32[100] reduce(input, zero), dimensions={0,2}, to_apply=add
+}
+)";
+
+  CheckTreeRewriter(hlo,
+                    R"(
+// CHECK:  [[bitcast_0:%[^ ]+]] = f32[8,100,256,256]{3,2,1,0} bitcast([[input_1:%[^ ]+]])
+// CHECK:  [[zero_2:%[^ ]+]] = f32[] constant(0)
+// CHECK:  [[reduce_3:%[^ ]+]] = f32[100,256]{1,0} reduce([[bitcast_0]], [[zero_2]]), dimensions={0,3}, to_apply=[[add_4:%[^ ]+]]
+// CHECK:  ROOT [[out_1_5:%[^ ]+]] = f32[100]{0} reduce([[reduce_3]], [[zero_2]]), dimensions={1}, to_apply=[[add_4]]
+      )");
+}
+
+TEST_F(TreeReductionRewriterTest, RowReductionBatchedDimensionDoesNotFit) {
+  const char* hlo = R"(
+HloModule ReduceWithPadding
+
+add {
+  accum = f32[] parameter(0)
+  op = f32[] parameter(1)
+  ROOT out = f32[] add(accum, op)
+}
+
+ENTRY main {
+  input = f32[32,100,90000] parameter(0)
+  zero = f32[] constant(0)
+  ROOT out = f32[100] reduce(input, zero), dimensions={0,2}, to_apply=add
+}
+)";
+
+  CheckTreeRewriter(hlo,
+                    R"(
+// CHECK: [[reduce_0:%[^ ]+]] = f32[32,100]{1,0} reduce([[input_1:%[^ ]+]], [[zero_2:%[^ ]+]]), dimensions={2}, to_apply=[[add_3:%[^ ]+]]
+// CHECK:  ROOT [[out_1_4:%[^ ]+]] = f32[100]{0} reduce([[reduce_0]], [[zero_2]]), dimensions={0}, to_apply=[[add_3]]
+      )");
+}
+
+TEST_F(TreeReductionRewriterTest, ColumnReductionSimple) {
+  const char* hlo = R"(
+HloModule ReduceWithPadding
+
+add {
+  accum = f32[] parameter(0)
+  op = f32[] parameter(1)
+  ROOT out = f32[] add(accum, op)
+}
+
+ENTRY main {
+  input = f32[16384,100] parameter(0)
+  zero = f32[] constant(0)
+  ROOT out = f32[100] reduce(input, zero), dimensions={0}, to_apply=add
+}
+)";
+
+  CheckTreeRewriter(hlo,
+                    R"(
+
+// CHECK:  [[input_0:%[^ ]+]] = f32[16384,100]{1,0} parameter(0)
+// CHECK:  [[bitcast_1:%[^ ]+]] = f32[128,128,100]{2,1,0} bitcast([[input_0]])
+// CHECK:  [[reduce_2:%[^ ]+]] = f32[128,100]{1,0} reduce([[bitcast_1]], [[zero_3:%[^ ]+]]), dimensions={1}, to_apply=[[add_4:%[^ ]+]]
+// CHECK:  ROOT [[out_1_5:%[^ ]+]] = f32[100]{0} reduce([[reduce_2]], [[zero_3]]), dimensions={0}, to_apply=[[add_4]]
+      )");
+}
+
+TEST_F(TreeReductionRewriterTest, ColumnReductionSimpleNoDivisible) {
+  const char* hlo = R"(
+HloModule ReduceWithPadding
+
+add {
+  accum = f32[] parameter(0)
+  op = f32[] parameter(1)
+  ROOT out = f32[] add(accum, op)
+}
+
+ENTRY main {
+  input = f32[10303,100] parameter(0)
+  zero = f32[] constant(0)
+  ROOT out = f32[100] reduce(input, zero), dimensions={0}, to_apply=add
+}
+)";
+
+  CheckTreeRewriter(hlo,
+                    R"(
+// CHECK:  [[input_0:%[^ ]+]] = f32[10303,100]{1,0} parameter(0)
+// CHECK:  [[zero_2:%[^ ]+]] = f32[] constant(0)
+// CHECK:  [[pad_0:%[^ ]+]] = f32[10304,100]{1,0} pad([[input_1:%[^ ]+]], [[zero_2:%[^ ]+]]), padding=0_1x0_0
+// CHECK:  [[bitcast_1:%[^ ]+]] = f32[64,161,100]{2,1,0} bitcast([[pad_0]])
+// CHECK:  [[reduce_3:%[^ ]+]] = f32[64,100]{1,0} reduce([[bitcast_1]], [[zero_2]]), dimensions={1}, to_apply=[[add_4:%[^ ]+]]
+// CHECK:  ROOT [[out_1_5:%[^ ]+]] = f32[100]{0} reduce([[reduce_3]], [[zero_2]]), dimensions={0}, to_apply=[[add_4]]
+      )");
+}
+
+TEST_F(TreeReductionRewriterTest, ColumnReductionOtherIndex) {
+  const char* hlo = R"(
+HloModule ReduceWithPadding
+
+add {
+  accum = f32[] parameter(0)
+  op = f32[] parameter(1)
+  ROOT out = f32[] add(accum, op)
+}
+
+ENTRY main {
+  input = f32[16384,2,2,2] parameter(0)
+  zero = f32[] constant(0)
+  ROOT out = f32[2,2,2] reduce(input, zero), dimensions={0}, to_apply=add
+}
+)";
+
+  CheckTreeRewriter(hlo,
+                    R"(
+// CHECK:  [[input_0:%[^ ]+]] = f32[16384,2,2,2]{3,2,1,0} parameter(0)
+// CHECK:  [[bitcast_1:%[^ ]+]] = f32[128,128,2,2,2]{4,3,2,1,0} bitcast([[input_0]])
+// CHECK:  [[zero_2:%[^ ]+]] = f32[] constant(0)
+// CHECK:  [[reduce_3:%[^ ]+]] = f32[128,2,2,2]{3,2,1,0} reduce([[bitcast_1]], [[zero_2]]), dimensions={1}, to_apply=[[add_4:%[^ ]+]]
+// CHECK:  ROOT [[out_1_5:%[^ ]+]] = f32[2,2,2]{2,1,0} reduce([[reduce_3]], [[zero_2]]), dimensions={0}, to_apply=[[add_4]]
+      )");
+}
+
+TEST_F(TreeReductionRewriterTest, ColumnReductionVeryLargeInput) {
+  const char* hlo = R"(
+HloModule ReduceWithPadding
+
+add {
+  accum = f32[] parameter(0)
+  op = f32[] parameter(1)
+  ROOT out = f32[] add(accum, op)
+}
+
+ENTRY main {
+  input = f32[1048576,5] parameter(0)
+  zero = f32[] constant(0)
+  ROOT out = f32[5] reduce(input, zero), dimensions={0}, to_apply=add
+}
+)";
+
+  CheckTreeRewriter(hlo,
+                    R"(
+
+// CHECK:  [[bitcast_0:%[^ ]+]] = f32[1024,1024,5]{2,1,0} bitcast([[input_1:%[^ ]+]])
+// CHECK:  [[zero_2:%[^ ]+]] = f32[] constant(0)
+// CHECK:  [[reduce_3:%[^ ]+]] = f32[1024,5]{1,0} reduce([[bitcast_0]], [[zero_2]]), dimensions={1}, to_apply=[[add_4:%[^ ]+]]
+// CHECK:  ROOT [[out_1_5:%[^ ]+]] = f32[5]{0} reduce([[reduce_3]], [[zero_2]]), dimensions={0}, to_apply=[[add_4]]
+      )");
+}
+
+TEST_F(TreeReductionRewriterTest, VariadicReductionLargeRow) {
+  const char* hlo = R"(
+HloModule Reduce_R1x2_to_R0x2_argmax
+
+argmax {
+  running_max = f32[] parameter(0)
+  running_max_idx = u32[] parameter(1)
+  current_value = f32[] parameter(2)
+  current_value_idx = u32[] parameter(3)
+
+  current = (f32[], u32[]) tuple(running_max, running_max_idx)
+  potential = (f32[], u32[]) tuple(current_value, current_value_idx)
+
+  cmp_code = pred[] compare(current_value, running_max), direction=GT
+
+  new_max = f32[] select(cmp_code, current_value, running_max)
+  new_idx = u32[] select(cmp_code, current_value_idx, running_max_idx)
+
+  ROOT out = (f32[], u32[]) tuple(new_max, new_idx)
+}
+
+ENTRY main {
+  input = f32[2,100003] parameter(0)
+  idxs = u32[2,100003] iota(), iota_dimension=0
+  zero = f32[] constant(0)
+  zero_idx = u32[] constant(0)
+
+  ROOT out = (f32[2], u32[2]) reduce(
+    input, idxs, zero, zero_idx),
+    dimensions={1},
+    to_apply=%argmax
+}
+)";
+
+  CheckTreeRewriter(hlo,
+                    R"(
+// CHECK:  [[pad_0:%[^ ]+]] = f32[2,100005]{1,0} pad([[input_1:%[^ ]+]], [[zero_2:%[^ ]+]]), padding=0_0x0_2
+// CHECK:  [[bitcast_3:%[^ ]+]] = f32[2,295,339]{2,1,0} bitcast([[pad_0]])
+// CHECK:  [[zero_idx_4:%[^ ]+]] = u32[] constant(0)
+// CHECK:  [[pad_1_5:%[^ ]+]] = u32[2,100005]{1,0} pad([[idxs_6:%[^ ]+]], [[zero_idx_4]]), padding=0_0x0_2
+// CHECK:  [[bitcast_1_7:%[^ ]+]] = u32[2,295,339]{2,1,0} bitcast([[pad_1_5]])
+// CHECK:  [[reduce_8:%[^ ]+]] = (f32[2,295]{1,0}, u32[2,295]{1,0}) reduce([[bitcast_3]], [[bitcast_1_7]], [[zero_2]], [[zero_idx_4]]), dimensions={2}, to_apply=[[argmax_9:%[^ ]+]]
+// CHECK:  [[get_tuple_element_10:%[^ ]+]] = f32[2,295]{1,0} get-tuple-element([[reduce_8]]), index=0
+// CHECK:  [[get_tuple_element_1_11:%[^ ]+]] = u32[2,295]{1,0} get-tuple-element([[reduce_8]]), index=1
+// CHECK:  ROOT [[out_1_12:%[^ ]+]] = (f32[2]{0}, u32[2]{0}) reduce([[get_tuple_element_10]], [[get_tuple_element_1_11]], [[zero_2]], [[zero_idx_4]]), dimensions={1}, to_apply=[[argmax_9]]
+      )");
+}
+
+TEST_F(TreeReductionRewriterTest, VariadicReductionLargeBatchSize) {
+  const char* hlo = R"(
+HloModule Reduce_R1x2_to_R0x2_argmax
+
+argmax {
+  running_max = f32[] parameter(0)
+  running_max_idx = u32[] parameter(1)
+  current_value = f32[] parameter(2)
+  current_value_idx = u32[] parameter(3)
+
+  current = (f32[], u32[]) tuple(running_max, running_max_idx)
+  potential = (f32[], u32[]) tuple(current_value, current_value_idx)
+
+  cmp_code = pred[] compare(current_value, running_max), direction=GT
+
+  new_max = f32[] select(cmp_code, current_value, running_max)
+  new_idx = u32[] select(cmp_code, current_value_idx, running_max_idx)
+
+  ROOT out = (f32[], u32[]) tuple(new_max, new_idx)
+}
+
+ENTRY main {
+  input = f32[20,2,100] parameter(0)
+  idxs = u32[20,2,100] iota(), iota_dimension=0
+  zero = f32[] constant(0)
+  zero_idx = u32[] constant(0)
+
+  ROOT out = (f32[2], u32[2]) reduce(
+    input, idxs, zero, zero_idx),
+    dimensions={0,2},
+    to_apply=%argmax
+}
+)";
+
+  CheckTreeRewriter(hlo,
+                    R"(
+// CHECK:  [[reduce_0:%[^ ]+]] = (f32[20,2]{1,0}, u32[20,2]{1,0}) reduce([[input_1:%[^ ]+]], [[idxs_2:%[^ ]+]], [[zero_3:%[^ ]+]], [[zero_idx_4:%[^ ]+]]), dimensions={2}, to_apply=[[argmax_5:%[^ ]+]]
+// CHECK:  [[get_tuple_element_6:%[^ ]+]] = f32[20,2]{1,0} get-tuple-element([[reduce_0]]), index=0
+// CHECK:  [[get_tuple_element_1_7:%[^ ]+]] = u32[20,2]{1,0} get-tuple-element([[reduce_0]]), index=1
+// CHECK:  ROOT [[out_1_8:%[^ ]+]] = (f32[2]{0}, u32[2]{0}) reduce([[get_tuple_element_6]], [[get_tuple_element_1_7]], [[zero_3]], [[zero_idx_4]]), dimensions={0}, to_apply=[[argmax_5]]
+      )");
+}
+
+TEST_F(TreeReductionRewriterTest, KeepInnerReductionVectorized) {
+  const char* hlo = R"(
+HloModule KeepInnerRowReductionVectorized
+
+add {
+  accum = f32[] parameter(0)
+  op = f32[] parameter(1)
+  ROOT out = f32[] add(accum, op)
+}
+
+ENTRY main {
+  input = f32[1024,73984] parameter(0)
+  zero = f32[] constant(0)
+  ROOT out = f32[1024] reduce(input, zero), dimensions={1}, to_apply=add
+}
+)";
+
+  CheckTreeRewriter(hlo,
+                    R"(
+
+// CHECK:  [[bitcast_0:%[^ ]+]] = f32[1024,289,256]{2,1,0} bitcast([[input_1:%[^ ]+]])
+// CHECK:  [[zero_2:%[^ ]+]] = f32[] constant(0)
+// CHECK:  [[reduce_3:%[^ ]+]] = f32[1024,289]{1,0} reduce([[bitcast_0]], [[zero_2]]), dimensions={2}, to_apply=[[add_4:%[^ ]+]]
+// CHECK:  ROOT [[out_1_5:%[^ ]+]] = f32[1024]{0} reduce([[reduce_3]], [[zero_2]]), dimensions={1}, to_apply=[[add_4]]
+      )");
+}
+
+TEST_F(TreeReductionRewriterTest, PreferLargeVectorizedDimension) {
+  const char* hlo = R"(
+HloModule PreferLargeVectorizedDimension
+
+add {
+  accum = f32[] parameter(0)
+  op = f32[] parameter(1)
+  ROOT out = f32[] add(accum, op)
+}
+
+ENTRY main {
+  input = f32[1024,98304] parameter(0)
+  zero = f32[] constant(0)
+  ROOT out = f32[1024] reduce(input, zero), dimensions={1}, to_apply=add
+}
+)";
+
+  CheckTreeRewriter(hlo,
+                    R"(
+
+// CHECK:  [[bitcast_0:%[^ ]+]] = f32[1024,256,384]{2,1,0} bitcast([[input_1:%[^ ]+]])
+// CHECK:  [[zero_2:%[^ ]+]] = f32[] constant(0)
+// CHECK:  [[reduce_3:%[^ ]+]] = f32[1024,256]{1,0} reduce([[bitcast_0]], [[zero_2]]), dimensions={2}, to_apply=[[add_4:%[^ ]+]]
+// CHECK:  ROOT [[out_1_5:%[^ ]+]] = f32[1024]{0} reduce([[reduce_3]], [[zero_2]]), dimensions={1}, to_apply=[[add_4]]
+      )");
+}
+
+TEST_F(TreeReductionRewriterTest, SwapIfNonAlignedBeforePadding) {
+  const char* hlo = R"(
+HloModule SwapIfNonAlignedBeforePadding
+
+add {
+  accum = f32[] parameter(0)
+  op = f32[] parameter(1)
+  ROOT out = f32[] add(accum, op)
+}
+
+ENTRY main {
+  input = f32[1024,19739] parameter(0)
+  zero = f32[] constant(0)
+  ROOT out = f32[1024] reduce(input, zero), dimensions={1}, to_apply=add
+}
+)";
+
+  CheckTreeRewriter(hlo,
+                    R"(
+
+// CHECK-DAG:  [[bitcast_0:%[^ ]+]] = f32[1024,140,141]{2,1,0} bitcast([[input_1:%[^ ]+]])
+// CHECK-DAG:  [[zero_2:%[^ ]+]] = f32[] constant(0)
+// CHECK:  [[reduce_3:%[^ ]+]] = f32[1024,140]{1,0} reduce([[bitcast_0]], [[zero_2]]), dimensions={2}, to_apply=[[add_4:%[^ ]+]]
+// CHECK:  ROOT [[out_1_5:%[^ ]+]] = f32[1024]{0} reduce([[reduce_3]], [[zero_2]]), dimensions={1}, to_apply=[[add_4]]
+      )");
+}
+
+TEST_F(TreeReductionRewriterTest, DontSwapIfNonAlignedBeforePadding) {
+  const char* hlo = R"(
+HloModule DontSwapIfNonAlignedBeforePadding
+
+add {
+  accum = f32[] parameter(0)
+  op = f32[] parameter(1)
+  ROOT out = f32[] add(accum, op)
+}
+
+ENTRY main {
+  input = f32[1024,19459] parameter(0)
+  zero = f32[] constant(0)
+  ROOT out = f32[1024] reduce(input, zero), dimensions={1}, to_apply=add
+}
+)";
+
+  CheckTreeRewriter(hlo,
+                    R"(
+
+// CHECK-DAG:  [[bitcast_0:%[^ ]+]] = f32[1024,140,139]{2,1,0} bitcast([[input_1:%[^ ]+]])
+// CHECK-DAG:  [[zero_2:%[^ ]+]] = f32[] constant(0)
+// CHECK:  [[reduce_3:%[^ ]+]] = f32[1024,140]{1,0} reduce([[bitcast_0]], [[zero_2]]), dimensions={2}, to_apply=[[add_4:%[^ ]+]]
+// CHECK:  ROOT [[out_1_5:%[^ ]+]] = f32[1024]{0} reduce([[reduce_3]], [[zero_2]]), dimensions={1}, to_apply=[[add_4]]
+      )");
+}
+
+TEST_F(TreeReductionRewriterTest, NonCosequtiveReductionDims) {
+  const char* hlo = R"(
+    HloModule NonCosequtiveReductionDims
+
+    add {
+      accum = f32[] parameter(0)
+      op = f32[] parameter(1)
+      ROOT out = f32[] add(accum, op)
+    }
+
+    ENTRY main {
+      input = f32[5,3,4,5] parameter(0)
+      zero = f32[] constant(0)
+      ROOT out = f32[5,4] reduce(input, zero), dimensions={1,3}, to_apply=add
+    }
+  )";
+
+  CheckTreeRewriter(hlo, std::nullopt);
+}
+
+}  // namespace
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/triangular_solve_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/triangular_solve_rewriter.cc
new file mode 100644
index 0000000..e81bdae
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/triangular_solve_rewriter.cc
@@ -0,0 +1,88 @@
+/* Copyright 2022 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/triangular_solve_rewriter.h"
+
+#include <cstdint>
+#include <numeric>
+#include <vector>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/service/hlo_creation_utils.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+absl::StatusOr<bool> TriangularSolveRewriter::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  bool changed = false;
+  for (HloComputation* comp :
+       module->MakeNonfusionComputations(execution_threads)) {
+    std::vector<HloInstruction*> to_rewrite;
+    for (HloInstruction* instr : comp->instructions()) {
+      if (instr->opcode() == HloOpcode::kTriangularSolve) {
+        to_rewrite.push_back(instr);
+      }
+    }
+
+    for (HloInstruction* instr : to_rewrite) {
+      const Shape& b_shape = instr->operand(1)->shape();
+      int64_t batch_size = std::accumulate(
+          b_shape.dimensions().begin(), b_shape.dimensions().end() - 2,
+          int64_t{1}, [](int64_t a, int64_t b) { return a * b; });
+
+      // batch 1 triangular solves get 0 temp bytes, because unbatched trsm()
+      // doesn't require temp memory.
+      int64_t temp_bytes = batch_size == 1 ? 0 : 2 * sizeof(void*) * batch_size;
+      Shape new_shape = ShapeUtil::MakeTupleShape({
+          instr->shape(),
+          ShapeUtil::MakeShape(S8, {temp_bytes}),
+      });
+
+      HloInstruction* custom_call =
+          comp->AddInstruction(HloInstruction::CreateCustomCall(
+              new_shape, instr->operands(), kTriangularSolveCallTarget));
+      module->SetAndUniquifyInstrName(custom_call, "triangular-solve");
+      TF_RETURN_IF_ERROR(
+          custom_call->set_backend_config(instr->triangular_solve_options()));
+
+      // Preserve metadata from `instr`.
+      custom_call->set_metadata(instr->metadata());
+      custom_call->set_frontend_attributes(instr->frontend_attributes());
+
+      // Get the actual result out of the custom call's tuple.
+      TF_ASSIGN_OR_RETURN(HloInstruction * gte,
+                          MakeGetTupleElementHlo(custom_call, 0));
+      TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, gte));
+    }
+  }
+  return changed;
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/triangular_solve_rewriter.h b/third_party/xla/xla/service/gpu/transforms/triangular_solve_rewriter.h
new file mode 100644
index 0000000..c52e0ff
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/triangular_solve_rewriter.h
@@ -0,0 +1,60 @@
+/* Copyright 2022 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_TRIANGULAR_SOLVE_REWRITER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_TRIANGULAR_SOLVE_REWRITER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// Rewrites HLO TriangularSolve ops into a custom-call.
+//
+// The motivation for this is that we need to add temp memory to batched
+// triangular-solve ops in order to call cublas trsmBatched.  We rewrite batch 1
+// ops as well so that we have fewer codepaths to worry about in the backend.
+//
+// cublas trsmBatched takes arrays in GPU memory of pointers to the inputs and
+// outputs, `a` and `b`.  In XLA the inputs/outputs are always contiguous, but
+// we still have to materialize out these arrays.
+//
+// We use the same trick as for cudnn convolutions: This custom-call returns a
+// tuple (actual-result, temp-memory).  In this our case the temp buffer always
+// has size 2 * sizeof(void*) * batch_size, because we need two arrays of
+// pointers.
+//
+// The custom-call has a backend-config equal to the TriangularSolveOptions
+// object.
+class TriangularSolveRewriter : public HloModulePass {
+ public:
+  absl::string_view name() const override {
+    return "triangular-solve-rewriter";
+  }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_TRIANGULAR_SOLVE_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier.cc b/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier.cc
new file mode 100644
index 0000000..10ae640
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier.cc
@@ -0,0 +1,190 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/triton_fusion_numerics_verifier.h"
+
+#include <memory>
+#include <optional>
+#include <utility>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/functional/any_invocable.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/executable.h"
+#include "xla/service/gpu/autotuning/autotuner_compile_util.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/gpu/buffer_comparator.h"
+#include "xla/service/gpu/ir_emission_utils.h"
+#include "xla/service/hlo_module_config.h"
+#include "xla/service/shaped_buffer.h"
+#include "xla/shape.h"
+#include "xla/status_macros.h"
+#include "xla/stream_executor/stream.h"
+#include "xla/tools/hlo_decomposer.h"
+#include "xla/util.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla::gpu {
+
+namespace {
+
+using ProfilingOutput = AutotunerCompileUtil::ProfilingOutput;
+
+// Returns the input instruction as a fusion instruction, if it represents a
+// Triton fusion. Otherwise, returns nullptr.
+absl::StatusOr<const HloFusionInstruction*> AsTritonFusion(
+    const HloInstruction* hlo) {
+  if (hlo->opcode() != HloOpcode::kFusion) {
+    return nullptr;
+  }
+  const HloFusionInstruction* fusion = Cast<HloFusionInstruction>(hlo);
+  TF_ASSIGN_OR_RETURN(auto gpu_config,
+                      fusion->backend_config<GpuBackendConfig>());
+  const FusionBackendConfig& backend_config =
+      gpu_config.fusion_backend_config();
+  if (backend_config.kind() == kTritonFusionKind) {
+    return fusion;
+  }
+  return nullptr;
+}
+
+std::unique_ptr<HloModule> NewHloModuleFromFusion(
+    const HloFusionInstruction& fusion, const DebugOptions& debug_opts,
+    bool clear_backend_config) {
+  std::unique_ptr<HloModule> new_module =
+      ExtractInstructionIntoNewModule(fusion);
+  if (clear_backend_config) {
+    new_module->entry_computation()->root_instruction()->clear_backend_config();
+  }
+  new_module->mutable_config().set_debug_options(debug_opts);
+
+  return new_module;
+}
+
+}  // namespace
+
+namespace triton_fusion_numerics_pass_internal {
+
+absl::StatusOr<ScopedShapedBuffer> CompileAndRunFusion(
+    AutotunerCompileUtil& util, const HloFusionInstruction& fusion,
+    const AutotuneConfig& config, const DebugOptions& debug_opts,
+    bool clear_backend_config) {
+  TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
+                      util.Compile([&](const DebugOptions& opts) {
+                        return NewHloModuleFromFusion(fusion, opts,
+                                                      clear_backend_config);
+                      }));
+  TF_ASSIGN_OR_RETURN(auto rz_buffers, RedzoneBuffers::FromInstruction(
+                                           fusion, config, debug_opts,
+                                           RedzoneBuffers::kAllInputs));
+  TF_ASSIGN_OR_RETURN(auto stream, config.GetStream());
+  TF_ASSIGN_OR_RETURN(std::optional<ProfilingOutput> profiling_output,
+                      util.ProfileExecutable(executable.get(), stream,
+                                             rz_buffers.input_buffers(),
+                                             rz_buffers.input_shapes()));
+  if (!profiling_output.has_value()) {
+    return Internal("No output after a successful verification run.");
+  }
+
+  return std::move(profiling_output->output);
+}
+
+absl::Status CompareBuffers(const ScopedShapedBuffer& current,
+                            const ScopedShapedBuffer& expected,
+                            const Shape& shape, const HloModuleConfig& config,
+                            se::Stream* stream) {
+  BufferComparator comparator(
+      shape, config.debug_options().xla_gpu_autotune_gemm_rtol());
+  TF_ASSIGN_OR_RETURN(bool outputs_match,
+                      comparator.CompareEqual(stream, current.root_buffer(),
+                                              expected.root_buffer()));
+
+  if (!outputs_match) {
+    return Internal("Triton fusion output does not match emitters output.");
+  }
+  return absl::OkStatus();
+}
+
+absl::Status ForAllTritonFusions(
+    const HloModule& module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads,
+    absl::AnyInvocable<absl::Status(const HloFusionInstruction&)> fn) {
+  for (HloComputation* computation :
+       module.MakeNonfusionComputations(execution_threads)) {
+    for (HloInstruction* instruction : computation->instructions()) {
+      TF_ASSIGN_OR_RETURN(auto triton_fusion, AsTritonFusion(instruction));
+      if (triton_fusion != nullptr) {
+        TF_RETURN_IF_ERROR(fn(*triton_fusion));
+      }
+    }
+  }
+  return absl::OkStatus();
+}
+
+}  // namespace triton_fusion_numerics_pass_internal
+
+namespace {
+absl::Status VerifyTritonFusion(AutotunerCompileUtil& util,
+                                const HloFusionInstruction& fusion,
+                                const AutotuneConfig& config,
+                                const DebugOptions& debug_opts) {
+  TF_ASSIGN_OR_RETURN(auto triton_result,
+                      triton_fusion_numerics_pass_internal::CompileAndRunFusion(
+                          util, fusion, config, debug_opts,
+                          /*clear_backend_config=*/false));
+  TF_ASSIGN_OR_RETURN(auto emitters_result,
+                      triton_fusion_numerics_pass_internal::CompileAndRunFusion(
+                          util, fusion, config, debug_opts,
+                          /*clear_backend_config=*/true));
+
+  TF_ASSIGN_OR_RETURN(auto stream, config.GetStream());
+  return triton_fusion_numerics_pass_internal::CompareBuffers(
+      triton_result, emitters_result, fusion.shape(),
+      fusion.GetModule()->config(), stream);
+}
+
+}  // namespace
+
+absl::StatusOr<bool> TritonFusionNumericsVerifier::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  if (config_.IsDeviceless()) {
+    return absl::InternalError(
+        "Cannot run TritonFusionNumericsVerifier on a deviceless compilation.");
+  }
+
+  const DebugOptions& debug_options = module->config().debug_options();
+  TF_ASSIGN_OR_RETURN(std::optional<AutotunerCompileUtil> opt_compile_util,
+                      AutotunerCompileUtil::Create(config_, debug_options));
+  TF_RET_CHECK(opt_compile_util.has_value());
+
+  TF_RETURN_IF_ERROR(triton_fusion_numerics_pass_internal::ForAllTritonFusions(
+      *module, execution_threads, [&](const HloFusionInstruction& fusion) {
+        return VerifyTritonFusion(*opt_compile_util, fusion, config_,
+                                  debug_options);
+      }));
+  return false;
+}
+
+}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier.h b/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier.h
new file mode 100644
index 0000000..e3dc6eb
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier.h
@@ -0,0 +1,74 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_TRITON_FUSION_NUMERICS_VERIFIER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_TRITON_FUSION_NUMERICS_VERIFIER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/functional/any_invocable.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/gpu/autotuning/autotuner_compile_util.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
+#include "xla/service/hlo_module_config.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/service/shaped_buffer.h"
+#include "xla/shape.h"
+#include "xla/stream_executor/stream.h"
+
+namespace xla::gpu {
+
+// For each Triton fusion in the Hlo module this pass checks that the output
+// of the fusion generated via Triton matches the output of the fusion if
+// generated with the regular emitters.
+class TritonFusionNumericsVerifier : public HloModulePass {
+ public:
+  explicit TritonFusionNumericsVerifier(const AutotuneConfig& config)
+      : config_(config) {}
+
+  static absl::string_view Name() { return "triton-numerics-verifier"; }
+  absl::string_view name() const override { return Name(); }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+  AutotuneConfig config_;
+};
+
+namespace triton_fusion_numerics_pass_internal {
+// These are exposed only for testing. Do not use.
+absl::StatusOr<ScopedShapedBuffer> CompileAndRunFusion(
+    AutotunerCompileUtil& util, const HloFusionInstruction& fusion,
+    const AutotuneConfig& config, const DebugOptions& debug_opts,
+    bool clear_backend_config);
+absl::Status CompareBuffers(const ScopedShapedBuffer& current,
+                            const ScopedShapedBuffer& expected,
+                            const Shape& shape, const HloModuleConfig& config,
+                            se::Stream* stream);
+absl::Status ForAllTritonFusions(
+    const HloModule& module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads,
+    absl::AnyInvocable<absl::Status(const HloFusionInstruction&)> fn);
+}  // namespace triton_fusion_numerics_pass_internal
+
+}  // namespace xla::gpu
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_TRITON_FUSION_NUMERICS_VERIFIER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier_test.cc b/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier_test.cc
new file mode 100644
index 0000000..0382577
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier_test.cc
@@ -0,0 +1,195 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/triton_fusion_numerics_verifier.h"
+
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "absl/status/status.h"
+#include "absl/strings/string_view.h"
+#include "absl/strings/substitute.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/primitive_util.h"
+#include "xla/service/gpu/autotuning/autotuner_compile_util.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
+#include "xla/service/platform_util.h"
+#include "xla/stream_executor/platform.h"
+#include "xla/test_helpers.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/tsl/lib/core/status_test_util.h"
+
+namespace xla::gpu {
+namespace {
+
+class TritonFusionNumericsVerifierTest
+    : public HloTestBase,
+      public ::testing::WithParamInterface<PrimitiveType> {
+ public:
+  DebugOptions GetDebugOptionsForTest() override {
+    auto options = HloTestBase::GetDebugOptionsForTest();
+    options.set_xla_gpu_enable_triton_softmax_fusion(true);
+    options.set_xla_gpu_verify_triton_fusion_numerics(true);
+    return options;
+  }
+
+ protected:
+  std::unique_ptr<xla::HloModule> Module(absl::string_view hlo_text_template,
+                                         absl::string_view type) {
+    auto m = GetOptimizedModule(absl::Substitute(hlo_text_template, type));
+    TF_EXPECT_OK(m);
+    return std::move(m.value());
+  }
+
+  const HloFusionInstruction* TritonFusion(const xla::HloModule& module) {
+    const HloFusionInstruction* fusion_result = nullptr;
+
+    absl::Status res =
+        triton_fusion_numerics_pass_internal::ForAllTritonFusions(
+            module, /*execution_threads=*/{},
+            [&](const HloFusionInstruction& fusion) -> absl::Status {
+              EXPECT_EQ(fusion_result, nullptr);
+              fusion_result = &fusion;
+              return absl::OkStatus();
+            });
+    return fusion_result;
+  }
+
+  AutotuneConfig CreateAutotuneConfig() {
+    se::Platform* platform = PlatformUtil::GetDefaultPlatform().value();
+    auto executors_or = PlatformUtil::GetStreamExecutors(platform);
+    TF_EXPECT_OK(executors_or);
+    return AutotuneConfig{DeviceConfig{executors_or->at(0), nullptr},
+                          GetDebugOptionsForTest()};
+  }
+
+  AutotunerCompileUtil CreateAutotunerCompileUtil(AutotuneConfig& config) {
+    auto opt_compile_util_or =
+        AutotunerCompileUtil::Create(config, GetDebugOptionsForTest());
+    TF_EXPECT_OK(opt_compile_util_or);
+    EXPECT_TRUE(opt_compile_util_or->has_value());
+    return std::move(opt_compile_util_or->value());
+  }
+};
+
+constexpr absl::string_view kSoftmaxHlo = R"(
+HloModule softmax
+max_computation {
+  arg_0 = $0[] parameter(0)
+  arg_1 = $0[] parameter(1)
+  ROOT maximum = $0[] maximum(arg_0, arg_1)
+}
+add_computation {
+  arg_0.1 = $0[] parameter(0)
+  arg_1.1 = $0[] parameter(1)
+  ROOT add = $0[] add(arg_0.1, arg_1.1)
+}
+ENTRY main {
+  param_0 = $0[127,125]{1,0} parameter(0)
+  constant_neg_inf = $0[] constant(-inf)
+  reduce = $0[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
+  broadcast = $0[127,125]{1,0} broadcast(reduce), dimensions={0}
+  subtract = $0[127,125]{1,0} subtract(param_0, broadcast)
+  exponential = $0[127,125]{1,0} exponential(subtract)
+  constant_zero = $0[] constant(0)
+  second_reduce = $0[127]{0} reduce(exponential, constant_zero), dimensions={1}, to_apply=add_computation
+  second_broadcast = $0[127,125]{1,0} broadcast(second_reduce), dimensions={0}
+  ROOT divide = $0[127,125]{1,0} divide(exponential, second_broadcast)
+}
+)";
+
+bool HloPassHasRun(const HloModule& module, absl::string_view pass_name) {
+  for (const auto& pass_metadata : module.metadata().proto().pass_metadata()) {
+    if (pass_metadata.pass_name() == pass_name) {
+      return true;
+    }
+  }
+  return false;
+}
+
+TEST_P(TritonFusionNumericsVerifierTest, VerifyExactSoftmaxFusionNumerics) {
+  PrimitiveType data_type = GetParam();
+
+  auto module = Module(kSoftmaxHlo,
+                       primitive_util::LowercasePrimitiveTypeName(data_type));
+
+  // At this point all HLO passes have been executed successfully, because the
+  // Module() function hasn't failed. In particular the numerics verification
+  // pass should have also run and **not** found any issues. Below we just
+  // ensure that the pass has indeed been correctly enabled and that there are
+  // Triton Fusions in the input module.
+
+  EXPECT_TRUE(HloPassHasRun(*module, TritonFusionNumericsVerifier::Name()));
+  auto fusion = TritonFusion(*module);
+  EXPECT_NE(fusion, nullptr);
+}
+
+TEST_F(TritonFusionNumericsVerifierTest, CheckMismatch) {
+  // This test intentionally compares two different Triton modules to each
+  // other. This is to test that the verifier functions correctly catch and
+  // report mismatches.
+  //
+  // Note that as part of computing the two modules below, the numerics verifier
+  // pass also runs individually for each module. These runs compare the
+  // modules to the corresponding emitters generated version, which matches. In
+  // that sense this test covers what is being tested by
+  // VerifyExactSoftmaxFusionNumerics. The reason to keep two tests is that
+  // VerifyExactSoftmaxFusionNumerics is minimal and will be easier to debug if
+  // it fails.
+
+  auto module_f16 = Module(kSoftmaxHlo, "f16");
+  auto fusion_f16 = TritonFusion(*module_f16);
+  EXPECT_NE(fusion_f16, nullptr);
+
+  auto module_f32 = Module(kSoftmaxHlo, "f32");
+  auto fusion_f32 = TritonFusion(*module_f32);
+  EXPECT_NE(fusion_f32, nullptr);
+
+  AutotuneConfig autotune_config = CreateAutotuneConfig();
+  AutotunerCompileUtil compile_util =
+      CreateAutotunerCompileUtil(autotune_config);
+  const DebugOptions& debug_options = GetDebugOptionsForTest();
+
+  auto f16_result = triton_fusion_numerics_pass_internal::CompileAndRunFusion(
+      compile_util, *fusion_f16, autotune_config, debug_options,
+      /*clear_backend_config=*/false);
+  TF_EXPECT_OK(f16_result);
+
+  auto f32_result = triton_fusion_numerics_pass_internal::CompileAndRunFusion(
+      compile_util, *fusion_f32, autotune_config, debug_options,
+      /*clear_backend_config=*/false);
+  TF_EXPECT_OK(f32_result);
+
+  auto stream = autotune_config.GetStream();
+  TF_EXPECT_OK(stream);
+
+  // Intentionally compare the fusions from the different modules, triggering a
+  // mismatch.
+  auto cmp = triton_fusion_numerics_pass_internal::CompareBuffers(
+      *f16_result, *f32_result, fusion_f16->shape(),
+      fusion_f16->GetModule()->config(), *stream);
+
+  EXPECT_FALSE(cmp.ok());
+}
+
+INSTANTIATE_TEST_SUITE_P(TritonFusionNumericsVerifierTestSuite,
+                         TritonFusionNumericsVerifierTest,
+                         ::testing::Values(F32, F16, BF16));
+
+}  // namespace
+}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/transforms/variadic_op_splitter.cc b/third_party/xla/xla/service/gpu/transforms/variadic_op_splitter.cc
new file mode 100644
index 0000000..0712040
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/variadic_op_splitter.cc
@@ -0,0 +1,115 @@
+/* Copyright 2018 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/variadic_op_splitter.h"
+
+#include <cstdint>
+#include <vector>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/shape.h"
+#include "xla/util.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+// The parameter space on the GPU device is limited. We pick an arbitrary low
+// constant here to try to prevent exceeding this parameter space. For a proper
+// fix, we would have to take into account which parameters share a buffer, and
+// how big these buffers are.
+constexpr int32_t kMaxParameters = 128;
+
+absl::StatusOr<bool> SplitConcatenate(HloInstruction* concat,
+                                      HloComputation* comp) {
+  auto operands = concat->operands();
+  std::vector<HloInstruction*> operands_to_split(operands.begin(),
+                                                 operands.end());
+  while (operands_to_split.size() > 1) {
+    std::vector<HloInstruction*> new_operands;
+    absl::Span<HloInstruction*> operands_span(operands_to_split);
+    for (int64_t offset = 0; offset < operands_to_split.size();
+         offset += kMaxParameters) {
+      // Check if there is a remainder of operands that does not completely fill
+      // one "batch" of exactly 'kMaxParameters' operands. If there are only
+      // less than 'kMaxParameters' operands left, then we still put them into a
+      // concat together. Otherwise, we spare them for another round so that
+      // they can be put together into a concat with some of the newly created
+      // concats.
+      if (offset > 0 && offset + kMaxParameters > operands_to_split.size()) {
+        new_operands.insert(new_operands.end(),
+                            operands_to_split.begin() + offset,
+                            operands_to_split.end());
+      } else {
+        Shape new_shape = concat->shape();
+        int64_t concat_dimension_size = 0;
+        for (int64_t i = 0;
+             i < kMaxParameters && offset + i < operands_to_split.size(); ++i) {
+          concat_dimension_size +=
+              operands_to_split[i + offset]->shape().dimensions(
+                  concat->concatenate_dimension());
+        }
+        new_shape.set_dimensions(concat->concatenate_dimension(),
+                                 concat_dimension_size);
+        auto new_concat = comp->AddInstruction(concat->CloneWithNewOperands(
+            new_shape, operands_span.subspan(offset, kMaxParameters)));
+        new_operands.push_back(new_concat);
+      }
+    }
+    operands_to_split = new_operands;
+  }
+  TF_RETURN_IF_ERROR(comp->ReplaceInstruction(concat, operands_to_split[0]));
+  return true;
+}
+
+std::vector<HloInstruction*> GetRelevantVariadicOps(HloComputation* comp) {
+  std::vector<HloInstruction*> ops;
+  for (HloInstruction* instr : comp->instructions()) {
+    if (instr->opcode() == HloOpcode::kConcatenate &&
+        instr->operand_count() > kMaxParameters) {
+      ops.push_back(instr);
+    }
+  }
+  return ops;
+}
+
+}  // namespace
+
+absl::StatusOr<bool> VariadicOpSplitter::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  bool changed = false;
+  for (HloComputation* comp :
+       module->MakeNonfusionComputations(execution_threads)) {
+    for (HloInstruction* op : GetRelevantVariadicOps(comp)) {
+      // TODO(b/112613927): Handle also other ops than concatenate.
+      TF_ASSIGN_OR_RETURN(bool result, SplitConcatenate(op, comp));
+      changed |= result;
+    }
+  }
+  return changed;
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/variadic_op_splitter.h b/third_party/xla/xla/service/gpu/transforms/variadic_op_splitter.h
new file mode 100644
index 0000000..304afa1
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/variadic_op_splitter.h
@@ -0,0 +1,43 @@
+/* Copyright 2018 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_VARIADIC_OP_SPLITTER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_VARIADIC_OP_SPLITTER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// Splits variadic ops with many operands into pieces such that we don't exceed
+// the parameter space on the GPU. Currently only concatenate ops are split up.
+class VariadicOpSplitter : public HloModulePass {
+ public:
+  absl::string_view name() const override { return "variadic-op-splitter"; }
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_VARIADIC_OP_SPLITTER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/variadic_op_splitter_test.cc b/third_party/xla/xla/service/gpu/transforms/variadic_op_splitter_test.cc
new file mode 100644
index 0000000..1d72613
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/variadic_op_splitter_test.cc
@@ -0,0 +1,84 @@
+/* Copyright 2018 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/variadic_op_splitter.h"
+
+#include <cstdint>
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/literal_util.h"
+#include "xla/service/hlo_parser.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/shape_util.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/util.h"
+#include "xla/xla_data.pb.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+using match::Concatenate;
+
+class VariadicOpSplitterTest : public HloTestBase {};
+
+TEST_F(VariadicOpSplitterTest, DontSplit) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule TestModule
+
+  ENTRY TestComputation {
+    p0 = f16[30,41] parameter(0)
+    p1 = f16[30,41] parameter(1)
+    ROOT result = f16[60, 41] concatenate(p0, p1), dimensions={0}
+  })")
+                    .value();
+  EXPECT_FALSE(VariadicOpSplitter().Run(module.get()).value());
+}
+
+TEST_F(VariadicOpSplitterTest, SplitInto2) {
+  auto builder = HloComputation::Builder(TestName());
+  auto operand = builder.AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32_t>({42})));
+  std::vector<HloInstruction*> concat_operands(255, operand);
+  builder.AddInstruction(HloInstruction::CreateConcatenate(
+      ShapeUtil::MakeShape(S32, {255}), concat_operands, 0));
+  auto module = CreateNewVerifiedModule();
+  auto entry_computation = module->AddEntryComputation(builder.Build());
+  EXPECT_TRUE(VariadicOpSplitter().Run(module.get()).value());
+  EXPECT_TRUE(Match(entry_computation->root_instruction(),
+                    Concatenate().WithNumOperands(128).WithOperand(
+                        0, Concatenate().WithNumOperands(128))));
+}
+
+TEST_F(VariadicOpSplitterTest, SplitInto3) {
+  auto builder = HloComputation::Builder(TestName());
+  auto operand = builder.AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32_t>({42})));
+  std::vector<HloInstruction*> concat_operands(256, operand);
+  builder.AddInstruction(HloInstruction::CreateConcatenate(
+      ShapeUtil::MakeShape(S32, {256}), concat_operands, 0));
+  auto module = CreateNewVerifiedModule();
+  auto entry_computation = module->AddEntryComputation(builder.Build());
+  EXPECT_TRUE(VariadicOpSplitter().Run(module.get()).value());
+  EXPECT_TRUE(Match(entry_computation->root_instruction(),
+                    Concatenate(Concatenate().WithNumOperands(128),
+                                Concatenate().WithNumOperands(128))));
+}
+
+}  // namespace
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler.cc b/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler.cc
new file mode 100644
index 0000000..eeb24ce
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler.cc
@@ -0,0 +1,1148 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/windowed_einsum_handler.h"
+
+#include <cstdint>
+#include <vector>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/status.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/hlo/utils/hlo_query.h"
+#include "xla/literal_util.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/hlo_creation_utils.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/shape_inference.h"
+#include "xla/util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/logging.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla::gpu {
+namespace {
+
+namespace m = match;
+
+// Enables the creation of FP8 GEMM Custom Calls for all-gather and
+// reduce-scatter windowed einsums in gemm_rewriter.cc by moving the scalings
+// and type conversions of FP8 operands into the bodies of their while loops,
+// i.e. rewrites
+//
+//   inputs --> dequant --> while loop {dynamic-slice/collective-permute/dot}
+//
+// into
+//
+//   inputs --> while loop {dequant --> dynamic-slice/collective-permute/dot}.
+absl::Status ShiftDequantizationF8(const HloComputation* comp,
+                                   const std::array<HloInstruction*, 2>& gte) {
+  HloInstruction* while_instr = comp->WhileCallInstruction();
+  if (!while_instr) {
+    return absl::OkStatus();
+  }
+
+  // Identify the scalings and type conversions applied to the inputs of the
+  // while loop.
+  HloInstruction* param_tuple = while_instr->mutable_operand(0);
+  std::array<HloInstruction*, 2> binaries, operands, scales;
+  for (int k = 0; k < 2; ++k) {
+    if (!Match(param_tuple->mutable_operand(k),
+               m::AnyOf<HloInstruction>(
+                   m::Divide(&binaries[k], m::Convert(m::Op(&operands[k])),
+                             m::Broadcast(m::Op(&scales[k]))),
+                   m::MultiplyAnyOrder(&binaries[k],
+                                       m::Convert(m::Op(&operands[k])),
+                                       m::Broadcast(m::Op(&scales[k])))))) {
+      VLOG(5) << "Unable to identify FP8 dequantization pattern.";
+      return absl::OkStatus();
+    }
+  }
+
+  // For the dot to be rewritten by gemm_rewriter.cc into an FP8 GEMM, at most
+  // one of the inputs can be F8E5M2.
+  std::array<PrimitiveType, 2> operand_types{
+      operands[0]->shape().element_type(), operands[1]->shape().element_type()};
+  if (!((operand_types[0] == F8E4M3FN && operand_types[1] == F8E4M3FN) ||
+        (operand_types[0] == F8E4M3FN && operand_types[1] == F8E5M2) ||
+        (operand_types[0] == F8E5M2 && operand_types[1] == F8E4M3FN))) {
+    VLOG(5) << "Unsupported types.";
+    return absl::OkStatus();
+  }
+
+  // The dequantized types must be BF16, FP16 or FP32.
+  for (int k = 0; k < 2; ++k) {
+    if (binaries[k]->shape().element_type() != BF16 &&
+        binaries[k]->shape().element_type() != F16 &&
+        binaries[k]->shape().element_type() != F32) {
+      VLOG(5) << "Unsupported types.";
+      return absl::OkStatus();
+    }
+  }
+
+  // The FP8 scaling operands must be scalars.
+  if (!ShapeUtil::IsScalar(scales[0]->shape()) ||
+      !ShapeUtil::IsScalar(scales[1]->shape())) {
+    VLOG(5) << "Scaling factors must be scalars.";
+    return absl::OkStatus();
+  }
+
+  // Identify the dot and collective-permute or dynamic-slice instructions in
+  // the all-gather or reduce-scatter patterns in while's body.
+  HloComputation* while_body = while_instr->while_body();
+  HloComputation* while_condition = while_instr->while_condition();
+  HloInstruction* while_root = while_body->root_instruction();
+  std::array<HloInstruction*, 2> dots, dyn_slices{nullptr, nullptr},
+      coll_perms{nullptr, nullptr};
+  if (Match(
+          while_root,
+          m::Tuple(m::CollectivePermute(
+                       &coll_perms[1], m::CollectivePermute(
+                                           &coll_perms[0], m::Op().Is(gte[0]))),
+                   m::Op().Is(gte[1]),
+                   m::DynamicUpdateSlice(
+                       m::DynamicUpdateSlice().WithOperand(
+                           1, m::Dot(&dots[0], m::Op().Is(gte[0]),
+                                     m::Op().Is(gte[1]))),
+                       m::Dot(&dots[1], m::Op(), m::Op().Is(gte[1])), m::Op(),
+                       m::Op(), m::Op()),
+                   m::Op(), m::Op()))) {
+    VLOG(5) << "Identified all-gather windowed einsum pattern.";
+  } else if (Match(
+                 while_root,
+                 m::Tuple(m::Op().Is(gte[0]), m::Op().Is(gte[1]),
+                          m::AddAnyOrder(
+                              m::Dot(&dots[0], m::DynamicSlice(&dyn_slices[0]),
+                                     m::Op().Is(gte[1])),
+                              m::Op()),
+                          m::CollectivePermute(m::AddAnyOrder(
+                              m::Dot(&dots[1], m::DynamicSlice(&dyn_slices[1]),
+                                     m::Op().Is(gte[1])),
+                              m::Op())),
+                          m::Op()))) {
+    VLOG(5) << "Identified reduce-scatter windowed einsum pattern.";
+  } else {
+    VLOG(5) << "Unable to identify valid windowed einsum pattern.";
+    return absl::OkStatus();
+  }
+
+  // Replace the dequantized dot operands in the parameter tuple used by while
+  // with FP8 operands.
+  for (int k = 0; k < 2; ++k) {
+    TF_RETURN_IF_ERROR(
+        param_tuple->ReplaceOperandWithDifferentShape(k, operands[k]));
+    ShapeUtil::UpdateTupleShape(operands[k]->shape(), k,
+                                param_tuple->mutable_shape());
+    param_tuple->AppendOperand(scales[k]);
+    ShapeUtil::AppendShapeToTuple(scales[k]->shape(),
+                                  param_tuple->mutable_shape());
+  }
+
+  // Update the parameter tuples of while's body and condition computations.
+  for (HloComputation* while_comp : {while_body, while_condition}) {
+    while_comp->ReplaceParameter(
+        0, HloInstruction::CreateParameter(
+               0, param_tuple->shape(),
+               while_comp->parameter_instruction(0)->name()));
+  }
+
+  // In the while body, replace the existing get-tuple-element instructions
+  // retrieving BF16/FP16/FP32 dot operands with dequantized get-tuple-element
+  // instructions retrieving FP8 dot operands from the input tuple.
+  HloInstruction* body_param = while_body->parameter_instruction(0);
+  for (int k = 0; k < 2; ++k) {
+    TF_ASSIGN_OR_RETURN(HloInstruction * operand_f8,
+                        MakeGetTupleElementHlo(body_param, k));
+
+    if (while_root->operand(k) == gte[k]) {
+      TF_RETURN_IF_ERROR(
+          while_root->ReplaceOperandWithDifferentShape(k, operand_f8));
+      ShapeUtil::UpdateTupleShape(operand_f8->shape(), k,
+                                  while_root->mutable_shape());
+    }
+
+    TF_ASSIGN_OR_RETURN(
+        HloInstruction * operand_scale,
+        MakeGetTupleElementHlo(
+            body_param, body_param->shape().tuple_shapes_size() - 2 + k));
+
+    // Also add the scaling factor to the output tuple of the while body.
+    while_root->AppendOperand(operand_scale);
+    ShapeUtil::AppendShapeToTuple(operand_scale->shape(),
+                                  while_root->mutable_shape());
+
+    // Dequantize the operands of the dots and dynamic-slices.
+    HloInstruction* operand_f32 =
+        MakeConvertToHlo(operand_f8, gte[k]->shape().element_type());
+    HloInstruction* broadcast_scale =
+        MakeBroadcastHlo(operand_scale, {}, operand_f32->shape());
+    TF_ASSIGN_OR_RETURN(
+        HloInstruction * operand_scaled,
+        MakeBinaryHlo(binaries[k]->opcode(), operand_f32, broadcast_scale));
+
+    // Replace the original get-tuple-element instructions accessing the
+    // operands of the dots and dynamic-slices with the dequantized FP8
+    // operands. The order of dequantization and dynamic-slices will be
+    // exchanged in gemm_rewriter.cc.
+    for (int l = 0; l < 2; ++l) {
+      if (dots[l]->operand(k) == gte[k]) {
+        TF_RETURN_IF_ERROR(dots[l]->ReplaceOperandWith(k, operand_scaled));
+      }
+      if (dyn_slices[l] && dyn_slices[l]->operand(0) == gte[k]) {
+        TF_RETURN_IF_ERROR(
+            dyn_slices[l]->ReplaceOperandWith(0, operand_scaled));
+      }
+    }
+
+    // In the all-gather case, coll_perms[0] has two users, coll_perms[1] and
+    // dots[1], which prevents it from being exchanged with dequantization in
+    // gemm_rewriter.cc. Instead, directly insert the dequantization before
+    // dots[1] here.
+    if (coll_perms[0] && coll_perms[0]->operand(0) == gte[k]) {
+      std::array<HloInstruction*, 2> coll_perms_f8{nullptr, nullptr};
+      // Change the type of both collective-permutes to FP8.
+      coll_perms_f8[0] =
+          while_body->AddInstruction(coll_perms[0]->CloneWithNewOperands(
+              operand_f8->shape(), {operand_f8}));
+      coll_perms_f8[1] =
+          while_body->AddInstruction(coll_perms[1]->CloneWithNewOperands(
+              coll_perms_f8[0]->shape(), {coll_perms_f8[0]}));
+
+      // Insert the dequantization between coll_perms[0] and dots[1].
+      HloInstruction* coll_perm0_f32 =
+          MakeConvertToHlo(coll_perms_f8[0], gte[k]->shape().element_type());
+      TF_ASSIGN_OR_RETURN(HloInstruction * x_scaled,
+                          MakeBinaryHlo(binaries[k]->opcode(), coll_perm0_f32,
+                                        broadcast_scale));
+      TF_RETURN_IF_ERROR(dots[1]->ReplaceOperandWith(0, x_scaled));
+
+      // Update the output tuple.
+      TF_RETURN_IF_ERROR(
+          while_root->ReplaceOperandWithDifferentShape(0, coll_perms_f8[1]));
+      ShapeUtil::UpdateTupleShape(coll_perms_f8[1]->shape(), 0,
+                                  while_root->mutable_shape());
+    }
+  }
+
+  // Update the shape of the while call in the parent computation.
+  TF_RETURN_IF_ERROR(
+      while_instr->ReplaceAllUsesWithDifferentShape(while_instr->AddInstruction(
+          while_instr->CloneWithNewShape(while_root->shape()))));
+  TF_RETURN_IF_ERROR(while_instr->parent()->RemoveInstruction(while_instr));
+
+  if (coll_perms[0]) {
+    TF_RETURN_IF_ERROR(while_body->RemoveInstruction(coll_perms[1]));
+    TF_RETURN_IF_ERROR(while_body->RemoveInstruction(coll_perms[0]));
+  }
+  TF_RETURN_IF_ERROR(while_body->RemoveInstruction(gte[0]));
+  TF_RETURN_IF_ERROR(while_body->RemoveInstruction(gte[1]));
+
+  VLOG(5) << "FP8 dequantization moved into while loop.";
+  return absl::OkStatus();
+}
+
+int64_t NumberOfInstructionsInComp(const HloComputation* comp, HloOpcode op) {
+  int64_t total_count = 0;
+  for (const HloInstruction* inst : comp->instructions()) {
+    if (inst->opcode() == op) {
+      ++total_count;
+    }
+  }
+  return total_count;
+}
+
+absl::Status UpdateDotAndConsumerConfig(HloInstruction* dot,
+                                        int64_t stream_id) {
+  auto dot_gpu_config = dot->backend_config<gpu::GpuBackendConfig>();
+  HloInstruction* updater = dot->users()[0];
+  auto updater_gpu_config = updater->backend_config<gpu::GpuBackendConfig>();
+  dot_gpu_config->set_operation_queue_id(stream_id);
+  updater_gpu_config->mutable_wait_on_operation_queues()->Add(stream_id);
+
+  TF_RETURN_IF_ERROR(dot->set_backend_config(dot_gpu_config.value()));
+  TF_RETURN_IF_ERROR(updater->set_backend_config(updater_gpu_config.value()));
+  return absl::OkStatus();
+}
+
+absl::Status SetForceDelayForInstruction(HloInstruction* instr,
+                                         bool force_delay) {
+  auto gpu_config = instr->backend_config<gpu::GpuBackendConfig>();
+
+  gpu_config->set_force_earliest_schedule(force_delay);
+
+  TF_RETURN_IF_ERROR(instr->set_backend_config(gpu_config.value()));
+  return absl::OkStatus();
+}
+
+absl::StatusOr<bool> HandleRsWindowedEinsumLoop(HloComputation* comp,
+                                                int64_t stream_id) {
+  bool changed = false;
+  // If we have a einsum loop with only 1 dot, this means either
+  // the loop is not unrolled or only 1 partition is available.
+  // It's a no-op in either case.
+  if (NumberOfInstructionsInComp(comp, HloOpcode::kDot) <= 1) {
+    return changed;
+  }
+  for (auto inst : comp->MakeInstructionPostOrder()) {
+    HloInstruction* matched_dot;
+    std::array<HloInstruction*, 2> gte;
+    // The dot we'd like to parallelize is consuming the second loop input
+    // as RHS.
+    if (Match(inst,
+              m::Dot(&matched_dot,
+                     m::DynamicSlice().WithOperand(
+                         0, m::GetTupleElement(&gte[0], m::Parameter(), 0)),
+                     m::GetTupleElement(&gte[1], m::Parameter(), 1)))) {
+      // If present, move the dequantization of FP8 operands of the dot into the
+      // while loop to allow gemm_rewriter.cc to rewrite into an FP8 Custom
+      // Call.
+      TF_RETURN_IF_ERROR(ShiftDequantizationF8(comp, gte));
+
+      // Dispatch the dot to additional compute stream.
+      TF_RETURN_IF_ERROR(UpdateDotAndConsumerConfig(matched_dot, stream_id));
+      ++stream_id;
+      changed = true;
+    }
+
+    // We need to enforce the first collective-permute to be always scheduled
+    // at the beginning of the loop.
+    HloInstruction* matched_cp;
+    if (Match(inst, m::CollectivePermute(
+                        &matched_cp, m::GetTupleElement(m::Parameter(), 2)))) {
+      TF_RETURN_IF_ERROR(
+          SetForceDelayForInstruction(matched_cp, /*force_delay=*/true));
+      changed = true;
+    }
+  }
+  return changed;
+}
+
+absl::StatusOr<bool> HandleAgWindowedEinsumLoop(HloComputation* comp,
+                                                int64_t stream_id) {
+  bool changed = false;
+  // If we have a einsum loop with only 1 dot, this means either
+  // the loop is not unrolled or only 1 partition is available.
+  // It's a no-op in either case.
+  if (NumberOfInstructionsInComp(comp, HloOpcode::kDot) <= 1) {
+    return changed;
+  }
+  for (auto inst : comp->MakeInstructionPostOrder()) {
+    HloInstruction* matched_dot;
+    std::array<HloInstruction*, 2> gte;
+    // The dot we'd like to parallelize is consuming the second loop input
+    // as RHS and first loop input as LHS.
+    if (Match(inst, m::Dot(&matched_dot,
+                           m::GetTupleElement(&gte[0], m::Parameter(), 0),
+                           m::GetTupleElement(&gte[1], m::Parameter(), 1)))) {
+      // If present, move the dequantization of FP8 operands of the dot into the
+      // while loop to allow gemm_rewriter.cc to rewrite into an FP8 Custom
+      // Call.
+      TF_RETURN_IF_ERROR(ShiftDequantizationF8(comp, gte));
+
+      // Dispatch the dot to additional compute stream.
+      TF_RETURN_IF_ERROR(UpdateDotAndConsumerConfig(matched_dot, stream_id));
+      ++stream_id;
+      TF_RETURN_IF_ERROR(
+          SetForceDelayForInstruction(matched_dot, /*force_delay=*/true));
+      changed = true;
+    }
+
+    // We need to enforce the first collective-permute to be always scheduled
+    // at the beginning of the loop.
+    HloInstruction* matched_cp;
+    if (Match(inst, m::CollectivePermute(
+                        &matched_cp, m::GetTupleElement(m::Parameter(), 0)))) {
+      TF_RETURN_IF_ERROR(
+          SetForceDelayForInstruction(matched_cp, /*force_delay=*/true));
+      changed = true;
+    }
+  }
+  return changed;
+}
+
+static int64_t GetAgActivationCacheIndex(const HloInstruction* while_loop) {
+  const HloInstruction* loop_tuple = while_loop->operand(0);
+  const Shape& tuple_shape = loop_tuple->shape();
+  CHECK(tuple_shape.IsTuple());
+  return tuple_shape.tuple_shapes_size();
+}
+
+absl::Status ProcessWindowedEinsumLoopForActivationCaching(
+    WindowedEinsumHandler::WindowedEinsumAgLoops& ag_loop,
+    HloInstruction* ag_with_shared_operand) {
+  HloInstruction* loop = ag_loop.loop;
+  // Transform the while body to cache the allgathered result in the
+  // output buffer to be consumed by the dot
+  HloComputation* while_body = loop->while_body();
+  HloInstruction* input_gte;
+  for (HloInstruction* gte : while_body->parameter_instruction(0)->users()) {
+    if (gte->tuple_index() == 0) {
+      input_gte = gte;
+    }
+  }
+  // Get the output operand of the full buffer.
+  HloInstruction* root = while_body->root_instruction();
+  // Change loop body to include the new input and output element.
+  HloInstruction* input_tuple = while_body->parameter_instruction(0);
+  const Shape& input_shape = input_tuple->shape();
+  // The full buffer that we will use to cache the accumulated activation
+  // is the last operand in the output tuple.
+  int64_t full_cache_buffer_index = GetAgActivationCacheIndex(loop);
+  std::vector<Shape> new_input_shapes(input_shape.tuple_shapes().begin(),
+                                      input_shape.tuple_shapes().end());
+  new_input_shapes.push_back(ag_with_shared_operand->shape());
+  // Update body input shape
+  Shape new_input_shape = ShapeUtil::MakeTupleShape(new_input_shapes);
+  *input_tuple->mutable_shape() = new_input_shape;
+  HloInstruction* full_buffer_output_gte =
+      while_body->AddInstruction(HloInstruction::CreateGetTupleElement(
+          ag_with_shared_operand->shape(), input_tuple,
+          full_cache_buffer_index));
+
+  // Update condition input shape
+  HloComputation* cond_comp = loop->while_condition();
+  HloInstruction* cond_input_tuple = cond_comp->parameter_instruction(0);
+  *cond_input_tuple->mutable_shape() = new_input_shape;
+
+  // Update input to the while instruction in parent computation
+  HloInstruction* original_while_input = loop->mutable_operand(0);
+  HloComputation* parent_comp = loop->parent();
+  std::vector<HloInstruction*> new_operands(
+      original_while_input->operands().begin(),
+      original_while_input->operands().end());
+  new_operands.push_back(
+      parent_comp->AddInstruction(HloInstruction::CreateBroadcast(
+          ag_with_shared_operand->shape(),
+          parent_comp->AddInstruction(HloInstruction::CreateConstant(
+              LiteralUtil::Zero(new_input_shapes[0].element_type()))),
+          {})));
+  HloInstruction* new_while_input =
+      parent_comp->AddInstruction(HloInstruction::CreateTuple(new_operands));
+  TF_RETURN_IF_ERROR(
+      loop->ReplaceOperandWithDifferentShape(0, new_while_input));
+  TF_RETURN_IF_ERROR(parent_comp->ReplaceInstructionWithDifferentShape(
+      original_while_input, new_while_input));
+  *loop->mutable_shape() = new_input_shape;
+
+  HloInstruction* new_full_buffer_output = nullptr;
+  // Find the DUS in the loop body and re-use the slice indices
+  // This should just be a constant(0)
+  HloInstruction* dus_boundary_constant;
+  // The slice we need this time is the output of the first
+  // collective-permute
+  HloInstruction* first_cp_output;
+  for (HloInstruction* gte_user : input_gte->users()) {
+    if (gte_user->opcode() == HloOpcode::kCollectivePermute) {
+      first_cp_output = gte_user;
+      break;
+    }
+  }
+  for (HloInstruction* inst : while_body->MakeInstructionPostOrder()) {
+    HloInstruction* slice_indices;
+    // If we have a DUS(PARAM,DS) pattern, we need to update the output
+    // buffer with the first slice.
+    if (Match(inst,
+              m::DynamicUpdateSlice(
+                  m::GetTupleElement(m::Parameter()), m::Op(),
+                  m::Constant(&dus_boundary_constant),
+                  m::Reshape(m::DynamicSlice(&slice_indices, m::Op(), m::Op())),
+                  m::Op()))) {
+      slice_indices = while_body->AddInstruction(HloInstruction::CreateReshape(
+          dus_boundary_constant->shape(), slice_indices));
+      VLOG(5) << "Created slice op for first slice: "
+              << slice_indices->ToString();
+      full_buffer_output_gte =
+          while_body->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
+              full_buffer_output_gte->shape(), full_buffer_output_gte,
+              input_gte,
+              {dus_boundary_constant, slice_indices, dus_boundary_constant}));
+    }
+    // If we have a DUS(DUS,DS) pattern, then the einsum loop is
+    // unrolled, we need to update the output buffer again with the
+    // second slice. Since the second slice will have different indices,
+    // we need to re-capture slice_indices.
+    if (Match(inst,
+              m::DynamicUpdateSlice(
+                  m::DynamicUpdateSlice(), m::Op(), m::Constant(),
+                  m::Reshape(m::DynamicSlice(&slice_indices, m::Op(), m::Op())),
+                  m::Op()))) {
+      slice_indices = while_body->AddInstruction(HloInstruction::CreateReshape(
+          dus_boundary_constant->shape(), slice_indices));
+      VLOG(5) << "Created slice op for second slice: "
+              << slice_indices->ToString();
+      new_full_buffer_output =
+          while_body->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
+              full_buffer_output_gte->shape(), full_buffer_output_gte,
+              first_cp_output,
+              {dus_boundary_constant, slice_indices, dus_boundary_constant}));
+    }
+
+    // If we have a Dot(DS(parameter_index1)), then operands are sharded along
+    // the contracting dim. Slice indices will be the contracting dim's slices.
+    HloInstruction* slice_index;
+    HloInstruction* ds_index_constant;
+    HloInstruction* remainder;
+    HloInstruction* ds_param;
+    // There will be 2 dynamic-slices for unrolled loops, match for each one to
+    // get the slice index which will be used to write the corresponding
+    // received shard into cached activation buffer. For unrolled loops, we need
+    // to write to the final buffer twice per iteration, so we need to match for
+    // the correct slice index based on each DS.
+    if (Match(inst, m::Dot(m::Op(), m::DynamicSlice(&ds_param))) &&
+        Match(ds_param->operand(0), m::GetTupleElement(m::Parameter(), 1))) {
+      for (int64_t ds_op_i = 1; ds_op_i < ds_param->operands().size();
+           ds_op_i++) {
+        if (!Match(
+                ds_param->mutable_operand(ds_op_i),
+                m::Reshape(&slice_index, m::DynamicSlice(m::Constant(),
+                                                         m::Op(&remainder)))) &&
+            !Match(ds_param->mutable_operand(ds_op_i),
+                   m::Constant(&ds_index_constant))) {
+          return absl::OkStatus();
+        }
+      }
+      // First DS has slice index calculated based on loop iterator
+      // Remainder(add(gte, partition_id))
+      if (Match(remainder,
+                m::Remainder(m::Add(m::GetTupleElement(), m::Op()), m::Op()))) {
+        full_buffer_output_gte =
+            while_body->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
+                full_buffer_output_gte->shape(), full_buffer_output_gte,
+                input_gte,
+                {ds_index_constant, ds_index_constant, slice_index}));
+      }
+      // Second DS has slice index calculated based on loop iterator+1 hence
+      // Remainder(add(add(gte, 1), partition_id))
+      if (Match(remainder,
+                m::Remainder(
+                    m::Add(m::Add(m::GetTupleElement(), m::Op()), m::Op()),
+                    m::Op()))) {
+        new_full_buffer_output =
+            while_body->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
+                full_buffer_output_gte->shape(), full_buffer_output_gte,
+                first_cp_output,
+                {ds_index_constant, ds_index_constant, slice_index}));
+      }
+    }
+  }
+  std::vector<HloInstruction*> original_operands(root->operands().begin(),
+                                                 root->operands().end());
+  original_operands.push_back(new_full_buffer_output);
+  HloInstruction* new_output_tuple = while_body->AddInstruction(
+      HloInstruction::CreateTuple(original_operands));
+  TF_RETURN_IF_ERROR(
+      while_body->ReplaceInstructionWithDifferentShape(root, new_output_tuple));
+  return absl::OkStatus();
+}
+
+bool HasReplicaGroups(const HloInstruction* inst) {
+  return inst->replica_groups().size() > 0;
+}
+
+bool ShouldAddToChain(const HloInstruction* inst) {
+  switch (inst->opcode()) {
+    case HloOpcode::kTranspose:
+    case HloOpcode::kReshape:
+    case HloOpcode::kCopy:
+      return inst->user_count() == 1;
+    default:
+      return false;
+  }
+}
+
+struct MatchedGemmA2aResult {
+  HloInstruction* producer_gemm;
+  HloInstruction* lhs;
+  HloInstruction* rhs;
+  HloInstruction* a2a_replacement = nullptr;
+  bool matched = false;
+};
+
+class WindowedEinsumVisitor : public DfsHloRewriteVisitor {
+ public:
+  explicit WindowedEinsumVisitor(
+      std::vector<WindowedEinsumHandler::WindowedEinsumAgLoops>& all_ag_loops)
+      : all_ag_loops_(all_ag_loops) {}
+  absl::StatusOr<bool> MatchA2aGemmWithIntermediateReshapes(
+      HloInstruction* dot, HloInstruction** lhs, HloInstruction** rhs) {
+    if (Match(dot, m::Dot(m::AllToAll(lhs).WithOneUse().WithPredicate(
+                              HasReplicaGroups),
+                          m::Op(rhs))) &&
+        !DynCast<HloAllToAllInstruction>((*lhs))->constrain_layout() &&
+        !(*lhs)->shape().IsTuple()) {
+      return true;
+    }
+    std::vector<HloInstruction*> allowed_intermediate_ops(
+        {dot->mutable_operand(0)});
+
+    HloAllToAllInstruction* matched_a2a = nullptr;
+    // We keep pushing until an unmet condition or we have found the a2a.
+    while (true) {
+      HloInstruction* curr = allowed_intermediate_ops.back();
+      if (ShouldAddToChain(curr)) {
+        allowed_intermediate_ops.insert(allowed_intermediate_ops.end(),
+                                        std::begin(curr->operands()),
+                                        std::end(curr->operands()));
+      } else if (curr->opcode() == HloOpcode::kAllToAll &&
+                 curr->user_count() == 1) {
+        matched_a2a = DynCast<HloAllToAllInstruction>(curr);
+        allowed_intermediate_ops.pop_back();
+        break;
+      } else {
+        return false;
+      }
+    }
+    CHECK(matched_a2a != nullptr);
+    if (matched_a2a->constrain_layout() || matched_a2a->shape().IsTuple() ||
+        !HasReplicaGroups(matched_a2a) || !matched_a2a->split_dimension()) {
+      return false;
+    }
+    // We need to create a new a2a that's a direct producer of the dot and
+    // replace it with the original a2a. A new reshape will be added to the
+    // orginal a2a's input. We first need to determine the new split dimension
+    // after all the reshape ops.
+    int64_t split_dimension = *matched_a2a->split_dimension();
+    for (int64_t i = allowed_intermediate_ops.size() - 1; i >= 0; i--) {
+      HloInstruction* current_op = allowed_intermediate_ops[i];
+      if (current_op->opcode() == HloOpcode::kReshape) {
+        std::vector<std::pair<int64_t, int64_t>> unmodified_dims =
+            ShapeUtil::DimensionsUnmodifiedByReshape(
+                current_op->operand(0)->shape(), current_op->shape());
+        auto it = absl::c_find_if(
+            unmodified_dims,
+            [&split_dimension](std::pair<int64_t, int64_t>& dim_pair) {
+              return dim_pair.first == split_dimension;
+            });
+        // Split dimension of a2a has been modified, we cannot deduce the new
+        // split dim easily, so skip decomposition.
+        if (it == unmodified_dims.end()) {
+          VLOG(5) << "Split dimension of: " << matched_a2a->ToShortString()
+                  << " has been modified by reshapes. Skip process it for "
+                     "decomposition.";
+          return false;
+        }
+        // Assign the new split dim.
+        split_dimension = it->second;
+      } else if (current_op->opcode() == HloOpcode::kTranspose) {
+        const auto& transpose_dims = current_op->dimensions();
+        for (int64_t j = 0; j < transpose_dims.size(); j++) {
+          if ((int64_t)transpose_dims[j] == split_dimension) {
+            split_dimension = j;
+            break;
+          }
+        }
+      }
+    }
+    TF_RETURN_IF_ERROR(allowed_intermediate_ops.back()->ReplaceOperandWith(
+        0, matched_a2a->mutable_operand(0)));
+    HloInstruction* new_a2a =
+        matched_a2a->parent()->AddInstruction(HloInstruction::CreateAllToAll(
+            allowed_intermediate_ops.front()->shape(),
+            {allowed_intermediate_ops.front()}, matched_a2a->replica_groups(),
+            false, hlo_query::NextChannelId(*matched_a2a->GetModule()),
+            split_dimension));
+
+    TF_RETURN_IF_ERROR(dot->ReplaceOperandWith(0, new_a2a));
+    TF_RETURN_IF_ERROR(
+        matched_a2a->parent()->RemoveInstructionAndUnusedOperands(matched_a2a));
+    MarkAsChanged();
+    *lhs = new_a2a;
+    *rhs = dot->mutable_operand(1);
+    return true;
+  }
+
+  absl::Status HandleDot(HloInstruction* dot) override {
+    CHECK_EQ(dot->opcode(), HloOpcode::kDot);
+    HloComputation* comp = dot->parent();
+    // Rewrites a allgather-dot pattern that shares the same operand
+    // with a windowed einsum loop to consume the output of the loop
+    // and remove the all-gather.
+    // Now that we have processed all loops, we can check if there are any
+    // allgather-dot pattern that we can optimize. We'd want to transform:
+    //                       input
+    //                       /    |
+    //                      /     |
+    //                     AG    windowed loop
+    //                     /
+    //                    /
+    //                   dot
+    // to:
+    //                       input
+    //                       |
+    //                       |
+    //                     windowed loop
+    //                       |
+    //                       |
+    //                      dot
+    // The windowed einsum loop will also be rewritten to output the full input
+    // to be consumed by the dot. This is advantageous since the chained dot can
+    // fully utilize all the resources on the GPU while comm is hidden by the
+    // first collective matmul loop.
+    for (WindowedEinsumHandler::WindowedEinsumAgLoops ag_loop : all_ag_loops_) {
+      HloInstruction* loop = ag_loop.loop;
+      HloInstruction* ag_operand = nullptr;
+
+      if (Match(dot, m::Dot(m::AllGather(&ag_operand), m::Op())) ||
+          Match(dot, m::Dot(m::Op(), m::AllGather(&ag_operand)))) {
+        HloInstruction* windowed_lhs =
+            loop->mutable_operand(0)->mutable_operand(0);
+        HloInstruction* ag_with_shared_operand = nullptr;
+        if (ag_operand && ag_operand->mutable_operand(0) == windowed_lhs) {
+          ag_with_shared_operand = ag_operand;
+        }
+
+        if (!ag_with_shared_operand) {
+          continue;
+        }
+
+        VLOG(5) << "Found all-gather that shares the same operand with a "
+                   "windowed einsum loop : "
+                << loop->ToString();
+
+        if (!ag_loop.consumed) {
+          TF_RETURN_IF_ERROR(ProcessWindowedEinsumLoopForActivationCaching(
+              ag_loop, ag_with_shared_operand));
+          ag_loop.consumed = true;
+        }
+        int64_t cache_output_index = dot->operand_index(ag_with_shared_operand);
+        HloComputation* comp = dot->parent();
+        HloInstruction* new_gte =
+            comp->AddInstruction(HloInstruction::CreateGetTupleElement(
+                loop, GetAgActivationCacheIndex(loop) - 1));
+        TF_RETURN_IF_ERROR(
+            dot->ReplaceOperandWith(cache_output_index, new_gte));
+        TF_RETURN_IF_ERROR(comp->RemoveInstruction(ag_with_shared_operand));
+      }
+    }
+    // Rewrites an all-to-all+gemm into multiple independent partial a2a+gemms
+    // to minimize communication overhead. To do this, the original input will
+    // be sliced into replica_group size and perform all-to-all+gemm.
+    HloInstruction* lhs;
+    HloInstruction* rhs;
+    std::vector<xla::ReplicaGroup> replica_groups;
+    TF_ASSIGN_OR_RETURN(bool matched,
+                        MatchA2aGemmWithIntermediateReshapes(dot, &lhs, &rhs));
+    if (matched) {
+      replica_groups = lhs->replica_groups();
+      // We split the a2a+gemm along the contracting dimension into multiple
+      // a2a+gemms and perform partial dots, partial results are added to the
+      // final output buffer.
+      int64_t group_size = replica_groups[0].replica_ids_size();
+      if (absl::c_find_if(replica_groups, [&](ReplicaGroup& group) {
+            return group.replica_ids_size() != group_size;
+          }) != replica_groups.end()) {
+        VLOG(5) << "All-to-all split groups don't have the same number of "
+                   "replicas.";
+        return absl::OkStatus();
+      }
+
+      // Get the dimension to slice for lhs and rhs, we slice on the contracting
+      // dimensions to calculate partial results
+      const DotDimensionNumbers& original_dot_dnums =
+          dot->dot_dimension_numbers();
+      const PrecisionConfig& original_precision = dot->precision_config();
+      const auto& lhs_contracting_dims =
+          dot->dot_dimension_numbers().lhs_contracting_dimensions();
+      const auto& rhs_contracting_dims =
+          dot->dot_dimension_numbers().rhs_contracting_dimensions();
+
+      if (lhs_contracting_dims.size() != 1 ||
+          rhs_contracting_dims.size() != 1) {
+        VLOG(5) << "Contracting dimensions have multiple elements, all-to-all "
+                   "sharding will be skipped.";
+        return absl::OkStatus();
+      }
+      int64_t lhs_contracting_dim = lhs_contracting_dims[0];
+      int64_t rhs_contracting_dim = rhs_contracting_dims[0];
+      HloAllToAllInstruction* a2a = DynCast<HloAllToAllInstruction>(lhs);
+      int64_t contracting_dim_value =
+          rhs->shape().dimensions()[rhs_contracting_dim];
+
+      // Each split is sliced out of the input buffer, we need to determine the
+      // slice sizes and increments.
+      std::vector<int64_t> lhs_slice_sizes(a2a->shape().rank(), 0);
+      std::vector<int64_t> lhs_slice_increments(a2a->shape().rank(), 1);
+      std::vector<int64_t> lhs_slice_max_range(
+          a2a->shape().dimensions().begin(), a2a->shape().dimensions().end());
+
+      std::vector<int64_t> rhs_slice_sizes(rhs->shape().rank(), 0);
+      std::vector<int64_t> rhs_slice_increments(rhs->shape().rank(), 1);
+      std::vector<int64_t> rhs_slice_max_range(
+          rhs->shape().dimensions().begin(), rhs->shape().dimensions().end());
+
+      // Create a zero-valued buffer to hold output.
+      HloInstruction* output_buffer =
+          comp->AddInstruction(HloInstruction::CreateBroadcast(
+              dot->shape(),
+              comp->AddInstruction(HloInstruction::CreateConstant(
+                  LiteralUtil::Zero(dot->shape().element_type()))),
+              {}));
+      HloInstruction* a2a_operand = a2a->mutable_operand(0);
+      if (contracting_dim_value % group_size) {
+        VLOG(5) << absl::StrFormat(
+            "Contracting dimension %d needs to be divisible by group_size %d",
+            contracting_dim_value, group_size);
+        return absl::OkStatus();
+      }
+      int64_t size_per_split = contracting_dim_value / group_size;
+
+      // Each split is sliced out of the input buffer, we need to determine the
+      // slice sizes and increments.
+      lhs_slice_max_range[lhs_contracting_dim] = size_per_split;
+      rhs_slice_max_range[rhs_contracting_dim] = size_per_split;
+
+      Shape lhs_slice_shape = a2a->shape();
+      Shape rhs_slice_shape = rhs->shape();
+
+      lhs_slice_shape.set_dimensions(lhs_contracting_dim, size_per_split);
+      rhs_slice_shape.set_dimensions(rhs_contracting_dim, size_per_split);
+
+      HloInstruction* lhs_slice;
+      HloInstruction* rhs_slice;
+
+      HloInstruction* partial_result = output_buffer;
+
+      Shape partial_all_to_all_shape = lhs_slice_shape;
+
+      TF_ASSIGN_OR_RETURN(
+          Shape partial_dot_shape,
+          ShapeInference::InferDotOpShape(
+              partial_all_to_all_shape, rhs_slice_shape, original_dot_dnums,
+              /*preferred_element_type=*/std::nullopt));
+      int64_t stream_id = hlo_query::NextChannelId(*a2a->GetModule());
+      for (int64_t i = 0; i < group_size; ++i) {
+        lhs_slice = comp->AddInstruction(HloInstruction::CreateSlice(
+            lhs_slice_shape, a2a_operand, lhs_slice_sizes, lhs_slice_max_range,
+            lhs_slice_increments));
+        a2a->SetupDerivedInstruction(lhs_slice);
+        lhs_slice_sizes[lhs_contracting_dim] =
+            lhs_slice_max_range[lhs_contracting_dim];
+        lhs_slice_max_range[lhs_contracting_dim] += size_per_split;
+
+        rhs_slice = comp->AddInstruction(HloInstruction::CreateSlice(
+            rhs_slice_shape, rhs, rhs_slice_sizes, rhs_slice_max_range,
+            rhs_slice_increments));
+        a2a->SetupDerivedInstruction(rhs_slice);
+        rhs_slice_sizes[rhs_contracting_dim] =
+            rhs_slice_max_range[rhs_contracting_dim];
+        rhs_slice_max_range[rhs_contracting_dim] += size_per_split;
+
+        HloInstruction* partial_all_to_all =
+            comp->AddInstruction(HloInstruction::CreateAllToAll(
+                partial_all_to_all_shape, {lhs_slice}, a2a->device_list(),
+                false, hlo_query::NextChannelId(*a2a->GetModule()),
+                a2a->split_dimension()));
+        a2a->SetupDerivedInstruction(partial_all_to_all);
+
+        HloInstruction* partial_dot =
+            comp->AddInstruction(HloInstruction::CreateDot(
+                partial_dot_shape, partial_all_to_all, rhs_slice,
+                original_dot_dnums, original_precision));
+        partial_result = comp->AddInstruction(
+            HloInstruction::CreateBinary(partial_dot->shape(), HloOpcode::kAdd,
+                                         partial_dot, partial_result));
+        a2a->SetupDerivedInstruction(partial_result);
+        TF_RETURN_IF_ERROR(
+            UpdateDotAndConsumerConfig(partial_dot, stream_id++));
+      }
+      TF_RETURN_IF_ERROR(ReplaceInstruction(dot, partial_result));
+    }
+    return absl::OkStatus();
+  }
+
+  absl::StatusOr<MatchedGemmA2aResult> MatchGemmA2aWithIntermediateReshapes(
+      HloInstruction* inst) {
+    MatchedGemmA2aResult result;
+    HloAllToAllInstruction* a2a = DynCast<HloAllToAllInstruction>(inst);
+    if (!HasReplicaGroups(a2a) || a2a->constrain_layout() ||
+        a2a->shape().IsTuple()) {
+      return result;
+    }
+    if (Match(a2a, m::AllToAll(m::Dot(&result.producer_gemm, m::Op(&result.lhs),
+                                      m::Op(&result.rhs))
+                                   .WithOneUse()))) {
+      result.matched = true;
+      return result;
+    }
+    std::vector<HloInstruction*> allowed_intermediate_ops(
+        {a2a->mutable_operand(0)});
+
+    HloInstruction* matched_dot = nullptr;
+    // We keep pushing until an unmet condition or we have found the producer
+    // dot.
+    while (true) {
+      HloInstruction* curr = allowed_intermediate_ops.back();
+      if (ShouldAddToChain(curr)) {
+        allowed_intermediate_ops.insert(allowed_intermediate_ops.end(),
+                                        std::begin(curr->operands()),
+                                        std::end(curr->operands()));
+      } else if (curr->opcode() == HloOpcode::kDot && curr->user_count() == 1) {
+        matched_dot = curr;
+        allowed_intermediate_ops.pop_back();
+        break;
+      } else {
+        return result;
+      }
+    }
+    CHECK(matched_dot != nullptr);
+    // We need to create a new a2a that's a direct consumer of the dot and
+    // replace it with the original a2a. A new reshape will be added to the
+    // orginal a2a's output. We first need to determine the new split dimension
+    // after all the reshape ops.
+    int64_t split_dimension = *a2a->split_dimension();
+    for (int64_t i = 0; i < allowed_intermediate_ops.size(); i++) {
+      HloInstruction* current_op = allowed_intermediate_ops[i];
+      if (current_op->opcode() == HloOpcode::kReshape) {
+        std::vector<std::pair<int64_t, int64_t>> unmodified_dims =
+            ShapeUtil::DimensionsUnmodifiedByReshape(
+                current_op->operand(0)->shape(), current_op->shape());
+        auto it = absl::c_find_if(
+            unmodified_dims,
+            [&split_dimension](std::pair<int64_t, int64_t>& dim_pair) {
+              return dim_pair.second == split_dimension;
+            });
+        // Split dimension of a2a has been modified, we cannot deduce the new
+        // split dim easily, so skip decomposition.
+        if (it == unmodified_dims.end()) {
+          VLOG(5) << "Split dimension of: " << a2a->ToShortString()
+                  << " has been modified by reshapes. Skip process it for "
+                     "decomposition.";
+          return result;
+        }
+        // Assign the new split dim.
+        split_dimension = it->first;
+      } else if (current_op->opcode() == HloOpcode::kTranspose) {
+        const auto& transpose_dims = current_op->dimensions();
+        split_dimension = transpose_dims[split_dimension];
+      }
+    }
+    result.a2a_replacement =
+        matched_dot->parent()->AddInstruction(HloInstruction::CreateAllToAll(
+            matched_dot->shape(), {matched_dot}, a2a->replica_groups(), false,
+            hlo_query::NextChannelId(*matched_dot->GetModule()),
+            split_dimension));
+    TF_RETURN_IF_ERROR(allowed_intermediate_ops.back()->ReplaceOperandWith(
+        0, result.a2a_replacement));
+    inst->SetupDerivedInstruction(result.a2a_replacement);
+
+    TF_RETURN_IF_ERROR(
+        ReplaceInstruction(inst, allowed_intermediate_ops.front()));
+    result.lhs = matched_dot->mutable_operand(0);
+    result.rhs = matched_dot->mutable_operand(1);
+    result.producer_gemm = matched_dot;
+    result.matched = true;
+    return result;
+  }
+
+  // Rewrites an gemm+all-to-all into multiple independent partial gemm+a2a's
+  // to minimize communication overhead. To do this, the original input will be
+  // sliced into replica_group size and perform gemm+all-to-all.
+  absl::Status HandleAllToAll(HloInstruction* inst) override {
+    CHECK_EQ(inst->opcode(), HloOpcode::kAllToAll);
+    HloComputation* comp = inst->parent();
+    // Rewrites a gemm+alltoall into multiple independent partial gemm+a2as
+    // to minimize communication overhead.
+    std::vector<xla::ReplicaGroup> replica_groups;
+    TF_ASSIGN_OR_RETURN(MatchedGemmA2aResult matched_result,
+                        MatchGemmA2aWithIntermediateReshapes(inst));
+    if (matched_result.matched) {
+      HloInstruction* a2a = inst;
+      if (matched_result.a2a_replacement) {
+        a2a = matched_result.a2a_replacement;
+      }
+      replica_groups = a2a->replica_groups();
+      // Similar to a2a+gemm, we split along contracting dimensions
+      // and aggregate result at each step.
+      int64_t group_size = replica_groups[0].replica_ids_size();
+
+      if (absl::c_find_if(replica_groups, [&](ReplicaGroup& group) {
+            return group.replica_ids_size() != group_size;
+          }) != replica_groups.end()) {
+        VLOG(5) << "All-to-all split groups don't have the same number of "
+                   "replicas.";
+        return absl::OkStatus();
+      }
+
+      // Get the dimension to slice for lhs and rhs, we slice on the contracting
+      // dimensions to calculate partial results
+      const DotDimensionNumbers& original_dot_dnums =
+          matched_result.producer_gemm->dot_dimension_numbers();
+      const PrecisionConfig& original_precision =
+          matched_result.producer_gemm->precision_config();
+      const auto& lhs_contracting_dims =
+          matched_result.producer_gemm->dot_dimension_numbers()
+              .lhs_contracting_dimensions();
+      const auto& rhs_contracting_dims =
+          matched_result.producer_gemm->dot_dimension_numbers()
+              .rhs_contracting_dimensions();
+
+      if (lhs_contracting_dims.size() != 1 ||
+          rhs_contracting_dims.size() != 1) {
+        VLOG(5) << "Contracting dimensions have multiple elements, all-to-all "
+                   "sharding will be skipped.";
+        return absl::OkStatus();
+      }
+      int64_t lhs_contracting_dim = lhs_contracting_dims[0];
+      int64_t rhs_contracting_dim = rhs_contracting_dims[0];
+      HloAllToAllInstruction* all_to_all = DynCast<HloAllToAllInstruction>(a2a);
+      int64_t contracting_dim_value =
+          matched_result.rhs->shape().dimensions()[rhs_contracting_dim];
+      // Each split is sliced out of the input buffer, we need to determine the
+      // slice sizes and increments.
+      std::vector<int64_t> lhs_slice_sizes(matched_result.lhs->shape().rank(),
+                                           0);
+      std::vector<int64_t> lhs_slice_increments(
+          matched_result.lhs->shape().rank(), 1);
+      std::vector<int64_t> lhs_slice_max_range(
+          matched_result.lhs->shape().dimensions().begin(),
+          matched_result.lhs->shape().dimensions().end());
+
+      std::vector<int64_t> rhs_slice_sizes(matched_result.rhs->shape().rank(),
+                                           0);
+      std::vector<int64_t> rhs_slice_increments(
+          matched_result.rhs->shape().rank(), 1);
+      std::vector<int64_t> rhs_slice_max_range(
+          matched_result.rhs->shape().dimensions().begin(),
+          matched_result.rhs->shape().dimensions().end());
+
+      // Create a zero-valued buffer to hold output.
+      HloInstruction* output_buffer =
+          comp->AddInstruction(HloInstruction::CreateBroadcast(
+              all_to_all->shape(),
+              comp->AddInstruction(HloInstruction::CreateConstant(
+                  LiteralUtil::Zero(all_to_all->shape().element_type()))),
+              {}));
+      if (contracting_dim_value % group_size) {
+        VLOG(5) << absl::StrFormat(
+            "Contracting dimension %d needs to be divisible by group_size %d",
+            contracting_dim_value, group_size);
+        return absl::OkStatus();
+      }
+
+      int64_t size_per_split = contracting_dim_value / group_size;
+      // Each split is sliced out of the input buffer, we need to determine the
+      // slice sizes and increments.
+      lhs_slice_max_range[lhs_contracting_dim] = size_per_split;
+      rhs_slice_max_range[rhs_contracting_dim] = size_per_split;
+
+      Shape lhs_slice_shape = matched_result.lhs->shape();
+      Shape rhs_slice_shape = matched_result.rhs->shape();
+
+      lhs_slice_shape.set_dimensions(lhs_contracting_dim, size_per_split);
+      rhs_slice_shape.set_dimensions(rhs_contracting_dim, size_per_split);
+
+      HloInstruction* lhs_slice;
+      HloInstruction* rhs_slice;
+
+      HloInstruction* partial_result = output_buffer;
+      Shape partial_all_to_all_shape = all_to_all->shape();
+
+      TF_ASSIGN_OR_RETURN(
+          Shape partial_dot_shape,
+          ShapeInference::InferDotOpShape(
+              lhs_slice_shape, rhs_slice_shape, original_dot_dnums,
+              /*preferred_element_type=*/std::nullopt));
+      int64_t stream_id = hlo_query::NextChannelId(*all_to_all->GetModule());
+      for (int64_t i = 0; i < group_size; ++i) {
+        lhs_slice = comp->AddInstruction(HloInstruction::CreateSlice(
+            lhs_slice_shape, matched_result.lhs, lhs_slice_sizes,
+            lhs_slice_max_range, lhs_slice_increments));
+        all_to_all->SetupDerivedInstruction(lhs_slice);
+        lhs_slice_sizes[lhs_contracting_dim] =
+            lhs_slice_max_range[lhs_contracting_dim];
+        lhs_slice_max_range[lhs_contracting_dim] += size_per_split;
+
+        rhs_slice = comp->AddInstruction(HloInstruction::CreateSlice(
+            rhs_slice_shape, matched_result.rhs, rhs_slice_sizes,
+            rhs_slice_max_range, rhs_slice_increments));
+
+        all_to_all->SetupDerivedInstruction(rhs_slice);
+        rhs_slice_sizes[rhs_contracting_dim] =
+            rhs_slice_max_range[rhs_contracting_dim];
+        rhs_slice_max_range[rhs_contracting_dim] += size_per_split;
+
+        HloInstruction* partial_dot = comp->AddInstruction(
+            HloInstruction::CreateDot(partial_dot_shape, lhs_slice, rhs_slice,
+                                      original_dot_dnums, original_precision));
+
+        HloInstruction* partial_all_to_all =
+            comp->AddInstruction(HloInstruction::CreateAllToAll(
+                partial_all_to_all_shape, {partial_dot},
+                all_to_all->device_list(), false,
+                hlo_query::NextChannelId(*all_to_all->GetModule()),
+                all_to_all->split_dimension()));
+        all_to_all->SetupDerivedInstruction(partial_all_to_all);
+        partial_result = comp->AddInstruction(HloInstruction::CreateBinary(
+            partial_all_to_all_shape, HloOpcode::kAdd, partial_all_to_all,
+            partial_result));
+        all_to_all->SetupDerivedInstruction(partial_result);
+        TF_RETURN_IF_ERROR(
+            UpdateDotAndConsumerConfig(partial_dot, stream_id++));
+      }
+      TF_RETURN_IF_ERROR(ReplaceInstruction(all_to_all, partial_result));
+    }
+
+    return absl::OkStatus();
+  }
+
+ private:
+  std::vector<WindowedEinsumHandler::WindowedEinsumAgLoops>& all_ag_loops_;
+};
+
+}  // namespace
+
+absl::StatusOr<bool> WindowedEinsumHandler::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  XLA_VLOG_LINES(
+      5, "WindowedEinsumHandler::Run(), before:\n" + module->ToString());
+  bool changed = false;
+  int64_t stream_id = hlo_query::NextChannelId(*module);
+
+  for (HloComputation* comp :
+       module->MakeNonfusionComputations(execution_threads)) {
+    if (comp->name().find(kWindowedEinsumRsLoopName) == 0) {
+      VLOG(5) << "Processing computation: " << comp->name();
+      TF_ASSIGN_OR_RETURN(bool comp_result,
+                          HandleRsWindowedEinsumLoop(comp, stream_id));
+      changed = comp_result;
+    } else if (comp->name().find(kWindowedEinsumAgLoopName) == 0) {
+      VLOG(5) << "Processing computation: " << comp->name();
+      TF_ASSIGN_OR_RETURN(bool comp_result,
+                          HandleAgWindowedEinsumLoop(comp, stream_id));
+      all_ag_loops_.push_back(
+          WindowedEinsumAgLoops(comp->WhileCallInstruction()));
+      changed = comp_result;
+    }
+  }
+  for (HloComputation* comp :
+       module->MakeNonfusionComputations(execution_threads)) {
+    WindowedEinsumVisitor visitor(all_ag_loops_);
+    TF_RETURN_IF_ERROR(comp->Accept(&visitor));
+    changed |= visitor.changed();
+  }
+
+  XLA_VLOG_LINES(5,
+                 "WindowedEinsumHandler::Run(), after:\n" + module->ToString());
+  return changed;
+}
+
+}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler.h b/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler.h
new file mode 100644
index 0000000..bcc7680e
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler.h
@@ -0,0 +1,64 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_TRANSFORMS_WINDOWED_EINSUM_HANDLER_H_
+#define XLA_SERVICE_GPU_TRANSFORMS_WINDOWED_EINSUM_HANDLER_H_
+
+#include <vector>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+
+namespace xla::gpu {
+
+// This pass is targeting the windowed einsum optimization
+// in the SPMD pipeline. It rewrites all-gather+gemm or
+// gemm+reduce-scatter into sharded loops to achieve overlap
+// between sharded gemms and communication. This pass will
+// optimize it on GPU by annotating independent gemms with
+// stream ids in the backend config. By running them in different
+// streams, we can practically achieve overlap between gemms too.
+class WindowedEinsumHandler : public HloModulePass {
+ public:
+  absl::string_view name() const override { return "windowed-einsum-handler"; }
+
+  struct WindowedEinsumAgLoops {
+    explicit WindowedEinsumAgLoops(HloInstruction* loop) : loop(loop) {}
+    HloInstruction* loop;
+    bool consumed = false;
+  };
+
+  using HloPassInterface::Run;
+  absl::StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+  constexpr static const char* kWindowedEinsumRsLoopName =
+      "windowed_dot_general_body_rs";
+  constexpr static const char* kWindowedEinsumAgLoopName =
+      "windowed_dot_general_body_ag";
+
+ private:
+  std::vector<WindowedEinsumAgLoops> all_ag_loops_;
+};
+
+}  // namespace xla::gpu
+
+#endif  // XLA_SERVICE_GPU_TRANSFORMS_WINDOWED_EINSUM_HANDLER_H_
diff --git a/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler_test.cc b/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler_test.cc
new file mode 100644
index 0000000..a8ea1b1
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler_test.cc
@@ -0,0 +1,918 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/transforms/windowed_einsum_handler.h"
+
+#include <memory>
+#include <string>
+
+#include <gtest/gtest.h>
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/gpu/backend_configs.pb.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/tests/filecheck.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla::gpu {
+namespace {
+
+namespace m = ::xla::match;
+
+using WindowedEinsumHanlderTest = HloTestBase;
+
+HloInstruction* FindInstructionByName(HloComputation* comp, std::string name) {
+  for (auto inst : comp->instructions()) {
+    if (inst->name() == name) {
+      return inst;
+    }
+  }
+  return nullptr;
+}
+
+TEST_F(WindowedEinsumHanlderTest, AgLoopsHaveStreamIds) {
+  constexpr absl::string_view kHloString = R"(
+HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[1,512,24576]{2,1,0}, bf16[24576,24576]{1,0})->bf16[2048,24576]{1,0}}, num_partitions=4
+
+windowed_dot_general_body_ag.1 {
+  param = (bf16[512,24576]{1,0}, bf16[24576,24576]{1,0}, bf16[2048,24576]{1,0}, bf16[2048,24576]{1,0}, u32[]) parameter(0)
+  get-tuple-element = bf16[512,24576]{1,0} get-tuple-element(param), index=0
+  collective-permute = bf16[512,24576]{1,0} collective-permute(get-tuple-element), channel_id=2, source_target_pairs={{0,3},{1,0},{2,1},{3,2}}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[]}
+  get-tuple-element.1 = bf16[24576,24576]{1,0} get-tuple-element(param), index=1
+  get-tuple-element.2 = bf16[2048,24576]{1,0} get-tuple-element(param), index=2
+  dot.2 = bf16[512,24576]{1,0} dot(get-tuple-element, get-tuple-element.1), lhs_contracting_dims={1}, rhs_contracting_dims={0}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[]}
+  constant.1 = s32[4]{0} constant({0, 512, 1024, 1536})
+  get-tuple-element.4 = u32[] get-tuple-element(param), index=4
+  partition-id = u32[] partition-id()
+  add = u32[] add(get-tuple-element.4, partition-id)
+  constant = u32[] constant(4)
+  remainder = u32[] remainder(add, constant)
+  dynamic-slice = s32[1]{0} dynamic-slice(constant.1, remainder), dynamic_slice_sizes={1}
+  reshape.4 = s32[] reshape(dynamic-slice)
+  constant.2 = s32[] constant(0)
+  dynamic-update-slice = bf16[2048,24576]{1,0} dynamic-update-slice(get-tuple-element.2, dot.2, reshape.4, constant.2), backend_config={"operation_queue_id":"0","wait_on_operation_queues":[]}
+  dot.3 = bf16[512,24576]{1,0} dot(collective-permute, get-tuple-element.1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  constant.3 = u32[] constant(1)
+  add.1 = u32[] add(get-tuple-element.4, constant.3)
+  add.2 = u32[] add(add.1, partition-id)
+  remainder.1 = u32[] remainder(add.2, constant)
+  dynamic-slice.1 = s32[1]{0} dynamic-slice(constant.1, remainder.1), dynamic_slice_sizes={1}
+  reshape.5 = s32[] reshape(dynamic-slice.1)
+  dynamic-update-slice.1 = bf16[2048,24576]{1,0} dynamic-update-slice(dynamic-update-slice, dot.3, reshape.5, constant.2)
+  get-tuple-element.3 = bf16[2048,24576]{1,0} get-tuple-element(param), index=3
+  add.3 = u32[] add(add.1, constant.3)
+  ROOT tuple = (bf16[512,24576]{1,0}, bf16[24576,24576]{1,0}, bf16[2048,24576]{1,0}, bf16[2048,24576]{1,0}, u32[]) tuple(collective-permute, get-tuple-element.1, dynamic-update-slice.1, get-tuple-element.3, add.3)
+} // windowed_dot_general_body_ag.1
+
+windowed_dot_general_cond_ag {
+  param.1 = (bf16[512,24576]{1,0}, bf16[24576,24576]{1,0}, bf16[2048,24576]{1,0}, bf16[2048,24576]{1,0}, u32[]) parameter(0)
+  get-tuple-element.5 = u32[] get-tuple-element(param.1), index=4
+  constant.8 = u32[] constant(4)
+  ROOT compare = pred[] compare(get-tuple-element.5, constant.8), direction=LT
+}
+
+ENTRY test_main {
+  param.4 = bf16[1,512,24576]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]}
+  reshape.8 = bf16[512,24576]{1,0} reshape(param.4)
+  param.5 = bf16[24576,24576]{1,0} parameter(1), sharding={devices=[1,4]<=[4]}
+  constant.18 = bf16[] constant(0)
+  broadcast = bf16[2048,24576]{1,0} broadcast(constant.18), dimensions={}
+  constant.20 = u32[] constant(0)
+  tuple.2 = (bf16[512,24576]{1,0}, bf16[24576,24576]{1,0}, bf16[2048,24576]{1,0}, bf16[2048,24576]{1,0}, u32[]) tuple(reshape.8, param.5, broadcast, broadcast, constant.20)
+  while = (bf16[512,24576]{1,0}, bf16[24576,24576]{1,0}, bf16[2048,24576]{1,0}, bf16[2048,24576]{1,0}, u32[]) while(tuple.2), condition=windowed_dot_general_cond_ag, body=windowed_dot_general_body_ag.1
+  ROOT get-tuple-element.13 = bf16[2048,24576]{1,0} get-tuple-element(while), index=2
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(kHloString));
+
+  WindowedEinsumHandler gpu_handler;
+  bool changed;
+  TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get()));
+  EXPECT_TRUE(changed);
+
+  HloInstruction* ag_loop =
+      module->entry_computation()->root_instruction()->mutable_operand(0);
+  HloComputation* ag_loop_body = ag_loop->while_body();
+  HloInstruction* inst = FindInstructionByName(ag_loop_body, "dot.2");
+  EXPECT_GT(inst->backend_config<GpuBackendConfig>()->operation_queue_id(), 0);
+  EXPECT_TRUE(
+      inst->backend_config<GpuBackendConfig>()->force_earliest_schedule());
+
+  HloInstruction* cp1 =
+      FindInstructionByName(ag_loop_body, "collective-permute");
+  EXPECT_TRUE(
+      cp1->backend_config<GpuBackendConfig>()->force_earliest_schedule());
+}
+
+TEST_F(WindowedEinsumHanlderTest, RsLoopsHaveStreamIds) {
+  constexpr absl::string_view kHloString = R"(
+HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[24576,24576]{1,0}, bf16[512,24576]{1,0}, bf16[2048,24576]{1,0})->bf16[512,24576]{1,0}}, num_partitions=4
+
+windowed_dot_general_body_rs_clone.1 {
+  param.2 = (bf16[2048,24576]{1,0}, bf16[24576,24576]{1,0}, bf16[512,24576]{1,0}, bf16[512,24576]{1,0}, u32[]) parameter(0)
+  get-tuple-element.6 = bf16[2048,24576]{1,0} get-tuple-element(param.2), index=0
+  get-tuple-element.7 = bf16[24576,24576]{1,0} get-tuple-element(param.2), index=1
+  get-tuple-element.9 = bf16[512,24576]{1,0} get-tuple-element(param.2), index=2
+  collective-permute.1 = bf16[512,24576]{1,0} collective-permute(get-tuple-element.9), channel_id=4, source_target_pairs={{0,2},{1,3},{2,0},{3,1}}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[]}
+  constant.10 = s32[4]{0} constant({0, 512, 1024, 1536})
+  get-tuple-element.11 = u32[] get-tuple-element(param.2), index=4
+  constant.12 = u32[] constant(2)
+  add.8 = u32[] add(get-tuple-element.11, constant.12)
+  constant.13 = u32[] constant(1)
+  add.9 = u32[] add(add.8, constant.13)
+  partition-id.3 = u32[] partition-id()
+  add.10 = u32[] add(add.9, partition-id.3)
+  constant.9 = u32[] constant(4)
+  remainder.3 = u32[] remainder(add.10, constant.9)
+  dynamic-slice.4 = s32[1]{0} dynamic-slice(constant.10, remainder.3), dynamic_slice_sizes={1}
+  reshape.7 = s32[] reshape(dynamic-slice.4)
+  constant.11 = s32[] constant(0)
+  dynamic-slice.5 = bf16[512,24576]{1,0} dynamic-slice(get-tuple-element.6, reshape.7, constant.11), dynamic_slice_sizes={512,24576}
+  dot.7 = bf16[512,24576]{1,0} dot(dynamic-slice.5, get-tuple-element.7), lhs_contracting_dims={1}, rhs_contracting_dims={0}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[]}
+  add.11 = bf16[512,24576]{1,0} add(collective-permute.1, dot.7), backend_config={"operation_queue_id":"0","wait_on_operation_queues":[]}
+  get-tuple-element.10 = bf16[512,24576]{1,0} get-tuple-element(param.2), index=3
+  add.6 = u32[] add(get-tuple-element.11, partition-id.3)
+  remainder.2 = u32[] remainder(add.6, constant.9)
+  dynamic-slice.2 = s32[1]{0} dynamic-slice(constant.10, remainder.2), dynamic_slice_sizes={1}
+  reshape.6 = s32[] reshape(dynamic-slice.2)
+  dynamic-slice.3 = bf16[512,24576]{1,0} dynamic-slice(get-tuple-element.6, reshape.6, constant.11), dynamic_slice_sizes={512,24576}
+  dot.5 = bf16[512,24576]{1,0} dot(dynamic-slice.3, get-tuple-element.7), lhs_contracting_dims={1}, rhs_contracting_dims={0}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[]}
+  add.7 = bf16[512,24576]{1,0} add(get-tuple-element.10, dot.5), backend_config={"operation_queue_id":"0","wait_on_operation_queues":[]}
+  collective-permute.2 = bf16[512,24576]{1,0} collective-permute(add.7), channel_id=5, source_target_pairs={{0,2},{1,3},{2,0},{3,1}}
+  ROOT tuple.1 = (bf16[2048,24576]{1,0}, bf16[24576,24576]{1,0}, bf16[512,24576]{1,0}, bf16[512,24576]{1,0}, u32[]) tuple(get-tuple-element.6, get-tuple-element.7, add.11, collective-permute.2, add.8)
+}
+
+windowed_dot_general_cond_rs {
+  param.3 = (bf16[2048,24576]{1,0}, bf16[24576,24576]{1,0}, bf16[512,24576]{1,0}, bf16[512,24576]{1,0}, u32[]) parameter(0)
+  get-tuple-element.12 = u32[] get-tuple-element(param.3), index=4
+  constant.17 = u32[] constant(4)
+  ROOT compare.1 = pred[] compare(get-tuple-element.12, constant.17), direction=LT
+}
+
+ENTRY main.9_spmd {
+  param.6 = bf16[24576,24576]{1,0} parameter(0), sharding={devices=[4,1]<=[4]}
+  param.7 = bf16[512,24576]{1,0} parameter(1)
+  param.8 = bf16[2048,24576]{1,0} parameter(2)
+  constant.20 = u32[] constant(0)
+  tuple.3 = (bf16[2048,24576]{1,0}, bf16[24576,24576]{1,0}, bf16[512,24576]{1,0}, bf16[512,24576]{1,0}, u32[]) tuple(param.8, param.6, param.7, param.7, constant.20)
+  while.1 = (bf16[2048,24576]{1,0}, bf16[24576,24576]{1,0}, bf16[512,24576]{1,0}, bf16[512,24576]{1,0}, u32[]) while(tuple.3), condition=windowed_dot_general_cond_rs, body=windowed_dot_general_body_rs_clone.1
+  ROOT get-tuple-element.14 = bf16[512,24576]{1,0} get-tuple-element(while.1), index=2
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(kHloString));
+
+  WindowedEinsumHandler gpu_handler;
+  bool changed;
+  TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get()));
+  EXPECT_TRUE(changed);
+
+  HloInstruction* rs_loop =
+      module->entry_computation()->root_instruction()->mutable_operand(0);
+  HloComputation* rs_loop_body = rs_loop->while_body();
+  HloInstruction* inst = FindInstructionByName(rs_loop_body, "dot.7");
+  EXPECT_TRUE(inst->backend_config<GpuBackendConfig>()->operation_queue_id() >
+              0);
+
+  HloInstruction* cp1 =
+      FindInstructionByName(rs_loop_body, "collective-permute.1");
+  EXPECT_TRUE(
+      cp1->backend_config<GpuBackendConfig>()->force_earliest_schedule());
+}
+
+TEST_F(WindowedEinsumHanlderTest, AgLoopsMultipleConsumersAreChained) {
+  constexpr absl::string_view kHloString = R"(
+HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[2,512,24576]{2,1,0}, bf16[24576,24576]{1,0}, bf16[24576,24576]{1,0})->bf16[2,2048,24576]{2,1,0}}, num_partitions=4
+
+windowed_dot_general_body_ag {
+  param.1 = (bf16[2,512,24576]{2,1,0}, bf16[24576,24576]{1,0}, bf16[2,2048,24576]{2,1,0}, bf16[2,2048,24576]{2,1,0}, u32[]) parameter(0)
+  get-tuple-element.1 = bf16[2,512,24576]{2,1,0} get-tuple-element(param.1), index=0
+  collective-permute = bf16[2,512,24576]{2,1,0} collective-permute(get-tuple-element.1), channel_id=2, source_target_pairs={{0,3},{1,0},{2,1},{3,2}}
+  collective-permute.1 = bf16[2,512,24576]{2,1,0} collective-permute(collective-permute), channel_id=3, source_target_pairs={{0,3},{1,0},{2,1},{3,2}}
+  get-tuple-element.2 = bf16[24576,24576]{1,0} get-tuple-element(param.1), index=1
+  get-tuple-element.3 = bf16[2,2048,24576]{2,1,0} get-tuple-element(param.1), index=2
+  dot = bf16[2,512,24576]{2,1,0} dot(get-tuple-element.1, get-tuple-element.2), lhs_contracting_dims={2}, rhs_contracting_dims={0}
+  constant.2 = s32[] constant(0)
+  constant.3 = s32[4]{0} constant({0, 512, 1024, 1536})
+  get-tuple-element.5 = u32[] get-tuple-element(param.1), index=4
+  partition-id = u32[] partition-id()
+  add = u32[] add(get-tuple-element.5, partition-id)
+  constant.1 = u32[] constant(4)
+  remainder = u32[] remainder(add, constant.1)
+  dynamic-slice = s32[1]{0} dynamic-slice(constant.3, remainder), dynamic_slice_sizes={1}
+  reshape = s32[] reshape(dynamic-slice)
+  dynamic-update-slice = bf16[2,2048,24576]{2,1,0} dynamic-update-slice(get-tuple-element.3, dot, constant.2, reshape, constant.2)
+  dot.1 = bf16[2,512,24576]{2,1,0} dot(collective-permute, get-tuple-element.2), lhs_contracting_dims={2}, rhs_contracting_dims={0}
+  constant.5 = u32[] constant(1)
+  add.1 = u32[] add(get-tuple-element.5, constant.5)
+  add.2 = u32[] add(add.1, partition-id)
+  remainder.1 = u32[] remainder(add.2, constant.1)
+  dynamic-slice.1 = s32[1]{0} dynamic-slice(constant.3, remainder.1), dynamic_slice_sizes={1}
+  reshape.1 = s32[] reshape(dynamic-slice.1)
+  dynamic-update-slice.1 = bf16[2,2048,24576]{2,1,0} dynamic-update-slice(dynamic-update-slice, dot.1, constant.2, reshape.1, constant.2)
+  get-tuple-element.4 = bf16[2,2048,24576]{2,1,0} get-tuple-element(param.1), index=3
+  add.3 = u32[] add(add.1, constant.5)
+  ROOT tuple = (bf16[2,512,24576]{2,1,0}, bf16[24576,24576]{1,0}, bf16[2,2048,24576]{2,1,0}, bf16[2,2048,24576]{2,1,0}, u32[]) tuple(collective-permute.1, get-tuple-element.2, dynamic-update-slice.1, get-tuple-element.4, add.3)
+} // windowed_dot_general_body_ag
+
+windowed_dot_general_cond_ag {
+  param = (bf16[2,512,24576]{2,1,0}, bf16[24576,24576]{1,0}, bf16[2,2048,24576]{2,1,0}, bf16[2,2048,24576]{2,1,0}, u32[]) parameter(0)
+  get-tuple-element = u32[] get-tuple-element(param), index=4
+  constant = u32[] constant(4)
+  ROOT compare = pred[] compare(get-tuple-element, constant), direction=LT
+}
+
+ENTRY main.12_spmd {
+  param.4 = bf16[2,512,24576]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]}
+  param.5 = bf16[24576,24576]{1,0} parameter(1), sharding={devices=[1,4]<=[4]}
+  constant.22 = bf16[] constant(0)
+  broadcast = bf16[2,2048,24576]{2,1,0} broadcast(constant.22), dimensions={}
+  constant.24 = u32[] constant(0)
+  tuple.2 = (bf16[2,512,24576]{2,1,0}, bf16[24576,24576]{1,0}, bf16[2,2048,24576]{2,1,0}, bf16[2,2048,24576]{2,1,0}, u32[]) tuple(param.4, param.5, broadcast, broadcast, constant.24)
+  while = (bf16[2,512,24576]{2,1,0}, bf16[24576,24576]{1,0}, bf16[2,2048,24576]{2,1,0}, bf16[2,2048,24576]{2,1,0}, u32[]) while(tuple.2), condition=windowed_dot_general_cond_ag, body=windowed_dot_general_body_ag
+  get-tuple-element.13 = bf16[2,2048,24576]{2,1,0} get-tuple-element(while), index=2
+  copy.1 = bf16[2,2048,24576]{2,1,0} copy(get-tuple-element.13)
+  all-gather = bf16[2,2048,24576]{2,1,0} all-gather(param.4), channel_id=1, replica_groups={{0,1,2,3}}, dimensions={1}, use_global_device_ids=true
+  param.6 = bf16[24576,24576]{1,0} parameter(2), sharding={devices=[1,4]<=[4]}
+  ROOT dot.7 = bf16[2,2048,24576]{2,1,0} dot(all-gather, param.6), lhs_contracting_dims={2}, rhs_contracting_dims={0}
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(kHloString));
+
+  WindowedEinsumHandler gpu_handler;
+  bool changed;
+  TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get()));
+  EXPECT_TRUE(changed);
+  HloInstruction* ag_loop =
+      FindInstructionByName(module->entry_computation(), "while");
+  HloInstruction* inst =
+      FindInstructionByName(module->entry_computation(), "dot.7");
+  // dot.7 should now consume output of the windowed einsum while loop.
+  EXPECT_EQ(inst->operand(0)->opcode(), HloOpcode::kGetTupleElement);
+  EXPECT_EQ(inst->operand(0)->tuple_index(), 5);
+  EXPECT_EQ(inst->operand(0)->operand(0), ag_loop);
+
+  // while loop's root should now have a chain of DUS.
+  HloInstruction* ag_while_root = ag_loop->while_body()->root_instruction();
+  EXPECT_THAT(ag_while_root,
+              GmockMatch(m::Tuple(
+                  m::Op(), m::Op(), m::Op(), m::Op(), m::Op(),
+                  m::DynamicUpdateSlice(
+                      m::DynamicUpdateSlice(
+                          m::GetTupleElement(m::Parameter())
+                              .WithPredicate([](const HloInstruction* instr) {
+                                return instr->tuple_index() == 5;
+                              }),
+                          m::Op(), m::Op(), m::Op(), m::Op()),
+                      m::Op(), m::Op(), m::Op(), m::Op()))));
+}
+TEST_F(WindowedEinsumHanlderTest, A2aGemmHaveStreamIds) {
+  constexpr absl::string_view kHloString = R"(
+HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[1,8192,32768]{2,1,0}, bf16[1,4,2048,8192]{3,2,1,0})->bf16[1,4,2048,32768]{3,2,1,0}}, num_partitions=8
+
+ENTRY main.9_spmd {
+  param0 = bf16[1,8192,32768]{2,1,0} parameter(0)
+  param1 = bf16[1,4,2048,8192]{3,2,1,0} parameter(1)
+  all-to-all = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(param1), channel_id=4, replica_groups={{0,1,2,3},{4,5,6,7}}, dimensions={1}
+  ROOT dot.12 = bf16[1,4,2048,32768]{3,2,1,0} dot(all-to-all, param0), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}
+}
+)";
+
+  const char* kExpected = R"(
+CHECK: ENTRY
+CHECK-DAG: %[[P1:.*]] = bf16[1,4,2048,8192]{3,2,1,0} parameter(1)
+
+CHECK-DAG: %[[SLICE0:.*]] = bf16[1,4,2048,2048]{3,2,1,0} slice(bf16[1,4,2048,8192]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [6144:8192]}
+CHECK: %[[A2A0:.*]] = bf16[1,4,2048,2048]{3,2,1,0} all-to-all(bf16[1,4,2048,2048]{3,2,1,0} %[[SLICE0]]),
+CHECK: replica_groups={
+CHECK:     {0,1,2,3},{4,5,6,7}
+CHECK: }
+CHECK: dimensions={1}
+CHECK-DAG: %[[P0:.*]] = bf16[1,8192,32768]{2,1,0} parameter(0)
+CHECK-DAG: %[[SLICE4:.*]] = bf16[1,2048,32768]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [6144:8192], [0:32768]}
+CHECK-DAG: %[[DOT0:.*]] = bf16[1,4,2048,32768]{3,2,1,0} dot(bf16[1,4,2048,2048]{3,2,1,0} %[[A2A0:.*]], bf16[1,2048,32768]{2,1,0} %[[SLICE4:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"8","wait_on_operation_queues":[],"force_earliest_schedule":false}
+
+CHECK-DAG: %[[SLICE1:.*]] = bf16[1,4,2048,2048]{3,2,1,0} slice(bf16[1,4,2048,8192]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [4096:6144]}
+CHECK: %[[A2A1:.*]] = bf16[1,4,2048,2048]{3,2,1,0} all-to-all(bf16[1,4,2048,2048]{3,2,1,0} %[[SLICE1]]),
+CHECK: replica_groups={
+CHECK:     {0,1,2,3},{4,5,6,7}
+CHECK: }
+CHECK: dimensions={1}
+CHECK-DAG: %[[SLICE5:.*]] = bf16[1,2048,32768]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [4096:6144], [0:32768]}
+CHECK-DAG: %[[DOT1:.*]] = bf16[1,4,2048,32768]{3,2,1,0} dot(bf16[1,4,2048,2048]{3,2,1,0} %[[A2A1:.*]], bf16[1,2048,32768]{2,1,0} %[[SLICE5:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"7","wait_on_operation_queues":[],"force_earliest_schedule":false}
+
+CHECK-DAG: %[[SLICE2:.*]] = bf16[1,4,2048,2048]{3,2,1,0} slice(bf16[1,4,2048,8192]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [2048:4096]}
+CHECK: %[[A2A2:.*]] = bf16[1,4,2048,2048]{3,2,1,0} all-to-all(bf16[1,4,2048,2048]{3,2,1,0} %[[SLICE2]]),
+CHECK: replica_groups={
+CHECK:     {0,1,2,3},{4,5,6,7}
+CHECK: }
+CHECK: dimensions={1}
+CHECK-DAG: %[[SLICE6:.*]] = bf16[1,2048,32768]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [2048:4096], [0:32768]}
+CHECK-DAG: %[[DOT2:.*]] = bf16[1,4,2048,32768]{3,2,1,0} dot(bf16[1,4,2048,2048]{3,2,1,0} %[[A2A2:.*]], bf16[1,2048,32768]{2,1,0} %[[SLICE6:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"6","wait_on_operation_queues":[],"force_earliest_schedule":false}
+
+CHECK-DAG: %[[SLICE3:.*]] = bf16[1,4,2048,2048]{3,2,1,0} slice(bf16[1,4,2048,8192]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [0:2048]}
+CHECK: %[[A2A2:.*]] = bf16[1,4,2048,2048]{3,2,1,0} all-to-all(bf16[1,4,2048,2048]{3,2,1,0} %[[SLICE3]]),
+CHECK: replica_groups={
+CHECK:     {0,1,2,3},{4,5,6,7}
+CHECK: }
+CHECK: dimensions={1}
+CHECK-DAG: %[[SLICE7:.*]] = bf16[1,2048,32768]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [0:2048], [0:32768]}
+CHECK-DAG: %[[DOT3:.*]] = bf16[1,4,2048,32768]{3,2,1,0} dot(bf16[1,4,2048,2048]{3,2,1,0} %[[A2A3:.*]], bf16[1,2048,32768]{2,1,0} %[[SLICE7:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"5","wait_on_operation_queues":[],"force_earliest_schedule":false}
+CHECK-DAG: %[[CONSTANT:.*]] = bf16[] constant(0)
+CHECK-DAG: %[[BROADCAST:.*]] = bf16[1,4,2048,32768]{3,2,1,0} broadcast(bf16[] %[[CONSTANT:.*]]), dimensions={}
+CHECK-DAG: %[[ADD0:.*]] = bf16[1,4,2048,32768]{3,2,1,0} add(bf16[1,4,2048,32768]{3,2,1,0} %[[DOT0:.*]], bf16[1,4,2048,32768]{3,2,1,0} %[[BROADCAST:.*]]), backend_config={"operation_queue_id":"0","wait_on_operation_queues":["5"],"force_earliest_schedule":false}
+CHECK-DAG: %[[ADD1:.*]] = bf16[1,4,2048,32768]{3,2,1,0} add(bf16[1,4,2048,32768]{3,2,1,0} %[[DOT1:.*]], bf16[1,4,2048,32768]{3,2,1,0} %[[ADD0:.*]]), backend_config={"operation_queue_id":"0","wait_on_operation_queues":["6"],"force_earliest_schedule":false}
+CHECK-DAG: %[[ADD2:.*]] = bf16[1,4,2048,32768]{3,2,1,0} add(bf16[1,4,2048,32768]{3,2,1,0} %[[DOT2:.*]], bf16[1,4,2048,32768]{3,2,1,0} %[[ADD1:.*]]), backend_config={"operation_queue_id":"0","wait_on_operation_queues":["7"],"force_earliest_schedule":false}
+
+CHECK: ROOT {{.*}} = bf16[1,4,2048,32768]{3,2,1,0} add(bf16[1,4,2048,32768]{3,2,1,0} %[[DOT3:.*]], bf16[1,4,2048,32768]{3,2,1,0} %[[ADD2:.*]])
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(kHloString));
+
+  WindowedEinsumHandler gpu_handler;
+  bool changed;
+  TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get()));
+  TF_ASSERT_OK_AND_ASSIGN(bool filecheck_matched,
+                          RunFileCheck(module->ToString(), kExpected));
+  EXPECT_TRUE(filecheck_matched);
+}
+
+TEST_F(WindowedEinsumHanlderTest, GemmA2aHaveStreamIds) {
+  constexpr absl::string_view kHloString = R"(
+HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[1,8192,32768]{2,1,0}, bf16[1,4,2048,32768]{3,2,1,0})->bf16[1,4,2048,8192]{3,2,1,0}}, num_partitions=4
+
+ENTRY main.9_spmd {
+  param.9 = bf16[1,8192,32768]{2,1,0} parameter(0)
+  param.10 = bf16[1,4,2048,32768]{3,2,1,0} parameter(1)
+  dot.12 = bf16[1,4,2048,8192]{3,2,1,0} dot(param.10, param.9), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={2}
+  ROOT all-to-all = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(dot.12), channel_id=4, replica_groups={{0,1,2,3}}, dimensions={1}
+}
+)";
+
+  const char* kExpected = R"(
+CHECK: ENTRY
+CHECK-DAG: %[[P1:.*]] = bf16[1,4,2048,32768]{3,2,1,0} parameter(1)
+
+CHECK-DAG: %[[SLICE0:.*]] = bf16[1,4,2048,8192]{3,2,1,0} slice(bf16[1,4,2048,32768]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [24576:32768]}
+CHECK-DAG: %[[P0:.*]] = bf16[1,8192,32768]{2,1,0} parameter(0)
+CHECK-DAG: %[[SLICE4:.*]] = bf16[1,8192,8192]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [0:8192], [24576:32768]}
+CHECK-DAG: %[[DOT0:.*]] = bf16[1,4,2048,8192]{3,2,1,0} dot(bf16[1,4,2048,8192]{3,2,1,0} %[[SLICE0:.*]], bf16[1,8192,8192]{2,1,0} %[[SLICE4:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={2}, backend_config={"operation_queue_id":"8","wait_on_operation_queues":[],"force_earliest_schedule":false}
+CHECK: %[[A2A0:.*]] = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(bf16[1,4,2048,8192]{3,2,1,0} %[[DOT0:.*]]),
+CHECK: replica_groups={
+CHECK:     {0,1,2,3}
+CHECK: }
+CHECK: dimensions={1}
+
+CHECK-DAG: %[[SLICE1:.*]] = bf16[1,4,2048,8192]{3,2,1,0} slice(bf16[1,4,2048,32768]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [16384:24576]}
+CHECK-DAG: %[[SLICE5:.*]] = bf16[1,8192,8192]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [0:8192], [16384:24576]}
+CHECK-DAG: %[[DOT1:.*]] = bf16[1,4,2048,8192]{3,2,1,0} dot(bf16[1,4,2048,8192]{3,2,1,0} %[[SLICE1:.*]], bf16[1,8192,8192]{2,1,0} %[[SLICE5:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={2}, backend_config={"operation_queue_id":"7","wait_on_operation_queues":[],"force_earliest_schedule":false}
+CHECK: %[[A2A1:.*]] = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(bf16[1,4,2048,8192]{3,2,1,0} %[[DOT1:.*]]),
+CHECK: replica_groups={
+CHECK:     {0,1,2,3}
+CHECK: }
+CHECK: dimensions={1}
+
+CHECK-DAG: %[[SLICE2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} slice(bf16[1,4,2048,32768]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [8192:16384]}
+CHECK-DAG: %[[SLICE6:.*]] = bf16[1,8192,8192]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [0:8192], [8192:16384]}
+CHECK-DAG: %[[DOT2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} dot(bf16[1,4,2048,8192]{3,2,1,0} %[[SLICE2:.*]], bf16[1,8192,8192]{2,1,0} %[[SLICE6:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={2}, backend_config={"operation_queue_id":"6","wait_on_operation_queues":[],"force_earliest_schedule":false}
+CHECK: %[[A2A2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(bf16[1,4,2048,8192]{3,2,1,0} %[[DOT2:.*]]),
+CHECK: replica_groups={
+CHECK:     {0,1,2,3}
+CHECK: }
+CHECK: dimensions={1}
+
+CHECK-DAG: %[[SLICE3:.*]] = bf16[1,4,2048,8192]{3,2,1,0} slice(bf16[1,4,2048,32768]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [0:8192]}
+CHECK-DAG: %[[SLICE7:.*]] = bf16[1,8192,8192]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [0:8192], [0:8192]}
+CHECK-DAG: %[[DOT3:.*]] = bf16[1,4,2048,8192]{3,2,1,0} dot(bf16[1,4,2048,8192]{3,2,1,0} %[[SLICE3:.*]], bf16[1,8192,8192]{2,1,0} %[[SLICE7:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={2}, backend_config={"operation_queue_id":"5","wait_on_operation_queues":[],"force_earliest_schedule":false}
+CHECK: %[[A2A2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(bf16[1,4,2048,8192]{3,2,1,0} %[[DOT3:.*]]),
+CHECK: replica_groups={
+CHECK:     {0,1,2,3}
+CHECK: }
+CHECK: dimensions={1}
+CHECK-DAG: %[[CONSTANT:.*]] = bf16[] constant(0)
+CHECK-DAG: %[[BROADCAST:.*]] = bf16[1,4,2048,8192]{3,2,1,0} broadcast(bf16[] %[[CONSTANT:.*]]), dimensions={}
+CHECK-DAG: %[[ADD0:.*]] = bf16[1,4,2048,8192]{3,2,1,0} add(bf16[1,4,2048,8192]{3,2,1,0} %[[A2A0:.*]], bf16[1,4,2048,8192]{3,2,1,0} %[[BROADCAST:.*]])
+CHECK-DAG: %[[ADD1:.*]] = bf16[1,4,2048,8192]{3,2,1,0} add(bf16[1,4,2048,8192]{3,2,1,0} %[[A2A1:.*]], bf16[1,4,2048,8192]{3,2,1,0} %[[ADD0:.*]])
+CHECK-DAG: %[[ADD2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} add(bf16[1,4,2048,8192]{3,2,1,0} %[[A2A2:.*]], bf16[1,4,2048,8192]{3,2,1,0} %[[ADD1:.*]])
+
+CHECK: ROOT {{.*}} = bf16[1,4,2048,8192]{3,2,1,0} add(bf16[1,4,2048,8192]{3,2,1,0} %[[A2A3:.*]], bf16[1,4,2048,8192]{3,2,1,0} %[[ADD2:.*]])
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(kHloString));
+
+  WindowedEinsumHandler gpu_handler;
+  bool changed;
+  TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get()));
+  TF_ASSERT_OK_AND_ASSIGN(bool filecheck_matched,
+                          RunFileCheck(module->ToString(), kExpected));
+  EXPECT_TRUE(filecheck_matched);
+}
+
+TEST_F(WindowedEinsumHanlderTest, A2aTransposeLoopsHaveStreamIds) {
+  constexpr absl::string_view kHloString = R"(
+HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[1,8192,32768]{2,1,0}, bf16[1,1,8192,4,1,2048]{5,4,3,2,1,0})->bf16[1,4,2048,32768]{3,2,1,0}}, num_partitions=4
+
+ENTRY main.9_spmd {
+  param.9 = bf16[1,8192,32768]{2,1,0} parameter(0)
+  param.10 = bf16[1,1,8192,4,1,2048]{5,4,3,2,1,0} parameter(1)
+  all-to-all = bf16[1,1,8192,4,1,2048]{5,4,3,2,1,0} all-to-all(param.10), channel_id=4, replica_groups={{0,1,2,3}}, dimensions={3}
+  transpose.15 = bf16[1,4,1,8192,1,2048]{5,4,1,3,2,0} transpose(all-to-all), dimensions={0,3,1,2,4,5}
+  reshape.2170 = bf16[1,4,8192,1,2048]{4,3,2,1,0} reshape(transpose.15)
+  reshape.2173 = bf16[4,8192,1,2048]{3,2,1,0} reshape(reshape.2170)
+  transpose.16 = bf16[1,4,2048,8192]{2,0,3,1} transpose(reshape.2173), dimensions={2,0,3,1}
+  copy.53 = bf16[1,4,2048,8192]{3,2,1,0} copy(transpose.16)
+  ROOT dot.12 = bf16[1,4,2048,32768]{3,2,1,0} dot(copy.53, param.9), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}
+}
+)";
+
+  const char* kExpected = R"(
+CHECK: ENTRY
+CHECK-DAG: %[[P1:.*]] = bf16[1,1,8192,4,1,2048]{5,4,3,2,1,0} parameter(1)
+CHECK-DAG: %[[TRANSPOSE0:.*]] = bf16[1,4,1,8192,1,2048]{5,4,1,3,2,0} transpose(bf16[1,1,8192,4,1,2048]{5,4,3,2,1,0} %[[P1:.*]]), dimensions={0,3,1,2,4,5}
+CHECK-DAG: %[[RESHAPE0:.*]] = bf16[1,4,8192,1,2048]{4,3,2,1,0} reshape(bf16[1,4,1,8192,1,2048]{5,4,1,3,2,0} %[[TRANSPOSE0:.*]])
+CHECK-DAG: %[[RESHAPE1:.*]] = bf16[4,8192,1,2048]{3,2,1,0} reshape(bf16[1,4,8192,1,2048]{4,3,2,1,0} %[[RESHAPE0:.*]])
+CHECK-DAG: %[[TRANSPOSE1:.*]] = bf16[1,4,2048,8192]{2,0,3,1} transpose(bf16[4,8192,1,2048]{3,2,1,0} %[[RESHAPE1:.*]]), dimensions={2,0,3,1}
+CHECK-DAG: %[[COPY:.*]] = bf16[1,4,2048,8192]{3,2,1,0} copy(bf16[1,4,2048,8192]{2,0,3,1} %[[TRANSPOSE1:.*]])
+
+CHECK-DAG: %[[SLICE0:.*]] = bf16[1,4,2048,2048]{3,2,1,0} slice(bf16[1,4,2048,8192]{3,2,1,0} %[[COPY:.*]]), slice={[0:1], [0:4], [0:2048], [6144:8192]}
+CHECK: %[[A2A0:.*]] = bf16[1,4,2048,2048]{3,2,1,0} all-to-all(bf16[1,4,2048,2048]{3,2,1,0} %[[SLICE0]]),
+CHECK: replica_groups={
+CHECK:     {0,1,2,3}
+CHECK: }
+CHECK: dimensions={1}
+CHECK-DAG: %[[P0:.*]] = bf16[1,8192,32768]{2,1,0} parameter(0)
+CHECK-DAG: %[[SLICE4:.*]] = bf16[1,2048,32768]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [6144:8192], [0:32768]}
+CHECK-DAG: %[[DOT0:.*]] = bf16[1,4,2048,32768]{3,2,1,0} dot(bf16[1,4,2048,2048]{3,2,1,0} %[[A2A0:.*]], bf16[1,2048,32768]{2,1,0} %[[SLICE4:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"9","wait_on_operation_queues":[],"force_earliest_schedule":false}
+
+CHECK-DAG: %[[SLICE1:.*]] = bf16[1,4,2048,2048]{3,2,1,0} slice(bf16[1,4,2048,8192]{3,2,1,0} %[[COPY:.*]]), slice={[0:1], [0:4], [0:2048], [4096:6144]}
+CHECK: %[[A2A1:.*]] = bf16[1,4,2048,2048]{3,2,1,0} all-to-all(bf16[1,4,2048,2048]{3,2,1,0} %[[SLICE1]]),
+CHECK: replica_groups={
+CHECK:     {0,1,2,3}
+CHECK: }
+CHECK: dimensions={1}
+CHECK-DAG: %[[SLICE5:.*]] = bf16[1,2048,32768]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [4096:6144], [0:32768]}
+CHECK-DAG: %[[DOT1:.*]] = bf16[1,4,2048,32768]{3,2,1,0} dot(bf16[1,4,2048,2048]{3,2,1,0} %[[A2A1:.*]], bf16[1,2048,32768]{2,1,0} %[[SLICE5:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"8","wait_on_operation_queues":[],"force_earliest_schedule":false}
+
+CHECK-DAG: %[[SLICE2:.*]] = bf16[1,4,2048,2048]{3,2,1,0} slice(bf16[1,4,2048,8192]{3,2,1,0} %[[COPY:.*]]), slice={[0:1], [0:4], [0:2048], [2048:4096]}
+CHECK: %[[A2A2:.*]] = bf16[1,4,2048,2048]{3,2,1,0} all-to-all(bf16[1,4,2048,2048]{3,2,1,0} %[[SLICE2]]),
+CHECK: replica_groups={
+CHECK:     {0,1,2,3}
+CHECK: }
+CHECK: dimensions={1}
+CHECK-DAG: %[[SLICE6:.*]] = bf16[1,2048,32768]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [2048:4096], [0:32768]}
+CHECK-DAG: %[[DOT2:.*]] = bf16[1,4,2048,32768]{3,2,1,0} dot(bf16[1,4,2048,2048]{3,2,1,0} %[[A2A2:.*]], bf16[1,2048,32768]{2,1,0} %[[SLICE6:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"7","wait_on_operation_queues":[],"force_earliest_schedule":false}
+
+CHECK-DAG: %[[SLICE3:.*]] = bf16[1,4,2048,2048]{3,2,1,0} slice(bf16[1,4,2048,8192]{3,2,1,0} %[[COPY:.*]]), slice={[0:1], [0:4], [0:2048], [0:2048]}
+CHECK: %[[A2A2:.*]] = bf16[1,4,2048,2048]{3,2,1,0} all-to-all(bf16[1,4,2048,2048]{3,2,1,0} %[[SLICE3]]),
+CHECK: replica_groups={
+CHECK:     {0,1,2,3}
+CHECK: }
+CHECK: dimensions={1}
+CHECK-DAG: %[[SLICE7:.*]] = bf16[1,2048,32768]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [0:2048], [0:32768]}
+CHECK-DAG: %[[DOT3:.*]] = bf16[1,4,2048,32768]{3,2,1,0} dot(bf16[1,4,2048,2048]{3,2,1,0} %[[A2A3:.*]], bf16[1,2048,32768]{2,1,0} %[[SLICE7:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"6","wait_on_operation_queues":[],"force_earliest_schedule":false}
+CHECK-DAG: %[[CONSTANT:.*]] = bf16[] constant(0)
+CHECK-DAG: %[[BROADCAST:.*]] = bf16[1,4,2048,32768]{3,2,1,0} broadcast(bf16[] %[[CONSTANT:.*]]), dimensions={}
+CHECK-DAG: %[[ADD0:.*]] = bf16[1,4,2048,32768]{3,2,1,0} add(bf16[1,4,2048,32768]{3,2,1,0} %[[DOT0:.*]], bf16[1,4,2048,32768]{3,2,1,0} %[[BROADCAST:.*]]), backend_config={"operation_queue_id":"0","wait_on_operation_queues":["6"],"force_earliest_schedule":false}
+CHECK-DAG: %[[ADD1:.*]] = bf16[1,4,2048,32768]{3,2,1,0} add(bf16[1,4,2048,32768]{3,2,1,0} %[[DOT1:.*]], bf16[1,4,2048,32768]{3,2,1,0} %[[ADD0:.*]]), backend_config={"operation_queue_id":"0","wait_on_operation_queues":["7"],"force_earliest_schedule":false}
+CHECK-DAG: %[[ADD2:.*]] = bf16[1,4,2048,32768]{3,2,1,0} add(bf16[1,4,2048,32768]{3,2,1,0} %[[DOT2:.*]], bf16[1,4,2048,32768]{3,2,1,0} %[[ADD1:.*]]), backend_config={"operation_queue_id":"0","wait_on_operation_queues":["8"],"force_earliest_schedule":false}
+
+CHECK: ROOT {{.*}} = bf16[1,4,2048,32768]{3,2,1,0} add(bf16[1,4,2048,32768]{3,2,1,0} %[[DOT3:.*]], bf16[1,4,2048,32768]{3,2,1,0} %[[ADD2:.*]])
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(kHloString));
+
+  WindowedEinsumHandler gpu_handler;
+  bool changed;
+  TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get()));
+  EXPECT_TRUE(changed);
+  TF_ASSERT_OK_AND_ASSIGN(bool filecheck_matched,
+                          RunFileCheck(module->ToString(), kExpected));
+  EXPECT_TRUE(filecheck_matched);
+}
+
+TEST_F(WindowedEinsumHanlderTest, GemmA2aTransposeLoopsHaveStreamIds) {
+  constexpr absl::string_view kHloString = R"(
+HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[1,4,2048,32768]{3,2,1,0}, bf16[1,32768,8192]{2,1,0})->bf16[1,4,1,1,2048,8192]{5,4,3,2,1,0}}, num_partitions=4
+
+ENTRY main.9_spmd {
+  param.9 = bf16[1,4,2048,32768]{3,2,1,0} parameter(0)
+  param.10 = bf16[1,32768,8192]{2,1,0} parameter(1)
+  dot.13 = bf16[1,4,2048,8192]{3,2,1,0} dot(param.9, param.10), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}
+  copy.55 = bf16[1,4,2048,8192]{3,2,1,0} copy(dot.13)
+  transpose.17 = bf16[4,1,2048,8192]{3,2,0,1} transpose(copy.55), dimensions={1,0,2,3}
+  copy.56 = bf16[4,1,2048,8192]{3,2,1,0} copy(transpose.17)
+  reshape.2216 = bf16[1,4,1,2048,8192]{4,3,2,1,0} reshape(copy.56)
+  reshape.2219 = bf16[1,4,1,1,2048,8192]{5,4,3,2,1,0} reshape(reshape.2216)
+  ROOT all-to-all.1 = bf16[1,4,1,1,2048,8192]{5,4,3,2,1,0} all-to-all(reshape.2219), channel_id=7, replica_groups={{0,1,2,3}}, dimensions={1}
+}
+)";
+
+  const char* kExpected = R"(
+CHECK: ENTRY
+CHECK-DAG: %[[P1:.*]] = bf16[1,4,2048,32768]{3,2,1,0} parameter(0)
+
+CHECK-DAG: %[[SLICE0:.*]] = bf16[1,4,2048,8192]{3,2,1,0} slice(bf16[1,4,2048,32768]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [24576:32768]}
+CHECK-DAG: %[[P0:.*]] = bf16[1,32768,8192]{2,1,0} parameter(1)
+CHECK-DAG: %[[SLICE4:.*]] = bf16[1,8192,8192]{2,1,0} slice(bf16[1,32768,8192]{2,1,0} %[[P0:.*]]), slice={[0:1], [24576:32768], [0:8192]}
+CHECK-DAG: %[[DOT0:.*]] = bf16[1,4,2048,8192]{3,2,1,0} dot(bf16[1,4,2048,8192]{3,2,1,0} %[[SLICE0:.*]], bf16[1,8192,8192]{2,1,0} %[[SLICE4:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"12","wait_on_operation_queues":[],"force_earliest_schedule":false}
+CHECK: %[[A2A0:.*]] = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(bf16[1,4,2048,8192]{3,2,1,0} %[[DOT0:.*]]),
+CHECK: replica_groups={
+CHECK:     {0,1,2,3}
+CHECK: }
+CHECK: dimensions={1}
+
+CHECK-DAG: %[[SLICE1:.*]] = bf16[1,4,2048,8192]{3,2,1,0} slice(bf16[1,4,2048,32768]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [16384:24576]}
+CHECK-DAG: %[[SLICE5:.*]] = bf16[1,8192,8192]{2,1,0} slice(bf16[1,32768,8192]{2,1,0} %[[P0:.*]]), slice={[0:1], [16384:24576], [0:8192]}
+CHECK-DAG: %[[DOT1:.*]] = bf16[1,4,2048,8192]{3,2,1,0} dot(bf16[1,4,2048,8192]{3,2,1,0} %[[SLICE1:.*]], bf16[1,8192,8192]{2,1,0} %[[SLICE5:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"11","wait_on_operation_queues":[],"force_earliest_schedule":false}
+CHECK: %[[A2A1:.*]] = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(bf16[1,4,2048,8192]{3,2,1,0} %[[DOT1:.*]]),
+CHECK: replica_groups={
+CHECK:     {0,1,2,3}
+CHECK: }
+CHECK: dimensions={1}
+
+CHECK-DAG: %[[SLICE2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} slice(bf16[1,4,2048,32768]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [8192:16384]}
+CHECK-DAG: %[[SLICE6:.*]] = bf16[1,8192,8192]{2,1,0} slice(bf16[1,32768,8192]{2,1,0} %[[P0:.*]]), slice={[0:1], [8192:16384], [0:8192]}
+CHECK-DAG: %[[DOT2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} dot(bf16[1,4,2048,8192]{3,2,1,0} %[[SLICE2:.*]], bf16[1,8192,8192]{2,1,0} %[[SLICE6:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"10","wait_on_operation_queues":[],"force_earliest_schedule":false}
+CHECK: %[[A2A2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(bf16[1,4,2048,8192]{3,2,1,0} %[[DOT2:.*]]),
+CHECK: replica_groups={
+CHECK:     {0,1,2,3}
+CHECK: }
+CHECK: dimensions={1}
+
+CHECK-DAG: %[[SLICE3:.*]] = bf16[1,4,2048,8192]{3,2,1,0} slice(bf16[1,4,2048,32768]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [0:8192]}
+CHECK-DAG: %[[SLICE7:.*]] = bf16[1,8192,8192]{2,1,0} slice(bf16[1,32768,8192]{2,1,0} %[[P0:.*]]), slice={[0:1], [0:8192], [0:8192]}
+CHECK-DAG: %[[DOT3:.*]] = bf16[1,4,2048,8192]{3,2,1,0} dot(bf16[1,4,2048,8192]{3,2,1,0} %[[SLICE3:.*]], bf16[1,8192,8192]{2,1,0} %[[SLICE7:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"9","wait_on_operation_queues":[],"force_earliest_schedule":false}
+CHECK: %[[A2A2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(bf16[1,4,2048,8192]{3,2,1,0} %[[DOT3:.*]]),
+CHECK: replica_groups={
+CHECK:     {0,1,2,3}
+CHECK: }
+CHECK: dimensions={1}
+CHECK-DAG: %[[CONSTANT:.*]] = bf16[] constant(0)
+CHECK-DAG: %[[BROADCAST:.*]] = bf16[1,4,2048,8192]{3,2,1,0} broadcast(bf16[] %[[CONSTANT:.*]]), dimensions={}
+CHECK-DAG: %[[ADD0:.*]] = bf16[1,4,2048,8192]{3,2,1,0} add(bf16[1,4,2048,8192]{3,2,1,0} %[[A2A0:.*]], bf16[1,4,2048,8192]{3,2,1,0} %[[BROADCAST:.*]])
+CHECK-DAG: %[[ADD1:.*]] = bf16[1,4,2048,8192]{3,2,1,0} add(bf16[1,4,2048,8192]{3,2,1,0} %[[A2A1:.*]], bf16[1,4,2048,8192]{3,2,1,0} %[[ADD0:.*]])
+CHECK-DAG: %[[ADD2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} add(bf16[1,4,2048,8192]{3,2,1,0} %[[A2A2:.*]], bf16[1,4,2048,8192]{3,2,1,0} %[[ADD1:.*]])
+CHECK-DAG: %[[ADD3:.*]] = bf16[1,4,2048,8192]{3,2,1,0} add(bf16[1,4,2048,8192]{3,2,1,0} %[[A2A3:.*]], bf16[1,4,2048,8192]{3,2,1,0} %[[ADD2:.*]])
+
+CHECK-DAG: %[[COPY:.*]] = bf16[1,4,2048,8192]{3,2,1,0} copy(bf16[1,4,2048,8192]{3,2,1,0} %[[ADD3:.*]])
+CHECK-DAG: %[[TRANSPOSE0:.*]] = bf16[4,1,2048,8192]{3,2,0,1} transpose(bf16[1,4,2048,8192]{3,2,1,0} %[[COPY:.*]]), dimensions={1,0,2,3}
+CHECK-DAG: %[[COPY1:.*]] = bf16[4,1,2048,8192]{3,2,1,0} copy(bf16[4,1,2048,8192]{3,2,0,1} %[[TRANSPOSE0:.*]])
+CHECK-DAG: %[[RESHAPE0:.*]] = bf16[1,4,1,2048,8192]{4,3,2,1,0} reshape(bf16[4,1,2048,8192]{3,2,1,0} %[[COPY1:.*]])
+
+CHECK: ROOT {{.*}} = bf16[1,4,1,1,2048,8192]{5,4,3,2,1,0} reshape(bf16[1,4,1,2048,8192]{4,3,2,1,0} %[[RESHAPE0:.*]])
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(kHloString));
+
+  WindowedEinsumHandler gpu_handler;
+  bool changed;
+  TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get()));
+  EXPECT_TRUE(changed);
+  TF_ASSERT_OK_AND_ASSIGN(bool filecheck_matched,
+                          RunFileCheck(module->ToString(), kExpected));
+  EXPECT_TRUE(filecheck_matched);
+}
+
+TEST_F(WindowedEinsumHanlderTest, AllGatherF8) {
+  constexpr absl::string_view kHloString = R"(
+HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(f8e4m3fn[2,512,24576]{2,1,0}, f8e4m3fn[24576,24576]{1,0}, f32[], f32[])->f32[2,2048,24576]{2,1,0}}, num_partitions=4
+
+windowed_dot_general_body_ag {
+  param.1 = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) parameter(0)
+  get-tuple-element.1 = f32[2,512,24576]{2,1,0} get-tuple-element(param.1), index=0
+  collective-permute = f32[2,512,24576]{2,1,0} collective-permute(get-tuple-element.1), channel_id=4, source_target_pairs={{0,3},{1,0},{2,1},{3,2}}
+  collective-permute.1 = f32[2,512,24576]{2,1,0} collective-permute(collective-permute), channel_id=5, source_target_pairs={{0,3},{1,0},{2,1},{3,2}}
+  get-tuple-element.2 = f32[24576,24576]{1,0} get-tuple-element(param.1), index=1
+  get-tuple-element.3 = f32[2,2048,24576]{2,1,0} get-tuple-element(param.1), index=2
+  dot = f32[2,512,24576]{2,1,0} dot(get-tuple-element.1, get-tuple-element.2), lhs_contracting_dims={2}, rhs_contracting_dims={0}
+  constant.12 = s32[] constant(0)
+  constant.13 = s32[4]{0} constant({0, 512, 1024, 1536})
+  get-tuple-element.5 = u32[] get-tuple-element(param.1), index=4
+  partition-id = u32[] partition-id()
+  add = u32[] add(get-tuple-element.5, partition-id)
+  constant.11 = u32[] constant(4)
+  remainder = u32[] remainder(add, constant.11)
+  dynamic-slice = s32[1]{0} dynamic-slice(constant.13, remainder), dynamic_slice_sizes={1}
+  reshape = s32[] reshape(dynamic-slice)
+  dynamic-update-slice = f32[2,2048,24576]{2,1,0} dynamic-update-slice(get-tuple-element.3, dot, constant.12, reshape, constant.12)
+  dot.1 = f32[2,512,24576]{2,1,0} dot(collective-permute, get-tuple-element.2), lhs_contracting_dims={2}, rhs_contracting_dims={0}
+  constant.15 = u32[] constant(1)
+  add.1 = u32[] add(get-tuple-element.5, constant.15)
+  add.2 = u32[] add(add.1, partition-id)
+  remainder.1 = u32[] remainder(add.2, constant.11)
+  dynamic-slice.1 = s32[1]{0} dynamic-slice(constant.13, remainder.1), dynamic_slice_sizes={1}
+  reshape.1 = s32[] reshape(dynamic-slice.1)
+  dynamic-update-slice.1 = f32[2,2048,24576]{2,1,0} dynamic-update-slice(dynamic-update-slice, dot.1, constant.12, reshape.1, constant.12)
+  get-tuple-element.4 = f32[2,2048,24576]{2,1,0} get-tuple-element(param.1), index=3
+  add.3 = u32[] add(add.1, constant.15)
+  ROOT tuple = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) tuple(collective-permute.1, get-tuple-element.2, dynamic-update-slice.1, get-tuple-element.4, add.3)
+} // windowed_dot_general_body_ag
+
+windowed_dot_general_cond_ag {
+  param = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) parameter(0)
+  get-tuple-element = u32[] get-tuple-element(param), index=4
+  constant.10 = u32[] constant(4)
+  ROOT compare = pred[] compare(get-tuple-element, constant.10), direction=LT
+}
+
+ENTRY test_main {
+  param.4 = f8e4m3fn[2,512,24576]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]}
+  reshape.8 = f8e4m3fn[2,512,24576]{2,1,0} reshape(param.4)
+  param.5 = f8e4m3fn[24576,24576]{1,0} parameter(1), sharding={devices=[1,4]<=[4]}
+  constant.18 = f32[] constant(0)
+  broadcast = f32[2,2048,24576]{2,1,0} broadcast(constant.18), dimensions={}
+  constant.20 = u32[] constant(0)
+  scale_lhs = f32[] parameter(2)
+  scale_lhs_bcast = f32[2,512,24576]{2,1,0} broadcast(scale_lhs), dimensions={}
+  lhs_bf32 = f32[2,512,24576]{2,1,0} convert(reshape.8)  
+  lhs_scaled = f32[2,512,24576]{2,1,0} multiply(lhs_bf32, scale_lhs_bcast)
+  scale_rhs = f32[] parameter(3)
+  scale_rhs_bcast = f32[24576,24576]{1,0} broadcast(scale_rhs), dimensions={}
+  rhs_bf32 = f32[24576,24576]{1,0} convert(param.5)  
+  rhs_scaled = f32[24576,24576]{1,0} multiply(rhs_bf32, scale_rhs_bcast)
+  tuple.2 = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) tuple(lhs_scaled, rhs_scaled, broadcast, broadcast, constant.20)
+  while = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) while(tuple.2), condition=windowed_dot_general_cond_ag, body=windowed_dot_general_body_ag
+  ROOT get-tuple-element.13 = f32[2,2048,24576]{2,1,0} get-tuple-element(while), index=2
+}
+)";
+
+  RunAndFilecheckHloRewrite(kHloString, WindowedEinsumHandler(),
+                            R"(
+; CHECK-LABEL: windowed_dot_general_body_ag
+; CHECK-NEXT:    [[P0:%[^ ]+]] = (f8e4m3fn[2,512,24576]{2,1,0}, f8e4m3fn[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[], /*index=5*/f32[], f32[]) parameter(0)
+; CHECK-NEXT:    [[GTE0:%[^ ]+]] = f8e4m3fn[2,512,24576]{2,1,0} get-tuple-element([[P0]]), index=0
+; CHECK-NEXT:    [[CP0:%[^ ]+]] = f8e4m3fn[2,512,24576]{2,1,0} collective-permute([[GTE0]]), channel_id=4
+; CHECK-NEXT:    [[CP1:%[^ ]+]] = f8e4m3fn[2,512,24576]{2,1,0} collective-permute([[CP0]]), channel_id=5
+; CHECK-NEXT:    [[GTE1:%[^ ]+]] = f8e4m3fn[24576,24576]{1,0} get-tuple-element([[P0]]), index=1
+; CHECK-NEXT:    [[GTE2:%[^ ]+]] = f32[2,2048,24576]{2,1,0} get-tuple-element([[P0]]), index=2
+; CHECK-NEXT:    [[CONVERT0:%[^ ]+]] = f32[2,512,24576]{2,1,0} convert([[GTE0]])
+; CHECK-NEXT:    [[GTE3:%[^ ]+]] = f32[] get-tuple-element([[P0]]), index=5
+; CHECK-NEXT:    [[BCAST0:%[^ ]+]] = f32[2,512,24576]{2,1,0} broadcast([[GTE3]]), dimensions={}
+; CHECK-NEXT:    [[MUL0:%[^ ]+]] = f32[2,512,24576]{2,1,0} multiply([[CONVERT0]], [[BCAST0]])
+; CHECK-NEXT:    [[CONVERT1:%[^ ]+]] = f32[24576,24576]{1,0} convert([[GTE1]])
+; CHECK-NEXT:    [[GTE4:%[^ ]+]] = f32[] get-tuple-element([[P0]]), index=6
+; CHECK-NEXT:    [[BCAST1:%[^ ]+]] = f32[24576,24576]{1,0} broadcast([[GTE4]]), dimensions={}
+; CHECK-NEXT:    [[MUL1:%[^ ]+]] = f32[24576,24576]{1,0} multiply([[CONVERT1]], [[BCAST1]])
+; CHECK-NEXT:    [[DOT0:%[^ ]+]] = f32[2,512,24576]{2,1,0} dot([[MUL0]], [[MUL1]]),
+; CHECK-DAG:       lhs_contracting_dims={2},
+; CHECK-DAG:       rhs_contracting_dims={0},
+; CHECK-DAG:       backend_config={
+; CHECK-DAG:         "operation_queue_id":"[[OPQUEUEID:[0-9]+]]",
+; CHECK-DAG:         "wait_on_operation_queues":[],
+; CHECK-DAG:         "force_earliest_schedule":true}
+; CHECK-NEXT:    [[C0:%[^ ]+]] = s32[] constant(0)
+; CHECK-NEXT:    [[C1:%[^ ]+]] = s32[4]{0} constant({0, 512, 1024, 1536})
+; CHECK-NEXT:    [[GTE5:%[^ ]+]] = u32[] get-tuple-element([[P0]]), index=4
+; CHECK-NEXT:    [[PID:%[^ ]+]] = u32[] partition-id()
+; CHECK-NEXT:    [[ADD0:%[^ ]+]] = u32[] add([[GTE5]], [[PID]])
+; CHECK-NEXT:    [[C2:%[^ ]+]] = u32[] constant(4)
+; CHECK-NEXT:    [[REM0:%[^ ]+]] = u32[] remainder([[ADD0]], [[C2]])
+; CHECK-NEXT:    [[DSLICE0:%[^ ]+]] = s32[1]{0} dynamic-slice([[C1]], [[REM0]]), dynamic_slice_sizes={1}
+; CHECK-NEXT:    [[RESHAPE0:%[^ ]+]] = s32[] reshape([[DSLICE0]])
+; CHECK-NEXT:    [[DUPDATESLICE0:%[^ ]+]] = f32[2,2048,24576]{2,1,0} dynamic-update-slice([[GTE2]], [[DOT0]], [[C0]], [[RESHAPE0]], [[C0]]),
+; CHECK-DAG:       backend_config={
+; CHECK-DAG:         "operation_queue_id":"0",
+; CHECK-DAG:         "wait_on_operation_queues":["[[OPQUEUEID]]"],
+; CHECK-DAG:         "force_earliest_schedule":false}
+; CHECK-NEXT:    [[CONVERT2:%[^ ]+]] = f32[2,512,24576]{2,1,0} convert([[CP0]])
+; CHECK-NEXT:    [[MUL2:%[^ ]+]] = f32[2,512,24576]{2,1,0} multiply([[CONVERT2]], [[BCAST0]])
+; CHECK-NEXT:    [[DOT1:%[^ ]+]] = f32[2,512,24576]{2,1,0} dot([[MUL2]], [[MUL1]]),
+; CHECK-DAG:       lhs_contracting_dims={2},
+; CHECK-DAG:       rhs_contracting_dims={0}
+; CHECK-NEXT:    [[C3:%[^ ]+]] = u32[] constant(1)
+; CHECK-NEXT:    [[ADD1:%[^ ]+]] = u32[] add([[GTE5]], [[C3]])
+; CHECK-NEXT:    [[ADD2:%[^ ]+]] = u32[] add([[ADD1]], [[PID]])
+; CHECK-NEXT:    [[REM1:%[^ ]+]] = u32[] remainder([[ADD2]], [[C2]])
+; CHECK-NEXT:    [[DSLICE1:%[^ ]+]] = s32[1]{0} dynamic-slice([[C1]], [[REM1]]), dynamic_slice_sizes={1}
+; CHECK-NEXT:    [[RESHAPE1:%[^ ]+]] = s32[] reshape([[DSLICE1]])
+; CHECK-NEXT:    [[DUPDATESLICE1:%[^ ]+]] = f32[2,2048,24576]{2,1,0} dynamic-update-slice([[DUPDATESLICE0]], [[DOT1]], [[C0]], [[RESHAPE1]], [[C0]])
+; CHECK-NEXT:    [[GTE6:%[^ ]+]] = f32[2,2048,24576]{2,1,0} get-tuple-element([[P0]]), index=3
+; CHECK-NEXT:    [[ADD3:%[^ ]+]] = u32[] add([[ADD1]], [[C3]])
+; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = (f8e4m3fn[2,512,24576]{2,1,0}, f8e4m3fn[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[], /*index=5*/f32[], f32[]) tuple([[CP1]], [[GTE1]], [[DUPDATESLICE1]], [[GTE6]], [[ADD3]], /*index=5*/[[GTE3]], [[GTE4]])
+)");
+}
+
+TEST_F(WindowedEinsumHanlderTest, ReduceScatterF8) {
+  constexpr absl::string_view kHloString = R"(
+HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(f8e4m3fn[24576,24576]{1,0}, f32[2,512,24576]{2,1,0}, f8e4m3fn[2,2048,24576]{2,1,0}, f32[], f32[])->f32[2,512,24576]{2,1,0}}, num_partitions=4
+
+windowed_dot_general_body_rs {
+  param.3 = (f32[2,2048,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,512,24576]{2,1,0}, f32[2,512,24576]{2,1,0}, u32[]) parameter(0)
+  get-tuple-element.7 = f32[2,2048,24576]{2,1,0} get-tuple-element(param.3), index=0
+  get-tuple-element.8 = f32[24576,24576]{1,0} get-tuple-element(param.3), index=1
+  get-tuple-element.9 = f32[2,512,24576]{2,1,0} get-tuple-element(param.3), index=2
+  collective-permute.2 = f32[2,512,24576]{2,1,0} collective-permute(get-tuple-element.9), channel_id=9, source_target_pairs={{0,2},{1,3},{2,0},{3,1}}
+  constant.23 = s32[] constant(0)
+  constant.24 = s32[4]{0} constant({0, 512, 1024, 1536})
+  get-tuple-element.11 = u32[] get-tuple-element(param.3), index=4
+  constant.26 = u32[] constant(2)
+  add.8 = u32[] add(get-tuple-element.11, constant.26)
+  constant.27 = u32[] constant(1)
+  add.9 = u32[] add(add.8, constant.27)
+  partition-id.3 = u32[] partition-id()
+  add.10 = u32[] add(add.9, partition-id.3)
+  constant.22 = u32[] constant(4)
+  remainder.3 = u32[] remainder(add.10, constant.22)
+  dynamic-slice.4 = s32[1]{0} dynamic-slice(constant.24, remainder.3), dynamic_slice_sizes={1}
+  reshape.3 = s32[] reshape(dynamic-slice.4)
+  dynamic-slice.5 = f32[2,512,24576]{2,1,0} dynamic-slice(get-tuple-element.7, constant.23, reshape.3, constant.23), dynamic_slice_sizes={2,512,24576}
+  dot.3 = f32[2,512,24576]{2,1,0} dot(dynamic-slice.5, get-tuple-element.8), lhs_contracting_dims={2}, rhs_contracting_dims={0}
+  add.11 = f32[2,512,24576]{2,1,0} add(collective-permute.2, dot.3)
+  get-tuple-element.10 = f32[2,512,24576]{2,1,0} get-tuple-element(param.3), index=3
+  add.6 = u32[] add(get-tuple-element.11, partition-id.3)
+  remainder.2 = u32[] remainder(add.6, constant.22)
+  dynamic-slice.2 = s32[1]{0} dynamic-slice(constant.24, remainder.2), dynamic_slice_sizes={1}
+  reshape.2 = s32[] reshape(dynamic-slice.2)
+  dynamic-slice.3 = f32[2,512,24576]{2,1,0} dynamic-slice(get-tuple-element.7, constant.23, reshape.2, constant.23), dynamic_slice_sizes={2,512,24576}
+  dot.2 = f32[2,512,24576]{2,1,0} dot(dynamic-slice.3, get-tuple-element.8), lhs_contracting_dims={2}, rhs_contracting_dims={0}
+  add.7 = f32[2,512,24576]{2,1,0} add(get-tuple-element.10, dot.2)
+  collective-permute.3 = f32[2,512,24576]{2,1,0} collective-permute(add.7), channel_id=10, source_target_pairs={{0,2},{1,3},{2,0},{3,1}}
+  ROOT tuple.1 = (f32[2,2048,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,512,24576]{2,1,0}, f32[2,512,24576]{2,1,0}, u32[]) tuple(get-tuple-element.7, get-tuple-element.8, add.11, collective-permute.3, add.8)
+} // windowed_dot_general_body_rs
+
+windowed_dot_general_cond_rs {
+  param.2 = (f32[2,2048,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,512,24576]{2,1,0}, f32[2,512,24576]{2,1,0}, u32[]) parameter(0)
+  get-tuple-element.6 = u32[] get-tuple-element(param.2), index=4
+  constant.21 = u32[] constant(4)
+  ROOT compare.1 = pred[] compare(get-tuple-element.6, constant.21), direction=LT
+}
+
+ENTRY main.9_spmd {
+  param.6 = f8e4m3fn[24576,24576]{1,0} parameter(0), sharding={devices=[4,1]<=[4]}
+  param.7 = f32[2,512,24576]{2,1,0} parameter(1)
+  param.8 = f8e4m3fn[2,2048,24576]{2,1,0} parameter(2)
+  constant.20 = u32[] constant(0)
+  scale_lhs = f32[] parameter(3)
+  scale_lhs_bcast = f32[2,2048,24576]{2,1,0} broadcast(scale_lhs), dimensions={}
+  lhs_bf16 = f32[2,2048,24576]{2,1,0} convert(param.8)  
+  lhs_scaled = f32[2,2048,24576]{2,1,0} multiply(lhs_bf16, scale_lhs_bcast)
+  scale_rhs = f32[] parameter(4)
+  scale_rhs_bcast = f32[24576,24576]{1,0} broadcast(scale_rhs), dimensions={}
+  rhs_bf16 = f32[24576,24576]{1,0} convert(param.6)  
+  rhs_scaled = f32[24576,24576]{1,0} multiply(rhs_bf16, scale_rhs_bcast)
+  tuple.3 = (f32[2,2048,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,512,24576]{2,1,0}, f32[2,512,24576]{2,1,0}, u32[]) tuple(lhs_scaled, rhs_scaled, param.7, param.7, constant.20)
+  while.1 = (f32[2,2048,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,512,24576]{2,1,0}, f32[2,512,24576]{2,1,0}, u32[]) while(tuple.3), condition=windowed_dot_general_cond_rs, body=windowed_dot_general_body_rs
+  ROOT get-tuple-element.14 = f32[2,512,24576]{2,1,0} get-tuple-element(while.1), index=2
+}
+)";
+
+  RunAndFilecheckHloRewrite(kHloString, WindowedEinsumHandler(),
+                            R"(
+; CHECK-LABEL: windowed_dot_general_body_rs
+; CHECK-NEXT:    [[P0:%[^ ]+]] = (f8e4m3fn[2,2048,24576]{2,1,0}, f8e4m3fn[24576,24576]{1,0}, f32[2,512,24576]{2,1,0}, f32[2,512,24576]{2,1,0}, u32[], /*index=5*/f32[], f32[]) parameter(0)
+; CHECK-NEXT:    [[GTE0:%[^ ]+]] = f8e4m3fn[2,2048,24576]{2,1,0} get-tuple-element([[P0]]), index=0
+; CHECK-NEXT:    [[GTE1:%[^ ]+]] = f8e4m3fn[24576,24576]{1,0} get-tuple-element([[P0]]), index=1
+; CHECK-NEXT:    [[GTE2:%[^ ]+]] = f32[2,512,24576]{2,1,0} get-tuple-element([[P0]]), index=2
+; CHECK-NEXT:    [[CP0:%[^ ]+]] = f32[2,512,24576]{2,1,0} collective-permute([[GTE2]]), channel_id=9
+; CHECK-NEXT:    [[CONVERT0:%[^ ]+]] = f32[2,2048,24576]{2,1,0} convert([[GTE0]])
+; CHECK-NEXT:    [[GTE3:%[^ ]+]] = f32[] get-tuple-element([[P0]]), index=5
+; CHECK-NEXT:    [[BCAST0:%[^ ]+]] = f32[2,2048,24576]{2,1,0} broadcast([[GTE3]]), dimensions={}
+; CHECK-NEXT:    [[MUL0:%[^ ]+]] = f32[2,2048,24576]{2,1,0} multiply([[CONVERT0]], [[BCAST0]])
+; CHECK-NEXT:    [[C0:%[^ ]+]] = s32[] constant(0)
+; CHECK-NEXT:    [[C1:%[^ ]+]] = s32[4]{0} constant({0, 512, 1024, 1536})
+; CHECK-NEXT:    [[GTE4:%[^ ]+]] = u32[] get-tuple-element([[P0]]), index=4
+; CHECK-NEXT:    [[C2:%[^ ]+]] = u32[] constant(2)
+; CHECK-NEXT:    [[ADD0:%[^ ]+]] = u32[] add([[GTE4]], [[C2]])
+; CHECK-NEXT:    [[C3:%[^ ]+]] = u32[] constant(1)
+; CHECK-NEXT:    [[ADD1:%[^ ]+]] = u32[] add([[ADD0]], [[C3]])
+; CHECK-NEXT:    [[PID:%[^ ]+]] = u32[] partition-id()
+; CHECK-NEXT:    [[ADD2:%[^ ]+]] = u32[] add([[ADD1]], [[PID]])
+; CHECK-NEXT:    [[C4:%[^ ]+]] = u32[] constant(4)
+; CHECK-NEXT:    [[REM0:%[^ ]+]] = u32[] remainder([[ADD2]], [[C4]])
+; CHECK-NEXT:    [[DSLICE0:%[^ ]+]] = s32[1]{0} dynamic-slice([[C1]], [[REM0]]), dynamic_slice_sizes={1}
+; CHECK-NEXT:    [[RESHAPE0:%[^ ]+]] = s32[] reshape([[DSLICE0]])
+; CHECK-NEXT:    [[DSLICE1:%[^ ]+]] = f32[2,512,24576]{2,1,0} dynamic-slice([[MUL0]], [[C0]], [[RESHAPE0]], [[C0]]), dynamic_slice_sizes={2,512,24576}
+; CHECK-NEXT:    [[CONVERT1:%[^ ]+]] = f32[24576,24576]{1,0} convert([[GTE1]])
+; CHECK-NEXT:    [[GTE5:%[^ ]+]] = f32[] get-tuple-element([[P0]]), index=6
+; CHECK-NEXT:    [[BCAST1:%[^ ]+]] = f32[24576,24576]{1,0} broadcast([[GTE5]]), dimensions={}
+; CHECK-NEXT:    [[MUL1:%[^ ]+]] = f32[24576,24576]{1,0} multiply([[CONVERT1]], [[BCAST1]])
+; CHECK-NEXT:    [[DOT0:%[^ ]+]] = f32[2,512,24576]{2,1,0} dot([[DSLICE1]], [[MUL1]]),
+; CHECK-DAG:       lhs_contracting_dims={2},
+; CHECK-DAG:       rhs_contracting_dims={0},
+; CHECK-DAG:       backend_config={
+; CHECK-DAG:         "operation_queue_id":"[[OPQUEUEID:[0-9]+]]",
+; CHECK-DAG:         "wait_on_operation_queues":[],
+; CHECK-DAG:         "force_earliest_schedule":false}
+; CHECK-NEXT:    [[ADD3:%[^ ]+]] = f32[2,512,24576]{2,1,0} add([[CP0]], [[DOT0]]),
+; CHECK-DAG:       backend_config={"
+; CHECK-DAG:         operation_queue_id":"0",
+; CHECK-DAG:         "wait_on_operation_queues":["[[OPQUEUEID]]"],
+; CHECK-DAG:         "force_earliest_schedule":false}
+; CHECK-NEXT:    [[GTE6:[^ ]+]] = f32[2,512,24576]{2,1,0} get-tuple-element([[P0]]), index=3
+; CHECK-NEXT:    [[ADD4:%[^ ]+]] = u32[] add([[GTE4]], [[PID]])
+; CHECK-NEXT:    [[REM1:%[^ ]+]] = u32[] remainder([[ADD4]], [[C4]])
+; CHECK-NEXT:    [[DSLICE2:%[^ ]+]] = s32[1]{0} dynamic-slice([[C1]], [[REM1]]), dynamic_slice_sizes={1}
+; CHECK-NEXT:    [[RESHAPE1:%[^ ]+]] = s32[] reshape([[DSLICE2]])
+; CHECK-NEXT:    [[DSLICE3:%[^ ]+]] = f32[2,512,24576]{2,1,0} dynamic-slice([[MUL0]], [[C0]], [[RESHAPE1]], [[C0]]), dynamic_slice_sizes={2,512,24576}
+; CHECK-NEXT:    [[DOT1:%[^ ]+]] = f32[2,512,24576]{2,1,0} dot([[DSLICE3]], [[MUL1]]),
+; CHECK-DAG:       lhs_contracting_dims={2},
+; CHECK-DAG:       rhs_contracting_dims={0}
+; CHECK-NEXT:    [[ADD5:%[^ ]+]] = f32[2,512,24576]{2,1,0} add([[GTE6]], [[DOT1]])
+; CHECK-NEXT:    [[CP1:[^ ]+]] = f32[2,512,24576]{2,1,0} collective-permute([[ADD5]]), channel_id=10
+; CHECK-NEXT:  ROOT [[OUT:[^ ]+]] = (f8e4m3fn[2,2048,24576]{2,1,0}, f8e4m3fn[24576,24576]{1,0}, f32[2,512,24576]{2,1,0}, f32[2,512,24576]{2,1,0}, u32[], /*index=5*/f32[], f32[]) tuple([[GTE0]], [[GTE1]], [[ADD3]], [[CP1]], [[ADD0]], /*index=5*/[[GTE3]], [[GTE5]])
+)");
+}
+
+TEST_F(WindowedEinsumHanlderTest,
+       AgLoopsMultipleConsumersAreChainedWithShardedContratingDim) {
+  constexpr absl::string_view kHloString = R"(
+HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0})->bf16[4096,6288]{1,0}}, num_partitions=8
+
+windowed_dot_general_body_ag {
+  param.195 = (bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0}, bf16[16,2048,6288]{2,1,0}, u32[]) parameter(0)
+  get-tuple-element.588 = bf16[16,2048,512]{2,1,0} get-tuple-element(param.195), index=0
+  collective-permute.194 = bf16[16,2048,512]{2,1,0} collective-permute(get-tuple-element.588), channel_id=446, source_target_pairs={{0,7},{1,0},{2,1},{3,2},{4,3},{5,4},{6,5},{7,6}}
+  collective-permute.195 = bf16[16,2048,512]{2,1,0} collective-permute(collective-permute.194), channel_id=447, source_target_pairs={{0,7},{1,0},{2,1},{3,2},{4,3},{5,4},{6,5},{7,6}}
+  get-tuple-element.589 = bf16[4096,6288]{1,0} get-tuple-element(param.195), index=1
+  get-tuple-element.590 = bf16[16,2048,6288]{2,1,0} get-tuple-element(param.195), index=2
+  constant.11432 = s32[8]{0} constant({0, 512, 1024, 1536, 2048, 2560, 3072, 3584})
+  get-tuple-element.592 = u32[] get-tuple-element(param.195), index=4
+  partition-id.194 = u32[] partition-id()
+  add.4309 = u32[] add(get-tuple-element.592, partition-id.194)
+  constant.11431 = u32[] constant(8)
+  remainder.194 = u32[] remainder(add.4309, constant.11431)
+  dynamic-slice.388 = s32[1]{0} dynamic-slice(constant.11432, remainder.194), dynamic_slice_sizes={1}
+  reshape.12959 = s32[] reshape(dynamic-slice.388)
+  constant.11433 = s32[] constant(0)
+  dynamic-slice.389 = bf16[512,6288]{1,0} dynamic-slice(get-tuple-element.589, reshape.12959, constant.11433), dynamic_slice_sizes={512,6288}
+  dot.244 = bf16[16,2048,6288]{2,1,0} dot(get-tuple-element.588, dynamic-slice.389), lhs_contracting_dims={2}, rhs_contracting_dims={0}
+  add.4310 = bf16[16,2048,6288]{2,1,0} add(get-tuple-element.590, dot.244)
+  constant.11434 = u32[] constant(1)
+  add.4312 = u32[] add(get-tuple-element.592, constant.11434)
+  add.4313 = u32[] add(add.4312, partition-id.194)
+  remainder.195 = u32[] remainder(add.4313, constant.11431)
+  dynamic-slice.390 = s32[1]{0} dynamic-slice(constant.11432, remainder.195), dynamic_slice_sizes={1}
+  reshape.12960 = s32[] reshape(dynamic-slice.390)
+  dynamic-slice.391 = bf16[512,6288]{1,0} dynamic-slice(get-tuple-element.589, reshape.12960, constant.11433), dynamic_slice_sizes={512,6288}
+  dot.245 = bf16[16,2048,6288]{2,1,0} dot(collective-permute.194, dynamic-slice.391), lhs_contracting_dims={2}, rhs_contracting_dims={0}
+  add.4314 = bf16[16,2048,6288]{2,1,0} add(add.4310, dot.245)
+  get-tuple-element.591 = bf16[16,2048,6288]{2,1,0} get-tuple-element(param.195), index=3
+  add.4315 = u32[] add(add.4312, constant.11434)
+  ROOT tuple.98 = (bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0}, bf16[16,2048,6288]{2,1,0}, u32[]) tuple(collective-permute.195, get-tuple-element.589, add.4314, get-tuple-element.591, add.4315)
+} // windowed_dot_general_body_ag
+
+windowed_dot_general_cond_ag {
+  param = (bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0}, bf16[16,2048,6288]{2,1,0}, u32[]) parameter(0)
+  get-tuple-element = u32[] get-tuple-element(param), index=4
+  constant = u32[] constant(4)
+  ROOT compare = pred[] compare(get-tuple-element, constant), direction=LT
+}
+
+ENTRY main.12_spmd {
+  param.4 = bf16[16,2048,512]{2,1,0} parameter(0)
+  param.5 = bf16[4096,6288]{1,0} parameter(1)
+  constant.22 = bf16[] constant(0)
+  broadcast = bf16[16,2048,6288]{2,1,0} broadcast(constant.22), dimensions={}
+  constant.24 = u32[] constant(0)
+  tuple.2 = (bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0}, bf16[16,2048,6288]{2,1,0}, u32[]) tuple(param.4, param.5, broadcast, broadcast, constant.24)
+  while = (bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0}, bf16[16,2048,6288]{2,1,0}, u32[]) while(tuple.2), condition=windowed_dot_general_cond_ag, body=windowed_dot_general_body_ag
+  get-tuple-element.13 = bf16[16,2048,6288]{2,1,0} get-tuple-element(while), index=2
+  all-gather = bf16[16,2048,4096]{2,1,0} all-gather(param.4), channel_id=1, replica_groups={{0,1,2,3,4,5,6,7}}, dimensions={2}, use_global_device_ids=true
+  param.6 = bf16[16,2048,6288]{2,1,0} parameter(2)
+  ROOT dot.7 = bf16[4096,6288]{1,0} dot(all-gather, param.6), lhs_contracting_dims={0,1}, rhs_contracting_dims={0,1}
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(kHloString));
+
+  WindowedEinsumHandler gpu_handler;
+  bool changed;
+  TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get()));
+  EXPECT_TRUE(changed);
+
+  HloInstruction* ag_loop =
+      FindInstructionByName(module->entry_computation(), "while");
+  HloInstruction* inst =
+      FindInstructionByName(module->entry_computation(), "dot.7");
+  // dot.7 should now consume output of the windowed einsum while loop.
+  EXPECT_EQ(inst->operand(0)->opcode(), HloOpcode::kGetTupleElement);
+  EXPECT_EQ(inst->operand(0)->tuple_index(), 5);
+  EXPECT_EQ(inst->operand(0)->operand(0), ag_loop);
+}
+}  // namespace
+}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/tree_reduction_rewriter.cc b/third_party/xla/xla/service/gpu/tree_reduction_rewriter.cc
deleted file mode 100644
index b54d006..0000000
--- a/third_party/xla/xla/service/gpu/tree_reduction_rewriter.cc
+++ /dev/null
@@ -1,389 +0,0 @@
-/* Copyright 2020 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "xla/service/gpu/tree_reduction_rewriter.h"
-
-#include <algorithm>
-#include <cmath>
-#include <cstdint>
-#include <iterator>
-#include <memory>
-#include <utility>
-
-#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_set.h"
-#include "absl/container/inlined_vector.h"
-#include "absl/log/check.h"
-#include "absl/log/log.h"
-#include "absl/numeric/bits.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/collective_ops_utils.h"
-#include "xla/service/gpu/reduction_utils.h"
-#include "xla/service/hlo_module_config.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/stream_executor/device_description.h"
-#include "xla/util.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-
-absl::InlinedVector<int64_t, 2> GetSortedReducedDims(
-    HloReduceInstruction *reduce) {
-  absl::InlinedVector<int64_t, 2> reduced_dims{reduce->dimensions().begin(),
-                                               reduce->dimensions().end()};
-  absl::c_sort(reduced_dims);
-  return reduced_dims;
-}
-
-bool IsMinMaxReduction(HloReduceInstruction *reduce) {
-  HloComputation *called = &reduce->to_apply()[0];
-  if (auto reduction_kind = MatchReductionComputation(called)) {
-    return reduction_kind == ReductionKind::MAX ||
-           reduction_kind == ReductionKind::MIN;
-  }
-  return false;
-}
-
-}  // namespace
-
-class ReductionRewriterVisitor : public DfsHloRewriteVisitor {
- public:
-  explicit ReductionRewriterVisitor(se::GpuComputeCapability gpu_version)
-      : gpu_version_(gpu_version) {}
-
-  absl::Status HandleReduce(HloInstruction *hlo) override {
-    auto *reduce = Cast<HloReduceInstruction>(hlo);
-    VLOG(3) << "Reduction instruction: " << reduce->ToString();
-
-    const HloModuleConfig &config = reduce->GetModule()->config();
-    if (!MatchReductionForSplit(reduce, config)) {
-      return absl::OkStatus();
-    }
-    ReductionDimensions reduction_dims =
-        GetReductionKindAndContiguousComponents(*hlo);
-    if (ReductionIsRaceFree(config, reduction_dims)) {
-      VLOG(3) << "Base case: dimensions fit";
-      return absl::OkStatus();
-    }
-    auto sorted_dims_to_reduce = GetSortedReducedDims(reduce);
-    CHECK_LE(sorted_dims_to_reduce.size(), 2);
-
-    // If the major reduced dimension does not fit, reduce the minor dimension
-    // first, then the major.
-    if (reduction_dims.is_row_reduction &&
-        reduction_dims
-                .dimensions[ReductionDimensions::kRowMajorReducedDimension] >
-            BatchedReductionRaceFreeBound()) {
-      VLOG(2) << "Splitting batched dimension reduce into a separate reduction";
-      return RewriteBatchDimensionLargerThanTile(reduce, reduction_dims,
-                                                 sorted_dims_to_reduce);
-    }
-    SplitParams split_params =
-        ComputeSplitParams(reduce, reduction_dims, sorted_dims_to_reduce);
-    return SplitReductionDimension(reduce, split_params, sorted_dims_to_reduce);
-  }
-
- private:
-  bool MatchReductionForSplit(HloReduceInstruction *reduce,
-                              const HloModuleConfig &config) {
-    // MLIR emitters only support race-free reductions.
-    // TODO(jreiffers: Verify performance and implement atomics for reductions
-    // if needed.
-    bool reductions_via_mlir_disabled =
-        config.debug_options().xla_gpu_mlir_emitter_level() < 4;
-    if (reductions_via_mlir_disabled && IsMinMaxReduction(reduce)) {
-      // TODO(cheshire): Also enable for integers.
-      VLOG(1) << "Not performing tree expansion on min/max-reduction: "
-              << reduce->ToString()
-              << " since min/max operations are associative";
-      return false;
-    }
-    if (!IsReductionFromOrToContiguousDimensions(*reduce)) {
-      VLOG(3) << "Is not a reduction from or to contiguous dimensions";
-      return false;
-    }
-    VLOG(3) << "Perform rewrite";
-    return true;
-  }
-
-  // We observe larger n_div_k can improve tree reduction performance in most of
-  // the cases by reducing memory store and the launch overhead of blocks. Swap
-  // k and n_div_k if possible.
-  bool ShouldSwapInnerAndOuterReducedMinorDimension(uint64_t k1, uint64_t k2,
-                                                    uint64_t n,
-                                                    int64_t race_free_bound,
-                                                    bool is_row_reduction) {
-    CHECK(k1 >= k2);
-    // Keep inner reduction as race free.
-    if (k1 > race_free_bound) {
-      return false;
-    }
-    // Swapping only affects row reduction vectorization.
-    if (is_row_reduction) {
-      // Rough conditions for row reduction vectorization, not mean that
-      // vectorization will definitely occur.
-      bool maybe_vectorized = k2 % 2 == 0 && n % 2 == 0;
-      if (maybe_vectorized) {
-        // Swap if n_div_k is small enough or k dim can be vectorized also.
-        return k2 * 2 < k1 || k1 % 2 == 0;
-      }
-      // Current reduction emitter only checks reduction input dimensions but
-      // not fusion input dimensions. Due to pad and inner reduction always fuse
-      // into same computation, it may leads to each thread reads multiple non
-      // aligned elements but can not vectorized so that get bad performance.
-      // Don't swap If encountered this situation.
-      return n % 2 == 0 || k1 % 2 != 0;
-    }
-    // There exists no specific situation where swapping has no performance gain
-    // for column reduction.
-    return true;
-  }
-
-  // Parameters how to split a dimension `dim` with `k` elements into `k1` x
-  // `k2`.
-  struct SplitParams {
-    int64_t k1;
-    int64_t k2;
-    int64_t dim;
-  };
-
-  // Attempts to find the best way to split a dimension `dim` with `k` elements
-  // into `k1` x `k2`.
-  SplitParams ComputeSplitParams(
-      HloReduceInstruction *reduce, const ReductionDimensions &reduction_dims,
-      absl::Span<const int64_t> sorted_dims_to_reduce) {
-    absl::Span<int64_t const> input_shape_dims =
-        reduce->inputs()[0]->shape().dimensions();
-
-    int64_t reduced_dim = sorted_dims_to_reduce.back();
-    int64_t reduced_dim_size = input_shape_dims[reduced_dim];
-    VLOG(3) << "reduced dim size = " << reduced_dim_size;
-
-    // We will do this reduction in two stages.  The first will reduce from k
-    // elements to k1 elements in the reduction dimension.  The second will
-    // reduce further, from k2 to 1 element.
-    //
-    // We do this by splitting the input shape [a, k, b] into [a, k1, k2, b].
-    //
-    // We want to choose k1 to be roughly equal to sqrt(k) so that we process
-    // "most of" the reduction in the first step. But it is also important that
-    // we choose a value of k1 with the least amount of padding we need to add
-    // to n to make it divisible by k1. We search for the best value of k2
-    // between sqrt(k)/2 and sqrt(k). If there are several possible values for
-    // k2 that result in the minimum amount of padding, we also want k2 to
-    // be a power of 2, so that the GPU kernel doesn't spend all its time doing
-    // slow integer divmods to compute indices into the shape [a,k1,k2,b].
-    // Note that by searching in the range between sqrt(k)/2 and sqrt(k), we
-    // will have a power of 2 in that range.
-    uint64_t k2 =
-        static_cast<uint64_t>(std::floor(std::sqrt(reduced_dim_size)));
-    int64_t race_free_bound = ReductionDimensionRaceFreeBound(
-        reduce->GetModule()->config(), reduction_dims);
-    if (k2 > race_free_bound) {
-      // This means we need more than one split. It is best to limit the n/k
-      // dimension to the maximum size that doesn't require further splitting.
-      // Otherwise we might choose a rather small reduce dimension size for the
-      // first step (in the worst case, sqrt(race_free_bound + 1)).
-      k2 = race_free_bound;
-    }
-    uint64_t minimum_padding = (k2 - reduced_dim_size % k2) % k2;
-    uint64_t best_k1 = (reduced_dim_size + minimum_padding) / k2;
-    for (uint64_t i = k2 - 1; i > k2 / 2; --i) {
-      uint64_t padding = (i - reduced_dim_size % i) % i;
-      if (padding < minimum_padding ||
-          (padding == minimum_padding && absl::has_single_bit(i))) {
-        minimum_padding = padding;
-        best_k1 = (reduced_dim_size + padding) / i;
-      }
-    }
-    uint64_t padded_k = reduced_dim_size + minimum_padding;
-
-    // We get the best {k_1, k_2} pair by the size of padding and whether
-    // index computation is fast. But we ignored the overhead of memory
-    // read/write and blocks launch, which are also important for kernel
-    // performance. It is obvious that the swapped {k1, k2} pairs has same
-    // padding size and consumption of index computation as the original. So we
-    // only need to compare the memory read/write and blocks launch to choose
-    // the better one of them.
-    uint64_t best_k2 = padded_k / best_k1;
-    if (ShouldSwapInnerAndOuterReducedMinorDimension(
-            best_k1, best_k2, reduced_dim_size, race_free_bound,
-            reduction_dims.is_row_reduction)) {
-      std::swap(best_k1, best_k2);
-    }
-    return SplitParams{static_cast<int64_t>(best_k1),
-                       static_cast<int64_t>(best_k2), reduced_dim};
-  }
-
-  // Replaces the original reduce with pad->reshape>inner_reduce->outer_reduce.
-  // * 1. pads split dimension of the inputs to k1 * k2 if necessary.
-  // * 2. reshapes split dimension of the padded inputs into [k1, k2].
-  // * 3. inner reduction reduces the dims specified in the original reduction.
-  //      Instead of reducing the split dimension, reduces K2.
-  // * 4. outer_reduction reduces K1 only.
-  absl::Status SplitReductionDimension(
-      HloReduceInstruction *reduce, const SplitParams &split_params,
-      absl::Span<const int64_t> sorted_dims_to_reduce) {
-    absl::Span<int64_t const> reduce_input_dims =
-        reduce->inputs()[0]->shape().dimensions();
-    int64_t split_dim_size = reduce_input_dims[split_params.dim];
-    VLOG(2) << "dimension to split = " << split_params.dim << " with "
-            << split_dim_size << " elements into " << split_params.k1 << " by "
-            << split_params.k2;
-
-    // Pad 'k' to 'k1 * k2' if necessary.
-    HloInstruction::InstructionVector padded_inputs(reduce->inputs().begin(),
-                                                    reduce->inputs().end());
-    auto padded_size = split_params.k1 * split_params.k2;
-    absl::InlinedVector<int64_t, 3> padded_dimensions(reduce_input_dims.begin(),
-                                                      reduce_input_dims.end());
-    if (split_dim_size != padded_size) {
-      padded_dimensions[split_params.dim] = padded_size;
-      PaddingConfig padding_config =
-          MakeNoPaddingConfig(reduce_input_dims.size());
-      padding_config.mutable_dimensions(split_params.dim)
-          ->set_edge_padding_high(padded_size - split_dim_size);
-
-      for (int input_idx = 0; input_idx < padded_inputs.size(); ++input_idx) {
-        auto &reduction_input = padded_inputs[input_idx];
-        Shape padded_shape = ShapeUtil::MakeShape(
-            reduction_input->shape().element_type(), padded_dimensions);
-        VLOG(2) << "Generated padded shape: " << padded_shape.ToString();
-        reduction_input = reduce->parent()->AddInstruction(
-            HloInstruction::CreatePad(padded_shape, reduction_input,
-                                      reduce->init_values()[input_idx],
-                                      padding_config),
-            &reduction_input->metadata());
-      }
-    }
-
-    // Compute output type of reshape that expands the split dimension into
-    // [k1, k2].
-    absl::InlinedVector<int64_t, 3> reshaped_dimensions;
-    int64_t input_rank = reduce_input_dims.size();
-    for (int64_t dim_idx = 0; dim_idx < input_rank; dim_idx++) {
-      if (dim_idx == split_params.dim) {
-        reshaped_dimensions.push_back(split_params.k1);
-        reshaped_dimensions.push_back(split_params.k2);
-      } else {
-        reshaped_dimensions.push_back(padded_dimensions[dim_idx]);
-      }
-    }
-
-    // Compute dimensions to reduce for inner reduction.
-    absl::InlinedVector<int64_t, 2> inner_reduce_dims(
-        sorted_dims_to_reduce.begin(), sorted_dims_to_reduce.end());
-    auto split_dim_it = std::find(inner_reduce_dims.begin(),
-                                  inner_reduce_dims.end(), split_params.dim);
-    *split_dim_it += 1;
-
-    // Compute dimension to reduce for outer reduction.
-    absl::InlinedVector<int64_t, 1> outer_reduce_dims{
-        split_params.dim -
-        std::distance(inner_reduce_dims.begin(), split_dim_it)};
-
-    // Compute output shape of the inner reduction.
-    absl::InlinedVector<int64_t, 3> inner_reduce_shape =
-        RemoveElements(inner_reduce_dims, reshaped_dimensions);
-
-    // Reshape the split dimensions of the padded inputs into [k1, k2].
-    HloInstruction::InstructionVector reshaped_padded_inputs;
-    absl::InlinedVector<Shape, 2> inner_reduce_shapes;
-    for (HloInstruction *padded_input : padded_inputs) {
-      Shape reshaped_shape = ShapeUtil::MakeShape(
-          padded_input->shape().element_type(), reshaped_dimensions);
-      HloInstruction *reshaped_padded_input = reduce->parent()->AddInstruction(
-          HloInstruction::CreateBitcast(reshaped_shape, padded_input),
-          &padded_input->metadata());
-      VLOG(2) << "Generated reshape: " << reshaped_padded_input->ToString();
-      reshaped_padded_inputs.push_back(reshaped_padded_input);
-      inner_reduce_shapes.push_back(ShapeUtil::MakeShape(
-          padded_input->shape().element_type(), inner_reduce_shape));
-    }
-
-    // Inner reduce that reduces [k1, k2] to [k1].
-    HloInstruction *inner_reduce = reduce->parent()->AddInstruction(
-        HloInstruction::CreateReduce(
-            ShapeUtil::MakeMaybeTupleShape(inner_reduce_shapes),
-            reshaped_padded_inputs, reduce->init_values(), inner_reduce_dims,
-            reduce->to_apply()),
-        &reduce->metadata());
-    VLOG(1) << "Generated inner reduction: " << inner_reduce->ToString();
-
-    // Outer reduce that reduces [k2].
-    std::unique_ptr<HloInstruction> outer_reduce = HloInstruction::CreateReduce(
-        reduce->shape(), inner_reduce, reduce->init_values(), outer_reduce_dims,
-        reduce->to_apply());
-
-    VLOG(1) << "Generated outer reduction: " << outer_reduce->ToString();
-    return ReplaceWithNewInstruction(reduce, std::move(outer_reduce));
-  }
-
-  // Rewrites batch dimension reduction into a separate reduce operation.
-  absl::Status RewriteBatchDimensionLargerThanTile(
-      HloReduceInstruction *hlo,
-      const ReductionDimensions &reduction_dimensions,
-      absl::Span<const int64_t> sorted_dims_to_reduce) {
-    // TODO(cheshire): this codepath is essentially the exact reverse of what
-    // algebraic_simplifier is doing, we need to make sure they don't keep
-    // undoing each other.
-    CHECK(reduction_dimensions.is_row_reduction);
-
-    absl::InlinedVector<Shape, 2> tuple_shapes;
-    int64_t minor_reduction_dim = sorted_dims_to_reduce.back();
-    for (HloInstruction *input : hlo->inputs()) {
-      tuple_shapes.push_back(
-          ShapeUtil::DeleteDimension(minor_reduction_dim, input->shape()));
-    }
-
-    HloInstruction *inner_reduce =
-        hlo->parent()->AddInstruction(HloInstruction::CreateReduce(
-            ShapeUtil::MakeMaybeTupleShape(tuple_shapes), hlo->inputs(),
-            hlo->init_values(), {minor_reduction_dim}, hlo->to_apply()));
-
-    VLOG(1) << "Inner reduction: " << inner_reduce->ToString();
-    std::unique_ptr<HloInstruction> out = HloInstruction::CreateReduce(
-        hlo->shape(), inner_reduce, hlo->init_values(), {0}, hlo->to_apply());
-    VLOG(1) << "Generated: " << out->ToString();
-    return ReplaceWithNewInstruction(hlo, std::move(out));
-  }
-
-  se::GpuComputeCapability gpu_version_;
-};
-
-absl::StatusOr<bool> GpuTreeReductionRewriter::Run(
-    HloModule *module,
-    const absl::flat_hash_set<absl::string_view> &execution_threads) {
-  VLOG(5) << "Rewriter input: " << module->ToString();
-  TF_ASSIGN_OR_RETURN(bool changed,
-                      ReductionRewriterVisitor(gpu_version_)
-                          .RunOnModule(module, execution_threads));
-  VLOG(5) << "Rewriter output: " << module->ToString();
-  return changed;
-}
-
-}  // end namespace gpu
-}  // end namespace xla
diff --git a/third_party/xla/xla/service/gpu/tree_reduction_rewriter.h b/third_party/xla/xla/service/gpu/tree_reduction_rewriter.h
deleted file mode 100644
index 5f6edf8..0000000
--- a/third_party/xla/xla/service/gpu/tree_reduction_rewriter.h
+++ /dev/null
@@ -1,99 +0,0 @@
-/* Copyright 2020 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef XLA_SERVICE_GPU_TREE_REDUCTION_REWRITER_H_
-#define XLA_SERVICE_GPU_TREE_REDUCTION_REWRITER_H_
-
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/stream_executor/device_description.h"
-
-namespace xla {
-namespace gpu {
-
-// Rewrites reductions in a way they can be implemented without atomics.
-//
-// Rule application: rewrite a single HLO reduce operation into two.
-//
-// Case 1: Row reduction, batched dimension is present, larger than
-// Z-tiling size.
-// -----------------------------------------------------------------
-//
-// Rewriting:
-//
-// f32[B] out = reduce(f32[A, B, C] input, dimensions={0, 2})
-//
-// Into:
-//
-// f32[A, B] tmp = reduce(f32[A, B, C] input, dimensions={2})
-// f32[B] out = reduce(f32[A, B] tmp, dimensions={0})
-//
-// Case 2: Row reduction
-// ------------------------------------------------------------------
-//
-// Let M be the thread tiling multiplied by the warp size.
-// We go from (assuming C > M):
-//
-// f32[B] out = reduce(f32[A, B, C] input, dimensions={0, 2})
-//
-// to:
-//
-// f32[A, B, P] padded = pad(input) // Let P = ceil(C/M) * M.
-// f32[A, B, Q, M] reshaped = bitcast(padded) // Let Q = ceil(C/M)
-// f32[B, Q] inner_reduce = reduce(reshaped, dimensions={0, 3})
-// f32[B] outer_reduce = reduce(inner_reduce, dimensions={1})
-//
-// Case 3: Column reduction
-// -------------------------------------------------------------------
-//
-// Let T be the tiling size for the column reduction.
-//
-// We go from (assuming B > T):
-//
-// f32[A, C] out = reduce(f32[A, B, C] input, dimensions={1})
-//
-// to:
-//
-// f32[A, P, C] padded = pad(input) // Let P = ceil(B/T) * T.
-// f32[A, Q, T, C] reshaped = bitcast(padded) // Let Q = ceil(B/T)
-// f32[A, Q, C] inner_reduce = reduce(reshaped, dimensions={2})
-// f32[A, C] outer_reduce = reduce(inner_reduce, dimensions={1})
-//
-class GpuTreeReductionRewriter : public HloModulePass {
- public:
-  explicit GpuTreeReductionRewriter(se::GpuComputeCapability gpu_version)
-      : gpu_version_(gpu_version) {}
-
-  ~GpuTreeReductionRewriter() override = default;
-  absl::string_view name() const override {
-    return "gpu-tree-reduction-rewriter";
-  }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
-  se::GpuComputeCapability gpu_version_;
-};
-
-}  // end namespace gpu
-}  // end namespace xla
-
-#endif  // XLA_SERVICE_GPU_TREE_REDUCTION_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/triangular_solve_rewriter.cc b/third_party/xla/xla/service/gpu/triangular_solve_rewriter.cc
deleted file mode 100644
index 2dcd365..0000000
--- a/third_party/xla/xla/service/gpu/triangular_solve_rewriter.cc
+++ /dev/null
@@ -1,88 +0,0 @@
-/* Copyright 2022 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/triangular_solve_rewriter.h"
-
-#include <cstdint>
-#include <numeric>
-#include <vector>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/gpu/cublas_cudnn.h"
-#include "xla/service/hlo_creation_utils.h"
-#include "xla/shape.h"
-#include "xla/shape_util.h"
-#include "xla/xla_data.pb.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-absl::StatusOr<bool> TriangularSolveRewriter::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  bool changed = false;
-  for (HloComputation* comp :
-       module->MakeNonfusionComputations(execution_threads)) {
-    std::vector<HloInstruction*> to_rewrite;
-    for (HloInstruction* instr : comp->instructions()) {
-      if (instr->opcode() == HloOpcode::kTriangularSolve) {
-        to_rewrite.push_back(instr);
-      }
-    }
-
-    for (HloInstruction* instr : to_rewrite) {
-      const Shape& b_shape = instr->operand(1)->shape();
-      int64_t batch_size = std::accumulate(
-          b_shape.dimensions().begin(), b_shape.dimensions().end() - 2,
-          int64_t{1}, [](int64_t a, int64_t b) { return a * b; });
-
-      // batch 1 triangular solves get 0 temp bytes, because unbatched trsm()
-      // doesn't require temp memory.
-      int64_t temp_bytes = batch_size == 1 ? 0 : 2 * sizeof(void*) * batch_size;
-      Shape new_shape = ShapeUtil::MakeTupleShape({
-          instr->shape(),
-          ShapeUtil::MakeShape(S8, {temp_bytes}),
-      });
-
-      HloInstruction* custom_call =
-          comp->AddInstruction(HloInstruction::CreateCustomCall(
-              new_shape, instr->operands(), kTriangularSolveCallTarget));
-      module->SetAndUniquifyInstrName(custom_call, "triangular-solve");
-      TF_RETURN_IF_ERROR(
-          custom_call->set_backend_config(instr->triangular_solve_options()));
-
-      // Preserve metadata from `instr`.
-      custom_call->set_metadata(instr->metadata());
-      custom_call->set_frontend_attributes(instr->frontend_attributes());
-
-      // Get the actual result out of the custom call's tuple.
-      TF_ASSIGN_OR_RETURN(HloInstruction * gte,
-                          MakeGetTupleElementHlo(custom_call, 0));
-      TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, gte));
-    }
-  }
-  return changed;
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/triangular_solve_rewriter.h b/third_party/xla/xla/service/gpu/triangular_solve_rewriter.h
deleted file mode 100644
index 6d4b1c1..0000000
--- a/third_party/xla/xla/service/gpu/triangular_solve_rewriter.h
+++ /dev/null
@@ -1,60 +0,0 @@
-/* Copyright 2022 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_TRIANGULAR_SOLVE_REWRITER_H_
-#define XLA_SERVICE_GPU_TRIANGULAR_SOLVE_REWRITER_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-// Rewrites HLO TriangularSolve ops into a custom-call.
-//
-// The motivation for this is that we need to add temp memory to batched
-// triangular-solve ops in order to call cublas trsmBatched.  We rewrite batch 1
-// ops as well so that we have fewer codepaths to worry about in the backend.
-//
-// cublas trsmBatched takes arrays in GPU memory of pointers to the inputs and
-// outputs, `a` and `b`.  In XLA the inputs/outputs are always contiguous, but
-// we still have to materialize out these arrays.
-//
-// We use the same trick as for cudnn convolutions: This custom-call returns a
-// tuple (actual-result, temp-memory).  In this our case the temp buffer always
-// has size 2 * sizeof(void*) * batch_size, because we need two arrays of
-// pointers.
-//
-// The custom-call has a backend-config equal to the TriangularSolveOptions
-// object.
-class TriangularSolveRewriter : public HloModulePass {
- public:
-  absl::string_view name() const override {
-    return "triangular-solve-rewriter";
-  }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_TRIANGULAR_SOLVE_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/triton_fusion_analysis_test.cc b/third_party/xla/xla/service/gpu/triton_fusion_analysis_test.cc
index c911bf1..c5cc0ad 100644
--- a/third_party/xla/xla/service/gpu/triton_fusion_analysis_test.cc
+++ b/third_party/xla/xla/service/gpu/triton_fusion_analysis_test.cc
@@ -23,7 +23,7 @@
 #include "absl/status/statusor.h"
 #include "xla/hlo/ir/hlo_instruction.h"
 #include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/gpu/gemm_fusion.h"
+#include "xla/service/gpu/transforms/gemm_fusion.h"
 #include "xla/stream_executor/device_description.h"
 #include "xla/tests/hlo_test_base.h"
 #include "xla/tests/verified_hlo_module.h"
diff --git a/third_party/xla/xla/service/gpu/triton_fusion_numerics_verifier.cc b/third_party/xla/xla/service/gpu/triton_fusion_numerics_verifier.cc
deleted file mode 100644
index 75c43fe..0000000
--- a/third_party/xla/xla/service/gpu/triton_fusion_numerics_verifier.cc
+++ /dev/null
@@ -1,190 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/triton_fusion_numerics_verifier.h"
-
-#include <memory>
-#include <optional>
-#include <utility>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/functional/any_invocable.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/executable.h"
-#include "xla/service/gpu/autotuner_compile_util.h"
-#include "xla/service/gpu/autotuner_util.h"
-#include "xla/service/gpu/backend_configs.pb.h"
-#include "xla/service/gpu/buffer_comparator.h"
-#include "xla/service/gpu/ir_emission_utils.h"
-#include "xla/service/hlo_module_config.h"
-#include "xla/service/shaped_buffer.h"
-#include "xla/shape.h"
-#include "xla/status_macros.h"
-#include "xla/stream_executor/stream.h"
-#include "xla/tools/hlo_decomposer.h"
-#include "xla/util.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla::gpu {
-
-namespace {
-
-using ProfilingOutput = AutotunerCompileUtil::ProfilingOutput;
-
-// Returns the input instruction as a fusion instruction, if it represents a
-// Triton fusion. Otherwise, returns nullptr.
-absl::StatusOr<const HloFusionInstruction*> AsTritonFusion(
-    const HloInstruction* hlo) {
-  if (hlo->opcode() != HloOpcode::kFusion) {
-    return nullptr;
-  }
-  const HloFusionInstruction* fusion = Cast<HloFusionInstruction>(hlo);
-  TF_ASSIGN_OR_RETURN(auto gpu_config,
-                      fusion->backend_config<GpuBackendConfig>());
-  const FusionBackendConfig& backend_config =
-      gpu_config.fusion_backend_config();
-  if (backend_config.kind() == kTritonFusionKind) {
-    return fusion;
-  }
-  return nullptr;
-}
-
-std::unique_ptr<HloModule> NewHloModuleFromFusion(
-    const HloFusionInstruction& fusion, const DebugOptions& debug_opts,
-    bool clear_backend_config) {
-  std::unique_ptr<HloModule> new_module =
-      ExtractInstructionIntoNewModule(fusion);
-  if (clear_backend_config) {
-    new_module->entry_computation()->root_instruction()->clear_backend_config();
-  }
-  new_module->mutable_config().set_debug_options(debug_opts);
-
-  return new_module;
-}
-
-}  // namespace
-
-namespace triton_fusion_numerics_pass_internal {
-
-absl::StatusOr<ScopedShapedBuffer> CompileAndRunFusion(
-    AutotunerCompileUtil& util, const HloFusionInstruction& fusion,
-    const AutotuneConfig& config, const DebugOptions& debug_opts,
-    bool clear_backend_config) {
-  TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
-                      util.Compile([&](const DebugOptions& opts) {
-                        return NewHloModuleFromFusion(fusion, opts,
-                                                      clear_backend_config);
-                      }));
-  TF_ASSIGN_OR_RETURN(auto rz_buffers, RedzoneBuffers::FromInstruction(
-                                           fusion, config, debug_opts,
-                                           RedzoneBuffers::kAllInputs));
-  TF_ASSIGN_OR_RETURN(auto stream, config.GetStream());
-  TF_ASSIGN_OR_RETURN(std::optional<ProfilingOutput> profiling_output,
-                      util.ProfileExecutable(executable.get(), stream,
-                                             rz_buffers.input_buffers(),
-                                             rz_buffers.input_shapes()));
-  if (!profiling_output.has_value()) {
-    return Internal("No output after a successful verification run.");
-  }
-
-  return std::move(profiling_output->output);
-}
-
-absl::Status CompareBuffers(const ScopedShapedBuffer& current,
-                            const ScopedShapedBuffer& expected,
-                            const Shape& shape, const HloModuleConfig& config,
-                            se::Stream* stream) {
-  BufferComparator comparator(
-      shape, config.debug_options().xla_gpu_autotune_gemm_rtol());
-  TF_ASSIGN_OR_RETURN(bool outputs_match,
-                      comparator.CompareEqual(stream, current.root_buffer(),
-                                              expected.root_buffer()));
-
-  if (!outputs_match) {
-    return Internal("Triton fusion output does not match emitters output.");
-  }
-  return absl::OkStatus();
-}
-
-absl::Status ForAllTritonFusions(
-    const HloModule& module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads,
-    absl::AnyInvocable<absl::Status(const HloFusionInstruction&)> fn) {
-  for (HloComputation* computation :
-       module.MakeNonfusionComputations(execution_threads)) {
-    for (HloInstruction* instruction : computation->instructions()) {
-      TF_ASSIGN_OR_RETURN(auto triton_fusion, AsTritonFusion(instruction));
-      if (triton_fusion != nullptr) {
-        TF_RETURN_IF_ERROR(fn(*triton_fusion));
-      }
-    }
-  }
-  return absl::OkStatus();
-}
-
-}  // namespace triton_fusion_numerics_pass_internal
-
-namespace {
-absl::Status VerifyTritonFusion(AutotunerCompileUtil& util,
-                                const HloFusionInstruction& fusion,
-                                const AutotuneConfig& config,
-                                const DebugOptions& debug_opts) {
-  TF_ASSIGN_OR_RETURN(auto triton_result,
-                      triton_fusion_numerics_pass_internal::CompileAndRunFusion(
-                          util, fusion, config, debug_opts,
-                          /*clear_backend_config=*/false));
-  TF_ASSIGN_OR_RETURN(auto emitters_result,
-                      triton_fusion_numerics_pass_internal::CompileAndRunFusion(
-                          util, fusion, config, debug_opts,
-                          /*clear_backend_config=*/true));
-
-  TF_ASSIGN_OR_RETURN(auto stream, config.GetStream());
-  return triton_fusion_numerics_pass_internal::CompareBuffers(
-      triton_result, emitters_result, fusion.shape(),
-      fusion.GetModule()->config(), stream);
-}
-
-}  // namespace
-
-absl::StatusOr<bool> TritonFusionNumericsVerifier::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  if (config_.IsDeviceless()) {
-    return absl::InternalError(
-        "Cannot run TritonFusionNumericsVerifier on a deviceless compilation.");
-  }
-
-  const DebugOptions& debug_options = module->config().debug_options();
-  TF_ASSIGN_OR_RETURN(std::optional<AutotunerCompileUtil> opt_compile_util,
-                      AutotunerCompileUtil::Create(config_, debug_options));
-  TF_RET_CHECK(opt_compile_util.has_value());
-
-  TF_RETURN_IF_ERROR(triton_fusion_numerics_pass_internal::ForAllTritonFusions(
-      *module, execution_threads, [&](const HloFusionInstruction& fusion) {
-        return VerifyTritonFusion(*opt_compile_util, fusion, config_,
-                                  debug_options);
-      }));
-  return false;
-}
-
-}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/triton_fusion_numerics_verifier.h b/third_party/xla/xla/service/gpu/triton_fusion_numerics_verifier.h
deleted file mode 100644
index 6d74f46..0000000
--- a/third_party/xla/xla/service/gpu/triton_fusion_numerics_verifier.h
+++ /dev/null
@@ -1,74 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_TRITON_FUSION_NUMERICS_VERIFIER_H_
-#define XLA_SERVICE_GPU_TRITON_FUSION_NUMERICS_VERIFIER_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/functional/any_invocable.h"
-#include "absl/status/status.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/gpu/autotuner_compile_util.h"
-#include "xla/service/gpu/autotuner_util.h"
-#include "xla/service/hlo_module_config.h"
-#include "xla/service/hlo_pass_interface.h"
-#include "xla/service/shaped_buffer.h"
-#include "xla/shape.h"
-#include "xla/stream_executor/stream.h"
-
-namespace xla::gpu {
-
-// For each Triton fusion in the Hlo module this pass checks that the output
-// of the fusion generated via Triton matches the output of the fusion if
-// generated with the regular emitters.
-class TritonFusionNumericsVerifier : public HloModulePass {
- public:
-  explicit TritonFusionNumericsVerifier(const AutotuneConfig& config)
-      : config_(config) {}
-
-  static absl::string_view Name() { return "triton-numerics-verifier"; }
-  absl::string_view name() const override { return Name(); }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-
- private:
-  AutotuneConfig config_;
-};
-
-namespace triton_fusion_numerics_pass_internal {
-// These are exposed only for testing. Do not use.
-absl::StatusOr<ScopedShapedBuffer> CompileAndRunFusion(
-    AutotunerCompileUtil& util, const HloFusionInstruction& fusion,
-    const AutotuneConfig& config, const DebugOptions& debug_opts,
-    bool clear_backend_config);
-absl::Status CompareBuffers(const ScopedShapedBuffer& current,
-                            const ScopedShapedBuffer& expected,
-                            const Shape& shape, const HloModuleConfig& config,
-                            se::Stream* stream);
-absl::Status ForAllTritonFusions(
-    const HloModule& module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads,
-    absl::AnyInvocable<absl::Status(const HloFusionInstruction&)> fn);
-}  // namespace triton_fusion_numerics_pass_internal
-
-}  // namespace xla::gpu
-
-#endif  // XLA_SERVICE_GPU_TRITON_FUSION_NUMERICS_VERIFIER_H_
diff --git a/third_party/xla/xla/service/gpu/triton_fusion_numerics_verifier_test.cc b/third_party/xla/xla/service/gpu/triton_fusion_numerics_verifier_test.cc
deleted file mode 100644
index 8703eff..0000000
--- a/third_party/xla/xla/service/gpu/triton_fusion_numerics_verifier_test.cc
+++ /dev/null
@@ -1,195 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/triton_fusion_numerics_verifier.h"
-
-#include <memory>
-#include <utility>
-#include <vector>
-
-#include <gtest/gtest.h>
-#include "absl/status/status.h"
-#include "absl/strings/string_view.h"
-#include "absl/strings/substitute.h"
-#include "xla/hlo/ir/hlo_instructions.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/primitive_util.h"
-#include "xla/service/gpu/autotuner_compile_util.h"
-#include "xla/service/gpu/autotuner_util.h"
-#include "xla/service/platform_util.h"
-#include "xla/stream_executor/platform.h"
-#include "xla/test_helpers.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/tsl/lib/core/status_test_util.h"
-
-namespace xla::gpu {
-namespace {
-
-class TritonFusionNumericsVerifierTest
-    : public HloTestBase,
-      public ::testing::WithParamInterface<PrimitiveType> {
- public:
-  DebugOptions GetDebugOptionsForTest() override {
-    auto options = HloTestBase::GetDebugOptionsForTest();
-    options.set_xla_gpu_enable_triton_softmax_fusion(true);
-    options.set_xla_gpu_verify_triton_fusion_numerics(true);
-    return options;
-  }
-
- protected:
-  std::unique_ptr<xla::HloModule> Module(absl::string_view hlo_text_template,
-                                         absl::string_view type) {
-    auto m = GetOptimizedModule(absl::Substitute(hlo_text_template, type));
-    TF_EXPECT_OK(m);
-    return std::move(m.value());
-  }
-
-  const HloFusionInstruction* TritonFusion(const xla::HloModule& module) {
-    const HloFusionInstruction* fusion_result = nullptr;
-
-    absl::Status res =
-        triton_fusion_numerics_pass_internal::ForAllTritonFusions(
-            module, /*execution_threads=*/{},
-            [&](const HloFusionInstruction& fusion) -> absl::Status {
-              EXPECT_EQ(fusion_result, nullptr);
-              fusion_result = &fusion;
-              return absl::OkStatus();
-            });
-    return fusion_result;
-  }
-
-  AutotuneConfig CreateAutotuneConfig() {
-    se::Platform* platform = PlatformUtil::GetDefaultPlatform().value();
-    auto executors_or = PlatformUtil::GetStreamExecutors(platform);
-    TF_EXPECT_OK(executors_or);
-    return AutotuneConfig{DeviceConfig{executors_or->at(0), nullptr},
-                          GetDebugOptionsForTest()};
-  }
-
-  AutotunerCompileUtil CreateAutotunerCompileUtil(AutotuneConfig& config) {
-    auto opt_compile_util_or =
-        AutotunerCompileUtil::Create(config, GetDebugOptionsForTest());
-    TF_EXPECT_OK(opt_compile_util_or);
-    EXPECT_TRUE(opt_compile_util_or->has_value());
-    return std::move(opt_compile_util_or->value());
-  }
-};
-
-constexpr absl::string_view kSoftmaxHlo = R"(
-HloModule softmax
-max_computation {
-  arg_0 = $0[] parameter(0)
-  arg_1 = $0[] parameter(1)
-  ROOT maximum = $0[] maximum(arg_0, arg_1)
-}
-add_computation {
-  arg_0.1 = $0[] parameter(0)
-  arg_1.1 = $0[] parameter(1)
-  ROOT add = $0[] add(arg_0.1, arg_1.1)
-}
-ENTRY main {
-  param_0 = $0[127,125]{1,0} parameter(0)
-  constant_neg_inf = $0[] constant(-inf)
-  reduce = $0[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation
-  broadcast = $0[127,125]{1,0} broadcast(reduce), dimensions={0}
-  subtract = $0[127,125]{1,0} subtract(param_0, broadcast)
-  exponential = $0[127,125]{1,0} exponential(subtract)
-  constant_zero = $0[] constant(0)
-  second_reduce = $0[127]{0} reduce(exponential, constant_zero), dimensions={1}, to_apply=add_computation
-  second_broadcast = $0[127,125]{1,0} broadcast(second_reduce), dimensions={0}
-  ROOT divide = $0[127,125]{1,0} divide(exponential, second_broadcast)
-}
-)";
-
-bool HloPassHasRun(const HloModule& module, absl::string_view pass_name) {
-  for (const auto& pass_metadata : module.metadata().proto().pass_metadata()) {
-    if (pass_metadata.pass_name() == pass_name) {
-      return true;
-    }
-  }
-  return false;
-}
-
-TEST_P(TritonFusionNumericsVerifierTest, VerifyExactSoftmaxFusionNumerics) {
-  PrimitiveType data_type = GetParam();
-
-  auto module = Module(kSoftmaxHlo,
-                       primitive_util::LowercasePrimitiveTypeName(data_type));
-
-  // At this point all HLO passes have been executed successfully, because the
-  // Module() function hasn't failed. In particular the numerics verification
-  // pass should have also run and **not** found any issues. Below we just
-  // ensure that the pass has indeed been correctly enabled and that there are
-  // Triton Fusions in the input module.
-
-  EXPECT_TRUE(HloPassHasRun(*module, TritonFusionNumericsVerifier::Name()));
-  auto fusion = TritonFusion(*module);
-  EXPECT_NE(fusion, nullptr);
-}
-
-TEST_F(TritonFusionNumericsVerifierTest, CheckMismatch) {
-  // This test intentionally compares two different Triton modules to each
-  // other. This is to test that the verifier functions correctly catch and
-  // report mismatches.
-  //
-  // Note that as part of computing the two modules below, the numerics verifier
-  // pass also runs individually for each module. These runs compare the
-  // modules to the corresponding emitters generated version, which matches. In
-  // that sense this test covers what is being tested by
-  // VerifyExactSoftmaxFusionNumerics. The reason to keep two tests is that
-  // VerifyExactSoftmaxFusionNumerics is minimal and will be easier to debug if
-  // it fails.
-
-  auto module_f16 = Module(kSoftmaxHlo, "f16");
-  auto fusion_f16 = TritonFusion(*module_f16);
-  EXPECT_NE(fusion_f16, nullptr);
-
-  auto module_f32 = Module(kSoftmaxHlo, "f32");
-  auto fusion_f32 = TritonFusion(*module_f32);
-  EXPECT_NE(fusion_f32, nullptr);
-
-  AutotuneConfig autotune_config = CreateAutotuneConfig();
-  AutotunerCompileUtil compile_util =
-      CreateAutotunerCompileUtil(autotune_config);
-  const DebugOptions& debug_options = GetDebugOptionsForTest();
-
-  auto f16_result = triton_fusion_numerics_pass_internal::CompileAndRunFusion(
-      compile_util, *fusion_f16, autotune_config, debug_options,
-      /*clear_backend_config=*/false);
-  TF_EXPECT_OK(f16_result);
-
-  auto f32_result = triton_fusion_numerics_pass_internal::CompileAndRunFusion(
-      compile_util, *fusion_f32, autotune_config, debug_options,
-      /*clear_backend_config=*/false);
-  TF_EXPECT_OK(f32_result);
-
-  auto stream = autotune_config.GetStream();
-  TF_EXPECT_OK(stream);
-
-  // Intentionally compare the fusions from the different modules, triggering a
-  // mismatch.
-  auto cmp = triton_fusion_numerics_pass_internal::CompareBuffers(
-      *f16_result, *f32_result, fusion_f16->shape(),
-      fusion_f16->GetModule()->config(), *stream);
-
-  EXPECT_FALSE(cmp.ok());
-}
-
-INSTANTIATE_TEST_SUITE_P(TritonFusionNumericsVerifierTestSuite,
-                         TritonFusionNumericsVerifierTest,
-                         ::testing::Values(F32, F16, BF16));
-
-}  // namespace
-}  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/variadic_op_splitter.cc b/third_party/xla/xla/service/gpu/variadic_op_splitter.cc
deleted file mode 100644
index f137157..0000000
--- a/third_party/xla/xla/service/gpu/variadic_op_splitter.cc
+++ /dev/null
@@ -1,115 +0,0 @@
-/* Copyright 2018 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/variadic_op_splitter.h"
-
-#include <cstdint>
-#include <vector>
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/shape.h"
-#include "xla/util.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/statusor.h"
-
-namespace xla {
-namespace gpu {
-
-namespace {
-// The parameter space on the GPU device is limited. We pick an arbitrary low
-// constant here to try to prevent exceeding this parameter space. For a proper
-// fix, we would have to take into account which parameters share a buffer, and
-// how big these buffers are.
-constexpr int32_t kMaxParameters = 128;
-
-absl::StatusOr<bool> SplitConcatenate(HloInstruction* concat,
-                                      HloComputation* comp) {
-  auto operands = concat->operands();
-  std::vector<HloInstruction*> operands_to_split(operands.begin(),
-                                                 operands.end());
-  while (operands_to_split.size() > 1) {
-    std::vector<HloInstruction*> new_operands;
-    absl::Span<HloInstruction*> operands_span(operands_to_split);
-    for (int64_t offset = 0; offset < operands_to_split.size();
-         offset += kMaxParameters) {
-      // Check if there is a remainder of operands that does not completely fill
-      // one "batch" of exactly 'kMaxParameters' operands. If there are only
-      // less than 'kMaxParameters' operands left, then we still put them into a
-      // concat together. Otherwise, we spare them for another round so that
-      // they can be put together into a concat with some of the newly created
-      // concats.
-      if (offset > 0 && offset + kMaxParameters > operands_to_split.size()) {
-        new_operands.insert(new_operands.end(),
-                            operands_to_split.begin() + offset,
-                            operands_to_split.end());
-      } else {
-        Shape new_shape = concat->shape();
-        int64_t concat_dimension_size = 0;
-        for (int64_t i = 0;
-             i < kMaxParameters && offset + i < operands_to_split.size(); ++i) {
-          concat_dimension_size +=
-              operands_to_split[i + offset]->shape().dimensions(
-                  concat->concatenate_dimension());
-        }
-        new_shape.set_dimensions(concat->concatenate_dimension(),
-                                 concat_dimension_size);
-        auto new_concat = comp->AddInstruction(concat->CloneWithNewOperands(
-            new_shape, operands_span.subspan(offset, kMaxParameters)));
-        new_operands.push_back(new_concat);
-      }
-    }
-    operands_to_split = new_operands;
-  }
-  TF_RETURN_IF_ERROR(comp->ReplaceInstruction(concat, operands_to_split[0]));
-  return true;
-}
-
-std::vector<HloInstruction*> GetRelevantVariadicOps(HloComputation* comp) {
-  std::vector<HloInstruction*> ops;
-  for (HloInstruction* instr : comp->instructions()) {
-    if (instr->opcode() == HloOpcode::kConcatenate &&
-        instr->operand_count() > kMaxParameters) {
-      ops.push_back(instr);
-    }
-  }
-  return ops;
-}
-
-}  // namespace
-
-absl::StatusOr<bool> VariadicOpSplitter::Run(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  bool changed = false;
-  for (HloComputation* comp :
-       module->MakeNonfusionComputations(execution_threads)) {
-    for (HloInstruction* op : GetRelevantVariadicOps(comp)) {
-      // TODO(b/112613927): Handle also other ops than concatenate.
-      TF_ASSIGN_OR_RETURN(bool result, SplitConcatenate(op, comp));
-      changed |= result;
-    }
-  }
-  return changed;
-}
-
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/variadic_op_splitter.h b/third_party/xla/xla/service/gpu/variadic_op_splitter.h
deleted file mode 100644
index 4449ce2..0000000
--- a/third_party/xla/xla/service/gpu/variadic_op_splitter.h
+++ /dev/null
@@ -1,43 +0,0 @@
-/* Copyright 2018 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_GPU_VARIADIC_OP_SPLITTER_H_
-#define XLA_SERVICE_GPU_VARIADIC_OP_SPLITTER_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/status/statusor.h"
-#include "absl/strings/string_view.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace gpu {
-
-// Splits variadic ops with many operands into pieces such that we don't exceed
-// the parameter space on the GPU. Currently only concatenate ops are split up.
-class VariadicOpSplitter : public HloModulePass {
- public:
-  absl::string_view name() const override { return "variadic-op-splitter"; }
-
-  using HloPassInterface::Run;
-  absl::StatusOr<bool> Run(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
-};
-
-}  // namespace gpu
-}  // namespace xla
-
-#endif  // XLA_SERVICE_GPU_VARIADIC_OP_SPLITTER_H_
diff --git a/third_party/xla/xla/service/gpu/variadic_op_splitter_test.cc b/third_party/xla/xla/service/gpu/variadic_op_splitter_test.cc
deleted file mode 100644
index 6d7b72e..0000000
--- a/third_party/xla/xla/service/gpu/variadic_op_splitter_test.cc
+++ /dev/null
@@ -1,84 +0,0 @@
-/* Copyright 2018 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/gpu/variadic_op_splitter.h"
-
-#include <cstdint>
-#include <vector>
-
-#include <gtest/gtest.h>
-#include "xla/hlo/ir/hlo_computation.h"
-#include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/literal_util.h"
-#include "xla/service/hlo_parser.h"
-#include "xla/service/pattern_matcher.h"
-#include "xla/shape_util.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/util.h"
-#include "xla/xla_data.pb.h"
-
-namespace xla {
-namespace gpu {
-namespace {
-using match::Concatenate;
-
-class VariadicOpSplitterTest : public HloTestBase {};
-
-TEST_F(VariadicOpSplitterTest, DontSplit) {
-  auto module = ParseAndReturnVerifiedModule(R"(
-  HloModule TestModule
-
-  ENTRY TestComputation {
-    p0 = f16[30,41] parameter(0)
-    p1 = f16[30,41] parameter(1)
-    ROOT result = f16[60, 41] concatenate(p0, p1), dimensions={0}
-  })")
-                    .value();
-  EXPECT_FALSE(VariadicOpSplitter().Run(module.get()).value());
-}
-
-TEST_F(VariadicOpSplitterTest, SplitInto2) {
-  auto builder = HloComputation::Builder(TestName());
-  auto operand = builder.AddInstruction(
-      HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32_t>({42})));
-  std::vector<HloInstruction*> concat_operands(255, operand);
-  builder.AddInstruction(HloInstruction::CreateConcatenate(
-      ShapeUtil::MakeShape(S32, {255}), concat_operands, 0));
-  auto module = CreateNewVerifiedModule();
-  auto entry_computation = module->AddEntryComputation(builder.Build());
-  EXPECT_TRUE(VariadicOpSplitter().Run(module.get()).value());
-  EXPECT_TRUE(Match(entry_computation->root_instruction(),
-                    Concatenate().WithNumOperands(128).WithOperand(
-                        0, Concatenate().WithNumOperands(128))));
-}
-
-TEST_F(VariadicOpSplitterTest, SplitInto3) {
-  auto builder = HloComputation::Builder(TestName());
-  auto operand = builder.AddInstruction(
-      HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32_t>({42})));
-  std::vector<HloInstruction*> concat_operands(256, operand);
-  builder.AddInstruction(HloInstruction::CreateConcatenate(
-      ShapeUtil::MakeShape(S32, {256}), concat_operands, 0));
-  auto module = CreateNewVerifiedModule();
-  auto entry_computation = module->AddEntryComputation(builder.Build());
-  EXPECT_TRUE(VariadicOpSplitter().Run(module.get()).value());
-  EXPECT_TRUE(Match(entry_computation->root_instruction(),
-                    Concatenate(Concatenate().WithNumOperands(128),
-                                Concatenate().WithNumOperands(128))));
-}
-
-}  // namespace
-}  // namespace gpu
-}  // namespace xla
diff --git a/third_party/xla/xla/service/heap_simulator/heap_simulator.cc b/third_party/xla/xla/service/heap_simulator/heap_simulator.cc
index fc319e6..7f6cce5 100644
--- a/third_party/xla/xla/service/heap_simulator/heap_simulator.cc
+++ b/third_party/xla/xla/service/heap_simulator/heap_simulator.cc
@@ -219,14 +219,12 @@
 absl::StatusOr<int64_t> HeapSimulator::MinimumMemoryForComputation(
     const HloComputation& computation, const HloInstructionSequence& sequence,
     const HloAliasAnalysis& alias_analysis,
-    const LogicalBuffer::SizeFunction& size_function,
-    const absl::flat_hash_map<const HloComputation*, int64_t>*
-        memory_by_computation) {
+    const LogicalBuffer::SizeFunction& size_function) {
   TF_ASSIGN_OR_RETURN(
       HeapSimulator::Result<HloValue> result,
       HeapSimulator::Run(std::make_unique<NoFragmentationStatsHeap<HloValue>>(),
                          computation, sequence, alias_analysis, size_function,
-                         HeapSimulator::Options(), memory_by_computation));
+                         HeapSimulator::Options()));
   return result.heap_size;
 }
 
@@ -267,11 +265,9 @@
     const HloComputation& computation,
     const HloInstructionSequence& instruction_sequence,
     const HloAliasAnalysis& alias_analysis,
-    const BufferValue::SizeFunction& size_fn, const Options& options,
-    const absl::flat_hash_map<const HloComputation*, int64_t>*
-        memory_by_computation) {
+    const BufferValue::SizeFunction& size_fn, const Options& options) {
   HeapSimulator heap(std::move(algorithm), size_fn, options,
-                     /*schedule=*/nullptr, memory_by_computation);
+                     /*schedule=*/nullptr);
   HloSchedule schedule(computation.parent());
   schedule.set_sequence(&computation, instruction_sequence);
   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloLiveRange> hlo_live_range,
@@ -291,7 +287,7 @@
     const BufferValue::SizeFunction& size_fn, const HloSchedule* schedule,
     const Options& options) {
   HeapSimulator heap(std::move(algorithm), size_fn, options,
-                     /*schedule=*/schedule, nullptr);
+                     /*schedule=*/schedule);
   TF_ASSIGN_OR_RETURN(
       std::unique_ptr<HloLiveRange> hlo_live_range,
       HloLiveRange::Run(*schedule, alias_analysis, &computation));
@@ -492,19 +488,16 @@
   return absl::OkStatus();
 }
 
-HeapSimulator::HeapSimulator(
-    std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,
-    const BufferValue::SizeFunction& size_fn, const Options& options,
-    const HloSchedule* schedule,
-    const absl::flat_hash_map<const HloComputation*, int64_t>*
-        memory_by_computation)
+HeapSimulator::HeapSimulator(std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,
+                             const BufferValue::SizeFunction& size_fn,
+                             const Options& options,
+                             const HloSchedule* schedule)
     : no_fragmentation_stats_(
           std::make_unique<NoFragmentationStatsHeap<HloValue>>()),
       algorithm_(std::move(algorithm)),
       size_fn_(size_fn),
       options_(options),
-      schedule_(schedule),
-      memory_by_computation_(memory_by_computation) {
+      schedule_(schedule) {
   debug_trace_.set_whole_module_simulation(schedule_ != nullptr);
 }
 
@@ -629,21 +622,10 @@
 
 template <typename BufferType>
 void NoFragmentationStatsHeap<BufferType>::AccountForSubcomputationMemory(
-    const HloInstruction* instruction, int64_t alloc_size_by_instruction,
-    const absl::flat_hash_map<const HloComputation*, int64_t>&
-        memory_by_computation) {
+    const HloInstruction* instruction, int64_t alloc_size_by_instruction) {
   // We only count the memory usage of the largest subcomputation, instead of
   // adding them all, because subcomputations won't execute in parallel.
   int64_t max_subcomputation_bytes = 0;
-  for (const auto* c : instruction->called_computations()) {
-    auto it = memory_by_computation.find(c);
-    if (it != memory_by_computation.end()) {
-      int64_t subcomputation_bytes = it->second;
-      if (subcomputation_bytes > max_subcomputation_bytes) {
-        max_subcomputation_bytes = subcomputation_bytes;
-      }
-    }
-  }
   if (max_subcomputation_bytes > 0 &&
       (instruction->opcode() == HloOpcode::kWhile ||
        instruction->opcode() == HloOpcode::kCall ||
diff --git a/third_party/xla/xla/service/heap_simulator/heap_simulator.h b/third_party/xla/xla/service/heap_simulator/heap_simulator.h
index 09e12d2..e446a4e 100644
--- a/third_party/xla/xla/service/heap_simulator/heap_simulator.h
+++ b/third_party/xla/xla/service/heap_simulator/heap_simulator.h
@@ -148,9 +148,7 @@
   static absl::StatusOr<int64_t> MinimumMemoryForComputation(
       const HloComputation& computation, const HloInstructionSequence& sequence,
       const HloAliasAnalysis& alias_analysis,
-      const LogicalBuffer::SizeFunction& size_function,
-      const absl::flat_hash_map<const HloComputation*, int64_t>*
-          memory_by_computation = nullptr);
+      const LogicalBuffer::SizeFunction& size_function);
 
   static absl::StatusOr<int64_t> MinimumMemoryForComputation(
       const HloComputation& computation, const HloInstructionSequence& sequence,
@@ -184,9 +182,7 @@
       const HloInstructionSequence& instruction_sequence,
       const HloAliasAnalysis& alias_analysis,
       const BufferValue::SizeFunction& size_fn,
-      const Options& options = Options(),
-      const absl::flat_hash_map<const HloComputation*, int64_t>*
-          memory_by_computation = nullptr);
+      const Options& options = Options());
 
   // Same as above, but runs on with a schedule that covers all nested
   // computations.
@@ -204,9 +200,7 @@
   // be run recursively. I.e. the simulation is run over the whole module.
   HeapSimulator(std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,
                 const BufferValue::SizeFunction& size_fn,
-                const Options& options, const HloSchedule* schedule = nullptr,
-                const absl::flat_hash_map<const HloComputation*, int64_t>*
-                    memory_by_computation = nullptr);
+                const Options& options, const HloSchedule* schedule = nullptr);
   ~HeapSimulator();
 
   absl::Status RunComputation(
@@ -244,13 +238,10 @@
   const std::unique_ptr<HeapAlgorithm<HloValue>> algorithm_;
   const BufferValue::SizeFunction size_fn_;
   const Options options_;
-  // schedule_ is set by buffer assignment, and memory_by_computation_ is
-  // set by hlo scheduling. Then, in RunComputation, we check both in order to
-  // handle subcomputations. It would be good to unify the handling of
-  // subcomputations, but it's not clear how.
+  // schedule_ is set by buffer assignment. Then, in RunComputation, we check
+  // both in order to handle subcomputations. It would be good to unify the
+  // handling of subcomputations, but it's not clear how.
   const HloSchedule* schedule_;
-  const absl::flat_hash_map<const HloComputation*, int64_t>*
-      memory_by_computation_;
 
   // Hold some sets for error-checking the sequence of Alloc and Free calls.
   absl::flat_hash_set<const HloValue*> allocated_buffers_;
@@ -290,9 +281,7 @@
   virtual void AccountForSubcomputationMemory(
       const HloInstruction* instruction,
       // The total number of bytes allocated by instruction.
-      int64_t alloc_size_by_instruction,
-      const absl::flat_hash_map<const HloComputation*, int64_t>&
-          memory_by_computation) {}
+      int64_t alloc_size_by_instruction) {}
 
   // Free de-allocates a previously allocated buffer.
   virtual void Free(const BufferType* buffer, int64_t size) = 0;
@@ -328,9 +317,8 @@
   void Alloc(const BufferType* buffer, int64_t size) override;
 
   void AccountForSubcomputationMemory(
-      const HloInstruction* instruction, int64_t alloc_size_by_instruction,
-      const absl::flat_hash_map<const HloComputation*, int64_t>&
-          memory_by_computation) override;
+      const HloInstruction* instruction,
+      int64_t alloc_size_by_instruction) override;
 
   void Free(const BufferType* buffer, int64_t size) override;
 
diff --git a/third_party/xla/xla/service/heap_simulator/heap_simulator_test.cc b/third_party/xla/xla/service/heap_simulator/heap_simulator_test.cc
index cff0e2f..c0f48d4 100644
--- a/third_party/xla/xla/service/heap_simulator/heap_simulator_test.cc
+++ b/third_party/xla/xla/service/heap_simulator/heap_simulator_test.cc
@@ -210,9 +210,6 @@
   auto size_fn = [](const BufferValue& buffer) {
     return ShapeUtil::ByteSizeOf(buffer.shape());
   };
-  absl::flat_hash_map<const HloComputation*, int64_t> memory_by_computation;
-  memory_by_computation[cond_computation] = 5;
-  memory_by_computation[body_computation] = 16;
 
   std::unique_ptr<HloAliasAnalysis> alias_analysis =
       HloAliasAnalysis::Run(module.get()).value();
@@ -221,7 +218,7 @@
   // so we don't double count.
   EXPECT_EQ(64, HeapSimulator::MinimumMemoryForComputation(
                     *entry_computation, schedule.sequence(entry_computation),
-                    *alias_analysis, size_fn, &memory_by_computation)
+                    *alias_analysis, size_fn)
                     .value());
 }
 
diff --git a/third_party/xla/xla/service/hlo.proto b/third_party/xla/xla/service/hlo.proto
index fdeaa68..83c3b85 100644
--- a/third_party/xla/xla/service/hlo.proto
+++ b/third_party/xla/xla/service/hlo.proto
@@ -112,7 +112,7 @@
 }
 
 // Serialization of HloInstruction.
-// Next ID: 89
+// Next ID: 90
 message HloInstructionProto {
   reserved 10;
   reserved "parameter_name";
@@ -385,6 +385,9 @@
 
   // For HLO value tracking.
   xla.OriginalValueProto original_value = 88;
+
+  // Specifies if a call instruction is a composite.
+  bool is_composite = 89;
 }
 
 // Serialization of HloComputation.
@@ -579,6 +582,7 @@
     FUSION = 2;
     LAYOUT = 3;
     DOT = 4;
+    FLAGNET = 5;
   }
 
   // Information about the optimization profile that this module contains.
diff --git a/third_party/xla/xla/service/hlo_computation_test.cc b/third_party/xla/xla/service/hlo_computation_test.cc
index 16ce018..a7190b3 100644
--- a/third_party/xla/xla/service/hlo_computation_test.cc
+++ b/third_party/xla/xla/service/hlo_computation_test.cc
@@ -15,8 +15,8 @@
 
 #include "xla/hlo/ir/hlo_computation.h"
 
+#include <cstdint>
 #include <memory>
-#include <set>
 #include <string>
 #include <string_view>
 #include <vector>
@@ -24,19 +24,24 @@
 #include <gmock/gmock.h>
 #include "absl/container/flat_hash_map.h"
 #include "absl/container/flat_hash_set.h"
+#include "absl/status/status.h"
+#include "xla/comparison_util.h"
 #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
+#include "xla/hlo/ir/hlo_clone_context.h"
 #include "xla/hlo/ir/hlo_instruction.h"
 #include "xla/hlo/ir/hlo_opcode.h"
 #include "xla/hlo/utils/hlo_matchers.h"
-#include "xla/literal.h"
+#include "xla/literal_util.h"
 #include "xla/service/hlo_parser.h"
 #include "xla/service/pattern_matcher.h"
 #include "xla/service/pattern_matcher_gmock.h"
 #include "xla/shape.h"
+#include "xla/shape_tree.h"
 #include "xla/shape_util.h"
 #include "xla/test.h"
 #include "xla/test_helpers.h"
 #include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/status.h"
 #include "tsl/platform/statusor.h"
 
 namespace xla {
@@ -940,5 +945,32 @@
             cloned_done.get()->async_wrapped_computation());
 }
 
+TEST_F(HloComputationTest, CompositeCall) {
+  const char* const hlo_string = R"(
+  HloModule Module
+
+  add (x: f32[]) -> f32[] {
+    %x = f32[] parameter(0)
+    %constant = f32[] constant(2)
+    ROOT %z = f32[] add(f32[] %x, f32[] %constant)
+  }
+
+  ENTRY %CallR0F32AddScalar.v2 () -> f32[] {
+    %constant.1 = f32[] constant(42)
+    ROOT %call = f32[] call(f32[] %constant.1), to_apply=add, is_composite=true,
+      frontend_attributes={
+        composite.attributes={n = 1 : i32, tensor = dense<1> : tensor<i32>},
+        composite.name="foo.bar",
+        composite.version="1"
+      }
+})";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  HloInstruction* composite_call = FindInstruction(module.get(), "call");
+  EXPECT_EQ(composite_call->opcode(), HloOpcode::kCall);
+  EXPECT_TRUE(composite_call->is_composite());
+  EXPECT_EQ(composite_call->frontend_attributes().map().size(), 3);
+}
+
 }  // namespace
 }  // namespace xla
diff --git a/third_party/xla/xla/service/hlo_instruction_test.cc b/third_party/xla/xla/service/hlo_instruction_test.cc
index 981b967..20f7690 100644
--- a/third_party/xla/xla/service/hlo_instruction_test.cc
+++ b/third_party/xla/xla/service/hlo_instruction_test.cc
@@ -15,6 +15,11 @@
 
 #include "xla/hlo/ir/hlo_instruction.h"
 
+#include <cstddef>
+#include <cstdint>
+#include <initializer_list>
+#include <limits>
+#include <memory>
 #include <optional>
 #include <set>
 #include <string>
@@ -22,16 +27,23 @@
 #include <vector>
 
 #include "absl/container/flat_hash_map.h"
+#include "absl/status/status.h"
+#include "absl/strings/string_view.h"
+#include "xla/comparison_util.h"
+#include "xla/hlo/ir/collective_device_list.h"
 #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
 #include "xla/hlo/ir/hlo_casting_utils.h"
 #include "xla/hlo/ir/hlo_computation.h"
 #include "xla/hlo/ir/hlo_instructions.h"
 #include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/literal.h"
+#include "xla/hlo/ir/hlo_sharding.h"
+#include "xla/layout_util.h"
+#include "xla/literal_util.h"
 #include "xla/protobuf_util.h"
 #include "xla/service/gpu/backend_configs.pb.h"
 #include "xla/service/pattern_matcher.h"
 #include "xla/service/pattern_matcher_gmock.h"
+#include "xla/shape.h"
 #include "xla/shape_util.h"
 #include "xla/test.h"
 #include "xla/test_helpers.h"
@@ -40,6 +52,7 @@
 #include "xla/util.h"
 #include "xla/window_util.h"
 #include "xla/xla_data.pb.h"
+#include "tsl/platform/statusor.h"
 
 namespace xla {
 namespace {
@@ -2752,7 +2765,7 @@
 
   module->AddEntryComputation(main_builder.Build());
   // Should find conditional branch computations in the graph and it should
-  // point to the conditonal instruction.
+  // point to the conditional instruction.
   int num_conditional_branch_comp = 0;
   for (HloComputation* comp : module->MakeComputationPostOrder()) {
     if (comp->IsConditionalBranchComputation()) {
@@ -2827,7 +2840,7 @@
 
   module->AddEntryComputation(main_builder.Build());
   // Should find conditional branch computations in the graph and it should
-  // point to the conditonal instruction.
+  // point to the conditional instruction.
   int num_conditional_branch_comp = 0;
   for (HloComputation* comp : module->MakeComputationPostOrder()) {
     if (comp->IsConditionalBranchComputation()) {
diff --git a/third_party/xla/xla/service/hlo_memory_scheduler.cc b/third_party/xla/xla/service/hlo_memory_scheduler.cc
index 283b82e..83e4072 100644
--- a/third_party/xla/xla/service/hlo_memory_scheduler.cc
+++ b/third_party/xla/xla/service/hlo_memory_scheduler.cc
@@ -90,11 +90,8 @@
   static absl::StatusOr<HloInstructionSequence> Run(
       HloComputation* computation,
       const TuplePointsToAnalysis& points_to_analysis,
-      const BufferValue::SizeFunction& size_function,
-      const absl::flat_hash_map<const HloComputation*, int64_t>&
-          memory_by_computation) {
-    ListScheduler scheduler(computation, points_to_analysis, size_function,
-                            memory_by_computation);
+      const BufferValue::SizeFunction& size_function) {
+    ListScheduler scheduler(computation, points_to_analysis, size_function);
     return scheduler.CreateSchedule();
   }
 
@@ -115,13 +112,10 @@
 
   ListScheduler(HloComputation* computation,
                 const TuplePointsToAnalysis& points_to_analysis,
-                const BufferValue::SizeFunction& size_function,
-                const absl::flat_hash_map<const HloComputation*, int64_t>&
-                    memory_by_computation)
+                const BufferValue::SizeFunction& size_function)
       : computation_(computation),
         points_to_analysis_(points_to_analysis),
-        size_function_(size_function),
-        memory_by_computation_(memory_by_computation) {
+        size_function_(size_function) {
     // Create a map containing the LogicalBuffer uses for each HLO
     // instruction. An HLO instruction "uses" a LogicalBuffer if the
     // LogicalBuffer is in an operand of the instruction as indicated by
@@ -242,29 +236,7 @@
         freed_bytes += size_function_(*buffer);
       }
     }
-    // We only count the memory usage of the largest subcomputation, instead of
-    // adding them all, because subcomputations won't execute in parallel.
-    int64_t max_subcomputation_bytes = 0;
-    for (const auto* c : instruction->called_computations()) {
-      auto it = memory_by_computation_.find(c);
-      if (it != memory_by_computation_.end()) {
-        int64_t subcomputation_bytes = it->second;
-        if (subcomputation_bytes > max_subcomputation_bytes) {
-          max_subcomputation_bytes = subcomputation_bytes;
-        }
-      }
-    }
-    int64_t bytes_defined;
-    if (max_subcomputation_bytes > 0 &&
-        (opcode == HloOpcode::kWhile || opcode == HloOpcode::kCall ||
-         opcode == HloOpcode::kConditional)) {
-      // The output buffer of while/call/conditional is always aliased with the
-      // output buffer of the root instruction in the body. Don't double count.
-      bytes_defined = max_subcomputation_bytes;
-    } else {
-      bytes_defined = entry.bytes_defined + max_subcomputation_bytes;
-    }
-    return freed_bytes - bytes_defined;
+    return freed_bytes - entry.bytes_defined;
   }
 
   // Constructs the scheduling priority of the given instruction.
@@ -392,11 +364,6 @@
   HloComputation* computation_;
   const TuplePointsToAnalysis& points_to_analysis_;
   const BufferValue::SizeFunction& size_function_;
-  // Computations are analyzed in post-order. When scheduling an instruction
-  // that includes subcomputations, such as a while loop, we use this map to
-  // look up the memory needed by subcomputations.
-  const absl::flat_hash_map<const HloComputation*, int64_t>&
-      memory_by_computation_;
 
   // A map containing the LogicalBuffers that each instruction uses.
   absl::flat_hash_map<const HloInstruction*, std::vector<const LogicalBuffer*>>
@@ -426,19 +393,15 @@
     const HloAliasAnalysis& alias_analysis,
     const BufferValue::SizeFunction& size_function,
     const MemorySchedulerAlgorithm& algorithm,
-    const absl::flat_hash_map<const HloComputation*, int64_t>&
-        memory_by_computation,
     const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory) {
   VLOG(2) << "Computation: " << computation->name();
 
   if (algorithm) {
     return algorithm(computation, points_to_analysis, alias_analysis,
-                     size_function, memory_by_computation, postprocessor,
-                     peak_memory);
+                     size_function, postprocessor, peak_memory);
   }
   return DefaultMemoryScheduler(computation, points_to_analysis, alias_analysis,
-                                size_function, memory_by_computation,
-                                postprocessor, peak_memory);
+                                size_function, postprocessor, peak_memory);
 }
 
 }  // namespace
@@ -448,8 +411,6 @@
     const TuplePointsToAnalysis& points_to_analysis,
     const HloAliasAnalysis& alias_analysis,
     const BufferValue::SizeFunction& size_function,
-    const absl::flat_hash_map<const HloComputation*, int64_t>&
-        memory_by_computation,
     const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory) {
   // These variables are a hack to prevent overflows.
   int64_t cumulative_total_size = 0;
@@ -526,9 +487,9 @@
   CHECK_EQ(sequence.size(), computation->instruction_count());
   if (peak_memory) {
     TF_ASSIGN_OR_RETURN(
-        *peak_memory, HeapSimulator::MinimumMemoryForComputation(
-                          *computation, sequence, alias_analysis, size_function,
-                          &memory_by_computation));
+        *peak_memory,
+        HeapSimulator::MinimumMemoryForComputation(
+            *computation, sequence, alias_analysis, size_function));
   }
   return sequence;
 }
@@ -538,8 +499,6 @@
     const TuplePointsToAnalysis& points_to_analysis,
     const HloAliasAnalysis& alias_analysis,
     const BufferValue::SizeFunction& size_function,
-    const absl::flat_hash_map<const HloComputation*, int64_t>&
-        memory_by_computation,
     const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory) {
   // Index of HloInstruction in the `computation`.
   absl::flat_hash_map<const HloInstruction*, int64_t> inst_index;
@@ -586,9 +545,9 @@
   CHECK_EQ(sequence.size(), computation->instruction_count());
   if (peak_memory) {
     TF_ASSIGN_OR_RETURN(
-        *peak_memory, HeapSimulator::MinimumMemoryForComputation(
-                          *computation, sequence, alias_analysis, size_function,
-                          &memory_by_computation));
+        *peak_memory,
+        HeapSimulator::MinimumMemoryForComputation(
+            *computation, sequence, alias_analysis, size_function));
   }
 
   return sequence;
@@ -605,16 +564,14 @@
              const absl::flat_hash_set<absl::string_view>& execution_threads,
              int64_t* peak_memory) -> absl::StatusOr<HloSchedule> {
     HloSchedule schedule(module);
-    absl::flat_hash_map<const HloComputation*, int64_t> memory_by_computation;
     for (auto* computation :
          module->MakeComputationPostOrder(execution_threads)) {
       if (!computation->IsFusionComputation()) {
-        TF_ASSIGN_OR_RETURN(
-            HloInstructionSequence computation_sequence,
-            ScheduleComputationHelper(
-                computation, points_to_analysis, alias_analysis, size_func,
-                computation_scheduler, memory_by_computation, postprocessor,
-                /*peak_memory=*/nullptr));
+        TF_ASSIGN_OR_RETURN(HloInstructionSequence computation_sequence,
+                            ScheduleComputationHelper(
+                                computation, points_to_analysis, alias_analysis,
+                                size_func, computation_scheduler, postprocessor,
+                                /*peak_memory=*/nullptr));
         schedule.set_sequence(computation, std::move(computation_sequence));
       }
     }
@@ -631,20 +588,18 @@
     const TuplePointsToAnalysis& points_to_analysis,
     const HloAliasAnalysis& alias_analysis,
     const BufferValue::SizeFunction& size_function,
-    const absl::flat_hash_map<const HloComputation*, int64_t>&
-        memory_by_computation,
     const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory) {
-  TF_ASSIGN_OR_RETURN(HloInstructionSequence sequence,
-                      ListScheduler::Run(computation, points_to_analysis,
-                                         size_function, memory_by_computation));
+  TF_ASSIGN_OR_RETURN(
+      HloInstructionSequence sequence,
+      ListScheduler::Run(computation, points_to_analysis, size_function));
   if (postprocessor) {
     sequence = postprocessor(sequence);
   }
   if (peak_memory) {
     TF_ASSIGN_OR_RETURN(
-        *peak_memory, HeapSimulator::MinimumMemoryForComputation(
-                          *computation, sequence, alias_analysis, size_function,
-                          &memory_by_computation));
+        *peak_memory,
+        HeapSimulator::MinimumMemoryForComputation(
+            *computation, sequence, alias_analysis, size_function));
   }
   return sequence;
 }
@@ -654,8 +609,6 @@
     const TuplePointsToAnalysis& points_to_analysis,
     const HloAliasAnalysis& alias_analysis,
     const BufferValue::SizeFunction& size_function,
-    const absl::flat_hash_map<const HloComputation*, int64_t>&
-        memory_by_computation,
     const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory) {
   HloInstructionSequence sequence(computation->MakeInstructionPostOrder());
   if (postprocessor) {
@@ -663,9 +616,9 @@
   }
   if (peak_memory) {
     TF_ASSIGN_OR_RETURN(
-        *peak_memory, HeapSimulator::MinimumMemoryForComputation(
-                          *computation, sequence, alias_analysis, size_function,
-                          &memory_by_computation));
+        *peak_memory,
+        HeapSimulator::MinimumMemoryForComputation(
+            *computation, sequence, alias_analysis, size_function));
   }
   return sequence;
 }
@@ -675,8 +628,6 @@
     const TuplePointsToAnalysis& points_to_analysis,
     const HloAliasAnalysis& alias_analysis,
     const BufferValue::SizeFunction& size_function,
-    const absl::flat_hash_map<const HloComputation*, int64_t>&
-        memory_by_computation,
     const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory) {
   // We try a few schedulers and choose whichever returns a lower min-memory,
   // not accounting for fragmentation.
@@ -690,24 +641,21 @@
   TF_ASSIGN_OR_RETURN(
       HloInstructionSequence list_sequence,
       ListMemoryScheduler(computation, points_to_analysis, alias_analysis,
-                          size_function, memory_by_computation, postprocessor,
-                          &list_memory));
+                          size_function, postprocessor, &list_memory));
   VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory);
 
   int64_t dfs_memory;
   TF_ASSIGN_OR_RETURN(
       HloInstructionSequence dfs_sequence,
       DFSMemoryScheduler(computation, points_to_analysis, alias_analysis,
-                         size_function, memory_by_computation, postprocessor,
-                         &dfs_memory));
+                         size_function, postprocessor, &dfs_memory));
   VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory);
 
   int64_t post_order_memory;
-  TF_ASSIGN_OR_RETURN(
-      HloInstructionSequence post_order_sequence,
-      PostOrderMemoryScheduler(computation, points_to_analysis, alias_analysis,
-                               size_function, memory_by_computation,
-                               postprocessor, &post_order_memory));
+  TF_ASSIGN_OR_RETURN(HloInstructionSequence post_order_sequence,
+                      PostOrderMemoryScheduler(
+                          computation, points_to_analysis, alias_analysis,
+                          size_function, postprocessor, &post_order_memory));
   VLOG(2) << "Min-memory post order sequence: "
           << HumanReadableNumBytes(post_order_memory);
 
@@ -815,21 +763,6 @@
   return std::move(schedule);
 }
 
-absl::StatusOr<HloInstructionSequence> ScheduleComputation(
-    HloComputation* computation, const BufferValue::SizeFunction& size_function,
-    const MemorySchedulerPostprocessor& postprocessor) {
-  CHECK(!computation->IsFusionComputation());
-  TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
-                      TuplePointsToAnalysis::Run(computation->parent()));
-  TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
-                      HloAliasAnalysis::Run(computation->parent()));
-  absl::flat_hash_map<const HloComputation*, int64_t> empty_map;
-  return ScheduleComputationHelper(
-      computation, *points_to_analysis, *alias_analysis, size_function,
-      /*algorithm=*/nullptr, empty_map, postprocessor,
-      /*peak_memory=*/nullptr);
-}
-
 HloMemoryScheduler::HloMemoryScheduler(
     const BufferValue::SizeFunction& size_function,
     const ModuleSchedulerAlgorithm& algorithm)
diff --git a/third_party/xla/xla/service/hlo_memory_scheduler.h b/third_party/xla/xla/service/hlo_memory_scheduler.h
index 112ced3..2fb211a 100644
--- a/third_party/xla/xla/service/hlo_memory_scheduler.h
+++ b/third_party/xla/xla/service/hlo_memory_scheduler.h
@@ -19,7 +19,6 @@
 #include <cstdint>
 #include <functional>
 
-#include "absl/container/flat_hash_map.h"
 #include "absl/container/flat_hash_set.h"
 #include "absl/status/statusor.h"
 #include "absl/strings/string_view.h"
@@ -51,7 +50,6 @@
     std::function<absl::StatusOr<HloInstructionSequence>(
         HloComputation*, const TuplePointsToAnalysis&, const HloAliasAnalysis&,
         const LogicalBuffer::SizeFunction&,
-        const absl::flat_hash_map<const HloComputation*, int64_t>&,
         const MemorySchedulerPostprocessor&,
         /*peak_memory*/ int64_t*)>;
 
@@ -73,8 +71,6 @@
     const TuplePointsToAnalysis& points_to_analysis,
     const HloAliasAnalysis& alias_analysis,
     const LogicalBuffer::SizeFunction& size_function,
-    const absl::flat_hash_map<const HloComputation*, int64_t>&
-        memory_by_computation,
     const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory);
 
 // DFS-order scheduler
@@ -83,8 +79,6 @@
     const TuplePointsToAnalysis& points_to_analysis,
     const HloAliasAnalysis& alias_analysis,
     const LogicalBuffer::SizeFunction& size_function,
-    const absl::flat_hash_map<const HloComputation*, int64_t>&
-        memory_by_computation,
     const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory);
 
 // BFS-order scheduler
@@ -102,8 +96,6 @@
     const TuplePointsToAnalysis& points_to_analysis,
     const HloAliasAnalysis& alias_analysis,
     const LogicalBuffer::SizeFunction& size_function,
-    const absl::flat_hash_map<const HloComputation*, int64_t>&
-        memory_by_computation,
     const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory);
 
 // Naive Post Order scheduler
@@ -112,8 +104,6 @@
     const TuplePointsToAnalysis& points_to_analysis,
     const HloAliasAnalysis& alias_analysis,
     const LogicalBuffer::SizeFunction& size_function,
-    const absl::flat_hash_map<const HloComputation*, int64_t>&
-        memory_by_computation,
     const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory);
 
 // The default scheduling algorithm. Runs the list scheduler, the DFS scheduler,
@@ -125,8 +115,6 @@
     const TuplePointsToAnalysis& points_to_analysis,
     const HloAliasAnalysis& alias_analysis,
     const LogicalBuffer::SizeFunction& size_function,
-    const absl::flat_hash_map<const HloComputation*, int64_t>&
-        memory_by_computation,
     const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory);
 
 absl::StatusOr<HloSchedule> DefaultModuleScheduler(
@@ -146,13 +134,6 @@
     const absl::flat_hash_set<absl::string_view>& execution_threads = {},
     int64_t* peak_memory = nullptr);
 
-// Computes the schedule for a single computation.
-// Currently only used by the GPU backend.
-absl::StatusOr<HloInstructionSequence> ScheduleComputation(
-    HloComputation* computation,
-    const LogicalBuffer::SizeFunction& size_function,
-    const MemorySchedulerPostprocessor& postprocessor);
-
 // A pass which schedules the HLO instructions in a module. The HloModule's
 // schedule field is set to the resulting HloSchedule using
 // HloModule::set_schedule.
diff --git a/third_party/xla/xla/service/hlo_module_test.cc b/third_party/xla/xla/service/hlo_module_test.cc
index 2910932..f237575 100644
--- a/third_party/xla/xla/service/hlo_module_test.cc
+++ b/third_party/xla/xla/service/hlo_module_test.cc
@@ -37,9 +37,9 @@
 #include "xla/test.h"
 #include "xla/tests/hlo_test_base.h"
 #include "xla/tsl/lib/core/status_test_util.h"
+#include "xla/tsl/lib/strings/proto_serialization.h"
 #include "xla/xla.pb.h"
 #include "xla/xla_data.pb.h"
-#include "tsl/lib/strings/proto_serialization.h"
 #include "tsl/platform/errors.h"
 #include "tsl/platform/statusor.h"
 
diff --git a/third_party/xla/xla/service/hlo_parser.cc b/third_party/xla/xla/service/hlo_parser.cc
index 8f097ca..2ff0697 100644
--- a/third_party/xla/xla/service/hlo_parser.cc
+++ b/third_party/xla/xla/service/hlo_parser.cc
@@ -80,6 +80,7 @@
 #include "tsl/lib/gtl/map_util.h"
 #include "tsl/platform/errors.h"
 #include "tsl/platform/logging.h"
+#include "tsl/platform/status.h"
 
 namespace xla {
 
@@ -1915,6 +1916,7 @@
           std::vector<HloInstruction*> async_wrapped_operands;
           std::vector<Shape> async_wrapped_operand_shapes;
           Shape async_wrapped_root_shape;
+          async_wrapped_operand_shapes.reserve(operands.size());
           for (const HloInstruction* operand : operands) {
             async_wrapped_operand_shapes.push_back(operand->shape());
           }
@@ -2249,8 +2251,11 @@
     }
     case HloOpcode::kCall: {
       optional<HloComputation*> to_apply;
+      optional<bool> is_composite = false;
       attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
                            &to_apply};
+      attrs["is_composite"] = {/*required=*/false, AttrTy::kBool,
+                               &is_composite};
       if ((!preset_operands && !ParseOperands(&operands, builder)) ||
           !ParseAttributes(attrs, allow_attributes, shape)) {
         return nullptr;
@@ -2266,8 +2271,10 @@
           })) {
         return nullptr;
       }
-      return builder->AddInstruction(
-          HloInstruction::CreateCall(*shape, operands, *to_apply));
+
+      auto call_op = HloInstruction::CreateCall(*shape, operands, *to_apply);
+      call_op->set_is_composite(is_composite.value());
+      return builder->AddInstruction(std::move(call_op));
     }
     case HloOpcode::kReduceWindow: {
       optional<HloComputation*> reduce_computation;
@@ -3178,6 +3185,13 @@
       optional<bool> indices_are_sorted = false;
       attrs["indices_are_sorted"] = {/*required=*/false, AttrTy::kBool,
                                      &indices_are_sorted};
+      optional<std::vector<int64_t>> operand_batching_dims;
+      attrs["operand_batching_dims"] = {
+          /*required=*/false, AttrTy::kBracedInt64List, &operand_batching_dims};
+      optional<std::vector<int64_t>> start_indices_batching_dims;
+      attrs["start_indices_batching_dims"] = {/*required=*/false,
+                                              AttrTy::kBracedInt64List,
+                                              &start_indices_batching_dims};
 
       if ((!preset_operands &&
            !ParseOperands(&operands, builder, /*expected_size=*/2)) ||
@@ -3190,7 +3204,13 @@
               /*offset_dims=*/*offset_dims,
               /*collapsed_slice_dims=*/*collapsed_slice_dims,
               /*start_index_map=*/*start_index_map,
-              /*index_vector_dim=*/*index_vector_dim);
+              /*index_vector_dim=*/*index_vector_dim,
+              /*operand_batching_dims=*/
+              operand_batching_dims ? *operand_batching_dims
+                                    : std::vector<int64_t>(),
+              /*start_indices_batching_dims=*/
+              start_indices_batching_dims ? *start_indices_batching_dims
+                                          : std::vector<int64_t>());
       if (!maybe_infer_shape([&] {
             return ShapeInference::InferGatherShape(operands[0]->shape(),
                                                     operands[1]->shape(),
@@ -3226,6 +3246,13 @@
       optional<bool> unique_indices = false;
       attrs["unique_indices"] = {/*required=*/false, AttrTy::kBool,
                                  &unique_indices};
+      optional<std::vector<int64_t>> input_batching_dims;
+      attrs["input_batching_dims"] = {
+          /*required=*/false, AttrTy::kBracedInt64List, &input_batching_dims};
+      optional<std::vector<int64_t>> scatter_indices_batching_dims;
+      attrs["scatter_indices_batching_dims"] = {/*required=*/false,
+                                                AttrTy::kBracedInt64List,
+                                                &scatter_indices_batching_dims};
 
       if ((!preset_operands && !ParseOperands(&operands, builder)) ||
           !ParseAttributes(attrs, allow_attributes, shape)) {
@@ -3243,7 +3270,13 @@
               /*update_window_dims=*/*update_window_dims,
               /*inserted_window_dims=*/*inserted_window_dims,
               /*scatter_dims_to_operand_dims=*/*scatter_dims_to_operand_dims,
-              /*index_vector_dim=*/*index_vector_dim);
+              /*index_vector_dim=*/*index_vector_dim,
+              /*input_batching_dims=*/
+              input_batching_dims ? *input_batching_dims
+                                  : std::vector<int64_t>(),
+              /*scatter_indices_batching_dims=*/
+              scatter_indices_batching_dims ? *scatter_indices_batching_dims
+                                            : std::vector<int64_t>());
 
       if (!maybe_infer_shape([&] {
             absl::InlinedVector<const Shape*, 3> arg_shapes;
@@ -3421,11 +3454,21 @@
       if (!ParseAttributeName(&attribute)) {
         return false;
       }
-      if (lexer_.GetKind() != TokKind::kString) {
+
+      std::string result;
+      if (lexer_.GetKind() == TokKind::kString) {
+        if (!ParseString(&result)) {
+          return false;
+        }
+      } else if (lexer_.GetKind() == TokKind::kLbrace) {
+        if (!ParseJsonDict(&result)) {
+          return false;
+        }
+      } else {
         return false;
       }
-      (*frontend_attributes->mutable_map())[attribute] = lexer_.GetStrVal();
-      lexer_.Lex();
+
+      (*frontend_attributes->mutable_map())[attribute] = result;
     } while (EatIfPresent(TokKind::kComma));
   }
   return ParseToken(TokKind::kRbrace,
diff --git a/third_party/xla/xla/service/hlo_parser_test.cc b/third_party/xla/xla/service/hlo_parser_test.cc
index 4c67c01..6378f08 100644
--- a/third_party/xla/xla/service/hlo_parser_test.cc
+++ b/third_party/xla/xla/service/hlo_parser_test.cc
@@ -15,6 +15,7 @@
 
 #include "xla/service/hlo_parser.h"
 
+#include <cstdint>
 #include <memory>
 #include <string>
 #include <string_view>
@@ -22,16 +23,25 @@
 #include <vector>
 
 #include <gtest/gtest.h>
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
 #include "absl/strings/ascii.h"
 #include "absl/strings/match.h"
 #include "absl/strings/str_cat.h"
 #include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/array.h"
 #include "xla/hlo/ir/collective_device_list.h"
 #include "xla/hlo/ir/hlo_casting_utils.h"
-#include "xla/hlo/ir/hlo_frontend_attributes.h"
 #include "xla/hlo/ir/hlo_instruction.h"
 #include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_opcode.h"
 #include "xla/hlo/ir/hlo_sharding.h"
+#include "xla/layout.h"
+#include "xla/layout_util.h"
+#include "xla/service/hlo_lexer.h"
+#include "xla/service/hlo_module_config.h"
 #include "xla/service/pattern_matcher.h"
 #include "xla/service/pattern_matcher_gmock.h"
 #include "xla/shape.h"
@@ -40,6 +50,7 @@
 #include "xla/tsl/lib/core/status_test_util.h"
 #include "xla/window_util.h"
 #include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
 #include "tsl/platform/status_matchers.h"
 #include "tsl/platform/statusor.h"
 #include "tsl/platform/test.h"
@@ -68,7 +79,7 @@
 //
 // In general we want to avoid these because we want HLO text to be
 // round-trippable!  But nested instructions, e.g. add(sqrt(x), y), cannot be
-// round-triped without modification.
+// round-tripped without modification.
 struct NonRoundtripTestData {
   std::string test_name;
   std::string input_module_string;
@@ -463,6 +474,96 @@
 
 )"
 },
+// composite call
+{
+"CompositeCall",
+R"(HloModule CompositeCall, entry_computation_layout={()->f32[]}
+
+%add (x: f32[]) -> f32[] {
+  %x = f32[] parameter(0)
+  %constant = f32[] constant(2)
+  ROOT %z = f32[] add(f32[] %x, f32[] %constant)
+}
+
+ENTRY %CompositeCall.v2 () -> f32[] {
+  %constant.1 = f32[] constant(42)
+  ROOT %call = f32[] call(f32[] %constant.1), to_apply=%add, is_composite=true, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor<i32>},composite.name="foo.bar",composite.version="1"}
+}
+
+)"
+},
+// composite call with extra frontend attributes
+{
+"CompositeCallWithExtraFrontendAttributes",
+R"(HloModule CompositeCall, entry_computation_layout={()->f32[]}
+
+%add (x: f32[]) -> f32[] {
+  %x = f32[] parameter(0)
+  %constant = f32[] constant(2)
+  ROOT %z = f32[] add(f32[] %x, f32[] %constant)
+}
+
+ENTRY %CompositeCall.v2 () -> f32[] {
+  %constant.1 = f32[] constant(42)
+  ROOT %call = f32[] call(f32[] %constant.1), to_apply=%add, is_composite=true, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor<i32>},composite.name="foo.bar",composite.version="1",foo="bar"}
+}
+
+)"
+},
+// composite call optional composite.attributes and composite.version
+{
+"CompositeCallOptionalAttributesAndVersion",
+R"(HloModule CompositeCall, entry_computation_layout={()->f32[]}
+
+%add (x: f32[]) -> f32[] {
+  %x = f32[] parameter(0)
+  %constant = f32[] constant(2)
+  ROOT %z = f32[] add(f32[] %x, f32[] %constant)
+}
+
+ENTRY %CompositeCall.v2 () -> f32[] {
+  %constant.1 = f32[] constant(42)
+  ROOT %call = f32[] call(f32[] %constant.1), to_apply=%add, is_composite=true, frontend_attributes={composite.name="foo.bar"}
+}
+
+)"
+},
+// composite call optional composite.attributes
+{
+"CompositeCallOptionalAttributes",
+R"(HloModule CompositeCall, entry_computation_layout={()->f32[]}
+
+%add (x: f32[]) -> f32[] {
+  %x = f32[] parameter(0)
+  %constant = f32[] constant(2)
+  ROOT %z = f32[] add(f32[] %x, f32[] %constant)
+}
+
+ENTRY %CompositeCall.v2 () -> f32[] {
+  %constant.1 = f32[] constant(42)
+  ROOT %call = f32[] call(f32[] %constant.1), to_apply=%add, is_composite=true, frontend_attributes={composite.name="foo.bar",composite.version="1"}
+}
+
+)"
+},
+// composite call optional composite.version
+{
+"CompositeCallOptionalVersion",
+R"(HloModule CompositeCall, entry_computation_layout={()->f32[]}
+
+%add (x: f32[]) -> f32[] {
+  %x = f32[] parameter(0)
+  %constant = f32[] constant(2)
+  ROOT %z = f32[] add(f32[] %x, f32[] %constant)
+}
+
+ENTRY %CompositeCall.v2 () -> f32[] {
+  %constant.1 = f32[] constant(42)
+  ROOT %call = f32[] call(f32[] %constant.1), to_apply=%add, is_composite=true, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor<i32>},composite.name="foo.bar"}
+}
+
+)"
+},
 // CustomCall with backend_config.
 {
 "CustomCallWithOpaque",
@@ -1062,6 +1163,18 @@
 )"
 },
 {
+"BatchGather",
+R"(HloModule StringifyGather, entry_computation_layout={(f32[50,49,48,47,46,512]{5,4,3,2,1,0}, s64[10,9,8,7,5,512]{5,4,3,2,1,0})->f32[10,9,8,7,30,29,28,27,26,512]{9,8,7,6,5,4,3,2,1,0}}
+
+ENTRY %Gather (input_tensor: f32[50,49,48,47,46,512], start_indices: s64[10,9,8,7,5,512]) -> f32[10,9,8,7,30,29,28,27,26,512] {
+  %input_tensor = f32[50,49,48,47,46,512]{5,4,3,2,1,0} parameter(0)
+  %start_indices = s64[10,9,8,7,5,512]{5,4,3,2,1,0} parameter(1)
+  ROOT %gather = f32[10,9,8,7,30,29,28,27,26,512]{9,8,7,6,5,4,3,2,1,0} gather(f32[50,49,48,47,46,512]{5,4,3,2,1,0} %input_tensor, s64[10,9,8,7,5,512]{5,4,3,2,1,0} %start_indices), offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, start_index_map={0,1,2,3,4}, operand_batching_dims={5}, start_indices_batching_dims={5}, index_vector_dim=4, slice_sizes={30,29,28,27,26,1}
+}
+
+)"
+},
+{
 "Scatter",
 R"(HloModule StringifyScatter, entry_computation_layout={(f32[50,49,48,47,46]{4,3,2,1,0}, s64[10,9,8,7,5]{4,3,2,1,0}, f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0})->f32[50,49,48,47,46]{4,3,2,1,0}}
 
@@ -1081,6 +1194,25 @@
 )"
 },
 {
+"BatchScatter",
+R"(HloModule StringifyScatter, entry_computation_layout={(f32[50,49,48,47,46,512]{5,4,3,2,1,0}, s64[10,9,8,7,5,512]{5,4,3,2,1,0}, f32[10,9,8,7,30,29,28,27,26,512]{9,8,7,6,5,4,3,2,1,0})->f32[50,49,48,47,46,512]{5,4,3,2,1,0}}
+
+%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] {
+  %lhs = f32[] parameter(0)
+  %rhs = f32[] parameter(1)
+  ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
+}
+
+ENTRY %Scatter (input_tensor: f32[50,49,48,47,46,512], scatter_indices: s64[10,9,8,7,5,512], updates: f32[10,9,8,7,30,29,28,27,26,512]) -> f32[50,49,48,47,46,512] {
+  %input_tensor = f32[50,49,48,47,46,512]{5,4,3,2,1,0} parameter(0)
+  %scatter_indices = s64[10,9,8,7,5,512]{5,4,3,2,1,0} parameter(1)
+  %updates = f32[10,9,8,7,30,29,28,27,26,512]{9,8,7,6,5,4,3,2,1,0} parameter(2)
+  ROOT %scatter = f32[50,49,48,47,46,512]{5,4,3,2,1,0} scatter(f32[50,49,48,47,46,512]{5,4,3,2,1,0} %input_tensor, s64[10,9,8,7,5,512]{5,4,3,2,1,0} %scatter_indices, f32[10,9,8,7,30,29,28,27,26,512]{9,8,7,6,5,4,3,2,1,0} %updates), update_window_dims={4,5,6,7,8}, inserted_window_dims={}, scatter_dims_to_operand_dims={0,1,2,3,4}, input_batching_dims={5}, scatter_indices_batching_dims={5}, index_vector_dim=4, to_apply=%add_F32.v3
+}
+
+)"
+},
+{
 "TupleScatter",
 R"(HloModule TupleScatter, entry_computation_layout={(f32[50,49,48,47,46]{4,3,2,1,0}, bf16[50,49,48,47,46]{4,3,2,1,0}, s64[10,9,8,7,5]{4,3,2,1,0}, f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0}, bf16[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0})->(f32[50,49,48,47,46]{4,3,2,1,0}, bf16[50,49,48,47,46]{4,3,2,1,0})}
 
diff --git a/third_party/xla/xla/service/hlo_unstacker.cc b/third_party/xla/xla/service/hlo_unstacker.cc
index c6b0971..21d0eb9 100644
--- a/third_party/xla/xla/service/hlo_unstacker.cc
+++ b/third_party/xla/xla/service/hlo_unstacker.cc
@@ -790,14 +790,8 @@
   HloInstruction* bitcast = mutable_dynamic_slicing_fusion->AddInstruction(
       HloInstruction::CreateBitcast(mutable_dynamic_slicing_fusion->shape(),
                                     new_operand));
-  HloInstruction* bitcast_fusion =
-      mutable_dynamic_slicing_fusion->AddInstruction(
-          HloInstruction::CreateFusion(mutable_dynamic_slicing_fusion->shape(),
-                                       HloInstruction::FusionKind::kLoop,
-                                       bitcast));
-
   return mutable_dynamic_slicing_fusion->ReplaceAllUsesWithDifferentShape(
-      bitcast_fusion);
+      bitcast);
 }
 
 // This function recognizes fusions with the following pattern:
@@ -1430,6 +1424,7 @@
                                   /*force_unroll=*/true, /*prepare=*/false));
     CHECK(unrolled);
   }
+  VLOG(3) << "after unstacking \n" << module->ToString();
   return true;
 }
 
diff --git a/third_party/xla/xla/service/hlo_unstacker_test.cc b/third_party/xla/xla/service/hlo_unstacker_test.cc
index 37a9843..3b00f92 100644
--- a/third_party/xla/xla/service/hlo_unstacker_test.cc
+++ b/third_party/xla/xla/service/hlo_unstacker_test.cc
@@ -34,18 +34,18 @@
 
 using UnstackerTest = HloTestBase;
 
-int64_t GetSliceCountInEntry(HloModule* module) {
-  int64_t slice_instrs_count = 0;
+int64_t GetInstrCountWithOpcodeInEntry(HloModule* module, HloOpcode opcode) {
+  int64_t instr_with_opcode_count = 0;
   for (HloInstruction* instr :
        module->entry_computation()->MakeInstructionPostOrder()) {
-    if (instr->opcode() == HloOpcode::kSlice) {
-      slice_instrs_count++;
+    if (instr->opcode() == opcode) {
+      instr_with_opcode_count++;
     }
   }
-  return slice_instrs_count;
+  return instr_with_opcode_count;
 }
 
-TEST_F(UnstackerTest, UnstackLoopSingleFusionUser) {
+TEST_F(UnstackerTest, UnstackDSFusionPattern) {
   std::string hlo_string = R"(
   HloModule SimpleLoop
   %fused_computation.slice (param_0.51117: s8[3,128,128], p1: s32[]) -> s8[128,128] {
@@ -63,7 +63,8 @@
     p1 = s8[3,128,128] get-tuple-element(wide_p), index=2
     one = s32[] constant(1)
     inc = s32[] add(i, one)
-    %fusion.67830 = s8[128,128] fusion(s8[3,128,128] p1, i), kind=kLoop, calls=%fused_computation.slice conv = bf16[8,128] convolution(bf16[8,128] p0, s8[128,128] %fusion.67830), dim_labels=bf_io->bf 
+    %fusion.67830 = s8[128,128] fusion(s8[3,128,128] p1, i), kind=kLoop, calls=%fused_computation.slice 
+    conv = bf16[8,128] convolution(bf16[8,128] p0, s8[128,128] %fusion.67830), dim_labels=bf_io->bf 
     ROOT out = (s32[], bf16[8,128], s8[3,128,128]) tuple(inc, conv, p1)
   }
 
@@ -80,7 +81,7 @@
     init = s32[] constant(0)
     while.input = (s32[], bf16[8,128], s8[3,128,128]) tuple(init, p1, p0)
     while.out = (s32[], bf16[8,128], s8[3,128,128]) while(while.input), condition=%while.cond , body=%while.body 
-    while_use = s8[3,128,128] get-tuple-element(while.out), index=2 
+    while_use = s8[3,128,128] get-tuple-element(while.out), index=2
     ROOT out = bf16[8,128] get-tuple-element(while.out), index=1
   }
   )";
@@ -90,12 +91,15 @@
   TF_ASSERT_OK_AND_ASSIGN(bool unstacked, HloUnstacker().Run(module.get()));
   EXPECT_TRUE(unstacked);
   // Check for the creation of slice instructions.
-  EXPECT_EQ(GetSliceCountInEntry(module.get()), 3);
+  EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kSlice), 3);
+  // Check that the bitcast is unfused and there are not fusions.
+  EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kFusion),
+            0);
   EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original),
-                                      std::nullopt));
+                                      std::nullopt, false));
 }
 
-TEST_F(UnstackerTest, UnstackLoopSingleFusionUser2) {
+TEST_F(UnstackerTest, UnstackReduceFusionPattern) {
   std::string hlo_string = R"(
   HloModule SimpleLoop
   dynamic-slice.609.reduce_sub_computation {
@@ -148,10 +152,10 @@
   TF_ASSERT_OK_AND_ASSIGN(bool unstacked, HloUnstacker().Run(module.get()));
   EXPECT_TRUE(unstacked);
   EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original),
-                                      std::nullopt));
+                                      std::nullopt, false));
 }
 
-TEST_F(UnstackerTest, UnstackLoopSingleFusionUserNoBitcast) {
+TEST_F(UnstackerTest, UnstackDSFusionPatternNoBitcast) {
   std::string hlo_string = R"(
   HloModule SimpleLoop
   %fused_computation.slice (param_0.51117: s8[3,128,128], p1: s32[]) -> s8[1,128,128] {
@@ -195,15 +199,17 @@
                           ParseAndReturnVerifiedModule(hlo_string));
   auto original = module->Clone();
   TF_ASSERT_OK_AND_ASSIGN(bool unstacked, HloUnstacker().Run(module.get()));
-  std::cout << module->ToString() << std::endl;
   EXPECT_TRUE(unstacked);
   // Check for the creation of slice instructions.
-  EXPECT_EQ(GetSliceCountInEntry(module.get()), 3);
+  EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kSlice), 3);
+  // Check that all the fusions are removed.
+  EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kFusion),
+            0);
   EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original),
                                       std::nullopt, false));
 }
 
-TEST_F(UnstackerTest, UnstackLoopSingleFusionUserNoBitcastKeepFused) {
+TEST_F(UnstackerTest, UnstackDSFusionPatternNoBitcastKeepFused) {
   std::string hlo_string = R"(
   HloModule SimpleLoop
   %fused_computation.slice (param_0.51117: s8[3,128,128], p1: s32[]) -> s8[1,128,128] {
@@ -249,33 +255,35 @@
   auto unfuse = [](HloInstruction* instruction) { return false; };
   TF_ASSERT_OK_AND_ASSIGN(bool unstacked,
                           HloUnstacker(unfuse).Run(module.get()));
-  std::cout << module->ToString() << std::endl;
   EXPECT_TRUE(unstacked);
   // Check for the creation of slice instructions.
-  EXPECT_EQ(GetSliceCountInEntry(module.get()), 0);
+  EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kSlice), 0);
+  // Check that dynamic-slices are still fused.
+  EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kFusion),
+            3);
   EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original),
                                       std::nullopt, false));
 }
 
-TEST_F(UnstackerTest, UnstackLoopSingleFusionUserDifferentLayout) {
+TEST_F(UnstackerTest, UnstackDSFusionPatternWithDifferentLayout) {
   std::string hlo_string = R"(
   HloModule SimpleLoop
   %fused_computation.30.clone (param_0.153: bf16[32,4,64,64,3], param_1.123: s32[]) -> bf16[64,4,64,3] {
-    %param_0.153 = bf16[32,4,64,64,3]{2,1,4,3,0:T(4,128)(2,1)} parameter(0)
+    %param_0.153 = bf16[32,4,64,64,3]{2,1,4,3,0} parameter(0)
     %param_1.123 = s32[]{:T(128)} parameter(1)
     %constant.227 = s32[]{:T(128)} constant(0)
-    %dynamic-slice.5 = bf16[1,4,64,64,3]{2,1,4,3,0:T(4,128)(2,1)} dynamic-slice(bf16[32,4,64,64,3]{2,1,4,3,0:T(4,128)(2,1)} %param_0.153, s32[]{:T(128)} %param_1.123, s32[]{:T(128)} %constant.227, s32[]{:T(128)} %constant.227, s32[]{:T(128)} %constant.227, /*index=5*/s32[]{:T(128)} %constant.227), dynamic_slice_sizes={1,4,64,64,3}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"0","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}]},"used_scoped_memory_configs":[]}
-    ROOT %bitcast.102 = bf16[64,4,64,3]{0,1,3,2:T(4,128)(2,1)} bitcast(bf16[1,4,64,64,3]{2,1,4,3,0:T(4,128)(2,1)} %dynamic-slice.5)
+    %dynamic-slice.5 = bf16[1,4,64,64,3]{2,1,4,3,0} dynamic-slice(bf16[32,4,64,64,3]{2,1,4,3,0} %param_0.153, s32[]{:T(128)} %param_1.123, s32[]{:T(128)} %constant.227, s32[]{:T(128)} %constant.227, s32[]{:T(128)} %constant.227, /*index=5*/s32[]{:T(128)} %constant.227), dynamic_slice_sizes={1,4,64,64,3}
+    ROOT %bitcast.102 = bf16[64,4,64,3]{0,1,3,2} bitcast(bf16[1,4,64,64,3]{2,1,4,3,0} %dynamic-slice.5)
   }
 
   %while.body (wide_param: (s32[], bf16[8,128], bf16[32,4,64,64,3])) -> (s32[], bf16[8,128], bf16[32,4,64,64,3]) {
     wide_p = (s32[], bf16[8,128], bf16[32,4,64,64,3]) parameter(0)
     i = s32[] get-tuple-element(wide_p), index=0
     p0 = bf16[8,128] get-tuple-element(wide_p), index=1
-    p1 = bf16[32,4,64,64,3]{2,1,4,3,0:T(4,128)(2,1)} get-tuple-element(wide_p), index=2
+    p1 = bf16[32,4,64,64,3]{2,1,4,3,0} get-tuple-element(wide_p), index=2
     one = s32[] constant(1)
     inc = s32[] add(i, one)
-    %fusion.67830 = bf16[64,4,64,3]{0,1,3,2:T(4,128)(2,1)} fusion(p1, i), kind=kLoop, calls=%fused_computation.30.clone
+    %fusion.67830 = bf16[64,4,64,3]{0,1,3,2} fusion(p1, i), kind=kLoop, calls=%fused_computation.30.clone
     ROOT out = (s32[], bf16[8,128], bf16[32,4,64,64,3]) tuple(inc, p0, p1)
   }
 
@@ -291,7 +299,7 @@
     p1 = bf16[8,128] parameter(1)
     init = s32[] constant(0)
     while.input = (s32[], bf16[8,128], bf16[32,4,64,64,3]) tuple(init, p1, p0)
-    while.out = (s32[], bf16[8,128], bf16[32,4,64,64,3]) while(while.input), condition=%while.cond , body=%while.body 
+    while.out = (s32[], bf16[8,128], bf16[32,4,64,64,3]) while(while.input), condition=%while.cond , body=%while.body
     while_use = bf16[32,4,64,64,3] get-tuple-element(while.out), index=2
     ROOT out = bf16[8,128] get-tuple-element(while.out), index=1
   }
@@ -301,11 +309,17 @@
   auto original = module->Clone();
   TF_ASSERT_OK_AND_ASSIGN(bool unstacked, HloUnstacker().Run(module.get()));
   EXPECT_TRUE(unstacked);
+  // Check for the creation of slice instructions.
+  EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kSlice),
+            32);
+  // Check that dynamic-slices are still fused.
+  EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kFusion),
+            0);
   EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original),
                                       std::nullopt));
 }
 
-TEST_F(UnstackerTest, UnstackLoopSingleNestedFusionUser) {
+TEST_F(UnstackerTest, UnstackNestedDSFusionPattern) {
   std::string hlo_string = R"(
   HloModule SimpleLoop
   %fused_computation.slice (param_0.51117: s8[3,128,128], p1: s32[]) -> s8[128,128] {
@@ -358,14 +372,14 @@
   TF_ASSERT_OK_AND_ASSIGN(bool unstacked, HloUnstacker().Run(module.get()));
   EXPECT_TRUE(unstacked);
   // Check for the creation of slice instructions.
-  EXPECT_EQ(GetSliceCountInEntry(module.get()), 3);
+  EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kSlice), 3);
   EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original),
                                       std::nullopt, false));
 }
 
 // Instead of slicing the entire shape, this test slices only even elements from
 // the first parameter.
-TEST_F(UnstackerTest, UnstackLoopSingleNestedFusionUserDynamicIndex) {
+TEST_F(UnstackerTest, UnstackNestedDSFusionPatternWithDynamicIndex) {
   std::string hlo_string = R"(
   HloModule SimpleLoop
   %fused_computation.slice (param_0.51117: s8[6,128,128], p1: s32[]) -> s8[128,128] {
@@ -423,7 +437,7 @@
                                       std::nullopt, false));
 }
 
-TEST_F(UnstackerTest, UnstackLoopSingleNestedFusionUserMultipleIndex) {
+TEST_F(UnstackerTest, UnstackNestedDSFusionPatternWithMultipleIndex) {
   std::string hlo_string = R"(
     HloModule SimpleLoop
     %fused_computation.slice.1 (param_0.51117: s8[4,128,128], p1: s32[]) -> s8[128,128] {
@@ -497,12 +511,12 @@
   EXPECT_TRUE(unstacked);
   // Check for the creation of slice instructions. For each unstacked input, we
   // create 4 slices, 8 in total.
-  EXPECT_EQ(GetSliceCountInEntry(module.get()), 8);
+  EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kSlice), 8);
   EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original),
                                       std::nullopt, false));
 }
 
-TEST_F(UnstackerTest, UnstackLoopSingleNestedFusionUserDiffereOperandsOrder) {
+TEST_F(UnstackerTest, UnstackNestedDSFusionPatternWithDiffereOperandsOrder) {
   std::string hlo_string = R"(
   HloModule SimpleLoop
   %fused_computation.slice (param_0.51117: s8[3,128,128], p1: s32[]) -> s8[128,128] {
@@ -555,12 +569,12 @@
   TF_ASSERT_OK_AND_ASSIGN(bool unstacked, HloUnstacker().Run(module.get()));
   EXPECT_TRUE(unstacked);
   // Check for the creation of slice instructions.
-  EXPECT_EQ(GetSliceCountInEntry(module.get()), 3);
+  EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kSlice), 3);
   EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original),
                                       std::nullopt, false));
 }
 
-TEST_F(UnstackerTest, UnstackLoopMultipleNestedFusionUsersSameUnstackingComps) {
+TEST_F(UnstackerTest, UnstackNestedDSFusionPatternWithSameUnstackingComps) {
   std::string hlo_string = R"(
   HloModule SimpleLoop
   %fused_computation.slice.1 (param_0.51117: s8[3,128,128], p1: s32[]) -> s8[128,128] {
@@ -631,12 +645,12 @@
   TF_ASSERT_OK_AND_ASSIGN(bool unstacked, HloUnstacker().Run(module.get()));
   EXPECT_TRUE(unstacked);
   // Check for the creation of slice instructions.
-  EXPECT_EQ(GetSliceCountInEntry(module.get()), 3);
+  EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kSlice), 3);
   EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original),
                                       std::nullopt, false));
 }
 
-TEST_F(UnstackerTest, NotUnstackLoopMultipleDifferentUnstackingComps) {
+TEST_F(UnstackerTest, NotUnstackNestedDSFusionPatternWithSameUnstackingComps) {
   std::string hlo_string = R"(
   HloModule SimpleLoop
   %fused_computation.slice.1 (param_0.51117: s8[3,128,128], p1: s32[]) -> s8[1,128,128] {
@@ -691,7 +705,86 @@
   EXPECT_FALSE(unstacked);
 }
 
-TEST_F(UnstackerTest, UnstackMultipleLoops) {
+TEST_F(UnstackerTest, UnstackNestedDSFusionPatternSingleNestedLoop) {
+  std::string hlo_string = R"(
+    HloModule SimpleLoop
+    %fused_computation.slice (param_0.51117: s8[4,128,128], p1: s32[]) -> s8[128,128] {
+      %param_0.51117 = s8[4,128,128] parameter(0)
+      p1 = s32[] parameter(1)
+      %constant.85694 = s32[] constant(0)
+      %dynamic-slice.22040 = s8[1,128,128] dynamic-slice(s8[4,128,128] %param_0.51117, p1, s32[] %constant.85694, s32[] %constant.85694), dynamic_slice_sizes={1,128,128}
+      ROOT %bitcast.31250 = s8[128,128] bitcast(s8[1,128,128] %dynamic-slice.22040)
+    }
+
+    %fused_computation.inner (param_0.34523: bf16[8,128], param_1.30691: s8[4,128,128], p2: s32[]) -> bf16[8,128] {
+      %param_0.34523 = bf16[8,128] parameter(0)
+      %param_1.30691 = s8[4,128,128] parameter(1)
+      p2 = s32[] parameter(2)
+      %fusion.67830 = s8[128,128] fusion(s8[4,128,128] %param_1.30691, p2), kind=kLoop, calls=%fused_computation.slice
+      ROOT %convolution.3447 = bf16[8,128] convolution(bf16[8,128] %param_0.34523, s8[128,128] %fusion.67830), dim_labels=bf_io->bf
+    }
+
+    %while.body.inner (wide_param: (s32[], bf16[8,128], s8[4,128,128])) -> (s32[], bf16[8,128], s8[4,128,128]) {
+      wide_p = (s32[], bf16[8,128], s8[4,128,128]) parameter(0)
+      i = s32[] get-tuple-element(wide_p), index=0
+      inner_param_0 = bf16[8,128] get-tuple-element(wide_p), index=1
+      inner_param_1 = s8[4,128,128] get-tuple-element(wide_p), index=2
+      one = s32[] constant(1)
+      inc = s32[] add(i, one)
+      fusion.conv = bf16[8,128] fusion(inner_param_0, inner_param_1, i), kind=kOutput, calls=%fused_computation.inner
+      ROOT out = (s32[], bf16[8,128], s8[4,128,128]) tuple(inc, fusion.conv, inner_param_1)
+    }
+
+    %while.cond.inner (wide_param: (s32[], bf16[8,128], s8[4,128,128])) -> pred[] {
+      wide_p = (s32[], bf16[8,128], s8[4,128,128]) parameter(0)
+      i = s32[] get-tuple-element(wide_p), index=0
+      %constant.12857 = s32[] constant(4)
+      ROOT %compare.1921 = pred[]{:T(512)} compare(s32[] i, s32[] %constant.12857), direction=LT
+    }
+
+    %while.body (wide_param: (s32[], bf16[8,128], s8[4,128,128])) -> (s32[], bf16[8,128], s8[4,128,128]) {
+      wide_p = (s32[], bf16[8,128], s8[4,128,128]) parameter(0)
+      i = s32[] get-tuple-element(wide_p), index=0
+      param0 = bf16[8,128] get-tuple-element(wide_p), index=1
+      param1 = s8[4,128,128] get-tuple-element(wide_p), index=2
+      one = s32[] constant(2)
+      zero = s32[] constant(0)
+      mult = s32[] multiply(i, one)
+      inner.in = (s32[], bf16[8,128], s8[4,128,128]) tuple(zero, param0, param1)
+      inner.out = (s32[], bf16[8,128], s8[4,128,128]) while(inner.in), condition=%while.cond.inner, body=%while.body.inner
+      fusion.conv.inner = bf16[8,128] get-tuple-element(inner.out), index=1
+      ROOT out = (s32[], bf16[8,128], s8[4,128,128]) tuple(mult, fusion.conv.inner, param1)
+    }
+
+    %while.cond (wide_param: (s32[], bf16[8,128], s8[4,128,128])) -> pred[] {
+      wide_p = (s32[], bf16[8,128], s8[4,128,128]) parameter(0)
+      i = s32[] get-tuple-element(wide_p), index=0
+      %constant.12857 = s32[] constant(20)
+      add = s32[] add(%constant.12857, %constant.12857)
+      ROOT %compare.1921 = pred[]{:T(512)} compare(s32[] i, add), direction=LT
+    }
+
+    ENTRY main {
+      weight = s8[4,128,128] parameter(0)
+      p1 = bf16[8,128] parameter(1)
+      init = s32[] constant(1)
+      while.input = (s32[], bf16[8,128], s8[4,128,128]) tuple(init, p1, weight)
+      while.out = (s32[], bf16[8,128], s8[4,128,128]) while(while.input), condition=%while.cond , body=%while.body
+      ROOT out = bf16[8,128] get-tuple-element(while.out), index=1
+    }
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  auto original = module->Clone();
+  TF_ASSERT_OK_AND_ASSIGN(bool unstacked, HloUnstacker().Run(module.get()));
+  EXPECT_TRUE(unstacked);
+  // Check for the creation of slice instructions.
+  EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kSlice), 4);
+  EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original),
+                                      std::nullopt, false));
+}
+
+TEST_F(UnstackerTest, UnstackNestedDSFusionPatternTwoNestedLoops) {
   std::string hlo_string = R"(
     HloModule SimpleLoop
     %fused_computation.slice1 (param_0.51117: s8[4,128,128], p1: s32[]) -> s8[128,128] {
@@ -827,91 +920,12 @@
   EXPECT_TRUE(unstacked);
   // Check for the creation of slice instructions. For each loop there is one
   // unstacked input that creates 4 slices, in total 8 slices for two loops.
-  EXPECT_EQ(GetSliceCountInEntry(module.get()), 8);
+  EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kSlice), 8);
   EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original),
                                       std::nullopt, false));
 }
 
-TEST_F(UnstackerTest, UnstackNestedLoopSingleNestedFusionUser) {
-  std::string hlo_string = R"(
-    HloModule SimpleLoop
-    %fused_computation.slice (param_0.51117: s8[4,128,128], p1: s32[]) -> s8[128,128] {
-      %param_0.51117 = s8[4,128,128] parameter(0)
-      p1 = s32[] parameter(1)
-      %constant.85694 = s32[] constant(0)
-      %dynamic-slice.22040 = s8[1,128,128] dynamic-slice(s8[4,128,128] %param_0.51117, p1, s32[] %constant.85694, s32[] %constant.85694), dynamic_slice_sizes={1,128,128}
-      ROOT %bitcast.31250 = s8[128,128] bitcast(s8[1,128,128] %dynamic-slice.22040)
-    }
-
-    %fused_computation.inner (param_0.34523: bf16[8,128], param_1.30691: s8[4,128,128], p2: s32[]) -> bf16[8,128] {
-      %param_0.34523 = bf16[8,128] parameter(0)
-      %param_1.30691 = s8[4,128,128] parameter(1)
-      p2 = s32[] parameter(2)
-      %fusion.67830 = s8[128,128] fusion(s8[4,128,128] %param_1.30691, p2), kind=kLoop, calls=%fused_computation.slice
-      ROOT %convolution.3447 = bf16[8,128] convolution(bf16[8,128] %param_0.34523, s8[128,128] %fusion.67830), dim_labels=bf_io->bf
-    }
-
-    %while.body.inner (wide_param: (s32[], bf16[8,128], s8[4,128,128])) -> (s32[], bf16[8,128], s8[4,128,128]) {
-      wide_p = (s32[], bf16[8,128], s8[4,128,128]) parameter(0)
-      i = s32[] get-tuple-element(wide_p), index=0
-      inner_param_0 = bf16[8,128] get-tuple-element(wide_p), index=1
-      inner_param_1 = s8[4,128,128] get-tuple-element(wide_p), index=2
-      one = s32[] constant(1)
-      inc = s32[] add(i, one)
-      fusion.conv = bf16[8,128] fusion(inner_param_0, inner_param_1, i), kind=kOutput, calls=%fused_computation.inner
-      ROOT out = (s32[], bf16[8,128], s8[4,128,128]) tuple(inc, fusion.conv, inner_param_1)
-    }
-
-    %while.cond.inner (wide_param: (s32[], bf16[8,128], s8[4,128,128])) -> pred[] {
-      wide_p = (s32[], bf16[8,128], s8[4,128,128]) parameter(0)
-      i = s32[] get-tuple-element(wide_p), index=0
-      %constant.12857 = s32[] constant(4)
-      ROOT %compare.1921 = pred[]{:T(512)} compare(s32[] i, s32[] %constant.12857), direction=LT
-    }
-
-    %while.body (wide_param: (s32[], bf16[8,128], s8[4,128,128])) -> (s32[], bf16[8,128], s8[4,128,128]) {
-      wide_p = (s32[], bf16[8,128], s8[4,128,128]) parameter(0)
-      i = s32[] get-tuple-element(wide_p), index=0
-      param0 = bf16[8,128] get-tuple-element(wide_p), index=1
-      param1 = s8[4,128,128] get-tuple-element(wide_p), index=2
-      one = s32[] constant(2)
-      zero = s32[] constant(0)
-      mult = s32[] multiply(i, one)
-      inner.in = (s32[], bf16[8,128], s8[4,128,128]) tuple(zero, param0, param1)
-      inner.out = (s32[], bf16[8,128], s8[4,128,128]) while(inner.in), condition=%while.cond.inner, body=%while.body.inner
-      fusion.conv.inner = bf16[8,128] get-tuple-element(inner.out), index=1
-      ROOT out = (s32[], bf16[8,128], s8[4,128,128]) tuple(mult, fusion.conv.inner, param1)
-    }
-
-    %while.cond (wide_param: (s32[], bf16[8,128], s8[4,128,128])) -> pred[] {
-      wide_p = (s32[], bf16[8,128], s8[4,128,128]) parameter(0)
-      i = s32[] get-tuple-element(wide_p), index=0
-      %constant.12857 = s32[] constant(20)
-      add = s32[] add(%constant.12857, %constant.12857)
-      ROOT %compare.1921 = pred[]{:T(512)} compare(s32[] i, add), direction=LT
-    }
-
-    ENTRY main {
-      weight = s8[4,128,128] parameter(0)
-      p1 = bf16[8,128] parameter(1)
-      init = s32[] constant(1)
-      while.input = (s32[], bf16[8,128], s8[4,128,128]) tuple(init, p1, weight)
-      while.out = (s32[], bf16[8,128], s8[4,128,128]) while(while.input), condition=%while.cond , body=%while.body
-      ROOT out = bf16[8,128] get-tuple-element(while.out), index=1
-    }
-  )";
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(hlo_string));
-  auto original = module->Clone();
-  TF_ASSERT_OK_AND_ASSIGN(bool unstacked, HloUnstacker().Run(module.get()));
-  EXPECT_TRUE(unstacked);
-  // Check for the creation of slice instructions.
-  EXPECT_EQ(GetSliceCountInEntry(module.get()), 4);
-  EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original),
-                                      std::nullopt, false));
-}
-
-TEST_F(UnstackerTest, UnstackSingleLoopOnlyWithDSAndDUS) {
+TEST_F(UnstackerTest, UnstackDSAndDUSPattern) {
   std::string hlo_string = R"(
   HloModule SimpleLoop
   %fused_computation.slice (param_0.51117: s32[4,3], offset: s32[]) -> s32[3] {
@@ -975,7 +989,7 @@
 // Unstacking outer loop at index 1 forces to unstacked inner while at index 1
 // as well. This is because the output of the outer loop at index 1 is aliased
 // to the output of the inner while at index 1.
-TEST_F(UnstackerTest, UnstackNestedLoopWithDSAndDUS) {
+TEST_F(UnstackerTest, UnstackDSAndDUSPatternNestedLoop) {
   std::string hlo_string = R"(
   HloModule SimpleLoop
 
@@ -1059,7 +1073,7 @@
 
 // Unstacking the first loop at index 1 forces to unstack the second loop at
 // index 1 as well.
-TEST_F(UnstackerTest, UnstackLoopFeedingLoopWithDUS) {
+TEST_F(UnstackerTest, UnstackDSAndDUSPatternLoopFeedingLoop) {
   std::string hlo_string = R"(
   HloModule SimpleLoop
 
@@ -1076,45 +1090,43 @@
     %param_0.51117 = bf16[4,1,8,257,128] parameter(0)
     offset = s32[] parameter(1)
     zero = s32[] constant(0)
-    %dynamic-slice.22040 = bf16[1,1,8,257,128]
-    dynamic-slice(bf16[4,1,8,257,128] %param_0.51117, offset, zero, zero,
-    zero, zero), dynamic_slice_sizes={1,1,8,257,128} ROOT %bitcast.31250 =
-    bf16[1,8,257,128] bitcast(%dynamic-slice.22040)
+    %dynamic-slice.22040 = bf16[1,1,8,257,128] dynamic-slice(bf16[4,1,8,257,128] %param_0.51117, offset, zero, zero, zero, zero), dynamic_slice_sizes={1,1,8,257,128} 
+    ROOT %bitcast.31250 = bf16[1,8,257,128] bitcast(%dynamic-slice.22040)
   }
 
   first.body {
     loop_var.1 = (s32[], bf16[4,1,8,257,128]) parameter(0) 
-    get-tuple-element.1 = s32[] get-tuple-element(loop_var.1),index=0 
-    get-tuple-element.2 = bf16[4,1,8,257,128] get-tuple-element(loop_var.1), index=1 
+    get-tuple-element.1 = s32[] get-tuple-element(loop_var.1),index=0
+    get-tuple-element.2 = bf16[4,1,8,257,128] get-tuple-element(loop_var.1), index=1
     constant = bf16[1,8,257,128] constant({...})
     sliced = bf16[1,8,257,128] fusion(get-tuple-element.2, get-tuple-element.1), kind=kLoop, calls=%fused_computation.slice
     tmp = bf16[1,8,257,128] add(sliced, sliced)
     one = s32[] constant(1)
-    idx = s32[] add(get-tuple-element.1, one) 
+    idx = s32[] add(get-tuple-element.1, one)
     ROOT out = tuple(idx, get-tuple-element.2)
   }
   first.condition {
-    loop_var.1 = (s32[], bf16[4,1,8,257,128])
-    parameter(0) get-tuple-element.1 = s32[] get-tuple-element(loop_var.1),
-    index=0 constant.2 = s32[] constant(4) ROOT less-than = pred[]
-    compare(get-tuple-element.1, constant.2), direction=LT
+    loop_var.1 = (s32[], bf16[4,1,8,257,128]) parameter(0) 
+    get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
+    constant.2 = s32[] constant(4) 
+    ROOT less-than = pred[] compare(get-tuple-element.1, constant.2), direction=LT
   }
   
   next.body {
     loop_var.1 = (s32[], bf16[4,1,8,257,128]) parameter(0) 
-    get-tuple-element.1 = s32[] get-tuple-element(loop_var.1),index=0 
-    get-tuple-element.2 = bf16[4,1,8,257,128] get-tuple-element(loop_var.1), index=1 
+    get-tuple-element.1 = s32[] get-tuple-element(loop_var.1),index=0
+    get-tuple-element.2 = bf16[4,1,8,257,128] get-tuple-element(loop_var.1), index=1
     constant = bf16[1,8,257,128] constant({...})
-    update.sliced = bf16[4,1,8,257,128] fusion(get-tuple-element.2, get-tuple-element.1, constant), kind=kLoop, calls=%fused_computation.update.slice 
+    update.sliced = bf16[4,1,8,257,128] fusion(get-tuple-element.2, get-tuple-element.1, constant), kind=kLoop, calls=%fused_computation.update.slice
     one = s32[] constant(1)
-    idx = s32[] add(get-tuple-element.1, one) 
+    idx = s32[] add(get-tuple-element.1, one)
     ROOT out = tuple(idx, update.sliced)
   }
   next.condition {
-    loop_var.1 = (s32[], bf16[4,1,8,257,128])
-    parameter(0) get-tuple-element.1 = s32[] get-tuple-element(loop_var.1),
-    index=0 constant.2 = s32[] constant(4) ROOT less-than = pred[]
-    compare(get-tuple-element.1, constant.2), direction=LT
+    loop_var.1 = (s32[], bf16[4,1,8,257,128]) parameter(0)
+    get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
+    constant.2 = s32[] constant(4) 
+    ROOT less-than = pred[] compare(get-tuple-element.1, constant.2), direction=LT
   }
 
   ENTRY SimpleLoop {
@@ -1138,7 +1150,7 @@
   EXPECT_TRUE(unstacked);
 }
 
-TEST_F(UnstackerTest, UnstackLoopFeedingLoopWithDUSFusionWithPad) {
+TEST_F(UnstackerTest, UnstackDUSFusionWithPadPatternLoopFeedingLoop) {
   std::string hlo_string = R"(
   HloModule SimpleLoop
   fused_computation.75.clone {
@@ -1213,7 +1225,7 @@
   EXPECT_TRUE(unstacked);
 }
 
-TEST_F(UnstackerTest, UnstackSingleLoopWithDSFusionWithAdd) {
+TEST_F(UnstackerTest, UnstackDUSFusionWithAddPattern) {
   std::string hlo_string = R"(
   HloModule SimpleLoop
   
diff --git a/third_party/xla/xla/service/hlo_value_semantics_analysis.cc b/third_party/xla/xla/service/hlo_value_semantics_analysis.cc
index 67259f8..06ae990 100644
--- a/third_party/xla/xla/service/hlo_value_semantics_analysis.cc
+++ b/third_party/xla/xla/service/hlo_value_semantics_analysis.cc
@@ -1757,6 +1757,7 @@
       [&](const ShapeIndex& index,
           const HloValueSemantics* semantics) -> absl::Status {
         std::vector<HloValueSemantics> semantics_vector;
+        semantics_vector.reserve(semantics_tree_vec.size());
         for (size_t i = 0; i < semantics_tree_vec.size(); ++i) {
           semantics_vector.push_back(
               *(semantics_tree_vec[i].find(index)->second));
diff --git a/third_party/xla/xla/service/hlo_verifier.cc b/third_party/xla/xla/service/hlo_verifier.cc
index ebf03de..3a8d958 100644
--- a/third_party/xla/xla/service/hlo_verifier.cc
+++ b/third_party/xla/xla/service/hlo_verifier.cc
@@ -35,6 +35,7 @@
 #include "absl/status/status.h"
 #include "absl/status/statusor.h"
 #include "absl/strings/match.h"
+#include "absl/strings/numbers.h"
 #include "absl/strings/str_cat.h"
 #include "absl/strings/str_join.h"
 #include "absl/strings/string_view.h"
@@ -1322,6 +1323,34 @@
   for (int64_t i = 0; i < call->to_apply()->num_parameters(); ++i) {
     TF_RETURN_IF_ERROR(CheckOperandAndParameter(call, i, call->to_apply(), i));
   }
+  if (call->is_composite()) {
+    TF_RET_CHECK(call->has_frontend_attributes())
+        << "A composite call op must have frontend attributes";
+    auto map = call->frontend_attributes().map();
+    if (auto name = map.find("composite.name");
+        name == map.end() || name->second.empty()) {
+      return InvalidArgument(
+          "A composite call op must have frontend attributes with key "
+          "composite.name whose value is non-empty");
+    }
+    if (auto attributes = map.find("composite.attributes");
+        attributes != map.end() && attributes->second.empty()) {
+      return InvalidArgument(
+          "A composite call op must have frontend attributes with key "
+          "composite.attributes whose value is default: {} or non-empty");
+    }
+    if (auto version_str = map.find("composite.version");
+        version_str != map.end()) {
+      int64_t version = 0;
+      if (!absl::SimpleAtoi(version_str->second, &version) || version < 0) {
+        return InvalidArgument(
+            "A composite call op must have frontend attributes with a "
+            "composite.version whose value is a non-negative integer but got: "
+            "%s",
+            version_str->second);
+      }
+    }
+  }
   // The shape of kCall should match the shape of the computation it calls.
   return CheckShape(call, call->to_apply()->root_instruction()->shape());
 }
@@ -1920,6 +1949,26 @@
         }
         return ShapesSame(instruction->shape(), inferred_shape, equal);
       }
+      case HloOpcode::kCopy: {
+        // Disallow host offloading copies which change FpPrecision.
+        if (opts_.IsLayoutSensitive()) {
+          if (instruction->shape().has_layout() &&
+              inferred_shape.has_layout()) {
+            int64_t instruction_memory_space =
+                instruction->shape().layout().memory_space();
+            int64_t operand_memory_space =
+                inferred_shape.layout().memory_space();
+            if (instruction_memory_space != operand_memory_space &&
+                (instruction_memory_space == Layout::kHostMemorySpace ||
+                 operand_memory_space == Layout::kHostMemorySpace)) {
+              // Is a host->device copy for a device->host copy.
+              return Shape::Equal().IgnoreMemorySpaceInLayout()(
+                  instruction->shape(), inferred_shape);
+            }
+          }
+        }
+        [[fallthrough]];
+      }
 
       // We allow arbitrary layout and f32->bf16 transformations on all other
       // instructions, although this may be made more strict pending discussion
@@ -2907,6 +2956,15 @@
       }
     }
 
+    if (instruction->has_to_apply() &&
+        instruction->to_apply()->execution_thread() !=
+            instruction->parent()->execution_thread()) {
+      return Internal(
+          "%s top_apply computation execution thread does not match (%s vs %s)",
+          instruction->name(), instruction->to_apply()->execution_thread(),
+          instruction->parent()->execution_thread());
+    }
+
     return absl::OkStatus();
   }
 
diff --git a/third_party/xla/xla/service/hlo_verifier_test.cc b/third_party/xla/xla/service/hlo_verifier_test.cc
index 82f712a..b7055e9 100644
--- a/third_party/xla/xla/service/hlo_verifier_test.cc
+++ b/third_party/xla/xla/service/hlo_verifier_test.cc
@@ -83,6 +83,15 @@
                     LayoutAssignment::InstructionCanChangeLayout) {}
 };
 
+class HloVerifierTestLayoutSensitiveAndAllowMixedPrecision
+    : public HloTestBase {
+ public:
+  HloVerifierTestLayoutSensitiveAndAllowMixedPrecision()
+      : HloTestBase(/*verifier_layout_sensitive=*/true,
+                    /*allow_mixed_precision_in_hlo_verifier=*/true,
+                    LayoutAssignment::InstructionCanChangeLayout) {}
+};
+
 class HloVerifierTestLayoutFusion : public HloTestBase {
  public:
   HloVerifierTestLayoutFusion()
@@ -216,8 +225,164 @@
   auto status = verifier().Run(module.get()).status();
   ASSERT_FALSE(status.ok());
   EXPECT_THAT(status.message(),
-              HasSubstr("expects parent computation thread name same as called "
-                        "computation's thread name"));
+              HasSubstr("mycall top_apply computation execution thread does "
+                        "not match (parallel_thread vs main)"));
+}
+
+TEST_F(HloVerifierTest, CompositeCall) {
+  constexpr absl::string_view hlo = R"(
+  HloModule Module
+
+  add_n {
+    x = f32[] parameter(0)
+    constant = f32[] constant(2)
+    ROOT z = f32[] add(f32[] x, f32[] constant)
+  }
+
+  ENTRY entry {
+    constant = f32[] constant(42)
+    ROOT mycall = f32[] call(constant), is_composite=true, to_apply=add_n, frontend_attributes={composite.name="foo.bar",composite.attributes={n = 1 : i32, tensor = dense<1> : tensor<i32>},composite.version="1"}
+  }
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo));
+
+  auto status = verifier().Run(module.get()).status();
+  EXPECT_TRUE(status.ok());
+}
+
+TEST_F(HloVerifierTest, CompositeCallMissingFrontendAttributes) {
+  constexpr absl::string_view hlo = R"(
+  HloModule Module
+
+  add_n {
+    x = f32[] parameter(0)
+    constant = f32[] constant(2)
+    ROOT z = f32[] add(f32[] x, f32[] constant)
+  }
+
+  ENTRY entry {
+    constant = f32[] constant(42)
+    ROOT mycall = f32[] call(constant), is_composite=true, to_apply=add_n
+  }
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo));
+
+  auto status = verifier().Run(module.get()).status();
+  ASSERT_FALSE(status.ok());
+  EXPECT_THAT(status.message(),
+              HasSubstr("A composite call op must have frontend attributes"));
+}
+
+TEST_F(HloVerifierTest, CompositeCallOptionalAttributesAndVersion) {
+  constexpr absl::string_view hlo = R"(
+  HloModule Module
+
+  add_n {
+    x = f32[] parameter(0)
+    constant = f32[] constant(2)
+    ROOT z = f32[] add(f32[] x, f32[] constant)
+  }
+
+  ENTRY entry {
+    constant = f32[] constant(42)
+    ROOT mycall = f32[] call(constant), is_composite=true, to_apply=add_n, frontend_attributes={composite.name="foo.bar"}
+  }
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo));
+
+  auto status = verifier().Run(module.get()).status();
+  EXPECT_TRUE(status.ok());
+}
+
+TEST_F(HloVerifierTest, CompositeCallOptionalAttributes) {
+  constexpr absl::string_view hlo = R"(
+  HloModule Module
+
+  add_n {
+    x = f32[] parameter(0)
+    constant = f32[] constant(2)
+    ROOT z = f32[] add(f32[] x, f32[] constant)
+  }
+
+  ENTRY entry {
+    constant = f32[] constant(42)
+    ROOT mycall = f32[] call(constant), is_composite=true, to_apply=add_n, frontend_attributes={composite.name="foo.bar",composite.version="1"}
+  }
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo));
+
+  auto status = verifier().Run(module.get()).status();
+  EXPECT_TRUE(status.ok());
+}
+
+TEST_F(HloVerifierTest, CompositeCallMissingName) {
+  constexpr absl::string_view hlo = R"(
+  HloModule Module
+
+  add_n {
+    x = f32[] parameter(0)
+    constant = f32[] constant(2)
+    ROOT z = f32[] add(f32[] x, f32[] constant)
+  }
+
+  ENTRY entry {
+    constant = f32[] constant(42)
+    ROOT mycall = f32[] call(constant), is_composite=true, to_apply=add_n, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor<i32>},composite.version="1"}
+  }
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo));
+
+  auto status = verifier().Run(module.get()).status();
+  ASSERT_FALSE(status.ok());
+  EXPECT_THAT(status.message(),
+              HasSubstr("A composite call op must have frontend attributes "
+                        "with key composite.name whose value is non-empty"));
+}
+
+TEST_F(HloVerifierTest, CompositeCallOptionalVersion) {
+  constexpr absl::string_view hlo = R"(
+  HloModule Module
+
+  add_n {
+    x = f32[] parameter(0)
+    constant = f32[] constant(2)
+    ROOT z = f32[] add(f32[] x, f32[] constant)
+  }
+
+  ENTRY entry {
+    constant = f32[] constant(42)
+    ROOT mycall = f32[] call(constant), is_composite=true, to_apply=add_n, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor<i32>},composite.name="foo.bar"}
+  }
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo));
+
+  auto status = verifier().Run(module.get()).status();
+  EXPECT_TRUE(status.ok());
+}
+
+TEST_F(HloVerifierTest, CompositeCallNonNegativeVersion) {
+  constexpr absl::string_view hlo = R"(
+  HloModule Module
+
+  add_n {
+    x = f32[] parameter(0)
+    constant = f32[] constant(2)
+    ROOT z = f32[] add(f32[] x, f32[] constant)
+  }
+
+  ENTRY entry {
+    constant = f32[] constant(42)
+    ROOT mycall = f32[] call(constant), is_composite=true, to_apply=add_n, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor<i32>},composite.name="foo.bar",composite.version="-1"}
+  }
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo));
+
+  auto status = verifier().Run(module.get()).status();
+  ASSERT_FALSE(status.ok());
+  EXPECT_THAT(
+      status.message(),
+      HasSubstr("A composite call op must have frontend attributes with a "
+                "composite.version whose value is a non-negative integer"));
 }
 
 TEST_F(HloVerifierTest, CheckConditionalOperandParameterShapesMismatch) {
@@ -2000,10 +2165,10 @@
   )";
   TF_ASSERT_OK_AND_ASSIGN(auto module,
                           ParseAndReturnUnverifiedModule(kModuleStr));
-  EXPECT_THAT(verifier().Run(module.get()).status().message(),
-              HasSubstr("Nested computations expects same computation's thread "
-                        "name: parallel_thread vs main, in called computation "
-                        "`add` vs caller computation `fused_computation`"));
+  EXPECT_THAT(
+      verifier().Run(module.get()).status().message(),
+      HasSubstr("crs0 top_apply computation execution thread does not match "
+                "(parallel_thread vs main)"));
 }
 
 TEST_F(HloVerifierTest, AllReduceVerifier) {
@@ -2639,8 +2804,8 @@
           .status();
   ASSERT_FALSE(status.ok());
   EXPECT_THAT(status.message(),
-              HasSubstr("expects parent computation thread name same as called "
-                        "computation's thread name"));
+              HasSubstr("custom top_apply computation execution thread does "
+                        "not match (parallel_thread vs main)"));
 }
 
 TEST_F(HloVerifierTest, CheckWhileThread) {
@@ -3133,6 +3298,49 @@
                         "memory space from device to host"));
 }
 
+TEST_F(HloVerifierTestLayoutSensitiveAndAllowMixedPrecision,
+       HostOffloadingCopyCannotChangeType) {
+  const char* const hlo_string = R"(
+HloModule m
+
+ENTRY main {
+  param = f32[1024,1024]{1,0:T(8,128)S(5)} parameter(0)
+  copy = bf16[1024,1024]{1,0:T(8,128)} copy(param)
+  ROOT dot = f32[1024,1024]{1,0:T(8,128)} dot(copy, copy), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+)";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnUnverifiedModule(hlo_string));
+
+  auto status = verifier().Run(module.get()).status();
+  ASSERT_FALSE(status.ok());
+  EXPECT_THAT(status.message(),
+              HasSubstr("Expected instruction to have shape equal to "
+                        "f32[1024,1024]{1,0:T(8,128)S(5)}, actual shape is "
+                        "bf16[1024,1024]{1,0:T(8,128)}"));
+}
+
+TEST_F(HloVerifierTestLayoutSensitiveAndAllowMixedPrecision,
+       HostOffloadingCopyCannotChangeLayout) {
+  const char* const hlo_string = R"(
+HloModule m
+
+ENTRY main {
+  param = f32[1024,1024]{1,0:T(8,128)S(5)} parameter(0)
+  ROOT copy = f32[1024,1024]{0,1:T(8,128)} copy(param)
+}
+)";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnUnverifiedModule(hlo_string));
+
+  auto status = verifier().Run(module.get()).status();
+  ASSERT_FALSE(status.ok());
+  EXPECT_THAT(status.message(),
+              HasSubstr("Expected instruction to have shape equal to "
+                        "f32[1024,1024]{1,0:T(8,128)S(5)}, actual shape is "
+                        "f32[1024,1024]{0,1:T(8,128)}"));
+}
+
 TEST_F(HloVerifierTestLayoutSensitive,
        MismatchedMinorToMajorSizeAndDimensionSize) {
   const char* const hlo_string = R"(
diff --git a/third_party/xla/xla/service/latency_hiding_scheduler.cc b/third_party/xla/xla/service/latency_hiding_scheduler.cc
index 5f7757b..dc59e5c 100644
--- a/third_party/xla/xla/service/latency_hiding_scheduler.cc
+++ b/third_party/xla/xla/service/latency_hiding_scheduler.cc
@@ -486,6 +486,15 @@
       });
 }
 
+bool AsyncTracker::OccupiesSelectiveResource(const HloGraphNode* node) const {
+  return absl::c_any_of(
+      node->GetResources(), [&](const ResourcePair& resource) {
+        return resource.second == ResourceUsageType::kResourceOccupy &&
+               GetResourceHazardType(resource.first) ==
+                   ResourceHazardType::kSelective;
+      });
+}
+
 BufferInfoTracker::BufferInfoTracker(
     const HloModule* module, const HloAliasAnalysis* alias_analysis,
     const HloCostAnalysis::ShapeSizeFunction& shape_size_bytes) {
@@ -731,6 +740,25 @@
 
 namespace {
 
+// Find the num hops to the closest selective resource overlap in ready set that
+// provided node can be scheduled in between.
+int64_t GetNumHopsToClosestSelectiveOverlap(
+    const DefaultSchedulerCore::ReadyQueueSet& ready_set,
+    const HloGraphNode* node) {
+  int64_t num_hops_to_closest_selective_resource_occupier =
+      std::numeric_limits<int64_t>::max();
+  for (const HloGraphNode* n : ready_set) {
+    // Skip the node itself.
+    if (n == node) {
+      continue;
+    }
+    num_hops_to_closest_selective_resource_occupier =
+        std::min(num_hops_to_closest_selective_resource_occupier,
+                 n->GetNumHopsToClosestSelectiveResourceOccupier());
+  }
+  return num_hops_to_closest_selective_resource_occupier;
+}
+
 // Comparator for the ready set. This class represents the priority policies
 // for the nodes in the ready set. The policy can be whatever is appropriate to
 // reduce the execution time of the graph or achieve interesting properties
@@ -1002,6 +1030,31 @@
         return *value;
       }
     }
+    // If there are no selective overlaps open currently and there will be
+    // overlaps opened in the near future, hold off scheduling instructions
+    // that are valuable for selective overlaps.
+    if (sched_state_.config.enable_selective_resources &&
+        sched_state_.selective_resource_releasers.empty()) {
+      int64_t distance_to_selective_overlap_for_a =
+          GetNumHopsToClosestSelectiveOverlap(sched_state_.ready_set, a.node);
+      int64_t distance_to_selective_overlap_for_b =
+          GetNumHopsToClosestSelectiveOverlap(sched_state_.ready_set, b.node);
+      // If a is valuable for selective overlap and there is a selective
+      // overlap in the near future a can be scheduled inside, hold off
+      // scheduling a and schedule b instead. Same logic applies in reverse.
+      int64_t max_distance =
+          sched_state_.config.max_hops_to_closest_selective_overlap;
+      if (auto value = DefaultSchedulerCore::ChooseBestCandidate(
+              (a.node->GetValuableForSelectiveOverlap() &&
+               distance_to_selective_overlap_for_a <= max_distance),
+              b,
+              (b.node->GetValuableForSelectiveOverlap() &&
+               distance_to_selective_overlap_for_b <= max_distance),
+              a, "kNotValuableForSelectiveOverlap")) {
+        return *value;
+      }
+    }
+
     if (sched_state_.config.aggressive_scheduling_policies) {
       // Favor nodes that unlock other nodes to be scheduled if possible.
       // This makes us more flexible in what we can use in scheduling.
@@ -1693,6 +1746,8 @@
             new_node_it->second->GetResources());
     new_node_it->second->releases_selective_resource_ =
         async_tracker->ReleasesSelectiveResource(new_node_it->second.get());
+    new_node_it->second->occupies_selective_resource_ =
+        async_tracker->OccupiesSelectiveResource(new_node_it->second.get());
     // Gather while instructions for subsequent send-done dependency checks.
     if (instr->opcode() == HloOpcode::kWhile) {
       while_instrs.push_back(instr);
@@ -1900,6 +1955,25 @@
   while (!stack.empty()) {
     auto* node = stack.back();
     stack.pop_back();
+    // If a node occupies a selective resource, it is the closest selective
+    // resource occupier to itself and is 0 hops away. Otherwise, the num hops
+    // to closest selective resource occupier is the minimum of that of all
+    // predecessors plus 1.
+    if (async_tracker->OccupiesSelectiveResource(node)) {
+      node->num_hops_to_closest_selective_resource_occupier_ = 0;
+    } else {
+      int64_t closest_predecessor_distance =
+          std::numeric_limits<int64_t>::max();
+      for (auto& pred : node->GetPredecessors()) {
+        closest_predecessor_distance = std::min(
+            closest_predecessor_distance,
+            pred.Target().num_hops_to_closest_selective_resource_occupier_);
+      }
+      if (closest_predecessor_distance != std::numeric_limits<int64_t>::max()) {
+        node->num_hops_to_closest_selective_resource_occupier_ =
+            closest_predecessor_distance + 1;
+      }
+    }
     if (async_tracker->IsSupportedAsyncDone(node->GetInstr())) {
       for (auto& pred : node->GetPredecessors()) {
         node->SetAsyncDepth(
diff --git a/third_party/xla/xla/service/latency_hiding_scheduler.h b/third_party/xla/xla/service/latency_hiding_scheduler.h
index ebe1cf0..b0d8a8d 100644
--- a/third_party/xla/xla/service/latency_hiding_scheduler.h
+++ b/third_party/xla/xla/service/latency_hiding_scheduler.h
@@ -137,6 +137,7 @@
   bool resource_serializing = false;
   bool depth_based_memory_pressure_reduction = false;
   bool enable_selective_resources = false;
+  int64_t max_hops_to_closest_selective_overlap = 0;
   int64_t rerun = 0;
 };
 
@@ -284,6 +285,9 @@
   // Returns whether the provided node releases a selective resource.
   bool ReleasesSelectiveResource(const HloGraphNode* node) const;
 
+  // Returns whether the provided node occupies a selective resource.
+  bool OccupiesSelectiveResource(const HloGraphNode* node) const;
+
   inline CanonicalAsyncOp GetCanonicalAsyncOp(const HloInstruction& hlo) const {
     return get_canonical_async_op_(hlo);
   }
@@ -386,6 +390,17 @@
   bool ReleasesSelectiveResource() const {
     return releases_selective_resource_;
   }
+  bool OccupiesSelectiveResource() const {
+    return occupies_selective_resource_;
+  }
+  int64_t GetNumHopsToClosestSelectiveResourceOccupier() const {
+    return num_hops_to_closest_selective_resource_occupier_;
+  }
+  void SetNumHopsToClosestSelectiveResourceOccupier(
+      int64_t num_hops_to_closest_selective_resource_occupier) {
+    num_hops_to_closest_selective_resource_occupier_ =
+        num_hops_to_closest_selective_resource_occupier;
+  }
 
   ResourcesVector GetResources() const { return resources_; }
   bool DoesOccupyAnyResource() const {
@@ -525,6 +540,11 @@
   bool valuable_for_selective_overlap_ = true;
   // Whether this node releases a selective resource.
   bool releases_selective_resource_ = false;
+  // Whether this node occupies a selective resource.
+  bool occupies_selective_resource_ = false;
+  // Nums hops to closest selective resource occupier.
+  int64_t num_hops_to_closest_selective_resource_occupier_ =
+      std::numeric_limits<int64_t>::max();
 };
 
 // Schedule graph that can be used to drive scheduling
@@ -920,7 +940,6 @@
   virtual absl::StatusOr<HloGraphNode*> FindAndExtractBestNodeAvailable(
       SchedulingState& sched_state,
       DefaultSchedulerCore::ShouldSkipNodeFunction should_skip_node);
-  bool DoesNodeReleaseSelectiveResource(const HloGraphNode* node) const;
   void DumpLatencyHidingSchedule(
       const HloComputation* computation, const HloScheduleGraph& schedule_graph,
       const std::vector<HloInstruction*>& instructions,
diff --git a/third_party/xla/xla/service/llvm_ir/ir_array.cc b/third_party/xla/xla/service/llvm_ir/ir_array.cc
index 29a4f4b..8c9c290 100644
--- a/third_party/xla/xla/service/llvm_ir/ir_array.cc
+++ b/third_party/xla/xla/service/llvm_ir/ir_array.cc
@@ -527,6 +527,7 @@
     if (!index.LinearValidOnShape(shape_)) {
       // Create a valid linear index.
       std::vector<int64_t> dimensions;
+      dimensions.reserve(shape_.rank());
       for (int64_t i = 0; i < shape_.rank(); ++i) {
         dimensions.push_back(shape_.dimensions(i));
       }
diff --git a/third_party/xla/xla/service/llvm_ir/ir_array.h b/third_party/xla/xla/service/llvm_ir/ir_array.h
index 9ec78b0..691f93f 100644
--- a/third_party/xla/xla/service/llvm_ir/ir_array.h
+++ b/third_party/xla/xla/service/llvm_ir/ir_array.h
@@ -250,9 +250,9 @@
   IrArray(llvm::Value* base_ptr, llvm::Type* pointee_type, Shape shape);
 
   // Default implementations of copying and moving.
-  IrArray(IrArray&& other) = default;
+  IrArray(IrArray&& other) noexcept = default;
   IrArray(const IrArray& other) = default;
-  IrArray& operator=(IrArray&& other) = default;
+  IrArray& operator=(IrArray&& other) noexcept = default;
   IrArray& operator=(const IrArray& other) = default;
 
   llvm::Value* GetBasePointer() const { return base_ptr_; }
diff --git a/third_party/xla/xla/service/llvm_ir/llvm_util.cc b/third_party/xla/xla/service/llvm_ir/llvm_util.cc
index 399c335..0ed7bac 100644
--- a/third_party/xla/xla/service/llvm_ir/llvm_util.cc
+++ b/third_party/xla/xla/service/llvm_ir/llvm_util.cc
@@ -715,17 +715,6 @@
   return result;
 }
 
-static absl::Status CreateAndWriteStringToFile(
-    const std::string& directory_name, const std::string& file_name,
-    const std::string& text) {
-  std::unique_ptr<tsl::WritableFile> f;
-  TF_RETURN_IF_ERROR(tsl::Env::Default()->RecursivelyCreateDir(directory_name));
-  TF_RETURN_IF_ERROR(tsl::Env::Default()->NewWritableFile(file_name, &f));
-  TF_RETURN_IF_ERROR(f->Append(text));
-  TF_RETURN_IF_ERROR(f->Close());
-  return absl::OkStatus();
-}
-
 void DumpIrIfEnabled(const HloModule& hlo_module,
                      const llvm::Module& llvm_module, bool optimized,
                      absl::string_view filename_suffix) {
diff --git a/third_party/xla/xla/service/memory_space_assignment/algorithm.cc b/third_party/xla/xla/service/memory_space_assignment/algorithm.cc
index 4611453..6737136 100644
--- a/third_party/xla/xla/service/memory_space_assignment/algorithm.cc
+++ b/third_party/xla/xla/service/memory_space_assignment/algorithm.cc
@@ -1493,7 +1493,8 @@
       continue;
     }
 
-    if (interval.size > available_heap_size()) {
+    if (!options_.enable_window_prefetch &&
+        interval.size > available_heap_size()) {
       VLOG(3) << "Skip " << interval.buffer->ToShortString()
               << " because the buffer is larger than the heap size.";
       continue;
@@ -2152,6 +2153,12 @@
             options_.alternate_memory_space;
     VLOG(4) << "require_no_copy_alternate_mem_allocation = "
             << require_no_copy_alternate_mem_allocation;
+    if (require_no_copy_alternate_mem_allocation &&
+        allocation_value.size() > available_heap_size()) {
+      VLOG(3) << "Skip " << allocation_value.value()->ToShortString()
+              << " because the buffer is larger than the heap size.";
+      continue;
+    }
     if (!options_.is_position_allowed_in_alternate_mem_fn(
             allocation_value.defining_position())) {
       if (require_no_copy_alternate_mem_allocation) {
@@ -3018,8 +3025,12 @@
     const AllocationSequence& allocations, int64_t time) {
   for (auto allocation_it = allocations.rbegin();
        allocation_it != allocations.rend(); ++allocation_it) {
+    // The use case of GetLiveAllocationAt is to find the allocation that
+    // corresponds to the full buffer. Window prefetched allocations allocates
+    // only partial buffers, so we want to skip them.
     if ((*allocation_it)->start_time() <= time &&
-        (*allocation_it)->end_time() >= time) {
+        (*allocation_it)->end_time() >= time &&
+        !(*allocation_it)->is_window_prefetched_allocation()) {
       return allocation_it->get();
     }
   }
@@ -4197,6 +4208,11 @@
         << "Not trying to prefetch because use requires buffer in default mem.";
     (*prev_allocation_in_default_mem_it)->set_end_time(request.end_time);
     (*prev_allocation_in_default_mem_it)->AddUse(request.use->hlo_use);
+
+    // If the buffer is placed in default memory, we can also try window
+    // prefetching it, which will try to prefetch only a window worth of data to
+    // alternate memory.
+    WindowPrefetch(request, **prev_allocation_in_default_mem_it);
     return Result::kSuccess;
   }
 
@@ -4286,9 +4302,28 @@
   // default memory.
   (*prev_allocation_in_default_mem_it)->set_end_time(request.end_time);
   (*prev_allocation_in_default_mem_it)->AddUse(request.use->hlo_use);
+
+  // If the buffer is placed in default memory, we can try window prefetching
+  // it, which will try to prefetch only a window worth of data to alternate
+  // memory.
+  WindowPrefetch(request, **prev_allocation_in_default_mem_it);
   return allocation_result;
 }
 
+void MsaAlgorithm::AddAsyncCopyForWindowPrefetch(
+    Allocation& prev_allocation, HloUse use, const Chunk& chunk,
+    int64_t exclusive_start_time, int64_t inclusive_end_time,
+    AllocationSequence* allocations, AliasedOffset* aliased_offset,
+    float resource, const WindowPrefetchedAllocation::Options& options) {
+  allocations->push_back(std::make_unique<WindowPrefetchedAllocation>(
+      prev_allocation, use, chunk, exclusive_start_time, inclusive_end_time,
+      options));
+
+  RegisterAsyncCopy(MemorySpace::kAlternate, exclusive_start_time,
+                    inclusive_end_time, allocations, aliased_offset, resource,
+                    /*cross_program_prefetch_index=*/std::nullopt);
+}
+
 void MsaAlgorithm::AddAsyncCopy(
     Allocation& prev_allocation, MemorySpace memory_space,
     std::optional<Chunk> chunk, int64_t exclusive_start_time, int64_t end_time,
@@ -4306,6 +4341,16 @@
       prev_allocation, memory_space, chunk, exclusive_start_time,
       copy_done_schedule_before_time, end_time, cross_program_prefetch_index));
 
+  RegisterAsyncCopy(memory_space, exclusive_start_time,
+                    copy_done_schedule_before_time, allocations, aliased_offset,
+                    resource, cross_program_prefetch_index);
+}
+
+void MsaAlgorithm::RegisterAsyncCopy(
+    MemorySpace memory_space, int64_t exclusive_start_time,
+    int64_t copy_done_schedule_before_time, AllocationSequence* allocations,
+    AliasedOffset* aliased_offset, float resource,
+    std::optional<int> cross_program_prefetch_index) {
   // Register the additional async copy with the interval tree to keep track of
   // the limit at any given time.
   pending_async_copies_.push_back({exclusive_start_time,
@@ -4445,7 +4490,8 @@
     prev_allocation =
         request.allocation_value->allocation_sequence()->back().get();
     can_eliminate_copy =
-        (prev_allocation->memory_space() == MemorySpace::kAlternate);
+        (prev_allocation->memory_space() == MemorySpace::kAlternate &&
+         !prev_allocation->is_window_prefetched_allocation());
   }
 
   if (!can_eliminate_copy) {
@@ -4718,9 +4764,41 @@
 
 }  // namespace
 
-MsaAlgorithm::Result MsaAlgorithm::Prefetch(
+MsaAlgorithm::Result MsaAlgorithm::WindowPrefetch(
     const AllocationRequest& request,
     Allocation& prev_allocation_in_default_mem) {
+  if (!options_.enable_window_prefetch) {
+    return Result::kSuccess;
+  }
+
+  const HloUse use = request.use->hlo_use;
+  VLOG(3) << "Considering window prefetch for use=" << use.ToString();
+
+  // Get the window prefetch details for this use.
+  WindowPrefetchDetail details =
+      options_.window_prefetch_detail_fn(use.instruction);
+  for (const WindowPrefetchDetail::WindowDetail& window : details.windows()) {
+    if (window.operand() != use.operand_number) {
+      continue;
+    }
+
+    WindowPrefetchedAllocation::Options options;
+    options.bytes = window.size();
+    options.uid = window.uid();
+    options.alternate_memory_space = options_.alternate_memory_space;
+    options.notify_operand_appended_fn = options_.notify_operand_appended_fn;
+    AllocationRequest window_prefetch_request = request;
+    window_prefetch_request.window_prefetch_options = &options;
+    window_prefetch_request.size = window.size();
+    const Shape shape = ShapeUtil::MakeShape(U8, {window.size()});
+    Prefetch(window_prefetch_request, prev_allocation_in_default_mem, &shape);
+  }
+  return Result::kSuccess;
+}
+
+MsaAlgorithm::Result MsaAlgorithm::Prefetch(
+    const AllocationRequest& request,
+    Allocation& prev_allocation_in_default_mem, const Shape* shape) {
   // Try partially placing the buffer in the alternate space. The time that is
   // overlapped will be used to asynchronously copy the buffer from the
   // default memory to the alternate memory.
@@ -4743,6 +4821,10 @@
   PrefetchContext context;
   context.request = &request;
   context.prev_allocation_in_default_mem = &prev_allocation_in_default_mem;
+  // If the request has window prefetch options, it is called from window
+  // prefetch.
+  context.window_prefetch = (request.window_prefetch_options != nullptr);
+  CHECK(!context.window_prefetch || options_.enable_window_prefetch);
 
   // Create a SliceProposal and WorkingIntervals.
   SetupPrefetchWorkingIntervalsAndSliceProposal(context);
@@ -4757,8 +4839,13 @@
     return check_result;
   }
   const HloUse& use = request.use->hlo_use;
-  context.full_shape = &ShapeUtil::GetSubshape(
-      use.instruction->operand(use.operand_number)->shape(), use.operand_index);
+  if (shape != nullptr) {
+    context.full_shape = shape;
+  } else {
+    context.full_shape = &ShapeUtil::GetSubshape(
+        use.instruction->operand(use.operand_number)->shape(),
+        use.operand_index);
+  }
   // While uses might be allowed to have additional outstanding prefetches.
   context.extra_async_copy_limit =
       use.instruction->opcode() == HloOpcode::kWhile
@@ -4849,14 +4936,26 @@
             << context.unsliced_solution->prefetch_picker_debug_string;
     AddToPendingChunks(context.unsliced_solution_intervals.full,
                        context.unsliced_solution->chunk_candidate);
-    AddAsyncCopy(
-        *context.prev_allocation_in_default_mem, MemorySpace::kAlternate,
-        context.unsliced_solution->chunk_candidate,
-        context.unsliced_solution_intervals.full.start - 1,
-        context.request->end_time, context.prefetch_end_time,
-        context.request->allocation_value->mutable_allocation_sequence(),
-        context.request->preferred_offset,
-        context.unsliced_solution->prefetch_resource);
+    if (context.window_prefetch) {
+      AddAsyncCopyForWindowPrefetch(
+          *context.prev_allocation_in_default_mem, request.use->hlo_use,
+          context.unsliced_solution->chunk_candidate,
+          context.unsliced_solution_intervals.full.start - 1,
+          context.prefetch_end_time,
+          context.request->allocation_value->mutable_allocation_sequence(),
+          context.request->preferred_offset,
+          context.unsliced_solution->prefetch_resource,
+          *context.request->window_prefetch_options);
+    } else {
+      AddAsyncCopy(
+          *context.prev_allocation_in_default_mem, MemorySpace::kAlternate,
+          context.unsliced_solution->chunk_candidate,
+          context.unsliced_solution_intervals.full.start - 1,
+          context.request->end_time, context.prefetch_end_time,
+          context.request->allocation_value->mutable_allocation_sequence(),
+          context.request->preferred_offset,
+          context.unsliced_solution->prefetch_resource);
+    }
 
     request.allocation_value->allocation_sequence()->back()->AddUse(
         request.use->hlo_use);
@@ -4929,7 +5028,9 @@
       context.sliced_solution_intervals.full;
 
   // Attempt to generate a slice proposal.
-  GenerateSliceProposal(context);
+  if (!context.window_prefetch) {
+    GenerateSliceProposal(context);
+  }
 
   // Setup the full SlicedBufferIntervals for the sliced and unsliced solutions.
   // If there is no slice proposal, we will not try a sliced solution. In such a
diff --git a/third_party/xla/xla/service/memory_space_assignment/algorithm.h b/third_party/xla/xla/service/memory_space_assignment/algorithm.h
index 5e2073b..52d0f0e 100644
--- a/third_party/xla/xla/service/memory_space_assignment/algorithm.h
+++ b/third_party/xla/xla/service/memory_space_assignment/algorithm.h
@@ -514,6 +514,10 @@
     absl::Span<const int64_t> all_use_times;
     // See the comment for require_copy_allocation
     HloInstruction* required_copy_allocation_for;
+    // Data structure that contains the options for making window prefetched
+    // allocations.
+    const WindowPrefetchedAllocation::Options* window_prefetch_options =
+        nullptr;
   };
 
   // This struct contains mandatory memory assignments at a given time. E.g., an
@@ -669,6 +673,11 @@
     // Data structures used to compute and store the unsliced solution.
     WorkingIntervals unsliced_solution_intervals;
     std::optional<UnslicedSolution> unsliced_solution;
+
+    // Indicates whether the prefetch is for a windowed prefetch. A window
+    // prefetch only prefetches a window worth of data. Its prefetch does not
+    // use sliced prefetch.
+    bool window_prefetch = false;
   };
 
   // Result of an allocation, prefetch, eviction etc. request.  The result is
@@ -860,7 +869,8 @@
 
   // Try prefetching to alternate memory space.
   Result Prefetch(const AllocationRequest& request,
-                  Allocation& prev_allocation_in_default_mem);
+                  Allocation& prev_allocation_in_default_mem,
+                  const Shape* shape = nullptr);
 
   // Helper methods used to implement Prefetch().
   //
@@ -888,6 +898,10 @@
   std::string AlternateMemoryAllocationAttemptToString(
       bool for_sliced_solution, const PrefetchContext& context) const;
 
+  // Try to prefetch a window worth of data into the alternate memory.
+  Result WindowPrefetch(const AllocationRequest& request,
+                        Allocation& prev_allocation_in_default_mem);
+
   // Find the best possible chunk candidate, where it has the longest possible
   // availability if no preferred offset is given, or at the preferred_offset if
   // it is given.
@@ -1014,6 +1028,14 @@
   void ImportRepackedSlicedAllocation(RepackAllocationBlock& block);
   absl::Status AreRepackedSlicesValid(const RepackAllocationBlock& block);
 
+  // Registers an asynchronous copy with asynchronous copy data structures to
+  // keep track of its state.
+  void RegisterAsyncCopy(MemorySpace memory_space, int64_t exclusive_start_time,
+                         int64_t copy_done_schedule_before_time,
+                         AllocationSequence* allocations,
+                         AliasedOffset* aliased_offset, float resource,
+                         std::optional<int> cross_program_prefetch_index);
+
   // Adds an asynchronous copy to allocations.
   void AddAsyncCopy(
       Allocation& prev_allocation, MemorySpace memory_space,
@@ -1032,6 +1054,15 @@
       const std::vector<SliceDecision>& slice_decisions_sorted_by_start_time,
       int64_t prefetch_end_time, int64_t allocation_end_time);
 
+  // For window prefetching, adds a WindowPrefetchedAllocation to allocations.
+  // Also updates asynchronous copy data structures, prefetch_interval_tree_,
+  // and aliasing data structures.
+  void AddAsyncCopyForWindowPrefetch(
+      Allocation& prev_allocation, HloUse use, const Chunk& chunk,
+      int64_t exclusive_start_time, int64_t inclusive_end_time,
+      AllocationSequence* allocations, AliasedOffset* aliased_offset,
+      float resource, const WindowPrefetchedAllocation::Options& options);
+
   // This method is used for committing the chunk candidate but adding it to
   // pending_chunks_ so that we can "uncommit" them in case we need to roll back
   // this allocation sequence.
diff --git a/third_party/xla/xla/service/memory_space_assignment/allocation.cc b/third_party/xla/xla/service/memory_space_assignment/allocation.cc
index 50bec57..8699aab 100644
--- a/third_party/xla/xla/service/memory_space_assignment/allocation.cc
+++ b/third_party/xla/xla/service/memory_space_assignment/allocation.cc
@@ -37,6 +37,8 @@
 #include "absl/types/span.h"
 #include "xla/hlo/ir/hlo_instruction.h"
 #include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/layout.h"
+#include "xla/layout_util.h"
 #include "xla/service/heap_simulator/allocation_block.h"
 #include "xla/service/heap_simulator/heap_simulator.h"
 #include "xla/service/hlo_value.h"
@@ -854,6 +856,133 @@
   return casted_other != nullptr && (*this) == (*casted_other);
 }
 
+WindowPrefetchedAllocation::WindowPrefetchedAllocation(
+    Allocation& prev_allocation, HloUse use, const HeapSimulator::Chunk& chunk,
+    int64_t prefetch_start_schedule_after_time,
+    int64_t prefetch_done_schedule_before_time, const Options& options)
+    : Allocation(
+          {nullptr, {}}, MemorySpace::kAlternate, chunk,
+          ExclusiveToInclusiveStartTime(prefetch_start_schedule_after_time),
+          InclusiveToExclusiveEndTime(prefetch_done_schedule_before_time),
+          /*is_scoped_allocation=*/false,
+          /*cross_program_prefetch_index=*/std::nullopt),
+      options_(options),
+      prev_allocation_(prev_allocation),
+      use_(use),
+      prefetch_start_schedule_after_(prefetch_start_schedule_after_time),
+      prefetch_done_schedule_before_(prefetch_done_schedule_before_time),
+      bytes_(chunk.size) {}
+
+HloPosition WindowPrefetchedAllocation::defining_position() const {
+  HloPosition defining_position = original_defining_position();
+  if (defining_position.instruction == nullptr) {
+    return prev_allocation_.defining_position();
+  }
+  return defining_position;
+}
+
+int64_t WindowPrefetchedAllocation::earliest_available_time() const {
+  return prefetch_done_schedule_before_;
+}
+
+absl::Status WindowPrefetchedAllocation::InsertWindowPrefetchInstruction(
+    HloInstruction* producing_instruction, HloInstruction* use_instruction,
+    HloComputation* computation) {
+  // Derive the shape for window buffer.
+  Shape shape = ShapeUtil::MakeShape(U8, {options_.bytes});
+  Layout layout = LayoutUtil::MakeLayout({0});
+  layout.set_memory_space(options_.alternate_memory_space);
+  *shape.mutable_layout() = layout;
+
+  // Insert a new parameter in the fused computation.
+  HloComputation* fused_computation =
+      use_instruction->fused_instructions_computation();
+  const int64_t num_parameters = fused_computation->num_parameters();
+  std::string name = absl::StrCat("window-buffer.", num_parameters);
+  HloInstruction* param = fused_computation->AddParameter(
+      HloInstruction::CreateParameter(num_parameters, shape, name));
+
+  // Insert async WindowPrefetch instructions as operands to the fusion.
+  HloInstruction* prefetch =
+      computation->AddInstruction(HloInstruction::CreateCustomCall(
+          shape, {producing_instruction}, "WindowPrefetch"));
+  TF_ASSIGN_OR_RETURN(prefetch_instruction_,
+                      computation->CreateAsyncInstructions(prefetch, {}));
+  use_instruction->AppendOperand(prefetch_instruction_);
+
+  // Insert instruction to consume the added operands and forwards the original
+  // fusion output.
+  auto get_or_create_consumer =
+      [](HloComputation* computation) -> HloInstruction* {
+    HloInstruction* root = computation->root_instruction();
+    // If the root is already a WindowPrefetchBuffer, we don't need to create
+    // a new one.
+    if (root->IsCustomCall("WindowPrefetchBuffer")) {
+      return root;
+    }
+    HloInstruction* new_root =
+        computation->AddInstruction(HloInstruction::CreateCustomCall(
+            root->shape(), {root}, "WindowPrefetchBuffer"));
+    computation->set_root_instruction(new_root);
+    return new_root;
+  };
+  HloInstruction* consumer = get_or_create_consumer(fused_computation);
+  consumer->AppendOperand(param);
+  return absl::OkStatus();
+}
+
+absl::Status WindowPrefetchedAllocation::Process() {
+  HloInstruction* producing_instruction = AddGetTupleElements();
+  HloComputation* computation = producing_instruction->parent();
+  HloInstruction* use_instruction = use_.instruction;
+  CHECK_EQ(use_instruction->opcode(), HloOpcode::kFusion);
+
+  TF_RETURN_IF_ERROR(InsertWindowPrefetchInstruction(
+      producing_instruction, use_instruction, computation));
+
+  // Notify the backend that an operand has been appended as a window prefetch
+  // buffer.
+  int64_t use_operand = use_instruction->operand_count() - 1;
+  options_.notify_operand_appended_fn(use_instruction, options_.uid,
+                                      use_operand);
+
+  // Set the original defining position to the window prefetch instruction.
+  set_original_defining_position(HloPosition{prefetch_instruction_, {}});
+  AddUse(HloUse{use_instruction, use_operand});
+  return absl::OkStatus();
+}
+
+void WindowPrefetchedAllocation::MarkIfNeeded(
+    absl::flat_hash_set<const Allocation*>& needed_allocations) const {
+  MarkNeeded(needed_allocations);
+}
+
+void WindowPrefetchedAllocation::MarkNeeded(
+    absl::flat_hash_set<const Allocation*>& needed_allocations) const {
+  needed_allocations.insert(this);
+  prev_allocation_.MarkNeeded(needed_allocations);
+}
+
+std::string WindowPrefetchedAllocation::ToString() const {
+  return absl::StrCat("WindowPrefetched Allocation");
+}
+
+bool WindowPrefetchedAllocation::operator==(
+    const WindowPrefetchedAllocation& other) const {
+  return this->base_is_equal(static_cast<const Allocation&>(other)) &&
+         prefetch_done_schedule_before() ==
+             other.prefetch_done_schedule_before() &&
+         prefetch_start_schedule_after() ==
+             other.prefetch_start_schedule_after() &&
+         prefetch() == other.prefetch() && bytes_ == other.bytes_;
+}
+
+bool WindowPrefetchedAllocation::operator==(const Allocation& other) const {
+  const WindowPrefetchedAllocation* casted_other =
+      dynamic_cast<const WindowPrefetchedAllocation*>(&other);
+  return casted_other != nullptr && (*this) == (*casted_other);
+}
+
 std::tuple<int64_t, bool, int64_t> GetAllocationSortTuple(
     const std::unique_ptr<Allocation>& allocation) {
   int64_t scheduled_on_or_before = allocation->start_time();
diff --git a/third_party/xla/xla/service/memory_space_assignment/allocation.h b/third_party/xla/xla/service/memory_space_assignment/allocation.h
index d0a4d72..bb3b324 100644
--- a/third_party/xla/xla/service/memory_space_assignment/allocation.h
+++ b/third_party/xla/xla/service/memory_space_assignment/allocation.h
@@ -18,6 +18,7 @@
 
 #include <algorithm>
 #include <cstdint>
+#include <functional>
 #include <memory>
 #include <optional>
 #include <string>
@@ -130,6 +131,7 @@
   virtual bool is_pinned_allocation() const = 0;
   virtual bool is_copy_allocation() const = 0;
   virtual bool is_sliced_copy_allocation() const = 0;
+  virtual bool is_window_prefetched_allocation() const = 0;
   // True if the allocation is for a copy or a sliced-copy.
   bool is_copy_like_allocation() const;
 
@@ -211,6 +213,7 @@
   bool is_pinned_allocation() const override { return true; }
   bool is_copy_allocation() const override { return false; }
   bool is_sliced_copy_allocation() const override { return false; }
+  bool is_window_prefetched_allocation() const override { return false; }
   absl::Status Process() override;
   absl::Status PostProcess() override { return absl::OkStatus(); }
   void MarkIfNeeded(absl::flat_hash_set<const Allocation*>& needed_allocations)
@@ -249,6 +252,7 @@
   bool is_pinned_allocation() const override { return false; }
   bool is_copy_allocation() const override { return true; }
   bool is_sliced_copy_allocation() const override { return false; }
+  bool is_window_prefetched_allocation() const override { return false; }
   absl::Status Process() override;
   absl::Status PostProcess() override { return absl::OkStatus(); }
   void MarkIfNeeded(absl::flat_hash_set<const Allocation*>& needed_allocations)
@@ -350,6 +354,7 @@
   bool is_pinned_allocation() const override { return false; }
   bool is_copy_allocation() const override { return false; }
   bool is_sliced_copy_allocation() const override { return true; }
+  bool is_window_prefetched_allocation() const override { return false; }
   // MemorySpaceAssignment::Process() calls Process() to create asynchronous
   // slice copies, and a bitcast-concat call to glue the slices back together.
   absl::Status Process() override;
@@ -393,6 +398,75 @@
   absl::FunctionRef<Shape(const Shape&)> get_equivalent_s8_shape_fn_;
 };
 
+// This class represents an allocation resulting from asynchronously prefetching
+// a window buffer. When a tensor is placed in the default memory, we can
+// prefetch the window buffer of the tensor to the alternate memory space. This
+// is called window prefetching.
+class WindowPrefetchedAllocation final : public Allocation {
+ public:
+  struct Options {
+    int64_t bytes = 0;
+    int64_t uid = 0;
+    int64_t alternate_memory_space = 0;
+    std::function<void(HloInstruction*, int64_t, int64_t)>
+        notify_operand_appended_fn =
+            [](const HloInstruction*, int64_t, int64_t) {};
+  };
+
+  WindowPrefetchedAllocation(Allocation& prev_allocation, HloUse use,
+                             const HeapSimulator::Chunk& chunk,
+                             int64_t prefetch_start_schedule_after_time,
+                             int64_t prefetch_done_schedule_before_time,
+                             const Options& options);
+
+  // Overridden methods
+  //
+  HloPosition defining_position() const override;
+  int64_t earliest_available_time() const override;
+  bool is_pinned_allocation() const override { return false; }
+  bool is_copy_allocation() const override { return false; }
+  bool is_sliced_copy_allocation() const override { return false; }
+  bool is_window_prefetched_allocation() const override { return true; }
+  // MemorySpaceAssignment::Process() calls Process() to create asynchronous
+  // window prefetches.
+  absl::Status Process() override;
+  absl::Status PostProcess() override { return absl::OkStatus(); }
+  // Marks the allocation as needed.
+  void MarkIfNeeded(absl::flat_hash_set<const Allocation*>& needed_allocations)
+      const override;
+  void MarkNeeded(absl::flat_hash_set<const Allocation*>& needed_allocations)
+      const override;
+  std::string ToString() const override;
+  bool operator==(const WindowPrefetchedAllocation& other) const;
+  bool operator==(const Allocation& other) const override;
+  int64_t bytes() const { return bytes_; }
+  int64_t prefetch_start_schedule_after() const {
+    return prefetch_start_schedule_after_;
+  }
+  int64_t prefetch_done_schedule_before() const {
+    return prefetch_done_schedule_before_;
+  }
+  HloInstruction* prefetch() const { return prefetch_instruction_; }
+
+ private:
+  // This method is called by Process() to create window prefetch instructions.
+  // These instructions include a pair of async WindowPrefetch outside the
+  // fusion and a WindowPrefetchBuffer inside the fusion. The
+  // WindowPrefetchBuffer is used for consuming the appended window buffer
+  // operands.
+  absl::Status InsertWindowPrefetchInstruction(
+      HloInstruction* producing_instruction, HloInstruction* use_instruction,
+      HloComputation* computation);
+
+  Options options_;
+  HloInstruction* prefetch_instruction_ = nullptr;
+  Allocation& prev_allocation_;
+  HloUse use_;
+  int64_t prefetch_start_schedule_after_;
+  int64_t prefetch_done_schedule_before_;
+  int64_t bytes_;
+};
+
 // An allocation in the default memory space that mirrors another Allocation
 // object. This is useful to model an eviction that happens before a while op
 // so that we don't need to redundantly evict the buffer after the while op as
@@ -409,6 +483,7 @@
   bool is_pinned_allocation() const override { return false; }
   bool is_copy_allocation() const override { return false; }
   bool is_sliced_copy_allocation() const override { return false; }
+  bool is_window_prefetched_allocation() const override { return false; }
   absl::Status Process() override;
   absl::Status PostProcess() override { return absl::OkStatus(); }
   void MarkIfNeeded(absl::flat_hash_set<const Allocation*>& needed_allocations)
@@ -442,6 +517,7 @@
   bool is_pinned_allocation() const override { return false; }
   bool is_copy_allocation() const override { return false; }
   bool is_sliced_copy_allocation() const override { return false; }
+  bool is_window_prefetched_allocation() const override { return false; }
   absl::Status Process() override;
   absl::Status PostProcess() override;
   void MarkIfNeeded(absl::flat_hash_set<const Allocation*>& needed_allocations)
diff --git a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.proto b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.proto
index 77faa69..e15d564 100644
--- a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.proto
+++ b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.proto
@@ -46,6 +46,21 @@
   uint64 preferred_slice_size = 5;
 }
 
+// Memory space assignment options for prefetching windows of data
+message WindowPrefetchDetail {
+  message WindowDetail {
+    // Index of the operand that is window prefetched.
+    int64 operand = 1;
+    // Window buffer size in bytes.
+    int64 size = 2;
+    // Unique identifier to distinguish the buffers that are associated with the
+    // same operand.
+    int64 uid = 3;
+  }
+
+  repeated WindowDetail windows = 1;
+}
+
 // Options for memory-bound loop optimizations in memory space assignment. If
 // enabled, this pass can optimize memory-bound unrolled loops to maximize the
 // bandwidth utilized and minimize the execution time.
diff --git a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc
index 3fb558b..40f5e1d 100644
--- a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc
+++ b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc
@@ -8282,6 +8282,65 @@
   }
 }
 
+TEST_F(MemorySpaceAssignmentTest, WindowPrefetch) {
+  absl::string_view hlo_string = R"(
+HloModule module, is_scheduled=true
+
+%fused_computation {
+  %p0 = bf16[64,8]{1,0:T(8,128)(2,1)} parameter(0)
+  %p1 = bf16[64,8]{1,0:T(8,128)(2,1)} parameter(1)
+  %p2 = bf16[64,8]{1,0:T(8,128)(2,1)} parameter(2)
+  %add0 = bf16[64,8]{1,0:T(8,128)(2,1)} add(%p0, %p1)
+  ROOT %add1 = bf16[64,8]{1,0:T(8,128)(2,1)} add(%add0, %p2)
+}
+
+entry {
+  %p0 = bf16[64,8]{1,0:T(8,128)(2,1)} parameter(0)
+  %p1 = bf16[64,8]{1,0:T(8,128)(2,1)} parameter(1)
+  %p2 = bf16[64,8]{1,0:T(8,128)(2,1)} parameter(2)
+  ROOT fusion = bf16[64,8]{1,0:T(8,128)(2,1)} fusion(bf16[64,8]{1,0:T(8,128)(2,1)} %p0, bf16[64,8]{1,0:T(8,128)(2,1)} %p1, bf16[64,8]{1,0:T(8,128)(2,1)} %p2), kind=kLoop, calls=%fused_computation
+}
+
+)";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+
+  // Get info about window prefetch buffers, such as which operands they
+  // correspond to and their sizes.
+  auto window_prefetch_detail_fn = [&](const HloInstruction* instruction) {
+    WindowPrefetchDetail window_prefetch_detail;
+    const HloInstruction* fusion = FindInstruction(module.get(), "fusion");
+    if (instruction == fusion) {
+      for (int i = 0; i < 3; ++i) {
+        auto* operand = window_prefetch_detail.add_windows();
+        operand->set_operand(i);
+        operand->set_size(32);
+      }
+    }
+    return window_prefetch_detail;
+  };
+
+  Options options = DefaultMemorySpaceOptions();
+  options.enable_window_prefetch = true;
+  options.window_prefetch_detail_fn = window_prefetch_detail_fn;
+  AssignMemorySpace(module.get(), options, /*max_prefetch_interval=*/10,
+                    /*min_prefetch_interval=*/0);
+  const HloInstruction* fusion = FindInstruction(module.get(), "fusion");
+  // The fusion instruction should have 5 operands: the 3 original operands
+  // plus 2 window prefetch buffers.
+  EXPECT_EQ(fusion->operand_count(), 5);
+
+  // The root of the fusion should be a WindowPrefetchBuffer. The first operand
+  // should be the original root, and the second and third operands should be
+  // the window prefetch buffers.
+  HloInstruction* root = fusion->fused_expression_root();
+  EXPECT_TRUE(root->IsCustomCall("WindowPrefetchBuffer"));
+  EXPECT_EQ(root->operand_count(), 3);
+  EXPECT_EQ(root->operand(1), fusion->fused_parameter(3));
+  EXPECT_EQ(root->operand(2), fusion->fused_parameter(4));
+  VLOG(2) << "module: " << module->ToString();
+}
+
 using AsynchronousCopyOrderingTest = ::testing::Test;
 
 TEST_F(AsynchronousCopyOrderingTest, Simple) {
diff --git a/third_party/xla/xla/service/memory_space_assignment/options.h b/third_party/xla/xla/service/memory_space_assignment/options.h
index 3a1d848..fb9730c 100644
--- a/third_party/xla/xla/service/memory_space_assignment/options.h
+++ b/third_party/xla/xla/service/memory_space_assignment/options.h
@@ -57,6 +57,10 @@
     const absl::flat_hash_set<ShapeIndex>& /*outputs_in_alternate_memory*/)>;
 using PositionRequiresContiguousAllocationFunction =
     std::function<bool(const HloPosition&)>;
+using WindowPrefetchDetailFunction =
+    std::function<WindowPrefetchDetail(const HloInstruction*)>;
+using WindowPrefetchNotifyOperandAppendedFunction =
+    std::function<void(HloInstruction*, int64_t, int64_t)>;
 
 // The different options to be passed to the Run() API.
 struct Options {
@@ -111,6 +115,15 @@
       position_requires_contiguous_allocation_fn =
           [](const HloPosition&) { return false; };
 
+  // This function is called to get details about window prefetches.
+  WindowPrefetchDetailFunction window_prefetch_detail_fn =
+      [](const HloInstruction*) { return WindowPrefetchDetail(); };
+
+  // This function is called to notify that an operand has been appended as a
+  // window prefetch buffer.
+  WindowPrefetchNotifyOperandAppendedFunction notify_operand_appended_fn =
+      [](HloInstruction*, int64_t, int64_t) {};
+
   // If true, we will try to reduce scoped allocation buffer size for all
   // instructions if their operand/output has been allocated in alternate
   // memory.
@@ -234,6 +247,13 @@
   // Option to always spill buffers from alternate memory to default memory
   // and prefetching back to alternate memory(if needed) just in time for use.
   bool always_spill_to_default_memory = false;
+
+  // If true, enables window prefetching. Window prefetching is a mechanism
+  // where we prefetch windows of data into the alternate memory before the
+  // first use of the buffer. This allows large tensors to be prefetched as well
+  // and gives MSA more flexibility in choosing the prefetch time and how much
+  // data to prefetch.
+  bool enable_window_prefetch = false;
 };
 }  // namespace memory_space_assignment
 }  // namespace xla
diff --git a/third_party/xla/xla/service/pattern_matcher.h b/third_party/xla/xla/service/pattern_matcher.h
index 9b5a953..76979f0 100644
--- a/third_party/xla/xla/service/pattern_matcher.h
+++ b/third_party/xla/xla/service/pattern_matcher.h
@@ -2682,6 +2682,7 @@
 XLA_UNOP_PATTERN(Bitcast)
 XLA_UNOP_PATTERN(BitcastConvert)
 XLA_UNOP_PATTERN(Broadcast)
+XLA_UNOP_PATTERN(Cbrt)
 XLA_UNOP_PATTERN(Ceil)
 XLA_UNOP_PATTERN(Convert)
 XLA_UNOP_PATTERN(Copy)
@@ -2695,6 +2696,7 @@
 XLA_UNOP_PATTERN(CollectivePermuteStart)
 XLA_UNOP_PATTERN(CollectivePermuteDone)
 XLA_UNOP_PATTERN(Domain)
+XLA_UNOP_PATTERN(Erf)
 XLA_UNOP_PATTERN(Exp)
 XLA_UNOP_PATTERN(Expm1)
 XLA_UNOP_PATTERN(Fft)
@@ -2704,6 +2706,7 @@
 XLA_UNOP_PATTERN(Infeed)
 XLA_UNOP_PATTERN(IsFinite)
 XLA_UNOP_PATTERN(Log)
+XLA_UNOP_PATTERN(Logistic)
 XLA_UNOP_PATTERN(Not)
 XLA_UNOP_PATTERN(Negate)
 XLA_UNOP_PATTERN(OptimizationBarrier)
diff --git a/third_party/xla/xla/service/shape_inference.cc b/third_party/xla/xla/service/shape_inference.cc
index 53ff557..bf375e8 100644
--- a/third_party/xla/xla/service/shape_inference.cc
+++ b/third_party/xla/xla/service/shape_inference.cc
@@ -3794,6 +3794,7 @@
 static absl::Status ValidateGatherDimensionNumbers(
     const Shape& input_shape, absl::Span<const int64_t> start_indices_shape,
     const GatherDimensionNumbers& dim_numbers) {
+  // Validate offset_dims in GatherDimensionNumbers.
   if (!absl::c_is_sorted(dim_numbers.offset_dims())) {
     return InvalidArgument(
         "Output window dimensions in gather op must be ascending; got: %s.",
@@ -3834,6 +3835,7 @@
         start_indices_shape[dim_numbers.index_vector_dim()]);
   }
 
+  // Validate start_index_map in GatherDimensionNumbers.
   for (int i = 0; i < dim_numbers.start_index_map_size(); i++) {
     int64_t operand_dim_for_start_index_i = dim_numbers.start_index_map(i);
     if (operand_dim_for_start_index_i < 0 ||
@@ -3858,6 +3860,7 @@
         StrJoin(dim_numbers.start_index_map(), ", "));
   }
 
+  // Validate collapsed_slice_dims in GatherDimensionNumbers.
   for (int64_t collapsed_dim : dim_numbers.collapsed_slice_dims()) {
     if (collapsed_dim < 0 || collapsed_dim >= input_shape.dimensions_size()) {
       return InvalidArgument(
@@ -3881,6 +3884,69 @@
         StrJoin(dim_numbers.collapsed_slice_dims(), ", "));
   }
 
+  // Validate operand_batching_dims and start_indices_batching_dims are of the
+  // same size.
+  if (dim_numbers.operand_batching_dims_size() !=
+      dim_numbers.start_indices_batching_dims_size()) {
+    return InvalidArgument(
+        "operand_batching_dims and start_indices_batching_dims in gather op "
+        "must be of the same size; got: %d and %d.",
+        dim_numbers.operand_batching_dims_size(),
+        dim_numbers.start_indices_batching_dims_size());
+  }
+
+  // Validate operand_batching_dims in GatherDimensionNumbers.
+  for (int64_t operand_batching_dim : dim_numbers.operand_batching_dims()) {
+    if (operand_batching_dim < 0 ||
+        operand_batching_dim >= input_shape.dimensions_size()) {
+      return InvalidArgument(
+          "Invalid operand_batching_dims set in gather op; valid range is [0, "
+          "%d), got: %d.",
+          input_shape.dimensions_size(), operand_batching_dim);
+    }
+  }
+
+  if (!absl::c_is_sorted(dim_numbers.operand_batching_dims())) {
+    return InvalidArgument(
+        "operand_batching_dims in gather op must be sorted; got: %s",
+        StrJoin(dim_numbers.operand_batching_dims(), ", "));
+  }
+
+  if (absl::c_adjacent_find(dim_numbers.operand_batching_dims()) !=
+      dim_numbers.operand_batching_dims().end()) {
+    return InvalidArgument(
+        "Repeated dimensions not allowed in operand_batching_dims in gather "
+        "op; "
+        "got: %s.",
+        StrJoin(dim_numbers.operand_batching_dims(), ", "));
+  }
+
+  // Validate start_indices_batching_dims in GatherDimensionNumbers.
+  for (int i = 0; i < dim_numbers.start_indices_batching_dims_size(); i++) {
+    int64_t start_indices_batching_dim_i =
+        dim_numbers.start_indices_batching_dims(i);
+    if (start_indices_batching_dim_i < 0 ||
+        start_indices_batching_dim_i >= start_indices_shape.size()) {
+      return InvalidArgument(
+          "Invalid start_indices_batching_dims; domain is [0, %d), got: "
+          "%d->%d.",
+          start_indices_shape.size(), i, start_indices_batching_dim_i);
+    }
+  }
+
+  std::vector<int64_t> sorted_start_indices_batching_dims(
+      dim_numbers.start_indices_batching_dims().begin(),
+      dim_numbers.start_indices_batching_dims().end());
+
+  absl::c_sort(sorted_start_indices_batching_dims);
+
+  if (absl::c_adjacent_find(sorted_start_indices_batching_dims) !=
+      sorted_start_indices_batching_dims.end()) {
+    return InvalidArgument(
+        "Repeated dimensions are not allowed in start_indices_batching_dims; "
+        "got: %s.",
+        StrJoin(dim_numbers.start_indices_batching_dims(), ", "));
+  }
   return absl::OkStatus();
 }
 
@@ -3943,13 +4009,16 @@
 
   if (slice_sizes.size() !=
       gather_dim_numbers.offset_dims_size() +
-          gather_dim_numbers.collapsed_slice_dims_size()) {
+          gather_dim_numbers.collapsed_slice_dims_size() +
+          gather_dim_numbers.operand_batching_dims_size()) {
     return InvalidArgument(
         "All components of the offset index in a gather op must either be a "
-        "offset dimension or explicitly collapsed; got len(slice_sizes)=%lu, "
-        "output_slice_sizes=%s, collapsed_slice_dims=%s.",
+        "offset dimension or explicitly collapsed or explicitly batched; got "
+        "len(slice_sizes)=%lu, output_slice_sizes=%s, collapsed_slice_dims=%s, "
+        "operand_batching_dims=%s.",
         slice_sizes.size(), StrJoin(gather_dim_numbers.offset_dims(), ","),
-        StrJoin(gather_dim_numbers.collapsed_slice_dims(), ","));
+        StrJoin(gather_dim_numbers.collapsed_slice_dims(), ","),
+        StrJoin(gather_dim_numbers.operand_batching_dims(), ","));
   }
 
   for (int i = 0; i < slice_sizes.size(); i++) {
@@ -3974,6 +4043,16 @@
     }
   }
 
+  for (int i = 0; i < gather_dim_numbers.operand_batching_dims_size(); i++) {
+    if (slice_sizes[gather_dim_numbers.operand_batching_dims(i)] > 1) {
+      return InvalidArgument(
+          "Gather op can only have operand_batching_dims with bound 1 or 0, "
+          "but bound is %d for index %d at position %d.",
+          slice_sizes[gather_dim_numbers.operand_batching_dims(i)],
+          gather_dim_numbers.operand_batching_dims(i), i);
+    }
+  }
+
   int64_t result_rank = gather_dim_numbers.offset_dims_size() +
                         (expanded_start_indices_shape.size() - 1);
   int64_t offset_dims_seen = 0;
@@ -3993,6 +4072,10 @@
                                    offset_dims_seen)) {
         offset_dims_seen++;
       }
+      while (absl::c_binary_search(gather_dim_numbers.operand_batching_dims(),
+                                   offset_dims_seen)) {
+        offset_dims_seen++;
+      }
       // Gathering an entire dynamic dimension creates dynamic dimension.
       //
       // e.g.,:
@@ -4075,7 +4158,8 @@
 
   // Validate window size.
   auto window_size = dim_numbers.update_window_dims_size() +
-                     dim_numbers.inserted_window_dims_size();
+                     dim_numbers.inserted_window_dims_size() +
+                     dim_numbers.input_batching_dims_size();
   if (window_size != operand_shape.rank()) {
     return InvalidArgument(
         "Scatter op has window of size %d; doesn't match operand of rank %d.",
@@ -4117,6 +4201,61 @@
         StrJoin(dim_numbers.scatter_dims_to_operand_dims(), ", "));
   }
 
+  // Validate input_batching_dims and scatter_indices_batching_dims in
+  // ScatterDimensionNumbers.
+  if (dim_numbers.input_batching_dims_size() !=
+      dim_numbers.scatter_indices_batching_dims_size()) {
+    return InvalidArgument(
+        "input_batching_dims and scatter_indices_batching_dims in scatter op "
+        "must be of the same size; got: %d and %d.",
+        dim_numbers.input_batching_dims_size(),
+        dim_numbers.scatter_indices_batching_dims_size());
+  }
+
+  // Validate input_batching_dims in ScatterDimensionNumbers.
+  if (!absl::c_is_sorted(dim_numbers.input_batching_dims())) {
+    return InvalidArgument(
+        "input_batching_dims in scatter op must be sorted; got: %s.",
+        StrJoin(dim_numbers.input_batching_dims(), ", "));
+  }
+  if (absl::c_adjacent_find(dim_numbers.input_batching_dims()) !=
+      dim_numbers.input_batching_dims().end()) {
+    return InvalidArgument(
+        "input_batching_dims in scatter op must not repeat; got: %s.",
+        StrJoin(dim_numbers.input_batching_dims(), ", "));
+  }
+  for (int64_t input_batching_dim : dim_numbers.input_batching_dims()) {
+    if (input_batching_dim < 0 ||
+        input_batching_dim >= operand_shape.dimensions_size()) {
+      return InvalidArgument(
+          "Invalid input_batching_dims set in scatter op; valid range is [0, "
+          "%d), got: %d.",
+          operand_shape.dimensions_size(), input_batching_dim);
+    }
+  }
+
+  // Validate scatter_indices_batching_dims in ScatterDimensionNumbers.
+  for (int64_t scatter_indices_batching_dim :
+       dim_numbers.scatter_indices_batching_dims()) {
+    if (scatter_indices_batching_dim < 0 ||
+        scatter_indices_batching_dim >= scatter_indices_shape.size()) {
+      return InvalidArgument(
+          "Invalid scatter_indices_batching_dims set in scatter op; valid "
+          "range is [0, %d), got: %d.",
+          scatter_indices_shape.size(), scatter_indices_batching_dim);
+    }
+  }
+  std::vector<int64_t> sorted_scatter_indices_batching_dims(
+      dim_numbers.scatter_indices_batching_dims().begin(),
+      dim_numbers.scatter_indices_batching_dims().end());
+  absl::c_sort(sorted_scatter_indices_batching_dims);
+  if (absl::c_adjacent_find(sorted_scatter_indices_batching_dims) !=
+      sorted_scatter_indices_batching_dims.end()) {
+    return InvalidArgument(
+        "scatter_indices_batching_dims in scatter op must not repeat; got: %s.",
+        StrJoin(dim_numbers.scatter_indices_batching_dims(), ", "));
+  }
+
   return absl::OkStatus();
 }
 
diff --git a/third_party/xla/xla/service/sharding_propagation.cc b/third_party/xla/xla/service/sharding_propagation.cc
index 16a0152..5239d6c 100644
--- a/third_party/xla/xla/service/sharding_propagation.cc
+++ b/third_party/xla/xla/service/sharding_propagation.cc
@@ -358,7 +358,7 @@
       computation_map.find(instruction->parent()) == computation_map.end() &&
       !(is_entry_root && allow_spmd_sharding_propagation_to_output)) {
     // We don't support sharding the root instruction of a computation yet,
-    // unless the computation is a while body.
+    // unless the computation is in computation_map.
     return false;
   }
 
@@ -2886,14 +2886,23 @@
       return std::vector<HloInstruction*>{inst, callee->root_instruction()};
     } else if (inst->opcode() == HloOpcode::kParameter) {
       auto it = computation_map.find(inst->parent());
-      if (it != computation_map.end() &&
-          it->second->opcode() == HloOpcode::kConditional) {
-        HloInstruction* cond = it->second;
-        for (int64_t i = 1; i < cond->operand_count(); ++i) {
-          if (cond->called_computations()[i - 1] == inst->parent()) {
-            return std::vector<HloInstruction*>{inst, cond->mutable_operand(i)};
+      if (it != computation_map.end()) {
+        if (it->second->opcode() == HloOpcode::kConditional) {
+          HloInstruction* cond = it->second;
+          for (int64_t i = 1; i < cond->operand_count(); ++i) {
+            if (cond->called_computations()[i - 1] == inst->parent()) {
+              return std::vector<HloInstruction*>{inst,
+                                                  cond->mutable_operand(i)};
+            }
           }
         }
+        if (it->second->opcode() == HloOpcode::kCall) {
+          HloInstruction* call = it->second;
+          int64_t operand_index = inst->parameter_number();
+          CHECK_LT(operand_index, call->operand_count());
+          return std::vector<HloInstruction*>{
+              inst, call->mutable_operand(operand_index)};
+        }
       }
       return std::vector<HloInstruction*>{};
     } else {
@@ -2936,9 +2945,11 @@
               auto it = computation_map.find(instruction->parent());
               if (it != computation_map.end()) {
                 propagate_to_instruction(it->second);
-                // Propagate parameter shardings back to conditional's operands.
+                // Propagate parameter shardings back to conditional's and
+                // call's operands.
                 if (instruction->opcode() == HloOpcode::kParameter &&
-                    it->second->opcode() == HloOpcode::kConditional) {
+                    (it->second->opcode() == HloOpcode::kConditional ||
+                     it->second->opcode() == HloOpcode::kCall)) {
                   propagate_to_instruction(instruction);
                 }
               }
@@ -2954,8 +2965,8 @@
     }
   }
 
-  // Populate computation_map in order to associate while bodies to their
-  // while instructions.
+  // Populate computation_map in order to associate while bodies and conditions
+  // to their while instructions.
   for (auto computation : module->computations(execution_threads)) {
     for (auto instruction : computation->instructions()) {
       if (instruction->opcode() == HloOpcode::kWhile ||
@@ -2982,6 +2993,7 @@
         }
         if (instruction->opcode() == HloOpcode::kWhile) {
           computation_map[instruction->while_body()] = instruction;
+          computation_map[instruction->while_condition()] = instruction;
         } else {
           for (HloComputation* c : instruction->called_computations()) {
             computation_map[c] = instruction;
diff --git a/third_party/xla/xla/service/sharding_propagation_test.cc b/third_party/xla/xla/service/sharding_propagation_test.cc
index ac04389..072f436 100644
--- a/third_party/xla/xla/service/sharding_propagation_test.cc
+++ b/third_party/xla/xla/service/sharding_propagation_test.cc
@@ -2757,6 +2757,60 @@
   }
 }
 
+TEST_F(ShardingPropagationTest, PropagateShardingInWhileCondition) {
+  const char* const hlo_string = R"(
+HloModule module
+
+%cond {
+  %vars.cond = (u32[], f32[]) parameter(0)
+  %count.cond = u32[] get-tuple-element(%vars.cond), index=0
+  %limit = u32[] constant(10)
+  ROOT %lt = pred[] compare(%count.cond, %limit), direction=LT
+}
+
+%body {
+  %vars = (u32[], f32[]) parameter(0)
+  %count = u32[] get-tuple-element(%vars), index=0
+  %acc = f32[] get-tuple-element(%vars), index=1
+
+  %one = u32[] constant(1)
+  %count.1 = u32[] add(u32[] %count, u32[] %one)
+  %acc.1 = f32[] add(f32[] %acc, f32[] %acc)
+  ROOT %tuple = (u32[], f32[]) tuple(%count.1, %acc.1)
+}
+
+ENTRY %entry {
+  %p0 = f32[] parameter(0), sharding={devices=[2,2]<=[4] last_tile_dims={manual, replicated}}
+  %zero = u32[] constant(0), sharding={devices=[2,2]<=[4] last_tile_dims={manual, replicated}}
+  %init = (u32[], f32[]) tuple(%zero, %p0)
+  ROOT %while = (u32[], f32[]) while(%init), body=%body, condition=%cond
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  TF_ASSERT_OK_AND_ASSIGN(
+      bool changed,
+      ShardingPropagation(/*is_spmd=*/false, /*propagate_metadata=*/false,
+                          /*allow_spmd_sharding_propagation_to_output=*/{true})
+          .Run(module.get()));
+  EXPECT_TRUE(changed);
+  HloSharding single_sharding =
+      ParseSharding("{devices=[2,2]<=[4] last_tile_dims={manual, replicated}}")
+          .value();
+  HloSharding tuple_sharding = HloSharding::SingleTuple(
+      module->entry_computation()->root_instruction()->shape(),
+      single_sharding);
+
+  for (const HloComputation* computation : module->computations()) {
+    for (const HloInstruction* instruction : computation->instructions()) {
+      EXPECT_TRUE(instruction->has_sharding());
+      EXPECT_EQ(instruction->sharding(), instruction->shape().IsTuple()
+                                             ? tuple_sharding
+                                             : single_sharding);
+    }
+  }
+}
+
 TEST_P(ParameterizedMetadataTest, WhileGetShardingFromRecvInBody) {
   const char* const hlo_string = R"(
 HloModule module
@@ -12070,5 +12124,36 @@
           "last_tile_dim_replicate}}"));
 }
 
+TEST_F(ShardingPropagationTest, CallPropagation) {
+  const absl::string_view hlo_string = R"(
+HloModule module
+
+called_computation {
+  p0 = bf16[20,2,68096,8512] parameter(0)
+  %add_called_comp = bf16[20,2,68096,8512] add(p0, p0)
+  ROOT tuple = (bf16[20,2,68096,8512]) tuple(add_called_comp)
+}
+
+ENTRY main {
+  %param0 = bf16[20,2,68096,8512] parameter(0)
+  %add = bf16[20,2,68096,8512] add(param0, param0)
+  ROOT %call = (bf16[20,2,68096,8512]) call(add), to_apply=%called_computation, sharding={{devices=[1,1,16,64]<=[64,16]T(1,0)}}
+})";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  TF_ASSERT_OK_AND_ASSIGN(
+      bool changed,
+      ShardingPropagation(
+          /*is_spmd=*/true, /*propagate_metadata=*/true,
+          /*allow_spmd_sharding_propagation_to_output=*/{false},
+          /*allow_spmd_sharding_propagation_to_parameters=*/{false})
+          .Run(module.get()));
+  XLA_VLOG_LINES(1, module->ToString());
+  EXPECT_TRUE(changed);
+  auto* add = FindInstruction(module.get(), "add");
+  ASSERT_NE(add, nullptr);
+  EXPECT_THAT(add, op::Sharding("{devices=[1,1,16,64]<=[64,16]T(1,0)}"));
+}
+
 }  // namespace
 }  // namespace xla
diff --git a/third_party/xla/xla/service/space_to_batch_converter.cc b/third_party/xla/xla/service/space_to_batch_converter.cc
index e07ee54..45f2113 100644
--- a/third_party/xla/xla/service/space_to_batch_converter.cc
+++ b/third_party/xla/xla/service/space_to_batch_converter.cc
@@ -1734,6 +1734,10 @@
   }
 
   if (consumer->opcode() == HloOpcode::kReduce) {
+    // Do not propagate through tuple outputs.
+    if (consumer->shape().IsTuple()) {
+      return false;
+    }
     // Support only the trivial case where both batch and split spatial dim are
     // being reduced
 
@@ -1741,8 +1745,13 @@
     auto result = instr_to_dim_map_[consumer->mutable_operand(0)];
     const int64_t batch_dim = result[DimMapper(SpaceToBatchDimMap::kBatch)];
     const int64_t space_dim = result[DimMapper(SpaceToBatchDimMap::kSpace0)];
-    VLOG(1) << "Checking if reduce is supported batch_dim " << batch_dim
-            << " space_dim " << space_dim << " reduce " << consumer->ToString();
+    // Support the trivial case where none of the batch and split spatial dim
+    // are being reduced.
+    return !absl::c_linear_search(reduce_dims, batch_dim) &&
+           !absl::c_linear_search(reduce_dims, space_dim);
+
+    // Support only the trivial case where both batch and split spatial dim are
+    // being reduced
     return absl::c_linear_search(reduce_dims, batch_dim) &&
            absl::c_linear_search(reduce_dims, space_dim);
   }
@@ -2072,16 +2081,116 @@
   }
 
   if (consumer->opcode() == HloOpcode::kReduce) {
-    auto new_consumer = computation->AddInstruction(consumer->Clone());
-    auto first_operand = old_to_new_instrs_[consumer->mutable_operand(0)];
-
+    auto reduce_dims = consumer->dimensions();
     auto dim_map_val = instr_to_dim_map_[consumer->mutable_operand(0)];
+    auto first_operand = old_to_new_instrs_[consumer->mutable_operand(0)];
+    auto permute_dims = instr_to_dim_permute_map_[first_operand];
+
     const int64_t old_batch_dim =
         dim_map_val[DimMapper(SpaceToBatchDimMap::kBatch)];
+    const int64_t space_dim =
+        dim_map_val[DimMapper(SpaceToBatchDimMap::kSpace0)];
 
-    auto permute_dims = instr_to_dim_permute_map_[first_operand];
     const int64_t new_batch_dim = DimLookUp(permute_dims, old_batch_dim);
+    const int64_t new_space_dim = DimLookUp(permute_dims, space_dim);
+    std::vector<int64_t> changed_dims(consumer->dimensions().size());
 
+    // Support the trivial case where none of the batch and split spatial dim
+    // are being reduced.
+    if (!absl::c_linear_search(reduce_dims, old_batch_dim) &&
+        !absl::c_linear_search(reduce_dims, space_dim)) {
+      for (int64_t i = 0; i < consumer->dimensions().size(); ++i) {
+        changed_dims[i] = DimLookUp(permute_dims, consumer->dimensions(i));
+      }
+
+      // Decide where the new batch and space dims are in the output.
+      int64_t new_output_batch_dim = new_batch_dim;
+      int64_t new_output_space_dim = new_space_dim;
+      for (int64_t i = 0; i < consumer->dimensions().size(); ++i) {
+        if (changed_dims[i] < new_batch_dim) {
+          new_output_batch_dim--;
+        }
+        if (changed_dims[i] < new_space_dim) {
+          new_output_space_dim--;
+        }
+      }
+
+      // Decide where the new batch and space dims are in the original reduce's
+      // output.
+      int64_t old_output_batch_dim = old_batch_dim;
+      int64_t old_output_space_dim = space_dim;
+      for (int64_t i = 0; i < consumer->dimensions().size(); ++i) {
+        if (reduce_dims[i] < old_batch_dim) {
+          old_output_batch_dim--;
+        }
+        if (reduce_dims[i] < space_dim) {
+          old_output_space_dim--;
+        }
+      }
+
+      HloInstruction* new_consumer = nullptr;
+      TF_ASSIGN_OR_RETURN(
+          new_consumer,
+          MakeReduceHlo(first_operand, consumer->mutable_operand(1),
+                        changed_dims, consumer->called_computations()[0]));
+
+      VLOG(3) << " new_output_batch_dim " << new_output_batch_dim << " size "
+              << first_operand->shape().dimensions(new_batch_dim)
+              << " new_output_space_dim " << new_output_space_dim << " size "
+              << first_operand->shape().dimensions(new_space_dim);
+
+      std::vector<int64_t> dim_map(kNumMappedDims);
+      dim_map[DimMapper(SpaceToBatchDimMap::kBatch)] = old_output_batch_dim;
+      dim_map[DimMapper(SpaceToBatchDimMap::kSpace0)] = old_output_space_dim;
+      // We don't know where the feature dim is, so set it to -1.
+      dim_map[DimMapper(SpaceToBatchDimMap::kFeature)] = -1;
+
+      instr_to_dim_map_[consumer] = dim_map;
+      const int64_t rank = first_operand->shape().rank();
+
+      const int64_t output_rank = new_consumer->shape().rank();
+
+      // Make a map of each dim in original reduce output to input.
+      std::vector<int64_t> old_reduce_output_to_input(output_rank);
+      int dim_number_to_assign_old = 0;
+      for (int64_t i = 0; i < rank; ++i) {
+        if (auto it = absl::c_find(reduce_dims, i); it != reduce_dims.end()) {
+          continue;
+        }
+        old_reduce_output_to_input[i] = dim_number_to_assign_old++;
+      }
+
+      // Make a map of each dim in new reduce output to the new input.
+      std::vector<int64_t> new_reduce_output_to_input(output_rank);
+      int dim_number_to_assign_new = 0;
+      for (int64_t i = 0; i < rank; ++i) {
+        if (auto it = absl::c_find(changed_dims, i); it != changed_dims.end()) {
+          continue;
+        }
+        new_reduce_output_to_input[i] = dim_number_to_assign_new++;
+      }
+
+      std::vector<int64_t> new_permute_dims(output_rank);
+      // From the output dims to input dims mapping, figure how the old output
+      // dims are mapped to the new output dims.
+      for (int64_t i = 0; i < output_rank; ++i) {
+        new_permute_dims[i] = std::distance(
+            new_reduce_output_to_input.begin(),
+            absl::c_find(
+                new_reduce_output_to_input,
+                DimLookUp(permute_dims, old_reduce_output_to_input[i])));
+      }
+
+      instr_to_dim_permute_map_[new_consumer] = new_permute_dims;
+      old_to_new_instrs_[consumer] = new_consumer;
+
+      // Because batch and split spatial dims are not reduced, further
+      // propagation is needed.
+      return true;
+    }
+
+    HloInstruction* new_consumer =
+        computation->AddInstruction(consumer->Clone());
     auto retval = GetSpatialDimsToSplit(consumer->mutable_operand(0));
     std::vector<int64_t> old_spatial_dims = retval.first;
     std::vector<int64_t> new_spatial_dims = retval.second;
@@ -2092,7 +2201,6 @@
                            consumer->mutable_operand(1), new_batch_dim,
                            new_spatial_dims, old_batch_dim, old_spatial_dims));
 
-    std::vector<int64_t> changed_dims(new_consumer->dimensions().size());
     for (int64_t i = 0; i < new_consumer->dimensions().size(); ++i) {
       changed_dims[i] = DimLookUp(permute_dims, new_consumer->dimensions(i));
     }
diff --git a/third_party/xla/xla/service/space_to_batch_converter_test.cc b/third_party/xla/xla/service/space_to_batch_converter_test.cc
index e2ed331..dbc11e9 100644
--- a/third_party/xla/xla/service/space_to_batch_converter_test.cc
+++ b/third_party/xla/xla/service/space_to_batch_converter_test.cc
@@ -272,5 +272,84 @@
   ASSERT_TRUE(converter.Run(module.get()).value());
 }
 
+TEST_F(SpaceToBatchConverterTest, PropagateOnTrivialReduce) {
+  std::string hlo_string = R"(
+  HloModule module
+
+  %region_1.37 (Arg_0.38: f32[], Arg_1.39: f32[]) -> f32[] {
+    %Arg_0.38 = f32[] parameter(0)
+    %Arg_1.39 = f32[] parameter(1)
+    ROOT %add.40 = f32[] add(f32[] %Arg_0.38, f32[] %Arg_1.39)
+  }
+
+  ENTRY computation {
+    %p0 = bf16[7,320,800,3]{3,2,1,0} parameter(0)
+    %p1 = bf16[3,3,3,32]{3,2,1,0} parameter(1)
+    %c = f32[7,160,400,32]{3,2,1,0} convolution( %p0,  %p1),
+    window={size=3x3 stride=2x2 pad=0_1x0_1}, dim_labels=b01f_01io->b01f
+    %constant.5 = f32[] constant(0)
+    ROOT %reduce.41 = f32[7,160,400]{2,1,0} reduce(%c, %constant.5), dimensions={3}, to_apply=%region_1.37
+  }
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+
+  auto computation = module->entry_computation();
+  SpaceToBatchConverter converter(
+      SpaceToBatchController{true, true, true, true, /*number_of_splits=*/8});
+  ASSERT_TRUE(converter.Run(module.get()).value());
+
+  HloInstruction* root = computation->root_instruction();
+  EXPECT_THAT(root, op::Transpose());
+  EXPECT_THAT(root->operand(0)->operand(0)->operand(0)->operand(0),
+              op::Reduce());
+  auto new_reduce = root->operand(0)->operand(0)->operand(0)->operand(0);
+  // Make sure we propagated on the reduce with the larger batch size.
+  EXPECT_EQ(new_reduce->shape().dimensions(1),
+            // batch*number_of_splits
+            7 * 8);
+}
+
+TEST_F(SpaceToBatchConverterTest, DoNotPropagateOnTupleReduce) {
+  std::string hlo_string = R"(
+  HloModule module
+
+%minmax_func.2717 {
+  %lhs_value.2718 = f32[] parameter(0)
+  %rhs_value.2720 = f32[] parameter(2)
+  %compare.2722 = pred[] compare(f32[] %lhs_value.2718, f32[] %rhs_value.2720), direction=GE
+  %select.2723 = f32[] select(pred[] %compare.2722, f32[] %lhs_value.2718, f32[] %rhs_value.2720)
+  %compare.2725 = pred[] compare(f32[] %lhs_value.2718, f32[] %rhs_value.2720), direction=EQ
+  %lhs_index.2719 = f32[] parameter(1)
+  %rhs_index.2721 = f32[] parameter(3)
+  %minimum.2726 = f32[] minimum(f32[] %lhs_index.2719, f32[] %rhs_index.2721)
+  %select.2724 = f32[] select(pred[] %compare.2722, f32[] %lhs_index.2719, f32[] %rhs_index.2721)
+  %select.2727 = f32[] select(pred[] %compare.2725, f32[] %minimum.2726, f32[] %select.2724)
+  ROOT %tuple.4 = (f32[], f32[]) tuple(f32[] %select.2723, f32[] %select.2727)
+ }
+ 
+  ENTRY computation {
+    %p0 = bf16[7,320,800,3]{3,2,1,0} parameter(0)
+    %p1 = bf16[3,3,3,32]{3,2,1,0} parameter(1)
+    %c = f32[7,160,400,32]{3,2,1,0} convolution( %p0,  %p1),
+    window={size=3x3 stride=2x2 pad=0_1x0_1}, dim_labels=b01f_01io->b01f
+    %constant.5 = f32[] constant(0)
+    %constant.6 = f32[] constant(1)
+    ROOT %reduce.36 = (f32[7,160,400]{2,1,0}, f32[7,160,400]{2,1,0}) reduce(%c, %c,
+    %constant.5, %constant.6), dimensions={3}, to_apply=%minmax_func.2717
+  }
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+
+  auto computation = module->entry_computation();
+  SpaceToBatchConverter converter(
+      SpaceToBatchController{true, true, true, true, /*number_of_splits=*/8});
+  ASSERT_TRUE(converter.Run(module.get()).value());
+
+  HloInstruction* root = computation->root_instruction();
+  EXPECT_THAT(root, op::Reduce());
+}
+
 }  // namespace
 }  // namespace xla
diff --git a/third_party/xla/xla/service/spmd/dot_handler.cc b/third_party/xla/xla/service/spmd/dot_handler.cc
index 87db7f7..22f88cf 100644
--- a/third_party/xla/xla/service/spmd/dot_handler.cc
+++ b/third_party/xla/xla/service/spmd/dot_handler.cc
@@ -1899,7 +1899,7 @@
         has_reshape_operand(lhs) ? lhs.hlo()->operand(0) : lhs.hlo();
     auto rhs_operand =
         has_reshape_operand(rhs) ? rhs.hlo()->operand(0) : rhs.hlo();
-    for (auto loop : *windowed_dot_general_loops) {
+    for (const auto& loop : *windowed_dot_general_loops) {
       if (loop.while_loop->while_body()->name().find(
               "windowed_dot_general_body_ag") == 0) {
         auto cm_lhs = loop.while_loop->operand(0)->operand(0);
@@ -2562,7 +2562,8 @@
   };
   std::optional<GroupedSharding> other_grouped =
       try_sharding_for_other_operand(other.sharding());
-  if (!other_grouped && !other.sharding().IsReplicated()) {
+  if (!other_grouped && !other.sharding().IsReplicated() &&
+      dims_mapping.conv_spatial_dims.empty()) {
     const HloSharding expected_other_sharding =
         hlo_sharding_util::InferDotOperandSharding(
             &output_sharding, &matching.sharding(), lhs_matching ? 1 : 0,
@@ -2570,9 +2571,9 @@
     // Try the expected sharding since it is no worse than the last resort
     // (replicated sharding).
     other_grouped = try_sharding_for_other_operand(expected_other_sharding);
-    if (!other_grouped) {
-      other = other.Replicate();
-    }
+  }
+  if (!other_grouped) {
+    other = other.Replicate();
   }
 
   matching = matching.Reshard(UngroupSharding(matching_grouped));
diff --git a/third_party/xla/xla/service/spmd/shardy/BUILD b/third_party/xla/xla/service/spmd/shardy/BUILD
index fcacdd7..bd15f20 100644
--- a/third_party/xla/xla/service/spmd/shardy/BUILD
+++ b/third_party/xla/xla/service/spmd/shardy/BUILD
@@ -142,8 +142,8 @@
         "//xla/service/spmd/shardy/mhlo_round_trip:mhlo_import",
         "//xla/service/spmd/shardy/mhlo_round_trip:shard_map_export",
         "//xla/service/spmd/shardy/round_trip_common:convert_sharding_custom_calls",
-        "//xla/service/spmd/shardy/round_trip_common:identity_to_pass_through_while_args",
         "//xla/service/spmd/shardy/round_trip_common:import_constants",
+        "//xla/service/spmd/shardy/round_trip_common:open_while_free_vars_sharding",
         "//xla/service/spmd/shardy/round_trip_common:shard_map_import",
         "//xla/service/spmd/shardy/sdy_round_trip:export_ops",
         "//xla/service/spmd/shardy/sdy_round_trip:export_shardings",
diff --git a/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/export_ops.cc b/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/export_ops.cc
index b6c6a99..0ffff71 100644
--- a/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/export_ops.cc
+++ b/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/export_ops.cc
@@ -66,7 +66,6 @@
 using ::mlir::success;
 
 using ::mlir::sdy::ConstantOp;
-using ::mlir::sdy::IdentityOp;
 using ::mlir::sdy::kShardingAttr;
 using ::mlir::sdy::ReshardOp;
 using ::mlir::sdy::ShardingConstraintOp;
@@ -88,20 +87,6 @@
   }
 };
 
-// Removes `sdy::IdentityOp`.
-class IdentityPattern : public OpConversionPattern<IdentityOp> {
- public:
-  using OpConversionPattern::OpConversionPattern;
-
- private:
-  LogicalResult matchAndRewrite(
-      IdentityOp op, OpAdaptor adaptor,
-      ConversionPatternRewriter& rewriter) const override {
-    rewriter.replaceOp(op, adaptor.getInput());
-    return success();
-  }
-};
-
 class ReshardPattern : public OpConversionPattern<ReshardOp> {
  public:
   using OpConversionPattern::OpConversionPattern;
@@ -148,15 +133,14 @@
     // We do not expect to see ShardingConstraintOp in the input module.
     // ShardingConstraintOp should be replaced by ReshardOp before this pass.
     // Hence, we add ShardingConstraintOp as an illegal op.
-    target.addIllegalOp<ConstantOp, IdentityOp, ReshardOp,
-                        ShardingConstraintOp>();
+    target.addIllegalOp<ConstantOp, ReshardOp, ShardingConstraintOp>();
     target.addLegalOp<mhlo::ConstantOp, mhlo::CopyOp>();
     mlir::RewritePatternSet patterns(&context);
     // After converting `sdy.constant` into `mhlo.constant`, the constants
     // should not be deduped via folding. Fortunately, folding only happens in
     // greedy pattern rewriters. ExportHloShardingsPass does a simple walk,
     // which keeps the constants as is.
-    patterns.add<ConstantPattern, IdentityPattern, ReshardPattern>(&context);
+    patterns.add<ConstantPattern, ReshardPattern>(&context);
     if (mlir::failed(mlir::applyPartialConversion(getOperation(), target,
                                                   std::move(patterns)))) {
       signalPassFailure();
@@ -166,8 +150,8 @@
   StringRef getArgument() const override { return "xla-sdy-export-ops"; }
 
   StringRef getDescription() const override {
-    return "Exports Shardy ops to MHLO ops. Processes sdy::IdentityOp, "
-           "sdy::ReshardOp, and sdy::ConstantOp.";
+    return "Exports Shardy ops to MHLO ops. Processes sdy::ReshardOp and "
+           "sdy::ConstantOp.";
   }
 
   void getDependentDialects(mlir::DialectRegistry& registry) const final {
diff --git a/third_party/xla/xla/service/spmd/shardy/round_trip_common/BUILD b/third_party/xla/xla/service/spmd/shardy/round_trip_common/BUILD
index f9fd53b..e929f61 100644
--- a/third_party/xla/xla/service/spmd/shardy/round_trip_common/BUILD
+++ b/third_party/xla/xla/service/spmd/shardy/round_trip_common/BUILD
@@ -32,9 +32,9 @@
 )
 
 cc_library(
-    name = "identity_to_pass_through_while_args",
-    srcs = ["identity_to_pass_through_while_args.cc"],
-    hdrs = ["identity_to_pass_through_while_args.h"],
+    name = "import_constants",
+    srcs = ["import_constants.cc"],
+    hdrs = ["import_constants.h"],
     deps = [
         "//xla/mlir_hlo",
         "@llvm-project//llvm:Support",
@@ -48,9 +48,9 @@
 )
 
 cc_library(
-    name = "import_constants",
-    srcs = ["import_constants.cc"],
-    hdrs = ["import_constants.h"],
+    name = "open_while_free_vars_sharding",
+    srcs = ["open_while_free_vars_sharding.cc"],
+    hdrs = ["open_while_free_vars_sharding.h"],
     deps = [
         "//xla/mlir_hlo",
         "@llvm-project//llvm:Support",
@@ -94,8 +94,8 @@
     hdrs = ["pipeline_passes.h"],
     deps = [
         ":convert_sharding_custom_calls",
-        ":identity_to_pass_through_while_args",
         ":import_constants",
+        ":open_while_free_vars_sharding",
         ":shard_map_import",
         "//xla/mlir_hlo:mhlo_passes",
         "@llvm-project//mlir:FuncDialect",
diff --git a/third_party/xla/xla/service/spmd/shardy/round_trip_common/identity_to_pass_through_while_args.cc b/third_party/xla/xla/service/spmd/shardy/round_trip_common/identity_to_pass_through_while_args.cc
deleted file mode 100644
index e1675c6..0000000
--- a/third_party/xla/xla/service/spmd/shardy/round_trip_common/identity_to_pass_through_while_args.cc
+++ /dev/null
@@ -1,99 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/service/spmd/shardy/round_trip_common/identity_to_pass_through_while_args.h"
-
-#include <memory>
-
-#include "llvm/ADT/StringRef.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/IR/OperationSupport.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/Value.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Pass/PassRegistry.h"
-#include "mlir/Support/LLVM.h"
-#include "mlir/Support/TypeID.h"
-#include "mlir/Transforms/DialectConversion.h"
-#include "shardy/dialect/sdy/ir/dialect.h"
-#include "shardy/dialect/sdy/ir/utils.h"
-#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
-
-namespace xla {
-namespace sdy {
-
-namespace {
-
-using ::mlir::StringRef;
-
-using ::mlir::func::FuncOp;
-
-// For every block argument of an `mhlo::WhileOp` that is directly returned by
-// the body of the op (pass-through), add an `sdy::IdentityOp` between the block
-// argument and the return op.
-//
-// This will prevent canonicalization from replacing these block arguments with
-// the corresponding operands as free variables.
-class AddIdentityToPassThroughWhileArgsPass
-    : public mlir::PassWrapper<AddIdentityToPassThroughWhileArgsPass,
-                               mlir::OperationPass<FuncOp>> {
- public:
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
-      AddIdentityToPassThroughWhileArgsPass)
-
-  void runOnOperation() final {
-    FuncOp funcOp = getOperation();
-    mlir::IRRewriter rewriter(funcOp);
-
-    funcOp.walk([&](mlir::mhlo::WhileOp op) {
-      mlir::Operation* returnOp = mlir::sdy::getBodyTerminator(op);
-      rewriter.setInsertionPoint(returnOp);
-      for (mlir::Value returnValue : returnOp->getOperands()) {
-        if (auto blockArg = mlir::dyn_cast<mlir::BlockArgument>(returnValue);
-            blockArg && blockArg.getOwner() == &op.getBody().front()) {
-          auto identityOp = rewriter.create<mlir::sdy::IdentityOp>(
-              returnValue.getLoc(), returnValue);
-          rewriter.replaceUsesWithIf(returnValue, identityOp,
-                                     [returnOp](mlir::OpOperand& use) {
-                                       return use.getOwner() == returnOp;
-                                     });
-        }
-      }
-    });
-  }
-
-  StringRef getArgument() const override {
-    return "xla-sdy-add-identity-to-pass-through-while-args";
-  }
-
-  StringRef getDescription() const override {
-    return "Adds an identity op between pass-through block arguments of a "
-           "while op.";
-  }
-};
-
-}  // namespace
-
-std::unique_ptr<mlir::Pass> createAddIdentityToPassThroughWhileArgsPass() {
-  return std::make_unique<AddIdentityToPassThroughWhileArgsPass>();
-}
-
-void registerAddIdentityToPassThroughWhileArgsPass() {
-  mlir::registerPass(createAddIdentityToPassThroughWhileArgsPass);
-}
-
-}  // namespace sdy
-}  // namespace xla
diff --git a/third_party/xla/xla/service/spmd/shardy/round_trip_common/identity_to_pass_through_while_args.h b/third_party/xla/xla/service/spmd/shardy/round_trip_common/identity_to_pass_through_while_args.h
deleted file mode 100644
index 5dcb51f..0000000
--- a/third_party/xla/xla/service/spmd/shardy/round_trip_common/identity_to_pass_through_while_args.h
+++ /dev/null
@@ -1,36 +0,0 @@
-/* Copyright 2024 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef XLA_SERVICE_SPMD_SHARDY_ROUND_TRIP_COMMON_IDENTITY_TO_PASS_THROUGH_WHILE_ARGS_H_
-#define XLA_SERVICE_SPMD_SHARDY_ROUND_TRIP_COMMON_IDENTITY_TO_PASS_THROUGH_WHILE_ARGS_H_
-
-#include <memory>
-
-#include "mlir/Pass/Pass.h"
-
-namespace xla {
-namespace sdy {
-
-// Creates a pass that adds an identity op between pass-through block arguments
-// of a while op.
-std::unique_ptr<mlir::Pass> createAddIdentityToPassThroughWhileArgsPass();
-
-// Registers the xla-sdy-add-identity-to-pass-through-while-args pass.
-void registerAddIdentityToPassThroughWhileArgsPass();
-
-}  // namespace sdy
-}  // namespace xla
-
-#endif  // XLA_SERVICE_SPMD_SHARDY_ROUND_TRIP_COMMON_IDENTITY_TO_PASS_THROUGH_WHILE_ARGS_H_
diff --git a/third_party/xla/xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.cc b/third_party/xla/xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.cc
new file mode 100644
index 0000000..603b270
--- /dev/null
+++ b/third_party/xla/xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.cc
@@ -0,0 +1,95 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.h"
+
+#include <memory>
+
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/StringRef.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Support/TypeID.h"
+#include "mlir/Transforms/RegionUtils.h"
+#include "shardy/dialect/sdy/ir/dialect.h"
+#include "shardy/dialect/sdy/ir/utils.h"
+#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
+
+namespace xla {
+namespace sdy {
+
+namespace {
+
+using ::mlir::StringRef;
+using ::mlir::func::FuncOp;
+using ::mlir::sdy::TensorShardingAttr;
+
+class OpenWhileFreeVarsShardingPass
+    : public mlir::PassWrapper<OpenWhileFreeVarsShardingPass,
+                               mlir::OperationPass<FuncOp>> {
+ public:
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenWhileFreeVarsShardingPass)
+
+  void runOnOperation() final {
+    FuncOp funcOp = getOperation();
+    mlir::IRRewriter rewriter(funcOp);
+
+    funcOp.walk([&](mlir::mhlo::WhileOp op) {
+      llvm::SetVector<mlir::Value> freeVars;
+      mlir::getUsedValuesDefinedAbove(op->getRegions(), freeVars);
+      rewriter.setInsertionPoint(op);
+      for (mlir::Value freeVar : freeVars) {
+        TensorShardingAttr sharding = mlir::sdy::getSharding(freeVar);
+        if (!sharding || sharding.getRank() == 0) {
+          continue;
+        }
+        auto shardingConstraint =
+            rewriter.create<mlir::sdy::ShardingConstraintOp>(
+                freeVar.getLoc(), freeVar,
+                TensorShardingAttr::getFullyOpenLike(sharding));
+        // Only replace uses in the regions of the while op.
+        rewriter.replaceUsesWithIf(
+            freeVar, shardingConstraint, [op](mlir::OpOperand& use) {
+              return op->isProperAncestor(use.getOwner());
+            });
+      }
+    });
+  }
+
+  StringRef getArgument() const override {
+    return "xla-sdy-open-while-free-vars-sharding";
+  }
+
+  StringRef getDescription() const override {
+    return "Adds a fully open sharding constraint to free variables of while "
+           "op that already have a sharding.";
+  }
+};
+
+}  // namespace
+
+std::unique_ptr<mlir::Pass> createOpenWhileFreeVarsShardingPass() {
+  return std::make_unique<OpenWhileFreeVarsShardingPass>();
+}
+
+void registerOpenWhileFreeVarsShardingPass() {
+  mlir::registerPass(createOpenWhileFreeVarsShardingPass);
+}
+
+}  // namespace sdy
+}  // namespace xla
diff --git a/third_party/xla/xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.h b/third_party/xla/xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.h
new file mode 100644
index 0000000..c06776f
--- /dev/null
+++ b/third_party/xla/xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.h
@@ -0,0 +1,40 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_SPMD_SHARDY_ROUND_TRIP_COMMON_OPEN_WHILE_FREE_VARS_SHARDING_H_
+#define XLA_SERVICE_SPMD_SHARDY_ROUND_TRIP_COMMON_OPEN_WHILE_FREE_VARS_SHARDING_H_
+
+#include <memory>
+
+#include "mlir/Pass/Pass.h"
+
+namespace xla {
+namespace sdy {
+
+// Creates a pass that adds a fully open sharding constraint to free variables
+// of while op that already have a user-defined sharding.
+//
+// This allows for their uses in the while op to be further sharded, which is
+// important when converting to HLO as they will be lifted as passthrough while
+// operands/results.
+std::unique_ptr<mlir::Pass> createOpenWhileFreeVarsShardingPass();
+
+// Registers the xla-sdy-open-while-free-vars-sharding pass.
+void registerOpenWhileFreeVarsShardingPass();
+
+}  // namespace sdy
+}  // namespace xla
+
+#endif  // XLA_SERVICE_SPMD_SHARDY_ROUND_TRIP_COMMON_OPEN_WHILE_FREE_VARS_SHARDING_H_
diff --git a/third_party/xla/xla/service/spmd/shardy/round_trip_common/pipeline_passes.cc b/third_party/xla/xla/service/spmd/shardy/round_trip_common/pipeline_passes.cc
index 5ddeb43..23960ab 100644
--- a/third_party/xla/xla/service/spmd/shardy/round_trip_common/pipeline_passes.cc
+++ b/third_party/xla/xla/service/spmd/shardy/round_trip_common/pipeline_passes.cc
@@ -20,8 +20,8 @@
 #include "mlir/Transforms/Passes.h"
 #include "xla/mlir_hlo/mhlo/transforms/passes.h"
 #include "xla/service/spmd/shardy/round_trip_common/convert_sharding_custom_calls.h"
-#include "xla/service/spmd/shardy/round_trip_common/identity_to_pass_through_while_args.h"
 #include "xla/service/spmd/shardy/round_trip_common/import_constants.h"
+#include "xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.h"
 #include "xla/service/spmd/shardy/round_trip_common/shard_map_import.h"
 
 namespace xla {
@@ -36,11 +36,6 @@
   // changes happen before shardings are added to operations, to ensure the
   // correct shardings are added and that they are not lost by this pass.
   pm.addNestedPass<FuncOp>(mlir::mhlo::createPrepareForExportPass());
-  // The prepare-for-export pass lifts `mhlo::WhileOp` free variables, and added
-  // them as additional operands of the op whose corresponding block arguments
-  // are directly returned by the body of the op (pass-through). To prevent
-  // canonicalization from undoing this, we add identity ops.
-  pm.addNestedPass<FuncOp>(createAddIdentityToPassThroughWhileArgsPass());
 
   // We import `mhlo.constant` ops to `sdy.constant` ops so that constants
   // aren't folded in greedy pattern rewriters, which would lift them outside of
@@ -51,13 +46,15 @@
 
   pm.addNestedPass<FuncOp>(mlir::mhlo::createFlattenTuplePass());
   // We need to canonicalize redundant mhlo::GetTupleElementOp and
-  // mhlo::GetTupleOp.
+  // mhlo::GetTupleOp. We also need to canonicalize mhlo::WhileOp before
+  // `createOpenWhileFreeVarsShardingPass`.
   pm.addPass(mlir::createCanonicalizerPass());
 }
 
 void addCommonPostImportPasses(mlir::OpPassManager& pm) {
   pm.addPass(createShardMapImportPass());
   pm.addPass(createConvertShardingCustomCallsPass());
+  pm.addNestedPass<FuncOp>(createOpenWhileFreeVarsShardingPass());
 }
 
 }  // namespace sdy
diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_opt_main.cc b/third_party/xla/xla/service/spmd/shardy/sdy_opt_main.cc
index c12587c..b5670e7 100644
--- a/third_party/xla/xla/service/spmd/shardy/sdy_opt_main.cc
+++ b/third_party/xla/xla/service/spmd/shardy/sdy_opt_main.cc
@@ -29,8 +29,8 @@
 #include "xla/service/spmd/shardy/mhlo_round_trip/mhlo_import.h"
 #include "xla/service/spmd/shardy/mhlo_round_trip/shard_map_export.h"
 #include "xla/service/spmd/shardy/round_trip_common/convert_sharding_custom_calls.h"
-#include "xla/service/spmd/shardy/round_trip_common/identity_to_pass_through_while_args.h"
 #include "xla/service/spmd/shardy/round_trip_common/import_constants.h"
+#include "xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.h"
 #include "xla/service/spmd/shardy/round_trip_common/shard_map_import.h"
 #include "xla/service/spmd/shardy/sdy_round_trip/export_ops.h"
 #include "xla/service/spmd/shardy/sdy_round_trip/export_shardings.h"
@@ -55,7 +55,7 @@
   xla::sdy::registerMhloImportShardingsPass();
   xla::sdy::registerShardMapImportPass();
   xla::sdy::registerConvertShardingCustomCallsPass();
-  xla::sdy::registerAddIdentityToPassThroughWhileArgsPass();
+  xla::sdy::registerOpenWhileFreeVarsShardingPass();
   xla::sdy::registerImportConstantsPass();
 
   xla::sdy::registerMhloExportPipeline();
diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_ops.cc b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_ops.cc
index d4e14da..b5bd21f 100644
--- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_ops.cc
+++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_ops.cc
@@ -61,7 +61,6 @@
 using ::mlir::success;
 
 using ::mlir::sdy::ConstantOp;
-using ::mlir::sdy::IdentityOp;
 using ::mlir::sdy::ShardingConstraintOp;
 using ::mlir::sdy::TensorShardingAttr;
 using ::mlir::sdy::TensorShardingPerValueAttr;
@@ -81,20 +80,6 @@
   }
 };
 
-// Removes `sdy::IdentityOp`.
-class IdentityPattern : public OpConversionPattern<IdentityOp> {
- public:
-  using OpConversionPattern::OpConversionPattern;
-
- private:
-  LogicalResult matchAndRewrite(
-      IdentityOp op, OpAdaptor adaptor,
-      ConversionPatternRewriter& rewriter) const override {
-    rewriter.replaceOp(op, adaptor.getInput());
-    return success();
-  }
-};
-
 class ShardingConstraintPattern
     : public OpConversionPattern<ShardingConstraintOp> {
  public:
@@ -130,11 +115,10 @@
   void runOnOperation() final {
     mlir::MLIRContext& context = getContext();
     mlir::ConversionTarget target(context);
-    target.addIllegalOp<ConstantOp, IdentityOp, ShardingConstraintOp>();
+    target.addIllegalOp<ConstantOp, ShardingConstraintOp>();
     target.addLegalOp<mhlo::ConstantOp, mhlo::CustomCallOp>();
     mlir::RewritePatternSet patterns(&context);
-    patterns.add<ConstantPattern, IdentityPattern, ShardingConstraintPattern>(
-        &context);
+    patterns.add<ConstantPattern, ShardingConstraintPattern>(&context);
     if (mlir::failed(mlir::applyPartialConversion(getOperation(), target,
                                                   std::move(patterns)))) {
       signalPassFailure();
diff --git a/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass_test.cc b/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass_test.cc
index 59197cc..40463d0 100644
--- a/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass_test.cc
+++ b/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass_test.cc
@@ -535,36 +535,42 @@
 
 TEST_F(ShardyXLATest, WhileWithFreeVariables) {
   const char* const hloString = R"(
-    HloModule main
+    HloModule main, entry_computation_layout={(f32[32,96]{1,0}, f32[32,96]{1,0})->f32[32,96]{1,0}}
 
-    %region_0.6 (arg_tuple.7: (f32[32,96], s32[], s32[], s32[])) -> (f32[32,96], s32[], s32[], s32[]) {
-      %arg_tuple.7 = (f32[32,96]{1,0}, s32[], s32[], s32[]) parameter(0)
-      %get-tuple-element.8 = f32[32,96]{1,0} get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[]) %arg_tuple.7), index=0
-      %add.13 = f32[32,96]{1,0} add(f32[32,96]{1,0} %get-tuple-element.8, f32[32,96]{1,0} %get-tuple-element.8)
-      %get-tuple-element.9 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[]) %arg_tuple.7), index=1
-      %get-tuple-element.11 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[]) %arg_tuple.7), index=3
-      %add.12 = s32[] add(s32[] %get-tuple-element.9, s32[] %get-tuple-element.11)
-      %get-tuple-element.10 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[]) %arg_tuple.7), index=2
-      ROOT %tuple.14 = (f32[32,96]{1,0}, s32[], s32[], s32[]) tuple(f32[32,96]{1,0} %add.13, s32[] %add.12, s32[] %get-tuple-element.10, s32[] %get-tuple-element.11)
+    %region_0.7 (arg_tuple.8: (f32[32,96], s32[], s32[], s32[], f32[32,96])) -> (f32[32,96], s32[], s32[], s32[], f32[32,96]) {
+      %arg_tuple.8 = (f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) parameter(0)
+      %get-tuple-element.9 = f32[32,96]{1,0} get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %arg_tuple.8), index=0
+      %get-tuple-element.13 = f32[32,96]{1,0} get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %arg_tuple.8), index=4
+      %add.15 = f32[32,96]{1,0} add(f32[32,96]{1,0} %get-tuple-element.9, f32[32,96]{1,0} %get-tuple-element.13), metadata={source_file="-" source_line=25}
+      %get-tuple-element.10 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %arg_tuple.8), index=1
+      %get-tuple-element.12 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %arg_tuple.8), index=3
+      %add.14 = s32[] add(s32[] %get-tuple-element.10, s32[] %get-tuple-element.12), metadata={source_file="-" source_line=24}
+      %get-tuple-element.11 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %arg_tuple.8), index=2
+      ROOT %tuple.16 = (f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) tuple(f32[32,96]{1,0} %add.15, s32[] %add.14, s32[] %get-tuple-element.11, s32[] %get-tuple-element.12, f32[32,96]{1,0} %get-tuple-element.13)
     }
 
-    %region_1.15 (arg_tuple.16: (f32[32,96], s32[], s32[], s32[])) -> pred[] {
-      %arg_tuple.16 = (f32[32,96]{1,0}, s32[], s32[], s32[]) parameter(0)
-      %get-tuple-element.17 = f32[32,96]{1,0} get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[]) %arg_tuple.16), index=0
-      %get-tuple-element.20 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[]) %arg_tuple.16), index=3
-      %get-tuple-element.18 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[]) %arg_tuple.16), index=1
-      %get-tuple-element.19 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[]) %arg_tuple.16), index=2
-      ROOT %compare.21 = pred[] compare(s32[] %get-tuple-element.18, s32[] %get-tuple-element.19), direction=LT
+    %region_1.17 (arg_tuple.18: (f32[32,96], s32[], s32[], s32[], f32[32,96])) -> pred[] {
+      %arg_tuple.18 = (f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) parameter(0)
+      %get-tuple-element.19 = f32[32,96]{1,0} get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %arg_tuple.18), index=0
+      %get-tuple-element.22 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %arg_tuple.18), index=3
+      %get-tuple-element.23 = f32[32,96]{1,0} get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %arg_tuple.18), index=4
+      %get-tuple-element.20 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %arg_tuple.18), index=1
+      %get-tuple-element.21 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %arg_tuple.18), index=2
+      ROOT %compare.24 = pred[] compare(s32[] %get-tuple-element.20, s32[] %get-tuple-element.21), direction=LT, metadata={source_file="-" source_line=21}
     }
 
-    ENTRY %main.27 (Arg_0.1: f32[32,96]) -> f32[32,96] {
+    ENTRY %main.30 (Arg_0.1: f32[32,96], Arg_1.2: f32[32,96]) -> f32[32,96] {
       %Arg_0.1 = f32[32,96]{1,0} parameter(0), sharding={devices=[2,2]<=[4]}
-      %constant.2 = s32[] constant(0)
-      %constant.4 = s32[] constant(32)
-      %constant.3 = s32[] constant(1)
-      %tuple.5 = (f32[32,96]{1,0}, s32[], s32[], s32[]) tuple(f32[32,96]{1,0} %Arg_0.1, s32[] %constant.2, s32[] %constant.4, s32[] %constant.3)
-      %while.22 = (f32[32,96]{1,0}, s32[], s32[], s32[]) while((f32[32,96]{1,0}, s32[], s32[], s32[]) %tuple.5), condition=%region_1.15, body=%region_0.6
-      ROOT %get-tuple-element.23 = f32[32,96]{1,0} get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[]) %while.22), index=0
+      %constant.3 = s32[] constant(0)
+      %constant.5 = s32[] constant(32)
+      %constant.4 = s32[] constant(1)
+      %Arg_1.2 = f32[32,96]{1,0} parameter(1), sharding={devices=[2,1,2]<=[4] last_tile_dim_replicate}
+      %tuple.6 = (f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) tuple(f32[32,96]{1,0} %Arg_0.1, s32[] %constant.3, s32[] %constant.5, s32[] %constant.4, f32[32,96]{1,0} %Arg_1.2), metadata={source_file="-" source_line=19}
+      %while.25 = (f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) while((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %tuple.6), condition=%region_1.17, body=%region_0.7, metadata={source_file="-" source_line=19}
+      %get-tuple-element.27 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %while.25), index=1, metadata={source_file="-" source_line=19}
+      %get-tuple-element.26 = f32[32,96]{1,0} get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %while.25), index=0, metadata={source_file="-" source_line=19}
+      %tuple.28 = (f32[32,96]{1,0}) tuple(f32[32,96]{1,0} %get-tuple-element.26)
+      ROOT %get-tuple-element.29 = f32[32,96]{1,0} get-tuple-element((f32[32,96]{1,0}) %tuple.28), index=0
     })";
   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
                           ParseAndReturnVerifiedModule(hloString));
@@ -575,10 +581,14 @@
   HloInstruction* whileInst =
       FindInstruction(module.get(), xla::HloOpcode::kWhile);
   EXPECT_NE(whileInst, nullptr);
-  EXPECT_THAT(
-      whileInst,
-      op::Sharding(
-          "{{devices=[2,2]<=[4]}, {replicated}, {replicated}, {replicated}}"));
+  // Verify that the sharding of parameter(1) hasn't changed.
+  EXPECT_THAT(module->entry_computation()->parameter_instruction(1),
+              op::Sharding("{devices=[2,1,2]<=[4] last_tile_dim_replicate}"));
+  // Verify the sharding of the while, and specifically that the sharding of the
+  // result that corresponds to parameter(1) is further sharded.
+  EXPECT_THAT(whileInst,
+              op::Sharding("{{devices=[2,2]<=[4]}, {replicated}, {replicated}, "
+                           "{devices=[2,2]<=[4]}, {replicated}}"));
 }
 
 TEST_F(ShardyXLATest, ShardMap) {
diff --git a/third_party/xla/xla/service/spmd/shardy/test/mhlo_import_pipeline.mlir b/third_party/xla/xla/service/spmd/shardy/test/mhlo_import_pipeline.mlir
index fdc7efb..b022afc 100644
--- a/third_party/xla/xla/service/spmd/shardy/test/mhlo_import_pipeline.mlir
+++ b/third_party/xla/xla/service/spmd/shardy/test/mhlo_import_pipeline.mlir
@@ -52,38 +52,44 @@
 
 // -----
 
+// CHECK-LABEL: sdy.mesh @mesh = <"axis_0"=2, "axis_1"=2>
+
 // CHECK-LABEL: func @while_with_free_variables
-func.func @while_with_free_variables(%arg0: tensor<32x96xf32>) -> tensor<32x96xf32> {
+func.func @while_with_free_variables(
+    %arg0: tensor<32x96xf32>,
+    %arg1: tensor<32x96xf32> {mhlo.sharding = "{devices=[2,1,2]<=[4] last_tile_dim_replicate}"})
+    -> tensor<32x96xf32> {
   // CHECK-NEXT: %[[C0:.*]] = sdy.constant dense<0>
   // CHECK-NEXT: %[[C1:.*]] = sdy.constant dense<1>
-  // CHECK-NEXT: %[[C32:.*]] = sdy.constant dense<32>
-  // CHECK-NEXT: %[[WHILE:.*]]:4 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %[[C0]], %iterArg_1 = %[[C32]], %iterArg_2 = %[[C1]])
+  // CHECK-NEXT: %[[C32:.*]] = sdy.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, []>]>} dense<32>
+  // CHECK-NEXT: %[[SC:.*]] = sdy.sharding_constraint %arg1 <@mesh, [{?}, {?}]>
+  // CHECK-NEXT: %[[WHILE:.*]]:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %[[C0]])
   // CHECK-NEXT:   cond {
-  // CHECK-NEXT:   %[[COND:.*]] = mhlo.compare LT, %iterArg_0, %iterArg_1
+  // CHECK-NEXT:   %[[COND:.*]] = mhlo.compare LT, %iterArg_0, %[[C32]]
   // CHECK-NEXT:   mhlo.return %[[COND]]
   // CHECK-NEXT: } do {
-  // CHECK-NEXT:   %[[ADD_0:.*]] = mhlo.add %iterArg_0, %iterArg_2
-  // CHECK-NEXT:   %[[ADD_1:.*]] = mhlo.add %iterArg, %iterArg
-  // CHECK-NEXT:   %[[IDENTITY_0:.*]] = sdy.identity %iterArg_1
-  // CHECK-NEXT:   %[[IDENTITY_1:.*]] = sdy.identity %iterArg_2
-  // CHECK-NEXT:   mhlo.return %[[ADD_1]], %[[ADD_0]], %[[IDENTITY_0]], %[[IDENTITY_1]]
+  // CHECK-NEXT:   %[[ADD_0:.*]] = mhlo.add %iterArg_0, %[[C1]]
+  // CHECK-NEXT:   %[[ADD_1:.*]] = mhlo.add %iterArg, %[[SC]]
+  // CHECK-NEXT:   mhlo.return %[[ADD_1]], %[[ADD_0]]
   // CHECK-NEXT: }
   // CHECK-NEXT: return %[[WHILE]]#0
   %0 = mhlo.constant dense<0> : tensor<i32>
   %1 = mhlo.constant dense<1> : tensor<i32>
-  %2 = mhlo.constant dense<32> : tensor<i32>
+  %2 = mhlo.constant {mhlo.sharding = "{replicated}"} dense<32> : tensor<i32>
   %3:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %0) : tensor<32x96xf32>, tensor<i32>
     cond {
     %4 = mhlo.compare LT, %iterArg_0, %2 : (tensor<i32>, tensor<i32>) -> tensor<i1>
     mhlo.return %4 : tensor<i1>
   } do {
     %4 = mhlo.add %iterArg_0, %1 : tensor<i32>
-    %5 = mhlo.add %iterArg, %iterArg : tensor<32x96xf32>
+    %5 = mhlo.add %iterArg, %arg1 : tensor<32x96xf32>
     mhlo.return %5, %4 : tensor<32x96xf32>, tensor<i32>
   }
   return %3#0 : tensor<32x96xf32>
 }
 
+// -----
+
 // CHECK-LABEL: func @while_with_sinked_constants
 func.func @while_with_sinked_constants(%arg0: tensor<32x96xf32>) -> tensor<32x96xf32> {
   // CHECK-NEXT: %[[C0:.*]] = sdy.constant dense<0>
diff --git a/third_party/xla/xla/service/spmd/shardy/test/open_while_free_vars_sharding.mlir b/third_party/xla/xla/service/spmd/shardy/test/open_while_free_vars_sharding.mlir
new file mode 100644
index 0000000..b87048e
--- /dev/null
+++ b/third_party/xla/xla/service/spmd/shardy/test/open_while_free_vars_sharding.mlir
@@ -0,0 +1,93 @@
+// RUN: sdy_opt %s -xla-sdy-open-while-free-vars-sharding 2>&1 | FileCheck %s
+
+sdy.mesh @mesh1 = <"a"=2>
+sdy.mesh @mesh2 = <"b"=2>
+
+// CHECK-LABEL: func @while_with_free_variables
+func.func @while_with_free_variables(
+    %arg0: tensor<32x96xf32>,
+    %arg1: tensor<32x96xf32> {sdy.sharding = #sdy.sharding<@mesh1, [{"a"}, {}]>},
+    %arg2: tensor<32x96xf32>)
+    -> (tensor<32x96xf32>, tensor<32x96xf32>) {
+  // CHECK-NEXT: %[[C0:.*]] = mhlo.constant dense<0>
+  // CHECK-NEXT: %[[C1:.*]] = mhlo.constant dense<1>
+  // CHECK-NEXT: %[[C32:.*]] = mhlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, []>]>} dense<32>
+  // CHECK-NEXT: %[[ADD_0:.*]] = mhlo.add %arg1, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh2, [{}, {"b"}]>]>}
+  // CHECK-NEXT: %[[SC_0:.*]] = sdy.sharding_constraint %arg1 <@mesh1, [{?}, {?}]>
+  // CHECK-NEXT: %[[SC_1:.*]] = sdy.sharding_constraint %[[ADD_0]] <@mesh2, [{?}, {?}]>
+  // CHECK-NEXT: %[[WHILE:.*]]:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %[[C0]])
+  // CHECK-NEXT:   cond {
+  // CHECK-NEXT:   %[[COND:.*]] = mhlo.compare LT, %iterArg_0, %[[C32]]
+  // CHECK-NEXT:   mhlo.return %[[COND]]
+  // CHECK-NEXT: } do {
+  // CHECK-NEXT:   %[[ADD_1:.*]] = mhlo.add %iterArg_0, %[[C1]]
+  // CHECK-NEXT:   %[[ADD_2:.*]] = mhlo.add %iterArg, %[[SC_0]]
+  // CHECK-NEXT:   %[[ADD_3:.*]] = mhlo.add %[[ADD_2]], %arg2
+  // CHECK-NEXT:   %[[ADD_4:.*]] = mhlo.add %[[ADD_3]], %[[SC_1]]
+  // CHECK-NEXT:   mhlo.return %[[ADD_4]], %[[ADD_1]]
+  // CHECK-NEXT: }
+  // CHECK-NEXT: return %[[ADD_0]], %[[WHILE]]#0
+  %0 = mhlo.constant dense<0> : tensor<i32>
+  %1 = mhlo.constant dense<1> : tensor<i32>
+  %2 = mhlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, []>]>} dense<32> : tensor<i32>
+  %3 = mhlo.add %arg1, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh2, [{}, {"b"}]>]>} : tensor<32x96xf32>
+  %4:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %0) : tensor<32x96xf32>, tensor<i32>
+    cond {
+    %5 = mhlo.compare LT, %iterArg_0, %2 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    mhlo.return %5 : tensor<i1>
+  } do {
+    %5 = mhlo.add %iterArg_0, %1 : tensor<i32>
+    %6 = mhlo.add %iterArg, %arg1 : tensor<32x96xf32>
+    %7 = mhlo.add %6, %arg2 : tensor<32x96xf32>
+    %8 = mhlo.add %7, %3 : tensor<32x96xf32>
+    mhlo.return %8, %5 : tensor<32x96xf32>, tensor<i32>
+  }
+  return %3, %4#0 : tensor<32x96xf32>, tensor<32x96xf32>
+}
+
+// CHECK-LABEL: func @free_var_used_in_multiple_while_ops
+func.func @free_var_used_in_multiple_while_ops(
+    %arg0: tensor<32x96xf32>,
+    %arg1: tensor<32x96xf32> {sdy.sharding = #sdy.sharding<@mesh1, [{"a"}, {}]>})
+    -> tensor<32x96xf32> {
+  // CHECK-NEXT: %[[C0:.*]] = mhlo.constant dense<0>
+  // CHECK-NEXT: %[[C32:.*]] = mhlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, []>]>} dense<32>
+  // CHECK-NEXT: %[[SC_0:.*]] = sdy.sharding_constraint %arg1 <@mesh1, [{?}, {?}]>
+  // CHECK-NEXT: %[[WHILE_0:.*]]:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %[[C0]])
+  // CHECK-NEXT:   cond {
+  // CHECK-NEXT:   %[[COND:.*]] = mhlo.compare LT, %iterArg_0, %[[C32]]
+  // CHECK-NEXT:   mhlo.return %[[COND]]
+  // CHECK-NEXT: } do {
+  // CHECK-NEXT:   %[[ADD_0:.*]] = mhlo.add %iterArg, %[[SC_0]]
+  // CHECK-NEXT:   mhlo.return %[[ADD_0]], %iterArg_0
+  // CHECK-NEXT: }
+  // CHECK-NEXT: %[[SC_1:.*]] = sdy.sharding_constraint %arg1 <@mesh1, [{?}, {?}]>
+  // CHECK-NEXT: %[[WHILE_1:.*]]:2 = mhlo.while(%iterArg = %[[WHILE_0]]#0, %iterArg_0 = %[[C0]])
+  // CHECK-NEXT:   cond {
+  // CHECK-NEXT:   %[[COND:.*]] = mhlo.compare LT, %iterArg_0, %[[C32]]
+  // CHECK-NEXT:   mhlo.return %[[COND]]
+  // CHECK-NEXT: } do {
+  // CHECK-NEXT:   %[[ADD_1:.*]] = mhlo.add %iterArg, %[[SC_1]]
+  // CHECK-NEXT:   mhlo.return %[[ADD_1]], %iterArg_0
+  // CHECK-NEXT: }
+  // CHECK-NEXT: return %[[WHILE_1]]#0
+  %0 = mhlo.constant dense<0> : tensor<i32>
+  %1 = mhlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, []>]>} dense<32> : tensor<i32>
+  %2:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %0) : tensor<32x96xf32>, tensor<i32>
+    cond {
+    %4 = mhlo.compare LT, %iterArg_0, %1 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    mhlo.return %4 : tensor<i1>
+  } do {
+    %4 = mhlo.add %iterArg, %arg1 : tensor<32x96xf32>
+    mhlo.return %4, %iterArg_0 : tensor<32x96xf32>, tensor<i32>
+  }
+  %3:2 = mhlo.while(%iterArg = %2#0, %iterArg_0 = %0) : tensor<32x96xf32>, tensor<i32>
+    cond {
+    %4 = mhlo.compare LT, %iterArg_0, %1 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    mhlo.return %4 : tensor<i1>
+  } do {
+    %4 = mhlo.add %iterArg, %arg1 : tensor<32x96xf32>
+    mhlo.return %4, %iterArg_0 : tensor<32x96xf32>, tensor<i32>
+  }
+  return %3#0 : tensor<32x96xf32>
+}
diff --git a/third_party/xla/xla/service/spmd/shardy/test/round_trip_pipeline.mlir b/third_party/xla/xla/service/spmd/shardy/test/round_trip_pipeline.mlir
index dcd81b2..66b227d 100644
--- a/third_party/xla/xla/service/spmd/shardy/test/round_trip_pipeline.mlir
+++ b/third_party/xla/xla/service/spmd/shardy/test/round_trip_pipeline.mlir
@@ -13,6 +13,7 @@
 // CHECK: sdy.mesh @mesh = <"a"=2, "b"=2, "c"=2>
 sdy.mesh @mesh = <"a"=2, "b"=2, "c"=2>
 
+// CHECK-LABEL: func @main
 func.func @main(
   // CHECK: %arg0: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {"b"}p4]>})
   %arg0: tensor<8x16xf32>           {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {"b"}p4]>}
@@ -35,6 +36,7 @@
 // CHECK: sdy.mesh @mesh = <"a"=2, "b"=2>
 sdy.mesh @mesh = <"a"=2, "b"=2>
 
+// CHECK-LABEL: func @main
 func.func @main(
   // CHECK: %arg0: tensor<8x16xf32>)
   %arg0: tensor<8x16xf32>
@@ -55,6 +57,7 @@
 // CHECK: sdy.mesh @mesh = <"a"=2, "b"=2>
 sdy.mesh @mesh = <"a"=2, "b"=2>
 
+// CHECK-LABEL: func @main
 func.func @main(
   // CHECK: %arg0: tensor<8x16xf32>)
   %arg0: tensor<8x16xf32>
@@ -78,6 +81,7 @@
 // CHECK: sdy.mesh @mesh = <"a"=2, "b"=2>
 sdy.mesh @mesh = <"a"=2, "b"=2>
 
+// CHECK-LABEL: func @main
 func.func @main(
   // CHECK: %arg0: tensor<8x16xf32>)
   %arg0: tensor<8x16xf32>
@@ -97,7 +101,7 @@
 // CHECK: sdy.mesh @mesh = <"a"=2, "b"=2, "c"=2>
 sdy.mesh @mesh = <"a"=2, "b"=2, "c"=2>
 
-// CHECK:      @main(
+// CHECK-LABEL:      @main(
 // CHECK-SAME:   %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {"b"}p4]>},
 // CHECK-SAME:   %arg1: tensor<8x8xf32>, %arg2: tensor<8x8xf32>
 // CHECK-SAME:   ) -> tensor<8x8xf32> {
@@ -122,6 +126,7 @@
 // CHECK: sdy.mesh @mesh = <"data"=2>
 sdy.mesh @mesh = <"data"=2>
 
+// CHECK-LABEL: func @main
 func.func @main(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> {
   // CHECK:  sdy.sharding_constraint %arg0 <@mesh, [{"data", ?}, {?}]> :  tensor<8x8xf32>
   %0 = sdy.sharding_constraint %arg0       <@mesh, [{"data", ?}, {?}]> :  tensor<8x8xf32>
@@ -144,6 +149,7 @@
 // CHECK: sdy.mesh @mesh_2 = <"x"=8, "y"=4>
 sdy.mesh @mesh_2 = <"x"=8, "y"=4>
 
+// CHECK-LABEL: func @main
 func.func @main(
   // CHECK: %arg0: tensor<8x16xf32>) -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"x", ?}, {"y"}p4]>}, tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{?}, {"y"}p4]>}, tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"x"}, {"y"}p1]>}) {
   %arg0: tensor<8x16xf32>) ->           (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"x", ?}, {"y"}p4]>}, tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{?}, {"y"}p4]>}, tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"x"}, {"y"}p1]>}) {
@@ -157,36 +163,41 @@
 
 // -----
 
+// CHECK: sdy.mesh @mesh = <"x"=2>
+sdy.mesh @mesh = <"x"=2>
+
 // Test WhileOp with lifted free variables and sinked constants.
 
-// CHECK: func @main
-func.func @main(%arg0: tensor<32x96xf32>) -> tensor<32x96xf32> {
+// CHECK-LABEL: func @main
+func.func @main(
+    %arg0: tensor<32x96xf32>,
+    %arg1: tensor<32x96xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{}, {}]>"}})
+    -> tensor<32x96xf32> {
   // CHECK-NEXT: %[[C0:.*]] = sdy.constant dense<0>
   // CHECK-NEXT: %[[C32:.*]] = sdy.constant dense<32>
-  // CHECK-NEXT: %[[WHILE:.*]]:3 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %[[C0]], %iterArg_1 = %[[C32]])
+  // CHECK-NEXT: %[[SC:.*]] = sdy.sharding_constraint %arg1 <@mesh, [{?}, {?}]>
+  // CHECK-NEXT: %[[WHILE:.*]]:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %[[C0]])
   // CHECK-NEXT:   cond {
-  // CHECK-NEXT:   %[[COND:.*]] = mhlo.compare LT, %iterArg_0, %iterArg_1
+  // CHECK-NEXT:   %[[COND:.*]] = mhlo.compare LT, %iterArg_0, %[[C32]]
   // CHECK-NEXT:   mhlo.return %[[COND]]
   // CHECK-NEXT: } do {
   // CHECK-DAG:    %[[C1:.*]] = sdy.constant dense<1>
   // CHECK-DAG:    %[[ADD_0:.*]] = mhlo.add %iterArg_0, %[[C1]]
-  // CHECK-DAG:    %[[ADD_1:.*]] = mhlo.add %iterArg, %iterArg
-  // CHECK-DAG:    %[[IDENTITY:.*]] = sdy.identity %iterArg_1
-  // CHECK-NEXT:   mhlo.return %[[ADD_1]], %[[ADD_0]], %[[IDENTITY]]
+  // CHECK-DAG:    %[[ADD_1:.*]] = mhlo.add %iterArg, %[[SC]]
+  // CHECK-NEXT:   mhlo.return %[[ADD_1]], %[[ADD_0]]
   // CHECK-NEXT: }
   // CHECK-NEXT: return %[[WHILE]]#0
   %0 = sdy.constant dense<0> : tensor<i32>
   %1 = sdy.constant dense<32> : tensor<i32>
-  %2:3 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %0, %iterArg_1 = %1) : tensor<32x96xf32>, tensor<i32>, tensor<i32>
+  %2:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %0) : tensor<32x96xf32>, tensor<i32>
     cond {
-    %3 = mhlo.compare LT, %iterArg_0, %iterArg_1 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    %3 = mhlo.compare LT, %iterArg_0, %1 : (tensor<i32>, tensor<i32>) -> tensor<i1>
     mhlo.return %3 : tensor<i1>
   } do {
     %3 = sdy.constant dense<1> : tensor<i32>
     %4 = mhlo.add %iterArg_0, %3 : tensor<i32>
-    %5 = mhlo.add %iterArg, %iterArg : tensor<32x96xf32>
-    %6 = sdy.identity %iterArg_1 : tensor<i32>
-    mhlo.return %5, %4, %6 : tensor<32x96xf32>, tensor<i32>, tensor<i32>
+    %5 = mhlo.add %iterArg, %arg1 : tensor<32x96xf32>
+    mhlo.return %5, %4 : tensor<32x96xf32>, tensor<i32>
   }
   return %2#0 : tensor<32x96xf32>
 }
diff --git a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_export_pipeline.mlir b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_export_pipeline.mlir
index b180caa..89b8722 100644
--- a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_export_pipeline.mlir
+++ b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_export_pipeline.mlir
@@ -89,11 +89,10 @@
   return %0 : tensor<8x8xf32>
 }
 
-// CHECK-LABEL: func @identity_and_constant
-func.func @identity_and_constant() -> tensor<i32> {
+// CHECK-LABEL: func @constant
+func.func @constant() -> tensor<i32> {
   // CHECK-NEXT: %[[CONST:.*]] = mhlo.constant dense<0>
   // CHECK-NEXT: return %[[CONST]]
   %0 = sdy.constant dense<0> : tensor<i32>
-  %1 = sdy.identity %0 : tensor<i32>
-  return %1 : tensor<i32>
+  return %0 : tensor<i32>
 }
diff --git a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir
index 2354d83..e782de0 100644
--- a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir
+++ b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir
@@ -30,20 +30,22 @@
   }
 
   // CHECK-LABEL: func @while_with_free_variables
-  func.func @while_with_free_variables(%arg0: tensor<32x96xf32>) -> tensor<32x96xf32> {
+  func.func @while_with_free_variables(
+      %arg0: tensor<32x96xf32>,
+      %arg1: tensor<32x96xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{}, {}]>"}})
+      -> tensor<32x96xf32> {
     // CHECK-NEXT: %[[C0:.*]] = sdy.constant dense<0>
     // CHECK-NEXT: %[[C1:.*]] = sdy.constant dense<1>
     // CHECK-NEXT: %[[C32:.*]] = sdy.constant dense<32>
-    // CHECK-NEXT: %[[WHILE:.*]]:4 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %[[C0]], %iterArg_1 = %[[C32]], %iterArg_2 = %[[C1]])
+    // CHECK-NEXT: %[[SC:.*]] = sdy.sharding_constraint %arg1 <@mesh, [{?}, {?}]>
+    // CHECK-NEXT: %[[WHILE:.*]]:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %[[C0]])
     // CHECK-NEXT:   cond {
-    // CHECK-NEXT:   %[[COND:.*]] = mhlo.compare LT, %iterArg_0, %iterArg_1
+    // CHECK-NEXT:   %[[COND:.*]] = mhlo.compare LT, %iterArg_0, %[[C32]]
     // CHECK-NEXT:   mhlo.return %[[COND]]
     // CHECK-NEXT: } do {
-    // CHECK-NEXT:   %[[ADD_0:.*]] = mhlo.add %iterArg_0, %iterArg_2
-    // CHECK-NEXT:   %[[ADD_1:.*]] = mhlo.add %iterArg, %iterArg
-    // CHECK-NEXT:   %[[IDENTITY_0:.*]] = sdy.identity %iterArg_1
-    // CHECK-NEXT:   %[[IDENTITY_1:.*]] = sdy.identity %iterArg_2
-    // CHECK-NEXT:   mhlo.return %[[ADD_1]], %[[ADD_0]], %[[IDENTITY_0]], %[[IDENTITY_1]]
+    // CHECK-NEXT:   %[[ADD_0:.*]] = mhlo.add %iterArg_0, %[[C1]]
+    // CHECK-NEXT:   %[[ADD_1:.*]] = mhlo.add %iterArg, %[[SC]]
+    // CHECK-NEXT:   mhlo.return %[[ADD_1]], %[[ADD_0]]
     // CHECK-NEXT: }
     // CHECK-NEXT: return %[[WHILE]]#0
     %0 = mhlo.constant dense<0> : tensor<i32>
@@ -55,7 +57,7 @@
       mhlo.return %4 : tensor<i1>
     } do {
       %4 = mhlo.add %iterArg_0, %1 : tensor<i32>
-      %5 = mhlo.add %iterArg, %iterArg : tensor<32x96xf32>
+      %5 = mhlo.add %iterArg, %arg1 : tensor<32x96xf32>
       mhlo.return %5, %4 : tensor<32x96xf32>, tensor<i32>
     }
     return %3#0 : tensor<32x96xf32>
diff --git a/third_party/xla/xla/service/spmd/spmd_partitioner.cc b/third_party/xla/xla/service/spmd/spmd_partitioner.cc
index c3fc8b1..c02493e 100644
--- a/third_party/xla/xla/service/spmd/spmd_partitioner.cc
+++ b/third_party/xla/xla/service/spmd/spmd_partitioner.cc
@@ -571,8 +571,11 @@
              "not able to go from sharding "
           << sharding().ToString(/*include_metadata=*/true) << " to "
           << target.ToString(/*include_metadata=*/true)
-          << " without doing a full rematerialization of the tensor. You "
-             "probably want to enrich the sharding annotations to prevent "
+          << " without doing a full rematerialization of the tensor for HLO "
+             "operation: "
+          << hlo_->ToString()
+          << ". You probably want to enrich the sharding annotations to "
+             "prevent "
              "this from happening.";
     }
     return Replicate().Reshard(target);
@@ -3316,6 +3319,7 @@
     auto param = true_b.AddInstruction(HloInstruction::CreateParameter(
         /*parameter_number=*/0, operand_shape, "true_branch_param"));
     std::vector<HloInstruction*> new_operands;
+    new_operands.reserve(operands.size());
     for (int64_t i = 0; i < operands.size(); ++i) {
       new_operands.push_back(true_b.AddInstruction(
           HloInstruction::CreateGetTupleElement(*operand_shapes[i], param, i)));
@@ -4040,20 +4044,21 @@
   const HloSharding& sharding = hlo->sharding();
 
   // Shardings for the body parameter, body root, and cond parameter must be
-  // the same, and the condition root must be replicated so that all partitions
-  // follow the same control flow.
+  // the same.
   hlo->while_condition()->parameter_instruction(0)->set_sharding(sharding);
   hlo->while_body()->parameter_instruction(0)->set_sharding(sharding);
-  const HloSharding& cond_root_sharding =
-      hlo->while_condition()->root_instruction()->sharding();
-  TF_RETURN_IF_ERROR(partitioner_
-                         ->PartitionComputation(hlo->while_condition(),
-                                                cond_root_sharding.IsManual()
-                                                    ? cond_root_sharding
-                                                    : HloSharding::Replicate(),
-                                                next_channel_id_, logger_,
-                                                call_graph_)
-                         .status());
+
+  // The condition root must be replicated so that all partitions follow the
+  // same control flow.
+  HloInstruction* cond_root = hlo->while_condition()->root_instruction();
+  const HloSharding cond_root_sharding =
+      hlo_sharding_util::ReplicateAllDataDims(cond_root->sharding());
+  cond_root->set_sharding(cond_root_sharding);
+  TF_RETURN_IF_ERROR(
+      partitioner_
+          ->PartitionComputation(hlo->while_condition(), cond_root_sharding,
+                                 next_channel_id_, logger_, call_graph_)
+          .status());
   TF_RETURN_IF_ERROR(partitioner_
                          ->PartitionComputation(hlo->while_body(), sharding,
                                                 next_channel_id_, logger_,
@@ -4129,6 +4134,7 @@
   if (hlo->sharding().IsManual()) {
     auto clone_from_original = [&](const HloSharding& shared_sharding) {
       std::vector<HloInstruction*> new_operands;
+      new_operands.reserve(hlo->operand_count());
       for (int64_t i = 0; i < hlo->operand_count(); ++i) {
         new_operands.push_back(
             GetPartitionedHlo(hlo->operand(i)).Reshard(shared_sharding).hlo());
@@ -4310,6 +4316,7 @@
   }
   auto clone_from_original = [&](const HloSharding& shared_sharding) {
     std::vector<HloInstruction*> new_operands;
+    new_operands.reserve(hlo->operand_count());
     for (int64_t i = 0; i < hlo->operand_count(); ++i) {
       new_operands.push_back(
           GetPartitionedHlo(hlo->operand(i)).Reshard(shared_sharding).hlo());
@@ -4340,6 +4347,7 @@
   TF_RET_CHECK(!hlo->sharding().IsTileMaximal());
   // Replicate the operands and run partitioned Rng on all devices.
   std::vector<HloInstruction*> new_operands;
+  new_operands.reserve(hlo->operand_count());
   for (int64_t i = 0; i < hlo->operand_count(); ++i) {
     new_operands.push_back(GetPartitionedHlo(hlo->operand(i))
                                .Reshard(HloSharding::Replicate())
@@ -4659,6 +4667,7 @@
 
 absl::Status SpmdPartitioningVisitor::HandleTuple(HloInstruction* hlo) {
   std::vector<HloInstruction*> new_operands;
+  new_operands.reserve(hlo->operand_count());
   for (int64_t i = 0; i < hlo->operand_count(); ++i) {
     new_operands.push_back(
         GetPartitionedHlo(hlo->operand(i))
diff --git a/third_party/xla/xla/service/spmd/spmd_partitioner_test.cc b/third_party/xla/xla/service/spmd/spmd_partitioner_test.cc
index e24e32f..0d1d4a7 100644
--- a/third_party/xla/xla/service/spmd/spmd_partitioner_test.cc
+++ b/third_party/xla/xla/service/spmd/spmd_partitioner_test.cc
@@ -4474,6 +4474,36 @@
   EXPECT_THAT(root, AllOf(op::While(zero), op::Shape("s32[]")));
 }
 
+TEST_P(SpmdPartitioningTest, WhilePartialManual) {
+  absl::string_view hlo_string = R"(
+HloModule module
+
+LoopCond {
+  x = s32[] parameter(0), sharding={devices=[2,2]<=[4] last_tile_dims={manual, replicated}}
+  const = s32[] constant(5), sharding={devices=[2,2]<=[4] last_tile_dims={manual, replicated}}
+  ROOT lt = pred[] compare(x, const), direction=LT, sharding={devices=[2,2]<=[4] last_tile_dims={manual, replicated}}
+}
+
+Inc {
+  x = s32[] parameter(0), sharding={devices=[2,2]<=[4] last_tile_dims={manual, replicated}}
+  const = s32[] constant(1), sharding={devices=[2,2]<=[4] last_tile_dims={manual, replicated}}
+  ROOT add = s32[] add(x, const), sharding={devices=[2,2]<=[4] last_tile_dims={manual, replicated}}
+}
+
+ENTRY entry {
+  zero = s32[] parameter(0), sharding={devices=[2,2]<=[4] last_tile_dims={manual, replicated}}
+  ROOT while = s32[] while(zero), body=Inc, condition=LoopCond, sharding={devices=[2,2]<=[4] last_tile_dims={manual, replicated}}
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          PartitionComputation(hlo_string, /*num_devices=*/4));
+  VLOG(1) << module->ToString();
+
+  auto zero = AllOf(op::Parameter(0), op::Shape("s32[]"));
+  const auto root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(root, AllOf(op::While(zero), op::Shape("s32[]")));
+}
+
 TEST_P(SpmdPartitioningTest, TestWhileFrontendAttributes) {
   absl::string_view hlo_string = R"(
 HloModule module
@@ -9163,6 +9193,29 @@
   EXPECT_THAT(root, op::AllReduce(dot));
 }
 
+TEST_P(SpmdPartitioningTest, ReplicateLHSofConv) {
+  const char* const hlo_string = R"(
+HloModule module
+ENTRY main {
+  lhs = bf16[128,8,8,1280] parameter(0), sharding={devices=[128,1,1,1]<=[128]}
+  rhs = bf16[3,3,1280,1280] parameter(1), sharding={devices=[1,1,1,8,16]<=[16,8]T(1,0) last_tile_dim_replicate}
+  ROOT conv = bf16[128,8,8,1280] convolution(lhs, rhs), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, sharding={devices=[1,1,1,8,16]<=[16,8]T(1,0) last_tile_dim_replicate}
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(
+      auto module, PartitionComputation(hlo_string, /*num_devices=*/128));
+  VLOG(1) << module->ToString();
+
+  const auto lhs = AllOf(op::Shape("bf16[128,8,8,1280]"),
+                         op::AllReduce(op::DynamicUpdateSlice(
+                             op::Broadcast(), op::Parameter(0), _, _, _, _)));
+  const auto rhs = AllOf(op::Shape("bf16[3,3,1280,160]"), op::Parameter(1));
+  const auto root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(root,
+              AllOf(op::Shape("bf16[128,8,8,160]"), op::Convolution(lhs, rhs)));
+}
+
 TEST_P(SpmdPartitioningTest, ElementwiseTest_SubgroupSharding_TileToReplicate) {
   absl::string_view hlo_string = R"(
 HloModule module
@@ -10911,6 +10964,53 @@
                 _));
 }
 
+TEST_P(SpmdPartitioningTest, ScatterRepsOnLastTileDimDontDivideGroups) {
+  absl::string_view hlo_string = R"(
+HloModule module
+
+region.1 {
+  lhs = f32[] parameter(0)
+  rhs = f32[] parameter(1)
+  ROOT res.1 = f32[] add(lhs, rhs)
+}
+
+ENTRY entry {
+  %add.1 = f32[8,96,2048,16]{3,2,1,0} parameter(0)
+  %concatenate.1 = s32[8,96,2048,2,4]{4,3,2,1,0} parameter(1)
+  %broadcast.1 = f32[8,96,2048,2]{3,2,1,0} parameter(2)
+
+  %add.1.shard = f32[8,96,2048,16]{3,2,1,0} copy(%add.1), sharding={devices=[8,8,1,1,24]<=[8,8,24]T(1,0,2) last_tile_dim_replicate}
+  %concatenate.1.shard = s32[8,96,2048,2,4]{4,3,2,1,0} copy(%concatenate.1), sharding={devices=[8,8,1,1,1,24]<=[8,8,24]T(1,0,2) last_tile_dim_replicate}
+  %broadcast.1.shard = f32[8,96,2048,2]{3,2,1,0} copy(%broadcast.1), sharding={devices=[8,8,1,1,24]<=[8,8,24]T(1,0,2) last_tile_dim_replicate}
+
+  ROOT %scatter.44 = f32[8,96,2048,16]{3,2,1,0} scatter(
+    %add.1.shard,
+    %concatenate.1.shard,
+    %broadcast.1.shard),
+    update_window_dims={},
+    inserted_window_dims={0,1,2,3},
+    scatter_dims_to_operand_dims={0,1,2,3},
+    index_vector_dim=4,
+    to_apply=region.1,
+    sharding={devices=[8,8,1,1,24]<=[8,8,24]T(1,0,2) last_tile_dim_replicate}
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(
+      auto module, PartitionComputation(hlo_string, /*num_devices=*/1536));
+  VLOG(1) << module->ToString();
+  // Verify scatter is partitioned properly.
+  {
+    const auto partitioned_scatter =
+        module->entry_computation()->root_instruction();
+    auto operand = AllOf(op::Shape("f32[1,12,2048,16]"));
+    auto indices = AllOf(op::Shape("s32[8,96,2048,2,4]"));
+    auto update = AllOf(op::Shape("f32[8,96,2048,2]"));
+    auto scatter = AllOf(op::Shape("f32[1,12,2048,16]"),
+                         op::Scatter(operand, indices, update));
+    EXPECT_THAT(partitioned_scatter, scatter);
+  }
+}
+
 TEST_P(SpmdPartitioningTest, ParallelDimFromOutsideConditionalPositive) {
   absl::string_view hlo_string = R"(
 HloModule module
diff --git a/third_party/xla/xla/service/spmd/spmd_partitioner_util.h b/third_party/xla/xla/service/spmd/spmd_partitioner_util.h
index 65b5d01..a982c3e 100644
--- a/third_party/xla/xla/service/spmd/spmd_partitioner_util.h
+++ b/third_party/xla/xla/service/spmd/spmd_partitioner_util.h
@@ -84,6 +84,7 @@
                                                               PrimitiveType)) {
   if (shape.IsTuple()) {
     std::vector<HloInstruction*> elements;
+    elements.reserve(ShapeUtil::TupleElementCount(shape));
     for (int64_t i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
       elements.push_back(
           CreateConstantBase(ShapeUtil::GetTupleElementShape(shape, i),
diff --git a/third_party/xla/xla/service/stable_sort_expander.cc b/third_party/xla/xla/service/stable_sort_expander.cc
index 910ab5d..ca87dce 100644
--- a/third_party/xla/xla/service/stable_sort_expander.cc
+++ b/third_party/xla/xla/service/stable_sort_expander.cc
@@ -55,7 +55,6 @@
   HloComputation* computation = sort->parent();
 
   HloInstruction* expanded_sort = nullptr;
-  absl::flat_hash_set<int64_t> used_indices;
   int64_t iota_index = IotaOperandIndexForStableSort(*sort);
 
   // If there is currently no iota operand which we could use for making the
diff --git a/third_party/xla/xla/service/stream_pool_test.cc b/third_party/xla/xla/service/stream_pool_test.cc
index fd0a05e..2bea411 100644
--- a/third_party/xla/xla/service/stream_pool_test.cc
+++ b/third_party/xla/xla/service/stream_pool_test.cc
@@ -26,22 +26,21 @@
 
 class StreamPoolTest : public ::testing::Test {
  protected:
-  std::unique_ptr<se::StreamExecutor> NewStreamExecutor() {
+  se::StreamExecutor* NewStreamExecutor() {
     se::Platform* platform =
         se::PlatformManager::PlatformWithName("Host").value();
-    se::StreamExecutorConfig config(/*ordinal=*/0);
-    return platform->GetUncachedExecutor(config).value();
+    return platform->ExecutorForDevice(/*ordinal=*/0).value();
   }
 };
 
 TEST_F(StreamPoolTest, EmptyPool) {
-  std::unique_ptr<se::StreamExecutor> executor = NewStreamExecutor();
-  StreamPool pool(executor.get());
+  se::StreamExecutor* executor = NewStreamExecutor();
+  StreamPool pool(executor);
 }
 
 TEST_F(StreamPoolTest, OneStreamPool) {
-  std::unique_ptr<se::StreamExecutor> executor = NewStreamExecutor();
-  StreamPool pool(executor.get());
+  se::StreamExecutor* executor = NewStreamExecutor();
+  StreamPool pool(executor);
 
   // Borrow and return a stream.
   StreamPool::Ptr stream1 = pool.BorrowStream();
@@ -61,8 +60,8 @@
 }
 
 TEST_F(StreamPoolTest, TwoStreamPool) {
-  std::unique_ptr<se::StreamExecutor> executor = NewStreamExecutor();
-  StreamPool pool(executor.get());
+  se::StreamExecutor* executor = NewStreamExecutor();
+  StreamPool pool(executor);
 
   // Borrow two streams.
   StreamPool::Ptr stream1 = pool.BorrowStream();
diff --git a/third_party/xla/xla/service/while_loop_fusible_sinking.cc b/third_party/xla/xla/service/while_loop_fusible_sinking.cc
index d1fd7ac..07b49db 100644
--- a/third_party/xla/xla/service/while_loop_fusible_sinking.cc
+++ b/third_party/xla/xla/service/while_loop_fusible_sinking.cc
@@ -136,10 +136,6 @@
   }
 
   bool changed = false;
-
-  absl::flat_hash_map<int64_t, absl::InlinedVector<HloInstruction*, 1>>
-      conditional_gte_index_to_insts =
-          WhileUtil::GetGTEsMapForWhileConditional(*while_cond);
   std::vector<HloInstruction*> invariant_body_gtes =
       WhileUtil::GetInvariantGTEsForWhileBody(*while_body);
   std::vector<int64_t> tuple_indices;
diff --git a/third_party/xla/xla/service/xla_aot_compile_test_autotune_results.prototxt b/third_party/xla/xla/service/xla_aot_compile_test_autotune_results.prototxt
index 1901cf2e..0cae8b4 100644
--- a/third_party/xla/xla/service/xla_aot_compile_test_autotune_results.prototxt
+++ b/third_party/xla/xla/service/xla_aot_compile_test_autotune_results.prototxt
@@ -15,7 +15,7 @@
 version: 3
 results {
   device: "CUDA: 6.0, Cores: 56, GPU clock: 1.4805 GHz, Memory bandwidth: 732 GB/s, L2 cache: 4 MB"
-  hlo: "(f32[3,3]{1,0}, s8[72]{0}) custom-call(f32[3,3]{1,0}, f32[3,3]{1,0}), custom_call_target=\"__cublas$gemm\", backend_config={\"operation_queue_id\":\"0\",\"wait_on_operation_queues\":[],\"gemm_backend_config\":{\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"],\"algorithm\":\"ALG_UNSET\"},\"epilogue\":\"DEFAULT\",\"lhs_stride\":\"9\",\"rhs_stride\":\"9\",\"grad_x\":false,\"grad_y\":false},\"force_earliest_schedule\":false}"
+  hlo: "(f32[3,3]{1,0}, s8[72]{0}) custom-call(f32[3,3]{1,0}, f32[3,3]{1,0}), custom_call_target=\"__cublas$gemm\", backend_config={\"force_earliest_schedule\":false,\"gemm_backend_config\":{\"alpha_imag\":0,\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_batch_dimensions\":[],\"lhs_contracting_dimensions\":[\"1\"],\"rhs_batch_dimensions\":[],\"rhs_contracting_dimensions\":[\"0\"]},\"epilogue\":\"DEFAULT\",\"grad_x\":false,\"grad_y\":false,\"lhs_stride\":\"9\",\"precision_config\":{\"algorithm\":\"ALG_UNSET\",\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"rhs_stride\":\"9\"},\"operation_queue_id\":\"0\",\"wait_on_operation_queues\":[]}"
   result {
     gemm {
       algorithm: 13
@@ -24,7 +24,7 @@
 }
 results {
   device: "CUDA: 6.0, Cores: 56, GPU clock: 1.4805 GHz, Memory bandwidth: 732 GB/s, L2 cache: 4 MB"
-  hlo: "(f32[1,1,2,3]{3,2,1,0}, u8[0]{0}) custom-call(f32[1,2,4,4]{3,2,1,0}, f32[1,2,3,2]{3,2,1,0}), window={size=3x2}, dim_labels=bf01_oi01->bf01, custom_call_target=\"__cudnn$convForward\", backend_config={\"operation_queue_id\":\"0\",\"wait_on_operation_queues\":[],\"cudnn_conv_backend_config\":{\"activation_mode\":\"kNone\",\"conv_result_scale\":1,\"side_input_scale\":0,\"leakyrelu_alpha\":0},\"force_earliest_schedule\":false}"
+  hlo: "(f32[1,1,2,3]{3,2,1,0}, u8[0]{0}) custom-call(f32[1,2,4,4]{3,2,1,0}, f32[1,2,3,2]{3,2,1,0}), window={size=3x2}, dim_labels=bf01_oi01->bf01, custom_call_target=\"__cudnn$convForward\", backend_config={\"cudnn_conv_backend_config\":{\"activation_mode\":\"kNone\",\"conv_result_scale\":1,\"leakyrelu_alpha\":0,\"side_input_scale\":0},\"force_earliest_schedule\":false,\"operation_queue_id\":\"0\",\"wait_on_operation_queues\":[]}"
   result {
     run_time {
       nanos: 8192
diff --git a/third_party/xla/xla/shape_util.cc b/third_party/xla/xla/shape_util.cc
index 0deca08..56dd78d 100644
--- a/third_party/xla/xla/shape_util.cc
+++ b/third_party/xla/xla/shape_util.cc
@@ -16,7 +16,6 @@
 #include "xla/shape_util.h"
 
 #include <algorithm>
-#include <array>
 #include <climits>
 #include <cstddef>
 #include <cstdint>
@@ -1982,8 +1981,9 @@
 // Returns the indices of the first elements of all consecutive subarrays of the
 // given array. For example:
 // ConsecutiveSegments({m, m+1, m+2, n, k, k+1}) = {0, 3, 4}
-static std::vector<size_t> ConsecutiveSegments(absl::Span<const int64_t> xs) {
-  std::vector<size_t> is = {0};
+static absl::InlinedVector<size_t, 3> ConsecutiveSegments(
+    absl::Span<const int64_t> xs) {
+  absl::InlinedVector<size_t, 3> is = {0};
   for (size_t i = 1; i < xs.size(); ++i) {
     if (1 != xs[i] - xs[i - 1]) {
       is.push_back(i);
@@ -2010,19 +2010,22 @@
                                                   dimensions);
 }
 
-static std::vector<int64_t> MajorToMinorLayout(const Shape& s) {
+static absl::InlinedVector<int64_t, 3> MajorToMinorLayout(const Shape& s) {
   absl::Span<const int64_t> minor_to_major = LayoutUtil::MinorToMajor(s);
-  return std::vector<int64_t>{minor_to_major.rbegin(), minor_to_major.rend()};
+  return absl::InlinedVector<int64_t, 3>{minor_to_major.rbegin(),
+                                         minor_to_major.rend()};
 }
 
-static std::optional<Vector3> GetNormalizedTransposeShapeHelper(
+static std::optional<absl::InlinedVector<int64_t, 3>>
+GetNormalizedTransposeShapeHelper(
     const Shape& input_shape, absl::Span<int64_t const> output_to_input,
-    const Vector3& permutation) {
+    const absl::InlinedVector<int64_t, 3>& permutation) {
   // 'permutation' should not be the identity permutation.
   if (permutation[0] == 0 && permutation[1] == 1 && permutation[2] == 2) {
     return std::nullopt;
   }
-  std::vector<size_t> segments = ConsecutiveSegments(output_to_input);
+  absl::InlinedVector<size_t, 3> segments =
+      ConsecutiveSegments(output_to_input);
   if (segments.size() > 3) {
     return std::nullopt;
   }
@@ -2031,8 +2034,9 @@
       ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
           input_shape);
   Shape normalized_shape = MergeDimensions(segments, normalized_input_shape);
-  std::vector<int64_t> normalized_dims{normalized_shape.dimensions().begin(),
-                                       normalized_shape.dimensions().end()};
+  absl::InlinedVector<int64_t, 3> normalized_dims{
+      normalized_shape.dimensions().begin(),
+      normalized_shape.dimensions().end()};
   if (segments.size() == 2) {
     // If we have two segments, we know that at least one transpose is
     // happening, otherwise we would have only 1 segment.
@@ -2050,9 +2054,9 @@
     normalized_dims.insert(normalized_dims.begin() + untransposed, 1);
   } else if (segments.size() == 3) {
     // Derive the order from the segments.
-    Vector3 segment_order{output_to_input[segments[0]],
-                          output_to_input[segments[1]],
-                          output_to_input[segments[2]]};
+    absl::InlinedVector<int64_t, 3> segment_order{output_to_input[segments[0]],
+                                                  output_to_input[segments[1]],
+                                                  output_to_input[segments[2]]};
     // We expect the same relative order.
     for (int64_t i = 1; i < 3; ++i) {
       if ((segment_order[i] > segment_order[i - 1]) !=
@@ -2062,31 +2066,32 @@
     }
   }
   if (normalized_dims.size() == 3) {
-    return Vector3{normalized_dims[permutation[0]],
-                   normalized_dims[permutation[1]],
-                   normalized_dims[permutation[2]]};
+    return absl::InlinedVector<int64_t, 3>{normalized_dims[permutation[0]],
+                                           normalized_dims[permutation[1]],
+                                           normalized_dims[permutation[2]]};
   }
   return std::nullopt;
 }
 
-/* static */ std::optional<Vector3>
+/* static */ std::optional<absl::InlinedVector<int64_t, 3>>
 ShapeUtil::GetNormalizedLogicalTransposeShape(
     const Shape& input_shape, const Shape& output_shape,
-    absl::Span<int64_t const> dimensions, const Vector3& permutation) {
+    absl::Span<int64_t const> dimensions,
+    const absl::InlinedVector<int64_t, 3>& permutation) {
   if (!LayoutUtil::IsMonotonicWithDim0Major(input_shape.layout()) ||
       !LayoutUtil::IsMonotonicWithDim0Major(output_shape.layout())) {
     // Only works on default layouts.
     return std::nullopt;
   }
   // Drop degenerate dimensions.
-  std::vector<int64_t> delta(input_shape.rank() + 1, 0);
+  absl::InlinedVector<int64_t, 3> delta(input_shape.rank() + 1, 0);
   for (int i = 0; i < input_shape.rank(); ++i) {
     delta[i + 1] = delta[i];
     if (input_shape.dimensions(i) == static_cast<int64_t>(1)) {
       ++delta[i + 1];
     }
   }
-  std::vector<int64_t> new_dimensions;
+  absl::InlinedVector<int64_t, 3> new_dimensions;
   for (int i = 0; i < dimensions.size(); i++) {
     if (output_shape.dimensions(i) != 1) {
       new_dimensions.push_back(dimensions[i] - delta[dimensions[i]]);
@@ -2098,15 +2103,18 @@
       permutation);
 }
 
-/* static */ std::optional<Vector3> ShapeUtil::GetNormalizedTransposeShape(
+/* static */ std::optional<absl::InlinedVector<int64_t, 3>>
+ShapeUtil::GetNormalizedTransposeShape(
     const Shape& input_shape, const Shape& output_shape,
-    const Vector3& permutation) {
+    const absl::InlinedVector<int64_t, 3>& permutation) {
   if (!ShapeUtil::CompatibleIgnoringElementType(input_shape, output_shape)) {
     return std::nullopt;
   }
 
-  std::vector<int64_t> major_to_minor_input = MajorToMinorLayout(input_shape);
-  std::vector<int64_t> major_to_minor_output = MajorToMinorLayout(output_shape);
+  absl::InlinedVector<int64_t, 3> major_to_minor_input =
+      MajorToMinorLayout(input_shape);
+  absl::InlinedVector<int64_t, 3> major_to_minor_output =
+      MajorToMinorLayout(output_shape);
   std::vector<int64_t> output_to_input = ComposePermutations(
       InversePermutation(major_to_minor_output), major_to_minor_input);
 
diff --git a/third_party/xla/xla/shape_util.h b/third_party/xla/xla/shape_util.h
index a773ade..03a809d 100644
--- a/third_party/xla/xla/shape_util.h
+++ b/third_party/xla/xla/shape_util.h
@@ -44,7 +44,6 @@
 #include "xla/primitive_util.h"
 #include "xla/printer.h"
 #include "xla/shape.h"
-#include "xla/util.h"
 #include "xla/xla_data.pb.h"
 #include "tsl/platform/errors.h"
 #include "tsl/platform/logging.h"  // IWYU pragma: keep
@@ -1030,14 +1029,17 @@
   // If `b` is a 0-2-1 transpose of `a` in 0-1-2, return the dimensions for the
   // normalized shape of `b` or the 0-2-1 shape. In general, the
   // permutation[0]-permutation[1]-permutation[2] shape is returned.
-  static std::optional<Vector3> GetNormalizedTransposeShape(
+  static std::optional<absl::InlinedVector<int64_t, 3>>
+  GetNormalizedTransposeShape(
       const Shape& input_shape, const Shape& output_shape,
-      const Vector3& permutation);
+      const absl::InlinedVector<int64_t, 3>& permutation);
 
   // Entry point for physical + logical transposition.
-  static std::optional<Vector3> GetNormalizedLogicalTransposeShape(
+  static std::optional<absl::InlinedVector<int64_t, 3>>
+  GetNormalizedLogicalTransposeShape(
       const Shape& input_shape, const Shape& output_shape,
-      absl::Span<int64_t const> dimensions, const Vector3& permutation);
+      absl::Span<int64_t const> dimensions,
+      const absl::InlinedVector<int64_t, 3>& permutation);
 
   // Strips device-specific information, namely tiling and memory-space
   // information, from a shape.
diff --git a/third_party/xla/xla/shape_util_test.cc b/third_party/xla/xla/shape_util_test.cc
index e7c1beb..a66b148 100644
--- a/third_party/xla/xla/shape_util_test.cc
+++ b/third_party/xla/xla/shape_util_test.cc
@@ -23,6 +23,7 @@
 #include <variant>
 #include <vector>
 
+#include "absl/container/inlined_vector.h"
 #include "absl/status/status.h"
 #include "absl/strings/str_cat.h"
 #include "absl/strings/str_join.h"
@@ -1395,8 +1396,9 @@
   Shape shape = ShapeUtil::MakeShapeWithDenseLayout(F32, {128, 64}, {1, 0});
   Shape transposed =
       ShapeUtil::MakeShapeWithDenseLayout(F32, {64, 128}, {0, 1});
-  EXPECT_EQ(std::nullopt, ShapeUtil::GetNormalizedTransposeShape(
-                              shape, transposed, Vector3{0, 2, 1}));
+  EXPECT_EQ(std::nullopt,
+            ShapeUtil::GetNormalizedTransposeShape(
+                shape, transposed, absl::InlinedVector<int64_t, 3>{0, 2, 1}));
 }
 
 TEST(Transpose021Test, NoTranspose2) {
@@ -1404,8 +1406,9 @@
       ShapeUtil::MakeShapeWithDenseLayout(F32, {128, 64, 32}, {2, 1, 0});
   Shape transposed =
       ShapeUtil::MakeShapeWithDenseLayout(F32, {32, 64, 128}, {0, 1, 2});
-  EXPECT_EQ(std::nullopt, ShapeUtil::GetNormalizedTransposeShape(
-                              shape, transposed, Vector3{0, 1, 2}));
+  EXPECT_EQ(std::nullopt,
+            ShapeUtil::GetNormalizedTransposeShape(
+                shape, transposed, absl::InlinedVector<int64_t, 3>{0, 1, 2}));
 }
 
 TEST(Transpose021Test, WrongTranspose) {
@@ -1414,7 +1417,8 @@
   Shape output_shape =
       ShapeUtil::MakeShapeWithDenseLayout(F32, {8, 32768, 16}, {0, 1, 2});
   EXPECT_EQ(std::nullopt, ShapeUtil::GetNormalizedTransposeShape(
-                              input_shape, output_shape, Vector3{0, 2, 1}));
+                              input_shape, output_shape,
+                              absl::InlinedVector<int64_t, 3>{0, 2, 1}));
 }
 
 TEST(Transpose021Test, WrongTranspose2) {
@@ -1422,7 +1426,8 @@
   Shape output_shape =
       ShapeUtil::MakeShapeWithDenseLayout(F32, {8, 16}, {0, 1});
   EXPECT_EQ(std::nullopt, ShapeUtil::GetNormalizedTransposeShape(
-                              input_shape, output_shape, Vector3{0, 1, 2}));
+                              input_shape, output_shape,
+                              absl::InlinedVector<int64_t, 3>{0, 1, 2}));
 }
 
 TEST(Transpose021Test, WrongTranspose3) {
@@ -1430,16 +1435,17 @@
   Shape output_shape =
       ShapeUtil::MakeShapeWithDenseLayout(F32, {8, 16}, {0, 1});
   EXPECT_EQ(std::nullopt, ShapeUtil::GetNormalizedTransposeShape(
-                              input_shape, output_shape, Vector3{1, 2, 0}));
+                              input_shape, output_shape,
+                              absl::InlinedVector<int64_t, 3>{1, 2, 0}));
 }
 
 TEST(Transpose021Test, Simple) {
   Shape shape = ShapeUtil::MakeShapeWithDenseLayout(F32, {128, 64}, {1, 0});
   Shape transposed =
       ShapeUtil::MakeShapeWithDenseLayout(F32, {128, 64}, {0, 1});
-  EXPECT_EQ(std::make_optional(Vector3{1, 64, 128}),
-            ShapeUtil::GetNormalizedTransposeShape(shape, transposed,
-                                                   Vector3{0, 2, 1}));
+  EXPECT_EQ(std::make_optional(absl::InlinedVector<int64_t, 3>{1, 64, 128}),
+            ShapeUtil::GetNormalizedTransposeShape(
+                shape, transposed, absl::InlinedVector<int64_t, 3>{0, 2, 1}));
 }
 
 TEST(Transpose021Test, Simple2) {
@@ -1447,9 +1453,10 @@
       ShapeUtil::MakeShapeWithDenseLayout(F32, {8, 32768, 16}, {2, 1, 0});
   Shape output_shape =
       ShapeUtil::MakeShapeWithDenseLayout(F32, {8, 32768, 16}, {1, 2, 0});
-  EXPECT_EQ(std::make_optional(Vector3{8, 16, 32768}),
-            ShapeUtil::GetNormalizedTransposeShape(input_shape, output_shape,
-                                                   Vector3{0, 2, 1}));
+  EXPECT_EQ(
+      std::make_optional(absl::InlinedVector<int64_t, 3>{8, 16, 32768}),
+      ShapeUtil::GetNormalizedTransposeShape(
+          input_shape, output_shape, absl::InlinedVector<int64_t, 3>{0, 2, 1}));
 }
 
 TEST(Transpose021Test, Simple3) {
@@ -1457,18 +1464,20 @@
       ShapeUtil::MakeShapeWithDenseLayout(F32, {8, 32768, 16}, {2, 1, 0});
   Shape output_shape =
       ShapeUtil::MakeShapeWithDenseLayout(F32, {8, 32768, 16}, {0, 1, 2});
-  EXPECT_EQ(std::make_optional(Vector3{16, 32768, 8}),
-            ShapeUtil::GetNormalizedTransposeShape(input_shape, output_shape,
-                                                   Vector3{2, 1, 0}));
+  EXPECT_EQ(
+      std::make_optional(absl::InlinedVector<int64_t, 3>{16, 32768, 8}),
+      ShapeUtil::GetNormalizedTransposeShape(
+          input_shape, output_shape, absl::InlinedVector<int64_t, 3>{2, 1, 0}));
 }
 
 TEST(Transpose021Test, Simple4) {
   Shape input_shape = ShapeUtil::MakeShapeWithDenseLayout(F32, {8, 16}, {1, 0});
   Shape output_shape =
       ShapeUtil::MakeShapeWithDenseLayout(F32, {8, 16}, {0, 1});
-  EXPECT_EQ(std::make_optional(Vector3{16, 1, 8}),
-            ShapeUtil::GetNormalizedTransposeShape(input_shape, output_shape,
-                                                   Vector3{2, 1, 0}));
+  EXPECT_EQ(
+      std::make_optional(absl::InlinedVector<int64_t, 3>{16, 1, 8}),
+      ShapeUtil::GetNormalizedTransposeShape(
+          input_shape, output_shape, absl::InlinedVector<int64_t, 3>{2, 1, 0}));
 }
 
 TEST(Transpose021Test, LargeView) {
@@ -1476,9 +1485,10 @@
       F32, {8, 32, 32, 32, 16}, {4, 3, 2, 1, 0});
   Shape output_shape = ShapeUtil::MakeShapeWithDenseLayout(
       F32, {8, 32, 32, 32, 16}, {3, 2, 1, 4, 0});
-  EXPECT_EQ(std::make_optional(Vector3{8, 16, 32768}),
-            ShapeUtil::GetNormalizedTransposeShape(input_shape, output_shape,
-                                                   Vector3{0, 2, 1}));
+  EXPECT_EQ(
+      std::make_optional(absl::InlinedVector<int64_t, 3>{8, 16, 32768}),
+      ShapeUtil::GetNormalizedTransposeShape(
+          input_shape, output_shape, absl::InlinedVector<int64_t, 3>{0, 2, 1}));
 }
 
 TEST(Transpose021Test, LargeSizeOverflowTest) {
@@ -1487,7 +1497,8 @@
   Shape output_shape =
       ShapeUtil::MakeShapeWithDenseLayout(BF16, {4096, 4096, 128}, {2, 1, 0});
   EXPECT_EQ(std::nullopt, ShapeUtil::GetNormalizedTransposeShape(
-                              input_shape, output_shape, Vector3{0, 2, 1}));
+                              input_shape, output_shape,
+                              absl::InlinedVector<int64_t, 3>{0, 2, 1}));
 }
 
 TEST(Transpose021Test, Batched) {
@@ -1495,9 +1506,9 @@
       ShapeUtil::MakeShapeWithDenseLayout(F32, {32, 3, 64}, {2, 1, 0});
   Shape transposed =
       ShapeUtil::MakeShapeWithDenseLayout(F32, {32, 3, 64}, {1, 0, 2});
-  EXPECT_EQ(std::make_optional(Vector3{1, 64, 96}),
-            ShapeUtil::GetNormalizedTransposeShape(shape, transposed,
-                                                   Vector3{0, 2, 1}));
+  EXPECT_EQ(std::make_optional(absl::InlinedVector<int64_t, 3>{1, 64, 96}),
+            ShapeUtil::GetNormalizedTransposeShape(
+                shape, transposed, absl::InlinedVector<int64_t, 3>{0, 2, 1}));
 }
 
 TEST(Transpose021Test, BatchedLogical) {
@@ -1506,9 +1517,10 @@
   Shape transposed =
       ShapeUtil::MakeShapeWithDenseLayout(F32, {64, 32, 3}, {2, 1, 0});
   std::vector<int64_t> dimensions = {2, 0, 1};
-  EXPECT_EQ(std::make_optional(Vector3{1, 64, 96}),
+  EXPECT_EQ(std::make_optional(absl::InlinedVector<int64_t, 3>{1, 64, 96}),
             ShapeUtil::GetNormalizedLogicalTransposeShape(
-                shape, transposed, dimensions, Vector3{0, 2, 1}));
+                shape, transposed, dimensions,
+                absl::InlinedVector<int64_t, 3>{0, 2, 1}));
 }
 
 TEST(Transpose021Test, LogicalWithDegenerateDims) {
@@ -1517,9 +1529,10 @@
   Shape transposed = ShapeUtil::MakeShapeWithDenseLayout(
       F32, {1, 32, 1, 64, 1, 3, 1}, {6, 5, 4, 3, 2, 1, 0});
   std::vector<int64_t> dimensions = {6, 1, 4, 5, 2, 3, 0};
-  EXPECT_EQ(std::make_optional(Vector3{32, 64, 3}),
+  EXPECT_EQ(std::make_optional(absl::InlinedVector<int64_t, 3>{32, 64, 3}),
             ShapeUtil::GetNormalizedLogicalTransposeShape(
-                shape, transposed, dimensions, Vector3{0, 2, 1}));
+                shape, transposed, dimensions,
+                absl::InlinedVector<int64_t, 3>{0, 2, 1}));
 }
 
 TEST(Transpose021Test, LogicalWithDegenerateLastDim) {
@@ -1528,9 +1541,10 @@
   Shape transposed =
       ShapeUtil::MakeShapeWithDenseLayout(F32, {32, 64, 1}, {2, 1, 0});
   std::vector<int64_t> dimensions = {2, 1, 0};
-  EXPECT_EQ(std::make_optional(Vector3{1, 32, 64}),
+  EXPECT_EQ(std::make_optional(absl::InlinedVector<int64_t, 3>{1, 32, 64}),
             ShapeUtil::GetNormalizedLogicalTransposeShape(
-                shape, transposed, dimensions, Vector3{0, 2, 1}));
+                shape, transposed, dimensions,
+                absl::InlinedVector<int64_t, 3>{0, 2, 1}));
 }
 
 TEST(Transpose021Test, Large) {
@@ -1538,9 +1552,9 @@
       ShapeUtil::MakeShapeWithDenseLayout(F32, {8, 31, 31, 65}, {3, 2, 1, 0});
   Shape transposed =
       ShapeUtil::MakeShapeWithDenseLayout(F32, {8, 31, 31, 65}, {2, 1, 3, 0});
-  EXPECT_EQ(std::make_optional(Vector3{8, 65, 961}),
-            ShapeUtil::GetNormalizedTransposeShape(shape, transposed,
-                                                   Vector3{0, 2, 1}));
+  EXPECT_EQ(std::make_optional(absl::InlinedVector<int64_t, 3>{8, 65, 961}),
+            ShapeUtil::GetNormalizedTransposeShape(
+                shape, transposed, absl::InlinedVector<int64_t, 3>{0, 2, 1}));
 }
 
 TEST(Transpose210Test, LogicalTranspose) {
@@ -1549,9 +1563,10 @@
   Shape transposed =
       ShapeUtil::MakeShapeWithDenseLayout(F32, {13, 12, 10, 11}, {3, 2, 1, 0});
   std::vector<int64_t> dimensions = {3, 2, 0, 1};
-  EXPECT_EQ(std::make_optional(Vector3{13, 12, 110}),
+  EXPECT_EQ(std::make_optional(absl::InlinedVector<int64_t, 3>{13, 12, 110}),
             ShapeUtil::GetNormalizedLogicalTransposeShape(
-                shape, transposed, dimensions, Vector3{2, 1, 0}));
+                shape, transposed, dimensions,
+                absl::InlinedVector<int64_t, 3>{2, 1, 0}));
 }
 
 TEST(AlgebraicSimplifierTest, ReshapeIsBitcast_3x2x2_6x2_Dim0IsMostMinor) {
diff --git a/third_party/xla/xla/sort_json.cc b/third_party/xla/xla/sort_json.cc
new file mode 100644
index 0000000..aaa1e19
--- /dev/null
+++ b/third_party/xla/xla/sort_json.cc
@@ -0,0 +1,257 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/sort_json.h"
+
+#include <algorithm>
+#include <cctype>
+#include <cstddef>
+#include <memory>
+#include <string>
+#include <utility>
+#include <variant>
+#include <vector>
+
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
+
+namespace {
+
+void SkipWhitespace(absl::string_view json, size_t& index) {
+  while (index < json.size() && std::isspace(json[index])) {
+    ++index;
+  }
+}
+
+absl::Status CheckNotEndOfString(absl::string_view json, int index,
+                                 absl::string_view expected) {
+  return index < json.size()
+             ? absl::OkStatus()
+             : absl::InvalidArgumentError(absl::StrCat(
+                   "Prematurely reached end of JSON while looking for ",
+                   expected, "."));
+}
+
+absl::Status Consume(absl::string_view json, size_t& index, char c,
+                     bool optional = false) {
+  SkipWhitespace(json, index);
+  TF_RETURN_IF_ERROR(CheckNotEndOfString(json, index, std::string(1, c)));
+  if (json[index] == c) {
+    ++index;
+    SkipWhitespace(json, index);
+  } else if (!optional) {
+    return absl::InvalidArgumentError(
+        absl::StrCat("Expected '", std::string(1, c), "', but found '",
+                     std::string(1, json[index]), "'."));
+  }
+  return absl::OkStatus();
+}
+
+struct JsonArray;
+struct JsonObject;
+
+using JsonValue = std::variant<absl::string_view, std::unique_ptr<JsonObject>,
+                               std::unique_ptr<JsonArray>>;
+
+struct JsonField {
+  absl::string_view name;
+  JsonValue value;
+};
+
+template <typename T>
+struct JsonSequence {
+  std::vector<T> elements;
+};
+
+struct JsonArray : public JsonSequence<JsonValue> {};
+struct JsonObject : public JsonSequence<JsonField> {};
+
+// This parses either an array or an object.
+template <typename T, char begin, char end, const char* name, typename ElemFn>
+absl::StatusOr<std::unique_ptr<T>> ParseSequence(absl::string_view outer_json,
+                                                 size_t& index,
+                                                 ElemFn elem_fn) {
+  TF_RETURN_IF_ERROR(Consume(outer_json, index, begin));
+  TF_RETURN_IF_ERROR(CheckNotEndOfString(outer_json, index, name));
+
+  auto seq = std::make_unique<T>();
+  while (outer_json[index] != end) {
+    TF_ASSIGN_OR_RETURN(auto elem, elem_fn(outer_json, index));
+    seq->elements.emplace_back(std::move(elem));
+    TF_RETURN_IF_ERROR(Consume(outer_json, index, ',', /*optional=*/true));
+    TF_RETURN_IF_ERROR(CheckNotEndOfString(outer_json, index, name));
+  }
+  TF_RETURN_IF_ERROR(Consume(outer_json, index, end));
+  return seq;
+}
+
+absl::Status EnsureValidLiteralStart(char c) {
+  if (c != '"' && c != '+' && c != '-' && c != 'f' && c != 't' && c != 'n' &&
+      (c < '0' || c > '9')) {
+    return absl::InvalidArgumentError(absl::StrCat(
+        "Invalid first character of literal: '", std::string(1, c), "'."));
+  }
+  return absl::OkStatus();
+}
+
+bool HandleEscape(absl::string_view outer_json, size_t& index,
+                  bool& is_escaped) {
+  if (is_escaped) {
+    is_escaped = false;
+    ++index;
+    return true;
+  }
+
+  if (outer_json[index] == '\\') {
+    is_escaped = true;
+    ++index;
+    return true;
+  }
+  return false;
+}
+
+bool LiteralIsFinished(absl::string_view outer_json, size_t& index,
+                       bool is_string_literal) {
+  char c = outer_json[index];
+  if (is_string_literal) {
+    index += (c == '"' ? 1 : 0);
+    return c == '"';
+  }
+
+  return std::isspace(c) || c == ',' || c == '{' || c == '}' || c == '[' ||
+         c == ']' || c == ':';
+}
+
+absl::StatusOr<absl::string_view> ParseLiteral(absl::string_view outer_json,
+                                               size_t& index) {
+  SkipWhitespace(outer_json, index);
+  TF_RETURN_IF_ERROR(CheckNotEndOfString(outer_json, index, "literal"));
+
+  auto c = outer_json[index];
+  TF_RETURN_IF_ERROR(EnsureValidLiteralStart(c));
+  bool is_string_literal = c == '"';
+  size_t start_index = index;
+  bool is_escaped = false;
+  ++index;
+
+  while (index < outer_json.size()) {
+    if (HandleEscape(outer_json, index, is_escaped)) {
+      continue;
+    }
+    if (LiteralIsFinished(outer_json, index, is_string_literal)) {
+      break;
+    }
+    ++index;
+  }
+  return outer_json.substr(start_index, index - start_index);
+}
+
+absl::StatusOr<JsonField> ParseField(absl::string_view outer_json,
+                                     size_t& index);
+
+absl::StatusOr<JsonValue> ParseValue(absl::string_view outer_json,
+                                     size_t& index) {
+  JsonValue value;
+  SkipWhitespace(outer_json, index);
+  TF_RETURN_IF_ERROR(CheckNotEndOfString(outer_json, index, "value"));
+  auto c = outer_json[index];
+  if (c == '{') {
+    constexpr static char kObject[] = "object";
+    auto seq = ParseSequence<JsonObject, '{', '}', kObject>(outer_json, index,
+                                                            ParseField);
+    TF_ASSIGN_OR_RETURN(value, std::move(seq));
+  } else if (c == '[') {
+    constexpr static char kArray[] = "array";
+    auto seq = ParseSequence<JsonArray, '[', ']', kArray>(outer_json, index,
+                                                          ParseValue);
+    TF_ASSIGN_OR_RETURN(value, std::move(seq));
+  } else {
+    TF_ASSIGN_OR_RETURN(value, ParseLiteral(outer_json, index));
+  }
+  return value;
+}
+
+absl::StatusOr<JsonField> ParseField(absl::string_view outer_json,
+                                     size_t& index) {
+  JsonField field;
+  TF_ASSIGN_OR_RETURN(field.name, ParseLiteral(outer_json, index));
+  TF_RETURN_IF_ERROR(Consume(outer_json, index, ':'));
+  TF_ASSIGN_OR_RETURN(field.value, ParseValue(outer_json, index));
+  return field;
+}
+
+template <typename T>
+std::vector<std::string> SerializedElements(const JsonSequence<T>& seq) {
+  std::vector<std::string> result;
+  for (const auto& field : seq.elements) {
+    result.push_back("");
+    Serialize(field, result.back());
+  }
+  return result;
+}
+
+template <typename ElemT, char begin_brace, char end_brace>
+void Serialize(const JsonSequence<ElemT>& object, std::string& result) {
+  auto elems = SerializedElements(object);
+  if constexpr (std::is_same_v<ElemT, JsonField>) {
+    std::sort(elems.begin(), elems.end());
+  }
+
+  result += begin_brace;
+  bool has_preceeding = false;
+  for (const auto& elem : elems) {
+    if (has_preceeding) {
+      result += ',';
+    }
+    result += elem;
+    has_preceeding = true;
+  }
+  result += end_brace;
+}
+
+void Serialize(const JsonValue& value, std::string& result) {
+  if (auto* lit = std::get_if<absl::string_view>(&value)) {
+    absl::StrAppend(&result, *lit);
+  } else if (auto* object = std::get_if<std::unique_ptr<JsonObject>>(&value)) {
+    Serialize<JsonField, '{', '}'>(**object, result);
+  } else if (auto* array = std::get_if<std::unique_ptr<JsonArray>>(&value)) {
+    Serialize<JsonValue, '[', ']'>(**array, result);
+  }
+}
+
+void Serialize(const JsonField& field, std::string& result) {
+  absl::StrAppend(&result, field.name, ":");
+  Serialize(field.value, result);
+}
+
+}  // namespace
+
+namespace xla {
+absl::StatusOr<std::string> SortJson(absl::string_view json) {
+  size_t index = 0;
+  TF_ASSIGN_OR_RETURN(auto value, ParseValue(json, index));
+  SkipWhitespace(json, index);
+  if (index < json.size()) {
+    return absl::InvalidArgumentError("Found trailing characters in JSON.");
+  }
+  std::string result;
+  Serialize(value, result);
+  return result;
+}
+}  // namespace xla
diff --git a/third_party/xla/xla/sort_json.h b/third_party/xla/xla/sort_json.h
new file mode 100644
index 0000000..b4283f5
--- /dev/null
+++ b/third_party/xla/xla/sort_json.h
@@ -0,0 +1,35 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SORT_JSON_H_
+#define XLA_SORT_JSON_H_
+
+#include <string>
+
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+
+namespace xla {
+
+// Sorts the given JSON string or returns an error if the JSON could not be
+// parsed. Note that this function expects the input JSON to be valid and not
+// all forms of invalid JSON are correctly recognized. This function completely
+// ignores whitespace and the resulting JSON does not have any whitespace.
+// Comments are not supported in the input JSON.
+absl::StatusOr<std::string> SortJson(absl::string_view json);
+
+}  // namespace xla
+
+#endif  // XLA_SORT_JSON_H_
diff --git a/third_party/xla/xla/sort_json_test.cc b/third_party/xla/xla/sort_json_test.cc
new file mode 100644
index 0000000..f4ff0c1
--- /dev/null
+++ b/third_party/xla/xla/sort_json_test.cc
@@ -0,0 +1,51 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/sort_json.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tsl/platform/status_matchers.h"
+#include "tsl/platform/test.h"
+
+namespace xla {
+namespace {
+
+using ::tsl::testing::IsOkAndHolds;
+
+TEST(SortJsonTest, SortsJson) {
+  EXPECT_THAT(SortJson(R"({"a": 1, "c": 3,"b": 2, "b": 1,})"),
+              IsOkAndHolds(R"({"a":1,"b":1,"b":2,"c":3})"));
+
+  EXPECT_THAT(SortJson(R"({"a": 1  , "c": 1,"b": 1  })"),
+              IsOkAndHolds(R"({"a":1,"b":1,"c":1})"));
+
+  EXPECT_THAT(SortJson(R"({"a": 1,"c": 3,"b": 2,"b": [3,2,1],})"),
+              IsOkAndHolds(R"({"a":1,"b":2,"b":[3,2,1],"c":3})"));
+
+  EXPECT_THAT(SortJson(R"({"aa": 1, "a": {"c": "c", "b": "b"}})"),
+              IsOkAndHolds(R"({"a":{"b":"b","c":"c"},"aa":1})"));
+
+  EXPECT_THAT(
+      SortJson(
+          R"({"x": true, "x": false, "x": null, "x": 0, "x": -0.5,"x": "a"})"),
+      IsOkAndHolds(R"({"x":"a","x":-0.5,"x":0,"x":false,"x":null,"x":true})"));
+
+  EXPECT_THAT(SortJson(R"({"a": "a}", "a": "a"})"),
+              IsOkAndHolds(R"({"a":"a","a":"a}"})"));
+}
+
+}  // namespace
+}  // namespace xla
diff --git a/third_party/xla/xla/stream_executor/BUILD b/third_party/xla/xla/stream_executor/BUILD
index e7d2bb7..baf2833 100644
--- a/third_party/xla/xla/stream_executor/BUILD
+++ b/third_party/xla/xla/stream_executor/BUILD
@@ -220,7 +220,6 @@
         ":stream",
         ":stream_executor_h",
         "//xla:test",
-        "@com_google_absl//absl/functional:any_invocable",
         "@com_google_absl//absl/status",
         "@com_google_absl//absl/status:statusor",
         "@com_google_absl//absl/types:span",
@@ -380,6 +379,7 @@
         ":device_memory",
         ":numeric_options",
         "//xla/stream_executor/platform",
+        "//xla/tsl/lib/strings:proto_serialization",
         "@com_google_absl//absl/algorithm:container",
         "@com_google_absl//absl/container:btree",
         "@com_google_absl//absl/container:flat_hash_map",
@@ -391,7 +391,6 @@
         "@com_google_absl//absl/strings:str_format",
         "@com_google_absl//absl/types:span",
         "@eigen_archive//:eigen3",  # buildcleaner: keep
-        "@local_tsl//tsl/lib/strings:proto_serialization",
         "@local_tsl//tsl/platform:logging",
         "@local_tsl//tsl/platform:ml_dtypes",
         "@local_tsl//tsl/platform:status",
@@ -454,7 +453,6 @@
         ":module_spec",
         ":platform",
         ":stream",
-        "@com_google_absl//absl/functional:any_invocable",
         "@com_google_absl//absl/log",
         "@com_google_absl//absl/status",
         "@com_google_absl//absl/status:statusor",
@@ -576,12 +574,13 @@
         ":platform",
         ":stream_executor_h",
         "@com_google_absl//absl/base:core_headers",
-        "@com_google_absl//absl/container:node_hash_map",
+        "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/status",
         "@com_google_absl//absl/status:statusor",
         "@com_google_absl//absl/strings:str_format",
         "@com_google_absl//absl/synchronization",
         "@local_tsl//tsl/platform:logging",
+        "@local_tsl//tsl/platform:statusor",
     ],
 )
 
@@ -675,7 +674,6 @@
         ":stream",
         ":stream_executor_h",
         "@com_google_absl//absl/base:core_headers",
-        "@com_google_absl//absl/functional:any_invocable",
         "@com_google_absl//absl/log:check",
         "@com_google_absl//absl/status",
         "@com_google_absl//absl/status:statusor",
@@ -811,6 +809,19 @@
     ],
 )
 
+xla_cc_test(
+    name = "executor_cache_test",
+    srcs = ["executor_cache_test.cc"],
+    deps = [
+        ":executor_cache",
+        ":mock_stream_executor",
+        ":platform",
+        ":stream",
+        "@com_google_googletest//:gtest_main",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
 #===--------------------------------------------------------------------------------------------===#
 # Aliases for StreamExecutor platforms
 #===--------------------------------------------------------------------------------------------===#
diff --git a/third_party/xla/xla/stream_executor/cuda/BUILD b/third_party/xla/xla/stream_executor/cuda/BUILD
index 1521b22..644e5ce 100644
--- a/third_party/xla/xla/stream_executor/cuda/BUILD
+++ b/third_party/xla/xla/stream_executor/cuda/BUILD
@@ -10,6 +10,7 @@
 load(
     "@local_tsl//tsl/platform/default:cuda_build_defs.bzl",
     "if_cuda_is_configured",
+    "if_cuda_newer_than",
 )
 load(
     "//xla:xla.bzl",
@@ -225,11 +226,7 @@
     name = "cuda_driver_test",
     srcs = ["cuda_driver_test.cc"],
     backends = ["gpu"],
-    tags = [
-        # TODO(b/317293391) Remove once Bazel test_suite handles tags correctly
-        "gpu",
-        "no_rocm",
-    ],
+    tags = ["no_rocm"],
     deps = [
         ":cuda_driver",
         ":cuda_status",
@@ -428,7 +425,6 @@
         "//xla/stream_executor/gpu:gpu_driver_header",
         "//xla/stream_executor/gpu:gpu_executor_header",
         "//xla/stream_executor/gpu:gpu_semaphore",
-        "//xla/stream_executor/gpu:gpu_stream",
         "@com_google_absl//absl/status:statusor",
     ],
 )
@@ -590,6 +586,7 @@
 cc_library(
     name = "ptx_compiler",
     hdrs = ["ptx_compiler.h"],
+    tags = ["no_rocm"],
     deps = select({
         ":libnvptxcompiler_support_enabled": [":ptx_compiler_impl"],
         "//conditions:default": [":ptx_compiler_stub"],
@@ -599,11 +596,30 @@
     ],
 )
 
+xla_test(
+    name = "cuda_platform_test",
+    srcs = ["cuda_platform_test.cc"],
+    backends = ["gpu"],
+    deps = [
+        ":cuda_platform",
+        "//xla/stream_executor:platform",
+        "//xla/stream_executor:platform_manager",
+        "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_googletest//:gtest_main",
+        "@local_tsl//tsl/platform:statusor",
+        "@local_tsl//tsl/platform:test",
+        "@local_tsl//tsl/platform:test_main",
+    ],
+)
+
 xla_cc_test(
     name = "ptx_compiler_test",
     srcs = ["ptx_compiler_test.cc"],
-    # TODO(b/343996893): Figure out whether msan reports a false positive or not.
-    tags = ["nomsan"],
+    tags = [
+        "no_rocm",
+        # TODO(b/343996893): Figure out whether msan reports a false positive or not.
+        "nomsan",
+    ],
     deps = [
         ":ptx_compiler",
         ":ptx_compiler_support",
@@ -629,7 +645,11 @@
         "//conditions:default": [
             "LIBNVJITLINK_SUPPORT=false",
         ],
-    }),
+    }) + if_cuda_newer_than(
+        "12_0",
+        ["CUDA_SUPPORTS_NVJITLINK=true"],
+        ["CUDA_SUPPORTS_NVJITLINK=false"],
+    ),
 )
 
 cc_library(
@@ -672,13 +692,25 @@
     ],
 )
 
+# Since select() can't be nested, we need to wrap the cuda_newer_than check in a separate
+# library target.
+cc_library(
+    name = "nvjitlink_cuda_supported",
+    # Even though the macro is called `*_newer_than`, it does a greater-than-or-equal-to comparison.
+    deps = if_cuda_newer_than(
+        "12_0",
+        [":nvjitlink_impl"],
+        [":nvjitlink_stub"],
+    ),
+)
+
 cc_library(
     name = "nvjitlink",
     hdrs = [
         "nvjitlink.h",
     ],
     deps = select({
-        ":libnvjitlink_support_enabled": [":nvjitlink_impl"],
+        ":libnvjitlink_support_enabled": [":nvjitlink_cuda_supported"],
         "//conditions:default": [":nvjitlink_stub"],
     }) + [
         "//xla/stream_executor/gpu:gpu_asm_opts",
@@ -824,7 +856,6 @@
         "@local_tsl//tsl/platform:errors",
         "@local_tsl//tsl/platform:fingerprint",
         "@local_tsl//tsl/platform:logging",
-        "@local_tsl//tsl/platform:platform_port",
         "@local_tsl//tsl/platform:statusor",
     ] + if_cuda_is_configured([":delay_kernel_cuda"]),
     alwayslink = True,
diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_asm_compiler.cc b/third_party/xla/xla/stream_executor/cuda/cuda_asm_compiler.cc
index 01aa153..7f2183f 100644
--- a/third_party/xla/xla/stream_executor/cuda/cuda_asm_compiler.cc
+++ b/third_party/xla/xla/stream_executor/cuda/cuda_asm_compiler.cc
@@ -20,7 +20,6 @@
 #include <cstdint>
 #include <cstdlib>
 #include <cstring>
-#include <sstream>
 #include <string>
 #include <tuple>
 #include <utility>
diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.cc b/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.cc
index a4337df..ce16080 100644
--- a/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.cc
+++ b/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.cc
@@ -41,6 +41,7 @@
 #include "xla/stream_executor/blas.h"
 #include "xla/stream_executor/cuda/cuda_blas.h"
 #include "xla/stream_executor/cuda/cuda_blas_utils.h"
+#include "xla/stream_executor/device_memory.h"
 #include "xla/stream_executor/event_based_timer.h"
 #include "xla/stream_executor/gpu/gpu_activation.h"
 #include "xla/stream_executor/gpu/gpu_blas_lt.h"
diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.h b/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.h
index 2fae670..3d61c81 100644
--- a/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.h
+++ b/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.h
@@ -18,6 +18,7 @@
 
 #include <cstddef>
 #include <memory>
+#include <optional>
 #include <type_traits>
 #include <utility>
 #include <vector>
diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc
index 46a2757..440f647 100644
--- a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc
+++ b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc
@@ -40,7 +40,6 @@
 #include "absl/status/status.h"
 #include "absl/strings/str_cat.h"
 #include "absl/strings/str_format.h"
-#include "absl/strings/str_join.h"
 #include "absl/strings/string_view.h"
 #include "absl/synchronization/mutex.h"
 #include "absl/time/time.h"
@@ -3762,32 +3761,6 @@
 }
 
 #if CUDNN_VERSION >= 8800
-enum CudnnfMHAUid {
-  Q_ID = 400,
-  K_ID,
-  V_ID,
-  P_ID,
-  O_ID,
-  dQ_ID,
-  dK_ID,
-  dV_ID,
-  dP_ID,
-  dO_ID,
-  dS_ID,
-  dBIAS_ID,
-  BIAS_ID,
-  MASK_ID,
-  ZERO_VAL_ID,
-  ONE_VAL_ID,
-  NEG_INFINITY_ID,
-  ALPHA_SCALE_ID,
-  DROPOUT_SCALE_ID,
-  Q_SEQLEN_ID,
-  K_SEQLEN_ID,
-  D_OFFSET_ID,
-  D_SEED_ID,
-  VIRTUAL_ID = 34857
-};
 
 absl::StatusOr<cudnn_frontend::PointWiseDesc> CreatePwDesc(
     dnn::DataType dtype, cudnnPointwiseMode_t mode) {
@@ -3842,49 +3815,6 @@
   RETURN_MSG_IF_CUDNN_ERROR(pw_op_created);
   return pw_op_created;
 }
-
-// Returns a cudnn tensor that's the output of the mask op
-absl::StatusOr<cudnn_frontend::Tensor> CreateCudnnMaskFwdTensor(
-    std::vector<cudnn_frontend::Operation>& ops, absl::Span<const int64_t> dims,
-    absl::Span<const int64_t> strides, dnn::DataType dtype,
-    cudnn_frontend::Tensor& input_tensor) {
-  std::vector<int64_t> mask_dim(dims.size(), 1);
-  std::vector<int64_t> mask_stride(strides.size(), 1);
-
-  // Create the mask tensor
-  TF_ASSIGN_OR_RETURN(
-      auto mask_tensor,
-      CreateCudnnTensor(dims, strides, CudnnfMHAUid::MASK_ID, dtype, 1, -1,
-                        /*is_virtual=*/false));
-  // Create the mask output tensor
-  TF_ASSIGN_OR_RETURN(
-      auto mask_out_tensor,
-      CreateCudnnTensor(dims, strides, CudnnfMHAUid::VIRTUAL_ID + 400,
-                        dnn::DataType::kFloat, 1, -1,
-                        /*is_virtual=*/true));
-
-  auto mask_desc = cudnn_frontend::PointWiseDescBuilder()
-                       .setMode(CUDNN_POINTWISE_MUL)
-                       .setComputeType(CUDNN_DATA_FLOAT)
-                       .build();
-
-  // Create the mask op.
-  auto mask_op = cudnn_frontend::OperationBuilder(
-                     CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
-                     .setxDesc(input_tensor)
-                     .setbDesc(mask_tensor)
-                     .setyDesc(mask_out_tensor)
-                     .setpwDesc(mask_desc)
-                     .build();
-
-  RETURN_MSG_IF_CUDNN_ERROR(mask_op);
-
-  RETURN_MSG_IF_CUDNN_ERROR(mask_out_tensor);
-  // Add mask to op list
-  ops.push_back(std::move(mask_op));
-
-  return mask_out_tensor;
-}
 #endif  // CUDNN_VERSION >= 8800
 
 absl::StatusOr<std::unique_ptr<cudnn_frontend::OperationGraph>>
@@ -5047,7 +4977,7 @@
     const dnn::FMHAMaskKind mask_type) {
   using cudnn_frontend::graph::Tensor_attributes;
 
-#if CUDNN_VERSION >= 8904
+#if CUDNN_VERSION >= 90000
   if (VLOG_IS_ON(4)) {
     VLOG(4) << "\n bmm1_lhs(q): " << q_descriptor.ToString()
             << "\n bmm1_rhs(k): " << k_descriptor.ToString()
@@ -5075,12 +5005,14 @@
       .set_io_data_type(ioDataType)
       .set_compute_data_type(cudnn_frontend::DataType_t::FLOAT);
 
+  auto next_uid = [uid = 0]() mutable -> int { return CuDnnTensorUID(uid++); };
+
   std::shared_ptr<Tensor_attributes> q_tensor =
       graph.tensor(Tensor_attributes()
                        .set_name("Q")
                        .set_dim(q_descriptor.GetCudnnCompatibleDimensions(true))
                        .set_stride(q_descriptor.GetCudnnCompatibleStrides(true))
-                       .set_uid(CudnnfMHAUid::Q_ID));
+                       .set_uid(next_uid()));
   auto dim = k_descriptor.GetCudnnCompatibleDimensions(true);
 
   std::shared_ptr<Tensor_attributes> k_tensor =
@@ -5088,13 +5020,13 @@
                        .set_name("K")
                        .set_dim(k_descriptor.GetCudnnCompatibleDimensions(true))
                        .set_stride(k_descriptor.GetCudnnCompatibleStrides(true))
-                       .set_uid(CudnnfMHAUid::K_ID));
+                       .set_uid(next_uid()));
   std::shared_ptr<Tensor_attributes> v_tensor = graph.tensor(
       Tensor_attributes()
           .set_name("V")
           .set_dim(v_descriptor.GetCudnnCompatibleDimensions(false))
           .set_stride(v_descriptor.GetCudnnCompatibleStrides(false))
-          .set_uid(CudnnfMHAUid::V_ID));
+          .set_uid(next_uid()));
 
   // Setting sdpa, and is_inference
   bool is_causal = mask_type == dnn::FMHAMaskKind::CAUSAL ||
@@ -5112,7 +5044,7 @@
                          .set_name("bias")
                          .set_dim(bias_descriptor->dimensions())
                          .set_stride(bias_descriptor->GetLogicalStrides())
-                         .set_uid(CudnnfMHAUid::BIAS_ID));
+                         .set_uid(next_uid()));
     sdpa_options.set_bias(bias_tensor);
   }
   // Setting actual seqlen
@@ -5126,37 +5058,38 @@
                          .set_name("seq_q")
                          .set_dim({b, 1, 1, 1})
                          .set_stride({1, 1, 1, 1})
-                         .set_uid(CudnnfMHAUid::Q_SEQLEN_ID)
+                         .set_uid(next_uid())
                          .set_data_type(cudnn_frontend::DataType_t::INT32));
     auto seq_kv_tensor =
         graph.tensor(Tensor_attributes()
                          .set_name("seq_kv")
                          .set_dim({b, 1, 1, 1})
                          .set_stride({1, 1, 1, 1})
-                         .set_uid(CudnnfMHAUid::K_SEQLEN_ID)
+                         .set_uid(next_uid())
                          .set_data_type(cudnn_frontend::DataType_t::INT32));
     sdpa_options.set_padding_mask(true);
     sdpa_options.set_seq_len_q(seq_q_tensor);
     sdpa_options.set_seq_len_kv(seq_kv_tensor);
   }
   // Setting seed and offset
+  std::shared_ptr<Tensor_attributes> seed_tensor;
+  std::shared_ptr<Tensor_attributes> offset_tensor;
   if (use_dropout) {
-    auto seed_tensor =
+    // Skip setting UIDs: pass by value tensors go at the end.
+    seed_tensor =
         graph.tensor(Tensor_attributes()
                          .set_name("seed")
                          .set_dim({1, 1, 1, 1})
                          .set_stride({1, 1, 1, 1})
                          .set_data_type(cudnn_frontend::DataType_t::INT64)
-                         .set_is_pass_by_value(true)
-                         .set_uid(CudnnfMHAUid::D_SEED_ID));
-    auto offset_tensor =
+                         .set_is_pass_by_value(true));
+    offset_tensor =
         graph.tensor(Tensor_attributes()
                          .set_name("offset")
                          .set_dim({1, 1, 1, 1})
                          .set_stride({1, 1, 1, 1})
                          .set_data_type(cudnn_frontend::DataType_t::INT64)
-                         .set_is_pass_by_value(true)
-                         .set_uid(CudnnfMHAUid::D_OFFSET_ID));
+                         .set_is_pass_by_value(true));
     sdpa_options.set_dropout((float)dropout_rate.value(), seed_tensor,
                              offset_tensor);
   }
@@ -5170,7 +5103,7 @@
       .set_output(true)
       .set_dim(o_descriptor.dimensions())
       .set_stride(o_descriptor.GetLogicalStrides())
-      .set_uid(CudnnfMHAUid::O_ID);
+      .set_uid(next_uid());
   if (stats_descriptor.has_value()) {
     cudnn_frontend::DataType_t statsType =
         ToCudnnFrontendDataType(stats_descriptor->type());
@@ -5183,11 +5116,19 @@
         .set_data_type(statsType)
         .set_dim(stat_dims)
         .set_stride(stat_strides)
-        .set_uid(CudnnfMHAUid::P_ID);
+        .set_uid(next_uid());
+  }
+  if (seed_tensor != nullptr) {
+    seed_tensor->set_uid(next_uid());
+  }
+  if (offset_tensor != nullptr) {
+    offset_tensor->set_uid(next_uid());
   }
   CudnnGraph cudnnGraph(std::move(graph));
-  TF_RETURN_IF_ERROR(cudnnGraph.Prepare(dnn_support));
-  TF_RETURN_IF_ERROR(cudnnGraph.Build(dnn_support, std::nullopt));
+  TF_RETURN_IF_ERROR(cudnnGraph.Prepare(
+      dnn_support, NumericOptions{/*require_determinism=*/false,
+                                  /*allow_tf32=*/true}));
+  TF_RETURN_IF_ERROR(cudnnGraph.Build(dnn_support, /*plan_id=*/std::nullopt));
 
   if (VLOG_IS_ON(4)) {
     VLOG(4) << "\b flash attention operation graph: " << graph;
@@ -5195,7 +5136,7 @@
   return cudnnGraph;
 #else
   return absl::UnimplementedError(
-      "Cudnn flash attention only supported with Cudnn >= 8.9.4");
+      "Cudnn flash attention only supported with Cudnn >= 9.0.0");
 #endif
 }
 
@@ -5211,7 +5152,7 @@
     std::optional<double> dropout_rate, std::optional<int64_t> seed,
     double scale, bool use_dropout, bool use_bias, dnn::FMHAMaskKind mask_type,
     bool force_deterministic) {
-#if CUDNN_VERSION >= 8904
+#if CUDNN_VERSION >= 90000
   if (VLOG_IS_ON(4)) {
     VLOG(4) << "\n bmm1_grad_gemm1_rhs(q): " << q_desc.ToString()
             << "\n bmm1_grad_gemm2_rhs(k): " << k_desc.ToString()
@@ -5236,41 +5177,6 @@
       .set_intermediate_data_type(cudnn_frontend::DataType_t::FLOAT)
       .set_io_data_type(ioDataType);
 
-  std::shared_ptr<Tensor_attributes> q =
-      graph.tensor(Tensor_attributes()
-                       .set_name("Q")
-                       .set_dim(q_desc.GetCudnnCompatibleDimensions(false))
-                       .set_stride(q_desc.GetCudnnCompatibleStrides(false))
-                       .set_uid(CudnnfMHAUid::Q_ID)
-                       .set_data_type(ioDataType));
-  std::shared_ptr<Tensor_attributes> k =
-      graph.tensor(Tensor_attributes()
-                       .set_name("K")
-                       .set_dim(k_desc.GetCudnnCompatibleDimensions(false))
-                       .set_stride(k_desc.GetCudnnCompatibleStrides(false))
-                       .set_uid(CudnnfMHAUid::K_ID)
-                       .set_data_type(ioDataType));
-  std::shared_ptr<Tensor_attributes> v =
-      graph.tensor(Tensor_attributes()
-                       .set_name("V")
-                       .set_dim(v_desc.GetCudnnCompatibleDimensions(true))
-                       .set_stride(v_desc.GetCudnnCompatibleStrides(true))
-                       .set_uid(CudnnfMHAUid::V_ID)
-                       .set_data_type(ioDataType));
-  std::shared_ptr<Tensor_attributes> o =
-      graph.tensor(Tensor_attributes()
-                       .set_name("O")
-                       .set_dim(do_desc.GetCudnnCompatibleDimensions(false))
-                       .set_stride(do_desc.GetCudnnCompatibleStrides(false))
-                       .set_uid(CudnnfMHAUid::O_ID)
-                       .set_data_type(ioDataType));
-  std::shared_ptr<Tensor_attributes> dO =
-      graph.tensor(Tensor_attributes()
-                       .set_name("dO")
-                       .set_dim(do_desc.GetCudnnCompatibleDimensions(false))
-                       .set_stride(do_desc.GetCudnnCompatibleStrides(false))
-                       .set_uid(CudnnfMHAUid::dO_ID)
-                       .set_data_type(ioDataType));
   auto p_dims = p_desc.GetCudnnCompatibleDimensions(false);
   auto p_strides = p_desc.GetCudnnCompatibleStrides(false);
   std::vector<int64_t> p_reduction_dims(p_dims.begin(), p_dims.end() - 1);
@@ -5284,13 +5190,6 @@
     p_reduction_strides.push_back(stride / p_reduced_dim_len);
   }
   p_reduction_strides[3] = 1;
-  std::shared_ptr<Tensor_attributes> stats =
-      graph.tensor(Tensor_attributes()
-                       .set_name("stats")
-                       .set_dim(p_reduction_dims)
-                       .set_stride(p_reduction_strides)
-                       .set_uid(CudnnfMHAUid::P_ID)
-                       .set_data_type(cudnn_frontend::DataType_t::FLOAT));
   bool is_causal = mask_type == dnn::FMHAMaskKind::CAUSAL ||
                    mask_type == dnn::FMHAMaskKind::PADDING_CAUSAL;
   auto sdpa_backward_options =
@@ -5300,7 +5199,44 @@
           .set_attn_scale(scale)
           .set_compute_data_type(cudnn_frontend::DataType_t::FLOAT);
 
-  // Setting bias
+  auto next_uid = [uid = 0]() mutable -> int { return CuDnnTensorUID(uid++); };
+
+  std::shared_ptr<Tensor_attributes> q =
+      graph.tensor(Tensor_attributes()
+                       .set_name("Q")
+                       .set_dim(q_desc.GetCudnnCompatibleDimensions(false))
+                       .set_stride(q_desc.GetCudnnCompatibleStrides(false))
+                       .set_uid(next_uid())
+                       .set_data_type(ioDataType));
+  std::shared_ptr<Tensor_attributes> k =
+      graph.tensor(Tensor_attributes()
+                       .set_name("K")
+                       .set_dim(k_desc.GetCudnnCompatibleDimensions(false))
+                       .set_stride(k_desc.GetCudnnCompatibleStrides(false))
+                       .set_uid(next_uid())
+                       .set_data_type(ioDataType));
+  std::shared_ptr<Tensor_attributes> v =
+      graph.tensor(Tensor_attributes()
+                       .set_name("V")
+                       .set_dim(v_desc.GetCudnnCompatibleDimensions(true))
+                       .set_stride(v_desc.GetCudnnCompatibleStrides(true))
+                       .set_uid(next_uid())
+                       .set_data_type(ioDataType));
+  std::shared_ptr<Tensor_attributes> stats =
+      graph.tensor(Tensor_attributes()
+                       .set_name("stats")
+                       .set_dim(p_reduction_dims)
+                       .set_stride(p_reduction_strides)
+                       .set_uid(next_uid())
+                       .set_data_type(cudnn_frontend::DataType_t::FLOAT));
+  std::shared_ptr<Tensor_attributes> dO =
+      graph.tensor(Tensor_attributes()
+                       .set_name("dO")
+                       .set_dim(do_desc.GetCudnnCompatibleDimensions(false))
+                       .set_stride(do_desc.GetCudnnCompatibleStrides(false))
+                       .set_uid(next_uid())
+                       .set_data_type(ioDataType));
+  std::shared_ptr<Tensor_attributes> d_bias_tensor;
   if (use_bias) {
     DCHECK(bias_descriptor != std::nullopt);
     auto bias_dim = bias_descriptor->dimensions();
@@ -5313,21 +5249,29 @@
                          .set_name("bias")
                          .set_dim(bias_descriptor->dimensions())
                          .set_stride(bias_descriptor->GetLogicalStrides())
-                         .set_uid(CudnnfMHAUid::BIAS_ID));
+                         .set_uid(next_uid()));
     sdpa_backward_options.set_bias(bias_tensor);
 
     // shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s] are not supported for
     // dbias calculation but they are supported for forward bias calculation
+    // Set UID later: this is the last output tuple element.
     if (b == 1 && n == q_n) {
-      auto d_bias_tensor =
+      d_bias_tensor =
           graph.tensor(Tensor_attributes()
                            .set_name("dBias")
                            .set_dim(bias_descriptor->dimensions())
-                           .set_stride(bias_descriptor->GetLogicalStrides())
-                           .set_uid(CudnnfMHAUid::dBIAS_ID));
+                           .set_stride(bias_descriptor->GetLogicalStrides()));
       sdpa_backward_options.set_dbias(d_bias_tensor);
     }
   }
+  std::shared_ptr<Tensor_attributes> o =
+      graph.tensor(Tensor_attributes()
+                       .set_name("O")
+                       .set_dim(do_desc.GetCudnnCompatibleDimensions(false))
+                       .set_stride(do_desc.GetCudnnCompatibleStrides(false))
+                       .set_uid(next_uid())
+                       .set_data_type(ioDataType));
+
   // Setting actual seqlen
   bool is_padding = mask_type == dnn::FMHAMaskKind::PADDING ||
                     mask_type == dnn::FMHAMaskKind::PADDING_CAUSAL;
@@ -5339,38 +5283,39 @@
                          .set_name("seq_q")
                          .set_dim({b, 1, 1, 1})
                          .set_stride({1, 1, 1, 1})
-                         .set_uid(CudnnfMHAUid::Q_SEQLEN_ID)
+                         .set_uid(next_uid())
                          .set_data_type(cudnn_frontend::DataType_t::INT32));
     auto seq_kv_tensor =
         graph.tensor(Tensor_attributes()
                          .set_name("seq_kv")
                          .set_dim({b, 1, 1, 1})
                          .set_stride({1, 1, 1, 1})
-                         .set_uid(CudnnfMHAUid::K_SEQLEN_ID)
+                         .set_uid(next_uid())
                          .set_data_type(cudnn_frontend::DataType_t::INT32));
     sdpa_backward_options.set_padding_mask(true);
     sdpa_backward_options.set_seq_len_q(seq_q_tensor);
     sdpa_backward_options.set_seq_len_kv(seq_kv_tensor);
   }
   // Setting seed and offset
+  std::shared_ptr<Tensor_attributes> seed_tensor;
+  std::shared_ptr<Tensor_attributes> offset_tensor;
   if (use_dropout) {
     DCHECK(dropout_rate != std::nullopt);
-    auto seed_tensor =
+    // Skip setting UIDs: pass by value tensors go at the end.
+    seed_tensor =
         graph.tensor(Tensor_attributes()
                          .set_name("seed")
                          .set_dim({1, 1, 1, 1})
                          .set_stride({1, 1, 1, 1})
                          .set_data_type(cudnn_frontend::DataType_t::INT64)
-                         .set_is_pass_by_value(true)
-                         .set_uid(CudnnfMHAUid::D_SEED_ID));
-    auto offset_tensor =
+                         .set_is_pass_by_value(true));
+    offset_tensor =
         graph.tensor(Tensor_attributes()
                          .set_name("offset")
                          .set_dim({1, 1, 1, 1})
                          .set_stride({1, 1, 1, 1})
                          .set_data_type(cudnn_frontend::DataType_t::INT64)
-                         .set_is_pass_by_value(true)
-                         .set_uid(CudnnfMHAUid::D_OFFSET_ID));
+                         .set_is_pass_by_value(true));
     sdpa_backward_options.set_dropout((float)dropout_rate.value(), seed_tensor,
                                       offset_tensor);
   }
@@ -5385,25 +5330,36 @@
   dQ->set_output(true)
       .set_dim(dq_desc.dimensions())
       .set_stride(dq_desc.GetLogicalStrides())
+      .set_uid(next_uid())
       .set_name("dQ")
-      .set_uid(CudnnfMHAUid::dQ_ID)
       .set_data_type(ioDataType);
   dK->set_output(true)
       .set_dim(dk_desc.dimensions())
       .set_stride(dk_desc.GetLogicalStrides())
+      .set_uid(next_uid())
       .set_name("dK")
-      .set_uid(CudnnfMHAUid::dK_ID)
       .set_data_type(ioDataType);
   dV->set_output(true)
       .set_dim(dv_desc.dimensions())
       .set_stride(dv_desc.GetLogicalStrides())
+      .set_uid(next_uid())
       .set_name("dV")
-      .set_uid(CudnnfMHAUid::dV_ID)
       .set_data_type(ioDataType);
+  if (d_bias_tensor != nullptr) {
+    d_bias_tensor->set_uid(next_uid());
+  }
+  if (seed_tensor != nullptr) {
+    seed_tensor->set_uid(next_uid());
+  }
+  if (offset_tensor != nullptr) {
+    offset_tensor->set_uid(next_uid());
+  }
 
   CudnnGraph cudnnGraph(std::move(graph));
-  TF_RETURN_IF_ERROR(cudnnGraph.Prepare(dnn_support));
-  TF_RETURN_IF_ERROR(cudnnGraph.Build(dnn_support, std::nullopt));
+  TF_RETURN_IF_ERROR(
+      cudnnGraph.Prepare(dnn_support, NumericOptions{force_deterministic,
+                                                     /*allow_tf32=*/true}));
+  TF_RETURN_IF_ERROR(cudnnGraph.Build(dnn_support, /*plan_id=*/std::nullopt));
 
   if (VLOG_IS_ON(4)) {
     VLOG(4) << "\b flash attention operation backward graph: " << graph;
@@ -5412,7 +5368,7 @@
   return cudnnGraph;
 #else
   return absl::UnimplementedError(
-      "Cudnn flash attention only supported with Cudnn >= 8.9.4");
+      "Cudnn flash attention only supported with Cudnn >= 9.0.0");
 #endif
 }
 
@@ -5735,8 +5691,8 @@
 }
 
 // Utility for dealing with CUDA's type-erased scaling parameters, where some
-// sets of parameters expect a void* pointing at a float while others expect it
-// to point at a double.
+// sets of parameters expect a void* pointing at a float while others expect
+// it to point at a double.
 //
 // This is rather ugly, but its purpose is to quarantine the corresponding
 // ugliness that already exists in the CUDA API.
@@ -5760,9 +5716,9 @@
   //
   // See
   // https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#scaling-parameters
-  // for more info; the behavior for int8 result tensors is not described there,
-  // but is maintained from the existing behavior (namely, using a float scaling
-  // parameter).
+  // for more info; the behavior for int8 result tensors is not described
+  // there, but is maintained from the existing behavior (namely, using a
+  // float scaling parameter).
   void* ToVoidPointer(dnn::DataType element_type) {
     if (element_type == dnn::DataType::kDouble) {
       return &as_double_;
@@ -5834,10 +5790,11 @@
   absl::c_transform(result, std::back_inserter(raw_ptrs),
                     [](const BackendDescriptor& ptr) { return ptr.get(); });
 
-  // This API evidently does a deep copy of the descriptors into the pointers in
-  // the output array, rather than writing pointers to the descriptors into the
-  // output array.  So, this writes the memory behind each BackendDescriptor in
-  // result, rather than writing the contents of raw_ptrs.
+  // This API evidently does a deep copy of the descriptors into the pointers
+  // in the output array, rather than writing pointers to the descriptors into
+  // the output array.  So, this writes the memory behind each
+  // BackendDescriptor in result, rather than writing the contents of
+  // raw_ptrs.
   RETURN_IF_CUDNN_ERROR(cudnnBackendGetAttribute(
       desc, name, CUDNN_TYPE_BACKEND_DESCRIPTOR, n, &n, raw_ptrs.data()));
 
@@ -5873,9 +5830,9 @@
       cudnnBackendGetAttribute(engines[0].get(), CUDNN_ATTR_ENGINE_GLOBAL_INDEX,
                                CUDNN_TYPE_INT64, 1, &n, &engine_id));
 
-  // Apparently for CUDNN_ATTR_ENGINECFG_KNOB_CHOICES only, trying to query the
-  // number of elements in the attribute by using an output limit value of 0
-  // just returns 0; the only way to find out how many there are is to
+  // Apparently for CUDNN_ATTR_ENGINECFG_KNOB_CHOICES only, trying to query
+  // the number of elements in the attribute by using an output limit value of
+  // 0 just returns 0; the only way to find out how many there are is to
   // pre-allocate space for every existing knob type (as an upper bound on the
   // number of knob choices a config can have), and then look back at how many
   // were filled.
@@ -6086,103 +6043,7 @@
   std::vector<int64_t> scalar_input_uids_;
   std::vector<ScalingParam> scalar_input_values_;
 };
-#endif  // CUDNN_VERSION >= 8100
 
-template <typename Sig>
-class CudnnGraphRunner;
-// An OpRunner implemented by a cuDNN frontend graph.
-//
-// This is the class holding the implementation of ToString, GetWorkspaceSize,
-// and operator() for use by the cudnn frontend op runners.
-template <typename... Args>
-class CudnnGraphRunner<void(Args...)> : public dnn::OpRunner<void(Args...)> {
- private:
-  using Graph = cudnn_frontend::graph::Graph;
-  using Tensor_attributes = cudnn_frontend::graph::Tensor_attributes;
-
- public:
-  std::string ToString() const override { return graph_.Graph().print(); }
-
-  size_t GetWorkspaceSize() const override {
-    return graph_.Graph().get_workspace_size();
-  }
-
-  absl::StatusOr<dnn::AlgorithmDesc> ToAlgorithmDesc() const override {
-    return absl::InternalError(
-        "Unexpected call to CudnnGraphRunner::ToAlgorithmDesc");
-  }
-
-  absl::Status operator()(Stream* stream, dnn::ProfileResult* profile_result,
-                          DeviceMemoryBase scratch_memory,
-                          Args... inputs) const override {
-    if (parent_ != stream->parent()) {
-      return tsl::errors::Internal(
-          "CudnnExecutionPlanRunner cached across multiple StreamExecutors.");
-    }
-    CudnnHandle handle = cudnn_->GetHandle(parent_, stream);
-    std::unordered_map<int64_t, void*> variant_pack;
-    std::vector<void*> vec = {inputs.opaque()...};
-
-    // add device buffers to the variant pack
-    for (int i = 0; i < uids_.size(); ++i) {
-      if (uids_[i].has_value()) {
-        variant_pack[*uids_[i]] = vec[i];
-      }
-    }
-    if (dropout_rng_offset_increment_ > 0) {
-#if CUDNN_VERSION >= 8800
-      variant_pack[CudnnfMHAUid::D_SEED_ID] = (void*)&dropout_rng_seed_;
-      current_dropout_rng_offset_ += dropout_rng_offset_increment_;
-      variant_pack[CudnnfMHAUid::D_OFFSET_ID] =
-          (void*)&current_dropout_rng_offset_;
-#else
-      return absl::UnimplementedError(
-          "Cudnn dropout offset and seed are only supported with Cudnn >= "
-          "8.8.0");
-#endif  // CUDNN_VERSION >= 8800
-    }
-    int workspace = graph_.Graph().get_workspace_size();
-    if (workspace > scratch_memory.size()) {
-      return tsl::errors::Internal(
-          absl::StrFormat("CuDNN FMHA requires %d workspace, got %d workspace.",
-                          workspace, scratch_memory.size()));
-    }
-    RETURN_IF_CUDNN_FRONTEND_ERROR(graph_.Graph().execute(
-        handle.handle(), variant_pack, scratch_memory.opaque()));
-
-    return absl::OkStatus();
-  }
-
-  static absl::StatusOr<CudnnGraphRunner> Create(
-      GpuExecutor* parent, CudnnAccess* cudnn, CudnnGraph graph,
-      int64_t dropout_rng_seed, int64_t dropout_rng_offset,
-      std::vector<std::optional<int64_t>> uids) {
-    return CudnnGraphRunner(parent, cudnn, std::move(graph), dropout_rng_seed,
-                            dropout_rng_offset, uids);
-  }
-
- private:
-  CudnnGraphRunner(GpuExecutor* parent, CudnnAccess* cudnn, CudnnGraph graph,
-                   int64_t dropout_rng_seed, int64_t dropout_rng_offset,
-                   std::vector<std::optional<int64_t>> uids)
-      : parent_(parent),
-        cudnn_(cudnn),
-        graph_(std::move(graph)),
-        dropout_rng_seed_(dropout_rng_seed),
-        current_dropout_rng_offset_(0),
-        dropout_rng_offset_increment_(dropout_rng_offset),
-        uids_(uids) {}
-  GpuExecutor* parent_;
-  CudnnAccess* cudnn_;
-  Stream* stream_;
-  CudnnGraph graph_;
-  int64_t dropout_rng_seed_;
-  mutable int64_t current_dropout_rng_offset_;
-  int64_t dropout_rng_offset_increment_;
-  std::vector<std::optional<int64_t>> uids_;
-};
-
-#if CUDNN_VERSION >= 8100
 namespace {
 
 template <typename Sig>
@@ -6968,7 +6829,8 @@
       use_fallback, out_exec_plans, /*need_side_input=*/true, numeric_options);
 #else
   return tsl::errors::Unimplemented(
-      "Cudnn execution plans for matmul are only supported with Cudnn >= 8.4.");
+      "Cudnn execution plans for matmul are only supported with Cudnn >= "
+      "8.4.");
 #endif  // CUDNN_VERSION >= 8400
 }
 
@@ -7170,139 +7032,6 @@
   return max_seq_len * max_seq_len / cudnn_mha_num_threads;
 }
 
-absl::StatusOr<std::unique_ptr<const dnn::FusedMHARunner>>
-CudnnSupport::FusedMHARunnerFromDesc(
-    Stream* stream, const dnn::AlgorithmDesc& algorithm_desc,
-    const dnn::MatmulTensorDescriptor& bmm1_lhs_descriptor,
-    const dnn::MatmulTensorDescriptor& bmm1_rhs_descriptor,
-    const dnn::MatmulTensorDescriptor& bmm2_rhs_descriptor,
-    const dnn::MatmulTensorDescriptor& intermediate_bmm2_lhs_descriptor,
-    const dnn::TensorDescriptor& output_descriptor,
-    std::optional<dnn::TensorDescriptor> activation_descriptor,
-    std::optional<dnn::TensorDescriptor> bias_descriptor, double scale,
-    std::optional<double> dropout_rate, std::optional<int64_t> seed,
-    dnn::FMHAMaskKind mask_type) {
-#if CUDNN_VERSION >= 8904
-  auto cudnn = cudnn_->GetHandle(parent_, stream);
-  bool use_dropout = dropout_rate && *dropout_rate > 0.0;
-  std::vector<int64_t> intermediate_shape;
-
-  TF_ASSIGN_OR_RETURN(auto graph,
-                      GetCudnnFlashAttentionOperationGraph(
-                          *this, /*q_descriptor=*/bmm1_lhs_descriptor,
-                          /*k_descriptor=*/bmm1_rhs_descriptor,
-                          /*v_descriptor=*/bmm2_rhs_descriptor,
-                          /*o_descriptor=*/output_descriptor, bias_descriptor,
-                          /*stats_descriptor=*/activation_descriptor,
-                          /*scale=*/static_cast<float>(scale), use_dropout,
-                          dropout_rate, mask_type));
-
-  std::vector<int64_t> intermediate_bmm2_lhs_dims =
-      intermediate_bmm2_lhs_descriptor.GetCudnnCompatibleDimensions(true);
-  intermediate_shape = intermediate_bmm2_lhs_dims;
-  int64_t dropout_rng_offset = GetDropoutRngOffset(intermediate_shape);
-  int64_t dropout_rng_seed = seed.has_value() ? *seed : 0;
-  std::vector<std::optional<int64_t>> uids = {
-      CudnnfMHAUid::Q_ID, CudnnfMHAUid::K_ID, CudnnfMHAUid::V_ID,
-      CudnnfMHAUid::O_ID};
-  uids.emplace_back(bias_descriptor.has_value()
-                        ? std::optional<CudnnfMHAUid>(CudnnfMHAUid::BIAS_ID)
-                        : std::nullopt);
-  uids.emplace_back(activation_descriptor.has_value()
-                        ? std::optional<CudnnfMHAUid>(CudnnfMHAUid::P_ID)
-                        : std::nullopt);
-  bool is_padding = mask_type == dnn::FMHAMaskKind::PADDING ||
-                    mask_type == dnn::FMHAMaskKind::PADDING_CAUSAL;
-  uids.emplace_back(is_padding
-                        ? std::optional<CudnnfMHAUid>(CudnnfMHAUid::Q_SEQLEN_ID)
-                        : std::nullopt);
-  uids.emplace_back(is_padding
-                        ? std::optional<CudnnfMHAUid>(CudnnfMHAUid::K_SEQLEN_ID)
-                        : std::nullopt);
-  TF_ASSIGN_OR_RETURN(auto runner,
-                      CudnnGraphRunner<dnn::FusedMHASignature>::Create(
-                          parent_, cudnn_.get(), std::move(graph),
-                          dropout_rng_seed, dropout_rng_offset, uids));
-
-  return {std::make_unique<CudnnGraphRunner<dnn::FusedMHASignature>>(
-      std::move(runner))};
-#else
-  return absl::UnimplementedError(
-      "Cudnn flash attention are only supported with Cudnn >= 8.9.4");
-#endif  // CUDNN_VERSION >= 8904
-}
-
-absl::StatusOr<std::unique_ptr<const dnn::FusedMHABackwardRunner>>
-CudnnSupport::FusedMHABackwardRunnerFromDesc(
-    Stream* stream, const dnn::AlgorithmDesc& algorithm_desc,
-    const dnn::MatmulTensorDescriptor& bmm1_grad_gemm1_rhs_descriptor,
-    const dnn::MatmulTensorDescriptor& bmm1_grad_gemm2_rhs_descriptor,
-    const dnn::MatmulTensorDescriptor& bmm2_grad_gemm1_lhs_descriptor,
-    const dnn::MatmulTensorDescriptor& bmm2_grad_gemm2_rhs_descriptor,
-    const dnn::MatmulTensorDescriptor& d_output_descriptor,
-    const dnn::TensorDescriptor& d_bmm1_lhs_descriptor,
-    const dnn::TensorDescriptor& d_bmm1_rhs_descriptor,
-    const dnn::TensorDescriptor& d_bmm2_rhs_descriptor,
-    std::optional<dnn::TensorDescriptor> d_s_descriptor,
-    std::optional<dnn::TensorDescriptor> d_bias_descriptor,
-    std::optional<dnn::TensorDescriptor> fwd_output_descriptor,
-    std::optional<dnn::TensorDescriptor> bias_descriptor, double scale,
-    std::optional<double> dropout_rate, std::optional<int64_t> seed,
-    dnn::FMHAMaskKind mask_type, bool force_deterministic) {
-#if CUDNN_VERSION >= 8904
-  auto cudnn = cudnn_->GetHandle(parent_, stream);
-
-  bool use_dropout = dropout_rate && *dropout_rate > 0.0;
-  std::vector<int64_t> intermediate_shape;
-
-  TF_ASSIGN_OR_RETURN(
-      auto graph,
-      GetCudnnFlashAttentionBackwardOperationGraph(
-          *this, bmm1_grad_gemm1_rhs_descriptor, bmm1_grad_gemm2_rhs_descriptor,
-          bmm2_grad_gemm1_lhs_descriptor, bmm2_grad_gemm2_rhs_descriptor,
-          d_output_descriptor, d_bmm1_lhs_descriptor, d_bmm1_rhs_descriptor,
-          d_bmm2_rhs_descriptor, bias_descriptor, dropout_rate, seed, scale,
-          use_dropout, bias_descriptor != std::nullopt, mask_type,
-          force_deterministic));
-
-  std::vector<int64_t> p_dims =
-      bmm2_grad_gemm1_lhs_descriptor.GetCudnnCompatibleDimensions(false);
-  intermediate_shape = p_dims;
-  int64_t dropout_rng_offset = GetDropoutRngOffset(intermediate_shape);
-  int64_t dropout_rng_seed = seed.has_value() ? *seed : 0;
-
-  std::vector<std::optional<int64_t>> uids;
-  uids = {CudnnfMHAUid::Q_ID,  CudnnfMHAUid::K_ID,  CudnnfMHAUid::P_ID,
-          CudnnfMHAUid::V_ID,  CudnnfMHAUid::dO_ID, CudnnfMHAUid::dQ_ID,
-          CudnnfMHAUid::dK_ID, CudnnfMHAUid::dV_ID, std::nullopt};
-  uids.emplace_back(d_bias_descriptor.has_value()
-                        ? std::optional<CudnnfMHAUid>(CudnnfMHAUid::dBIAS_ID)
-                        : std::nullopt);
-  uids.push_back(CudnnfMHAUid::O_ID);
-  uids.emplace_back(bias_descriptor.has_value()
-                        ? std::optional<CudnnfMHAUid>(CudnnfMHAUid::BIAS_ID)
-                        : std::nullopt);
-  bool is_padding = mask_type == dnn::FMHAMaskKind::PADDING ||
-                    mask_type == dnn::FMHAMaskKind::PADDING_CAUSAL;
-  uids.emplace_back(is_padding
-                        ? std::optional<CudnnfMHAUid>(CudnnfMHAUid::Q_SEQLEN_ID)
-                        : std::nullopt);
-  uids.emplace_back(is_padding
-                        ? std::optional<CudnnfMHAUid>(CudnnfMHAUid::K_SEQLEN_ID)
-                        : std::nullopt);
-  TF_ASSIGN_OR_RETURN(auto runner,
-                      CudnnGraphRunner<dnn::FusedMHABackwardSignature>::Create(
-                          parent_, cudnn_.get(), graph, dropout_rng_seed,
-                          dropout_rng_offset, uids));
-  return {std::make_unique<CudnnGraphRunner<dnn::FusedMHABackwardSignature>>(
-      std::move(runner))};
-#else
-  return absl::UnimplementedError(
-      "Cudnn flash attention bwd are only "
-      "supported with Cudnn >= 8.9.4");
-#endif  // CUDNN_VERSION >= 8904
-}
-
 bool CudnnSupport::GetRnnAlgorithms(
     std::vector<dnn::AlgorithmDesc>* out_algorithms) {
   PreloadCudnnSubLibs(PreloadCudnnType::Rnn);
@@ -8353,11 +8082,16 @@
   return std::make_unique<CudnnGraph>(std::move(graph));
 }
 
-absl::Status CudnnGraph::Prepare(dnn::DnnSupport& dnn_support) {
+absl::Status CudnnGraph::Prepare(dnn::DnnSupport& dnn_support,
+                                 const NumericOptions& numeric_options) {
   const CudnnSupport& cudnn_support = static_cast<CudnnSupport&>(dnn_support);
   TF_ASSIGN_OR_RETURN(auto cudnn, cudnn_support.cudnn_->GetLocalHandle());
   RETURN_IF_CUDNN_FRONTEND_ERROR(graph_.validate());
   RETURN_IF_CUDNN_FRONTEND_ERROR(graph_.build_operation_graph(cudnn->handle()));
+  if (numeric_options.require_determinism) {
+    graph_.deselect_numeric_notes(
+        {cudnn_frontend::NumericalNote_t::NONDETERMINISTIC});
+  }
   RETURN_IF_CUDNN_FRONTEND_ERROR(
       graph_.create_execution_plans({cudnn_frontend::HeurMode_t::A}));
   RETURN_IF_CUDNN_FRONTEND_ERROR(graph_.check_support(cudnn->handle()));
@@ -8382,15 +8116,30 @@
   std::unordered_map<int64_t, void*> tensor_to_ptr_map;
   absl::Span<DeviceMemoryBase> operands_without_workspace = operands;
   DeviceMemoryBase workspace;
-  if (graph_.get_workspace_size() != 0) {
+  if (graph_.get_workspace_size() > 0) {
     workspace = operands.back();
     CHECK_EQ(graph_.get_workspace_size(), workspace.size());
+  }
+  if (graph_.get_workspace_size() > 0 || operands.back().size() == 0) {
     operands_without_workspace = operands.first(operands.size() - 1);
   }
-  int operand_number = 0;
+  auto next_uid = [uid = 0]() mutable -> int { return CuDnnTensorUID(uid++); };
   for (DeviceMemoryBase operand : operands_without_workspace) {
-    tensor_to_ptr_map[CuDnnTensorUID(operand_number++)] = operand.opaque();
+    tensor_to_ptr_map[next_uid()] = operand.opaque();
   }
+
+  if (dropout_rng_offset_increment_ > 0) {
+#if CUDNN_VERSION >= 8800
+    tensor_to_ptr_map[next_uid()] = (void*)&dropout_rng_seed_;
+    current_dropout_rng_offset_ += dropout_rng_offset_increment_;
+    tensor_to_ptr_map[next_uid()] = (void*)&current_dropout_rng_offset_;
+#else
+    return absl::UnimplementedError(
+        "Cudnn dropout offset and seed are only supported with Cudnn >= "
+        "8.8.0");
+#endif  // CUDNN_VERSION >= 8800
+  }
+
   const CudnnSupport& dnn_support =
       static_cast<CudnnSupport&>(*stream.parent()->AsDnn());
   RETURN_IF_CUDNN_FRONTEND_ERROR(graph_.execute(
diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h
index 5208693..24d84e3 100644
--- a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h
+++ b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h
@@ -60,7 +60,7 @@
   explicit CudnnGraph(cudnn_frontend::graph::Graph&& graph)
       : graph_(std::move(graph)) {}
   // Prepares a graph and checks whether it is generally supported.
-  absl::Status Prepare(dnn::DnnSupport&) override;
+  absl::Status Prepare(dnn::DnnSupport&, const NumericOptions&) override;
   // Builds single plan of the graph with given ID.
   absl::Status Build(dnn::DnnSupport&, std::optional<int64_t> plan_id) override;
   // Builds all the plans
@@ -70,6 +70,9 @@
 
  private:
   cudnn_frontend::graph::Graph graph_;
+  int64_t dropout_rng_seed_;
+  mutable int64_t current_dropout_rng_offset_;
+  int64_t dropout_rng_offset_increment_ = 0;
 };
 #endif  // CUDNN_VERSION >= 8100
 
@@ -335,37 +338,6 @@
       std::optional<dnn::TensorDescriptor> dscale_descriptor,
       std::optional<dnn::TensorDescriptor> dbias_descriptor) override;
 
-  absl::StatusOr<std::unique_ptr<const dnn::FusedMHARunner>>
-  FusedMHARunnerFromDesc(
-      Stream* stream, const dnn::AlgorithmDesc& algorithm_desc,
-      const dnn::MatmulTensorDescriptor& bmm1_lhs_descriptor,
-      const dnn::MatmulTensorDescriptor& bmm1_rhs_descriptor,
-      const dnn::MatmulTensorDescriptor& bmm2_rhs_descriptor,
-      const dnn::MatmulTensorDescriptor& intermediate_bmm2_lhs_descriptor,
-      const dnn::TensorDescriptor& output_descriptor,
-      std::optional<dnn::TensorDescriptor> activation_descriptor,
-      std::optional<dnn::TensorDescriptor> bias_descriptor, double scale,
-      std::optional<double> dropout_rate, std::optional<int64_t> seed,
-      dnn::FMHAMaskKind mask_type) override;
-
-  absl::StatusOr<std::unique_ptr<const dnn::FusedMHABackwardRunner>>
-  FusedMHABackwardRunnerFromDesc(
-      Stream* stream, const dnn::AlgorithmDesc& algorithm_desc,
-      const dnn::MatmulTensorDescriptor& bmm1_grad_gemm1_rhs_descriptor,
-      const dnn::MatmulTensorDescriptor& bmm1_grad_gemm2_rhs_descriptor,
-      const dnn::MatmulTensorDescriptor& bmm2_grad_gemm1_lhs_descriptor,
-      const dnn::MatmulTensorDescriptor& bmm2_grad_gemm2_rhs_descriptor,
-      const dnn::MatmulTensorDescriptor& d_output_descriptor,
-      const dnn::TensorDescriptor& d_bmm1_lhs_descriptor,
-      const dnn::TensorDescriptor& d_bmm1_rhs_descriptor,
-      const dnn::TensorDescriptor& d_bmm2_rhs_descriptor,
-      std::optional<dnn::TensorDescriptor> d_s_descriptor,
-      std::optional<dnn::TensorDescriptor> d_bias_descriptor,
-      std::optional<dnn::TensorDescriptor> fwd_output_descriptor,
-      std::optional<dnn::TensorDescriptor> bias_descriptor, double scale,
-      std::optional<double> dropout_rate, std::optional<int64_t> seed,
-      dnn::FMHAMaskKind mask_type, bool force_deterministic);
-
   bool GetRnnAlgorithms(
       std::vector<dnn::AlgorithmDesc>* out_algorithms) override;
 
diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc b/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc
index 866c1ff..e8e26e2 100644
--- a/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc
+++ b/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc
@@ -19,6 +19,7 @@
 #include <stdlib.h>
 
 #include <cstdint>
+#include <cstdlib>
 #include <cstring>
 #include <string>
 #include <utility>
@@ -27,7 +28,6 @@
 
 #include "absl/base/casts.h"
 #include "absl/base/const_init.h"
-#include "absl/base/optimization.h"
 #include "absl/container/inlined_vector.h"
 #include "absl/debugging/leak_check.h"
 #include "absl/log/check.h"
diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_driver.h b/third_party/xla/xla/stream_executor/cuda/cuda_driver.h
index 5c04ab6..aefd896 100644
--- a/third_party/xla/xla/stream_executor/cuda/cuda_driver.h
+++ b/third_party/xla/xla/stream_executor/cuda/cuda_driver.h
@@ -19,16 +19,13 @@
 #define XLA_STREAM_EXECUTOR_CUDA_CUDA_DRIVER_H_
 
 #include <algorithm>
-#include <cstdint>
 #include <memory>
-#include <string>
 #include <utility>
 #include <vector>
 
 #include "absl/container/node_hash_map.h"
 #include "absl/log/check.h"
 #include "absl/log/log.h"
-#include "absl/strings/str_cat.h"
 #include "absl/synchronization/mutex.h"
 #include "third_party/gpus/cuda/include/cuda.h"
 #include "xla/stream_executor/cuda/cuda_status.h"
diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc
index b624709..8ae2477 100644
--- a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc
+++ b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc
@@ -17,20 +17,17 @@
 #include <cstdint>
 #include <cstdio>
 #include <cstdlib>
-#include <ios>
 #include <memory>
 #include <optional>
 #include <string>
 #include <tuple>
 #include <utility>
 #include <variant>
-#include <vector>
 
-#include "absl/base/casts.h"
 #include "absl/numeric/int128.h"
-#include "absl/strings/str_join.h"
 #include "xla/stream_executor/blas.h"
 #include "xla/stream_executor/device_description.h"
+#include "xla/stream_executor/device_memory.h"
 #include "xla/stream_executor/dnn.h"
 #include "xla/stream_executor/event.h"
 #include "xla/stream_executor/event_based_timer.h"
@@ -46,7 +43,6 @@
 #include <unistd.h>
 #endif
 
-#include "absl/functional/any_invocable.h"
 #include "absl/log/check.h"
 #include "absl/status/status.h"
 #include "absl/status/statusor.h"
@@ -54,7 +50,6 @@
 #include "absl/strings/numbers.h"
 #include "absl/strings/str_cat.h"
 #include "absl/strings/str_format.h"
-#include "absl/strings/str_split.h"
 #include "absl/strings/string_view.h"
 #include "absl/synchronization/mutex.h"
 #include "absl/types/span.h"
@@ -75,18 +70,15 @@
 #include "xla/stream_executor/gpu/gpu_stream.h"
 #include "xla/stream_executor/gpu/gpu_timer.h"
 #include "xla/stream_executor/gpu/gpu_types.h"
-#include "xla/stream_executor/integrations/device_mem_allocator.h"
 #include "xla/stream_executor/kernel.h"
 #include "xla/stream_executor/module_spec.h"
 #include "xla/stream_executor/platform.h"
 #include "xla/stream_executor/plugin_registry.h"
 #include "xla/stream_executor/stream.h"
 #include "xla/stream_executor/stream_executor.h"
-#include "tsl/platform/env.h"
 #include "tsl/platform/errors.h"
 #include "tsl/platform/fingerprint.h"
 #include "tsl/platform/logging.h"
-#include "tsl/platform/numa.h"
 #include "tsl/platform/statusor.h"
 
 // LOG(ERROR) uses a const named ERROR, so a macro with the same name is
@@ -154,9 +146,6 @@
   }
 }
 
-static std::optional<int> TryToReadNumaNode(const std::string& pci_bus_id,
-                                            int device_ordinal);
-
 absl::Status GpuExecutor::Init() {
   TF_RETURN_IF_ERROR(GpuDriver::Init());
   TF_RETURN_IF_ERROR(GpuDriver::GetDevice(device_ordinal_, &device_));
@@ -164,17 +153,6 @@
       GpuDriver::CreateContext(device_ordinal_, device_, &context_));
   TF_RETURN_IF_ERROR(
       GpuDriver::GetComputeCapability(&cc_major_, &cc_minor_, device_));
-  std::optional<int> numa_node = TryToReadNumaNode(
-      absl::AsciiStrToLower(GpuDriver::GetPCIBusID(device_ordinal_)),
-      device_ordinal_);
-  if (!numa_node || *numa_node < 0) {
-    LOG(WARNING) << "NUMA node could not be determined for device "
-                 << device_ordinal_
-                 << ", host memory allocations will not be NUMA-pinned";
-    numa_node_ = tsl::port::kNUMANoAffinity;
-  } else {
-    numa_node_ = *numa_node;
-  }
   return absl::OkStatus();
 }
 
@@ -492,83 +470,6 @@
   return absl::OkStatus();
 }
 
-absl::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims,
-                                 const BlockDim& block_dims,
-                                 const Kernel& kernel, const KernelArgs& args) {
-  return Launch(stream, thread_dims, block_dims, std::nullopt, kernel, args);
-}
-
-absl::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims,
-                                 const BlockDim& block_dims,
-                                 const ClusterDim& cluster_dims,
-                                 const Kernel& kernel, const KernelArgs& args) {
-  return Launch(stream, thread_dims, block_dims,
-                std::make_optional(cluster_dims), kernel, args);
-}
-
-absl::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims,
-                                 const BlockDim& block_dims,
-                                 const std::optional<ClusterDim>& cluster_dims,
-                                 const Kernel& kernel, const KernelArgs& args) {
-  CUstream custream = AsGpuStreamValue(stream);
-  const GpuKernel* cuda_kernel = AsGpuKernel(&kernel);
-  CUfunction cufunc = cuda_kernel->gpu_function();
-
-  if (cuda_kernel->cache_config() != KernelCacheConfig::kNoPreference) {
-    TF_RETURN_IF_ERROR(GpuDriver::FuncSetCacheConfig(
-        cufunc, cuda_kernel->GetGpuCacheConfig()));
-  }
-
-  // Launch CUDA kernels with packed arguments.
-  auto launch = [&](const KernelArgsPackedArrayBase& packed) {
-    int32_t expected_number_of_arguments =
-        kernel.Arity() + (packed.number_of_shared_bytes() > 0);
-
-    CHECK_EQ(expected_number_of_arguments, packed.number_of_arguments())
-        << "Kernel " << kernel.name() << " has " << packed.number_of_arguments()
-        << " arguments, but expected " << expected_number_of_arguments
-        << "; arity=" << kernel.Arity()
-        << "; number_of_shared_bytes=" << packed.number_of_shared_bytes();
-
-    void** params = const_cast<void**>(packed.argument_addresses().data());
-
-    if (cluster_dims.has_value()) {
-      return GpuDriver::LaunchKernel(
-          context_, kernel.name(), cufunc, cluster_dims->x, cluster_dims->y,
-          cluster_dims->z, block_dims.x, block_dims.y, block_dims.z,
-          thread_dims.x, thread_dims.y, thread_dims.z,
-          packed.number_of_shared_bytes(), custream, params,
-          /*extra=*/nullptr);
-    } else {
-      return GpuDriver::LaunchKernel(
-          context_, kernel.name(), cufunc, block_dims.x, block_dims.y,
-          block_dims.z, thread_dims.x, thread_dims.y, thread_dims.z,
-          packed.number_of_shared_bytes(), custream, params,
-          /*extra=*/nullptr);
-    }
-  };
-
-  // If arguments are already packed we can just launch the kernel.
-  if (auto* packed = DynCast<KernelArgsPackedArrayBase>(&args)) {
-    return launch(*packed);
-  }
-
-  // For device memory array we rely on a custom kernel arguments packing.
-  if (auto* device_mem = DynCast<KernelArgsDeviceMemoryArray>(&args)) {
-    auto& pack = kernel.args_packing();
-    if (!pack) {
-      return absl::InternalError(
-          "Kernel is missing a custom arguments packing function for device "
-          "memory arguments array");
-    }
-
-    TF_ASSIGN_OR_RETURN(auto packed, pack(kernel, *device_mem));
-    return launch(*packed);
-  }
-
-  return absl::InternalError("Unsupported kernel arguments type");
-}
-
 DeviceMemoryBase GpuExecutor::Allocate(uint64_t size, int64_t memory_space) {
   if (memory_space == 1) {
     auto result = GpuCollectives::CollectiveMemoryAllocate(context_, size);
@@ -588,47 +489,6 @@
   GpuDriver::DeviceDeallocate(context_, mem->opaque());
 }
 
-// CUDA allocation/registration functions are necessary because the driver
-// internally sets up buffers for DMA operations (and page locks them). There's
-// no external interface for us to otherwise control these DMA settings.
-absl::StatusOr<std::unique_ptr<MemoryAllocation>>
-GpuExecutor::HostMemoryAllocate(uint64_t size) {
-  if (numa_node_ != tsl::port::kNUMANoAffinity) {
-    auto* buffer =
-        tsl::port::NUMAMalloc(numa_node_, size, /* minimum_alignment=*/16);
-    if (buffer == nullptr && size > 0) {
-      return absl::InternalError(absl::StrFormat(
-          "Failed to allocate host memory of size %d pinned to NUMA node %d",
-          size, numa_node_));
-    }
-    if (size > 0 && !GpuDriver::HostRegister(context_, buffer, size)) {
-      return absl::InternalError(
-          absl::StrFormat("Failed to register host memory of size %d pinned to "
-                          "NUMA node %d with the GPU driver",
-                          size, numa_node_));
-    }
-    return std::make_unique<HostMemoryAllocation>(buffer, size, this);
-  } else {
-    auto* buffer = GpuDriver::HostAllocate(context_, size);
-    if (buffer == nullptr && size > 0) {
-      return absl::InternalError(
-          absl::StrFormat("Failed to allocate HostMemory of size %d", size));
-    }
-    return std::make_unique<HostMemoryAllocation>(buffer, size, this);
-  }
-}
-
-void GpuExecutor::HostMemoryDeallocate(void* location, uint64_t size) {
-  if (numa_node_ != tsl::port::kNUMANoAffinity) {
-    if (size > 0) {
-      GpuDriver::HostUnregister(context_, location);
-    }
-    tsl::port::NUMAFree(location, size);
-  } else {
-    GpuDriver::HostDeallocate(context_, location);
-  }
-}
-
 bool GpuExecutor::SynchronizeAllActivity() {
   return GpuDriver::SynchronizeContext(context_);
 }
@@ -829,22 +689,22 @@
 GpuContext* GpuExecutor::gpu_context() { return context_; }
 
 // Attempts to read the NUMA node corresponding to the GPU device's PCI bus out
-// of SysFS.
+// of SysFS. Returns -1 if it cannot.
 //
 // For anything more complicated/prod-focused than this, you'll likely want to
-// turn to gsys' topology modeling. nvmlDeviceGetMemoryAffinity could also be
-// used.
-static std::optional<int> TryToReadNumaNode(const std::string& pci_bus_id,
-                                            int device_ordinal) {
+// turn to gsys' topology modeling.
+static int TryToReadNumaNode(const std::string& pci_bus_id,
+                             int device_ordinal) {
 #if defined(PLATFORM_WINDOWS)
   // Windows support for NUMA is not currently implemented. Return node 0.
   return 0;
 #else
   VLOG(2) << "trying to read NUMA node for device ordinal: " << device_ordinal;
+  static const int kUnknownNumaNode = -1;
 
   if (pci_bus_id.empty()) {
     LOG(INFO) << "no PCI bus ID for device ordinal: " << device_ordinal;
-    return std::nullopt;
+    return kUnknownNumaNode;
   }
 
   std::string filename =
@@ -857,7 +717,7 @@
   if (file == nullptr) {
     LOG(INFO) << "could not open file to read NUMA node: " << filename
               << "\nYour kernel may have been built without NUMA support.";
-    return std::nullopt;
+    return kUnknownNumaNode;
   }
 
   std::string content;
@@ -868,6 +728,17 @@
 
   int32_t value;
   if (absl::SimpleAtoi(content, &value)) {
+    if (value < 0) {  // See http://b/18228951 for details on this path.
+      LOG(INFO) << "successful NUMA node read from SysFS had negative value ("
+                << value
+                << "), but there must be at least one NUMA node"
+                   ", so returning NUMA node zero."
+                   " See more at "
+                   "https://github.com/torvalds/linux/blob/v6.0/Documentation/"
+                   "ABI/testing/sysfs-bus-pci#L344-L355";
+      fclose(file);
+      return 0;
+    }
     fclose(file);
     return value;
   }
@@ -877,7 +748,7 @@
       << content;
 
   fclose(file);
-  return std::nullopt;
+  return kUnknownNumaNode;
 #endif
 }
 
@@ -909,24 +780,8 @@
     builder.set_pci_bus_id(pci_bus_id);
 
     // Read the NUMA node corresponding to the PCI bus ID out of sysfs.
-    std::optional<int> numa_node =
-        TryToReadNumaNode(pci_bus_id, device_ordinal);
-    if (numa_node.has_value()) {
-      if (*numa_node < 0) {  // See http://b/18228951 for details on this path.
-        LOG(INFO)
-            << "successful NUMA node read from SysFS had negative value ("
-            << *numa_node
-            << "), but there must be at least one NUMA node"
-               ", so returning NUMA node zero."
-               " See more at "
-               "https://github.com/torvalds/linux/blob/v6.0/Documentation/"
-               "ABI/testing/sysfs-bus-pci#L344-L355";
-        numa_node = 0;
-      }
-    } else {
-      numa_node = -1;
-    }
-    builder.set_numa_node(*numa_node);
+    int numa_node = TryToReadNumaNode(pci_bus_id, device_ordinal);
+    builder.set_numa_node(numa_node);
   }
 
   {
diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_platform.cc b/third_party/xla/xla/stream_executor/cuda/cuda_platform.cc
index bdace57..16dab63 100644
--- a/third_party/xla/xla/stream_executor/cuda/cuda_platform.cc
+++ b/third_party/xla/xla/stream_executor/cuda/cuda_platform.cc
@@ -15,19 +15,14 @@
 
 #include "xla/stream_executor/cuda/cuda_platform.h"
 
-#include <algorithm>
-#include <cstdlib>
-#include <cstring>
 #include <memory>
 #include <string>
 #include <utility>
 
-#include "absl/base/call_once.h"
 #include "absl/log/check.h"
 #include "absl/log/log.h"
 #include "absl/status/status.h"
 #include "absl/status/statusor.h"
-#include "absl/strings/str_format.h"
 #include "xla/stream_executor/cuda/cuda_platform_id.h"
 #include "xla/stream_executor/device_description.h"
 #include "xla/stream_executor/gpu/gpu_driver.h"
@@ -35,65 +30,16 @@
 #include "xla/stream_executor/platform.h"
 #include "xla/stream_executor/platform/initialize.h"
 #include "xla/stream_executor/platform_manager.h"
+#include "tsl/platform/errors.h"
 #include "tsl/platform/status.h"
 
 namespace stream_executor {
 namespace gpu {
 
-CudaPlatform::CudaPlatform()
-    : name_("CUDA"), min_numa_node_(0), limit_numa_node_(0) {}
+CudaPlatform::CudaPlatform() : name_("CUDA") {}
 
 CudaPlatform::~CudaPlatform() {}
 
-// Due to legacy issues in user code, we can't currently call InpectNumaNodes
-// at module initialization time, because non-GPU programs still include this
-// plugin via various methods, so instead, it has to be init-on-reference.
-void CudaPlatform::InspectNumaNodes() {
-  // To get NUMA node information, we need to create all executors, so we can
-  // examine their device descriptions to see their bus assignments.
-  static absl::once_flag once;
-  absl::call_once(once, [&] {
-    for (int i = 0; i < VisibleDeviceCount(); i++) {
-      StreamExecutor* exec = *ExecutorForDevice(i);
-      if (i == 0) {
-        // NUMA nodes may not start at 0, so set the minimum node  based on the
-        // first executor we see.
-        min_numa_node_ = exec->GetDeviceDescription().numa_node();
-        limit_numa_node_ = min_numa_node_ + 1;
-      } else {
-        min_numa_node_ =
-            std::min(min_numa_node_, exec->GetDeviceDescription().numa_node());
-        limit_numa_node_ = std::max(
-            limit_numa_node_, exec->GetDeviceDescription().numa_node() + 1);
-      }
-    }
-  });
-}
-
-int CudaPlatform::BusCount() {
-  InspectNumaNodes();
-  return limit_numa_node_ - min_numa_node_;
-}
-
-int CudaPlatform::DeviceToBus(int device_ordinal) {
-  StreamExecutor* exec = *ExecutorForDevice(device_ordinal);
-  return exec->GetDeviceDescription().numa_node() - min_numa_node_;
-}
-
-absl::StatusOr<StreamExecutor*> CudaPlatform::FirstExecutorForBus(
-    int bus_ordinal) {
-  InspectNumaNodes();
-  CHECK_LT(bus_ordinal, BusCount()) << "bus ordinal out of available range";
-  for (int i = 0; i < VisibleDeviceCount(); i++) {
-    if (DeviceToBus(i) == bus_ordinal) {
-      return *ExecutorForDevice(i);
-    }
-  }
-
-  return absl::NotFoundError(
-      absl::StrFormat("Executor for bus %d not found.", bus_ordinal));
-}
-
 Platform::Id CudaPlatform::id() const { return cuda::kCudaPlatformId; }
 
 int CudaPlatform::VisibleDeviceCount() const {
@@ -118,6 +64,12 @@
   return GetExecutor(config);
 }
 
+absl::StatusOr<StreamExecutor*> CudaPlatform::FindExisting(int ordinal) {
+  StreamExecutorConfig config;
+  config.ordinal = ordinal;
+  return executor_cache_.Get(config);
+}
+
 absl::StatusOr<StreamExecutor*> CudaPlatform::GetExecutor(
     const StreamExecutorConfig& config) {
   if (config.gpu_stream) {
@@ -133,24 +85,15 @@
 absl::StatusOr<std::unique_ptr<StreamExecutor>>
 CudaPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) {
   auto executor = std::make_unique<GpuExecutor>(this, config.ordinal);
-  auto init_status = executor->Init();
-  if (!init_status.ok()) {
-    return absl::InternalError(absl::StrFormat(
-        "failed initializing StreamExecutor for CUDA device ordinal %d: %s",
-        config.ordinal, init_status.ToString()));
-  }
-
+  TF_RETURN_IF_ERROR(executor->Init());
   return std::move(executor);
 }
 
 }  // namespace gpu
 
 static void InitializeCudaPlatform() {
-  // Disabling leak checking, PlatformManager does not destroy its
-  // registered platforms.
-
-  std::unique_ptr<gpu::CudaPlatform> platform(new gpu::CudaPlatform);
-  TF_CHECK_OK(PlatformManager::RegisterPlatform(std::move(platform)));
+  TF_CHECK_OK(
+      PlatformManager::RegisterPlatform(std::make_unique<gpu::CudaPlatform>()));
 }
 
 }  // namespace stream_executor
diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_platform.h b/third_party/xla/xla/stream_executor/cuda/cuda_platform.h
index 153282b..ac33429 100644
--- a/third_party/xla/xla/stream_executor/cuda/cuda_platform.h
+++ b/third_party/xla/xla/stream_executor/cuda/cuda_platform.h
@@ -22,7 +22,6 @@
 #include "absl/status/statusor.h"
 #include "xla/stream_executor/executor_cache.h"
 #include "xla/stream_executor/platform.h"
-#include "xla/stream_executor/platform_manager.h"
 #include "xla/stream_executor/stream_executor.h"
 
 namespace stream_executor {
@@ -41,16 +40,6 @@
   CudaPlatform();
   ~CudaPlatform() override;
 
-  // CudaPlatform-specific functionality
-  // Returns the number of distinct buses / NUMA nodes on the machine.
-  int BusCount();
-
-  // Returns the bus/NUMA node for the specified device ordinal.
-  int DeviceToBus(int device_ordinal);
-
-  // Returns the lowest-ordinal-number StreamExecutor on the specified bus.
-  absl::StatusOr<StreamExecutor*> FirstExecutorForBus(int bus_ordinal);
-
   // Platform interface implementation:
   // Returns the same value as kCudaPlatform above.
   Platform::Id id() const override;
@@ -64,32 +53,24 @@
       int ordinal) const override;
 
   absl::StatusOr<StreamExecutor*> ExecutorForDevice(int ordinal) override;
+  absl::StatusOr<StreamExecutor*> FindExisting(int ordinal) override;
 
   absl::StatusOr<StreamExecutor*> GetExecutor(
       const StreamExecutorConfig& config) override;
 
+  // Returns a device constructed with the options specified in "config" without
+  // looking in or storing to the Platform's executor cache.
+  // Ownership IS transferred to the caller.
   absl::StatusOr<std::unique_ptr<StreamExecutor>> GetUncachedExecutor(
-      const StreamExecutorConfig& config) override;
+      const StreamExecutorConfig& config);
 
  private:
-  // Determines the number of NUMA nodes and the assignment of executor to each.
-  void InspectNumaNodes();
-
   // This platform's name.
   std::string name_;
 
   // Cache of created executors.
   ExecutorCache executor_cache_;
 
-  // The smallest NUMA node value for any device managed by this machine
-  // manager. Used, along with limit_numa_node_, to convert NUMA nodes into bus
-  // ordinals. The NUMA node space occupied by GPUs is assumed to be dense./
-  int min_numa_node_;
-
-  // Larger than the NUMA node value for any device managed by this machine
-  // manager.
-  int limit_numa_node_;
-
   CudaPlatform(const CudaPlatform&) = delete;
   void operator=(const CudaPlatform&) = delete;
 };
diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_platform_test.cc b/third_party/xla/xla/stream_executor/cuda/cuda_platform_test.cc
new file mode 100644
index 0000000..b9621f7
--- /dev/null
+++ b/third_party/xla/xla/stream_executor/cuda/cuda_platform_test.cc
@@ -0,0 +1,48 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/stream_executor/cuda/cuda_platform.h"
+
+#include <gtest/gtest.h>
+#include "absl/container/flat_hash_map.h"
+#include "xla/stream_executor/platform.h"
+#include "xla/stream_executor/platform_manager.h"
+#include "tsl/platform/statusor.h"
+#include "tsl/platform/test.h"
+
+namespace stream_executor::gpu {
+namespace {
+
+TEST(CudaPlatformTest, FindExistingWorks) {
+  TF_ASSERT_OK_AND_ASSIGN(Platform * platform,
+                          PlatformManager::PlatformWithName("CUDA"));
+  CHECK_GT(platform->VisibleDeviceCount(), 0);
+  for (int i = 0; i < platform->VisibleDeviceCount(); ++i) {
+    EXPECT_FALSE(platform->FindExisting(i).ok());
+  }
+  absl::flat_hash_map<int, StreamExecutor*> executors;
+  for (int i = 0; i < platform->VisibleDeviceCount(); ++i) {
+    TF_ASSERT_OK_AND_ASSIGN(auto executor, platform->ExecutorForDevice(i));
+    executors[i] = executor;
+  }
+  EXPECT_EQ(executors.size(), platform->VisibleDeviceCount());
+  for (int i = 0; i < platform->VisibleDeviceCount(); ++i) {
+    TF_ASSERT_OK_AND_ASSIGN(auto executor, platform->FindExisting(i));
+    EXPECT_EQ(executor, executors[i]);
+  }
+}
+
+}  // namespace
+}  // namespace stream_executor::gpu
diff --git a/third_party/xla/xla/stream_executor/cuda/cudnn_frontend_helpers.h b/third_party/xla/xla/stream_executor/cuda/cudnn_frontend_helpers.h
index aa59af5..0a30c1a 100644
--- a/third_party/xla/xla/stream_executor/cuda/cudnn_frontend_helpers.h
+++ b/third_party/xla/xla/stream_executor/cuda/cudnn_frontend_helpers.h
@@ -29,6 +29,11 @@
     }                                                                       \
   } while (false)
 
+// UIDs for cuDNN are unique identifiers of tensors within a graph. They are
+// assigned during graph construction; then graph execution takes a {uid:
+// buffer pointer} map defining the correspondance of buffers to tensors.
+// UID assignment scheme can be arbitrary; at the moment for simplicity XLA uses
+// a scheme UID = (HLO operand number + 1).
 int CuDnnTensorUID(int offset);
 
 }  // namespace gpu
diff --git a/third_party/xla/xla/stream_executor/cuda/delay_kernel.h b/third_party/xla/xla/stream_executor/cuda/delay_kernel.h
index 09aad2f..016639d 100644
--- a/third_party/xla/xla/stream_executor/cuda/delay_kernel.h
+++ b/third_party/xla/xla/stream_executor/cuda/delay_kernel.h
@@ -18,7 +18,6 @@
 
 #include "absl/status/statusor.h"
 #include "xla/stream_executor/gpu/gpu_semaphore.h"
-#include "xla/stream_executor/gpu/gpu_stream.h"
 #include "xla/stream_executor/stream.h"
 
 namespace stream_executor::gpu {
diff --git a/third_party/xla/xla/stream_executor/cuda/nvjitlink_support.cc b/third_party/xla/xla/stream_executor/cuda/nvjitlink_support.cc
index 1803697..bedd416 100644
--- a/third_party/xla/xla/stream_executor/cuda/nvjitlink_support.cc
+++ b/third_party/xla/xla/stream_executor/cuda/nvjitlink_support.cc
@@ -16,5 +16,7 @@
 #include "xla/stream_executor/cuda/nvjitlink_support.h"
 
 namespace stream_executor {
-bool IsLibNvJitLinkSupported() { return LIBNVJITLINK_SUPPORT; }
+bool IsLibNvJitLinkSupported() {
+  return LIBNVJITLINK_SUPPORT && CUDA_SUPPORTS_NVJITLINK;
+}
 }  // namespace stream_executor
diff --git a/third_party/xla/xla/stream_executor/cuda/ptx_compilation_method.h b/third_party/xla/xla/stream_executor/cuda/ptx_compilation_method.h
index d6e28e9..12d5ae5 100644
--- a/third_party/xla/xla/stream_executor/cuda/ptx_compilation_method.h
+++ b/third_party/xla/xla/stream_executor/cuda/ptx_compilation_method.h
@@ -22,6 +22,7 @@
 namespace stream_executor {
 
 enum class PtxCompilationMethod {
+  kNvJitLink,
   kNvPtxCompiler,
   kPtxas,
 };
@@ -30,6 +31,9 @@
 static void AbslStringify(Sink& sink,
                           const PtxCompilationMethod& compilation_method) {
   switch (compilation_method) {
+    case PtxCompilationMethod::kNvJitLink:
+      sink.Append("NvJitLink");
+      break;
     case PtxCompilationMethod::kNvPtxCompiler:
       sink.Append("NvPtxCompiler");
       break;
diff --git a/third_party/xla/xla/stream_executor/cuda/ptx_compiler_test.cc b/third_party/xla/xla/stream_executor/cuda/ptx_compiler_test.cc
index c295833..aae9406 100644
--- a/third_party/xla/xla/stream_executor/cuda/ptx_compiler_test.cc
+++ b/third_party/xla/xla/stream_executor/cuda/ptx_compiler_test.cc
@@ -19,7 +19,6 @@
 
 #include <cstdint>
 #include <string>
-#include <utility>
 #include <vector>
 
 #include <gmock/gmock.h>
diff --git a/third_party/xla/xla/stream_executor/cuda/ptx_linking_method.h b/third_party/xla/xla/stream_executor/cuda/ptx_linking_method.h
index 56dcdf1..aafc36d 100644
--- a/third_party/xla/xla/stream_executor/cuda/ptx_linking_method.h
+++ b/third_party/xla/xla/stream_executor/cuda/ptx_linking_method.h
@@ -26,11 +26,15 @@
   kNone,
   kNvLink,
   kDriver,
+  kNvJitLink,
 };
 
 template <typename Sink>
 void AbslStringify(Sink& sink, const PtxLinkingMethod& method) {
   switch (method) {
+    case PtxLinkingMethod::kNvJitLink:
+      sink.Append("NvJitLink");
+      break;
     case PtxLinkingMethod::kNvLink:
       sink.Append("NvLink");
       break;
diff --git a/third_party/xla/xla/stream_executor/dnn.cc b/third_party/xla/xla/stream_executor/dnn.cc
index 5a674a0..951b2f6 100644
--- a/third_party/xla/xla/stream_executor/dnn.cc
+++ b/third_party/xla/xla/stream_executor/dnn.cc
@@ -41,7 +41,7 @@
 #include "xla/stream_executor/data_type.h"
 #include "xla/stream_executor/device_memory.h"
 #include "xla/stream_executor/numeric_options.h"
-#include "tsl/lib/strings/proto_serialization.h"
+#include "xla/tsl/lib/strings/proto_serialization.h"
 #include "tsl/platform/ml_dtypes.h"
 #include "tsl/protobuf/dnn.pb.h"
 
@@ -249,42 +249,6 @@
   return absl::UnimplementedError("NormRunnerFromDesc not implemented.");
 }
 
-absl::StatusOr<std::unique_ptr<const dnn::FusedMHARunner>>
-DnnSupport::FusedMHARunnerFromDesc(
-    Stream* stream, const dnn::AlgorithmDesc& algorithm_desc,
-    const dnn::MatmulTensorDescriptor& bmm1_lhs_descriptor,
-    const dnn::MatmulTensorDescriptor& bmm1_rhs_descriptor,
-    const dnn::MatmulTensorDescriptor& bmm2_rhs_descriptor,
-    const dnn::MatmulTensorDescriptor& intermediate_bmm2_lhs_descriptor,
-    const dnn::TensorDescriptor& output_descriptor,
-    std::optional<dnn::TensorDescriptor> activation_descriptor,
-    std::optional<dnn::TensorDescriptor> bias_descriptor, double scale,
-    std::optional<double> dropout_rate, std::optional<int64_t> seed,
-    dnn::FMHAMaskKind mask_type) {
-  return absl::UnimplementedError("FusedMHARunnerFromDesc not implemented.");
-}
-
-absl::StatusOr<std::unique_ptr<const dnn::FusedMHABackwardRunner>>
-DnnSupport::FusedMHABackwardRunnerFromDesc(
-    Stream* stream, const dnn::AlgorithmDesc& algorithm_desc,
-    const MatmulTensorDescriptor& bmm1_grad_gemm1_rhs_descriptor,
-    const MatmulTensorDescriptor& bmm1_grad_gemm2_rhs_descriptor,
-    const MatmulTensorDescriptor& bmm2_grad_gemm1_lhs_descriptor,
-    const MatmulTensorDescriptor& bmm2_grad_gemm2_rhs_descriptor,
-    const MatmulTensorDescriptor& d_output_descriptor,
-    const TensorDescriptor& d_bmm1_lhs_descriptor,
-    const TensorDescriptor& d_bmm1_rhs_descriptor,
-    const TensorDescriptor& d_bmm2_rhs_descriptor,
-    std::optional<dnn::TensorDescriptor> d_s_descriptor,
-    std::optional<dnn::TensorDescriptor> d_bias_descriptor,
-    std::optional<dnn::TensorDescriptor> fwd_output_descriptor,
-    std::optional<dnn::TensorDescriptor> bias_descriptor, double scale,
-    std::optional<double> dropout_rate, std::optional<int64_t> seed,
-    dnn::FMHAMaskKind mask_type, bool force_deterministic) {
-  return absl::UnimplementedError(
-      "FusedMHABackwardRunnerFromDesc not implemented.");
-}
-
 bool DnnSupport::GetMIOpenConvolveAlgorithms(
     dnn::ConvolutionKind /*kind*/, dnn::DataType /*element_type*/,
     dnn::DataType /*output_type*/, Stream* /*stream*/,
diff --git a/third_party/xla/xla/stream_executor/dnn.h b/third_party/xla/xla/stream_executor/dnn.h
index 72f4603..a2e1cd6 100644
--- a/third_party/xla/xla/stream_executor/dnn.h
+++ b/third_party/xla/xla/stream_executor/dnn.h
@@ -993,30 +993,6 @@
 using NormSignature = void(std::vector<DeviceMemoryBase>);
 using NormRunner = OpRunner<NormSignature>;
 
-using FusedMHASignature = void(DeviceMemoryBase /*BMM1_inputA_data*/,
-                               DeviceMemoryBase /* BMM1_inputB_data */,
-                               DeviceMemoryBase /* BMM2_inputA_data */,
-                               DeviceMemoryBase /* output_data */,
-                               DeviceMemoryBase /* bias_data */,
-                               DeviceMemoryBase /* activation_data */,
-                               DeviceMemoryBase /* seqlen_q_data */,
-                               DeviceMemoryBase /* seqlen_k_data */);
-using FusedMHARunner = OpRunner<FusedMHASignature>;
-
-using FusedMHABackwardSignature = void(
-    DeviceMemoryBase /* BMM1_GRAD_GEMM1_inputA_data */,
-    DeviceMemoryBase /* BMM1_GRAD_GEMM2_inputB_data */,
-    DeviceMemoryBase /* BMM2_GRAD_GEMM1_inputA_data */,
-    DeviceMemoryBase /* BMM2_GRAD_GEMM2_inputB_data */,
-    DeviceMemoryBase /* d_output_data */,
-    DeviceMemoryBase /* d_BMM1_inputA_data */,
-    DeviceMemoryBase /* d_BMM1_inputB_data */,
-    DeviceMemoryBase /* d_BMM2_inputB_data */, DeviceMemoryBase /* d_S_data */,
-    DeviceMemoryBase /* d_bias_data */, DeviceMemoryBase /* fwd_output_data */,
-    DeviceMemoryBase /* bias_data */, DeviceMemoryBase /* seqlen_q_data */,
-    DeviceMemoryBase /* seqlen_k_data */);
-using FusedMHABackwardRunner = OpRunner<FusedMHABackwardSignature>;
-
 // Describes the configuration for the algorithms that will used.
 //
 // Arguments:
@@ -1257,11 +1233,7 @@
   DnnGraph() = default;
   virtual ~DnnGraph() = default;
 
-  // Returns non-OK status on hard failures (incorrectly constructed graph,
-  // anything else unexpected),
-  // false on expected ones (graph is valid but not supported),
-  // true on success.
-  virtual absl::Status Prepare(DnnSupport&) = 0;
+  virtual absl::Status Prepare(DnnSupport&, const NumericOptions&) = 0;
   virtual absl::Status Build(DnnSupport&, std::optional<int64_t> plan_id) = 0;
   virtual absl::Status Execute(Stream& stream,
                                absl::Span<DeviceMemoryBase> operands) const = 0;
@@ -1735,37 +1707,6 @@
     return absl::UnimplementedError("Graph support requires cuDNN >= 8.1.");
   };
 
-  virtual absl::StatusOr<std::unique_ptr<const FusedMHARunner>>
-  FusedMHARunnerFromDesc(
-      Stream* stream, const AlgorithmDesc& algorithm_desc,
-      const MatmulTensorDescriptor& bmm1_lhs_descriptor,
-      const MatmulTensorDescriptor& bmm1_rhs_descriptor,
-      const MatmulTensorDescriptor& bmm2_rhs_descriptor,
-      const MatmulTensorDescriptor& intermediate_bmm2_lhs_descriptor,
-      const TensorDescriptor& output_descriptor,
-      std::optional<TensorDescriptor> activation_descriptor,
-      std::optional<TensorDescriptor> bias_descriptor, double scale,
-      std::optional<double> dropout_rate, std::optional<int64_t> seed,
-      dnn::FMHAMaskKind mask_type);
-
-  virtual absl::StatusOr<std::unique_ptr<const FusedMHABackwardRunner>>
-  FusedMHABackwardRunnerFromDesc(
-      Stream* stream, const AlgorithmDesc& algorithm_desc,
-      const MatmulTensorDescriptor& bmm1_grad_gemm1_rhs_descriptor,
-      const MatmulTensorDescriptor& bmm1_grad_gemm2_rhs_descriptor,
-      const MatmulTensorDescriptor& bmm2_grad_gemm1_lhs_descriptor,
-      const MatmulTensorDescriptor& bmm2_grad_gemm2_rhs_descriptor,
-      const MatmulTensorDescriptor& d_output_descriptor,
-      const TensorDescriptor& d_bmm1_lhs_descriptor,
-      const TensorDescriptor& d_bmm1_rhs_descriptor,
-      const TensorDescriptor& d_bmm2_rhs_descriptor,
-      std::optional<TensorDescriptor> d_s_descriptor,
-      std::optional<TensorDescriptor> d_bias_descriptor,
-      std::optional<TensorDescriptor> fwd_output_descriptor,
-      std::optional<TensorDescriptor> bias_descriptor, double scale,
-      std::optional<double> dropout_rate, std::optional<int64_t> seed,
-      dnn::FMHAMaskKind mask_type, bool force_deterministic);
-
   virtual bool GetMIOpenConvolveAlgorithms(
       ConvolutionKind kind, DataType element_type, DataType output_type,
       Stream* stream, const BatchDescriptor& input_descriptor,
diff --git a/third_party/xla/xla/stream_executor/executor_cache.cc b/third_party/xla/xla/stream_executor/executor_cache.cc
index eae7206..341af6f 100644
--- a/third_party/xla/xla/stream_executor/executor_cache.cc
+++ b/third_party/xla/xla/stream_executor/executor_cache.cc
@@ -25,11 +25,12 @@
 #include "xla/stream_executor/platform.h"
 #include "xla/stream_executor/stream_executor.h"
 #include "tsl/platform/logging.h"
+#include "tsl/platform/statusor.h"
 
 namespace stream_executor {
 
 ExecutorCache::ExecutorCache() = default;
-ExecutorCache::~ExecutorCache() { DestroyAllExecutors(); }
+ExecutorCache::~ExecutorCache() = default;
 
 absl::StatusOr<StreamExecutor*> ExecutorCache::GetOrCreate(
     const StreamExecutorConfig& config, const ExecutorFactory& factory) {
@@ -40,85 +41,36 @@
     return fast_result;
   }
 
-  Entry* entry = nullptr;
-  {
-    absl::MutexLock lock{&mutex_};
-    entry = &cache_[config.ordinal];
-    // Release the map lock; the address of 'entry' is stable because
-    // absl::node_hash_map guarantees reference stability.
-  }
-
-  // Acquire the per-Entry mutex without holding the map mutex. Initializing
-  // an Executor may be expensive, so we want to allow concurrent
-  // initialization of different entries.
-  absl::MutexLock lock{&entry->configurations_mutex};
-  for (const auto& iter : entry->configurations) {
-    VLOG(2) << "hit in cache";
-    return iter.second.get();
-  }
-
   VLOG(2) << "building executor";
-  absl::StatusOr<std::unique_ptr<StreamExecutor>> result = factory();
-  if (!result.ok()) {
-    VLOG(2) << "failed to get build executor: " << result.status();
-    // If construction failed, leave the cache Entry around, but with a null
-    // executor.
-    return result.status();
-  }
-  entry->configurations.emplace_back(config, std::move(result.value()));
-  return entry->configurations.back().second.get();
+  TF_ASSIGN_OR_RETURN(std::unique_ptr<StreamExecutor> result, factory());
+  auto returned_executor = result.get();
+  absl::MutexLock lock(&mutex_);
+  cache_.emplace(config.ordinal, std::move(result));
+  return returned_executor;
 }
 
 absl::StatusOr<StreamExecutor*> ExecutorCache::Get(
     const StreamExecutorConfig& config) {
-  Entry* entry = nullptr;
-  {
-    absl::ReaderMutexLock lock{&mutex_};
+  absl::ReaderMutexLock lock{&mutex_};
 
-    // If gpu stream is not nullptr we have to find StreamExecutor that owns it,
-    // and return NOT_FOUND error if we can't find it.
-    if (config.gpu_stream) {
-      for (auto& [ordinal, e] : cache_) {
-        absl::ReaderMutexLock l{&e.configurations_mutex};
-        for (auto& [c, executor] : e.configurations) {
-          if (executor->FindAllocatedStream(config.gpu_stream)) {
-            return executor.get();
-          }
-        }
+  // If gpu stream is not nullptr we have to find StreamExecutor that owns it,
+  // and return NOT_FOUND error if we can't find it.
+  if (config.gpu_stream) {
+    for (auto& [ordinal, executor] : cache_) {
+      if (executor->FindAllocatedStream(config.gpu_stream)) {
+        return executor.get();
       }
-      return absl::NotFoundError(
-          absl::StrFormat("No executors own stream %p", config.gpu_stream));
     }
-
-    if (auto it = cache_.find(config.ordinal); it != cache_.end()) {
-      entry = &it->second;
-    } else {
-      return absl::NotFoundError(absl::StrFormat(
-          "No executors registered for ordinal %d", config.ordinal));
-    }
+    return absl::NotFoundError(
+        absl::StrFormat("No executors own stream %p", config.gpu_stream));
   }
 
-  absl::ReaderMutexLock lock{&entry->configurations_mutex};
-  if (entry->configurations.empty()) {
-    return absl::NotFoundError(absl::StrFormat(
-        "No executors registered for ordinal %d", config.ordinal));
+  if (auto it = cache_.find(config.ordinal); it != cache_.end()) {
+    return it->second.get();
   }
 
-  for (auto& [entry_config, entry_executor] : entry->configurations) {
-    return entry_executor.get();
-  }
-
-  return absl::NotFoundError("No executor found with a matching config.");
-}
-
-void ExecutorCache::DestroyAllExecutors() {
-  absl::MutexLock lock{&mutex_};
-  cache_.clear();
-}
-
-ExecutorCache::Entry::~Entry() {
-  absl::MutexLock lock{&configurations_mutex};
-  configurations.clear();
+  return absl::NotFoundError(absl::StrFormat(
+      "No executors registered for ordinal %d", config.ordinal));
 }
 
 }  // namespace stream_executor
diff --git a/third_party/xla/xla/stream_executor/executor_cache.h b/third_party/xla/xla/stream_executor/executor_cache.h
index 6e7f32e..ae62c6d 100644
--- a/third_party/xla/xla/stream_executor/executor_cache.h
+++ b/third_party/xla/xla/stream_executor/executor_cache.h
@@ -18,11 +18,9 @@
 
 #include <functional>
 #include <memory>
-#include <utility>
-#include <vector>
 
 #include "absl/base/thread_annotations.h"
-#include "absl/container/node_hash_map.h"
+#include "absl/container/flat_hash_map.h"
 #include "absl/status/statusor.h"
 #include "absl/synchronization/mutex.h"
 #include "xla/stream_executor/platform.h"
@@ -52,33 +50,13 @@
   // has been created), or a NOT_FOUND status.
   absl::StatusOr<StreamExecutor*> Get(const StreamExecutorConfig& config);
 
-  // Destroys all Executors and clears the cache.
-  // Performs no synchronization with the executors - undefined behavior may
-  // occur if any executors are active!
-  void DestroyAllExecutors();
-
  private:
-  // Each Entry contains zero or more cached executors for a device ordinal.
-  struct Entry {
-    ~Entry();
-
-    // Mutex that guards the contents of each entry. The 'mutex_' of the
-    // ExecutorCache class protects both the 'cache_' and the existence of each
-    // Entry, but not the Entry's contents. 'configurations_mutex' protects the
-    // contents of the entry after 'mutex_' has been dropped.
-    absl::Mutex configurations_mutex;
-
-    // Vector of cached {config, executor} pairs.
-    std::vector<
-        std::pair<StreamExecutorConfig, std::unique_ptr<StreamExecutor>>>
-        configurations ABSL_GUARDED_BY(configurations_mutex);
-  };
-
-  // Maps ordinal number to a list of cached executors for that ordinal.
-  // We key off of ordinal (instead of just looking up all fields in the
-  // StreamExecutorConfig) for a slight improvement in lookup time.
+  // Protects cache_.
   absl::Mutex mutex_;
-  absl::node_hash_map<int, Entry> cache_ ABSL_GUARDED_BY(mutex_);
+
+  // Maps ordinal number to a cached executor for that ordinal.
+  absl::flat_hash_map<int, std::unique_ptr<StreamExecutor>> cache_
+      ABSL_GUARDED_BY(mutex_);
 
   ExecutorCache(const ExecutorCache&) = delete;
   void operator=(const ExecutorCache&) = delete;
diff --git a/third_party/xla/xla/stream_executor/executor_cache_test.cc b/third_party/xla/xla/stream_executor/executor_cache_test.cc
new file mode 100644
index 0000000..71e9f72
--- /dev/null
+++ b/third_party/xla/xla/stream_executor/executor_cache_test.cc
@@ -0,0 +1,128 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/stream_executor/executor_cache.h"
+
+#include <memory>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "xla/stream_executor/mock_stream_executor.h"
+#include "xla/stream_executor/platform.h"
+#include "xla/stream_executor/stream.h"
+#include "tsl/platform/statusor.h"
+
+namespace stream_executor {
+namespace {
+
+TEST(ExecutorCacheTest, GetOnEmptyCacheFails) {
+  ExecutorCache cache;
+  StreamExecutorConfig config;
+  config.ordinal = 0;
+  EXPECT_FALSE(cache.Get(config).ok());
+}
+
+TEST(ExecutorCacheTest, GetViaStreamOnEmptyCacheFails) {
+  ExecutorCache cache;
+  StreamExecutorConfig config;
+  config.ordinal = 0;
+  config.gpu_stream = reinterpret_cast<void *>(0x1234);
+  EXPECT_FALSE(cache.Get(config).ok());
+}
+
+TEST(ExecutorCacheTest, GetOrCreateConstructsAndRepeatedlyReturns) {
+  ExecutorCache cache;
+  StreamExecutorConfig config;
+  config.ordinal = 0;
+  StreamExecutor *created = nullptr;
+  auto factory = [&created]() {
+    auto executor = std::make_unique<MockStreamExecutor>();
+    created = executor.get();
+    return executor;
+  };
+  TF_ASSERT_OK_AND_ASSIGN(auto executor, cache.GetOrCreate(config, factory));
+  EXPECT_EQ(executor, created);
+  TF_ASSERT_OK_AND_ASSIGN(auto found, cache.GetOrCreate(config, factory));
+  EXPECT_EQ(found, created);
+  TF_ASSERT_OK_AND_ASSIGN(found, cache.Get(config));
+  EXPECT_EQ(found, created);
+}
+
+TEST(ExecutorCacheTest, GetViaStreamFailsIfNotFound) {
+  ExecutorCache cache;
+  StreamExecutorConfig config;
+  config.ordinal = 0;
+  StreamExecutor *created = nullptr;
+  void *expected_stream = reinterpret_cast<void *>(0x1234);
+  auto factory = [&created, &expected_stream]() {
+    auto executor = std::make_unique<MockStreamExecutor>();
+    EXPECT_CALL(*executor, FindAllocatedStream(expected_stream))
+        .WillRepeatedly(testing::Return(nullptr));
+    created = executor.get();
+    return executor;
+  };
+
+  // Create the executor.
+  TF_ASSERT_OK_AND_ASSIGN(auto executor, cache.GetOrCreate(config, factory));
+  EXPECT_EQ(executor, created);
+  // Now look for the expected stream, and don't expected to find it.
+  config.gpu_stream = expected_stream;
+  EXPECT_FALSE(cache.Get(config).ok());
+}
+
+TEST(ExecutorCacheTest, GetViaStreamWorksOnSecondStream) {
+  ExecutorCache cache;
+  StreamExecutorConfig config;
+  config.ordinal = 0;
+  StreamExecutor *created = nullptr;
+  Stream *expected_stream = reinterpret_cast<Stream *>(0x1234);
+
+  // Create a factory that will make the second StreamExecutor find the
+  // expected_stream.
+  auto factory = [&created, &expected_stream]() {
+    static int count = 0;
+    auto executor = std::make_unique<MockStreamExecutor>();
+    if (count != 1) {
+      EXPECT_CALL(*executor, FindAllocatedStream(expected_stream))
+          .WillRepeatedly(testing::Return(nullptr));
+    } else {
+      created = executor.get();
+      EXPECT_CALL(*executor, FindAllocatedStream(expected_stream))
+          .WillRepeatedly(testing::Invoke(
+              [expected_stream](void *stream) { return expected_stream; }));
+    }
+    ++count;
+    return executor;
+  };
+
+  // Create four executors.
+  std::vector<StreamExecutor *> created_executors;
+  for (int i = 0; i < 4; ++i) {
+    config.ordinal = i;
+    TF_ASSERT_OK_AND_ASSIGN(auto executor, cache.GetOrCreate(config, factory));
+    EXPECT_NE(executor, nullptr);
+    created_executors.push_back(executor);
+  }
+  EXPECT_EQ(created_executors.size(), 4);
+  // Now look for the expected stream, and expect to find it on the second
+  // stream.
+  config.gpu_stream = expected_stream;
+  TF_ASSERT_OK_AND_ASSIGN(auto found, cache.Get(config));
+  EXPECT_EQ(found, created);
+}
+
+}  // namespace
+}  // namespace stream_executor
diff --git a/third_party/xla/xla/stream_executor/gpu/BUILD b/third_party/xla/xla/stream_executor/gpu/BUILD
index ae1ad8c..2d763e5 100644
--- a/third_party/xla/xla/stream_executor/gpu/BUILD
+++ b/third_party/xla/xla/stream_executor/gpu/BUILD
@@ -228,14 +228,12 @@
         "//xla/stream_executor:stream_executor_h",
         "@com_google_absl//absl/base:core_headers",
         "@com_google_absl//absl/container:flat_hash_map",
-        "@com_google_absl//absl/functional:any_invocable",
         "@com_google_absl//absl/numeric:int128",
         "@com_google_absl//absl/status",
         "@com_google_absl//absl/status:statusor",
         "@com_google_absl//absl/strings:str_format",
         "@com_google_absl//absl/synchronization",
         "@com_google_absl//absl/types:span",
-        "@local_tsl//tsl/platform:platform_port",
         "@local_tsl//tsl/platform:thread_annotations",
     ],
 )
@@ -244,8 +242,7 @@
     name = "gpu_helpers_header",
     hdrs = ["gpu_helpers.h"],
     deps = [
-        ":gpu_types_header",
-        "@local_tsl//tsl/platform:logging",
+        "//xla/stream_executor:device_memory",
     ],
 )
 
@@ -311,10 +308,13 @@
     deps = [
         ":gpu_event_header",
         ":gpu_executor_header",
+        ":gpu_kernel_header",
         ":gpu_types_header",
         "//xla/stream_executor:device_memory",
         "//xla/stream_executor:event",
         "//xla/stream_executor:event_based_timer",
+        "//xla/stream_executor:kernel",
+        "//xla/stream_executor:launch_dim",
         "//xla/stream_executor:platform",
         "//xla/stream_executor:stream",
         "//xla/stream_executor:stream_common",
@@ -332,10 +332,13 @@
         ":gpu_driver_header",
         ":gpu_event_header",
         ":gpu_executor_header",
+        ":gpu_kernel_header",
         ":gpu_types_header",
         "//xla/stream_executor:device_memory",
         "//xla/stream_executor:event",
         "//xla/stream_executor:event_based_timer",
+        "//xla/stream_executor:kernel",
+        "//xla/stream_executor:launch_dim",
         "//xla/stream_executor:platform",
         "//xla/stream_executor:stream",
         "//xla/stream_executor:stream_common",
@@ -346,6 +349,7 @@
         "@com_google_absl//absl/strings:str_format",
         "@com_google_absl//absl/strings:string_view",
         "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:statusor",
         "@local_tsl//tsl/profiler/lib:nvtx_utils",
     ],
 )
@@ -378,7 +382,6 @@
         ":gpu_stream",
         ":gpu_types_header",
         "//xla/stream_executor",
-        "//xla/stream_executor:event",
         "//xla/stream_executor:event_based_timer",
         "@com_google_absl//absl/base:core_headers",
         "@com_google_absl//absl/log",
@@ -387,7 +390,6 @@
         "@com_google_absl//absl/status:statusor",
         "@com_google_absl//absl/synchronization",
         "@com_google_absl//absl/time",
-        "@com_google_absl//absl/utility",
         "@local_tsl//tsl/platform:errors",
         "@local_tsl//tsl/platform:statusor",
     ],
@@ -557,8 +559,6 @@
     name = "redzone_allocator_test",
     srcs = ["redzone_allocator_test.cc"],
     backends = ["gpu"],
-    # TODO(b/317293391) Remove once Bazel test_suite handles tags correctly
-    tags = ["gpu"],
     deps = [
         ":gpu_asm_opts",
         ":gpu_init",
@@ -604,11 +604,7 @@
     name = "gpu_cudamallocasync_allocator_test",
     srcs = ["gpu_cudamallocasync_allocator_test.cc"],
     backends = ["gpu_any"],
-    tags = [
-        # TODO(b/317293391) Remove once Bazel test_suite handles tags correctly
-        "gpu",
-        "no_rocm",
-    ],
+    tags = ["no_rocm"],
     deps = [
         ":gpu_cudamallocasync_allocator",
         ":gpu_stream",
@@ -740,8 +736,6 @@
     local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([
         "TENSORFLOW_USE_ROCM=1",
     ]),
-    # TODO(b/317293391) Remove once Bazel test_suite handles tags correctly
-    tags = ["gpu"],
     deps = [
         "//xla/stream_executor",
         "//xla/stream_executor:device_memory",
@@ -765,8 +759,6 @@
     local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([
         "TENSORFLOW_USE_ROCM=1",
     ]),
-    # TODO(b/317293391) Remove once Bazel test_suite handles tags correctly
-    tags = ["gpu"],
     deps = [
         "//xla/stream_executor",
         "//xla/stream_executor/host:host_platform",
@@ -789,8 +781,6 @@
     local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([
         "TENSORFLOW_USE_ROCM=1",
     ]),
-    # TODO(b/317293391) Remove once Bazel test_suite handles tags correctly
-    tags = ["gpu"],
     deps = [
         "//xla/service:platform_util",
         "//xla/stream_executor:platform",
@@ -799,7 +789,6 @@
         "@com_google_absl//absl/status:statusor",
         "@com_google_absl//absl/strings",
         "@com_google_googletest//:gtest_main",
-        "@local_tsl//tsl/platform:platform_port",
         "@local_tsl//tsl/platform:statusor",
         "@local_tsl//tsl/platform:test",
     ] + if_cuda([
diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.h b/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.h
index b4bf7eb..20caccb 100644
--- a/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.h
+++ b/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.h
@@ -25,6 +25,7 @@
 #include <vector>
 
 #include "absl/status/status.h"
+#include "absl/status/statusor.h"
 #include "xla/stream_executor/blas.h"
 #include "xla/stream_executor/device_memory.h"
 #include "xla/stream_executor/host_or_device_scalar.h"
diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc
index d5d9ebd..c672b11 100644
--- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc
+++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc
@@ -1006,6 +1006,8 @@
       } else {
         TF_RETURN_IF_ERROR(retry);
       }
+    } else {
+      TF_RETURN_IF_ERROR(instantiated);
     }
 
     uint64_t end_nanos = tsl::Env::Default()->NowNanos();
diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc
index cd32935..0376a5a 100644
--- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc
+++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc
@@ -776,6 +776,69 @@
   ASSERT_EQ(dst, expected);
 }
 
+TEST(GpuCommandBufferTest, ConditionalIfWithMemset) {
+#if CUDA_VERSION < 12040
+  GTEST_SKIP() << "ConditionalsWithMemset are not supported before 12.4.1.";
+#endif
+  Platform* platform = GpuPlatform();
+
+  StreamExecutor* executor = platform->ExecutorForDevice(0).value();
+
+  TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream());
+
+  int64_t length = 4;
+  int64_t byte_length = sizeof(int32_t) * length;
+
+  // Prepare arguments: a=0, pred=true
+  DeviceMemory<bool> pred = executor->AllocateArray<bool>(1, 0);
+  DeviceMemory<int32_t> a = executor->AllocateArray<int32_t>(length, 0);
+
+  constexpr bool kTrue = true;
+  TF_ASSERT_OK(stream->Memcpy(&pred, &kTrue, 1));
+  TF_ASSERT_OK(stream->Memset32(&a, 0, byte_length));
+
+  // if (pred == true) memset(&a, ...);
+  CommandBuffer::Builder then_builder = [&](CommandBuffer* then_cmd) {
+    return then_cmd->Memset(&a, uint8_t{1}, byte_length);
+  };
+
+  // Create a command buffer with a single conditional operation.
+  TF_ASSERT_OK_AND_ASSIGN(auto cmd_buffer,
+                          executor->CreateCommandBuffer(primary));
+  TF_ASSERT_OK(cmd_buffer->If(pred, then_builder));
+  TF_ASSERT_OK(cmd_buffer->Finalize());
+
+  TF_ASSERT_OK(cmd_buffer->Submit(stream.get()));
+
+  // Copy `a` data back to host.
+  std::vector<int32_t> dst(length, 42);
+  TF_ASSERT_OK(stream->Memcpy(dst.data(), a, byte_length));
+
+  std::vector<int32_t> expected(length, 1 << 24 | 1 << 16 | 1 << 8 | 1);
+  ASSERT_EQ(dst, expected);
+
+  // Prepare argument for graph update: b = 0
+  DeviceMemory<int32_t> b = executor->AllocateArray<int32_t>(length, 0);
+  TF_ASSERT_OK(stream->MemZero(&a, byte_length));
+
+  // if (pred == true) memset(&b, ...);
+  then_builder = [&](CommandBuffer* then_cmd) {
+    return then_cmd->Memset(&b, uint8_t{1}, byte_length);
+  };
+
+  // Update command buffer with a conditional to use new builder.
+  TF_ASSERT_OK(cmd_buffer->Update());
+  TF_ASSERT_OK(cmd_buffer->If(pred, then_builder));
+  TF_ASSERT_OK(cmd_buffer->Finalize());
+
+  TF_ASSERT_OK(cmd_buffer->Submit(stream.get()));
+
+  // Copy `b` data back to host.
+  std::fill(dst.begin(), dst.end(), 42);
+  TF_ASSERT_OK(stream->Memcpy(dst.data(), b, byte_length));
+  ASSERT_EQ(dst, expected);
+}
+
 TEST(GpuCommandBufferTest, ConditionalIfElse) {
   if (!IsAtLeastCuda12300()) {
     GTEST_SKIP() << "CUDA graph conditionals are not supported";
diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_driver.h b/third_party/xla/xla/stream_executor/gpu/gpu_driver.h
index 599480c..94cff46 100644
--- a/third_party/xla/xla/stream_executor/gpu/gpu_driver.h
+++ b/third_party/xla/xla/stream_executor/gpu/gpu_driver.h
@@ -31,6 +31,7 @@
 #include "absl/strings/string_view.h"
 #include "absl/types/span.h"
 #include "xla/stream_executor/gpu/gpu_types.h"
+#include "xla/stream_executor/platform.h"
 #include "xla/stream_executor/stream_executor.h"
 
 namespace stream_executor {
diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_executor.h b/third_party/xla/xla/stream_executor/gpu/gpu_executor.h
index 116120d..bfb7e57 100644
--- a/third_party/xla/xla/stream_executor/gpu/gpu_executor.h
+++ b/third_party/xla/xla/stream_executor/gpu/gpu_executor.h
@@ -1,4 +1,3 @@
-#include "xla/stream_executor/event_based_timer.h"
 /* Copyright 2019 The OpenXLA Authors.
 
 Licensed under the Apache License, Version 2.0 (the "License");
@@ -31,10 +30,10 @@
 #include <unordered_map>
 #include <utility>
 #include <variant>
+#include <vector>
 
 #include "absl/base/thread_annotations.h"
 #include "absl/container/flat_hash_map.h"
-#include "absl/functional/any_invocable.h"
 #include "absl/numeric/int128.h"
 #include "absl/status/status.h"
 #include "absl/status/statusor.h"
@@ -46,6 +45,7 @@
 #include "xla/stream_executor/device_memory.h"
 #include "xla/stream_executor/dnn.h"
 #include "xla/stream_executor/event.h"
+#include "xla/stream_executor/event_based_timer.h"
 #include "xla/stream_executor/fft.h"
 #include "xla/stream_executor/gpu/gpu_collectives.h"
 #include "xla/stream_executor/gpu/gpu_driver.h"
@@ -59,7 +59,6 @@
 #include "xla/stream_executor/platform.h"
 #include "xla/stream_executor/stream_executor.h"
 #include "xla/stream_executor/stream_executor_common.h"
-#include "tsl/platform/numa.h"
 #include "tsl/platform/thread_annotations.h"
 
 namespace stream_executor {
@@ -113,8 +112,7 @@
         device_ordinal_(device_ordinal),
         cc_major_(0),
         cc_minor_(0),
-        version_(0),
-        numa_node_(tsl::port::kNUMANoAffinity) {}
+        version_(0) {}
 
   // See the corresponding StreamExecutor methods for method comments on the
   // following overrides.
@@ -140,15 +138,6 @@
   absl::StatusOr<std::shared_ptr<DeviceMemoryBase>> CreateOrShareConstant(
       Stream* stream, absl::Span<const uint8_t> content) override;
 
-  absl::Status Launch(Stream* stream, const ThreadDim& thread_dims,
-                      const BlockDim& block_dims, const Kernel& kernel,
-                      const KernelArgs& args) override;
-
-  absl::Status Launch(Stream* stream, const ThreadDim& thread_dims,
-                      const BlockDim& block_dims,
-                      const ClusterDim& cluster_dims, const Kernel& kernel,
-                      const KernelArgs& args) override;
-
   DeviceMemoryBase Allocate(uint64_t size, int64_t memory_space) override;
 
   void Deallocate(DeviceMemoryBase* mem) override;
@@ -169,10 +158,23 @@
     return GpuCollectives::CollectiveMemoryDeallocate(context_, location);
   }
 
+  // CUDA allocation/registration functions are necessary because the driver
+  // internally sets up buffers for DMA operations (and page locks them).
+  // There's no external interface for us to otherwise control these DMA
+  // settings.
   absl::StatusOr<std::unique_ptr<MemoryAllocation>> HostMemoryAllocate(
-      uint64_t size) override;
+      uint64_t size) override {
+    auto* buffer = GpuDriver::HostAllocate(context_, size);
+    if (buffer == nullptr && size > 0) {
+      return absl::InternalError(
+          absl::StrFormat("Failed to allocate HostMemory of size %d", size));
+    }
+    return std::make_unique<HostMemoryAllocation>(buffer, size, this);
+  }
 
-  void HostMemoryDeallocate(void* location, uint64_t size) override;
+  void HostMemoryDeallocate(void* location) override {
+    return GpuDriver::HostDeallocate(context_, location);
+  }
 
   absl::StatusOr<MemoryType> GetPointerMemorySpace(const void* ptr) override {
     return GpuDriver::GetPointerMemorySpace(
@@ -305,11 +307,6 @@
   absl::Status LoadModuleFromHsaco(const char* hsaco, GpuModuleHandle* module)
       TF_EXCLUSIVE_LOCKS_REQUIRED(in_memory_modules_mu_);
 
-  absl::Status Launch(Stream* stream, const ThreadDim& thread_dims,
-                      const BlockDim& block_dims,
-                      const std::optional<ClusterDim>& cluster_dims,
-                      const Kernel& kernel, const KernelArgs& args);
-
   bool UnloadGpuBinary(const void* gpu_binary)
       TF_EXCLUSIVE_LOCKS_REQUIRED(in_memory_modules_mu_);
 
@@ -369,9 +366,6 @@
   // GPU ISA version for device_.
   int version_;
 
-  // NUMA node for device_.
-  int numa_node_;
-
   // Type erased XLA specific state attached to GpuExecutor.
   Object xla_state_;
 
diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_executor_test.cc b/third_party/xla/xla/stream_executor/gpu/gpu_executor_test.cc
index 9ac7be1..c3c67bc 100644
--- a/third_party/xla/xla/stream_executor/gpu/gpu_executor_test.cc
+++ b/third_party/xla/xla/stream_executor/gpu/gpu_executor_test.cc
@@ -20,7 +20,6 @@
 #include "xla/stream_executor/platform.h"
 #include "xla/stream_executor/platform_manager.h"
 #include "xla/stream_executor/stream_executor.h"
-#include "tsl/platform/numa.h"
 #include "tsl/platform/statusor.h"
 #include "tsl/platform/test.h"
 
@@ -55,32 +54,4 @@
   executor->Deallocate(&mem);
 }
 
-using HostMemoryAllocateTest = GpuExecutorTest;
-
-TEST_F(HostMemoryAllocateTest, Numa) {
-  Platform* platform = GetPlatform();
-  const uint64_t kSize = 1024;
-  const int num_devices = platform->VisibleDeviceCount();
-  for (int device = 0; device < num_devices; ++device) {
-    TF_ASSERT_OK_AND_ASSIGN(StreamExecutor * executor,
-                            platform->ExecutorForDevice(device));
-    ASSERT_TRUE(executor);
-    TF_ASSERT_OK_AND_ASSIGN(auto device_desc,
-                            executor->CreateDeviceDescription());
-    ASSERT_TRUE(device_desc);
-    TF_ASSERT_OK_AND_ASSIGN(auto host_ptr, executor->HostMemoryAllocate(kSize));
-    ASSERT_TRUE(host_ptr);
-    EXPECT_NE(host_ptr->opaque(), nullptr);
-    const int numa_node = tsl::port::NUMAGetMemAffinity(host_ptr->opaque());
-    if (numa_node == tsl::port::kNUMANoAffinity) {
-      // Could be because `executor` could not determine its own NUMA node, in
-      // which case numa_node() will be -1 or 0, depending on the failure mode.
-      EXPECT_LE(device_desc->numa_node(), 0);
-      EXPECT_GE(device_desc->numa_node(), -1);
-    } else {
-      EXPECT_EQ(device_desc->numa_node(), numa_node);
-    }
-  }
-}
-
 }  // namespace stream_executor
diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_helpers.h b/third_party/xla/xla/stream_executor/gpu/gpu_helpers.h
index 62db127..187d882 100644
--- a/third_party/xla/xla/stream_executor/gpu/gpu_helpers.h
+++ b/third_party/xla/xla/stream_executor/gpu/gpu_helpers.h
@@ -23,17 +23,10 @@
 
 #include <stddef.h>
 
-#include <complex>
-#include <cstdint>
-
-#include "xla/stream_executor/gpu/gpu_types.h"
-#include "tsl/platform/logging.h"
+#include "xla/stream_executor/device_memory.h"
 
 namespace stream_executor {
 
-template <typename ElemT>
-class DeviceMemory;
-
 namespace gpu {
 
 // Converts a const DeviceMemory reference to its underlying typed pointer in
diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc b/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc
index 7068265..b257ffa 100644
--- a/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc
+++ b/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc
@@ -17,6 +17,7 @@
 
 #include <cstdint>
 #include <memory>
+#include <optional>
 #include <utility>
 #include <variant>
 
@@ -31,10 +32,14 @@
 #include "xla/stream_executor/gpu/gpu_driver.h"
 #include "xla/stream_executor/gpu/gpu_event.h"
 #include "xla/stream_executor/gpu/gpu_executor.h"
+#include "xla/stream_executor/gpu/gpu_kernel.h"
 #include "xla/stream_executor/gpu/gpu_types.h"
+#include "xla/stream_executor/kernel.h"
+#include "xla/stream_executor/launch_dim.h"
 #include "xla/stream_executor/platform.h"
 #include "xla/stream_executor/stream.h"
 #include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
 #include "tsl/profiler/lib/nvtx_utils.h"
 
 namespace stream_executor {
@@ -195,6 +200,83 @@
   return parent_->CreateEventBasedTimer(this, use_delay_kernel);
 }
 
+absl::Status GpuStream::Launch(const ThreadDim& thread_dims,
+                               const BlockDim& block_dims, const Kernel& kernel,
+                               const KernelArgs& args) {
+  return Launch(thread_dims, block_dims, std::nullopt, kernel, args);
+}
+
+absl::Status GpuStream::Launch(const ThreadDim& thread_dims,
+                               const BlockDim& block_dims,
+                               const ClusterDim& cluster_dims,
+                               const Kernel& kernel, const KernelArgs& args) {
+  return Launch(thread_dims, block_dims, std::make_optional(cluster_dims),
+                kernel, args);
+}
+
+absl::Status GpuStream::Launch(const ThreadDim& thread_dims,
+                               const BlockDim& block_dims,
+                               const std::optional<ClusterDim>& cluster_dims,
+                               const Kernel& kernel, const KernelArgs& args) {
+  const GpuKernel* gpu_kernel = AsGpuKernel(&kernel);
+  GpuFunctionHandle function = gpu_kernel->gpu_function();
+
+  if (gpu_kernel->cache_config() != KernelCacheConfig::kNoPreference) {
+    TF_RETURN_IF_ERROR(GpuDriver::FuncSetCacheConfig(
+        function, gpu_kernel->GetGpuCacheConfig()));
+  }
+
+  // Launch kernels with packed arguments.
+  auto launch = [this, &kernel, &cluster_dims, &thread_dims, &block_dims,
+                 &function](const KernelArgsPackedArrayBase& packed) {
+    int32_t expected_number_of_arguments =
+        kernel.Arity() + (packed.number_of_shared_bytes() > 0);
+
+    CHECK_EQ(expected_number_of_arguments, packed.number_of_arguments())
+        << "Kernel " << kernel.name() << " has " << packed.number_of_arguments()
+        << " arguments, but expected " << expected_number_of_arguments
+        << "; arity=" << kernel.Arity()
+        << "; number_of_shared_bytes=" << packed.number_of_shared_bytes();
+
+    void** params = const_cast<void**>(packed.argument_addresses().data());
+
+    if (cluster_dims.has_value()) {
+      return GpuDriver::LaunchKernel(
+          parent_->gpu_context(), kernel.name(), function, cluster_dims->x,
+          cluster_dims->y, cluster_dims->z, block_dims.x, block_dims.y,
+          block_dims.z, thread_dims.x, thread_dims.y, thread_dims.z,
+          packed.number_of_shared_bytes(), gpu_stream(), params,
+          /*extra=*/nullptr);
+    } else {
+      return GpuDriver::LaunchKernel(
+          parent_->gpu_context(), kernel.name(), function, block_dims.x,
+          block_dims.y, block_dims.z, thread_dims.x, thread_dims.y,
+          thread_dims.z, packed.number_of_shared_bytes(), gpu_stream(), params,
+          /*extra=*/nullptr);
+    }
+  };
+
+  // If arguments are already packed we can just launch the kernel.
+  if (auto* packed = DynCast<KernelArgsPackedArrayBase>(&args)) {
+    return launch(*packed);
+  }
+
+  // For device memory array we rely on a custom kernel arguments packing.
+  if (auto* device_mem = DynCast<KernelArgsDeviceMemoryArray>(&args)) {
+    auto& pack = kernel.args_packing();
+    if (!pack) {
+      return absl::InternalError(
+          "Kernel is missing a custom arguments packing function for device "
+          "memory arguments array");
+    }
+
+    TF_ASSIGN_OR_RETURN(auto packed, pack(kernel, *device_mem));
+    return launch(*packed);
+  }
+
+  return absl::InternalError("Unsupported kernel arguments type");
+}
+
 GpuStream* AsGpuStream(Stream* stream) {
   DCHECK(stream != nullptr);
   return static_cast<GpuStream*>(stream);
diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_stream.h b/third_party/xla/xla/stream_executor/gpu/gpu_stream.h
index 4cf21ca..249fbf7 100644
--- a/third_party/xla/xla/stream_executor/gpu/gpu_stream.h
+++ b/third_party/xla/xla/stream_executor/gpu/gpu_stream.h
@@ -34,6 +34,8 @@
 #include "xla/stream_executor/gpu/gpu_event.h"
 #include "xla/stream_executor/gpu/gpu_executor.h"
 #include "xla/stream_executor/gpu/gpu_types.h"
+#include "xla/stream_executor/kernel.h"
+#include "xla/stream_executor/launch_dim.h"
 #include "xla/stream_executor/platform.h"
 #include "xla/stream_executor/stream.h"
 #include "xla/stream_executor/stream_common.h"
@@ -103,8 +105,18 @@
   void set_name(absl::string_view name) override;
   absl::StatusOr<std::unique_ptr<EventBasedTimer>> CreateEventBasedTimer(
       bool use_delay_kernel) override;
+  absl::Status Launch(const ThreadDim& thread_dims, const BlockDim& block_dims,
+                      const Kernel& k, const KernelArgs& args) override;
+  absl::Status Launch(const ThreadDim& thread_dims, const BlockDim& block_dims,
+                      const ClusterDim& cluster_dims, const Kernel& k,
+                      const KernelArgs& args) override;
 
  private:
+  // Helper method to launch a kernel with optional cluster dimensions.
+  absl::Status Launch(const ThreadDim& thread_dims, const BlockDim& block_dims,
+                      const std::optional<ClusterDim>& cluster_dims,
+                      const Kernel& kernel, const KernelArgs& args);
+
   GpuExecutor* parent_;         // Executor that spawned this stream.
   GpuStreamHandle gpu_stream_;  // Wrapped CUDA stream handle.
   std::variant<StreamPriority, int> stream_priority_;
diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_timer.h b/third_party/xla/xla/stream_executor/gpu/gpu_timer.h
index be0f9a5..656dd1e 100644
--- a/third_party/xla/xla/stream_executor/gpu/gpu_timer.h
+++ b/third_party/xla/xla/stream_executor/gpu/gpu_timer.h
@@ -21,12 +21,9 @@
 
 #include "absl/status/statusor.h"
 #include "absl/time/time.h"
-#include "xla/stream_executor/event.h"
 #include "xla/stream_executor/event_based_timer.h"
 #include "xla/stream_executor/gpu/gpu_event.h"
 #include "xla/stream_executor/gpu/gpu_semaphore.h"
-#include "xla/stream_executor/gpu/gpu_types.h"
-#include "xla/stream_executor/stream.h"
 
 namespace xla {
 namespace gpu {
diff --git a/third_party/xla/xla/stream_executor/host/BUILD b/third_party/xla/xla/stream_executor/host/BUILD
index 326b3d6..8db3751 100644
--- a/third_party/xla/xla/stream_executor/host/BUILD
+++ b/third_party/xla/xla/stream_executor/host/BUILD
@@ -81,8 +81,11 @@
     ],
     deps = [
         ":host_event",
+        ":host_kernel",
         "//xla/stream_executor:device_memory",
         "//xla/stream_executor:event",
+        "//xla/stream_executor:kernel",
+        "//xla/stream_executor:launch_dim",
         "//xla/stream_executor:stream",
         "//xla/stream_executor:stream_common",
         "@com_google_absl//absl/base:core_headers",
@@ -166,14 +169,17 @@
         ":host_event",
         ":host_kernel",
         ":host_stream",
-        "//xla/stream_executor",
+        "//xla/stream_executor:device_description",
+        "//xla/stream_executor:device_memory",
         "//xla/stream_executor:event",
         "//xla/stream_executor:host_memory_allocation",
+        "//xla/stream_executor:kernel",
         "//xla/stream_executor:kernel_spec",
         "//xla/stream_executor:memory_allocation",
+        "//xla/stream_executor:platform",
+        "//xla/stream_executor:stream",
         "//xla/stream_executor:stream_executor_common",
         "//xla/stream_executor:stream_executor_h",
-        "@com_google_absl//absl/functional:any_invocable",
         "@com_google_absl//absl/log",
         "@com_google_absl//absl/log:check",
         "@com_google_absl//absl/status",
diff --git a/third_party/xla/xla/stream_executor/host/host_executor.cc b/third_party/xla/xla/stream_executor/host/host_executor.cc
index 38715ce..68562b7 100644
--- a/third_party/xla/xla/stream_executor/host/host_executor.cc
+++ b/third_party/xla/xla/stream_executor/host/host_executor.cc
@@ -22,17 +22,17 @@
 
 #include <cstdint>
 #include <memory>
+#include <optional>
 #include <string>
 #include <utility>
+#include <variant>
 #include <vector>
 
-#include "absl/functional/any_invocable.h"
 #include "absl/log/check.h"
 #include "absl/log/log.h"
 #include "absl/status/status.h"
 #include "absl/status/statusor.h"
 #include "absl/strings/string_view.h"
-#include "absl/synchronization/notification.h"
 #include "xla/stream_executor/device_description.h"
 #include "xla/stream_executor/device_memory.h"
 #include "xla/stream_executor/event.h"
@@ -40,9 +40,8 @@
 #include "xla/stream_executor/host/host_kernel.h"
 #include "xla/stream_executor/host/host_stream.h"
 #include "xla/stream_executor/kernel_spec.h"
-#include "xla/stream_executor/launch_dim.h"
 #include "xla/stream_executor/platform.h"
-#include "xla/stream_executor/stream_executor.h"
+#include "xla/stream_executor/stream.h"
 #include "tsl/platform/cpu_info.h"
 #include "tsl/platform/env.h"
 #include "tsl/platform/mem.h"
@@ -91,26 +90,6 @@
   return absl::InternalError("No method of loading host kernel provided");
 }
 
-absl::Status HostExecutor::Launch(Stream* stream, const ThreadDim& thread_dims,
-                                  const BlockDim& block_dims,
-                                  const Kernel& kernel,
-                                  const KernelArgs& args) {
-  const HostKernel* host_kernel = AsHostKernel(&kernel);
-
-  const KernelArgsDeviceMemoryArray* device_mem =
-      DynCast<KernelArgsDeviceMemoryArray>(&args);
-
-  absl::Status result;
-  if (device_mem != nullptr) {
-    result = host_kernel->Launch(thread_dims, device_mem->device_memory_args());
-  } else {
-    result = absl::UnimplementedError(
-        "Host kernel implements Launch method only for DeviceMemoryArray "
-        "arguments.");
-  }
-  return result;
-}
-
 bool HostExecutor::DeviceMemoryUsage(int64_t* free, int64_t* total) const {
   tsl::port::MemoryInfo mem_info = tsl::port::GetMemoryInfo();
   *free = (mem_info.free != INT64_MAX) ? mem_info.free : -1;
diff --git a/third_party/xla/xla/stream_executor/host/host_executor.h b/third_party/xla/xla/stream_executor/host/host_executor.h
index 5f1c5d0..55eacc5 100644
--- a/third_party/xla/xla/stream_executor/host/host_executor.h
+++ b/third_party/xla/xla/stream_executor/host/host_executor.h
@@ -13,20 +13,15 @@
 limitations under the License.
 ==============================================================================*/
 
-// Declares the HostExecutor class, which is a CPU-only implementation of
-// the StreamExecutor interface. For now, this is used for testing and to
-// examine the performance of host-based StreamExecutor code.
 #ifndef XLA_STREAM_EXECUTOR_HOST_HOST_EXECUTOR_H_
 #define XLA_STREAM_EXECUTOR_HOST_HOST_EXECUTOR_H_
 
-#include <cstddef>
 #include <cstdint>
 #include <functional>
 #include <memory>
 #include <optional>
 #include <variant>
 
-#include "absl/functional/any_invocable.h"
 #include "absl/status/status.h"
 #include "absl/status/statusor.h"
 #include "xla/stream_executor/device_description.h"
@@ -36,24 +31,21 @@
 #include "xla/stream_executor/host_memory_allocation.h"
 #include "xla/stream_executor/kernel.h"
 #include "xla/stream_executor/kernel_spec.h"
-#include "xla/stream_executor/launch_dim.h"
 #include "xla/stream_executor/memory_allocation.h"
 #include "xla/stream_executor/platform.h"
+#include "xla/stream_executor/stream.h"
 #include "xla/stream_executor/stream_executor_common.h"
 #include "tsl/platform/threadpool.h"
 
 namespace stream_executor {
 namespace host {
 
-// An implementation of StreamExecutor that does no communication or interaction
-// with a device, but DOES perform memory operations backed by the host.
-// Kernel invocations will fail, but host callbacks may be enqueued on this
-// executor and its associated stream, and should follow standard ordering
-// semantics.
+// Declares the HostExecutor class, which is a CPU-only implementation of
+// the StreamExecutor interface. For now, this is used for testing and to
+// examine the performance of host-based StreamExecutor code.
 //
 // This is useful for evaluating the performance of host-based or fallback
 // routines executed under the context of a GPU executor.
-// See stream_executor.h for description of the below operations.
 class HostExecutor : public StreamExecutorCommon {
  public:
   // A function that loads a kernel function from a given spec. If spec is not
@@ -73,10 +65,6 @@
   absl::StatusOr<std::unique_ptr<Kernel>> LoadKernel(
       const MultiKernelLoaderSpec& spec) override;
 
-  absl::Status Launch(Stream* stream, const ThreadDim& thread_dims,
-                      const BlockDim& block_dims, const Kernel& kernel,
-                      const KernelArgs& args) override;
-
   DeviceMemoryBase Allocate(uint64_t size, int64_t memory_space) override;
   void Deallocate(DeviceMemoryBase* mem) override;
 
@@ -84,11 +72,10 @@
       uint64_t size) override {
     return std::make_unique<HostMemoryAllocation>(new char[size], size, this);
   }
-  void HostMemoryDeallocate(void* mem, uint64_t size) override {
+  void HostMemoryDeallocate(void* mem) override {
     delete[] static_cast<char*>(mem);
   }
 
-  // No "synchronize all activity" implemented for this platform at the moment.
   bool SynchronizeAllActivity() override { return true; }
   absl::Status SynchronousMemZero(DeviceMemoryBase* location,
                                   uint64_t size) override;
diff --git a/third_party/xla/xla/stream_executor/host/host_kernel_test.cc b/third_party/xla/xla/stream_executor/host/host_kernel_test.cc
index a99c675..08bc958 100644
--- a/third_party/xla/xla/stream_executor/host/host_kernel_test.cc
+++ b/third_party/xla/xla/stream_executor/host/host_kernel_test.cc
@@ -89,10 +89,10 @@
 }
 )";
 
-static absl::StatusOr<std::unique_ptr<StreamExecutor>> NewStreamExecutor() {
-  StreamExecutorConfig config(/*ordinal=*/0);
+static absl::StatusOr<StreamExecutor*> NewStreamExecutor() {
   TF_ASSIGN_OR_RETURN(auto platform, PlatformManager::PlatformWithName("Host"));
-  TF_ASSIGN_OR_RETURN(auto stream_exec, platform->GetUncachedExecutor(config));
+  TF_ASSIGN_OR_RETURN(auto stream_exec,
+                      platform->ExecutorForDevice(/*ordinal=*/0));
   return stream_exec;
 }
 
diff --git a/third_party/xla/xla/stream_executor/host/host_platform.h b/third_party/xla/xla/stream_executor/host/host_platform.h
index 25c1179..3dd90a6 100644
--- a/third_party/xla/xla/stream_executor/host/host_platform.h
+++ b/third_party/xla/xla/stream_executor/host/host_platform.h
@@ -54,10 +54,13 @@
   absl::StatusOr<StreamExecutor*> GetExecutor(
       const StreamExecutorConfig& config) override;
 
-  absl::StatusOr<std::unique_ptr<StreamExecutor>> GetUncachedExecutor(
-      const StreamExecutorConfig& config) override;
-
  private:
+  // Returns a device constructed with the options specified in "config" without
+  // looking in or storing to the Platform's executor cache.
+  // Ownership IS transferred to the caller.
+  absl::StatusOr<std::unique_ptr<StreamExecutor>> GetUncachedExecutor(
+      const StreamExecutorConfig& config);
+
   // This platform's name.
   std::string name_;
 
diff --git a/third_party/xla/xla/stream_executor/host/host_stream.cc b/third_party/xla/xla/stream_executor/host/host_stream.cc
index ed6e040..76b6671 100644
--- a/third_party/xla/xla/stream_executor/host/host_stream.cc
+++ b/third_party/xla/xla/stream_executor/host/host_stream.cc
@@ -33,6 +33,9 @@
 #include "xla/stream_executor/device_memory.h"
 #include "xla/stream_executor/event.h"
 #include "xla/stream_executor/host/host_event.h"
+#include "xla/stream_executor/host/host_kernel.h"
+#include "xla/stream_executor/kernel.h"
+#include "xla/stream_executor/launch_dim.h"
 #include "xla/stream_executor/stream.h"
 #include "xla/stream_executor/stream_common.h"
 #include "tsl/platform/denormal.h"
@@ -192,6 +195,21 @@
   return status;
 }
 
-}  // namespace host
+absl::Status HostStream::Launch(const ThreadDim& thread_dims,
+                                const BlockDim& block_dims,
+                                const Kernel& kernel, const KernelArgs& args) {
+  const HostKernel* host_kernel = AsHostKernel(&kernel);
 
+  const KernelArgsDeviceMemoryArray* device_mem =
+      DynCast<KernelArgsDeviceMemoryArray>(&args);
+
+  if (device_mem != nullptr) {
+    return host_kernel->Launch(thread_dims, device_mem->device_memory_args());
+  }
+  return absl::UnimplementedError(
+      "Host kernel implements Launch method only for DeviceMemoryArray "
+      "arguments.");
+}
+
+}  // namespace host
 }  // namespace stream_executor
diff --git a/third_party/xla/xla/stream_executor/host/host_stream.h b/third_party/xla/xla/stream_executor/host/host_stream.h
index ed1bbc2..563d9fc 100644
--- a/third_party/xla/xla/stream_executor/host/host_stream.h
+++ b/third_party/xla/xla/stream_executor/host/host_stream.h
@@ -13,12 +13,10 @@
 limitations under the License.
 ==============================================================================*/
 
-// Class declaration for Stream type that enqueues tasks onto a host/CPU-based
-// execution context (as opposed to a GPU device), HostExecutor.
 #ifndef XLA_STREAM_EXECUTOR_HOST_HOST_STREAM_H_
 #define XLA_STREAM_EXECUTOR_HOST_HOST_STREAM_H_
 
-#include <cstddef>
+#include <cstdint>
 #include <memory>
 #include <queue>
 
@@ -27,6 +25,10 @@
 #include "absl/status/status.h"
 #include "absl/synchronization/mutex.h"
 #include "xla/stream_executor/device_memory.h"
+#include "xla/stream_executor/event.h"
+#include "xla/stream_executor/kernel.h"
+#include "xla/stream_executor/launch_dim.h"
+#include "xla/stream_executor/stream.h"
 #include "xla/stream_executor/stream_common.h"
 #include "tsl/platform/env.h"
 #include "tsl/platform/thread_annotations.h"
@@ -34,6 +36,8 @@
 namespace stream_executor {
 namespace host {
 
+// Class declaration for Stream type that enqueues tasks onto a host/CPU-based
+// execution context (as opposed to a GPU device), HostExecutor.
 class HostStream : public StreamCommon {
  public:
   explicit HostStream(StreamExecutor* executor);
@@ -65,6 +69,8 @@
                       uint64_t size) override;
   absl::Status DoHostCallbackWithStatus(
       absl::AnyInvocable<absl::Status() &&> callback) override;
+  absl::Status Launch(const ThreadDim& thread_dims, const BlockDim& block_dims,
+                      const Kernel& kernel, const KernelArgs& args) override;
 
  private:
   bool WorkAvailable() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
diff --git a/third_party/xla/xla/stream_executor/host_memory_allocation.cc b/third_party/xla/xla/stream_executor/host_memory_allocation.cc
index 9772396..e77c5e8 100644
--- a/third_party/xla/xla/stream_executor/host_memory_allocation.cc
+++ b/third_party/xla/xla/stream_executor/host_memory_allocation.cc
@@ -27,7 +27,7 @@
 
 HostMemoryAllocation::~HostMemoryAllocation() {
   if (ptr_ != nullptr && executor_ != nullptr) {
-    executor_->HostMemoryDeallocate(ptr_, size_);
+    executor_->HostMemoryDeallocate(ptr_);
   }
 }
 
diff --git a/third_party/xla/xla/stream_executor/integrations/device_mem_allocator.h b/third_party/xla/xla/stream_executor/integrations/device_mem_allocator.h
index 8b31f8b..736b62e 100644
--- a/third_party/xla/xla/stream_executor/integrations/device_mem_allocator.h
+++ b/third_party/xla/xla/stream_executor/integrations/device_mem_allocator.h
@@ -82,7 +82,7 @@
         auto status = stream_exec_->CollectiveMemoryDeallocate(ptr);
         CHECK(status.ok()) << status.message();
       } else if (memory_type_ == MemoryType::kHost) {
-        stream_exec_->HostMemoryDeallocate(ptr, num_bytes);
+        stream_exec_->HostMemoryDeallocate(ptr);
       } else {
         DeviceMemoryBase device_ptr(ptr);
         stream_exec_->Deallocate(&device_ptr);
diff --git a/third_party/xla/xla/stream_executor/kernel_test.cc b/third_party/xla/xla/stream_executor/kernel_test.cc
index cf63e5b..a554785 100644
--- a/third_party/xla/xla/stream_executor/kernel_test.cc
+++ b/third_party/xla/xla/stream_executor/kernel_test.cc
@@ -66,15 +66,12 @@
     std::is_same_v<ArgsStorage<DeviceMemoryBase*, const DeviceMemoryBase*>,
                    std::tuple<const void*, const void*>>);
 
-static std::unique_ptr<StreamExecutor> NewStreamExecutor() {
+static StreamExecutor* NewStreamExecutor() {
   Platform* platform = PlatformManager::PlatformWithName("Host").value();
-  StreamExecutorConfig config(/*ordinal=*/0);
-  return platform->GetUncachedExecutor(config).value();
+  return platform->ExecutorForDevice(/*ordinal=*/0).value();
 }
 
 TEST(KernelTest, PackDeviceMemoryArguments) {
-  auto executor = NewStreamExecutor();
-
   DeviceMemoryBase a(reinterpret_cast<void*>(0x12345678));
   DeviceMemoryBase b(reinterpret_cast<void*>(0x87654321));
 
@@ -125,7 +122,7 @@
   MultiKernelLoaderSpec empty_spec(/*arity=*/0);
 
   auto executor = NewStreamExecutor();
-  auto kernel = TypedKernelFactory<>::Create(executor.get(), empty_spec);
+  auto kernel = TypedKernelFactory<>::Create(executor, empty_spec);
   EXPECT_FALSE(kernel.ok());
 }
 
diff --git a/third_party/xla/xla/stream_executor/lazy_op_runner.h b/third_party/xla/xla/stream_executor/lazy_op_runner.h
index c74a03e..bf964e0 100644
--- a/third_party/xla/xla/stream_executor/lazy_op_runner.h
+++ b/third_party/xla/xla/stream_executor/lazy_op_runner.h
@@ -280,76 +280,6 @@
   }
 };
 
-struct FusedMHAOp {
-  using Signature = FusedMHASignature;
-  struct Config {
-    double scale;
-    const MatmulTensorDescriptor& bmm1_lhs_descriptor;
-    const MatmulTensorDescriptor& bmm1_rhs_descriptor;
-    const MatmulTensorDescriptor& bmm2_rhs_descriptor;
-    const MatmulTensorDescriptor& intermediate_bmm2_lhs_descriptor;
-    const TensorDescriptor& output_descriptor;
-    std::optional<TensorDescriptor> bias_descriptor;
-    std::optional<TensorDescriptor> activation_descriptor;
-    std::optional<double> dropout_rate;
-    std::optional<int64_t> seed;
-    FMHAMaskKind mask_type;
-  };
-
-  static absl::StatusOr<std::unique_ptr<const OpRunner<FusedMHASignature>>>
-  RunnerFromAlgorithmDesc(const AlgorithmDesc& desc, Config config,
-                          Stream* stream) {
-    TF_ASSIGN_OR_RETURN(auto dnn, internal::GetDnnFromStream(stream));
-    return dnn->FusedMHARunnerFromDesc(
-        stream, desc, config.bmm1_lhs_descriptor, config.bmm1_rhs_descriptor,
-        config.bmm2_rhs_descriptor, config.intermediate_bmm2_lhs_descriptor,
-        config.output_descriptor, config.activation_descriptor,
-        config.bias_descriptor, config.scale, config.dropout_rate, config.seed,
-        config.mask_type);
-  }
-};
-
-struct FusedMHABackwardOp {
-  using Signature = FusedMHABackwardSignature;
-
-  struct Config {
-    double scale;
-    const MatmulTensorDescriptor& bmm1_grad_gemm1_rhs_descriptor;
-    const MatmulTensorDescriptor& bmm1_grad_gemm2_rhs_descriptor;
-    const MatmulTensorDescriptor& bmm2_grad_gemm1_lhs_descriptor;
-    const MatmulTensorDescriptor& bmm2_grad_gemm2_rhs_descriptor;
-    const MatmulTensorDescriptor& d_output_descriptor;
-    const TensorDescriptor& d_bmm1_lhs_descriptor;
-    const TensorDescriptor& d_bmm1_rhs_descriptor;
-    const TensorDescriptor& d_bmm2_rhs_descriptor;
-    std::optional<TensorDescriptor> d_s_descriptor;
-    std::optional<TensorDescriptor> d_bias_descriptor;
-    std::optional<TensorDescriptor> fwd_output_descriptor;
-    std::optional<TensorDescriptor> bias_descriptor;
-    std::optional<double> dropout_rate;
-    std::optional<int64_t> seed;
-    FMHAMaskKind mask_type;
-    bool force_deterministic;
-  };
-
-  static absl::StatusOr<
-      std::unique_ptr<const OpRunner<FusedMHABackwardSignature>>>
-  RunnerFromAlgorithmDesc(const AlgorithmDesc& desc, Config config,
-                          Stream* stream) {
-    TF_ASSIGN_OR_RETURN(auto dnn, internal::GetDnnFromStream(stream));
-    return dnn->FusedMHABackwardRunnerFromDesc(
-        stream, desc, config.bmm1_grad_gemm1_rhs_descriptor,
-        config.bmm1_grad_gemm2_rhs_descriptor,
-        config.bmm2_grad_gemm1_lhs_descriptor,
-        config.bmm2_grad_gemm2_rhs_descriptor, config.d_output_descriptor,
-        config.d_bmm1_lhs_descriptor, config.d_bmm1_rhs_descriptor,
-        config.d_bmm2_rhs_descriptor, config.d_s_descriptor,
-        config.d_bias_descriptor, config.fwd_output_descriptor,
-        config.bias_descriptor, config.scale, config.dropout_rate, config.seed,
-        config.mask_type, config.force_deterministic);
-  }
-};
-
 }  // namespace dnn
 }  // namespace stream_executor
 
diff --git a/third_party/xla/xla/stream_executor/mock_stream_executor.h b/third_party/xla/xla/stream_executor/mock_stream_executor.h
index 03dd111..3c69f20 100644
--- a/third_party/xla/xla/stream_executor/mock_stream_executor.h
+++ b/third_party/xla/xla/stream_executor/mock_stream_executor.h
@@ -22,7 +22,6 @@
 #include <string>
 #include <variant>
 
-#include "absl/functional/any_invocable.h"
 #include "absl/status/status.h"
 #include "absl/status/statusor.h"
 #include "absl/types/span.h"
@@ -89,8 +88,7 @@
               (override));
   MOCK_METHOD(absl::StatusOr<std::unique_ptr<MemoryAllocation>>,
               HostMemoryAllocate, (uint64_t size), (override));
-  MOCK_METHOD(void, HostMemoryDeallocate, (void* mem, uint64_t size),
-              (override));
+  MOCK_METHOD(void, HostMemoryDeallocate, (void* mem), (override));
   MOCK_METHOD(bool, SynchronizeAllActivity, (), (override));
   MOCK_METHOD(absl::Status, SynchronousMemZero,
               (DeviceMemoryBase * location, uint64_t size), (override));
diff --git a/third_party/xla/xla/stream_executor/platform.h b/third_party/xla/xla/stream_executor/platform.h
index 5fbc44b..c5120e5 100644
--- a/third_party/xla/xla/stream_executor/platform.h
+++ b/third_party/xla/xla/stream_executor/platform.h
@@ -105,6 +105,13 @@
   virtual absl::StatusOr<std::unique_ptr<DeviceDescription>>
   DescriptionForDevice(int ordinal) const = 0;
 
+  // Returns a StreamExecutor for the given ordinal if one has already been
+  // created, or an error is returned if none exists.  Does not create a new
+  // context with the device.
+  virtual absl::StatusOr<StreamExecutor*> FindExisting(int ordinal) {
+    return absl::NotFoundError("Not implemented for this platform.");
+  }
+
   // Returns a device with the given ordinal on this platform with a default
   // plugin configuration or, if none can be found with the given ordinal or
   // there is an error in opening a context to communicate with the device, an
@@ -118,12 +125,6 @@
   // Ownership of the executor is NOT transferred to the caller.
   virtual absl::StatusOr<StreamExecutor*> GetExecutor(
       const StreamExecutorConfig& config) = 0;
-
-  // Returns a device constructed with the options specified in "config" without
-  // looking in or storing to the Platform's executor cache.
-  // Ownership IS transferred to the caller.
-  virtual absl::StatusOr<std::unique_ptr<StreamExecutor>> GetUncachedExecutor(
-      const StreamExecutorConfig& config) = 0;
 };
 
 }  // namespace stream_executor
diff --git a/third_party/xla/xla/stream_executor/rocm/BUILD b/third_party/xla/xla/stream_executor/rocm/BUILD
index 06a44b0..7204036 100644
--- a/third_party/xla/xla/stream_executor/rocm/BUILD
+++ b/third_party/xla/xla/stream_executor/rocm/BUILD
@@ -184,6 +184,7 @@
         "@com_google_absl//absl/functional:any_invocable",
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/strings:str_format",
+        "@local_config_rocm//rocm:rocm_headers",
         "@local_tsl//tsl/platform:env",
         "@local_tsl//tsl/platform:errors",
         "@local_tsl//tsl/platform:fingerprint",
diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc b/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc
index 465dbbe..2f61eae 100644
--- a/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc
+++ b/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc
@@ -1148,6 +1148,21 @@
   return absl::OkStatus();
 }
 
+absl::Status GpuDriver::LaunchKernel(
+    GpuContext* context, absl::string_view kernel_name,
+    GpuFunctionHandle function, unsigned int cluster_dim_x,
+    unsigned int cluster_dim_y, unsigned int cluster_dim_z,
+    unsigned int grid_dim_x, unsigned int grid_dim_y, unsigned int grid_dim_z,
+    unsigned int block_dim_x, unsigned int block_dim_y,
+    unsigned int block_dim_z, unsigned int shared_mem_bytes,
+    GpuStreamHandle stream, void** kernel_params, void** extra) {
+  if (cluster_dim_x != 1 || cluster_dim_y != 1 || cluster_dim_z != 1)
+    return absl::UnimplementedError("Not implemented for ROCm");
+  return LaunchKernel(context, kernel_name, function, grid_dim_x, grid_dim_y,
+                      grid_dim_z, block_dim_x, block_dim_y, block_dim_z,
+                      shared_mem_bytes, stream, kernel_params, extra);
+}
+
 /* static */ absl::Status GpuDriver::LoadPtx(GpuContext* context,
                                              const char* ptx_contents,
                                              hipModule_t* module) {
diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc b/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc
index cb096c6..76a4db7 100644
--- a/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc
+++ b/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc
@@ -333,58 +333,6 @@
   return absl::OkStatus();
 }
 
-absl::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims,
-                                 const BlockDim& block_dims,
-                                 const Kernel& kernel, const KernelArgs& args) {
-  GpuStreamHandle hipstream = AsGpuStreamValue(stream);
-  const GpuKernel* rocm_kernel = AsGpuKernel(&kernel);
-  hipFunction_t hipfunc = rocm_kernel->gpu_function();
-
-  if (rocm_kernel->cache_config() != KernelCacheConfig::kNoPreference) {
-    TF_RETURN_IF_ERROR(GpuDriver::FuncSetCacheConfig(
-        hipfunc, rocm_kernel->GetGpuCacheConfig()));
-  }
-
-  auto launch = [&](const KernelArgsPackedArrayBase& packed) {
-    CHECK_EQ(kernel.Arity() + (args.number_of_shared_bytes() > 0),
-             packed.number_of_arguments());
-
-    void** kernel_params =
-        const_cast<void**>(packed.argument_addresses().data());
-
-    return GpuDriver::LaunchKernel(
-        GetGpuContext(stream), kernel.name(), hipfunc, block_dims.x,
-        block_dims.y, block_dims.z, thread_dims.x, thread_dims.y, thread_dims.z,
-        args.number_of_shared_bytes(), hipstream, kernel_params, nullptr);
-  };
-
-  auto* packed_args = DynCast<KernelArgsPackedArrayBase>(&args);
-  if (packed_args) return launch(*packed_args);
-
-  if (auto* device_mem = DynCast<KernelArgsDeviceMemoryArray>(&args)) {
-    auto& pack = kernel.args_packing();
-    if (!pack) {
-      return absl::InternalError(
-          "Kernel is missing a custom arguments packing function for device "
-          "memory arguments array");
-    }
-
-    TF_ASSIGN_OR_RETURN(auto packed_args, pack(kernel, *device_mem));
-    return launch(*packed_args);
-  }
-
-  return absl::InternalError("Unsupported kernel arguments type");
-}
-
-absl::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims,
-                                 const BlockDim& block_dims,
-                                 const ClusterDim& cluster_dims,
-                                 const Kernel& kernel, const KernelArgs& args) {
-  if (cluster_dims.x != 1 || cluster_dims.y != 1 || cluster_dims.z != 1)
-    return absl::UnimplementedError("Not implemented for ROCm");
-  return Launch(stream, thread_dims, block_dims, kernel, args);
-}
-
 absl::Status GpuExecutor::LoadModule(const MultiModuleLoaderSpec& spec,
                                      ModuleHandle* module_handle) {
   // In GpuExecutor we store the pointer to the  HSACO binary  as
@@ -447,20 +395,6 @@
   GpuDriver::DeviceDeallocate(context_, mem->opaque());
 }
 
-absl::StatusOr<std::unique_ptr<MemoryAllocation>>
-GpuExecutor::HostMemoryAllocate(uint64_t size) {
-  auto* buffer = GpuDriver::HostAllocate(context_, size);
-  if (buffer == nullptr && size > 0) {
-    return absl::InternalError(
-        absl::StrFormat("Failed to allocate HostMemory of size %d", size));
-  }
-  return std::make_unique<HostMemoryAllocation>(buffer, size, this);
-}
-
-void GpuExecutor::HostMemoryDeallocate(void* location, uint64_t size) {
-  return GpuDriver::HostDeallocate(context_, location);
-}
-
 bool GpuExecutor::SynchronizeAllActivity() {
   return GpuDriver::SynchronizeContext(context_);
 }
diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_platform.cc b/third_party/xla/xla/stream_executor/rocm/rocm_platform.cc
index 0ac3540..a65b5df 100644
--- a/third_party/xla/xla/stream_executor/rocm/rocm_platform.cc
+++ b/third_party/xla/xla/stream_executor/rocm/rocm_platform.cc
@@ -28,67 +28,10 @@
 namespace stream_executor {
 namespace gpu {
 
-ROCmPlatform::ROCmPlatform()
-    : name_("ROCM"), min_numa_node_(0), limit_numa_node_(0) {}
+ROCmPlatform::ROCmPlatform() : name_("ROCM") {}
 
 ROCmPlatform::~ROCmPlatform() {}
 
-// Due to legacy issues in user code, we can't currently call InpectNumaNodes
-// at module initialization time, because non-GPU programs still include this
-// plugin via various methods, so instead, it has to be init-on-reference.
-void ROCmPlatform::InspectNumaNodes() {
-  // To get NUMA node information, we need to create all executors, so we can
-  // examine their device descriptions to see their bus assignments.
-  absl::once_flag once;
-  absl::call_once(once, [&] {
-    StreamExecutorConfig config;
-    for (int i = 0; i < VisibleDeviceCount(); i++) {
-      config.ordinal = i;
-      StreamExecutor* exec = GetExecutor(config).value();
-      if (i == 0) {
-        // NUMA nodes may not start at 0, so set the minimum node  based on the
-        // first executor we see.
-        min_numa_node_ = exec->GetDeviceDescription().numa_node();
-        limit_numa_node_ = min_numa_node_ + 1;
-      } else {
-        min_numa_node_ =
-            std::min(min_numa_node_, exec->GetDeviceDescription().numa_node());
-        limit_numa_node_ = std::max(
-            limit_numa_node_, exec->GetDeviceDescription().numa_node() + 1);
-      }
-    }
-  });
-}
-
-int ROCmPlatform::BusCount() {
-  InspectNumaNodes();
-  return limit_numa_node_ - min_numa_node_;
-}
-
-int ROCmPlatform::DeviceToBus(int device_ordinal) {
-  StreamExecutorConfig config;
-  config.ordinal = device_ordinal;
-  StreamExecutor* exec = GetExecutor(config).value();
-  return exec->GetDeviceDescription().numa_node() - min_numa_node_;
-}
-
-absl::StatusOr<StreamExecutor*> ROCmPlatform::FirstExecutorForBus(
-    int bus_ordinal) {
-  InspectNumaNodes();
-  CHECK_LT(bus_ordinal, BusCount()) << "bus ordinal out of available range";
-  for (int i = 0; i < VisibleDeviceCount(); i++) {
-    if (DeviceToBus(i) == bus_ordinal) {
-      StreamExecutorConfig config;
-      config.ordinal = i;
-      return GetExecutor(config).value();
-    }
-  }
-
-  return absl::Status{
-      absl::StatusCode::kNotFound,
-      absl::StrFormat("Executor for bus %d not found.", bus_ordinal)};
-}
-
 Platform::Id ROCmPlatform::id() const { return rocm::kROCmPlatformId; }
 
 int ROCmPlatform::VisibleDeviceCount() const {
@@ -115,6 +58,12 @@
   return GetExecutor(config);
 }
 
+absl::StatusOr<StreamExecutor*> ROCmPlatform::FindExisting(int ordinal) {
+  StreamExecutorConfig config;
+  config.ordinal = ordinal;
+  return executor_cache_.Get(config);
+}
+
 absl::StatusOr<StreamExecutor*> ROCmPlatform::GetExecutor(
     const StreamExecutorConfig& config) {
   if (config.gpu_stream) {
@@ -130,27 +79,17 @@
 absl::StatusOr<std::unique_ptr<StreamExecutor>>
 ROCmPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) {
   auto executor = std::make_unique<GpuExecutor>(this, config.ordinal);
-  auto init_status = executor->Init();
-  if (!init_status.ok()) {
-    return absl::Status{
-        absl::StatusCode::kInternal,
-        absl::StrFormat(
-            "failed initializing StreamExecutor for ROCM device ordinal %d: %s",
-            config.ordinal, init_status.ToString().c_str())};
-  }
-
+  TF_RETURN_IF_ERROR(executor->Init());
   return std::move(executor);
 }
 
 }  // namespace gpu
 
 static void InitializeROCmPlatform() {
-  // Disabling leak checking, PlatformManager does not destroy its
-  // registered platforms.
   auto status = PlatformManager::PlatformWithName("ROCM");
   if (!status.ok()) {
-    std::unique_ptr<gpu::ROCmPlatform> platform(new gpu::ROCmPlatform);
-    TF_CHECK_OK(PlatformManager::RegisterPlatform(std::move(platform)));
+    TF_CHECK_OK(PlatformManager::RegisterPlatform(
+        std::make_unique<gpu::ROCmPlatform>()));
   }
 }
 
diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_platform.h b/third_party/xla/xla/stream_executor/rocm/rocm_platform.h
index 6d18cf4..923cda8 100644
--- a/third_party/xla/xla/stream_executor/rocm/rocm_platform.h
+++ b/third_party/xla/xla/stream_executor/rocm/rocm_platform.h
@@ -41,16 +41,6 @@
   ROCmPlatform();
   ~ROCmPlatform() override;
 
-  // ROCmPlatform-specific functionality
-  // Returns the number of distinct buses / NUMA nodes on the machine.
-  int BusCount();
-
-  // Returns the bus/NUMA node for the specified device ordinal.
-  int DeviceToBus(int device_ordinal);
-
-  // Returns the lowest-ordinal-number StreamExecutor on the specified bus.
-  absl::StatusOr<StreamExecutor*> FirstExecutorForBus(int bus_ordinal);
-
   // Platform interface implementation:
   // Returns the same value as kROCmPlatform above.
   Platform::Id id() const override;
@@ -64,16 +54,17 @@
       int ordinal) const override;
 
   absl::StatusOr<StreamExecutor*> ExecutorForDevice(int ordinal) override;
+  absl::StatusOr<StreamExecutor*> FindExisting(int ordinal) override;
 
   absl::StatusOr<StreamExecutor*> GetExecutor(
       const StreamExecutorConfig& config) override;
 
-  absl::StatusOr<std::unique_ptr<StreamExecutor>> GetUncachedExecutor(
-      const StreamExecutorConfig& config) override;
-
  private:
-  // Determines the number of NUMA nodes and the assignment of executor to each.
-  void InspectNumaNodes();
+  // Returns a device constructed with the options specified in "config" without
+  // looking in or storing to the Platform's executor cache.
+  // Ownership IS transferred to the caller.
+  absl::StatusOr<std::unique_ptr<StreamExecutor>> GetUncachedExecutor(
+      const StreamExecutorConfig& config);
 
   // This platform's name.
   std::string name_;
@@ -84,15 +75,6 @@
   // Cache of created executors.
   ExecutorCache executor_cache_;
 
-  // The smallest NUMA node value for any device managed by this machine
-  // manager. Used, along with limit_numa_node_, to convert NUMA nodes into bus
-  // ordinals. The NUMA node space occupied by GPUs is assumed to be dense./
-  int min_numa_node_;
-
-  // Larger than the NUMA node value for any device managed by this machine
-  // manager.
-  int limit_numa_node_;
-
   ROCmPlatform(const ROCmPlatform&) = delete;
   void operator=(const ROCmPlatform&) = delete;
 };
diff --git a/third_party/xla/xla/stream_executor/stream_common.cc b/third_party/xla/xla/stream_executor/stream_common.cc
index 048623d..b47198a 100644
--- a/third_party/xla/xla/stream_executor/stream_common.cc
+++ b/third_party/xla/xla/stream_executor/stream_common.cc
@@ -21,7 +21,6 @@
 #include <utility>
 #include <vector>
 
-#include "absl/functional/any_invocable.h"
 #include "absl/status/status.h"
 #include "absl/strings/str_format.h"
 #include "absl/synchronization/mutex.h"
diff --git a/third_party/xla/xla/stream_executor/stream_common.h b/third_party/xla/xla/stream_executor/stream_common.h
index 3d2ade7..e3e75f6 100644
--- a/third_party/xla/xla/stream_executor/stream_common.h
+++ b/third_party/xla/xla/stream_executor/stream_common.h
@@ -28,7 +28,6 @@
 #include <vector>
 
 #include "absl/base/thread_annotations.h"
-#include "absl/functional/any_invocable.h"
 #include "absl/log/check.h"
 #include "absl/status/status.h"
 #include "absl/status/statusor.h"
diff --git a/third_party/xla/xla/stream_executor/stream_executor.h b/third_party/xla/xla/stream_executor/stream_executor.h
index 49929d4..1fa4312 100644
--- a/third_party/xla/xla/stream_executor/stream_executor.h
+++ b/third_party/xla/xla/stream_executor/stream_executor.h
@@ -1,5 +1,3 @@
-#include "absl/functional/any_invocable.h"
-#include "absl/log/log.h"
 /* Copyright 2015 The OpenXLA Authors.
 
 Licensed under the Apache License, Version 2.0 (the "License");
@@ -31,6 +29,7 @@
 #include <variant>
 #include <vector>
 
+#include "absl/log/log.h"
 #include "absl/status/status.h"
 #include "absl/status/statusor.h"
 #include "absl/types/span.h"
@@ -200,7 +199,7 @@
       uint64_t size) = 0;
 
   // Deallocates a region of host memory allocated by HostMemoryAllocate().
-  virtual void HostMemoryDeallocate(void* mem, uint64_t size) = 0;
+  virtual void HostMemoryDeallocate(void* mem) = 0;
 
   // Returns the memory space of the given pointer.
   virtual absl::StatusOr<MemoryType> GetPointerMemorySpace(const void* ptr) {
diff --git a/third_party/xla/xla/stream_executor/stream_executor_test.cc b/third_party/xla/xla/stream_executor/stream_executor_test.cc
index 34bd459..9a2ca57 100644
--- a/third_party/xla/xla/stream_executor/stream_executor_test.cc
+++ b/third_party/xla/xla/stream_executor/stream_executor_test.cc
@@ -25,10 +25,9 @@
 
 namespace stream_executor {
 
-static absl::StatusOr<std::unique_ptr<StreamExecutor>> NewStreamExecutor() {
-  StreamExecutorConfig config(/*ordinal=*/0);
+static absl::StatusOr<StreamExecutor*> NewStreamExecutor() {
   TF_ASSIGN_OR_RETURN(auto platform, PlatformManager::PlatformWithName("Host"));
-  TF_ASSIGN_OR_RETURN(auto stream_exec, platform->GetUncachedExecutor(config));
+  TF_ASSIGN_OR_RETURN(auto stream_exec, platform->ExecutorForDevice(0));
   return stream_exec;
 }
 
diff --git a/third_party/xla/xla/stream_executor/stream_test.cc b/third_party/xla/xla/stream_executor/stream_test.cc
index 473472d..ef5294e 100644
--- a/third_party/xla/xla/stream_executor/stream_test.cc
+++ b/third_party/xla/xla/stream_executor/stream_test.cc
@@ -29,31 +29,30 @@
 
 class StreamTest : public ::testing::Test {
  protected:
-  std::unique_ptr<StreamExecutor> NewStreamExecutor() {
+  StreamExecutor* NewStreamExecutor() {
     Platform* platform = PlatformManager::PlatformWithName("Host").value();
-    StreamExecutorConfig config(/*ordinal=*/0);
-    return platform->GetUncachedExecutor(config).value();
+    return platform->ExecutorForDevice(/*ordinal=*/0).value();
   }
 };
 
 TEST_F(StreamTest, InitOk) {
-  std::unique_ptr<StreamExecutor> executor = NewStreamExecutor();
+  StreamExecutor* executor = NewStreamExecutor();
   TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream());
 }
 
 TEST_F(StreamTest, InitWithIntPriorityOk) {
-  std::unique_ptr<StreamExecutor> executor = NewStreamExecutor();
+  StreamExecutor* executor = NewStreamExecutor();
   TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream(1));
 }
 
 TEST_F(StreamTest, InitWithStreamPriorityOk) {
-  std::unique_ptr<StreamExecutor> executor = NewStreamExecutor();
+  StreamExecutor* executor = NewStreamExecutor();
   TF_ASSERT_OK_AND_ASSIGN(auto stream,
                           executor->CreateStream(StreamPriority::Highest));
 }
 
 TEST_F(StreamTest, OneSubStream) {
-  std::unique_ptr<StreamExecutor> executor = NewStreamExecutor();
+  StreamExecutor* executor = NewStreamExecutor();
   TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream());
 
   // Get and return a sub-stream. Sub-streams are always initialized.
@@ -72,7 +71,7 @@
 }
 
 TEST_F(StreamTest, TwoSubStreams) {
-  std::unique_ptr<StreamExecutor> executor = NewStreamExecutor();
+  StreamExecutor* executor = NewStreamExecutor();
   TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream());
 
   // Get two sub-streams.
diff --git a/third_party/xla/xla/stream_executor/sycl/BUILD b/third_party/xla/xla/stream_executor/sycl/BUILD
index 8745be9..86c00a0 100644
--- a/third_party/xla/xla/stream_executor/sycl/BUILD
+++ b/third_party/xla/xla/stream_executor/sycl/BUILD
@@ -48,6 +48,7 @@
         "//xla/stream_executor/gpu:gpu_driver_header",
         "//xla/stream_executor/gpu:gpu_executor_header",
         "//xla/stream_executor/gpu:gpu_collectives_header",
+        "@local_tsl//tsl/platform:errors",
     ]),
     alwayslink = True,  # Registers itself with the PlatformManager.
 )
diff --git a/third_party/xla/xla/stream_executor/sycl/sycl_platform.cc b/third_party/xla/xla/stream_executor/sycl/sycl_platform.cc
index 876775b..ac6da36 100644
--- a/third_party/xla/xla/stream_executor/sycl/sycl_platform.cc
+++ b/third_party/xla/xla/stream_executor/sycl/sycl_platform.cc
@@ -35,65 +35,16 @@
 #include "xla/stream_executor/platform/initialize.h"
 #include "xla/stream_executor/platform_manager.h"
 #include "xla/stream_executor/sycl/sycl_platform_id.h"
+#include "tsl/platform/errors.h"
 #include "tsl/platform/status.h"
 
 namespace stream_executor {
 namespace gpu {
 
-SyclPlatform::SyclPlatform()
-    : name_("SYCL"), min_numa_node_(0), limit_numa_node_(0) {}
+SyclPlatform::SyclPlatform() : name_("SYCL") {}
 
 SyclPlatform::~SyclPlatform() {}
 
-// Due to legacy issues in user code, we can't currently call InspectNumaNodes
-// at module initialization time, because non-GPU programs still include this
-// plugin via various methods, so instead, it has to be init-on-reference.
-void SyclPlatform::InspectNumaNodes() {
-  // To get NUMA node information, we need to create all executors, so we can
-  // examine their device descriptions to see their bus assignments.
-  static absl::once_flag once;
-  absl::call_once(once, [&] {
-    for (int i = 0; i < VisibleDeviceCount(); i++) {
-      StreamExecutor* exec = *ExecutorForDevice(i);
-      if (i == 0) {
-        // NUMA nodes may not start at 0, so set the minimum node  based on the
-        // first executor we see.
-        min_numa_node_ = exec->GetDeviceDescription().numa_node();
-        limit_numa_node_ = min_numa_node_ + 1;
-      } else {
-        min_numa_node_ =
-            std::min(min_numa_node_, exec->GetDeviceDescription().numa_node());
-        limit_numa_node_ = std::max(
-            limit_numa_node_, exec->GetDeviceDescription().numa_node() + 1);
-      }
-    }
-  });
-}
-
-int SyclPlatform::BusCount() {
-  InspectNumaNodes();
-  return limit_numa_node_ - min_numa_node_;
-}
-
-int SyclPlatform::DeviceToBus(int device_ordinal) {
-  StreamExecutor* exec = *ExecutorForDevice(device_ordinal);
-  return exec->GetDeviceDescription().numa_node() - min_numa_node_;
-}
-
-absl::StatusOr<StreamExecutor*> SyclPlatform::FirstExecutorForBus(
-    int bus_ordinal) {
-  InspectNumaNodes();
-  CHECK_LT(bus_ordinal, BusCount()) << "bus ordinal out of available range";
-  for (int i = 0; i < VisibleDeviceCount(); i++) {
-    if (DeviceToBus(i) == bus_ordinal) {
-      return *ExecutorForDevice(i);
-    }
-  }
-
-  return absl::NotFoundError(
-      absl::StrFormat("Executor for bus %d not found.", bus_ordinal));
-}
-
 Platform::Id SyclPlatform::id() const { return sycl::kSyclPlatformId; }
 
 int SyclPlatform::VisibleDeviceCount() const {
@@ -133,24 +84,15 @@
 absl::StatusOr<std::unique_ptr<StreamExecutor>>
 SyclPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) {
   auto executor = std::make_unique<GpuExecutor>(this, config.ordinal);
-  auto init_status = executor->Init();
-  if (!init_status.ok()) {
-    return absl::InternalError(absl::StrFormat(
-        "failed initializing StreamExecutor for SYCL device ordinal %d: %s",
-        config.ordinal, init_status.ToString()));
-  }
-
+  TF_RETURN_IF_ERROR(executor->Init());
   return std::move(executor);
 }
 
 }  // namespace gpu
 
 static void InitializeSyclPlatform() {
-  // Disabling leak checking, PlatformManager does not destroy its
-  // registered platforms.
-
-  std::unique_ptr<gpu::SyclPlatform> platform(new gpu::SyclPlatform);
-  TF_CHECK_OK(PlatformManager::RegisterPlatform(std::move(platform)));
+  TF_CHECK_OK(
+      PlatformManager::RegisterPlatform(std::make_unique<gpu::SyclPlatform>()));
 }
 
 }  // namespace stream_executor
diff --git a/third_party/xla/xla/stream_executor/sycl/sycl_platform.h b/third_party/xla/xla/stream_executor/sycl/sycl_platform.h
index 0c687f4..adc6cc9 100644
--- a/third_party/xla/xla/stream_executor/sycl/sycl_platform.h
+++ b/third_party/xla/xla/stream_executor/sycl/sycl_platform.h
@@ -41,16 +41,6 @@
   SyclPlatform();
   ~SyclPlatform() override;
 
-  // SyclPlatform-specific functionality
-  // Returns the number of distinct buses / NUMA nodes on the machine.
-  int BusCount();
-
-  // Returns the bus/NUMA node for the specified device ordinal.
-  int DeviceToBus(int device_ordinal);
-
-  // Returns the lowest-ordinal-number StreamExecutor on the specified bus.
-  absl::StatusOr<StreamExecutor*> FirstExecutorForBus(int bus_ordinal);
-
   // Platform interface implementation:
   // Returns the same value as kSyclPlatform above.
   Platform::Id id() const override;
@@ -68,28 +58,19 @@
   absl::StatusOr<StreamExecutor*> GetExecutor(
       const StreamExecutorConfig& config) override;
 
+ private:
+  // Returns a device constructed with the options specified in "config" without
+  // looking in or storing to the Platform's executor cache.
+  // Ownership IS transferred to the caller.
   absl::StatusOr<std::unique_ptr<StreamExecutor>> GetUncachedExecutor(
       const StreamExecutorConfig& config) override;
 
- private:
-  // Determines the number of NUMA nodes and the assignment of executor to each.
-  void InspectNumaNodes();
-
   // This platform's name.
   std::string name_;
 
   // Cache of created executors.
   ExecutorCache executor_cache_;
 
-  // The smallest NUMA node value for any device managed by this machine
-  // manager. Used, along with limit_numa_node_, to convert NUMA nodes into bus
-  // ordinals. The NUMA node space occupied by GPUs is assumed to be dense.
-  int min_numa_node_;
-
-  // Larger than the NUMA node value for any device managed by this machine
-  // manager.
-  int limit_numa_node_;
-
   SyclPlatform(const SyclPlatform&) = delete;
   void operator=(const SyclPlatform&) = delete;
 };
diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_executor.h b/third_party/xla/xla/stream_executor/tpu/tpu_executor.h
index c969c6d..85646af 100644
--- a/third_party/xla/xla/stream_executor/tpu/tpu_executor.h
+++ b/third_party/xla/xla/stream_executor/tpu/tpu_executor.h
@@ -137,7 +137,7 @@
       uint64_t size) override {
     LOG(FATAL) << "not yet implemented";
   }
-  void HostMemoryDeallocate(void* mem, uint64_t size) override {
+  void HostMemoryDeallocate(void* mem) override {
     LOG(FATAL) << "not yet implemented";
   }
   absl::Status SynchronousMemZero(DeviceMemoryBase* location,
diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_platform.h b/third_party/xla/xla/stream_executor/tpu/tpu_platform.h
index f2ea4f9..e5e616c 100644
--- a/third_party/xla/xla/stream_executor/tpu/tpu_platform.h
+++ b/third_party/xla/xla/stream_executor/tpu/tpu_platform.h
@@ -88,11 +88,14 @@
     return GetExecutor(config);
   }
 
-  absl::StatusOr<::stream_executor::StreamExecutor*> GetExecutor(
-      const ::stream_executor::StreamExecutorConfig& config) override;
+  absl::StatusOr<::stream_executor::StreamExecutor*> FindExisting(
+      int ordinal) override {
+    stream_executor::StreamExecutorConfig config;
+    config.ordinal = ordinal;
+    return executor_cache_.Get(config);
+  }
 
-  absl::StatusOr<std::unique_ptr<::stream_executor::StreamExecutor>>
-  GetUncachedExecutor(
+  absl::StatusOr<::stream_executor::StreamExecutor*> GetExecutor(
       const ::stream_executor::StreamExecutorConfig& config) override;
 
   StreamMap* stream_map() { return &stream_map_; }
@@ -118,6 +121,12 @@
   absl::Mutex& mutex() { return event_map_mu_; }
 
  private:
+  // Returns a device constructed with the options specified in "config" without
+  // looking in or storing to the Platform's executor cache.
+  // Ownership IS transferred to the caller.
+  absl::StatusOr<std::unique_ptr<::stream_executor::StreamExecutor>>
+  GetUncachedExecutor(const ::stream_executor::StreamExecutorConfig& config);
+
   mutable SE_Platform* platform_;
   std::string name_;
   stream_executor::ExecutorCache executor_cache_;
diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD
index e7219fd4..6697d09 100644
--- a/third_party/xla/xla/tests/BUILD
+++ b/third_party/xla/xla/tests/BUILD
@@ -201,6 +201,7 @@
         "//xla:xla_data_proto_cc",
         "//xla/hlo/ir:hlo",
         "//xla/hlo/ir:hlo_module_group",
+        "//xla/hlo/utils:hlo_query",
         "//xla/service:backend",
         "//xla/service:computation_layout",
         "//xla/service:hlo_module_util",
@@ -1782,12 +1783,16 @@
         ":literal_test_util",
         ":test_macros_header",
         ":xla_internal_test_main",
-        "//xla:array2d",
+        "//xla:array3d",
+        "//xla:array4d",
         "//xla:literal",
+        "//xla:literal_util",
+        "//xla:shape_util",
         "//xla:xla_data_proto_cc",
         "//xla/client:xla_builder",
         "//xla/hlo/ir:hlo",
-        "@local_tsl//tsl/platform:protobuf",
+        "@com_google_absl//absl/types:span",
+        "@com_google_googletest//:gtest_main",
         "@local_tsl//tsl/platform:test",
     ],
 )
@@ -1900,6 +1905,8 @@
         ":test_macros_header",
         ":test_utils",
         ":xla_internal_test_main",  # fixdeps: keep
+        "//xla:array2d",
+        "//xla:array3d",
         "//xla:literal",
         "//xla:literal_util",
         "//xla:shape_util",
@@ -1910,15 +1917,19 @@
         "//xla/ffi:attribute_map",
         "//xla/ffi:ffi_api",
         "//xla/hlo/ir:hlo",
+        "//xla/service",
         "//xla/service:custom_call_status",
         "//xla/service:custom_call_target_registry",
         "//xla/tsl/lib/core:status_test_util",
         "@com_google_absl//absl/algorithm:container",
         "@com_google_absl//absl/base:dynamic_annotations",
         "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/strings:str_format",
+        "@com_google_absl//absl/synchronization",
         "@com_google_absl//absl/types:span",
+        "@eigen_archive//:eigen3",
         "@local_tsl//tsl/platform:statusor",
         "@local_tsl//tsl/platform:test",
     ],
@@ -2124,6 +2135,9 @@
 xla_test(
     name = "dynamic_reshape_test",
     srcs = ["dynamic_reshape_test.cc"],
+    backend_tags = {
+        "gpu": ["notsan"],  # TODO(b/345034145): Fix tsan error.
+    },
     disabled_backends = ["interpreter"],
     tags = ["test_xla_cpu_thunks"],
     deps = [
@@ -2311,14 +2325,9 @@
     srcs = ["collective_pipeline_parallelism_test.cc"],
     args = ["--xla_force_host_platform_device_count=4"],
     backend_tags = {
-        # This test is tagged "manual" because it requires multiple GPUs, and Forge only supports
-        # single-GPU tests.  Guitar skips "manual" tests unless they're also tagged "guitar".
         "gpu": [
-            "guitar",
-            "manual",
             "multi_gpu",
             "no_oss",
-            "notap",
         ],
         "cpu": [
             "notsan",
@@ -2357,15 +2366,9 @@
     name = "collective_ops_e2e_test",
     srcs = ["collective_ops_e2e_test.cc"],
     backend_tags = {
-        # This test is tagged "manual" because it requires multiple GPUs, and
-        # Forge only supports single-GPU tests.  Guitar skips "manual" tests
-        # unless they're also tagged "guitar".
         "gpu": [
-            "guitar",
-            "manual",
             "multi_gpu",
             "no_oss",
-            "notap",
         ],
     },
     backends = [
@@ -2409,15 +2412,9 @@
     name = "replicated_io_feed_test",
     srcs = ["replicated_io_feed_test.cc"],
     backend_tags = {
-        # This test is tagged "manual" because it requires multiple GPUs, and
-        # Forge only supports single-GPU tests.  Guitar skips "manual" tests
-        # unless they're also tagged "guitar".
         "gpu": [
-            "guitar",
-            "manual",
             "multi_gpu",
             "no_oss",
-            "notap",
         ],
     },
     backends = ["gpu"],
@@ -2790,7 +2787,10 @@
     size = "large",
     srcs = ["local_client_execute_test.cc"],
     shard_count = 30,
-    tags = ["optonly"],
+    tags = [
+        "optonly",
+        "test_xla_cpu_thunks",
+    ],
     deps = [
         ":literal_test_util",
         ":local_client_test_base",
@@ -3081,6 +3081,9 @@
 xla_test(
     name = "set_dimension_size_test",
     srcs = ["set_dimension_size_test.cc"],
+    backend_tags = {
+        "gpu": ["notsan"],  # TODO(b/345034145): Fix tsan error.
+    },
     tags = ["test_xla_cpu_thunks"],
     deps = [
         ":hlo_test_base",
@@ -3189,6 +3192,7 @@
         "//xla:types",
         "//xla/hlo/ir:hlo",
         "@com_google_absl//absl/status:statusor",
+        "@local_tsl//tsl/platform:statusor",
         "@local_tsl//tsl/platform:test",
     ],
 )
diff --git a/third_party/xla/xla/tests/build_defs.bzl b/third_party/xla/xla/tests/build_defs.bzl
index 8a42642..8f528c5 100644
--- a/third_party/xla/xla/tests/build_defs.bzl
+++ b/third_party/xla/xla/tests/build_defs.bzl
@@ -320,8 +320,23 @@
     # b/317293391. For this reason, if we would create an empty `test_suite`,
     # instead create a `cc_test` with no srcs that links against `main` to have
     # more predictable behavior that avoids bugs.
+    #
+    # Due to b/317293391, we also mark the test suite `manual`, so that wild card builds
+    # like in the XLA CI won't try to build the test suite target. Instead the wild card
+    # build will build the individual test targets and therefore respect the tags on each
+    # individual test target.
+    # Example: Assume we have an `xla_test(name=my_test)` in `//xla/service/gpu` with backends `cpu`
+    # and `gpu`. This generates two test targets `//xla/service/gpu:my_test_{cpu|gpu}`. The latter
+    # has a tag `gpu`.
+    #
+    # - `bazel test --test_tag_filters=-gpu //xla/service/gpu/...` will only run the cpu test.
+    # - `bazel test //xla/service/gpu/...` will run both tests.
+    # - `bazel test //xla/service/gpu:my_test` will run both tests.
+    # Caveat:
+    # - `bazel test --test_tag_filters=-gpu //xla/service/gpu:my_test` will run both tests and
+    #   not respect the tag filter - but it's way better than the previous behavoir.
     if test_names:
-        native.test_suite(name = name, tags = tags, tests = test_names)
+        native.test_suite(name = name, tags = tags + ["manual"], tests = test_names)
     else:
         native.cc_test(name = name, deps = ["@local_tsl//tsl/platform:test_main"])
 
diff --git a/third_party/xla/xla/tests/collective_ops_e2e_test.cc b/third_party/xla/xla/tests/collective_ops_e2e_test.cc
index f1d1c78..9908538 100644
--- a/third_party/xla/xla/tests/collective_ops_e2e_test.cc
+++ b/third_party/xla/xla/tests/collective_ops_e2e_test.cc
@@ -154,6 +154,7 @@
     )";
 
   const int64_t kNumReplicas = 2;
+  SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
   const bool enable_async_all_reduce = GetParam();
   TF_ASSERT_OK_AND_ASSIGN(auto executable,
                           CreateExecutable(kModuleStr, kNumReplicas));
@@ -190,6 +191,7 @@
   }
   )";
   const int64_t kNumReplicas = 2;
+  SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
   const bool enable_async_all_gather = GetParam();
 
   TF_ASSERT_OK_AND_ASSIGN(auto executable,
@@ -231,6 +233,7 @@
   }
   )";
   const int64_t kNumReplicas = 2;
+  SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
   const bool enable_async_all_gather = GetParam();
 
   TF_ASSERT_OK_AND_ASSIGN(auto executable,
@@ -268,6 +271,7 @@
   }
   )";
   const int64_t kNumReplicas = 2;
+  SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
   const bool enable_async_collective_broadcast = GetParam();
   TF_ASSERT_OK_AND_ASSIGN(auto executable,
                           CreateExecutable(kModuleStr, kNumReplicas));
@@ -300,6 +304,7 @@
   }
   )";
   const int64_t kNumReplicas = 2;
+  SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
   const bool enable_async_collective_permute = GetParam();
   TF_ASSERT_OK_AND_ASSIGN(auto executable,
                           CreateExecutable(kModuleStr, kNumReplicas));
@@ -343,6 +348,7 @@
   )";
 
   const int64_t kNumReplicas = 2;
+  SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
   const bool enable_async_reduce_scatter = GetParam();
   TF_ASSERT_OK_AND_ASSIGN(auto executable,
                           CreateExecutable(kModuleStr, kNumReplicas));
@@ -376,6 +382,7 @@
   }
   )";
   const int64_t kNumReplicas = 2;
+  SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
   const bool enable_async_all_to_all = GetParam();
   TF_ASSERT_OK_AND_ASSIGN(auto executable,
                           CreateExecutable(kModuleStr, kNumReplicas));
@@ -420,6 +427,7 @@
   }
   )";
   const int64_t kNumReplicas = 2;
+  SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
   const bool enable_async_all_to_all = GetParam();
   TF_ASSERT_OK_AND_ASSIGN(auto executable,
                           CreateExecutable(kModuleStr, kNumReplicas));
@@ -472,6 +480,7 @@
    }
   )";
   const int64_t kNumReplicas = 4;
+  SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
 
   HloModuleConfig config =
       GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
@@ -592,6 +601,7 @@
   )";
 
   const int64_t kNumReplicas = 2;
+  SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
 
   DebugOptions debug_options = GetDebugOptionsForTest();
   debug_options.set_xla_gpu_enable_while_loop_reduce_scatter_code_motion(true);
@@ -646,6 +656,7 @@
   }
   )";
   const int64_t kNumReplicas = 2;
+  SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
 
   HloModuleConfig config =
       GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
@@ -677,6 +688,7 @@
       absl::string_view hlo_text, bool disable_dot_merger = false) {
     const int64_t kNumReplicas = 1;
     const int64_t kNumPartitions = 4;
+    SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas * kNumPartitions);
 
     HloModuleConfig config =
         GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
@@ -959,6 +971,7 @@
 )";
 
   const int64_t kNumReplicas = 1;
+  SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
   const int64_t kNumPartitions = 4;
 
   HloModuleConfig config =
@@ -1052,6 +1065,7 @@
 )";
 
   const int64_t kNumReplicas = 1;
+  SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
 
   HloModuleConfig config =
       GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
@@ -1085,6 +1099,7 @@
 )";
 
   const int64_t kNumReplicas = 1;
+  SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
   const int64_t kNumPartitions = 4;
 
   HloModuleConfig config =
diff --git a/third_party/xla/xla/tests/collective_ops_test.cc b/third_party/xla/xla/tests/collective_ops_test.cc
index 460864c..9cd874c 100644
--- a/third_party/xla/xla/tests/collective_ops_test.cc
+++ b/third_party/xla/xla/tests/collective_ops_test.cc
@@ -39,23 +39,17 @@
 #include "tsl/platform/env.h"
 #include "tsl/platform/threadpool.h"
 
+namespace xla {
+namespace {
+
 // Tests cross-GPU operations.
 //
 // Several tests requires at least four GPUs.  For instructions on running this
 // within Google, see go/multi-gpu-unit-test.
-
-#define SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(x)                     \
-  if (num_devices_ < x) {                                         \
-    GTEST_SKIP() << "Test requires at least " << x << " devices"; \
-  }
-
-namespace xla {
-namespace {
-
 class CollectiveOpsTest : public HloTestBase {
  public:
-  CollectiveOpsTest() : num_devices_(backend().device_count()) {
-    VLOG(1) << "Running with " << num_devices_ << " devices";
+  CollectiveOpsTest() {
+    VLOG(1) << "Running with " << num_devices() << " devices";
   }
 
  protected:
@@ -180,9 +174,6 @@
           /*expected_value=*/to_literal({cast(-1), cast(-2), cast(-3)}));
     }
   }
-
- protected:
-  const int64_t num_devices_;
 };
 
 // Returns the non-empty subsets of {0, 1, ..., n}.  For example,
@@ -370,7 +361,7 @@
 XLA_TEST_F(CollectiveOpsTest, AllReduce_AllCombinations) {
   const int64_t kNumElems = 1024;
 
-  for (std::vector<int64_t> devices : PowerSetOfIota(num_devices_)) {
+  for (std::vector<int64_t> devices : PowerSetOfIota(num_devices())) {
     SCOPED_TRACE(absl::StrFormat("Running on devices {%s}",
                                  absl::StrJoin(devices, ", ")));
 
@@ -494,7 +485,7 @@
   // Test a prime number so it's not all powers of 2.
   const int64_t kNumElems = 137;
   const int64_t kNumReplicas = 4;
-  SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas)
+  SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
 
   HloModuleConfig config =
       GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
@@ -541,7 +532,7 @@
       }
     )";
   static constexpr int kNumReplicas = 4;
-  SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas)
+  SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
 
   HloModuleConfig config =
       GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
@@ -577,19 +568,19 @@
     )";
 
   HloModuleConfig config =
-      GetModuleConfigForTest(/*replica_count=*/num_devices_);
+      GetModuleConfigForTest(/*replica_count=*/num_devices());
   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
                           ParseAndReturnVerifiedModule(kModuleStr, config));
   TF_ASSERT_OK_AND_ASSIGN(
       std::vector<Literal> results,
       ExecuteReplicated(std::move(module), absl::Span<Literal* const>{},
-                        num_devices_,
+                        num_devices(),
                         /*use_threads=*/true, /*run_hlo_passes=*/false));
 
-  ASSERT_EQ(results.size(), num_devices_);
+  ASSERT_EQ(results.size(), num_devices());
   // sum [0, num_devices)
-  uint32_t expected = num_devices_ * (num_devices_ - 1) / 2;
-  for (int i = 0; i < num_devices_; ++i) {
+  uint32_t expected = num_devices() * (num_devices() - 1) / 2;
+  for (int i = 0; i < num_devices(); ++i) {
     LiteralTestUtil::ExpectR0Equal<uint32_t>(expected, results[i]);
   }
 }
@@ -613,22 +604,22 @@
     )";
 
   HloModuleConfig config =
-      GetModuleConfigForTest(/*replica_count=*/num_devices_);
+      GetModuleConfigForTest(/*replica_count=*/num_devices());
   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
                           ParseAndReturnVerifiedModule(kModuleStr, config));
   TF_ASSERT_OK_AND_ASSIGN(
       std::vector<Literal> results,
       ExecuteReplicated(std::move(module), absl::Span<Literal* const>{},
-                        num_devices_,
+                        num_devices(),
                         /*use_threads=*/true, /*run_hlo_passes=*/false));
 
-  ASSERT_EQ(results.size(), num_devices_);
+  ASSERT_EQ(results.size(), num_devices());
   // sum [0, num_devices)
-  uint32_t expected0 = num_devices_ * (num_devices_ - 1) / 2;
+  uint32_t expected0 = num_devices() * (num_devices() - 1) / 2;
   // sum squares [0, num_devices)
   uint32_t expected1 =
-      num_devices_ * (num_devices_ - 1) * (2 * num_devices_ - 1) / 6;
-  for (int i = 0; i < num_devices_; ++i) {
+      num_devices() * (num_devices() - 1) * (2 * num_devices() - 1) / 6;
+  for (int i = 0; i < num_devices(); ++i) {
     std::vector<Literal> replica_results = results[i].DecomposeTuple();
     LiteralTestUtil::ExpectR0Equal<uint32_t>(expected0, replica_results[0]);
     LiteralTestUtil::ExpectR0Equal<uint32_t>(expected1, replica_results[1]);
@@ -645,18 +636,18 @@
   )";
 
   HloModuleConfig config =
-      GetModuleConfigForTest(/*replica_count=*/num_devices_);
+      GetModuleConfigForTest(/*replica_count=*/num_devices());
   TF_ASSERT_OK_AND_ASSIGN(auto module,
                           ParseAndReturnVerifiedModule(kModuleStr));
 
   TF_ASSERT_OK_AND_ASSIGN(
       std::vector<Literal> results,
       ExecuteReplicated(std::move(module), absl::Span<Literal* const>{},
-                        num_devices_,
+                        num_devices(),
                         /*use_threads=*/true, /*run_hlo_passes=*/true));
 
-  ASSERT_EQ(results.size(), num_devices_);
-  for (uint32_t i = 0; i < num_devices_; ++i) {
+  ASSERT_EQ(results.size(), num_devices());
+  for (uint32_t i = 0; i < num_devices(); ++i) {
     EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR0(i), results[i]));
   }
 }
@@ -680,7 +671,7 @@
   }
   )";
   const int64_t kNumReplicas = 4;
-  SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas)
+  SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
 
   HloModuleConfig config =
       GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
@@ -716,7 +707,7 @@
   }
   )";
   const int64_t kNumReplicas = 4;
-  SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas)
+  SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
 
   HloModuleConfig config =
       GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
@@ -753,7 +744,7 @@
   }
   )";
   const int64_t kNumReplicas = 4;
-  SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas)
+  SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
 
   HloModuleConfig config =
       GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
@@ -789,7 +780,7 @@
   }
   )";
   const int64_t kNumReplicas = 4;
-  SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas)
+  SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
 
   HloModuleConfig config =
       GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
@@ -826,7 +817,7 @@
   }
   )";
   const int64_t kNumReplicas = 4;
-  SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas)
+  SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
 
   HloModuleConfig config =
       GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
@@ -864,7 +855,7 @@
     )";
 
   const int64_t kNumReplicas = 2;
-  SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas)
+  SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
 
   HloModuleConfig config =
       GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
@@ -906,7 +897,7 @@
   }
   )";
   const int64_t kNumReplicas = 4;
-  SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas)
+  SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
 
   HloModuleConfig config =
       GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
@@ -952,7 +943,7 @@
   }
   )";
   const int64_t kNumReplicas = 4;
-  SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas)
+  SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
 
   HloModuleConfig config =
       GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
@@ -992,7 +983,7 @@
   }
   )";
   const int64_t kNumReplicas = 4;
-  SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas)
+  SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
 
   HloModuleConfig config =
       GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
@@ -1024,7 +1015,7 @@
   }
   )";
   const int64_t kNumReplicas = 4;
-  SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas)
+  SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
 
   HloModuleConfig config =
       GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
@@ -2003,7 +1994,7 @@
   )";
 
   const int64_t kNumReplicas = 2;
-  SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas)
+  SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
 
   HloModuleConfig config =
       GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
@@ -2083,7 +2074,7 @@
   })";
 
   const int64_t kNumReplicas = 2;
-  SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas)
+  SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
 
   HloModuleConfig config =
       GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
@@ -2162,7 +2153,7 @@
   })";
 
   const int64_t kNumReplicas = 2;
-  SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas)
+  SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
 
   HloModuleConfig config =
       GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
@@ -2263,7 +2254,7 @@
   })";
 
   const int64_t kNumReplicas = 2;
-  SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas)
+  SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
 
   HloModuleConfig config =
       GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
diff --git a/third_party/xla/xla/tests/collective_pipeline_parallelism_test.cc b/third_party/xla/xla/tests/collective_pipeline_parallelism_test.cc
index bfcf5e1..ee84472 100644
--- a/third_party/xla/xla/tests/collective_pipeline_parallelism_test.cc
+++ b/third_party/xla/xla/tests/collective_pipeline_parallelism_test.cc
@@ -15,6 +15,7 @@
 
 #include <cstdint>
 #include <memory>
+#include <string>
 #include <utility>
 #include <vector>
 
@@ -32,28 +33,18 @@
 #include "xla/tests/verified_hlo_module.h"
 #include "tsl/platform/statusor.h"
 
+namespace xla {
+namespace {
+
 // Tests cross-GPU operations.
 //
 // Several tests requires at least four GPUs.  For instructions on running this
 // within Google, see go/multi-gpu-unit-test.
-
-// TODO: Move this to hlo_test_base.h
-#define SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(x)                     \
-  if (num_devices_ < x) {                                         \
-    GTEST_SKIP() << "Test requires at least " << x << " devices"; \
-  }
-
-namespace xla {
-namespace {
-
 class CollectivePipelineParallelismTest : public HloTestBase {
  public:
-  CollectivePipelineParallelismTest() : num_devices_(backend().device_count()) {
-    VLOG(1) << "Running with " << num_devices_ << " devices";
+  CollectivePipelineParallelismTest() {
+    VLOG(1) << "Running with " << num_devices() << " devices";
   }
-
- protected:
-  const int64_t num_devices_;
 };
 
 XLA_TEST_F(CollectivePipelineParallelismTest,
@@ -126,26 +117,59 @@
   LiteralTestUtil::ExpectR2Equal<float>({{0, 0}, {1, 1}}, results[3]);
 }
 
-// Naive implementation of pipeline parallelism:
-//   - 4 devices
-//   - 4 microbatches
-//   - no circular repeat
-//   - no disabled collectives
-//   - no collective pipelining
-//
-// Every stage of the pipeline is a single linear layer.
-XLA_TEST_F(CollectivePipelineParallelismTest, NaiveDFSMicrobatch4Replica4) {
-  const absl::string_view kModuleStr = R"(
-  HloModule test
+std::string GetModuleStrWithCommonComputations(
+    const std::string name, const std::string more_computations) {
+  static constexpr char kCommonComputationsStr[] = R"(
+  read_buffer_mb4 {
+    buffer = f32[4,16] parameter(0)
+    offset = u32[] parameter(1)
+    index = u32[] parameter(2)
+    c0 = u32[] constant(0)
+    c4 = u32[] constant(4)
+    index_ = u32[] add(index, offset)
+    index__ = u32[] remainder(index_, c4)
+    slice = f32[1,16] dynamic-slice(buffer, index__, c0),
+        dynamic_slice_sizes={1,16}
+    ROOT slice_ = f32[16] reshape(slice)
+  }
 
-  get_circ_buffer_index {
-    offset = u32[] parameter(0)
-    index = u32[] parameter(1)
-    size = u32[] parameter(2)
-    t0 = u32[] add(offset, index)
-    t1 = u32[] divide(t0, size)
-    t2 = u32[] multiply(t1, size)
-    ROOT t4 = u32[] subtract(t0, t2)
+  read_buffer_mb5 {
+    buffer = f32[5,16] parameter(0)
+    offset = u32[] parameter(1)
+    index = u32[] parameter(2)
+    c0 = u32[] constant(0)
+    c5 = u32[] constant(5)
+    index_ = u32[] add(index, offset)
+    index__ = u32[] remainder(index_, c5)
+    slice = f32[1,16] dynamic-slice(buffer, index__, c0),
+        dynamic_slice_sizes={1,16}
+    ROOT slice_ = f32[16] reshape(slice)
+  }
+
+  update_buffer_mb4 {
+    buffer = f32[4,16] parameter(0)
+    update = f32[16] parameter(1)
+    offset = u32[] parameter(2)
+    index = u32[] parameter(3)
+    c0 = u32[] constant(0)
+    c4 = u32[] constant(4)
+    index_ = u32[] add(index, offset)
+    index__ = u32[] remainder(index_, c4)
+    update_ = f32[1,16] reshape(update)
+    ROOT buffer_ = f32[4,16] dynamic-update-slice(buffer, update_, index__, c0)
+  }
+
+  update_buffer_mb5 {
+    buffer = f32[5,16] parameter(0)
+    update = f32[16] parameter(1)
+    offset = u32[] parameter(2)
+    index = u32[] parameter(3)
+    c0 = u32[] constant(0)
+    c5 = u32[] constant(5)
+    index_ = u32[] add(index, offset)
+    index__ = u32[] remainder(index_, c5)
+    update_ = f32[1,16] reshape(update)
+    ROOT buffer_ = f32[5,16] dynamic-update-slice(buffer, update_, index__, c0)
   }
 
   is_input_replica {
@@ -156,10 +180,40 @@
 
   is_output_replica {
     replica_id = u32[] replica-id()
-    c1 = u32[] constant(1)
-    ROOT predicate = pred[] compare(replica_id, c1), direction=EQ
+    c3 = u32[] constant(3)
+    ROOT predicate = pred[] compare(replica_id, c3), direction=EQ
   }
 
+  is_read_input_mb4 {
+    is_input_replica = pred[] call(), to_apply=is_input_replica
+    i = u32[] parameter(0)
+    c4 = u32[] constant(4)
+    is_input_iteration = pred[] compare(i, c4), direction=LT
+    ROOT is_read_input = pred[] and(is_input_replica, is_input_iteration)
+  }
+
+  is_read_input_mb5 {
+    is_input_replica = pred[] call(), to_apply=is_input_replica
+    i = u32[] parameter(0)
+    c5 = u32[] constant(5)
+    is_input_iteration = pred[] compare(i, c5), direction=LT
+    ROOT is_read_input = pred[] and(is_input_replica, is_input_iteration)
+  }
+  )";
+  return "HloModule " + name + "\n" + kCommonComputationsStr + "\n" +
+         more_computations;
+}
+
+// Naive implementation of pipeline parallelism:
+//   - 4 devices
+//   - 4 microbatches
+//   - no circular repeat
+//   - no disabled collectives
+//   - no collective pipelining
+//
+// Every stage of the pipeline is a single linear layer.
+XLA_TEST_F(CollectivePipelineParallelismTest, NaiveBFSMicrobatch4Replica4) {
+  constexpr char kMoreComputationsStr[] = R"(
   while_condition {
     tuple = (f32[16,16], f32[4,16], f32[4,16], f32[16], u32[]) parameter(0)
     i = u32[] get-tuple-element(tuple), index=4
@@ -172,36 +226,34 @@
     weights = f32[16,16] get-tuple-element(tuple), index=0
     input = f32[4,16] get-tuple-element(tuple), index=1
     output = f32[4,16] get-tuple-element(tuple), index=2
-    tmp = f32[16] get-tuple-element(tuple), index=3
+    prev_iteration_compute_res = f32[16] get-tuple-element(tuple), index=3
     i = u32[] get-tuple-element(tuple), index=4
 
-    c1 = u32[] constant(1)
     c0 = u32[] constant(0)
+    c1 = u32[] constant(1)
     c4 = u32[] constant(4)
 
-    input_idx = u32[] call(c0, i, c4), to_apply=get_circ_buffer_index
-    input_slice = f32[1,16] dynamic-slice(input, input_idx, c0),
-        dynamic_slice_sizes={1,16}
-    input_slice_ = f32[16] reshape(input_slice)
+    // Read from buffers.
+    input_slice = f32[16] call(input, c0, i), to_apply=read_buffer_mb4
 
-    prev_stage_slice = f32[16] collective-permute(tmp),
+    // Shift data to the next stage in the pipeline.
+    prev_stage_slice = f32[16] collective-permute(prev_iteration_compute_res),
         source_target_pairs={{0,1}, {1,2}, {2,3}, {3,0}}
 
+    // Select compute argument from previous stage or from input and perform
+    // compute.
     read_input = pred[] call(), to_apply=is_input_replica
-    compute_in = f32[16] select(read_input, input_slice_, prev_stage_slice)
-
-    compute_out = f32[16] dot(weights, compute_in), lhs_contracting_dims={1},
+    compute_arg = f32[16] select(read_input, input_slice, prev_stage_slice)
+    compute_res = f32[16] dot(weights, compute_arg), lhs_contracting_dims={1},
         rhs_contracting_dims={0}
 
-    output_index = u32[] call(c1, i, c4), to_apply=get_circ_buffer_index
-    output_slice = f32[1,16] reshape(compute_out)
-    output_ = f32[4,16] dynamic-update-slice(output, output_slice, output_index,
-        c0)
+    // Update buffers.
+    output_ = call(output, compute_res, c1, i), to_apply=update_buffer_mb4
 
     i_ = add(i, c1)
 
     ROOT tuple1 = (f32[16,16], f32[4,16], f32[4,16], f32[16], u32[]) tuple(
-        weights, input, output_, compute_out, i_)
+        weights, input, output_, compute_res, i_)
   }
 
   ENTRY main {
@@ -210,11 +262,11 @@
 
     cf0 = f32[] constant(0)
     output = f32[4,16] broadcast(cf0), dimensions={}
-    tmp = f32[16] broadcast(cf0), dimensions={}
+    prev_iteration_compute_res = f32[16] broadcast(cf0), dimensions={}
     c0 = u32[] constant(0)
 
     tuple = (f32[16,16], f32[4,16], f32[4,16], f32[16], u32[]) tuple(weights,
-        input, output, tmp, c0)
+        input, output, prev_iteration_compute_res, c0)
     tuple_ = (f32[16,16], f32[4,16], f32[4,16], f32[16], u32[]) while(tuple),
         condition=while_condition, body=while_body
 
@@ -227,8 +279,11 @@
 
   HloModuleConfig config =
       GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
-  TF_ASSERT_OK_AND_ASSIGN(auto module,
-                          ParseAndReturnVerifiedModule(kModuleStr, config));
+  TF_ASSERT_OK_AND_ASSIGN(
+      auto module,
+      ParseAndReturnVerifiedModule(GetModuleStrWithCommonComputations(
+                                       /*name=*/"test", kMoreComputationsStr),
+                                   config));
 
   // This pipeline consists of 4 layers, each of which is a single linear layer.
   // We assign the weights to the replicas such that the layers scale the input
@@ -260,7 +315,7 @@
   // Check pipeline output for last replica.
   // The combined effect of the pipeline is to scale the input data by 24.0.
   const float kExpectedFactor = 1.0 * 2.0 * 3.0 * 4.0;
-  Literal expected_output = LiteralUtil::CreateFingerprintMatixR2(
+  Literal expected_output = LiteralUtil::CreateFingerprintMatixR2<float>(
       kMicrobatches, kInputSize, kExpectedFactor);
   EXPECT_TRUE(LiteralTestUtil::NearOrEqual(expected_output, results[3],
                                            ErrorSpec{1e-5, 1e-5}));
@@ -274,32 +329,8 @@
 //   - no collective pipelining
 //
 // Every stage of the pipeline is a single linear layer.
-XLA_TEST_F(CollectivePipelineParallelismTest, NaiveDFSMicrobatch5Replica4) {
-  const absl::string_view kModuleStr = R"(
-  HloModule test
-
-  get_circ_buffer_index {
-    offset = u32[] parameter(0)
-    index = u32[] parameter(1)
-    size = u32[] parameter(2)
-    t0 = u32[] add(offset, index)
-    t1 = u32[] divide(t0, size)
-    t2 = u32[] multiply(t1, size)
-    ROOT t4 = u32[] subtract(t0, t2)
-  }
-
-  is_input_replica {
-    replica_id = u32[] replica-id()
-    c0 = u32[] constant(0)
-    ROOT predicate = pred[] compare(replica_id, c0), direction=EQ
-  }
-
-  is_output_replica {
-    replica_id = u32[] replica-id()
-    c1 = u32[] constant(1)
-    ROOT predicate = pred[] compare(replica_id, c1), direction=EQ
-  }
-
+XLA_TEST_F(CollectivePipelineParallelismTest, NaiveBFSMicrobatch5Replica4) {
+  constexpr char kMoreComputationsStr[] = R"(
   while_condition {
     tuple = (f32[16,16], f32[5,16], f32[5,16], f32[16], u32[]) parameter(0)
     i = u32[] get-tuple-element(tuple), index=4
@@ -312,37 +343,35 @@
     weights = f32[16,16] get-tuple-element(tuple), index=0
     input = f32[5,16] get-tuple-element(tuple), index=1
     output = f32[5,16] get-tuple-element(tuple), index=2
-    tmp = f32[16] get-tuple-element(tuple), index=3
+    prev_iteration_compute_res = f32[16] get-tuple-element(tuple), index=3
     i = u32[] get-tuple-element(tuple), index=4
 
+    c0 = u32[] constant(0)
     c1 = u32[] constant(1)
     c2 = u32[] constant(2)
-    c0 = u32[] constant(0)
     c5 = u32[] constant(5)
 
-    input_idx = u32[] call(c0, i, c5), to_apply=get_circ_buffer_index
-    input_slice = f32[1,16] dynamic-slice(input, input_idx, c0),
-        dynamic_slice_sizes={1,16}
-    input_slice_ = f32[16] reshape(input_slice)
+    // Read from buffers.
+    input_slice = f32[16] call(input, c0, i), to_apply=read_buffer_mb5
 
-    prev_stage_slice = f32[16] collective-permute(tmp),
+    // Shift data to the next stage in the pipeline.
+    prev_stage_slice = f32[16] collective-permute(prev_iteration_compute_res),
         source_target_pairs={{0,1}, {1,2}, {2,3}, {3,0}}
 
+    // Select compute argument from previous stage or from input and perform
+    // compute.
     read_input = pred[] call(), to_apply=is_input_replica
-    compute_in = f32[16] select(read_input, input_slice_, prev_stage_slice)
-
-    compute_out = f32[16] dot(weights, compute_in), lhs_contracting_dims={1},
+    compute_arg = f32[16] select(read_input, input_slice, prev_stage_slice)
+    compute_res = f32[16] dot(weights, compute_arg), lhs_contracting_dims={1},
         rhs_contracting_dims={0}
 
-    output_index = u32[] call(c2, i, c5), to_apply=get_circ_buffer_index
-    output_slice = f32[1,16] reshape(compute_out)
-    output_ = f32[5,16] dynamic-update-slice(output, output_slice, output_index,
-        c0)
+    // Update buffers.
+    output_ = call(output, compute_res, c2, i), to_apply=update_buffer_mb5
 
     i_ = add(i, c1)
 
     ROOT tuple1 = (f32[16,16], f32[5,16], f32[5,16], f32[16], u32[])
-        tuple(weights, input, output_, compute_out, i_)
+        tuple(weights, input, output_, compute_res, i_)
   }
 
   ENTRY main {
@@ -351,11 +380,11 @@
 
     cf0 = f32[] constant(0)
     output = f32[5,16] broadcast(cf0), dimensions={}
-    tmp = f32[16] broadcast(cf0), dimensions={}
+    prev_iteration_compute_res = f32[16] broadcast(cf0), dimensions={}
     c0 = u32[] constant(0)
 
     tuple = (f32[16,16], f32[5,16], f32[5,16], f32[16], u32[])
-        tuple(weights, input, output, tmp, c0)
+        tuple(weights, input, output, prev_iteration_compute_res, c0)
     tuple_ = (f32[16,16], f32[5,16], f32[5,16], f32[16], u32[]) while(tuple),
         condition=while_condition, body=while_body
 
@@ -368,8 +397,11 @@
 
   HloModuleConfig config =
       GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
-  TF_ASSERT_OK_AND_ASSIGN(auto module,
-                          ParseAndReturnVerifiedModule(kModuleStr, config));
+  TF_ASSERT_OK_AND_ASSIGN(
+      auto module,
+      ParseAndReturnVerifiedModule(GetModuleStrWithCommonComputations(
+                                       /*name=*/"test", kMoreComputationsStr),
+                                   config));
 
   // This pipeline consists of 4 layers, each of which is a single linear layer.
   // We assign the weights to the replicas such that the layers scale the input
@@ -415,40 +447,8 @@
 //
 // Every stage of the pipeline is a single linear layer.
 XLA_TEST_F(CollectivePipelineParallelismTest,
-           NaiveDFSMicrobatch4CircularRepeat2Replica4) {
-  const absl::string_view kModuleStr = R"(
-  HloModule test
-
-  get_circ_buffer_index {
-    offset = u32[] parameter(0)
-    index = u32[] parameter(1)
-    size = u32[] parameter(2)
-    t0 = u32[] add(offset, index)
-    t1 = u32[] divide(t0, size)
-    t2 = u32[] multiply(t1, size)
-    ROOT t4 = u32[] subtract(t0, t2)
-  }
-
-  is_input_replica {
-    replica_id = u32[] replica-id()
-    c0 = u32[] constant(0)
-    ROOT predicate = pred[] compare(replica_id, c0), direction=EQ
-  }
-
-  is_output_replica {
-    replica_id = u32[] replica-id()
-    c3 = u32[] constant(3)
-    ROOT predicate = pred[] compare(replica_id, c3), direction=EQ
-  }
-
-  is_read_input {
-    is_input_replica = pred[] call(), to_apply=is_input_replica
-    i = u32[] parameter(0)
-    c4 = u32[] constant(4)
-    is_input_iteration = pred[] compare(i, c4), direction=LT
-    ROOT is_read_input = pred[] and(is_input_replica, is_input_iteration)
-  }
-
+           NaiveBFSMicrobatch4CircularRepeat2Replica4) {
+  constexpr char kMoreComputationsStr[] = R"(
   while_condition {
     tuple = (f32[16,16], f32[4,16], f32[4,16], f32[16], u32[]) parameter(0)
     i = u32[] get-tuple-element(tuple), index=4
@@ -461,36 +461,35 @@
     weights = f32[16,16] get-tuple-element(tuple), index=0
     input = f32[4,16] get-tuple-element(tuple), index=1
     output = f32[4,16] get-tuple-element(tuple), index=2
-    tmp = f32[16] get-tuple-element(tuple), index=3
+    prev_iteration_compute_res = f32[16] get-tuple-element(tuple), index=3
     i = u32[] get-tuple-element(tuple), index=4
 
-    c1 = u32[] constant(1)
     c0 = u32[] constant(0)
+    c1 = u32[] constant(1)
     c4 = u32[] constant(4)
 
-    input_idx = u32[] call(c0, i, c4), to_apply=get_circ_buffer_index
-    input_slice = f32[1,16] dynamic-slice(input, input_idx, c0),
-        dynamic_slice_sizes={1,16}
-    input_slice_ = f32[16] reshape(input_slice)
+    // Read from buffers.
+    input_slice = f32[16] call(input, c0, i), to_apply=read_buffer_mb4
 
-    prev_stage_slice = f32[16] collective-permute(tmp),
+    // Shift data to the next stage in the pipeline.
+    prev_stage_slice = f32[16] collective-permute(prev_iteration_compute_res),
         source_target_pairs={{0,1}, {1,2}, {2,3}, {3,0}}
 
-    is_read_input = pred[] call(i), to_apply=is_read_input
-    compute_in = f32[16] select(is_read_input, input_slice_, prev_stage_slice)
-
-    compute_out = f32[16] dot(weights, compute_in), lhs_contracting_dims={1},
+    // Select compute argument from previous stage or from input and perform
+    // compute.
+    is_read_input = pred[] call(i), to_apply=is_read_input_mb4
+    compute_arg = f32[16] select(is_read_input, input_slice, prev_stage_slice)
+    compute_res = f32[16] dot(weights, compute_arg), lhs_contracting_dims={1},
         rhs_contracting_dims={0}
 
-    output_index = u32[] call(c1, i, c4), to_apply=get_circ_buffer_index
-    output_slice = f32[1,16] reshape(compute_out)
-    output_ = f32[4,16] dynamic-update-slice(output, output_slice, output_index,
-        c0)
+    // Update buffers.
+    output_ = f32[4,16] call(output, compute_res, c1, i),
+        to_apply=update_buffer_mb4
 
     i_ = add(i, c1)
 
     ROOT tuple1 = (f32[16,16], f32[4,16], f32[4,16], f32[16], u32[])
-        tuple(weights, input, output_, compute_out, i_)
+        tuple(weights, input, output_, compute_res, i_)
   }
 
   ENTRY main {
@@ -499,11 +498,11 @@
 
     cf0 = f32[] constant(0)
     output = f32[4,16] broadcast(cf0), dimensions={}
-    tmp = f32[16] broadcast(cf0), dimensions={}
+    prev_iteration_compute_res = f32[16] broadcast(cf0), dimensions={}
     c0 = u32[] constant(0)
 
     tuple = (f32[16,16], f32[4,16], f32[4,16], f32[16], u32[]) tuple(weights,
-        input, output, tmp, c0)
+        input, output, prev_iteration_compute_res, c0)
     tuple_ = (f32[16,16], f32[4,16], f32[4,16], f32[16], u32[]) while(tuple),
         condition=while_condition, body=while_body
 
@@ -516,8 +515,11 @@
 
   HloModuleConfig config =
       GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
-  TF_ASSERT_OK_AND_ASSIGN(auto module,
-                          ParseAndReturnVerifiedModule(kModuleStr, config));
+  TF_ASSERT_OK_AND_ASSIGN(
+      auto module,
+      ParseAndReturnVerifiedModule(GetModuleStrWithCommonComputations(
+                                       /*name=*/"test", kMoreComputationsStr),
+                                   config));
 
   // This pipeline consists of a total of 8 layers (2 per replica), each of
   // which is a single linear layer. We assign the weights to the replicas such
@@ -556,7 +558,7 @@
                                            ErrorSpec{1e-5, 1e-5}));
 }
 
-// Naive implementation if pipeline parallelism:
+// Naive implementation of pipeline parallelism:
 //   - 4 devices
 //   - 5 microbatches
 //   - 2 circular repeat
@@ -565,66 +567,8 @@
 //
 // Every stage of the pipeline is a single linear layer.
 XLA_TEST_F(CollectivePipelineParallelismTest,
-           NaiveDFSMicrobatch5CircularRepeat2Replica4) {
-  const absl::string_view kModuleStr = R"(
-  HloModule test
-
-  get_circ_buffer_index {
-    offset = u32[] parameter(0)
-    index = u32[] parameter(1)
-    size = u32[] parameter(2)
-    t0 = u32[] add(offset, index)
-    t1 = u32[] divide(t0, size)
-    t2 = u32[] multiply(t1, size)
-    ROOT t4 = u32[] subtract(t0, t2)
-  }
-
-  read_buffer {
-    buffer = f32[5,16] parameter(0)
-    offset = u32[] parameter(1)
-    index = u32[] parameter(2)
-    c0 = u32[] constant(0)
-    c5 = u32[] constant(5)
-    index_ = u32[] add(index, offset)
-    index__ = u32[] remainder(index_, c5)
-    slice = f32[1,16] dynamic-slice(buffer, index__, c0),
-        dynamic_slice_sizes={1,16}
-    ROOT slice_ = f32[16] reshape(slice)
-  }
-
-  update_buffer {
-    buffer = f32[5,16] parameter(0)
-    update = f32[16] parameter(1)
-    offset = u32[] parameter(2)
-    index = u32[] parameter(3)
-    c0 = u32[] constant(0)
-    c5 = u32[] constant(5)
-    index_ = u32[] add(index, offset)
-    index__ = u32[] remainder(index_, c5)
-    update_ = f32[1,16] reshape(update)
-    ROOT buffer_ = f32[5,16] dynamic-update-slice(buffer, update_, index__, c0)
-  }
-
-  is_input_replica {
-    replica_id = u32[] replica-id()
-    c0 = u32[] constant(0)
-    ROOT predicate = pred[] compare(replica_id, c0), direction=EQ
-  }
-
-  is_output_replica {
-    replica_id = u32[] replica-id()
-    c3 = u32[] constant(3)
-    ROOT predicate = pred[] compare(replica_id, c3), direction=EQ
-  }
-
-  is_read_input {
-    is_input_replica = pred[] call(), to_apply=is_input_replica
-    i = u32[] parameter(0)
-    c5 = u32[] constant(5)
-    is_input_iteration = pred[] compare(i, c5), direction=LT
-    ROOT is_read_input = pred[] and(is_input_replica, is_input_iteration)
-  }
-
+           NaiveBFSMicrobatch5CircularRepeat2Replica4) {
+  constexpr char kMoreComputationsStr[] = R"(
   while_condition {
     tuple = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[])
         parameter(0)
@@ -640,43 +584,46 @@
     input = f32[5,16] get-tuple-element(tuple), index=1
     output = f32[5,16] get-tuple-element(tuple), index=2
     buffer = f32[5,16] get-tuple-element(tuple), index=3
-    prev_iteration_compute_out = f32[16] get-tuple-element(tuple), index=4
+    prev_iteration_compute_res = f32[16] get-tuple-element(tuple), index=4
     i = u32[] get-tuple-element(tuple), index=5
 
     c0 = u32[] constant(0)
     c1 = u32[] constant(1)
     c2 = u32[] constant(2)
     c3 = u32[] constant(3)
+    c4 = u32[] constant(4)
     c5 = u32[] constant(5)
 
-    input_idx = u32[] call(c0, i, c5), to_apply=get_circ_buffer_index
-    input_slice = f32[1,16] dynamic-slice(input, input_idx, c0),
-        dynamic_slice_sizes={1,16}
-    input_slice_ = f32[16] reshape(input_slice)
+    // Read from buffers.
+    input_slice = f32[16] call(input, c0, i), to_apply=read_buffer_mb5
+    buffer_slice = f32[16] call(buffer, c3, i), to_apply=read_buffer_mb5
 
-    buffer_slice = f32[16] call(buffer, c3, i), to_apply=read_buffer
-
+    // Shift data to the next stage in the pipeline.
+    // Directly depends on the updated buffer of the previous iteration and,
+    // therefore, depends on the previous iteration's compute.
     is_output_replica = pred[] call(), to_apply=is_output_replica
     next_stage_slice = select(is_output_replica, buffer_slice,
-        prev_iteration_compute_out)
-
+        prev_iteration_compute_res)
     prev_stage_slice = f32[16] collective-permute(next_stage_slice),
         source_target_pairs={{0,1}, {1,2}, {2,3}, {3,0}}
 
-    is_read_input = pred[] call(i), to_apply=is_read_input
-    compute_in = f32[16] select(is_read_input, input_slice_, prev_stage_slice)
-
-    compute_out = f32[16] dot(weights, compute_in), lhs_contracting_dims={1},
+    // Select compute argument from previous stage or from input and perform
+    // compute.
+    is_read_input = pred[] call(i), to_apply=is_read_input_mb5
+    compute_arg = f32[16] select(is_read_input, input_slice, prev_stage_slice)
+    compute_res = f32[16] dot(weights, compute_arg), lhs_contracting_dims={1},
         rhs_contracting_dims={0}
 
-    output_ = f32[5,16] call(output, compute_out, c2, i), to_apply=update_buffer
-
-    buffer_ = f32[5,16] call(buffer, compute_out, c0, i), to_apply=update_buffer
+    // Update buffers.
+    output_ = f32[5,16] call(output, compute_res, c2, i),
+        to_apply=update_buffer_mb5
+    buffer_ = f32[5,16] call(buffer, compute_res, c0, i),
+        to_apply=update_buffer_mb5
 
     i_ = add(i, c1)
 
     ROOT tuple_ = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[])
-        tuple(weights, input, output_, buffer_, compute_out, i_)
+        tuple(weights, input, output_, buffer_, compute_res, i_)
   }
 
   ENTRY main {
@@ -686,11 +633,12 @@
     cf0 = f32[] constant(0)
     output = f32[5,16] broadcast(cf0), dimensions={}
     buffer = f32[5,16] broadcast(cf0), dimensions={}
-    prev_iteration_compute_out = f32[16] broadcast(cf0), dimensions={}
+    prev_iteration_compute_res = f32[16] broadcast(cf0), dimensions={}
     c0 = u32[] constant(0)
 
+    // Iterate through pipeline stages.
     tuple = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[])
-        tuple(weights, input, output, buffer, prev_iteration_compute_out, c0)
+        tuple(weights, input, output, buffer, prev_iteration_compute_res, c0)
     tuple_ = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[])
         while(tuple), condition=while_condition, body=while_body
 
@@ -703,8 +651,11 @@
 
   HloModuleConfig config =
       GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
-  TF_ASSERT_OK_AND_ASSIGN(auto module,
-                          ParseAndReturnVerifiedModule(kModuleStr, config));
+  TF_ASSERT_OK_AND_ASSIGN(
+      auto module,
+      ParseAndReturnVerifiedModule(GetModuleStrWithCommonComputations(
+                                       /*name=*/"test", kMoreComputationsStr),
+                                   config));
 
   // This pipeline consists of a total of 8 layers (2 per replica), each of
   // which is a single linear layer. We assign the weights to the replicas such
@@ -754,66 +705,8 @@
 //
 // Every stage of the pipeline is a single linear layer.
 XLA_TEST_F(CollectivePipelineParallelismTest,
-           NaiveWoDirectBufferDependencyDFSMicrobatch5CircularRepeat2Replica4) {
-  const absl::string_view kModuleStr = R"(
-  HloModule test
-
-  get_circ_buffer_index {
-    offset = u32[] parameter(0)
-    index = u32[] parameter(1)
-    size = u32[] parameter(2)
-    t0 = u32[] add(offset, index)
-    t1 = u32[] divide(t0, size)
-    t2 = u32[] multiply(t1, size)
-    ROOT t4 = u32[] subtract(t0, t2)
-  }
-
-  read_buffer {
-    buffer = f32[5,16] parameter(0)
-    offset = u32[] parameter(1)
-    index = u32[] parameter(2)
-    c0 = u32[] constant(0)
-    c5 = u32[] constant(5)
-    index_ = u32[] add(index, offset)
-    index__ = u32[] remainder(index_, c5)
-    slice = f32[1,16] dynamic-slice(buffer, index__, c0),
-        dynamic_slice_sizes={1,16}
-    ROOT slice_ = f32[16] reshape(slice)
-  }
-
-  update_buffer {
-    buffer = f32[5,16] parameter(0)
-    update = f32[16] parameter(1)
-    offset = u32[] parameter(2)
-    index = u32[] parameter(3)
-    c0 = u32[] constant(0)
-    c5 = u32[] constant(5)
-    index_ = u32[] add(index, offset)
-    index__ = u32[] remainder(index_, c5)
-    update_ = f32[1,16] reshape(update)
-    ROOT buffer_ = f32[5,16] dynamic-update-slice(buffer, update_, index__, c0)
-  }
-
-  is_input_replica {
-    replica_id = u32[] replica-id()
-    c0 = u32[] constant(0)
-    ROOT predicate = pred[] compare(replica_id, c0), direction=EQ
-  }
-
-  is_output_replica {
-    replica_id = u32[] replica-id()
-    c3 = u32[] constant(3)
-    ROOT predicate = pred[] compare(replica_id, c3), direction=EQ
-  }
-
-  is_read_input {
-    is_input_replica = pred[] call(), to_apply=is_input_replica
-    i = u32[] parameter(0)
-    c5 = u32[] constant(5)
-    is_input_iteration = pred[] compare(i, c5), direction=LT
-    ROOT is_read_input = pred[] and(is_input_replica, is_input_iteration)
-  }
-
+           NaiveWoDirectBufferDependencyBFSMicrobatch5CircularRepeat2Replica4) {
+  constexpr char kMoreComputationsStr[] = R"(
   while_condition {
     tuple = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[])
         parameter(0)
@@ -829,7 +722,7 @@
     input = f32[5,16] get-tuple-element(tuple), index=1
     output = f32[5,16] get-tuple-element(tuple), index=2
     buffer = f32[5,16] get-tuple-element(tuple), index=3
-    prev_iteration_compute_out = f32[16] get-tuple-element(tuple), index=4
+    prev_iteration_compute_res = f32[16] get-tuple-element(tuple), index=4
     i = u32[] get-tuple-element(tuple), index=5
 
     c0 = u32[] constant(0)
@@ -839,38 +732,36 @@
     c4 = u32[] constant(4)
     c5 = u32[] constant(5)
 
-    input_idx = u32[] call(c0, i, c5), to_apply=get_circ_buffer_index
-    input_slice = f32[1,16] dynamic-slice(input, input_idx, c0),
-        dynamic_slice_sizes={1,16}
-    input_slice_ = f32[16] reshape(input_slice)
+    // Read from buffers before they are updated.
+    input_slice = f32[16] call(input, c0, i), to_apply=read_buffer_mb5
+    buffer_slice = f32[16] call(buffer, c3, i), to_apply=read_buffer_mb5
 
-    buffer_slice = f32[16] call(buffer, c3, i), to_apply=read_buffer
-
-    buffer_ = f32[5,16] call(buffer, prev_iteration_compute_out, c4, i),
-        to_apply=update_buffer
-
+    // Shift data to the next stage in the pipeline.
     // Depends on the non-updated buffer of the previous iteration and,
     // therefore, does not depend on the previous iteration's compute.
     is_output_replica = pred[] call(), to_apply=is_output_replica
     next_stage_slice = select(is_output_replica, buffer_slice,
-        prev_iteration_compute_out)
-
-
+        prev_iteration_compute_res)
     prev_stage_slice = f32[16] collective-permute(next_stage_slice),
         source_target_pairs={{0,1}, {1,2}, {2,3}, {3,0}}
 
-    is_read_input = pred[] call(i), to_apply=is_read_input
-    compute_in = f32[16] select(is_read_input, input_slice_, prev_stage_slice)
-
-    compute_out = f32[16] dot(weights, compute_in), lhs_contracting_dims={1},
+    // Select compute argument from previous stage or from input and perform
+    // compute.
+    is_read_input = pred[] call(i), to_apply=is_read_input_mb5
+    compute_arg = f32[16] select(is_read_input, input_slice, prev_stage_slice)
+    compute_res = f32[16] dot(weights, compute_arg), lhs_contracting_dims={1},
         rhs_contracting_dims={0}
 
-    output_ = f32[5,16] call(output, compute_out, c2, i), to_apply=update_buffer
+    // Update buffers.
+    buffer_ = f32[5,16] call(buffer, prev_iteration_compute_res, c4, i),
+        to_apply=update_buffer_mb5
+    output_ = f32[5,16] call(output, compute_res, c2, i),
+        to_apply=update_buffer_mb5
 
     i_ = add(i, c1)
 
     ROOT tuple_ = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[])
-        tuple(weights, input, output_, buffer_, compute_out, i_)
+        tuple(weights, input, output_, buffer_, compute_res, i_)
   }
 
   ENTRY main {
@@ -880,11 +771,12 @@
     cf0 = f32[] constant(0)
     output = f32[5,16] broadcast(cf0), dimensions={}
     buffer = f32[5,16] broadcast(cf0), dimensions={}
-    prev_iteration_compute_out = f32[16] broadcast(cf0), dimensions={}
+    prev_iteration_compute_res = f32[16] broadcast(cf0), dimensions={}
     c0 = u32[] constant(0)
 
+    // Iterate through pipeline stages.
     tuple = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[])
-        tuple(weights, input, output, buffer, prev_iteration_compute_out, c0)
+        tuple(weights, input, output, buffer, prev_iteration_compute_res, c0)
     tuple_ = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[])
         while(tuple), condition=while_condition, body=while_body
 
@@ -897,8 +789,11 @@
 
   HloModuleConfig config =
       GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
-  TF_ASSERT_OK_AND_ASSIGN(auto module,
-                          ParseAndReturnVerifiedModule(kModuleStr, config));
+  TF_ASSERT_OK_AND_ASSIGN(
+      auto module,
+      ParseAndReturnVerifiedModule(GetModuleStrWithCommonComputations(
+                                       /*name=*/"test", kMoreComputationsStr),
+                                   config));
 
   // This pipeline consists of a total of 8 layers (2 per replica), each of
   // which is a single linear layer. We assign the weights to the replicas such
diff --git a/third_party/xla/xla/tests/copy_test.cc b/third_party/xla/xla/tests/copy_test.cc
index 45d94ab..734ecbd 100644
--- a/third_party/xla/xla/tests/copy_test.cc
+++ b/third_party/xla/xla/tests/copy_test.cc
@@ -13,22 +13,31 @@
 limitations under the License.
 ==============================================================================*/
 
+#include <cstddef>
+#include <cstdint>
 #include <memory>
 #include <utility>
+#include <vector>
 
-#include "xla/array2d.h"
+#include <gtest/gtest.h>
+#include "absl/types/span.h"
+#include "xla/array3d.h"
+#include "xla/array4d.h"
 #include "xla/client/xla_builder.h"
 #include "xla/hlo/ir/hlo_computation.h"
 #include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
 #include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/layout.h"
+#include "xla/layout_util.h"
 #include "xla/literal.h"
+#include "xla/literal_util.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
 #include "xla/tests/client_library_test_base.h"
 #include "xla/tests/hlo_test_base.h"
 #include "xla/tests/literal_test_util.h"
 #include "xla/tests/test_macros.h"
 #include "xla/xla_data.pb.h"
-#include "tsl/platform/protobuf.h"
 #include "tsl/platform/test.h"
 
 namespace xla {
@@ -50,6 +59,25 @@
     EXPECT_TRUE(LiteralTestUtil::Equal(literal, result));
   }
 
+  // TODO(vsytch): Remove special handling for dynamic shapes once *all* of XLA
+  // supports those as module inputs/outputs.
+  void TestDynamicCopyOp(const Literal& literal, const Shape& bounded_shape) {
+    Literal dynamic_literal = literal.ToBoundedDynamic(bounded_shape);
+    auto builder = HloComputation::Builder(TestName());
+    auto parameter = builder.AddInstruction(
+        HloInstruction::CreateParameter(0, dynamic_literal.shape(), "param"));
+    builder.AddInstruction(HloInstruction::CreateUnary(
+        parameter->shape(), HloOpcode::kCopy, parameter));
+    auto computation = builder.Build();
+    auto module = CreateNewVerifiedModule();
+    module->AddEntryComputation(std::move(computation));
+
+    std::vector<Literal*> args = {&dynamic_literal};
+    Literal result = ExecuteAndTransfer(std::move(module), args);
+    Literal dynamic_result = result.ToBoundedDynamic(bounded_shape);
+    EXPECT_TRUE(LiteralTestUtil::Equal(dynamic_literal, dynamic_result));
+  }
+
   void TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3);
   void TestCopyConstantLayoutR4(size_t n1, size_t n2, size_t n3, size_t n4,
                                 absl::Span<const int64_t> permutation);
@@ -67,6 +95,47 @@
   TestCopyOp(LiteralUtil::CreateR1<uint32_t>({1, 2, 3}));
 }
 
+XLA_TEST_F(CopyOpTest, CopyDynamicR1S1310720U32Dynamic0) {
+  // TODO(vsytch): CPU emitter doesn't handle dynamic shapes.
+  if (backend().platform()->Name() == "Host") {
+    GTEST_SKIP();
+  }
+  Shape bounded_shape =
+      ShapeUtil::MakeShape(PrimitiveType::F32, {1310720}, {true});
+  TestDynamicCopyOp(LiteralUtil::CreateRandomLiteral<PrimitiveType::F32>(
+                        ShapeUtil::MakeShape(PrimitiveType::F32, {0}), 0, 1)
+                        .value(),
+                    bounded_shape);
+}
+
+XLA_TEST_F(CopyOpTest, CopyDynamicR1S1310720U32Dynamic106632) {
+  // TODO(vsytch): CPU emitter doesn't handle dynamic shapes.
+  if (backend().platform()->Name() == "Host") {
+    GTEST_SKIP();
+  }
+  Shape bounded_shape =
+      ShapeUtil::MakeShape(PrimitiveType::F32, {1310720}, {true});
+  TestDynamicCopyOp(
+      LiteralUtil::CreateRandomLiteral<PrimitiveType::F32>(
+          ShapeUtil::MakeShape(PrimitiveType::F32, {106632}), 0, 1)
+          .value(),
+      bounded_shape);
+}
+
+XLA_TEST_F(CopyOpTest, CopyDynamicR1S1310720U32Dynamic1310720) {
+  // TODO(vsytch): CPU emitter doesn't handle dynamic shapes.
+  if (backend().platform()->Name() == "Host") {
+    GTEST_SKIP();
+  }
+  Shape bounded_shape =
+      ShapeUtil::MakeShape(PrimitiveType::F32, {1310720}, {true});
+  TestDynamicCopyOp(
+      LiteralUtil::CreateRandomLiteral<PrimitiveType::F32>(
+          ShapeUtil::MakeShape(PrimitiveType::F32, {1310720}), 0, 1)
+          .value(),
+      bounded_shape);
+}
+
 XLA_TEST_F(CopyOpTest, CopyR3F32_2x2x3) {
   TestCopyOp(LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
                                     {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
diff --git a/third_party/xla/xla/tests/custom_call_test.cc b/third_party/xla/xla/tests/custom_call_test.cc
index 2d0f370..3fd7cf5 100644
--- a/third_party/xla/xla/tests/custom_call_test.cc
+++ b/third_party/xla/xla/tests/custom_call_test.cc
@@ -26,17 +26,19 @@
 #include "absl/algorithm/container.h"
 #include "absl/base/dynamic_annotations.h"
 #include "absl/status/status.h"
+#include "absl/status/statusor.h"
 #include "absl/strings/str_cat.h"
 #include "absl/strings/str_format.h"
 #include "absl/strings/string_view.h"
+#include "absl/synchronization/blocking_counter.h"
 #include "absl/types/span.h"
-#include "xla/client/lib/constants.h"
+#include "xla/array2d.h"
+#include "xla/array3d.h"
 #include "xla/client/xla_builder.h"
 #include "xla/ffi/ffi.h"
 #include "xla/ffi/ffi_api.h"
 #include "xla/hlo/ir/hlo_computation.h"
 #include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
 #include "xla/hlo/ir/hlo_opcode.h"
 #include "xla/layout_util.h"
 #include "xla/literal.h"
@@ -44,6 +46,7 @@
 #include "xla/primitive_util.h"
 #include "xla/service/custom_call_status.h"
 #include "xla/service/custom_call_target_registry.h"
+#include "xla/service/service.h"
 #include "xla/shape.h"
 #include "xla/shape_util.h"
 #include "xla/tests/client_library_test_base.h"
@@ -55,6 +58,9 @@
 #include "tsl/platform/statusor.h"
 #include "tsl/platform/test.h"
 
+#define EIGEN_USE_THREADS
+#include "unsupported/Eigen/CXX11/Tensor"
+
 namespace {
 void R0F32Add2(float* out, float** in) {
   ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(in, sizeof(float*));
@@ -862,6 +868,40 @@
                          "__xla_test$$HandleTupleDifferentRanks", "Host",
                          kHandleTupleDifferentRanks);
 
+static absl::Status CustomCallWithIntraOpThreadPool(
+    ffi::Result<ffi::AnyBuffer>,
+    const Eigen::ThreadPoolDevice* intra_op_thread_pool) {
+  // We use two blocking counters to ensure that the task is actually running
+  // inside a thread pool.
+  absl::BlockingCounter counter0(1);
+  absl::BlockingCounter counter1(1);
+
+  intra_op_thread_pool->getPool()->Schedule([&]() {
+    counter0.Wait();
+    counter1.DecrementCount();
+  });
+
+  // Unblock submitted task.
+  counter0.DecrementCount();
+
+  // TODO(b/356389210): It is unsafe to wait for the completion of a task
+  // submitted into an intra-op thread pool as we might be running on a thread
+  // inside the same thread pool, and this can lead to deadlocks. Custom calls
+  // should return `AsyncValue` to signal completion of all submitted tasks.
+  counter1.Wait();
+
+  return absl::OkStatus();
+}
+
+XLA_FFI_DEFINE_HANDLER(kIntraOpThreadPool, CustomCallWithIntraOpThreadPool,
+                       ffi::Ffi::Bind()
+                           .Ret<AnyBuffer>()  // unused out buffer
+                           .Ctx<ffi::IntraOpThreadPool>());
+
+XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(),
+                         "__xla_test$$intra_op_thread_pool", "Host",
+                         kIntraOpThreadPool);
+
 }  // namespace
 
 // __xla_test$$ConcatVectors
@@ -1610,5 +1650,19 @@
   EXPECT_EQ(result, expected);
 }
 
+XLA_TEST_F(FfiCustomCallTest, IntraOpThreadPool) {
+  auto module = CreateNewVerifiedModule();
+  auto builder = HloComputation::Builder(TestName());
+
+  builder.AddInstruction(HloInstruction::CreateCustomCall(
+      r0f32_, {}, "__xla_test$$intra_op_thread_pool", "",
+      /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI));
+
+  module->AddEntryComputation(builder.Build());
+
+  auto status = Execute(std::move(module), {}).status();
+  EXPECT_EQ(status, absl::OkStatus());
+}
+
 }  // namespace
 }  // namespace xla
diff --git a/third_party/xla/xla/tests/exhaustive/BUILD b/third_party/xla/xla/tests/exhaustive/BUILD
index 293e103..ed239be 100644
--- a/third_party/xla/xla/tests/exhaustive/BUILD
+++ b/third_party/xla/xla/tests/exhaustive/BUILD
@@ -29,6 +29,7 @@
     deps = [
         "//xla:bit_cast",
         "//xla:executable_run_options",
+        "//xla:fp_util",
         "//xla:literal",
         "//xla:literal_util",
         "//xla:shape_util",
@@ -184,6 +185,7 @@
     deps = [
         ":exhaustive_op_test_utils",
         "//xla:literal",
+        "//xla:types",
         "//xla/client:xla_builder",
         "//xla/tests:xla_internal_test_main",
         "//xla/tsl/util:command_line_flags",
diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_16_bit_test.cc b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_16_bit_test.cc
index f9da77d..e3d0459 100644
--- a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_16_bit_test.cc
+++ b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_16_bit_test.cc
@@ -16,8 +16,10 @@
 #include <array>
 #include <cmath>
 #include <cstdint>
-#include <cstring>
+#include <cstdlib>
+#include <limits>
 #include <tuple>
+#include <type_traits>
 #include <utility>
 
 #include "absl/log/check.h"
@@ -28,6 +30,7 @@
 #include "xla/literal.h"
 #include "xla/tests/exhaustive/exhaustive_op_test_utils.h"
 #include "xla/tests/test_macros.h"
+#include "xla/types.h"
 #include "tsl/platform/test.h"
 
 #ifdef __FAST_MATH__
@@ -115,16 +118,101 @@
   Run(AddEmptyBroadcastDimension(Sub), host_sub);
 })
 
-// TODO(bixia): Mul fails with bfloat16 on CPU.
-BINARY_TEST_16BIT(DISABLED_ON_CPU(Mul), {
-  auto host_mul = [](float x, float y) { return x * y; };
-  Run(AddEmptyBroadcastDimension(Mul), host_mul);
+// Can be thought of as an absolute error of
+// `<= |std::numeric_limits::<float>::min()|`.
+double MulCpuBf16AbsErr(xla::bfloat16 left, xla::bfloat16 right) {
+  float output = static_cast<float>(left) * static_cast<float>(right);
+
+  // Subnormals are flushed to 0 (as inputs or outputs). In these cases, we
+  // calculate 0 instead of the expected very small number so we use the minimum
+  // float value as the absolute error to give a buffer.
+  auto left_is_subnormal = IsSubnormal(left);
+  auto right_is_subnormal = IsSubnormal(right);
+  auto output_is_subnormal = IsSubnormal(output);
+  if (left_is_subnormal || right_is_subnormal || output_is_subnormal) {
+    return std::numeric_limits<float>::min();
+  }
+
+  return 0.0;
+}
+
+bool MulCpuBf16Skip(xla::bfloat16 left, xla::bfloat16 right) {
+  // For BF16, multiplying a subnormal by infinity will lead to calculating 0
+  // multiplied by infinity due to subnormal flushing, which is defined to be
+  // NaN. However, the calculation in higher precision does not flush the
+  // subnormal value to 0, leading to a result of infinity.
+  auto left_is_subnormal = IsSubnormal(left);
+  auto left_is_infinite = std::isinf(left);
+  auto right_is_subnormal = IsSubnormal(right);
+  auto right_is_infinite = std::isinf(right);
+  if ((left_is_subnormal && right_is_infinite) ||
+      (left_is_infinite && right_is_subnormal)) {
+    return true;
+  }
+
+  return false;
+}
+
+BINARY_TEST_16BIT(Mul, {
+  ErrorSpecGen error_spec_gen = +[](NativeT left, NativeT right) {
+    return ErrorSpec::Builder().strict_signed_zeros().build();
+  };
+  if (IsCpu(platform_)) {
+    if constexpr (std::is_same_v<NativeT, xla::bfloat16>) {
+      error_spec_gen = +[](NativeT left, NativeT right) {
+        return ErrorSpec::Builder()
+            .abs_err(MulCpuBf16AbsErr(static_cast<xla::bfloat16>(left),
+                                      static_cast<xla::bfloat16>(right)))
+            .strict_signed_zeros()
+            .skip_comparison(MulCpuBf16Skip(static_cast<xla::bfloat16>(left),
+                                            static_cast<xla::bfloat16>(right)))
+            .build();
+      };
+    }
+  }
+  Run(
+      AddEmptyBroadcastDimension(Mul), [](float x, float y) { return x * y; },
+      error_spec_gen);
 })
 
-// TODO(bixia): Div fails with bfloat16 on CPU.
-BINARY_TEST_16BIT(DISABLED_ON_CPU(Div), {
-  auto host_div = [](float x, float y) { return x / y; };
-  Run(AddEmptyBroadcastDimension(Div), host_div);
+// Can be thought of as an absolute error of
+// `<= |std::numeric_limits::<float>::min()|`.
+double DivCpuBf16AbsErr(xla::bfloat16 left, xla::bfloat16 right) {
+  float output = static_cast<float>(left) / static_cast<float>(right);
+
+  // Subnormals are flushed to 0 so we add a absolute error margin that is
+  // larger than any subnormal.
+  auto output_is_subnormal = IsSubnormal(output);
+  if (output_is_subnormal) {
+    return std::numeric_limits<float>::min();
+  }
+
+  return 0.0;
+}
+
+BINARY_TEST_16BIT(Div, {
+  ErrorSpecGen error_spec_gen = +[](NativeT, NativeT) {
+    return ErrorSpec::Builder().strict_signed_zeros().build();
+  };
+  if (IsCpu(platform_)) {
+    if constexpr (std::is_same_v<NativeT, xla::bfloat16>) {
+      error_spec_gen = +[](NativeT left, NativeT right) {
+        return ErrorSpec::Builder()
+            .abs_err(DivCpuBf16AbsErr(static_cast<xla::bfloat16>(left),
+                                      static_cast<xla::bfloat16>(right)))
+            .strict_signed_zeros()
+            .build();
+      };
+    }
+  }
+  if (IsGpu(platform_) && std::is_same_v<NativeT, xla::half>) {
+    error_spec_gen = +[](NativeT, NativeT) {
+      return ErrorSpec::Builder().distance_err(1).strict_signed_zeros().build();
+    };
+  }
+  Run(
+      AddEmptyBroadcastDimension(Div), [](float x, float y) { return x / y; },
+      error_spec_gen);
 })
 
 BINARY_TEST_16BIT(Max, {
@@ -135,21 +223,135 @@
   Run(AddEmptyBroadcastDimension(Min), ReferenceMin<float>);
 })
 
-// TODO(bixia): Pow fails with bfloat16 on CPU.
-BINARY_TEST_16BIT(DISABLED_ON_GPU(DISABLED_ON_CPU(Pow)), {
-  // See b/162664705.
-  known_incorrect_fn_ = [](int64_t val) {
-    Eigen::bfloat16 f;
-    uint16_t val_16 = val;
-    memcpy(&f, &val_16, 2);
-    return std::isnan(f);
+template <typename NativeT>
+bool PowCpuGpuF16Skip(NativeT left, NativeT right) {
+  // Hardware seems to always return 1 if right is 0, no matter if left is NaN.
+  if (std::isnan(left) && right == 0) {
+    return true;
+  }
+  // Hardware seems to always return 1 if left is 1, no matter if right is NaN.
+  if (left == 1 && std::isnan(right)) {
+    return true;
+  }
+  return false;
+}
+
+double PowCpuBf16AbsErr(xla::bfloat16 left, xla::bfloat16 right) {
+  float output = std::pow(static_cast<float>(left), static_cast<float>(right));
+
+  // Output is flushed to 0 if subnormal.
+  if (IsSubnormal(output)) {
+    return std::numeric_limits<float>::min();
+  }
+
+  // TODO(b/359325328): pow computation for subnormal bases is different from
+  // std::pow.
+  //
+  // If the base is subnormal, the output computation selects a different base.
+  // The minimum value ever chosen is slightly greater than the 1e-91 used
+  // below. We return an absolute error from this value to the "real" output.
+  //
+  // Because the exponent (right) can be any floating point value, this allows
+  // an arbitrary absolute error for subnormal values.
+  if (IsSubnormal(left)) {
+    xla::bfloat16 output_as_bf16 = static_cast<xla::bfloat16>(output);
+    auto expected = std::pow(1e-91, static_cast<double>(right));
+    auto err = std::abs(expected - output_as_bf16);
+    if (!std::isnan(err)) {
+      return err;
+    }
+  }
+
+  return 0.0;
+}
+
+BINARY_TEST_16BIT(Pow, {
+  ErrorSpecGen error_spec_gen = +[](NativeT, NativeT) {
+    return ErrorSpec::Builder().strict_signed_zeros().build();
   };
-  Run(AddEmptyBroadcastDimension(Pow), std::pow);
+  if (IsCpu(platform_)) {
+    if constexpr (std::is_same_v<NativeT, xla::half>) {
+      error_spec_gen = +[](NativeT left, NativeT right) {
+        return ErrorSpec::Builder()
+            .strict_signed_zeros()
+            .skip_comparison(PowCpuGpuF16Skip(left, right))
+            .build();
+      };
+    }
+    if constexpr (std::is_same_v<NativeT, xla::bfloat16>) {
+      error_spec_gen = +[](NativeT left, NativeT right) {
+        return ErrorSpec::Builder()
+            .abs_err(PowCpuBf16AbsErr(static_cast<xla::bfloat16>(left),
+                                      static_cast<xla::bfloat16>(right)))
+            .strict_signed_zeros()
+            .build();
+      };
+    }
+  }
+  if (IsGpu(platform_)) {
+    error_spec_gen = +[](NativeT left, NativeT right) {
+      return ErrorSpec::Builder()
+          .distance_err(1)
+          .strict_signed_zeros()
+          .skip_comparison(PowCpuGpuF16Skip(left, right))
+          .build();
+    };
+  }
+  Run(AddEmptyBroadcastDimension(Pow), std::pow, error_spec_gen);
 })
 
-// TODO(bixia): Atan2 fails with bfloat16 on CPU.
-BINARY_TEST_16BIT(DISABLED_ON_CPU(Atan2),
-                  { Run(AddEmptyBroadcastDimension(Atan2), std::atan2); })
+// Can be thought of as an absolute error of
+// `<= |std::numeric_limits::<float>::min()|`.
+double Atan2CpuBf16AbsErr(xla::bfloat16 left, xla::bfloat16 right) {
+  float output =
+      std::atan2(static_cast<float>(left), static_cast<float>(right));
+
+  // If the output would be a subnormal float, we allow some error to account
+  // for BF16 implementation flushing subnormals to zero.
+  auto output_is_subnormal = IsSubnormal(output);
+  if (output_is_subnormal) {
+    return std::numeric_limits<float>::min();
+  }
+
+  return 0.0;
+}
+
+bool Atan2CpuBf16Skip(xla::bfloat16 left, xla::bfloat16 right) {
+  // Subnormals are flushed to 0, but 0/0 returns NaN instead of
+  // <subnormal>/<subnormal> which returns some positive number. We cannot set
+  // an error to compare against NaN.
+  if (IsSubnormal(left) && IsSubnormal(right)) {
+    return true;
+  }
+
+  return false;
+}
+
+BINARY_TEST_16BIT(Atan2, {
+  auto error_spec_gen = +[](NativeT, NativeT) {
+    return ErrorSpec::Builder().strict_signed_zeros().build();
+  };
+  if (IsCpu(platform_)) {
+    if constexpr (std::is_same_v<NativeT, xla::bfloat16>) {
+      error_spec_gen = +[](NativeT left, NativeT right) {
+        return ErrorSpec::Builder()
+            .abs_err(Atan2CpuBf16AbsErr(static_cast<xla::bfloat16>(left),
+                                        static_cast<xla::bfloat16>(right)))
+            .strict_signed_zeros()
+            .skip_comparison(
+                Atan2CpuBf16Skip(static_cast<xla::bfloat16>(left),
+                                 static_cast<xla::bfloat16>(right)))
+            .build();
+      };
+    }
+  }
+  if (IsGpu(platform_)) {
+    error_spec_gen = +[](NativeT, NativeT) {
+      return ErrorSpec::Builder().distance_err(1).strict_signed_zeros().build();
+    };
+  }
+  Run(AddEmptyBroadcastDimension(Atan2), std::atan2, error_spec_gen);
+})
 
 #if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16)
 INSTANTIATE_TEST_SUITE_P(F16, ExhaustiveF16BinaryTest,
diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.cc b/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.cc
index 6c791a8..f677539 100644
--- a/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.cc
+++ b/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.cc
@@ -15,6 +15,7 @@
 
 #include "xla/tests/exhaustive/exhaustive_op_test_utils.h"
 
+#include <algorithm>
 #include <array>
 #include <cmath>
 #include <cstddef>
@@ -90,6 +91,64 @@
   return IsMinNormal(value.imag());
 }
 
+/*static*/ ErrorSpec::Builder builder() { return ErrorSpecBuilder(); }
+
+ErrorSpecBuilder& ErrorSpecBuilder::abs_err(double abs_err) & {
+  spec_.abs_err = abs_err;
+  return *this;
+}
+
+ErrorSpecBuilder& ErrorSpecBuilder::rel_err(double rel_err) & {
+  spec_.rel_err = rel_err;
+  return *this;
+}
+
+ErrorSpecBuilder& ErrorSpecBuilder::distance_err(int64_t distance_err) & {
+  spec_.distance_err = distance_err;
+  return *this;
+}
+
+ErrorSpecBuilder& ErrorSpecBuilder::strict_signed_zeros(
+    bool strict_signed_zeros) & {
+  spec_.strict_signed_zeros = strict_signed_zeros;
+  return *this;
+}
+
+ErrorSpecBuilder& ErrorSpecBuilder::skip_comparison(bool skip_comparison) & {
+  spec_.skip_comparison = skip_comparison;
+  return *this;
+}
+
+ErrorSpecBuilder&& ErrorSpecBuilder::abs_err(double abs_err) && {
+  spec_.abs_err = abs_err;
+  return std::move(*this);
+}
+
+ErrorSpecBuilder&& ErrorSpecBuilder::rel_err(double rel_err) && {
+  spec_.rel_err = rel_err;
+  return std::move(*this);
+}
+
+ErrorSpecBuilder&& ErrorSpecBuilder::distance_err(int64_t distance_err) && {
+  spec_.distance_err = distance_err;
+  return std::move(*this);
+}
+
+ErrorSpecBuilder&& ErrorSpecBuilder::strict_signed_zeros(
+    bool strict_signed_zeros) && {
+  spec_.strict_signed_zeros = strict_signed_zeros;
+  return std::move(*this);
+}
+
+ErrorSpecBuilder&& ErrorSpecBuilder::skip_comparison(bool skip_comparison) && {
+  spec_.skip_comparison = skip_comparison;
+  return std::move(*this);
+}
+
+ErrorSpecBuilder::operator ErrorSpec() && { return std::move(*this).build(); }
+
+ErrorSpec ErrorSpecBuilder::build() && { return spec_; }
+
 // For f64, f32, f16, and bf16, we need 17, 9, 5, and 4 decimal places of
 // precision to be guaranteed that we're printing the full number.
 //
@@ -491,6 +550,9 @@
     }
 
     ErrorSpec error_spec = CallErrorSpec(error_spec_gen, inputs);
+    ASSERT_GE(error_spec.abs_err, 0.0);
+    ASSERT_GE(error_spec.rel_err, 0.0);
+    ASSERT_GE(error_spec.distance_err, 0.0);
 
     if (error_spec.skip_comparison) {
       PrintSkipped(&skipped, [&] {
@@ -552,7 +614,7 @@
           result = pure_subnormal_cache[cache_loc];
         }
       } else {
-        result = result = CallOperation(evaluate_op, test_value);
+        result = CallOperation(evaluate_op, test_value);
       }
 
       if (IsClose(result, static_cast<NativeRefT>(actual), error_spec)) {
diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.h b/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.h
index 6a491be..6465350 100644
--- a/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.h
+++ b/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.h
@@ -44,6 +44,7 @@
 #include "xla/client/xla_builder.h"
 #include "xla/client/xla_computation.h"
 #include "xla/executable_run_options.h"
+#include "xla/fp_util.h"
 #include "xla/literal.h"
 #include "xla/literal_util.h"
 #include "xla/primitive_util.h"
@@ -140,9 +141,40 @@
   return IsSubnormal(value) || IsMinNormal(value);
 }
 
+// Get the floating point distance (number of floating point values between)
+// expected and actual.
+//
+// This is a wrapper around xla::CalculateDistanceInFloats for most types. For
+// complex types, this returns the maximum distance between the real and
+// imaginary components.
+template <typename NativeT>
+int64_t GetDistanceErr(NativeT expected, NativeT actual) {
+  if constexpr (std::is_same_v<NativeT, xla::complex64> ||
+                std::is_same_v<NativeT, xla::complex128>) {
+    return std::max(
+        CalculateDistanceInFloats(expected.real(), actual.real()),
+        CalculateDistanceInFloats(expected.imag(), expected.imag()));
+  } else {
+    return CalculateDistanceInFloats(expected, actual);
+  }
+}
+
+class ErrorSpecBuilder;
+
 struct ErrorSpec {
-  double abs_err = 0;
-  double rel_err = 0;
+  using Builder = ErrorSpecBuilder;
+
+  double abs_err = 0.0;
+  double rel_err = 0.0;
+  // The acceptable amount of floating point values between the expected and
+  // actual (also calling floating point distance).
+  //
+  // This is similar to absolute error, but the same distance_err can have
+  // different floating point values as the exponent changes. In some way, it is
+  // a hybrid of absolute and relative error, as it allows a fixed binary
+  // difference (like abs_err), but that has a varied floating point value based
+  // on the number (like rel_err).
+  int64_t distance_err = 0;
   // If true, will consider -0 not near to +0 and vice versa.  Note that
   // +epsilon may still be considered close to -0, depending on the error
   // spec; this only covers the case when both `expected` and `actual` are
@@ -154,6 +186,35 @@
   bool skip_comparison = false;
 };
 
+// Builder pattern to construct an ErrorSpec without a proliferation of
+// constructors or requiring extensive argument name comments.
+//
+// You can use an lvalue or rvalue to call the setter functions, but you can
+// only build (explicitly or implicitly) using an rvalue from std::move.
+class ErrorSpecBuilder {
+ public:
+  ErrorSpecBuilder() : spec_() {}
+
+  ErrorSpecBuilder& abs_err(double abs_err) &;
+  ErrorSpecBuilder& rel_err(double rel_err) &;
+  ErrorSpecBuilder& distance_err(int64_t distance_err) &;
+  ErrorSpecBuilder& strict_signed_zeros(bool strict_signed_zeros = true) &;
+  ErrorSpecBuilder& skip_comparison(bool skip_comparison = true) &;
+
+  ErrorSpecBuilder&& abs_err(double abs_err) &&;
+  ErrorSpecBuilder&& rel_err(double rel_err) &&;
+  ErrorSpecBuilder&& distance_err(int64_t distance_err) &&;
+  ErrorSpecBuilder&& strict_signed_zeros(bool strict_signed_zeros = true) &&;
+  ErrorSpecBuilder&& skip_comparison(bool skip_comparison = true) &&;
+
+  ErrorSpec build() &&;
+
+  explicit operator ErrorSpec() &&;
+
+ private:
+  ErrorSpec spec_;
+};
+
 // Representations of the reference function passed in by the user.
 template <typename NativeRefT, size_t K>
 struct EvaluateOpWrapper {};
@@ -592,8 +653,12 @@
     double abs_err =
         std::abs(ReplaceInfWithMax(expected) - ReplaceInfWithMax(actual));
     double rel_err = abs_err / std::abs(ReplaceInfWithMax(expected));
+    // N.B.: For sub-32-bit floats, NativeRefT is `float`, so ULP comparisons
+    // will be wildly off. We convert back to NativeT for this comparison.
+    int64_t distance_err = GetDistanceErr(NativeT(expected), NativeT(actual));
 
-    return abs_err <= spec.abs_err || rel_err <= spec.rel_err;
+    return abs_err <= spec.abs_err || rel_err <= spec.rel_err ||
+           distance_err <= spec.distance_err;
   }
 
   // Converts part or all bits in an uint64_t to the value of the floating point
@@ -1089,7 +1154,7 @@
       kDefaultAbsoluteToleranceSlackFactor * std::numeric_limits<double>::min();
   double rtol = kDefaultRelativeToleranceSlackFactor *
                 std::numeric_limits<double>::epsilon();
-  return ErrorSpec{atol, rtol};
+  return ErrorSpec::Builder().abs_err(atol).rel_err(rtol).build();
 }
 
 template <>
@@ -1097,7 +1162,7 @@
   double atol =
       kDefaultAbsoluteToleranceSlackFactor * std::numeric_limits<float>::min();
   double rtol = 40 * std::numeric_limits<float>::epsilon();
-  return ErrorSpec{atol, rtol};
+  return ErrorSpec::Builder().abs_err(atol).rel_err(rtol).build();
 }
 
 template <>
@@ -1106,7 +1171,7 @@
       kDefaultAbsoluteToleranceSlackFactor * std::numeric_limits<double>::min();
   double rtol = kDefaultRelativeToleranceSlackFactor *
                 std::numeric_limits<double>::epsilon();
-  return ErrorSpec{atol, rtol};
+  return ErrorSpec::Builder().abs_err(atol).rel_err(rtol).build();
 }
 
 template <>
@@ -1115,7 +1180,7 @@
       kDefaultAbsoluteToleranceSlackFactor * std::numeric_limits<float>::min();
   double rtol = kDefaultRelativeToleranceSlackFactor *
                 std::numeric_limits<float>::epsilon();
-  return ErrorSpec{atol, rtol};
+  return ErrorSpec::Builder().abs_err(atol).rel_err(rtol).build();
 }
 
 template <>
@@ -1124,7 +1189,7 @@
                 std::numeric_limits<Eigen::half>::min();
   // epsilon for FP16 is quite large, so a slack factor of 5 suffices.
   double rtol = 5 * std::numeric_limits<Eigen::half>::epsilon();
-  return ErrorSpec{atol, rtol};
+  return ErrorSpec::Builder().abs_err(atol).rel_err(rtol).build();
 }
 
 template <>
@@ -1133,7 +1198,7 @@
                 std::numeric_limits<bfloat16>::min();
   // epsilon for BF16 is quite large, so a slack factor of 2 suffices.
   double rtol = 2 * std::numeric_limits<bfloat16>::epsilon();
-  return ErrorSpec{atol, rtol};
+  return ErrorSpec::Builder().abs_err(atol).rel_err(rtol).build();
 }
 
 template <>
@@ -1142,7 +1207,7 @@
       kDefaultAbsoluteToleranceSlackFactor * std::numeric_limits<double>::min();
   double rtol = kDefaultRelativeToleranceSlackFactor *
                 std::numeric_limits<double>::epsilon();
-  return ErrorSpec{atol, rtol};
+  return ErrorSpec::Builder().abs_err(atol).rel_err(rtol).build();
 }
 
 template <>
@@ -1151,7 +1216,7 @@
       kDefaultAbsoluteToleranceSlackFactor * std::numeric_limits<float>::min();
   double rtol = kDefaultRelativeToleranceSlackFactor *
                 std::numeric_limits<float>::epsilon();
-  return ErrorSpec{atol, rtol};
+  return ErrorSpec::Builder().abs_err(atol).rel_err(rtol).build();
 }
 
 template <>
@@ -1160,7 +1225,7 @@
                 std::numeric_limits<Eigen::half>::min();
   // epsilon for FP16 is quite large, so a slack factor of 5 suffices.
   double rtol = 5 * std::numeric_limits<Eigen::half>::epsilon();
-  return ErrorSpec{atol, rtol};
+  return ErrorSpec::Builder().abs_err(atol).rel_err(rtol).build();
 }
 
 template <>
@@ -1169,7 +1234,7 @@
                 std::numeric_limits<bfloat16>::min();
   // epsilon for BF16 is quite large, so a slack factor of 5 suffices.
   double rtol = 2 * std::numeric_limits<bfloat16>::epsilon();
-  return ErrorSpec{atol, rtol};
+  return ErrorSpec::Builder().abs_err(atol).rel_err(rtol).build();
 }
 
 template <PrimitiveType T, size_t N>
@@ -1227,7 +1292,13 @@
 };
 
 template <PrimitiveType T>
-using ExhaustiveBinaryTest = ExhaustiveOpTestBase<T, 2>;
+class ExhaustiveBinaryTest : public ExhaustiveOpTestBase<T, 2> {
+ public:
+  using typename ExhaustiveOpTestBase<T, 2>::ErrorSpecGen;
+  static ErrorSpecGen GetDefaultSpecGenerator() {
+    return exhaustive_op_test::GetDefaultSpecGenerator<T, 2>();
+  }
+};
 
 }  // namespace exhaustive_op_test
 }  // namespace xla
diff --git a/third_party/xla/xla/tests/hlo_test_base.cc b/third_party/xla/xla/tests/hlo_test_base.cc
index 8c47362..6df8322 100644
--- a/third_party/xla/xla/tests/hlo_test_base.cc
+++ b/third_party/xla/xla/tests/hlo_test_base.cc
@@ -16,6 +16,7 @@
 #include "xla/tests/hlo_test_base.h"
 
 #include <functional>
+#include <iterator>
 #include <memory>
 #include <set>
 #include <string>
@@ -27,6 +28,9 @@
 #include "absl/status/statusor.h"
 #include "absl/types/span.h"
 #include "xla/debug_options_flags.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/hlo/utils/hlo_query.h"
 #include "xla/layout_util.h"
 #include "xla/service/hlo_module_util.h"
 #include "xla/service/hlo_parser.h"
@@ -1010,23 +1014,15 @@
 
 HloComputation* HloTestBase::FindComputation(HloModule* module,
                                              absl::string_view name) {
-  auto computations = module->computations();
-  auto it = absl::c_find_if(
-      computations, [&](HloComputation* c) { return c->name() == name; });
-  if (it == computations.end()) {
-    return nullptr;
-  }
-  return *it;
+  return hlo_query::FindComputation(module, name);
 }
 
 HloInstruction* HloTestBase::FindInstruction(HloModule* module,
                                              absl::string_view name) {
-  for (const HloComputation* c : module->computations()) {
-    auto instructions = c->instructions();
-    auto it = absl::c_find_if(
-        instructions, [&](HloInstruction* i) { return i->name() == name; });
-    if (it != instructions.end()) {
-      return *it;
+  for (const HloComputation* computation : module->computations()) {
+    if (auto instruction = hlo_query::FindFirstInstruction(computation, name);
+        instruction.first != nullptr) {
+      return instruction.first;
     }
   }
   return nullptr;
@@ -1034,17 +1030,25 @@
 
 HloInstruction* HloTestBase::FindInstruction(HloModule* module,
                                              HloOpcode opcode) {
-  for (const HloComputation* c : module->computations()) {
-    auto instructions = c->instructions();
-    auto it = absl::c_find_if(
-        instructions, [&](HloInstruction* i) { return i->opcode() == opcode; });
-    if (it != instructions.end()) {
-      return *it;
+  for (const HloComputation* computation : module->computations()) {
+    if (auto instruction = hlo_query::FindFirstInstruction(computation, opcode);
+        instruction.first != nullptr) {
+      return instruction.first;
     }
   }
   return nullptr;
 }
 
+std::vector<HloInstruction*> HloTestBase::FindInstructions(HloModule* module,
+                                                           HloOpcode opcode) {
+  std::vector<HloInstruction*> instructions;
+  for (const HloComputation* c : module->computations()) {
+    absl::c_copy_if(c->instructions(), std::back_inserter(instructions),
+                    [&](HloInstruction* i) { return i->opcode() == opcode; });
+  }
+  return instructions;
+}
+
 se::DeviceMemoryAllocator* HloTestBase::GetAllocator() {
   if (allocator_ == nullptr) {
     allocator_ = std::make_unique<se::StreamExecutorMemoryAllocator>(
diff --git a/third_party/xla/xla/tests/hlo_test_base.h b/third_party/xla/xla/tests/hlo_test_base.h
index 9858ed6..4c194c8 100644
--- a/third_party/xla/xla/tests/hlo_test_base.h
+++ b/third_party/xla/xla/tests/hlo_test_base.h
@@ -25,8 +25,10 @@
 
 #include "absl/status/statusor.h"
 #include "absl/types/span.h"
+#include "xla/hlo/ir/hlo_instruction.h"
 #include "xla/hlo/ir/hlo_module.h"
 #include "xla/hlo/ir/hlo_module_group.h"
+#include "xla/hlo/ir/hlo_opcode.h"
 #include "xla/service/backend.h"
 #include "xla/service/computation_layout.h"
 #include "xla/service/hlo_runner.h"
@@ -423,13 +425,19 @@
   }
 
   // Gets the computation/instruction from the given module with the given name.
-  //
+  // Note that it is encouraged to use these functions directly via the
+  // hlo_query.h header instead since they are independent from any test-time
+  // variables or contexts.
+
   // This is useful for tests which create HLOs from a string and then want to
   // inspect a particular computation or instruction.
   HloComputation* FindComputation(HloModule* module, absl::string_view name);
   HloInstruction* FindInstruction(HloModule* module, absl::string_view name);
   // Gets the instruction from the given module with the given opcode.
   HloInstruction* FindInstruction(HloModule* module, HloOpcode opcode);
+  // Gets all the instructions from the given module with the given opcode.
+  std::vector<HloInstruction*> FindInstructions(HloModule* module,
+                                                HloOpcode opcode);
 
   // Return an HLO verifier constructed for the test backend.
   HloVerifier& verifier() const { return *hlo_verifier_; }
@@ -438,6 +446,7 @@
 
   // Returns the backend owned by the test runner.
   Backend& backend();
+  int64_t num_devices() { return backend().device_count(); }
 
   HloRunner test_runner_;
   HloRunner reference_runner_;
@@ -513,6 +522,13 @@
       se::Platform* test_platform);
 };
 
+#define SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(x)                      \
+  int64_t num_devices = backend().device_count();                  \
+  if (num_devices < x) {                                           \
+    GTEST_SKIP() << "Test requires at least " << x << " devices (" \
+                 << num_devices << " available)";                  \
+  }
+
 }  // namespace xla
 
 #endif  // XLA_TESTS_HLO_TEST_BASE_H_
diff --git a/third_party/xla/xla/tests/numerics_test.cc b/third_party/xla/xla/tests/numerics_test.cc
index 8a54242..b1bfcd9 100644
--- a/third_party/xla/xla/tests/numerics_test.cc
+++ b/third_party/xla/xla/tests/numerics_test.cc
@@ -24,6 +24,7 @@
 #include "xla/tests/hlo_test_base.h"
 #include "xla/tests/test_macros.h"
 #include "xla/types.h"
+#include "tsl/platform/statusor.h"
 #include "tsl/platform/test.h"
 
 namespace xla {
@@ -86,5 +87,79 @@
       std::numeric_limits<float>::quiet_NaN(), 0));
 }
 
+// Case from one of XLA users, the following code produced incorrect results on
+// CPU thunks backend (due to incorrect LLVM IR generated).
+// This is an HLO module optimized for CPU backend, it may be invalid for other
+// backends.
+XLA_TEST_F(NumericsTest,
+           DISABLED_ON_GPU(DISABLED_ON_TPU(MultiplySubtractConcatTest))) {
+  const char* test_hlo = R"(
+    HloModule jit_step, is_scheduled=true
+
+    fused_computation {
+      param_0.2 = f32[1,5] parameter(0)
+      slice.11 = f32[1,1] slice(param_0.2), slice={[0:1], [1:2]}
+      slice.10 = f32[1,1] slice(param_0.2), slice={[0:1], [4:5]}
+      multiply.11 = f32[1,1] multiply(slice.11, slice.10)
+      slice.9 = f32[1,1] slice(param_0.2), slice={[0:1], [2:3]}
+      slice.8 = f32[1,1] slice(param_0.2), slice={[0:1], [3:4]}
+      multiply.10 = f32[1,1] multiply(slice.9, slice.8)
+      subtract.5 = f32[1,1] subtract(multiply.11, multiply.10)
+      slice.6 = f32[1,1] slice(param_0.2), slice={[0:1], [0:1]}
+      multiply.8 = f32[1,1] multiply(slice.6, slice.10)
+      subtract.4 = f32[1,1] subtract(slice.9, multiply.8)
+      ROOT concatenate.1 = f32[1,3] concatenate(
+        subtract.5, subtract.4, subtract.4), dimensions={1}
+    } // fused_computation
+
+    ENTRY main {
+      Arg_0.0 = f32[1,5] parameter(0)
+      ROOT fusion = f32[1,3] fusion(Arg_0.0), kind=kLoop,
+        calls=fused_computation
+    } // main
+  )";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto test_module,
+                          ParseAndReturnVerifiedModule(test_hlo));
+  auto argument = LiteralUtil::CreateR2<float>(
+      {{0.261473775, -0.642940283, -0.719902277, 0.712947428, 0.543724537}});
+
+  TF_ASSERT_OK_AND_ASSIGN(auto test_result,
+                          Execute(std::move(test_module), {&argument},
+                                  /*run_hlo_passes=*/false));
+
+  // Reference HLO module. It's a subgraph of the test module, it performs only
+  // the calculations needed for the first output element from the test module.
+  const char* reference_hlo = R"(
+    HloModule jit_step, is_scheduled=true
+
+    fused_computation {
+      param_0.2 = f32[1,5] parameter(0)
+      slice.11 = f32[1,1] slice(param_0.2), slice={[0:1], [1:2]}
+      slice.10 = f32[1,1] slice(param_0.2), slice={[0:1], [4:5]}
+      multiply.11 = f32[1,1] multiply(slice.11, slice.10)
+      slice.9 = f32[1,1] slice(param_0.2), slice={[0:1], [2:3]}
+      slice.8 = f32[1,1] slice(param_0.2), slice={[0:1], [3:4]}
+      multiply.10 = f32[1,1] multiply(slice.9, slice.8)
+      ROOT subtract.5 = f32[1,1] subtract(multiply.11, multiply.10)
+    } // fused_computation
+
+    ENTRY main {
+      Arg_0.0 = f32[1,5] parameter(0)
+      ROOT fusion = f32[1,1] fusion(Arg_0.0), kind=kLoop,
+        calls=fused_computation
+    } // main
+  )";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto reference_module,
+                          ParseAndReturnVerifiedModule(reference_hlo));
+  TF_ASSERT_OK_AND_ASSIGN(auto reference_result,
+                          Execute(std::move(reference_module), {&argument},
+                                  /*run_hlo_passes=*/false));
+
+  // Only compare the first element.
+  EXPECT_EQ(reference_result.data<float>()[0], test_result.data<float>()[0]);
+}
+
 }  // namespace
 }  // namespace xla
diff --git a/third_party/xla/xla/tests/replicated_io_feed_test.cc b/third_party/xla/xla/tests/replicated_io_feed_test.cc
index 9ee34a7..0164f8b 100644
--- a/third_party/xla/xla/tests/replicated_io_feed_test.cc
+++ b/third_party/xla/xla/tests/replicated_io_feed_test.cc
@@ -50,7 +50,10 @@
     result = u32[] add(infeed.data, replica_id)
     outfeed = token[] outfeed(result, infeed.token), outfeed_shape=u32[]
   })";
+
   const int kNumReplicas = 4;
+  SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas);
+
   auto config = GetModuleConfigForTest();
   config.set_replica_count(kNumReplicas);
   std::unique_ptr<HloModule> module =
diff --git a/third_party/xla/xla/tools/BUILD b/third_party/xla/xla/tools/BUILD
index 36e284e..6975a38 100644
--- a/third_party/xla/xla/tools/BUILD
+++ b/third_party/xla/xla/tools/BUILD
@@ -56,7 +56,6 @@
     name = "hex_floats_to_packed_literal",
     srcs = ["hex_floats_to_packed_literal.cc"],
     deps = [
-        "//xla:types",
         "//xla/tsl/util:command_line_flags",
         "@com_google_absl//absl/base",
         "@com_google_absl//absl/strings",
@@ -127,13 +126,12 @@
     name = "convert_computation",
     srcs = ["convert_computation.cc"],
     deps = [
-        "//xla:types",
         "//xla/service:hlo_proto_cc",
-        "@com_google_absl//absl/status:statusor",
         "@local_tsl//tsl/platform:env",
         "@local_tsl//tsl/platform:logging",
         "@local_tsl//tsl/platform:platform_port",
         "@local_tsl//tsl/platform:protobuf",
+        "@local_tsl//tsl/platform:status",
     ],
 )
 
@@ -193,19 +191,24 @@
     name = "dumped_computation_to_text",
     srcs = ["dumped_computation_to_text.cc"],
     deps = [
-        "//xla:types",
-        "//xla/client",
+        "//xla:shape_util",
+        "//xla:xla_proto_cc",
         "//xla/client:client_library",
+        "//xla/client:executable_build_options",
         "//xla/client:local_client",
         "//xla/client:xla_computation",
+        "//xla/hlo/ir:hlo",
         "//xla/service",
         "//xla/service:hlo_proto_cc",
         "//xla/service:interpreter_plugin",
+        "//xla/service:local_service",
+        "//xla/tsl/util:command_line_flags",
         "@com_google_absl//absl/status:statusor",
         "@com_google_absl//absl/types:span",
         "@local_tsl//tsl/platform:env",
         "@local_tsl//tsl/platform:logging",
         "@local_tsl//tsl/platform:platform_port",
+        "@local_tsl//tsl/platform:status",
     ],
 )
 
@@ -220,15 +223,16 @@
     name = "dumped_computation_to_operation_list",
     srcs = ["dumped_computation_to_operation_list.cc"],
     deps = [
-        "//xla:types",
-        "//xla/client",
+        "//xla:shape_util",
         "//xla/client:client_library",
+        "//xla/client:executable_build_options",
         "//xla/client:local_client",
         "//xla/client:xla_computation",
         "//xla/hlo/ir:hlo",
         "//xla/service",
         "//xla/service:hlo_proto_cc",
         "//xla/service:interpreter_plugin",
+        "//xla/service:local_service",
         "@com_google_absl//absl/status",
         "@com_google_absl//absl/status:statusor",
         "@com_google_absl//absl/strings",
@@ -237,6 +241,7 @@
         "@local_tsl//tsl/platform:env",
         "@local_tsl//tsl/platform:logging",
         "@local_tsl//tsl/platform:platform_port",
+        "@local_tsl//tsl/platform:status",
     ],
 )
 
@@ -508,6 +513,19 @@
     ],
 )
 
+xla_cc_test(
+    name = "prepare_reference_module_test",
+    srcs = ["prepare_reference_module_test.cc"],
+    deps = [
+        ":prepare_reference_module",
+        "//xla:test",
+        "//xla/hlo/ir:hlo",
+        "//xla/tests:hlo_test_base",
+        "@local_tsl//tsl/platform:statusor",
+        "@local_tsl//tsl/platform:test_main",
+    ],
+)
+
 tf_proto_library(
     name = "run_hlo_module_proto",
     srcs = ["run_hlo_module.proto"],
@@ -633,6 +651,7 @@
         "data/add.hlo",
         "data/add_mhlo.mlir",
         "data/add_stablehlo.mlir",
+        "data/input_literal_f32_2_2.pbtxt",
         "data/must_alias.hlo",
         "data/must_alias_with_sharding.hlo",
         ":run_hlo_module",
@@ -644,6 +663,7 @@
         "//xla/service:hlo_parser",
         "//xla/tsl/lib/core:status_test_util",
         "@com_google_absl//absl/log",
+        "@com_google_absl//absl/strings",
         "@local_tsl//tsl/platform:env",
         "@local_tsl//tsl/platform:path",
         "@local_tsl//tsl/platform:statusor",
@@ -718,8 +738,14 @@
     srcs = ["compute_cost.cc"],
     deps = [
         ":hlo_module_loader",
+        "//xla:debug_options_flags",
+        "//xla:shape_util",
         "//xla/service:hlo_cost_analysis",
+        "//xla/tsl/util:command_line_flags",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/strings",
         "@local_tsl//tsl/platform:platform_port",
+        "@local_tsl//tsl/platform:status",
     ],
 )
 
@@ -798,32 +824,23 @@
         "//xla/service/gpu:amdgpu_compiler",
         "//xla/service/gpu:amdgpu_compiler_impl",
     ]) + if_gpu_is_configured([
-        "//xla/service/gpu:autotuner_util",
         "//xla/service/gpu:executable_proto_cc",
         "//xla/service/gpu:gpu_compiler",
+        "//xla/service/gpu/autotuning:autotuner_util",
         "//xla/stream_executor/gpu:gpu_init",
         "//xla/service/gpu:gpu_symbol_repository",
     ]) + if_google(["@com_google_protobuf//:duration_cc_proto"]),
 )
 
 xla_test(
-    name = "xla_compile_lib_test",
-    srcs = ["xla_compile_lib_test.cc"],
-    backend_tags = {
-        "gpu": ["requires-gpu-nvidia"] + if_google(["config-cuda-only"]),
-    },
+    name = "xla_cpu_compile_lib_test",
+    srcs = ["xla_cpu_compile_lib_test.cc"],
     backends = [
         "cpu",
-        "gpu",
     ],
     data = [
         ":data/add.hlo",
-        "//xla/service:xla_aot_compile_test_gpu_target_config.prototxt",
-        "//xla/service/gpu:gpu_compiler_test_autotune_db.textproto",
     ],
-    local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([
-        "TENSORFLOW_USE_ROCM=1",
-    ]),
     deps = [
         ":xla_compile_lib",
         "//xla:util",
@@ -831,11 +848,7 @@
         "//xla/service:platform_util",
         "//xla/service:symbol_repository",
         "//xla/service:xla_compile_result_proto_cc_impl",
-        "//xla/service/gpu:autotuner_util",
-        "//xla/service/gpu:gpu_symbol_repository",
-        "//xla/stream_executor:device_description_proto_cc",
         "//xla/tests:hlo_test_base",
-        "//xla/tests:test_macros_header",
         "//xla/tests:xla_internal_test_main",  # fixdeps: keep
         "//xla/tsl/lib/core:status_test_util",
         "@com_google_absl//absl/strings",
@@ -843,6 +856,7 @@
         "@com_google_googletest//:gtest",
         "@local_tsl//tsl/platform:env",
         "@local_tsl//tsl/platform:env_time",
+        "@local_tsl//tsl/platform:errors",
         "@local_tsl//tsl/platform:path",
         "@local_tsl//tsl/platform:status_matchers",
         "@local_tsl//tsl/platform:statusor",
@@ -853,6 +867,44 @@
 )
 
 xla_test(
+    name = "xla_gpu_compile_lib_test",
+    srcs = ["xla_gpu_compile_lib_test.cc"],
+    backend_tags = {
+        "gpu": ["requires-gpu-nvidia"] + if_google(["config-cuda-only"]),
+    },
+    backends = [
+        "gpu",
+    ],
+    data = [
+        ":data/add.hlo",
+        "//xla/service:xla_aot_compile_test_gpu_target_config.prototxt",
+        "//xla/service/gpu:gpu_compiler_test_autotune_db.textproto",
+    ],
+    deps = [
+        ":xla_compile_lib",
+        "//xla:util",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:platform_util",
+        "//xla/service:symbol_repository",
+        "//xla/service:xla_compile_result_proto_cc_impl",
+        "//xla/service/gpu:gpu_symbol_repository",
+        "//xla/service/gpu/autotuning:autotuner_util",
+        "//xla/stream_executor:device_description_proto_cc",
+        "//xla/tests:hlo_test_base",
+        "//xla/tests:xla_internal_test_main",  # fixdeps: keep
+        "//xla/tsl/lib/core:status_test_util",
+        "@com_google_googletest//:gtest",
+        "@local_tsl//tsl/platform:env",
+        "@local_tsl//tsl/platform:path",
+        "@local_tsl//tsl/platform:status_matchers",
+        "@local_tsl//tsl/platform:statusor",
+        "@local_tsl//tsl/platform:test",
+        "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc",
+        "@local_tsl//tsl/protobuf:status_proto_cc",
+    ],
+)
+
+xla_test(
     name = "hlo_decomposer_test",
     srcs = ["hlo_decomposer_test.cc"],
     deps = [
diff --git a/third_party/xla/xla/tools/compute_cost.cc b/third_party/xla/xla/tools/compute_cost.cc
index c5e0ddc..9615ae0 100644
--- a/third_party/xla/xla/tools/compute_cost.cc
+++ b/third_party/xla/xla/tools/compute_cost.cc
@@ -21,9 +21,16 @@
 #include <string>
 #include <vector>
 
+#include "absl/log/log.h"
+#include "absl/strings/str_cat.h"
+#include "xla/debug_options_flags.h"
 #include "xla/service/hlo_cost_analysis.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
 #include "xla/tools/hlo_module_loader.h"
+#include "xla/tsl/util/command_line_flags.h"
 #include "tsl/platform/init_main.h"
+#include "tsl/platform/status.h"
 
 namespace {
 const char* const kUsage = R"(
diff --git a/third_party/xla/xla/tools/convert_computation.cc b/third_party/xla/xla/tools/convert_computation.cc
index f81d517..7ebc5d3 100644
--- a/third_party/xla/xla/tools/convert_computation.cc
+++ b/third_party/xla/xla/tools/convert_computation.cc
@@ -16,6 +16,7 @@
 // Usage: convert_computation <txt2bin|bin2txt> serialized_computation_proto
 //
 // bin2txt spits out the result to stdout. txt2bin modifies the file in place.
+#include "tsl/platform/status.h"
 #ifndef _WIN32
 #include <unistd.h>
 #endif
@@ -23,9 +24,7 @@
 
 #include <string>
 
-#include "absl/status/statusor.h"
 #include "xla/service/hlo.pb.h"
-#include "xla/types.h"
 #include "tsl/platform/env.h"
 #include "tsl/platform/init_main.h"
 #include "tsl/platform/logging.h"
diff --git a/third_party/xla/xla/tools/data/input_literal_f32_2_2.pbtxt b/third_party/xla/xla/tools/data/input_literal_f32_2_2.pbtxt
new file mode 100644
index 0000000..6c39d03
--- /dev/null
+++ b/third_party/xla/xla/tools/data/input_literal_f32_2_2.pbtxt
@@ -0,0 +1,20 @@
+# proto-file: third_party/tensorflow/compiler/xla/tools/run_hlo_module.proto
+# proto-message: RunHloModuleIterationLiterals
+arguments {
+  shape {
+    element_type: F32
+    dimensions: 2
+    dimensions: 2
+    layout {
+      minor_to_major: 1
+      minor_to_major: 0
+      tail_padding_alignment_in_elements: 1
+    }
+    is_dynamic_dimension: false
+    is_dynamic_dimension: false
+  }
+  f32s: 0.1
+  f32s: 0.2
+  f32s: 0.3
+  f32s: 0.4
+}
\ No newline at end of file
diff --git a/third_party/xla/xla/tools/dumped_computation_to_operation_list.cc b/third_party/xla/xla/tools/dumped_computation_to_operation_list.cc
index e17d487..6450210 100644
--- a/third_party/xla/xla/tools/dumped_computation_to_operation_list.cc
+++ b/third_party/xla/xla/tools/dumped_computation_to_operation_list.cc
@@ -22,20 +22,26 @@
 
 #include "absl/status/status.h"
 #include "absl/status/statusor.h"
+#include "absl/strings/str_cat.h"
 #include "absl/strings/str_format.h"
 #include "absl/strings/str_join.h"
 #include "absl/types/span.h"
-#include "xla/client/client.h"
 #include "xla/client/client_library.h"
+#include "xla/client/executable_build_options.h"
 #include "xla/client/local_client.h"
 #include "xla/client/xla_computation.h"
 #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
 #include "xla/service/hlo.pb.h"
-#include "xla/service/service.h"
-#include "xla/types.h"
+#include "xla/service/local_service.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
 #include "tsl/platform/env.h"
 #include "tsl/platform/init_main.h"
 #include "tsl/platform/logging.h"
+#include "tsl/platform/status.h"
 
 namespace xla {
 namespace tools {
diff --git a/third_party/xla/xla/tools/dumped_computation_to_text.cc b/third_party/xla/xla/tools/dumped_computation_to_text.cc
index df9116e..695d4c9 100644
--- a/third_party/xla/xla/tools/dumped_computation_to_text.cc
+++ b/third_party/xla/xla/tools/dumped_computation_to_text.cc
@@ -21,16 +21,21 @@
 
 #include "absl/status/statusor.h"
 #include "absl/types/span.h"
-#include "xla/client/client.h"
 #include "xla/client/client_library.h"
+#include "xla/client/executable_build_options.h"
 #include "xla/client/local_client.h"
 #include "xla/client/xla_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
 #include "xla/service/hlo.pb.h"
-#include "xla/service/service.h"
-#include "xla/types.h"
+#include "xla/service/local_service.h"
+#include "xla/shape.h"
+#include "xla/tsl/util/command_line_flags.h"
+#include "xla/xla.pb.h"
 #include "tsl/platform/env.h"
 #include "tsl/platform/init_main.h"
 #include "tsl/platform/logging.h"
+#include "tsl/platform/status.h"
 
 namespace xla {
 namespace tools {
diff --git a/third_party/xla/xla/tools/hex_floats_to_packed_literal.cc b/third_party/xla/xla/tools/hex_floats_to_packed_literal.cc
index c4d591b..6388a8f 100644
--- a/third_party/xla/xla/tools/hex_floats_to_packed_literal.cc
+++ b/third_party/xla/xla/tools/hex_floats_to_packed_literal.cc
@@ -21,10 +21,10 @@
 #include "absl/base/casts.h"
 #include "absl/strings/string_view.h"
 #include "xla/tsl/util/command_line_flags.h"
-#include "xla/types.h"
 #include "tsl/lib/io/buffered_inputstream.h"
 #include "tsl/lib/io/random_inputstream.h"
 #include "tsl/platform/env.h"
+#include "tsl/platform/file_system.h"
 #include "tsl/platform/init_main.h"
 #include "tsl/platform/logging.h"
 #include "tsl/platform/status.h"
diff --git a/third_party/xla/xla/tools/hlo_bisect/BUILD b/third_party/xla/xla/tools/hlo_bisect/BUILD
index 3df7fbb..7f9747b 100644
--- a/third_party/xla/xla/tools/hlo_bisect/BUILD
+++ b/third_party/xla/xla/tools/hlo_bisect/BUILD
@@ -84,6 +84,7 @@
         "//xla:protobuf_util",
         "//xla:util",
         "//xla/hlo/ir:hlo",
+        "//xla/service:dump",
         "//xla/service:hlo_parser",
         "//xla/service:hlo_proto_cc",
         "//xla/service:hlo_proto_util",
diff --git a/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_utils.cc b/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_utils.cc
index 0e38151..d4e6d0d 100644
--- a/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_utils.cc
+++ b/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_utils.cc
@@ -25,6 +25,7 @@
 #include "xla/error_spec.h"
 #include "xla/hlo/ir/hlo_instruction.h"
 #include "xla/protobuf_util.h"
+#include "xla/service/dump.h"
 #include "xla/service/hlo.pb.h"
 #include "xla/service/hlo_parser.h"
 #include "xla/service/hlo_proto_util.h"
@@ -137,7 +138,7 @@
   HloProto proto = MakeHloProto(*module);
   if (output_format == "hlo") {
     tsl::Env* env = tsl::Env::Default();
-    TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(std::string(dir_path)));
+    TF_RETURN_IF_ERROR(CreateDirIfNeeded(std::string(dir_path), env));
     std::string file_path =
         tsl::io::JoinPath(dir_path, SanitizeFileName(file_name)) + ".hlo";
     LOG(INFO) << "Dumped HLO text to " << file_path;
@@ -148,8 +149,8 @@
                              .set_compact_operands(false))));
   } else if (output_format == "pb") {
     std::string path;
-    TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory(
-        proto, std::string(dir_path), file_name, &path));
+    TF_RETURN_IF_ERROR(
+        DumpProtoToDirectory(proto, std::string(dir_path), file_name, &path));
     LOG(INFO) << "Dumped HLO module proto to " << path;
 
   } else {
diff --git a/third_party/xla/xla/tools/hlo_control_flow_flattening.cc b/third_party/xla/xla/tools/hlo_control_flow_flattening.cc
index d37c6b0..e81f3ef 100644
--- a/third_party/xla/xla/tools/hlo_control_flow_flattening.cc
+++ b/third_party/xla/xla/tools/hlo_control_flow_flattening.cc
@@ -496,7 +496,7 @@
           TF_RETURN_IF_ERROR(RemoveCollective(instruction).status());
         }
         changed = true;
-      } else if (remove_comm_ &&
+      } else if ((remove_comm_ || remove_id_) &&
                  (instruction->opcode() == HloOpcode::kPartitionId ||
                   instruction->opcode() == HloOpcode::kReplicaId ||
                   (instruction->opcode() == HloOpcode::kCustomCall &&
diff --git a/third_party/xla/xla/tools/hlo_control_flow_flattening.h b/third_party/xla/xla/tools/hlo_control_flow_flattening.h
index cff9db4..450aeab 100644
--- a/third_party/xla/xla/tools/hlo_control_flow_flattening.h
+++ b/third_party/xla/xla/tools/hlo_control_flow_flattening.h
@@ -49,6 +49,8 @@
     bool flatten_while_loop = true;
     bool remove_comm = true;
     bool remove_host_transfer = false;
+    // Removes partition-id, replica-id, and slice-id.
+    bool remove_id = false;
   };
   explicit HloControlFlowFlattening(const Options& options)
       : while_execution_count_(options.while_execution_count),
@@ -57,7 +59,8 @@
         remove_infeed_outfeed_(options.remove_infeed_outfeed),
         flatten_while_loop_(options.flatten_while_loop),
         remove_host_transfer_(options.remove_host_transfer),
-        remove_comm_(options.remove_comm) {}
+        remove_comm_(options.remove_comm),
+        remove_id_(options.remove_id) {}
   ~HloControlFlowFlattening() override = default;
   absl::string_view name() const override { return "control-flow-flattening"; }
   using HloPassInterface::Run;
@@ -102,6 +105,7 @@
       HloInstruction* recv_done,
       absl::flat_hash_set<HloInstruction*>* additional_removed) const;
   bool remove_comm_;
+  bool remove_id_;
 };
 
 // Retrieves the original loop bound. If fail, return a default value. If bounds
diff --git a/third_party/xla/xla/tools/hlo_control_flow_flattening_test.cc b/third_party/xla/xla/tools/hlo_control_flow_flattening_test.cc
index ceb51be..40d16ca 100644
--- a/third_party/xla/xla/tools/hlo_control_flow_flattening_test.cc
+++ b/third_party/xla/xla/tools/hlo_control_flow_flattening_test.cc
@@ -515,6 +515,37 @@
             "replica-id.18600");
 }
 
+TEST_F(HloControlFlowFlatteningTest, RemoveReplicaIdButKeepAllReduce) {
+  absl::string_view kHloText = R"(
+  HloModule RemoveReplicaIdButKeepCollective
+
+%sum (a: f32[], b: f32[]) -> f32[] {
+    %a = f32[] parameter(0)
+    %b = f32[] parameter(1)
+    ROOT %add = f32[] add(f32[] a, f32[] b)
+  }
+  ENTRY ReplicaId {
+    replica-id.1 = u32[]{:T(128)} replica-id()
+    ROOT all-reduce.1 = u32[]{:T(128)} all-reduce(replica-id.1), to_apply=sum, replica_groups={}
+  }
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+                          ParseAndReturnVerifiedModule(kHloText));
+  HloControlFlowFlattening flattening(HloControlFlowFlattening::Options{
+      /*while_execution_count=*/1, /*max_outer_loop_count=*/1,
+      /*max_loop_count=*/1, /*remove_infeed_outfeed=*/false,
+      /*flatten_while_loop=*/false, /*remove_comm=*/false,
+      /*remove_host_transfer=*/false, /*remove_id=*/true});
+  EXPECT_TRUE(flattening.Run(module.get()).value());
+  TF_ASSERT_OK(HloVerifier(/*layout_sensitive=*/true,
+                           /*allow_mixed_precision=*/true)
+                   .Run(module.get())
+                   .status());
+  EXPECT_THAT(module->entry_computation()->root_instruction(), op::AllReduce());
+  EXPECT_THAT(module->entry_computation()->root_instruction()->operand(0),
+              op::Constant());
+}
+
 TEST_F(HloControlFlowFlatteningTest, CollectivePermuteInPlaceUpdate) {
   absl::string_view hlo_string = R"(
   HloModule CollectivePermuteInPlaceUpdate
diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/BUILD b/third_party/xla/xla/tools/multihost_hlo_runner/BUILD
index cefc272..1e1450b 100644
--- a/third_party/xla/xla/tools/multihost_hlo_runner/BUILD
+++ b/third_party/xla/xla/tools/multihost_hlo_runner/BUILD
@@ -84,6 +84,7 @@
     testonly = True,
     tags = [
         "gpu",
+        "no_rocm",
         "nomac",
     ] + tf_gpu_tests_tags(),
     deps = [
diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/create_client.cc b/third_party/xla/xla/tools/multihost_hlo_runner/create_client.cc
index 1a490a0..3aabf56 100644
--- a/third_party/xla/xla/tools/multihost_hlo_runner/create_client.cc
+++ b/third_party/xla/xla/tools/multihost_hlo_runner/create_client.cc
@@ -57,43 +57,41 @@
   if (enable_mock_nccl) {
     CHECK_GT(num_nodes, 1);
     return CreateMockGpuClient(num_nodes);
-  } else {
-    if (num_nodes == 1) {
-      return CreateGpuClient({});
-    } else {
-      TF_RET_CHECK(!address.empty());
-      TF_RET_CHECK(node_id >= 0)
-          << "Node id is expected to be in range [0, num_nodes)";
-      TF_RET_CHECK(node_id < num_nodes)
-          << "Node id is expected to be in range [0, num_nodes)";
-
-      CHECK_GT(address.length(), 0);
-      // Multinode. Start service on task 0.
-      if (node_id == 0) {
-        std::string coordinator_bind_address =
-            "[::]:" + std::string(address).substr(address.rfind(':') + 1);
-        xla::CoordinationServiceImpl::Options options;
-        options.num_nodes = num_nodes;
-        auto status_or = xla::GetDistributedRuntimeService(
-            coordinator_bind_address, options);
-        TF_QCHECK_OK(status_or.status());
-        service = std::move(status_or.value());
-      }
-      xla::DistributedRuntimeClient::Options options;
-      options.node_id = node_id;
-      options.init_timeout = init_timeout;
-      distributed_client =
-          GetDistributedRuntimeClient(std::string(address), options);
-      TF_QCHECK_OK(distributed_client->Connect());
-      kv_store = GetDistributedKeyValueStore(distributed_client,
-                                             /*key_prefix=*/"gpu:");
-      GpuClientOptions gpu_client_options;
-      gpu_client_options.node_id = node_id;
-      gpu_client_options.num_nodes = num_nodes;
-      gpu_client_options.kv_store = kv_store;
-      return CreateGpuClient(std::move(gpu_client_options));
-    }
   }
+
+  if (num_nodes == 1) {
+    return CreateGpuClient({});
+  }
+
+  TF_RET_CHECK(!address.empty());
+  TF_RET_CHECK(node_id >= 0)
+      << "Node id is expected to be in range [0, num_nodes)";
+  TF_RET_CHECK(node_id < num_nodes)
+      << "Node id is expected to be in range [0, num_nodes)";
+
+  CHECK_GT(address.length(), 0);
+  // Multinode. Start service on task 0.
+  if (node_id == 0) {
+    std::string coordinator_bind_address =
+        "[::]:" + std::string(address).substr(address.rfind(':') + 1);
+    xla::CoordinationServiceImpl::Options options;
+    options.num_nodes = num_nodes;
+    TF_ASSIGN_OR_RETURN(service, xla::GetDistributedRuntimeService(
+                                     coordinator_bind_address, options));
+  }
+  xla::DistributedRuntimeClient::Options options;
+  options.node_id = node_id;
+  options.init_timeout = init_timeout;
+  distributed_client =
+      GetDistributedRuntimeClient(std::string(address), options);
+  TF_QCHECK_OK(distributed_client->Connect());
+  kv_store = GetDistributedKeyValueStore(distributed_client,
+                                         /*key_prefix=*/"gpu:");
+  GpuClientOptions gpu_client_options;
+  gpu_client_options.node_id = node_id;
+  gpu_client_options.num_nodes = num_nodes;
+  gpu_client_options.kv_store = kv_store;
+  return CreateGpuClient(std::move(gpu_client_options));
 }
 
 absl::StatusOr<PjRtEnvironment> GetPjRtClient(absl::string_view device_type,
diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/data/multiple_gemm_fusions.hlo b/third_party/xla/xla/tools/multihost_hlo_runner/data/multiple_gemm_fusions.hlo
index c745ee7..c182a59 100644
--- a/third_party/xla/xla/tools/multihost_hlo_runner/data/multiple_gemm_fusions.hlo
+++ b/third_party/xla/xla/tools/multihost_hlo_runner/data/multiple_gemm_fusions.hlo
@@ -1,35 +1,46 @@
 f1 {
-  p0 = f16[720,720,720]{2,1,0} parameter(0)
-  p1 = s8[720,720,720]{2,1,0} parameter(1)
-  c = f16[720,720,720]{2,1,0} convert(p1)
-  ROOT d1 = f16[720,720,720]{2,1,0} dot(p0, c),
+  p0 = f16[64,64,64] parameter(0)
+  p1 = s8[64,64,64] parameter(1)
+  c = f16[64,64,64] convert(p1)
+  ROOT d1 = f32[64,64,64] dot(p0, c),
     lhs_batch_dims={0}, lhs_contracting_dims={2},
     rhs_batch_dims={0}, rhs_contracting_dims={1}
 }
 
 f2 {
-  p0 = s8[720,720,720]{2,1,0} parameter(0)
-  c0 = f32[720,720,720]{2,1,0} convert(p0)
-  p1 = f16[720,720,720]{2,1,0} parameter(1)
-  c1 = f32[720,720,720]{2,1,0} convert(p1)
-  ROOT %dot.1 = f32[720,720,720]{2,1,0} dot(c0, c1),
+  p0 = s8[64,64,64] parameter(0)
+  c0 = f32[64,64,64] convert(p0)
+  p1 = f16[64,64,64] parameter(1)
+  c1 = f32[64,64,64] convert(p1)
+  ROOT d2 = f32[64,64,64] dot(c0, c1),
     lhs_batch_dims={0}, lhs_contracting_dims={2},
     rhs_batch_dims={0}, rhs_contracting_dims={1}
 }
 
+f3 {
+  p0 = f16[64,64,64] parameter(0)
+  p1 = f16[64,64,64] parameter(1)
+  ROOT d3 = f32[64,64,64] dot(p0, p1),
+    lhs_batch_dims={0}, lhs_contracting_dims={1},
+    rhs_batch_dims={0}, rhs_contracting_dims={2}
+}
+
 fa {
-  p1 = f16[720,720,720]{2,1,0} parameter(1)
-  c = f32[720,720,720]{2,1,0} convert(p1)
-  p0 = f32[720,720,720]{2,1,0} parameter(0)
-  ROOT %add.1.1 = f32[720,720,720]{2,1,0} add(c, p0)
+  p0 = f32[64,64,64] parameter(0)
+  p1 = f32[64,64,64] parameter(1)
+  p2 = f32[64,64,64] parameter(2)
+  a1 = f32[64,64,64] add(p2, p1)
+  ROOT a = f32[64,64,64] add(p0, a1)
 }
 
 ENTRY e {
-  p1 = s8[720,720,720]{2,1,0} parameter(1)
-  p0 = f16[720,720,720]{2,1,0} parameter(0)
-  f1r = f16[720,720,720]{2,1,0} fusion(p0, p1), kind=kCustom, calls=f1,
+  p0 = f16[64,64,64] parameter(0)
+  p1 = s8[64,64,64] parameter(1)
+  f1r = f32[64,64,64] fusion(p0, p1), kind=kCustom, calls=f1,
     backend_config={"fusion_backend_config":{"kind":"__triton_gemm"}}
-  f2r = f32[720,720,720]{2,1,0} fusion(p1, p0), kind=kCustom, calls=f2,
+  f2r = f32[64,64,64] fusion(p1, p0), kind=kCustom, calls=f2,
     backend_config={"fusion_backend_config":{"kind":"__triton_gemm"}}
-  ROOT _ = f32[720,720,720]{2,1,0} fusion(f2r, f1r), kind=kLoop, calls=fa
+  f3r = f32[64,64,64] fusion(p0, p0), kind=kCustom, calls=f3,
+    backend_config={"fusion_backend_config":{"kind":"__triton_gemm"}}
+  ROOT _ = f32[64,64,64] fusion(f1r, f2r, f3r), kind=kLoop, calls=fa
 }
diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc
index b55c1be..405e421 100644
--- a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc
+++ b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc
@@ -256,7 +256,7 @@
 
 // Name of the test binary.
 static const char* binary_name;
-constexpr int kNumNodes = 3;
+constexpr int kNumNodes = 2;
 
 TEST_F(FunctionalHloRunnerTest, ShardedAutotuningWorks) {
   if (IsTestingCpu()) {
@@ -308,13 +308,8 @@
                         env.kv_store->Get("gemm_fusion_autotuning_results_1_1",
                                           absl::Seconds(1)));
     CHECK(absl::StrContains(results1, "run_time"));
-    // First two nodes autotune two different fusions.
+    // The nodes autotune different fusions.
     CHECK_NE(results0, results1);
-    TF_ASSIGN_OR_RETURN(std::string results2,
-                        env.kv_store->Get("gemm_fusion_autotuning_results_1_2",
-                                          absl::Seconds(1)));
-    // Third node has nothing to autotune.
-    CHECK(!absl::StrContains(results2, "run_time"));
   }
   return absl::OkStatus();
 }
diff --git a/third_party/xla/xla/tools/prepare_reference_module.cc b/third_party/xla/xla/tools/prepare_reference_module.cc
index 4ce766d..82fd57a 100644
--- a/third_party/xla/xla/tools/prepare_reference_module.cc
+++ b/third_party/xla/xla/tools/prepare_reference_module.cc
@@ -34,7 +34,8 @@
     const HloModule& test_module, HloRunnerInterface* test_runner,
     const std::function<void(HloModuleConfig*)>& config_modifier_hook,
     const std::function<absl::Status(const HloModule&, HloRunnerInterface*,
-                                     HloModule*)>& module_modifier_hook) {
+                                     HloModule*)>& module_modifier_hook,
+    bool skip_despecialization) {
   DebugOptions debug_options = GetDebugOptionsFromFlags();
   // The combination of fast math and optimizations leads to unsound code
   // transformations (see third_party/tensorflow/compiler/xla/xla.proto for
@@ -51,7 +52,7 @@
   if (module_modifier_hook) {
     TF_RETURN_IF_ERROR(
         module_modifier_hook(test_module, test_runner, reference_module.get()));
-  } else {
+  } else if (!skip_despecialization) {
     TF_RETURN_IF_ERROR(Despecializer().Run(reference_module.get()).status());
   }
   return std::move(reference_module);
diff --git a/third_party/xla/xla/tools/prepare_reference_module.h b/third_party/xla/xla/tools/prepare_reference_module.h
index 4a1064d..f26e847 100644
--- a/third_party/xla/xla/tools/prepare_reference_module.h
+++ b/third_party/xla/xla/tools/prepare_reference_module.h
@@ -37,7 +37,8 @@
     const HloModule& test_module, HloRunnerInterface* test_runner,
     const std::function<void(HloModuleConfig*)>& config_modifier_hook = {},
     const std::function<absl::Status(const HloModule&, HloRunnerInterface*,
-                                     HloModule*)>& module_modifier_hook = {});
+                                     HloModule*)>& module_modifier_hook = {},
+    bool skip_despecialization = false);
 
 }  // namespace xla
 
diff --git a/third_party/xla/xla/tools/prepare_reference_module_test.cc b/third_party/xla/xla/tools/prepare_reference_module_test.cc
new file mode 100644
index 0000000..0b2ad0e
--- /dev/null
+++ b/third_party/xla/xla/tools/prepare_reference_module_test.cc
@@ -0,0 +1,76 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/tools/prepare_reference_module.h"
+
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/test.h"
+#include "xla/tests/hlo_test_base.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace {
+
+const char* const kModuleStr = R"(
+  HloModule jit_step
+
+  %fused_computation (param_0.2: f32[1,4]) -> f32[1,3] {
+    %param_0.2 = f32[1,4]{1,0} parameter(0)
+    ROOT %slice.11 = f32[1,3]{1,0} slice(f32[1,4]{1,0} %param_0.2),
+      slice={[0:1], [0:3]}
+  }
+
+  ENTRY %main.3491 (Arg_0.0: f32[1,4]) -> f32[1,3] {
+    %Arg_0.0 = f32[1,4]{1,0} parameter(0)
+    ROOT %fusion = f32[1,3]{1,0} fusion(f32[1,4]{1,0} %Arg_0.0), kind=kLoop,
+      calls=%fused_computation
+  }
+)";
+
+using PrepareReferenceModuleTest = HloTestBase;
+
+// Ideally 'Despecializer' pass should be mocked. Because it is not feasible
+// with the current design, despecialization tests in this file are based on
+// Despecializer's implementation (Despecializer removes fusion op from the
+// module).
+TEST_F(PrepareReferenceModuleTest, PerformDespecialization) {
+  TF_ASSERT_OK_AND_ASSIGN(auto test_module,
+                          ParseAndReturnVerifiedModule(kModuleStr));
+
+  TF_ASSERT_OK_AND_ASSIGN(
+      auto reference_module,
+      PrepareReferenceModule(*test_module, nullptr, {}, {},
+                             /*skip_despecialization=*/false));
+
+  // Fusion op should have been removed.
+  EXPECT_THAT(reference_module->ToString(),
+              Not(::testing::HasSubstr("fusion")));
+}
+
+TEST_F(PrepareReferenceModuleTest, SkipDespecialization) {
+  TF_ASSERT_OK_AND_ASSIGN(auto test_module,
+                          ParseAndReturnVerifiedModule(kModuleStr));
+
+  TF_ASSERT_OK_AND_ASSIGN(
+      auto reference_module,
+      PrepareReferenceModule(*test_module, nullptr, {}, {},
+                             /*skip_despecialization=*/true));
+
+  // Fusion op should be there.
+  EXPECT_THAT(reference_module->ToString(), ::testing::HasSubstr("fusion"));
+}
+
+}  // namespace
+}  // namespace xla
diff --git a/third_party/xla/xla/tools/run_hlo_module.cc b/third_party/xla/xla/tools/run_hlo_module.cc
index 690b509..22c0c02 100644
--- a/third_party/xla/xla/tools/run_hlo_module.cc
+++ b/third_party/xla/xla/tools/run_hlo_module.cc
@@ -168,6 +168,12 @@
   return std::move(result_status).value();
 }
 
+void UseCpuThunkRuntime(HloModule& module) {
+  auto debug_options = module.config().debug_options();
+  debug_options.set_xla_cpu_use_thunk_runtime(true);
+  module.mutable_config().set_debug_options(debug_options);
+}
+
 absl::Status RunAndCompareInternal(
     std::unique_ptr<HloModule> test_module,
     const BufferAssignmentProto* buffer_assignment_proto,
@@ -255,17 +261,27 @@
 
   std::unique_ptr<HloModule> reference_module;
   if (reference_runner != nullptr) {
+    // If reference platform is the same as test platform, we shouldn't
+    // deoptimize the reference module.
+    bool skip_deoptimization = options.reference_platform == options.platform;
+
     // PrepareReferenceModule needs to know the *test* runner, in order to
     // properly match the test runner's numerics.
     TF_ASSIGN_OR_RETURN(
         reference_module,
         copy_result_on_failure(
-            PrepareReferenceModule(*test_module, test_runner,
-                                   config_modifier_hook,
-                                   reference_module_modifier_hook),
+            PrepareReferenceModule(
+                *test_module, test_runner, config_modifier_hook,
+                reference_module_modifier_hook, skip_deoptimization),
             ModuleResult::kCompilationError, reference_run_result));
   }
 
+  // Now when reference_module is ready, we can modify test_module without
+  // impacting the reference run.
+  if (options.force_use_cpu_thunk_runtime_for_test) {
+    UseCpuThunkRuntime(*test_module);
+  }
+
   TF_ASSIGN_OR_RETURN(
       auto test_result,
       copy_result_on_failure(
diff --git a/third_party/xla/xla/tools/run_hlo_module.h b/third_party/xla/xla/tools/run_hlo_module.h
index 3300f1b..66afdc5 100644
--- a/third_party/xla/xla/tools/run_hlo_module.h
+++ b/third_party/xla/xla/tools/run_hlo_module.h
@@ -40,6 +40,7 @@
   bool flatten_control_flow{false};
   bool run_test_hlo_passes{true};
   bool run_reference_hlo_passes{true};
+  bool force_use_cpu_thunk_runtime_for_test{false};
   // Using small float range by default, as otherwise all reductions
   // miscompare vs. the interpreter with inf/nan.
   bool use_large_float_range{false};
diff --git a/third_party/xla/xla/tools/run_hlo_module_bin_test.cc b/third_party/xla/xla/tools/run_hlo_module_bin_test.cc
index fe82122..9b6138f 100644
--- a/third_party/xla/xla/tools/run_hlo_module_bin_test.cc
+++ b/third_party/xla/xla/tools/run_hlo_module_bin_test.cc
@@ -14,10 +14,12 @@
 ==============================================================================*/
 
 #include <memory>
+#include <optional>
 #include <string>
 #include <vector>
 
 #include "absl/log/log.h"
+#include "absl/strings/str_cat.h"
 #include "xla/hlo/ir/hlo_module.h"
 #include "xla/literal.h"
 #include "xla/literal_util.h"
@@ -32,22 +34,41 @@
 namespace xla {
 namespace {
 
+std::vector<std::string> make_args(
+    const std::string& run_hlo_module_bin, const std::string& file_name,
+    const std::vector<std::string>& extra_args = {},
+    std::optional<std::string> input_literals_file = std::nullopt) {
+  std::string hlo_path = file_name[0] == '/'
+                             ? file_name
+                             : tsl::io::JoinPath(tsl::testing::XlaSrcRoot(),
+                                                 "tools", "data", file_name);
+
+  std::vector<std::string> args = {run_hlo_module_bin, hlo_path,
+                                   "--platform=Host"};
+
+  args.insert(args.end(), extra_args.begin(), extra_args.end());
+
+  if (input_literals_file.has_value()) {
+    std::string input_path =
+        tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "tools", "data",
+                          input_literals_file.value());
+    args.push_back(absl::StrCat("--input_literals_file=", input_path));
+  }
+
+  return args;
+}
+
 class RunHloModuleTest : public ::testing::Test {
  protected:
   void RunHlo(const std::string& file_name,
-              const std::vector<std::string>& extra_args = {}) {
+              const std::vector<std::string>& extra_args = {},
+              std::optional<std::string> input_literals_file = std::nullopt) {
     std::string run_hlo_module_bin = tsl::io::JoinPath(
         tsl::testing::XlaSrcRoot(), "tools", "run_hlo_module");
 
-    std::string hlo_path = file_name[0] == '/'
-                               ? file_name
-                               : tsl::io::JoinPath(tsl::testing::XlaSrcRoot(),
-                                                   "tools", "data", file_name);
-
     tsl::SubProcess proc;
-    std::vector<std::string> args = {run_hlo_module_bin, hlo_path,
-                                     "--platform=Host"};
-    args.insert(args.end(), extra_args.begin(), extra_args.end());
+    auto args = make_args(run_hlo_module_bin, file_name, extra_args,
+                          input_literals_file);
     proc.SetProgram(run_hlo_module_bin, args);
     proc.SetChannelAction(tsl::CHAN_STDOUT, tsl::ACTION_PIPE);
     proc.SetChannelAction(tsl::CHAN_STDERR, tsl::ACTION_PIPE);
@@ -131,6 +152,22 @@
               testing::Not(testing::HasSubstr("memory allocation bug")));
 }
 
+TEST_F(RunHloModuleTest, ReadInputLiteralsFromFile) {
+  RunHlo("add.hlo",
+         /*extra_args=*/{"--print_literals=true", "--reference_platform="},
+         /*input_literals_file=*/"input_literal_f32_2_2.pbtxt");
+
+  EXPECT_TRUE(exited_normally_);
+  EXPECT_EQ(exit_status_, 0);
+
+  ASSERT_THAT(
+      stdout_output_,
+      testing::HasSubstr("{ 0.1, 0.2 },"));  // First two values of the input
+  ASSERT_THAT(
+      stdout_output_,
+      testing::HasSubstr("{ 0.2, 0.4 },"));  // First two values of the result
+}
+
 TEST_F(RunHloModuleTest, AddSnapshot) {
   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
                           ParseAndReturnUnverifiedModule(R"(
diff --git a/third_party/xla/xla/tools/run_hlo_module_main.cc b/third_party/xla/xla/tools/run_hlo_module_main.cc
index 0b3ff5d..92e2e23 100644
--- a/third_party/xla/xla/tools/run_hlo_module_main.cc
+++ b/third_party/xla/xla/tools/run_hlo_module_main.cc
@@ -122,7 +122,13 @@
           "other "
           "than the reference this is necessary because some HLO passes are "
           "legalization passes which must be run prior to code generation."),
-
+      tsl::Flag(
+          "force_use_cpu_thunk_runtime_for_test",
+          &opts.force_use_cpu_thunk_runtime_for_test,
+          "Use thunk runtime for the test platform. If true, thunks runtime "
+          "will be used for the test run regardless of the "
+          "xla_cpu_use_thunk_runtime flag in XLA_FLAGS. This option doesn't "
+          "impact reference run. It is ignored for platforms other than CPU."),
       tsl::Flag("random_init_input_literals", &opts.random_init_input_literals,
                 "Initialize input literals with random numbers."
                 "Leave them uninitialized otherwise."),
@@ -252,9 +258,9 @@
                                 &input_literals_proto);
     }
 
-    for (int i = 1; i <= iteration_count; ++i) {
+    for (int i = 0; i < iteration_count; ++i) {
       if (iteration_count != 1) {
-        std::cerr << "\n=== Iteration " << i << "\n";
+        std::cerr << "\n=== Iteration " << i + 1 << "\n";
       }
       xla::RunHloModuleIterationLiterals* iteration_literals_proto = nullptr;
       if (!opts.output_literals_file.empty() ||
@@ -276,7 +282,7 @@
           opts, iteration_literals_proto,
           /*reference_module_modifier_hook=*/{},
           [&](xla::HloModuleConfig* config) {
-            config->set_seed(different_random_seeds ? i : 42);
+            config->set_seed(different_random_seeds ? i + 1 : 42);
           });
 
       if (result.ok()) {
diff --git a/third_party/xla/xla/tools/xla_compile_lib.cc b/third_party/xla/xla/tools/xla_compile_lib.cc
index fe4289b..16d4d0d 100644
--- a/third_party/xla/xla/tools/xla_compile_lib.cc
+++ b/third_party/xla/xla/tools/xla_compile_lib.cc
@@ -68,7 +68,7 @@
 #include "tsl/platform/statusor.h"
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
-#include "xla/service/gpu/autotuner_util.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
 #include "xla/service/gpu/executable.pb.h"
 #include "xla/service/gpu/gpu_symbol_repository.h"
 #include "xla/stream_executor/gpu/gpu_init.h"
diff --git a/third_party/xla/xla/tools/xla_compile_lib_test.cc b/third_party/xla/xla/tools/xla_compile_lib_test.cc
deleted file mode 100644
index 101282c..0000000
--- a/third_party/xla/xla/tools/xla_compile_lib_test.cc
+++ /dev/null
@@ -1,326 +0,0 @@
-/* Copyright 2023 The OpenXLA Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "xla/tools/xla_compile_lib.h"
-
-#include <memory>
-#include <optional>
-#include <string>
-#include <utility>
-
-#include "google/protobuf/duration.pb.h"
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "absl/strings/string_view.h"
-#include "absl/synchronization/mutex.h"
-#include "xla/hlo/ir/hlo_module.h"
-#include "xla/service/gpu/gpu_symbol_repository.h"
-#include "xla/service/platform_util.h"
-#include "xla/service/symbol_repository.h"
-#include "xla/service/xla_compile_result.pb.h"
-#include "xla/stream_executor/device_description.pb.h"
-#include "xla/tests/hlo_test_base.h"
-#include "xla/tests/test_macros.h"
-#include "xla/tsl/lib/core/status_test_util.h"
-#include "xla/util.h"
-#include "tsl/platform/env.h"
-#include "tsl/platform/env_time.h"
-#include "tsl/platform/path.h"
-#include "tsl/platform/status_matchers.h"
-#include "tsl/platform/statusor.h"
-#include "tsl/platform/test.h"
-#include "tsl/protobuf/error_codes.pb.h"
-#include "tsl/protobuf/status.pb.h"
-
-#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
-#include "xla/service/gpu/autotuner_util.h"
-#endif
-
-namespace xla {
-namespace {
-
-using ::testing::IsEmpty;
-using ::testing::IsNull;
-using ::testing::Not;
-using ::tsl::testing::IsOk;
-using ::tsl::testing::IsOkAndHolds;
-using ::tsl::testing::StatusIs;
-
-#if XLA_TEST_BACKEND_CPU
-static constexpr absl::string_view kPlatformName = "Host";
-#elif XLA_TEST_BACKEND_GPU
-static constexpr absl::string_view kPlatformName =
-#if TENSORFLOW_USE_ROCM
-    "ROCM";
-#else
-    "CUDA";
-#endif
-#endif  // XLA_TEST_BACKEND_CPU
-
-class XlaCompileLibTest : public HloTestBase {
- protected:
-  XlaCompileLibTest()
-      : HloTestBase(*PlatformUtil::GetPlatform(std::string(kPlatformName)),
-                    GetReferencePlatform()) {}
-  void SetUp() override {
-    const std::string hlo_path = tsl::io::JoinPath(tsl::testing::XlaSrcRoot(),
-                                                   "tools", "data", "add.hlo");
-    std::string hlo;
-    TF_ASSERT_OK(tsl::ReadFileToString(tsl::Env::Default(), hlo_path, &hlo));
-    TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo));
-  }
-
-  std::unique_ptr<HloModule> module_;
-};
-
-TEST_F(XlaCompileLibTest, DISABLED_ON_GPU(CompilesForCpu)) {
-  CompilationResult result;
-  EXPECT_THAT(CompileExecutable(std::move(module_), BackendType::kCpu,
-                                std::nullopt, result),
-              IsOkAndHolds(Not(IsEmpty())));
-}
-
-TEST_F(XlaCompileLibTest, DISABLED_ON_CPU(CompilesForGpuWithDevice)) {
-  CompilationResult result;
-  EXPECT_THAT(CompileExecutable(std::move(module_), BackendType::kGpu,
-                                std::nullopt, result),
-              IsOkAndHolds(Not(IsEmpty())));
-  EXPECT_TRUE(result.has_hlo_module()) << result.DebugString();
-}
-
-TEST_F(XlaCompileLibTest, DISABLED_ON_CPU(CompilesForGpuWithoutDevice)) {
-  const std::string target_config_path =
-      tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "service",
-                        "xla_aot_compile_test_gpu_target_config.prototxt");
-  stream_executor::GpuTargetConfigProto target_config;
-  TF_ASSERT_OK(tsl::ReadTextProto(tsl::Env::Default(), target_config_path,
-                                  &target_config));
-  CompilationResult result;
-  EXPECT_THAT(CompileExecutable(std::move(module_), BackendType::kGpu,
-                                std::nullopt, result),
-              IsOkAndHolds(Not(IsEmpty())));
-  EXPECT_TRUE(result.has_hlo_module()) << result.DebugString();
-}
-
-TEST_F(XlaCompileLibTest, DISABLED_ON_GPU(ErrorsOnUnexpectedPlatform)) {
-  XlaCompileOptions options;
-  options.platform = "tpu";
-  EXPECT_THAT(XlaCompileMain(options), StatusIs(tsl::error::UNIMPLEMENTED));
-}
-
-TEST_F(XlaCompileLibTest, DISABLED_ON_GPU(WriteResultFilePropagatesErrors)) {
-  TimerStats stats;
-  CompilationResult result;
-  EXPECT_THAT(WriteResultFile("/does/not/exist", stats, result), Not(IsOk()));
-}
-
-TEST_F(XlaCompileLibTest, DISABLED_ON_GPU(WriteResultFileWritesTheFile)) {
-  std::string result_output_file;
-  ASSERT_TRUE(tsl::Env::Default()->LocalTempFilename(&result_output_file));
-
-  TimerStats stats;
-  {
-    absl::MutexLock ml(&stats.stats_mutex);
-    stats.cumulative_secs = 5.5;
-    stats.max_secs = 5.5;
-  }
-
-  CompilationResult result;
-  google::protobuf::Duration duration;
-  duration.set_seconds(5);
-  duration.set_nanos(0.5 * tsl::EnvTime::kSecondsToNanos);
-  *result.mutable_perf_stats()->mutable_compilation_duration() = duration;
-  *result.mutable_perf_stats()->mutable_total_duration() = duration;
-
-  TF_ASSERT_OK(WriteResultFile(result_output_file, stats, result));
-
-  CompilationResult got_result;
-  TF_ASSERT_OK(tsl::ReadBinaryProto(tsl::Env::Default(), result_output_file,
-                                    &got_result));
-  // Sadly EqualsProto isn't OSS, so we inspect a few fields manually.
-  // See googletest#1761 and b/229726259.
-  EXPECT_EQ(5, got_result.perf_stats().compilation_duration().seconds());
-  EXPECT_EQ(0.5 * tsl::EnvTime::kSecondsToNanos,
-            got_result.perf_stats().compilation_duration().nanos());
-  EXPECT_EQ(5, got_result.perf_stats().total_duration().seconds());
-  EXPECT_EQ(0.5 * tsl::EnvTime::kSecondsToNanos,
-            got_result.perf_stats().total_duration().nanos());
-}
-
-TEST_F(XlaCompileLibTest, LoadModuleErrors) {
-  EXPECT_THAT(LoadModule("/does/not/exist"), Not(IsOk()));
-}
-
-TEST_F(XlaCompileLibTest, LoadModuleLoadsTextFormat) {
-  const std::string module_file =
-      tsl::io::JoinPath(tsl::testing::TmpDir(), "module.txt");
-  TF_ASSERT_OK(tsl::WriteStringToFile(tsl::Env::Default(), module_file,
-                                      module_->ToString()));
-
-  EXPECT_THAT(LoadModule(module_file), IsOkAndHolds(Not(IsNull())));
-}
-
-TEST_F(XlaCompileLibTest, DISABLED_ON_GPU(MainForCpu)) {
-  const std::string module_file =
-      tsl::io::JoinPath(tsl::testing::TmpDir(), "module.txt");
-  TF_ASSERT_OK(tsl::WriteStringToFile(tsl::Env::Default(), module_file,
-                                      module_->ToString()));
-
-  const std::string output_path =
-      tsl::io::JoinPath(tsl::testing::TmpDir(), "cpu_output");
-  const std::string result_file =
-      tsl::io::JoinPath(tsl::testing::TmpDir(), "cpu_result.pb");
-
-  XlaCompileOptions options;
-  options.module_path = module_file;
-  options.output_path = output_path;
-  options.platform = "cpu";
-  options.result_output_file = result_file;
-  TF_EXPECT_OK(XlaCompileMain(options));
-
-  CompilationResult result;
-  TF_ASSERT_OK(tsl::ReadBinaryProto(tsl::Env::Default(), result_file, &result));
-  EXPECT_TRUE(result.has_status());
-  EXPECT_EQ(result.status().code(), tensorflow::error::OK);
-}
-
-TEST_F(XlaCompileLibTest, DISABLED_ON_CPU(MainForGpu)) {
-  const std::string module_file =
-      tsl::io::JoinPath(tsl::testing::TmpDir(), "module.txt");
-  TF_ASSERT_OK(tsl::WriteStringToFile(tsl::Env::Default(), module_file,
-                                      module_->ToString()));
-
-  const std::string output_path =
-      tsl::io::JoinPath(tsl::testing::TmpDir(), "gpu_output");
-  const std::string result_file =
-      tsl::io::JoinPath(tsl::testing::TmpDir(), "gpu_result.pb");
-
-  XlaCompileOptions options;
-  options.module_path = module_file;
-  options.output_path = output_path;
-  options.platform = "gpu";
-  options.result_output_file = result_file;
-  options.gpu_options.use_attached_device = true;
-  TF_EXPECT_OK(XlaCompileMain(options));
-
-  CompilationResult result;
-  TF_ASSERT_OK(tsl::ReadBinaryProto(tsl::Env::Default(), result_file, &result));
-  EXPECT_TRUE(result.has_status());
-  EXPECT_EQ(result.status().code(), tensorflow::error::OK);
-}
-
-TEST_F(XlaCompileLibTest, DISABLED_ON_GPU(LoadAutotuneDataCpu)) {
-  HloModuleAndMetadata mod;
-  mod.hlo_module = std::move(module_);
-
-  EXPECT_THAT(internal::LoadAutotuneDataFromModule(&mod, BackendType::kCpu),
-              IsOkAndHolds(false));
-}
-
-TEST_F(XlaCompileLibTest,
-       DISABLED_ON_CPU(LoadAutotuneDataGpuDataPresentAndAutotuningEnabled)) {
-#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
-  gpu::AutotunerUtil::ClearAutotuneResults();
-
-  HloModuleAndMetadata mod;
-  mod.hlo_module = std::move(module_);
-  auto data = std::make_unique<gpu::GpuBackendSpecificData>();
-
-  AutotuneResults autotune_results;
-  TF_ASSERT_OK(tsl::ReadTextProto(
-      tsl::Env::Default(),
-      tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "service", "gpu",
-                        "gpu_compiler_test_autotune_db.textproto"),
-      &autotune_results));
-  data->autotune_results = autotune_results;
-  mod.backend_specific_data = std::move(data);
-
-  DebugOptions opts = mod.hlo_module->config().debug_options();
-  opts.set_xla_gpu_autotune_level(3);
-  mod.hlo_module->mutable_config().set_debug_options(opts);
-
-  EXPECT_THAT(internal::LoadAutotuneDataFromModule(&mod, BackendType::kGpu),
-              IsOkAndHolds(true));
-  EXPECT_FALSE(gpu::AutotunerUtil::ResultCacheIsEmpty());
-#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
-}
-
-TEST_F(XlaCompileLibTest,
-       DISABLED_ON_CPU(LoadAutotuneDataGpuDataPresentAndAutotuningDisabled)) {
-#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
-  gpu::AutotunerUtil::ClearAutotuneResults();
-
-  HloModuleAndMetadata mod;
-  mod.hlo_module = std::move(module_);
-  auto data = std::make_unique<gpu::GpuBackendSpecificData>();
-
-  AutotuneResults autotune_results;
-  TF_ASSERT_OK(tsl::ReadTextProto(
-      tsl::Env::Default(),
-      tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "service", "gpu",
-                        "gpu_compiler_test_autotune_db.textproto"),
-      &autotune_results));
-  data->autotune_results = autotune_results;
-  mod.backend_specific_data = std::move(data);
-
-  DebugOptions opts = mod.hlo_module->config().debug_options();
-  opts.set_xla_gpu_autotune_level(0);
-  mod.hlo_module->mutable_config().set_debug_options(opts);
-
-  EXPECT_THAT(internal::LoadAutotuneDataFromModule(&mod, BackendType::kGpu),
-              IsOkAndHolds(false));
-  EXPECT_TRUE(gpu::AutotunerUtil::ResultCacheIsEmpty());
-#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
-}
-
-TEST_F(XlaCompileLibTest,
-       DISABLED_ON_CPU(LoadAutotuneDataGpuDataNotPresentAndAutotuningEnabled)) {
-#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
-  gpu::AutotunerUtil::ClearAutotuneResults();
-
-  HloModuleAndMetadata mod;
-  mod.hlo_module = std::move(module_);
-
-  DebugOptions opts = mod.hlo_module->config().debug_options();
-  opts.set_xla_gpu_autotune_level(3);
-  mod.hlo_module->mutable_config().set_debug_options(opts);
-
-  EXPECT_THAT(internal::LoadAutotuneDataFromModule(&mod, BackendType::kGpu),
-              IsOkAndHolds(false));
-  EXPECT_TRUE(gpu::AutotunerUtil::ResultCacheIsEmpty());
-#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
-}
-
-TEST_F(
-    XlaCompileLibTest,
-    DISABLED_ON_CPU(LoadAutotuneDataGpuDataNotPresentAndAutotuningDisabled)) {
-#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
-  gpu::AutotunerUtil::ClearAutotuneResults();
-
-  HloModuleAndMetadata mod;
-  mod.hlo_module = std::move(module_);
-
-  DebugOptions opts = mod.hlo_module->config().debug_options();
-  opts.set_xla_gpu_autotune_level(0);
-  mod.hlo_module->mutable_config().set_debug_options(opts);
-
-  EXPECT_THAT(internal::LoadAutotuneDataFromModule(&mod, BackendType::kGpu),
-              IsOkAndHolds(false));
-  EXPECT_TRUE(gpu::AutotunerUtil::ResultCacheIsEmpty());
-#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
-}
-
-}  // namespace
-}  // namespace xla
diff --git a/third_party/xla/xla/tools/xla_cpu_compile_lib_test.cc b/third_party/xla/xla/tools/xla_cpu_compile_lib_test.cc
new file mode 100644
index 0000000..62c0673
--- /dev/null
+++ b/third_party/xla/xla/tools/xla_cpu_compile_lib_test.cc
@@ -0,0 +1,167 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <optional>
+#include <string>
+#include <utility>
+
+#include "google/protobuf/duration.pb.h"
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/synchronization/mutex.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/platform_util.h"
+#include "xla/service/symbol_repository.h"
+#include "xla/service/xla_compile_result.pb.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/tools/xla_compile_lib.h"
+#include "xla/tsl/lib/core/status_test_util.h"
+#include "xla/util.h"
+#include "tsl/platform/env.h"
+#include "tsl/platform/env_time.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/path.h"
+#include "tsl/platform/status_matchers.h"
+#include "tsl/platform/statusor.h"
+#include "tsl/platform/test.h"
+#include "tsl/protobuf/error_codes.pb.h"
+#include "tsl/protobuf/status.pb.h"
+
+namespace xla {
+namespace {
+
+using ::testing::IsEmpty;
+using ::testing::IsNull;
+using ::testing::Not;
+using ::tsl::testing::IsOk;
+using ::tsl::testing::IsOkAndHolds;
+using ::tsl::testing::StatusIs;
+
+class XlaCompileLibTest : public HloTestBase {
+ protected:
+  XlaCompileLibTest()
+      : HloTestBase(*PlatformUtil::GetPlatform("Host"),
+                    GetReferencePlatform()) {}
+  void SetUp() override {
+    const std::string hlo_path = tsl::io::JoinPath(tsl::testing::XlaSrcRoot(),
+                                                   "tools", "data", "add.hlo");
+    std::string hlo;
+    TF_ASSERT_OK(tsl::ReadFileToString(tsl::Env::Default(), hlo_path, &hlo));
+    TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo));
+  }
+
+  std::unique_ptr<HloModule> module_;
+};
+
+TEST_F(XlaCompileLibTest, CompilesForCpu) {
+  CompilationResult result;
+  EXPECT_THAT(CompileExecutable(std::move(module_), BackendType::kCpu,
+                                std::nullopt, result),
+              IsOkAndHolds(Not(IsEmpty())));
+}
+
+TEST_F(XlaCompileLibTest, ErrorsOnUnexpectedPlatform) {
+  XlaCompileOptions options;
+  options.platform = "tpu";
+  EXPECT_THAT(XlaCompileMain(options), StatusIs(tsl::error::UNIMPLEMENTED));
+}
+
+TEST_F(XlaCompileLibTest, WriteResultFilePropagatesErrors) {
+  TimerStats stats;
+  CompilationResult result;
+  EXPECT_THAT(WriteResultFile("/does/not/exist", stats, result), Not(IsOk()));
+}
+
+TEST_F(XlaCompileLibTest, WriteResultFileWritesTheFile) {
+  std::string result_output_file;
+  ASSERT_TRUE(tsl::Env::Default()->LocalTempFilename(&result_output_file));
+
+  TimerStats stats;
+  {
+    absl::MutexLock ml(&stats.stats_mutex);
+    stats.cumulative_secs = 5.5;
+    stats.max_secs = 5.5;
+  }
+
+  CompilationResult result;
+  google::protobuf::Duration duration;
+  duration.set_seconds(5);
+  duration.set_nanos(0.5 * tsl::EnvTime::kSecondsToNanos);
+  *result.mutable_perf_stats()->mutable_compilation_duration() = duration;
+  *result.mutable_perf_stats()->mutable_total_duration() = duration;
+
+  TF_ASSERT_OK(WriteResultFile(result_output_file, stats, result));
+
+  CompilationResult got_result;
+  TF_ASSERT_OK(tsl::ReadBinaryProto(tsl::Env::Default(), result_output_file,
+                                    &got_result));
+  // Sadly EqualsProto isn't OSS, so we inspect a few fields manually.
+  // See googletest#1761 and b/229726259.
+  EXPECT_EQ(5, got_result.perf_stats().compilation_duration().seconds());
+  EXPECT_EQ(0.5 * tsl::EnvTime::kSecondsToNanos,
+            got_result.perf_stats().compilation_duration().nanos());
+  EXPECT_EQ(5, got_result.perf_stats().total_duration().seconds());
+  EXPECT_EQ(0.5 * tsl::EnvTime::kSecondsToNanos,
+            got_result.perf_stats().total_duration().nanos());
+}
+
+TEST_F(XlaCompileLibTest, LoadModuleErrors) {
+  EXPECT_THAT(LoadModule("/does/not/exist"), Not(IsOk()));
+}
+
+TEST_F(XlaCompileLibTest, LoadModuleLoadsTextFormat) {
+  const std::string module_file =
+      tsl::io::JoinPath(tsl::testing::TmpDir(), "module.txt");
+  TF_ASSERT_OK(tsl::WriteStringToFile(tsl::Env::Default(), module_file,
+                                      module_->ToString()));
+
+  EXPECT_THAT(LoadModule(module_file), IsOkAndHolds(Not(IsNull())));
+}
+
+TEST_F(XlaCompileLibTest, MainForCpu) {
+  const std::string module_file =
+      tsl::io::JoinPath(tsl::testing::TmpDir(), "module.txt");
+  TF_ASSERT_OK(tsl::WriteStringToFile(tsl::Env::Default(), module_file,
+                                      module_->ToString()));
+
+  const std::string output_path =
+      tsl::io::JoinPath(tsl::testing::TmpDir(), "cpu_output");
+  const std::string result_file =
+      tsl::io::JoinPath(tsl::testing::TmpDir(), "cpu_result.pb");
+
+  XlaCompileOptions options;
+  options.module_path = module_file;
+  options.output_path = output_path;
+  options.platform = "cpu";
+  options.result_output_file = result_file;
+  TF_EXPECT_OK(XlaCompileMain(options));
+
+  CompilationResult result;
+  TF_ASSERT_OK(tsl::ReadBinaryProto(tsl::Env::Default(), result_file, &result));
+  EXPECT_TRUE(result.has_status());
+  EXPECT_EQ(result.status().code(), tensorflow::error::OK);
+}
+
+TEST_F(XlaCompileLibTest, LoadAutotuneDataCpu) {
+  HloModuleAndMetadata mod;
+  mod.hlo_module = std::move(module_);
+
+  EXPECT_THAT(internal::LoadAutotuneDataFromModule(&mod, BackendType::kCpu),
+              IsOkAndHolds(false));
+}
+
+}  // namespace
+}  // namespace xla
diff --git a/third_party/xla/xla/tools/xla_gpu_compile_lib_test.cc b/third_party/xla/xla/tools/xla_gpu_compile_lib_test.cc
new file mode 100644
index 0000000..bc34c87
--- /dev/null
+++ b/third_party/xla/xla/tools/xla_gpu_compile_lib_test.cc
@@ -0,0 +1,195 @@
+/* Copyright 2023 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <optional>
+#include <string>
+#include <utility>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/gpu/autotuning/autotuner_util.h"
+#include "xla/service/gpu/gpu_symbol_repository.h"
+#include "xla/service/platform_util.h"
+#include "xla/service/symbol_repository.h"
+#include "xla/service/xla_compile_result.pb.h"
+#include "xla/stream_executor/device_description.pb.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/tools/xla_compile_lib.h"
+#include "xla/tsl/lib/core/status_test_util.h"
+#include "xla/util.h"
+#include "tsl/platform/env.h"
+#include "tsl/platform/path.h"
+#include "tsl/platform/status_matchers.h"
+#include "tsl/platform/statusor.h"
+#include "tsl/platform/test.h"
+#include "tsl/protobuf/error_codes.pb.h"
+#include "tsl/protobuf/status.pb.h"
+
+namespace xla {
+namespace {
+
+using ::testing::IsEmpty;
+using ::testing::Not;
+using ::tsl::testing::IsOkAndHolds;
+
+class XlaCompileLibTest : public HloTestBase {
+ protected:
+  XlaCompileLibTest()
+      : HloTestBase(*PlatformUtil::GetPlatform(std::string("GPU")),
+                    GetReferencePlatform()) {}
+  void SetUp() override {
+    const std::string hlo_path = tsl::io::JoinPath(tsl::testing::XlaSrcRoot(),
+                                                   "tools", "data", "add.hlo");
+    std::string hlo;
+    TF_ASSERT_OK(tsl::ReadFileToString(tsl::Env::Default(), hlo_path, &hlo));
+    TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo));
+  }
+
+  std::unique_ptr<HloModule> module_;
+};
+
+TEST_F(XlaCompileLibTest, CompilesForGpuWithDevice) {
+  CompilationResult result;
+  EXPECT_THAT(CompileExecutable(std::move(module_), BackendType::kGpu,
+                                std::nullopt, result),
+              IsOkAndHolds(Not(IsEmpty())));
+  EXPECT_TRUE(result.has_hlo_module()) << result.DebugString();
+}
+
+TEST_F(XlaCompileLibTest, CompilesForGpuWithoutDevice) {
+  const std::string target_config_path =
+      tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "service",
+                        "xla_aot_compile_test_gpu_target_config.prototxt");
+  stream_executor::GpuTargetConfigProto target_config;
+  TF_ASSERT_OK(tsl::ReadTextProto(tsl::Env::Default(), target_config_path,
+                                  &target_config));
+  CompilationResult result;
+  EXPECT_THAT(CompileExecutable(std::move(module_), BackendType::kGpu,
+                                std::nullopt, result),
+              IsOkAndHolds(Not(IsEmpty())));
+  EXPECT_TRUE(result.has_hlo_module()) << result.DebugString();
+}
+
+TEST_F(XlaCompileLibTest, MainForGpu) {
+  const std::string module_file =
+      tsl::io::JoinPath(tsl::testing::TmpDir(), "module.txt");
+  TF_ASSERT_OK(tsl::WriteStringToFile(tsl::Env::Default(), module_file,
+                                      module_->ToString()));
+
+  const std::string output_path =
+      tsl::io::JoinPath(tsl::testing::TmpDir(), "gpu_output");
+  const std::string result_file =
+      tsl::io::JoinPath(tsl::testing::TmpDir(), "gpu_result.pb");
+
+  XlaCompileOptions options;
+  options.module_path = module_file;
+  options.output_path = output_path;
+  options.platform = "gpu";
+  options.result_output_file = result_file;
+  options.gpu_options.use_attached_device = true;
+  TF_EXPECT_OK(XlaCompileMain(options));
+
+  CompilationResult result;
+  TF_ASSERT_OK(tsl::ReadBinaryProto(tsl::Env::Default(), result_file, &result));
+  EXPECT_TRUE(result.has_status());
+  EXPECT_EQ(result.status().code(), tensorflow::error::OK);
+}
+
+TEST_F(XlaCompileLibTest, LoadAutotuneDataGpuDataPresentAndAutotuningEnabled) {
+  gpu::AutotunerUtil::ClearAutotuneResults();
+
+  HloModuleAndMetadata mod;
+  mod.hlo_module = std::move(module_);
+  auto data = std::make_unique<gpu::GpuBackendSpecificData>();
+
+  AutotuneResults autotune_results;
+  TF_ASSERT_OK(tsl::ReadTextProto(
+      tsl::Env::Default(),
+      tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "service", "gpu",
+                        "gpu_compiler_test_autotune_db.textproto"),
+      &autotune_results));
+  data->autotune_results = autotune_results;
+  mod.backend_specific_data = std::move(data);
+
+  DebugOptions opts = mod.hlo_module->config().debug_options();
+  opts.set_xla_gpu_autotune_level(3);
+  mod.hlo_module->mutable_config().set_debug_options(opts);
+
+  EXPECT_THAT(internal::LoadAutotuneDataFromModule(&mod, BackendType::kGpu),
+              IsOkAndHolds(true));
+  EXPECT_FALSE(gpu::AutotunerUtil::ResultCacheIsEmpty());
+}
+
+TEST_F(XlaCompileLibTest, LoadAutotuneDataGpuDataPresentAndAutotuningDisabled) {
+  gpu::AutotunerUtil::ClearAutotuneResults();
+
+  HloModuleAndMetadata mod;
+  mod.hlo_module = std::move(module_);
+  auto data = std::make_unique<gpu::GpuBackendSpecificData>();
+
+  AutotuneResults autotune_results;
+  TF_ASSERT_OK(tsl::ReadTextProto(
+      tsl::Env::Default(),
+      tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "service", "gpu",
+                        "gpu_compiler_test_autotune_db.textproto"),
+      &autotune_results));
+  data->autotune_results = autotune_results;
+  mod.backend_specific_data = std::move(data);
+
+  DebugOptions opts = mod.hlo_module->config().debug_options();
+  opts.set_xla_gpu_autotune_level(0);
+  mod.hlo_module->mutable_config().set_debug_options(opts);
+
+  EXPECT_THAT(internal::LoadAutotuneDataFromModule(&mod, BackendType::kGpu),
+              IsOkAndHolds(false));
+  EXPECT_TRUE(gpu::AutotunerUtil::ResultCacheIsEmpty());
+}
+
+TEST_F(XlaCompileLibTest,
+       LoadAutotuneDataGpuDataNotPresentAndAutotuningEnabled) {
+  gpu::AutotunerUtil::ClearAutotuneResults();
+
+  HloModuleAndMetadata mod;
+  mod.hlo_module = std::move(module_);
+
+  DebugOptions opts = mod.hlo_module->config().debug_options();
+  opts.set_xla_gpu_autotune_level(3);
+  mod.hlo_module->mutable_config().set_debug_options(opts);
+
+  EXPECT_THAT(internal::LoadAutotuneDataFromModule(&mod, BackendType::kGpu),
+              IsOkAndHolds(false));
+  EXPECT_TRUE(gpu::AutotunerUtil::ResultCacheIsEmpty());
+}
+
+TEST_F(XlaCompileLibTest,
+       LoadAutotuneDataGpuDataNotPresentAndAutotuningDisabled) {
+  gpu::AutotunerUtil::ClearAutotuneResults();
+
+  HloModuleAndMetadata mod;
+  mod.hlo_module = std::move(module_);
+
+  DebugOptions opts = mod.hlo_module->config().debug_options();
+  opts.set_xla_gpu_autotune_level(0);
+  mod.hlo_module->mutable_config().set_debug_options(opts);
+
+  EXPECT_THAT(internal::LoadAutotuneDataFromModule(&mod, BackendType::kGpu),
+              IsOkAndHolds(false));
+  EXPECT_TRUE(gpu::AutotunerUtil::ResultCacheIsEmpty());
+}
+
+}  // namespace
+}  // namespace xla
diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/BUILD b/third_party/xla/xla/translate/hlo_to_mhlo/BUILD
index 2d1ae78..179ae72 100644
--- a/third_party/xla/xla/translate/hlo_to_mhlo/BUILD
+++ b/third_party/xla/xla/translate/hlo_to_mhlo/BUILD
@@ -21,12 +21,34 @@
         "//xla:xla_data_proto_cc",
         "//xla/mlir_hlo",
         "//xla/service:hlo_proto_cc",
+        "@com_google_absl//absl/algorithm:container",
         "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/types:span",
+        "@llvm-project//llvm:Support",
         "@llvm-project//mlir:IR",
     ],
 )
 
 cc_library(
+    name = "async_importer",
+    srcs = ["async_importer.cc"],
+    hdrs = ["async_importer.h"],
+    deps = [
+        ":attribute_importer",
+        ":hlo_utils",
+        "//xla:util",
+        "//xla/hlo/ir:hlo",
+        "//xla/mlir_hlo",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:FuncDialect",
+        "@llvm-project//mlir:IR",
+        "@local_tsl//tsl/platform:errors",
+    ],
+)
+
+cc_library(
     name = "custom_call_importer",
     srcs = ["custom_call_importer.cc"],
     hdrs = ["custom_call_importer.h"],
@@ -67,6 +89,7 @@
         "hlo_module_importer.h",
     ],
     deps = [
+        ":async_importer",
         ":attribute_importer",
         ":custom_call_importer",
         ":hlo_utils",
@@ -84,7 +107,6 @@
         "//xla/hlo/ir:hlo",
         "//xla/mlir_hlo",
         "//xla/service:hlo_proto_cc",
-        "@com_google_absl//absl/algorithm:container",
         "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/log:check",
         "@com_google_absl//absl/status",
@@ -113,7 +135,11 @@
         ":hlo_module_importer",
         "//xla:status_macros",
         "//xla/mlir/utils:error_util",
+        "//xla/service/llvm_ir:llvm_util",
         "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
+        "@llvm-project//mlir:IR",
+        "@local_tsl//tsl/platform:errors",
     ],
 )
 
diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/async_importer.cc b/third_party/xla/xla/translate/hlo_to_mhlo/async_importer.cc
new file mode 100644
index 0000000..57bc78a
--- /dev/null
+++ b/third_party/xla/xla/translate/hlo_to_mhlo/async_importer.cc
@@ -0,0 +1,383 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/translate/hlo_to_mhlo/async_importer.h"
+
+#include <cassert>
+#include <functional>
+#include <string>
+
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Casting.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/IR/Value.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
+#include "xla/translate/hlo_to_mhlo/attribute_importer.h"
+#include "xla/translate/hlo_to_mhlo/hlo_utils.h"
+#include "xla/util.h"
+#include "tsl/platform/errors.h"
+
+namespace xla {
+
+namespace {
+
+constexpr char kFrontendAttributesAttr[] = "mhlo.frontend_attributes";
+constexpr char kShardingAttr[] = "mhlo.sharding";
+
+// ============
+// Imports an old-style async start op. E.g. an HLO all-gather-start
+// instruction is imported as an async-start associated with an all-gather
+// computation.
+//
+// Eventually, old-style async ops (e.g. all-gather-start) and new-style async
+// ops (i.e. async-start, async-update and async-done) will converge on the
+// HLO side, so we decided to not introduce new MHLO ops for all-gather-start
+// and friends.
+//
+// In the end, there may be new ops added in the old-style because they're not
+// compatible with the new-style async semantics, but those should be handled
+// on their own, rather than this function which "upgrades" ops to the
+// new-style async API.
+// ============
+template <typename sync_op>
+absl::StatusOr<mlir::Operation*> ImportOldStyleAsyncStart(
+    mlir::SymbolTable& symbol_table,
+    llvm::SmallVectorImpl<mlir::NamedAttribute>& attributes,
+    const llvm::SmallVectorImpl<mlir::Value>& operands, mlir::Location loc,
+    mlir::Type result_type, mlir::OpBuilder* builder, std::string func_name,
+    std::function<absl::Status(sync_op)> mutate_op) {
+  auto context = builder->getContext();
+  if (!llvm::isa<mlir::TupleType>(result_type)) {
+    return tsl::errors::InvalidArgument(
+        "expected async_bundle tuple result type");
+  }
+  auto result_types = result_type.cast<mlir::TupleType>().getTypes();
+  if (result_types.size() < 2) {
+    return tsl::errors::InvalidArgument(
+        "async_bundle must contain at least two values");
+  }
+  auto func_type = mlir::FunctionType::get(context, Untuple(result_types[0]),
+                                           Untuple(result_types[1]));
+  auto function = mlir::func::FuncOp::create(loc, func_name, func_type);
+
+  // The new function doesn't need to be inserted in the beginning but is done
+  // to make testing easier and preserve the original behavior.
+  mlir::Block& block = symbol_table.getOp()->getRegion(0).front();
+  symbol_table.insert(function, mlir::Block::iterator(block.begin()));
+
+  function.setPrivate();
+  auto async_builder = mlir::OpBuilder(function.getBody());
+
+  llvm::SmallVector<mlir::NamedAttribute> async_attributes;
+  async_attributes.push_back(builder->getNamedAttr(
+      "called_computation",
+      mlir::FlatSymbolRefAttr::get(builder->getContext(), function.getName())));
+  async_attributes.push_back(builder->getNamedAttr(
+      "execution_thread", builder->getStringAttr("main")));
+
+  // Attach the frontend_attributes and sharding attributes to the async op
+  // instead of the sync op. First, semantically sharding attributes cannot be
+  // attached to the sync op since the sync op may not produce the same number
+  // of results as the sharding's tuple element count, e.g., `mhlo.send` vs. HLO
+  // `send`. Second, `mlir_hlo_to_hlo.cc` imports these attributes from the
+  // `mhlo.async_start` ops, so attaching them to the sync op will make them
+  // disappear during MHLO to HLO lowering.
+  for (auto it = attributes.begin(); it != attributes.end();) {
+    if (it->getName() == kShardingAttr ||
+        it->getName() == kFrontendAttributesAttr) {
+      async_attributes.push_back(*it);
+      it = attributes.erase(it);
+    } else {
+      ++it;
+    }
+  }
+
+  llvm::SmallVector<mlir::Location, 1> locs(Untuple(result_types[0]).size(),
+                                            loc);
+  auto sync_operand =
+      async_builder
+          .createBlock(&function.getBody(), {}, Untuple(result_types[0]), locs)
+          ->getArguments();
+  auto sync_operation = async_builder.create<sync_op>(
+      loc, Untuple(result_types[1]), sync_operand, attributes);
+  async_builder.create<mlir::func::ReturnOp>(loc, sync_operation->getResults());
+  TF_RETURN_IF_ERROR(mutate_op(sync_operation));
+
+  function->setAttr("execution_thread", builder->getStringAttr("main"));
+
+  auto bundle_result_type =
+      mlir::mhlo::AsyncBundleType::get(context, result_types);
+  return builder
+      ->create<mlir::mhlo::AsyncStartOp>(loc, bundle_result_type, operands,
+                                         async_attributes)
+      .getOperation();
+}
+
+absl::StatusOr<mlir::Operation*> ImportOldStyleAsyncDone(
+    llvm::SmallVectorImpl<mlir::NamedAttribute>& attributes,
+    const llvm::SmallVectorImpl<mlir::Value>& operands, mlir::Location loc,
+    mlir::Type result_type, mlir::OpBuilder* builder,
+    bool useBundleResult = false) {
+  assert(operands.size() == 1 &&
+         "*-done ops must take only a single async_bundle operand");
+  auto async_start = operands[0].getDefiningOp<mlir::mhlo::AsyncStartOp>();
+  if (!async_start) return InvalidArgument("*-start requires *-done as input");
+  attributes.push_back(builder->getNamedAttr(
+      "called_computation",
+      mlir::FlatSymbolRefAttr::get(builder->getContext(),
+                                   async_start.getCalledComputation())));
+  attributes.push_back(builder->getNamedAttr("execution_thread",
+                                             builder->getStringAttr("main")));
+
+  auto async_bundle = llvm::cast<mlir::mhlo::AsyncBundleType>(
+      async_start.getResult().getType());
+
+  auto start_tuple =
+      llvm::dyn_cast<mlir::TupleType>(async_bundle.getTypes()[1]);
+  if (start_tuple && llvm::isa<mlir::TupleType>(start_tuple.getType(0))) {
+    auto op = builder->create<mlir::mhlo::AsyncDoneOp>(loc, result_type,
+                                                       operands, attributes);
+    return {op};
+  } else {
+    if (useBundleResult) result_type = async_bundle.getTypes()[1];
+    auto op = builder->create<mlir::mhlo::AsyncDoneOp>(
+        loc, Untuple(result_type), operands, attributes);
+    return CreateTupleFromOpResults(builder, loc, op.getOperation(),
+                                    result_type);
+  }
+}
+
+}  // namespace
+
+// Op Converters
+
+absl::StatusOr<mlir::Operation*> ImportSend(
+    const HloInstruction* instruction, mlir::Location loc,
+    const llvm::SmallVectorImpl<mlir::Value>& operands,
+    llvm::SmallVectorImpl<mlir::NamedAttribute>& attributes,
+    mlir::Type result_type, mlir::OpBuilder* builder,
+    mlir::SymbolTable& symbol_table) {
+  auto send_op = Cast<HloSendInstruction>(instruction);
+  attributes.push_back(builder->getNamedAttr(
+      "is_host_transfer", builder->getBoolAttr(send_op->is_host_transfer())));
+  if (send_op->channel_id().has_value()) {
+    ChannelHandle channel_handle;
+    channel_handle.set_handle(send_op->channel_id().value());
+    channel_handle.set_type(send_op->is_host_transfer()
+                                ? ChannelHandle::DEVICE_TO_HOST
+                                : ChannelHandle::DEVICE_TO_DEVICE);
+    attributes.push_back(ConvertChannelHandle(channel_handle, builder));
+  }
+
+  // Return async_start/done for pipelined send.
+  //
+  // old-style send returns a bundle of (arg, sync flag, token) to be passed
+  // along to send-done.
+  // However, the new-style async ops have a shared bundle
+  // format of (args, results, scratchpad), so to rewrite the `send` and
+  // `send-done` ops to use the new-style async API, we need to reorder the
+  // arguments to be in (args, token, sync flag) order.
+  auto result_types = result_type.cast<mlir::TupleType>().getTypes();
+  if (result_types.size() != 3)
+    return InvalidArgument("send should return a 3-tuple");
+  auto async_arg_type = mlir::TupleType::get(
+      builder->getContext(), {result_types[0], result_types[2]});
+  auto async_bundled_tuple =
+      mlir::TupleType::get(builder->getContext(),
+                           {async_arg_type, result_types[2], result_types[1]});
+  return ImportOldStyleAsyncStart<mlir::mhlo::SendOp>(
+      symbol_table, attributes, operands, loc, async_bundled_tuple, builder,
+      "send_", [](auto) { return absl::OkStatus(); });
+}
+
+absl::StatusOr<mlir::Operation*> ImportRecv(
+    const HloInstruction* instruction, mlir::Location loc,
+    const llvm::SmallVectorImpl<mlir::Value>& operands,
+    llvm::SmallVectorImpl<mlir::NamedAttribute>& attributes,
+    mlir::Type result_type, mlir::OpBuilder* builder,
+    mlir::SymbolTable& symbol_table) {
+  auto recv_op = Cast<HloRecvInstruction>(instruction);
+  attributes.push_back(builder->getNamedAttr(
+      "is_host_transfer", builder->getBoolAttr(recv_op->is_host_transfer())));
+  if (recv_op->channel_id().has_value()) {
+    ChannelHandle channel_handle;
+    channel_handle.set_handle(recv_op->channel_id().value());
+    channel_handle.set_type(recv_op->is_host_transfer()
+                                ? ChannelHandle::HOST_TO_DEVICE
+                                : ChannelHandle::DEVICE_TO_DEVICE);
+    attributes.push_back(ConvertChannelHandle(channel_handle, builder));
+  }
+
+  // Old-style `recv` returns a bundle of (result, sync flag, token) to be
+  // passed along to recv-done.
+  // However, the new-style async ops have a shared
+  // bundle format of (args, results, scratchpad), so to rewrite the `recv`
+  // and `recv-done` ops to use the new-style async API, we need to reorder
+  // the arguments to be in (token, (result, token), sync flag) order.
+  // OR (token, token, sync flag) if no result is received.
+  auto result_types = result_type.cast<mlir::TupleType>().getTypes();
+  if (result_types.size() != 3)
+    return InvalidArgument("recv should return a 3-tuple");
+
+  // Allow recv of no values, only token.
+  // b/TODO: Allow recv of no values, only token.
+  auto async_result_type = mlir::TupleType::get(
+      builder->getContext(), {result_types[0], result_types[2]});
+  auto async_bundled_tuple = mlir::TupleType::get(
+      builder->getContext(),
+      {result_types[2], async_result_type, result_types[1]});
+  return ImportOldStyleAsyncStart<mlir::mhlo::RecvOp>(
+      symbol_table, attributes, operands, loc, async_bundled_tuple, builder,
+      "recv_", [](auto) { return absl::OkStatus(); });
+}
+
+// Async Collectives
+
+absl::StatusOr<mlir::Operation*> ImportAllGatherStart(
+    const HloInstruction* instruction, mlir::Location loc,
+    const llvm::SmallVectorImpl<mlir::Value>& operands,
+    llvm::SmallVectorImpl<mlir::NamedAttribute>& attributes,
+    mlir::Type result_type, mlir::OpBuilder* builder,
+    mlir::SymbolTable& symbol_table) {
+  auto all_gather_start = Cast<HloAllGatherInstruction>(instruction);
+  attributes.push_back(builder->getNamedAttr(
+      "all_gather_dim",
+      builder->getI64IntegerAttr(all_gather_start->all_gather_dimension())));
+  attributes.push_back(
+      ConvertReplicaGroups(all_gather_start->replica_groups(), builder));
+  if (all_gather_start->channel_id().has_value())
+    attributes.push_back(
+        ConvertChannelHandle(all_gather_start->channel_id().value(), builder));
+  if (all_gather_start->use_global_device_ids())
+    attributes.push_back(ConvertUseGlobalDeviceIds(builder));
+  if (all_gather_start->operands().size() > 1)
+    return InvalidArgument("Async tuple all-gather is not supported in MHLO");
+
+  if (!llvm::isa<mlir::TupleType>(result_type)) {
+    // Async AllGather's output type is bundle<input_type,output_type>
+    // There are some instances where the output type is not a tuple, this seems
+    // to be the more modern case, so we will wrap these in a tuple for MHLO.
+    result_type = mlir::TupleType::get(builder->getContext(),
+                                       {operands[0].getType(), result_type});
+  }
+
+  return ImportOldStyleAsyncStart<mlir::mhlo::AllGatherOp>(
+      symbol_table, attributes, operands, loc, result_type, builder,
+      "all_gather_", [](auto) { return absl::OkStatus(); });
+}
+
+absl::StatusOr<mlir::Operation*> ImportAllReduceStart(
+    const HloInstruction* instruction, mlir::Location loc,
+    const llvm::SmallVectorImpl<mlir::Value>& operands,
+    llvm::SmallVectorImpl<mlir::NamedAttribute>& attributes,
+    mlir::Type result_type, mlir::OpBuilder* builder,
+    std::function<absl::Status(mlir::mhlo::AllReduceOp)> mutate_op,
+    mlir::SymbolTable& symbol_table) {
+  auto all_reduce_start = Cast<HloAllReduceInstruction>(instruction);
+  attributes.push_back(
+      ConvertReplicaGroups(all_reduce_start->replica_groups(), builder));
+  if (all_reduce_start->channel_id().has_value())
+    attributes.push_back(
+        ConvertChannelHandle(all_reduce_start->channel_id().value(), builder));
+  if (all_reduce_start->use_global_device_ids())
+    attributes.push_back(ConvertUseGlobalDeviceIds(builder));
+  if (all_reduce_start->operands().size() > 1)
+    return InvalidArgument("Async tuple all-reduce is not supported in MHLO");
+
+  if (!llvm::isa<mlir::TupleType>(result_type)) {
+    // Async AllReduce's output type is bundle<input_type,output_type>
+    // There are some instances where the output type is not a tuple, this seems
+    // to be the more modern case, so we will wrap these in a tuple for MHLO.
+    result_type = mlir::TupleType::get(builder->getContext(),
+                                       {operands[0].getType(), result_type});
+  }
+
+  return ImportOldStyleAsyncStart<mlir::mhlo::AllReduceOp>(
+      symbol_table, attributes, operands, loc, result_type, builder,
+      "all_reduce_", mutate_op);
+}
+
+// Collective Permute
+
+absl::StatusOr<mlir::Operation*> ImportCollectivePermuteStart(
+    const HloInstruction* instruction, mlir::Location loc,
+    const llvm::SmallVectorImpl<mlir::Value>& operands,
+    llvm::SmallVectorImpl<mlir::NamedAttribute>& attributes,
+    mlir::Type result_type, mlir::OpBuilder* builder,
+    mlir::SymbolTable& symbol_table) {
+  attributes.push_back(
+      ConvertSourceTargetPairs(instruction->source_target_pairs(), builder));
+  if (!llvm::isa<mlir::TupleType>(result_type)) {
+    // Async CollectivePermute's output type is bundle<input_type,output_type>
+    // There are some instances where the output type is not a tuple, this seems
+    // to be the more modern case, so we will wrap these in a tuple for MHLO.
+    result_type = mlir::TupleType::get(builder->getContext(),
+                                       {operands[0].getType(), result_type});
+  }
+  return ImportOldStyleAsyncStart<mlir::mhlo::CollectivePermuteOp>(
+      symbol_table, attributes, operands, loc, result_type, builder,
+      "collective_permute_", [&](auto) { return absl::OkStatus(); });
+}
+
+absl::StatusOr<mlir::Operation*> ImportCopyStart(
+    const HloInstruction* instruction, mlir::Location loc,
+    const llvm::SmallVectorImpl<mlir::Value>& operands,
+    llvm::SmallVectorImpl<mlir::NamedAttribute>& attributes,
+    mlir::Type result_type, mlir::OpBuilder* builder,
+    mlir::SymbolTable& symbol_table) {
+  auto context = builder->getContext();
+  auto copy_start_instruction = Cast<HloCopyStartInstruction>(instruction);
+  if (auto cross_program_prefetch_index =
+          copy_start_instruction->cross_program_prefetch_index()) {
+    attributes.push_back(builder->getNamedAttr(
+        "cross_program_prefetch_index",
+        builder->getIntegerAttr(builder->getIntegerType(32),
+                                *cross_program_prefetch_index)));
+    // Cross-program prefetch allows copy ops to accept tuples, in which
+    // case, we need to double-wrap inputs and outputs in tuples.
+    if (operands[0].getType().isa<mlir::TupleType>()) {
+      auto result_types = result_type.cast<mlir::TupleType>().getTypes();
+      result_type = mlir::TupleType::get(
+          context,
+          {mlir::TupleType::get(context, {result_types[0]}),
+           mlir::TupleType::get(context, {result_types[1]}), result_types[2]});
+    }
+  }
+  return ImportOldStyleAsyncStart<mlir::mhlo::CopyOp>(
+      symbol_table, attributes, operands, loc, result_type, builder, "copy_",
+      [](auto) { return absl::OkStatus(); });
+}
+
+absl::StatusOr<mlir::Operation*> ImportAsyncOpDone(
+    const HloInstruction* instruction, mlir::Location loc,
+    const llvm::SmallVectorImpl<mlir::Value>& operands,
+    llvm::SmallVectorImpl<mlir::NamedAttribute>& attributes,
+    mlir::Type result_type, mlir::OpBuilder* builder) {
+  return ImportOldStyleAsyncDone(attributes, operands, loc, result_type,
+                                 builder);
+}
+
+}  // namespace xla
diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/async_importer.h b/third_party/xla/xla/translate/hlo_to_mhlo/async_importer.h
new file mode 100644
index 0000000..efdd487
--- /dev/null
+++ b/third_party/xla/xla/translate/hlo_to_mhlo/async_importer.h
@@ -0,0 +1,88 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_TRANSLATE_HLO_TO_MHLO_ASYNC_IMPORTER_H_
+#define XLA_TRANSLATE_HLO_TO_MHLO_ASYNC_IMPORTER_H_
+
+#include <functional>
+
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "llvm/ADT/SmallVector.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/IR/Value.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
+
+namespace xla {
+
+// Op Converters
+absl::StatusOr<mlir::Operation*> ImportSend(
+    const HloInstruction* instruction, mlir::Location loc,
+    const llvm::SmallVectorImpl<mlir::Value>& operands,
+    llvm::SmallVectorImpl<mlir::NamedAttribute>& attributes,
+    mlir::Type result_type, mlir::OpBuilder* builder,
+    mlir::SymbolTable& symbol_table);
+
+absl::StatusOr<mlir::Operation*> ImportRecv(
+    const HloInstruction* instruction, mlir::Location loc,
+    const llvm::SmallVectorImpl<mlir::Value>& operands,
+    llvm::SmallVectorImpl<mlir::NamedAttribute>& attributes,
+    mlir::Type result_type, mlir::OpBuilder* builder,
+    mlir::SymbolTable& symbol_table);
+
+// Async Collectives
+absl::StatusOr<mlir::Operation*> ImportAllGatherStart(
+    const HloInstruction* instruction, mlir::Location loc,
+    const llvm::SmallVectorImpl<mlir::Value>& operands,
+    llvm::SmallVectorImpl<mlir::NamedAttribute>& attributes,
+    mlir::Type result_type, mlir::OpBuilder* builder,
+    mlir::SymbolTable& symbol_table);
+
+absl::StatusOr<mlir::Operation*> ImportAllReduceStart(
+    const HloInstruction* instruction, mlir::Location loc,
+    const llvm::SmallVectorImpl<mlir::Value>& operands,
+    llvm::SmallVectorImpl<mlir::NamedAttribute>& attributes,
+    mlir::Type result_type, mlir::OpBuilder* builder,
+    std::function<absl::Status(mlir::mhlo::AllReduceOp)> mutate_op,
+    mlir::SymbolTable& symbol_table);
+
+absl::StatusOr<mlir::Operation*> ImportCollectivePermuteStart(
+    const HloInstruction* instruction, mlir::Location loc,
+    const llvm::SmallVectorImpl<mlir::Value>& operands,
+    llvm::SmallVectorImpl<mlir::NamedAttribute>& attributes,
+    mlir::Type result_type, mlir::OpBuilder* builder,
+    mlir::SymbolTable& symbol_table);
+
+absl::StatusOr<mlir::Operation*> ImportCopyStart(
+    const HloInstruction* instruction, mlir::Location loc,
+    const llvm::SmallVectorImpl<mlir::Value>& operands,
+    llvm::SmallVectorImpl<mlir::NamedAttribute>& attributes,
+    mlir::Type result_type, mlir::OpBuilder* builder,
+    mlir::SymbolTable& symbol_table);
+
+absl::StatusOr<mlir::Operation*> ImportAsyncOpDone(
+    const HloInstruction* instruction, mlir::Location loc,
+    const llvm::SmallVectorImpl<mlir::Value>& operands,
+    llvm::SmallVectorImpl<mlir::NamedAttribute>& attributes,
+    mlir::Type result_type, mlir::OpBuilder* builder);
+
+}  // namespace xla
+
+#endif  // XLA_TRANSLATE_HLO_TO_MHLO_ASYNC_IMPORTER_H_
diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/attribute_importer.cc b/third_party/xla/xla/translate/hlo_to_mhlo/attribute_importer.cc
index 9a2fa06..fbe1a90 100644
--- a/third_party/xla/xla/translate/hlo_to_mhlo/attribute_importer.cc
+++ b/third_party/xla/xla/translate/hlo_to_mhlo/attribute_importer.cc
@@ -17,13 +17,24 @@
 
 #include <sys/types.h>
 
+#include <algorithm>
+#include <cstdint>
 #include <optional>
 #include <utility>
 #include <vector>
 
+#include "absl/algorithm/container.h"
 #include "absl/status/statusor.h"
+#include "absl/types/span.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "xla/layout.h"
 #include "xla/layout_util.h"
 #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
+#include "xla/shape.h"
 #include "xla/shape_util.h"
 #include "xla/util.h"
 #include "xla/xla_data.pb.h"
@@ -195,6 +206,62 @@
   }
 }
 
+mlir::NamedAttribute ConvertChannelHandle(const ChannelHandle& channel,
+                                          mlir::Builder* builder) {
+  return builder->getNamedAttr(
+      "channel_handle",
+      mlir::mhlo::ChannelHandleAttr::get(builder->getContext(),
+                                         channel.handle(), channel.type()));
+}
+mlir::NamedAttribute ConvertChannelHandle(std::optional<int64_t> channel_id,
+                                          mlir::Builder* builder) {
+  ChannelHandle channel_handle;
+  if (channel_id) channel_handle.set_handle(*channel_id);
+  return ConvertChannelHandle(channel_handle, builder);
+}
+
+mlir::NamedAttribute ConvertReplicaGroups(
+    absl::Span<const ReplicaGroup> replica_groups, mlir::Builder* builder) {
+  const int64_t num_groups = replica_groups.size();
+  // Replica groups in HLO can be non-uniform in size, for example:
+  // replica_groups={{0},{1,2},{3}}. Since we are representing them as a 2D
+  // tensor, pad the smaller sized replica groups with -1.
+  const int64_t group_size = absl::c_accumulate(
+      replica_groups, static_cast<int64_t>(0),
+      [](int64_t current, const ReplicaGroup& g) {
+        return std::max<int64_t>(current, g.replica_ids_size());
+      });
+  // Initialize all elements to -1 to support non-uniform replica groups.
+  std::vector<int64_t> attr(num_groups * group_size, -1);
+  for (int i = 0; i < num_groups; ++i) {
+    int index = i * group_size;
+    for (const int64_t& id : replica_groups[i].replica_ids())
+      attr[index++] = id;
+  }
+  auto type = mlir::RankedTensorType::get({num_groups, group_size},
+                                          builder->getIntegerType(64));
+  return builder->getNamedAttr("replica_groups",
+                               mlir::DenseIntElementsAttr::get(type, attr));
+}
+
+mlir::NamedAttribute ConvertSourceTargetPairs(
+    const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
+    mlir::Builder* builder) {
+  std::vector<int64_t> attr(source_target_pairs.size() * 2);
+  for (const auto& p : llvm::enumerate(source_target_pairs)) {
+    attr[2 * p.index()] = p.value().first;
+    attr[2 * p.index() + 1] = p.value().second;
+  }
+  auto type = mlir::RankedTensorType::get(
+      {static_cast<int64_t>(attr.size() / 2), 2}, builder->getIntegerType(64));
+  return builder->getNamedAttr("source_target_pairs",
+                               mlir::DenseIntElementsAttr::get(type, attr));
+}
+
+mlir::NamedAttribute ConvertUseGlobalDeviceIds(mlir::Builder* builder) {
+  return builder->getNamedAttr("use_global_device_ids", builder->getUnitAttr());
+}
+
 absl::StatusOr<mlir::ArrayAttr> ExtractLayoutsFromShapes(
     const absl::Span<const Shape> shapes_with_layouts, mlir::Builder* builder) {
   std::vector<mlir::Attribute> layouts;
diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/attribute_importer.h b/third_party/xla/xla/translate/hlo_to_mhlo/attribute_importer.h
index 4f1ba9e..f836814 100644
--- a/third_party/xla/xla/translate/hlo_to_mhlo/attribute_importer.h
+++ b/third_party/xla/xla/translate/hlo_to_mhlo/attribute_importer.h
@@ -16,10 +16,14 @@
 #ifndef XLA_TRANSLATE_HLO_TO_MHLO_ATTRIBUTE_IMPORTER_H_
 #define XLA_TRANSLATE_HLO_TO_MHLO_ATTRIBUTE_IMPORTER_H_
 
+#include <cstdint>
+#include <optional>
 #include <utility>
 #include <vector>
 
 #include "absl/status/statusor.h"
+#include "absl/types/span.h"
+#include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
 #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
 #include "xla/service/hlo.pb.h"
@@ -66,6 +70,20 @@
 absl::StatusOr<mlir::mhlo::CustomCallApiVersion> ConvertCustomCallApiVersion(
     xla::CustomCallApiVersion api_version);
 
+mlir::NamedAttribute ConvertChannelHandle(const ChannelHandle& channel,
+                                          mlir::Builder* builder);
+mlir::NamedAttribute ConvertChannelHandle(std::optional<int64_t> channel_id,
+                                          mlir::Builder* builder);
+
+mlir::NamedAttribute ConvertReplicaGroups(
+    absl::Span<const ReplicaGroup> replica_groups, mlir::Builder* builder);
+
+mlir::NamedAttribute ConvertSourceTargetPairs(
+    const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
+    mlir::Builder* builder);
+
+mlir::NamedAttribute ConvertUseGlobalDeviceIds(mlir::Builder* builder);
+
 // Extracts layouts from shapes and converts it into layout attributes (array of
 // rank-1 index tensors). Returns an error if any of the shapes is a tuple.
 absl::StatusOr<mlir::ArrayAttr> ExtractLayoutsFromShapes(
diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.cc b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.cc
index 64e08df..6490839 100644
--- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.cc
+++ b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.cc
@@ -15,7 +15,6 @@
 
 #include "xla/translate/hlo_to_mhlo/hlo_function_importer.h"
 
-#include <algorithm>
 #include <cassert>
 #include <cstddef>
 #include <cstdint>
@@ -26,7 +25,6 @@
 #include <utility>
 #include <vector>
 
-#include "absl/algorithm/container.h"
 #include "absl/log/check.h"
 #include "absl/status/status.h"
 #include "absl/status/statusor.h"
@@ -37,6 +35,7 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/Support/Casting.h"
+#include "llvm/Support/Debug.h"
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/LogicalResult.h"
 #include "mlir/AsmParser/AsmParser.h"
@@ -71,6 +70,7 @@
 #include "xla/service/hlo.pb.h"
 #include "xla/shape_util.h"
 #include "xla/status_macros.h"
+#include "xla/translate/hlo_to_mhlo/async_importer.h"
 #include "xla/translate/hlo_to_mhlo/attribute_importer.h"
 #include "xla/translate/hlo_to_mhlo/custom_call_importer.h"
 #include "xla/translate/hlo_to_mhlo/hlo_utils.h"
@@ -90,6 +90,8 @@
 using mlir::Value;
 using mlir::func::FuncOp;
 
+#define DEBUG_TYPE "xla-translate"
+
 namespace xla {
 
 namespace {
@@ -100,7 +102,7 @@
 
 // Note: This sanitization function causes an irreversible many-to-one mapping
 // and any solution to mitigate this would cause issues with the reverse
-// direction. Longterm solution is to add a function attribute to maintain the
+// direction. Long-term solution is to add a function attribute to maintain the
 // original HLO naming.
 std::string SanitizeFunctionName(llvm::StringRef name) {
   std::string output(name);
@@ -170,115 +172,6 @@
 
 }  // namespace
 
-mlir::TypeRange Untuple(const mlir::Type& type) {
-  if (type.isa<mlir::TupleType>()) {
-    return llvm::dyn_cast<mlir::TupleType>(type).getTypes();
-  }
-  return type;
-}
-
-template <typename sync_op>
-absl::StatusOr<mlir::Operation*> HloFunctionImporter::ImportOldStyleAsyncStart(
-    llvm::SmallVectorImpl<mlir::NamedAttribute>& attributes,
-    const llvm::SmallVectorImpl<mlir::Value>& operands, mlir::Location loc,
-    mlir::Type result_type, mlir::OpBuilder* func_builder,
-    std::string func_name, std::function<absl::Status(sync_op)> mutate_op) {
-  auto result_types = result_type.cast<mlir::TupleType>().getTypes();
-  if (result_types.size() < 2) {
-    return tsl::errors::InvalidArgument(
-        "async_bundle must contain at least two values");
-  }
-  auto func_type = mlir::FunctionType::get(context_, Untuple(result_types[0]),
-                                           Untuple(result_types[1]));
-  auto function = FuncOp::create(loc, func_name, func_type);
-
-  // The new function doesn't need to be inserted in the beginning but is done
-  // to make testing easier and preserve the original behavior.
-  mlir::Block& block = symbol_table_.getOp()->getRegion(0).front();
-  symbol_table_.insert(function, mlir::Block::iterator(block.begin()));
-
-  function.setPrivate();
-  auto async_builder = mlir::OpBuilder(function.getBody());
-
-  llvm::SmallVector<mlir::NamedAttribute> async_attributes;
-  async_attributes.push_back(builder_->getNamedAttr(
-      "called_computation", mlir::FlatSymbolRefAttr::get(builder_->getContext(),
-                                                         function.getName())));
-  async_attributes.push_back(builder_->getNamedAttr(
-      "execution_thread", builder_->getStringAttr("main")));
-
-  // Attach the frontend_attributes and sharding attributes to the async op
-  // instead of the sync op. First, semantically sharding attributes cannot be
-  // attached to the sync op since the sync op may not produce the same number
-  // of results as the sharding's tuple element count, e.g., `mhlo.send` vs. HLO
-  // `send`. Second, `mlir_hlo_to_hlo.cc` imports these attributes from the
-  // `mhlo.async_start` ops, so attaching them to the sync op will make them
-  // disappear during MHLO to HLO lowering.
-  for (auto it = attributes.begin(); it != attributes.end();) {
-    if (it->getName() == kShardingAttr ||
-        it->getName() == kFrontendAttributesAttr) {
-      async_attributes.push_back(*it);
-      it = attributes.erase(it);
-    } else {
-      ++it;
-    }
-  }
-
-  llvm::SmallVector<mlir::Location, 1> locs(Untuple(result_types[0]).size(),
-                                            loc);
-  auto sync_operand =
-      async_builder
-          .createBlock(&function.getBody(), {}, Untuple(result_types[0]), locs)
-          ->getArguments();
-  auto sync_operation = async_builder.create<sync_op>(
-      loc, Untuple(result_types[1]), sync_operand, attributes);
-  async_builder.create<mlir::func::ReturnOp>(loc, sync_operation->getResults());
-  TF_RETURN_IF_ERROR(mutate_op(sync_operation));
-
-  function->setAttr("execution_thread", builder_->getStringAttr("main"));
-
-  auto bundle_result_type =
-      mlir::mhlo::AsyncBundleType::get(context_, result_types);
-  return func_builder
-      ->create<mlir::mhlo::AsyncStartOp>(loc, bundle_result_type, operands,
-                                         async_attributes)
-      .getOperation();
-}
-
-absl::StatusOr<mlir::Operation*> HloFunctionImporter::ImportOldStyleAsyncDone(
-    llvm::SmallVectorImpl<NamedAttribute>& attributes,
-    const llvm::SmallVectorImpl<mlir::Value>& operands, mlir::Location loc,
-    mlir::Type result_type, mlir::OpBuilder* func_builder) {
-  if (operands.size() != 1) {
-    return InvalidArgument(
-        "async-done must take only a single async_bundle operand");
-  }
-  auto async_start = operands[0].getDefiningOp<mlir::mhlo::AsyncStartOp>();
-  if (!async_start) return InvalidArgument("*-start requires *-done as input");
-  attributes.push_back(builder_->getNamedAttr(
-      "called_computation",
-      mlir::FlatSymbolRefAttr::get(builder_->getContext(),
-                                   async_start.getCalledComputation())));
-  attributes.push_back(builder_->getNamedAttr("execution_thread",
-                                              builder_->getStringAttr("main")));
-
-  auto start_tuple = async_start.getResult()
-                         .getType()
-                         .cast<mlir::mhlo::AsyncBundleType>()
-                         .getTypes()[1]
-                         .dyn_cast<mlir::TupleType>();
-  if (start_tuple && start_tuple.getType(0).isa<mlir::TupleType>()) {
-    auto op = func_builder->create<mlir::mhlo::AsyncDoneOp>(
-        loc, result_type, operands, attributes);
-    return {op};
-  } else {
-    auto op = func_builder->create<mlir::mhlo::AsyncDoneOp>(
-        loc, Untuple(result_type), operands, attributes);
-    return CreateTupleFromOpResults(func_builder, loc, op.getOperation(),
-                                    result_type);
-  }
-}
-
 void HloFunctionImporter::ReplaceBlockArgumentsWithImplicitOperands(
     mlir::Operation* op, llvm::ArrayRef<mlir::Value> implicit_operands) {
   assert((mlir::dyn_cast<mlir::mhlo::IfOp>(*op) ||
@@ -296,20 +189,6 @@
   }
 }
 
-mlir::Operation* HloFunctionImporter::CreateTupleFromOpResults(
-    mlir::OpBuilder* func_builder, mlir::Location loc, mlir::Operation* op,
-    mlir::Type type) {
-  if (!type.isa<mlir::TupleType>()) return op;
-
-  mlir::ValueRange flattened_results_ref(op->getResults());
-  auto result =
-      CreateTupleValue(func_builder, loc, flattened_results_ref, type);
-  auto defining_tuple_op = result.getDefiningOp<mlir::mhlo::TupleOp>();
-  assert(defining_tuple_op && "builder didn't return the right type");
-  auto tupleOp = defining_tuple_op.getOperation();
-  return tupleOp;
-}
-
 static bool IsNestedTupleInData(Type type) {
   auto tuple_type = type.dyn_cast<mlir::TupleType>();
   if (!tuple_type) return false;
@@ -381,27 +260,6 @@
   return flattened_values;
 }
 
-Value HloFunctionImporter::CreateTupleValue(mlir::OpBuilder* func_builder,
-                                            mlir::Location loc,
-                                            mlir::ValueRange& flatten_values,
-                                            Type type) {
-  auto tuple_type = type.dyn_cast<mlir::TupleType>();
-  if (!tuple_type) {
-    assert(!flatten_values.empty());
-    auto retval = flatten_values.front();
-    flatten_values = flatten_values.drop_front();
-    return retval;
-  }
-
-  llvm::SmallVector<mlir::Value> flatten_sub_values;
-  for (auto child_type : tuple_type.getTypes())
-    flatten_sub_values.push_back(
-        CreateTupleValue(func_builder, loc, flatten_values, child_type));
-
-  return func_builder->create<mlir::mhlo::TupleOp>(loc, flatten_sub_values)
-      .getResult();
-}
-
 absl::StatusOr<mlir::func::FuncOp> HloFunctionImporter::ImportAsFunc(
     const HloComputation& computation, mlir::SymbolTable& symbol_table,
     std::unordered_map<const HloComputation*, FuncOp>* function_map,
@@ -761,7 +619,10 @@
     frontend_attributes.push_back(
         builder_->getNamedAttr(k, builder_->getStringAttr(v)));
   }
+
+  int frontend_attributes_index = 0;
   if (!frontend_attributes.empty()) {
+    frontend_attributes_index = attributes.size();
     attributes.push_back(builder_->getNamedAttr(
         kFrontendAttributesAttr,
         builder_->getDictionaryAttr(frontend_attributes)));
@@ -926,6 +787,68 @@
           FuncOp function,
           ImportAsFunc(*instruction->to_apply(), /*is_main=*/false));
       mlir::Operation* new_operation;
+      if (instruction->is_composite()) {
+        // TODO: b/354721812 -  Support flatten_computation_args_result_ flag
+        // for composite calls
+
+        mlir::DictionaryAttr frontend_attributes_attr =
+            builder_->getDictionaryAttr(frontend_attributes);
+        if (frontend_attributes.empty() ||
+            !frontend_attributes_attr.contains("composite.attributes") ||
+            !frontend_attributes_attr.contains("composite.name") ||
+            !frontend_attributes_attr.contains("composite.version")) {
+          return InvalidArgument(
+              "A composite call op must have frontend attributes with the "
+              "following keys: composite.attributes, composite.name, "
+              "composite.version");
+        }
+
+        llvm::SmallVector<NamedAttribute, 4> fe_attrs_without_composite_attrs;
+        for (const auto& attr : frontend_attributes) {
+          if (attr.getName() != "composite.attributes" &&
+              attr.getName() != "composite.name" &&
+              attr.getName() != "composite.version") {
+            fe_attrs_without_composite_attrs.push_back(attr);
+          }
+        }
+
+        // Frontend attributes may have been created by composite related
+        // attributes. If frontend attributes is empty after removing
+        // composite related attributes, it is not needed, so we remove it
+        // entirely. Otherwise, we update it.
+        if (fe_attrs_without_composite_attrs.empty()) {
+          attributes.erase(attributes.begin() + frontend_attributes_index);
+        } else {
+          attributes[frontend_attributes_index] = builder_->getNamedAttr(
+              kFrontendAttributesAttr,
+              builder_->getDictionaryAttr(fe_attrs_without_composite_attrs));
+        }
+
+        auto frontend_attributes_map = instruction->frontend_attributes().map();
+        mlir::StringAttr name = builder_->getStringAttr(
+            frontend_attributes_map.find("composite.name")->second);
+        mlir::Attribute composite_attributes = mlir::parseAttribute(
+            frontend_attributes_map.find("composite.attributes")->second,
+            builder_->getContext());
+        mlir::FlatSymbolRefAttr decomposition = mlir::SymbolRefAttr::get(
+            builder_->getContext(), instruction->to_apply()->name());
+        mlir::IntegerAttr version = builder_->getIntegerAttr(
+            builder_->getI32Type(),
+            std::stoi(
+                frontend_attributes_map.find("composite.version")->second));
+
+        new_operation = func_builder->create<mlir::mhlo::CompositeOp>(
+            loc, result_type, operands);
+        new_operation->setAttr("name", name);
+        new_operation->setAttr("composite_attributes", composite_attributes);
+        new_operation->setAttr("decomposition", decomposition);
+        new_operation->setAttr("version", version);
+        for (const auto& attr : attributes) {
+          new_operation->setAttr(attr.getName(), attr.getValue());
+        }
+        return new_operation;
+      }
+
       if (flatten_computation_args_result_) {
         // Flatten the tuple-typed operands.
         llvm::SmallVector<Value> flattened_operands = FlattenTupleValues(
@@ -946,7 +869,7 @@
       } else {
         new_operation =
             func_builder->create<mlir::func::CallOp>(loc, function, operands);
-        for (auto attr : attributes) {
+        for (const auto& attr : attributes) {
           new_operation->setAttr(attr.getName(), attr.getValue());
         }
       }
@@ -957,8 +880,8 @@
       attributes.push_back(ConvertReplicaGroups(
           collective_broadcast->replica_groups(), builder_));
       if (collective_broadcast->channel_id().has_value())
-        attributes.push_back(
-            ConvertChannelHandle(collective_broadcast->channel_id().value()));
+        attributes.push_back(ConvertChannelHandle(
+            collective_broadcast->channel_id().value(), builder_));
       return func_builder
           ->create<mlir::mhlo::CollectiveBroadcastOp>(loc, result_type,
                                                       operands, attributes)
@@ -970,23 +893,21 @@
       attributes.push_back(ConvertSourceTargetPairs(
           collective_permute->source_target_pairs(), builder_));
       if (collective_permute->channel_id().has_value())
-        attributes.push_back(
-            ConvertChannelHandle(collective_permute->channel_id().value()));
+        attributes.push_back(ConvertChannelHandle(
+            collective_permute->channel_id().value(), builder_));
       return func_builder
           ->create<mlir::mhlo::CollectivePermuteOp>(loc, result_type, operands,
                                                     attributes)
           .getOperation();
     }
     case HloOpcode::kCollectivePermuteStart: {
-      attributes.push_back(ConvertSourceTargetPairs(
-          instruction->source_target_pairs(), builder_));
-      return ImportOldStyleAsyncStart<mlir::mhlo::CollectivePermuteOp>(
-          attributes, operands, loc, result_type, func_builder,
-          "collective_permute_", [&](auto) { return absl::OkStatus(); });
+      return ImportCollectivePermuteStart(instruction, loc, operands,
+                                          attributes, result_type, func_builder,
+                                          symbol_table_);
     }
     case HloOpcode::kCollectivePermuteDone: {
-      return ImportOldStyleAsyncDone(attributes, operands, loc, result_type,
-                                     func_builder);
+      return ImportAsyncOpDone(instruction, loc, operands, attributes,
+                               result_type, func_builder);
     }
     case HloOpcode::kCustomCall: {
       auto custom_call = Cast<HloCustomCallInstruction>(instruction);
@@ -1310,103 +1231,31 @@
           .getOperation();
     }
     case HloOpcode::kCopyStart: {
-      auto copy_start_instruction = Cast<HloCopyStartInstruction>(instruction);
-      if (auto cross_program_prefetch_index =
-              copy_start_instruction->cross_program_prefetch_index()) {
-        attributes.push_back(builder_->getNamedAttr(
-            "cross_program_prefetch_index",
-            builder_->getIntegerAttr(builder_->getIntegerType(32),
-                                     *cross_program_prefetch_index)));
-        // Cross-program prefetch allows copy ops to accept tuples, in which
-        // case, we need to double-wrap inputs and outputs in tuples.
-        if (operands[0].getType().isa<mlir::TupleType>()) {
-          auto result_types = result_type.cast<mlir::TupleType>().getTypes();
-          result_type = mlir::TupleType::get(
-              context_, {mlir::TupleType::get(context_, {result_types[0]}),
-                         mlir::TupleType::get(context_, {result_types[1]}),
-                         result_types[2]});
-        }
-      }
-      return ImportOldStyleAsyncStart<mlir::mhlo::CopyOp>(
-          attributes, operands, loc, result_type, func_builder, "copy_",
-          [](auto) { return absl::OkStatus(); });
+      return ImportCopyStart(instruction, loc, operands, attributes,
+                             result_type, func_builder, symbol_table_);
     }
     case HloOpcode::kCopyDone: {
-      return ImportOldStyleAsyncDone(attributes, operands, loc, result_type,
-                                     func_builder);
+      return ImportAsyncOpDone(instruction, loc, operands, attributes,
+                               result_type, func_builder);
     }
     case HloOpcode::kSend: {
-      // old-style send returns a bundle of (arg, sync flag, token) to be passed
-      // along to send-done.
-      // However, the new-style async ops have a shared bundle
-      // format of (args, results, scratchpad), so to rewrite the `send` and
-      // `send-done` ops to use the new-style async API, we need to reorder the
-      // arguments to be in (args, token, sync flag) order.
-      auto result_types = result_type.cast<mlir::TupleType>().getTypes();
-      if (result_types.size() != 3)
-        return InvalidArgument("send should return a 3-tuple");
-      auto async_arg_type =
-          mlir::TupleType::get(context_, {result_types[0], result_types[2]});
-      auto async_bundled_tuple = mlir::TupleType::get(
-          context_, {async_arg_type, result_types[2], result_types[1]});
-      auto send_op = Cast<HloSendInstruction>(instruction);
-      attributes.push_back(builder_->getNamedAttr(
-          "is_host_transfer",
-          builder_->getBoolAttr(send_op->is_host_transfer())));
-      if (send_op->channel_id().has_value()) {
-        ChannelHandle channel_handle;
-        channel_handle.set_handle(send_op->channel_id().value());
-        channel_handle.set_type(send_op->is_host_transfer()
-                                    ? ChannelHandle::DEVICE_TO_HOST
-                                    : ChannelHandle::DEVICE_TO_DEVICE);
-        attributes.push_back(ConvertChannelHandle(channel_handle));
-      }
-      return ImportOldStyleAsyncStart<mlir::mhlo::SendOp>(
-          attributes, operands, loc, async_bundled_tuple, func_builder, "send_",
-          [](auto) { return absl::OkStatus(); });
+      return ImportSend(instruction, loc, operands, attributes, result_type,
+                        func_builder, symbol_table_);
     }
     case HloOpcode::kSendDone: {
-      return ImportOldStyleAsyncDone(attributes, operands, loc, result_type,
-                                     func_builder);
+      return ImportAsyncOpDone(instruction, loc, operands, attributes,
+                               result_type, func_builder);
     }
     case HloOpcode::kRecv: {
-      // Old-style `recv` returns a bundle of (result, sync flag, token) to be
-      // passed along to recv-done.
-      // However, the new-style async ops have a shared
-      // bundle format of (args, results, scratchpad), so to rewrite the `recv`
-      // and `recv-done` ops to use the new-style async API, we need to reorder
-      // the arguments to be in (token, (result, token), sync flag) order.
-      auto result_types = result_type.cast<mlir::TupleType>().getTypes();
-      if (result_types.size() != 3)
-        return InvalidArgument("recv should return a 3-tuple");
-      auto async_result_type =
-          mlir::TupleType::get(context_, {result_types[0], result_types[2]});
-      auto async_bundled_tuple = mlir::TupleType::get(
-          context_, {result_types[2], async_result_type, result_types[1]});
-      auto recv_op = Cast<HloRecvInstruction>(instruction);
-      attributes.push_back(builder_->getNamedAttr(
-          "is_host_transfer",
-          builder_->getBoolAttr(recv_op->is_host_transfer())));
-      if (recv_op->channel_id().has_value()) {
-        ChannelHandle channel_handle;
-        channel_handle.set_handle(recv_op->channel_id().value());
-        channel_handle.set_type(recv_op->is_host_transfer()
-                                    ? ChannelHandle::HOST_TO_DEVICE
-                                    : ChannelHandle::DEVICE_TO_DEVICE);
-        attributes.push_back(ConvertChannelHandle(channel_handle));
-      }
-      return ImportOldStyleAsyncStart<mlir::mhlo::RecvOp>(
-          attributes, operands, loc, async_bundled_tuple, func_builder, "recv_",
-          [](auto) { return absl::OkStatus(); });
+      return ImportRecv(instruction, loc, operands, attributes, result_type,
+                        func_builder, symbol_table_);
     }
     case HloOpcode::kRecvDone: {
-      return ImportOldStyleAsyncDone(attributes, operands, loc, result_type,
-                                     func_builder);
+      return ImportAsyncOpDone(instruction, loc, operands, attributes,
+                               result_type, func_builder);
     }
     case HloOpcode::kConditional: {
       llvm::SmallVector<Type, 4> rets;
-
-      // Flatten the tuple-typed operands.
       llvm::SmallVector<Value> flattened_operands =
           FlattenTupleValues(func_builder, loc, operands);
 
@@ -1498,9 +1347,9 @@
           ConvertReplicaGroups(all_gather->replica_groups(), builder_));
       if (all_gather->channel_id().has_value())
         attributes.push_back(
-            ConvertChannelHandle(all_gather->channel_id().value()));
+            ConvertChannelHandle(all_gather->channel_id().value(), builder_));
       if (all_gather->use_global_device_ids())
-        attributes.push_back(ConvertUseGlobalDeviceIds());
+        attributes.push_back(ConvertUseGlobalDeviceIds(builder_));
       auto all_gather_op = func_builder->create<mlir::mhlo::AllGatherOp>(
           loc, result_types, operands, attributes);
       if (result_tuple_ty) {
@@ -1512,28 +1361,12 @@
       return all_gather_op.getOperation();
     }
     case HloOpcode::kAllGatherStart: {
-      auto all_gather_start = Cast<HloAllGatherInstruction>(instruction);
-      attributes.push_back(builder_->getNamedAttr(
-          "all_gather_dim", builder_->getI64IntegerAttr(
-                                all_gather_start->all_gather_dimension())));
-      attributes.push_back(
-          ConvertReplicaGroups(all_gather_start->replica_groups(), builder_));
-      if (all_gather_start->channel_id().has_value())
-        attributes.push_back(
-            ConvertChannelHandle(all_gather_start->channel_id().value()));
-      if (all_gather_start->use_global_device_ids())
-        attributes.push_back(ConvertUseGlobalDeviceIds());
-      if (all_gather_start->operands().size() > 1)
-        return InvalidArgument(
-            "Async tuple all-gather is not supported in MHLO");
-
-      return ImportOldStyleAsyncStart<mlir::mhlo::AllGatherOp>(
-          attributes, operands, loc, result_type, func_builder, "all_gather_",
-          [](auto) { return absl::OkStatus(); });
+      return ImportAllGatherStart(instruction, loc, operands, attributes,
+                                  result_type, func_builder, symbol_table_);
     }
     case HloOpcode::kAllGatherDone: {
-      return ImportOldStyleAsyncDone(attributes, operands, loc, result_type,
-                                     func_builder);
+      return ImportAsyncOpDone(instruction, loc, operands, attributes,
+                               result_type, func_builder);
     }
     case HloOpcode::kAllReduce: {
       auto all_reduce = Cast<HloAllReduceInstruction>(instruction);
@@ -1548,9 +1381,9 @@
           ConvertReplicaGroups(all_reduce->replica_groups(), builder_));
       if (all_reduce->channel_id().has_value())
         attributes.push_back(
-            ConvertChannelHandle(all_reduce->channel_id().value()));
+            ConvertChannelHandle(all_reduce->channel_id().value(), builder_));
       if (all_reduce->use_global_device_ids())
-        attributes.push_back(ConvertUseGlobalDeviceIds());
+        attributes.push_back(ConvertUseGlobalDeviceIds(builder_));
       auto all_reduce_op = func_builder->create<mlir::mhlo::AllReduceOp>(
           loc, result_types, operands, attributes);
       TF_RETURN_IF_ERROR(ImportAsRegion(*all_reduce->to_apply(),
@@ -1564,29 +1397,19 @@
       return all_reduce_op.getOperation();
     }
     case HloOpcode::kAllReduceStart: {
-      auto all_reduce_start = Cast<HloAllReduceInstruction>(instruction);
-      attributes.push_back(
-          ConvertReplicaGroups(all_reduce_start->replica_groups(), builder_));
-      if (all_reduce_start->channel_id().has_value())
-        attributes.push_back(
-            ConvertChannelHandle(all_reduce_start->channel_id().value()));
-      if (all_reduce_start->use_global_device_ids())
-        attributes.push_back(ConvertUseGlobalDeviceIds());
-      if (all_reduce_start->operands().size() > 1)
-        return InvalidArgument(
-            "Async tuple all-reduce is not supported in MHLO");
+      auto appendRegion = [&](mlir::mhlo::AllReduceOp all_reduce_sync) {
+        TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->to_apply(),
+                                          &all_reduce_sync.getComputation()));
+        return absl::OkStatus();
+      };
 
-      return ImportOldStyleAsyncStart<mlir::mhlo::AllReduceOp>(
-          attributes, operands, loc, result_type, func_builder, "all_reduce_",
-          [&](auto all_reduce_sync) {
-            TF_RETURN_IF_ERROR(ImportAsRegion(
-                *instruction->to_apply(), &all_reduce_sync.getComputation()));
-            return absl::OkStatus();
-          });
+      return ImportAllReduceStart(instruction, loc, operands, attributes,
+                                  result_type, func_builder, appendRegion,
+                                  symbol_table_);
     }
     case HloOpcode::kAllReduceDone: {
-      return ImportOldStyleAsyncDone(attributes, operands, loc, result_type,
-                                     func_builder);
+      return ImportAsyncOpDone(instruction, loc, operands, attributes,
+                               result_type, func_builder);
     }
     case HloOpcode::kAllToAll: {
       auto all_to_all = Cast<HloAllToAllInstruction>(instruction);
@@ -1623,7 +1446,8 @@
           replica_groups_attr);
 
       if (all_to_all->channel_id().has_value()) {
-        auto handle = ConvertChannelHandle(all_to_all->channel_id().value());
+        auto handle =
+            ConvertChannelHandle(all_to_all->channel_id().value(), builder_);
         result.setChannelHandleAttr(
             handle.getValue().cast<mlir::mhlo::ChannelHandleAttr>());
       }
@@ -1829,10 +1653,10 @@
       attributes.push_back(
           ConvertReplicaGroups(reduce_scatter->replica_groups(), builder_));
       if (reduce_scatter->channel_id().has_value())
-        attributes.push_back(
-            ConvertChannelHandle(reduce_scatter->channel_id().value()));
+        attributes.push_back(ConvertChannelHandle(
+            reduce_scatter->channel_id().value(), builder_));
       if (reduce_scatter->use_global_device_ids())
-        attributes.push_back(ConvertUseGlobalDeviceIds());
+        attributes.push_back(ConvertUseGlobalDeviceIds(builder_));
       auto reduce_scatter_op =
           func_builder->create<mlir::mhlo::ReduceScatterOp>(
               loc, result_type, operands, attributes);
@@ -2192,10 +2016,22 @@
     const HloInstruction* instruction,
     const llvm::SmallVectorImpl<mlir::Value>& operands,
     mlir::OpBuilder* func_builder, DynamicShapeHandlingMode mode) {
+  LLVM_DEBUG(llvm::dbgs() << "Importing instruction: "
+                          << HloOpcodeString(instruction->opcode()) << '\n');
+  LLVM_DEBUG({
+    llvm::dbgs() << "  operands: (";
+    llvm::interleaveComma(operands, llvm::dbgs(),
+                          [](Value v) { llvm::dbgs() << v.getType(); });
+    llvm::dbgs() << ")\n";
+  });
   TF_ASSIGN_OR_RETURN(
       mlir::Operation * op,
       ImportInstructionImpl(instruction, operands, func_builder, mode));
-  if (op == nullptr) return op;
+  if (op == nullptr) {
+    LLVM_DEBUG(llvm::dbgs() << "  instruction skipped.\n");
+    return op;
+  }
+  LLVM_DEBUG(llvm::dbgs() << "  imported: " << *op << '\n');
 
   // See MlirToHloConversionOptions for more about layouts.
   //
@@ -2322,62 +2158,6 @@
   return builder_->getNamedAttr("padding", attr);
 }
 
-mlir::NamedAttribute HloFunctionImporter::ConvertSourceTargetPairs(
-    const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
-    mlir::Builder* builder) {
-  std::vector<int64_t> attr(source_target_pairs.size() * 2);
-  for (const auto& p : llvm::enumerate(source_target_pairs)) {
-    attr[2 * p.index()] = p.value().first;
-    attr[2 * p.index() + 1] = p.value().second;
-  }
-  auto type = mlir::RankedTensorType::get(
-      {static_cast<int64_t>(attr.size() / 2), 2}, builder->getIntegerType(64));
-  return builder->getNamedAttr("source_target_pairs",
-                               DenseIntElementsAttr::get(type, attr));
-}
-
-mlir::NamedAttribute HloFunctionImporter::ConvertReplicaGroups(
-    absl::Span<const ReplicaGroup> replica_groups, mlir::Builder* builder) {
-  const int64_t num_groups = replica_groups.size();
-  // Replica groups in HLO can be non-uniform in size, for example:
-  // replica_groups={{0},{1,2},{3}}. Since we are representing them as a 2D
-  // tensor, pad the smaller sized replica groups with -1.
-  const int64_t group_size = absl::c_accumulate(
-      replica_groups, int64_t(0), [](int64_t current, const ReplicaGroup& g) {
-        return std::max<int64_t>(current, g.replica_ids_size());
-      });
-  // Initialize all elements to -1 to support non-uniform replica groups.
-  std::vector<int64_t> attr(num_groups * group_size, -1);
-  for (int i = 0; i < num_groups; ++i) {
-    int index = i * group_size;
-    for (const int64_t& id : replica_groups[i].replica_ids())
-      attr[index++] = id;
-  }
-  auto type = mlir::RankedTensorType::get({num_groups, group_size},
-                                          builder->getIntegerType(64));
-  return builder->getNamedAttr("replica_groups",
-                               DenseIntElementsAttr::get(type, attr));
-}
-
-mlir::NamedAttribute HloFunctionImporter::ConvertChannelHandle(
-    std::optional<int64_t> channel_id) {
-  ChannelHandle channel_handle;
-  if (channel_id) channel_handle.set_handle(*channel_id);
-  return ConvertChannelHandle(channel_handle);
-}
-
-mlir::NamedAttribute HloFunctionImporter::ConvertChannelHandle(
-    const ChannelHandle& channel) {
-  return builder_->getNamedAttr(
-      "channel_handle", mlir::mhlo::ChannelHandleAttr::get(
-                            context_, channel.handle(), channel.type()));
-}
-
-mlir::NamedAttribute HloFunctionImporter::ConvertUseGlobalDeviceIds() {
-  return builder_->getNamedAttr("use_global_device_ids",
-                                builder_->getUnitAttr());
-}
-
 void HloFunctionImporter::SetLayoutForMlir(mlir::Operation* op,
                                            const Shape& shape,
                                            llvm::StringRef attr_name) {
diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.h b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.h
index 5c5a4e3..fa22a6d 100644
--- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.h
+++ b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.h
@@ -16,21 +16,28 @@
 #ifndef XLA_TRANSLATE_HLO_TO_MHLO_HLO_FUNCTION_IMPORTER_H_
 #define XLA_TRANSLATE_HLO_TO_MHLO_HLO_FUNCTION_IMPORTER_H_
 
-#include <string>
+#include <cstdint>
 #include <unordered_map>
 
 #include "absl/status/status.h"
 #include "absl/status/statusor.h"
 #include "absl/types/optional.h"
+#include "absl/types/span.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
 #include "mlir/IR/Attributes.h"
+#include "mlir/IR/Block.h"
 #include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Operation.h"
+#include "mlir/IR/Region.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/IR/Value.h"
 #include "mlir/IR/ValueRange.h"
 #include "xla/comparison_util.h"
 #include "xla/hlo/ir/hlo_input_output_alias_config.h"
@@ -90,30 +97,12 @@
   static void SetLayoutForMlir(mlir::Operation* op, const Shape& shape,
                                llvm::StringRef attr_name);
 
-  // TODO(b/179166199): move this to attribute_importer.h.
-  // Converts XLA instruction source target pairs to MLIR attribute.
-  static mlir::NamedAttribute ConvertSourceTargetPairs(
-      const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
-      mlir::Builder* builder);
-
-  // TODO(b/179166199): move this to attribute_importer.h.
-  // Converts replica groups to attribute
-  static mlir::NamedAttribute ConvertReplicaGroups(
-      absl::Span<const ReplicaGroup> replica_groups, mlir::Builder* builder);
-
   // For mlir::IfOp or mlir::CaseOp, replace the uses of their region's block
   // arguments with 'implicit_operands'. Here | implicit_operands | == sum of
   // the number of arguments in all the regions in IfOp or CaseOp.
   void ReplaceBlockArgumentsWithImplicitOperands(
       mlir::Operation* op, llvm::ArrayRef<mlir::Value> implicit_operands);
 
-  // Create a TupleOp using the results of 'op' if 'type' is a mlir::TupleType.
-  // Otherwise, return 'op'.
-  mlir::Operation* CreateTupleFromOpResults(mlir::OpBuilder* func_builder,
-                                            mlir::Location loc,
-                                            mlir::Operation* op,
-                                            mlir::Type type);
-
   // FlattenTupleType flattens the types in (nested) tuple-type 'type' and
   // stores them in 'flattened_types'.
   static void FlattenTupleType(
@@ -131,23 +120,6 @@
       mlir::OpBuilder* func_builder, mlir::Location loc,
       mlir::ValueRange values, std::optional<int> reserve_size = std::nullopt);
 
-  // CreateTupleValue creates a root TupleOp of (nested) tuple-type 'type' using
-  // the non-tuple-typed values in 'flatten_values'.
-  //
-  // e.g., Given 'flatten_values': [V1, V2, V3] &'type': tuple<T1,tuple<T1,T2>>,
-  //      The function returns %t2 such that:
-  //       %t1 = mhlo.tuple(V2,V3) : (T2,T3) -> tuple<T2,T3>
-  //       %t2 = mhlo.tuple(V1,%t1): (T1,tuple<T2,T3>) -> tuple<T1,tuple<T1,T2>>
-  //
-  // Note: 1. FlattenTupleValue and CreateTupleValue is a pair of functions to
-  //          resp. flatten and create tuples in the exact same order.
-  //       2. `flatten_values`, initially storing the flattened values, will be
-  //          mutated to a 0-length array by the end of function invocation.
-  static mlir::Value CreateTupleValue(mlir::OpBuilder* func_builder,
-                                      mlir::Location loc,
-                                      mlir::ValueRange& flatten_values,
-                                      mlir::Type type);
-
  private:
   HloFunctionImporter(mlir::SymbolTable& symbol_table,
                       std::unordered_map<const HloComputation*,
@@ -221,6 +193,7 @@
   // Returns the Mlir Value for the corresponding HloInstruction.
   absl::StatusOr<mlir::Value> GetMlirValue(const HloInstruction* instruction);
 
+  // TODO(b/179166199): Move attribute converters to attribute_importer.
   // Converts an XLA ComparisonDirection to the corresponding MLIR attribute.
   mlir::NamedAttribute ConvertComparisonDirection(
       ComparisonDirection direction);
@@ -245,43 +218,6 @@
   // padding low and padding high for each of the spatial dimensions.
   mlir::NamedAttribute ConvertPadding(llvm::ArrayRef<int64_t> padding);
 
-  // Converts channel id to attribute
-  mlir::NamedAttribute ConvertChannelHandle(std::optional<int64_t> channel_id);
-
-  // Convert use global device ids flag to attribute
-  mlir::NamedAttribute ConvertUseGlobalDeviceIds();
-
-  // Converts channel handle to attribute
-  mlir::NamedAttribute ConvertChannelHandle(const ChannelHandle& channel);
-
-  // ============
-  // Imports an old-style async start op. E.g. an HLO all-gather-start
-  // instruction is imported as an async-start associated with an all-gather
-  // computation.
-  //
-  // Eventually, old-style async ops (e.g. all-gather-start) and new-style async
-  // ops (i.e. async-start, async-update and async-done) will converge on the
-  // HLO side, so we decided to not introduce new MHLO ops for all-gather-start
-  // and friends.
-  //
-  // In the end, there may be new ops added in the old-style because they're not
-  // compatible with the new-style async semantics, but those should be handled
-  // on their own, rather than this function which "upgrades" ops to the
-  // new-style async API.
-  // ============
-  template <typename SyncOp>
-  absl::StatusOr<mlir::Operation*> ImportOldStyleAsyncStart(
-      llvm::SmallVectorImpl<mlir::NamedAttribute>& attributes,
-      const llvm::SmallVectorImpl<mlir::Value>& operands, mlir::Location loc,
-      mlir::Type result_type, mlir::OpBuilder* func_builder,
-      std::string func_name, std::function<absl::Status(SyncOp)> mutate_op);
-
-  // Imports an old-style async done op
-  absl::StatusOr<mlir::Operation*> ImportOldStyleAsyncDone(
-      llvm::SmallVectorImpl<mlir::NamedAttribute>& attributes,
-      const llvm::SmallVectorImpl<mlir::Value>& operands, mlir::Location loc,
-      mlir::Type result_type, mlir::OpBuilder* func_builder);
-
   mlir::MLIRContext* context_;
 
   // SymbolTable to which new functions should be inserted.
diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.cc b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.cc
index e8d81dc..d6dafe0 100644
--- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.cc
+++ b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.cc
@@ -15,12 +15,29 @@
 
 #include "xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h"
 
+#include "absl/status/statusor.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/OwningOpRef.h"
 #include "xla/mlir/utils/error_util.h"
+#include "xla/service/llvm_ir/llvm_util.h"
 #include "xla/status_macros.h"
 #include "xla/translate/hlo_to_mhlo/hlo_module_importer.h"
+#include "tsl/platform/errors.h"
 
 namespace xla {
 
+absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> ConvertHloToMlirHlo(
+    mlir::MLIRContext& ctx, xla::HloModuleProto const* hlo_module,
+    bool import_all_computations, bool flatten_computation_args_result) {
+  mlir::OwningOpRef<mlir::ModuleOp> module =
+      llvm_ir::CreateMlirModuleOp(mlir::UnknownLoc::get(&ctx));
+  TF_RETURN_IF_ERROR(ConvertHloToMlirHlo(*module, hlo_module,
+                                         import_all_computations,
+                                         flatten_computation_args_result));
+  return module;
+}
+
 absl::Status ConvertHloToMlirHlo(mlir::ModuleOp module,
                                  xla::HloModuleProto const* hlo_module_proto,
                                  bool import_all_computation,
@@ -32,7 +49,7 @@
 }
 
 absl::Status ConvertHloToMlirHlo(mlir::ModuleOp module,
-                                 xla::HloModule* hlo_module,
+                                 const xla::HloModule* hlo_module,
                                  bool import_all_computation,
                                  bool flatten_computation_args_result) {
   mlir::BaseScopedDiagnosticHandler diag_handler(module.getContext());
@@ -41,4 +58,15 @@
       .Import(*hlo_module);
 }
 
+absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> ConvertHloToMlirHlo(
+    mlir::MLIRContext& ctx, const xla::HloModule* hlo_module,
+    bool import_all_computations, bool flatten_computation_args_result) {
+  mlir::OwningOpRef<mlir::ModuleOp> module =
+      llvm_ir::CreateMlirModuleOp(mlir::UnknownLoc::get(&ctx));
+  TF_RETURN_IF_ERROR(ConvertHloToMlirHlo(*module, hlo_module,
+                                         import_all_computations,
+                                         flatten_computation_args_result));
+  return module;
+}
+
 }  // namespace xla
diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h
index 161823a1..775d636 100644
--- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h
+++ b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h
@@ -19,6 +19,10 @@
 #include <stdbool.h>
 
 #include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/OwningOpRef.h"
 
 namespace mlir {
 class ModuleOp;
@@ -35,6 +39,11 @@
 //
 // If `flatten_computation_args_result` is set to true, flattens all tuple
 // arguments and result of every computation when importing them as func ops.
+absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> ConvertHloToMlirHlo(
+    mlir::MLIRContext& ctx, xla::HloModuleProto const* hlo_module,
+    bool import_all_computations = false,
+    bool flatten_computation_args_result = false);
+
 absl::Status ConvertHloToMlirHlo(mlir::ModuleOp module,
                                  xla::HloModuleProto const* hlo_module,
                                  bool import_all_computations = false,
@@ -47,8 +56,13 @@
 //
 // If `flatten_computation_args_result` is set to true, flattens all tuple
 // arguments and result of every computation when importing them as func ops.
+absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> ConvertHloToMlirHlo(
+    mlir::MLIRContext& ctx, const xla::HloModule* hlo_module,
+    bool import_all_computations = false,
+    bool flatten_computation_args_result = false);
+
 absl::Status ConvertHloToMlirHlo(mlir::ModuleOp module,
-                                 xla::HloModule* hlo_module,
+                                 const xla::HloModule* hlo_module,
                                  bool import_all_computations = false,
                                  bool flatten_computation_args_result = false);
 
diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.cc b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.cc
index 468c29a..e6004cf 100644
--- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.cc
+++ b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.cc
@@ -17,20 +17,34 @@
 
 #include "xla/translate/hlo_to_mhlo/hlo_utils.h"
 
+#include <cassert>
 #include <cstddef>
-#include <type_traits>
+#include <cstdint>
 #include <vector>
 
+#include "absl/status/statusor.h"
 #include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Casting.h"
 #include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/Operation.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "mlir/IR/ValueRange.h"
+#include "xla/layout_util.h"
 #include "xla/literal.h"
 #include "xla/mlir/utils/type_util.h"
+#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
 #include "xla/primitive_util.h"
+#include "xla/shape.h"
 #include "xla/types.h"
 #include "xla/util.h"
 #include "xla/xla_data.pb.h"
+#include "tsl/platform/statusor.h"
 
 namespace xla {
 namespace {
@@ -139,4 +153,46 @@
       vector);
 }
 
+mlir::Value CreateTupleValue(mlir::OpBuilder* func_builder, mlir::Location loc,
+                             mlir::ValueRange& flatten_values,
+                             mlir::Type type) {
+  auto tuple_type = type.dyn_cast<mlir::TupleType>();
+  if (!tuple_type) {
+    assert(!flatten_values.empty());
+    auto retval = flatten_values.front();
+    flatten_values = flatten_values.drop_front();
+    return retval;
+  }
+
+  llvm::SmallVector<mlir::Value> flatten_sub_values;
+  for (auto child_type : tuple_type.getTypes())
+    flatten_sub_values.push_back(
+        CreateTupleValue(func_builder, loc, flatten_values, child_type));
+
+  return func_builder->create<mlir::mhlo::TupleOp>(loc, flatten_sub_values)
+      .getResult();
+}
+
+mlir::Operation* CreateTupleFromOpResults(mlir::OpBuilder* func_builder,
+                                          mlir::Location loc,
+                                          mlir::Operation* op,
+                                          mlir::Type type) {
+  if (!type.isa<mlir::TupleType>()) return op;
+
+  mlir::ValueRange flattened_results_ref(op->getResults());
+  auto result =
+      CreateTupleValue(func_builder, loc, flattened_results_ref, type);
+  auto defining_tuple_op = result.getDefiningOp<mlir::mhlo::TupleOp>();
+  assert(defining_tuple_op && "builder didn't return the right type");
+  auto tupleOp = defining_tuple_op.getOperation();
+  return tupleOp;
+}
+
+mlir::TypeRange Untuple(const mlir::Type& type) {
+  if (llvm::isa<mlir::TupleType>(type)) {
+    return llvm::dyn_cast<mlir::TupleType>(type).getTypes();
+  }
+  return type;
+}
+
 }  // namespace xla
diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.h b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.h
index 81fe60e..dd7f68a 100644
--- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.h
+++ b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.h
@@ -32,6 +32,10 @@
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Value.h"
+#include "mlir/IR/ValueRange.h"
 #include "xla/layout.h"
 #include "xla/layout_util.h"
 #include "xla/literal.h"
@@ -169,6 +173,29 @@
   return ConvertTensorShapeToType<TypeT>(shape, builder);
 }
 
+// CreateTupleValue creates a root TupleOp of (nested) tuple-type 'type' using
+// the non-tuple-typed values in 'flatten_values'.
+//
+// e.g., Given 'flatten_values': [V1, V2, V3] &'type': tuple<T1,tuple<T1,T2>>,
+//      The function returns %t2 such that:
+//       %t1 = mhlo.tuple(V2,V3) : (T2,T3) -> tuple<T2,T3>
+//       %t2 = mhlo.tuple(V1,%t1): (T1,tuple<T2,T3>) -> tuple<T1,tuple<T1,T2>>
+//
+// Note: 1. FlattenTupleValue and CreateTupleValue is a pair of functions to
+//          resp. flatten and create tuples in the exact same order.
+//       2. `flatten_values`, initially storing the flattened values, will be
+//          mutated to a 0-length array by the end of function invocation.
+mlir::Value CreateTupleValue(mlir::OpBuilder* func_builder, mlir::Location loc,
+                             mlir::ValueRange& flatten_values, mlir::Type type);
+
+// Create a TupleOp using the results of 'op' if 'type' is a mlir::TupleType.
+// Otherwise, return 'op'.
+mlir::Operation* CreateTupleFromOpResults(mlir::OpBuilder* func_builder,
+                                          mlir::Location loc,
+                                          mlir::Operation* op, mlir::Type type);
+
+mlir::TypeRange Untuple(const mlir::Type& type);
+
 static std::pair<mlir::Attribute, mlir::ArrayAttr> GetLayoutAttribute(
     mlir::Builder& b, const Shape& shape,
     std::optional<const Layout> maybe_layout = std::nullopt) {
diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/BUILD b/third_party/xla/xla/translate/hlo_to_mhlo/tests/BUILD
index 9c3500c..fd980a0 100644
--- a/third_party/xla/xla/translate/hlo_to_mhlo/tests/BUILD
+++ b/third_party/xla/xla/translate/hlo_to_mhlo/tests/BUILD
@@ -11,6 +11,7 @@
         [
             "bool_compare.hlo",
             "case_conditional.hlo",
+            "composite_call.hlo",
             "custom_call.hlo",
             "dynamic_param.hlo",
             "entry_computation_layout.hlo",
@@ -20,11 +21,11 @@
             "if_conditional.hlo",
             "import.hlo",
             "import_async.hlo",
+            "import_async2.hlo",
             "layouts_and_names.hlo",
             "location.hlo",
             "module_attributes.hlo",
             "module_config.hlo",
-            "send_recv.hlo",
             "simple.hlo",
             "spmd_module_sharding.hlo",
             "stacktrace_to_location.hlo",
diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/composite_call.hlo b/third_party/xla/xla/translate/hlo_to_mhlo/tests/composite_call.hlo
new file mode 100644
index 0000000..ad3dc70
--- /dev/null
+++ b/third_party/xla/xla/translate/hlo_to_mhlo/tests/composite_call.hlo
@@ -0,0 +1,186 @@
+// RUN: xla-translate -split-input-file -hlo-text-to-mlir-hlo %s | FileCheck %s
+
+// dictionary-like frontend_attributes
+HloModule composite, entry_computation_layout={()->f32[]}
+
+// CHECK: func.func @main() -> tensor<f32> {
+// CHECK:    %0 = mhlo.constant dense<4.200000e+01> : tensor<f32>
+// CHECK:    %1 = mhlo.composite "foo.bar" %0 {composite_attributes = {n = 1 : i32, tensor = dense<1> : tensor<i32>}, decomposition = @add.2, version = 1 : i32} : (tensor<f32>) -> tensor<f32>
+// CHECK:    return %1 : tensor<f32>
+// CHECK:  }
+
+// CHECK: func.func private @add.2(%arg0: tensor<f32>) -> tensor<f32> {
+// CHECK:    %0 = mhlo.constant dense<2.000000e+00> : tensor<f32>
+// CHECK:    %1 = mhlo.add %arg0, %0 : tensor<f32>
+// CHECK:    return %1 : tensor<f32>
+// CHECK:  }
+%add.2 (Arg_0.3: f32[]) -> f32[] {
+  %Arg_0.3 = f32[] parameter(0)
+  %constant.4 = f32[] constant(2)
+  ROOT %add.5 = f32[] add(f32[] %Arg_0.3, f32[] %constant.4)
+}
+
+ENTRY %main.7 () -> f32[] {
+  %constant.1 = f32[] constant(42)
+  ROOT %call.6 = f32[] call(f32[] %constant.1), to_apply=%add.2, is_composite=true, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor<i32>},composite.name="foo.bar",composite.version="1"}
+}
+
+// -----
+
+// string-like frontend_attributes
+HloModule composite, entry_computation_layout={()->f32[]}
+
+// CHECK: func.func @main() -> tensor<f32> {
+// CHECK:    %0 = mhlo.constant dense<4.200000e+01> : tensor<f32>
+// CHECK:    %1 = mhlo.composite "foo.bar" %0 {composite_attributes = {n = 1 : i32, tensor = dense<1> : tensor<i32>}, decomposition = @add.2, version = 1 : i32} : (tensor<f32>) -> tensor<f32>
+// CHECK:    return %1 : tensor<f32>
+// CHECK:  }
+
+// CHECK: func.func private @add.2(%arg0: tensor<f32>) -> tensor<f32> {
+// CHECK:    %0 = mhlo.constant dense<2.000000e+00> : tensor<f32>
+// CHECK:    %1 = mhlo.add %arg0, %0 : tensor<f32>
+// CHECK:    return %1 : tensor<f32>
+// CHECK:  }
+%add.2 (Arg_0.3: f32[]) -> f32[] {
+  %Arg_0.3 = f32[] parameter(0)
+  %constant.4 = f32[] constant(2)
+  ROOT %add.5 = f32[] add(f32[] %Arg_0.3, f32[] %constant.4)
+}
+
+ENTRY %main.7 () -> f32[] {
+  %constant.1 = f32[] constant(42)
+  ROOT %call.6 = f32[] call(f32[] %constant.1), to_apply=%add.2, is_composite=true, frontend_attributes={composite.attributes="{n = 1 : i32, tensor = dense<1> : tensor<i32>}",composite.name="foo.bar",composite.version="1"}
+}
+
+// -----
+
+// zero-output composite
+HloModule composite, entry_computation_layout={()->()}
+
+// CHECK: func.func @main() -> tuple<> {
+// CHECK:   %0 = mhlo.constant dense<4.200000e+01> : tensor<f32>
+// CHECK:   %1 = mhlo.composite "foo.bar" %0 {composite_attributes = {n = 1 : i32, tensor = dense<1> : tensor<i32>}, decomposition = @return.2, version = 1 : i32, xla_shape = "()"} : (tensor<f32>) -> tuple<>
+// CHECK:   return %1 : tuple<>
+// CHECK: }
+// CHECK: func.func private @return.2(%arg0: tensor<f32>) -> tuple<> {
+// CHECK:   %0 = mhlo.tuple  {xla_shape = "()"} : tuple<>
+// CHECK:   return %0 : tuple<>
+// CHECK: }
+%return.2 (Arg_0.3: f32[]) -> () {
+  %Arg_0.3 = f32[] parameter(0)
+  ROOT %tuple.4 = () tuple()
+}
+
+ENTRY %main.7 () -> () {
+  %constant.1 = f32[] constant(42)
+  ROOT %call.5 = () call(f32[] %constant.1), to_apply=%return.2, is_composite=true, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor<i32>},composite.name="foo.bar",composite.version="1"}
+}
+
+// -----
+
+// multi-output composite
+HloModule composite, entry_computation_layout={()->(f32[], f32[])}
+
+// CHECK: func.func @main() -> tuple<tensor<f32>, tensor<f32>> {
+// CHECK:   %0 = mhlo.constant dense<4.200000e+01> : tensor<f32>
+// CHECK:   %1 = mhlo.composite "foo.bar" %0 {composite_attributes = {n = 1 : i32, tensor = dense<1> : tensor<i32>}, decomposition = @add.2, version = 1 : i32, xla_shape = "(f32[], f32[])"} : (tensor<f32>) -> tuple<tensor<f32>, tensor<f32>>
+// CHECK:   return %1 : tuple<tensor<f32>, tensor<f32>>
+// CHECK: }
+// CHECK: func.func private @add.2(%arg0: tensor<f32>) -> tuple<tensor<f32>, tensor<f32>> {
+// CHECK:   %0 = mhlo.constant dense<2.000000e+00> : tensor<f32>
+// CHECK:   %1 = mhlo.add %arg0, %0 : tensor<f32>
+// CHECK:   %2 = mhlo.tuple %1, %1 {xla_shape = "(f32[], f32[])"} : tuple<tensor<f32>, tensor<f32>>
+// CHECK:   return %2 : tuple<tensor<f32>, tensor<f32>>
+// CHECK: }
+%add.2 (Arg_0.3: f32[]) -> (f32[], f32[]) {
+  %Arg_0.3 = f32[] parameter(0)
+  %constant.4 = f32[] constant(2)
+  %add.5 = f32[] add(f32[] %Arg_0.3, f32[] %constant.4)
+  ROOT %tuple.6 = (f32[], f32[]) tuple(f32[] %add.5, f32[] %add.5)
+}
+
+ENTRY %main.9 () -> (f32[], f32[]) {
+  %constant.1 = f32[] constant(42)
+  ROOT %call.7 = (f32[], f32[]) call(f32[] %constant.1), to_apply=%add.2, is_composite=true, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor<i32>},composite.name="foo.bar",composite.version="1"}
+}
+
+// -----
+
+// optional composite attributes
+HloModule composite, entry_computation_layout={()->f32[]}
+
+// CHECK: func.func @main() -> tensor<f32> {
+// CHECK:    %0 = mhlo.constant dense<4.200000e+01> : tensor<f32>
+// CHECK:    %1 = mhlo.composite "foo.bar" %0 {decomposition = @add.2, version = 1 : i32} : (tensor<f32>) -> tensor<f32>
+// CHECK:    return %1 : tensor<f32>
+// CHECK:  }
+
+// CHECK: func.func private @add.2(%arg0: tensor<f32>) -> tensor<f32> {
+// CHECK:    %0 = mhlo.constant dense<2.000000e+00> : tensor<f32>
+// CHECK:    %1 = mhlo.add %arg0, %0 : tensor<f32>
+// CHECK:    return %1 : tensor<f32>
+// CHECK:  }
+%add.2 (Arg_0.3: f32[]) -> f32[] {
+  %Arg_0.3 = f32[] parameter(0)
+  %constant.4 = f32[] constant(2)
+  ROOT %add.5 = f32[] add(f32[] %Arg_0.3, f32[] %constant.4)
+}
+
+ENTRY %main.7 () -> f32[] {
+  %constant.1 = f32[] constant(42)
+  ROOT %call.6 = f32[] call(f32[] %constant.1), to_apply=%add.2, is_composite=true, frontend_attributes={composite.attributes={},composite.name="foo.bar",composite.version="1"}
+}
+
+// -----
+
+// optional composite version
+HloModule composite, entry_computation_layout={()->f32[]}
+
+// CHECK: func.func @main() -> tensor<f32> {
+// CHECK:    %0 = mhlo.constant dense<4.200000e+01> : tensor<f32>
+// CHECK:    %1 = mhlo.composite "foo.bar" %0 {composite_attributes = {n = 1 : i32, tensor = dense<1> : tensor<i32>}, decomposition = @add.2} : (tensor<f32>) -> tensor<f32>
+// CHECK:    return %1 : tensor<f32>
+// CHECK:  }
+
+// CHECK: func.func private @add.2(%arg0: tensor<f32>) -> tensor<f32> {
+// CHECK:    %0 = mhlo.constant dense<2.000000e+00> : tensor<f32>
+// CHECK:    %1 = mhlo.add %arg0, %0 : tensor<f32>
+// CHECK:    return %1 : tensor<f32>
+// CHECK:  }
+%add.2 (Arg_0.3: f32[]) -> f32[] {
+  %Arg_0.3 = f32[] parameter(0)
+  %constant.4 = f32[] constant(2)
+  ROOT %add.5 = f32[] add(f32[] %Arg_0.3, f32[] %constant.4)
+}
+
+ENTRY %main.7 () -> f32[] {
+  %constant.1 = f32[] constant(42)
+  ROOT %call.6 = f32[] call(f32[] %constant.1), to_apply=%add.2, is_composite=true, frontend_attributes={composite.attributes="{n = 1 : i32, tensor = dense<1> : tensor<i32>}",composite.name="foo.bar",composite.version="0"}
+}
+
+// -----
+
+// optional composite attributes and version
+HloModule composite, entry_computation_layout={()->f32[]}
+
+// CHECK: func.func @main() -> tensor<f32> {
+// CHECK:    %0 = mhlo.constant dense<4.200000e+01> : tensor<f32>
+// CHECK:    %1 = mhlo.composite "foo.bar" %0 {decomposition = @add.2} : (tensor<f32>) -> tensor<f32>
+// CHECK:    return %1 : tensor<f32>
+// CHECK:  }
+
+// CHECK: func.func private @add.2(%arg0: tensor<f32>) -> tensor<f32> {
+// CHECK:    %0 = mhlo.constant dense<2.000000e+00> : tensor<f32>
+// CHECK:    %1 = mhlo.add %arg0, %0 : tensor<f32>
+// CHECK:    return %1 : tensor<f32>
+// CHECK:  }
+%add.2 (Arg_0.3: f32[]) -> f32[] {
+  %Arg_0.3 = f32[] parameter(0)
+  %constant.4 = f32[] constant(2)
+  ROOT %add.5 = f32[] add(f32[] %Arg_0.3, f32[] %constant.4)
+}
+
+ENTRY %main.7 () -> f32[] {
+  %constant.1 = f32[] constant(42)
+  ROOT %call.6 = f32[] call(f32[] %constant.1), to_apply=%add.2, is_composite=true, frontend_attributes={composite.attributes={},composite.name="foo.bar",composite.version="0"}
+}
diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/import_async.hlo b/third_party/xla/xla/translate/hlo_to_mhlo/tests/import_async.hlo
index 7dcd16a..4e96330 100644
--- a/third_party/xla/xla/translate/hlo_to_mhlo/tests/import_async.hlo
+++ b/third_party/xla/xla/translate/hlo_to_mhlo/tests/import_async.hlo
@@ -1,142 +1,162 @@
-// RUN: xla-translate --print-sugar=false -hlo-text-to-mlir-hlo -hlo-import-all-computations %s -o - | FileCheck %s
-// RUN: xla-translate --print-sugar=false -hlo-text-to-mlir-hlo %s -o - | FileCheck %s -check-prefix=NO_DEAD_FUNCTION
+// RUN: xla-translate -hlo-text-to-mlir-hlo -hlo-import-all-computations -split-input-file %s -o - | FileCheck %s
 
-// NO_DEAD_FUNCTION-NOT: @test
+// CHECK-LABEL: func.func private @recv_
+// CHECK:  %0:2 = "mhlo.recv"(%arg0) <{channel_handle = #mhlo.channel_handle<handle = 5, type = 3>, is_host_transfer = true}> : (!mhlo.token) -> (tensor<i32>, !mhlo.token)
 
-// CHECK: module @foobar
+// CHECK-LABEL: func.func private @send_
+// CHECK: %0 = "mhlo.send"(%arg0, %arg1) <{channel_handle = #mhlo.channel_handle<handle = 3, type = 2>, is_host_transfer = true}> : (tensor<i32>, !mhlo.token) -> !mhlo.token
+
+// CHECK-LABEL: func.func @main
+// CHECK-LITERAL:      %0 = "mhlo.async_start"(%arg0, %arg1) <{called_computation = @send_, execution_thread = "main"}> {mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "_foo_dtoh_0"}, mhlo.sharding = "{{maximal device=0}, {maximal device=0}, {maximal device=0}}", xla_shape = "(s32[], u32[], token[])"} : (tensor<i32>, !mhlo.token) -> !mhlo.async_bundle<tuple<tensor<i32>, !mhlo.token>, !mhlo.token, tensor<ui32>>
+// CHECK-NEXT-LITERAL: %1 = "mhlo.async_done"(%0) {called_computation = @send_, execution_thread = "main", mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "_foo_dtoh_0"}, mhlo.sharding = "{maximal device=0}", xla_shape = "token[]"} : (!mhlo.async_bundle<tuple<tensor<i32>, !mhlo.token>, !mhlo.token, tensor<ui32>>) -> !mhlo.token
+// CHECK-NEXT-LITERAL: %2 = "mhlo.async_start"(%1) <{called_computation = @recv_, execution_thread = "main"}> {mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "_foo_htod_0"}, mhlo.sharding = "{{maximal device=0}, {maximal device=0}, {maximal device=0}}", xla_shape = "(s32[], u32[], token[])"} : (!mhlo.token) -> !mhlo.async_bundle<!mhlo.token, tuple<tensor<i32>, !mhlo.token>, tensor<ui32>>
+// CHECK-NEXT-LITERAL: %3:2 = "mhlo.async_done"(%2) {called_computation = @recv_, execution_thread = "main", mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "_foo_htod_0"}, mhlo.sharding = "{{maximal device=0}, {maximal device=0}}"} : (!mhlo.async_bundle<!mhlo.token, tuple<tensor<i32>, !mhlo.token>, tensor<ui32>>) -> (tensor<i32>, !mhlo.token)
 HloModule foobar
 
-// Compiler-generated functions
+ENTRY %async_send_recv_test (arg_0: s32[], arg_1: token[]) -> (s32[], token[]) {
+  %arg_0 = s32[] parameter(0)
+  %arg_1 = token[] parameter(1)
 
-// CHECK:  func private [[RECV_DTD_GENSYM:@.*recv.*]]([[TOK:%.*]]: !mhlo.token) -> (tensor<128x32xf32>, !mhlo.token) attributes {execution_thread = "main"} {
-  // CHECK-NEXT: "mhlo.recv"([[TOK]]
-  // CHECK-SAME{LITERAL}: {channel_handle = #mhlo.channel_handle<handle = 5, type = 1>, is_host_transfer = false}
+  %send.0 = (s32[], u32[], token[]) send(s32[] %arg_0, token[] %arg_1), channel_id=3, is_host_transfer=true, sharding={{maximal device=0}, {maximal device=0}, {maximal device=0}}, frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous", _xla_host_transfer_rendezvous="_foo_dtoh_0"}
+  %send-done.1 = token[] send-done((s32[], u32[], token[]) %send.0), channel_id=3, is_host_transfer=true, sharding={maximal device=0}, frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous", _xla_host_transfer_rendezvous="_foo_dtoh_0"}
 
-// CHECK:  func private [[RECV_GENSYM:@.*recv.*]]([[TOK:%.*]]: !mhlo.token) -> (tensor<128x32xf32>, !mhlo.token) attributes {execution_thread = "main"} {
-  // CHECK-NEXT: "mhlo.recv"([[TOK]]
-  // CHECK-SAME{LITERAL}: {channel_handle = #mhlo.channel_handle<handle = 5, type = 3>, is_host_transfer = true}
+  %recv.2 = (s32[], u32[], token[]) recv(token[] %send-done.1), channel_id=5, is_host_transfer=true, sharding={{maximal device=0}, {maximal device=0}, {maximal device=0}}, frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous", _xla_host_transfer_rendezvous="_foo_htod_0"}
+  %recv-done.3 = (s32[], token[]) recv-done((s32[], u32[], token[]) %recv.2), channel_id=5, is_host_transfer=true, sharding={{maximal device=0}, {maximal device=0}}, frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous", _xla_host_transfer_rendezvous="_foo_htod_0"}
 
-// CHECK:  func private [[SEND_GENSYM:@.*send.*]]([[INPUT:%.*]]: tensor<128x32xf32>, %arg1: !mhlo.token) -> !mhlo.token attributes {execution_thread = "main"} {
-  // CHECK-NEXT:  "mhlo.send"([[INPUT]]
-  // CHECK-SAME{LITERAL}: {channel_handle = #mhlo.channel_handle<handle = 5, type = 2>, is_host_transfer = true}
-
-// CHECK:  func private [[COPY_GENSYM:@.*copy.*]]([[INPUT:%.*]]: tensor<128x32xf32>) -> tensor<128x32xf32> attributes {execution_thread = "main"} {
-  // CHECK-NEXT:  mhlo.copy [[INPUT]]
-  // CHECK-SAME: cross_program_prefetch_index
-
-// CHECK: func private [[CP_GENSYM:@.*collective_permute_.*]]([[INPUT:%.*]]: tensor<128x32xf32>) -> tensor<128x32xf32> attributes {execution_thread = "main"} {
-  // CHECK-NEXT:  "mhlo.collective_permute"([[INPUT]])
-  // CHECK-SAME{LITERAL}: <{source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>}> : (tensor<128x32xf32>) -> tensor<128x32xf32>
-
-// CHECK:  func private [[AR_GENSYM:@.*all_reduce.*]]([[INPUT:%.*]]: tensor<128x32xf32>) -> tensor<128x32xf32> attributes {execution_thread = "main"} {
-  // CHECK-NEXT:  "mhlo.all_reduce"([[INPUT]])
-  // CHECK-SAME: channel_handle = #mhlo.channel_handle<handle = 1, type = 0>
-  // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>
-  // CHECK-SAME: use_global_device_ids
-    // CHECK: [[BLOCK:^.*]]([[LHS:%.*]]: tensor<f32>, [[RHS:%.*]]: tensor<f32>):
-    // CHECK: mhlo.add [[LHS]], [[RHS]]
-
-// CHECK:  func private [[AG_GENSYM:@.*all_gather.*]]([[INPUT:%.*]]: tensor<128x32xf32>) -> tensor<128x128xf32> attributes {execution_thread = "main"} {
-  // CHECK-NEXT:  "mhlo.all_gather"([[INPUT]])
-  // CHECK-SAME: all_gather_dim = 1 : i64
-  // CHECK-SAME: channel_handle = #mhlo.channel_handle<handle = 1, type = 0>
-  // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>
-  // CHECK-SAME: use_global_device_ids
-
-// CHECK:  func @main(%arg0: tensor<f32>) -> tensor<f32> {
-ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] {
-  ROOT %Arg_0.1 = f32[] parameter(0)
+  %get-tuple-element.4 = s32[] get-tuple-element((s32[], token[]) %recv-done.3), index=0, sharding={maximal device=0}
+  %get-tuple-element.5 = token[] get-tuple-element((s32[], token[]) %recv-done.3), index=1, sharding={maximal device=0}
+  ROOT %tuple.6 = (s32[], token[]) tuple(s32[] %get-tuple-element.4, token[] %get-tuple-element.5)
 }
 
-// Tests
+// -----
 
-// CHECK:  func private @test_all_gather_start
-// CHECK-SAME:  ([[INPUT:%.*]]: tensor<128x32xf32>)
-%test_all_gather_start {
-  input = f32[128,32] parameter(0)
-  // CHECK-NEXT:  [[AG_START:%.*]] = "mhlo.async_start"([[INPUT]])
-  // CHECK-SAME: called_computation = [[AG_GENSYM]], execution_thread = "main"
-  ag-start = (f32[128,32], f32[128,128]) all-gather-start(input), channel_id=1, replica_groups={{0, 2, 4, 6}, {1, 3, 5, 7}}, dimensions={1}, use_global_device_ids=true
-  // CHECK-NEXT:  "mhlo.async_done"([[AG_START]])
-  ROOT ag-done = f32[128,128] all-gather-done(ag-start)
+HloModule main, entry_computation_layout={(f32[128,32]{1,0})->f32[128,128]{1,0}}
+
+// CHECK-LABEL: func.func private @all_gather_
+// CHECK: mhlo.all_gather
+
+// CHECK-LABEL: func.func @main
+// CHECK: mhlo.async_start{{.*}}called_computation = @all_gather_
+// CHECK: mhlo.async_done
+
+ENTRY %async_all_gather_test (Arg_0.1: f32[128,32]) -> f32[128,128] {
+  %Arg_0.1 = f32[128,32] parameter(0)
+  %all-gather-start.2 = f32[128,128] all-gather-start(f32[128,32] %Arg_0.1), channel_id=1, replica_groups={{0,2,4,6},{1,3,5,7}}, constrain_layout=true, dimensions={1}, use_global_device_ids=true, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:1 offset " source_line=16}
+  ROOT %all-gather-done.3 = f32[128,128] all-gather-done(f32[128,128] %all-gather-start.2), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:1 offset " source_line=17}
 }
 
-add {
-  lhs = f32[] parameter(0)
-  rhs = f32[] parameter(1)
-  ROOT add = f32[] add(lhs, rhs)
+// -----
+
+HloModule main, entry_computation_layout={(f32[10]{0})->f32[10]{0}}
+
+%region_1.2 (Arg_0.3: f32[], Arg_1.4: f32[]) -> f32[] {
+  %Arg_0.3 = f32[] parameter(0)
+  %Arg_1.4 = f32[] parameter(1)
+  ROOT %maximum.5 = f32[] maximum(f32[] %Arg_0.3, f32[] %Arg_1.4), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:30 offset " source_line=7}
 }
 
-// CHECK:  func private @test_all_reduce_start
-// CHECK-SAME:  ([[INPUT:%.*]]: tensor<128x32xf32>)
-%test_all_reduce_start {
-  input = f32[128,32] parameter(0)
-  // CHECK-NEXT:  [[AR_START:%.*]] = "mhlo.async_start"([[INPUT]])
-  // CHECK-SAME: called_computation = [[AR_GENSYM]], execution_thread = "main"
-  ar-start = (f32[128,32], f32[128,32]) all-reduce-start(input), channel_id=1, replica_groups={{0, 2, 4, 6}, {1, 3, 5, 7}}, to_apply=add, use_global_device_ids=true
-  // CHECK-NEXT:  "mhlo.async_done"([[AR_START]])
-  ROOT ar-done = f32[128,32] all-reduce-done(ar-start)
+// CHECK-LABEL: func.func private @all_reduce_
+// CHECK: mhlo.all_reduce
+
+// CHECK-LABEL: func.func @main
+// CHECK: mhlo.async_start{{.*}}called_computation = @all_reduce_
+// CHECK: mhlo.async_done
+ENTRY %async_all_reduce_test (Arg_0.1: f32[10]) -> f32[10] {
+  %Arg_0.1 = f32[10] parameter(0)
+  %all-reduce-start.6 = f32[10] all-reduce-start(f32[10] %Arg_0.1), channel_id=5, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=%region_1.2, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:30 offset " source_line=22}
+  ROOT %all-reduce-done.7 = f32[10] all-reduce-done(f32[10] %all-reduce-start.6), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:30 offset " source_line=23}
 }
 
-// CHECK:  func private @test_collective_permute
-// CHECK-SAME:  ([[ARG:%.*]]: tensor<128x32xf32>) -> tensor<128x32xf32>
-%test_collective_permute (input: f32[128,32]) -> f32[128,32] {
-  %input = f32[128,32]{1,0} parameter(0)
-  // CHECK-NEXT:  [[CP_START:%.*]] = "mhlo.async_start"([[ARG]])
-  // CHECK-SAME: called_computation = [[CP_GENSYM]], execution_thread = "main"
-  %cp-start = (f32[128,32]{1,0}, f32[128,32]) collective-permute-start(%input), source_target_pairs={{0,1},{1,2},{2,3}}
-  // CHECK-NEXT:  "mhlo.async_done"([[CP_START]])
-  ROOT %cp-done = f32[128,32]{1,0} collective-permute-done(%cp-start)
+// -----
+
+HloModule main, entry_computation_layout={(f32[128,32]{1,0})->f32[128,32]{1,0}}
+
+// CHECK-LABEL: func.func private @collective_permute_
+// CHECK: mhlo.collective_permute
+
+// CHECK-LABEL: func.func @main
+// CHECK: mhlo.async_start{{.*}}called_computation = @collective_permute_
+// CHECK: mhlo.async_done
+ENTRY %async_collective_permute_test (Arg_0.1: f32[128,32]) -> f32[128,32] {
+  %Arg_0.1 = f32[128,32] parameter(0)
+  %collective-permute-start.2 = f32[128,32] collective-permute-start(f32[128,32] %Arg_0.1), channel_id=1, source_target_pairs={{0,1},{1,2},{2,3}}, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:109 offset " source_line=13}
+  ROOT %collective-permute-done.3 = f32[128,32] collective-permute-done(f32[128,32] %collective-permute-start.2), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:109 offset " source_line=14}
 }
 
-// CHECK:  func private @test_copy_start
-// CHECK-SAME:  ([[INPUT:%.*]]: tensor<128x32xf32>)
-%test_copy_start {
-  input = f32[128,32] parameter(0)
-  // CHECK-NEXT:  [[COPY_START:%.*]] = "mhlo.async_start"([[INPUT]])
-  // CHECK-SAME: called_computation = [[COPY_GENSYM]], execution_thread = "main"
-  copy-start = (f32[128,32], f32[128,32], u32[]) copy-start(input), cross_program_prefetch_index=0
-  // CHECK-NEXT: "mhlo.async_done"([[COPY_START]])
-  ROOT copy-done = f32[128,32] copy-done(copy-start)
+// -----
+
+HloModule main, entry_computation_layout={(f32[128,32]{1,0})->f32[128,32]{1,0}}
+
+ENTRY %async_copy_test (Arg_0.1: f32[128,32]) -> f32[128,32] {
+  %Arg_0.1 = f32[128,32] parameter(0)
+  %copy-start.2 = (f32[128,32], f32[128,32], u32[]) copy-start(f32[128,32] %Arg_0.1), cross_program_prefetch_index=0, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:133 offset " source_line=10}
+  ROOT %copy-done.3 = f32[128,32] copy-done((f32[128,32], f32[128,32], u32[]) %copy-start.2), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:133 offset " source_line=11}
 }
 
-// CHECK:  func private @test_send
-// CHECK-SAME:  ([[INPUT:%.*]]: tensor<128x32xf32>, [[TOK:%.*]]: !mhlo.token)
-%test_send_start {
-  input = f32[128,32] parameter(0)
-  tok = token[] parameter(1)
-  // CHECK-NEXT:  [[SEND_START:%.*]] = "mhlo.async_start"([[INPUT]], [[TOK]])
-  // CHECK-SAME: called_computation = [[SEND_GENSYM]], execution_thread = "main"
-  // CHECK-SAME{LITERAL}: -> !mhlo.async_bundle<tuple<tensor<128x32xf32>, !mhlo.token>, !mhlo.token, tensor<ui32>>
-  send-start = (f32[128,32], u32[], token[]) send(input, tok), channel_id=5, is_host_transfer=true
-  // CHECK-NEXT: "mhlo.async_done"([[SEND_START]])
-  ROOT send-done = token[] send-done(send-start), channel_id=5, is_host_transfer=true
+// -----
+
+HloModule main, entry_computation_layout={(token[])->(s32[3,4]{1,0}, token[])}
+
+ENTRY %async_recv_test_tuple (Arg_0.1: token[]) -> (s32[3,4], token[]) {
+  %Arg_0.1 = token[] parameter(0)
+  %recv.2 = (s32[3,4], u32[], token[]) recv(token[] %Arg_0.1), channel_id=5, is_host_transfer=true, sharding={{maximal device=0}, {maximal device=0}, {maximal device=0}}, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:179 offset " source_line=16}
+  %recv-done.3 = (s32[3,4], token[]) recv-done((s32[3,4], u32[], token[]) %recv.2), channel_id=5, is_host_transfer=true, sharding={{maximal device=0}, {maximal device=0}}, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:179 offset " source_line=17}
+  %get-tuple-element.4 = s32[3,4] get-tuple-element((s32[3,4], token[]) %recv-done.3), index=0, sharding={maximal device=0}, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:179 offset " source_line=17}
+  %get-tuple-element.5 = token[] get-tuple-element((s32[3,4], token[]) %recv-done.3), index=1, sharding={maximal device=0}, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:179 offset " source_line=17}
+  ROOT %tuple.6 = (s32[3,4], token[]) tuple(s32[3,4] %get-tuple-element.4, token[] %get-tuple-element.5)
 }
 
-// CHECK:  func private @test_recv
-// CHECK-SAME:  ([[INPUT:%.*]]: tensor<128x32xf32>, [[TOK:%.*]]: !mhlo.token)
-%test_recv_start {
-  input = f32[128,32] parameter(0)
-  tok = token[] parameter(1)
-  // CHECK-NEXT:  [[RECV_START:%.*]] = "mhlo.async_start"([[TOK]])
-  // CHECK-SAME: called_computation = [[RECV_GENSYM]], execution_thread = "main"
-  // CHECK-SAME{LITERAL}: -> !mhlo.async_bundle<!mhlo.token, tuple<tensor<128x32xf32>, !mhlo.token>, tensor<ui32>>
-  recv-start = (f32[128,32], u32[], token[]) recv(tok), channel_id=5, is_host_transfer=true
-  // CHECK-NEXT: "mhlo.async_done"([[RECV_START]])
-  recv-done = (f32[128,21], token[]) recv-done(recv-start), channel_id=5, is_host_transfer=true
-  ROOT gte = get-tuple-element(recv-done), index=0
+// -----
+
+HloModule main, entry_computation_layout={(s32[3,4]{1,0}, token[])->token[]}
+
+ENTRY %async_send_test (Arg_0.1: s32[3,4], Arg_1.2: token[]) -> token[] {
+  %Arg_0.1 = s32[3,4] parameter(0)
+  %Arg_1.2 = token[] parameter(1)
+  %send.3 = (s32[3,4], u32[], token[]) send(s32[3,4] %Arg_0.1, token[] %Arg_1.2), channel_id=5, is_host_transfer=true, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:213 offset " source_line=16}
+  ROOT %send-done.4 = token[] send-done((s32[3,4], u32[], token[]) %send.3), channel_id=5, is_host_transfer=true, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:213 offset " source_line=17}
 }
 
-// CHECK:  func private @test_recv_dtd
-// CHECK-SAME:  ([[INPUT:%.*]]: tensor<128x32xf32>, [[TOK:%.*]]: !mhlo.token)
-%test_recv_dtd_start {
-  input = f32[128,32] parameter(0)
-  tok = token[] parameter(1)
-  // CHECK-NEXT:  [[RECV_START:%.*]] = "mhlo.async_start"([[TOK]])
-  // CHECK-SAME: called_computation = [[RECV_DTD_GENSYM]], execution_thread = "main"
-  // CHECK-SAME{LITERAL}: -> !mhlo.async_bundle<!mhlo.token, tuple<tensor<128x32xf32>, !mhlo.token>, tensor<ui32>>
-  recv-start = (f32[128,32], u32[], token[]) recv(tok), channel_id=5
-  // CHECK-NEXT: "mhlo.async_done"([[RECV_START]])
-  recv-done = (f32[128,21], token[]) recv-done(recv-start), channel_id=5
-  ROOT gte = get-tuple-element(recv-done), index=0
-}
+
+// BROKEN: b/TODO: Async custom calls?
+
+// HloModule main, entry_computation_layout={(f32[10]{0})->(f32[20]{0})}
+
+// ENTRY %async_custom_call_test2 (Arg_0.1: f32[10]) -> (f32[20]) {
+//   %Arg_0.1 = f32[10] parameter(0)
+//   %async-start.5 = ((f32[10]), f32[20], s32[]) custom-call-start(f32[10] %Arg_0.1), async_execution_thread="thread", custom_call_target="bar", metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:288 offset " source_line=21}
+//   %async-update.6 = ((f32[10]), f32[20], s32[]) custom-call-update(((f32[10]), f32[20], s32[]) %async-start.5), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:288 offset " source_line=22}
+//   ROOT %async-done.7 = (f32[20]) custom-call-done(((f32[10]), f32[20], s32[]) %async-update.6), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:288 offset " source_line=23}
+// }
+
+// HloModule main, entry_computation_layout={(f32[10]{0})->(f32[20]{0})}
+
+// ENTRY %async_custom_call_test (Arg_0.1: f32[10]) -> (f32[20]) {
+//   %Arg_0.1 = f32[10] parameter(0)
+//   %async-start.5 = ((f32[10]), f32[20], s32[]) custom-call-start(f32[10] %Arg_0.1), async_execution_thread="thread", custom_call_target="foo", metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:265 offset " source_line=16}
+//   %async-update.6 = ((f32[10]), f32[20], s32[]) custom-call-update(((f32[10]), f32[20], s32[]) %async-start.5), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:265 offset " source_line=18}
+//   ROOT %async-done.7 = (f32[20]) custom-call-done(((f32[10]), f32[20], s32[]) %async-update.6), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:265 offset " source_line=20}
+// }
+
+
+///////////
+
+// BROKEN: b/TODO: Empty arg send/recv don't roundtrip
+
+// HloModule main, entry_computation_layout={(token[])->token[]}
+
+// ENTRY %async_send_test_empty (Arg_0.1: token[]) -> token[] {
+//   %tuple.2 = () tuple(), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:240 offset " source_line=15}
+//   %Arg_0.1 = token[] parameter(0)
+//   %send.3 = ((), u32[], token[]) send(() %tuple.2, token[] %Arg_0.1), channel_id=5, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:240 offset " source_line=15}
+//   ROOT %send-done.4 = token[] send-done(((), u32[], token[]) %send.3), channel_id=5, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:240 offset " source_line=16}
+// }
+
+// HloModule main, entry_computation_layout={(token[])->((), token[])}
+
+// ENTRY %async_recv_test (Arg_0.1: token[]) -> ((), token[]) {
+//   %Arg_0.1 = token[] parameter(0)
+//   %recv.2 = ((), u32[], token[]) recv(token[] %Arg_0.1), channel_id=5, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:153 offset " source_line=17}
+//   ROOT %recv-done.3 = ((), token[]) recv-done(((), u32[], token[]) %recv.2), channel_id=5, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:153 offset " source_line=18}
+// }
+
diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/import_async2.hlo b/third_party/xla/xla/translate/hlo_to_mhlo/tests/import_async2.hlo
new file mode 100644
index 0000000..7493c95
--- /dev/null
+++ b/third_party/xla/xla/translate/hlo_to_mhlo/tests/import_async2.hlo
@@ -0,0 +1,146 @@
+// RUN: xla-translate --print-sugar=false -hlo-text-to-mlir-hlo -hlo-import-all-computations %s -o - | FileCheck %s
+// RUN: xla-translate --print-sugar=false -hlo-text-to-mlir-hlo %s -o - | FileCheck %s -check-prefix=NO_DEAD_FUNCTION
+
+// It would be great to consolidate this test with `import_async.hlo`, but
+// this test is very fragile and doesn't run properly in a `-split-input-file`
+// mode.
+
+// NO_DEAD_FUNCTION-NOT: @test
+
+// CHECK: module @foobar
+HloModule foobar
+
+// Compiler-generated functions
+
+// CHECK:  func private [[RECV_DTD_GENSYM:@.*recv.*]]([[TOK:%.*]]: !mhlo.token) -> (tensor<128x32xf32>, !mhlo.token) attributes {execution_thread = "main"} {
+  // CHECK-NEXT: "mhlo.recv"([[TOK]]
+  // CHECK-SAME{LITERAL}: {channel_handle = #mhlo.channel_handle<handle = 5, type = 1>, is_host_transfer = false}
+
+// CHECK:  func private [[RECV_GENSYM:@.*recv.*]]([[TOK:%.*]]: !mhlo.token) -> (tensor<128x32xf32>, !mhlo.token) attributes {execution_thread = "main"} {
+  // CHECK-NEXT: "mhlo.recv"([[TOK]]
+  // CHECK-SAME{LITERAL}: {channel_handle = #mhlo.channel_handle<handle = 5, type = 3>, is_host_transfer = true}
+
+// CHECK:  func private [[SEND_GENSYM:@.*send.*]]([[INPUT:%.*]]: tensor<128x32xf32>, %arg1: !mhlo.token) -> !mhlo.token attributes {execution_thread = "main"} {
+  // CHECK-NEXT:  "mhlo.send"([[INPUT]]
+  // CHECK-SAME{LITERAL}: {channel_handle = #mhlo.channel_handle<handle = 5, type = 2>, is_host_transfer = true}
+
+// CHECK:  func private [[COPY_GENSYM:@.*copy.*]]([[INPUT:%.*]]: tensor<128x32xf32>) -> tensor<128x32xf32> attributes {execution_thread = "main"} {
+  // CHECK-NEXT:  mhlo.copy [[INPUT]]
+  // CHECK-SAME: cross_program_prefetch_index
+
+// CHECK: func private [[CP_GENSYM:@.*collective_permute_.*]]([[INPUT:%.*]]: tensor<128x32xf32>) -> tensor<128x32xf32> attributes {execution_thread = "main"} {
+  // CHECK-NEXT:  "mhlo.collective_permute"([[INPUT]])
+  // CHECK-SAME{LITERAL}: <{source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>}> : (tensor<128x32xf32>) -> tensor<128x32xf32>
+
+// CHECK:  func private [[AR_GENSYM:@.*all_reduce.*]]([[INPUT:%.*]]: tensor<128x32xf32>) -> tensor<128x32xf32> attributes {execution_thread = "main"} {
+  // CHECK-NEXT:  "mhlo.all_reduce"([[INPUT]])
+  // CHECK-SAME: channel_handle = #mhlo.channel_handle<handle = 1, type = 0>
+  // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>
+  // CHECK-SAME: use_global_device_ids
+    // CHECK: [[BLOCK:^.*]]([[LHS:%.*]]: tensor<f32>, [[RHS:%.*]]: tensor<f32>):
+    // CHECK: mhlo.add [[LHS]], [[RHS]]
+
+// CHECK:  func private [[AG_GENSYM:@.*all_gather.*]]([[INPUT:%.*]]: tensor<128x32xf32>) -> tensor<128x128xf32> attributes {execution_thread = "main"} {
+  // CHECK-NEXT:  "mhlo.all_gather"([[INPUT]])
+  // CHECK-SAME: all_gather_dim = 1 : i64
+  // CHECK-SAME: channel_handle = #mhlo.channel_handle<handle = 1, type = 0>
+  // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>
+  // CHECK-SAME: use_global_device_ids
+
+// CHECK:  func @main(%arg0: tensor<f32>) -> tensor<f32> {
+ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] {
+  ROOT %Arg_0.1 = f32[] parameter(0)
+}
+
+// Tests
+
+// CHECK:  func private @test_all_gather_start
+// CHECK-SAME:  ([[INPUT:%.*]]: tensor<128x32xf32>)
+%test_all_gather_start {
+  input = f32[128,32] parameter(0)
+  // CHECK-NEXT:  [[AG_START:%.*]] = "mhlo.async_start"([[INPUT]])
+  // CHECK-SAME: called_computation = [[AG_GENSYM]], execution_thread = "main"
+  ag-start = (f32[128,32], f32[128,128]) all-gather-start(input), channel_id=1, replica_groups={{0, 2, 4, 6}, {1, 3, 5, 7}}, dimensions={1}, use_global_device_ids=true
+  // CHECK-NEXT:  "mhlo.async_done"([[AG_START]])
+  ROOT ag-done = f32[128,128] all-gather-done(ag-start)
+}
+
+add {
+  lhs = f32[] parameter(0)
+  rhs = f32[] parameter(1)
+  ROOT add = f32[] add(lhs, rhs)
+}
+
+// CHECK:  func private @test_all_reduce_start
+// CHECK-SAME:  ([[INPUT:%.*]]: tensor<128x32xf32>)
+%test_all_reduce_start {
+  input = f32[128,32] parameter(0)
+  // CHECK-NEXT:  [[AR_START:%.*]] = "mhlo.async_start"([[INPUT]])
+  // CHECK-SAME: called_computation = [[AR_GENSYM]], execution_thread = "main"
+  ar-start = (f32[128,32], f32[128,32]) all-reduce-start(input), channel_id=1, replica_groups={{0, 2, 4, 6}, {1, 3, 5, 7}}, to_apply=add, use_global_device_ids=true
+  // CHECK-NEXT:  "mhlo.async_done"([[AR_START]])
+  ROOT ar-done = f32[128,32] all-reduce-done(ar-start)
+}
+
+// CHECK:  func private @test_collective_permute
+// CHECK-SAME:  ([[ARG:%.*]]: tensor<128x32xf32>) -> tensor<128x32xf32>
+%test_collective_permute (input: f32[128,32]) -> f32[128,32] {
+  %input = f32[128,32]{1,0} parameter(0)
+  // CHECK-NEXT:  [[CP_START:%.*]] = "mhlo.async_start"([[ARG]])
+  // CHECK-SAME: called_computation = [[CP_GENSYM]], execution_thread = "main"
+  %cp-start = (f32[128,32]{1,0}, f32[128,32]) collective-permute-start(%input), source_target_pairs={{0,1},{1,2},{2,3}}
+  // CHECK-NEXT:  "mhlo.async_done"([[CP_START]])
+  ROOT %cp-done = f32[128,32]{1,0} collective-permute-done(%cp-start)
+}
+
+// CHECK:  func private @test_copy_start
+// CHECK-SAME:  ([[INPUT:%.*]]: tensor<128x32xf32>)
+%test_copy_start {
+  input = f32[128,32] parameter(0)
+  // CHECK-NEXT:  [[COPY_START:%.*]] = "mhlo.async_start"([[INPUT]])
+  // CHECK-SAME: called_computation = [[COPY_GENSYM]], execution_thread = "main"
+  copy-start = (f32[128,32], f32[128,32], u32[]) copy-start(input), cross_program_prefetch_index=0
+  // CHECK-NEXT: "mhlo.async_done"([[COPY_START]])
+  ROOT copy-done = f32[128,32] copy-done(copy-start)
+}
+
+// CHECK:  func private @test_send
+// CHECK-SAME:  ([[INPUT:%.*]]: tensor<128x32xf32>, [[TOK:%.*]]: !mhlo.token)
+%test_send_start {
+  input = f32[128,32] parameter(0)
+  tok = token[] parameter(1)
+  // CHECK-NEXT:  [[SEND_START:%.*]] = "mhlo.async_start"([[INPUT]], [[TOK]])
+  // CHECK-SAME: called_computation = [[SEND_GENSYM]], execution_thread = "main"
+  // CHECK-SAME{LITERAL}: -> !mhlo.async_bundle<tuple<tensor<128x32xf32>, !mhlo.token>, !mhlo.token, tensor<ui32>>
+  send-start = (f32[128,32], u32[], token[]) send(input, tok), channel_id=5, is_host_transfer=true
+  // CHECK-NEXT: "mhlo.async_done"([[SEND_START]])
+  ROOT send-done = token[] send-done(send-start), channel_id=5, is_host_transfer=true
+}
+
+// CHECK:  func private @test_recv
+// CHECK-SAME:  ([[INPUT:%.*]]: tensor<128x32xf32>, [[TOK:%.*]]: !mhlo.token)
+%test_recv_start {
+  input = f32[128,32] parameter(0)
+  tok = token[] parameter(1)
+  // CHECK-NEXT:  [[RECV_START:%.*]] = "mhlo.async_start"([[TOK]])
+  // CHECK-SAME: called_computation = [[RECV_GENSYM]], execution_thread = "main"
+  // CHECK-SAME{LITERAL}: -> !mhlo.async_bundle<!mhlo.token, tuple<tensor<128x32xf32>, !mhlo.token>, tensor<ui32>>
+  recv-start = (f32[128,32], u32[], token[]) recv(tok), channel_id=5, is_host_transfer=true
+  // CHECK-NEXT: "mhlo.async_done"([[RECV_START]])
+  recv-done = (f32[128,21], token[]) recv-done(recv-start), channel_id=5, is_host_transfer=true
+  ROOT gte = get-tuple-element(recv-done), index=0
+}
+
+// CHECK:  func private @test_recv_dtd
+// CHECK-SAME:  ([[INPUT:%.*]]: tensor<128x32xf32>, [[TOK:%.*]]: !mhlo.token)
+%test_recv_dtd_start {
+  input = f32[128,32] parameter(0)
+  tok = token[] parameter(1)
+  // CHECK-NEXT:  [[RECV_START:%.*]] = "mhlo.async_start"([[TOK]])
+  // CHECK-SAME: called_computation = [[RECV_DTD_GENSYM]], execution_thread = "main"
+  // CHECK-SAME{LITERAL}: -> !mhlo.async_bundle<!mhlo.token, tuple<tensor<128x32xf32>, !mhlo.token>, tensor<ui32>>
+  recv-start = (f32[128,32], u32[], token[]) recv(tok), channel_id=5
+  // CHECK-NEXT: "mhlo.async_done"([[RECV_START]])
+  recv-done = (f32[128,21], token[]) recv-done(recv-start), channel_id=5
+  ROOT gte = get-tuple-element(recv-done), index=0
+}
diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/send_recv.hlo b/third_party/xla/xla/translate/hlo_to_mhlo/tests/send_recv.hlo
deleted file mode 100644
index ef40699..0000000
--- a/third_party/xla/xla/translate/hlo_to_mhlo/tests/send_recv.hlo
+++ /dev/null
@@ -1,55 +0,0 @@
-// RUN: xla-translate -hlo-text-to-mlir-hlo -hlo-import-all-computations %s -o - | FileCheck %s
-
-HloModule foo
-
-// CHECK:      func private @[[RECV_FUNC:[^(]*]]
-// CHECK:        mhlo.recv
-// CHECK-SAME:     channel_handle = #mhlo.channel_handle<handle = 5, type = 3>
-// CHECK-NOT:      mhlo.sharding
-
-// CHECK:      func private @[[SEND_FUNC:[^(]*]]
-// CHECK:        mhlo.send
-// CHECK-SAME:     channel_handle = #mhlo.channel_handle<handle = 3, type = 2>
-
-// CHECK:      func @main
-// CHECK:        mhlo.async_start
-// CHECK-SAME:     called_computation = @[[SEND_FUNC]]
-// CHECK-SAME:     mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "_foo_dtoh_0"}
-// CHECK-SAME:     mhlo.sharding = "{
-// CHECK-SAME:       {maximal device=0}, {maximal device=0}, {maximal device=0}
-// CHECK-SAME:     }"
-// CHECK-SAME:     (tensor<i32>, !mhlo.token) -> !mhlo.async_bundle<tuple<tensor<i32>, !mhlo.token>, !mhlo.token, tensor<ui32>>
-// CHECK:        mhlo.async_done
-// CHECK-SAME:     called_computation = @[[SEND_FUNC]]
-// CHECK-SAME:     mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "_foo_dtoh_0"}
-// CHECK-SAME:     mhlo.sharding = "{maximal device=0}"
-// CHECK-SAME:     (!mhlo.async_bundle<tuple<tensor<i32>, !mhlo.token>, !mhlo.token, tensor<ui32>>) -> !mhlo.token
-// CHECK:        mhlo.async_start
-// CHECK-SAME:     called_computation = @[[RECV_FUNC]]
-// CHECK-SAME:     mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "_foo_htod_0"}
-// CHECK-SAME:     mhlo.sharding = "{
-// CHECK-SAME:       {maximal device=0}, {maximal device=0}, {maximal device=0}
-// CHECK-SAME:     }"
-// CHECK-SAME:     (!mhlo.token) -> !mhlo.async_bundle<!mhlo.token, tuple<tensor<i32>, !mhlo.token>, tensor<ui32>>
-// CHECK:        mhlo.async_done
-// CHECK-SAME:     called_computation = @[[RECV_FUNC]]
-// CHECK-SAME:     mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "_foo_htod_0"}
-// CHECK-SAME:     mhlo.sharding = "{
-// CHECK-SAME:       {maximal device=0}, {maximal device=0}
-// CHECK-SAME:     }"
-// CHECK-SAME:     (!mhlo.async_bundle<!mhlo.token, tuple<tensor<i32>, !mhlo.token>, tensor<ui32>>) -> (tensor<i32>, !mhlo.token)
-
-ENTRY %foo (arg_0: s32[], arg_1: token[]) -> (s32[], token[]) {
-  %arg_0 = s32[] parameter(0)
-  %arg_1 = token[] parameter(1)
-
-  %send.0 = (s32[], u32[], token[]) send(s32[] %arg_0, token[] %arg_1), channel_id=3, is_host_transfer=true, sharding={{maximal device=0}, {maximal device=0}, {maximal device=0}}, frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous", _xla_host_transfer_rendezvous="_foo_dtoh_0"}
-  %send-done.1 = token[] send-done((s32[], u32[], token[]) %send.0), channel_id=3, is_host_transfer=true, sharding={maximal device=0}, frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous", _xla_host_transfer_rendezvous="_foo_dtoh_0"}
-
-  %recv.2 = (s32[], u32[], token[]) recv(token[] %send-done.1), channel_id=5, is_host_transfer=true, sharding={{maximal device=0}, {maximal device=0}, {maximal device=0}}, frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous", _xla_host_transfer_rendezvous="_foo_htod_0"}
-  %recv-done.3 = (s32[], token[]) recv-done((s32[], u32[], token[]) %recv.2), channel_id=5, is_host_transfer=true, sharding={{maximal device=0}, {maximal device=0}}, frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous", _xla_host_transfer_rendezvous="_foo_htod_0"}
-
-  %get-tuple-element.4 = s32[] get-tuple-element((s32[], token[]) %recv-done.3), index=0, sharding={maximal device=0}
-  %get-tuple-element.5 = token[] get-tuple-element((s32[], token[]) %recv-done.3), index=1, sharding={maximal device=0}
-  ROOT %tuple.6 = (s32[], token[]) tuple(s32[] %get-tuple-element.4, token[] %get-tuple-element.5)
-}
diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/BUILD b/third_party/xla/xla/translate/mhlo_to_hlo/BUILD
index 1047e8a..40e05c8 100644
--- a/third_party/xla/xla/translate/mhlo_to_hlo/BUILD
+++ b/third_party/xla/xla/translate/mhlo_to_hlo/BUILD
@@ -99,6 +99,7 @@
         ":type_to_shape",
         "//xla:array",
         "//xla:comparison_util",
+        "//xla:debug_options_flags",
         "//xla:literal",
         "//xla:literal_util",
         "//xla:shape_util",
@@ -106,6 +107,7 @@
         "//xla:types",
         "//xla:xla_data_proto_cc",
         "//xla/client:xla_builder",
+        "//xla/client:xla_computation",
         "//xla/client/lib:approx_topk",
         "//xla/client/lib:approx_topk_shape",
         "//xla/client/lib:matrix",
@@ -120,8 +122,11 @@
         "//xla/service:hlo_parser",
         "//xla/service:hlo_proto_cc",
         "//xla/service/gpu:backend_configs_cc",
+        "@com_google_absl//absl/log:check",
         "@com_google_absl//absl/status",
         "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/strings:string_view",
         "@com_google_absl//absl/types:span",
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:ArithDialect",
@@ -133,8 +138,11 @@
         "@llvm-project//mlir:Support",
         "@llvm-project//mlir:TensorDialect",
         "@llvm-project//mlir:TransformUtils",
+        "@local_tsl//tsl/platform:errors",
         "@local_tsl//tsl/platform:ml_dtypes",
         "@local_tsl//tsl/platform:statusor",
+        "@local_tsl//tsl/platform:types",
+        "@stablehlo//:base",
         "@stablehlo//:stablehlo_ops",
     ],
 )
@@ -177,18 +185,26 @@
     deps = [
         ":mlir_hlo_to_hlo",
         ":type_to_shape",
+        "//xla:debug_options_flags",
+        "//xla:shape_util",
+        "//xla/client:xla_builder",
+        "//xla/client:xla_computation",
         "//xla/hlo/ir:hlo",
         "//xla/mlir_hlo:hlo_dialect_registration",
         "//xla/service:hlo_module_config",
         "//xla/service:hlo_proto_cc",
         "//xla/service:hlo_proto_util",
         "@com_google_absl//absl/log",
+        "@com_google_absl//absl/status",
         "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:FuncDialect",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Parser",
         "@llvm-project//mlir:Support",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:statusor",
     ],
 )
 
diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/layout_util.h b/third_party/xla/xla/translate/mhlo_to_hlo/layout_util.h
index 2c85a82..2ecd4e3 100644
--- a/third_party/xla/xla/translate/mhlo_to_hlo/layout_util.h
+++ b/third_party/xla/xla/translate/mhlo_to_hlo/layout_util.h
@@ -19,8 +19,9 @@
 #define XLA_TRANSLATE_MHLO_TO_HLO_LAYOUT_UTIL_H_
 
 #include <functional>
-#include <vector>
+#include <optional>
 
+#include "absl/status/status.h"
 #include "absl/status/statusor.h"
 #include "xla/client/xla_builder.h"
 #include "xla/hlo/ir/hlo_sharding.h"
@@ -30,10 +31,10 @@
 namespace mlir {
 
 // XLA Layout preferences. Currently, when it comes to TPU, there are two
-// primary layout choices for any XLA argumetns (parameter or resource): (1)
+// primary layout choices for any XLA arguments (parameter or resource): (1)
 // CompactChunkPadded and (2) Linear. CompactChunkPadded is the native TPU
 // layout while Linear is native host (CPU) layout.
-// This enum allows the caller of XLA to progogate layout preference to the XLA
+// This enum allows the caller of XLA to propagate layout preference to the XLA
 // compiler.
 //   kNoPreference: the generic layout where the XLA compiler has the freedom
 //                  to assign any layout.
diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc b/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc
index 90eb1a9..287b455 100644
--- a/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc
+++ b/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc
@@ -19,26 +19,25 @@
 #include <cassert>
 #include <cstddef>
 #include <cstdint>
-#include <iterator>
 #include <memory>
 #include <optional>
 #include <string>
 #include <utility>
 #include <vector>
 
+#include "absl/log/check.h"
 #include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
 #include "absl/types/span.h"
 #include "llvm/ADT/APInt.h"
 #include "llvm/ADT/DenseMap.h"
-#include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/Support/Casting.h"
-#include "llvm/Support/FormatVariadic.h"
-#include "llvm/Support/MemoryBuffer.h"
-#include "llvm/Support/SMLoc.h"
-#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/LogicalResult.h"
 #include "llvm/Support/raw_ostream.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -46,9 +45,12 @@
 #include "mlir/Dialect/Shape/IR/Shape.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinAttributeInterfaces.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Location.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Matchers.h"
@@ -56,25 +58,26 @@
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/IR/UseDefLists.h"
 #include "mlir/IR/Value.h"
-#include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassManager.h"
-#include "mlir/Support/DebugStringHelper.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Transforms/RegionUtils.h"
-#include "stablehlo/dialect/StablehloOps.h"
+#include "stablehlo/dialect/Base.h"
 #include "xla/array.h"
 #include "xla/client/lib/approx_topk.h"
 #include "xla/client/lib/approx_topk_shape.h"
-#include "xla/client/lib/matrix.h"
-#include "xla/client/lib/quantize.h"
+#include "xla/client/lib/matrix.h"  // IWYU pragma: keep
 #include "xla/client/lib/slicing.h"
 #include "xla/client/xla_builder.h"
+#include "xla/client/xla_computation.h"
 #include "xla/comparison_util.h"
+#include "xla/debug_options_flags.h"
 #include "xla/hlo/ir/dynamic_parameter_binding.h"
-#include "xla/hlo/ir/hlo_instructions.h"
 #include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
 #include "xla/hlo/ir/hlo_sharding.h"
+#include "xla/layout.h"
+#include "xla/layout_util.h"
 #include "xla/literal.h"
 #include "xla/literal_util.h"
 #include "xla/mlir/utils/error_util.h"
@@ -88,16 +91,16 @@
 #include "xla/service/hlo_parser.h"
 #include "xla/shape.h"
 #include "xla/shape_util.h"
-#include "xla/status_macros.h"
 #include "xla/translate/mhlo_to_hlo/attribute_exporter.h"
+#include "xla/translate/mhlo_to_hlo/layout_util.h"
 #include "xla/translate/mhlo_to_hlo/location_exporter.h"
 #include "xla/translate/mhlo_to_hlo/module_config_exporter.h"
 #include "xla/translate/mhlo_to_hlo/stack_frame_index_builder.h"
 #include "xla/translate/mhlo_to_hlo/type_to_shape.h"
-#include "xla/types.h"
 #include "xla/xla_data.pb.h"
-#include "tsl/platform/ml_dtypes.h"
+#include "tsl/platform/errors.h"
 #include "tsl/platform/statusor.h"
+#include "tsl/platform/types.h"
 
 using ::int64_t;
 using ::tsl::int16;
@@ -109,18 +112,54 @@
 using ::tsl::uint64;
 using ::tsl::uint8;
 
-constexpr char kShapeIndicesAttr[] = "shape_indices";
-constexpr char kPaddingArgIndicesAttr[] = "padding_arg_indices";
-constexpr char kShardingAttr[] = "mhlo.sharding";
-constexpr char kFrontendAttributesAttr[] = "mhlo.frontend_attributes";
-constexpr char kReplicationAttr[] = "mhlo.is_same_data_across_replicas";
-constexpr char kParameterReplicationAttr[] = "mhlo.parameter_replication";
-constexpr char kLiteralAttr[] = "mhlo.literal";
+// Boolean attribute.
+constexpr char kJaxBufferDonor[] = "jax.buffer_donor";
 
+// BitcastOp lowering strings.
+constexpr char kResultLayout[] = "result_layout";
+constexpr char kSourceLayout[] = "source_layout";
+
+// CustomCallOp lowering strings.
+constexpr char kAggregateToTopk[] = "aggregate_to_topk";
+constexpr char kApiVersion[] = "api_version";
+constexpr char kApproxTopK[] = "ApproxTopK";
+constexpr char kBackendConfig[] = "backend_config";
+constexpr char kCallTargetName[] = "call_target_name";
+constexpr char kCalledComputations[] = "called_computations";
+constexpr char kHasSideEffect[] = "has_side_effect";
+constexpr char kIsFallback[] = "is_fallback";
+constexpr char kRecallTarget[] = "recall_target";
+constexpr char kReductionDim[] = "reduction_dim";
+constexpr char kReductionInputSizeOverride[] = "reduction_input_size_override";
+constexpr char kTopK[] = "top_k";
+
+// MHLO attributes. Module level attributes require namespacing.
+constexpr char kMhloCrossProgramPrefetches[] = "mhlo.cross_program_prefetches";
+constexpr char kMhloFrontendAttributes[] = "mhlo.frontend_attributes";
+constexpr char kMhloInputOutputAlias[] = "mhlo.input_output_alias";
+constexpr char kMhloIsDynamic[] = "mhlo.is_dynamic";
+constexpr char kMhloLiteral[] = "mhlo.literal";
+constexpr char kMhloParameterReplication[] = "mhlo.parameter_replication";
+constexpr char kMhloReplication[] = "mhlo.is_same_data_across_replicas";
+constexpr char kMhloSharding[] = "mhlo.sharding";
+constexpr char kMhloSpmdOutputSharding[] = "mhlo.spmd_output_sharding";
+constexpr char kMhloSpmdParametersShardings[] =
+    "mhlo.spmd_parameters_shardings";
+constexpr char kMhloUseAutoSpmdPartitioning[] =
+    "mhlo.use_auto_spmd_partitioning";
+
+// Miscellaneous string literals.
+constexpr char kArgEmptyTuple[] = "arg_empty_tuple";
+constexpr char kArgPrefix[] = "Arg_";
+constexpr char kArgTuple[] = "arg_tuple";
+constexpr char kDefaultLayoutAttrName[] = "xla_shape";
+constexpr char kExecutionThread[] = "execution_thread";
 // Array attribute. Same shape as infeed result, but contains a
 // minor_to_major array for every tensor.
-constexpr char kLayoutAttr[] = "layout";
-constexpr char kDefaultLayoutAttrName[] = "xla_shape";
+constexpr char kLayout[] = "layout";
+constexpr char kMain[] = "main";
+constexpr char kRegionPrefix[] = "region_";
+constexpr char kTfAliasingOutput[] = "tf.aliasing_output";
 
 // Passes through everything except for unique_ptr, on which it calls get().
 // This exists to allow the generated code to call XLA functions that take a raw
@@ -585,7 +624,7 @@
 // returns std::nullopt.
 static std::optional<xla::OpSharding> CreateOpShardingFromAttribute(
     mlir::Operation* op) {
-  auto shardingAttr = op->getAttrOfType<mlir::StringAttr>(kShardingAttr);
+  auto shardingAttr = op->getAttrOfType<mlir::StringAttr>(kMhloSharding);
   if (!shardingAttr) return std::nullopt;
   return xla::ConvertSharding(shardingAttr.getValue());
 }
@@ -606,7 +645,7 @@
     mlir::Operation* op) {
   xla::FrontendAttributes frontend_attributes;
   auto frontend_attributes_dict =
-      op->getAttrOfType<mlir::DictionaryAttr>(kFrontendAttributesAttr);
+      op->getAttrOfType<mlir::DictionaryAttr>(kMhloFrontendAttributes);
   if (!frontend_attributes_dict) return frontend_attributes;
   ConstructFrontendAttributesFromAttribute(frontend_attributes_dict,
                                            frontend_attributes);
@@ -619,7 +658,7 @@
   fe_attrs->resize(function.getNumArguments(), std::nullopt);
   for (int i = 0, end = function.getNumArguments(); i < end; ++i)
     if (auto fe_attr = function.getArgAttrOfType<mlir::DictionaryAttr>(
-            i, kFrontendAttributesAttr)) {
+            i, kMhloFrontendAttributes)) {
       xla::FrontendAttributes frontend_attributes;
       ConstructFrontendAttributesFromAttribute(fe_attr, frontend_attributes);
       (*fe_attrs)[i] = frontend_attributes;
@@ -643,14 +682,14 @@
                         std::optional<xla::OpSharding>());
   for (int i = 0, end = function.getNumArguments(); i < end; ++i)
     if (auto sharding =
-            function.getArgAttrOfType<mlir::StringAttr>(i, kShardingAttr))
+            function.getArgAttrOfType<mlir::StringAttr>(i, kMhloSharding))
       (*arg_shardings)[i] = xla::ConvertSharding(sharding.getValue());
 
   ret_shardings->resize(function.getNumResults(),
                         std::optional<xla::OpSharding>());
   for (int i = 0, end = function.getNumResults(); i < end; ++i)
     if (auto sharding =
-            function.getResultAttrOfType<mlir::StringAttr>(i, kShardingAttr))
+            function.getResultAttrOfType<mlir::StringAttr>(i, kMhloSharding))
       (*ret_shardings)[i] = xla::ConvertSharding(sharding.getValue());
 }
 
@@ -753,7 +792,7 @@
   //
   // TODO(hinsu): Check for dynamic shapes and exit instead of crashing.
   LogicalResult Run() {
-    auto main = module_.lookupSymbol<mlir::func::FuncOp>("main");
+    auto main = module_.lookupSymbol<mlir::func::FuncOp>(kMain);
     if (!main)
       return module_.emitError(
           "conversion requires module with `main` function");
@@ -771,8 +810,8 @@
   // Lower a `mlir::Region` to a `XlaComputation`
   LogicalResult LowerRegionAsComputation(
       mlir::Region* region, xla::XlaComputation* func,
-      std::optional<llvm::ArrayRef<mlir::Value>> implicit_operands =
-          std::nullopt,
+      llvm::ArrayRef<mlir::Value> implicit_operands = {},
+      llvm::ArrayRef<mlir::Value> implicit_results = {},
       bool ensure_single_arg = false,
       llvm::ArrayRef<std::optional<xla::OpSharding>> arg_shardings = {},
       llvm::ArrayRef<std::optional<xla::OpSharding>> ret_shardings = {});
@@ -786,11 +825,11 @@
       llvm::ArrayRef<std::optional<xla::OpSharding>> ret_shardings,
       llvm::ArrayRef<std::optional<xla::FrontendAttributes>> fe_attrs,
       xla::XlaComputation* result,
-      std::optional<llvm::ArrayRef<mlir::Value>> implicit_operands =
-          std::nullopt);
+      llvm::ArrayRef<mlir::Value> implicit_operands = {},
+      llvm::ArrayRef<mlir::Value> implicit_results = {});
 
   ::xla::HloModuleProto ConsumeMainProto() {
-    auto main = module_.lookupSymbol<mlir::func::FuncOp>("main");
+    auto main = module_.lookupSymbol<mlir::func::FuncOp>(kMain);
     // This is an invariant check as Run returns failure if there is no main
     // function and so the main proto shouldn't be consumed in that case.
     CHECK(main) << "requires module to have main function";  // Crash Ok.
@@ -816,7 +855,7 @@
   LogicalResult Lower(
       mlir::Operation* inst, bool is_entry_function,
       llvm::ArrayRef<std::optional<xla::OpSharding>> ret_shardings,
-      xla::XlaBuilder* builder,
+      llvm::ArrayRef<mlir::Value> implicit_results, xla::XlaBuilder* builder,
       ConvertToHloModule::ValueLoweringMap* value_lowering,
       xla::XlaOp* return_value);
 
@@ -916,12 +955,13 @@
 }
 
 void BuildGetTupleElementsForTupleResults(mlir::Operation* op, xla::XlaOp tuple,
-                                          OpLoweringContext ctx) {
+                                          OpLoweringContext ctx,
+                                          unsigned num_implicit_results = 0) {
   const std::optional<xla::OpSharding>& sharding = ctx.builder->sharding();
   if (sharding.has_value()) {
     bool is_tuple_sharding = sharding->type() == xla::OpSharding::TUPLE;
-    assert(!is_tuple_sharding ||
-           op->getNumResults() == sharding->tuple_shardings_size());
+    assert(!is_tuple_sharding || (op->getNumResults() + num_implicit_results ==
+                                  sharding->tuple_shardings_size()));
     for (auto [index, result] : llvm::enumerate(op->getResults())) {
       // If `sharding` is not a tuple sharding, then every `get-tuple-element`
       // gets the same sharding.
@@ -956,7 +996,8 @@
 }
 
 LogicalResult ExportXlaOp(CompositeOp, OpLoweringContext) {
-  // TODO: b/328526226 - Implement MHLO export for CompositeOp.
+  // Failure on purpose because `mhlo::CompositeOp` will be handled by
+  // special purpose logic in `ConvertToHloModule::Lower`.
   return failure();
 }
 
@@ -1610,7 +1651,7 @@
   xla::XlaComputation false_branch;
   auto& value_map = *ctx.values;
 
-  // mhlo.IfOp does not have any operands or blocks-arguments. The computation
+  // mhlo.IfOp does not have any operands or blocks arguments. The computation
   // inside the region-blocks use implicit captures of values defined above.
   // In order to create the xla parameters for functions corresponding to
   // IfOp regions, we need to infer the a region-block's arguments, using all
@@ -1628,10 +1669,10 @@
   getUsedValuesDefinedAbove(op.getFalseBranch(), op.getFalseBranch(),
                             implicit_false_operand_set);
 
-  llvm::SmallVector<mlir::Value> implicit_true_operands(
-      implicit_true_operand_set.begin(), implicit_true_operand_set.end());
-  llvm::SmallVector<mlir::Value> implicit_false_operands(
-      implicit_false_operand_set.begin(), implicit_false_operand_set.end());
+  llvm::SmallVector<mlir::Value> implicit_true_operands =
+      implicit_true_operand_set.takeVector();
+  llvm::SmallVector<mlir::Value> implicit_false_operands =
+      implicit_false_operand_set.takeVector();
 
   llvm::SmallVector<std::optional<xla::OpSharding>> ret_shardings =
       GetResultShardings(ctx.builder->sharding(), op->getNumResults());
@@ -1657,13 +1698,13 @@
   // implicit captures operands. Also export the instructions within those
   // regions.
   if (failed(ctx.converter->LowerRegionAsComputation(
-          &op.getTrueBranch(), &true_branch,
-          llvm::ArrayRef(implicit_true_operands),
-          /*ensure_single_arg*/ true, true_arg_shardings, ret_shardings)) ||
+          &op.getTrueBranch(), &true_branch, implicit_true_operands,
+          /*implicit_results=*/{}, /*ensure_single_arg=*/true,
+          true_arg_shardings, ret_shardings)) ||
       failed(ctx.converter->LowerRegionAsComputation(
-          &op.getFalseBranch(), &false_branch,
-          llvm::ArrayRef(implicit_false_operands),
-          /*ensure_single_arg*/ true, false_arg_shardings, ret_shardings))) {
+          &op.getFalseBranch(), &false_branch, implicit_false_operands,
+          /*implicit_results=*/{}, /*ensure_single_arg=*/true,
+          false_arg_shardings, ret_shardings))) {
     return failure();
   }
 
@@ -1701,7 +1742,7 @@
   std::vector<xla::XlaComputation> computations(branches.size());
   std::vector<xla::XlaComputation*> computations_p(branches.size());
 
-  // mhlo.CaseOp does not have any operands or blocks-arguments. The computation
+  // mhlo.CaseOp does not have any operands or blocks arguments. The computation
   // inside the region-blocks use implicit captures of values defined above.
   // In order to create the xla parameters for functions corresponding to
   // CaseOp regions, we need to infer the a region-block's arguments, using all
@@ -1715,8 +1756,8 @@
   for (unsigned i = 0; i < branches.size(); ++i) {
     llvm::SetVector<mlir::Value> implicit_operand_set;
     getUsedValuesDefinedAbove(branches[i], branches[i], implicit_operand_set);
-    llvm::SmallVector<mlir::Value> implicit_operands(
-        implicit_operand_set.begin(), implicit_operand_set.end());
+    llvm::SmallVector<mlir::Value> implicit_operands =
+        implicit_operand_set.takeVector();
 
     llvm::SmallVector<std::optional<xla::OpSharding>> ret_shardings =
         GetResultShardings(ctx.builder->sharding(), op->getNumResults());
@@ -1740,8 +1781,9 @@
     // that region.
     computations_p[i] = &computations[i];
     if (failed(ctx.converter->LowerRegionAsComputation(
-            &branches[i], computations_p[i], llvm::ArrayRef(implicit_operands),
-            /*ensure_single_arg*/ true, arg_shardings, ret_shardings)))
+            &branches[i], computations_p[i], implicit_operands,
+            /*implicit_results=*/{}, /*ensure_single_arg=*/true, arg_shardings,
+            ret_shardings)))
       return failure();
   }
 
@@ -1905,12 +1947,12 @@
   // This feature is at time of writing only used by JAX, and is tested in the
   // jax2tf backwards compatibility tests.
 
-  if (op.getCallTargetName() == "ApproxTopK") {
+  if (op.getCallTargetName() == kApproxTopK) {
     auto isSupportedAttrName = [](NamedAttribute attr) {
       auto name = attr.getName();
-      return name == "call_target_name" || name == "backend_config" ||
-             name == "api_version" || name == "called_computations" ||
-             name == "has_side_effect";
+      return name == kCallTargetName || name == kBackendConfig ||
+             name == kApiVersion || name == kCalledComputations ||
+             name == kHasSideEffect;
     };
     for (const auto& attr : op->getAttrs()) {
       if (!isSupportedAttrName(attr))
@@ -1925,9 +1967,9 @@
 
     for (auto attr : backend_config) {
       auto name = attr.getName();
-      if (!(name == "top_k" || name == "reduction_dim" ||
-            name == "recall_target" || name == "aggregate_to_topk" ||
-            name == "reduction_input_size_override" || name == "is_fallback"))
+      if (!(name == kTopK || name == kReductionDim || name == kRecallTarget ||
+            name == kAggregateToTopk || name == kReductionInputSizeOverride ||
+            name == kIsFallback))
         return op.emitOpError()
                << name.getValue() << " is not a supported backend_config"
                << " attribute for ApproxTopK";
@@ -1969,29 +2011,28 @@
                << " attribute in backend_config must be of bool type";
       return success();
     };
-    if (failed(checkI64Attr("top_k"))) return failure();
-    if (failed(checkI64Attr("reduction_dim"))) return failure();
-    if (failed(checkF32Attr("recall_target"))) return failure();
-    if (failed(checkBoolAttr("aggregate_to_topk"))) return failure();
-    if (failed(checkI64Attr("reduction_input_size_override"))) return failure();
-    bool has_is_fallback = backend_config.contains("is_fallback");
-    if (has_is_fallback && !backend_config.getAs<BoolAttr>("is_fallback"))
+    if (failed(checkI64Attr(kTopK))) return failure();
+    if (failed(checkI64Attr(kReductionDim))) return failure();
+    if (failed(checkF32Attr(kRecallTarget))) return failure();
+    if (failed(checkBoolAttr(kAggregateToTopk))) return failure();
+    if (failed(checkI64Attr(kReductionInputSizeOverride))) return failure();
+    bool has_is_fallback = backend_config.contains(kIsFallback);
+    if (has_is_fallback && !backend_config.getAs<BoolAttr>(kIsFallback))
       return op.emitOpError()
              << "is_fallback attribute in backend_config must be of bool type";
 
-    int64_t top_k = backend_config.getAs<IntegerAttr>("top_k").getInt();
+    int64_t top_k = backend_config.getAs<IntegerAttr>(kTopK).getInt();
     int64_t reduction_dim =
-        backend_config.getAs<IntegerAttr>("reduction_dim").getInt();
-    float recall_target = backend_config.getAs<FloatAttr>("recall_target")
+        backend_config.getAs<IntegerAttr>(kReductionDim).getInt();
+    float recall_target = backend_config.getAs<FloatAttr>(kRecallTarget)
                               .getValue()
                               .convertToFloat();
     bool aggregate_to_topk =
-        backend_config.getAs<BoolAttr>("aggregate_to_topk").getValue();
+        backend_config.getAs<BoolAttr>(kAggregateToTopk).getValue();
     int64_t reduction_input_size_override =
-        backend_config.getAs<IntegerAttr>("reduction_input_size_override")
-            .getInt();
+        backend_config.getAs<IntegerAttr>(kReductionInputSizeOverride).getInt();
     bool is_fallback = has_is_fallback &&
-                       backend_config.getAs<BoolAttr>("is_fallback").getValue();
+                       backend_config.getAs<BoolAttr>(kIsFallback).getValue();
 
     // (C1)
     if (args.size() % 2 != 0) {
@@ -2151,7 +2192,7 @@
 
   absl::StatusOr<xla::Literal> literal;
   const xla::Literal* literal_ptr = nullptr;
-  auto literal_attr = op->getAttrOfType<DenseElementsAttr>(kLiteralAttr);
+  auto literal_attr = op->getAttrOfType<DenseElementsAttr>(kMhloLiteral);
   if (literal_attr) {
     literal = CreateArrayLiteralFromAttr(literal_attr, {});
     if (!literal.ok()) return failure();
@@ -2712,15 +2753,57 @@
 LogicalResult ExportXlaOp(WhileOp op, OpLoweringContext ctx) {
   xla::XlaComputation condition;
   xla::XlaComputation body;
+
   // If the results of the while op have a sharding, we use those shardings for
   // the corresponding arguments and return shardings in the body and condition.
   llvm::SmallVector<std::optional<xla::OpSharding>> res_shardings =
       GetResultShardings(ctx.builder->sharding(), op->getNumResults());
+
+  // mhlo.WhileOp has operands and corresponding blocks arguments, but the
+  // computation inside its region-blocks can also use implicit captures of
+  // values defined above.
+  // In order to create the xla parameters for functions corresponding to
+  // WhileOp regions, we need to infer the implicit region-block's arguments,
+  // using all the values used in the region but defined above.
+  //
+  // Note that the body and cond regions of WhileOp share the same block
+  // arguments, so we collect the implicit values for both in a single set.
+  llvm::SetVector<mlir::Value> implicit_operand_set;
+  getUsedValuesDefinedAbove(op->getRegions(), implicit_operand_set);
+  llvm::SmallVector<mlir::Value> implicit_operands =
+      implicit_operand_set.takeVector();
+
+  llvm::SmallVector<xla::XlaOp> implicit_args;
+  if (failed(GetXlaOps(op, implicit_operands, ctx, implicit_args)))
+    return failure();
+
+  // We need to append the shardings of the implicit values to the result
+  // shardings, since the HLO While will have those implcit values as additional
+  // operands and results.
+  llvm::SmallVector<std::optional<xla::OpSharding>> implicit_shardings;
+  if (!implicit_args.empty() && !res_shardings.empty()) {
+    // We only add implicit arg shardings if there are result shardings,
+    // otherwise it means sharding propagation hasn't been done yet.
+    implicit_shardings = GetXlaOpShardings(implicit_args);
+
+    res_shardings.append(implicit_shardings.begin(), implicit_shardings.end());
+    if (std::optional<xla::OpSharding> new_sharding =
+            CreateTupleSharding(res_shardings)) {
+      ctx.builder->SetSharding(*new_sharding);
+    }
+  }
+
+  // The body of the While needs to return the same number of values as its
+  // arguments, as they are carried over to the next iteration. Thus, we pass
+  // the `implicit_operands` as `implicit_results`, to carry them over as is.
   if (failed(ctx.converter->LowerRegionAsComputation(
-          &op.getBody(), &body, std::nullopt, /*ensure_single_arg=*/true,
-          /*arg_shardings=*/res_shardings, /*ret_shardings=*/res_shardings)) ||
+          &op.getBody(), &body, implicit_operands,
+          /*implicit_results=*/implicit_operands,
+          /*ensure_single_arg=*/true, /*arg_shardings=*/res_shardings,
+          /*ret_shardings=*/res_shardings)) ||
       failed(ctx.converter->LowerRegionAsComputation(
-          &op.getCond(), &condition, std::nullopt,
+          &op.getCond(), &condition, implicit_operands,
+          /*implicit_results=*/{},
           /*ensure_single_arg=*/true, /*arg_shardings=*/res_shardings))) {
     return failure();
   }
@@ -2729,11 +2812,12 @@
   // those operands, to be used as sole operand of xla::While.
   llvm::SmallVector<xla::XlaOp> operands;
   if (failed(GetTuple(op, op.getOperands(), ctx, operands))) return failure();
+  operands.append(implicit_args.begin(), implicit_args.end());
 
   xla::XlaOp operand = operands[0];
   if (operands.size() > 1) operand = Tuple(ctx.builder, operands);
 
-  auto whileop = xla::While(condition, body, operand);
+  xla::XlaOp whileop = xla::While(condition, body, operand);
 
   auto& value_map = *ctx.values;
   auto shape_or = whileop.builder()->GetShape(whileop);
@@ -2748,7 +2832,8 @@
   }
 
   // mhlo.WhileOp supports multiple returns, untuple all the results of XLA's.
-  BuildGetTupleElementsForTupleResults(op, whileop, ctx);
+  BuildGetTupleElementsForTupleResults(
+      op, whileop, ctx, /*num_implicit_results=*/implicit_args.size());
 
   return success();
 }
@@ -2824,11 +2909,11 @@
         xla::internal::XlaBuilderFriend::GetInstruction(operand);
     xla::LayoutProto result_layout =
         ExtractLayout(op, bitcast_proto->shape().dimensions_size(),
-                      "result_layout")
+                      kResultLayout)
             .ToProto();
     xla::LayoutProto source_layout =
         ExtractLayout(op, operand_proto->shape().dimensions_size(),
-                      "source_layout")
+                      kSourceLayout)
             .ToProto();
     xla::gpu::BitcastBackendConfig bitcast_config;
     *bitcast_config.mutable_source_layout() = source_layout;
@@ -2879,7 +2964,7 @@
 
 LogicalResult ConvertLayout(mlir::Operation* op, const mlir::ArrayAttr& layout,
                             xla::ShapeProto* shape) {
-  // In the case of tuples, ShapeProtos can be nested, and so can the mlir
+  // In the case of tuples, Shape protos can be nested, and so can the mlir
   // attribute describing the layout. So recurse into the subshapes in both data
   // structures in parallel.
   if (shape->element_type() == xla::TUPLE) {
@@ -3045,7 +3130,7 @@
 LogicalResult ConvertToHloModule::Lower(
     mlir::Operation* inst, bool is_entry_function,
     llvm::ArrayRef<std::optional<xla::OpSharding>> ret_shardings,
-    xla::XlaBuilder* builder,
+    llvm::ArrayRef<mlir::Value> implicit_results, xla::XlaBuilder* builder,
     ConvertToHloModule::ValueLoweringMap* value_lowering,
     xla::XlaOp* return_value) {
   // Explicitly fail for ops that are not supported for export.
@@ -3092,8 +3177,7 @@
     // For infeed ops stemming back to InfeedDequeueTuple, respect the
     // layout attribute, and create the corresponding layout in hlo.
     if (isa<mhlo::InfeedOp>(inst)) {
-      mlir::ArrayAttr layout =
-          inst->getAttrOfType<mlir::ArrayAttr>(kLayoutAttr);
+      mlir::ArrayAttr layout = inst->getAttrOfType<mlir::ArrayAttr>(kLayout);
 
       if (layout) {
         // We propagate layout to the following three ops:
@@ -3222,40 +3306,51 @@
   if (isa<mhlo::ReturnOp, mlir::func::ReturnOp>(inst)) {
     // Construct the return value for the function. If there is a single value
     // returned, then return it directly, else create a tuple and return.
-    unsigned num_return_values = inst->getNumOperands();
+    unsigned num_return_values =
+        inst->getNumOperands() + implicit_results.size();
     std::optional<xla::OpSharding> ret_tuple_sharding =
         CreateTupleSharding(ret_shardings);
     if ((options_.return_tuple && is_entry_function) ||
         num_return_values != 1) {
-      std::vector<xla::XlaOp> returns(num_return_values);
-      for (OpOperand& ret : inst->getOpOperands()) {
-        unsigned index = ret.getOperandNumber();
-        xla::XlaOp operand;
-        if (failed(GetXlaOp(ret.get(), value_map, &operand, inst)))
-          return failure();
+      std::vector<xla::XlaOp> returns;
+      returns.reserve(num_return_values);
+      // NOTE: we can't use operand_range in llvm::concat.
+      for (Value ret : inst->getOperands()) {
+        xla::XlaOp& operand = returns.emplace_back();
+        if (failed(GetXlaOp(ret, value_map, &operand, inst))) return failure();
+      }
+      for (Value ret : implicit_results) {
+        xla::XlaOp& operand = returns.emplace_back();
+        if (failed(GetXlaOp(ret, value_map, &operand, inst))) return failure();
+      }
+      if (is_entry_function && ret_tuple_sharding) {
+        assert(implicit_results.empty() &&
+               "entry functions shouldn't have implicit results");
+        for (OpOperand& ret : inst->getOpOperands()) {
+          unsigned index = ret.getOperandNumber();
 
-        returns[index] = operand;
-        if (!is_entry_function || !ret_tuple_sharding) continue;
+          xla::Shape return_shape = xla::TypeToShape(ret.get().getType());
+          absl::StatusOr<xla::XlaOp> reshape =
+              ReshapeWithCorrectRepresentationAndSharding(
+                  builder, returns[index], return_shape,
+                  options_.layout_preference_fn,
+                  options_.shape_representation_fn, ret_shardings[index],
+                  /*fast_mem=*/false);
+          if (!reshape.ok())
+            return inst->emitError() << reshape.status().message();
 
-        xla::Shape return_shape = xla::TypeToShape(ret.get().getType());
-        absl::StatusOr<xla::XlaOp> reshape =
-            ReshapeWithCorrectRepresentationAndSharding(
-                builder, returns[index], return_shape,
-                options_.layout_preference_fn, options_.shape_representation_fn,
-                ret_shardings[index], /*fast_mem=*/false);
-        if (!reshape.ok())
-          return inst->emitError() << reshape.status().message();
-
-        returns[index] = reshape.value();
+          returns[index] = reshape.value();
+        }
       }
 
       xla::XlaScopedShardingAssignment scoped_sharding(builder,
                                                        ret_tuple_sharding);
       *return_value = xla::Tuple(builder, returns);
     } else if (num_return_values == 1) {
+      Value ret = implicit_results.empty() ? inst->getOperand(0)
+                                           : implicit_results.front();
       xla::XlaOp operand;
-      if (failed(GetXlaOp(inst->getOperand(0), value_map, &operand, inst)))
-        return failure();
+      if (failed(GetXlaOp(ret, value_map, &operand, inst))) return failure();
 
       if (ret_tuple_sharding) {
         auto tuple = Tuple(builder, {operand});
@@ -3270,6 +3365,59 @@
     return success();
   }
 
+  if (auto composite_op = dyn_cast<mhlo::CompositeOp>(inst)) {
+    SmallVector<xla::XlaOp, 1> operands;
+    for (const Value& val : inst->getOperands()) {
+      xla::XlaOp operand;
+      if (failed(GetXlaOp(val, value_map, &operand, inst))) {
+        return failure();
+      }
+      operands.push_back(operand);
+    }
+
+    xla::XlaComputation computation;
+    if (failed(LowerBasicBlockAsFunction(
+            /*block=*/&module_
+                .lookupSymbol<mlir::func::FuncOp>(
+                    composite_op.getDecomposition())
+                .getBody()
+                .front(),
+            /*builder=*/
+            module_builder_
+                .CreateSubBuilder(composite_op.getDecomposition().str())
+                .get(),
+            /*is_entry_function=*/false,
+            /*ensure_single_arg=*/false,
+            /*entry_args_same_across_replicas=*/{},
+            /*arg_shardings=*/{}, /*ret_shardings=*/{},
+            /*fe_attrs=*/{}, /*result=*/&computation,
+            /*implicit_operands=*/{}))) {
+      return failure();
+    }
+
+    std::string composite_attributes;
+    llvm::raw_string_ostream(composite_attributes)
+        << composite_op.getCompositeAttributes();
+
+    xla::XlaOp composite_call = xla::CompositeCall(
+        builder, computation, operands, composite_op.getName().str(),
+        composite_attributes, composite_op.getVersion());
+
+    // Use GetTupleElement for multiple outputs
+    unsigned num_results = composite_op.getNumResults();
+    if (num_results > 1) {
+      for (unsigned i = 0; i != num_results; ++i) {
+        value_map[composite_op.getResult(i)] =
+            xla::GetTupleElement(composite_call, i);
+      }
+    } else if (num_results == 1) {
+      value_map[composite_op.getResult(0)] = composite_call;
+    }
+    *return_value = composite_call;
+
+    return success();
+  }
+
   inst->emitOpError() << "can't be translated to XLA HLO";
   return failure();
 }
@@ -3318,7 +3466,7 @@
 
   // Create a sub-builder if this is not the main function.
   std::unique_ptr<xla::XlaBuilder> builder_up;
-  bool entry_function = f.getName() == "main";
+  bool entry_function = f.getName() == kMain;
   if (!entry_function)
     builder_up = module_builder_.CreateSubBuilder(f.getName().str());
   auto& builder = entry_function ? module_builder_ : *builder_up;
@@ -3332,14 +3480,14 @@
     bool any_arg_replicated = false;
     entry_args_same_across_replicas.reserve(f.getNumArguments());
     for (int64_t i = 0; i < f.getNumArguments(); ++i) {
-      auto attr = f.getArgAttrOfType<mlir::BoolAttr>(i, kReplicationAttr);
+      auto attr = f.getArgAttrOfType<mlir::BoolAttr>(i, kMhloReplication);
       entry_args_same_across_replicas.push_back(attr != nullptr &&
                                                 attr.getValue());
       any_arg_replicated |= entry_args_same_across_replicas.back();
       // Pass the alias info to the builder so that it will build the alias info
       // into the resulting HloModule.
       auto buffer_donor =
-          f.getArgAttrOfType<mlir::BoolAttr>(i, "jax.buffer_donor");
+          f.getArgAttrOfType<mlir::BoolAttr>(i, kJaxBufferDonor);
       if (buffer_donor) {
         if (options_.use_tuple_args) {
           builder.AddBufferDonor(/*param_number=*/0, /*param_index=*/{i});
@@ -3348,7 +3496,7 @@
         }
       }
       auto aliasing_output =
-          f.getArgAttrOfType<mlir::IntegerAttr>(i, "tf.aliasing_output");
+          f.getArgAttrOfType<mlir::IntegerAttr>(i, kTfAliasingOutput);
       if (!aliasing_output) continue;
       xla::ShapeIndex output_index;
       if ((options_.return_tuple && entry_function) || f.getNumResults() != 1) {
@@ -3383,13 +3531,13 @@
     return failure();
   }
   if (auto execution_thread =
-          f->getAttrOfType<mlir::StringAttr>("execution_thread")) {
+          f->getAttrOfType<mlir::StringAttr>(kExecutionThread)) {
     computation.mutable_proto()->mutable_computations(0)->set_execution_thread(
         execution_thread.str());
   }
   for (int i = 0; i < f.getNumArguments(); ++i) {
     if (auto pr =
-            f.getArgAttrOfType<mlir::ArrayAttr>(i, kParameterReplicationAttr)) {
+            f.getArgAttrOfType<mlir::ArrayAttr>(i, kMhloParameterReplication)) {
       for (auto b : pr.getValue())
         for (auto& instr : *computation.mutable_proto()
                                 ->mutable_computations(0)
@@ -3494,8 +3642,8 @@
     llvm::ArrayRef<std::optional<xla::OpSharding>> arg_shardings,
     llvm::ArrayRef<std::optional<xla::OpSharding>> ret_shardings,
     llvm::ArrayRef<std::optional<xla::FrontendAttributes>> fe_attrs,
-    xla::XlaComputation* result,
-    std::optional<llvm::ArrayRef<mlir::Value>> implicit_operands) {
+    xla::XlaComputation* result, llvm::ArrayRef<mlir::Value> implicit_operands,
+    llvm::ArrayRef<mlir::Value> implicit_results) {
   // Mapping from the Value to lowered XlaOp.
   ValueLoweringMap lowering;
 
@@ -3519,7 +3667,7 @@
     // fuse all the `mlir::Location`s or join the operation name strings with
     // ";" (which is essentially the same).
     auto tuple =
-        xla::Parameter(builder, 0, input_shape, "arg_tuple", leaf_replication);
+        xla::Parameter(builder, 0, input_shape, kArgTuple, leaf_replication);
     builder->ClearSharding();
 
     for (BlockArgument& arg : block->getArguments()) {
@@ -3533,17 +3681,16 @@
       // Applicable for mhlo.IfOp or mhlo.CaseOp or mhlo.WhileOp.
       llvm::SmallVector<xla::Shape, 4> arg_shapes;
 
-      auto args_size = block->getNumArguments();
-      if (implicit_operands) args_size = implicit_operands->size();
+      // Lowering supports mix of block args and implicit operands
+      // Block args must be added before implicit capture operands
+
+      auto args_size = block->getNumArguments() + implicit_operands.size();
 
       arg_shapes.reserve(args_size);
-      if (implicit_operands) {
-        for (auto implicit_operand : *implicit_operands)
-          arg_shapes.push_back(xla::TypeToShape(implicit_operand.getType()));
-      } else {
-        for (BlockArgument& arg : block->getArguments())
-          arg_shapes.push_back(xla::TypeToShape(arg.getType()));
-      }
+      for (BlockArgument& arg : block->getArguments())
+        arg_shapes.push_back(xla::TypeToShape(arg.getType()));
+      for (Value implicit_operand : implicit_operands)
+        arg_shapes.push_back(xla::TypeToShape(implicit_operand.getType()));
 
       if (args_size > 1) {
         xla::XlaScopedShardingAssignment scoped_sharding(
@@ -3554,26 +3701,23 @@
         // but not tuple params. Do the same for tuple params. To do so, either
         // fuse all the `mlir::Location`s or join the operation name strings
         // with ";" (which is essentially the same).
-        auto tuple = xla::Parameter(builder, 0,
-                                    xla::ShapeUtil::MakeTupleShape(arg_shapes),
-                                    "arg_tuple");
+        auto tuple = xla::Parameter(
+            builder, 0, xla::ShapeUtil::MakeTupleShape(arg_shapes), kArgTuple);
 
-        if (implicit_operands) {
-          for (auto [arg_index, implicit_operand] :
-               llvm::enumerate(*implicit_operands)) {
-            xla::XlaScopedShardingAssignment scoped_sharding(
-                builder, arg_shardings.empty() ? std::nullopt
-                                               : arg_shardings[arg_index]);
-            lowering[implicit_operand] = xla::GetTupleElement(tuple, arg_index);
-          }
-        } else {
-          for (BlockArgument& arg : block->getArguments()) {
-            auto num = arg.getArgNumber();
-            xla::XlaScopedShardingAssignment scoped_sharding(
-                builder,
-                arg_shardings.empty() ? std::nullopt : arg_shardings[num]);
-            lowering[arg] = xla::GetTupleElement(tuple, num);
-          }
+        for (BlockArgument& arg : block->getArguments()) {
+          auto num = arg.getArgNumber();
+          xla::XlaScopedShardingAssignment scoped_sharding(
+              builder,
+              arg_shardings.empty() ? std::nullopt : arg_shardings[num]);
+          lowering[arg] = xla::GetTupleElement(tuple, num);
+        }
+        for (auto [implicit_index, implicit_operand] :
+             llvm::enumerate(implicit_operands)) {
+          int64_t arg_index = block->getNumArguments() + implicit_index;
+          xla::XlaScopedShardingAssignment scoped_sharding(
+              builder,
+              arg_shardings.empty() ? std::nullopt : arg_shardings[arg_index]);
+          lowering[implicit_operand] = xla::GetTupleElement(tuple, arg_index);
         }
       } else if (args_size == 1) {
         // Save the location information as a name. For example JAX will set the
@@ -3581,23 +3725,17 @@
         xla::XlaScopedShardingAssignment scoped_sharding(
             builder,
             arg_shardings.empty() ? std::nullopt : arg_shardings.front());
-        if (implicit_operands) {
-          mlir::Value arg = (*implicit_operands)[0];
-          xla::XlaScopedOpMetadataAssignment op_metadata(
-              builder, GetOpNameMetadataFromLocation(arg));
-          lowering[arg] = xla::Parameter(builder, 0, arg_shapes[0], "Arg_");
-        } else {
-          mlir::BlockArgument arg = block->getArgument(0);
-          xla::XlaScopedOpMetadataAssignment op_metadata(
-              builder, GetOpNameMetadataFromLocation(arg));
-          lowering[arg] = xla::Parameter(builder, 0, arg_shapes[0], "Arg_");
-        }
+        mlir::Value arg = implicit_operands.empty() ? block->getArgument(0)
+                                                    : implicit_operands.front();
+        xla::XlaScopedOpMetadataAssignment op_metadata(
+            builder, GetOpNameMetadataFromLocation(arg));
+        lowering[arg] = xla::Parameter(builder, 0, arg_shapes[0], kArgPrefix);
       } else {
         // Applicable only for IfOp or CaseOp. No implicit operands implies no
         // xla parameters. In this case, we create an empty tuple as the
         // block-parameter.
         xla::Parameter(builder, 0, xla::ShapeUtil::MakeTupleShape(arg_shapes),
-                       "arg_empty_tuple");
+                       kArgEmptyTuple);
       }
     } else {
       for (BlockArgument& arg : block->getArguments()) {
@@ -3616,11 +3754,11 @@
         xla::XlaScopedOpMetadataAssignment op_metadata(
             builder, GetOpNameMetadataFromLocation(arg));
         if (entry_args_same_across_replicas.empty()) {
-          lowering[arg] =
-              xla::Parameter(builder, num, shape, absl::StrCat("Arg_", num));
+          lowering[arg] = xla::Parameter(builder, num, shape,
+                                         absl::StrCat(kArgPrefix, num));
         } else {
           lowering[arg] = xla::Parameter(
-              builder, num, shape, absl::StrCat("Arg_", num),
+              builder, num, shape, absl::StrCat(kArgPrefix, num),
               std::vector<bool>(entry_args_same_across_replicas[num],
                                 xla::ShapeUtil::GetLeafCount(shape)));
         }
@@ -3631,8 +3769,8 @@
 
   xla::XlaOp return_value;
   for (auto& inst : *block)
-    if (failed(Lower(&inst, is_entry_function, ret_shardings, builder,
-                     &lowering, &return_value)))
+    if (failed(Lower(&inst, is_entry_function, ret_shardings, implicit_results,
+                     builder, &lowering, &return_value)))
       return failure();
 
   // Build the XlaComputation and check for failures.
@@ -3648,18 +3786,18 @@
 
 LogicalResult ConvertToHloModule::LowerRegionAsComputation(
     mlir::Region* region, xla::XlaComputation* func,
-    std::optional<llvm::ArrayRef<mlir::Value>> implicit_operands,
-    bool ensure_single_arg,
+    llvm::ArrayRef<mlir::Value> implicit_operands,
+    llvm::ArrayRef<mlir::Value> implicit_results, bool ensure_single_arg,
     llvm::ArrayRef<std::optional<xla::OpSharding>> arg_shardings,
     llvm::ArrayRef<std::optional<xla::OpSharding>> ret_shardings) {
-  std::unique_ptr<xla::XlaBuilder> builder =
-      module_builder_.CreateSubBuilder(absl::StrCat("region_", region_id_++));
-  return LowerBasicBlockAsFunction(&region->front(), builder.get(),
-                                   /*is_entry_function=*/false,
-                                   /*ensure_single_arg*/ ensure_single_arg,
-                                   /*entry_args_same_across_replicas=*/{},
-                                   arg_shardings, ret_shardings,
-                                   /*fe_attrs=*/{}, func, implicit_operands);
+  std::unique_ptr<xla::XlaBuilder> builder = module_builder_.CreateSubBuilder(
+      absl::StrCat(kRegionPrefix, region_id_++));
+  return LowerBasicBlockAsFunction(
+      &region->front(), builder.get(),
+      /*is_entry_function=*/false,
+      /*ensure_single_arg*/ ensure_single_arg,
+      /*entry_args_same_across_replicas=*/{}, arg_shardings, ret_shardings,
+      /*fe_attrs=*/{}, func, implicit_operands, implicit_results);
 }
 
 // Runs the PrepareForExport pass on the ModuleOp.
@@ -3704,47 +3842,46 @@
 
   TF_RETURN_IF_ERROR(PrepareForExport(module));
   mlir::BaseScopedDiagnosticHandler diag_handler(module.getContext());
-  xla::XlaBuilder module_builder("main");
+  xla::XlaBuilder module_builder(kMain);
   ConvertToHloModule converter(module, module_builder, options);
   if (failed(converter.Run())) return diag_handler.ConsumeStatus();
   auto hlo_module = converter.ConsumeMainProto();
-  StringRef module_name = module.getName() ? *module.getName() : "main";
+  StringRef module_name = module.getName() ? *module.getName() : kMain;
   hlo_module.set_name(module_name.str());
-  if (auto cross_program_prefetches = module->getAttrOfType<mlir::ArrayAttr>(
-          "mhlo.cross_program_prefetches")) {
+  if (auto cross_program_prefetches =
+          module->getAttrOfType<mlir::ArrayAttr>(kMhloCrossProgramPrefetches)) {
     for (const auto& prefetch :
          Convert_cross_program_prefetches(cross_program_prefetches)) {
       *hlo_module.add_cross_program_prefetches() = std::move(prefetch);
     }
   }
-  if (auto is_dynamic =
-          module->getAttrOfType<mlir::BoolAttr>("mhlo.is_dynamic")) {
+  if (auto is_dynamic = module->getAttrOfType<mlir::BoolAttr>(kMhloIsDynamic)) {
     hlo_module.set_is_dynamic(is_dynamic.getValue());
   }
   if (auto frontend_attributes =
-          module->getAttrOfType<DictionaryAttr>(kFrontendAttributesAttr)) {
+          module->getAttrOfType<DictionaryAttr>(kMhloFrontendAttributes)) {
     ConstructFrontendAttributesFromAttribute(
         frontend_attributes, *hlo_module.mutable_frontend_attributes());
   }
-  if (auto use_auto_spmd_partitioning = module->getAttrOfType<mlir::BoolAttr>(
-          "mhlo.use_auto_spmd_partitioning")) {
+  if (auto use_auto_spmd_partitioning =
+          module->getAttrOfType<mlir::BoolAttr>(kMhloUseAutoSpmdPartitioning)) {
     hlo_module.set_use_auto_spmd_partitioning(
         use_auto_spmd_partitioning.getValue());
   }
-  if (auto spmd_output_sharding = module->getAttrOfType<mlir::StringAttr>(
-          "mhlo.spmd_output_sharding")) {
+  if (auto spmd_output_sharding =
+          module->getAttrOfType<mlir::StringAttr>(kMhloSpmdOutputSharding)) {
     *hlo_module.mutable_spmd_output_sharding() =
         *xla::ConvertSharding(spmd_output_sharding.getValue());
   }
   if (auto input_output_alias =
-          module->getAttrOfType<mlir::ArrayAttr>("mhlo.input_output_alias")) {
+          module->getAttrOfType<mlir::ArrayAttr>(kMhloInputOutputAlias)) {
     if (std::optional<xla::HloInputOutputAliasProto> input_output_alias_proto =
             xla::ConvertInputOutputAlias(input_output_alias.getValue())) {
       *hlo_module.mutable_input_output_alias() = *input_output_alias_proto;
     }
   }
   if (auto spmd_parameters_sharding = module->getAttrOfType<mlir::ArrayAttr>(
-          "mhlo.spmd_parameters_shardings")) {
+          kMhloSpmdParametersShardings)) {
     for (const auto& sharding : spmd_parameters_sharding.getValue()) {
       *hlo_module.add_spmd_parameters_shardings() = *xla::ConvertSharding(
           mlir::cast<mlir::StringAttr>(sharding).getValue());
@@ -3812,7 +3949,8 @@
     } else {
       xla::XlaOp return_value;
       if (failed(converter.Lower(&inst, /*is_entry_function=*/true,
-                                 /*ret_shardings=*/{}, &builder, &lowering,
+                                 /*ret_shardings=*/{},
+                                 /*implicit_results=*/{}, &builder, &lowering,
                                  &return_value)))
         return diag_handler.ConsumeStatus();
     }
diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/module_config_exporter.cc b/third_party/xla/xla/translate/mhlo_to_hlo/module_config_exporter.cc
index 88f0523..7dad7c3 100644
--- a/third_party/xla/xla/translate/mhlo_to_hlo/module_config_exporter.cc
+++ b/third_party/xla/xla/translate/mhlo_to_hlo/module_config_exporter.cc
@@ -22,18 +22,20 @@
 namespace mlir {
 namespace mhlo {
 namespace {
-constexpr char kConfigNumPartitions[] = "mhlo.num_partitions";
-constexpr char kConfigNumReplicas[] = "mhlo.num_replicas";
+
+constexpr char kMhloNumPartitions[] = "mhlo.num_partitions";
+constexpr char kMhloNumReplicas[] = "mhlo.num_replicas";
+
 }  // namespace
 
 void ExportHloModuleConfig(xla::HloModuleConfig& config,
                            mlir::ModuleOp module) {
   if (auto num_partitions =
-          module->getAttrOfType<mlir::IntegerAttr>(kConfigNumPartitions)) {
+          module->getAttrOfType<mlir::IntegerAttr>(kMhloNumPartitions)) {
     config.set_num_partitions(num_partitions.getInt());
   }
   if (auto num_replicas =
-          module->getAttrOfType<mlir::IntegerAttr>(kConfigNumReplicas)) {
+          module->getAttrOfType<mlir::IntegerAttr>(kMhloNumReplicas)) {
     config.set_replica_count(num_replicas.getInt());
   }
 }
diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/BUILD b/third_party/xla/xla/translate/mhlo_to_hlo/tests/BUILD
index 37354ba..f947307 100644
--- a/third_party/xla/xla/translate/mhlo_to_hlo/tests/BUILD
+++ b/third_party/xla/xla/translate/mhlo_to_hlo/tests/BUILD
@@ -12,9 +12,11 @@
         [
             "add.mlir",
             "case.mlir",
+            "composite.mlir",
             "dynamic.mlir",
             "export-with-layouts.mlir",
             "export.mlir",
+            "export_async.mlir",
             "export_and_check_layouts.mlir",
             "export_large_constants.mlir",
             "export_replicas.mlir",
@@ -36,6 +38,7 @@
             "simple.mlir",
             "unsupported_type.mlir",
             "while.mlir",
+            "while_free_vars.mlir",
         ],
         include = [
             "*.mlir",
diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/composite.mlir b/third_party/xla/xla/translate/mhlo_to_hlo/tests/composite.mlir
new file mode 100644
index 0000000..60c5548
--- /dev/null
+++ b/third_party/xla/xla/translate/mhlo_to_hlo/tests/composite.mlir
@@ -0,0 +1,190 @@
+// RUN: xla-translate -split-input-file -mlir-hlo-to-hlo-text %s | FileCheck %s
+
+module @composite {
+  // CHECK: HloModule composite, entry_computation_layout={()->f32[]}
+  // CHECK: %add.2 (Arg_0.3: f32[]) -> f32[] {
+  // CHECK:   %Arg_0.3 = f32[] parameter(0)
+  // CHECK:   %constant.4 = f32[] constant(2)
+  // CHECK:   ROOT %add.5 = f32[] add(f32[] %Arg_0.3, f32[] %constant.4)
+  // CHECK: }
+  // CHECK: ENTRY %main.7 () -> f32[] {
+  // CHECK:   %constant.1 = f32[] constant(42)
+  // CHECK:   ROOT %call.6 = f32[] call(f32[] %constant.1), to_apply=%add.2, is_composite=true, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor<i32>},composite.name="foo.bar",composite.version="1"}
+  // CHECK: }
+  func.func @main() -> tensor<f32> {
+    %0 = mhlo.constant dense<4.200000e+01> : tensor<f32>
+    %1 = mhlo.composite "foo.bar" %0 {
+      composite_attributes = {
+        n = 1 : i32,
+        tensor = dense<1> : tensor<i32>
+      },
+      decomposition = @add,
+      version = 1 : i32
+    } : (tensor<f32>) -> tensor<f32>
+    return %1 : tensor<f32>
+  }
+  func.func @add(%arg0: tensor<f32>) -> tensor<f32> {
+    %0 = mhlo.constant dense<2.000000e+00> : tensor<f32>
+    %1 = mhlo.add %arg0, %0 : tensor<f32>
+    return %1 : tensor<f32>
+  }
+}
+
+// -----
+
+// zero-output composite
+module @composite {
+  //CHECK: HloModule composite, entry_computation_layout={()->()}
+  //CHECK: %return.2 (Arg_0.3: f32[]) -> () {
+  //CHECK:   %Arg_0.3 = f32[] parameter(0)
+  //CHECK:   ROOT %tuple.4 = () tuple()
+  //CHECK: }
+  //CHECK: ENTRY %main.7 () -> () {
+  //CHECK:   %constant.1 = f32[] constant(42)
+  //CHECK:   %call.5 = () call(f32[] %constant.1), to_apply=%return.2, is_composite=true, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor<i32>},composite.name="foo.bar",composite.version="1"}
+  //CHECK:   ROOT %tuple.6 = () tuple()
+  //CHECK: }
+  func.func @main() -> () {
+    %0 = mhlo.constant dense<4.200000e+01> : tensor<f32>
+    "mhlo.composite"(%0) {
+      name = "foo.bar",
+      composite_attributes = {
+        n = 1 : i32,
+        tensor = dense<1> : tensor<i32>
+      },
+      decomposition = @return,
+      version = 1 : i32
+    } : (tensor<f32>) -> ()
+    return
+  }
+  func.func @return(%arg0: tensor<f32>) -> () {
+    return
+  }
+}
+
+// -----
+
+// multi-output composite
+module @composite {
+  //CHECK: HloModule composite, entry_computation_layout={()->(f32[], f32[])}
+  //CHECK: %add.2 (Arg_0.3: f32[]) -> (f32[], f32[]) {
+  //CHECK:   %Arg_0.3 = f32[] parameter(0)
+  //CHECK:   %constant.4 = f32[] constant(2)
+  //CHECK:   %add.5 = f32[] add(f32[] %Arg_0.3, f32[] %constant.4)
+  //CHECK:   ROOT %tuple.6 = (f32[], f32[]) tuple(f32[] %add.5, f32[] %add.5)
+  //CHECK: }
+  //CHECK: ENTRY %main.11 () -> (f32[], f32[]) {
+  //CHECK:   %constant.1 = f32[] constant(42)
+  //CHECK:   %call.7 = (f32[], f32[]) call(f32[] %constant.1), to_apply=%add.2, is_composite=true, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor<i32>},composite.name="foo.bar",composite.version="1"}
+  //CHECK:   %get-tuple-element.8 = f32[] get-tuple-element((f32[], f32[]) %call.7), index=0
+  //CHECK:   %get-tuple-element.9 = f32[] get-tuple-element((f32[], f32[]) %call.7), index=1
+  //CHECK:   ROOT %tuple.10 = (f32[], f32[]) tuple(f32[] %get-tuple-element.8, f32[] %get-tuple-element.9)
+  //CHECK: }
+  func.func @main() -> (tensor<f32>, tensor<f32>) {
+    %0 = mhlo.constant dense<4.200000e+01> : tensor<f32>
+    %result:2 = "mhlo.composite"(%0) {
+      name = "foo.bar",
+      composite_attributes = {
+        n = 1 : i32,
+        tensor = dense<1> : tensor<i32>
+      },
+      decomposition = @add,
+      version = 1 : i32
+    } : (tensor<f32>) -> (tensor<f32>, tensor<f32>)
+    return %result#0, %result#1 : tensor<f32>, tensor<f32>
+  }
+  func.func @add(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
+    %0 = mhlo.constant dense<2.000000e+00> : tensor<f32>
+    %1 = mhlo.add %arg0, %0 : tensor<f32>
+    return %1, %1 : tensor<f32>, tensor<f32>
+  }
+}
+
+// -----
+
+// optional composite attributes
+module @composite {
+  // CHECK: HloModule composite, entry_computation_layout={()->f32[]}
+  // CHECK: %add.2 (Arg_0.3: f32[]) -> f32[] {
+  // CHECK:   %Arg_0.3 = f32[] parameter(0)
+  // CHECK:   %constant.4 = f32[] constant(2)
+  // CHECK:   ROOT %add.5 = f32[] add(f32[] %Arg_0.3, f32[] %constant.4)
+  // CHECK: }
+  // CHECK: ENTRY %main.7 () -> f32[] {
+  // CHECK:   %constant.1 = f32[] constant(42)
+  // CHECK:   ROOT %call.6 = f32[] call(f32[] %constant.1), to_apply=%add.2, is_composite=true, frontend_attributes={composite.attributes={},composite.name="foo.bar",composite.version="1"}
+  // CHECK: }
+  func.func @main() -> tensor<f32> {
+    %0 = mhlo.constant dense<4.200000e+01> : tensor<f32>
+    %1 = mhlo.composite "foo.bar" %0 {
+      decomposition = @add,
+      version = 1 : i32
+    } : (tensor<f32>) -> tensor<f32>
+    return %1 : tensor<f32>
+  }
+  func.func @add(%arg0: tensor<f32>) -> tensor<f32> {
+    %0 = mhlo.constant dense<2.000000e+00> : tensor<f32>
+    %1 = mhlo.add %arg0, %0 : tensor<f32>
+    return %1 : tensor<f32>
+  }
+}
+
+// -----
+
+// optional composite version
+module @composite {
+  // CHECK: HloModule composite, entry_computation_layout={()->f32[]}
+  // CHECK: %add.2 (Arg_0.3: f32[]) -> f32[] {
+  // CHECK:   %Arg_0.3 = f32[] parameter(0)
+  // CHECK:   %constant.4 = f32[] constant(2)
+  // CHECK:   ROOT %add.5 = f32[] add(f32[] %Arg_0.3, f32[] %constant.4)
+  // CHECK: }
+  // CHECK: ENTRY %main.7 () -> f32[] {
+  // CHECK:   %constant.1 = f32[] constant(42)
+  // CHECK:   ROOT %call.6 = f32[] call(f32[] %constant.1), to_apply=%add.2, is_composite=true, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor<i32>},composite.name="foo.bar",composite.version="0"}
+  // CHECK: }
+  func.func @main() -> tensor<f32> {
+    %0 = mhlo.constant dense<4.200000e+01> : tensor<f32>
+    %1 = mhlo.composite "foo.bar" %0 {
+      composite_attributes = {
+        n = 1 : i32,
+        tensor = dense<1> : tensor<i32>
+      },
+      decomposition = @add
+    } : (tensor<f32>) -> tensor<f32>
+    return %1 : tensor<f32>
+  }
+  func.func @add(%arg0: tensor<f32>) -> tensor<f32> {
+    %0 = mhlo.constant dense<2.000000e+00> : tensor<f32>
+    %1 = mhlo.add %arg0, %0 : tensor<f32>
+    return %1 : tensor<f32>
+  }
+}
+
+// -----
+
+// optional composite attributes and version
+module @composite {
+  // CHECK: HloModule composite, entry_computation_layout={()->f32[]}
+  // CHECK: %add.2 (Arg_0.3: f32[]) -> f32[] {
+  // CHECK:   %Arg_0.3 = f32[] parameter(0)
+  // CHECK:   %constant.4 = f32[] constant(2)
+  // CHECK:   ROOT %add.5 = f32[] add(f32[] %Arg_0.3, f32[] %constant.4)
+  // CHECK: }
+  // CHECK: ENTRY %main.7 () -> f32[] {
+  // CHECK:   %constant.1 = f32[] constant(42)
+  // CHECK:   ROOT %call.6 = f32[] call(f32[] %constant.1), to_apply=%add.2, is_composite=true, frontend_attributes={composite.attributes={},composite.name="foo.bar",composite.version="0"}
+  // CHECK: }
+  func.func @main() -> tensor<f32> {
+    %0 = mhlo.constant dense<4.200000e+01> : tensor<f32>
+    %1 = mhlo.composite "foo.bar" %0 {
+      decomposition = @add
+    } : (tensor<f32>) -> tensor<f32>
+    return %1 : tensor<f32>
+  }
+  func.func @add(%arg0: tensor<f32>) -> tensor<f32> {
+    %0 = mhlo.constant dense<2.000000e+00> : tensor<f32>
+    %1 = mhlo.add %arg0, %0 : tensor<f32>
+    return %1 : tensor<f32>
+  }
+}
diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/export-with-layouts.mlir b/third_party/xla/xla/translate/mhlo_to_hlo/tests/export-with-layouts.mlir
index 680341f..3d44aff 100644
--- a/third_party/xla/xla/translate/mhlo_to_hlo/tests/export-with-layouts.mlir
+++ b/third_party/xla/xla/translate/mhlo_to_hlo/tests/export-with-layouts.mlir
@@ -1,5 +1,5 @@
 // RUN: xla-translate -split-input-file -mlir-hlo-to-hlo-text -with-layouts -print-layouts %s | FileCheck %s
-// RUN: xla-translate -split-input-file -mlir-hlo-to-hlo-text  -with-layouts -print-layouts --via-builder=true %s | FileCheck %s
+// RUN: xla-translate -split-input-file -mlir-hlo-to-hlo-text -with-layouts -print-layouts --via-builder=true %s | FileCheck %s
 
 #CSR = #sparse_tensor.encoding<{
   map = (d0, d1) -> (d0 : dense, d1 : compressed),
diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/export.mlir b/third_party/xla/xla/translate/mhlo_to_hlo/tests/export.mlir
index dec3e5d..6672e62 100644
--- a/third_party/xla/xla/translate/mhlo_to_hlo/tests/export.mlir
+++ b/third_party/xla/xla/translate/mhlo_to_hlo/tests/export.mlir
@@ -1,6 +1,16 @@
 // RUN: xla-translate --print-sugar=false -split-input-file -mlir-hlo-to-hlo-text -verify-diagnostics %s | FileCheck %s
 // RUN: xla-translate --print-sugar=false -split-input-file -mlir-hlo-to-hlo-text -verify-diagnostics --via-builder=true %s | FileCheck %s
 
+// CHECK: HloModule foo
+// CHECK: ENTRY %main
+module @foo {
+  func.func @main(%arg: tensor<i1>) -> tensor<i1> {
+    func.return %arg : tensor<i1>
+  }
+}
+
+// -----
+
 // CHECK:  HloModule
 func.func @main(%arg0: tensor<2xi1>) -> tensor<2xi1> {
   %0 = "mhlo.add"(%arg0, %arg0) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1>
@@ -109,114 +119,6 @@
 
 // -----
 
-// CHECK:  HloModule
-func.func @all_gather_0(%arg1: tensor<128x32xf32>) -> tensor<128x128xf32> attributes {execution_thread = "main"} {
-  %0 = "mhlo.all_gather"(%arg1) {
-    all_gather_dim = 1 : i64,
-    channel_handle = #mhlo.channel_handle<handle = 1, type = 0>,
-    shard_count = 4,
-    replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>,
-    use_global_device_ids
-  } : (tensor<128x32xf32>) -> tensor<128x128xf32>
-  return %0 : tensor<128x128xf32>
-}
-
-func.func @main(%arg0: tensor<128x32xf32>) -> tensor<128x128xf32> {
-  %0 = "mhlo.async_start"(%arg0) {called_computation = @all_gather_0, execution_thread = "main"} : (tensor<128x32xf32>) -> !mhlo.async_bundle<tensor<128x32xf32>, tensor<128x128xf32>>
-  %1 = "mhlo.async_done"(%0) : (!mhlo.async_bundle<tensor<128x32xf32>, tensor<128x128xf32>>) -> tensor<128x128xf32>
-  return %1 : tensor<128x128xf32>
-}
-
-// CHECK: ENTRY
-// CHECK: %[[INPUT:.*]] = f32[128,32] parameter(0)
-// CHECK: %[[OUTPUT:.*]] = f32[128,128] all-gather-start(f32[128,32] %[[INPUT]])
-// CHECK-SAME: channel_id=1
-// CHECK-SAME{LITERAL}: replica_groups={{0,2,4,6},{1,3,5,7}}
-// CHECK-SAME: dimensions={1}
-// CHECK-SAME: use_global_device_ids=true
-// CHECK: ROOT {{.*}} f32[128,128] all-gather-done(f32[128,128] %[[OUTPUT]]
-
-// -----
-
-// CHECK:  HloModule
-func.func @all_reduce_0(%arg0: tensor<10xf32>) -> tensor<10xf32> attributes {execution_thread = "main"} {
-  %0 = "mhlo.all_reduce"(%arg0) ({
-  ^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>):
-    %max = mhlo.maximum %lhs, %rhs : tensor<f32>
-    "mhlo.return"(%max) : (tensor<f32>) -> ()
-  })
-  {
-    replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>,
-    channel_handle = #mhlo.channel_handle<
-      handle = 5,
-      type = 2
-    >,
-    use_global_device_ids
-  } : (tensor<10xf32>) -> tensor<10xf32>
-  func.return %0 : tensor<10xf32>
-}
-
-func.func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> {
-  %0 = "mhlo.async_start"(%arg0) {called_computation = @all_reduce_0, execution_thread = "main"} : (tensor<10xf32>) -> !mhlo.async_bundle<tensor<10xf32>, tensor<10xf32>>
-  %1 = "mhlo.async_done"(%0) : (!mhlo.async_bundle<tensor<10xf32>, tensor<10xf32>>) -> tensor<10xf32>
-  return %1 : tensor<10xf32>
-}
-
-// CHECK: ENTRY
-// CHECK: %[[INPUT:.*]] = f32[10] parameter(0)
-// CHECK: %[[OUTPUT:.*]] = f32[10] all-reduce-start(f32[10] %[[INPUT]])
-// CHECK-SAME:  channel_id=5
-// CHECK-SAME{LITERAL}:  replica_groups={{0,2,4,6},{1,3,5,7}}
-// CHECK-SAME: use_global_device_ids=true
-// CHECK: ROOT {{.*}} f32[10] all-reduce-done(f32[10] %[[OUTPUT]]
-
-// -----
-
-// expected-error@-3 {{'mhlo.async_start' op can't be translated to XLA HLO}}
-func.func @all_reduce_0(%arg0: tensor<10xf32>, %arg1: tensor<1xf32>) -> (tensor<10xf32>, tensor<1xf32>) attributes {execution_thread = "main"} {
-  %0:2 = "mhlo.all_reduce"(%arg0, %arg1) ({
-  ^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>):
-    %max = mhlo.maximum %lhs, %rhs : tensor<f32>
-    "mhlo.return"(%max) : (tensor<f32>) -> ()
-  })
-  {
-    replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>,
-    channel_handle = #mhlo.channel_handle<
-      handle = 5,
-      type = 2
-    >,
-    use_global_device_ids
-  } : (tensor<10xf32>, tensor<1xf32>) -> (tensor<10xf32>, tensor<1xf32>)
-  func.return %0#0, %0#1 : tensor<10xf32>, tensor<1xf32>
-}
-
-func.func @main(%arg0: tensor<10xf32>, %arg1: tensor<1xf32>) -> (tensor<10xf32>, tensor<1xf32>) {
-  %0 = "mhlo.async_start"(%arg0, %arg1) {called_computation = @all_reduce_0, execution_thread = "main"} : (tensor<10xf32>, tensor<1xf32>) -> !mhlo.async_bundle<tuple<tensor<10xf32>,tensor<1xf32>>, tuple<tensor<10xf32>,tensor<1xf32>>>
-  %1:2 = "mhlo.async_done"(%0) : (!mhlo.async_bundle<tuple<tensor<10xf32>,tensor<1xf32>>, tuple<tensor<10xf32>,tensor<1xf32>>>) -> (tensor<10xf32>, tensor<1xf32>)
-  return %1#0, %1#1 : tensor<10xf32>, tensor<1xf32>
-}
-
-// -----
-
-// expected-error@-3 {{'mhlo.async_start' op can't be translated to XLA HLO}}
-func.func @all_gather_0(%arg0: tensor<8x2xf32>, %arg1: tensor<8x4xf32>) -> (tensor<8x2xf32>, tensor<8x4xf32>) attributes {execution_thread = "main"} {
-  %0:2 = "mhlo.all_gather"(%arg0, %arg1) {
-    all_gather_dim = 1 : i64,
-    channel_handle = #mhlo.channel_handle<handle = 1, type = 0>,
-    replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>,
-    use_global_device_ids
-  } : (tensor<8x2xf32>, tensor<8x4xf32>) -> (tensor<8x2xf32>, tensor<8x4xf32>)
-  func.return %0#0, %0#1 : tensor<8x2xf32>, tensor<8x4xf32>
-}
-
-func.func @main(%arg0: tensor<8x2xf32>, %arg1: tensor<8x4xf32>) -> (tensor<8x2xf32>, tensor<8x4xf32>) {
-  %0 = "mhlo.async_start"(%arg0, %arg1) {called_computation = @all_gather_0, execution_thread = "main"} : (tensor<8x2xf32>, tensor<8x4xf32>) -> !mhlo.async_bundle<tuple<tensor<8x2xf32>,tensor<8x4xf32>>, tuple<tensor<8x2xf32>,tensor<8x4xf32>>>
-  %1:2 = "mhlo.async_done"(%0) : (!mhlo.async_bundle<tuple<tensor<8x2xf32>,tensor<8x4xf32>>, tuple<tensor<8x2xf32>,tensor<8x4xf32>>>) -> (tensor<8x2xf32>, tensor<8x4xf32>)
-  return %1#0, %1#1 : tensor<8x2xf32>, tensor<8x4xf32>
-}
-
-// -----
-
 func.func private @main(%arg0: tensor<8x2xf32>, %arg1: tensor<8x4xf32>) -> tuple<tensor<8x8xf32>, tensor<8x16xf32>> {
   // CHECK:      %[[ARG0:.*]] = f32[8,2] parameter(0)
   // CHECK-NEXT: %[[ARG1:.*]] = f32[8,4] parameter(1)
@@ -625,30 +527,6 @@
 // -----
 
 // CHECK:  HloModule
-func.func @collective_permute_0(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> attributes {execution_thread = "main"} {
-  %0 = "mhlo.collective_permute"(%arg0) {
-    source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>,
-    channel_handle = #mhlo.channel_handle<handle = 1, type = 0>
-  } : (tensor<128x32xf32>) -> tensor<128x32xf32>
-  func.return %0 : tensor<128x32xf32>
-}
-
-func.func @main(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> {
-  %0 = "mhlo.async_start"(%arg0) {called_computation = @collective_permute_0, execution_thread = "main"} : (tensor<128x32xf32>) -> !mhlo.async_bundle<tensor<128x32xf32>, tensor<128x32xf32>>
-  %1 = "mhlo.async_done"(%0) : (!mhlo.async_bundle<tensor<128x32xf32>, tensor<128x32xf32>>) -> tensor<128x32xf32>
-  return %1 : tensor<128x32xf32>
-}
-
-// CHECK: ENTRY
-// CHECK: %[[INPUT:.*]] = f32[128,32] parameter(0)
-// CHECK: %[[OUTPUT:.*]] = f32[128,32] collective-permute-start(f32[128,32] %[[INPUT]])
-// CHECK-SAME:  channel_id=1
-// CHECK-SAME{LITERAL}:  source_target_pairs={{0,1},{1,2},{2,3}}
-// CHECK: ROOT {{.*}} f32[128,32] collective-permute-done(f32[128,32] %[[OUTPUT]]
-
-// -----
-
-// CHECK:  HloModule
 func.func @main(%arg0 : tensor<5x2xf32>,
            %arg1 : tensor<5x5xf32>,
            %arg2 : tensor<5x7xf32>) -> tensor<5x14xf32> {
@@ -893,27 +771,6 @@
 // -----
 
 // CHECK:  HloModule
-func.func @copy_0(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> attributes {execution_thread = "main"} {
-  %0 = "mhlo.copy"(%arg0) {cross_program_prefetch_index = 0 : i32} : (tensor<128x32xf32>) -> tensor<128x32xf32>
-  func.return %0 : tensor<128x32xf32>
-}
-
-func.func @main(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> {
-  %0 = "mhlo.async_start"(%arg0) {called_computation = @copy_0, execution_thread = "main"} : (tensor<128x32xf32>) -> !mhlo.async_bundle<tensor<128x32xf32>, tensor<128x32xf32>>
-  %1 = "mhlo.async_done"(%0) : (!mhlo.async_bundle<tensor<128x32xf32>, tensor<128x32xf32>>) -> tensor<128x32xf32>
-  return %1 : tensor<128x32xf32>
-}
-
-// CHECK: ENTRY
-// CHECK: %[[INPUT:.*]] = f32[128,32] parameter(0)
-// CHECK: %[[OUTPUT:.*]] = (f32[128,32], f32[128,32], u32[]) copy-start(f32[128,32] %[[INPUT]])
-// CHECK-SAME:  cross_program_prefetch_index=0
-// CHECK: ROOT {{.*}} f32[128,32] copy-done((f32[128,32], f32[128,32], u32[]) %[[OUTPUT]]
-
-
-// -----
-
-// CHECK:  HloModule
 func.func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> {
   %0 = mhlo.constant dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi32>
   %1 = "mhlo.cross-replica-sum"(%arg0) {replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>} : (tensor<10xf32>) -> tensor<10xf32>
@@ -2131,67 +1988,6 @@
 // -----
 
 // CHECK:  HloModule
-
-func.func @recv_0(%token: !mhlo.token) -> (!mhlo.token) attributes {execution_thread = "main"} {
-  %0 = "mhlo.recv"(%token) {
-    channel_handle = #mhlo.channel_handle<
-      handle = 5,
-      type = 1  // Device to device channel
-    >,
-    is_host_transfer = false
-  } : (!mhlo.token) -> (!mhlo.token)
-  func.return %0 : !mhlo.token
-}
-
-func.func @main(%token: !mhlo.token) -> (!mhlo.token) {
-  %0 = "mhlo.async_start"(%token) {called_computation = @recv_0, execution_thread = "main"} : (!mhlo.token) -> !mhlo.async_bundle<!mhlo.token, !mhlo.token, tensor<i32>>
-  %2 = "mhlo.async_done"(%0) : (!mhlo.async_bundle<!mhlo.token, !mhlo.token, tensor<i32>>) -> (!mhlo.token)
-  return %2 : !mhlo.token
-}
-
-// CHECK:  ENTRY
-// CHECK:  [[TOKEN:%.*]] = token[] parameter(0)
-// CHECK:  [[RECV:%.*]] = ((), u32[], token[]) recv(token[] [[TOKEN]]), channel_id=5
-// CHECK:  ((), token[]) recv-done(((), u32[], token[]) [[RECV]]), channel_id=5
-
-// -----
-
-// CHECK:  HloModule
-func.func @recv_0(%token: !mhlo.token) -> (tensor<3x4xi32>, !mhlo.token) attributes {execution_thread = "main"} {
-  %0:2 = "mhlo.recv"(%token) {
-    channel_handle = #mhlo.channel_handle<
-      handle = 5,
-      type = 3  // Host to device channel
-    >,
-    is_host_transfer = true
-  } : (!mhlo.token) -> (tensor<3x4xi32>, !mhlo.token)
-  func.return %0#0, %0#1 : tensor<3x4xi32>, !mhlo.token
-}
-
-func.func @main(%token: !mhlo.token) -> (tensor<3x4xi32>, !mhlo.token) {
-  %0 = "mhlo.async_start"(%token) {called_computation = @recv_0, execution_thread = "main", mhlo.sharding = "{{maximal device=0}, {maximal device=0}, {maximal device=0}}"} : (!mhlo.token) -> !mhlo.async_bundle<!mhlo.token, tuple<tensor<3x4xi32>, !mhlo.token>, tensor<i32>>
-  %1, %2 = "mhlo.async_done"(%0) {mhlo.sharding = "{{maximal device=0}, {maximal device=0}}"} : (!mhlo.async_bundle<!mhlo.token, tuple<tensor<3x4xi32>, !mhlo.token>, tensor<i32>>) -> (tensor<3x4xi32>, !mhlo.token)
-  return %1, %2 : tensor<3x4xi32>, !mhlo.token
-}
-
-// CHECK:  ENTRY
-// CHECK:  [[TOKEN:%.*]] = token[] parameter(0)
-// CHECK:  [[RECV:%.*]] = (s32[3,4], u32[], token[]) recv(token[] [[TOKEN]]), channel_id=5, is_host_transfer
-// CHECK-SAME:  sharding={
-// CHECK-SAME:    {maximal device=0}, {maximal device=0}, {maximal device=0}
-// CHECK-SAME:  }
-// CHECK:  [[RECV_DONE:%.*]] = (s32[3,4], token[]) recv-done((s32[3,4], u32[], token[]) [[RECV]]), channel_id=5, is_host_transfer
-// CHECK-SAME:  sharding={
-// CHECK-SAME:    {maximal device=0}, {maximal device=0}
-// CHECK-SAME:  }
-// CHECK:  [[TUPLE0:%.*]] = s32[3,4] get-tuple-element((s32[3,4], token[]) [[RECV_DONE]]), index=0, sharding={maximal device=0}
-// CHECK:  [[TUPLE1:%.*]] = token[] get-tuple-element((s32[3,4], token[]) [[RECV_DONE]]), index=1, sharding={maximal device=0}
-// CHECK:  ROOT {{%.*}} = (s32[3,4], token[]) tuple(s32[3,4] [[TUPLE0]], token[] [[TUPLE1]])
-
-// -----
-
-
-// CHECK:  HloModule
 func.func @main(%arg0 : tensor<1x10xf32>, %arg1 : tensor<1x10xi32>, %arg2 : tensor<f32>, %arg3 : tensor<i32>) -> (tensor<1xf32>, tensor<1xi32>) {
   %result0, %result1 = "mhlo.reduce"(%arg0, %arg1, %arg2, %arg3) ({
     ^bb0(%fa: tensor<f32>, %ia : tensor<i32>, %fb: tensor<f32>, %ib: tensor<i32>):
@@ -2466,58 +2262,6 @@
 // -----
 
 // CHECK:  HloModule
-func.func @send_0(%arg: tensor<3x4xi32>, %token: !mhlo.token) -> !mhlo.token attributes {execution_thread = "main"} {
-  %0 = "mhlo.send"(%arg, %token) {
-    channel_handle = #mhlo.channel_handle<
-      handle = 5,
-      type = 2  // Device to host channel
-    >,
-    is_host_transfer = true
-  } : (tensor<3x4xi32>, !mhlo.token) -> !mhlo.token
-  func.return %0 : !mhlo.token
-}
-
-func.func @main(%arg: tensor<3x4xi32>, %token: !mhlo.token) -> !mhlo.token {
-  %0 = "mhlo.async_start"(%arg, %token) {called_computation = @send_0, execution_thread = "main"} : (tensor<3x4xi32>, !mhlo.token) -> !mhlo.async_bundle<tuple<tensor<3x4xi32>, !mhlo.token>, !mhlo.token, tensor<i32>>
-  %1 = "mhlo.async_done"(%0) : (!mhlo.async_bundle<tuple<tensor<3x4xi32>, !mhlo.token>, !mhlo.token, tensor<i32>>) -> !mhlo.token
-  return %1 : !mhlo.token
-}
-
-// CHECK:  ENTRY
-// CHECK:  [[ARG:%.*]] = s32[3,4] parameter(0)
-// CHECK:  [[TOKEN:%.*]] = token[] parameter(1)
-// CHECK:  [[SEND:%.*]] = (s32[3,4], u32[], token[]) send(s32[3,4] [[ARG]], token[] [[TOKEN]]), channel_id=5, is_host_transfer
-// CHECK:  ROOT
-// CHECK-SAME:  token[] send-done((s32[3,4], u32[], token[]) [[SEND]]), channel_id=5
-
-// -----
-
-// CHECK:  HloModule
-func.func @send_0(%token: !mhlo.token) -> !mhlo.token attributes {execution_thread = "main"} {
-  %0 = "mhlo.send"(%token) {
-    channel_handle = #mhlo.channel_handle<
-      handle = 5,
-      type = 1  // Device to device channel
-    >
-  } : (!mhlo.token) -> !mhlo.token
-  func.return %0 : !mhlo.token
-}
-
-func.func @main(%token: !mhlo.token) -> !mhlo.token {
-  %0 = "mhlo.async_start"(%token) {called_computation = @send_0, execution_thread = "main"} : (!mhlo.token) -> !mhlo.async_bundle<!mhlo.token, !mhlo.token, tensor<i32>>
-  %1 = "mhlo.async_done"(%0) : (!mhlo.async_bundle<!mhlo.token, !mhlo.token, tensor<i32>>) -> !mhlo.token
-  return %1 : !mhlo.token
-}
-
-// CHECK:  ENTRY
-// CHECK:  [[TOKEN:%.*]] = token[] parameter(0)
-// CHECK:  [[SEND:%.*]] = ((), u32[], token[]) send(() [[UNIT:%.*]], token[] [[TOKEN]]), channel_id=5
-// CHECK:  ROOT
-// CHECK-SAME:  token[] send-done(((), u32[], token[]) [[SEND]]), channel_id=5
-
-// -----
-
-// CHECK:  HloModule
 func.func @main(%arg: tensor<4x4xf32>, %size: tensor<i32>) -> tensor<4x4xf32> {
   %0 = "mhlo.set_dimension_size"(%arg, %size) {dimension = 1 : i64} : (tensor<4x4xf32>, tensor<i32>) -> tensor<4x4xf32>
   func.return %0 : tensor<4x4xf32>
@@ -2935,55 +2679,6 @@
            window_strides = dense<[3, 1]> : tensor<2xi64> } : (tensor<4x2xf32>, tensor<4x2xi32>, tensor<f32>, tensor<i32>) -> (tensor<2x2xf32>, tensor<2x2xi32>)
   func.return %0#0, %0#1 : tensor<2x2xf32>, tensor<2x2xi32>
 }
-// -----
-
-// CHECK: HloModule
-// CHECK: [[CALLED_COMPUTATION:%AsyncOp.*]] ([[ARG:.*]]: f32[10]) -> f32[20] {
-func.func @AsyncOp(%arg0: tensor<10xf32>) -> tensor<20xf32>
-  attributes {execution_thread = "thread"} {
-  %0 = "mhlo.custom_call"(%arg0) {call_target_name = "foo"} : (tensor<10xf32>) -> tensor<20xf32>
-  return %0 : tensor<20xf32>
-}
-
-// CHECK: ENTRY
-func.func @main(%arg0: tensor<10xf32>) -> tensor<20xf32> {
-  // CHECK: %[[ARG0:.*]] = f32[10] parameter(0)
-  // CHECK: %[[START:.*]] = ((f32[10]), f32[20], s32[]) async-start(f32[10] %[[ARG0]])
-  // CHECK-SAME: calls=[[CALLED_COMPUTATION]]
-  %0 = "mhlo.async_start"(%arg0) {called_computation = @AsyncOp, execution_thread = "thread"} : (tensor<10xf32>) -> !mhlo.async_bundle<tuple<tensor<10xf32>>, tensor<20xf32>, tensor<i32>>
-  // CHECK: %[[UPDATE:.*]] = ((f32[10]), f32[20], s32[]) async-update(((f32[10]), f32[20], s32[]) %[[START]])
-  %1 = "mhlo.async_update"(%0) : (!mhlo.async_bundle<tuple<tensor<10xf32>>, tensor<20xf32>, tensor<i32>>) -> !mhlo.async_bundle<tuple<tensor<10xf32>>, tensor<20xf32>, tensor<i32>>
-  // CHECK: ROOT %{{.*}} = (f32[20]) async-done(((f32[10]), f32[20], s32[]) %[[UPDATE]])
-  %2 = "mhlo.async_done"(%1) : (!mhlo.async_bundle<tuple<tensor<10xf32>>, tensor<20xf32>, tensor<i32>>) -> tensor<20xf32>
-  return %2 : tensor<20xf32>
-}
-
-// -----
-
-// CHECK: HloModule
-// CHECK: [[CALLED_COMPUTATION:%AsyncOp.*]] ([[ARG:.*]]: f32[10]) -> f32[20] {
-func.func @AsyncOp(%arg0: tensor<10xf32>) -> tensor<20xf32>
-  attributes {execution_thread = "thread"} {
-  %1 = "mhlo.custom_call"(%arg0) {call_target_name = "bar"} : (tensor<10xf32>) -> tensor<20xf32>
-  // CHECK: custom-call
-  // CHECK-SAME: custom_call_target="bar"
-  return %1 : tensor<20xf32>
-}
-
-// CHECK: ENTRY
-func.func @main(%arg0: tensor<10xf32>) -> tensor<20xf32> {
-  // CHECK: %[[ARG0:.*]] = f32[10] parameter(0)
-  // CHECK: %[[START:.*]] = ((f32[10]), f32[20], s32[]) async-start(f32[10] %[[ARG0]]), async_execution_thread="thread", calls=[[CALLED_COMPUTATION]],
-  // CHECK: %[[UPDATE:.*]] = ((f32[10]), f32[20], s32[]) async-update(((f32[10]), f32[20], s32[]) %[[START]])
-  // CHECK: ROOT
-  // CHECK-SAME: (f32[20]) async-done(((f32[10]), f32[20], s32[]) %[[UPDATE]])
-
-  %0 = "mhlo.async_start"(%arg0) {called_computation = @AsyncOp, execution_thread="thread"} : (tensor<10xf32>) -> !mhlo.async_bundle<tuple<tensor<10xf32>>, tensor<20xf32>, tensor<i32>>
-  %1 = "mhlo.async_update"(%0) : (!mhlo.async_bundle<tuple<tensor<10xf32>>, tensor<20xf32>, tensor<i32>>) -> !mhlo.async_bundle<tuple<tensor<10xf32>>, tensor<20xf32>, tensor<i32>>
-  %2 = "mhlo.async_done"(%1) : (!mhlo.async_bundle<tuple<tensor<10xf32>>, tensor<20xf32>, tensor<i32>>) -> tensor<20xf32>
-  return %2 : tensor<20xf32>
-}
-
 
 // -----
 
diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/export_async.mlir b/third_party/xla/xla/translate/mhlo_to_hlo/tests/export_async.mlir
new file mode 100644
index 0000000..70bf10c
--- /dev/null
+++ b/third_party/xla/xla/translate/mhlo_to_hlo/tests/export_async.mlir
@@ -0,0 +1,312 @@
+// RUN: xla-translate --print-sugar=false -split-input-file -mlir-hlo-to-hlo-text -verify-diagnostics %s | FileCheck %s
+
+// CHECK:  HloModule
+func.func @all_gather_0(%arg1: tensor<128x32xf32>) -> tensor<128x128xf32> attributes {execution_thread = "main"} {
+  %0 = "mhlo.all_gather"(%arg1) {
+    all_gather_dim = 1 : i64,
+    channel_handle = #mhlo.channel_handle<handle = 1, type = 0>,
+    shard_count = 4,
+    replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>,
+    use_global_device_ids
+  } : (tensor<128x32xf32>) -> tensor<128x128xf32>
+  return %0 : tensor<128x128xf32>
+}
+
+func.func @main(%arg0: tensor<128x32xf32>) -> tensor<128x128xf32> {
+  %0 = "mhlo.async_start"(%arg0) {called_computation = @all_gather_0, execution_thread = "main"} : (tensor<128x32xf32>) -> !mhlo.async_bundle<tensor<128x32xf32>, tensor<128x128xf32>>
+  %1 = "mhlo.async_done"(%0) : (!mhlo.async_bundle<tensor<128x32xf32>, tensor<128x128xf32>>) -> tensor<128x128xf32>
+  return %1 : tensor<128x128xf32>
+}
+
+// CHECK: ENTRY
+// CHECK: %[[INPUT:.*]] = f32[128,32] parameter(0)
+// CHECK: %[[OUTPUT:.*]] = f32[128,128] all-gather-start(f32[128,32] %[[INPUT]])
+// CHECK-SAME: channel_id=1
+// CHECK-SAME{LITERAL}: replica_groups={{0,2,4,6},{1,3,5,7}}
+// CHECK-SAME: dimensions={1}
+// CHECK-SAME: use_global_device_ids=true
+// CHECK: ROOT {{.*}} f32[128,128] all-gather-done(f32[128,128] %[[OUTPUT]]
+
+// -----
+
+// CHECK:  HloModule
+func.func @all_reduce_0(%arg0: tensor<10xf32>) -> tensor<10xf32> attributes {execution_thread = "main"} {
+  %0 = "mhlo.all_reduce"(%arg0) ({
+  ^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>):
+    %max = mhlo.maximum %lhs, %rhs : tensor<f32>
+    "mhlo.return"(%max) : (tensor<f32>) -> ()
+  })
+  {
+    replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>,
+    channel_handle = #mhlo.channel_handle<
+      handle = 5,
+      type = 2
+    >,
+    use_global_device_ids
+  } : (tensor<10xf32>) -> tensor<10xf32>
+  func.return %0 : tensor<10xf32>
+}
+
+func.func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> {
+  %0 = "mhlo.async_start"(%arg0) {called_computation = @all_reduce_0, execution_thread = "main"} : (tensor<10xf32>) -> !mhlo.async_bundle<tensor<10xf32>, tensor<10xf32>>
+  %1 = "mhlo.async_done"(%0) : (!mhlo.async_bundle<tensor<10xf32>, tensor<10xf32>>) -> tensor<10xf32>
+  return %1 : tensor<10xf32>
+}
+
+// CHECK: ENTRY
+// CHECK: %[[INPUT:.*]] = f32[10] parameter(0)
+// CHECK: %[[OUTPUT:.*]] = f32[10] all-reduce-start(f32[10] %[[INPUT]])
+// CHECK-SAME:  channel_id=5
+// CHECK-SAME{LITERAL}:  replica_groups={{0,2,4,6},{1,3,5,7}}
+// CHECK-SAME: use_global_device_ids=true
+// CHECK: ROOT {{.*}} f32[10] all-reduce-done(f32[10] %[[OUTPUT]]
+
+// -----
+
+// expected-error@-3 {{'mhlo.async_start' op can't be translated to XLA HLO}}
+func.func @all_reduce_0(%arg0: tensor<10xf32>, %arg1: tensor<1xf32>) -> (tensor<10xf32>, tensor<1xf32>) attributes {execution_thread = "main"} {
+  %0:2 = "mhlo.all_reduce"(%arg0, %arg1) ({
+  ^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>):
+    %max = mhlo.maximum %lhs, %rhs : tensor<f32>
+    "mhlo.return"(%max) : (tensor<f32>) -> ()
+  })
+  {
+    replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>,
+    channel_handle = #mhlo.channel_handle<
+      handle = 5,
+      type = 2
+    >,
+    use_global_device_ids
+  } : (tensor<10xf32>, tensor<1xf32>) -> (tensor<10xf32>, tensor<1xf32>)
+  func.return %0#0, %0#1 : tensor<10xf32>, tensor<1xf32>
+}
+
+func.func @main(%arg0: tensor<10xf32>, %arg1: tensor<1xf32>) -> (tensor<10xf32>, tensor<1xf32>) {
+  %0 = "mhlo.async_start"(%arg0, %arg1) {called_computation = @all_reduce_0, execution_thread = "main"} : (tensor<10xf32>, tensor<1xf32>) -> !mhlo.async_bundle<tuple<tensor<10xf32>,tensor<1xf32>>, tuple<tensor<10xf32>,tensor<1xf32>>>
+  %1:2 = "mhlo.async_done"(%0) : (!mhlo.async_bundle<tuple<tensor<10xf32>,tensor<1xf32>>, tuple<tensor<10xf32>,tensor<1xf32>>>) -> (tensor<10xf32>, tensor<1xf32>)
+  return %1#0, %1#1 : tensor<10xf32>, tensor<1xf32>
+}
+
+// -----
+
+// expected-error@-3 {{'mhlo.async_start' op can't be translated to XLA HLO}}
+func.func @all_gather_0(%arg0: tensor<8x2xf32>, %arg1: tensor<8x4xf32>) -> (tensor<8x2xf32>, tensor<8x4xf32>) attributes {execution_thread = "main"} {
+  %0:2 = "mhlo.all_gather"(%arg0, %arg1) {
+    all_gather_dim = 1 : i64,
+    channel_handle = #mhlo.channel_handle<handle = 1, type = 0>,
+    replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>,
+    use_global_device_ids
+  } : (tensor<8x2xf32>, tensor<8x4xf32>) -> (tensor<8x2xf32>, tensor<8x4xf32>)
+  func.return %0#0, %0#1 : tensor<8x2xf32>, tensor<8x4xf32>
+}
+
+func.func @main(%arg0: tensor<8x2xf32>, %arg1: tensor<8x4xf32>) -> (tensor<8x2xf32>, tensor<8x4xf32>) {
+  %0 = "mhlo.async_start"(%arg0, %arg1) {called_computation = @all_gather_0, execution_thread = "main"} : (tensor<8x2xf32>, tensor<8x4xf32>) -> !mhlo.async_bundle<tuple<tensor<8x2xf32>,tensor<8x4xf32>>, tuple<tensor<8x2xf32>,tensor<8x4xf32>>>
+  %1:2 = "mhlo.async_done"(%0) : (!mhlo.async_bundle<tuple<tensor<8x2xf32>,tensor<8x4xf32>>, tuple<tensor<8x2xf32>,tensor<8x4xf32>>>) -> (tensor<8x2xf32>, tensor<8x4xf32>)
+  return %1#0, %1#1 : tensor<8x2xf32>, tensor<8x4xf32>
+}
+
+// -----
+
+// CHECK:  HloModule
+func.func @collective_permute_0(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> attributes {execution_thread = "main"} {
+  %0 = "mhlo.collective_permute"(%arg0) {
+    source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>,
+    channel_handle = #mhlo.channel_handle<handle = 1, type = 0>
+  } : (tensor<128x32xf32>) -> tensor<128x32xf32>
+  func.return %0 : tensor<128x32xf32>
+}
+
+func.func @main(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> {
+  %0 = "mhlo.async_start"(%arg0) {called_computation = @collective_permute_0, execution_thread = "main"} : (tensor<128x32xf32>) -> !mhlo.async_bundle<tensor<128x32xf32>, tensor<128x32xf32>>
+  %1 = "mhlo.async_done"(%0) : (!mhlo.async_bundle<tensor<128x32xf32>, tensor<128x32xf32>>) -> tensor<128x32xf32>
+  return %1 : tensor<128x32xf32>
+}
+
+// CHECK: ENTRY
+// CHECK: %[[INPUT:.*]] = f32[128,32] parameter(0)
+// CHECK: %[[OUTPUT:.*]] = f32[128,32] collective-permute-start(f32[128,32] %[[INPUT]])
+// CHECK-SAME:  channel_id=1
+// CHECK-SAME{LITERAL}:  source_target_pairs={{0,1},{1,2},{2,3}}
+// CHECK: ROOT {{.*}} f32[128,32] collective-permute-done(f32[128,32] %[[OUTPUT]]
+
+// -----
+
+// CHECK:  HloModule
+func.func @copy_0(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> attributes {execution_thread = "main"} {
+  %0 = "mhlo.copy"(%arg0) {cross_program_prefetch_index = 0 : i32} : (tensor<128x32xf32>) -> tensor<128x32xf32>
+  func.return %0 : tensor<128x32xf32>
+}
+
+func.func @main(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> {
+  %0 = "mhlo.async_start"(%arg0) {called_computation = @copy_0, execution_thread = "main"} : (tensor<128x32xf32>) -> !mhlo.async_bundle<tensor<128x32xf32>, tensor<128x32xf32>>
+  %1 = "mhlo.async_done"(%0) : (!mhlo.async_bundle<tensor<128x32xf32>, tensor<128x32xf32>>) -> tensor<128x32xf32>
+  return %1 : tensor<128x32xf32>
+}
+
+// CHECK: ENTRY
+// CHECK: %[[INPUT:.*]] = f32[128,32] parameter(0)
+// CHECK: %[[OUTPUT:.*]] = (f32[128,32], f32[128,32], u32[]) copy-start(f32[128,32] %[[INPUT]])
+// CHECK-SAME:  cross_program_prefetch_index=0
+// CHECK: ROOT {{.*}} f32[128,32] copy-done((f32[128,32], f32[128,32], u32[]) %[[OUTPUT]]
+
+// -----
+
+// CHECK:  HloModule
+
+func.func @recv_0(%token: !mhlo.token) -> (!mhlo.token) attributes {execution_thread = "main"} {
+  %0 = "mhlo.recv"(%token) {
+    channel_handle = #mhlo.channel_handle<
+      handle = 5,
+      type = 1  // Device to device channel
+    >,
+    is_host_transfer = false
+  } : (!mhlo.token) -> (!mhlo.token)
+  func.return %0 : !mhlo.token
+}
+
+func.func @main(%token: !mhlo.token) -> (!mhlo.token) {
+  %0 = "mhlo.async_start"(%token) {called_computation = @recv_0, execution_thread = "main"} : (!mhlo.token) -> !mhlo.async_bundle<!mhlo.token, !mhlo.token, tensor<i32>>
+  %2 = "mhlo.async_done"(%0) : (!mhlo.async_bundle<!mhlo.token, !mhlo.token, tensor<i32>>) -> (!mhlo.token)
+  return %2 : !mhlo.token
+}
+
+// CHECK:  ENTRY
+// CHECK:  [[TOKEN:%.*]] = token[] parameter(0)
+// CHECK:  [[RECV:%.*]] = ((), u32[], token[]) recv(token[] [[TOKEN]]), channel_id=5
+// CHECK:  ((), token[]) recv-done(((), u32[], token[]) [[RECV]]), channel_id=5
+
+// -----
+
+// CHECK:  HloModule
+func.func @recv_0(%token: !mhlo.token) -> (tensor<3x4xi32>, !mhlo.token) attributes {execution_thread = "main"} {
+  %0:2 = "mhlo.recv"(%token) {
+    channel_handle = #mhlo.channel_handle<
+      handle = 5,
+      type = 3  // Host to device channel
+    >,
+    is_host_transfer = true
+  } : (!mhlo.token) -> (tensor<3x4xi32>, !mhlo.token)
+  func.return %0#0, %0#1 : tensor<3x4xi32>, !mhlo.token
+}
+
+func.func @main(%token: !mhlo.token) -> (tensor<3x4xi32>, !mhlo.token) {
+  %0 = "mhlo.async_start"(%token) {called_computation = @recv_0, execution_thread = "main", mhlo.sharding = "{{maximal device=0}, {maximal device=0}, {maximal device=0}}"} : (!mhlo.token) -> !mhlo.async_bundle<!mhlo.token, tuple<tensor<3x4xi32>, !mhlo.token>, tensor<i32>>
+  %1, %2 = "mhlo.async_done"(%0) {mhlo.sharding = "{{maximal device=0}, {maximal device=0}}"} : (!mhlo.async_bundle<!mhlo.token, tuple<tensor<3x4xi32>, !mhlo.token>, tensor<i32>>) -> (tensor<3x4xi32>, !mhlo.token)
+  return %1, %2 : tensor<3x4xi32>, !mhlo.token
+}
+
+// CHECK:  ENTRY
+// CHECK:  [[TOKEN:%.*]] = token[] parameter(0)
+// CHECK:  [[RECV:%.*]] = (s32[3,4], u32[], token[]) recv(token[] [[TOKEN]]), channel_id=5, is_host_transfer
+// CHECK-SAME:  sharding={
+// CHECK-SAME:    {maximal device=0}, {maximal device=0}, {maximal device=0}
+// CHECK-SAME:  }
+// CHECK:  [[RECV_DONE:%.*]] = (s32[3,4], token[]) recv-done((s32[3,4], u32[], token[]) [[RECV]]), channel_id=5, is_host_transfer
+// CHECK-SAME:  sharding={
+// CHECK-SAME:    {maximal device=0}, {maximal device=0}
+// CHECK-SAME:  }
+// CHECK:  [[TUPLE0:%.*]] = s32[3,4] get-tuple-element((s32[3,4], token[]) [[RECV_DONE]]), index=0, sharding={maximal device=0}
+// CHECK:  [[TUPLE1:%.*]] = token[] get-tuple-element((s32[3,4], token[]) [[RECV_DONE]]), index=1, sharding={maximal device=0}
+// CHECK:  ROOT {{%.*}} = (s32[3,4], token[]) tuple(s32[3,4] [[TUPLE0]], token[] [[TUPLE1]])
+
+// -----
+
+// CHECK:  HloModule
+func.func @send_0(%arg: tensor<3x4xi32>, %token: !mhlo.token) -> !mhlo.token attributes {execution_thread = "main"} {
+  %0 = "mhlo.send"(%arg, %token) {
+    channel_handle = #mhlo.channel_handle<
+      handle = 5,
+      type = 2  // Device to host channel
+    >,
+    is_host_transfer = true
+  } : (tensor<3x4xi32>, !mhlo.token) -> !mhlo.token
+  func.return %0 : !mhlo.token
+}
+
+func.func @main(%arg: tensor<3x4xi32>, %token: !mhlo.token) -> !mhlo.token {
+  %0 = "mhlo.async_start"(%arg, %token) {called_computation = @send_0, execution_thread = "main"} : (tensor<3x4xi32>, !mhlo.token) -> !mhlo.async_bundle<tuple<tensor<3x4xi32>, !mhlo.token>, !mhlo.token, tensor<i32>>
+  %1 = "mhlo.async_done"(%0) : (!mhlo.async_bundle<tuple<tensor<3x4xi32>, !mhlo.token>, !mhlo.token, tensor<i32>>) -> !mhlo.token
+  return %1 : !mhlo.token
+}
+
+// CHECK:  ENTRY
+// CHECK:  [[ARG:%.*]] = s32[3,4] parameter(0)
+// CHECK:  [[TOKEN:%.*]] = token[] parameter(1)
+// CHECK:  [[SEND:%.*]] = (s32[3,4], u32[], token[]) send(s32[3,4] [[ARG]], token[] [[TOKEN]]), channel_id=5, is_host_transfer
+// CHECK:  ROOT
+// CHECK-SAME:  token[] send-done((s32[3,4], u32[], token[]) [[SEND]]), channel_id=5
+
+// -----
+
+// CHECK:  HloModule
+func.func @send_0(%token: !mhlo.token) -> !mhlo.token attributes {execution_thread = "main"} {
+  %0 = "mhlo.send"(%token) {
+    channel_handle = #mhlo.channel_handle<
+      handle = 5,
+      type = 1  // Device to device channel
+    >
+  } : (!mhlo.token) -> !mhlo.token
+  func.return %0 : !mhlo.token
+}
+
+func.func @main(%token: !mhlo.token) -> !mhlo.token {
+  %0 = "mhlo.async_start"(%token) {called_computation = @send_0, execution_thread = "main"} : (!mhlo.token) -> !mhlo.async_bundle<!mhlo.token, !mhlo.token, tensor<i32>>
+  %1 = "mhlo.async_done"(%0) : (!mhlo.async_bundle<!mhlo.token, !mhlo.token, tensor<i32>>) -> !mhlo.token
+  return %1 : !mhlo.token
+}
+
+// CHECK:  ENTRY
+// CHECK:  [[TOKEN:%.*]] = token[] parameter(0)
+// CHECK:  [[SEND:%.*]] = ((), u32[], token[]) send(() [[UNIT:%.*]], token[] [[TOKEN]]), channel_id=5
+// CHECK:  ROOT
+// CHECK-SAME:  token[] send-done(((), u32[], token[]) [[SEND]]), channel_id=5
+
+// -----
+
+// CHECK: HloModule
+// CHECK: [[CALLED_COMPUTATION:%AsyncOp.*]] ([[ARG:.*]]: f32[10]) -> f32[20] {
+func.func @AsyncOp(%arg0: tensor<10xf32>) -> tensor<20xf32>
+  attributes {execution_thread = "thread"} {
+  %0 = "mhlo.custom_call"(%arg0) {call_target_name = "foo"} : (tensor<10xf32>) -> tensor<20xf32>
+  return %0 : tensor<20xf32>
+}
+
+// CHECK: ENTRY
+func.func @main(%arg0: tensor<10xf32>) -> tensor<20xf32> {
+  // CHECK: %[[ARG0:.*]] = f32[10] parameter(0)
+  // CHECK: %[[START:.*]] = ((f32[10]), f32[20], s32[]) async-start(f32[10] %[[ARG0]])
+  // CHECK-SAME: calls=[[CALLED_COMPUTATION]]
+  %0 = "mhlo.async_start"(%arg0) {called_computation = @AsyncOp, execution_thread = "thread"} : (tensor<10xf32>) -> !mhlo.async_bundle<tuple<tensor<10xf32>>, tensor<20xf32>, tensor<i32>>
+  // CHECK: %[[UPDATE:.*]] = ((f32[10]), f32[20], s32[]) async-update(((f32[10]), f32[20], s32[]) %[[START]])
+  %1 = "mhlo.async_update"(%0) : (!mhlo.async_bundle<tuple<tensor<10xf32>>, tensor<20xf32>, tensor<i32>>) -> !mhlo.async_bundle<tuple<tensor<10xf32>>, tensor<20xf32>, tensor<i32>>
+  // CHECK: ROOT %{{.*}} = (f32[20]) async-done(((f32[10]), f32[20], s32[]) %[[UPDATE]])
+  %2 = "mhlo.async_done"(%1) : (!mhlo.async_bundle<tuple<tensor<10xf32>>, tensor<20xf32>, tensor<i32>>) -> tensor<20xf32>
+  return %2 : tensor<20xf32>
+}
+
+// -----
+
+// CHECK: HloModule
+// CHECK: [[CALLED_COMPUTATION:%AsyncOp.*]] ([[ARG:.*]]: f32[10]) -> f32[20] {
+func.func @AsyncOp(%arg0: tensor<10xf32>) -> tensor<20xf32>
+  attributes {execution_thread = "thread"} {
+  %1 = "mhlo.custom_call"(%arg0) {call_target_name = "bar"} : (tensor<10xf32>) -> tensor<20xf32>
+  // CHECK: custom-call
+  // CHECK-SAME: custom_call_target="bar"
+  return %1 : tensor<20xf32>
+}
+
+// CHECK: ENTRY
+func.func @main(%arg0: tensor<10xf32>) -> tensor<20xf32> {
+  // CHECK: %[[ARG0:.*]] = f32[10] parameter(0)
+  // CHECK: %[[START:.*]] = ((f32[10]), f32[20], s32[]) async-start(f32[10] %[[ARG0]]), async_execution_thread="thread", calls=[[CALLED_COMPUTATION]],
+  // CHECK: %[[UPDATE:.*]] = ((f32[10]), f32[20], s32[]) async-update(((f32[10]), f32[20], s32[]) %[[START]])
+  // CHECK: ROOT
+  // CHECK-SAME: (f32[20]) async-done(((f32[10]), f32[20], s32[]) %[[UPDATE]])
+
+  %0 = "mhlo.async_start"(%arg0) {called_computation = @AsyncOp, execution_thread="thread"} : (tensor<10xf32>) -> !mhlo.async_bundle<tuple<tensor<10xf32>>, tensor<20xf32>, tensor<i32>>
+  %1 = "mhlo.async_update"(%0) : (!mhlo.async_bundle<tuple<tensor<10xf32>>, tensor<20xf32>, tensor<i32>>) -> !mhlo.async_bundle<tuple<tensor<10xf32>>, tensor<20xf32>, tensor<i32>>
+  %2 = "mhlo.async_done"(%1) : (!mhlo.async_bundle<tuple<tensor<10xf32>>, tensor<20xf32>, tensor<i32>>) -> tensor<20xf32>
+  return %2 : tensor<20xf32>
+}
diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/while_free_vars.mlir b/third_party/xla/xla/translate/mhlo_to_hlo/tests/while_free_vars.mlir
new file mode 100644
index 0000000..3663f92
--- /dev/null
+++ b/third_party/xla/xla/translate/mhlo_to_hlo/tests/while_free_vars.mlir
@@ -0,0 +1,89 @@
+// RUN: xla-translate -split-input-file -mlir-hlo-to-hlo-text %s -o - | FileCheck %s
+
+// This test verifies that the correct shardings are added when a while loop
+// has free variables.
+
+// CHECK-LABEL: HloModule main
+
+// CHECK:      %region_0.7 (arg_tuple.8: (s32[], f32[4], s32[], s32[], f32[4])) -> (s32[], f32[4], s32[], s32[], f32[4]) {
+// CHECK-NEXT:   %arg_tuple.8 = (s32[], f32[4], s32[], s32[], f32[4]) parameter(0)
+// CHECK-SAME:     sharding={{\{}}{replicated}, {devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}, {replicated}, {devices=[4]<=[4]}}
+// CHECK-DAG:    %get-tuple-element.12 = s32[] get-tuple-element((s32[], f32[4], s32[], s32[], f32[4]) %arg_tuple.8), index=3
+// CHECK-DAG:    %get-tuple-element.13 = f32[4] get-tuple-element((s32[], f32[4], s32[], s32[], f32[4]) %arg_tuple.8), index=4, sharding={devices=[4]<=[4]}
+// CHECK-DAG:    %add.14 = s32[] add(s32[] %get-tuple-element.9, s32[] %get-tuple-element.12)
+// CHECK-DAG:    %add.15 = f32[4] add(f32[4] %get-tuple-element.10, f32[4] %get-tuple-element.13)
+// CHECK:        ROOT %tuple.16 = (s32[], f32[4], s32[], s32[], f32[4]) tuple(s32[] %add.14, f32[4] %add.15, s32[] %get-tuple-element.11, s32[] %get-tuple-element.12, f32[4] %get-tuple-element.13)
+// CHECK-SAME:     sharding={{\{}}{replicated}, {devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}, {replicated}, {devices=[4]<=[4]}}
+
+// CHECK:      %region_1.17 (arg_tuple.18: (s32[], f32[4], s32[], s32[], f32[4])) -> pred[] {
+// CHECK-NEXT:   %arg_tuple.18 = (s32[], f32[4], s32[], s32[], f32[4]) parameter(0)
+// CHECK-SAME:     sharding={{\{}}{replicated}, {devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}, {replicated}, {devices=[4]<=[4]}}
+// CHECK:        %get-tuple-element.21 = s32[] get-tuple-element((s32[], f32[4], s32[], s32[], f32[4]) %arg_tuple.18), index=2
+// CHECK-NEXT:   ROOT %compare.24 = pred[] compare(s32[] %get-tuple-element.19, s32[] %get-tuple-element.21), direction=LT
+
+// CHECK:      ENTRY %main.28 (Arg_0.1: s32[], Arg_1.2: f32[4], Arg_2.3: f32[4]) -> f32[4] {
+// CHECK-NEXT:   %Arg_0.1 = s32[] parameter(0)
+// CHECK-NEXT:   %Arg_1.2 = f32[4] parameter(1)
+// CHECK-NEXT:   %constant.4 = s32[] constant(0)
+// CHECK-NEXT:   %constant.5 = s32[] constant(1)
+// CHECK-NEXT:   %Arg_2.3 = f32[4] parameter(2)
+// CHECK-NEXT:   %tuple.6 = (s32[], f32[4], s32[], s32[], f32[4]) tuple(s32[] %Arg_0.1, f32[4] %Arg_1.2, s32[] %constant.4, s32[] %constant.5, f32[4] %Arg_2.3)
+// CHECK-SAME:     sharding={{\{}}{replicated}, {devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}, {replicated}, {devices=[4]<=[4]}}
+// CHECK-NEXT:   %while.25 = (s32[], f32[4], s32[], s32[], f32[4]) while((s32[], f32[4], s32[], s32[], f32[4]) %tuple.6), condition=%region_1.17, body=%region_0.7
+// CHECK-SAME:     sharding={{\{}}{replicated}, {devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}, {replicated}, {devices=[4]<=[4]}}
+// CHECK-NEXT:   %get-tuple-element.26 = s32[] get-tuple-element((s32[], f32[4], s32[], s32[], f32[4]) %while.25), index=0, sharding={replicated}
+// CHECK-NEXT:   ROOT %get-tuple-element.27 = f32[4] get-tuple-element((s32[], f32[4], s32[], s32[], f32[4]) %while.25), index=1, sharding={devices=[2,2]<=[4] last_tile_dim_replicate}
+
+func.func @main(%arg0: tensor<i32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32> {mhlo.sharding = "{devices=[4]<=[4]}"}) -> tensor<4xf32> {
+  %0 = mhlo.constant dense<0> : tensor<i32>
+  %1 = mhlo.constant dense<1> : tensor<i32>
+  %2:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %arg1) : tensor<i32>, tensor<4xf32>
+    attributes {mhlo.sharding = "{{replicated},{devices=[2,2]<=[4] last_tile_dim_replicate}}"}
+    cond {
+    %3 = mhlo.compare LT, %iterArg, %0 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    mhlo.return %3 : tensor<i1>
+  } do {
+    %3 = mhlo.add %iterArg, %1 : tensor<i32>
+    %4 = mhlo.add %iterArg_0, %arg2 : tensor<4xf32>
+    mhlo.return %3, %4: tensor<i32>, tensor<4xf32>
+  }
+  func.return %2#1 : tensor<4xf32>
+}
+
+// -----
+
+// This test verifies that a value captured multiple times is only lifted once
+// and all its uses are replaced. Also verifies that no sharding is added to
+// region parameters or root when the while doesn't have a sharding.
+
+// CHECK-LABEL: HloModule main
+
+// CHECK:      %region_0.5 (arg_tuple.6: (s32[], f32[4], s32[])) -> (s32[], f32[4], s32[]) {
+// CHECK-NEXT:   %arg_tuple.6 = (s32[], f32[4], s32[]) parameter(0)
+// CHECK:        %get-tuple-element.9 = s32[] get-tuple-element((s32[], f32[4], s32[]) %arg_tuple.6), index=2
+// CHECK:        %add.10 = s32[] add(s32[] %get-tuple-element.7, s32[] %get-tuple-element.9)
+// CHECK:        ROOT %tuple.11 = (s32[], f32[4], s32[]) tuple(s32[] %add.10, f32[4] %get-tuple-element.8, s32[] %get-tuple-element.9)
+
+// CHECK:      %region_1.12 (arg_tuple.13: (s32[], f32[4], s32[])) -> pred[] {
+// CHECK-NEXT:   %arg_tuple.13 = (s32[], f32[4], s32[]) parameter(0)
+// CHECK:        %get-tuple-element.16 = s32[] get-tuple-element((s32[], f32[4], s32[]) %arg_tuple.13), index=2
+// CHECK:        ROOT %compare.17 = pred[] compare(s32[] %get-tuple-element.14, s32[] %get-tuple-element.16), direction=LT
+
+// CHECK:      ENTRY %main.21 (Arg_0.1: s32[], Arg_1.2: f32[4], Arg_2.3: s32[]) -> f32[4] {
+// CHECK-NEXT:   %Arg_0.1 = s32[] parameter(0)
+// CHECK-NEXT:   %Arg_1.2 = f32[4] parameter(1)
+// CHECK-NEXT:   %Arg_2.3 = s32[] parameter(2)
+// CHECK-NEXT:   %tuple.4 = (s32[], f32[4], s32[]) tuple(s32[] %Arg_0.1, f32[4] %Arg_1.2, s32[] %Arg_2.3)
+// CHECK-NEXT:   %while.18 = (s32[], f32[4], s32[]) while((s32[], f32[4], s32[]) %tuple.4), condition=%region_1.12, body=%region_0.5
+
+func.func @main(%arg0: tensor<i32>, %arg1: tensor<4xf32>, %arg2: tensor<i32>) -> tensor<4xf32> {
+  %2:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %arg1) : tensor<i32>, tensor<4xf32>
+    cond {
+    %3 = mhlo.compare LT, %iterArg, %arg2 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    mhlo.return %3 : tensor<i1>
+  } do {
+    %3 = mhlo.add %iterArg, %arg2 : tensor<i32>
+    mhlo.return %3, %iterArg_0: tensor<i32>, tensor<4xf32>
+  }
+  func.return %2#1 : tensor<4xf32>
+}
diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/translate.cc b/third_party/xla/xla/translate/mhlo_to_hlo/translate.cc
index 8cff1c9..7c07582 100644
--- a/third_party/xla/xla/translate/mhlo_to_hlo/translate.cc
+++ b/third_party/xla/xla/translate/mhlo_to_hlo/translate.cc
@@ -16,26 +16,42 @@
 
 #include <memory>
 #include <utility>
+#include <vector>
 
 #include "absl/log/log.h"
+#include "absl/status/status.h"
 #include "absl/status/statusor.h"
+#include "absl/strings/str_cat.h"
 #include "llvm/Support/MemoryBuffer.h"
 #include "llvm/Support/SMLoc.h"
 #include "llvm/Support/SourceMgr.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/Block.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/DialectRegistry.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/OwningOpRef.h"
+#include "mlir/IR/Value.h"
 #include "mlir/Parser/Parser.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/LogicalResult.h"
+#include "xla/client/xla_builder.h"
+#include "xla/client/xla_computation.h"
+#include "xla/debug_options_flags.h"
+#include "xla/hlo/ir/hlo_input_output_alias_config.h"
+#include "xla/hlo/ir/hlo_instruction.h"
 #include "xla/hlo/ir/hlo_module.h"
 #include "xla/mlir_hlo/mhlo/IR/register.h"
 #include "xla/service/hlo.pb.h"
 #include "xla/service/hlo_module_config.h"
 #include "xla/service/hlo_proto_util.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
 #include "xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h"
 #include "xla/translate/mhlo_to_hlo/type_to_shape.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
 
 constexpr char kParameterReplicationAttr[] = "mhlo.parameter_replication";
 
@@ -123,6 +139,8 @@
                 mlir::cast<mlir::BoolAttr>(b).getValue());
 
   auto hlo_module = computation.proto();
+  mlir::StringRef module_name = module.getName() ? *module.getName() : "main";
+  hlo_module.set_name(module_name.str());
   hlo_proto->mutable_hlo_module()->Swap(&hlo_module);
 
   return absl::OkStatus();
diff --git a/third_party/xla/xla/tsl/c/tsl_status.cc b/third_party/xla/xla/tsl/c/tsl_status.cc
index fea8943..75b9481 100644
--- a/third_party/xla/xla/tsl/c/tsl_status.cc
+++ b/third_party/xla/xla/tsl/c/tsl_status.cc
@@ -35,7 +35,7 @@
     return;
   }
   s->status =
-      Status(static_cast<absl::StatusCode>(code), tsl::StringPiece(msg));
+      Status(static_cast<absl::StatusCode>(code), absl::string_view(msg));
 }
 
 void TSL_SetPayload(TSL_Status* s, const char* key, const char* value) {
diff --git a/third_party/xla/xla/tsl/concurrency/BUILD b/third_party/xla/xla/tsl/concurrency/BUILD
index 0363d15..578b6ce 100644
--- a/third_party/xla/xla/tsl/concurrency/BUILD
+++ b/third_party/xla/xla/tsl/concurrency/BUILD
@@ -31,7 +31,6 @@
         "@com_google_absl//absl/synchronization",
         "@com_google_absl//absl/types:span",
         "@local_tsl//tsl/platform:logging",
-        "@local_tsl//tsl/platform:platform_port",
     ],
 )
 
@@ -72,6 +71,7 @@
         "@com_google_absl//absl/status:statusor",
         "@com_google_absl//absl/types:span",
         "@local_tsl//tsl/platform:test",
+        "@local_tsl//tsl/platform:test_benchmark",
         "@local_tsl//tsl/platform:test_main",
     ],
 )
diff --git a/third_party/xla/xla/tsl/concurrency/async_value.cc b/third_party/xla/xla/tsl/concurrency/async_value.cc
index dd26e04..fa3f058 100644
--- a/third_party/xla/xla/tsl/concurrency/async_value.cc
+++ b/third_party/xla/xla/tsl/concurrency/async_value.cc
@@ -63,12 +63,6 @@
 
 std::atomic<size_t> AsyncValue::total_allocated_async_values_;
 
-const AsyncValue::TypeInfo& AsyncValue::GetTypeInfo() const {
-  TypeInfoTable* type_info_table = AsyncValue::GetTypeInfoTableSingleton();
-  DCHECK_NE(type_id_, 0);
-  return (*type_info_table)[type_id_ - 1];
-}
-
 // This is called when the value is set into the ConcreteAsyncValue buffer, or
 // when the IndirectAsyncValue is forwarded to an available AsyncValue, and we
 // need to change our state and clear out the notifications. The current state
diff --git a/third_party/xla/xla/tsl/concurrency/async_value.h b/third_party/xla/xla/tsl/concurrency/async_value.h
index 372db4f..30e0d8e 100644
--- a/third_party/xla/xla/tsl/concurrency/async_value.h
+++ b/third_party/xla/xla/tsl/concurrency/async_value.h
@@ -21,17 +21,18 @@
 #include <cassert>
 #include <cstddef>
 #include <cstdint>
-#include <iostream>
 #include <memory>
+#include <new>
 #include <type_traits>
 #include <utility>
 
+#include "absl/base/optimization.h"
 #include "absl/functional/any_invocable.h"
 #include "absl/status/status.h"
 #include "absl/types/span.h"
 #include "xla/tsl/concurrency/concurrent_vector.h"
 #include "xla/tsl/concurrency/ref_count.h"
-#include "tsl/platform/mem.h"
+#include "tsl/platform/logging.h"
 
 namespace tsl {
 
@@ -164,8 +165,8 @@
   // process. This is intended for debugging/assertions only, and shouldn't be
   // used for mainline logic in the runtime.
   static size_t GetNumAsyncValueInstances() {
-    assert(AsyncValueAllocationTrackingEnabled() &&
-           "AsyncValue instance tracking disabled!");
+    DCHECK(AsyncValueAllocationTrackingEnabled())
+        << "AsyncValue instance tracking disabled!";
     return total_allocated_async_values_.load(std::memory_order_relaxed);
   }
 
@@ -418,8 +419,9 @@
  private:
   // Information about a ConcreteAsyncValue<T> subclass.
   struct TypeInfo {
-    // Destructor returns the size of the derived AsyncValue to be deallocated.
-    using DestructorFn = size_t (*)(AsyncValue*);
+    // Destructor returns the size and alignment of the derived AsyncValue to
+    // be deallocated.
+    using DestructorFn = std::pair<size_t, std::align_val_t> (*)(AsyncValue*);
     using GetErrorFn = const absl::Status& (*)(const AsyncValue*);
     using SetErrorFn = void (*)(AsyncValue*, absl::Status);
     using HasDataFn = bool (*)(const AsyncValue*);
@@ -433,9 +435,9 @@
   template <typename Derived>
   static TypeInfo MakeTypeInfo() {
     return TypeInfo{
-        [](AsyncValue* v) {
+        [](AsyncValue* v) -> std::pair<size_t, std::align_val_t> {
           static_cast<Derived*>(v)->~Derived();
-          return sizeof(Derived);
+          return {sizeof(Derived), std::align_val_t{alignof(Derived)}};
         },
         [](const AsyncValue* v) -> const absl::Status& {
           return static_cast<const Derived*>(v)->GetError();
@@ -454,14 +456,17 @@
   template <typename T>
   const T& GetConcreteValue() const;
 
-  // Get the TypeInfo instance for this AsyncValue.
-  const TypeInfo& GetTypeInfo() const;
-
-  using TypeInfoTable = internal::ConcurrentVector<TypeInfo>;
-
   // Returns the TypeInfoTable instance (there is one per process).
+  using TypeInfoTable = internal::ConcurrentVector<TypeInfo>;
   static TypeInfoTable* GetTypeInfoTableSingleton();
 
+  // Get the TypeInfo instance for this AsyncValue.
+  const TypeInfo& GetTypeInfo() const {
+    TypeInfoTable* type_info_table = AsyncValue::GetTypeInfoTableSingleton();
+    DCHECK_NE(type_id_, 0) << "TypeId must be set";
+    return (*type_info_table)[type_id_ - 1];
+  }
+
   void EnqueueWaiter(absl::AnyInvocable<void()> waiter,
                      WaitersAndState old_value);
 
@@ -569,7 +574,7 @@
 
   // Return the underlying error. IsError() must return true.
   const absl::Status& GetError() const {
-    assert(IsError());
+    DCHECK(IsError());
     return data_store_.error();
   }
 
@@ -579,12 +584,12 @@
   }
 
   const T& get() const {
-    assert(HasData());
+    DCHECK(HasData());
     return data_store_.data();
   }
 
   T& get() {
-    assert(HasData());
+    DCHECK(HasData());
     return data_store_.data();
   }
 
@@ -629,7 +634,7 @@
     }
 
     void SetError(State s, absl::Status status) {
-      assert(s == State::kUnconstructed || s == State::kConstructed);
+      DCHECK(s == State::kUnconstructed || s == State::kConstructed);
       if (s == State::kConstructed) {
         data_.~T();
       }
@@ -677,13 +682,13 @@
     }
 
     void SetError(State s, absl::Status status) {
-      assert(!error_);
+      DCHECK(!error_);
       error_ = std::make_unique<absl::Status>(std::move(status));
     }
 
     template <typename... Args>
     void EmplaceData(Args&&... args) {
-      assert(!HasData());
+      DCHECK(!HasData());
       new (&data_) T(std::forward<Args>(args)...);
       has_data_ = true;
     }
@@ -807,8 +812,8 @@
 };
 
 inline AsyncValue::~AsyncValue() {
-  assert(waiters_and_state_.load().waiter() == nullptr &&
-         "An async value with waiters should never have refcount of zero");
+  DCHECK_EQ(waiters_and_state_.load().waiter(), nullptr)
+      << "An async value with waiters should never have refcount of zero";
   if (AsyncValueAllocationTrackingEnabled() && is_refcounted_)
     total_allocated_async_values_.fetch_sub(1, std::memory_order_relaxed);
 
@@ -853,7 +858,7 @@
 #endif
 
   if (count > 0) {
-    assert(refcount_.load(std::memory_order_relaxed) > 0);
+    DCHECK_GT(refcount_.load(std::memory_order_relaxed), 0);
     // Increasing the reference counter can always be done with
     // memory_order_relaxed: New references to an object can only be formed from
     // an existing reference, and passing an existing reference from one thread
@@ -871,7 +876,7 @@
   if (!is_refcounted_) return;
 #endif
 
-  assert(refcount_.load(std::memory_order_relaxed) > 0);
+  DCHECK_GT(refcount_.load(std::memory_order_relaxed), 0);
   // We expect that `count` argument will often equal the actual reference count
   // here; optimize for that. If `count` == reference count, only an acquire
   // barrier is needed to prevent the effects of the deletion from leaking
@@ -894,8 +899,8 @@
 const T& AsyncValue::GetConcreteValue() const {
   // Make sure both T (the stored type) and BaseT have vtable_ptr or
   // neither have the vtable_ptr.
-  assert(std::is_polymorphic<T>::value == has_vtable_);
-  assert(IsTypeIdCompatible<T>() && "Incorrect accessor");
+  DCHECK_EQ(std::is_polymorphic<T>::value, has_vtable_);
+  DCHECK(IsTypeIdCompatible<T>()) << "Incorrect accessor";
 
   const char* this_ptr = reinterpret_cast<const char*>(this);
   return *reinterpret_cast<const T*>(this_ptr + AsyncValue::kDataOffset);
@@ -909,32 +914,27 @@
   switch (kind()) {
     case Kind::kConcrete:
 #ifndef NDEBUG
-      // TODO(ezhulenev): Use `DLOG_IF` when absl logging is available.
       if (!GetTypeInfo().has_data(this)) {
-        std::cerr << "Cannot call get() when ConcreteAsyncValue"  // Crash OK
-                  << " isn't constructed; state: " << s.DebugString() << ","
-                  << " error message: "
-                  << (IsError() ? GetError().message() : "None");
-        std::abort();
+        LOG(FATAL) << "Cannot call get() when ConcreteAsyncValue"
+                   << " isn't constructed; state: " << s.DebugString() << ","
+                   << " error message: "
+                   << (IsError() ? GetError().message() : "None");
       }
 #endif  // NDEBUG
       return GetConcreteValue<T>();
     case Kind::kIndirect:
 #ifndef NDEBUG
-      // TODO(ezhulenev): Use `DLOG_IF` when absl logging is available.
       if (s != State::kConcrete) {
-        std::cerr << "Cannot call get() when IndirectAsyncValue"  // Crash OK
-                  << " isn't concrete; state: " << s.DebugString() << ","
-                  << " error message: "
-                  << (IsError() ? GetError().message() : "None");
-        std::abort();
+        LOG(FATAL) << "Cannot call get() when IndirectAsyncValue"
+                   << " isn't concrete; state: " << s.DebugString() << ","
+                   << " error message: "
+                   << (IsError() ? GetError().message() : "None");
       }
 #endif  // NDEBUG
       auto* iv_value = static_cast<const IndirectAsyncValue*>(this)->value_;
-      assert(iv_value && "Indirect value not resolved");
+      DCHECK(iv_value) << "Indirect value not resolved";
       return iv_value->get<T>();
   }
-  assert(false && "unexpected AsyncValue kind");
 }
 
 template <typename T>
@@ -943,14 +943,14 @@
 }
 
 inline void AsyncValue::SetStateConcrete() {
-  assert(IsConstructed() && kind() == Kind::kConcrete);
+  DCHECK(IsConstructed() && kind() == Kind::kConcrete);
   NotifyAvailable(State::kConcrete);
 }
 
 template <typename T, typename... Args>
 void AsyncValue::emplace(Args&&... args) {
-  assert(GetTypeId<T>() == type_id_ && "Incorrect accessor");
-  assert(IsUnconstructed() && kind() == Kind::kConcrete);
+  DCHECK_EQ(GetTypeId<T>(), type_id_) << "Incorrect accessor";
+  DCHECK(IsUnconstructed() && kind() == Kind::kConcrete);
 
   static_cast<internal::ConcreteAsyncValue<T>*>(this)->emplace(
       std::forward<Args>(args)...);
@@ -968,7 +968,7 @@
       // Unresolved IndirectAsyncValues are not errors.
       if (!iv_value) return nullptr;
 
-      assert(iv_value->kind() != Kind::kIndirect);
+      DCHECK(iv_value->kind() != Kind::kIndirect);
       return iv_value->GetErrorIfPresent();
     }
   }
@@ -976,7 +976,7 @@
 
 inline const absl::Status& AsyncValue::GetError() const {
   auto* result = GetErrorIfPresent();
-  assert(result && "Cannot call GetError() when error isn't available.");
+  DCHECK(result) << "Cannot call GetError() when error isn't available.";
   return *result;
 }
 
@@ -988,7 +988,7 @@
   auto old_value = waiters_and_state_.load(std::memory_order_acquire);
   if (old_value.state() == State::kConcrete ||
       old_value.state() == State::kError) {
-    assert(old_value.waiter() == nullptr);
+    DCHECK_EQ(old_value.waiter(), nullptr);
     waiter();
     return;
   }
@@ -1003,7 +1003,7 @@
   auto old_value = waiters_and_state_.load(std::memory_order_acquire);
   if (old_value.state() == State::kConcrete ||
       old_value.state() == State::kError) {
-    assert(old_value.waiter() == nullptr);
+    DCHECK_EQ(old_value.waiter(), nullptr);
     executor.Execute(std::forward<Waiter>(waiter));
     return;
   }
@@ -1018,17 +1018,30 @@
   // Copy `is_refcounted` flag before destroying the async value object.
   bool was_ref_counted = is_refcounted_;
 
-  if (kind() == Kind::kIndirect) {
+  if (ABSL_PREDICT_FALSE(kind() == Kind::kIndirect)) {
     // Depending on what the benchmarks say, it might make sense to remove this
     // explicit check and instead make ~IndirectAsyncValue go through the
     // GetTypeInfo().destructor case below.
     static_cast<IndirectAsyncValue*>(this)->~IndirectAsyncValue();
-    if (was_ref_counted) port::AlignedFree(this);
+    if (was_ref_counted) {
+#if defined(__cpp_sized_deallocation)
+      ::operator delete(this, sizeof(IndirectAsyncValue),
+                        std::align_val_t{alignof(IndirectAsyncValue)});
+#else   // defined(__cpp_sized_deallocation)
+      ::operator delete(this, std::align_val_t{alignof(IndirectAsyncValue)});
+#endif  // defined(__cpp_sized_deallocation)
+    }
     return;
   }
 
-  GetTypeInfo().destructor(this);
-  if (was_ref_counted) port::AlignedFree(this);
+  auto [size, alignment] = GetTypeInfo().destructor(this);
+  if (was_ref_counted) {
+#if defined(__cpp_sized_deallocation)
+    ::operator delete(this, size, alignment);
+#else   // defined(__cpp_sized_deallocation)
+    ::operator delete(this, alignment);
+#endif  // defined(__cpp_sized_deallocation)
+  }
 }
 
 inline bool AsyncValue::IsUnique() const {
diff --git a/third_party/xla/xla/tsl/concurrency/async_value_ref.h b/third_party/xla/xla/tsl/concurrency/async_value_ref.h
index 625c908..65fb655 100644
--- a/third_party/xla/xla/tsl/concurrency/async_value_ref.h
+++ b/third_party/xla/xla/tsl/concurrency/async_value_ref.h
@@ -18,6 +18,7 @@
 
 #include <algorithm>
 #include <cstddef>
+#include <new>
 #include <string_view>
 #include <type_traits>
 #include <utility>
@@ -32,7 +33,6 @@
 #include "xla/tsl/concurrency/async_value.h"
 #include "xla/tsl/concurrency/ref_count.h"
 #include "tsl/platform/logging.h"
-#include "tsl/platform/mem.h"
 
 namespace tsl {
 
@@ -88,8 +88,8 @@
   AsyncValueRef(const AsyncValueRef&) = default;
   AsyncValueRef& operator=(const AsyncValueRef&) = default;
 
-  AsyncValueRef(AsyncValueRef&&) = default;
-  AsyncValueRef& operator=(AsyncValueRef&&) = default;
+  AsyncValueRef(AsyncValueRef&&) noexcept = default;
+  AsyncValueRef& operator=(AsyncValueRef&&) noexcept = default;
 
   explicit AsyncValueRef(RCReference<AsyncValue> value)
       : value_(std::move(value)) {}
@@ -135,7 +135,7 @@
   // Return true if the AsyncValue contains a concrete value.
   bool IsConcrete() const { return value_->IsConcrete(); }
 
-  // Return true if state is kUnconstructed.
+  // Return true if state is `kUnconstructed`.
   bool IsUnconstructed() const { return value_->IsUnconstructed(); }
 
   // Return the stored value. The AsyncValueRef must be available.
@@ -876,7 +876,7 @@
 
 template <typename T, typename... Args>
 T* AllocateAndConstruct(Args&&... args) {
-  void* buf = port::AlignedMalloc(sizeof(T), alignof(T));
+  void* buf = ::operator new(sizeof(T), std::align_val_t{alignof(T)});
   return PlacementConstruct<T, Args...>(buf, std::forward<Args>(args)...);
 }
 
@@ -953,13 +953,13 @@
   AsyncValueOwningRef(const AsyncValueOwningRef&) = delete;
   AsyncValueOwningRef& operator=(const AsyncValueOwningRef&) = delete;
 
-  AsyncValueOwningRef& operator=(AsyncValueOwningRef&& other) {
+  AsyncValueOwningRef& operator=(AsyncValueOwningRef&& other) noexcept {
     Destroy();
     std::swap(value_, other.value_);
     return *this;
   }
 
-  AsyncValueOwningRef(AsyncValueOwningRef&& other) {
+  AsyncValueOwningRef(AsyncValueOwningRef&& other) noexcept {
     Destroy();
     std::swap(value_, other.value_);
   }
diff --git a/third_party/xla/xla/tsl/concurrency/async_value_ref_test.cc b/third_party/xla/xla/tsl/concurrency/async_value_ref_test.cc
index 2c4ce86..646b05b 100644
--- a/third_party/xla/xla/tsl/concurrency/async_value_ref_test.cc
+++ b/third_party/xla/xla/tsl/concurrency/async_value_ref_test.cc
@@ -16,6 +16,7 @@
 #include "xla/tsl/concurrency/async_value_ref.h"
 
 #include <any>
+#include <array>
 #include <atomic>
 #include <cstddef>
 #include <cstdint>
@@ -30,6 +31,7 @@
 #include "xla/tsl/concurrency/async_value.h"
 #include "xla/tsl/concurrency/ref_count.h"
 #include "tsl/platform/test.h"
+#include "tsl/platform/test_benchmark.h"
 
 namespace tsl {
 
@@ -787,4 +789,25 @@
   EXPECT_EQ(counter, 1 + 2 + 3);
 }
 
+//===----------------------------------------------------------------------===//
+// Performance benchmarks below
+//===----------------------------------------------------------------------===//
+
+template <size_t size>
+static void BM_MakeConstructed(benchmark::State& state) {
+  for (auto _ : state) {
+    auto ref = MakeConstructedAsyncValueRef<std::array<char, size>>();
+    benchmark::DoNotOptimize(ref);
+  }
+}
+
+BENCHMARK(BM_MakeConstructed<1>);
+BENCHMARK(BM_MakeConstructed<4>);
+BENCHMARK(BM_MakeConstructed<8>);
+BENCHMARK(BM_MakeConstructed<16>);
+BENCHMARK(BM_MakeConstructed<32>);
+BENCHMARK(BM_MakeConstructed<64>);
+BENCHMARK(BM_MakeConstructed<128>);
+BENCHMARK(BM_MakeConstructed<256>);
+
 }  // namespace tsl
diff --git a/third_party/xla/xla/tsl/concurrency/ref_count.h b/third_party/xla/xla/tsl/concurrency/ref_count.h
index 664dd95..4ea65ee 100644
--- a/third_party/xla/xla/tsl/concurrency/ref_count.h
+++ b/third_party/xla/xla/tsl/concurrency/ref_count.h
@@ -124,7 +124,7 @@
  public:
   RCReference() : pointer_(nullptr) {}
 
-  RCReference(RCReference&& other) : pointer_(other.pointer_) {
+  RCReference(RCReference&& other) noexcept : pointer_(other.pointer_) {
     other.pointer_ = nullptr;
   }
 
@@ -132,7 +132,7 @@
     if (pointer_) pointer_->AddRef();
   }
 
-  RCReference& operator=(RCReference&& other) {
+  RCReference& operator=(RCReference&& other) noexcept {
     reset(other.pointer_);
     other.pointer_ = nullptr;
     return *this;
@@ -187,7 +187,7 @@
 
   explicit operator bool() const { return pointer_ != nullptr; }
 
-  void swap(RCReference& other) {
+  void swap(RCReference& other) noexcept {
     using std::swap;
     swap(pointer_, other.pointer_);
   }
@@ -256,7 +256,7 @@
 }
 // For ADL style swap.
 template <typename T>
-void swap(RCReference<T>& a, RCReference<T>& b) {
+void swap(RCReference<T>& a, RCReference<T>& b) noexcept {
   a.swap(b);
 }
 
diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/BUILD b/third_party/xla/xla/tsl/distributed_runtime/coordination/BUILD
index 5e727c6..a198b4c 100644
--- a/third_party/xla/xla/tsl/distributed_runtime/coordination/BUILD
+++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/BUILD
@@ -13,11 +13,13 @@
 
 cc_library(
     name = "coordination_service_error_util",
+    srcs = ["coordination_service_error_util.cc"],
     hdrs = ["coordination_service_error_util.h"],
     deps = [
         "@com_google_absl//absl/status",
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/strings:cord",
+        "@local_tsl//tsl/platform:regexp",
         "@local_tsl//tsl/protobuf:coordination_service_proto_cc",
     ],
 )
@@ -28,6 +30,7 @@
     deps = [
         ":coordination_service_error_util",
         "@com_google_absl//absl/status",
+        "@com_google_absl//absl/strings",
         "@local_tsl//tsl/platform:status",
         "@local_tsl//tsl/platform:test",
         "@local_tsl//tsl/platform:test_main",
diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc
index dd53c80..c3ba0da 100644
--- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc
+++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc
@@ -246,10 +246,6 @@
     void Disconnect(uint64_t grace_period_duration_us);
     absl::Status RecordHeartbeat(uint64_t task_incarnation);
     int64_t TimeSinceLastHeartbeatMs();
-    // This denotes the deadline after which we stop accepting heartbeats from a
-    // disconnected task. This grace period accounts for the lag time between
-    // the service recording the state change and the agent stopping heartbeats.
-    uint64_t GetDisconnectedGracePeriodMicros();
     void SetError(absl::Status status);
     DeviceInfo GetDeviceInfo() { return devices_; }
     void CollectDeviceInfo(const DeviceInfo& devices) { devices_ = devices; }
@@ -260,6 +256,11 @@
     absl::flat_hash_set<std::string> GetOngoingBarriers();
     void JoinBarrier(std::string_view barrier_id);
     void ExitBarrier(std::string_view barrier_id);
+    // Returns true if the task has been disconnected beyond the grace period
+    // and no further agent requests are expected. Note that the grace period
+    // accounts for the lag time between the service recording the state change
+    // and the agent stopping heartbeats/error polling.
+    bool IsDisconnectedBeyondGracePeriod();
 
    private:
     // Incarnation ID for CPU:0 on remote task.
@@ -269,9 +270,10 @@
     absl::Status status_;
     absl::Mutex last_heartbeat_mu_;
     uint64_t last_heartbeat_us_ ABSL_GUARDED_BY(last_heartbeat_mu_);
-    // This denotes the deadline after which we stop accepting heartbeats from a
-    // disconnected task. This grace period accounts for the lag time between
-    // the service recording the state change and the agent stopping heartbeats.
+    // This denotes the deadline after which we stop accepting heartbeats or
+    // error polling requests from a disconnected task. This grace period
+    // accounts for the lag time between the service recording the state change
+    // and the agent stopping heartbeats/error polling.
     uint64_t disconnect_grace_period_us_ = 0;
     DeviceInfo devices_;
     // For now, we assume there won't be many simultaneous barriers so we simply
@@ -392,11 +394,6 @@
   return (Env::Default()->NowMicros() - last_heartbeat_us_) / 1000;
 }
 
-uint64_t CoordinationServiceStandaloneImpl::TaskState::
-    GetDisconnectedGracePeriodMicros() {
-  return disconnect_grace_period_us_;
-}
-
 absl::flat_hash_set<std::string>
 CoordinationServiceStandaloneImpl::TaskState::GetOngoingBarriers() {
   return ongoing_barriers_for_task_;
@@ -412,6 +409,12 @@
   ongoing_barriers_for_task_.erase(barrier_id);
 }
 
+bool CoordinationServiceStandaloneImpl::TaskState::
+    IsDisconnectedBeyondGracePeriod() {
+  return GetState() == CoordinatedTaskState::TASKSTATE_DISCONNECTED &&
+         Env::Default()->NowMicros() > disconnect_grace_period_us_;
+}
+
 void CoordinationServiceStandaloneImpl::SetDeviceAggregationFunction(
     std::function<DeviceInfo(const DeviceInfo& devices)>
         post_aggregate_device_fn) {
@@ -551,7 +554,8 @@
                 absl::StrAppend(
                     &error_message,
                     "Total Number of tasks already at the barrier: ",
-                    barrier->tasks_at_barrier.size() - pending_task_count,
+                    barrier->tasks_at_barrier.size() - pending_task_count, "/",
+                    barrier->tasks_at_barrier.size(),
                     ". Timed out task names:\n%s", pending_tasks);
               }
               const absl::Status error = MakeCoordinationError(
@@ -890,14 +894,10 @@
     }
     if (!cluster_state_[task_name]->GetStatus().ok()) {
       return cluster_state_[task_name]->GetStatus();
-    } else if (cluster_state_[task_name]->GetState() ==
-                   CoordinatedTaskState::TASKSTATE_DISCONNECTED &&
-               // We accept heartbeats for a short grace period to account for
-               // the lag time between the service recording the state change
-               // and the agent stopping heartbeats.
-               Env::Default()->NowMicros() >
-                   cluster_state_[task_name]
-                       ->GetDisconnectedGracePeriodMicros()) {
+    } else if (cluster_state_[task_name]->IsDisconnectedBeyondGracePeriod()) {
+      // We accept heartbeats for a short grace period to account for the lag
+      // time between the service recording the state change and the agent
+      // stopping heartbeats.
       return MakeCoordinationError(absl::InvalidArgumentError(absl::StrCat(
           "Task with task_name=", task_name,
           " must be registered before sending heartbeat messages")));
@@ -1193,11 +1193,26 @@
     return;
   }
 
-  if (cluster_state_[task_name]->GetState() !=
-      CoordinatedTaskState::TASKSTATE_CONNECTED) {
-    done(MakeCoordinationError(absl::InvalidArgumentError(
+  // On the agent side, the error polling thread will only be started when the
+  // task is connected, but by the time the request is processed by the service,
+  // the task state may have changed due to actions by the service or the main
+  // thread on the agent. As a way to handle this, we accept error polling for a
+  // short grace period. After the grace period, the service will return an
+  // error to the task.
+  if (cluster_state_[task_name]->IsDisconnectedBeyondGracePeriod()) {
+    done(MakeCoordinationError(absl::FailedPreconditionError(
         absl::StrCat("Task (", task_name,
-                     ") that has not been registered polling for errors."))));
+                     ") that has not been registered or has disconnected "
+                     "polling for errors."))));
+    return;
+  }
+
+  if (cluster_state_[task_name]->GetState() ==
+      CoordinatedTaskState::TASKSTATE_ERROR) {
+    done(MakeCoordinationError(absl::FailedPreconditionError(absl::StrCat(
+        "Task (", task_name,
+        ") that is already in error state polling for errors. Current error: ",
+        cluster_state_[task_name]->GetStatus().ToString()))));
     return;
   }
 
@@ -1471,9 +1486,11 @@
       return;
     }
   }
-  LOG(ERROR) << "An error is encountered. Sending the error as a response to "
-                "all error polling requests: "
-             << error;
+  if (!absl::IsCancelled(error)) {
+    VLOG(2) << "An error is encountered. Sending the error as a response to "
+               "all error polling requests: "
+            << error;
+  }
   std::vector<std::string> missing_tasks;
   {
     absl::MutexLock l(&state_mu_);
diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent.cc
index 8bcf451..617da59 100644
--- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent.cc
+++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent.cc
@@ -412,21 +412,19 @@
 }
 
 void CoordinationServiceAgentImpl::StartPollingForError() {
-  LOG(INFO) << "Polling error from coordination service. This thread "
-               "will run until an error is encountered or the agent is "
-               "shutdown.";
+  LOG(INFO) << "Polling for error from coordination service. This thread will "
+               "run until an error is encountered or the agent is shutdown.";
   absl::Status status = PollForError();
   CHECK(!status.ok()) << "PollForError returned OK status. Should "
                          "always return an error.";
   if (absl::IsCancelled(status)) {
-    LOG(INFO) << "Stop polling error from coordination service because "
-                 "the service or the agent is shutting down."
-              << status;
+    LOG(INFO) << "Cancelling error polling because the service or the agent is "
+                 "shutting down.";
+    // Return early and there is no need to set error.
     return;
   }
-  LOG(INFO) << "Error returned from coordination service after polling: "
-            << status;
-
+  LOG(ERROR) << "An error is returned from coordination service (this can be "
+                "an error from this or another task).";
   SetError(status);
 }
 
@@ -440,10 +438,6 @@
   n.WaitForNotification();
   CHECK(!status.ok())
       << "PollForError returned OK status. Should always return an error.";
-  LOG(ERROR)
-      << "PollForError returned with status (this can be an error from this or "
-         "another task): "
-      << status;
   return status;
 }
 
@@ -628,7 +622,7 @@
     } else {
       LOG(ERROR)
           << "Failed to disconnect from coordination service with status: "
-          << status
+          << TrimCoordinationErrorMessage(status)
           << "\nProceeding with agent shutdown anyway. This is usually caused "
              "by an earlier error during execution. Check the logs (this task "
              "or the leader) for an earlier error to debug further.";
@@ -893,11 +887,12 @@
   assert(!error.ok());
   absl::MutexLock l(&state_mu_);
   if (state_ == CoordinatedTaskState::TASKSTATE_ERROR) return;
+  absl::Status trimmed_error = TrimCoordinationErrorMessage(error);
 
-  LOG(ERROR) << "Coordination agent is set to ERROR: " << error;
+  LOG(ERROR) << "Coordination agent is set to ERROR: " << trimmed_error;
   state_ = CoordinatedTaskState::TASKSTATE_ERROR;
-  status_ = error;
-  error_fn_(error);
+  status_ = trimmed_error;
+  error_fn_(trimmed_error);
 }
 
 absl::Status CoordinationServiceAgentImpl::ActivateWatch(
diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent_test.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent_test.cc
index ee2eb23..1281ea8 100644
--- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent_test.cc
+++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent_test.cc
@@ -454,6 +454,58 @@
   EXPECT_TRUE(agent_->IsError());
 }
 
+TEST_F(CoordinationServiceAgentTest, CancelledPollForErrorRequest) {
+  // Connect coordination agent.
+  PollForErrorResponse mocked_response;
+  EXPECT_CALL(*GetClient(), PollForErrorAsync(_, _, _, _))
+      .WillOnce(DoAll(SetArgPointee<2>(mocked_response),
+                      InvokeArgument<3>(absl::CancelledError("Test Error."))));
+
+  CoordinationServiceConfig config;
+  config.set_poll_for_error_from_service_at_startup(true);
+  InitializeAgent(config);
+  TF_ASSERT_OK(agent_->Connect());
+  // Wait a bit for the error polling thread to start.
+  absl::SleepFor(absl::Seconds(2));
+  // Cancelled error polling request will not set agent to error.
+  ASSERT_FALSE(agent_->IsError());
+}
+
+TEST_F(CoordinationServiceAgentTest, InvalidPollForErrorRequest) {
+  // Connect coordination agent.
+  PollForErrorResponse mocked_response;
+  EXPECT_CALL(*GetClient(), PollForErrorAsync(_, _, _, _))
+      .WillOnce(
+          DoAll(SetArgPointee<2>(mocked_response),
+                InvokeArgument<3>(absl::InvalidArgumentError("Test Error."))));
+
+  CoordinationServiceConfig config;
+  config.set_poll_for_error_from_service_at_startup(true);
+  InitializeAgent(config);
+  TF_ASSERT_OK(agent_->Connect());
+  // Wait a bit for the error polling thread to start.
+  absl::SleepFor(absl::Seconds(2));
+  ASSERT_TRUE(agent_->IsError());
+}
+
+TEST_F(CoordinationServiceAgentTest,
+       PollForErrorRequestWithFailedPrecondition) {
+  // Connect coordination agent.
+  PollForErrorResponse mocked_response;
+  EXPECT_CALL(*GetClient(), PollForErrorAsync(_, _, _, _))
+      .WillOnce(DoAll(
+          SetArgPointee<2>(mocked_response),
+          InvokeArgument<3>(absl::FailedPreconditionError("Test Error."))));
+
+  CoordinationServiceConfig config;
+  config.set_poll_for_error_from_service_at_startup(true);
+  InitializeAgent(config);
+  TF_ASSERT_OK(agent_->Connect());
+  // Wait a bit for the error polling thread to start.
+  absl::SleepFor(absl::Seconds(2));
+  ASSERT_TRUE(agent_->IsError());
+}
+
 TEST_F(CoordinationServiceAgentTest, ResetCanBeRetried) {
   // Mock reset error failing for the first time.
   EXPECT_CALL(*GetClient(), ResetTaskAsync(_, _, _))
diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util.cc
new file mode 100644
index 0000000..8fc7631
--- /dev/null
+++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util.cc
@@ -0,0 +1,75 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/tsl/distributed_runtime/coordination/coordination_service_error_util.h"
+
+#include <optional>
+#include <string>
+
+#include "absl/status/status.h"
+#include "absl/strings/cord.h"
+#include "absl/strings/str_cat.h"
+#include "tsl/platform/regexp.h"
+
+namespace tsl {
+absl::Status TrimCoordinationErrorMessage(const absl::Status& s) {
+  if (s.ok()) {
+    return s;
+  }
+  auto status_message = std::string(s.message());
+  auto additional_info_index = status_message.find("Additional GRPC");
+  // This error didn't come from gRPC, so we don't need to trim it.
+  if (additional_info_index == std::string::npos) {
+    return s;
+  }
+
+  std::optional<absl::Cord> payload =
+      s.GetPayload(CoordinationErrorPayloadKey());
+  if (!payload.has_value() && absl::IsUnavailable(s)) {
+    // This error is not provided by us, so it's probably an RPC layer error.
+    auto prefix_message =
+        "Failed to send RPC to coordination service. Either the leader task "
+        "died/restarted unexpectedly or this task is experiencing network "
+        "issues. Check earlier logs from this task and the "
+        "leader (usually slice 0 process/task/worker 0)  to debug further.\n";
+    status_message = absl::StrCat(
+        prefix_message,
+        // Replace the duplicated error message at the start with the prefix.
+        status_message.substr(additional_info_index));
+  } else {
+    // Extract RPC called.
+    std::string rpc_name;
+    // Note: it is unfortunate that we have to keep the tensorflow prefix
+    // because that's the RPC service proto namespace.
+    RE2::PartialMatch(status_message,
+                      "(/tensorflow.CoordinationService/(\\w+))", &rpc_name);
+    // Erase duplicated error message.
+    status_message = status_message.substr(0, additional_info_index);
+    absl::StrAppend(&status_message, "\nRPC: ", rpc_name);
+  }
+  auto trimmed_status = absl::Status(s.code(), status_message);
+  // Reattach payload.
+  if (payload.has_value()) {
+    trimmed_status.SetPayload(CoordinationErrorPayloadKey(), *payload);
+  }
+#if defined(PLATFORM_GOOGLE)
+  // Reattach source locations.
+  for (const auto& source_location : s.GetSourceLocations()) {
+    trimmed_status.AddSourceLocation(source_location);
+  }
+#endif
+  return trimmed_status;
+}
+}  // namespace tsl
diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util.h b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util.h
index 4555a4e..e1a3cdc 100644
--- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util.h
+++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util.h
@@ -55,6 +55,14 @@
                absl::Cord(payload.SerializeAsString()));
   return s;
 }
+
+// Trims the error message by replacing the `Additional GRPC error` part.
+// Note: The duplicated error message is a quirk of the underlying gRPC code
+// that we are using. Changing the shared code may hide important messages for
+// other libraries, so we trim the error message for coordination service
+// instead. See tsl/distributed_runtime/rpc/grpc_state.h for more details.
+absl::Status TrimCoordinationErrorMessage(const absl::Status& s);
+
 }  // namespace tsl
 
 #endif  // XLA_TSL_DISTRIBUTED_RUNTIME_COORDINATION_COORDINATION_SERVICE_ERROR_UTIL_H_
diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util_test.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util_test.cc
index 3c19fa5..535f471 100644
--- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util_test.cc
+++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util_test.cc
@@ -17,6 +17,7 @@
 #include <string>
 
 #include "absl/status/status.h"
+#include "absl/strings/match.h"
 #include "tsl/platform/test.h"
 #include "tsl/protobuf/coordination_service.pb.h"
 namespace tsl {
@@ -99,5 +100,54 @@
   EXPECT_EQ(actual_payload.is_reported_error(), payload.is_reported_error());
 }
 
+TEST(CoordinationServiceErrorUtil,
+     TrimCoordinationErrorMessage_CoordinationError) {
+  absl::Status error = MakeCoordinationError(absl::InternalError(
+      "Coordination service has stopped. RecordHeartbeat() from task: "
+      "/job:jax_worker/replica:0/task:2 failed. Additional GRPC error "
+      "information from remote target coordination_service while calling "
+      "/tensorflow.CoordinationService/Heartbeat::UNKNOWN:Error received from "
+      "peer  "
+      "{file:'third_party/grpc/src/core/lib/surface/filter_stack_call.cc', "
+      "file_line:464, created_time:'2024-08-05T13:57:51.331198242-07:00', "
+      "grpc_status:13, grpc_message:'Coordination service has stopped. "
+      "RecordHeartbeat() from task: /job:jax_worker/replica:0/task:2 failed. "
+      "'} "));
+
+  absl::Status trimmed_error = TrimCoordinationErrorMessage(error);
+  EXPECT_EQ(trimmed_error.code(), error.code());
+  EXPECT_EQ(trimmed_error.message(),
+            "Coordination service has stopped. RecordHeartbeat() from task: "
+            "/job:jax_worker/replica:0/task:2 failed. \nRPC: "
+            "/tensorflow.CoordinationService/Heartbeat");
+  // Payload exists but has no value.
+  EXPECT_EQ(trimmed_error.GetPayload(CoordinationErrorPayloadKey()).value(),
+            "");
+}
+
+TEST(CoordinationServiceErrorUtil, TrimCoordinationErrorMessage_NetworkError) {
+  absl::Status error = absl::UnavailableError(
+      "failed to connect to all addresses; last error: UNKNOWN: "
+      "ipv4:127.0.0.1:10001: Failed to connect to remote host: Connection "
+      "refused. Additional GRPC error information from remote target "
+      "coordination_service while calling "
+      "/tensorflow.CoordinationService/Heartbeat::UNKNOWN:Error received from "
+      "peer "
+      "{file:'third_party/grpc/src/core/lib/surface/filter_stack_call.cc', "
+      "file_line:464, created_time:'2024-08-05T13:57:53.123562608-07:00', "
+      "grpc_status:14, grpc_message:'failed to connect to all addresses; last "
+      "error: UNKNOWN: ipv4:127.0.0.1:10001: Failed to connect to remote host: "
+      "Connection refused'} ");
+
+  absl::Status trimmed_error = TrimCoordinationErrorMessage(error);
+  auto message = trimmed_error.message();
+  EXPECT_EQ(trimmed_error.code(), error.code());
+  EXPECT_TRUE(absl::StrContains(message, "Check earlier logs"));
+  // Message is not duplicated.
+  EXPECT_EQ(message.find("failed to connect"),
+            message.rfind("failed to connect"))
+      << trimmed_error;
+}
+
 }  // namespace
 }  // namespace tsl
diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc
index 0fb11db..2fa5001 100644
--- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc
+++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc
@@ -1199,6 +1199,9 @@
                         GetTaskName(GetTask(1))));  // First task at barrier.
   EXPECT_TRUE(absl::StrContains(barrier_status_0.message(),
                                 GetTaskName(GetTask(2))));  // Timed-out task.
+  EXPECT_TRUE(absl::StrContains(
+      barrier_status_0.message(),
+      "2/3"));  // Number of tasks at barrier / total number of tasks.
 }
 
 TEST_F(CoordinationBarrierTest, BarrierReturnsPreviousError) {
@@ -1820,10 +1823,77 @@
   coord_service_->PollForErrorAsync(
       task_0_, [&](const absl::Status& status) { s = status; });
 
-  EXPECT_THAT(s, StatusIs(absl::StatusCode::kInvalidArgument,
+  EXPECT_THAT(s, StatusIs(absl::StatusCode::kFailedPrecondition,
                           HasSubstr("has not been registered")));
 }
 
+TEST_F(CoordinateTwoTasksTest,
+       AllowPollForErrorWithinGracePeriodIfTaskHasShutDown) {
+  EnableCoordinationService(/*has_service_to_client_connection=*/false);
+  absl::Status s;
+  ASSERT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_));
+  ASSERT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_));
+  coord_service_->ShutdownTaskAsync(task_0_,
+                                    [&](const absl::Status& status) {});
+  coord_service_->ShutdownTaskAsync(task_1_,
+                                    [&](const absl::Status& status) {});
+
+  coord_service_->PollForErrorAsync(
+      task_0_, [&](const absl::Status& status) { s = status; });
+  // Stop the service.
+  coord_service_.reset();
+  // The error polling request will still proceed because of grace period. It
+  // will be cancelled.
+  EXPECT_THAT(s, StatusIs(absl::StatusCode::kCancelled));
+}
+
+TEST_F(CoordinateTwoTasksTest, DoNotAllowPollForErrorIfTaskHasShutDown) {
+  EnableCoordinationService(/*has_service_to_client_connection=*/false);
+  absl::Status s;
+  ASSERT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_));
+  ASSERT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_));
+  coord_service_->ShutdownTaskAsync(task_0_,
+                                    [&](const absl::Status& status) {});
+  coord_service_->ShutdownTaskAsync(task_1_,
+                                    [&](const absl::Status& status) {});
+
+  // Sleep past the grace period.
+  Env::Default()->SleepForMicroseconds(
+      absl::ToInt64Microseconds(2 * kHeartbeatTimeout));
+  coord_service_->PollForErrorAsync(
+      task_0_, [&](const absl::Status& status) { s = status; });
+  EXPECT_THAT(s, StatusIs(absl::StatusCode::kFailedPrecondition,
+                          HasSubstr("has disconnected")));
+}
+
+TEST_F(CoordinateTwoTasksTest, DoNotAllowPollForErrorAfterReset) {
+  EnableCoordinationService(/*has_service_to_client_connection=*/false);
+  absl::Status s;
+  ASSERT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_));
+  ASSERT_OK(coord_service_->ResetTask(task_0_));
+
+  // Sleep past the grace period.
+  Env::Default()->SleepForMicroseconds(
+      absl::ToInt64Microseconds(2 * kHeartbeatTimeout));
+  coord_service_->PollForErrorAsync(
+      task_0_, [&](const absl::Status& status) { s = status; });
+  EXPECT_THAT(s, StatusIs(absl::StatusCode::kFailedPrecondition,
+                          HasSubstr("has disconnected")));
+}
+
+TEST_F(CoordinateTwoTasksTest, DoNotAllowPollForErrorWhenInErrorState) {
+  EnableCoordinationService(/*has_service_to_client_connection=*/false);
+  absl::Status s;
+  ASSERT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_));
+  ASSERT_OK(coord_service_->ReportTaskError(task_0_,
+                                            absl::InternalError("test_error")));
+
+  coord_service_->PollForErrorAsync(
+      task_0_, [&](const absl::Status& status) { s = status; });
+  EXPECT_THAT(s, StatusIs(absl::StatusCode::kFailedPrecondition,
+                          HasSubstr("test_error")));
+}
+
 TEST_F(CoordinateTwoTasksTest, DoNotAllowPollForErrorIfServiceHasStopped) {
   EnableCoordinationService(/*has_service_to_client_connection=*/false);
   ASSERT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_));
diff --git a/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_client.cc b/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_client.cc
index 4afe13f..6bd7885 100644
--- a/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_client.cc
+++ b/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_client.cc
@@ -360,9 +360,7 @@
 
 CoordinationClient* NewGrpcCoordinationClient(
     std::shared_ptr<::grpc::Channel> channel) {
-  // TODO(hanyangtay): Pass in the logical task name for better logging.
-  return new GrpcCoordinationClient(
-      channel, /*target=*/"unknown_target_for_coordination_leader");
+  return new GrpcCoordinationClient(channel, /*target=*/"coordination_service");
 }
 
 }  // namespace tsl
diff --git a/third_party/xla/xla/tsl/framework/cancellation.cc b/third_party/xla/xla/tsl/framework/cancellation.cc
index d0a841f..7802eb9 100644
--- a/third_party/xla/xla/tsl/framework/cancellation.cc
+++ b/third_party/xla/xla/tsl/framework/cancellation.cc
@@ -103,7 +103,7 @@
 
 bool CancellationManager::RegisterCallbackWithErrorLogging(
     CancellationToken token, CancelCallback callback,
-    tsl::StringPiece callback_name) {
+    absl::string_view callback_name) {
   return RegisterCallbackConfig(
       token, CallbackConfiguration{callback, std::string(callback_name), true});
 }
diff --git a/third_party/xla/xla/tsl/framework/cancellation.h b/third_party/xla/xla/tsl/framework/cancellation.h
index 56076c8..38f7ebf 100644
--- a/third_party/xla/xla/tsl/framework/cancellation.h
+++ b/third_party/xla/xla/tsl/framework/cancellation.h
@@ -135,7 +135,7 @@
   // callback, which will be displayed on the log.
   bool RegisterCallbackWithErrorLogging(CancellationToken token,
                                         CancelCallback callback,
-                                        tsl::StringPiece callback_name);
+                                        absl::string_view callback_name);
 
   // Deregister the callback that, when registered, was associated
   // with the given cancellation token. Returns true iff the callback
diff --git a/third_party/xla/xla/tsl/framework/device_id_utils.cc b/third_party/xla/xla/tsl/framework/device_id_utils.cc
index a751a3a..812b119 100644
--- a/third_party/xla/xla/tsl/framework/device_id_utils.cc
+++ b/third_party/xla/xla/tsl/framework/device_id_utils.cc
@@ -29,12 +29,6 @@
 #include "tsl/platform/str_util.h"
 
 namespace tsl {
-namespace {
-int GetTfDeviceIdFromDeviceParsedName(
-    const DeviceNameUtils::ParsedName& device_name) {
-  return device_name.id;
-}
-}  // namespace
 
 void CheckValidTfDeviceId(const DeviceType& type,
                           const int visible_device_count,
@@ -62,7 +56,7 @@
     std::iota(visible_device_order->begin(), visible_device_order->end(), 0);
   } else {
     const std::vector<std::string> order_str =
-        tsl::str_util::Split(visible_device_list, ',');
+        tsl::str_util::Split(visible_device_list, ',');  // non-absl ok
     for (const std::string& platform_device_id_str : order_str) {
       int32_t platform_device_id;
       if (!absl::SimpleAtoi(platform_device_id_str, &platform_device_id)) {
@@ -126,7 +120,7 @@
 absl::StatusOr<int> GetPlatformDeviceIdFromDeviceParsedName(
     const DeviceNameUtils::ParsedName& device_name,
     const DeviceType& device_type) {
-  const TfDeviceId tf_device_id(GetTfDeviceIdFromDeviceParsedName(device_name));
+  const TfDeviceId tf_device_id(GetDeviceIdFromDeviceParsedName(device_name));
   PlatformDeviceId platform_device_id;
   absl::Status platform_id_status = DeviceIdManager::TfToPlatformDeviceId(
       device_type, tf_device_id, &platform_device_id);
@@ -136,15 +130,10 @@
   return platform_id_status;
 }
 
-absl::StatusOr<int> GetDeviceIdFromDeviceParsedName(
-    const DeviceNameUtils::ParsedName& device_name,
-    const DeviceType& device_type) {
-  auto platform_id =
-      GetPlatformDeviceIdFromDeviceParsedName(device_name, device_type);
-  if (platform_id.ok()) {
-    return *platform_id;
-  }
-  return GetTfDeviceIdFromDeviceParsedName(device_name);
+int GetDeviceIdFromDeviceParsedName(
+    const DeviceNameUtils::ParsedName& device_name) {
+  // This assumes that TF device ID is the same as PJRT local device ID.
+  return device_name.id;
 }
 
 }  // namespace tsl
diff --git a/third_party/xla/xla/tsl/framework/device_id_utils.h b/third_party/xla/xla/tsl/framework/device_id_utils.h
index d25ae1c..0da5969 100644
--- a/third_party/xla/xla/tsl/framework/device_id_utils.h
+++ b/third_party/xla/xla/tsl/framework/device_id_utils.h
@@ -60,12 +60,9 @@
     const DeviceNameUtils::ParsedName& device_name,
     const DeviceType& device_type);
 
-// TODO(b/293324740): support virtual devices.
-// Returns the corresponding PlatformDeviceId if it is found. Otherwise returns
-// the id in device_name.
-absl::StatusOr<int> GetDeviceIdFromDeviceParsedName(
-    const DeviceNameUtils::ParsedName& device_name,
-    const DeviceType& device_type);
+// Returns the id in device_name.
+int GetDeviceIdFromDeviceParsedName(
+    const DeviceNameUtils::ParsedName& device_name);
 
 }  // namespace tsl
 
diff --git a/third_party/xla/xla/tsl/framework/device_id_utils_test.cc b/third_party/xla/xla/tsl/framework/device_id_utils_test.cc
index da12c3b..e230d85 100644
--- a/third_party/xla/xla/tsl/framework/device_id_utils_test.cc
+++ b/third_party/xla/xla/tsl/framework/device_id_utils_test.cc
@@ -182,11 +182,7 @@
   DeviceNameUtils::ParsedName device_name;
   device_name.id = 0;
 
-  TF_ASSERT_OK_AND_ASSIGN(int device_id,
-                          GetDeviceIdFromDeviceParsedName(
-                              device_name, DeviceType(kTestDeviceType)));
-
-  EXPECT_EQ(device_id, 1);
+  EXPECT_EQ(GetDeviceIdFromDeviceParsedName(device_name), 0);
   DeviceIdManager::TestOnlyReset();
 }
 
@@ -194,11 +190,7 @@
   DeviceNameUtils::ParsedName device_name;
   device_name.id = 0;
 
-  TF_ASSERT_OK_AND_ASSIGN(int device_id,
-                          GetDeviceIdFromDeviceParsedName(
-                              device_name, DeviceType(kTestDeviceType)));
-
-  EXPECT_EQ(device_id, 0);
+  EXPECT_EQ(GetDeviceIdFromDeviceParsedName(device_name), 0);
 }
 
 }  // namespace
diff --git a/third_party/xla/xla/tsl/lib/strings/BUILD b/third_party/xla/xla/tsl/lib/strings/BUILD
new file mode 100644
index 0000000..03f82a3
--- /dev/null
+++ b/third_party/xla/xla/tsl/lib/strings/BUILD
@@ -0,0 +1,57 @@
+load(
+    "@local_tsl//tsl/platform:rules_cc.bzl",
+    "cc_library",
+)
+load("//xla/tsl:tsl.bzl", "internal_visibility")
+load("//xla/tsl:tsl.default.bzl", "filegroup")
+
+# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"])
+
+cc_library(
+    name = "proto_serialization",
+    srcs = ["proto_serialization.cc"],
+    hdrs = ["proto_serialization.h"],
+    visibility = ["//visibility:public"],
+    deps = [
+        "@com_google_absl//absl/memory",
+        "@com_google_absl//absl/strings",
+        "@local_tsl//tsl/lib/gtl:inlined_vector",
+        "@local_tsl//tsl/platform:hash",
+        "@local_tsl//tsl/platform:logging",
+        "@local_tsl//tsl/platform:macros",
+        "@local_tsl//tsl/platform:protobuf",
+    ],
+)
+
+filegroup(
+    name = "mobile_srcs_only_runtime",
+    srcs = [
+        "proto_serialization.cc",
+        "proto_serialization.h",
+    ],
+    visibility = internal_visibility(["//tensorflow/core/lib/strings:__pkg__"]),
+)
+
+filegroup(
+    name = "legacy_lib_strings_all_headers",
+    srcs = [
+        "proto_serialization.h",
+    ],
+    visibility = internal_visibility(["//tensorflow/core/lib/strings:__pkg__"]),
+)
+
+filegroup(
+    name = "legacy_lib_string_headers",
+    srcs = [
+        "proto_serialization.h",
+    ],
+    visibility = internal_visibility(["//tensorflow/core/lib/strings:__pkg__"]),
+)
+
+filegroup(
+    name = "legacy_lib_internal_public_string_headers",
+    srcs = [
+        "proto_serialization.h",
+    ],
+    visibility = internal_visibility(["//tensorflow/core/lib/strings:__pkg__"]),
+)
diff --git a/third_party/xla/xla/tsl/lib/strings/proto_serialization.cc b/third_party/xla/xla/tsl/lib/strings/proto_serialization.cc
new file mode 100644
index 0000000..06ef074
--- /dev/null
+++ b/third_party/xla/xla/tsl/lib/strings/proto_serialization.cc
@@ -0,0 +1,101 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "xla/tsl/lib/strings/proto_serialization.h"
+
+#include <cstring>
+#include <memory>
+
+#include "absl/memory/memory.h"
+#include "absl/strings/string_view.h"
+#include "tsl/lib/gtl/inlined_vector.h"
+#include "tsl/platform/hash.h"
+#include "tsl/platform/logging.h"
+#include "tsl/platform/macros.h"
+
+namespace tsl {
+namespace {
+
+// Helper for deterministic serialization.
+class DeterministicSerializer {
+ public:
+  explicit DeterministicSerializer(const protobuf::MessageLite& msg)
+      : DeterministicSerializer(msg, msg.ByteSizeLong()) {}
+
+  DeterministicSerializer(const protobuf::MessageLite& msg, size_t size)
+      : size_(size) {
+    char* ptr = space_;
+    if (size_ > sizeof(space_)) {
+      ptr = new char[size_];
+      alloc_.reset(ptr);
+    }
+    bool ok = SerializeToBufferDeterministic(msg, ptr, size_);
+    DCHECK(ok);
+  }
+
+  size_t size() const { return size_; }
+  const char* data() const { return alloc_ == nullptr ? space_ : alloc_.get(); }
+
+ private:
+  // Avoid InlinedVector since it causes 2x slowdown in the compilation
+  // of graphs containing large tensors in debug mode.
+  static constexpr int kInlinedBufferSize = 256;
+  const size_t size_;
+  std::unique_ptr<char[]> alloc_;
+  char space_[kInlinedBufferSize];
+};
+}  // namespace
+
+bool SerializeToStringDeterministic(const protobuf::MessageLite& msg,
+                                    string* result) {
+  const size_t size = msg.ByteSizeLong();
+  DCHECK_LE(size, static_cast<size_t>(INT_MAX));
+  *result = string(size, '\0');
+  return SerializeToBufferDeterministic(msg, const_cast<char*>(result->data()),
+                                        result->size());
+}
+
+bool SerializeToBufferDeterministic(const protobuf::MessageLite& msg,
+                                    char* buffer, size_t size) {
+  DCHECK(msg.ByteSizeLong() == size && size <= static_cast<size_t>(INT_MAX));
+  protobuf::io::ArrayOutputStream array_stream(buffer, size);
+  protobuf::io::CodedOutputStream output_stream(&array_stream);
+  output_stream.SetSerializationDeterministic(true);
+  msg.SerializeWithCachedSizes(&output_stream);
+  return !output_stream.HadError() &&
+         size == static_cast<size_t>(output_stream.ByteCount());
+}
+
+bool AreSerializedProtosEqual(const protobuf::MessageLite& x,
+                              const protobuf::MessageLite& y) {
+  const size_t size = x.ByteSizeLong();
+  if (size != y.ByteSizeLong()) return false;
+  if (size == 0) return true;
+  DeterministicSerializer x_serialized(x, size);
+  DeterministicSerializer y_serialized(y, size);
+  return memcmp(x_serialized.data(), y_serialized.data(), size) == 0;
+}
+
+uint64 DeterministicProtoHash64(const protobuf::MessageLite& proto,
+                                uint64 seed) {
+  DeterministicSerializer serialized(proto);
+  return Hash64(serialized.data(), serialized.size(), seed);
+}
+
+uint64 DeterministicProtoHash64(const protobuf::MessageLite& proto) {
+  DeterministicSerializer serialized(proto);
+  return Hash64(serialized.data(), serialized.size());
+}
+
+}  // namespace tsl
diff --git a/third_party/xla/xla/tsl/lib/strings/proto_serialization.h b/third_party/xla/xla/tsl/lib/strings/proto_serialization.h
new file mode 100644
index 0000000..b79e9af
--- /dev/null
+++ b/third_party/xla/xla/tsl/lib/strings/proto_serialization.h
@@ -0,0 +1,48 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef XLA_TSL_LIB_STRINGS_PROTO_SERIALIZATION_H_
+#define XLA_TSL_LIB_STRINGS_PROTO_SERIALIZATION_H_
+
+#include "tsl/platform/protobuf.h"
+
+namespace tsl {
+
+// Wrapper around protocol buffer serialization that requests deterministic
+// serialization, in particular for Map fields, which serialize in a random
+// order by default. Returns true on success.
+// Serialization is guaranteed to be deterministic for a given binary only.
+// See the following for more details:
+// https://github.com/google/protobuf/blob/a1bb147e96b6f74db6cdf3c3fcb00492472dbbfa/src/google/protobuf/io/coded_stream.h#L834
+bool SerializeToStringDeterministic(const protobuf::MessageLite& msg,
+                                    string* result);
+
+// As above, but takes a pre-allocated buffer wrapped by result.
+// PRECONDITION: size == msg.ByteSizeLong() && size <= INT_MAX.
+bool SerializeToBufferDeterministic(const protobuf::MessageLite& msg,
+                                    char* buffer, size_t size);
+
+// Returns true if serializing x and y using
+// SerializeToBufferDeterministic() yields identical strings.
+bool AreSerializedProtosEqual(const protobuf::MessageLite& x,
+                              const protobuf::MessageLite& y);
+
+// Computes Hash64 of the output of SerializeToBufferDeterministic().
+uint64 DeterministicProtoHash64(const protobuf::MessageLite& proto);
+uint64 DeterministicProtoHash64(const protobuf::MessageLite& proto,
+                                uint64 seed);
+
+}  // namespace tsl
+
+#endif  // XLA_TSL_LIB_STRINGS_PROTO_SERIALIZATION_H_
diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto
index b43944c..620f1b6 100644
--- a/third_party/xla/xla/xla.proto
+++ b/third_party/xla/xla/xla.proto
@@ -875,7 +875,15 @@
   // TODO(b/355487968): Remove this option when validation complete.
   bool xla_enable_command_buffers_during_profiling = 317;
 
-  // Next id: 318
+  // Limit for the number of kernel configurations (plans) to use during
+  // autotuning of cuDNN GEMM fusions. The more - the slower the autotuning
+  // but potentially higher the performance.
+  int32 xla_gpu_cudnn_gemm_max_plans = 318;
+
+  // If enabled, uses the libnvjitlink library for PTX compilation and linking
+  bool xla_gpu_enable_libnvjitlink = 319;
+
+  // Next id: 320
 
   // Extra options to pass to the compilation backend (e.g. LLVM); specific
   // interpretation of these values is left to the backend.
diff --git a/third_party/xla/xla/xla_data.proto b/third_party/xla/xla/xla_data.proto
index d394361..335b59e 100644
--- a/third_party/xla/xla/xla_data.proto
+++ b/third_party/xla/xla/xla_data.proto
@@ -661,6 +661,13 @@
   // The dimension in the start_indices input that contains the starting
   // indices.
   int64 index_vector_dim = 4;
+
+  // This is the batch dimensions in the operand.
+  repeated int64 operand_batching_dims = 5;
+
+  // This is the batch dimensions in the index, and it should be the same size
+  // as operand_batching_dims.
+  repeated int64 start_indices_batching_dims = 6;
 }
 
 // Describes the dimension numbers for a scatter operation.
@@ -675,6 +682,12 @@
 
   repeated int64 scatter_dims_to_operand_dims = 3;
   int64 index_vector_dim = 4;
+
+  // This is the batch dimension in the input.
+  repeated int64 input_batching_dims = 5;
+
+  // This is the batch dimension in the index.
+  repeated int64 scatter_indices_batching_dims = 6;
 }
 
 message ConvolutionDimensionNumbers {